Release to master

This commit is contained in:
Daniel 2020-11-24 22:20:43 +01:00
commit 118cd3797f
68 changed files with 3761 additions and 1138 deletions

41
Gopkg.lock generated
View file

@ -25,6 +25,14 @@
revision = "fba169763ea663f7496376e5cdf709e4c7504704" revision = "fba169763ea663f7496376e5cdf709e4c7504704"
version = "v0.1" version = "v0.1"
[[projects]]
digest = "1:5680f8c40e48f07cb77aece3165a866aaf8276305258b3b70db8ec7ad6ddb78d"
name = "github.com/armon/go-radix"
packages = ["."]
pruneopts = ""
revision = "1a2de0c21c94309923825da3df33a4381872c795"
version = "v1.0.0"
[[projects]] [[projects]]
branch = "master" branch = "master"
digest = "1:baf770c4efa1883bb5e444614e85b8028bbad33913aca290a43298f65d9df485" digest = "1:baf770c4efa1883bb5e444614e85b8028bbad33913aca290a43298f65d9df485"
@ -132,6 +140,22 @@
revision = "b65e62901fc1c0d968042419e74789f6af455eb9" revision = "b65e62901fc1c0d968042419e74789f6af455eb9"
version = "v1.4.2" version = "v1.4.2"
[[projects]]
digest = "1:eaed935e3637c60ad9897e54ab3419c18b91775d6e3af339dec54aeefb48b8d6"
name = "github.com/hashicorp/errwrap"
packages = ["."]
pruneopts = ""
revision = "7b00e5db719c64d14dd0caaacbd13e76254d02c0"
version = "v1.1.0"
[[projects]]
digest = "1:c6e569ffa34fcd24febd3562bff0520a104d15d1a600199cb3141debf2e58c89"
name = "github.com/hashicorp/go-multierror"
packages = ["."]
pruneopts = ""
revision = "2004d9dba6b07a5b8d133209244f376680f9d472"
version = "v1.1.0"
[[projects]] [[projects]]
digest = "1:2f0c811248aeb64978037b357178b1593372439146bda860cb16f2c80785ea93" digest = "1:2f0c811248aeb64978037b357178b1593372439146bda860cb16f2c80785ea93"
name = "github.com/hashicorp/go-version" name = "github.com/hashicorp/go-version"
@ -214,7 +238,10 @@
[[projects]] [[projects]]
digest = "1:83fd2513b9f6ae0997bf646db6b74e9e00131e31002116fda597175f25add42d" digest = "1:83fd2513b9f6ae0997bf646db6b74e9e00131e31002116fda597175f25add42d"
name = "github.com/stretchr/testify" name = "github.com/stretchr/testify"
packages = ["assert"] packages = [
"assert",
"require",
]
pruneopts = "" pruneopts = ""
revision = "f654a9112bbeac49ca2cd45bfbe11533c4666cf8" revision = "f654a9112bbeac49ca2cd45bfbe11533c4666cf8"
version = "v1.6.1" version = "v1.6.1"
@ -278,6 +305,14 @@
pruneopts = "" pruneopts = ""
revision = "0ba52f642ac2f9371a88bfdde41f4b4e195a37c0" revision = "0ba52f642ac2f9371a88bfdde41f4b4e195a37c0"
[[projects]]
branch = "master"
digest = "1:10d47e7094ce8dd202cca920e4c58a68ba1d113908c30fb0cc8590b7d333a348"
name = "golang.org/x/sync"
packages = ["errgroup"]
pruneopts = ""
revision = "67f06af15bc961c363a7260195bcd53487529a21"
[[projects]] [[projects]]
branch = "master" branch = "master"
digest = "1:bf837d996e7dfe7b819cbe53c8c9733e93228577f0561e43996b9ef0ea8a68a9" digest = "1:bf837d996e7dfe7b819cbe53c8c9733e93228577f0561e43996b9ef0ea8a68a9"
@ -339,6 +374,7 @@
analyzer-version = 1 analyzer-version = 1
input-imports = [ input-imports = [
"github.com/aead/serpent", "github.com/aead/serpent",
"github.com/armon/go-radix",
"github.com/bluele/gcache", "github.com/bluele/gcache",
"github.com/davecgh/go-spew/spew", "github.com/davecgh/go-spew/spew",
"github.com/dgraph-io/badger", "github.com/dgraph-io/badger",
@ -346,15 +382,18 @@
"github.com/google/renameio", "github.com/google/renameio",
"github.com/gorilla/mux", "github.com/gorilla/mux",
"github.com/gorilla/websocket", "github.com/gorilla/websocket",
"github.com/hashicorp/go-multierror",
"github.com/hashicorp/go-version", "github.com/hashicorp/go-version",
"github.com/seehuhn/fortuna", "github.com/seehuhn/fortuna",
"github.com/shirou/gopsutil/host", "github.com/shirou/gopsutil/host",
"github.com/spf13/cobra", "github.com/spf13/cobra",
"github.com/stretchr/testify/assert", "github.com/stretchr/testify/assert",
"github.com/stretchr/testify/require",
"github.com/tevino/abool", "github.com/tevino/abool",
"github.com/tidwall/gjson", "github.com/tidwall/gjson",
"github.com/tidwall/sjson", "github.com/tidwall/sjson",
"go.etcd.io/bbolt", "go.etcd.io/bbolt",
"golang.org/x/sync/errgroup",
"golang.org/x/sys/windows", "golang.org/x/sys/windows",
] ]
solver-name = "gps-cdcl" solver-name = "gps-cdcl"

View file

