Merge pull request #363 from ruvnet/feat/adr-079-camera-ground-truth

feat: camera ground-truth training pipeline with ruvector optimizations (ADR-079)
This commit is contained in:
rUv 2026-04-06 17:29:13 -04:00 committed by GitHub
commit 8dddbf941a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
37 changed files with 6359 additions and 28 deletions

View file

@ -0,0 +1,512 @@
# ADR-079: Camera Ground-Truth Training Pipeline
- **Status**: Accepted
- **Date**: 2026-04-06
- **Deciders**: ruv
- **Relates to**: ADR-072 (WiFlow Architecture), ADR-070 (Self-Supervised Pretraining), ADR-071 (ruvllm Training Pipeline), ADR-024 (AETHER Contrastive), ADR-064 (Multimodal Ambient Intelligence), ADR-075 (MinCut Person Separation)
## Context
WiFlow (ADR-072) currently trains without ground-truth pose labels, using proxy poses
generated from presence/motion heuristics. This produces a PCK@20 of only 2.5% — far
below the 30-50% achievable with supervised training. The fundamental bottleneck is the
absence of spatial keypoint labels.
Academic WiFi pose estimation systems (Wi-Pose, Person-in-WiFi 3D, MetaFi++) all train
with synchronized camera ground truth and achieve PCK@20 of 40-85%. They discard the
camera at deployment — the camera is a training-time teacher, not a runtime dependency.
ADR-064 already identified this: *"Record CSI + mmWave while performing signs with a
camera as ground truth, then deploy camera-free."* This ADR specifies the implementation.
### Current Training Pipeline Gap
```
Current: CSI amplitude → WiFlow → 17 keypoints (proxy-supervised, PCK@20 = 2.5%)
Heuristic proxies:
- Standing skeleton when presence > 0.3
- Limb perturbation from motion energy
- No spatial accuracy
```
### Target Pipeline
```
Training: CSI amplitude ──→ WiFlow ──→ 17 keypoints (camera-supervised, PCK@20 target: 35%+)
Laptop camera ──→ MediaPipe ──→ 17 COCO keypoints (ground truth)
(time-synchronized, 30 fps)
Deploy: CSI amplitude ──→ WiFlow ──→ 17 keypoints (camera-free, trained model only)
```
## Decision
Build a camera ground-truth collection and training pipeline using the laptop webcam
as a teacher signal. The camera is used **only during training data collection** and is
not required at deployment.
### Architecture Overview
```
┌─────────────────────────────────────────────────────────────────┐
│ Data Collection Phase │
│ │
│ ESP32-S3 nodes ──UDP──→ Sensing Server ──→ CSI frames (.jsonl) │
│ ↑ time sync │
│ Laptop Camera ──→ MediaPipe Pose ──→ Keypoints (.jsonl) │
│ ↑ │
│ collect-ground-truth.py │
│ (single orchestrator) │
└─────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────┐
│ Training Phase │
│ │
│ Paired dataset: { csi_window[128,20], keypoints[17,2], conf } │
│ ↓ │
│ train-wiflow-supervised.js │
│ Phase 1: Contrastive pretrain (ADR-072, reuse) │
│ Phase 2: Supervised keypoint regression (NEW) │
│ Phase 3: Fine-tune with bone constraints + confidence │
│ ↓ │
│ WiFlow model (1.8M params) → SafeTensors export │
└─────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────┐
│ Deployment (camera-free) │
│ │
│ ESP32-S3 CSI → Sensing Server → WiFlow inference → 17 keypoints│
│ (No camera. Trained model runs on CSI input only.) │
└─────────────────────────────────────────────────────────────────┘
```
### Component 1: `scripts/collect-ground-truth.py`
Single Python script that orchestrates synchronized capture from the laptop camera
and the ESP32 CSI stream.
**Dependencies:** `mediapipe`, `opencv-python`, `requests` (all pip-installable, no GPU)
**Capture flow:**
```python
# Pseudocode
camera = cv2.VideoCapture(0) # Laptop webcam
sensing_api = "http://localhost:3000" # Sensing server
# Start CSI recording via existing API
requests.post(f"{sensing_api}/api/v1/recording/start")
while recording:
frame = camera.read()
t = time.time_ns() # Nanosecond timestamp
# MediaPipe Pose: 33 landmarks → map to 17 COCO keypoints
result = mp_pose.process(frame)
keypoints_17 = map_mediapipe_to_coco(result.pose_landmarks)
confidence = mean(landmark.visibility for relevant landmarks)
# Write to ground-truth JSONL (one line per frame)
write_jsonl({
"ts_ns": t,
"keypoints": keypoints_17, # [[x,y], ...] normalized [0,1]
"confidence": confidence, # 0-1, used for loss weighting
"n_visible": count(visibility > 0.5),
})
# Optional: show live preview with skeleton overlay
if preview:
draw_skeleton(frame, keypoints_17)
cv2.imshow("Ground Truth", frame)
# Stop CSI recording
requests.post(f"{sensing_api}/api/v1/recording/stop")
```
**MediaPipe → COCO keypoint mapping:**
| COCO Index | Joint | MediaPipe Index |
|------------|-------|-----------------|
| 0 | Nose | 0 |
| 1 | Left Eye | 2 |
| 2 | Right Eye | 5 |
| 3 | Left Ear | 7 |
| 4 | Right Ear | 8 |
| 5 | Left Shoulder | 11 |
| 6 | Right Shoulder | 12 |
| 7 | Left Elbow | 13 |
| 8 | Right Elbow | 14 |
| 9 | Left Wrist | 15 |
| 10 | Right Wrist | 16 |
| 11 | Left Hip | 23 |
| 12 | Right Hip | 24 |
| 13 | Left Knee | 25 |
| 14 | Right Knee | 26 |
| 15 | Left Ankle | 27 |
| 16 | Right Ankle | 28 |
### Component 2: Time Alignment (`scripts/align-ground-truth.js`)
CSI frames arrive at ~100 Hz with server-side timestamps. Camera keypoints arrive at
~30 fps with client-side timestamps. Alignment is needed because:
1. Camera and sensing server clocks differ (typically < 50ms on LAN)
2. CSI is aggregated into 20-frame windows for WiFlow input
3. Ground-truth keypoints must be averaged over the same window
**Alignment algorithm:**
```
For each CSI window W_i (20 frames, ~200ms at 100Hz):
t_start = W_i.first_frame.timestamp
t_end = W_i.last_frame.timestamp
# Find all camera keypoints within this time window
matching_keypoints = [k for k in camera_data if t_start <= k.ts <= t_end]
if len(matching_keypoints) >= 3: # At least 3 camera frames per window
# Average keypoints, weighted by confidence
avg_keypoints = weighted_mean(matching_keypoints, weights=confidences)
avg_confidence = mean(confidences)
paired_dataset.append({
csi_window: W_i.amplitudes, # [128, 20] float32
keypoints: avg_keypoints, # [17, 2] float32
confidence: avg_confidence, # scalar
n_camera_frames: len(matching_keypoints),
})
```
**Clock sync strategy:**
- NTP is sufficient (< 20ms error on LAN)
- The 200ms CSI window is 10x larger than typical clock drift
- For tighter sync: use a handclap/jump as a sync marker — visible spike in both
CSI motion energy and camera skeleton velocity. Auto-detect and align.
**Output:** `data/recordings/paired-{timestamp}.jsonl` — one line per paired sample:
```json
{"csi": [128x20 flat], "kp": [[0.45,0.12], ...], "conf": 0.92, "ts": 1775300000000}
```
### Component 3: Supervised Training (`scripts/train-wiflow-supervised.js`)
Extends the existing `train-ruvllm.js` pipeline with a supervised phase.
**Phase 1: Contrastive Pretrain (reuse ADR-072)**
- Same as existing: temporal + cross-node triplets
- Learns CSI representation without labels
- 50 epochs, ~5 min on laptop
**Phase 2: Supervised Keypoint Regression (NEW)**
- Load paired dataset from Component 2
- Loss: confidence-weighted SmoothL1 on keypoints
```
L_supervised = (1/N) * sum_i [ conf_i * SmoothL1(pred_i, gt_i, beta=0.05) ]
```
- Only train on samples where `conf > 0.5` (discard frames where MediaPipe lost tracking)
- Learning rate: 1e-4 with cosine decay
- 200 epochs, ~15 min on laptop CPU (1.8M params, no GPU needed)
**Phase 3: Refinement with Bone Constraints**
- Fine-tune with combined loss:
```
L = L_supervised + 0.3 * L_bone + 0.1 * L_temporal
L_bone = (1/14) * sum_b (bone_len_b - prior_b)^2 # ADR-072 bone priors
L_temporal = SmoothL1(kp_t, kp_{t-1}) # Temporal smoothness
```
- 50 epochs at lower LR (1e-5)
- Tighten bone constraint weight from 0.3 → 0.5 over epochs
**Phase 4: Quantization + Export**
- Reuse ruvllm TurboQuant: float32 → int8 (4x smaller, ~881 KB)
- Export via SafeTensors for cross-platform deployment
- Validate quantized model PCK@20 within 2% of full-precision
### Component 4: Evaluation Script (`scripts/eval-wiflow.js`)
Measure actual PCK@20 using held-out paired data (20% split).
```
PCK@k = (1/N) * sum_i [ (||pred_i - gt_i|| < k * torso_length) ? 1 : 0 ]
```
**Metrics reported:**
| Metric | Description | Target |
|--------|-------------|--------|
| PCK@20 | % of keypoints within 20% torso length | > 35% |
| PCK@50 | % within 50% torso length | > 60% |
| MPJPE | Mean per-joint position error (pixels) | < 40px |
| Per-joint PCK | Breakdown by joint (wrists are hardest) | Report all 17 |
| Inference latency | Single window prediction time | < 50ms |
### Optimization Strategy
#### O1: Curriculum Learning
Train easy poses first, hard poses later:
| Stage | Epochs | Data Filter | Rationale |
|-------|--------|-------------|-----------|
| 1 | 50 | `conf > 0.9`, standing only | Establish stable skeleton baseline |
| 2 | 50 | `conf > 0.7`, low motion | Add sitting, subtle movements |
| 3 | 50 | `conf > 0.5`, all poses | Full dataset including occlusions |
| 4 | 50 | All data, with augmentation | Robustness via noise injection |
#### O2: Data Augmentation (CSI domain)
Augment CSI windows to increase effective dataset size without collecting more data:
| Augmentation | Implementation | Expected Gain |
|-------------|----------------|---------------|
| Time shift | Roll CSI window by ±2 frames | +30% data |
| Amplitude noise | Gaussian noise, sigma=0.02 | Robustness |
| Subcarrier dropout | Zero 10% of subcarriers randomly | Robustness |
| Temporal flip | Reverse window + reverse keypoint velocity | +100% data |
| Multi-node mix | Swap node CSI, keep same-time keypoints | Cross-node generalization |
#### O3: Knowledge Distillation from MediaPipe
Instead of raw keypoint regression, distill MediaPipe's confidence and heatmap
information:
```
L_distill = KL_div(softmax(wifi_heatmap / T), softmax(camera_heatmap / T))
```
- Temperature T=4 for soft targets (transfers inter-joint relationships)
- WiFlow predicts a 17-channel heatmap [17, H, W] instead of direct [17, 2]
- Argmax for final keypoint extraction
- **Trade-off:** Adds ~200K params for heatmap decoder, but improves spatial precision
#### O4: Active Learning Loop
Identify which poses the model is worst at and collect more data for those:
```
1. Train initial model on first collection session
2. Run inference on new CSI data, compute prediction entropy
3. Flag high-entropy windows (model is uncertain)
4. During next collection, the preview overlay highlights these moments:
"Hold this pose — model needs more examples"
5. Re-train with augmented dataset
```
Expected: 2-3 active learning iterations reach saturation.
#### O6: Subcarrier Selection (ruvector-solver)
Variance-based top-K subcarrier selection, equivalent to ruvector-solver's sparse
interpolation (114→56). Removes noise/static subcarriers before training:
```
For each subcarrier d in [0, dim):
variance[d] = mean over samples of temporal_variance(csi[d, :])
Select top-K by variance (K = dim * 0.5)
```
**Validated:** 128 → 56 subcarriers (56% input reduction), proportional model size reduction.
#### O7: Attention-Weighted Subcarriers (ruvector-attention)
Compute per-subcarrier attention weights based on temporal energy correlation with
ground-truth keypoint motion. High-energy subcarriers that covary with skeleton
movement get amplified:
```
For each subcarrier d:
energy[d] = sum of squared first-differences over time
weight[d] = softmax(energy, temperature=0.1)
Apply: csi[d, :] *= weight[d] * dim (mean weight = 1)
```
**Validated:** Top-5 attention subcarriers identified automatically per dataset.
#### O8: Stoer-Wagner MinCut Person Separation (ruvector-mincut / ADR-075)
JS implementation of the Stoer-Wagner algorithm for person separation in CSI, equivalent
to `DynamicPersonMatcher` in `wifi-densepose-train/src/metrics.rs`. Builds a subcarrier
correlation graph and finds the minimum cut to identify person-specific subcarrier clusters:
```
1. Build dim×dim Pearson correlation matrix across subcarriers
2. Run Stoer-Wagner min-cut on correlation graph
3. Partition subcarriers into person-specific groups
4. Train per-partition models for multi-person scenarios
```
**Validated:** Stoer-Wagner executes on 56-dim graph, identifies partition boundaries.
#### O9: Multi-SPSA Gradient Estimation
Average over K=3 random perturbation directions per gradient step. Reduces variance
by sqrt(K) = 1.73x compared to single SPSA, at 3x forward pass cost (net win for
convergence quality):
```
For k in 1..K:
delta_k = random ±1 per parameter
grad_k = (loss(w + eps*delta_k) - loss(w - eps*delta_k)) / (2*eps*delta_k)
grad = mean(grad_1, ..., grad_K)
```
#### O10: Mac M4 Pro Training via Tailscale
Training runs on Mac Mini M4 Pro (16-core GPU, ARM NEON SIMD) via Tailscale SSH,
using ruvllm's native Node.js SIMD ops:
| | Windows (CPU) | Mac M4 Pro |
|---|---|---|
| Node.js | v24.12.0 (x86) | v25.9.0 (ARM) |
| SIMD | SSE4/AVX2 | NEON |
| Cores | Consumer laptop | 12P + 4E cores |
| Training | Slow (minutes/epoch) | Fast (seconds/epoch) |
#### O5: Cross-Environment Transfer
Train on one room, deploy in another:
| Strategy | Implementation |
|----------|---------------|
| Room-invariant features | Normalize CSI by running mean/variance |
| LoRA adapters | Train a 4-rank LoRA per room (ADR-071) — 7.3 KB each |
| Few-shot calibration | 2 min of camera data in new room → fine-tune LoRA only |
| AETHER embeddings | Use contrastive room-independent features (ADR-024) as input |
The LoRA approach is most practical: ship a base model + collect 2 min of calibration
data per new room using the laptop camera.
### Data Collection Protocol
Recommended collection sessions per room:
| Session | Duration | Activity | People | Total CSI Frames |
|---------|----------|----------|--------|-----------------|
| 1. Baseline | 5 min | Empty + 1 person entry/exit | 0-1 | 30,000 |
| 2. Standing poses | 5 min | Stand, arms up/down/sides, turn | 1 | 30,000 |
| 3. Sitting | 5 min | Sit, type, lean, stand up/sit down | 1 | 30,000 |
| 4. Walking | 5 min | Walk paths across room | 1 | 30,000 |
| 5. Mixed | 5 min | Varied activities, transitions | 1 | 30,000 |
| 6. Multi-person | 5 min | 2 people, varied activities | 2 | 30,000 |
| **Total** | **30 min** | | | **180,000** |
At 20-frame windows: **9,000 paired training samples** per 30-min session.
With augmentation (O2): **~27,000 effective samples**.
Camera placement: position laptop so the camera has a clear view of the sensing area.
The camera FOV should cover the same space the ESP32 nodes cover.
### File Structure
```
scripts/
collect-ground-truth.py # Camera capture + MediaPipe + CSI sync
align-ground-truth.js # Time-align CSI windows with camera keypoints
train-wiflow-supervised.js # Supervised training pipeline
eval-wiflow.js # PCK evaluation on held-out data
data/
ground-truth/ # Raw camera keypoint captures
gt-{timestamp}.jsonl
paired/ # Aligned CSI + keypoint pairs
paired-{timestamp}.jsonl
models/
wiflow-supervised/ # Trained model outputs
wiflow-v1.safetensors
wiflow-v1-int8.safetensors
training-log.json
eval-report.json
```
### Privacy Considerations
- Camera frames are processed **locally** by MediaPipe — no cloud upload
- Raw video is **never saved** — only extracted keypoint coordinates are stored
- The `.jsonl` ground-truth files contain only `[x,y]` joint coordinates, not images
- The trained model runs on CSI only — no camera data leaves the laptop
- Users can delete `data/ground-truth/` after training; the model is self-contained
## Consequences
### Positive
- **10-20x accuracy improvement**: PCK@20 from 2.5% → 35%+ with real supervision
- **Reuses existing infrastructure**: sensing server recording API, ruvllm training, SafeTensors
- **No new hardware**: laptop webcam + existing ESP32 nodes
- **Privacy preserved at deployment**: camera only needed during 30-min training session
- **Incremental**: can improve with more collection sessions + active learning
- **Distributable**: trained model weights can be shared on HuggingFace (ADR-070)
### Negative
- **Camera placement matters**: must see the same area ESP32 nodes sense
- **Single-room models**: need LoRA calibration per room (2 min + camera)
- **MediaPipe limitations**: occlusion, side views, multiple people reduce keypoint quality
- **Time sync**: NTP drift can misalign frames (mitigated by 200ms windows)
### Risks
| Risk | Probability | Impact | Mitigation |
|------|-------------|--------|------------|
| MediaPipe keypoints too noisy | Low | Medium | Filter by confidence; MediaPipe is robust indoors |
| Clock drift > 100ms | Low | High | Add handclap sync marker detection |
| Single camera can't see all poses | Medium | Medium | Position camera centrally; collect from 2 angles |
| Model overfits to one room | High | Medium | LoRA adapters + AETHER normalization (O5) |
| Insufficient data (< 5K pairs) | Low | High | Augmentation (O2) + active learning (O4) |
## Implementation Plan
| Phase | Task | Effort | Status |
|-------|------|--------|--------|
| P1 | `collect-ground-truth.py` — camera + MediaPipe capture | 2 hrs | **Done** |
| P2 | `align-ground-truth.js` — time alignment + pairing | 1 hr | **Done** |
| P3 | `train-wiflow-supervised.js` — supervised training | 3 hrs | **Done** |
| P4 | `eval-wiflow.js` — PCK evaluation | 1 hr | **Done** |
| P5 | ruvector optimizations (O6-O9) | 2 hrs | **Done** |
| P6 | Mac M4 Pro training via Tailscale (O10) | 1 hr | **Done** |
| P7 | Data collection session (30 min recording) | 1 hr | Pending |
| P8 | Training + evaluation on real paired data | 30 min | Pending |
| P9 | LoRA cross-room calibration (O5) | 2 hrs | Pending |
## Validated Hardware
| Component | Spec | Validated |
|-----------|------|-----------|
| Mac Mini camera | 1920x1080, 30fps | Yes — 14/17 keypoints, conf 0.94-1.0 |
| MediaPipe PoseLandmarker | v0.10.33 Tasks API, lite model | Yes — via Tailscale SSH |
| Mac M4 Pro GPU | 16-core, Metal 4, NEON SIMD | Yes — Node.js v25.9.0 |
| Tailscale SSH | LAN-accessible Mac, passwordless | Yes |
| ESP32-S3 CSI | 128 subcarriers, 100Hz | Yes — existing recordings |
| Sensing server recording API | `/api/v1/recording/start\|stop` | Yes — existing |
## Baseline Benchmark
Proxy-pose baseline (no camera supervision, standing skeleton heuristic):
```
PCK@10: 11.8%
PCK@20: 35.3%
PCK@50: 94.1%
MPJPE: 0.067
Latency: 0.03ms/sample
```
Per-joint PCK@20: upper body (nose, shoulders, wrists) at 0% — proxy has no spatial
accuracy for these. Camera supervision targets these joints specifically.
## References
- WiFlow: arXiv:2602.08661 — WiFi-based pose estimation with TCN + axial attention
- Wi-Pose (CVPR 2021) — 3D CNN WiFi pose with camera supervision
- Person-in-WiFi 3D (CVPR 2024) — Deformable attention with camera labels
- MediaPipe Pose — Google's real-time 33-landmark body pose estimator
- MetaFi++ (NeurIPS 2023) — Meta-learning cross-modal WiFi sensing

View file

@ -330,9 +330,36 @@ impl<B: Backend> InferenceEngine<B> {
Ok(result)
}
/// Run batched inference
/// Run batched inference.
///
/// Stacks all inputs along a new batch dimension, runs a single
/// backend call, then splits the output back into individual tensors.
/// Falls back to sequential inference if stack/split fails.
pub fn infer_batch(&self, inputs: &[Tensor]) -> NnResult<Vec<Tensor>> {
inputs.iter().map(|input| self.infer(input)).collect()
if inputs.is_empty() {
return Ok(Vec::new());
}
if inputs.len() == 1 {
return Ok(vec![self.infer(&inputs[0])?]);
}
// Try batched path: stack -> single call -> split
match Tensor::stack(inputs) {
Ok(batched_input) => {
let n = inputs.len();
let batched_output = self.backend.run_single(&batched_input)?;
match batched_output.split(n) {
Ok(outputs) => Ok(outputs),
Err(_) => {
// Fallback: sequential
inputs.iter().map(|input| self.infer(input)).collect()
}
}
}
Err(_) => {
// Fallback: sequential if shapes are incompatible
inputs.iter().map(|input| self.infer(input)).collect()
}
}
}
/// Get inference statistics

View file

