mirror of
https://github.com/safing/portbase
synced 2025-09-01 01:59:48 +00:00
Add endpoint api and authentication layer
This commit is contained in:
parent
3f544cb07f
commit
9100dc999b
11 changed files with 1257 additions and 115 deletions
34
api/auth_wrapper.go
Normal file
34
api/auth_wrapper.go
Normal file
|
@ -0,0 +1,34 @@
|
|||
package api
|
||||
|
||||
import "net/http"
|
||||
|
||||
// WrapInAuthHandler wraps a simple http.HandlerFunc into a handler that
|
||||
// exposes the given API permissions.
|
||||
func WrapInAuthHandler(fn http.HandlerFunc, read, write Permission) http.Handler {
|
||||
return &wrappedAuthenticatedHandler{
|
||||
handleFunc: fn,
|
||||
read: read,
|
||||
write: write,
|
||||
}
|
||||
}
|
||||
|
||||
type wrappedAuthenticatedHandler struct {
|
||||
handleFunc http.HandlerFunc
|
||||
read Permission
|
||||
write Permission
|
||||
}
|
||||
|
||||
// ServeHTTP handles the http request.
|
||||
func (wah *wrappedAuthenticatedHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
wah.handleFunc(w, r)
|
||||
}
|
||||
|
||||
// ReadPermission returns the read permission for the handler.
|
||||
func (wah *wrappedAuthenticatedHandler) ReadPermission(r *http.Request) Permission {
|
||||
return wah.read
|
||||
}
|
||||
|
||||
// WritePermission returns the write permission for the handler.
|
||||
func (wah *wrappedAuthenticatedHandler) WritePermission(r *http.Request) Permission {
|
||||
return wah.write
|
||||
}
|
|
@ -8,138 +8,345 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portbase/modules"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/rng"
|
||||
)
|
||||
|
||||
var (
|
||||
validTokens = make(map[string]time.Time)
|
||||
validTokensLock sync.Mutex
|
||||
// Permission defines an API requests permission.
|
||||
type Permission int8
|
||||
|
||||
authFnLock sync.Mutex
|
||||
authFn Authenticator
|
||||
const (
|
||||
// NotFound declares that the operation does not exist.
|
||||
NotFound Permission = -2
|
||||
|
||||
// Require declares that the operation requires permission to be processed,
|
||||
// but anyone can execute the operation.
|
||||
Require Permission = -1
|
||||
|
||||
// NotSupported declares that the operation is not supported.
|
||||
NotSupported Permission = 0
|
||||
|
||||
// PermitAnyone declares that anyone can execute the operation without any
|
||||
// authentication.
|
||||
PermitAnyone Permission = 1
|
||||
|
||||
// PermitUser declares that the operation may be executed by authenticated
|
||||
// third party applications that are categorized as representing a simple
|
||||
// user and is limited in access.
|
||||
PermitUser Permission = 2
|
||||
|
||||
// PermitAdmin declares that the operation may be executed by authenticated
|
||||
// third party applications that are categorized as representing an
|
||||
// administrator and has broad in access.
|
||||
PermitAdmin Permission = 3
|
||||
|
||||
// PermitSelf declares that the operation may only be executed by the
|
||||
// software itself and its own (first party) components.
|
||||
PermitSelf Permission = 4
|
||||
)
|
||||
|
||||
// Authenticator is a function that can be set as the authenticator for the API endpoint. If none is set, all requests will be permitted.
|
||||
type Authenticator func(ctx context.Context, s *http.Server, r *http.Request) (*AuthToken, error)
|
||||
|
||||
// AuthToken represents either a set of required or granted permissions.
|
||||
// All attributes must be set when the struct is built and must not be changed
|
||||
// later. Functions may be called at any time.
|
||||
// The Write permission implicitly also includes reading.
|
||||
type AuthToken struct {
|
||||
Read Permission
|
||||
Write Permission
|
||||
|
||||
validUntil time.Time
|
||||
validLock sync.Mutex
|
||||
}
|
||||
|
||||
// Expired returns whether the token has expired.
|
||||
func (token *AuthToken) Expired() bool {
|
||||
token.validLock.Lock()
|
||||
defer token.validLock.Unlock()
|
||||
|
||||
return time.Now().After(token.validUntil)
|
||||
}
|
||||
|
||||
// Refresh refreshes the validity of the token with the given TTL.
|
||||
func (token *AuthToken) Refresh(ttl time.Duration) {
|
||||
token.validLock.Lock()
|
||||
defer token.validLock.Unlock()
|
||||
|
||||
token.validUntil = time.Now().Add(ttl)
|
||||
}
|
||||
|
||||
// AuthenticatedHandler defines the handler interface to specify custom
|
||||
// permission for an API handler.
|
||||
type AuthenticatedHandler interface {
|
||||
ReadPermission(*http.Request) Permission
|
||||
WritePermission(*http.Request) Permission
|
||||
}
|
||||
|
||||
const (
|
||||
cookieName = "Portmaster-API-Token"
|
||||
cookieTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
var (
|
||||
authFnSet = abool.New()
|
||||
authFn Authenticator
|
||||
|
||||
authTokens = make(map[string]*AuthToken)
|
||||
authTokensLock sync.Mutex
|
||||
|
||||
// ErrAPIAccessDeniedMessage should be returned by Authenticator functions in
|
||||
// order to signify a blocked request, including a error message for the user.
|
||||
ErrAPIAccessDeniedMessage = errors.New("")
|
||||
)
|
||||
|
||||
const (
|
||||
cookieName = "Portmaster-API-Token"
|
||||
|
||||
cookieTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
// Authenticator is a function that can be set as the authenticator for the API endpoint. If none is set, all requests will be permitted.
|
||||
type Authenticator func(ctx context.Context, s *http.Server, r *http.Request) (err error)
|
||||
|
||||
// SetAuthenticator sets an authenticator function for the API endpoint. If none is set, all requests will be permitted.
|
||||
func SetAuthenticator(fn Authenticator) error {
|
||||
if module.Online() {
|
||||
return ErrAuthenticationImmutable
|
||||
}
|
||||
|
||||
authFnLock.Lock()
|
||||
defer authFnLock.Unlock()
|
||||
|
||||
if authFn != nil {
|
||||
if authFnSet.IsSet() {
|
||||
return ErrAuthenticationAlreadySet
|
||||
}
|
||||
|
||||
authFn = fn
|
||||
authFnSet.Set()
|
||||
return nil
|
||||
}
|
||||
|
||||
func authMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
tracer := log.Tracer(r.Context())
|
||||
|
||||
// get authenticator
|
||||
authFnLock.Lock()
|
||||
authenticator := authFn
|
||||
authFnLock.Unlock()
|
||||
|
||||
// permit if no authenticator set
|
||||
if authenticator == nil {
|
||||
token := authenticateRequest(w, r, nil)
|
||||
if token != nil {
|
||||
if _, apiRequest := getAPIContext(r); apiRequest != nil {
|
||||
apiRequest.AuthToken = token
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// check existing auth cookie
|
||||
c, err := r.Cookie(cookieName)
|
||||
if err == nil {
|
||||
// get token
|
||||
validTokensLock.Lock()
|
||||
validUntil, valid := validTokens[c.Value]
|
||||
validTokensLock.Unlock()
|
||||
|
||||
// check if token is valid
|
||||
if valid && time.Now().Before(validUntil) {
|
||||
tracer.Tracef("api: auth token %s is valid, refreshing", c.Value)
|
||||
// refresh cookie
|
||||
validTokensLock.Lock()
|
||||
validTokens[c.Value] = time.Now().Add(cookieTTL)
|
||||
validTokensLock.Unlock()
|
||||
// continue
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
tracer.Tracef("api: provided auth token %s is invalid", c.Value)
|
||||
}
|
||||
|
||||
// get auth decision
|
||||
err = authenticator(r.Context(), server, r)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrAPIAccessDeniedMessage) {
|
||||
tracer.Warningf("api: denying api access to %s", r.RemoteAddr)
|
||||
http.Error(w, err.Error(), http.StatusForbidden)
|
||||
} else {
|
||||
tracer.Warningf("api: authenticator failed: %s", err)
|
||||
http.Error(w, "Internal server error during authentication.", http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// generate new token
|
||||
token, err := rng.Bytes(32) // 256 bit
|
||||
if err != nil {
|
||||
tracer.Warningf("api: failed to generate random token: %s", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
}
|
||||
tokenString := base64.RawURLEncoding.EncodeToString(token)
|
||||
// write new cookie
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: cookieName,
|
||||
Value: tokenString,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
MaxAge: int(cookieTTL.Seconds()),
|
||||
})
|
||||
// save cookie
|
||||
validTokensLock.Lock()
|
||||
validTokens[tokenString] = time.Now().Add(cookieTTL)
|
||||
validTokensLock.Unlock()
|
||||
|
||||
// serve
|
||||
tracer.Tracef("api: granted %s, assigned auth token %s", r.RemoteAddr, tokenString)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func cleanAuthTokens(_ context.Context, _ *modules.Task) error {
|
||||
validTokensLock.Lock()
|
||||
defer validTokensLock.Unlock()
|
||||
func authenticateRequest(w http.ResponseWriter, r *http.Request, targetHandler http.Handler) *AuthToken {
|
||||
tracer := log.Tracer(r.Context())
|
||||
|
||||
now := time.Now()
|
||||
for token, validUntil := range validTokens {
|
||||
if now.After(validUntil) {
|
||||
delete(validTokens, token)
|
||||
// Check if authenticator is set.
|
||||
if !authFnSet.IsSet() {
|
||||
// Return highest available permissions.
|
||||
return &AuthToken{
|
||||
Read: PermitSelf,
|
||||
Write: PermitSelf,
|
||||
}
|
||||
}
|
||||
|
||||
// Check if request is read only.
|
||||
readRequest := isReadMethod(r.Method)
|
||||
|
||||
// Get required permission for target handler.
|
||||
requiredPermission := PermitSelf
|
||||
if authdHandler, ok := targetHandler.(AuthenticatedHandler); ok {
|
||||
if readRequest {
|
||||
requiredPermission = authdHandler.ReadPermission(r)
|
||||
} else {
|
||||
requiredPermission = authdHandler.WritePermission(r)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if we need to do any authentication at all.
|
||||
switch requiredPermission {
|
||||
case NotFound:
|
||||
// Not found.
|
||||
tracer.Trace("api: authenticated handler reported: not found")
|
||||
http.Error(w, "Not found.", http.StatusNotFound)
|
||||
return nil
|
||||
case NotSupported:
|
||||
// A read or write permission can be marked as not supported.
|
||||
tracer.Trace("api: authenticated handler reported: not supported")
|
||||
http.Error(w, "Method not allowed.", http.StatusMethodNotAllowed)
|
||||
return nil
|
||||
case PermitAnyone:
|
||||
// Don't process permissions, as we don't need them.
|
||||
tracer.Tracef("api: granted %s access to public handler", r.RemoteAddr)
|
||||
return &AuthToken{
|
||||
Read: PermitAnyone,
|
||||
Write: PermitAnyone,
|
||||
}
|
||||
case Require:
|
||||
// Continue processing permissions, but treat as PermitAnyone.
|
||||
requiredPermission = PermitAnyone
|
||||
}
|
||||
|
||||
// Check for valid permission after handling the specials.
|
||||
if requiredPermission < PermitAnyone || requiredPermission > PermitSelf {
|
||||
tracer.Warningf(
|
||||
"api: handler returned invalid permission: %s (%d)",
|
||||
requiredPermission,
|
||||
requiredPermission,
|
||||
)
|
||||
http.Error(w, "Internal server error during authentication.", http.StatusInternalServerError)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check for an existing auth token.
|
||||
token := checkAuthToken(r)
|
||||
|
||||
// Get auth token from authenticator if none was in the request.
|
||||
if token == nil {
|
||||
var err error
|
||||
token, err = authFn(r.Context(), server, r)
|
||||
if err != nil {
|
||||
// Check for internal error.
|
||||
if !errors.Is(err, ErrAPIAccessDeniedMessage) {
|
||||
tracer.Warningf("api: authenticator failed: %s", err)
|
||||
http.Error(w, "Internal server error during authentication.", http.StatusInternalServerError)
|
||||
return nil
|
||||
}
|
||||
|
||||
// If authentication failed and we require authentication, return an
|
||||
// authentication error.
|
||||
if requiredPermission != PermitAnyone {
|
||||
// Return authentication error.
|
||||
tracer.Warningf("api: denying api access to %s", r.RemoteAddr)
|
||||
http.Error(w, err.Error(), http.StatusForbidden)
|
||||
return nil
|
||||
}
|
||||
|
||||
token = &AuthToken{
|
||||
Read: PermitAnyone,
|
||||
Write: PermitAnyone,
|
||||
}
|
||||
}
|
||||
|
||||
// Apply auth token to request.
|
||||
err = applyAuthToken(w, token)
|
||||
if err != nil {
|
||||
tracer.Warningf("api: failed to create auth token: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Get effective permission for request.
|
||||
var requestPermission Permission
|
||||
if readRequest {
|
||||
requestPermission = token.Read
|
||||
} else {
|
||||
requestPermission = token.Write
|
||||
}
|
||||
|
||||
// Check for valid request permission.
|
||||
if requestPermission < PermitAnyone || requestPermission > PermitSelf {
|
||||
tracer.Warningf(
|
||||
"api: authenticator returned invalid permission: %s (%d)",
|
||||
requestPermission,
|
||||
requestPermission,
|
||||
)
|
||||
http.Error(w, "Internal server error during authentication.", http.StatusInternalServerError)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check permission.
|
||||
if requestPermission < requiredPermission {
|
||||
http.Error(w, "Insufficient permissions.", http.StatusForbidden)
|
||||
return nil
|
||||
}
|
||||
|
||||
tracer.Tracef("api: granted %s access to authenticated handler", r.RemoteAddr)
|
||||
return token
|
||||
}
|
||||
|
||||
func checkAuthToken(r *http.Request) *AuthToken {
|
||||
// Get auth token from request.
|
||||
c, err := r.Cookie(cookieName)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if auth token is registered.
|
||||
authTokensLock.Lock()
|
||||
token, ok := authTokens[c.Value]
|
||||
authTokensLock.Unlock()
|
||||
if !ok {
|
||||
log.Tracer(r.Context()).Tracef("api: provided auth token %s is unknown", c.Value)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if token is still valid.
|
||||
if token.Expired() {
|
||||
log.Tracer(r.Context()).Tracef("api: provided auth token %s has expired", c.Value)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Refresh token and return.
|
||||
token.Refresh(cookieTTL)
|
||||
log.Tracer(r.Context()).Tracef("api: auth token %s is valid, refreshing", c.Value)
|
||||
return token
|
||||
}
|
||||
|
||||
func applyAuthToken(w http.ResponseWriter, token *AuthToken) error {
|
||||
// Generate new token secret.
|
||||
secret, err := rng.Bytes(32) // 256 bit
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
secretHex := base64.RawURLEncoding.EncodeToString(secret)
|
||||
|
||||
// Set token cookie in response.
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: cookieName,
|
||||
Value: secretHex,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
})
|
||||
|
||||
// Set token TTL.
|
||||
token.Refresh(cookieTTL)
|
||||
|
||||
// Save token.
|
||||
authTokensLock.Lock()
|
||||
defer authTokensLock.Unlock()
|
||||
authTokens[secretHex] = token
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func cleanAuthTokens(_ context.Context, _ *modules.Task) error {
|
||||
authTokensLock.Lock()
|
||||
defer authTokensLock.Unlock()
|
||||
|
||||
for secret, token := range authTokens {
|
||||
if token.Expired() {
|
||||
delete(authTokens, secret)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func isReadMethod(method string) bool {
|
||||
return method == http.MethodGet || method == http.MethodHead
|
||||
}
|
||||
|
||||
func (p Permission) String() string {
|
||||
switch p {
|
||||
case NotSupported:
|
||||
return "NotSupported"
|
||||
case Require:
|
||||
return "Require"
|
||||
case PermitAnyone:
|
||||
return "PermitAnyone"
|
||||
case PermitUser:
|
||||
return "PermitUser"
|
||||
case PermitAdmin:
|
||||
return "PermitAdmin"
|
||||
case PermitSelf:
|
||||
return "PermitSelf"
|
||||
case NotFound:
|
||||
return "NotFound"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
|
196
api/authentication_test.go
Normal file
196
api/authentication_test.go
Normal file
|
@ -0,0 +1,196 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var (
|
||||
testToken = new(AuthToken)
|
||||
)
|
||||
|
||||
func testAuthenticator(ctx context.Context, s *http.Server, r *http.Request) (*AuthToken, error) {
|
||||
switch {
|
||||
case testToken.Read == -127 || testToken.Write == -127:
|
||||
return nil, errors.New("test error")
|
||||
case testToken.Read == -128 || testToken.Write == -128:
|
||||
return nil, fmt.Errorf("%wdenied", ErrAPIAccessDeniedMessage)
|
||||
default:
|
||||
return testToken, nil
|
||||
}
|
||||
}
|
||||
|
||||
type testAuthHandler struct {
|
||||
Read Permission
|
||||
Write Permission
|
||||
}
|
||||
|
||||
func (ah *testAuthHandler) ReadPermission(r *http.Request) Permission {
|
||||
return ah.Read
|
||||
}
|
||||
|
||||
func (ah *testAuthHandler) WritePermission(r *http.Request) Permission {
|
||||
return ah.Write
|
||||
}
|
||||
|
||||
func (ah *testAuthHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Check if request is as expected.
|
||||
ar := GetAPIRequest(r)
|
||||
switch {
|
||||
case ar == nil:
|
||||
http.Error(w, "ar == nil", http.StatusInternalServerError)
|
||||
case ar.AuthToken == nil:
|
||||
http.Error(w, "ar.AuthToken == nil", http.StatusInternalServerError)
|
||||
default:
|
||||
http.Error(w, "auth success", http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func makeAuthTestPath(reading bool, p Permission) string {
|
||||
if reading {
|
||||
return fmt.Sprintf("/test/auth/read/%s", p)
|
||||
}
|
||||
return fmt.Sprintf("/test/auth/write/%s", p)
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Set test authenticator.
|
||||
err := SetAuthenticator(testAuthenticator)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPermissions(t *testing.T) { //nolint:gocognit
|
||||
testHandler := &mainHandler{
|
||||
mux: mainMux,
|
||||
}
|
||||
|
||||
// Define permissions that need testing.
|
||||
permissionsToTest := []Permission{
|
||||
NotSupported,
|
||||
PermitAnyone,
|
||||
PermitUser,
|
||||
PermitAdmin,
|
||||
PermitSelf,
|
||||
Require,
|
||||
NotFound,
|
||||
100, // Test a too high value.
|
||||
-100, // Test a too low value.
|
||||
-127, // Simulate authenticator failure.
|
||||
-128, // Simulate authentication denied message.
|
||||
}
|
||||
|
||||
// Register test handlers.
|
||||
for _, p := range permissionsToTest {
|
||||
RegisterHandler(makeAuthTestPath(true, p), &testAuthHandler{Read: p})
|
||||
RegisterHandler(makeAuthTestPath(false, p), &testAuthHandler{Write: p})
|
||||
}
|
||||
|
||||
// Test all the combinations.
|
||||
for _, requestPerm := range permissionsToTest {
|
||||
for _, handlerPerm := range permissionsToTest {
|
||||
for _, method := range []string{
|
||||
http.MethodGet,
|
||||
http.MethodHead,
|
||||
http.MethodPost,
|
||||
http.MethodPut,
|
||||
} {
|
||||
|
||||
// Set request permission for test requests.
|
||||
reading := isReadMethod(method)
|
||||
if reading {
|
||||
testToken.Read = requestPerm
|
||||
testToken.Write = NotSupported
|
||||
} else {
|
||||
testToken.Read = NotSupported
|
||||
testToken.Write = requestPerm
|
||||
}
|
||||
|
||||
// Evaluate expected result.
|
||||
var expectSuccess bool
|
||||
switch {
|
||||
case handlerPerm == PermitAnyone:
|
||||
// This is fast-tracked. There are not additional checks.
|
||||
expectSuccess = true
|
||||
case handlerPerm == Require:
|
||||
// This is turned into PermitAnyone in the authenticator.
|
||||
// But authentication is still processed and the result still gets
|
||||
// sanity checked!
|
||||
if requestPerm >= PermitAnyone &&
|
||||
requestPerm <= PermitSelf {
|
||||
expectSuccess = true
|
||||
}
|
||||
// Another special case is when the handler requires permission to be
|
||||
// processed but the authenticator fails to authenticate the request.
|
||||
// In this case, a fallback token with PermitAnyone is used.
|
||||
if requestPerm == -128 {
|
||||
// -128 is used to simulate a permission denied message.
|
||||
expectSuccess = true
|
||||
}
|
||||
case handlerPerm <= NotSupported:
|
||||
// Invalid handler permission.
|
||||
case handlerPerm > PermitSelf:
|
||||
// Invalid handler permission.
|
||||
case requestPerm <= NotSupported:
|
||||
// Invalid request permission.
|
||||
case requestPerm > PermitSelf:
|
||||
// Invalid request permission.
|
||||
case requestPerm < handlerPerm:
|
||||
// Valid, but insufficient request permission.
|
||||
default:
|
||||
expectSuccess = true
|
||||
}
|
||||
|
||||
if expectSuccess {
|
||||
|
||||
// Test for success.
|
||||
if !assert.HTTPBodyContains(
|
||||
t,
|
||||
testHandler.ServeHTTP,
|
||||
method,
|
||||
makeAuthTestPath(reading, handlerPerm),
|
||||
nil,
|
||||
"auth success",
|
||||
) {
|
||||
t.Errorf(
|
||||
"%s with %s (%d) to handler %s (%d)",
|
||||
method,
|
||||
requestPerm, requestPerm,
|
||||
handlerPerm, handlerPerm,
|
||||
)
|
||||
}
|
||||
|
||||
} else {
|
||||
|
||||
// Test for error.
|
||||
if !assert.HTTPError(t,
|
||||
testHandler.ServeHTTP,
|
||||
method,
|
||||
makeAuthTestPath(reading, handlerPerm),
|
||||
nil,
|
||||
) {
|
||||
t.Errorf(
|
||||
"%s with %s (%d) to handler %s (%d)",
|
||||
method,
|
||||
requestPerm, requestPerm,
|
||||
handlerPerm, handlerPerm,
|
||||
)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPermissionDefinitions(t *testing.T) {
|
||||
if NotSupported != 0 {
|
||||
t.Fatalf("NotSupported must be zero, was %v", NotSupported)
|
||||
}
|
||||
}
|
|
@ -278,7 +278,7 @@ func (api *DatabaseAPI) handleGet(opID []byte, key string) {
|
|||
|
||||
r, err := api.db.Get(key)
|
||||
if err == nil {
|
||||
data, err = marshalRecord(r)
|
||||
data, err = marshalRecord(r, true)
|
||||
}
|
||||
if err != nil {
|
||||
api.send(opID, dbMsgTypeError, err.Error(), nil)
|
||||
|
@ -336,7 +336,7 @@ func (api *DatabaseAPI) processQuery(opID []byte, q *query.Query) (ok bool) {
|
|||
// process query feed
|
||||
if r != nil {
|
||||
// process record
|
||||
data, err := marshalRecord(r)
|
||||
data, err := marshalRecord(r, true)
|
||||
if err != nil {
|
||||
api.send(opID, dbMsgTypeWarning, err.Error(), nil)
|
||||
}
|
||||
|
@ -412,7 +412,7 @@ func (api *DatabaseAPI) processSub(opID []byte, sub *database.Subscription) {
|
|||
// process sub feed
|
||||
if r != nil {
|
||||
// process record
|
||||
data, err := marshalRecord(r)
|
||||
data, err := marshalRecord(r, true)
|
||||
if err != nil {
|
||||
api.send(opID, dbMsgTypeWarning, err.Error(), nil)
|
||||
continue
|
||||
|
@ -625,7 +625,7 @@ func (api *DatabaseAPI) shutdown() {
|
|||
|
||||
// marsharlRecords locks and marshals the given record, additionally adding
|
||||
// metadata and returning it as json.
|
||||
func marshalRecord(r record.Record) ([]byte, error) {
|
||||
func marshalRecord(r record.Record, withDSDIdentifier bool) ([]byte, error) {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
|
@ -651,10 +651,12 @@ func marshalRecord(r record.Record) ([]byte, error) {
|
|||
}
|
||||
|
||||
// Add JSON identifier again.
|
||||
formatID := varint.Pack8(record.JSON)
|
||||
finalData := make([]byte, 0, len(formatID)+len(jsonData))
|
||||
finalData = append(finalData, formatID...)
|
||||
finalData = append(finalData, jsonData...)
|
||||
|
||||
return finalData, nil
|
||||
if withDSDIdentifier {
|
||||
formatID := varint.Pack8(record.JSON)
|
||||
finalData := make([]byte, 0, len(formatID)+len(jsonData))
|
||||
finalData = append(finalData, formatID...)
|
||||
finalData = append(finalData, jsonData...)
|
||||
return finalData, nil
|
||||
}
|
||||
return jsonData, nil
|
||||
}
|
||||
|
|
302
api/endpoints.go
Normal file
302
api/endpoints.go
Normal file
|
@ -0,0 +1,302 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/safing/portbase/database/record"
|
||||
"github.com/safing/portbase/log"
|
||||
)
|
||||
|
||||
// Endpoint describes an API Endpoint.
|
||||
// Path and at least one permission are required.
|
||||
// As is exactly one function.
|
||||
type Endpoint struct {
|
||||
Path string
|
||||
MimeType string
|
||||
Read Permission
|
||||
Write Permission
|
||||
|
||||
// TODO: We _could_ expose more metadata to be able to build lists of actions
|
||||
// automatically.
|
||||
// Name string
|
||||
// Description string
|
||||
// Order int
|
||||
// ExpertiseLevel config.ExpertiseLevel
|
||||
|
||||
// ActionFn is for simple actions with a return message for the user.
|
||||
ActionFn ActionFn `json:"-"`
|
||||
|
||||
// DataFn is for returning raw data that the caller for further processing.
|
||||
DataFn DataFn `json:"-"`
|
||||
|
||||
// StructFn is for returning any kind of struct.
|
||||
StructFn StructFn `json:"-"`
|
||||
|
||||
// RecordFn is for returning a database record. It will be properly locked
|
||||
// and marshalled including metadata.
|
||||
RecordFn RecordFn `json:"-"`
|
||||
|
||||
// HandlerFn is the raw http handler.
|
||||
HandlerFn http.HandlerFunc `json:"-"`
|
||||
}
|
||||
|
||||
type (
|
||||
// ActionFn is for simple actions with a return message for the user.
|
||||
ActionFn func(ar *Request) (msg string, err error)
|
||||
|
||||
// DataFn is for returning raw data that the caller for further processing.
|
||||
DataFn func(ar *Request) (data []byte, err error)
|
||||
|
||||
// StructFn is for returning any kind of struct.
|
||||
StructFn func(ar *Request) (i interface{}, err error)
|
||||
|
||||
// RecordFn is for returning a database record. It will be properly locked
|
||||
// and marshalled including metadata.
|
||||
RecordFn func(ar *Request) (r record.Record, err error)
|
||||
)
|
||||
|
||||
// MIME Types
|
||||
const (
|
||||
MimeTypeJSON string = "application/json"
|
||||
MimeTypeText string = "text/plain"
|
||||
|
||||
apiV1Path = "/api/v1/"
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterHandler(apiV1Path+"{endpointPath:.+}", &endpointHandler{})
|
||||
}
|
||||
|
||||
var (
|
||||
endpoints = make(map[string]*Endpoint)
|
||||
endpointsLock sync.RWMutex
|
||||
|
||||
// ErrInvalidEndpoint is returned when an invalid endpoint is registered.
|
||||
ErrInvalidEndpoint = errors.New("endpoint is invalid")
|
||||
|
||||
// ErrAlreadyRegistered is returned when there already is an endpoint with
|
||||
// the same path registered.
|
||||
ErrAlreadyRegistered = errors.New("an endpoint for this path is already registered")
|
||||
)
|
||||
|
||||
func getAPIContext(r *http.Request) (apiEndpoint *Endpoint, apiRequest *Request) {
|
||||
// Get request context and check if we already have an action cached.
|
||||
apiRequest = GetAPIRequest(r)
|
||||
if apiRequest == nil {
|
||||
return nil, nil
|
||||
}
|
||||
var ok bool
|
||||
apiEndpoint, ok = apiRequest.HandlerCache.(*Endpoint)
|
||||
if ok {
|
||||
return apiEndpoint, apiRequest
|
||||
}
|
||||
|
||||
// If not, get the action from the registry.
|
||||
endpointPath, ok := apiRequest.URLVars["endpointPath"]
|
||||
if !ok {
|
||||
return nil, apiRequest
|
||||
}
|
||||
|
||||
endpointsLock.RLock()
|
||||
defer endpointsLock.RUnlock()
|
||||
|
||||
apiEndpoint, ok = endpoints[endpointPath]
|
||||
if ok {
|
||||
// Cache for next operation.
|
||||
apiRequest.HandlerCache = apiEndpoint
|
||||
}
|
||||
return apiEndpoint, apiRequest
|
||||
}
|
||||
|
||||
// RegisterEndpoint registers a new endpoint. An error will be returned if it
|
||||
// does not pass the sanity checks.
|
||||
func RegisterEndpoint(e Endpoint) error {
|
||||
if err := e.check(); err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrInvalidEndpoint, err)
|
||||
}
|
||||
|
||||
endpointsLock.Lock()
|
||||
defer endpointsLock.Unlock()
|
||||
|
||||
_, ok := endpoints[e.Path]
|
||||
if ok {
|
||||
return ErrAlreadyRegistered
|
||||
}
|
||||
|
||||
endpoints[e.Path] = &e
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Endpoint) check() error {
|
||||
// Check path.
|
||||
if e.Path == "" {
|
||||
return errors.New("path is missing")
|
||||
}
|
||||
|
||||
// Check permissions.
|
||||
if e.Read < Require || e.Read > PermitSelf {
|
||||
return errors.New("invalid read permission")
|
||||
}
|
||||
if e.Write < Require || e.Write > PermitSelf {
|
||||
return errors.New("invalid write permission")
|
||||
}
|
||||
|
||||
// Check functions.
|
||||
var defaultMimeType string
|
||||
fnCnt := 0
|
||||
if e.ActionFn != nil {
|
||||
fnCnt++
|
||||
defaultMimeType = MimeTypeText
|
||||
}
|
||||
if e.DataFn != nil {
|
||||
fnCnt++
|
||||
defaultMimeType = MimeTypeText
|
||||
}
|
||||
if e.StructFn != nil {
|
||||
fnCnt++
|
||||
defaultMimeType = MimeTypeJSON
|
||||
}
|
||||
if e.RecordFn != nil {
|
||||
fnCnt++
|
||||
defaultMimeType = MimeTypeJSON
|
||||
}
|
||||
if e.HandlerFn != nil {
|
||||
fnCnt++
|
||||
defaultMimeType = MimeTypeText
|
||||
}
|
||||
if fnCnt != 1 {
|
||||
return errors.New("only one function may be set")
|
||||
}
|
||||
|
||||
// Set default mime type.
|
||||
if e.MimeType == "" {
|
||||
e.MimeType = defaultMimeType
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type endpointHandler struct{}
|
||||
|
||||
var _ AuthenticatedHandler = &endpointHandler{} // Compile time interface check.
|
||||
|
||||
// ReadPermission returns the read permission for the handler.
|
||||
func (eh *endpointHandler) ReadPermission(r *http.Request) Permission {
|
||||
apiEndpoint, _ := getAPIContext(r)
|
||||
if apiEndpoint != nil {
|
||||
return apiEndpoint.Read
|
||||
}
|
||||
return NotFound
|
||||
}
|
||||
|
||||
// WritePermission returns the write permission for the handler.
|
||||
func (eh *endpointHandler) WritePermission(r *http.Request) Permission {
|
||||
apiEndpoint, _ := getAPIContext(r)
|
||||
if apiEndpoint != nil {
|
||||
return apiEndpoint.Write
|
||||
}
|
||||
return NotFound
|
||||
}
|
||||
|
||||
// ServeHTTP handles the http request.
|
||||
func (eh *endpointHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
apiEndpoint, apiRequest := getAPIContext(r)
|
||||
if apiEndpoint == nil || apiRequest == nil {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodHead:
|
||||
http.Error(w, "", http.StatusOK)
|
||||
return
|
||||
case http.MethodPost, http.MethodPut:
|
||||
// Read body data.
|
||||
inputData, ok := readBody(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
apiRequest.InputData = inputData
|
||||
case http.MethodGet:
|
||||
// Nothing special to do here.
|
||||
default:
|
||||
http.Error(w, "Unsupported method for the actions API.", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Execute action function and get response data
|
||||
var responseData []byte
|
||||
var err error
|
||||
|
||||
switch {
|
||||
case apiEndpoint.ActionFn != nil:
|
||||
var msg string
|
||||
msg, err = apiEndpoint.ActionFn(apiRequest)
|
||||
if err == nil {
|
||||
responseData = []byte(msg)
|
||||
}
|
||||
|
||||
case apiEndpoint.DataFn != nil:
|
||||
responseData, err = apiEndpoint.DataFn(apiRequest)
|
||||
|
||||
case apiEndpoint.StructFn != nil:
|
||||
var v interface{}
|
||||
v, err = apiEndpoint.StructFn(apiRequest)
|
||||
if err == nil && v != nil {
|
||||
responseData, err = json.Marshal(v)
|
||||
}
|
||||
|
||||
case apiEndpoint.RecordFn != nil:
|
||||
var rec record.Record
|
||||
rec, err = apiEndpoint.RecordFn(apiRequest)
|
||||
if err == nil && r != nil {
|
||||
responseData, err = marshalRecord(rec, false)
|
||||
}
|
||||
|
||||
case apiEndpoint.HandlerFn != nil:
|
||||
apiEndpoint.HandlerFn(w, r)
|
||||
return
|
||||
|
||||
default:
|
||||
http.Error(w, "Internal server error: Missing handler.", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Check for handler error.
|
||||
if err != nil {
|
||||
http.Error(w, "Internal server error: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Write response.
|
||||
w.Header().Set("Content-Type", apiEndpoint.MimeType+"; charset=utf-8")
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(responseData)))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, err = w.Write(responseData)
|
||||
if err != nil {
|
||||
log.Tracer(r.Context()).Warningf("api: failed to write response: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func readBody(w http.ResponseWriter, r *http.Request) (inputData []byte, ok bool) {
|
||||
// Check for too long content in order to prevent death.
|
||||
if r.ContentLength > 20000000 { // 20MB
|
||||
http.Error(w, "Too much input data.", http.StatusBadRequest)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Read and close body.
|
||||
inputData, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read body: "+err.Error(), http.StatusInternalServerError)
|
||||
return nil, false
|
||||
}
|
||||
r.Body.Close()
|
||||
return inputData, true
|
||||
}
|
65
api/endpoints_meta.go
Normal file
65
api/endpoints_meta.go
Normal file
|
@ -0,0 +1,65 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
)
|
||||
|
||||
func registerMetaEndpoints() error {
|
||||
if err := RegisterEndpoint(Endpoint{
|
||||
Path: "endpoints",
|
||||
Read: PermitAnyone,
|
||||
MimeType: MimeTypeJSON,
|
||||
DataFn: listEndpoints,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := RegisterEndpoint(Endpoint{
|
||||
Path: "permission",
|
||||
Read: Require,
|
||||
StructFn: permissions,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := RegisterEndpoint(Endpoint{
|
||||
Path: "ping",
|
||||
Read: PermitAnyone,
|
||||
ActionFn: ping,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func listEndpoints(ar *Request) (data []byte, err error) {
|
||||
endpointsLock.Lock()
|
||||
defer endpointsLock.Unlock()
|
||||
|
||||
data, err = json.Marshal(endpoints)
|
||||
return
|
||||
}
|
||||
|
||||
func permissions(ar *Request) (i interface{}, err error) {
|
||||
if ar.AuthToken == nil {
|
||||
return nil, errors.New("authentication token missing")
|
||||
}
|
||||
|
||||
return struct {
|
||||
Read Permission
|
||||
Write Permission
|
||||
ReadPermName string
|
||||
WritePermName string
|
||||
}{
|
||||
Read: ar.AuthToken.Read,
|
||||
Write: ar.AuthToken.Write,
|
||||
ReadPermName: ar.AuthToken.Read.String(),
|
||||
WritePermName: ar.AuthToken.Write.String(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ping(ar *Request) (msg string, err error) {
|
||||
return "Pong.", nil
|
||||
}
|
156
api/endpoints_test.go
Normal file
156
api/endpoints_test.go
Normal file
|
@ -0,0 +1,156 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/safing/portbase/database/record"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
successMsg = "endpoint api success"
|
||||
failedMsg = "endpoint api failed"
|
||||
)
|
||||
|
||||
type actionTestRecord struct {
|
||||
record.Base
|
||||
sync.Mutex
|
||||
Msg string
|
||||
}
|
||||
|
||||
func TestEndpoints(t *testing.T) {
|
||||
testHandler := &mainHandler{
|
||||
mux: mainMux,
|
||||
}
|
||||
|
||||
// ActionFn
|
||||
|
||||
assert.NoError(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/action",
|
||||
Read: PermitAnyone,
|
||||
ActionFn: func(_ *Request) (msg string, err error) {
|
||||
return successMsg, nil
|
||||
},
|
||||
}))
|
||||
assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/action", nil, successMsg)
|
||||
|
||||
assert.NoError(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/action-err",
|
||||
Read: PermitAnyone,
|
||||
ActionFn: func(_ *Request) (msg string, err error) {
|
||||
return "", errors.New(failedMsg)
|
||||
},
|
||||
}))
|
||||
assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/action-err", nil, failedMsg)
|
||||
|
||||
// DataFn
|
||||
|
||||
assert.NoError(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/data",
|
||||
Read: PermitAnyone,
|
||||
DataFn: func(_ *Request) (data []byte, err error) {
|
||||
return []byte(successMsg), nil
|
||||
},
|
||||
}))
|
||||
assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/data", nil, successMsg)
|
||||
|
||||
assert.NoError(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/data-err",
|
||||
Read: PermitAnyone,
|
||||
DataFn: func(_ *Request) (data []byte, err error) {
|
||||
return nil, errors.New(failedMsg)
|
||||
},
|
||||
}))
|
||||
assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/data-err", nil, failedMsg)
|
||||
|
||||
// StructFn
|
||||
|
||||
assert.NoError(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/struct",
|
||||
Read: PermitAnyone,
|
||||
StructFn: func(_ *Request) (i interface{}, err error) {
|
||||
return &actionTestRecord{
|
||||
Msg: successMsg,
|
||||
}, nil
|
||||
},
|
||||
}))
|
||||
assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/struct", nil, successMsg)
|
||||
|
||||
assert.NoError(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/struct-err",
|
||||
Read: PermitAnyone,
|
||||
StructFn: func(_ *Request) (i interface{}, err error) {
|
||||
return nil, errors.New(failedMsg)
|
||||
},
|
||||
}))
|
||||
assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/struct-err", nil, failedMsg)
|
||||
|
||||
// RecordFn
|
||||
|
||||
assert.NoError(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/record",
|
||||
Read: PermitAnyone,
|
||||
RecordFn: func(_ *Request) (r record.Record, err error) {
|
||||
r = &actionTestRecord{
|
||||
Msg: successMsg,
|
||||
}
|
||||
r.CreateMeta()
|
||||
return r, nil
|
||||
},
|
||||
}))
|
||||
assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/record", nil, successMsg)
|
||||
|
||||
assert.NoError(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/record-err",
|
||||
Read: PermitAnyone,
|
||||
RecordFn: func(_ *Request) (r record.Record, err error) {
|
||||
return nil, errors.New(failedMsg)
|
||||
},
|
||||
}))
|
||||
assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/record-err", nil, failedMsg)
|
||||
}
|
||||
|
||||
func TestActionRegistration(t *testing.T) {
|
||||
assert.Error(t, RegisterEndpoint(Endpoint{}))
|
||||
|
||||
assert.Error(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/err",
|
||||
Read: NotFound,
|
||||
}))
|
||||
assert.Error(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/err",
|
||||
Read: PermitSelf + 1,
|
||||
}))
|
||||
|
||||
assert.Error(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/err",
|
||||
Write: NotFound,
|
||||
}))
|
||||
assert.Error(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/err",
|
||||
Write: PermitSelf + 1,
|
||||
}))
|
||||
|
||||
assert.Error(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/err",
|
||||
}))
|
||||
|
||||
assert.Error(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/err",
|
||||
ActionFn: func(_ *Request) (msg string, err error) {
|
||||
return successMsg, nil
|
||||
},
|
||||
DataFn: func(_ *Request) (data []byte, err error) {
|
||||
return []byte(successMsg), nil
|
||||
},
|
||||
}))
|
||||
|
||||
assert.NoError(t, RegisterEndpoint(Endpoint{
|
||||
Path: "test/err",
|
||||
ActionFn: func(_ *Request) (msg string, err error) {
|
||||
return successMsg, nil
|
||||
},
|
||||
}))
|
||||
}
|
13
api/main.go
13
api/main.go
|
@ -26,7 +26,12 @@ func prep() error {
|
|||
if getDefaultListenAddress() == "" {
|
||||
return errors.New("no default listen address for api available")
|
||||
}
|
||||
return registerConfig()
|
||||
|
||||
if err := registerConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return registerMetaEndpoints()
|
||||
}
|
||||
|
||||
func start() error {
|
||||
|
@ -34,10 +39,8 @@ func start() error {
|
|||
go Serve()
|
||||
|
||||
// start api auth token cleaner
|
||||
authFnLock.Lock()
|
||||
defer authFnLock.Unlock()
|
||||
if authFn != nil {
|
||||
module.NewTask("clean api auth tokens", cleanAuthTokens).Repeat(time.Minute)
|
||||
if authFnSet.IsSet() {
|
||||
module.NewTask("clean api auth tokens", cleanAuthTokens).Repeat(5 * time.Minute)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
58
api/main_test.go
Normal file
58
api/main_test.go
Normal file
|
@ -0,0 +1,58 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/safing/portbase/dataroot"
|
||||
"github.com/safing/portbase/modules"
|
||||
|
||||
// API depends on the database for the database api.
|
||||
_ "github.com/safing/portbase/database/dbmodule"
|
||||
)
|
||||
|
||||
func init() {
|
||||
defaultListenAddress = "127.0.0.1:8817"
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// enable module for testing
|
||||
module.Enable()
|
||||
|
||||
// tmp dir for data root (db & config)
|
||||
tmpDir, err := ioutil.TempDir("", "portbase-testing-")
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "failed to create tmp dir: %s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
// initialize data dir
|
||||
err = dataroot.Initialize(tmpDir, 0755)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "failed to initialize data root: %s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// start modules
|
||||
var exitCode int
|
||||
err = modules.Start()
|
||||
if err != nil {
|
||||
// starting failed
|
||||
fmt.Fprintf(os.Stderr, "failed to setup test: %s\n", err)
|
||||
exitCode = 1
|
||||
} else {
|
||||
// run tests
|
||||
exitCode = m.Run()
|
||||
}
|
||||
|
||||
// shutdown
|
||||
_ = modules.Shutdown()
|
||||
if modules.GetExitStatusCode() != 0 {
|
||||
exitCode = modules.GetExitStatusCode()
|
||||
fmt.Fprintf(os.Stderr, "failed to cleanly shutdown test: %s\n", err)
|
||||
}
|
||||
// clean up and exit
|
||||
os.RemoveAll(tmpDir)
|
||||
os.Exit(exitCode)
|
||||
}
|
50
api/request.go
Normal file
50
api/request.go
Normal file
|
@ -0,0 +1,50 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
// Request is a support struct to pool more request related information.
|
||||
type Request struct {
|
||||
// Request is the http request.
|
||||
Request *http.Request
|
||||
|
||||
// InputData contains the request body for write operations.
|
||||
InputData []byte
|
||||
|
||||
// Route of this request.
|
||||
Route *mux.Route
|
||||
|
||||
// URLVars contains the URL variables extracted by the gorilla mux.
|
||||
URLVars map[string]string
|
||||
|
||||
// AuthToken is the request-side authentication token assigned.
|
||||
AuthToken *AuthToken
|
||||
|
||||
// HandlerCache can be used by handlers to cache data between handlers within a request.
|
||||
HandlerCache interface{}
|
||||
}
|
||||
|
||||
// Ctx is a shortcut to access the request context.
|
||||
func (ar *Request) Ctx() context.Context {
|
||||
return ar.Request.Context()
|
||||
}
|
||||
|
||||
// apiRequestContextKey is a key used for the context key/value storage.
|
||||
type apiRequestContextKey struct{}
|
||||
|
||||
var (
|
||||
requestContextKey = apiRequestContextKey{}
|
||||
)
|
||||
|
||||
// GetAPIRequest returns the API Request of the given http request.
|
||||
func GetAPIRequest(r *http.Request) *Request {
|
||||
ar, ok := r.Context().Value(requestContextKey).(*Request)
|
||||
if ok {
|
||||
return ar
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -56,7 +56,9 @@ func RegisterMiddleware(middleware Middleware) {
|
|||
func Serve() {
|
||||
// configure server
|
||||
server.Addr = listenAddressConfig()
|
||||
server.Handler = middlewareHandler
|
||||
server.Handler = &mainHandler{
|
||||
mux: mainMux,
|
||||
}
|
||||
|
||||
// start serving
|
||||
log.Infof("api: starting to listen on %s", server.Addr)
|
||||
|
@ -76,7 +78,74 @@ func Serve() {
|
|||
}
|
||||
}
|
||||
|
||||
// GetMuxVars wraps github.com/gorilla/mux.Vars in order to mitigate context key issues in multi-repo projects.
|
||||
func GetMuxVars(r *http.Request) map[string]string {
|
||||
return mux.Vars(r)
|
||||
type mainHandler struct {
|
||||
mux *mux.Router
|
||||
}
|
||||
|
||||
func (mh *mainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
_ = module.RunWorker("http request", func(_ context.Context) error {
|
||||
return mh.handle(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (mh *mainHandler) handle(w http.ResponseWriter, r *http.Request) error {
|
||||
// Setup context trace logging.
|
||||
ctx, tracer := log.AddTracer(r.Context())
|
||||
lrw := NewLoggingResponseWriter(w, r)
|
||||
// Add request context.
|
||||
apiRequest := &Request{
|
||||
Request: r,
|
||||
}
|
||||
ctx = context.WithValue(ctx, requestContextKey, apiRequest)
|
||||
// Add context back to request.
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
tracer.Tracef("api request: %s ___ %s", r.RemoteAddr, r.RequestURI)
|
||||
defer func() {
|
||||
// Log request status.
|
||||
if lrw.Status != 0 {
|
||||
// If lrw.Status is 0, the request may have been hijacked.
|
||||
tracer.Debugf("api request: %s %d %s", lrw.Request.RemoteAddr, lrw.Status, lrw.Request.RequestURI)
|
||||
}
|
||||
tracer.Submit()
|
||||
}()
|
||||
|
||||
// Get handler for request.
|
||||
// Gorilla does not support handling this on our own very well.
|
||||
// See github.com/gorilla/mux.ServeHTTP for reference.
|
||||
var match mux.RouteMatch
|
||||
var handler http.Handler
|
||||
if mh.mux.Match(r, &match) {
|
||||
handler = match.Handler
|
||||
apiRequest.Route = match.Route
|
||||
apiRequest.URLVars = match.Vars
|
||||
}
|
||||
|
||||
// Be sure that URLVars always is a map.
|
||||
if apiRequest.URLVars == nil {
|
||||
apiRequest.URLVars = make(map[string]string)
|
||||
}
|
||||
|
||||
// Check authentication.
|
||||
token := authenticateRequest(lrw, r, handler)
|
||||
if token == nil {
|
||||
// Authenticator already replied.
|
||||
return nil
|
||||
}
|
||||
apiRequest.AuthToken = &AuthToken{
|
||||
Read: token.Read,
|
||||
Write: token.Write,
|
||||
}
|
||||
|
||||
// Handle request.
|
||||
switch {
|
||||
case handler != nil:
|
||||
handler.ServeHTTP(lrw, r)
|
||||
case match.MatchErr == mux.ErrMethodMismatch:
|
||||
http.Error(lrw, "Method not allowed.", http.StatusMethodNotAllowed)
|
||||
default: // handler == nil or other error
|
||||
http.Error(lrw, "Not found.", http.StatusNotFound)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue