mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-05-06 16:16:26 +00:00
968 lines
28 KiB
Go
968 lines
28 KiB
Go
package relay
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ecdh"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"math/rand/v2"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/rcourtman/pulse-go-rewrite/internal/utils"
|
|
"github.com/rs/zerolog"
|
|
)
|
|
|
|
// ErrNotConnected is returned when attempting to send on a disconnected client.
|
|
var ErrNotConnected = errors.New("relay client not connected")
|
|
|
|
const (
|
|
// Reconnect backoff parameters
|
|
baseReconnectDelay = 5 * time.Second
|
|
maxReconnectDelay = 5 * time.Minute
|
|
reconnectJitter = 0.1
|
|
|
|
// WebSocket parameters
|
|
wsPingInterval = 25 * time.Second
|
|
wsWriteWait = 10 * time.Second
|
|
wsHandshakeWait = 15 * time.Second
|
|
sendChBufferSize = 256
|
|
|
|
// wsReadLimit bounds inbound WebSocket message size to a single relay frame.
|
|
wsReadLimit = HeaderSize + MaxPayloadSize
|
|
)
|
|
|
|
// maxConcurrentDataHandlers limits active DATA stream handlers per connection.
|
|
// This prevents unbounded goroutine growth if the relay floods DATA frames.
|
|
var maxConcurrentDataHandlers = 64
|
|
|
|
const (
|
|
wsMaxMessageSize = wsReadLimit
|
|
wsPongWait = 60 * time.Second
|
|
proxyStreamTimeout = 15 * time.Minute
|
|
relayOverloadedReason = "relay proxy overloaded"
|
|
)
|
|
|
|
// TokenValidator validates an API token and returns the raw token if valid.
|
|
type TokenValidator func(token string) bool
|
|
|
|
// channelState holds per-channel state including auth and encryption.
|
|
type channelState struct {
|
|
apiToken string
|
|
encryption *ChannelEncryption // nil until key exchange completes
|
|
ephemeral *ecdh.PrivateKey // ephemeral keypair, cleared after handshake
|
|
}
|
|
|
|
// ClientDeps holds injectable dependencies for the relay client.
|
|
type ClientDeps struct {
|
|
LicenseTokenFunc func() string // returns the raw license JWT
|
|
TokenValidator TokenValidator // validates API tokens from CHANNEL_OPEN
|
|
LocalAddr string // e.g. "127.0.0.1:7655"
|
|
ServerVersion string // Pulse version for ClientVersion in REGISTER
|
|
IdentityPubKey string // base64-encoded Ed25519 public key for MITM prevention
|
|
IdentityPrivateKey string // base64-encoded Ed25519 private key for signing KEY_EXCHANGE
|
|
}
|
|
|
|
// ClientStatus represents the current state of the relay client.
|
|
type ClientStatus struct {
|
|
Connected bool `json:"connected"`
|
|
InstanceID string `json:"instance_id,omitempty"`
|
|
ActiveChannels int `json:"active_channels"`
|
|
LastError string `json:"last_error,omitempty"`
|
|
ReconnectIn string `json:"reconnect_in,omitempty"`
|
|
}
|
|
|
|
// Client maintains a persistent connection to the relay server.
|
|
type Client struct {
|
|
config Config
|
|
deps ClientDeps
|
|
proxy *HTTPProxy
|
|
logger zerolog.Logger
|
|
// startupErr captures invalid static config/dependency inputs.
|
|
startupErr error
|
|
|
|
// Connection state (protected by mu)
|
|
mu sync.RWMutex
|
|
conn *websocket.Conn
|
|
sendCh chan<- []byte // per-connection send channel (nil when disconnected)
|
|
instanceID string
|
|
sessionToken string
|
|
channels map[uint32]*channelState // channelID → channel state
|
|
connected bool
|
|
lastError string
|
|
nextRetryAt time.Time
|
|
|
|
// Lifecycle
|
|
lifecycleMu sync.Mutex
|
|
cancel context.CancelFunc
|
|
done chan struct{}
|
|
}
|
|
|
|
// NewClient creates a new relay client.
|
|
func NewClient(cfg Config, deps ClientDeps, logger zerolog.Logger) *Client {
|
|
cfg, deps, warnings, startupErr := normalizeClientInputs(cfg, deps)
|
|
for _, warning := range warnings {
|
|
logger.Warn().Str("warning", warning).Msg("Normalized relay client configuration")
|
|
}
|
|
if startupErr != nil {
|
|
logger.Error().Err(startupErr).Msg("Invalid relay client configuration")
|
|
}
|
|
|
|
return &Client{
|
|
config: cfg,
|
|
deps: deps,
|
|
proxy: NewHTTPProxy(deps.LocalAddr, logger),
|
|
logger: logger,
|
|
channels: make(map[uint32]*channelState),
|
|
startupErr: startupErr,
|
|
}
|
|
}
|
|
|
|
// Run starts the reconnect loop. Blocks until ctx is cancelled.
|
|
func (c *Client) Run(ctx context.Context) error {
|
|
runCtx, runCancel := context.WithCancel(ctx)
|
|
runDone := make(chan struct{})
|
|
|
|
c.lifecycleMu.Lock()
|
|
c.cancel = runCancel
|
|
c.done = runDone
|
|
c.lifecycleMu.Unlock()
|
|
|
|
defer func() {
|
|
c.lifecycleMu.Lock()
|
|
if c.done == runDone {
|
|
c.done = nil
|
|
c.cancel = nil
|
|
}
|
|
c.lifecycleMu.Unlock()
|
|
close(runDone)
|
|
}()
|
|
|
|
ctx = runCtx
|
|
|
|
if c.startupErr != nil {
|
|
c.mu.Lock()
|
|
c.lastError = c.startupErr.Error()
|
|
c.connected = false
|
|
c.nextRetryAt = time.Time{}
|
|
c.mu.Unlock()
|
|
return c.startupErr
|
|
}
|
|
|
|
consecutiveFailures := 0
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
default:
|
|
}
|
|
|
|
connected, err := c.connectAndHandle(ctx)
|
|
if err != nil {
|
|
if ctx.Err() != nil {
|
|
return ctx.Err()
|
|
}
|
|
|
|
if isSessionResumeRejected(err) {
|
|
c.mu.Lock()
|
|
c.lastError = err.Error()
|
|
c.connected = false
|
|
c.nextRetryAt = time.Time{}
|
|
c.mu.Unlock()
|
|
|
|
consecutiveFailures = 0
|
|
c.logger.Warn().Err(err).Msg("relay session resume rejected, retrying fresh registration")
|
|
continue
|
|
}
|
|
|
|
consecutiveFailures = nextConsecutiveFailures(consecutiveFailures, connected)
|
|
c.mu.Lock()
|
|
c.lastError = err.Error()
|
|
c.connected = false
|
|
c.nextRetryAt = time.Time{}
|
|
c.mu.Unlock()
|
|
|
|
delay := c.backoffDelay(consecutiveFailures)
|
|
if delay <= 0 {
|
|
c.logger.Warn().
|
|
Int("failures", consecutiveFailures).
|
|
Dur("computed_retry_in", delay).
|
|
Msg("computed non-positive relay retry delay; using base reconnect delay")
|
|
delay = baseReconnectDelay
|
|
}
|
|
|
|
if consecutiveFailures >= 3 {
|
|
c.logger.Warn().Err(err).
|
|
Int("failures", consecutiveFailures).
|
|
Dur("retry_in", delay).
|
|
Msg("relay connection failed repeatedly")
|
|
} else {
|
|
c.logger.Debug().Err(err).
|
|
Dur("retry_in", delay).
|
|
Msg("relay connection interrupted, reconnecting")
|
|
}
|
|
|
|
// If it's a license error, pause longer
|
|
if isLicenseError(err) {
|
|
delay = maxReconnectDelay
|
|
c.logger.Warn().Msg("license error from relay server, pausing reconnect")
|
|
}
|
|
|
|
c.mu.Lock()
|
|
c.nextRetryAt = time.Now().Add(delay)
|
|
c.mu.Unlock()
|
|
|
|
timer := time.NewTimer(delay)
|
|
select {
|
|
case <-ctx.Done():
|
|
c.mu.Lock()
|
|
c.nextRetryAt = time.Time{}
|
|
c.mu.Unlock()
|
|
if !timer.Stop() {
|
|
select {
|
|
case <-timer.C:
|
|
default:
|
|
}
|
|
}
|
|
return ctx.Err()
|
|
case <-timer.C:
|
|
c.mu.Lock()
|
|
c.nextRetryAt = time.Time{}
|
|
c.mu.Unlock()
|
|
}
|
|
} else {
|
|
consecutiveFailures = 0
|
|
}
|
|
}
|
|
}
|
|
|
|
// nextConsecutiveFailures advances the reconnect failure streak.
|
|
// A session that successfully registers is treated as a recovery point, so
|
|
// subsequent disconnects start a new streak at 1 instead of compounding old failures.
|
|
func nextConsecutiveFailures(current int, connected bool) int {
|
|
if connected {
|
|
return 1
|
|
}
|
|
return current + 1
|
|
}
|
|
|
|
// Close stops the client and closes the connection.
|
|
func (c *Client) Close() {
|
|
c.lifecycleMu.Lock()
|
|
cancel := c.cancel
|
|
done := c.done
|
|
c.lifecycleMu.Unlock()
|
|
|
|
if cancel == nil {
|
|
return
|
|
}
|
|
|
|
cancel()
|
|
// Wait for Run to finish
|
|
timer := time.NewTimer(5 * time.Second)
|
|
defer timer.Stop()
|
|
select {
|
|
case <-done:
|
|
case <-timer.C:
|
|
}
|
|
|
|
c.proxy.Close()
|
|
}
|
|
|
|
// Status returns the current client status.
|
|
func (c *Client) Status() ClientStatus {
|
|
c.mu.RLock()
|
|
defer c.mu.RUnlock()
|
|
reconnectIn := ""
|
|
if !c.nextRetryAt.IsZero() {
|
|
remaining := time.Until(c.nextRetryAt)
|
|
if remaining > 0 {
|
|
reconnectIn = remaining.Round(time.Second).String()
|
|
}
|
|
}
|
|
return ClientStatus{
|
|
Connected: c.connected,
|
|
InstanceID: c.instanceID,
|
|
ActiveChannels: len(c.channels),
|
|
LastError: c.lastError,
|
|
ReconnectIn: reconnectIn,
|
|
}
|
|
}
|
|
|
|
// SendPushNotification sends a push notification through the relay.
|
|
// Returns ErrNotConnected if the client has not completed registration
|
|
// with the relay server.
|
|
func (c *Client) SendPushNotification(notification PushNotificationPayload) error {
|
|
c.mu.RLock()
|
|
instanceID := c.instanceID
|
|
ch := c.sendCh
|
|
connected := c.connected
|
|
c.mu.RUnlock()
|
|
|
|
if notification.InstanceID == "" {
|
|
notification.InstanceID = instanceID
|
|
}
|
|
|
|
frame, err := NewControlFrame(FramePushNotification, 0, notification)
|
|
if err != nil {
|
|
return fmt.Errorf("build push frame: %w", err)
|
|
}
|
|
data, err := EncodeFrame(frame)
|
|
if err != nil {
|
|
return fmt.Errorf("encode push frame: %w", err)
|
|
}
|
|
|
|
if ch == nil || !connected {
|
|
return ErrNotConnected
|
|
}
|
|
|
|
select {
|
|
case ch <- data:
|
|
return nil
|
|
default:
|
|
return fmt.Errorf("send channel full")
|
|
}
|
|
}
|
|
|
|
// connectAndHandle establishes a relay session and handles frames until disconnect.
|
|
// The returned bool reports whether registration succeeded for this attempt.
|
|
func (c *Client) connectAndHandle(ctx context.Context) (bool, error) {
|
|
dialer := websocket.Dialer{
|
|
HandshakeTimeout: wsHandshakeWait,
|
|
}
|
|
tlsConfig, err := relayTLSConfig()
|
|
if err != nil {
|
|
return false, fmt.Errorf("build relay tls config: %w", err)
|
|
}
|
|
dialer.TLSClientConfig = tlsConfig
|
|
|
|
c.logger.Info().Str("url", c.config.ServerURL).Msg("connecting to relay server")
|
|
|
|
conn, _, err := dialer.DialContext(ctx, c.config.ServerURL, nil)
|
|
if err != nil {
|
|
return false, fmt.Errorf("dial relay: %w", err)
|
|
}
|
|
conn.SetReadLimit(wsReadLimit)
|
|
|
|
// Per-connection send channel — no races because each writePump gets its own
|
|
sendCh := make(chan []byte, sendChBufferSize)
|
|
|
|
c.mu.Lock()
|
|
c.conn = conn
|
|
c.channels = make(map[uint32]*channelState)
|
|
c.mu.Unlock()
|
|
|
|
defer func() {
|
|
c.mu.Lock()
|
|
instanceID := c.instanceID
|
|
activeChannels := len(c.channels)
|
|
c.sendCh = nil
|
|
c.conn = nil
|
|
c.channels = make(map[uint32]*channelState)
|
|
c.connected = false
|
|
c.mu.Unlock()
|
|
conn.Close()
|
|
c.logger.Info().
|
|
Str("instance_id", instanceID).
|
|
Int("active_channels", activeChannels).
|
|
Msg("Relay connection closed")
|
|
}()
|
|
|
|
// Register with relay server
|
|
if err := c.register(conn); err != nil {
|
|
return false, fmt.Errorf("register: %w", err)
|
|
}
|
|
|
|
// Enforce connection liveness: each Pong extends the read deadline.
|
|
conn.SetReadLimit(int64(wsMaxMessageSize))
|
|
_ = conn.SetReadDeadline(time.Now().Add(wsPongWait))
|
|
conn.SetPongHandler(func(string) error {
|
|
return conn.SetReadDeadline(time.Now().Add(wsPongWait))
|
|
})
|
|
|
|
// Expose sendCh only after successful registration so
|
|
// SendPushNotification callers can't enqueue frames during
|
|
// the handshake window before a relay session exists.
|
|
c.mu.Lock()
|
|
c.sendCh = sendCh
|
|
c.connected = true
|
|
c.lastError = ""
|
|
c.nextRetryAt = time.Time{}
|
|
c.mu.Unlock()
|
|
|
|
c.logger.Info().Str("instance_id", c.instanceID).Msg("registered with relay server")
|
|
|
|
// Per-connection context: cancelled when this connection ends (for any
|
|
// reason), which tears down the write pump and any in-flight stream
|
|
// goroutines spawned by handleData. Without this, stream goroutines
|
|
// would keep running against a stale sendCh until the whole client stops.
|
|
connCtx, connCancel := context.WithCancel(ctx)
|
|
defer connCancel()
|
|
go func() {
|
|
<-connCtx.Done()
|
|
_ = conn.Close()
|
|
}()
|
|
|
|
go c.writePump(connCtx, conn, sendCh)
|
|
|
|
// Read pump (blocking) — passes connCtx so handleData streams inherit it
|
|
// Rate-limit concurrent DATA stream handlers per connection.
|
|
dataLimiter := make(chan struct{}, maxConcurrentDataHandlers)
|
|
|
|
return true, c.readPump(connCtx, conn, sendCh, dataLimiter)
|
|
}
|
|
|
|
func relayTLSConfig() (*tls.Config, error) {
|
|
tlsConfig := &tls.Config{MinVersion: tls.VersionTLS12}
|
|
|
|
caBundle := strings.TrimSpace(os.Getenv("SSL_CERT_FILE"))
|
|
if caBundle == "" {
|
|
return tlsConfig, nil
|
|
}
|
|
|
|
caData, err := os.ReadFile(caBundle)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read relay CA bundle %s: %w", caBundle, err)
|
|
}
|
|
|
|
pool, err := x509.SystemCertPool()
|
|
if err != nil || pool == nil {
|
|
pool = x509.NewCertPool()
|
|
}
|
|
if ok := pool.AppendCertsFromPEM(caData); !ok {
|
|
return nil, fmt.Errorf("relay CA bundle %s does not contain any certificates", caBundle)
|
|
}
|
|
|
|
tlsConfig.RootCAs = pool
|
|
return tlsConfig, nil
|
|
}
|
|
|
|
func (c *Client) register(conn *websocket.Conn) error {
|
|
if c.deps.LicenseTokenFunc == nil {
|
|
return fmt.Errorf("license token provider not configured")
|
|
}
|
|
token := strings.TrimSpace(c.deps.LicenseTokenFunc())
|
|
if token == "" {
|
|
return fmt.Errorf("no license token available")
|
|
}
|
|
|
|
payload := RegisterPayload{
|
|
LicenseToken: token,
|
|
InstanceHint: c.config.InstanceSecret,
|
|
ClientVersion: c.deps.ServerVersion,
|
|
IdentityPubKey: c.deps.IdentityPubKey,
|
|
}
|
|
attemptedSessionResume := false
|
|
|
|
// Reuse session token if we have one from a previous connection.
|
|
// When reconnecting, send the derived instance_id (from the prior
|
|
// REGISTER_ACK) as InstanceHint instead of the raw secret. The
|
|
// relay server's session reconnect path looks up by instance_id
|
|
// directly — it needs the derived ID, not the secret used to
|
|
// compute it during the initial registration.
|
|
//
|
|
// Only set SessionToken when we also have the derived instanceID,
|
|
// because the bridge takes the session fast-path whenever both
|
|
// SessionToken and InstanceHint are present. If SessionToken is
|
|
// set without the derived ID, the bridge would query by raw secret
|
|
// and fail with AUTH_FAILED (no license fallback).
|
|
c.mu.RLock()
|
|
if c.sessionToken != "" && c.instanceID != "" {
|
|
payload.SessionToken = c.sessionToken
|
|
payload.InstanceHint = c.instanceID
|
|
attemptedSessionResume = true
|
|
}
|
|
c.mu.RUnlock()
|
|
|
|
frame, err := NewControlFrame(FrameRegister, 0, payload)
|
|
if err != nil {
|
|
return fmt.Errorf("build register frame: %w", err)
|
|
}
|
|
|
|
data, err := EncodeFrame(frame)
|
|
if err != nil {
|
|
return fmt.Errorf("encode register frame: %w", err)
|
|
}
|
|
|
|
if err := conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil {
|
|
return fmt.Errorf("set write deadline: %w", err)
|
|
}
|
|
if err := conn.WriteMessage(websocket.BinaryMessage, data); err != nil {
|
|
return fmt.Errorf("send register: %w", err)
|
|
}
|
|
|
|
// Wait for REGISTER_ACK or ERROR
|
|
if err := conn.SetReadDeadline(time.Now().Add(wsHandshakeWait)); err != nil {
|
|
return fmt.Errorf("set read deadline: %w", err)
|
|
}
|
|
_, msg, err := conn.ReadMessage()
|
|
if err != nil {
|
|
return fmt.Errorf("read register response: %w", err)
|
|
}
|
|
if err := conn.SetReadDeadline(time.Time{}); err != nil {
|
|
return fmt.Errorf("clear read deadline: %w", err)
|
|
}
|
|
|
|
frame, err = DecodeFrame(msg)
|
|
if err != nil {
|
|
return fmt.Errorf("decode register response: %w", err)
|
|
}
|
|
|
|
switch frame.Type {
|
|
case FrameRegisterAck:
|
|
var ack RegisterAckPayload
|
|
if err := UnmarshalControlPayload(frame.Payload, &ack); err != nil {
|
|
return fmt.Errorf("unmarshal register ack: %w", err)
|
|
}
|
|
c.mu.Lock()
|
|
c.instanceID = ack.InstanceID
|
|
c.sessionToken = ack.SessionToken
|
|
c.mu.Unlock()
|
|
return nil
|
|
|
|
case FrameError:
|
|
var errPayload ErrorPayload
|
|
if err := UnmarshalControlPayload(frame.Payload, &errPayload); err != nil {
|
|
return fmt.Errorf("unmarshal error: %w", err)
|
|
}
|
|
if attemptedSessionResume && shouldResetSessionAfterRegisterError(errPayload.Code) {
|
|
c.mu.Lock()
|
|
c.sessionToken = ""
|
|
c.instanceID = ""
|
|
c.mu.Unlock()
|
|
return &sessionResumeRejectedError{code: errPayload.Code, message: errPayload.Message}
|
|
}
|
|
return fmt.Errorf("relay error (%s): %s", errPayload.Code, errPayload.Message)
|
|
|
|
default:
|
|
return fmt.Errorf("unexpected frame type during registration: %s", FrameTypeName(frame.Type))
|
|
}
|
|
}
|
|
|
|
func (c *Client) readPump(ctx context.Context, conn *websocket.Conn, sendCh chan<- []byte, dataLimiter chan struct{}) error {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
default:
|
|
}
|
|
|
|
_, msg, err := conn.ReadMessage()
|
|
if err != nil {
|
|
return fmt.Errorf("read message: %w", err)
|
|
}
|
|
_ = conn.SetReadDeadline(time.Now().Add(wsPongWait))
|
|
|
|
frame, err := DecodeFrame(msg)
|
|
if err != nil {
|
|
c.logger.Warn().Err(err).Msg("failed to decode frame, skipping")
|
|
continue
|
|
}
|
|
|
|
switch frame.Type {
|
|
case FrameChannelOpen:
|
|
c.handleChannelOpen(frame, sendCh)
|
|
|
|
case FrameKeyExchange:
|
|
c.handleKeyExchange(frame, sendCh)
|
|
|
|
case FrameData:
|
|
c.handleData(ctx, frame, sendCh, dataLimiter)
|
|
|
|
case FrameChannelClose:
|
|
c.handleChannelClose(frame)
|
|
|
|
case FramePing:
|
|
queueFrame(sendCh, NewPongFrame(), c.logger)
|
|
|
|
case FrameDrain:
|
|
var drain DrainPayload
|
|
if err := UnmarshalControlPayload(frame.Payload, &drain); err != nil {
|
|
c.logger.Debug().Err(err).Msg("Failed to unmarshal DRAIN payload")
|
|
} else {
|
|
c.logger.Info().Str("reason", drain.Reason).Msg("Relay server draining, will reconnect")
|
|
}
|
|
return nil // exit readPump, triggers reconnect
|
|
|
|
case FrameError:
|
|
var errPayload ErrorPayload
|
|
if err := UnmarshalControlPayload(frame.Payload, &errPayload); err != nil {
|
|
c.logger.Warn().Err(err).Msg("Failed to unmarshal ERROR payload")
|
|
} else {
|
|
c.logger.Warn().Str("code", errPayload.Code).Str("message", errPayload.Message).Msg("Relay error")
|
|
if errPayload.Code == ErrCodeLicenseInvalid || errPayload.Code == ErrCodeLicenseExpired {
|
|
return &licenseError{code: errPayload.Code, message: errPayload.Message}
|
|
}
|
|
}
|
|
|
|
default:
|
|
c.logger.Debug().Str("type", FrameTypeName(frame.Type)).Msg("ignoring unhandled frame type")
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Client) writePump(ctx context.Context, conn *websocket.Conn, sendCh <-chan []byte) {
|
|
// Always close the socket on writer exit so blocked readers unblock and reconnect.
|
|
defer conn.Close()
|
|
|
|
ticker := time.NewTicker(wsPingInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
// Send close message
|
|
if err := conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil {
|
|
c.logger.Debug().Err(err).Msg("set write deadline failed")
|
|
}
|
|
if err := conn.WriteMessage(websocket.CloseMessage,
|
|
websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
|
|
c.logger.Debug().Err(err).Msg("WS close frame write failed")
|
|
}
|
|
return
|
|
|
|
case data, ok := <-sendCh:
|
|
if !ok {
|
|
return
|
|
}
|
|
if err := conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil {
|
|
c.logger.Debug().Err(err).Msg("set write deadline failed")
|
|
}
|
|
if err := conn.WriteMessage(websocket.BinaryMessage, data); err != nil {
|
|
c.logger.Debug().Err(err).Msg("write failed")
|
|
return
|
|
}
|
|
|
|
case <-ticker.C:
|
|
if err := conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil {
|
|
c.logger.Debug().Err(err).Msg("set write deadline failed")
|
|
}
|
|
if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
|
c.logger.Debug().Err(err).Msg("WS ping failed")
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Client) handleChannelOpen(frame Frame, sendCh chan<- []byte) {
|
|
var payload ChannelOpenPayload
|
|
if err := UnmarshalControlPayload(frame.Payload, &payload); err != nil {
|
|
c.logger.Warn().Err(err).Msg("failed to unmarshal CHANNEL_OPEN")
|
|
return
|
|
}
|
|
|
|
if c.deps.TokenValidator == nil {
|
|
c.logger.Error().Uint32("channel", payload.ChannelID).Msg("Rejecting channel: token validator not configured")
|
|
closeFrame, err := NewControlFrame(FrameChannelClose, payload.ChannelID, ChannelClosePayload{
|
|
ChannelID: payload.ChannelID,
|
|
Reason: "token validation unavailable",
|
|
})
|
|
if err == nil {
|
|
queueFrame(sendCh, closeFrame, c.logger)
|
|
}
|
|
return
|
|
}
|
|
|
|
payload.AuthToken = strings.TrimSpace(payload.AuthToken)
|
|
|
|
// Validate the auth token
|
|
if payload.AuthToken == "" || !c.deps.TokenValidator(payload.AuthToken) {
|
|
c.logger.Warn().Uint32("channel", payload.ChannelID).Msg("Rejecting channel: invalid auth token")
|
|
closeFrame, err := NewControlFrame(FrameChannelClose, payload.ChannelID, ChannelClosePayload{
|
|
ChannelID: payload.ChannelID,
|
|
Reason: "invalid auth token",
|
|
})
|
|
if err != nil {
|
|
c.logger.Warn().Err(err).Uint32("channel", payload.ChannelID).Msg("Failed to build CHANNEL_CLOSE frame")
|
|
} else {
|
|
queueFrame(sendCh, closeFrame, c.logger)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Accept: store channel and echo CHANNEL_OPEN back
|
|
c.mu.Lock()
|
|
c.channels[payload.ChannelID] = &channelState{apiToken: payload.AuthToken}
|
|
c.mu.Unlock()
|
|
|
|
c.logger.Info().Uint32("channel", payload.ChannelID).Msg("channel opened")
|
|
|
|
// Echo CHANNEL_OPEN to acknowledge
|
|
ackFrame, err := NewControlFrame(FrameChannelOpen, payload.ChannelID, payload)
|
|
if err != nil {
|
|
c.logger.Warn().Err(err).Uint32("channel", payload.ChannelID).Msg("Failed to build CHANNEL_OPEN ack frame")
|
|
} else {
|
|
queueFrame(sendCh, ackFrame, c.logger)
|
|
}
|
|
}
|
|
|
|
func (c *Client) handleData(connCtx context.Context, frame Frame, sendCh chan<- []byte, dataLimiter chan struct{}) {
|
|
channelID := frame.Channel
|
|
|
|
// Snapshot channel state under lock so the goroutine below doesn't race
|
|
// with handleKeyExchange writing state.encryption.
|
|
c.mu.RLock()
|
|
state, ok := c.channels[channelID]
|
|
var enc *ChannelEncryption
|
|
var apiToken string
|
|
if ok {
|
|
enc = state.encryption
|
|
apiToken = state.apiToken
|
|
}
|
|
c.mu.RUnlock()
|
|
|
|
if !ok {
|
|
c.logger.Warn().Uint32("channel", channelID).Msg("DATA for unknown channel")
|
|
return
|
|
}
|
|
|
|
payload := frame.Payload
|
|
// Preserve relay frame arrival order through decryption. The channel nonce
|
|
// guard is strictly monotonic; decrypting in background goroutines lets
|
|
// later frames win the race and falsely trip replay protection.
|
|
if enc != nil {
|
|
decrypted, err := enc.Decrypt(payload)
|
|
if err != nil {
|
|
c.logger.Warn().Err(err).Uint32("channel", channelID).Msg("failed to decrypt DATA payload")
|
|
return
|
|
}
|
|
payload = decrypted
|
|
}
|
|
|
|
select {
|
|
case dataLimiter <- struct{}{}:
|
|
default:
|
|
c.handleOverloadedData(channelID, payload, enc, sendCh)
|
|
return
|
|
}
|
|
|
|
// Handle in background goroutine so we don't block the read pump
|
|
go func(payload []byte) {
|
|
defer func() { <-dataLimiter }()
|
|
|
|
// Derive from the connection context so streams are cancelled on disconnect.
|
|
// The 15-minute timeout is a safety net for runaway streams.
|
|
ctx, cancel := context.WithTimeout(connCtx, proxyStreamTimeout)
|
|
defer cancel()
|
|
|
|
err := c.proxy.HandleStreamRequest(ctx, payload, apiToken, func(respPayload []byte) {
|
|
if enc != nil {
|
|
encrypted, err := enc.Encrypt(respPayload)
|
|
if err != nil {
|
|
c.logger.Warn().Err(err).Uint32("channel", channelID).Msg("failed to encrypt DATA response")
|
|
return
|
|
}
|
|
respPayload = encrypted
|
|
}
|
|
respFrame := NewFrame(FrameData, channelID, respPayload)
|
|
queueFrame(sendCh, respFrame, c.logger)
|
|
})
|
|
if err != nil && connCtx.Err() == nil {
|
|
c.logger.Warn().Err(err).Uint32("channel", channelID).Msg("stream proxy error")
|
|
}
|
|
}(payload)
|
|
}
|
|
|
|
func (c *Client) handleOverloadedData(channelID uint32, payload []byte, enc *ChannelEncryption, sendCh chan<- []byte) {
|
|
requestID := extractProxyRequestID(payload)
|
|
respPayload := c.proxy.errorResponse(requestID, http.StatusServiceUnavailable, relayOverloadedReason)
|
|
if enc != nil {
|
|
encrypted, err := enc.Encrypt(respPayload)
|
|
if err != nil {
|
|
c.logger.Warn().Err(err).Uint32("channel", channelID).Msg("Failed to encrypt overload response")
|
|
return
|
|
}
|
|
respPayload = encrypted
|
|
}
|
|
|
|
c.logger.Warn().
|
|
Uint32("channel", channelID).
|
|
Str("request_id", requestID).
|
|
Int("max_in_flight", maxConcurrentDataHandlers).
|
|
Msg("DATA handler limit reached, rejecting request")
|
|
queueFrame(sendCh, NewFrame(FrameData, channelID, respPayload), c.logger)
|
|
}
|
|
|
|
func extractProxyRequestID(payload []byte) string {
|
|
var req ProxyRequest
|
|
if err := json.Unmarshal(payload, &req); err != nil {
|
|
return ""
|
|
}
|
|
return req.ID
|
|
}
|
|
|
|
func (c *Client) handleKeyExchange(frame Frame, sendCh chan<- []byte) {
|
|
channelID := frame.Channel
|
|
|
|
c.mu.RLock()
|
|
state, ok := c.channels[channelID]
|
|
c.mu.RUnlock()
|
|
|
|
if !ok {
|
|
c.logger.Warn().Uint32("channel", channelID).Msg("KEY_EXCHANGE for unknown channel")
|
|
return
|
|
}
|
|
|
|
// Unmarshal the app's ephemeral public key
|
|
appPubBytes, _, err := UnmarshalKeyExchangePayload(frame.Payload)
|
|
if err != nil {
|
|
c.logger.Warn().Err(err).Uint32("channel", channelID).Msg("failed to unmarshal KEY_EXCHANGE")
|
|
return
|
|
}
|
|
|
|
appPubKey, err := ecdh.X25519().NewPublicKey(appPubBytes)
|
|
if err != nil {
|
|
c.logger.Warn().Err(err).Uint32("channel", channelID).Msg("invalid X25519 public key in KEY_EXCHANGE")
|
|
return
|
|
}
|
|
|
|
// Generate instance's ephemeral X25519 keypair
|
|
instancePriv, err := GenerateEphemeralKeyPair()
|
|
if err != nil {
|
|
c.logger.Error().Err(err).Uint32("channel", channelID).Msg("failed to generate ephemeral keypair")
|
|
return
|
|
}
|
|
|
|
// Derive channel keys
|
|
encryption, err := DeriveChannelKeys(instancePriv, appPubKey, true)
|
|
if err != nil {
|
|
c.logger.Error().Err(err).Uint32("channel", channelID).Msg("failed to derive channel keys")
|
|
return
|
|
}
|
|
|
|
// Sign instance's ephemeral public key with Ed25519 identity key.
|
|
// Fail closed: refuse key exchange if we can't sign (prevents unsigned/MITM-vulnerable channels).
|
|
if c.deps.IdentityPrivateKey == "" {
|
|
c.logger.Error().Uint32("channel", channelID).Msg("rejecting KEY_EXCHANGE: identity private key not configured")
|
|
c.closeAndRemoveChannel(channelID, "key exchange signing unavailable", sendCh)
|
|
return
|
|
}
|
|
|
|
instancePubBytes := instancePriv.PublicKey().Bytes()
|
|
sig, err := SignKeyExchange(instancePubBytes, c.deps.IdentityPrivateKey)
|
|
if err != nil {
|
|
c.logger.Error().Err(err).Uint32("channel", channelID).Msg("failed to sign KEY_EXCHANGE")
|
|
c.closeAndRemoveChannel(channelID, "key exchange signing failed", sendCh)
|
|
return
|
|
}
|
|
|
|
// Send KEY_EXCHANGE response with instance public key + signature
|
|
respPayload := MarshalKeyExchangePayload(instancePubBytes, sig)
|
|
respFrame := NewFrame(FrameKeyExchange, channelID, respPayload)
|
|
queueFrame(sendCh, respFrame, c.logger)
|
|
|
|
// Store encryption state
|
|
c.mu.Lock()
|
|
state.encryption = encryption
|
|
c.mu.Unlock()
|
|
|
|
c.logger.Info().Uint32("channel", channelID).Msg("key exchange completed, channel encrypted")
|
|
}
|
|
|
|
// closeAndRemoveChannel sends CHANNEL_CLOSE to the peer and removes the
|
|
// channel locally so no further DATA frames are processed.
|
|
func (c *Client) closeAndRemoveChannel(channelID uint32, reason string, sendCh chan<- []byte) {
|
|
c.mu.Lock()
|
|
delete(c.channels, channelID)
|
|
c.mu.Unlock()
|
|
|
|
closeFrame, err := NewControlFrame(FrameChannelClose, channelID, ChannelClosePayload{
|
|
ChannelID: channelID,
|
|
Reason: reason,
|
|
})
|
|
if err != nil {
|
|
c.logger.Warn().Err(err).Uint32("channel", channelID).Msg("Failed to build CHANNEL_CLOSE frame")
|
|
} else {
|
|
queueFrame(sendCh, closeFrame, c.logger)
|
|
}
|
|
}
|
|
|
|
func (c *Client) handleChannelClose(frame Frame) {
|
|
var payload ChannelClosePayload
|
|
if err := UnmarshalControlPayload(frame.Payload, &payload); err != nil {
|
|
c.logger.Debug().Err(err).Uint32("channel", frame.Channel).Msg("Failed to unmarshal CHANNEL_CLOSE payload, using frame channel")
|
|
// Fall back to using frame channel ID
|
|
payload.ChannelID = frame.Channel
|
|
}
|
|
|
|
c.mu.Lock()
|
|
delete(c.channels, payload.ChannelID)
|
|
c.mu.Unlock()
|
|
|
|
c.logger.Info().Uint32("channel", payload.ChannelID).Str("reason", payload.Reason).Msg("channel closed")
|
|
}
|
|
|
|
// queueFrame encodes and sends a frame to the send channel (non-blocking).
|
|
func queueFrame(sendCh chan<- []byte, f Frame, logger zerolog.Logger) {
|
|
frameLog := logger.With().
|
|
Str("component", "relay_client").
|
|
Str("frame_type", FrameTypeName(f.Type)).
|
|
Uint32("channel", f.Channel).
|
|
Int("payload_bytes", len(f.Payload)).
|
|
Logger()
|
|
|
|
data, err := EncodeFrame(f)
|
|
if err != nil {
|
|
frameLog.Warn().
|
|
Err(err).
|
|
Str("action", "encode_frame").
|
|
Msg("Failed to encode frame for send")
|
|
return
|
|
}
|
|
|
|
select {
|
|
case sendCh <- data:
|
|
default:
|
|
frameLog.Warn().
|
|
Str("action", "drop_frame").
|
|
Int("send_queue_depth", len(sendCh)).
|
|
Int("send_queue_capacity", cap(sendCh)).
|
|
Msg("Send channel full, dropping frame")
|
|
}
|
|
}
|
|
|
|
func (c *Client) backoffDelay(failures int) time.Duration {
|
|
return utils.ExponentialBackoff(baseReconnectDelay, maxReconnectDelay, failures, reconnectJitter, rand.Float64)
|
|
}
|
|
|
|
// licenseError is returned when the relay server rejects us due to license issues.
|
|
type licenseError struct {
|
|
code string
|
|
message string
|
|
}
|
|
|
|
func (e *licenseError) Error() string {
|
|
return fmt.Sprintf("license error (%s): %s", e.code, e.message)
|
|
}
|
|
|
|
func isLicenseError(err error) bool {
|
|
_, ok := err.(*licenseError)
|
|
return ok
|
|
}
|
|
|
|
type sessionResumeRejectedError struct {
|
|
code string
|
|
message string
|
|
}
|
|
|
|
func (e *sessionResumeRejectedError) Error() string {
|
|
return fmt.Sprintf("session resume rejected (%s): %s", e.code, e.message)
|
|
}
|
|
|
|
func isSessionResumeRejected(err error) bool {
|
|
var target *sessionResumeRejectedError
|
|
return errors.As(err, &target)
|
|
}
|
|
|
|
func shouldResetSessionAfterRegisterError(code string) bool {
|
|
return code == ErrCodeAuthFailed || code == ErrCodeNotFound
|
|
}
|