mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-05-07 00:37:36 +00:00
198 lines
6.5 KiB
Go
198 lines
6.5 KiB
Go
package api
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
|
pulsews "github.com/rcourtman/pulse-go-rewrite/internal/websocket"
|
|
)
|
|
|
|
func resetTrustedProxyCIDRsForTests() {
|
|
trustedProxyCIDRs = nil
|
|
trustedProxyOnce = sync.Once{}
|
|
}
|
|
|
|
func newWebSocketRouterWithServer(t *testing.T, allowedOrigins []string, tokenRecord config.APITokenRecord, serverFactory func(*testing.T, http.Handler) *httptest.Server) (*httptest.Server, func()) {
|
|
t.Helper()
|
|
|
|
cfg := newTestConfigWithTokens(t, tokenRecord)
|
|
|
|
hub := pulsews.NewHub(nil)
|
|
hub.SetAllowedOrigins(allowedOrigins)
|
|
go hub.Run()
|
|
|
|
router := NewRouter(cfg, nil, nil, hub, nil, "1.0.0")
|
|
server := serverFactory(t, router.Handler())
|
|
|
|
cleanup := func() {
|
|
server.Close()
|
|
hub.Stop()
|
|
}
|
|
|
|
return server, cleanup
|
|
}
|
|
|
|
func newWebSocketRouter(t *testing.T, allowedOrigins []string, tokenRecord config.APITokenRecord) (*httptest.Server, func()) {
|
|
t.Helper()
|
|
return newWebSocketRouterWithServer(t, allowedOrigins, tokenRecord, newIPv4HTTPServer)
|
|
}
|
|
|
|
func newWebSocketRouterIPv6(t *testing.T, allowedOrigins []string, tokenRecord config.APITokenRecord) (*httptest.Server, func()) {
|
|
t.Helper()
|
|
return newWebSocketRouterWithServer(t, allowedOrigins, tokenRecord, newIPv6HTTPServer)
|
|
}
|
|
|
|
func TestWebSocketOriginRejectedWhenNotAllowed(t *testing.T) {
|
|
rawToken := "ws-origin-reject-123.12345678"
|
|
record := newTokenRecord(t, rawToken, []string{config.ScopeMonitoringRead}, nil)
|
|
|
|
server, cleanup := newWebSocketRouter(t, []string{"https://allowed.example.com"}, record)
|
|
defer cleanup()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws?org_id=default"
|
|
headers := http.Header{}
|
|
headers.Set("X-API-Token", rawToken)
|
|
headers.Set("Origin", "https://evil.example.com")
|
|
|
|
conn, resp, err := websocket.DefaultDialer.Dial(wsURL, headers)
|
|
if err == nil {
|
|
conn.Close()
|
|
t.Fatalf("expected websocket origin rejection")
|
|
}
|
|
if resp == nil {
|
|
t.Fatalf("expected HTTP response for rejected origin")
|
|
}
|
|
if resp.StatusCode != http.StatusForbidden {
|
|
t.Fatalf("expected status %d, got %d", http.StatusForbidden, resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestWebSocketOriginAllowedWhenConfigured(t *testing.T) {
|
|
rawToken := "ws-origin-allow-123.12345678"
|
|
record := newTokenRecord(t, rawToken, []string{config.ScopeMonitoringRead}, nil)
|
|
|
|
server, cleanup := newWebSocketRouter(t, []string{"https://allowed.example.com"}, record)
|
|
defer cleanup()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws?org_id=default"
|
|
headers := http.Header{}
|
|
headers.Set("X-API-Token", rawToken)
|
|
headers.Set("Origin", "https://allowed.example.com")
|
|
|
|
conn, resp, err := websocket.DefaultDialer.Dial(wsURL, headers)
|
|
if err != nil {
|
|
t.Fatalf("expected websocket connection, got %v", err)
|
|
}
|
|
if resp == nil || resp.StatusCode != http.StatusSwitchingProtocols {
|
|
t.Fatalf("expected 101 switching protocols, got %v", resp)
|
|
}
|
|
conn.Close()
|
|
}
|
|
|
|
func TestWebSocketOriginRejectedWhenNoAllowedOriginsAndPublicOrigin(t *testing.T) {
|
|
rawToken := "ws-origin-default-reject-123.12345678"
|
|
record := newTokenRecord(t, rawToken, []string{config.ScopeMonitoringRead}, nil)
|
|
|
|
server, cleanup := newWebSocketRouter(t, []string{}, record)
|
|
defer cleanup()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws?org_id=default"
|
|
headers := http.Header{}
|
|
headers.Set("X-API-Token", rawToken)
|
|
headers.Set("Origin", "https://evil.example.com")
|
|
|
|
conn, resp, err := websocket.DefaultDialer.Dial(wsURL, headers)
|
|
if err == nil {
|
|
conn.Close()
|
|
t.Fatalf("expected websocket origin rejection with empty allowed origins")
|
|
}
|
|
if resp == nil {
|
|
t.Fatalf("expected HTTP response for rejected origin")
|
|
}
|
|
if resp.StatusCode != http.StatusForbidden {
|
|
t.Fatalf("expected status %d, got %d", http.StatusForbidden, resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestWebSocketOriginAllowsPrivateWhenNoAllowedOrigins(t *testing.T) {
|
|
rawToken := "ws-origin-default-allow-123.12345678"
|
|
record := newTokenRecord(t, rawToken, []string{config.ScopeMonitoringRead}, nil)
|
|
|
|
server, cleanup := newWebSocketRouter(t, []string{}, record)
|
|
defer cleanup()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws?org_id=default"
|
|
headers := http.Header{}
|
|
headers.Set("X-API-Token", rawToken)
|
|
headers.Set("Origin", "http://localhost:3000")
|
|
|
|
conn, resp, err := websocket.DefaultDialer.Dial(wsURL, headers)
|
|
if err != nil {
|
|
t.Fatalf("expected websocket connection, got %v", err)
|
|
}
|
|
if resp == nil || resp.StatusCode != http.StatusSwitchingProtocols {
|
|
t.Fatalf("expected 101 switching protocols, got %v", resp)
|
|
}
|
|
conn.Close()
|
|
}
|
|
|
|
func TestWebSocketOriginAllowsTrustedForwardedHostedOrigin(t *testing.T) {
|
|
t.Setenv("PULSE_TRUSTED_PROXY_CIDRS", "127.0.0.1/32")
|
|
resetTrustedProxyCIDRsForTests()
|
|
|
|
rawToken := "ws-origin-forwarded-allow-123.12345678"
|
|
record := newTokenRecord(t, rawToken, []string{config.ScopeMonitoringRead}, nil)
|
|
|
|
server, cleanup := newWebSocketRouter(t, []string{}, record)
|
|
defer cleanup()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws?org_id=default"
|
|
headers := http.Header{}
|
|
headers.Set("X-API-Token", rawToken)
|
|
headers.Set("Origin", "https://tenant.example.com")
|
|
headers.Set("X-Forwarded-Proto", "https")
|
|
headers.Set("X-Forwarded-Host", "tenant.example.com")
|
|
|
|
conn, resp, err := websocket.DefaultDialer.Dial(wsURL, headers)
|
|
if err != nil {
|
|
t.Fatalf("expected websocket connection behind trusted proxy, got %v", err)
|
|
}
|
|
if resp == nil || resp.StatusCode != http.StatusSwitchingProtocols {
|
|
conn.Close()
|
|
t.Fatalf("expected 101 switching protocols, got %v", resp)
|
|
}
|
|
conn.Close()
|
|
}
|
|
|
|
func TestWebSocketOriginAllowsTrustedForwardedHostedOriginIPv6Loopback(t *testing.T) {
|
|
t.Setenv("PULSE_TRUSTED_PROXY_CIDRS", "::1/128")
|
|
resetTrustedProxyCIDRsForTests()
|
|
|
|
rawToken := "ws-origin-forwarded-ipv6-allow-123.12345678"
|
|
record := newTokenRecord(t, rawToken, []string{config.ScopeMonitoringRead}, nil)
|
|
|
|
server, cleanup := newWebSocketRouterIPv6(t, []string{}, record)
|
|
defer cleanup()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws?org_id=default"
|
|
headers := http.Header{}
|
|
headers.Set("X-API-Token", rawToken)
|
|
headers.Set("Origin", "https://tenant.example.com")
|
|
headers.Set("X-Forwarded-Proto", "https")
|
|
headers.Set("X-Forwarded-Host", "tenant.example.com")
|
|
|
|
conn, resp, err := websocket.DefaultDialer.Dial(wsURL, headers)
|
|
if err != nil {
|
|
t.Fatalf("expected websocket connection behind trusted IPv6 loopback proxy, got %v", err)
|
|
}
|
|
if resp == nil || resp.StatusCode != http.StatusSwitchingProtocols {
|
|
conn.Close()
|
|
t.Fatalf("expected 101 switching protocols, got %v", resp)
|
|
}
|
|
conn.Close()
|
|
}
|