From df62abdf1b743d4101c97999a45c86eb1ae41aa7 Mon Sep 17 00:00:00 2001
From: Vladimir Stoilov <vladimir@safing.io>
Date: Fri, 2 Jun 2023 11:41:38 +0300
Subject: [PATCH] Add database custom interface functions

---
 api/database.go  | 273 ++++++++++++++++++++++++++---------------------
 api/endpoints.go |  11 ++
 api/request.go   |   4 +-
 api/router.go    |   2 +-
 log/input.go     |   4 +-
 5 files changed, 166 insertions(+), 128 deletions(-)

diff --git a/api/database.go b/api/database.go
index 804dc1a..554e7f4 100644
--- a/api/database.go
+++ b/api/database.go
@@ -44,7 +44,7 @@ var (
 
 func init() {
 	RegisterHandler("/api/database/v1", WrapInAuthHandler(
-		startDatabaseAPI,
+		startDatabaseWebsocketAPI,
 		// Default to admin read/write permissions until the database gets support
 		// for api permissions.
 		dbCompatibilityPermission,
@@ -54,9 +54,6 @@ func init() {
 
 // DatabaseAPI is a database API instance.
 type DatabaseAPI struct {
-	conn      *websocket.Conn
-	sendQueue chan []byte
-
 	queriesLock sync.Mutex
 	queries     map[string]*iterator.Iterator
 
@@ -66,13 +63,33 @@ type DatabaseAPI struct {
 	shutdownSignal chan struct{}
 	shuttingDown   *abool.AtomicBool
 	db             *database.Interface
+
+	sendBytes func(data []byte)
+}
+
+type DatabaseWebsocketAPI struct {
+	DatabaseAPI
+
+	sendQueue chan []byte
+	conn      *websocket.Conn
 }
 
 func allowAnyOrigin(r *http.Request) bool {
 	return true
 }
 
-func startDatabaseAPI(w http.ResponseWriter, r *http.Request) {
+func CreateDatabaseAPI(sendFunction func(data []byte)) DatabaseAPI {
+	return DatabaseAPI{
+		queries:        make(map[string]*iterator.Iterator),
+		subs:           make(map[string]*database.Subscription),
+		shutdownSignal: make(chan struct{}),
+		shuttingDown:   abool.NewBool(false),
+		db:             database.NewInterface(nil),
+		sendBytes:      sendFunction,
+	}
+}
+
+func startDatabaseWebsocketAPI(w http.ResponseWriter, r *http.Request) {
 	upgrader := websocket.Upgrader{
 		CheckOrigin:     allowAnyOrigin,
 		ReadBufferSize:  1024,
@@ -86,14 +103,21 @@ func startDatabaseAPI(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	newDBAPI := &DatabaseAPI{
-		conn:           wsConn,
-		sendQueue:      make(chan []byte, 100),
-		queries:        make(map[string]*iterator.Iterator),
-		subs:           make(map[string]*database.Subscription),
-		shutdownSignal: make(chan struct{}),
-		shuttingDown:   abool.NewBool(false),
-		db:             database.NewInterface(nil),
+	newDBAPI := &DatabaseWebsocketAPI{
+		DatabaseAPI: DatabaseAPI{
+			queries:        make(map[string]*iterator.Iterator),
+			subs:           make(map[string]*database.Subscription),
+			shutdownSignal: make(chan struct{}),
+			shuttingDown:   abool.NewBool(false),
+			db:             database.NewInterface(nil),
+		},
+
+		sendQueue: make(chan []byte, 100),
+		conn:      wsConn,
+	}
+
+	newDBAPI.sendBytes = func(data []byte) {
+		newDBAPI.sendQueue <- data
 	}
 
 	module.StartWorker("database api handler", newDBAPI.handler)
@@ -102,11 +126,76 @@ func startDatabaseAPI(w http.ResponseWriter, r *http.Request) {
 	log.Tracer(r.Context()).Infof("api request: init websocket %s %s", r.RemoteAddr, r.RequestURI)
 }
 
-func (api *DatabaseAPI) handler(context.Context) error {
+func (api *DatabaseWebsocketAPI) handler(context.Context) error {
 	defer func() {
 		_ = api.shutdown(nil)
 	}()
 
+	for {
+		_, msg, err := api.conn.ReadMessage()
+		if err != nil {
+			return api.shutdown(err)
+		}
+
+		api.Handle(msg)
+	}
+}
+
+func (api *DatabaseWebsocketAPI) writer(ctx context.Context) error {
+	defer func() {
+		_ = api.shutdown(nil)
+	}()
+
+	var data []byte
+	var err error
+
+	for {
+		select {
+		// prioritize direct writes
+		case data = <-api.sendQueue:
+			if len(data) == 0 {
+				return nil
+			}
+		case <-ctx.Done():
+			return nil
+		case <-api.shutdownSignal:
+			return nil
+		}
+
+		// log.Tracef("api: sending %s", string(*msg))
+		err = api.conn.WriteMessage(websocket.BinaryMessage, data)
+		if err != nil {
+			return api.shutdown(err)
+		}
+	}
+}
+
+func (api *DatabaseWebsocketAPI) shutdown(err error) error {
+	// Check if we are the first to shut down.
+	if !api.shuttingDown.SetToIf(false, true) {
+		return nil
+	}
+
+	// Check the given error.
+	if err != nil {
+		if websocket.IsCloseError(err,
+			websocket.CloseNormalClosure,
+			websocket.CloseGoingAway,
+			websocket.CloseAbnormalClosure,
+		) {
+			log.Infof("api: websocket connection to %s closed", api.conn.RemoteAddr())
+		} else {
+			log.Warningf("api: websocket connection error with %s: %s", api.conn.RemoteAddr(), err)
+		}
+	}
+
+	// Trigger shutdown.
+	close(api.shutdownSignal)
+	_ = api.conn.Close()
+	return nil
+}
+
+func (api *DatabaseAPI) Handle(msg []byte) {
 	// 123|get|<key>
 	//    123|ok|<key>|<data>
 	//    123|error|<message>
@@ -145,124 +234,62 @@ func (api *DatabaseAPI) handler(context.Context) error {
 	//    131|success
 	//    131|error|<message>
 
-	for {
+	parts := bytes.SplitN(msg, []byte("|"), 3)
 
-		_, msg, err := api.conn.ReadMessage()
-		if err != nil {
-			return api.shutdown(err)
-		}
+	// Handle special command "cancel"
+	if len(parts) == 2 && string(parts[1]) == "cancel" {
+		// 124|cancel
+		// 125|cancel
+		// 127|cancel
+		go api.handleCancel(parts[0])
+		return
+	}
 
-		parts := bytes.SplitN(msg, []byte("|"), 3)
+	if len(parts) != 3 {
+		api.send(nil, dbMsgTypeError, "bad request: malformed message", nil)
+		return
+	}
 
-		// Handle special command "cancel"
-		if len(parts) == 2 && string(parts[1]) == "cancel" {
-			// 124|cancel
-			// 125|cancel
-			// 127|cancel
-			go api.handleCancel(parts[0])
-			continue
-		}
-
-		if len(parts) != 3 {
+	switch string(parts[1]) {
+	case "get":
+		// 123|get|<key>
+		go api.handleGet(parts[0], string(parts[2]))
+	case "query":
+		// 124|query|<query>
+		go api.handleQuery(parts[0], string(parts[2]))
+	case "sub":
+		// 125|sub|<query>
+		go api.handleSub(parts[0], string(parts[2]))
+	case "qsub":
+		// 127|qsub|<query>
+		go api.handleQsub(parts[0], string(parts[2]))
+	case "create", "update", "insert":
+		// split key and payload
+		dataParts := bytes.SplitN(parts[2], []byte("|"), 2)
+		if len(dataParts) != 2 {
 			api.send(nil, dbMsgTypeError, "bad request: malformed message", nil)
-			continue
+			return
 		}
 
 		switch string(parts[1]) {
-		case "get":
-			// 123|get|<key>
-			go api.handleGet(parts[0], string(parts[2]))
-		case "query":
-			// 124|query|<query>
-			go api.handleQuery(parts[0], string(parts[2]))
-		case "sub":
-			// 125|sub|<query>
-			go api.handleSub(parts[0], string(parts[2]))
-		case "qsub":
-			// 127|qsub|<query>
-			go api.handleQsub(parts[0], string(parts[2]))
-		case "create", "update", "insert":
-			// split key and payload
-			dataParts := bytes.SplitN(parts[2], []byte("|"), 2)
-			if len(dataParts) != 2 {
-				api.send(nil, dbMsgTypeError, "bad request: malformed message", nil)
-				continue
-			}
-
-			switch string(parts[1]) {
-			case "create":
-				// 128|create|<key>|<data>
-				go api.handlePut(parts[0], string(dataParts[0]), dataParts[1], true)
-			case "update":
-				// 129|update|<key>|<data>
-				go api.handlePut(parts[0], string(dataParts[0]), dataParts[1], false)
-			case "insert":
-				// 130|insert|<key>|<data>
-				go api.handleInsert(parts[0], string(dataParts[0]), dataParts[1])
-			}
-		case "delete":
-			// 131|delete|<key>
-			go api.handleDelete(parts[0], string(parts[2]))
-		default:
-			api.send(parts[0], dbMsgTypeError, "bad request: unknown method", nil)
+		case "create":
+			// 128|create|<key>|<data>
+			go api.handlePut(parts[0], string(dataParts[0]), dataParts[1], true)
+		case "update":
+			// 129|update|<key>|<data>
+			go api.handlePut(parts[0], string(dataParts[0]), dataParts[1], false)
+		case "insert":
+			// 130|insert|<key>|<data>
+			go api.handleInsert(parts[0], string(dataParts[0]), dataParts[1])
 		}
+	case "delete":
+		// 131|delete|<key>
+		go api.handleDelete(parts[0], string(parts[2]))
+	default:
+		api.send(parts[0], dbMsgTypeError, "bad request: unknown method", nil)
 	}
 }
 
-func (api *DatabaseAPI) writer(ctx context.Context) error {
-	defer func() {
-		_ = api.shutdown(nil)
-	}()
-
-	var data []byte
-	var err error
-
-	for {
-		select {
-		// prioritize direct writes
-		case data = <-api.sendQueue:
-			if len(data) == 0 {
-				return nil
-			}
-		case <-ctx.Done():
-			return nil
-		case <-api.shutdownSignal:
-			return nil
-		}
-
-		// log.Tracef("api: sending %s", string(*msg))
-		err = api.conn.WriteMessage(websocket.BinaryMessage, data)
-		if err != nil {
-			return api.shutdown(err)
-		}
-	}
-}
-
-func (api *DatabaseAPI) shutdown(err error) error {
-	// Check if we are the first to shut down.
-	if !api.shuttingDown.SetToIf(false, true) {
-		return nil
-	}
-
-	// Check the given error.
-	if err != nil {
-		if websocket.IsCloseError(err,
-			websocket.CloseNormalClosure,
-			websocket.CloseGoingAway,
-			websocket.CloseAbnormalClosure,
-		) {
-			log.Infof("api: websocket connection to %s closed", api.conn.RemoteAddr())
-		} else {
-			log.Warningf("api: websocket connection error with %s: %s", api.conn.RemoteAddr(), err)
-		}
-	}
-
-	// Trigger shutdown.
-	close(api.shutdownSignal)
-	_ = api.conn.Close()
-	return nil
-}
-
 func (api *DatabaseAPI) send(opID []byte, msgType string, msgOrKey string, data []byte) {
 	c := container.New(opID)
 	c.Append(dbAPISeperatorBytes)
@@ -278,7 +305,7 @@ func (api *DatabaseAPI) send(opID []byte, msgType string, msgOrKey string, data
 		c.Append(data)
 	}
 
-	api.sendQueue <- c.CompileData()
+	api.sendBytes(c.CompileData())
 }
 
 func (api *DatabaseAPI) handleGet(opID []byte, key string) {
@@ -367,7 +394,7 @@ func (api *DatabaseAPI) processQuery(opID []byte, q *query.Query) (ok bool) {
 	}
 }
 
-// func (api *DatabaseAPI) runQuery()
+// func (api *DatabaseWebsocketAPI) runQuery()
 
 func (api *DatabaseAPI) handleSub(opID []byte, queryText string) {
 	// 125|sub|<query>
diff --git a/api/endpoints.go b/api/endpoints.go
index 3a005af..65de1b1 100644
--- a/api/endpoints.go
+++ b/api/endpoints.go
@@ -224,6 +224,17 @@ func RegisterEndpoint(e Endpoint) error {
 	return nil
 }
 
+func GetEndpointByPath(path string) (*Endpoint, error) {
+	endpointsLock.Lock()
+	defer endpointsLock.Unlock()
+	endpoint, ok := endpoints[path]
+	if !ok {
+		return nil, fmt.Errorf("no registered endpoint on path: %q", path)
+	}
+
+	return endpoint, nil
+}
+
 func (e *Endpoint) check() error {
 	// Check path.
 	if strings.TrimSpace(e.Path) == "" {
diff --git a/api/request.go b/api/request.go
index 88d77af..9891d9b 100644
--- a/api/request.go
+++ b/api/request.go
@@ -33,11 +33,11 @@ type Request struct {
 // apiRequestContextKey is a key used for the context key/value storage.
 type apiRequestContextKey struct{}
 
-var requestContextKey = apiRequestContextKey{}
+var RequestContextKey = apiRequestContextKey{}
 
 // GetAPIRequest returns the API Request of the given http request.
 func GetAPIRequest(r *http.Request) *Request {
-	ar, ok := r.Context().Value(requestContextKey).(*Request)
+	ar, ok := r.Context().Value(RequestContextKey).(*Request)
 	if ok {
 		return ar
 	}
diff --git a/api/router.go b/api/router.go
index d9a93a5..029235e 100644
--- a/api/router.go
+++ b/api/router.go
@@ -118,7 +118,7 @@ func (mh *mainHandler) handle(w http.ResponseWriter, r *http.Request) error {
 	apiRequest := &Request{
 		Request: r,
 	}
-	ctx = context.WithValue(ctx, requestContextKey, apiRequest)
+	ctx = context.WithValue(ctx, RequestContextKey, apiRequest)
 	// Add context back to request.
 	r = r.WithContext(ctx)
 	lrw := NewLoggingResponseWriter(w, r)
diff --git a/log/input.go b/log/input.go
index 609cff0..ef8564a 100644
--- a/log/input.go
+++ b/log/input.go
@@ -184,7 +184,7 @@ func Errorf(format string, things ...interface{}) {
 	}
 }
 
-// Critical is used to log events that completely break the system. Operation connot continue. User/Admin must be informed.
+// Critical is used to log events that completely break the system. Operation cannot continue. User/Admin must be informed.
 func Critical(msg string) {
 	atomic.AddUint64(critLogLines, 1)
 	if fastcheck(CriticalLevel) {
@@ -192,7 +192,7 @@ func Critical(msg string) {
 	}
 }
 
-// Criticalf is used to log events that completely break the system. Operation connot continue. User/Admin must be informed.
+// Criticalf is used to log events that completely break the system. Operation cannot continue. User/Admin must be informed.
 func Criticalf(format string, things ...interface{}) {
 	atomic.AddUint64(critLogLines, 1)
 	if fastcheck(CriticalLevel) {