@ -9,7 +9,7 @@ import (
// Config Keys // Config Keys
const ( const (
CfgDefaultListenAddressKey = "api/listenAddress" CfgDefaultListenAddressKey = "core/listenAddress"
) )
var ( var (
@ -41,19 +41,22 @@ func registerConfig() error {
err := config.Register(&config.Option{ err := config.Register(&config.Option{
Name: "API Address", Name: "API Address",
Key: CfgDefaultListenAddressKey, Key: CfgDefaultListenAddressKey,
Description: "Define on which IP and port the API should listen on.", Description: "Defines the IP address and port for the internal API.",
Order: 128,
OptType: config.OptTypeString, OptType: config.OptTypeString,
ExpertiseLevel: config.ExpertiseLevelDeveloper, ExpertiseLevel: config.ExpertiseLevelDeveloper,
ReleaseLevel: config.ReleaseLevelStable, ReleaseLevel: config.ReleaseLevelStable,
DefaultValue: getDefaultListenAddress(), DefaultValue: getDefaultListenAddress(),
ValidationRegex: "^([0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}:[0-9]{1,5}|\\[[:0-9A-Fa-f]+\\]:[0-9]{1,5})$", ValidationRegex: "^([0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}:[0-9]{1,5}|\\[[:0-9A-Fa-f]+\\]:[0-9]{1,5})$",
RequiresRestart: true, RequiresRestart: true,
Annotations: config.Annotations{
config.DisplayOrderAnnotation: 513,
config.CategoryAnnotation: "Development",
},
}) })
if err != nil { if err != nil {
return err return err
} }
listenAddressConfig = config.GetAsString("api/listenAddress", getDefaultListenAddress()) listenAddressConfig = config.GetAsString(CfgDefaultListenAddressKey, getDefaultListenAddress())
return nil return nil
} }

View file

@ -5,6 +5,12 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"sync"
"github.com/tidwall/sjson"
"github.com/safing/portbase/database/iterator"
"github.com/safing/portbase/formats/varint"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/tevino/abool" "github.com/tevino/abool"
@ -43,7 +49,12 @@ func init() {
type DatabaseAPI struct { type DatabaseAPI struct {
conn *websocket.Conn conn *websocket.Conn
sendQueue chan []byte sendQueue chan []byte
subs map[string]*database.Subscription
queriesLock sync.Mutex
queries map[string]*iterator.Iterator
subsLock sync.Mutex
subs map[string]*database.Subscription
shutdownSignal chan struct{} shutdownSignal chan struct{}
shuttingDown *abool.AtomicBool shuttingDown *abool.AtomicBool
@ -72,6 +83,7 @@ func startDatabaseAPI(w http.ResponseWriter, r *http.Request) {
new := &DatabaseAPI{ new := &DatabaseAPI{
conn: wsConn, conn: wsConn,
sendQueue: make(chan []byte, 100), sendQueue: make(chan []byte, 100),
queries: make(map[string]*iterator.Iterator),
subs: make(map[string]*database.Subscription), subs: make(map[string]*database.Subscription),
shutdownSignal: make(chan struct{}), shutdownSignal: make(chan struct{}),
shuttingDown: abool.NewBool(false), shuttingDown: abool.NewBool(false),
@ -94,11 +106,13 @@ func (api *DatabaseAPI) handler() {
// 124|done // 124|done
// 124|error|<message> // 124|error|<message>
// 124|warning|<message> // error with single record, operation continues // 124|warning|<message> // error with single record, operation continues
// 124|cancel
// 125|sub|<query> // 125|sub|<query>
// 125|upd|<key>|<data> // 125|upd|<key>|<data>
// 125|new|<key>|<data> // 125|new|<key>|<data>
// 127|del|<key> // 127|del|<key>
// 125|warning|<message> // error with single record, operation continues // 125|warning|<message> // error with single record, operation continues
// 125|cancel
// 127|qsub|<query> // 127|qsub|<query>
// 127|ok|<key>|<data> // 127|ok|<key>|<data>
// 127|done // 127|done
@ -107,6 +121,7 @@ func (api *DatabaseAPI) handler() {
// 127|new|<key>|<data> // 127|new|<key>|<data>
// 127|del|<key> // 127|del|<key>
// 127|warning|<message> // error with single record, operation continues // 127|warning|<message> // error with single record, operation continues
// 127|cancel
// 128|create|<key>|<data> // 128|create|<key>|<data>
// 128|success // 128|success
@ -141,6 +156,16 @@ func (api *DatabaseAPI) handler() {
} }
parts := bytes.SplitN(msg, []byte("|"), 3) parts := bytes.SplitN(msg, []byte("|"), 3)
// Handle special command "cancel"
if len(parts) == 2 && string(parts[1]) == "cancel" {
// 124|cancel
// 125|cancel
// 127|cancel
go api.handleCancel(parts[0])
continue
}
if len(parts) != 3 { if len(parts) != 3 {
api.send(nil, dbMsgTypeError, "bad request: malformed message", nil) api.send(nil, dbMsgTypeError, "bad request: malformed message", nil)
continue continue
@ -253,7 +278,7 @@ func (api *DatabaseAPI) handleGet(opID []byte, key string) {
r, err := api.db.Get(key) r, err := api.db.Get(key)
if err == nil { if err == nil {
data, err = r.Marshal(r, record.JSON) data, err = marshalRecord(r)
} }
if err != nil { if err != nil {
api.send(opID, dbMsgTypeError, err.Error(), nil) api.send(opID, dbMsgTypeError, err.Error(), nil)
@ -269,6 +294,7 @@ func (api *DatabaseAPI) handleQuery(opID []byte, queryText string) {
// 124|warning|<message> // 124|warning|<message>
// 124|error|<message> // 124|error|<message>
// 124|warning|<message> // error with single record, operation continues // 124|warning|<message> // error with single record, operation continues
// 124|cancel
var err error var err error
@ -288,19 +314,17 @@ func (api *DatabaseAPI) processQuery(opID []byte, q *query.Query) (ok bool) {
return false return false
} }
for r := range it.Next { // Save query iterator.
r.Lock() api.queriesLock.Lock()
data, err := r.Marshal(r, record.JSON) api.queries[string(opID)] = it
r.Unlock() api.queriesLock.Unlock()
if err != nil {
api.send(opID, dbMsgTypeWarning, err.Error(), nil) // Remove query iterator after it ended.
} defer func() {
api.send(opID, dbMsgTypeOk, r.Key(), data) api.queriesLock.Lock()
} defer api.queriesLock.Unlock()
if it.Err() != nil { delete(api.queries, string(opID))
api.send(opID, dbMsgTypeError, it.Err().Error(), nil) }()
return false
}
for { for {
select { select {
@ -312,9 +336,7 @@ func (api *DatabaseAPI) processQuery(opID []byte, q *query.Query) (ok bool) {
// process query feed // process query feed
if r != nil { if r != nil {
// process record // process record
r.Lock() data, err := marshalRecord(r)
data, err := r.Marshal(r, record.JSON)
r.Unlock()
if err != nil { if err != nil {
api.send(opID, dbMsgTypeWarning, err.Error(), nil) api.send(opID, dbMsgTypeWarning, err.Error(), nil)
} }
@ -340,6 +362,7 @@ func (api *DatabaseAPI) handleSub(opID []byte, queryText string) {
// 125|new|<key>|<data> // 125|new|<key>|<data>
// 125|delete|<key> // 125|delete|<key>
// 125|warning|<message> // error with single record, operation continues // 125|warning|<message> // error with single record, operation continues
// 125|cancel
var err error var err error
q, err := query.ParseQuery(queryText) q, err := query.ParseQuery(queryText)
@ -362,10 +385,23 @@ func (api *DatabaseAPI) registerSub(opID []byte, q *query.Query) (sub *database.
api.send(opID, dbMsgTypeError, err.Error(), nil) api.send(opID, dbMsgTypeError, err.Error(), nil)
return nil, false return nil, false
} }
return sub, true return sub, true
} }
func (api *DatabaseAPI) processSub(opID []byte, sub *database.Subscription) { func (api *DatabaseAPI) processSub(opID []byte, sub *database.Subscription) {
// Save subscription.
api.subsLock.Lock()
api.subs[string(opID)] = sub
api.subsLock.Unlock()
// Remove subscription after it ended.
defer func() {
api.subsLock.Lock()
defer api.subsLock.Unlock()
delete(api.subs, string(opID))
}()
for { for {
select { select {
case <-api.shutdownSignal: case <-api.shutdownSignal:
@ -376,9 +412,7 @@ func (api *DatabaseAPI) processSub(opID []byte, sub *database.Subscription) {
// process sub feed // process sub feed
if r != nil { if r != nil {
// process record // process record
r.Lock() data, err := marshalRecord(r)
data, err := r.Marshal(r, record.JSON)
r.Unlock()
if err != nil { if err != nil {
api.send(opID, dbMsgTypeWarning, err.Error(), nil) api.send(opID, dbMsgTypeWarning, err.Error(), nil)
continue continue
@ -414,6 +448,7 @@ func (api *DatabaseAPI) handleQsub(opID []byte, queryText string) {
// 127|new|<key>|<data> // 127|new|<key>|<data>
// 127|delete|<key> // 127|delete|<key>
// 127|warning|<message> // error with single record, operation continues // 127|warning|<message> // error with single record, operation continues
// 127|cancel
var err error var err error
@ -434,6 +469,48 @@ func (api *DatabaseAPI) handleQsub(opID []byte, queryText string) {
api.processSub(opID, sub) api.processSub(opID, sub)
} }
func (api *DatabaseAPI) handleCancel(opID []byte) {
api.cancelQuery(opID)
api.cancelSub(opID)
}
func (api *DatabaseAPI) cancelQuery(opID []byte) {
api.queriesLock.Lock()
defer api.queriesLock.Unlock()
// Get subscription from api.
it, ok := api.queries[string(opID)]
if !ok {
// Fail silently as quries end by themselves when finished.
return
}
// End query.
it.Cancel()
// The query handler will end the communication with a done message.
}
func (api *DatabaseAPI) cancelSub(opID []byte) {
api.subsLock.Lock()
defer api.subsLock.Unlock()
// Get subscription from api.
sub, ok := api.subs[string(opID)]
if !ok {
api.send(opID, dbMsgTypeError, "could not find subscription", nil)
return
}
// End subscription.
err := sub.Cancel()
if err != nil {
api.send(opID, dbMsgTypeError, fmt.Sprintf("failed to cancel subscription: %s", err), nil)
}
// The subscription handler will end the communication with a done message.
}
func (api *DatabaseAPI) handlePut(opID []byte, key string, data []byte, create bool) { func (api *DatabaseAPI) handlePut(opID []byte, key string, data []byte, create bool) {
// 128|create|<key>|<data> // 128|create|<key>|<data>
// 128|success // 128|success
@ -545,3 +622,39 @@ func (api *DatabaseAPI) shutdown() {
api.conn.Close() api.conn.Close()
} }
} }
// marsharlRecords locks and marshals the given record, additionally adding
// metadata and returning it as json.
func marshalRecord(r record.Record) ([]byte, error) {
r.Lock()
defer r.Unlock()
// Pour record into JSON.
jsonData, err := r.Marshal(r, record.JSON)
if err != nil {
return nil, err
}
// Remove JSON identifier for manual editing.
jsonData = bytes.TrimPrefix(jsonData, varint.Pack8(record.JSON))
// Add metadata.
jsonData, err = sjson.SetBytes(jsonData, "_meta", r.Meta())
if err != nil {
return nil, err
}
// Add database key.
jsonData, err = sjson.SetBytes(jsonData, "_meta.Key", r.Key())
if err != nil {
return nil, err
}
// Add JSON identifier again.
formatID := varint.Pack8(record.JSON)
finalData := make([]byte, 0, len(formatID)+len(jsonData))
finalData = append(finalData, formatID...)
finalData = append(finalData, jsonData...)
return finalData, nil
}

View file

@ -24,11 +24,8 @@ type StorageInterface struct {
// Get returns a database record. // Get returns a database record.
func (s *StorageInterface) Get(key string) (record.Record, error) { func (s *StorageInterface) Get(key string) (record.Record, error) {
optionsLock.Lock() opt, err := GetOption(key)
defer optionsLock.Unlock() if err != nil {
opt, ok := options[key]
if !ok {
return nil, storage.ErrNotFound return nil, storage.ErrNotFound
} }
@ -55,11 +52,9 @@ func (s *StorageInterface) Put(r record.Record) (record.Record, error) {
return s.Get(r.DatabaseKey()) return s.Get(r.DatabaseKey())
} }
optionsLock.RLock() option, err := GetOption(r.DatabaseKey())
option, ok := options[r.DatabaseKey()] if err != nil {
optionsLock.RUnlock() return nil, err
if !ok {
return nil, errors.New("config option does not exist")
} }
var value interface{} var value interface{}
@ -77,8 +72,7 @@ func (s *StorageInterface) Put(r record.Record) (record.Record, error) {
return nil, errors.New("received invalid value in \"Value\"") return nil, errors.New("received invalid value in \"Value\"")
} }
err := setConfigOption(r.DatabaseKey(), value, false) if err := setConfigOption(r.DatabaseKey(), value, false); err != nil {
if err != nil {
return nil, err return nil, err
} }
return option.Export() return option.Export()
@ -91,9 +85,8 @@ func (s *StorageInterface) Delete(key string) error {
// Query returns a an iterator for the supplied query. // Query returns a an iterator for the supplied query.
func (s *StorageInterface) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) { func (s *StorageInterface) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
optionsLock.RLock()
optionsLock.Lock() defer optionsLock.RUnlock()
defer optionsLock.Unlock()
it := iterator.New() it := iterator.New()
var opts []*Option var opts []*Option
@ -109,8 +102,7 @@ func (s *StorageInterface) Query(q *query.Query, local, internal bool) (*iterato
} }
func (s *StorageInterface) processQuery(it *iterator.Iterator, opts []*Option) { func (s *StorageInterface) processQuery(it *iterator.Iterator, opts []*Option) {
sort.Sort(sortByKey(opts))
sort.Sort(sortableOptions(opts))
for _, opt := range opts { for _, opt := range opts {
r, err := opt.Export() r, err := opt.Export()
@ -148,17 +140,27 @@ func registerAsDatabase() error {
return nil return nil
} }
func pushFullUpdate() { // handleOptionUpdate updates the expertise and release level options,
optionsLock.RLock() // if required, and eventually pushes a update for the option.
defer optionsLock.RUnlock() // The caller must hold the option lock.
func handleOptionUpdate(option *Option, push bool) {
if expertiseLevelOptionFlag.IsSet() && option == expertiseLevelOption {
updateExpertiseLevel()
}
for _, option := range options { if releaseLevelOptionFlag.IsSet() && option == releaseLevelOption {
updateReleaseLevel()
}
if push {
pushUpdate(option) pushUpdate(option)
} }
} }
// pushUpdate pushes an database update notification for option.
// The caller must hold the option lock.
func pushUpdate(option *Option) { func pushUpdate(option *Option) {
r, err := option.Export() r, err := option.export()
if err != nil { if err != nil {
log.Errorf("failed to export option to push update: %s", err) log.Errorf("failed to export option to push update: %s", err)
} else { } else {

View file

@ -3,17 +3,22 @@
package config package config
import ( import (
"fmt"
"sync/atomic" "sync/atomic"
"github.com/tevino/abool" "github.com/tevino/abool"
) )
// ExpertiseLevel allows to group settings by user expertise.
// It's useful if complex or technical settings should be hidden
// from the average user while still allowing experts and developers
// to change deep configuration settings.
type ExpertiseLevel uint8
// Expertise Level constants // Expertise Level constants
const ( const (
ExpertiseLevelUser uint8 = 0 ExpertiseLevelUser ExpertiseLevel = 0
ExpertiseLevelExpert uint8 = 1 ExpertiseLevelExpert ExpertiseLevel = 1
ExpertiseLevelDeveloper uint8 = 2 ExpertiseLevelDeveloper ExpertiseLevel = 2
ExpertiseLevelNameUser = "user" ExpertiseLevelNameUser = "user"
ExpertiseLevelNameExpert = "expert" ExpertiseLevelNameExpert = "expert"
@ -23,33 +28,46 @@ const (
) )
var ( var (
expertiseLevel *int32
expertiseLevelOption *Option expertiseLevelOption *Option
expertiseLevel = new(int32)
expertiseLevelOptionFlag = abool.New() expertiseLevelOptionFlag = abool.New()
) )
func init() { func init() {
var expertiseLevelVal int32
expertiseLevel = &expertiseLevelVal
registerExpertiseLevelOption() registerExpertiseLevelOption()
} }
func registerExpertiseLevelOption() { func registerExpertiseLevelOption() {
expertiseLevelOption = &Option{ expertiseLevelOption = &Option{
Name: "Expertise Level", Name: "UI Mode",
Key: expertiseLevelKey, Key: expertiseLevelKey,
Description: "The Expertise Level controls the perceived complexity. Higher settings will show you more complex settings and information. This might also affect various other things relying on this setting. Modified settings in higher expertise levels stay in effect when switching back. (Unlike the Release Level)", Description: "Control the default amount of settings and information shown. Hidden settings are still in effect. Can be changed temporarily in the top right corner.",
OptType: OptTypeString, OptType: OptTypeString,
ExpertiseLevel: ExpertiseLevelUser, ExpertiseLevel: ExpertiseLevelUser,
ReleaseLevel: ExpertiseLevelUser, ReleaseLevel: ReleaseLevelStable,
DefaultValue: ExpertiseLevelNameUser,
RequiresRestart: false, Annotations: Annotations{
DefaultValue: ExpertiseLevelNameUser, DisplayOrderAnnotation: -16,
DisplayHintAnnotation: DisplayHintOneOf,
ExternalOptType: "string list", CategoryAnnotation: "User Interface",
ValidationRegex: fmt.Sprintf("^(%s|%s|%s)$", ExpertiseLevelNameUser, ExpertiseLevelNameExpert, ExpertiseLevelNameDeveloper), },
PossibleValues: []PossibleValue{
{
Name: "Simple",
Value: ExpertiseLevelNameUser,
Description: "Hide complex settings and information.",
},
{
Name: "Advanced",
Value: ExpertiseLevelNameExpert,
Description: "Show technical details.",
},
{
Name: "Developer",
Value: ExpertiseLevelNameDeveloper,
Description: "Developer mode. Please be careful!",
},
},
} }
err := Register(expertiseLevelOption) err := Register(expertiseLevelOption)
@ -61,10 +79,6 @@ func registerExpertiseLevelOption() {
} }
func updateExpertiseLevel() { func updateExpertiseLevel() {
// check if already registered
if !expertiseLevelOptionFlag.IsSet() {
return
}
// get value // get value
value := expertiseLevelOption.activeFallbackValue value := expertiseLevelOption.activeFallbackValue
if expertiseLevelOption.activeValue != nil { if expertiseLevelOption.activeValue != nil {

View file

@ -15,26 +15,24 @@ type (
BoolOption func() bool BoolOption func() bool
) )
func getValueCache(name string, option *Option, requestedType uint8) (*Option, *valueCache) { func getValueCache(name string, option *Option, requestedType OptionType) (*Option, *valueCache) {
// get option // get option
if option == nil { if option == nil {
var ok bool var err error
optionsLock.RLock() option, err = GetOption(name)
option, ok = options[name] if err != nil {
optionsLock.RUnlock()
if !ok {
log.Errorf("config: request for unregistered option: %s", name) log.Errorf("config: request for unregistered option: %s", name)
return nil, nil return nil, nil
} }
} }
// check type // Check the option type, no locking required as
// OptType is immutable once it is set
if requestedType != option.OptType { if requestedType != option.OptType {
log.Errorf("config: bad type: requested %s as %s, but is %s", name, getTypeName(requestedType), getTypeName(option.OptType)) log.Errorf("config: bad type: requested %s as %s, but is %s", name, getTypeName(requestedType), getTypeName(option.OptType))
return option, nil return option, nil
} }
// lock option
option.Lock() option.Lock()
defer option.Unlock() defer option.Unlock()

View file

@ -7,25 +7,25 @@ import (
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
) )
func parseAndSetConfig(jsonData string) error { func parseAndReplaceConfig(jsonData string) error {
m, err := JSONToMap([]byte(jsonData)) m, err := JSONToMap([]byte(jsonData))
if err != nil { if err != nil {
return err return err
} }
return setConfig(m) return replaceConfig(m)
} }
func parseAndSetDefaultConfig(jsonData string) error { func parseAndReplaceDefaultConfig(jsonData string) error {
m, err := JSONToMap([]byte(jsonData)) m, err := JSONToMap([]byte(jsonData))
if err != nil { if err != nil {
return err return err
} }
return SetDefaultConfig(m) return replaceDefaultConfig(m)
} }
func quickRegister(t *testing.T, key string, optType uint8, defaultValue interface{}) { func quickRegister(t *testing.T, key string, optType OptionType, defaultValue interface{}) {
err := Register(&Option{ err := Register(&Option{
Name: key, Name: key,
Key: key, Key: key,
@ -55,7 +55,7 @@ func TestGet(t *testing.T) { //nolint:gocognit
quickRegister(t, "hot", OptTypeBool, false) quickRegister(t, "hot", OptTypeBool, false)
quickRegister(t, "cold", OptTypeBool, true) quickRegister(t, "cold", OptTypeBool, true)
err = parseAndSetConfig(` err = parseAndReplaceConfig(`
{ {
"monkey": "a", "monkey": "a",
"zebras": { "zebras": {
@ -70,7 +70,7 @@ func TestGet(t *testing.T) { //nolint:gocognit
t.Fatal(err) t.Fatal(err)
} }
err = parseAndSetDefaultConfig(` err = parseAndReplaceDefaultConfig(`
{ {
"monkey": "b", "monkey": "b",
"snake": "0", "snake": "0",
@ -106,7 +106,7 @@ func TestGet(t *testing.T) { //nolint:gocognit
t.Errorf("cold should be false, is %v", cold()) t.Errorf("cold should be false, is %v", cold())
} }
err = parseAndSetConfig(` err = parseAndReplaceConfig(`
{ {
"monkey": "3" "monkey": "3"
} }
@ -284,7 +284,7 @@ func BenchmarkGetAsStringCached(b *testing.B) {
options = make(map[string]*Option) options = make(map[string]*Option)
// Setup // Setup
err := parseAndSetConfig(`{ err := parseAndReplaceConfig(`{
"monkey": "banana" "monkey": "banana"
}`) }`)
if err != nil { if err != nil {
@ -303,7 +303,7 @@ func BenchmarkGetAsStringCached(b *testing.B) {
func BenchmarkGetAsStringRefetch(b *testing.B) { func BenchmarkGetAsStringRefetch(b *testing.B) {
// Setup // Setup
err := parseAndSetConfig(`{ err := parseAndReplaceConfig(`{
"monkey": "banana" "monkey": "banana"
}`) }`)
if err != nil { if err != nil {
@ -321,7 +321,7 @@ func BenchmarkGetAsStringRefetch(b *testing.B) {
func BenchmarkGetAsIntCached(b *testing.B) { func BenchmarkGetAsIntCached(b *testing.B) {
// Setup // Setup
err := parseAndSetConfig(`{ err := parseAndReplaceConfig(`{
"elephant": 1 "elephant": 1
}`) }`)
if err != nil { if err != nil {
@ -340,7 +340,7 @@ func BenchmarkGetAsIntCached(b *testing.B) {
func BenchmarkGetAsIntRefetch(b *testing.B) { func BenchmarkGetAsIntRefetch(b *testing.B) {
// Setup // Setup
err := parseAndSetConfig(`{ err := parseAndReplaceConfig(`{
"elephant": 1 "elephant": 1
}`) }`)
if err != nil { if err != nil {

View file

@ -11,16 +11,22 @@ import (
"github.com/safing/portbase/database/record" "github.com/safing/portbase/database/record"
) )
// OptionType defines the value type of an option.
type OptionType uint8
// Various attribute options. Use ExternalOptType for extended types in the frontend. // Various attribute options. Use ExternalOptType for extended types in the frontend.
const ( const (
OptTypeString uint8 = 1 optTypeAny OptionType = 0
OptTypeStringArray uint8 = 2 OptTypeString OptionType = 1
OptTypeInt uint8 = 3 OptTypeStringArray OptionType = 2
OptTypeBool uint8 = 4 OptTypeInt OptionType = 3
OptTypeBool OptionType = 4
) )
func getTypeName(t uint8) string { func getTypeName(t OptionType) string {
switch t { switch t {
case optTypeAny:
return "any"
case OptTypeString: case OptTypeString:
return "string" return "string"
case OptTypeStringArray: case OptTypeStringArray:
@ -34,25 +40,195 @@ func getTypeName(t uint8) string {
} }
} }
// PossibleValue defines a value that is possible for
// a configuration setting.
type PossibleValue struct {
// Name is a human readable name of the option.
Name string
// Description is a human readable description of
// this value.
Description string
// Value is the actual value of the option. The type
// must match the option's value type.
Value interface{}
}
// Annotations can be attached to configuration options to
// provide hints for user interfaces or other systems working
// or setting configuration options.
// Annotation keys should follow the below format to ensure
// future well-known annotation additions do not conflict
// with vendor/product/package specific annoations.
//
// Format: <vendor/package>:<scope>:<identifier>
type Annotations map[string]interface{}
// Well known annotations defined by this package.
const (
// DisplayHintAnnotation provides a hint for the user
// interface on how to render an option.
// The value of DisplayHintAnnotation is expected to
// be a string. See DisplayHintXXXX constants below
// for a list of well-known display hint annotations.
DisplayHintAnnotation = "safing/portbase:ui:display-hint"
// DisplayOrderAnnotation provides a hint for the user
// interface in which order settings should be displayed.
// The value of DisplayOrderAnnotations is expected to be
// an number (int).
DisplayOrderAnnotation = "safing/portbase:ui:order"
// UnitAnnotations defines the SI unit of an option (if any).
UnitAnnotation = "safing/portbase:ui:unit"
// CategoryAnnotations can provide an additional category
// to each settings. This category can be used by a user
// interface to group certain options together.
// User interfaces should treat a CategoryAnnotation, if
// supported, with higher priority as a DisplayOrderAnnotation.
CategoryAnnotation = "safing/portbase:ui:category"
// SubsystemAnnotation can be used to mark an option as part
// of a module subsystem.
SubsystemAnnotation = "safing/portbase:module:subsystem"
// StackableAnnotation can be set on configuration options that
// stack on top of the default (or otherwise related) options.
// The value of StackableAnnotaiton is expected to be a boolean but
// may be extended to hold references to other options in the
// future.
StackableAnnotation = "safing/portbase:options:stackable"
// QuickSettingAnnotation can be used to add quick settings to
// a configuration option. A quick setting can support the user
// by switching between pre-configured values.
// The type of a quick-setting annotation is []QuickSetting or QuickSetting.
QuickSettingsAnnotation = "safing/portbase:ui:quick-setting"
// RequiresAnnotation can be used to mark another option as a
// requirement. The type of RequiresAnnotation is []ValueRequirement
// or ValueRequirement.
RequiresAnnotation = "safing/portbase:config:requires"
)
// QuickSettingsAction defines the action of a quick setting.
type QuickSettingsAction string
const (
// QuickReplace replaces the current setting with the one from
// the quick setting.
QuickReplace = QuickSettingsAction("replace")
// QuickMergeTop merges the value of the quick setting with the
// already configured one adding new values on the top. Merging
// is only supported for OptTypeStringArray.
QuickMergeTop = QuickSettingsAction("merge-top")
// QuickMergeBottom merges the value of the quick setting with the
// already configured one adding new values at the bottom. Merging
// is only supported for OptTypeStringArray.
QuickMergeBottom = QuickSettingsAction("merge-bottom")
)
// QuickSetting defines a quick setting for a configuration option and
// should be used together with the QuickSettingsAnnotation.
type QuickSetting struct {
// Name is the name of the quick setting.
Name string
// Value is the value that the quick-setting configures. It must match
// the expected value type of the annotated option.
Value interface{}
// Action defines the action of the quick setting.
Action QuickSettingsAction
}
// ValueRequirement defines a requirement on another configuration option.
type ValueRequirement struct {
// Key is the key of the configuration option that is required.
Key string
// Value that is required.
Value interface{}
}
// Values for the DisplayHintAnnotation
const (
// DisplayHintOneOf is used to mark an option
// as a "select"-style option. That is, only one of
// the supported values may be set. This option makes
// only sense together with the PossibleValues property
// of Option.
DisplayHintOneOf = "one-of"
// DisplayHintOrdered Used to mark a list option as ordered.
// That is, the order of items is important and a user interface
// is encouraged to provide the user with re-ordering support
// (like drag'n'drop).
DisplayHintOrdered = "ordered"
)
// Option describes a configuration option. // Option describes a configuration option.
type Option struct { type Option struct {
sync.Mutex sync.Mutex
// Name holds the name of the configuration options.
Name string // It should be human readable and is mainly used for
Key string // in path format: category/sub/key // presentation purposes.
// Name is considered immutable after the option has
// been created.
Name string
// Key holds the database path for the option. It should
// follow the path format `category/sub/key`.
// Key is considered immutable after the option has
// been created.
Key string
// Description holds a human readable description of the
// option and what is does. The description should be short.
// Use the Help property for a longer support text.
// Description is considered immutable after the option has
// been created.
Description string Description string
Help string // Help may hold a long version of the description providing
Order int // assistance with the configuration option.
// Help is considered immutable after the option has
OptType uint8 // been created.
ExpertiseLevel uint8 Help string
ReleaseLevel uint8 // OptType defines the type of the option.
// OptType is considered immutable after the option has
// been created.
OptType OptionType
// ExpertiseLevel can be used to set the required expertise
// level for the option to be displayed to a user.
// ExpertiseLevel is considered immutable after the option has
// been created.
ExpertiseLevel ExpertiseLevel
// ReleaseLevel is used to mark the stability of the option.
// ReleaseLevel is considered immutable after the option has
// been created.
ReleaseLevel ReleaseLevel
// RequiresRestart should be set to true if a modification of
// the options value requires a restart of the whole application
// to take effect.
// RequiresRestart is considered immutable after the option has
// been created.
RequiresRestart bool RequiresRestart bool
DefaultValue interface{} // DefaultValue holds the default value of the option. Note that
// this value can be overwritten during runtime (see activeDefaultValue
ExternalOptType string // and activeFallbackValue).
// DefaultValue is considered immutable after the option has
// been created.
DefaultValue interface{}
// ValidationRegex may contain a regular expression used to validate
// the value of option. If the option type is set to OptTypeStringArray
// the validation regex is applied to all entries of the string slice.
// Note that it is recommended to keep the validation regex simple so
// it can also be used in other languages (mainly JavaScript) to provide
// a better user-experience by pre-validating the expression.
// ValidationRegex is considered immutable after the option has
// been created.
ValidationRegex string ValidationRegex string
// PossibleValues may be set to a slice of values that are allowed
// for this configuration setting. Note that PossibleValues makes most
// sense when ExternalOptType is set to HintOneOf
// PossibleValues is considered immutable after the option has
// been created.
PossibleValues []PossibleValue `json:",omitempty"`
// Annotations adds additional annotations to the configuration options.
// See documentation of Annotations for more information.
// Annotations is considered mutable and setting/reading annotation keys
// must be performed while the option is locked.
Annotations Annotations
activeValue *valueCache // runtime value (loaded from config file or set by user) activeValue *valueCache // runtime value (loaded from config file or set by user)
activeDefaultValue *valueCache // runtime default value (may be set internally) activeDefaultValue *valueCache // runtime default value (may be set internally)
@ -60,11 +236,54 @@ type Option struct {
compiledRegex *regexp.Regexp compiledRegex *regexp.Regexp
} }
// AddAnnotation adds the annotation key to option if it's not already set.
func (option *Option) AddAnnotation(key string, value interface{}) {
option.Lock()
defer option.Unlock()
if option.Annotations == nil {
option.Annotations = make(Annotations)
}
if _, ok := option.Annotations[key]; ok {
return
}
option.Annotations[key] = value
}
// SetAnnotation sets the value of the annotation key overwritting an
// existing value if required.
func (option *Option) SetAnnotation(key string, value interface{}) {
option.Lock()
defer option.Unlock()
if option.Annotations == nil {
option.Annotations = make(Annotations)
}
option.Annotations[key] = value
}
// GetAnnotation returns the value of the annotation key
func (option *Option) GetAnnotation(key string) (interface{}, bool) {
option.Lock()
defer option.Unlock()
if option.Annotations == nil {
return nil, false
}
val, ok := option.Annotations[key]
return val, ok
}
// Export expors an option to a Record. // Export expors an option to a Record.
func (option *Option) Export() (record.Record, error) { func (option *Option) Export() (record.Record, error) {
option.Lock() option.Lock()
defer option.Unlock() defer option.Unlock()
return option.export()
}
func (option *Option) export() (record.Record, error) {
data, err := json.Marshal(option) data, err := json.Marshal(option)
if err != nil { if err != nil {
return nil, err return nil, err
@ -93,20 +312,8 @@ func (option *Option) Export() (record.Record, error) {
return r, nil return r, nil
} }
type sortableOptions []*Option type sortByKey []*Option
// Len is the number of elements in the collection. func (opts sortByKey) Len() int { return len(opts) }
func (opts sortableOptions) Len() int { func (opts sortByKey) Less(i, j int) bool { return opts[i].Key < opts[j].Key }
return len(opts) func (opts sortByKey) Swap(i, j int) { opts[i], opts[j] = opts[j], opts[i] }
}
// Less reports whether the element with
// index i should sort before the element with index j.
func (opts sortableOptions) Less(i, j int) bool {
return opts[i].Key < opts[j].Key
}
// Swap swaps the elements with indexes i and j.
func (opts sortableOptions) Swap(i, j int) {
opts[i], opts[j] = opts[j], opts[i]
}

View file

@ -31,11 +31,16 @@ func loadConfig() error {
return err return err
} }
// apply return replaceConfig(newValues)
return setConfig(newValues)
} }
// saveConfig saves the current configuration to file.
// It will acquire a read-lock on the global options registry
// lock and must lock each option!
func saveConfig() error { func saveConfig() error {
optionsLock.RLock()
defer optionsLock.RUnlock()
// check if persistence is configured // check if persistence is configured
if configFilePath == "" { if configFilePath == "" {
return nil return nil
@ -43,15 +48,18 @@ func saveConfig() error {
// extract values // extract values
activeValues := make(map[string]interface{}) activeValues := make(map[string]interface{})
optionsLock.RLock()
for key, option := range options { for key, option := range options {
// we cannot immedately unlock the option afger
// getData() because someone could lock and change it
// while we are marshaling the value (i.e. for string slices).
// We NEED to keep the option locks until we finsihed.
option.Lock() option.Lock()
defer option.Unlock()
if option.activeValue != nil { if option.activeValue != nil {
activeValues[key] = option.activeValue.getData(option) activeValues[key] = option.activeValue.getData(option)
} }
option.Unlock()
} }
optionsLock.RUnlock()
// convert to JSON // convert to JSON
data, err := MapToJSON(activeValues) data, err := MapToJSON(activeValues)

View file

@ -27,7 +27,7 @@ func NewPerspective(config map[string]interface{}) (*Perspective, error) {
var firstErr error var firstErr error
var errCnt int var errCnt int
optionsLock.Lock() optionsLock.RLock()
optionsLoop: optionsLoop:
for key, option := range options { for key, option := range options {
// get option key from config // get option key from config
@ -51,7 +51,7 @@ optionsLoop:
valueCache: valueCache, valueCache: valueCache,
} }
} }
optionsLock.Unlock() optionsLock.RUnlock()
if firstErr != nil { if firstErr != nil {
if errCnt > 0 { if errCnt > 0 {
@ -63,22 +63,19 @@ optionsLoop:
return perspective, nil return perspective, nil
} }
func (p *Perspective) getPerspectiveValueCache(name string, requestedType uint8) *valueCache { func (p *Perspective) getPerspectiveValueCache(name string, requestedType OptionType) *valueCache {
// get option // get option
pOption, ok := p.config[name] pOption, ok := p.config[name]
if !ok { if !ok {
// check if option exists at all // check if option exists at all
optionsLock.RLock() if _, err := GetOption(name); err != nil {
_, ok = options[name]
optionsLock.RUnlock()
if !ok {
log.Errorf("config: request for unregistered option: %s", name) log.Errorf("config: request for unregistered option: %s", name)
} }
return nil return nil
} }
// check type // check type
if requestedType != pOption.option.OptType { if requestedType != pOption.option.OptType && requestedType != optTypeAny {
log.Errorf("config: bad type: requested %s as %s, but is %s", name, getTypeName(requestedType), getTypeName(pOption.option.OptType)) log.Errorf("config: bad type: requested %s as %s, but is %s", name, getTypeName(requestedType), getTypeName(pOption.option.OptType))
return nil return nil
} }
@ -91,6 +88,12 @@ func (p *Perspective) getPerspectiveValueCache(name string, requestedType uint8)
return pOption.valueCache return pOption.valueCache
} }
// Has returns whether the given option is set in the perspective.
func (p *Perspective) Has(name string) bool {
valueCache := p.getPerspectiveValueCache(name, optTypeAny)
return valueCache != nil
}
// GetAsString returns a function that returns the wanted string with high performance. // GetAsString returns a function that returns the wanted string with high performance.
func (p *Perspective) GetAsString(name string) (value string, ok bool) { func (p *Perspective) GetAsString(name string) (value string, ok bool) {
valueCache := p.getPerspectiveValueCache(name, OptTypeString) valueCache := p.getPerspectiveValueCache(name, OptTypeString)

View file

@ -3,6 +3,7 @@ package config
import ( import (
"fmt" "fmt"
"regexp" "regexp"
"strings"
"sync" "sync"
) )
@ -11,6 +12,37 @@ var (
options = make(map[string]*Option) options = make(map[string]*Option)
) )
// ForEachOption calls fn for each defined option. If fn returns
// and error the iteration is stopped and the error is returned.
// Note that ForEachOption does not guarantee a stable order of
// iteration between multiple calles. ForEachOption does NOT lock
// opt when calling fn.
func ForEachOption(fn func(opt *Option) error) error {
optionsLock.RLock()
defer optionsLock.RUnlock()
for _, opt := range options {
if err := fn(opt); err != nil {
return err
}
}
return nil
}
// GetOption returns the option with name or an error
// if the option does not exist. The caller should lock
// the returned option itself for further processing.
func GetOption(name string) (*Option, error) {
optionsLock.RLock()
defer optionsLock.RUnlock()
opt, ok := options[name]
if !ok {
return nil, fmt.Errorf("option %q does not exist", name)
}
return opt, nil
}
// Register registers a new configuration option. // Register registers a new configuration option.
func Register(option *Option) error { func Register(option *Option) error {
if option.Name == "" { if option.Name == "" {
@ -26,8 +58,15 @@ func Register(option *Option) error {
return fmt.Errorf("failed to register option: please set option.OptType") return fmt.Errorf("failed to register option: please set option.OptType")
} }
var err error if option.ValidationRegex == "" && option.PossibleValues != nil {
values := make([]string, len(option.PossibleValues))
for idx, val := range option.PossibleValues {
values[idx] = fmt.Sprintf("%v", val.Value)
}
option.ValidationRegex = fmt.Sprintf("^(%s)$", strings.Join(values, "|"))
}
var err error
if option.ValidationRegex != "" { if option.ValidationRegex != "" {
option.compiledRegex, err = regexp.Compile(option.ValidationRegex) option.compiledRegex, err = regexp.Compile(option.ValidationRegex)
if err != nil { if err != nil {

View file

@ -3,17 +3,20 @@
package config package config
import ( import (
"fmt"
"sync/atomic" "sync/atomic"
"github.com/tevino/abool" "github.com/tevino/abool"
) )
// ReleaseLevel is used to define the maturity of a
// configuration setting.
type ReleaseLevel uint8
// Release Level constants // Release Level constants
const ( const (
ReleaseLevelStable uint8 = 0 ReleaseLevelStable ReleaseLevel = 0
ReleaseLevelBeta uint8 = 1 ReleaseLevelBeta ReleaseLevel = 1
ReleaseLevelExperimental uint8 = 2 ReleaseLevelExperimental ReleaseLevel = 2
ReleaseLevelNameStable = "stable" ReleaseLevelNameStable = "stable"
ReleaseLevelNameBeta = "beta" ReleaseLevelNameBeta = "beta"
@ -23,33 +26,46 @@ const (
) )
var ( var (
releaseLevel *int32 releaseLevel = new(int32)
releaseLevelOption *Option releaseLevelOption *Option
releaseLevelOptionFlag = abool.New() releaseLevelOptionFlag = abool.New()
) )
func init() { func init() {
var releaseLevelVal int32
releaseLevel = &releaseLevelVal
registerReleaseLevelOption() registerReleaseLevelOption()
} }
func registerReleaseLevelOption() { func registerReleaseLevelOption() {
releaseLevelOption = &Option{ releaseLevelOption = &Option{
Name: "Release Level", Name: "Feature Stability",
Key: releaseLevelKey, Key: releaseLevelKey,
Description: "The Release Level changes which features are available to you. Some beta or experimental features are also available in the stable release channel. Unavailable settings are set to the default value.", Description: `May break things. Decide if you want to experiment with unstable features. "Beta" has been tested roughly by the Safing team while "Experimental" is really raw. When "Beta" or "Experimental" are disabled, their settings use the default again.`,
OptType: OptTypeString, OptType: OptTypeString,
ExpertiseLevel: ExpertiseLevelExpert, ExpertiseLevel: ExpertiseLevelExpert,
ReleaseLevel: ReleaseLevelStable, ReleaseLevel: ReleaseLevelStable,
DefaultValue: ReleaseLevelNameStable,
RequiresRestart: false, Annotations: Annotations{
DefaultValue: ReleaseLevelNameStable, DisplayOrderAnnotation: -8,
DisplayHintAnnotation: DisplayHintOneOf,
ExternalOptType: "string list", CategoryAnnotation: "Updates",
ValidationRegex: fmt.Sprintf("^(%s|%s|%s)$", ReleaseLevelNameStable, ReleaseLevelNameBeta, ReleaseLevelNameExperimental), },
PossibleValues: []PossibleValue{
{
Name: "Stable",
Value: ReleaseLevelNameStable,
Description: "Only show stable features.",
},
{
Name: "Beta",
Value: ReleaseLevelNameBeta,
Description: "Show stable and beta features.",
},
{
Name: "Experimental",
Value: ReleaseLevelNameExperimental,
Description: "Show all features",
},
},
} }
err := Register(releaseLevelOption) err := Register(releaseLevelOption)
@ -61,10 +77,6 @@ func registerReleaseLevelOption() {
} }
func updateReleaseLevel() { func updateReleaseLevel() {
// check if already registered
if !releaseLevelOptionFlag.IsSet() {
return
}
// get value // get value
value := releaseLevelOption.activeFallbackValue value := releaseLevelOption.activeFallbackValue
if releaseLevelOption.activeValue != nil { if releaseLevelOption.activeValue != nil {
@ -86,6 +98,6 @@ func updateReleaseLevel() {
} }
} }
func getReleaseLevel() uint8 { func getReleaseLevel() ReleaseLevel {
return uint8(atomic.LoadInt32(releaseLevel)) return ReleaseLevel(atomic.LoadInt32(releaseLevel))
} }

View file

@ -26,11 +26,9 @@ func getValidityFlag() *abool.AtomicBool {
return validityFlag return validityFlag
} }
// signalChanges marks the configs validtityFlag as dirty and eventually
// triggers a config change event.
func signalChanges() { func signalChanges() {
// refetch and save release level and expertise level
updateReleaseLevel()
updateExpertiseLevel()
// reset validity flag // reset validity flag
validityFlagLock.Lock() validityFlagLock.Lock()
validityFlag.SetTo(false) validityFlag.SetTo(false)
@ -40,14 +38,20 @@ func signalChanges() {
module.TriggerEvent(configChangeEvent, nil) module.TriggerEvent(configChangeEvent, nil)
} }
// setConfig sets the (prioritized) user defined config. // replaceConfig sets the (prioritized) user defined config.
func setConfig(newValues map[string]interface{}) error { func replaceConfig(newValues map[string]interface{}) error {
var firstErr error var firstErr error
var errCnt int var errCnt int
optionsLock.Lock() // RLock the options because we are not adding or removing
// options from the registration but rather only update the
// options value which is guarded by the option's lock itself
optionsLock.RLock()
defer optionsLock.RUnlock()
for key, option := range options { for key, option := range options {
newValue, ok := newValues[key] newValue, ok := newValues[key]
option.Lock() option.Lock()
option.activeValue = nil option.activeValue = nil
if ok { if ok {
@ -61,12 +65,12 @@ func setConfig(newValues map[string]interface{}) error {
} }
} }
} }
handleOptionUpdate(option, true)
option.Unlock() option.Unlock()
} }
optionsLock.Unlock()
signalChanges() signalChanges()
go pushFullUpdate()
if firstErr != nil { if firstErr != nil {
if errCnt > 0 { if errCnt > 0 {
@ -78,14 +82,20 @@ func setConfig(newValues map[string]interface{}) error {
return nil return nil
} }
// SetDefaultConfig sets the (fallback) default config. // replaceDefaultConfig sets the (fallback) default config.
func SetDefaultConfig(newValues map[string]interface{}) error { func replaceDefaultConfig(newValues map[string]interface{}) error {
var firstErr error var firstErr error
var errCnt int var errCnt int
optionsLock.Lock() // RLock the options because we are not adding or removing
// options from the registration but rather only update the
// options value which is guarded by the option's lock itself
optionsLock.RLock()
defer optionsLock.RUnlock()
for key, option := range options { for key, option := range options {
newValue, ok := newValues[key] newValue, ok := newValues[key]
option.Lock() option.Lock()
option.activeDefaultValue = nil option.activeDefaultValue = nil
if ok { if ok {
@ -99,12 +109,11 @@ func SetDefaultConfig(newValues map[string]interface{}) error {
} }
} }
} }
handleOptionUpdate(option, true)
option.Unlock() option.Unlock()
} }
optionsLock.Unlock()
signalChanges() signalChanges()
go pushFullUpdate()
if firstErr != nil { if firstErr != nil {
if errCnt > 0 { if errCnt > 0 {
@ -122,11 +131,9 @@ func SetConfigOption(key string, value interface{}) error {
} }
func setConfigOption(key string, value interface{}, push bool) (err error) { func setConfigOption(key string, value interface{}, push bool) (err error) {
optionsLock.Lock() option, err := GetOption(key)
option, ok := options[key] if err != nil {
optionsLock.Unlock() return err
if !ok {
return fmt.Errorf("config option %s does not exist", key)
} }
option.Lock() option.Lock()
@ -139,16 +146,17 @@ func setConfigOption(key string, value interface{}, push bool) (err error) {
option.activeValue = valueCache option.activeValue = valueCache
} }
} }
handleOptionUpdate(option, push)
option.Unlock() option.Unlock()
if err != nil { if err != nil {
return err return err
} }
// finalize change, activate triggers // finalize change, activate triggers
signalChanges() signalChanges()
if push {
go pushUpdate(option)
}
return saveConfig() return saveConfig()
} }
@ -158,11 +166,9 @@ func SetDefaultConfigOption(key string, value interface{}) error {
} }
func setDefaultConfigOption(key string, value interface{}, push bool) (err error) { func setDefaultConfigOption(key string, value interface{}, push bool) (err error) {
optionsLock.Lock() option, err := GetOption(key)
option, ok := options[key] if err != nil {
optionsLock.Unlock() return err
if !ok {
return fmt.Errorf("config option %s does not exist", key)
} }
option.Lock() option.Lock()
@ -175,15 +181,16 @@ func setDefaultConfigOption(key string, value interface{}, push bool) (err error
option.activeDefaultValue = valueCache option.activeDefaultValue = valueCache
} }
} }
handleOptionUpdate(option, push)
option.Unlock() option.Unlock()
if err != nil { if err != nil {
return err return err
} }
// finalize change, activate triggers // finalize change, activate triggers
signalChanges() signalChanges()
if push {
go pushUpdate(option)
}
return saveConfig() return saveConfig()
} }

View file

@ -24,7 +24,7 @@ func TestLayersGetters(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = setConfig(mapData) err = replaceConfig(mapData)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"math" "math"
"reflect"
) )
type valueCache struct { type valueCache struct {
@ -28,7 +29,50 @@ func (vc *valueCache) getData(opt *Option) interface{} {
} }
} }
// isAllowedPossibleValue checks if value is defined as a PossibleValue
// in opt. If there are not possible values defined value is considered
// allowed and nil is returned. isAllowedPossibleValue ensure the actual
// value is an allowed primitiv value by using reflection to convert
// value and each PossibleValue to a comparable primitiv if possible.
// In case of complex value types isAllowedPossibleValue uses
// reflect.DeepEqual as a fallback.
func isAllowedPossibleValue(opt *Option, value interface{}) error {
if opt.PossibleValues == nil {
return nil
}
for _, val := range opt.PossibleValues {
compareAgainst := val.Value
valueType := reflect.TypeOf(value)
// loading int's from the configuration JSON does not preserve the correct type
// as we get float64 instead. Make sure to convert them before.
if reflect.TypeOf(val.Value).ConvertibleTo(valueType) {
compareAgainst = reflect.ValueOf(val.Value).Convert(valueType).Interface()
}
if compareAgainst == value {
return nil
}
if reflect.DeepEqual(val.Value, value) {
return nil
}
}
return fmt.Errorf("value is not allowed")
}
// validateValue ensures that value matches the expected type of option.
// It does not create a copy of the value!
func validateValue(option *Option, value interface{}) (*valueCache, error) { //nolint:gocyclo func validateValue(option *Option, value interface{}) (*valueCache, error) { //nolint:gocyclo
if option.OptType != OptTypeStringArray {
if err := isAllowedPossibleValue(option, value); err != nil {
return nil, fmt.Errorf("validation of option %s failed for %v: %w", option.Key, value, err)
}
}
reflect.TypeOf(value).ConvertibleTo(reflect.TypeOf(""))
switch v := value.(type) { switch v := value.(type) {
case string: case string:
if option.OptType != OptTypeString { if option.OptType != OptTypeString {
@ -61,6 +105,10 @@ func validateValue(option *Option, value interface{}) (*valueCache, error) { //n
if !option.compiledRegex.MatchString(entry) { if !option.compiledRegex.MatchString(entry) {
return nil, fmt.Errorf("validation of option %s failed: string \"%s\" at index %d did not match validation regex", option.Key, entry, pos) return nil, fmt.Errorf("validation of option %s failed: string \"%s\" at index %d did not match validation regex", option.Key, entry, pos)
} }
if err := isAllowedPossibleValue(option, entry); err != nil {
return nil, fmt.Errorf("validation of option %s failed: string %q at index %d is not allowed", option.Key, entry, pos)
}
} }
} }
return &valueCache{stringArrayVal: v}, nil return &valueCache{stringArrayVal: v}, nil

View file

@ -6,8 +6,6 @@ import (
"sync" "sync"
"time" "time"
"github.com/tevino/abool"
"github.com/safing/portbase/database/iterator" "github.com/safing/portbase/database/iterator"
"github.com/safing/portbase/database/query" "github.com/safing/portbase/database/query"
"github.com/safing/portbase/database/record" "github.com/safing/portbase/database/record"
@ -19,18 +17,11 @@ type Controller struct {
storage storage.Interface storage storage.Interface
shadowDelete bool shadowDelete bool
hooks []*RegisteredHook hooksLock sync.RWMutex
subscriptions []*Subscription hooks []*RegisteredHook
writeLock sync.RWMutex subscriptionLock sync.RWMutex
// Lock: nobody may write subscriptions []*Subscription
// RLock: concurrent writing
readLock sync.RWMutex
// Lock: nobody may read
// RLock: concurrent reading
migrating *abool.AtomicBool // TODO
hibernating *abool.AtomicBool // TODO
} }
// newController creates a new controller for a storage. // newController creates a new controller for a storage.
@ -38,8 +29,6 @@ func newController(storageInt storage.Interface, shadowDelete bool) *Controller
return &Controller{ return &Controller{
storage: storageInt, storage: storageInt,
shadowDelete: shadowDelete, shadowDelete: shadowDelete,
migrating: abool.NewBool(false),
hibernating: abool.NewBool(false),
} }
} }
@ -55,21 +44,12 @@ func (c *Controller) Injected() bool {
// Get return the record with the given key. // Get return the record with the given key.
func (c *Controller) Get(key string) (record.Record, error) { func (c *Controller) Get(key string) (record.Record, error) {
c.readLock.RLock()
defer c.readLock.RUnlock()
if shuttingDown.IsSet() { if shuttingDown.IsSet() {
return nil, ErrShuttingDown return nil, ErrShuttingDown
} }
// process hooks if err := c.runPreGetHooks(key); err != nil {
for _, hook := range c.hooks { return nil, err
if hook.h.UsesPreGet() && hook.q.MatchesKey(key) {
err := hook.h.PreGet(key)
if err != nil {
return nil, err
}
}
} }
r, err := c.storage.Get(key) r, err := c.storage.Get(key)
@ -84,14 +64,9 @@ func (c *Controller) Get(key string) (record.Record, error) {
r.Lock() r.Lock()
defer r.Unlock() defer r.Unlock()
// process hooks r, err = c.runPostGetHooks(r)
for _, hook := range c.hooks { if err != nil {
if hook.h.UsesPostGet() && hook.q.Matches(r) { return nil, err
r, err = hook.h.PostGet(r)
if err != nil {
return nil, err
}
}
} }
if !r.Meta().CheckValidity() { if !r.Meta().CheckValidity() {
@ -101,11 +76,11 @@ func (c *Controller) Get(key string) (record.Record, error) {
return r, nil return r, nil
} }
// Put saves a record in the database. // Put saves a record in the database, executes any registered
// pre-put hooks and finally send an update to all subscribers.
// The record must be locked and secured from concurrent access
// when calling Put().
func (c *Controller) Put(r record.Record) (err error) { func (c *Controller) Put(r record.Record) (err error) {
c.writeLock.RLock()
defer c.writeLock.RUnlock()
if shuttingDown.IsSet() { if shuttingDown.IsSet() {
return ErrShuttingDown return ErrShuttingDown
} }
@ -114,51 +89,35 @@ func (c *Controller) Put(r record.Record) (err error) {
return ErrReadOnly return ErrReadOnly
} }
// process hooks r, err = c.runPrePutHooks(r)
for _, hook := range c.hooks { if err != nil {
if hook.h.UsesPrePut() && hook.q.Matches(r) { return err
r, err = hook.h.PrePut(r)
if err != nil {
return err
}
}
} }
if !c.shadowDelete && r.Meta().IsDeleted() { if !c.shadowDelete && r.Meta().IsDeleted() {
// Immediate delete. // Immediate delete.
err = c.storage.Delete(r.DatabaseKey()) err = c.storage.Delete(r.DatabaseKey())
if err != nil {
return err
}
} else { } else {
// Put or shadow delete. // Put or shadow delete.
r, err = c.storage.Put(r) r, err = c.storage.Put(r)
if err != nil {
return err
}
if r == nil {
return errors.New("storage returned nil record after successful put operation")
}
} }
// process subscriptions if err != nil {
for _, sub := range c.subscriptions { return err
if r.Meta().CheckPermission(sub.local, sub.internal) && sub.q.Matches(r) {
select {
case sub.Feed <- r:
default:
}
}
} }
if r == nil {
return errors.New("storage returned nil record after successful put operation")
}
c.notifySubscribers(r)
return nil return nil
} }
// PutMany stores many records in the database. // PutMany stores many records in the database. It does not
// process any hooks or update subscriptions. Use with care!
func (c *Controller) PutMany() (chan<- record.Record, <-chan error) { func (c *Controller) PutMany() (chan<- record.Record, <-chan error) {
c.writeLock.RLock()
defer c.writeLock.RUnlock()
if shuttingDown.IsSet() { if shuttingDown.IsSet() {
errs := make(chan error, 1) errs := make(chan error, 1)
errs <- ErrShuttingDown errs <- ErrShuttingDown
@ -182,67 +141,44 @@ func (c *Controller) PutMany() (chan<- record.Record, <-chan error) {
// Query executes the given query on the database. // Query executes the given query on the database.
func (c *Controller) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) { func (c *Controller) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
c.readLock.RLock()
if shuttingDown.IsSet() { if shuttingDown.IsSet() {
c.readLock.RUnlock()
return nil, ErrShuttingDown return nil, ErrShuttingDown
} }
it, err := c.storage.Query(q, local, internal) it, err := c.storage.Query(q, local, internal)
if err != nil { if err != nil {
c.readLock.RUnlock()
return nil, err return nil, err
} }
go c.readUnlockerAfterQuery(it)
return it, nil return it, nil
} }
// PushUpdate pushes a record update to subscribers. // PushUpdate pushes a record update to subscribers.
// The caller must hold the record's lock when calling
// PushUpdate.
func (c *Controller) PushUpdate(r record.Record) { func (c *Controller) PushUpdate(r record.Record) {
if c != nil { if c != nil {
c.readLock.RLock()
defer c.readLock.RUnlock()
if shuttingDown.IsSet() { if shuttingDown.IsSet() {
return return
} }
for _, sub := range c.subscriptions { c.notifySubscribers(r)
if r.Meta().CheckPermission(sub.local, sub.internal) && sub.q.Matches(r) {
select {
case sub.Feed <- r:
default:
}
}
}
} }
} }
func (c *Controller) addSubscription(sub *Subscription) { func (c *Controller) addSubscription(sub *Subscription) {
c.readLock.Lock()
defer c.readLock.Unlock()
c.writeLock.Lock()
defer c.writeLock.Unlock()
if shuttingDown.IsSet() { if shuttingDown.IsSet() {
return return
} }
c.subscriptions = append(c.subscriptions, sub) c.subscriptionLock.Lock()
} defer c.subscriptionLock.Unlock()
func (c *Controller) readUnlockerAfterQuery(it *iterator.Iterator) { c.subscriptions = append(c.subscriptions, sub)
<-it.Done
c.readLock.RUnlock()
} }
// Maintain runs the Maintain method on the storage. // Maintain runs the Maintain method on the storage.
func (c *Controller) Maintain(ctx context.Context) error { func (c *Controller) Maintain(ctx context.Context) error {
c.writeLock.RLock()
defer c.writeLock.RUnlock()
if shuttingDown.IsSet() { if shuttingDown.IsSet() {
return ErrShuttingDown return ErrShuttingDown
} }
@ -253,11 +189,9 @@ func (c *Controller) Maintain(ctx context.Context) error {
return nil return nil
} }
// MaintainThorough runs the MaintainThorough method on the storage. // MaintainThorough runs the MaintainThorough method on the
// storage.
func (c *Controller) MaintainThorough(ctx context.Context) error { func (c *Controller) MaintainThorough(ctx context.Context) error {
c.writeLock.RLock()
defer c.writeLock.RUnlock()
if shuttingDown.IsSet() { if shuttingDown.IsSet() {
return ErrShuttingDown return ErrShuttingDown
} }
@ -268,11 +202,9 @@ func (c *Controller) MaintainThorough(ctx context.Context) error {
return nil return nil
} }
// MaintainRecordStates runs the record state lifecycle maintenance on the storage. // MaintainRecordStates runs the record state lifecycle
// maintenance on the storage.
func (c *Controller) MaintainRecordStates(ctx context.Context, purgeDeletedBefore time.Time) error { func (c *Controller) MaintainRecordStates(ctx context.Context, purgeDeletedBefore time.Time) error {
c.writeLock.RLock()
defer c.writeLock.RUnlock()
if shuttingDown.IsSet() { if shuttingDown.IsSet() {
return ErrShuttingDown return ErrShuttingDown
} }
@ -280,11 +212,9 @@ func (c *Controller) MaintainRecordStates(ctx context.Context, purgeDeletedBefor
return c.storage.MaintainRecordStates(ctx, purgeDeletedBefore, c.shadowDelete) return c.storage.MaintainRecordStates(ctx, purgeDeletedBefore, c.shadowDelete)
} }
// Purge deletes all records that match the given query. It returns the number of successful deletes and an error. // Purge deletes all records that match the given query.
// It returns the number of successful deletes and an error.
func (c *Controller) Purge(ctx context.Context, q *query.Query, local, internal bool) (int, error) { func (c *Controller) Purge(ctx context.Context, q *query.Query, local, internal bool) (int, error) {
c.writeLock.RLock()
defer c.writeLock.RUnlock()
if shuttingDown.IsSet() { if shuttingDown.IsSet() {
return 0, ErrShuttingDown return 0, ErrShuttingDown
} }
@ -292,16 +222,96 @@ func (c *Controller) Purge(ctx context.Context, q *query.Query, local, internal
if purger, ok := c.storage.(storage.Purger); ok { if purger, ok := c.storage.(storage.Purger); ok {
return purger.Purge(ctx, q, local, internal, c.shadowDelete) return purger.Purge(ctx, q, local, internal, c.shadowDelete)
} }
return 0, ErrNotImplemented return 0, ErrNotImplemented
} }
// Shutdown shuts down the storage. // Shutdown shuts down the storage.
func (c *Controller) Shutdown() error { func (c *Controller) Shutdown() error {
// acquire full locks
c.readLock.Lock()
defer c.readLock.Unlock()
c.writeLock.Lock()
defer c.writeLock.Unlock()
return c.storage.Shutdown() return c.storage.Shutdown()
} }
// notifySubscribers notifies all subscribers that are interested
// in r. r must be locked when calling notifySubscribers.
// Any subscriber that is not blocking on it's feed channel will
// be skipped.
func (c *Controller) notifySubscribers(r record.Record) {
c.subscriptionLock.RLock()
defer c.subscriptionLock.RUnlock()
for _, sub := range c.subscriptions {
if r.Meta().CheckPermission(sub.local, sub.internal) && sub.q.Matches(r) {
select {
case sub.Feed <- r:
default:
}
}
}
}
func (c *Controller) runPreGetHooks(key string) error {
c.hooksLock.RLock()
defer c.hooksLock.RUnlock()
for _, hook := range c.hooks {
if !hook.h.UsesPreGet() {
continue
}
if !hook.q.MatchesKey(key) {
continue
}
if err := hook.h.PreGet(key); err != nil {
return err
}
}
return nil
}
func (c *Controller) runPostGetHooks(r record.Record) (record.Record, error) {
c.hooksLock.RLock()
defer c.hooksLock.RUnlock()
var err error
for _, hook := range c.hooks {
if !hook.h.UsesPostGet() {
continue
}
if !hook.q.Matches(r) {
continue
}
r, err = hook.h.PostGet(r)
if err != nil {
return nil, err
}
}
return r, nil
}
func (c *Controller) runPrePutHooks(r record.Record) (record.Record, error) {
c.hooksLock.RLock()
defer c.hooksLock.RUnlock()
var err error
for _, hook := range c.hooks {
if !hook.h.UsesPrePut() {
continue
}
if !hook.q.Matches(r) {
continue
}
r, err = hook.h.PrePut(r)
if err != nil {
return nil, err
}
}
return r, nil
}

View file

@ -1,7 +1,6 @@
package database package database
import ( import (
"errors"
"time" "time"
) )
@ -16,11 +15,6 @@ type Database struct {
LastLoaded time.Time LastLoaded time.Time
} }
// MigrateTo migrates the database to another storage type.
func (db *Database) MigrateTo(newStorageType string) error {
return errors.New("not implemented yet") // TODO
}
// Loaded updates the LastLoaded timestamp. // Loaded updates the LastLoaded timestamp.
func (db *Database) Loaded() { func (db *Database) Loaded() {
db.LastLoaded = time.Now().Round(time.Second) db.LastLoaded = time.Now().Round(time.Second)

View file

@ -22,6 +22,26 @@ import (
_ "github.com/safing/portbase/database/storage/hashmap" _ "github.com/safing/portbase/database/storage/hashmap"
) )
func TestMain(m *testing.M) {
testDir, err := ioutil.TempDir("", "portbase-database-testing-")
if err != nil {
panic(err)
}
err = InitializeWithPath(testDir)
if err != nil {
panic(err)
}
exitCode := m.Run()
// Clean up the test directory.
// Do not defer, as we end this function with a os.Exit call.
os.RemoveAll(testDir)
os.Exit(exitCode)
}
func makeKey(dbName, key string) string { func makeKey(dbName, key string) string {
return fmt.Sprintf("%s:%s", dbName, key) return fmt.Sprintf("%s:%s", dbName, key)
} }
@ -220,24 +240,18 @@ func testDatabase(t *testing.T, storageType string, shadowDelete bool) { //nolin
func TestDatabaseSystem(t *testing.T) { func TestDatabaseSystem(t *testing.T) {
// panic after 10 seconds, to check for locks // panic after 10 seconds, to check for locks
finished := make(chan struct{})
defer close(finished)
go func() { go func() {
time.Sleep(10 * time.Second) select {
fmt.Println("===== TAKING TOO LONG - PRINTING STACK TRACES =====") case <-finished:
_ = pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) case <-time.After(10 * time.Second):
os.Exit(1) fmt.Println("===== TAKING TOO LONG - PRINTING STACK TRACES =====")
_ = pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
os.Exit(1)
}
}() }()
testDir, err := ioutil.TempDir("", "portbase-database-testing-")
if err != nil {
t.Fatal(err)
}
err = InitializeWithPath(testDir)
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(testDir) // clean up
for _, shadowDelete := range []bool{false, true} { for _, shadowDelete := range []bool{false, true} {
testDatabase(t, "bbolt", shadowDelete) testDatabase(t, "bbolt", shadowDelete)
testDatabase(t, "hashmap", shadowDelete) testDatabase(t, "hashmap", shadowDelete)
@ -246,7 +260,7 @@ func TestDatabaseSystem(t *testing.T) {
// TODO: Fix badger tests // TODO: Fix badger tests
} }
err = MaintainRecordStates(context.TODO()) err := MaintainRecordStates(context.TODO())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -5,15 +5,36 @@ import (
"github.com/safing/portbase/database/record" "github.com/safing/portbase/database/record"
) )
// Hook describes a hook // Hook can be registered for a database query and
// will be executed at certain points during the life
// cycle of a database record.
type Hook interface { type Hook interface {
// UsesPreGet should return true if the hook's PreGet
// should be called prior to loading a database record
// from the underlying storage.
UsesPreGet() bool UsesPreGet() bool
// PreGet is called before a database record is loaded from
// the underlying storage. A PreGet hookd may be used to
// implement more advanced access control on database keys.
PreGet(dbKey string) error PreGet(dbKey string) error
// UsesPostGet should return true if the hook's PostGet
// should be called after loading a database record from
// the underlying storage.
UsesPostGet() bool UsesPostGet() bool
// PostGet is called after a record has been loaded form the
// underlying storage and may perform additional mutation
// or access check based on the records data.
// The passed record is already locked by the database system
// so users can safely access all data of r.
PostGet(r record.Record) (record.Record, error) PostGet(r record.Record) (record.Record, error)
// UsesPrePut should return true if the hook's PrePut method
// should be called prior to saving a record in the database.
UsesPrePut() bool UsesPrePut() bool
// PrePut is called prior to saving (creating or updating) a
// record in the database storage. It may be used to perform
// extended validation or mutations on the record.
// The passed record is already locked by the database system
// so users can safely access all data of r.
PrePut(r record.Record) (record.Record, error) PrePut(r record.Record) (record.Record, error)
} }
@ -23,7 +44,8 @@ type RegisteredHook struct {
h Hook h Hook
} }
// RegisterHook registers a hook for records matching the given query in the database. // RegisterHook registers a hook for records matching the given
// query in the database.
func RegisterHook(q *query.Query, hook Hook) (*RegisteredHook, error) { func RegisterHook(q *query.Query, hook Hook) (*RegisteredHook, error) {
_, err := q.Check() _, err := q.Check()
if err != nil { if err != nil {
@ -35,30 +57,29 @@ func RegisterHook(q *query.Query, hook Hook) (*RegisteredHook, error) {
return nil, err return nil, err
} }
c.readLock.Lock()
defer c.readLock.Unlock()
c.writeLock.Lock()
defer c.writeLock.Unlock()
rh := &RegisteredHook{ rh := &RegisteredHook{
q: q, q: q,
h: hook, h: hook,
} }
c.hooksLock.Lock()
defer c.hooksLock.Unlock()
c.hooks = append(c.hooks, rh) c.hooks = append(c.hooks, rh)
return rh, nil return rh, nil
} }
// Cancel unhooks the hook. // Cancel unregisteres the hook from the database. Once
// Cancel returned the hook's methods will not be called
// anymore for updates that matched the registered query.
func (h *RegisteredHook) Cancel() error { func (h *RegisteredHook) Cancel() error {
c, err := getController(h.q.DatabaseName()) c, err := getController(h.q.DatabaseName())
if err != nil { if err != nil {
return err return err
} }
c.readLock.Lock() c.hooksLock.Lock()
defer c.readLock.Unlock() defer c.hooksLock.Unlock()
c.writeLock.Lock()
defer c.writeLock.Unlock()
for key, hook := range c.hooks { for key, hook := range c.hooks {
if hook.q == h.q { if hook.q == h.q {

View file

@ -4,11 +4,11 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"sync"
"time" "time"
"github.com/tevino/abool"
"github.com/bluele/gcache" "github.com/bluele/gcache"
"github.com/tevino/abool"
"github.com/safing/portbase/database/accessor" "github.com/safing/portbase/database/accessor"
"github.com/safing/portbase/database/iterator" "github.com/safing/portbase/database/iterator"
@ -24,17 +24,56 @@ const (
type Interface struct { type Interface struct {
options *Options options *Options
cache gcache.Cache cache gcache.Cache
writeCache map[string]record.Record
writeCacheLock sync.Mutex
triggerCacheWrite chan struct{}
} }
// Options holds options that may be set for an Interface instance. // Options holds options that may be set for an Interface instance.
type Options struct { type Options struct {
Local bool // Local specifies if the interface is used by an actor on the local device.
Internal bool // Setting both the Local and Internal flags will bring performance
AlwaysMakeSecret bool // improvements because less checks are needed.
AlwaysMakeCrownjewel bool Local bool
// Internal specifies if the interface is used by an actor within the
// software. Setting both the Local and Internal flags will bring performance
// improvements because less checks are needed.
Internal bool
// AlwaysMakeSecret will have the interface mark all saved records as secret.
// This means that they will be only accessible by an internal interface.
AlwaysMakeSecret bool
// AlwaysMakeCrownjewel will have the interface mark all saved records as
// crown jewels. This means that they will be only accessible by a local
// interface.
AlwaysMakeCrownjewel bool
// AlwaysSetRelativateExpiry will have the interface set a relative expiry,
// based on the current time, on all saved records.
AlwaysSetRelativateExpiry int64 AlwaysSetRelativateExpiry int64
AlwaysSetAbsoluteExpiry int64
CacheSize int // AlwaysSetAbsoluteExpiry will have the interface set an absolute expiry on
// all saved records.
AlwaysSetAbsoluteExpiry int64
// CacheSize defines that a cache should be used for this interface and
// defines it's size.
// Caching comes with an important caveat: If database records are changed
// from another interface, the cache will not be invalidated for these
// records. It will therefore serve outdated data until that record is
// evicted from the cache.
CacheSize int
// DelayCachedWrites defines a database name for which cache writes should
// be cached and batched. The database backend must support the Batcher
// interface. This option is only valid if used with a cache.
// Additionally, this may only be used for internal and local interfaces.
// Please note that this means that other interfaces will not be able to
// guarantee to serve the latest record if records are written this way.
DelayCachedWrites string
} }
// Apply applies options to the record metadata. // Apply applies options to the record metadata.
@ -53,6 +92,28 @@ func (o *Options) Apply(r record.Record) {
} }
} }
// HasAllPermissions returns whether the options specify the highest possible
// permissions for operations.
func (o *Options) HasAllPermissions() bool {
return o.Local && o.Internal
}
// hasAccessPermission checks if the interface options permit access to the
// given record, locking the record for accessing it's attributes.
func (o *Options) hasAccessPermission(r record.Record) bool {
// Check if the options specify all permissions, which makes checking the
// record unnecessary.
if o.HasAllPermissions() {
return true
}
r.Lock()
defer r.Unlock()
// Check permissions against record.
return r.Meta().CheckPermission(o.Local, o.Internal)
}
// NewInterface returns a new Interface to the database. // NewInterface returns a new Interface to the database.
func NewInterface(opts *Options) *Interface { func NewInterface(opts *Options) *Interface {
if opts == nil { if opts == nil {
@ -63,57 +124,40 @@ func NewInterface(opts *Options) *Interface {
options: opts, options: opts,
} }
if opts.CacheSize > 0 { if opts.CacheSize > 0 {
new.cache = gcache.New(opts.CacheSize).ARC().Expiration(time.Hour).Build() cacheBuilder := gcache.New(opts.CacheSize).ARC()
if opts.DelayCachedWrites != "" {
cacheBuilder.EvictedFunc(new.cacheEvictHandler)
new.writeCache = make(map[string]record.Record, opts.CacheSize/2)
new.triggerCacheWrite = make(chan struct{})
}
new.cache = cacheBuilder.Build()
} }
return new return new
} }
func (i *Interface) checkCache(key string) (record.Record, bool) {
if i.cache != nil {
cacheVal, err := i.cache.Get(key)
if err == nil {
r, ok := cacheVal.(record.Record)
if ok {
return r, true
}
}
}
return nil, false
}
func (i *Interface) updateCache(r record.Record) {
if i.cache != nil {
_ = i.cache.Set(r.Key(), r)
}
}
// Exists return whether a record with the given key exists. // Exists return whether a record with the given key exists.
func (i *Interface) Exists(key string) (bool, error) { func (i *Interface) Exists(key string) (bool, error) {
_, _, err := i.getRecord(getDBFromKey, key, false, false) _, err := i.Get(key)
if err != nil { if err != nil {
if err == ErrNotFound { switch {
case errors.Is(err, ErrNotFound):
return false, nil return false, nil
case errors.Is(err, ErrPermissionDenied):
return true, nil
default:
return false, err
} }
return false, err
} }
return true, nil return true, nil
} }
// Get return the record with the given key. // Get return the record with the given key.
func (i *Interface) Get(key string) (record.Record, error) { func (i *Interface) Get(key string) (record.Record, error) {
r, ok := i.checkCache(key) r, _, err := i.getRecord(getDBFromKey, key, false)
if ok {
if !r.Meta().CheckPermission(i.options.Local, i.options.Internal) {
return nil, ErrPermissionDenied
}
return r, nil
}
r, _, err := i.getRecord(getDBFromKey, key, true, false)
return r, err return r, err
} }
func (i *Interface) getRecord(dbName string, dbKey string, check bool, mustBeWriteable bool) (r record.Record, db *Controller, err error) { func (i *Interface) getRecord(dbName string, dbKey string, mustBeWriteable bool) (r record.Record, db *Controller, err error) {
if dbName == "" { if dbName == "" {
dbName, dbKey = record.ParseKey(dbKey) dbName, dbKey = record.ParseKey(dbKey)
} }
@ -124,27 +168,42 @@ func (i *Interface) getRecord(dbName string, dbKey string, check bool, mustBeWri
} }
if mustBeWriteable && db.ReadOnly() { if mustBeWriteable && db.ReadOnly() {
return nil, nil, ErrReadOnly return nil, db, ErrReadOnly
}
r = i.checkCache(dbName + ":" + dbKey)
if r != nil {
if !i.options.hasAccessPermission(r) {
return nil, db, ErrPermissionDenied
}
return r, db, nil
} }
r, err = db.Get(dbKey) r, err = db.Get(dbKey)
if err != nil { if err != nil {
if err == ErrNotFound { return nil, db, err
return nil, db, err
}
return nil, nil, err
} }
if check && !r.Meta().CheckPermission(i.options.Local, i.options.Internal) { if !i.options.hasAccessPermission(r) {
return nil, nil, ErrPermissionDenied return nil, db, ErrPermissionDenied
} }
r.Lock()
ttl := r.Meta().GetRelativeExpiry()
r.Unlock()
i.updateCache(
r,
false, // writing
false, // remove
ttl, // expiry
)
return r, db, nil return r, db, nil
} }
// InsertValue inserts a value into a record. // InsertValue inserts a value into a record.
func (i *Interface) InsertValue(key string, attribute string, value interface{}) error { func (i *Interface) InsertValue(key string, attribute string, value interface{}) error {
r, db, err := i.getRecord(getDBFromKey, key, true, true) r, db, err := i.getRecord(getDBFromKey, key, true)
if err != nil { if err != nil {
return err return err
} }
@ -176,8 +235,46 @@ func (i *Interface) InsertValue(key string, attribute string, value interface{})
func (i *Interface) Put(r record.Record) (err error) { func (i *Interface) Put(r record.Record) (err error) {
// get record or only database // get record or only database
var db *Controller var db *Controller
if !i.options.Internal || !i.options.Local { if !i.options.HasAllPermissions() {
_, db, err = i.getRecord(r.DatabaseName(), r.DatabaseKey(), true, true) _, db, err = i.getRecord(r.DatabaseName(), r.DatabaseKey(), true)
if err != nil && err != ErrNotFound {
return err
}
} else {
db, err = getController(r.DatabaseName())
if err != nil {
return err
}
}
// Check if database is read only before we add to the cache.
if db.ReadOnly() {
return ErrReadOnly
}
r.Lock()
i.options.Apply(r)
remove := r.Meta().IsDeleted()
ttl := r.Meta().GetRelativeExpiry()
r.Unlock()
// The record may not be locked when updating the cache.
written := i.updateCache(r, true, remove, ttl)
if written {
return nil
}
r.Lock()
defer r.Unlock()
return db.Put(r)
}
// PutNew saves a record to the database as a new record (ie. with new timestamps).
func (i *Interface) PutNew(r record.Record) (err error) {
// get record or only database
var db *Controller
if !i.options.HasAllPermissions() {
_, db, err = i.getRecord(r.DatabaseName(), r.DatabaseKey(), true)
if err != nil && err != ErrNotFound { if err != nil && err != ErrNotFound {
return err return err
} }
@ -189,38 +286,22 @@ func (i *Interface) Put(r record.Record) (err error) {
} }
r.Lock() r.Lock()
defer r.Unlock()
i.options.Apply(r)
i.updateCache(r)
return db.Put(r)
}
// PutNew saves a record to the database as a new record (ie. with new timestamps).
func (i *Interface) PutNew(r record.Record) (err error) {
// get record or only database
var db *Controller
if !i.options.Internal || !i.options.Local {
_, db, err = i.getRecord(r.DatabaseName(), r.DatabaseKey(), true, true)
if err != nil && err != ErrNotFound {
return err
}
} else {
db, err = getController(r.DatabaseKey())
if err != nil {
return err
}
}
r.Lock()
defer r.Unlock()
if r.Meta() != nil { if r.Meta() != nil {
r.Meta().Reset() r.Meta().Reset()
} }
i.options.Apply(r) i.options.Apply(r)
i.updateCache(r) remove := r.Meta().IsDeleted()
ttl := r.Meta().GetRelativeExpiry()
r.Unlock()
// The record may not be locked when updating the cache.
written := i.updateCache(r, true, remove, ttl)
if written {
return nil
}
r.Lock()
defer r.Unlock()
return db.Put(r) return db.Put(r)
} }
@ -233,7 +314,7 @@ func (i *Interface) PutMany(dbName string) (put func(record.Record) error) {
interfaceBatch := make(chan record.Record, 100) interfaceBatch := make(chan record.Record, 100)
// permission check // permission check
if !i.options.Internal || !i.options.Local { if !i.options.HasAllPermissions() {
return func(r record.Record) error { return func(r record.Record) error {
return ErrPermissionDenied return ErrPermissionDenied
} }
@ -316,7 +397,7 @@ func (i *Interface) PutMany(dbName string) (put func(record.Record) error) {
// SetAbsoluteExpiry sets an absolute record expiry. // SetAbsoluteExpiry sets an absolute record expiry.
func (i *Interface) SetAbsoluteExpiry(key string, time int64) error { func (i *Interface) SetAbsoluteExpiry(key string, time int64) error {
r, db, err := i.getRecord(getDBFromKey, key, true, true) r, db, err := i.getRecord(getDBFromKey, key, true)
if err != nil { if err != nil {
return err return err
} }
@ -331,7 +412,7 @@ func (i *Interface) SetAbsoluteExpiry(key string, time int64) error {
// SetRelativateExpiry sets a relative (self-updating) record expiry. // SetRelativateExpiry sets a relative (self-updating) record expiry.
func (i *Interface) SetRelativateExpiry(key string, duration int64) error { func (i *Interface) SetRelativateExpiry(key string, duration int64) error {
r, db, err := i.getRecord(getDBFromKey, key, true, true) r, db, err := i.getRecord(getDBFromKey, key, true)
if err != nil { if err != nil {
return err return err
} }
@ -346,7 +427,7 @@ func (i *Interface) SetRelativateExpiry(key string, duration int64) error {
// MakeSecret marks the record as a secret, meaning interfacing processes, such as an UI, are denied access to the record. // MakeSecret marks the record as a secret, meaning interfacing processes, such as an UI, are denied access to the record.
func (i *Interface) MakeSecret(key string) error { func (i *Interface) MakeSecret(key string) error {
r, db, err := i.getRecord(getDBFromKey, key, true, true) r, db, err := i.getRecord(getDBFromKey, key, true)
if err != nil { if err != nil {
return err return err
} }
@ -361,7 +442,7 @@ func (i *Interface) MakeSecret(key string) error {
// MakeCrownJewel marks a record as a crown jewel, meaning it will only be accessible locally. // MakeCrownJewel marks a record as a crown jewel, meaning it will only be accessible locally.
func (i *Interface) MakeCrownJewel(key string) error { func (i *Interface) MakeCrownJewel(key string) error {
r, db, err := i.getRecord(getDBFromKey, key, true, true) r, db, err := i.getRecord(getDBFromKey, key, true)
if err != nil { if err != nil {
return err return err
} }
@ -376,7 +457,7 @@ func (i *Interface) MakeCrownJewel(key string) error {
// Delete deletes a record from the database. // Delete deletes a record from the database.
func (i *Interface) Delete(key string) error { func (i *Interface) Delete(key string) error {
r, db, err := i.getRecord(getDBFromKey, key, true, true) r, db, err := i.getRecord(getDBFromKey, key, true)
if err != nil { if err != nil {
return err return err
} }

207
database/interface_cache.go Normal file
View file

@ -0,0 +1,207 @@
package database
import (
"context"
"errors"
"time"
"github.com/safing/portbase/database/record"
"github.com/safing/portbase/log"
)
// DelayedCacheWriter must be run by the caller of an interface that uses delayed cache writing.
func (i *Interface) DelayedCacheWriter(ctx context.Context) error {
// Check if the DelayedCacheWriter should be run at all.
if i.options.CacheSize <= 0 || i.options.DelayCachedWrites == "" {
return errors.New("delayed cache writer is not applicable to this database interface")
}
// Check if backend support the Batcher interface.
batchPut := i.PutMany(i.options.DelayCachedWrites)
// End batchPut immediately and check for an error.
err := batchPut(nil)
if err != nil {
return err
}
// percentThreshold defines the minimum percentage of entries in the write cache in relation to the cache size that need to be present in order for flushing the cache to the database storage.
percentThreshold := 25
thresholdWriteTicker := time.NewTicker(5 * time.Second)
forceWriteTicker := time.NewTicker(5 * time.Minute)
for {
// Wait for trigger for writing the cache.
select {
case <-ctx.Done():
// The caller is shutting down, flush the cache to storage and exit.
i.flushWriteCache(0)
return nil
case <-i.triggerCacheWrite:
// An entry from the cache was evicted that was also in the write cache.
// This makes it likely that other entries that are also present in the
// write cache will be evicted soon. Flush the write cache to storage
// immediately in order to reduce single writes.
i.flushWriteCache(0)
case <-thresholdWriteTicker.C:
// Often check if the the write cache has filled up to a certain degree and
// flush it to storage before we start evicting to-be-written entries and
// slow down the hot path again.
i.flushWriteCache(percentThreshold)
case <-forceWriteTicker.C:
// Once in a while, flush the write cache to storage no matter how much
// it is filled. We don't want entries lingering around in the write
// cache forever. This also reduces the amount of data loss in the event
// of a total crash.
i.flushWriteCache(0)
}
}
}
func (i *Interface) flushWriteCache(percentThreshold int) {
i.writeCacheLock.Lock()
defer i.writeCacheLock.Unlock()
// Check if there is anything to do.
if len(i.writeCache) == 0 {
return
}
// Check if we reach the given threshold for writing to storage.
if (len(i.writeCache)*100)/i.options.CacheSize < percentThreshold {
return
}
// Write the full cache in a batch operation.
batchPut := i.PutMany(i.options.DelayCachedWrites)
for _, r := range i.writeCache {
err := batchPut(r)
if err != nil {
log.Warningf("database: failed to write write-cached entry to %q database: %s", i.options.DelayCachedWrites, err)
}
}
// Finish batch.
err := batchPut(nil)
if err != nil {
log.Warningf("database: failed to finish flushing write cache to %q database: %s", i.options.DelayCachedWrites, err)
}
// Optimized map clearing following the Go1.11 recommendation.
for key := range i.writeCache {
delete(i.writeCache, key)
}
}
// cacheEvictHandler is run by the cache for every entry that gets evicted
// from the cache.
func (i *Interface) cacheEvictHandler(keyData, _ interface{}) {
// Transform the key into a string.
key, ok := keyData.(string)
if !ok {
return
}
// Check if the evicted record is one that is to be written.
// Lock the write cache until the end of the function.
// The read cache is locked anyway for the whole duration.
i.writeCacheLock.Lock()
defer i.writeCacheLock.Unlock()
r, ok := i.writeCache[key]
if ok {
delete(i.writeCache, key)
}
if !ok {
return
}
// Write record to database in order to mitigate race conditions where the record would appear
// as non-existent for a short duration.
db, err := getController(r.DatabaseName())
if err != nil {
log.Warningf("database: failed to write evicted cache entry %q: database %q does not exist", key, r.DatabaseName())
return
}
r.Lock()
defer r.Unlock()
err = db.Put(r)
if err != nil {
log.Warningf("database: failed to write evicted cache entry %q to database: %s", key, err)
}
// Finally, trigger writing the full write cache because a to-be-written
// entry was just evicted from the cache, and this makes it likely that more
// to-be-written entries will be evicted shortly.
select {
case i.triggerCacheWrite <- struct{}{}:
default:
}
}
func (i *Interface) checkCache(key string) record.Record {
// Check if cache is in use.
if i.cache == nil {
return nil
}
// Check if record exists in cache.
cacheVal, err := i.cache.Get(key)
if err == nil {
r, ok := cacheVal.(record.Record)
if ok {
return r
}
}
return nil
}
// updateCache updates an entry in the interface cache. The given record may
// not be locked, as updating the cache might write an (unrelated) evicted
// record to the database in the process. If this happens while the
// DelayedCacheWriter flushes the write cache with the same record present,
// this will deadlock.
func (i *Interface) updateCache(r record.Record, write bool, remove bool, ttl int64) (written bool) {
// Check if cache is in use.
if i.cache == nil {
return false
}
// Check if record should be deleted
if remove {
// Remove entry from cache.
i.cache.Remove(r.Key())
// Let write through to database storage.
return false
}
// Update cache with record.
if ttl >= 0 {
_ = i.cache.SetWithExpire(
r.Key(),
r,
time.Duration(ttl)*time.Second,
)
} else {
_ = i.cache.Set(
r.Key(),
r,
)
}
// Add record to write cache instead if:
// 1. The record is being written.
// 2. Write delaying is active.
// 3. Write delaying is active for the database of this record.
if write && r.DatabaseName() == i.options.DelayCachedWrites {
i.writeCacheLock.Lock()
defer i.writeCacheLock.Unlock()
i.writeCache[r.Key()] = r
return true
}
return false
}

View file

@ -0,0 +1,158 @@
package database
import (
"context"
"fmt"
"strconv"
"sync"
"testing"
)
func benchmarkCacheWriting(b *testing.B, storageType string, cacheSize int, sampleSize int, delayWrites bool) { //nolint:gocognit,gocyclo
b.Run(fmt.Sprintf("CacheWriting_%s_%d_%d_%v", storageType, cacheSize, sampleSize, delayWrites), func(b *testing.B) {
// Setup Benchmark.
// Create database.
dbName := fmt.Sprintf("cache-w-benchmark-%s-%d-%d-%v", storageType, cacheSize, sampleSize, delayWrites)
_, err := Register(&Database{
Name: dbName,
Description: fmt.Sprintf("Cache Benchmark Database for %s", storageType),
StorageType: storageType,
})
if err != nil {
b.Fatal(err)
}
// Create benchmark interface.
options := &Options{
Local: true,
Internal: true,
CacheSize: cacheSize,
}
if cacheSize > 0 && delayWrites {
options.DelayCachedWrites = dbName
}
db := NewInterface(options)
// Start
ctx, cancelCtx := context.WithCancel(context.Background())
var wg sync.WaitGroup
if cacheSize > 0 && delayWrites {
wg.Add(1)
go func() {
err := db.DelayedCacheWriter(ctx)
if err != nil {
panic(err)
}
wg.Done()
}()
}
// Start Benchmark.
b.ResetTimer()
for i := 0; i < b.N; i++ {
testRecordID := i % sampleSize
r := NewExample(
dbName+":"+strconv.Itoa(testRecordID),
"A",
1,
)
err = db.Put(r)
if err != nil {
b.Fatal(err)
}
}
// End cache writer and wait
cancelCtx()
wg.Wait()
})
}
func benchmarkCacheReadWrite(b *testing.B, storageType string, cacheSize int, sampleSize int, delayWrites bool) { //nolint:gocognit,gocyclo
b.Run(fmt.Sprintf("CacheReadWrite_%s_%d_%d_%v", storageType, cacheSize, sampleSize, delayWrites), func(b *testing.B) {
// Setup Benchmark.
// Create database.
dbName := fmt.Sprintf("cache-rw-benchmark-%s-%d-%d-%v", storageType, cacheSize, sampleSize, delayWrites)
_, err := Register(&Database{
Name: dbName,
Description: fmt.Sprintf("Cache Benchmark Database for %s", storageType),
StorageType: storageType,
})
if err != nil {
b.Fatal(err)
}
// Create benchmark interface.
options := &Options{
Local: true,
Internal: true,
CacheSize: cacheSize,
}
if cacheSize > 0 && delayWrites {
options.DelayCachedWrites = dbName
}
db := NewInterface(options)
// Start
ctx, cancelCtx := context.WithCancel(context.Background())
var wg sync.WaitGroup
if cacheSize > 0 && delayWrites {
wg.Add(1)
go func() {
err := db.DelayedCacheWriter(ctx)
if err != nil {
panic(err)
}
wg.Done()
}()
}
// Start Benchmark.
b.ResetTimer()
writing := true
for i := 0; i < b.N; i++ {
testRecordID := i % sampleSize
key := dbName + ":" + strconv.Itoa(testRecordID)
if i > 0 && testRecordID == 0 {
writing = !writing // switch between reading and writing every samplesize
}
if writing {
r := NewExample(key, "A", 1)
err = db.Put(r)
} else {
_, err = db.Get(key)
}
if err != nil {
b.Fatal(err)
}
}
// End cache writer and wait
cancelCtx()
wg.Wait()
})
}
func BenchmarkCache(b *testing.B) {
for _, storageType := range []string{"bbolt", "hashmap"} {
benchmarkCacheWriting(b, storageType, 32, 8, false)
benchmarkCacheWriting(b, storageType, 32, 8, true)
benchmarkCacheWriting(b, storageType, 32, 1024, false)
benchmarkCacheWriting(b, storageType, 32, 1024, true)
benchmarkCacheWriting(b, storageType, 512, 1024, false)
benchmarkCacheWriting(b, storageType, 512, 1024, true)
benchmarkCacheReadWrite(b, storageType, 32, 8, false)
benchmarkCacheReadWrite(b, storageType, 32, 8, true)
benchmarkCacheReadWrite(b, storageType, 32, 1024, false)
benchmarkCacheReadWrite(b, storageType, 32, 1024, true)
benchmarkCacheReadWrite(b, storageType, 512, 1024, false)
benchmarkCacheReadWrite(b, storageType, 512, 1024, true)
}
}

View file

@ -2,13 +2,29 @@ package record
import ( import (
"errors" "errors"
"fmt"
"github.com/safing/portbase/container" "github.com/safing/portbase/container"
"github.com/safing/portbase/database/accessor" "github.com/safing/portbase/database/accessor"
"github.com/safing/portbase/formats/dsd" "github.com/safing/portbase/formats/dsd"
"github.com/safing/portbase/log"
) )
// TODO(ppacher):
// we can reduce the record.Record interface a lot by moving
// most of those functions that require the Record as it's first
// parameter to static package functions
// (i.e. Marshal, MarshalRecord, GetAccessor, ...).
// We should also consider given Base a GetBase() *Base method
// that returns itself. This way we can remove almost all Base
// only methods from the record.Record interface. That is, we can
// remove all those CreateMeta, UpdateMeta, ... stuff from the
// interface definition (not the actual functions!). This would make
// the record.Record interface slim and only provide methods that
// most users actually need. All those database/storage related methods
// can still be accessed by using GetBase().XXX() instead. We can also
// expose the dbName and dbKey and meta properties directly which would
// make a nice JSON blob when marshalled.
// Base provides a quick way to comply with the Model interface. // Base provides a quick way to comply with the Model interface.
type Base struct { type Base struct {
dbName string dbName string
@ -16,37 +32,46 @@ type Base struct {
meta *Meta meta *Meta
} }
// SetKey sets the key on the database record. The key may only be set once and
// future calls to SetKey will be ignored. If you want to copy/move the record
// to another database key, you will need to create a copy and assign a new key.
// A key must be set before the record is used in any database operation.
func (b *Base) SetKey(key string) {
if !b.KeyIsSet() {
b.dbName, b.dbKey = ParseKey(key)
} else {
log.Errorf("database: key is already set: tried to replace %q with %q", b.Key(), key)
}
}
// Key returns the key of the database record. // Key returns the key of the database record.
// As the key must be set before any usage and can only be set once, this
// function may be used without locking the record.
func (b *Base) Key() string { func (b *Base) Key() string {
return fmt.Sprintf("%s:%s", b.dbName, b.dbKey) return b.dbName + ":" + b.dbKey
} }
// KeyIsSet returns true if the database key is set. // KeyIsSet returns true if the database key is set.
// As the key must be set before any usage and can only be set once, this
// function may be used without locking the record.
func (b *Base) KeyIsSet() bool { func (b *Base) KeyIsSet() bool {
return len(b.dbName) > 0 && len(b.dbKey) > 0 return b.dbName != ""
} }
// DatabaseName returns the name of the database. // DatabaseName returns the name of the database.
// As the key must be set before any usage and can only be set once, this
// function may be used without locking the record.
func (b *Base) DatabaseName() string { func (b *Base) DatabaseName() string {
return b.dbName return b.dbName
} }
// DatabaseKey returns the database key of the database record. // DatabaseKey returns the database key of the database record.
// As the key must be set before any usage and can only be set once, this
// function may be used without locking the record.
func (b *Base) DatabaseKey() string { func (b *Base) DatabaseKey() string {
return b.dbKey return b.dbKey
} }
// SetKey sets the key on the database record, it should only be called after loading the record. Use MoveTo to save the record with another key.
func (b *Base) SetKey(key string) {
b.dbName, b.dbKey = ParseKey(key)
}
// MoveTo sets a new key for the record and resets all metadata, except for the secret and crownjewel status.
func (b *Base) MoveTo(key string) {
b.SetKey(key)
b.meta.Reset()
}
// Meta returns the metadata object for this record. // Meta returns the metadata object for this record.
func (b *Base) Meta() *Meta { func (b *Base) Meta() *Meta {
return b.meta return b.meta
@ -60,7 +85,7 @@ func (b *Base) CreateMeta() {
// UpdateMeta creates the metadata if it does not exist and updates it. // UpdateMeta creates the metadata if it does not exist and updates it.
func (b *Base) UpdateMeta() { func (b *Base) UpdateMeta() {
if b.meta == nil { if b.meta == nil {
b.meta = &Meta{} b.CreateMeta()
} }
b.meta.Update() b.meta.Update()
} }

View file

@ -7,8 +7,8 @@ import (
// ParseKey splits a key into it's database name and key parts. // ParseKey splits a key into it's database name and key parts.
func ParseKey(key string) (dbName, dbKey string) { func ParseKey(key string) (dbName, dbKey string) {
splitted := strings.SplitN(key, ":", 2) splitted := strings.SplitN(key, ":", 2)
if len(splitted) == 2 { if len(splitted) < 2 {
return splitted[0], splitted[1] return splitted[0], ""
} }
return splitted[0], "" return splitted[0], strings.Join(splitted[1:], ":")
} }

View file

@ -31,9 +31,10 @@ func (m *Meta) GetAbsoluteExpiry() int64 {
} }
// GetRelativeExpiry returns the current relative expiry time - ie. seconds until expiry. // GetRelativeExpiry returns the current relative expiry time - ie. seconds until expiry.
// A negative value signifies that the record does not expire.
func (m *Meta) GetRelativeExpiry() int64 { func (m *Meta) GetRelativeExpiry() int64 {
if m.Deleted < 0 { if m.Expires == 0 {
return -m.Deleted return -1
} }
abs := m.Expires - time.Now().Unix() abs := m.Expires - time.Now().Unix()

View file

@ -6,13 +6,12 @@ import (
// Record provides an interface for uniformally handling database records. // Record provides an interface for uniformally handling database records.
type Record interface { type Record interface {
Key() string // test:config SetKey(key string) // test:config
Key() string // test:config
KeyIsSet() bool KeyIsSet() bool
DatabaseName() string // test DatabaseName() string // test
DatabaseKey() string // config DatabaseKey() string // config
SetKey(key string) // test:config
MoveTo(key string) // test:config
Meta() *Meta Meta() *Meta
SetMeta(meta *Meta) SetMeta(meta *Meta)
CreateMeta() CreateMeta()

View file

@ -115,26 +115,9 @@ func (b *BBolt) PutMany(shadowDelete bool) (chan<- record.Record, <-chan error)
err := b.db.Batch(func(tx *bbolt.Tx) error { err := b.db.Batch(func(tx *bbolt.Tx) error {
bucket := tx.Bucket(bucketName) bucket := tx.Bucket(bucketName)
for r := range batch { for r := range batch {
if !shadowDelete && r.Meta().IsDeleted() { txErr := b.batchPutOrDelete(bucket, shadowDelete, r)
// Immediate delete. if txErr != nil {
txErr := bucket.Delete([]byte(r.DatabaseKey())) return txErr
if txErr != nil {
return txErr
}
} else {
// Put or shadow delete.
// marshal
data, txErr := r.MarshalRecord(r)
if txErr != nil {
return txErr
}
// put
txErr = bucket.Put([]byte(r.DatabaseKey()), data)
if txErr != nil {
return txErr
}
} }
} }
return nil return nil
@ -145,6 +128,25 @@ func (b *BBolt) PutMany(shadowDelete bool) (chan<- record.Record, <-chan error)
return batch, errs return batch, errs
} }
func (b *BBolt) batchPutOrDelete(bucket *bbolt.Bucket, shadowDelete bool, r record.Record) (err error) {
r.Lock()
defer r.Unlock()
if !shadowDelete && r.Meta().IsDeleted() {
// Immediate delete.
err = bucket.Delete([]byte(r.DatabaseKey()))
} else {
// Put or shadow delete.
var data []byte
data, err = r.MarshalRecord(r)
if err == nil {
err = bucket.Put([]byte(r.DatabaseKey()), data)
}
}
return err
}
// Delete deletes a record from the database. // Delete deletes a record from the database.
func (b *BBolt) Delete(key string) error { func (b *BBolt) Delete(key string) error {
err := b.db.Update(func(tx *bbolt.Tx) error { err := b.db.Update(func(tx *bbolt.Tx) error {

View file

@ -66,11 +66,7 @@ func (hm *HashMap) PutMany(shadowDelete bool) (chan<- record.Record, <-chan erro
// start handler // start handler
go func() { go func() {
for r := range batch { for r := range batch {
if !shadowDelete && r.Meta().IsDeleted() { hm.batchPutOrDelete(shadowDelete, r)
delete(hm.db, r.DatabaseKey())
} else {
hm.db[r.DatabaseKey()] = r
}
} }
errs <- nil errs <- nil
}() }()
@ -78,6 +74,20 @@ func (hm *HashMap) PutMany(shadowDelete bool) (chan<- record.Record, <-chan erro
return batch, errs return batch, errs
} }
func (hm *HashMap) batchPutOrDelete(shadowDelete bool, r record.Record) {
r.Lock()
defer r.Unlock()
hm.dbLock.Lock()
defer hm.dbLock.Unlock()
if !shadowDelete && r.Meta().IsDeleted() {
delete(hm.db, r.DatabaseKey())
} else {
hm.db[r.DatabaseKey()] = r
}
}
// Delete deletes a record from the database. // Delete deletes a record from the database.
func (hm *HashMap) Delete(key string) error { func (hm *HashMap) Delete(key string) error {
hm.dbLock.Lock() hm.dbLock.Lock()
@ -108,17 +118,16 @@ func (hm *HashMap) queryExecutor(queryIter *iterator.Iterator, q *query.Query, l
mapLoop: mapLoop:
for key, record := range hm.db { for key, record := range hm.db {
record.Lock()
if !q.MatchesKey(key) ||
!q.MatchesRecord(record) ||
!record.Meta().CheckValidity() ||
!record.Meta().CheckPermission(local, internal) {
switch { record.Unlock()
case !q.MatchesKey(key):
continue
case !q.MatchesRecord(record):
continue
case !record.Meta().CheckValidity():
continue
case !record.Meta().CheckPermission(local, internal):
continue continue
} }
record.Unlock()
select { select {
case <-queryIter.Done: case <-queryIter.Done:

View file

@ -10,7 +10,6 @@ type Subscription struct {
q *query.Query q *query.Query
local bool local bool
internal bool internal bool
canceled bool
Feed chan record.Record Feed chan record.Record
} }
@ -22,20 +21,13 @@ func (s *Subscription) Cancel() error {
return err return err
} }
c.readLock.Lock() c.subscriptionLock.Lock()
defer c.readLock.Unlock() defer c.subscriptionLock.Unlock()
c.writeLock.Lock()
defer c.writeLock.Unlock()
if s.canceled {
return nil
}
s.canceled = true
close(s.Feed)
for key, sub := range c.subscriptions { for key, sub := range c.subscriptions {
if sub.q == s.q { if sub.q == s.q {
c.subscriptions = append(c.subscriptions[:key], c.subscriptions[key+1:]...) c.subscriptions = append(c.subscriptions[:key], c.subscriptions[key+1:]...)
close(s.Feed) // this close is guarded by the controllers subscriptionLock.
return nil return nil
} }
} }

View file

@ -32,7 +32,6 @@ func (s Severity) String() string {
} }
func formatLine(line *logLine, duplicates uint64, useColor bool) string { func formatLine(line *logLine, duplicates uint64, useColor bool) string {
colorStart := "" colorStart := ""
colorEnd := "" colorEnd := ""
if useColor { if useColor {

View file

@ -33,6 +33,16 @@ import (
// Severity describes a log level. // Severity describes a log level.
type Severity uint32 type Severity uint32
// Message describes a log level message and is implemented
// by logLine.
type Message interface {
Text() string
Severity() Severity
Time() time.Time
File() string
LineNumber() int
}
type logLine struct { type logLine struct {
msg string msg string
tracer *ContextTracer tracer *ContextTracer
@ -42,6 +52,26 @@ type logLine struct {
line int line int
} }
func (ll *logLine) Text() string {
return ll.msg
}
func (ll *logLine) Severity() Severity {
return ll.level
}
func (ll *logLine) Time() time.Time {
return ll.timestamp
}
func (ll *logLine) File() string {
return ll.file
}
func (ll *logLine) LineNumber() int {
return ll.line
}
func (ll *logLine) Equal(ol *logLine) bool { func (ll *logLine) Equal(ol *logLine) bool {
switch { switch {
case ll.msg != ol.msg: case ll.msg != ol.msg:

View file

@ -7,11 +7,71 @@ import (
"time" "time"
) )
type (
// Adapter is used to write logs.
Adapter interface {
// Write is called for each log message.
Write(msg Message, duplicates uint64)
}
// AdapterFunc is a convenience type for implementing
// Adapter.
AdapterFunc func(msg Message, duplciates uint64)
// FormatFunc formats msg into a string.
FormatFunc func(msg Message, duplciates uint64) string
// SimpleFileAdapter implements Adapter and writes all
// messages to File.
SimpleFileAdapter struct {
Format FormatFunc
File *os.File
}
)
var ( var (
// StdoutAdapter is a simple file adapter that writes
// all logs to os.Stdout using a predefined format.
StdoutAdapter = &SimpleFileAdapter{
File: os.Stdout,
Format: defaultColorFormater,
}
// StderrAdapter is a simple file adapter that writes
// all logs to os.Stdout using a predefined format.
StderrAdapter = &SimpleFileAdapter{
File: os.Stderr,
Format: defaultColorFormater,
}
)
var (
adapter Adapter = StdoutAdapter
schedulingEnabled = false schedulingEnabled = false
writeTrigger = make(chan struct{}) writeTrigger = make(chan struct{})
) )
// SetAdapter configures the logging adapter to use.
// This must be called before the log package is initialized.
func SetAdapter(a Adapter) {
if initializing.IsSet() || a == nil {
return
}
adapter = a
}
// Write implements Adapter and calls fn.
func (fn AdapterFunc) Write(msg Message, duplicates uint64) {
fn(msg, duplicates)
}
// Write implements Adapter and writes msg the underlying file.
func (fileAdapter *SimpleFileAdapter) Write(msg Message, duplicates uint64) {
fmt.Fprintln(fileAdapter.File, fileAdapter.Format(msg, duplicates))
}
// EnableScheduling enables external scheduling of the logger. This will require to manually trigger writes via TriggerWrite whenevery logs should be written. Please note that full buffers will also trigger writing. Must be called before Start() to have an effect. // EnableScheduling enables external scheduling of the logger. This will require to manually trigger writes via TriggerWrite whenevery logs should be written. Please note that full buffers will also trigger writing. Must be called before Start() to have an effect.
func EnableScheduling() { func EnableScheduling() {
if !initializing.IsSet() { if !initializing.IsSet() {
@ -34,10 +94,8 @@ func TriggerWriterChannel() chan struct{} {
return writeTrigger return writeTrigger
} }
func writeLine(line *logLine, duplicates uint64) { func defaultColorFormater(line Message, duplicates uint64) string {
fmt.Println(formatLine(line, duplicates, true)) return formatLine(line.(*logLine), duplicates, true)
// TODO: implement file logging and setting console/file logging
// TODO: use https://github.com/natefinch/lumberjack
} }
func startWriter() { func startWriter() {
@ -132,7 +190,7 @@ StackTrace:
} }
// if currentLine and line are _not_ equal, output currentLine // if currentLine and line are _not_ equal, output currentLine
writeLine(currentLine, duplicates) adapter.Write(currentLine, duplicates)
// reset duplicate counter // reset duplicate counter
duplicates = 0 duplicates = 0
// set new currentLine // set new currentLine
@ -144,7 +202,7 @@ StackTrace:
// write final line // write final line
if currentLine != nil { if currentLine != nil {
writeLine(currentLine, duplicates) adapter.Write(currentLine, duplicates)
} }
// reset state // reset state
currentLine = nil //nolint:ineffassign currentLine = nil //nolint:ineffassign
@ -166,7 +224,7 @@ func finalizeWriting() {
for { for {
select { select {
case line := <-logBuffer: case line := <-logBuffer:
writeLine(line, 0) adapter.Write(line, 0)
case <-time.After(10 * time.Millisecond): case <-time.After(10 * time.Millisecond):
fmt.Printf("%s%s %s EOF%s\n", InfoLevel.color(), time.Now().Format(timeFormat), leftArrow, endColor()) fmt.Printf("%s%s %s EOF%s\n", InfoLevel.color(), time.Now().Format(timeFormat), leftArrow, endColor())
return return

View file

@ -1,18 +1,28 @@
/* // Package modules provides a full module and task management ecosystem to
Package modules provides a full module and task management ecosystem to cleanly put all big and small moving parts of a service together. // cleanly put all big and small moving parts of a service together.
//
Modules are started in a multi-stage process and may depend on other modules: // Modules are started in a multi-stage process and may depend on other
- Go's init(): register flags // modules:
- prep: check flags, register config variables // - Go's init(): register flags
- start: start actual work, access config // - prep: check flags, register config variables
- stop: gracefully shut down // - start: start actual work, access config
// - stop: gracefully shut down
Workers: A simple function that is run by the module while catching panics and reporting them. Ideal for long running (possibly) idle goroutines. Can be automatically restarted if execution ends with an error. //
// **Workers**
Tasks: Functions that take somewhere between a couple seconds and a couple minutes to execute and should be queued, scheduled or repeated. // A simple function that is run by the module while catching
// panics and reporting them. Ideal for long running (possibly) idle goroutines.
MicroTasks: Functions that take less than a second to execute, but require lots of resources. Running such functions as MicroTasks will reduce concurrent execution and shall improve performance. // Can be automatically restarted if execution ends with an error.
//
Ideally, _any_ execution by a module is done through these methods. This will not only ensure that all panics are caught, but will also give better insights into how your service performs. // **Tasks**
*/ // Functions that take somewhere between a couple seconds and a couple
// minutes to execute and should be queued, scheduled or repeated.
//
// **MicroTasks**
// Functions that take less than a second to execute, but require
// lots of resources. Running such functions as MicroTasks will reduce concurrent
// execution and shall improve performance.
//
// Ideally, _any_ execution by a module is done through these methods. This will
// not only ensure that all panics are caught, but will also give better insights
// into how your service performs.
package modules package modules

View file

@ -12,17 +12,20 @@ var (
modulesChangeNotifyFn func(*Module) modulesChangeNotifyFn func(*Module)
) )
// Enable enables the module. Only has an effect if module management is enabled. // Enable enables the module. Only has an effect if module management
// is enabled.
func (m *Module) Enable() (changed bool) { func (m *Module) Enable() (changed bool) {
return m.enabled.SetToIf(false, true) return m.enabled.SetToIf(false, true)
} }
// Disable disables the module. Only has an effect if module management is enabled. // Disable disables the module. Only has an effect if module management
// is enabled.
func (m *Module) Disable() (changed bool) { func (m *Module) Disable() (changed bool) {
return m.enabled.SetToIf(true, false) return m.enabled.SetToIf(true, false)
} }
// SetEnabled sets the module to the desired enabled state. Only has an effect if module management is enabled. // SetEnabled sets the module to the desired enabled state. Only has
// an effect if module management is enabled.
func (m *Module) SetEnabled(enable bool) (changed bool) { func (m *Module) SetEnabled(enable bool) (changed bool) {
if enable { if enable {
return m.Enable() return m.Enable()
@ -35,16 +38,36 @@ func (m *Module) Enabled() bool {
return m.enabled.IsSet() return m.enabled.IsSet()
} }
// EnabledAsDependency returns whether or not the module is currently enabled as a dependency. // EnabledAsDependency returns whether or not the module is currently
// enabled as a dependency.
func (m *Module) EnabledAsDependency() bool { func (m *Module) EnabledAsDependency() bool {
return m.enabledAsDependency.IsSet() return m.enabledAsDependency.IsSet()
} }
// EnableModuleManagement enables the module management functionality within modules. The supplied notify function will be called whenever the status of a module changes. The affected module will be in the parameter. You will need to manually enable modules, else nothing will start. // EnableModuleManagement enables the module management functionality
func EnableModuleManagement(changeNotifyFn func(*Module)) { // within modules. The supplied notify function will be called whenever
// the status of a module changes. The affected module will be in the
// parameter. You will need to manually enable modules, else nothing
// will start.
// EnableModuleManagement returns true if changeNotifyFn has been set
// and it has been called for the first time.
//
// Example:
//
// EnableModuleManagement(func(m *modules.Module) {
// // some module has changed ...
// // do what ever you like
//
// // Run the built-in module management
// modules.ManageModules()
// })
//
func EnableModuleManagement(changeNotifyFn func(*Module)) bool {
if moduleMgmtEnabled.SetToIf(false, true) { if moduleMgmtEnabled.SetToIf(false, true) {
modulesChangeNotifyFn = changeNotifyFn modulesChangeNotifyFn = changeNotifyFn
return true
} }
return false
} }
func (m *Module) notifyOfChange() { func (m *Module) notifyOfChange() {
@ -56,7 +79,8 @@ func (m *Module) notifyOfChange() {
} }
} }
// ManageModules triggers the module manager to react to recent changes of enabled modules. // ManageModules triggers the module manager to react to recent changes of
// enabled modules.
func ManageModules() error { func ManageModules() error {
// check if enabled // check if enabled
if !moduleMgmtEnabled.IsSet() { if !moduleMgmtEnabled.IsSet() {

View file

@ -273,12 +273,14 @@ func (m *Module) stopAllTasks(reports chan *report) {
"module-failed-stop", "module-failed-stop",
fmt.Sprintf("failed to stop module: %s", err.Error()), fmt.Sprintf("failed to stop module: %s", err.Error()),
) )
} else {
m.Lock()
m.status = StatusOffline
m.Unlock()
m.notifyOfChange()
} }
// Always set to offline in order to let other modules shutdown in order.
m.Lock()
m.status = StatusOffline
m.Unlock()
m.notifyOfChange()
// send report // send report
reports <- &report{ reports <- &report{
module: m, module: m,

View file

@ -4,119 +4,127 @@ import (
"context" "context"
"flag" "flag"
"fmt" "fmt"
"strings"
"github.com/safing/portbase/database" "github.com/safing/portbase/config"
_ "github.com/safing/portbase/database/dbmodule" // database module is required _ "github.com/safing/portbase/database/dbmodule" // database module is required
"github.com/safing/portbase/modules" "github.com/safing/portbase/modules"
"github.com/safing/portbase/runtime"
) )
const ( const configChangeEvent = "config change"
configChangeEvent = "config change"
subsystemsStatusChange = "status change"
)
var ( var (
// DefaultManager is the default subsystem registry.
DefaultManager *Manager
module *modules.Module module *modules.Module
printGraphFlag bool printGraphFlag bool
databaseKeySpace string
db = database.NewInterface(nil)
) )
func init() { // Register registers a new subsystem. It's like Manager.Register
// enable partial starting // but uses DefaultManager and panics on error.
modules.EnableModuleManagement(handleModuleChanges) func Register(id, name, description string, module *modules.Module, configKeySpace string, option *config.Option) {
err := DefaultManager.Register(id, name, description, module, configKeySpace, option)
if err != nil {
panic(err)
}
}
// register module and enable it for starting func init() {
module = modules.Register("subsystems", prep, start, nil, "config", "database", "base") // The subsystem layer takes over module management. Note that
// no one must have called EnableModuleManagement. Otherwise
// the subsystem layer will silently fail managing module
// dependencies!
// TODO(ppacher): we SHOULD panic here!
// TASK(#1431)
modules.EnableModuleManagement(func(m *modules.Module) {
if DefaultManager == nil {
return
}
DefaultManager.handleModuleUpdate(m)
})
module = modules.Register("subsystems", prep, start, nil, "config", "database", "runtime", "base")
module.Enable() module.Enable()
// register event for changes in the subsystem // TODO(ppacher): can we create the default registry during prep phase?
module.RegisterEvent(subsystemsStatusChange) var err error
DefaultManager, err = NewManager(runtime.DefaultRegistry)
if err != nil {
panic("Failed to create default registry: " + err.Error())
}
flag.BoolVar(&printGraphFlag, "print-subsystem-graph", false, "print the subsystem module dependency graph") flag.BoolVar(&printGraphFlag, "print-subsystem-graph", false, "print the subsystem module dependency graph")
} }
func prep() error { func prep() error {
if printGraphFlag { if printGraphFlag {
printGraph() DefaultManager.PrintGraph()
return modules.ErrCleanExit return modules.ErrCleanExit
} }
return module.RegisterEventHook("config", configChangeEvent, "control subsystems", handleConfigChanges) // We need to listen for configuration changes so we can
} // start/stop dependend modules in case a subsystem is
// (de-)activated.
func start() error { if err := module.RegisterEventHook(
// lock registration "config",
subsystemsLocked.Set() configChangeEvent,
"control subsystems",
// lock slice and map func(ctx context.Context, _ interface{}) error {
subsystemsLock.Lock() err := DefaultManager.CheckConfig(ctx)
// go through all dependencies if err != nil {
seen := make(map[string]struct{}) module.Error(
for _, sub := range subsystems { "modulemgmt-failed",
// mark subsystem module as seen fmt.Sprintf("The subsystem framework failed to start or stop one or more modules.\nError: %s\nCheck logs for more information.", err),
seen[sub.module.Name] = struct{}{} )
return nil
}
module.Resolve("modulemgmt-failed")
return nil
},
); err != nil {
return fmt.Errorf("register event hook: %w", err)
} }
for _, sub := range subsystems {
// add main module
sub.Modules = append(sub.Modules, statusFromModule(sub.module))
// add dependencies
sub.addDependencies(sub.module, seen)
}
// unlock
subsystemsLock.Unlock()
// apply config
module.StartWorker("initial subsystem configuration", func(ctx context.Context) error {
return handleConfigChanges(module.Ctx, nil)
})
return nil return nil
} }
func (sub *Subsystem) addDependencies(module *modules.Module, seen map[string]struct{}) { func start() error {
for _, module := range module.Dependencies() { // Registration of subsystems is only allowed during
_, ok := seen[module.Name] // preparation. Make sure any further call to Register()
if !ok { // panics.
// add dependency to modules if err := DefaultManager.Start(); err != nil {
sub.Modules = append(sub.Modules, statusFromModule(module)) return err
// mark as seen
seen[module.Name] = struct{}{}
// add further dependencies
sub.addDependencies(module, seen)
}
} }
module.StartWorker("initial subsystem configuration", DefaultManager.CheckConfig)
return nil
} }
// SetDatabaseKeySpace sets a key space where subsystem status // PrintGraph prints the subsystem and module graph.
func SetDatabaseKeySpace(keySpace string) { func (mng *Manager) PrintGraph() {
if databaseKeySpace == "" { mng.l.RLock()
databaseKeySpace = keySpace defer mng.l.RUnlock()
if !strings.HasSuffix(databaseKeySpace, "/") {
databaseKeySpace += "/"
}
}
}
func printGraph() {
fmt.Println("subsystems dependency graph:") fmt.Println("subsystems dependency graph:")
// unmark subsystems module // unmark subsystems module
module.Disable() module.Disable()
// mark roots // mark roots
for _, sub := range subsystems { for _, sub := range mng.subsys {
sub.module.Enable() // mark as tree root sub.module.Enable() // mark as tree root
} }
// print
for _, sub := range subsystems { for _, sub := range mng.subsys {
printModuleGraph("", sub.module, true) printModuleGraph("", sub.module, true)
} }
fmt.Println("\nsubsystem module groups:") fmt.Println("\nsubsystem module groups:")
_ = start() // no errors for what we need here _ = start() // no errors for what we need here
for _, sub := range subsystems { for _, sub := range mng.subsys {
fmt.Printf("├── %s\n", sub.Name) fmt.Printf("├── %s\n", sub.Name)
for _, mod := range sub.Modules[1:] { for _, mod := range sub.Modules[1:] {
fmt.Printf("│ ├── %s\n", mod.Name) fmt.Printf("│ ├── %s\n", mod.Name)

View file

@ -0,0 +1,268 @@
package subsystems
import (
"context"
"errors"
"fmt"
"sort"
"strings"
"sync"
"time"
"github.com/safing/portbase/config"
"github.com/safing/portbase/database/record"
"github.com/safing/portbase/modules"
"github.com/safing/portbase/runtime"
"github.com/tevino/abool"
)
var (
// ErrManagerStarted is returned when subsystem registration attempt
// occurs after the manager has been started.
ErrManagerStarted = errors.New("subsystem manager already started")
// ErrDuplicateSubsystem is returned when the subsystem to be registered
// is alreadey known (duplicated subsystem ID).
ErrDuplicateSubsystem = errors.New("subsystem is already registered")
)
// Manager manages subsystems, provides access via a runtime
// value providers and can takeover module management.
type Manager struct {
l sync.RWMutex
subsys map[string]*Subsystem
pushUpdate runtime.PushFunc
immutable *abool.AtomicBool
debounceUpdate *abool.AtomicBool
runtime *runtime.Registry
}
// NewManager returns a new subsystem manager that registers
// itself at rtReg.
func NewManager(rtReg *runtime.Registry) (*Manager, error) {
mng := &Manager{
subsys: make(map[string]*Subsystem),
immutable: abool.New(),
debounceUpdate: abool.New(),
}
push, err := rtReg.Register("subsystems/", runtime.SimpleValueGetterFunc(mng.Get))
if err != nil {
return nil, err
}
mng.pushUpdate = push
mng.runtime = rtReg
return mng, nil
}
// Start starts managing subsystems. Note that it's not possible
// to define new subsystems once Start() has been called.
func (mng *Manager) Start() error {
mng.immutable.Set()
seen := make(map[string]struct{}, len(mng.subsys))
configKeyPrefixes := make(map[string]*Subsystem, len(mng.subsys))
// mark all sub-systems as seen. This prevents sub-systems
// from being added as a sub-systems dependency in addAndMarkDependencies.
for _, sub := range mng.subsys {
seen[sub.module.Name] = struct{}{}
configKeyPrefixes[sub.ConfigKeySpace] = sub
}
// aggregate all modules dependencies (and the subsystem module itself)
// into the Modules slice. Configuration options form dependent modules
// will be marked using config.SubsystemAnnotation if not already set.
for _, sub := range mng.subsys {
sub.Modules = append(sub.Modules, statusFromModule(sub.module))
sub.addDependencies(sub.module, seen)
}
// Annotate all configuration options with their respective subsystem.
_ = config.ForEachOption(func(opt *config.Option) error {
subsys, ok := configKeyPrefixes[opt.Key]
if !ok {
return nil
}
// Add a new subsystem annotation is it is not already set!
opt.AddAnnotation(config.SubsystemAnnotation, subsys.ID)
return nil
})
return nil
}
// Get implements runtime.ValueProvider
func (mng *Manager) Get(keyOrPrefix string) ([]record.Record, error) {
mng.l.RLock()
defer mng.l.RUnlock()
dbName := mng.runtime.DatabaseName()
records := make([]record.Record, 0, len(mng.subsys))
for _, subsys := range mng.subsys {
subsys.Lock()
if !subsys.KeyIsSet() {
subsys.SetKey(dbName + ":subsystems/" + subsys.ID)
}
if strings.HasPrefix(subsys.DatabaseKey(), keyOrPrefix) {
records = append(records, subsys)
}
subsys.Unlock()
}
// make sure the order is always the same
sort.Sort(bySubsystemID(records))
return records, nil
}
// Register registers a new subsystem. The given option must be a bool option.
// Should be called in init() directly after the modules.Register() function.
// The config option must not yet be registered and will be registered for
// you. Pass a nil option to force enable.
//
// TODO(ppacher): IMHO the subsystem package is not responsible of registering
// the "toggle option". This would also remove runtime
// dependency to the config package. Users should either pass
// the BoolOptionFunc and the expertise/release level directly
// or just pass the configuration key so those information can
// be looked up by the registry.
func (mng *Manager) Register(id, name, description string, module *modules.Module, configKeySpace string, option *config.Option) error {
mng.l.Lock()
defer mng.l.Unlock()
if mng.immutable.IsSet() {
return ErrManagerStarted
}
if _, ok := mng.subsys[id]; ok {
return ErrDuplicateSubsystem
}
s := &Subsystem{
ID: id,
Name: name,
Description: description,
ConfigKeySpace: configKeySpace,
module: module,
toggleOption: option,
}
s.CreateMeta()
if s.toggleOption != nil {
s.ToggleOptionKey = s.toggleOption.Key
s.ExpertiseLevel = s.toggleOption.ExpertiseLevel
s.ReleaseLevel = s.toggleOption.ReleaseLevel
if err := config.Register(s.toggleOption); err != nil {
return fmt.Errorf("failed to register subsystem option: %w", err)
}
s.toggleValue = config.GetAsBool(s.ToggleOptionKey, false)
} else {
s.toggleValue = func() bool { return true }
}
mng.subsys[id] = s
return nil
}
func (mng *Manager) shouldServeUpdates() bool {
if !mng.immutable.IsSet() {
// the manager must be marked as immutable before we
// are going to handle any module changes.
return false
}
if modules.IsShuttingDown() {
// we don't care if we are shutting down anyway
return false
}
return true
}
// CheckConfig checks subsystem configuration values and enables
// or disables subsystems and their dependencies as required.
func (mng *Manager) CheckConfig(ctx context.Context) error {
return mng.handleConfigChanges(ctx)
}
func (mng *Manager) handleModuleUpdate(m *modules.Module) {
if !mng.shouldServeUpdates() {
return
}
// Read lock is fine as the subsystems are write-locked on their own
mng.l.RLock()
defer mng.l.RUnlock()
subsys, ms := mng.findParentSubsystem(m)
if subsys == nil {
// the updated module is not handled by any
// subsystem. We're done here.
return
}
subsys.Lock()
defer subsys.Unlock()
updated := compareAndUpdateStatus(m, ms)
if updated {
subsys.makeSummary()
}
if updated {
mng.pushUpdate(subsys)
}
}
func (mng *Manager) handleConfigChanges(_ context.Context) error {
if !mng.shouldServeUpdates() {
return nil
}
if mng.debounceUpdate.SetToIf(false, true) {
time.Sleep(100 * time.Millisecond)
mng.debounceUpdate.UnSet()
} else {
return nil
}
mng.l.RLock()
defer mng.l.RUnlock()
var changed bool
for _, subsystem := range mng.subsys {
if subsystem.module.SetEnabled(subsystem.toggleValue()) {
changed = true
}
}
if !changed {
return nil
}
return modules.ManageModules()
}
func (mng *Manager) findParentSubsystem(m *modules.Module) (*Subsystem, *ModuleStatus) {
for _, subsys := range mng.subsys {
for _, ms := range subsys.Modules {
if ms.Name == m.Name {
return subsys, ms
}
}
}
return nil, nil
}
// helper type to sort a slice of []*Subsystem (casted as []record.Record) by
// id. Only use if it's guaranteed that all record.Records are *Subsystem.
// Otherwise Less() will panic.
type bySubsystemID []record.Record
func (sl bySubsystemID) Less(i, j int) bool { return sl[i].(*Subsystem).ID < sl[j].(*Subsystem).ID }
func (sl bySubsystemID) Swap(i, j int) { sl[i], sl[j] = sl[j], sl[i] }
func (sl bySubsystemID) Len() int { return len(sl) }

View file

@ -5,30 +5,49 @@ import (
"github.com/safing/portbase/config" "github.com/safing/portbase/config"
"github.com/safing/portbase/database/record" "github.com/safing/portbase/database/record"
"github.com/safing/portbase/log"
"github.com/safing/portbase/modules" "github.com/safing/portbase/modules"
) )
// Subsystem describes a subset of modules that represent a part of a service or program to the user. // Subsystem describes a subset of modules that represent a part of a
// service or program to the user. Subsystems can be (de-)activated causing
// all related modules to be brought down or up.
type Subsystem struct { //nolint:maligned // not worth the effort type Subsystem struct { //nolint:maligned // not worth the effort
record.Base record.Base
sync.Mutex sync.Mutex
// ID is a unique identifier for the subsystem.
ID string ID string
Name string // Name holds a human readable name of the subsystem.
Name string
// Description may holds an optional description of
// the subsystem's purpose.
Description string Description string
module *modules.Module // Modules contains all modules that are related to the subsystem.
// Note that this slice also contains a reference to the subsystem
Modules []*ModuleStatus // module itself.
FailureStatus uint8 // summary: worst status Modules []*ModuleStatus
// FailureStatus is the worst failure status that is currently
// set in one of the subsystem's dependencies.
FailureStatus uint8
// ToggleOptionKey holds the key of the configuration option
// that is used to completely enable or disable this subsystem.
ToggleOptionKey string ToggleOptionKey string
toggleOption *config.Option // ExpertiseLevel defines the complexity of the subsystem and is
toggleValue func() bool // copied from the subsystem's toggleOption.
ExpertiseLevel uint8 // copied from toggleOption ExpertiseLevel config.ExpertiseLevel
ReleaseLevel uint8 // copied from toggleOption // ReleaseLevel defines the stability of the subsystem and is
// copied form the subsystem's toggleOption.
ReleaseLevel config.ReleaseLevel
// ConfigKeySpace defines the database key prefix that all
// options that belong to this subsystem have. Note that this
// value is mainly used to mark all related options with a
// config.SubsystemAnnotation. Options that are part of
// this subsystem but don't start with the correct prefix can
// still be marked by manually setting the appropriate annotation.
ConfigKeySpace string ConfigKeySpace string
module *modules.Module
toggleOption *config.Option
toggleValue config.BoolOption
} }
// ModuleStatus describes the status of a module. // ModuleStatus describes the status of a module.
@ -46,15 +65,13 @@ type ModuleStatus struct {
FailureMsg string FailureMsg string
} }
// Save saves the Subsystem Status to the database. func (sub *Subsystem) addDependencies(module *modules.Module, seen map[string]struct{}) {
func (sub *Subsystem) Save() { for _, module := range module.Dependencies() {
if databaseKeySpace != "" { if _, ok := seen[module.Name]; !ok {
if !sub.KeyIsSet() { seen[module.Name] = struct{}{}
sub.SetKey(databaseKeySpace + sub.ID)
} sub.Modules = append(sub.Modules, statusFromModule(module))
err := db.Put(sub) sub.addDependencies(module, seen)
if err != nil {
log.Errorf("subsystems: could not save subsystem status to database: %s", err)
} }
} }
} }
@ -90,6 +107,7 @@ func compareAndUpdateStatus(module *modules.Module, status *ModuleStatus) (chang
failureStatus, failureID, failureMsg := module.FailureStatus() failureStatus, failureID, failureMsg := module.FailureStatus()
if status.FailureStatus != failureStatus || if status.FailureStatus != failureStatus ||
status.FailureID != failureID { status.FailureID != failureID {
status.FailureStatus = failureStatus status.FailureStatus = failureStatus
status.FailureID = failureID status.FailureID = failureID
status.FailureMsg = failureMsg status.FailureMsg = failureMsg

View file

@ -1,161 +0,0 @@
package subsystems
import (
"context"
"fmt"
"sync"
"time"
"github.com/tevino/abool"
"github.com/safing/portbase/config"
"github.com/safing/portbase/modules"
)
var (
subsystems []*Subsystem
subsystemsMap = make(map[string]*Subsystem)
subsystemsLock sync.Mutex
subsystemsLocked = abool.New()
handlingConfigChanges = abool.New()
)
// Register registers a new subsystem. The given option must be a bool option. Should be called in init() directly after the modules.Register() function. The config option must not yet be registered and will be registered for you. Pass a nil option to force enable.
func Register(id, name, description string, module *modules.Module, configKeySpace string, option *config.Option) {
// lock slice and map
subsystemsLock.Lock()
defer subsystemsLock.Unlock()
// check if registration is closed
if subsystemsLocked.IsSet() {
panic("subsystems can only be registered in prep phase or earlier")
}
// check if already registered
_, ok := subsystemsMap[name]
if ok {
panic(fmt.Sprintf(`subsystem "%s" already registered`, name))
}
// create new
new := &Subsystem{
ID: id,
Name: name,
Description: description,
module: module,
toggleOption: option,
ConfigKeySpace: configKeySpace,
}
if new.toggleOption != nil {
new.ToggleOptionKey = new.toggleOption.Key
new.ExpertiseLevel = new.toggleOption.ExpertiseLevel
new.ReleaseLevel = new.toggleOption.ReleaseLevel
}
// register config
if option != nil {
err := config.Register(option)
if err != nil {
panic(fmt.Sprintf("failed to register config: %s", err))
}
new.toggleValue = config.GetAsBool(new.ToggleOptionKey, false)
} else {
// force enabled
new.toggleValue = func() bool { return true }
}
// add to lists
subsystemsMap[name] = new
subsystems = append(subsystems, new)
}
func handleModuleChanges(m *modules.Module) {
// check if ready
if !subsystemsLocked.IsSet() {
return
}
// check if shutting down
if modules.IsShuttingDown() {
return
}
// find module status
var moduleSubsystem *Subsystem
var moduleStatus *ModuleStatus
subsystemLoop:
for _, subsystem := range subsystems {
for _, status := range subsystem.Modules {
if m.Name == status.Name {
moduleSubsystem = subsystem
moduleStatus = status
break subsystemLoop
}
}
}
// abort if not found
if moduleSubsystem == nil || moduleStatus == nil {
return
}
// update status
moduleSubsystem.Lock()
changed := compareAndUpdateStatus(m, moduleStatus)
if changed {
moduleSubsystem.makeSummary()
}
moduleSubsystem.Unlock()
// save
if changed {
moduleSubsystem.Save()
}
}
func handleConfigChanges(ctx context.Context, data interface{}) error {
// check if ready
if !subsystemsLocked.IsSet() {
return nil
}
// potentially catch multiple changes
if handlingConfigChanges.SetToIf(false, true) {
time.Sleep(100 * time.Millisecond)
handlingConfigChanges.UnSet()
} else {
return nil
}
// don't do anything if we are already shutting down globally
if modules.IsShuttingDown() {
return nil
}
// only run one instance at any time
subsystemsLock.Lock()
defer subsystemsLock.Unlock()
var changed bool
for _, subsystem := range subsystems {
if subsystem.module.SetEnabled(subsystem.toggleValue()) {
// if changed
changed = true
}
}
// trigger module management if any setting was changed
if changed {
err := modules.ManageModules()
if err != nil {
module.Error(
"modulemgmt-failed",
fmt.Sprintf("The subsystem framework failed to start or stop one or more modules.\nError: %s\nCheck logs for more information.", err),
)
} else {
module.Resolve("modulemgmt-failed")
}
}
return nil
}

View file

@ -50,7 +50,7 @@ func TestSubsystems(t *testing.T) {
DefaultValue: false, DefaultValue: false,
}, },
) )
sub1 := subsystemsMap["Feature One"] sub1 := DefaultManager.subsys["feature-one"]
feature2 := modules.Register("feature2", nil, nil, nil) feature2 := modules.Register("feature2", nil, nil, nil)
Register( Register(

View file

@ -352,6 +352,8 @@ func (t *Task) executeWithLocking() {
// notify that we finished // notify that we finished
t.cancelCtx() t.cancelCtx()
// refresh context // refresh context
// RACE CONDITION with L314!
t.ctx, t.cancelCtx = context.WithCancel(t.module.Ctx) t.ctx, t.cancelCtx = context.WithCancel(t.module.Ctx)
t.lock.Unlock() t.lock.Unlock()

View file

@ -3,66 +3,49 @@ package notifications
import ( import (
"context" "context"
"time" "time"
"github.com/safing/portbase/log"
) )
//nolint:unparam // must conform to interface func cleaner(ctx context.Context) error { //nolint:unparam // Conforms to worker interface
func cleaner(ctx context.Context) error { ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil return nil
case <-time.After(5 * time.Second): case <-ticker.C:
cleanNotifications() deleteExpiredNotifs()
} }
} }
} }
func cleanNotifications() { func deleteExpiredNotifs() {
now := time.Now().Unix() // Get a copy of the notification map.
finishedThreshhold := time.Now().Add(-10 * time.Second).Unix() notsCopy := getNotsCopy()
executionTimelimit := time.Now().Add(-24 * time.Hour).Unix()
fallbackTimelimit := time.Now().Add(-72 * time.Hour).Unix()
notsLock.Lock() // Delete all expired notifications.
defer notsLock.Unlock() for _, n := range notsCopy {
if n.isExpired() {
n.delete(true)
}
}
}
func (n *Notification) isExpired() bool {
n.Lock()
defer n.Unlock()
return n.Expires > 0 && n.Expires < time.Now().Unix()
}
func getNotsCopy() []*Notification {
notsLock.RLock()
defer notsLock.RUnlock()
notsCopy := make([]*Notification, 0, len(nots))
for _, n := range nots { for _, n := range nots {
n.Lock() notsCopy = append(notsCopy, n)
switch {
case n.Executed != 0: // notification was fully handled
// wait for a short time before deleting
if n.Executed < finishedThreshhold {
go deleteNotification(n)
}
case n.Responded != 0:
// waiting for execution
if n.Responded < executionTimelimit {
go deleteNotification(n)
}
case n.Expires != 0:
// expired without response
if n.Expires < now {
go deleteNotification(n)
}
case n.Created != 0:
// fallback: delete after 3 days after creation
if n.Created < fallbackTimelimit {
go deleteNotification(n)
}
default:
// invalid, impossible to determine cleanup timeframe, delete now
go deleteNotification(n)
}
n.Unlock()
} }
}
func deleteNotification(n *Notification) { return notsCopy
err := n.Delete()
if err != nil {
log.Debugf("notifications: failed to delete %s: %s", n.ID, err)
}
} }

View file

@ -19,9 +19,6 @@ var (
notsLock sync.RWMutex notsLock sync.RWMutex
dbController *database.Controller dbController *database.Controller
dbInterface *database.Interface
persistentBasePath string
) )
// Storage interface errors // Storage interface errors
@ -31,13 +28,6 @@ var (
ErrNoDelete = errors.New("notifications may not be deleted, they must be handled") ErrNoDelete = errors.New("notifications may not be deleted, they must be handled")
) )
// SetPersistenceBasePath sets the base path for persisting persistent notifications.
func SetPersistenceBasePath(dbBasePath string) {
if persistentBasePath == "" {
persistentBasePath = dbBasePath
}
}
// StorageInterface provices a storage.Interface to the configuration manager. // StorageInterface provices a storage.Interface to the configuration manager.
type StorageInterface struct { type StorageInterface struct {
storage.InjectBase storage.InjectBase
@ -64,22 +54,27 @@ func registerAsDatabase() error {
// Get returns a database record. // Get returns a database record.
func (s *StorageInterface) Get(key string) (record.Record, error) { func (s *StorageInterface) Get(key string) (record.Record, error) {
notsLock.RLock() // Get EventID from key.
defer notsLock.RUnlock() if !strings.HasPrefix(key, "all/") {
return nil, storage.ErrNotFound
}
key = strings.TrimPrefix(key, "all/")
// transform key // Get notification from storage.
if strings.HasPrefix(key, "all/") { n, ok := getNotification(key)
key = strings.TrimPrefix(key, "all/") if !ok {
} else {
return nil, storage.ErrNotFound return nil, storage.ErrNotFound
} }
// get notification return n, nil
not, ok := nots[key] }
if ok {
return not, nil func getNotification(eventID string) (n *Notification, ok bool) {
} notsLock.RLock()
return nil, storage.ErrNotFound defer notsLock.RUnlock()
n, ok = nots[eventID]
return
} }
// Query returns a an iterator for the supplied query. // Query returns a an iterator for the supplied query.
@ -92,23 +87,40 @@ func (s *StorageInterface) Query(q *query.Query, local, internal bool) (*iterato
} }
func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) { func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) {
notsLock.RLock() // Get a copy of the notification map.
defer notsLock.RUnlock() notsCopy := getNotsCopy()
// send all notifications // send all notifications
for _, n := range nots { for _, n := range notsCopy {
if n.Meta().IsDeleted() { if inQuery(n, q) {
continue select {
} case it.Next <- n:
case <-it.Done:
if q.MatchesKey(n.DatabaseKey()) && q.MatchesRecord(n) { // make sure we don't leak this goroutine if the iterator get's cancelled
it.Next <- n return
}
} }
} }
it.Finish(nil) it.Finish(nil)
} }
func inQuery(n *Notification, q *query.Query) bool {
n.lock.Lock()
defer n.lock.Unlock()
switch {
case n.Meta().IsDeleted():
return false
case !q.MatchesKey(n.DatabaseKey()):
return false
case !q.MatchesRecord(n):
return false
}
return true
}
// Put stores a record in the database. // Put stores a record in the database.
func (s *StorageInterface) Put(r record.Record) (record.Record, error) { func (s *StorageInterface) Put(r record.Record) (record.Record, error) {
// record is already locked! // record is already locked!
@ -126,76 +138,79 @@ func (s *StorageInterface) Put(r record.Record) (record.Record, error) {
return nil, ErrInvalidPath return nil, ErrInvalidPath
} }
// continue in goroutine return applyUpdate(n, key)
go UpdateNotification(n, key)
return n, nil
} }
// UpdateNotification updates a notification with input from a database action. Notification will not be saved/propagated if there is no valid change. func applyUpdate(n *Notification, key string) (*Notification, error) {
func UpdateNotification(n *Notification, key string) {
n.Lock()
defer n.Unlock()
// separate goroutine in order to correctly lock notsLock // separate goroutine in order to correctly lock notsLock
notsLock.RLock() existing, ok := getNotification(key)
origN, ok := nots[key]
notsLock.RUnlock()
save := false
// ignore if already deleted // ignore if already deleted
if ok && origN.Meta().IsDeleted() { if !ok || existing.Meta().IsDeleted() {
ok = false // this is a completely new notification
// we pass pushUpdate==false because the storage
// controller will push an update on put anyway.
n.save(false)
return n, nil
} }
if ok { // Save when we're finished, if needed.
// existing notification save := false
// only update select attributes defer func() {
origN.Lock() if save {
defer origN.Unlock() existing.save(false)
} else { }
// new notification (from external source): old == new }()
origN = n
existing.Lock()
defer existing.Unlock()
if existing.State == Executed {
return existing, fmt.Errorf("action already executed")
}
// check if the notification has been marked as
// "executed externally".
if n.State == Executed {
log.Tracef("notifications: action for %s executed externally", n.EventID)
existing.State = Executed
save = true
// in case the action has been executed immediately by the
// sender we may need to update the SelectedActionID.
// Though, we guard the assignments with value check
// so partial updates that only change the
// State property do not overwrite existing values.
if n.SelectedActionID != "" {
existing.SelectedActionID = n.SelectedActionID
}
}
if n.SelectedActionID != "" && existing.State == Active {
log.Tracef("notifications: selected action for %s: %s", n.EventID, n.SelectedActionID)
existing.selectAndExecuteAction(n.SelectedActionID)
save = true save = true
} }
switch { return existing, nil
case n.SelectedActionID != "" && n.Responded == 0:
// select action, if not yet already handled
log.Tracef("notifications: selected action for %s: %s", n.ID, n.SelectedActionID)
origN.selectAndExecuteAction(n.SelectedActionID)
save = true
case origN.Executed == 0 && n.Executed != 0:
log.Tracef("notifications: action for %s executed externally", n.ID)
origN.Executed = n.Executed
save = true
}
if save {
// we may be locking
go origN.Save()
}
} }
// Delete deletes a record from the database. // Delete deletes a record from the database.
func (s *StorageInterface) Delete(key string) error { func (s *StorageInterface) Delete(key string) error {
// transform key // Get EventID from key.
if strings.HasPrefix(key, "all/") { if !strings.HasPrefix(key, "all/") {
key = strings.TrimPrefix(key, "all/")
} else {
return storage.ErrNotFound return storage.ErrNotFound
} }
key = strings.TrimPrefix(key, "all/")
// get notification // Get notification from storage.
notsLock.Lock() n, ok := getNotification(key)
n, ok := nots[key]
notsLock.Unlock()
if !ok { if !ok {
return storage.ErrNotFound return storage.ErrNotFound
} }
// delete
return n.Delete() n.delete(true)
return nil
} }
// ReadOnly returns whether the database is read only. // ReadOnly returns whether the database is read only.

View file

@ -1,49 +1,100 @@
package notifications package notifications
import ( import (
"context"
"fmt" "fmt"
"sync" "sync"
"time" "time"
"github.com/safing/portbase/database"
"github.com/safing/portbase/database/record" "github.com/safing/portbase/database/record"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portbase/utils" "github.com/safing/portbase/utils"
) )
// Notification types // Type describes the type of a notification.
type Type uint8
// Notification types.
const ( const (
Info uint8 = 0 Info Type = 0
Warning uint8 = 1 Warning Type = 1
Prompt uint8 = 2 Prompt Type = 2
)
// State describes the state of a notification.
type State string
// NotificationActionFn defines the function signature for notification action
// functions.
type NotificationActionFn func(context.Context, *Notification) error
// Possible notification states.
// State transitions can only happen from top to bottom.
const (
// Active describes a notification that is active, no expired and,
// if actions are available, still waits for the user to select an
// action.
Active State = "active"
// Responded describes a notification where the user has already
// selected which action to take but that action is still to be
// performed.
Responded State = "responded"
// Executes describes a notification where the user has selected
// and action and that action has been performed.
Executed State = "executed"
) )
// Notification represents a notification that is to be delivered to the user. // Notification represents a notification that is to be delivered to the user.
type Notification struct { type Notification struct {
record.Base record.Base
// EventID is used to identify a specific notification. It consists of
ID string // the module name and a per-module unique event id.
// The following format is recommended:
// <module-id>:<event-id>
EventID string
// GUID is a unique identifier for each notification instance. That is
// two notifications with the same EventID must still have unique GUIDs.
// The GUID is mainly used for system (Windows) integration and is
// automatically populated by the notification package. Average users
// don't need to care about this field.
GUID string GUID string
// Type is the notification type. It can be one of Info, Warning or Prompt.
Type Type
// Title is an optional and very short title for the message that gives a
// hint about what the notification is about.
Title string
// Category is an optional category for the notification that allows for
// tagging and grouping notifications by category.
Category string
// Message is the default message shown to the user if no localized version
// of the notification is available. Note that the message should already
// have any paramerized values replaced.
Message string Message string
// MessageTemplate string // EventData contains an additional payload for the notification. This payload
// MessageData []string // may contain contextual data and may be used by a localization framework
DataSubject sync.Locker // to populate the notification message template.
Type uint8 // If EventData implements sync.Locker it will be locked and unlocked together with the
// notification. Otherwise, EventData is expected to be immutable once the
Persistent bool // this notification persists until it is handled and survives restarts // notification has been saved and handed over to the notification or database package.
Created int64 // creation timestamp, notification "starts" EventData interface{}
Expires int64 // expiry timestamp, notification is expected to be canceled at this time and may be cleaned up afterwards // Expires holds the unix epoch timestamp at which the notification expires
Responded int64 // response timestamp, notification "ends" // and can be cleaned up.
Executed int64 // execution timestamp, notification will be deleted soon // Users can safely ignore expired notifications and should handle expiry the
// same as deletion.
Expires int64
// State describes the current state of a notification. See State for
// a list of available values and their meaning.
State State
// AvailableActions defines a list of actions that a user can choose from.
AvailableActions []*Action AvailableActions []*Action
// SelectedActionID is updated to match the ID of one of the AvailableActions
// based on the user selection.
SelectedActionID string SelectedActionID string
lock sync.Mutex lock sync.Mutex
actionFunction func(*Notification) // call function to process action actionFunction NotificationActionFn // call function to process action
actionTrigger chan string // and/or send to a channel actionTrigger chan string // and/or send to a channel
expiredTrigger chan struct{} // closed on expire expiredTrigger chan struct{} // closed on expire
} }
// Action describes an action that can be taken for a notification. // Action describes an action that can be taken for a notification.
@ -52,9 +103,6 @@ type Action struct {
Text string Text string
} }
func noOpAction(n *Notification) {
}
// Get returns the notification identifed by the given id or nil if it doesn't exist. // Get returns the notification identifed by the given id or nil if it doesn't exist.
func Get(id string) *Notification { func Get(id string) *Notification {
notsLock.RLock() notsLock.RLock()
@ -87,43 +135,85 @@ func NotifyPrompt(id, msg string, actions ...Action) *Notification {
return notify(Prompt, id, msg, actions...) return notify(Prompt, id, msg, actions...)
} }
func notify(nType uint8, id string, msg string, actions ...Action) *Notification { func notify(nType Type, id, msg string, actions ...Action) *Notification {
acts := make([]*Action, len(actions)) acts := make([]*Action, len(actions))
for idx := range actions { for idx := range actions {
a := actions[idx] a := actions[idx]
acts[idx] = &a acts[idx] = &a
} }
if id == "" { return Notify(&Notification{
id = utils.DerivedInstanceUUID(msg).String() EventID: id,
}
n := Notification{
ID: id,
Message: msg,
Type: nType, Type: nType,
Message: msg,
AvailableActions: acts, AvailableActions: acts,
} })
return n.Save()
} }
// Save saves the notification and returns it. // Notify sends the given notification.
func (n *Notification) Save() *Notification { func Notify(n *Notification) *Notification {
notsLock.Lock() // While this function is very similar to Save(), it is much nicer to use in
defer notsLock.Unlock() // order to just fire off one notification, as it does not require some more
n.Lock() // uncommon Go syntax.
defer n.Unlock()
// initialize n.save(true)
if n.Created == 0 { return n
n.Created = time.Now().Unix() }
// Save saves the notification.
func (n *Notification) Save() {
n.save(true)
}
// save saves the notification to the internal storage. It locks the
// notification, so it must not be locked when save is called.
func (n *Notification) save(pushUpdate bool) {
var id string
// Save notification after pre-save processing.
defer func() {
if id != "" {
// Lock and save to notification storage.
notsLock.Lock()
defer notsLock.Unlock()
nots[id] = n
}
}()
// We do not access EventData here, so it is enough to just lock the
// notification itself.
n.lock.Lock()
defer n.lock.Unlock()
// Move Title to Message, as that is the required field.
if n.Message == "" {
n.Message = n.Title
n.Title = ""
} }
// Check if required data is present.
if n.Message == "" {
log.Warning("notifications: ignoring notification without Message")
return
}
// Derive EventID from Message if not given.
if n.EventID == "" {
n.EventID = fmt.Sprintf(
"unknown:%s",
utils.DerivedInstanceUUID(n.Message).String(),
)
}
// Save ID for deletion
id = n.EventID
// Generate random GUID if not set.
if n.GUID == "" { if n.GUID == "" {
n.GUID = utils.RandomUUID(n.ID).String() n.GUID = utils.RandomUUID(n.EventID).String()
} }
// make ack notification if there are no defined actions // Make ack notification if there are no defined actions.
if len(n.AvailableActions) == 0 { if len(n.AvailableActions) == 0 {
n.AvailableActions = []*Action{ n.AvailableActions = []*Action{
{ {
@ -131,55 +221,31 @@ func (n *Notification) Save() *Notification {
Text: "OK", Text: "OK",
}, },
} }
n.actionFunction = noOpAction }
// Make sure we always have a notification state assigned.
if n.State == "" {
n.State = Active
} }
// check key // check key
if n.DatabaseKey() == "" { if !n.KeyIsSet() {
n.SetKey(fmt.Sprintf("notifications:all/%s", n.ID)) n.SetKey(fmt.Sprintf("notifications:all/%s", n.EventID))
} }
// update meta // Update meta data.
n.UpdateMeta() n.UpdateMeta()
// assign to data map // Push update via the database system if needed.
nots[n.ID] = n if pushUpdate {
log.Tracef("notifications: pushing update for %s to subscribers", n.Key())
// push update dbController.PushUpdate(n)
log.Tracef("notifications: pushing update for %s to subscribers", n.Key())
dbController.PushUpdate(n)
// persist
if n.Persistent && persistentBasePath != "" {
duplicate := &Notification{
ID: n.ID,
Message: n.Message,
DataSubject: n.DataSubject,
AvailableActions: duplicateActions(n.AvailableActions),
SelectedActionID: n.SelectedActionID,
Persistent: n.Persistent,
Created: n.Created,
Expires: n.Expires,
Responded: n.Responded,
Executed: n.Executed,
}
duplicate.SetMeta(n.Meta().Duplicate())
key := fmt.Sprintf("%s/%s", persistentBasePath, n.ID)
duplicate.SetKey(key)
go func() {
err := dbInterface.Put(duplicate)
if err != nil {
log.Warningf("notifications: failed to persist notification %s: %s", key, err)
}
}()
} }
return n
} }
// SetActionFunction sets a trigger function to be executed when the user reacted on the notification. // SetActionFunction sets a trigger function to be executed when the user reacted on the notification.
// The provided function will be started as its own goroutine and will have to lock everything it accesses, even the provided notification. // The provided function will be started as its own goroutine and will have to lock everything it accesses, even the provided notification.
func (n *Notification) SetActionFunction(fn func(*Notification)) *Notification { func (n *Notification) SetActionFunction(fn NotificationActionFn) *Notification {
n.lock.Lock() n.lock.Lock()
defer n.lock.Unlock() defer n.lock.Unlock()
n.actionFunction = fn n.actionFunction = fn
@ -200,52 +266,72 @@ func (n *Notification) Response() <-chan string {
// Update updates/resends a notification if it was not already responded to. // Update updates/resends a notification if it was not already responded to.
func (n *Notification) Update(expires int64) { func (n *Notification) Update(expires int64) {
responded := true // Save when we're finished, if needed.
n.lock.Lock() save := false
if n.Responded == 0 { defer func() {
responded = false if save {
n.Expires = expires n.save(true)
} }
n.lock.Unlock() }()
// save if not yet responded n.lock.Lock()
if !responded { defer n.lock.Unlock()
n.Save()
// Don't update if notification isn't active.
if n.State != Active {
return
} }
// Don't update too quickly.
if n.Meta().Modified > time.Now().Add(-10*time.Second).Unix() {
return
}
// Update expiry and save.
n.Expires = expires
save = true
} }
// Delete (prematurely) cancels and deletes a notification. // Delete (prematurely) cancels and deletes a notification.
func (n *Notification) Delete() error { func (n *Notification) Delete() error {
notsLock.Lock() n.delete(true)
defer notsLock.Unlock() return nil
n.Lock() }
defer n.Unlock()
// mark as deleted // delete deletes the notification from the internal storage. It locks the
// notification, so it must not be locked when delete is called.
func (n *Notification) delete(pushUpdate bool) {
var id string
// Delete notification after processing deletion.
defer func() {
// Lock and delete from notification storage.
notsLock.Lock()
defer notsLock.Unlock()
delete(nots, id)
}()
// We do not access EventData here, so it is enough to just lock the
// notification itself.
n.lock.Lock()
defer n.lock.Unlock()
// Save ID for deletion
id = n.EventID
// Mark notification as deleted.
n.Meta().Delete() n.Meta().Delete()
// delete from internal storage // Close expiry channel if available.
delete(nots, n.ID)
// close expired
if n.expiredTrigger != nil { if n.expiredTrigger != nil {
close(n.expiredTrigger) close(n.expiredTrigger)
n.expiredTrigger = nil n.expiredTrigger = nil
} }
// push update // Push update via the database system if needed.
dbController.PushUpdate(n) if pushUpdate {
dbController.PushUpdate(n)
// delete from persistent storage
if n.Persistent && persistentBasePath != "" {
key := fmt.Sprintf("%s/%s", persistentBasePath, n.ID)
err := dbInterface.Delete(key)
if err != nil && err != database.ErrNotFound {
return fmt.Errorf("failed to delete persisted notification %s from database: %s", key, err)
}
} }
return nil
} }
// Expired notifies the caller when the notification has expired. // Expired notifies the caller when the notification has expired.
@ -262,23 +348,29 @@ func (n *Notification) Expired() <-chan struct{} {
// selectAndExecuteAction sets the user response and executes/triggers the action, if possible. // selectAndExecuteAction sets the user response and executes/triggers the action, if possible.
func (n *Notification) selectAndExecuteAction(id string) { func (n *Notification) selectAndExecuteAction(id string) {
// abort if already executed if n.State != Active {
if n.Executed != 0 {
return return
} }
// set response n.State = Responded
n.Responded = time.Now().Unix()
n.SelectedActionID = id n.SelectedActionID = id
// execute
executed := false executed := false
if n.actionFunction != nil { if n.actionFunction != nil {
go n.actionFunction(n) module.StartWorker("notification action execution", func(ctx context.Context) error {
return n.actionFunction(ctx, n)
})
executed = true executed = true
} }
if n.actionTrigger != nil { if n.actionTrigger != nil {
// satisfy all listeners // satisfy all listeners (if they are listening)
// TODO(ppacher): if we miss to notify the waiter here (because
// nobody is listeing on actionTrigger) we wil likely
// never be able to execute the action again (simply because
// we won't try). May consider replacing the single actionTrigger
// channel with a per-listener (buffered) one so we just send
// the value and close the channel.
triggerAll: triggerAll:
for { for {
select { select {
@ -290,42 +382,30 @@ func (n *Notification) selectAndExecuteAction(id string) {
} }
} }
// save execution time
if executed { if executed {
n.Executed = time.Now().Unix() n.State = Executed
} }
} }
// AddDataSubject adds the data subject to the notification. This is the only way how a data subject should be added - it avoids locking problems. // Lock locks the Notification. If EventData is set and
func (n *Notification) AddDataSubject(ds sync.Locker) { // implements sync.Locker it is locked as well. Users that
n.lock.Lock() // want to replace the EventData on a notification must
defer n.lock.Unlock() // ensure to unlock the current value on their own. If the
n.DataSubject = ds // new EventData implements sync.Locker as well, it must
} // be locked prior to unlocking the notification.
// Lock locks the Notification and the DataSubject, if available.
func (n *Notification) Lock() { func (n *Notification) Lock() {
n.lock.Lock() n.lock.Lock()
if n.DataSubject != nil { if locker, ok := n.EventData.(sync.Locker); ok {
n.DataSubject.Lock() locker.Lock()
} }
} }
// Unlock unlocks the Notification and the DataSubject, if available. // Unlock unlocks the Notification and the EventData, if
// it implements sync.Locker. See Lock() for more information
// on how to replace and work with EventData.
func (n *Notification) Unlock() { func (n *Notification) Unlock() {
n.lock.Unlock() n.lock.Unlock()
if n.DataSubject != nil { if locker, ok := n.EventData.(sync.Locker); ok {
n.DataSubject.Unlock() locker.Unlock()
} }
} }
func duplicateActions(original []*Action) (duplicate []*Action) {
duplicate = make([]*Action, len(original))
for _, action := range original {
duplicate = append(duplicate, &Action{
ID: action.ID,
Text: action.Text,
})
}
return
}

40
runtime/module_api.go Normal file
View file

@ -0,0 +1,40 @@
package runtime
import (
"github.com/safing/portbase/database"
"github.com/safing/portbase/modules"
)
var (
// DefaultRegistry is the default registry
// that is used by the module-level API.
DefaultRegistry = NewRegistry()
)
func init() {
modules.Register("runtime", nil, startModule, nil, "database")
}
func startModule() error {
_, err := database.Register(&database.Database{
Name: "runtime",
Description: "Runtime database",
StorageType: "injected",
ShadowDelete: false,
})
if err != nil {
return err
}
if err := DefaultRegistry.InjectAsDatabase("runtime"); err != nil {
return err
}
return nil
}
// Register is like Registry.Register but uses
// the package DefaultRegistry.
func Register(key string, provider ValueProvider) (PushFunc, error) {
return DefaultRegistry.Register(key, provider)
}

72
runtime/provider.go Normal file
View file

@ -0,0 +1,72 @@
package runtime
import (
"errors"
"github.com/safing/portbase/database/record"
)
var (
// ErrReadOnly should be returned from ValueProvider.Set if a
// runtime record is considered read-only.
ErrReadOnly = errors.New("runtime record is read-only")
// ErrWriteOnly should be returned from ValueProvider.Get if
// a runtime record is considered write-only.
ErrWriteOnly = errors.New("runtime record is write-only")
)
type (
// PushFunc is returned when registering a new value provider
// and can be used to inform the database system about the
// availability of a new runtime record value. Similar to
// database.Controller.PushUpdate, the caller must hold
// the lock for each record passed to PushFunc.
PushFunc func(...record.Record)
// ValueProvider provides access to a runtime-computed
// database record.
ValueProvider interface {
// Set is called when the value is set from outside.
// If the runtime value is considered read-only ErrReadOnly
// should be returned. It is guaranteed that the key of
// the record passed to Set is prefixed with the key used
// to register the value provider.
Set(record.Record) (record.Record, error)
// Get should return one or more records that match keyOrPrefix.
// keyOrPrefix is guaranteed to be at least the prefix used to
// register the ValueProvider.
Get(keyOrPrefix string) ([]record.Record, error)
}
// SimpleValueSetterFunc is a convenience type for implementing a
// write-only value provider.
SimpleValueSetterFunc func(record.Record) (record.Record, error)
// SimpleValueGetterFunc is a convenience type for implementing a
// read-only value provider.
SimpleValueGetterFunc func(keyOrPrefix string) ([]record.Record, error)
)
// Set implements ValueProvider.Set and calls fn.
func (fn SimpleValueSetterFunc) Set(r record.Record) (record.Record, error) {
return fn(r)
}
// Get implements ValueProvider.Get and returns ErrWriteOnly.
func (SimpleValueSetterFunc) Get(_ string) ([]record.Record, error) {
return nil, ErrWriteOnly
}
// Set implements ValueProvider.Set and returns ErrReadOnly.
func (SimpleValueGetterFunc) Set(r record.Record) (record.Record, error) {
return nil, ErrReadOnly
}
// Get implements ValueProvider.Get and calls fn.
func (fn SimpleValueGetterFunc) Get(keyOrPrefix string) ([]record.Record, error) {
return fn(keyOrPrefix)
}
// compile time checks
var _ ValueProvider = SimpleValueGetterFunc(nil)
var _ ValueProvider = SimpleValueSetterFunc(nil)

334
runtime/registry.go Normal file
View file

@ -0,0 +1,334 @@
package runtime
import (
"errors"
"fmt"
"strings"
"sync"
"github.com/armon/go-radix"
"github.com/safing/portbase/database"
"github.com/safing/portbase/database/iterator"
"github.com/safing/portbase/database/query"
"github.com/safing/portbase/database/record"
"github.com/safing/portbase/database/storage"
"github.com/safing/portbase/log"
"golang.org/x/sync/errgroup"
)
var (
// ErrKeyTaken is returned when trying to register
// a value provider at database key or prefix that
// is already occupied by another provider.
ErrKeyTaken = errors.New("runtime key or prefix already used")
// ErrKeyUnmanaged is returned when a Put operation
// on an unmanaged key is performed.
ErrKeyUnmanaged = errors.New("runtime key not managed by any provider")
// ErrInjected is returned by Registry.InjectAsDatabase
// if the registry has already been injected.
ErrInjected = errors.New("registry already injected")
)
// Registry keeps track of registered runtime
// value providers and exposes them via an
// injected database. Users normally just need
// to use the defaul registry provided by this
// package but may consider creating a dedicated
// runtime registry on their own. Registry uses
// a radix tree for value providers and their
// chosen database key/prefix.
type Registry struct {
l sync.RWMutex
providers *radix.Tree
dbController *database.Controller
dbName string
}
// keyedValueProvider simply wraps a value provider with it's
// registration prefix.
type keyedValueProvider struct {
ValueProvider
key string
}
// NewRegistry returns a new registry.
func NewRegistry() *Registry {
return &Registry{
providers: radix.New(),
}
}
func isPrefixKey(key string) bool {
return strings.HasSuffix(key, "/")
}
// DatabaseName returns the name of the database where the
// registry has been injected. It returns an empty string
// if InjectAsDatabase has not been called.
func (r *Registry) DatabaseName() string {
r.l.RLock()
defer r.l.RUnlock()
return r.dbName
}
// InjectAsDatabase injects the registry as the storage
// database for name.
func (r *Registry) InjectAsDatabase(name string) error {
r.l.Lock()
defer r.l.Unlock()
if r.dbController != nil {
return ErrInjected
}
ctrl, err := database.InjectDatabase(name, r.asStorage())
if err != nil {
return err
}
r.dbName = name
r.dbController = ctrl
return nil
}
// Register registers a new value provider p under keyOrPrefix. The
// returned PushFunc can be used to send update notitifcations to
// database subscribers. Note that keyOrPrefix must end in '/' to be
// accepted as a prefix.
func (r *Registry) Register(keyOrPrefix string, p ValueProvider) (PushFunc, error) {
r.l.Lock()
defer r.l.Unlock()
// search if there's a provider registered for a prefix
// that matches or is equal to keyOrPrefix.
key, _, ok := r.providers.LongestPrefix(keyOrPrefix)
if ok && (isPrefixKey(key) || key == keyOrPrefix) {
return nil, fmt.Errorf("%w: found provider on %s", ErrKeyTaken, key)
}
// if keyOrPrefix is a prefix there must not be any provider
// registered for a key that matches keyOrPrefix.
if isPrefixKey(keyOrPrefix) {
foundProvider := ""
r.providers.WalkPrefix(keyOrPrefix, func(s string, _ interface{}) bool {
foundProvider = s
return true
})
if foundProvider != "" {
return nil, fmt.Errorf("%w: found provider on %s", ErrKeyTaken, foundProvider)
}
}
r.providers.Insert(keyOrPrefix, &keyedValueProvider{
ValueProvider: TraceProvider(p),
key: keyOrPrefix,
})
log.Tracef("runtime: registered new provider at %s", keyOrPrefix)
return func(records ...record.Record) {
r.l.RLock()
defer r.l.RUnlock()
if r.dbController == nil {
return
}
for _, rec := range records {
r.dbController.PushUpdate(rec)
}
}, nil
}
// Get returns the runtime value that is identified by key.
// It implements the storage.Interface.
func (r *Registry) Get(key string) (record.Record, error) {
provider := r.getMatchingProvider(key)
if provider == nil {
return nil, database.ErrNotFound
}
records, err := provider.Get(key)
if err != nil {
// instead of returning ErrWriteOnly to the database interface
// we wrap it in ErrNotFound so the records effectively gets
// hidden.
if errors.Is(err, ErrWriteOnly) {
return nil, database.ErrNotFound
}
return nil, err
}
// Get performs an exact match so filter out
// and values that do not match key.
for _, r := range records {
if r.DatabaseKey() == key {
return r, nil
}
}
return nil, database.ErrNotFound
}
// Put stores the record m in the runtime database. Note that
// ErrReadOnly is returned if there's no value provider responsible
// for m.Key().
func (r *Registry) Put(m record.Record) (record.Record, error) {
provider := r.getMatchingProvider(m.DatabaseKey())
if provider == nil {
// if there's no provider for the given value
// return ErrKeyUnmanaged.
return nil, ErrKeyUnmanaged
}
res, err := provider.Set(m)
if err != nil {
return nil, err
}
return res, nil
}
// Query performs a query on the runtime registry returning all
// records across all value providers that match q.
// Query implements the storage.Storage interface.
func (r *Registry) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
if _, err := q.Check(); err != nil {
return nil, fmt.Errorf("invalid query: %w", err)
}
searchPrefix := q.DatabaseKeyPrefix()
providers := r.collectProviderByPrefix(searchPrefix)
if len(providers) == 0 {
return nil, fmt.Errorf("%w: for key %s", ErrKeyUnmanaged, searchPrefix)
}
iter := iterator.New()
grp := new(errgroup.Group)
for idx := range providers {
p := providers[idx]
grp.Go(func() (err error) {
defer recovery(&err)
key := p.key
if len(searchPrefix) > len(key) {
key = searchPrefix
}
records, err := p.Get(key)
if err != nil {
if errors.Is(err, ErrWriteOnly) {
return nil
}
return err
}
for _, r := range records {
r.Lock()
var (
matchesKey = q.MatchesKey(r.DatabaseKey())
isValid = r.Meta().CheckValidity()
isAllowed = r.Meta().CheckPermission(local, internal)
allowed = matchesKey && isValid && isAllowed
)
if allowed {
allowed = q.MatchesRecord(r)
}
r.Unlock()
if !allowed {
log.Tracef("runtime: not sending %s for query %s. matchesKey=%v isValid=%v isAllowed=%v", r.DatabaseKey(), searchPrefix, matchesKey, isValid, isAllowed)
continue
}
select {
case iter.Next <- r:
case <-iter.Done:
return nil
}
}
return nil
})
}
go func() {
err := grp.Wait()
iter.Finish(err)
}()
return iter, nil
}
func (r *Registry) getMatchingProvider(key string) *keyedValueProvider {
r.l.RLock()
defer r.l.RUnlock()
providerKey, provider, ok := r.providers.LongestPrefix(key)
if !ok {
return nil
}
if !isPrefixKey(providerKey) && providerKey != key {
return nil
}
return provider.(*keyedValueProvider)
}
func (r *Registry) collectProviderByPrefix(prefix string) []*keyedValueProvider {
r.l.RLock()
defer r.l.RUnlock()
// if there's a LongestPrefix provider that's the only one
// we need to ask
if _, p, ok := r.providers.LongestPrefix(prefix); ok {
return []*keyedValueProvider{p.(*keyedValueProvider)}
}
var providers []*keyedValueProvider
r.providers.WalkPrefix(prefix, func(key string, p interface{}) bool {
providers = append(providers, p.(*keyedValueProvider))
return false
})
return providers
}
// GetRegistrationKeys returns a list of all provider registration
// keys or prefixes.
func (r *Registry) GetRegistrationKeys() []string {
r.l.RLock()
defer r.l.RUnlock()
var keys []string
r.providers.Walk(func(key string, p interface{}) bool {
keys = append(keys, key)
return false
})
return keys
}
// asStorage returns a storage.Interface compatible struct
// that is backed by r.
func (r *Registry) asStorage() storage.Interface {
return &storageWrapper{
Registry: r,
}
}
func recovery(err *error) {
if x := recover(); x != nil {
if e, ok := x.(error); ok {
*err = e
return
}
*err = fmt.Errorf("%v", x)
}
}

150
runtime/registry_test.go Normal file
View file

@ -0,0 +1,150 @@
package runtime
import (
"errors"
"sync"
"testing"
"github.com/safing/portbase/database/query"
"github.com/safing/portbase/database/record"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type testRecord struct {
record.Base
sync.Mutex
Value string
}
func makeTestRecord(key, value string) record.Record {
r := &testRecord{Value: value}
r.CreateMeta()
r.SetKey("runtime:" + key)
return r
}
type testProvider struct {
k string
r []record.Record
}
func (tp *testProvider) Get(key string) ([]record.Record, error) {
return tp.r, nil
}
func (tp *testProvider) Set(r record.Record) (record.Record, error) {
return nil, errors.New("not implemented")
}
func getTestRegistry(t *testing.T) *Registry {
t.Helper()
r := NewRegistry()
providers := []testProvider{
{
k: "p1/",
r: []record.Record{
makeTestRecord("p1/f1/v1", "p1.1"),
makeTestRecord("p1/f2/v2", "p1.2"),
makeTestRecord("p1/v3", "p1.3"),
},
},
{
k: "p2/f1",
r: []record.Record{
makeTestRecord("p2/f1/v1", "p2.1"),
makeTestRecord("p2/f1/f2/v2", "p2.2"),
makeTestRecord("p2/f1/v3", "p2.3"),
},
},
}
for idx := range providers {
p := providers[idx]
_, err := r.Register(p.k, &p)
require.NoError(t, err)
}
return r
}
func TestRegistryGet(t *testing.T) {
var (
r record.Record
err error
)
reg := getTestRegistry(t)
r, err = reg.Get("p1/f1/v1")
require.NoError(t, err)
require.NotNil(t, r)
assert.Equal(t, "p1.1", r.(*testRecord).Value)
r, err = reg.Get("p1/v3")
require.NoError(t, err)
require.NotNil(t, r)
assert.Equal(t, "p1.3", r.(*testRecord).Value)
r, err = reg.Get("p1/v4")
require.Error(t, err)
assert.Nil(t, r)
r, err = reg.Get("no-provider/foo")
require.Error(t, err)
assert.Nil(t, r)
}
func TestRegistryQuery(t *testing.T) {
reg := getTestRegistry(t)
q := query.New("runtime:p")
iter, err := reg.Query(q, true, true)
require.NoError(t, err)
require.NotNil(t, iter)
var records []record.Record //nolint:prealloc
for r := range iter.Next {
records = append(records, r)
}
assert.Len(t, records, 6)
q = query.New("runtime:p1/f")
iter, err = reg.Query(q, true, true)
require.NoError(t, err)
require.NotNil(t, iter)
records = nil
for r := range iter.Next {
records = append(records, r)
}
assert.Len(t, records, 2)
}
func TestRegistryRegister(t *testing.T) {
r := NewRegistry()
cases := []struct {
inp string
err bool
}{
{"runtime:foo/bar/bar", false},
{"runtime:foo/bar/bar2", false},
{"runtime:foo/bar", false},
{"runtime:foo/bar", true}, // already used
{"runtime:foo/bar/", true}, // cannot register a prefix if there are providers below
{"runtime:foo/baz/", false},
{"runtime:foo/baz2/", false},
{"runtime:foo/baz3", false},
{"runtime:foo/baz/bar", true},
}
for _, c := range cases {
_, err := r.Register(c.inp, nil)
if c.err {
assert.Error(t, err, c.inp)
} else {
assert.NoError(t, err, c.inp)
}
}
}

View file

@ -0,0 +1,45 @@
package runtime
import "github.com/safing/portbase/database/record"
// singleRecordReader is a convenience type for read-only exposing
// a single record.Record. Note that users must lock the whole record
// themself before performing any manipulation on the record.
type singleRecordReader struct {
record.Record
}
// ProvideRecord returns a ValueProvider the exposes read-only
// access to r. Users of ProvideRecord need to ensure the lock
// the whole record before performing modifications on it.
//
// Example:
//
// type MyValue struct {
// record.Base
// Value string
// }
// r := new(MyValue)
// pushUpdate, _ := runtime.Register("my/key", ProvideRecord(r))
// r.Lock()
// r.Value = "foobar"
// pushUpdate(r)
// r.Unlock()
//
func ProvideRecord(r record.Record) ValueProvider {
return &singleRecordReader{r}
}
// Set implements ValueProvider.Set and returns ErrReadOnly.
func (sr *singleRecordReader) Set(_ record.Record) (record.Record, error) {
return nil, ErrReadOnly
}
// Get implements ValueProvider.Get and returns the wrapped record.Record
// but only if keyOrPrefix exactly matches the records database key.
func (sr *singleRecordReader) Get(keyOrPrefix string) ([]record.Record, error) {
if keyOrPrefix != sr.Record.Key() {
return nil, nil
}
return []record.Record{sr.Record}, nil
}

32
runtime/storage.go Normal file
View file

@ -0,0 +1,32 @@
package runtime
import (
"github.com/safing/portbase/database/iterator"
"github.com/safing/portbase/database/query"
"github.com/safing/portbase/database/record"
"github.com/safing/portbase/database/storage"
)
// storageWrapper is a simple wrapper around storage.InjectBase and
// Registry and make sure the supported methods are handled by
// the registry rather than the InjectBase defaults.
// storageWrapper is mainly there to keep the method landscape of
// Registry as small as possible.
type storageWrapper struct {
storage.InjectBase
Registry *Registry
}
func (sw *storageWrapper) Get(key string) (record.Record, error) {
return sw.Registry.Get(key)
}
func (sw *storageWrapper) Put(r record.Record) (record.Record, error) {
return sw.Registry.Put(r)
}
func (sw *storageWrapper) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
return sw.Registry.Query(q, local, internal)
}
func (sw *storageWrapper) ReadOnly() bool { return false }

37
runtime/trace_provider.go Normal file
View file

@ -0,0 +1,37 @@
package runtime
import (
"time"
"github.com/safing/portbase/database/record"
"github.com/safing/portbase/log"
)
// traceValueProvider can be used to wrap an
// existing value provider to trace an calls to
// their Set and Get methods.
type traceValueProvider struct {
ValueProvider
}
// TraceProvider returns a new ValueProvider that wraps
// vp but traces all Set and Get methods calls.
func TraceProvider(vp ValueProvider) ValueProvider {
return &traceValueProvider{vp}
}
func (tvp *traceValueProvider) Set(r record.Record) (res record.Record, err error) {
defer func(start time.Time) {
log.Tracef("runtime: setting record %q: duration=%s err=%v", r.Key(), time.Since(start), err)
}(time.Now())
return tvp.ValueProvider.Set(r)
}
func (tvp *traceValueProvider) Get(keyOrPrefix string) (records []record.Record, err error) {
defer func(start time.Time) {
log.Tracef("runtime: loading records %q: duration=%s err=%v #records=%d", keyOrPrefix, time.Since(start), err, len(records))
}(time.Now())
return tvp.ValueProvider.Get(keyOrPrefix)
}

View file

@ -30,7 +30,7 @@ func init() {
module, module,
"config:template", // key space for configuration options registered "config:template", // key space for configuration options registered
&config.Option{ &config.Option{
Name: "Enable Template Subsystem", Name: "Template Subsystem",
Key: "config:subsystems/template", Key: "config:subsystems/template",
Description: "This option enables the Template Subsystem [TEMPLATE]", Description: "This option enables the Template Subsystem [TEMPLATE]",
OptType: config.OptTypeBool, OptType: config.OptTypeBool,
@ -46,7 +46,7 @@ func prep() error {
// register options // register options
err := config.Register(&config.Option{ err := config.Register(&config.Option{
Name: "language", Name: "language",
Key: "config:template/language", Key: "template/language",
Description: "Sets the language for the template [TEMPLATE]", Description: "Sets the language for the template [TEMPLATE]",
OptType: config.OptTypeString, OptType: config.OptTypeString,
ExpertiseLevel: config.ExpertiseLevelUser, // default ExpertiseLevel: config.ExpertiseLevelUser, // default

View file

@ -1,7 +1,7 @@
package updater package updater
import ( import (
"fmt" "path"
"regexp" "regexp"
"strings" "strings"
) )
@ -13,34 +13,45 @@ var (
// GetIdentifierAndVersion splits the given file path into its identifier and version. // GetIdentifierAndVersion splits the given file path into its identifier and version.
func GetIdentifierAndVersion(versionedPath string) (identifier, version string, ok bool) { func GetIdentifierAndVersion(versionedPath string) (identifier, version string, ok bool) {
// extract version dirPath, filename := path.Split(versionedPath)
rawVersion := fileVersionRegex.FindString(versionedPath)
// Extract version from filename.
rawVersion := fileVersionRegex.FindString(filename)
if rawVersion == "" { if rawVersion == "" {
// No version present in file, making it invalid.
return "", "", false return "", "", false
} }
// replace - with . and trim _ // Trim the `_v` that gets caught by the regex and
// replace `-` with `.` to get the version string.
version = strings.Replace(strings.TrimLeft(rawVersion, "_v"), "-", ".", -1) version = strings.Replace(strings.TrimLeft(rawVersion, "_v"), "-", ".", -1)
// put together without version // Put the filename back together without version.
i := strings.Index(versionedPath, rawVersion) i := strings.Index(filename, rawVersion)
if i < 0 { if i < 0 {
// extracted version not in string (impossible) // extracted version not in string (impossible)
return "", "", false return "", "", false
} }
return versionedPath[:i] + versionedPath[i+len(rawVersion):], version, true filename = filename[:i] + filename[i+len(rawVersion):]
// Put the full path back together and return it.
// `dirPath + filename` is guaranteed by path.Split()
return dirPath + filename, version, true
} }
// GetVersionedPath combines the identifier and version and returns it as a file path. // GetVersionedPath combines the identifier and version and returns it as a file path.
func GetVersionedPath(identifier, version string) (versionedPath string) { func GetVersionedPath(identifier, version string) (versionedPath string) {
// split in half identifierPath, filename := path.Split(identifier)
splittedFilePath := strings.SplitN(identifier, ".", 2)
// replace . with - // Split the filename where the version should go.
splittedFilename := strings.SplitN(filename, ".", 2)
// Replace `.` with `-` for the filename format.
transformedVersion := strings.Replace(version, ".", "-", -1) transformedVersion := strings.Replace(version, ".", "-", -1)
// put together // Put everything back together and return it.
if len(splittedFilePath) == 1 { versionedPath = identifierPath + splittedFilename[0] + "_v" + transformedVersion
return fmt.Sprintf("%s_v%s", splittedFilePath[0], transformedVersion) if len(splittedFilename) > 1 {
versionedPath += "." + splittedFilename[1]
} }
return fmt.Sprintf("%s_v%s.%s", splittedFilePath[0], transformedVersion, splittedFilePath[1]) return versionedPath
} }

View file

@ -1,13 +0,0 @@
package updater
import (
"compress/gzip"
"io"
)
// UnpackGZIP unpacks a GZIP compressed reader r
// and returns a new reader. It's suitable to be
// used with registry.GetPackedFile.
func UnpackGZIP(r io.Reader) (io.Reader, error) {
return gzip.NewReader(r)
}

View file

@ -26,6 +26,7 @@ type ResourceRegistry struct {
UpdateURLs []string UpdateURLs []string
UserAgent string UserAgent string
MandatoryUpdates []string MandatoryUpdates []string
AutoUnpack []string
Beta bool Beta bool
DevMode bool DevMode bool
@ -170,6 +171,14 @@ func (reg *ResourceRegistry) Purge(keep int) {
} }
} }
// Reset resets the internal state of the registry, removing all added resources.
func (reg *ResourceRegistry) Reset() {
reg.Lock()
defer reg.Unlock()
reg.resources = make(map[string]*Resource)
}
// Cleanup removes temporary files. // Cleanup removes temporary files.
func (reg *ResourceRegistry) Cleanup() error { func (reg *ResourceRegistry) Cleanup() error {
// delete download tmp dir // delete download tmp dir

View file

@ -338,65 +338,101 @@ func (res *Resource) Blacklist(version string) error {
// Purge deletes old updates, retaining a certain amount, specified by // Purge deletes old updates, retaining a certain amount, specified by
// the keep parameter. Purge will always keep at least 2 versions so // the keep parameter. Purge will always keep at least 2 versions so
// specifying a smaller keep value will have no effect. Note that // specifying a smaller keep value will have no effect.
// blacklisted versions are not counted for the keep parameter. func (res *Resource) Purge(keepExtra int) { //nolint:gocognit
// After purging a new version will be selected.
func (res *Resource) Purge(keep int) {
res.Lock() res.Lock()
defer res.Unlock() defer res.Unlock()
// safeguard // If there is any blacklisted version within the resource, pause purging.
if keep < 2 { // In this case we may need extra available versions beyond what would be
keep = 2 // available after purging.
for _, rv := range res.Versions {
if rv.Blacklisted {
log.Debugf(
"%s: pausing purging of resource %s, as it contains blacklisted items",
res.registry.Name,
rv.resource.Identifier,
)
return
}
} }
// keep versions // Safeguard the amount of extra version to keep.
var validVersions int if keepExtra < 2 {
keepExtra = 2
}
// Search for purge boundary.
var purgeBoundary int
var skippedActiveVersion bool var skippedActiveVersion bool
var skippedSelectedVersion bool var skippedSelectedVersion bool
var purgeFrom int var skippedStableVersion bool
boundarySearch:
for i, rv := range res.Versions { for i, rv := range res.Versions {
// continue to purging? // Check if required versions are already skipped.
if validVersions >= keep && // skip at least <keep> versions switch {
skippedActiveVersion && // skip until active version case !skippedActiveVersion && res.ActiveVersion != nil:
skippedSelectedVersion { // skip until selected version // Skip versions until the active version, if it's set.
purgeFrom = i case !skippedSelectedVersion && res.SelectedVersion != nil:
break // Skip versions until the selected version, if it's set.
case !skippedStableVersion:
// Skip versions until the stable version.
default:
// All required version skipped, set purge boundary.
purgeBoundary = i + keepExtra
break boundarySearch
} }
// keep active version // Check if current instance is a required version.
if !skippedActiveVersion && rv == res.ActiveVersion { if rv == res.ActiveVersion {
skippedActiveVersion = true skippedActiveVersion = true
} }
if rv == res.SelectedVersion {
// keep selected version
if !skippedSelectedVersion && rv == res.SelectedVersion {
skippedSelectedVersion = true skippedSelectedVersion = true
} }
if rv.StableRelease {
// count valid (not blacklisted) versions skippedStableVersion = true
if !rv.Blacklisted {
validVersions++
} }
} }
// check if there is anything to purge // Check if there is anything to purge at all.
if purgeFrom < keep || purgeFrom > len(res.Versions) { if purgeBoundary <= keepExtra || purgeBoundary >= len(res.Versions) {
return return
} }
// purge phase // Purge everything beyond the purge boundary.
for _, rv := range res.Versions[purgeFrom:] { for _, rv := range res.Versions[purgeBoundary:] {
// delete storagePath := rv.storagePath()
err := os.Remove(rv.storagePath()) // Remove resource file.
err := os.Remove(storagePath)
if err != nil { if err != nil {
log.Warningf("%s: failed to purge old resource %s: %s", res.registry.Name, rv.storagePath(), err) log.Warningf("%s: failed to purge resource %s v%s: %s", res.registry.Name, rv.resource.Identifier, rv.VersionNumber, err)
} else {
log.Tracef("%s: purged resource %s v%s", res.registry.Name, rv.resource.Identifier, rv.VersionNumber)
}
// Remove unpacked version of resource.
ext := filepath.Ext(storagePath)
if ext == "" {
// Nothing to do if file does not have an extension.
continue
}
unpackedPath := strings.TrimSuffix(storagePath, ext)
// Remove if it exists, or an error occurs on access.
_, err = os.Stat(unpackedPath)
if err == nil || !os.IsNotExist(err) {
err = os.Remove(unpackedPath)
if err != nil {
log.Warningf("%s: failed to purge unpacked resource %s v%s: %s", res.registry.Name, rv.resource.Identifier, rv.VersionNumber, err)
} else {
log.Tracef("%s: purged unpacked resource %s v%s", res.registry.Name, rv.resource.Identifier, rv.VersionNumber)
}
} }
} }
// remove entries of deleted files
res.Versions = res.Versions[purgeFrom:]
res.selectVersion() // remove entries of deleted files
res.Versions = res.Versions[purgeBoundary:]
} }
func (rv *ResourceVersion) versionedPath() string { func (rv *ResourceVersion) versionedPath() string {

View file

@ -39,12 +39,23 @@ func (reg *ResourceRegistry) ScanStorage(root string) error {
// walk fs // walk fs
_ = filepath.Walk(root, func(path string, info os.FileInfo, err error) error { _ = filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
// skip tmp dir (including errors trying to read it)
if strings.HasPrefix(path, reg.tmpDir.Path) {
return filepath.SkipDir
}
// handle walker error
if err != nil { if err != nil {
lastError = fmt.Errorf("%s: could not read %s: %w", reg.Name, path, err) lastError = fmt.Errorf("%s: could not read %s: %w", reg.Name, path, err)
log.Warning(lastError.Error()) log.Warning(lastError.Error())
return nil return nil
} }
// ignore directories
if info.IsDir() {
return nil
}
// get relative path to storage // get relative path to storage
relativePath, err := filepath.Rel(reg.storageDir.Path, path) relativePath, err := filepath.Rel(reg.storageDir.Path, path)
if err != nil { if err != nil {
@ -52,10 +63,6 @@ func (reg *ResourceRegistry) ScanStorage(root string) error {
log.Warning(lastError.Error()) log.Warning(lastError.Error())
return nil return nil
} }
// ignore files in tmp dir
if strings.HasPrefix(relativePath, reg.tmpDir.Path) {
return nil
}
// convert to identifier and version // convert to identifier and version
relativePath = filepath.ToSlash(relativePath) relativePath = filepath.ToSlash(relativePath)

176
updater/unpacking.go Normal file
View file

@ -0,0 +1,176 @@
package updater
import (
"archive/zip"
"compress/gzip"
"fmt"
"io"
"os"
"path"
"path/filepath"
"strings"
"github.com/safing/portbase/log"
"github.com/hashicorp/go-multierror"
"github.com/safing/portbase/utils"
)
// UnpackGZIP unpacks a GZIP compressed reader r
// and returns a new reader. It's suitable to be
// used with registry.GetPackedFile.
func UnpackGZIP(r io.Reader) (io.Reader, error) {
return gzip.NewReader(r)
}
// UnpackResources unpacks all resources defined in the AutoUnpack list.
func (reg *ResourceRegistry) UnpackResources() error {
reg.RLock()
defer reg.RUnlock()
var multierr *multierror.Error
for _, res := range reg.resources {
if utils.StringInSlice(reg.AutoUnpack, res.Identifier) {
err := res.UnpackArchive()
if err != nil {
multierr = multierror.Append(multierr, err)
}
}
}
return multierr.ErrorOrNil()
}
const (
zipSuffix = ".zip"
)
// UnpackArchive unpacks the archive the resource refers to. The contents are
// unpacked into a directory with the same name as the file, excluding the
// suffix. If the destination folder already exists, it is assumed that the
// contents have already been correctly unpacked.
func (res *Resource) UnpackArchive() error {
res.Lock()
defer res.Unlock()
// Only unpack selected versions.
if res.SelectedVersion == nil {
return nil
}
switch {
case strings.HasSuffix(res.Identifier, zipSuffix):
return res.unpackZipArchive()
default:
return fmt.Errorf("unsupported file type for unpacking")
}
}
func (res *Resource) unpackZipArchive() (err error) {
// Get file and directory paths.
archiveFile := res.SelectedVersion.storagePath()
destDir := strings.TrimSuffix(archiveFile, zipSuffix)
tmpDir := filepath.Join(
res.registry.tmpDir.Path,
filepath.FromSlash(strings.TrimSuffix(
path.Base(res.SelectedVersion.versionedPath()),
zipSuffix,
)),
)
// Check status of destination.
dstStat, err := os.Stat(destDir)
switch {
case os.IsNotExist(err):
// The destination does not exist, continue with unpacking.
case err != nil:
return fmt.Errorf("cannot access destination for unpacking: %w", err)
case !dstStat.IsDir():
return fmt.Errorf("destination for unpacking is blocked by file: %s", dstStat.Name())
default:
// Archive already seems to be unpacked.
return nil
}
// Create the tmp directory for unpacking.
err = res.registry.tmpDir.EnsureAbsPath(tmpDir)
if err != nil {
return fmt.Errorf("failed to create tmp dir for unpacking: %s", err)
}
// Defer clean up of directories.
defer func() {
// Always clean up the tmp dir.
_ = os.RemoveAll(tmpDir)
// Cleanup the destination in case of an error.
if err != nil {
_ = os.RemoveAll(destDir)
}
}()
// Open the archive for reading.
var archiveReader *zip.ReadCloser
archiveReader, err = zip.OpenReader(archiveFile)
if err != nil {
return
}
defer archiveReader.Close()
// Save all files to the tmp dir.
for _, file := range archiveReader.File {
err = copyFromZipArchive(
file,
filepath.Join(tmpDir, filepath.FromSlash(file.Name)),
)
if err != nil {
return
}
}
// Make the final move.
err = os.Rename(tmpDir, destDir)
if err != nil {
return
}
// Fix permissions on the destination dir.
err = res.registry.storageDir.EnsureAbsPath(destDir)
if err != nil {
return
}
log.Infof("%s: unpacked %s", res.registry.Name, res.SelectedVersion.versionedPath())
return nil
}
func copyFromZipArchive(archiveFile *zip.File, dstPath string) error {
// If file is a directory, create it and continue.
if archiveFile.FileInfo().IsDir() {
err := os.Mkdir(dstPath, archiveFile.Mode())
if err != nil {
return err
}
return nil
}
// Open archived file for reading.
fileReader, err := archiveFile.Open()
if err != nil {
return err
}
defer fileReader.Close()
// Open destination file for writing.
dstFile, err := os.OpenFile(dstPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, archiveFile.Mode())
if err != nil {
return err
}
defer dstFile.Close()
// Copy full file from archive to dst.
if _, err := io.Copy(dstFile, fileReader); err != nil {
return err
}
return nil
}

View file

@ -6,7 +6,9 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"path"
"path/filepath" "path/filepath"
"strings"
"github.com/safing/portbase/utils" "github.com/safing/portbase/utils"
@ -45,31 +47,48 @@ func (reg *ResourceRegistry) downloadIndex(ctx context.Context, client *http.Cli
} }
// parse // parse
new := make(map[string]string) newIndexData := make(map[string]string)
err = json.Unmarshal(data, &new) err = json.Unmarshal(data, &newIndexData)
if err != nil { if err != nil {
return fmt.Errorf("failed to parse index %s: %w", idx.Path, err) return fmt.Errorf("failed to parse index %s: %w", idx.Path, err)
} }
// check for content // check for content
if len(new) == 0 { if len(newIndexData) == 0 {
return fmt.Errorf("index %s is empty", idx.Path) return fmt.Errorf("index %s is empty", idx.Path)
} }
// Check if all resources are within the indexes' authority.
authoritativePath := path.Dir(idx.Path) + "/"
if authoritativePath == "./" {
// Fix path for indexes at the storage root.
authoritativePath = ""
}
cleanedData := make(map[string]string, len(newIndexData))
for key, version := range newIndexData {
if strings.HasPrefix(key, authoritativePath) {
cleanedData[key] = version
} else {
log.Warningf("%s: index %s oversteps it's authority by defining version for %s", reg.Name, idx.Path, key)
}
}
// add resources to registry // add resources to registry
err = reg.AddResources(new, false, idx.Stable, idx.Beta) err = reg.AddResources(cleanedData, false, idx.Stable, idx.Beta)
if err != nil { if err != nil {
log.Warningf("%s: failed to add resources: %s", reg.Name, err) log.Warningf("%s: failed to add resources: %s", reg.Name, err)
} }
// check if dest dir exists // check if dest dir exists
err = reg.storageDir.EnsureRelPath(filepath.Dir(idx.Path)) indexDir := filepath.FromSlash(path.Dir(idx.Path))
err = reg.storageDir.EnsureRelPath(indexDir)
if err != nil { if err != nil {
log.Warningf("%s: failed to ensure directory for updated index %s: %s", reg.Name, idx.Path, err) log.Warningf("%s: failed to ensure directory for updated index %s: %s", reg.Name, idx.Path, err)
} }
// save index // save index
err = ioutil.WriteFile(filepath.Join(reg.storageDir.Path, idx.Path), data, 0644) indexPath := filepath.FromSlash(idx.Path)
err = ioutil.WriteFile(filepath.Join(reg.storageDir.Path, indexPath), data, 0644)
if err != nil { if err != nil {
log.Warningf("%s: failed to save updated index %s: %s", reg.Name, idx.Path, err) log.Warningf("%s: failed to save updated index %s: %s", reg.Name, idx.Path, err)
} }

86
utils/osdetail/binmeta.go Normal file
View file

@ -0,0 +1,86 @@
package osdetail
import (
"path/filepath"
"regexp"
"strings"
)
var (
segmentsSplitter = regexp.MustCompile("[^A-Za-z0-9]*[A-Z]?[a-z0-9]*")
nameOnly = regexp.MustCompile("^[A-Za-z0-9]+$")
delimiters = regexp.MustCompile("^[^A-Za-z0-9]+")
)
// GenerateBinaryNameFromPath generates a more human readable binary name from
// the given path. This function is used as fallback in the GetBinaryName
// functions.
func GenerateBinaryNameFromPath(path string) string {
// Get file name from path.
_, fileName := filepath.Split(path)
// Split up into segments.
segments := segmentsSplitter.FindAllString(fileName, -1)
// Remove last segment if it's an extension.
if len(segments) >= 2 &&
strings.HasPrefix(segments[len(segments)-1], ".") {
segments = segments[:len(segments)-1]
}
// Go through segments and collect name parts.
nameParts := make([]string, 0, len(segments))
var fragments string
for _, segment := range segments {
// Group very short segments.
if len(segment) <= 3 {
fragments += segment
continue
} else if fragments != "" {
nameParts = append(nameParts, fragments)
fragments = ""
}
// Add segment to name.
nameParts = append(nameParts, segment)
}
// Add last fragment.
if fragments != "" {
nameParts = append(nameParts, fragments)
}
// Post-process name parts
for i := range nameParts {
// Remove any leading delimiters.
nameParts[i] = delimiters.ReplaceAllString(nameParts[i], "")
// Title-case name-only parts.
if nameOnly.MatchString(nameParts[i]) {
nameParts[i] = strings.Title(nameParts[i])
}
}
return strings.Join(nameParts, " ")
}
func cleanFileDescription(fields []string) string {
// If there is a 1 or 2 character delimiter field, only use fields before it.
endIndex := len(fields)
for i, field := range fields {
// Ignore the first field as well as fields with more than two characters.
if i >= 1 && len(field) <= 2 && !nameOnly.MatchString(field) {
endIndex = i
break
}
}
// Concatenate name
binName := strings.Join(fields[:endIndex], " ")
// If there are multiple sentences, only use the first.
if strings.Contains(binName, ". ") {
binName = strings.SplitN(binName, ". ", 2)[0]
}
return binName
}

View file

@ -0,0 +1,15 @@
//+build !windows
package osdetail
// GetBinaryNameFromSystem queries the operating system for a human readable
// name for the given binary path.
func GetBinaryNameFromSystem(path string) (string, error) {
return "", ErrNotSupported
}
// GetBinaryIconFromSystem queries the operating system for the associated icon
// for a given binary path.
func GetBinaryIconFromSystem(path string) (string, error) {
return "", ErrNotSupported
}

View file

@ -0,0 +1,35 @@
package osdetail
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestGenerateBinaryNameFromPath(t *testing.T) {
assert.Equal(t, "Nslookup", GenerateBinaryNameFromPath("nslookup.exe"))
assert.Equal(t, "System Settings", GenerateBinaryNameFromPath("SystemSettings.exe"))
assert.Equal(t, "One Drive Setup", GenerateBinaryNameFromPath("OneDriveSetup.exe"))
assert.Equal(t, "Msedge", GenerateBinaryNameFromPath("msedge.exe"))
assert.Equal(t, "SIH Client", GenerateBinaryNameFromPath("SIHClient.exe"))
assert.Equal(t, "Openvpn Gui", GenerateBinaryNameFromPath("openvpn-gui.exe"))
assert.Equal(t, "Portmaster Core v0-1-2", GenerateBinaryNameFromPath("portmaster-core_v0-1-2.exe"))
assert.Equal(t, "Win Store App", GenerateBinaryNameFromPath("WinStore.App.exe"))
assert.Equal(t, "Test Script", GenerateBinaryNameFromPath(".test-script"))
assert.Equal(t, "Browser Broker", GenerateBinaryNameFromPath("browser_broker.exe"))
}
func TestCleanFileDescription(t *testing.T) {
assert.Equal(t, "Product Name", cleanFileDescription(strings.Fields("Product Name. Does this and that.")))
assert.Equal(t, "Product Name", cleanFileDescription(strings.Fields("Product Name - Does this and that.")))
assert.Equal(t, "Product Name", cleanFileDescription(strings.Fields("Product Name / Does this and that.")))
assert.Equal(t, "Product Name", cleanFileDescription(strings.Fields("Product Name :: Does this and that.")))
assert.Equal(t, "/ Product Name", cleanFileDescription(strings.Fields("/ Product Name")))
assert.Equal(t, "Product", cleanFileDescription(strings.Fields("Product / Name")))
assert.Equal(t,
"Product Name a Does this and that.",
cleanFileDescription(strings.Fields("Product Name a Does this and that.")),
)
}

View file

@ -0,0 +1,96 @@
package osdetail
import (
"bufio"
"bytes"
"fmt"
"strings"
)
const powershellGetFileDescription = `Get-ItemProperty %q | Format-List`
// GetBinaryNameFromSystem queries the operating system for a human readable
// name for the given binary path.
func GetBinaryNameFromSystem(path string) (string, error) {
// Get FileProperties via Powershell call.
output, err := runPowershellCmd(fmt.Sprintf(powershellGetFileDescription, path))
if err != nil {
return "", fmt.Errorf("failed to get file properties of %s: %s", path, err)
}
// Create scanner for the output.
scanner := bufio.NewScanner(bytes.NewBufferString(output))
scanner.Split(bufio.ScanLines)
// Search for the FileDescription line.
for scanner.Scan() {
// Split line up into fields.
fields := strings.Fields(scanner.Text())
// Discard lines with less than two fields.
if len(fields) < 2 {
continue
}
// Skip all lines that we aren't looking for.
if fields[0] != "FileDescription:" {
continue
}
// Clean and return.
return cleanFileDescription(fields[1:]), nil
}
// Generate a default name as default.
return "", ErrNotFound
}
const powershellGetIcon = `Add-Type -AssemblyName System.Drawing
$Icon = [System.Drawing.Icon]::ExtractAssociatedIcon(%q)
$MemoryStream = New-Object System.IO.MemoryStream
$Icon.save($MemoryStream)
$Bytes = $MemoryStream.ToArray()
$MemoryStream.Flush()
$MemoryStream.Dispose()
[convert]::ToBase64String($Bytes)`
// TODO: This returns a small and crappy icon.
// Saving a better icon to file works:
/*
Add-Type -AssemblyName System.Drawing
$ImgList = New-Object System.Windows.Forms.ImageList
$ImgList.ImageSize = New-Object System.Drawing.Size(256,256)
$ImgList.ColorDepth = 32
$Icon = [System.Drawing.Icon]::ExtractAssociatedIcon("C:\Program Files (x86)\Mozilla Firefox\firefox.exe")
$ImgList.Images.Add($Icon);
$BigIcon = $ImgList.Images.Item(0)
$BigIcon.Save("test.png")
*/
// But not saving to a memory stream:
/*
Add-Type -AssemblyName System.Drawing
$ImgList = New-Object System.Windows.Forms.ImageList
$ImgList.ImageSize = New-Object System.Drawing.Size(256,256)
$ImgList.ColorDepth = 32
$Icon = [System.Drawing.Icon]::ExtractAssociatedIcon("C:\Program Files (x86)\Mozilla Firefox\firefox.exe")
$ImgList.Images.Add($Icon);
$MemoryStream = New-Object System.IO.MemoryStream
$BigIcon = $ImgList.Images.Item(0)
$BigIcon.Save($MemoryStream)
$Bytes = $MemoryStream.ToArray()
$MemoryStream.Flush()
$MemoryStream.Dispose()
[convert]::ToBase64String($Bytes)
*/
// GetBinaryIconFromSystem queries the operating system for the associated icon
// for a given binary path and returns it as a data-URL.
func GetBinaryIconFromSystem(path string) (string, error) {
// Get Associated File Icon via Powershell call.
output, err := runPowershellCmd(fmt.Sprintf(powershellGetIcon, path))
if err != nil {
return "", fmt.Errorf("failed to get file properties of %s: %s", path, err)
}
return "data:image/png;base64," + output, nil
}

12
utils/osdetail/errors.go Normal file
View file

@ -0,0 +1,12 @@
package osdetail
import "errors"
var (
// ErrNotSupported is returned when an operation is not supported on the current platform.
ErrNotSupported = errors.New("not supported")
// ErrNotFound is returned when the desired data is not found.
ErrNotFound = errors.New("not found")
// ErrEmptyOutput is a special error that is returned when an operation has no error, but also returns to data.
ErrEmptyOutput = errors.New("command succeeded with empty output")
)

View file

@ -0,0 +1,47 @@
package osdetail
import (
"bytes"
"errors"
"os/exec"
"strings"
)
func runPowershellCmd(script string) (output string, err error) {
// Create command to execute.
cmd := exec.Command(
"powershell.exe",
"-NoProfile",
"-NonInteractive",
script,
)
// Create and assign output buffers.
var stdoutBuf bytes.Buffer
var stderrBuf bytes.Buffer
cmd.Stdout = &stdoutBuf
cmd.Stderr = &stderrBuf
// Run command and collect output.
err = cmd.Run()
stdout, stderr := stdoutBuf.String(), stderrBuf.String()
if err != nil {
return "", err
}
// Powershell might not return an error, but just write to stdout instead.
if stderr != "" {
return "", errors.New(strings.SplitN(stderr, "\n", 2)[0])
}
// Debugging output:
// fmt.Printf("powershell stdout: %s\n", stdout)
// fmt.Printf("powershell stderr: %s\n", stderr)
// Finalize stdout.
cleanedOutput := strings.TrimSpace(stdout)
if cleanedOutput == "" {
return "", ErrEmptyOutput
}
return cleanedOutput, nil
}

2
utils/osdetail/test/.gitignore vendored Normal file
View file

@ -0,0 +1,2 @@
test
test.exe

View file

@ -7,9 +7,42 @@ import (
) )
func main() { func main() {
fmt.Println("Binary Names:")
printBinaryName("openvpn-gui.exe", `C:\Program Files\OpenVPN\bin\openvpn-gui.exe`)
printBinaryName("firefox.exe", `C:\Program Files (x86)\Mozilla Firefox\firefox.exe`)
printBinaryName("powershell.exe", `C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe`)
printBinaryName("explorer.exe", `C:\Windows\explorer.exe`)
printBinaryName("svchost.exe", `C:\Windows\System32\svchost.exe`)
fmt.Println("\n\nBinary Icons:")
printBinaryIcon("openvpn-gui.exe", `C:\Program Files\OpenVPN\bin\openvpn-gui.exe`)
printBinaryIcon("firefox.exe", `C:\Program Files (x86)\Mozilla Firefox\firefox.exe`)
printBinaryIcon("powershell.exe", `C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe`)
printBinaryIcon("explorer.exe", `C:\Windows\explorer.exe`)
printBinaryIcon("svchost.exe", `C:\Windows\System32\svchost.exe`)
fmt.Println("\n\nSvcHost Service Names:")
names, err := osdetail.GetAllServiceNames() names, err := osdetail.GetAllServiceNames()
if err != nil { if err != nil {
panic(err) panic(err)
} }
fmt.Printf("%+v\n", names) fmt.Printf("%+v\n", names)
} }
func printBinaryName(name, path string) {
binName, err := osdetail.GetBinaryName(path)
if err != nil {
fmt.Printf("%s: ERROR: %s\n", name, err)
} else {
fmt.Printf("%s: %s\n", name, binName)
}
}
func printBinaryIcon(name, path string) {
binIcon, err := osdetail.GetBinaryIcon(path)
if err != nil {
fmt.Printf("%s: ERROR: %s\n", name, err)
} else {
fmt.Printf("%s: %s\n", name, binIcon)
}
}

Binary file not shown.