From c63cf2ee77b2b72f96fb4e0b0e2dd1cb70ebc038 Mon Sep 17 00:00:00 2001 From: ruv Date: Thu, 2 Apr 2026 22:04:57 -0400 Subject: [PATCH] feat: GCloud GPU training pipeline + data collection + benchmarking - gcloud-train.sh: L4/A100/H100 VM provisioning, Rust build, training with --cuda, artifact download, auto-cleanup ($0.80-$8.50/hr) - training-config-sweep.json: 10 hyperparameter configs (LR, batch, backbone, windows, loss weights, warmup) - collect-training-data.py: UDP listener for 2-node ESP32 CSI recording to .csi.jsonl with interactive/batch labeling and manifest generation - benchmark-model.py: ONNX latency/throughput/PCK/FLOPs profiling with multi-model sweep comparison Co-Authored-By: claude-flow --- scripts/benchmark-model.py | 550 +++++++++++++++++++++++++++++ scripts/collect-training-data.py | 483 +++++++++++++++++++++++++ scripts/gcloud-train.sh | 469 ++++++++++++++++++++++++ scripts/training-config-sweep.json | 155 ++++++++ 4 files changed, 1657 insertions(+) create mode 100644 scripts/benchmark-model.py create mode 100644 scripts/collect-training-data.py create mode 100644 scripts/gcloud-train.sh create mode 100644 scripts/training-config-sweep.json diff --git a/scripts/benchmark-model.py b/scripts/benchmark-model.py new file mode 100644 index 00000000..ba817325 --- /dev/null +++ b/scripts/benchmark-model.py @@ -0,0 +1,550 @@ +#!/usr/bin/env python3 +""" +WiFi-DensePose Model Benchmarking + +Loads trained ONNX models, runs inference on test data, and reports +performance metrics: latency, throughput, PCK@0.2, model size, and +estimated FLOPs. + +Can compare multiple models from a hyperparameter sweep. + +Usage: + # Benchmark a single model + python scripts/benchmark-model.py --model checkpoints/best.onnx + + # Benchmark with recorded test data + python scripts/benchmark-model.py --model best.onnx --test-data data/recordings/test.csi.jsonl + + # Compare models from a sweep + python scripts/benchmark-model.py --sweep-dir training-results/wdp-train-a100-*/checkpoints/ + + # Benchmark with synthetic data (no recordings needed) + python scripts/benchmark-model.py --model best.onnx --synthetic --num-samples 200 + + # Export results as JSON + python scripts/benchmark-model.py --model best.onnx --output results.json + +Prerequisites: + pip install onnxruntime numpy + Optional: pip install onnx (for FLOPs estimation) +""" + +from __future__ import annotations + +import argparse +import json +import os +import sys +import time +from dataclasses import dataclass, field, asdict +from pathlib import Path +from typing import Optional + +import numpy as np + +try: + import onnxruntime as ort +except ImportError: + print("ERROR: onnxruntime not installed. Run: pip install onnxruntime") + sys.exit(1) + + +# ── Configuration ──────────────────────────────────────────────────────────── + +# Default model input shape (must match TrainingConfig defaults) +NUM_SUBCARRIERS = 56 +NUM_ANTENNAS_TX = 3 +NUM_ANTENNAS_RX = 3 +WINDOW_FRAMES = 100 +NUM_KEYPOINTS = 17 +HEATMAP_SIZE = 56 + +# PCK threshold +PCK_THRESHOLD = 0.2 + + +# ── Data classes ───────────────────────────────────────────────────────────── + +@dataclass +class BenchmarkResult: + model_path: str + model_size_mb: float + num_parameters: Optional[int] = None + estimated_flops: Optional[int] = None + + # Latency + warmup_runs: int = 10 + benchmark_runs: int = 100 + latency_mean_ms: float = 0.0 + latency_std_ms: float = 0.0 + latency_p50_ms: float = 0.0 + latency_p95_ms: float = 0.0 + latency_p99_ms: float = 0.0 + throughput_fps: float = 0.0 + + # Accuracy (if ground truth available) + pck_at_02: Optional[float] = None + mean_per_joint_error: Optional[float] = None + num_test_samples: int = 0 + + # Input shape + input_shape: list = field(default_factory=list) + provider: str = "" + + +# ── ONNX model loading ────────────────────────────────────────────────────── + +def load_model(model_path: str) -> ort.InferenceSession: + """Load an ONNX model with the best available execution provider.""" + providers = [] + if "CUDAExecutionProvider" in ort.get_available_providers(): + providers.append("CUDAExecutionProvider") + providers.append("CPUExecutionProvider") + + sess_opts = ort.SessionOptions() + sess_opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + sess_opts.intra_op_num_threads = os.cpu_count() or 4 + + session = ort.InferenceSession(model_path, sess_opts, providers=providers) + return session + + +def get_model_info(model_path: str) -> dict: + """Extract model metadata: size, parameter count, FLOPs estimate.""" + path = Path(model_path) + size_mb = path.stat().st_size / (1024 * 1024) + + info = { + "size_mb": round(size_mb, 2), + "num_parameters": None, + "estimated_flops": None, + } + + # Try to count parameters via onnx + try: + import onnx + model = onnx.load(model_path) + total_params = 0 + for initializer in model.graph.initializer: + shape = list(initializer.dims) + if shape: + total_params += int(np.prod(shape)) + info["num_parameters"] = total_params + + # Rough FLOPs estimate: ~2 * params (multiply-accumulate) + info["estimated_flops"] = total_params * 2 + except ImportError: + pass + except Exception as e: + print(f" Warning: Could not extract parameter count: {e}") + + return info + + +# ── Synthetic data generation ──────────────────────────────────────────────── + +def generate_synthetic_input( + batch_size: int = 1, + num_subcarriers: int = NUM_SUBCARRIERS, + num_tx: int = NUM_ANTENNAS_TX, + num_rx: int = NUM_ANTENNAS_RX, + window_frames: int = WINDOW_FRAMES, +) -> np.ndarray: + """Generate synthetic CSI input tensor matching the model's expected shape. + + The WiFi-DensePose model expects input shape: + [batch, channels, height, width] + where channels = num_tx * num_rx, height = window_frames, width = num_subcarriers. + """ + channels = num_tx * num_rx # 3x3 = 9 MIMO streams + # Simulate CSI amplitude data with realistic distribution + rng = np.random.default_rng(42) + data = rng.normal(loc=0.0, scale=1.0, size=(batch_size, channels, window_frames, num_subcarriers)) + return data.astype(np.float32) + + +def generate_synthetic_keypoints( + num_samples: int, + num_keypoints: int = NUM_KEYPOINTS, + heatmap_size: int = HEATMAP_SIZE, +) -> np.ndarray: + """Generate synthetic ground truth keypoint coordinates for PCK evaluation.""" + rng = np.random.default_rng(123) + # Keypoints as (x, y) in [0, heatmap_size) range + return rng.uniform(0, heatmap_size, size=(num_samples, num_keypoints, 2)).astype(np.float32) + + +# ── Load test data from .csi.jsonl ────────────────────────────────────────── + +def load_test_data( + jsonl_path: str, + window_frames: int = WINDOW_FRAMES, + num_subcarriers: int = NUM_SUBCARRIERS, + max_samples: int = 500, +) -> np.ndarray: + """Load CSI frames from a .csi.jsonl file and window them into model inputs.""" + frames = [] + path = Path(jsonl_path) + + with open(path, "r") as f: + for line in f: + line = line.strip() + if not line: + continue + try: + record = json.loads(line) + subs = record.get("subcarriers", []) + if len(subs) > 0: + frames.append(subs) + except json.JSONDecodeError: + continue + + if len(frames) < window_frames: + print(f" Warning: Only {len(frames)} frames, need {window_frames}. Padding with zeros.") + while len(frames) < window_frames: + frames.append([0.0] * num_subcarriers) + + # Normalize subcarrier count + normalized = [] + for frame in frames: + if len(frame) < num_subcarriers: + frame = frame + [0.0] * (num_subcarriers - len(frame)) + elif len(frame) > num_subcarriers: + # Downsample via linear interpolation + indices = np.linspace(0, len(frame) - 1, num_subcarriers) + frame = np.interp(indices, range(len(frame)), frame).tolist() + normalized.append(frame) + + frames = normalized + + # Create sliding windows + samples = [] + stride = max(1, window_frames // 2) + for i in range(0, len(frames) - window_frames + 1, stride): + window = frames[i : i + window_frames] + # Shape: [channels=1, window_frames, num_subcarriers] + # Expand single stream to 9 channels (repeat for MIMO) + arr = np.array(window, dtype=np.float32) + arr = np.expand_dims(arr, axis=0) # [1, window_frames, num_subcarriers] + arr = np.repeat(arr, NUM_ANTENNAS_TX * NUM_ANTENNAS_RX, axis=0) # [9, window, subs] + samples.append(arr) + + if len(samples) >= max_samples: + break + + if not samples: + return generate_synthetic_input(1) + + return np.stack(samples, axis=0) # [N, 9, window_frames, num_subcarriers] + + +# ── Benchmarking ───────────────────────────────────────────────────────────── + +def benchmark_latency( + session: ort.InferenceSession, + input_data: np.ndarray, + warmup: int = 10, + runs: int = 100, +) -> dict: + """Measure inference latency over multiple runs.""" + input_name = session.get_inputs()[0].name + + # Warmup + for _ in range(warmup): + session.run(None, {input_name: input_data[:1]}) + + # Timed runs + latencies = [] + for _ in range(runs): + start = time.perf_counter() + session.run(None, {input_name: input_data[:1]}) + end = time.perf_counter() + latencies.append((end - start) * 1000) # ms + + latencies = np.array(latencies) + return { + "mean_ms": float(np.mean(latencies)), + "std_ms": float(np.std(latencies)), + "p50_ms": float(np.percentile(latencies, 50)), + "p95_ms": float(np.percentile(latencies, 95)), + "p99_ms": float(np.percentile(latencies, 99)), + "throughput_fps": 1000.0 / float(np.mean(latencies)), + } + + +def compute_pck( + predictions: np.ndarray, + ground_truth: np.ndarray, + threshold: float = PCK_THRESHOLD, + normalize_by: float = HEATMAP_SIZE, +) -> float: + """Compute Percentage of Correct Keypoints at a given threshold. + + PCK@t = fraction of predicted keypoints within t * normalize_by of ground truth. + """ + if predictions.shape != ground_truth.shape: + return 0.0 + + # Euclidean distance per keypoint + distances = np.linalg.norm(predictions - ground_truth, axis=-1) # [N, K] + threshold_pixels = threshold * normalize_by + correct = (distances < threshold_pixels).astype(float) + return float(np.mean(correct)) + + +def extract_keypoints_from_heatmaps(heatmaps: np.ndarray) -> np.ndarray: + """Convert heatmap outputs [N, K, H, W] to keypoint coordinates [N, K, 2].""" + n, k, h, w = heatmaps.shape + flat = heatmaps.reshape(n, k, -1) + max_idx = np.argmax(flat, axis=-1) # [N, K] + y = max_idx // w + x = max_idx % w + return np.stack([x, y], axis=-1).astype(np.float32) + + +def benchmark_model( + model_path: str, + test_data: Optional[np.ndarray] = None, + gt_keypoints: Optional[np.ndarray] = None, + warmup: int = 10, + runs: int = 100, +) -> BenchmarkResult: + """Run full benchmark on a single model.""" + print(f"\nBenchmarking: {model_path}") + + # Load model + session = load_model(model_path) + provider = session.get_providers()[0] + print(f" Provider: {provider}") + + # Model info + model_info = get_model_info(model_path) + print(f" Size: {model_info['size_mb']} MB") + if model_info["num_parameters"]: + print(f" Parameters: {model_info['num_parameters']:,}") + if model_info["estimated_flops"]: + print(f" Estimated FLOPs: {model_info['estimated_flops']:,}") + + # Input shape + input_meta = session.get_inputs()[0] + input_shape = input_meta.shape + print(f" Input: {input_meta.name} {input_shape} ({input_meta.type})") + + # Output shapes + for out in session.get_outputs(): + print(f" Output: {out.name} {out.shape}") + + # Generate or use provided test data + if test_data is None: + # Infer shape from model + if input_shape and all(isinstance(d, int) for d in input_shape): + batch = max(1, input_shape[0] if input_shape[0] > 0 else 1) + test_data = np.random.randn(*[batch if d <= 0 else d for d in input_shape]).astype(np.float32) + else: + test_data = generate_synthetic_input(1) + + # Latency benchmark + print(f" Running {warmup} warmup + {runs} benchmark iterations...") + latency = benchmark_latency(session, test_data, warmup=warmup, runs=runs) + print(f" Latency: {latency['mean_ms']:.2f} +/- {latency['std_ms']:.2f} ms") + print(f" P50/P95/P99: {latency['p50_ms']:.2f} / {latency['p95_ms']:.2f} / {latency['p99_ms']:.2f} ms") + print(f" Throughput: {latency['throughput_fps']:.1f} fps") + + # Accuracy (if ground truth provided or we can do synthetic evaluation) + pck = None + mpjpe = None + num_samples = 0 + + if gt_keypoints is not None and test_data is not None: + input_name = session.get_inputs()[0].name + all_preds = [] + + for i in range(len(test_data)): + outputs = session.run(None, {input_name: test_data[i : i + 1]}) + # Assume first output is keypoint heatmaps [1, K, H, W] + heatmaps = outputs[0] + if heatmaps.ndim == 4: + kp = extract_keypoints_from_heatmaps(heatmaps) + all_preds.append(kp[0]) + + if all_preds: + predictions = np.stack(all_preds) + gt = gt_keypoints[: len(predictions)] + pck = compute_pck(predictions, gt) + distances = np.linalg.norm(predictions - gt, axis=-1) + mpjpe = float(np.mean(distances)) + num_samples = len(predictions) + print(f" PCK@{PCK_THRESHOLD}: {pck:.4f}") + print(f" MPJPE: {mpjpe:.2f} px") + print(f" Samples: {num_samples}") + + result = BenchmarkResult( + model_path=model_path, + model_size_mb=model_info["size_mb"], + num_parameters=model_info["num_parameters"], + estimated_flops=model_info["estimated_flops"], + warmup_runs=warmup, + benchmark_runs=runs, + latency_mean_ms=round(latency["mean_ms"], 3), + latency_std_ms=round(latency["std_ms"], 3), + latency_p50_ms=round(latency["p50_ms"], 3), + latency_p95_ms=round(latency["p95_ms"], 3), + latency_p99_ms=round(latency["p99_ms"], 3), + throughput_fps=round(latency["throughput_fps"], 1), + pck_at_02=round(pck, 4) if pck is not None else None, + mean_per_joint_error=round(mpjpe, 2) if mpjpe is not None else None, + num_test_samples=num_samples, + input_shape=list(input_shape) if input_shape else [], + provider=provider, + ) + + return result + + +# ── Comparison table ───────────────────────────────────────────────────────── + +def print_comparison_table(results: list[BenchmarkResult]): + """Print a formatted comparison table of multiple models.""" + if not results: + return + + print("\n" + "=" * 100) + print(" Model Comparison") + print("=" * 100) + + # Header + print( + f"{'Model':<35} {'Size(MB)':>8} {'Params':>10} " + f"{'Lat(ms)':>8} {'P95(ms)':>8} {'FPS':>7} {'PCK@0.2':>8}" + ) + print("-" * 100) + + for r in results: + name = Path(r.model_path).stem[:33] + params = f"{r.num_parameters:,}" if r.num_parameters else "?" + pck = f"{r.pck_at_02:.4f}" if r.pck_at_02 is not None else "N/A" + + print( + f"{name:<35} {r.model_size_mb:>8.2f} {params:>10} " + f"{r.latency_mean_ms:>8.2f} {r.latency_p95_ms:>8.2f} " + f"{r.throughput_fps:>7.1f} {pck:>8}" + ) + + print("=" * 100) + + # Best model by latency + best_latency = min(results, key=lambda r: r.latency_mean_ms) + print(f"\n Fastest: {Path(best_latency.model_path).stem} ({best_latency.latency_mean_ms:.2f} ms)") + + # Best by PCK (if available) + pck_results = [r for r in results if r.pck_at_02 is not None] + if pck_results: + best_pck = max(pck_results, key=lambda r: r.pck_at_02) + print(f" Best accuracy: {Path(best_pck.model_path).stem} (PCK@0.2={best_pck.pck_at_02:.4f})") + + # Smallest model + smallest = min(results, key=lambda r: r.model_size_mb) + print(f" Smallest: {Path(smallest.model_path).stem} ({smallest.model_size_mb:.2f} MB)") + + +# ── Main ───────────────────────────────────────────────────────────────────── + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark WiFi-DensePose ONNX models", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument("--model", type=str, help="Path to a single ONNX model") + parser.add_argument("--sweep-dir", type=str, help="Directory containing multiple ONNX models to compare") + parser.add_argument("--test-data", type=str, help="Path to .csi.jsonl test data file") + parser.add_argument("--synthetic", action="store_true", help="Use synthetic test data") + parser.add_argument("--num-samples", type=int, default=100, help="Number of synthetic samples (default: 100)") + parser.add_argument("--warmup", type=int, default=10, help="Warmup iterations (default: 10)") + parser.add_argument("--runs", type=int, default=100, help="Benchmark iterations (default: 100)") + parser.add_argument("--output", type=str, help="Save results to JSON file") + parser.add_argument("--gpu", action="store_true", help="Force GPU execution provider") + + args = parser.parse_args() + + if not args.model and not args.sweep_dir: + parser.error("Specify --model or --sweep-dir") + + # Prepare test data + test_data = None + gt_keypoints = None + + if args.test_data: + print(f"Loading test data from: {args.test_data}") + test_data = load_test_data(args.test_data) + print(f" Loaded {len(test_data)} windowed samples") + elif args.synthetic: + print(f"Generating {args.num_samples} synthetic samples...") + test_data = generate_synthetic_input(args.num_samples) + gt_keypoints = generate_synthetic_keypoints(args.num_samples) + print(f" Input shape: {test_data.shape}") + + # Collect models + model_paths = [] + if args.model: + model_paths.append(args.model) + if args.sweep_dir: + sweep = Path(args.sweep_dir) + if sweep.is_dir(): + model_paths.extend(sorted(str(p) for p in sweep.glob("**/*.onnx"))) + else: + # Glob pattern + from glob import glob + model_paths.extend(sorted(glob(str(sweep)))) + + if not model_paths: + print("ERROR: No ONNX models found.") + sys.exit(1) + + print(f"Found {len(model_paths)} model(s) to benchmark.") + + # Benchmark each model + results = [] + for path in model_paths: + if not Path(path).exists(): + print(f" Skipping (not found): {path}") + continue + try: + result = benchmark_model( + path, + test_data=test_data, + gt_keypoints=gt_keypoints, + warmup=args.warmup, + runs=args.runs, + ) + results.append(result) + except Exception as e: + print(f" ERROR benchmarking {path}: {e}") + + # Comparison table + if len(results) > 1: + print_comparison_table(results) + + # Save results + if args.output: + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w") as f: + json.dump( + { + "benchmark_results": [asdict(r) for r in results], + "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + "num_models": len(results), + }, + f, + indent=2, + ) + print(f"\nResults saved to: {output_path}") + + if not results: + print("No models were successfully benchmarked.") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/scripts/collect-training-data.py b/scripts/collect-training-data.py new file mode 100644 index 00000000..6b2e4d6b --- /dev/null +++ b/scripts/collect-training-data.py @@ -0,0 +1,483 @@ +#!/usr/bin/env python3 +""" +WiFi-DensePose Training Data Collector + +Listens on UDP for CSI data from ESP32 nodes and records to .csi.jsonl +files compatible with the Rust training pipeline (MmFiDataset / CsiDataset). + +Supports two packet formats: + - ADR-069 feature vectors (magic 0xC5110003, 48 bytes) — 8-dim pre-extracted + - ADR-018 raw CSI frames (magic 0xC5110001, variable) — full subcarrier data + +Usage: + # Interactive — prompts for scenario labels + python scripts/collect-training-data.py --port 5006 + + # Scripted — fixed label, 60s per recording + python scripts/collect-training-data.py --port 5006 --label walking --duration 60 + + # Multiple scenarios in sequence + python scripts/collect-training-data.py --port 5006 --scenarios walking,standing,sitting --duration 30 + + # Dual-node collection (two ESP32s on different ports) + python scripts/collect-training-data.py --port 5005 --port2 5006 --label walking + + # Generate manifest only from existing recordings + python scripts/collect-training-data.py --manifest-only --output-dir data/recordings + +Prerequisites: + - ESP32 nodes streaming CSI on UDP (see firmware/esp32-csi-node) + - Python 3.9+ +""" + +from __future__ import annotations + +import argparse +import json +import logging +import os +import socket +import struct +import sys +import time +import signal +from datetime import datetime, timezone +from pathlib import Path +from typing import Optional + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%H:%M:%S", +) +log = logging.getLogger("collect-data") + +# ── Packet formats (must match firmware) ───────────────────────────────────── + +# ADR-018 raw CSI frame header +MAGIC_CSI_RAW = 0xC5110001 +# ADR-069 feature vector packet +MAGIC_FEATURES = 0xC5110003 +FEATURE_PKT_FMT = " Optional[dict]: + """Parse a UDP packet into a frame dict, or None if unrecognized.""" + if len(data) < 4: + return None + + magic = struct.unpack_from("= FEATURE_PKT_SIZE: + return _parse_feature_packet(data) + elif magic == MAGIC_CSI_RAW and len(data) >= RAW_CSI_HDR_SIZE: + return _parse_raw_csi_packet(data) + else: + return None + + +def _parse_feature_packet(data: bytes) -> Optional[dict]: + """Parse ADR-069 feature vector packet (48 bytes).""" + try: + magic, node_id, _, seq, ts_us, *features = struct.unpack_from(FEATURE_PKT_FMT, data) + except struct.error: + return None + + if magic != MAGIC_FEATURES: + return None + + # Reject NaN/inf + import math + if any(math.isnan(f) or math.isinf(f) for f in features): + return None + + return { + "type": "features", + "node_id": node_id, + "seq": seq, + "timestamp_us": ts_us, + "timestamp": ts_us / 1_000_000.0, + "features": features, + "subcarriers": features, # Use features as subcarrier proxy for training + "rssi": 0.0, + "noise_floor": 0.0, + } + + +def _parse_raw_csi_packet(data: bytes) -> Optional[dict]: + """Parse ADR-018 raw CSI frame with full subcarrier data.""" + try: + magic, node_id, ant_cfg, n_sub, rssi, noise, channel, ts_ms = struct.unpack_from( + RAW_CSI_HDR_FMT, data + ) + except struct.error: + return None + + if magic != MAGIC_CSI_RAW: + return None + + # Subcarrier data follows header as int16 I/Q pairs + payload_offset = RAW_CSI_HDR_SIZE + expected_bytes = n_sub * 2 * 2 # n_sub * (I + Q) * int16 + if len(data) < payload_offset + expected_bytes: + return None + + iq_data = struct.unpack_from(f"<{n_sub * 2}h", data, payload_offset) + # Convert I/Q pairs to amplitude + subcarriers = [] + for i in range(0, len(iq_data), 2): + real, imag = iq_data[i], iq_data[i + 1] + amplitude = (real ** 2 + imag ** 2) ** 0.5 + subcarriers.append(amplitude) + + return { + "type": "raw_csi", + "node_id": node_id, + "antenna_config": ant_cfg, + "n_subcarriers": n_sub, + "channel": channel, + "timestamp": ts_ms / 1000.0, + "subcarriers": subcarriers, + "rssi": float(rssi), + "noise_floor": float(noise), + } + + +# ── JSONL recording ────────────────────────────────────────────────────────── + +class CsiRecorder: + """Records CSI frames to .csi.jsonl files compatible with the Rust pipeline.""" + + def __init__(self, output_dir: str, session_name: str, label: Optional[str] = None): + self.output_dir = Path(output_dir) + self.output_dir.mkdir(parents=True, exist_ok=True) + + ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") + safe_name = session_name.replace(" ", "_").replace("/", "_") + self.session_id = f"{safe_name}-{ts}" + self.label = label + self.file_path = self.output_dir / f"{self.session_id}.csi.jsonl" + self.meta_path = self.output_dir / f"{self.session_id}.csi.meta.json" + self.frame_count = 0 + self.start_time = time.time() + self.started_at = datetime.now(timezone.utc).isoformat() + self._file = None + + def open(self): + self._file = open(self.file_path, "a", encoding="utf-8") + log.info(f"Recording to: {self.file_path}") + + def write_frame(self, frame: dict): + """Write a single frame as a JSONL line.""" + if self._file is None: + return + + record = { + "timestamp": frame.get("timestamp", time.time()), + "subcarriers": frame.get("subcarriers", []), + "rssi": frame.get("rssi", 0.0), + "noise_floor": frame.get("noise_floor", 0.0), + "features": { + k: v for k, v in frame.items() + if k not in ("timestamp", "subcarriers", "rssi", "noise_floor", "type") + }, + } + + line = json.dumps(record, separators=(",", ":")) + self._file.write(line + "\n") + self.frame_count += 1 + + if self.frame_count % 500 == 0: + self._file.flush() + + def close(self) -> dict: + """Close the recording and write metadata. Returns session info.""" + if self._file: + self._file.flush() + self._file.close() + self._file = None + + ended_at = datetime.now(timezone.utc).isoformat() + elapsed = time.time() - self.start_time + file_size = self.file_path.stat().st_size if self.file_path.exists() else 0 + + meta = { + "id": self.session_id, + "name": self.session_id, + "label": self.label, + "started_at": self.started_at, + "ended_at": ended_at, + "duration_secs": round(elapsed, 2), + "frame_count": self.frame_count, + "file_size_bytes": file_size, + "file_path": str(self.file_path), + "fps": round(self.frame_count / elapsed, 1) if elapsed > 0 else 0, + } + + with open(self.meta_path, "w", encoding="utf-8") as f: + json.dump(meta, f, indent=2) + + log.info( + f"Recording stopped: {self.frame_count} frames in {elapsed:.1f}s " + f"({meta['fps']} fps, {file_size / 1024:.1f} KB)" + ) + return meta + + +# ── Manifest generation ────────────────────────────────────────────────────── + +def generate_manifest(output_dir: str) -> dict: + """Scan recordings directory and generate a dataset manifest JSON.""" + rec_dir = Path(output_dir) + sessions = [] + + for meta_file in sorted(rec_dir.glob("*.csi.meta.json")): + try: + with open(meta_file, "r") as f: + meta = json.load(f) + sessions.append(meta) + except (json.JSONDecodeError, OSError) as e: + log.warning(f"Skipping {meta_file}: {e}") + + # Aggregate stats + total_frames = sum(s.get("frame_count", 0) for s in sessions) + total_bytes = sum(s.get("file_size_bytes", 0) for s in sessions) + labels = sorted(set(s.get("label", "unlabeled") or "unlabeled" for s in sessions)) + + manifest = { + "dataset": "wifi-densepose-csi", + "generated_at": datetime.now(timezone.utc).isoformat(), + "directory": str(rec_dir), + "num_sessions": len(sessions), + "total_frames": total_frames, + "total_size_bytes": total_bytes, + "total_size_mb": round(total_bytes / (1024 * 1024), 2), + "labels": labels, + "sessions": sessions, + } + + manifest_path = rec_dir / "manifest.json" + with open(manifest_path, "w", encoding="utf-8") as f: + json.dump(manifest, f, indent=2) + + log.info( + f"Manifest: {len(sessions)} sessions, {total_frames} frames, " + f"{manifest['total_size_mb']} MB, labels={labels}" + ) + log.info(f"Written to: {manifest_path}") + return manifest + + +# ── UDP listener ───────────────────────────────────────────────────────────── + +def collect_session( + port: int, + port2: Optional[int], + output_dir: str, + label: str, + duration: float, + session_name: Optional[str] = None, +) -> dict: + """Run a single collection session. Returns session metadata.""" + name = session_name or label or "session" + recorder = CsiRecorder(output_dir, name, label) + recorder.open() + + # Bind primary socket + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(("0.0.0.0", port)) + sock.settimeout(1.0) + sockets = [sock] + + # Bind secondary socket if specified + if port2: + sock2 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock2.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock2.bind(("0.0.0.0", port2)) + sock2.settimeout(0.1) + sockets.append(sock2) + + log.info( + f"Collecting '{label}' for {duration}s on port(s) " + f"{port}{f', {port2}' if port2 else ''}" + ) + + start = time.time() + dropped = 0 + + try: + while time.time() - start < duration: + for s in sockets: + try: + data, addr = s.recvfrom(4096) + except socket.timeout: + continue + + frame = parse_packet(data) + if frame: + recorder.write_frame(frame) + else: + dropped += 1 + + # Progress update every 5s + elapsed = time.time() - start + if recorder.frame_count > 0 and int(elapsed) % 5 == 0 and int(elapsed) > 0: + remaining = duration - elapsed + if remaining > 0 and int(elapsed * 10) % 50 == 0: + log.info( + f" {recorder.frame_count} frames collected, " + f"{remaining:.0f}s remaining..." + ) + except KeyboardInterrupt: + log.info("Interrupted by user.") + finally: + for s in sockets: + s.close() + + if dropped > 0: + log.warning(f" {dropped} unrecognized packets dropped") + + return recorder.close() + + +# ── Main ───────────────────────────────────────────────────────────────────── + +def main(): + parser = argparse.ArgumentParser( + description="Collect CSI training data from ESP32 nodes via UDP", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Interactive label input + python scripts/collect-training-data.py --port 5006 + + # Fixed label, 60 seconds + python scripts/collect-training-data.py --port 5006 --label walking --duration 60 + + # Multiple scenarios + python scripts/collect-training-data.py --port 5006 --scenarios walking,standing,sitting --duration 30 + + # Dual ESP32 nodes + python scripts/collect-training-data.py --port 5005 --port2 5006 --label test + + # Generate manifest from existing recordings + python scripts/collect-training-data.py --manifest-only +""", + ) + + parser.add_argument("--port", type=int, default=5006, help="Primary UDP port (default: 5006)") + parser.add_argument("--port2", type=int, default=None, help="Secondary UDP port for dual-node") + parser.add_argument("--output-dir", default="data/recordings", help="Output directory (default: data/recordings)") + parser.add_argument("--label", default=None, help="Activity label for the recording") + parser.add_argument("--duration", type=float, default=30.0, help="Recording duration in seconds (default: 30)") + parser.add_argument("--scenarios", default=None, help="Comma-separated list of scenarios to record sequentially") + parser.add_argument("--pause", type=float, default=5.0, help="Pause between scenarios in seconds (default: 5)") + parser.add_argument("--manifest-only", action="store_true", help="Only generate manifest from existing recordings") + parser.add_argument("--repeats", type=int, default=1, help="Number of repeats per scenario (default: 1)") + + args = parser.parse_args() + + # Manifest-only mode + if args.manifest_only: + generate_manifest(args.output_dir) + return + + # Collect scenarios + all_sessions = [] + + if args.scenarios: + # Multi-scenario sequential collection + scenarios = [s.strip() for s in args.scenarios.split(",") if s.strip()] + total = len(scenarios) * args.repeats + idx = 0 + + for repeat in range(args.repeats): + for scenario in scenarios: + idx += 1 + print(f"\n{'='*60}") + print(f" Scenario {idx}/{total}: '{scenario}' (repeat {repeat+1}/{args.repeats})") + print(f" Duration: {args.duration}s") + print(f"{'='*60}") + + if idx > 1: + print(f" Starting in {args.pause}s... (get into position)") + time.sleep(args.pause) + + meta = collect_session( + port=args.port, + port2=args.port2, + output_dir=args.output_dir, + label=scenario, + duration=args.duration, + session_name=f"{scenario}_r{repeat+1:02d}", + ) + all_sessions.append(meta) + + elif args.label: + # Single labeled recording + meta = collect_session( + port=args.port, + port2=args.port2, + output_dir=args.output_dir, + label=args.label, + duration=args.duration, + ) + all_sessions.append(meta) + + else: + # Interactive mode — prompt for labels + print("\nInteractive data collection mode.") + print("Type a label for each recording, or 'q' to quit.\n") + + while True: + label = input("Label (or 'q' to quit): ").strip() + if label.lower() in ("q", "quit", "exit"): + break + if not label: + print(" Empty label. Try again.") + continue + + duration = args.duration + try: + dur_input = input(f"Duration in seconds [{duration}]: ").strip() + if dur_input: + duration = float(dur_input) + except ValueError: + pass + + print(f" Recording '{label}' for {duration}s — starting now...") + meta = collect_session( + port=args.port, + port2=args.port2, + output_dir=args.output_dir, + label=label, + duration=duration, + ) + all_sessions.append(meta) + print() + + # Generate manifest + if all_sessions: + print(f"\nCollected {len(all_sessions)} session(s).") + manifest = generate_manifest(args.output_dir) + + total_frames = sum(s.get("frame_count", 0) for s in all_sessions) + print(f"\nSummary:") + print(f" Sessions: {len(all_sessions)}") + print(f" Total frames: {total_frames}") + print(f" Output: {args.output_dir}/") + print(f" Manifest: {args.output_dir}/manifest.json") + else: + print("No sessions recorded.") + + +if __name__ == "__main__": + main() diff --git a/scripts/gcloud-train.sh b/scripts/gcloud-train.sh new file mode 100644 index 00000000..f7bb0e35 --- /dev/null +++ b/scripts/gcloud-train.sh @@ -0,0 +1,469 @@ +#!/bin/bash +# ============================================================================== +# GCloud GPU Training Script for WiFi-DensePose +# ============================================================================== +# +# Creates a GCloud VM with GPU, runs the Rust training pipeline, downloads +# the trained model artifacts, and tears down the VM to avoid ongoing costs. +# +# Usage: +# bash scripts/gcloud-train.sh [OPTIONS] +# +# Options: +# --gpu l4|a100|h100 GPU type (default: l4) +# --zone ZONE GCloud zone (default: us-central1-a) +# --hours N Max VM lifetime in hours (default: 2) +# --config FILE Training config JSON (default: scripts/training-config-sweep.json entry 0) +# --data-dir DIR Local data directory to upload (default: data/recordings) +# --dry-run Run smoke test with synthetic data +# --sweep Run full hyperparameter sweep (all configs) +# --keep-vm Do not delete VM after training +# --instance NAME Custom VM instance name +# +# Prerequisites: +# - gcloud CLI authenticated: gcloud auth login +# - Project set: gcloud config set project cognitum-20260110 +# - Quota for GPUs in the selected zone +# +# Cost estimates: +# L4 (~$0.80/hr) — good for prototyping and small sweeps +# A100 40GB (~$3.60/hr) — full training runs +# H100 80GB (~$11.00/hr) — large batch / fast iteration +# ============================================================================== + +set -euo pipefail + +# ── Defaults ────────────────────────────────────────────────────────────────── + +PROJECT="cognitum-20260110" +GPU_TYPE="l4" +ZONE="us-central1-a" +MAX_HOURS=2 +CONFIG_FILE="" +DATA_DIR="data/recordings" +DRY_RUN=false +SWEEP=false +KEEP_VM=false +INSTANCE_NAME="" +REPO_URL="https://github.com/ruvnet/wifi-densepose.git" +BRANCH="main" + +# ── Parse arguments ─────────────────────────────────────────────────────────── + +while [[ $# -gt 0 ]]; do + case "$1" in + --gpu) GPU_TYPE="$2"; shift 2 ;; + --zone) ZONE="$2"; shift 2 ;; + --hours) MAX_HOURS="$2"; shift 2 ;; + --config) CONFIG_FILE="$2"; shift 2 ;; + --data-dir) DATA_DIR="$2"; shift 2 ;; + --dry-run) DRY_RUN=true; shift ;; + --sweep) SWEEP=true; shift ;; + --keep-vm) KEEP_VM=true; shift ;; + --instance) INSTANCE_NAME="$2"; shift 2 ;; + --branch) BRANCH="$2"; shift 2 ;; + -h|--help) + head -35 "$0" | tail -30 + exit 0 + ;; + *) + echo "ERROR: Unknown option: $1" + exit 1 + ;; + esac +done + +# ── GPU configuration map ──────────────────────────────────────────────────── + +declare -A GPU_ACCELERATOR=( + [l4]="nvidia-l4" + [a100]="nvidia-tesla-a100" + [h100]="nvidia-h100-80gb" +) + +declare -A GPU_MACHINE_TYPE=( + [l4]="g2-standard-8" + [a100]="a2-highgpu-1g" + [h100]="a3-highgpu-1g" +) + +declare -A GPU_BOOT_DISK=( + [l4]="200" + [a100]="300" + [h100]="300" +) + +if [[ -z "${GPU_ACCELERATOR[$GPU_TYPE]+x}" ]]; then + echo "ERROR: Unknown GPU type '$GPU_TYPE'. Choose: l4, a100, h100" + exit 1 +fi + +ACCELERATOR="${GPU_ACCELERATOR[$GPU_TYPE]}" +MACHINE_TYPE="${GPU_MACHINE_TYPE[$GPU_TYPE]}" +BOOT_DISK_GB="${GPU_BOOT_DISK[$GPU_TYPE]}" + +# ── Instance naming ────────────────────────────────────────────────────────── + +TIMESTAMP=$(date +%Y%m%d-%H%M%S) +if [[ -z "$INSTANCE_NAME" ]]; then + INSTANCE_NAME="wdp-train-${GPU_TYPE}-${TIMESTAMP}" +fi + +# ── Announce plan ──────────────────────────────────────────────────────────── + +echo "============================================================" +echo " WiFi-DensePose GCloud GPU Training" +echo "============================================================" +echo " Project: $PROJECT" +echo " Instance: $INSTANCE_NAME" +echo " Zone: $ZONE" +echo " GPU: $GPU_TYPE ($ACCELERATOR)" +echo " Machine: $MACHINE_TYPE" +echo " Boot disk: ${BOOT_DISK_GB}GB" +echo " Max runtime: ${MAX_HOURS}h" +echo " Data dir: $DATA_DIR" +echo " Dry run: $DRY_RUN" +echo " Sweep: $SWEEP" +echo " Branch: $BRANCH" +echo "============================================================" +echo "" + +# ── Verify gcloud auth ────────────────────────────────────────────────────── + +if ! gcloud auth list --filter=status:ACTIVE --format="value(account)" 2>/dev/null | head -1 | grep -q '@'; then + echo "ERROR: No active gcloud account. Run: gcloud auth login" + exit 1 +fi + +gcloud config set project "$PROJECT" --quiet + +# ── Build startup script ───────────────────────────────────────────────────── + +STARTUP_SCRIPT=$(cat <<'STARTUP_EOF' +#!/bin/bash +set -euo pipefail +exec > /var/log/wdp-setup.log 2>&1 + +echo "=== WiFi-DensePose GPU VM Setup ===" +echo "Started: $(date)" + +# Wait for GPU driver +echo "Waiting for NVIDIA driver..." +for i in $(seq 1 60); do + if nvidia-smi &>/dev/null; then + echo "GPU ready after ${i}s" + nvidia-smi + break + fi + sleep 5 +done + +if ! nvidia-smi &>/dev/null; then + echo "ERROR: GPU driver not available after 300s" + exit 1 +fi + +# Install Rust toolchain +echo "Installing Rust toolchain..." +curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable +source "$HOME/.cargo/env" +rustc --version +cargo --version + +# Install system dependencies +echo "Installing system dependencies..." +apt-get update -qq +apt-get install -y -qq pkg-config libssl-dev cmake clang + +# Find libtorch from the Deep Learning VM's PyTorch installation +echo "Locating libtorch..." +PYTORCH_LIB=$(python3 -c "import torch; print(torch.__path__[0] + '/lib')" 2>/dev/null || echo "") +if [[ -n "$PYTORCH_LIB" && -d "$PYTORCH_LIB" ]]; then + export LIBTORCH="$PYTORCH_LIB" + export LD_LIBRARY_PATH="${LIBTORCH}:${LD_LIBRARY_PATH:-}" + echo "Found libtorch at: $LIBTORCH" +else + echo "WARNING: PyTorch not found in system Python. Installing via pip..." + pip3 install torch --index-url https://download.pytorch.org/whl/cu121 + PYTORCH_LIB=$(python3 -c "import torch; print(torch.__path__[0] + '/lib')") + export LIBTORCH="$PYTORCH_LIB" + export LD_LIBRARY_PATH="${LIBTORCH}:${LD_LIBRARY_PATH:-}" +fi + +# Persist env vars +cat >> /etc/environment </dev/null; then + echo " Setup complete after $((i * 15))s" + break + fi + if [[ $i -eq 60 ]]; then + echo "ERROR: Setup timed out after 15 minutes." + echo "Check logs: gcloud compute ssh $INSTANCE_NAME --zone=$ZONE --command='cat /var/log/wdp-setup.log'" + if [[ "$KEEP_VM" == "false" ]]; then + echo "Cleaning up VM..." + gcloud compute instances delete "$INSTANCE_NAME" --zone="$ZONE" --quiet + fi + exit 1 + fi + sleep 15 +done + +# ── Step 3: Clone repo and build ───────────────────────────────────────────── + +echo "[3/7] Cloning repository and building training binary..." + +gcloud compute ssh "$INSTANCE_NAME" --zone="$ZONE" --command="$(cat <&1 | tail -5 + +echo "Build complete." +ls -lh target/release/train +CLONE_EOF +)" + +# ── Step 4: Upload training data ───────────────────────────────────────────── + +echo "[4/7] Uploading training data..." + +if [[ -d "$DATA_DIR" ]] && [[ "$(ls -A "$DATA_DIR" 2>/dev/null)" ]]; then + # Create a tarball of the data directory + DATA_TAR="/tmp/wdp-training-data-${TIMESTAMP}.tar.gz" + tar czf "$DATA_TAR" -C "$(dirname "$DATA_DIR")" "$(basename "$DATA_DIR")" + DATA_SIZE=$(du -h "$DATA_TAR" | cut -f1) + echo " Uploading ${DATA_SIZE} of training data..." + + gcloud compute scp "$DATA_TAR" "${INSTANCE_NAME}:~/training-data.tar.gz" --zone="$ZONE" --quiet + gcloud compute ssh "$INSTANCE_NAME" --zone="$ZONE" --command=" + mkdir -p ~/wifi-densepose/data + tar xzf ~/training-data.tar.gz -C ~/wifi-densepose/data/ + echo 'Data extracted:' + find ~/wifi-densepose/data -name '*.jsonl' -o -name '*.csi.jsonl' | head -20 + " + rm -f "$DATA_TAR" +else + echo " No local data at '$DATA_DIR'. Training will use --dry-run or MM-Fi." + if [[ "$DRY_RUN" == "false" && "$SWEEP" == "false" ]]; then + echo " WARNING: No data and --dry-run not set. Forcing --dry-run." + DRY_RUN=true + fi +fi + +# ── Step 5: Upload config and run training ──────────────────────────────────── + +echo "[5/7] Running training..." + +# Upload sweep config if doing a sweep +if [[ "$SWEEP" == "true" ]]; then + SWEEP_FILE="scripts/training-config-sweep.json" + if [[ -f "$SWEEP_FILE" ]]; then + gcloud compute scp "$SWEEP_FILE" "${INSTANCE_NAME}:~/sweep-configs.json" --zone="$ZONE" --quiet + else + echo "ERROR: Sweep config not found at $SWEEP_FILE" + exit 1 + fi +fi + +# Upload single config if specified +if [[ -n "$CONFIG_FILE" ]]; then + gcloud compute scp "$CONFIG_FILE" "${INSTANCE_NAME}:~/train-config.json" --zone="$ZONE" --quiet +fi + +# Build the training command +TRAIN_CMD_BASE=" +set -euo pipefail +source \$HOME/.cargo/env +export LIBTORCH=\$(python3 -c \"import torch; print(torch.__path__[0] + '/lib')\") +export LD_LIBRARY_PATH=\"\${LIBTORCH}:\${LD_LIBRARY_PATH:-}\" +cd ~/wifi-densepose/rust-port/wifi-densepose-rs + +# Set auto-shutdown timer (safety net) +sudo shutdown -P +$((MAX_HOURS * 60)) & + +TRAIN_BIN=./target/release/train +" + +if [[ "$SWEEP" == "true" ]]; then + # Run all configs in the sweep file + gcloud compute ssh "$INSTANCE_NAME" --zone="$ZONE" --command="$(cat <&1 | tee ~/results/sweep_\${i}.log || true + + END_TIME=\$(date +%s) + ELAPSED=\$(( END_TIME - START_TIME )) + echo " Completed in \${ELAPSED}s" +done + +echo "" +echo "=== Sweep Complete ===" +echo "Results in ~/results/" +ls -lh ~/results/ +SWEEP_EOF +)" +elif [[ -n "$CONFIG_FILE" ]]; then + # Single config run + gcloud compute ssh "$INSTANCE_NAME" --zone="$ZONE" --command="$(cat <&1 | tee ~/train.log +SINGLE_EOF +)" +else + # Default config run + gcloud compute ssh "$INSTANCE_NAME" --zone="$ZONE" --command="$(cat <&1 | tee ~/train.log +DEFAULT_EOF +)" +fi + +# ── Step 6: Download results ───────────────────────────────────────────────── + +echo "[6/7] Downloading trained model artifacts..." + +LOCAL_RESULTS="training-results/${INSTANCE_NAME}" +mkdir -p "$LOCAL_RESULTS" + +# Package results on the VM +gcloud compute ssh "$INSTANCE_NAME" --zone="$ZONE" --command=" +cd ~/wifi-densepose/rust-port/wifi-densepose-rs +tar czf ~/training-artifacts.tar.gz \ + checkpoints/ \ + logs/ \ + 2>/dev/null || true + +# Also grab sweep results if they exist +if [[ -d ~/results ]]; then + tar czf ~/sweep-results.tar.gz -C ~ results/ 2>/dev/null || true +fi + +ls -lh ~/training-artifacts.tar.gz ~/sweep-results.tar.gz 2>/dev/null || true +" + +# Download artifacts +gcloud compute scp "${INSTANCE_NAME}:~/training-artifacts.tar.gz" \ + "${LOCAL_RESULTS}/training-artifacts.tar.gz" --zone="$ZONE" --quiet 2>/dev/null || true + +if [[ "$SWEEP" == "true" ]]; then + gcloud compute scp "${INSTANCE_NAME}:~/sweep-results.tar.gz" \ + "${LOCAL_RESULTS}/sweep-results.tar.gz" --zone="$ZONE" --quiet 2>/dev/null || true +fi + +# Download training log +gcloud compute scp "${INSTANCE_NAME}:~/train.log" \ + "${LOCAL_RESULTS}/train.log" --zone="$ZONE" --quiet 2>/dev/null || true + +# Extract locally +if [[ -f "${LOCAL_RESULTS}/training-artifacts.tar.gz" ]]; then + tar xzf "${LOCAL_RESULTS}/training-artifacts.tar.gz" -C "$LOCAL_RESULTS/" + echo " Artifacts extracted to: $LOCAL_RESULTS/" + find "$LOCAL_RESULTS" -name "*.pt" -o -name "*.onnx" -o -name "*.rvf" 2>/dev/null | head -20 +fi + +# ── Step 7: Cleanup ────────────────────────────────────────────────────────── + +if [[ "$KEEP_VM" == "true" ]]; then + echo "[7/7] Keeping VM alive (--keep-vm). Remember to delete it manually:" + echo " gcloud compute instances delete $INSTANCE_NAME --zone=$ZONE --quiet" + echo " SSH: gcloud compute ssh $INSTANCE_NAME --zone=$ZONE" +else + echo "[7/7] Deleting VM to avoid ongoing costs..." + gcloud compute instances delete "$INSTANCE_NAME" --zone="$ZONE" --quiet + echo " VM deleted." +fi + +# ── Summary ────────────────────────────────────────────────────────────────── + +echo "" +echo "============================================================" +echo " Training Complete" +echo "============================================================" +echo " Results: $LOCAL_RESULTS/" +echo " GPU: $GPU_TYPE ($ZONE)" +echo " Instance: $INSTANCE_NAME" +if [[ "$KEEP_VM" == "true" ]]; then + echo " VM: STILL RUNNING (delete manually!)" +fi +echo "============================================================" diff --git a/scripts/training-config-sweep.json b/scripts/training-config-sweep.json new file mode 100644 index 00000000..ffacb69b --- /dev/null +++ b/scripts/training-config-sweep.json @@ -0,0 +1,155 @@ +{ + "description": "WiFi-DensePose hyperparameter sweep — 10 configurations exploring learning rate, batch size, backbone width, window length, loss ratios, and warmup schedules.", + "base": { + "num_subcarriers": 56, + "native_subcarriers": 114, + "num_antennas_tx": 3, + "num_antennas_rx": 3, + "heatmap_size": 56, + "num_keypoints": 17, + "num_body_parts": 24, + "weight_decay": 1e-4, + "num_epochs": 50, + "lr_gamma": 0.1, + "grad_clip_norm": 1.0, + "val_every_epochs": 1, + "early_stopping_patience": 10, + "save_top_k": 3, + "use_gpu": true, + "gpu_device_id": 0, + "num_workers": 4, + "seed": 42 + }, + "configs": [ + { + "_name": "baseline", + "_description": "Default config — reference baseline", + "learning_rate": 1e-3, + "batch_size": 8, + "backbone_channels": 256, + "window_frames": 100, + "warmup_epochs": 5, + "lr_milestones": [30, 45], + "lambda_kp": 0.3, + "lambda_dp": 0.6, + "lambda_tr": 0.1 + }, + { + "_name": "low_lr_large_batch", + "_description": "Lower LR with larger batch — stable convergence", + "learning_rate": 1e-4, + "batch_size": 16, + "backbone_channels": 256, + "window_frames": 100, + "warmup_epochs": 10, + "lr_milestones": [30, 45], + "lambda_kp": 0.3, + "lambda_dp": 0.6, + "lambda_tr": 0.1 + }, + { + "_name": "high_lr_small_batch", + "_description": "Higher LR with small batch — fast exploration", + "learning_rate": 2e-3, + "batch_size": 4, + "backbone_channels": 256, + "window_frames": 100, + "warmup_epochs": 3, + "lr_milestones": [20, 40], + "lambda_kp": 0.3, + "lambda_dp": 0.6, + "lambda_tr": 0.1 + }, + { + "_name": "narrow_backbone", + "_description": "128-channel backbone — faster training, lower VRAM", + "learning_rate": 1e-3, + "batch_size": 16, + "backbone_channels": 128, + "window_frames": 100, + "warmup_epochs": 5, + "lr_milestones": [30, 45], + "lambda_kp": 0.3, + "lambda_dp": 0.6, + "lambda_tr": 0.1 + }, + { + "_name": "short_window", + "_description": "50-frame window — lower latency, tests temporal sensitivity", + "learning_rate": 5e-4, + "batch_size": 16, + "backbone_channels": 256, + "window_frames": 50, + "warmup_epochs": 5, + "lr_milestones": [30, 45], + "lambda_kp": 0.3, + "lambda_dp": 0.6, + "lambda_tr": 0.1 + }, + { + "_name": "keypoint_heavy", + "_description": "Heavier keypoint loss — prioritize skeleton accuracy", + "learning_rate": 5e-4, + "batch_size": 8, + "backbone_channels": 256, + "window_frames": 100, + "warmup_epochs": 5, + "lr_milestones": [30, 45], + "lambda_kp": 0.5, + "lambda_dp": 0.4, + "lambda_tr": 0.1 + }, + { + "_name": "contrastive_heavy", + "_description": "Strong contrastive/transfer loss — self-supervised pretraining focus", + "learning_rate": 5e-4, + "batch_size": 8, + "backbone_channels": 256, + "window_frames": 100, + "warmup_epochs": 10, + "lr_milestones": [30, 45], + "lambda_kp": 0.2, + "lambda_dp": 0.3, + "lambda_tr": 0.5 + }, + { + "_name": "wide_backbone_long_warmup", + "_description": "256-ch backbone + long warmup + moderate LR", + "learning_rate": 5e-4, + "batch_size": 8, + "backbone_channels": 256, + "window_frames": 100, + "warmup_epochs": 10, + "lr_milestones": [35, 48], + "lambda_kp": 0.3, + "lambda_dp": 0.6, + "lambda_tr": 0.1 + }, + { + "_name": "narrow_short_aggressive", + "_description": "128-ch + 50-frame + high LR — fast cheap exploration", + "learning_rate": 2e-3, + "batch_size": 16, + "backbone_channels": 128, + "window_frames": 50, + "warmup_epochs": 3, + "lr_milestones": [20, 40], + "lambda_kp": 0.4, + "lambda_dp": 0.5, + "lambda_tr": 0.1 + }, + { + "_name": "balanced_medium", + "_description": "Balanced loss, medium LR, medium batch — robust default", + "learning_rate": 5e-4, + "batch_size": 8, + "backbone_channels": 256, + "window_frames": 100, + "warmup_epochs": 5, + "lr_milestones": [25, 40], + "lambda_kp": 0.35, + "lambda_dp": 0.45, + "lambda_tr": 0.2 + } + ] +}