Merge pull request #209 from safing/feature/database-allow-custom-interface

Add database custom interface functions
This commit is contained in:
Daniel Hovie 2023-07-20 14:51:24 +02:00 committed by GitHub
commit 29ac7d1aae
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 187 additions and 142 deletions

View file

@ -7,6 +7,7 @@ linters:
- containedctx
- contextcheck
- cyclop
- depguard
- exhaustivestruct
- exhaustruct
- forbidigo
@ -22,6 +23,7 @@ linters:
- interfacer
- ireturn
- lll
- musttag
- nestif
- nilnil
- nlreturn

View file

@ -44,7 +44,7 @@ var (
func init() {
RegisterHandler("/api/database/v1", WrapInAuthHandler(
startDatabaseAPI,
startDatabaseWebsocketAPI,
// Default to admin read/write permissions until the database gets support
// for api permissions.
dbCompatibilityPermission,
@ -52,11 +52,8 @@ func init() {
))
}
// DatabaseAPI is a database API instance.
// DatabaseAPI is a generic database API interface.
type DatabaseAPI struct {
conn *websocket.Conn
sendQueue chan []byte
queriesLock sync.Mutex
queries map[string]*iterator.Iterator
@ -66,13 +63,35 @@ type DatabaseAPI struct {
shutdownSignal chan struct{}
shuttingDown *abool.AtomicBool
db *database.Interface
sendBytes func(data []byte)
}
// DatabaseWebsocketAPI is a database websocket API interface.
type DatabaseWebsocketAPI struct {
DatabaseAPI
sendQueue chan []byte
conn *websocket.Conn
}
func allowAnyOrigin(r *http.Request) bool {
return true
}
func startDatabaseAPI(w http.ResponseWriter, r *http.Request) {
// CreateDatabaseAPI creates a new database interface.
func CreateDatabaseAPI(sendFunction func(data []byte)) DatabaseAPI {
return DatabaseAPI{
queries: make(map[string]*iterator.Iterator),
subs: make(map[string]*database.Subscription),
shutdownSignal: make(chan struct{}),
shuttingDown: abool.NewBool(false),
db: database.NewInterface(nil),
sendBytes: sendFunction,
}
}
func startDatabaseWebsocketAPI(w http.ResponseWriter, r *http.Request) {
upgrader := websocket.Upgrader{
CheckOrigin: allowAnyOrigin,
ReadBufferSize: 1024,
@ -86,14 +105,21 @@ func startDatabaseAPI(w http.ResponseWriter, r *http.Request) {
return
}
newDBAPI := &DatabaseAPI{
conn: wsConn,
sendQueue: make(chan []byte, 100),
newDBAPI := &DatabaseWebsocketAPI{
DatabaseAPI: DatabaseAPI{
queries: make(map[string]*iterator.Iterator),
subs: make(map[string]*database.Subscription),
shutdownSignal: make(chan struct{}),
shuttingDown: abool.NewBool(false),
db: database.NewInterface(nil),
},
sendQueue: make(chan []byte, 100),
conn: wsConn,
}
newDBAPI.sendBytes = func(data []byte) {
newDBAPI.sendQueue <- data
}
module.StartWorker("database api handler", newDBAPI.handler)
@ -102,11 +128,77 @@ func startDatabaseAPI(w http.ResponseWriter, r *http.Request) {
log.Tracer(r.Context()).Infof("api request: init websocket %s %s", r.RemoteAddr, r.RequestURI)
}
func (api *DatabaseAPI) handler(context.Context) error {
func (api *DatabaseWebsocketAPI) handler(context.Context) error {
defer func() {
_ = api.shutdown(nil)
}()
for {
_, msg, err := api.conn.ReadMessage()
if err != nil {
return api.shutdown(err)
}
api.Handle(msg)
}
}
func (api *DatabaseWebsocketAPI) writer(ctx context.Context) error {
defer func() {
_ = api.shutdown(nil)
}()
var data []byte
var err error
for {
select {
// prioritize direct writes
case data = <-api.sendQueue:
if len(data) == 0 {
return nil
}
case <-ctx.Done():
return nil
case <-api.shutdownSignal:
return nil
}
// log.Tracef("api: sending %s", string(*msg))
err = api.conn.WriteMessage(websocket.BinaryMessage, data)
if err != nil {
return api.shutdown(err)
}
}
}
func (api *DatabaseWebsocketAPI) shutdown(err error) error {
// Check if we are the first to shut down.
if !api.shuttingDown.SetToIf(false, true) {
return nil
}
// Check the given error.
if err != nil {
if websocket.IsCloseError(err,
websocket.CloseNormalClosure,
websocket.CloseGoingAway,
websocket.CloseAbnormalClosure,
) {
log.Infof("api: websocket connection to %s closed", api.conn.RemoteAddr())
} else {
log.Warningf("api: websocket connection error with %s: %s", api.conn.RemoteAddr(), err)
}
}
// Trigger shutdown.
close(api.shutdownSignal)
_ = api.conn.Close()
return nil
}
// Handle handles a message for the database API.
func (api *DatabaseAPI) Handle(msg []byte) {
// 123|get|<key>
// 123|ok|<key>|<data>
// 123|error|<message>
@ -145,13 +237,6 @@ func (api *DatabaseAPI) handler(context.Context) error {
// 131|success
// 131|error|<message>
for {
_, msg, err := api.conn.ReadMessage()
if err != nil {
return api.shutdown(err)
}
parts := bytes.SplitN(msg, []byte("|"), 3)
// Handle special command "cancel"
@ -160,12 +245,12 @@ func (api *DatabaseAPI) handler(context.Context) error {
// 125|cancel
// 127|cancel
go api.handleCancel(parts[0])
continue
return
}
if len(parts) != 3 {
api.send(nil, dbMsgTypeError, "bad request: malformed message", nil)
continue
return
}
switch string(parts[1]) {
@ -186,7 +271,7 @@ func (api *DatabaseAPI) handler(context.Context) error {
dataParts := bytes.SplitN(parts[2], []byte("|"), 2)
if len(dataParts) != 2 {
api.send(nil, dbMsgTypeError, "bad request: malformed message", nil)
continue
return
}
switch string(parts[1]) {
@ -206,61 +291,6 @@ func (api *DatabaseAPI) handler(context.Context) error {
default:
api.send(parts[0], dbMsgTypeError, "bad request: unknown method", nil)
}
}
}
func (api *DatabaseAPI) writer(ctx context.Context) error {
defer func() {
_ = api.shutdown(nil)
}()
var data []byte
var err error
for {
select {
// prioritize direct writes
case data = <-api.sendQueue:
if len(data) == 0 {
return nil
}
case <-ctx.Done():
return nil
case <-api.shutdownSignal:
return nil
}
// log.Tracef("api: sending %s", string(*msg))
err = api.conn.WriteMessage(websocket.BinaryMessage, data)
if err != nil {
return api.shutdown(err)
}
}
}
func (api *DatabaseAPI) shutdown(err error) error {
// Check if we are the first to shut down.
if !api.shuttingDown.SetToIf(false, true) {
return nil
}
// Check the given error.
if err != nil {
if websocket.IsCloseError(err,
websocket.CloseNormalClosure,
websocket.CloseGoingAway,
websocket.CloseAbnormalClosure,
) {
log.Infof("api: websocket connection to %s closed", api.conn.RemoteAddr())
} else {
log.Warningf("api: websocket connection error with %s: %s", api.conn.RemoteAddr(), err)
}
}
// Trigger shutdown.
close(api.shutdownSignal)
_ = api.conn.Close()
return nil
}
func (api *DatabaseAPI) send(opID []byte, msgType string, msgOrKey string, data []byte) {
@ -278,7 +308,7 @@ func (api *DatabaseAPI) send(opID []byte, msgType string, msgOrKey string, data
c.Append(data)
}
api.sendQueue <- c.CompileData()
api.sendBytes(c.CompileData())
}
func (api *DatabaseAPI) handleGet(opID []byte, key string) {
@ -343,7 +373,7 @@ func (api *DatabaseAPI) processQuery(opID []byte, q *query.Query) (ok bool) {
case <-api.shutdownSignal:
// cancel query and return
it.Cancel()
return
return false
case r := <-it.Next:
// process query feed
if r != nil {
@ -367,7 +397,7 @@ func (api *DatabaseAPI) processQuery(opID []byte, q *query.Query) (ok bool) {
}
}
// func (api *DatabaseAPI) runQuery()
// func (api *DatabaseWebsocketAPI) runQuery()
func (api *DatabaseAPI) handleSub(opID []byte, queryText string) {
// 125|sub|<query>
@ -629,7 +659,7 @@ func (api *DatabaseAPI) handleDelete(opID []byte, key string) {
api.send(opID, dbMsgTypeSuccess, emptyString, nil)
}
// MarshalRecords locks and marshals the given record, additionally adding
// MarshalRecord locks and marshals the given record, additionally adding
// metadata and returning it as json.
func MarshalRecord(r record.Record, withDSDIdentifier bool) ([]byte, error) {
r.Lock()

View file

@ -208,7 +208,7 @@ func getAPIContext(r *http.Request) (apiEndpoint *Endpoint, apiRequest *Request)
// does not pass the sanity checks.
func RegisterEndpoint(e Endpoint) error {
if err := e.check(); err != nil {
return fmt.Errorf("%w: %s", ErrInvalidEndpoint, err)
return fmt.Errorf("%w: %w", ErrInvalidEndpoint, err)
}
endpointsLock.Lock()
@ -224,6 +224,18 @@ func RegisterEndpoint(e Endpoint) error {
return nil
}
// GetEndpointByPath returns the endpoint registered with the given path.
func GetEndpointByPath(path string) (*Endpoint, error) {
endpointsLock.Lock()
defer endpointsLock.Unlock()
endpoint, ok := endpoints[path]
if !ok {
return nil, fmt.Errorf("no registered endpoint on path: %q", path)
}
return endpoint, nil
}
func (e *Endpoint) check() error {
// Check path.
if strings.TrimSpace(e.Path) == "" {

View file

@ -33,11 +33,12 @@ type Request struct {
// apiRequestContextKey is a key used for the context key/value storage.
type apiRequestContextKey struct{}
var requestContextKey = apiRequestContextKey{}
// RequestContextKey is the key used to add the API request to the context.
var RequestContextKey = apiRequestContextKey{}
// GetAPIRequest returns the API Request of the given http request.
func GetAPIRequest(r *http.Request) *Request {
ar, ok := r.Context().Value(requestContextKey).(*Request)
ar, ok := r.Context().Value(RequestContextKey).(*Request)
if ok {
return ar
}

View file

@ -118,7 +118,7 @@ func (mh *mainHandler) handle(w http.ResponseWriter, r *http.Request) error {
apiRequest := &Request{
Request: r,
}
ctx = context.WithValue(ctx, requestContextKey, apiRequest)
ctx = context.WithValue(ctx, RequestContextKey, apiRequest)
// Add context back to request.
r = r.WithContext(ctx)
lrw := NewLoggingResponseWriter(w, r)

View file

@ -45,7 +45,7 @@ func (i *Interface) DelayedCacheWriter(ctx context.Context) error {
i.flushWriteCache(0)
case <-thresholdWriteTicker.C:
// Often check if the the write cache has filled up to a certain degree and
// Often check if the write cache has filled up to a certain degree and
// flush it to storage before we start evicting to-be-written entries and
// slow down the hot path again.
i.flushWriteCache(percentThreshold)

View file

@ -62,7 +62,7 @@ func (s *Sinkhole) PutMany(shadowDelete bool) (chan<- record.Record, <-chan erro
// start handler
go func() {
for range batch {
// nom, nom, nom
// discard everything
}
errs <- nil
}()

View file

@ -184,7 +184,7 @@ func Errorf(format string, things ...interface{}) {
}
}
// Critical is used to log events that completely break the system. Operation connot continue. User/Admin must be informed.
// Critical is used to log events that completely break the system. Operation cannot continue. User/Admin must be informed.
func Critical(msg string) {
atomic.AddUint64(critLogLines, 1)
if fastcheck(CriticalLevel) {
@ -192,7 +192,7 @@ func Critical(msg string) {
}
}
// Criticalf is used to log events that completely break the system. Operation connot continue. User/Admin must be informed.
// Criticalf is used to log events that completely break the system. Operation cannot continue. User/Admin must be informed.
func Criticalf(format string, things ...interface{}) {
atomic.AddUint64(critLogLines, 1)
if fastcheck(CriticalLevel) {

View file

@ -30,7 +30,7 @@ var (
)
// Module represents a module.
type Module struct {
type Module struct { //nolint:maligned
sync.RWMutex
Name string

View file

@ -44,7 +44,7 @@ func (f *Feeder) NeedsEntropy() bool {
return f.needsEntropy.IsSet()
}
// SupplyEntropy supplies entropy to to the Feeder, it will block until the Feeder has read from it.
// SupplyEntropy supplies entropy to the Feeder, it will block until the Feeder has read from it.
func (f *Feeder) SupplyEntropy(data []byte, entropy int) {
f.input <- &entropyData{
data: data,
@ -52,7 +52,7 @@ func (f *Feeder) SupplyEntropy(data []byte, entropy int) {
}
}
// SupplyEntropyIfNeeded supplies entropy to to the Feeder, but will not block if no entropy is currently needed.
// SupplyEntropyIfNeeded supplies entropy to the Feeder, but will not block if no entropy is currently needed.
func (f *Feeder) SupplyEntropyIfNeeded(data []byte, entropy int) {
if f.needsEntropy.IsSet() {
return
@ -67,14 +67,14 @@ func (f *Feeder) SupplyEntropyIfNeeded(data []byte, entropy int) {
}
}
// SupplyEntropyAsInt supplies entropy to to the Feeder, it will block until the Feeder has read from it.
// SupplyEntropyAsInt supplies entropy to the Feeder, it will block until the Feeder has read from it.
func (f *Feeder) SupplyEntropyAsInt(n int64, entropy int) {
b := make([]byte, 8)
binary.LittleEndian.PutUint64(b, uint64(n))
f.SupplyEntropy(b, entropy)
}
// SupplyEntropyAsIntIfNeeded supplies entropy to to the Feeder, but will not block if no entropy is currently needed.
// SupplyEntropyAsIntIfNeeded supplies entropy to the Feeder, but will not block if no entropy is currently needed.
func (f *Feeder) SupplyEntropyAsIntIfNeeded(n int64, entropy int) {
if f.needsEntropy.IsSet() { // avoid allocating a slice if possible
b := make([]byte, 8)

View file

@ -7,7 +7,7 @@ import (
"github.com/safing/portbase/utils/renameio"
)
func ExampleTempFile_justone() {
func ExampleTempFile_justone() { //nolint:testableexamples
persist := func(temperature float64) error {
t, err := renameio.TempFile("", "/srv/www/metrics.txt")
if err != nil {
@ -28,7 +28,7 @@ func ExampleTempFile_justone() {
}
}
func ExampleTempFile_many() {
func ExampleTempFile_many() { //nolint:testableexamples
// Prepare for writing files to /srv/www, effectively caching calls to
// TempDir which TempFile would otherwise need to make.
dir := renameio.TempDir("/srv/www")

View file

@ -1,4 +1,4 @@
// go:build !windows
//go:build !windows
package utils