diff --git a/api/authentication.go b/api/authentication.go index a9b376f..063c1b6 100644 --- a/api/authentication.go +++ b/api/authentication.go @@ -2,6 +2,7 @@ package api import ( "encoding/base64" + "errors" "net/http" "sync" "time" @@ -16,6 +17,10 @@ var ( authFnLock sync.Mutex authFn Authenticator + + // ErrAPIAccessDeniedMessage should be returned by Authenticator functions in + // order to signify a blocked request, including a error message for the user. + ErrAPIAccessDeniedMessage = errors.New("") ) const ( @@ -28,7 +33,7 @@ const ( ) // Authenticator is a function that can be set as the authenticator for the API endpoint. If none is set, all requests will be allowed. -type Authenticator func(s *http.Server, r *http.Request) (grantAccess bool, err error) +type Authenticator func(s *http.Server, r *http.Request) (err error) // SetAuthenticator sets an authenticator function for the API endpoint. If none is set, all requests will be allowed. func SetAuthenticator(fn Authenticator) error { @@ -79,15 +84,15 @@ func authMiddleware(next http.Handler) http.Handler { } // get auth decision - grantAccess, err := authenticator(server, r) + err = authenticator(server, r) if err != nil { - log.Warningf("api: authenticator failed: %s", err) - http.Error(w, "Bad Request: Could not identify client", http.StatusBadRequest) - return - } - if !grantAccess { - log.Warningf("api: denying api access to %s", r.RemoteAddr) - http.Error(w, "Forbidden", http.StatusForbidden) + if errors.Is(err, ErrAPIAccessDeniedMessage) { + log.Warningf("api: denying api access to %s", r.RemoteAddr) + http.Error(w, err.Error(), http.StatusForbidden) + } else { + log.Warningf("api: authenticator failed: %s", err) + http.Error(w, "Internal server error during authentication.", http.StatusInternalServerError) + } return } diff --git a/run/main.go b/run/main.go index 903519e..9050df5 100644 --- a/run/main.go +++ b/run/main.go @@ -4,6 +4,7 @@ import ( "bufio" "flag" "fmt" + "io" "os" "os/signal" "runtime/pprof" @@ -36,6 +37,10 @@ func Run() int { return 0 } + if printStackOnExit { + printStackTo(os.Stdout) + } + _ = modules.Shutdown() return modules.GetExitStatusCode() } @@ -78,20 +83,13 @@ signalLoop: }() if printStackOnExit { - fmt.Println("=== PRINTING TRACES ===") - fmt.Println("=== GOROUTINES ===") - _ = pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) - fmt.Println("=== BLOCKING ===") - _ = pprof.Lookup("block").WriteTo(os.Stdout, 1) - fmt.Println("=== MUTEXES ===") - _ = pprof.Lookup("mutex").WriteTo(os.Stdout, 1) - fmt.Println("=== END TRACES ===") + printStackTo(os.Stdout) } go func() { - time.Sleep(60 * time.Second) - fmt.Fprintln(os.Stderr, "===== TAKING TOO LONG FOR SHUTDOWN - PRINTING STACK TRACES =====") - _ = pprof.Lookup("goroutine").WriteTo(os.Stderr, 1) + time.Sleep(3 * time.Minute) + fmt.Fprintln(os.Stderr, "===== TAKING TOO LONG FOR SHUTDOWN =====") + printStackTo(os.Stderr) os.Exit(1) }() @@ -124,3 +122,14 @@ func inputSignals(signalCh chan os.Signal) { } } } + +func printStackTo(writer io.Writer) { + fmt.Fprintln(writer, "=== PRINTING TRACES ===") + fmt.Fprintln(writer, "=== GOROUTINES ===") + _ = pprof.Lookup("goroutine").WriteTo(writer, 1) + fmt.Fprintln(writer, "=== BLOCKING ===") + _ = pprof.Lookup("block").WriteTo(writer, 1) + fmt.Fprintln(writer, "=== MUTEXES ===") + _ = pprof.Lookup("mutex").WriteTo(writer, 1) + fmt.Fprintln(writer, "=== END TRACES ===") +}