@ -304,6 +304,74 @@ impl Tensor {
}
}
/// Stack multiple tensors along a new batch dimension (dim 0).
///
/// All tensors must have the same shape. The result has one extra
/// leading dimension equal to `tensors.len()`.
pub fn stack(tensors: &[Tensor]) -> NnResult<Tensor> {
if tensors.is_empty() {
return Err(NnError::tensor_op("Cannot stack zero tensors"));
}
let first_shape = tensors[0].shape();
for (i, t) in tensors.iter().enumerate().skip(1) {
if t.shape() != first_shape {
return Err(NnError::tensor_op(&format!(
"Shape mismatch at index {i}: expected {first_shape}, got {}",
t.shape()
)));
}
}
let mut all_data: Vec<f32> = Vec::with_capacity(tensors.len() * first_shape.numel());
for t in tensors {
let data = t.to_vec()?;
all_data.extend_from_slice(&data);
}
let mut new_dims = vec![tensors.len()];
new_dims.extend_from_slice(first_shape.dims());
let arr = ndarray::ArrayD::from_shape_vec(
ndarray::IxDyn(&new_dims),
all_data,
)
.map_err(|e| NnError::tensor_op(&format!("Stack reshape failed: {e}")))?;
Ok(Tensor::FloatND(arr))
}
/// Split a tensor along dim 0 into `n` sub-tensors.
///
/// The first dimension must be evenly divisible by `n`.
pub fn split(self, n: usize) -> NnResult<Vec<Tensor>> {
if n == 0 {
return Err(NnError::tensor_op("Cannot split into 0 pieces"));
}
let shape = self.shape();
let batch = shape.dim(0).ok_or_else(|| NnError::tensor_op("Tensor has no dimensions"))?;
if batch % n != 0 {
return Err(NnError::tensor_op(&format!(
"Batch dim {batch} not divisible by {n}"
)));
}
let chunk_size = batch / n;
let data = self.to_vec()?;
let elem_per_sample = shape.numel() / batch;
let sub_dims: Vec<usize> = {
let mut d = shape.dims().to_vec();
d[0] = chunk_size;
d
};
let mut result = Vec::with_capacity(n);
for i in 0..n {
let start = i * chunk_size * elem_per_sample;
let end = start + chunk_size * elem_per_sample;
let arr = ndarray::ArrayD::from_shape_vec(
ndarray::IxDyn(&sub_dims),
data[start..end].to_vec(),
)
.map_err(|e| NnError::tensor_op(&format!("Split reshape failed: {e}")))?;
result.push(Tensor::FloatND(arr));
}
Ok(result)
}
/// Compute standard deviation
pub fn std(&self) -> NnResult<f32> {
match self {

View file

@ -0,0 +1,105 @@
//! CLI argument definitions and early-exit mode handlers.
use std::path::PathBuf;
use clap::Parser;
/// CLI arguments for the sensing server.
#[derive(Parser, Debug)]
#[command(name = "sensing-server", about = "WiFi-DensePose sensing server")]
pub struct Args {
/// HTTP port for UI and REST API
#[arg(long, default_value = "8080")]
pub http_port: u16,
/// WebSocket port for sensing stream
#[arg(long, default_value = "8765")]
pub ws_port: u16,
/// UDP port for ESP32 CSI frames
#[arg(long, default_value = "5005")]
pub udp_port: u16,
/// Path to UI static files
#[arg(long, default_value = "../../ui")]
pub ui_path: PathBuf,
/// Tick interval in milliseconds (default 100 ms = 10 fps for smooth pose animation)
#[arg(long, default_value = "100")]
pub tick_ms: u64,
/// Bind address (default 127.0.0.1; set to 0.0.0.0 for network access)
#[arg(long, default_value = "127.0.0.1", env = "SENSING_BIND_ADDR")]
pub bind_addr: String,
/// Data source: auto, wifi, esp32, simulate
#[arg(long, default_value = "auto")]
pub source: String,
/// Run vital sign detection benchmark (1000 frames) and exit
#[arg(long)]
pub benchmark: bool,
/// Load model config from an RVF container at startup
#[arg(long, value_name = "PATH")]
pub load_rvf: Option<PathBuf>,
/// Save current model state as an RVF container on shutdown
#[arg(long, value_name = "PATH")]
pub save_rvf: Option<PathBuf>,
/// Load a trained .rvf model for inference
#[arg(long, value_name = "PATH")]
pub model: Option<PathBuf>,
/// Enable progressive loading (Layer A instant start)
#[arg(long)]
pub progressive: bool,
/// Export an RVF container package and exit (no server)
#[arg(long, value_name = "PATH")]
pub export_rvf: Option<PathBuf>,
/// Run training mode (train a model and exit)
#[arg(long)]
pub train: bool,
/// Path to dataset directory (MM-Fi or Wi-Pose)
#[arg(long, value_name = "PATH")]
pub dataset: Option<PathBuf>,
/// Dataset type: "mmfi" or "wipose"
#[arg(long, value_name = "TYPE", default_value = "mmfi")]
pub dataset_type: String,
/// Number of training epochs
#[arg(long, default_value = "100")]
pub epochs: usize,
/// Directory for training checkpoints
#[arg(long, value_name = "DIR")]
pub checkpoint_dir: Option<PathBuf>,
/// Run self-supervised contrastive pretraining (ADR-024)
#[arg(long)]
pub pretrain: bool,
/// Number of pretraining epochs (default 50)
#[arg(long, default_value = "50")]
pub pretrain_epochs: usize,
/// Extract embeddings mode: load model and extract CSI embeddings
#[arg(long)]
pub embed: bool,
/// Build fingerprint index from embeddings (env|activity|temporal|person)
#[arg(long, value_name = "TYPE")]
pub build_index: Option<String>,
/// Node positions for multistatic fusion (format: "x,y,z;x,y,z;...")
#[arg(long, env = "SENSING_NODE_POSITIONS")]
pub node_positions: Option<String>,
/// Start field model calibration on boot (empty room required)
#[arg(long)]
pub calibrate: bool,
}

View file

@ -0,0 +1,675 @@
//! CSI frame parsing, signal field generation, feature extraction,
//! classification, vital signs smoothing, and multi-person estimation.
use std::collections::{HashMap, VecDeque};
use ruvector_mincut::{DynamicMinCut, MinCutBuilder};
use crate::adaptive_classifier;
use crate::types::*;
use crate::vital_signs::VitalSigns;
// ── ESP32 UDP frame parsers ─────────────────────────────────────────────────
/// Parse a 32-byte edge vitals packet (magic 0xC511_0002).
pub fn parse_esp32_vitals(buf: &[u8]) -> Option<Esp32VitalsPacket> {
if buf.len() < 32 { return None; }
let magic = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]);
if magic != 0xC511_0002 { return None; }
let node_id = buf[4];
let flags = buf[5];
let breathing_raw = u16::from_le_bytes([buf[6], buf[7]]);
let heartrate_raw = u32::from_le_bytes([buf[8], buf[9], buf[10], buf[11]]);
let rssi = buf[12] as i8;
let n_persons = buf[13];
let motion_energy = f32::from_le_bytes([buf[16], buf[17], buf[18], buf[19]]);
let presence_score = f32::from_le_bytes([buf[20], buf[21], buf[22], buf[23]]);
let timestamp_ms = u32::from_le_bytes([buf[24], buf[25], buf[26], buf[27]]);
Some(Esp32VitalsPacket {
node_id,
presence: (flags & 0x01) != 0,
fall_detected: (flags & 0x02) != 0,
motion: (flags & 0x04) != 0,
breathing_rate_bpm: breathing_raw as f64 / 100.0,
heartrate_bpm: heartrate_raw as f64 / 10000.0,
rssi, n_persons, motion_energy, presence_score, timestamp_ms,
})
}
/// Parse a WASM output packet (magic 0xC511_0004).
pub fn parse_wasm_output(buf: &[u8]) -> Option<WasmOutputPacket> {
if buf.len() < 8 { return None; }
let magic = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]);
if magic != 0xC511_0004 { return None; }
let node_id = buf[4];
let module_id = buf[5];
let event_count = u16::from_le_bytes([buf[6], buf[7]]) as usize;
let mut events = Vec::with_capacity(event_count);
let mut offset = 8;
for _ in 0..event_count {
if offset + 5 > buf.len() { break; }
let event_type = buf[offset];
let value = f32::from_le_bytes([
buf[offset + 1], buf[offset + 2], buf[offset + 3], buf[offset + 4],
]);
events.push(WasmEvent { event_type, value });
offset += 5;
}
Some(WasmOutputPacket { node_id, module_id, events })
}
pub fn parse_esp32_frame(buf: &[u8]) -> Option<Esp32Frame> {
if buf.len() < 20 { return None; }
let magic = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]);
if magic != 0xC511_0001 { return None; }
let node_id = buf[4];
let n_antennas = buf[5];
let n_subcarriers = buf[6];
let freq_mhz = u16::from_le_bytes([buf[8], buf[9]]);
let sequence = u32::from_le_bytes([buf[10], buf[11], buf[12], buf[13]]);
let rssi_raw = buf[14] as i8;
let rssi = if rssi_raw > 0 { rssi_raw.saturating_neg() } else { rssi_raw };
let noise_floor = buf[15] as i8;
let iq_start = 20;
let n_pairs = n_antennas as usize * n_subcarriers as usize;
let expected_len = iq_start + n_pairs * 2;
if buf.len() < expected_len { return None; }
let mut amplitudes = Vec::with_capacity(n_pairs);
let mut phases = Vec::with_capacity(n_pairs);
for k in 0..n_pairs {
let i_val = buf[iq_start + k * 2] as i8 as f64;
let q_val = buf[iq_start + k * 2 + 1] as i8 as f64;
amplitudes.push((i_val * i_val + q_val * q_val).sqrt());
phases.push(q_val.atan2(i_val));
}
Some(Esp32Frame {
magic, node_id, n_antennas, n_subcarriers, freq_mhz, sequence,
rssi, noise_floor, amplitudes, phases,
})
}
// ── Signal field generation ─────────────────────────────────────────────────
pub fn generate_signal_field(
_mean_rssi: f64, motion_score: f64, breathing_rate_hz: f64,
signal_quality: f64, subcarrier_variances: &[f64],
) -> SignalField {
let grid = 20usize;
let mut values = vec![0.0f64; grid * grid];
let center = (grid as f64 - 1.0) / 2.0;
let max_var = subcarrier_variances.iter().cloned().fold(0.0f64, f64::max);
let norm_factor = if max_var > 1e-9 { max_var } else { 1.0 };
let n_sub = subcarrier_variances.len().max(1);
for (k, &var) in subcarrier_variances.iter().enumerate() {
let weight = (var / norm_factor) * motion_score;
if weight < 1e-6 { continue; }
let angle = (k as f64 / n_sub as f64) * 2.0 * std::f64::consts::PI;
let radius = center * 0.8 * weight.sqrt();
let hx = center + radius * angle.cos();
let hz = center + radius * angle.sin();
for z in 0..grid {
for x in 0..grid {
let dx = x as f64 - hx;
let dz = z as f64 - hz;
let dist2 = dx * dx + dz * dz;
let spread = (0.5 + weight * 2.0).max(0.5);
values[z * grid + x] += weight * (-dist2 / (2.0 * spread * spread)).exp();
}
}
}
for z in 0..grid {
for x in 0..grid {
let dx = x as f64 - center;
let dz = z as f64 - center;
let dist = (dx * dx + dz * dz).sqrt();
let base = signal_quality * (-dist * 0.12).exp();
values[z * grid + x] += base * 0.3;
}
}
if breathing_rate_hz > 0.05 {
let ring_r = center * 0.55;
let ring_width = 1.8f64;
for z in 0..grid {
for x in 0..grid {
let dx = x as f64 - center;
let dz = z as f64 - center;
let dist = (dx * dx + dz * dz).sqrt();
let ring_val = 0.08 * (-(dist - ring_r).powi(2) / (2.0 * ring_width * ring_width)).exp();
values[z * grid + x] += ring_val;
}
}
}
let field_max = values.iter().cloned().fold(0.0f64, f64::max);
let scale = if field_max > 1e-9 { 1.0 / field_max } else { 1.0 };
for v in &mut values { *v = (*v * scale).clamp(0.0, 1.0); }
SignalField { grid_size: [grid, 1, grid], values }
}
// ── Feature extraction ──────────────────────────────────────────────────────
pub fn estimate_breathing_rate_hz(frame_history: &VecDeque<Vec<f64>>, sample_rate_hz: f64) -> f64 {
let n = frame_history.len();
if n < 6 { return 0.0; }
let series: Vec<f64> = frame_history.iter()
.map(|amps| if amps.is_empty() { 0.0 } else { amps.iter().sum::<f64>() / amps.len() as f64 })
.collect();
let mean_s = series.iter().sum::<f64>() / n as f64;
let detrended: Vec<f64> = series.iter().map(|x| x - mean_s).collect();
let n_candidates = 9usize;
let f_low = 0.1f64;
let f_high = 0.5f64;
let mut best_freq = 0.0f64;
let mut best_power = 0.0f64;
for i in 0..n_candidates {
let freq = f_low + (f_high - f_low) * i as f64 / (n_candidates - 1).max(1) as f64;
let omega = 2.0 * std::f64::consts::PI * freq / sample_rate_hz;
let coeff = 2.0 * omega.cos();
let (mut s_prev2, mut s_prev1) = (0.0f64, 0.0f64);
for &x in &detrended {
let s = x + coeff * s_prev1 - s_prev2;
s_prev2 = s_prev1;
s_prev1 = s;
}
let power = s_prev2 * s_prev2 + s_prev1 * s_prev1 - coeff * s_prev1 * s_prev2;
if power > best_power { best_power = power; best_freq = freq; }
}
let avg_power = {
let mut total = 0.0f64;
for i in 0..n_candidates {
let freq = f_low + (f_high - f_low) * i as f64 / (n_candidates - 1).max(1) as f64;
let omega = 2.0 * std::f64::consts::PI * freq / sample_rate_hz;
let coeff = 2.0 * omega.cos();
let (mut s_prev2, mut s_prev1) = (0.0f64, 0.0f64);
for &x in &detrended {
let s = x + coeff * s_prev1 - s_prev2;
s_prev2 = s_prev1;
s_prev1 = s;
}
total += s_prev2 * s_prev2 + s_prev1 * s_prev1 - coeff * s_prev1 * s_prev2;
}
total / n_candidates as f64
};
if best_power > avg_power * 3.0 { best_freq.clamp(f_low, f_high) } else { 0.0 }
}
pub fn compute_subcarrier_importance_weights(sensitivity: &[f64]) -> Vec<f64> {
let n = sensitivity.len();
if n == 0 { return vec![]; }
let max_sens = sensitivity.iter().cloned().fold(f64::NEG_INFINITY, f64::max).max(1e-9);
let mut sorted = sensitivity.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let median = if n % 2 == 0 { (sorted[n / 2 - 1] + sorted[n / 2]) / 2.0 } else { sorted[n / 2] };
sensitivity.iter()
.map(|&s| if s >= median { 1.0 + (s / max_sens).min(1.0) } else { 0.5 })
.collect()
}
pub fn compute_subcarrier_variances(frame_history: &VecDeque<Vec<f64>>, n_sub: usize) -> Vec<f64> {
if frame_history.is_empty() || n_sub == 0 { return vec![0.0; n_sub]; }
let n_frames = frame_history.len() as f64;
let mut means = vec![0.0f64; n_sub];
let mut sq_means = vec![0.0f64; n_sub];
for frame in frame_history.iter() {
for k in 0..n_sub {
let a = if k < frame.len() { frame[k] } else { 0.0 };
means[k] += a;
sq_means[k] += a * a;
}
}
(0..n_sub).map(|k| {
let mean = means[k] / n_frames;
let sq_mean = sq_means[k] / n_frames;
(sq_mean - mean * mean).max(0.0)
}).collect()
}
pub fn extract_features_from_frame(
frame: &Esp32Frame, frame_history: &VecDeque<Vec<f64>>, sample_rate_hz: f64,
) -> (FeatureInfo, ClassificationInfo, f64, Vec<f64>, f64) {
let n_sub = frame.amplitudes.len().max(1);
let n = n_sub as f64;
let mean_rssi = frame.rssi as f64;
let sub_sensitivity: Vec<f64> = frame.amplitudes.iter().map(|a| a.abs()).collect();
let importance_weights = compute_subcarrier_importance_weights(&sub_sensitivity);
let weight_sum: f64 = importance_weights.iter().sum::<f64>();
let mean_amp: f64 = if weight_sum > 0.0 {
frame.amplitudes.iter().zip(importance_weights.iter())
.map(|(a, w)| a * w).sum::<f64>() / weight_sum
} else {
frame.amplitudes.iter().sum::<f64>() / n
};
let intra_variance: f64 = if weight_sum > 0.0 {
frame.amplitudes.iter().zip(importance_weights.iter())
.map(|(a, w)| w * (a - mean_amp).powi(2)).sum::<f64>() / weight_sum
} else {
frame.amplitudes.iter().map(|a| (a - mean_amp).powi(2)).sum::<f64>() / n
};
let sub_variances = compute_subcarrier_variances(frame_history, n_sub);
let temporal_variance: f64 = if sub_variances.is_empty() {
intra_variance
} else {
sub_variances.iter().sum::<f64>() / sub_variances.len() as f64
};
let variance = intra_variance.max(temporal_variance);
let spectral_power: f64 = frame.amplitudes.iter().map(|a| a * a).sum::<f64>() / n;
let half = frame.amplitudes.len() / 2;
let motion_band_power = if half > 0 {
frame.amplitudes[half..].iter().map(|a| (a - mean_amp).powi(2)).sum::<f64>()
/ (frame.amplitudes.len() - half) as f64
} else { 0.0 };
let breathing_band_power = if half > 0 {
frame.amplitudes[..half].iter().map(|a| (a - mean_amp).powi(2)).sum::<f64>() / half as f64
} else { 0.0 };
let peak_idx = frame.amplitudes.iter().enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i).unwrap_or(0);
let dominant_freq_hz = peak_idx as f64 * 0.05;
let threshold = mean_amp * 1.2;
let change_points = frame.amplitudes.windows(2)
.filter(|w| (w[0] < threshold) != (w[1] < threshold)).count();
let temporal_motion_score = if let Some(prev_frame) = frame_history.back() {
let n_cmp = n_sub.min(prev_frame.len());
if n_cmp > 0 {
let diff_energy: f64 = (0..n_cmp)
.map(|k| (frame.amplitudes[k] - prev_frame[k]).powi(2)).sum::<f64>() / n_cmp as f64;
let ref_energy = mean_amp * mean_amp + 1e-9;
(diff_energy / ref_energy).sqrt().clamp(0.0, 1.0)
} else { 0.0 }
} else {
(intra_variance / (mean_amp * mean_amp + 1e-9)).sqrt().clamp(0.0, 1.0)
};
let variance_motion = (temporal_variance / 10.0).clamp(0.0, 1.0);
let mbp_motion = (motion_band_power / 25.0).clamp(0.0, 1.0);
let cp_motion = (change_points as f64 / 15.0).clamp(0.0, 1.0);
let motion_score = (temporal_motion_score * 0.4 + variance_motion * 0.2
+ mbp_motion * 0.25 + cp_motion * 0.15).clamp(0.0, 1.0);
let snr_db = (frame.rssi as f64 - frame.noise_floor as f64).max(0.0);
let snr_quality = (snr_db / 40.0).clamp(0.0, 1.0);
let stability = (1.0 - (temporal_variance / (mean_amp * mean_amp + 1e-9)).clamp(0.0, 1.0)).max(0.0);
let signal_quality = (snr_quality * 0.6 + stability * 0.4).clamp(0.0, 1.0);
let breathing_rate_hz = estimate_breathing_rate_hz(frame_history, sample_rate_hz);
let features = FeatureInfo {
mean_rssi, variance, motion_band_power, breathing_band_power,
dominant_freq_hz, change_points, spectral_power,
};
let raw_classification = ClassificationInfo {
motion_level: raw_classify(motion_score),
presence: motion_score > 0.04,
confidence: (0.4 + signal_quality * 0.3 + motion_score * 0.3).clamp(0.0, 1.0),
};
(features, raw_classification, breathing_rate_hz, sub_variances, motion_score)
}
// ── Classification ──────────────────────────────────────────────────────────
pub fn raw_classify(score: f64) -> String {
if score > 0.25 { "active".into() }
else if score > 0.12 { "present_moving".into() }
else if score > 0.04 { "present_still".into() }
else { "absent".into() }
}
pub fn smooth_and_classify(state: &mut AppStateInner, raw: &mut ClassificationInfo, raw_motion: f64) {
state.baseline_frames += 1;
if state.baseline_frames < BASELINE_WARMUP {
state.baseline_motion = state.baseline_motion * 0.9 + raw_motion * 0.1;
} else if raw_motion < state.smoothed_motion + 0.05 {
state.baseline_motion = state.baseline_motion * (1.0 - BASELINE_EMA_ALPHA)
+ raw_motion * BASELINE_EMA_ALPHA;
}
let adjusted = (raw_motion - state.baseline_motion * 0.7).max(0.0);
state.smoothed_motion = state.smoothed_motion * (1.0 - MOTION_EMA_ALPHA) + adjusted * MOTION_EMA_ALPHA;
let sm = state.smoothed_motion;
let candidate = raw_classify(sm);
if candidate == state.current_motion_level {
state.debounce_counter = 0;
state.debounce_candidate = candidate;
} else if candidate == state.debounce_candidate {
state.debounce_counter += 1;
if state.debounce_counter >= DEBOUNCE_FRAMES {
state.current_motion_level = candidate;
state.debounce_counter = 0;
}
} else {
state.debounce_candidate = candidate;
state.debounce_counter = 1;
}
raw.motion_level = state.current_motion_level.clone();
raw.presence = sm > 0.03;
raw.confidence = (0.4 + sm * 0.6).clamp(0.0, 1.0);
}
pub fn smooth_and_classify_node(ns: &mut NodeState, raw: &mut ClassificationInfo, raw_motion: f64) {
ns.baseline_frames += 1;
if ns.baseline_frames < BASELINE_WARMUP {
ns.baseline_motion = ns.baseline_motion * 0.9 + raw_motion * 0.1;
} else if raw_motion < ns.smoothed_motion + 0.05 {
ns.baseline_motion = ns.baseline_motion * (1.0 - BASELINE_EMA_ALPHA) + raw_motion * BASELINE_EMA_ALPHA;
}
let adjusted = (raw_motion - ns.baseline_motion * 0.7).max(0.0);
ns.smoothed_motion = ns.smoothed_motion * (1.0 - MOTION_EMA_ALPHA) + adjusted * MOTION_EMA_ALPHA;
let sm = ns.smoothed_motion;
let candidate = raw_classify(sm);
if candidate == ns.current_motion_level {
ns.debounce_counter = 0;
ns.debounce_candidate = candidate;
} else if candidate == ns.debounce_candidate {
ns.debounce_counter += 1;
if ns.debounce_counter >= DEBOUNCE_FRAMES {
ns.current_motion_level = candidate;
ns.debounce_counter = 0;
}
} else {
ns.debounce_candidate = candidate;
ns.debounce_counter = 1;
}
raw.motion_level = ns.current_motion_level.clone();
raw.presence = sm > 0.03;
raw.confidence = (0.4 + sm * 0.6).clamp(0.0, 1.0);
}
pub fn adaptive_override(state: &AppStateInner, features: &FeatureInfo, classification: &mut ClassificationInfo) {
if let Some(ref model) = state.adaptive_model {
let amps = state.frame_history.back().map(|v| v.as_slice()).unwrap_or(&[]);
let feat_arr = adaptive_classifier::features_from_runtime(
&serde_json::json!({
"variance": features.variance,
"motion_band_power": features.motion_band_power,
"breathing_band_power": features.breathing_band_power,
"spectral_power": features.spectral_power,
"dominant_freq_hz": features.dominant_freq_hz,
"change_points": features.change_points,
"mean_rssi": features.mean_rssi,
}),
amps,
);
let (label, conf) = model.classify(&feat_arr);
classification.motion_level = label.to_string();
classification.presence = label != "absent";
classification.confidence = (conf * 0.7 + classification.confidence * 0.3).clamp(0.0, 1.0);
}
}
// ── Vital signs smoothing ───────────────────────────────────────────────────
fn trimmed_mean(buf: &VecDeque<f64>) -> f64 {
if buf.is_empty() { return 0.0; }
let mut sorted: Vec<f64> = buf.iter().copied().collect();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let n = sorted.len();
let trim = n / 4;
let middle = &sorted[trim..n - trim.max(0)];
if middle.is_empty() { sorted[n / 2] } else { middle.iter().sum::<f64>() / middle.len() as f64 }
}
pub fn smooth_vitals(state: &mut AppStateInner, raw: &VitalSigns) -> VitalSigns {
let raw_hr = raw.heart_rate_bpm.unwrap_or(0.0);
let raw_br = raw.breathing_rate_bpm.unwrap_or(0.0);
let hr_ok = state.smoothed_hr < 1.0 || (raw_hr - state.smoothed_hr).abs() < HR_MAX_JUMP;
let br_ok = state.smoothed_br < 1.0 || (raw_br - state.smoothed_br).abs() < BR_MAX_JUMP;
if hr_ok && raw_hr > 0.0 {
state.hr_buffer.push_back(raw_hr);
if state.hr_buffer.len() > VITAL_MEDIAN_WINDOW { state.hr_buffer.pop_front(); }
}
if br_ok && raw_br > 0.0 {
state.br_buffer.push_back(raw_br);
if state.br_buffer.len() > VITAL_MEDIAN_WINDOW { state.br_buffer.pop_front(); }
}
let trimmed_hr = trimmed_mean(&state.hr_buffer);
let trimmed_br = trimmed_mean(&state.br_buffer);
if trimmed_hr > 0.0 {
if state.smoothed_hr < 1.0 { state.smoothed_hr = trimmed_hr; }
else if (trimmed_hr - state.smoothed_hr).abs() > HR_DEAD_BAND {
state.smoothed_hr = state.smoothed_hr * (1.0 - VITAL_EMA_ALPHA) + trimmed_hr * VITAL_EMA_ALPHA;
}
}
if trimmed_br > 0.0 {
if state.smoothed_br < 1.0 { state.smoothed_br = trimmed_br; }
else if (trimmed_br - state.smoothed_br).abs() > BR_DEAD_BAND {
state.smoothed_br = state.smoothed_br * (1.0 - VITAL_EMA_ALPHA) + trimmed_br * VITAL_EMA_ALPHA;
}
}
state.smoothed_hr_conf = state.smoothed_hr_conf * 0.92 + raw.heartbeat_confidence * 0.08;
state.smoothed_br_conf = state.smoothed_br_conf * 0.92 + raw.breathing_confidence * 0.08;
VitalSigns {
breathing_rate_bpm: if state.smoothed_br > 1.0 { Some(state.smoothed_br) } else { None },
heart_rate_bpm: if state.smoothed_hr > 1.0 { Some(state.smoothed_hr) } else { None },
breathing_confidence: state.smoothed_br_conf,
heartbeat_confidence: state.smoothed_hr_conf,
signal_quality: raw.signal_quality,
}
}
pub fn smooth_vitals_node(ns: &mut NodeState, raw: &VitalSigns) -> VitalSigns {
let raw_hr = raw.heart_rate_bpm.unwrap_or(0.0);
let raw_br = raw.breathing_rate_bpm.unwrap_or(0.0);
let hr_ok = ns.smoothed_hr < 1.0 || (raw_hr - ns.smoothed_hr).abs() < HR_MAX_JUMP;
let br_ok = ns.smoothed_br < 1.0 || (raw_br - ns.smoothed_br).abs() < BR_MAX_JUMP;
if hr_ok && raw_hr > 0.0 {
ns.hr_buffer.push_back(raw_hr);
if ns.hr_buffer.len() > VITAL_MEDIAN_WINDOW { ns.hr_buffer.pop_front(); }
}
if br_ok && raw_br > 0.0 {
ns.br_buffer.push_back(raw_br);
if ns.br_buffer.len() > VITAL_MEDIAN_WINDOW { ns.br_buffer.pop_front(); }
}
let trimmed_hr = trimmed_mean(&ns.hr_buffer);
let trimmed_br = trimmed_mean(&ns.br_buffer);
if trimmed_hr > 0.0 {
if ns.smoothed_hr < 1.0 { ns.smoothed_hr = trimmed_hr; }
else if (trimmed_hr - ns.smoothed_hr).abs() > HR_DEAD_BAND {
ns.smoothed_hr = ns.smoothed_hr * (1.0 - VITAL_EMA_ALPHA) + trimmed_hr * VITAL_EMA_ALPHA;
}
}
if trimmed_br > 0.0 {
if ns.smoothed_br < 1.0 { ns.smoothed_br = trimmed_br; }
else if (trimmed_br - ns.smoothed_br).abs() > BR_DEAD_BAND {
ns.smoothed_br = ns.smoothed_br * (1.0 - VITAL_EMA_ALPHA) + trimmed_br * VITAL_EMA_ALPHA;
}
}
ns.smoothed_hr_conf = ns.smoothed_hr_conf * 0.92 + raw.heartbeat_confidence * 0.08;
ns.smoothed_br_conf = ns.smoothed_br_conf * 0.92 + raw.breathing_confidence * 0.08;
VitalSigns {
breathing_rate_bpm: if ns.smoothed_br > 1.0 { Some(ns.smoothed_br) } else { None },
heart_rate_bpm: if ns.smoothed_hr > 1.0 { Some(ns.smoothed_hr) } else { None },
breathing_confidence: ns.smoothed_br_conf,
heartbeat_confidence: ns.smoothed_hr_conf,
signal_quality: raw.signal_quality,
}
}
// ── Multi-person estimation ─────────────────────────────────────────────────
pub fn fuse_multi_node_features(
current_features: &FeatureInfo, node_states: &HashMap<u8, NodeState>,
) -> FeatureInfo {
let now = std::time::Instant::now();
let active: Vec<(&FeatureInfo, f64)> = node_states.values()
.filter(|ns| ns.last_frame_time.map_or(false, |t| now.duration_since(t).as_secs() < 10))
.filter_map(|ns| {
let feat = ns.latest_features.as_ref()?;
let rssi = ns.rssi_history.back().copied().unwrap_or(-80.0);
Some((feat, rssi))
})
.collect();
if active.len() <= 1 { return current_features.clone(); }
let max_rssi = active.iter().map(|(_, r)| *r).fold(f64::NEG_INFINITY, f64::max);
let weights: Vec<f64> = active.iter()
.map(|(_, r)| (1.0 + (r - max_rssi + 20.0) / 20.0).clamp(0.1, 1.0)).collect();
let w_sum: f64 = weights.iter().sum::<f64>().max(1e-9);
FeatureInfo {
variance: active.iter().zip(&weights).map(|((f, _), w)| f.variance * w).sum::<f64>() / w_sum,
motion_band_power: active.iter().zip(&weights).map(|((f, _), w)| f.motion_band_power * w).sum::<f64>() / w_sum,
breathing_band_power: active.iter().zip(&weights).map(|((f, _), w)| f.breathing_band_power * w).sum::<f64>() / w_sum,
spectral_power: active.iter().zip(&weights).map(|((f, _), w)| f.spectral_power * w).sum::<f64>() / w_sum,
dominant_freq_hz: active.iter().zip(&weights).map(|((f, _), w)| f.dominant_freq_hz * w).sum::<f64>() / w_sum,
change_points: current_features.change_points,
mean_rssi: active.iter().map(|(f, _)| f.mean_rssi).fold(f64::NEG_INFINITY, f64::max),
}
}
pub fn compute_person_score(feat: &FeatureInfo) -> f64 {
let var_norm = (feat.variance / 300.0).clamp(0.0, 1.0);
let cp_norm = (feat.change_points as f64 / 30.0).clamp(0.0, 1.0);
let motion_norm = (feat.motion_band_power / 250.0).clamp(0.0, 1.0);
let sp_norm = (feat.spectral_power / 500.0).clamp(0.0, 1.0);
var_norm * 0.40 + cp_norm * 0.20 + motion_norm * 0.25 + sp_norm * 0.15
}
pub fn estimate_persons_from_correlation(frame_history: &VecDeque<Vec<f64>>) -> usize {
let n_frames = frame_history.len();
if n_frames < 10 { return 1; }
let window: Vec<&Vec<f64>> = frame_history.iter().rev().take(20).collect();
let n_sub = window[0].len().min(56);
if n_sub < 4 { return 1; }
let k = window.len() as f64;
let mut means = vec![0.0f64; n_sub];
let mut variances = vec![0.0f64; n_sub];
for frame in &window {
for sc in 0..n_sub.min(frame.len()) { means[sc] += frame[sc] / k; }
}
for frame in &window {
for sc in 0..n_sub.min(frame.len()) { variances[sc] += (frame[sc] - means[sc]).powi(2) / k; }
}
let noise_floor = 1.0;
let active: Vec<usize> = (0..n_sub).filter(|&sc| variances[sc] > noise_floor).collect();
let m = active.len();
if m < 3 { return if m == 0 { 0 } else { 1 }; }
let mut edges: Vec<(u64, u64, f64)> = Vec::new();
let source = m as u64;
let sink = (m + 1) as u64;
let stds: Vec<f64> = active.iter().map(|&sc| variances[sc].sqrt().max(1e-9)).collect();
for i in 0..m {
for j in (i + 1)..m {
let mut cov = 0.0f64;
for frame in &window {
let (si, sj) = (active[i], active[j]);
if si < frame.len() && sj < frame.len() {
cov += (frame[si] - means[si]) * (frame[sj] - means[sj]) / k;
}
}
let corr = (cov / (stds[i] * stds[j])).abs();
if corr > 0.1 {
let weight = corr * 10.0;
edges.push((i as u64, j as u64, weight));
edges.push((j as u64, i as u64, weight));
}
}
}
let (max_var_idx, _) = active.iter().enumerate()
.max_by(|(_, &a), (_, &b)| variances[a].partial_cmp(&variances[b]).unwrap())
.unwrap_or((0, &0));
let (min_var_idx, _) = active.iter().enumerate()
.min_by(|(_, &a), (_, &b)| variances[a].partial_cmp(&variances[b]).unwrap())
.unwrap_or((0, &0));
if max_var_idx == min_var_idx { return 1; }
edges.push((source, max_var_idx as u64, 100.0));
edges.push((min_var_idx as u64, sink, 100.0));
let mc: DynamicMinCut = match MinCutBuilder::new().exact().with_edges(edges.clone()).build() {
Ok(mc) => mc,
Err(_) => return 1,
};
let cut_value = mc.min_cut_value();
let total_edge_weight: f64 = edges.iter()
.filter(|(s, t, _)| *s != source && *s != sink && *t != source && *t != sink)
.map(|(_, _, w)| w).sum::<f64>() / 2.0;
if total_edge_weight < 1e-9 { return 1; }
let cut_ratio = cut_value / total_edge_weight;
if cut_ratio > 0.4 { 1 }
else if cut_ratio > 0.15 { 2 }
else { 3 }
}
pub fn score_to_person_count(smoothed_score: f64, prev_count: usize) -> usize {
match prev_count {
0 | 1 => {
if smoothed_score > 0.85 { 3 }
else if smoothed_score > 0.70 { 2 }
else { 1 }
}
2 => {
if smoothed_score > 0.92 { 3 }
else if smoothed_score < 0.55 { 1 }
else { 2 }
}
_ => {
if smoothed_score < 0.55 { 1 }
else if smoothed_score < 0.78 { 2 }
else { 3 }
}
}
}
/// Generate a simulated ESP32 frame for testing/demo mode.
pub fn generate_simulated_frame(tick: u64) -> Esp32Frame {
let t = tick as f64 * 0.1;
let n_sub = 56usize;
let mut amplitudes = Vec::with_capacity(n_sub);
let mut phases = Vec::with_capacity(n_sub);
for i in 0..n_sub {
let base = 15.0 + 5.0 * (i as f64 * 0.1 + t * 0.3).sin();
let noise = (i as f64 * 7.3 + t * 13.7).sin() * 2.0;
amplitudes.push((base + noise).max(0.1));
phases.push((i as f64 * 0.2 + t * 0.5).sin() * std::f64::consts::PI);
}
Esp32Frame {
magic: 0xC511_0001, node_id: 1, n_antennas: 1, n_subcarriers: n_sub as u8,
freq_mhz: 2437, sequence: tick as u32,
rssi: (-40.0 + 5.0 * (t * 0.2).sin()) as i8, noise_floor: -90,
amplitudes, phases,
}
}
/// Generate a simple timestamp (epoch seconds) for recording IDs.
pub fn chrono_timestamp() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}

