diff --git a/api/config.go b/api/config.go new file mode 100644 index 0000000..420572c --- /dev/null +++ b/api/config.go @@ -0,0 +1,49 @@ +package api + +import ( + "flag" + + "github.com/Safing/portbase/config" + "github.com/Safing/portbase/log" +) + +var ( + listenAddressFlag string + listenAddressConfig config.StringOption +) + +func init() { + flag.StringVar(&listenAddressFlag, "api-address", "", "override api listen address") +} + +func checkFlags() error { + if listenAddressFlag != "" { + log.Warning("api: api/listenAddress config is being overridden by -api-address flag") + } + return nil +} + +func getListenAddress() string { + if listenAddressFlag != "" { + return listenAddressFlag + } + return listenAddressConfig() +} + +func registerConfig() error { + err := config.Register(&config.Option{ + Name: "API Address", + Key: "api/listenAddress", + Description: "Define on what IP and port the API should listen on. Be careful, changing this may become a security issue.", + ExpertiseLevel: config.ExpertiseLevelExpert, + OptType: config.OptTypeString, + DefaultValue: "127.0.0.1:18", + ValidationRegex: "^([0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}:[0-9]{1,5}|\\[[:0-9A-Fa-f]+\\]:[0-9]{1,5})$", + }) + if err != nil { + return err + } + listenAddressConfig = config.GetAsString("api/listenAddress", "127.0.0.1:18") + + return nil +} diff --git a/api/database.go b/api/database.go index ce815db..d703945 100644 --- a/api/database.go +++ b/api/database.go @@ -27,12 +27,12 @@ const ( dbMsgTypeDel = "del" dbMsgTypeWarning = "warning" - dbApiSeperator = "|" + dbAPISeperator = "|" emptyString = "" ) var ( - dbApiSeperatorBytes = []byte(dbApiSeperator) + dbAPISeperatorBytes = []byte(dbAPISeperator) ) // DatabaseAPI is a database API instance. @@ -76,6 +76,8 @@ func startDatabaseAPI(w http.ResponseWriter, r *http.Request) { go new.handler() go new.writer() + + log.Infof("api request: init websocket %s %s", r.RemoteAddr, r.RequestURI) } func (api *DatabaseAPI) handler() { @@ -210,16 +212,16 @@ func (api *DatabaseAPI) writer() { func (api *DatabaseAPI) send(opID []byte, msgType string, msgOrKey string, data []byte) { c := container.New(opID) - c.Append(dbApiSeperatorBytes) + c.Append(dbAPISeperatorBytes) c.Append([]byte(msgType)) if msgOrKey != emptyString { - c.Append(dbApiSeperatorBytes) + c.Append(dbAPISeperatorBytes) c.Append([]byte(msgOrKey)) } if len(data) > 0 { - c.Append(dbApiSeperatorBytes) + c.Append(dbAPISeperatorBytes) c.Append(data) } @@ -270,19 +272,47 @@ func (api *DatabaseAPI) processQuery(opID []byte, q *query.Query) (ok bool) { } for r := range it.Next { + r.Lock() data, err := r.Marshal(r, record.JSON) + r.Unlock() if err != nil { api.send(opID, dbMsgTypeWarning, err.Error(), nil) } api.send(opID, dbMsgTypeOk, r.Key(), data) } - if it.Err != nil { - api.send(opID, dbMsgTypeError, it.Err.Error(), nil) + if it.Err() != nil { + api.send(opID, dbMsgTypeError, it.Err().Error(), nil) return false } - api.send(opID, dbMsgTypeDone, emptyString, nil) - return true + for { + select { + case <-api.shutdownSignal: + // cancel query and return + it.Cancel() + return + case r := <-it.Next: + // process query feed + if r != nil { + // process record + r.Lock() + data, err := r.Marshal(r, record.JSON) + r.Unlock() + if err != nil { + api.send(opID, dbMsgTypeWarning, err.Error(), nil) + } + api.send(opID, dbMsgTypeOk, r.Key(), data) + } else { + // sub feed ended + if it.Err() != nil { + api.send(opID, dbMsgTypeError, it.Err().Error(), nil) + return false + } + api.send(opID, dbMsgTypeDone, emptyString, nil) + return true + } + } + } } // func (api *DatabaseAPI) runQuery() @@ -319,20 +349,36 @@ func (api *DatabaseAPI) registerSub(opID []byte, q *query.Query) (sub *database. } func (api *DatabaseAPI) processSub(opID []byte, sub *database.Subscription) { - for r := range sub.Feed { - data, err := r.Marshal(r, record.JSON) - if err != nil { - api.send(opID, dbMsgTypeWarning, err.Error(), nil) + for { + select { + case <-api.shutdownSignal: + // cancel sub and return + sub.Cancel() + return + case r := <-sub.Feed: + // process sub feed + if r != nil { + // process record + r.Lock() + data, err := r.Marshal(r, record.JSON) + r.Unlock() + if err != nil { + api.send(opID, dbMsgTypeWarning, err.Error(), nil) + continue + } + // TODO: use upd, new and delete msgTypes + if r.Meta().IsDeleted() { + api.send(opID, dbMsgTypeDel, r.Key(), nil) + } else { + api.send(opID, dbMsgTypeUpd, r.Key(), data) + } + } else { + // sub feed ended + if sub.Err != nil { + api.send(opID, dbMsgTypeError, sub.Err.Error(), nil) + } + } } - // TODO: use upd, new and delete msgTypes - if r.Meta().Deleted > 0 { - api.send(opID, dbMsgTypeDel, r.Key(), nil) - } else { - api.send(opID, dbMsgTypeUpd, r.Key(), data) - } - } - if sub.Err != nil { - api.send(opID, dbMsgTypeError, sub.Err.Error(), nil) } } diff --git a/api/enriched-response.go b/api/enriched-response.go index b34d68e..58fd983 100644 --- a/api/enriched-response.go +++ b/api/enriched-response.go @@ -4,11 +4,13 @@ import ( "net/http" ) +// EnrichedResponseWriter is a wrapper for http.ResponseWriter for better information extraction. type EnrichedResponseWriter struct { http.ResponseWriter Status int } +// NewEnrichedResponseWriter wraps a http.ResponseWriter. func NewEnrichedResponseWriter(w http.ResponseWriter) *EnrichedResponseWriter { return &EnrichedResponseWriter{ w, @@ -16,6 +18,7 @@ func NewEnrichedResponseWriter(w http.ResponseWriter) *EnrichedResponseWriter { } } +// WriteHeader wraps the original WriteHeader method to extract information. func (ew *EnrichedResponseWriter) WriteHeader(code int) { ew.Status = code ew.ResponseWriter.WriteHeader(code) diff --git a/api/main.go b/api/main.go index ba337fe..77e4339 100644 --- a/api/main.go +++ b/api/main.go @@ -5,18 +5,18 @@ import ( ) func init() { - modules.Register("api", prep, start, stop, "database") + modules.Register("api", prep, start, nil, "database") } func prep() error { - return nil + err := checkFlags() + if err != nil { + return err + } + return registerConfig() } func start() error { go Serve() return nil } - -func stop() error { - return nil -} diff --git a/api/router.go b/api/router.go index 138c4c1..5e4e42e 100644 --- a/api/router.go +++ b/api/router.go @@ -9,39 +9,33 @@ import ( ) var ( - additionalRoutes map[string]http.Handler + router = mux.NewRouter() ) -func RegisterAdditionalRoute(path string, handler http.Handler) { - if additionalRoutes == nil { - additionalRoutes = make(map[string]http.Handler) - } - additionalRoutes[path] = handler +// RegisterHandleFunc registers an additional handle function with the API endoint. +func RegisterHandleFunc(path string, handleFunc func(http.ResponseWriter, *http.Request)) *mux.Route { + return router.HandleFunc(path, handleFunc) } +// RequestLogger is a logging middleware func RequestLogger(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Tracef("api request: %s ___ %s", r.RemoteAddr, r.RequestURI) ew := NewEnrichedResponseWriter(w) next.ServeHTTP(ew, r) log.Infof("api request: %s %d %s", r.RemoteAddr, ew.Status, r.RequestURI) }) } +// Serve starts serving the API endpoint. func Serve() { - - router := mux.NewRouter() - // router.HandleFunc("/api/database/v1", startDatabaseAPI) - - for path, handler := range additionalRoutes { - router.Handle(path, handler) - } - router.Use(RequestLogger) - http.Handle("/", router) - http.HandleFunc("/api/database/v1", startDatabaseAPI) + mainMux := http.NewServeMux() + mainMux.Handle("/", router) // net/http pattern matching /* + mainMux.HandleFunc("/api/database/v1", startDatabaseAPI) // net/http pattern matching only this exact path - address := "127.0.0.1:18" + address := getListenAddress() log.Infof("api: starting to listen on %s", address) - log.Errorf("api: failed to listen on %s: %s", address, http.ListenAndServe(address, nil)) + log.Errorf("api: failed to listen on %s: %s", address, http.ListenAndServe(address, mainMux)) } diff --git a/build b/build index bbde1f9..8c730c0 100755 --- a/build +++ b/build @@ -49,4 +49,4 @@ echo "Run the compiled binary with the -version flag to see the information incl # build BUILD_PATH="github.com/Safing/portbase/info" -go build -ldflags "-X ${BUILD_PATH}.commit=${BUILD_COMMIT} -X ${BUILD_PATH}.buildOptions=${BUILD_BUILDOPTIONS} -X ${BUILD_PATH}.buildUser=${BUILD_USER} -X ${BUILD_PATH}.buildHost=${BUILD_HOST} -X ${BUILD_PATH}.buildDate=${BUILD_DATE} -X ${BUILD_PATH}.buildSource=${BUILD_SOURCE}" $* +go build -ldflags "-X ${BUILD_PATH}.commit=${BUILD_COMMIT} -X ${BUILD_PATH}.buildOptions=${BUILD_BUILDOPTIONS} -X ${BUILD_PATH}.buildUser=${BUILD_USER} -X ${BUILD_PATH}.buildHost=${BUILD_HOST} -X ${BUILD_PATH}.buildDate=${BUILD_DATE} -X ${BUILD_PATH}.buildSource=${BUILD_SOURCE}" $@ diff --git a/database/controller.go b/database/controller.go index d981fc3..6f70177 100644 --- a/database/controller.go +++ b/database/controller.go @@ -13,22 +13,27 @@ import ( // A Controller takes care of all the extra database logic. type Controller struct { - storage storage.Interface + storage storage.Interface - hooks []*RegisteredHook + hooks []*RegisteredHook subscriptions []*Subscription writeLock sync.RWMutex - readLock sync.RWMutex - migrating *abool.AtomicBool // TODO + // Lock: nobody may write + // RLock: concurrent writing + readLock sync.RWMutex + // Lock: nobody may read + // RLock: concurrent reading + + migrating *abool.AtomicBool // TODO hibernating *abool.AtomicBool // TODO } // newController creates a new controller for a storage. func newController(storageInt storage.Interface) (*Controller, error) { return &Controller{ - storage: storageInt, - migrating: abool.NewBool(false), + storage: storageInt, + migrating: abool.NewBool(false), hibernating: abool.NewBool(false), }, nil } @@ -101,9 +106,6 @@ func (c *Controller) Put(r record.Record) (err error) { return ErrReadOnly } - r.Lock() - defer r.Unlock() - // process hooks for _, hook := range c.hooks { if hook.h.UsesPrePut() && hook.q.Matches(r) { @@ -160,6 +162,9 @@ func (c *Controller) Query(q *query.Query, local, internal bool) (*iterator.Iter // PushUpdate pushes a record update to subscribers. func (c *Controller) PushUpdate(r record.Record) { if c != nil { + c.readLock.RLock() + defer c.readLock.RUnlock() + for _, sub := range c.subscriptions { if r.Meta().CheckPermission(sub.local, sub.internal) && sub.q.Matches(r) { select { @@ -171,19 +176,28 @@ func (c *Controller) PushUpdate(r record.Record) { } } +func (c *Controller) addSubscription(sub *Subscription) { + c.readLock.Lock() + defer c.readLock.Unlock() + c.writeLock.Lock() + defer c.writeLock.Unlock() + + c.subscriptions = append(c.subscriptions, sub) +} + func (c *Controller) readUnlockerAfterQuery(it *iterator.Iterator) { - <- it.Done + <-it.Done c.readLock.RUnlock() } -// Maintain runs the Maintain method no the storage. +// Maintain runs the Maintain method on the storage. func (c *Controller) Maintain() error { c.writeLock.RLock() defer c.writeLock.RUnlock() return c.storage.Maintain() } -// MaintainThorough runs the MaintainThorough method no the storage. +// MaintainThorough runs the MaintainThorough method on the storage. func (c *Controller) MaintainThorough() error { c.writeLock.RLock() defer c.writeLock.RUnlock() diff --git a/database/database_test.go b/database/database_test.go index 9aa5e7d..462d054 100644 --- a/database/database_test.go +++ b/database/database_test.go @@ -98,8 +98,8 @@ func testDatabase(t *testing.T, storageType string) { for _ = range it.Next { cnt++ } - if it.Err != nil { - t.Fatal(it.Err) + if it.Err() != nil { + t.Fatal(it.Err()) } if cnt != 2 { t.Fatal("expected two records") @@ -131,7 +131,11 @@ func TestDatabaseSystem(t *testing.T) { t.Fatal(err) } - err = Initialize(testDir) + ok := SetLocation(testDir) + if !ok { + t.Fatal("database location already set") + } + err = Initialize() if err != nil { t.Fatal(err) } diff --git a/database/dbmodule/db.go b/database/dbmodule/db.go index 1b7d1d8..351518b 100644 --- a/database/dbmodule/db.go +++ b/database/dbmodule/db.go @@ -15,6 +15,11 @@ var ( maintenanceWg sync.WaitGroup ) +// SetDatabaseLocation sets the location of the database. Must be called before modules.Start and will be overridden by command line options. Intended for unit tests. +func SetDatabaseLocation(location string) { + databaseDir = location +} + func init() { flag.StringVar(&databaseDir, "db", "", "set database directory") @@ -25,13 +30,17 @@ func prep() error { if databaseDir == "" { return errors.New("no database location specified, set with `-db=/path/to/db`") } + ok := database.SetLocation(databaseDir) + if !ok { + return errors.New("database location already set") + } return nil } func start() error { - err := database.Initialize(databaseDir) + err := database.Initialize() if err == nil { - go maintainer() + startMaintainer() } return err } diff --git a/database/dbmodule/maintenance.go b/database/dbmodule/maintenance.go index 4bdd6fe..0c403cb 100644 --- a/database/dbmodule/maintenance.go +++ b/database/dbmodule/maintenance.go @@ -1,36 +1,45 @@ package dbmodule import ( - "time" + "time" - "github.com/Safing/portbase/database" - "github.com/Safing/portbase/log" + "github.com/Safing/portbase/database" + "github.com/Safing/portbase/log" ) -func maintainer() { - ticker := time.NewTicker(10 * time.Minute) - longTicker := time.NewTicker(1 * time.Hour) - maintenanceWg.Add(1) +var ( + maintenanceShortTickDuration = 10 * time.Minute + maintenanceLongTickDuration = 1 * time.Hour +) - for { - select { - case <- ticker.C: - err := database.Maintain() - if err != nil { - log.Errorf("database: maintenance error: %s", err) - } - case <- longTicker.C: - err := database.MaintainRecordStates() - if err != nil { - log.Errorf("database: record states maintenance error: %s", err) - } - err = database.MaintainThorough() - if err != nil { - log.Errorf("database: thorough maintenance error: %s", err) - } - case <-shutdownSignal: - maintenanceWg.Done() - return - } - } +func startMaintainer() { + maintenanceWg.Add(1) + go maintenanceWorker() +} + +func maintenanceWorker() { + ticker := time.NewTicker(maintenanceShortTickDuration) + longTicker := time.NewTicker(maintenanceLongTickDuration) + + for { + select { + case <-ticker.C: + err := database.Maintain() + if err != nil { + log.Errorf("database: maintenance error: %s", err) + } + case <-longTicker.C: + err := database.MaintainRecordStates() + if err != nil { + log.Errorf("database: record states maintenance error: %s", err) + } + err = database.MaintainThorough() + if err != nil { + log.Errorf("database: thorough maintenance error: %s", err) + } + case <-shutdownSignal: + maintenanceWg.Done() + return + } + } } diff --git a/database/interface.go b/database/interface.go index e4422e5..69063bc 100644 --- a/database/interface.go +++ b/database/interface.go @@ -178,6 +178,9 @@ func (i *Interface) Put(r record.Record) error { return err } + r.Lock() + defer r.Unlock() + i.options.Apply(r) i.updateCache(r) @@ -191,6 +194,12 @@ func (i *Interface) PutNew(r record.Record) error { return err } + r.Lock() + defer r.Unlock() + + if r.Meta() == nil { + r.SetMeta(&record.Meta{}) + } r.Meta().Reset() i.options.Apply(r) i.updateCache(r) @@ -296,17 +305,12 @@ func (i *Interface) Subscribe(q *query.Query) (*Subscription, error) { return nil, err } - c.readLock.Lock() - defer c.readLock.Unlock() - c.writeLock.Lock() - defer c.writeLock.Unlock() - sub := &Subscription{ q: q, local: i.options.Local, internal: i.options.Internal, Feed: make(chan record.Record, 100), } - c.subscriptions = append(c.subscriptions, sub) + c.addSubscription(sub) return sub, nil } diff --git a/database/iterator/iterator.go b/database/iterator/iterator.go index 70451ea..6abf766 100644 --- a/database/iterator/iterator.go +++ b/database/iterator/iterator.go @@ -1,6 +1,10 @@ package iterator import ( + "sync" + + "github.com/tevino/abool" + "github.com/Safing/portbase/database/record" ) @@ -8,19 +12,43 @@ import ( type Iterator struct { Next chan record.Record Done chan struct{} - Err error + + errLock sync.Mutex + err error + doneClosed *abool.AtomicBool } // New creates a new Iterator. func New() *Iterator { return &Iterator{ - Next: make(chan record.Record, 10), - Done: make(chan struct{}), + Next: make(chan record.Record, 10), + Done: make(chan struct{}), + doneClosed: abool.NewBool(false), } } +// Finish is called be the storage to signal the end of the query results. func (it *Iterator) Finish(err error) { close(it.Next) - close(it.Done) - it.Err = err + if it.doneClosed.SetToIf(false, true) { + close(it.Done) + } + + it.errLock.Lock() + defer it.errLock.Unlock() + it.err = err +} + +// Cancel is called by the iteration consumer to cancel the running query. +func (it *Iterator) Cancel() { + if it.doneClosed.SetToIf(false, true) { + close(it.Done) + } +} + +// Err returns the iterator error, if exists. +func (it *Iterator) Err() error { + it.errLock.Lock() + defer it.errLock.Unlock() + return it.err } diff --git a/database/main.go b/database/main.go index 7bb781d..943d4be 100644 --- a/database/main.go +++ b/database/main.go @@ -15,10 +15,18 @@ var ( shutdownSignal = make(chan struct{}) ) -// Initialize initialized the database -func Initialize(location string) error { - if initialized.SetToIf(false, true) { +// SetLocation sets the location of the database. This is separate from the initialization to provide the location to other modules earlier. +func SetLocation(location string) (ok bool) { + if !initialized.IsSet() && rootDir == "" { rootDir = location + return true + } + return false +} + +// Initialize initialized the database +func Initialize() error { + if initialized.SetToIf(false, true) { err := ensureDirectory(rootDir) if err != nil { diff --git a/database/maintenance.go b/database/maintenance.go index 05d8182..afce639 100644 --- a/database/maintenance.go +++ b/database/maintenance.go @@ -1,92 +1,92 @@ package database import ( - "time" + "time" - "github.com/Safing/portbase/database/query" - "github.com/Safing/portbase/database/record" + "github.com/Safing/portbase/database/query" + "github.com/Safing/portbase/database/record" ) // Maintain runs the Maintain method on all storages. func Maintain() (err error) { - controllers := duplicateControllers() - for _, c := range controllers { - err = c.Maintain() - if err != nil { - return - } - } - return + controllers := duplicateControllers() + for _, c := range controllers { + err = c.Maintain() + if err != nil { + return + } + } + return } // MaintainThorough runs the MaintainThorough method on all storages. func MaintainThorough() (err error) { - all := duplicateControllers() - for _, c := range all { - err = c.MaintainThorough() - if err != nil { - return - } - } - return + all := duplicateControllers() + for _, c := range all { + err = c.MaintainThorough() + if err != nil { + return + } + } + return } // MaintainRecordStates runs record state lifecycle maintenance on all storages. func MaintainRecordStates() error { - all := duplicateControllers() - now := time.Now().Unix() - thirtyDaysAgo := time.Now().Add(-30*24*time.Hour).Unix() + all := duplicateControllers() + now := time.Now().Unix() + thirtyDaysAgo := time.Now().Add(-30 * 24 * time.Hour).Unix() - for _, c := range all { + for _, c := range all { - if c.ReadOnly() || c.Injected() { - continue - } + if c.ReadOnly() || c.Injected() { + continue + } - q, err := query.New("").Check() - if err != nil { - return err - } + q, err := query.New("").Check() + if err != nil { + return err + } - it, err := c.Query(q, true, true) - if err != nil { - return err - } + it, err := c.Query(q, true, true) + if err != nil { + return err + } - var toDelete []record.Record - var toExpire []record.Record + var toDelete []record.Record + var toExpire []record.Record - for r := range it.Next { - switch { - case r.Meta().Deleted < thirtyDaysAgo: - toDelete = append(toDelete, r) - case r.Meta().Expires < now: - toExpire = append(toExpire, r) - } - } - if it.Err != nil { - return err - } + for r := range it.Next { + switch { + case r.Meta().Deleted < thirtyDaysAgo: + toDelete = append(toDelete, r) + case r.Meta().Expires < now: + toExpire = append(toExpire, r) + } + } + if it.Err() != nil { + return err + } - for _, r := range toDelete { - c.storage.Delete(r.DatabaseKey()) - } - for _, r := range toExpire { - r.Meta().Delete() - return c.Put(r) - } + for _, r := range toDelete { + c.storage.Delete(r.DatabaseKey()) + } + for _, r := range toExpire { + r.Meta().Delete() + return c.Put(r) + } - } - return nil + } + return nil } func duplicateControllers() (all []*Controller) { - controllersLock.Lock() - defer controllersLock.Unlock() + controllersLock.Lock() + defer controllersLock.Unlock() - for _, c := range controllers { - all = append(all, c) - } + for _, c := range controllers { + all = append(all, c) + } - return + return } diff --git a/database/record/base.go b/database/record/base.go index 0c7597f..340ef37 100644 --- a/database/record/base.go +++ b/database/record/base.go @@ -21,6 +21,11 @@ func (b *Base) Key() string { return fmt.Sprintf("%s:%s", b.dbName, b.dbKey) } +// KeyIsSet returns true if the database key is set. +func (b *Base) KeyIsSet() bool { + return len(b.dbName) > 0 && len(b.dbKey) > 0 +} + // DatabaseName returns the name of the database. func (b *Base) DatabaseName() string { return b.dbName @@ -47,6 +52,11 @@ func (b *Base) Meta() *Meta { return b.meta } +// CreateMeta sets a default metadata object for this record. +func (b *Base) CreateMeta() { + b.meta = &Meta{} +} + // SetMeta sets the metadata on the database record, it should only be called after loading the record. Use MoveTo to save the record with another key. func (b *Base) SetMeta(meta *Meta) { b.meta = meta diff --git a/database/record/meta.go b/database/record/meta.go index a2c8bb4..33a529b 100644 --- a/database/record/meta.go +++ b/database/record/meta.go @@ -78,6 +78,11 @@ func (m *Meta) Delete() { m.Deleted = time.Now().Unix() } +// IsDeleted returns whether the record is deleted. +func (m *Meta) IsDeleted() bool { + return m.Deleted > 0 +} + // CheckValidity checks whether the database record is valid. func (m *Meta) CheckValidity() (valid bool) { switch { diff --git a/database/record/record.go b/database/record/record.go index 3895292..68082de 100644 --- a/database/record/record.go +++ b/database/record/record.go @@ -6,7 +6,8 @@ import ( // Record provides an interface for uniformally handling database records. type Record interface { - Key() string // test:config + Key() string // test:config + KeyIsSet() bool DatabaseName() string // test DatabaseKey() string // config diff --git a/database/storage/badger/badger.go b/database/storage/badger/badger.go index 3df5a37..add2867 100644 --- a/database/storage/badger/badger.go +++ b/database/storage/badger/badger.go @@ -121,7 +121,11 @@ func (b *Badger) queryExecutor(queryIter *iterator.Iterator, q *query.Query, loc for it.Seek(prefix); it.ValidForPrefix(prefix); it.Next() { item := it.Item() - data, err := item.Value() + var data []byte + err := item.Value(func(val []byte) error { + data = val + return nil + }) if err != nil { return err } @@ -148,10 +152,14 @@ func (b *Badger) queryExecutor(queryIter *iterator.Iterator, q *query.Query, loc return err } select { + case <-queryIter.Done: + return nil case queryIter.Next <- new: default: select { case queryIter.Next <- new: + case <-queryIter.Done: + return nil case <-time.After(1 * time.Minute): return errors.New("query timeout") } @@ -162,11 +170,7 @@ func (b *Badger) queryExecutor(queryIter *iterator.Iterator, q *query.Query, loc return nil }) - if err != nil { - queryIter.Err = err - } - close(queryIter.Next) - close(queryIter.Done) + queryIter.Finish(err) } // ReadOnly returns whether the database is read only. diff --git a/modules/modules.go b/modules/modules.go index 5b4f1ba..f9ebb89 100644 --- a/modules/modules.go +++ b/modules/modules.go @@ -24,7 +24,7 @@ var ( type Module struct { Name string Active *abool.AtomicBool - inTransition bool + inTransition *abool.AtomicBool prep func() error start func() error @@ -42,6 +42,7 @@ func Register(name string, prep, start, stop func() error, dependencies ...strin newModule := &Module{ Name: name, Active: abool.NewBool(false), + inTransition: abool.NewBool(false), prep: prep, start: start, stop: stop, diff --git a/modules/modules_test.go b/modules/modules_test.go index 8f3bfbc..20e8236 100644 --- a/modules/modules_test.go +++ b/modules/modules_test.go @@ -5,8 +5,8 @@ package modules import ( "errors" "fmt" - "testing" "sync" + "testing" "time" ) @@ -79,7 +79,7 @@ func TestOrdering(t *testing.T) { func resetModules() { for _, module := range modules { module.Active.UnSet() - module.inTransition = false + module.inTransition.UnSet() } } diff --git a/modules/start.go b/modules/start.go index 798cf70..cc3290a 100644 --- a/modules/start.go +++ b/modules/start.go @@ -72,7 +72,7 @@ moduleLoop: switch { case module.Active.IsSet(): active++ - case module.inTransition: + case module.inTransition.IsSet(): modulesInProgress = true default: for _, depName := range module.dependencies { @@ -116,7 +116,7 @@ func startModules() error { for _, module := range readyToStart { modulesStarting.Add(1) - module.inTransition = true + module.inTransition.Set() nextModule := module // workaround go vet alert go func() { startErr := nextModule.start() @@ -125,7 +125,7 @@ func startModules() error { } else { log.Infof("modules: started %s", nextModule.Name) nextModule.Active.Set() - nextModule.inTransition = false + nextModule.inTransition.UnSet() reports <- nil } modulesStarting.Done() diff --git a/modules/stop.go b/modules/stop.go index 2696aca..553c5d9 100644 --- a/modules/stop.go +++ b/modules/stop.go @@ -9,7 +9,7 @@ import ( ) var ( - shutdownSignal = make(chan struct{}) + shutdownSignal = make(chan struct{}) shutdownSignalClosed = abool.NewBool(false) ) @@ -42,7 +42,7 @@ func checkStopStatus() (readyToStop []*Module, done bool) { // make list out of map, minus modules in transition for _, module := range activeModules { - if !module.inTransition { + if !module.inTransition.IsSet() { readyToStop = append(readyToStop, module) } } @@ -74,7 +74,7 @@ func Shutdown() error { } for _, module := range readyToStop { - module.inTransition = true + module.inTransition.Set() nextModule := module // workaround go vet alert go func() { err := nextModule.stop() @@ -84,7 +84,7 @@ func Shutdown() error { reports <- nil } nextModule.Active.UnSet() - nextModule.inTransition = false + nextModule.inTransition.UnSet() }() } diff --git a/utils/slices.go b/utils/slices.go index c91496f..f185b3b 100644 --- a/utils/slices.go +++ b/utils/slices.go @@ -1,22 +1,25 @@ package utils +// IndexOfString returns the index of given string and -1 if its not part of the slice. +func IndexOfString(a []string, s string) int { + for i, entry := range a { + if entry == s { + return i + } + } + return -1 +} + // StringInSlice returns whether the given string is in the string slice. func StringInSlice(a []string, s string) bool { - for _, entry := range a { - if entry == s { - return true - } - } - return false + return IndexOfString(a, s) >= 0 } // RemoveFromStringSlice removes the given string from the slice and returns a new slice. func RemoveFromStringSlice(a []string, s string) []string { - for key, entry := range a { - if entry == s { - a = append(a[:key], a[key+1:]...) - return a - } + i := IndexOfString(a, s) + if i > 0 { + a = append(a[:i], a[i+1:]...) } return a }