feat: ADR-080 P1+P2 remediation — refactor, perf, tests, safety

P1 fixes (this sprint):
- P1-6: Extract sensing-server modules (cli, types, csi, pose) from main.rs
- P1-7: DDA ray march for tomography — O(max(n)) replaces O(n^3) voxel scan
- P1-8: Batch neural inference — Tensor::stack/split for single GPU call
- P1-10: Eliminate 112KB/frame alloc — islice replaces deque→list copy

P2 fixes (this quarter):
- P2-11: Python unit tests for 8 modules (rate_limit, auth, error_handler,
  pose_service, stream_service, hardware_service, health_check, metrics)
- P2-13: MAT simulated data safety guard — blocking overlay + pulsing banner
- P2-14: Wire token blacklist into auth verification + logout endpoint
- P2-15: Frame budget benchmark — confirms pipeline well under 50ms budget

Addresses 8 of 10 remaining issues from QE analysis (ADR-080).

Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
ruv 2026-04-06 17:00:27 -04:00
parent 327d0d13f6
commit 5bd0d59aa6
30 changed files with 2635 additions and 27 deletions

View file

@ -330,9 +330,36 @@ impl<B: Backend> InferenceEngine<B> {
Ok(result)
}
/// Run batched inference
/// Run batched inference.
///
/// Stacks all inputs along a new batch dimension, runs a single
/// backend call, then splits the output back into individual tensors.
/// Falls back to sequential inference if stack/split fails.
pub fn infer_batch(&self, inputs: &[Tensor]) -> NnResult<Vec<Tensor>> {
inputs.iter().map(|input| self.infer(input)).collect()
if inputs.is_empty() {
return Ok(Vec::new());
}
if inputs.len() == 1 {
return Ok(vec![self.infer(&inputs[0])?]);
}
// Try batched path: stack -> single call -> split
match Tensor::stack(inputs) {
Ok(batched_input) => {
let n = inputs.len();
let batched_output = self.backend.run_single(&batched_input)?;
match batched_output.split(n) {
Ok(outputs) => Ok(outputs),
Err(_) => {
// Fallback: sequential
inputs.iter().map(|input| self.infer(input)).collect()
}
}
}
Err(_) => {
// Fallback: sequential if shapes are incompatible
inputs.iter().map(|input| self.infer(input)).collect()
}
}
}
/// Get inference statistics

View file

