mirror of
https://github.com/safing/portmaster
synced 2025-04-03 10:39:13 +00:00
Restructure modules (#1572)
* Move portbase into monorepo * Add new simple module mgr * [WIP] Switch to new simple module mgr * Add StateMgr and more worker variants * [WIP] Switch more modules * [WIP] Switch more modules * [WIP] swtich more modules * [WIP] switch all SPN modules * [WIP] switch all service modules * [WIP] Convert all workers to the new module system * [WIP] add new task system to module manager * [WIP] Add second take for scheduling workers * [WIP] Add FIXME for bugs in new scheduler * [WIP] Add minor improvements to scheduler * [WIP] Add new worker scheduler * [WIP] Fix more bug related to new module system * [WIP] Fix start handing of the new module system * [WIP] Improve startup process * [WIP] Fix minor issues * [WIP] Fix missing subsystem in settings * [WIP] Initialize managers in constructor * [WIP] Move module event initialization to constrictors * [WIP] Fix setting for enabling and disabling the SPN module * [WIP] Move API registeration into module construction * [WIP] Update states mgr for all modules * [WIP] Add CmdLine operation support * Add state helper methods to module group and instance * Add notification and module status handling to status package * Fix starting issues * Remove pilot widget and update security lock to new status data * Remove debug logs * Improve http server shutdown * Add workaround for cleanly shutting down firewall+netquery * Improve logging * Add syncing states with notifications for new module system * Improve starting, stopping, shutdown; resolve FIXMEs/TODOs * [WIP] Fix most unit tests * Review new module system and fix minor issues * Push shutdown and restart events again via API * Set sleep mode via interface * Update example/template module * [WIP] Fix spn/cabin unit test * Remove deprecated UI elements * Make log output more similar for the logging transition phase * Switch spn hub and observer cmds to new module system * Fix log sources * Make worker mgr less error prone * Fix tests and minor issues * Fix observation hub * Improve shutdown and restart handling * Split up big connection.go source file * Move varint and dsd packages to structures repo * Improve expansion test * Fix linter warnings * Fix interception module on windows * Fix linter errors --------- Co-authored-by: Vladimir Stoilov <vladimir@safing.io>
This commit is contained in:
parent
10a77498f4
commit
80664d1a27
647 changed files with 37690 additions and 3366 deletions
Earthfile
assets
base
.gitignoreREADME.md
api
api_bridge.goauth_wrapper.goauthentication.goauthentication_test.go
client
config.godatabase.godoc.goendpoints.goendpoints_config.goendpoints_debug.goendpoints_meta.goendpoints_test.goenriched-response.goinit_test.gomain.gomodule.gorequest.gorouter.gotestclient
apprise
config
basic_config.godatabase.godoc.goexpertise.goget-safe.goget.goget_test.goinit_test.gomain.gomodule.gooption.gopersistence.gopersistence_test.goperspective.goregistry.goregistry_test.gorelease.goset.goset_test.govalidate.govalidity.go
container
database
accessor
boilerplate_test.gocontroller.gocontrollers.godatabase.godatabase_test.godbmodule
doc.goerrors.gohook.gohookbase.gointerface.gointerface_cache.gointerface_cache_test.goiterator
main.gomaintenance.gomigration
query
README.mdcondition-and.gocondition-bool.gocondition-error.gocondition-exists.gocondition-float.gocondition-int.gocondition-not.gocondition-or.gocondition-regex.gocondition-string.gocondition-stringslice.gocondition.gocondition_test.gooperators.gooperators_test.goparser.goparser_test.goquery.goquery_test.go
record
|
@ -202,7 +202,7 @@ go-build:
|
|||
ENV EXTRA_LD_FLAGS = ""
|
||||
END
|
||||
|
||||
RUN --no-cache go build -ldflags="-X github.com/safing/portbase/info.version=${VERSION} -X github.com/safing/portbase/info.buildSource=${SOURCE} -X github.com/safing/portbase/info.buildTime=${BUILD_TIME} ${EXTRA_LD_FLAGS}" -o "/tmp/build/" ./cmds/${bin}
|
||||
RUN --no-cache go build -ldflags="-X github.com/safing/portmaster/base/info.version=${VERSION} -X github.com/safing/portmaster/base/info.buildSource=${SOURCE} -X github.com/safing/portmaster/base/info.buildTime=${BUILD_TIME}" -o "/tmp/build/" ./cmds/${bin}
|
||||
END
|
||||
|
||||
DO +GO_ARCH_STRING --goos="${GOOS}" --goarch="${GOARCH}" --goarm="${GOARM}"
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
|
||||
"golang.org/x/image/draw"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/base/log"
|
||||
)
|
||||
|
||||
// Colored Icon IDs.
|
||||
|
|
8
base/.gitignore
vendored
Normal file
8
base/.gitignore
vendored
Normal file
|
@ -0,0 +1,8 @@
|
|||
portbase
|
||||
apitest
|
||||
misc
|
||||
|
||||
go.mod.*
|
||||
vendor
|
||||
go.work
|
||||
go.work.sum
|
157
base/README.md
Normal file
157
base/README.md
Normal file
|
@ -0,0 +1,157 @@
|
|||
> **Check out our main project at [safing/portmaster](https://github.com/safing/portmaster)**
|
||||
|
||||
# Portbase
|
||||
|
||||
Portbase helps you quickly take off with your project. It gives you all the basic needs you would have for a service (_not_ tool!).
|
||||
Here is what is included:
|
||||
|
||||
- `log`: really fast and beautiful logging
|
||||
- `modules`: a multi stage, dependency aware boot process for your software, also manages tasks
|
||||
- `config`: simple, live updating and extremely fast configuration storage
|
||||
- `info`: easily tag your builds with versions, commit hashes, and so on
|
||||
- `formats`: some handy data encoding libs
|
||||
- `rng`: a feedable CSPRNG for great randomness
|
||||
- `database`: intelligent and syncable database with hooks and easy integration with structs, uses buckets with different backends
|
||||
- `api`: a websocket interface to the database, can be extended with custom http handlers
|
||||
|
||||
Before you continue, a word about this project. It was created to hold the base code for both Portmaster and Gate17. This is also what it will be developed for. If you have a great idea on how to improve portbase, please, by all means, raise an issue and tell us about it, but please also don't be surprised or offended if we ask you to create your own fork to do what you need. Portbase isn't for everyone, it's quite specific to our needs, but we decided to make it easily available to others.
|
||||
|
||||
Portbase is actively maintained, please raise issues.
|
||||
|
||||
## log
|
||||
|
||||
The main goal of this logging package is to be as fast as possible. Logs are sent to a channel only with minimal processing beforehand, so that the service can continue with the important work and write the logs later.
|
||||
|
||||
Second, is beauty, both in form what information is provided and how.
|
||||
|
||||
You can use flags to change the log level on a source file basis.
|
||||
|
||||
## modules <small>requires `log`</small>
|
||||
|
||||
packages may register themselves as modules, to take part in the multi stage boot and coordinated shutdown.
|
||||
|
||||
Registering only requires a name/key and the `prep()`, `start()` and `stop()` functions.
|
||||
|
||||
This is how modules are booted:
|
||||
|
||||
- `init()` available: ~~flags~~, ~~config~~, ~~logging~~, ~~dependencies~~
|
||||
- register flags (with the stdlib `flag` library)
|
||||
- register module
|
||||
- `module.prep()` available: flags, ~~config~~, ~~logging~~, ~~dependencies~~
|
||||
- react to flags
|
||||
- register config variables
|
||||
- if an error occurs, return it
|
||||
- return ErrCleanExit for a clean, successful exit. (eg. you only printed a version)
|
||||
- `module.start()` available: flags, config, logging, dependencies
|
||||
- start tasks and workers
|
||||
- do not log errors while starting, but return them
|
||||
- `module.stop()` available: flags, config, logging, dependencies
|
||||
- stop all work (ie. goroutines)
|
||||
- do not log errors while stopping, but return them
|
||||
|
||||
You can start tasks and workers from your module that are then integrated into the module system and will allow for insights and better control of them in the future.
|
||||
|
||||
## config <small>requires `log`</small>
|
||||
|
||||
The config package stores the configuration in json strings. This may sound a bit weird, but it's very practical.
|
||||
|
||||
There are three layers of configuration - in order of priority: user configuration, default configuration and the fallback values supplied when registering a config variable.
|
||||
|
||||
When using config variables, you get a function that checks if your config variable is still up to date every time. If it did not change, it's _extremely_ fast. But if it, it will fetch the current value, which takes a short while, but does not happen often.
|
||||
|
||||
// This is how you would get a string config variable function.
|
||||
myVar := GetAsString("my_config_var", "default")
|
||||
// You then use myVar() directly every time, except when you must guarantee the same value between two calls
|
||||
if myVar() != "default" {
|
||||
log.Infof("my_config_var is set to %s", myVar())
|
||||
}
|
||||
// no error handling needed! :)
|
||||
|
||||
WARNING: While these config variable functions are _extremely_ fast, they are _NOT_ thread/goroutine safe! (Use the `Concurrent` wrapper for that!)
|
||||
|
||||
## info
|
||||
|
||||
Info provides a easy way to store your version and build information within the binary. If you use the `build` script to build the program, it will automatically set build information so that you can easily find out when and from which commit a binary was built.
|
||||
|
||||
The `build` script extracts information from the host and the git repo and then calls `go build` with some additional arguments.
|
||||
|
||||
## formats/varint
|
||||
|
||||
This is just a convenience wrapper around `encoding/binary`, because we use varints a lot.
|
||||
|
||||
## formats/dsd <small>requires `formats/varint`</small>
|
||||
|
||||
DSD stands for dynamically structured data. In short, this a generic packer that reacts to the supplied data type.
|
||||
|
||||
- structs are usually json encoded
|
||||
- []bytes and strings stay the same
|
||||
|
||||
This makes it easier / more efficient to store different data types in a k/v data storage.
|
||||
|
||||
## rng <small>requires `log`, `config`</small>
|
||||
|
||||
This package provides a CSPRNG based on the [Fortuna](https://en.wikipedia.org/wiki/Fortuna_(PRNG)) CSPRNG, devised by Bruce Schneier and Niels Ferguson. Implemented by Jochen Voss, published [on Github](https://github.com/seehuhn/fortuna).
|
||||
|
||||
Only the Generator is used from the `fortuna` package. The feeding system implemented here is configurable and is focused with efficiency in mind.
|
||||
|
||||
While you can feed the RNG yourself, it has two feeders by default:
|
||||
- It starts with a seed from `crypto/rand` and periodically reseeds from there
|
||||
- A really simple tickfeeder which extracts entropy from the internal go scheduler using goroutines and is meant to be used under load.
|
||||
|
||||
## database <small>requires `log`</small>
|
||||
_introduction to be written_
|
||||
|
||||
## api <small>requires `log`, `database`, `config`</small>
|
||||
_introduction to be written_
|
||||
|
||||
## The main program
|
||||
|
||||
If you build everything with modules, your main program should be similar to this - just use an empty import for the modules you need:
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/safing/portmaster/base/info"
|
||||
"github.com/safing/portmaster/base/log"
|
||||
"github.com/safing/portmaster/base/modules"
|
||||
|
||||
// include packages here
|
||||
_ "path/to/my/custom/module"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
// Set Info
|
||||
info.Set("MySoftware", "1.0.0")
|
||||
|
||||
// Start
|
||||
err := modules.Start()
|
||||
if err != nil {
|
||||
if err == modules.ErrCleanExit {
|
||||
os.Exit(0)
|
||||
} else {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown
|
||||
// catch interrupt for clean shutdown
|
||||
signalCh := make(chan os.Signal)
|
||||
signal.Notify(
|
||||
signalCh,
|
||||
os.Interrupt,
|
||||
syscall.SIGHUP,
|
||||
syscall.SIGINT,
|
||||
syscall.SIGTERM,
|
||||
syscall.SIGQUIT,
|
||||
)
|
||||
select {
|
||||
case <-signalCh:
|
||||
log.Warning("main: program was interrupted")
|
||||
modules.Shutdown()
|
||||
case <-modules.ShuttingDown():
|
||||
}
|
||||
|
||||
}
|
173
base/api/api_bridge.go
Normal file
173
base/api/api_bridge.go
Normal file
|
@ -0,0 +1,173 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/safing/portmaster/base/database"
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
"github.com/safing/portmaster/base/database/storage"
|
||||
)
|
||||
|
||||
const (
|
||||
endpointBridgeRemoteAddress = "websocket-bridge"
|
||||
apiDatabaseName = "api"
|
||||
)
|
||||
|
||||
func registerEndpointBridgeDB() error {
|
||||
if _, err := database.Register(&database.Database{
|
||||
Name: apiDatabaseName,
|
||||
Description: "API Bridge",
|
||||
StorageType: "injected",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := database.InjectDatabase("api", &endpointBridgeStorage{})
|
||||
return err
|
||||
}
|
||||
|
||||
type endpointBridgeStorage struct {
|
||||
storage.InjectBase
|
||||
}
|
||||
|
||||
// EndpointBridgeRequest holds a bridged request API request.
|
||||
type EndpointBridgeRequest struct {
|
||||
record.Base
|
||||
sync.Mutex
|
||||
|
||||
Method string
|
||||
Path string
|
||||
Query map[string]string
|
||||
Data []byte
|
||||
MimeType string
|
||||
}
|
||||
|
||||
// EndpointBridgeResponse holds a bridged request API response.
|
||||
type EndpointBridgeResponse struct {
|
||||
record.Base
|
||||
sync.Mutex
|
||||
|
||||
MimeType string
|
||||
Body string
|
||||
}
|
||||
|
||||
// Get returns a database record.
|
||||
func (ebs *endpointBridgeStorage) Get(key string) (record.Record, error) {
|
||||
if key == "" {
|
||||
return nil, database.ErrNotFound
|
||||
}
|
||||
|
||||
return callAPI(&EndpointBridgeRequest{
|
||||
Method: http.MethodGet,
|
||||
Path: key,
|
||||
})
|
||||
}
|
||||
|
||||
// Get returns the metadata of a database record.
|
||||
func (ebs *endpointBridgeStorage) GetMeta(key string) (*record.Meta, error) {
|
||||
// This interface is an API, always return a fresh copy.
|
||||
m := &record.Meta{}
|
||||
m.Update()
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Put stores a record in the database.
|
||||
func (ebs *endpointBridgeStorage) Put(r record.Record) (record.Record, error) {
|
||||
if r.DatabaseKey() == "" {
|
||||
return nil, database.ErrNotFound
|
||||
}
|
||||
|
||||
// Prepare data.
|
||||
var ebr *EndpointBridgeRequest
|
||||
if r.IsWrapped() {
|
||||
// Only allocate a new struct, if we need it.
|
||||
ebr = &EndpointBridgeRequest{}
|
||||
err := record.Unwrap(r, ebr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
var ok bool
|
||||
ebr, ok = r.(*EndpointBridgeRequest)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("record not of type *EndpointBridgeRequest, but %T", r)
|
||||
}
|
||||
}
|
||||
|
||||
// Override path with key to mitigate sneaky stuff.
|
||||
ebr.Path = r.DatabaseKey()
|
||||
return callAPI(ebr)
|
||||
}
|
||||
|
||||
// ReadOnly returns whether the database is read only.
|
||||
func (ebs *endpointBridgeStorage) ReadOnly() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func callAPI(ebr *EndpointBridgeRequest) (record.Record, error) {
|
||||
// Add API prefix to path.
|
||||
requestURL := path.Join(apiV1Path, ebr.Path)
|
||||
// Check if path is correct. (Defense in depth)
|
||||
if !strings.HasPrefix(requestURL, apiV1Path) {
|
||||
return nil, fmt.Errorf("bridged request for %q violates scope", ebr.Path)
|
||||
}
|
||||
|
||||
// Apply default Method.
|
||||
if ebr.Method == "" {
|
||||
if len(ebr.Data) > 0 {
|
||||
ebr.Method = http.MethodPost
|
||||
} else {
|
||||
ebr.Method = http.MethodGet
|
||||
}
|
||||
}
|
||||
|
||||
// Build URL.
|
||||
u, err := url.ParseRequestURI(requestURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build bridged request url: %w", err)
|
||||
}
|
||||
// Build query values.
|
||||
if ebr.Query != nil && len(ebr.Query) > 0 {
|
||||
query := url.Values{}
|
||||
for k, v := range ebr.Query {
|
||||
query.Set(k, v)
|
||||
}
|
||||
u.RawQuery = query.Encode()
|
||||
}
|
||||
|
||||
// Create request and response objects.
|
||||
r := httptest.NewRequest(ebr.Method, u.String(), bytes.NewBuffer(ebr.Data))
|
||||
r.RemoteAddr = endpointBridgeRemoteAddress
|
||||
if ebr.MimeType != "" {
|
||||
r.Header.Set("Content-Type", ebr.MimeType)
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
// Let the API handle the request.
|
||||
server.Handler.ServeHTTP(w, r)
|
||||
switch w.Code {
|
||||
case 200:
|
||||
// Everything okay, continue.
|
||||
case 500:
|
||||
// A Go error was returned internally.
|
||||
// We can safely return this as an error.
|
||||
return nil, fmt.Errorf("bridged api call failed: %s", w.Body.String())
|
||||
default:
|
||||
return nil, fmt.Errorf("bridged api call returned unexpected error code %d", w.Code)
|
||||
}
|
||||
|
||||
response := &EndpointBridgeResponse{
|
||||
MimeType: w.Header().Get("Content-Type"),
|
||||
Body: w.Body.String(),
|
||||
}
|
||||
response.SetKey(apiDatabaseName + ":" + ebr.Path)
|
||||
response.UpdateMeta()
|
||||
|
||||
return response, nil
|
||||
}
|
30
base/api/auth_wrapper.go
Normal file
30
base/api/auth_wrapper.go
Normal file
|
@ -0,0 +1,30 @@
|
|||
package api
|
||||
|
||||
import "net/http"
|
||||
|
||||
// WrapInAuthHandler wraps a simple http.HandlerFunc into a handler that
|
||||
// exposes the required API permissions for this handler.
|
||||
func WrapInAuthHandler(fn http.HandlerFunc, read, write Permission) http.Handler {
|
||||
return &wrappedAuthenticatedHandler{
|
||||
HandlerFunc: fn,
|
||||
read: read,
|
||||
write: write,
|
||||
}
|
||||
}
|
||||
|
||||
type wrappedAuthenticatedHandler struct {
|
||||
http.HandlerFunc
|
||||
|
||||
read Permission
|
||||
write Permission
|
||||
}
|
||||
|
||||
// ReadPermission returns the read permission for the handler.
|
||||
func (wah *wrappedAuthenticatedHandler) ReadPermission(r *http.Request) Permission {
|
||||
return wah.read
|
||||
}
|
||||
|
||||
// WritePermission returns the write permission for the handler.
|
||||
func (wah *wrappedAuthenticatedHandler) WritePermission(r *http.Request) Permission {
|
||||
return wah.write
|
||||
}
|
598
base/api/authentication.go
Normal file
598
base/api/authentication.go
Normal file
|
@ -0,0 +1,598 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portmaster/base/config"
|
||||
"github.com/safing/portmaster/base/log"
|
||||
"github.com/safing/portmaster/base/rng"
|
||||
"github.com/safing/portmaster/service/mgr"
|
||||
)
|
||||
|
||||
const (
|
||||
sessionCookieName = "Portmaster-API-Token"
|
||||
sessionCookieTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
var (
|
||||
apiKeys = make(map[string]*AuthToken)
|
||||
apiKeysLock sync.Mutex
|
||||
|
||||
authFnSet = abool.New()
|
||||
authFn AuthenticatorFunc
|
||||
|
||||
sessions = make(map[string]*session)
|
||||
sessionsLock sync.Mutex
|
||||
|
||||
// ErrAPIAccessDeniedMessage should be wrapped by errors returned by
|
||||
// AuthenticatorFunc in order to signify a blocked request, including a error
|
||||
// message for the user. This is an empty message on purpose, as to allow the
|
||||
// function to define the full text of the error shown to the user.
|
||||
ErrAPIAccessDeniedMessage = errors.New("")
|
||||
)
|
||||
|
||||
// Permission defines an API requests permission.
|
||||
type Permission int8
|
||||
|
||||
const (
|
||||
// NotFound declares that the operation does not exist.
|
||||
NotFound Permission = -2
|
||||
|
||||
// Dynamic declares that the operation requires permission to be processed,
|
||||
// but anyone can execute the operation, as it reacts to permissions itself.
|
||||
Dynamic Permission = -1
|
||||
|
||||
// NotSupported declares that the operation is not supported.
|
||||
NotSupported Permission = 0
|
||||
|
||||
// PermitAnyone declares that anyone can execute the operation without any
|
||||
// authentication.
|
||||
PermitAnyone Permission = 1
|
||||
|
||||
// PermitUser declares that the operation may be executed by authenticated
|
||||
// third party applications that are categorized as representing a simple
|
||||
// user and is limited in access.
|
||||
PermitUser Permission = 2
|
||||
|
||||
// PermitAdmin declares that the operation may be executed by authenticated
|
||||
// third party applications that are categorized as representing an
|
||||
// administrator and has broad in access.
|
||||
PermitAdmin Permission = 3
|
||||
|
||||
// PermitSelf declares that the operation may only be executed by the
|
||||
// software itself and its own (first party) components.
|
||||
PermitSelf Permission = 4
|
||||
)
|
||||
|
||||
// AuthenticatorFunc is a function that can be set as the authenticator for the
|
||||
// API endpoint. If none is set, all requests will have full access.
|
||||
// The returned AuthToken represents the permissions that the request has.
|
||||
type AuthenticatorFunc func(r *http.Request, s *http.Server) (*AuthToken, error)
|
||||
|
||||
// AuthToken represents either a set of required or granted permissions.
|
||||
// All attributes must be set when the struct is built and must not be changed
|
||||
// later. Functions may be called at any time.
|
||||
// The Write permission implicitly also includes reading.
|
||||
type AuthToken struct {
|
||||
Read Permission
|
||||
Write Permission
|
||||
ValidUntil *time.Time
|
||||
}
|
||||
|
||||
type session struct {
|
||||
sync.Mutex
|
||||
|
||||
token *AuthToken
|
||||
validUntil time.Time
|
||||
}
|
||||
|
||||
// Expired returns whether the session has expired.
|
||||
func (sess *session) Expired() bool {
|
||||
sess.Lock()
|
||||
defer sess.Unlock()
|
||||
|
||||
return time.Now().After(sess.validUntil)
|
||||
}
|
||||
|
||||
// Refresh refreshes the validity of the session with the given TTL.
|
||||
func (sess *session) Refresh(ttl time.Duration) {
|
||||
sess.Lock()
|
||||
defer sess.Unlock()
|
||||
|
||||
sess.validUntil = time.Now().Add(ttl)
|
||||
}
|
||||
|
||||
// AuthenticatedHandler defines the handler interface to specify custom
|
||||
// permission for an API handler. The returned permission is the required
|
||||
// permission for the request to proceed.
|
||||
type AuthenticatedHandler interface {
|
||||
ReadPermission(r *http.Request) Permission
|
||||
WritePermission(r *http.Request) Permission
|
||||
}
|
||||
|
||||
// SetAuthenticator sets an authenticator function for the API endpoint. If none is set, all requests will be permitted.
|
||||
func SetAuthenticator(fn AuthenticatorFunc) error {
|
||||
if module.online.Load() {
|
||||
return ErrAuthenticationImmutable
|
||||
}
|
||||
|
||||
if !authFnSet.SetToIf(false, true) {
|
||||
return ErrAuthenticationAlreadySet
|
||||
}
|
||||
|
||||
authFn = fn
|
||||
return nil
|
||||
}
|
||||
|
||||
func authenticateRequest(w http.ResponseWriter, r *http.Request, targetHandler http.Handler, readMethod bool) *AuthToken {
|
||||
tracer := log.Tracer(r.Context())
|
||||
|
||||
// Get required permission for target handler.
|
||||
requiredPermission := PermitSelf
|
||||
if authdHandler, ok := targetHandler.(AuthenticatedHandler); ok {
|
||||
if readMethod {
|
||||
requiredPermission = authdHandler.ReadPermission(r)
|
||||
} else {
|
||||
requiredPermission = authdHandler.WritePermission(r)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if we need to do any authentication at all.
|
||||
switch requiredPermission { //nolint:exhaustive
|
||||
case NotFound:
|
||||
// Not found.
|
||||
tracer.Debug("api: no API endpoint registered for this path")
|
||||
http.Error(w, "Not found.", http.StatusNotFound)
|
||||
return nil
|
||||
case NotSupported:
|
||||
// A read or write permission can be marked as not supported.
|
||||
tracer.Trace("api: authenticated handler reported: not supported")
|
||||
http.Error(w, "Method not allowed.", http.StatusMethodNotAllowed)
|
||||
return nil
|
||||
case PermitAnyone:
|
||||
// Don't process permissions, as we don't need them.
|
||||
tracer.Tracef("api: granted %s access to public handler", r.RemoteAddr)
|
||||
return &AuthToken{
|
||||
Read: PermitAnyone,
|
||||
Write: PermitAnyone,
|
||||
}
|
||||
case Dynamic:
|
||||
// Continue processing permissions, but treat as PermitAnyone.
|
||||
requiredPermission = PermitAnyone
|
||||
}
|
||||
|
||||
// The required permission must match the request permission values after
|
||||
// handling the specials.
|
||||
if requiredPermission < PermitAnyone || requiredPermission > PermitSelf {
|
||||
tracer.Warningf(
|
||||
"api: handler returned invalid permission: %s (%d)",
|
||||
requiredPermission,
|
||||
requiredPermission,
|
||||
)
|
||||
http.Error(w, "Internal server error during authentication.", http.StatusInternalServerError)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Authenticate request.
|
||||
token, handled := checkAuth(w, r, requiredPermission > PermitAnyone)
|
||||
switch {
|
||||
case handled:
|
||||
return nil
|
||||
case token == nil:
|
||||
// Use default permissions.
|
||||
token = &AuthToken{
|
||||
Read: PermitAnyone,
|
||||
Write: PermitAnyone,
|
||||
}
|
||||
}
|
||||
|
||||
// Get effective permission for request.
|
||||
var requestPermission Permission
|
||||
if readMethod {
|
||||
requestPermission = token.Read
|
||||
} else {
|
||||
requestPermission = token.Write
|
||||
}
|
||||
|
||||
// Check for valid request permission.
|
||||
if requestPermission < PermitAnyone || requestPermission > PermitSelf {
|
||||
tracer.Warningf(
|
||||
"api: authenticator returned invalid permission: %s (%d)",
|
||||
requestPermission,
|
||||
requestPermission,
|
||||
)
|
||||
http.Error(w, "Internal server error during authentication.", http.StatusInternalServerError)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check permission.
|
||||
if requestPermission < requiredPermission {
|
||||
// If the token is strictly public, return an authentication request.
|
||||
if token.Read == PermitAnyone && token.Write == PermitAnyone {
|
||||
w.Header().Set(
|
||||
"WWW-Authenticate",
|
||||
`Bearer realm="Portmaster API" domain="/"`,
|
||||
)
|
||||
http.Error(w, "Authorization required.", http.StatusUnauthorized)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Otherwise just inform of insufficient permissions.
|
||||
http.Error(w, "Insufficient permissions.", http.StatusForbidden)
|
||||
return nil
|
||||
}
|
||||
|
||||
tracer.Tracef("api: granted %s access to protected handler", r.RemoteAddr)
|
||||
|
||||
// Make a copy of the AuthToken in order mitigate the handler poisoning the
|
||||
// token, as changes would apply to future requests.
|
||||
return &AuthToken{
|
||||
Read: token.Read,
|
||||
Write: token.Write,
|
||||
}
|
||||
}
|
||||
|
||||
func checkAuth(w http.ResponseWriter, r *http.Request, authRequired bool) (token *AuthToken, handled bool) {
|
||||
// Return highest possible permissions in dev mode.
|
||||
if devMode() {
|
||||
return &AuthToken{
|
||||
Read: PermitSelf,
|
||||
Write: PermitSelf,
|
||||
}, false
|
||||
}
|
||||
|
||||
// Database Bridge Access.
|
||||
if r.RemoteAddr == endpointBridgeRemoteAddress {
|
||||
return &AuthToken{
|
||||
Read: dbCompatibilityPermission,
|
||||
Write: dbCompatibilityPermission,
|
||||
}, false
|
||||
}
|
||||
|
||||
// Check for valid API key.
|
||||
token = checkAPIKey(r)
|
||||
if token != nil {
|
||||
return token, false
|
||||
}
|
||||
|
||||
// Check for valid session cookie.
|
||||
token = checkSessionCookie(r)
|
||||
if token != nil {
|
||||
return token, false
|
||||
}
|
||||
|
||||
// Check if an external authentication method is available.
|
||||
if !authFnSet.IsSet() {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Authenticate externally.
|
||||
token, err := authFn(r, server)
|
||||
if err != nil {
|
||||
// Check if the authentication process failed internally.
|
||||
if !errors.Is(err, ErrAPIAccessDeniedMessage) {
|
||||
log.Tracer(r.Context()).Errorf("api: authenticator failed: %s", err)
|
||||
http.Error(w, "Internal server error during authentication.", http.StatusInternalServerError)
|
||||
return nil, true
|
||||
}
|
||||
|
||||
// Return authentication failure message if authentication is required.
|
||||
if authRequired {
|
||||
log.Tracer(r.Context()).Warningf("api: denying api access from %s", r.RemoteAddr)
|
||||
http.Error(w, err.Error(), http.StatusForbidden)
|
||||
return nil, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Abort if no token is returned.
|
||||
if token == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Create session cookie for authenticated request.
|
||||
err = createSession(w, r, token)
|
||||
if err != nil {
|
||||
log.Tracer(r.Context()).Warningf("api: failed to create session: %s", err)
|
||||
}
|
||||
return token, false
|
||||
}
|
||||
|
||||
func checkAPIKey(r *http.Request) *AuthToken {
|
||||
// Get API key from request.
|
||||
key := r.Header.Get("Authorization")
|
||||
if key == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse API key.
|
||||
switch {
|
||||
case strings.HasPrefix(key, "Bearer "):
|
||||
key = strings.TrimPrefix(key, "Bearer ")
|
||||
case strings.HasPrefix(key, "Basic "):
|
||||
user, pass, _ := r.BasicAuth()
|
||||
key = user + pass
|
||||
default:
|
||||
log.Tracer(r.Context()).Tracef(
|
||||
"api: provided api key type %s is unsupported", strings.Split(key, " ")[0],
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
apiKeysLock.Lock()
|
||||
defer apiKeysLock.Unlock()
|
||||
|
||||
// Check if the provided API key exists.
|
||||
token, ok := apiKeys[key]
|
||||
if !ok {
|
||||
log.Tracer(r.Context()).Tracef(
|
||||
"api: provided api key %s... is unknown", key[:4],
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Abort if the token is expired.
|
||||
if token.ValidUntil != nil && time.Now().After(*token.ValidUntil) {
|
||||
log.Tracer(r.Context()).Warningf("api: denying api access from %s using expired token", r.RemoteAddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
func updateAPIKeys() {
|
||||
apiKeysLock.Lock()
|
||||
defer apiKeysLock.Unlock()
|
||||
|
||||
log.Debug("api: importing possibly updated API keys from config")
|
||||
|
||||
// Delete current keys.
|
||||
for k := range apiKeys {
|
||||
delete(apiKeys, k)
|
||||
}
|
||||
|
||||
// whether or not we found expired API keys that should be removed
|
||||
// from the setting
|
||||
hasExpiredKeys := false
|
||||
|
||||
// a list of valid API keys. Used when hasExpiredKeys is set to true.
|
||||
// in that case we'll update the setting to only contain validAPIKeys
|
||||
validAPIKeys := []string{}
|
||||
|
||||
// Parse new keys.
|
||||
for _, key := range configuredAPIKeys() {
|
||||
u, err := url.Parse(key)
|
||||
if err != nil {
|
||||
log.Errorf("api: failed to parse configured API key %s: %s", key, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if u.Path == "" {
|
||||
log.Errorf("api: malformed API key %s: missing path section", key)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// Create token with default permissions.
|
||||
token := &AuthToken{
|
||||
Read: PermitAnyone,
|
||||
Write: PermitAnyone,
|
||||
}
|
||||
|
||||
// Update with configured permissions.
|
||||
q := u.Query()
|
||||
// Parse read permission.
|
||||
readPermission, err := parseAPIPermission(q.Get("read"))
|
||||
if err != nil {
|
||||
log.Errorf("api: invalid API key %s: %s", key, err)
|
||||
continue
|
||||
}
|
||||
token.Read = readPermission
|
||||
// Parse write permission.
|
||||
writePermission, err := parseAPIPermission(q.Get("write"))
|
||||
if err != nil {
|
||||
log.Errorf("api: invalid API key %s: %s", key, err)
|
||||
continue
|
||||
}
|
||||
token.Write = writePermission
|
||||
|
||||
expireStr := q.Get("expires")
|
||||
if expireStr != "" {
|
||||
validUntil, err := time.Parse(time.RFC3339, expireStr)
|
||||
if err != nil {
|
||||
log.Errorf("api: invalid API key %s: %s", key, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// continue to the next token if this one is already invalid
|
||||
if time.Now().After(validUntil) {
|
||||
// mark the key as expired so we'll remove it from the setting afterwards
|
||||
hasExpiredKeys = true
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
token.ValidUntil = &validUntil
|
||||
}
|
||||
|
||||
// Save token.
|
||||
apiKeys[u.Path] = token
|
||||
validAPIKeys = append(validAPIKeys, key)
|
||||
}
|
||||
|
||||
if hasExpiredKeys {
|
||||
module.mgr.Go("api key cleanup", func(ctx *mgr.WorkerCtx) error {
|
||||
if err := config.SetConfigOption(CfgAPIKeys, validAPIKeys); err != nil {
|
||||
log.Errorf("api: failed to remove expired API keys: %s", err)
|
||||
} else {
|
||||
log.Infof("api: removed expired API keys from %s", CfgAPIKeys)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func checkSessionCookie(r *http.Request) *AuthToken {
|
||||
// Get session cookie from request.
|
||||
c, err := r.Cookie(sessionCookieName)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if session cookie is registered.
|
||||
sessionsLock.Lock()
|
||||
sess, ok := sessions[c.Value]
|
||||
sessionsLock.Unlock()
|
||||
if !ok {
|
||||
log.Tracer(r.Context()).Tracef("api: provided session cookie %s is unknown", c.Value)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if session is still valid.
|
||||
if sess.Expired() {
|
||||
log.Tracer(r.Context()).Tracef("api: provided session cookie %s has expired", c.Value)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Refresh session and return.
|
||||
sess.Refresh(sessionCookieTTL)
|
||||
log.Tracer(r.Context()).Tracef("api: session cookie %s is valid, refreshing", c.Value)
|
||||
return sess.token
|
||||
}
|
||||
|
||||
func createSession(w http.ResponseWriter, r *http.Request, token *AuthToken) error {
|
||||
// Generate new session key.
|
||||
secret, err := rng.Bytes(32) // 256 bit
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sessionKey := base64.RawURLEncoding.EncodeToString(secret)
|
||||
|
||||
// Set token cookie in response.
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
Value: sessionKey,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
})
|
||||
|
||||
// Create session.
|
||||
sess := &session{
|
||||
token: token,
|
||||
}
|
||||
sess.Refresh(sessionCookieTTL)
|
||||
|
||||
// Save session.
|
||||
sessionsLock.Lock()
|
||||
defer sessionsLock.Unlock()
|
||||
sessions[sessionKey] = sess
|
||||
log.Tracer(r.Context()).Debug("api: issued session cookie")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func cleanSessions(_ *mgr.WorkerCtx) error {
|
||||
sessionsLock.Lock()
|
||||
defer sessionsLock.Unlock()
|
||||
|
||||
for sessionKey, sess := range sessions {
|
||||
if sess.Expired() {
|
||||
delete(sessions, sessionKey)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func deleteSession(sessionKey string) {
|
||||
sessionsLock.Lock()
|
||||
defer sessionsLock.Unlock()
|
||||
|
||||
delete(sessions, sessionKey)
|
||||
}
|
||||
|
||||
func getEffectiveMethod(r *http.Request) (eMethod string, readMethod bool, ok bool) {
|
||||
method := r.Method
|
||||
|
||||
// Get CORS request method if OPTIONS request.
|
||||
if r.Method == http.MethodOptions {
|
||||
method = r.Header.Get("Access-Control-Request-Method")
|
||||
if method == "" {
|
||||
return "", false, false
|
||||
}
|
||||
}
|
||||
|
||||
switch method {
|
||||
case http.MethodGet, http.MethodHead:
|
||||
return http.MethodGet, true, true
|
||||
case http.MethodPost, http.MethodPut, http.MethodDelete:
|
||||
return method, false, true
|
||||
default:
|
||||
return "", false, false
|
||||
}
|
||||
}
|
||||
|
||||
func parseAPIPermission(s string) (Permission, error) {
|
||||
switch strings.ToLower(s) {
|
||||
case "", "anyone":
|
||||
return PermitAnyone, nil
|
||||
case "user":
|
||||
return PermitUser, nil
|
||||
case "admin":
|
||||
return PermitAdmin, nil
|
||||
default:
|
||||
return PermitAnyone, fmt.Errorf("invalid permission: %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
func (p Permission) String() string {
|
||||
switch p {
|
||||
case NotSupported:
|
||||
return "NotSupported"
|
||||
case Dynamic:
|
||||
return "Dynamic"
|
||||
case PermitAnyone:
|
||||
return "PermitAnyone"
|
||||
case PermitUser:
|
||||
return "PermitUser"
|
||||
case PermitAdmin:
|
||||
return "PermitAdmin"
|
||||
case PermitSelf:
|
||||
return "PermitSelf"
|
||||
case NotFound:
|
||||
return "NotFound"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// Role returns a string representation of the permission role.
|
||||
func (p Permission) Role() string {
|
||||
switch p {
|
||||
case PermitAnyone:
|
||||
return "Anyone"
|
||||
case PermitUser:
|
||||
return "User"
|
||||
case PermitAdmin:
|
||||
return "Admin"
|
||||
case PermitSelf:
|
||||
return "Self"
|
||||
case Dynamic, NotFound, NotSupported:
|
||||
return "Invalid"
|
||||
default:
|
||||
return "Invalid"
|
||||
}
|
||||
}
|
186
base/api/authentication_test.go
Normal file
186
base/api/authentication_test.go
Normal file
|
@ -0,0 +1,186 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var testToken = new(AuthToken)
|
||||
|
||||
func testAuthenticator(r *http.Request, s *http.Server) (*AuthToken, error) {
|
||||
switch {
|
||||
case testToken.Read == -127 || testToken.Write == -127:
|
||||
return nil, errors.New("test error")
|
||||
case testToken.Read == -128 || testToken.Write == -128:
|
||||
return nil, fmt.Errorf("%wdenied", ErrAPIAccessDeniedMessage)
|
||||
default:
|
||||
return testToken, nil
|
||||
}
|
||||
}
|
||||
|
||||
type testAuthHandler struct {
|
||||
Read Permission
|
||||
Write Permission
|
||||
}
|
||||
|
||||
func (ah *testAuthHandler) ReadPermission(r *http.Request) Permission {
|
||||
return ah.Read
|
||||
}
|
||||
|
||||
func (ah *testAuthHandler) WritePermission(r *http.Request) Permission {
|
||||
return ah.Write
|
||||
}
|
||||
|
||||
func (ah *testAuthHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Check if request is as expected.
|
||||
ar := GetAPIRequest(r)
|
||||
switch {
|
||||
case ar == nil:
|
||||
http.Error(w, "ar == nil", http.StatusInternalServerError)
|
||||
case ar.AuthToken == nil:
|
||||
http.Error(w, "ar.AuthToken == nil", http.StatusInternalServerError)
|
||||
default:
|
||||
http.Error(w, "auth success", http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func makeAuthTestPath(reading bool, p Permission) string {
|
||||
if reading {
|
||||
return fmt.Sprintf("/test/auth/read/%s", p)
|
||||
}
|
||||
return fmt.Sprintf("/test/auth/write/%s", p)
|
||||
}
|
||||
|
||||
func TestPermissions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testHandler := &mainHandler{
|
||||
mux: mainMux,
|
||||
}
|
||||
|
||||
// Define permissions that need testing.
|
||||
permissionsToTest := []Permission{
|
||||
NotSupported,
|
||||
PermitAnyone,
|
||||
PermitUser,
|
||||
PermitAdmin,
|
||||
PermitSelf,
|
||||
Dynamic,
|
||||
NotFound,
|
||||
100, // Test a too high value.
|
||||
-100, // Test a too low value.
|
||||
-127, // Simulate authenticator failure.
|
||||
-128, // Simulate authentication denied message.
|
||||
}
|
||||
|
||||
// Register test handlers.
|
||||
for _, p := range permissionsToTest {
|
||||
RegisterHandler(makeAuthTestPath(true, p), &testAuthHandler{Read: p})
|
||||
RegisterHandler(makeAuthTestPath(false, p), &testAuthHandler{Write: p})
|
||||
}
|
||||
|
||||
// Test all the combinations.
|
||||
for _, requestPerm := range permissionsToTest {
|
||||
for _, handlerPerm := range permissionsToTest {
|
||||
for _, method := range []string{
|
||||
http.MethodGet,
|
||||
http.MethodHead,
|
||||
http.MethodPost,
|
||||
http.MethodPut,
|
||||
http.MethodDelete,
|
||||
} {
|
||||
|
||||
// Set request permission for test requests.
|
||||
_, reading, _ := getEffectiveMethod(&http.Request{Method: method})
|
||||
if reading {
|
||||
testToken.Read = requestPerm
|
||||
testToken.Write = NotSupported
|
||||
} else {
|
||||
testToken.Read = NotSupported
|
||||
testToken.Write = requestPerm
|
||||
}
|
||||
|
||||
// Evaluate expected result.
|
||||
var expectSuccess bool
|
||||
switch {
|
||||
case handlerPerm == PermitAnyone:
|
||||
// This is fast-tracked. There are not additional checks.
|
||||
expectSuccess = true
|
||||
case handlerPerm == Dynamic:
|
||||
// This is turned into PermitAnyone in the authenticator.
|
||||
// But authentication is still processed and the result still gets
|
||||
// sanity checked!
|
||||
if requestPerm >= PermitAnyone &&
|
||||
requestPerm <= PermitSelf {
|
||||
expectSuccess = true
|
||||
}
|
||||
// Another special case is when the handler requires permission to be
|
||||
// processed but the authenticator fails to authenticate the request.
|
||||
// In this case, a fallback token with PermitAnyone is used.
|
||||
if requestPerm == -128 {
|
||||
// -128 is used to simulate a permission denied message.
|
||||
expectSuccess = true
|
||||
}
|
||||
case handlerPerm <= NotSupported:
|
||||
// Invalid handler permission.
|
||||
case handlerPerm > PermitSelf:
|
||||
// Invalid handler permission.
|
||||
case requestPerm <= NotSupported:
|
||||
// Invalid request permission.
|
||||
case requestPerm > PermitSelf:
|
||||
// Invalid request permission.
|
||||
case requestPerm < handlerPerm:
|
||||
// Valid, but insufficient request permission.
|
||||
default:
|
||||
expectSuccess = true
|
||||
}
|
||||
|
||||
if expectSuccess {
|
||||
// Test for success.
|
||||
if !assert.HTTPBodyContains(
|
||||
t,
|
||||
testHandler.ServeHTTP,
|
||||
method,
|
||||
makeAuthTestPath(reading, handlerPerm),
|
||||
nil,
|
||||
"auth success",
|
||||
) {
|
||||
t.Errorf(
|
||||
"%s with %s (%d) to handler %s (%d)",
|
||||
method,
|
||||
requestPerm, requestPerm,
|
||||
handlerPerm, handlerPerm,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
// Test for error.
|
||||
if !assert.HTTPError(t,
|
||||
testHandler.ServeHTTP,
|
||||
method,
|
||||
makeAuthTestPath(reading, handlerPerm),
|
||||
nil,
|
||||
) {
|
||||
t.Errorf(
|
||||
"%s with %s (%d) to handler %s (%d)",
|
||||
method,
|
||||
requestPerm, requestPerm,
|
||||
handlerPerm, handlerPerm,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPermissionDefinitions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if NotSupported != 0 {
|
||||
t.Fatalf("NotSupported must be zero, was %v", NotSupported)
|
||||
}
|
||||
}
|
57
base/api/client/api.go
Normal file
57
base/api/client/api.go
Normal file
|
@ -0,0 +1,57 @@
|
|||
package client
|
||||
|
||||
// Get sends a get command to the API.
|
||||
func (c *Client) Get(key string, handleFunc func(*Message)) *Operation {
|
||||
op := c.NewOperation(handleFunc)
|
||||
op.Send(msgRequestGet, key, nil)
|
||||
return op
|
||||
}
|
||||
|
||||
// Query sends a query command to the API.
|
||||
func (c *Client) Query(query string, handleFunc func(*Message)) *Operation {
|
||||
op := c.NewOperation(handleFunc)
|
||||
op.Send(msgRequestQuery, query, nil)
|
||||
return op
|
||||
}
|
||||
|
||||
// Sub sends a sub command to the API.
|
||||
func (c *Client) Sub(query string, handleFunc func(*Message)) *Operation {
|
||||
op := c.NewOperation(handleFunc)
|
||||
op.Send(msgRequestSub, query, nil)
|
||||
return op
|
||||
}
|
||||
|
||||
// Qsub sends a qsub command to the API.
|
||||
func (c *Client) Qsub(query string, handleFunc func(*Message)) *Operation {
|
||||
op := c.NewOperation(handleFunc)
|
||||
op.Send(msgRequestQsub, query, nil)
|
||||
return op
|
||||
}
|
||||
|
||||
// Create sends a create command to the API.
|
||||
func (c *Client) Create(key string, value interface{}, handleFunc func(*Message)) *Operation {
|
||||
op := c.NewOperation(handleFunc)
|
||||
op.Send(msgRequestCreate, key, value)
|
||||
return op
|
||||
}
|
||||
|
||||
// Update sends an update command to the API.
|
||||
func (c *Client) Update(key string, value interface{}, handleFunc func(*Message)) *Operation {
|
||||
op := c.NewOperation(handleFunc)
|
||||
op.Send(msgRequestUpdate, key, value)
|
||||
return op
|
||||
}
|
||||
|
||||
// Insert sends an insert command to the API.
|
||||
func (c *Client) Insert(key string, value interface{}, handleFunc func(*Message)) *Operation {
|
||||
op := c.NewOperation(handleFunc)
|
||||
op.Send(msgRequestInsert, key, value)
|
||||
return op
|
||||
}
|
||||
|
||||
// Delete sends a delete command to the API.
|
||||
func (c *Client) Delete(key string, handleFunc func(*Message)) *Operation {
|
||||
op := c.NewOperation(handleFunc)
|
||||
op.Send(msgRequestDelete, key, nil)
|
||||
return op
|
||||
}
|
240
base/api/client/client.go
Normal file
240
base/api/client/client.go
Normal file
|
@ -0,0 +1,240 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portmaster/base/log"
|
||||
)
|
||||
|
||||
const (
|
||||
backOffTimer = 1 * time.Second
|
||||
|
||||
offlineSignal uint8 = 0
|
||||
onlineSignal uint8 = 1
|
||||
)
|
||||
|
||||
// The Client enables easy interaction with the API.
|
||||
type Client struct {
|
||||
sync.Mutex
|
||||
|
||||
server string
|
||||
|
||||
onlineSignal chan struct{}
|
||||
offlineSignal chan struct{}
|
||||
shutdownSignal chan struct{}
|
||||
lastSignal uint8
|
||||
|
||||
send chan *Message
|
||||
resend chan *Message
|
||||
recv chan *Message
|
||||
|
||||
operations map[string]*Operation
|
||||
nextOpID uint64
|
||||
|
||||
lastError string
|
||||
}
|
||||
|
||||
// NewClient returns a new Client.
|
||||
func NewClient(server string) *Client {
|
||||
c := &Client{
|
||||
server: server,
|
||||
onlineSignal: make(chan struct{}),
|
||||
offlineSignal: make(chan struct{}),
|
||||
shutdownSignal: make(chan struct{}),
|
||||
lastSignal: offlineSignal,
|
||||
send: make(chan *Message, 100),
|
||||
resend: make(chan *Message, 1),
|
||||
recv: make(chan *Message, 100),
|
||||
operations: make(map[string]*Operation),
|
||||
}
|
||||
go c.handler()
|
||||
return c
|
||||
}
|
||||
|
||||
// Connect connects to the API once.
|
||||
func (c *Client) Connect() error {
|
||||
defer c.signalOffline()
|
||||
|
||||
err := c.wsConnect()
|
||||
if err != nil && err.Error() != c.lastError {
|
||||
log.Errorf("client: error connecting to Portmaster: %s", err)
|
||||
c.lastError = err.Error()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// StayConnected calls Connect again whenever the connection is lost.
|
||||
func (c *Client) StayConnected() {
|
||||
log.Infof("client: connecting to Portmaster at %s", c.server)
|
||||
|
||||
_ = c.Connect()
|
||||
for {
|
||||
select {
|
||||
case <-time.After(backOffTimer):
|
||||
log.Infof("client: reconnecting...")
|
||||
_ = c.Connect()
|
||||
case <-c.shutdownSignal:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown shuts the client down.
|
||||
func (c *Client) Shutdown() {
|
||||
select {
|
||||
case <-c.shutdownSignal:
|
||||
default:
|
||||
close(c.shutdownSignal)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) signalOnline() {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
if c.lastSignal == offlineSignal {
|
||||
log.Infof("client: went online")
|
||||
c.offlineSignal = make(chan struct{})
|
||||
close(c.onlineSignal)
|
||||
c.lastSignal = onlineSignal
|
||||
|
||||
// resend unsent request
|
||||
for _, op := range c.operations {
|
||||
if op.resuscitationEnabled.IsSet() && op.request.sent != nil && op.request.sent.SetToIf(true, false) {
|
||||
op.client.send <- op.request
|
||||
log.Infof("client: resuscitated %s %s %s", op.request.OpID, op.request.Type, op.request.Key)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) signalOffline() {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
if c.lastSignal == onlineSignal {
|
||||
log.Infof("client: went offline")
|
||||
c.onlineSignal = make(chan struct{})
|
||||
close(c.offlineSignal)
|
||||
c.lastSignal = offlineSignal
|
||||
|
||||
// signal offline status to operations
|
||||
for _, op := range c.operations {
|
||||
op.handle(&Message{
|
||||
OpID: op.ID,
|
||||
Type: MsgOffline,
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// Online returns a closed channel read if the client is connected to the API.
|
||||
func (c *Client) Online() <-chan struct{} {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
return c.onlineSignal
|
||||
}
|
||||
|
||||
// Offline returns a closed channel read if the client is not connected to the API.
|
||||
func (c *Client) Offline() <-chan struct{} {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
return c.offlineSignal
|
||||
}
|
||||
|
||||
func (c *Client) handler() {
|
||||
for {
|
||||
select {
|
||||
|
||||
case m := <-c.recv:
|
||||
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.Lock()
|
||||
op, ok := c.operations[m.OpID]
|
||||
c.Unlock()
|
||||
|
||||
if ok {
|
||||
log.Tracef("client: [%s] received %s msg: %s", m.OpID, m.Type, m.Key)
|
||||
op.handle(m)
|
||||
} else {
|
||||
log.Tracef("client: received message for unknown operation %s", m.OpID)
|
||||
}
|
||||
|
||||
case <-c.shutdownSignal:
|
||||
return
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Operation represents a single operation by a client.
|
||||
type Operation struct {
|
||||
ID string
|
||||
request *Message
|
||||
client *Client
|
||||
handleFunc func(*Message)
|
||||
handler chan *Message
|
||||
resuscitationEnabled *abool.AtomicBool
|
||||
}
|
||||
|
||||
func (op *Operation) handle(m *Message) {
|
||||
if op.handleFunc != nil {
|
||||
op.handleFunc(m)
|
||||
} else {
|
||||
select {
|
||||
case op.handler <- m:
|
||||
default:
|
||||
log.Warningf("client: handler channel of operation %s overflowed", op.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cancel the operation.
|
||||
func (op *Operation) Cancel() {
|
||||
op.client.Lock()
|
||||
defer op.client.Unlock()
|
||||
delete(op.client.operations, op.ID)
|
||||
close(op.handler)
|
||||
}
|
||||
|
||||
// Send sends a request to the API.
|
||||
func (op *Operation) Send(command, text string, data interface{}) {
|
||||
op.request = &Message{
|
||||
OpID: op.ID,
|
||||
Type: command,
|
||||
Key: text,
|
||||
Value: data,
|
||||
sent: abool.NewBool(false),
|
||||
}
|
||||
log.Tracef("client: [%s] sending %s msg: %s", op.request.OpID, op.request.Type, op.request.Key)
|
||||
op.client.send <- op.request
|
||||
}
|
||||
|
||||
// EnableResuscitation will resend the request after reconnecting to the API.
|
||||
func (op *Operation) EnableResuscitation() {
|
||||
op.resuscitationEnabled.Set()
|
||||
}
|
||||
|
||||
// NewOperation returns a new operation.
|
||||
func (c *Client) NewOperation(handleFunc func(*Message)) *Operation {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
c.nextOpID++
|
||||
op := &Operation{
|
||||
ID: fmt.Sprintf("#%d", c.nextOpID),
|
||||
client: c,
|
||||
handleFunc: handleFunc,
|
||||
handler: make(chan *Message, 100),
|
||||
resuscitationEnabled: abool.NewBool(false),
|
||||
}
|
||||
c.operations[op.ID] = op
|
||||
return op
|
||||
}
|
28
base/api/client/const.go
Normal file
28
base/api/client/const.go
Normal file
|
@ -0,0 +1,28 @@
|
|||
package client
|
||||
|
||||
// Message Types.
|
||||
const (
|
||||
msgRequestGet = "get"
|
||||
msgRequestQuery = "query"
|
||||
msgRequestSub = "sub"
|
||||
msgRequestQsub = "qsub"
|
||||
msgRequestCreate = "create"
|
||||
msgRequestUpdate = "update"
|
||||
msgRequestInsert = "insert"
|
||||
msgRequestDelete = "delete"
|
||||
|
||||
MsgOk = "ok"
|
||||
MsgError = "error"
|
||||
MsgDone = "done"
|
||||
MsgSuccess = "success"
|
||||
MsgUpdate = "upd"
|
||||
MsgNew = "new"
|
||||
MsgDelete = "del"
|
||||
MsgWarning = "warning"
|
||||
|
||||
MsgOffline = "offline" // special message type for signaling the handler that the connection was lost
|
||||
|
||||
apiSeperator = "|"
|
||||
)
|
||||
|
||||
var apiSeperatorBytes = []byte(apiSeperator)
|
95
base/api/client/message.go
Normal file
95
base/api/client/message.go
Normal file
|
@ -0,0 +1,95 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/structures/container"
|
||||
"github.com/safing/structures/dsd"
|
||||
)
|
||||
|
||||
// ErrMalformedMessage is returned when a malformed message was encountered.
|
||||
var ErrMalformedMessage = errors.New("malformed message")
|
||||
|
||||
// Message is an API message.
|
||||
type Message struct {
|
||||
OpID string
|
||||
Type string
|
||||
Key string
|
||||
RawValue []byte
|
||||
Value interface{}
|
||||
sent *abool.AtomicBool
|
||||
}
|
||||
|
||||
// ParseMessage parses the given raw data and returns a Message.
|
||||
func ParseMessage(data []byte) (*Message, error) {
|
||||
parts := bytes.SplitN(data, apiSeperatorBytes, 4)
|
||||
if len(parts) < 2 {
|
||||
return nil, ErrMalformedMessage
|
||||
}
|
||||
|
||||
m := &Message{
|
||||
OpID: string(parts[0]),
|
||||
Type: string(parts[1]),
|
||||
}
|
||||
|
||||
switch m.Type {
|
||||
case MsgOk, MsgUpdate, MsgNew:
|
||||
// parse key and data
|
||||
// 127|ok|<key>|<data>
|
||||
// 127|upd|<key>|<data>
|
||||
// 127|new|<key>|<data>
|
||||
if len(parts) != 4 {
|
||||
return nil, ErrMalformedMessage
|
||||
}
|
||||
m.Key = string(parts[2])
|
||||
m.RawValue = parts[3]
|
||||
case MsgDelete:
|
||||
// parse key
|
||||
// 127|del|<key>
|
||||
if len(parts) != 3 {
|
||||
return nil, ErrMalformedMessage
|
||||
}
|
||||
m.Key = string(parts[2])
|
||||
case MsgWarning, MsgError:
|
||||
// parse message
|
||||
// 127|error|<message>
|
||||
// 127|warning|<message> // error with single record, operation continues
|
||||
if len(parts) != 3 {
|
||||
return nil, ErrMalformedMessage
|
||||
}
|
||||
m.Key = string(parts[2])
|
||||
case MsgDone, MsgSuccess:
|
||||
// nothing more to do
|
||||
// 127|success
|
||||
// 127|done
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Pack serializes a message into a []byte slice.
|
||||
func (m *Message) Pack() ([]byte, error) {
|
||||
c := container.New([]byte(m.OpID), apiSeperatorBytes, []byte(m.Type))
|
||||
|
||||
if m.Key != "" {
|
||||
c.Append(apiSeperatorBytes)
|
||||
c.Append([]byte(m.Key))
|
||||
if len(m.RawValue) > 0 {
|
||||
c.Append(apiSeperatorBytes)
|
||||
c.Append(m.RawValue)
|
||||
} else if m.Value != nil {
|
||||
var err error
|
||||
m.RawValue, err = dsd.Dump(m.Value, dsd.JSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.Append(apiSeperatorBytes)
|
||||
c.Append(m.RawValue)
|
||||
}
|
||||
}
|
||||
|
||||
return c.CompileData(), nil
|
||||
}
|
121
base/api/client/websocket.go
Normal file
121
base/api/client/websocket.go
Normal file
|
@ -0,0 +1,121 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portmaster/base/log"
|
||||
)
|
||||
|
||||
type wsState struct {
|
||||
wsConn *websocket.Conn
|
||||
wg sync.WaitGroup
|
||||
failing *abool.AtomicBool
|
||||
failSignal chan struct{}
|
||||
}
|
||||
|
||||
func (c *Client) wsConnect() error {
|
||||
state := &wsState{
|
||||
failing: abool.NewBool(false),
|
||||
failSignal: make(chan struct{}),
|
||||
}
|
||||
|
||||
var err error
|
||||
state.wsConn, _, err = websocket.DefaultDialer.Dial(fmt.Sprintf("ws://%s/api/database/v1", c.server), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.signalOnline()
|
||||
|
||||
state.wg.Add(2)
|
||||
go c.wsReader(state)
|
||||
go c.wsWriter(state)
|
||||
|
||||
// wait for end of connection
|
||||
select {
|
||||
case <-state.failSignal:
|
||||
case <-c.shutdownSignal:
|
||||
state.Error("")
|
||||
}
|
||||
_ = state.wsConn.Close()
|
||||
state.wg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) wsReader(state *wsState) {
|
||||
defer state.wg.Done()
|
||||
for {
|
||||
_, data, err := state.wsConn.ReadMessage()
|
||||
log.Tracef("client: read message")
|
||||
if err != nil {
|
||||
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||||
state.Error(fmt.Sprintf("client: read error: %s", err))
|
||||
} else {
|
||||
state.Error("client: connection closed by server")
|
||||
}
|
||||
return
|
||||
}
|
||||
log.Tracef("client: received message: %s", string(data))
|
||||
m, err := ParseMessage(data)
|
||||
if err != nil {
|
||||
log.Warningf("client: failed to parse message: %s", err)
|
||||
} else {
|
||||
select {
|
||||
case c.recv <- m:
|
||||
case <-state.failSignal:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) wsWriter(state *wsState) {
|
||||
defer state.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-state.failSignal:
|
||||
return
|
||||
case m := <-c.resend:
|
||||
data, err := m.Pack()
|
||||
if err == nil {
|
||||
err = state.wsConn.WriteMessage(websocket.BinaryMessage, data)
|
||||
}
|
||||
if err != nil {
|
||||
state.Error(fmt.Sprintf("client: write error: %s", err))
|
||||
return
|
||||
}
|
||||
log.Tracef("client: sent message: %s", string(data))
|
||||
if m.sent != nil {
|
||||
m.sent.Set()
|
||||
}
|
||||
case m := <-c.send:
|
||||
data, err := m.Pack()
|
||||
if err == nil {
|
||||
err = state.wsConn.WriteMessage(websocket.BinaryMessage, data)
|
||||
}
|
||||
if err != nil {
|
||||
c.resend <- m
|
||||
state.Error(fmt.Sprintf("client: write error: %s", err))
|
||||
return
|
||||
}
|
||||
log.Tracef("client: sent message: %s", string(data))
|
||||
if m.sent != nil {
|
||||
m.sent.Set()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (state *wsState) Error(message string) {
|
||||
if state.failing.SetToIf(false, true) {
|
||||
close(state.failSignal)
|
||||
if message != "" {
|
||||
log.Warning(message)
|
||||
}
|
||||
}
|
||||
}
|
91
base/api/config.go
Normal file
91
base/api/config.go
Normal file
|
@ -0,0 +1,91 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"flag"
|
||||
|
||||
"github.com/safing/portmaster/base/config"
|
||||
)
|
||||
|
||||
// Config Keys.
|
||||
const (
|
||||
CfgDefaultListenAddressKey = "core/listenAddress"
|
||||
CfgAPIKeys = "core/apiKeys"
|
||||
)
|
||||
|
||||
var (
|
||||
listenAddressFlag string
|
||||
listenAddressConfig config.StringOption
|
||||
defaultListenAddress string
|
||||
|
||||
configuredAPIKeys config.StringArrayOption
|
||||
|
||||
devMode config.BoolOption
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.StringVar(
|
||||
&listenAddressFlag,
|
||||
"api-address",
|
||||
"",
|
||||
"set api listen address; configuration is stronger",
|
||||
)
|
||||
}
|
||||
|
||||
func getDefaultListenAddress() string {
|
||||
// check if overridden
|
||||
if listenAddressFlag != "" {
|
||||
return listenAddressFlag
|
||||
}
|
||||
// return internal default
|
||||
return defaultListenAddress
|
||||
}
|
||||
|
||||
func registerConfig() error {
|
||||
err := config.Register(&config.Option{
|
||||
Name: "API Listen Address",
|
||||
Key: CfgDefaultListenAddressKey,
|
||||
Description: "Defines the IP address and port on which the internal API listens.",
|
||||
OptType: config.OptTypeString,
|
||||
ExpertiseLevel: config.ExpertiseLevelDeveloper,
|
||||
ReleaseLevel: config.ReleaseLevelStable,
|
||||
DefaultValue: getDefaultListenAddress(),
|
||||
ValidationRegex: "^([0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}:[0-9]{1,5}|\\[[:0-9A-Fa-f]+\\]:[0-9]{1,5})$",
|
||||
RequiresRestart: true,
|
||||
Annotations: config.Annotations{
|
||||
config.DisplayOrderAnnotation: 513,
|
||||
config.CategoryAnnotation: "Development",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
listenAddressConfig = config.GetAsString(CfgDefaultListenAddressKey, getDefaultListenAddress())
|
||||
|
||||
err = config.Register(&config.Option{
|
||||
Name: "API Keys",
|
||||
Key: CfgAPIKeys,
|
||||
Description: "Define API keys for privileged access to the API. Every entry is a separate API key with respective permissions. Format is `<key>?read=<perm>&write=<perm>`. Permissions are `anyone`, `user` and `admin`, and may be omitted.",
|
||||
Sensitive: true,
|
||||
OptType: config.OptTypeStringArray,
|
||||
ExpertiseLevel: config.ExpertiseLevelDeveloper,
|
||||
ReleaseLevel: config.ReleaseLevelStable,
|
||||
DefaultValue: []string{},
|
||||
Annotations: config.Annotations{
|
||||
config.DisplayOrderAnnotation: 514,
|
||||
config.CategoryAnnotation: "Development",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
configuredAPIKeys = config.GetAsStringArray(CfgAPIKeys, []string{})
|
||||
|
||||
devMode = config.Concurrent.GetAsBool(config.CfgDevModeKey, false)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetDefaultAPIListenAddress sets the default listen address for the API.
|
||||
func SetDefaultAPIListenAddress(address string) {
|
||||
defaultListenAddress = address
|
||||
}
|
698
base/api/database.go
Normal file
698
base/api/database.go
Normal file
|
@ -0,0 +1,698 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/tevino/abool"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
"github.com/safing/portmaster/base/database"
|
||||
"github.com/safing/portmaster/base/database/iterator"
|
||||
"github.com/safing/portmaster/base/database/query"
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
"github.com/safing/portmaster/base/log"
|
||||
"github.com/safing/portmaster/service/mgr"
|
||||
"github.com/safing/structures/container"
|
||||
"github.com/safing/structures/dsd"
|
||||
"github.com/safing/structures/varint"
|
||||
)
|
||||
|
||||
const (
|
||||
dbMsgTypeOk = "ok"
|
||||
dbMsgTypeError = "error"
|
||||
dbMsgTypeDone = "done"
|
||||
dbMsgTypeSuccess = "success"
|
||||
dbMsgTypeUpd = "upd"
|
||||
dbMsgTypeNew = "new"
|
||||
dbMsgTypeDel = "del"
|
||||
dbMsgTypeWarning = "warning"
|
||||
|
||||
dbAPISeperator = "|"
|
||||
emptyString = ""
|
||||
)
|
||||
|
||||
var (
|
||||
dbAPISeperatorBytes = []byte(dbAPISeperator)
|
||||
dbCompatibilityPermission = PermitAdmin
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterHandler("/api/database/v1", WrapInAuthHandler(
|
||||
startDatabaseWebsocketAPI,
|
||||
// Default to admin read/write permissions until the database gets support
|
||||
// for api permissions.
|
||||
dbCompatibilityPermission,
|
||||
dbCompatibilityPermission,
|
||||
))
|
||||
}
|
||||
|
||||
// DatabaseAPI is a generic database API interface.
|
||||
type DatabaseAPI struct {
|
||||
queriesLock sync.Mutex
|
||||
queries map[string]*iterator.Iterator
|
||||
|
||||
subsLock sync.Mutex
|
||||
subs map[string]*database.Subscription
|
||||
|
||||
shutdownSignal chan struct{}
|
||||
shuttingDown *abool.AtomicBool
|
||||
db *database.Interface
|
||||
|
||||
sendBytes func(data []byte)
|
||||
}
|
||||
|
||||
// DatabaseWebsocketAPI is a database websocket API interface.
|
||||
type DatabaseWebsocketAPI struct {
|
||||
DatabaseAPI
|
||||
|
||||
sendQueue chan []byte
|
||||
conn *websocket.Conn
|
||||
}
|
||||
|
||||
func allowAnyOrigin(r *http.Request) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// CreateDatabaseAPI creates a new database interface.
|
||||
func CreateDatabaseAPI(sendFunction func(data []byte)) DatabaseAPI {
|
||||
return DatabaseAPI{
|
||||
queries: make(map[string]*iterator.Iterator),
|
||||
subs: make(map[string]*database.Subscription),
|
||||
shutdownSignal: make(chan struct{}),
|
||||
shuttingDown: abool.NewBool(false),
|
||||
db: database.NewInterface(nil),
|
||||
sendBytes: sendFunction,
|
||||
}
|
||||
}
|
||||
|
||||
func startDatabaseWebsocketAPI(w http.ResponseWriter, r *http.Request) {
|
||||
upgrader := websocket.Upgrader{
|
||||
CheckOrigin: allowAnyOrigin,
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 65536,
|
||||
}
|
||||
wsConn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
errMsg := fmt.Sprintf("could not upgrade: %s", err)
|
||||
log.Error(errMsg)
|
||||
http.Error(w, errMsg, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
newDBAPI := &DatabaseWebsocketAPI{
|
||||
DatabaseAPI: DatabaseAPI{
|
||||
queries: make(map[string]*iterator.Iterator),
|
||||
subs: make(map[string]*database.Subscription),
|
||||
shutdownSignal: make(chan struct{}),
|
||||
shuttingDown: abool.NewBool(false),
|
||||
db: database.NewInterface(nil),
|
||||
},
|
||||
|
||||
sendQueue: make(chan []byte, 100),
|
||||
conn: wsConn,
|
||||
}
|
||||
|
||||
newDBAPI.sendBytes = func(data []byte) {
|
||||
newDBAPI.sendQueue <- data
|
||||
}
|
||||
|
||||
module.mgr.Go("database api handler", newDBAPI.handler)
|
||||
module.mgr.Go("database api writer", newDBAPI.writer)
|
||||
|
||||
log.Tracer(r.Context()).Infof("api request: init websocket %s %s", r.RemoteAddr, r.RequestURI)
|
||||
}
|
||||
|
||||
func (api *DatabaseWebsocketAPI) handler(_ *mgr.WorkerCtx) error {
|
||||
defer func() {
|
||||
_ = api.shutdown(nil)
|
||||
}()
|
||||
|
||||
for {
|
||||
_, msg, err := api.conn.ReadMessage()
|
||||
if err != nil {
|
||||
return api.shutdown(err)
|
||||
}
|
||||
|
||||
api.Handle(msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (api *DatabaseWebsocketAPI) writer(ctx *mgr.WorkerCtx) error {
|
||||
defer func() {
|
||||
_ = api.shutdown(nil)
|
||||
}()
|
||||
|
||||
var data []byte
|
||||
var err error
|
||||
|
||||
for {
|
||||
select {
|
||||
// prioritize direct writes
|
||||
case data = <-api.sendQueue:
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-api.shutdownSignal:
|
||||
return nil
|
||||
}
|
||||
|
||||
// log.Tracef("api: sending %s", string(*msg))
|
||||
err = api.conn.WriteMessage(websocket.BinaryMessage, data)
|
||||
if err != nil {
|
||||
return api.shutdown(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (api *DatabaseWebsocketAPI) shutdown(err error) error {
|
||||
// Check if we are the first to shut down.
|
||||
if !api.shuttingDown.SetToIf(false, true) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check the given error.
|
||||
if err != nil {
|
||||
if websocket.IsCloseError(err,
|
||||
websocket.CloseNormalClosure,
|
||||
websocket.CloseGoingAway,
|
||||
websocket.CloseAbnormalClosure,
|
||||
) {
|
||||
log.Infof("api: websocket connection to %s closed", api.conn.RemoteAddr())
|
||||
} else {
|
||||
log.Warningf("api: websocket connection error with %s: %s", api.conn.RemoteAddr(), err)
|
||||
}
|
||||
}
|
||||
|
||||
// Trigger shutdown.
|
||||
close(api.shutdownSignal)
|
||||
_ = api.conn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle handles a message for the database API.
|
||||
func (api *DatabaseAPI) Handle(msg []byte) {
|
||||
// 123|get|<key>
|
||||
// 123|ok|<key>|<data>
|
||||
// 123|error|<message>
|
||||
// 124|query|<query>
|
||||
// 124|ok|<key>|<data>
|
||||
// 124|done
|
||||
// 124|error|<message>
|
||||
// 124|warning|<message> // error with single record, operation continues
|
||||
// 124|cancel
|
||||
// 125|sub|<query>
|
||||
// 125|upd|<key>|<data>
|
||||
// 125|new|<key>|<data>
|
||||
// 127|del|<key>
|
||||
// 125|warning|<message> // error with single record, operation continues
|
||||
// 125|cancel
|
||||
// 127|qsub|<query>
|
||||
// 127|ok|<key>|<data>
|
||||
// 127|done
|
||||
// 127|error|<message>
|
||||
// 127|upd|<key>|<data>
|
||||
// 127|new|<key>|<data>
|
||||
// 127|del|<key>
|
||||
// 127|warning|<message> // error with single record, operation continues
|
||||
// 127|cancel
|
||||
|
||||
// 128|create|<key>|<data>
|
||||
// 128|success
|
||||
// 128|error|<message>
|
||||
// 129|update|<key>|<data>
|
||||
// 129|success
|
||||
// 129|error|<message>
|
||||
// 130|insert|<key>|<data>
|
||||
// 130|success
|
||||
// 130|error|<message>
|
||||
// 131|delete|<key>
|
||||
// 131|success
|
||||
// 131|error|<message>
|
||||
|
||||
parts := bytes.SplitN(msg, []byte("|"), 3)
|
||||
|
||||
// Handle special command "cancel"
|
||||
if len(parts) == 2 && string(parts[1]) == "cancel" {
|
||||
// 124|cancel
|
||||
// 125|cancel
|
||||
// 127|cancel
|
||||
go api.handleCancel(parts[0])
|
||||
return
|
||||
}
|
||||
|
||||
if len(parts) != 3 {
|
||||
api.send(nil, dbMsgTypeError, "bad request: malformed message", nil)
|
||||
return
|
||||
}
|
||||
|
||||
switch string(parts[1]) {
|
||||
case "get":
|
||||
// 123|get|<key>
|
||||
go api.handleGet(parts[0], string(parts[2]))
|
||||
case "query":
|
||||
// 124|query|<query>
|
||||
go api.handleQuery(parts[0], string(parts[2]))
|
||||
case "sub":
|
||||
// 125|sub|<query>
|
||||
go api.handleSub(parts[0], string(parts[2]))
|
||||
case "qsub":
|
||||
// 127|qsub|<query>
|
||||
go api.handleQsub(parts[0], string(parts[2]))
|
||||
case "create", "update", "insert":
|
||||
// split key and payload
|
||||
dataParts := bytes.SplitN(parts[2], []byte("|"), 2)
|
||||
if len(dataParts) != 2 {
|
||||
api.send(nil, dbMsgTypeError, "bad request: malformed message", nil)
|
||||
return
|
||||
}
|
||||
|
||||
switch string(parts[1]) {
|
||||
case "create":
|
||||
// 128|create|<key>|<data>
|
||||
go api.handlePut(parts[0], string(dataParts[0]), dataParts[1], true)
|
||||
case "update":
|
||||
// 129|update|<key>|<data>
|
||||
go api.handlePut(parts[0], string(dataParts[0]), dataParts[1], false)
|
||||
case "insert":
|
||||
// 130|insert|<key>|<data>
|
||||
go api.handleInsert(parts[0], string(dataParts[0]), dataParts[1])
|
||||
}
|
||||
case "delete":
|
||||
// 131|delete|<key>
|
||||
go api.handleDelete(parts[0], string(parts[2]))
|
||||
default:
|
||||
api.send(parts[0], dbMsgTypeError, "bad request: unknown method", nil)
|
||||
}
|
||||
}
|
||||
|
||||
func (api *DatabaseAPI) send(opID []byte, msgType string, msgOrKey string, data []byte) {
|
||||
c := container.New(opID)
|
||||
c.Append(dbAPISeperatorBytes)
|
||||
c.Append([]byte(msgType))
|
||||
|
||||
if msgOrKey != emptyString {
|
||||
c.Append(dbAPISeperatorBytes)
|
||||
c.Append([]byte(msgOrKey))
|
||||
}
|
||||
|
||||
if len(data) > 0 {
|
||||
c.Append(dbAPISeperatorBytes)
|
||||
c.Append(data)
|
||||
}
|
||||
|
||||
api.sendBytes(c.CompileData())
|
||||
}
|
||||
|
||||
func (api *DatabaseAPI) handleGet(opID []byte, key string) {
|
||||
// 123|get|<key>
|
||||
// 123|ok|<key>|<data>
|
||||
// 123|error|<message>
|
||||
|
||||
var data []byte
|
||||
|
||||
r, err := api.db.Get(key)
|
||||
if err == nil {
|
||||
data, err = MarshalRecord(r, true)
|
||||
}
|
||||
if err != nil {
|
||||
api.send(opID, dbMsgTypeError, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
api.send(opID, dbMsgTypeOk, r.Key(), data)
|
||||
}
|
||||
|
||||
func (api *DatabaseAPI) handleQuery(opID []byte, queryText string) {
|
||||
// 124|query|<query>
|
||||
// 124|ok|<key>|<data>
|
||||
// 124|done
|
||||
// 124|warning|<message>
|
||||
// 124|error|<message>
|
||||
// 124|warning|<message> // error with single record, operation continues
|
||||
// 124|cancel
|
||||
|
||||
var err error
|
||||
|
||||
q, err := query.ParseQuery(queryText)
|
||||
if err != nil {
|
||||
api.send(opID, dbMsgTypeError, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
api.processQuery(opID, q)
|
||||
}
|
||||
|
||||
func (api *DatabaseAPI) processQuery(opID []byte, q *query.Query) (ok bool) {
|
||||
it, err := api.db.Query(q)
|
||||
if err != nil {
|
||||
api.send(opID, dbMsgTypeError, err.Error(), nil)
|
||||
return false
|
||||
}
|
||||
|
||||
// Save query iterator.
|
||||
api.queriesLock.Lock()
|
||||
api.queries[string(opID)] = it
|
||||
api.queriesLock.Unlock()
|
||||
|
||||
// Remove query iterator after it ended.
|
||||
defer func() {
|
||||
api.queriesLock.Lock()
|
||||
defer api.queriesLock.Unlock()
|
||||
delete(api.queries, string(opID))
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-api.shutdownSignal:
|
||||
// cancel query and return
|
||||
it.Cancel()
|
||||
return false
|
||||
case r := <-it.Next:
|
||||
// process query feed
|
||||
if r != nil {
|
||||
// process record
|
||||
data, err := MarshalRecord(r, true)
|
||||
if err != nil {
|
||||
api.send(opID, dbMsgTypeWarning, err.Error(), nil)
|
||||
continue
|
||||
}
|
||||
api.send(opID, dbMsgTypeOk, r.Key(), data)
|
||||
} else {
|
||||
// sub feed ended
|
||||
if it.Err() != nil {
|
||||
api.send(opID, dbMsgTypeError, it.Err().Error(), nil)
|
||||
return false
|
||||
}
|
||||
api.send(opID, dbMsgTypeDone, emptyString, nil)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// func (api *DatabaseWebsocketAPI) runQuery()
|
||||
|
||||
func (api *DatabaseAPI) handleSub(opID []byte, queryText string) {
|
||||
// 125|sub|<query>
|
||||
// 125|upd|<key>|<data>
|
||||
// 125|new|<key>|<data>
|
||||
// 125|delete|<key>
|
||||
// 125|warning|<message> // error with single record, operation continues
|
||||
// 125|cancel
|
||||
var err error
|
||||
|
||||
q, err := query.ParseQuery(queryText)
|
||||
if err != nil {
|
||||
api.send(opID, dbMsgTypeError, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
sub, ok := api.registerSub(opID, q)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
api.processSub(opID, sub)
|
||||
}
|
||||
|
||||
func (api *DatabaseAPI) registerSub(opID []byte, q *query.Query) (sub *database.Subscription, ok bool) {
|
||||
var err error
|
||||
sub, err = api.db.Subscribe(q)
|
||||
if err != nil {
|
||||
api.send(opID, dbMsgTypeError, err.Error(), nil)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return sub, true
|
||||
}
|
||||
|
||||
func (api *DatabaseAPI) processSub(opID []byte, sub *database.Subscription) {
|
||||
// Save subscription.
|
||||
api.subsLock.Lock()
|
||||
api.subs[string(opID)] = sub
|
||||
api.subsLock.Unlock()
|
||||
|
||||
// Remove subscription after it ended.
|
||||
defer func() {
|
||||
api.subsLock.Lock()
|
||||
defer api.subsLock.Unlock()
|
||||
delete(api.subs, string(opID))
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-api.shutdownSignal:
|
||||
// cancel sub and return
|
||||
_ = sub.Cancel()
|
||||
return
|
||||
case r := <-sub.Feed:
|
||||
// process sub feed
|
||||
if r != nil {
|
||||
// process record
|
||||
data, err := MarshalRecord(r, true)
|
||||
if err != nil {
|
||||
api.send(opID, dbMsgTypeWarning, err.Error(), nil)
|
||||
continue
|
||||
}
|
||||
// TODO: use upd, new and delete msgTypes
|
||||
r.Lock()
|
||||
isDeleted := r.Meta().IsDeleted()
|
||||
isNew := r.Meta().Created == r.Meta().Modified
|
||||
r.Unlock()
|
||||
switch {
|
||||
case isDeleted:
|
||||
api.send(opID, dbMsgTypeDel, r.Key(), nil)
|
||||
case isNew:
|
||||
api.send(opID, dbMsgTypeNew, r.Key(), data)
|
||||
default:
|
||||
api.send(opID, dbMsgTypeUpd, r.Key(), data)
|
||||
}
|
||||
} else {
|
||||
// sub feed ended
|
||||
api.send(opID, dbMsgTypeDone, "", nil)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (api *DatabaseAPI) handleQsub(opID []byte, queryText string) {
|
||||
// 127|qsub|<query>
|
||||
// 127|ok|<key>|<data>
|
||||
// 127|done
|
||||
// 127|error|<message>
|
||||
// 127|upd|<key>|<data>
|
||||
// 127|new|<key>|<data>
|
||||
// 127|delete|<key>
|
||||
// 127|warning|<message> // error with single record, operation continues
|
||||
// 127|cancel
|
||||
|
||||
var err error
|
||||
|
||||
q, err := query.ParseQuery(queryText)
|
||||
if err != nil {
|
||||
api.send(opID, dbMsgTypeError, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
sub, ok := api.registerSub(opID, q)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
ok = api.processQuery(opID, q)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
api.processSub(opID, sub)
|
||||
}
|
||||
|
||||
func (api *DatabaseAPI) handleCancel(opID []byte) {
|
||||
api.cancelQuery(opID)
|
||||
api.cancelSub(opID)
|
||||
}
|
||||
|
||||
func (api *DatabaseAPI) cancelQuery(opID []byte) {
|
||||
api.queriesLock.Lock()
|
||||
defer api.queriesLock.Unlock()
|
||||
|
||||
// Get subscription from api.
|
||||
it, ok := api.queries[string(opID)]
|
||||
if !ok {
|
||||
// Fail silently as quries end by themselves when finished.
|
||||
return
|
||||
}
|
||||
|
||||
// End query.
|
||||
it.Cancel()
|
||||
|
||||
// The query handler will end the communication with a done message.
|
||||
}
|
||||
|
||||
func (api *DatabaseAPI) cancelSub(opID []byte) {
|
||||
api.subsLock.Lock()
|
||||
defer api.subsLock.Unlock()
|
||||
|
||||
// Get subscription from api.
|
||||
sub, ok := api.subs[string(opID)]
|
||||
if !ok {
|
||||
api.send(opID, dbMsgTypeError, "could not find subscription", nil)
|
||||
return
|
||||
}
|
||||
|
||||
// End subscription.
|
||||
err := sub.Cancel()
|
||||
if err != nil {
|
||||
api.send(opID, dbMsgTypeError, fmt.Sprintf("failed to cancel subscription: %s", err), nil)
|
||||
}
|
||||
|
||||
// The subscription handler will end the communication with a done message.
|
||||
}
|
||||
|
||||
func (api *DatabaseAPI) handlePut(opID []byte, key string, data []byte, create bool) {
|
||||
// 128|create|<key>|<data>
|
||||
// 128|success
|
||||
// 128|error|<message>
|
||||
|
||||
// 129|update|<key>|<data>
|
||||
// 129|success
|
||||
// 129|error|<message>
|
||||
|
||||
if len(data) < 2 {
|
||||
api.send(opID, dbMsgTypeError, "bad request: malformed message", nil)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO - staged for deletion: remove transition code
|
||||
// if data[0] != dsd.JSON {
|
||||
// typedData := make([]byte, len(data)+1)
|
||||
// typedData[0] = dsd.JSON
|
||||
// copy(typedData[1:], data)
|
||||
// data = typedData
|
||||
// }
|
||||
|
||||
r, err := record.NewWrapper(key, nil, data[0], data[1:])
|
||||
if err != nil {
|
||||
api.send(opID, dbMsgTypeError, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
if create {
|
||||
err = api.db.PutNew(r)
|
||||
} else {
|
||||
err = api.db.Put(r)
|
||||
}
|
||||
if err != nil {
|
||||
api.send(opID, dbMsgTypeError, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
api.send(opID, dbMsgTypeSuccess, emptyString, nil)
|
||||
}
|
||||
|
||||
func (api *DatabaseAPI) handleInsert(opID []byte, key string, data []byte) {
|
||||
// 130|insert|<key>|<data>
|
||||
// 130|success
|
||||
// 130|error|<message>
|
||||
|
||||
r, err := api.db.Get(key)
|
||||
if err != nil {
|
||||
api.send(opID, dbMsgTypeError, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
acc := r.GetAccessor(r)
|
||||
|
||||
result := gjson.ParseBytes(data)
|
||||
anythingPresent := false
|
||||
var insertError error
|
||||
result.ForEach(func(key gjson.Result, value gjson.Result) bool {
|
||||
anythingPresent = true
|
||||
if !key.Exists() {
|
||||
insertError = errors.New("values must be in a map")
|
||||
return false
|
||||
}
|
||||
if key.Type != gjson.String {
|
||||
insertError = errors.New("keys must be strings")
|
||||
return false
|
||||
}
|
||||
if !value.Exists() {
|
||||
insertError = errors.New("non-existent value")
|
||||
return false
|
||||
}
|
||||
insertError = acc.Set(key.String(), value.Value())
|
||||
return insertError == nil
|
||||
})
|
||||
|
||||
if insertError != nil {
|
||||
api.send(opID, dbMsgTypeError, insertError.Error(), nil)
|
||||
return
|
||||
}
|
||||
if !anythingPresent {
|
||||
api.send(opID, dbMsgTypeError, "could not find any valid values", nil)
|
||||
return
|
||||
}
|
||||
|
||||
err = api.db.Put(r)
|
||||
if err != nil {
|
||||
api.send(opID, dbMsgTypeError, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
api.send(opID, dbMsgTypeSuccess, emptyString, nil)
|
||||
}
|
||||
|
||||
func (api *DatabaseAPI) handleDelete(opID []byte, key string) {
|
||||
// 131|delete|<key>
|
||||
// 131|success
|
||||
// 131|error|<message>
|
||||
|
||||
err := api.db.Delete(key)
|
||||
if err != nil {
|
||||
api.send(opID, dbMsgTypeError, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
api.send(opID, dbMsgTypeSuccess, emptyString, nil)
|
||||
}
|
||||
|
||||
// MarshalRecord locks and marshals the given record, additionally adding
|
||||
// metadata and returning it as json.
|
||||
func MarshalRecord(r record.Record, withDSDIdentifier bool) ([]byte, error) {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
// Pour record into JSON.
|
||||
jsonData, err := r.Marshal(r, dsd.JSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Remove JSON identifier for manual editing.
|
||||
jsonData = bytes.TrimPrefix(jsonData, varint.Pack8(dsd.JSON))
|
||||
|
||||
// Add metadata.
|
||||
jsonData, err = sjson.SetBytes(jsonData, "_meta", r.Meta())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Add database key.
|
||||
jsonData, err = sjson.SetBytes(jsonData, "_meta.Key", r.Key())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Add JSON identifier again.
|
||||
if withDSDIdentifier {
|
||||
formatID := varint.Pack8(dsd.JSON)
|
||||
finalData := make([]byte, 0, len(formatID)+len(jsonData))
|
||||
finalData = append(finalData, formatID...)
|
||||
finalData = append(finalData, jsonData...)
|
||||
return finalData, nil
|
||||
}
|
||||
return jsonData, nil
|
||||
}
|
10
base/api/doc.go
Normal file
10
base/api/doc.go
Normal file
|
@ -0,0 +1,10 @@
|
|||
/*
|
||||
Package api provides an API for integration with other components of the same software package and also third party components.
|
||||
|
||||
It provides direct database access as well as a simpler way to register API endpoints. You can of course also register raw `http.Handler`s directly.
|
||||
|
||||
Optional authentication guards registered handlers. This is achieved by attaching functions to the `http.Handler`s that are registered, which allow them to specify the required permissions for the handler.
|
||||
|
||||
The permissions are divided into the roles and assume a single user per host. The Roles are User, Admin and Self. User roles are expected to have mostly read access and react to notifications or system events, like a system tray program. The Admin role is meant for advanced components that also change settings, but are restricted so they cannot break the software. Self is reserved for internal use with full access.
|
||||
*/
|
||||
package api
|
521
base/api/endpoints.go
Normal file
521
base/api/endpoints.go
Normal file
|
@ -0,0 +1,521 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
"github.com/safing/portmaster/base/log"
|
||||
"github.com/safing/structures/dsd"
|
||||
)
|
||||
|
||||
// Endpoint describes an API Endpoint.
|
||||
// Path and at least one permission are required.
|
||||
// As is exactly one function.
|
||||
type Endpoint struct { //nolint:maligned
|
||||
// Name is the human reabable name of the endpoint.
|
||||
Name string
|
||||
// Description is the human readable description and documentation of the endpoint.
|
||||
Description string
|
||||
// Parameters is the parameter documentation.
|
||||
Parameters []Parameter `json:",omitempty"`
|
||||
|
||||
// Path describes the URL path of the endpoint.
|
||||
Path string
|
||||
|
||||
// MimeType defines the content type of the returned data.
|
||||
MimeType string
|
||||
|
||||
// Read defines the required read permission.
|
||||
Read Permission `json:",omitempty"`
|
||||
|
||||
// ReadMethod sets the required read method for the endpoint.
|
||||
// Available methods are:
|
||||
// GET: Returns data only, no action is taken, nothing is changed.
|
||||
// If omitted, defaults to GET.
|
||||
//
|
||||
// This field is currently being introduced and will only warn and not deny
|
||||
// access if the write method does not match.
|
||||
ReadMethod string `json:",omitempty"`
|
||||
|
||||
// Write defines the required write permission.
|
||||
Write Permission `json:",omitempty"`
|
||||
|
||||
// WriteMethod sets the required write method for the endpoint.
|
||||
// Available methods are:
|
||||
// POST: Create a new resource; Change a status; Execute a function
|
||||
// PUT: Update an existing resource
|
||||
// DELETE: Remove an existing resource
|
||||
// If omitted, defaults to POST.
|
||||
//
|
||||
// This field is currently being introduced and will only warn and not deny
|
||||
// access if the write method does not match.
|
||||
WriteMethod string `json:",omitempty"`
|
||||
|
||||
// ActionFunc is for simple actions with a return message for the user.
|
||||
ActionFunc ActionFunc `json:"-"`
|
||||
|
||||
// DataFunc is for returning raw data that the caller for further processing.
|
||||
DataFunc DataFunc `json:"-"`
|
||||
|
||||
// StructFunc is for returning any kind of struct.
|
||||
StructFunc StructFunc `json:"-"`
|
||||
|
||||
// RecordFunc is for returning a database record. It will be properly locked
|
||||
// and marshalled including metadata.
|
||||
RecordFunc RecordFunc `json:"-"`
|
||||
|
||||
// HandlerFunc is the raw http handler.
|
||||
HandlerFunc http.HandlerFunc `json:"-"`
|
||||
}
|
||||
|
||||
// Parameter describes a parameterized variation of an endpoint.
|
||||
type Parameter struct {
|
||||
Method string
|
||||
Field string
|
||||
Value string
|
||||
Description string
|
||||
}
|
||||
|
||||
// HTTPStatusProvider is an interface for errors to provide a custom HTTP
|
||||
// status code.
|
||||
type HTTPStatusProvider interface {
|
||||
HTTPStatus() int
|
||||
}
|
||||
|
||||
// HTTPStatusError represents an error with an HTTP status code.
|
||||
type HTTPStatusError struct {
|
||||
err error
|
||||
code int
|
||||
}
|
||||
|
||||
// Error returns the error message.
|
||||
func (e *HTTPStatusError) Error() string {
|
||||
return e.err.Error()
|
||||
}
|
||||
|
||||
// Unwrap return the wrapped error.
|
||||
func (e *HTTPStatusError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
// HTTPStatus returns the HTTP status code this error.
|
||||
func (e *HTTPStatusError) HTTPStatus() int {
|
||||
return e.code
|
||||
}
|
||||
|
||||
// ErrorWithStatus adds the HTTP status code to the error.
|
||||
func ErrorWithStatus(err error, code int) error {
|
||||
return &HTTPStatusError{
|
||||
err: err,
|
||||
code: code,
|
||||
}
|
||||
}
|
||||
|
||||
type (
|
||||
// ActionFunc is for simple actions with a return message for the user.
|
||||
ActionFunc func(ar *Request) (msg string, err error)
|
||||
|
||||
// DataFunc is for returning raw data that the caller for further processing.
|
||||
DataFunc func(ar *Request) (data []byte, err error)
|
||||
|
||||
// StructFunc is for returning any kind of struct.
|
||||
StructFunc func(ar *Request) (i interface{}, err error)
|
||||
|
||||
// RecordFunc is for returning a database record. It will be properly locked
|
||||
// and marshalled including metadata.
|
||||
RecordFunc func(ar *Request) (r record.Record, err error)
|
||||
)
|
||||
|
||||
// MIME Types.
|
||||
const (
|
||||
MimeTypeJSON string = "application/json"
|
||||
MimeTypeText string = "text/plain"
|
||||
|
||||
apiV1Path = "/api/v1/"
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterHandler(apiV1Path+"{endpointPath:.+}", &endpointHandler{})
|
||||
}
|
||||
|
||||
var (
|
||||
endpoints = make(map[string]*Endpoint)
|
||||
endpointsMux = mux.NewRouter()
|
||||
endpointsLock sync.RWMutex
|
||||
|
||||
// ErrInvalidEndpoint is returned when an invalid endpoint is registered.
|
||||
ErrInvalidEndpoint = errors.New("endpoint is invalid")
|
||||
|
||||
// ErrAlreadyRegistered is returned when there already is an endpoint with
|
||||
// the same path registered.
|
||||
ErrAlreadyRegistered = errors.New("an endpoint for this path is already registered")
|
||||
)
|
||||
|
||||
func getAPIContext(r *http.Request) (apiEndpoint *Endpoint, apiRequest *Request) {
|
||||
// Get request context and check if we already have an action cached.
|
||||
apiRequest = GetAPIRequest(r)
|
||||
if apiRequest == nil {
|
||||
return nil, nil
|
||||
}
|
||||
var ok bool
|
||||
apiEndpoint, ok = apiRequest.HandlerCache.(*Endpoint)
|
||||
if ok {
|
||||
return apiEndpoint, apiRequest
|
||||
}
|
||||
|
||||
endpointsLock.RLock()
|
||||
defer endpointsLock.RUnlock()
|
||||
|
||||
// Get handler for request.
|
||||
// Gorilla does not support handling this on our own very well.
|
||||
// See github.com/gorilla/mux.ServeHTTP for reference.
|
||||
var match mux.RouteMatch
|
||||
var handler http.Handler
|
||||
if endpointsMux.Match(r, &match) {
|
||||
handler = match.Handler
|
||||
apiRequest.Route = match.Route
|
||||
// Add/Override variables instead of replacing.
|
||||
for k, v := range match.Vars {
|
||||
apiRequest.URLVars[k] = v
|
||||
}
|
||||
} else {
|
||||
return nil, apiRequest
|
||||
}
|
||||
|
||||
apiEndpoint, ok = handler.(*Endpoint)
|
||||
if ok {
|
||||
// Cache for next operation.
|
||||
apiRequest.HandlerCache = apiEndpoint
|
||||
}
|
||||
return apiEndpoint, apiRequest
|
||||
}
|
||||
|
||||
// RegisterEndpoint registers a new endpoint. An error will be returned if it
|
||||
// does not pass the sanity checks.
|
||||
func RegisterEndpoint(e Endpoint) error {
|
||||
if err := e.check(); err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrInvalidEndpoint, err)
|
||||
}
|
||||
|
||||
endpointsLock.Lock()
|
||||
defer endpointsLock.Unlock()
|
||||
|
||||
_, ok := endpoints[e.Path]
|
||||
if ok {
|
||||
return ErrAlreadyRegistered
|
||||
}
|
||||
|
||||
endpoints[e.Path] = &e
|
||||
endpointsMux.Handle(apiV1Path+e.Path, &e)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetEndpointByPath returns the endpoint registered with the given path.
|
||||
func GetEndpointByPath(path string) (*Endpoint, error) {
|
||||
endpointsLock.Lock()
|
||||
defer endpointsLock.Unlock()
|
||||
endpoint, ok := endpoints[path]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no registered endpoint on path: %q", path)
|
||||
}
|
||||
|
||||
return endpoint, nil
|
||||
}
|
||||
|
||||
func (e *Endpoint) check() error {
|
||||
// Check path.
|
||||
if strings.TrimSpace(e.Path) == "" {
|
||||
return errors.New("path is missing")
|
||||
}
|
||||
|
||||
// Check permissions.
|
||||
if e.Read < Dynamic || e.Read > PermitSelf {
|
||||
return errors.New("invalid read permission")
|
||||
}
|
||||
if e.Write < Dynamic || e.Write > PermitSelf {
|
||||
return errors.New("invalid write permission")
|
||||
}
|
||||
|
||||
// Check methods.
|
||||
if e.Read != NotSupported {
|
||||
switch e.ReadMethod {
|
||||
case http.MethodGet:
|
||||
// All good.
|
||||
case "":
|
||||
// Set to default.
|
||||
e.ReadMethod = http.MethodGet
|
||||
default:
|
||||
return errors.New("invalid read method")
|
||||
}
|
||||
} else {
|
||||
e.ReadMethod = ""
|
||||
}
|
||||
if e.Write != NotSupported {
|
||||
switch e.WriteMethod {
|
||||
case http.MethodPost,
|
||||
http.MethodPut,
|
||||
http.MethodDelete:
|
||||
// All good.
|
||||
case "":
|
||||
// Set to default.
|
||||
e.WriteMethod = http.MethodPost
|
||||
default:
|
||||
return errors.New("invalid write method")
|
||||
}
|
||||
} else {
|
||||
e.WriteMethod = ""
|
||||
}
|
||||
|
||||
// Check functions.
|
||||
var defaultMimeType string
|
||||
fnCnt := 0
|
||||
if e.ActionFunc != nil {
|
||||
fnCnt++
|
||||
defaultMimeType = MimeTypeText
|
||||
}
|
||||
if e.DataFunc != nil {
|
||||
fnCnt++
|
||||
defaultMimeType = MimeTypeText
|
||||
}
|
||||
if e.StructFunc != nil {
|
||||
fnCnt++
|
||||
defaultMimeType = MimeTypeJSON
|
||||
}
|
||||
if e.RecordFunc != nil {
|
||||
fnCnt++
|
||||
defaultMimeType = MimeTypeJSON
|
||||
}
|
||||
if e.HandlerFunc != nil {
|
||||
fnCnt++
|
||||
defaultMimeType = MimeTypeText
|
||||
}
|
||||
if fnCnt != 1 {
|
||||
return errors.New("only one function may be set")
|
||||
}
|
||||
|
||||
// Set default mime type.
|
||||
if e.MimeType == "" {
|
||||
e.MimeType = defaultMimeType
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExportEndpoints exports the registered endpoints. The returned data must be
|
||||
// treated as immutable.
|
||||
func ExportEndpoints() []*Endpoint {
|
||||
endpointsLock.RLock()
|
||||
defer endpointsLock.RUnlock()
|
||||
|
||||
// Copy the map into a slice.
|
||||
eps := make([]*Endpoint, 0, len(endpoints))
|
||||
for _, ep := range endpoints {
|
||||
eps = append(eps, ep)
|
||||
}
|
||||
|
||||
sort.Sort(sortByPath(eps))
|
||||
return eps
|
||||
}
|
||||
|
||||
type sortByPath []*Endpoint
|
||||
|
||||
func (eps sortByPath) Len() int { return len(eps) }
|
||||
func (eps sortByPath) Less(i, j int) bool { return eps[i].Path < eps[j].Path }
|
||||
func (eps sortByPath) Swap(i, j int) { eps[i], eps[j] = eps[j], eps[i] }
|
||||
|
||||
type endpointHandler struct{}
|
||||
|
||||
var _ AuthenticatedHandler = &endpointHandler{} // Compile time interface check.
|
||||
|
||||
// ReadPermission returns the read permission for the handler.
|
||||
func (eh *endpointHandler) ReadPermission(r *http.Request) Permission {
|
||||
apiEndpoint, _ := getAPIContext(r)
|
||||
if apiEndpoint != nil {
|
||||
return apiEndpoint.Read
|
||||
}
|
||||
return NotFound
|
||||
}
|
||||
|
||||
// WritePermission returns the write permission for the handler.
|
||||
func (eh *endpointHandler) WritePermission(r *http.Request) Permission {
|
||||
apiEndpoint, _ := getAPIContext(r)
|
||||
if apiEndpoint != nil {
|
||||
return apiEndpoint.Write
|
||||
}
|
||||
return NotFound
|
||||
}
|
||||
|
||||
// ServeHTTP handles the http request.
|
||||
func (eh *endpointHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
apiEndpoint, apiRequest := getAPIContext(r)
|
||||
if apiEndpoint == nil || apiRequest == nil {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
apiEndpoint.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// ServeHTTP handles the http request.
|
||||
func (e *Endpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
_, apiRequest := getAPIContext(r)
|
||||
if apiRequest == nil {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Return OPTIONS request before starting to handle normal requests.
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
eMethod, readMethod, ok := getEffectiveMethod(r)
|
||||
if !ok {
|
||||
http.Error(w, "unsupported method for the actions API", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
if readMethod {
|
||||
if eMethod != e.ReadMethod {
|
||||
log.Tracer(r.Context()).Warningf(
|
||||
"api: method %q does not match required read method %q%s",
|
||||
r.Method,
|
||||
e.ReadMethod,
|
||||
" - this will be an error and abort the request in the future",
|
||||
)
|
||||
}
|
||||
} else {
|
||||
if eMethod != e.WriteMethod {
|
||||
log.Tracer(r.Context()).Warningf(
|
||||
"api: method %q does not match required write method %q%s",
|
||||
r.Method,
|
||||
e.WriteMethod,
|
||||
" - this will be an error and abort the request in the future",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
switch eMethod {
|
||||
case http.MethodGet, http.MethodDelete:
|
||||
// Nothing to do for these.
|
||||
case http.MethodPost, http.MethodPut:
|
||||
// Read body data.
|
||||
inputData, ok := readBody(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
apiRequest.InputData = inputData
|
||||
|
||||
// restore request body for any http.HandlerFunc below
|
||||
r.Body = io.NopCloser(bytes.NewReader(inputData))
|
||||
default:
|
||||
// Defensive.
|
||||
http.Error(w, "unsupported method for the actions API", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Add response headers to request struct so that the endpoint can work with them.
|
||||
apiRequest.ResponseHeader = w.Header()
|
||||
|
||||
// Execute action function and get response data
|
||||
var responseData []byte
|
||||
var err error
|
||||
|
||||
switch {
|
||||
case e.ActionFunc != nil:
|
||||
var msg string
|
||||
msg, err = e.ActionFunc(apiRequest)
|
||||
if !strings.HasSuffix(msg, "\n") {
|
||||
msg += "\n"
|
||||
}
|
||||
if err == nil {
|
||||
responseData = []byte(msg)
|
||||
}
|
||||
|
||||
case e.DataFunc != nil:
|
||||
responseData, err = e.DataFunc(apiRequest)
|
||||
|
||||
case e.StructFunc != nil:
|
||||
var v interface{}
|
||||
v, err = e.StructFunc(apiRequest)
|
||||
if err == nil && v != nil {
|
||||
var mimeType string
|
||||
responseData, mimeType, _, err = dsd.MimeDump(v, r.Header.Get("Accept"))
|
||||
if err == nil {
|
||||
w.Header().Set("Content-Type", mimeType)
|
||||
}
|
||||
}
|
||||
|
||||
case e.RecordFunc != nil:
|
||||
var rec record.Record
|
||||
rec, err = e.RecordFunc(apiRequest)
|
||||
if err == nil && r != nil {
|
||||
responseData, err = MarshalRecord(rec, false)
|
||||
}
|
||||
|
||||
case e.HandlerFunc != nil:
|
||||
e.HandlerFunc(w, r)
|
||||
return
|
||||
|
||||
default:
|
||||
http.Error(w, "missing handler", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Check for handler error.
|
||||
if err != nil {
|
||||
var statusProvider HTTPStatusProvider
|
||||
if errors.As(err, &statusProvider) {
|
||||
http.Error(w, err.Error(), statusProvider.HTTPStatus())
|
||||
} else {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Return no content if there is none, or if request is HEAD.
|
||||
if len(responseData) == 0 || r.Method == http.MethodHead {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
// Set content type if not yet set.
|
||||
if w.Header().Get("Content-Type") == "" {
|
||||
w.Header().Set("Content-Type", e.MimeType+"; charset=utf-8")
|
||||
}
|
||||
|
||||
// Write response.
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(responseData)))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, err = w.Write(responseData)
|
||||
if err != nil {
|
||||
log.Tracer(r.Context()).Warningf("api: failed to write response: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func readBody(w http.ResponseWriter, r *http.Request) (inputData []byte, ok bool) {
|
||||
// Check for too long content in order to prevent death.
|
||||
if r.ContentLength > 20000000 { // 20MB
|
||||
http.Error(w, "too much input data", http.StatusRequestEntityTooLarge)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Read and close body.
|
||||
inputData, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to read body"+err.Error(), http.StatusInternalServerError)
|
||||
return nil, false
|
||||
}
|
||||
return inputData, true
|
||||
}
|
24
base/api/endpoints_config.go
Normal file
24
base/api/endpoints_config.go
Normal file
|
@ -0,0 +1,24 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"github.com/safing/portmaster/base/config"
|
||||
)
|
||||
|
||||
func registerConfigEndpoints() error {
|
||||
if err := RegisterEndpoint(Endpoint{
|
||||
Path: "config/options",
|
||||
Read: PermitAnyone,
|
||||
MimeType: MimeTypeJSON,
|
||||
StructFunc: listConfig,
|
||||
Name: "Export Configuration Options",
|
||||
Description: "Returns a list of all registered configuration options and their metadata. This does not include the current active or default settings.",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func listConfig(ar *Request) (i interface{}, err error) {
|
||||
return config.ExportOptions(), nil
|
||||
}
|
249
base/api/endpoints_debug.go
Normal file
249
base/api/endpoints_debug.go
Normal file
|
@ -0,0 +1,249 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime/pprof"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/base/info"
|
||||
"github.com/safing/portmaster/base/utils/debug"
|
||||
)
|
||||
|
||||
func registerDebugEndpoints() error {
|
||||
if err := RegisterEndpoint(Endpoint{
|
||||
Path: "ping",
|
||||
Read: PermitAnyone,
|
||||
ActionFunc: ping,
|
||||
Name: "Ping",
|
||||
Description: "Pong.",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := RegisterEndpoint(Endpoint{
|
||||
Path: "ready",
|
||||
Read: PermitAnyone,
|
||||
ActionFunc: ready,
|
||||
Name: "Ready",
|
||||
Description: "Check if Portmaster has completed starting and is ready.",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := RegisterEndpoint(Endpoint{
|
||||
Path: "debug/stack",
|
||||
Read: PermitAnyone,
|
||||
DataFunc: getStack,
|
||||
Name: "Get Goroutine Stack",
|
||||
Description: "Returns the current goroutine stack.",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := RegisterEndpoint(Endpoint{
|
||||
Path: "debug/stack/print",
|
||||
Read: PermitAnyone,
|
||||
ActionFunc: printStack,
|
||||
Name: "Print Goroutine Stack",
|
||||
Description: "Prints the current goroutine stack to stdout.",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := RegisterEndpoint(Endpoint{
|
||||
Path: "debug/cpu",
|
||||
MimeType: "application/octet-stream",
|
||||
Read: PermitAnyone,
|
||||
DataFunc: handleCPUProfile,
|
||||
Name: "Get CPU Profile",
|
||||
Description: strings.ReplaceAll(`Gather and return the CPU profile.
|
||||
This data needs to gathered over a period of time, which is specified using the duration parameter.
|
||||
|
||||
You can easily view this data in your browser with this command (with Go installed):
|
||||
"go tool pprof -http :8888 http://127.0.0.1:817/api/v1/debug/cpu"
|
||||
`, `"`, "`"),
|
||||
Parameters: []Parameter{{
|
||||
Method: http.MethodGet,
|
||||
Field: "duration",
|
||||
Value: "10s",
|
||||
Description: "Specify the formatting style. The default is simple markdown formatting.",
|
||||
}},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := RegisterEndpoint(Endpoint{
|
||||
Path: "debug/heap",
|
||||
MimeType: "application/octet-stream",
|
||||
Read: PermitAnyone,
|
||||
DataFunc: handleHeapProfile,
|
||||
Name: "Get Heap Profile",
|
||||
Description: strings.ReplaceAll(`Gather and return the heap memory profile.
|
||||
|
||||
You can easily view this data in your browser with this command (with Go installed):
|
||||
"go tool pprof -http :8888 http://127.0.0.1:817/api/v1/debug/heap"
|
||||
`, `"`, "`"),
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := RegisterEndpoint(Endpoint{
|
||||
Path: "debug/allocs",
|
||||
MimeType: "application/octet-stream",
|
||||
Read: PermitAnyone,
|
||||
DataFunc: handleAllocsProfile,
|
||||
Name: "Get Allocs Profile",
|
||||
Description: strings.ReplaceAll(`Gather and return the memory allocation profile.
|
||||
|
||||
You can easily view this data in your browser with this command (with Go installed):
|
||||
"go tool pprof -http :8888 http://127.0.0.1:817/api/v1/debug/allocs"
|
||||
`, `"`, "`"),
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := RegisterEndpoint(Endpoint{
|
||||
Path: "debug/info",
|
||||
Read: PermitAnyone,
|
||||
DataFunc: debugInfo,
|
||||
Name: "Get Debug Information",
|
||||
Description: "Returns debugging information, including the version and platform info, errors, logs and the current goroutine stack.",
|
||||
Parameters: []Parameter{{
|
||||
Method: http.MethodGet,
|
||||
Field: "style",
|
||||
Value: "github",
|
||||
Description: "Specify the formatting style. The default is simple markdown formatting.",
|
||||
}},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ping responds with pong.
|
||||
func ping(ar *Request) (msg string, err error) {
|
||||
return "Pong.", nil
|
||||
}
|
||||
|
||||
// ready checks if Portmaster has completed starting.
|
||||
func ready(ar *Request) (msg string, err error) {
|
||||
if module.instance.Ready() {
|
||||
return "", ErrorWithStatus(errors.New("portmaster is not ready, reload (F5) to try again"), http.StatusTooEarly)
|
||||
}
|
||||
return "Portmaster is ready.", nil
|
||||
}
|
||||
|
||||
// getStack returns the current goroutine stack.
|
||||
func getStack(_ *Request) (data []byte, err error) {
|
||||
buf := &bytes.Buffer{}
|
||||
err = pprof.Lookup("goroutine").WriteTo(buf, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// printStack prints the current goroutine stack to stderr.
|
||||
func printStack(_ *Request) (msg string, err error) {
|
||||
_, err = fmt.Fprint(os.Stderr, "===== PRINTING STACK =====\n")
|
||||
if err == nil {
|
||||
err = pprof.Lookup("goroutine").WriteTo(os.Stderr, 1)
|
||||
}
|
||||
if err == nil {
|
||||
_, err = fmt.Fprint(os.Stderr, "===== END OF STACK =====\n")
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return "stack printed to stdout", nil
|
||||
}
|
||||
|
||||
// handleCPUProfile returns the CPU profile.
|
||||
func handleCPUProfile(ar *Request) (data []byte, err error) {
|
||||
// Parse duration.
|
||||
duration := 10 * time.Second
|
||||
if durationOption := ar.Request.URL.Query().Get("duration"); durationOption != "" {
|
||||
parsedDuration, err := time.ParseDuration(durationOption)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse duration: %w", err)
|
||||
}
|
||||
duration = parsedDuration
|
||||
}
|
||||
|
||||
// Indicate download and filename.
|
||||
ar.ResponseHeader.Set(
|
||||
"Content-Disposition",
|
||||
fmt.Sprintf(`attachment; filename="portmaster-cpu-profile_v%s.pprof"`, info.Version()),
|
||||
)
|
||||
|
||||
// Start CPU profiling.
|
||||
buf := new(bytes.Buffer)
|
||||
if err := pprof.StartCPUProfile(buf); err != nil {
|
||||
return nil, fmt.Errorf("failed to start cpu profile: %w", err)
|
||||
}
|
||||
|
||||
// Wait for the specified duration.
|
||||
select {
|
||||
case <-time.After(duration):
|
||||
case <-ar.Context().Done():
|
||||
pprof.StopCPUProfile()
|
||||
return nil, context.Canceled
|
||||
}
|
||||
|
||||
// Stop CPU profiling and return data.
|
||||
pprof.StopCPUProfile()
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// handleHeapProfile returns the Heap profile.
|
||||
func handleHeapProfile(ar *Request) (data []byte, err error) {
|
||||
// Indicate download and filename.
|
||||
ar.ResponseHeader.Set(
|
||||
"Content-Disposition",
|
||||
fmt.Sprintf(`attachment; filename="portmaster-memory-heap-profile_v%s.pprof"`, info.Version()),
|
||||
)
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
if err := pprof.Lookup("heap").WriteTo(buf, 0); err != nil {
|
||||
return nil, fmt.Errorf("failed to write heap profile: %w", err)
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// handleAllocsProfile returns the Allocs profile.
|
||||
func handleAllocsProfile(ar *Request) (data []byte, err error) {
|
||||
// Indicate download and filename.
|
||||
ar.ResponseHeader.Set(
|
||||
"Content-Disposition",
|
||||
fmt.Sprintf(`attachment; filename="portmaster-memory-allocs-profile_v%s.pprof"`, info.Version()),
|
||||
)
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
if err := pprof.Lookup("allocs").WriteTo(buf, 0); err != nil {
|
||||
return nil, fmt.Errorf("failed to write allocs profile: %w", err)
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// debugInfo returns the debugging information for support requests.
|
||||
func debugInfo(ar *Request) (data []byte, err error) {
|
||||
// Create debug information helper.
|
||||
di := new(debug.Info)
|
||||
di.Style = ar.Request.URL.Query().Get("style")
|
||||
|
||||
// Add debug information.
|
||||
di.AddVersionInfo()
|
||||
di.AddPlatformInfo(ar.Context())
|
||||
di.AddLastUnexpectedLogs()
|
||||
di.AddGoroutineStack()
|
||||
|
||||
// Return data.
|
||||
return di.Bytes(), nil
|
||||
}
|
140
base/api/endpoints_meta.go
Normal file
140
base/api/endpoints_meta.go
Normal file
|
@ -0,0 +1,140 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func registerMetaEndpoints() error {
|
||||
if err := RegisterEndpoint(Endpoint{
|
||||
Path: "endpoints",
|
||||
Read: PermitAnyone,
|
||||
MimeType: MimeTypeJSON,
|
||||
DataFunc: listEndpoints,
|
||||
Name: "Export API Endpoints",
|
||||
Description: "Returns a list of all registered endpoints and their metadata.",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := RegisterEndpoint(Endpoint{
|
||||
Path: "auth/permissions",
|
||||
Read: Dynamic,
|
||||
StructFunc: permissions,
|
||||
Name: "View Current Permissions",
|
||||
Description: "Returns the current permissions assigned to the request.",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := RegisterEndpoint(Endpoint{
|
||||
Path: "auth/bearer",
|
||||
Read: Dynamic,
|
||||
HandlerFunc: authBearer,
|
||||
Name: "Request HTTP Bearer Auth",
|
||||
Description: "Returns an HTTP Bearer Auth request, if not authenticated.",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := RegisterEndpoint(Endpoint{
|
||||
Path: "auth/basic",
|
||||
Read: Dynamic,
|
||||
HandlerFunc: authBasic,
|
||||
Name: "Request HTTP Basic Auth",
|
||||
Description: "Returns an HTTP Basic Auth request, if not authenticated.",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := RegisterEndpoint(Endpoint{
|
||||
Path: "auth/reset",
|
||||
Read: PermitAnyone,
|
||||
HandlerFunc: authReset,
|
||||
Name: "Reset Authenticated Session",
|
||||
Description: "Resets authentication status internally and in the browser.",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func listEndpoints(ar *Request) (data []byte, err error) {
|
||||
data, err = json.Marshal(ExportEndpoints())
|
||||
return
|
||||
}
|
||||
|
||||
func permissions(ar *Request) (i interface{}, err error) {
|
||||
if ar.AuthToken == nil {
|
||||
return nil, errors.New("authentication token missing")
|
||||
}
|
||||
|
||||
return struct {
|
||||
Read Permission
|
||||
Write Permission
|
||||
ReadRole string
|
||||
WriteRole string
|
||||
}{
|
||||
Read: ar.AuthToken.Read,
|
||||
Write: ar.AuthToken.Write,
|
||||
ReadRole: ar.AuthToken.Read.Role(),
|
||||
WriteRole: ar.AuthToken.Write.Role(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func authBearer(w http.ResponseWriter, r *http.Request) {
|
||||
// Check if authenticated by checking read permission.
|
||||
ar := GetAPIRequest(r)
|
||||
if ar.AuthToken.Read != PermitAnyone {
|
||||
TextResponse(w, r, "Authenticated.")
|
||||
return
|
||||
}
|
||||
|
||||
// Respond with desired authentication header.
|
||||
w.Header().Set(
|
||||
"WWW-Authenticate",
|
||||
`Bearer realm="Portmaster API" domain="/"`,
|
||||
)
|
||||
http.Error(w, "Authorization required.", http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
func authBasic(w http.ResponseWriter, r *http.Request) {
|
||||
// Check if authenticated by checking read permission.
|
||||
ar := GetAPIRequest(r)
|
||||
if ar.AuthToken.Read != PermitAnyone {
|
||||
TextResponse(w, r, "Authenticated.")
|
||||
return
|
||||
}
|
||||
|
||||
// Respond with desired authentication header.
|
||||
w.Header().Set(
|
||||
"WWW-Authenticate",
|
||||
`Basic realm="Portmaster API" domain="/"`,
|
||||
)
|
||||
http.Error(w, "Authorization required.", http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
func authReset(w http.ResponseWriter, r *http.Request) {
|
||||
// Get session cookie from request and delete session if exists.
|
||||
c, err := r.Cookie(sessionCookieName)
|
||||
if err == nil {
|
||||
deleteSession(c.Value)
|
||||
}
|
||||
|
||||
// Delete session and cookie.
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
MaxAge: -1, // MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0'
|
||||
})
|
||||
|
||||
// Request client to also reset all data.
|
||||
w.Header().Set("Clear-Site-Data", "*")
|
||||
|
||||
// Set HTTP Auth Realm without requesting authorization.
|
||||
w.Header().Set("WWW-Authenticate", `None realm="Portmaster API"`)
|
||||
|
||||
// Reply with 401 Unauthorized in order to clear HTTP Basic Auth data.
|
||||
http.Error(w, "Session deleted.", http.StatusUnauthorized)
|
||||
}
|
161
base/api/endpoints_test.go
Normal file
161
base/api/endpoints_test.go
Normal file
|
@ -0,0 +1,161 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
)
|
||||
|
||||
const (
|
||||
successMsg = "endpoint api success"
|
||||
failedMsg = "endpoint api failed"
|
||||
)
|
||||
|
||||
type actionTestRecord struct {
|
||||
record.Base
|
||||
sync.Mutex
|
||||
Msg string
|
||||
}
|
||||
|
||||
func TestEndpoints(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testHandler := &mainHandler{
|
||||
mux: mainMux,
|
||||
}
|
||||
|
||||
// ActionFn
|
||||
|
||||
assert.NoError(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/action",
|
||||
Read: PermitAnyone,
|
||||
ActionFunc: func(_ *Request) (msg string, err error) {
|
||||
return successMsg, nil
|
||||
},
|
||||
}))
|
||||
assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/action", nil, successMsg)
|
||||
|
||||
assert.NoError(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/action-err",
|
||||
Read: PermitAnyone,
|
||||
ActionFunc: func(_ *Request) (msg string, err error) {
|
||||
return "", errors.New(failedMsg)
|
||||
},
|
||||
}))
|
||||
assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/action-err", nil, failedMsg)
|
||||
|
||||
// DataFn
|
||||
|
||||
assert.NoError(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/data",
|
||||
Read: PermitAnyone,
|
||||
DataFunc: func(_ *Request) (data []byte, err error) {
|
||||
return []byte(successMsg), nil
|
||||
},
|
||||
}))
|
||||
assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/data", nil, successMsg)
|
||||
|
||||
assert.NoError(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/data-err",
|
||||
Read: PermitAnyone,
|
||||
DataFunc: func(_ *Request) (data []byte, err error) {
|
||||
return nil, errors.New(failedMsg)
|
||||
},
|
||||
}))
|
||||
assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/data-err", nil, failedMsg)
|
||||
|
||||
// StructFn
|
||||
|
||||
assert.NoError(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/struct",
|
||||
Read: PermitAnyone,
|
||||
StructFunc: func(_ *Request) (i interface{}, err error) {
|
||||
return &actionTestRecord{
|
||||
Msg: successMsg,
|
||||
}, nil
|
||||
},
|
||||
}))
|
||||
assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/struct", nil, successMsg)
|
||||
|
||||
assert.NoError(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/struct-err",
|
||||
Read: PermitAnyone,
|
||||
StructFunc: func(_ *Request) (i interface{}, err error) {
|
||||
return nil, errors.New(failedMsg)
|
||||
},
|
||||
}))
|
||||
assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/struct-err", nil, failedMsg)
|
||||
|
||||
// RecordFn
|
||||
|
||||
assert.NoError(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/record",
|
||||
Read: PermitAnyone,
|
||||
RecordFunc: func(_ *Request) (r record.Record, err error) {
|
||||
r = &actionTestRecord{
|
||||
Msg: successMsg,
|
||||
}
|
||||
r.CreateMeta()
|
||||
return r, nil
|
||||
},
|
||||
}))
|
||||
assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/record", nil, successMsg)
|
||||
|
||||
assert.NoError(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/record-err",
|
||||
Read: PermitAnyone,
|
||||
RecordFunc: func(_ *Request) (r record.Record, err error) {
|
||||
return nil, errors.New(failedMsg)
|
||||
},
|
||||
}))
|
||||
assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/record-err", nil, failedMsg)
|
||||
}
|
||||
|
||||
func TestActionRegistration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.Error(t, RegisterEndpoint(Endpoint{}))
|
||||
|
||||
assert.Error(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/err",
|
||||
Read: NotFound,
|
||||
}))
|
||||
assert.Error(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/err",
|
||||
Read: PermitSelf + 1,
|
||||
}))
|
||||
|
||||
assert.Error(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/err",
|
||||
Write: NotFound,
|
||||
}))
|
||||
assert.Error(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/err",
|
||||
Write: PermitSelf + 1,
|
||||
}))
|
||||
|
||||
assert.Error(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/err",
|
||||
}))
|
||||
|
||||
assert.Error(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/err",
|
||||
ActionFunc: func(_ *Request) (msg string, err error) {
|
||||
return successMsg, nil
|
||||
},
|
||||
DataFunc: func(_ *Request) (data []byte, err error) {
|
||||
return []byte(successMsg), nil
|
||||
},
|
||||
}))
|
||||
|
||||
assert.NoError(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/err",
|
||||
ActionFunc: func(_ *Request) (msg string, err error) {
|
||||
return successMsg, nil
|
||||
},
|
||||
}))
|
||||
}
|
68
base/api/enriched-response.go
Normal file
68
base/api/enriched-response.go
Normal file
|
@ -0,0 +1,68 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/safing/portmaster/base/log"
|
||||
)
|
||||
|
||||
// LoggingResponseWriter is a wrapper for http.ResponseWriter for better request logging.
|
||||
type LoggingResponseWriter struct {
|
||||
ResponseWriter http.ResponseWriter
|
||||
Request *http.Request
|
||||
Status int
|
||||
}
|
||||
|
||||
// NewLoggingResponseWriter wraps a http.ResponseWriter.
|
||||
func NewLoggingResponseWriter(w http.ResponseWriter, r *http.Request) *LoggingResponseWriter {
|
||||
return &LoggingResponseWriter{
|
||||
ResponseWriter: w,
|
||||
Request: r,
|
||||
}
|
||||
}
|
||||
|
||||
// Header wraps the original Header method.
|
||||
func (lrw *LoggingResponseWriter) Header() http.Header {
|
||||
return lrw.ResponseWriter.Header()
|
||||
}
|
||||
|
||||
// Write wraps the original Write method.
|
||||
func (lrw *LoggingResponseWriter) Write(b []byte) (int, error) {
|
||||
return lrw.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
// WriteHeader wraps the original WriteHeader method to extract information.
|
||||
func (lrw *LoggingResponseWriter) WriteHeader(code int) {
|
||||
lrw.Status = code
|
||||
lrw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
// Hijack wraps the original Hijack method, if available.
|
||||
func (lrw *LoggingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
hijacker, ok := lrw.ResponseWriter.(http.Hijacker)
|
||||
if ok {
|
||||
c, b, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
log.Tracer(lrw.Request.Context()).Infof("api request: %s HIJ %s", lrw.Request.RemoteAddr, lrw.Request.RequestURI)
|
||||
return c, b, nil
|
||||
}
|
||||
return nil, nil, errors.New("response does not implement http.Hijacker")
|
||||
}
|
||||
|
||||
// RequestLogger is a logging middleware.
|
||||
func RequestLogger(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Tracer(r.Context()).Tracef("api request: %s ___ %s", r.RemoteAddr, r.RequestURI)
|
||||
lrw := NewLoggingResponseWriter(w, r)
|
||||
next.ServeHTTP(lrw, r)
|
||||
if lrw.Status != 0 {
|
||||
// request may have been hijacked
|
||||
log.Tracer(r.Context()).Infof("api request: %s %d %s", lrw.Request.RemoteAddr, lrw.Status, lrw.Request.RequestURI)
|
||||
}
|
||||
})
|
||||
}
|
38
base/api/init_test.go
Normal file
38
base/api/init_test.go
Normal file
|
@ -0,0 +1,38 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/safing/portmaster/base/config"
|
||||
)
|
||||
|
||||
type testInstance struct {
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
var _ instance = &testInstance{}
|
||||
|
||||
func (stub *testInstance) Config() *config.Config {
|
||||
return stub.config
|
||||
}
|
||||
|
||||
func (stub *testInstance) SetCmdLineOperation(f func() error) {}
|
||||
|
||||
func (stub *testInstance) Ready() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
SetDefaultAPIListenAddress("0.0.0.0:8080")
|
||||
instance := &testInstance{}
|
||||
var err error
|
||||
module, err = New(instance)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = SetAuthenticator(testAuthenticator)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
m.Run()
|
||||
}
|
82
base/api/main.go
Normal file
82
base/api/main.go
Normal file
|
@ -0,0 +1,82 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/service/mgr"
|
||||
)
|
||||
|
||||
var exportEndpoints bool
|
||||
|
||||
// API Errors.
|
||||
var (
|
||||
ErrAuthenticationAlreadySet = errors.New("the authentication function has already been set")
|
||||
ErrAuthenticationImmutable = errors.New("the authentication function can only be set before the api has started")
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.BoolVar(&exportEndpoints, "export-api-endpoints", false, "export api endpoint registry and exit")
|
||||
}
|
||||
|
||||
func prep() error {
|
||||
// Register endpoints.
|
||||
if err := registerConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := registerDebugEndpoints(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := registerConfigEndpoints(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := registerMetaEndpoints(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if exportEndpoints {
|
||||
module.instance.SetCmdLineOperation(exportEndpointsCmd)
|
||||
return mgr.ErrExecuteCmdLineOp
|
||||
}
|
||||
|
||||
if getDefaultListenAddress() == "" {
|
||||
return errors.New("no default listen address for api available")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func start() error {
|
||||
startServer()
|
||||
|
||||
updateAPIKeys()
|
||||
module.instance.Config().EventConfigChange.AddCallback("update API keys",
|
||||
func(wc *mgr.WorkerCtx, s struct{}) (cancel bool, err error) {
|
||||
updateAPIKeys()
|
||||
return false, nil
|
||||
})
|
||||
|
||||
// start api auth token cleaner
|
||||
if authFnSet.IsSet() {
|
||||
_ = module.mgr.Repeat("clean api sessions", 5*time.Minute, cleanSessions)
|
||||
}
|
||||
|
||||
return registerEndpointBridgeDB()
|
||||
}
|
||||
|
||||
func stop() error {
|
||||
return stopServer()
|
||||
}
|
||||
|
||||
func exportEndpointsCmd() error {
|
||||
data, err := json.MarshalIndent(ExportEndpoints(), "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = os.Stdout.Write(data)
|
||||
return err
|
||||
}
|
65
base/api/module.go
Normal file
65
base/api/module.go
Normal file
|
@ -0,0 +1,65 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/safing/portmaster/base/config"
|
||||
"github.com/safing/portmaster/service/mgr"
|
||||
)
|
||||
|
||||
// API is the HTTP/Websockets API module.
|
||||
type API struct {
|
||||
mgr *mgr.Manager
|
||||
instance instance
|
||||
|
||||
online atomic.Bool
|
||||
}
|
||||
|
||||
func (api *API) Manager() *mgr.Manager {
|
||||
return api.mgr
|
||||
}
|
||||
|
||||
// Start starts the module.
|
||||
func (api *API) Start() error {
|
||||
if err := start(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
api.online.Store(true)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the module.
|
||||
func (api *API) Stop() error {
|
||||
defer api.online.Store(false)
|
||||
return stop()
|
||||
}
|
||||
|
||||
var (
|
||||
shimLoaded atomic.Bool
|
||||
module *API
|
||||
)
|
||||
|
||||
// New returns a new UI module.
|
||||
func New(instance instance) (*API, error) {
|
||||
if !shimLoaded.CompareAndSwap(false, true) {
|
||||
return nil, errors.New("only one instance allowed")
|
||||
}
|
||||
m := mgr.New("API")
|
||||
module = &API{
|
||||
mgr: m,
|
||||
instance: instance,
|
||||
}
|
||||
|
||||
if err := prep(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return module, nil
|
||||
}
|
||||
|
||||
type instance interface {
|
||||
Config() *config.Config
|
||||
SetCmdLineOperation(f func() error)
|
||||
Ready() bool
|
||||
}
|
60
base/api/request.go
Normal file
60
base/api/request.go
Normal file
|
@ -0,0 +1,60 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/safing/portmaster/base/log"
|
||||
)
|
||||
|
||||
// Request is a support struct to pool more request related information.
|
||||
type Request struct {
|
||||
// Request is the http request.
|
||||
*http.Request
|
||||
|
||||
// InputData contains the request body for write operations.
|
||||
InputData []byte
|
||||
|
||||
// Route of this request.
|
||||
Route *mux.Route
|
||||
|
||||
// URLVars contains the URL variables extracted by the gorilla mux.
|
||||
URLVars map[string]string
|
||||
|
||||
// AuthToken is the request-side authentication token assigned.
|
||||
AuthToken *AuthToken
|
||||
|
||||
// ResponseHeader holds the response header.
|
||||
ResponseHeader http.Header
|
||||
|
||||
// HandlerCache can be used by handlers to cache data between handlers within a request.
|
||||
HandlerCache interface{}
|
||||
}
|
||||
|
||||
// apiRequestContextKey is a key used for the context key/value storage.
|
||||
type apiRequestContextKey struct{}
|
||||
|
||||
// RequestContextKey is the key used to add the API request to the context.
|
||||
var RequestContextKey = apiRequestContextKey{}
|
||||
|
||||
// GetAPIRequest returns the API Request of the given http request.
|
||||
func GetAPIRequest(r *http.Request) *Request {
|
||||
ar, ok := r.Context().Value(RequestContextKey).(*Request)
|
||||
if ok {
|
||||
return ar
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TextResponse writes a text response.
|
||||
func TextResponse(w http.ResponseWriter, r *http.Request, text string) {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, err := fmt.Fprintln(w, text)
|
||||
if err != nil {
|
||||
log.Tracer(r.Context()).Warningf("api: failed to write text response: %s", err)
|
||||
}
|
||||
}
|
329
base/api/router.go
Normal file
329
base/api/router.go
Normal file
|
@ -0,0 +1,329 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/safing/portmaster/base/log"
|
||||
"github.com/safing/portmaster/base/utils"
|
||||
"github.com/safing/portmaster/service/mgr"
|
||||
)
|
||||
|
||||
// EnableServer defines if the HTTP server should be started.
|
||||
var EnableServer = true
|
||||
|
||||
var (
|
||||
// mainMux is the main mux router.
|
||||
mainMux = mux.NewRouter()
|
||||
|
||||
// server is the main server.
|
||||
server = &http.Server{
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
}
|
||||
handlerLock sync.RWMutex
|
||||
|
||||
allowedDevCORSOrigins = []string{
|
||||
"127.0.0.1",
|
||||
"localhost",
|
||||
}
|
||||
)
|
||||
|
||||
// RegisterHandler registers a handler with the API endpoint.
|
||||
func RegisterHandler(path string, handler http.Handler) *mux.Route {
|
||||
handlerLock.Lock()
|
||||
defer handlerLock.Unlock()
|
||||
return mainMux.Handle(path, handler)
|
||||
}
|
||||
|
||||
// RegisterHandleFunc registers a handle function with the API endpoint.
|
||||
func RegisterHandleFunc(path string, handleFunc func(http.ResponseWriter, *http.Request)) *mux.Route {
|
||||
handlerLock.Lock()
|
||||
defer handlerLock.Unlock()
|
||||
return mainMux.HandleFunc(path, handleFunc)
|
||||
}
|
||||
|
||||
func startServer() {
|
||||
// Check if server is enabled.
|
||||
if !EnableServer {
|
||||
return
|
||||
}
|
||||
|
||||
// Configure server.
|
||||
server.Addr = listenAddressConfig()
|
||||
server.Handler = &mainHandler{
|
||||
// TODO: mainMux should not be modified anymore.
|
||||
mux: mainMux,
|
||||
}
|
||||
|
||||
// Start server manager.
|
||||
module.mgr.Go("http server manager", serverManager)
|
||||
}
|
||||
|
||||
func stopServer() error {
|
||||
// Check if server is enabled.
|
||||
if !EnableServer {
|
||||
return nil
|
||||
}
|
||||
|
||||
if server.Addr != "" {
|
||||
return server.Shutdown(context.Background())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Serve starts serving the API endpoint.
|
||||
func serverManager(ctx *mgr.WorkerCtx) error {
|
||||
// start serving
|
||||
log.Infof("api: starting to listen on %s", server.Addr)
|
||||
backoffDuration := 10 * time.Second
|
||||
for {
|
||||
err := module.mgr.Do("http server", func(ctx *mgr.WorkerCtx) error {
|
||||
err := server.ListenAndServe()
|
||||
// return on shutdown error
|
||||
if errors.Is(err, http.ErrServerClosed) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
})
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
// log error and restart
|
||||
log.Errorf("api: http endpoint failed: %s - restarting in %s", err, backoffDuration)
|
||||
time.Sleep(backoffDuration)
|
||||
}
|
||||
}
|
||||
|
||||
type mainHandler struct {
|
||||
mux *mux.Router
|
||||
}
|
||||
|
||||
func (mh *mainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
_ = module.mgr.Do("http request", func(_ *mgr.WorkerCtx) error {
|
||||
return mh.handle(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (mh *mainHandler) handle(w http.ResponseWriter, r *http.Request) error {
|
||||
// Setup context trace logging.
|
||||
ctx, tracer := log.AddTracer(r.Context())
|
||||
// Add request context.
|
||||
apiRequest := &Request{
|
||||
Request: r,
|
||||
}
|
||||
ctx = context.WithValue(ctx, RequestContextKey, apiRequest)
|
||||
// Add context back to request.
|
||||
r = r.WithContext(ctx)
|
||||
lrw := NewLoggingResponseWriter(w, r)
|
||||
|
||||
tracer.Tracef("api request: %s ___ %s %s", r.RemoteAddr, lrw.Request.Method, r.RequestURI)
|
||||
defer func() {
|
||||
// Log request status.
|
||||
if lrw.Status != 0 {
|
||||
// If lrw.Status is 0, the request may have been hijacked.
|
||||
tracer.Debugf("api request: %s %d %s %s", lrw.Request.RemoteAddr, lrw.Status, lrw.Request.Method, lrw.Request.RequestURI)
|
||||
}
|
||||
tracer.Submit()
|
||||
}()
|
||||
|
||||
// Add security headers.
|
||||
w.Header().Set("Referrer-Policy", "same-origin")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.Header().Set("X-Frame-Options", "deny")
|
||||
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
||||
w.Header().Set("X-DNS-Prefetch-Control", "off")
|
||||
|
||||
// Add CSP Header in production mode.
|
||||
if !devMode() {
|
||||
w.Header().Set(
|
||||
"Content-Security-Policy",
|
||||
"default-src 'self'; "+
|
||||
"connect-src https://*.safing.io 'self'; "+
|
||||
"style-src 'self' 'unsafe-inline'; "+
|
||||
"img-src 'self' data: blob:",
|
||||
)
|
||||
}
|
||||
|
||||
// Check Cross-Origin Requests.
|
||||
origin := r.Header.Get("Origin")
|
||||
isPreflighCheck := false
|
||||
if origin != "" {
|
||||
|
||||
// Parse origin URL.
|
||||
originURL, err := url.Parse(origin)
|
||||
if err != nil {
|
||||
tracer.Warningf("api: denied request from %s: failed to parse origin header: %s", r.RemoteAddr, err)
|
||||
http.Error(lrw, "Invalid Origin.", http.StatusForbidden)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if the Origin matches the Host.
|
||||
switch {
|
||||
case originURL.Host == r.Host:
|
||||
// Origin (with port) matches Host.
|
||||
case originURL.Hostname() == r.Host:
|
||||
// Origin (without port) matches Host.
|
||||
case originURL.Scheme == "chrome-extension":
|
||||
// Allow access for the browser extension
|
||||
// TODO(ppacher):
|
||||
// This currently allows access from any browser extension.
|
||||
// Can we reduce that to only our browser extension?
|
||||
// Also, what do we need to support Firefox?
|
||||
case devMode() &&
|
||||
utils.StringInSlice(allowedDevCORSOrigins, originURL.Hostname()):
|
||||
// We are in dev mode and the request is coming from the allowed
|
||||
// development origins.
|
||||
default:
|
||||
// Origin and Host do NOT match!
|
||||
tracer.Warningf("api: denied request from %s: Origin (`%s`) and Host (`%s`) do not match", r.RemoteAddr, origin, r.Host)
|
||||
http.Error(lrw, "Cross-Origin Request Denied.", http.StatusForbidden)
|
||||
return nil
|
||||
|
||||
// If the Host header has a port, and the Origin does not, requests will
|
||||
// also end up here, as we cannot properly check for equality.
|
||||
}
|
||||
|
||||
// Add Cross-Site Headers now as we need them in any case now.
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
w.Header().Set("Access-Control-Allow-Methods", "*")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "*")
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
w.Header().Set("Access-Control-Expose-Headers", "*")
|
||||
w.Header().Set("Access-Control-Max-Age", "60")
|
||||
w.Header().Add("Vary", "Origin")
|
||||
|
||||
// if there's a Access-Control-Request-Method header this is a Preflight check.
|
||||
// In that case, we will just check if the preflighMethod is allowed and then return
|
||||
// success here
|
||||
if preflighMethod := r.Header.Get("Access-Control-Request-Method"); r.Method == http.MethodOptions && preflighMethod != "" {
|
||||
isPreflighCheck = true
|
||||
}
|
||||
}
|
||||
|
||||
// Clean URL.
|
||||
cleanedRequestPath := cleanRequestPath(r.URL.Path)
|
||||
|
||||
// If the cleaned URL differs from the original one, redirect to there.
|
||||
if r.URL.Path != cleanedRequestPath {
|
||||
redirURL := *r.URL
|
||||
redirURL.Path = cleanedRequestPath
|
||||
http.Redirect(lrw, r, redirURL.String(), http.StatusMovedPermanently)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get handler for request.
|
||||
// Gorilla does not support handling this on our own very well.
|
||||
// See github.com/gorilla/mux.ServeHTTP for reference.
|
||||
var match mux.RouteMatch
|
||||
var handler http.Handler
|
||||
if mh.mux.Match(r, &match) {
|
||||
handler = match.Handler
|
||||
apiRequest.Route = match.Route
|
||||
apiRequest.URLVars = match.Vars
|
||||
}
|
||||
switch {
|
||||
case match.MatchErr == nil:
|
||||
// All good.
|
||||
case errors.Is(match.MatchErr, mux.ErrMethodMismatch):
|
||||
http.Error(lrw, "Method not allowed.", http.StatusMethodNotAllowed)
|
||||
return nil
|
||||
default:
|
||||
tracer.Debug("api: no handler registered for this path")
|
||||
http.Error(lrw, "Not found.", http.StatusNotFound)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Be sure that URLVars always is a map.
|
||||
if apiRequest.URLVars == nil {
|
||||
apiRequest.URLVars = make(map[string]string)
|
||||
}
|
||||
|
||||
// Check method.
|
||||
_, readMethod, ok := getEffectiveMethod(r)
|
||||
if !ok {
|
||||
http.Error(lrw, "Method not allowed.", http.StatusMethodNotAllowed)
|
||||
return nil
|
||||
}
|
||||
|
||||
// At this point we know the method is allowed and there's a handler for the request.
|
||||
// If this is just a CORS-Preflight, we'll accept the request with StatusOK now.
|
||||
// There's no point in trying to authenticate the request because the Browser will
|
||||
// not send authentication along a preflight check.
|
||||
if isPreflighCheck && handler != nil {
|
||||
lrw.WriteHeader(http.StatusOK)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check authentication.
|
||||
apiRequest.AuthToken = authenticateRequest(lrw, r, handler, readMethod)
|
||||
if apiRequest.AuthToken == nil {
|
||||
// Authenticator already replied.
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if we have a handler.
|
||||
if handler == nil {
|
||||
http.Error(lrw, "Not found.", http.StatusNotFound)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Format panics in handler.
|
||||
defer func() {
|
||||
if panicValue := recover(); panicValue != nil {
|
||||
// Log failure.
|
||||
log.Errorf("api: handler panic: %s", panicValue)
|
||||
// Respond with a server error.
|
||||
if devMode() {
|
||||
http.Error(
|
||||
lrw,
|
||||
fmt.Sprintf(
|
||||
"Internal Server Error: %s\n\n%s",
|
||||
panicValue,
|
||||
debug.Stack(),
|
||||
),
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
} else {
|
||||
http.Error(lrw, "Internal Server Error.", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Handle with registered handler.
|
||||
handler.ServeHTTP(lrw, r)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanRequestPath cleans and returns a request URL.
|
||||
func cleanRequestPath(requestPath string) string {
|
||||
// If the request URL is empty, return a request for "root".
|
||||
if requestPath == "" || requestPath == "/" {
|
||||
return "/"
|
||||
}
|
||||
// If the request URL does not start with a slash, prepend it.
|
||||
if !strings.HasPrefix(requestPath, "/") {
|
||||
requestPath = "/" + requestPath
|
||||
}
|
||||
|
||||
// Clean path to remove any relative parts.
|
||||
cleanedRequestPath := path.Clean(requestPath)
|
||||
// Because path.Clean removes a trailing slash, we need to add it back here
|
||||
// if the original URL had one.
|
||||
if strings.HasSuffix(requestPath, "/") {
|
||||
cleanedRequestPath += "/"
|
||||
}
|
||||
|
||||
return cleanedRequestPath
|
||||
}
|
49
base/api/testclient/root/index.html
Normal file
49
base/api/testclient/root/index.html
Normal file
|
@ -0,0 +1,49 @@
|
|||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title></title>
|
||||
<!-- <script src="https://cdn.jsdelivr.net/sockjs/1/sockjs.min.js"></script> -->
|
||||
</head>
|
||||
<body>
|
||||
<script type="text/javascript">
|
||||
|
||||
var ws = new WebSocket('ws://127.0.0.1:18/api/database/v1')
|
||||
|
||||
ws.onopen = function () {
|
||||
console.log('open');
|
||||
};
|
||||
|
||||
ws.onerror = function (error) {
|
||||
console.log('error');
|
||||
console.log(error);
|
||||
};
|
||||
|
||||
ws.onmessage = function (e) {
|
||||
reader = new FileReader()
|
||||
reader.onload = function(e) {
|
||||
console.log(e.target.result)
|
||||
}
|
||||
reader.readAsText(e.data)
|
||||
};
|
||||
|
||||
function send(text) {
|
||||
ws.send(text)
|
||||
}
|
||||
|
||||
// var sock = new SockJS("http://localhost:8080/api/v1");
|
||||
//
|
||||
// sock.onopen = function() {
|
||||
// console.log('open');
|
||||
// };
|
||||
//
|
||||
// sock.onmessage = function(e) {
|
||||
// console.log('message received: ', e.data);
|
||||
// };
|
||||
//
|
||||
// sock.onclose = function(e) {
|
||||
// console.log('close', e);
|
||||
// };
|
||||
</script>
|
||||
yeeee
|
||||
</body>
|
||||
</html>
|
11
base/api/testclient/serve.go
Normal file
11
base/api/testclient/serve.go
Normal file
|
@ -0,0 +1,11 @@
|
|||
package testclient
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/safing/portmaster/base/api"
|
||||
)
|
||||
|
||||
func init() {
|
||||
api.RegisterHandler("/test/", http.StripPrefix("/test/", http.FileServer(http.Dir("./api/testclient/root/"))))
|
||||
}
|
167
base/apprise/notify.go
Normal file
167
base/apprise/notify.go
Normal file
|
@ -0,0 +1,167 @@
|
|||
package apprise
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/safing/portmaster/base/utils"
|
||||
)
|
||||
|
||||
// Notifier sends messsages to an Apprise API.
|
||||
type Notifier struct {
|
||||
// URL defines the Apprise API endpoint.
|
||||
URL string
|
||||
|
||||
// DefaultType defines the default message type.
|
||||
DefaultType MsgType
|
||||
|
||||
// DefaultTag defines the default message tag.
|
||||
DefaultTag string
|
||||
|
||||
// DefaultFormat defines the default message format.
|
||||
DefaultFormat MsgFormat
|
||||
|
||||
// AllowUntagged defines if untagged messages are allowed,
|
||||
// which are sent to all configured apprise endpoints.
|
||||
AllowUntagged bool
|
||||
|
||||
client *http.Client
|
||||
clientLock sync.Mutex
|
||||
}
|
||||
|
||||
// Message represents the message to be sent to the Apprise API.
|
||||
type Message struct {
|
||||
// Title is an optional title to go along with the body.
|
||||
Title string `json:"title,omitempty"`
|
||||
|
||||
// Body is the main message content. This is the only required field.
|
||||
Body string `json:"body"`
|
||||
|
||||
// Type defines the message type you want to send as.
|
||||
// The valid options are info, success, warning, and failure.
|
||||
// If no type is specified then info is the default value used.
|
||||
Type MsgType `json:"type,omitempty"`
|
||||
|
||||
// Tag is used to notify only those tagged accordingly.
|
||||
// Use a comma (,) to OR your tags and a space ( ) to AND them.
|
||||
Tag string `json:"tag,omitempty"`
|
||||
|
||||
// Format optionally identifies the text format of the data you're feeding Apprise.
|
||||
// The valid options are text, markdown, html.
|
||||
// The default value if nothing is specified is text.
|
||||
Format MsgFormat `json:"format,omitempty"`
|
||||
}
|
||||
|
||||
// MsgType defines the message type.
|
||||
type MsgType string
|
||||
|
||||
// Message Types.
|
||||
const (
|
||||
TypeInfo MsgType = "info"
|
||||
TypeSuccess MsgType = "success"
|
||||
TypeWarning MsgType = "warning"
|
||||
TypeFailure MsgType = "failure"
|
||||
)
|
||||
|
||||
// MsgFormat defines the message format.
|
||||
type MsgFormat string
|
||||
|
||||
// Message Formats.
|
||||
const (
|
||||
FormatText MsgFormat = "text"
|
||||
FormatMarkdown MsgFormat = "markdown"
|
||||
FormatHTML MsgFormat = "html"
|
||||
)
|
||||
|
||||
type errorResponse struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
// Send sends a message to the Apprise API.
|
||||
func (n *Notifier) Send(ctx context.Context, m *Message) error {
|
||||
// Check if the message has a body.
|
||||
if m.Body == "" {
|
||||
return errors.New("the message must have a body")
|
||||
}
|
||||
|
||||
// Apply notifier defaults.
|
||||
n.applyDefaults(m)
|
||||
|
||||
// Check if the message is tagged.
|
||||
if m.Tag == "" && !n.AllowUntagged {
|
||||
return errors.New("the message must have a tag")
|
||||
}
|
||||
|
||||
// Marshal the message to JSON.
|
||||
payload, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal message: %w", err)
|
||||
}
|
||||
|
||||
// Create request.
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodPost, n.URL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Send message to API.
|
||||
resp, err := n.getClient().Do(request)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send message: %w", err)
|
||||
}
|
||||
defer resp.Body.Close() //nolint:errcheck,gosec
|
||||
switch resp.StatusCode {
|
||||
case http.StatusOK, http.StatusCreated, http.StatusNoContent, http.StatusAccepted:
|
||||
return nil
|
||||
default:
|
||||
// Try to tease body contents.
|
||||
if body, err := io.ReadAll(resp.Body); err == nil && len(body) > 0 {
|
||||
// Try to parse json response.
|
||||
errorResponse := &errorResponse{}
|
||||
if err := json.Unmarshal(body, errorResponse); err == nil && errorResponse.Error != "" {
|
||||
return fmt.Errorf("failed to send message: apprise returned %q with an error message: %s", resp.Status, errorResponse.Error)
|
||||
}
|
||||
return fmt.Errorf("failed to send message: %s (body teaser: %s)", resp.Status, utils.SafeFirst16Bytes(body))
|
||||
}
|
||||
return fmt.Errorf("failed to send message: %s", resp.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Notifier) applyDefaults(m *Message) {
|
||||
if m.Type == "" {
|
||||
m.Type = n.DefaultType
|
||||
}
|
||||
if m.Tag == "" {
|
||||
m.Tag = n.DefaultTag
|
||||
}
|
||||
if m.Format == "" {
|
||||
m.Format = n.DefaultFormat
|
||||
}
|
||||
}
|
||||
|
||||
// SetClient sets a custom http client for accessing the Apprise API.
|
||||
func (n *Notifier) SetClient(client *http.Client) {
|
||||
n.clientLock.Lock()
|
||||
defer n.clientLock.Unlock()
|
||||
|
||||
n.client = client
|
||||
}
|
||||
|
||||
func (n *Notifier) getClient() *http.Client {
|
||||
n.clientLock.Lock()
|
||||
defer n.clientLock.Unlock()
|
||||
|
||||
// Create client if needed.
|
||||
if n.client == nil {
|
||||
n.client = &http.Client{}
|
||||
}
|
||||
|
||||
return n.client
|
||||
}
|
106
base/config/basic_config.go
Normal file
106
base/config/basic_config.go
Normal file
|
@ -0,0 +1,106 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"flag"
|
||||
|
||||
"github.com/safing/portmaster/base/log"
|
||||
"github.com/safing/portmaster/service/mgr"
|
||||
)
|
||||
|
||||
// Configuration Keys.
|
||||
var (
|
||||
CfgDevModeKey = "core/devMode"
|
||||
defaultDevMode bool
|
||||
|
||||
CfgLogLevel = "core/log/level"
|
||||
defaultLogLevel = log.InfoLevel.String()
|
||||
logLevel StringOption
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.BoolVar(&defaultDevMode, "devmode", false, "enable development mode; configuration is stronger")
|
||||
}
|
||||
|
||||
func registerBasicOptions() error {
|
||||
// Get the default log level from the log package.
|
||||
defaultLogLevel = log.GetLogLevel().Name()
|
||||
|
||||
// Register logging setting.
|
||||
// The log package cannot do that, as it would trigger and import loop.
|
||||
if err := Register(&Option{
|
||||
Name: "Log Level",
|
||||
Key: CfgLogLevel,
|
||||
Description: "Configure the logging level.",
|
||||
OptType: OptTypeString,
|
||||
ExpertiseLevel: ExpertiseLevelDeveloper,
|
||||
ReleaseLevel: ReleaseLevelStable,
|
||||
DefaultValue: defaultLogLevel,
|
||||
Annotations: Annotations{
|
||||
DisplayOrderAnnotation: 513,
|
||||
DisplayHintAnnotation: DisplayHintOneOf,
|
||||
CategoryAnnotation: "Development",
|
||||
},
|
||||
PossibleValues: []PossibleValue{
|
||||
{
|
||||
Name: "Critical",
|
||||
Value: "critical",
|
||||
Description: "The critical level only logs errors that lead to a partial, but imminent failure.",
|
||||
},
|
||||
{
|
||||
Name: "Error",
|
||||
Value: "error",
|
||||
Description: "The error level logs errors that potentially break functionality. Everything logged by the critical level is included here too.",
|
||||
},
|
||||
{
|
||||
Name: "Warning",
|
||||
Value: "warning",
|
||||
Description: "The warning level logs minor errors and worse. Everything logged by the error level is included here too.",
|
||||
},
|
||||
{
|
||||
Name: "Info",
|
||||
Value: "info",
|
||||
Description: "The info level logs the main events that are going on and are interesting to the user. Everything logged by the warning level is included here too.",
|
||||
},
|
||||
{
|
||||
Name: "Debug",
|
||||
Value: "debug",
|
||||
Description: "The debug level logs some additional debugging details. Everything logged by the info level is included here too.",
|
||||
},
|
||||
{
|
||||
Name: "Trace",
|
||||
Value: "trace",
|
||||
Description: "The trace level logs loads of detailed information as well as operation and request traces. Everything logged by the debug level is included here too.",
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
logLevel = GetAsString(CfgLogLevel, defaultLogLevel)
|
||||
|
||||
// Register to hook to update the log level.
|
||||
module.EventConfigChange.AddCallback("update log level", setLogLevel)
|
||||
|
||||
return Register(&Option{
|
||||
Name: "Development Mode",
|
||||
Key: CfgDevModeKey,
|
||||
Description: "In Development Mode, security restrictions are lifted/softened to enable unrestricted access for debugging and testing purposes.",
|
||||
OptType: OptTypeBool,
|
||||
ExpertiseLevel: ExpertiseLevelDeveloper,
|
||||
ReleaseLevel: ReleaseLevelStable,
|
||||
DefaultValue: defaultDevMode,
|
||||
Annotations: Annotations{
|
||||
DisplayOrderAnnotation: 512,
|
||||
CategoryAnnotation: "Development",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func loadLogLevel() error {
|
||||
return setDefaultConfigOption(CfgLogLevel, log.GetLogLevel().Name(), false)
|
||||
}
|
||||
|
||||
func setLogLevel(_ *mgr.WorkerCtx, _ struct{}) (cancel bool, err error) {
|
||||
log.SetLogLevel(log.ParseLevel(logLevel()))
|
||||
|
||||
return false, nil
|
||||
}
|
169
base/config/database.go
Normal file
169
base/config/database.go
Normal file
|
@ -0,0 +1,169 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/safing/portmaster/base/database"
|
||||
"github.com/safing/portmaster/base/database/iterator"
|
||||
"github.com/safing/portmaster/base/database/query"
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
"github.com/safing/portmaster/base/database/storage"
|
||||
"github.com/safing/portmaster/base/log"
|
||||
)
|
||||
|
||||
var dbController *database.Controller
|
||||
|
||||
// StorageInterface provices a storage.Interface to the configuration manager.
|
||||
type StorageInterface struct {
|
||||
storage.InjectBase
|
||||
}
|
||||
|
||||
// Get returns a database record.
|
||||
func (s *StorageInterface) Get(key string) (record.Record, error) {
|
||||
opt, err := GetOption(key)
|
||||
if err != nil {
|
||||
return nil, storage.ErrNotFound
|
||||
}
|
||||
|
||||
return opt.Export()
|
||||
}
|
||||
|
||||
// Put stores a record in the database.
|
||||
func (s *StorageInterface) Put(r record.Record) (record.Record, error) {
|
||||
if r.Meta().Deleted > 0 {
|
||||
return r, setConfigOption(r.DatabaseKey(), nil, false)
|
||||
}
|
||||
|
||||
acc := r.GetAccessor(r)
|
||||
if acc == nil {
|
||||
return nil, errors.New("invalid data")
|
||||
}
|
||||
|
||||
val, ok := acc.Get("Value")
|
||||
if !ok || val == nil {
|
||||
err := setConfigOption(r.DatabaseKey(), nil, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.Get(r.DatabaseKey())
|
||||
}
|
||||
|
||||
option, err := GetOption(r.DatabaseKey())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var value interface{}
|
||||
switch option.OptType {
|
||||
case OptTypeString:
|
||||
value, ok = acc.GetString("Value")
|
||||
case OptTypeStringArray:
|
||||
value, ok = acc.GetStringArray("Value")
|
||||
case OptTypeInt:
|
||||
value, ok = acc.GetInt("Value")
|
||||
case OptTypeBool:
|
||||
value, ok = acc.GetBool("Value")
|
||||
case optTypeAny:
|
||||
ok = false
|
||||
}
|
||||
if !ok {
|
||||
return nil, errors.New("received invalid value in \"Value\"")
|
||||
}
|
||||
|
||||
if err := setConfigOption(r.DatabaseKey(), value, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return option.Export()
|
||||
}
|
||||
|
||||
// Delete deletes a record from the database.
|
||||
func (s *StorageInterface) Delete(key string) error {
|
||||
return setConfigOption(key, nil, false)
|
||||
}
|
||||
|
||||
// Query returns a an iterator for the supplied query.
|
||||
func (s *StorageInterface) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
|
||||
optionsLock.RLock()
|
||||
defer optionsLock.RUnlock()
|
||||
|
||||
it := iterator.New()
|
||||
var opts []*Option
|
||||
for _, opt := range options {
|
||||
if strings.HasPrefix(opt.Key, q.DatabaseKeyPrefix()) {
|
||||
opts = append(opts, opt)
|
||||
}
|
||||
}
|
||||
|
||||
go s.processQuery(it, opts)
|
||||
|
||||
return it, nil
|
||||
}
|
||||
|
||||
func (s *StorageInterface) processQuery(it *iterator.Iterator, opts []*Option) {
|
||||
sort.Sort(sortByKey(opts))
|
||||
|
||||
for _, opt := range opts {
|
||||
r, err := opt.Export()
|
||||
if err != nil {
|
||||
it.Finish(err)
|
||||
return
|
||||
}
|
||||
it.Next <- r
|
||||
}
|
||||
|
||||
it.Finish(nil)
|
||||
}
|
||||
|
||||
// ReadOnly returns whether the database is read only.
|
||||
func (s *StorageInterface) ReadOnly() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func registerAsDatabase() error {
|
||||
_, err := database.Register(&database.Database{
|
||||
Name: "config",
|
||||
Description: "Configuration Manager",
|
||||
StorageType: "injected",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
controller, err := database.InjectDatabase("config", &StorageInterface{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dbController = controller
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleOptionUpdate updates the expertise and release level options,
|
||||
// if required, and eventually pushes a update for the option.
|
||||
// The caller must hold the option lock.
|
||||
func handleOptionUpdate(option *Option, push bool) {
|
||||
if expertiseLevelOptionFlag.IsSet() && option == expertiseLevelOption {
|
||||
updateExpertiseLevel()
|
||||
}
|
||||
|
||||
if releaseLevelOptionFlag.IsSet() && option == releaseLevelOption {
|
||||
updateReleaseLevel()
|
||||
}
|
||||
|
||||
if push {
|
||||
pushUpdate(option)
|
||||
}
|
||||
}
|
||||
|
||||
// pushUpdate pushes an database update notification for option.
|
||||
// The caller must hold the option lock.
|
||||
func pushUpdate(option *Option) {
|
||||
r, err := option.export()
|
||||
if err != nil {
|
||||
log.Errorf("failed to export option to push update: %s", err)
|
||||
} else {
|
||||
dbController.PushUpdate(r)
|
||||
}
|
||||
}
|
2
base/config/doc.go
Normal file
2
base/config/doc.go
Normal file
|
@ -0,0 +1,2 @@
|
|||
// Package config provides a versatile configuration management system.
|
||||
package config
|
104
base/config/expertise.go
Normal file
104
base/config/expertise.go
Normal file
|
@ -0,0 +1,104 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
)
|
||||
|
||||
// ExpertiseLevel allows to group settings by user expertise.
|
||||
// It's useful if complex or technical settings should be hidden
|
||||
// from the average user while still allowing experts and developers
|
||||
// to change deep configuration settings.
|
||||
type ExpertiseLevel uint8
|
||||
|
||||
// Expertise Level constants.
|
||||
const (
|
||||
ExpertiseLevelUser ExpertiseLevel = 0
|
||||
ExpertiseLevelExpert ExpertiseLevel = 1
|
||||
ExpertiseLevelDeveloper ExpertiseLevel = 2
|
||||
|
||||
ExpertiseLevelNameUser = "user"
|
||||
ExpertiseLevelNameExpert = "expert"
|
||||
ExpertiseLevelNameDeveloper = "developer"
|
||||
|
||||
expertiseLevelKey = "core/expertiseLevel"
|
||||
)
|
||||
|
||||
var (
|
||||
expertiseLevelOption *Option
|
||||
expertiseLevel = new(int32)
|
||||
expertiseLevelOptionFlag = abool.New()
|
||||
)
|
||||
|
||||
func init() {
|
||||
registerExpertiseLevelOption()
|
||||
}
|
||||
|
||||
func registerExpertiseLevelOption() {
|
||||
expertiseLevelOption = &Option{
|
||||
Name: "UI Mode",
|
||||
Key: expertiseLevelKey,
|
||||
Description: "Control the default amount of settings and information shown. Hidden settings are still in effect. Can be changed temporarily in the top right corner.",
|
||||
OptType: OptTypeString,
|
||||
ExpertiseLevel: ExpertiseLevelUser,
|
||||
ReleaseLevel: ReleaseLevelStable,
|
||||
DefaultValue: ExpertiseLevelNameUser,
|
||||
Annotations: Annotations{
|
||||
DisplayOrderAnnotation: -16,
|
||||
DisplayHintAnnotation: DisplayHintOneOf,
|
||||
CategoryAnnotation: "User Interface",
|
||||
},
|
||||
PossibleValues: []PossibleValue{
|
||||
{
|
||||
Name: "Simple Interface",
|
||||
Value: ExpertiseLevelNameUser,
|
||||
Description: "Hide complex settings and information.",
|
||||
},
|
||||
{
|
||||
Name: "Advanced Interface",
|
||||
Value: ExpertiseLevelNameExpert,
|
||||
Description: "Show technical details.",
|
||||
},
|
||||
{
|
||||
Name: "Developer Interface",
|
||||
Value: ExpertiseLevelNameDeveloper,
|
||||
Description: "Developer mode. Please be careful!",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := Register(expertiseLevelOption)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
expertiseLevelOptionFlag.Set()
|
||||
}
|
||||
|
||||
func updateExpertiseLevel() {
|
||||
// get value
|
||||
value := expertiseLevelOption.activeFallbackValue
|
||||
if expertiseLevelOption.activeValue != nil {
|
||||
value = expertiseLevelOption.activeValue
|
||||
}
|
||||
if expertiseLevelOption.activeDefaultValue != nil {
|
||||
value = expertiseLevelOption.activeDefaultValue
|
||||
}
|
||||
// set atomic value
|
||||
switch value.stringVal {
|
||||
case ExpertiseLevelNameUser:
|
||||
atomic.StoreInt32(expertiseLevel, int32(ExpertiseLevelUser))
|
||||
case ExpertiseLevelNameExpert:
|
||||
atomic.StoreInt32(expertiseLevel, int32(ExpertiseLevelExpert))
|
||||
case ExpertiseLevelNameDeveloper:
|
||||
atomic.StoreInt32(expertiseLevel, int32(ExpertiseLevelDeveloper))
|
||||
default:
|
||||
atomic.StoreInt32(expertiseLevel, int32(ExpertiseLevelUser))
|
||||
}
|
||||
}
|
||||
|
||||
// GetExpertiseLevel returns the current active expertise level.
|
||||
func GetExpertiseLevel() uint8 {
|
||||
return uint8(atomic.LoadInt32(expertiseLevel))
|
||||
}
|
112
base/config/get-safe.go
Normal file
112
base/config/get-safe.go
Normal file
|
@ -0,0 +1,112 @@
|
|||
package config
|
||||
|
||||
import "sync"
|
||||
|
||||
type safe struct{}
|
||||
|
||||
// Concurrent makes concurrency safe get methods available.
|
||||
var Concurrent = &safe{}
|
||||
|
||||
// GetAsString returns a function that returns the wanted string with high performance.
|
||||
func (cs *safe) GetAsString(name string, fallback string) StringOption {
|
||||
valid := getValidityFlag()
|
||||
option, valueCache := getValueCache(name, nil, OptTypeString)
|
||||
value := fallback
|
||||
if valueCache != nil {
|
||||
value = valueCache.stringVal
|
||||
}
|
||||
var lock sync.Mutex
|
||||
|
||||
return func() string {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
if !valid.IsSet() {
|
||||
valid = getValidityFlag()
|
||||
option, valueCache = getValueCache(name, option, OptTypeString)
|
||||
if valueCache != nil {
|
||||
value = valueCache.stringVal
|
||||
} else {
|
||||
value = fallback
|
||||
}
|
||||
}
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
// GetAsStringArray returns a function that returns the wanted string with high performance.
|
||||
func (cs *safe) GetAsStringArray(name string, fallback []string) StringArrayOption {
|
||||
valid := getValidityFlag()
|
||||
option, valueCache := getValueCache(name, nil, OptTypeStringArray)
|
||||
value := fallback
|
||||
if valueCache != nil {
|
||||
value = valueCache.stringArrayVal
|
||||
}
|
||||
var lock sync.Mutex
|
||||
|
||||
return func() []string {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
if !valid.IsSet() {
|
||||
valid = getValidityFlag()
|
||||
option, valueCache = getValueCache(name, option, OptTypeStringArray)
|
||||
if valueCache != nil {
|
||||
value = valueCache.stringArrayVal
|
||||
} else {
|
||||
value = fallback
|
||||
}
|
||||
}
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
// GetAsInt returns a function that returns the wanted int with high performance.
|
||||
func (cs *safe) GetAsInt(name string, fallback int64) IntOption {
|
||||
valid := getValidityFlag()
|
||||
option, valueCache := getValueCache(name, nil, OptTypeInt)
|
||||
value := fallback
|
||||
if valueCache != nil {
|
||||
value = valueCache.intVal
|
||||
}
|
||||
var lock sync.Mutex
|
||||
|
||||
return func() int64 {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
if !valid.IsSet() {
|
||||
valid = getValidityFlag()
|
||||
option, valueCache = getValueCache(name, option, OptTypeInt)
|
||||
if valueCache != nil {
|
||||
value = valueCache.intVal
|
||||
} else {
|
||||
value = fallback
|
||||
}
|
||||
}
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
// GetAsBool returns a function that returns the wanted int with high performance.
|
||||
func (cs *safe) GetAsBool(name string, fallback bool) BoolOption {
|
||||
valid := getValidityFlag()
|
||||
option, valueCache := getValueCache(name, nil, OptTypeBool)
|
||||
value := fallback
|
||||
if valueCache != nil {
|
||||
value = valueCache.boolVal
|
||||
}
|
||||
var lock sync.Mutex
|
||||
|
||||
return func() bool {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
if !valid.IsSet() {
|
||||
valid = getValidityFlag()
|
||||
option, valueCache = getValueCache(name, option, OptTypeBool)
|
||||
if valueCache != nil {
|
||||
value = valueCache.boolVal
|
||||
} else {
|
||||
value = fallback
|
||||
}
|
||||
}
|
||||
return value
|
||||
}
|
||||
}
|
174
base/config/get.go
Normal file
174
base/config/get.go
Normal file
|
@ -0,0 +1,174 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"github.com/safing/portmaster/base/log"
|
||||
)
|
||||
|
||||
type (
|
||||
// StringOption defines the returned function by GetAsString.
|
||||
StringOption func() string
|
||||
// StringArrayOption defines the returned function by GetAsStringArray.
|
||||
StringArrayOption func() []string
|
||||
// IntOption defines the returned function by GetAsInt.
|
||||
IntOption func() int64
|
||||
// BoolOption defines the returned function by GetAsBool.
|
||||
BoolOption func() bool
|
||||
)
|
||||
|
||||
func getValueCache(name string, option *Option, requestedType OptionType) (*Option, *valueCache) {
|
||||
// get option
|
||||
if option == nil {
|
||||
var err error
|
||||
option, err = GetOption(name)
|
||||
if err != nil {
|
||||
log.Errorf("config: request for unregistered option: %s", name)
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Check the option type, no locking required as
|
||||
// OptType is immutable once it is set
|
||||
if requestedType != option.OptType {
|
||||
log.Errorf("config: bad type: requested %s as %s, but is %s", name, getTypeName(requestedType), getTypeName(option.OptType))
|
||||
return option, nil
|
||||
}
|
||||
|
||||
option.Lock()
|
||||
defer option.Unlock()
|
||||
|
||||
// check release level
|
||||
if option.ReleaseLevel <= getReleaseLevel() && option.activeValue != nil {
|
||||
return option, option.activeValue
|
||||
}
|
||||
|
||||
if option.activeDefaultValue != nil {
|
||||
return option, option.activeDefaultValue
|
||||
}
|
||||
|
||||
return option, option.activeFallbackValue
|
||||
}
|
||||
|
||||
// GetAsString returns a function that returns the wanted string with high performance.
|
||||
func GetAsString(name string, fallback string) StringOption {
|
||||
valid := getValidityFlag()
|
||||
option, valueCache := getValueCache(name, nil, OptTypeString)
|
||||
value := fallback
|
||||
if valueCache != nil {
|
||||
value = valueCache.stringVal
|
||||
}
|
||||
|
||||
return func() string {
|
||||
if !valid.IsSet() {
|
||||
valid = getValidityFlag()
|
||||
option, valueCache = getValueCache(name, option, OptTypeString)
|
||||
if valueCache != nil {
|
||||
value = valueCache.stringVal
|
||||
} else {
|
||||
value = fallback
|
||||
}
|
||||
}
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
// GetAsStringArray returns a function that returns the wanted string with high performance.
|
||||
func GetAsStringArray(name string, fallback []string) StringArrayOption {
|
||||
valid := getValidityFlag()
|
||||
option, valueCache := getValueCache(name, nil, OptTypeStringArray)
|
||||
value := fallback
|
||||
if valueCache != nil {
|
||||
value = valueCache.stringArrayVal
|
||||
}
|
||||
|
||||
return func() []string {
|
||||
if !valid.IsSet() {
|
||||
valid = getValidityFlag()
|
||||
option, valueCache = getValueCache(name, option, OptTypeStringArray)
|
||||
if valueCache != nil {
|
||||
value = valueCache.stringArrayVal
|
||||
} else {
|
||||
value = fallback
|
||||
}
|
||||
}
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
// GetAsInt returns a function that returns the wanted int with high performance.
|
||||
func GetAsInt(name string, fallback int64) IntOption {
|
||||
valid := getValidityFlag()
|
||||
option, valueCache := getValueCache(name, nil, OptTypeInt)
|
||||
value := fallback
|
||||
if valueCache != nil {
|
||||
value = valueCache.intVal
|
||||
}
|
||||
|
||||
return func() int64 {
|
||||
if !valid.IsSet() {
|
||||
valid = getValidityFlag()
|
||||
option, valueCache = getValueCache(name, option, OptTypeInt)
|
||||
if valueCache != nil {
|
||||
value = valueCache.intVal
|
||||
} else {
|
||||
value = fallback
|
||||
}
|
||||
}
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
// GetAsBool returns a function that returns the wanted int with high performance.
|
||||
func GetAsBool(name string, fallback bool) BoolOption {
|
||||
valid := getValidityFlag()
|
||||
option, valueCache := getValueCache(name, nil, OptTypeBool)
|
||||
value := fallback
|
||||
if valueCache != nil {
|
||||
value = valueCache.boolVal
|
||||
}
|
||||
|
||||
return func() bool {
|
||||
if !valid.IsSet() {
|
||||
valid = getValidityFlag()
|
||||
option, valueCache = getValueCache(name, option, OptTypeBool)
|
||||
if valueCache != nil {
|
||||
value = valueCache.boolVal
|
||||
} else {
|
||||
value = fallback
|
||||
}
|
||||
}
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
func getAndFindValue(key string) interface{} {
|
||||
optionsLock.RLock()
|
||||
option, ok := options[key]
|
||||
optionsLock.RUnlock()
|
||||
if !ok {
|
||||
log.Errorf("config: request for unregistered option: %s", key)
|
||||
return nil
|
||||
}
|
||||
|
||||
return option.findValue()
|
||||
}
|
||||
*/
|
||||
|
||||
/*
|
||||
// findValue finds the preferred value in the user or default config.
|
||||
func (option *Option) findValue() interface{} {
|
||||
// lock option
|
||||
option.Lock()
|
||||
defer option.Unlock()
|
||||
|
||||
if option.ReleaseLevel <= getReleaseLevel() && option.activeValue != nil {
|
||||
return option.activeValue
|
||||
}
|
||||
|
||||
if option.activeDefaultValue != nil {
|
||||
return option.activeDefaultValue
|
||||
}
|
||||
|
||||
return option.DefaultValue
|
||||
}
|
||||
*/
|
368
base/config/get_test.go
Normal file
368
base/config/get_test.go
Normal file
|
@ -0,0 +1,368 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/safing/portmaster/base/log"
|
||||
)
|
||||
|
||||
func parseAndReplaceConfig(jsonData string) error {
|
||||
m, err := JSONToMap([]byte(jsonData))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
validationErrors, _ := ReplaceConfig(m)
|
||||
if len(validationErrors) > 0 {
|
||||
return fmt.Errorf("%d errors, first: %w", len(validationErrors), validationErrors[0])
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseAndReplaceDefaultConfig(jsonData string) error {
|
||||
m, err := JSONToMap([]byte(jsonData))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
validationErrors, _ := ReplaceDefaultConfig(m)
|
||||
if len(validationErrors) > 0 {
|
||||
return fmt.Errorf("%d errors, first: %w", len(validationErrors), validationErrors[0])
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func quickRegister(t *testing.T, key string, optType OptionType, defaultValue interface{}) {
|
||||
t.Helper()
|
||||
|
||||
err := Register(&Option{
|
||||
Name: key,
|
||||
Key: key,
|
||||
Description: "test config",
|
||||
ReleaseLevel: ReleaseLevelStable,
|
||||
ExpertiseLevel: ExpertiseLevelUser,
|
||||
OptType: optType,
|
||||
DefaultValue: defaultValue,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGet(t *testing.T) { //nolint:paralleltest
|
||||
// reset
|
||||
options = make(map[string]*Option)
|
||||
|
||||
err := log.Start()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
quickRegister(t, "monkey", OptTypeString, "c")
|
||||
quickRegister(t, "zebras/zebra", OptTypeStringArray, []string{"a", "b"})
|
||||
quickRegister(t, "elephant", OptTypeInt, -1)
|
||||
quickRegister(t, "hot", OptTypeBool, false)
|
||||
quickRegister(t, "cold", OptTypeBool, true)
|
||||
|
||||
err = parseAndReplaceConfig(`
|
||||
{
|
||||
"monkey": "a",
|
||||
"zebras": {
|
||||
"zebra": ["black", "white"]
|
||||
},
|
||||
"elephant": 2,
|
||||
"hot": true,
|
||||
"cold": false
|
||||
}
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = parseAndReplaceDefaultConfig(`
|
||||
{
|
||||
"monkey": "b",
|
||||
"snake": "0",
|
||||
"elephant": 0
|
||||
}
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
monkey := GetAsString("monkey", "none")
|
||||
if monkey() != "a" {
|
||||
t.Errorf("monkey should be a, is %s", monkey())
|
||||
}
|
||||
|
||||
zebra := GetAsStringArray("zebras/zebra", []string{})
|
||||
if len(zebra()) != 2 || zebra()[0] != "black" || zebra()[1] != "white" {
|
||||
t.Errorf("zebra should be [\"black\", \"white\"], is %v", zebra())
|
||||
}
|
||||
|
||||
elephant := GetAsInt("elephant", -1)
|
||||
if elephant() != 2 {
|
||||
t.Errorf("elephant should be 2, is %d", elephant())
|
||||
}
|
||||
|
||||
hot := GetAsBool("hot", false)
|
||||
if !hot() {
|
||||
t.Errorf("hot should be true, is %v", hot())
|
||||
}
|
||||
|
||||
cold := GetAsBool("cold", true)
|
||||
if cold() {
|
||||
t.Errorf("cold should be false, is %v", cold())
|
||||
}
|
||||
|
||||
err = parseAndReplaceConfig(`
|
||||
{
|
||||
"monkey": "3"
|
||||
}
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if monkey() != "3" {
|
||||
t.Errorf("monkey should be 0, is %s", monkey())
|
||||
}
|
||||
|
||||
if elephant() != 0 {
|
||||
t.Errorf("elephant should be 0, is %d", elephant())
|
||||
}
|
||||
|
||||
zebra()
|
||||
hot()
|
||||
|
||||
// concurrent
|
||||
GetAsString("monkey", "none")()
|
||||
GetAsStringArray("zebras/zebra", []string{})()
|
||||
GetAsInt("elephant", -1)()
|
||||
GetAsBool("hot", false)()
|
||||
|
||||
// perspective
|
||||
|
||||
// load data
|
||||
pLoaded := make(map[string]interface{})
|
||||
err = json.Unmarshal([]byte(`{
|
||||
"monkey": "a",
|
||||
"zebras": {
|
||||
"zebra": ["black", "white"]
|
||||
},
|
||||
"elephant": 2,
|
||||
"hot": true,
|
||||
"cold": false
|
||||
}`), &pLoaded)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// create
|
||||
p, err := NewPerspective(pLoaded)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
monkeyVal, ok := p.GetAsString("monkey")
|
||||
if !ok || monkeyVal != "a" {
|
||||
t.Errorf("[perspective] monkey should be a, is %+v", monkeyVal)
|
||||
}
|
||||
|
||||
zebraVal, ok := p.GetAsStringArray("zebras/zebra")
|
||||
if !ok || len(zebraVal) != 2 || zebraVal[0] != "black" || zebraVal[1] != "white" {
|
||||
t.Errorf("[perspective] zebra should be [\"black\", \"white\"], is %+v", zebraVal)
|
||||
}
|
||||
|
||||
elephantVal, ok := p.GetAsInt("elephant")
|
||||
if !ok || elephantVal != 2 {
|
||||
t.Errorf("[perspective] elephant should be 2, is %+v", elephantVal)
|
||||
}
|
||||
|
||||
hotVal, ok := p.GetAsBool("hot")
|
||||
if !ok || !hotVal {
|
||||
t.Errorf("[perspective] hot should be true, is %+v", hotVal)
|
||||
}
|
||||
|
||||
coldVal, ok := p.GetAsBool("cold")
|
||||
if !ok || coldVal {
|
||||
t.Errorf("[perspective] cold should be false, is %+v", coldVal)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReleaseLevel(t *testing.T) { //nolint:paralleltest
|
||||
// reset
|
||||
options = make(map[string]*Option)
|
||||
registerReleaseLevelOption()
|
||||
|
||||
// setup
|
||||
subsystemOption := &Option{
|
||||
Name: "test subsystem",
|
||||
Key: "subsystem/test",
|
||||
Description: "test config",
|
||||
ReleaseLevel: ReleaseLevelStable,
|
||||
ExpertiseLevel: ExpertiseLevelUser,
|
||||
OptType: OptTypeBool,
|
||||
DefaultValue: false,
|
||||
}
|
||||
err := Register(subsystemOption)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = SetConfigOption("subsystem/test", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
testSubsystem := GetAsBool("subsystem/test", false)
|
||||
|
||||
// test option level stable
|
||||
subsystemOption.ReleaseLevel = ReleaseLevelStable
|
||||
err = SetConfigOption(releaseLevelKey, ReleaseLevelNameStable)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !testSubsystem() {
|
||||
t.Error("should be active")
|
||||
}
|
||||
err = SetConfigOption(releaseLevelKey, ReleaseLevelNameBeta)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !testSubsystem() {
|
||||
t.Error("should be active")
|
||||
}
|
||||
err = SetConfigOption(releaseLevelKey, ReleaseLevelNameExperimental)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !testSubsystem() {
|
||||
t.Error("should be active")
|
||||
}
|
||||
|
||||
// test option level beta
|
||||
subsystemOption.ReleaseLevel = ReleaseLevelBeta
|
||||
err = SetConfigOption(releaseLevelKey, ReleaseLevelNameStable)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if testSubsystem() {
|
||||
t.Errorf("should be inactive: opt=%d system=%d", subsystemOption.ReleaseLevel, getReleaseLevel())
|
||||
}
|
||||
err = SetConfigOption(releaseLevelKey, ReleaseLevelNameBeta)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !testSubsystem() {
|
||||
t.Error("should be active")
|
||||
}
|
||||
err = SetConfigOption(releaseLevelKey, ReleaseLevelNameExperimental)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !testSubsystem() {
|
||||
t.Error("should be active")
|
||||
}
|
||||
|
||||
// test option level experimental
|
||||
subsystemOption.ReleaseLevel = ReleaseLevelExperimental
|
||||
err = SetConfigOption(releaseLevelKey, ReleaseLevelNameStable)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if testSubsystem() {
|
||||
t.Error("should be inactive")
|
||||
}
|
||||
err = SetConfigOption(releaseLevelKey, ReleaseLevelNameBeta)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if testSubsystem() {
|
||||
t.Error("should be inactive")
|
||||
}
|
||||
err = SetConfigOption(releaseLevelKey, ReleaseLevelNameExperimental)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !testSubsystem() {
|
||||
t.Error("should be active")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetAsStringCached(b *testing.B) {
|
||||
// reset
|
||||
options = make(map[string]*Option)
|
||||
|
||||
// Setup
|
||||
err := parseAndReplaceConfig(`{
|
||||
"monkey": "banana"
|
||||
}`)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
monkey := GetAsString("monkey", "no banana")
|
||||
|
||||
// Reset timer for precise results
|
||||
b.ResetTimer()
|
||||
|
||||
// Start benchmark
|
||||
for range b.N {
|
||||
monkey()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetAsStringRefetch(b *testing.B) {
|
||||
// Setup
|
||||
err := parseAndReplaceConfig(`{
|
||||
"monkey": "banana"
|
||||
}`)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// Reset timer for precise results
|
||||
b.ResetTimer()
|
||||
|
||||
// Start benchmark
|
||||
for range b.N {
|
||||
getValueCache("monkey", nil, OptTypeString)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetAsIntCached(b *testing.B) {
|
||||
// Setup
|
||||
err := parseAndReplaceConfig(`{
|
||||
"elephant": 1
|
||||
}`)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
elephant := GetAsInt("elephant", -1)
|
||||
|
||||
// Reset timer for precise results
|
||||
b.ResetTimer()
|
||||
|
||||
// Start benchmark
|
||||
for range b.N {
|
||||
elephant()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetAsIntRefetch(b *testing.B) {
|
||||
// Setup
|
||||
err := parseAndReplaceConfig(`{
|
||||
"elephant": 1
|
||||
}`)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// Reset timer for precise results
|
||||
b.ResetTimer()
|
||||
|
||||
// Start benchmark
|
||||
for range b.N {
|
||||
getValueCache("elephant", nil, OptTypeInt)
|
||||
}
|
||||
}
|
35
base/config/init_test.go
Normal file
35
base/config/init_test.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type testInstance struct{}
|
||||
|
||||
var _ instance = testInstance{}
|
||||
|
||||
func (stub testInstance) SetCmdLineOperation(f func() error) {}
|
||||
|
||||
func runTest(m *testing.M) error {
|
||||
ds, err := InitializeUnitTestDataroot("test-config")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize dataroot: %w", err)
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(ds) }()
|
||||
module, err = New(&testInstance{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize module: %w", err)
|
||||
}
|
||||
|
||||
m.Run()
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
if err := runTest(m); err != nil {
|
||||
fmt.Printf("%s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
155
base/config/main.go
Normal file
155
base/config/main.go
Normal file
|
@ -0,0 +1,155 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
|
||||
"github.com/safing/portmaster/base/dataroot"
|
||||
"github.com/safing/portmaster/base/utils"
|
||||
"github.com/safing/portmaster/base/utils/debug"
|
||||
"github.com/safing/portmaster/service/mgr"
|
||||
)
|
||||
|
||||
// ChangeEvent is the name of the config change event.
|
||||
const ChangeEvent = "config change"
|
||||
|
||||
var (
|
||||
dataRoot *utils.DirStructure
|
||||
|
||||
exportConfig bool
|
||||
)
|
||||
|
||||
// SetDataRoot sets the data root from which the updates module derives its paths.
|
||||
func SetDataRoot(root *utils.DirStructure) {
|
||||
if dataRoot == nil {
|
||||
dataRoot = root
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
flag.BoolVar(&exportConfig, "export-config-options", false, "export configuration registry and exit")
|
||||
}
|
||||
|
||||
func prep() error {
|
||||
SetDataRoot(dataroot.Root())
|
||||
if dataRoot == nil {
|
||||
return errors.New("data root is not set")
|
||||
}
|
||||
|
||||
if exportConfig {
|
||||
module.instance.SetCmdLineOperation(exportConfigCmd)
|
||||
return mgr.ErrExecuteCmdLineOp
|
||||
}
|
||||
|
||||
return registerBasicOptions()
|
||||
}
|
||||
|
||||
func start() error {
|
||||
configFilePath = filepath.Join(dataRoot.Path, "config.json")
|
||||
|
||||
// Load log level from log package after it started.
|
||||
err := loadLogLevel()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = registerAsDatabase()
|
||||
if err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
return err
|
||||
}
|
||||
|
||||
err = loadConfig(false)
|
||||
if err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
return fmt.Errorf("failed to load config file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func exportConfigCmd() error {
|
||||
// Reset the metrics instance name option, as the default
|
||||
// is set to the current hostname.
|
||||
// Config key copied from metrics.CfgOptionInstanceKey.
|
||||
option, err := GetOption("core/metrics/instance")
|
||||
if err == nil {
|
||||
option.DefaultValue = ""
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(ExportOptions(), "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = os.Stdout.Write(data)
|
||||
return err
|
||||
}
|
||||
|
||||
// AddToDebugInfo adds all changed global config options to the given debug.Info.
|
||||
func AddToDebugInfo(di *debug.Info) {
|
||||
var lines []string
|
||||
|
||||
// Collect all changed settings.
|
||||
_ = ForEachOption(func(opt *Option) error {
|
||||
opt.Lock()
|
||||
defer opt.Unlock()
|
||||
|
||||
if opt.ReleaseLevel <= getReleaseLevel() && opt.activeValue != nil {
|
||||
if opt.Sensitive {
|
||||
lines = append(lines, fmt.Sprintf("%s: [redacted]", opt.Key))
|
||||
} else {
|
||||
lines = append(lines, fmt.Sprintf("%s: %v", opt.Key, opt.activeValue.getData(opt)))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
sort.Strings(lines)
|
||||
|
||||
// Add data as section.
|
||||
di.AddSection(
|
||||
fmt.Sprintf("Config: %d", len(lines)),
|
||||
debug.UseCodeSection|debug.AddContentLineBreaks,
|
||||
lines...,
|
||||
)
|
||||
}
|
||||
|
||||
// GetActiveConfigValues returns a map with the active config values.
|
||||
func GetActiveConfigValues() map[string]interface{} {
|
||||
values := make(map[string]interface{})
|
||||
|
||||
// Collect active values from options.
|
||||
_ = ForEachOption(func(opt *Option) error {
|
||||
opt.Lock()
|
||||
defer opt.Unlock()
|
||||
|
||||
if opt.ReleaseLevel <= getReleaseLevel() && opt.activeValue != nil {
|
||||
values[opt.Key] = opt.activeValue.getData(opt)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return values
|
||||
}
|
||||
|
||||
// InitializeUnitTestDataroot initializes a new random tmp directory for running tests.
|
||||
func InitializeUnitTestDataroot(testName string) (string, error) {
|
||||
basePath, err := os.MkdirTemp("", fmt.Sprintf("portmaster-%s", testName))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to make tmp dir: %w", err)
|
||||
}
|
||||
|
||||
ds := utils.NewDirStructure(basePath, 0o0755)
|
||||
SetDataRoot(ds)
|
||||
err = dataroot.Initialize(basePath, 0o0755)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to initialize dataroot: %w", err)
|
||||
}
|
||||
|
||||
return basePath, nil
|
||||
}
|
60
base/config/module.go
Normal file
60
base/config/module.go
Normal file
|
@ -0,0 +1,60 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/safing/portmaster/service/mgr"
|
||||
)
|
||||
|
||||
// Config provides configuration mgmt.
|
||||
type Config struct {
|
||||
mgr *mgr.Manager
|
||||
|
||||
instance instance
|
||||
|
||||
EventConfigChange *mgr.EventMgr[struct{}]
|
||||
}
|
||||
|
||||
// Manager returns the module's manager.
|
||||
func (u *Config) Manager() *mgr.Manager {
|
||||
return u.mgr
|
||||
}
|
||||
|
||||
// Start starts the module.
|
||||
func (u *Config) Start() error {
|
||||
return start()
|
||||
}
|
||||
|
||||
// Stop stops the module.
|
||||
func (u *Config) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
module *Config
|
||||
shimLoaded atomic.Bool
|
||||
)
|
||||
|
||||
// New returns a new Config module.
|
||||
func New(instance instance) (*Config, error) {
|
||||
if !shimLoaded.CompareAndSwap(false, true) {
|
||||
return nil, errors.New("only one instance allowed")
|
||||
}
|
||||
m := mgr.New("Config")
|
||||
module = &Config{
|
||||
mgr: m,
|
||||
instance: instance,
|
||||
EventConfigChange: mgr.NewEventMgr[struct{}](ChangeEvent, m),
|
||||
}
|
||||
|
||||
if err := prep(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return module, nil
|
||||
}
|
||||
|
||||
type instance interface {
|
||||
SetCmdLineOperation(f func() error)
|
||||
}
|
418
base/config/option.go
Normal file
418
base/config/option.go
Normal file
|
@ -0,0 +1,418 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"sync"
|
||||
|
||||
"github.com/mitchellh/copystructure"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
"github.com/safing/structures/dsd"
|
||||
)
|
||||
|
||||
// OptionType defines the value type of an option.
|
||||
type OptionType uint8
|
||||
|
||||
// Various attribute options. Use ExternalOptType for extended types in the frontend.
|
||||
const (
|
||||
optTypeAny OptionType = 0
|
||||
OptTypeString OptionType = 1
|
||||
OptTypeStringArray OptionType = 2
|
||||
OptTypeInt OptionType = 3
|
||||
OptTypeBool OptionType = 4
|
||||
)
|
||||
|
||||
func getTypeName(t OptionType) string {
|
||||
switch t {
|
||||
case optTypeAny:
|
||||
return "any"
|
||||
case OptTypeString:
|
||||
return "string"
|
||||
case OptTypeStringArray:
|
||||
return "[]string"
|
||||
case OptTypeInt:
|
||||
return "int"
|
||||
case OptTypeBool:
|
||||
return "bool"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// PossibleValue defines a value that is possible for
|
||||
// a configuration setting.
|
||||
type PossibleValue struct {
|
||||
// Name is a human readable name of the option.
|
||||
Name string
|
||||
// Description is a human readable description of
|
||||
// this value.
|
||||
Description string
|
||||
// Value is the actual value of the option. The type
|
||||
// must match the option's value type.
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
// Annotations can be attached to configuration options to
|
||||
// provide hints for user interfaces or other systems working
|
||||
// or setting configuration options.
|
||||
// Annotation keys should follow the below format to ensure
|
||||
// future well-known annotation additions do not conflict
|
||||
// with vendor/product/package specific annoations.
|
||||
//
|
||||
// Format: <vendor/package>:<scope>:<identifier> //.
|
||||
type Annotations map[string]interface{}
|
||||
|
||||
// MigrationFunc is a function that migrates a config option value.
|
||||
type MigrationFunc func(option *Option, value any) any
|
||||
|
||||
// Well known annotations defined by this package.
|
||||
const (
|
||||
// DisplayHintAnnotation provides a hint for the user
|
||||
// interface on how to render an option.
|
||||
// The value of DisplayHintAnnotation is expected to
|
||||
// be a string. See DisplayHintXXXX constants below
|
||||
// for a list of well-known display hint annotations.
|
||||
DisplayHintAnnotation = "safing/portbase:ui:display-hint"
|
||||
// DisplayOrderAnnotation provides a hint for the user
|
||||
// interface in which order settings should be displayed.
|
||||
// The value of DisplayOrderAnnotations is expected to be
|
||||
// an number (int).
|
||||
DisplayOrderAnnotation = "safing/portbase:ui:order"
|
||||
// UnitAnnotations defines the SI unit of an option (if any).
|
||||
UnitAnnotation = "safing/portbase:ui:unit"
|
||||
// CategoryAnnotations can provide an additional category
|
||||
// to each settings. This category can be used by a user
|
||||
// interface to group certain options together.
|
||||
// User interfaces should treat a CategoryAnnotation, if
|
||||
// supported, with higher priority as a DisplayOrderAnnotation.
|
||||
CategoryAnnotation = "safing/portbase:ui:category"
|
||||
// SubsystemAnnotation can be used to mark an option as part
|
||||
// of a module subsystem.
|
||||
SubsystemAnnotation = "safing/portbase:module:subsystem"
|
||||
// StackableAnnotation can be set on configuration options that
|
||||
// stack on top of the default (or otherwise related) options.
|
||||
// The value of StackableAnnotaiton is expected to be a boolean but
|
||||
// may be extended to hold references to other options in the
|
||||
// future.
|
||||
StackableAnnotation = "safing/portbase:options:stackable"
|
||||
// RestartPendingAnnotation is automatically set on a configuration option
|
||||
// that requires a restart and has been changed.
|
||||
// The value must always be a boolean with value "true".
|
||||
RestartPendingAnnotation = "safing/portbase:options:restart-pending"
|
||||
// QuickSettingAnnotation can be used to add quick settings to
|
||||
// a configuration option. A quick setting can support the user
|
||||
// by switching between pre-configured values.
|
||||
// The type of a quick-setting annotation is []QuickSetting or QuickSetting.
|
||||
QuickSettingsAnnotation = "safing/portbase:ui:quick-setting"
|
||||
// RequiresAnnotation can be used to mark another option as a
|
||||
// requirement. The type of RequiresAnnotation is []ValueRequirement
|
||||
// or ValueRequirement.
|
||||
RequiresAnnotation = "safing/portbase:config:requires"
|
||||
// RequiresFeatureIDAnnotation can be used to mark a setting as only available
|
||||
// when the user has a certain feature ID in the subscription plan.
|
||||
// The type is []string or string.
|
||||
RequiresFeatureIDAnnotation = "safing/portmaster:ui:config:requires-feature"
|
||||
// SettablePerAppAnnotation can be used to mark a setting as settable per-app and
|
||||
// is a boolean.
|
||||
SettablePerAppAnnotation = "safing/portmaster:settable-per-app"
|
||||
// RequiresUIReloadAnnotation can be used to inform the UI that changing the value
|
||||
// of the annotated setting requires a full reload of the user interface.
|
||||
// The value of this annotation does not matter as the sole presence of
|
||||
// the annotation key is enough. Though, users are advised to set the value
|
||||
// of this annotation to true.
|
||||
RequiresUIReloadAnnotation = "safing/portmaster:ui:requires-reload"
|
||||
)
|
||||
|
||||
// QuickSettingsAction defines the action of a quick setting.
|
||||
type QuickSettingsAction string
|
||||
|
||||
const (
|
||||
// QuickReplace replaces the current setting with the one from
|
||||
// the quick setting.
|
||||
QuickReplace = QuickSettingsAction("replace")
|
||||
// QuickMergeTop merges the value of the quick setting with the
|
||||
// already configured one adding new values on the top. Merging
|
||||
// is only supported for OptTypeStringArray.
|
||||
QuickMergeTop = QuickSettingsAction("merge-top")
|
||||
// QuickMergeBottom merges the value of the quick setting with the
|
||||
// already configured one adding new values at the bottom. Merging
|
||||
// is only supported for OptTypeStringArray.
|
||||
QuickMergeBottom = QuickSettingsAction("merge-bottom")
|
||||
)
|
||||
|
||||
// QuickSetting defines a quick setting for a configuration option and
|
||||
// should be used together with the QuickSettingsAnnotation.
|
||||
type QuickSetting struct {
|
||||
// Name is the name of the quick setting.
|
||||
Name string
|
||||
|
||||
// Value is the value that the quick-setting configures. It must match
|
||||
// the expected value type of the annotated option.
|
||||
Value interface{}
|
||||
|
||||
// Action defines the action of the quick setting.
|
||||
Action QuickSettingsAction
|
||||
}
|
||||
|
||||
// ValueRequirement defines a requirement on another configuration option.
|
||||
type ValueRequirement struct {
|
||||
// Key is the key of the configuration option that is required.
|
||||
Key string
|
||||
|
||||
// Value that is required.
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
// Values for the DisplayHintAnnotation.
|
||||
const (
|
||||
// DisplayHintOneOf is used to mark an option
|
||||
// as a "select"-style option. That is, only one of
|
||||
// the supported values may be set. This option makes
|
||||
// only sense together with the PossibleValues property
|
||||
// of Option.
|
||||
DisplayHintOneOf = "one-of"
|
||||
// DisplayHintOrdered is used to mark a list option as ordered.
|
||||
// That is, the order of items is important and a user interface
|
||||
// is encouraged to provide the user with re-ordering support
|
||||
// (like drag'n'drop).
|
||||
DisplayHintOrdered = "ordered"
|
||||
// DisplayHintFilePicker is used to mark the option as being a file, which
|
||||
// should give the option to use a file picker to select a local file from disk.
|
||||
DisplayHintFilePicker = "file-picker"
|
||||
)
|
||||
|
||||
// Option describes a configuration option.
|
||||
type Option struct {
|
||||
sync.Mutex
|
||||
// Name holds the name of the configuration options.
|
||||
// It should be human readable and is mainly used for
|
||||
// presentation purposes.
|
||||
// Name is considered immutable after the option has
|
||||
// been created.
|
||||
Name string
|
||||
// Key holds the database path for the option. It should
|
||||
// follow the path format `category/sub/key`.
|
||||
// Key is considered immutable after the option has
|
||||
// been created.
|
||||
Key string
|
||||
// Description holds a human readable description of the
|
||||
// option and what is does. The description should be short.
|
||||
// Use the Help property for a longer support text.
|
||||
// Description is considered immutable after the option has
|
||||
// been created.
|
||||
Description string
|
||||
// Help may hold a long version of the description providing
|
||||
// assistance with the configuration option.
|
||||
// Help is considered immutable after the option has
|
||||
// been created.
|
||||
Help string
|
||||
// Sensitive signifies that the configuration values may contain sensitive
|
||||
// content, such as authentication keys.
|
||||
Sensitive bool
|
||||
// OptType defines the type of the option.
|
||||
// OptType is considered immutable after the option has
|
||||
// been created.
|
||||
OptType OptionType
|
||||
// ExpertiseLevel can be used to set the required expertise
|
||||
// level for the option to be displayed to a user.
|
||||
// ExpertiseLevel is considered immutable after the option has
|
||||
// been created.
|
||||
ExpertiseLevel ExpertiseLevel
|
||||
// ReleaseLevel is used to mark the stability of the option.
|
||||
// ReleaseLevel is considered immutable after the option has
|
||||
// been created.
|
||||
ReleaseLevel ReleaseLevel
|
||||
// RequiresRestart should be set to true if a modification of
|
||||
// the options value requires a restart of the whole application
|
||||
// to take effect.
|
||||
// RequiresRestart is considered immutable after the option has
|
||||
// been created.
|
||||
RequiresRestart bool
|
||||
// DefaultValue holds the default value of the option. Note that
|
||||
// this value can be overwritten during runtime (see activeDefaultValue
|
||||
// and activeFallbackValue).
|
||||
// DefaultValue is considered immutable after the option has
|
||||
// been created.
|
||||
DefaultValue interface{}
|
||||
// ValidationRegex may contain a regular expression used to validate
|
||||
// the value of option. If the option type is set to OptTypeStringArray
|
||||
// the validation regex is applied to all entries of the string slice.
|
||||
// Note that it is recommended to keep the validation regex simple so
|
||||
// it can also be used in other languages (mainly JavaScript) to provide
|
||||
// a better user-experience by pre-validating the expression.
|
||||
// ValidationRegex is considered immutable after the option has
|
||||
// been created.
|
||||
ValidationRegex string
|
||||
// ValidationFunc may contain a function to validate more complex values.
|
||||
// The error is returned beyond the scope of this package and may be
|
||||
// displayed to a user.
|
||||
ValidationFunc func(value interface{}) error `json:"-"`
|
||||
// PossibleValues may be set to a slice of values that are allowed
|
||||
// for this configuration setting. Note that PossibleValues makes most
|
||||
// sense when ExternalOptType is set to HintOneOf
|
||||
// PossibleValues is considered immutable after the option has
|
||||
// been created.
|
||||
PossibleValues []PossibleValue `json:",omitempty"`
|
||||
// Annotations adds additional annotations to the configuration options.
|
||||
// See documentation of Annotations for more information.
|
||||
// Annotations is considered mutable and setting/reading annotation keys
|
||||
// must be performed while the option is locked.
|
||||
Annotations Annotations
|
||||
// Migrations holds migration functions that are given the raw option value
|
||||
// before any validation is run. The returned value is then used.
|
||||
Migrations []MigrationFunc `json:"-"`
|
||||
|
||||
activeValue *valueCache // runtime value (loaded from config file or set by user)
|
||||
activeDefaultValue *valueCache // runtime default value (may be set internally)
|
||||
activeFallbackValue *valueCache // default value from option registration
|
||||
compiledRegex *regexp.Regexp
|
||||
}
|
||||
|
||||
// AddAnnotation adds the annotation key to option if it's not already set.
|
||||
func (option *Option) AddAnnotation(key string, value interface{}) {
|
||||
option.Lock()
|
||||
defer option.Unlock()
|
||||
|
||||
if option.Annotations == nil {
|
||||
option.Annotations = make(Annotations)
|
||||
}
|
||||
|
||||
if _, ok := option.Annotations[key]; ok {
|
||||
return
|
||||
}
|
||||
option.Annotations[key] = value
|
||||
}
|
||||
|
||||
// SetAnnotation sets the value of the annotation key overwritting an
|
||||
// existing value if required.
|
||||
func (option *Option) SetAnnotation(key string, value interface{}) {
|
||||
option.Lock()
|
||||
defer option.Unlock()
|
||||
|
||||
option.setAnnotation(key, value)
|
||||
}
|
||||
|
||||
// setAnnotation sets the value of the annotation key overwritting an
|
||||
// existing value if required. Does not lock the Option.
|
||||
func (option *Option) setAnnotation(key string, value interface{}) {
|
||||
if option.Annotations == nil {
|
||||
option.Annotations = make(Annotations)
|
||||
}
|
||||
option.Annotations[key] = value
|
||||
}
|
||||
|
||||
// GetAnnotation returns the value of the annotation key.
|
||||
func (option *Option) GetAnnotation(key string) (interface{}, bool) {
|
||||
option.Lock()
|
||||
defer option.Unlock()
|
||||
|
||||
if option.Annotations == nil {
|
||||
return nil, false
|
||||
}
|
||||
val, ok := option.Annotations[key]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// AnnotationEquals returns whether the annotation of the given key matches the
|
||||
// given value.
|
||||
func (option *Option) AnnotationEquals(key string, value any) bool {
|
||||
option.Lock()
|
||||
defer option.Unlock()
|
||||
|
||||
if option.Annotations == nil {
|
||||
return false
|
||||
}
|
||||
setValue, ok := option.Annotations[key]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return reflect.DeepEqual(value, setValue)
|
||||
}
|
||||
|
||||
// copyOrNil returns a copy of the option, or nil if copying failed.
|
||||
func (option *Option) copyOrNil() *Option {
|
||||
copied, err := copystructure.Copy(option)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return copied.(*Option) //nolint:forcetypeassert
|
||||
}
|
||||
|
||||
// IsSetByUser returns whether the option has been set by the user.
|
||||
func (option *Option) IsSetByUser() bool {
|
||||
option.Lock()
|
||||
defer option.Unlock()
|
||||
|
||||
return option.activeValue != nil
|
||||
}
|
||||
|
||||
// UserValue returns the value set by the user or nil if the value has not
|
||||
// been changed from the default.
|
||||
func (option *Option) UserValue() any {
|
||||
option.Lock()
|
||||
defer option.Unlock()
|
||||
|
||||
if option.activeValue == nil {
|
||||
return nil
|
||||
}
|
||||
return option.activeValue.getData(option)
|
||||
}
|
||||
|
||||
// ValidateValue checks if the given value is valid for the option.
|
||||
func (option *Option) ValidateValue(value any) error {
|
||||
option.Lock()
|
||||
defer option.Unlock()
|
||||
|
||||
value = migrateValue(option, value)
|
||||
if _, err := validateValue(option, value); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Export expors an option to a Record.
|
||||
func (option *Option) Export() (record.Record, error) {
|
||||
option.Lock()
|
||||
defer option.Unlock()
|
||||
|
||||
return option.export()
|
||||
}
|
||||
|
||||
func (option *Option) export() (record.Record, error) {
|
||||
data, err := json.Marshal(option)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if option.activeValue != nil {
|
||||
data, err = sjson.SetBytes(data, "Value", option.activeValue.getData(option))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if option.activeDefaultValue != nil {
|
||||
data, err = sjson.SetBytes(data, "DefaultValue", option.activeDefaultValue.getData(option))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
r, err := record.NewWrapper(fmt.Sprintf("config:%s", option.Key), nil, dsd.JSON, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.SetMeta(&record.Meta{})
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
type sortByKey []*Option
|
||||
|
||||
func (opts sortByKey) Len() int { return len(opts) }
|
||||
func (opts sortByKey) Less(i, j int) bool { return opts[i].Key < opts[j].Key }
|
||||
func (opts sortByKey) Swap(i, j int) { opts[i], opts[j] = opts[j], opts[i] }
|
234
base/config/persistence.go
Normal file
234
base/config/persistence.go
Normal file
|
@ -0,0 +1,234 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/safing/portmaster/base/log"
|
||||
)
|
||||
|
||||
var (
|
||||
configFilePath string
|
||||
|
||||
loadedConfigValidationErrors []*ValidationError
|
||||
loadedConfigValidationErrorsLock sync.Mutex
|
||||
)
|
||||
|
||||
// GetLoadedConfigValidationErrors returns the encountered validation errors
|
||||
// from the last time loading config from disk.
|
||||
func GetLoadedConfigValidationErrors() []*ValidationError {
|
||||
loadedConfigValidationErrorsLock.Lock()
|
||||
defer loadedConfigValidationErrorsLock.Unlock()
|
||||
|
||||
return loadedConfigValidationErrors
|
||||
}
|
||||
|
||||
func loadConfig(requireValidConfig bool) error {
|
||||
// check if persistence is configured
|
||||
if configFilePath == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// read config file
|
||||
data, err := os.ReadFile(configFilePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// convert to map
|
||||
newValues, err := JSONToMap(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
validationErrors, _ := ReplaceConfig(newValues)
|
||||
if requireValidConfig && len(validationErrors) > 0 {
|
||||
return fmt.Errorf("encountered %d validation errors during config loading", len(validationErrors))
|
||||
}
|
||||
|
||||
// Save validation errors.
|
||||
loadedConfigValidationErrorsLock.Lock()
|
||||
defer loadedConfigValidationErrorsLock.Unlock()
|
||||
loadedConfigValidationErrors = validationErrors
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveConfig saves the current configuration to file.
|
||||
// It will acquire a read-lock on the global options registry
|
||||
// lock and must lock each option!
|
||||
func SaveConfig() error {
|
||||
optionsLock.RLock()
|
||||
defer optionsLock.RUnlock()
|
||||
|
||||
// check if persistence is configured
|
||||
if configFilePath == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// extract values
|
||||
activeValues := make(map[string]interface{})
|
||||
for key, option := range options {
|
||||
// we cannot immedately unlock the option afger
|
||||
// getData() because someone could lock and change it
|
||||
// while we are marshaling the value (i.e. for string slices).
|
||||
// We NEED to keep the option locks until we finsihed.
|
||||
option.Lock()
|
||||
defer option.Unlock()
|
||||
|
||||
if option.activeValue != nil {
|
||||
activeValues[key] = option.activeValue.getData(option)
|
||||
}
|
||||
}
|
||||
|
||||
// convert to JSON
|
||||
data, err := MapToJSON(activeValues)
|
||||
if err != nil {
|
||||
log.Errorf("config: failed to save config: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// write file
|
||||
return os.WriteFile(configFilePath, data, 0o0600)
|
||||
}
|
||||
|
||||
// JSONToMap parses and flattens a hierarchical json object.
|
||||
func JSONToMap(jsonData []byte) (map[string]interface{}, error) {
|
||||
loaded := make(map[string]interface{})
|
||||
err := json.Unmarshal(jsonData, &loaded)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return Flatten(loaded), nil
|
||||
}
|
||||
|
||||
// Flatten returns a flattened copy of the given hierarchical config.
|
||||
func Flatten(config map[string]interface{}) (flattenedConfig map[string]interface{}) {
|
||||
flattenedConfig = make(map[string]interface{})
|
||||
flattenMap(flattenedConfig, config, "")
|
||||
return flattenedConfig
|
||||
}
|
||||
|
||||
func flattenMap(rootMap, subMap map[string]interface{}, subKey string) {
|
||||
for key, entry := range subMap {
|
||||
|
||||
// get next level key
|
||||
subbedKey := path.Join(subKey, key)
|
||||
|
||||
// check for next subMap
|
||||
nextSub, ok := entry.(map[string]interface{})
|
||||
if ok {
|
||||
flattenMap(rootMap, nextSub, subbedKey)
|
||||
} else {
|
||||
// only set if not on root level
|
||||
rootMap[subbedKey] = entry
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MapToJSON expands a flattened map and returns it as json.
|
||||
func MapToJSON(config map[string]interface{}) ([]byte, error) {
|
||||
return json.MarshalIndent(Expand(config), "", " ")
|
||||
}
|
||||
|
||||
// Expand returns a hierarchical copy of the given flattened config.
|
||||
func Expand(flattenedConfig map[string]interface{}) (config map[string]interface{}) {
|
||||
config = make(map[string]interface{})
|
||||
for key, entry := range flattenedConfig {
|
||||
PutValueIntoHierarchicalConfig(config, key, entry)
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
||||
// PutValueIntoHierarchicalConfig injects a configuration entry into an hierarchical config map. Conflicting entries will be replaced.
|
||||
func PutValueIntoHierarchicalConfig(config map[string]interface{}, key string, value interface{}) {
|
||||
parts := strings.Split(key, "/")
|
||||
|
||||
// create/check maps for all parts except the last one
|
||||
subMap := config
|
||||
for i, part := range parts {
|
||||
if i == len(parts)-1 {
|
||||
// do not process the last part,
|
||||
// which is not a map, but the value key itself
|
||||
break
|
||||
}
|
||||
|
||||
var nextSubMap map[string]interface{}
|
||||
// get value
|
||||
value, ok := subMap[part]
|
||||
if !ok {
|
||||
// create new map and assign it
|
||||
nextSubMap = make(map[string]interface{})
|
||||
subMap[part] = nextSubMap
|
||||
} else {
|
||||
nextSubMap, ok = value.(map[string]interface{})
|
||||
if !ok {
|
||||
// create new map and assign it
|
||||
nextSubMap = make(map[string]interface{})
|
||||
subMap[part] = nextSubMap
|
||||
}
|
||||
}
|
||||
|
||||
// assign for next parts loop
|
||||
subMap = nextSubMap
|
||||
}
|
||||
|
||||
// assign value to last submap
|
||||
subMap[parts[len(parts)-1]] = value
|
||||
}
|
||||
|
||||
// CleanFlattenedConfig removes all inexistent configuration options from the given flattened config map.
|
||||
func CleanFlattenedConfig(flattenedConfig map[string]interface{}) {
|
||||
optionsLock.RLock()
|
||||
defer optionsLock.RUnlock()
|
||||
|
||||
for key := range flattenedConfig {
|
||||
_, ok := options[key]
|
||||
if !ok {
|
||||
delete(flattenedConfig, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CleanHierarchicalConfig removes all inexistent configuration options from the given hierarchical config map.
|
||||
func CleanHierarchicalConfig(config map[string]interface{}) {
|
||||
optionsLock.RLock()
|
||||
defer optionsLock.RUnlock()
|
||||
|
||||
cleanSubMap(config, "")
|
||||
}
|
||||
|
||||
func cleanSubMap(subMap map[string]interface{}, subKey string) (empty bool) {
|
||||
var foundValid int
|
||||
for key, value := range subMap {
|
||||
value, ok := value.(map[string]interface{})
|
||||
if ok {
|
||||
// we found another section
|
||||
isEmpty := cleanSubMap(value, path.Join(subKey, key))
|
||||
if isEmpty {
|
||||
delete(subMap, key)
|
||||
} else {
|
||||
foundValid++
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// we found an option value
|
||||
if strings.Contains(key, "/") {
|
||||
delete(subMap, key)
|
||||
} else {
|
||||
_, ok := options[path.Join(subKey, key)]
|
||||
if ok {
|
||||
foundValid++
|
||||
} else {
|
||||
delete(subMap, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
return foundValid == 0
|
||||
}
|
97
base/config/persistence_test.go
Normal file
97
base/config/persistence_test.go
Normal file
|
@ -0,0 +1,97 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var (
|
||||
jsonData = `{
|
||||
"a": "b",
|
||||
"c": {
|
||||
"d": "e",
|
||||
"f": "g",
|
||||
"h": {
|
||||
"i": "j",
|
||||
"k": "l",
|
||||
"m": {
|
||||
"n": "o"
|
||||
}
|
||||
}
|
||||
},
|
||||
"p": "q"
|
||||
}`
|
||||
jsonBytes = []byte(jsonData)
|
||||
|
||||
mapData = map[string]interface{}{
|
||||
"a": "b",
|
||||
"p": "q",
|
||||
"c/d": "e",
|
||||
"c/f": "g",
|
||||
"c/h/i": "j",
|
||||
"c/h/k": "l",
|
||||
"c/h/m/n": "o",
|
||||
}
|
||||
)
|
||||
|
||||
func TestJSONMapConversion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// convert to json
|
||||
j, err := MapToJSON(mapData)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// check if to json matches
|
||||
if !bytes.Equal(jsonBytes, j) {
|
||||
t.Errorf("json does not match, got %s", j)
|
||||
}
|
||||
|
||||
// convert to map
|
||||
m, err := JSONToMap(jsonBytes)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// and back
|
||||
j2, err := MapToJSON(m)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// check if double convert matches
|
||||
if !bytes.Equal(jsonBytes, j2) {
|
||||
t.Errorf("json does not match, got %s", j)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigCleaning(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// load
|
||||
configFlat, err := JSONToMap(jsonBytes)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// clean everything
|
||||
CleanFlattenedConfig(configFlat)
|
||||
if len(configFlat) != 0 {
|
||||
t.Errorf("should be empty: %+v", configFlat)
|
||||
}
|
||||
|
||||
// load manuall for hierarchical config
|
||||
configHier := make(map[string]interface{})
|
||||
err = json.Unmarshal(jsonBytes, &configHier)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// clean everything
|
||||
CleanHierarchicalConfig(configHier)
|
||||
if len(configHier) != 0 {
|
||||
t.Errorf("should be empty: %+v", configHier)
|
||||
}
|
||||
}
|
133
base/config/perspective.go
Normal file
133
base/config/perspective.go
Normal file
|
@ -0,0 +1,133 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/safing/portmaster/base/log"
|
||||
)
|
||||
|
||||
// Perspective is a view on configuration data without interfering with the configuration system.
|
||||
type Perspective struct {
|
||||
config map[string]*perspectiveOption
|
||||
}
|
||||
|
||||
type perspectiveOption struct {
|
||||
option *Option
|
||||
valueCache *valueCache
|
||||
}
|
||||
|
||||
// NewPerspective parses the given config and returns it as a new perspective.
|
||||
func NewPerspective(config map[string]interface{}) (*Perspective, error) {
|
||||
// flatten config structure
|
||||
config = Flatten(config)
|
||||
|
||||
perspective := &Perspective{
|
||||
config: make(map[string]*perspectiveOption),
|
||||
}
|
||||
var firstErr error
|
||||
var errCnt int
|
||||
|
||||
optionsLock.RLock()
|
||||
optionsLoop:
|
||||
for key, option := range options {
|
||||
// get option key from config
|
||||
configValue, ok := config[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
// migrate value
|
||||
configValue = migrateValue(option, configValue)
|
||||
// validate value
|
||||
valueCache, err := validateValue(option, configValue)
|
||||
if err != nil {
|
||||
errCnt++
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
continue optionsLoop
|
||||
}
|
||||
|
||||
// add to perspective
|
||||
perspective.config[key] = &perspectiveOption{
|
||||
option: option,
|
||||
valueCache: valueCache,
|
||||
}
|
||||
}
|
||||
optionsLock.RUnlock()
|
||||
|
||||
if firstErr != nil {
|
||||
if errCnt > 0 {
|
||||
return perspective, fmt.Errorf("encountered %d errors, first was: %w", errCnt, firstErr)
|
||||
}
|
||||
return perspective, firstErr
|
||||
}
|
||||
|
||||
return perspective, nil
|
||||
}
|
||||
|
||||
func (p *Perspective) getPerspectiveValueCache(name string, requestedType OptionType) *valueCache {
|
||||
// get option
|
||||
pOption, ok := p.config[name]
|
||||
if !ok {
|
||||
// check if option exists at all
|
||||
if _, err := GetOption(name); err != nil {
|
||||
log.Errorf("config: request for unregistered option: %s", name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// check type
|
||||
if requestedType != pOption.option.OptType && requestedType != optTypeAny {
|
||||
log.Errorf("config: bad type: requested %s as %s, but is %s", name, getTypeName(requestedType), getTypeName(pOption.option.OptType))
|
||||
return nil
|
||||
}
|
||||
|
||||
// check release level
|
||||
if pOption.option.ReleaseLevel > getReleaseLevel() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return pOption.valueCache
|
||||
}
|
||||
|
||||
// Has returns whether the given option is set in the perspective.
|
||||
func (p *Perspective) Has(name string) bool {
|
||||
valueCache := p.getPerspectiveValueCache(name, optTypeAny)
|
||||
return valueCache != nil
|
||||
}
|
||||
|
||||
// GetAsString returns a function that returns the wanted string with high performance.
|
||||
func (p *Perspective) GetAsString(name string) (value string, ok bool) {
|
||||
valueCache := p.getPerspectiveValueCache(name, OptTypeString)
|
||||
if valueCache != nil {
|
||||
return valueCache.stringVal, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// GetAsStringArray returns a function that returns the wanted string with high performance.
|
||||
func (p *Perspective) GetAsStringArray(name string) (value []string, ok bool) {
|
||||
valueCache := p.getPerspectiveValueCache(name, OptTypeStringArray)
|
||||
if valueCache != nil {
|
||||
return valueCache.stringArrayVal, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// GetAsInt returns a function that returns the wanted int with high performance.
|
||||
func (p *Perspective) GetAsInt(name string) (value int64, ok bool) {
|
||||
valueCache := p.getPerspectiveValueCache(name, OptTypeInt)
|
||||
if valueCache != nil {
|
||||
return valueCache.intVal, true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetAsBool returns a function that returns the wanted int with high performance.
|
||||
func (p *Perspective) GetAsBool(name string) (value bool, ok bool) {
|
||||
valueCache := p.getPerspectiveValueCache(name, OptTypeBool)
|
||||
if valueCache != nil {
|
||||
return valueCache.boolVal, true
|
||||
}
|
||||
return false, false
|
||||
}
|
106
base/config/registry.go
Normal file
106
base/config/registry.go
Normal file
|
@ -0,0 +1,106 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
optionsLock sync.RWMutex
|
||||
options = make(map[string]*Option)
|
||||
)
|
||||
|
||||
// ForEachOption calls fn for each defined option. If fn returns
|
||||
// and error the iteration is stopped and the error is returned.
|
||||
// Note that ForEachOption does not guarantee a stable order of
|
||||
// iteration between multiple calles. ForEachOption does NOT lock
|
||||
// opt when calling fn.
|
||||
func ForEachOption(fn func(opt *Option) error) error {
|
||||
optionsLock.RLock()
|
||||
defer optionsLock.RUnlock()
|
||||
|
||||
for _, opt := range options {
|
||||
if err := fn(opt); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExportOptions exports the registered options. The returned data must be
|
||||
// treated as immutable.
|
||||
// The data does not include the current active or default settings.
|
||||
func ExportOptions() []*Option {
|
||||
optionsLock.RLock()
|
||||
defer optionsLock.RUnlock()
|
||||
|
||||
// Copy the map into a slice.
|
||||
opts := make([]*Option, 0, len(options))
|
||||
for _, opt := range options {
|
||||
opts = append(opts, opt)
|
||||
}
|
||||
|
||||
sort.Sort(sortByKey(opts))
|
||||
return opts
|
||||
}
|
||||
|
||||
// GetOption returns the option with name or an error
|
||||
// if the option does not exist. The caller should lock
|
||||
// the returned option itself for further processing.
|
||||
func GetOption(name string) (*Option, error) {
|
||||
optionsLock.RLock()
|
||||
defer optionsLock.RUnlock()
|
||||
|
||||
opt, ok := options[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("option %q does not exist", name)
|
||||
}
|
||||
return opt, nil
|
||||
}
|
||||
|
||||
// Register registers a new configuration option.
|
||||
func Register(option *Option) error {
|
||||
if option.Name == "" {
|
||||
return fmt.Errorf("failed to register option: please set option.Name")
|
||||
}
|
||||
if option.Key == "" {
|
||||
return fmt.Errorf("failed to register option: please set option.Key")
|
||||
}
|
||||
if option.Description == "" {
|
||||
return fmt.Errorf("failed to register option: please set option.Description")
|
||||
}
|
||||
if option.OptType == 0 {
|
||||
return fmt.Errorf("failed to register option: please set option.OptType")
|
||||
}
|
||||
|
||||
if option.ValidationRegex == "" && option.PossibleValues != nil {
|
||||
values := make([]string, len(option.PossibleValues))
|
||||
for idx, val := range option.PossibleValues {
|
||||
values[idx] = fmt.Sprintf("%v", val.Value)
|
||||
}
|
||||
option.ValidationRegex = fmt.Sprintf("^(%s)$", strings.Join(values, "|"))
|
||||
}
|
||||
|
||||
var err error
|
||||
if option.ValidationRegex != "" {
|
||||
option.compiledRegex, err = regexp.Compile(option.ValidationRegex)
|
||||
if err != nil {
|
||||
return fmt.Errorf("config: could not compile option.ValidationRegex: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var vErr *ValidationError
|
||||
option.activeFallbackValue, vErr = validateValue(option, option.DefaultValue)
|
||||
if vErr != nil {
|
||||
return fmt.Errorf("config: invalid default value: %w", vErr)
|
||||
}
|
||||
|
||||
optionsLock.Lock()
|
||||
defer optionsLock.Unlock()
|
||||
options[option.Key] = option
|
||||
|
||||
return nil
|
||||
}
|
49
base/config/registry_test.go
Normal file
49
base/config/registry_test.go
Normal file
|
@ -0,0 +1,49 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRegistry(t *testing.T) { //nolint:paralleltest
|
||||
// reset
|
||||
options = make(map[string]*Option)
|
||||
|
||||
if err := Register(&Option{
|
||||
Name: "name",
|
||||
Key: "key",
|
||||
Description: "description",
|
||||
ReleaseLevel: ReleaseLevelStable,
|
||||
ExpertiseLevel: ExpertiseLevelUser,
|
||||
OptType: OptTypeString,
|
||||
DefaultValue: "water",
|
||||
ValidationRegex: "^(banana|water)$",
|
||||
}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if err := Register(&Option{
|
||||
Name: "name",
|
||||
Key: "key",
|
||||
Description: "description",
|
||||
ReleaseLevel: ReleaseLevelStable,
|
||||
ExpertiseLevel: ExpertiseLevelUser,
|
||||
OptType: 0,
|
||||
DefaultValue: "default",
|
||||
ValidationRegex: "^[A-Z][a-z]+$",
|
||||
}); err == nil {
|
||||
t.Error("should fail")
|
||||
}
|
||||
|
||||
if err := Register(&Option{
|
||||
Name: "name",
|
||||
Key: "key",
|
||||
Description: "description",
|
||||
ReleaseLevel: ReleaseLevelStable,
|
||||
ExpertiseLevel: ExpertiseLevelUser,
|
||||
OptType: OptTypeString,
|
||||
DefaultValue: "default",
|
||||
ValidationRegex: "[",
|
||||
}); err == nil {
|
||||
t.Error("should fail")
|
||||
}
|
||||
}
|
101
base/config/release.go
Normal file
101
base/config/release.go
Normal file
|
@ -0,0 +1,101 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
)
|
||||
|
||||
// ReleaseLevel is used to define the maturity of a
|
||||
// configuration setting.
|
||||
type ReleaseLevel uint8
|
||||
|
||||
// Release Level constants.
|
||||
const (
|
||||
ReleaseLevelStable ReleaseLevel = 0
|
||||
ReleaseLevelBeta ReleaseLevel = 1
|
||||
ReleaseLevelExperimental ReleaseLevel = 2
|
||||
|
||||
ReleaseLevelNameStable = "stable"
|
||||
ReleaseLevelNameBeta = "beta"
|
||||
ReleaseLevelNameExperimental = "experimental"
|
||||
|
||||
releaseLevelKey = "core/releaseLevel"
|
||||
)
|
||||
|
||||
var (
|
||||
releaseLevel = new(int32)
|
||||
releaseLevelOption *Option
|
||||
releaseLevelOptionFlag = abool.New()
|
||||
)
|
||||
|
||||
func init() {
|
||||
registerReleaseLevelOption()
|
||||
}
|
||||
|
||||
func registerReleaseLevelOption() {
|
||||
releaseLevelOption = &Option{
|
||||
Name: "Feature Stability",
|
||||
Key: releaseLevelKey,
|
||||
Description: `May break things. Decide if you want to experiment with unstable features. "Beta" has been tested roughly by the Safing team while "Experimental" is really raw. When "Beta" or "Experimental" are disabled, their settings use the default again.`,
|
||||
OptType: OptTypeString,
|
||||
ExpertiseLevel: ExpertiseLevelDeveloper,
|
||||
ReleaseLevel: ReleaseLevelStable,
|
||||
DefaultValue: ReleaseLevelNameStable,
|
||||
Annotations: Annotations{
|
||||
DisplayOrderAnnotation: -8,
|
||||
DisplayHintAnnotation: DisplayHintOneOf,
|
||||
CategoryAnnotation: "Updates",
|
||||
},
|
||||
PossibleValues: []PossibleValue{
|
||||
{
|
||||
Name: "Stable",
|
||||
Value: ReleaseLevelNameStable,
|
||||
Description: "Only show stable features.",
|
||||
},
|
||||
{
|
||||
Name: "Beta",
|
||||
Value: ReleaseLevelNameBeta,
|
||||
Description: "Show stable and beta features.",
|
||||
},
|
||||
{
|
||||
Name: "Experimental",
|
||||
Value: ReleaseLevelNameExperimental,
|
||||
Description: "Show all features",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := Register(releaseLevelOption)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
releaseLevelOptionFlag.Set()
|
||||
}
|
||||
|
||||
func updateReleaseLevel() {
|
||||
// get value
|
||||
value := releaseLevelOption.activeFallbackValue
|
||||
if releaseLevelOption.activeValue != nil {
|
||||
value = releaseLevelOption.activeValue
|
||||
}
|
||||
if releaseLevelOption.activeDefaultValue != nil {
|
||||
value = releaseLevelOption.activeDefaultValue
|
||||
}
|
||||
// set atomic value
|
||||
switch value.stringVal {
|
||||
case ReleaseLevelNameStable:
|
||||
atomic.StoreInt32(releaseLevel, int32(ReleaseLevelStable))
|
||||
case ReleaseLevelNameBeta:
|
||||
atomic.StoreInt32(releaseLevel, int32(ReleaseLevelBeta))
|
||||
case ReleaseLevelNameExperimental:
|
||||
atomic.StoreInt32(releaseLevel, int32(ReleaseLevelExperimental))
|
||||
default:
|
||||
atomic.StoreInt32(releaseLevel, int32(ReleaseLevelStable))
|
||||
}
|
||||
}
|
||||
|
||||
func getReleaseLevel() ReleaseLevel {
|
||||
return ReleaseLevel(atomic.LoadInt32(releaseLevel))
|
||||
}
|
235
base/config/set.go
Normal file
235
base/config/set.go
Normal file
|
@ -0,0 +1,235 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidJSON is returned by SetConfig and SetDefaultConfig if they receive invalid json.
|
||||
ErrInvalidJSON = errors.New("json string invalid")
|
||||
|
||||
// ErrInvalidOptionType is returned by SetConfigOption and SetDefaultConfigOption if given an unsupported option type.
|
||||
ErrInvalidOptionType = errors.New("invalid option value type")
|
||||
|
||||
validityFlag = abool.NewBool(true)
|
||||
validityFlagLock sync.RWMutex
|
||||
)
|
||||
|
||||
// getValidityFlag returns a flag that signifies if the configuration has been changed. This flag must not be changed, only read.
|
||||
func getValidityFlag() *abool.AtomicBool {
|
||||
validityFlagLock.RLock()
|
||||
defer validityFlagLock.RUnlock()
|
||||
return validityFlag
|
||||
}
|
||||
|
||||
// signalChanges marks the configs validtityFlag as dirty and eventually
|
||||
// triggers a config change event.
|
||||
func signalChanges() {
|
||||
// reset validity flag
|
||||
validityFlagLock.Lock()
|
||||
validityFlag.SetTo(false)
|
||||
validityFlag = abool.NewBool(true)
|
||||
validityFlagLock.Unlock()
|
||||
|
||||
module.EventConfigChange.Submit(struct{}{})
|
||||
}
|
||||
|
||||
// ValidateConfig validates the given configuration and returns all validation
|
||||
// errors as well as whether the given configuration contains unknown keys.
|
||||
func ValidateConfig(newValues map[string]interface{}) (validationErrors []*ValidationError, requiresRestart bool, containsUnknown bool) {
|
||||
// RLock the options because we are not adding or removing
|
||||
// options from the registration but rather only checking the
|
||||
// options value which is guarded by the option's lock itself.
|
||||
optionsLock.RLock()
|
||||
defer optionsLock.RUnlock()
|
||||
|
||||
var checked int
|
||||
for key, option := range options {
|
||||
newValue, ok := newValues[key]
|
||||
if ok {
|
||||
checked++
|
||||
|
||||
func() {
|
||||
option.Lock()
|
||||
defer option.Unlock()
|
||||
|
||||
newValue = migrateValue(option, newValue)
|
||||
_, err := validateValue(option, newValue)
|
||||
if err != nil {
|
||||
validationErrors = append(validationErrors, err)
|
||||
}
|
||||
|
||||
if option.RequiresRestart {
|
||||
requiresRestart = true
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
return validationErrors, requiresRestart, checked < len(newValues)
|
||||
}
|
||||
|
||||
// ReplaceConfig sets the (prioritized) user defined config.
|
||||
func ReplaceConfig(newValues map[string]interface{}) (validationErrors []*ValidationError, requiresRestart bool) {
|
||||
// RLock the options because we are not adding or removing
|
||||
// options from the registration but rather only update the
|
||||
// options value which is guarded by the option's lock itself.
|
||||
optionsLock.RLock()
|
||||
defer optionsLock.RUnlock()
|
||||
|
||||
for key, option := range options {
|
||||
newValue, ok := newValues[key]
|
||||
|
||||
func() {
|
||||
option.Lock()
|
||||
defer option.Unlock()
|
||||
|
||||
option.activeValue = nil
|
||||
if ok {
|
||||
newValue = migrateValue(option, newValue)
|
||||
valueCache, err := validateValue(option, newValue)
|
||||
if err == nil {
|
||||
option.activeValue = valueCache
|
||||
} else {
|
||||
validationErrors = append(validationErrors, err)
|
||||
}
|
||||
}
|
||||
handleOptionUpdate(option, true)
|
||||
|
||||
if option.RequiresRestart {
|
||||
requiresRestart = true
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
signalChanges()
|
||||
|
||||
return validationErrors, requiresRestart
|
||||
}
|
||||
|
||||
// ReplaceDefaultConfig sets the (fallback) default config.
|
||||
func ReplaceDefaultConfig(newValues map[string]interface{}) (validationErrors []*ValidationError, requiresRestart bool) {
|
||||
// RLock the options because we are not adding or removing
|
||||
// options from the registration but rather only update the
|
||||
// options value which is guarded by the option's lock itself.
|
||||
optionsLock.RLock()
|
||||
defer optionsLock.RUnlock()
|
||||
|
||||
for key, option := range options {
|
||||
newValue, ok := newValues[key]
|
||||
|
||||
func() {
|
||||
option.Lock()
|
||||
defer option.Unlock()
|
||||
|
||||
option.activeDefaultValue = nil
|
||||
if ok {
|
||||
newValue = migrateValue(option, newValue)
|
||||
valueCache, err := validateValue(option, newValue)
|
||||
if err == nil {
|
||||
option.activeDefaultValue = valueCache
|
||||
} else {
|
||||
validationErrors = append(validationErrors, err)
|
||||
}
|
||||
}
|
||||
handleOptionUpdate(option, true)
|
||||
|
||||
if option.RequiresRestart {
|
||||
requiresRestart = true
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
signalChanges()
|
||||
|
||||
return validationErrors, requiresRestart
|
||||
}
|
||||
|
||||
// SetConfigOption sets a single value in the (prioritized) user defined config.
|
||||
func SetConfigOption(key string, value any) error {
|
||||
return setConfigOption(key, value, true)
|
||||
}
|
||||
|
||||
func setConfigOption(key string, value any, push bool) (err error) {
|
||||
option, err := GetOption(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
option.Lock()
|
||||
if value == nil {
|
||||
option.activeValue = nil
|
||||
} else {
|
||||
value = migrateValue(option, value)
|
||||
valueCache, vErr := validateValue(option, value)
|
||||
if vErr == nil {
|
||||
option.activeValue = valueCache
|
||||
} else {
|
||||
err = vErr
|
||||
}
|
||||
}
|
||||
|
||||
// Add the "restart pending" annotation if the settings requires a restart.
|
||||
if option.RequiresRestart {
|
||||
option.setAnnotation(RestartPendingAnnotation, true)
|
||||
}
|
||||
|
||||
handleOptionUpdate(option, push)
|
||||
option.Unlock()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// finalize change, activate triggers
|
||||
signalChanges()
|
||||
|
||||
return SaveConfig()
|
||||
}
|
||||
|
||||
// SetDefaultConfigOption sets a single value in the (fallback) default config.
|
||||
func SetDefaultConfigOption(key string, value interface{}) error {
|
||||
return setDefaultConfigOption(key, value, true)
|
||||
}
|
||||
|
||||
func setDefaultConfigOption(key string, value interface{}, push bool) (err error) {
|
||||
option, err := GetOption(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
option.Lock()
|
||||
if value == nil {
|
||||
option.activeDefaultValue = nil
|
||||
} else {
|
||||
value = migrateValue(option, value)
|
||||
valueCache, vErr := validateValue(option, value)
|
||||
if vErr == nil {
|
||||
option.activeDefaultValue = valueCache
|
||||
} else {
|
||||
err = vErr
|
||||
}
|
||||
}
|
||||
|
||||
// Add the "restart pending" annotation if the settings requires a restart.
|
||||
if option.RequiresRestart {
|
||||
option.setAnnotation(RestartPendingAnnotation, true)
|
||||
}
|
||||
|
||||
handleOptionUpdate(option, push)
|
||||
option.Unlock()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// finalize change, activate triggers
|
||||
signalChanges()
|
||||
|
||||
// Do not save the configuration, as it only saves the active values, not the
|
||||
// active default value.
|
||||
return nil
|
||||
}
|
193
base/config/set_test.go
Normal file
193
base/config/set_test.go
Normal file
|
@ -0,0 +1,193 @@
|
|||
//nolint:goconst
|
||||
package config
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestLayersGetters(t *testing.T) { //nolint:paralleltest
|
||||
// reset
|
||||
options = make(map[string]*Option)
|
||||
|
||||
mapData, err := JSONToMap([]byte(`
|
||||
{
|
||||
"monkey": "1",
|
||||
"elephant": 2,
|
||||
"zebras": {
|
||||
"zebra": ["black", "white"],
|
||||
"weird_zebra": ["black", -1]
|
||||
},
|
||||
"env": {
|
||||
"hot": true
|
||||
}
|
||||
}
|
||||
`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
validationErrors, _ := ReplaceConfig(mapData)
|
||||
if len(validationErrors) > 0 {
|
||||
t.Fatalf("%d errors, first: %s", len(validationErrors), validationErrors[0].Error())
|
||||
}
|
||||
|
||||
// Test missing values
|
||||
|
||||
missingString := GetAsString("missing", "fallback")
|
||||
if missingString() != "fallback" {
|
||||
t.Error("expected fallback value: fallback")
|
||||
}
|
||||
|
||||
missingStringArray := GetAsStringArray("missing", []string{"fallback"})
|
||||
if len(missingStringArray()) != 1 || missingStringArray()[0] != "fallback" {
|
||||
t.Error("expected fallback value: [fallback]")
|
||||
}
|
||||
|
||||
missingInt := GetAsInt("missing", -1)
|
||||
if missingInt() != -1 {
|
||||
t.Error("expected fallback value: -1")
|
||||
}
|
||||
|
||||
missingBool := GetAsBool("missing", false)
|
||||
if missingBool() {
|
||||
t.Error("expected fallback value: false")
|
||||
}
|
||||
|
||||
// Test value mismatch
|
||||
|
||||
notString := GetAsString("elephant", "fallback")
|
||||
if notString() != "fallback" {
|
||||
t.Error("expected fallback value: fallback")
|
||||
}
|
||||
|
||||
notStringArray := GetAsStringArray("elephant", []string{"fallback"})
|
||||
if len(notStringArray()) != 1 || notStringArray()[0] != "fallback" {
|
||||
t.Error("expected fallback value: [fallback]")
|
||||
}
|
||||
|
||||
mixedStringArray := GetAsStringArray("zebras/weird_zebra", []string{"fallback"})
|
||||
if len(mixedStringArray()) != 1 || mixedStringArray()[0] != "fallback" {
|
||||
t.Error("expected fallback value: [fallback]")
|
||||
}
|
||||
|
||||
notInt := GetAsInt("monkey", -1)
|
||||
if notInt() != -1 {
|
||||
t.Error("expected fallback value: -1")
|
||||
}
|
||||
|
||||
notBool := GetAsBool("monkey", false)
|
||||
if notBool() {
|
||||
t.Error("expected fallback value: false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLayersSetters(t *testing.T) { //nolint:paralleltest
|
||||
// reset
|
||||
options = make(map[string]*Option)
|
||||
|
||||
_ = Register(&Option{
|
||||
Name: "name",
|
||||
Key: "monkey",
|
||||
Description: "description",
|
||||
ReleaseLevel: ReleaseLevelStable,
|
||||
ExpertiseLevel: ExpertiseLevelUser,
|
||||
OptType: OptTypeString,
|
||||
DefaultValue: "banana",
|
||||
ValidationRegex: "^(banana|water)$",
|
||||
})
|
||||
_ = Register(&Option{
|
||||
Name: "name",
|
||||
Key: "zebras/zebra",
|
||||
Description: "description",
|
||||
ReleaseLevel: ReleaseLevelStable,
|
||||
ExpertiseLevel: ExpertiseLevelUser,
|
||||
OptType: OptTypeStringArray,
|
||||
DefaultValue: []string{"black", "white"},
|
||||
ValidationRegex: "^[a-z]+$",
|
||||
})
|
||||
_ = Register(&Option{
|
||||
Name: "name",
|
||||
Key: "elephant",
|
||||
Description: "description",
|
||||
ReleaseLevel: ReleaseLevelStable,
|
||||
ExpertiseLevel: ExpertiseLevelUser,
|
||||
OptType: OptTypeInt,
|
||||
DefaultValue: 2,
|
||||
ValidationRegex: "",
|
||||
})
|
||||
_ = Register(&Option{
|
||||
Name: "name",
|
||||
Key: "hot",
|
||||
Description: "description",
|
||||
ReleaseLevel: ReleaseLevelStable,
|
||||
ExpertiseLevel: ExpertiseLevelUser,
|
||||
OptType: OptTypeBool,
|
||||
DefaultValue: true,
|
||||
ValidationRegex: "",
|
||||
})
|
||||
|
||||
// correct types
|
||||
if err := SetConfigOption("monkey", "banana"); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := SetConfigOption("zebras/zebra", []string{"black", "white"}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := SetDefaultConfigOption("elephant", 2); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := SetDefaultConfigOption("hot", true); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// incorrect types
|
||||
if err := SetConfigOption("monkey", []string{"black", "white"}); err == nil {
|
||||
t.Error("should fail")
|
||||
}
|
||||
if err := SetConfigOption("zebras/zebra", 2); err == nil {
|
||||
t.Error("should fail")
|
||||
}
|
||||
if err := SetDefaultConfigOption("elephant", true); err == nil {
|
||||
t.Error("should fail")
|
||||
}
|
||||
if err := SetDefaultConfigOption("hot", "banana"); err == nil {
|
||||
t.Error("should fail")
|
||||
}
|
||||
if err := SetDefaultConfigOption("hot", []byte{0}); err == nil {
|
||||
t.Error("should fail")
|
||||
}
|
||||
|
||||
// validation fail
|
||||
if err := SetConfigOption("monkey", "dirt"); err == nil {
|
||||
t.Error("should fail")
|
||||
}
|
||||
if err := SetConfigOption("zebras/zebra", []string{"Element649"}); err == nil {
|
||||
t.Error("should fail")
|
||||
}
|
||||
|
||||
// unregistered checking
|
||||
if err := SetConfigOption("invalid", "banana"); err == nil {
|
||||
t.Error("should fail")
|
||||
}
|
||||
if err := SetConfigOption("invalid", []string{"black", "white"}); err == nil {
|
||||
t.Error("should fail")
|
||||
}
|
||||
if err := SetConfigOption("invalid", 2); err == nil {
|
||||
t.Error("should fail")
|
||||
}
|
||||
if err := SetConfigOption("invalid", true); err == nil {
|
||||
t.Error("should fail")
|
||||
}
|
||||
if err := SetConfigOption("invalid", []byte{0}); err == nil {
|
||||
t.Error("should fail")
|
||||
}
|
||||
|
||||
// delete
|
||||
if err := SetConfigOption("monkey", nil); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := SetDefaultConfigOption("elephant", nil); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := SetDefaultConfigOption("invalid_delete", nil); err == nil {
|
||||
t.Error("should fail")
|
||||
}
|
||||
}
|
239
base/config/validate.go
Normal file
239
base/config/validate.go
Normal file
|
@ -0,0 +1,239 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
|
||||
"github.com/safing/portmaster/base/log"
|
||||
)
|
||||
|
||||
type valueCache struct {
|
||||
stringVal string
|
||||
stringArrayVal []string
|
||||
intVal int64
|
||||
boolVal bool
|
||||
}
|
||||
|
||||
func (vc *valueCache) getData(opt *Option) interface{} {
|
||||
switch opt.OptType {
|
||||
case OptTypeBool:
|
||||
return vc.boolVal
|
||||
case OptTypeInt:
|
||||
return vc.intVal
|
||||
case OptTypeString:
|
||||
return vc.stringVal
|
||||
case OptTypeStringArray:
|
||||
return vc.stringArrayVal
|
||||
case optTypeAny:
|
||||
return nil
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// isAllowedPossibleValue checks if value is defined as a PossibleValue
|
||||
// in opt. If there are not possible values defined value is considered
|
||||
// allowed and nil is returned. isAllowedPossibleValue ensure the actual
|
||||
// value is an allowed primitiv value by using reflection to convert
|
||||
// value and each PossibleValue to a comparable primitiv if possible.
|
||||
// In case of complex value types isAllowedPossibleValue uses
|
||||
// reflect.DeepEqual as a fallback.
|
||||
func isAllowedPossibleValue(opt *Option, value interface{}) error {
|
||||
if opt.PossibleValues == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, val := range opt.PossibleValues {
|
||||
compareAgainst := val.Value
|
||||
valueType := reflect.TypeOf(value)
|
||||
|
||||
// loading int's from the configuration JSON does not preserve the correct type
|
||||
// as we get float64 instead. Make sure to convert them before.
|
||||
if reflect.TypeOf(val.Value).ConvertibleTo(valueType) {
|
||||
compareAgainst = reflect.ValueOf(val.Value).Convert(valueType).Interface()
|
||||
}
|
||||
if compareAgainst == value {
|
||||
return nil
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(val.Value, value) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return errors.New("value is not allowed")
|
||||
}
|
||||
|
||||
// migrateValue runs all value migrations.
|
||||
func migrateValue(option *Option, value any) any {
|
||||
for _, migration := range option.Migrations {
|
||||
newValue := migration(option, value)
|
||||
if newValue != value {
|
||||
log.Debugf("config: migrated %s value from %v to %v", option.Key, value, newValue)
|
||||
}
|
||||
value = newValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// validateValue ensures that value matches the expected type of option.
|
||||
// It does not create a copy of the value!
|
||||
func validateValue(option *Option, value interface{}) (*valueCache, *ValidationError) { //nolint:gocyclo
|
||||
if option.OptType != OptTypeStringArray {
|
||||
if err := isAllowedPossibleValue(option, value); err != nil {
|
||||
return nil, &ValidationError{
|
||||
Option: option.copyOrNil(),
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var validated *valueCache
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
if option.OptType != OptTypeString {
|
||||
return nil, invalid(option, "expected type %s, got type %T", getTypeName(option.OptType), v)
|
||||
}
|
||||
if option.compiledRegex != nil {
|
||||
if !option.compiledRegex.MatchString(v) {
|
||||
return nil, invalid(option, "did not match validation regex")
|
||||
}
|
||||
}
|
||||
validated = &valueCache{stringVal: v}
|
||||
case []interface{}:
|
||||
vConverted := make([]string, len(v))
|
||||
for pos, entry := range v {
|
||||
s, ok := entry.(string)
|
||||
if !ok {
|
||||
return nil, invalid(option, "entry #%d is not a string", pos+1)
|
||||
}
|
||||
vConverted[pos] = s
|
||||
}
|
||||
// Call validation function again with converted value.
|
||||
var vErr *ValidationError
|
||||
validated, vErr = validateValue(option, vConverted)
|
||||
if vErr != nil {
|
||||
return nil, vErr
|
||||
}
|
||||
case []string:
|
||||
if option.OptType != OptTypeStringArray {
|
||||
return nil, invalid(option, "expected type %s, got type %T", getTypeName(option.OptType), v)
|
||||
}
|
||||
if option.compiledRegex != nil {
|
||||
for pos, entry := range v {
|
||||
if !option.compiledRegex.MatchString(entry) {
|
||||
return nil, invalid(option, "entry #%d did not match validation regex", pos+1)
|
||||
}
|
||||
|
||||
if err := isAllowedPossibleValue(option, entry); err != nil {
|
||||
return nil, invalid(option, "entry #%d is not allowed", pos+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
validated = &valueCache{stringArrayVal: v}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, float32, float64:
|
||||
// uint64 is omitted, as it does not fit in a int64
|
||||
if option.OptType != OptTypeInt {
|
||||
return nil, invalid(option, "expected type %s, got type %T", getTypeName(option.OptType), v)
|
||||
}
|
||||
if option.compiledRegex != nil {
|
||||
// we need to use %v here so we handle float and int correctly.
|
||||
if !option.compiledRegex.MatchString(fmt.Sprintf("%v", v)) {
|
||||
return nil, invalid(option, "did not match validation regex")
|
||||
}
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
validated = &valueCache{intVal: int64(v)}
|
||||
case int8:
|
||||
validated = &valueCache{intVal: int64(v)}
|
||||
case int16:
|
||||
validated = &valueCache{intVal: int64(v)}
|
||||
case int32:
|
||||
validated = &valueCache{intVal: int64(v)}
|
||||
case int64:
|
||||
validated = &valueCache{intVal: v}
|
||||
case uint:
|
||||
validated = &valueCache{intVal: int64(v)}
|
||||
case uint8:
|
||||
validated = &valueCache{intVal: int64(v)}
|
||||
case uint16:
|
||||
validated = &valueCache{intVal: int64(v)}
|
||||
case uint32:
|
||||
validated = &valueCache{intVal: int64(v)}
|
||||
case float32:
|
||||
// convert if float has no decimals
|
||||
if math.Remainder(float64(v), 1) == 0 {
|
||||
validated = &valueCache{intVal: int64(v)}
|
||||
} else {
|
||||
return nil, invalid(option, "failed to convert float32 to int64")
|
||||
}
|
||||
case float64:
|
||||
// convert if float has no decimals
|
||||
if math.Remainder(v, 1) == 0 {
|
||||
validated = &valueCache{intVal: int64(v)}
|
||||
} else {
|
||||
return nil, invalid(option, "failed to convert float64 to int64")
|
||||
}
|
||||
default:
|
||||
return nil, invalid(option, "internal error")
|
||||
}
|
||||
case bool:
|
||||
if option.OptType != OptTypeBool {
|
||||
return nil, invalid(option, "expected type %s, got type %T", getTypeName(option.OptType), v)
|
||||
}
|
||||
validated = &valueCache{boolVal: v}
|
||||
default:
|
||||
return nil, invalid(option, "invalid option value type: %T", value)
|
||||
}
|
||||
|
||||
// Check if there is an additional function to validate the value.
|
||||
if option.ValidationFunc != nil {
|
||||
var err error
|
||||
switch option.OptType {
|
||||
case optTypeAny:
|
||||
err = errors.New("internal error")
|
||||
case OptTypeString:
|
||||
err = option.ValidationFunc(validated.stringVal)
|
||||
case OptTypeStringArray:
|
||||
err = option.ValidationFunc(validated.stringArrayVal)
|
||||
case OptTypeInt:
|
||||
err = option.ValidationFunc(validated.intVal)
|
||||
case OptTypeBool:
|
||||
err = option.ValidationFunc(validated.boolVal)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, &ValidationError{
|
||||
Option: option.copyOrNil(),
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return validated, nil
|
||||
}
|
||||
|
||||
// ValidationError error holds details about a config option value validation error.
|
||||
type ValidationError struct {
|
||||
Option *Option
|
||||
Err error
|
||||
}
|
||||
|
||||
// Error returns the formatted error.
|
||||
func (ve *ValidationError) Error() string {
|
||||
return fmt.Sprintf("validation of %s failed: %s", ve.Option.Key, ve.Err)
|
||||
}
|
||||
|
||||
// Unwrap returns the wrapped error.
|
||||
func (ve *ValidationError) Unwrap() error {
|
||||
return ve.Err
|
||||
}
|
||||
|
||||
func invalid(option *Option, format string, a ...interface{}) *ValidationError {
|
||||
return &ValidationError{
|
||||
Option: option.copyOrNil(),
|
||||
Err: fmt.Errorf(format, a...),
|
||||
}
|
||||
}
|
32
base/config/validity.go
Normal file
32
base/config/validity.go
Normal file
|
@ -0,0 +1,32 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"github.com/tevino/abool"
|
||||
)
|
||||
|
||||
// ValidityFlag is a flag that signifies if the configuration has been changed. It is not safe for concurrent use.
|
||||
type ValidityFlag struct {
|
||||
flag *abool.AtomicBool
|
||||
}
|
||||
|
||||
// NewValidityFlag returns a flag that signifies if the configuration has been changed.
|
||||
// It always starts out as invalid. Refresh to start with the current value.
|
||||
func NewValidityFlag() *ValidityFlag {
|
||||
vf := &ValidityFlag{
|
||||
flag: abool.New(),
|
||||
}
|
||||
return vf
|
||||
}
|
||||
|
||||
// IsValid returns if the configuration is still valid.
|
||||
func (vf *ValidityFlag) IsValid() bool {
|
||||
return vf.flag.IsSet()
|
||||
}
|
||||
|
||||
// Refresh refreshes the flag and makes it reusable.
|
||||
func (vf *ValidityFlag) Refresh() {
|
||||
validityFlagLock.RLock()
|
||||
defer validityFlagLock.RUnlock()
|
||||
|
||||
vf.flag = validityFlag
|
||||
}
|
368
base/container/container.go
Normal file
368
base/container/container.go
Normal file
|
@ -0,0 +1,368 @@
|
|||
package container
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"github.com/safing/structures/varint"
|
||||
)
|
||||
|
||||
// Container is []byte sclie on steroids, allowing for quick data appending, prepending and fetching.
|
||||
type Container struct {
|
||||
compartments [][]byte
|
||||
offset int
|
||||
err error
|
||||
}
|
||||
|
||||
// Data Handling
|
||||
|
||||
// NewContainer is DEPRECATED, please use New(), it's the same thing.
|
||||
func NewContainer(data ...[]byte) *Container {
|
||||
return &Container{
|
||||
compartments: data,
|
||||
}
|
||||
}
|
||||
|
||||
// New creates a new container with an optional initial []byte slice. Data will NOT be copied.
|
||||
func New(data ...[]byte) *Container {
|
||||
return &Container{
|
||||
compartments: data,
|
||||
}
|
||||
}
|
||||
|
||||
// Prepend prepends data. Data will NOT be copied.
|
||||
func (c *Container) Prepend(data []byte) {
|
||||
if c.offset < 1 {
|
||||
c.renewCompartments()
|
||||
}
|
||||
c.offset--
|
||||
c.compartments[c.offset] = data
|
||||
}
|
||||
|
||||
// Append appends the given data. Data will NOT be copied.
|
||||
func (c *Container) Append(data []byte) {
|
||||
c.compartments = append(c.compartments, data)
|
||||
}
|
||||
|
||||
// PrependNumber prepends a number (varint encoded).
|
||||
func (c *Container) PrependNumber(n uint64) {
|
||||
c.Prepend(varint.Pack64(n))
|
||||
}
|
||||
|
||||
// AppendNumber appends a number (varint encoded).
|
||||
func (c *Container) AppendNumber(n uint64) {
|
||||
c.compartments = append(c.compartments, varint.Pack64(n))
|
||||
}
|
||||
|
||||
// PrependInt prepends an int (varint encoded).
|
||||
func (c *Container) PrependInt(n int) {
|
||||
c.Prepend(varint.Pack64(uint64(n)))
|
||||
}
|
||||
|
||||
// AppendInt appends an int (varint encoded).
|
||||
func (c *Container) AppendInt(n int) {
|
||||
c.compartments = append(c.compartments, varint.Pack64(uint64(n)))
|
||||
}
|
||||
|
||||
// AppendAsBlock appends the length of the data and the data itself. Data will NOT be copied.
|
||||
func (c *Container) AppendAsBlock(data []byte) {
|
||||
c.AppendNumber(uint64(len(data)))
|
||||
c.Append(data)
|
||||
}
|
||||
|
||||
// PrependAsBlock prepends the length of the data and the data itself. Data will NOT be copied.
|
||||
func (c *Container) PrependAsBlock(data []byte) {
|
||||
c.Prepend(data)
|
||||
c.PrependNumber(uint64(len(data)))
|
||||
}
|
||||
|
||||
// AppendContainer appends another Container. Data will NOT be copied.
|
||||
func (c *Container) AppendContainer(data *Container) {
|
||||
c.compartments = append(c.compartments, data.compartments...)
|
||||
}
|
||||
|
||||
// AppendContainerAsBlock appends another Container (length and data). Data will NOT be copied.
|
||||
func (c *Container) AppendContainerAsBlock(data *Container) {
|
||||
c.AppendNumber(uint64(data.Length()))
|
||||
c.compartments = append(c.compartments, data.compartments...)
|
||||
}
|
||||
|
||||
// HoldsData returns true if the Container holds any data.
|
||||
func (c *Container) HoldsData() bool {
|
||||
for i := c.offset; i < len(c.compartments); i++ {
|
||||
if len(c.compartments[i]) > 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Length returns the full length of all bytes held by the container.
|
||||
func (c *Container) Length() (length int) {
|
||||
for i := c.offset; i < len(c.compartments); i++ {
|
||||
length += len(c.compartments[i])
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Replace replaces all held data with a new data slice. Data will NOT be copied.
|
||||
func (c *Container) Replace(data []byte) {
|
||||
c.compartments = [][]byte{data}
|
||||
}
|
||||
|
||||
// CompileData concatenates all bytes held by the container and returns it as one single []byte slice. Data will NOT be copied and is NOT consumed.
|
||||
func (c *Container) CompileData() []byte {
|
||||
if len(c.compartments) != 1 {
|
||||
newBuf := make([]byte, c.Length())
|
||||
copyBuf := newBuf
|
||||
for i := c.offset; i < len(c.compartments); i++ {
|
||||
copy(copyBuf, c.compartments[i])
|
||||
copyBuf = copyBuf[len(c.compartments[i]):]
|
||||
}
|
||||
c.compartments = [][]byte{newBuf}
|
||||
c.offset = 0
|
||||
}
|
||||
return c.compartments[0]
|
||||
}
|
||||
|
||||
// Get returns the given amount of bytes. Data MAY be copied and IS consumed.
|
||||
func (c *Container) Get(n int) ([]byte, error) {
|
||||
buf := c.Peek(n)
|
||||
if len(buf) < n {
|
||||
return nil, errors.New("container: not enough data to return")
|
||||
}
|
||||
c.skip(len(buf))
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// GetAll returns all data. Data MAY be copied and IS consumed.
|
||||
func (c *Container) GetAll() []byte {
|
||||
// TODO: Improve.
|
||||
buf := c.Peek(c.Length())
|
||||
c.skip(len(buf))
|
||||
return buf
|
||||
}
|
||||
|
||||
// GetAsContainer returns the given amount of bytes in a new container. Data will NOT be copied and IS consumed.
|
||||
func (c *Container) GetAsContainer(n int) (*Container, error) {
|
||||
newC := c.PeekContainer(n)
|
||||
if newC == nil {
|
||||
return nil, errors.New("container: not enough data to return")
|
||||
}
|
||||
c.skip(n)
|
||||
return newC, nil
|
||||
}
|
||||
|
||||
// GetMax returns as much as possible, but the given amount of bytes at maximum. Data MAY be copied and IS consumed.
|
||||
func (c *Container) GetMax(n int) []byte {
|
||||
buf := c.Peek(n)
|
||||
c.skip(len(buf))
|
||||
return buf
|
||||
}
|
||||
|
||||
// WriteToSlice copies data to the give slice until it is full, or the container is empty. It returns the bytes written and if the container is now empty. Data IS copied and IS consumed.
|
||||
func (c *Container) WriteToSlice(slice []byte) (n int, containerEmptied bool) {
|
||||
for i := c.offset; i < len(c.compartments); i++ {
|
||||
copy(slice, c.compartments[i])
|
||||
if len(slice) < len(c.compartments[i]) {
|
||||
// only part was copied
|
||||
n += len(slice)
|
||||
c.compartments[i] = c.compartments[i][len(slice):]
|
||||
c.checkOffset()
|
||||
return n, false
|
||||
}
|
||||
// all was copied
|
||||
n += len(c.compartments[i])
|
||||
slice = slice[len(c.compartments[i]):]
|
||||
c.compartments[i] = nil
|
||||
c.offset = i + 1
|
||||
}
|
||||
c.checkOffset()
|
||||
return n, true
|
||||
}
|
||||
|
||||
// WriteAllTo writes all the data to the given io.Writer. Data IS NOT copied (but may be by writer) and IS NOT consumed.
|
||||
func (c *Container) WriteAllTo(writer io.Writer) error {
|
||||
for i := c.offset; i < len(c.compartments); i++ {
|
||||
written := 0
|
||||
for written < len(c.compartments[i]) {
|
||||
n, err := writer.Write(c.compartments[i][written:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
written += n
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Container) clean() {
|
||||
if c.offset > 100 {
|
||||
c.renewCompartments()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Container) renewCompartments() {
|
||||
baseLength := len(c.compartments) - c.offset + 5
|
||||
newCompartments := make([][]byte, baseLength, baseLength+5)
|
||||
copy(newCompartments[5:], c.compartments[c.offset:])
|
||||
c.compartments = newCompartments
|
||||
c.offset = 4
|
||||
}
|
||||
|
||||
func (c *Container) carbonCopy() *Container {
|
||||
newC := &Container{
|
||||
compartments: make([][]byte, len(c.compartments)),
|
||||
offset: c.offset,
|
||||
err: c.err,
|
||||
}
|
||||
copy(newC.compartments, c.compartments)
|
||||
return newC
|
||||
}
|
||||
|
||||
func (c *Container) checkOffset() {
|
||||
if c.offset >= len(c.compartments) {
|
||||
c.offset = len(c.compartments) / 2
|
||||
}
|
||||
}
|
||||
|
||||
// Block Handling
|
||||
|
||||
// PrependLength prepends the current full length of all bytes in the container.
|
||||
func (c *Container) PrependLength() {
|
||||
c.Prepend(varint.Pack64(uint64(c.Length())))
|
||||
}
|
||||
|
||||
// Peek returns the given amount of bytes. Data MAY be copied and IS NOT consumed.
|
||||
func (c *Container) Peek(n int) []byte {
|
||||
// Check requested length.
|
||||
if n <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if the first slice holds enough data.
|
||||
if len(c.compartments[c.offset]) >= n {
|
||||
return c.compartments[c.offset][:n]
|
||||
}
|
||||
|
||||
// Start gathering data.
|
||||
slice := make([]byte, n)
|
||||
copySlice := slice
|
||||
n = 0
|
||||
for i := c.offset; i < len(c.compartments); i++ {
|
||||
copy(copySlice, c.compartments[i])
|
||||
if len(copySlice) <= len(c.compartments[i]) {
|
||||
n += len(copySlice)
|
||||
return slice[:n]
|
||||
}
|
||||
n += len(c.compartments[i])
|
||||
copySlice = copySlice[len(c.compartments[i]):]
|
||||
}
|
||||
return slice[:n]
|
||||
}
|
||||
|
||||
// PeekContainer returns the given amount of bytes in a new container. Data will NOT be copied and IS NOT consumed.
|
||||
func (c *Container) PeekContainer(n int) (newC *Container) {
|
||||
// Check requested length.
|
||||
if n < 0 {
|
||||
return nil
|
||||
} else if n == 0 {
|
||||
return &Container{}
|
||||
}
|
||||
|
||||
newC = &Container{}
|
||||
for i := c.offset; i < len(c.compartments); i++ {
|
||||
if n >= len(c.compartments[i]) {
|
||||
newC.compartments = append(newC.compartments, c.compartments[i])
|
||||
n -= len(c.compartments[i])
|
||||
} else {
|
||||
newC.compartments = append(newC.compartments, c.compartments[i][:n])
|
||||
n = 0
|
||||
}
|
||||
}
|
||||
if n > 0 {
|
||||
return nil
|
||||
}
|
||||
return newC
|
||||
}
|
||||
|
||||
func (c *Container) skip(n int) {
|
||||
for i := c.offset; i < len(c.compartments); i++ {
|
||||
if len(c.compartments[i]) <= n {
|
||||
n -= len(c.compartments[i])
|
||||
c.offset = i + 1
|
||||
c.compartments[i] = nil
|
||||
if n == 0 {
|
||||
c.checkOffset()
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.compartments[i] = c.compartments[i][n:]
|
||||
c.checkOffset()
|
||||
return
|
||||
}
|
||||
}
|
||||
c.checkOffset()
|
||||
}
|
||||
|
||||
// GetNextBlock returns the next block of data defined by a varint. Data MAY be copied and IS consumed.
|
||||
func (c *Container) GetNextBlock() ([]byte, error) {
|
||||
blockSize, err := c.GetNextN64()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.Get(int(blockSize))
|
||||
}
|
||||
|
||||
// GetNextBlockAsContainer returns the next block of data as a Container defined by a varint. Data will NOT be copied and IS consumed.
|
||||
func (c *Container) GetNextBlockAsContainer() (*Container, error) {
|
||||
blockSize, err := c.GetNextN64()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.GetAsContainer(int(blockSize))
|
||||
}
|
||||
|
||||
// GetNextN8 parses and returns a varint of type uint8.
|
||||
func (c *Container) GetNextN8() (uint8, error) {
|
||||
buf := c.Peek(2)
|
||||
num, n, err := varint.Unpack8(buf)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
c.skip(n)
|
||||
return num, nil
|
||||
}
|
||||
|
||||
// GetNextN16 parses and returns a varint of type uint16.
|
||||
func (c *Container) GetNextN16() (uint16, error) {
|
||||
buf := c.Peek(3)
|
||||
num, n, err := varint.Unpack16(buf)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
c.skip(n)
|
||||
return num, nil
|
||||
}
|
||||
|
||||
// GetNextN32 parses and returns a varint of type uint32.
|
||||
func (c *Container) GetNextN32() (uint32, error) {
|
||||
buf := c.Peek(5)
|
||||
num, n, err := varint.Unpack32(buf)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
c.skip(n)
|
||||
return num, nil
|
||||
}
|
||||
|
||||
// GetNextN64 parses and returns a varint of type uint64.
|
||||
func (c *Container) GetNextN64() (uint64, error) {
|
||||
buf := c.Peek(10)
|
||||
num, n, err := varint.Unpack64(buf)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
c.skip(n)
|
||||
return num, nil
|
||||
}
|
208
base/container/container_test.go
Normal file
208
base/container/container_test.go
Normal file
|
@ -0,0 +1,208 @@
|
|||
package container
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/safing/portmaster/base/utils"
|
||||
)
|
||||
|
||||
var (
|
||||
testData = []byte("The quick brown fox jumps over the lazy dog")
|
||||
testDataSplitted = [][]byte{
|
||||
[]byte("T"),
|
||||
[]byte("he"),
|
||||
[]byte(" qu"),
|
||||
[]byte("ick "),
|
||||
[]byte("brown"),
|
||||
[]byte(" fox j"),
|
||||
[]byte("umps ov"),
|
||||
[]byte("er the l"),
|
||||
[]byte("azy dog"),
|
||||
}
|
||||
)
|
||||
|
||||
func TestContainerDataHandling(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c1 := New(utils.DuplicateBytes(testData))
|
||||
c1c := c1.carbonCopy()
|
||||
|
||||
c2 := New()
|
||||
for range len(testData) {
|
||||
oneByte := make([]byte, 1)
|
||||
c1c.WriteToSlice(oneByte)
|
||||
c2.Append(oneByte)
|
||||
}
|
||||
c2c := c2.carbonCopy()
|
||||
|
||||
c3 := New()
|
||||
for i := len(c2c.compartments) - 1; i >= c2c.offset; i-- {
|
||||
c3.Prepend(c2c.compartments[i])
|
||||
}
|
||||
c3c := c3.carbonCopy()
|
||||
|
||||
d4 := make([]byte, len(testData)*2)
|
||||
n, _ := c3c.WriteToSlice(d4)
|
||||
d4 = d4[:n]
|
||||
c3c = c3.carbonCopy()
|
||||
|
||||
d5 := make([]byte, len(testData))
|
||||
for i := range len(testData) {
|
||||
c3c.WriteToSlice(d5[i : i+1])
|
||||
}
|
||||
|
||||
c6 := New()
|
||||
c6.Replace(testData)
|
||||
|
||||
c7 := New(testDataSplitted[0])
|
||||
for i := 1; i < len(testDataSplitted); i++ {
|
||||
c7.Append(testDataSplitted[i])
|
||||
}
|
||||
|
||||
c8 := New(testDataSplitted...)
|
||||
for range 110 {
|
||||
c8.Prepend(nil)
|
||||
}
|
||||
c8.clean()
|
||||
|
||||
c9 := c8.PeekContainer(len(testData))
|
||||
|
||||
c10 := c9.PeekContainer(len(testData) - 1)
|
||||
c10.Append(testData[len(testData)-1:])
|
||||
|
||||
compareMany(t, testData, c1.CompileData(), c2.CompileData(), c3.CompileData(), d4, d5, c6.CompileData(), c7.CompileData(), c8.CompileData(), c9.CompileData(), c10.CompileData())
|
||||
}
|
||||
|
||||
func compareMany(t *testing.T, reference []byte, other ...[]byte) {
|
||||
t.Helper()
|
||||
|
||||
for i, cmp := range other {
|
||||
if !bytes.Equal(reference, cmp) {
|
||||
t.Errorf("sample %d does not match reference: sample is '%s'", i+1, string(cmp))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDataFetching(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c1 := New(utils.DuplicateBytes(testData))
|
||||
data := c1.GetMax(1)
|
||||
if string(data[0]) != "T" {
|
||||
t.Errorf("failed to GetMax(1), got %s, expected %s", string(data), "T")
|
||||
}
|
||||
|
||||
_, err := c1.Get(1000)
|
||||
if err == nil {
|
||||
t.Error("should fail")
|
||||
}
|
||||
|
||||
_, err = c1.GetAsContainer(1000)
|
||||
if err == nil {
|
||||
t.Error("should fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBlocks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c1 := New(utils.DuplicateBytes(testData))
|
||||
c1.PrependLength()
|
||||
|
||||
n, err := c1.GetNextN8()
|
||||
if err != nil {
|
||||
t.Errorf("GetNextN8() failed: %s", err)
|
||||
}
|
||||
if n != 43 {
|
||||
t.Errorf("n should be 43, was %d", n)
|
||||
}
|
||||
c1.PrependLength()
|
||||
|
||||
n2, err := c1.GetNextN16()
|
||||
if err != nil {
|
||||
t.Errorf("GetNextN16() failed: %s", err)
|
||||
}
|
||||
if n2 != 43 {
|
||||
t.Errorf("n should be 43, was %d", n2)
|
||||
}
|
||||
c1.PrependLength()
|
||||
|
||||
n3, err := c1.GetNextN32()
|
||||
if err != nil {
|
||||
t.Errorf("GetNextN32() failed: %s", err)
|
||||
}
|
||||
if n3 != 43 {
|
||||
t.Errorf("n should be 43, was %d", n3)
|
||||
}
|
||||
c1.PrependLength()
|
||||
|
||||
n4, err := c1.GetNextN64()
|
||||
if err != nil {
|
||||
t.Errorf("GetNextN64() failed: %s", err)
|
||||
}
|
||||
if n4 != 43 {
|
||||
t.Errorf("n should be 43, was %d", n4)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainerBlockHandling(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c1 := New(utils.DuplicateBytes(testData))
|
||||
c1.PrependLength()
|
||||
c1.AppendAsBlock(testData)
|
||||
c1c := c1.carbonCopy()
|
||||
|
||||
c2 := New(nil)
|
||||
for range c1.Length() {
|
||||
oneByte := make([]byte, 1)
|
||||
c1c.WriteToSlice(oneByte)
|
||||
c2.Append(oneByte)
|
||||
}
|
||||
|
||||
c3 := New(testDataSplitted[0])
|
||||
for i := 1; i < len(testDataSplitted); i++ {
|
||||
c3.Append(testDataSplitted[i])
|
||||
}
|
||||
c3.PrependLength()
|
||||
|
||||
d1, err := c1.GetNextBlock()
|
||||
if err != nil {
|
||||
t.Errorf("GetNextBlock failed: %s", err)
|
||||
}
|
||||
d2, err := c1.GetNextBlock()
|
||||
if err != nil {
|
||||
t.Errorf("GetNextBlock failed: %s", err)
|
||||
}
|
||||
d3, err := c2.GetNextBlock()
|
||||
if err != nil {
|
||||
t.Errorf("GetNextBlock failed: %s", err)
|
||||
}
|
||||
d4, err := c2.GetNextBlock()
|
||||
if err != nil {
|
||||
t.Errorf("GetNextBlock failed: %s", err)
|
||||
}
|
||||
d5, err := c3.GetNextBlock()
|
||||
if err != nil {
|
||||
t.Errorf("GetNextBlock failed: %s", err)
|
||||
}
|
||||
|
||||
compareMany(t, testData, d1, d2, d3, d4, d5)
|
||||
}
|
||||
|
||||
func TestContainerMisc(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c1 := New()
|
||||
d1 := c1.CompileData()
|
||||
if len(d1) > 0 {
|
||||
t.Fatalf("empty container should not hold any data")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeprecated(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
NewContainer(utils.DuplicateBytes(testData))
|
||||
}
|
26
base/container/doc.go
Normal file
26
base/container/doc.go
Normal file
|
@ -0,0 +1,26 @@
|
|||
// Package container gives you a []byte slice on steroids, allowing for quick data appending, prepending and fetching as well as transparent error transportation.
|
||||
//
|
||||
// A Container is basically a [][]byte slice that just appends new []byte slices and only copies things around when necessary.
|
||||
//
|
||||
// Byte slices added to the Container are not changed or appended, to not corrupt any other data that may be before and after the given slice.
|
||||
// If interested, consider the following example to understand why this is important:
|
||||
//
|
||||
// package main
|
||||
//
|
||||
// import (
|
||||
// "fmt"
|
||||
// )
|
||||
//
|
||||
// func main() {
|
||||
// a := []byte{0, 1,2,3,4,5,6,7,8,9}
|
||||
// fmt.Printf("a: %+v\n", a)
|
||||
// fmt.Printf("\nmaking changes...\n(we are not changing a directly)\n\n")
|
||||
// b := a[2:6]
|
||||
// c := append(b, 10, 11)
|
||||
// fmt.Printf("b: %+v\n", b)
|
||||
// fmt.Printf("c: %+v\n", c)
|
||||
// fmt.Printf("a: %+v\n", a)
|
||||
// }
|
||||
//
|
||||
// run it here: https://play.golang.org/p/xu1BXT3QYeE
|
||||
package container
|
21
base/container/serialization.go
Normal file
21
base/container/serialization.go
Normal file
|
@ -0,0 +1,21 @@
|
|||
package container
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// MarshalJSON serializes the container as a JSON byte array.
|
||||
func (c *Container) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(c.CompileData())
|
||||
}
|
||||
|
||||
// UnmarshalJSON unserializes a container from a JSON byte array.
|
||||
func (c *Container) UnmarshalJSON(data []byte) error {
|
||||
var raw []byte
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.compartments = [][]byte{raw}
|
||||
return nil
|
||||
}
|
116
base/database/accessor/accessor-json-bytes.go
Normal file
116
base/database/accessor/accessor-json-bytes.go
Normal file
|
@ -0,0 +1,116 @@
|
|||
package accessor
|
||||
|
||||
import (
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// JSONBytesAccessor is a json string with get functions.
|
||||
type JSONBytesAccessor struct {
|
||||
json *[]byte
|
||||
}
|
||||
|
||||
// NewJSONBytesAccessor adds the Accessor interface to a JSON bytes string.
|
||||
func NewJSONBytesAccessor(json *[]byte) *JSONBytesAccessor {
|
||||
return &JSONBytesAccessor{
|
||||
json: json,
|
||||
}
|
||||
}
|
||||
|
||||
// Set sets the value identified by key.
|
||||
func (ja *JSONBytesAccessor) Set(key string, value interface{}) error {
|
||||
result := gjson.GetBytes(*ja.json, key)
|
||||
if result.Exists() {
|
||||
err := checkJSONValueType(result, key, value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
newJSON, err := sjson.SetBytes(*ja.json, key, value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*ja.json = newJSON
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns the value found by the given json key and whether it could be successfully extracted.
|
||||
func (ja *JSONBytesAccessor) Get(key string) (value interface{}, ok bool) {
|
||||
result := gjson.GetBytes(*ja.json, key)
|
||||
if !result.Exists() {
|
||||
return nil, false
|
||||
}
|
||||
return result.Value(), true
|
||||
}
|
||||
|
||||
// GetString returns the string found by the given json key and whether it could be successfully extracted.
|
||||
func (ja *JSONBytesAccessor) GetString(key string) (value string, ok bool) {
|
||||
result := gjson.GetBytes(*ja.json, key)
|
||||
if !result.Exists() || result.Type != gjson.String {
|
||||
return emptyString, false
|
||||
}
|
||||
return result.String(), true
|
||||
}
|
||||
|
||||
// GetStringArray returns the []string found by the given json key and whether it could be successfully extracted.
|
||||
func (ja *JSONBytesAccessor) GetStringArray(key string) (value []string, ok bool) {
|
||||
result := gjson.GetBytes(*ja.json, key)
|
||||
if !result.Exists() && !result.IsArray() {
|
||||
return nil, false
|
||||
}
|
||||
slice := result.Array()
|
||||
sliceCopy := make([]string, len(slice))
|
||||
for i, res := range slice {
|
||||
if res.Type == gjson.String {
|
||||
sliceCopy[i] = res.String()
|
||||
} else {
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
return sliceCopy, true
|
||||
}
|
||||
|
||||
// GetInt returns the int found by the given json key and whether it could be successfully extracted.
|
||||
func (ja *JSONBytesAccessor) GetInt(key string) (value int64, ok bool) {
|
||||
result := gjson.GetBytes(*ja.json, key)
|
||||
if !result.Exists() || result.Type != gjson.Number {
|
||||
return 0, false
|
||||
}
|
||||
return result.Int(), true
|
||||
}
|
||||
|
||||
// GetFloat returns the float found by the given json key and whether it could be successfully extracted.
|
||||
func (ja *JSONBytesAccessor) GetFloat(key string) (value float64, ok bool) {
|
||||
result := gjson.GetBytes(*ja.json, key)
|
||||
if !result.Exists() || result.Type != gjson.Number {
|
||||
return 0, false
|
||||
}
|
||||
return result.Float(), true
|
||||
}
|
||||
|
||||
// GetBool returns the bool found by the given json key and whether it could be successfully extracted.
|
||||
func (ja *JSONBytesAccessor) GetBool(key string) (value bool, ok bool) {
|
||||
result := gjson.GetBytes(*ja.json, key)
|
||||
switch {
|
||||
case !result.Exists():
|
||||
return false, false
|
||||
case result.Type == gjson.True:
|
||||
return true, true
|
||||
case result.Type == gjson.False:
|
||||
return false, true
|
||||
default:
|
||||
return false, false
|
||||
}
|
||||
}
|
||||
|
||||
// Exists returns the whether the given key exists.
|
||||
func (ja *JSONBytesAccessor) Exists(key string) bool {
|
||||
result := gjson.GetBytes(*ja.json, key)
|
||||
return result.Exists()
|
||||
}
|
||||
|
||||
// Type returns the accessor type as a string.
|
||||
func (ja *JSONBytesAccessor) Type() string {
|
||||
return "JSONBytesAccessor"
|
||||
}
|
140
base/database/accessor/accessor-json-string.go
Normal file
140
base/database/accessor/accessor-json-string.go
Normal file
|
@ -0,0 +1,140 @@
|
|||
package accessor
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// JSONAccessor is a json string with get functions.
|
||||
type JSONAccessor struct {
|
||||
json *string
|
||||
}
|
||||
|
||||
// NewJSONAccessor adds the Accessor interface to a JSON string.
|
||||
func NewJSONAccessor(json *string) *JSONAccessor {
|
||||
return &JSONAccessor{
|
||||
json: json,
|
||||
}
|
||||
}
|
||||
|
||||
// Set sets the value identified by key.
|
||||
func (ja *JSONAccessor) Set(key string, value interface{}) error {
|
||||
result := gjson.Get(*ja.json, key)
|
||||
if result.Exists() {
|
||||
err := checkJSONValueType(result, key, value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
newJSON, err := sjson.Set(*ja.json, key, value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*ja.json = newJSON
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkJSONValueType(jsonValue gjson.Result, key string, value interface{}) error {
|
||||
switch value.(type) {
|
||||
case string:
|
||||
if jsonValue.Type != gjson.String {
|
||||
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, jsonValue.Type.String(), value)
|
||||
}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
|
||||
if jsonValue.Type != gjson.Number {
|
||||
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, jsonValue.Type.String(), value)
|
||||
}
|
||||
case bool:
|
||||
if jsonValue.Type != gjson.True && jsonValue.Type != gjson.False {
|
||||
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, jsonValue.Type.String(), value)
|
||||
}
|
||||
case []string:
|
||||
if !jsonValue.IsArray() {
|
||||
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, jsonValue.Type.String(), value)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns the value found by the given json key and whether it could be successfully extracted.
|
||||
func (ja *JSONAccessor) Get(key string) (value interface{}, ok bool) {
|
||||
result := gjson.Get(*ja.json, key)
|
||||
if !result.Exists() {
|
||||
return nil, false
|
||||
}
|
||||
return result.Value(), true
|
||||
}
|
||||
|
||||
// GetString returns the string found by the given json key and whether it could be successfully extracted.
|
||||
func (ja *JSONAccessor) GetString(key string) (value string, ok bool) {
|
||||
result := gjson.Get(*ja.json, key)
|
||||
if !result.Exists() || result.Type != gjson.String {
|
||||
return emptyString, false
|
||||
}
|
||||
return result.String(), true
|
||||
}
|
||||
|
||||
// GetStringArray returns the []string found by the given json key and whether it could be successfully extracted.
|
||||
func (ja *JSONAccessor) GetStringArray(key string) (value []string, ok bool) {
|
||||
result := gjson.Get(*ja.json, key)
|
||||
if !result.Exists() && !result.IsArray() {
|
||||
return nil, false
|
||||
}
|
||||
slice := result.Array()
|
||||
sliceCopy := make([]string, len(slice))
|
||||
for i, res := range slice {
|
||||
if res.Type == gjson.String {
|
||||
sliceCopy[i] = res.String()
|
||||
} else {
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
return sliceCopy, true
|
||||
}
|
||||
|
||||
// GetInt returns the int found by the given json key and whether it could be successfully extracted.
|
||||
func (ja *JSONAccessor) GetInt(key string) (value int64, ok bool) {
|
||||
result := gjson.Get(*ja.json, key)
|
||||
if !result.Exists() || result.Type != gjson.Number {
|
||||
return 0, false
|
||||
}
|
||||
return result.Int(), true
|
||||
}
|
||||
|
||||
// GetFloat returns the float found by the given json key and whether it could be successfully extracted.
|
||||
func (ja *JSONAccessor) GetFloat(key string) (value float64, ok bool) {
|
||||
result := gjson.Get(*ja.json, key)
|
||||
if !result.Exists() || result.Type != gjson.Number {
|
||||
return 0, false
|
||||
}
|
||||
return result.Float(), true
|
||||
}
|
||||
|
||||
// GetBool returns the bool found by the given json key and whether it could be successfully extracted.
|
||||
func (ja *JSONAccessor) GetBool(key string) (value bool, ok bool) {
|
||||
result := gjson.Get(*ja.json, key)
|
||||
switch {
|
||||
case !result.Exists():
|
||||
return false, false
|
||||
case result.Type == gjson.True:
|
||||
return true, true
|
||||
case result.Type == gjson.False:
|
||||
return false, true
|
||||
default:
|
||||
return false, false
|
||||
}
|
||||
}
|
||||
|
||||
// Exists returns the whether the given key exists.
|
||||
func (ja *JSONAccessor) Exists(key string) bool {
|
||||
result := gjson.Get(*ja.json, key)
|
||||
return result.Exists()
|
||||
}
|
||||
|
||||
// Type returns the accessor type as a string.
|
||||
func (ja *JSONAccessor) Type() string {
|
||||
return "JSONAccessor"
|
||||
}
|
169
base/database/accessor/accessor-struct.go
Normal file
169
base/database/accessor/accessor-struct.go
Normal file
|
@ -0,0 +1,169 @@
|
|||
package accessor
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// StructAccessor is a json string with get functions.
|
||||
type StructAccessor struct {
|
||||
object reflect.Value
|
||||
}
|
||||
|
||||
// NewStructAccessor adds the Accessor interface to a JSON string.
|
||||
func NewStructAccessor(object interface{}) *StructAccessor {
|
||||
return &StructAccessor{
|
||||
object: reflect.ValueOf(object).Elem(),
|
||||
}
|
||||
}
|
||||
|
||||
// Set sets the value identified by key.
|
||||
func (sa *StructAccessor) Set(key string, value interface{}) error {
|
||||
field := sa.object.FieldByName(key)
|
||||
if !field.IsValid() {
|
||||
return errors.New("struct field does not exist")
|
||||
}
|
||||
if !field.CanSet() {
|
||||
return fmt.Errorf("field %s or struct is immutable", field.String())
|
||||
}
|
||||
|
||||
newVal := reflect.ValueOf(value)
|
||||
|
||||
// set directly if type matches
|
||||
if newVal.Kind() == field.Kind() {
|
||||
field.Set(newVal)
|
||||
return nil
|
||||
}
|
||||
|
||||
// handle special cases
|
||||
switch field.Kind() { // nolint:exhaustive
|
||||
|
||||
// ints
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
var newInt int64
|
||||
switch newVal.Kind() { // nolint:exhaustive
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
newInt = newVal.Int()
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
newInt = int64(newVal.Uint())
|
||||
default:
|
||||
return fmt.Errorf("tried to set field %s (%s) to a %s value", key, field.Kind().String(), newVal.Kind().String())
|
||||
}
|
||||
if field.OverflowInt(newInt) {
|
||||
return fmt.Errorf("setting field %s (%s) to %d would overflow", key, field.Kind().String(), newInt)
|
||||
}
|
||||
field.SetInt(newInt)
|
||||
|
||||
// uints
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
var newUint uint64
|
||||
switch newVal.Kind() { // nolint:exhaustive
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
newUint = uint64(newVal.Int())
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
newUint = newVal.Uint()
|
||||
default:
|
||||
return fmt.Errorf("tried to set field %s (%s) to a %s value", key, field.Kind().String(), newVal.Kind().String())
|
||||
}
|
||||
if field.OverflowUint(newUint) {
|
||||
return fmt.Errorf("setting field %s (%s) to %d would overflow", key, field.Kind().String(), newUint)
|
||||
}
|
||||
field.SetUint(newUint)
|
||||
|
||||
// floats
|
||||
case reflect.Float32, reflect.Float64:
|
||||
switch newVal.Kind() { // nolint:exhaustive
|
||||
case reflect.Float32, reflect.Float64:
|
||||
field.SetFloat(newVal.Float())
|
||||
default:
|
||||
return fmt.Errorf("tried to set field %s (%s) to a %s value", key, field.Kind().String(), newVal.Kind().String())
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("tried to set field %s (%s) to a %s value", key, field.Kind().String(), newVal.Kind().String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns the value found by the given json key and whether it could be successfully extracted.
|
||||
func (sa *StructAccessor) Get(key string) (value interface{}, ok bool) {
|
||||
field := sa.object.FieldByName(key)
|
||||
if !field.IsValid() || !field.CanInterface() {
|
||||
return nil, false
|
||||
}
|
||||
return field.Interface(), true
|
||||
}
|
||||
|
||||
// GetString returns the string found by the given json key and whether it could be successfully extracted.
|
||||
func (sa *StructAccessor) GetString(key string) (value string, ok bool) {
|
||||
field := sa.object.FieldByName(key)
|
||||
if !field.IsValid() || field.Kind() != reflect.String {
|
||||
return "", false
|
||||
}
|
||||
return field.String(), true
|
||||
}
|
||||
|
||||
// GetStringArray returns the []string found by the given json key and whether it could be successfully extracted.
|
||||
func (sa *StructAccessor) GetStringArray(key string) (value []string, ok bool) {
|
||||
field := sa.object.FieldByName(key)
|
||||
if !field.IsValid() || field.Kind() != reflect.Slice || !field.CanInterface() {
|
||||
return nil, false
|
||||
}
|
||||
v := field.Interface()
|
||||
slice, ok := v.([]string)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return slice, true
|
||||
}
|
||||
|
||||
// GetInt returns the int found by the given json key and whether it could be successfully extracted.
|
||||
func (sa *StructAccessor) GetInt(key string) (value int64, ok bool) {
|
||||
field := sa.object.FieldByName(key)
|
||||
if !field.IsValid() {
|
||||
return 0, false
|
||||
}
|
||||
switch field.Kind() { // nolint:exhaustive
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return field.Int(), true
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return int64(field.Uint()), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// GetFloat returns the float found by the given json key and whether it could be successfully extracted.
|
||||
func (sa *StructAccessor) GetFloat(key string) (value float64, ok bool) {
|
||||
field := sa.object.FieldByName(key)
|
||||
if !field.IsValid() {
|
||||
return 0, false
|
||||
}
|
||||
switch field.Kind() { // nolint:exhaustive
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return field.Float(), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// GetBool returns the bool found by the given json key and whether it could be successfully extracted.
|
||||
func (sa *StructAccessor) GetBool(key string) (value bool, ok bool) {
|
||||
field := sa.object.FieldByName(key)
|
||||
if !field.IsValid() || field.Kind() != reflect.Bool {
|
||||
return false, false
|
||||
}
|
||||
return field.Bool(), true
|
||||
}
|
||||
|
||||
// Exists returns the whether the given key exists.
|
||||
func (sa *StructAccessor) Exists(key string) bool {
|
||||
field := sa.object.FieldByName(key)
|
||||
return field.IsValid()
|
||||
}
|
||||
|
||||
// Type returns the accessor type as a string.
|
||||
func (sa *StructAccessor) Type() string {
|
||||
return "StructAccessor"
|
||||
}
|
18
base/database/accessor/accessor.go
Normal file
18
base/database/accessor/accessor.go
Normal file
|
@ -0,0 +1,18 @@
|
|||
package accessor
|
||||
|
||||
const (
|
||||
emptyString = ""
|
||||
)
|
||||
|
||||
// Accessor provides an interface to supply the query matcher a method to retrieve values from an object.
|
||||
type Accessor interface {
|
||||
Get(key string) (value interface{}, ok bool)
|
||||
GetString(key string) (value string, ok bool)
|
||||
GetStringArray(key string) (value []string, ok bool)
|
||||
GetInt(key string) (value int64, ok bool)
|
||||
GetFloat(key string) (value float64, ok bool)
|
||||
GetBool(key string) (value bool, ok bool)
|
||||
Exists(key string) bool
|
||||
Set(key string, value interface{}) error
|
||||
Type() string
|
||||
}
|
291
base/database/accessor/accessor_test.go
Normal file
291
base/database/accessor/accessor_test.go
Normal file
|
@ -0,0 +1,291 @@
|
|||
//nolint:maligned,unparam
|
||||
package accessor
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/safing/portmaster/base/utils"
|
||||
)
|
||||
|
||||
type TestStruct struct {
|
||||
S string
|
||||
A []string
|
||||
I int
|
||||
I8 int8
|
||||
I16 int16
|
||||
I32 int32
|
||||
I64 int64
|
||||
UI uint
|
||||
UI8 uint8
|
||||
UI16 uint16
|
||||
UI32 uint32
|
||||
UI64 uint64
|
||||
F32 float32
|
||||
F64 float64
|
||||
B bool
|
||||
}
|
||||
|
||||
var (
|
||||
testStruct = &TestStruct{
|
||||
S: "banana",
|
||||
A: []string{"black", "white"},
|
||||
I: 42,
|
||||
I8: 42,
|
||||
I16: 42,
|
||||
I32: 42,
|
||||
I64: 42,
|
||||
UI: 42,
|
||||
UI8: 42,
|
||||
UI16: 42,
|
||||
UI32: 42,
|
||||
UI64: 42,
|
||||
F32: 42.42,
|
||||
F64: 42.42,
|
||||
B: true,
|
||||
}
|
||||
testJSONBytes, _ = json.Marshal(testStruct) //nolint:errchkjson
|
||||
testJSON = string(testJSONBytes)
|
||||
)
|
||||
|
||||
func testGetString(t *testing.T, acc Accessor, key string, shouldSucceed bool, expectedValue string) {
|
||||
t.Helper()
|
||||
|
||||
v, ok := acc.GetString(key)
|
||||
switch {
|
||||
case !ok && shouldSucceed:
|
||||
t.Errorf("%s failed to get string with key %s", acc.Type(), key)
|
||||
case ok && !shouldSucceed:
|
||||
t.Errorf("%s should have failed to get string with key %s, it returned %v", acc.Type(), key, v)
|
||||
}
|
||||
if v != expectedValue {
|
||||
t.Errorf("%s returned an unexpected value: wanted %v, got %v", acc.Type(), expectedValue, v)
|
||||
}
|
||||
}
|
||||
|
||||
func testGetStringArray(t *testing.T, acc Accessor, key string, shouldSucceed bool, expectedValue []string) {
|
||||
t.Helper()
|
||||
|
||||
v, ok := acc.GetStringArray(key)
|
||||
switch {
|
||||
case !ok && shouldSucceed:
|
||||
t.Errorf("%s failed to get []string with key %s", acc.Type(), key)
|
||||
case ok && !shouldSucceed:
|
||||
t.Errorf("%s should have failed to get []string with key %s, it returned %v", acc.Type(), key, v)
|
||||
}
|
||||
if !utils.StringSliceEqual(v, expectedValue) {
|
||||
t.Errorf("%s returned an unexpected value: wanted %v, got %v", acc.Type(), expectedValue, v)
|
||||
}
|
||||
}
|
||||
|
||||
func testGetInt(t *testing.T, acc Accessor, key string, shouldSucceed bool, expectedValue int64) {
|
||||
t.Helper()
|
||||
|
||||
v, ok := acc.GetInt(key)
|
||||
switch {
|
||||
case !ok && shouldSucceed:
|
||||
t.Errorf("%s failed to get int with key %s", acc.Type(), key)
|
||||
case ok && !shouldSucceed:
|
||||
t.Errorf("%s should have failed to get int with key %s, it returned %v", acc.Type(), key, v)
|
||||
}
|
||||
if v != expectedValue {
|
||||
t.Errorf("%s returned an unexpected value: wanted %v, got %v", acc.Type(), expectedValue, v)
|
||||
}
|
||||
}
|
||||
|
||||
func testGetFloat(t *testing.T, acc Accessor, key string, shouldSucceed bool, expectedValue float64) {
|
||||
t.Helper()
|
||||
|
||||
v, ok := acc.GetFloat(key)
|
||||
switch {
|
||||
case !ok && shouldSucceed:
|
||||
t.Errorf("%s failed to get float with key %s", acc.Type(), key)
|
||||
case ok && !shouldSucceed:
|
||||
t.Errorf("%s should have failed to get float with key %s, it returned %v", acc.Type(), key, v)
|
||||
}
|
||||
if int64(v) != int64(expectedValue) {
|
||||
t.Errorf("%s returned an unexpected value: wanted %v, got %v", acc.Type(), expectedValue, v)
|
||||
}
|
||||
}
|
||||
|
||||
func testGetBool(t *testing.T, acc Accessor, key string, shouldSucceed bool, expectedValue bool) {
|
||||
t.Helper()
|
||||
|
||||
v, ok := acc.GetBool(key)
|
||||
switch {
|
||||
case !ok && shouldSucceed:
|
||||
t.Errorf("%s failed to get bool with key %s", acc.Type(), key)
|
||||
case ok && !shouldSucceed:
|
||||
t.Errorf("%s should have failed to get bool with key %s, it returned %v", acc.Type(), key, v)
|
||||
}
|
||||
if v != expectedValue {
|
||||
t.Errorf("%s returned an unexpected value: wanted %v, got %v", acc.Type(), expectedValue, v)
|
||||
}
|
||||
}
|
||||
|
||||
func testExists(t *testing.T, acc Accessor, key string, shouldSucceed bool) {
|
||||
t.Helper()
|
||||
|
||||
ok := acc.Exists(key)
|
||||
switch {
|
||||
case !ok && shouldSucceed:
|
||||
t.Errorf("%s should report key %s as existing", acc.Type(), key)
|
||||
case ok && !shouldSucceed:
|
||||
t.Errorf("%s should report key %s as non-existing", acc.Type(), key)
|
||||
}
|
||||
}
|
||||
|
||||
func testSet(t *testing.T, acc Accessor, key string, shouldSucceed bool, valueToSet interface{}) {
|
||||
t.Helper()
|
||||
|
||||
err := acc.Set(key, valueToSet)
|
||||
switch {
|
||||
case err != nil && shouldSucceed:
|
||||
t.Errorf("%s failed to set %s to %+v: %s", acc.Type(), key, valueToSet, err)
|
||||
case err == nil && !shouldSucceed:
|
||||
t.Errorf("%s should have failed to set %s to %+v", acc.Type(), key, valueToSet)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessor(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Test interface compliance.
|
||||
accs := []Accessor{
|
||||
NewJSONAccessor(&testJSON),
|
||||
NewJSONBytesAccessor(&testJSONBytes),
|
||||
NewStructAccessor(testStruct),
|
||||
}
|
||||
|
||||
// get
|
||||
for _, acc := range accs {
|
||||
testGetString(t, acc, "S", true, "banana")
|
||||
testGetStringArray(t, acc, "A", true, []string{"black", "white"})
|
||||
testGetInt(t, acc, "I", true, 42)
|
||||
testGetInt(t, acc, "I8", true, 42)
|
||||
testGetInt(t, acc, "I16", true, 42)
|
||||
testGetInt(t, acc, "I32", true, 42)
|
||||
testGetInt(t, acc, "I64", true, 42)
|
||||
testGetInt(t, acc, "UI", true, 42)
|
||||
testGetInt(t, acc, "UI8", true, 42)
|
||||
testGetInt(t, acc, "UI16", true, 42)
|
||||
testGetInt(t, acc, "UI32", true, 42)
|
||||
testGetInt(t, acc, "UI64", true, 42)
|
||||
testGetFloat(t, acc, "F32", true, 42.42)
|
||||
testGetFloat(t, acc, "F64", true, 42.42)
|
||||
testGetBool(t, acc, "B", true, true)
|
||||
}
|
||||
|
||||
// set
|
||||
for _, acc := range accs {
|
||||
testSet(t, acc, "S", true, "coconut")
|
||||
testSet(t, acc, "A", true, []string{"green", "blue"})
|
||||
testSet(t, acc, "I", true, uint32(44))
|
||||
testSet(t, acc, "I8", true, uint64(44))
|
||||
testSet(t, acc, "I16", true, uint8(44))
|
||||
testSet(t, acc, "I32", true, uint16(44))
|
||||
testSet(t, acc, "I64", true, 44)
|
||||
testSet(t, acc, "UI", true, 44)
|
||||
testSet(t, acc, "UI8", true, int64(44))
|
||||
testSet(t, acc, "UI16", true, int32(44))
|
||||
testSet(t, acc, "UI32", true, int8(44))
|
||||
testSet(t, acc, "UI64", true, int16(44))
|
||||
testSet(t, acc, "F32", true, 44.44)
|
||||
testSet(t, acc, "F64", true, 44.44)
|
||||
testSet(t, acc, "B", true, false)
|
||||
}
|
||||
|
||||
// get again to check if new values were set
|
||||
for _, acc := range accs {
|
||||
testGetString(t, acc, "S", true, "coconut")
|
||||
testGetStringArray(t, acc, "A", true, []string{"green", "blue"})
|
||||
testGetInt(t, acc, "I", true, 44)
|
||||
testGetInt(t, acc, "I8", true, 44)
|
||||
testGetInt(t, acc, "I16", true, 44)
|
||||
testGetInt(t, acc, "I32", true, 44)
|
||||
testGetInt(t, acc, "I64", true, 44)
|
||||
testGetInt(t, acc, "UI", true, 44)
|
||||
testGetInt(t, acc, "UI8", true, 44)
|
||||
testGetInt(t, acc, "UI16", true, 44)
|
||||
testGetInt(t, acc, "UI32", true, 44)
|
||||
testGetInt(t, acc, "UI64", true, 44)
|
||||
testGetFloat(t, acc, "F32", true, 44.44)
|
||||
testGetFloat(t, acc, "F64", true, 44.44)
|
||||
testGetBool(t, acc, "B", true, false)
|
||||
}
|
||||
|
||||
// failures
|
||||
for _, acc := range accs {
|
||||
testSet(t, acc, "S", false, true)
|
||||
testSet(t, acc, "S", false, false)
|
||||
testSet(t, acc, "S", false, 1)
|
||||
testSet(t, acc, "S", false, 1.1)
|
||||
|
||||
testSet(t, acc, "A", false, "1")
|
||||
testSet(t, acc, "A", false, true)
|
||||
testSet(t, acc, "A", false, false)
|
||||
testSet(t, acc, "A", false, 1)
|
||||
testSet(t, acc, "A", false, 1.1)
|
||||
|
||||
testSet(t, acc, "I", false, "1")
|
||||
testSet(t, acc, "I8", false, "1")
|
||||
testSet(t, acc, "I16", false, "1")
|
||||
testSet(t, acc, "I32", false, "1")
|
||||
testSet(t, acc, "I64", false, "1")
|
||||
testSet(t, acc, "UI", false, "1")
|
||||
testSet(t, acc, "UI8", false, "1")
|
||||
testSet(t, acc, "UI16", false, "1")
|
||||
testSet(t, acc, "UI32", false, "1")
|
||||
testSet(t, acc, "UI64", false, "1")
|
||||
|
||||
testSet(t, acc, "F32", false, "1.1")
|
||||
testSet(t, acc, "F64", false, "1.1")
|
||||
|
||||
testSet(t, acc, "B", false, "false")
|
||||
testSet(t, acc, "B", false, 1)
|
||||
testSet(t, acc, "B", false, 1.1)
|
||||
}
|
||||
|
||||
// get again to check if values werent changed when an error occurred
|
||||
for _, acc := range accs {
|
||||
testGetString(t, acc, "S", true, "coconut")
|
||||
testGetStringArray(t, acc, "A", true, []string{"green", "blue"})
|
||||
testGetInt(t, acc, "I", true, 44)
|
||||
testGetInt(t, acc, "I8", true, 44)
|
||||
testGetInt(t, acc, "I16", true, 44)
|
||||
testGetInt(t, acc, "I32", true, 44)
|
||||
testGetInt(t, acc, "I64", true, 44)
|
||||
testGetInt(t, acc, "UI", true, 44)
|
||||
testGetInt(t, acc, "UI8", true, 44)
|
||||
testGetInt(t, acc, "UI16", true, 44)
|
||||
testGetInt(t, acc, "UI32", true, 44)
|
||||
testGetInt(t, acc, "UI64", true, 44)
|
||||
testGetFloat(t, acc, "F32", true, 44.44)
|
||||
testGetFloat(t, acc, "F64", true, 44.44)
|
||||
testGetBool(t, acc, "B", true, false)
|
||||
}
|
||||
|
||||
// test existence
|
||||
for _, acc := range accs {
|
||||
testExists(t, acc, "S", true)
|
||||
testExists(t, acc, "A", true)
|
||||
testExists(t, acc, "I", true)
|
||||
testExists(t, acc, "I8", true)
|
||||
testExists(t, acc, "I16", true)
|
||||
testExists(t, acc, "I32", true)
|
||||
testExists(t, acc, "I64", true)
|
||||
testExists(t, acc, "UI", true)
|
||||
testExists(t, acc, "UI8", true)
|
||||
testExists(t, acc, "UI16", true)
|
||||
testExists(t, acc, "UI32", true)
|
||||
testExists(t, acc, "UI64", true)
|
||||
testExists(t, acc, "F32", true)
|
||||
testExists(t, acc, "F64", true)
|
||||
testExists(t, acc, "B", true)
|
||||
}
|
||||
|
||||
// test non-existence
|
||||
for _, acc := range accs {
|
||||
testExists(t, acc, "X", false)
|
||||
}
|
||||
}
|
65
base/database/boilerplate_test.go
Normal file
65
base/database/boilerplate_test.go
Normal file
|
@ -0,0 +1,65 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
)
|
||||
|
||||
type Example struct {
|
||||
record.Base
|
||||
sync.Mutex
|
||||
|
||||
Name string
|
||||
Score int
|
||||
}
|
||||
|
||||
var exampleDB = NewInterface(&Options{
|
||||
Internal: true,
|
||||
Local: true,
|
||||
})
|
||||
|
||||
// GetExample gets an Example from the database.
|
||||
func GetExample(key string) (*Example, error) {
|
||||
r, err := exampleDB.Get(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// unwrap
|
||||
if r.IsWrapped() {
|
||||
// only allocate a new struct, if we need it
|
||||
newExample := &Example{}
|
||||
err = record.Unwrap(r, newExample)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newExample, nil
|
||||
}
|
||||
|
||||
// or adjust type
|
||||
newExample, ok := r.(*Example)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("record not of type *Example, but %T", r)
|
||||
}
|
||||
return newExample, nil
|
||||
}
|
||||
|
||||
func (e *Example) Save() error {
|
||||
return exampleDB.Put(e)
|
||||
}
|
||||
|
||||
func (e *Example) SaveAs(key string) error {
|
||||
e.SetKey(key)
|
||||
return exampleDB.PutNew(e)
|
||||
}
|
||||
|
||||
func NewExample(key, name string, score int) *Example {
|
||||
newExample := &Example{
|
||||
Name: name,
|
||||
Score: score,
|
||||
}
|
||||
newExample.SetKey(key)
|
||||
return newExample
|
||||
}
|
355
base/database/controller.go
Normal file
355
base/database/controller.go
Normal file
|
@ -0,0 +1,355 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/base/database/iterator"
|
||||
"github.com/safing/portmaster/base/database/query"
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
"github.com/safing/portmaster/base/database/storage"
|
||||
)
|
||||
|
||||
// A Controller takes care of all the extra database logic.
|
||||
type Controller struct {
|
||||
database *Database
|
||||
storage storage.Interface
|
||||
shadowDelete bool
|
||||
|
||||
hooksLock sync.RWMutex
|
||||
hooks []*RegisteredHook
|
||||
|
||||
subscriptionLock sync.RWMutex
|
||||
subscriptions []*Subscription
|
||||
}
|
||||
|
||||
// newController creates a new controller for a storage.
|
||||
func newController(database *Database, storageInt storage.Interface, shadowDelete bool) *Controller {
|
||||
return &Controller{
|
||||
database: database,
|
||||
storage: storageInt,
|
||||
shadowDelete: shadowDelete,
|
||||
}
|
||||
}
|
||||
|
||||
// ReadOnly returns whether the storage is read only.
|
||||
func (c *Controller) ReadOnly() bool {
|
||||
return c.storage.ReadOnly()
|
||||
}
|
||||
|
||||
// Injected returns whether the storage is injected.
|
||||
func (c *Controller) Injected() bool {
|
||||
return c.storage.Injected()
|
||||
}
|
||||
|
||||
// Get returns the record with the given key.
|
||||
func (c *Controller) Get(key string) (record.Record, error) {
|
||||
if shuttingDown.IsSet() {
|
||||
return nil, ErrShuttingDown
|
||||
}
|
||||
|
||||
if err := c.runPreGetHooks(key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r, err := c.storage.Get(key)
|
||||
if err != nil {
|
||||
// replace not found error
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
r, err = c.runPostGetHooks(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !r.Meta().CheckValidity() {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// GetMeta returns the metadata of the record with the given key.
|
||||
func (c *Controller) GetMeta(key string) (*record.Meta, error) {
|
||||
if shuttingDown.IsSet() {
|
||||
return nil, ErrShuttingDown
|
||||
}
|
||||
|
||||
var m *record.Meta
|
||||
var err error
|
||||
if metaDB, ok := c.storage.(storage.MetaHandler); ok {
|
||||
m, err = metaDB.GetMeta(key)
|
||||
if err != nil {
|
||||
// replace not found error
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
r, err := c.storage.Get(key)
|
||||
if err != nil {
|
||||
// replace not found error
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
m = r.Meta()
|
||||
}
|
||||
|
||||
if !m.CheckValidity() {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Put saves a record in the database, executes any registered
|
||||
// pre-put hooks and finally send an update to all subscribers.
|
||||
// The record must be locked and secured from concurrent access
|
||||
// when calling Put().
|
||||
func (c *Controller) Put(r record.Record) (err error) {
|
||||
if shuttingDown.IsSet() {
|
||||
return ErrShuttingDown
|
||||
}
|
||||
|
||||
if c.ReadOnly() {
|
||||
return ErrReadOnly
|
||||
}
|
||||
|
||||
r, err = c.runPrePutHooks(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !c.shadowDelete && r.Meta().IsDeleted() {
|
||||
// Immediate delete.
|
||||
err = c.storage.Delete(r.DatabaseKey())
|
||||
} else {
|
||||
// Put or shadow delete.
|
||||
r, err = c.storage.Put(r)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if r == nil {
|
||||
return errors.New("storage returned nil record after successful put operation")
|
||||
}
|
||||
|
||||
c.notifySubscribers(r)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PutMany stores many records in the database. It does not
|
||||
// process any hooks or update subscriptions. Use with care!
|
||||
func (c *Controller) PutMany() (chan<- record.Record, <-chan error) {
|
||||
if shuttingDown.IsSet() {
|
||||
errs := make(chan error, 1)
|
||||
errs <- ErrShuttingDown
|
||||
return make(chan record.Record), errs
|
||||
}
|
||||
|
||||
if c.ReadOnly() {
|
||||
errs := make(chan error, 1)
|
||||
errs <- ErrReadOnly
|
||||
return make(chan record.Record), errs
|
||||
}
|
||||
|
||||
if batcher, ok := c.storage.(storage.Batcher); ok {
|
||||
return batcher.PutMany(c.shadowDelete)
|
||||
}
|
||||
|
||||
errs := make(chan error, 1)
|
||||
errs <- ErrNotImplemented
|
||||
return make(chan record.Record), errs
|
||||
}
|
||||
|
||||
// Query executes the given query on the database.
|
||||
func (c *Controller) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
|
||||
if shuttingDown.IsSet() {
|
||||
return nil, ErrShuttingDown
|
||||
}
|
||||
|
||||
it, err := c.storage.Query(q, local, internal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return it, nil
|
||||
}
|
||||
|
||||
// PushUpdate pushes a record update to subscribers.
|
||||
// The caller must hold the record's lock when calling
|
||||
// PushUpdate.
|
||||
func (c *Controller) PushUpdate(r record.Record) {
|
||||
if c != nil {
|
||||
if shuttingDown.IsSet() {
|
||||
return
|
||||
}
|
||||
|
||||
c.notifySubscribers(r)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) addSubscription(sub *Subscription) {
|
||||
if shuttingDown.IsSet() {
|
||||
return
|
||||
}
|
||||
|
||||
c.subscriptionLock.Lock()
|
||||
defer c.subscriptionLock.Unlock()
|
||||
|
||||
c.subscriptions = append(c.subscriptions, sub)
|
||||
}
|
||||
|
||||
// Maintain runs the Maintain method on the storage.
|
||||
func (c *Controller) Maintain(ctx context.Context) error {
|
||||
if shuttingDown.IsSet() {
|
||||
return ErrShuttingDown
|
||||
}
|
||||
|
||||
if maintainer, ok := c.storage.(storage.Maintainer); ok {
|
||||
return maintainer.Maintain(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MaintainThorough runs the MaintainThorough method on the
|
||||
// storage.
|
||||
func (c *Controller) MaintainThorough(ctx context.Context) error {
|
||||
if shuttingDown.IsSet() {
|
||||
return ErrShuttingDown
|
||||
}
|
||||
|
||||
if maintainer, ok := c.storage.(storage.Maintainer); ok {
|
||||
return maintainer.MaintainThorough(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MaintainRecordStates runs the record state lifecycle
|
||||
// maintenance on the storage.
|
||||
func (c *Controller) MaintainRecordStates(ctx context.Context, purgeDeletedBefore time.Time) error {
|
||||
if shuttingDown.IsSet() {
|
||||
return ErrShuttingDown
|
||||
}
|
||||
|
||||
return c.storage.MaintainRecordStates(ctx, purgeDeletedBefore, c.shadowDelete)
|
||||
}
|
||||
|
||||
// Purge deletes all records that match the given query.
|
||||
// It returns the number of successful deletes and an error.
|
||||
func (c *Controller) Purge(ctx context.Context, q *query.Query, local, internal bool) (int, error) {
|
||||
if shuttingDown.IsSet() {
|
||||
return 0, ErrShuttingDown
|
||||
}
|
||||
|
||||
if purger, ok := c.storage.(storage.Purger); ok {
|
||||
return purger.Purge(ctx, q, local, internal, c.shadowDelete)
|
||||
}
|
||||
|
||||
return 0, ErrNotImplemented
|
||||
}
|
||||
|
||||
// Shutdown shuts down the storage.
|
||||
func (c *Controller) Shutdown() error {
|
||||
return c.storage.Shutdown()
|
||||
}
|
||||
|
||||
// notifySubscribers notifies all subscribers that are interested
|
||||
// in r. r must be locked when calling notifySubscribers.
|
||||
// Any subscriber that is not blocking on it's feed channel will
|
||||
// be skipped.
|
||||
func (c *Controller) notifySubscribers(r record.Record) {
|
||||
c.subscriptionLock.RLock()
|
||||
defer c.subscriptionLock.RUnlock()
|
||||
|
||||
for _, sub := range c.subscriptions {
|
||||
if r.Meta().CheckPermission(sub.local, sub.internal) && sub.q.Matches(r) {
|
||||
select {
|
||||
case sub.Feed <- r:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) runPreGetHooks(key string) error {
|
||||
c.hooksLock.RLock()
|
||||
defer c.hooksLock.RUnlock()
|
||||
|
||||
for _, hook := range c.hooks {
|
||||
if !hook.h.UsesPreGet() {
|
||||
continue
|
||||
}
|
||||
|
||||
if !hook.q.MatchesKey(key) {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := hook.h.PreGet(key); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) runPostGetHooks(r record.Record) (record.Record, error) {
|
||||
c.hooksLock.RLock()
|
||||
defer c.hooksLock.RUnlock()
|
||||
|
||||
var err error
|
||||
for _, hook := range c.hooks {
|
||||
if !hook.h.UsesPostGet() {
|
||||
continue
|
||||
}
|
||||
|
||||
if !hook.q.Matches(r) {
|
||||
continue
|
||||
}
|
||||
|
||||
r, err = hook.h.PostGet(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (c *Controller) runPrePutHooks(r record.Record) (record.Record, error) {
|
||||
c.hooksLock.RLock()
|
||||
defer c.hooksLock.RUnlock()
|
||||
|
||||
var err error
|
||||
for _, hook := range c.hooks {
|
||||
if !hook.h.UsesPrePut() {
|
||||
continue
|
||||
}
|
||||
|
||||
if !hook.q.Matches(r) {
|
||||
continue
|
||||
}
|
||||
|
||||
r, err = hook.h.PrePut(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
106
base/database/controllers.go
Normal file
106
base/database/controllers.go
Normal file
|
@ -0,0 +1,106 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/safing/portmaster/base/database/storage"
|
||||
)
|
||||
|
||||
// StorageTypeInjected is the type of injected databases.
|
||||
const StorageTypeInjected = "injected"
|
||||
|
||||
var (
|
||||
controllers = make(map[string]*Controller)
|
||||
controllersLock sync.RWMutex
|
||||
)
|
||||
|
||||
func getController(name string) (*Controller, error) {
|
||||
if !initialized.IsSet() {
|
||||
return nil, errors.New("database not initialized")
|
||||
}
|
||||
|
||||
// return database if already started
|
||||
controllersLock.RLock()
|
||||
controller, ok := controllers[name]
|
||||
controllersLock.RUnlock()
|
||||
if ok {
|
||||
return controller, nil
|
||||
}
|
||||
|
||||
controllersLock.Lock()
|
||||
defer controllersLock.Unlock()
|
||||
|
||||
if shuttingDown.IsSet() {
|
||||
return nil, ErrShuttingDown
|
||||
}
|
||||
|
||||
// get db registration
|
||||
registeredDB, err := getDatabase(name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not start database %s: %w", name, err)
|
||||
}
|
||||
|
||||
// Check if database is injected.
|
||||
if registeredDB.StorageType == StorageTypeInjected {
|
||||
return nil, fmt.Errorf("database storage is not injected")
|
||||
}
|
||||
|
||||
// get location
|
||||
dbLocation, err := getLocation(name, registeredDB.StorageType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not start database %s (type %s): %w", name, registeredDB.StorageType, err)
|
||||
}
|
||||
|
||||
// start database
|
||||
storageInt, err := storage.StartDatabase(name, registeredDB.StorageType, dbLocation)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not start database %s (type %s): %w", name, registeredDB.StorageType, err)
|
||||
}
|
||||
|
||||
controller = newController(registeredDB, storageInt, registeredDB.ShadowDelete)
|
||||
controllers[name] = controller
|
||||
return controller, nil
|
||||
}
|
||||
|
||||
// InjectDatabase injects an already running database into the system.
|
||||
func InjectDatabase(name string, storageInt storage.Interface) (*Controller, error) {
|
||||
controllersLock.Lock()
|
||||
defer controllersLock.Unlock()
|
||||
|
||||
if shuttingDown.IsSet() {
|
||||
return nil, ErrShuttingDown
|
||||
}
|
||||
|
||||
_, ok := controllers[name]
|
||||
if ok {
|
||||
return nil, fmt.Errorf(`database "%s" already loaded`, name)
|
||||
}
|
||||
|
||||
registryLock.Lock()
|
||||
defer registryLock.Unlock()
|
||||
|
||||
// check if database is registered
|
||||
registeredDB, ok := registry[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("database %q not registered", name)
|
||||
}
|
||||
if registeredDB.StorageType != StorageTypeInjected {
|
||||
return nil, fmt.Errorf("database not of type %q", StorageTypeInjected)
|
||||
}
|
||||
|
||||
controller := newController(registeredDB, storageInt, false)
|
||||
controllers[name] = controller
|
||||
return controller, nil
|
||||
}
|
||||
|
||||
// Withdraw withdraws an injected database, but leaves the database registered.
|
||||
func (c *Controller) Withdraw() {
|
||||
if c != nil && c.Injected() {
|
||||
controllersLock.Lock()
|
||||
defer controllersLock.Unlock()
|
||||
|
||||
delete(controllers, c.database.Name)
|
||||
}
|
||||
}
|
26
base/database/database.go
Normal file
26
base/database/database.go
Normal file
|
@ -0,0 +1,26 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Database holds information about a registered database.
|
||||
type Database struct {
|
||||
Name string
|
||||
Description string
|
||||
StorageType string
|
||||
ShadowDelete bool // Whether deleted records should be kept until purged.
|
||||
Registered time.Time
|
||||
LastUpdated time.Time
|
||||
LastLoaded time.Time
|
||||
}
|
||||
|
||||
// Loaded updates the LastLoaded timestamp.
|
||||
func (db *Database) Loaded() {
|
||||
db.LastLoaded = time.Now().Round(time.Second)
|
||||
}
|
||||
|
||||
// Updated updates the LastUpdated timestamp.
|
||||
func (db *Database) Updated() {
|
||||
db.LastUpdated = time.Now().Round(time.Second)
|
||||
}
|
303
base/database/database_test.go
Normal file
303
base/database/database_test.go
Normal file
|
@ -0,0 +1,303 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"reflect"
|
||||
"runtime/pprof"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
q "github.com/safing/portmaster/base/database/query"
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
"github.com/safing/portmaster/base/database/storage"
|
||||
_ "github.com/safing/portmaster/base/database/storage/badger"
|
||||
_ "github.com/safing/portmaster/base/database/storage/bbolt"
|
||||
_ "github.com/safing/portmaster/base/database/storage/fstree"
|
||||
_ "github.com/safing/portmaster/base/database/storage/hashmap"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
testDir, err := os.MkdirTemp("", "portbase-database-testing-")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = InitializeWithPath(testDir)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
exitCode := m.Run()
|
||||
|
||||
// Clean up the test directory.
|
||||
// Do not defer, as we end this function with a os.Exit call.
|
||||
_ = os.RemoveAll(testDir)
|
||||
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
|
||||
func makeKey(dbName, key string) string {
|
||||
return fmt.Sprintf("%s:%s", dbName, key)
|
||||
}
|
||||
|
||||
func testDatabase(t *testing.T, storageType string, shadowDelete bool) { //nolint:maintidx,thelper
|
||||
t.Run(fmt.Sprintf("TestStorage_%s_%v", storageType, shadowDelete), func(t *testing.T) {
|
||||
dbName := fmt.Sprintf("testing-%s-%v", storageType, shadowDelete)
|
||||
fmt.Println(dbName)
|
||||
_, err := Register(&Database{
|
||||
Name: dbName,
|
||||
Description: fmt.Sprintf("Unit Test Database for %s", storageType),
|
||||
StorageType: storageType,
|
||||
ShadowDelete: shadowDelete,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dbController, err := getController(dbName)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// hook
|
||||
hook, err := RegisterHook(q.New(dbName).MustBeValid(), &HookBase{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// interface
|
||||
db := NewInterface(&Options{
|
||||
Local: true,
|
||||
Internal: true,
|
||||
})
|
||||
|
||||
// sub
|
||||
sub, err := db.Subscribe(q.New(dbName).MustBeValid())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
A := NewExample(dbName+":A", "Herbert", 411)
|
||||
err = A.Save()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
B := NewExample(makeKey(dbName, "B"), "Fritz", 347)
|
||||
err = B.Save()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
C := NewExample(makeKey(dbName, "C"), "Norbert", 217)
|
||||
err = C.Save()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
exists, err := db.Exists(makeKey(dbName, "A"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("record %s should exist!", makeKey(dbName, "A"))
|
||||
}
|
||||
|
||||
A1, err := GetExample(makeKey(dbName, "A"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(A, A1) {
|
||||
log.Fatalf("A and A1 mismatch, A1: %v", A1)
|
||||
}
|
||||
|
||||
cnt := countRecords(t, db, q.New(dbName).Where(
|
||||
q.And(
|
||||
q.Where("Name", q.EndsWith, "bert"),
|
||||
q.Where("Score", q.GreaterThan, 100),
|
||||
),
|
||||
))
|
||||
if cnt != 2 {
|
||||
t.Fatalf("expected two records, got %d", cnt)
|
||||
}
|
||||
|
||||
// test putmany
|
||||
if _, ok := dbController.storage.(storage.Batcher); ok {
|
||||
batchPut := db.PutMany(dbName)
|
||||
records := []record.Record{A, B, C, nil} // nil is to signify finish
|
||||
for _, r := range records {
|
||||
err = batchPut(r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// test maintenance
|
||||
if _, ok := dbController.storage.(storage.Maintainer); ok {
|
||||
now := time.Now().UTC()
|
||||
nowUnix := now.Unix()
|
||||
|
||||
// we start with 3 records without expiry
|
||||
cnt := countRecords(t, db, q.New(dbName))
|
||||
if cnt != 3 {
|
||||
t.Fatalf("expected three records, got %d", cnt)
|
||||
}
|
||||
// delete entry
|
||||
A.Meta().Deleted = nowUnix - 61
|
||||
err = A.Save()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// expire entry
|
||||
B.Meta().Expires = nowUnix - 1
|
||||
err = B.Save()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// one left
|
||||
cnt = countRecords(t, db, q.New(dbName))
|
||||
if cnt != 1 {
|
||||
t.Fatalf("expected one record, got %d", cnt)
|
||||
}
|
||||
|
||||
// run maintenance
|
||||
err = dbController.MaintainRecordStates(context.TODO(), now.Add(-60*time.Second))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// one left
|
||||
cnt = countRecords(t, db, q.New(dbName))
|
||||
if cnt != 1 {
|
||||
t.Fatalf("expected one record, got %d", cnt)
|
||||
}
|
||||
|
||||
// check status individually
|
||||
_, err = dbController.storage.Get("A")
|
||||
if !errors.Is(err, storage.ErrNotFound) {
|
||||
t.Errorf("A should be deleted and purged, err=%s", err)
|
||||
}
|
||||
B1, err := dbController.storage.Get("B")
|
||||
if err != nil {
|
||||
t.Fatalf("should exist: %s, original meta: %+v", err, B.Meta())
|
||||
}
|
||||
if B1.Meta().Deleted == 0 {
|
||||
t.Errorf("B should be deleted")
|
||||
}
|
||||
|
||||
// delete last entry
|
||||
C.Meta().Deleted = nowUnix - 1
|
||||
err = C.Save()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// run maintenance
|
||||
err = dbController.MaintainRecordStates(context.TODO(), now)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// check status individually
|
||||
B2, err := dbController.storage.Get("B")
|
||||
if err == nil {
|
||||
t.Errorf("B should be deleted and purged, meta: %+v", B2.Meta())
|
||||
} else if !errors.Is(err, storage.ErrNotFound) {
|
||||
t.Errorf("B should be deleted and purged, err=%s", err)
|
||||
}
|
||||
C2, err := dbController.storage.Get("C")
|
||||
if err == nil {
|
||||
t.Errorf("C should be deleted and purged, meta: %+v", C2.Meta())
|
||||
} else if !errors.Is(err, storage.ErrNotFound) {
|
||||
t.Errorf("C should be deleted and purged, err=%s", err)
|
||||
}
|
||||
|
||||
// none left
|
||||
cnt = countRecords(t, db, q.New(dbName))
|
||||
if cnt != 0 {
|
||||
t.Fatalf("expected no records, got %d", cnt)
|
||||
}
|
||||
}
|
||||
|
||||
err = hook.Cancel()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = sub.Cancel()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDatabaseSystem(t *testing.T) { //nolint:tparallel
|
||||
t.Parallel()
|
||||
|
||||
// panic after 10 seconds, to check for locks
|
||||
finished := make(chan struct{})
|
||||
defer close(finished)
|
||||
go func() {
|
||||
select {
|
||||
case <-finished:
|
||||
case <-time.After(10 * time.Second):
|
||||
fmt.Println("===== TAKING TOO LONG - PRINTING STACK TRACES =====")
|
||||
_ = pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
|
||||
os.Exit(1)
|
||||
}
|
||||
}()
|
||||
|
||||
for _, shadowDelete := range []bool{false, true} {
|
||||
testDatabase(t, "bbolt", shadowDelete)
|
||||
testDatabase(t, "hashmap", shadowDelete)
|
||||
testDatabase(t, "fstree", shadowDelete)
|
||||
// testDatabase(t, "badger", shadowDelete)
|
||||
// TODO: Fix badger tests
|
||||
}
|
||||
|
||||
err := MaintainRecordStates(context.TODO())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = Maintain(context.TODO())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = MaintainThorough(context.TODO())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = Shutdown()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func countRecords(t *testing.T, db *Interface, query *q.Query) int {
|
||||
t.Helper()
|
||||
|
||||
_, err := query.Check()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
it, err := db.Query(query)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cnt := 0
|
||||
for range it.Next {
|
||||
cnt++
|
||||
}
|
||||
if it.Err() != nil {
|
||||
t.Fatal(it.Err())
|
||||
}
|
||||
return cnt
|
||||
}
|
84
base/database/dbmodule/db.go
Normal file
84
base/database/dbmodule/db.go
Normal file
|
@ -0,0 +1,84 @@
|
|||
package dbmodule
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/safing/portmaster/base/database"
|
||||
"github.com/safing/portmaster/base/dataroot"
|
||||
"github.com/safing/portmaster/base/utils"
|
||||
"github.com/safing/portmaster/service/mgr"
|
||||
)
|
||||
|
||||
type DBModule struct {
|
||||
mgr *mgr.Manager
|
||||
instance instance
|
||||
}
|
||||
|
||||
func (dbm *DBModule) Manager() *mgr.Manager {
|
||||
return dbm.mgr
|
||||
}
|
||||
|
||||
func (dbm *DBModule) Start() error {
|
||||
return start()
|
||||
}
|
||||
|
||||
func (dbm *DBModule) Stop() error {
|
||||
return stop()
|
||||
}
|
||||
|
||||
var databaseStructureRoot *utils.DirStructure
|
||||
|
||||
// SetDatabaseLocation sets the location of the database for initialization. Supply either a path or dir structure.
|
||||
func SetDatabaseLocation(dirStructureRoot *utils.DirStructure) {
|
||||
if databaseStructureRoot == nil {
|
||||
databaseStructureRoot = dirStructureRoot
|
||||
}
|
||||
}
|
||||
|
||||
func prep() error {
|
||||
SetDatabaseLocation(dataroot.Root())
|
||||
if databaseStructureRoot == nil {
|
||||
return errors.New("database location not specified")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func start() error {
|
||||
startMaintenanceTasks()
|
||||
return nil
|
||||
}
|
||||
|
||||
func stop() error {
|
||||
return database.Shutdown()
|
||||
}
|
||||
|
||||
var (
|
||||
module *DBModule
|
||||
shimLoaded atomic.Bool
|
||||
)
|
||||
|
||||
func New(instance instance) (*DBModule, error) {
|
||||
if !shimLoaded.CompareAndSwap(false, true) {
|
||||
return nil, errors.New("only one instance allowed")
|
||||
}
|
||||
|
||||
if err := prep(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := mgr.New("DBModule")
|
||||
module = &DBModule{
|
||||
mgr: m,
|
||||
instance: instance,
|
||||
}
|
||||
|
||||
err := database.Initialize(databaseStructureRoot)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return module, nil
|
||||
}
|
||||
|
||||
type instance interface{}
|
30
base/database/dbmodule/maintenance.go
Normal file
30
base/database/dbmodule/maintenance.go
Normal file
|
@ -0,0 +1,30 @@
|
|||
package dbmodule
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/base/database"
|
||||
"github.com/safing/portmaster/base/log"
|
||||
"github.com/safing/portmaster/service/mgr"
|
||||
)
|
||||
|
||||
func startMaintenanceTasks() {
|
||||
_ = module.mgr.Repeat("basic maintenance", 10*time.Minute, maintainBasic)
|
||||
_ = module.mgr.Repeat("thorough maintenance", 1*time.Hour, maintainThorough)
|
||||
_ = module.mgr.Repeat("record maintenance", 1*time.Hour, maintainRecords)
|
||||
}
|
||||
|
||||
func maintainBasic(ctx *mgr.WorkerCtx) error {
|
||||
log.Infof("database: running Maintain")
|
||||
return database.Maintain(ctx.Ctx())
|
||||
}
|
||||
|
||||
func maintainThorough(ctx *mgr.WorkerCtx) error {
|
||||
log.Infof("database: running MaintainThorough")
|
||||
return database.MaintainThorough(ctx.Ctx())
|
||||
}
|
||||
|
||||
func maintainRecords(ctx *mgr.WorkerCtx) error {
|
||||
log.Infof("database: running MaintainRecordStates")
|
||||
return database.MaintainRecordStates(ctx.Ctx())
|
||||
}
|
62
base/database/doc.go
Normal file
62
base/database/doc.go
Normal file
|
@ -0,0 +1,62 @@
|
|||
/*
|
||||
Package database provides a universal interface for interacting with the database.
|
||||
|
||||
# A Lazy Database
|
||||
|
||||
The database system can handle Go structs as well as serialized data by the dsd package.
|
||||
While data is in transit within the system, it does not know which form it currently has. Only when it reaches its destination, it must ensure that it is either of a certain type or dump it.
|
||||
|
||||
# Record Interface
|
||||
|
||||
The database system uses the Record interface to transparently handle all types of structs that get saved in the database. Structs include the Base struct to fulfill most parts of the Record interface.
|
||||
|
||||
Boilerplate Code:
|
||||
|
||||
type Example struct {
|
||||
record.Base
|
||||
sync.Mutex
|
||||
|
||||
Name string
|
||||
Score int
|
||||
}
|
||||
|
||||
var (
|
||||
db = database.NewInterface(nil)
|
||||
)
|
||||
|
||||
// GetExample gets an Example from the database.
|
||||
func GetExample(key string) (*Example, error) {
|
||||
r, err := db.Get(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// unwrap
|
||||
if r.IsWrapped() {
|
||||
// only allocate a new struct, if we need it
|
||||
new := &Example{}
|
||||
err = record.Unwrap(r, new)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return new, nil
|
||||
}
|
||||
|
||||
// or adjust type
|
||||
new, ok := r.(*Example)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("record not of type *Example, but %T", r)
|
||||
}
|
||||
return new, nil
|
||||
}
|
||||
|
||||
func (e *Example) Save() error {
|
||||
return db.Put(e)
|
||||
}
|
||||
|
||||
func (e *Example) SaveAs(key string) error {
|
||||
e.SetKey(key)
|
||||
return db.PutNew(e)
|
||||
}
|
||||
*/
|
||||
package database
|
14
base/database/errors.go
Normal file
14
base/database/errors.go
Normal file
|
@ -0,0 +1,14 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
// Errors.
|
||||
var (
|
||||
ErrNotFound = errors.New("database entry not found")
|
||||
ErrPermissionDenied = errors.New("access to database record denied")
|
||||
ErrReadOnly = errors.New("database is read only")
|
||||
ErrShuttingDown = errors.New("database system is shutting down")
|
||||
ErrNotImplemented = errors.New("not implemented by this storage")
|
||||
)
|
91
base/database/hook.go
Normal file
91
base/database/hook.go
Normal file
|
@ -0,0 +1,91 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"github.com/safing/portmaster/base/database/query"
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
)
|
||||
|
||||
// Hook can be registered for a database query and
|
||||
// will be executed at certain points during the life
|
||||
// cycle of a database record.
|
||||
type Hook interface {
|
||||
// UsesPreGet should return true if the hook's PreGet
|
||||
// should be called prior to loading a database record
|
||||
// from the underlying storage.
|
||||
UsesPreGet() bool
|
||||
// PreGet is called before a database record is loaded from
|
||||
// the underlying storage. A PreGet hookd may be used to
|
||||
// implement more advanced access control on database keys.
|
||||
PreGet(dbKey string) error
|
||||
// UsesPostGet should return true if the hook's PostGet
|
||||
// should be called after loading a database record from
|
||||
// the underlying storage.
|
||||
UsesPostGet() bool
|
||||
// PostGet is called after a record has been loaded form the
|
||||
// underlying storage and may perform additional mutation
|
||||
// or access check based on the records data.
|
||||
// The passed record is already locked by the database system
|
||||
// so users can safely access all data of r.
|
||||
PostGet(r record.Record) (record.Record, error)
|
||||
// UsesPrePut should return true if the hook's PrePut method
|
||||
// should be called prior to saving a record in the database.
|
||||
UsesPrePut() bool
|
||||
// PrePut is called prior to saving (creating or updating) a
|
||||
// record in the database storage. It may be used to perform
|
||||
// extended validation or mutations on the record.
|
||||
// The passed record is already locked by the database system
|
||||
// so users can safely access all data of r.
|
||||
PrePut(r record.Record) (record.Record, error)
|
||||
}
|
||||
|
||||
// RegisteredHook is a registered database hook.
|
||||
type RegisteredHook struct {
|
||||
q *query.Query
|
||||
h Hook
|
||||
}
|
||||
|
||||
// RegisterHook registers a hook for records matching the given
|
||||
// query in the database.
|
||||
func RegisterHook(q *query.Query, hook Hook) (*RegisteredHook, error) {
|
||||
_, err := q.Check()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c, err := getController(q.DatabaseName())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rh := &RegisteredHook{
|
||||
q: q,
|
||||
h: hook,
|
||||
}
|
||||
|
||||
c.hooksLock.Lock()
|
||||
defer c.hooksLock.Unlock()
|
||||
c.hooks = append(c.hooks, rh)
|
||||
|
||||
return rh, nil
|
||||
}
|
||||
|
||||
// Cancel unregisteres the hook from the database. Once
|
||||
// Cancel returned the hook's methods will not be called
|
||||
// anymore for updates that matched the registered query.
|
||||
func (h *RegisteredHook) Cancel() error {
|
||||
c, err := getController(h.q.DatabaseName())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.hooksLock.Lock()
|
||||
defer c.hooksLock.Unlock()
|
||||
|
||||
for key, hook := range c.hooks {
|
||||
if hook.q == h.q {
|
||||
c.hooks = append(c.hooks[:key], c.hooks[key+1:]...)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
38
base/database/hookbase.go
Normal file
38
base/database/hookbase.go
Normal file
|
@ -0,0 +1,38 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
)
|
||||
|
||||
// HookBase implements the Hook interface and provides dummy functions to reduce boilerplate.
|
||||
type HookBase struct{}
|
||||
|
||||
// UsesPreGet implements the Hook interface and returns false.
|
||||
func (b *HookBase) UsesPreGet() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// UsesPostGet implements the Hook interface and returns false.
|
||||
func (b *HookBase) UsesPostGet() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// UsesPrePut implements the Hook interface and returns false.
|
||||
func (b *HookBase) UsesPrePut() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// PreGet implements the Hook interface.
|
||||
func (b *HookBase) PreGet(dbKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// PostGet implements the Hook interface.
|
||||
func (b *HookBase) PostGet(r record.Record) (record.Record, error) {
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// PrePut implements the Hook interface.
|
||||
func (b *HookBase) PrePut(r record.Record) (record.Record, error) {
|
||||
return r, nil
|
||||
}
|
585
base/database/interface.go
Normal file
585
base/database/interface.go
Normal file
|
@ -0,0 +1,585 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bluele/gcache"
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portmaster/base/database/accessor"
|
||||
"github.com/safing/portmaster/base/database/iterator"
|
||||
"github.com/safing/portmaster/base/database/query"
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
)
|
||||
|
||||
const (
|
||||
getDBFromKey = ""
|
||||
)
|
||||
|
||||
// Interface provides a method to access the database with attached options.
|
||||
type Interface struct {
|
||||
options *Options
|
||||
cache gcache.Cache
|
||||
|
||||
writeCache map[string]record.Record
|
||||
writeCacheLock sync.Mutex
|
||||
triggerCacheWrite chan struct{}
|
||||
}
|
||||
|
||||
// Options holds options that may be set for an Interface instance.
|
||||
type Options struct {
|
||||
// Local specifies if the interface is used by an actor on the local device.
|
||||
// Setting both the Local and Internal flags will bring performance
|
||||
// improvements because less checks are needed.
|
||||
Local bool
|
||||
|
||||
// Internal specifies if the interface is used by an actor within the
|
||||
// software. Setting both the Local and Internal flags will bring performance
|
||||
// improvements because less checks are needed.
|
||||
Internal bool
|
||||
|
||||
// AlwaysMakeSecret will have the interface mark all saved records as secret.
|
||||
// This means that they will be only accessible by an internal interface.
|
||||
AlwaysMakeSecret bool
|
||||
|
||||
// AlwaysMakeCrownjewel will have the interface mark all saved records as
|
||||
// crown jewels. This means that they will be only accessible by a local
|
||||
// interface.
|
||||
AlwaysMakeCrownjewel bool
|
||||
|
||||
// AlwaysSetRelativateExpiry will have the interface set a relative expiry,
|
||||
// based on the current time, on all saved records.
|
||||
AlwaysSetRelativateExpiry int64
|
||||
|
||||
// AlwaysSetAbsoluteExpiry will have the interface set an absolute expiry on
|
||||
// all saved records.
|
||||
AlwaysSetAbsoluteExpiry int64
|
||||
|
||||
// CacheSize defines that a cache should be used for this interface and
|
||||
// defines it's size.
|
||||
// Caching comes with an important caveat: If database records are changed
|
||||
// from another interface, the cache will not be invalidated for these
|
||||
// records. It will therefore serve outdated data until that record is
|
||||
// evicted from the cache.
|
||||
CacheSize int
|
||||
|
||||
// DelayCachedWrites defines a database name for which cache writes should
|
||||
// be cached and batched. The database backend must support the Batcher
|
||||
// interface. This option is only valid if used with a cache.
|
||||
// Additionally, this may only be used for internal and local interfaces.
|
||||
// Please note that this means that other interfaces will not be able to
|
||||
// guarantee to serve the latest record if records are written this way.
|
||||
DelayCachedWrites string
|
||||
}
|
||||
|
||||
// Apply applies options to the record metadata.
|
||||
func (o *Options) Apply(r record.Record) {
|
||||
r.UpdateMeta()
|
||||
if o.AlwaysMakeSecret {
|
||||
r.Meta().MakeSecret()
|
||||
}
|
||||
if o.AlwaysMakeCrownjewel {
|
||||
r.Meta().MakeCrownJewel()
|
||||
}
|
||||
if o.AlwaysSetAbsoluteExpiry > 0 {
|
||||
r.Meta().SetAbsoluteExpiry(o.AlwaysSetAbsoluteExpiry)
|
||||
} else if o.AlwaysSetRelativateExpiry > 0 {
|
||||
r.Meta().SetRelativateExpiry(o.AlwaysSetRelativateExpiry)
|
||||
}
|
||||
}
|
||||
|
||||
// HasAllPermissions returns whether the options specify the highest possible
|
||||
// permissions for operations.
|
||||
func (o *Options) HasAllPermissions() bool {
|
||||
return o.Local && o.Internal
|
||||
}
|
||||
|
||||
// hasAccessPermission checks if the interface options permit access to the
|
||||
// given record, locking the record for accessing it's attributes.
|
||||
func (o *Options) hasAccessPermission(r record.Record) bool {
|
||||
// Check if the options specify all permissions, which makes checking the
|
||||
// record unnecessary.
|
||||
if o.HasAllPermissions() {
|
||||
return true
|
||||
}
|
||||
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
// Check permissions against record.
|
||||
return r.Meta().CheckPermission(o.Local, o.Internal)
|
||||
}
|
||||
|
||||
// NewInterface returns a new Interface to the database.
|
||||
func NewInterface(opts *Options) *Interface {
|
||||
if opts == nil {
|
||||
opts = &Options{}
|
||||
}
|
||||
|
||||
newIface := &Interface{
|
||||
options: opts,
|
||||
}
|
||||
if opts.CacheSize > 0 {
|
||||
cacheBuilder := gcache.New(opts.CacheSize).ARC()
|
||||
if opts.DelayCachedWrites != "" {
|
||||
cacheBuilder.EvictedFunc(newIface.cacheEvictHandler)
|
||||
newIface.writeCache = make(map[string]record.Record, opts.CacheSize/2)
|
||||
newIface.triggerCacheWrite = make(chan struct{})
|
||||
}
|
||||
newIface.cache = cacheBuilder.Build()
|
||||
}
|
||||
return newIface
|
||||
}
|
||||
|
||||
// Exists return whether a record with the given key exists.
|
||||
func (i *Interface) Exists(key string) (bool, error) {
|
||||
_, err := i.Get(key)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ErrNotFound):
|
||||
return false, nil
|
||||
case errors.Is(err, ErrPermissionDenied):
|
||||
return true, nil
|
||||
default:
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Get return the record with the given key.
|
||||
func (i *Interface) Get(key string) (record.Record, error) {
|
||||
r, _, err := i.getRecord(getDBFromKey, key, false)
|
||||
return r, err
|
||||
}
|
||||
|
||||
func (i *Interface) getRecord(dbName string, dbKey string, mustBeWriteable bool) (r record.Record, db *Controller, err error) { //nolint:unparam
|
||||
if dbName == "" {
|
||||
dbName, dbKey = record.ParseKey(dbKey)
|
||||
}
|
||||
|
||||
db, err = getController(dbName)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if mustBeWriteable && db.ReadOnly() {
|
||||
return nil, db, ErrReadOnly
|
||||
}
|
||||
|
||||
r = i.checkCache(dbName + ":" + dbKey)
|
||||
if r != nil {
|
||||
if !i.options.hasAccessPermission(r) {
|
||||
return nil, db, ErrPermissionDenied
|
||||
}
|
||||
return r, db, nil
|
||||
}
|
||||
|
||||
r, err = db.Get(dbKey)
|
||||
if err != nil {
|
||||
return nil, db, err
|
||||
}
|
||||
|
||||
if !i.options.hasAccessPermission(r) {
|
||||
return nil, db, ErrPermissionDenied
|
||||
}
|
||||
|
||||
r.Lock()
|
||||
ttl := r.Meta().GetRelativeExpiry()
|
||||
r.Unlock()
|
||||
i.updateCache(
|
||||
r,
|
||||
false, // writing
|
||||
false, // remove
|
||||
ttl, // expiry
|
||||
)
|
||||
|
||||
return r, db, nil
|
||||
}
|
||||
|
||||
func (i *Interface) getMeta(dbName string, dbKey string, mustBeWriteable bool) (m *record.Meta, db *Controller, err error) { //nolint:unparam
|
||||
if dbName == "" {
|
||||
dbName, dbKey = record.ParseKey(dbKey)
|
||||
}
|
||||
|
||||
db, err = getController(dbName)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if mustBeWriteable && db.ReadOnly() {
|
||||
return nil, db, ErrReadOnly
|
||||
}
|
||||
|
||||
r := i.checkCache(dbName + ":" + dbKey)
|
||||
if r != nil {
|
||||
if !i.options.hasAccessPermission(r) {
|
||||
return nil, db, ErrPermissionDenied
|
||||
}
|
||||
return r.Meta(), db, nil
|
||||
}
|
||||
|
||||
m, err = db.GetMeta(dbKey)
|
||||
if err != nil {
|
||||
return nil, db, err
|
||||
}
|
||||
|
||||
if !m.CheckPermission(i.options.Local, i.options.Internal) {
|
||||
return nil, db, ErrPermissionDenied
|
||||
}
|
||||
|
||||
return m, db, nil
|
||||
}
|
||||
|
||||
// InsertValue inserts a value into a record.
|
||||
func (i *Interface) InsertValue(key string, attribute string, value interface{}) error {
|
||||
r, db, err := i.getRecord(getDBFromKey, key, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
var acc accessor.Accessor
|
||||
if r.IsWrapped() {
|
||||
wrapper, ok := r.(*record.Wrapper)
|
||||
if !ok {
|
||||
return errors.New("record is malformed (reports to be wrapped but is not of type *record.Wrapper)")
|
||||
}
|
||||
acc = accessor.NewJSONBytesAccessor(&wrapper.Data)
|
||||
} else {
|
||||
acc = accessor.NewStructAccessor(r)
|
||||
}
|
||||
|
||||
err = acc.Set(attribute, value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set value with %s: %w", acc.Type(), err)
|
||||
}
|
||||
|
||||
i.options.Apply(r)
|
||||
return db.Put(r)
|
||||
}
|
||||
|
||||
// Put saves a record to the database.
|
||||
func (i *Interface) Put(r record.Record) (err error) {
|
||||
// get record or only database
|
||||
var db *Controller
|
||||
if !i.options.HasAllPermissions() {
|
||||
_, db, err = i.getMeta(r.DatabaseName(), r.DatabaseKey(), true)
|
||||
if err != nil && !errors.Is(err, ErrNotFound) {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
db, err = getController(r.DatabaseName())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Check if database is read only.
|
||||
if db.ReadOnly() {
|
||||
return ErrReadOnly
|
||||
}
|
||||
|
||||
r.Lock()
|
||||
i.options.Apply(r)
|
||||
remove := r.Meta().IsDeleted()
|
||||
ttl := r.Meta().GetRelativeExpiry()
|
||||
r.Unlock()
|
||||
|
||||
// The record may not be locked when updating the cache.
|
||||
written := i.updateCache(r, true, remove, ttl)
|
||||
if written {
|
||||
return nil
|
||||
}
|
||||
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
return db.Put(r)
|
||||
}
|
||||
|
||||
// PutNew saves a record to the database as a new record (ie. with new timestamps).
|
||||
func (i *Interface) PutNew(r record.Record) (err error) {
|
||||
// get record or only database
|
||||
var db *Controller
|
||||
if !i.options.HasAllPermissions() {
|
||||
_, db, err = i.getMeta(r.DatabaseName(), r.DatabaseKey(), true)
|
||||
if err != nil && !errors.Is(err, ErrNotFound) {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
db, err = getController(r.DatabaseName())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Check if database is read only.
|
||||
if db.ReadOnly() {
|
||||
return ErrReadOnly
|
||||
}
|
||||
|
||||
r.Lock()
|
||||
if r.Meta() != nil {
|
||||
r.Meta().Reset()
|
||||
}
|
||||
i.options.Apply(r)
|
||||
remove := r.Meta().IsDeleted()
|
||||
ttl := r.Meta().GetRelativeExpiry()
|
||||
r.Unlock()
|
||||
|
||||
// The record may not be locked when updating the cache.
|
||||
written := i.updateCache(r, true, remove, ttl)
|
||||
if written {
|
||||
return nil
|
||||
}
|
||||
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
return db.Put(r)
|
||||
}
|
||||
|
||||
// PutMany stores many records in the database.
|
||||
// Warning: This is nearly a direct database access and omits many things:
|
||||
// - Record locking
|
||||
// - Hooks
|
||||
// - Subscriptions
|
||||
// - Caching
|
||||
// Use with care.
|
||||
func (i *Interface) PutMany(dbName string) (put func(record.Record) error) {
|
||||
interfaceBatch := make(chan record.Record, 100)
|
||||
|
||||
// permission check
|
||||
if !i.options.HasAllPermissions() {
|
||||
return func(r record.Record) error {
|
||||
return ErrPermissionDenied
|
||||
}
|
||||
}
|
||||
|
||||
// get database
|
||||
db, err := getController(dbName)
|
||||
if err != nil {
|
||||
return func(r record.Record) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Check if database is read only.
|
||||
if db.ReadOnly() {
|
||||
return func(r record.Record) error {
|
||||
return ErrReadOnly
|
||||
}
|
||||
}
|
||||
|
||||
// start database access
|
||||
dbBatch, errs := db.PutMany()
|
||||
finished := abool.New()
|
||||
var internalErr error
|
||||
|
||||
// interface options proxy
|
||||
go func() {
|
||||
defer close(dbBatch) // signify that we are finished
|
||||
for {
|
||||
select {
|
||||
case r := <-interfaceBatch:
|
||||
// finished?
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
// apply options
|
||||
i.options.Apply(r)
|
||||
// pass along
|
||||
dbBatch <- r
|
||||
case <-time.After(1 * time.Second):
|
||||
// bail out
|
||||
internalErr = errors.New("timeout: putmany unused for too long")
|
||||
finished.Set()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return func(r record.Record) error {
|
||||
// finished?
|
||||
if finished.IsSet() {
|
||||
// check for internal error
|
||||
if internalErr != nil {
|
||||
return internalErr
|
||||
}
|
||||
// check for previous error
|
||||
select {
|
||||
case err := <-errs:
|
||||
return err
|
||||
default:
|
||||
return errors.New("batch is closed")
|
||||
}
|
||||
}
|
||||
|
||||
// finish?
|
||||
if r == nil {
|
||||
finished.Set()
|
||||
interfaceBatch <- nil // signify that we are finished
|
||||
// do not close, as this fn could be called again with nil.
|
||||
return <-errs
|
||||
}
|
||||
|
||||
// check record scope
|
||||
if r.DatabaseName() != dbName {
|
||||
return errors.New("record out of database scope")
|
||||
}
|
||||
|
||||
// submit
|
||||
select {
|
||||
case interfaceBatch <- r:
|
||||
return nil
|
||||
case err := <-errs:
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetAbsoluteExpiry sets an absolute record expiry.
|
||||
func (i *Interface) SetAbsoluteExpiry(key string, time int64) error {
|
||||
r, db, err := i.getRecord(getDBFromKey, key, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
i.options.Apply(r)
|
||||
r.Meta().SetAbsoluteExpiry(time)
|
||||
return db.Put(r)
|
||||
}
|
||||
|
||||
// SetRelativateExpiry sets a relative (self-updating) record expiry.
|
||||
func (i *Interface) SetRelativateExpiry(key string, duration int64) error {
|
||||
r, db, err := i.getRecord(getDBFromKey, key, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
i.options.Apply(r)
|
||||
r.Meta().SetRelativateExpiry(duration)
|
||||
return db.Put(r)
|
||||
}
|
||||
|
||||
// MakeSecret marks the record as a secret, meaning interfacing processes, such as an UI, are denied access to the record.
|
||||
func (i *Interface) MakeSecret(key string) error {
|
||||
r, db, err := i.getRecord(getDBFromKey, key, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
i.options.Apply(r)
|
||||
r.Meta().MakeSecret()
|
||||
return db.Put(r)
|
||||
}
|
||||
|
||||
// MakeCrownJewel marks a record as a crown jewel, meaning it will only be accessible locally.
|
||||
func (i *Interface) MakeCrownJewel(key string) error {
|
||||
r, db, err := i.getRecord(getDBFromKey, key, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
i.options.Apply(r)
|
||||
r.Meta().MakeCrownJewel()
|
||||
return db.Put(r)
|
||||
}
|
||||
|
||||
// Delete deletes a record from the database.
|
||||
func (i *Interface) Delete(key string) error {
|
||||
r, db, err := i.getRecord(getDBFromKey, key, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if database is read only.
|
||||
if db.ReadOnly() {
|
||||
return ErrReadOnly
|
||||
}
|
||||
|
||||
i.options.Apply(r)
|
||||
r.Meta().Delete()
|
||||
return db.Put(r)
|
||||
}
|
||||
|
||||
// Query executes the given query on the database.
|
||||
// Will not see data that is in the write cache, waiting to be written.
|
||||
// Use with care with caching.
|
||||
func (i *Interface) Query(q *query.Query) (*iterator.Iterator, error) {
|
||||
_, err := q.Check()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db, err := getController(q.DatabaseName())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: Finish caching system integration.
|
||||
// Flush the cache before we query the database.
|
||||
// i.FlushCache()
|
||||
|
||||
return db.Query(q, i.options.Local, i.options.Internal)
|
||||
}
|
||||
|
||||
// Purge deletes all records that match the given query. It returns the number
|
||||
// of successful deletes and an error.
|
||||
func (i *Interface) Purge(ctx context.Context, q *query.Query) (int, error) {
|
||||
_, err := q.Check()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
db, err := getController(q.DatabaseName())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Check if database is read only before we add to the cache.
|
||||
if db.ReadOnly() {
|
||||
return 0, ErrReadOnly
|
||||
}
|
||||
|
||||
return db.Purge(ctx, q, i.options.Local, i.options.Internal)
|
||||
}
|
||||
|
||||
// Subscribe subscribes to updates matching the given query.
|
||||
func (i *Interface) Subscribe(q *query.Query) (*Subscription, error) {
|
||||
_, err := q.Check()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c, err := getController(q.DatabaseName())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sub := &Subscription{
|
||||
q: q,
|
||||
local: i.options.Local,
|
||||
internal: i.options.Internal,
|
||||
Feed: make(chan record.Record, 1000),
|
||||
}
|
||||
c.addSubscription(sub)
|
||||
return sub, nil
|
||||
}
|
227
base/database/interface_cache.go
Normal file
227
base/database/interface_cache.go
Normal file
|
@ -0,0 +1,227 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
"github.com/safing/portmaster/base/log"
|
||||
"github.com/safing/portmaster/service/mgr"
|
||||
)
|
||||
|
||||
// DelayedCacheWriter must be run by the caller of an interface that uses delayed cache writing.
|
||||
func (i *Interface) DelayedCacheWriter(wc *mgr.WorkerCtx) error {
|
||||
// Check if the DelayedCacheWriter should be run at all.
|
||||
if i.options.CacheSize <= 0 || i.options.DelayCachedWrites == "" {
|
||||
return errors.New("delayed cache writer is not applicable to this database interface")
|
||||
}
|
||||
|
||||
// Check if backend support the Batcher interface.
|
||||
batchPut := i.PutMany(i.options.DelayCachedWrites)
|
||||
// End batchPut immediately and check for an error.
|
||||
err := batchPut(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// percentThreshold defines the minimum percentage of entries in the write cache in relation to the cache size that need to be present in order for flushing the cache to the database storage.
|
||||
percentThreshold := 25
|
||||
thresholdWriteTicker := time.NewTicker(5 * time.Second)
|
||||
forceWriteTicker := time.NewTicker(5 * time.Minute)
|
||||
|
||||
for {
|
||||
// Wait for trigger for writing the cache.
|
||||
select {
|
||||
case <-wc.Done():
|
||||
// The caller is shutting down, flush the cache to storage and exit.
|
||||
i.flushWriteCache(0)
|
||||
return nil
|
||||
|
||||
case <-i.triggerCacheWrite:
|
||||
// An entry from the cache was evicted that was also in the write cache.
|
||||
// This makes it likely that other entries that are also present in the
|
||||
// write cache will be evicted soon. Flush the write cache to storage
|
||||
// immediately in order to reduce single writes.
|
||||
i.flushWriteCache(0)
|
||||
|
||||
case <-thresholdWriteTicker.C:
|
||||
// Often check if the write cache has filled up to a certain degree and
|
||||
// flush it to storage before we start evicting to-be-written entries and
|
||||
// slow down the hot path again.
|
||||
i.flushWriteCache(percentThreshold)
|
||||
|
||||
case <-forceWriteTicker.C:
|
||||
// Once in a while, flush the write cache to storage no matter how much
|
||||
// it is filled. We don't want entries lingering around in the write
|
||||
// cache forever. This also reduces the amount of data loss in the event
|
||||
// of a total crash.
|
||||
i.flushWriteCache(0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ClearCache clears the read cache.
|
||||
func (i *Interface) ClearCache() {
|
||||
// Check if cache is in use.
|
||||
if i.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Clear all cache entries.
|
||||
i.cache.Purge()
|
||||
}
|
||||
|
||||
// FlushCache writes (and thus clears) the write cache.
|
||||
func (i *Interface) FlushCache() {
|
||||
// Check if write cache is in use.
|
||||
if i.options.DelayCachedWrites != "" {
|
||||
return
|
||||
}
|
||||
|
||||
i.flushWriteCache(0)
|
||||
}
|
||||
|
||||
func (i *Interface) flushWriteCache(percentThreshold int) {
|
||||
i.writeCacheLock.Lock()
|
||||
defer i.writeCacheLock.Unlock()
|
||||
|
||||
// Check if there is anything to do.
|
||||
if len(i.writeCache) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if we reach the given threshold for writing to storage.
|
||||
if (len(i.writeCache)*100)/i.options.CacheSize < percentThreshold {
|
||||
return
|
||||
}
|
||||
|
||||
// Write the full cache in a batch operation.
|
||||
batchPut := i.PutMany(i.options.DelayCachedWrites)
|
||||
for _, r := range i.writeCache {
|
||||
err := batchPut(r)
|
||||
if err != nil {
|
||||
log.Warningf("database: failed to write write-cached entry to %q database: %s", i.options.DelayCachedWrites, err)
|
||||
}
|
||||
}
|
||||
// Finish batch.
|
||||
err := batchPut(nil)
|
||||
if err != nil {
|
||||
log.Warningf("database: failed to finish flushing write cache to %q database: %s", i.options.DelayCachedWrites, err)
|
||||
}
|
||||
|
||||
// Optimized map clearing following the Go1.11 recommendation.
|
||||
for key := range i.writeCache {
|
||||
delete(i.writeCache, key)
|
||||
}
|
||||
}
|
||||
|
||||
// cacheEvictHandler is run by the cache for every entry that gets evicted
|
||||
// from the cache.
|
||||
func (i *Interface) cacheEvictHandler(keyData, _ interface{}) {
|
||||
// Transform the key into a string.
|
||||
key, ok := keyData.(string)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the evicted record is one that is to be written.
|
||||
// Lock the write cache until the end of the function.
|
||||
// The read cache is locked anyway for the whole duration.
|
||||
i.writeCacheLock.Lock()
|
||||
defer i.writeCacheLock.Unlock()
|
||||
r, ok := i.writeCache[key]
|
||||
if ok {
|
||||
delete(i.writeCache, key)
|
||||
}
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Write record to database in order to mitigate race conditions where the record would appear
|
||||
// as non-existent for a short duration.
|
||||
db, err := getController(r.DatabaseName())
|
||||
if err != nil {
|
||||
log.Warningf("database: failed to write evicted cache entry %q: database %q does not exist", key, r.DatabaseName())
|
||||
return
|
||||
}
|
||||
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
err = db.Put(r)
|
||||
if err != nil {
|
||||
log.Warningf("database: failed to write evicted cache entry %q to database: %s", key, err)
|
||||
}
|
||||
|
||||
// Finally, trigger writing the full write cache because a to-be-written
|
||||
// entry was just evicted from the cache, and this makes it likely that more
|
||||
// to-be-written entries will be evicted shortly.
|
||||
select {
|
||||
case i.triggerCacheWrite <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Interface) checkCache(key string) record.Record {
|
||||
// Check if cache is in use.
|
||||
if i.cache == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if record exists in cache.
|
||||
cacheVal, err := i.cache.Get(key)
|
||||
if err == nil {
|
||||
r, ok := cacheVal.(record.Record)
|
||||
if ok {
|
||||
return r
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateCache updates an entry in the interface cache. The given record may
|
||||
// not be locked, as updating the cache might write an (unrelated) evicted
|
||||
// record to the database in the process. If this happens while the
|
||||
// DelayedCacheWriter flushes the write cache with the same record present,
|
||||
// this will deadlock.
|
||||
func (i *Interface) updateCache(r record.Record, write bool, remove bool, ttl int64) (written bool) {
|
||||
// Check if cache is in use.
|
||||
if i.cache == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if record should be deleted
|
||||
if remove {
|
||||
// Remove entry from cache.
|
||||
i.cache.Remove(r.Key())
|
||||
// Let write through to database storage.
|
||||
return false
|
||||
}
|
||||
|
||||
// Update cache with record.
|
||||
if ttl >= 0 {
|
||||
_ = i.cache.SetWithExpire(
|
||||
r.Key(),
|
||||
r,
|
||||
time.Duration(ttl)*time.Second,
|
||||
)
|
||||
} else {
|
||||
_ = i.cache.Set(
|
||||
r.Key(),
|
||||
r,
|
||||
)
|
||||
}
|
||||
|
||||
// Add record to write cache instead if:
|
||||
// 1. The record is being written.
|
||||
// 2. Write delaying is active.
|
||||
// 3. Write delaying is active for the database of this record.
|
||||
if write && r.DatabaseName() == i.options.DelayCachedWrites {
|
||||
i.writeCacheLock.Lock()
|
||||
defer i.writeCacheLock.Unlock()
|
||||
i.writeCache[r.Key()] = r
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
159
base/database/interface_cache_test.go
Normal file
159
base/database/interface_cache_test.go
Normal file
|
@ -0,0 +1,159 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/safing/portmaster/service/mgr"
|
||||
)
|
||||
|
||||
func benchmarkCacheWriting(b *testing.B, storageType string, cacheSize int, sampleSize int, delayWrites bool) { //nolint:gocognit,gocyclo,thelper
|
||||
b.Run(fmt.Sprintf("CacheWriting_%s_%d_%d_%v", storageType, cacheSize, sampleSize, delayWrites), func(b *testing.B) {
|
||||
// Setup Benchmark.
|
||||
|
||||
// Create database.
|
||||
dbName := fmt.Sprintf("cache-w-benchmark-%s-%d-%d-%v", storageType, cacheSize, sampleSize, delayWrites)
|
||||
_, err := Register(&Database{
|
||||
Name: dbName,
|
||||
Description: fmt.Sprintf("Cache Benchmark Database for %s", storageType),
|
||||
StorageType: storageType,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// Create benchmark interface.
|
||||
options := &Options{
|
||||
Local: true,
|
||||
Internal: true,
|
||||
CacheSize: cacheSize,
|
||||
}
|
||||
if cacheSize > 0 && delayWrites {
|
||||
options.DelayCachedWrites = dbName
|
||||
}
|
||||
db := NewInterface(options)
|
||||
|
||||
// Start
|
||||
m := mgr.New("Cache writing benchmark test")
|
||||
var wg sync.WaitGroup
|
||||
if cacheSize > 0 && delayWrites {
|
||||
wg.Add(1)
|
||||
m.Go("Cache writing benchmark worker", func(wc *mgr.WorkerCtx) error {
|
||||
err := db.DelayedCacheWriter(wc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
wg.Done()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Start Benchmark.
|
||||
b.ResetTimer()
|
||||
for i := range b.N {
|
||||
testRecordID := i % sampleSize
|
||||
r := NewExample(
|
||||
dbName+":"+strconv.Itoa(testRecordID),
|
||||
"A",
|
||||
1,
|
||||
)
|
||||
err = db.Put(r)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// End cache writer and wait
|
||||
m.Cancel()
|
||||
wg.Wait()
|
||||
})
|
||||
}
|
||||
|
||||
func benchmarkCacheReadWrite(b *testing.B, storageType string, cacheSize int, sampleSize int, delayWrites bool) { //nolint:gocognit,gocyclo,thelper
|
||||
b.Run(fmt.Sprintf("CacheReadWrite_%s_%d_%d_%v", storageType, cacheSize, sampleSize, delayWrites), func(b *testing.B) {
|
||||
// Setup Benchmark.
|
||||
|
||||
// Create database.
|
||||
dbName := fmt.Sprintf("cache-rw-benchmark-%s-%d-%d-%v", storageType, cacheSize, sampleSize, delayWrites)
|
||||
_, err := Register(&Database{
|
||||
Name: dbName,
|
||||
Description: fmt.Sprintf("Cache Benchmark Database for %s", storageType),
|
||||
StorageType: storageType,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// Create benchmark interface.
|
||||
options := &Options{
|
||||
Local: true,
|
||||
Internal: true,
|
||||
CacheSize: cacheSize,
|
||||
}
|
||||
if cacheSize > 0 && delayWrites {
|
||||
options.DelayCachedWrites = dbName
|
||||
}
|
||||
db := NewInterface(options)
|
||||
|
||||
// Start
|
||||
m := mgr.New("Cache read/write benchmark test")
|
||||
var wg sync.WaitGroup
|
||||
if cacheSize > 0 && delayWrites {
|
||||
wg.Add(1)
|
||||
m.Go("Cache read/write benchmark worker", func(wc *mgr.WorkerCtx) error {
|
||||
err := db.DelayedCacheWriter(wc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
wg.Done()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Start Benchmark.
|
||||
b.ResetTimer()
|
||||
writing := true
|
||||
for i := range b.N {
|
||||
testRecordID := i % sampleSize
|
||||
key := dbName + ":" + strconv.Itoa(testRecordID)
|
||||
|
||||
if i > 0 && testRecordID == 0 {
|
||||
writing = !writing // switch between reading and writing every samplesize
|
||||
}
|
||||
|
||||
if writing {
|
||||
r := NewExample(key, "A", 1)
|
||||
err = db.Put(r)
|
||||
} else {
|
||||
_, err = db.Get(key)
|
||||
}
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// End cache writer and wait
|
||||
m.Cancel()
|
||||
wg.Wait()
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkCache(b *testing.B) {
|
||||
for _, storageType := range []string{"bbolt", "hashmap"} {
|
||||
benchmarkCacheWriting(b, storageType, 32, 8, false)
|
||||
benchmarkCacheWriting(b, storageType, 32, 8, true)
|
||||
benchmarkCacheWriting(b, storageType, 32, 1024, false)
|
||||
benchmarkCacheWriting(b, storageType, 32, 1024, true)
|
||||
benchmarkCacheWriting(b, storageType, 512, 1024, false)
|
||||
benchmarkCacheWriting(b, storageType, 512, 1024, true)
|
||||
|
||||
benchmarkCacheReadWrite(b, storageType, 32, 8, false)
|
||||
benchmarkCacheReadWrite(b, storageType, 32, 8, true)
|
||||
benchmarkCacheReadWrite(b, storageType, 32, 1024, false)
|
||||
benchmarkCacheReadWrite(b, storageType, 32, 1024, true)
|
||||
benchmarkCacheReadWrite(b, storageType, 512, 1024, false)
|
||||
benchmarkCacheReadWrite(b, storageType, 512, 1024, true)
|
||||
}
|
||||
}
|
54
base/database/iterator/iterator.go
Normal file
54
base/database/iterator/iterator.go
Normal file
|
@ -0,0 +1,54 @@
|
|||
package iterator
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
)
|
||||
|
||||
// Iterator defines the iterator structure.
|
||||
type Iterator struct {
|
||||
Next chan record.Record
|
||||
Done chan struct{}
|
||||
|
||||
errLock sync.Mutex
|
||||
err error
|
||||
doneClosed *abool.AtomicBool
|
||||
}
|
||||
|
||||
// New creates a new Iterator.
|
||||
func New() *Iterator {
|
||||
return &Iterator{
|
||||
Next: make(chan record.Record, 10),
|
||||
Done: make(chan struct{}),
|
||||
doneClosed: abool.NewBool(false),
|
||||
}
|
||||
}
|
||||
|
||||
// Finish is called be the storage to signal the end of the query results.
|
||||
func (it *Iterator) Finish(err error) {
|
||||
close(it.Next)
|
||||
if it.doneClosed.SetToIf(false, true) {
|
||||
close(it.Done)
|
||||
}
|
||||
|
||||
it.errLock.Lock()
|
||||
defer it.errLock.Unlock()
|
||||
it.err = err
|
||||
}
|
||||
|
||||
// Cancel is called by the iteration consumer to cancel the running query.
|
||||
func (it *Iterator) Cancel() {
|
||||
if it.doneClosed.SetToIf(false, true) {
|
||||
close(it.Done)
|
||||
}
|
||||
}
|
||||
|
||||
// Err returns the iterator error, if exists.
|
||||
func (it *Iterator) Err() error {
|
||||
it.errLock.Lock()
|
||||
defer it.errLock.Unlock()
|
||||
return it.err
|
||||
}
|
77
base/database/main.go
Normal file
77
base/database/main.go
Normal file
|
@ -0,0 +1,77 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portmaster/base/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
databasesSubDir = "databases"
|
||||
)
|
||||
|
||||
var (
|
||||
initialized = abool.NewBool(false)
|
||||
|
||||
shuttingDown = abool.NewBool(false)
|
||||
shutdownSignal = make(chan struct{})
|
||||
|
||||
rootStructure *utils.DirStructure
|
||||
databasesStructure *utils.DirStructure
|
||||
)
|
||||
|
||||
// InitializeWithPath initializes the database at the specified location using a path.
|
||||
func InitializeWithPath(dirPath string) error {
|
||||
return Initialize(utils.NewDirStructure(dirPath, 0o0755))
|
||||
}
|
||||
|
||||
// Initialize initializes the database at the specified location using a dir structure.
|
||||
func Initialize(dirStructureRoot *utils.DirStructure) error {
|
||||
if initialized.SetToIf(false, true) {
|
||||
rootStructure = dirStructureRoot
|
||||
|
||||
// ensure root and databases dirs
|
||||
databasesStructure = rootStructure.ChildDir(databasesSubDir, 0o0700)
|
||||
err := databasesStructure.Ensure()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create/open database directory (%s): %w", rootStructure.Path, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
return errors.New("database already initialized")
|
||||
}
|
||||
|
||||
// Shutdown shuts down the whole database system.
|
||||
func Shutdown() (err error) {
|
||||
if shuttingDown.SetToIf(false, true) {
|
||||
close(shutdownSignal)
|
||||
} else {
|
||||
return
|
||||
}
|
||||
|
||||
controllersLock.RLock()
|
||||
defer controllersLock.RUnlock()
|
||||
|
||||
for _, c := range controllers {
|
||||
err = c.Shutdown()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// getLocation returns the storage location for the given name and type.
|
||||
func getLocation(name, storageType string) (string, error) {
|
||||
location := databasesStructure.ChildDir(name, 0o0700).ChildDir(storageType, 0o0700)
|
||||
// check location
|
||||
err := location.Ensure()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf(`failed to create/check database dir "%s": %w`, location.Path, err)
|
||||
}
|
||||
return location.Path, nil
|
||||
}
|
64
base/database/maintenance.go
Normal file
64
base/database/maintenance.go
Normal file
|
@ -0,0 +1,64 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Maintain runs the Maintain method on all storages.
|
||||
func Maintain(ctx context.Context) (err error) {
|
||||
// copy, as we might use the very long
|
||||
all := duplicateControllers()
|
||||
|
||||
for _, c := range all {
|
||||
err = c.Maintain(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// MaintainThorough runs the MaintainThorough method on all storages.
|
||||
func MaintainThorough(ctx context.Context) (err error) {
|
||||
// copy, as we might use the very long
|
||||
all := duplicateControllers()
|
||||
|
||||
for _, c := range all {
|
||||
err = c.MaintainThorough(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// MaintainRecordStates runs record state lifecycle maintenance on all storages.
|
||||
func MaintainRecordStates(ctx context.Context) (err error) {
|
||||
// delete immediately for now
|
||||
// TODO: increase purge threshold when starting to sync DBs
|
||||
purgeDeletedBefore := time.Now().UTC()
|
||||
|
||||
// copy, as we might use the very long
|
||||
all := duplicateControllers()
|
||||
|
||||
for _, c := range all {
|
||||
err = c.MaintainRecordStates(ctx, purgeDeletedBefore)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func duplicateControllers() (all []*Controller) {
|
||||
controllersLock.RLock()
|
||||
defer controllersLock.RUnlock()
|
||||
|
||||
all = make([]*Controller, 0, len(controllers))
|
||||
for _, c := range controllers {
|
||||
all = append(all, c)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
58
base/database/migration/error.go
Normal file
58
base/database/migration/error.go
Normal file
|
@ -0,0 +1,58 @@
|
|||
package migration
|
||||
|
||||
import "errors"
|
||||
|
||||
// DiagnosticStep describes one migration step in the Diagnostics.
|
||||
type DiagnosticStep struct {
|
||||
Version string
|
||||
Description string
|
||||
}
|
||||
|
||||
// Diagnostics holds a detailed error report about a failed migration.
|
||||
type Diagnostics struct { //nolint:errname
|
||||
// Message holds a human readable message of the encountered
|
||||
// error.
|
||||
Message string
|
||||
// Wrapped must be set to the underlying error that was encountered
|
||||
// while preparing or executing migrations.
|
||||
Wrapped error
|
||||
// StartOfMigration is set to the version of the database before
|
||||
// any migrations are applied.
|
||||
StartOfMigration string
|
||||
// LastSuccessfulMigration is set to the version of the database
|
||||
// which has been applied successfully before the error happened.
|
||||
LastSuccessfulMigration string
|
||||
// TargetVersion is set to the version of the database that the
|
||||
// migration run aimed for. That is, it's the last available version
|
||||
// added to the registry.
|
||||
TargetVersion string
|
||||
// ExecutionPlan is a list of migration steps that were planned to
|
||||
// be executed.
|
||||
ExecutionPlan []DiagnosticStep
|
||||
// FailedMigration is the description of the migration that has
|
||||
// failed.
|
||||
FailedMigration string
|
||||
}
|
||||
|
||||
// Error returns a string representation of the migration error.
|
||||
func (err *Diagnostics) Error() string {
|
||||
msg := ""
|
||||
if err.FailedMigration != "" {
|
||||
msg = err.FailedMigration + ": "
|
||||
}
|
||||
if err.Message != "" {
|
||||
msg += err.Message + ": "
|
||||
}
|
||||
msg += err.Wrapped.Error()
|
||||
return msg
|
||||
}
|
||||
|
||||
// Unwrap returns the actual error that happened when executing
|
||||
// a migration. It implements the interface required by the stdlib
|
||||
// errors package to support errors.Is() and errors.As().
|
||||
func (err *Diagnostics) Unwrap() error {
|
||||
if u := errors.Unwrap(err.Wrapped); u != nil {
|
||||
return u
|
||||
}
|
||||
return err.Wrapped
|
||||
}
|
220
base/database/migration/migration.go
Normal file
220
base/database/migration/migration.go
Normal file
|
@ -0,0 +1,220 @@
|
|||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
|
||||
"github.com/safing/portmaster/base/database"
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
"github.com/safing/portmaster/base/log"
|
||||
"github.com/safing/structures/dsd"
|
||||
)
|
||||
|
||||
// MigrateFunc is called when a migration should be applied to the
|
||||
// database. It receives the current version (from) and the target
|
||||
// version (to) of the database and a dedicated interface for
|
||||
// interacting with data stored in the DB.
|
||||
// A dedicated log.ContextTracer is added to ctx for each migration
|
||||
// run.
|
||||
type MigrateFunc func(ctx context.Context, from, to *version.Version, dbInterface *database.Interface) error
|
||||
|
||||
// Migration represents a registered data-migration that should be applied to
|
||||
// some database. Migrations are stacked on top and executed in order of increasing
|
||||
// version number (see Version field).
|
||||
type Migration struct {
|
||||
// Description provides a short human-readable description of the
|
||||
// migration.
|
||||
Description string
|
||||
// Version should hold the version of the database/subsystem after
|
||||
// the migration has been applied.
|
||||
Version string
|
||||
// MigrateFuc is executed when the migration should be performed.
|
||||
MigrateFunc MigrateFunc
|
||||
}
|
||||
|
||||
// Registry holds a migration stack.
|
||||
type Registry struct {
|
||||
key string
|
||||
|
||||
lock sync.Mutex
|
||||
migrations []Migration
|
||||
}
|
||||
|
||||
// New creates a new migration registry.
|
||||
// The key should be the name of the database key that is used to store
|
||||
// the version of the last successfully applied migration.
|
||||
func New(key string) *Registry {
|
||||
return &Registry{
|
||||
key: key,
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds one or more migrations to reg.
|
||||
func (reg *Registry) Add(migrations ...Migration) error {
|
||||
reg.lock.Lock()
|
||||
defer reg.lock.Unlock()
|
||||
for _, m := range migrations {
|
||||
if _, err := version.NewSemver(m.Version); err != nil {
|
||||
return fmt.Errorf("migration %q: invalid version %s: %w", m.Description, m.Version, err)
|
||||
}
|
||||
reg.migrations = append(reg.migrations, m)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Migrate migrates the database by executing all registered
|
||||
// migration in order of increasing version numbers. The error
|
||||
// returned, if not nil, is always of type *Diagnostics.
|
||||
func (reg *Registry) Migrate(ctx context.Context) (err error) {
|
||||
reg.lock.Lock()
|
||||
defer reg.lock.Unlock()
|
||||
|
||||
start := time.Now()
|
||||
log.Infof("migration: migration of %s started", reg.key)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
log.Errorf("migration: migration of %s failed after %s: %s", reg.key, time.Since(start), err)
|
||||
} else {
|
||||
log.Infof("migration: migration of %s finished after %s", reg.key, time.Since(start))
|
||||
}
|
||||
}()
|
||||
|
||||
db := database.NewInterface(&database.Options{
|
||||
Local: true,
|
||||
Internal: true,
|
||||
})
|
||||
|
||||
startOfMigration, err := reg.getLatestSuccessfulMigration(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
execPlan, diag, err := reg.getExecutionPlan(startOfMigration)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(execPlan) == 0 {
|
||||
return nil
|
||||
}
|
||||
diag.TargetVersion = execPlan[len(execPlan)-1].Version
|
||||
|
||||
// finally, apply our migrations
|
||||
lastAppliedMigration := startOfMigration
|
||||
for _, m := range execPlan {
|
||||
target, _ := version.NewSemver(m.Version) // we can safely ignore the error here
|
||||
|
||||
migrationCtx, tracer := log.AddTracer(ctx)
|
||||
|
||||
if err := m.MigrateFunc(migrationCtx, lastAppliedMigration, target, db); err != nil {
|
||||
diag.Wrapped = err
|
||||
diag.FailedMigration = m.Description
|
||||
tracer.Errorf("migration: migration for %s failed: %s - %s", reg.key, target.String(), m.Description)
|
||||
tracer.Submit()
|
||||
return diag
|
||||
}
|
||||
|
||||
lastAppliedMigration = target
|
||||
diag.LastSuccessfulMigration = lastAppliedMigration.String()
|
||||
|
||||
if err := reg.saveLastSuccessfulMigration(db, target); err != nil {
|
||||
diag.Message = "failed to persist migration status"
|
||||
diag.Wrapped = err
|
||||
diag.FailedMigration = m.Description
|
||||
}
|
||||
tracer.Infof("migration: applied migration for %s: %s - %s", reg.key, target.String(), m.Description)
|
||||
tracer.Submit()
|
||||
}
|
||||
|
||||
// all migrations have been applied successfully, we're done here
|
||||
return nil
|
||||
}
|
||||
|
||||
func (reg *Registry) getLatestSuccessfulMigration(db *database.Interface) (*version.Version, error) {
|
||||
// find the latest version stored in the database
|
||||
rec, err := db.Get(reg.key)
|
||||
if errors.Is(err, database.ErrNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, &Diagnostics{
|
||||
Message: "failed to query database for migration status",
|
||||
Wrapped: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Unwrap the record to get the actual database
|
||||
r, ok := rec.(*record.Wrapper)
|
||||
if !ok {
|
||||
return nil, &Diagnostics{
|
||||
Wrapped: errors.New("expected wrapped database record"),
|
||||
}
|
||||
}
|
||||
|
||||
sv, err := version.NewSemver(string(r.Data))
|
||||
if err != nil {
|
||||
return nil, &Diagnostics{
|
||||
Message: "failed to parse version stored in migration status record",
|
||||
Wrapped: err,
|
||||
}
|
||||
}
|
||||
return sv, nil
|
||||
}
|
||||
|
||||
func (reg *Registry) saveLastSuccessfulMigration(db *database.Interface, ver *version.Version) error {
|
||||
r := &record.Wrapper{
|
||||
Data: []byte(ver.String()),
|
||||
Format: dsd.RAW,
|
||||
}
|
||||
r.SetKey(reg.key)
|
||||
|
||||
return db.Put(r)
|
||||
}
|
||||
|
||||
func (reg *Registry) getExecutionPlan(startOfMigration *version.Version) ([]Migration, *Diagnostics, error) {
|
||||
// create a look-up map for migrations indexed by their semver created a
|
||||
// list of version (sorted by increasing number) that we use as our execution
|
||||
// plan.
|
||||
lm := make(map[string]Migration)
|
||||
versions := make(version.Collection, 0, len(reg.migrations))
|
||||
for _, m := range reg.migrations {
|
||||
ver, err := version.NewSemver(m.Version)
|
||||
if err != nil {
|
||||
return nil, nil, &Diagnostics{
|
||||
Message: "failed to parse version of migration",
|
||||
Wrapped: err,
|
||||
FailedMigration: m.Description,
|
||||
}
|
||||
}
|
||||
lm[ver.String()] = m // use .String() for a normalized string representation
|
||||
versions = append(versions, ver)
|
||||
}
|
||||
sort.Sort(versions)
|
||||
|
||||
diag := new(Diagnostics)
|
||||
if startOfMigration != nil {
|
||||
diag.StartOfMigration = startOfMigration.String()
|
||||
}
|
||||
|
||||
// prepare our diagnostics and the execution plan
|
||||
execPlan := make([]Migration, 0, len(versions))
|
||||
for _, ver := range versions {
|
||||
// skip an migration that has already been applied.
|
||||
if startOfMigration != nil && startOfMigration.GreaterThanOrEqual(ver) {
|
||||
continue
|
||||
}
|
||||
m := lm[ver.String()]
|
||||
diag.ExecutionPlan = append(diag.ExecutionPlan, DiagnosticStep{
|
||||
Description: m.Description,
|
||||
Version: ver.String(),
|
||||
})
|
||||
execPlan = append(execPlan, m)
|
||||
}
|
||||
|
||||
return execPlan, diag, nil
|
||||
}
|
55
base/database/query/README.md
Normal file
55
base/database/query/README.md
Normal file
|
@ -0,0 +1,55 @@
|
|||
# Query
|
||||
|
||||
## Control Flow
|
||||
|
||||
- Grouping with `(` and `)`
|
||||
- Chaining with `and` and `or`
|
||||
- _NO_ mixing! Be explicit and use grouping.
|
||||
- Negation with `not`
|
||||
- in front of expression for group: `not (...)`
|
||||
- inside expression for clause: `name not matches "^King "`
|
||||
|
||||
## Selectors
|
||||
|
||||
Supported by all feeders:
|
||||
- root level field: `field`
|
||||
- sub level field: `field.sub`
|
||||
- array/slice/map access: `map.0`
|
||||
- array/slice/map length: `map.#`
|
||||
|
||||
Please note that some feeders may have other special characters. It is advised to only use alphanumeric characters for keys.
|
||||
|
||||
## Operators
|
||||
|
||||
| Name | Textual | Req. Type | Internal Type | Compared with |
|
||||
|-------------------------|--------------------|-----------|---------------|---------------------------|
|
||||
| Equals | `==` | int | int64 | `==` |
|
||||
| GreaterThan | `>` | int | int64 | `>` |
|
||||
| GreaterThanOrEqual | `>=` | int | int64 | `>=` |
|
||||
| LessThan | `<` | int | int64 | `<` |
|
||||
| LessThanOrEqual | `<=` | int | int64 | `<=` |
|
||||
| FloatEquals | `f==` | float | float64 | `==` |
|
||||
| FloatGreaterThan | `f>` | float | float64 | `>` |
|
||||
| FloatGreaterThanOrEqual | `f>=` | float | float64 | `>=` |
|
||||
| FloatLessThan | `f<` | float | float64 | `<` |
|
||||
| FloatLessThanOrEqual | `f<=` | float | float64 | `<=` |
|
||||
| SameAs | `sameas`, `s==` | string | string | `==` |
|
||||
| Contains | `contains`, `co` | string | string | `strings.Contains()` |
|
||||
| StartsWith | `startswith`, `sw` | string | string | `strings.HasPrefix()` |
|
||||
| EndsWith | `endswith`, `ew` | string | string | `strings.HasSuffix()` |
|
||||
| In | `in` | string | string | for loop with `==` |
|
||||
| Matches | `matches`, `re` | string | string | `regexp.Regexp.Matches()` |
|
||||
| Is | `is` | bool* | bool | `==` |
|
||||
| Exists | `exists`, `ex` | any | n/a | n/a |
|
||||
|
||||
\*accepts strings: 1, t, T, true, True, TRUE, 0, f, F, false, False, FALSE
|
||||
|
||||
## Escaping
|
||||
|
||||
If you need to use a control character within a value (ie. not for controlling), escape it with `\`.
|
||||
It is recommended to wrap a word into parenthesis instead of escaping control characters, when possible.
|
||||
|
||||
| Location | Characters to be escaped |
|
||||
|---|---|
|
||||
| Within parenthesis (`"`) | `"`, `\` |
|
||||
| Everywhere else | `(`, `)`, `"`, `\`, `\t`, `\r`, `\n`, ` ` (space) |
|
46
base/database/query/condition-and.go
Normal file
46
base/database/query/condition-and.go
Normal file
|
@ -0,0 +1,46 @@
|
|||
package query
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/safing/portmaster/base/database/accessor"
|
||||
)
|
||||
|
||||
// And combines multiple conditions with a logical _AND_ operator.
|
||||
func And(conditions ...Condition) Condition {
|
||||
return &andCond{
|
||||
conditions: conditions,
|
||||
}
|
||||
}
|
||||
|
||||
type andCond struct {
|
||||
conditions []Condition
|
||||
}
|
||||
|
||||
func (c *andCond) complies(acc accessor.Accessor) bool {
|
||||
for _, cond := range c.conditions {
|
||||
if !cond.complies(acc) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *andCond) check() (err error) {
|
||||
for _, cond := range c.conditions {
|
||||
err = cond.check()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *andCond) string() string {
|
||||
all := make([]string, 0, len(c.conditions))
|
||||
for _, cond := range c.conditions {
|
||||
all = append(all, cond.string())
|
||||
}
|
||||
return fmt.Sprintf("(%s)", strings.Join(all, " and "))
|
||||
}
|
69
base/database/query/condition-bool.go
Normal file
69
base/database/query/condition-bool.go
Normal file
|
@ -0,0 +1,69 @@
|
|||
package query
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/safing/portmaster/base/database/accessor"
|
||||
)
|
||||
|
||||
type boolCondition struct {
|
||||
key string
|
||||
operator uint8
|
||||
value bool
|
||||
}
|
||||
|
||||
func newBoolCondition(key string, operator uint8, value interface{}) *boolCondition {
|
||||
var parsedValue bool
|
||||
|
||||
switch v := value.(type) {
|
||||
case bool:
|
||||
parsedValue = v
|
||||
case string:
|
||||
var err error
|
||||
parsedValue, err = strconv.ParseBool(v)
|
||||
if err != nil {
|
||||
return &boolCondition{
|
||||
key: fmt.Sprintf("could not parse \"%s\" to bool: %s", v, err),
|
||||
operator: errorPresent,
|
||||
}
|
||||
}
|
||||
default:
|
||||
return &boolCondition{
|
||||
key: fmt.Sprintf("incompatible value %v for int64", value),
|
||||
operator: errorPresent,
|
||||
}
|
||||
}
|
||||
|
||||
return &boolCondition{
|
||||
key: key,
|
||||
operator: operator,
|
||||
value: parsedValue,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *boolCondition) complies(acc accessor.Accessor) bool {
|
||||
comp, ok := acc.GetBool(c.key)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
switch c.operator {
|
||||
case Is:
|
||||
return comp == c.value
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *boolCondition) check() error {
|
||||
if c.operator == errorPresent {
|
||||
return errors.New(c.key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *boolCondition) string() string {
|
||||
return fmt.Sprintf("%s %s %t", escapeString(c.key), getOpName(c.operator), c.value)
|
||||
}
|
27
base/database/query/condition-error.go
Normal file
27
base/database/query/condition-error.go
Normal file
|
@ -0,0 +1,27 @@
|
|||
package query
|
||||
|
||||
import (
|
||||
"github.com/safing/portmaster/base/database/accessor"
|
||||
)
|
||||
|
||||
type errorCondition struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func newErrorCondition(err error) *errorCondition {
|
||||
return &errorCondition{
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *errorCondition) complies(acc accessor.Accessor) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *errorCondition) check() error {
|
||||
return c.err
|
||||
}
|
||||
|
||||
func (c *errorCondition) string() string {
|
||||
return "[ERROR]"
|
||||
}
|
35
base/database/query/condition-exists.go
Normal file
35
base/database/query/condition-exists.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
package query
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/safing/portmaster/base/database/accessor"
|
||||
)
|
||||
|
||||
type existsCondition struct {
|
||||
key string
|
||||
operator uint8
|
||||
}
|
||||
|
||||
func newExistsCondition(key string, operator uint8) *existsCondition {
|
||||
return &existsCondition{
|
||||
key: key,
|
||||
operator: operator,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *existsCondition) complies(acc accessor.Accessor) bool {
|
||||
return acc.Exists(c.key)
|
||||
}
|
||||
|
||||
func (c *existsCondition) check() error {
|
||||
if c.operator == errorPresent {
|
||||
return errors.New(c.key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *existsCondition) string() string {
|
||||
return fmt.Sprintf("%s %s", escapeString(c.key), getOpName(c.operator))
|
||||
}
|
97
base/database/query/condition-float.go
Normal file
97
base/database/query/condition-float.go
Normal file
|
@ -0,0 +1,97 @@
|
|||
package query
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/safing/portmaster/base/database/accessor"
|
||||
)
|
||||
|
||||
type floatCondition struct {
|
||||
key string
|
||||
operator uint8
|
||||
value float64
|
||||
}
|
||||
|
||||
func newFloatCondition(key string, operator uint8, value interface{}) *floatCondition {
|
||||
var parsedValue float64
|
||||
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
parsedValue = float64(v)
|
||||
case int8:
|
||||
parsedValue = float64(v)
|
||||
case int16:
|
||||
parsedValue = float64(v)
|
||||
case int32:
|
||||
parsedValue = float64(v)
|
||||
case int64:
|
||||
parsedValue = float64(v)
|
||||
case uint:
|
||||
parsedValue = float64(v)
|
||||
case uint8:
|
||||
parsedValue = float64(v)
|
||||
case uint16:
|
||||
parsedValue = float64(v)
|
||||
case uint32:
|
||||
parsedValue = float64(v)
|
||||
case float32:
|
||||
parsedValue = float64(v)
|
||||
case float64:
|
||||
parsedValue = v
|
||||
case string:
|
||||
var err error
|
||||
parsedValue, err = strconv.ParseFloat(v, 64)
|
||||
if err != nil {
|
||||
return &floatCondition{
|
||||
key: fmt.Sprintf("could not parse %s to float64: %s", v, err),
|
||||
operator: errorPresent,
|
||||
}
|
||||
}
|
||||
default:
|
||||
return &floatCondition{
|
||||
key: fmt.Sprintf("incompatible value %v for float64", value),
|
||||
operator: errorPresent,
|
||||
}
|
||||
}
|
||||
|
||||
return &floatCondition{
|
||||
key: key,
|
||||
operator: operator,
|
||||
value: parsedValue,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *floatCondition) complies(acc accessor.Accessor) bool {
|
||||
comp, ok := acc.GetFloat(c.key)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
switch c.operator {
|
||||
case FloatEquals:
|
||||
return comp == c.value
|
||||
case FloatGreaterThan:
|
||||
return comp > c.value
|
||||
case FloatGreaterThanOrEqual:
|
||||
return comp >= c.value
|
||||
case FloatLessThan:
|
||||
return comp < c.value
|
||||
case FloatLessThanOrEqual:
|
||||
return comp <= c.value
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *floatCondition) check() error {
|
||||
if c.operator == errorPresent {
|
||||
return errors.New(c.key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *floatCondition) string() string {
|
||||
return fmt.Sprintf("%s %s %g", escapeString(c.key), getOpName(c.operator), c.value)
|
||||
}
|
93
base/database/query/condition-int.go
Normal file
93
base/database/query/condition-int.go
Normal file
|
@ -0,0 +1,93 @@
|
|||
package query
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/safing/portmaster/base/database/accessor"
|
||||
)
|
||||
|
||||
type intCondition struct {
|
||||
key string
|
||||
operator uint8
|
||||
value int64
|
||||
}
|
||||
|
||||
func newIntCondition(key string, operator uint8, value interface{}) *intCondition {
|
||||
var parsedValue int64
|
||||
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
parsedValue = int64(v)
|
||||
case int8:
|
||||
parsedValue = int64(v)
|
||||
case int16:
|
||||
parsedValue = int64(v)
|
||||
case int32:
|
||||
parsedValue = int64(v)
|
||||
case int64:
|
||||
parsedValue = v
|
||||
case uint:
|
||||
parsedValue = int64(v)
|
||||
case uint8:
|
||||
parsedValue = int64(v)
|
||||
case uint16:
|
||||
parsedValue = int64(v)
|
||||
case uint32:
|
||||
parsedValue = int64(v)
|
||||
case string:
|
||||
var err error
|
||||
parsedValue, err = strconv.ParseInt(v, 10, 64)
|
||||
if err != nil {
|
||||
return &intCondition{
|
||||
key: fmt.Sprintf("could not parse %s to int64: %s (hint: use \"sameas\" to compare strings)", v, err),
|
||||
operator: errorPresent,
|
||||
}
|
||||
}
|
||||
default:
|
||||
return &intCondition{
|
||||
key: fmt.Sprintf("incompatible value %v for int64", value),
|
||||
operator: errorPresent,
|
||||
}
|
||||
}
|
||||
|
||||
return &intCondition{
|
||||
key: key,
|
||||
operator: operator,
|
||||
value: parsedValue,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *intCondition) complies(acc accessor.Accessor) bool {
|
||||
comp, ok := acc.GetInt(c.key)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
switch c.operator {
|
||||
case Equals:
|
||||
return comp == c.value
|
||||
case GreaterThan:
|
||||
return comp > c.value
|
||||
case GreaterThanOrEqual:
|
||||
return comp >= c.value
|
||||
case LessThan:
|
||||
return comp < c.value
|
||||
case LessThanOrEqual:
|
||||
return comp <= c.value
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *intCondition) check() error {
|
||||
if c.operator == errorPresent {
|
||||
return errors.New(c.key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *intCondition) string() string {
|
||||
return fmt.Sprintf("%s %s %d", escapeString(c.key), getOpName(c.operator), c.value)
|
||||
}
|
36
base/database/query/condition-not.go
Normal file
36
base/database/query/condition-not.go
Normal file
|
@ -0,0 +1,36 @@
|
|||
package query
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/safing/portmaster/base/database/accessor"
|
||||
)
|
||||
|
||||
// Not negates the supplied condition.
|
||||
func Not(c Condition) Condition {
|
||||
return ¬Cond{
|
||||
notC: c,
|
||||
}
|
||||
}
|
||||
|
||||
type notCond struct {
|
||||
notC Condition
|
||||
}
|
||||
|
||||
func (c *notCond) complies(acc accessor.Accessor) bool {
|
||||
return !c.notC.complies(acc)
|
||||
}
|
||||
|
||||
func (c *notCond) check() error {
|
||||
return c.notC.check()
|
||||
}
|
||||
|
||||
func (c *notCond) string() string {
|
||||
next := c.notC.string()
|
||||
if strings.HasPrefix(next, "(") {
|
||||
return fmt.Sprintf("not %s", c.notC.string())
|
||||
}
|
||||
splitted := strings.Split(next, " ")
|
||||
return strings.Join(append([]string{splitted[0], "not"}, splitted[1:]...), " ")
|
||||
}
|
46
base/database/query/condition-or.go
Normal file
46
base/database/query/condition-or.go
Normal file
|
@ -0,0 +1,46 @@
|
|||
package query
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/safing/portmaster/base/database/accessor"
|
||||
)
|
||||
|
||||
// Or combines multiple conditions with a logical _OR_ operator.
|
||||
func Or(conditions ...Condition) Condition {
|
||||
return &orCond{
|
||||
conditions: conditions,
|
||||
}
|
||||
}
|
||||
|
||||
type orCond struct {
|
||||
conditions []Condition
|
||||
}
|
||||
|
||||
func (c *orCond) complies(acc accessor.Accessor) bool {
|
||||
for _, cond := range c.conditions {
|
||||
if cond.complies(acc) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *orCond) check() (err error) {
|
||||
for _, cond := range c.conditions {
|
||||
err = cond.check()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *orCond) string() string {
|
||||
all := make([]string, 0, len(c.conditions))
|
||||
for _, cond := range c.conditions {
|
||||
all = append(all, cond.string())
|
||||
}
|
||||
return fmt.Sprintf("(%s)", strings.Join(all, " or "))
|
||||
}
|
63
base/database/query/condition-regex.go
Normal file
63
base/database/query/condition-regex.go
Normal file
|
@ -0,0 +1,63 @@
|
|||
package query
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
|
||||
"github.com/safing/portmaster/base/database/accessor"
|
||||
)
|
||||
|
||||
type regexCondition struct {
|
||||
key string
|
||||
operator uint8
|
||||
regex *regexp.Regexp
|
||||
}
|
||||
|
||||
func newRegexCondition(key string, operator uint8, value interface{}) *regexCondition {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
r, err := regexp.Compile(v)
|
||||
if err != nil {
|
||||
return ®exCondition{
|
||||
key: fmt.Sprintf("could not compile regex \"%s\": %s", v, err),
|
||||
operator: errorPresent,
|
||||
}
|
||||
}
|
||||
return ®exCondition{
|
||||
key: key,
|
||||
operator: operator,
|
||||
regex: r,
|
||||
}
|
||||
default:
|
||||
return ®exCondition{
|
||||
key: fmt.Sprintf("incompatible value %v for string", value),
|
||||
operator: errorPresent,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *regexCondition) complies(acc accessor.Accessor) bool {
|
||||
comp, ok := acc.GetString(c.key)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
switch c.operator {
|
||||
case Matches:
|
||||
return c.regex.MatchString(comp)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *regexCondition) check() error {
|
||||
if c.operator == errorPresent {
|
||||
return errors.New(c.key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *regexCondition) string() string {
|
||||
return fmt.Sprintf("%s %s %s", escapeString(c.key), getOpName(c.operator), escapeString(c.regex.String()))
|
||||
}
|
62
base/database/query/condition-string.go
Normal file
62
base/database/query/condition-string.go
Normal file
|
@ -0,0 +1,62 @@
|
|||
package query
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/safing/portmaster/base/database/accessor"
|
||||
)
|
||||
|
||||
type stringCondition struct {
|
||||
key string
|
||||
operator uint8
|
||||
value string
|
||||
}
|
||||
|
||||
func newStringCondition(key string, operator uint8, value interface{}) *stringCondition {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return &stringCondition{
|
||||
key: key,
|
||||
operator: operator,
|
||||
value: v,
|
||||
}
|
||||
default:
|
||||
return &stringCondition{
|
||||
key: fmt.Sprintf("incompatible value %v for string", value),
|
||||
operator: errorPresent,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *stringCondition) complies(acc accessor.Accessor) bool {
|
||||
comp, ok := acc.GetString(c.key)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
switch c.operator {
|
||||
case SameAs:
|
||||
return c.value == comp
|
||||
case Contains:
|
||||
return strings.Contains(comp, c.value)
|
||||
case StartsWith:
|
||||
return strings.HasPrefix(comp, c.value)
|
||||
case EndsWith:
|
||||
return strings.HasSuffix(comp, c.value)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *stringCondition) check() error {
|
||||
if c.operator == errorPresent {
|
||||
return errors.New(c.key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *stringCondition) string() string {
|
||||
return fmt.Sprintf("%s %s %s", escapeString(c.key), getOpName(c.operator), escapeString(c.value))
|
||||
}
|
69
base/database/query/condition-stringslice.go
Normal file
69
base/database/query/condition-stringslice.go
Normal file
|
@ -0,0 +1,69 @@
|
|||
package query
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/safing/portmaster/base/database/accessor"
|
||||
"github.com/safing/portmaster/base/utils"
|
||||
)
|
||||
|
||||
type stringSliceCondition struct {
|
||||
key string
|
||||
operator uint8
|
||||
value []string
|
||||
}
|
||||
|
||||
func newStringSliceCondition(key string, operator uint8, value interface{}) *stringSliceCondition {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
parsedValue := strings.Split(v, ",")
|
||||
if len(parsedValue) < 2 {
|
||||
return &stringSliceCondition{
|
||||
key: v,
|
||||
operator: errorPresent,
|
||||
}
|
||||
}
|
||||
return &stringSliceCondition{
|
||||
key: key,
|
||||
operator: operator,
|
||||
value: parsedValue,
|
||||
}
|
||||
case []string:
|
||||
return &stringSliceCondition{
|
||||
key: key,
|
||||
operator: operator,
|
||||
value: v,
|
||||
}
|
||||
default:
|
||||
return &stringSliceCondition{
|
||||
key: fmt.Sprintf("incompatible value %v for []string", value),
|
||||
operator: errorPresent,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *stringSliceCondition) complies(acc accessor.Accessor) bool {
|
||||
comp, ok := acc.GetString(c.key)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
switch c.operator {
|
||||
case In:
|
||||
return utils.StringInSlice(c.value, comp)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *stringSliceCondition) check() error {
|
||||
if c.operator == errorPresent {
|
||||
return fmt.Errorf("could not parse \"%s\" to []string", c.key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *stringSliceCondition) string() string {
|
||||
return fmt.Sprintf("%s %s %s", escapeString(c.key), getOpName(c.operator), escapeString(strings.Join(c.value, ",")))
|
||||
}
|
71
base/database/query/condition.go
Normal file
71
base/database/query/condition.go
Normal file
|
@ -0,0 +1,71 @@
|
|||
package query
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/safing/portmaster/base/database/accessor"
|
||||
)
|
||||
|
||||
// Condition is an interface to provide a common api to all condition types.
|
||||
type Condition interface {
|
||||
complies(acc accessor.Accessor) bool
|
||||
check() error
|
||||
string() string
|
||||
}
|
||||
|
||||
// Operators.
|
||||
const (
|
||||
Equals uint8 = iota // int
|
||||
GreaterThan // int
|
||||
GreaterThanOrEqual // int
|
||||
LessThan // int
|
||||
LessThanOrEqual // int
|
||||
FloatEquals // float
|
||||
FloatGreaterThan // float
|
||||
FloatGreaterThanOrEqual // float
|
||||
FloatLessThan // float
|
||||
FloatLessThanOrEqual // float
|
||||
SameAs // string
|
||||
Contains // string
|
||||
StartsWith // string
|
||||
EndsWith // string
|
||||
In // stringSlice
|
||||
Matches // regex
|
||||
Is // bool: accepts 1, t, T, TRUE, true, True, 0, f, F, FALSE
|
||||
Exists // any
|
||||
|
||||
errorPresent uint8 = 255
|
||||
)
|
||||
|
||||
// Where returns a condition to add to a query.
|
||||
func Where(key string, operator uint8, value interface{}) Condition {
|
||||
switch operator {
|
||||
case Equals,
|
||||
GreaterThan,
|
||||
GreaterThanOrEqual,
|
||||
LessThan,
|
||||
LessThanOrEqual:
|
||||
return newIntCondition(key, operator, value)
|
||||
case FloatEquals,
|
||||
FloatGreaterThan,
|
||||
FloatGreaterThanOrEqual,
|
||||
FloatLessThan,
|
||||
FloatLessThanOrEqual:
|
||||
return newFloatCondition(key, operator, value)
|
||||
case SameAs,
|
||||
Contains,
|
||||
StartsWith,
|
||||
EndsWith:
|
||||
return newStringCondition(key, operator, value)
|
||||
case In:
|
||||
return newStringSliceCondition(key, operator, value)
|
||||
case Matches:
|
||||
return newRegexCondition(key, operator, value)
|
||||
case Is:
|
||||
return newBoolCondition(key, operator, value)
|
||||
case Exists:
|
||||
return newExistsCondition(key, operator)
|
||||
default:
|
||||
return newErrorCondition(fmt.Errorf("no operator with ID %d", operator))
|
||||
}
|
||||
}
|
86
base/database/query/condition_test.go
Normal file
86
base/database/query/condition_test.go
Normal file
|
@ -0,0 +1,86 @@
|
|||
package query
|
||||
|
||||
import "testing"
|
||||
|
||||
func testSuccess(t *testing.T, c Condition) {
|
||||
t.Helper()
|
||||
|
||||
err := c.check()
|
||||
if err != nil {
|
||||
t.Errorf("failed: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterfaces(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testSuccess(t, newIntCondition("banana", Equals, uint(1)))
|
||||
testSuccess(t, newIntCondition("banana", Equals, uint8(1)))
|
||||
testSuccess(t, newIntCondition("banana", Equals, uint16(1)))
|
||||
testSuccess(t, newIntCondition("banana", Equals, uint32(1)))
|
||||
testSuccess(t, newIntCondition("banana", Equals, int(1)))
|
||||
testSuccess(t, newIntCondition("banana", Equals, int8(1)))
|
||||
testSuccess(t, newIntCondition("banana", Equals, int16(1)))
|
||||
testSuccess(t, newIntCondition("banana", Equals, int32(1)))
|
||||
testSuccess(t, newIntCondition("banana", Equals, int64(1)))
|
||||
testSuccess(t, newIntCondition("banana", Equals, "1"))
|
||||
|
||||
testSuccess(t, newFloatCondition("banana", FloatEquals, uint(1)))
|
||||
testSuccess(t, newFloatCondition("banana", FloatEquals, uint8(1)))
|
||||
testSuccess(t, newFloatCondition("banana", FloatEquals, uint16(1)))
|
||||
testSuccess(t, newFloatCondition("banana", FloatEquals, uint32(1)))
|
||||
testSuccess(t, newFloatCondition("banana", FloatEquals, int(1)))
|
||||
testSuccess(t, newFloatCondition("banana", FloatEquals, int8(1)))
|
||||
testSuccess(t, newFloatCondition("banana", FloatEquals, int16(1)))
|
||||
testSuccess(t, newFloatCondition("banana", FloatEquals, int32(1)))
|
||||
testSuccess(t, newFloatCondition("banana", FloatEquals, int64(1)))
|
||||
testSuccess(t, newFloatCondition("banana", FloatEquals, float32(1)))
|
||||
testSuccess(t, newFloatCondition("banana", FloatEquals, float64(1)))
|
||||
testSuccess(t, newFloatCondition("banana", FloatEquals, "1.1"))
|
||||
|
||||
testSuccess(t, newStringCondition("banana", SameAs, "coconut"))
|
||||
testSuccess(t, newRegexCondition("banana", Matches, "coconut"))
|
||||
testSuccess(t, newStringSliceCondition("banana", FloatEquals, []string{"banana", "coconut"}))
|
||||
testSuccess(t, newStringSliceCondition("banana", FloatEquals, "banana,coconut"))
|
||||
}
|
||||
|
||||
func testCondError(t *testing.T, c Condition) {
|
||||
t.Helper()
|
||||
|
||||
err := c.check()
|
||||
if err == nil {
|
||||
t.Error("should fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConditionErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// test invalid value types
|
||||
testCondError(t, newBoolCondition("banana", Is, 1))
|
||||
testCondError(t, newFloatCondition("banana", FloatEquals, true))
|
||||
testCondError(t, newIntCondition("banana", Equals, true))
|
||||
testCondError(t, newStringCondition("banana", SameAs, 1))
|
||||
testCondError(t, newRegexCondition("banana", Matches, 1))
|
||||
testCondError(t, newStringSliceCondition("banana", Matches, 1))
|
||||
|
||||
// test error presence
|
||||
testCondError(t, newBoolCondition("banana", errorPresent, true))
|
||||
testCondError(t, And(newBoolCondition("banana", errorPresent, true)))
|
||||
testCondError(t, Or(newBoolCondition("banana", errorPresent, true)))
|
||||
testCondError(t, newExistsCondition("banana", errorPresent))
|
||||
testCondError(t, newFloatCondition("banana", errorPresent, 1.1))
|
||||
testCondError(t, newIntCondition("banana", errorPresent, 1))
|
||||
testCondError(t, newStringCondition("banana", errorPresent, "coconut"))
|
||||
testCondError(t, newRegexCondition("banana", errorPresent, "coconut"))
|
||||
}
|
||||
|
||||
func TestWhere(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c := Where("", 254, nil)
|
||||
err := c.check()
|
||||
if err == nil {
|
||||
t.Error("should fail")
|
||||
}
|
||||
}
|
53
base/database/query/operators.go
Normal file
53
base/database/query/operators.go
Normal file
|
@ -0,0 +1,53 @@
|
|||
package query
|
||||
|
||||
var (
|
||||
operatorNames = map[string]uint8{
|
||||
"==": Equals,
|
||||
">": GreaterThan,
|
||||
">=": GreaterThanOrEqual,
|
||||
"<": LessThan,
|
||||
"<=": LessThanOrEqual,
|
||||
"f==": FloatEquals,
|
||||
"f>": FloatGreaterThan,
|
||||
"f>=": FloatGreaterThanOrEqual,
|
||||
"f<": FloatLessThan,
|
||||
"f<=": FloatLessThanOrEqual,
|
||||
"sameas": SameAs,
|
||||
"s==": SameAs,
|
||||
"contains": Contains,
|
||||
"co": Contains,
|
||||
"startswith": StartsWith,
|
||||
"sw": StartsWith,
|
||||
"endswith": EndsWith,
|
||||
"ew": EndsWith,
|
||||
"in": In,
|
||||
"matches": Matches,
|
||||
"re": Matches,
|
||||
"is": Is,
|
||||
"exists": Exists,
|
||||
"ex": Exists,
|
||||
}
|
||||
|
||||
primaryNames = make(map[uint8]string)
|
||||
)
|
||||
|
||||
func init() {
|
||||
for opName, opID := range operatorNames {
|
||||
name, ok := primaryNames[opID]
|
||||
if ok {
|
||||
if len(name) < len(opName) {
|
||||
primaryNames[opID] = opName
|
||||
}
|
||||
} else {
|
||||
primaryNames[opID] = opName
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getOpName(operator uint8) string {
|
||||
name, ok := primaryNames[operator]
|
||||
if ok {
|
||||
return name
|
||||
}
|
||||
return "[unknown]"
|
||||
}
|
11
base/database/query/operators_test.go
Normal file
11
base/database/query/operators_test.go
Normal file
|
@ -0,0 +1,11 @@
|
|||
package query
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGetOpName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if getOpName(254) != "[unknown]" {
|
||||
t.Error("unexpected output")
|
||||
}
|
||||
}
|
350
base/database/query/parser.go
Normal file
350
base/database/query/parser.go
Normal file
|
@ -0,0 +1,350 @@
|
|||
package query
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type snippet struct {
|
||||
text string
|
||||
globalPosition int
|
||||
}
|
||||
|
||||
// ParseQuery parses a plaintext query. Special characters (that must be escaped with a '\') are: `\()` and any whitespaces.
|
||||
//
|
||||
//nolint:gocognit
|
||||
func ParseQuery(query string) (*Query, error) {
|
||||
snippets, err := extractSnippets(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
snippetsPos := 0
|
||||
|
||||
getSnippet := func() (*snippet, error) {
|
||||
// order is important, as parseAndOr will always consume one additional snippet.
|
||||
snippetsPos++
|
||||
if snippetsPos > len(snippets) {
|
||||
return nil, fmt.Errorf("unexpected end at position %d", len(query))
|
||||
}
|
||||
return snippets[snippetsPos-1], nil
|
||||
}
|
||||
remainingSnippets := func() int {
|
||||
return len(snippets) - snippetsPos
|
||||
}
|
||||
|
||||
// check for query word
|
||||
queryWord, err := getSnippet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if queryWord.text != "query" {
|
||||
return nil, errors.New("queries must start with \"query\"")
|
||||
}
|
||||
|
||||
// get prefix
|
||||
prefix, err := getSnippet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
q := New(prefix.text)
|
||||
|
||||
for remainingSnippets() > 0 {
|
||||
command, err := getSnippet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch command.text {
|
||||
case "where":
|
||||
if q.where != nil {
|
||||
return nil, fmt.Errorf("duplicate \"%s\" clause found at position %d", command.text, command.globalPosition)
|
||||
}
|
||||
|
||||
// parse conditions
|
||||
condition, err := parseAndOr(getSnippet, remainingSnippets, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// go one back, as parseAndOr had to check if its done
|
||||
snippetsPos--
|
||||
|
||||
q.Where(condition)
|
||||
case "orderby":
|
||||
if q.orderBy != "" {
|
||||
return nil, fmt.Errorf("duplicate \"%s\" clause found at position %d", command.text, command.globalPosition)
|
||||
}
|
||||
|
||||
orderBySnippet, err := getSnippet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q.OrderBy(orderBySnippet.text)
|
||||
case "limit":
|
||||
if q.limit != 0 {
|
||||
return nil, fmt.Errorf("duplicate \"%s\" clause found at position %d", command.text, command.globalPosition)
|
||||
}
|
||||
|
||||
limitSnippet, err := getSnippet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
limit, err := strconv.ParseUint(limitSnippet.text, 10, 31)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not parse integer (%s) at position %d", limitSnippet.text, limitSnippet.globalPosition)
|
||||
}
|
||||
|
||||
q.Limit(int(limit))
|
||||
case "offset":
|
||||
if q.offset != 0 {
|
||||
return nil, fmt.Errorf("duplicate \"%s\" clause found at position %d", command.text, command.globalPosition)
|
||||
}
|
||||
|
||||
offsetSnippet, err := getSnippet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
offset, err := strconv.ParseUint(offsetSnippet.text, 10, 31)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not parse integer (%s) at position %d", offsetSnippet.text, offsetSnippet.globalPosition)
|
||||
}
|
||||
|
||||
q.Offset(int(offset))
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown clause \"%s\" at position %d", command.text, command.globalPosition)
|
||||
}
|
||||
}
|
||||
|
||||
return q.Check()
|
||||
}
|
||||
|
||||
func extractSnippets(text string) (snippets []*snippet, err error) {
|
||||
skip := false
|
||||
start := -1
|
||||
inParenthesis := false
|
||||
var pos int
|
||||
var char rune
|
||||
|
||||
for pos, char = range text {
|
||||
|
||||
// skip
|
||||
if skip {
|
||||
skip = false
|
||||
continue
|
||||
}
|
||||
if char == '\\' {
|
||||
skip = true
|
||||
}
|
||||
|
||||
// wait for parenthesis to be overs
|
||||
if inParenthesis {
|
||||
if char == '"' {
|
||||
snippets = append(snippets, &snippet{
|
||||
text: prepToken(text[start+1 : pos]),
|
||||
globalPosition: start + 1,
|
||||
})
|
||||
start = -1
|
||||
inParenthesis = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// handle segments
|
||||
switch char {
|
||||
case '\t', '\n', '\r', ' ', '(', ')':
|
||||
if start >= 0 {
|
||||
snippets = append(snippets, &snippet{
|
||||
text: prepToken(text[start:pos]),
|
||||
globalPosition: start + 1,
|
||||
})
|
||||
start = -1
|
||||
}
|
||||
default:
|
||||
if start == -1 {
|
||||
start = pos
|
||||
}
|
||||
}
|
||||
|
||||
// handle special segment characters
|
||||
switch char {
|
||||
case '(', ')':
|
||||
snippets = append(snippets, &snippet{
|
||||
text: text[pos : pos+1],
|
||||
globalPosition: pos + 1,
|
||||
})
|
||||
case '"':
|
||||
if start < pos {
|
||||
return nil, fmt.Errorf("parenthesis ('\"') may not be used within words, please escape with '\\' (position: %d)", pos+1)
|
||||
}
|
||||
inParenthesis = true
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// add last
|
||||
if start >= 0 {
|
||||
snippets = append(snippets, &snippet{
|
||||
text: prepToken(text[start : pos+1]),
|
||||
globalPosition: start + 1,
|
||||
})
|
||||
}
|
||||
|
||||
return snippets, nil
|
||||
}
|
||||
|
||||
//nolint:gocognit
|
||||
func parseAndOr(getSnippet func() (*snippet, error), remainingSnippets func() int, rootCondition bool) (Condition, error) {
|
||||
var (
|
||||
isOr = false
|
||||
typeSet = false
|
||||
wrapInNot = false
|
||||
expectingMore = true
|
||||
conditions []Condition
|
||||
)
|
||||
|
||||
for {
|
||||
if !expectingMore && rootCondition && remainingSnippets() == 0 {
|
||||
// advance snippetsPos by one, as it will be set back by 1
|
||||
_, _ = getSnippet()
|
||||
if len(conditions) == 1 {
|
||||
return conditions[0], nil
|
||||
}
|
||||
if isOr {
|
||||
return Or(conditions...), nil
|
||||
}
|
||||
return And(conditions...), nil
|
||||
}
|
||||
|
||||
firstSnippet, err := getSnippet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !expectingMore && rootCondition {
|
||||
switch firstSnippet.text {
|
||||
case "orderby", "limit", "offset":
|
||||
if len(conditions) == 1 {
|
||||
return conditions[0], nil
|
||||
}
|
||||
if isOr {
|
||||
return Or(conditions...), nil
|
||||
}
|
||||
return And(conditions...), nil
|
||||
}
|
||||
}
|
||||
|
||||
switch firstSnippet.text {
|
||||
case "(":
|
||||
condition, err := parseAndOr(getSnippet, remainingSnippets, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if wrapInNot {
|
||||
conditions = append(conditions, Not(condition))
|
||||
wrapInNot = false
|
||||
} else {
|
||||
conditions = append(conditions, condition)
|
||||
}
|
||||
expectingMore = true
|
||||
case ")":
|
||||
if len(conditions) == 1 {
|
||||
return conditions[0], nil
|
||||
}
|
||||
if isOr {
|
||||
return Or(conditions...), nil
|
||||
}
|
||||
return And(conditions...), nil
|
||||
case "and":
|
||||
if typeSet && isOr {
|
||||
return nil, fmt.Errorf("you may not mix \"and\" and \"or\" (position: %d)", firstSnippet.globalPosition)
|
||||
}
|
||||
isOr = false
|
||||
typeSet = true
|
||||
expectingMore = true
|
||||
case "or":
|
||||
if typeSet && !isOr {
|
||||
return nil, fmt.Errorf("you may not mix \"and\" and \"or\" (position: %d)", firstSnippet.globalPosition)
|
||||
}
|
||||
isOr = true
|
||||
typeSet = true
|
||||
expectingMore = true
|
||||
case "not":
|
||||
wrapInNot = true
|
||||
expectingMore = true
|
||||
default:
|
||||
condition, err := parseCondition(firstSnippet, getSnippet)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if wrapInNot {
|
||||
conditions = append(conditions, Not(condition))
|
||||
wrapInNot = false
|
||||
} else {
|
||||
conditions = append(conditions, condition)
|
||||
}
|
||||
expectingMore = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseCondition(firstSnippet *snippet, getSnippet func() (*snippet, error)) (Condition, error) {
|
||||
wrapInNot := false
|
||||
|
||||
// get operator name
|
||||
opName, err := getSnippet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// negate?
|
||||
if opName.text == "not" {
|
||||
wrapInNot = true
|
||||
opName, err = getSnippet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// get operator
|
||||
operator, ok := operatorNames[opName.text]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown operator at position %d", opName.globalPosition)
|
||||
}
|
||||
|
||||
// don't need a value for "exists"
|
||||
if operator == Exists {
|
||||
if wrapInNot {
|
||||
return Not(Where(firstSnippet.text, operator, nil)), nil
|
||||
}
|
||||
return Where(firstSnippet.text, operator, nil), nil
|
||||
}
|
||||
|
||||
// get value
|
||||
value, err := getSnippet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if wrapInNot {
|
||||
return Not(Where(firstSnippet.text, operator, value.text)), nil
|
||||
}
|
||||
return Where(firstSnippet.text, operator, value.text), nil
|
||||
}
|
||||
|
||||
var escapeReplacer = regexp.MustCompile(`\\([^\\])`)
|
||||
|
||||
// prepToken removes surrounding parenthesis and escape characters.
|
||||
func prepToken(text string) string {
|
||||
return escapeReplacer.ReplaceAllString(strings.Trim(text, "\""), "$1")
|
||||
}
|
||||
|
||||
// escapeString correctly escapes a snippet for printing.
|
||||
func escapeString(token string) string {
|
||||
// check if token contains characters that need to be escaped
|
||||
if strings.ContainsAny(token, "()\"\\\t\r\n ") {
|
||||
// put the token in parenthesis and only escape \ and "
|
||||
return fmt.Sprintf("\"%s\"", strings.ReplaceAll(token, "\"", "\\\""))
|
||||
}
|
||||
return token
|
||||
}
|
177
base/database/query/parser_test.go
Normal file
177
base/database/query/parser_test.go
Normal file
|
@ -0,0 +1,177 @@
|
|||
package query
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
)
|
||||
|
||||
func TestExtractSnippets(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
text1 := `query test: where ( "bananas" > 100 and monkeys.# <= "12")or(coconuts < 10 "and" area > 50) or name sameas Julian or name matches ^King\ `
|
||||
result1 := []*snippet{
|
||||
{text: "query", globalPosition: 1},
|
||||
{text: "test:", globalPosition: 7},
|
||||
{text: "where", globalPosition: 13},
|
||||
{text: "(", globalPosition: 19},
|
||||
{text: "bananas", globalPosition: 21},
|
||||
{text: ">", globalPosition: 31},
|
||||
{text: "100", globalPosition: 33},
|
||||
{text: "and", globalPosition: 37},
|
||||
{text: "monkeys.#", globalPosition: 41},
|
||||
{text: "<=", globalPosition: 51},
|
||||
{text: "12", globalPosition: 54},
|
||||
{text: ")", globalPosition: 58},
|
||||
{text: "or", globalPosition: 59},
|
||||
{text: "(", globalPosition: 61},
|
||||
{text: "coconuts", globalPosition: 62},
|
||||
{text: "<", globalPosition: 71},
|
||||
{text: "10", globalPosition: 73},
|
||||
{text: "and", globalPosition: 76},
|
||||
{text: "area", globalPosition: 82},
|
||||
{text: ">", globalPosition: 87},
|
||||
{text: "50", globalPosition: 89},
|
||||
{text: ")", globalPosition: 91},
|
||||
{text: "or", globalPosition: 93},
|
||||
{text: "name", globalPosition: 96},
|
||||
{text: "sameas", globalPosition: 101},
|
||||
{text: "Julian", globalPosition: 108},
|
||||
{text: "or", globalPosition: 115},
|
||||
{text: "name", globalPosition: 118},
|
||||
{text: "matches", globalPosition: 123},
|
||||
{text: "^King ", globalPosition: 131},
|
||||
}
|
||||
|
||||
snippets, err := extractSnippets(text1)
|
||||
if err != nil {
|
||||
t.Errorf("failed to extract snippets: %s", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(result1, snippets) {
|
||||
t.Errorf("unexpected results:")
|
||||
for _, el := range snippets {
|
||||
t.Errorf("%+v", el)
|
||||
}
|
||||
}
|
||||
|
||||
// t.Error(spew.Sprintf("%v", treeElement))
|
||||
}
|
||||
|
||||
func testParsing(t *testing.T, queryText string, expectedResult *Query) {
|
||||
t.Helper()
|
||||
|
||||
_, err := expectedResult.Check()
|
||||
if err != nil {
|
||||
t.Errorf("failed to create query: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
q, err := ParseQuery(queryText)
|
||||
if err != nil {
|
||||
t.Errorf("failed to parse query: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if queryText != q.Print() {
|
||||
t.Errorf("string match failed: %s", q.Print())
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(expectedResult, q) {
|
||||
t.Error("deepqual match failed.")
|
||||
t.Error("got:")
|
||||
t.Error(spew.Sdump(q))
|
||||
t.Error("expected:")
|
||||
t.Error(spew.Sdump(expectedResult))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
text1 := `query test: where (bananas > 100 and monkeys.# <= 12) or not (coconuts < 10 and area not > 50) or name sameas Julian or name matches "^King " orderby name limit 10 offset 20`
|
||||
result1 := New("test:").Where(Or(
|
||||
And(
|
||||
Where("bananas", GreaterThan, 100),
|
||||
Where("monkeys.#", LessThanOrEqual, 12),
|
||||
),
|
||||
Not(And(
|
||||
Where("coconuts", LessThan, 10),
|
||||
Not(Where("area", GreaterThan, 50)),
|
||||
)),
|
||||
Where("name", SameAs, "Julian"),
|
||||
Where("name", Matches, "^King "),
|
||||
)).OrderBy("name").Limit(10).Offset(20)
|
||||
testParsing(t, text1, result1)
|
||||
|
||||
testParsing(t, `query test: orderby name`, New("test:").OrderBy("name"))
|
||||
testParsing(t, `query test: limit 10`, New("test:").Limit(10))
|
||||
testParsing(t, `query test: offset 10`, New("test:").Offset(10))
|
||||
testParsing(t, `query test: where banana matches ^ban`, New("test:").Where(Where("banana", Matches, "^ban")))
|
||||
testParsing(t, `query test: where banana exists`, New("test:").Where(Where("banana", Exists, nil)))
|
||||
testParsing(t, `query test: where banana not exists`, New("test:").Where(Not(Where("banana", Exists, nil))))
|
||||
|
||||
// test all operators
|
||||
testParsing(t, `query test: where banana == 1`, New("test:").Where(Where("banana", Equals, 1)))
|
||||
testParsing(t, `query test: where banana > 1`, New("test:").Where(Where("banana", GreaterThan, 1)))
|
||||
testParsing(t, `query test: where banana >= 1`, New("test:").Where(Where("banana", GreaterThanOrEqual, 1)))
|
||||
testParsing(t, `query test: where banana < 1`, New("test:").Where(Where("banana", LessThan, 1)))
|
||||
testParsing(t, `query test: where banana <= 1`, New("test:").Where(Where("banana", LessThanOrEqual, 1)))
|
||||
testParsing(t, `query test: where banana f== 1.1`, New("test:").Where(Where("banana", FloatEquals, 1.1)))
|
||||
testParsing(t, `query test: where banana f> 1.1`, New("test:").Where(Where("banana", FloatGreaterThan, 1.1)))
|
||||
testParsing(t, `query test: where banana f>= 1.1`, New("test:").Where(Where("banana", FloatGreaterThanOrEqual, 1.1)))
|
||||
testParsing(t, `query test: where banana f< 1.1`, New("test:").Where(Where("banana", FloatLessThan, 1.1)))
|
||||
testParsing(t, `query test: where banana f<= 1.1`, New("test:").Where(Where("banana", FloatLessThanOrEqual, 1.1)))
|
||||
testParsing(t, `query test: where banana sameas banana`, New("test:").Where(Where("banana", SameAs, "banana")))
|
||||
testParsing(t, `query test: where banana contains banana`, New("test:").Where(Where("banana", Contains, "banana")))
|
||||
testParsing(t, `query test: where banana startswith banana`, New("test:").Where(Where("banana", StartsWith, "banana")))
|
||||
testParsing(t, `query test: where banana endswith banana`, New("test:").Where(Where("banana", EndsWith, "banana")))
|
||||
testParsing(t, `query test: where banana in banana,coconut`, New("test:").Where(Where("banana", In, []string{"banana", "coconut"})))
|
||||
testParsing(t, `query test: where banana matches banana`, New("test:").Where(Where("banana", Matches, "banana")))
|
||||
testParsing(t, `query test: where banana is true`, New("test:").Where(Where("banana", Is, true)))
|
||||
testParsing(t, `query test: where banana exists`, New("test:").Where(Where("banana", Exists, nil)))
|
||||
|
||||
// special
|
||||
testParsing(t, `query test: where banana not exists`, New("test:").Where(Not(Where("banana", Exists, nil))))
|
||||
}
|
||||
|
||||
func testParseError(t *testing.T, queryText string, expectedErrorString string) {
|
||||
t.Helper()
|
||||
|
||||
_, err := ParseQuery(queryText)
|
||||
if err == nil {
|
||||
t.Errorf("should fail to parse: %s", queryText)
|
||||
return
|
||||
}
|
||||
if err.Error() != expectedErrorString {
|
||||
t.Errorf("unexpected error for query: %s\nwanted: %s\n got: %s", queryText, expectedErrorString, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// syntax
|
||||
testParseError(t, `query`, `unexpected end at position 5`)
|
||||
testParseError(t, `query test: where`, `unexpected end at position 17`)
|
||||
testParseError(t, `query test: where (`, `unexpected end at position 19`)
|
||||
testParseError(t, `query test: where )`, `unknown clause ")" at position 19`)
|
||||
testParseError(t, `query test: where not`, `unexpected end at position 21`)
|
||||
testParseError(t, `query test: where banana`, `unexpected end at position 24`)
|
||||
testParseError(t, `query test: where banana >`, `unexpected end at position 26`)
|
||||
testParseError(t, `query test: where banana nope`, `unknown operator at position 26`)
|
||||
testParseError(t, `query test: where banana exists or`, `unexpected end at position 34`)
|
||||
testParseError(t, `query test: where banana exists and`, `unexpected end at position 35`)
|
||||
testParseError(t, `query test: where banana exists and (`, `unexpected end at position 37`)
|
||||
testParseError(t, `query test: where banana exists and banana is true or`, `you may not mix "and" and "or" (position: 52)`)
|
||||
testParseError(t, `query test: where banana exists or banana is true and`, `you may not mix "and" and "or" (position: 51)`)
|
||||
// testParseError(t, `query test: where banana exists and (`, ``)
|
||||
|
||||
// value parsing error
|
||||
testParseError(t, `query test: where banana == banana`, `could not parse banana to int64: strconv.ParseInt: parsing "banana": invalid syntax (hint: use "sameas" to compare strings)`)
|
||||
testParseError(t, `query test: where banana f== banana`, `could not parse banana to float64: strconv.ParseFloat: parsing "banana": invalid syntax`)
|
||||
testParseError(t, `query test: where banana in banana`, `could not parse "banana" to []string`)
|
||||
testParseError(t, `query test: where banana matches [banana`, "could not compile regex \"[banana\": error parsing regexp: missing closing ]: `[banana`")
|
||||
testParseError(t, `query test: where banana is great`, `could not parse "great" to bool: strconv.ParseBool: parsing "great": invalid syntax`)
|
||||
}
|
170
base/database/query/query.go
Normal file
170
base/database/query/query.go
Normal file
|
@ -0,0 +1,170 @@
|
|||
package query
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/safing/portmaster/base/database/accessor"
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
)
|
||||
|
||||
// Example:
|
||||
// q.New("core:/",
|
||||
// q.Where("a", q.GreaterThan, 0),
|
||||
// q.Where("b", q.Equals, 0),
|
||||
// q.Or(
|
||||
// q.Where("c", q.StartsWith, "x"),
|
||||
// q.Where("d", q.Contains, "y")
|
||||
// )
|
||||
// )
|
||||
|
||||
// Query contains a compiled query.
|
||||
type Query struct {
|
||||
checked bool
|
||||
dbName string
|
||||
dbKeyPrefix string
|
||||
where Condition
|
||||
orderBy string
|
||||
limit int
|
||||
offset int
|
||||
}
|
||||
|
||||
// New creates a new query with the supplied prefix.
|
||||
func New(prefix string) *Query {
|
||||
dbName, dbKeyPrefix := record.ParseKey(prefix)
|
||||
return &Query{
|
||||
dbName: dbName,
|
||||
dbKeyPrefix: dbKeyPrefix,
|
||||
}
|
||||
}
|
||||
|
||||
// Where adds filtering.
|
||||
func (q *Query) Where(condition Condition) *Query {
|
||||
q.where = condition
|
||||
return q
|
||||
}
|
||||
|
||||
// Limit limits the number of returned results.
|
||||
func (q *Query) Limit(limit int) *Query {
|
||||
q.limit = limit
|
||||
return q
|
||||
}
|
||||
|
||||
// Offset sets the query offset.
|
||||
func (q *Query) Offset(offset int) *Query {
|
||||
q.offset = offset
|
||||
return q
|
||||
}
|
||||
|
||||
// OrderBy orders the results by the given key.
|
||||
func (q *Query) OrderBy(key string) *Query {
|
||||
q.orderBy = key
|
||||
return q
|
||||
}
|
||||
|
||||
// Check checks for errors in the query.
|
||||
func (q *Query) Check() (*Query, error) {
|
||||
if q.checked {
|
||||
return q, nil
|
||||
}
|
||||
|
||||
// check condition
|
||||
if q.where != nil {
|
||||
err := q.where.check()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
q.checked = true
|
||||
return q, nil
|
||||
}
|
||||
|
||||
// MustBeValid checks for errors in the query and panics if there is an error.
|
||||
func (q *Query) MustBeValid() *Query {
|
||||
_, err := q.Check()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
||||
// IsChecked returns whether they query was checked.
|
||||
func (q *Query) IsChecked() bool {
|
||||
return q.checked
|
||||
}
|
||||
|
||||
// MatchesKey checks whether the query matches the supplied database key (key without database prefix).
|
||||
func (q *Query) MatchesKey(dbKey string) bool {
|
||||
return strings.HasPrefix(dbKey, q.dbKeyPrefix)
|
||||
}
|
||||
|
||||
// MatchesRecord checks whether the query matches the supplied database record (value only).
|
||||
func (q *Query) MatchesRecord(r record.Record) bool {
|
||||
if q.where == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
acc := r.GetAccessor(r)
|
||||
if acc == nil {
|
||||
return false
|
||||
}
|
||||
return q.where.complies(acc)
|
||||
}
|
||||
|
||||
// MatchesAccessor checks whether the query matches the supplied accessor (value only).
|
||||
func (q *Query) MatchesAccessor(acc accessor.Accessor) bool {
|
||||
if q.where == nil {
|
||||
return true
|
||||
}
|
||||
return q.where.complies(acc)
|
||||
}
|
||||
|
||||
// Matches checks whether the query matches the supplied database record.
|
||||
func (q *Query) Matches(r record.Record) bool {
|
||||
if !q.MatchesKey(r.DatabaseKey()) {
|
||||
return false
|
||||
}
|
||||
return q.MatchesRecord(r)
|
||||
}
|
||||
|
||||
// Print returns the string representation of the query.
|
||||
func (q *Query) Print() string {
|
||||
var where string
|
||||
if q.where != nil {
|
||||
where = q.where.string()
|
||||
if where != "" {
|
||||
if strings.HasPrefix(where, "(") {
|
||||
where = where[1 : len(where)-1]
|
||||
}
|
||||
where = fmt.Sprintf(" where %s", where)
|
||||
}
|
||||
}
|
||||
|
||||
var orderBy string
|
||||
if q.orderBy != "" {
|
||||
orderBy = fmt.Sprintf(" orderby %s", q.orderBy)
|
||||
}
|
||||
|
||||
var limit string
|
||||
if q.limit > 0 {
|
||||
limit = fmt.Sprintf(" limit %d", q.limit)
|
||||
}
|
||||
|
||||
var offset string
|
||||
if q.offset > 0 {
|
||||
offset = fmt.Sprintf(" offset %d", q.offset)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("query %s:%s%s%s%s%s", q.dbName, q.dbKeyPrefix, where, orderBy, limit, offset)
|
||||
}
|
||||
|
||||
// DatabaseName returns the name of the database.
|
||||
func (q *Query) DatabaseName() string {
|
||||
return q.dbName
|
||||
}
|
||||
|
||||
// DatabaseKeyPrefix returns the key prefix for the database.
|
||||
func (q *Query) DatabaseKeyPrefix() string {
|
||||
return q.dbKeyPrefix
|
||||
}
|
113
base/database/query/query_test.go
Normal file
113
base/database/query/query_test.go
Normal file
|
@ -0,0 +1,113 @@
|
|||
//nolint:unparam
|
||||
package query
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/safing/portmaster/base/database/record"
|
||||
"github.com/safing/structures/dsd"
|
||||
)
|
||||
|
||||
// copied from https://github.com/tidwall/gjson/blob/master/gjson_test.go
|
||||
var testJSON = `{"age":100, "name":{"here":"B\\\"R"},
|
||||
"noop":{"what is a wren?":"a bird"},
|
||||
"happy":true,"immortal":false,
|
||||
"items":[1,2,3,{"tags":[1,2,3],"points":[[1,2],[3,4]]},4,5,6,7],
|
||||
"arr":["1",2,"3",{"hello":"world"},"4",5],
|
||||
"vals":[1,2,3,{"sadf":sdf"asdf"}],"name":{"first":"tom","last":null},
|
||||
"created":"2014-05-16T08:28:06.989Z",
|
||||
"loggy":{
|
||||
"programmers": [
|
||||
{
|
||||
"firstName": "Brett",
|
||||
"lastName": "McLaughlin",
|
||||
"email": "aaaa",
|
||||
"tag": "good"
|
||||
},
|
||||
{
|
||||
"firstName": "Jason",
|
||||
"lastName": "Hunter",
|
||||
"email": "bbbb",
|
||||
"tag": "bad"
|
||||
},
|
||||
{
|
||||
"firstName": "Elliotte",
|
||||
"lastName": "Harold",
|
||||
"email": "cccc",
|
||||
"tag":, "good"
|
||||
},
|
||||
{
|
||||
"firstName": 1002.3,
|
||||
"age": 101
|
||||
}
|
||||
]
|
||||
},
|
||||
"lastly":{"yay":"final"},
|
||||
"temperature": 120.413
|
||||
}`
|
||||
|
||||
func testQuery(t *testing.T, r record.Record, shouldMatch bool, condition Condition) {
|
||||
t.Helper()
|
||||
|
||||
q := New("test:").Where(condition).MustBeValid()
|
||||
// fmt.Printf("%s\n", q.Print())
|
||||
|
||||
matched := q.Matches(r)
|
||||
switch {
|
||||
case !matched && shouldMatch:
|
||||
t.Errorf("should match: %s", q.Print())
|
||||
case matched && !shouldMatch:
|
||||
t.Errorf("should not match: %s", q.Print())
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// if !gjson.Valid(testJSON) {
|
||||
// t.Fatal("test json is invalid")
|
||||
// }
|
||||
r, err := record.NewWrapper("", nil, dsd.JSON, []byte(testJSON))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testQuery(t, r, true, Where("age", Equals, 100))
|
||||
testQuery(t, r, true, Where("age", GreaterThan, uint8(99)))
|
||||
testQuery(t, r, true, Where("age", GreaterThanOrEqual, 99))
|
||||
testQuery(t, r, true, Where("age", GreaterThanOrEqual, 100))
|
||||
testQuery(t, r, true, Where("age", LessThan, 101))
|
||||
testQuery(t, r, true, Where("age", LessThanOrEqual, "101"))
|
||||
testQuery(t, r, true, Where("age", LessThanOrEqual, 100))
|
||||
|
||||
testQuery(t, r, true, Where("temperature", FloatEquals, 120.413))
|
||||
testQuery(t, r, true, Where("temperature", FloatGreaterThan, 120))
|
||||
testQuery(t, r, true, Where("temperature", FloatGreaterThanOrEqual, 120))
|
||||
testQuery(t, r, true, Where("temperature", FloatGreaterThanOrEqual, 120.413))
|
||||
testQuery(t, r, true, Where("temperature", FloatLessThan, 121))
|
||||
testQuery(t, r, true, Where("temperature", FloatLessThanOrEqual, "121"))
|
||||
testQuery(t, r, true, Where("temperature", FloatLessThanOrEqual, "120.413"))
|
||||
|
||||
testQuery(t, r, true, Where("lastly.yay", SameAs, "final"))
|
||||
testQuery(t, r, true, Where("lastly.yay", Contains, "ina"))
|
||||
testQuery(t, r, true, Where("lastly.yay", StartsWith, "fin"))
|
||||
testQuery(t, r, true, Where("lastly.yay", EndsWith, "nal"))
|
||||
testQuery(t, r, true, Where("lastly.yay", In, "draft,final"))
|
||||
testQuery(t, r, true, Where("lastly.yay", In, "final,draft"))
|
||||
|
||||
testQuery(t, r, true, Where("happy", Is, true))
|
||||
testQuery(t, r, true, Where("happy", Is, "true"))
|
||||
testQuery(t, r, true, Where("happy", Is, "t"))
|
||||
testQuery(t, r, true, Not(Where("happy", Is, "0")))
|
||||
testQuery(t, r, true, And(
|
||||
Where("happy", Is, "1"),
|
||||
Not(Or(
|
||||
Where("happy", Is, false),
|
||||
Where("happy", Is, "f"),
|
||||
)),
|
||||
))
|
||||
|
||||
testQuery(t, r, true, Where("happy", Exists, nil))
|
||||
|
||||
testQuery(t, r, true, Where("created", Matches, "^2014-[0-9]{2}-[0-9]{2}T"))
|
||||
}
|
156
base/database/record/base.go
Normal file
156
base/database/record/base.go
Normal file
|
@ -0,0 +1,156 @@
|
|||
package record
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/safing/portmaster/base/database/accessor"
|
||||
"github.com/safing/portmaster/base/log"
|
||||
"github.com/safing/structures/container"
|
||||
"github.com/safing/structures/dsd"
|
||||
)
|
||||
|
||||
// TODO(ppacher):
|
||||
// we can reduce the record.Record interface a lot by moving
|
||||
// most of those functions that require the Record as it's first
|
||||
// parameter to static package functions
|
||||
// (i.e. Marshal, MarshalRecord, GetAccessor, ...).
|
||||
// We should also consider given Base a GetBase() *Base method
|
||||
// that returns itself. This way we can remove almost all Base
|
||||
// only methods from the record.Record interface. That is, we can
|
||||
// remove all those CreateMeta, UpdateMeta, ... stuff from the
|
||||
// interface definition (not the actual functions!). This would make
|
||||
// the record.Record interface slim and only provide methods that
|
||||
// most users actually need. All those database/storage related methods
|
||||
// can still be accessed by using GetBase().XXX() instead. We can also
|
||||
// expose the dbName and dbKey and meta properties directly which would
|
||||
// make a nice JSON blob when marshalled.
|
||||
|
||||
// Base provides a quick way to comply with the Model interface.
|
||||
type Base struct {
|
||||
dbName string
|
||||
dbKey string
|
||||
meta *Meta
|
||||
}
|
||||
|
||||
// SetKey sets the key on the database record. The key may only be set once and
|
||||
// future calls to SetKey will be ignored. If you want to copy/move the record
|
||||
// to another database key, you will need to create a copy and assign a new key.
|
||||
// A key must be set before the record is used in any database operation.
|
||||
func (b *Base) SetKey(key string) {
|
||||
if !b.KeyIsSet() {
|
||||
b.dbName, b.dbKey = ParseKey(key)
|
||||
} else {
|
||||
log.Errorf("database: key is already set: tried to replace %q with %q", b.Key(), key)
|
||||
}
|
||||
}
|
||||
|
||||
// ResetKey resets the database name and key.
|
||||
// Use with caution!
|
||||
func (b *Base) ResetKey() {
|
||||
b.dbName = ""
|
||||
b.dbKey = ""
|
||||
}
|
||||
|
||||
// Key returns the key of the database record.
|
||||
// As the key must be set before any usage and can only be set once, this
|
||||
// function may be used without locking the record.
|
||||
func (b *Base) Key() string {
|
||||
return b.dbName + ":" + b.dbKey
|
||||
}
|
||||
|
||||
// KeyIsSet returns true if the database key is set.
|
||||
// As the key must be set before any usage and can only be set once, this
|
||||
// function may be used without locking the record.
|
||||
func (b *Base) KeyIsSet() bool {
|
||||
return b.dbName != ""
|
||||
}
|
||||
|
||||
// DatabaseName returns the name of the database.
|
||||
// As the key must be set before any usage and can only be set once, this
|
||||
// function may be used without locking the record.
|
||||
func (b *Base) DatabaseName() string {
|
||||
return b.dbName
|
||||
}
|
||||
|
||||
// DatabaseKey returns the database key of the database record.
|
||||
// As the key must be set before any usage and can only be set once, this
|
||||
// function may be used without locking the record.
|
||||
func (b *Base) DatabaseKey() string {
|
||||
return b.dbKey
|
||||
}
|
||||
|
||||
// Meta returns the metadata object for this record.
|
||||
func (b *Base) Meta() *Meta {
|
||||
return b.meta
|
||||
}
|
||||
|
||||
// CreateMeta sets a default metadata object for this record.
|
||||
func (b *Base) CreateMeta() {
|
||||
b.meta = &Meta{}
|
||||
}
|
||||
|
||||
// UpdateMeta creates the metadata if it does not exist and updates it.
|
||||
func (b *Base) UpdateMeta() {
|
||||
if b.meta == nil {
|
||||
b.CreateMeta()
|
||||
}
|
||||
b.meta.Update()
|
||||
}
|
||||
|
||||
// SetMeta sets the metadata on the database record, it should only be called after loading the record. Use MoveTo to save the record with another key.
|
||||
func (b *Base) SetMeta(meta *Meta) {
|
||||
b.meta = meta
|
||||
}
|
||||
|
||||
// Marshal marshals the object, without the database key or metadata. It returns nil if the record is deleted.
|
||||
func (b *Base) Marshal(self Record, format uint8) ([]byte, error) {
|
||||
if b.Meta() == nil {
|
||||
return nil, errors.New("missing meta")
|
||||
}
|
||||
|
||||
if b.Meta().Deleted > 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
dumped, err := dsd.Dump(self, format)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dumped, nil
|
||||
}
|
||||
|
||||
// MarshalRecord packs the object, including metadata, into a byte array for saving in a database.
|
||||
func (b *Base) MarshalRecord(self Record) ([]byte, error) {
|
||||
if b.Meta() == nil {
|
||||
return nil, errors.New("missing meta")
|
||||
}
|
||||
|
||||
// version
|
||||
c := container.New([]byte{1})
|
||||
|
||||
// meta encoding
|
||||
metaSection, err := dsd.Dump(b.meta, dsd.GenCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.AppendAsBlock(metaSection)
|
||||
|
||||
// data
|
||||
dataSection, err := b.Marshal(self, dsd.JSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.Append(dataSection)
|
||||
|
||||
return c.CompileData(), nil
|
||||
}
|
||||
|
||||
// IsWrapped returns whether the record is a Wrapper.
|
||||
func (b *Base) IsWrapped() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// GetAccessor returns an accessor for this record, if available.
|
||||
func (b *Base) GetAccessor(self Record) accessor.Accessor {
|
||||
return accessor.NewStructAccessor(self)
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue