Fix CORS handling

This commit is contained in:
Daniel 2021-11-28 23:48:45 +01:00
parent b3dd9a1b3f
commit 1695420b0e
6 changed files with 82 additions and 38 deletions

View file

@ -133,16 +133,13 @@ func SetAuthenticator(fn AuthenticatorFunc) error {
return nil return nil
} }
func authenticateRequest(w http.ResponseWriter, r *http.Request, targetHandler http.Handler) *AuthToken { func authenticateRequest(w http.ResponseWriter, r *http.Request, targetHandler http.Handler, readMethod bool) *AuthToken {
tracer := log.Tracer(r.Context()) tracer := log.Tracer(r.Context())
// Check if request is read only.
readRequest := isReadMethod(r.Method)
// Get required permission for target handler. // Get required permission for target handler.
requiredPermission := PermitSelf requiredPermission := PermitSelf
if authdHandler, ok := targetHandler.(AuthenticatedHandler); ok { if authdHandler, ok := targetHandler.(AuthenticatedHandler); ok {
if readRequest { if readMethod {
requiredPermission = authdHandler.ReadPermission(r) requiredPermission = authdHandler.ReadPermission(r)
} else { } else {
requiredPermission = authdHandler.WritePermission(r) requiredPermission = authdHandler.WritePermission(r)
@ -200,7 +197,7 @@ func authenticateRequest(w http.ResponseWriter, r *http.Request, targetHandler h
// Get effective permission for request. // Get effective permission for request.
var requestPermission Permission var requestPermission Permission
if readRequest { if readMethod {
requestPermission = token.Read requestPermission = token.Read
} else { } else {
requestPermission = token.Write requestPermission = token.Write
@ -221,7 +218,10 @@ func authenticateRequest(w http.ResponseWriter, r *http.Request, targetHandler h
if requestPermission < requiredPermission { if requestPermission < requiredPermission {
// If the token is strictly public, return an authentication request. // If the token is strictly public, return an authentication request.
if token.Read == PermitAnyone && token.Write == PermitAnyone { if token.Read == PermitAnyone && token.Write == PermitAnyone {
w.Header().Set("WWW-Authenticate", "Bearer realm=Portmaster API") w.Header().Set(
"WWW-Authenticate",
`Bearer realm="Portmaster API" domain="/"`,
)
http.Error(w, "Authorization required.", http.StatusUnauthorized) http.Error(w, "Authorization required.", http.StatusUnauthorized)
return nil return nil
} }
@ -477,12 +477,24 @@ func deleteSession(sessionKey string) {
delete(sessions, sessionKey) delete(sessions, sessionKey)
} }
func isReadMethod(method string) bool { func getEffectiveMethod(r *http.Request) (eMethod string, readMethod bool, ok bool) {
method := r.Method
// Get CORS request method if OPTIONS request.
if r.Method == http.MethodOptions {
method = r.Header.Get("Access-Control-Request-Method")
if method == "" {
return "", false, false
}
}
switch method { switch method {
case http.MethodGet, http.MethodHead, http.MethodOptions: case http.MethodGet, http.MethodHead:
return true return http.MethodGet, true, true
case http.MethodPost, http.MethodPut, http.MethodDelete:
return method, false, true
default: default:
return false return "", false, false
} }
} }

View file

@ -102,7 +102,7 @@ func TestPermissions(t *testing.T) { //nolint:gocognit
} { } {
// Set request permission for test requests. // Set request permission for test requests.
reading := isReadMethod(method) reading := method == http.MethodGet
if reading { if reading {
testToken.Read = requestPerm testToken.Read = requestPerm
testToken.Write = NotSupported testToken.Write = NotSupported

View file

@ -371,9 +371,20 @@ func (e *Endpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
// TODO: Return errors instead of warnings, also update the field docs. // Return OPTIONS request before starting to handle normal requests.
if isReadMethod(r.Method) { if r.Method == http.MethodOptions {
if r.Method != e.ReadMethod { w.WriteHeader(http.StatusNoContent)
return
}
eMethod, readMethod, ok := getEffectiveMethod(r)
if !ok {
http.Error(w, "unsupported method for the actions API", http.StatusMethodNotAllowed)
return
}
if readMethod {
if eMethod != e.ReadMethod {
log.Tracer(r.Context()).Warningf( log.Tracer(r.Context()).Warningf(
"api: method %q does not match required read method %q%s", "api: method %q does not match required read method %q%s",
" - this will be an error and abort the request in the future", " - this will be an error and abort the request in the future",
@ -382,7 +393,7 @@ func (e *Endpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
) )
} }
} else { } else {
if r.Method != e.WriteMethod { if eMethod != e.WriteMethod {
log.Tracer(r.Context()).Warningf( log.Tracer(r.Context()).Warningf(
"api: method %q does not match required write method %q%s", "api: method %q does not match required write method %q%s",
" - this will be an error and abort the request in the future", " - this will be an error and abort the request in the future",
@ -392,10 +403,9 @@ func (e *Endpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
} }
switch r.Method { switch eMethod {
case http.MethodHead: case http.MethodGet, http.MethodDelete:
w.WriteHeader(http.StatusOK) // Nothing to do for these.
return
case http.MethodPost, http.MethodPut: case http.MethodPost, http.MethodPut:
// Read body data. // Read body data.
inputData, ok := readBody(w, r) inputData, ok := readBody(w, r)
@ -403,12 +413,8 @@ func (e *Endpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
apiRequest.InputData = inputData apiRequest.InputData = inputData
case http.MethodGet:
// Nothing special to do here.
case http.MethodOptions:
w.WriteHeader(http.StatusNoContent)
return
default: default:
// Defensive.
http.Error(w, "unsupported method for the actions API", http.StatusMethodNotAllowed) http.Error(w, "unsupported method for the actions API", http.StatusMethodNotAllowed)
return return
} }
@ -466,8 +472,8 @@ func (e *Endpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
// Check if there is no response data. // Return no content if there is none, or if request is HEAD.
if len(responseData) == 0 { if len(responseData) == 0 || r.Method == http.MethodHead {
w.WriteHeader(http.StatusNoContent) w.WriteHeader(http.StatusNoContent)
return return
} }

View file

@ -93,7 +93,10 @@ func authBearer(w http.ResponseWriter, r *http.Request) {
} }
// Respond with desired authentication header. // Respond with desired authentication header.
w.Header().Set("WWW-Authenticate", "Bearer realm=Portmaster API") w.Header().Set(
"WWW-Authenticate",
`Bearer realm="Portmaster API" domain="/"`,
)
http.Error(w, "Authorization required.", http.StatusUnauthorized) http.Error(w, "Authorization required.", http.StatusUnauthorized)
} }
@ -106,7 +109,10 @@ func authBasic(w http.ResponseWriter, r *http.Request) {
} }
// Respond with desired authentication header. // Respond with desired authentication header.
w.Header().Set("WWW-Authenticate", "Basic realm=Portmaster API") w.Header().Set(
"WWW-Authenticate",
`Basic realm="Portmaster API" domain="/"`,
)
http.Error(w, "Authorization required.", http.StatusUnauthorized) http.Error(w, "Authorization required.", http.StatusUnauthorized)
} }
@ -127,7 +133,7 @@ func authReset(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Clear-Site-Data", "*") w.Header().Set("Clear-Site-Data", "*")
// Set HTTP Auth Realm without requesting authorization. // Set HTTP Auth Realm without requesting authorization.
w.Header().Set("WWW-Authenticate", "None realm=Portmaster API") w.Header().Set("WWW-Authenticate", `None realm="Portmaster API"`)
// Reply with 401 Unauthorized in order to clear HTTP Basic Auth data. // Reply with 401 Unauthorized in order to clear HTTP Basic Auth data.
http.Error(w, "Session deleted.", http.StatusUnauthorized) http.Error(w, "Session deleted.", http.StatusUnauthorized)

View file

@ -125,14 +125,31 @@ func (mh *mainHandler) handle(w http.ResponseWriter, r *http.Request) error {
apiRequest.Route = match.Route apiRequest.Route = match.Route
apiRequest.URLVars = match.Vars apiRequest.URLVars = match.Vars
} }
switch {
case match.MatchErr == nil:
// All good.
case errors.Is(match.MatchErr, mux.ErrMethodMismatch):
http.Error(lrw, "Method not allowed.", http.StatusMethodNotAllowed)
return nil
default:
http.Error(lrw, "Not found.", http.StatusNotFound)
return nil
}
// Be sure that URLVars always is a map. // Be sure that URLVars always is a map.
if apiRequest.URLVars == nil { if apiRequest.URLVars == nil {
apiRequest.URLVars = make(map[string]string) apiRequest.URLVars = make(map[string]string)
} }
// Check method.
_, readMethod, ok := getEffectiveMethod(r)
if !ok {
http.Error(lrw, "Method not allowed.", http.StatusMethodNotAllowed)
return nil
}
// Check authentication. // Check authentication.
apiRequest.AuthToken = authenticateRequest(lrw, r, handler) apiRequest.AuthToken = authenticateRequest(lrw, r, handler, readMethod)
if apiRequest.AuthToken == nil { if apiRequest.AuthToken == nil {
// Authenticator already replied. // Authenticator already replied.
return nil return nil
@ -164,18 +181,21 @@ func (mh *mainHandler) handle(w http.ResponseWriter, r *http.Request) error {
} else if origin := r.Header.Get("Origin"); origin != "" { } else if origin := r.Header.Get("Origin"); origin != "" {
// Allow cross origin requests from localhost in dev mode. // Allow cross origin requests from localhost in dev mode.
if u, err := url.Parse(origin); err == nil && if u, err := url.Parse(origin); err == nil &&
utils.StringInSlice(allowedDevCORSOrigins, u.Host) { utils.StringInSlice(allowedDevCORSOrigins, u.Hostname()) {
w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Methods", "*")
w.Header().Set("Access-Control-Allow-Headers", "*")
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Expose-Headers", "*")
w.Header().Set("Access-Control-Max-Age", "60")
w.Header().Add("Vary", "Origin")
} }
} }
// Handle request. // Handle request.
switch { if handler != nil {
case handler != nil:
handler.ServeHTTP(lrw, r) handler.ServeHTTP(lrw, r)
case errors.Is(match.MatchErr, mux.ErrMethodMismatch): } else {
http.Error(lrw, "Method not allowed.", http.StatusMethodNotAllowed)
default: // handler == nil or other error
http.Error(lrw, "Not found.", http.StatusNotFound) http.Error(lrw, "Not found.", http.StatusNotFound)
} }

View file

@ -47,7 +47,7 @@ func prepConfig() error {
instanceOption = config.Concurrent.GetAsString(CfgOptionInstanceKey, instanceFlag) instanceOption = config.Concurrent.GetAsString(CfgOptionInstanceKey, instanceFlag)
err = config.Register(&config.Option{ err = config.Register(&config.Option{
Name: "Push Metrics", Name: "Push Prometheus Metrics",
Key: CfgOptionPushKey, Key: CfgOptionPushKey,
Description: "Push metrics to this URL in the prometheus format.", Description: "Push metrics to this URL in the prometheus format.",
OptType: config.OptTypeString, OptType: config.OptTypeString,