@ -304,6 +304,74 @@ impl Tensor {
}
}
/// Stack multiple tensors along a new batch dimension (dim 0).
///
/// All tensors must have the same shape. The result has one extra
/// leading dimension equal to `tensors.len()`.
pub fn stack(tensors: &[Tensor]) -> NnResult<Tensor> {
if tensors.is_empty() {
return Err(NnError::tensor_op("Cannot stack zero tensors"));
}
let first_shape = tensors[0].shape();
for (i, t) in tensors.iter().enumerate().skip(1) {
if t.shape() != first_shape {
return Err(NnError::tensor_op(&format!(
"Shape mismatch at index {i}: expected {first_shape}, got {}",
t.shape()
)));
}
}
let mut all_data: Vec<f32> = Vec::with_capacity(tensors.len() * first_shape.numel());
for t in tensors {
let data = t.to_vec()?;
all_data.extend_from_slice(&data);
}
let mut new_dims = vec![tensors.len()];
new_dims.extend_from_slice(first_shape.dims());
let arr = ndarray::ArrayD::from_shape_vec(
ndarray::IxDyn(&new_dims),
all_data,
)
.map_err(|e| NnError::tensor_op(&format!("Stack reshape failed: {e}")))?;
Ok(Tensor::FloatND(arr))
}
/// Split a tensor along dim 0 into `n` sub-tensors.
///
/// The first dimension must be evenly divisible by `n`.
pub fn split(self, n: usize) -> NnResult<Vec<Tensor>> {
if n == 0 {
return Err(NnError::tensor_op("Cannot split into 0 pieces"));
}
let shape = self.shape();
let batch = shape.dim(0).ok_or_else(|| NnError::tensor_op("Tensor has no dimensions"))?;
if batch % n != 0 {
return Err(NnError::tensor_op(&format!(
"Batch dim {batch} not divisible by {n}"
)));
}
let chunk_size = batch / n;
let data = self.to_vec()?;
let elem_per_sample = shape.numel() / batch;
let sub_dims: Vec<usize> = {
let mut d = shape.dims().to_vec();
d[0] = chunk_size;
d
};
let mut result = Vec::with_capacity(n);
for i in 0..n {
let start = i * chunk_size * elem_per_sample;
let end = start + chunk_size * elem_per_sample;
let arr = ndarray::ArrayD::from_shape_vec(
ndarray::IxDyn(&sub_dims),
data[start..end].to_vec(),
)
.map_err(|e| NnError::tensor_op(&format!("Split reshape failed: {e}")))?;
result.push(Tensor::FloatND(arr));
}
Ok(result)
}
/// Compute standard deviation
pub fn std(&self) -> NnResult<f32> {
match self {

View file

@ -0,0 +1,105 @@
//! CLI argument definitions and early-exit mode handlers.
use std::path::PathBuf;
use clap::Parser;
/// CLI arguments for the sensing server.
#[derive(Parser, Debug)]
#[command(name = "sensing-server", about = "WiFi-DensePose sensing server")]
pub struct Args {
/// HTTP port for UI and REST API
#[arg(long, default_value = "8080")]
pub http_port: u16,
/// WebSocket port for sensing stream
#[arg(long, default_value = "8765")]
pub ws_port: u16,
/// UDP port for ESP32 CSI frames
#[arg(long, default_value = "5005")]
pub udp_port: u16,
/// Path to UI static files
#[arg(long, default_value = "../../ui")]
pub ui_path: PathBuf,
/// Tick interval in milliseconds (default 100 ms = 10 fps for smooth pose animation)
#[arg(long, default_value = "100")]
pub tick_ms: u64,
/// Bind address (default 127.0.0.1; set to 0.0.0.0 for network access)
#[arg(long, default_value = "127.0.0.1", env = "SENSING_BIND_ADDR")]
pub bind_addr: String,
/// Data source: auto, wifi, esp32, simulate
#[arg(long, default_value = "auto")]
pub source: String,
/// Run vital sign detection benchmark (1000 frames) and exit
#[arg(long)]
pub benchmark: bool,
/// Load model config from an RVF container at startup
#[arg(long, value_name = "PATH")]
pub load_rvf: Option<PathBuf>,
/// Save current model state as an RVF container on shutdown
#[arg(long, value_name = "PATH")]
pub save_rvf: Option<PathBuf>,
/// Load a trained .rvf model for inference
#[arg(long, value_name = "PATH")]
pub model: Option<PathBuf>,
/// Enable progressive loading (Layer A instant start)
#[arg(long)]
pub progressive: bool,
/// Export an RVF container package and exit (no server)
#[arg(long, value_name = "PATH")]
pub export_rvf: Option<PathBuf>,
/// Run training mode (train a model and exit)
#[arg(long)]
pub train: bool,
/// Path to dataset directory (MM-Fi or Wi-Pose)
#[arg(long, value_name = "PATH")]
pub dataset: Option<PathBuf>,
/// Dataset type: "mmfi" or "wipose"
#[arg(long, value_name = "TYPE", default_value = "mmfi")]
pub dataset_type: String,
/// Number of training epochs
#[arg(long, default_value = "100")]
pub epochs: usize,
/// Directory for training checkpoints
#[arg(long, value_name = "DIR")]
pub checkpoint_dir: Option<PathBuf>,
/// Run self-supervised contrastive pretraining (ADR-024)
#[arg(long)]
pub pretrain: bool,
/// Number of pretraining epochs (default 50)
#[arg(long, default_value = "50")]
pub pretrain_epochs: usize,
/// Extract embeddings mode: load model and extract CSI embeddings
#[arg(long)]
pub embed: bool,
/// Build fingerprint index from embeddings (env|activity|temporal|person)
#[arg(long, value_name = "TYPE")]
pub build_index: Option<String>,
/// Node positions for multistatic fusion (format: "x,y,z;x,y,z;...")
#[arg(long, env = "SENSING_NODE_POSITIONS")]
pub node_positions: Option<String>,
/// Start field model calibration on boot (empty room required)
#[arg(long)]
pub calibrate: bool,
}

View file

@ -0,0 +1,675 @@
//! CSI frame parsing, signal field generation, feature extraction,
//! classification, vital signs smoothing, and multi-person estimation.
use std::collections::{HashMap, VecDeque};
use ruvector_mincut::{DynamicMinCut, MinCutBuilder};
use crate::adaptive_classifier;
use crate::types::*;
use crate::vital_signs::VitalSigns;
// ── ESP32 UDP frame parsers ─────────────────────────────────────────────────
/// Parse a 32-byte edge vitals packet (magic 0xC511_0002).
pub fn parse_esp32_vitals(buf: &[u8]) -> Option<Esp32VitalsPacket> {
if buf.len() < 32 { return None; }
let magic = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]);
if magic != 0xC511_0002 { return None; }
let node_id = buf[4];
let flags = buf[5];
let breathing_raw = u16::from_le_bytes([buf[6], buf[7]]);
let heartrate_raw = u32::from_le_bytes([buf[8], buf[9], buf[10], buf[11]]);
let rssi = buf[12] as i8;
let n_persons = buf[13];
let motion_energy = f32::from_le_bytes([buf[16], buf[17], buf[18], buf[19]]);
let presence_score = f32::from_le_bytes([buf[20], buf[21], buf[22], buf[23]]);
let timestamp_ms = u32::from_le_bytes([buf[24], buf[25], buf[26], buf[27]]);
Some(Esp32VitalsPacket {
node_id,
presence: (flags & 0x01) != 0,
fall_detected: (flags & 0x02) != 0,
motion: (flags & 0x04) != 0,
breathing_rate_bpm: breathing_raw as f64 / 100.0,
heartrate_bpm: heartrate_raw as f64 / 10000.0,
rssi, n_persons, motion_energy, presence_score, timestamp_ms,
})
}
/// Parse a WASM output packet (magic 0xC511_0004).
pub fn parse_wasm_output(buf: &[u8]) -> Option<WasmOutputPacket> {
if buf.len() < 8 { return None; }
let magic = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]);
if magic != 0xC511_0004 { return None; }
let node_id = buf[4];
let module_id = buf[5];
let event_count = u16::from_le_bytes([buf[6], buf[7]]) as usize;
let mut events = Vec::with_capacity(event_count);
let mut offset = 8;
for _ in 0..event_count {
if offset + 5 > buf.len() { break; }
let event_type = buf[offset];
let value = f32::from_le_bytes([
buf[offset + 1], buf[offset + 2], buf[offset + 3], buf[offset + 4],
]);
events.push(WasmEvent { event_type, value });
offset += 5;
}
Some(WasmOutputPacket { node_id, module_id, events })
}
pub fn parse_esp32_frame(buf: &[u8]) -> Option<Esp32Frame> {
if buf.len() < 20 { return None; }
let magic = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]);
if magic != 0xC511_0001 { return None; }
let node_id = buf[4];
let n_antennas = buf[5];
let n_subcarriers = buf[6];
let freq_mhz = u16::from_le_bytes([buf[8], buf[9]]);
let sequence = u32::from_le_bytes([buf[10], buf[11], buf[12], buf[13]]);
let rssi_raw = buf[14] as i8;
let rssi = if rssi_raw > 0 { rssi_raw.saturating_neg() } else { rssi_raw };
let noise_floor = buf[15] as i8;
let iq_start = 20;
let n_pairs = n_antennas as usize * n_subcarriers as usize;
let expected_len = iq_start + n_pairs * 2;
if buf.len() < expected_len { return None; }
let mut amplitudes = Vec::with_capacity(n_pairs);
let mut phases = Vec::with_capacity(n_pairs);
for k in 0..n_pairs {
let i_val = buf[iq_start + k * 2] as i8 as f64;
let q_val = buf[iq_start + k * 2 + 1] as i8 as f64;
amplitudes.push((i_val * i_val + q_val * q_val).sqrt());
phases.push(q_val.atan2(i_val));
}
Some(Esp32Frame {
magic, node_id, n_antennas, n_subcarriers, freq_mhz, sequence,
rssi, noise_floor, amplitudes, phases,
})
}
// ── Signal field generation ─────────────────────────────────────────────────
pub fn generate_signal_field(
_mean_rssi: f64, motion_score: f64, breathing_rate_hz: f64,
signal_quality: f64, subcarrier_variances: &[f64],
) -> SignalField {
let grid = 20usize;
let mut values = vec![0.0f64; grid * grid];
let center = (grid as f64 - 1.0) / 2.0;
let max_var = subcarrier_variances.iter().cloned().fold(0.0f64, f64::max);
let norm_factor = if max_var > 1e-9 { max_var } else { 1.0 };
let n_sub = subcarrier_variances.len().max(1);
for (k, &var) in subcarrier_variances.iter().enumerate() {
let weight = (var / norm_factor) * motion_score;
if weight < 1e-6 { continue; }
let angle = (k as f64 / n_sub as f64) * 2.0 * std::f64::consts::PI;
let radius = center * 0.8 * weight.sqrt();
let hx = center + radius * angle.cos();
let hz = center + radius * angle.sin();
for z in 0..grid {
for x in 0..grid {
let dx = x as f64 - hx;
let dz = z as f64 - hz;
let dist2 = dx * dx + dz * dz;
let spread = (0.5 + weight * 2.0).max(0.5);
values[z * grid + x] += weight * (-dist2 / (2.0 * spread * spread)).exp();
}
}
}
for z in 0..grid {
for x in 0..grid {
let dx = x as f64 - center;
let dz = z as f64 - center;
let dist = (dx * dx + dz * dz).sqrt();
let base = signal_quality * (-dist * 0.12).exp();
values[z * grid + x] += base * 0.3;
}
}
if breathing_rate_hz > 0.05 {
let ring_r = center * 0.55;
let ring_width = 1.8f64;
for z in 0..grid {
for x in 0..grid {
let dx = x as f64 - center;
let dz = z as f64 - center;
let dist = (dx * dx + dz * dz).sqrt();
let ring_val = 0.08 * (-(dist - ring_r).powi(2) / (2.0 * ring_width * ring_width)).exp();
values[z * grid + x] += ring_val;
}
}
}
let field_max = values.iter().cloned().fold(0.0f64, f64::max);
let scale = if field_max > 1e-9 { 1.0 / field_max } else { 1.0 };
for v in &mut values { *v = (*v * scale).clamp(0.0, 1.0); }
SignalField { grid_size: [grid, 1, grid], values }
}
// ── Feature extraction ──────────────────────────────────────────────────────
pub fn estimate_breathing_rate_hz(frame_history: &VecDeque<Vec<f64>>, sample_rate_hz: f64) -> f64 {
let n = frame_history.len();
if n < 6 { return 0.0; }
let series: Vec<f64> = frame_history.iter()
.map(|amps| if amps.is_empty() { 0.0 } else { amps.iter().sum::<f64>() / amps.len() as f64 })
.collect();
let mean_s = series.iter().sum::<f64>() / n as f64;
let detrended: Vec<f64> = series.iter().map(|x| x - mean_s).collect();
let n_candidates = 9usize;
let f_low = 0.1f64;
let f_high = 0.5f64;
let mut best_freq = 0.0f64;
let mut best_power = 0.0f64;
for i in 0..n_candidates {
let freq = f_low + (f_high - f_low) * i as f64 / (n_candidates - 1).max(1) as f64;
let omega = 2.0 * std::f64::consts::PI * freq / sample_rate_hz;
let coeff = 2.0 * omega.cos();
let (mut s_prev2, mut s_prev1) = (0.0f64, 0.0f64);
for &x in &detrended {
let s = x + coeff * s_prev1 - s_prev2;
s_prev2 = s_prev1;
s_prev1 = s;
}
let power = s_prev2 * s_prev2 + s_prev1 * s_prev1 - coeff * s_prev1 * s_prev2;
if power > best_power { best_power = power; best_freq = freq; }
}
let avg_power = {
let mut total = 0.0f64;
for i in 0..n_candidates {
let freq = f_low + (f_high - f_low) * i as f64 / (n_candidates - 1).max(1) as f64;
let omega = 2.0 * std::f64::consts::PI * freq / sample_rate_hz;
let coeff = 2.0 * omega.cos();
let (mut s_prev2, mut s_prev1) = (0.0f64, 0.0f64);
for &x in &detrended {
let s = x + coeff * s_prev1 - s_prev2;
s_prev2 = s_prev1;
s_prev1 = s;
}
total += s_prev2 * s_prev2 + s_prev1 * s_prev1 - coeff * s_prev1 * s_prev2;
}
total / n_candidates as f64
};
if best_power > avg_power * 3.0 { best_freq.clamp(f_low, f_high) } else { 0.0 }
}
pub fn compute_subcarrier_importance_weights(sensitivity: &[f64]) -> Vec<f64> {
let n = sensitivity.len();
if n == 0 { return vec![]; }
let max_sens = sensitivity.iter().cloned().fold(f64::NEG_INFINITY, f64::max).max(1e-9);
let mut sorted = sensitivity.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let median = if n % 2 == 0 { (sorted[n / 2 - 1] + sorted[n / 2]) / 2.0 } else { sorted[n / 2] };
sensitivity.iter()
.map(|&s| if s >= median { 1.0 + (s / max_sens).min(1.0) } else { 0.5 })
.collect()
}
pub fn compute_subcarrier_variances(frame_history: &VecDeque<Vec<f64>>, n_sub: usize) -> Vec<f64> {
if frame_history.is_empty() || n_sub == 0 { return vec![0.0; n_sub]; }
let n_frames = frame_history.len() as f64;
let mut means = vec![0.0f64; n_sub];
let mut sq_means = vec![0.0f64; n_sub];
for frame in frame_history.iter() {
for k in 0..n_sub {
let a = if k < frame.len() { frame[k] } else { 0.0 };
means[k] += a;
sq_means[k] += a * a;
}
}
(0..n_sub).map(|k| {
let mean = means[k] / n_frames;
let sq_mean = sq_means[k] / n_frames;
(sq_mean - mean * mean).max(0.0)
}).collect()
}
pub fn extract_features_from_frame(
frame: &Esp32Frame, frame_history: &VecDeque<Vec<f64>>, sample_rate_hz: f64,
) -> (FeatureInfo, ClassificationInfo, f64, Vec<f64>, f64) {
let n_sub = frame.amplitudes.len().max(1);
let n = n_sub as f64;
let mean_rssi = frame.rssi as f64;
let sub_sensitivity: Vec<f64> = frame.amplitudes.iter().map(|a| a.abs()).collect();
let importance_weights = compute_subcarrier_importance_weights(&sub_sensitivity);
let weight_sum: f64 = importance_weights.iter().sum::<f64>();
let mean_amp: f64 = if weight_sum > 0.0 {
frame.amplitudes.iter().zip(importance_weights.iter())
.map(|(a, w)| a * w).sum::<f64>() / weight_sum
} else {
frame.amplitudes.iter().sum::<f64>() / n
};
let intra_variance: f64 = if weight_sum > 0.0 {
frame.amplitudes.iter().zip(importance_weights.iter())
.map(|(a, w)| w * (a - mean_amp).powi(2)).sum::<f64>() / weight_sum
} else {
frame.amplitudes.iter().map(|a| (a - mean_amp).powi(2)).sum::<f64>() / n
};
let sub_variances = compute_subcarrier_variances(frame_history, n_sub);
let temporal_variance: f64 = if sub_variances.is_empty() {
intra_variance
} else {
sub_variances.iter().sum::<f64>() / sub_variances.len() as f64
};
let variance = intra_variance.max(temporal_variance);
let spectral_power: f64 = frame.amplitudes.iter().map(|a| a * a).sum::<f64>() / n;
let half = frame.amplitudes.len() / 2;
let motion_band_power = if half > 0 {
frame.amplitudes[half..].iter().map(|a| (a - mean_amp).powi(2)).sum::<f64>()
/ (frame.amplitudes.len() - half) as f64
} else { 0.0 };
let breathing_band_power = if half > 0 {
frame.amplitudes[..half].iter().map(|a| (a - mean_amp).powi(2)).sum::<f64>() / half as f64
} else { 0.0 };
let peak_idx = frame.amplitudes.iter().enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i).unwrap_or(0);
let dominant_freq_hz = peak_idx as f64 * 0.05;
let threshold = mean_amp * 1.2;
let change_points = frame.amplitudes.windows(2)
.filter(|w| (w[0] < threshold) != (w[1] < threshold)).count();
let temporal_motion_score = if let Some(prev_frame) = frame_history.back() {
let n_cmp = n_sub.min(prev_frame.len());
if n_cmp > 0 {
let diff_energy: f64 = (0..n_cmp)
.map(|k| (frame.amplitudes[k] - prev_frame[k]).powi(2)).sum::<f64>() / n_cmp as f64;
let ref_energy = mean_amp * mean_amp + 1e-9;
(diff_energy / ref_energy).sqrt().clamp(0.0, 1.0)
} else { 0.0 }
} else {
(intra_variance / (mean_amp * mean_amp + 1e-9)).sqrt().clamp(0.0, 1.0)
};
let variance_motion = (temporal_variance / 10.0).clamp(0.0, 1.0);
let mbp_motion = (motion_band_power / 25.0).clamp(0.0, 1.0);
let cp_motion = (change_points as f64 / 15.0).clamp(0.0, 1.0);
let motion_score = (temporal_motion_score * 0.4 + variance_motion * 0.2
+ mbp_motion * 0.25 + cp_motion * 0.15).clamp(0.0, 1.0);
let snr_db = (frame.rssi as f64 - frame.noise_floor as f64).max(0.0);
let snr_quality = (snr_db / 40.0).clamp(0.0, 1.0);
let stability = (1.0 - (temporal_variance / (mean_amp * mean_amp + 1e-9)).clamp(0.0, 1.0)).max(0.0);
let signal_quality = (snr_quality * 0.6 + stability * 0.4).clamp(0.0, 1.0);
let breathing_rate_hz = estimate_breathing_rate_hz(frame_history, sample_rate_hz);
let features = FeatureInfo {
mean_rssi, variance, motion_band_power, breathing_band_power,
dominant_freq_hz, change_points, spectral_power,
};
let raw_classification = ClassificationInfo {
motion_level: raw_classify(motion_score),
presence: motion_score > 0.04,
confidence: (0.4 + signal_quality * 0.3 + motion_score * 0.3).clamp(0.0, 1.0),
};
(features, raw_classification, breathing_rate_hz, sub_variances, motion_score)
}
// ── Classification ──────────────────────────────────────────────────────────
pub fn raw_classify(score: f64) -> String {
if score > 0.25 { "active".into() }
else if score > 0.12 { "present_moving".into() }
else if score > 0.04 { "present_still".into() }
else { "absent".into() }
}
pub fn smooth_and_classify(state: &mut AppStateInner, raw: &mut ClassificationInfo, raw_motion: f64) {
state.baseline_frames += 1;
if state.baseline_frames < BASELINE_WARMUP {
state.baseline_motion = state.baseline_motion * 0.9 + raw_motion * 0.1;
} else if raw_motion < state.smoothed_motion + 0.05 {
state.baseline_motion = state.baseline_motion * (1.0 - BASELINE_EMA_ALPHA)
+ raw_motion * BASELINE_EMA_ALPHA;
}
let adjusted = (raw_motion - state.baseline_motion * 0.7).max(0.0);
state.smoothed_motion = state.smoothed_motion * (1.0 - MOTION_EMA_ALPHA) + adjusted * MOTION_EMA_ALPHA;
let sm = state.smoothed_motion;
let candidate = raw_classify(sm);
if candidate == state.current_motion_level {
state.debounce_counter = 0;
state.debounce_candidate = candidate;
} else if candidate == state.debounce_candidate {
state.debounce_counter += 1;
if state.debounce_counter >= DEBOUNCE_FRAMES {
state.current_motion_level = candidate;
state.debounce_counter = 0;
}
} else {
state.debounce_candidate = candidate;
state.debounce_counter = 1;
}
raw.motion_level = state.current_motion_level.clone();
raw.presence = sm > 0.03;
raw.confidence = (0.4 + sm * 0.6).clamp(0.0, 1.0);
}
pub fn smooth_and_classify_node(ns: &mut NodeState, raw: &mut ClassificationInfo, raw_motion: f64) {
ns.baseline_frames += 1;
if ns.baseline_frames < BASELINE_WARMUP {
ns.baseline_motion = ns.baseline_motion * 0.9 + raw_motion * 0.1;
} else if raw_motion < ns.smoothed_motion + 0.05 {
ns.baseline_motion = ns.baseline_motion * (1.0 - BASELINE_EMA_ALPHA) + raw_motion * BASELINE_EMA_ALPHA;
}
let adjusted = (raw_motion - ns.baseline_motion * 0.7).max(0.0);
ns.smoothed_motion = ns.smoothed_motion * (1.0 - MOTION_EMA_ALPHA) + adjusted * MOTION_EMA_ALPHA;
let sm = ns.smoothed_motion;
let candidate = raw_classify(sm);
if candidate == ns.current_motion_level {
ns.debounce_counter = 0;
ns.debounce_candidate = candidate;
} else if candidate == ns.debounce_candidate {
ns.debounce_counter += 1;
if ns.debounce_counter >= DEBOUNCE_FRAMES {
ns.current_motion_level = candidate;
ns.debounce_counter = 0;
}
} else {
ns.debounce_candidate = candidate;
ns.debounce_counter = 1;
}
raw.motion_level = ns.current_motion_level.clone();
raw.presence = sm > 0.03;
raw.confidence = (0.4 + sm * 0.6).clamp(0.0, 1.0);
}
pub fn adaptive_override(state: &AppStateInner, features: &FeatureInfo, classification: &mut ClassificationInfo) {
if let Some(ref model) = state.adaptive_model {
let amps = state.frame_history.back().map(|v| v.as_slice()).unwrap_or(&[]);
let feat_arr = adaptive_classifier::features_from_runtime(
&serde_json::json!({
"variance": features.variance,
"motion_band_power": features.motion_band_power,
"breathing_band_power": features.breathing_band_power,
"spectral_power": features.spectral_power,
"dominant_freq_hz": features.dominant_freq_hz,
"change_points": features.change_points,
"mean_rssi": features.mean_rssi,
}),
amps,
);
let (label, conf) = model.classify(&feat_arr);
classification.motion_level = label.to_string();
classification.presence = label != "absent";
classification.confidence = (conf * 0.7 + classification.confidence * 0.3).clamp(0.0, 1.0);
}
}
// ── Vital signs smoothing ───────────────────────────────────────────────────
fn trimmed_mean(buf: &VecDeque<f64>) -> f64 {
if buf.is_empty() { return 0.0; }
let mut sorted: Vec<f64> = buf.iter().copied().collect();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let n = sorted.len();
let trim = n / 4;
let middle = &sorted[trim..n - trim.max(0)];
if middle.is_empty() { sorted[n / 2] } else { middle.iter().sum::<f64>() / middle.len() as f64 }
}
pub fn smooth_vitals(state: &mut AppStateInner, raw: &VitalSigns) -> VitalSigns {
let raw_hr = raw.heart_rate_bpm.unwrap_or(0.0);
let raw_br = raw.breathing_rate_bpm.unwrap_or(0.0);
let hr_ok = state.smoothed_hr < 1.0 || (raw_hr - state.smoothed_hr).abs() < HR_MAX_JUMP;
let br_ok = state.smoothed_br < 1.0 || (raw_br - state.smoothed_br).abs() < BR_MAX_JUMP;
if hr_ok && raw_hr > 0.0 {
state.hr_buffer.push_back(raw_hr);
if state.hr_buffer.len() > VITAL_MEDIAN_WINDOW { state.hr_buffer.pop_front(); }
}
if br_ok && raw_br > 0.0 {
state.br_buffer.push_back(raw_br);
if state.br_buffer.len() > VITAL_MEDIAN_WINDOW { state.br_buffer.pop_front(); }
}
let trimmed_hr = trimmed_mean(&state.hr_buffer);
let trimmed_br = trimmed_mean(&state.br_buffer);
if trimmed_hr > 0.0 {
if state.smoothed_hr < 1.0 { state.smoothed_hr = trimmed_hr; }
else if (trimmed_hr - state.smoothed_hr).abs() > HR_DEAD_BAND {
state.smoothed_hr = state.smoothed_hr * (1.0 - VITAL_EMA_ALPHA) + trimmed_hr * VITAL_EMA_ALPHA;
}
}
if trimmed_br > 0.0 {
if state.smoothed_br < 1.0 { state.smoothed_br = trimmed_br; }
else if (trimmed_br - state.smoothed_br).abs() > BR_DEAD_BAND {
state.smoothed_br = state.smoothed_br * (1.0 - VITAL_EMA_ALPHA) + trimmed_br * VITAL_EMA_ALPHA;
}
}
state.smoothed_hr_conf = state.smoothed_hr_conf * 0.92 + raw.heartbeat_confidence * 0.08;
state.smoothed_br_conf = state.smoothed_br_conf * 0.92 + raw.breathing_confidence * 0.08;
VitalSigns {
breathing_rate_bpm: if state.smoothed_br > 1.0 { Some(state.smoothed_br) } else { None },
heart_rate_bpm: if state.smoothed_hr > 1.0 { Some(state.smoothed_hr) } else { None },
breathing_confidence: state.smoothed_br_conf,
heartbeat_confidence: state.smoothed_hr_conf,
signal_quality: raw.signal_quality,
}
}
pub fn smooth_vitals_node(ns: &mut NodeState, raw: &VitalSigns) -> VitalSigns {
let raw_hr = raw.heart_rate_bpm.unwrap_or(0.0);
let raw_br = raw.breathing_rate_bpm.unwrap_or(0.0);
let hr_ok = ns.smoothed_hr < 1.0 || (raw_hr - ns.smoothed_hr).abs() < HR_MAX_JUMP;
let br_ok = ns.smoothed_br < 1.0 || (raw_br - ns.smoothed_br).abs() < BR_MAX_JUMP;
if hr_ok && raw_hr > 0.0 {
ns.hr_buffer.push_back(raw_hr);
if ns.hr_buffer.len() > VITAL_MEDIAN_WINDOW { ns.hr_buffer.pop_front(); }
}
if br_ok && raw_br > 0.0 {
ns.br_buffer.push_back(raw_br);
if ns.br_buffer.len() > VITAL_MEDIAN_WINDOW { ns.br_buffer.pop_front(); }
}
let trimmed_hr = trimmed_mean(&ns.hr_buffer);
let trimmed_br = trimmed_mean(&ns.br_buffer);
if trimmed_hr > 0.0 {
if ns.smoothed_hr < 1.0 { ns.smoothed_hr = trimmed_hr; }
else if (trimmed_hr - ns.smoothed_hr).abs() > HR_DEAD_BAND {
ns.smoothed_hr = ns.smoothed_hr * (1.0 - VITAL_EMA_ALPHA) + trimmed_hr * VITAL_EMA_ALPHA;
}
}
if trimmed_br > 0.0 {
if ns.smoothed_br < 1.0 { ns.smoothed_br = trimmed_br; }
else if (trimmed_br - ns.smoothed_br).abs() > BR_DEAD_BAND {
ns.smoothed_br = ns.smoothed_br * (1.0 - VITAL_EMA_ALPHA) + trimmed_br * VITAL_EMA_ALPHA;
}
}
ns.smoothed_hr_conf = ns.smoothed_hr_conf * 0.92 + raw.heartbeat_confidence * 0.08;
ns.smoothed_br_conf = ns.smoothed_br_conf * 0.92 + raw.breathing_confidence * 0.08;
VitalSigns {
breathing_rate_bpm: if ns.smoothed_br > 1.0 { Some(ns.smoothed_br) } else { None },
heart_rate_bpm: if ns.smoothed_hr > 1.0 { Some(ns.smoothed_hr) } else { None },
breathing_confidence: ns.smoothed_br_conf,
heartbeat_confidence: ns.smoothed_hr_conf,
signal_quality: raw.signal_quality,
}
}
// ── Multi-person estimation ─────────────────────────────────────────────────
pub fn fuse_multi_node_features(
current_features: &FeatureInfo, node_states: &HashMap<u8, NodeState>,
) -> FeatureInfo {
let now = std::time::Instant::now();
let active: Vec<(&FeatureInfo, f64)> = node_states.values()
.filter(|ns| ns.last_frame_time.map_or(false, |t| now.duration_since(t).as_secs() < 10))
.filter_map(|ns| {
let feat = ns.latest_features.as_ref()?;
let rssi = ns.rssi_history.back().copied().unwrap_or(-80.0);
Some((feat, rssi))
})
.collect();
if active.len() <= 1 { return current_features.clone(); }
let max_rssi = active.iter().map(|(_, r)| *r).fold(f64::NEG_INFINITY, f64::max);
let weights: Vec<f64> = active.iter()
.map(|(_, r)| (1.0 + (r - max_rssi + 20.0) / 20.0).clamp(0.1, 1.0)).collect();
let w_sum: f64 = weights.iter().sum::<f64>().max(1e-9);
FeatureInfo {
variance: active.iter().zip(&weights).map(|((f, _), w)| f.variance * w).sum::<f64>() / w_sum,
motion_band_power: active.iter().zip(&weights).map(|((f, _), w)| f.motion_band_power * w).sum::<f64>() / w_sum,
breathing_band_power: active.iter().zip(&weights).map(|((f, _), w)| f.breathing_band_power * w).sum::<f64>() / w_sum,
spectral_power: active.iter().zip(&weights).map(|((f, _), w)| f.spectral_power * w).sum::<f64>() / w_sum,
dominant_freq_hz: active.iter().zip(&weights).map(|((f, _), w)| f.dominant_freq_hz * w).sum::<f64>() / w_sum,
change_points: current_features.change_points,
mean_rssi: active.iter().map(|(f, _)| f.mean_rssi).fold(f64::NEG_INFINITY, f64::max),
}
}
pub fn compute_person_score(feat: &FeatureInfo) -> f64 {
let var_norm = (feat.variance / 300.0).clamp(0.0, 1.0);
let cp_norm = (feat.change_points as f64 / 30.0).clamp(0.0, 1.0);
let motion_norm = (feat.motion_band_power / 250.0).clamp(0.0, 1.0);
let sp_norm = (feat.spectral_power / 500.0).clamp(0.0, 1.0);
var_norm * 0.40 + cp_norm * 0.20 + motion_norm * 0.25 + sp_norm * 0.15
}
pub fn estimate_persons_from_correlation(frame_history: &VecDeque<Vec<f64>>) -> usize {
let n_frames = frame_history.len();
if n_frames < 10 { return 1; }
let window: Vec<&Vec<f64>> = frame_history.iter().rev().take(20).collect();
let n_sub = window[0].len().min(56);
if n_sub < 4 { return 1; }
let k = window.len() as f64;
let mut means = vec![0.0f64; n_sub];
let mut variances = vec![0.0f64; n_sub];
for frame in &window {
for sc in 0..n_sub.min(frame.len()) { means[sc] += frame[sc] / k; }
}
for frame in &window {
for sc in 0..n_sub.min(frame.len()) { variances[sc] += (frame[sc] - means[sc]).powi(2) / k; }
}
let noise_floor = 1.0;
let active: Vec<usize> = (0..n_sub).filter(|&sc| variances[sc] > noise_floor).collect();
let m = active.len();
if m < 3 { return if m == 0 { 0 } else { 1 }; }
let mut edges: Vec<(u64, u64, f64)> = Vec::new();
let source = m as u64;
let sink = (m + 1) as u64;
let stds: Vec<f64> = active.iter().map(|&sc| variances[sc].sqrt().max(1e-9)).collect();
for i in 0..m {
for j in (i + 1)..m {
let mut cov = 0.0f64;
for frame in &window {
let (si, sj) = (active[i], active[j]);
if si < frame.len() && sj < frame.len() {
cov += (frame[si] - means[si]) * (frame[sj] - means[sj]) / k;
}
}
let corr = (cov / (stds[i] * stds[j])).abs();
if corr > 0.1 {
let weight = corr * 10.0;
edges.push((i as u64, j as u64, weight));
edges.push((j as u64, i as u64, weight));
}
}
}
let (max_var_idx, _) = active.iter().enumerate()
.max_by(|(_, &a), (_, &b)| variances[a].partial_cmp(&variances[b]).unwrap())
.unwrap_or((0, &0));
let (min_var_idx, _) = active.iter().enumerate()
.min_by(|(_, &a), (_, &b)| variances[a].partial_cmp(&variances[b]).unwrap())
.unwrap_or((0, &0));
if max_var_idx == min_var_idx { return 1; }
edges.push((source, max_var_idx as u64, 100.0));
edges.push((min_var_idx as u64, sink, 100.0));
let mc: DynamicMinCut = match MinCutBuilder::new().exact().with_edges(edges.clone()).build() {
Ok(mc) => mc,
Err(_) => return 1,
};
let cut_value = mc.min_cut_value();
let total_edge_weight: f64 = edges.iter()
.filter(|(s, t, _)| *s != source && *s != sink && *t != source && *t != sink)
.map(|(_, _, w)| w).sum::<f64>() / 2.0;
if total_edge_weight < 1e-9 { return 1; }
let cut_ratio = cut_value / total_edge_weight;
if cut_ratio > 0.4 { 1 }
else if cut_ratio > 0.15 { 2 }
else { 3 }
}
pub fn score_to_person_count(smoothed_score: f64, prev_count: usize) -> usize {
match prev_count {
0 | 1 => {
if smoothed_score > 0.85 { 3 }
else if smoothed_score > 0.70 { 2 }
else { 1 }
}
2 => {
if smoothed_score > 0.92 { 3 }
else if smoothed_score < 0.55 { 1 }
else { 2 }
}
_ => {
if smoothed_score < 0.55 { 1 }
else if smoothed_score < 0.78 { 2 }
else { 3 }
}
}
}
/// Generate a simulated ESP32 frame for testing/demo mode.
pub fn generate_simulated_frame(tick: u64) -> Esp32Frame {
let t = tick as f64 * 0.1;
let n_sub = 56usize;
let mut amplitudes = Vec::with_capacity(n_sub);
let mut phases = Vec::with_capacity(n_sub);
for i in 0..n_sub {
let base = 15.0 + 5.0 * (i as f64 * 0.1 + t * 0.3).sin();
let noise = (i as f64 * 7.3 + t * 13.7).sin() * 2.0;
amplitudes.push((base + noise).max(0.1));
phases.push((i as f64 * 0.2 + t * 0.5).sin() * std::f64::consts::PI);
}
Esp32Frame {
magic: 0xC511_0001, node_id: 1, n_antennas: 1, n_subcarriers: n_sub as u8,
freq_mhz: 2437, sequence: tick as u32,
rssi: (-40.0 + 5.0 * (t * 0.2).sin()) as i8, noise_floor: -90,
amplitudes, phases,
}
}
/// Generate a simple timestamp (epoch seconds) for recording IDs.
pub fn chrono_timestamp() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}

