Add endpoint api and authentication layer

This commit is contained in:
Daniel 2020-12-25 23:38:13 +01:00
parent 3f544cb07f
commit 9100dc999b
11 changed files with 1257 additions and 115 deletions

34
api/auth_wrapper.go Normal file
View file

@ -0,0 +1,34 @@
package api
import "net/http"
// WrapInAuthHandler wraps a simple http.HandlerFunc into a handler that
// exposes the given API permissions.
func WrapInAuthHandler(fn http.HandlerFunc, read, write Permission) http.Handler {
return &wrappedAuthenticatedHandler{
handleFunc: fn,
read: read,
write: write,
}
}
type wrappedAuthenticatedHandler struct {
handleFunc http.HandlerFunc
read Permission
write Permission
}
// ServeHTTP handles the http request.
func (wah *wrappedAuthenticatedHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
wah.handleFunc(w, r)
}
// ReadPermission returns the read permission for the handler.
func (wah *wrappedAuthenticatedHandler) ReadPermission(r *http.Request) Permission {
return wah.read
}
// WritePermission returns the write permission for the handler.
func (wah *wrappedAuthenticatedHandler) WritePermission(r *http.Request) Permission {
return wah.write
}

View file

@ -8,138 +8,345 @@ import (
"sync"
"time"
"github.com/tevino/abool"
"github.com/safing/portbase/modules"
"github.com/safing/portbase/log"
"github.com/safing/portbase/rng"
)
var (
validTokens = make(map[string]time.Time)
validTokensLock sync.Mutex
// Permission defines an API requests permission.
type Permission int8
authFnLock sync.Mutex
authFn Authenticator
const (
// NotFound declares that the operation does not exist.
NotFound Permission = -2
// Require declares that the operation requires permission to be processed,
// but anyone can execute the operation.
Require Permission = -1
// NotSupported declares that the operation is not supported.
NotSupported Permission = 0
// PermitAnyone declares that anyone can execute the operation without any
// authentication.
PermitAnyone Permission = 1
// PermitUser declares that the operation may be executed by authenticated
// third party applications that are categorized as representing a simple
// user and is limited in access.
PermitUser Permission = 2
// PermitAdmin declares that the operation may be executed by authenticated
// third party applications that are categorized as representing an
// administrator and has broad in access.
PermitAdmin Permission = 3
// PermitSelf declares that the operation may only be executed by the
// software itself and its own (first party) components.
PermitSelf Permission = 4
)
// Authenticator is a function that can be set as the authenticator for the API endpoint. If none is set, all requests will be permitted.
type Authenticator func(ctx context.Context, s *http.Server, r *http.Request) (*AuthToken, error)
// AuthToken represents either a set of required or granted permissions.
// All attributes must be set when the struct is built and must not be changed
// later. Functions may be called at any time.
// The Write permission implicitly also includes reading.
type AuthToken struct {
Read Permission
Write Permission
validUntil time.Time
validLock sync.Mutex
}
// Expired returns whether the token has expired.
func (token *AuthToken) Expired() bool {
token.validLock.Lock()
defer token.validLock.Unlock()
return time.Now().After(token.validUntil)
}
// Refresh refreshes the validity of the token with the given TTL.
func (token *AuthToken) Refresh(ttl time.Duration) {
token.validLock.Lock()
defer token.validLock.Unlock()
token.validUntil = time.Now().Add(ttl)
}
// AuthenticatedHandler defines the handler interface to specify custom
// permission for an API handler.
type AuthenticatedHandler interface {
ReadPermission(*http.Request) Permission
WritePermission(*http.Request) Permission
}
const (
cookieName = "Portmaster-API-Token"
cookieTTL = 5 * time.Minute
)
var (
authFnSet = abool.New()
authFn Authenticator
authTokens = make(map[string]*AuthToken)
authTokensLock sync.Mutex
// ErrAPIAccessDeniedMessage should be returned by Authenticator functions in
// order to signify a blocked request, including a error message for the user.
ErrAPIAccessDeniedMessage = errors.New("")
)
const (
cookieName = "Portmaster-API-Token"
cookieTTL = 5 * time.Minute
)
// Authenticator is a function that can be set as the authenticator for the API endpoint. If none is set, all requests will be permitted.
type Authenticator func(ctx context.Context, s *http.Server, r *http.Request) (err error)
// SetAuthenticator sets an authenticator function for the API endpoint. If none is set, all requests will be permitted.
func SetAuthenticator(fn Authenticator) error {
if module.Online() {
return ErrAuthenticationImmutable
}
authFnLock.Lock()
defer authFnLock.Unlock()
if authFn != nil {
if authFnSet.IsSet() {
return ErrAuthenticationAlreadySet
}
authFn = fn
authFnSet.Set()
return nil
}
func authMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tracer := log.Tracer(r.Context())
// get authenticator
authFnLock.Lock()
authenticator := authFn
authFnLock.Unlock()
// permit if no authenticator set
if authenticator == nil {
token := authenticateRequest(w, r, nil)
if token != nil {
if _, apiRequest := getAPIContext(r); apiRequest != nil {
apiRequest.AuthToken = token
}
next.ServeHTTP(w, r)
return
}
// check existing auth cookie
c, err := r.Cookie(cookieName)
if err == nil {
// get token
validTokensLock.Lock()
validUntil, valid := validTokens[c.Value]
validTokensLock.Unlock()
// check if token is valid
if valid && time.Now().Before(validUntil) {
tracer.Tracef("api: auth token %s is valid, refreshing", c.Value)
// refresh cookie
validTokensLock.Lock()
validTokens[c.Value] = time.Now().Add(cookieTTL)
validTokensLock.Unlock()
// continue
next.ServeHTTP(w, r)
return
}
tracer.Tracef("api: provided auth token %s is invalid", c.Value)
}
// get auth decision
err = authenticator(r.Context(), server, r)
if err != nil {
if errors.Is(err, ErrAPIAccessDeniedMessage) {
tracer.Warningf("api: denying api access to %s", r.RemoteAddr)
http.Error(w, err.Error(), http.StatusForbidden)
} else {
tracer.Warningf("api: authenticator failed: %s", err)
http.Error(w, "Internal server error during authentication.", http.StatusInternalServerError)
}
return
}
// generate new token
token, err := rng.Bytes(32) // 256 bit
if err != nil {
tracer.Warningf("api: failed to generate random token: %s", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
tokenString := base64.RawURLEncoding.EncodeToString(token)
// write new cookie
http.SetCookie(w, &http.Cookie{
Name: cookieName,
Value: tokenString,
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteStrictMode,
MaxAge: int(cookieTTL.Seconds()),
})
// save cookie
validTokensLock.Lock()
validTokens[tokenString] = time.Now().Add(cookieTTL)
validTokensLock.Unlock()
// serve
tracer.Tracef("api: granted %s, assigned auth token %s", r.RemoteAddr, tokenString)
next.ServeHTTP(w, r)
})
}
func cleanAuthTokens(_ context.Context, _ *modules.Task) error {
validTokensLock.Lock()
defer validTokensLock.Unlock()
func authenticateRequest(w http.ResponseWriter, r *http.Request, targetHandler http.Handler) *AuthToken {
tracer := log.Tracer(r.Context())
now := time.Now()
for token, validUntil := range validTokens {
if now.After(validUntil) {
delete(validTokens, token)
// Check if authenticator is set.
if !authFnSet.IsSet() {
// Return highest available permissions.
return &AuthToken{
Read: PermitSelf,
Write: PermitSelf,
}
}
// Check if request is read only.
readRequest := isReadMethod(r.Method)
// Get required permission for target handler.
requiredPermission := PermitSelf
if authdHandler, ok := targetHandler.(AuthenticatedHandler); ok {
if readRequest {
requiredPermission = authdHandler.ReadPermission(r)
} else {
requiredPermission = authdHandler.WritePermission(r)
}
}
// Check if we need to do any authentication at all.
switch requiredPermission {
case NotFound:
// Not found.
tracer.Trace("api: authenticated handler reported: not found")
http.Error(w, "Not found.", http.StatusNotFound)
return nil
case NotSupported:
// A read or write permission can be marked as not supported.
tracer.Trace("api: authenticated handler reported: not supported")
http.Error(w, "Method not allowed.", http.StatusMethodNotAllowed)
return nil
case PermitAnyone:
// Don't process permissions, as we don't need them.
tracer.Tracef("api: granted %s access to public handler", r.RemoteAddr)
return &AuthToken{
Read: PermitAnyone,
Write: PermitAnyone,
}
case Require:
// Continue processing permissions, but treat as PermitAnyone.
requiredPermission = PermitAnyone
}
// Check for valid permission after handling the specials.
if requiredPermission < PermitAnyone || requiredPermission > PermitSelf {
tracer.Warningf(
"api: handler returned invalid permission: %s (%d)",
requiredPermission,
requiredPermission,
)
http.Error(w, "Internal server error during authentication.", http.StatusInternalServerError)
return nil
}
// Check for an existing auth token.
token := checkAuthToken(r)
// Get auth token from authenticator if none was in the request.
if token == nil {
var err error
token, err = authFn(r.Context(), server, r)
if err != nil {
// Check for internal error.
if !errors.Is(err, ErrAPIAccessDeniedMessage) {
tracer.Warningf("api: authenticator failed: %s", err)
http.Error(w, "Internal server error during authentication.", http.StatusInternalServerError)
return nil
}
// If authentication failed and we require authentication, return an
// authentication error.
if requiredPermission != PermitAnyone {
// Return authentication error.
tracer.Warningf("api: denying api access to %s", r.RemoteAddr)
http.Error(w, err.Error(), http.StatusForbidden)
return nil
}
token = &AuthToken{
Read: PermitAnyone,
Write: PermitAnyone,
}
}
// Apply auth token to request.
err = applyAuthToken(w, token)
if err != nil {
tracer.Warningf("api: failed to create auth token: %s", err)
}
}
// Get effective permission for request.
var requestPermission Permission
if readRequest {
requestPermission = token.Read
} else {
requestPermission = token.Write
}
// Check for valid request permission.
if requestPermission < PermitAnyone || requestPermission > PermitSelf {
tracer.Warningf(
"api: authenticator returned invalid permission: %s (%d)",
requestPermission,
requestPermission,
)
http.Error(w, "Internal server error during authentication.", http.StatusInternalServerError)
return nil
}
// Check permission.
if requestPermission < requiredPermission {
http.Error(w, "Insufficient permissions.", http.StatusForbidden)
return nil
}
tracer.Tracef("api: granted %s access to authenticated handler", r.RemoteAddr)
return token
}
func checkAuthToken(r *http.Request) *AuthToken {
// Get auth token from request.
c, err := r.Cookie(cookieName)
if err != nil {
return nil
}
// Check if auth token is registered.
authTokensLock.Lock()
token, ok := authTokens[c.Value]
authTokensLock.Unlock()
if !ok {
log.Tracer(r.Context()).Tracef("api: provided auth token %s is unknown", c.Value)
return nil
}
// Check if token is still valid.
if token.Expired() {
log.Tracer(r.Context()).Tracef("api: provided auth token %s has expired", c.Value)
return nil
}
// Refresh token and return.
token.Refresh(cookieTTL)
log.Tracer(r.Context()).Tracef("api: auth token %s is valid, refreshing", c.Value)
return token
}
func applyAuthToken(w http.ResponseWriter, token *AuthToken) error {
// Generate new token secret.
secret, err := rng.Bytes(32) // 256 bit
if err != nil {
return err
}
secretHex := base64.RawURLEncoding.EncodeToString(secret)
// Set token cookie in response.
http.SetCookie(w, &http.Cookie{
Name: cookieName,
Value: secretHex,
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteStrictMode,
})
// Set token TTL.
token.Refresh(cookieTTL)
// Save token.
authTokensLock.Lock()
defer authTokensLock.Unlock()
authTokens[secretHex] = token
return nil
}
func cleanAuthTokens(_ context.Context, _ *modules.Task) error {
authTokensLock.Lock()
defer authTokensLock.Unlock()
for secret, token := range authTokens {
if token.Expired() {
delete(authTokens, secret)
}
}
return nil
}
func isReadMethod(method string) bool {
return method == http.MethodGet || method == http.MethodHead
}
func (p Permission) String() string {
switch p {
case NotSupported:
return "NotSupported"
case Require:
return "Require"
case PermitAnyone:
return "PermitAnyone"
case PermitUser:
return "PermitUser"
case PermitAdmin:
return "PermitAdmin"
case PermitSelf:
return "PermitSelf"
case NotFound:
return "NotFound"
default:
return "Unknown"
}
}

196
api/authentication_test.go Normal file
View file

@ -0,0 +1,196 @@
package api
import (
"context"
"errors"
"fmt"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
)
var (
testToken = new(AuthToken)
)
func testAuthenticator(ctx context.Context, s *http.Server, r *http.Request) (*AuthToken, error) {
switch {
case testToken.Read == -127 || testToken.Write == -127:
return nil, errors.New("test error")
case testToken.Read == -128 || testToken.Write == -128:
return nil, fmt.Errorf("%wdenied", ErrAPIAccessDeniedMessage)
default:
return testToken, nil
}
}
type testAuthHandler struct {
Read Permission
Write Permission
}
func (ah *testAuthHandler) ReadPermission(r *http.Request) Permission {
return ah.Read
}
func (ah *testAuthHandler) WritePermission(r *http.Request) Permission {
return ah.Write
}
func (ah *testAuthHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Check if request is as expected.
ar := GetAPIRequest(r)
switch {
case ar == nil:
http.Error(w, "ar == nil", http.StatusInternalServerError)
case ar.AuthToken == nil:
http.Error(w, "ar.AuthToken == nil", http.StatusInternalServerError)
default:
http.Error(w, "auth success", http.StatusOK)
}
}
func makeAuthTestPath(reading bool, p Permission) string {
if reading {
return fmt.Sprintf("/test/auth/read/%s", p)
}
return fmt.Sprintf("/test/auth/write/%s", p)
}
func init() {
// Set test authenticator.
err := SetAuthenticator(testAuthenticator)
if err != nil {
panic(err)
}
}
func TestPermissions(t *testing.T) { //nolint:gocognit
testHandler := &mainHandler{
mux: mainMux,
}
// Define permissions that need testing.
permissionsToTest := []Permission{
NotSupported,
PermitAnyone,
PermitUser,
PermitAdmin,
PermitSelf,
Require,
NotFound,
100, // Test a too high value.
-100, // Test a too low value.
-127, // Simulate authenticator failure.
-128, // Simulate authentication denied message.
}
// Register test handlers.
for _, p := range permissionsToTest {
RegisterHandler(makeAuthTestPath(true, p), &testAuthHandler{Read: p})
RegisterHandler(makeAuthTestPath(false, p), &testAuthHandler{Write: p})
}
// Test all the combinations.
for _, requestPerm := range permissionsToTest {
for _, handlerPerm := range permissionsToTest {
for _, method := range []string{
http.MethodGet,
http.MethodHead,
http.MethodPost,
http.MethodPut,
} {
// Set request permission for test requests.
reading := isReadMethod(method)
if reading {
testToken.Read = requestPerm
testToken.Write = NotSupported
} else {
testToken.Read = NotSupported
testToken.Write = requestPerm
}
// Evaluate expected result.
var expectSuccess bool
switch {
case handlerPerm == PermitAnyone:
// This is fast-tracked. There are not additional checks.
expectSuccess = true
case handlerPerm == Require:
// This is turned into PermitAnyone in the authenticator.
// But authentication is still processed and the result still gets
// sanity checked!
if requestPerm >= PermitAnyone &&
requestPerm <= PermitSelf {
expectSuccess = true
}
// Another special case is when the handler requires permission to be
// processed but the authenticator fails to authenticate the request.
// In this case, a fallback token with PermitAnyone is used.
if requestPerm == -128 {
// -128 is used to simulate a permission denied message.
expectSuccess = true
}
case handlerPerm <= NotSupported:
// Invalid handler permission.
case handlerPerm > PermitSelf:
// Invalid handler permission.
case requestPerm <= NotSupported:
// Invalid request permission.
case requestPerm > PermitSelf:
// Invalid request permission.
case requestPerm < handlerPerm:
// Valid, but insufficient request permission.
default:
expectSuccess = true
}
if expectSuccess {
// Test for success.
if !assert.HTTPBodyContains(
t,
testHandler.ServeHTTP,
method,
makeAuthTestPath(reading, handlerPerm),
nil,
"auth success",
) {
t.Errorf(
"%s with %s (%d) to handler %s (%d)",
method,
requestPerm, requestPerm,
handlerPerm, handlerPerm,
)
}
} else {
// Test for error.
if !assert.HTTPError(t,
testHandler.ServeHTTP,
method,
makeAuthTestPath(reading, handlerPerm),
nil,
) {
t.Errorf(
"%s with %s (%d) to handler %s (%d)",
method,
requestPerm, requestPerm,
handlerPerm, handlerPerm,
)
}
}
}
}
}
}
func TestPermissionDefinitions(t *testing.T) {
if NotSupported != 0 {
t.Fatalf("NotSupported must be zero, was %v", NotSupported)
}
}

View file

@ -278,7 +278,7 @@ func (api *DatabaseAPI) handleGet(opID []byte, key string) {
r, err := api.db.Get(key)
if err == nil {
data, err = marshalRecord(r)
data, err = marshalRecord(r, true)
}
if err != nil {
api.send(opID, dbMsgTypeError, err.Error(), nil)
@ -336,7 +336,7 @@ func (api *DatabaseAPI) processQuery(opID []byte, q *query.Query) (ok bool) {
// process query feed
if r != nil {
// process record
data, err := marshalRecord(r)
data, err := marshalRecord(r, true)
if err != nil {
api.send(opID, dbMsgTypeWarning, err.Error(), nil)
}
@ -412,7 +412,7 @@ func (api *DatabaseAPI) processSub(opID []byte, sub *database.Subscription) {
// process sub feed
if r != nil {
// process record
data, err := marshalRecord(r)
data, err := marshalRecord(r, true)
if err != nil {
api.send(opID, dbMsgTypeWarning, err.Error(), nil)
continue
@ -625,7 +625,7 @@ func (api *DatabaseAPI) shutdown() {
// marsharlRecords locks and marshals the given record, additionally adding
// metadata and returning it as json.
func marshalRecord(r record.Record) ([]byte, error) {
func marshalRecord(r record.Record, withDSDIdentifier bool) ([]byte, error) {
r.Lock()
defer r.Unlock()
@ -651,10 +651,12 @@ func marshalRecord(r record.Record) ([]byte, error) {
}
// Add JSON identifier again.
formatID := varint.Pack8(record.JSON)
finalData := make([]byte, 0, len(formatID)+len(jsonData))
finalData = append(finalData, formatID...)
finalData = append(finalData, jsonData...)
return finalData, nil
if withDSDIdentifier {
formatID := varint.Pack8(record.JSON)
finalData := make([]byte, 0, len(formatID)+len(jsonData))
finalData = append(finalData, formatID...)
finalData = append(finalData, jsonData...)
return finalData, nil
}
return jsonData, nil
}

302
api/endpoints.go Normal file
View file

@ -0,0 +1,302 @@
package api
import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"strconv"
"sync"
"github.com/safing/portbase/database/record"
"github.com/safing/portbase/log"
)
// Endpoint describes an API Endpoint.
// Path and at least one permission are required.
// As is exactly one function.
type Endpoint struct {
Path string
MimeType string
Read Permission
Write Permission
// TODO: We _could_ expose more metadata to be able to build lists of actions
// automatically.
// Name string
// Description string
// Order int
// ExpertiseLevel config.ExpertiseLevel
// ActionFn is for simple actions with a return message for the user.
ActionFn ActionFn `json:"-"`
// DataFn is for returning raw data that the caller for further processing.
DataFn DataFn `json:"-"`
// StructFn is for returning any kind of struct.
StructFn StructFn `json:"-"`
// RecordFn is for returning a database record. It will be properly locked
// and marshalled including metadata.
RecordFn RecordFn `json:"-"`
// HandlerFn is the raw http handler.
HandlerFn http.HandlerFunc `json:"-"`
}
type (
// ActionFn is for simple actions with a return message for the user.
ActionFn func(ar *Request) (msg string, err error)
// DataFn is for returning raw data that the caller for further processing.
DataFn func(ar *Request) (data []byte, err error)
// StructFn is for returning any kind of struct.
StructFn func(ar *Request) (i interface{}, err error)
// RecordFn is for returning a database record. It will be properly locked
// and marshalled including metadata.
RecordFn func(ar *Request) (r record.Record, err error)
)
// MIME Types
const (
MimeTypeJSON string = "application/json"
MimeTypeText string = "text/plain"
apiV1Path = "/api/v1/"
)
func init() {
RegisterHandler(apiV1Path+"{endpointPath:.+}", &endpointHandler{})
}
var (
endpoints = make(map[string]*Endpoint)
endpointsLock sync.RWMutex
// ErrInvalidEndpoint is returned when an invalid endpoint is registered.
ErrInvalidEndpoint = errors.New("endpoint is invalid")
// ErrAlreadyRegistered is returned when there already is an endpoint with
// the same path registered.
ErrAlreadyRegistered = errors.New("an endpoint for this path is already registered")
)
func getAPIContext(r *http.Request) (apiEndpoint *Endpoint, apiRequest *Request) {
// Get request context and check if we already have an action cached.
apiRequest = GetAPIRequest(r)
if apiRequest == nil {
return nil, nil
}
var ok bool
apiEndpoint, ok = apiRequest.HandlerCache.(*Endpoint)
if ok {
return apiEndpoint, apiRequest
}
// If not, get the action from the registry.
endpointPath, ok := apiRequest.URLVars["endpointPath"]
if !ok {
return nil, apiRequest
}
endpointsLock.RLock()
defer endpointsLock.RUnlock()
apiEndpoint, ok = endpoints[endpointPath]
if ok {
// Cache for next operation.
apiRequest.HandlerCache = apiEndpoint
}
return apiEndpoint, apiRequest
}
// RegisterEndpoint registers a new endpoint. An error will be returned if it
// does not pass the sanity checks.
func RegisterEndpoint(e Endpoint) error {
if err := e.check(); err != nil {
return fmt.Errorf("%w: %s", ErrInvalidEndpoint, err)
}
endpointsLock.Lock()
defer endpointsLock.Unlock()
_, ok := endpoints[e.Path]
if ok {
return ErrAlreadyRegistered
}
endpoints[e.Path] = &e
return nil
}
func (e *Endpoint) check() error {
// Check path.
if e.Path == "" {
return errors.New("path is missing")
}
// Check permissions.
if e.Read < Require || e.Read > PermitSelf {
return errors.New("invalid read permission")
}
if e.Write < Require || e.Write > PermitSelf {
return errors.New("invalid write permission")
}
// Check functions.
var defaultMimeType string
fnCnt := 0
if e.ActionFn != nil {
fnCnt++
defaultMimeType = MimeTypeText
}
if e.DataFn != nil {
fnCnt++
defaultMimeType = MimeTypeText
}
if e.StructFn != nil {
fnCnt++
defaultMimeType = MimeTypeJSON
}
if e.RecordFn != nil {
fnCnt++
defaultMimeType = MimeTypeJSON
}
if e.HandlerFn != nil {
fnCnt++
defaultMimeType = MimeTypeText
}
if fnCnt != 1 {
return errors.New("only one function may be set")
}
// Set default mime type.
if e.MimeType == "" {
e.MimeType = defaultMimeType
}
return nil
}
type endpointHandler struct{}
var _ AuthenticatedHandler = &endpointHandler{} // Compile time interface check.
// ReadPermission returns the read permission for the handler.
func (eh *endpointHandler) ReadPermission(r *http.Request) Permission {
apiEndpoint, _ := getAPIContext(r)
if apiEndpoint != nil {
return apiEndpoint.Read
}
return NotFound
}
// WritePermission returns the write permission for the handler.
func (eh *endpointHandler) WritePermission(r *http.Request) Permission {
apiEndpoint, _ := getAPIContext(r)
if apiEndpoint != nil {
return apiEndpoint.Write
}
return NotFound
}
// ServeHTTP handles the http request.
func (eh *endpointHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
apiEndpoint, apiRequest := getAPIContext(r)
if apiEndpoint == nil || apiRequest == nil {
http.NotFound(w, r)
return
}
switch r.Method {
case http.MethodHead:
http.Error(w, "", http.StatusOK)
return
case http.MethodPost, http.MethodPut:
// Read body data.
inputData, ok := readBody(w, r)
if !ok {
return
}
apiRequest.InputData = inputData
case http.MethodGet:
// Nothing special to do here.
default:
http.Error(w, "Unsupported method for the actions API.", http.StatusMethodNotAllowed)
return
}
// Execute action function and get response data
var responseData []byte
var err error
switch {
case apiEndpoint.ActionFn != nil:
var msg string
msg, err = apiEndpoint.ActionFn(apiRequest)
if err == nil {
responseData = []byte(msg)
}
case apiEndpoint.DataFn != nil:
responseData, err = apiEndpoint.DataFn(apiRequest)
case apiEndpoint.StructFn != nil:
var v interface{}
v, err = apiEndpoint.StructFn(apiRequest)
if err == nil && v != nil {
responseData, err = json.Marshal(v)
}
case apiEndpoint.RecordFn != nil:
var rec record.Record
rec, err = apiEndpoint.RecordFn(apiRequest)
if err == nil && r != nil {
responseData, err = marshalRecord(rec, false)
}
case apiEndpoint.HandlerFn != nil:
apiEndpoint.HandlerFn(w, r)
return
default:
http.Error(w, "Internal server error: Missing handler.", http.StatusInternalServerError)
return
}
// Check for handler error.
if err != nil {
http.Error(w, "Internal server error: "+err.Error(), http.StatusInternalServerError)
return
}
// Write response.
w.Header().Set("Content-Type", apiEndpoint.MimeType+"; charset=utf-8")
w.Header().Set("Content-Length", strconv.Itoa(len(responseData)))
w.WriteHeader(http.StatusOK)
_, err = w.Write(responseData)
if err != nil {
log.Tracer(r.Context()).Warningf("api: failed to write response: %s", err)
}
}
func readBody(w http.ResponseWriter, r *http.Request) (inputData []byte, ok bool) {
// Check for too long content in order to prevent death.
if r.ContentLength > 20000000 { // 20MB
http.Error(w, "Too much input data.", http.StatusBadRequest)
return nil, false
}
// Read and close body.
inputData, err := ioutil.ReadAll(r.Body)
if err != nil {
http.Error(w, "Failed to read body: "+err.Error(), http.StatusInternalServerError)
return nil, false
}
r.Body.Close()
return inputData, true
}

65
api/endpoints_meta.go Normal file
View file

@ -0,0 +1,65 @@
package api
import (
"encoding/json"
"errors"
)
func registerMetaEndpoints() error {
if err := RegisterEndpoint(Endpoint{
Path: "endpoints",
Read: PermitAnyone,
MimeType: MimeTypeJSON,
DataFn: listEndpoints,
}); err != nil {
return err
}
if err := RegisterEndpoint(Endpoint{
Path: "permission",
Read: Require,
StructFn: permissions,
}); err != nil {
return err
}
if err := RegisterEndpoint(Endpoint{
Path: "ping",
Read: PermitAnyone,
ActionFn: ping,
}); err != nil {
return err
}
return nil
}
func listEndpoints(ar *Request) (data []byte, err error) {
endpointsLock.Lock()
defer endpointsLock.Unlock()
data, err = json.Marshal(endpoints)
return
}
func permissions(ar *Request) (i interface{}, err error) {
if ar.AuthToken == nil {
return nil, errors.New("authentication token missing")
}
return struct {
Read Permission
Write Permission
ReadPermName string
WritePermName string
}{
Read: ar.AuthToken.Read,
Write: ar.AuthToken.Write,
ReadPermName: ar.AuthToken.Read.String(),
WritePermName: ar.AuthToken.Write.String(),
}, nil
}
func ping(ar *Request) (msg string, err error) {
return "Pong.", nil
}

156
api/endpoints_test.go Normal file
View file

@ -0,0 +1,156 @@
package api
import (
"errors"
"sync"
"testing"
"github.com/safing/portbase/database/record"
"github.com/stretchr/testify/assert"
)
const (
successMsg = "endpoint api success"
failedMsg = "endpoint api failed"
)
type actionTestRecord struct {
record.Base
sync.Mutex
Msg string
}
func TestEndpoints(t *testing.T) {
testHandler := &mainHandler{
mux: mainMux,
}
// ActionFn
assert.NoError(t, RegisterEndpoint(Endpoint{
Path: "test/action",
Read: PermitAnyone,
ActionFn: func(_ *Request) (msg string, err error) {
return successMsg, nil
},
}))
assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/action", nil, successMsg)
assert.NoError(t, RegisterEndpoint(Endpoint{
Path: "test/action-err",
Read: PermitAnyone,
ActionFn: func(_ *Request) (msg string, err error) {
return "", errors.New(failedMsg)
},
}))
assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/action-err", nil, failedMsg)
// DataFn
assert.NoError(t, RegisterEndpoint(Endpoint{
Path: "test/data",
Read: PermitAnyone,
DataFn: func(_ *Request) (data []byte, err error) {
return []byte(successMsg), nil
},
}))
assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/data", nil, successMsg)
assert.NoError(t, RegisterEndpoint(Endpoint{
Path: "test/data-err",
Read: PermitAnyone,
DataFn: func(_ *Request) (data []byte, err error) {
return nil, errors.New(failedMsg)
},
}))
assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/data-err", nil, failedMsg)
// StructFn
assert.NoError(t, RegisterEndpoint(Endpoint{
Path: "test/struct",
Read: PermitAnyone,
StructFn: func(_ *Request) (i interface{}, err error) {
return &actionTestRecord{
Msg: successMsg,
}, nil
},
}))
assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/struct", nil, successMsg)
assert.NoError(t, RegisterEndpoint(Endpoint{
Path: "test/struct-err",
Read: PermitAnyone,
StructFn: func(_ *Request) (i interface{}, err error) {
return nil, errors.New(failedMsg)
},
}))
assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/struct-err", nil, failedMsg)
// RecordFn
assert.NoError(t, RegisterEndpoint(Endpoint{
Path: "test/record",
Read: PermitAnyone,
RecordFn: func(_ *Request) (r record.Record, err error) {
r = &actionTestRecord{
Msg: successMsg,
}
r.CreateMeta()
return r, nil
},
}))
assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/record", nil, successMsg)
assert.NoError(t, RegisterEndpoint(Endpoint{
Path: "test/record-err",
Read: PermitAnyone,
RecordFn: func(_ *Request) (r record.Record, err error) {
return nil, errors.New(failedMsg)
},
}))
assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/record-err", nil, failedMsg)
}
func TestActionRegistration(t *testing.T) {
assert.Error(t, RegisterEndpoint(Endpoint{}))
assert.Error(t, RegisterEndpoint(Endpoint{
Path: "test/err",
Read: NotFound,
}))
assert.Error(t, RegisterEndpoint(Endpoint{
Path: "test/err",
Read: PermitSelf + 1,
}))
assert.Error(t, RegisterEndpoint(Endpoint{
Path: "test/err",
Write: NotFound,
}))
assert.Error(t, RegisterEndpoint(Endpoint{
Path: "test/err",
Write: PermitSelf + 1,
}))
assert.Error(t, RegisterEndpoint(Endpoint{
Path: "test/err",
}))
assert.Error(t, RegisterEndpoint(Endpoint{
Path: "test/err",
ActionFn: func(_ *Request) (msg string, err error) {
return successMsg, nil
},
DataFn: func(_ *Request) (data []byte, err error) {
return []byte(successMsg), nil
},
}))
assert.NoError(t, RegisterEndpoint(Endpoint{
Path: "test/err",
ActionFn: func(_ *Request) (msg string, err error) {
return successMsg, nil
},
}))
}

View file

@ -26,7 +26,12 @@ func prep() error {
if getDefaultListenAddress() == "" {
return errors.New("no default listen address for api available")
}
return registerConfig()
if err := registerConfig(); err != nil {
return err
}
return registerMetaEndpoints()
}
func start() error {
@ -34,10 +39,8 @@ func start() error {
go Serve()
// start api auth token cleaner
authFnLock.Lock()
defer authFnLock.Unlock()
if authFn != nil {
module.NewTask("clean api auth tokens", cleanAuthTokens).Repeat(time.Minute)
if authFnSet.IsSet() {
module.NewTask("clean api auth tokens", cleanAuthTokens).Repeat(5 * time.Minute)
}
return nil

58
api/main_test.go Normal file
View file

@ -0,0 +1,58 @@
package api
import (
"fmt"
"io/ioutil"
"os"
"testing"
"github.com/safing/portbase/dataroot"
"github.com/safing/portbase/modules"
// API depends on the database for the database api.
_ "github.com/safing/portbase/database/dbmodule"
)
func init() {
defaultListenAddress = "127.0.0.1:8817"
}
func TestMain(m *testing.M) {
// enable module for testing
module.Enable()
// tmp dir for data root (db & config)
tmpDir, err := ioutil.TempDir("", "portbase-testing-")
if err != nil {
fmt.Fprintf(os.Stderr, "failed to create tmp dir: %s\n", err)
os.Exit(1)
}
// initialize data dir
err = dataroot.Initialize(tmpDir, 0755)
if err != nil {
fmt.Fprintf(os.Stderr, "failed to initialize data root: %s\n", err)
os.Exit(1)
}
// start modules
var exitCode int
err = modules.Start()
if err != nil {
// starting failed
fmt.Fprintf(os.Stderr, "failed to setup test: %s\n", err)
exitCode = 1
} else {
// run tests
exitCode = m.Run()
}
// shutdown
_ = modules.Shutdown()
if modules.GetExitStatusCode() != 0 {
exitCode = modules.GetExitStatusCode()
fmt.Fprintf(os.Stderr, "failed to cleanly shutdown test: %s\n", err)
}
// clean up and exit
os.RemoveAll(tmpDir)
os.Exit(exitCode)
}

50
api/request.go Normal file
View file

@ -0,0 +1,50 @@
package api
import (
"context"
"net/http"
"github.com/gorilla/mux"
)
// Request is a support struct to pool more request related information.
type Request struct {
// Request is the http request.
Request *http.Request
// InputData contains the request body for write operations.
InputData []byte
// Route of this request.
Route *mux.Route
// URLVars contains the URL variables extracted by the gorilla mux.
URLVars map[string]string
// AuthToken is the request-side authentication token assigned.
AuthToken *AuthToken
// HandlerCache can be used by handlers to cache data between handlers within a request.
HandlerCache interface{}
}
// Ctx is a shortcut to access the request context.
func (ar *Request) Ctx() context.Context {
return ar.Request.Context()
}
// apiRequestContextKey is a key used for the context key/value storage.
type apiRequestContextKey struct{}
var (
requestContextKey = apiRequestContextKey{}
)
// GetAPIRequest returns the API Request of the given http request.
func GetAPIRequest(r *http.Request) *Request {
ar, ok := r.Context().Value(requestContextKey).(*Request)
if ok {
return ar
}
return nil
}

View file

@ -56,7 +56,9 @@ func RegisterMiddleware(middleware Middleware) {
func Serve() {
// configure server
server.Addr = listenAddressConfig()
server.Handler = middlewareHandler
server.Handler = &mainHandler{
mux: mainMux,
}
// start serving
log.Infof("api: starting to listen on %s", server.Addr)
@ -76,7 +78,74 @@ func Serve() {
}
}
// GetMuxVars wraps github.com/gorilla/mux.Vars in order to mitigate context key issues in multi-repo projects.
func GetMuxVars(r *http.Request) map[string]string {
return mux.Vars(r)
type mainHandler struct {
mux *mux.Router
}
func (mh *mainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
_ = module.RunWorker("http request", func(_ context.Context) error {
return mh.handle(w, r)
})
}
func (mh *mainHandler) handle(w http.ResponseWriter, r *http.Request) error {
// Setup context trace logging.
ctx, tracer := log.AddTracer(r.Context())
lrw := NewLoggingResponseWriter(w, r)
// Add request context.
apiRequest := &Request{
Request: r,
}
ctx = context.WithValue(ctx, requestContextKey, apiRequest)
// Add context back to request.
r = r.WithContext(ctx)
tracer.Tracef("api request: %s ___ %s", r.RemoteAddr, r.RequestURI)
defer func() {
// Log request status.
if lrw.Status != 0 {
// If lrw.Status is 0, the request may have been hijacked.
tracer.Debugf("api request: %s %d %s", lrw.Request.RemoteAddr, lrw.Status, lrw.Request.RequestURI)
}
tracer.Submit()
}()
// Get handler for request.
// Gorilla does not support handling this on our own very well.
// See github.com/gorilla/mux.ServeHTTP for reference.
var match mux.RouteMatch
var handler http.Handler
if mh.mux.Match(r, &match) {
handler = match.Handler
apiRequest.Route = match.Route
apiRequest.URLVars = match.Vars
}
// Be sure that URLVars always is a map.
if apiRequest.URLVars == nil {
apiRequest.URLVars = make(map[string]string)
}
// Check authentication.
token := authenticateRequest(lrw, r, handler)
if token == nil {
// Authenticator already replied.
return nil
}
apiRequest.AuthToken = &AuthToken{
Read: token.Read,
Write: token.Write,
}
// Handle request.
switch {
case handler != nil:
handler.ServeHTTP(lrw, r)
case match.MatchErr == mux.ErrMethodMismatch:
http.Error(lrw, "Method not allowed.", http.StatusMethodNotAllowed)
default: // handler == nil or other error
http.Error(lrw, "Not found.", http.StatusNotFound)
}
return nil
}