From df62abdf1b743d4101c97999a45c86eb1ae41aa7 Mon Sep 17 00:00:00 2001 From: Vladimir Stoilov <vladimir@safing.io> Date: Fri, 2 Jun 2023 11:41:38 +0300 Subject: [PATCH] Add database custom interface functions --- api/database.go | 273 ++++++++++++++++++++++++++--------------------- api/endpoints.go | 11 ++ api/request.go | 4 +- api/router.go | 2 +- log/input.go | 4 +- 5 files changed, 166 insertions(+), 128 deletions(-) diff --git a/api/database.go b/api/database.go index 804dc1a..554e7f4 100644 --- a/api/database.go +++ b/api/database.go @@ -44,7 +44,7 @@ var ( func init() { RegisterHandler("/api/database/v1", WrapInAuthHandler( - startDatabaseAPI, + startDatabaseWebsocketAPI, // Default to admin read/write permissions until the database gets support // for api permissions. dbCompatibilityPermission, @@ -54,9 +54,6 @@ func init() { // DatabaseAPI is a database API instance. type DatabaseAPI struct { - conn *websocket.Conn - sendQueue chan []byte - queriesLock sync.Mutex queries map[string]*iterator.Iterator @@ -66,13 +63,33 @@ type DatabaseAPI struct { shutdownSignal chan struct{} shuttingDown *abool.AtomicBool db *database.Interface + + sendBytes func(data []byte) +} + +type DatabaseWebsocketAPI struct { + DatabaseAPI + + sendQueue chan []byte + conn *websocket.Conn } func allowAnyOrigin(r *http.Request) bool { return true } -func startDatabaseAPI(w http.ResponseWriter, r *http.Request) { +func CreateDatabaseAPI(sendFunction func(data []byte)) DatabaseAPI { + return DatabaseAPI{ + queries: make(map[string]*iterator.Iterator), + subs: make(map[string]*database.Subscription), + shutdownSignal: make(chan struct{}), + shuttingDown: abool.NewBool(false), + db: database.NewInterface(nil), + sendBytes: sendFunction, + } +} + +func startDatabaseWebsocketAPI(w http.ResponseWriter, r *http.Request) { upgrader := websocket.Upgrader{ CheckOrigin: allowAnyOrigin, ReadBufferSize: 1024, @@ -86,14 +103,21 @@ func startDatabaseAPI(w http.ResponseWriter, r *http.Request) { return } - newDBAPI := &DatabaseAPI{ - conn: wsConn, - sendQueue: make(chan []byte, 100), - queries: make(map[string]*iterator.Iterator), - subs: make(map[string]*database.Subscription), - shutdownSignal: make(chan struct{}), - shuttingDown: abool.NewBool(false), - db: database.NewInterface(nil), + newDBAPI := &DatabaseWebsocketAPI{ + DatabaseAPI: DatabaseAPI{ + queries: make(map[string]*iterator.Iterator), + subs: make(map[string]*database.Subscription), + shutdownSignal: make(chan struct{}), + shuttingDown: abool.NewBool(false), + db: database.NewInterface(nil), + }, + + sendQueue: make(chan []byte, 100), + conn: wsConn, + } + + newDBAPI.sendBytes = func(data []byte) { + newDBAPI.sendQueue <- data } module.StartWorker("database api handler", newDBAPI.handler) @@ -102,11 +126,76 @@ func startDatabaseAPI(w http.ResponseWriter, r *http.Request) { log.Tracer(r.Context()).Infof("api request: init websocket %s %s", r.RemoteAddr, r.RequestURI) } -func (api *DatabaseAPI) handler(context.Context) error { +func (api *DatabaseWebsocketAPI) handler(context.Context) error { defer func() { _ = api.shutdown(nil) }() + for { + _, msg, err := api.conn.ReadMessage() + if err != nil { + return api.shutdown(err) + } + + api.Handle(msg) + } +} + +func (api *DatabaseWebsocketAPI) writer(ctx context.Context) error { + defer func() { + _ = api.shutdown(nil) + }() + + var data []byte + var err error + + for { + select { + // prioritize direct writes + case data = <-api.sendQueue: + if len(data) == 0 { + return nil + } + case <-ctx.Done(): + return nil + case <-api.shutdownSignal: + return nil + } + + // log.Tracef("api: sending %s", string(*msg)) + err = api.conn.WriteMessage(websocket.BinaryMessage, data) + if err != nil { + return api.shutdown(err) + } + } +} + +func (api *DatabaseWebsocketAPI) shutdown(err error) error { + // Check if we are the first to shut down. + if !api.shuttingDown.SetToIf(false, true) { + return nil + } + + // Check the given error. + if err != nil { + if websocket.IsCloseError(err, + websocket.CloseNormalClosure, + websocket.CloseGoingAway, + websocket.CloseAbnormalClosure, + ) { + log.Infof("api: websocket connection to %s closed", api.conn.RemoteAddr()) + } else { + log.Warningf("api: websocket connection error with %s: %s", api.conn.RemoteAddr(), err) + } + } + + // Trigger shutdown. + close(api.shutdownSignal) + _ = api.conn.Close() + return nil +} + +func (api *DatabaseAPI) Handle(msg []byte) { // 123|get|<key> // 123|ok|<key>|<data> // 123|error|<message> @@ -145,124 +234,62 @@ func (api *DatabaseAPI) handler(context.Context) error { // 131|success // 131|error|<message> - for { + parts := bytes.SplitN(msg, []byte("|"), 3) - _, msg, err := api.conn.ReadMessage() - if err != nil { - return api.shutdown(err) - } + // Handle special command "cancel" + if len(parts) == 2 && string(parts[1]) == "cancel" { + // 124|cancel + // 125|cancel + // 127|cancel + go api.handleCancel(parts[0]) + return + } - parts := bytes.SplitN(msg, []byte("|"), 3) + if len(parts) != 3 { + api.send(nil, dbMsgTypeError, "bad request: malformed message", nil) + return + } - // Handle special command "cancel" - if len(parts) == 2 && string(parts[1]) == "cancel" { - // 124|cancel - // 125|cancel - // 127|cancel - go api.handleCancel(parts[0]) - continue - } - - if len(parts) != 3 { + switch string(parts[1]) { + case "get": + // 123|get|<key> + go api.handleGet(parts[0], string(parts[2])) + case "query": + // 124|query|<query> + go api.handleQuery(parts[0], string(parts[2])) + case "sub": + // 125|sub|<query> + go api.handleSub(parts[0], string(parts[2])) + case "qsub": + // 127|qsub|<query> + go api.handleQsub(parts[0], string(parts[2])) + case "create", "update", "insert": + // split key and payload + dataParts := bytes.SplitN(parts[2], []byte("|"), 2) + if len(dataParts) != 2 { api.send(nil, dbMsgTypeError, "bad request: malformed message", nil) - continue + return } switch string(parts[1]) { - case "get": - // 123|get|<key> - go api.handleGet(parts[0], string(parts[2])) - case "query": - // 124|query|<query> - go api.handleQuery(parts[0], string(parts[2])) - case "sub": - // 125|sub|<query> - go api.handleSub(parts[0], string(parts[2])) - case "qsub": - // 127|qsub|<query> - go api.handleQsub(parts[0], string(parts[2])) - case "create", "update", "insert": - // split key and payload - dataParts := bytes.SplitN(parts[2], []byte("|"), 2) - if len(dataParts) != 2 { - api.send(nil, dbMsgTypeError, "bad request: malformed message", nil) - continue - } - - switch string(parts[1]) { - case "create": - // 128|create|<key>|<data> - go api.handlePut(parts[0], string(dataParts[0]), dataParts[1], true) - case "update": - // 129|update|<key>|<data> - go api.handlePut(parts[0], string(dataParts[0]), dataParts[1], false) - case "insert": - // 130|insert|<key>|<data> - go api.handleInsert(parts[0], string(dataParts[0]), dataParts[1]) - } - case "delete": - // 131|delete|<key> - go api.handleDelete(parts[0], string(parts[2])) - default: - api.send(parts[0], dbMsgTypeError, "bad request: unknown method", nil) + case "create": + // 128|create|<key>|<data> + go api.handlePut(parts[0], string(dataParts[0]), dataParts[1], true) + case "update": + // 129|update|<key>|<data> + go api.handlePut(parts[0], string(dataParts[0]), dataParts[1], false) + case "insert": + // 130|insert|<key>|<data> + go api.handleInsert(parts[0], string(dataParts[0]), dataParts[1]) } + case "delete": + // 131|delete|<key> + go api.handleDelete(parts[0], string(parts[2])) + default: + api.send(parts[0], dbMsgTypeError, "bad request: unknown method", nil) } } -func (api *DatabaseAPI) writer(ctx context.Context) error { - defer func() { - _ = api.shutdown(nil) - }() - - var data []byte - var err error - - for { - select { - // prioritize direct writes - case data = <-api.sendQueue: - if len(data) == 0 { - return nil - } - case <-ctx.Done(): - return nil - case <-api.shutdownSignal: - return nil - } - - // log.Tracef("api: sending %s", string(*msg)) - err = api.conn.WriteMessage(websocket.BinaryMessage, data) - if err != nil { - return api.shutdown(err) - } - } -} - -func (api *DatabaseAPI) shutdown(err error) error { - // Check if we are the first to shut down. - if !api.shuttingDown.SetToIf(false, true) { - return nil - } - - // Check the given error. - if err != nil { - if websocket.IsCloseError(err, - websocket.CloseNormalClosure, - websocket.CloseGoingAway, - websocket.CloseAbnormalClosure, - ) { - log.Infof("api: websocket connection to %s closed", api.conn.RemoteAddr()) - } else { - log.Warningf("api: websocket connection error with %s: %s", api.conn.RemoteAddr(), err) - } - } - - // Trigger shutdown. - close(api.shutdownSignal) - _ = api.conn.Close() - return nil -} - func (api *DatabaseAPI) send(opID []byte, msgType string, msgOrKey string, data []byte) { c := container.New(opID) c.Append(dbAPISeperatorBytes) @@ -278,7 +305,7 @@ func (api *DatabaseAPI) send(opID []byte, msgType string, msgOrKey string, data c.Append(data) } - api.sendQueue <- c.CompileData() + api.sendBytes(c.CompileData()) } func (api *DatabaseAPI) handleGet(opID []byte, key string) { @@ -367,7 +394,7 @@ func (api *DatabaseAPI) processQuery(opID []byte, q *query.Query) (ok bool) { } } -// func (api *DatabaseAPI) runQuery() +// func (api *DatabaseWebsocketAPI) runQuery() func (api *DatabaseAPI) handleSub(opID []byte, queryText string) { // 125|sub|<query> diff --git a/api/endpoints.go b/api/endpoints.go index 3a005af..65de1b1 100644 --- a/api/endpoints.go +++ b/api/endpoints.go @@ -224,6 +224,17 @@ func RegisterEndpoint(e Endpoint) error { return nil } +func GetEndpointByPath(path string) (*Endpoint, error) { + endpointsLock.Lock() + defer endpointsLock.Unlock() + endpoint, ok := endpoints[path] + if !ok { + return nil, fmt.Errorf("no registered endpoint on path: %q", path) + } + + return endpoint, nil +} + func (e *Endpoint) check() error { // Check path. if strings.TrimSpace(e.Path) == "" { diff --git a/api/request.go b/api/request.go index 88d77af..9891d9b 100644 --- a/api/request.go +++ b/api/request.go @@ -33,11 +33,11 @@ type Request struct { // apiRequestContextKey is a key used for the context key/value storage. type apiRequestContextKey struct{} -var requestContextKey = apiRequestContextKey{} +var RequestContextKey = apiRequestContextKey{} // GetAPIRequest returns the API Request of the given http request. func GetAPIRequest(r *http.Request) *Request { - ar, ok := r.Context().Value(requestContextKey).(*Request) + ar, ok := r.Context().Value(RequestContextKey).(*Request) if ok { return ar } diff --git a/api/router.go b/api/router.go index d9a93a5..029235e 100644 --- a/api/router.go +++ b/api/router.go @@ -118,7 +118,7 @@ func (mh *mainHandler) handle(w http.ResponseWriter, r *http.Request) error { apiRequest := &Request{ Request: r, } - ctx = context.WithValue(ctx, requestContextKey, apiRequest) + ctx = context.WithValue(ctx, RequestContextKey, apiRequest) // Add context back to request. r = r.WithContext(ctx) lrw := NewLoggingResponseWriter(w, r) diff --git a/log/input.go b/log/input.go index 609cff0..ef8564a 100644 --- a/log/input.go +++ b/log/input.go @@ -184,7 +184,7 @@ func Errorf(format string, things ...interface{}) { } } -// Critical is used to log events that completely break the system. Operation connot continue. User/Admin must be informed. +// Critical is used to log events that completely break the system. Operation cannot continue. User/Admin must be informed. func Critical(msg string) { atomic.AddUint64(critLogLines, 1) if fastcheck(CriticalLevel) { @@ -192,7 +192,7 @@ func Critical(msg string) { } } -// Criticalf is used to log events that completely break the system. Operation connot continue. User/Admin must be informed. +// Criticalf is used to log events that completely break the system. Operation cannot continue. User/Admin must be informed. func Criticalf(format string, things ...interface{}) { atomic.AddUint64(critLogLines, 1) if fastcheck(CriticalLevel) {