View file

@ -9,11 +9,15 @@
//! Replaces both ws_server.py and the Python HTTP server.
mod adaptive_classifier;
pub mod cli;
pub mod csi;
mod field_bridge;
mod multistatic_bridge;
pub mod pose;
mod rvf_container;
mod rvf_pipeline;
mod tracker_bridge;
pub mod types;
mod vital_signs;
// Training pipeline modules (exposed via lib.rs)

View file

@ -0,0 +1,194 @@
//! Skeleton derivation, pose estimation, and temporal smoothing.
use crate::types::*;
/// Expected bone lengths in pixel-space for the COCO-17 skeleton.
pub const POSE_BONE_PAIRS: &[(usize, usize)] = &[
(5, 7), (7, 9), (6, 8), (8, 10),
(5, 11), (6, 12),
(11, 13), (13, 15), (12, 14), (14, 16),
(5, 6), (11, 12),
];
const TORSO_KP: [usize; 4] = [5, 6, 11, 12];
const EXTREMITY_KP: [usize; 4] = [9, 10, 15, 16];
pub fn derive_single_person_pose(
update: &SensingUpdate, person_idx: usize, total_persons: usize,
) -> PersonDetection {
let cls = &update.classification;
let feat = &update.features;
let phase_offset = person_idx as f64 * 2.094;
let half = (total_persons as f64 - 1.0) / 2.0;
let person_x_offset = (person_idx as f64 - half) * 120.0;
let conf_decay = 1.0 - person_idx as f64 * 0.15;
let motion_score = (feat.motion_band_power / 15.0).clamp(0.0, 1.0);
let is_walking = motion_score > 0.55;
let breath_amp = (feat.breathing_band_power * 4.0).clamp(0.0, 12.0);
let breath_phase = if let Some(ref vs) = update.vital_signs {
let bpm = vs.breathing_rate_bpm.unwrap_or(15.0);
let freq = (bpm / 60.0).clamp(0.1, 0.5);
(update.tick as f64 * freq * 0.02 * std::f64::consts::TAU + phase_offset).sin()
} else {
(update.tick as f64 * 0.02 + phase_offset).sin()
};
let lean_x = (feat.dominant_freq_hz / 5.0 - 1.0).clamp(-1.0, 1.0) * 18.0;
let stride_x = if is_walking {
let stride_phase = (feat.motion_band_power * 0.7 + update.tick as f64 * 0.06 + phase_offset).sin();
stride_phase * 20.0 * motion_score
} else { 0.0 };
let burst = (feat.change_points as f64 / 20.0).clamp(0.0, 0.3);
let noise_seed = person_idx as f64 * 97.1;
let noise_val = (noise_seed.sin() * 43758.545).fract();
let snr_factor = ((feat.variance - 0.5) / 10.0).clamp(0.0, 1.0);
let base_confidence = cls.confidence * (0.6 + 0.4 * snr_factor) * conf_decay;
let base_x = 320.0 + stride_x + lean_x * 0.5 + person_x_offset;
let base_y = 240.0 - motion_score * 8.0;
let kp_names = [
"nose", "left_eye", "right_eye", "left_ear", "right_ear",
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
"left_wrist", "right_wrist", "left_hip", "right_hip",
"left_knee", "right_knee", "left_ankle", "right_ankle",
];
let kp_offsets: [(f64, f64); 17] = [
(0.0, -80.0), (-8.0, -88.0), (8.0, -88.0), (-16.0, -82.0), (16.0, -82.0),
(-30.0, -50.0), (30.0, -50.0), (-45.0, -15.0), (45.0, -15.0),
(-50.0, 20.0), (50.0, 20.0), (-20.0, 20.0), (20.0, 20.0),
(-22.0, 70.0), (22.0, 70.0), (-24.0, 120.0), (24.0, 120.0),
];
let keypoints: Vec<PoseKeypoint> = kp_names.iter().zip(kp_offsets.iter())
.enumerate()
.map(|(i, (name, (dx, dy)))| {
let breath_dx = if TORSO_KP.contains(&i) {
let sign = if *dx < 0.0 { -1.0 } else { 1.0 };
sign * breath_amp * breath_phase * 0.5
} else { 0.0 };
let breath_dy = if TORSO_KP.contains(&i) {
let sign = if *dy < 0.0 { -1.0 } else { 1.0 };
sign * breath_amp * breath_phase * 0.3
} else { 0.0 };
let extremity_jitter = if EXTREMITY_KP.contains(&i) {
let phase = noise_seed + i as f64 * 2.399;
(phase.sin() * burst * motion_score * 4.0, (phase * 1.31).cos() * burst * motion_score * 3.0)
} else { (0.0, 0.0) };
let kp_noise_x = ((noise_seed + i as f64 * 1.618).sin() * 43758.545).fract()
* feat.variance.sqrt().clamp(0.0, 3.0) * motion_score;
let kp_noise_y = ((noise_seed + i as f64 * 2.718).cos() * 31415.926).fract()
* feat.variance.sqrt().clamp(0.0, 3.0) * motion_score * 0.6;
let swing_dy = if is_walking {
let stride_phase = (feat.motion_band_power * 0.7 + update.tick as f64 * 0.12 + phase_offset).sin();
match i {
7 | 9 => -stride_phase * 20.0 * motion_score,
8 | 10 => stride_phase * 20.0 * motion_score,
13 | 15 => stride_phase * 25.0 * motion_score,
14 | 16 => -stride_phase * 25.0 * motion_score,
_ => 0.0,
}
} else { 0.0 };
let final_x = base_x + dx + breath_dx + extremity_jitter.0 + kp_noise_x;
let final_y = base_y + dy + breath_dy + extremity_jitter.1 + kp_noise_y + swing_dy;
let kp_conf = if EXTREMITY_KP.contains(&i) {
base_confidence * (0.7 + 0.3 * snr_factor) * (0.85 + 0.15 * noise_val)
} else {
base_confidence * (0.88 + 0.12 * ((i as f64 * 0.7 + noise_seed).cos()))
};
PoseKeypoint { name: name.to_string(), x: final_x, y: final_y, z: lean_x * 0.02, confidence: kp_conf.clamp(0.1, 1.0) }
})
.collect();
let xs: Vec<f64> = keypoints.iter().map(|k| k.x).collect();
let ys: Vec<f64> = keypoints.iter().map(|k| k.y).collect();
let min_x = xs.iter().cloned().fold(f64::MAX, f64::min) - 10.0;
let min_y = ys.iter().cloned().fold(f64::MAX, f64::min) - 10.0;
let max_x = xs.iter().cloned().fold(f64::MIN, f64::max) + 10.0;
let max_y = ys.iter().cloned().fold(f64::MIN, f64::max) + 10.0;
PersonDetection {
id: (person_idx + 1) as u32,
confidence: cls.confidence * conf_decay,
keypoints,
bbox: BoundingBox { x: min_x, y: min_y, width: (max_x - min_x).max(80.0), height: (max_y - min_y).max(160.0) },
zone: format!("zone_{}", person_idx + 1),
}
}
pub fn derive_pose_from_sensing(update: &SensingUpdate) -> Vec<PersonDetection> {
let cls = &update.classification;
if !cls.presence { return vec![]; }
let person_count = update.estimated_persons.unwrap_or(1).max(1);
(0..person_count).map(|idx| derive_single_person_pose(update, idx, person_count)).collect()
}
/// Apply temporal EMA smoothing and bone-length clamping to person detections.
pub fn apply_temporal_smoothing(persons: &mut [PersonDetection], ns: &mut NodeState) {
if persons.is_empty() { return; }
let alpha = ns.ema_alpha();
let person = &mut persons[0];
let current_kps: Vec<[f64; 3]> = person.keypoints.iter()
.map(|kp| [kp.x, kp.y, kp.z]).collect();
let smoothed = if let Some(ref prev) = ns.prev_keypoints {
let mut out = Vec::with_capacity(current_kps.len());
for (cur, prv) in current_kps.iter().zip(prev.iter()) {
out.push([
alpha * cur[0] + (1.0 - alpha) * prv[0],
alpha * cur[1] + (1.0 - alpha) * prv[1],
alpha * cur[2] + (1.0 - alpha) * prv[2],
]);
}
clamp_bone_lengths_f64(&mut out, prev);
out
} else {
current_kps.clone()
};
for (kp, s) in person.keypoints.iter_mut().zip(smoothed.iter()) {
kp.x = s[0]; kp.y = s[1]; kp.z = s[2];
}
ns.prev_keypoints = Some(smoothed);
}
fn clamp_bone_lengths_f64(pose: &mut Vec<[f64; 3]>, prev: &[[f64; 3]]) {
for &(p, c) in POSE_BONE_PAIRS {
if p >= pose.len() || c >= pose.len() { continue; }
let prev_len = dist_f64(&prev[p], &prev[c]);
if prev_len < 1e-6 { continue; }
let cur_len = dist_f64(&pose[p], &pose[c]);
if cur_len < 1e-6 { continue; }
let ratio = cur_len / prev_len;
let lo = 1.0 - MAX_BONE_CHANGE_RATIO;
let hi = 1.0 + MAX_BONE_CHANGE_RATIO;
if ratio < lo || ratio > hi {
let target = prev_len * ratio.clamp(lo, hi);
let scale = target / cur_len;
for dim in 0..3 {
let diff = pose[c][dim] - pose[p][dim];
pose[c][dim] = pose[p][dim] + diff * scale;
}
}
}
}
fn dist_f64(a: &[f64; 3], b: &[f64; 3]) -> f64 {
let dx = b[0] - a[0];
let dy = b[1] - a[1];
let dz = b[2] - a[2];
(dx * dx + dy * dy + dz * dz).sqrt()
}

