Update api with registering handlers, middleware and authentication

This commit is contained in:
Daniel 2019-07-30 13:08:45 +02:00
parent a7105dc6ba
commit 8d091f7f7a
10 changed files with 255 additions and 37 deletions

110
api/authentication.go Normal file
View file

@ -0,0 +1,110 @@
package api
import (
"encoding/base64"
"net/http"
"sync"
"time"
"github.com/safing/portbase/crypto/random"
"github.com/safing/portbase/log"
)
var (
validTokens map[string]time.Time
validTokensLock sync.Mutex
authFnLock sync.Mutex
authFn Authenticator
)
const (
cookieName = "T17"
// in seconds
cookieBaseTTL = 300 // 5 minutes
cookieTTL = cookieBaseTTL * time.Second
cookieRefresh = cookieBaseTTL * 0.9 * time.Second
)
// Authenticator is a function that can be set as the authenticator for the API endpoint. If none is set, all requests will be allowed.
type Authenticator func(s *http.Server, r *http.Request) (grantAccess bool, err error)
// SetAuthenticator sets an authenticator function for the API endpoint. If none is set, all requests will be allowed.
func SetAuthenticator(fn Authenticator) error {
authFnLock.Lock()
defer authFnLock.Unlock()
if authFn == nil {
authFn = fn
return nil
}
return ErrAuthenticationAlreadySet
}
func authMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// check existing auth cookie
c, err := r.Cookie(cookieName)
if err == nil {
// get token
validTokensLock.Lock()
validUntil, valid := validTokens[c.Value]
validTokensLock.Unlock()
// check if token is valid
if valid && time.Now().Before(validUntil) {
// maybe refresh cookie
if time.Now().After(validUntil.Add(-cookieRefresh)) {
validTokensLock.Lock()
validTokens[c.Value] = time.Now()
validTokensLock.Unlock()
}
next.ServeHTTP(w, r)
return
}
}
// get authenticator
authFnLock.Lock()
authenticator := authFn
authFnLock.Unlock()
// permit if no authenticator set
if authenticator == nil {
next.ServeHTTP(w, r)
return
}
// get auth decision
grantAccess, err := authenticator(server, r)
if err != nil {
log.Errorf("api: authenticator failed: %s", err)
http.Error(w, "", http.StatusInternalServerError)
}
if !grantAccess {
log.Warningf("api: denying api access to %s", r.RemoteAddr)
http.Error(w, "", http.StatusForbidden)
return
}
// write new cookie
token, err := random.Bytes(32) // 256 bit
if err != nil {
log.Errorf("api: failed to generate random token: %s", err)
http.Error(w, "", http.StatusInternalServerError)
}
tokenString := base64.RawURLEncoding.EncodeToString(token)
http.SetCookie(w, &http.Cookie{
Name: cookieName,
Value: tokenString,
HttpOnly: true,
})
// serve
log.Tracef("api: granted %s", r.RemoteAddr)
next.ServeHTTP(w, r)
})
}

View file

