mirror of
https://github.com/safing/portbase
synced 2025-09-01 18:19:57 +00:00
Merge pull request #55 from safing/feature/api-improvements
API Improvements
This commit is contained in:
commit
df271461fc
5 changed files with 109 additions and 46 deletions
|
@ -1,18 +1,21 @@
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/safing/portbase/modules"
|
||||||
|
|
||||||
"github.com/safing/portbase/log"
|
"github.com/safing/portbase/log"
|
||||||
"github.com/safing/portbase/rng"
|
"github.com/safing/portbase/rng"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
validTokens map[string]time.Time
|
validTokens = make(map[string]time.Time)
|
||||||
validTokensLock sync.Mutex
|
validTokensLock sync.Mutex
|
||||||
|
|
||||||
authFnLock sync.Mutex
|
authFnLock sync.Mutex
|
||||||
|
@ -24,24 +27,22 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
cookieName = "T17"
|
cookieName = "Portmaster-API-Token"
|
||||||
|
|
||||||
// in seconds
|
cookieTTL = 5 * time.Minute
|
||||||
cookieBaseTTL = 300 // 5 minutes
|
|
||||||
cookieTTL = cookieBaseTTL * time.Second
|
|
||||||
cookieRefresh = cookieBaseTTL * 0.9 * time.Second
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Authenticator is a function that can be set as the authenticator for the API endpoint. If none is set, all requests will be allowed.
|
// Authenticator is a function that can be set as the authenticator for the API endpoint. If none is set, all requests will be permitted.
|
||||||
type Authenticator func(s *http.Server, r *http.Request) (err error)
|
type Authenticator func(ctx context.Context, s *http.Server, r *http.Request) (err error)
|
||||||
|
|
||||||
// SetAuthenticator sets an authenticator function for the API endpoint. If none is set, all requests will be allowed.
|
// SetAuthenticator sets an authenticator function for the API endpoint. If none is set, all requests will be permitted.
|
||||||
func SetAuthenticator(fn Authenticator) error {
|
func SetAuthenticator(fn Authenticator) error {
|
||||||
authFnLock.Lock()
|
authFnLock.Lock()
|
||||||
defer authFnLock.Unlock()
|
defer authFnLock.Unlock()
|
||||||
|
|
||||||
if authFn == nil {
|
if authFn == nil {
|
||||||
authFn = fn
|
authFn = fn
|
||||||
|
module.NewTask("clean api auth tokens", cleanAuthTokens).Repeat(time.Minute)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,27 +51,7 @@ func SetAuthenticator(fn Authenticator) error {
|
||||||
|
|
||||||
func authMiddleware(next http.Handler) http.Handler {
|
func authMiddleware(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
tracer := log.Tracer(r.Context())
|
||||||
// check existing auth cookie
|
|
||||||
c, err := r.Cookie(cookieName)
|
|
||||||
if err == nil {
|
|
||||||
// get token
|
|
||||||
validTokensLock.Lock()
|
|
||||||
validUntil, valid := validTokens[c.Value]
|
|
||||||
validTokensLock.Unlock()
|
|
||||||
|
|
||||||
// check if token is valid
|
|
||||||
if valid && time.Now().Before(validUntil) {
|
|
||||||
// maybe refresh cookie
|
|
||||||
if time.Now().After(validUntil.Add(-cookieRefresh)) {
|
|
||||||
validTokensLock.Lock()
|
|
||||||
validTokens[c.Value] = time.Now()
|
|
||||||
validTokensLock.Unlock()
|
|
||||||
}
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// get authenticator
|
// get authenticator
|
||||||
authFnLock.Lock()
|
authFnLock.Lock()
|
||||||
|
@ -83,35 +64,78 @@ func authMiddleware(next http.Handler) http.Handler {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check existing auth cookie
|
||||||
|
c, err := r.Cookie(cookieName)
|
||||||
|
if err == nil {
|
||||||
|
// get token
|
||||||
|
validTokensLock.Lock()
|
||||||
|
validUntil, valid := validTokens[c.Value]
|
||||||
|
validTokensLock.Unlock()
|
||||||
|
|
||||||
|
// check if token is valid
|
||||||
|
if valid && time.Now().Before(validUntil) {
|
||||||
|
tracer.Tracef("api: auth token %s is valid, refreshing", c.Value)
|
||||||
|
// refresh cookie
|
||||||
|
validTokensLock.Lock()
|
||||||
|
validTokens[c.Value] = time.Now().Add(cookieTTL)
|
||||||
|
validTokensLock.Unlock()
|
||||||
|
// continue
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tracer.Tracef("api: provided auth token %s is invalid", c.Value)
|
||||||
|
}
|
||||||
|
|
||||||
// get auth decision
|
// get auth decision
|
||||||
err = authenticator(server, r)
|
err = authenticator(r.Context(), server, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, ErrAPIAccessDeniedMessage) {
|
if errors.Is(err, ErrAPIAccessDeniedMessage) {
|
||||||
log.Warningf("api: denying api access to %s", r.RemoteAddr)
|
tracer.Warningf("api: denying api access to %s", r.RemoteAddr)
|
||||||
http.Error(w, err.Error(), http.StatusForbidden)
|
http.Error(w, err.Error(), http.StatusForbidden)
|
||||||
} else {
|
} else {
|
||||||
log.Warningf("api: authenticator failed: %s", err)
|
tracer.Warningf("api: authenticator failed: %s", err)
|
||||||
http.Error(w, "Internal server error during authentication.", http.StatusInternalServerError)
|
http.Error(w, "Internal server error during authentication.", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// write new cookie
|
// generate new token
|
||||||
token, err := rng.Bytes(32) // 256 bit
|
token, err := rng.Bytes(32) // 256 bit
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warningf("api: failed to generate random token: %s", err)
|
tracer.Warningf("api: failed to generate random token: %s", err)
|
||||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
tokenString := base64.RawURLEncoding.EncodeToString(token)
|
tokenString := base64.RawURLEncoding.EncodeToString(token)
|
||||||
|
// write new cookie
|
||||||
http.SetCookie(w, &http.Cookie{
|
http.SetCookie(w, &http.Cookie{
|
||||||
Name: cookieName,
|
Name: cookieName,
|
||||||
Value: tokenString,
|
Value: tokenString,
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
|
SameSite: http.SameSiteStrictMode,
|
||||||
MaxAge: int(cookieTTL.Seconds()),
|
MaxAge: int(cookieTTL.Seconds()),
|
||||||
})
|
})
|
||||||
|
// save cookie
|
||||||
|
validTokensLock.Lock()
|
||||||
|
validTokens[tokenString] = time.Now().Add(cookieTTL)
|
||||||
|
validTokensLock.Unlock()
|
||||||
|
|
||||||
// serve
|
// serve
|
||||||
log.Tracef("api: granted %s", r.RemoteAddr)
|
tracer.Tracef("api: granted %s, assigned auth token %s", r.RemoteAddr, tokenString)
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func cleanAuthTokens(_ context.Context, _ *modules.Task) error {
|
||||||
|
validTokensLock.Lock()
|
||||||
|
defer validTokensLock.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
for token, validUntil := range validTokens {
|
||||||
|
if now.After(validUntil) {
|
||||||
|
delete(validTokens, token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -127,8 +127,14 @@ func (api *DatabaseAPI) handler() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !api.shuttingDown.IsSet() {
|
if !api.shuttingDown.IsSet() {
|
||||||
api.shutdown()
|
api.shutdown()
|
||||||
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
if websocket.IsCloseError(err,
|
||||||
log.Warningf("api: websocket write error: %s", err)
|
websocket.CloseNormalClosure,
|
||||||
|
websocket.CloseGoingAway,
|
||||||
|
websocket.CloseAbnormalClosure,
|
||||||
|
) {
|
||||||
|
log.Infof("api: websocket connection to %s closed", api.conn.RemoteAddr())
|
||||||
|
} else {
|
||||||
|
log.Warningf("api: websocket read error from %s: %s", api.conn.RemoteAddr(), err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
@ -204,8 +210,14 @@ func (api *DatabaseAPI) writer() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !api.shuttingDown.IsSet() {
|
if !api.shuttingDown.IsSet() {
|
||||||
api.shutdown()
|
api.shutdown()
|
||||||
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
if websocket.IsCloseError(err,
|
||||||
log.Warningf("api: websocket write error: %s", err)
|
websocket.CloseNormalClosure,
|
||||||
|
websocket.CloseGoingAway,
|
||||||
|
websocket.CloseAbnormalClosure,
|
||||||
|
) {
|
||||||
|
log.Infof("api: websocket connection to %s closed", api.conn.RemoteAddr())
|
||||||
|
} else {
|
||||||
|
log.Warningf("api: websocket write error to %s: %s", api.conn.RemoteAddr(), err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
|
@ -48,7 +48,7 @@ func (lrw *LoggingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
log.Infof("api request: %s HIJ %s", lrw.Request.RemoteAddr, lrw.Request.RequestURI)
|
log.Tracer(lrw.Request.Context()).Infof("api request: %s HIJ %s", lrw.Request.RemoteAddr, lrw.Request.RequestURI)
|
||||||
return c, b, nil
|
return c, b, nil
|
||||||
}
|
}
|
||||||
return nil, nil, errors.New("response does not implement http.Hijacker")
|
return nil, nil, errors.New("response does not implement http.Hijacker")
|
||||||
|
@ -57,12 +57,12 @@ func (lrw *LoggingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error)
|
||||||
// RequestLogger is a logging middleware.
|
// RequestLogger is a logging middleware.
|
||||||
func RequestLogger(next http.Handler) http.Handler {
|
func RequestLogger(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
log.Tracef("api request: %s ___ %s", r.RemoteAddr, r.RequestURI)
|
log.Tracer(r.Context()).Tracef("api request: %s ___ %s", r.RemoteAddr, r.RequestURI)
|
||||||
lrw := NewLoggingResponseWriter(w, r)
|
lrw := NewLoggingResponseWriter(w, r)
|
||||||
next.ServeHTTP(lrw, r)
|
next.ServeHTTP(lrw, r)
|
||||||
if lrw.Status != 0 {
|
if lrw.Status != 0 {
|
||||||
// request may have been hijacked
|
// request may have been hijacked
|
||||||
log.Infof("api request: %s %d %s", lrw.Request.RemoteAddr, lrw.Status, lrw.Request.RequestURI)
|
log.Tracer(r.Context()).Infof("api request: %s %d %s", lrw.Request.RemoteAddr, lrw.Status, lrw.Request.RequestURI)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,11 @@
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import "net/http"
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/safing/portbase/log"
|
||||||
|
)
|
||||||
|
|
||||||
// Middleware is a function that can be added as a middleware to the API endpoint.
|
// Middleware is a function that can be added as a middleware to the API endpoint.
|
||||||
type Middleware func(next http.Handler) http.Handler
|
type Middleware func(next http.Handler) http.Handler
|
||||||
|
@ -18,10 +23,30 @@ func (mwh *mwHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
handler := mwh.final
|
handler := mwh.final
|
||||||
|
|
||||||
// build middleware chain
|
// build middleware chain
|
||||||
for _, mw := range mwh.handlers {
|
// loop in reverse to build the handler chain in the correct order
|
||||||
handler = mw(handler)
|
for i := len(mwh.handlers) - 1; i >= 0; i-- {
|
||||||
|
handler = mwh.handlers[i](handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
// start
|
// start
|
||||||
handler.ServeHTTP(w, r)
|
handler.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ModuleWorker is an http middleware that wraps the request in a module worker.
|
||||||
|
func ModuleWorker(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_ = module.RunWorker("http request", func(_ context.Context) error {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogTracer is an http middleware that attaches a log tracer to the request context.
|
||||||
|
func LogTracer(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx, tracer := log.AddTracer(r.Context())
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
tracer.Submit()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -19,6 +19,8 @@ var (
|
||||||
middlewareHandler = &mwHandler{
|
middlewareHandler = &mwHandler{
|
||||||
final: mainMux,
|
final: mainMux,
|
||||||
handlers: []Middleware{
|
handlers: []Middleware{
|
||||||
|
ModuleWorker,
|
||||||
|
LogTracer,
|
||||||
RequestLogger,
|
RequestLogger,
|
||||||
authMiddleware,
|
authMiddleware,
|
||||||
},
|
},
|
||||||
|
|
Loading…
Add table
Reference in a new issue