Implement review suggestions

This commit is contained in:
Daniel 2021-01-06 13:34:25 +01:00
parent 5daeac8cf7
commit 3244fefd43
10 changed files with 118 additions and 111 deletions

View file

@ -3,24 +3,20 @@ package api
import "net/http"
// WrapInAuthHandler wraps a simple http.HandlerFunc into a handler that
// exposes the given API permissions.
// exposes the required API permissions for this handler.
func WrapInAuthHandler(fn http.HandlerFunc, read, write Permission) http.Handler {
return &wrappedAuthenticatedHandler{
handleFunc: fn,
read: read,
write: write,
HandlerFunc: fn,
read: read,
write: write,
}
}
type wrappedAuthenticatedHandler struct {
handleFunc http.HandlerFunc
read Permission
write Permission
}
http.HandlerFunc
// ServeHTTP handles the http request.
func (wah *wrappedAuthenticatedHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
wah.handleFunc(w, r)
read Permission
write Permission
}
// ReadPermission returns the read permission for the handler.

View file

@ -49,8 +49,10 @@ const (
PermitSelf Permission = 4
)
// Authenticator is a function that can be set as the authenticator for the API endpoint. If none is set, all requests will be permitted.
type Authenticator func(ctx context.Context, s *http.Server, r *http.Request) (*AuthToken, error)
// AuthenticatorFunc is a function that can be set as the authenticator for the
// API endpoint. If none is set, all requests will have full access.
// The returned AuthToken represents the permissions that the request has.
type AuthenticatorFunc func(r *http.Request, s *http.Server) (*AuthToken, error)
// AuthToken represents either a set of required or granted permissions.
// All attributes must be set when the struct is built and must not be changed
@ -81,7 +83,8 @@ func (token *AuthToken) Refresh(ttl time.Duration) {
}
// AuthenticatedHandler defines the handler interface to specify custom
// permission for an API handler.
// permission for an API handler. The returned permission is the required
// permission for the request to proceed.
type AuthenticatedHandler interface {
ReadPermission(*http.Request) Permission
WritePermission(*http.Request) Permission
@ -94,34 +97,35 @@ const (
var (
authFnSet = abool.New()
authFn Authenticator
authFn AuthenticatorFunc
authTokens = make(map[string]*AuthToken)
authTokensLock sync.Mutex
// ErrAPIAccessDeniedMessage should be returned by Authenticator functions in
// order to signify a blocked request, including a error message for the user.
// ErrAPIAccessDeniedMessage should be wrapped by errors returned by
// AuthenticatorFunc in order to signify a blocked request, including a error
// message for the user. This is an empty message on purpose, as to allow the
// function to define the full text of the error shown to the user.
ErrAPIAccessDeniedMessage = errors.New("")
)
// SetAuthenticator sets an authenticator function for the API endpoint. If none is set, all requests will be permitted.
func SetAuthenticator(fn Authenticator) error {
func SetAuthenticator(fn AuthenticatorFunc) error {
if module.Online() {
return ErrAuthenticationImmutable
}
if authFnSet.IsSet() {
if !authFnSet.SetToIf(false, true) {
return ErrAuthenticationAlreadySet
}
authFn = fn
authFnSet.Set()
return nil
}
func authMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := authenticateRequest(w, r, nil)
token := authenticateRequest(w, r, next)
if token != nil {
if _, apiRequest := getAPIContext(r); apiRequest != nil {
apiRequest.AuthToken = token
@ -136,7 +140,7 @@ func authenticateRequest(w http.ResponseWriter, r *http.Request, targetHandler h
// Check if authenticator is set.
if !authFnSet.IsSet() {
// Return highest available permissions.
// Return highest available permissions for the request.
return &AuthToken{
Read: PermitSelf,
Write: PermitSelf,
@ -180,7 +184,8 @@ func authenticateRequest(w http.ResponseWriter, r *http.Request, targetHandler h
requiredPermission = PermitAnyone
}
// Check for valid permission after handling the specials.
// The required permission must match the request permission values after
// handling the specials.
if requiredPermission < PermitAnyone || requiredPermission > PermitSelf {
tracer.Warningf(
"api: handler returned invalid permission: %s (%d)",
@ -197,11 +202,11 @@ func authenticateRequest(w http.ResponseWriter, r *http.Request, targetHandler h
// Get auth token from authenticator if none was in the request.
if token == nil {
var err error
token, err = authFn(r.Context(), server, r)
token, err = authFn(r, server)
if err != nil {
// Check for internal error.
if !errors.Is(err, ErrAPIAccessDeniedMessage) {
tracer.Warningf("api: authenticator failed: %s", err)
tracer.Errorf("api: authenticator failed: %s", err)
http.Error(w, "Internal server error during authentication.", http.StatusInternalServerError)
return nil
}
@ -254,7 +259,13 @@ func authenticateRequest(w http.ResponseWriter, r *http.Request, targetHandler h
}
tracer.Tracef("api: granted %s access to authenticated handler", r.RemoteAddr)
return token
// Make a copy of the AuthToken in order mitigate the handler poisoning the
// token, as changes would apply to future requests.
return &AuthToken{
Read: token.Read,
Write: token.Write,
}
}
func checkAuthToken(r *http.Request) *AuthToken {

View file

@ -1,7 +1,6 @@
package api
import (
"context"
"errors"
"fmt"
"net/http"
@ -14,7 +13,7 @@ var (
testToken = new(AuthToken)
)
func testAuthenticator(ctx context.Context, s *http.Server, r *http.Request) (*AuthToken, error) {
func testAuthenticator(r *http.Request, s *http.Server) (*AuthToken, error) {
switch {
case testToken.Read == -127 || testToken.Write == -127:
return nil, errors.New("test error")

10
api/doc.go Normal file
View file

@ -0,0 +1,10 @@
/*
Package api provides an API for integration with other components of the same software package and also third party components.
It provides direct database access as well as a simpler way to register API endpoints. You can of course also register raw `http.Handler`s directly.
Optional authentication guards registered handlers. This is achieved by attaching functions to the `http.Handler`s that are registered, which allow them to specify the required permissions for the handler.
The permissions are divided into the roles and assume a single user per host. The Roles are User, Admin and Self. User roles are expected to have mostly read access and react to notifications or system events, like a system tray program. The Admin role is meant for advanced components that also change settings, but are restricted so they cannot break the software. Self is reserved for internal use with full access.
*/
package api

View file

@ -7,6 +7,7 @@ import (
"io/ioutil"
"net/http"
"strconv"
"strings"
"sync"
"github.com/safing/portbase/database/record"
@ -29,36 +30,36 @@ type Endpoint struct {
// Order int
// ExpertiseLevel config.ExpertiseLevel
// ActionFn is for simple actions with a return message for the user.
ActionFn ActionFn `json:"-"`
// ActionFunc is for simple actions with a return message for the user.
ActionFunc ActionFunc `json:"-"`
// DataFn is for returning raw data that the caller for further processing.
DataFn DataFn `json:"-"`
// DataFunc is for returning raw data that the caller for further processing.
DataFunc DataFunc `json:"-"`
// StructFn is for returning any kind of struct.
StructFn StructFn `json:"-"`
// StructFunc is for returning any kind of struct.
StructFunc StructFunc `json:"-"`
// RecordFn is for returning a database record. It will be properly locked
// RecordFunc is for returning a database record. It will be properly locked
// and marshalled including metadata.
RecordFn RecordFn `json:"-"`
RecordFunc RecordFunc `json:"-"`
// HandlerFn is the raw http handler.
HandlerFn http.HandlerFunc `json:"-"`
// HandlerFunc is the raw http handler.
HandlerFunc http.HandlerFunc `json:"-"`
}
type (
// ActionFn is for simple actions with a return message for the user.
ActionFn func(ar *Request) (msg string, err error)
// ActionFunc is for simple actions with a return message for the user.
ActionFunc func(ar *Request) (msg string, err error)
// DataFn is for returning raw data that the caller for further processing.
DataFn func(ar *Request) (data []byte, err error)
// DataFunc is for returning raw data that the caller for further processing.
DataFunc func(ar *Request) (data []byte, err error)
// StructFn is for returning any kind of struct.
StructFn func(ar *Request) (i interface{}, err error)
// StructFunc is for returning any kind of struct.
StructFunc func(ar *Request) (i interface{}, err error)
// RecordFn is for returning a database record. It will be properly locked
// RecordFunc is for returning a database record. It will be properly locked
// and marshalled including metadata.
RecordFn func(ar *Request) (r record.Record, err error)
RecordFunc func(ar *Request) (r record.Record, err error)
)
// MIME Types
@ -135,7 +136,7 @@ func RegisterEndpoint(e Endpoint) error {
func (e *Endpoint) check() error {
// Check path.
if e.Path == "" {
if strings.TrimSpace(e.Path) == "" {
return errors.New("path is missing")
}
@ -150,23 +151,23 @@ func (e *Endpoint) check() error {
// Check functions.
var defaultMimeType string
fnCnt := 0
if e.ActionFn != nil {
if e.ActionFunc != nil {
fnCnt++
defaultMimeType = MimeTypeText
}
if e.DataFn != nil {
if e.DataFunc != nil {
fnCnt++
defaultMimeType = MimeTypeText
}
if e.StructFn != nil {
if e.StructFunc != nil {
fnCnt++
defaultMimeType = MimeTypeJSON
}
if e.RecordFn != nil {
if e.RecordFunc != nil {
fnCnt++
defaultMimeType = MimeTypeJSON
}
if e.HandlerFn != nil {
if e.HandlerFunc != nil {
fnCnt++
defaultMimeType = MimeTypeText
}
@ -214,7 +215,7 @@ func (eh *endpointHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodHead:
http.Error(w, "", http.StatusOK)
w.WriteHeader(http.StatusOK)
return
case http.MethodPost, http.MethodPut:
// Read body data.
@ -235,32 +236,32 @@ func (eh *endpointHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var err error
switch {
case apiEndpoint.ActionFn != nil:
case apiEndpoint.ActionFunc != nil:
var msg string
msg, err = apiEndpoint.ActionFn(apiRequest)
msg, err = apiEndpoint.ActionFunc(apiRequest)
if err == nil {
responseData = []byte(msg)
}
case apiEndpoint.DataFn != nil:
responseData, err = apiEndpoint.DataFn(apiRequest)
case apiEndpoint.DataFunc != nil:
responseData, err = apiEndpoint.DataFunc(apiRequest)
case apiEndpoint.StructFn != nil:
case apiEndpoint.StructFunc != nil:
var v interface{}
v, err = apiEndpoint.StructFn(apiRequest)
v, err = apiEndpoint.StructFunc(apiRequest)
if err == nil && v != nil {
responseData, err = json.Marshal(v)
}
case apiEndpoint.RecordFn != nil:
case apiEndpoint.RecordFunc != nil:
var rec record.Record
rec, err = apiEndpoint.RecordFn(apiRequest)
rec, err = apiEndpoint.RecordFunc(apiRequest)
if err == nil && r != nil {
responseData, err = marshalRecord(rec, false)
}
case apiEndpoint.HandlerFn != nil:
apiEndpoint.HandlerFn(w, r)
case apiEndpoint.HandlerFunc != nil:
apiEndpoint.HandlerFunc(w, r)
return
default:
@ -287,7 +288,7 @@ func (eh *endpointHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func readBody(w http.ResponseWriter, r *http.Request) (inputData []byte, ok bool) {
// Check for too long content in order to prevent death.
if r.ContentLength > 20000000 { // 20MB
http.Error(w, "Too much input data.", http.StatusBadRequest)
http.Error(w, "Too much input data.", http.StatusRequestEntityTooLarge)
return nil, false
}
@ -297,6 +298,5 @@ func readBody(w http.ResponseWriter, r *http.Request) (inputData []byte, ok bool
http.Error(w, "Failed to read body: "+err.Error(), http.StatusInternalServerError)
return nil, false
}
r.Body.Close()
return inputData, true
}

View file

@ -11,25 +11,25 @@ import (
func registerDebugEndpoints() error {
if err := RegisterEndpoint(Endpoint{
Path: "debug/stack",
Read: PermitAnyone,
DataFn: getStack,
}); err != nil {
return err
}
if err := RegisterEndpoint(Endpoint{
Path: "debug/stack/print",
Path: "debug/stack",
Read: PermitAnyone,
ActionFn: printStack,
DataFunc: getStack,
}); err != nil {
return err
}
if err := RegisterEndpoint(Endpoint{
Path: "debug/info",
Read: PermitAnyone,
DataFn: debugInfo,
Path: "debug/stack/print",
Read: PermitAnyone,
ActionFunc: printStack,
}); err != nil {
return err
}
if err := RegisterEndpoint(Endpoint{
Path: "debug/info",
Read: PermitAnyone,
DataFunc: debugInfo,
}); err != nil {
return err
}
@ -70,7 +70,7 @@ func debugInfo(ar *Request) (data []byte, err error) {
// Add debug information.
di.AddVersionInfo()
di.AddPlatformInfo(ar.Ctx())
di.AddPlatformInfo(ar.Context())
di.AddLastReportedModuleError()
di.AddLastUnexpectedLogs()
di.AddGoroutineStack()

View file

@ -10,23 +10,23 @@ func registerMetaEndpoints() error {
Path: "endpoints",
Read: PermitAnyone,
MimeType: MimeTypeJSON,
DataFn: listEndpoints,
DataFunc: listEndpoints,
}); err != nil {
return err
}
if err := RegisterEndpoint(Endpoint{
Path: "permission",
Read: Require,
StructFn: permissions,
Path: "permission",
Read: Require,
StructFunc: permissions,
}); err != nil {
return err
}
if err := RegisterEndpoint(Endpoint{
Path: "ping",
Read: PermitAnyone,
ActionFn: ping,
Path: "ping",
Read: PermitAnyone,
ActionFunc: ping,
}); err != nil {
return err
}

View file

@ -30,7 +30,7 @@ func TestEndpoints(t *testing.T) {
assert.NoError(t, RegisterEndpoint(Endpoint{
Path: "test/action",
Read: PermitAnyone,
ActionFn: func(_ *Request) (msg string, err error) {
ActionFunc: func(_ *Request) (msg string, err error) {
return successMsg, nil
},
}))
@ -39,7 +39,7 @@ func TestEndpoints(t *testing.T) {
assert.NoError(t, RegisterEndpoint(Endpoint{
Path: "test/action-err",
Read: PermitAnyone,
ActionFn: func(_ *Request) (msg string, err error) {
ActionFunc: func(_ *Request) (msg string, err error) {
return "", errors.New(failedMsg)
},
}))
@ -50,7 +50,7 @@ func TestEndpoints(t *testing.T) {
assert.NoError(t, RegisterEndpoint(Endpoint{
Path: "test/data",
Read: PermitAnyone,
DataFn: func(_ *Request) (data []byte, err error) {
DataFunc: func(_ *Request) (data []byte, err error) {
return []byte(successMsg), nil
},
}))
@ -59,7 +59,7 @@ func TestEndpoints(t *testing.T) {
assert.NoError(t, RegisterEndpoint(Endpoint{
Path: "test/data-err",
Read: PermitAnyone,
DataFn: func(_ *Request) (data []byte, err error) {
DataFunc: func(_ *Request) (data []byte, err error) {
return nil, errors.New(failedMsg)
},
}))
@ -70,7 +70,7 @@ func TestEndpoints(t *testing.T) {
assert.NoError(t, RegisterEndpoint(Endpoint{
Path: "test/struct",
Read: PermitAnyone,
StructFn: func(_ *Request) (i interface{}, err error) {
StructFunc: func(_ *Request) (i interface{}, err error) {
return &actionTestRecord{
Msg: successMsg,
}, nil
@ -81,7 +81,7 @@ func TestEndpoints(t *testing.T) {
assert.NoError(t, RegisterEndpoint(Endpoint{
Path: "test/struct-err",
Read: PermitAnyone,
StructFn: func(_ *Request) (i interface{}, err error) {
StructFunc: func(_ *Request) (i interface{}, err error) {
return nil, errors.New(failedMsg)
},
}))
@ -92,7 +92,7 @@ func TestEndpoints(t *testing.T) {
assert.NoError(t, RegisterEndpoint(Endpoint{
Path: "test/record",
Read: PermitAnyone,
RecordFn: func(_ *Request) (r record.Record, err error) {
RecordFunc: func(_ *Request) (r record.Record, err error) {
r = &actionTestRecord{
Msg: successMsg,
}
@ -105,7 +105,7 @@ func TestEndpoints(t *testing.T) {
assert.NoError(t, RegisterEndpoint(Endpoint{
Path: "test/record-err",
Read: PermitAnyone,
RecordFn: func(_ *Request) (r record.Record, err error) {
RecordFunc: func(_ *Request) (r record.Record, err error) {
return nil, errors.New(failedMsg)
},
}))
@ -139,17 +139,17 @@ func TestActionRegistration(t *testing.T) {
assert.Error(t, RegisterEndpoint(Endpoint{
Path: "test/err",
ActionFn: func(_ *Request) (msg string, err error) {
ActionFunc: func(_ *Request) (msg string, err error) {
return successMsg, nil
},
DataFn: func(_ *Request) (data []byte, err error) {
DataFunc: func(_ *Request) (data []byte, err error) {
return []byte(successMsg), nil
},
}))
assert.NoError(t, RegisterEndpoint(Endpoint{
Path: "test/err",
ActionFn: func(_ *Request) (msg string, err error) {
ActionFunc: func(_ *Request) (msg string, err error) {
return successMsg, nil
},
}))

View file

@ -1,7 +1,6 @@
package api
import (
"context"
"net/http"
"github.com/gorilla/mux"
@ -10,7 +9,7 @@ import (
// Request is a support struct to pool more request related information.
type Request struct {
// Request is the http request.
Request *http.Request
*http.Request
// InputData contains the request body for write operations.
InputData []byte
@ -28,11 +27,6 @@ type Request struct {
HandlerCache interface{}
}
// Ctx is a shortcut to access the request context.
func (ar *Request) Ctx() context.Context {
return ar.Request.Context()
}
// apiRequestContextKey is a key used for the context key/value storage.
type apiRequestContextKey struct{}

View file

@ -2,6 +2,7 @@ package api
import (
"context"
"errors"
"net/http"
"sync"
"time"
@ -127,21 +128,17 @@ func (mh *mainHandler) handle(w http.ResponseWriter, r *http.Request) error {
}
// Check authentication.
token := authenticateRequest(lrw, r, handler)
if token == nil {
apiRequest.AuthToken = authenticateRequest(lrw, r, handler)
if apiRequest.AuthToken == nil {
// Authenticator already replied.
return nil
}
apiRequest.AuthToken = &AuthToken{
Read: token.Read,
Write: token.Write,
}
// Handle request.
switch {
case handler != nil:
handler.ServeHTTP(lrw, r)
case match.MatchErr == mux.ErrMethodMismatch:
case errors.Is(match.MatchErr, mux.ErrMethodMismatch):
http.Error(lrw, "Method not allowed.", http.StatusMethodNotAllowed)
default: // handler == nil or other error
http.Error(lrw, "Not found.", http.StatusNotFound)