mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-05-02 13:30:13 +00:00
Export restricted outbound HTTP security helpers
This commit is contained in:
parent
de99fcb1f0
commit
9c8387be6f
8 changed files with 641 additions and 562 deletions
307
pkg/securityutil/httpurl.go
Normal file
307
pkg/securityutil/httpurl.go
Normal file
|
|
@ -0,0 +1,307 @@
|
|||
package securityutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const requestPlaceholderURL = "http://pulse.invalid"
|
||||
|
||||
func cloneURL(u *url.URL) *url.URL {
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *u
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func joinURLPath(basePath, relativePath string) string {
|
||||
parts := []string{basePath}
|
||||
if trimmed := strings.Trim(relativePath, "/"); trimmed != "" {
|
||||
parts = append(parts, trimmed)
|
||||
}
|
||||
|
||||
joined := path.Join(parts...)
|
||||
switch joined {
|
||||
case ".", "/":
|
||||
return ""
|
||||
default:
|
||||
if strings.HasPrefix(joined, "/") {
|
||||
return joined
|
||||
}
|
||||
return "/" + joined
|
||||
}
|
||||
}
|
||||
|
||||
// NormalizeAbsoluteHTTPURL validates a fully-qualified HTTP(S) URL.
|
||||
func NormalizeAbsoluteHTTPURL(raw string) (*url.URL, error) {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return nil, fmt.Errorf("URL is required")
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(trimmed)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
if parsed.Scheme != "http" && parsed.Scheme != "https" {
|
||||
return nil, fmt.Errorf("URL scheme must be http or https")
|
||||
}
|
||||
if parsed.Host == "" {
|
||||
return nil, fmt.Errorf("URL host is required")
|
||||
}
|
||||
if parsed.User != nil {
|
||||
return nil, fmt.Errorf("URL userinfo is not allowed")
|
||||
}
|
||||
if parsed.Hostname() == "" {
|
||||
return nil, fmt.Errorf("URL hostname is required")
|
||||
}
|
||||
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
// NormalizeHTTPBaseURL validates a base HTTP(S) URL and optionally adds a default scheme.
|
||||
func NormalizeHTTPBaseURL(raw string, defaultScheme string) (*url.URL, error) {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return nil, fmt.Errorf("base URL is required")
|
||||
}
|
||||
if defaultScheme != "" && !strings.Contains(trimmed, "://") {
|
||||
trimmed = defaultScheme + "://" + trimmed
|
||||
}
|
||||
|
||||
parsed, err := NormalizeAbsoluteHTTPURL(trimmed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if parsed.RawQuery != "" || parsed.Fragment != "" {
|
||||
return nil, fmt.Errorf("base URL must not include query or fragment")
|
||||
}
|
||||
|
||||
cleanedPath := path.Clean(parsed.Path)
|
||||
switch cleanedPath {
|
||||
case ".", "/":
|
||||
parsed.Path = ""
|
||||
default:
|
||||
if !strings.HasPrefix(cleanedPath, "/") {
|
||||
cleanedPath = "/" + cleanedPath
|
||||
}
|
||||
parsed.Path = cleanedPath
|
||||
}
|
||||
parsed.RawPath = ""
|
||||
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
// IsLoopbackHost reports whether host resolves to localhost or a loopback IP literal.
|
||||
func IsLoopbackHost(host string) bool {
|
||||
normalized := strings.ToLower(strings.Trim(host, "[]"))
|
||||
if normalized == "" {
|
||||
return false
|
||||
}
|
||||
if normalized == "localhost" || strings.HasSuffix(normalized, ".localhost") {
|
||||
return true
|
||||
}
|
||||
|
||||
ip := net.ParseIP(normalized)
|
||||
return ip != nil && ip.IsLoopback()
|
||||
}
|
||||
|
||||
// NormalizePulseHTTPBaseURL validates a Pulse control-plane base URL.
|
||||
// HTTPS is required for non-loopback hosts; loopback localhost may use HTTP.
|
||||
func NormalizePulseHTTPBaseURL(raw string) (*url.URL, error) {
|
||||
return normalizePulseBaseURL(raw, false)
|
||||
}
|
||||
|
||||
// NormalizePulseWebSocketBaseURL validates a Pulse command-channel base URL.
|
||||
// Non-loopback hosts are normalized to WSS; loopback localhost may use WS.
|
||||
func NormalizePulseWebSocketBaseURL(raw string) (*url.URL, error) {
|
||||
return normalizePulseBaseURL(raw, true)
|
||||
}
|
||||
|
||||
func normalizePulseBaseURL(raw string, websocket bool) (*url.URL, error) {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return nil, fmt.Errorf("Pulse URL is required")
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(trimmed)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Pulse URL %q is invalid: %w", raw, err)
|
||||
}
|
||||
if parsed.Scheme == "" {
|
||||
if websocket {
|
||||
return nil, fmt.Errorf("Pulse URL %q must include scheme (https://, wss://, or loopback http:// / ws://)", raw)
|
||||
}
|
||||
return nil, fmt.Errorf("Pulse URL %q must include scheme (https:// or loopback http://)", raw)
|
||||
}
|
||||
if parsed.Host == "" || parsed.Hostname() == "" {
|
||||
return nil, fmt.Errorf("Pulse URL %q must include host", raw)
|
||||
}
|
||||
if parsed.User != nil {
|
||||
return nil, fmt.Errorf("Pulse URL %q must not include user credentials", raw)
|
||||
}
|
||||
if parsed.RawQuery != "" || parsed.Fragment != "" {
|
||||
return nil, fmt.Errorf("Pulse URL %q must not include query or fragment", raw)
|
||||
}
|
||||
|
||||
if port := parsed.Port(); port != "" {
|
||||
portNum, err := strconv.Atoi(port)
|
||||
if err != nil || portNum < 1 || portNum > 65535 {
|
||||
return nil, fmt.Errorf("invalid port %q: must be between 1 and 65535", port)
|
||||
}
|
||||
}
|
||||
|
||||
switch scheme := strings.ToLower(parsed.Scheme); scheme {
|
||||
case "https":
|
||||
if websocket {
|
||||
parsed.Scheme = "wss"
|
||||
} else {
|
||||
parsed.Scheme = "https"
|
||||
}
|
||||
case "http":
|
||||
if !IsLoopbackHost(parsed.Hostname()) {
|
||||
if websocket {
|
||||
return nil, fmt.Errorf("Pulse URL %q must use https/wss unless host is loopback", raw)
|
||||
}
|
||||
return nil, fmt.Errorf("Pulse URL %q must use https unless host is loopback", raw)
|
||||
}
|
||||
if websocket {
|
||||
parsed.Scheme = "ws"
|
||||
} else {
|
||||
parsed.Scheme = "http"
|
||||
}
|
||||
case "wss":
|
||||
if !websocket {
|
||||
return nil, fmt.Errorf("Pulse URL %q has unsupported scheme %q", raw, parsed.Scheme)
|
||||
}
|
||||
parsed.Scheme = "wss"
|
||||
case "ws":
|
||||
if !websocket {
|
||||
return nil, fmt.Errorf("Pulse URL %q has unsupported scheme %q", raw, parsed.Scheme)
|
||||
}
|
||||
if !IsLoopbackHost(parsed.Hostname()) {
|
||||
return nil, fmt.Errorf("Pulse URL %q must use https/wss unless host is loopback", raw)
|
||||
}
|
||||
parsed.Scheme = "ws"
|
||||
default:
|
||||
return nil, fmt.Errorf("Pulse URL %q has unsupported scheme %q", raw, parsed.Scheme)
|
||||
}
|
||||
|
||||
parsed.Host = strings.ToLower(parsed.Host)
|
||||
parsed.Path = strings.TrimRight(parsed.Path, "/")
|
||||
parsed.RawPath = strings.TrimRight(parsed.RawPath, "/")
|
||||
parsed.RawQuery = ""
|
||||
parsed.Fragment = ""
|
||||
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
// AppendURLPath appends path segments onto a validated base URL.
|
||||
func AppendURLPath(base *url.URL, segments ...string) *url.URL {
|
||||
cloned := cloneURL(base)
|
||||
if cloned == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
parts := []string{cloned.Path}
|
||||
for _, segment := range segments {
|
||||
trimmed := strings.Trim(segment, "/")
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
parts = append(parts, trimmed)
|
||||
}
|
||||
|
||||
joined := path.Join(parts...)
|
||||
if joined == "." || joined == "/" {
|
||||
cloned.Path = ""
|
||||
} else if strings.HasPrefix(joined, "/") {
|
||||
cloned.Path = joined
|
||||
} else {
|
||||
cloned.Path = "/" + joined
|
||||
}
|
||||
cloned.RawPath = ""
|
||||
cloned.Fragment = ""
|
||||
|
||||
return cloned
|
||||
}
|
||||
|
||||
// ResolveRelativeURL validates a rooted relative path and resolves it against base.
|
||||
func ResolveRelativeURL(base *url.URL, relativePath string) (*url.URL, error) {
|
||||
if base == nil {
|
||||
return nil, fmt.Errorf("base URL is required")
|
||||
}
|
||||
|
||||
trimmed := strings.TrimSpace(relativePath)
|
||||
if trimmed == "" {
|
||||
return nil, fmt.Errorf("relative path is required")
|
||||
}
|
||||
if strings.Contains(trimmed, `\`) {
|
||||
return nil, fmt.Errorf("relative path must not contain backslashes")
|
||||
}
|
||||
|
||||
ref, err := url.Parse(trimmed)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid relative path: %w", err)
|
||||
}
|
||||
if ref.IsAbs() || ref.Host != "" || ref.User != nil {
|
||||
return nil, fmt.Errorf("relative path must not include scheme or host")
|
||||
}
|
||||
if !strings.HasPrefix(ref.Path, "/") {
|
||||
return nil, fmt.Errorf("relative path must start with '/'")
|
||||
}
|
||||
|
||||
cleanedPath := path.Clean(ref.Path)
|
||||
if !strings.HasPrefix(cleanedPath, "/") {
|
||||
cleanedPath = "/" + cleanedPath
|
||||
}
|
||||
target := cloneURL(base)
|
||||
if target == nil {
|
||||
return nil, fmt.Errorf("base URL is required")
|
||||
}
|
||||
target.Path = joinURLPath(base.Path, cleanedPath)
|
||||
escapedPath := path.Clean(ref.EscapedPath())
|
||||
if !strings.HasPrefix(escapedPath, "/") {
|
||||
escapedPath = "/" + escapedPath
|
||||
}
|
||||
target.RawPath = joinURLPath(base.EscapedPath(), escapedPath)
|
||||
if target.RawPath == target.Path {
|
||||
target.RawPath = ""
|
||||
}
|
||||
target.RawQuery = ref.RawQuery
|
||||
target.Fragment = ""
|
||||
return target, nil
|
||||
}
|
||||
|
||||
// NewValidatedRequestWithContext builds an HTTP request from a pre-validated URL.
|
||||
func NewValidatedRequestWithContext(ctx context.Context, method string, target *url.URL, body io.Reader) (*http.Request, error) {
|
||||
if target == nil {
|
||||
return nil, fmt.Errorf("target URL is required")
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, requestPlaceholderURL, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.URL = cloneURL(target)
|
||||
req.Host = req.URL.Host
|
||||
req.RequestURI = ""
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// NewRelativeRequestWithContext validates a rooted relative path and builds a request from it.
|
||||
func NewRelativeRequestWithContext(ctx context.Context, method string, base *url.URL, relativePath string, body io.Reader) (*http.Request, error) {
|
||||
target, err := ResolveRelativeURL(base, relativePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewValidatedRequestWithContext(ctx, method, target, body)
|
||||
}
|
||||
240
pkg/securityutil/outbound_http.go
Normal file
240
pkg/securityutil/outbound_http.go
Normal file
|
|
@ -0,0 +1,240 @@
|
|||
package securityutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const defaultRestrictedRedirectLimit = 10
|
||||
|
||||
// RestrictedOutboundHTTPOptions controls outbound HTTP validation and transport policy.
|
||||
type RestrictedOutboundHTTPOptions struct {
|
||||
AllowedSchemes []string
|
||||
AllowPrivateIPs bool
|
||||
AllowLoopback bool
|
||||
TLSConfig *tls.Config
|
||||
ResolveIPAddrs func(ctx context.Context, host string) ([]net.IPAddr, error)
|
||||
}
|
||||
|
||||
var resolveOutboundFetchIPs = net.DefaultResolver.LookupIPAddr
|
||||
|
||||
func allowedOutboundSchemes(opts RestrictedOutboundHTTPOptions) []string {
|
||||
if len(opts.AllowedSchemes) == 0 {
|
||||
return []string{"http", "https"}
|
||||
}
|
||||
return opts.AllowedSchemes
|
||||
}
|
||||
|
||||
func isAllowedOutboundScheme(scheme string, allowed []string) bool {
|
||||
for _, candidate := range allowed {
|
||||
if strings.EqualFold(strings.TrimSpace(candidate), scheme) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func validateOutboundIP(ip net.IP, opts RestrictedOutboundHTTPOptions) error {
|
||||
if ip == nil {
|
||||
return fmt.Errorf("invalid IP address")
|
||||
}
|
||||
if ip.IsLoopback() && !opts.AllowLoopback {
|
||||
return fmt.Errorf("loopback addresses are not allowed")
|
||||
}
|
||||
if ip.Equal(net.ParseIP("169.254.169.254")) {
|
||||
return fmt.Errorf("metadata service address is not allowed")
|
||||
}
|
||||
if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
||||
return fmt.Errorf("link-local addresses are not allowed")
|
||||
}
|
||||
if ip.IsMulticast() {
|
||||
return fmt.Errorf("multicast addresses are not allowed")
|
||||
}
|
||||
if ip.IsUnspecified() {
|
||||
return fmt.Errorf("unspecified addresses are not allowed")
|
||||
}
|
||||
if !opts.AllowPrivateIPs && ip.IsPrivate() {
|
||||
return fmt.Errorf("private addresses are not allowed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func resolveOutboundIPAddrs(ctx context.Context, host string, opts RestrictedOutboundHTTPOptions) ([]net.IPAddr, error) {
|
||||
if resolver := opts.ResolveIPAddrs; resolver != nil {
|
||||
return resolver(ctx, host)
|
||||
}
|
||||
return resolveOutboundFetchIPs(ctx, host)
|
||||
}
|
||||
|
||||
func resolvePermittedOutboundIP(ctx context.Context, host string, opts RestrictedOutboundHTTPOptions) (net.IP, error) {
|
||||
host = strings.TrimSpace(host)
|
||||
if host == "" {
|
||||
return nil, fmt.Errorf("URL hostname is required")
|
||||
}
|
||||
|
||||
switch strings.ToLower(host) {
|
||||
case "metadata.google.internal", "metadata.goog":
|
||||
return nil, fmt.Errorf("metadata service host is not allowed")
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if err := validateOutboundIP(ip, opts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ip, nil
|
||||
}
|
||||
|
||||
baseCtx := ctx
|
||||
if baseCtx == nil {
|
||||
baseCtx = context.Background()
|
||||
}
|
||||
resolveCtx, cancel := context.WithTimeout(baseCtx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
addrs, err := resolveOutboundIPAddrs(resolveCtx, host, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve hostname %s: %w", host, err)
|
||||
}
|
||||
if len(addrs) == 0 {
|
||||
return nil, fmt.Errorf("hostname %s did not resolve", host)
|
||||
}
|
||||
|
||||
var blockedErr error
|
||||
for _, addr := range addrs {
|
||||
if err := validateOutboundIP(addr.IP, opts); err != nil {
|
||||
blockedErr = err
|
||||
continue
|
||||
}
|
||||
return addr.IP, nil
|
||||
}
|
||||
|
||||
if blockedErr != nil {
|
||||
return nil, fmt.Errorf("hostname %s resolves only to blocked addresses: %w", host, blockedErr)
|
||||
}
|
||||
return nil, fmt.Errorf("hostname %s did not resolve", host)
|
||||
}
|
||||
|
||||
// ValidateOutboundFetchURL validates a fully-qualified HTTP(S) URL against the restricted outbound policy.
|
||||
func ValidateOutboundFetchURL(ctx context.Context, raw string, opts RestrictedOutboundHTTPOptions) (*url.URL, error) {
|
||||
parsed, err := NormalizeAbsoluteHTTPURL(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
allowedSchemes := allowedOutboundSchemes(opts)
|
||||
if !isAllowedOutboundScheme(parsed.Scheme, allowedSchemes) {
|
||||
return nil, fmt.Errorf("URL scheme must be one of: %s", strings.Join(allowedSchemes, ", "))
|
||||
}
|
||||
if parsed.Fragment != "" {
|
||||
return nil, fmt.Errorf("URL fragments are not allowed")
|
||||
}
|
||||
|
||||
if _, err := resolvePermittedOutboundIP(ctx, parsed.Hostname(), opts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
func canonicalOriginHost(u *url.URL) string {
|
||||
if u == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
host := strings.ToLower(u.Hostname())
|
||||
port := u.Port()
|
||||
if port == "" {
|
||||
switch strings.ToLower(u.Scheme) {
|
||||
case "http":
|
||||
port = "80"
|
||||
case "https":
|
||||
port = "443"
|
||||
}
|
||||
}
|
||||
if host == "" || port == "" {
|
||||
return strings.ToLower(u.Host)
|
||||
}
|
||||
return net.JoinHostPort(host, port)
|
||||
}
|
||||
|
||||
func sameOriginRedirectPolicy(opts RestrictedOutboundHTTPOptions) func(req *http.Request, via []*http.Request) error {
|
||||
return func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) == 0 {
|
||||
return nil
|
||||
}
|
||||
if len(via) >= defaultRestrictedRedirectLimit {
|
||||
return fmt.Errorf("stopped after %d redirects", defaultRestrictedRedirectLimit)
|
||||
}
|
||||
|
||||
validated, err := ValidateOutboundFetchURL(req.Context(), req.URL.String(), opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
origin := via[0].URL
|
||||
if !strings.EqualFold(validated.Scheme, origin.Scheme) || canonicalOriginHost(validated) != canonicalOriginHost(origin) {
|
||||
return fmt.Errorf("redirects must stay on the same origin")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func cloneRestrictedTransport(tlsConfig *tls.Config) *http.Transport {
|
||||
transport, ok := http.DefaultTransport.(*http.Transport)
|
||||
var clone *http.Transport
|
||||
if ok && transport != nil {
|
||||
clone = transport.Clone()
|
||||
} else {
|
||||
clone = &http.Transport{Proxy: http.ProxyFromEnvironment}
|
||||
}
|
||||
|
||||
switch {
|
||||
case tlsConfig != nil:
|
||||
clone.TLSClientConfig = tlsConfig.Clone()
|
||||
case clone.TLSClientConfig != nil:
|
||||
clone.TLSClientConfig = clone.TLSClientConfig.Clone()
|
||||
default:
|
||||
clone.TLSClientConfig = &tls.Config{}
|
||||
}
|
||||
|
||||
if clone.TLSClientConfig.MinVersion < tls.VersionTLS12 {
|
||||
clone.TLSClientConfig.MinVersion = tls.VersionTLS12
|
||||
}
|
||||
|
||||
return clone
|
||||
}
|
||||
|
||||
// NewRestrictedOutboundHTTPClient returns an HTTP client that validates redirects and pins direct outbound dials
|
||||
// to the first permitted resolved IP for the requested host.
|
||||
func NewRestrictedOutboundHTTPClient(timeout time.Duration, opts RestrictedOutboundHTTPOptions) *http.Client {
|
||||
transport := cloneRestrictedTransport(opts.TLSConfig)
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse outbound address %q: %w", addr, err)
|
||||
}
|
||||
|
||||
permittedIP, err := resolvePermittedOutboundIP(ctx, host, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dialer := net.Dialer{Timeout: 10 * time.Second}
|
||||
return dialer.DialContext(ctx, network, net.JoinHostPort(permittedIP.String(), port))
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
CheckRedirect: sameOriginRedirectPolicy(opts),
|
||||
}
|
||||
if timeout > 0 {
|
||||
client.Timeout = timeout
|
||||
}
|
||||
return client
|
||||
}
|
||||
62
pkg/securityutil/websocket_origin.go
Normal file
62
pkg/securityutil/websocket_origin.go
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
package securityutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// NormalizeWebSocketOriginHost normalizes Origin/Host values for same-origin comparison.
|
||||
func NormalizeWebSocketOriginHost(host string) string {
|
||||
normalized := strings.TrimSpace(strings.ToLower(host))
|
||||
if normalized == "" {
|
||||
return normalized
|
||||
}
|
||||
|
||||
parsedHost, parsedPort, err := net.SplitHostPort(normalized)
|
||||
if err != nil {
|
||||
return normalized
|
||||
}
|
||||
if parsedPort == "80" || parsedPort == "443" {
|
||||
return parsedHost
|
||||
}
|
||||
return net.JoinHostPort(parsedHost, parsedPort)
|
||||
}
|
||||
|
||||
// SameHostWebSocketOrigin validates that an Origin header is http(s) and matches the request host.
|
||||
func SameHostWebSocketOrigin(origin string, requestHost string) bool {
|
||||
parsed, err := url.Parse(strings.TrimSpace(origin))
|
||||
if err != nil || parsed.Host == "" {
|
||||
return false
|
||||
}
|
||||
if parsed.Scheme != "http" && parsed.Scheme != "https" {
|
||||
return false
|
||||
}
|
||||
|
||||
return NormalizeWebSocketOriginHost(parsed.Host) == NormalizeWebSocketOriginHost(requestHost)
|
||||
}
|
||||
|
||||
// HTTPOriginForWebSocketBaseURL returns the http(s) Origin header for a Pulse websocket base URL.
|
||||
func HTTPOriginForWebSocketBaseURL(raw string) (string, error) {
|
||||
parsed, err := NormalizePulseWebSocketBaseURL(raw)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
switch parsed.Scheme {
|
||||
case "ws":
|
||||
parsed.Scheme = "http"
|
||||
case "wss":
|
||||
parsed.Scheme = "https"
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported websocket origin scheme %q", parsed.Scheme)
|
||||
}
|
||||
|
||||
parsed.Path = ""
|
||||
parsed.RawPath = ""
|
||||
parsed.RawQuery = ""
|
||||
parsed.Fragment = ""
|
||||
|
||||
return parsed.String(), nil
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue