diff --git a/api/authentication.go b/api/authentication.go index 063c1b6..79bb2fe 100644 --- a/api/authentication.go +++ b/api/authentication.go @@ -1,18 +1,21 @@ package api import ( + "context" "encoding/base64" "errors" "net/http" "sync" "time" + "github.com/safing/portbase/modules" + "github.com/safing/portbase/log" "github.com/safing/portbase/rng" ) var ( - validTokens map[string]time.Time + validTokens = make(map[string]time.Time) validTokensLock sync.Mutex authFnLock sync.Mutex @@ -24,24 +27,22 @@ var ( ) const ( - cookieName = "T17" + cookieName = "Portmaster-API-Token" - // in seconds - cookieBaseTTL = 300 // 5 minutes - cookieTTL = cookieBaseTTL * time.Second - cookieRefresh = cookieBaseTTL * 0.9 * time.Second + cookieTTL = 5 * time.Minute ) -// Authenticator is a function that can be set as the authenticator for the API endpoint. If none is set, all requests will be allowed. -type Authenticator func(s *http.Server, r *http.Request) (err error) +// Authenticator is a function that can be set as the authenticator for the API endpoint. If none is set, all requests will be permitted. +type Authenticator func(ctx context.Context, s *http.Server, r *http.Request) (err error) -// SetAuthenticator sets an authenticator function for the API endpoint. If none is set, all requests will be allowed. +// SetAuthenticator sets an authenticator function for the API endpoint. If none is set, all requests will be permitted. func SetAuthenticator(fn Authenticator) error { authFnLock.Lock() defer authFnLock.Unlock() if authFn == nil { authFn = fn + module.NewTask("clean api auth tokens", cleanAuthTokens).Repeat(time.Minute) return nil } @@ -50,27 +51,7 @@ func SetAuthenticator(fn Authenticator) error { func authMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - - // check existing auth cookie - c, err := r.Cookie(cookieName) - if err == nil { - // get token - validTokensLock.Lock() - validUntil, valid := validTokens[c.Value] - validTokensLock.Unlock() - - // check if token is valid - if valid && time.Now().Before(validUntil) { - // maybe refresh cookie - if time.Now().After(validUntil.Add(-cookieRefresh)) { - validTokensLock.Lock() - validTokens[c.Value] = time.Now() - validTokensLock.Unlock() - } - next.ServeHTTP(w, r) - return - } - } + tracer := log.Tracer(r.Context()) // get authenticator authFnLock.Lock() @@ -83,35 +64,78 @@ func authMiddleware(next http.Handler) http.Handler { return } + // check existing auth cookie + c, err := r.Cookie(cookieName) + if err == nil { + // get token + validTokensLock.Lock() + validUntil, valid := validTokens[c.Value] + validTokensLock.Unlock() + + // check if token is valid + if valid && time.Now().Before(validUntil) { + tracer.Tracef("api: auth token %s is valid, refreshing", c.Value) + // refresh cookie + validTokensLock.Lock() + validTokens[c.Value] = time.Now().Add(cookieTTL) + validTokensLock.Unlock() + // continue + next.ServeHTTP(w, r) + return + } + + tracer.Tracef("api: provided auth token %s is invalid", c.Value) + } + // get auth decision - err = authenticator(server, r) + err = authenticator(r.Context(), server, r) if err != nil { if errors.Is(err, ErrAPIAccessDeniedMessage) { - log.Warningf("api: denying api access to %s", r.RemoteAddr) + tracer.Warningf("api: denying api access to %s", r.RemoteAddr) http.Error(w, err.Error(), http.StatusForbidden) } else { - log.Warningf("api: authenticator failed: %s", err) + tracer.Warningf("api: authenticator failed: %s", err) http.Error(w, "Internal server error during authentication.", http.StatusInternalServerError) } return } - // write new cookie + // generate new token token, err := rng.Bytes(32) // 256 bit if err != nil { - log.Warningf("api: failed to generate random token: %s", err) + tracer.Warningf("api: failed to generate random token: %s", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) } tokenString := base64.RawURLEncoding.EncodeToString(token) + // write new cookie http.SetCookie(w, &http.Cookie{ Name: cookieName, Value: tokenString, HttpOnly: true, + SameSite: http.SameSiteStrictMode, MaxAge: int(cookieTTL.Seconds()), }) + // save cookie + validTokensLock.Lock() + validTokens[tokenString] = time.Now().Add(cookieTTL) + validTokensLock.Unlock() // serve - log.Tracef("api: granted %s", r.RemoteAddr) + tracer.Tracef("api: granted %s, assigned auth token %s", r.RemoteAddr, tokenString) next.ServeHTTP(w, r) }) } + +func cleanAuthTokens(_ context.Context, _ *modules.Task) error { + validTokensLock.Lock() + defer validTokensLock.Unlock() + + now := time.Now() + for token, validUntil := range validTokens { + if now.After(validUntil) { + delete(validTokens, token) + } + } + + return nil +}