From 8d091f7f7a50646de90a4b7a940cfc5b13d44a66 Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 30 Jul 2019 13:08:45 +0200 Subject: [PATCH] Update api with registering handlers, middleware and authentication --- api/authentication.go | 110 +++++++++++++++++++++++++++++++++++++++ api/config.go | 3 +- api/database.go | 4 ++ api/enriched-response.go | 69 +++++++++++++++++++----- api/main.go | 15 ++++++ api/middleware.go | 27 ++++++++++ api/router.go | 60 ++++++++++++++------- api/security.go | 1 - api/testclient/serve.go | 2 +- api/websocket.go | 1 - 10 files changed, 255 insertions(+), 37 deletions(-) create mode 100644 api/authentication.go create mode 100644 api/middleware.go delete mode 100644 api/security.go delete mode 100644 api/websocket.go diff --git a/api/authentication.go b/api/authentication.go new file mode 100644 index 0000000..156dba2 --- /dev/null +++ b/api/authentication.go @@ -0,0 +1,110 @@ +package api + +import ( + "encoding/base64" + "net/http" + "sync" + "time" + + "github.com/safing/portbase/crypto/random" + "github.com/safing/portbase/log" +) + +var ( + validTokens map[string]time.Time + validTokensLock sync.Mutex + + authFnLock sync.Mutex + authFn Authenticator +) + +const ( + cookieName = "T17" + + // in seconds + cookieBaseTTL = 300 // 5 minutes + cookieTTL = cookieBaseTTL * time.Second + cookieRefresh = cookieBaseTTL * 0.9 * time.Second +) + +// Authenticator is a function that can be set as the authenticator for the API endpoint. If none is set, all requests will be allowed. +type Authenticator func(s *http.Server, r *http.Request) (grantAccess bool, err error) + +// SetAuthenticator sets an authenticator function for the API endpoint. If none is set, all requests will be allowed. +func SetAuthenticator(fn Authenticator) error { + authFnLock.Lock() + defer authFnLock.Unlock() + + if authFn == nil { + authFn = fn + return nil + } + + return ErrAuthenticationAlreadySet +} + +func authMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + // check existing auth cookie + c, err := r.Cookie(cookieName) + if err == nil { + // get token + validTokensLock.Lock() + validUntil, valid := validTokens[c.Value] + validTokensLock.Unlock() + + // check if token is valid + if valid && time.Now().Before(validUntil) { + // maybe refresh cookie + if time.Now().After(validUntil.Add(-cookieRefresh)) { + validTokensLock.Lock() + validTokens[c.Value] = time.Now() + validTokensLock.Unlock() + } + next.ServeHTTP(w, r) + return + } + } + + // get authenticator + authFnLock.Lock() + authenticator := authFn + authFnLock.Unlock() + + // permit if no authenticator set + if authenticator == nil { + next.ServeHTTP(w, r) + return + } + + // get auth decision + grantAccess, err := authenticator(server, r) + if err != nil { + log.Errorf("api: authenticator failed: %s", err) + http.Error(w, "", http.StatusInternalServerError) + } + if !grantAccess { + log.Warningf("api: denying api access to %s", r.RemoteAddr) + http.Error(w, "", http.StatusForbidden) + return + } + + // write new cookie + token, err := random.Bytes(32) // 256 bit + if err != nil { + log.Errorf("api: failed to generate random token: %s", err) + http.Error(w, "", http.StatusInternalServerError) + } + tokenString := base64.RawURLEncoding.EncodeToString(token) + http.SetCookie(w, &http.Cookie{ + Name: cookieName, + Value: tokenString, + HttpOnly: true, + }) + + // serve + log.Tracef("api: granted %s", r.RemoteAddr) + next.ServeHTTP(w, r) + }) +} diff --git a/api/config.go b/api/config.go index 8a6e885..08b28a6 100644 --- a/api/config.go +++ b/api/config.go @@ -24,7 +24,7 @@ func checkFlags() error { return nil } -func getListenAddress() string { +func GetAPIAddress() string { if listenAddressFlag != "" { return listenAddressFlag } @@ -49,6 +49,7 @@ func registerConfig() error { return nil } +// SetDefaultAPIListenAddress sets the default listen address for the API. func SetDefaultAPIListenAddress(address string) { if defaultListenAddress == "" { defaultListenAddress = address diff --git a/api/database.go b/api/database.go index 0f900ed..dac729e 100644 --- a/api/database.go +++ b/api/database.go @@ -35,6 +35,10 @@ var ( dbAPISeperatorBytes = []byte(dbAPISeperator) ) +func init() { + RegisterHandleFunc("/api/database/v1", startDatabaseAPI) // net/http pattern matching only this exact path +} + // DatabaseAPI is a database API instance. type DatabaseAPI struct { conn *websocket.Conn diff --git a/api/enriched-response.go b/api/enriched-response.go index 58fd983..6aa5685 100644 --- a/api/enriched-response.go +++ b/api/enriched-response.go @@ -1,25 +1,68 @@ package api import ( + "bufio" + "errors" + "net" "net/http" + + "github.com/safing/portbase/log" ) -// EnrichedResponseWriter is a wrapper for http.ResponseWriter for better information extraction. -type EnrichedResponseWriter struct { - http.ResponseWriter - Status int +// LoggingResponseWriter is a wrapper for http.ResponseWriter for better request logging. +type LoggingResponseWriter struct { + ResponseWriter http.ResponseWriter + Request *http.Request + Status int } -// NewEnrichedResponseWriter wraps a http.ResponseWriter. -func NewEnrichedResponseWriter(w http.ResponseWriter) *EnrichedResponseWriter { - return &EnrichedResponseWriter{ - w, - 0, +// NewLoggingResponseWriter wraps a http.ResponseWriter. +func NewLoggingResponseWriter(w http.ResponseWriter, r *http.Request) *LoggingResponseWriter { + return &LoggingResponseWriter{ + ResponseWriter: w, + Request: r, } } -// WriteHeader wraps the original WriteHeader method to extract information. -func (ew *EnrichedResponseWriter) WriteHeader(code int) { - ew.Status = code - ew.ResponseWriter.WriteHeader(code) +// Header wraps the original Header method. +func (lrw *LoggingResponseWriter) Header() http.Header { + return lrw.ResponseWriter.Header() +} + +// Write wraps the original Write method. +func (lrw *LoggingResponseWriter) Write(b []byte) (int, error) { + return lrw.ResponseWriter.Write(b) +} + +// WriteHeader wraps the original WriteHeader method to extract information. +func (lrw *LoggingResponseWriter) WriteHeader(code int) { + lrw.Status = code + lrw.ResponseWriter.WriteHeader(code) +} + +// Hijack wraps the original Hijack method, if available. +func (lrw *LoggingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hijacker, ok := lrw.ResponseWriter.(http.Hijacker) + if ok { + c, b, err := hijacker.Hijack() + if err != nil { + return nil, nil, err + } + log.Infof("api request: %s HIJ %s", lrw.Request.RemoteAddr, lrw.Request.RequestURI) + return c, b, nil + } + return nil, nil, errors.New("response does not implement http.Hijacker") +} + +// RequestLogger is a logging middleware +func RequestLogger(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Tracef("api request: %s ___ %s", r.RemoteAddr, r.RequestURI) + lrw := NewLoggingResponseWriter(w, r) + next.ServeHTTP(lrw, r) + if lrw.Status != 0 { + // request may have been hijacked + log.Infof("api request: %s %d %s", lrw.Request.RemoteAddr, lrw.Status, lrw.Request.RequestURI) + } + }) } diff --git a/api/main.go b/api/main.go index 4a3c3d9..b522788 100644 --- a/api/main.go +++ b/api/main.go @@ -1,9 +1,17 @@ package api import ( + "context" + "errors" + "github.com/safing/portbase/modules" ) +// API Errors +var ( + ErrAuthenticationAlreadySet = errors.New("the authentication function has already been set") +) + func init() { modules.Register("api", prep, start, nil, "database") } @@ -20,3 +28,10 @@ func start() error { go Serve() return nil } + +func stop() error { + if server != nil { + return server.Shutdown(context.Background()) + } + return nil +} diff --git a/api/middleware.go b/api/middleware.go new file mode 100644 index 0000000..89dd465 --- /dev/null +++ b/api/middleware.go @@ -0,0 +1,27 @@ +package api + +import "net/http" + +// Middleware is a function that can be added as a middleware to the API endpoint. +type Middleware func(next http.Handler) http.Handler + +type mwHandler struct { + handlers []Middleware + final http.Handler +} + +func (mwh *mwHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + handlerLock.RLock() + defer handlerLock.RUnlock() + + // final handler + handler := mwh.final + + // build middleware chain + for _, mw := range mwh.handlers { + handler = mw(handler) + } + + // start + handler.ServeHTTP(w, r) +} diff --git a/api/router.go b/api/router.go index 06cbba7..236b72b 100644 --- a/api/router.go +++ b/api/router.go @@ -2,6 +2,7 @@ package api import ( "net/http" + "sync" "github.com/gorilla/mux" @@ -9,35 +10,54 @@ import ( ) var ( - router = mux.NewRouter() + // gorilla mux + mainMux = mux.NewRouter() + + // middlewares + middlewareHandler = &mwHandler{ + final: mainMux, + handlers: []Middleware{ + RequestLogger, + authMiddleware, + }, + } + + // main server and lock + server = &http.Server{} + handlerLock sync.RWMutex ) -// RegisterHandleFunc registers an additional handle function with the API endoint. -func RegisterHandleFunc(path string, handleFunc func(http.ResponseWriter, *http.Request)) *mux.Route { - return router.HandleFunc(path, handleFunc) +// RegisterHandler registers a handler with the API endoint. +func RegisterHandler(path string, handler http.Handler) *mux.Route { + handlerLock.Lock() + defer handlerLock.Unlock() + return mainMux.Handle(path, handler) } -// RequestLogger is a logging middleware -func RequestLogger(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log.Tracef("api request: %s ___ %s", r.RemoteAddr, r.RequestURI) - ew := NewEnrichedResponseWriter(w) - next.ServeHTTP(ew, r) - log.Infof("api request: %s %d %s", r.RemoteAddr, ew.Status, r.RequestURI) - }) +// RegisterHandleFunc registers a handle function with the API endoint. +func RegisterHandleFunc(path string, handleFunc func(http.ResponseWriter, *http.Request)) *mux.Route { + handlerLock.Lock() + defer handlerLock.Unlock() + return mainMux.HandleFunc(path, handleFunc) +} + +// RegisterMiddleware registers a middle function with the API endoint. +func RegisterMiddleware(middleware Middleware) { + handlerLock.Lock() + defer handlerLock.Unlock() + middlewareHandler.handlers = append(middlewareHandler.handlers, middleware) } // Serve starts serving the API endpoint. func Serve() { - router.Use(RequestLogger) + // configure server + server.Addr = GetAPIAddress() + server.Handler = middlewareHandler - mainMux := http.NewServeMux() - mainMux.Handle("/", router) // net/http pattern matching /* - mainMux.HandleFunc("/api/database/v1", startDatabaseAPI) // net/http pattern matching only this exact path - - address := getListenAddress() - log.Infof("api: starting to listen on %s", address) - log.Errorf("api: failed to listen on %s: %s", address, http.ListenAndServe(address, mainMux)) + // start serving + log.Infof("api: starting to listen on %s", server.Addr) + // TODO: retry if failed + log.Errorf("api: failed to listen on %s: %s", server.Addr, server.ListenAndServe()) } // GetMuxVars wraps github.com/gorilla/mux.Vars in order to mitigate context key issues in multi-repo projects. diff --git a/api/security.go b/api/security.go deleted file mode 100644 index 778f64e..0000000 --- a/api/security.go +++ /dev/null @@ -1 +0,0 @@ -package api diff --git a/api/testclient/serve.go b/api/testclient/serve.go index 01c40e5..e70e636 100644 --- a/api/testclient/serve.go +++ b/api/testclient/serve.go @@ -7,5 +7,5 @@ import ( ) func init() { - api.RegisterAdditionalRoute("/test/", http.StripPrefix("/test/", http.FileServer(http.Dir("./api/testclient/root/")))) + api.RegisterHandler("/test/", http.StripPrefix("/test/", http.FileServer(http.Dir("./api/testclient/root/")))) } diff --git a/api/websocket.go b/api/websocket.go deleted file mode 100644 index 778f64e..0000000 --- a/api/websocket.go +++ /dev/null @@ -1 +0,0 @@ -package api