Use gorilla/mux for endpoint paths

This commit is contained in:
Daniel 2021-05-12 11:24:30 +02:00
parent 8814d279bd
commit 400f4c12ed
3 changed files with 49 additions and 25 deletions

View file

@ -13,7 +13,6 @@ import (
"github.com/safing/portbase/database"
"github.com/safing/portbase/database/record"
"github.com/safing/portbase/database/storage"
"github.com/safing/portbase/log"
)
const (
@ -59,8 +58,6 @@ type EndpointBridgeResponse struct {
// Get returns a database record.
func (ebs *endpointBridgeStorage) Get(key string) (record.Record, error) {
log.Errorf("api bridge: getting %s", key)
if key == "" {
return nil, database.ErrNotFound
}
@ -81,8 +78,6 @@ func (ebs *endpointBridgeStorage) GetMeta(key string) (*record.Meta, error) {
// Put stores a record in the database.
func (ebs *endpointBridgeStorage) Put(r record.Record) (record.Record, error) {
log.Errorf("api bridge: putting %s", r.Key())
if r.DatabaseKey() == "" {
return nil, database.ErrNotFound
}
@ -103,7 +98,6 @@ func (ebs *endpointBridgeStorage) Put(r record.Record) (record.Record, error) {
return nil, fmt.Errorf("record not of type *EndpointBridgeRequest, but %T", r)
}
}
log.Errorf("api bridge: putting %+v", ebr)
// Override path with key to mitigate sneaky stuff.
ebr.Path = r.DatabaseKey()
@ -145,7 +139,6 @@ func callAPI(ebr *EndpointBridgeRequest) (record.Record, error) {
}
u.RawQuery = query.Encode()
}
log.Errorf("api bridge: calling %s", u.String())
// Create request and response objects.
r := httptest.NewRequest(ebr.Method, u.String(), bytes.NewBuffer(ebr.Data))

View file

@ -11,6 +11,8 @@ import (
"strings"
"sync"
"github.com/gorilla/mux"
"github.com/safing/portbase/database/record"
"github.com/safing/portbase/log"
)
@ -84,6 +86,7 @@ func init() {
var (
endpoints = make(map[string]*Endpoint)
endpointsMux = mux.NewRouter()
endpointsLock sync.RWMutex
// ErrInvalidEndpoint is returned when an invalid endpoint is registered.
@ -106,16 +109,28 @@ func getAPIContext(r *http.Request) (apiEndpoint *Endpoint, apiRequest *Request)
return apiEndpoint, apiRequest
}
// If not, get the action from the registry.
endpointPath, ok := apiRequest.URLVars["endpointPath"]
if !ok {
return nil, apiRequest
}
endpointsLock.RLock()
defer endpointsLock.RUnlock()
apiEndpoint, ok = endpoints[endpointPath]
// Get handler for request.
// Gorilla does not support handling this on our own very well.
// See github.com/gorilla/mux.ServeHTTP for reference.
var match mux.RouteMatch
var handler http.Handler
if endpointsMux.Match(r, &match) {
handler = match.Handler
apiRequest.Route = match.Route
// Add/Override variables instead of replacing.
for k, v := range match.Vars {
apiRequest.URLVars[k] = v
}
} else {
return nil, apiRequest
}
log.Errorf("handler: %+v", handler)
apiEndpoint, ok = handler.(*Endpoint)
log.Errorf("apiEndpoint: %+v", apiEndpoint)
if ok {
// Cache for next operation.
apiRequest.HandlerCache = apiEndpoint
@ -139,6 +154,7 @@ func RegisterEndpoint(e Endpoint) error {
}
endpoints[e.Path] = &e
endpointsMux.Handle(apiV1Path+e.Path, &e)
return nil
}
@ -243,6 +259,17 @@ func (eh *endpointHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
apiEndpoint.ServeHTTP(w, r)
}
// ServeHTTP handles the http request.
func (e *Endpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
_, apiRequest := getAPIContext(r)
if apiRequest == nil {
http.NotFound(w, r)
return
}
switch r.Method {
case http.MethodHead:
w.WriteHeader(http.StatusOK)
@ -269,32 +296,32 @@ func (eh *endpointHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var err error
switch {
case apiEndpoint.ActionFunc != nil:
case e.ActionFunc != nil:
var msg string
msg, err = apiEndpoint.ActionFunc(apiRequest)
msg, err = e.ActionFunc(apiRequest)
if err == nil {
responseData = []byte(msg)
}
case apiEndpoint.DataFunc != nil:
responseData, err = apiEndpoint.DataFunc(apiRequest)
case e.DataFunc != nil:
responseData, err = e.DataFunc(apiRequest)
case apiEndpoint.StructFunc != nil:
case e.StructFunc != nil:
var v interface{}
v, err = apiEndpoint.StructFunc(apiRequest)
v, err = e.StructFunc(apiRequest)
if err == nil && v != nil {
responseData, err = json.Marshal(v)
}
case apiEndpoint.RecordFunc != nil:
case e.RecordFunc != nil:
var rec record.Record
rec, err = apiEndpoint.RecordFunc(apiRequest)
rec, err = e.RecordFunc(apiRequest)
if err == nil && r != nil {
responseData, err = marshalRecord(rec, false)
}
case apiEndpoint.HandlerFunc != nil:
apiEndpoint.HandlerFunc(w, r)
case e.HandlerFunc != nil:
e.HandlerFunc(w, r)
return
default:
@ -309,7 +336,7 @@ func (eh *endpointHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
// Write response.
w.Header().Set("Content-Type", apiEndpoint.MimeType+"; charset=utf-8")
w.Header().Set("Content-Type", e.MimeType+"; charset=utf-8")
w.Header().Set("Content-Length", strconv.Itoa(len(responseData)))
w.WriteHeader(http.StatusOK)
_, err = w.Write(responseData)

View file

@ -50,6 +50,10 @@ func prep() error {
return err
}
if err := registerModulesEndpoints(); err != nil {
return err
}
return registerMetaEndpoints()
}