@ -24,7 +24,7 @@ func checkFlags() error {
return nil return nil
} }
func getListenAddress() string { func GetAPIAddress() string {
if listenAddressFlag != "" { if listenAddressFlag != "" {
return listenAddressFlag return listenAddressFlag
} }
@ -49,6 +49,7 @@ func registerConfig() error {
return nil return nil
} }
// SetDefaultAPIListenAddress sets the default listen address for the API.
func SetDefaultAPIListenAddress(address string) { func SetDefaultAPIListenAddress(address string) {
if defaultListenAddress == "" { if defaultListenAddress == "" {
defaultListenAddress = address defaultListenAddress = address

View file

@ -35,6 +35,10 @@ var (
dbAPISeperatorBytes = []byte(dbAPISeperator) dbAPISeperatorBytes = []byte(dbAPISeperator)
) )
func init() {
RegisterHandleFunc("/api/database/v1", startDatabaseAPI) // net/http pattern matching only this exact path
}
// DatabaseAPI is a database API instance. // DatabaseAPI is a database API instance.
type DatabaseAPI struct { type DatabaseAPI struct {
conn *websocket.Conn conn *websocket.Conn

View file

@ -1,25 +1,68 @@
package api package api
import ( import (
"bufio"
"errors"
"net"
"net/http" "net/http"
"github.com/safing/portbase/log"
) )
// EnrichedResponseWriter is a wrapper for http.ResponseWriter for better information extraction. // LoggingResponseWriter is a wrapper for http.ResponseWriter for better request logging.
type EnrichedResponseWriter struct { type LoggingResponseWriter struct {
http.ResponseWriter ResponseWriter http.ResponseWriter
Request *http.Request
Status int Status int
} }
// NewEnrichedResponseWriter wraps a http.ResponseWriter. // NewLoggingResponseWriter wraps a http.ResponseWriter.
func NewEnrichedResponseWriter(w http.ResponseWriter) *EnrichedResponseWriter { func NewLoggingResponseWriter(w http.ResponseWriter, r *http.Request) *LoggingResponseWriter {
return &EnrichedResponseWriter{ return &LoggingResponseWriter{
w, ResponseWriter: w,
0, Request: r,
} }
} }
// WriteHeader wraps the original WriteHeader method to extract information. // Header wraps the original Header method.
func (ew *EnrichedResponseWriter) WriteHeader(code int) { func (lrw *LoggingResponseWriter) Header() http.Header {
ew.Status = code return lrw.ResponseWriter.Header()
ew.ResponseWriter.WriteHeader(code) }
// Write wraps the original Write method.
func (lrw *LoggingResponseWriter) Write(b []byte) (int, error) {
return lrw.ResponseWriter.Write(b)
}
// WriteHeader wraps the original WriteHeader method to extract information.
func (lrw *LoggingResponseWriter) WriteHeader(code int) {
lrw.Status = code
lrw.ResponseWriter.WriteHeader(code)
}
// Hijack wraps the original Hijack method, if available.
func (lrw *LoggingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := lrw.ResponseWriter.(http.Hijacker)
if ok {
c, b, err := hijacker.Hijack()
if err != nil {
return nil, nil, err
}
log.Infof("api request: %s HIJ %s", lrw.Request.RemoteAddr, lrw.Request.RequestURI)
return c, b, nil
}
return nil, nil, errors.New("response does not implement http.Hijacker")
}
// RequestLogger is a logging middleware
func RequestLogger(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Tracef("api request: %s ___ %s", r.RemoteAddr, r.RequestURI)
lrw := NewLoggingResponseWriter(w, r)
next.ServeHTTP(lrw, r)
if lrw.Status != 0 {
// request may have been hijacked
log.Infof("api request: %s %d %s", lrw.Request.RemoteAddr, lrw.Status, lrw.Request.RequestURI)
}
})
} }

View file

@ -1,9 +1,17 @@
package api package api
import ( import (
"context"
"errors"
"github.com/safing/portbase/modules" "github.com/safing/portbase/modules"
) )
// API Errors
var (
ErrAuthenticationAlreadySet = errors.New("the authentication function has already been set")
)
func init() { func init() {
modules.Register("api", prep, start, nil, "database") modules.Register("api", prep, start, nil, "database")
} }
@ -20,3 +28,10 @@ func start() error {
go Serve() go Serve()
return nil return nil
} }
func stop() error {
if server != nil {
return server.Shutdown(context.Background())
}
return nil
}

27
api/middleware.go Normal file
View file

@ -0,0 +1,27 @@
package api
import "net/http"
// Middleware is a function that can be added as a middleware to the API endpoint.
type Middleware func(next http.Handler) http.Handler
type mwHandler struct {
handlers []Middleware
final http.Handler
}
func (mwh *mwHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
handlerLock.RLock()
defer handlerLock.RUnlock()
// final handler
handler := mwh.final
// build middleware chain
for _, mw := range mwh.handlers {
handler = mw(handler)
}
// start
handler.ServeHTTP(w, r)
}

View file

@ -2,6 +2,7 @@ package api
import ( import (
"net/http" "net/http"
"sync"
"github.com/gorilla/mux" "github.com/gorilla/mux"
@ -9,35 +10,54 @@ import (
) )
var ( var (
router = mux.NewRouter() // gorilla mux
mainMux = mux.NewRouter()
// middlewares
middlewareHandler = &mwHandler{
final: mainMux,
handlers: []Middleware{
RequestLogger,
authMiddleware,
},
}
// main server and lock
server = &http.Server{}
handlerLock sync.RWMutex
) )
// RegisterHandleFunc registers an additional handle function with the API endoint. // RegisterHandler registers a handler with the API endoint.
func RegisterHandleFunc(path string, handleFunc func(http.ResponseWriter, *http.Request)) *mux.Route { func RegisterHandler(path string, handler http.Handler) *mux.Route {
return router.HandleFunc(path, handleFunc) handlerLock.Lock()
defer handlerLock.Unlock()
return mainMux.Handle(path, handler)
} }
// RequestLogger is a logging middleware // RegisterHandleFunc registers a handle function with the API endoint.
func RequestLogger(next http.Handler) http.Handler { func RegisterHandleFunc(path string, handleFunc func(http.ResponseWriter, *http.Request)) *mux.Route {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handlerLock.Lock()
log.Tracef("api request: %s ___ %s", r.RemoteAddr, r.RequestURI) defer handlerLock.Unlock()
ew := NewEnrichedResponseWriter(w) return mainMux.HandleFunc(path, handleFunc)
next.ServeHTTP(ew, r) }
log.Infof("api request: %s %d %s", r.RemoteAddr, ew.Status, r.RequestURI)
}) // RegisterMiddleware registers a middle function with the API endoint.
func RegisterMiddleware(middleware Middleware) {
handlerLock.Lock()
defer handlerLock.Unlock()
middlewareHandler.handlers = append(middlewareHandler.handlers, middleware)
} }
// Serve starts serving the API endpoint. // Serve starts serving the API endpoint.
func Serve() { func Serve() {
router.Use(RequestLogger) // configure server
server.Addr = GetAPIAddress()
server.Handler = middlewareHandler
mainMux := http.NewServeMux() // start serving
mainMux.Handle("/", router) // net/http pattern matching /* log.Infof("api: starting to listen on %s", server.Addr)
mainMux.HandleFunc("/api/database/v1", startDatabaseAPI) // net/http pattern matching only this exact path // TODO: retry if failed
log.Errorf("api: failed to listen on %s: %s", server.Addr, server.ListenAndServe())
address := getListenAddress()
log.Infof("api: starting to listen on %s", address)
log.Errorf("api: failed to listen on %s: %s", address, http.ListenAndServe(address, mainMux))
} }
// GetMuxVars wraps github.com/gorilla/mux.Vars in order to mitigate context key issues in multi-repo projects. // GetMuxVars wraps github.com/gorilla/mux.Vars in order to mitigate context key issues in multi-repo projects.

View file

@ -1 +0,0 @@
package api

View file

@ -7,5 +7,5 @@ import (
) )
func init() { func init() {
api.RegisterAdditionalRoute("/test/", http.StripPrefix("/test/", http.FileServer(http.Dir("./api/testclient/root/")))) api.RegisterHandler("/test/", http.StripPrefix("/test/", http.FileServer(http.Dir("./api/testclient/root/"))))
} }

View file

@ -1 +0,0 @@
package api