Fix CORS handling

This commit is contained in:
Daniel 2021-11-28 23:48:45 +01:00
parent b3dd9a1b3f
commit 1695420b0e
6 changed files with 82 additions and 38 deletions

View file

@ -133,16 +133,13 @@ func SetAuthenticator(fn AuthenticatorFunc) error {
return nil
}
func authenticateRequest(w http.ResponseWriter, r *http.Request, targetHandler http.Handler) *AuthToken {
func authenticateRequest(w http.ResponseWriter, r *http.Request, targetHandler http.Handler, readMethod bool) *AuthToken {
tracer := log.Tracer(r.Context())
// Check if request is read only.
readRequest := isReadMethod(r.Method)
// Get required permission for target handler.
requiredPermission := PermitSelf
if authdHandler, ok := targetHandler.(AuthenticatedHandler); ok {
if readRequest {
if readMethod {
requiredPermission = authdHandler.ReadPermission(r)
} else {
requiredPermission = authdHandler.WritePermission(r)
@ -200,7 +197,7 @@ func authenticateRequest(w http.ResponseWriter, r *http.Request, targetHandler h
// Get effective permission for request.
var requestPermission Permission
if readRequest {
if readMethod {
requestPermission = token.Read
} else {
requestPermission = token.Write
@ -221,7 +218,10 @@ func authenticateRequest(w http.ResponseWriter, r *http.Request, targetHandler h
if requestPermission < requiredPermission {
// If the token is strictly public, return an authentication request.
if token.Read == PermitAnyone && token.Write == PermitAnyone {
w.Header().Set("WWW-Authenticate", "Bearer realm=Portmaster API")
w.Header().Set(
"WWW-Authenticate",
`Bearer realm="Portmaster API" domain="/"`,
)
http.Error(w, "Authorization required.", http.StatusUnauthorized)
return nil
}
@ -477,12 +477,24 @@ func deleteSession(sessionKey string) {
delete(sessions, sessionKey)
}
func isReadMethod(method string) bool {
func getEffectiveMethod(r *http.Request) (eMethod string, readMethod bool, ok bool) {
method := r.Method
// Get CORS request method if OPTIONS request.
if r.Method == http.MethodOptions {
method = r.Header.Get("Access-Control-Request-Method")
if method == "" {
return "", false, false
}
}
switch method {
case http.MethodGet, http.MethodHead, http.MethodOptions:
return true
case http.MethodGet, http.MethodHead:
return http.MethodGet, true, true
case http.MethodPost, http.MethodPut, http.MethodDelete:
return method, false, true
default:
return false
return "", false, false
}
}

View file

@ -102,7 +102,7 @@ func TestPermissions(t *testing.T) { //nolint:gocognit
} {
// Set request permission for test requests.
reading := isReadMethod(method)
reading := method == http.MethodGet
if reading {
testToken.Read = requestPerm
testToken.Write = NotSupported

View file

@ -371,9 +371,20 @@ func (e *Endpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
// TODO: Return errors instead of warnings, also update the field docs.
if isReadMethod(r.Method) {
if r.Method != e.ReadMethod {
// Return OPTIONS request before starting to handle normal requests.
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
}
eMethod, readMethod, ok := getEffectiveMethod(r)
if !ok {
http.Error(w, "unsupported method for the actions API", http.StatusMethodNotAllowed)
return
}
if readMethod {
if eMethod != e.ReadMethod {
log.Tracer(r.Context()).Warningf(
"api: method %q does not match required read method %q%s",
" - this will be an error and abort the request in the future",
@ -382,7 +393,7 @@ func (e *Endpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
)
}
} else {
if r.Method != e.WriteMethod {
if eMethod != e.WriteMethod {
log.Tracer(r.Context()).Warningf(
"api: method %q does not match required write method %q%s",
" - this will be an error and abort the request in the future",
@ -392,10 +403,9 @@ func (e *Endpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
switch r.Method {
case http.MethodHead:
w.WriteHeader(http.StatusOK)
return
switch eMethod {
case http.MethodGet, http.MethodDelete:
// Nothing to do for these.
case http.MethodPost, http.MethodPut:
// Read body data.
inputData, ok := readBody(w, r)
@ -403,12 +413,8 @@ func (e *Endpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
apiRequest.InputData = inputData
case http.MethodGet:
// Nothing special to do here.
case http.MethodOptions:
w.WriteHeader(http.StatusNoContent)
return
default:
// Defensive.
http.Error(w, "unsupported method for the actions API", http.StatusMethodNotAllowed)
return
}
@ -466,8 +472,8 @@ func (e *Endpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
// Check if there is no response data.
if len(responseData) == 0 {
// Return no content if there is none, or if request is HEAD.
if len(responseData) == 0 || r.Method == http.MethodHead {
w.WriteHeader(http.StatusNoContent)
return
}

View file

@ -93,7 +93,10 @@ func authBearer(w http.ResponseWriter, r *http.Request) {
}
// Respond with desired authentication header.
w.Header().Set("WWW-Authenticate", "Bearer realm=Portmaster API")
w.Header().Set(
"WWW-Authenticate",
`Bearer realm="Portmaster API" domain="/"`,
)
http.Error(w, "Authorization required.", http.StatusUnauthorized)
}
@ -106,7 +109,10 @@ func authBasic(w http.ResponseWriter, r *http.Request) {
}
// Respond with desired authentication header.
w.Header().Set("WWW-Authenticate", "Basic realm=Portmaster API")
w.Header().Set(
"WWW-Authenticate",
`Basic realm="Portmaster API" domain="/"`,
)
http.Error(w, "Authorization required.", http.StatusUnauthorized)
}
@ -127,7 +133,7 @@ func authReset(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Clear-Site-Data", "*")
// Set HTTP Auth Realm without requesting authorization.
w.Header().Set("WWW-Authenticate", "None realm=Portmaster API")
w.Header().Set("WWW-Authenticate", `None realm="Portmaster API"`)
// Reply with 401 Unauthorized in order to clear HTTP Basic Auth data.
http.Error(w, "Session deleted.", http.StatusUnauthorized)

View file

@ -125,14 +125,31 @@ func (mh *mainHandler) handle(w http.ResponseWriter, r *http.Request) error {
apiRequest.Route = match.Route
apiRequest.URLVars = match.Vars
}
switch {
case match.MatchErr == nil:
// All good.
case errors.Is(match.MatchErr, mux.ErrMethodMismatch):
http.Error(lrw, "Method not allowed.", http.StatusMethodNotAllowed)
return nil
default:
http.Error(lrw, "Not found.", http.StatusNotFound)
return nil
}
// Be sure that URLVars always is a map.
if apiRequest.URLVars == nil {
apiRequest.URLVars = make(map[string]string)
}
// Check method.
_, readMethod, ok := getEffectiveMethod(r)
if !ok {
http.Error(lrw, "Method not allowed.", http.StatusMethodNotAllowed)
return nil
}
// Check authentication.
apiRequest.AuthToken = authenticateRequest(lrw, r, handler)
apiRequest.AuthToken = authenticateRequest(lrw, r, handler, readMethod)
if apiRequest.AuthToken == nil {
// Authenticator already replied.
return nil
@ -164,18 +181,21 @@ func (mh *mainHandler) handle(w http.ResponseWriter, r *http.Request) error {
} else if origin := r.Header.Get("Origin"); origin != "" {
// Allow cross origin requests from localhost in dev mode.
if u, err := url.Parse(origin); err == nil &&
utils.StringInSlice(allowedDevCORSOrigins, u.Host) {
utils.StringInSlice(allowedDevCORSOrigins, u.Hostname()) {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Methods", "*")
w.Header().Set("Access-Control-Allow-Headers", "*")
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Expose-Headers", "*")
w.Header().Set("Access-Control-Max-Age", "60")
w.Header().Add("Vary", "Origin")
}
}
// Handle request.
switch {
case handler != nil:
if handler != nil {
handler.ServeHTTP(lrw, r)
case errors.Is(match.MatchErr, mux.ErrMethodMismatch):
http.Error(lrw, "Method not allowed.", http.StatusMethodNotAllowed)
default: // handler == nil or other error
} else {
http.Error(lrw, "Not found.", http.StatusNotFound)
}

View file

@ -47,7 +47,7 @@ func prepConfig() error {
instanceOption = config.Concurrent.GetAsString(CfgOptionInstanceKey, instanceFlag)
err = config.Register(&config.Option{
Name: "Push Metrics",
Name: "Push Prometheus Metrics",
Key: CfgOptionPushKey,
Description: "Push metrics to this URL in the prometheus format.",
OptType: config.OptTypeString,