mirror of
https://github.com/ruvnet/RuView.git
synced 2026-04-28 05:59:32 +00:00
Add 4 scripts for camera-supervised WiFlow pose training: - collect-ground-truth.py: synchronized webcam + CSI capture via MediaPipe PoseLandmarker (17 COCO keypoints at 30fps) - align-ground-truth.js: time-align camera keypoints with CSI windows using binary search, confidence-weighted averaging - train-wiflow-supervised.js: 3-phase supervised training (contrastive pretrain → supervised keypoint regression → bone-constrained refinement) with curriculum learning and CSI augmentation - eval-wiflow.js: PCK@10/20/50, MPJPE, per-joint breakdown, baseline proxy mode for benchmarking Baseline benchmark (proxy poses, no camera supervision): PCK@10: 11.8% | PCK@20: 35.3% | PCK@50: 94.1% | MPJPE: 0.067 Camera pipeline validated over Tailscale to Mac Mini M4 Pro (1920x1080, 14/17 keypoints visible, MediaPipe confidence 0.94-1.0). Target after camera-supervised training: PCK@20 > 50% Closes #362 Co-Authored-By: claude-flow <ruv@ruv.net>
341 lines
11 KiB
Python
341 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
"""Camera ground-truth collection for WiFi pose estimation training (ADR-079).
|
|
|
|
Captures webcam keypoints via MediaPipe PoseLandmarker (Tasks API) and
|
|
synchronizes with ESP32 CSI recording from the sensing server.
|
|
|
|
Output: JSONL file in data/ground-truth/ with per-frame 17-keypoint COCO poses.
|
|
|
|
Usage:
|
|
python scripts/collect-ground-truth.py --preview --duration 60
|
|
python scripts/collect-ground-truth.py --server http://192.168.1.10:3000
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import signal
|
|
import sys
|
|
import time
|
|
import urllib.request
|
|
import urllib.error
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
|
|
import cv2
|
|
import numpy as np
|
|
|
|
import mediapipe as mp
|
|
from mediapipe.tasks.python import BaseOptions
|
|
from mediapipe.tasks.python.vision import (
|
|
PoseLandmarker,
|
|
PoseLandmarkerOptions,
|
|
RunningMode,
|
|
)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# MediaPipe 33 landmarks -> 17 COCO keypoints
|
|
# ---------------------------------------------------------------------------
|
|
# COCO idx : MP idx : joint name
|
|
# 0 : 0 : nose
|
|
# 1 : 2 : left_eye
|
|
# 2 : 5 : right_eye
|
|
# 3 : 7 : left_ear
|
|
# 4 : 8 : right_ear
|
|
# 5 : 11 : left_shoulder
|
|
# 6 : 12 : right_shoulder
|
|
# 7 : 13 : left_elbow
|
|
# 8 : 14 : right_elbow
|
|
# 9 : 15 : left_wrist
|
|
# 10 : 16 : right_wrist
|
|
# 11 : 23 : left_hip
|
|
# 12 : 24 : right_hip
|
|
# 13 : 25 : left_knee
|
|
# 14 : 26 : right_knee
|
|
# 15 : 27 : left_ankle
|
|
# 16 : 28 : right_ankle
|
|
|
|
MP_TO_COCO = [0, 2, 5, 7, 8, 11, 12, 13, 14, 15, 16, 23, 24, 25, 26, 27, 28]
|
|
|
|
COCO_BONES = [
|
|
(5, 7), (7, 9), (6, 8), (8, 10), # arms
|
|
(5, 6), # shoulders
|
|
(11, 13), (13, 15), (12, 14), (14, 16), # legs
|
|
(11, 12), # hips
|
|
(5, 11), (6, 12), # torso
|
|
(0, 1), (0, 2), (1, 3), (2, 4), # face
|
|
]
|
|
|
|
MODEL_URL = (
|
|
"https://storage.googleapis.com/mediapipe-models/"
|
|
"pose_landmarker/pose_landmarker_lite/float16/latest/"
|
|
"pose_landmarker_lite.task"
|
|
)
|
|
MODEL_FILENAME = "pose_landmarker_lite.task"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def ensure_model(cache_dir: Path) -> Path:
|
|
"""Download the PoseLandmarker model if not already cached."""
|
|
model_path = cache_dir / MODEL_FILENAME
|
|
if model_path.exists():
|
|
return model_path
|
|
|
|
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
print(f"Downloading {MODEL_FILENAME} ...")
|
|
try:
|
|
urllib.request.urlretrieve(MODEL_URL, str(model_path))
|
|
print(f" saved to {model_path}")
|
|
except Exception as exc:
|
|
print(f"ERROR: Failed to download model: {exc}", file=sys.stderr)
|
|
print(
|
|
"Download manually from:\n"
|
|
f" {MODEL_URL}\n"
|
|
f"and place at {model_path}",
|
|
file=sys.stderr,
|
|
)
|
|
sys.exit(1)
|
|
return model_path
|
|
|
|
|
|
def post_json(url: str, payload: dict | None = None, timeout: float = 5.0) -> bool:
|
|
"""POST JSON to a URL. Returns True on success, False on failure."""
|
|
data = json.dumps(payload or {}).encode("utf-8")
|
|
req = urllib.request.Request(
|
|
url,
|
|
data=data,
|
|
headers={"Content-Type": "application/json"},
|
|
method="POST",
|
|
)
|
|
try:
|
|
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
|
return 200 <= resp.status < 300
|
|
except Exception as exc:
|
|
print(f"WARNING: POST {url} failed: {exc}", file=sys.stderr)
|
|
return False
|
|
|
|
|
|
def draw_skeleton(frame: np.ndarray, keypoints: list[list[float]], w: int, h: int):
|
|
"""Draw COCO skeleton overlay on a BGR frame."""
|
|
pts = []
|
|
for x, y in keypoints:
|
|
px, py = int(x * w), int(y * h)
|
|
pts.append((px, py))
|
|
cv2.circle(frame, (px, py), 4, (0, 255, 0), -1)
|
|
|
|
for i, j in COCO_BONES:
|
|
if i < len(pts) and j < len(pts):
|
|
cv2.line(frame, pts[i], pts[j], (0, 200, 255), 2)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Main collection loop
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Collect camera ground-truth keypoints for WiFi pose training (ADR-079)."
|
|
)
|
|
parser.add_argument(
|
|
"--server",
|
|
default="http://localhost:3000",
|
|
help="Sensing server URL (default: http://localhost:3000)",
|
|
)
|
|
parser.add_argument(
|
|
"--preview",
|
|
action="store_true",
|
|
help="Show live skeleton overlay window",
|
|
)
|
|
parser.add_argument(
|
|
"--duration",
|
|
type=int,
|
|
default=300,
|
|
help="Recording duration in seconds (default: 300)",
|
|
)
|
|
parser.add_argument(
|
|
"--camera",
|
|
type=int,
|
|
default=0,
|
|
help="Camera device index (default: 0)",
|
|
)
|
|
parser.add_argument(
|
|
"--output",
|
|
default="data/ground-truth",
|
|
help="Output directory (default: data/ground-truth)",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
# --- Resolve paths relative to repo root ---
|
|
repo_root = Path(__file__).resolve().parent.parent
|
|
output_dir = repo_root / args.output
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
cache_dir = repo_root / "data" / ".cache"
|
|
|
|
# --- Download / locate model ---
|
|
model_path = ensure_model(cache_dir)
|
|
|
|
# --- Open camera ---
|
|
cap = cv2.VideoCapture(args.camera)
|
|
if not cap.isOpened():
|
|
print(
|
|
f"ERROR: Cannot open camera index {args.camera}. "
|
|
"Check that a webcam is connected and not in use by another app.",
|
|
file=sys.stderr,
|
|
)
|
|
sys.exit(1)
|
|
|
|
frame_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
frame_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
print(f"Camera opened: {frame_w}x{frame_h}")
|
|
|
|
# --- Create PoseLandmarker ---
|
|
options = PoseLandmarkerOptions(
|
|
base_options=BaseOptions(model_asset_path=str(model_path)),
|
|
running_mode=RunningMode.IMAGE,
|
|
num_poses=1,
|
|
min_pose_detection_confidence=0.5,
|
|
min_pose_presence_confidence=0.5,
|
|
min_tracking_confidence=0.5,
|
|
)
|
|
landmarker = PoseLandmarker.create_from_options(options)
|
|
|
|
# --- Output file ---
|
|
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
out_path = output_dir / f"keypoints_{timestamp_str}.jsonl"
|
|
out_file = open(out_path, "w", encoding="utf-8")
|
|
print(f"Output: {out_path}")
|
|
|
|
# --- Start CSI recording ---
|
|
recording_url_start = f"{args.server}/api/v1/recording/start"
|
|
recording_url_stop = f"{args.server}/api/v1/recording/stop"
|
|
csi_started = post_json(recording_url_start)
|
|
if csi_started:
|
|
print("CSI recording started on sensing server.")
|
|
else:
|
|
print(
|
|
"WARNING: Could not start CSI recording. "
|
|
"Camera keypoints will still be captured.",
|
|
file=sys.stderr,
|
|
)
|
|
|
|
# --- Graceful shutdown ---
|
|
shutdown_requested = False
|
|
|
|
def _handle_signal(signum, frame):
|
|
nonlocal shutdown_requested
|
|
shutdown_requested = True
|
|
|
|
signal.signal(signal.SIGINT, _handle_signal)
|
|
signal.signal(signal.SIGTERM, _handle_signal)
|
|
|
|
# --- Collection loop ---
|
|
start_time = time.monotonic()
|
|
frame_count = 0
|
|
total_confidence = 0.0
|
|
total_visible = 0
|
|
|
|
print(f"Collecting for {args.duration}s ... (press 'q' in preview to stop)")
|
|
|
|
try:
|
|
while not shutdown_requested:
|
|
elapsed = time.monotonic() - start_time
|
|
if elapsed >= args.duration:
|
|
break
|
|
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
print("WARNING: Failed to read frame, retrying ...", file=sys.stderr)
|
|
time.sleep(0.01)
|
|
continue
|
|
|
|
ts_ns = time.time_ns()
|
|
|
|
# Convert BGR -> RGB for MediaPipe
|
|
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb)
|
|
|
|
result = landmarker.detect(mp_image)
|
|
|
|
n_persons = len(result.pose_landmarks)
|
|
|
|
if n_persons > 0:
|
|
landmarks = result.pose_landmarks[0]
|
|
keypoints = []
|
|
visibilities = []
|
|
for coco_idx in range(17):
|
|
mp_idx = MP_TO_COCO[coco_idx]
|
|
lm = landmarks[mp_idx]
|
|
keypoints.append([round(lm.x, 5), round(lm.y, 5)])
|
|
visibilities.append(lm.visibility if lm.visibility else 0.0)
|
|
|
|
confidence = float(np.mean(visibilities))
|
|
n_visible = int(sum(1 for v in visibilities if v > 0.5))
|
|
else:
|
|
keypoints = []
|
|
confidence = 0.0
|
|
n_visible = 0
|
|
|
|
record = {
|
|
"ts_ns": ts_ns,
|
|
"keypoints": keypoints,
|
|
"confidence": round(confidence, 4),
|
|
"n_visible": n_visible,
|
|
"n_persons": n_persons,
|
|
}
|
|
out_file.write(json.dumps(record) + "\n")
|
|
frame_count += 1
|
|
total_confidence += confidence
|
|
total_visible += n_visible
|
|
|
|
# Preview overlay
|
|
if args.preview and keypoints:
|
|
draw_skeleton(frame, keypoints, frame_w, frame_h)
|
|
|
|
if args.preview:
|
|
remaining = max(0, int(args.duration - elapsed))
|
|
cv2.putText(
|
|
frame,
|
|
f"Frames: {frame_count} Visible: {n_visible}/17 Time: {remaining}s",
|
|
(10, 30),
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|
0.7,
|
|
(255, 255, 255),
|
|
2,
|
|
)
|
|
cv2.imshow("Ground Truth Collection (ADR-079)", frame)
|
|
if cv2.waitKey(1) & 0xFF == ord("q"):
|
|
break
|
|
|
|
finally:
|
|
# --- Cleanup ---
|
|
out_file.close()
|
|
cap.release()
|
|
if args.preview:
|
|
cv2.destroyAllWindows()
|
|
landmarker.close()
|
|
|
|
# Stop CSI recording
|
|
if csi_started:
|
|
if post_json(recording_url_stop):
|
|
print("CSI recording stopped.")
|
|
else:
|
|
print("WARNING: Failed to stop CSI recording.", file=sys.stderr)
|
|
|
|
# --- Summary ---
|
|
avg_conf = total_confidence / frame_count if frame_count > 0 else 0.0
|
|
avg_vis = total_visible / frame_count if frame_count > 0 else 0.0
|
|
print()
|
|
print("=== Collection Summary ===")
|
|
print(f" Total frames: {frame_count}")
|
|
print(f" Avg confidence: {avg_conf:.3f}")
|
|
print(f" Avg visible joints: {avg_vis:.1f} / 17")
|
|
print(f" Output: {out_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|