View file

@ -0,0 +1,403 @@
//! Data types, constants, and shared state definitions.
use std::collections::{HashMap, VecDeque};
use std::path::PathBuf;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use tokio::sync::{broadcast, RwLock};
use crate::adaptive_classifier;
use crate::rvf_container::RvfContainerInfo;
use crate::rvf_pipeline::ProgressiveLoader;
use crate::vital_signs::{VitalSignDetector, VitalSigns};
use wifi_densepose_signal::ruvsense::pose_tracker::PoseTracker;
use wifi_densepose_signal::ruvsense::multistatic::MultistaticFuser;
use wifi_densepose_signal::ruvsense::field_model::FieldModel;
// ── Constants ───────────────────────────────────────────────────────────────
/// Number of frames retained in `frame_history` for temporal analysis.
pub const FRAME_HISTORY_CAPACITY: usize = 100;
/// If no ESP32 frame arrives within this duration, source reverts to offline.
pub const ESP32_OFFLINE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
/// Default EMA alpha for temporal keypoint smoothing (RuVector Phase 2).
pub const TEMPORAL_EMA_ALPHA_DEFAULT: f64 = 0.15;
/// Reduced EMA alpha when coherence is low.
pub const TEMPORAL_EMA_ALPHA_LOW_COHERENCE: f64 = 0.05;
/// Coherence threshold below which we reduce EMA alpha.
pub const COHERENCE_LOW_THRESHOLD: f64 = 0.3;
/// Maximum allowed bone-length change ratio between frames (20%).
pub const MAX_BONE_CHANGE_RATIO: f64 = 0.20;
/// Number of motion_energy frames to track for coherence scoring.
pub const COHERENCE_WINDOW: usize = 20;
/// Debounce frames required before state transition (at ~10 FPS = ~0.4s).
pub const DEBOUNCE_FRAMES: u32 = 4;
/// EMA alpha for motion smoothing (~1s time constant at 10 FPS).
pub const MOTION_EMA_ALPHA: f64 = 0.15;
/// EMA alpha for slow-adapting baseline (~30s time constant at 10 FPS).
pub const BASELINE_EMA_ALPHA: f64 = 0.003;
/// Number of warm-up frames before baseline subtraction kicks in.
pub const BASELINE_WARMUP: u64 = 50;
/// Size of the median filter window for vital signs outlier rejection.
pub const VITAL_MEDIAN_WINDOW: usize = 21;
/// EMA alpha for vital signs (~5s time constant at 10 FPS).
pub const VITAL_EMA_ALPHA: f64 = 0.02;
/// Maximum BPM jump per frame before a value is rejected as an outlier.
pub const HR_MAX_JUMP: f64 = 8.0;
pub const BR_MAX_JUMP: f64 = 2.0;
/// Minimum change from current smoothed value before EMA updates (dead-band).
pub const HR_DEAD_BAND: f64 = 2.0;
pub const BR_DEAD_BAND: f64 = 0.5;
// ── ESP32 Frame ─────────────────────────────────────────────────────────────
/// ADR-018 ESP32 CSI binary frame header (20 bytes)
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct Esp32Frame {
pub magic: u32,
pub node_id: u8,
pub n_antennas: u8,
pub n_subcarriers: u8,
pub freq_mhz: u16,
pub sequence: u32,
pub rssi: i8,
pub noise_floor: i8,
pub amplitudes: Vec<f64>,
pub phases: Vec<f64>,
}
// ── Sensing Update ──────────────────────────────────────────────────────────
/// Sensing update broadcast to WebSocket clients
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SensingUpdate {
#[serde(rename = "type")]
pub msg_type: String,
pub timestamp: f64,
pub source: String,
pub tick: u64,
pub nodes: Vec<NodeInfo>,
pub features: FeatureInfo,
pub classification: ClassificationInfo,
pub signal_field: SignalField,
#[serde(skip_serializing_if = "Option::is_none")]
pub vital_signs: Option<VitalSigns>,
#[serde(skip_serializing_if = "Option::is_none")]
pub enhanced_motion: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub enhanced_breathing: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub posture: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub signal_quality_score: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub quality_verdict: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub bssid_count: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub pose_keypoints: Option<Vec<[f64; 4]>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model_status: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub persons: Option<Vec<PersonDetection>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub estimated_persons: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub node_features: Option<Vec<PerNodeFeatureInfo>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeInfo {
pub node_id: u8,
pub rssi_dbm: f64,
pub position: [f64; 3],
pub amplitude: Vec<f64>,
pub subcarrier_count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FeatureInfo {
pub mean_rssi: f64,
pub variance: f64,
pub motion_band_power: f64,
pub breathing_band_power: f64,
pub dominant_freq_hz: f64,
pub change_points: usize,
pub spectral_power: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClassificationInfo {
pub motion_level: String,
pub presence: bool,
pub confidence: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SignalField {
pub grid_size: [usize; 3],
pub values: Vec<f64>,
}
/// WiFi-derived pose keypoint (17 COCO keypoints)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PoseKeypoint {
pub name: String,
pub x: f64,
pub y: f64,
pub z: f64,
pub confidence: f64,
}
/// Person detection from WiFi sensing
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PersonDetection {
pub id: u32,
pub confidence: f64,
pub keypoints: Vec<PoseKeypoint>,
pub bbox: BoundingBox,
pub zone: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BoundingBox {
pub x: f64,
pub y: f64,
pub width: f64,
pub height: f64,
}
/// Per-node feature info for WebSocket broadcasts (multi-node support).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerNodeFeatureInfo {
pub node_id: u8,
pub features: FeatureInfo,
pub classification: ClassificationInfo,
pub rssi_dbm: f64,
pub last_seen_ms: u64,
pub frame_rate_hz: f64,
pub stale: bool,
}
// ── ESP32 Edge Vitals Packet (ADR-039) ──────────────────────────────────────
/// Decoded vitals packet from ESP32 edge processing pipeline.
#[derive(Debug, Clone, Serialize)]
pub struct Esp32VitalsPacket {
pub node_id: u8,
pub presence: bool,
pub fall_detected: bool,
pub motion: bool,
pub breathing_rate_bpm: f64,
pub heartrate_bpm: f64,
pub rssi: i8,
pub n_persons: u8,
pub motion_energy: f32,
pub presence_score: f32,
pub timestamp_ms: u32,
}
/// Single WASM event (type + value).
#[derive(Debug, Clone, Serialize)]
pub struct WasmEvent {
pub event_type: u8,
pub value: f32,
}
/// Decoded WASM output packet from ESP32 Tier 3 runtime.
#[derive(Debug, Clone, Serialize)]
pub struct WasmOutputPacket {
pub node_id: u8,
pub module_id: u8,
pub events: Vec<WasmEvent>,
}
// ── Per-node state ──────────────────────────────────────────────────────────
/// Per-node sensing state for multi-node deployments (issue #249).
pub struct NodeState {
pub frame_history: VecDeque<Vec<f64>>,
pub smoothed_person_score: f64,
pub prev_person_count: usize,
pub smoothed_motion: f64,
pub current_motion_level: String,
pub debounce_counter: u32,
pub debounce_candidate: String,
pub baseline_motion: f64,
pub baseline_frames: u64,
pub smoothed_hr: f64,
pub smoothed_br: f64,
pub smoothed_hr_conf: f64,
pub smoothed_br_conf: f64,
pub hr_buffer: VecDeque<f64>,
pub br_buffer: VecDeque<f64>,
pub rssi_history: VecDeque<f64>,
pub vital_detector: VitalSignDetector,
pub latest_vitals: VitalSigns,
pub last_frame_time: Option<std::time::Instant>,
pub edge_vitals: Option<Esp32VitalsPacket>,
pub latest_features: Option<FeatureInfo>,
pub prev_keypoints: Option<Vec<[f64; 3]>>,
pub motion_energy_history: VecDeque<f64>,
pub coherence_score: f64,
}
impl NodeState {
pub fn new() -> Self {
Self {
frame_history: VecDeque::new(),
smoothed_person_score: 0.0,
prev_person_count: 0,
smoothed_motion: 0.0,
current_motion_level: "absent".to_string(),
debounce_counter: 0,
debounce_candidate: "absent".to_string(),
baseline_motion: 0.0,
baseline_frames: 0,
smoothed_hr: 0.0,
smoothed_br: 0.0,
smoothed_hr_conf: 0.0,
smoothed_br_conf: 0.0,
hr_buffer: VecDeque::with_capacity(8),
br_buffer: VecDeque::with_capacity(8),
rssi_history: VecDeque::new(),
vital_detector: VitalSignDetector::new(10.0),
latest_vitals: VitalSigns::default(),
last_frame_time: None,
edge_vitals: None,
latest_features: None,
prev_keypoints: None,
motion_energy_history: VecDeque::with_capacity(COHERENCE_WINDOW),
coherence_score: 1.0,
}
}
/// Update the coherence score from the latest motion_energy value.
pub fn update_coherence(&mut self, motion_energy: f64) {
if self.motion_energy_history.len() >= COHERENCE_WINDOW {
self.motion_energy_history.pop_front();
}
self.motion_energy_history.push_back(motion_energy);
let n = self.motion_energy_history.len();
if n < 2 {
self.coherence_score = 1.0;
return;
}
let mean: f64 = self.motion_energy_history.iter().sum::<f64>() / n as f64;
let variance: f64 = self.motion_energy_history.iter()
.map(|v| (v - mean) * (v - mean))
.sum::<f64>() / (n - 1) as f64;
self.coherence_score = (1.0 / (1.0 + variance)).clamp(0.0, 1.0);
}
/// Choose the EMA alpha based on current coherence score.
pub fn ema_alpha(&self) -> f64 {
if self.coherence_score < COHERENCE_LOW_THRESHOLD {
TEMPORAL_EMA_ALPHA_LOW_COHERENCE
} else {
TEMPORAL_EMA_ALPHA_DEFAULT
}
}
}
// ── Shared application state ────────────────────────────────────────────────
/// Shared application state
pub struct AppStateInner {
pub latest_update: Option<SensingUpdate>,
pub rssi_history: VecDeque<f64>,
pub frame_history: VecDeque<Vec<f64>>,
pub tick: u64,
pub source: String,
pub last_esp32_frame: Option<std::time::Instant>,
pub tx: broadcast::Sender<String>,
pub total_detections: u64,
pub start_time: std::time::Instant,
pub vital_detector: VitalSignDetector,
pub latest_vitals: VitalSigns,
pub rvf_info: Option<RvfContainerInfo>,
pub save_rvf_path: Option<PathBuf>,
pub progressive_loader: Option<ProgressiveLoader>,
pub active_sona_profile: Option<String>,
pub model_loaded: bool,
pub smoothed_person_score: f64,
pub prev_person_count: usize,
pub smoothed_motion: f64,
pub current_motion_level: String,
pub debounce_counter: u32,
pub debounce_candidate: String,
pub baseline_motion: f64,
pub baseline_frames: u64,
pub smoothed_hr: f64,
pub smoothed_br: f64,
pub smoothed_hr_conf: f64,
pub smoothed_br_conf: f64,
pub hr_buffer: VecDeque<f64>,
pub br_buffer: VecDeque<f64>,
pub edge_vitals: Option<Esp32VitalsPacket>,
pub latest_wasm_events: Option<WasmOutputPacket>,
pub discovered_models: Vec<serde_json::Value>,
pub active_model_id: Option<String>,
pub recordings: Vec<serde_json::Value>,
pub recording_active: bool,
pub recording_start_time: Option<std::time::Instant>,
pub recording_current_id: Option<String>,
pub recording_stop_tx: Option<tokio::sync::watch::Sender<bool>>,
pub training_status: String,
pub training_config: Option<serde_json::Value>,
pub adaptive_model: Option<adaptive_classifier::AdaptiveModel>,
pub node_states: HashMap<u8, NodeState>,
pub pose_tracker: PoseTracker,
pub last_tracker_instant: Option<std::time::Instant>,
pub multistatic_fuser: MultistaticFuser,
pub field_model: Option<FieldModel>,
}
impl AppStateInner {
/// Return the effective data source, accounting for ESP32 frame timeout.
pub fn effective_source(&self) -> String {
if self.source == "esp32" {
if let Some(last) = self.last_esp32_frame {
if last.elapsed() > ESP32_OFFLINE_TIMEOUT {
return "esp32:offline".to_string();
}
}
}
self.source.clone()
}
/// Person count: eigenvalue-based if field model is calibrated, else heuristic.
pub fn person_count(&self) -> usize {
use crate::field_bridge;
use crate::csi::score_to_person_count;
match self.field_model.as_ref() {
Some(fm) => {
let history = if !self.frame_history.is_empty() {
&self.frame_history
} else {
self.node_states.values()
.filter(|ns| !ns.frame_history.is_empty())
.max_by_key(|ns| ns.last_frame_time)
.map(|ns| &ns.frame_history)
.unwrap_or(&self.frame_history)
};
field_bridge::occupancy_or_fallback(
fm, history, self.smoothed_person_score, self.prev_person_count,
)
}
None => score_to_person_count(self.smoothed_person_score, self.prev_person_count),
}
}
}
pub type SharedState = Arc<RwLock<AppStateInner>>;

View file

@ -339,9 +339,16 @@ impl RfTomographer {
/// Compute the intersection weights of a link with the voxel grid.
///
/// Uses a simplified approach: for each voxel, computes the minimum
/// distance from the voxel center to the link ray. Voxels within
/// one Fresnel zone receive weight proportional to closeness.
/// Uses a DDA (Digital Differential Analyzer) ray-marching algorithm:
/// 1. March along the ray from TX to RX, advancing to the nearest
/// axis-aligned voxel boundary at each step.
/// 2. At each ray voxel, expand by the Fresnel radius to check
/// neighboring voxels.
/// 3. Use a visited bitvector to avoid duplicate entries.
/// 4. Weight = `1.0 - dist / fresnel_radius` (same as before).
///
/// This is O(ray_length / voxel_size) instead of O(nx*ny*nz),
/// a significant speedup for large grids.
fn compute_link_weights(link: &LinkGeometry, config: &TomographyConfig) -> Vec<(usize, f64)> {
let vx = (config.bounds[3] - config.bounds[0]) / config.nx as f64;
let vy = (config.bounds[4] - config.bounds[1]) / config.ny as f64;
@ -356,25 +363,74 @@ fn compute_link_weights(link: &LinkGeometry, config: &TomographyConfig) -> Vec<(
let dy = link.rx.y - link.tx.y;
let dz = link.rx.z - link.tx.z;
let n_voxels = config.nx * config.ny * config.nz;
let mut visited = vec![false; n_voxels];
let mut weights = Vec::new();
for iz in 0..config.nz {
for iy in 0..config.ny {
for ix in 0..config.nx {
let cx = config.bounds[0] + (ix as f64 + 0.5) * vx;
let cy = config.bounds[1] + (iy as f64 + 0.5) * vy;
let cz = config.bounds[2] + (iz as f64 + 0.5) * vz;
// Fresnel expansion radius in voxel units.
let expand_x = (fresnel_radius / vx).ceil() as isize;
let expand_y = (fresnel_radius / vy).ceil() as isize;
let expand_z = (fresnel_radius / vz).ceil() as isize;
// Point-to-line distance
let dist = point_to_segment_distance(
cx, cy, cz, link.tx.x, link.tx.y, link.tx.z, dx, dy, dz, link_dist,
);
// DDA initialization: start at TX position in voxel coordinates.
let start_vx = (link.tx.x - config.bounds[0]) / vx;
let start_vy = (link.tx.y - config.bounds[1]) / vy;
let start_vz = (link.tx.z - config.bounds[2]) / vz;
if dist < fresnel_radius {
// Weight decays with distance from link ray
let w = 1.0 - dist / fresnel_radius;
let idx = iz * config.ny * config.nx + iy * config.nx + ix;
weights.push((idx, w));
let end_vx = (link.rx.x - config.bounds[0]) / vx;
let end_vy = (link.rx.y - config.bounds[1]) / vy;
let end_vz = (link.rx.z - config.bounds[2]) / vz;
let ray_dx = end_vx - start_vx;
let ray_dy = end_vy - start_vy;
let ray_dz = end_vz - start_vz;
// Number of DDA steps: traverse the maximum voxel span.
let steps = (ray_dx.abs().max(ray_dy.abs()).max(ray_dz.abs()).ceil() as usize).max(1);
let inv_steps = 1.0 / steps as f64;
for step in 0..=steps {
let t = step as f64 * inv_steps;
let rx = start_vx + t * ray_dx;
let ry = start_vy + t * ray_dy;
let rz = start_vz + t * ray_dz;
let base_ix = rx.floor() as isize;
let base_iy = ry.floor() as isize;
let base_iz = rz.floor() as isize;
// Expand by Fresnel radius to check neighboring voxels.
for diz in -expand_z..=expand_z {
let iz = base_iz + diz;
if iz < 0 || iz >= config.nz as isize { continue; }
for diy in -expand_y..=expand_y {
let iy = base_iy + diy;
if iy < 0 || iy >= config.ny as isize { continue; }
for dix in -expand_x..=expand_x {
let ix = base_ix + dix;
if ix < 0 || ix >= config.nx as isize { continue; }
let idx = iz as usize * config.ny * config.nx
+ iy as usize * config.nx
+ ix as usize;
if visited[idx] { continue; }
let cx = config.bounds[0] + (ix as f64 + 0.5) * vx;
let cy = config.bounds[1] + (iy as f64 + 0.5) * vy;
let cz = config.bounds[2] + (iz as f64 + 0.5) * vz;
let dist = point_to_segment_distance(
cx, cy, cz,
link.tx.x, link.tx.y, link.tx.z,
dx, dy, dz, link_dist,
);
if dist < fresnel_radius {
let w = 1.0 - dist / fresnel_radius;
weights.push((idx, w));
}
visited[idx] = true;
}
}
}

View file

@ -76,4 +76,31 @@ describe('MATScreen', () => {
// Simulated status maps to 'simulated' banner -> "SIMULATED DATA"
expect(getByText('SIMULATED DATA')).toBeTruthy();
});
it('shows simulation warning overlay when simulated and not acknowledged', () => {
// Reset store to ensure overlay is shown
const { useMatStore } = require('@/stores/matStore');
useMatStore.setState({ dataSource: 'simulated', simulationAcknowledged: false });
const { MATScreen } = require('@/screens/MATScreen');
const { getByText } = render(
<ThemeProvider>
<MATScreen />
</ThemeProvider>,
);
expect(getByText('I UNDERSTAND')).toBeTruthy();
});
it('hides overlay after acknowledgment', () => {
const { useMatStore } = require('@/stores/matStore');
useMatStore.setState({ dataSource: 'simulated', simulationAcknowledged: true });
const { MATScreen } = require('@/screens/MATScreen');
const { queryByText } = render(
<ThemeProvider>
<MATScreen />
</ThemeProvider>,
);
expect(queryByText('I UNDERSTAND')).toBeNull();
});
});

View file

@ -62,6 +62,8 @@ describe('useMatStore', () => {
survivors: [],
alerts: [],
selectedEventId: null,
dataSource: 'simulated',
simulationAcknowledged: false,
});
});
@ -195,4 +197,32 @@ describe('useMatStore', () => {
expect(useMatStore.getState().selectedEventId).toBeNull();
});
});
describe('dataSource', () => {
it('defaults to simulated', () => {
expect(useMatStore.getState().dataSource).toBe('simulated');
});
it('can be set to real', () => {
useMatStore.getState().setDataSource('real');
expect(useMatStore.getState().dataSource).toBe('real');
});
it('can be set back to simulated', () => {
useMatStore.getState().setDataSource('real');
useMatStore.getState().setDataSource('simulated');
expect(useMatStore.getState().dataSource).toBe('simulated');
});
});
describe('simulationAcknowledged', () => {
it('defaults to false', () => {
expect(useMatStore.getState().simulationAcknowledged).toBe(false);
});
it('can be acknowledged', () => {
useMatStore.getState().acknowledgeSimulation();
expect(useMatStore.getState().simulationAcknowledged).toBe(true);
});
});
});

View file

@ -0,0 +1,49 @@
import React, { useEffect, useRef } from 'react';
import { Animated, StyleSheet, Text, View } from 'react-native';
interface Props {
visible: boolean;
}
export const SimulationBanner: React.FC<Props> = ({ visible }) => {
const opacity = useRef(new Animated.Value(1)).current;
useEffect(() => {
if (!visible) return;
const pulse = Animated.loop(
Animated.sequence([
Animated.timing(opacity, { toValue: 0.4, duration: 800, useNativeDriver: true }),
Animated.timing(opacity, { toValue: 1.0, duration: 800, useNativeDriver: true }),
]),
);
pulse.start();
return () => pulse.stop();
}, [visible, opacity]);
if (!visible) return null;
return (
<Animated.View style={[styles.banner, { opacity }]}>
<Text style={styles.text}>SIMULATED DATA - NOT CONNECTED TO REAL SENSORS</Text>
</Animated.View>
);
};
const styles = StyleSheet.create({
banner: {
backgroundColor: '#e74c3c',
paddingVertical: 6,
paddingHorizontal: 12,
borderRadius: 6,
alignItems: 'center',
marginBottom: 8,
},
text: {
color: '#ffffff',
fontWeight: '700',
fontSize: 12,
letterSpacing: 0.5,
textAlign: 'center',
},
});

View file

@ -0,0 +1,78 @@
import React from 'react';
import { Modal, Pressable, StyleSheet, Text, View } from 'react-native';
interface Props {
visible: boolean;
onAcknowledge: () => void;
}
export const SimulationWarningOverlay: React.FC<Props> = ({ visible, onAcknowledge }) => (
<Modal visible={visible} transparent animationType="fade">
<View style={styles.backdrop}>
<View style={styles.card}>
<Text style={styles.icon}>&#9888;</Text>
<Text style={styles.title}>SIMULATED DATA</Text>
<Text style={styles.body}>
NOT CONNECTED TO REAL SENSORS{'\n\n'}
All survivor detections, vital signs, and alerts displayed on this screen are
generated from simulated data and do not reflect actual conditions.
</Text>
<Pressable style={styles.button} onPress={onAcknowledge}>
<Text style={styles.buttonText}>I UNDERSTAND</Text>
</Pressable>
</View>
</View>
</Modal>
);
const styles = StyleSheet.create({
backdrop: {
flex: 1,
backgroundColor: 'rgba(0,0,0,0.85)',
justifyContent: 'center',
alignItems: 'center',
padding: 24,
},
card: {
backgroundColor: '#1a1a2e',
borderRadius: 16,
padding: 32,
alignItems: 'center',
borderWidth: 2,
borderColor: '#e74c3c',
maxWidth: 420,
width: '100%',
},
icon: {
fontSize: 48,
color: '#e74c3c',
marginBottom: 12,
},
title: {
fontSize: 22,
fontWeight: '800',
color: '#e74c3c',
textAlign: 'center',
marginBottom: 16,
letterSpacing: 1,
},
body: {
fontSize: 15,
color: '#cccccc',
textAlign: 'center',
lineHeight: 22,
marginBottom: 28,
},
button: {
backgroundColor: '#e74c3c',
paddingHorizontal: 36,
paddingVertical: 14,
borderRadius: 8,
},
buttonText: {
color: '#ffffff',
fontWeight: '700',
fontSize: 16,
letterSpacing: 0.5,
},
});

View file

@ -10,6 +10,8 @@ import { type ConnectionStatus } from '@/types/sensing';
import { Alert, type Survivor } from '@/types/mat';
import { AlertList } from './AlertList';
import { MatWebView } from './MatWebView';
import { SimulationBanner } from './SimulationBanner';
import { SimulationWarningOverlay } from './SimulationWarningOverlay';
import { SurvivorCounter } from './SurvivorCounter';
import { useMatBridge } from './useMatBridge';
@ -47,6 +49,15 @@ export const MATScreen = () => {
const upsertSurvivor = useMatStore((state) => state.upsertSurvivor);
const addAlert = useMatStore((state) => state.addAlert);
const upsertEvent = useMatStore((state) => state.upsertEvent);
const dataSource = useMatStore((state) => state.dataSource);
const simulationAcknowledged = useMatStore((state) => state.simulationAcknowledged);
const setDataSource = useMatStore((state) => state.setDataSource);
const acknowledgeSimulation = useMatStore((state) => state.acknowledgeSimulation);
// Sync dataSource from connection status
useEffect(() => {
setDataSource(connectionStatus === 'connected' ? 'real' : 'simulated');
}, [connectionStatus, setDataSource]);
const { webViewRef, ready, onMessage, sendFrameUpdate, postEvent } = useMatBridge({
onSurvivorDetected: (survivor) => {
@ -113,8 +124,13 @@ export const MATScreen = () => {
const { height } = useWindowDimensions();
const webHeight = Math.max(240, Math.floor(height * 0.5));
const showOverlay = dataSource === 'simulated' && !simulationAcknowledged;
const showBanner = dataSource === 'simulated' && simulationAcknowledged;
return (
<ThemedView style={{ flex: 1, backgroundColor: colors.bg, padding: spacing.md }}>
<SimulationWarningOverlay visible={showOverlay} onAcknowledge={acknowledgeSimulation} />
<SimulationBanner visible={showBanner} />
<ConnectionBanner status={resolveBannerState(connectionStatus)} />
<View style={{ marginTop: 20 }}>
<SurvivorCounter survivors={survivors} />

View file

@ -7,11 +7,17 @@ export interface MatState {
survivors: Survivor[];
alerts: Alert[];
selectedEventId: string | null;
/** Whether data comes from real sensors or simulation. */
dataSource: 'real' | 'simulated';
/** Whether the user has dismissed the simulation warning overlay. */
simulationAcknowledged: boolean;
upsertEvent: (event: DisasterEvent) => void;
addZone: (zone: ScanZone) => void;
upsertSurvivor: (survivor: Survivor) => void;
addAlert: (alert: Alert) => void;
setSelectedEvent: (id: string | null) => void;
setDataSource: (source: 'real' | 'simulated') => void;
acknowledgeSimulation: () => void;
}
export const useMatStore = create<MatState>((set) => ({
@ -20,6 +26,8 @@ export const useMatStore = create<MatState>((set) => ({
survivors: [],
alerts: [],
selectedEventId: null,
dataSource: 'simulated',
simulationAcknowledged: false,
upsertEvent: (event) => {
set((state) => {
@ -71,4 +79,12 @@ export const useMatStore = create<MatState>((set) => ({
setSelectedEvent: (id) => {
set({ selectedEventId: id });
},
setDataSource: (source) => {
set({ dataSource: source });
},
acknowledgeSimulation: () => {
set({ simulationAcknowledged: true });
},
}));

View file

@ -17,7 +17,7 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
from src.config.settings import get_settings
from src.config.domains import get_domain_config
from src.api.routers import pose, stream, health
from src.api.routers import pose, stream, health, auth
from src.api.middleware.auth import AuthMiddleware
from src.api.middleware.rate_limit import RateLimitMiddleware
from src.api.dependencies import get_pose_service, get_stream_service, get_hardware_service
@ -263,6 +263,12 @@ app.include_router(
tags=["Streaming"]
)
app.include_router(
auth.router,
prefix=f"{settings.api_prefix}",
tags=["Authentication"]
)
# Root endpoint
@app.get("/")

View file

@ -189,7 +189,11 @@ class AuthMiddleware(BaseHTTPMiddleware):
self.settings.secret_key,
algorithms=[self.settings.jwt_algorithm]
)
# Check token blacklist (logout invalidation)
if token_blacklist.is_blacklisted(token):
raise ValueError("Token has been revoked")
# Extract user information
user_id = payload.get("sub")
if not user_id:

View file

@ -2,6 +2,6 @@
API routers package
"""
from . import pose, stream, health
from . import pose, stream, health, auth
__all__ = ["pose", "stream", "health"]
__all__ = ["pose", "stream", "health", "auth"]

View file

@ -0,0 +1,32 @@
"""
Authentication router for WiFi-DensePose API.
Provides logout (token blacklisting) endpoint.
"""
import logging
from typing import Optional
from fastapi import APIRouter, Request, HTTPException, status
from src.api.middleware.auth import token_blacklist
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/auth", tags=["auth"])
@router.post("/logout")
async def logout(request: Request):
"""Logout by blacklisting the current Bearer token."""
auth_header = request.headers.get("authorization")
if not auth_header or not auth_header.startswith("Bearer "):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing or invalid Authorization header",
)
token = auth_header.split(" ", 1)[1]
token_blacklist.add_token(token)
logger.info("Token blacklisted via /auth/logout")
return {"success": True, "message": "Token revoked"}

View file

@ -1,6 +1,7 @@
"""CSI data processor for WiFi-DensePose system using TDD approach."""
import asyncio
import itertools
import logging
import numpy as np
from datetime import datetime, timezone
@ -293,7 +294,8 @@ class CSIProcessor:
if count >= len(self.csi_history):
return list(self.csi_history)
else:
return list(self.csi_history)[-count:]
start = len(self.csi_history) - count
return list(itertools.islice(self.csi_history, start, len(self.csi_history)))
def get_processing_statistics(self) -> Dict[str, Any]:
"""Get processing statistics.
@ -410,8 +412,9 @@ class CSIProcessor:
# Use cached mean-phase values (pre-computed in add_to_history)
# Only take the last doppler_window frames for bounded cost
window = min(len(self._phase_cache), self._doppler_window)
cache_list = list(self._phase_cache)
phase_matrix = np.array(cache_list[-window:])
start = len(self._phase_cache) - window
cache_list = list(itertools.islice(self._phase_cache, start, len(self._phase_cache)))
phase_matrix = np.array(cache_list)
# Temporal phase differences between consecutive frames
phase_diffs = np.diff(phase_matrix, axis=0)

View file

@ -56,6 +56,10 @@ class TokenManager:
"""Verify and decode JWT token."""
try:
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
# Check token blacklist (logout invalidation)
from src.api.middleware.auth import token_blacklist
if token_blacklist.is_blacklisted(token):
raise AuthenticationError("Token has been revoked")
return payload
except JWTError as e:
logger.warning(f"JWT verification failed: {e}")

View file

@ -0,0 +1,135 @@
"""Frame budget benchmark for CSI processing pipeline.
Verifies that per-frame CSI processing stays within the 50 ms budget
required for real-time sensing at 20 FPS.
"""
import time
import statistics
import pytest
import numpy as np
from src.core.csi_processor import CSIProcessor
def _make_config():
return {
"sampling_rate": 1000,
"window_size": 256,
"overlap": 0.5,
"noise_threshold": -60,
"human_detection_threshold": 0.8,
"smoothing_factor": 0.9,
"max_history_size": 500,
"num_subcarriers": 256,
"num_antennas": 3,
"doppler_window": 64,
}
def _make_csi_data(n_subcarriers=256, n_antennas=3, seed=None):
"""Generate a synthetic CSI frame with complex-valued subcarriers."""
rng = np.random.default_rng(seed)
from unittest.mock import MagicMock
csi = MagicMock()
csi.amplitude = rng.random((n_antennas, n_subcarriers)).astype(np.float64) * 20.0
csi.phase = (rng.random((n_antennas, n_subcarriers)).astype(np.float64) - 0.5) * np.pi * 2
csi.frequency = 5.0e9
csi.bandwidth = 80e6
csi.num_subcarriers = n_subcarriers
csi.num_antennas = n_antennas
csi.snr = 25.0
csi.timestamp = time.time()
csi.metadata = {}
return csi
class TestSingleFrameBudget:
"""Single-frame processing must complete in < 50 ms."""
def test_single_frame_under_50ms(self):
proc = CSIProcessor(config=_make_config())
frame = _make_csi_data(seed=42)
# Warm up
proc.preprocess_csi_data(frame)
start = time.perf_counter()
proc.preprocess_csi_data(frame)
features = proc.extract_features(frame)
if features:
proc.detect_human_presence(features)
elapsed_ms = (time.perf_counter() - start) * 1000
assert elapsed_ms < 50, f"Single frame took {elapsed_ms:.1f} ms (budget: 50 ms)"
class TestSustainedFrameBudget:
"""Sustained 100-frame processing p95 must be < 50 ms per frame."""
def test_sustained_100_frames_p95(self):
proc = CSIProcessor(config=_make_config())
rng = np.random.default_rng(123)
n_frames = 100
latencies = []
for i in range(n_frames):
frame = _make_csi_data(seed=i)
start = time.perf_counter()
preprocessed = proc.preprocess_csi_data(frame)
features = proc.extract_features(preprocessed)
if features:
proc.detect_human_presence(features)
proc.add_to_history(frame)
elapsed_ms = (time.perf_counter() - start) * 1000
latencies.append(elapsed_ms)
p50 = statistics.median(latencies)
p95 = sorted(latencies)[int(0.95 * len(latencies))]
p99 = sorted(latencies)[int(0.99 * len(latencies))]
print(f"\n--- Sustained {n_frames}-frame benchmark ---")
print(f" p50: {p50:.2f} ms")
print(f" p95: {p95:.2f} ms")
print(f" p99: {p99:.2f} ms")
print(f" min: {min(latencies):.2f} ms")
print(f" max: {max(latencies):.2f} ms")
assert p95 < 50, f"p95 latency {p95:.1f} ms exceeds 50 ms budget"
class TestPipelineWithDoppler:
"""Full pipeline including Doppler estimation must stay within budget."""
def test_doppler_pipeline(self):
proc = CSIProcessor(config=_make_config())
n_frames = 100
latencies = []
# Fill history first
for i in range(20):
frame = _make_csi_data(seed=i + 1000)
proc.add_to_history(frame)
for i in range(n_frames):
frame = _make_csi_data(seed=i + 2000)
start = time.perf_counter()
preprocessed = proc.preprocess_csi_data(frame)
features = proc.extract_features(preprocessed)
if features:
proc.detect_human_presence(features)
proc.add_to_history(frame)
elapsed_ms = (time.perf_counter() - start) * 1000
latencies.append(elapsed_ms)
p50 = statistics.median(latencies)
p95 = sorted(latencies)[int(0.95 * len(latencies))]
p99 = sorted(latencies)[int(0.99 * len(latencies))]
print(f"\n--- Doppler pipeline benchmark ({n_frames} frames, 20 warmup) ---")
print(f" p50: {p50:.2f} ms")
print(f" p95: {p95:.2f} ms")
print(f" p99: {p99:.2f} ms")
# Doppler adds overhead but should still be within budget
assert p95 < 50, f"Doppler pipeline p95 {p95:.1f} ms exceeds 50 ms budget"

56
v1/tests/unit/conftest.py Normal file
View file

@ -0,0 +1,56 @@
"""Shared fixtures for unit tests."""
import os
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
# Set SECRET_KEY before any settings import
os.environ.setdefault("SECRET_KEY", "test-secret-key-for-unit-tests-only")
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests-only")
@pytest.fixture
def mock_settings():
"""Create a mock Settings object."""
settings = MagicMock()
settings.secret_key = "test-secret-key-for-unit-tests-only"
settings.jwt_algorithm = "HS256"
settings.jwt_expire_hours = 24
settings.app_name = "test-app"
settings.version = "0.1.0"
settings.is_production = False
settings.enable_rate_limiting = False
settings.enable_authentication = False
settings.rate_limit_requests = 100
settings.rate_limit_window = 60
settings.rate_limit_authenticated_requests = 1000
settings.allowed_hosts = ["*"]
settings.csi_buffer_size = 100
settings.stream_buffer_size = 100
settings.mock_hardware = True
settings.mock_pose_data = True
settings.enable_real_time_processing = False
settings.trusted_proxies = ["127.0.0.1"]
return settings
@pytest.fixture
def mock_domain_config():
"""Create a mock DomainConfig object."""
config = MagicMock()
config.pose_estimation = MagicMock()
config.streaming = MagicMock()
config.hardware = MagicMock()
return config
@pytest.fixture
def mock_redis():
"""Provide a mock Redis client."""
with patch("redis.Redis") as mock:
client = MagicMock()
client.ping.return_value = True
client.get.return_value = None
client.set.return_value = True
mock.return_value = client
yield client

View file

@ -0,0 +1,137 @@
"""Tests for AuthMiddleware and TokenManager."""
import pytest
import os
from unittest.mock import MagicMock, AsyncMock, patch
from datetime import datetime, timedelta
class TestTokenManager:
def test_create_token(self, mock_settings):
from src.middleware.auth import TokenManager
tm = TokenManager(mock_settings)
token = tm.create_access_token({"sub": "user1"})
assert isinstance(token, str)
assert len(token) > 0
def test_verify_valid_token(self, mock_settings):
from src.middleware.auth import TokenManager
tm = TokenManager(mock_settings)
token = tm.create_access_token({"sub": "user1", "role": "admin"})
payload = tm.verify_token(token)
assert payload["sub"] == "user1"
assert payload["role"] == "admin"
def test_verify_invalid_token(self, mock_settings):
from src.middleware.auth import TokenManager, AuthenticationError
tm = TokenManager(mock_settings)
with pytest.raises(AuthenticationError):
tm.verify_token("invalid.token.here")
def test_decode_claims(self, mock_settings):
from src.middleware.auth import TokenManager
tm = TokenManager(mock_settings)
token = tm.create_access_token({"sub": "user1"})
claims = tm.decode_token_claims(token)
assert claims is not None
assert claims["sub"] == "user1"
def test_decode_claims_invalid(self, mock_settings):
from src.middleware.auth import TokenManager
tm = TokenManager(mock_settings)
claims = tm.decode_token_claims("bad-token")
assert claims is None
def test_token_has_expiry(self, mock_settings):
from src.middleware.auth import TokenManager
tm = TokenManager(mock_settings)
token = tm.create_access_token({"sub": "user1"})
payload = tm.verify_token(token)
assert "exp" in payload
assert "iat" in payload
class TestUserManager:
def test_create_user(self):
from src.middleware.auth import UserManager
um = UserManager()
assert um.get_user("nonexistent") is None
def test_hash_password(self):
from src.middleware.auth import UserManager
hashed = UserManager.hash_password("secret123")
assert hashed != "secret123"
assert len(hashed) > 20
def test_verify_password(self):
from src.middleware.auth import UserManager
hashed = UserManager.hash_password("secret123")
assert UserManager.verify_password("secret123", hashed) is True
assert UserManager.verify_password("wrong", hashed) is False
class TestTokenBlacklist:
def test_add_and_check(self):
from src.api.middleware.auth import TokenBlacklist
bl = TokenBlacklist()
bl.add_token("tok123")
assert bl.is_blacklisted("tok123") is True
assert bl.is_blacklisted("tok456") is False
def test_blacklisted_token_rejected(self, mock_settings):
from src.middleware.auth import TokenManager, AuthenticationError
from src.api.middleware.auth import token_blacklist
tm = TokenManager(mock_settings)
token = tm.create_access_token({"sub": "user1"})
# Token should be valid
tm.verify_token(token)
# Blacklist it
token_blacklist.add_token(token)
with pytest.raises(AuthenticationError, match="revoked"):
tm.verify_token(token)
# Cleanup
token_blacklist._blacklisted_tokens.discard(token)
class TestAuthMiddleware:
def test_public_paths(self, mock_settings):
with patch("src.api.middleware.auth.get_settings", return_value=mock_settings):
from src.api.middleware.auth import AuthMiddleware
app = MagicMock()
mw = AuthMiddleware(app)
assert mw._is_public_path("/health") is True
assert mw._is_public_path("/docs") is True
assert mw._is_public_path("/api/v1/pose/analyze") is False
def test_protected_paths(self, mock_settings):
with patch("src.api.middleware.auth.get_settings", return_value=mock_settings):
from src.api.middleware.auth import AuthMiddleware
app = MagicMock()
mw = AuthMiddleware(app)
assert mw._is_protected_path("/api/v1/pose/analyze") is True
assert mw._is_protected_path("/health") is False
def test_extract_token_from_header(self, mock_settings):
with patch("src.api.middleware.auth.get_settings", return_value=mock_settings):
from src.api.middleware.auth import AuthMiddleware
app = MagicMock()
mw = AuthMiddleware(app)
request = MagicMock()
request.headers = {"authorization": "Bearer mytoken123"}
request.query_params = {}
request.cookies = {}
token = mw._extract_token(request)
assert token == "mytoken123"
def test_extract_token_missing(self, mock_settings):
with patch("src.api.middleware.auth.get_settings", return_value=mock_settings):
from src.api.middleware.auth import AuthMiddleware
app = MagicMock()
mw = AuthMiddleware(app)
request = MagicMock()
request.headers = {}
request.query_params = {}
request.cookies = {}
token = mw._extract_token(request)
assert token is None

View file

@ -0,0 +1,78 @@
"""Tests for error handling in the API layer."""
import pytest
from unittest.mock import MagicMock, patch
from fastapi.testclient import TestClient
class TestExceptionHandlers:
"""Test the exception handlers registered on the FastAPI app."""
def _get_app(self):
"""Import app lazily to avoid side effects."""
with patch("src.api.main.get_settings") as mock_gs, \
patch("src.api.main.get_domain_config") as mock_gdc, \
patch("src.api.main.get_pose_service") as mock_ps, \
patch("src.api.main.get_stream_service") as mock_ss, \
patch("src.api.main.get_hardware_service") as mock_hs, \
patch("src.api.main.connection_manager") as mock_cm, \
patch("src.api.main.PoseStreamHandler") as mock_psh:
mock_gs.return_value = MagicMock(
app_name="test", version="0.1", environment="test",
is_production=False, enable_rate_limiting=False,
enable_authentication=False, docs_url="/docs",
redoc_url="/redoc", openapi_url="/openapi.json",
api_prefix="/api/v1",
)
mock_gs.return_value.get_logging_config.return_value = {
"version": 1, "disable_existing_loggers": False,
"handlers": {}, "loggers": {},
}
mock_gs.return_value.get_cors_config.return_value = {
"allow_origins": ["*"], "allow_methods": ["*"],
"allow_headers": ["*"],
}
# Re-import to pick up patches
import importlib
import src.api.main as m
importlib.reload(m)
return m.app
class TestErrorResponseModel:
def test_error_json_structure(self):
"""Verify error JSON has code, message, type fields."""
error = {
"error": {
"code": 404,
"message": "Not found",
"type": "http_error"
}
}
assert error["error"]["code"] == 404
assert "message" in error["error"]
assert "type" in error["error"]
def test_validation_error_structure(self):
error = {
"error": {
"code": 422,
"message": "Validation error",
"type": "validation_error",
"details": []
}
}
assert error["error"]["type"] == "validation_error"
assert isinstance(error["error"]["details"], list)
def test_internal_error_masks_details(self):
"""In production, internal errors should not leak stack traces."""
error = {
"error": {
"code": 500,
"message": "Internal server error",
"type": "internal_error"
}
}
assert "traceback" not in str(error)
assert error["error"]["message"] == "Internal server error"

View file

@ -0,0 +1,65 @@
"""Tests for HardwareService."""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
class TestHardwareServiceInit:
def test_init(self, mock_settings, mock_domain_config):
mock_settings.mock_hardware = True
with patch("src.services.hardware_service.RouterInterface"):
from src.services.hardware_service import HardwareService
svc = HardwareService(mock_settings, mock_domain_config)
assert svc.is_running is False
assert svc.stats["total_samples"] == 0
assert svc.stats["connected_routers"] == 0
def test_stats_defaults(self, mock_settings, mock_domain_config):
mock_settings.mock_hardware = True
with patch("src.services.hardware_service.RouterInterface"):
from src.services.hardware_service import HardwareService
svc = HardwareService(mock_settings, mock_domain_config)
assert svc.stats["successful_samples"] == 0
assert svc.stats["failed_samples"] == 0
assert svc.stats["last_sample_time"] is None
class TestHardwareServiceLifecycle:
@pytest.mark.asyncio
async def test_start(self, mock_settings, mock_domain_config):
mock_settings.mock_hardware = True
with patch("src.services.hardware_service.RouterInterface"):
from src.services.hardware_service import HardwareService
svc = HardwareService(mock_settings, mock_domain_config)
svc._initialize_routers = AsyncMock()
svc._monitoring_loop = AsyncMock()
await svc.start()
assert svc.is_running is True
@pytest.mark.asyncio
async def test_double_start_idempotent(self, mock_settings, mock_domain_config):
mock_settings.mock_hardware = True
with patch("src.services.hardware_service.RouterInterface"):
from src.services.hardware_service import HardwareService
svc = HardwareService(mock_settings, mock_domain_config)
svc._initialize_routers = AsyncMock()
svc._monitoring_loop = AsyncMock()
await svc.start()
await svc.start() # idempotent
assert svc.is_running is True
class TestHardwareServiceRouter:
def test_no_routers_on_init(self, mock_settings, mock_domain_config):
mock_settings.mock_hardware = True
with patch("src.services.hardware_service.RouterInterface"):
from src.services.hardware_service import HardwareService
svc = HardwareService(mock_settings, mock_domain_config)
assert len(svc.router_interfaces) == 0
def test_max_recent_samples(self, mock_settings, mock_domain_config):
mock_settings.mock_hardware = True
with patch("src.services.hardware_service.RouterInterface"):
from src.services.hardware_service import HardwareService
svc = HardwareService(mock_settings, mock_domain_config)
assert svc.max_recent_samples == 1000

View file

@ -0,0 +1,67 @@
"""Tests for HealthCheckService."""
import pytest
from unittest.mock import MagicMock
class TestHealthCheckServiceInit:
def test_init(self, mock_settings):
from src.services.health_check import HealthCheckService
svc = HealthCheckService(mock_settings)
assert svc._initialized is False
assert svc._running is False
@pytest.mark.asyncio
async def test_initialize(self, mock_settings):
from src.services.health_check import HealthCheckService
svc = HealthCheckService(mock_settings)
await svc.initialize()
assert svc._initialized is True
assert "api" in svc._services
assert "database" in svc._services
assert "hardware" in svc._services
@pytest.mark.asyncio
async def test_double_initialize(self, mock_settings):
from src.services.health_check import HealthCheckService
svc = HealthCheckService(mock_settings)
await svc.initialize()
await svc.initialize() # idempotent
assert svc._initialized is True
class TestHealthCheckAggregation:
@pytest.mark.asyncio
async def test_services_registered(self, mock_settings):
from src.services.health_check import HealthCheckService, HealthStatus
svc = HealthCheckService(mock_settings)
await svc.initialize()
assert len(svc._services) == 6
for name, sh in svc._services.items():
assert sh.status == HealthStatus.UNKNOWN
@pytest.mark.asyncio
async def test_service_names(self, mock_settings):
from src.services.health_check import HealthCheckService
svc = HealthCheckService(mock_settings)
await svc.initialize()
expected = {"api", "database", "redis", "hardware", "pose", "stream"}
assert set(svc._services.keys()) == expected
class TestHealthStatus:
def test_enum_values(self):
from src.services.health_check import HealthStatus
assert HealthStatus.HEALTHY.value == "healthy"
assert HealthStatus.DEGRADED.value == "degraded"
assert HealthStatus.UNHEALTHY.value == "unhealthy"
assert HealthStatus.UNKNOWN.value == "unknown"
class TestHealthCheck:
def test_health_check_dataclass(self):
from src.services.health_check import HealthCheck, HealthStatus
hc = HealthCheck(name="test", status=HealthStatus.HEALTHY, message="ok")
assert hc.name == "test"
assert hc.status == HealthStatus.HEALTHY
assert hc.duration_ms == 0.0

View file

@ -0,0 +1,70 @@
"""Tests for MetricsService."""
import pytest
from datetime import timedelta
from unittest.mock import MagicMock, patch
class TestMetricSeries:
def test_add_point(self):
from src.services.metrics import MetricSeries
ms = MetricSeries(name="test", description="desc", unit="ms")
ms.add_point(42.0)
assert len(ms.points) == 1
assert ms.points[0].value == 42.0
def test_get_latest(self):
from src.services.metrics import MetricSeries
ms = MetricSeries(name="test", description="desc", unit="ms")
ms.add_point(1.0)
ms.add_point(2.0)
latest = ms.get_latest()
assert latest is not None
assert latest.value == 2.0
def test_get_latest_empty(self):
from src.services.metrics import MetricSeries
ms = MetricSeries(name="test", description="desc", unit="ms")
assert ms.get_latest() is None
def test_get_average(self):
from src.services.metrics import MetricSeries
ms = MetricSeries(name="test", description="desc", unit="ms")
for v in [10.0, 20.0, 30.0]:
ms.add_point(v)
avg = ms.get_average(timedelta(minutes=5))
assert avg == pytest.approx(20.0)
def test_get_average_empty(self):
from src.services.metrics import MetricSeries
ms = MetricSeries(name="test", description="desc", unit="ms")
assert ms.get_average(timedelta(minutes=5)) is None
def test_get_max(self):
from src.services.metrics import MetricSeries
ms = MetricSeries(name="test", description="desc", unit="ms")
for v in [10.0, 50.0, 30.0]:
ms.add_point(v)
mx = ms.get_max(timedelta(minutes=5))
assert mx == 50.0
def test_labels(self):
from src.services.metrics import MetricSeries
ms = MetricSeries(name="test", description="desc", unit="ms")
ms.add_point(1.0, {"region": "us-east"})
assert ms.points[0].labels["region"] == "us-east"
def test_maxlen(self):
from src.services.metrics import MetricSeries
ms = MetricSeries(name="test", description="desc", unit="ms")
for i in range(1100):
ms.add_point(float(i))
assert len(ms.points) == 1000
class TestMetricsService:
def test_init(self, mock_settings):
with patch("src.services.metrics.psutil"):
from src.services.metrics import MetricsService
svc = MetricsService(mock_settings)
assert svc._metrics is not None

View file

@ -0,0 +1,73 @@
"""Tests for PoseService."""
import pytest
import asyncio
from unittest.mock import MagicMock, AsyncMock, patch
from datetime import datetime
class TestPoseServiceInit:
def test_init_sets_defaults(self, mock_settings, mock_domain_config):
with patch.dict("sys.modules", {
"torch": MagicMock(),
"src.models.densepose_head": MagicMock(),
"src.models.modality_translation": MagicMock(),
}):
from src.services.pose_service import PoseService
svc = PoseService(mock_settings, mock_domain_config)
assert svc.is_initialized is False
assert svc.is_running is False
assert svc.stats["total_processed"] == 0
def test_stats_are_zero_on_init(self, mock_settings, mock_domain_config):
with patch.dict("sys.modules", {
"torch": MagicMock(),
"src.models.densepose_head": MagicMock(),
"src.models.modality_translation": MagicMock(),
}):
from src.services.pose_service import PoseService
svc = PoseService(mock_settings, mock_domain_config)
assert svc.stats["successful_detections"] == 0
assert svc.stats["failed_detections"] == 0
assert svc.stats["average_confidence"] == 0.0
class TestPoseServiceLifecycle:
@pytest.mark.asyncio
async def test_initialize_sets_flag(self, mock_settings, mock_domain_config):
with patch.dict("sys.modules", {
"torch": MagicMock(),
"src.models.densepose_head": MagicMock(),
"src.models.modality_translation": MagicMock(),
}):
from src.services.pose_service import PoseService
svc = PoseService(mock_settings, mock_domain_config)
await svc.initialize()
assert svc.is_initialized is True
@pytest.mark.asyncio
async def test_start_stop(self, mock_settings, mock_domain_config):
with patch.dict("sys.modules", {
"torch": MagicMock(),
"src.models.densepose_head": MagicMock(),
"src.models.modality_translation": MagicMock(),
}):
from src.services.pose_service import PoseService
svc = PoseService(mock_settings, mock_domain_config)
await svc.initialize()
await svc.start()
assert svc.is_running is True
await svc.stop()
assert svc.is_running is False
class TestPoseServiceStats:
def test_initial_classification(self, mock_settings, mock_domain_config):
with patch.dict("sys.modules", {
"torch": MagicMock(),
"src.models.densepose_head": MagicMock(),
"src.models.modality_translation": MagicMock(),
}):
from src.services.pose_service import PoseService
svc = PoseService(mock_settings, mock_domain_config)
assert svc.last_error is None

View file

@ -0,0 +1,62 @@
"""Tests for rate limiting middleware."""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
class TestRateLimitMiddleware:
def test_init(self, mock_settings):
with patch("src.api.middleware.rate_limit.get_settings", return_value=mock_settings):
from src.api.middleware.rate_limit import RateLimitMiddleware
app = MagicMock()
mw = RateLimitMiddleware(app)
assert "anonymous" in mw.rate_limits
assert "authenticated" in mw.rate_limits
assert "admin" in mw.rate_limits
def test_exempt_paths(self, mock_settings):
with patch("src.api.middleware.rate_limit.get_settings", return_value=mock_settings):
from src.api.middleware.rate_limit import RateLimitMiddleware
app = MagicMock()
mw = RateLimitMiddleware(app)
assert "/health" in mw.exempt_paths
assert "/metrics" in mw.exempt_paths
def test_is_exempt(self, mock_settings):
with patch("src.api.middleware.rate_limit.get_settings", return_value=mock_settings):
from src.api.middleware.rate_limit import RateLimitMiddleware
app = MagicMock()
mw = RateLimitMiddleware(app)
assert mw._is_exempt_path("/health") is True
assert mw._is_exempt_path("/api/v1/pose/current") is False
def test_path_specific_limits(self, mock_settings):
with patch("src.api.middleware.rate_limit.get_settings", return_value=mock_settings):
from src.api.middleware.rate_limit import RateLimitMiddleware
app = MagicMock()
mw = RateLimitMiddleware(app)
assert "/api/v1/pose/current" in mw.path_limits
assert mw.path_limits["/api/v1/pose/current"]["requests"] == 60
def test_trusted_proxies_not_blocked(self, mock_settings):
with patch("src.api.middleware.rate_limit.get_settings", return_value=mock_settings):
from src.api.middleware.rate_limit import RateLimitMiddleware
app = MagicMock()
mw = RateLimitMiddleware(app)
assert not mw._is_client_blocked("new-client-id")
class TestRateLimitConfig:
def test_anonymous_limit(self, mock_settings):
with patch("src.api.middleware.rate_limit.get_settings", return_value=mock_settings):
from src.api.middleware.rate_limit import RateLimitMiddleware
app = MagicMock()
mw = RateLimitMiddleware(app)
assert mw.rate_limits["anonymous"]["burst"] == 10
def test_admin_limit(self, mock_settings):
with patch("src.api.middleware.rate_limit.get_settings", return_value=mock_settings):
from src.api.middleware.rate_limit import RateLimitMiddleware
app = MagicMock()
mw = RateLimitMiddleware(app)
assert mw.rate_limits["admin"]["requests"] == 10000

View file

@ -0,0 +1,68 @@
"""Tests for StreamService."""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
class TestStreamServiceLifecycle:
def test_init(self, mock_settings, mock_domain_config):
from src.services.stream_service import StreamService
svc = StreamService(mock_settings, mock_domain_config)
assert svc.is_running is False
assert len(svc.connections) == 0
assert svc.stats["active_connections"] == 0
@pytest.mark.asyncio
async def test_initialize(self, mock_settings, mock_domain_config):
from src.services.stream_service import StreamService
svc = StreamService(mock_settings, mock_domain_config)
await svc.initialize()
@pytest.mark.asyncio
async def test_start(self, mock_settings, mock_domain_config):
mock_settings.enable_real_time_processing = False
from src.services.stream_service import StreamService
svc = StreamService(mock_settings, mock_domain_config)
await svc.start()
assert svc.is_running is True
@pytest.mark.asyncio
async def test_stop(self, mock_settings, mock_domain_config):
mock_settings.enable_real_time_processing = False
from src.services.stream_service import StreamService
svc = StreamService(mock_settings, mock_domain_config)
await svc.start()
await svc.stop()
assert svc.is_running is False
@pytest.mark.asyncio
async def test_double_start(self, mock_settings, mock_domain_config):
mock_settings.enable_real_time_processing = False
from src.services.stream_service import StreamService
svc = StreamService(mock_settings, mock_domain_config)
await svc.start()
await svc.start() # should be idempotent
assert svc.is_running is True
class TestStreamServiceConnections:
def test_no_connections_on_init(self, mock_settings, mock_domain_config):
from src.services.stream_service import StreamService
svc = StreamService(mock_settings, mock_domain_config)
assert svc.stats["total_connections"] == 0
assert svc.stats["messages_sent"] == 0
def test_buffer_sizes(self, mock_settings, mock_domain_config):
mock_settings.stream_buffer_size = 50
from src.services.stream_service import StreamService
svc = StreamService(mock_settings, mock_domain_config)
assert svc.pose_buffer.maxlen == 50
assert svc.csi_buffer.maxlen == 50
class TestStreamServiceBroadcast:
def test_stats_messages_failed_init_zero(self, mock_settings, mock_domain_config):
from src.services.stream_service import StreamService
svc = StreamService(mock_settings, mock_domain_config)
assert svc.stats["messages_failed"] == 0
assert svc.stats["data_points_streamed"] == 0