feat: GCloud GPU training pipeline + data collection + benchmarking

- gcloud-train.sh: L4/A100/H100 VM provisioning, Rust build, training
  with --cuda, artifact download, auto-cleanup ($0.80-$8.50/hr)
- training-config-sweep.json: 10 hyperparameter configs (LR, batch,
  backbone, windows, loss weights, warmup)
- collect-training-data.py: UDP listener for 2-node ESP32 CSI recording
  to .csi.jsonl with interactive/batch labeling and manifest generation
- benchmark-model.py: ONNX latency/throughput/PCK/FLOPs profiling with
  multi-model sweep comparison

Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
ruv 2026-04-02 22:04:57 -04:00
parent 9a2bc1839a
commit c63cf2ee77
4 changed files with 1657 additions and 0 deletions

550
scripts/benchmark-model.py Normal file
View file

@ -0,0 +1,550 @@
#!/usr/bin/env python3
"""
WiFi-DensePose Model Benchmarking
Loads trained ONNX models, runs inference on test data, and reports
performance metrics: latency, throughput, PCK@0.2, model size, and
estimated FLOPs.
Can compare multiple models from a hyperparameter sweep.
Usage:
# Benchmark a single model
python scripts/benchmark-model.py --model checkpoints/best.onnx
# Benchmark with recorded test data
python scripts/benchmark-model.py --model best.onnx --test-data data/recordings/test.csi.jsonl
# Compare models from a sweep
python scripts/benchmark-model.py --sweep-dir training-results/wdp-train-a100-*/checkpoints/
# Benchmark with synthetic data (no recordings needed)
python scripts/benchmark-model.py --model best.onnx --synthetic --num-samples 200
# Export results as JSON
python scripts/benchmark-model.py --model best.onnx --output results.json
Prerequisites:
pip install onnxruntime numpy
Optional: pip install onnx (for FLOPs estimation)
"""
from __future__ import annotations
import argparse
import json
import os
import sys
import time
from dataclasses import dataclass, field, asdict
from pathlib import Path
from typing import Optional
import numpy as np
try:
import onnxruntime as ort
except ImportError:
print("ERROR: onnxruntime not installed. Run: pip install onnxruntime")
sys.exit(1)
# ── Configuration ────────────────────────────────────────────────────────────
# Default model input shape (must match TrainingConfig defaults)
NUM_SUBCARRIERS = 56
NUM_ANTENNAS_TX = 3
NUM_ANTENNAS_RX = 3
WINDOW_FRAMES = 100
NUM_KEYPOINTS = 17
HEATMAP_SIZE = 56
# PCK threshold
PCK_THRESHOLD = 0.2
# ── Data classes ─────────────────────────────────────────────────────────────
@dataclass
class BenchmarkResult:
model_path: str
model_size_mb: float
num_parameters: Optional[int] = None
estimated_flops: Optional[int] = None
# Latency
warmup_runs: int = 10
benchmark_runs: int = 100
latency_mean_ms: float = 0.0
latency_std_ms: float = 0.0
latency_p50_ms: float = 0.0
latency_p95_ms: float = 0.0
latency_p99_ms: float = 0.0
throughput_fps: float = 0.0
# Accuracy (if ground truth available)
pck_at_02: Optional[float] = None
mean_per_joint_error: Optional[float] = None
num_test_samples: int = 0
# Input shape
input_shape: list = field(default_factory=list)
provider: str = ""
# ── ONNX model loading ──────────────────────────────────────────────────────
def load_model(model_path: str) -> ort.InferenceSession:
"""Load an ONNX model with the best available execution provider."""
providers = []
if "CUDAExecutionProvider" in ort.get_available_providers():
providers.append("CUDAExecutionProvider")
providers.append("CPUExecutionProvider")
sess_opts = ort.SessionOptions()
sess_opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_opts.intra_op_num_threads = os.cpu_count() or 4
session = ort.InferenceSession(model_path, sess_opts, providers=providers)
return session
def get_model_info(model_path: str) -> dict:
"""Extract model metadata: size, parameter count, FLOPs estimate."""
path = Path(model_path)
size_mb = path.stat().st_size / (1024 * 1024)
info = {
"size_mb": round(size_mb, 2),
"num_parameters": None,
"estimated_flops": None,
}
# Try to count parameters via onnx
try:
import onnx
model = onnx.load(model_path)
total_params = 0
for initializer in model.graph.initializer:
shape = list(initializer.dims)
if shape:
total_params += int(np.prod(shape))
info["num_parameters"] = total_params
# Rough FLOPs estimate: ~2 * params (multiply-accumulate)
info["estimated_flops"] = total_params * 2
except ImportError:
pass
except Exception as e:
print(f" Warning: Could not extract parameter count: {e}")
return info
# ── Synthetic data generation ────────────────────────────────────────────────
def generate_synthetic_input(
batch_size: int = 1,
num_subcarriers: int = NUM_SUBCARRIERS,
num_tx: int = NUM_ANTENNAS_TX,
num_rx: int = NUM_ANTENNAS_RX,
window_frames: int = WINDOW_FRAMES,
) -> np.ndarray:
"""Generate synthetic CSI input tensor matching the model's expected shape.
The WiFi-DensePose model expects input shape:
[batch, channels, height, width]
where channels = num_tx * num_rx, height = window_frames, width = num_subcarriers.
"""
channels = num_tx * num_rx # 3x3 = 9 MIMO streams
# Simulate CSI amplitude data with realistic distribution
rng = np.random.default_rng(42)
data = rng.normal(loc=0.0, scale=1.0, size=(batch_size, channels, window_frames, num_subcarriers))
return data.astype(np.float32)
def generate_synthetic_keypoints(
num_samples: int,
num_keypoints: int = NUM_KEYPOINTS,
heatmap_size: int = HEATMAP_SIZE,
) -> np.ndarray:
"""Generate synthetic ground truth keypoint coordinates for PCK evaluation."""
rng = np.random.default_rng(123)
# Keypoints as (x, y) in [0, heatmap_size) range
return rng.uniform(0, heatmap_size, size=(num_samples, num_keypoints, 2)).astype(np.float32)
# ── Load test data from .csi.jsonl ──────────────────────────────────────────
def load_test_data(
jsonl_path: str,
window_frames: int = WINDOW_FRAMES,
num_subcarriers: int = NUM_SUBCARRIERS,
max_samples: int = 500,
) -> np.ndarray:
"""Load CSI frames from a .csi.jsonl file and window them into model inputs."""
frames = []
path = Path(jsonl_path)
with open(path, "r") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
record = json.loads(line)
subs = record.get("subcarriers", [])
if len(subs) > 0:
frames.append(subs)
except json.JSONDecodeError:
continue
if len(frames) < window_frames:
print(f" Warning: Only {len(frames)} frames, need {window_frames}. Padding with zeros.")
while len(frames) < window_frames:
frames.append([0.0] * num_subcarriers)
# Normalize subcarrier count
normalized = []
for frame in frames:
if len(frame) < num_subcarriers:
frame = frame + [0.0] * (num_subcarriers - len(frame))
elif len(frame) > num_subcarriers:
# Downsample via linear interpolation
indices = np.linspace(0, len(frame) - 1, num_subcarriers)
frame = np.interp(indices, range(len(frame)), frame).tolist()
normalized.append(frame)
frames = normalized
# Create sliding windows
samples = []
stride = max(1, window_frames // 2)
for i in range(0, len(frames) - window_frames + 1, stride):
window = frames[i : i + window_frames]
# Shape: [channels=1, window_frames, num_subcarriers]
# Expand single stream to 9 channels (repeat for MIMO)
arr = np.array(window, dtype=np.float32)
arr = np.expand_dims(arr, axis=0) # [1, window_frames, num_subcarriers]
arr = np.repeat(arr, NUM_ANTENNAS_TX * NUM_ANTENNAS_RX, axis=0) # [9, window, subs]
samples.append(arr)
if len(samples) >= max_samples:
break
if not samples:
return generate_synthetic_input(1)
return np.stack(samples, axis=0) # [N, 9, window_frames, num_subcarriers]
# ── Benchmarking ─────────────────────────────────────────────────────────────
def benchmark_latency(
session: ort.InferenceSession,
input_data: np.ndarray,
warmup: int = 10,
runs: int = 100,
) -> dict:
"""Measure inference latency over multiple runs."""
input_name = session.get_inputs()[0].name
# Warmup
for _ in range(warmup):
session.run(None, {input_name: input_data[:1]})
# Timed runs
latencies = []
for _ in range(runs):
start = time.perf_counter()
session.run(None, {input_name: input_data[:1]})
end = time.perf_counter()
latencies.append((end - start) * 1000) # ms
latencies = np.array(latencies)
return {
"mean_ms": float(np.mean(latencies)),
"std_ms": float(np.std(latencies)),
"p50_ms": float(np.percentile(latencies, 50)),
"p95_ms": float(np.percentile(latencies, 95)),
"p99_ms": float(np.percentile(latencies, 99)),
"throughput_fps": 1000.0 / float(np.mean(latencies)),
}
def compute_pck(
predictions: np.ndarray,
ground_truth: np.ndarray,
threshold: float = PCK_THRESHOLD,
normalize_by: float = HEATMAP_SIZE,
) -> float:
"""Compute Percentage of Correct Keypoints at a given threshold.
PCK@t = fraction of predicted keypoints within t * normalize_by of ground truth.
"""
if predictions.shape != ground_truth.shape:
return 0.0
# Euclidean distance per keypoint
distances = np.linalg.norm(predictions - ground_truth, axis=-1) # [N, K]
threshold_pixels = threshold * normalize_by
correct = (distances < threshold_pixels).astype(float)
return float(np.mean(correct))
def extract_keypoints_from_heatmaps(heatmaps: np.ndarray) -> np.ndarray:
"""Convert heatmap outputs [N, K, H, W] to keypoint coordinates [N, K, 2]."""
n, k, h, w = heatmaps.shape
flat = heatmaps.reshape(n, k, -1)
max_idx = np.argmax(flat, axis=-1) # [N, K]
y = max_idx // w
x = max_idx % w
return np.stack([x, y], axis=-1).astype(np.float32)
def benchmark_model(
model_path: str,
test_data: Optional[np.ndarray] = None,
gt_keypoints: Optional[np.ndarray] = None,
warmup: int = 10,
runs: int = 100,
) -> BenchmarkResult:
"""Run full benchmark on a single model."""
print(f"\nBenchmarking: {model_path}")
# Load model
session = load_model(model_path)
provider = session.get_providers()[0]
print(f" Provider: {provider}")
# Model info
model_info = get_model_info(model_path)
print(f" Size: {model_info['size_mb']} MB")
if model_info["num_parameters"]:
print(f" Parameters: {model_info['num_parameters']:,}")
if model_info["estimated_flops"]:
print(f" Estimated FLOPs: {model_info['estimated_flops']:,}")
# Input shape
input_meta = session.get_inputs()[0]
input_shape = input_meta.shape
print(f" Input: {input_meta.name} {input_shape} ({input_meta.type})")
# Output shapes
for out in session.get_outputs():
print(f" Output: {out.name} {out.shape}")
# Generate or use provided test data
if test_data is None:
# Infer shape from model
if input_shape and all(isinstance(d, int) for d in input_shape):
batch = max(1, input_shape[0] if input_shape[0] > 0 else 1)
test_data = np.random.randn(*[batch if d <= 0 else d for d in input_shape]).astype(np.float32)
else:
test_data = generate_synthetic_input(1)
# Latency benchmark
print(f" Running {warmup} warmup + {runs} benchmark iterations...")
latency = benchmark_latency(session, test_data, warmup=warmup, runs=runs)
print(f" Latency: {latency['mean_ms']:.2f} +/- {latency['std_ms']:.2f} ms")
print(f" P50/P95/P99: {latency['p50_ms']:.2f} / {latency['p95_ms']:.2f} / {latency['p99_ms']:.2f} ms")
print(f" Throughput: {latency['throughput_fps']:.1f} fps")
# Accuracy (if ground truth provided or we can do synthetic evaluation)
pck = None
mpjpe = None
num_samples = 0
if gt_keypoints is not None and test_data is not None:
input_name = session.get_inputs()[0].name
all_preds = []
for i in range(len(test_data)):
outputs = session.run(None, {input_name: test_data[i : i + 1]})
# Assume first output is keypoint heatmaps [1, K, H, W]
heatmaps = outputs[0]
if heatmaps.ndim == 4:
kp = extract_keypoints_from_heatmaps(heatmaps)
all_preds.append(kp[0])
if all_preds:
predictions = np.stack(all_preds)
gt = gt_keypoints[: len(predictions)]
pck = compute_pck(predictions, gt)
distances = np.linalg.norm(predictions - gt, axis=-1)
mpjpe = float(np.mean(distances))
num_samples = len(predictions)
print(f" PCK@{PCK_THRESHOLD}: {pck:.4f}")
print(f" MPJPE: {mpjpe:.2f} px")
print(f" Samples: {num_samples}")
result = BenchmarkResult(
model_path=model_path,
model_size_mb=model_info["size_mb"],
num_parameters=model_info["num_parameters"],
estimated_flops=model_info["estimated_flops"],
warmup_runs=warmup,
benchmark_runs=runs,
latency_mean_ms=round(latency["mean_ms"], 3),
latency_std_ms=round(latency["std_ms"], 3),
latency_p50_ms=round(latency["p50_ms"], 3),
latency_p95_ms=round(latency["p95_ms"], 3),
latency_p99_ms=round(latency["p99_ms"], 3),
throughput_fps=round(latency["throughput_fps"], 1),
pck_at_02=round(pck, 4) if pck is not None else None,
mean_per_joint_error=round(mpjpe, 2) if mpjpe is not None else None,
num_test_samples=num_samples,
input_shape=list(input_shape) if input_shape else [],
provider=provider,
)
return result
# ── Comparison table ─────────────────────────────────────────────────────────
def print_comparison_table(results: list[BenchmarkResult]):
"""Print a formatted comparison table of multiple models."""
if not results:
return
print("\n" + "=" * 100)
print(" Model Comparison")
print("=" * 100)
# Header
print(
f"{'Model':<35} {'Size(MB)':>8} {'Params':>10} "
f"{'Lat(ms)':>8} {'P95(ms)':>8} {'FPS':>7} {'PCK@0.2':>8}"
)
print("-" * 100)
for r in results:
name = Path(r.model_path).stem[:33]
params = f"{r.num_parameters:,}" if r.num_parameters else "?"
pck = f"{r.pck_at_02:.4f}" if r.pck_at_02 is not None else "N/A"
print(
f"{name:<35} {r.model_size_mb:>8.2f} {params:>10} "
f"{r.latency_mean_ms:>8.2f} {r.latency_p95_ms:>8.2f} "
f"{r.throughput_fps:>7.1f} {pck:>8}"
)
print("=" * 100)
# Best model by latency
best_latency = min(results, key=lambda r: r.latency_mean_ms)
print(f"\n Fastest: {Path(best_latency.model_path).stem} ({best_latency.latency_mean_ms:.2f} ms)")
# Best by PCK (if available)
pck_results = [r for r in results if r.pck_at_02 is not None]
if pck_results:
best_pck = max(pck_results, key=lambda r: r.pck_at_02)
print(f" Best accuracy: {Path(best_pck.model_path).stem} (PCK@0.2={best_pck.pck_at_02:.4f})")
# Smallest model
smallest = min(results, key=lambda r: r.model_size_mb)
print(f" Smallest: {Path(smallest.model_path).stem} ({smallest.model_size_mb:.2f} MB)")
# ── Main ─────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(
description="Benchmark WiFi-DensePose ONNX models",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument("--model", type=str, help="Path to a single ONNX model")
parser.add_argument("--sweep-dir", type=str, help="Directory containing multiple ONNX models to compare")
parser.add_argument("--test-data", type=str, help="Path to .csi.jsonl test data file")
parser.add_argument("--synthetic", action="store_true", help="Use synthetic test data")
parser.add_argument("--num-samples", type=int, default=100, help="Number of synthetic samples (default: 100)")
parser.add_argument("--warmup", type=int, default=10, help="Warmup iterations (default: 10)")
parser.add_argument("--runs", type=int, default=100, help="Benchmark iterations (default: 100)")
parser.add_argument("--output", type=str, help="Save results to JSON file")
parser.add_argument("--gpu", action="store_true", help="Force GPU execution provider")
args = parser.parse_args()
if not args.model and not args.sweep_dir:
parser.error("Specify --model or --sweep-dir")
# Prepare test data
test_data = None
gt_keypoints = None
if args.test_data:
print(f"Loading test data from: {args.test_data}")
test_data = load_test_data(args.test_data)
print(f" Loaded {len(test_data)} windowed samples")
elif args.synthetic:
print(f"Generating {args.num_samples} synthetic samples...")
test_data = generate_synthetic_input(args.num_samples)
gt_keypoints = generate_synthetic_keypoints(args.num_samples)
print(f" Input shape: {test_data.shape}")
# Collect models
model_paths = []
if args.model:
model_paths.append(args.model)
if args.sweep_dir:
sweep = Path(args.sweep_dir)
if sweep.is_dir():
model_paths.extend(sorted(str(p) for p in sweep.glob("**/*.onnx")))
else:
# Glob pattern
from glob import glob
model_paths.extend(sorted(glob(str(sweep))))
if not model_paths:
print("ERROR: No ONNX models found.")
sys.exit(1)
print(f"Found {len(model_paths)} model(s) to benchmark.")
# Benchmark each model
results = []
for path in model_paths:
if not Path(path).exists():
print(f" Skipping (not found): {path}")
continue
try:
result = benchmark_model(
path,
test_data=test_data,
gt_keypoints=gt_keypoints,
warmup=args.warmup,
runs=args.runs,
)
results.append(result)
except Exception as e:
print(f" ERROR benchmarking {path}: {e}")
# Comparison table
if len(results) > 1:
print_comparison_table(results)
# Save results
if args.output:
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
json.dump(
{
"benchmark_results": [asdict(r) for r in results],
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"num_models": len(results),
},
f,
indent=2,
)
print(f"\nResults saved to: {output_path}")
if not results:
print("No models were successfully benchmarked.")
sys.exit(1)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,483 @@
#!/usr/bin/env python3
"""
WiFi-DensePose Training Data Collector
Listens on UDP for CSI data from ESP32 nodes and records to .csi.jsonl
files compatible with the Rust training pipeline (MmFiDataset / CsiDataset).
Supports two packet formats:
- ADR-069 feature vectors (magic 0xC5110003, 48 bytes) 8-dim pre-extracted
- ADR-018 raw CSI frames (magic 0xC5110001, variable) full subcarrier data
Usage:
# Interactive — prompts for scenario labels
python scripts/collect-training-data.py --port 5006
# Scripted — fixed label, 60s per recording
python scripts/collect-training-data.py --port 5006 --label walking --duration 60
# Multiple scenarios in sequence
python scripts/collect-training-data.py --port 5006 --scenarios walking,standing,sitting --duration 30
# Dual-node collection (two ESP32s on different ports)
python scripts/collect-training-data.py --port 5005 --port2 5006 --label walking
# Generate manifest only from existing recordings
python scripts/collect-training-data.py --manifest-only --output-dir data/recordings
Prerequisites:
- ESP32 nodes streaming CSI on UDP (see firmware/esp32-csi-node)
- Python 3.9+
"""
from __future__ import annotations
import argparse
import json
import logging
import os
import socket
import struct
import sys
import time
import signal
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%H:%M:%S",
)
log = logging.getLogger("collect-data")
# ── Packet formats (must match firmware) ─────────────────────────────────────
# ADR-018 raw CSI frame header
MAGIC_CSI_RAW = 0xC5110001
# ADR-069 feature vector packet
MAGIC_FEATURES = 0xC5110003
FEATURE_PKT_FMT = "<IBBHq8f"
FEATURE_PKT_SIZE = struct.calcsize(FEATURE_PKT_FMT) # 48 bytes
# Raw CSI header: magic(4) + node_id(1) + antenna_cfg(1) + n_sub(2) + rssi(1) + noise(1) + channel(1) + reserved(1) + timestamp_ms(4)
RAW_CSI_HDR_FMT = "<IBBHbbBxI"
RAW_CSI_HDR_SIZE = struct.calcsize(RAW_CSI_HDR_FMT) # 16 bytes
# ── Packet parsing ───────────────────────────────────────────────────────────
def parse_packet(data: bytes) -> Optional[dict]:
"""Parse a UDP packet into a frame dict, or None if unrecognized."""
if len(data) < 4:
return None
magic = struct.unpack_from("<I", data)[0]
if magic == MAGIC_FEATURES and len(data) >= FEATURE_PKT_SIZE:
return _parse_feature_packet(data)
elif magic == MAGIC_CSI_RAW and len(data) >= RAW_CSI_HDR_SIZE:
return _parse_raw_csi_packet(data)
else:
return None
def _parse_feature_packet(data: bytes) -> Optional[dict]:
"""Parse ADR-069 feature vector packet (48 bytes)."""
try:
magic, node_id, _, seq, ts_us, *features = struct.unpack_from(FEATURE_PKT_FMT, data)
except struct.error:
return None
if magic != MAGIC_FEATURES:
return None
# Reject NaN/inf
import math
if any(math.isnan(f) or math.isinf(f) for f in features):
return None
return {
"type": "features",
"node_id": node_id,
"seq": seq,
"timestamp_us": ts_us,
"timestamp": ts_us / 1_000_000.0,
"features": features,
"subcarriers": features, # Use features as subcarrier proxy for training
"rssi": 0.0,
"noise_floor": 0.0,
}
def _parse_raw_csi_packet(data: bytes) -> Optional[dict]:
"""Parse ADR-018 raw CSI frame with full subcarrier data."""
try:
magic, node_id, ant_cfg, n_sub, rssi, noise, channel, ts_ms = struct.unpack_from(
RAW_CSI_HDR_FMT, data
)
except struct.error:
return None
if magic != MAGIC_CSI_RAW:
return None
# Subcarrier data follows header as int16 I/Q pairs
payload_offset = RAW_CSI_HDR_SIZE
expected_bytes = n_sub * 2 * 2 # n_sub * (I + Q) * int16
if len(data) < payload_offset + expected_bytes:
return None
iq_data = struct.unpack_from(f"<{n_sub * 2}h", data, payload_offset)
# Convert I/Q pairs to amplitude
subcarriers = []
for i in range(0, len(iq_data), 2):
real, imag = iq_data[i], iq_data[i + 1]
amplitude = (real ** 2 + imag ** 2) ** 0.5
subcarriers.append(amplitude)
return {
"type": "raw_csi",
"node_id": node_id,
"antenna_config": ant_cfg,
"n_subcarriers": n_sub,
"channel": channel,
"timestamp": ts_ms / 1000.0,
"subcarriers": subcarriers,
"rssi": float(rssi),
"noise_floor": float(noise),
}
# ── JSONL recording ──────────────────────────────────────────────────────────
class CsiRecorder:
"""Records CSI frames to .csi.jsonl files compatible with the Rust pipeline."""
def __init__(self, output_dir: str, session_name: str, label: Optional[str] = None):
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
safe_name = session_name.replace(" ", "_").replace("/", "_")
self.session_id = f"{safe_name}-{ts}"
self.label = label
self.file_path = self.output_dir / f"{self.session_id}.csi.jsonl"
self.meta_path = self.output_dir / f"{self.session_id}.csi.meta.json"
self.frame_count = 0
self.start_time = time.time()
self.started_at = datetime.now(timezone.utc).isoformat()
self._file = None
def open(self):
self._file = open(self.file_path, "a", encoding="utf-8")
log.info(f"Recording to: {self.file_path}")
def write_frame(self, frame: dict):
"""Write a single frame as a JSONL line."""
if self._file is None:
return
record = {
"timestamp": frame.get("timestamp", time.time()),
"subcarriers": frame.get("subcarriers", []),
"rssi": frame.get("rssi", 0.0),
"noise_floor": frame.get("noise_floor", 0.0),
"features": {
k: v for k, v in frame.items()
if k not in ("timestamp", "subcarriers", "rssi", "noise_floor", "type")
},
}
line = json.dumps(record, separators=(",", ":"))
self._file.write(line + "\n")
self.frame_count += 1
if self.frame_count % 500 == 0:
self._file.flush()
def close(self) -> dict:
"""Close the recording and write metadata. Returns session info."""
if self._file:
self._file.flush()
self._file.close()
self._file = None
ended_at = datetime.now(timezone.utc).isoformat()
elapsed = time.time() - self.start_time
file_size = self.file_path.stat().st_size if self.file_path.exists() else 0
meta = {
"id": self.session_id,
"name": self.session_id,
"label": self.label,
"started_at": self.started_at,
"ended_at": ended_at,
"duration_secs": round(elapsed, 2),
"frame_count": self.frame_count,
"file_size_bytes": file_size,
"file_path": str(self.file_path),
"fps": round(self.frame_count / elapsed, 1) if elapsed > 0 else 0,
}
with open(self.meta_path, "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2)
log.info(
f"Recording stopped: {self.frame_count} frames in {elapsed:.1f}s "
f"({meta['fps']} fps, {file_size / 1024:.1f} KB)"
)
return meta
# ── Manifest generation ──────────────────────────────────────────────────────
def generate_manifest(output_dir: str) -> dict:
"""Scan recordings directory and generate a dataset manifest JSON."""
rec_dir = Path(output_dir)
sessions = []
for meta_file in sorted(rec_dir.glob("*.csi.meta.json")):
try:
with open(meta_file, "r") as f:
meta = json.load(f)
sessions.append(meta)
except (json.JSONDecodeError, OSError) as e:
log.warning(f"Skipping {meta_file}: {e}")
# Aggregate stats
total_frames = sum(s.get("frame_count", 0) for s in sessions)
total_bytes = sum(s.get("file_size_bytes", 0) for s in sessions)
labels = sorted(set(s.get("label", "unlabeled") or "unlabeled" for s in sessions))
manifest = {
"dataset": "wifi-densepose-csi",
"generated_at": datetime.now(timezone.utc).isoformat(),
"directory": str(rec_dir),
"num_sessions": len(sessions),
"total_frames": total_frames,
"total_size_bytes": total_bytes,
"total_size_mb": round(total_bytes / (1024 * 1024), 2),
"labels": labels,
"sessions": sessions,
}
manifest_path = rec_dir / "manifest.json"
with open(manifest_path, "w", encoding="utf-8") as f:
json.dump(manifest, f, indent=2)
log.info(
f"Manifest: {len(sessions)} sessions, {total_frames} frames, "
f"{manifest['total_size_mb']} MB, labels={labels}"
)
log.info(f"Written to: {manifest_path}")
return manifest
# ── UDP listener ─────────────────────────────────────────────────────────────
def collect_session(
port: int,
port2: Optional[int],
output_dir: str,
label: str,
duration: float,
session_name: Optional[str] = None,
) -> dict:
"""Run a single collection session. Returns session metadata."""
name = session_name or label or "session"
recorder = CsiRecorder(output_dir, name, label)
recorder.open()
# Bind primary socket
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("0.0.0.0", port))
sock.settimeout(1.0)
sockets = [sock]
# Bind secondary socket if specified
if port2:
sock2 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock2.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock2.bind(("0.0.0.0", port2))
sock2.settimeout(0.1)
sockets.append(sock2)
log.info(
f"Collecting '{label}' for {duration}s on port(s) "
f"{port}{f', {port2}' if port2 else ''}"
)
start = time.time()
dropped = 0
try:
while time.time() - start < duration:
for s in sockets:
try:
data, addr = s.recvfrom(4096)
except socket.timeout:
continue
frame = parse_packet(data)
if frame:
recorder.write_frame(frame)
else:
dropped += 1
# Progress update every 5s
elapsed = time.time() - start
if recorder.frame_count > 0 and int(elapsed) % 5 == 0 and int(elapsed) > 0:
remaining = duration - elapsed
if remaining > 0 and int(elapsed * 10) % 50 == 0:
log.info(
f" {recorder.frame_count} frames collected, "
f"{remaining:.0f}s remaining..."
)
except KeyboardInterrupt:
log.info("Interrupted by user.")
finally:
for s in sockets:
s.close()
if dropped > 0:
log.warning(f" {dropped} unrecognized packets dropped")
return recorder.close()
# ── Main ─────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(
description="Collect CSI training data from ESP32 nodes via UDP",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Interactive label input
python scripts/collect-training-data.py --port 5006
# Fixed label, 60 seconds
python scripts/collect-training-data.py --port 5006 --label walking --duration 60
# Multiple scenarios
python scripts/collect-training-data.py --port 5006 --scenarios walking,standing,sitting --duration 30
# Dual ESP32 nodes
python scripts/collect-training-data.py --port 5005 --port2 5006 --label test
# Generate manifest from existing recordings
python scripts/collect-training-data.py --manifest-only
""",
)
parser.add_argument("--port", type=int, default=5006, help="Primary UDP port (default: 5006)")
parser.add_argument("--port2", type=int, default=None, help="Secondary UDP port for dual-node")
parser.add_argument("--output-dir", default="data/recordings", help="Output directory (default: data/recordings)")
parser.add_argument("--label", default=None, help="Activity label for the recording")
parser.add_argument("--duration", type=float, default=30.0, help="Recording duration in seconds (default: 30)")
parser.add_argument("--scenarios", default=None, help="Comma-separated list of scenarios to record sequentially")
parser.add_argument("--pause", type=float, default=5.0, help="Pause between scenarios in seconds (default: 5)")
parser.add_argument("--manifest-only", action="store_true", help="Only generate manifest from existing recordings")
parser.add_argument("--repeats", type=int, default=1, help="Number of repeats per scenario (default: 1)")
args = parser.parse_args()
# Manifest-only mode
if args.manifest_only:
generate_manifest(args.output_dir)
return
# Collect scenarios
all_sessions = []
if args.scenarios:
# Multi-scenario sequential collection
scenarios = [s.strip() for s in args.scenarios.split(",") if s.strip()]
total = len(scenarios) * args.repeats
idx = 0
for repeat in range(args.repeats):
for scenario in scenarios:
idx += 1
print(f"\n{'='*60}")
print(f" Scenario {idx}/{total}: '{scenario}' (repeat {repeat+1}/{args.repeats})")
print(f" Duration: {args.duration}s")
print(f"{'='*60}")
if idx > 1:
print(f" Starting in {args.pause}s... (get into position)")
time.sleep(args.pause)
meta = collect_session(
port=args.port,
port2=args.port2,
output_dir=args.output_dir,
label=scenario,
duration=args.duration,
session_name=f"{scenario}_r{repeat+1:02d}",
)
all_sessions.append(meta)
elif args.label:
# Single labeled recording
meta = collect_session(
port=args.port,
port2=args.port2,
output_dir=args.output_dir,
label=args.label,
duration=args.duration,
)
all_sessions.append(meta)
else:
# Interactive mode — prompt for labels
print("\nInteractive data collection mode.")
print("Type a label for each recording, or 'q' to quit.\n")
while True:
label = input("Label (or 'q' to quit): ").strip()
if label.lower() in ("q", "quit", "exit"):
break
if not label:
print(" Empty label. Try again.")
continue
duration = args.duration
try:
dur_input = input(f"Duration in seconds [{duration}]: ").strip()
if dur_input:
duration = float(dur_input)
except ValueError:
pass
print(f" Recording '{label}' for {duration}s — starting now...")
meta = collect_session(
port=args.port,
port2=args.port2,
output_dir=args.output_dir,
label=label,
duration=duration,
)
all_sessions.append(meta)
print()
# Generate manifest
if all_sessions:
print(f"\nCollected {len(all_sessions)} session(s).")
manifest = generate_manifest(args.output_dir)
total_frames = sum(s.get("frame_count", 0) for s in all_sessions)
print(f"\nSummary:")
print(f" Sessions: {len(all_sessions)}")
print(f" Total frames: {total_frames}")
print(f" Output: {args.output_dir}/")
print(f" Manifest: {args.output_dir}/manifest.json")
else:
print("No sessions recorded.")
if __name__ == "__main__":
main()

469
scripts/gcloud-train.sh Normal file
View file

@ -0,0 +1,469 @@
#!/bin/bash
# ==============================================================================
# GCloud GPU Training Script for WiFi-DensePose
# ==============================================================================
#
# Creates a GCloud VM with GPU, runs the Rust training pipeline, downloads
# the trained model artifacts, and tears down the VM to avoid ongoing costs.
#
# Usage:
# bash scripts/gcloud-train.sh [OPTIONS]
#
# Options:
# --gpu l4|a100|h100 GPU type (default: l4)
# --zone ZONE GCloud zone (default: us-central1-a)
# --hours N Max VM lifetime in hours (default: 2)
# --config FILE Training config JSON (default: scripts/training-config-sweep.json entry 0)
# --data-dir DIR Local data directory to upload (default: data/recordings)
# --dry-run Run smoke test with synthetic data
# --sweep Run full hyperparameter sweep (all configs)
# --keep-vm Do not delete VM after training
# --instance NAME Custom VM instance name
#
# Prerequisites:
# - gcloud CLI authenticated: gcloud auth login
# - Project set: gcloud config set project cognitum-20260110
# - Quota for GPUs in the selected zone
#
# Cost estimates:
# L4 (~$0.80/hr) — good for prototyping and small sweeps
# A100 40GB (~$3.60/hr) — full training runs
# H100 80GB (~$11.00/hr) — large batch / fast iteration
# ==============================================================================
set -euo pipefail
# ── Defaults ──────────────────────────────────────────────────────────────────
PROJECT="cognitum-20260110"
GPU_TYPE="l4"
ZONE="us-central1-a"
MAX_HOURS=2
CONFIG_FILE=""
DATA_DIR="data/recordings"
DRY_RUN=false
SWEEP=false
KEEP_VM=false
INSTANCE_NAME=""
REPO_URL="https://github.com/ruvnet/wifi-densepose.git"
BRANCH="main"
# ── Parse arguments ───────────────────────────────────────────────────────────
while [[ $# -gt 0 ]]; do
case "$1" in
--gpu) GPU_TYPE="$2"; shift 2 ;;
--zone) ZONE="$2"; shift 2 ;;
--hours) MAX_HOURS="$2"; shift 2 ;;
--config) CONFIG_FILE="$2"; shift 2 ;;
--data-dir) DATA_DIR="$2"; shift 2 ;;
--dry-run) DRY_RUN=true; shift ;;
--sweep) SWEEP=true; shift ;;
--keep-vm) KEEP_VM=true; shift ;;
--instance) INSTANCE_NAME="$2"; shift 2 ;;
--branch) BRANCH="$2"; shift 2 ;;
-h|--help)
head -35 "$0" | tail -30
exit 0
;;
*)
echo "ERROR: Unknown option: $1"
exit 1
;;
esac
done
# ── GPU configuration map ────────────────────────────────────────────────────
declare -A GPU_ACCELERATOR=(
[l4]="nvidia-l4"
[a100]="nvidia-tesla-a100"
[h100]="nvidia-h100-80gb"
)
declare -A GPU_MACHINE_TYPE=(
[l4]="g2-standard-8"
[a100]="a2-highgpu-1g"
[h100]="a3-highgpu-1g"
)
declare -A GPU_BOOT_DISK=(
[l4]="200"
[a100]="300"
[h100]="300"
)
if [[ -z "${GPU_ACCELERATOR[$GPU_TYPE]+x}" ]]; then
echo "ERROR: Unknown GPU type '$GPU_TYPE'. Choose: l4, a100, h100"
exit 1
fi
ACCELERATOR="${GPU_ACCELERATOR[$GPU_TYPE]}"
MACHINE_TYPE="${GPU_MACHINE_TYPE[$GPU_TYPE]}"
BOOT_DISK_GB="${GPU_BOOT_DISK[$GPU_TYPE]}"
# ── Instance naming ──────────────────────────────────────────────────────────
TIMESTAMP=$(date +%Y%m%d-%H%M%S)
if [[ -z "$INSTANCE_NAME" ]]; then
INSTANCE_NAME="wdp-train-${GPU_TYPE}-${TIMESTAMP}"
fi
# ── Announce plan ────────────────────────────────────────────────────────────
echo "============================================================"
echo " WiFi-DensePose GCloud GPU Training"
echo "============================================================"
echo " Project: $PROJECT"
echo " Instance: $INSTANCE_NAME"
echo " Zone: $ZONE"
echo " GPU: $GPU_TYPE ($ACCELERATOR)"
echo " Machine: $MACHINE_TYPE"
echo " Boot disk: ${BOOT_DISK_GB}GB"
echo " Max runtime: ${MAX_HOURS}h"
echo " Data dir: $DATA_DIR"
echo " Dry run: $DRY_RUN"
echo " Sweep: $SWEEP"
echo " Branch: $BRANCH"
echo "============================================================"
echo ""
# ── Verify gcloud auth ──────────────────────────────────────────────────────
if ! gcloud auth list --filter=status:ACTIVE --format="value(account)" 2>/dev/null | head -1 | grep -q '@'; then
echo "ERROR: No active gcloud account. Run: gcloud auth login"
exit 1
fi
gcloud config set project "$PROJECT" --quiet
# ── Build startup script ─────────────────────────────────────────────────────
STARTUP_SCRIPT=$(cat <<'STARTUP_EOF'
#!/bin/bash
set -euo pipefail
exec > /var/log/wdp-setup.log 2>&1
echo "=== WiFi-DensePose GPU VM Setup ==="
echo "Started: $(date)"
# Wait for GPU driver
echo "Waiting for NVIDIA driver..."
for i in $(seq 1 60); do
if nvidia-smi &>/dev/null; then
echo "GPU ready after ${i}s"
nvidia-smi
break
fi
sleep 5
done
if ! nvidia-smi &>/dev/null; then
echo "ERROR: GPU driver not available after 300s"
exit 1
fi
# Install Rust toolchain
echo "Installing Rust toolchain..."
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable
source "$HOME/.cargo/env"
rustc --version
cargo --version
# Install system dependencies
echo "Installing system dependencies..."
apt-get update -qq
apt-get install -y -qq pkg-config libssl-dev cmake clang
# Find libtorch from the Deep Learning VM's PyTorch installation
echo "Locating libtorch..."
PYTORCH_LIB=$(python3 -c "import torch; print(torch.__path__[0] + '/lib')" 2>/dev/null || echo "")
if [[ -n "$PYTORCH_LIB" && -d "$PYTORCH_LIB" ]]; then
export LIBTORCH="$PYTORCH_LIB"
export LD_LIBRARY_PATH="${LIBTORCH}:${LD_LIBRARY_PATH:-}"
echo "Found libtorch at: $LIBTORCH"
else
echo "WARNING: PyTorch not found in system Python. Installing via pip..."
pip3 install torch --index-url https://download.pytorch.org/whl/cu121
PYTORCH_LIB=$(python3 -c "import torch; print(torch.__path__[0] + '/lib')")
export LIBTORCH="$PYTORCH_LIB"
export LD_LIBRARY_PATH="${LIBTORCH}:${LD_LIBRARY_PATH:-}"
fi
# Persist env vars
cat >> /etc/environment <<ENV_VARS
LIBTORCH=$LIBTORCH
LD_LIBRARY_PATH=$LIBTORCH:\$LD_LIBRARY_PATH
PATH=$HOME/.cargo/bin:\$PATH
ENV_VARS
echo "=== Setup complete: $(date) ==="
touch /tmp/wdp-setup-done
STARTUP_EOF
)
# ── Step 1: Create the VM ────────────────────────────────────────────────────
echo "[1/7] Creating VM instance: $INSTANCE_NAME ..."
gcloud compute instances create "$INSTANCE_NAME" \
--project="$PROJECT" \
--zone="$ZONE" \
--machine-type="$MACHINE_TYPE" \
--accelerator="type=$ACCELERATOR,count=1" \
--image-family="common-cu121-ubuntu-2204" \
--image-project="deeplearning-platform-release" \
--boot-disk-size="${BOOT_DISK_GB}GB" \
--boot-disk-type="pd-ssd" \
--maintenance-policy=TERMINATE \
--metadata="install-nvidia-driver=True" \
--metadata-from-file="startup-script=<(echo "$STARTUP_SCRIPT")" \
--scopes="default,storage-rw" \
--labels="purpose=wdp-training,gpu=${GPU_TYPE}" \
--quiet
echo " VM created. Waiting for startup script to complete..."
# ── Step 2: Wait for setup ───────────────────────────────────────────────────
echo "[2/7] Waiting for setup to complete (GPU driver + Rust toolchain)..."
for i in $(seq 1 60); do
if gcloud compute ssh "$INSTANCE_NAME" --zone="$ZONE" --command="test -f /tmp/wdp-setup-done" --quiet 2>/dev/null; then
echo " Setup complete after $((i * 15))s"
break
fi
if [[ $i -eq 60 ]]; then
echo "ERROR: Setup timed out after 15 minutes."
echo "Check logs: gcloud compute ssh $INSTANCE_NAME --zone=$ZONE --command='cat /var/log/wdp-setup.log'"
if [[ "$KEEP_VM" == "false" ]]; then
echo "Cleaning up VM..."
gcloud compute instances delete "$INSTANCE_NAME" --zone="$ZONE" --quiet
fi
exit 1
fi
sleep 15
done
# ── Step 3: Clone repo and build ─────────────────────────────────────────────
echo "[3/7] Cloning repository and building training binary..."
gcloud compute ssh "$INSTANCE_NAME" --zone="$ZONE" --command="$(cat <<CLONE_EOF
set -euo pipefail
source \$HOME/.cargo/env
# Clone the repo
if [[ ! -d ~/wifi-densepose ]]; then
git clone --depth 1 --branch "$BRANCH" "$REPO_URL" ~/wifi-densepose
fi
# Set libtorch environment
export LIBTORCH=\$(python3 -c "import torch; print(torch.__path__[0] + '/lib')")
export LD_LIBRARY_PATH="\${LIBTORCH}:\${LD_LIBRARY_PATH:-}"
# Build the training binary with tch-backend
cd ~/wifi-densepose/rust-port/wifi-densepose-rs
echo "Building with LIBTORCH=\$LIBTORCH ..."
cargo build --release --features tch-backend --bin train 2>&1 | tail -5
echo "Build complete."
ls -lh target/release/train
CLONE_EOF
)"
# ── Step 4: Upload training data ─────────────────────────────────────────────
echo "[4/7] Uploading training data..."
if [[ -d "$DATA_DIR" ]] && [[ "$(ls -A "$DATA_DIR" 2>/dev/null)" ]]; then
# Create a tarball of the data directory
DATA_TAR="/tmp/wdp-training-data-${TIMESTAMP}.tar.gz"
tar czf "$DATA_TAR" -C "$(dirname "$DATA_DIR")" "$(basename "$DATA_DIR")"
DATA_SIZE=$(du -h "$DATA_TAR" | cut -f1)
echo " Uploading ${DATA_SIZE} of training data..."
gcloud compute scp "$DATA_TAR" "${INSTANCE_NAME}:~/training-data.tar.gz" --zone="$ZONE" --quiet
gcloud compute ssh "$INSTANCE_NAME" --zone="$ZONE" --command="
mkdir -p ~/wifi-densepose/data
tar xzf ~/training-data.tar.gz -C ~/wifi-densepose/data/
echo 'Data extracted:'
find ~/wifi-densepose/data -name '*.jsonl' -o -name '*.csi.jsonl' | head -20
"
rm -f "$DATA_TAR"
else
echo " No local data at '$DATA_DIR'. Training will use --dry-run or MM-Fi."
if [[ "$DRY_RUN" == "false" && "$SWEEP" == "false" ]]; then
echo " WARNING: No data and --dry-run not set. Forcing --dry-run."
DRY_RUN=true
fi
fi
# ── Step 5: Upload config and run training ────────────────────────────────────
echo "[5/7] Running training..."
# Upload sweep config if doing a sweep
if [[ "$SWEEP" == "true" ]]; then
SWEEP_FILE="scripts/training-config-sweep.json"
if [[ -f "$SWEEP_FILE" ]]; then
gcloud compute scp "$SWEEP_FILE" "${INSTANCE_NAME}:~/sweep-configs.json" --zone="$ZONE" --quiet
else
echo "ERROR: Sweep config not found at $SWEEP_FILE"
exit 1
fi
fi
# Upload single config if specified
if [[ -n "$CONFIG_FILE" ]]; then
gcloud compute scp "$CONFIG_FILE" "${INSTANCE_NAME}:~/train-config.json" --zone="$ZONE" --quiet
fi
# Build the training command
TRAIN_CMD_BASE="
set -euo pipefail
source \$HOME/.cargo/env
export LIBTORCH=\$(python3 -c \"import torch; print(torch.__path__[0] + '/lib')\")
export LD_LIBRARY_PATH=\"\${LIBTORCH}:\${LD_LIBRARY_PATH:-}\"
cd ~/wifi-densepose/rust-port/wifi-densepose-rs
# Set auto-shutdown timer (safety net)
sudo shutdown -P +$((MAX_HOURS * 60)) &
TRAIN_BIN=./target/release/train
"
if [[ "$SWEEP" == "true" ]]; then
# Run all configs in the sweep file
gcloud compute ssh "$INSTANCE_NAME" --zone="$ZONE" --command="$(cat <<SWEEP_EOF
$TRAIN_CMD_BASE
echo "=== Hyperparameter Sweep ==="
SWEEP_FILE=~/sweep-configs.json
NUM_CONFIGS=\$(python3 -c "import json; print(len(json.load(open('\$SWEEP_FILE'))['configs']))")
echo "Running \$NUM_CONFIGS configurations..."
mkdir -p ~/results
for i in \$(seq 0 \$((NUM_CONFIGS - 1))); do
echo ""
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
echo " Config \$((i+1)) / \$NUM_CONFIGS"
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
# Extract single config to temp file
python3 -c "
import json, sys
sweep = json.load(open('\$SWEEP_FILE'))
cfg = sweep['configs'][\$i]
# Merge with base config
base = sweep.get('base', {})
merged = {**base, **cfg}
# Set checkpoint dir per config
merged['checkpoint_dir'] = f'checkpoints/sweep_{i:02d}'
merged['log_dir'] = f'logs/sweep_{i:02d}'
json.dump(merged, open('/tmp/sweep_config_\${i}.json', 'w'), indent=2)
print(f\"Config \${i}: lr={merged.get('learning_rate', '?')}, bs={merged.get('batch_size', '?')}, bb={merged.get('backbone_channels', '?')}\")
"
START_TIME=\$(date +%s)
\$TRAIN_BIN --config /tmp/sweep_config_\${i}.json --cuda $( [[ "$DRY_RUN" == "true" ]] && echo "--dry-run" ) 2>&1 | tee ~/results/sweep_\${i}.log || true
END_TIME=\$(date +%s)
ELAPSED=\$(( END_TIME - START_TIME ))
echo " Completed in \${ELAPSED}s"
done
echo ""
echo "=== Sweep Complete ==="
echo "Results in ~/results/"
ls -lh ~/results/
SWEEP_EOF
)"
elif [[ -n "$CONFIG_FILE" ]]; then
# Single config run
gcloud compute ssh "$INSTANCE_NAME" --zone="$ZONE" --command="$(cat <<SINGLE_EOF
$TRAIN_CMD_BASE
echo "=== Training with custom config ==="
\$TRAIN_BIN --config ~/train-config.json --cuda $( [[ "$DRY_RUN" == "true" ]] && echo "--dry-run" ) 2>&1 | tee ~/train.log
SINGLE_EOF
)"
else
# Default config run
gcloud compute ssh "$INSTANCE_NAME" --zone="$ZONE" --command="$(cat <<DEFAULT_EOF
$TRAIN_CMD_BASE
echo "=== Training with default config ==="
\$TRAIN_BIN --cuda $( [[ "$DRY_RUN" == "true" ]] && echo "--dry-run --dry-run-samples 256" ) 2>&1 | tee ~/train.log
DEFAULT_EOF
)"
fi
# ── Step 6: Download results ─────────────────────────────────────────────────
echo "[6/7] Downloading trained model artifacts..."
LOCAL_RESULTS="training-results/${INSTANCE_NAME}"
mkdir -p "$LOCAL_RESULTS"
# Package results on the VM
gcloud compute ssh "$INSTANCE_NAME" --zone="$ZONE" --command="
cd ~/wifi-densepose/rust-port/wifi-densepose-rs
tar czf ~/training-artifacts.tar.gz \
checkpoints/ \
logs/ \
2>/dev/null || true
# Also grab sweep results if they exist
if [[ -d ~/results ]]; then
tar czf ~/sweep-results.tar.gz -C ~ results/ 2>/dev/null || true
fi
ls -lh ~/training-artifacts.tar.gz ~/sweep-results.tar.gz 2>/dev/null || true
"
# Download artifacts
gcloud compute scp "${INSTANCE_NAME}:~/training-artifacts.tar.gz" \
"${LOCAL_RESULTS}/training-artifacts.tar.gz" --zone="$ZONE" --quiet 2>/dev/null || true
if [[ "$SWEEP" == "true" ]]; then
gcloud compute scp "${INSTANCE_NAME}:~/sweep-results.tar.gz" \
"${LOCAL_RESULTS}/sweep-results.tar.gz" --zone="$ZONE" --quiet 2>/dev/null || true
fi
# Download training log
gcloud compute scp "${INSTANCE_NAME}:~/train.log" \
"${LOCAL_RESULTS}/train.log" --zone="$ZONE" --quiet 2>/dev/null || true
# Extract locally
if [[ -f "${LOCAL_RESULTS}/training-artifacts.tar.gz" ]]; then
tar xzf "${LOCAL_RESULTS}/training-artifacts.tar.gz" -C "$LOCAL_RESULTS/"
echo " Artifacts extracted to: $LOCAL_RESULTS/"
find "$LOCAL_RESULTS" -name "*.pt" -o -name "*.onnx" -o -name "*.rvf" 2>/dev/null | head -20
fi
# ── Step 7: Cleanup ──────────────────────────────────────────────────────────
if [[ "$KEEP_VM" == "true" ]]; then
echo "[7/7] Keeping VM alive (--keep-vm). Remember to delete it manually:"
echo " gcloud compute instances delete $INSTANCE_NAME --zone=$ZONE --quiet"
echo " SSH: gcloud compute ssh $INSTANCE_NAME --zone=$ZONE"
else
echo "[7/7] Deleting VM to avoid ongoing costs..."
gcloud compute instances delete "$INSTANCE_NAME" --zone="$ZONE" --quiet
echo " VM deleted."
fi
# ── Summary ──────────────────────────────────────────────────────────────────
echo ""
echo "============================================================"
echo " Training Complete"
echo "============================================================"
echo " Results: $LOCAL_RESULTS/"
echo " GPU: $GPU_TYPE ($ZONE)"
echo " Instance: $INSTANCE_NAME"
if [[ "$KEEP_VM" == "true" ]]; then
echo " VM: STILL RUNNING (delete manually!)"
fi
echo "============================================================"

View file

@ -0,0 +1,155 @@
{
"description": "WiFi-DensePose hyperparameter sweep — 10 configurations exploring learning rate, batch size, backbone width, window length, loss ratios, and warmup schedules.",
"base": {
"num_subcarriers": 56,
"native_subcarriers": 114,
"num_antennas_tx": 3,
"num_antennas_rx": 3,
"heatmap_size": 56,
"num_keypoints": 17,
"num_body_parts": 24,
"weight_decay": 1e-4,
"num_epochs": 50,
"lr_gamma": 0.1,
"grad_clip_norm": 1.0,
"val_every_epochs": 1,
"early_stopping_patience": 10,
"save_top_k": 3,
"use_gpu": true,
"gpu_device_id": 0,
"num_workers": 4,
"seed": 42
},
"configs": [
{
"_name": "baseline",
"_description": "Default config — reference baseline",
"learning_rate": 1e-3,
"batch_size": 8,
"backbone_channels": 256,
"window_frames": 100,
"warmup_epochs": 5,
"lr_milestones": [30, 45],
"lambda_kp": 0.3,
"lambda_dp": 0.6,
"lambda_tr": 0.1
},
{
"_name": "low_lr_large_batch",
"_description": "Lower LR with larger batch — stable convergence",
"learning_rate": 1e-4,
"batch_size": 16,
"backbone_channels": 256,
"window_frames": 100,
"warmup_epochs": 10,
"lr_milestones": [30, 45],
"lambda_kp": 0.3,
"lambda_dp": 0.6,
"lambda_tr": 0.1
},
{
"_name": "high_lr_small_batch",
"_description": "Higher LR with small batch — fast exploration",
"learning_rate": 2e-3,
"batch_size": 4,
"backbone_channels": 256,
"window_frames": 100,
"warmup_epochs": 3,
"lr_milestones": [20, 40],
"lambda_kp": 0.3,
"lambda_dp": 0.6,
"lambda_tr": 0.1
},
{
"_name": "narrow_backbone",
"_description": "128-channel backbone — faster training, lower VRAM",
"learning_rate": 1e-3,
"batch_size": 16,
"backbone_channels": 128,
"window_frames": 100,
"warmup_epochs": 5,
"lr_milestones": [30, 45],
"lambda_kp": 0.3,
"lambda_dp": 0.6,
"lambda_tr": 0.1
},
{
"_name": "short_window",
"_description": "50-frame window — lower latency, tests temporal sensitivity",
"learning_rate": 5e-4,
"batch_size": 16,
"backbone_channels": 256,
"window_frames": 50,
"warmup_epochs": 5,
"lr_milestones": [30, 45],
"lambda_kp": 0.3,
"lambda_dp": 0.6,
"lambda_tr": 0.1
},
{
"_name": "keypoint_heavy",
"_description": "Heavier keypoint loss — prioritize skeleton accuracy",
"learning_rate": 5e-4,
"batch_size": 8,
"backbone_channels": 256,
"window_frames": 100,
"warmup_epochs": 5,
"lr_milestones": [30, 45],
"lambda_kp": 0.5,
"lambda_dp": 0.4,
"lambda_tr": 0.1
},
{
"_name": "contrastive_heavy",
"_description": "Strong contrastive/transfer loss — self-supervised pretraining focus",
"learning_rate": 5e-4,
"batch_size": 8,
"backbone_channels": 256,
"window_frames": 100,
"warmup_epochs": 10,
"lr_milestones": [30, 45],
"lambda_kp": 0.2,
"lambda_dp": 0.3,
"lambda_tr": 0.5
},
{
"_name": "wide_backbone_long_warmup",
"_description": "256-ch backbone + long warmup + moderate LR",
"learning_rate": 5e-4,
"batch_size": 8,
"backbone_channels": 256,
"window_frames": 100,
"warmup_epochs": 10,
"lr_milestones": [35, 48],
"lambda_kp": 0.3,
"lambda_dp": 0.6,
"lambda_tr": 0.1
},
{
"_name": "narrow_short_aggressive",
"_description": "128-ch + 50-frame + high LR — fast cheap exploration",
"learning_rate": 2e-3,
"batch_size": 16,
"backbone_channels": 128,
"window_frames": 50,
"warmup_epochs": 3,
"lr_milestones": [20, 40],
"lambda_kp": 0.4,
"lambda_dp": 0.5,
"lambda_tr": 0.1
},
{
"_name": "balanced_medium",
"_description": "Balanced loss, medium LR, medium batch — robust default",
"learning_rate": 5e-4,
"batch_size": 8,
"backbone_channels": 256,
"window_frames": 100,
"warmup_epochs": 5,
"lr_milestones": [25, 40],
"lambda_kp": 0.35,
"lambda_dp": 0.45,
"lambda_tr": 0.2
}
]
}