mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-05-08 01:37:54 +00:00
229 lines
7.5 KiB
Go
229 lines
7.5 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
|
pulsews "github.com/rcourtman/pulse-go-rewrite/internal/websocket"
|
|
)
|
|
|
|
func resetTrustedProxyCIDRsForTests() {
|
|
trustedProxyCIDRs = nil
|
|
trustedProxyOnce = sync.Once{}
|
|
}
|
|
|
|
func newWebSocketRouterWithServer(t *testing.T, allowedOrigins []string, tokenRecord config.APITokenRecord, serverFactory func(*testing.T, http.Handler) *httptest.Server) (*httptest.Server, func()) {
|
|
t.Helper()
|
|
|
|
cfg := newTestConfigWithTokens(t, tokenRecord)
|
|
|
|
hub := pulsews.NewHub(nil)
|
|
hub.SetAllowedOrigins(allowedOrigins)
|
|
go hub.Run()
|
|
|
|
router := NewRouter(cfg, nil, nil, hub, nil, "1.0.0")
|
|
server := serverFactory(t, router.Handler())
|
|
|
|
cleanup := func() {
|
|
server.Close()
|
|
hub.Stop()
|
|
}
|
|
|
|
return server, cleanup
|
|
}
|
|
|
|
func newWebSocketRouter(t *testing.T, allowedOrigins []string, tokenRecord config.APITokenRecord) (*httptest.Server, func()) {
|
|
t.Helper()
|
|
return newWebSocketRouterWithServer(t, allowedOrigins, tokenRecord, newIPv4HTTPServer)
|
|
}
|
|
|
|
func newWebSocketRouterIPv6(t *testing.T, allowedOrigins []string, tokenRecord config.APITokenRecord) (*httptest.Server, func()) {
|
|
t.Helper()
|
|
return newWebSocketRouterWithServer(t, allowedOrigins, tokenRecord, newIPv6HTTPServer)
|
|
}
|
|
|
|
func TestWebSocketOriginRejectedWhenNotAllowed(t *testing.T) {
|
|
rawToken := "ws-origin-reject-123.12345678"
|
|
record := newTokenRecord(t, rawToken, []string{config.ScopeMonitoringRead}, nil)
|
|
|
|
server, cleanup := newWebSocketRouter(t, []string{"https://allowed.example.com"}, record)
|
|
defer cleanup()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws?org_id=default"
|
|
headers := http.Header{}
|
|
headers.Set("X-API-Token", rawToken)
|
|
headers.Set("Origin", "https://evil.example.com")
|
|
|
|
conn, resp, err := websocket.DefaultDialer.Dial(wsURL, headers)
|
|
if err == nil {
|
|
conn.Close()
|
|
t.Fatalf("expected websocket origin rejection")
|
|
}
|
|
if resp == nil {
|
|
t.Fatalf("expected HTTP response for rejected origin")
|
|
}
|
|
if resp.StatusCode != http.StatusForbidden {
|
|
t.Fatalf("expected status %d, got %d", http.StatusForbidden, resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestWebSocketOriginAllowedWhenConfigured(t *testing.T) {
|
|
rawToken := "ws-origin-allow-123.12345678"
|
|
record := newTokenRecord(t, rawToken, []string{config.ScopeMonitoringRead}, nil)
|
|
|
|
server, cleanup := newWebSocketRouter(t, []string{"https://allowed.example.com"}, record)
|
|
defer cleanup()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws?org_id=default"
|
|
headers := http.Header{}
|
|
headers.Set("X-API-Token", rawToken)
|
|
headers.Set("Origin", "https://allowed.example.com")
|
|
|
|
conn, resp, err := websocket.DefaultDialer.Dial(wsURL, headers)
|
|
if err != nil {
|
|
t.Fatalf("expected websocket connection, got %v", err)
|
|
}
|
|
if resp == nil || resp.StatusCode != http.StatusSwitchingProtocols {
|
|
t.Fatalf("expected 101 switching protocols, got %v", resp)
|
|
}
|
|
conn.Close()
|
|
}
|
|
|
|
func TestWebSocketOriginRejectedWhenNoAllowedOriginsAndPublicOrigin(t *testing.T) {
|
|
rawToken := "ws-origin-default-reject-123.12345678"
|
|
record := newTokenRecord(t, rawToken, []string{config.ScopeMonitoringRead}, nil)
|
|
|
|
server, cleanup := newWebSocketRouter(t, []string{}, record)
|
|
defer cleanup()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws?org_id=default"
|
|
headers := http.Header{}
|
|
headers.Set("X-API-Token", rawToken)
|
|
headers.Set("Origin", "https://evil.example.com")
|
|
|
|
conn, resp, err := websocket.DefaultDialer.Dial(wsURL, headers)
|
|
if err == nil {
|
|
conn.Close()
|
|
t.Fatalf("expected websocket origin rejection with empty allowed origins")
|
|
}
|
|
if resp == nil {
|
|
t.Fatalf("expected HTTP response for rejected origin")
|
|
}
|
|
if resp.StatusCode != http.StatusForbidden {
|
|
t.Fatalf("expected status %d, got %d", http.StatusForbidden, resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestWebSocketOriginAllowsPrivateWhenNoAllowedOrigins(t *testing.T) {
|
|
rawToken := "ws-origin-default-allow-123.12345678"
|
|
record := newTokenRecord(t, rawToken, []string{config.ScopeMonitoringRead}, nil)
|
|
|
|
server, cleanup := newWebSocketRouter(t, []string{}, record)
|
|
defer cleanup()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws?org_id=default"
|
|
headers := http.Header{}
|
|
headers.Set("X-API-Token", rawToken)
|
|
headers.Set("Origin", "http://localhost:3000")
|
|
|
|
conn, resp, err := websocket.DefaultDialer.Dial(wsURL, headers)
|
|
if err != nil {
|
|
t.Fatalf("expected websocket connection, got %v", err)
|
|
}
|
|
if resp == nil || resp.StatusCode != http.StatusSwitchingProtocols {
|
|
t.Fatalf("expected 101 switching protocols, got %v", resp)
|
|
}
|
|
conn.Close()
|
|
}
|
|
|
|
func TestWebSocketOriginAllowsSameHostTLSTerminationWithoutTrustedProxy(t *testing.T) {
|
|
rawToken := "ws-origin-same-host-123.12345678"
|
|
record := newTokenRecord(t, rawToken, []string{config.ScopeMonitoringRead}, nil)
|
|
|
|
server, cleanup := newWebSocketRouter(t, []string{}, record)
|
|
defer cleanup()
|
|
|
|
dialer := websocket.Dialer{
|
|
NetDialContext: func(ctx context.Context, network, _ string) (net.Conn, error) {
|
|
var dialer net.Dialer
|
|
return dialer.DialContext(ctx, network, server.Listener.Addr().String())
|
|
},
|
|
}
|
|
|
|
headers := http.Header{}
|
|
headers.Set("X-API-Token", rawToken)
|
|
headers.Set("Origin", "https://tenant.example.com")
|
|
|
|
conn, resp, err := dialer.Dial("ws://tenant.example.com/ws?org_id=default", headers)
|
|
if err != nil {
|
|
t.Fatalf("expected websocket connection when proxy preserves host but terminates tls upstream, got %v", err)
|
|
}
|
|
if resp == nil || resp.StatusCode != http.StatusSwitchingProtocols {
|
|
conn.Close()
|
|
t.Fatalf("expected 101 switching protocols, got %v", resp)
|
|
}
|
|
conn.Close()
|
|
}
|
|
|
|
func TestWebSocketOriginAllowsTrustedForwardedHostedOrigin(t *testing.T) {
|
|
t.Setenv("PULSE_TRUSTED_PROXY_CIDRS", "127.0.0.1/32")
|
|
resetTrustedProxyCIDRsForTests()
|
|
|
|
rawToken := "ws-origin-forwarded-allow-123.12345678"
|
|
record := newTokenRecord(t, rawToken, []string{config.ScopeMonitoringRead}, nil)
|
|
|
|
server, cleanup := newWebSocketRouter(t, []string{}, record)
|
|
defer cleanup()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws?org_id=default"
|
|
headers := http.Header{}
|
|
headers.Set("X-API-Token", rawToken)
|
|
headers.Set("Origin", "https://tenant.example.com")
|
|
headers.Set("X-Forwarded-Proto", "https")
|
|
headers.Set("X-Forwarded-Host", "tenant.example.com")
|
|
|
|
conn, resp, err := websocket.DefaultDialer.Dial(wsURL, headers)
|
|
if err != nil {
|
|
t.Fatalf("expected websocket connection behind trusted proxy, got %v", err)
|
|
}
|
|
if resp == nil || resp.StatusCode != http.StatusSwitchingProtocols {
|
|
conn.Close()
|
|
t.Fatalf("expected 101 switching protocols, got %v", resp)
|
|
}
|
|
conn.Close()
|
|
}
|
|
|
|
func TestWebSocketOriginAllowsTrustedForwardedHostedOriginIPv6Loopback(t *testing.T) {
|
|
t.Setenv("PULSE_TRUSTED_PROXY_CIDRS", "::1/128")
|
|
resetTrustedProxyCIDRsForTests()
|
|
|
|
rawToken := "ws-origin-forwarded-ipv6-allow-123.12345678"
|
|
record := newTokenRecord(t, rawToken, []string{config.ScopeMonitoringRead}, nil)
|
|
|
|
server, cleanup := newWebSocketRouterIPv6(t, []string{}, record)
|
|
defer cleanup()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws?org_id=default"
|
|
headers := http.Header{}
|
|
headers.Set("X-API-Token", rawToken)
|
|
headers.Set("Origin", "https://tenant.example.com")
|
|
headers.Set("X-Forwarded-Proto", "https")
|
|
headers.Set("X-Forwarded-Host", "tenant.example.com")
|
|
|
|
conn, resp, err := websocket.DefaultDialer.Dial(wsURL, headers)
|
|
if err != nil {
|
|
t.Fatalf("expected websocket connection behind trusted IPv6 loopback proxy, got %v", err)
|
|
}
|
|
if resp == nil || resp.StatusCode != http.StatusSwitchingProtocols {
|
|
conn.Close()
|
|
t.Fatalf("expected 101 switching protocols, got %v", resp)
|
|
}
|
|
conn.Close()
|
|
}
|