diff --git a/api/database.go b/api/database.go index 862feb4..d703945 100644 --- a/api/database.go +++ b/api/database.go @@ -285,8 +285,34 @@ func (api *DatabaseAPI) processQuery(opID []byte, q *query.Query) (ok bool) { return false } - api.send(opID, dbMsgTypeDone, emptyString, nil) - return true + for { + select { + case <-api.shutdownSignal: + // cancel query and return + it.Cancel() + return + case r := <-it.Next: + // process query feed + if r != nil { + // process record + r.Lock() + data, err := r.Marshal(r, record.JSON) + r.Unlock() + if err != nil { + api.send(opID, dbMsgTypeWarning, err.Error(), nil) + } + api.send(opID, dbMsgTypeOk, r.Key(), data) + } else { + // sub feed ended + if it.Err() != nil { + api.send(opID, dbMsgTypeError, it.Err().Error(), nil) + return false + } + api.send(opID, dbMsgTypeDone, emptyString, nil) + return true + } + } + } } // func (api *DatabaseAPI) runQuery() @@ -323,23 +349,36 @@ func (api *DatabaseAPI) registerSub(opID []byte, q *query.Query) (sub *database. } func (api *DatabaseAPI) processSub(opID []byte, sub *database.Subscription) { - for r := range sub.Feed { - r.Lock() - data, err := r.Marshal(r, record.JSON) - r.Unlock() - if err != nil { - api.send(opID, dbMsgTypeWarning, err.Error(), nil) - continue + for { + select { + case <-api.shutdownSignal: + // cancel sub and return + sub.Cancel() + return + case r := <-sub.Feed: + // process sub feed + if r != nil { + // process record + r.Lock() + data, err := r.Marshal(r, record.JSON) + r.Unlock() + if err != nil { + api.send(opID, dbMsgTypeWarning, err.Error(), nil) + continue + } + // TODO: use upd, new and delete msgTypes + if r.Meta().IsDeleted() { + api.send(opID, dbMsgTypeDel, r.Key(), nil) + } else { + api.send(opID, dbMsgTypeUpd, r.Key(), data) + } + } else { + // sub feed ended + if sub.Err != nil { + api.send(opID, dbMsgTypeError, sub.Err.Error(), nil) + } + } } - // TODO: use upd, new and delete msgTypes - if r.Meta().IsDeleted() { - api.send(opID, dbMsgTypeDel, r.Key(), nil) - } else { - api.send(opID, dbMsgTypeUpd, r.Key(), data) - } - } - if sub.Err != nil { - api.send(opID, dbMsgTypeError, sub.Err.Error(), nil) } } diff --git a/api/router.go b/api/router.go index a2be39a..5e4e42e 100644 --- a/api/router.go +++ b/api/router.go @@ -9,20 +9,18 @@ import ( ) var ( - additionalRoutes map[string]http.Handler + router = mux.NewRouter() ) -// RegisterAdditionalRoute registers an additional route with the API endoint. -func RegisterAdditionalRoute(path string, handler http.Handler) { - if additionalRoutes == nil { - additionalRoutes = make(map[string]http.Handler) - } - additionalRoutes[path] = handler +// RegisterHandleFunc registers an additional handle function with the API endoint. +func RegisterHandleFunc(path string, handleFunc func(http.ResponseWriter, *http.Request)) *mux.Route { + return router.HandleFunc(path, handleFunc) } // RequestLogger is a logging middleware func RequestLogger(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Tracef("api request: %s ___ %s", r.RemoteAddr, r.RequestURI) ew := NewEnrichedResponseWriter(w) next.ServeHTTP(ew, r) log.Infof("api request: %s %d %s", r.RemoteAddr, ew.Status, r.RequestURI) @@ -31,20 +29,13 @@ func RequestLogger(next http.Handler) http.Handler { // Serve starts serving the API endpoint. func Serve() { - - router := mux.NewRouter() - // router.HandleFunc("/api/database/v1", startDatabaseAPI) - - for path, handler := range additionalRoutes { - router.Handle(path, handler) - } - router.Use(RequestLogger) - http.Handle("/", router) - http.HandleFunc("/api/database/v1", startDatabaseAPI) + mainMux := http.NewServeMux() + mainMux.Handle("/", router) // net/http pattern matching /* + mainMux.HandleFunc("/api/database/v1", startDatabaseAPI) // net/http pattern matching only this exact path address := getListenAddress() log.Infof("api: starting to listen on %s", address) - log.Errorf("api: failed to listen on %s: %s", address, http.ListenAndServe(address, nil)) + log.Errorf("api: failed to listen on %s: %s", address, http.ListenAndServe(address, mainMux)) } diff --git a/database/database_test.go b/database/database_test.go index 99bef6c..462d054 100644 --- a/database/database_test.go +++ b/database/database_test.go @@ -131,7 +131,11 @@ func TestDatabaseSystem(t *testing.T) { t.Fatal(err) } - err = Initialize(testDir) + ok := SetLocation(testDir) + if !ok { + t.Fatal("database location already set") + } + err = Initialize() if err != nil { t.Fatal(err) } diff --git a/database/dbmodule/db.go b/database/dbmodule/db.go index 81f7800..351518b 100644 --- a/database/dbmodule/db.go +++ b/database/dbmodule/db.go @@ -15,6 +15,11 @@ var ( maintenanceWg sync.WaitGroup ) +// SetDatabaseLocation sets the location of the database. Must be called before modules.Start and will be overridden by command line options. Intended for unit tests. +func SetDatabaseLocation(location string) { + databaseDir = location +} + func init() { flag.StringVar(&databaseDir, "db", "", "set database directory") @@ -25,11 +30,15 @@ func prep() error { if databaseDir == "" { return errors.New("no database location specified, set with `-db=/path/to/db`") } + ok := database.SetLocation(databaseDir) + if !ok { + return errors.New("database location already set") + } return nil } func start() error { - err := database.Initialize(databaseDir) + err := database.Initialize() if err == nil { startMaintainer() } diff --git a/database/main.go b/database/main.go index 7bb781d..943d4be 100644 --- a/database/main.go +++ b/database/main.go @@ -15,10 +15,18 @@ var ( shutdownSignal = make(chan struct{}) ) -// Initialize initialized the database -func Initialize(location string) error { - if initialized.SetToIf(false, true) { +// SetLocation sets the location of the database. This is separate from the initialization to provide the location to other modules earlier. +func SetLocation(location string) (ok bool) { + if !initialized.IsSet() && rootDir == "" { rootDir = location + return true + } + return false +} + +// Initialize initialized the database +func Initialize() error { + if initialized.SetToIf(false, true) { err := ensureDirectory(rootDir) if err != nil {