Merge pull request from safing/feature/config-improvements

Improve config import and export utils
This commit is contained in:
Daniel Hovie 2023-10-03 11:38:11 +02:00 committed by GitHub
commit b41b567d2a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 290 additions and 93 deletions

View file

@ -2,7 +2,6 @@ package api
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
@ -15,6 +14,7 @@ import (
"github.com/gorilla/mux"
"github.com/safing/portbase/database/record"
"github.com/safing/portbase/formats/dsd"
"github.com/safing/portbase/log"
"github.com/safing/portbase/modules"
)
@ -461,7 +461,11 @@ func (e *Endpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var v interface{}
v, err = e.StructFunc(apiRequest)
if err == nil && v != nil {
responseData, err = json.Marshal(v)
var mimeType string
responseData, mimeType, _, err = dsd.MimeDump(v, r.Header.Get("Accept"))
if err == nil {
w.Header().Set("Content-Type", mimeType)
}
}
case e.RecordFunc != nil:
@ -482,7 +486,6 @@ func (e *Endpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Check for handler error.
if err != nil {
// if statusProvider, ok := err.(HTTPStatusProvider); ok {
var statusProvider HTTPStatusProvider
if errors.As(err, &statusProvider) {
http.Error(w, err.Error(), statusProvider.HTTPStatus())
@ -498,8 +501,12 @@ func (e *Endpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
// Set content type if not yet set.
if w.Header().Get("Content-Type") == "" {
w.Header().Set("Content-Type", e.MimeType+"; charset=utf-8")
}
// Write response.
w.Header().Set("Content-Type", e.MimeType+"; charset=utf-8")
w.Header().Set("Content-Length", strconv.Itoa(len(responseData)))
w.WriteHeader(http.StatusOK)
_, err = w.Write(responseData)

View file

@ -14,7 +14,7 @@ func parseAndReplaceConfig(jsonData string) error {
return err
}
validationErrors := replaceConfig(m)
validationErrors, _ := ReplaceConfig(m)
if len(validationErrors) > 0 {
return fmt.Errorf("%d errors, first: %w", len(validationErrors), validationErrors[0])
}
@ -27,7 +27,7 @@ func parseAndReplaceDefaultConfig(jsonData string) error {
return err
}
validationErrors := replaceDefaultConfig(m)
validationErrors, _ := ReplaceDefaultConfig(m)
if len(validationErrors) > 0 {
return fmt.Errorf("%d errors, first: %w", len(validationErrors), validationErrors[0])
}

View file

@ -69,7 +69,7 @@ func start() error {
err = loadConfig(false)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return err
return fmt.Errorf("failed to load config file: %w", err)
}
return nil
}

View file

@ -3,6 +3,7 @@ package config
import (
"encoding/json"
"fmt"
"reflect"
"regexp"
"sync"
@ -108,11 +109,13 @@ const (
// requirement. The type of RequiresAnnotation is []ValueRequirement
// or ValueRequirement.
RequiresAnnotation = "safing/portbase:config:requires"
// RequiresFeaturePlan can be used to mark a setting as only available
// RequiresFeatureIDAnnotation can be used to mark a setting as only available
// when the user has a certain feature ID in the subscription plan.
// The type is []string or string.
RequiresFeatureID = "safing/portmaster:ui:config:requires-feature"
RequiresFeatureIDAnnotation = "safing/portmaster:ui:config:requires-feature"
// SettablePerAppAnnotation can be used to mark a setting as settable per-app and
// is a boolean.
SettablePerAppAnnotation = "safing/portmaster:settable-per-app"
// RequiresUIReloadAnnotation can be used to inform the UI that changing the value
// of the annotated setting requires a full reload of the user interface.
// The value of this annotation does not matter as the sole presence of
@ -308,6 +311,22 @@ func (option *Option) GetAnnotation(key string) (interface{}, bool) {
return val, ok
}
// AnnotationEquals returns whether the annotation of the given key matches the
// given value.
func (option *Option) AnnotationEquals(key string, value any) bool {
option.Lock()
defer option.Unlock()
if option.Annotations == nil {
return false
}
setValue, ok := option.Annotations[key]
if !ok {
return false
}
return reflect.DeepEqual(value, setValue)
}
// copyOrNil returns a copy of the option, or nil if copying failed.
func (option *Option) copyOrNil() *Option {
copied, err := copystructure.Copy(option)
@ -325,6 +344,29 @@ func (option *Option) IsSetByUser() bool {
return option.activeValue != nil
}
// UserValue returns the value set by the user or nil if the value has not
// been changed from the default.
func (option *Option) UserValue() any {
option.Lock()
defer option.Unlock()
if option.activeValue == nil {
return nil
}
return option.activeValue.getData(option)
}
// ValidateValue checks if the given value is valid for the option.
func (option *Option) ValidateValue(value any) error {
option.Lock()
defer option.Unlock()
if _, err := validateValue(option, value); err != nil {
return err
}
return nil
}
// Export expors an option to a Record.
func (option *Option) Export() (record.Record, error) {
option.Lock()

View file

@ -45,7 +45,7 @@ func loadConfig(requireValidConfig bool) error {
return err
}
validationErrors := replaceConfig(newValues)
validationErrors, _ := ReplaceConfig(newValues)
if requireValidConfig && len(validationErrors) > 0 {
return fmt.Errorf("encountered %d validation errors during config loading", len(validationErrors))
}

View file

@ -37,70 +37,112 @@ func signalChanges() {
module.TriggerEvent(ChangeEvent, nil)
}
// replaceConfig sets the (prioritized) user defined config.
func replaceConfig(newValues map[string]interface{}) []*ValidationError {
var validationErrors []*ValidationError
// ValidateConfig validates the given configuration and returns all validation
// errors as well as whether the given configuration contains unknown keys.
func ValidateConfig(newValues map[string]interface{}) (validationErrors []*ValidationError, requiresRestart bool, containsUnknown bool) {
// RLock the options because we are not adding or removing
// options from the registration but rather only update the
// options value which is guarded by the option's lock itself
// options from the registration but rather only checking the
// options value which is guarded by the option's lock itself.
optionsLock.RLock()
defer optionsLock.RUnlock()
var checked int
for key, option := range options {
newValue, ok := newValues[key]
option.Lock()
option.activeValue = nil
if ok {
valueCache, err := validateValue(option, newValue)
if err == nil {
option.activeValue = valueCache
} else {
validationErrors = append(validationErrors, err)
}
}
checked++
handleOptionUpdate(option, true)
option.Unlock()
func() {
option.Lock()
defer option.Unlock()
_, err := validateValue(option, newValue)
if err != nil {
validationErrors = append(validationErrors, err)
}
if option.RequiresRestart {
requiresRestart = true
}
}()
}
}
signalChanges()
return validationErrors
return validationErrors, requiresRestart, checked < len(newValues)
}
// replaceDefaultConfig sets the (fallback) default config.
func replaceDefaultConfig(newValues map[string]interface{}) []*ValidationError {
var validationErrors []*ValidationError
// ReplaceConfig sets the (prioritized) user defined config.
func ReplaceConfig(newValues map[string]interface{}) (validationErrors []*ValidationError, requiresRestart bool) {
// RLock the options because we are not adding or removing
// options from the registration but rather only update the
// options value which is guarded by the option's lock itself
// options value which is guarded by the option's lock itself.
optionsLock.RLock()
defer optionsLock.RUnlock()
for key, option := range options {
newValue, ok := newValues[key]
option.Lock()
option.activeDefaultValue = nil
if ok {
valueCache, err := validateValue(option, newValue)
if err == nil {
option.activeDefaultValue = valueCache
} else {
validationErrors = append(validationErrors, err)
func() {
option.Lock()
defer option.Unlock()
option.activeValue = nil
if ok {
valueCache, err := validateValue(option, newValue)
if err == nil {
option.activeValue = valueCache
} else {
validationErrors = append(validationErrors, err)
}
}
}
handleOptionUpdate(option, true)
option.Unlock()
handleOptionUpdate(option, true)
if option.RequiresRestart {
requiresRestart = true
}
}()
}
signalChanges()
return validationErrors
return validationErrors, requiresRestart
}
// ReplaceDefaultConfig sets the (fallback) default config.
func ReplaceDefaultConfig(newValues map[string]interface{}) (validationErrors []*ValidationError, requiresRestart bool) {
// RLock the options because we are not adding or removing
// options from the registration but rather only update the
// options value which is guarded by the option's lock itself.
optionsLock.RLock()
defer optionsLock.RUnlock()
for key, option := range options {
newValue, ok := newValues[key]
func() {
option.Lock()
defer option.Unlock()
option.activeDefaultValue = nil
if ok {
valueCache, err := validateValue(option, newValue)
if err == nil {
option.activeDefaultValue = valueCache
} else {
validationErrors = append(validationErrors, err)
}
}
handleOptionUpdate(option, true)
if option.RequiresRestart {
requiresRestart = true
}
}()
}
signalChanges()
return validationErrors, requiresRestart
}
// SetConfigOption sets a single value in the (prioritized) user defined config.

View file

@ -24,7 +24,7 @@ func TestLayersGetters(t *testing.T) { //nolint:paralleltest
t.Fatal(err)
}
validationErrors := replaceConfig(mapData)
validationErrors, _ := ReplaceConfig(mapData)
if len(validationErrors) > 0 {
t.Fatalf("%d errors, first: %s", len(validationErrors), validationErrors[0].Error())
}

View file

@ -10,6 +10,7 @@ import (
"io"
"github.com/fxamacker/cbor/v2"
"github.com/ghodss/yaml"
"github.com/vmihailenco/msgpack/v5"
"github.com/safing/portbase/formats/varint"
@ -41,6 +42,12 @@ func LoadAsFormat(data []byte, format uint8, t interface{}) (err error) {
return fmt.Errorf("dsd: failed to unpack json: %w, data: %s", err, utils.SafeFirst16Bytes(data))
}
return nil
case YAML:
err = yaml.Unmarshal(data, t)
if err != nil {
return fmt.Errorf("dsd: failed to unpack yaml: %w, data: %s", err, utils.SafeFirst16Bytes(data))
}
return nil
case CBOR:
err = cbor.Unmarshal(data, t)
if err != nil {
@ -121,6 +128,11 @@ func dumpWithoutIdentifier(t interface{}, format uint8, indent string) ([]byte,
if err != nil {
return nil, err
}
case YAML:
data, err = yaml.Marshal(t)
if err != nil {
return nil, err
}
case CBOR:
data, err = cbor.Marshal(t)
if err != nil {

View file

@ -19,6 +19,7 @@ const (
GenCode = 71 // G
JSON = 74 // J
MsgPack = 77 // M
YAML = 89 // Y
// Compression types.
GZIP = 90 // Z
@ -48,6 +49,8 @@ func ValidateSerializationFormat(format uint8) (validatedFormat uint8, ok bool)
return format, true
case JSON:
return format, true
case YAML:
return format, true
case MsgPack:
return format, true
default:

View file

@ -5,8 +5,8 @@ import (
"errors"
"fmt"
"io"
"mime"
"net/http"
"strings"
)
// HTTP Related Errors.
@ -37,21 +37,8 @@ func loadFromHTTP(body io.Reader, mimeType string, t interface{}) (format uint8,
return 0, fmt.Errorf("dsd: failed to read http body: %w", err)
}
// Get mime type from header, then check, clean and verify it.
if mimeType == "" {
return 0, ErrMissingContentType
}
mimeType, _, err = mime.ParseMediaType(mimeType)
if err != nil {
return 0, fmt.Errorf("dsd: failed to parse content type: %w", err)
}
format, ok := MimeTypeToFormat[mimeType]
if !ok {
return 0, ErrIncompatibleFormat
}
// Parse data..
return format, LoadAsFormat(data, format, t)
// Load depending on mime type.
return MimeLoad(data, mimeType, t)
}
// RequestHTTPResponseFormat sets the Accept header to the given format.
@ -61,11 +48,6 @@ func RequestHTTPResponseFormat(r *http.Request, format uint8) (mimeType string,
if !ok {
return "", ErrIncompatibleFormat
}
// Omit charset.
mimeType, _, err = mime.ParseMediaType(mimeType)
if err != nil {
return "", fmt.Errorf("dsd: failed to parse content type: %w", err)
}
// Request response format.
r.Header.Set("Accept", mimeType)
@ -76,6 +58,7 @@ func RequestHTTPResponseFormat(r *http.Request, format uint8) (mimeType string,
// DumpToHTTPRequest dumps the given data to the HTTP request using the given
// format. It also sets the Accept header to the same format.
func DumpToHTTPRequest(r *http.Request, t interface{}, format uint8) error {
// Get mime type and set request format.
mimeType, err := RequestHTTPResponseFormat(r, format)
if err != nil {
return err
@ -87,7 +70,7 @@ func DumpToHTTPRequest(r *http.Request, t interface{}, format uint8) error {
return fmt.Errorf("dsd: failed to serialize: %w", err)
}
// Set body.
// Add data to request.
r.Header.Set("Content-Type", mimeType)
r.Body = io.NopCloser(bytes.NewReader(data))
@ -97,16 +80,8 @@ func DumpToHTTPRequest(r *http.Request, t interface{}, format uint8) error {
// DumpToHTTPResponse dumpts the given data to the HTTP response, using the
// format defined in the request's Accept header.
func DumpToHTTPResponse(w http.ResponseWriter, r *http.Request, t interface{}) error {
// Get format from Accept header.
// TODO: Improve parsing of Accept header.
mimeType := r.Header.Get("Accept")
format, ok := MimeTypeToFormat[mimeType]
if !ok {
return ErrIncompatibleFormat
}
// Serialize data.
data, err := dumpWithoutIdentifier(t, format, "")
// Serialize data based on accept header.
data, mimeType, _, err := MimeDump(t, r.Header.Get("Accept"))
if err != nil {
return fmt.Errorf("dsd: failed to serialize: %w", err)
}
@ -120,16 +95,71 @@ func DumpToHTTPResponse(w http.ResponseWriter, r *http.Request, t interface{}) e
return nil
}
// MimeLoad loads the given data into the interface based on the given mime type.
func MimeLoad(data []byte, mimeType string, t interface{}) (format uint8, err error) {
// Find format.
format = FormatFromMime(mimeType)
if format == 0 {
return 0, ErrIncompatibleFormat
}
// Load data.
err = LoadAsFormat(data, format, t)
return format, err
}
// MimeDump dumps the given interface based on the given mime type accept header.
func MimeDump(t any, accept string) (data []byte, mimeType string, format uint8, err error) {
// Find format.
accept = extractMimeType(accept)
switch accept {
case "", "*":
format = DefaultSerializationFormat
default:
format = MimeTypeToFormat[accept]
if format == 0 {
return nil, "", 0, ErrIncompatibleFormat
}
}
mimeType = FormatToMimeType[format]
// Serialize and return.
data, err = dumpWithoutIdentifier(t, format, "")
return data, mimeType, format, err
}
// FormatFromMime returns the format for the given mime type.
// Will return AUTO format for unsupported or unrecognized mime types.
func FormatFromMime(mimeType string) (format uint8) {
return MimeTypeToFormat[extractMimeType(mimeType)]
}
func extractMimeType(mimeType string) string {
if strings.Contains(mimeType, ",") {
mimeType, _, _ = strings.Cut(mimeType, ",")
}
if strings.Contains(mimeType, ";") {
mimeType, _, _ = strings.Cut(mimeType, ";")
}
if strings.Contains(mimeType, "/") {
_, mimeType, _ = strings.Cut(mimeType, "/")
}
return strings.ToLower(mimeType)
}
// Format and MimeType mappings.
var (
FormatToMimeType = map[uint8]string{
JSON: "application/json; charset=utf-8",
CBOR: "application/cbor",
JSON: "application/json",
MsgPack: "application/msgpack",
YAML: "application/yaml",
}
MimeTypeToFormat = map[string]uint8{
"application/json": JSON,
"application/cbor": CBOR,
"application/msgpack": MsgPack,
"cbor": CBOR,
"json": JSON,
"msgpack": MsgPack,
"yaml": YAML,
"yml": YAML,
}
)

36
formats/dsd/http_test.go Normal file
View file

@ -0,0 +1,36 @@
package dsd
import (
"mime"
"testing"
"github.com/stretchr/testify/assert"
)
func TestMimeTypes(t *testing.T) {
t.Parallel()
// Test static maps.
for _, mimeType := range FormatToMimeType {
cleaned, _, err := mime.ParseMediaType(mimeType)
assert.NoError(t, err, "mime type must be parse-able")
assert.Equal(t, mimeType, cleaned, "mime type should be clean in map already")
}
for mimeType := range MimeTypeToFormat {
cleaned, _, err := mime.ParseMediaType(mimeType)
assert.NoError(t, err, "mime type must be parse-able")
assert.Equal(t, mimeType, cleaned, "mime type should be clean in map already")
}
// Test assumptions.
for mimeType, mimeTypeCleaned := range map[string]string{
"application/xml, image/webp": "xml",
"application/xml;q=0.9, image/webp": "xml",
"*": "*",
"*/*": "*",
"text/yAMl": "yaml",
} {
cleaned := extractMimeType(mimeType)
assert.Equal(t, mimeTypeCleaned, cleaned, "assumption for %q should hold", mimeType)
}
}

4
go.mod
View file

@ -10,6 +10,7 @@ require (
github.com/davecgh/go-spew v1.1.1
github.com/dgraph-io/badger v1.6.2
github.com/fxamacker/cbor/v2 v2.5.0
github.com/ghodss/yaml v1.0.0
github.com/gofrs/uuid v4.4.0+incompatible
github.com/gorilla/mux v1.8.0
github.com/gorilla/websocket v1.5.0
@ -22,7 +23,7 @@ require (
github.com/shirou/gopsutil v3.21.11+incompatible
github.com/stretchr/testify v1.8.1
github.com/tevino/abool v1.2.0
github.com/tidwall/gjson v1.16.0
github.com/tidwall/gjson v1.17.0
github.com/tidwall/sjson v1.2.5
github.com/vmihailenco/msgpack/v5 v5.3.5
go.etcd.io/bbolt v1.3.7
@ -62,5 +63,6 @@ require (
golang.org/x/net v0.15.0 // indirect
golang.org/x/time v0.3.0 // indirect
google.golang.org/protobuf v1.31.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

7
go.sum
View file

@ -50,6 +50,8 @@ github.com/fxamacker/cbor v1.5.1/go.mod h1:3aPGItF174ni7dDzd6JZ206H8cmr4GDNBGpPa
github.com/fxamacker/cbor/v2 v2.4.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo=
github.com/fxamacker/cbor/v2 v2.5.0 h1:oHsG0V/Q6E/wqTS2O1Cozzsy69nqCiguo5Q1a1ADivE=
github.com/fxamacker/cbor/v2 v2.5.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo=
github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk=
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE=
github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78=
@ -157,8 +159,8 @@ github.com/tevino/abool v1.2.0 h1:heAkClL8H6w+mK5md9dzsuohKeXHUpY7Vw0ZCKW+huA=
github.com/tevino/abool v1.2.0/go.mod h1:qc66Pna1RiIsPa7O4Egxxs9OqkuxDX55zznh9K07Tzg=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.16.0 h1:SyXa+dsSPpUlcwEDuKuEBJEz5vzTvOea+9rjyYodQFg=
github.com/tidwall/gjson v1.16.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.17.0 h1:/Jocvlh98kcTfpN2+JzGQWQcqrPQwDrVEMApx/M5ZwM=
github.com/tidwall/gjson v1.17.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
@ -263,6 +265,7 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View file

@ -104,7 +104,7 @@ func (m *Module) InjectEvent(sourceEventName, targetModuleName, targetEventName
func (m *Module) runEventHook(hook *eventHook, event string, data interface{}) {
// check if source module is ready for handling
if m.Status() != StatusOnline {
// target module has not yet fully started, wait until start is complete
// source module has not yet fully started, wait until start is complete
select {
case <-m.StartCompleted():
// continue with hook execution

View file

@ -62,10 +62,30 @@ func TestCallLimiter(t *testing.T) {
}
testWg.Wait()
if execs <= 8 {
if execs <= 5 {
t.Errorf("unexpected low exec count: %d", execs)
}
if execs >= 12 {
if execs >= 15 {
t.Errorf("unexpected high exec count: %d", execs)
}
// Wait for pause to reset.
time.Sleep(pause)
// Check if the limiter correctly handles panics.
testWg.Add(100)
for i := 0; i < 100; i++ {
go func() {
defer func() {
_ = recover()
testWg.Done()
}()
oa.Do(func() {
time.Sleep(1 * time.Millisecond)
panic("test")
})
}()
time.Sleep(100 * time.Microsecond)
}
testWg.Wait()
}