View file

@ -9,11 +9,15 @@
//! Replaces both ws_server.py and the Python HTTP server.
mod adaptive_classifier;
pub mod cli;
pub mod csi;
mod field_bridge;
mod multistatic_bridge;
pub mod pose;
mod rvf_container;
mod rvf_pipeline;
mod tracker_bridge;
pub mod types;
mod vital_signs;
// Training pipeline modules (exposed via lib.rs)

View file

@ -0,0 +1,194 @@
//! Skeleton derivation, pose estimation, and temporal smoothing.
use crate::types::*;
/// Expected bone lengths in pixel-space for the COCO-17 skeleton.
pub const POSE_BONE_PAIRS: &[(usize, usize)] = &[
(5, 7), (7, 9), (6, 8), (8, 10),
(5, 11), (6, 12),
(11, 13), (13, 15), (12, 14), (14, 16),
(5, 6), (11, 12),
];
const TORSO_KP: [usize; 4] = [5, 6, 11, 12];
const EXTREMITY_KP: [usize; 4] = [9, 10, 15, 16];
pub fn derive_single_person_pose(
update: &SensingUpdate, person_idx: usize, total_persons: usize,
) -> PersonDetection {
let cls = &update.classification;
let feat = &update.features;
let phase_offset = person_idx as f64 * 2.094;
let half = (total_persons as f64 - 1.0) / 2.0;
let person_x_offset = (person_idx as f64 - half) * 120.0;
let conf_decay = 1.0 - person_idx as f64 * 0.15;
let motion_score = (feat.motion_band_power / 15.0).clamp(0.0, 1.0);
let is_walking = motion_score > 0.55;
let breath_amp = (feat.breathing_band_power * 4.0).clamp(0.0, 12.0);
let breath_phase = if let Some(ref vs) = update.vital_signs {
let bpm = vs.breathing_rate_bpm.unwrap_or(15.0);
let freq = (bpm / 60.0).clamp(0.1, 0.5);
(update.tick as f64 * freq * 0.02 * std::f64::consts::TAU + phase_offset).sin()
} else {
(update.tick as f64 * 0.02 + phase_offset).sin()
};
let lean_x = (feat.dominant_freq_hz / 5.0 - 1.0).clamp(-1.0, 1.0) * 18.0;
let stride_x = if is_walking {
let stride_phase = (feat.motion_band_power * 0.7 + update.tick as f64 * 0.06 + phase_offset).sin();
stride_phase * 20.0 * motion_score
} else { 0.0 };
let burst = (feat.change_points as f64 / 20.0).clamp(0.0, 0.3);
let noise_seed = person_idx as f64 * 97.1;
let noise_val = (noise_seed.sin() * 43758.545).fract();
let snr_factor = ((feat.variance - 0.5) / 10.0).clamp(0.0, 1.0);
let base_confidence = cls.confidence * (0.6 + 0.4 * snr_factor) * conf_decay;
let base_x = 320.0 + stride_x + lean_x * 0.5 + person_x_offset;
let base_y = 240.0 - motion_score * 8.0;
let kp_names = [
"nose", "left_eye", "right_eye", "left_ear", "right_ear",
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
"left_wrist", "right_wrist", "left_hip", "right_hip",
"left_knee", "right_knee", "left_ankle", "right_ankle",
];
let kp_offsets: [(f64, f64); 17] = [
(0.0, -80.0), (-8.0, -88.0), (8.0, -88.0), (-16.0, -82.0), (16.0, -82.0),
(-30.0, -50.0), (30.0, -50.0), (-45.0, -15.0), (45.0, -15.0),
(-50.0, 20.0), (50.0, 20.0), (-20.0, 20.0), (20.0, 20.0),
(-22.0, 70.0), (22.0, 70.0), (-24.0, 120.0), (24.0, 120.0),
];
let keypoints: Vec<PoseKeypoint> = kp_names.iter().zip(kp_offsets.iter())
.enumerate()
.map(|(i, (name, (dx, dy)))| {
let breath_dx = if TORSO_KP.contains(&i) {
let sign = if *dx < 0.0 { -1.0 } else { 1.0 };
sign * breath_amp * breath_phase * 0.5
} else { 0.0 };
let breath_dy = if TORSO_KP.contains(&i) {
let sign = if *dy < 0.0 { -1.0 } else { 1.0 };
sign * breath_amp * breath_phase * 0.3
} else { 0.0 };
let extremity_jitter = if EXTREMITY_KP.contains(&i) {
let phase = noise_seed + i as f64 * 2.399;
(phase.sin() * burst * motion_score * 4.0, (phase * 1.31).cos() * burst * motion_score * 3.0)
} else { (0.0, 0.0) };
let kp_noise_x = ((noise_seed + i as f64 * 1.618).sin() * 43758.545).fract()
* feat.variance.sqrt().clamp(0.0, 3.0) * motion_score;
let kp_noise_y = ((noise_seed + i as f64 * 2.718).cos() * 31415.926).fract()
* feat.variance.sqrt().clamp(0.0, 3.0) * motion_score * 0.6;
let swing_dy = if is_walking {
let stride_phase = (feat.motion_band_power * 0.7 + update.tick as f64 * 0.12 + phase_offset).sin();
match i {
7 | 9 => -stride_phase * 20.0 * motion_score,
8 | 10 => stride_phase * 20.0 * motion_score,
13 | 15 => stride_phase * 25.0 * motion_score,
14 | 16 => -stride_phase * 25.0 * motion_score,
_ => 0.0,
}
} else { 0.0 };
let final_x = base_x + dx + breath_dx + extremity_jitter.0 + kp_noise_x;
let final_y = base_y + dy + breath_dy + extremity_jitter.1 + kp_noise_y + swing_dy;
let kp_conf = if EXTREMITY_KP.contains(&i) {
base_confidence * (0.7 + 0.3 * snr_factor) * (0.85 + 0.15 * noise_val)
} else {
base_confidence * (0.88 + 0.12 * ((i as f64 * 0.7 + noise_seed).cos()))
};
PoseKeypoint { name: name.to_string(), x: final_x, y: final_y, z: lean_x * 0.02, confidence: kp_conf.clamp(0.1, 1.0) }
})
.collect();
let xs: Vec<f64> = keypoints.iter().map(|k| k.x).collect();
let ys: Vec<f64> = keypoints.iter().map(|k| k.y).collect();
let min_x = xs.iter().cloned().fold(f64::MAX, f64::min) - 10.0;
let min_y = ys.iter().cloned().fold(f64::MAX, f64::min) - 10.0;
let max_x = xs.iter().cloned().fold(f64::MIN, f64::max) + 10.0;
let max_y = ys.iter().cloned().fold(f64::MIN, f64::max) + 10.0;
PersonDetection {
id: (person_idx + 1) as u32,
confidence: cls.confidence * conf_decay,
keypoints,
bbox: BoundingBox { x: min_x, y: min_y, width: (max_x - min_x).max(80.0), height: (max_y - min_y).max(160.0) },
zone: format!("zone_{}", person_idx + 1),
}
}
pub fn derive_pose_from_sensing(update: &SensingUpdate) -> Vec<PersonDetection> {
let cls = &update.classification;
if !cls.presence { return vec![]; }
let person_count = update.estimated_persons.unwrap_or(1).max(1);
(0..person_count).map(|idx| derive_single_person_pose(update, idx, person_count)).collect()
}
/// Apply temporal EMA smoothing and bone-length clamping to person detections.
pub fn apply_temporal_smoothing(persons: &mut [PersonDetection], ns: &mut NodeState) {
if persons.is_empty() { return; }
let alpha = ns.ema_alpha();
let person = &mut persons[0];
let current_kps: Vec<[f64; 3]> = person.keypoints.iter()
.map(|kp| [kp.x, kp.y, kp.z]).collect();
let smoothed = if let Some(ref prev) = ns.prev_keypoints {
let mut out = Vec::with_capacity(current_kps.len());
for (cur, prv) in current_kps.iter().zip(prev.iter()) {
out.push([
alpha * cur[0] + (1.0 - alpha) * prv[0],
alpha * cur[1] + (1.0 - alpha) * prv[1],
alpha * cur[2] + (1.0 - alpha) * prv[2],
]);
}
clamp_bone_lengths_f64(&mut out, prev);
out
} else {
current_kps.clone()
};
for (kp, s) in person.keypoints.iter_mut().zip(smoothed.iter()) {
kp.x = s[0]; kp.y = s[1]; kp.z = s[2];
}
ns.prev_keypoints = Some(smoothed);
}
fn clamp_bone_lengths_f64(pose: &mut Vec<[f64; 3]>, prev: &[[f64; 3]]) {
for &(p, c) in POSE_BONE_PAIRS {
if p >= pose.len() || c >= pose.len() { continue; }
let prev_len = dist_f64(&prev[p], &prev[c]);
if prev_len < 1e-6 { continue; }
let cur_len = dist_f64(&pose[p], &pose[c]);
if cur_len < 1e-6 { continue; }
let ratio = cur_len / prev_len;
let lo = 1.0 - MAX_BONE_CHANGE_RATIO;
let hi = 1.0 + MAX_BONE_CHANGE_RATIO;
if ratio < lo || ratio > hi {
let target = prev_len * ratio.clamp(lo, hi);
let scale = target / cur_len;
for dim in 0..3 {
let diff = pose[c][dim] - pose[p][dim];
pose[c][dim] = pose[p][dim] + diff * scale;
}
}
}
}
fn dist_f64(a: &[f64; 3], b: &[f64; 3]) -> f64 {
let dx = b[0] - a[0];
let dy = b[1] - a[1];
let dz = b[2] - a[2];
(dx * dx + dy * dy + dz * dz).sqrt()
}

