diff --git a/api/authentication.go b/api/authentication.go index 063c1b6..79bb2fe 100644 --- a/api/authentication.go +++ b/api/authentication.go @@ -1,18 +1,21 @@ package api import ( + "context" "encoding/base64" "errors" "net/http" "sync" "time" + "github.com/safing/portbase/modules" + "github.com/safing/portbase/log" "github.com/safing/portbase/rng" ) var ( - validTokens map[string]time.Time + validTokens = make(map[string]time.Time) validTokensLock sync.Mutex authFnLock sync.Mutex @@ -24,24 +27,22 @@ var ( ) const ( - cookieName = "T17" + cookieName = "Portmaster-API-Token" - // in seconds - cookieBaseTTL = 300 // 5 minutes - cookieTTL = cookieBaseTTL * time.Second - cookieRefresh = cookieBaseTTL * 0.9 * time.Second + cookieTTL = 5 * time.Minute ) -// Authenticator is a function that can be set as the authenticator for the API endpoint. If none is set, all requests will be allowed. -type Authenticator func(s *http.Server, r *http.Request) (err error) +// Authenticator is a function that can be set as the authenticator for the API endpoint. If none is set, all requests will be permitted. +type Authenticator func(ctx context.Context, s *http.Server, r *http.Request) (err error) -// SetAuthenticator sets an authenticator function for the API endpoint. If none is set, all requests will be allowed. +// SetAuthenticator sets an authenticator function for the API endpoint. If none is set, all requests will be permitted. func SetAuthenticator(fn Authenticator) error { authFnLock.Lock() defer authFnLock.Unlock() if authFn == nil { authFn = fn + module.NewTask("clean api auth tokens", cleanAuthTokens).Repeat(time.Minute) return nil } @@ -50,27 +51,7 @@ func SetAuthenticator(fn Authenticator) error { func authMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - - // check existing auth cookie - c, err := r.Cookie(cookieName) - if err == nil { - // get token - validTokensLock.Lock() - validUntil, valid := validTokens[c.Value] - validTokensLock.Unlock() - - // check if token is valid - if valid && time.Now().Before(validUntil) { - // maybe refresh cookie - if time.Now().After(validUntil.Add(-cookieRefresh)) { - validTokensLock.Lock() - validTokens[c.Value] = time.Now() - validTokensLock.Unlock() - } - next.ServeHTTP(w, r) - return - } - } + tracer := log.Tracer(r.Context()) // get authenticator authFnLock.Lock() @@ -83,35 +64,78 @@ func authMiddleware(next http.Handler) http.Handler { return } + // check existing auth cookie + c, err := r.Cookie(cookieName) + if err == nil { + // get token + validTokensLock.Lock() + validUntil, valid := validTokens[c.Value] + validTokensLock.Unlock() + + // check if token is valid + if valid && time.Now().Before(validUntil) { + tracer.Tracef("api: auth token %s is valid, refreshing", c.Value) + // refresh cookie + validTokensLock.Lock() + validTokens[c.Value] = time.Now().Add(cookieTTL) + validTokensLock.Unlock() + // continue + next.ServeHTTP(w, r) + return + } + + tracer.Tracef("api: provided auth token %s is invalid", c.Value) + } + // get auth decision - err = authenticator(server, r) + err = authenticator(r.Context(), server, r) if err != nil { if errors.Is(err, ErrAPIAccessDeniedMessage) { - log.Warningf("api: denying api access to %s", r.RemoteAddr) + tracer.Warningf("api: denying api access to %s", r.RemoteAddr) http.Error(w, err.Error(), http.StatusForbidden) } else { - log.Warningf("api: authenticator failed: %s", err) + tracer.Warningf("api: authenticator failed: %s", err) http.Error(w, "Internal server error during authentication.", http.StatusInternalServerError) } return } - // write new cookie + // generate new token token, err := rng.Bytes(32) // 256 bit if err != nil { - log.Warningf("api: failed to generate random token: %s", err) + tracer.Warningf("api: failed to generate random token: %s", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) } tokenString := base64.RawURLEncoding.EncodeToString(token) + // write new cookie http.SetCookie(w, &http.Cookie{ Name: cookieName, Value: tokenString, HttpOnly: true, + SameSite: http.SameSiteStrictMode, MaxAge: int(cookieTTL.Seconds()), }) + // save cookie + validTokensLock.Lock() + validTokens[tokenString] = time.Now().Add(cookieTTL) + validTokensLock.Unlock() // serve - log.Tracef("api: granted %s", r.RemoteAddr) + tracer.Tracef("api: granted %s, assigned auth token %s", r.RemoteAddr, tokenString) next.ServeHTTP(w, r) }) } + +func cleanAuthTokens(_ context.Context, _ *modules.Task) error { + validTokensLock.Lock() + defer validTokensLock.Unlock() + + now := time.Now() + for token, validUntil := range validTokens { + if now.After(validUntil) { + delete(validTokens, token) + } + } + + return nil +} diff --git a/api/database.go b/api/database.go index 39d128b..9830571 100644 --- a/api/database.go +++ b/api/database.go @@ -127,8 +127,14 @@ func (api *DatabaseAPI) handler() { if err != nil { if !api.shuttingDown.IsSet() { api.shutdown() - if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { - log.Warningf("api: websocket write error: %s", err) + if websocket.IsCloseError(err, + websocket.CloseNormalClosure, + websocket.CloseGoingAway, + websocket.CloseAbnormalClosure, + ) { + log.Infof("api: websocket connection to %s closed", api.conn.RemoteAddr()) + } else { + log.Warningf("api: websocket read error from %s: %s", api.conn.RemoteAddr(), err) } } return @@ -204,8 +210,14 @@ func (api *DatabaseAPI) writer() { if err != nil { if !api.shuttingDown.IsSet() { api.shutdown() - if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { - log.Warningf("api: websocket write error: %s", err) + if websocket.IsCloseError(err, + websocket.CloseNormalClosure, + websocket.CloseGoingAway, + websocket.CloseAbnormalClosure, + ) { + log.Infof("api: websocket connection to %s closed", api.conn.RemoteAddr()) + } else { + log.Warningf("api: websocket write error to %s: %s", api.conn.RemoteAddr(), err) } } return diff --git a/api/enriched-response.go b/api/enriched-response.go index f1e34c2..138124d 100644 --- a/api/enriched-response.go +++ b/api/enriched-response.go @@ -48,7 +48,7 @@ func (lrw *LoggingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) if err != nil { return nil, nil, err } - log.Infof("api request: %s HIJ %s", lrw.Request.RemoteAddr, lrw.Request.RequestURI) + log.Tracer(lrw.Request.Context()).Infof("api request: %s HIJ %s", lrw.Request.RemoteAddr, lrw.Request.RequestURI) return c, b, nil } return nil, nil, errors.New("response does not implement http.Hijacker") @@ -57,12 +57,12 @@ func (lrw *LoggingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) // RequestLogger is a logging middleware. func RequestLogger(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log.Tracef("api request: %s ___ %s", r.RemoteAddr, r.RequestURI) + log.Tracer(r.Context()).Tracef("api request: %s ___ %s", r.RemoteAddr, r.RequestURI) lrw := NewLoggingResponseWriter(w, r) next.ServeHTTP(lrw, r) if lrw.Status != 0 { // request may have been hijacked - log.Infof("api request: %s %d %s", lrw.Request.RemoteAddr, lrw.Status, lrw.Request.RequestURI) + log.Tracer(r.Context()).Infof("api request: %s %d %s", lrw.Request.RemoteAddr, lrw.Status, lrw.Request.RequestURI) } }) } diff --git a/api/middleware.go b/api/middleware.go index 89dd465..7877037 100644 --- a/api/middleware.go +++ b/api/middleware.go @@ -1,6 +1,11 @@ package api -import "net/http" +import ( + "context" + "net/http" + + "github.com/safing/portbase/log" +) // Middleware is a function that can be added as a middleware to the API endpoint. type Middleware func(next http.Handler) http.Handler @@ -18,10 +23,30 @@ func (mwh *mwHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { handler := mwh.final // build middleware chain - for _, mw := range mwh.handlers { - handler = mw(handler) + // loop in reverse to build the handler chain in the correct order + for i := len(mwh.handlers) - 1; i >= 0; i-- { + handler = mwh.handlers[i](handler) } // start handler.ServeHTTP(w, r) } + +// ModuleWorker is an http middleware that wraps the request in a module worker. +func ModuleWorker(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = module.RunWorker("http request", func(_ context.Context) error { + next.ServeHTTP(w, r) + return nil + }) + }) +} + +// LogTracer is an http middleware that attaches a log tracer to the request context. +func LogTracer(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, tracer := log.AddTracer(r.Context()) + next.ServeHTTP(w, r.WithContext(ctx)) + tracer.Submit() + }) +} diff --git a/api/router.go b/api/router.go index bed878e..c4ad388 100644 --- a/api/router.go +++ b/api/router.go @@ -19,6 +19,8 @@ var ( middlewareHandler = &mwHandler{ final: mainMux, handlers: []Middleware{ + ModuleWorker, + LogTracer, RequestLogger, authMiddleware, },