Pulse/internal/relay/client_test.go

1772 lines
49 KiB
Go

package relay
import (
"bytes"
"context"
"crypto/cipher"
"crypto/ecdh"
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/rs/zerolog"
)
var testUpgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
}
// mockRelayServer creates an httptest.Server that speaks the relay protocol.
// It returns the server and a channel to receive the instance-side connection.
func mockRelayServer(t *testing.T, handler func(conn *websocket.Conn)) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := testUpgrader.Upgrade(w, r, nil)
if err != nil {
t.Logf("upgrade error: %v", err)
return
}
defer conn.Close()
handler(conn)
}))
}
func wsURL(server *httptest.Server) string {
return "ws" + strings.TrimPrefix(server.URL, "http")
}
func wssURL(server *httptest.Server) string {
return "wss" + strings.TrimPrefix(server.URL, "https")
}
func TestHTTPProxyTrimsRequestBoundaryWhitespace(t *testing.T) {
logger := zerolog.New(zerolog.NewTestWriter(t))
t.Run("handle request trims method and path whitespace", func(t *testing.T) {
mockAPI := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
t.Fatalf("method: got %q, want %q", r.Method, http.MethodGet)
}
if r.URL.Path != "/api/resources" {
t.Fatalf("path: got %q, want %q", r.URL.Path, "/api/resources")
}
if token := r.Header.Get("X-API-Token"); token != "test-token" {
t.Fatalf("token: got %q, want %q", token, "test-token")
}
if header := r.Header.Get("Content-Type"); header != "application/json" {
t.Fatalf("content-type: got %q, want %q", header, "application/json")
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
}))
defer mockAPI.Close()
proxy := NewHTTPProxy(strings.TrimPrefix(mockAPI.URL, "http://"), logger)
payload, _ := json.Marshal(ProxyRequest{
ID: "req_trim",
Method: " GET ",
Path: " /api/resources ",
Headers: map[string]string{
" Content-Type ": "application/json",
},
})
respPayload, err := proxy.HandleRequest(payload, "test-token")
if err != nil {
t.Fatalf("HandleRequest() error = %v", err)
}
var resp ProxyResponse
if err := json.Unmarshal(respPayload, &resp); err != nil {
t.Fatalf("unmarshal response: %v", err)
}
if resp.Status != http.StatusOK {
t.Fatalf("status: got %d, want %d", resp.Status, http.StatusOK)
}
})
t.Run("handle stream request trims method and path whitespace", func(t *testing.T) {
streamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Fatalf("method: got %q, want %q", r.Method, http.MethodPost)
}
if r.URL.Path != "/api/ai/chat" {
t.Fatalf("path: got %q, want %q", r.URL.Path, "/api/ai/chat")
}
if token := r.Header.Get("X-API-Token"); token != "test-token" {
t.Fatalf("token: got %q, want %q", token, "test-token")
}
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
fmt.Fprint(w, "data: done\n\n")
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
}
}))
defer streamServer.Close()
proxy := NewHTTPProxy(strings.TrimPrefix(streamServer.URL, "http://"), logger)
payload, _ := json.Marshal(ProxyRequest{
ID: "stream_trim",
Method: " POST ",
Path: " /api/ai/chat ",
})
var frames [][]byte
err := proxy.HandleStreamRequest(context.Background(), payload, "test-token", func(data []byte) {
frames = append(frames, append([]byte(nil), data...))
})
if err != nil {
t.Fatalf("HandleStreamRequest() error = %v", err)
}
if len(frames) == 0 {
t.Fatal("expected at least one stream frame")
}
})
}
func writeServerCertPEM(t *testing.T, server *httptest.Server) string {
t.Helper()
if len(server.TLS.Certificates) == 0 || len(server.TLS.Certificates[0].Certificate) == 0 {
t.Fatal("test server missing TLS certificate")
}
block := &pem.Block{Type: "CERTIFICATE", Bytes: server.TLS.Certificates[0].Certificate[0]}
if _, err := x509.ParseCertificate(block.Bytes); err != nil {
t.Fatalf("parse test certificate: %v", err)
}
path := filepath.Join(t.TempDir(), "relay-ca.pem")
if err := os.WriteFile(path, pem.EncodeToMemory(block), 0600); err != nil {
t.Fatalf("write relay CA bundle: %v", err)
}
return path
}
type blockingTestAEAD struct {
openStarted chan struct{}
releaseOpen chan struct{}
}
func (a *blockingTestAEAD) NonceSize() int { return nonceSize }
func (a *blockingTestAEAD) Overhead() int { return 0 }
func (a *blockingTestAEAD) Seal(_ []byte, _ []byte, plaintext []byte, _ []byte) []byte {
return append([]byte(nil), plaintext...)
}
func (a *blockingTestAEAD) Open(_ []byte, _ []byte, ciphertext []byte, _ []byte) ([]byte, error) {
select {
case <-a.openStarted:
default:
close(a.openStarted)
}
<-a.releaseOpen
return append([]byte(nil), ciphertext...), nil
}
var _ cipher.AEAD = (*blockingTestAEAD)(nil)
func TestClient_RegisterAndChannelLifecycle(t *testing.T) {
logger := zerolog.New(zerolog.NewTestWriter(t))
// Track what the mock relay receives
var mu sync.Mutex
var registerReceived bool
var channelOpenAckReceived bool
dataResponseCh := make(chan ProxyResponse, 1)
// Mock local Pulse API
mockAPI := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"path": r.URL.Path})
}))
defer mockAPI.Close()
localAddr := strings.TrimPrefix(mockAPI.URL, "http://")
// Mock relay server
relayServer := mockRelayServer(t, func(conn *websocket.Conn) {
// 1. Read REGISTER
_, msg, err := conn.ReadMessage()
if err != nil {
t.Logf("read register: %v", err)
return
}
frame, err := DecodeFrame(msg)
if err != nil {
t.Logf("decode register: %v", err)
return
}
if frame.Type != FrameRegister {
t.Logf("expected REGISTER, got %s", FrameTypeName(frame.Type))
return
}
var regPayload RegisterPayload
_ = UnmarshalControlPayload(frame.Payload, &regPayload)
mu.Lock()
registerReceived = regPayload.LicenseToken == "test-license-jwt" &&
regPayload.IdentityPubKey == "test-identity-pub-key"
mu.Unlock()
// 2. Send REGISTER_ACK
ack, _ := NewControlFrame(FrameRegisterAck, 0, RegisterAckPayload{
InstanceID: "inst_abc",
SessionToken: "sess_xyz",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
})
ackBytes, _ := EncodeFrame(ack)
_ = conn.WriteMessage(websocket.BinaryMessage, ackBytes)
// 3. Send CHANNEL_OPEN
chOpen, _ := NewControlFrame(FrameChannelOpen, 1, ChannelOpenPayload{
ChannelID: 1,
AuthToken: "valid-api-token",
})
chOpenBytes, _ := EncodeFrame(chOpen)
time.Sleep(50 * time.Millisecond) // let client set up
_ = conn.WriteMessage(websocket.BinaryMessage, chOpenBytes)
// 4. Read CHANNEL_OPEN ack from instance
_, msg, err = conn.ReadMessage()
if err != nil {
t.Logf("read channel open ack: %v", err)
return
}
frame, _ = DecodeFrame(msg)
mu.Lock()
channelOpenAckReceived = frame.Type == FrameChannelOpen
mu.Unlock()
// 5. Send DATA request
proxyReq := ProxyRequest{
ID: "req_test",
Method: "GET",
Path: "/api/status",
}
proxyReqBytes, _ := json.Marshal(proxyReq)
dataFrame := NewFrame(FrameData, 1, proxyReqBytes)
dataBytes, _ := EncodeFrame(dataFrame)
_ = conn.WriteMessage(websocket.BinaryMessage, dataBytes)
// 6. Read DATA response
_, msg, err = conn.ReadMessage()
if err != nil {
t.Logf("read data response: %v", err)
return
}
frame, _ = DecodeFrame(msg)
if frame.Type == FrameData {
var resp ProxyResponse
_ = json.Unmarshal(frame.Payload, &resp)
dataResponseCh <- resp
}
// 7. Send CHANNEL_CLOSE
chClose, _ := NewControlFrame(FrameChannelClose, 1, ChannelClosePayload{
ChannelID: 1,
Reason: "test done",
})
chCloseBytes, _ := EncodeFrame(chClose)
_ = conn.WriteMessage(websocket.BinaryMessage, chCloseBytes)
// Keep connection open so the client stays connected during assertions
time.Sleep(2 * time.Second)
})
defer relayServer.Close()
cfg := Config{
Enabled: true,
ServerURL: wsURL(relayServer),
}
deps := ClientDeps{
LicenseTokenFunc: func() string { return "test-license-jwt" },
TokenValidator: func(token string) bool { return token == "valid-api-token" },
LocalAddr: localAddr,
ServerVersion: "1.0.0-test",
IdentityPubKey: "test-identity-pub-key",
}
client := NewClient(cfg, deps, logger)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Run in background
errCh := make(chan error, 1)
go func() {
errCh <- client.Run(ctx)
}()
// Wait for DATA response
select {
case resp := <-dataResponseCh:
if resp.ID != "req_test" {
t.Errorf("response ID: got %q, want %q", resp.ID, "req_test")
}
if resp.Status != 200 {
t.Errorf("response status: got %d, want 200", resp.Status)
}
case <-time.After(3 * time.Second):
t.Fatal("timed out waiting for DATA response")
}
// Verify state
mu.Lock()
if !registerReceived {
t.Error("REGISTER not received or had wrong license token")
}
if !channelOpenAckReceived {
t.Error("CHANNEL_OPEN ack not received")
}
mu.Unlock()
// Wait a bit for CHANNEL_CLOSE to be processed
time.Sleep(200 * time.Millisecond)
status := client.Status()
if !status.Connected {
t.Error("expected client to be connected")
}
if status.InstanceID != "inst_abc" {
t.Errorf("instance_id: got %q, want %q", status.InstanceID, "inst_abc")
}
cancel()
select {
case <-errCh:
case <-time.After(3 * time.Second):
t.Fatal("client.Run didn't return after cancel")
}
}
func TestClient_RejectInvalidToken(t *testing.T) {
logger := zerolog.New(zerolog.NewTestWriter(t))
channelCloseCh := make(chan ChannelClosePayload, 1)
relayServer := mockRelayServer(t, func(conn *websocket.Conn) {
// Read REGISTER
_, msg, _ := conn.ReadMessage()
frame, _ := DecodeFrame(msg)
if frame.Type != FrameRegister {
return
}
// Send REGISTER_ACK
ack, _ := NewControlFrame(FrameRegisterAck, 0, RegisterAckPayload{
InstanceID: "inst_abc",
SessionToken: "sess_xyz",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
})
ackBytes, _ := EncodeFrame(ack)
_ = conn.WriteMessage(websocket.BinaryMessage, ackBytes)
// Send CHANNEL_OPEN with bad token
time.Sleep(50 * time.Millisecond)
chOpen, _ := NewControlFrame(FrameChannelOpen, 99, ChannelOpenPayload{
ChannelID: 99,
AuthToken: "INVALID-TOKEN",
})
chOpenBytes, _ := EncodeFrame(chOpen)
_ = conn.WriteMessage(websocket.BinaryMessage, chOpenBytes)
// Read CHANNEL_CLOSE (reject)
_, msg, err := conn.ReadMessage()
if err != nil {
return
}
frame, _ = DecodeFrame(msg)
if frame.Type == FrameChannelClose {
var closePayload ChannelClosePayload
_ = UnmarshalControlPayload(frame.Payload, &closePayload)
channelCloseCh <- closePayload
}
time.Sleep(100 * time.Millisecond)
})
defer relayServer.Close()
cfg := Config{
Enabled: true,
ServerURL: wsURL(relayServer),
}
deps := ClientDeps{
LicenseTokenFunc: func() string { return "test-jwt" },
TokenValidator: func(token string) bool { return token == "good-token" },
LocalAddr: "127.0.0.1:9999",
ServerVersion: "1.0.0",
IdentityPubKey: "test-pub-key",
}
client := NewClient(cfg, deps, logger)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
errCh := make(chan error, 1)
go func() { errCh <- client.Run(ctx) }()
select {
case closePayload := <-channelCloseCh:
if closePayload.ChannelID != 99 {
t.Errorf("channel ID: got %d, want 99", closePayload.ChannelID)
}
if closePayload.Reason != "invalid auth token" {
t.Errorf("reason: got %q, want %q", closePayload.Reason, "invalid auth token")
}
case <-time.After(3 * time.Second):
t.Fatal("timed out waiting for CHANNEL_CLOSE")
}
cancel()
<-errCh
}
func TestClient_UsesRelayCABundleFromEnvironment(t *testing.T) {
logger := zerolog.New(zerolog.NewTestWriter(t))
mockAPI := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
defer mockAPI.Close()
relayServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := testUpgrader.Upgrade(w, r, nil)
if err != nil {
t.Logf("upgrade error: %v", err)
return
}
defer conn.Close()
_, msg, err := conn.ReadMessage()
if err != nil {
t.Logf("read register: %v", err)
return
}
frame, err := DecodeFrame(msg)
if err != nil {
t.Logf("decode register: %v", err)
return
}
if frame.Type != FrameRegister {
t.Logf("expected REGISTER, got %s", FrameTypeName(frame.Type))
return
}
ack, _ := NewControlFrame(FrameRegisterAck, 0, RegisterAckPayload{
InstanceID: "inst_tls",
SessionToken: "sess_tls",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
})
ackBytes, _ := EncodeFrame(ack)
_ = conn.WriteMessage(websocket.BinaryMessage, ackBytes)
<-time.After(250 * time.Millisecond)
}))
defer relayServer.Close()
t.Setenv("SSL_CERT_FILE", writeServerCertPEM(t, relayServer))
client := NewClient(
Config{
Enabled: true,
ServerURL: wssURL(relayServer),
},
ClientDeps{
LicenseTokenFunc: func() string { return "test-license-jwt" },
TokenValidator: func(string) bool { return true },
LocalAddr: strings.TrimPrefix(mockAPI.URL, "http://"),
ServerVersion: "1.0.0-test",
IdentityPubKey: "test-identity-pub-key",
},
logger,
)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
errCh := make(chan error, 1)
go func() { errCh <- client.Run(ctx) }()
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
if status := client.Status(); status.Connected && status.InstanceID == "inst_tls" {
cancel()
select {
case <-errCh:
case <-time.After(time.Second):
t.Fatal("client.Run didn't return after cancel")
}
return
}
time.Sleep(25 * time.Millisecond)
}
status := client.Status()
cancel()
select {
case <-errCh:
case <-time.After(time.Second):
t.Fatal("client.Run didn't return after cancel")
}
t.Fatalf("expected TLS relay client to connect, got status=%+v", status)
}
func TestClient_DrainTriggersReconnect(t *testing.T) {
logger := zerolog.New(zerolog.NewTestWriter(t))
connectCount := 0
var connectMu sync.Mutex
relayServer := mockRelayServer(t, func(conn *websocket.Conn) {
// Read REGISTER
_, msg, _ := conn.ReadMessage()
frame, _ := DecodeFrame(msg)
if frame.Type != FrameRegister {
return
}
connectMu.Lock()
connectCount++
count := connectCount
connectMu.Unlock()
// Send REGISTER_ACK
ack, _ := NewControlFrame(FrameRegisterAck, 0, RegisterAckPayload{
InstanceID: "inst_abc",
SessionToken: "sess_xyz",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
})
ackBytes, _ := EncodeFrame(ack)
_ = conn.WriteMessage(websocket.BinaryMessage, ackBytes)
if count == 1 {
// First connection: send DRAIN after a short delay
time.Sleep(100 * time.Millisecond)
drain, _ := NewControlFrame(FrameDrain, 0, DrainPayload{
Reason: "server shutting down",
})
drainBytes, _ := EncodeFrame(drain)
_ = conn.WriteMessage(websocket.BinaryMessage, drainBytes)
} else {
// Second connection: keep alive briefly
time.Sleep(500 * time.Millisecond)
}
})
defer relayServer.Close()
cfg := Config{
Enabled: true,
ServerURL: wsURL(relayServer),
}
deps := ClientDeps{
LicenseTokenFunc: func() string { return "test-jwt" },
TokenValidator: func(token string) bool { return true },
LocalAddr: "127.0.0.1:9999",
ServerVersion: "1.0.0",
IdentityPubKey: "test-pub-key",
}
client := NewClient(cfg, deps, logger)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
errCh := make(chan error, 1)
go func() { errCh <- client.Run(ctx) }()
// Wait for second connection
deadline := time.After(8 * time.Second)
for {
connectMu.Lock()
c := connectCount
connectMu.Unlock()
if c >= 2 {
break
}
select {
case <-deadline:
t.Fatalf("timed out waiting for reconnect, connect count: %d", c)
case <-time.After(100 * time.Millisecond):
}
}
cancel()
<-errCh
}
func TestClient_SessionTokenReuse(t *testing.T) {
logger := zerolog.New(zerolog.NewTestWriter(t))
type regCapture struct {
sessionToken string
instanceHint string
}
captures := make(chan regCapture, 2)
relayServer := mockRelayServer(t, func(conn *websocket.Conn) {
_, msg, _ := conn.ReadMessage()
frame, _ := DecodeFrame(msg)
if frame.Type != FrameRegister {
return
}
var regPayload RegisterPayload
_ = UnmarshalControlPayload(frame.Payload, &regPayload)
captures <- regCapture{
sessionToken: regPayload.SessionToken,
instanceHint: regPayload.InstanceHint,
}
ack, _ := NewControlFrame(FrameRegisterAck, 0, RegisterAckPayload{
InstanceID: "inst_abc",
SessionToken: "server-issued-session-token",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
})
ackBytes, _ := EncodeFrame(ack)
_ = conn.WriteMessage(websocket.BinaryMessage, ackBytes)
// Close after a short delay to trigger reconnect
time.Sleep(100 * time.Millisecond)
conn.Close()
})
defer relayServer.Close()
cfg := Config{
Enabled: true,
ServerURL: wsURL(relayServer),
InstanceSecret: "my-raw-secret",
}
deps := ClientDeps{
LicenseTokenFunc: func() string { return "test-jwt" },
TokenValidator: func(token string) bool { return true },
LocalAddr: "127.0.0.1:9999",
ServerVersion: "1.0.0",
IdentityPubKey: "test-pub-key",
}
client := NewClient(cfg, deps, logger)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
errCh := make(chan error, 1)
go func() { errCh <- client.Run(ctx) }()
// First connection: no session token, InstanceHint is the raw secret
select {
case cap := <-captures:
if cap.sessionToken != "" {
t.Errorf("first connection session_token: got %q, want empty", cap.sessionToken)
}
if cap.instanceHint != "my-raw-secret" {
t.Errorf("first connection instance_hint: got %q, want %q", cap.instanceHint, "my-raw-secret")
}
case <-time.After(3 * time.Second):
t.Fatal("timed out waiting for first REGISTER")
}
// Second connection: session token reused, InstanceHint switches to derived instance ID
select {
case cap := <-captures:
if cap.sessionToken != "server-issued-session-token" {
t.Errorf("second connection session_token: got %q, want %q", cap.sessionToken, "server-issued-session-token")
}
// On reconnect, InstanceHint should be the derived instance_id (from
// REGISTER_ACK), NOT the raw secret. The relay server's session
// reconnect path looks up by instance_id directly.
if cap.instanceHint != "inst_abc" {
t.Errorf("second connection instance_hint: got %q, want %q (derived instance ID)", cap.instanceHint, "inst_abc")
}
case <-time.After(8 * time.Second):
t.Fatal("timed out waiting for second REGISTER")
}
cancel()
<-errCh
}
func TestClient_RejectsOversizedWebSocketMessage(t *testing.T) {
logger := zerolog.New(zerolog.NewTestWriter(t))
relayServer := mockRelayServer(t, func(conn *websocket.Conn) {
// Read REGISTER
_, msg, _ := conn.ReadMessage()
frame, _ := DecodeFrame(msg)
if frame.Type != FrameRegister {
return
}
// Send REGISTER_ACK
ack, _ := NewControlFrame(FrameRegisterAck, 0, RegisterAckPayload{
InstanceID: "inst_oversized",
SessionToken: "sess_oversized",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
})
ackBytes, _ := EncodeFrame(ack)
_ = conn.WriteMessage(websocket.BinaryMessage, ackBytes)
// Send one oversized websocket binary message. This should trip
// conn.SetReadLimit and force the client to drop/reconnect.
oversized := make([]byte, wsReadLimit+1)
_ = conn.WriteMessage(websocket.BinaryMessage, oversized)
time.Sleep(200 * time.Millisecond)
})
defer relayServer.Close()
cfg := Config{
Enabled: true,
ServerURL: wsURL(relayServer),
}
deps := ClientDeps{
LicenseTokenFunc: func() string { return "test-jwt" },
TokenValidator: func(token string) bool { return true },
LocalAddr: "127.0.0.1:9999",
ServerVersion: "1.0.0",
IdentityPubKey: "test-pub-key",
}
client := NewClient(cfg, deps, logger)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
errCh := make(chan error, 1)
go func() { errCh <- client.Run(ctx) }()
deadline := time.After(3 * time.Second)
observed := false
for !observed {
status := client.Status()
if strings.Contains(status.LastError, "read limit exceeded") {
observed = true
break
}
select {
case <-deadline:
t.Fatalf("timed out waiting for read-limit error, last error=%q", status.LastError)
case <-time.After(50 * time.Millisecond):
}
}
cancel()
select {
case <-errCh:
case <-time.After(2 * time.Second):
t.Fatal("client.Run did not return after cancel on idle connection")
}
}
// testIdentityKeyPair generates an Ed25519 keypair for testing and returns
// the base64 private key and public key.
func testIdentityKeyPair(t *testing.T) (privB64, pubB64 string) {
t.Helper()
priv, pub, _, err := GenerateIdentityKeyPair()
if err != nil {
t.Fatal(err)
}
return priv, pub
}
func TestClient_EncryptedChannelLifecycle(t *testing.T) {
logger := zerolog.New(zerolog.NewTestWriter(t))
identityPriv, identityPub := testIdentityKeyPair(t)
dataResponseCh := make(chan ProxyResponse, 1)
// Mock local Pulse API
mockAPI := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"path": r.URL.Path, "encrypted": "true"})
}))
defer mockAPI.Close()
localAddr := strings.TrimPrefix(mockAPI.URL, "http://")
relayServer := mockRelayServer(t, func(conn *websocket.Conn) {
// 1. Read REGISTER
_, msg, err := conn.ReadMessage()
if err != nil {
t.Logf("read register: %v", err)
return
}
frame, _ := DecodeFrame(msg)
if frame.Type != FrameRegister {
t.Logf("expected REGISTER, got %s", FrameTypeName(frame.Type))
return
}
// 2. Send REGISTER_ACK
ack, _ := NewControlFrame(FrameRegisterAck, 0, RegisterAckPayload{
InstanceID: "inst_enc",
SessionToken: "sess_enc",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
})
ackBytes, _ := EncodeFrame(ack)
_ = conn.WriteMessage(websocket.BinaryMessage, ackBytes)
// 3. Send CHANNEL_OPEN
chOpen, _ := NewControlFrame(FrameChannelOpen, 10, ChannelOpenPayload{
ChannelID: 10,
AuthToken: "valid-token",
})
chOpenBytes, _ := EncodeFrame(chOpen)
time.Sleep(50 * time.Millisecond)
_ = conn.WriteMessage(websocket.BinaryMessage, chOpenBytes)
// 4. Read CHANNEL_OPEN ack
_, msg, _ = conn.ReadMessage()
frame, _ = DecodeFrame(msg)
if frame.Type != FrameChannelOpen {
t.Logf("expected CHANNEL_OPEN ack, got %s", FrameTypeName(frame.Type))
return
}
// 5. Initiate key exchange: generate app's ephemeral keypair
appPriv, err := GenerateEphemeralKeyPair()
if err != nil {
t.Logf("generate app keypair: %v", err)
return
}
// Send KEY_EXCHANGE from "app" (no signature — app doesn't sign)
kexPayload := MarshalKeyExchangePayload(appPriv.PublicKey().Bytes(), nil)
kexFrame := NewFrame(FrameKeyExchange, 10, kexPayload)
kexBytes, _ := EncodeFrame(kexFrame)
_ = conn.WriteMessage(websocket.BinaryMessage, kexBytes)
// 6. Read instance's KEY_EXCHANGE response
_, msg, err = conn.ReadMessage()
if err != nil {
t.Logf("read key exchange: %v", err)
return
}
frame, _ = DecodeFrame(msg)
if frame.Type != FrameKeyExchange {
t.Logf("expected KEY_EXCHANGE, got %s", FrameTypeName(frame.Type))
return
}
instancePub, sig, err := UnmarshalKeyExchangePayload(frame.Payload)
if err != nil {
t.Logf("unmarshal key exchange: %v", err)
return
}
// Verify signature
if err := VerifyKeyExchangeSignature(instancePub, sig, identityPub); err != nil {
t.Logf("key exchange signature verification failed: %v", err)
return
}
// Derive keys on mock-relay side (acting as app)
instancePubKey, err := ecdh.X25519().NewPublicKey(instancePub)
if err != nil {
t.Logf("parse instance pubkey: %v", err)
return
}
appEnc, err := DeriveChannelKeys(appPriv, instancePubKey, false)
if err != nil {
t.Logf("derive channel keys: %v", err)
return
}
// 7. Send encrypted DATA request
proxyReq := ProxyRequest{
ID: "req_encrypted",
Method: "GET",
Path: "/api/status",
}
proxyReqBytes, _ := json.Marshal(proxyReq)
encryptedReq, err := appEnc.Encrypt(proxyReqBytes)
if err != nil {
t.Logf("encrypt request: %v", err)
return
}
dataFrame := NewFrame(FrameData, 10, encryptedReq)
dataBytes, _ := EncodeFrame(dataFrame)
_ = conn.WriteMessage(websocket.BinaryMessage, dataBytes)
// 8. Read encrypted DATA response
_, msg, err = conn.ReadMessage()
if err != nil {
t.Logf("read data response: %v", err)
return
}
frame, _ = DecodeFrame(msg)
if frame.Type == FrameData {
decrypted, err := appEnc.Decrypt(frame.Payload)
if err != nil {
t.Logf("decrypt response: %v", err)
return
}
var resp ProxyResponse
_ = json.Unmarshal(decrypted, &resp)
dataResponseCh <- resp
}
time.Sleep(2 * time.Second)
})
defer relayServer.Close()
cfg := Config{
Enabled: true,
ServerURL: wsURL(relayServer),
}
deps := ClientDeps{
LicenseTokenFunc: func() string { return "test-jwt" },
TokenValidator: func(token string) bool { return token == "valid-token" },
LocalAddr: localAddr,
ServerVersion: "1.0.0-test",
IdentityPubKey: identityPub,
IdentityPrivateKey: identityPriv,
}
client := NewClient(cfg, deps, logger)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
errCh := make(chan error, 1)
go func() {
errCh <- client.Run(ctx)
}()
select {
case resp := <-dataResponseCh:
if resp.ID != "req_encrypted" {
t.Errorf("response ID: got %q, want %q", resp.ID, "req_encrypted")
}
if resp.Status != 200 {
t.Errorf("response status: got %d, want 200", resp.Status)
}
case <-time.After(3 * time.Second):
t.Fatal("timed out waiting for encrypted DATA response")
}
cancel()
select {
case <-errCh:
case <-time.After(3 * time.Second):
t.Fatal("client.Run didn't return after cancel")
}
}
func TestClient_HandleDataDecryptsBeforeReturning(t *testing.T) {
logger := zerolog.New(zerolog.NewTestWriter(t))
mockAPI := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
defer mockAPI.Close()
blockingAEAD := &blockingTestAEAD{
openStarted: make(chan struct{}),
releaseOpen: make(chan struct{}),
}
client := &Client{
proxy: NewHTTPProxy(strings.TrimPrefix(mockAPI.URL, "http://"), logger),
logger: logger,
channels: map[uint32]*channelState{
7: {
apiToken: "valid-token",
encryption: &ChannelEncryption{
sendCipher: &channelCipher{aead: blockingAEAD},
recvCipher: &channelCipher{aead: blockingAEAD},
},
},
},
}
defer client.proxy.Close()
// Any payload with a zero nonce works with the blocking AEAD.
ciphertext := append(make([]byte, nonceSize), []byte(`{"id":"req-1","method":"GET","path":"/api/ai/approvals"}`)...)
frame := NewFrame(FrameData, 7, ciphertext)
sendCh := make(chan []byte, 1)
dataLimiter := make(chan struct{}, 1)
done := make(chan struct{})
go func() {
client.handleData(context.Background(), frame, sendCh, dataLimiter)
close(done)
}()
select {
case <-blockingAEAD.openStarted:
case <-time.After(time.Second):
t.Fatal("timed out waiting for DATA decrypt to start")
}
select {
case <-done:
t.Fatal("handleData returned before encrypted payload decryption completed")
case <-time.After(100 * time.Millisecond):
}
close(blockingAEAD.releaseOpen)
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("handleData did not return after decrypt completed")
}
}
func TestClient_DataWithoutKeyExchange(t *testing.T) {
// Verifies backward compatibility: unencrypted DATA still works when no KEY_EXCHANGE occurs.
logger := zerolog.New(zerolog.NewTestWriter(t))
dataResponseCh := make(chan ProxyResponse, 1)
mockAPI := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"path": r.URL.Path})
}))
defer mockAPI.Close()
localAddr := strings.TrimPrefix(mockAPI.URL, "http://")
identityPriv, identityPub := testIdentityKeyPair(t)
relayServer := mockRelayServer(t, func(conn *websocket.Conn) {
// REGISTER
_, msg, _ := conn.ReadMessage()
frame, _ := DecodeFrame(msg)
if frame.Type != FrameRegister {
return
}
ack, _ := NewControlFrame(FrameRegisterAck, 0, RegisterAckPayload{
InstanceID: "inst_plain",
SessionToken: "sess_plain",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
})
ackBytes, _ := EncodeFrame(ack)
_ = conn.WriteMessage(websocket.BinaryMessage, ackBytes)
// CHANNEL_OPEN (no KEY_EXCHANGE follows)
chOpen, _ := NewControlFrame(FrameChannelOpen, 5, ChannelOpenPayload{
ChannelID: 5,
AuthToken: "plain-token",
})
chOpenBytes, _ := EncodeFrame(chOpen)
time.Sleep(50 * time.Millisecond)
_ = conn.WriteMessage(websocket.BinaryMessage, chOpenBytes)
// Read CHANNEL_OPEN ack
_, msg, _ = conn.ReadMessage()
frame, _ = DecodeFrame(msg)
if frame.Type != FrameChannelOpen {
return
}
// Send unencrypted DATA
proxyReq := ProxyRequest{
ID: "req_plain",
Method: "GET",
Path: "/api/health",
}
proxyReqBytes, _ := json.Marshal(proxyReq)
dataFrame := NewFrame(FrameData, 5, proxyReqBytes)
dataBytes, _ := EncodeFrame(dataFrame)
_ = conn.WriteMessage(websocket.BinaryMessage, dataBytes)
// Read unencrypted DATA response
_, msg, _ = conn.ReadMessage()
frame, _ = DecodeFrame(msg)
if frame.Type == FrameData {
var resp ProxyResponse
// Should be plain JSON, not encrypted
if err := json.Unmarshal(frame.Payload, &resp); err != nil {
t.Logf("unmarshal response: %v", err)
return
}
dataResponseCh <- resp
}
time.Sleep(2 * time.Second)
})
defer relayServer.Close()
cfg := Config{
Enabled: true,
ServerURL: wsURL(relayServer),
}
deps := ClientDeps{
LicenseTokenFunc: func() string { return "test-jwt" },
TokenValidator: func(token string) bool { return token == "plain-token" },
LocalAddr: localAddr,
ServerVersion: "1.0.0-test",
IdentityPubKey: identityPub,
IdentityPrivateKey: identityPriv,
}
client := NewClient(cfg, deps, logger)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
errCh := make(chan error, 1)
go func() {
errCh <- client.Run(ctx)
}()
select {
case resp := <-dataResponseCh:
if resp.ID != "req_plain" {
t.Errorf("response ID: got %q, want %q", resp.ID, "req_plain")
}
if resp.Status != 200 {
t.Errorf("response status: got %d, want 200", resp.Status)
}
case <-time.After(3 * time.Second):
t.Fatal("timed out waiting for unencrypted DATA response")
}
cancel()
<-errCh
}
func TestClient_KeyExchangeRejectedWithoutIdentityKey(t *testing.T) {
// Verifies that KEY_EXCHANGE fails closed when IdentityPrivateKey is empty:
// the instance sends CHANNEL_CLOSE, removes the channel locally, and
// ignores any subsequent DATA on that channel.
logger := zerolog.New(zerolog.NewTestWriter(t))
channelCloseCh := make(chan ChannelClosePayload, 1)
dataResponseCh := make(chan struct{}, 1)
relayServer := mockRelayServer(t, func(conn *websocket.Conn) {
// REGISTER
_, msg, _ := conn.ReadMessage()
frame, _ := DecodeFrame(msg)
if frame.Type != FrameRegister {
return
}
ack, _ := NewControlFrame(FrameRegisterAck, 0, RegisterAckPayload{
InstanceID: "inst_nosign",
SessionToken: "sess_nosign",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
})
ackBytes, _ := EncodeFrame(ack)
_ = conn.WriteMessage(websocket.BinaryMessage, ackBytes)
// CHANNEL_OPEN
chOpen, _ := NewControlFrame(FrameChannelOpen, 20, ChannelOpenPayload{
ChannelID: 20,
AuthToken: "token-nosign",
})
chOpenBytes, _ := EncodeFrame(chOpen)
time.Sleep(50 * time.Millisecond)
_ = conn.WriteMessage(websocket.BinaryMessage, chOpenBytes)
// Read CHANNEL_OPEN ack
_, msg, _ = conn.ReadMessage()
frame, _ = DecodeFrame(msg)
if frame.Type != FrameChannelOpen {
return
}
// Send KEY_EXCHANGE from "app"
appPriv, _ := GenerateEphemeralKeyPair()
kexPayload := MarshalKeyExchangePayload(appPriv.PublicKey().Bytes(), nil)
kexFrame := NewFrame(FrameKeyExchange, 20, kexPayload)
kexBytes, _ := EncodeFrame(kexFrame)
_ = conn.WriteMessage(websocket.BinaryMessage, kexBytes)
// Should receive CHANNEL_CLOSE (not KEY_EXCHANGE response)
_, msg, err := conn.ReadMessage()
if err != nil {
return
}
frame, _ = DecodeFrame(msg)
if frame.Type == FrameChannelClose {
var closePayload ChannelClosePayload
_ = UnmarshalControlPayload(frame.Payload, &closePayload)
channelCloseCh <- closePayload
} else {
t.Logf("expected CHANNEL_CLOSE, got %s", FrameTypeName(frame.Type))
return
}
// Non-cooperative peer: send DATA on the closed channel anyway.
// The instance must ignore it (channel removed from map).
time.Sleep(50 * time.Millisecond)
proxyReq := ProxyRequest{
ID: "req_should_be_ignored",
Method: "GET",
Path: "/api/status",
}
proxyReqBytes, _ := json.Marshal(proxyReq)
dataFrame := NewFrame(FrameData, 20, proxyReqBytes)
dataBytes, _ := EncodeFrame(dataFrame)
_ = conn.WriteMessage(websocket.BinaryMessage, dataBytes)
// Wait briefly for any response — there should be none.
_ = conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
_, _, readErr := conn.ReadMessage()
if readErr == nil {
// Got a response — the channel wasn't properly removed
dataResponseCh <- struct{}{}
}
_ = conn.SetReadDeadline(time.Time{})
time.Sleep(100 * time.Millisecond)
})
defer relayServer.Close()
cfg := Config{
Enabled: true,
ServerURL: wsURL(relayServer),
}
deps := ClientDeps{
LicenseTokenFunc: func() string { return "test-jwt" },
TokenValidator: func(token string) bool { return token == "token-nosign" },
LocalAddr: "127.0.0.1:9999",
ServerVersion: "1.0.0-test",
IdentityPubKey: "some-pub-key",
IdentityPrivateKey: "", // deliberately empty
}
client := NewClient(cfg, deps, logger)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
errCh := make(chan error, 1)
go func() { errCh <- client.Run(ctx) }()
select {
case closePayload := <-channelCloseCh:
if closePayload.ChannelID != 20 {
t.Errorf("channel ID: got %d, want 20", closePayload.ChannelID)
}
if closePayload.Reason != "key exchange signing unavailable" {
t.Errorf("reason: got %q, want %q", closePayload.Reason, "key exchange signing unavailable")
}
case <-time.After(3 * time.Second):
t.Fatal("timed out waiting for CHANNEL_CLOSE from failed KEY_EXCHANGE")
}
// Verify no DATA response was sent for the post-close frame
select {
case <-dataResponseCh:
t.Fatal("instance processed DATA on a channel that should have been removed after KEY_EXCHANGE rejection")
case <-time.After(800 * time.Millisecond):
// Good — no response
}
// Verify channel is gone from client state
status := client.Status()
if status.ActiveChannels != 0 {
t.Errorf("active channels: got %d, want 0", status.ActiveChannels)
}
cancel()
<-errCh
}
func TestQueueFrameLogsStructuredContextOnEncodeFailure(t *testing.T) {
var logOutput bytes.Buffer
logger := zerolog.New(&logOutput)
sendCh := make(chan []byte, 1)
// Oversized payload ensures EncodeFrame fails.
queueFrame(sendCh, NewFrame(FrameData, 7, make([]byte, MaxPayloadSize+1)), logger)
got := logOutput.String()
for _, expected := range []string{
`"component":"relay_client"`,
`"action":"encode_frame"`,
`"frame_type":"DATA"`,
`"channel":7`,
`"payload_bytes":65537`,
`"message":"Failed to encode frame for send"`,
} {
if !strings.Contains(got, expected) {
t.Fatalf("expected log output to include %s, got %q", expected, got)
}
}
}
func TestQueueFrameLogsStructuredContextOnDrop(t *testing.T) {
var logOutput bytes.Buffer
logger := zerolog.New(&logOutput)
sendCh := make(chan []byte, 1)
sendCh <- []byte("full")
queueFrame(sendCh, NewFrame(FramePing, 11, nil), logger)
got := logOutput.String()
for _, expected := range []string{
`"component":"relay_client"`,
`"action":"drop_frame"`,
`"frame_type":"PING"`,
`"channel":11`,
`"payload_bytes":0`,
`"send_queue_depth":1`,
`"send_queue_capacity":1`,
`"message":"Send channel full, dropping frame"`,
} {
if !strings.Contains(got, expected) {
t.Fatalf("expected log output to include %s, got %q", expected, got)
}
}
}
func TestStatusReportsReconnectDelayDuringLicensePause(t *testing.T) {
logger := zerolog.New(zerolog.NewTestWriter(t))
attempted := make(chan struct{}, 1)
relayServer := mockRelayServer(t, func(conn *websocket.Conn) {
select {
case attempted <- struct{}{}:
default:
}
_, msg, err := conn.ReadMessage()
if err != nil {
return
}
frame, err := DecodeFrame(msg)
if err != nil || frame.Type != FrameRegister {
return
}
ack, _ := NewControlFrame(FrameRegisterAck, 0, RegisterAckPayload{
InstanceID: "inst_status_pause",
SessionToken: "sess_status_pause",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
})
ackBytes, _ := EncodeFrame(ack)
_ = conn.WriteMessage(websocket.BinaryMessage, ackBytes)
errFrame, _ := NewControlFrame(FrameError, 0, ErrorPayload{
Code: ErrCodeLicenseExpired,
Message: "license expired",
})
errBytes, _ := EncodeFrame(errFrame)
_ = conn.WriteMessage(websocket.BinaryMessage, errBytes)
time.Sleep(100 * time.Millisecond)
})
defer relayServer.Close()
deps := ClientDeps{
LicenseTokenFunc: func() string { return "expired-jwt" },
TokenValidator: func(token string) bool { return true },
LocalAddr: "127.0.0.1:9999",
ServerVersion: "test",
IdentityPubKey: "test-pub-key",
}
client := NewClient(Config{Enabled: true, ServerURL: wsURL(relayServer)}, deps, logger)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
errCh := make(chan error, 1)
go func() { errCh <- client.Run(ctx) }()
select {
case <-attempted:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for relay registration attempt")
}
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
status := client.Status()
if status.LastError != "" && status.ReconnectIn != "" {
if status.Connected {
t.Fatalf("expected disconnected status during license pause, got %+v", status)
}
return
}
time.Sleep(25 * time.Millisecond)
}
t.Fatalf("timed out waiting for reconnect_in status; final status=%+v", client.Status())
}
func TestClient_SendPushNotification(t *testing.T) {
logger := zerolog.New(zerolog.NewTestWriter(t))
pushFrameCh := make(chan Frame, 1)
relayServer := mockRelayServer(t, func(conn *websocket.Conn) {
// REGISTER
_, msg, _ := conn.ReadMessage()
frame, _ := DecodeFrame(msg)
if frame.Type != FrameRegister {
return
}
// REGISTER_ACK
ack, _ := NewControlFrame(FrameRegisterAck, 0, RegisterAckPayload{
InstanceID: "inst_push",
SessionToken: "sess_push",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
})
ackBytes, _ := EncodeFrame(ack)
_ = conn.WriteMessage(websocket.BinaryMessage, ackBytes)
// Read frames from the instance; expect PUSH_NOTIFICATION
for {
_, msg, err := conn.ReadMessage()
if err != nil {
return
}
frame, err := DecodeFrame(msg)
if err != nil {
continue
}
if frame.Type == FramePushNotification {
pushFrameCh <- frame
return
}
}
})
defer relayServer.Close()
cfg := Config{
Enabled: true,
ServerURL: wsURL(relayServer),
}
deps := ClientDeps{
LicenseTokenFunc: func() string { return "test-jwt" },
TokenValidator: func(token string) bool { return true },
LocalAddr: "127.0.0.1:9999",
ServerVersion: "1.0.0",
IdentityPubKey: "test-pub-key",
}
client := NewClient(cfg, deps, logger)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
errCh := make(chan error, 1)
go func() { errCh <- client.Run(ctx) }()
// Wait for connection to be established
deadline := time.After(3 * time.Second)
for {
if client.Status().Connected {
break
}
select {
case <-deadline:
t.Fatal("timed out waiting for connection")
case <-time.After(50 * time.Millisecond):
}
}
// Send push notification
notification := NewPatrolFindingNotification("finding-test", "critical", "performance", "Test Push")
if err := client.SendPushNotification(notification); err != nil {
t.Fatalf("SendPushNotification() error = %v", err)
}
// Verify the frame was received by the mock server
select {
case frame := <-pushFrameCh:
if frame.Type != FramePushNotification {
t.Errorf("frame type: got 0x%02X, want 0x%02X", frame.Type, FramePushNotification)
}
if frame.Channel != 0 {
t.Errorf("channel: got %d, want 0 (control channel)", frame.Channel)
}
var payload PushNotificationPayload
if err := UnmarshalControlPayload(frame.Payload, &payload); err != nil {
t.Fatalf("unmarshal push payload: %v", err)
}
if payload.Type != PushTypePatrolCritical {
t.Errorf("payload type: got %q, want %q", payload.Type, PushTypePatrolCritical)
}
if payload.Title != "Test Push" {
t.Errorf("payload title: got %q, want %q", payload.Title, "Test Push")
}
if payload.InstanceID != "inst_push" {
t.Errorf("payload instance_id: got %q, want %q", payload.InstanceID, "inst_push")
}
case <-time.After(3 * time.Second):
t.Fatal("timed out waiting for PUSH_NOTIFICATION frame")
}
cancel()
<-errCh
}
func TestClient_SendPushNotificationDisconnected(t *testing.T) {
logger := zerolog.New(zerolog.NewTestWriter(t))
cfg := Config{
Enabled: true,
ServerURL: "ws://127.0.0.1:1", // unreachable
}
deps := ClientDeps{
LicenseTokenFunc: func() string { return "test-jwt" },
TokenValidator: func(token string) bool { return true },
LocalAddr: "127.0.0.1:9999",
ServerVersion: "1.0.0",
IdentityPubKey: "test-pub-key",
}
// Create client but don't run it — stays disconnected
client := NewClient(cfg, deps, logger)
notification := NewPatrolFindingNotification("finding-test", "warning", "capacity", "Test")
err := client.SendPushNotification(notification)
if err == nil {
t.Fatal("expected error when sending on disconnected client")
}
if err != ErrNotConnected {
t.Errorf("expected ErrNotConnected, got %v", err)
}
}
func TestClient_RegisterFailsWithoutLicenseTokenProvider(t *testing.T) {
logger := zerolog.New(zerolog.NewTestWriter(t))
relayServer := mockRelayServer(t, func(conn *websocket.Conn) {
time.Sleep(100 * time.Millisecond)
})
defer relayServer.Close()
cfg := Config{
Enabled: true,
ServerURL: wsURL(relayServer),
}
deps := ClientDeps{
LicenseTokenFunc: nil, // explicit hardening guard
TokenValidator: func(token string) bool { return true },
LocalAddr: "127.0.0.1:9999",
ServerVersion: "1.0.0",
IdentityPubKey: "test-pub-key",
}
client := NewClient(cfg, deps, logger)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
_, err := client.connectAndHandle(ctx)
if err == nil {
t.Fatal("expected error when LicenseTokenFunc is nil")
}
if !strings.Contains(err.Error(), "license token provider not configured") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestClient_RejectChannelWhenTokenValidatorMissing(t *testing.T) {
logger := zerolog.New(zerolog.NewTestWriter(t))
cfg := Config{
Enabled: true,
ServerURL: "wss://relay.example.com",
}
deps := ClientDeps{
LicenseTokenFunc: func() string { return "test-jwt" },
TokenValidator: nil, // explicit hardening guard
LocalAddr: "127.0.0.1:9999",
ServerVersion: "1.0.0",
IdentityPubKey: "test-pub-key",
}
client := NewClient(cfg, deps, logger)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
// With nil TokenValidator, Run() should fail fast at startup validation
// rather than connecting to the server.
err := client.Run(ctx)
if err == nil {
t.Fatal("expected error when TokenValidator is nil")
}
if !strings.Contains(err.Error(), "token validator") {
t.Fatalf("expected token validator error, got: %v", err)
}
}
func TestClient_OverloadedDataReturnsBusyResponse(t *testing.T) {
logger := zerolog.New(zerolog.NewTestWriter(t))
origLimit := maxConcurrentDataHandlers
maxConcurrentDataHandlers = 1
defer func() { maxConcurrentDataHandlers = origLimit }()
releaseSlow := make(chan struct{})
mockAPI := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/slow" {
<-releaseSlow
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"path": r.URL.Path})
}))
defer mockAPI.Close()
localAddr := strings.TrimPrefix(mockAPI.URL, "http://")
overloadedRespCh := make(chan ProxyResponse, 1)
relayServer := mockRelayServer(t, func(conn *websocket.Conn) {
// REGISTER
_, msg, _ := conn.ReadMessage()
frame, _ := DecodeFrame(msg)
if frame.Type != FrameRegister {
return
}
ack, _ := NewControlFrame(FrameRegisterAck, 0, RegisterAckPayload{
InstanceID: "inst_overload",
SessionToken: "sess_overload",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
})
ackBytes, _ := EncodeFrame(ack)
_ = conn.WriteMessage(websocket.BinaryMessage, ackBytes)
// Open channel
time.Sleep(50 * time.Millisecond)
chOpen, _ := NewControlFrame(FrameChannelOpen, 1, ChannelOpenPayload{
ChannelID: 1,
AuthToken: "valid-token",
})
chOpenBytes, _ := EncodeFrame(chOpen)
_ = conn.WriteMessage(websocket.BinaryMessage, chOpenBytes)
// Read CHANNEL_OPEN ack
_, msg, err := conn.ReadMessage()
if err != nil {
return
}
frame, _ = DecodeFrame(msg)
if frame.Type != FrameChannelOpen {
return
}
// First request occupies the only in-flight slot.
firstReq := ProxyRequest{
ID: "req_slow",
Method: "GET",
Path: "/api/slow",
}
firstReqBytes, _ := json.Marshal(firstReq)
firstFrame := NewFrame(FrameData, 1, firstReqBytes)
firstData, _ := EncodeFrame(firstFrame)
_ = conn.WriteMessage(websocket.BinaryMessage, firstData)
// Second request should get immediate 503 overload response.
secondReq := ProxyRequest{
ID: "req_overload",
Method: "GET",
Path: "/api/fast",
}
secondReqBytes, _ := json.Marshal(secondReq)
secondFrame := NewFrame(FrameData, 1, secondReqBytes)
secondData, _ := EncodeFrame(secondFrame)
_ = conn.WriteMessage(websocket.BinaryMessage, secondData)
_ = conn.SetReadDeadline(time.Now().Add(3 * time.Second))
_, msg, err = conn.ReadMessage()
if err != nil {
return
}
frame, _ = DecodeFrame(msg)
if frame.Type != FrameData {
return
}
var resp ProxyResponse
if err := json.Unmarshal(frame.Payload, &resp); err != nil {
return
}
overloadedRespCh <- resp
close(releaseSlow)
time.Sleep(100 * time.Millisecond)
})
defer relayServer.Close()
cfg := Config{
Enabled: true,
ServerURL: wsURL(relayServer),
}
deps := ClientDeps{
LicenseTokenFunc: func() string { return "test-jwt" },
TokenValidator: func(token string) bool { return token == "valid-token" },
LocalAddr: localAddr,
ServerVersion: "1.0.0",
IdentityPubKey: "test-pub-key",
}
client := NewClient(cfg, deps, logger)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
errCh := make(chan error, 1)
go func() { errCh <- client.Run(ctx) }()
select {
case resp := <-overloadedRespCh:
if resp.ID != "req_overload" {
t.Errorf("response ID: got %q, want %q", resp.ID, "req_overload")
}
if resp.Status != http.StatusServiceUnavailable {
t.Errorf("response status: got %d, want %d", resp.Status, http.StatusServiceUnavailable)
}
case <-time.After(3 * time.Second):
close(releaseSlow)
t.Fatal("timed out waiting for overload response")
}
cancel()
<-errCh
}
func TestNextConsecutiveFailures(t *testing.T) {
tests := []struct {
name string
current int
connected bool
want int
}{
{
name: "increments when connection attempt never established",
current: 2,
connected: false,
want: 3,
},
{
name: "resets streak after a registered session disconnects",
current: 5,
connected: true,
want: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := nextConsecutiveFailures(tt.current, tt.connected)
if got != tt.want {
t.Fatalf("nextConsecutiveFailures(%d, %v) = %d, want %d", tt.current, tt.connected, got, tt.want)
}
})
}
}