Pulse/internal/relay/client.go

968 lines
28 KiB
Go

package relay
import (
"context"
"crypto/ecdh"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"math/rand/v2"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/rcourtman/pulse-go-rewrite/internal/utils"
"github.com/rs/zerolog"
)
// ErrNotConnected is returned when attempting to send on a disconnected client.
var ErrNotConnected = errors.New("relay client not connected")
const (
// Reconnect backoff parameters
baseReconnectDelay = 5 * time.Second
maxReconnectDelay = 5 * time.Minute
reconnectJitter = 0.1
// WebSocket parameters
wsPingInterval = 25 * time.Second
wsWriteWait = 10 * time.Second
wsHandshakeWait = 15 * time.Second
sendChBufferSize = 256
// wsReadLimit bounds inbound WebSocket message size to a single relay frame.
wsReadLimit = HeaderSize + MaxPayloadSize
)
// maxConcurrentDataHandlers limits active DATA stream handlers per connection.
// This prevents unbounded goroutine growth if the relay floods DATA frames.
var maxConcurrentDataHandlers = 64
const (
wsMaxMessageSize = wsReadLimit
wsPongWait = 60 * time.Second
proxyStreamTimeout = 15 * time.Minute
relayOverloadedReason = "relay proxy overloaded"
)
// TokenValidator validates an API token and returns the raw token if valid.
type TokenValidator func(token string) bool
// channelState holds per-channel state including auth and encryption.
type channelState struct {
apiToken string
encryption *ChannelEncryption // nil until key exchange completes
ephemeral *ecdh.PrivateKey // ephemeral keypair, cleared after handshake
}
// ClientDeps holds injectable dependencies for the relay client.
type ClientDeps struct {
LicenseTokenFunc func() string // returns the raw license JWT
TokenValidator TokenValidator // validates API tokens from CHANNEL_OPEN
LocalAddr string // e.g. "127.0.0.1:7655"
ServerVersion string // Pulse version for ClientVersion in REGISTER
IdentityPubKey string // base64-encoded Ed25519 public key for MITM prevention
IdentityPrivateKey string // base64-encoded Ed25519 private key for signing KEY_EXCHANGE
}
// ClientStatus represents the current state of the relay client.
type ClientStatus struct {
Connected bool `json:"connected"`
InstanceID string `json:"instance_id,omitempty"`
ActiveChannels int `json:"active_channels"`
LastError string `json:"last_error,omitempty"`
ReconnectIn string `json:"reconnect_in,omitempty"`
}
// Client maintains a persistent connection to the relay server.
type Client struct {
config Config
deps ClientDeps
proxy *HTTPProxy
logger zerolog.Logger
// startupErr captures invalid static config/dependency inputs.
startupErr error
// Connection state (protected by mu)
mu sync.RWMutex
conn *websocket.Conn
sendCh chan<- []byte // per-connection send channel (nil when disconnected)
instanceID string
sessionToken string
channels map[uint32]*channelState // channelID → channel state
connected bool
lastError string
nextRetryAt time.Time
// Lifecycle
lifecycleMu sync.Mutex
cancel context.CancelFunc
done chan struct{}
}
// NewClient creates a new relay client.
func NewClient(cfg Config, deps ClientDeps, logger zerolog.Logger) *Client {
cfg, deps, warnings, startupErr := normalizeClientInputs(cfg, deps)
for _, warning := range warnings {
logger.Warn().Str("warning", warning).Msg("Normalized relay client configuration")
}
if startupErr != nil {
logger.Error().Err(startupErr).Msg("Invalid relay client configuration")
}
return &Client{
config: cfg,
deps: deps,
proxy: NewHTTPProxy(deps.LocalAddr, logger),
logger: logger,
channels: make(map[uint32]*channelState),
startupErr: startupErr,
}
}
// Run starts the reconnect loop. Blocks until ctx is cancelled.
func (c *Client) Run(ctx context.Context) error {
runCtx, runCancel := context.WithCancel(ctx)
runDone := make(chan struct{})
c.lifecycleMu.Lock()
c.cancel = runCancel
c.done = runDone
c.lifecycleMu.Unlock()
defer func() {
c.lifecycleMu.Lock()
if c.done == runDone {
c.done = nil
c.cancel = nil
}
c.lifecycleMu.Unlock()
close(runDone)
}()
ctx = runCtx
if c.startupErr != nil {
c.mu.Lock()
c.lastError = c.startupErr.Error()
c.connected = false
c.nextRetryAt = time.Time{}
c.mu.Unlock()
return c.startupErr
}
consecutiveFailures := 0
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
connected, err := c.connectAndHandle(ctx)
if err != nil {
if ctx.Err() != nil {
return ctx.Err()
}
if isSessionResumeRejected(err) {
c.mu.Lock()
c.lastError = err.Error()
c.connected = false
c.nextRetryAt = time.Time{}
c.mu.Unlock()
consecutiveFailures = 0
c.logger.Warn().Err(err).Msg("relay session resume rejected, retrying fresh registration")
continue
}
consecutiveFailures = nextConsecutiveFailures(consecutiveFailures, connected)
c.mu.Lock()
c.lastError = err.Error()
c.connected = false
c.nextRetryAt = time.Time{}
c.mu.Unlock()
delay := c.backoffDelay(consecutiveFailures)
if delay <= 0 {
c.logger.Warn().
Int("failures", consecutiveFailures).
Dur("computed_retry_in", delay).
Msg("computed non-positive relay retry delay; using base reconnect delay")
delay = baseReconnectDelay
}
if consecutiveFailures >= 3 {
c.logger.Warn().Err(err).
Int("failures", consecutiveFailures).
Dur("retry_in", delay).
Msg("relay connection failed repeatedly")
} else {
c.logger.Debug().Err(err).
Dur("retry_in", delay).
Msg("relay connection interrupted, reconnecting")
}
// If it's a license error, pause longer
if isLicenseError(err) {
delay = maxReconnectDelay
c.logger.Warn().Msg("license error from relay server, pausing reconnect")
}
c.mu.Lock()
c.nextRetryAt = time.Now().Add(delay)
c.mu.Unlock()
timer := time.NewTimer(delay)
select {
case <-ctx.Done():
c.mu.Lock()
c.nextRetryAt = time.Time{}
c.mu.Unlock()
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
return ctx.Err()
case <-timer.C:
c.mu.Lock()
c.nextRetryAt = time.Time{}
c.mu.Unlock()
}
} else {
consecutiveFailures = 0
}
}
}
// nextConsecutiveFailures advances the reconnect failure streak.
// A session that successfully registers is treated as a recovery point, so
// subsequent disconnects start a new streak at 1 instead of compounding old failures.
func nextConsecutiveFailures(current int, connected bool) int {
if connected {
return 1
}
return current + 1
}
// Close stops the client and closes the connection.
func (c *Client) Close() {
c.lifecycleMu.Lock()
cancel := c.cancel
done := c.done
c.lifecycleMu.Unlock()
if cancel == nil {
return
}
cancel()
// Wait for Run to finish
timer := time.NewTimer(5 * time.Second)
defer timer.Stop()
select {
case <-done:
case <-timer.C:
}
c.proxy.Close()
}
// Status returns the current client status.
func (c *Client) Status() ClientStatus {
c.mu.RLock()
defer c.mu.RUnlock()
reconnectIn := ""
if !c.nextRetryAt.IsZero() {
remaining := time.Until(c.nextRetryAt)
if remaining > 0 {
reconnectIn = remaining.Round(time.Second).String()
}
}
return ClientStatus{
Connected: c.connected,
InstanceID: c.instanceID,
ActiveChannels: len(c.channels),
LastError: c.lastError,
ReconnectIn: reconnectIn,
}
}
// SendPushNotification sends a push notification through the relay.
// Returns ErrNotConnected if the client has not completed registration
// with the relay server.
func (c *Client) SendPushNotification(notification PushNotificationPayload) error {
c.mu.RLock()
instanceID := c.instanceID
ch := c.sendCh
connected := c.connected
c.mu.RUnlock()
if notification.InstanceID == "" {
notification.InstanceID = instanceID
}
frame, err := NewControlFrame(FramePushNotification, 0, notification)
if err != nil {
return fmt.Errorf("build push frame: %w", err)
}
data, err := EncodeFrame(frame)
if err != nil {
return fmt.Errorf("encode push frame: %w", err)
}
if ch == nil || !connected {
return ErrNotConnected
}
select {
case ch <- data:
return nil
default:
return fmt.Errorf("send channel full")
}
}
// connectAndHandle establishes a relay session and handles frames until disconnect.
// The returned bool reports whether registration succeeded for this attempt.
func (c *Client) connectAndHandle(ctx context.Context) (bool, error) {
dialer := websocket.Dialer{
HandshakeTimeout: wsHandshakeWait,
}
tlsConfig, err := relayTLSConfig()
if err != nil {
return false, fmt.Errorf("build relay tls config: %w", err)
}
dialer.TLSClientConfig = tlsConfig
c.logger.Info().Str("url", c.config.ServerURL).Msg("connecting to relay server")
conn, _, err := dialer.DialContext(ctx, c.config.ServerURL, nil)
if err != nil {
return false, fmt.Errorf("dial relay: %w", err)
}
conn.SetReadLimit(wsReadLimit)
// Per-connection send channel — no races because each writePump gets its own
sendCh := make(chan []byte, sendChBufferSize)
c.mu.Lock()
c.conn = conn
c.channels = make(map[uint32]*channelState)
c.mu.Unlock()
defer func() {
c.mu.Lock()
instanceID := c.instanceID
activeChannels := len(c.channels)
c.sendCh = nil
c.conn = nil
c.channels = make(map[uint32]*channelState)
c.connected = false
c.mu.Unlock()
conn.Close()
c.logger.Info().
Str("instance_id", instanceID).
Int("active_channels", activeChannels).
Msg("Relay connection closed")
}()
// Register with relay server
if err := c.register(conn); err != nil {
return false, fmt.Errorf("register: %w", err)
}
// Enforce connection liveness: each Pong extends the read deadline.
conn.SetReadLimit(int64(wsMaxMessageSize))
_ = conn.SetReadDeadline(time.Now().Add(wsPongWait))
conn.SetPongHandler(func(string) error {
return conn.SetReadDeadline(time.Now().Add(wsPongWait))
})
// Expose sendCh only after successful registration so
// SendPushNotification callers can't enqueue frames during
// the handshake window before a relay session exists.
c.mu.Lock()
c.sendCh = sendCh
c.connected = true
c.lastError = ""
c.nextRetryAt = time.Time{}
c.mu.Unlock()
c.logger.Info().Str("instance_id", c.instanceID).Msg("registered with relay server")
// Per-connection context: cancelled when this connection ends (for any
// reason), which tears down the write pump and any in-flight stream
// goroutines spawned by handleData. Without this, stream goroutines
// would keep running against a stale sendCh until the whole client stops.
connCtx, connCancel := context.WithCancel(ctx)
defer connCancel()
go func() {
<-connCtx.Done()
_ = conn.Close()
}()
go c.writePump(connCtx, conn, sendCh)
// Read pump (blocking) — passes connCtx so handleData streams inherit it
// Rate-limit concurrent DATA stream handlers per connection.
dataLimiter := make(chan struct{}, maxConcurrentDataHandlers)
return true, c.readPump(connCtx, conn, sendCh, dataLimiter)
}
func relayTLSConfig() (*tls.Config, error) {
tlsConfig := &tls.Config{MinVersion: tls.VersionTLS12}
caBundle := strings.TrimSpace(os.Getenv("SSL_CERT_FILE"))
if caBundle == "" {
return tlsConfig, nil
}
caData, err := os.ReadFile(caBundle)
if err != nil {
return nil, fmt.Errorf("read relay CA bundle %s: %w", caBundle, err)
}
pool, err := x509.SystemCertPool()
if err != nil || pool == nil {
pool = x509.NewCertPool()
}
if ok := pool.AppendCertsFromPEM(caData); !ok {
return nil, fmt.Errorf("relay CA bundle %s does not contain any certificates", caBundle)
}
tlsConfig.RootCAs = pool
return tlsConfig, nil
}
func (c *Client) register(conn *websocket.Conn) error {
if c.deps.LicenseTokenFunc == nil {
return fmt.Errorf("license token provider not configured")
}
token := strings.TrimSpace(c.deps.LicenseTokenFunc())
if token == "" {
return fmt.Errorf("no license token available")
}
payload := RegisterPayload{
LicenseToken: token,
InstanceHint: c.config.InstanceSecret,
ClientVersion: c.deps.ServerVersion,
IdentityPubKey: c.deps.IdentityPubKey,
}
attemptedSessionResume := false
// Reuse session token if we have one from a previous connection.
// When reconnecting, send the derived instance_id (from the prior
// REGISTER_ACK) as InstanceHint instead of the raw secret. The
// relay server's session reconnect path looks up by instance_id
// directly — it needs the derived ID, not the secret used to
// compute it during the initial registration.
//
// Only set SessionToken when we also have the derived instanceID,
// because the bridge takes the session fast-path whenever both
// SessionToken and InstanceHint are present. If SessionToken is
// set without the derived ID, the bridge would query by raw secret
// and fail with AUTH_FAILED (no license fallback).
c.mu.RLock()
if c.sessionToken != "" && c.instanceID != "" {
payload.SessionToken = c.sessionToken
payload.InstanceHint = c.instanceID
attemptedSessionResume = true
}
c.mu.RUnlock()
frame, err := NewControlFrame(FrameRegister, 0, payload)
if err != nil {
return fmt.Errorf("build register frame: %w", err)
}
data, err := EncodeFrame(frame)
if err != nil {
return fmt.Errorf("encode register frame: %w", err)
}
if err := conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil {
return fmt.Errorf("set write deadline: %w", err)
}
if err := conn.WriteMessage(websocket.BinaryMessage, data); err != nil {
return fmt.Errorf("send register: %w", err)
}
// Wait for REGISTER_ACK or ERROR
if err := conn.SetReadDeadline(time.Now().Add(wsHandshakeWait)); err != nil {
return fmt.Errorf("set read deadline: %w", err)
}
_, msg, err := conn.ReadMessage()
if err != nil {
return fmt.Errorf("read register response: %w", err)
}
if err := conn.SetReadDeadline(time.Time{}); err != nil {
return fmt.Errorf("clear read deadline: %w", err)
}
frame, err = DecodeFrame(msg)
if err != nil {
return fmt.Errorf("decode register response: %w", err)
}
switch frame.Type {
case FrameRegisterAck:
var ack RegisterAckPayload
if err := UnmarshalControlPayload(frame.Payload, &ack); err != nil {
return fmt.Errorf("unmarshal register ack: %w", err)
}
c.mu.Lock()
c.instanceID = ack.InstanceID
c.sessionToken = ack.SessionToken
c.mu.Unlock()
return nil
case FrameError:
var errPayload ErrorPayload
if err := UnmarshalControlPayload(frame.Payload, &errPayload); err != nil {
return fmt.Errorf("unmarshal error: %w", err)
}
if attemptedSessionResume && shouldResetSessionAfterRegisterError(errPayload.Code) {
c.mu.Lock()
c.sessionToken = ""
c.instanceID = ""
c.mu.Unlock()
return &sessionResumeRejectedError{code: errPayload.Code, message: errPayload.Message}
}
return fmt.Errorf("relay error (%s): %s", errPayload.Code, errPayload.Message)
default:
return fmt.Errorf("unexpected frame type during registration: %s", FrameTypeName(frame.Type))
}
}
func (c *Client) readPump(ctx context.Context, conn *websocket.Conn, sendCh chan<- []byte, dataLimiter chan struct{}) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
_, msg, err := conn.ReadMessage()
if err != nil {
return fmt.Errorf("read message: %w", err)
}
_ = conn.SetReadDeadline(time.Now().Add(wsPongWait))
frame, err := DecodeFrame(msg)
if err != nil {
c.logger.Warn().Err(err).Msg("failed to decode frame, skipping")
continue
}
switch frame.Type {
case FrameChannelOpen:
c.handleChannelOpen(frame, sendCh)
case FrameKeyExchange:
c.handleKeyExchange(frame, sendCh)
case FrameData:
c.handleData(ctx, frame, sendCh, dataLimiter)
case FrameChannelClose:
c.handleChannelClose(frame)
case FramePing:
queueFrame(sendCh, NewPongFrame(), c.logger)
case FrameDrain:
var drain DrainPayload
if err := UnmarshalControlPayload(frame.Payload, &drain); err != nil {
c.logger.Debug().Err(err).Msg("Failed to unmarshal DRAIN payload")
} else {
c.logger.Info().Str("reason", drain.Reason).Msg("Relay server draining, will reconnect")
}
return nil // exit readPump, triggers reconnect
case FrameError:
var errPayload ErrorPayload
if err := UnmarshalControlPayload(frame.Payload, &errPayload); err != nil {
c.logger.Warn().Err(err).Msg("Failed to unmarshal ERROR payload")
} else {
c.logger.Warn().Str("code", errPayload.Code).Str("message", errPayload.Message).Msg("Relay error")
if errPayload.Code == ErrCodeLicenseInvalid || errPayload.Code == ErrCodeLicenseExpired {
return &licenseError{code: errPayload.Code, message: errPayload.Message}
}
}
default:
c.logger.Debug().Str("type", FrameTypeName(frame.Type)).Msg("ignoring unhandled frame type")
}
}
}
func (c *Client) writePump(ctx context.Context, conn *websocket.Conn, sendCh <-chan []byte) {
// Always close the socket on writer exit so blocked readers unblock and reconnect.
defer conn.Close()
ticker := time.NewTicker(wsPingInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
// Send close message
if err := conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil {
c.logger.Debug().Err(err).Msg("set write deadline failed")
}
if err := conn.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
c.logger.Debug().Err(err).Msg("WS close frame write failed")
}
return
case data, ok := <-sendCh:
if !ok {
return
}
if err := conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil {
c.logger.Debug().Err(err).Msg("set write deadline failed")
}
if err := conn.WriteMessage(websocket.BinaryMessage, data); err != nil {
c.logger.Debug().Err(err).Msg("write failed")
return
}
case <-ticker.C:
if err := conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil {
c.logger.Debug().Err(err).Msg("set write deadline failed")
}
if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
c.logger.Debug().Err(err).Msg("WS ping failed")
return
}
}
}
}
func (c *Client) handleChannelOpen(frame Frame, sendCh chan<- []byte) {
var payload ChannelOpenPayload
if err := UnmarshalControlPayload(frame.Payload, &payload); err != nil {
c.logger.Warn().Err(err).Msg("failed to unmarshal CHANNEL_OPEN")
return
}
if c.deps.TokenValidator == nil {
c.logger.Error().Uint32("channel", payload.ChannelID).Msg("Rejecting channel: token validator not configured")
closeFrame, err := NewControlFrame(FrameChannelClose, payload.ChannelID, ChannelClosePayload{
ChannelID: payload.ChannelID,
Reason: "token validation unavailable",
})
if err == nil {
queueFrame(sendCh, closeFrame, c.logger)
}
return
}
payload.AuthToken = strings.TrimSpace(payload.AuthToken)
// Validate the auth token
if payload.AuthToken == "" || !c.deps.TokenValidator(payload.AuthToken) {
c.logger.Warn().Uint32("channel", payload.ChannelID).Msg("Rejecting channel: invalid auth token")
closeFrame, err := NewControlFrame(FrameChannelClose, payload.ChannelID, ChannelClosePayload{
ChannelID: payload.ChannelID,
Reason: "invalid auth token",
})
if err != nil {
c.logger.Warn().Err(err).Uint32("channel", payload.ChannelID).Msg("Failed to build CHANNEL_CLOSE frame")
} else {
queueFrame(sendCh, closeFrame, c.logger)
}
return
}
// Accept: store channel and echo CHANNEL_OPEN back
c.mu.Lock()
c.channels[payload.ChannelID] = &channelState{apiToken: payload.AuthToken}
c.mu.Unlock()
c.logger.Info().Uint32("channel", payload.ChannelID).Msg("channel opened")
// Echo CHANNEL_OPEN to acknowledge
ackFrame, err := NewControlFrame(FrameChannelOpen, payload.ChannelID, payload)
if err != nil {
c.logger.Warn().Err(err).Uint32("channel", payload.ChannelID).Msg("Failed to build CHANNEL_OPEN ack frame")
} else {
queueFrame(sendCh, ackFrame, c.logger)
}
}
func (c *Client) handleData(connCtx context.Context, frame Frame, sendCh chan<- []byte, dataLimiter chan struct{}) {
channelID := frame.Channel
// Snapshot channel state under lock so the goroutine below doesn't race
// with handleKeyExchange writing state.encryption.
c.mu.RLock()
state, ok := c.channels[channelID]
var enc *ChannelEncryption
var apiToken string
if ok {
enc = state.encryption
apiToken = state.apiToken
}
c.mu.RUnlock()
if !ok {
c.logger.Warn().Uint32("channel", channelID).Msg("DATA for unknown channel")
return
}
payload := frame.Payload
// Preserve relay frame arrival order through decryption. The channel nonce
// guard is strictly monotonic; decrypting in background goroutines lets
// later frames win the race and falsely trip replay protection.
if enc != nil {
decrypted, err := enc.Decrypt(payload)
if err != nil {
c.logger.Warn().Err(err).Uint32("channel", channelID).Msg("failed to decrypt DATA payload")
return
}
payload = decrypted
}
select {
case dataLimiter <- struct{}{}:
default:
c.handleOverloadedData(channelID, payload, enc, sendCh)
return
}
// Handle in background goroutine so we don't block the read pump
go func(payload []byte) {
defer func() { <-dataLimiter }()
// Derive from the connection context so streams are cancelled on disconnect.
// The 15-minute timeout is a safety net for runaway streams.
ctx, cancel := context.WithTimeout(connCtx, proxyStreamTimeout)
defer cancel()
err := c.proxy.HandleStreamRequest(ctx, payload, apiToken, func(respPayload []byte) {
if enc != nil {
encrypted, err := enc.Encrypt(respPayload)
if err != nil {
c.logger.Warn().Err(err).Uint32("channel", channelID).Msg("failed to encrypt DATA response")
return
}
respPayload = encrypted
}
respFrame := NewFrame(FrameData, channelID, respPayload)
queueFrame(sendCh, respFrame, c.logger)
})
if err != nil && connCtx.Err() == nil {
c.logger.Warn().Err(err).Uint32("channel", channelID).Msg("stream proxy error")
}
}(payload)
}
func (c *Client) handleOverloadedData(channelID uint32, payload []byte, enc *ChannelEncryption, sendCh chan<- []byte) {
requestID := extractProxyRequestID(payload)
respPayload := c.proxy.errorResponse(requestID, http.StatusServiceUnavailable, relayOverloadedReason)
if enc != nil {
encrypted, err := enc.Encrypt(respPayload)
if err != nil {
c.logger.Warn().Err(err).Uint32("channel", channelID).Msg("Failed to encrypt overload response")
return
}
respPayload = encrypted
}
c.logger.Warn().
Uint32("channel", channelID).
Str("request_id", requestID).
Int("max_in_flight", maxConcurrentDataHandlers).
Msg("DATA handler limit reached, rejecting request")
queueFrame(sendCh, NewFrame(FrameData, channelID, respPayload), c.logger)
}
func extractProxyRequestID(payload []byte) string {
var req ProxyRequest
if err := json.Unmarshal(payload, &req); err != nil {
return ""
}
return req.ID
}
func (c *Client) handleKeyExchange(frame Frame, sendCh chan<- []byte) {
channelID := frame.Channel
c.mu.RLock()
state, ok := c.channels[channelID]
c.mu.RUnlock()
if !ok {
c.logger.Warn().Uint32("channel", channelID).Msg("KEY_EXCHANGE for unknown channel")
return
}
// Unmarshal the app's ephemeral public key
appPubBytes, _, err := UnmarshalKeyExchangePayload(frame.Payload)
if err != nil {
c.logger.Warn().Err(err).Uint32("channel", channelID).Msg("failed to unmarshal KEY_EXCHANGE")
return
}
appPubKey, err := ecdh.X25519().NewPublicKey(appPubBytes)
if err != nil {
c.logger.Warn().Err(err).Uint32("channel", channelID).Msg("invalid X25519 public key in KEY_EXCHANGE")
return
}
// Generate instance's ephemeral X25519 keypair
instancePriv, err := GenerateEphemeralKeyPair()
if err != nil {
c.logger.Error().Err(err).Uint32("channel", channelID).Msg("failed to generate ephemeral keypair")
return
}
// Derive channel keys
encryption, err := DeriveChannelKeys(instancePriv, appPubKey, true)
if err != nil {
c.logger.Error().Err(err).Uint32("channel", channelID).Msg("failed to derive channel keys")
return
}
// Sign instance's ephemeral public key with Ed25519 identity key.
// Fail closed: refuse key exchange if we can't sign (prevents unsigned/MITM-vulnerable channels).
if c.deps.IdentityPrivateKey == "" {
c.logger.Error().Uint32("channel", channelID).Msg("rejecting KEY_EXCHANGE: identity private key not configured")
c.closeAndRemoveChannel(channelID, "key exchange signing unavailable", sendCh)
return
}
instancePubBytes := instancePriv.PublicKey().Bytes()
sig, err := SignKeyExchange(instancePubBytes, c.deps.IdentityPrivateKey)
if err != nil {
c.logger.Error().Err(err).Uint32("channel", channelID).Msg("failed to sign KEY_EXCHANGE")
c.closeAndRemoveChannel(channelID, "key exchange signing failed", sendCh)
return
}
// Send KEY_EXCHANGE response with instance public key + signature
respPayload := MarshalKeyExchangePayload(instancePubBytes, sig)
respFrame := NewFrame(FrameKeyExchange, channelID, respPayload)
queueFrame(sendCh, respFrame, c.logger)
// Store encryption state
c.mu.Lock()
state.encryption = encryption
c.mu.Unlock()
c.logger.Info().Uint32("channel", channelID).Msg("key exchange completed, channel encrypted")
}
// closeAndRemoveChannel sends CHANNEL_CLOSE to the peer and removes the
// channel locally so no further DATA frames are processed.
func (c *Client) closeAndRemoveChannel(channelID uint32, reason string, sendCh chan<- []byte) {
c.mu.Lock()
delete(c.channels, channelID)
c.mu.Unlock()
closeFrame, err := NewControlFrame(FrameChannelClose, channelID, ChannelClosePayload{
ChannelID: channelID,
Reason: reason,
})
if err != nil {
c.logger.Warn().Err(err).Uint32("channel", channelID).Msg("Failed to build CHANNEL_CLOSE frame")
} else {
queueFrame(sendCh, closeFrame, c.logger)
}
}
func (c *Client) handleChannelClose(frame Frame) {
var payload ChannelClosePayload
if err := UnmarshalControlPayload(frame.Payload, &payload); err != nil {
c.logger.Debug().Err(err).Uint32("channel", frame.Channel).Msg("Failed to unmarshal CHANNEL_CLOSE payload, using frame channel")
// Fall back to using frame channel ID
payload.ChannelID = frame.Channel
}
c.mu.Lock()
delete(c.channels, payload.ChannelID)
c.mu.Unlock()
c.logger.Info().Uint32("channel", payload.ChannelID).Str("reason", payload.Reason).Msg("channel closed")
}
// queueFrame encodes and sends a frame to the send channel (non-blocking).
func queueFrame(sendCh chan<- []byte, f Frame, logger zerolog.Logger) {
frameLog := logger.With().
Str("component", "relay_client").
Str("frame_type", FrameTypeName(f.Type)).
Uint32("channel", f.Channel).
Int("payload_bytes", len(f.Payload)).
Logger()
data, err := EncodeFrame(f)
if err != nil {
frameLog.Warn().
Err(err).
Str("action", "encode_frame").
Msg("Failed to encode frame for send")
return
}
select {
case sendCh <- data:
default:
frameLog.Warn().
Str("action", "drop_frame").
Int("send_queue_depth", len(sendCh)).
Int("send_queue_capacity", cap(sendCh)).
Msg("Send channel full, dropping frame")
}
}
func (c *Client) backoffDelay(failures int) time.Duration {
return utils.ExponentialBackoff(baseReconnectDelay, maxReconnectDelay, failures, reconnectJitter, rand.Float64)
}
// licenseError is returned when the relay server rejects us due to license issues.
type licenseError struct {
code string
message string
}
func (e *licenseError) Error() string {
return fmt.Sprintf("license error (%s): %s", e.code, e.message)
}
func isLicenseError(err error) bool {
_, ok := err.(*licenseError)
return ok
}
type sessionResumeRejectedError struct {
code string
message string
}
func (e *sessionResumeRejectedError) Error() string {
return fmt.Sprintf("session resume rejected (%s): %s", e.code, e.message)
}
func isSessionResumeRejected(err error) bool {
var target *sessionResumeRejectedError
return errors.As(err, &target)
}
func shouldResetSessionAfterRegisterError(code string) bool {
return code == ErrCodeAuthFailed || code == ErrCodeNotFound
}