Fix various issues during portmaster development

This commit is contained in:
Daniel 2019-01-24 15:23:35 +01:00
parent a943e3315f
commit 358e684909
5 changed files with 92 additions and 41 deletions

View file

@ -285,8 +285,34 @@ func (api *DatabaseAPI) processQuery(opID []byte, q *query.Query) (ok bool) {
return false return false
} }
for {
select {
case <-api.shutdownSignal:
// cancel query and return
it.Cancel()
return
case r := <-it.Next:
// process query feed
if r != nil {
// process record
r.Lock()
data, err := r.Marshal(r, record.JSON)
r.Unlock()
if err != nil {
api.send(opID, dbMsgTypeWarning, err.Error(), nil)
}
api.send(opID, dbMsgTypeOk, r.Key(), data)
} else {
// sub feed ended
if it.Err() != nil {
api.send(opID, dbMsgTypeError, it.Err().Error(), nil)
return false
}
api.send(opID, dbMsgTypeDone, emptyString, nil) api.send(opID, dbMsgTypeDone, emptyString, nil)
return true return true
}
}
}
} }
// func (api *DatabaseAPI) runQuery() // func (api *DatabaseAPI) runQuery()
@ -323,7 +349,16 @@ func (api *DatabaseAPI) registerSub(opID []byte, q *query.Query) (sub *database.
} }
func (api *DatabaseAPI) processSub(opID []byte, sub *database.Subscription) { func (api *DatabaseAPI) processSub(opID []byte, sub *database.Subscription) {
for r := range sub.Feed { for {
select {
case <-api.shutdownSignal:
// cancel sub and return
sub.Cancel()
return
case r := <-sub.Feed:
// process sub feed
if r != nil {
// process record
r.Lock() r.Lock()
data, err := r.Marshal(r, record.JSON) data, err := r.Marshal(r, record.JSON)
r.Unlock() r.Unlock()
@ -337,10 +372,14 @@ func (api *DatabaseAPI) processSub(opID []byte, sub *database.Subscription) {
} else { } else {
api.send(opID, dbMsgTypeUpd, r.Key(), data) api.send(opID, dbMsgTypeUpd, r.Key(), data)
} }
} } else {
// sub feed ended
if sub.Err != nil { if sub.Err != nil {
api.send(opID, dbMsgTypeError, sub.Err.Error(), nil) api.send(opID, dbMsgTypeError, sub.Err.Error(), nil)
} }
}
}
}
} }
func (api *DatabaseAPI) handleQsub(opID []byte, queryText string) { func (api *DatabaseAPI) handleQsub(opID []byte, queryText string) {

View file

@ -9,20 +9,18 @@ import (
) )
var ( var (
additionalRoutes map[string]http.Handler router = mux.NewRouter()
) )
// RegisterAdditionalRoute registers an additional route with the API endoint. // RegisterHandleFunc registers an additional handle function with the API endoint.
func RegisterAdditionalRoute(path string, handler http.Handler) { func RegisterHandleFunc(path string, handleFunc func(http.ResponseWriter, *http.Request)) *mux.Route {
if additionalRoutes == nil { return router.HandleFunc(path, handleFunc)
additionalRoutes = make(map[string]http.Handler)
}
additionalRoutes[path] = handler
} }
// RequestLogger is a logging middleware // RequestLogger is a logging middleware
func RequestLogger(next http.Handler) http.Handler { func RequestLogger(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Tracef("api request: %s ___ %s", r.RemoteAddr, r.RequestURI)
ew := NewEnrichedResponseWriter(w) ew := NewEnrichedResponseWriter(w)
next.ServeHTTP(ew, r) next.ServeHTTP(ew, r)
log.Infof("api request: %s %d %s", r.RemoteAddr, ew.Status, r.RequestURI) log.Infof("api request: %s %d %s", r.RemoteAddr, ew.Status, r.RequestURI)
@ -31,20 +29,13 @@ func RequestLogger(next http.Handler) http.Handler {
// Serve starts serving the API endpoint. // Serve starts serving the API endpoint.
func Serve() { func Serve() {
router := mux.NewRouter()
// router.HandleFunc("/api/database/v1", startDatabaseAPI)
for path, handler := range additionalRoutes {
router.Handle(path, handler)
}
router.Use(RequestLogger) router.Use(RequestLogger)
http.Handle("/", router) mainMux := http.NewServeMux()
http.HandleFunc("/api/database/v1", startDatabaseAPI) mainMux.Handle("/", router) // net/http pattern matching /*
mainMux.HandleFunc("/api/database/v1", startDatabaseAPI) // net/http pattern matching only this exact path
address := getListenAddress() address := getListenAddress()
log.Infof("api: starting to listen on %s", address) log.Infof("api: starting to listen on %s", address)
log.Errorf("api: failed to listen on %s: %s", address, http.ListenAndServe(address, nil)) log.Errorf("api: failed to listen on %s: %s", address, http.ListenAndServe(address, mainMux))
} }

View file

@ -131,7 +131,11 @@ func TestDatabaseSystem(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = Initialize(testDir) ok := SetLocation(testDir)
if !ok {
t.Fatal("database location already set")
}
err = Initialize()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -15,6 +15,11 @@ var (
maintenanceWg sync.WaitGroup maintenanceWg sync.WaitGroup
) )
// SetDatabaseLocation sets the location of the database. Must be called before modules.Start and will be overridden by command line options. Intended for unit tests.
func SetDatabaseLocation(location string) {
databaseDir = location
}
func init() { func init() {
flag.StringVar(&databaseDir, "db", "", "set database directory") flag.StringVar(&databaseDir, "db", "", "set database directory")
@ -25,11 +30,15 @@ func prep() error {
if databaseDir == "" { if databaseDir == "" {
return errors.New("no database location specified, set with `-db=/path/to/db`") return errors.New("no database location specified, set with `-db=/path/to/db`")
} }
ok := database.SetLocation(databaseDir)
if !ok {
return errors.New("database location already set")
}
return nil return nil
} }
func start() error { func start() error {
err := database.Initialize(databaseDir) err := database.Initialize()
if err == nil { if err == nil {
startMaintainer() startMaintainer()
} }

View file

@ -15,10 +15,18 @@ var (
shutdownSignal = make(chan struct{}) shutdownSignal = make(chan struct{})
) )
// Initialize initialized the database // SetLocation sets the location of the database. This is separate from the initialization to provide the location to other modules earlier.
func Initialize(location string) error { func SetLocation(location string) (ok bool) {
if initialized.SetToIf(false, true) { if !initialized.IsSet() && rootDir == "" {
rootDir = location rootDir = location
return true
}
return false
}
// Initialize initialized the database
func Initialize() error {
if initialized.SetToIf(false, true) {
err := ensureDirectory(rootDir) err := ensureDirectory(rootDir)
if err != nil { if err != nil {