Tighten redirect and download path handling

This commit is contained in:
rcourtman 2026-03-31 09:17:52 +01:00
parent 66b448d63b
commit a7326d7047
9 changed files with 150 additions and 35 deletions

View file

@ -52,6 +52,40 @@ var requiredHostAgentBinaries = []HostAgentBinary{
},
}
var supportedHostAgentTargets = []HostAgentBinary{
{Platform: "freebsd", Arch: "amd64"},
}
// IsSupportedHostAgentPlatform reports whether platform is one of the release-supported host-agent platforms.
func IsSupportedHostAgentPlatform(platform string) bool {
for _, binary := range requiredHostAgentBinaries {
if binary.Platform == platform {
return true
}
}
for _, binary := range supportedHostAgentTargets {
if binary.Platform == platform {
return true
}
}
return false
}
// IsSupportedHostAgentTarget reports whether the platform/arch pair is one of the release-supported host-agent binaries.
func IsSupportedHostAgentTarget(platform, arch string) bool {
for _, binary := range requiredHostAgentBinaries {
if binary.Platform == platform && binary.Arch == arch {
return true
}
}
for _, binary := range supportedHostAgentTargets {
if binary.Platform == platform && binary.Arch == arch {
return true
}
}
return false
}
var downloadMu sync.Mutex
var (

View file

@ -11,6 +11,7 @@ import (
"time"
"github.com/rcourtman/pulse-go-rewrite/internal/alerts"
"github.com/rcourtman/pulse-go-rewrite/internal/pathutil"
"github.com/rs/zerolog/log"
)
@ -119,11 +120,20 @@ func NewIncidentStore(cfg IncidentStoreConfig) *IncidentStore {
}
if store.dataDir != "" {
store.filePath = filepath.Join(store.dataDir, incidentFileName)
if err := store.loadFromDisk(); err != nil {
log.Warn().Err(err).Msg("Failed to load incident history from disk")
} else if len(store.incidents) > 0 {
log.Info().Int("count", len(store.incidents)).Msg("Loaded incident history from disk")
normalizedDataDir, err := pathutil.NormalizeDir(store.dataDir)
if err != nil {
log.Warn().Err(err).Str("dataDir", store.dataDir).Msg("Failed to normalize incident data dir")
store.dataDir = ""
} else {
store.dataDir = normalizedDataDir
store.filePath = filepath.Join(store.dataDir, incidentFileName)
}
if store.filePath != "" {
if err := store.loadFromDisk(); err != nil {
log.Warn().Err(err).Msg("Failed to load incident history from disk")
} else if len(store.incidents) > 0 {
log.Info().Int("count", len(store.incidents)).Msg("Loaded incident history from disk")
}
}
}

View file

@ -304,12 +304,9 @@ func (r *Router) handleOIDCCallback(w http.ResponseWriter, req *http.Request) {
LogAuditEventForTenant(GetOrgID(req.Context()), "oidc_login", username, GetClientIP(req), req.URL.Path, true, "OIDC login success")
target := sanitizeOIDCReturnTo(entry.ReturnTo)
if target == "" {
target = "/"
}
target = addQueryParam(target, "oidc", "success")
http.Redirect(w, req, target, http.StatusFound)
http.Redirect(w, req, buildLocalRedirectTarget(entry.ReturnTo, map[string]string{
"oidc": "success",
}), http.StatusFound)
}
func (r *Router) getOIDCService(ctx context.Context, redirectURL string) (*OIDCService, error) {
@ -357,16 +354,36 @@ func sanitizeOIDCReturnTo(raw string) string {
}
func (r *Router) redirectOIDCError(w http.ResponseWriter, req *http.Request, returnTo string, code string) {
target := returnTo
target := buildLocalRedirectTarget(returnTo, map[string]string{
"oidc": "error",
"oidc_error": code,
})
http.Redirect(w, req, target, http.StatusFound)
}
func buildLocalRedirectTarget(returnTo string, queryParams map[string]string) string {
target := sanitizeOIDCReturnTo(returnTo)
if target == "" {
target = "/"
}
target = addQueryParam(target, "oidc", "error")
if code != "" {
target = addQueryParam(target, "oidc_error", code)
parsed, err := url.Parse(target)
if err != nil || parsed.IsAbs() || parsed.Host != "" {
parsed = &url.URL{Path: "/"}
}
if parsed.Path == "" {
parsed.Path = "/"
}
http.Redirect(w, req, target, http.StatusFound)
query := parsed.Query()
for key, value := range queryParams {
if key == "" || value == "" {
continue
}
query.Set(key, value)
}
parsed.RawQuery = query.Encode()
return parsed.RequestURI()
}
func addQueryParam(path, key, value string) string {

View file

@ -82,6 +82,15 @@ func TestRedirectOIDCError(t *testing.T) {
}
}
func TestBuildLocalRedirectTargetRejectsAbsoluteURL(t *testing.T) {
target := buildLocalRedirectTarget("https://evil.example.com/pwn", map[string]string{
"oidc": "error",
})
if target != "/?oidc=error" {
t.Fatalf("target = %q, want /?oidc=error", target)
}
}
func TestEnsureOIDCConfig_Defaults(t *testing.T) {
cfg := &config.Config{PublicURL: "https://pulse.example.com"}
router := &Router{config: cfg}

View file

@ -7129,6 +7129,18 @@ func (r *Router) handleDownloadHostAgent(w http.ResponseWriter, req *http.Reques
http.Error(w, "Invalid arch parameter", http.StatusBadRequest)
return
}
if archParam != "" && platformParam == "" {
http.Error(w, "arch parameter requires platform", http.StatusBadRequest)
return
}
if platformParam != "" && !agentbinaries.IsSupportedHostAgentPlatform(platformParam) {
http.Error(w, "Unsupported platform parameter", http.StatusBadRequest)
return
}
if platformParam != "" && archParam != "" && !agentbinaries.IsSupportedHostAgentTarget(platformParam, archParam) {
http.Error(w, "Unsupported platform/arch combination", http.StatusBadRequest)
return
}
checkedPaths, served := r.tryServeHostAgentBinary(w, req, platformParam, archParam)
if served {

View file

@ -20,12 +20,12 @@ func setupTempPulseBin(t *testing.T) string {
func TestHandleDownloadHostAgentServesWindowsExe(t *testing.T) {
binDir := setupTempPulseBin(t)
filePath := filepath.Join(binDir, "pulse-host-agent-windows-unit-test.exe")
filePath := filepath.Join(binDir, "pulse-host-agent-windows-amd64.exe")
if err := os.WriteFile(filePath, []byte("exe-binary"), 0o755); err != nil {
t.Fatalf("failed to write test binary: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/download/pulse-host-agent?platform=windows&arch=unit-test", nil)
req := httptest.NewRequest(http.MethodGet, "/download/pulse-host-agent?platform=windows&arch=amd64", nil)
rr := httptest.NewRecorder()
router := &Router{}
@ -65,7 +65,7 @@ func TestHandleDownloadHostAgentServesLinuxArm64(t *testing.T) {
func TestHandleDownloadHostAgentServesChecksumForWindowsExe(t *testing.T) {
const (
arch = "unit-sha"
arch = "amd64"
filename = "pulse-host-agent-windows-" + arch + ".exe"
)
binDir := setupTempPulseBin(t)
@ -92,6 +92,30 @@ func TestHandleDownloadHostAgentServesChecksumForWindowsExe(t *testing.T) {
}
}
func TestHandleDownloadHostAgentRejectsArchWithoutPlatform(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/download/pulse-host-agent?arch=amd64", nil)
rr := httptest.NewRecorder()
router := &Router{}
router.handleDownloadHostAgent(rr, req)
if rr.Code != http.StatusBadRequest {
t.Fatalf("expected 400 Bad Request, got %d", rr.Code)
}
}
func TestHandleDownloadHostAgentRejectsUnsupportedTarget(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/download/pulse-host-agent?platform=windows&arch=unit-test", nil)
rr := httptest.NewRecorder()
router := &Router{}
router.handleDownloadHostAgent(rr, req)
if rr.Code != http.StatusBadRequest {
t.Fatalf("expected 400 Bad Request, got %d", rr.Code)
}
}
func TestHandleDownloadHostAgentAllowsHEAD(t *testing.T) {
binDir := setupTempPulseBin(t)
filePath := filepath.Join(binDir, "pulse-host-agent-linux-amd64")

View file

@ -264,12 +264,9 @@ func (r *Router) handleSAMLACS(w http.ResponseWriter, req *http.Request) {
LogAuditEventForTenant(GetOrgID(req.Context()), "saml_login", username, GetClientIP(req), req.URL.Path, true, "SAML login success via "+providerID)
// Redirect to return URL - sanitize relayState to prevent open redirect
target := sanitizeOIDCReturnTo(relayState)
if target == "" {
target = "/"
}
target = addQueryParam(target, "saml", "success")
http.Redirect(w, req, target, http.StatusFound)
http.Redirect(w, req, buildLocalRedirectTarget(relayState, map[string]string{
"saml": "success",
}), http.StatusFound)
}
// handleSAMLMetadata returns the SP metadata XML
@ -497,16 +494,10 @@ func (r *Router) clearSession(w http.ResponseWriter, req *http.Request) {
}
func (r *Router) redirectSAMLError(w http.ResponseWriter, req *http.Request, returnTo string, code string) {
// Sanitize returnTo to prevent open redirect attacks
target := sanitizeOIDCReturnTo(returnTo)
if target == "" {
target = "/"
}
target = addQueryParam(target, "saml", "error")
if code != "" {
target = addQueryParam(target, "saml_error", code)
}
http.Redirect(w, req, target, http.StatusFound)
http.Redirect(w, req, buildLocalRedirectTarget(returnTo, map[string]string{
"saml": "error",
"saml_error": code,
}), http.StatusFound)
}
// extractSAMLProviderID extracts the provider ID from a SAML endpoint path

View file

@ -125,6 +125,18 @@ func TestRedirectSAMLError(t *testing.T) {
}
}
func TestRedirectSAMLErrorRejectsAbsoluteReturnTo(t *testing.T) {
router := &Router{}
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
router.redirectSAMLError(rec, req, "https://evil.example.com/pwn", "session_failed")
if loc := rec.Header().Get("Location"); loc != "/?saml=error&saml_error=session_failed" {
t.Fatalf("unexpected redirect location %q", loc)
}
}
func TestInitializeSAMLProviders(t *testing.T) {
provider := testSAMLProvider("okta", true)
provider.SAML = &config.SAMLProviderConfig{

View file

@ -13,6 +13,7 @@ import (
"time"
"github.com/rcourtman/pulse-go-rewrite/internal/alerts"
"github.com/rcourtman/pulse-go-rewrite/internal/pathutil"
"github.com/rcourtman/pulse-go-rewrite/internal/utils"
"github.com/rs/zerolog/log"
_ "modernc.org/sqlite"
@ -67,6 +68,11 @@ func NewNotificationQueue(dataDir string) (*NotificationQueue, error) {
if dataDir == "" {
dataDir = filepath.Join(utils.GetDataDir(), "notifications")
}
normalizedDir, err := pathutil.NormalizeDir(dataDir)
if err != nil {
return nil, fmt.Errorf("failed to normalize notification queue directory: %w", err)
}
dataDir = normalizedDir
// Ensure directory exists
if err := os.MkdirAll(dataDir, 0755); err != nil {