mirror of
https://github.com/ruvnet/RuView.git
synced 2026-05-19 08:10:07 +00:00
feat: GCloud GPU training pipeline + data collection + benchmarking
- gcloud-train.sh: L4/A100/H100 VM provisioning, Rust build, training with --cuda, artifact download, auto-cleanup ($0.80-$8.50/hr) - training-config-sweep.json: 10 hyperparameter configs (LR, batch, backbone, windows, loss weights, warmup) - collect-training-data.py: UDP listener for 2-node ESP32 CSI recording to .csi.jsonl with interactive/batch labeling and manifest generation - benchmark-model.py: ONNX latency/throughput/PCK/FLOPs profiling with multi-model sweep comparison Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
parent
9a2bc1839a
commit
c63cf2ee77
4 changed files with 1657 additions and 0 deletions
550
scripts/benchmark-model.py
Normal file
550
scripts/benchmark-model.py
Normal file
|
|
@ -0,0 +1,550 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
WiFi-DensePose Model Benchmarking
|
||||
|
||||
Loads trained ONNX models, runs inference on test data, and reports
|
||||
performance metrics: latency, throughput, PCK@0.2, model size, and
|
||||
estimated FLOPs.
|
||||
|
||||
Can compare multiple models from a hyperparameter sweep.
|
||||
|
||||
Usage:
|
||||
# Benchmark a single model
|
||||
python scripts/benchmark-model.py --model checkpoints/best.onnx
|
||||
|
||||
# Benchmark with recorded test data
|
||||
python scripts/benchmark-model.py --model best.onnx --test-data data/recordings/test.csi.jsonl
|
||||
|
||||
# Compare models from a sweep
|
||||
python scripts/benchmark-model.py --sweep-dir training-results/wdp-train-a100-*/checkpoints/
|
||||
|
||||
# Benchmark with synthetic data (no recordings needed)
|
||||
python scripts/benchmark-model.py --model best.onnx --synthetic --num-samples 200
|
||||
|
||||
# Export results as JSON
|
||||
python scripts/benchmark-model.py --model best.onnx --output results.json
|
||||
|
||||
Prerequisites:
|
||||
pip install onnxruntime numpy
|
||||
Optional: pip install onnx (for FLOPs estimation)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
except ImportError:
|
||||
print("ERROR: onnxruntime not installed. Run: pip install onnxruntime")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# ── Configuration ────────────────────────────────────────────────────────────
|
||||
|
||||
# Default model input shape (must match TrainingConfig defaults)
|
||||
NUM_SUBCARRIERS = 56
|
||||
NUM_ANTENNAS_TX = 3
|
||||
NUM_ANTENNAS_RX = 3
|
||||
WINDOW_FRAMES = 100
|
||||
NUM_KEYPOINTS = 17
|
||||
HEATMAP_SIZE = 56
|
||||
|
||||
# PCK threshold
|
||||
PCK_THRESHOLD = 0.2
|
||||
|
||||
|
||||
# ── Data classes ─────────────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class BenchmarkResult:
|
||||
model_path: str
|
||||
model_size_mb: float
|
||||
num_parameters: Optional[int] = None
|
||||
estimated_flops: Optional[int] = None
|
||||
|
||||
# Latency
|
||||
warmup_runs: int = 10
|
||||
benchmark_runs: int = 100
|
||||
latency_mean_ms: float = 0.0
|
||||
latency_std_ms: float = 0.0
|
||||
latency_p50_ms: float = 0.0
|
||||
latency_p95_ms: float = 0.0
|
||||
latency_p99_ms: float = 0.0
|
||||
throughput_fps: float = 0.0
|
||||
|
||||
# Accuracy (if ground truth available)
|
||||
pck_at_02: Optional[float] = None
|
||||
mean_per_joint_error: Optional[float] = None
|
||||
num_test_samples: int = 0
|
||||
|
||||
# Input shape
|
||||
input_shape: list = field(default_factory=list)
|
||||
provider: str = ""
|
||||
|
||||
|
||||
# ── ONNX model loading ──────────────────────────────────────────────────────
|
||||
|
||||
def load_model(model_path: str) -> ort.InferenceSession:
|
||||
"""Load an ONNX model with the best available execution provider."""
|
||||
providers = []
|
||||
if "CUDAExecutionProvider" in ort.get_available_providers():
|
||||
providers.append("CUDAExecutionProvider")
|
||||
providers.append("CPUExecutionProvider")
|
||||
|
||||
sess_opts = ort.SessionOptions()
|
||||
sess_opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
sess_opts.intra_op_num_threads = os.cpu_count() or 4
|
||||
|
||||
session = ort.InferenceSession(model_path, sess_opts, providers=providers)
|
||||
return session
|
||||
|
||||
|
||||
def get_model_info(model_path: str) -> dict:
|
||||
"""Extract model metadata: size, parameter count, FLOPs estimate."""
|
||||
path = Path(model_path)
|
||||
size_mb = path.stat().st_size / (1024 * 1024)
|
||||
|
||||
info = {
|
||||
"size_mb": round(size_mb, 2),
|
||||
"num_parameters": None,
|
||||
"estimated_flops": None,
|
||||
}
|
||||
|
||||
# Try to count parameters via onnx
|
||||
try:
|
||||
import onnx
|
||||
model = onnx.load(model_path)
|
||||
total_params = 0
|
||||
for initializer in model.graph.initializer:
|
||||
shape = list(initializer.dims)
|
||||
if shape:
|
||||
total_params += int(np.prod(shape))
|
||||
info["num_parameters"] = total_params
|
||||
|
||||
# Rough FLOPs estimate: ~2 * params (multiply-accumulate)
|
||||
info["estimated_flops"] = total_params * 2
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f" Warning: Could not extract parameter count: {e}")
|
||||
|
||||
return info
|
||||
|
||||
|
||||
# ── Synthetic data generation ────────────────────────────────────────────────
|
||||
|
||||
def generate_synthetic_input(
|
||||
batch_size: int = 1,
|
||||
num_subcarriers: int = NUM_SUBCARRIERS,
|
||||
num_tx: int = NUM_ANTENNAS_TX,
|
||||
num_rx: int = NUM_ANTENNAS_RX,
|
||||
window_frames: int = WINDOW_FRAMES,
|
||||
) -> np.ndarray:
|
||||
"""Generate synthetic CSI input tensor matching the model's expected shape.
|
||||
|
||||
The WiFi-DensePose model expects input shape:
|
||||
[batch, channels, height, width]
|
||||
where channels = num_tx * num_rx, height = window_frames, width = num_subcarriers.
|
||||
"""
|
||||
channels = num_tx * num_rx # 3x3 = 9 MIMO streams
|
||||
# Simulate CSI amplitude data with realistic distribution
|
||||
rng = np.random.default_rng(42)
|
||||
data = rng.normal(loc=0.0, scale=1.0, size=(batch_size, channels, window_frames, num_subcarriers))
|
||||
return data.astype(np.float32)
|
||||
|
||||
|
||||
def generate_synthetic_keypoints(
|
||||
num_samples: int,
|
||||
num_keypoints: int = NUM_KEYPOINTS,
|
||||
heatmap_size: int = HEATMAP_SIZE,
|
||||
) -> np.ndarray:
|
||||
"""Generate synthetic ground truth keypoint coordinates for PCK evaluation."""
|
||||
rng = np.random.default_rng(123)
|
||||
# Keypoints as (x, y) in [0, heatmap_size) range
|
||||
return rng.uniform(0, heatmap_size, size=(num_samples, num_keypoints, 2)).astype(np.float32)
|
||||
|
||||
|
||||
# ── Load test data from .csi.jsonl ──────────────────────────────────────────
|
||||
|
||||
def load_test_data(
|
||||
jsonl_path: str,
|
||||
window_frames: int = WINDOW_FRAMES,
|
||||
num_subcarriers: int = NUM_SUBCARRIERS,
|
||||
max_samples: int = 500,
|
||||
) -> np.ndarray:
|
||||
"""Load CSI frames from a .csi.jsonl file and window them into model inputs."""
|
||||
frames = []
|
||||
path = Path(jsonl_path)
|
||||
|
||||
with open(path, "r") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
record = json.loads(line)
|
||||
subs = record.get("subcarriers", [])
|
||||
if len(subs) > 0:
|
||||
frames.append(subs)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if len(frames) < window_frames:
|
||||
print(f" Warning: Only {len(frames)} frames, need {window_frames}. Padding with zeros.")
|
||||
while len(frames) < window_frames:
|
||||
frames.append([0.0] * num_subcarriers)
|
||||
|
||||
# Normalize subcarrier count
|
||||
normalized = []
|
||||
for frame in frames:
|
||||
if len(frame) < num_subcarriers:
|
||||
frame = frame + [0.0] * (num_subcarriers - len(frame))
|
||||
elif len(frame) > num_subcarriers:
|
||||
# Downsample via linear interpolation
|
||||
indices = np.linspace(0, len(frame) - 1, num_subcarriers)
|
||||
frame = np.interp(indices, range(len(frame)), frame).tolist()
|
||||
normalized.append(frame)
|
||||
|
||||
frames = normalized
|
||||
|
||||
# Create sliding windows
|
||||
samples = []
|
||||
stride = max(1, window_frames // 2)
|
||||
for i in range(0, len(frames) - window_frames + 1, stride):
|
||||
window = frames[i : i + window_frames]
|
||||
# Shape: [channels=1, window_frames, num_subcarriers]
|
||||
# Expand single stream to 9 channels (repeat for MIMO)
|
||||
arr = np.array(window, dtype=np.float32)
|
||||
arr = np.expand_dims(arr, axis=0) # [1, window_frames, num_subcarriers]
|
||||
arr = np.repeat(arr, NUM_ANTENNAS_TX * NUM_ANTENNAS_RX, axis=0) # [9, window, subs]
|
||||
samples.append(arr)
|
||||
|
||||
if len(samples) >= max_samples:
|
||||
break
|
||||
|
||||
if not samples:
|
||||
return generate_synthetic_input(1)
|
||||
|
||||
return np.stack(samples, axis=0) # [N, 9, window_frames, num_subcarriers]
|
||||
|
||||
|
||||
# ── Benchmarking ─────────────────────────────────────────────────────────────
|
||||
|
||||
def benchmark_latency(
|
||||
session: ort.InferenceSession,
|
||||
input_data: np.ndarray,
|
||||
warmup: int = 10,
|
||||
runs: int = 100,
|
||||
) -> dict:
|
||||
"""Measure inference latency over multiple runs."""
|
||||
input_name = session.get_inputs()[0].name
|
||||
|
||||
# Warmup
|
||||
for _ in range(warmup):
|
||||
session.run(None, {input_name: input_data[:1]})
|
||||
|
||||
# Timed runs
|
||||
latencies = []
|
||||
for _ in range(runs):
|
||||
start = time.perf_counter()
|
||||
session.run(None, {input_name: input_data[:1]})
|
||||
end = time.perf_counter()
|
||||
latencies.append((end - start) * 1000) # ms
|
||||
|
||||
latencies = np.array(latencies)
|
||||
return {
|
||||
"mean_ms": float(np.mean(latencies)),
|
||||
"std_ms": float(np.std(latencies)),
|
||||
"p50_ms": float(np.percentile(latencies, 50)),
|
||||
"p95_ms": float(np.percentile(latencies, 95)),
|
||||
"p99_ms": float(np.percentile(latencies, 99)),
|
||||
"throughput_fps": 1000.0 / float(np.mean(latencies)),
|
||||
}
|
||||
|
||||
|
||||
def compute_pck(
|
||||
predictions: np.ndarray,
|
||||
ground_truth: np.ndarray,
|
||||
threshold: float = PCK_THRESHOLD,
|
||||
normalize_by: float = HEATMAP_SIZE,
|
||||
) -> float:
|
||||
"""Compute Percentage of Correct Keypoints at a given threshold.
|
||||
|
||||
PCK@t = fraction of predicted keypoints within t * normalize_by of ground truth.
|
||||
"""
|
||||
if predictions.shape != ground_truth.shape:
|
||||
return 0.0
|
||||
|
||||
# Euclidean distance per keypoint
|
||||
distances = np.linalg.norm(predictions - ground_truth, axis=-1) # [N, K]
|
||||
threshold_pixels = threshold * normalize_by
|
||||
correct = (distances < threshold_pixels).astype(float)
|
||||
return float(np.mean(correct))
|
||||
|
||||
|
||||
def extract_keypoints_from_heatmaps(heatmaps: np.ndarray) -> np.ndarray:
|
||||
"""Convert heatmap outputs [N, K, H, W] to keypoint coordinates [N, K, 2]."""
|
||||
n, k, h, w = heatmaps.shape
|
||||
flat = heatmaps.reshape(n, k, -1)
|
||||
max_idx = np.argmax(flat, axis=-1) # [N, K]
|
||||
y = max_idx // w
|
||||
x = max_idx % w
|
||||
return np.stack([x, y], axis=-1).astype(np.float32)
|
||||
|
||||
|
||||
def benchmark_model(
|
||||
model_path: str,
|
||||
test_data: Optional[np.ndarray] = None,
|
||||
gt_keypoints: Optional[np.ndarray] = None,
|
||||
warmup: int = 10,
|
||||
runs: int = 100,
|
||||
) -> BenchmarkResult:
|
||||
"""Run full benchmark on a single model."""
|
||||
print(f"\nBenchmarking: {model_path}")
|
||||
|
||||
# Load model
|
||||
session = load_model(model_path)
|
||||
provider = session.get_providers()[0]
|
||||
print(f" Provider: {provider}")
|
||||
|
||||
# Model info
|
||||
model_info = get_model_info(model_path)
|
||||
print(f" Size: {model_info['size_mb']} MB")
|
||||
if model_info["num_parameters"]:
|
||||
print(f" Parameters: {model_info['num_parameters']:,}")
|
||||
if model_info["estimated_flops"]:
|
||||
print(f" Estimated FLOPs: {model_info['estimated_flops']:,}")
|
||||
|
||||
# Input shape
|
||||
input_meta = session.get_inputs()[0]
|
||||
input_shape = input_meta.shape
|
||||
print(f" Input: {input_meta.name} {input_shape} ({input_meta.type})")
|
||||
|
||||
# Output shapes
|
||||
for out in session.get_outputs():
|
||||
print(f" Output: {out.name} {out.shape}")
|
||||
|
||||
# Generate or use provided test data
|
||||
if test_data is None:
|
||||
# Infer shape from model
|
||||
if input_shape and all(isinstance(d, int) for d in input_shape):
|
||||
batch = max(1, input_shape[0] if input_shape[0] > 0 else 1)
|
||||
test_data = np.random.randn(*[batch if d <= 0 else d for d in input_shape]).astype(np.float32)
|
||||
else:
|
||||
test_data = generate_synthetic_input(1)
|
||||
|
||||
# Latency benchmark
|
||||
print(f" Running {warmup} warmup + {runs} benchmark iterations...")
|
||||
latency = benchmark_latency(session, test_data, warmup=warmup, runs=runs)
|
||||
print(f" Latency: {latency['mean_ms']:.2f} +/- {latency['std_ms']:.2f} ms")
|
||||
print(f" P50/P95/P99: {latency['p50_ms']:.2f} / {latency['p95_ms']:.2f} / {latency['p99_ms']:.2f} ms")
|
||||
print(f" Throughput: {latency['throughput_fps']:.1f} fps")
|
||||
|
||||
# Accuracy (if ground truth provided or we can do synthetic evaluation)
|
||||
pck = None
|
||||
mpjpe = None
|
||||
num_samples = 0
|
||||
|
||||
if gt_keypoints is not None and test_data is not None:
|
||||
input_name = session.get_inputs()[0].name
|
||||
all_preds = []
|
||||
|
||||
for i in range(len(test_data)):
|
||||
outputs = session.run(None, {input_name: test_data[i : i + 1]})
|
||||
# Assume first output is keypoint heatmaps [1, K, H, W]
|
||||
heatmaps = outputs[0]
|
||||
if heatmaps.ndim == 4:
|
||||
kp = extract_keypoints_from_heatmaps(heatmaps)
|
||||
all_preds.append(kp[0])
|
||||
|
||||
if all_preds:
|
||||
predictions = np.stack(all_preds)
|
||||
gt = gt_keypoints[: len(predictions)]
|
||||
pck = compute_pck(predictions, gt)
|
||||
distances = np.linalg.norm(predictions - gt, axis=-1)
|
||||
mpjpe = float(np.mean(distances))
|
||||
num_samples = len(predictions)
|
||||
print(f" PCK@{PCK_THRESHOLD}: {pck:.4f}")
|
||||
print(f" MPJPE: {mpjpe:.2f} px")
|
||||
print(f" Samples: {num_samples}")
|
||||
|
||||
result = BenchmarkResult(
|
||||
model_path=model_path,
|
||||
model_size_mb=model_info["size_mb"],
|
||||
num_parameters=model_info["num_parameters"],
|
||||
estimated_flops=model_info["estimated_flops"],
|
||||
warmup_runs=warmup,
|
||||
benchmark_runs=runs,
|
||||
latency_mean_ms=round(latency["mean_ms"], 3),
|
||||
latency_std_ms=round(latency["std_ms"], 3),
|
||||
latency_p50_ms=round(latency["p50_ms"], 3),
|
||||
latency_p95_ms=round(latency["p95_ms"], 3),
|
||||
latency_p99_ms=round(latency["p99_ms"], 3),
|
||||
throughput_fps=round(latency["throughput_fps"], 1),
|
||||
pck_at_02=round(pck, 4) if pck is not None else None,
|
||||
mean_per_joint_error=round(mpjpe, 2) if mpjpe is not None else None,
|
||||
num_test_samples=num_samples,
|
||||
input_shape=list(input_shape) if input_shape else [],
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ── Comparison table ─────────────────────────────────────────────────────────
|
||||
|
||||
def print_comparison_table(results: list[BenchmarkResult]):
|
||||
"""Print a formatted comparison table of multiple models."""
|
||||
if not results:
|
||||
return
|
||||
|
||||
print("\n" + "=" * 100)
|
||||
print(" Model Comparison")
|
||||
print("=" * 100)
|
||||
|
||||
# Header
|
||||
print(
|
||||
f"{'Model':<35} {'Size(MB)':>8} {'Params':>10} "
|
||||
f"{'Lat(ms)':>8} {'P95(ms)':>8} {'FPS':>7} {'PCK@0.2':>8}"
|
||||
)
|
||||
print("-" * 100)
|
||||
|
||||
for r in results:
|
||||
name = Path(r.model_path).stem[:33]
|
||||
params = f"{r.num_parameters:,}" if r.num_parameters else "?"
|
||||
pck = f"{r.pck_at_02:.4f}" if r.pck_at_02 is not None else "N/A"
|
||||
|
||||
print(
|
||||
f"{name:<35} {r.model_size_mb:>8.2f} {params:>10} "
|
||||
f"{r.latency_mean_ms:>8.2f} {r.latency_p95_ms:>8.2f} "
|
||||
f"{r.throughput_fps:>7.1f} {pck:>8}"
|
||||
)
|
||||
|
||||
print("=" * 100)
|
||||
|
||||
# Best model by latency
|
||||
best_latency = min(results, key=lambda r: r.latency_mean_ms)
|
||||
print(f"\n Fastest: {Path(best_latency.model_path).stem} ({best_latency.latency_mean_ms:.2f} ms)")
|
||||
|
||||
# Best by PCK (if available)
|
||||
pck_results = [r for r in results if r.pck_at_02 is not None]
|
||||
if pck_results:
|
||||
best_pck = max(pck_results, key=lambda r: r.pck_at_02)
|
||||
print(f" Best accuracy: {Path(best_pck.model_path).stem} (PCK@0.2={best_pck.pck_at_02:.4f})")
|
||||
|
||||
# Smallest model
|
||||
smallest = min(results, key=lambda r: r.model_size_mb)
|
||||
print(f" Smallest: {Path(smallest.model_path).stem} ({smallest.model_size_mb:.2f} MB)")
|
||||
|
||||
|
||||
# ── Main ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark WiFi-DensePose ONNX models",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument("--model", type=str, help="Path to a single ONNX model")
|
||||
parser.add_argument("--sweep-dir", type=str, help="Directory containing multiple ONNX models to compare")
|
||||
parser.add_argument("--test-data", type=str, help="Path to .csi.jsonl test data file")
|
||||
parser.add_argument("--synthetic", action="store_true", help="Use synthetic test data")
|
||||
parser.add_argument("--num-samples", type=int, default=100, help="Number of synthetic samples (default: 100)")
|
||||
parser.add_argument("--warmup", type=int, default=10, help="Warmup iterations (default: 10)")
|
||||
parser.add_argument("--runs", type=int, default=100, help="Benchmark iterations (default: 100)")
|
||||
parser.add_argument("--output", type=str, help="Save results to JSON file")
|
||||
parser.add_argument("--gpu", action="store_true", help="Force GPU execution provider")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.model and not args.sweep_dir:
|
||||
parser.error("Specify --model or --sweep-dir")
|
||||
|
||||
# Prepare test data
|
||||
test_data = None
|
||||
gt_keypoints = None
|
||||
|
||||
if args.test_data:
|
||||
print(f"Loading test data from: {args.test_data}")
|
||||
test_data = load_test_data(args.test_data)
|
||||
print(f" Loaded {len(test_data)} windowed samples")
|
||||
elif args.synthetic:
|
||||
print(f"Generating {args.num_samples} synthetic samples...")
|
||||
test_data = generate_synthetic_input(args.num_samples)
|
||||
gt_keypoints = generate_synthetic_keypoints(args.num_samples)
|
||||
print(f" Input shape: {test_data.shape}")
|
||||
|
||||
# Collect models
|
||||
model_paths = []
|
||||
if args.model:
|
||||
model_paths.append(args.model)
|
||||
if args.sweep_dir:
|
||||
sweep = Path(args.sweep_dir)
|
||||
if sweep.is_dir():
|
||||
model_paths.extend(sorted(str(p) for p in sweep.glob("**/*.onnx")))
|
||||
else:
|
||||
# Glob pattern
|
||||
from glob import glob
|
||||
model_paths.extend(sorted(glob(str(sweep))))
|
||||
|
||||
if not model_paths:
|
||||
print("ERROR: No ONNX models found.")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Found {len(model_paths)} model(s) to benchmark.")
|
||||
|
||||
# Benchmark each model
|
||||
results = []
|
||||
for path in model_paths:
|
||||
if not Path(path).exists():
|
||||
print(f" Skipping (not found): {path}")
|
||||
continue
|
||||
try:
|
||||
result = benchmark_model(
|
||||
path,
|
||||
test_data=test_data,
|
||||
gt_keypoints=gt_keypoints,
|
||||
warmup=args.warmup,
|
||||
runs=args.runs,
|
||||
)
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
print(f" ERROR benchmarking {path}: {e}")
|
||||
|
||||
# Comparison table
|
||||
if len(results) > 1:
|
||||
print_comparison_table(results)
|
||||
|
||||
# Save results
|
||||
if args.output:
|
||||
output_path = Path(args.output)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"benchmark_results": [asdict(r) for r in results],
|
||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
||||
"num_models": len(results),
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
)
|
||||
print(f"\nResults saved to: {output_path}")
|
||||
|
||||
if not results:
|
||||
print("No models were successfully benchmarked.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
483
scripts/collect-training-data.py
Normal file
483
scripts/collect-training-data.py
Normal file
|
|
@ -0,0 +1,483 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
WiFi-DensePose Training Data Collector
|
||||
|
||||
Listens on UDP for CSI data from ESP32 nodes and records to .csi.jsonl
|
||||
files compatible with the Rust training pipeline (MmFiDataset / CsiDataset).
|
||||
|
||||
Supports two packet formats:
|
||||
- ADR-069 feature vectors (magic 0xC5110003, 48 bytes) — 8-dim pre-extracted
|
||||
- ADR-018 raw CSI frames (magic 0xC5110001, variable) — full subcarrier data
|
||||
|
||||
Usage:
|
||||
# Interactive — prompts for scenario labels
|
||||
python scripts/collect-training-data.py --port 5006
|
||||
|
||||
# Scripted — fixed label, 60s per recording
|
||||
python scripts/collect-training-data.py --port 5006 --label walking --duration 60
|
||||
|
||||
# Multiple scenarios in sequence
|
||||
python scripts/collect-training-data.py --port 5006 --scenarios walking,standing,sitting --duration 30
|
||||
|
||||
# Dual-node collection (two ESP32s on different ports)
|
||||
python scripts/collect-training-data.py --port 5005 --port2 5006 --label walking
|
||||
|
||||
# Generate manifest only from existing recordings
|
||||
python scripts/collect-training-data.py --manifest-only --output-dir data/recordings
|
||||
|
||||
Prerequisites:
|
||||
- ESP32 nodes streaming CSI on UDP (see firmware/esp32-csi-node)
|
||||
- Python 3.9+
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import struct
|
||||
import sys
|
||||
import time
|
||||
import signal
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger("collect-data")
|
||||
|
||||
# ── Packet formats (must match firmware) ─────────────────────────────────────
|
||||
|
||||
# ADR-018 raw CSI frame header
|
||||
MAGIC_CSI_RAW = 0xC5110001
|
||||
# ADR-069 feature vector packet
|
||||
MAGIC_FEATURES = 0xC5110003
|
||||
FEATURE_PKT_FMT = "<IBBHq8f"
|
||||
FEATURE_PKT_SIZE = struct.calcsize(FEATURE_PKT_FMT) # 48 bytes
|
||||
|
||||
# Raw CSI header: magic(4) + node_id(1) + antenna_cfg(1) + n_sub(2) + rssi(1) + noise(1) + channel(1) + reserved(1) + timestamp_ms(4)
|
||||
RAW_CSI_HDR_FMT = "<IBBHbbBxI"
|
||||
RAW_CSI_HDR_SIZE = struct.calcsize(RAW_CSI_HDR_FMT) # 16 bytes
|
||||
|
||||
|
||||
# ── Packet parsing ───────────────────────────────────────────────────────────
|
||||
|
||||
def parse_packet(data: bytes) -> Optional[dict]:
|
||||
"""Parse a UDP packet into a frame dict, or None if unrecognized."""
|
||||
if len(data) < 4:
|
||||
return None
|
||||
|
||||
magic = struct.unpack_from("<I", data)[0]
|
||||
|
||||
if magic == MAGIC_FEATURES and len(data) >= FEATURE_PKT_SIZE:
|
||||
return _parse_feature_packet(data)
|
||||
elif magic == MAGIC_CSI_RAW and len(data) >= RAW_CSI_HDR_SIZE:
|
||||
return _parse_raw_csi_packet(data)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def _parse_feature_packet(data: bytes) -> Optional[dict]:
|
||||
"""Parse ADR-069 feature vector packet (48 bytes)."""
|
||||
try:
|
||||
magic, node_id, _, seq, ts_us, *features = struct.unpack_from(FEATURE_PKT_FMT, data)
|
||||
except struct.error:
|
||||
return None
|
||||
|
||||
if magic != MAGIC_FEATURES:
|
||||
return None
|
||||
|
||||
# Reject NaN/inf
|
||||
import math
|
||||
if any(math.isnan(f) or math.isinf(f) for f in features):
|
||||
return None
|
||||
|
||||
return {
|
||||
"type": "features",
|
||||
"node_id": node_id,
|
||||
"seq": seq,
|
||||
"timestamp_us": ts_us,
|
||||
"timestamp": ts_us / 1_000_000.0,
|
||||
"features": features,
|
||||
"subcarriers": features, # Use features as subcarrier proxy for training
|
||||
"rssi": 0.0,
|
||||
"noise_floor": 0.0,
|
||||
}
|
||||
|
||||
|
||||
def _parse_raw_csi_packet(data: bytes) -> Optional[dict]:
|
||||
"""Parse ADR-018 raw CSI frame with full subcarrier data."""
|
||||
try:
|
||||
magic, node_id, ant_cfg, n_sub, rssi, noise, channel, ts_ms = struct.unpack_from(
|
||||
RAW_CSI_HDR_FMT, data
|
||||
)
|
||||
except struct.error:
|
||||
return None
|
||||
|
||||
if magic != MAGIC_CSI_RAW:
|
||||
return None
|
||||
|
||||
# Subcarrier data follows header as int16 I/Q pairs
|
||||
payload_offset = RAW_CSI_HDR_SIZE
|
||||
expected_bytes = n_sub * 2 * 2 # n_sub * (I + Q) * int16
|
||||
if len(data) < payload_offset + expected_bytes:
|
||||
return None
|
||||
|
||||
iq_data = struct.unpack_from(f"<{n_sub * 2}h", data, payload_offset)
|
||||
# Convert I/Q pairs to amplitude
|
||||
subcarriers = []
|
||||
for i in range(0, len(iq_data), 2):
|
||||
real, imag = iq_data[i], iq_data[i + 1]
|
||||
amplitude = (real ** 2 + imag ** 2) ** 0.5
|
||||
subcarriers.append(amplitude)
|
||||
|
||||
return {
|
||||
"type": "raw_csi",
|
||||
"node_id": node_id,
|
||||
"antenna_config": ant_cfg,
|
||||
"n_subcarriers": n_sub,
|
||||
"channel": channel,
|
||||
"timestamp": ts_ms / 1000.0,
|
||||
"subcarriers": subcarriers,
|
||||
"rssi": float(rssi),
|
||||
"noise_floor": float(noise),
|
||||
}
|
||||
|
||||
|
||||
# ── JSONL recording ──────────────────────────────────────────────────────────
|
||||
|
||||
class CsiRecorder:
|
||||
"""Records CSI frames to .csi.jsonl files compatible with the Rust pipeline."""
|
||||
|
||||
def __init__(self, output_dir: str, session_name: str, label: Optional[str] = None):
|
||||
self.output_dir = Path(output_dir)
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
||||
safe_name = session_name.replace(" ", "_").replace("/", "_")
|
||||
self.session_id = f"{safe_name}-{ts}"
|
||||
self.label = label
|
||||
self.file_path = self.output_dir / f"{self.session_id}.csi.jsonl"
|
||||
self.meta_path = self.output_dir / f"{self.session_id}.csi.meta.json"
|
||||
self.frame_count = 0
|
||||
self.start_time = time.time()
|
||||
self.started_at = datetime.now(timezone.utc).isoformat()
|
||||
self._file = None
|
||||
|
||||
def open(self):
|
||||
self._file = open(self.file_path, "a", encoding="utf-8")
|
||||
log.info(f"Recording to: {self.file_path}")
|
||||
|
||||
def write_frame(self, frame: dict):
|
||||
"""Write a single frame as a JSONL line."""
|
||||
if self._file is None:
|
||||
return
|
||||
|
||||
record = {
|
||||
"timestamp": frame.get("timestamp", time.time()),
|
||||
"subcarriers": frame.get("subcarriers", []),
|
||||
"rssi": frame.get("rssi", 0.0),
|
||||
"noise_floor": frame.get("noise_floor", 0.0),
|
||||
"features": {
|
||||
k: v for k, v in frame.items()
|
||||
if k not in ("timestamp", "subcarriers", "rssi", "noise_floor", "type")
|
||||
},
|
||||
}
|
||||
|
||||
line = json.dumps(record, separators=(",", ":"))
|
||||
self._file.write(line + "\n")
|
||||
self.frame_count += 1
|
||||
|
||||
if self.frame_count % 500 == 0:
|
||||
self._file.flush()
|
||||
|
||||
def close(self) -> dict:
|
||||
"""Close the recording and write metadata. Returns session info."""
|
||||
if self._file:
|
||||
self._file.flush()
|
||||
self._file.close()
|
||||
self._file = None
|
||||
|
||||
ended_at = datetime.now(timezone.utc).isoformat()
|
||||
elapsed = time.time() - self.start_time
|
||||
file_size = self.file_path.stat().st_size if self.file_path.exists() else 0
|
||||
|
||||
meta = {
|
||||
"id": self.session_id,
|
||||
"name": self.session_id,
|
||||
"label": self.label,
|
||||
"started_at": self.started_at,
|
||||
"ended_at": ended_at,
|
||||
"duration_secs": round(elapsed, 2),
|
||||
"frame_count": self.frame_count,
|
||||
"file_size_bytes": file_size,
|
||||
"file_path": str(self.file_path),
|
||||
"fps": round(self.frame_count / elapsed, 1) if elapsed > 0 else 0,
|
||||
}
|
||||
|
||||
with open(self.meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
log.info(
|
||||
f"Recording stopped: {self.frame_count} frames in {elapsed:.1f}s "
|
||||
f"({meta['fps']} fps, {file_size / 1024:.1f} KB)"
|
||||
)
|
||||
return meta
|
||||
|
||||
|
||||
# ── Manifest generation ──────────────────────────────────────────────────────
|
||||
|
||||
def generate_manifest(output_dir: str) -> dict:
|
||||
"""Scan recordings directory and generate a dataset manifest JSON."""
|
||||
rec_dir = Path(output_dir)
|
||||
sessions = []
|
||||
|
||||
for meta_file in sorted(rec_dir.glob("*.csi.meta.json")):
|
||||
try:
|
||||
with open(meta_file, "r") as f:
|
||||
meta = json.load(f)
|
||||
sessions.append(meta)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
log.warning(f"Skipping {meta_file}: {e}")
|
||||
|
||||
# Aggregate stats
|
||||
total_frames = sum(s.get("frame_count", 0) for s in sessions)
|
||||
total_bytes = sum(s.get("file_size_bytes", 0) for s in sessions)
|
||||
labels = sorted(set(s.get("label", "unlabeled") or "unlabeled" for s in sessions))
|
||||
|
||||
manifest = {
|
||||
"dataset": "wifi-densepose-csi",
|
||||
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||||
"directory": str(rec_dir),
|
||||
"num_sessions": len(sessions),
|
||||
"total_frames": total_frames,
|
||||
"total_size_bytes": total_bytes,
|
||||
"total_size_mb": round(total_bytes / (1024 * 1024), 2),
|
||||
"labels": labels,
|
||||
"sessions": sessions,
|
||||
}
|
||||
|
||||
manifest_path = rec_dir / "manifest.json"
|
||||
with open(manifest_path, "w", encoding="utf-8") as f:
|
||||
json.dump(manifest, f, indent=2)
|
||||
|
||||
log.info(
|
||||
f"Manifest: {len(sessions)} sessions, {total_frames} frames, "
|
||||
f"{manifest['total_size_mb']} MB, labels={labels}"
|
||||
)
|
||||
log.info(f"Written to: {manifest_path}")
|
||||
return manifest
|
||||
|
||||
|
||||
# ── UDP listener ─────────────────────────────────────────────────────────────
|
||||
|
||||
def collect_session(
|
||||
port: int,
|
||||
port2: Optional[int],
|
||||
output_dir: str,
|
||||
label: str,
|
||||
duration: float,
|
||||
session_name: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Run a single collection session. Returns session metadata."""
|
||||
name = session_name or label or "session"
|
||||
recorder = CsiRecorder(output_dir, name, label)
|
||||
recorder.open()
|
||||
|
||||
# Bind primary socket
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.bind(("0.0.0.0", port))
|
||||
sock.settimeout(1.0)
|
||||
sockets = [sock]
|
||||
|
||||
# Bind secondary socket if specified
|
||||
if port2:
|
||||
sock2 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
sock2.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock2.bind(("0.0.0.0", port2))
|
||||
sock2.settimeout(0.1)
|
||||
sockets.append(sock2)
|
||||
|
||||
log.info(
|
||||
f"Collecting '{label}' for {duration}s on port(s) "
|
||||
f"{port}{f', {port2}' if port2 else ''}"
|
||||
)
|
||||
|
||||
start = time.time()
|
||||
dropped = 0
|
||||
|
||||
try:
|
||||
while time.time() - start < duration:
|
||||
for s in sockets:
|
||||
try:
|
||||
data, addr = s.recvfrom(4096)
|
||||
except socket.timeout:
|
||||
continue
|
||||
|
||||
frame = parse_packet(data)
|
||||
if frame:
|
||||
recorder.write_frame(frame)
|
||||
else:
|
||||
dropped += 1
|
||||
|
||||
# Progress update every 5s
|
||||
elapsed = time.time() - start
|
||||
if recorder.frame_count > 0 and int(elapsed) % 5 == 0 and int(elapsed) > 0:
|
||||
remaining = duration - elapsed
|
||||
if remaining > 0 and int(elapsed * 10) % 50 == 0:
|
||||
log.info(
|
||||
f" {recorder.frame_count} frames collected, "
|
||||
f"{remaining:.0f}s remaining..."
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
log.info("Interrupted by user.")
|
||||
finally:
|
||||
for s in sockets:
|
||||
s.close()
|
||||
|
||||
if dropped > 0:
|
||||
log.warning(f" {dropped} unrecognized packets dropped")
|
||||
|
||||
return recorder.close()
|
||||
|
||||
|
||||
# ── Main ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Collect CSI training data from ESP32 nodes via UDP",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Interactive label input
|
||||
python scripts/collect-training-data.py --port 5006
|
||||
|
||||
# Fixed label, 60 seconds
|
||||
python scripts/collect-training-data.py --port 5006 --label walking --duration 60
|
||||
|
||||
# Multiple scenarios
|
||||
python scripts/collect-training-data.py --port 5006 --scenarios walking,standing,sitting --duration 30
|
||||
|
||||
# Dual ESP32 nodes
|
||||
python scripts/collect-training-data.py --port 5005 --port2 5006 --label test
|
||||
|
||||
# Generate manifest from existing recordings
|
||||
python scripts/collect-training-data.py --manifest-only
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument("--port", type=int, default=5006, help="Primary UDP port (default: 5006)")
|
||||
parser.add_argument("--port2", type=int, default=None, help="Secondary UDP port for dual-node")
|
||||
parser.add_argument("--output-dir", default="data/recordings", help="Output directory (default: data/recordings)")
|
||||
parser.add_argument("--label", default=None, help="Activity label for the recording")
|
||||
parser.add_argument("--duration", type=float, default=30.0, help="Recording duration in seconds (default: 30)")
|
||||
parser.add_argument("--scenarios", default=None, help="Comma-separated list of scenarios to record sequentially")
|
||||
parser.add_argument("--pause", type=float, default=5.0, help="Pause between scenarios in seconds (default: 5)")
|
||||
parser.add_argument("--manifest-only", action="store_true", help="Only generate manifest from existing recordings")
|
||||
parser.add_argument("--repeats", type=int, default=1, help="Number of repeats per scenario (default: 1)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Manifest-only mode
|
||||
if args.manifest_only:
|
||||
generate_manifest(args.output_dir)
|
||||
return
|
||||
|
||||
# Collect scenarios
|
||||
all_sessions = []
|
||||
|
||||
if args.scenarios:
|
||||
# Multi-scenario sequential collection
|
||||
scenarios = [s.strip() for s in args.scenarios.split(",") if s.strip()]
|
||||
total = len(scenarios) * args.repeats
|
||||
idx = 0
|
||||
|
||||
for repeat in range(args.repeats):
|
||||
for scenario in scenarios:
|
||||
idx += 1
|
||||
print(f"\n{'='*60}")
|
||||
print(f" Scenario {idx}/{total}: '{scenario}' (repeat {repeat+1}/{args.repeats})")
|
||||
print(f" Duration: {args.duration}s")
|
||||
print(f"{'='*60}")
|
||||
|
||||
if idx > 1:
|
||||
print(f" Starting in {args.pause}s... (get into position)")
|
||||
time.sleep(args.pause)
|
||||
|
||||
meta = collect_session(
|
||||
port=args.port,
|
||||
port2=args.port2,
|
||||
output_dir=args.output_dir,
|
||||
label=scenario,
|
||||
duration=args.duration,
|
||||
session_name=f"{scenario}_r{repeat+1:02d}",
|
||||
)
|
||||
all_sessions.append(meta)
|
||||
|
||||
elif args.label:
|
||||
# Single labeled recording
|
||||
meta = collect_session(
|
||||
port=args.port,
|
||||
port2=args.port2,
|
||||
output_dir=args.output_dir,
|
||||
label=args.label,
|
||||
duration=args.duration,
|
||||
)
|
||||
all_sessions.append(meta)
|
||||
|
||||
else:
|
||||
# Interactive mode — prompt for labels
|
||||
print("\nInteractive data collection mode.")
|
||||
print("Type a label for each recording, or 'q' to quit.\n")
|
||||
|
||||
while True:
|
||||
label = input("Label (or 'q' to quit): ").strip()
|
||||
if label.lower() in ("q", "quit", "exit"):
|
||||
break
|
||||
if not label:
|
||||
print(" Empty label. Try again.")
|
||||
continue
|
||||
|
||||
duration = args.duration
|
||||
try:
|
||||
dur_input = input(f"Duration in seconds [{duration}]: ").strip()
|
||||
if dur_input:
|
||||
duration = float(dur_input)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
print(f" Recording '{label}' for {duration}s — starting now...")
|
||||
meta = collect_session(
|
||||
port=args.port,
|
||||
port2=args.port2,
|
||||
output_dir=args.output_dir,
|
||||
label=label,
|
||||
duration=duration,
|
||||
)
|
||||
all_sessions.append(meta)
|
||||
print()
|
||||
|
||||
# Generate manifest
|
||||
if all_sessions:
|
||||
print(f"\nCollected {len(all_sessions)} session(s).")
|
||||
manifest = generate_manifest(args.output_dir)
|
||||
|
||||
total_frames = sum(s.get("frame_count", 0) for s in all_sessions)
|
||||
print(f"\nSummary:")
|
||||
print(f" Sessions: {len(all_sessions)}")
|
||||
print(f" Total frames: {total_frames}")
|
||||
print(f" Output: {args.output_dir}/")
|
||||
print(f" Manifest: {args.output_dir}/manifest.json")
|
||||
else:
|
||||
print("No sessions recorded.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
469
scripts/gcloud-train.sh
Normal file
469
scripts/gcloud-train.sh
Normal file
|
|
@ -0,0 +1,469 @@
|
|||
#!/bin/bash
|
||||
# ==============================================================================
|
||||
# GCloud GPU Training Script for WiFi-DensePose
|
||||
# ==============================================================================
|
||||
#
|
||||
# Creates a GCloud VM with GPU, runs the Rust training pipeline, downloads
|
||||
# the trained model artifacts, and tears down the VM to avoid ongoing costs.
|
||||
#
|
||||
# Usage:
|
||||
# bash scripts/gcloud-train.sh [OPTIONS]
|
||||
#
|
||||
# Options:
|
||||
# --gpu l4|a100|h100 GPU type (default: l4)
|
||||
# --zone ZONE GCloud zone (default: us-central1-a)
|
||||
# --hours N Max VM lifetime in hours (default: 2)
|
||||
# --config FILE Training config JSON (default: scripts/training-config-sweep.json entry 0)
|
||||
# --data-dir DIR Local data directory to upload (default: data/recordings)
|
||||
# --dry-run Run smoke test with synthetic data
|
||||
# --sweep Run full hyperparameter sweep (all configs)
|
||||
# --keep-vm Do not delete VM after training
|
||||
# --instance NAME Custom VM instance name
|
||||
#
|
||||
# Prerequisites:
|
||||
# - gcloud CLI authenticated: gcloud auth login
|
||||
# - Project set: gcloud config set project cognitum-20260110
|
||||
# - Quota for GPUs in the selected zone
|
||||
#
|
||||
# Cost estimates:
|
||||
# L4 (~$0.80/hr) — good for prototyping and small sweeps
|
||||
# A100 40GB (~$3.60/hr) — full training runs
|
||||
# H100 80GB (~$11.00/hr) — large batch / fast iteration
|
||||
# ==============================================================================
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# ── Defaults ──────────────────────────────────────────────────────────────────
|
||||
|
||||
PROJECT="cognitum-20260110"
|
||||
GPU_TYPE="l4"
|
||||
ZONE="us-central1-a"
|
||||
MAX_HOURS=2
|
||||
CONFIG_FILE=""
|
||||
DATA_DIR="data/recordings"
|
||||
DRY_RUN=false
|
||||
SWEEP=false
|
||||
KEEP_VM=false
|
||||
INSTANCE_NAME=""
|
||||
REPO_URL="https://github.com/ruvnet/wifi-densepose.git"
|
||||
BRANCH="main"
|
||||
|
||||
# ── Parse arguments ───────────────────────────────────────────────────────────
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--gpu) GPU_TYPE="$2"; shift 2 ;;
|
||||
--zone) ZONE="$2"; shift 2 ;;
|
||||
--hours) MAX_HOURS="$2"; shift 2 ;;
|
||||
--config) CONFIG_FILE="$2"; shift 2 ;;
|
||||
--data-dir) DATA_DIR="$2"; shift 2 ;;
|
||||
--dry-run) DRY_RUN=true; shift ;;
|
||||
--sweep) SWEEP=true; shift ;;
|
||||
--keep-vm) KEEP_VM=true; shift ;;
|
||||
--instance) INSTANCE_NAME="$2"; shift 2 ;;
|
||||
--branch) BRANCH="$2"; shift 2 ;;
|
||||
-h|--help)
|
||||
head -35 "$0" | tail -30
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "ERROR: Unknown option: $1"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# ── GPU configuration map ────────────────────────────────────────────────────
|
||||
|
||||
declare -A GPU_ACCELERATOR=(
|
||||
[l4]="nvidia-l4"
|
||||
[a100]="nvidia-tesla-a100"
|
||||
[h100]="nvidia-h100-80gb"
|
||||
)
|
||||
|
||||
declare -A GPU_MACHINE_TYPE=(
|
||||
[l4]="g2-standard-8"
|
||||
[a100]="a2-highgpu-1g"
|
||||
[h100]="a3-highgpu-1g"
|
||||
)
|
||||
|
||||
declare -A GPU_BOOT_DISK=(
|
||||
[l4]="200"
|
||||
[a100]="300"
|
||||
[h100]="300"
|
||||
)
|
||||
|
||||
if [[ -z "${GPU_ACCELERATOR[$GPU_TYPE]+x}" ]]; then
|
||||
echo "ERROR: Unknown GPU type '$GPU_TYPE'. Choose: l4, a100, h100"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ACCELERATOR="${GPU_ACCELERATOR[$GPU_TYPE]}"
|
||||
MACHINE_TYPE="${GPU_MACHINE_TYPE[$GPU_TYPE]}"
|
||||
BOOT_DISK_GB="${GPU_BOOT_DISK[$GPU_TYPE]}"
|
||||
|
||||
# ── Instance naming ──────────────────────────────────────────────────────────
|
||||
|
||||
TIMESTAMP=$(date +%Y%m%d-%H%M%S)
|
||||
if [[ -z "$INSTANCE_NAME" ]]; then
|
||||
INSTANCE_NAME="wdp-train-${GPU_TYPE}-${TIMESTAMP}"
|
||||
fi
|
||||
|
||||
# ── Announce plan ────────────────────────────────────────────────────────────
|
||||
|
||||
echo "============================================================"
|
||||
echo " WiFi-DensePose GCloud GPU Training"
|
||||
echo "============================================================"
|
||||
echo " Project: $PROJECT"
|
||||
echo " Instance: $INSTANCE_NAME"
|
||||
echo " Zone: $ZONE"
|
||||
echo " GPU: $GPU_TYPE ($ACCELERATOR)"
|
||||
echo " Machine: $MACHINE_TYPE"
|
||||
echo " Boot disk: ${BOOT_DISK_GB}GB"
|
||||
echo " Max runtime: ${MAX_HOURS}h"
|
||||
echo " Data dir: $DATA_DIR"
|
||||
echo " Dry run: $DRY_RUN"
|
||||
echo " Sweep: $SWEEP"
|
||||
echo " Branch: $BRANCH"
|
||||
echo "============================================================"
|
||||
echo ""
|
||||
|
||||
# ── Verify gcloud auth ──────────────────────────────────────────────────────
|
||||
|
||||
if ! gcloud auth list --filter=status:ACTIVE --format="value(account)" 2>/dev/null | head -1 | grep -q '@'; then
|
||||
echo "ERROR: No active gcloud account. Run: gcloud auth login"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
gcloud config set project "$PROJECT" --quiet
|
||||
|
||||
# ── Build startup script ─────────────────────────────────────────────────────
|
||||
|
||||
STARTUP_SCRIPT=$(cat <<'STARTUP_EOF'
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
exec > /var/log/wdp-setup.log 2>&1
|
||||
|
||||
echo "=== WiFi-DensePose GPU VM Setup ==="
|
||||
echo "Started: $(date)"
|
||||
|
||||
# Wait for GPU driver
|
||||
echo "Waiting for NVIDIA driver..."
|
||||
for i in $(seq 1 60); do
|
||||
if nvidia-smi &>/dev/null; then
|
||||
echo "GPU ready after ${i}s"
|
||||
nvidia-smi
|
||||
break
|
||||
fi
|
||||
sleep 5
|
||||
done
|
||||
|
||||
if ! nvidia-smi &>/dev/null; then
|
||||
echo "ERROR: GPU driver not available after 300s"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Install Rust toolchain
|
||||
echo "Installing Rust toolchain..."
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable
|
||||
source "$HOME/.cargo/env"
|
||||
rustc --version
|
||||
cargo --version
|
||||
|
||||
# Install system dependencies
|
||||
echo "Installing system dependencies..."
|
||||
apt-get update -qq
|
||||
apt-get install -y -qq pkg-config libssl-dev cmake clang
|
||||
|
||||
# Find libtorch from the Deep Learning VM's PyTorch installation
|
||||
echo "Locating libtorch..."
|
||||
PYTORCH_LIB=$(python3 -c "import torch; print(torch.__path__[0] + '/lib')" 2>/dev/null || echo "")
|
||||
if [[ -n "$PYTORCH_LIB" && -d "$PYTORCH_LIB" ]]; then
|
||||
export LIBTORCH="$PYTORCH_LIB"
|
||||
export LD_LIBRARY_PATH="${LIBTORCH}:${LD_LIBRARY_PATH:-}"
|
||||
echo "Found libtorch at: $LIBTORCH"
|
||||
else
|
||||
echo "WARNING: PyTorch not found in system Python. Installing via pip..."
|
||||
pip3 install torch --index-url https://download.pytorch.org/whl/cu121
|
||||
PYTORCH_LIB=$(python3 -c "import torch; print(torch.__path__[0] + '/lib')")
|
||||
export LIBTORCH="$PYTORCH_LIB"
|
||||
export LD_LIBRARY_PATH="${LIBTORCH}:${LD_LIBRARY_PATH:-}"
|
||||
fi
|
||||
|
||||
# Persist env vars
|
||||
cat >> /etc/environment <<ENV_VARS
|
||||
LIBTORCH=$LIBTORCH
|
||||
LD_LIBRARY_PATH=$LIBTORCH:\$LD_LIBRARY_PATH
|
||||
PATH=$HOME/.cargo/bin:\$PATH
|
||||
ENV_VARS
|
||||
|
||||
echo "=== Setup complete: $(date) ==="
|
||||
touch /tmp/wdp-setup-done
|
||||
STARTUP_EOF
|
||||
)
|
||||
|
||||
# ── Step 1: Create the VM ────────────────────────────────────────────────────
|
||||
|
||||
echo "[1/7] Creating VM instance: $INSTANCE_NAME ..."
|
||||
|
||||
gcloud compute instances create "$INSTANCE_NAME" \
|
||||
--project="$PROJECT" \
|
||||
--zone="$ZONE" \
|
||||
--machine-type="$MACHINE_TYPE" \
|
||||
--accelerator="type=$ACCELERATOR,count=1" \
|
||||
--image-family="common-cu121-ubuntu-2204" \
|
||||
--image-project="deeplearning-platform-release" \
|
||||
--boot-disk-size="${BOOT_DISK_GB}GB" \
|
||||
--boot-disk-type="pd-ssd" \
|
||||
--maintenance-policy=TERMINATE \
|
||||
--metadata="install-nvidia-driver=True" \
|
||||
--metadata-from-file="startup-script=<(echo "$STARTUP_SCRIPT")" \
|
||||
--scopes="default,storage-rw" \
|
||||
--labels="purpose=wdp-training,gpu=${GPU_TYPE}" \
|
||||
--quiet
|
||||
|
||||
echo " VM created. Waiting for startup script to complete..."
|
||||
|
||||
# ── Step 2: Wait for setup ───────────────────────────────────────────────────
|
||||
|
||||
echo "[2/7] Waiting for setup to complete (GPU driver + Rust toolchain)..."
|
||||
|
||||
for i in $(seq 1 60); do
|
||||
if gcloud compute ssh "$INSTANCE_NAME" --zone="$ZONE" --command="test -f /tmp/wdp-setup-done" --quiet 2>/dev/null; then
|
||||
echo " Setup complete after $((i * 15))s"
|
||||
break
|
||||
fi
|
||||
if [[ $i -eq 60 ]]; then
|
||||
echo "ERROR: Setup timed out after 15 minutes."
|
||||
echo "Check logs: gcloud compute ssh $INSTANCE_NAME --zone=$ZONE --command='cat /var/log/wdp-setup.log'"
|
||||
if [[ "$KEEP_VM" == "false" ]]; then
|
||||
echo "Cleaning up VM..."
|
||||
gcloud compute instances delete "$INSTANCE_NAME" --zone="$ZONE" --quiet
|
||||
fi
|
||||
exit 1
|
||||
fi
|
||||
sleep 15
|
||||
done
|
||||
|
||||
# ── Step 3: Clone repo and build ─────────────────────────────────────────────
|
||||
|
||||
echo "[3/7] Cloning repository and building training binary..."
|
||||
|
||||
gcloud compute ssh "$INSTANCE_NAME" --zone="$ZONE" --command="$(cat <<CLONE_EOF
|
||||
set -euo pipefail
|
||||
source \$HOME/.cargo/env
|
||||
|
||||
# Clone the repo
|
||||
if [[ ! -d ~/wifi-densepose ]]; then
|
||||
git clone --depth 1 --branch "$BRANCH" "$REPO_URL" ~/wifi-densepose
|
||||
fi
|
||||
|
||||
# Set libtorch environment
|
||||
export LIBTORCH=\$(python3 -c "import torch; print(torch.__path__[0] + '/lib')")
|
||||
export LD_LIBRARY_PATH="\${LIBTORCH}:\${LD_LIBRARY_PATH:-}"
|
||||
|
||||
# Build the training binary with tch-backend
|
||||
cd ~/wifi-densepose/rust-port/wifi-densepose-rs
|
||||
echo "Building with LIBTORCH=\$LIBTORCH ..."
|
||||
cargo build --release --features tch-backend --bin train 2>&1 | tail -5
|
||||
|
||||
echo "Build complete."
|
||||
ls -lh target/release/train
|
||||
CLONE_EOF
|
||||
)"
|
||||
|
||||
# ── Step 4: Upload training data ─────────────────────────────────────────────
|
||||
|
||||
echo "[4/7] Uploading training data..."
|
||||
|
||||
if [[ -d "$DATA_DIR" ]] && [[ "$(ls -A "$DATA_DIR" 2>/dev/null)" ]]; then
|
||||
# Create a tarball of the data directory
|
||||
DATA_TAR="/tmp/wdp-training-data-${TIMESTAMP}.tar.gz"
|
||||
tar czf "$DATA_TAR" -C "$(dirname "$DATA_DIR")" "$(basename "$DATA_DIR")"
|
||||
DATA_SIZE=$(du -h "$DATA_TAR" | cut -f1)
|
||||
echo " Uploading ${DATA_SIZE} of training data..."
|
||||
|
||||
gcloud compute scp "$DATA_TAR" "${INSTANCE_NAME}:~/training-data.tar.gz" --zone="$ZONE" --quiet
|
||||
gcloud compute ssh "$INSTANCE_NAME" --zone="$ZONE" --command="
|
||||
mkdir -p ~/wifi-densepose/data
|
||||
tar xzf ~/training-data.tar.gz -C ~/wifi-densepose/data/
|
||||
echo 'Data extracted:'
|
||||
find ~/wifi-densepose/data -name '*.jsonl' -o -name '*.csi.jsonl' | head -20
|
||||
"
|
||||
rm -f "$DATA_TAR"
|
||||
else
|
||||
echo " No local data at '$DATA_DIR'. Training will use --dry-run or MM-Fi."
|
||||
if [[ "$DRY_RUN" == "false" && "$SWEEP" == "false" ]]; then
|
||||
echo " WARNING: No data and --dry-run not set. Forcing --dry-run."
|
||||
DRY_RUN=true
|
||||
fi
|
||||
fi
|
||||
|
||||
# ── Step 5: Upload config and run training ────────────────────────────────────
|
||||
|
||||
echo "[5/7] Running training..."
|
||||
|
||||
# Upload sweep config if doing a sweep
|
||||
if [[ "$SWEEP" == "true" ]]; then
|
||||
SWEEP_FILE="scripts/training-config-sweep.json"
|
||||
if [[ -f "$SWEEP_FILE" ]]; then
|
||||
gcloud compute scp "$SWEEP_FILE" "${INSTANCE_NAME}:~/sweep-configs.json" --zone="$ZONE" --quiet
|
||||
else
|
||||
echo "ERROR: Sweep config not found at $SWEEP_FILE"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
# Upload single config if specified
|
||||
if [[ -n "$CONFIG_FILE" ]]; then
|
||||
gcloud compute scp "$CONFIG_FILE" "${INSTANCE_NAME}:~/train-config.json" --zone="$ZONE" --quiet
|
||||
fi
|
||||
|
||||
# Build the training command
|
||||
TRAIN_CMD_BASE="
|
||||
set -euo pipefail
|
||||
source \$HOME/.cargo/env
|
||||
export LIBTORCH=\$(python3 -c \"import torch; print(torch.__path__[0] + '/lib')\")
|
||||
export LD_LIBRARY_PATH=\"\${LIBTORCH}:\${LD_LIBRARY_PATH:-}\"
|
||||
cd ~/wifi-densepose/rust-port/wifi-densepose-rs
|
||||
|
||||
# Set auto-shutdown timer (safety net)
|
||||
sudo shutdown -P +$((MAX_HOURS * 60)) &
|
||||
|
||||
TRAIN_BIN=./target/release/train
|
||||
"
|
||||
|
||||
if [[ "$SWEEP" == "true" ]]; then
|
||||
# Run all configs in the sweep file
|
||||
gcloud compute ssh "$INSTANCE_NAME" --zone="$ZONE" --command="$(cat <<SWEEP_EOF
|
||||
$TRAIN_CMD_BASE
|
||||
|
||||
echo "=== Hyperparameter Sweep ==="
|
||||
SWEEP_FILE=~/sweep-configs.json
|
||||
NUM_CONFIGS=\$(python3 -c "import json; print(len(json.load(open('\$SWEEP_FILE'))['configs']))")
|
||||
echo "Running \$NUM_CONFIGS configurations..."
|
||||
|
||||
mkdir -p ~/results
|
||||
|
||||
for i in \$(seq 0 \$((NUM_CONFIGS - 1))); do
|
||||
echo ""
|
||||
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
echo " Config \$((i+1)) / \$NUM_CONFIGS"
|
||||
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
|
||||
# Extract single config to temp file
|
||||
python3 -c "
|
||||
import json, sys
|
||||
sweep = json.load(open('\$SWEEP_FILE'))
|
||||
cfg = sweep['configs'][\$i]
|
||||
# Merge with base config
|
||||
base = sweep.get('base', {})
|
||||
merged = {**base, **cfg}
|
||||
# Set checkpoint dir per config
|
||||
merged['checkpoint_dir'] = f'checkpoints/sweep_{i:02d}'
|
||||
merged['log_dir'] = f'logs/sweep_{i:02d}'
|
||||
json.dump(merged, open('/tmp/sweep_config_\${i}.json', 'w'), indent=2)
|
||||
print(f\"Config \${i}: lr={merged.get('learning_rate', '?')}, bs={merged.get('batch_size', '?')}, bb={merged.get('backbone_channels', '?')}\")
|
||||
"
|
||||
|
||||
START_TIME=\$(date +%s)
|
||||
|
||||
\$TRAIN_BIN --config /tmp/sweep_config_\${i}.json --cuda $( [[ "$DRY_RUN" == "true" ]] && echo "--dry-run" ) 2>&1 | tee ~/results/sweep_\${i}.log || true
|
||||
|
||||
END_TIME=\$(date +%s)
|
||||
ELAPSED=\$(( END_TIME - START_TIME ))
|
||||
echo " Completed in \${ELAPSED}s"
|
||||
done
|
||||
|
||||
echo ""
|
||||
echo "=== Sweep Complete ==="
|
||||
echo "Results in ~/results/"
|
||||
ls -lh ~/results/
|
||||
SWEEP_EOF
|
||||
)"
|
||||
elif [[ -n "$CONFIG_FILE" ]]; then
|
||||
# Single config run
|
||||
gcloud compute ssh "$INSTANCE_NAME" --zone="$ZONE" --command="$(cat <<SINGLE_EOF
|
||||
$TRAIN_CMD_BASE
|
||||
echo "=== Training with custom config ==="
|
||||
\$TRAIN_BIN --config ~/train-config.json --cuda $( [[ "$DRY_RUN" == "true" ]] && echo "--dry-run" ) 2>&1 | tee ~/train.log
|
||||
SINGLE_EOF
|
||||
)"
|
||||
else
|
||||
# Default config run
|
||||
gcloud compute ssh "$INSTANCE_NAME" --zone="$ZONE" --command="$(cat <<DEFAULT_EOF
|
||||
$TRAIN_CMD_BASE
|
||||
echo "=== Training with default config ==="
|
||||
\$TRAIN_BIN --cuda $( [[ "$DRY_RUN" == "true" ]] && echo "--dry-run --dry-run-samples 256" ) 2>&1 | tee ~/train.log
|
||||
DEFAULT_EOF
|
||||
)"
|
||||
fi
|
||||
|
||||
# ── Step 6: Download results ─────────────────────────────────────────────────
|
||||
|
||||
echo "[6/7] Downloading trained model artifacts..."
|
||||
|
||||
LOCAL_RESULTS="training-results/${INSTANCE_NAME}"
|
||||
mkdir -p "$LOCAL_RESULTS"
|
||||
|
||||
# Package results on the VM
|
||||
gcloud compute ssh "$INSTANCE_NAME" --zone="$ZONE" --command="
|
||||
cd ~/wifi-densepose/rust-port/wifi-densepose-rs
|
||||
tar czf ~/training-artifacts.tar.gz \
|
||||
checkpoints/ \
|
||||
logs/ \
|
||||
2>/dev/null || true
|
||||
|
||||
# Also grab sweep results if they exist
|
||||
if [[ -d ~/results ]]; then
|
||||
tar czf ~/sweep-results.tar.gz -C ~ results/ 2>/dev/null || true
|
||||
fi
|
||||
|
||||
ls -lh ~/training-artifacts.tar.gz ~/sweep-results.tar.gz 2>/dev/null || true
|
||||
"
|
||||
|
||||
# Download artifacts
|
||||
gcloud compute scp "${INSTANCE_NAME}:~/training-artifacts.tar.gz" \
|
||||
"${LOCAL_RESULTS}/training-artifacts.tar.gz" --zone="$ZONE" --quiet 2>/dev/null || true
|
||||
|
||||
if [[ "$SWEEP" == "true" ]]; then
|
||||
gcloud compute scp "${INSTANCE_NAME}:~/sweep-results.tar.gz" \
|
||||
"${LOCAL_RESULTS}/sweep-results.tar.gz" --zone="$ZONE" --quiet 2>/dev/null || true
|
||||
fi
|
||||
|
||||
# Download training log
|
||||
gcloud compute scp "${INSTANCE_NAME}:~/train.log" \
|
||||
"${LOCAL_RESULTS}/train.log" --zone="$ZONE" --quiet 2>/dev/null || true
|
||||
|
||||
# Extract locally
|
||||
if [[ -f "${LOCAL_RESULTS}/training-artifacts.tar.gz" ]]; then
|
||||
tar xzf "${LOCAL_RESULTS}/training-artifacts.tar.gz" -C "$LOCAL_RESULTS/"
|
||||
echo " Artifacts extracted to: $LOCAL_RESULTS/"
|
||||
find "$LOCAL_RESULTS" -name "*.pt" -o -name "*.onnx" -o -name "*.rvf" 2>/dev/null | head -20
|
||||
fi
|
||||
|
||||
# ── Step 7: Cleanup ──────────────────────────────────────────────────────────
|
||||
|
||||
if [[ "$KEEP_VM" == "true" ]]; then
|
||||
echo "[7/7] Keeping VM alive (--keep-vm). Remember to delete it manually:"
|
||||
echo " gcloud compute instances delete $INSTANCE_NAME --zone=$ZONE --quiet"
|
||||
echo " SSH: gcloud compute ssh $INSTANCE_NAME --zone=$ZONE"
|
||||
else
|
||||
echo "[7/7] Deleting VM to avoid ongoing costs..."
|
||||
gcloud compute instances delete "$INSTANCE_NAME" --zone="$ZONE" --quiet
|
||||
echo " VM deleted."
|
||||
fi
|
||||
|
||||
# ── Summary ──────────────────────────────────────────────────────────────────
|
||||
|
||||
echo ""
|
||||
echo "============================================================"
|
||||
echo " Training Complete"
|
||||
echo "============================================================"
|
||||
echo " Results: $LOCAL_RESULTS/"
|
||||
echo " GPU: $GPU_TYPE ($ZONE)"
|
||||
echo " Instance: $INSTANCE_NAME"
|
||||
if [[ "$KEEP_VM" == "true" ]]; then
|
||||
echo " VM: STILL RUNNING (delete manually!)"
|
||||
fi
|
||||
echo "============================================================"
|
||||
155
scripts/training-config-sweep.json
Normal file
155
scripts/training-config-sweep.json
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
{
|
||||
"description": "WiFi-DensePose hyperparameter sweep — 10 configurations exploring learning rate, batch size, backbone width, window length, loss ratios, and warmup schedules.",
|
||||
"base": {
|
||||
"num_subcarriers": 56,
|
||||
"native_subcarriers": 114,
|
||||
"num_antennas_tx": 3,
|
||||
"num_antennas_rx": 3,
|
||||
"heatmap_size": 56,
|
||||
"num_keypoints": 17,
|
||||
"num_body_parts": 24,
|
||||
"weight_decay": 1e-4,
|
||||
"num_epochs": 50,
|
||||
"lr_gamma": 0.1,
|
||||
"grad_clip_norm": 1.0,
|
||||
"val_every_epochs": 1,
|
||||
"early_stopping_patience": 10,
|
||||
"save_top_k": 3,
|
||||
"use_gpu": true,
|
||||
"gpu_device_id": 0,
|
||||
"num_workers": 4,
|
||||
"seed": 42
|
||||
},
|
||||
"configs": [
|
||||
{
|
||||
"_name": "baseline",
|
||||
"_description": "Default config — reference baseline",
|
||||
"learning_rate": 1e-3,
|
||||
"batch_size": 8,
|
||||
"backbone_channels": 256,
|
||||
"window_frames": 100,
|
||||
"warmup_epochs": 5,
|
||||
"lr_milestones": [30, 45],
|
||||
"lambda_kp": 0.3,
|
||||
"lambda_dp": 0.6,
|
||||
"lambda_tr": 0.1
|
||||
},
|
||||
{
|
||||
"_name": "low_lr_large_batch",
|
||||
"_description": "Lower LR with larger batch — stable convergence",
|
||||
"learning_rate": 1e-4,
|
||||
"batch_size": 16,
|
||||
"backbone_channels": 256,
|
||||
"window_frames": 100,
|
||||
"warmup_epochs": 10,
|
||||
"lr_milestones": [30, 45],
|
||||
"lambda_kp": 0.3,
|
||||
"lambda_dp": 0.6,
|
||||
"lambda_tr": 0.1
|
||||
},
|
||||
{
|
||||
"_name": "high_lr_small_batch",
|
||||
"_description": "Higher LR with small batch — fast exploration",
|
||||
"learning_rate": 2e-3,
|
||||
"batch_size": 4,
|
||||
"backbone_channels": 256,
|
||||
"window_frames": 100,
|
||||
"warmup_epochs": 3,
|
||||
"lr_milestones": [20, 40],
|
||||
"lambda_kp": 0.3,
|
||||
"lambda_dp": 0.6,
|
||||
"lambda_tr": 0.1
|
||||
},
|
||||
{
|
||||
"_name": "narrow_backbone",
|
||||
"_description": "128-channel backbone — faster training, lower VRAM",
|
||||
"learning_rate": 1e-3,
|
||||
"batch_size": 16,
|
||||
"backbone_channels": 128,
|
||||
"window_frames": 100,
|
||||
"warmup_epochs": 5,
|
||||
"lr_milestones": [30, 45],
|
||||
"lambda_kp": 0.3,
|
||||
"lambda_dp": 0.6,
|
||||
"lambda_tr": 0.1
|
||||
},
|
||||
{
|
||||
"_name": "short_window",
|
||||
"_description": "50-frame window — lower latency, tests temporal sensitivity",
|
||||
"learning_rate": 5e-4,
|
||||
"batch_size": 16,
|
||||
"backbone_channels": 256,
|
||||
"window_frames": 50,
|
||||
"warmup_epochs": 5,
|
||||
"lr_milestones": [30, 45],
|
||||
"lambda_kp": 0.3,
|
||||
"lambda_dp": 0.6,
|
||||
"lambda_tr": 0.1
|
||||
},
|
||||
{
|
||||
"_name": "keypoint_heavy",
|
||||
"_description": "Heavier keypoint loss — prioritize skeleton accuracy",
|
||||
"learning_rate": 5e-4,
|
||||
"batch_size": 8,
|
||||
"backbone_channels": 256,
|
||||
"window_frames": 100,
|
||||
"warmup_epochs": 5,
|
||||
"lr_milestones": [30, 45],
|
||||
"lambda_kp": 0.5,
|
||||
"lambda_dp": 0.4,
|
||||
"lambda_tr": 0.1
|
||||
},
|
||||
{
|
||||
"_name": "contrastive_heavy",
|
||||
"_description": "Strong contrastive/transfer loss — self-supervised pretraining focus",
|
||||
"learning_rate": 5e-4,
|
||||
"batch_size": 8,
|
||||
"backbone_channels": 256,
|
||||
"window_frames": 100,
|
||||
"warmup_epochs": 10,
|
||||
"lr_milestones": [30, 45],
|
||||
"lambda_kp": 0.2,
|
||||
"lambda_dp": 0.3,
|
||||
"lambda_tr": 0.5
|
||||
},
|
||||
{
|
||||
"_name": "wide_backbone_long_warmup",
|
||||
"_description": "256-ch backbone + long warmup + moderate LR",
|
||||
"learning_rate": 5e-4,
|
||||
"batch_size": 8,
|
||||
"backbone_channels": 256,
|
||||
"window_frames": 100,
|
||||
"warmup_epochs": 10,
|
||||
"lr_milestones": [35, 48],
|
||||
"lambda_kp": 0.3,
|
||||
"lambda_dp": 0.6,
|
||||
"lambda_tr": 0.1
|
||||
},
|
||||
{
|
||||
"_name": "narrow_short_aggressive",
|
||||
"_description": "128-ch + 50-frame + high LR — fast cheap exploration",
|
||||
"learning_rate": 2e-3,
|
||||
"batch_size": 16,
|
||||
"backbone_channels": 128,
|
||||
"window_frames": 50,
|
||||
"warmup_epochs": 3,
|
||||
"lr_milestones": [20, 40],
|
||||
"lambda_kp": 0.4,
|
||||
"lambda_dp": 0.5,
|
||||
"lambda_tr": 0.1
|
||||
},
|
||||
{
|
||||
"_name": "balanced_medium",
|
||||
"_description": "Balanced loss, medium LR, medium batch — robust default",
|
||||
"learning_rate": 5e-4,
|
||||
"batch_size": 8,
|
||||
"backbone_channels": 256,
|
||||
"window_frames": 100,
|
||||
"warmup_epochs": 5,
|
||||
"lr_milestones": [25, 40],
|
||||
"lambda_kp": 0.35,
|
||||
"lambda_dp": 0.45,
|
||||
"lambda_tr": 0.2
|
||||
}
|
||||
]
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue