diff --git a/api/database.go b/api/database.go new file mode 100644 index 0000000..0af80bd --- /dev/null +++ b/api/database.go @@ -0,0 +1,359 @@ +package api + +import ( + "bytes" + "fmt" + "net/http" + + "github.com/gorilla/websocket" + "github.com/tevino/abool" + + "github.com/Safing/portbase/container" + "github.com/Safing/portbase/database" + "github.com/Safing/portbase/database/query" + "github.com/Safing/portbase/database/record" + "github.com/Safing/portbase/log" +) + +const ( + dbMsgTypeOk = "ok" + dbMsgTypeError = "error" + dbMsgTypeDone = "done" + dbMsgTypeSuccess = "success" + dbMsgTypeUpd = "upd" + dbMsgTypeNew = "new" + dbMsgTypeDelete = "delete" + dbMsgTypeWarning = "warning" +) + +// DatabaseAPI is a database API instance. +type DatabaseAPI struct { + conn *websocket.Conn + sendQueue chan []byte + subs map[string]*database.Subscription + + shutdownSignal chan struct{} + shuttingDown *abool.AtomicBool + db *database.Interface +} + +func allowAnyOrigin(r *http.Request) bool { + return true +} + +func startDatabaseAPI(w http.ResponseWriter, r *http.Request) { + + upgrader := websocket.Upgrader{ + CheckOrigin: allowAnyOrigin, + ReadBufferSize: 1024, + WriteBufferSize: 65536, + } + wsConn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + errMsg := fmt.Sprintf("could not upgrade to websocket: %s", err) + log.Error(errMsg) + http.Error(w, errMsg, 400) + return + } + + new := &DatabaseAPI{ + conn: wsConn, + sendQueue: make(chan []byte, 100), + subs: make(map[string]*database.Subscription), + shutdownSignal: make(chan struct{}), + shuttingDown: abool.NewBool(false), + db: database.NewInterface(nil), + } + + go new.handler() + go new.writer() +} + +func (api *DatabaseAPI) handler() { + + // 123|get| + // 123|ok|| + // 123|error| + // 124|query| + // 124|ok|| + // 124|done + // 124|error| + // 125|sub| + // 125|upd|| + // 125|new|| + // 125|delete|| + // 125|warning| // does not cancel the subscription + // 127|qsub| + // 127|ok|| + // 127|done + // 127|error| + // 127|upd|| + // 127|new|| + // 127|delete|| + // 127|warning| // does not cancel the subscription + + // 128|create|| + // 128|success + // 128|error| + // 129|update|| + // 129|success + // 129|error| + // 130|insert|| + // 130|success + // 130|error| + + for { + + _, msg, err := api.conn.ReadMessage() + if err != nil { + if !api.shuttingDown.IsSet() { + api.shutdown() + if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + log.Warningf("api: websocket write error: %s", err) + } + } + return + } + + parts := bytes.SplitN(msg, []byte("|"), 2) + if len(parts) != 3 { + api.send(nil, dbMsgTypeError, []byte("bad request: malformed message")) + continue + } + + switch string(parts[1]) { + case "get": + // 123|get| + go api.handleGet(parts[0], string(parts[2])) + case "query": + // 124|query| + go api.handleQuery(parts[0], string(parts[2])) + case "sub": + // 125|sub| + go api.handleSub(parts[0], string(parts[2])) + case "qsub": + // 127|qsub| + go api.handleQsub(parts[0], string(parts[2])) + case "create", "update", "insert": + + // split key and payload + dataParts := bytes.SplitN(parts[2], []byte("|"), 1) + if len(dataParts) != 2 { + api.send(nil, dbMsgTypeError, []byte("bad request: malformed message")) + continue + } + + switch string(parts[1]) { + case "create": + // 128|create|| + go api.handleCreate(parts[0], string(dataParts[0]), dataParts[1]) + case "update": + // 129|update|| + go api.handleUpdate(parts[0], string(dataParts[0]), dataParts[1]) + case "insert": + // 130|insert|| + go api.handleInsert(parts[0], string(dataParts[0]), dataParts[1]) + } + + default: + api.send(parts[0], dbMsgTypeError, []byte("bad request: unknown method")) + } + } +} + +func (api *DatabaseAPI) writer() { + var data []byte + var err error + + for { + data = nil + + select { + // prioritize direct writes + case data = <-api.sendQueue: + if data == nil || len(data) == 0 { + api.shutdown() + return + } + case <-api.shutdownSignal: + return + } + + // log.Tracef("api: sending %s", string(*msg)) + err = api.conn.WriteMessage(websocket.BinaryMessage, data) + if err != nil { + if !api.shuttingDown.IsSet() { + api.shutdown() + if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + log.Warningf("api: websocket write error: %s", err) + } + } + return + } + + } +} + +func (api *DatabaseAPI) send(opID []byte, msgType string, data []byte) { + c := container.New(opID) + c.Append([]byte(fmt.Sprintf("|%s|", msgType))) + c.Append(data) + api.sendQueue <- c.CompileData() +} + +func (api *DatabaseAPI) handleGet(opID []byte, key string) { + // 123|get| + // 123|ok|| + // 123|error| + + var data []byte + + r, err := api.db.Get(key) + if err == nil { + data, err = r.Marshal(r, record.JSON) + } else { + api.send(opID, dbMsgTypeError, []byte(err.Error())) + return + } + api.send(opID, dbMsgTypeOk, data) +} + +func (api *DatabaseAPI) handleQuery(opID []byte, queryText string) { + // 124|query| + // 124|ok|| + // 124|done + // 124|warning| + // 124|error| + + var err error + + q, err := query.ParseQuery(queryText) + if err != nil { + api.send(opID, dbMsgTypeError, []byte(err.Error())) + return + } + + api.processQuery(opID, q) +} + +func (api *DatabaseAPI) processQuery(opID []byte, q *query.Query) (ok bool) { + it, err := api.db.Query(q) + if err != nil { + api.send(opID, dbMsgTypeError, []byte(err.Error())) + return false + } + + for r := range it.Next { + data, err := r.Marshal(r, record.JSON) + if err != nil { + api.send(opID, dbMsgTypeWarning, []byte(err.Error())) + } + api.send(opID, dbMsgTypeOk, data) + } + if it.Error != nil { + api.send(opID, dbMsgTypeError, []byte(err.Error())) + return false + } + + api.send(opID, dbMsgTypeDone, nil) + return true +} + +// func (api *DatabaseAPI) runQuery() + +func (api *DatabaseAPI) handleSub(opID []byte, queryText string) { + // 125|sub| + // 125|upd|| + // 125|new|| + // 125|delete| + // 125|warning| // does not cancel the subscription + var err error + + q, err := query.ParseQuery(queryText) + if err != nil { + api.send(opID, dbMsgTypeError, []byte(err.Error())) + return + } + + sub, ok := api.registerSub(opID, q) + if !ok { + return + } + api.processSub(opID, sub) +} + +func (api *DatabaseAPI) registerSub(opID []byte, q *query.Query) (sub *database.Subscription, ok bool) { + var err error + sub, err = api.db.Subscribe(q) + if err != nil { + api.send(opID, dbMsgTypeError, []byte(err.Error())) + return nil, false + } + return sub, true +} + +func (api *DatabaseAPI) processSub(opID []byte, sub *database.Subscription) { + for r := range sub.Feed { + data, err := r.Marshal(r, record.JSON) + if err != nil { + api.send(opID, dbMsgTypeWarning, []byte(err.Error())) + } + // TODO: use upd, new and delete msgTypes + api.send(opID, dbMsgTypeOk, data) + } + if sub.Err != nil { + api.send(opID, dbMsgTypeError, []byte(sub.Err.Error())) + } +} + +func (api *DatabaseAPI) handleQsub(opID []byte, queryText string) { + // 127|qsub| + // 127|ok|| + // 127|done + // 127|error| + // 127|upd|| + // 127|new|| + // 127|delete| + // 127|warning| // does not cancel the subscription + + var err error + + q, err := query.ParseQuery(queryText) + if err != nil { + api.send(opID, dbMsgTypeError, []byte(err.Error())) + return + } + + sub, ok := api.registerSub(opID, q) + if !ok { + return + } + ok = api.processQuery(opID, q) + if !ok { + return + } + api.processSub(opID, sub) +} + +func (api *DatabaseAPI) handleCreate(opID []byte, key string, data []byte) { + // 128|create|| + // 128|success + // 128|error| +} +func (api *DatabaseAPI) handleUpdate(opID []byte, key string, data []byte) { + // 129|update|| + // 129|success + // 129|error| +} +func (api *DatabaseAPI) handleInsert(opID []byte, key string, data []byte) { + // 130|insert|| + // 130|success + // 130|error| +} + +func (api *DatabaseAPI) shutdown() { + if api.shuttingDown.SetToIf(false, true) { + close(api.shutdownSignal) + api.conn.Close() + } +} diff --git a/api/actions.go b/api/old/actions.go similarity index 100% rename from api/actions.go rename to api/old/actions.go diff --git a/api/api.go b/api/old/api.go similarity index 100% rename from api/api.go rename to api/old/api.go diff --git a/api/handlers.go b/api/old/handlers.go similarity index 100% rename from api/handlers.go rename to api/old/handlers.go diff --git a/api/logger.go b/api/old/logger.go similarity index 100% rename from api/logger.go rename to api/old/logger.go diff --git a/api/old/router.go b/api/old/router.go new file mode 100644 index 0000000..01b06d1 --- /dev/null +++ b/api/old/router.go @@ -0,0 +1,28 @@ +// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. + +package api + +import ( + "net/http" + + "github.com/gorilla/mux" +) + +func NewRouter() *mux.Router { + router := mux.NewRouter().StrictSlash(true) + + for _, route := range routes { + var handler http.Handler + + handler = route.Handler + handler = Logger(handler, route.Name) + + router. + Methods(route.Method). + PathPrefix(route.Path). + Name(route.Name). + Handler(handler) + } + + return router +} diff --git a/api/routes.go b/api/old/routes.go similarity index 100% rename from api/routes.go rename to api/old/routes.go diff --git a/api/session.go b/api/old/session.go similarity index 100% rename from api/session.go rename to api/old/session.go diff --git a/api/writer.go b/api/old/writer.go similarity index 100% rename from api/writer.go rename to api/old/writer.go diff --git a/api/router.go b/api/router.go index 01b06d1..c06a5e8 100644 --- a/api/router.go +++ b/api/router.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package api import ( @@ -8,21 +6,24 @@ import ( "github.com/gorilla/mux" ) -func NewRouter() *mux.Router { - router := mux.NewRouter().StrictSlash(true) +var ( + additionalRoutes map[string]func(arg1 http.ResponseWriter, arg2 *http.Request) +) - for _, route := range routes { - var handler http.Handler +func RegisterAdditionalRoute(path string, handleFunc func(arg1 http.ResponseWriter, arg2 *http.Request)) { + if additionalRoutes == nil { + additionalRoutes = make(map[string]func(arg1 http.ResponseWriter, arg2 *http.Request)) + } + additionalRoutes[path] = handleFunc +} - handler = route.Handler - handler = Logger(handler, route.Name) +func Serve() { - router. - Methods(route.Method). - PathPrefix(route.Path). - Name(route.Name). - Handler(handler) + router := mux.NewRouter() + router.HandleFunc("/api/database/v1", startDatabaseAPI) + + for path, handleFunc := range additionalRoutes { + router.HandleFunc(path, handleFunc) } - return router } diff --git a/api/security.go b/api/security.go new file mode 100644 index 0000000..778f64e --- /dev/null +++ b/api/security.go @@ -0,0 +1 @@ +package api diff --git a/api/websocket.go b/api/websocket.go new file mode 100644 index 0000000..778f64e --- /dev/null +++ b/api/websocket.go @@ -0,0 +1 @@ +package api diff --git a/database/controller.go b/database/controller.go index 438072f..02e8d63 100644 --- a/database/controller.go +++ b/database/controller.go @@ -129,7 +129,7 @@ func (c *Controller) Put(r record.Record) (err error) { // process subscriptions for _, sub := range c.subscriptions { - if sub.q.Matches(r) { + if r.Meta().CheckPermission(sub.local, sub.internal) && sub.q.Matches(r) { select { case sub.Feed <- r: default: diff --git a/database/database_test.go b/database/database_test.go index 4cc94c3..e997760 100644 --- a/database/database_test.go +++ b/database/database_test.go @@ -36,15 +36,15 @@ func testDatabase(t *testing.T, storageType string) { t.Fatal(err) } + // interface + db := NewInterface(nil) + // sub - sub, err := Subscribe(q.New(dbName).MustBeValid()) + sub, err := db.Subscribe(q.New(dbName).MustBeValid()) if err != nil { t.Fatal(err) } - // interface - db := NewInterface(nil) - A := NewExample(makeKey(dbName, "A"), "Herbert", 411) err = A.Save() if err != nil { diff --git a/database/interface.go b/database/interface.go index de5910c..8db89d5 100644 --- a/database/interface.go +++ b/database/interface.go @@ -228,6 +228,11 @@ func (i *Interface) Delete(key string) error { // Query executes the given query on the database. func (i *Interface) Query(q *query.Query) (*iterator.Iterator, error) { + _, err := q.Check() + if err != nil { + return nil, err + } + db, err := getController(q.DatabaseName()) if err != nil { return nil, err @@ -235,3 +240,30 @@ func (i *Interface) Query(q *query.Query) (*iterator.Iterator, error) { return db.Query(q, i.options.Local, i.options.Internal) } + +// Subscribe subscribes to updates matching the given query. +func (i *Interface) Subscribe(q *query.Query) (*Subscription, error) { + _, err := q.Check() + if err != nil { + return nil, err + } + + c, err := getController(q.DatabaseName()) + if err != nil { + return nil, err + } + + c.readLock.Lock() + defer c.readLock.Unlock() + c.writeLock.Lock() + defer c.writeLock.Unlock() + + sub := &Subscription{ + q: q, + local: i.options.Local, + internal: i.options.Internal, + Feed: make(chan record.Record, 100), + } + c.subscriptions = append(c.subscriptions, sub) + return sub, nil +} diff --git a/database/subscription.go b/database/subscription.go index d95ac94..478ed96 100644 --- a/database/subscription.go +++ b/database/subscription.go @@ -7,36 +7,15 @@ import ( // Subscription is a database subscription for updates. type Subscription struct { - q *query.Query + q *query.Query + local bool + internal bool + canceled bool + Feed chan record.Record Err error } -// Subscribe subscribes to updates matching the given query. -func Subscribe(q *query.Query) (*Subscription, error) { - _, err := q.Check() - if err != nil { - return nil, err - } - - c, err := getController(q.DatabaseName()) - if err != nil { - return nil, err - } - - c.readLock.Lock() - defer c.readLock.Unlock() - c.writeLock.Lock() - defer c.writeLock.Unlock() - - sub := &Subscription{ - q: q, - Feed: make(chan record.Record, 100), - } - c.subscriptions = append(c.subscriptions, sub) - return sub, nil -} - // Cancel cancels the subscription. func (s *Subscription) Cancel() error { c, err := getController(s.q.DatabaseName()) @@ -49,6 +28,12 @@ func (s *Subscription) Cancel() error { c.writeLock.Lock() defer c.writeLock.Unlock() + if s.canceled { + return nil + } + s.canceled = true + close(s.Feed) + for key, sub := range c.subscriptions { if sub.q == s.q { c.subscriptions = append(c.subscriptions[:key], c.subscriptions[key+1:]...)