fix(websocket): respect X-Forwarded headers in same-origin check

- Use X-Forwarded-Proto/X-Forwarded-Scheme for scheme detection
- Use X-Forwarded-Host for host matching behind reverse proxies
- Update tests with remoteAddr for CSWSH protection validation
This commit is contained in:
rcourtman 2026-02-03 21:45:39 +00:00
parent 1490a6e6e3
commit a9ed380718
2 changed files with 53 additions and 32 deletions

View file

@ -105,17 +105,28 @@ func (h *Hub) checkOrigin(r *http.Request) bool {
allowedOrigins := h.allowedOrigins
h.mu.RUnlock()
// Determine the actual origin
// Determine the actual origin (accounting for proxy headers)
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
// Check X-Forwarded-Proto or X-Forwarded-Scheme for proxied requests
if forwardedProto := r.Header.Get("X-Forwarded-Proto"); forwardedProto != "" {
scheme = normalizeForwardedProto(forwardedProto, scheme)
} else if forwardedScheme := r.Header.Get("X-Forwarded-Scheme"); forwardedScheme != "" {
scheme = normalizeForwardedProto(forwardedScheme, scheme)
}
// Use X-Forwarded-Host if present (for proxied requests)
host := r.Host
if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" {
host = forwardedHost
}
requestOrigin := scheme + "://" + host
// Allow same-origin requests (accounting for proxy headers)
// Allow same-origin requests
if origin == requestOrigin {
return true
}

View file

@ -683,6 +683,7 @@ func TestHub_CheckOrigin(t *testing.T) {
allowedOrigins []string
forwardedProto string
forwardedHost string
remoteAddr string // Simulated peer IP for CSWSH checks
expected bool
}{
// No origin header - always allowed for non-browser clients
@ -748,41 +749,48 @@ func TestHub_CheckOrigin(t *testing.T) {
},
// Private network fallback (no allowed origins configured)
// Note: remoteAddr must be private for CSWSH protection to allow
{
name: "private IP 192.168.x.x allowed when no origins configured",
origin: "http://192.168.1.100:3000",
host: "localhost:8080",
expected: true,
name: "private IP 192.168.x.x allowed when no origins configured",
origin: "http://192.168.1.100:3000",
host: "localhost:8080",
remoteAddr: "192.168.1.100:54321",
expected: true,
},
{
name: "private IP 10.x.x.x allowed when no origins configured",
origin: "http://10.0.0.50:3000",
host: "localhost:8080",
expected: true,
name: "private IP 10.x.x.x allowed when no origins configured",
origin: "http://10.0.0.50:3000",
host: "localhost:8080",
remoteAddr: "10.0.0.50:54321",
expected: true,
},
{
name: "localhost allowed when no origins configured",
origin: "http://localhost:3000",
host: "localhost:8080",
expected: true,
name: "localhost allowed when no origins configured",
origin: "http://localhost:3000",
host: "localhost:8080",
remoteAddr: "127.0.0.1:54321",
expected: true,
},
{
name: "127.0.0.1 allowed when no origins configured",
origin: "http://127.0.0.1:3000",
host: "localhost:8080",
expected: true,
name: "127.0.0.1 allowed when no origins configured",
origin: "http://127.0.0.1:3000",
host: "localhost:8080",
remoteAddr: "127.0.0.1:54321",
expected: true,
},
{
name: ".local domain allowed when no origins configured",
origin: "http://myserver.local:3000",
host: "localhost:8080",
expected: true,
name: ".local domain allowed when no origins configured",
origin: "http://myserver.local:3000",
host: "localhost:8080",
remoteAddr: "192.168.1.50:54321",
expected: true,
},
{
name: ".lan domain allowed when no origins configured",
origin: "http://myserver.lan:3000",
host: "localhost:8080",
expected: true,
name: ".lan domain allowed when no origins configured",
origin: "http://myserver.lan:3000",
host: "localhost:8080",
remoteAddr: "192.168.1.50:54321",
expected: true,
},
{
name: "public IP rejected when no origins configured",
@ -799,10 +807,11 @@ func TestHub_CheckOrigin(t *testing.T) {
// HTTPS origin stripping
{
name: "https origin with private IP",
origin: "https://192.168.1.50:443",
host: "localhost:8080",
expected: true,
name: "https origin with private IP",
origin: "https://192.168.1.50:443",
host: "localhost:8080",
remoteAddr: "192.168.1.50:54321",
expected: true,
},
// Forwarded proto normalization
@ -830,8 +839,9 @@ func TestHub_CheckOrigin(t *testing.T) {
}
req := &http.Request{
Host: tc.host,
Header: make(http.Header),
Host: tc.host,
Header: make(http.Header),
RemoteAddr: tc.remoteAddr,
}
if tc.origin != "" {
req.Header.Set("Origin", tc.origin)