mirror of
https://github.com/safing/portbase
synced 2025-04-23 02:39:09 +00:00
335 lines
8.1 KiB
Go
335 lines
8.1 KiB
Go
package runtime
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/armon/go-radix"
|
|
"golang.org/x/sync/errgroup"
|
|
|
|
"github.com/safing/portbase/database"
|
|
"github.com/safing/portbase/database/iterator"
|
|
"github.com/safing/portbase/database/query"
|
|
"github.com/safing/portbase/database/record"
|
|
"github.com/safing/portbase/database/storage"
|
|
"github.com/safing/portbase/log"
|
|
)
|
|
|
|
var (
|
|
// ErrKeyTaken is returned when trying to register
|
|
// a value provider at database key or prefix that
|
|
// is already occupied by another provider.
|
|
ErrKeyTaken = errors.New("runtime key or prefix already used")
|
|
// ErrKeyUnmanaged is returned when a Put operation
|
|
// on an unmanaged key is performed.
|
|
ErrKeyUnmanaged = errors.New("runtime key not managed by any provider")
|
|
// ErrInjected is returned by Registry.InjectAsDatabase
|
|
// if the registry has already been injected.
|
|
ErrInjected = errors.New("registry already injected")
|
|
)
|
|
|
|
// Registry keeps track of registered runtime
|
|
// value providers and exposes them via an
|
|
// injected database. Users normally just need
|
|
// to use the defaul registry provided by this
|
|
// package but may consider creating a dedicated
|
|
// runtime registry on their own. Registry uses
|
|
// a radix tree for value providers and their
|
|
// chosen database key/prefix.
|
|
type Registry struct {
|
|
l sync.RWMutex
|
|
providers *radix.Tree
|
|
dbController *database.Controller
|
|
dbName string
|
|
}
|
|
|
|
// keyedValueProvider simply wraps a value provider with it's
|
|
// registration prefix.
|
|
type keyedValueProvider struct {
|
|
ValueProvider
|
|
key string
|
|
}
|
|
|
|
// NewRegistry returns a new registry.
|
|
func NewRegistry() *Registry {
|
|
return &Registry{
|
|
providers: radix.New(),
|
|
}
|
|
}
|
|
|
|
func isPrefixKey(key string) bool {
|
|
return strings.HasSuffix(key, "/")
|
|
}
|
|
|
|
// DatabaseName returns the name of the database where the
|
|
// registry has been injected. It returns an empty string
|
|
// if InjectAsDatabase has not been called.
|
|
func (r *Registry) DatabaseName() string {
|
|
r.l.RLock()
|
|
defer r.l.RUnlock()
|
|
|
|
return r.dbName
|
|
}
|
|
|
|
// InjectAsDatabase injects the registry as the storage
|
|
// database for name.
|
|
func (r *Registry) InjectAsDatabase(name string) error {
|
|
r.l.Lock()
|
|
defer r.l.Unlock()
|
|
|
|
if r.dbController != nil {
|
|
return ErrInjected
|
|
}
|
|
|
|
ctrl, err := database.InjectDatabase(name, r.asStorage())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
r.dbName = name
|
|
r.dbController = ctrl
|
|
|
|
return nil
|
|
}
|
|
|
|
// Register registers a new value provider p under keyOrPrefix. The
|
|
// returned PushFunc can be used to send update notitifcations to
|
|
// database subscribers. Note that keyOrPrefix must end in '/' to be
|
|
// accepted as a prefix.
|
|
func (r *Registry) Register(keyOrPrefix string, p ValueProvider) (PushFunc, error) {
|
|
r.l.Lock()
|
|
defer r.l.Unlock()
|
|
|
|
// search if there's a provider registered for a prefix
|
|
// that matches or is equal to keyOrPrefix.
|
|
key, _, ok := r.providers.LongestPrefix(keyOrPrefix)
|
|
if ok && (isPrefixKey(key) || key == keyOrPrefix) {
|
|
return nil, fmt.Errorf("%w: found provider on %s", ErrKeyTaken, key)
|
|
}
|
|
|
|
// if keyOrPrefix is a prefix there must not be any provider
|
|
// registered for a key that matches keyOrPrefix.
|
|
if isPrefixKey(keyOrPrefix) {
|
|
foundProvider := ""
|
|
r.providers.WalkPrefix(keyOrPrefix, func(s string, _ interface{}) bool {
|
|
foundProvider = s
|
|
return true
|
|
})
|
|
if foundProvider != "" {
|
|
return nil, fmt.Errorf("%w: found provider on %s", ErrKeyTaken, foundProvider)
|
|
}
|
|
}
|
|
|
|
r.providers.Insert(keyOrPrefix, &keyedValueProvider{
|
|
ValueProvider: TraceProvider(p),
|
|
key: keyOrPrefix,
|
|
})
|
|
|
|
log.Tracef("runtime: registered new provider at %s", keyOrPrefix)
|
|
|
|
return func(records ...record.Record) {
|
|
r.l.RLock()
|
|
defer r.l.RUnlock()
|
|
|
|
if r.dbController == nil {
|
|
return
|
|
}
|
|
|
|
for _, rec := range records {
|
|
r.dbController.PushUpdate(rec)
|
|
}
|
|
}, nil
|
|
}
|
|
|
|
// Get returns the runtime value that is identified by key.
|
|
// It implements the storage.Interface.
|
|
func (r *Registry) Get(key string) (record.Record, error) {
|
|
provider := r.getMatchingProvider(key)
|
|
if provider == nil {
|
|
return nil, database.ErrNotFound
|
|
}
|
|
|
|
records, err := provider.Get(key)
|
|
if err != nil {
|
|
// instead of returning ErrWriteOnly to the database interface
|
|
// we wrap it in ErrNotFound so the records effectively gets
|
|
// hidden.
|
|
if errors.Is(err, ErrWriteOnly) {
|
|
return nil, database.ErrNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
// Get performs an exact match so filter out
|
|
// and values that do not match key.
|
|
for _, r := range records {
|
|
if r.DatabaseKey() == key {
|
|
return r, nil
|
|
}
|
|
}
|
|
|
|
return nil, database.ErrNotFound
|
|
}
|
|
|
|
// Put stores the record m in the runtime database. Note that
|
|
// ErrReadOnly is returned if there's no value provider responsible
|
|
// for m.Key().
|
|
func (r *Registry) Put(m record.Record) (record.Record, error) {
|
|
provider := r.getMatchingProvider(m.DatabaseKey())
|
|
if provider == nil {
|
|
// if there's no provider for the given value
|
|
// return ErrKeyUnmanaged.
|
|
return nil, ErrKeyUnmanaged
|
|
}
|
|
|
|
res, err := provider.Set(m)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
// Query performs a query on the runtime registry returning all
|
|
// records across all value providers that match q.
|
|
// Query implements the storage.Storage interface.
|
|
func (r *Registry) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
|
|
if _, err := q.Check(); err != nil {
|
|
return nil, fmt.Errorf("invalid query: %w", err)
|
|
}
|
|
|
|
searchPrefix := q.DatabaseKeyPrefix()
|
|
providers := r.collectProviderByPrefix(searchPrefix)
|
|
if len(providers) == 0 {
|
|
return nil, fmt.Errorf("%w: for key %s", ErrKeyUnmanaged, searchPrefix)
|
|
}
|
|
|
|
iter := iterator.New()
|
|
|
|
grp := new(errgroup.Group)
|
|
for idx := range providers {
|
|
p := providers[idx]
|
|
|
|
grp.Go(func() (err error) {
|
|
defer recovery(&err)
|
|
|
|
key := p.key
|
|
if len(searchPrefix) > len(key) {
|
|
key = searchPrefix
|
|
}
|
|
|
|
records, err := p.Get(key)
|
|
if err != nil {
|
|
if errors.Is(err, ErrWriteOnly) {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
|
|
for _, r := range records {
|
|
r.Lock()
|
|
var (
|
|
matchesKey = q.MatchesKey(r.DatabaseKey())
|
|
isValid = r.Meta().CheckValidity()
|
|
isAllowed = r.Meta().CheckPermission(local, internal)
|
|
|
|
allowed = matchesKey && isValid && isAllowed
|
|
)
|
|
if allowed {
|
|
allowed = q.MatchesRecord(r)
|
|
}
|
|
r.Unlock()
|
|
|
|
if !allowed {
|
|
log.Tracef("runtime: not sending %s for query %s. matchesKey=%v isValid=%v isAllowed=%v", r.DatabaseKey(), searchPrefix, matchesKey, isValid, isAllowed)
|
|
continue
|
|
}
|
|
|
|
select {
|
|
case iter.Next <- r:
|
|
case <-iter.Done:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
go func() {
|
|
err := grp.Wait()
|
|
iter.Finish(err)
|
|
}()
|
|
|
|
return iter, nil
|
|
}
|
|
|
|
func (r *Registry) getMatchingProvider(key string) *keyedValueProvider {
|
|
r.l.RLock()
|
|
defer r.l.RUnlock()
|
|
|
|
providerKey, provider, ok := r.providers.LongestPrefix(key)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
if !isPrefixKey(providerKey) && providerKey != key {
|
|
return nil
|
|
}
|
|
|
|
return provider.(*keyedValueProvider) //nolint:forcetypeassert
|
|
}
|
|
|
|
func (r *Registry) collectProviderByPrefix(prefix string) []*keyedValueProvider {
|
|
r.l.RLock()
|
|
defer r.l.RUnlock()
|
|
|
|
// if there's a LongestPrefix provider that's the only one
|
|
// we need to ask
|
|
if _, p, ok := r.providers.LongestPrefix(prefix); ok {
|
|
return []*keyedValueProvider{p.(*keyedValueProvider)} //nolint:forcetypeassert
|
|
}
|
|
|
|
var providers []*keyedValueProvider
|
|
r.providers.WalkPrefix(prefix, func(key string, p interface{}) bool {
|
|
providers = append(providers, p.(*keyedValueProvider)) //nolint:forcetypeassert
|
|
return false
|
|
})
|
|
|
|
return providers
|
|
}
|
|
|
|
// GetRegistrationKeys returns a list of all provider registration
|
|
// keys or prefixes.
|
|
func (r *Registry) GetRegistrationKeys() []string {
|
|
r.l.RLock()
|
|
defer r.l.RUnlock()
|
|
|
|
var keys []string
|
|
|
|
r.providers.Walk(func(key string, p interface{}) bool {
|
|
keys = append(keys, key)
|
|
return false
|
|
})
|
|
return keys
|
|
}
|
|
|
|
// asStorage returns a storage.Interface compatible struct
|
|
// that is backed by r.
|
|
func (r *Registry) asStorage() storage.Interface {
|
|
return &storageWrapper{
|
|
Registry: r,
|
|
}
|
|
}
|
|
|
|
func recovery(err *error) {
|
|
if x := recover(); x != nil {
|
|
if e, ok := x.(error); ok {
|
|
*err = e
|
|
return
|
|
}
|
|
|
|
*err = fmt.Errorf("%v", x)
|
|
}
|
|
}
|