View file

@ -0,0 +1,403 @@
//! Data types, constants, and shared state definitions.
use std::collections::{HashMap, VecDeque};
use std::path::PathBuf;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use tokio::sync::{broadcast, RwLock};
use crate::adaptive_classifier;
use crate::rvf_container::RvfContainerInfo;
use crate::rvf_pipeline::ProgressiveLoader;
use crate::vital_signs::{VitalSignDetector, VitalSigns};
use wifi_densepose_signal::ruvsense::pose_tracker::PoseTracker;
use wifi_densepose_signal::ruvsense::multistatic::MultistaticFuser;
use wifi_densepose_signal::ruvsense::field_model::FieldModel;
// ── Constants ───────────────────────────────────────────────────────────────
/// Number of frames retained in `frame_history` for temporal analysis.
pub const FRAME_HISTORY_CAPACITY: usize = 100;
/// If no ESP32 frame arrives within this duration, source reverts to offline.
pub const ESP32_OFFLINE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
/// Default EMA alpha for temporal keypoint smoothing (RuVector Phase 2).
pub const TEMPORAL_EMA_ALPHA_DEFAULT: f64 = 0.15;
/// Reduced EMA alpha when coherence is low.
pub const TEMPORAL_EMA_ALPHA_LOW_COHERENCE: f64 = 0.05;
/// Coherence threshold below which we reduce EMA alpha.
pub const COHERENCE_LOW_THRESHOLD: f64 = 0.3;
/// Maximum allowed bone-length change ratio between frames (20%).
pub const MAX_BONE_CHANGE_RATIO: f64 = 0.20;
/// Number of motion_energy frames to track for coherence scoring.
pub const COHERENCE_WINDOW: usize = 20;
/// Debounce frames required before state transition (at ~10 FPS = ~0.4s).
pub const DEBOUNCE_FRAMES: u32 = 4;
/// EMA alpha for motion smoothing (~1s time constant at 10 FPS).
pub const MOTION_EMA_ALPHA: f64 = 0.15;
/// EMA alpha for slow-adapting baseline (~30s time constant at 10 FPS).
pub const BASELINE_EMA_ALPHA: f64 = 0.003;
/// Number of warm-up frames before baseline subtraction kicks in.
pub const BASELINE_WARMUP: u64 = 50;
/// Size of the median filter window for vital signs outlier rejection.
pub const VITAL_MEDIAN_WINDOW: usize = 21;
/// EMA alpha for vital signs (~5s time constant at 10 FPS).
pub const VITAL_EMA_ALPHA: f64 = 0.02;
/// Maximum BPM jump per frame before a value is rejected as an outlier.
pub const HR_MAX_JUMP: f64 = 8.0;
pub const BR_MAX_JUMP: f64 = 2.0;
/// Minimum change from current smoothed value before EMA updates (dead-band).
pub const HR_DEAD_BAND: f64 = 2.0;
pub const BR_DEAD_BAND: f64 = 0.5;
// ── ESP32 Frame ─────────────────────────────────────────────────────────────
/// ADR-018 ESP32 CSI binary frame header (20 bytes)
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct Esp32Frame {
pub magic: u32,
pub node_id: u8,
pub n_antennas: u8,
pub n_subcarriers: u8,
pub freq_mhz: u16,
pub sequence: u32,
pub rssi: i8,
pub noise_floor: i8,
pub amplitudes: Vec<f64>,
pub phases: Vec<f64>,
}
// ── Sensing Update ──────────────────────────────────────────────────────────
/// Sensing update broadcast to WebSocket clients
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SensingUpdate {
#[serde(rename = "type")]
pub msg_type: String,
pub timestamp: f64,
pub source: String,
pub tick: u64,
pub nodes: Vec<NodeInfo>,
pub features: FeatureInfo,
pub classification: ClassificationInfo,
pub signal_field: SignalField,
#[serde(skip_serializing_if = "Option::is_none")]
pub vital_signs: Option<VitalSigns>,
#[serde(skip_serializing_if = "Option::is_none")]
pub enhanced_motion: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub enhanced_breathing: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub posture: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub signal_quality_score: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub quality_verdict: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub bssid_count: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub pose_keypoints: Option<Vec<[f64; 4]>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model_status: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub persons: Option<Vec<PersonDetection>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub estimated_persons: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub node_features: Option<Vec<PerNodeFeatureInfo>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeInfo {
pub node_id: u8,
pub rssi_dbm: f64,
pub position: [f64; 3],
pub amplitude: Vec<f64>,
pub subcarrier_count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FeatureInfo {
pub mean_rssi: f64,
pub variance: f64,
pub motion_band_power: f64,
pub breathing_band_power: f64,
pub dominant_freq_hz: f64,
pub change_points: usize,
pub spectral_power: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClassificationInfo {
pub motion_level: String,
pub presence: bool,
pub confidence: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SignalField {
pub grid_size: [usize; 3],
pub values: Vec<f64>,
}
/// WiFi-derived pose keypoint (17 COCO keypoints)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PoseKeypoint {
pub name: String,
pub x: f64,
pub y: f64,
pub z: f64,
pub confidence: f64,
}
/// Person detection from WiFi sensing
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PersonDetection {
pub id: u32,
pub confidence: f64,
pub keypoints: Vec<PoseKeypoint>,
pub bbox: BoundingBox,
pub zone: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BoundingBox {
pub x: f64,
pub y: f64,
pub width: f64,
pub height: f64,
}
/// Per-node feature info for WebSocket broadcasts (multi-node support).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerNodeFeatureInfo {
pub node_id: u8,
pub features: FeatureInfo,
pub classification: ClassificationInfo,
pub rssi_dbm: f64,
pub last_seen_ms: u64,
pub frame_rate_hz: f64,
pub stale: bool,
}
// ── ESP32 Edge Vitals Packet (ADR-039) ──────────────────────────────────────
/// Decoded vitals packet from ESP32 edge processing pipeline.
#[derive(Debug, Clone, Serialize)]
pub struct Esp32VitalsPacket {
pub node_id: u8,
pub presence: bool,
pub fall_detected: bool,
pub motion: bool,
pub breathing_rate_bpm: f64,
pub heartrate_bpm: f64,
pub rssi: i8,
pub n_persons: u8,
pub motion_energy: f32,
pub presence_score: f32,
pub timestamp_ms: u32,
}
/// Single WASM event (type + value).
#[derive(Debug, Clone, Serialize)]
pub struct WasmEvent {
pub event_type: u8,
pub value: f32,
}
/// Decoded WASM output packet from ESP32 Tier 3 runtime.
#[derive(Debug, Clone, Serialize)]
pub struct WasmOutputPacket {
pub node_id: u8,
pub module_id: u8,
pub events: Vec<WasmEvent>,
}
// ── Per-node state ──────────────────────────────────────────────────────────
/// Per-node sensing state for multi-node deployments (issue #249).
pub struct NodeState {
pub frame_history: VecDeque<Vec<f64>>,
pub smoothed_person_score: f64,
pub prev_person_count: usize,
pub smoothed_motion: f64,
pub current_motion_level: String,
pub debounce_counter: u32,
pub debounce_candidate: String,
pub baseline_motion: f64,
pub baseline_frames: u64,
pub smoothed_hr: f64,
pub smoothed_br: f64,
pub smoothed_hr_conf: f64,
pub smoothed_br_conf: f64,
pub hr_buffer: VecDeque<f64>,
pub br_buffer: VecDeque<f64>,
pub rssi_history: VecDeque<f64>,
pub vital_detector: VitalSignDetector,
pub latest_vitals: VitalSigns,
pub last_frame_time: Option<std::time::Instant>,
pub edge_vitals: Option<Esp32VitalsPacket>,
pub latest_features: Option<FeatureInfo>,
pub prev_keypoints: Option<Vec<[f64; 3]>>,
pub motion_energy_history: VecDeque<f64>,
pub coherence_score: f64,
}
impl NodeState {
pub fn new() -> Self {
Self {
frame_history: VecDeque::new(),
smoothed_person_score: 0.0,
prev_person_count: 0,
smoothed_motion: 0.0,
current_motion_level: "absent".to_string(),
debounce_counter: 0,
debounce_candidate: "absent".to_string(),
baseline_motion: 0.0,
baseline_frames: 0,
smoothed_hr: 0.0,
smoothed_br: 0.0,
smoothed_hr_conf: 0.0,
smoothed_br_conf: 0.0,
hr_buffer: VecDeque::with_capacity(8),
br_buffer: VecDeque::with_capacity(8),
rssi_history: VecDeque::new(),
vital_detector: VitalSignDetector::new(10.0),
latest_vitals: VitalSigns::default(),
last_frame_time: None,
edge_vitals: None,
latest_features: None,
prev_keypoints: None,
motion_energy_history: VecDeque::with_capacity(COHERENCE_WINDOW),
coherence_score: 1.0,
}
}
/// Update the coherence score from the latest motion_energy value.
pub fn update_coherence(&mut self, motion_energy: f64) {
if self.motion_energy_history.len() >= COHERENCE_WINDOW {
self.motion_energy_history.pop_front();
}
self.motion_energy_history.push_back(motion_energy);
let n = self.motion_energy_history.len();
if n < 2 {
self.coherence_score = 1.0;
return;
}
let mean: f64 = self.motion_energy_history.iter().sum::<f64>() / n as f64;
let variance: f64 = self.motion_energy_history.iter()
.map(|v| (v - mean) * (v - mean))
.sum::<f64>() / (n - 1) as f64;
self.coherence_score = (1.0 / (1.0 + variance)).clamp(0.0, 1.0);
}
/// Choose the EMA alpha based on current coherence score.
pub fn ema_alpha(&self) -> f64 {
if self.coherence_score < COHERENCE_LOW_THRESHOLD {
TEMPORAL_EMA_ALPHA_LOW_COHERENCE
} else {
TEMPORAL_EMA_ALPHA_DEFAULT
}
}
}
// ── Shared application state ────────────────────────────────────────────────
/// Shared application state
pub struct AppStateInner {
pub latest_update: Option<SensingUpdate>,
pub rssi_history: VecDeque<f64>,
pub frame_history: VecDeque<Vec<f64>>,
pub tick: u64,
pub source: String,
pub last_esp32_frame: Option<std::time::Instant>,
pub tx: broadcast::Sender<String>,
pub total_detections: u64,
pub start_time: std::time::Instant,
pub vital_detector: VitalSignDetector,
pub latest_vitals: VitalSigns,
pub rvf_info: Option<RvfContainerInfo>,
pub save_rvf_path: Option<PathBuf>,
pub progressive_loader: Option<ProgressiveLoader>,
pub active_sona_profile: Option<String>,
pub model_loaded: bool,
pub smoothed_person_score: f64,
pub prev_person_count: usize,
pub smoothed_motion: f64,
pub current_motion_level: String,
pub debounce_counter: u32,
pub debounce_candidate: String,
pub baseline_motion: f64,
pub baseline_frames: u64,
pub smoothed_hr: f64,
pub smoothed_br: f64,
pub smoothed_hr_conf: f64,
pub smoothed_br_conf: f64,
pub hr_buffer: VecDeque<f64>,
pub br_buffer: VecDeque<f64>,
pub edge_vitals: Option<Esp32VitalsPacket>,
pub latest_wasm_events: Option<WasmOutputPacket>,
pub discovered_models: Vec<serde_json::Value>,
pub active_model_id: Option<String>,
pub recordings: Vec<serde_json::Value>,
pub recording_active: bool,
pub recording_start_time: Option<std::time::Instant>,
pub recording_current_id: Option<String>,
pub recording_stop_tx: Option<tokio::sync::watch::Sender<bool>>,
pub training_status: String,
pub training_config: Option<serde_json::Value>,
pub adaptive_model: Option<adaptive_classifier::AdaptiveModel>,
pub node_states: HashMap<u8, NodeState>,
pub pose_tracker: PoseTracker,
pub last_tracker_instant: Option<std::time::Instant>,
pub multistatic_fuser: MultistaticFuser,
pub field_model: Option<FieldModel>,
}
impl AppStateInner {
/// Return the effective data source, accounting for ESP32 frame timeout.
pub fn effective_source(&self) -> String {
if self.source == "esp32" {
if let Some(last) = self.last_esp32_frame {
if last.elapsed() > ESP32_OFFLINE_TIMEOUT {
return "esp32:offline".to_string();
}
}
}
self.source.clone()
}
/// Person count: eigenvalue-based if field model is calibrated, else heuristic.
pub fn person_count(&self) -> usize {
use crate::field_bridge;
use crate::csi::score_to_person_count;
match self.field_model.as_ref() {
Some(fm) => {
let history = if !self.frame_history.is_empty() {
&self.frame_history
} else {
self.node_states.values()
.filter(|ns| !ns.frame_history.is_empty())
.max_by_key(|ns| ns.last_frame_time)
.map(|ns| &ns.frame_history)
.unwrap_or(&self.frame_history)
};
field_bridge::occupancy_or_fallback(
fm, history, self.smoothed_person_score, self.prev_person_count,
)
}
None => score_to_person_count(self.smoothed_person_score, self.prev_person_count),
}
}
}
pub type SharedState = Arc<RwLock<AppStateInner>>;

View file

@ -339,9 +339,16 @@ impl RfTomographer {
/// Compute the intersection weights of a link with the voxel grid.
///
/// Uses a simplified approach: for each voxel, computes the minimum
/// distance from the voxel center to the link ray. Voxels within
/// one Fresnel zone receive weight proportional to closeness.
/// Uses a DDA (Digital Differential Analyzer) ray-marching algorithm:
/// 1. March along the ray from TX to RX, advancing to the nearest
/// axis-aligned voxel boundary at each step.
/// 2. At each ray voxel, expand by the Fresnel radius to check
/// neighboring voxels.
/// 3. Use a visited bitvector to avoid duplicate entries.
/// 4. Weight = `1.0 - dist / fresnel_radius` (same as before).
///
/// This is O(ray_length / voxel_size) instead of O(nx*ny*nz),
/// a significant speedup for large grids.
fn compute_link_weights(link: &LinkGeometry, config: &TomographyConfig) -> Vec<(usize, f64)> {
let vx = (config.bounds[3] - config.bounds[0]) / config.nx as f64;
let vy = (config.bounds[4] - config.bounds[1]) / config.ny as f64;
@ -356,25 +363,74 @@ fn compute_link_weights(link: &LinkGeometry, config: &TomographyConfig) -> Vec<(
let dy = link.rx.y - link.tx.y;
let dz = link.rx.z - link.tx.z;
let n_voxels = config.nx * config.ny * config.nz;
let mut visited = vec![false; n_voxels];
let mut weights = Vec::new();
for iz in 0..config.nz {
for iy in 0..config.ny {
for ix in 0..config.nx {
let cx = config.bounds[0] + (ix as f64 + 0.5) * vx;
let cy = config.bounds[1] + (iy as f64 + 0.5) * vy;
let cz = config.bounds[2] + (iz as f64 + 0.5) * vz;
// Fresnel expansion radius in voxel units.
let expand_x = (fresnel_radius / vx).ceil() as isize;
let expand_y = (fresnel_radius / vy).ceil() as isize;
let expand_z = (fresnel_radius / vz).ceil() as isize;
// Point-to-line distance
let dist = point_to_segment_distance(
cx, cy, cz, link.tx.x, link.tx.y, link.tx.z, dx, dy, dz, link_dist,
);
// DDA initialization: start at TX position in voxel coordinates.
let start_vx = (link.tx.x - config.bounds[0]) / vx;
let start_vy = (link.tx.y - config.bounds[1]) / vy;
let start_vz = (link.tx.z - config.bounds[2]) / vz;
if dist < fresnel_radius {
// Weight decays with distance from link ray
let w = 1.0 - dist / fresnel_radius;
let idx = iz * config.ny * config.nx + iy * config.nx + ix;
weights.push((idx, w));
let end_vx = (link.rx.x - config.bounds[0]) / vx;
let end_vy = (link.rx.y - config.bounds[1]) / vy;
let end_vz = (link.rx.z - config.bounds[2]) / vz;
let ray_dx = end_vx - start_vx;
let ray_dy = end_vy - start_vy;
let ray_dz = end_vz - start_vz;
// Number of DDA steps: traverse the maximum voxel span.
let steps = (ray_dx.abs().max(ray_dy.abs()).max(ray_dz.abs()).ceil() as usize).max(1);
let inv_steps = 1.0 / steps as f64;
for step in 0..=steps {
let t = step as f64 * inv_steps;
let rx = start_vx + t * ray_dx;
let ry = start_vy + t * ray_dy;
let rz = start_vz + t * ray_dz;
let base_ix = rx.floor() as isize;
let base_iy = ry.floor() as isize;
let base_iz = rz.floor() as isize;
// Expand by Fresnel radius to check neighboring voxels.
for diz in -expand_z..=expand_z {
let iz = base_iz + diz;
if iz < 0 || iz >= config.nz as isize { continue; }
for diy in -expand_y..=expand_y {
let iy = base_iy + diy;
if iy < 0 || iy >= config.ny as isize { continue; }
for dix in -expand_x..=expand_x {
let ix = base_ix + dix;
if ix < 0 || ix >= config.nx as isize { continue; }
let idx = iz as usize * config.ny * config.nx
+ iy as usize * config.nx
+ ix as usize;
if visited[idx] { continue; }
let cx = config.bounds[0] + (ix as f64 + 0.5) * vx;
let cy = config.bounds[1] + (iy as f64 + 0.5) * vy;
let cz = config.bounds[2] + (iz as f64 + 0.5) * vz;
let dist = point_to_segment_distance(
cx, cy, cz,
link.tx.x, link.tx.y, link.tx.z,
dx, dy, dz, link_dist,
);
if dist < fresnel_radius {
let w = 1.0 - dist / fresnel_radius;
weights.push((idx, w));
}
visited[idx] = true;
}
}
}

View file

@ -0,0 +1,477 @@
#!/usr/bin/env node
/**
* Ground-Truth Alignment Camera Keypoints <-> CSI Recording
*
* Time-aligns camera keypoint data with CSI recording data to produce
* paired training samples for WiFlow supervised training (ADR-079).
*
* Camera keypoints: data/ground-truth/gt-{timestamp}.jsonl
* CSI recordings: data/recordings/*.csi.jsonl
* Paired output: data/paired/*.paired.jsonl
*
* Usage:
* node scripts/align-ground-truth.js \
* --gt data/ground-truth/gt-1775300000.jsonl \
* --csi data/recordings/overnight-1775217646.csi.jsonl \
* --output data/paired/aligned.paired.jsonl
*
* # With clock offset correction (camera ahead by 50ms)
* node scripts/align-ground-truth.js \
* --gt data/ground-truth/gt-1775300000.jsonl \
* --csi data/recordings/overnight-1775217646.csi.jsonl \
* --clock-offset-ms -50
*
* ADR: docs/adr/ADR-079
*/
'use strict';
const fs = require('fs');
const path = require('path');
const { parseArgs } = require('util');
// ---------------------------------------------------------------------------
// CLI argument parsing
// ---------------------------------------------------------------------------
const { values: args } = parseArgs({
options: {
gt: { type: 'string' },
csi: { type: 'string' },
output: { type: 'string', short: 'o' },
'window-ms': { type: 'string', default: '200' },
'window-frames': { type: 'string', default: '20' },
'min-camera-frames': { type: 'string', default: '3' },
'min-confidence': { type: 'string', default: '0.5' },
'clock-offset-ms': { type: 'string', default: '0' },
help: { type: 'boolean', short: 'h', default: false },
},
strict: true,
});
if (args.help || !args.gt || !args.csi) {
console.log(`
Usage: node scripts/align-ground-truth.js --gt <gt.jsonl> --csi <csi.jsonl> [options]
Required:
--gt <path> Camera ground-truth JSONL file
--csi <path> CSI recording JSONL file
Options:
--output, -o <path> Output paired JSONL (default: data/paired/<basename>.paired.jsonl)
--window-ms <ms> CSI window size in ms (default: 200)
--window-frames <n> Frames per CSI window (default: 20)
--min-camera-frames <n> Minimum camera frames per window (default: 3)
--min-confidence <f> Minimum average confidence threshold (default: 0.5)
--clock-offset-ms <ms> Manual clock offset: added to camera timestamps (default: 0)
--help, -h Show this help
`);
process.exit(args.help ? 0 : 1);
}
const WINDOW_FRAMES = parseInt(args['window-frames'], 10);
const WINDOW_MS = parseInt(args['window-ms'], 10);
const MIN_CAMERA_FRAMES = parseInt(args['min-camera-frames'], 10);
const MIN_CONFIDENCE = parseFloat(args['min-confidence']);
const CLOCK_OFFSET_MS = parseFloat(args['clock-offset-ms']);
const NUM_KEYPOINTS = 17; // COCO 17-keypoint format
// ---------------------------------------------------------------------------
// Timestamp conversion
// ---------------------------------------------------------------------------
/**
* Convert camera nanosecond timestamp to milliseconds.
* Applies clock offset correction.
*/
function cameraTsToMs(tsNs) {
return tsNs / 1e6 + CLOCK_OFFSET_MS;
}
/**
* Convert ISO 8601 timestamp string to milliseconds since epoch.
*/
function isoToMs(isoStr) {
return new Date(isoStr).getTime();
}
// ---------------------------------------------------------------------------
// IQ hex parsing (matches train-wiflow.js conventions)
// ---------------------------------------------------------------------------
/**
* Parse IQ hex string into signed byte pairs [I0, Q0, I1, Q1, ...].
*/
function parseIqHex(hexStr) {
const bytes = [];
for (let i = 0; i < hexStr.length; i += 2) {
let val = parseInt(hexStr.substr(i, 2), 16);
if (val > 127) val -= 256; // signed byte
bytes.push(val);
}
return bytes;
}
/**
* Extract amplitude from IQ data for a given number of subcarriers.
* Returns Float32Array of amplitudes [nSubcarriers].
* Skips first I/Q pair (DC offset) per WiFlow paper recommendation.
*/
function extractAmplitude(iqBytes, nSubcarriers) {
const amp = new Float32Array(nSubcarriers);
const start = 2; // skip first IQ pair (DC offset)
for (let sc = 0; sc < nSubcarriers; sc++) {
const idx = start + sc * 2;
if (idx + 1 < iqBytes.length) {
const I = iqBytes[idx];
const Q = iqBytes[idx + 1];
amp[sc] = Math.sqrt(I * I + Q * Q);
}
}
return amp;
}
// ---------------------------------------------------------------------------
// File loading
// ---------------------------------------------------------------------------
/**
* Load and parse a JSONL file, skipping blank/malformed lines.
*/
function loadJsonl(filePath) {
const lines = fs.readFileSync(filePath, 'utf8').split('\n');
const records = [];
for (const line of lines) {
const trimmed = line.trim();
if (!trimmed) continue;
try {
records.push(JSON.parse(trimmed));
} catch {
// skip malformed lines
}
}
return records;
}
/**
* Load camera ground-truth file.
* Returns array of { tsMs, keypoints, confidence, nVisible, nPersons }.
*/
function loadGroundTruth(filePath) {
const raw = loadJsonl(filePath);
const frames = [];
for (const r of raw) {
if (r.ts_ns == null || !r.keypoints) continue;
frames.push({
tsMs: cameraTsToMs(r.ts_ns),
keypoints: r.keypoints,
confidence: r.confidence ?? 0,
nVisible: r.n_visible ?? 0,
nPersons: r.n_persons ?? 1,
});
}
// Sort by timestamp
frames.sort((a, b) => a.tsMs - b.tsMs);
return frames;
}
/**
* Load CSI recording file.
* Separates raw_csi frames and feature frames.
*/
function loadCsi(filePath) {
const raw = loadJsonl(filePath);
const rawCsi = [];
const features = [];
for (const r of raw) {
if (!r.timestamp) continue;
const tsMs = isoToMs(r.timestamp);
if (isNaN(tsMs)) continue;
if (r.type === 'raw_csi') {
rawCsi.push({
tsMs,
nodeId: r.node_id,
subcarriers: r.subcarriers ?? 128,
iqHex: r.iq_hex,
rssi: r.rssi,
seq: r.seq,
});
} else if (r.type === 'feature') {
features.push({
tsMs,
nodeId: r.node_id,
features: r.features,
rssi: r.rssi,
seq: r.seq,
});
}
}
// Sort by timestamp
rawCsi.sort((a, b) => a.tsMs - b.tsMs);
features.sort((a, b) => a.tsMs - b.tsMs);
return { rawCsi, features };
}
// ---------------------------------------------------------------------------
// Windowing
// ---------------------------------------------------------------------------
/**
* Group frames into non-overlapping windows of `windowSize` consecutive frames.
*/
function groupIntoWindows(frames, windowSize) {
const windows = [];
for (let i = 0; i + windowSize <= frames.length; i += windowSize) {
windows.push(frames.slice(i, i + windowSize));
}
return windows;
}
// ---------------------------------------------------------------------------
// Camera frame matching (binary search)
// ---------------------------------------------------------------------------
/**
* Find all camera frames within [tStart, tEnd] using binary search.
*/
function findCameraFramesInRange(cameraFrames, tStartMs, tEndMs) {
// Binary search for first frame >= tStartMs
let lo = 0;
let hi = cameraFrames.length;
while (lo < hi) {
const mid = (lo + hi) >>> 1;
if (cameraFrames[mid].tsMs < tStartMs) lo = mid + 1;
else hi = mid;
}
const matched = [];
for (let i = lo; i < cameraFrames.length; i++) {
if (cameraFrames[i].tsMs > tEndMs) break;
matched.push(cameraFrames[i]);
}
return matched;
}
// ---------------------------------------------------------------------------
// Keypoint averaging (confidence-weighted)
// ---------------------------------------------------------------------------
/**
* Average keypoints weighted by per-frame confidence.
* Returns { keypoints: [[x,y],...], avgConfidence }.
*/
function averageKeypoints(cameraFrames) {
let totalWeight = 0;
const sumKp = new Array(NUM_KEYPOINTS).fill(null).map(() => [0, 0]);
for (const f of cameraFrames) {
const w = f.confidence || 1e-6;
totalWeight += w;
for (let k = 0; k < NUM_KEYPOINTS && k < f.keypoints.length; k++) {
sumKp[k][0] += f.keypoints[k][0] * w;
sumKp[k][1] += f.keypoints[k][1] * w;
}
}
if (totalWeight === 0) totalWeight = 1;
const keypoints = sumKp.map(([x, y]) => [x / totalWeight, y / totalWeight]);
const avgConfidence = cameraFrames.reduce((s, f) => s + (f.confidence || 0), 0) / cameraFrames.length;
return { keypoints, avgConfidence };
}
// ---------------------------------------------------------------------------
// CSI matrix extraction
// ---------------------------------------------------------------------------
/**
* Extract CSI amplitude matrix from raw_csi window.
* Returns { data: flat Float32Array, shape: [subcarriers, windowFrames] }.
*/
function extractCsiMatrix(window) {
const nFrames = window.length;
const nSc = window[0].subcarriers || 128;
const matrix = new Float32Array(nSc * nFrames);
for (let f = 0; f < nFrames; f++) {
const frame = window[f];
if (frame.iqHex) {
const iq = parseIqHex(frame.iqHex);
const amp = extractAmplitude(iq, nSc);
matrix.set(amp, f * nSc);
}
}
return { data: Array.from(matrix), shape: [nSc, nFrames] };
}
/**
* Extract feature matrix from feature-type window.
* Returns { data: flat array, shape: [featureDim, windowFrames] }.
*/
function extractFeatureMatrix(window) {
const nFrames = window.length;
const dim = window[0].features ? window[0].features.length : 8;
const matrix = new Float32Array(dim * nFrames);
for (let f = 0; f < nFrames; f++) {
const feats = window[f].features || new Array(dim).fill(0);
for (let d = 0; d < dim; d++) {
matrix[f * dim + d] = feats[d] || 0;
}
}
return { data: Array.from(matrix), shape: [dim, nFrames] };
}
// ---------------------------------------------------------------------------
// Main alignment
// ---------------------------------------------------------------------------
function align() {
const gtPath = path.resolve(args.gt);
const csiPath = path.resolve(args.csi);
// Determine output path
let outputPath;
if (args.output) {
outputPath = path.resolve(args.output);
} else {
const baseName = path.basename(csiPath, '.csi.jsonl');
outputPath = path.resolve('data', 'paired', `${baseName}.paired.jsonl`);
}
// Ensure output directory exists
const outputDir = path.dirname(outputPath);
if (!fs.existsSync(outputDir)) {
fs.mkdirSync(outputDir, { recursive: true });
}
console.log('=== Ground-Truth Alignment (ADR-079) ===');
console.log(` GT file: ${gtPath}`);
console.log(` CSI file: ${csiPath}`);
console.log(` Output: ${outputPath}`);
console.log(` Window: ${WINDOW_FRAMES} frames / ${WINDOW_MS} ms`);
console.log(` Min camera frames: ${MIN_CAMERA_FRAMES}`);
console.log(` Min confidence: ${MIN_CONFIDENCE}`);
console.log(` Clock offset: ${CLOCK_OFFSET_MS} ms`);
console.log();
// Load data
console.log('Loading ground-truth...');
const cameraFrames = loadGroundTruth(gtPath);
console.log(` ${cameraFrames.length} camera frames loaded`);
if (cameraFrames.length > 0) {
console.log(` Time range: ${new Date(cameraFrames[0].tsMs).toISOString()} -> ${new Date(cameraFrames[cameraFrames.length - 1].tsMs).toISOString()}`);
}
console.log('Loading CSI data...');
const { rawCsi, features } = loadCsi(csiPath);
console.log(` ${rawCsi.length} raw_csi frames, ${features.length} feature frames`);
// Decide which CSI source to use
const useRawCsi = rawCsi.length >= WINDOW_FRAMES;
const csiSource = useRawCsi ? rawCsi : features;
const sourceLabel = useRawCsi ? 'raw_csi' : 'feature';
if (csiSource.length < WINDOW_FRAMES) {
console.error(`ERROR: Not enough CSI frames (${csiSource.length}) for even one window of ${WINDOW_FRAMES} frames.`);
process.exit(1);
}
console.log(` Using ${sourceLabel} frames (${csiSource.length} total)`);
if (csiSource.length > 0) {
console.log(` CSI time range: ${new Date(csiSource[0].tsMs).toISOString()} -> ${new Date(csiSource[csiSource.length - 1].tsMs).toISOString()}`);
}
console.log();
// Group CSI into windows
const windows = groupIntoWindows(csiSource, WINDOW_FRAMES);
console.log(`Grouped into ${windows.length} CSI windows`);
// Align
const paired = [];
let totalConfidence = 0;
for (const window of windows) {
const tStartMs = window[0].tsMs;
const tEndMs = window[window.length - 1].tsMs;
// Expand window if actual time span is smaller than window-ms
const halfWindow = WINDOW_MS / 2;
const midpoint = (tStartMs + tEndMs) / 2;
const searchStart = Math.min(tStartMs, midpoint - halfWindow);
const searchEnd = Math.max(tEndMs, midpoint + halfWindow);
// Find matching camera frames
const matched = findCameraFramesInRange(cameraFrames, searchStart, searchEnd);
if (matched.length < MIN_CAMERA_FRAMES) continue;
// Check average confidence
const avgConf = matched.reduce((s, f) => s + (f.confidence || 0), 0) / matched.length;
if (avgConf < MIN_CONFIDENCE) continue;
// Average keypoints weighted by confidence
const { keypoints, avgConfidence } = averageKeypoints(matched);
// Extract CSI matrix
const csiMatrix = useRawCsi
? extractCsiMatrix(window)
: extractFeatureMatrix(window);
paired.push({
csi: csiMatrix.data,
csi_shape: csiMatrix.shape,
kp: keypoints,
conf: Math.round(avgConfidence * 1000) / 1000,
n_camera_frames: matched.length,
ts_start: new Date(tStartMs).toISOString(),
ts_end: new Date(tEndMs).toISOString(),
});
totalConfidence += avgConfidence;
}
// Write output
const outputLines = paired.map(s => JSON.stringify(s));
fs.writeFileSync(outputPath, outputLines.join('\n') + (outputLines.length > 0 ? '\n' : ''));
// Print summary
const alignmentRate = windows.length > 0 ? (paired.length / windows.length * 100) : 0;
const avgPairedConf = paired.length > 0 ? (totalConfidence / paired.length) : 0;
console.log();
console.log('=== Alignment Summary ===');
console.log(` Total CSI windows: ${windows.length}`);
console.log(` Paired samples: ${paired.length}`);
console.log(` Alignment rate: ${alignmentRate.toFixed(1)}%`);
console.log(` Avg confidence (paired): ${avgPairedConf.toFixed(3)}`);
console.log(` CSI source: ${sourceLabel} (${csiMatrix_shapeLabel(paired, useRawCsi)})`);
if (paired.length > 0) {
console.log(` Time range covered: ${paired[0].ts_start} -> ${paired[paired.length - 1].ts_end}`);
}
console.log(` Output written: ${outputPath}`);
console.log();
if (paired.length === 0) {
console.log('WARNING: No paired samples produced. Check that camera and CSI time ranges overlap.');
console.log(' Hint: Use --clock-offset-ms to correct misaligned clocks.');
}
}
/**
* Format CSI matrix shape label for summary.
*/
function csiMatrix_shapeLabel(paired, useRawCsi) {
if (paired.length === 0) return useRawCsi ? `[128, ${WINDOW_FRAMES}]` : `[8, ${WINDOW_FRAMES}]`;
const shape = paired[0].csi_shape;
return `[${shape[0]}, ${shape[1]}]`;
}
// ---------------------------------------------------------------------------
// Entry point
// ---------------------------------------------------------------------------
align();

View file

@ -0,0 +1,341 @@
#!/usr/bin/env python3
"""Camera ground-truth collection for WiFi pose estimation training (ADR-079).
Captures webcam keypoints via MediaPipe PoseLandmarker (Tasks API) and
synchronizes with ESP32 CSI recording from the sensing server.
Output: JSONL file in data/ground-truth/ with per-frame 17-keypoint COCO poses.
Usage:
python scripts/collect-ground-truth.py --preview --duration 60
python scripts/collect-ground-truth.py --server http://192.168.1.10:3000
"""
from __future__ import annotations
import argparse
import json
import os
import signal
import sys
import time
import urllib.request
import urllib.error
from pathlib import Path
from datetime import datetime
import cv2
import numpy as np
import mediapipe as mp
from mediapipe.tasks.python import BaseOptions
from mediapipe.tasks.python.vision import (
PoseLandmarker,
PoseLandmarkerOptions,
RunningMode,
)
# ---------------------------------------------------------------------------
# MediaPipe 33 landmarks -> 17 COCO keypoints
# ---------------------------------------------------------------------------
# COCO idx : MP idx : joint name
# 0 : 0 : nose
# 1 : 2 : left_eye
# 2 : 5 : right_eye
# 3 : 7 : left_ear
# 4 : 8 : right_ear
# 5 : 11 : left_shoulder
# 6 : 12 : right_shoulder
# 7 : 13 : left_elbow
# 8 : 14 : right_elbow
# 9 : 15 : left_wrist
# 10 : 16 : right_wrist
# 11 : 23 : left_hip
# 12 : 24 : right_hip
# 13 : 25 : left_knee
# 14 : 26 : right_knee
# 15 : 27 : left_ankle
# 16 : 28 : right_ankle
MP_TO_COCO = [0, 2, 5, 7, 8, 11, 12, 13, 14, 15, 16, 23, 24, 25, 26, 27, 28]
COCO_BONES = [
(5, 7), (7, 9), (6, 8), (8, 10), # arms
(5, 6), # shoulders
(11, 13), (13, 15), (12, 14), (14, 16), # legs
(11, 12), # hips
(5, 11), (6, 12), # torso
(0, 1), (0, 2), (1, 3), (2, 4), # face
]
MODEL_URL = (
"https://storage.googleapis.com/mediapipe-models/"
"pose_landmarker/pose_landmarker_lite/float16/latest/"
"pose_landmarker_lite.task"
)
MODEL_FILENAME = "pose_landmarker_lite.task"
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def ensure_model(cache_dir: Path) -> Path:
"""Download the PoseLandmarker model if not already cached."""
model_path = cache_dir / MODEL_FILENAME
if model_path.exists():
return model_path
cache_dir.mkdir(parents=True, exist_ok=True)
print(f"Downloading {MODEL_FILENAME} ...")
try:
urllib.request.urlretrieve(MODEL_URL, str(model_path))
print(f" saved to {model_path}")
except Exception as exc:
print(f"ERROR: Failed to download model: {exc}", file=sys.stderr)
print(
"Download manually from:\n"
f" {MODEL_URL}\n"
f"and place at {model_path}",
file=sys.stderr,
)
sys.exit(1)
return model_path
def post_json(url: str, payload: dict | None = None, timeout: float = 5.0) -> bool:
"""POST JSON to a URL. Returns True on success, False on failure."""
data = json.dumps(payload or {}).encode("utf-8")
req = urllib.request.Request(
url,
data=data,
headers={"Content-Type": "application/json"},
method="POST",
)
try:
with urllib.request.urlopen(req, timeout=timeout) as resp:
return 200 <= resp.status < 300
except Exception as exc:
print(f"WARNING: POST {url} failed: {exc}", file=sys.stderr)
return False
def draw_skeleton(frame: np.ndarray, keypoints: list[list[float]], w: int, h: int):
"""Draw COCO skeleton overlay on a BGR frame."""
pts = []
for x, y in keypoints:
px, py = int(x * w), int(y * h)
pts.append((px, py))
cv2.circle(frame, (px, py), 4, (0, 255, 0), -1)
for i, j in COCO_BONES:
if i < len(pts) and j < len(pts):
cv2.line(frame, pts[i], pts[j], (0, 200, 255), 2)
# ---------------------------------------------------------------------------
# Main collection loop
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(
description="Collect camera ground-truth keypoints for WiFi pose training (ADR-079)."
)
parser.add_argument(
"--server",
default="http://localhost:3000",
help="Sensing server URL (default: http://localhost:3000)",
)
parser.add_argument(
"--preview",
action="store_true",
help="Show live skeleton overlay window",
)
parser.add_argument(
"--duration",
type=int,
default=300,
help="Recording duration in seconds (default: 300)",
)
parser.add_argument(
"--camera",
type=int,
default=0,
help="Camera device index (default: 0)",
)
parser.add_argument(
"--output",
default="data/ground-truth",
help="Output directory (default: data/ground-truth)",
)
args = parser.parse_args()
# --- Resolve paths relative to repo root ---
repo_root = Path(__file__).resolve().parent.parent
output_dir = repo_root / args.output
output_dir.mkdir(parents=True, exist_ok=True)
cache_dir = repo_root / "data" / ".cache"
# --- Download / locate model ---
model_path = ensure_model(cache_dir)
# --- Open camera ---
cap = cv2.VideoCapture(args.camera)
if not cap.isOpened():
print(
f"ERROR: Cannot open camera index {args.camera}. "
"Check that a webcam is connected and not in use by another app.",
file=sys.stderr,
)
sys.exit(1)
frame_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
print(f"Camera opened: {frame_w}x{frame_h}")
# --- Create PoseLandmarker ---
options = PoseLandmarkerOptions(
base_options=BaseOptions(model_asset_path=str(model_path)),
running_mode=RunningMode.IMAGE,
num_poses=1,
min_pose_detection_confidence=0.5,
min_pose_presence_confidence=0.5,
min_tracking_confidence=0.5,
)
landmarker = PoseLandmarker.create_from_options(options)
# --- Output file ---
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
out_path = output_dir / f"keypoints_{timestamp_str}.jsonl"
out_file = open(out_path, "w", encoding="utf-8")
print(f"Output: {out_path}")
# --- Start CSI recording ---
recording_url_start = f"{args.server}/api/v1/recording/start"
recording_url_stop = f"{args.server}/api/v1/recording/stop"
csi_started = post_json(recording_url_start)
if csi_started:
print("CSI recording started on sensing server.")
else:
print(
"WARNING: Could not start CSI recording. "
"Camera keypoints will still be captured.",
file=sys.stderr,
)
# --- Graceful shutdown ---
shutdown_requested = False
def _handle_signal(signum, frame):
nonlocal shutdown_requested
shutdown_requested = True
signal.signal(signal.SIGINT, _handle_signal)
signal.signal(signal.SIGTERM, _handle_signal)
# --- Collection loop ---
start_time = time.monotonic()
frame_count = 0
total_confidence = 0.0
total_visible = 0
print(f"Collecting for {args.duration}s ... (press 'q' in preview to stop)")
try:
while not shutdown_requested:
elapsed = time.monotonic() - start_time
if elapsed >= args.duration:
break
ret, frame = cap.read()
if not ret:
print("WARNING: Failed to read frame, retrying ...", file=sys.stderr)
time.sleep(0.01)
continue
ts_ns = time.time_ns()
# Convert BGR -> RGB for MediaPipe
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb)
result = landmarker.detect(mp_image)
n_persons = len(result.pose_landmarks)
if n_persons > 0:
landmarks = result.pose_landmarks[0]
keypoints = []
visibilities = []
for coco_idx in range(17):
mp_idx = MP_TO_COCO[coco_idx]
lm = landmarks[mp_idx]
keypoints.append([round(lm.x, 5), round(lm.y, 5)])
visibilities.append(lm.visibility if lm.visibility else 0.0)
confidence = float(np.mean(visibilities))
n_visible = int(sum(1 for v in visibilities if v > 0.5))
else:
keypoints = []
confidence = 0.0
n_visible = 0
record = {
"ts_ns": ts_ns,
"keypoints": keypoints,
"confidence": round(confidence, 4),
"n_visible": n_visible,
"n_persons": n_persons,
}
out_file.write(json.dumps(record) + "\n")
frame_count += 1
total_confidence += confidence
total_visible += n_visible
# Preview overlay
if args.preview and keypoints:
draw_skeleton(frame, keypoints, frame_w, frame_h)
if args.preview:
remaining = max(0, int(args.duration - elapsed))
cv2.putText(
frame,
f"Frames: {frame_count} Visible: {n_visible}/17 Time: {remaining}s",
(10, 30),
cv2.FONT_HERSHEY_SIMPLEX,
0.7,
(255, 255, 255),
2,
)
cv2.imshow("Ground Truth Collection (ADR-079)", frame)
if cv2.waitKey(1) & 0xFF == ord("q"):
break
finally:
# --- Cleanup ---
out_file.close()
cap.release()
if args.preview:
cv2.destroyAllWindows()
landmarker.close()
# Stop CSI recording
if csi_started:
if post_json(recording_url_stop):
print("CSI recording stopped.")
else:
print("WARNING: Failed to stop CSI recording.", file=sys.stderr)
# --- Summary ---
avg_conf = total_confidence / frame_count if frame_count > 0 else 0.0
avg_vis = total_visible / frame_count if frame_count > 0 else 0.0
print()
print("=== Collection Summary ===")
print(f" Total frames: {frame_count}")
print(f" Avg confidence: {avg_conf:.3f}")
print(f" Avg visible joints: {avg_vis:.1f} / 17")
print(f" Output: {out_path}")
if __name__ == "__main__":
main()

625
scripts/eval-wiflow.js Normal file
View file

@ -0,0 +1,625 @@
#!/usr/bin/env node
/**
* WiFlow PCK Evaluation Script (ADR-079)
*
* Measures accuracy of WiFi-based pose estimation against ground-truth
* camera keypoints using PCK (Percentage of Correct Keypoints) and MPJPE
* (Mean Per-Joint Position Error) metrics.
*
* Usage:
* node scripts/eval-wiflow.js --model models/wiflow-supervised/wiflow-v1.json --data data/paired/aligned.paired.jsonl
* node scripts/eval-wiflow.js --baseline --data data/paired/aligned.paired.jsonl
* node scripts/eval-wiflow.js --model models/wiflow-supervised/wiflow-v1.json --data data/paired/aligned.paired.jsonl --verbose
*
* ADR: docs/adr/ADR-079
*/
'use strict';
const fs = require('fs');
const path = require('path');
const { parseArgs } = require('util');
// ---------------------------------------------------------------------------
// Resolve WiFlow model dependencies
// ---------------------------------------------------------------------------
const {
WiFlowModel,
COCO_KEYPOINTS,
createRng,
} = require(path.join(__dirname, 'wiflow-model.js'));
const RUVLLM_PATH = path.resolve(__dirname, '..', 'vendor', 'ruvector', 'npm', 'packages', 'ruvllm', 'src');
const { SafeTensorsReader } = require(path.join(RUVLLM_PATH, 'export.js'));
// ---------------------------------------------------------------------------
// Constants
// ---------------------------------------------------------------------------
const NUM_KEYPOINTS = 17;
const DEFAULT_TORSO_LENGTH = 0.3; // normalized coords fallback
// Joint name aliases for display (short form)
const JOINT_NAMES = [
'nose', 'l_eye', 'r_eye', 'l_ear', 'r_ear',
'l_shoulder', 'r_shoulder', 'l_elbow', 'r_elbow',
'l_wrist', 'r_wrist', 'l_hip', 'r_hip',
'l_knee', 'r_knee', 'l_ankle', 'r_ankle',
];
// Shoulder indices: l_shoulder=5, r_shoulder=6
// Hip indices: l_hip=11, r_hip=12
const L_SHOULDER = 5;
const R_SHOULDER = 6;
const L_HIP = 11;
const R_HIP = 12;
// ---------------------------------------------------------------------------
// CLI argument parsing
// ---------------------------------------------------------------------------
const { values: args } = parseArgs({
options: {
model: { type: 'string', short: 'm' },
data: { type: 'string', short: 'd' },
baseline: { type: 'boolean', default: false },
output: { type: 'string', short: 'o' },
verbose: { type: 'boolean', short: 'v', default: false },
},
strict: true,
});
if (!args.data) {
console.error('Usage: node scripts/eval-wiflow.js --data <paired-jsonl> [--model <path>] [--baseline] [--output <path>]');
console.error('');
console.error('Required:');
console.error(' --data, -d <path> Paired CSI + keypoint JSONL (from align-ground-truth.js)');
console.error('');
console.error('Options:');
console.error(' --model, -m <path> Path to trained model directory or JSON');
console.error(' --baseline Evaluate proxy-based baseline (no model)');
console.error(' --output, -o <path> Output eval report JSON');
console.error(' --verbose, -v Verbose output');
process.exit(1);
}
if (!args.model && !args.baseline) {
console.error('Error: Must specify either --model <path> or --baseline');
process.exit(1);
}
// ---------------------------------------------------------------------------
// Data loading
// ---------------------------------------------------------------------------
/**
* Load paired JSONL samples.
* Each line: { csi: [...], csi_shape: [S, T], kp: [[x,y],...], conf: 0.xx, ... }
*/
function loadPairedData(filePath) {
const content = fs.readFileSync(filePath, 'utf-8');
const samples = [];
for (const line of content.split('\n')) {
if (!line.trim()) continue;
try {
const s = JSON.parse(line);
if (!s.kp || !Array.isArray(s.kp)) continue;
if (!s.csi && !s.csi_shape) continue;
samples.push(s);
} catch (e) {
// skip malformed lines
}
}
return samples;
}
// ---------------------------------------------------------------------------
// Model loading
// ---------------------------------------------------------------------------
/**
* Load WiFlow model from a directory or JSON file.
* Tries: model.safetensors, then config.json for architecture config.
* Returns { model, name }.
*/
function loadModel(modelPath) {
const stat = fs.statSync(modelPath);
let modelDir;
if (stat.isDirectory()) {
modelDir = modelPath;
} else {
// Assume JSON file in a model directory
modelDir = path.dirname(modelPath);
}
// Load architecture config if available
let config = {};
const configPath = path.join(modelDir, 'config.json');
if (fs.existsSync(configPath)) {
try {
const raw = JSON.parse(fs.readFileSync(configPath, 'utf-8'));
if (raw.custom) {
config.inputChannels = raw.custom.inputChannels || 128;
config.timeSteps = raw.custom.timeSteps || 20;
config.numKeypoints = raw.custom.numKeypoints || 17;
config.numHeads = raw.custom.numHeads || 8;
config.seed = raw.custom.seed || 42;
}
} catch (e) {
// use defaults
}
}
// Load training-metrics.json for additional config
const metricsPath = path.join(modelDir, 'training-metrics.json');
if (fs.existsSync(metricsPath)) {
try {
const metrics = JSON.parse(fs.readFileSync(metricsPath, 'utf-8'));
if (metrics.model && metrics.model.architecture === 'wiflow') {
// metrics available for report
}
} catch (e) {
// ignore
}
}
// Create model with config
const model = new WiFlowModel(config);
model.setTraining(false); // eval mode
// Load weights from SafeTensors
const safetensorsPath = path.join(modelDir, 'model.safetensors');
if (fs.existsSync(safetensorsPath)) {
const buffer = new Uint8Array(fs.readFileSync(safetensorsPath));
const reader = new SafeTensorsReader(buffer);
const tensorNames = reader.getTensorNames();
// Build tensor map for fromTensorMap
const tensorMap = new Map();
for (const name of tensorNames) {
const tensor = reader.getTensor(name);
if (tensor) {
tensorMap.set(name, tensor.data);
}
}
model.fromTensorMap(tensorMap);
if (args.verbose) {
console.log(`Loaded ${tensorNames.length} tensors from ${safetensorsPath}`);
console.log(`Model params: ${model.numParams().toLocaleString()}`);
}
} else {
console.warn(`WARN: No model.safetensors found in ${modelDir}, using random weights`);
}
// Derive model name
const name = path.basename(modelDir);
return { model, name };
}
// ---------------------------------------------------------------------------
// Baseline proxy pose generation (ADR-072 Phase 2 heuristic)
// ---------------------------------------------------------------------------
/**
* Generate a proxy standing skeleton from CSI features.
* If presence detected (amplitude energy > threshold), place a standing
* person at center with standard COCO proportions, perturbed by motion energy.
*/
function generateBaselinePose(sample) {
const rng = createRng(42);
// Estimate presence from CSI amplitude energy
const csi = sample.csi;
let energy = 0;
if (Array.isArray(csi)) {
for (let i = 0; i < csi.length; i++) {
energy += csi[i] * csi[i];
}
energy = Math.sqrt(energy / csi.length);
}
// Estimate motion energy (variance across subcarriers)
let motionEnergy = 0;
if (Array.isArray(csi) && sample.csi_shape) {
const [S, T] = sample.csi_shape;
if (T > 1) {
for (let s = 0; s < S; s++) {
let sum = 0;
let sumSq = 0;
for (let t = 0; t < T; t++) {
const v = csi[s * T + t] || 0;
sum += v;
sumSq += v * v;
}
const mean = sum / T;
motionEnergy += (sumSq / T) - (mean * mean);
}
motionEnergy = Math.sqrt(Math.max(0, motionEnergy / S));
}
}
// Normalized presence heuristic
const presence = Math.min(1, energy / 10);
if (presence < 0.3) {
// No person detected: return zero pose
return new Float32Array(NUM_KEYPOINTS * 2);
}
// Standing skeleton at center (0.5, 0.5) with standard proportions
// Coordinates are [x, y] in normalized [0, 1] space
// y=0 is top, y=1 is bottom (image convention)
const cx = 0.5;
const headY = 0.2;
const shoulderY = 0.32;
const elbowY = 0.45;
const wristY = 0.55;
const hipY = 0.55;
const kneeY = 0.72;
const ankleY = 0.88;
const shoulderW = 0.08;
const hipW = 0.06;
const armSpread = 0.12;
// Standard standing pose keypoints [x, y]
const skeleton = [
[cx, headY], // 0: nose
[cx - 0.02, headY - 0.02], // 1: l_eye
[cx + 0.02, headY - 0.02], // 2: r_eye
[cx - 0.04, headY], // 3: l_ear
[cx + 0.04, headY], // 4: r_ear
[cx - shoulderW, shoulderY], // 5: l_shoulder
[cx + shoulderW, shoulderY], // 6: r_shoulder
[cx - armSpread, elbowY], // 7: l_elbow
[cx + armSpread, elbowY], // 8: r_elbow
[cx - armSpread - 0.02, wristY], // 9: l_wrist
[cx + armSpread + 0.02, wristY], // 10: r_wrist
[cx - hipW, hipY], // 11: l_hip
[cx + hipW, hipY], // 12: r_hip
[cx - hipW, kneeY], // 13: l_knee
[cx + hipW, kneeY], // 14: r_knee
[cx - hipW, ankleY], // 15: l_ankle
[cx + hipW, ankleY], // 16: r_ankle
];
// Perturb limbs by motion energy
const perturbScale = Math.min(motionEnergy * 0.1, 0.05);
const result = new Float32Array(NUM_KEYPOINTS * 2);
for (let k = 0; k < NUM_KEYPOINTS; k++) {
const px = (rng() - 0.5) * 2 * perturbScale;
const py = (rng() - 0.5) * 2 * perturbScale;
result[k * 2] = Math.max(0, Math.min(1, skeleton[k][0] + px));
result[k * 2 + 1] = Math.max(0, Math.min(1, skeleton[k][1] + py));
}
return result;
}
// ---------------------------------------------------------------------------
// Metric computation
// ---------------------------------------------------------------------------
/** Euclidean distance between two 2D points */
function dist2d(x1, y1, x2, y2) {
const dx = x1 - x2;
const dy = y1 - y2;
return Math.sqrt(dx * dx + dy * dy);
}
/**
* Compute torso length from ground-truth keypoints.
* Torso = distance(mid_shoulder, mid_hip).
* Returns DEFAULT_TORSO_LENGTH if shoulders or hips not visible.
*/
function computeTorsoLength(kp) {
if (!kp || kp.length < 13) return DEFAULT_TORSO_LENGTH;
const lsX = kp[L_SHOULDER][0];
const lsY = kp[L_SHOULDER][1];
const rsX = kp[R_SHOULDER][0];
const rsY = kp[R_SHOULDER][1];
const lhX = kp[L_HIP][0];
const lhY = kp[L_HIP][1];
const rhX = kp[R_HIP][0];
const rhY = kp[R_HIP][1];
// Check if joints are at origin (not visible)
const shoulderVisible = (lsX !== 0 || lsY !== 0) && (rsX !== 0 || rsY !== 0);
const hipVisible = (lhX !== 0 || lhY !== 0) && (rhX !== 0 || rhY !== 0);
if (!shoulderVisible || !hipVisible) return DEFAULT_TORSO_LENGTH;
const midShoulderX = (lsX + rsX) / 2;
const midShoulderY = (lsY + rsY) / 2;
const midHipX = (lhX + rhX) / 2;
const midHipY = (lhY + rhY) / 2;
const torso = dist2d(midShoulderX, midShoulderY, midHipX, midHipY);
return torso > 0.01 ? torso : DEFAULT_TORSO_LENGTH;
}
/**
* Evaluate predictions against ground truth.
*
* @param {Array<{pred: Float32Array, gt: number[][], conf: number}>} results
* @returns {object} Evaluation report
*/
function computeMetrics(results) {
const n = results.length;
if (n === 0) {
return {
n_samples: 0,
pck_10: 0, pck_20: 0, pck_50: 0,
mpjpe: 0,
per_joint_pck20: {},
per_joint_mpjpe: {},
conf_weighted_pck20: 0,
conf_weighted_mpjpe: 0,
};
}
// Accumulators
const pckCounts = { 10: 0, 20: 0, 50: 0 };
let totalJoints = 0;
let totalMPJPE = 0;
const perJointPck20 = new Float64Array(NUM_KEYPOINTS);
const perJointMPJPE = new Float64Array(NUM_KEYPOINTS);
const perJointCount = new Float64Array(NUM_KEYPOINTS);
// Confidence-weighted accumulators
let confWeightedPck20Num = 0;
let confWeightedPck20Den = 0;
let confWeightedMpjpeNum = 0;
let confWeightedMpjpeDen = 0;
for (const { pred, gt, conf } of results) {
const torso = computeTorsoLength(gt);
const w = Math.max(conf, 1e-6);
for (let k = 0; k < NUM_KEYPOINTS; k++) {
if (k >= gt.length) continue;
const gtX = gt[k][0];
const gtY = gt[k][1];
const predX = pred[k * 2];
const predY = pred[k * 2 + 1];
const d = dist2d(predX, predY, gtX, gtY);
totalJoints++;
totalMPJPE += d;
perJointMPJPE[k] += d;
perJointCount[k] += 1;
// PCK at different thresholds
if (d < 0.10 * torso) pckCounts[10]++;
if (d < 0.20 * torso) {
pckCounts[20]++;
perJointPck20[k]++;
confWeightedPck20Num += w;
}
if (d < 0.50 * torso) pckCounts[50]++;
confWeightedPck20Den += w;
confWeightedMpjpeNum += d * w;
confWeightedMpjpeDen += w;
}
}
// Aggregate metrics
const pck10 = totalJoints > 0 ? pckCounts[10] / totalJoints : 0;
const pck20 = totalJoints > 0 ? pckCounts[20] / totalJoints : 0;
const pck50 = totalJoints > 0 ? pckCounts[50] / totalJoints : 0;
const mpjpe = totalJoints > 0 ? totalMPJPE / totalJoints : 0;
// Per-joint breakdown
const perJointPck20Map = {};
const perJointMpjpeMap = {};
for (let k = 0; k < NUM_KEYPOINTS; k++) {
const name = JOINT_NAMES[k];
perJointPck20Map[name] = perJointCount[k] > 0 ? perJointPck20[k] / perJointCount[k] : 0;
perJointMpjpeMap[name] = perJointCount[k] > 0 ? perJointMPJPE[k] / perJointCount[k] : 0;
}
// Confidence-weighted
const confPck20 = confWeightedPck20Den > 0 ? confWeightedPck20Num / confWeightedPck20Den : 0;
const confMpjpe = confWeightedMpjpeDen > 0 ? confWeightedMpjpeNum / confWeightedMpjpeDen : 0;
return {
n_samples: n,
pck_10: pck10,
pck_20: pck20,
pck_50: pck50,
mpjpe,
per_joint_pck20: perJointPck20Map,
per_joint_mpjpe: perJointMpjpeMap,
conf_weighted_pck20: confPck20,
conf_weighted_mpjpe: confMpjpe,
};
}
// ---------------------------------------------------------------------------
// Inference
// ---------------------------------------------------------------------------
/**
* Run model inference on a single paired sample.
* @param {WiFlowModel} model
* @param {object} sample - { csi, csi_shape, kp, conf }
* @returns {Float32Array} - [17*2] predicted keypoints
*/
function runModelInference(model, sample) {
const csi = sample.csi;
const shape = sample.csi_shape;
const S = shape ? shape[0] : 128;
const T = shape ? shape[1] : 20;
// Prepare input as Float32Array [S, T]
let input;
if (csi instanceof Float32Array) {
input = csi;
} else if (Array.isArray(csi)) {
input = new Float32Array(csi);
} else {
input = new Float32Array(S * T);
}
// Ensure correct size (pad or truncate)
const expectedLen = model.inputChannels * model.timeSteps;
if (input.length !== expectedLen) {
const resized = new Float32Array(expectedLen);
const copyLen = Math.min(input.length, expectedLen);
resized.set(input.subarray(0, copyLen));
input = resized;
}
return model.forward(input);
}
// ---------------------------------------------------------------------------
// Formatted output
// ---------------------------------------------------------------------------
function formatPercent(v) {
return (v * 100).toFixed(1) + '%';
}
function formatFloat(v, decimals) {
decimals = decimals || 4;
return v.toFixed(decimals);
}
function printReport(report) {
console.log('');
console.log('WiFlow Evaluation Report (ADR-079)');
console.log('===================================');
console.log(`Model: ${report.model}`);
console.log(`Samples: ${report.n_samples.toLocaleString()}`);
console.log(`PCK@10: ${formatPercent(report.pck_10)}`);
console.log(`PCK@20: ${formatPercent(report.pck_20)}`);
console.log(`PCK@50: ${formatPercent(report.pck_50)}`);
console.log(`MPJPE: ${formatFloat(report.mpjpe)}`);
console.log('');
console.log('Per-Joint PCK@20:');
const maxNameLen = Math.max(...JOINT_NAMES.map(n => n.length));
for (const name of JOINT_NAMES) {
const pck = report.per_joint_pck20[name] || 0;
const pad = ' '.repeat(maxNameLen - name.length + 2);
console.log(` ${name}${pad}${formatPercent(pck)}`);
}
console.log('');
console.log('Per-Joint MPJPE:');
for (const name of JOINT_NAMES) {
const mpjpe = report.per_joint_mpjpe[name] || 0;
const pad = ' '.repeat(maxNameLen - name.length + 2);
console.log(` ${name}${pad}${formatFloat(mpjpe)}`);
}
console.log('');
console.log('Confidence-Weighted:');
console.log(` PCK@20: ${formatPercent(report.conf_weighted_pck20)}`);
console.log(` MPJPE: ${formatFloat(report.conf_weighted_mpjpe)}`);
console.log('');
console.log(`Inference: ${report.inference_latency_ms.toFixed(2)}ms/sample`);
console.log('');
}
// ---------------------------------------------------------------------------
// Main
// ---------------------------------------------------------------------------
function main() {
// Load paired data
if (args.verbose) console.log(`Loading paired data from ${args.data}...`);
const samples = loadPairedData(args.data);
if (samples.length === 0) {
console.error('Error: No valid paired samples found in', args.data);
process.exit(1);
}
if (args.verbose) console.log(`Loaded ${samples.length} paired samples`);
let modelName;
let model = null;
if (args.baseline) {
modelName = 'baseline-proxy';
if (args.verbose) console.log('Running baseline proxy evaluation (ADR-072 Phase 2 heuristic)');
} else {
const loaded = loadModel(args.model);
model = loaded.model;
modelName = loaded.name;
if (args.verbose) console.log(`Running model evaluation: ${modelName}`);
}
// Run inference and collect results
const results = [];
const startTime = process.hrtime.bigint();
for (const sample of samples) {
let pred;
if (args.baseline) {
pred = generateBaselinePose(sample);
} else {
pred = runModelInference(model, sample);
}
results.push({
pred,
gt: sample.kp,
conf: sample.conf || 0,
});
}
const endTime = process.hrtime.bigint();
const totalMs = Number(endTime - startTime) / 1e6;
const latencyMs = totalMs / samples.length;
// Compute metrics
const metrics = computeMetrics(results);
// Build report
const report = {
model: modelName,
n_samples: metrics.n_samples,
pck_10: Math.round(metrics.pck_10 * 10000) / 10000,
pck_20: Math.round(metrics.pck_20 * 10000) / 10000,
pck_50: Math.round(metrics.pck_50 * 10000) / 10000,
mpjpe: Math.round(metrics.mpjpe * 100000) / 100000,
per_joint_pck20: {},
per_joint_mpjpe: {},
conf_weighted_pck20: Math.round(metrics.conf_weighted_pck20 * 10000) / 10000,
conf_weighted_mpjpe: Math.round(metrics.conf_weighted_mpjpe * 100000) / 100000,
inference_latency_ms: Math.round(latencyMs * 100) / 100,
timestamp: new Date().toISOString(),
};
// Round per-joint metrics
for (const name of JOINT_NAMES) {
report.per_joint_pck20[name] = Math.round((metrics.per_joint_pck20[name] || 0) * 10000) / 10000;
report.per_joint_mpjpe[name] = Math.round((metrics.per_joint_mpjpe[name] || 0) * 100000) / 100000;
}
// Print formatted report
printReport(report);
// Write output JSON
const outputPath = args.output ||
(args.model
? path.join(path.dirname(
fs.statSync(args.model).isDirectory() ? path.join(args.model, '.') : args.model
), 'eval-report.json')
: 'models/wiflow-supervised/eval-report.json');
const outputDir = path.dirname(outputPath);
if (!fs.existsSync(outputDir)) {
fs.mkdirSync(outputDir, { recursive: true });
}
fs.writeFileSync(outputPath, JSON.stringify(report, null, 2) + '\n');
console.log(`Report saved to ${outputPath}`);
}
main();

View file

@ -6,7 +6,7 @@ echo "Host: $(hostname) | $(sysctl -n hw.ncpu 2>/dev/null || nproc) cores | $(sy
echo ""
REPO_DIR="${HOME}/Projects/wifi-densepose"
WINDOWS_HOST="100.102.238.73" # Tailscale IP of Windows machine
WINDOWS_HOST="${WINDOWS_HOST:-}" # Set via env: export WINDOWS_HOST=<tailscale-ip>
# Step 1: Clone or update repo
echo "[1/7] Setting up repository..."

111
scripts/record-csi-udp.py Normal file
View file

@ -0,0 +1,111 @@
#!/usr/bin/env python3
"""
Lightweight ESP32 CSI UDP recorder (ADR-079).
Captures raw CSI packets from ESP32 nodes over UDP and writes to JSONL.
Runs alongside collect-ground-truth.py for synchronized capture.
Usage:
python scripts/record-csi-udp.py --duration 300 --output data/recordings
"""
import argparse
import json
import os
import socket
import struct
import time
def parse_csi_packet(data):
"""Parse ADR-018 binary CSI packet into dict."""
if len(data) < 8:
return None
# ADR-018 header: [magic(2), len(2), node_id(1), seq(1), rssi(1), channel(1), iq_data...]
# Simplified: extract what we can from the raw packet
node_id = data[4] if len(data) > 4 else 0
rssi = struct.unpack('b', bytes([data[6]]))[0] if len(data) > 6 else 0
channel = data[7] if len(data) > 7 else 0
# IQ data starts at offset 8
iq_data = data[8:] if len(data) > 8 else b''
n_subcarriers = len(iq_data) // 2 # I,Q pairs
# Compute amplitudes
amplitudes = []
for i in range(0, len(iq_data) - 1, 2):
I = struct.unpack('b', bytes([iq_data[i]]))[0]
Q = struct.unpack('b', bytes([iq_data[i + 1]]))[0]
amplitudes.append(round((I * I + Q * Q) ** 0.5, 2))
return {
"type": "raw_csi",
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S.") + f"{int(time.time() * 1000) % 1000:03d}Z",
"ts_ns": time.time_ns(),
"node_id": node_id,
"rssi": rssi,
"channel": channel,
"subcarriers": n_subcarriers,
"amplitudes": amplitudes,
"iq_hex": iq_data.hex(),
}
def main():
parser = argparse.ArgumentParser(description="Record ESP32 CSI over UDP")
parser.add_argument("--port", type=int, default=5005, help="UDP port (default: 5005)")
parser.add_argument("--duration", type=int, default=300, help="Duration in seconds (default: 300)")
parser.add_argument("--output", default="data/recordings", help="Output directory")
args = parser.parse_args()
os.makedirs(args.output, exist_ok=True)
filename = f"csi-{int(time.time())}.csi.jsonl"
filepath = os.path.join(args.output, filename)
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("0.0.0.0", args.port))
sock.settimeout(1)
print(f"Recording CSI on UDP :{args.port} for {args.duration}s")
print(f"Output: {filepath}")
count = 0
start = time.time()
nodes_seen = set()
with open(filepath, "w") as f:
try:
while time.time() - start < args.duration:
try:
data, addr = sock.recvfrom(4096)
frame = parse_csi_packet(data)
if frame:
f.write(json.dumps(frame) + "\n")
count += 1
nodes_seen.add(frame["node_id"])
if count % 500 == 0:
elapsed = time.time() - start
rate = count / elapsed
print(f" {count} frames | {rate:.0f} fps | "
f"nodes: {sorted(nodes_seen)} | "
f"{elapsed:.0f}s / {args.duration}s")
except socket.timeout:
continue
except KeyboardInterrupt:
print("\nStopped by user")
sock.close()
elapsed = time.time() - start
print(f"\n=== CSI Recording Complete ===")
print(f" Frames: {count}")
print(f" Duration: {elapsed:.0f}s")
print(f" Rate: {count / max(elapsed, 1):.0f} fps")
print(f" Nodes: {sorted(nodes_seen)}")
print(f" Output: {filepath}")
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load diff

View file

@ -76,4 +76,31 @@ describe('MATScreen', () => {
// Simulated status maps to 'simulated' banner -> "SIMULATED DATA"
expect(getByText('SIMULATED DATA')).toBeTruthy();
});
it('shows simulation warning overlay when simulated and not acknowledged', () => {
// Reset store to ensure overlay is shown
const { useMatStore } = require('@/stores/matStore');
useMatStore.setState({ dataSource: 'simulated', simulationAcknowledged: false });
const { MATScreen } = require('@/screens/MATScreen');
const { getByText } = render(
<ThemeProvider>
<MATScreen />
</ThemeProvider>,
);
expect(getByText('I UNDERSTAND')).toBeTruthy();
});
it('hides overlay after acknowledgment', () => {
const { useMatStore } = require('@/stores/matStore');
useMatStore.setState({ dataSource: 'simulated', simulationAcknowledged: true });
const { MATScreen } = require('@/screens/MATScreen');
const { queryByText } = render(
<ThemeProvider>
<MATScreen />
</ThemeProvider>,
);
expect(queryByText('I UNDERSTAND')).toBeNull();
});
});

View file

@ -62,6 +62,8 @@ describe('useMatStore', () => {
survivors: [],
alerts: [],
selectedEventId: null,
dataSource: 'simulated',
simulationAcknowledged: false,
});
});
@ -195,4 +197,32 @@ describe('useMatStore', () => {
expect(useMatStore.getState().selectedEventId).toBeNull();
});
});
describe('dataSource', () => {
it('defaults to simulated', () => {
expect(useMatStore.getState().dataSource).toBe('simulated');
});
it('can be set to real', () => {
useMatStore.getState().setDataSource('real');
expect(useMatStore.getState().dataSource).toBe('real');
});
it('can be set back to simulated', () => {
useMatStore.getState().setDataSource('real');
useMatStore.getState().setDataSource('simulated');
expect(useMatStore.getState().dataSource).toBe('simulated');
});
});
describe('simulationAcknowledged', () => {
it('defaults to false', () => {
expect(useMatStore.getState().simulationAcknowledged).toBe(false);
});
it('can be acknowledged', () => {
useMatStore.getState().acknowledgeSimulation();
expect(useMatStore.getState().simulationAcknowledged).toBe(true);
});
});
});

View file

@ -0,0 +1,49 @@
import React, { useEffect, useRef } from 'react';
import { Animated, StyleSheet, Text, View } from 'react-native';
interface Props {
visible: boolean;
}
export const SimulationBanner: React.FC<Props> = ({ visible }) => {
const opacity = useRef(new Animated.Value(1)).current;
useEffect(() => {
if (!visible) return;
const pulse = Animated.loop(
Animated.sequence([
Animated.timing(opacity, { toValue: 0.4, duration: 800, useNativeDriver: true }),
Animated.timing(opacity, { toValue: 1.0, duration: 800, useNativeDriver: true }),
]),
);
pulse.start();
return () => pulse.stop();
}, [visible, opacity]);
if (!visible) return null;
return (
<Animated.View style={[styles.banner, { opacity }]}>
<Text style={styles.text}>SIMULATED DATA - NOT CONNECTED TO REAL SENSORS</Text>
</Animated.View>
);
};
const styles = StyleSheet.create({
banner: {
backgroundColor: '#e74c3c',
paddingVertical: 6,
paddingHorizontal: 12,
borderRadius: 6,
alignItems: 'center',
marginBottom: 8,
},
text: {
color: '#ffffff',
fontWeight: '700',
fontSize: 12,
letterSpacing: 0.5,
textAlign: 'center',
},
});

View file

@ -0,0 +1,78 @@
import React from 'react';
import { Modal, Pressable, StyleSheet, Text, View } from 'react-native';
interface Props {
visible: boolean;
onAcknowledge: () => void;
}
export const SimulationWarningOverlay: React.FC<Props> = ({ visible, onAcknowledge }) => (
<Modal visible={visible} transparent animationType="fade">
<View style={styles.backdrop}>
<View style={styles.card}>
<Text style={styles.icon}>&#9888;</Text>
<Text style={styles.title}>SIMULATED DATA</Text>
<Text style={styles.body}>
NOT CONNECTED TO REAL SENSORS{'\n\n'}
All survivor detections, vital signs, and alerts displayed on this screen are
generated from simulated data and do not reflect actual conditions.
</Text>
<Pressable style={styles.button} onPress={onAcknowledge}>
<Text style={styles.buttonText}>I UNDERSTAND</Text>
</Pressable>
</View>
</View>
</Modal>
);
const styles = StyleSheet.create({
backdrop: {
flex: 1,
backgroundColor: 'rgba(0,0,0,0.85)',
justifyContent: 'center',
alignItems: 'center',
padding: 24,
},
card: {
backgroundColor: '#1a1a2e',
borderRadius: 16,
padding: 32,
alignItems: 'center',
borderWidth: 2,
borderColor: '#e74c3c',
maxWidth: 420,
width: '100%',
},
icon: {
fontSize: 48,
color: '#e74c3c',
marginBottom: 12,
},
title: {
fontSize: 22,
fontWeight: '800',
color: '#e74c3c',
textAlign: 'center',
marginBottom: 16,
letterSpacing: 1,
},
body: {
fontSize: 15,
color: '#cccccc',
textAlign: 'center',
lineHeight: 22,
marginBottom: 28,
},
button: {
backgroundColor: '#e74c3c',
paddingHorizontal: 36,
paddingVertical: 14,
borderRadius: 8,
},
buttonText: {
color: '#ffffff',
fontWeight: '700',
fontSize: 16,
letterSpacing: 0.5,
},
});

View file

@ -10,6 +10,8 @@ import { type ConnectionStatus } from '@/types/sensing';
import { Alert, type Survivor } from '@/types/mat';
import { AlertList } from './AlertList';
import { MatWebView } from './MatWebView';
import { SimulationBanner } from './SimulationBanner';
import { SimulationWarningOverlay } from './SimulationWarningOverlay';
import { SurvivorCounter } from './SurvivorCounter';
import { useMatBridge } from './useMatBridge';
@ -47,6 +49,15 @@ export const MATScreen = () => {
const upsertSurvivor = useMatStore((state) => state.upsertSurvivor);
const addAlert = useMatStore((state) => state.addAlert);
const upsertEvent = useMatStore((state) => state.upsertEvent);
const dataSource = useMatStore((state) => state.dataSource);
const simulationAcknowledged = useMatStore((state) => state.simulationAcknowledged);
const setDataSource = useMatStore((state) => state.setDataSource);
const acknowledgeSimulation = useMatStore((state) => state.acknowledgeSimulation);
// Sync dataSource from connection status
useEffect(() => {
setDataSource(connectionStatus === 'connected' ? 'real' : 'simulated');
}, [connectionStatus, setDataSource]);
const { webViewRef, ready, onMessage, sendFrameUpdate, postEvent } = useMatBridge({
onSurvivorDetected: (survivor) => {
@ -113,8 +124,13 @@ export const MATScreen = () => {
const { height } = useWindowDimensions();
const webHeight = Math.max(240, Math.floor(height * 0.5));
const showOverlay = dataSource === 'simulated' && !simulationAcknowledged;
const showBanner = dataSource === 'simulated' && simulationAcknowledged;
return (
<ThemedView style={{ flex: 1, backgroundColor: colors.bg, padding: spacing.md }}>
<SimulationWarningOverlay visible={showOverlay} onAcknowledge={acknowledgeSimulation} />
<SimulationBanner visible={showBanner} />
<ConnectionBanner status={resolveBannerState(connectionStatus)} />
<View style={{ marginTop: 20 }}>
<SurvivorCounter survivors={survivors} />

View file

@ -7,11 +7,17 @@ export interface MatState {
survivors: Survivor[];
alerts: Alert[];
selectedEventId: string | null;
/** Whether data comes from real sensors or simulation. */
dataSource: 'real' | 'simulated';
/** Whether the user has dismissed the simulation warning overlay. */
simulationAcknowledged: boolean;
upsertEvent: (event: DisasterEvent) => void;
addZone: (zone: ScanZone) => void;
upsertSurvivor: (survivor: Survivor) => void;
addAlert: (alert: Alert) => void;
setSelectedEvent: (id: string | null) => void;
setDataSource: (source: 'real' | 'simulated') => void;
acknowledgeSimulation: () => void;
}
export const useMatStore = create<MatState>((set) => ({
@ -20,6 +26,8 @@ export const useMatStore = create<MatState>((set) => ({
survivors: [],
alerts: [],
selectedEventId: null,
dataSource: 'simulated',
simulationAcknowledged: false,
upsertEvent: (event) => {
set((state) => {
@ -71,4 +79,12 @@ export const useMatStore = create<MatState>((set) => ({
setSelectedEvent: (id) => {
set({ selectedEventId: id });
},
setDataSource: (source) => {
set({ dataSource: source });
},
acknowledgeSimulation: () => {
set({ simulationAcknowledged: true });
},
}));

View file

@ -17,7 +17,7 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
from src.config.settings import get_settings
from src.config.domains import get_domain_config
from src.api.routers import pose, stream, health
from src.api.routers import pose, stream, health, auth
from src.api.middleware.auth import AuthMiddleware
from src.api.middleware.rate_limit import RateLimitMiddleware
from src.api.dependencies import get_pose_service, get_stream_service, get_hardware_service
@ -263,6 +263,12 @@ app.include_router(
tags=["Streaming"]
)
app.include_router(
auth.router,
prefix=f"{settings.api_prefix}",
tags=["Authentication"]
)
# Root endpoint
@app.get("/")

View file

@ -189,7 +189,11 @@ class AuthMiddleware(BaseHTTPMiddleware):
self.settings.secret_key,
algorithms=[self.settings.jwt_algorithm]
)
# Check token blacklist (logout invalidation)
if token_blacklist.is_blacklisted(token):
raise ValueError("Token has been revoked")
# Extract user information
user_id = payload.get("sub")
if not user_id:

View file

@ -2,6 +2,6 @@
API routers package
"""
from . import pose, stream, health
from . import pose, stream, health, auth
__all__ = ["pose", "stream", "health"]
__all__ = ["pose", "stream", "health", "auth"]

View file

@ -0,0 +1,32 @@
"""
Authentication router for WiFi-DensePose API.
Provides logout (token blacklisting) endpoint.
"""
import logging
from typing import Optional
from fastapi import APIRouter, Request, HTTPException, status
from src.api.middleware.auth import token_blacklist
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/auth", tags=["auth"])
@router.post("/logout")
async def logout(request: Request):
"""Logout by blacklisting the current Bearer token."""
auth_header = request.headers.get("authorization")
if not auth_header or not auth_header.startswith("Bearer "):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing or invalid Authorization header",
)
token = auth_header.split(" ", 1)[1]
token_blacklist.add_token(token)
logger.info("Token blacklisted via /auth/logout")
return {"success": True, "message": "Token revoked"}

View file

@ -1,6 +1,7 @@
"""CSI data processor for WiFi-DensePose system using TDD approach."""
import asyncio
import itertools
import logging
import numpy as np
from datetime import datetime, timezone
@ -293,7 +294,8 @@ class CSIProcessor:
if count >= len(self.csi_history):
return list(self.csi_history)
else:
return list(self.csi_history)[-count:]
start = len(self.csi_history) - count
return list(itertools.islice(self.csi_history, start, len(self.csi_history)))
def get_processing_statistics(self) -> Dict[str, Any]:
"""Get processing statistics.
@ -410,8 +412,9 @@ class CSIProcessor:
# Use cached mean-phase values (pre-computed in add_to_history)
# Only take the last doppler_window frames for bounded cost
window = min(len(self._phase_cache), self._doppler_window)
cache_list = list(self._phase_cache)
phase_matrix = np.array(cache_list[-window:])
start = len(self._phase_cache) - window
cache_list = list(itertools.islice(self._phase_cache, start, len(self._phase_cache)))
phase_matrix = np.array(cache_list)
# Temporal phase differences between consecutive frames
phase_diffs = np.diff(phase_matrix, axis=0)

View file

@ -56,6 +56,10 @@ class TokenManager:
"""Verify and decode JWT token."""
try:
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
# Check token blacklist (logout invalidation)
from src.api.middleware.auth import token_blacklist
if token_blacklist.is_blacklisted(token):
raise AuthenticationError("Token has been revoked")
return payload
except JWTError as e:
logger.warning(f"JWT verification failed: {e}")

View file

@ -0,0 +1,135 @@
"""Frame budget benchmark for CSI processing pipeline.
Verifies that per-frame CSI processing stays within the 50 ms budget
required for real-time sensing at 20 FPS.
"""
import time
import statistics
import pytest
import numpy as np
from src.core.csi_processor import CSIProcessor
def _make_config():
return {
"sampling_rate": 1000,
"window_size": 256,
"overlap": 0.5,
"noise_threshold": -60,
"human_detection_threshold": 0.8,
"smoothing_factor": 0.9,
"max_history_size": 500,
"num_subcarriers": 256,
"num_antennas": 3,
"doppler_window": 64,
}
def _make_csi_data(n_subcarriers=256, n_antennas=3, seed=None):
"""Generate a synthetic CSI frame with complex-valued subcarriers."""
rng = np.random.default_rng(seed)
from unittest.mock import MagicMock
csi = MagicMock()
csi.amplitude = rng.random((n_antennas, n_subcarriers)).astype(np.float64) * 20.0
csi.phase = (rng.random((n_antennas, n_subcarriers)).astype(np.float64) - 0.5) * np.pi * 2
csi.frequency = 5.0e9
csi.bandwidth = 80e6
csi.num_subcarriers = n_subcarriers
csi.num_antennas = n_antennas
csi.snr = 25.0
csi.timestamp = time.time()
csi.metadata = {}
return csi
class TestSingleFrameBudget:
"""Single-frame processing must complete in < 50 ms."""
def test_single_frame_under_50ms(self):
proc = CSIProcessor(config=_make_config())
frame = _make_csi_data(seed=42)
# Warm up
proc.preprocess_csi_data(frame)
start = time.perf_counter()
proc.preprocess_csi_data(frame)
features = proc.extract_features(frame)
if features:
proc.detect_human_presence(features)
elapsed_ms = (time.perf_counter() - start) * 1000
assert elapsed_ms < 50, f"Single frame took {elapsed_ms:.1f} ms (budget: 50 ms)"
class TestSustainedFrameBudget:
"""Sustained 100-frame processing p95 must be < 50 ms per frame."""
def test_sustained_100_frames_p95(self):
proc = CSIProcessor(config=_make_config())
rng = np.random.default_rng(123)
n_frames = 100
latencies = []
for i in range(n_frames):
frame = _make_csi_data(seed=i)
start = time.perf_counter()
preprocessed = proc.preprocess_csi_data(frame)
features = proc.extract_features(preprocessed)
if features:
proc.detect_human_presence(features)
proc.add_to_history(frame)
elapsed_ms = (time.perf_counter() - start) * 1000
latencies.append(elapsed_ms)
p50 = statistics.median(latencies)
p95 = sorted(latencies)[int(0.95 * len(latencies))]
p99 = sorted(latencies)[int(0.99 * len(latencies))]
print(f"\n--- Sustained {n_frames}-frame benchmark ---")
print(f" p50: {p50:.2f} ms")
print(f" p95: {p95:.2f} ms")
print(f" p99: {p99:.2f} ms")
print(f" min: {min(latencies):.2f} ms")
print(f" max: {max(latencies):.2f} ms")
assert p95 < 50, f"p95 latency {p95:.1f} ms exceeds 50 ms budget"
class TestPipelineWithDoppler:
"""Full pipeline including Doppler estimation must stay within budget."""
def test_doppler_pipeline(self):
proc = CSIProcessor(config=_make_config())
n_frames = 100
latencies = []
# Fill history first
for i in range(20):
frame = _make_csi_data(seed=i + 1000)
proc.add_to_history(frame)
for i in range(n_frames):
frame = _make_csi_data(seed=i + 2000)
start = time.perf_counter()
preprocessed = proc.preprocess_csi_data(frame)
features = proc.extract_features(preprocessed)
if features:
proc.detect_human_presence(features)
proc.add_to_history(frame)
elapsed_ms = (time.perf_counter() - start) * 1000
latencies.append(elapsed_ms)
p50 = statistics.median(latencies)
p95 = sorted(latencies)[int(0.95 * len(latencies))]
p99 = sorted(latencies)[int(0.99 * len(latencies))]
print(f"\n--- Doppler pipeline benchmark ({n_frames} frames, 20 warmup) ---")
print(f" p50: {p50:.2f} ms")
print(f" p95: {p95:.2f} ms")
print(f" p99: {p99:.2f} ms")
# Doppler adds overhead but should still be within budget
assert p95 < 50, f"Doppler pipeline p95 {p95:.1f} ms exceeds 50 ms budget"

56
v1/tests/unit/conftest.py Normal file
View file

@ -0,0 +1,56 @@
"""Shared fixtures for unit tests."""
import os
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
# Set SECRET_KEY before any settings import
os.environ.setdefault("SECRET_KEY", "test-secret-key-for-unit-tests-only")
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests-only")
@pytest.fixture
def mock_settings():
"""Create a mock Settings object."""
settings = MagicMock()
settings.secret_key = "test-secret-key-for-unit-tests-only"
settings.jwt_algorithm = "HS256"
settings.jwt_expire_hours = 24
settings.app_name = "test-app"
settings.version = "0.1.0"
settings.is_production = False
settings.enable_rate_limiting = False
settings.enable_authentication = False
settings.rate_limit_requests = 100
settings.rate_limit_window = 60
settings.rate_limit_authenticated_requests = 1000
settings.allowed_hosts = ["*"]
settings.csi_buffer_size = 100
settings.stream_buffer_size = 100
settings.mock_hardware = True
settings.mock_pose_data = True
settings.enable_real_time_processing = False
settings.trusted_proxies = ["127.0.0.1"]
return settings
@pytest.fixture
def mock_domain_config():
"""Create a mock DomainConfig object."""
config = MagicMock()
config.pose_estimation = MagicMock()
config.streaming = MagicMock()
config.hardware = MagicMock()
return config
@pytest.fixture
def mock_redis():
"""Provide a mock Redis client."""
with patch("redis.Redis") as mock:
client = MagicMock()
client.ping.return_value = True
client.get.return_value = None
client.set.return_value = True
mock.return_value = client
yield client

View file

@ -0,0 +1,137 @@
"""Tests for AuthMiddleware and TokenManager."""
import pytest
import os
from unittest.mock import MagicMock, AsyncMock, patch
from datetime import datetime, timedelta
class TestTokenManager:
def test_create_token(self, mock_settings):
from src.middleware.auth import TokenManager
tm = TokenManager(mock_settings)
token = tm.create_access_token({"sub": "user1"})
assert isinstance(token, str)
assert len(token) > 0
def test_verify_valid_token(self, mock_settings):
from src.middleware.auth import TokenManager
tm = TokenManager(mock_settings)
token = tm.create_access_token({"sub": "user1", "role": "admin"})
payload = tm.verify_token(token)
assert payload["sub"] == "user1"
assert payload["role"] == "admin"
def test_verify_invalid_token(self, mock_settings):
from src.middleware.auth import TokenManager, AuthenticationError
tm = TokenManager(mock_settings)
with pytest.raises(AuthenticationError):
tm.verify_token("invalid.token.here")
def test_decode_claims(self, mock_settings):
from src.middleware.auth import TokenManager
tm = TokenManager(mock_settings)
token = tm.create_access_token({"sub": "user1"})
claims = tm.decode_token_claims(token)
assert claims is not None
assert claims["sub"] == "user1"
def test_decode_claims_invalid(self, mock_settings):
from src.middleware.auth import TokenManager
tm = TokenManager(mock_settings)
claims = tm.decode_token_claims("bad-token")
assert claims is None
def test_token_has_expiry(self, mock_settings):
from src.middleware.auth import TokenManager
tm = TokenManager(mock_settings)
token = tm.create_access_token({"sub": "user1"})
payload = tm.verify_token(token)
assert "exp" in payload
assert "iat" in payload
class TestUserManager:
def test_create_user(self):
from src.middleware.auth import UserManager
um = UserManager()
assert um.get_user("nonexistent") is None
def test_hash_password(self):
from src.middleware.auth import UserManager
hashed = UserManager.hash_password("secret123")
assert hashed != "secret123"
assert len(hashed) > 20
def test_verify_password(self):
from src.middleware.auth import UserManager
hashed = UserManager.hash_password("secret123")
assert UserManager.verify_password("secret123", hashed) is True
assert UserManager.verify_password("wrong", hashed) is False
class TestTokenBlacklist:
def test_add_and_check(self):
from src.api.middleware.auth import TokenBlacklist
bl = TokenBlacklist()
bl.add_token("tok123")
assert bl.is_blacklisted("tok123") is True
assert bl.is_blacklisted("tok456") is False
def test_blacklisted_token_rejected(self, mock_settings):
from src.middleware.auth import TokenManager, AuthenticationError
from src.api.middleware.auth import token_blacklist
tm = TokenManager(mock_settings)
token = tm.create_access_token({"sub": "user1"})
# Token should be valid
tm.verify_token(token)
# Blacklist it
token_blacklist.add_token(token)
with pytest.raises(AuthenticationError, match="revoked"):
tm.verify_token(token)
# Cleanup
token_blacklist._blacklisted_tokens.discard(token)
class TestAuthMiddleware:
def test_public_paths(self, mock_settings):
with patch("src.api.middleware.auth.get_settings", return_value=mock_settings):
from src.api.middleware.auth import AuthMiddleware
app = MagicMock()
mw = AuthMiddleware(app)
assert mw._is_public_path("/health") is True
assert mw._is_public_path("/docs") is True
assert mw._is_public_path("/api/v1/pose/analyze") is False
def test_protected_paths(self, mock_settings):
with patch("src.api.middleware.auth.get_settings", return_value=mock_settings):
from src.api.middleware.auth import AuthMiddleware
app = MagicMock()
mw = AuthMiddleware(app)
assert mw._is_protected_path("/api/v1/pose/analyze") is True
assert mw._is_protected_path("/health") is False
def test_extract_token_from_header(self, mock_settings):
with patch("src.api.middleware.auth.get_settings", return_value=mock_settings):
from src.api.middleware.auth import AuthMiddleware
app = MagicMock()
mw = AuthMiddleware(app)
request = MagicMock()
request.headers = {"authorization": "Bearer mytoken123"}
request.query_params = {}
request.cookies = {}
token = mw._extract_token(request)
assert token == "mytoken123"
def test_extract_token_missing(self, mock_settings):
with patch("src.api.middleware.auth.get_settings", return_value=mock_settings):
from src.api.middleware.auth import AuthMiddleware
app = MagicMock()
mw = AuthMiddleware(app)
request = MagicMock()
request.headers = {}
request.query_params = {}
request.cookies = {}
token = mw._extract_token(request)
assert token is None

View file

@ -0,0 +1,78 @@
"""Tests for error handling in the API layer."""
import pytest
from unittest.mock import MagicMock, patch
from fastapi.testclient import TestClient
class TestExceptionHandlers:
"""Test the exception handlers registered on the FastAPI app."""
def _get_app(self):
"""Import app lazily to avoid side effects."""
with patch("src.api.main.get_settings") as mock_gs, \
patch("src.api.main.get_domain_config") as mock_gdc, \
patch("src.api.main.get_pose_service") as mock_ps, \
patch("src.api.main.get_stream_service") as mock_ss, \
patch("src.api.main.get_hardware_service") as mock_hs, \
patch("src.api.main.connection_manager") as mock_cm, \
patch("src.api.main.PoseStreamHandler") as mock_psh:
mock_gs.return_value = MagicMock(
app_name="test", version="0.1", environment="test",
is_production=False, enable_rate_limiting=False,
enable_authentication=False, docs_url="/docs",
redoc_url="/redoc", openapi_url="/openapi.json",
api_prefix="/api/v1",
)
mock_gs.return_value.get_logging_config.return_value = {
"version": 1, "disable_existing_loggers": False,
"handlers": {}, "loggers": {},
}
mock_gs.return_value.get_cors_config.return_value = {
"allow_origins": ["*"], "allow_methods": ["*"],
"allow_headers": ["*"],
}
# Re-import to pick up patches
import importlib
import src.api.main as m
importlib.reload(m)
return m.app
class TestErrorResponseModel:
def test_error_json_structure(self):
"""Verify error JSON has code, message, type fields."""
error = {
"error": {
"code": 404,
"message": "Not found",
"type": "http_error"
}
}
assert error["error"]["code"] == 404
assert "message" in error["error"]
assert "type" in error["error"]
def test_validation_error_structure(self):
error = {
"error": {
"code": 422,
"message": "Validation error",
"type": "validation_error",
"details": []
}
}
assert error["error"]["type"] == "validation_error"
assert isinstance(error["error"]["details"], list)
def test_internal_error_masks_details(self):
"""In production, internal errors should not leak stack traces."""
error = {
"error": {
"code": 500,
"message": "Internal server error",
"type": "internal_error"
}
}
assert "traceback" not in str(error)
assert error["error"]["message"] == "Internal server error"

View file

@ -0,0 +1,65 @@
"""Tests for HardwareService."""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
class TestHardwareServiceInit:
def test_init(self, mock_settings, mock_domain_config):
mock_settings.mock_hardware = True
with patch("src.services.hardware_service.RouterInterface"):
from src.services.hardware_service import HardwareService
svc = HardwareService(mock_settings, mock_domain_config)
assert svc.is_running is False
assert svc.stats["total_samples"] == 0
assert svc.stats["connected_routers"] == 0
def test_stats_defaults(self, mock_settings, mock_domain_config):
mock_settings.mock_hardware = True
with patch("src.services.hardware_service.RouterInterface"):
from src.services.hardware_service import HardwareService
svc = HardwareService(mock_settings, mock_domain_config)
assert svc.stats["successful_samples"] == 0
assert svc.stats["failed_samples"] == 0
assert svc.stats["last_sample_time"] is None
class TestHardwareServiceLifecycle:
@pytest.mark.asyncio
async def test_start(self, mock_settings, mock_domain_config):
mock_settings.mock_hardware = True
with patch("src.services.hardware_service.RouterInterface"):
from src.services.hardware_service import HardwareService
svc = HardwareService(mock_settings, mock_domain_config)
svc._initialize_routers = AsyncMock()
svc._monitoring_loop = AsyncMock()
await svc.start()
assert svc.is_running is True
@pytest.mark.asyncio
async def test_double_start_idempotent(self, mock_settings, mock_domain_config):
mock_settings.mock_hardware = True
with patch("src.services.hardware_service.RouterInterface"):
from src.services.hardware_service import HardwareService
svc = HardwareService(mock_settings, mock_domain_config)
svc._initialize_routers = AsyncMock()
svc._monitoring_loop = AsyncMock()
await svc.start()
await svc.start() # idempotent
assert svc.is_running is True
class TestHardwareServiceRouter:
def test_no_routers_on_init(self, mock_settings, mock_domain_config):
mock_settings.mock_hardware = True
with patch("src.services.hardware_service.RouterInterface"):
from src.services.hardware_service import HardwareService
svc = HardwareService(mock_settings, mock_domain_config)
assert len(svc.router_interfaces) == 0
def test_max_recent_samples(self, mock_settings, mock_domain_config):
mock_settings.mock_hardware = True
with patch("src.services.hardware_service.RouterInterface"):
from src.services.hardware_service import HardwareService
svc = HardwareService(mock_settings, mock_domain_config)
assert svc.max_recent_samples == 1000

View file

@ -0,0 +1,67 @@
"""Tests for HealthCheckService."""
import pytest
from unittest.mock import MagicMock
class TestHealthCheckServiceInit:
def test_init(self, mock_settings):
from src.services.health_check import HealthCheckService
svc = HealthCheckService(mock_settings)
assert svc._initialized is False
assert svc._running is False
@pytest.mark.asyncio
async def test_initialize(self, mock_settings):
from src.services.health_check import HealthCheckService
svc = HealthCheckService(mock_settings)
await svc.initialize()
assert svc._initialized is True
assert "api" in svc._services
assert "database" in svc._services
assert "hardware" in svc._services
@pytest.mark.asyncio
async def test_double_initialize(self, mock_settings):
from src.services.health_check import HealthCheckService
svc = HealthCheckService(mock_settings)
await svc.initialize()
await svc.initialize() # idempotent
assert svc._initialized is True
class TestHealthCheckAggregation:
@pytest.mark.asyncio
async def test_services_registered(self, mock_settings):
from src.services.health_check import HealthCheckService, HealthStatus
svc = HealthCheckService(mock_settings)
await svc.initialize()
assert len(svc._services) == 6
for name, sh in svc._services.items():
assert sh.status == HealthStatus.UNKNOWN
@pytest.mark.asyncio
async def test_service_names(self, mock_settings):
from src.services.health_check import HealthCheckService
svc = HealthCheckService(mock_settings)
await svc.initialize()
expected = {"api", "database", "redis", "hardware", "pose", "stream"}
assert set(svc._services.keys()) == expected
class TestHealthStatus:
def test_enum_values(self):
from src.services.health_check import HealthStatus
assert HealthStatus.HEALTHY.value == "healthy"
assert HealthStatus.DEGRADED.value == "degraded"
assert HealthStatus.UNHEALTHY.value == "unhealthy"
assert HealthStatus.UNKNOWN.value == "unknown"
class TestHealthCheck:
def test_health_check_dataclass(self):
from src.services.health_check import HealthCheck, HealthStatus
hc = HealthCheck(name="test", status=HealthStatus.HEALTHY, message="ok")
assert hc.name == "test"
assert hc.status == HealthStatus.HEALTHY
assert hc.duration_ms == 0.0

View file

@ -0,0 +1,70 @@
"""Tests for MetricsService."""
import pytest
from datetime import timedelta
from unittest.mock import MagicMock, patch
class TestMetricSeries:
def test_add_point(self):
from src.services.metrics import MetricSeries
ms = MetricSeries(name="test", description="desc", unit="ms")
ms.add_point(42.0)
assert len(ms.points) == 1
assert ms.points[0].value == 42.0
def test_get_latest(self):
from src.services.metrics import MetricSeries
ms = MetricSeries(name="test", description="desc", unit="ms")
ms.add_point(1.0)
ms.add_point(2.0)
latest = ms.get_latest()
assert latest is not None
assert latest.value == 2.0
def test_get_latest_empty(self):
from src.services.metrics import MetricSeries
ms = MetricSeries(name="test", description="desc", unit="ms")
assert ms.get_latest() is None
def test_get_average(self):
from src.services.metrics import MetricSeries
ms = MetricSeries(name="test", description="desc", unit="ms")
for v in [10.0, 20.0, 30.0]:
ms.add_point(v)
avg = ms.get_average(timedelta(minutes=5))
assert avg == pytest.approx(20.0)
def test_get_average_empty(self):
from src.services.metrics import MetricSeries
ms = MetricSeries(name="test", description="desc", unit="ms")
assert ms.get_average(timedelta(minutes=5)) is None
def test_get_max(self):
from src.services.metrics import MetricSeries
ms = MetricSeries(name="test", description="desc", unit="ms")
for v in [10.0, 50.0, 30.0]:
ms.add_point(v)
mx = ms.get_max(timedelta(minutes=5))
assert mx == 50.0
def test_labels(self):
from src.services.metrics import MetricSeries
ms = MetricSeries(name="test", description="desc", unit="ms")
ms.add_point(1.0, {"region": "us-east"})
assert ms.points[0].labels["region"] == "us-east"
def test_maxlen(self):
from src.services.metrics import MetricSeries
ms = MetricSeries(name="test", description="desc", unit="ms")
for i in range(1100):
ms.add_point(float(i))
assert len(ms.points) == 1000
class TestMetricsService:
def test_init(self, mock_settings):
with patch("src.services.metrics.psutil"):
from src.services.metrics import MetricsService
svc = MetricsService(mock_settings)
assert svc._metrics is not None

View file

@ -0,0 +1,73 @@
"""Tests for PoseService."""
import pytest
import asyncio
from unittest.mock import MagicMock, AsyncMock, patch
from datetime import datetime
class TestPoseServiceInit:
def test_init_sets_defaults(self, mock_settings, mock_domain_config):
with patch.dict("sys.modules", {
"torch": MagicMock(),
"src.models.densepose_head": MagicMock(),
"src.models.modality_translation": MagicMock(),
}):
from src.services.pose_service import PoseService
svc = PoseService(mock_settings, mock_domain_config)
assert svc.is_initialized is False
assert svc.is_running is False
assert svc.stats["total_processed"] == 0
def test_stats_are_zero_on_init(self, mock_settings, mock_domain_config):
with patch.dict("sys.modules", {
"torch": MagicMock(),
"src.models.densepose_head": MagicMock(),
"src.models.modality_translation": MagicMock(),
}):
from src.services.pose_service import PoseService
svc = PoseService(mock_settings, mock_domain_config)
assert svc.stats["successful_detections"] == 0
assert svc.stats["failed_detections"] == 0
assert svc.stats["average_confidence"] == 0.0
class TestPoseServiceLifecycle:
@pytest.mark.asyncio
async def test_initialize_sets_flag(self, mock_settings, mock_domain_config):
with patch.dict("sys.modules", {
"torch": MagicMock(),
"src.models.densepose_head": MagicMock(),
"src.models.modality_translation": MagicMock(),
}):
from src.services.pose_service import PoseService
svc = PoseService(mock_settings, mock_domain_config)
await svc.initialize()
assert svc.is_initialized is True
@pytest.mark.asyncio
async def test_start_stop(self, mock_settings, mock_domain_config):
with patch.dict("sys.modules", {
"torch": MagicMock(),
"src.models.densepose_head": MagicMock(),
"src.models.modality_translation": MagicMock(),
}):
from src.services.pose_service import PoseService
svc = PoseService(mock_settings, mock_domain_config)
await svc.initialize()
await svc.start()
assert svc.is_running is True
await svc.stop()
assert svc.is_running is False
class TestPoseServiceStats:
def test_initial_classification(self, mock_settings, mock_domain_config):
with patch.dict("sys.modules", {
"torch": MagicMock(),
"src.models.densepose_head": MagicMock(),
"src.models.modality_translation": MagicMock(),
}):
from src.services.pose_service import PoseService
svc = PoseService(mock_settings, mock_domain_config)
assert svc.last_error is None

View file

@ -0,0 +1,62 @@
"""Tests for rate limiting middleware."""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
class TestRateLimitMiddleware:
def test_init(self, mock_settings):
with patch("src.api.middleware.rate_limit.get_settings", return_value=mock_settings):
from src.api.middleware.rate_limit import RateLimitMiddleware
app = MagicMock()
mw = RateLimitMiddleware(app)
assert "anonymous" in mw.rate_limits
assert "authenticated" in mw.rate_limits
assert "admin" in mw.rate_limits
def test_exempt_paths(self, mock_settings):
with patch("src.api.middleware.rate_limit.get_settings", return_value=mock_settings):
from src.api.middleware.rate_limit import RateLimitMiddleware
app = MagicMock()
mw = RateLimitMiddleware(app)
assert "/health" in mw.exempt_paths
assert "/metrics" in mw.exempt_paths
def test_is_exempt(self, mock_settings):
with patch("src.api.middleware.rate_limit.get_settings", return_value=mock_settings):
from src.api.middleware.rate_limit import RateLimitMiddleware
app = MagicMock()
mw = RateLimitMiddleware(app)
assert mw._is_exempt_path("/health") is True
assert mw._is_exempt_path("/api/v1/pose/current") is False
def test_path_specific_limits(self, mock_settings):
with patch("src.api.middleware.rate_limit.get_settings", return_value=mock_settings):
from src.api.middleware.rate_limit import RateLimitMiddleware
app = MagicMock()
mw = RateLimitMiddleware(app)
assert "/api/v1/pose/current" in mw.path_limits
assert mw.path_limits["/api/v1/pose/current"]["requests"] == 60
def test_trusted_proxies_not_blocked(self, mock_settings):
with patch("src.api.middleware.rate_limit.get_settings", return_value=mock_settings):
from src.api.middleware.rate_limit import RateLimitMiddleware
app = MagicMock()
mw = RateLimitMiddleware(app)
assert not mw._is_client_blocked("new-client-id")
class TestRateLimitConfig:
def test_anonymous_limit(self, mock_settings):
with patch("src.api.middleware.rate_limit.get_settings", return_value=mock_settings):
from src.api.middleware.rate_limit import RateLimitMiddleware
app = MagicMock()
mw = RateLimitMiddleware(app)
assert mw.rate_limits["anonymous"]["burst"] == 10
def test_admin_limit(self, mock_settings):
with patch("src.api.middleware.rate_limit.get_settings", return_value=mock_settings):
from src.api.middleware.rate_limit import RateLimitMiddleware
app = MagicMock()
mw = RateLimitMiddleware(app)
assert mw.rate_limits["admin"]["requests"] == 10000

View file

@ -0,0 +1,68 @@
"""Tests for StreamService."""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
class TestStreamServiceLifecycle:
def test_init(self, mock_settings, mock_domain_config):
from src.services.stream_service import StreamService
svc = StreamService(mock_settings, mock_domain_config)
assert svc.is_running is False
assert len(svc.connections) == 0
assert svc.stats["active_connections"] == 0
@pytest.mark.asyncio
async def test_initialize(self, mock_settings, mock_domain_config):
from src.services.stream_service import StreamService
svc = StreamService(mock_settings, mock_domain_config)
await svc.initialize()
@pytest.mark.asyncio
async def test_start(self, mock_settings, mock_domain_config):
mock_settings.enable_real_time_processing = False
from src.services.stream_service import StreamService
svc = StreamService(mock_settings, mock_domain_config)
await svc.start()
assert svc.is_running is True
@pytest.mark.asyncio
async def test_stop(self, mock_settings, mock_domain_config):
mock_settings.enable_real_time_processing = False
from src.services.stream_service import StreamService
svc = StreamService(mock_settings, mock_domain_config)
await svc.start()
await svc.stop()
assert svc.is_running is False
@pytest.mark.asyncio
async def test_double_start(self, mock_settings, mock_domain_config):
mock_settings.enable_real_time_processing = False
from src.services.stream_service import StreamService
svc = StreamService(mock_settings, mock_domain_config)
await svc.start()
await svc.start() # should be idempotent
assert svc.is_running is True
class TestStreamServiceConnections:
def test_no_connections_on_init(self, mock_settings, mock_domain_config):
from src.services.stream_service import StreamService
svc = StreamService(mock_settings, mock_domain_config)
assert svc.stats["total_connections"] == 0
assert svc.stats["messages_sent"] == 0
def test_buffer_sizes(self, mock_settings, mock_domain_config):
mock_settings.stream_buffer_size = 50
from src.services.stream_service import StreamService
svc = StreamService(mock_settings, mock_domain_config)
assert svc.pose_buffer.maxlen == 50
assert svc.csi_buffer.maxlen == 50
class TestStreamServiceBroadcast:
def test_stats_messages_failed_init_zero(self, mock_settings, mock_domain_config):
from src.services.stream_service import StreamService
svc = StreamService(mock_settings, mock_domain_config)
assert svc.stats["messages_failed"] == 0
assert svc.stats["data_points_streamed"] == 0