mirror of
https://github.com/ruvnet/RuView.git
synced 2026-04-28 14:09:33 +00:00
- gcloud-train.sh: L4/A100/H100 VM provisioning, Rust build, training with --cuda, artifact download, auto-cleanup ($0.80-$8.50/hr) - training-config-sweep.json: 10 hyperparameter configs (LR, batch, backbone, windows, loss weights, warmup) - collect-training-data.py: UDP listener for 2-node ESP32 CSI recording to .csi.jsonl with interactive/batch labeling and manifest generation - benchmark-model.py: ONNX latency/throughput/PCK/FLOPs profiling with multi-model sweep comparison Co-Authored-By: claude-flow <ruv@ruv.net>
483 lines
17 KiB
Python
483 lines
17 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
WiFi-DensePose Training Data Collector
|
|
|
|
Listens on UDP for CSI data from ESP32 nodes and records to .csi.jsonl
|
|
files compatible with the Rust training pipeline (MmFiDataset / CsiDataset).
|
|
|
|
Supports two packet formats:
|
|
- ADR-069 feature vectors (magic 0xC5110003, 48 bytes) — 8-dim pre-extracted
|
|
- ADR-018 raw CSI frames (magic 0xC5110001, variable) — full subcarrier data
|
|
|
|
Usage:
|
|
# Interactive — prompts for scenario labels
|
|
python scripts/collect-training-data.py --port 5006
|
|
|
|
# Scripted — fixed label, 60s per recording
|
|
python scripts/collect-training-data.py --port 5006 --label walking --duration 60
|
|
|
|
# Multiple scenarios in sequence
|
|
python scripts/collect-training-data.py --port 5006 --scenarios walking,standing,sitting --duration 30
|
|
|
|
# Dual-node collection (two ESP32s on different ports)
|
|
python scripts/collect-training-data.py --port 5005 --port2 5006 --label walking
|
|
|
|
# Generate manifest only from existing recordings
|
|
python scripts/collect-training-data.py --manifest-only --output-dir data/recordings
|
|
|
|
Prerequisites:
|
|
- ESP32 nodes streaming CSI on UDP (see firmware/esp32-csi-node)
|
|
- Python 3.9+
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import logging
|
|
import os
|
|
import socket
|
|
import struct
|
|
import sys
|
|
import time
|
|
import signal
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
|
datefmt="%H:%M:%S",
|
|
)
|
|
log = logging.getLogger("collect-data")
|
|
|
|
# ── Packet formats (must match firmware) ─────────────────────────────────────
|
|
|
|
# ADR-018 raw CSI frame header
|
|
MAGIC_CSI_RAW = 0xC5110001
|
|
# ADR-069 feature vector packet
|
|
MAGIC_FEATURES = 0xC5110003
|
|
FEATURE_PKT_FMT = "<IBBHq8f"
|
|
FEATURE_PKT_SIZE = struct.calcsize(FEATURE_PKT_FMT) # 48 bytes
|
|
|
|
# Raw CSI header: magic(4) + node_id(1) + antenna_cfg(1) + n_sub(2) + rssi(1) + noise(1) + channel(1) + reserved(1) + timestamp_ms(4)
|
|
RAW_CSI_HDR_FMT = "<IBBHbbBxI"
|
|
RAW_CSI_HDR_SIZE = struct.calcsize(RAW_CSI_HDR_FMT) # 16 bytes
|
|
|
|
|
|
# ── Packet parsing ───────────────────────────────────────────────────────────
|
|
|
|
def parse_packet(data: bytes) -> Optional[dict]:
|
|
"""Parse a UDP packet into a frame dict, or None if unrecognized."""
|
|
if len(data) < 4:
|
|
return None
|
|
|
|
magic = struct.unpack_from("<I", data)[0]
|
|
|
|
if magic == MAGIC_FEATURES and len(data) >= FEATURE_PKT_SIZE:
|
|
return _parse_feature_packet(data)
|
|
elif magic == MAGIC_CSI_RAW and len(data) >= RAW_CSI_HDR_SIZE:
|
|
return _parse_raw_csi_packet(data)
|
|
else:
|
|
return None
|
|
|
|
|
|
def _parse_feature_packet(data: bytes) -> Optional[dict]:
|
|
"""Parse ADR-069 feature vector packet (48 bytes)."""
|
|
try:
|
|
magic, node_id, _, seq, ts_us, *features = struct.unpack_from(FEATURE_PKT_FMT, data)
|
|
except struct.error:
|
|
return None
|
|
|
|
if magic != MAGIC_FEATURES:
|
|
return None
|
|
|
|
# Reject NaN/inf
|
|
import math
|
|
if any(math.isnan(f) or math.isinf(f) for f in features):
|
|
return None
|
|
|
|
return {
|
|
"type": "features",
|
|
"node_id": node_id,
|
|
"seq": seq,
|
|
"timestamp_us": ts_us,
|
|
"timestamp": ts_us / 1_000_000.0,
|
|
"features": features,
|
|
"subcarriers": features, # Use features as subcarrier proxy for training
|
|
"rssi": 0.0,
|
|
"noise_floor": 0.0,
|
|
}
|
|
|
|
|
|
def _parse_raw_csi_packet(data: bytes) -> Optional[dict]:
|
|
"""Parse ADR-018 raw CSI frame with full subcarrier data."""
|
|
try:
|
|
magic, node_id, ant_cfg, n_sub, rssi, noise, channel, ts_ms = struct.unpack_from(
|
|
RAW_CSI_HDR_FMT, data
|
|
)
|
|
except struct.error:
|
|
return None
|
|
|
|
if magic != MAGIC_CSI_RAW:
|
|
return None
|
|
|
|
# Subcarrier data follows header as int16 I/Q pairs
|
|
payload_offset = RAW_CSI_HDR_SIZE
|
|
expected_bytes = n_sub * 2 * 2 # n_sub * (I + Q) * int16
|
|
if len(data) < payload_offset + expected_bytes:
|
|
return None
|
|
|
|
iq_data = struct.unpack_from(f"<{n_sub * 2}h", data, payload_offset)
|
|
# Convert I/Q pairs to amplitude
|
|
subcarriers = []
|
|
for i in range(0, len(iq_data), 2):
|
|
real, imag = iq_data[i], iq_data[i + 1]
|
|
amplitude = (real ** 2 + imag ** 2) ** 0.5
|
|
subcarriers.append(amplitude)
|
|
|
|
return {
|
|
"type": "raw_csi",
|
|
"node_id": node_id,
|
|
"antenna_config": ant_cfg,
|
|
"n_subcarriers": n_sub,
|
|
"channel": channel,
|
|
"timestamp": ts_ms / 1000.0,
|
|
"subcarriers": subcarriers,
|
|
"rssi": float(rssi),
|
|
"noise_floor": float(noise),
|
|
}
|
|
|
|
|
|
# ── JSONL recording ──────────────────────────────────────────────────────────
|
|
|
|
class CsiRecorder:
|
|
"""Records CSI frames to .csi.jsonl files compatible with the Rust pipeline."""
|
|
|
|
def __init__(self, output_dir: str, session_name: str, label: Optional[str] = None):
|
|
self.output_dir = Path(output_dir)
|
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
|
safe_name = session_name.replace(" ", "_").replace("/", "_")
|
|
self.session_id = f"{safe_name}-{ts}"
|
|
self.label = label
|
|
self.file_path = self.output_dir / f"{self.session_id}.csi.jsonl"
|
|
self.meta_path = self.output_dir / f"{self.session_id}.csi.meta.json"
|
|
self.frame_count = 0
|
|
self.start_time = time.time()
|
|
self.started_at = datetime.now(timezone.utc).isoformat()
|
|
self._file = None
|
|
|
|
def open(self):
|
|
self._file = open(self.file_path, "a", encoding="utf-8")
|
|
log.info(f"Recording to: {self.file_path}")
|
|
|
|
def write_frame(self, frame: dict):
|
|
"""Write a single frame as a JSONL line."""
|
|
if self._file is None:
|
|
return
|
|
|
|
record = {
|
|
"timestamp": frame.get("timestamp", time.time()),
|
|
"subcarriers": frame.get("subcarriers", []),
|
|
"rssi": frame.get("rssi", 0.0),
|
|
"noise_floor": frame.get("noise_floor", 0.0),
|
|
"features": {
|
|
k: v for k, v in frame.items()
|
|
if k not in ("timestamp", "subcarriers", "rssi", "noise_floor", "type")
|
|
},
|
|
}
|
|
|
|
line = json.dumps(record, separators=(",", ":"))
|
|
self._file.write(line + "\n")
|
|
self.frame_count += 1
|
|
|
|
if self.frame_count % 500 == 0:
|
|
self._file.flush()
|
|
|
|
def close(self) -> dict:
|
|
"""Close the recording and write metadata. Returns session info."""
|
|
if self._file:
|
|
self._file.flush()
|
|
self._file.close()
|
|
self._file = None
|
|
|
|
ended_at = datetime.now(timezone.utc).isoformat()
|
|
elapsed = time.time() - self.start_time
|
|
file_size = self.file_path.stat().st_size if self.file_path.exists() else 0
|
|
|
|
meta = {
|
|
"id": self.session_id,
|
|
"name": self.session_id,
|
|
"label": self.label,
|
|
"started_at": self.started_at,
|
|
"ended_at": ended_at,
|
|
"duration_secs": round(elapsed, 2),
|
|
"frame_count": self.frame_count,
|
|
"file_size_bytes": file_size,
|
|
"file_path": str(self.file_path),
|
|
"fps": round(self.frame_count / elapsed, 1) if elapsed > 0 else 0,
|
|
}
|
|
|
|
with open(self.meta_path, "w", encoding="utf-8") as f:
|
|
json.dump(meta, f, indent=2)
|
|
|
|
log.info(
|
|
f"Recording stopped: {self.frame_count} frames in {elapsed:.1f}s "
|
|
f"({meta['fps']} fps, {file_size / 1024:.1f} KB)"
|
|
)
|
|
return meta
|
|
|
|
|
|
# ── Manifest generation ──────────────────────────────────────────────────────
|
|
|
|
def generate_manifest(output_dir: str) -> dict:
|
|
"""Scan recordings directory and generate a dataset manifest JSON."""
|
|
rec_dir = Path(output_dir)
|
|
sessions = []
|
|
|
|
for meta_file in sorted(rec_dir.glob("*.csi.meta.json")):
|
|
try:
|
|
with open(meta_file, "r") as f:
|
|
meta = json.load(f)
|
|
sessions.append(meta)
|
|
except (json.JSONDecodeError, OSError) as e:
|
|
log.warning(f"Skipping {meta_file}: {e}")
|
|
|
|
# Aggregate stats
|
|
total_frames = sum(s.get("frame_count", 0) for s in sessions)
|
|
total_bytes = sum(s.get("file_size_bytes", 0) for s in sessions)
|
|
labels = sorted(set(s.get("label", "unlabeled") or "unlabeled" for s in sessions))
|
|
|
|
manifest = {
|
|
"dataset": "wifi-densepose-csi",
|
|
"generated_at": datetime.now(timezone.utc).isoformat(),
|
|
"directory": str(rec_dir),
|
|
"num_sessions": len(sessions),
|
|
"total_frames": total_frames,
|
|
"total_size_bytes": total_bytes,
|
|
"total_size_mb": round(total_bytes / (1024 * 1024), 2),
|
|
"labels": labels,
|
|
"sessions": sessions,
|
|
}
|
|
|
|
manifest_path = rec_dir / "manifest.json"
|
|
with open(manifest_path, "w", encoding="utf-8") as f:
|
|
json.dump(manifest, f, indent=2)
|
|
|
|
log.info(
|
|
f"Manifest: {len(sessions)} sessions, {total_frames} frames, "
|
|
f"{manifest['total_size_mb']} MB, labels={labels}"
|
|
)
|
|
log.info(f"Written to: {manifest_path}")
|
|
return manifest
|
|
|
|
|
|
# ── UDP listener ─────────────────────────────────────────────────────────────
|
|
|
|
def collect_session(
|
|
port: int,
|
|
port2: Optional[int],
|
|
output_dir: str,
|
|
label: str,
|
|
duration: float,
|
|
session_name: Optional[str] = None,
|
|
) -> dict:
|
|
"""Run a single collection session. Returns session metadata."""
|
|
name = session_name or label or "session"
|
|
recorder = CsiRecorder(output_dir, name, label)
|
|
recorder.open()
|
|
|
|
# Bind primary socket
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
sock.bind(("0.0.0.0", port))
|
|
sock.settimeout(1.0)
|
|
sockets = [sock]
|
|
|
|
# Bind secondary socket if specified
|
|
if port2:
|
|
sock2 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
sock2.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
sock2.bind(("0.0.0.0", port2))
|
|
sock2.settimeout(0.1)
|
|
sockets.append(sock2)
|
|
|
|
log.info(
|
|
f"Collecting '{label}' for {duration}s on port(s) "
|
|
f"{port}{f', {port2}' if port2 else ''}"
|
|
)
|
|
|
|
start = time.time()
|
|
dropped = 0
|
|
|
|
try:
|
|
while time.time() - start < duration:
|
|
for s in sockets:
|
|
try:
|
|
data, addr = s.recvfrom(4096)
|
|
except socket.timeout:
|
|
continue
|
|
|
|
frame = parse_packet(data)
|
|
if frame:
|
|
recorder.write_frame(frame)
|
|
else:
|
|
dropped += 1
|
|
|
|
# Progress update every 5s
|
|
elapsed = time.time() - start
|
|
if recorder.frame_count > 0 and int(elapsed) % 5 == 0 and int(elapsed) > 0:
|
|
remaining = duration - elapsed
|
|
if remaining > 0 and int(elapsed * 10) % 50 == 0:
|
|
log.info(
|
|
f" {recorder.frame_count} frames collected, "
|
|
f"{remaining:.0f}s remaining..."
|
|
)
|
|
except KeyboardInterrupt:
|
|
log.info("Interrupted by user.")
|
|
finally:
|
|
for s in sockets:
|
|
s.close()
|
|
|
|
if dropped > 0:
|
|
log.warning(f" {dropped} unrecognized packets dropped")
|
|
|
|
return recorder.close()
|
|
|
|
|
|
# ── Main ─────────────────────────────────────────────────────────────────────
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Collect CSI training data from ESP32 nodes via UDP",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Examples:
|
|
# Interactive label input
|
|
python scripts/collect-training-data.py --port 5006
|
|
|
|
# Fixed label, 60 seconds
|
|
python scripts/collect-training-data.py --port 5006 --label walking --duration 60
|
|
|
|
# Multiple scenarios
|
|
python scripts/collect-training-data.py --port 5006 --scenarios walking,standing,sitting --duration 30
|
|
|
|
# Dual ESP32 nodes
|
|
python scripts/collect-training-data.py --port 5005 --port2 5006 --label test
|
|
|
|
# Generate manifest from existing recordings
|
|
python scripts/collect-training-data.py --manifest-only
|
|
""",
|
|
)
|
|
|
|
parser.add_argument("--port", type=int, default=5006, help="Primary UDP port (default: 5006)")
|
|
parser.add_argument("--port2", type=int, default=None, help="Secondary UDP port for dual-node")
|
|
parser.add_argument("--output-dir", default="data/recordings", help="Output directory (default: data/recordings)")
|
|
parser.add_argument("--label", default=None, help="Activity label for the recording")
|
|
parser.add_argument("--duration", type=float, default=30.0, help="Recording duration in seconds (default: 30)")
|
|
parser.add_argument("--scenarios", default=None, help="Comma-separated list of scenarios to record sequentially")
|
|
parser.add_argument("--pause", type=float, default=5.0, help="Pause between scenarios in seconds (default: 5)")
|
|
parser.add_argument("--manifest-only", action="store_true", help="Only generate manifest from existing recordings")
|
|
parser.add_argument("--repeats", type=int, default=1, help="Number of repeats per scenario (default: 1)")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Manifest-only mode
|
|
if args.manifest_only:
|
|
generate_manifest(args.output_dir)
|
|
return
|
|
|
|
# Collect scenarios
|
|
all_sessions = []
|
|
|
|
if args.scenarios:
|
|
# Multi-scenario sequential collection
|
|
scenarios = [s.strip() for s in args.scenarios.split(",") if s.strip()]
|
|
total = len(scenarios) * args.repeats
|
|
idx = 0
|
|
|
|
for repeat in range(args.repeats):
|
|
for scenario in scenarios:
|
|
idx += 1
|
|
print(f"\n{'='*60}")
|
|
print(f" Scenario {idx}/{total}: '{scenario}' (repeat {repeat+1}/{args.repeats})")
|
|
print(f" Duration: {args.duration}s")
|
|
print(f"{'='*60}")
|
|
|
|
if idx > 1:
|
|
print(f" Starting in {args.pause}s... (get into position)")
|
|
time.sleep(args.pause)
|
|
|
|
meta = collect_session(
|
|
port=args.port,
|
|
port2=args.port2,
|
|
output_dir=args.output_dir,
|
|
label=scenario,
|
|
duration=args.duration,
|
|
session_name=f"{scenario}_r{repeat+1:02d}",
|
|
)
|
|
all_sessions.append(meta)
|
|
|
|
elif args.label:
|
|
# Single labeled recording
|
|
meta = collect_session(
|
|
port=args.port,
|
|
port2=args.port2,
|
|
output_dir=args.output_dir,
|
|
label=args.label,
|
|
duration=args.duration,
|
|
)
|
|
all_sessions.append(meta)
|
|
|
|
else:
|
|
# Interactive mode — prompt for labels
|
|
print("\nInteractive data collection mode.")
|
|
print("Type a label for each recording, or 'q' to quit.\n")
|
|
|
|
while True:
|
|
label = input("Label (or 'q' to quit): ").strip()
|
|
if label.lower() in ("q", "quit", "exit"):
|
|
break
|
|
if not label:
|
|
print(" Empty label. Try again.")
|
|
continue
|
|
|
|
duration = args.duration
|
|
try:
|
|
dur_input = input(f"Duration in seconds [{duration}]: ").strip()
|
|
if dur_input:
|
|
duration = float(dur_input)
|
|
except ValueError:
|
|
pass
|
|
|
|
print(f" Recording '{label}' for {duration}s — starting now...")
|
|
meta = collect_session(
|
|
port=args.port,
|
|
port2=args.port2,
|
|
output_dir=args.output_dir,
|
|
label=label,
|
|
duration=duration,
|
|
)
|
|
all_sessions.append(meta)
|
|
print()
|
|
|
|
# Generate manifest
|
|
if all_sessions:
|
|
print(f"\nCollected {len(all_sessions)} session(s).")
|
|
manifest = generate_manifest(args.output_dir)
|
|
|
|
total_frames = sum(s.get("frame_count", 0) for s in all_sessions)
|
|
print(f"\nSummary:")
|
|
print(f" Sessions: {len(all_sessions)}")
|
|
print(f" Total frames: {total_frames}")
|
|
print(f" Output: {args.output_dir}/")
|
|
print(f" Manifest: {args.output_dir}/manifest.json")
|
|
else:
|
|
print("No sessions recorded.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|