mirror of
https://github.com/ruvnet/RuView.git
synced 2026-04-28 05:59:32 +00:00
Pure JS implementation of WiFlow (arXiv:2602.08661) adapted for ESP32: - TCN temporal encoder (dilated causal conv, k=7, dilation 1/2/4/8) - Asymmetric spatial encoder (1x3 residual blocks, stride-2) - Axial self-attention (width + height, 8 heads, 256 channels) - Pose decoder (adaptive pooling → 17x2 COCO keypoints) - SmoothL1 + bone constraint loss (14 skeleton connections) - 1.8M params (1.6 MB at INT8), 198M FLOPs Integrated with camera-free pipeline (pose proxy labels from RSSI triangulation + subcarrier asymmetry + vibration) Co-Authored-By: claude-flow <ruv@ruv.net>
1015 lines
35 KiB
JavaScript
1015 lines
35 KiB
JavaScript
#!/usr/bin/env node
|
|
/**
|
|
* WiFlow Pose Estimation Training Pipeline
|
|
*
|
|
* Trains the WiFlow architecture (arXiv:2602.08661) on collected CSI data
|
|
* using the ruvllm training infrastructure. Extends train-ruvllm.js patterns
|
|
* with WiFlow-specific stages:
|
|
*
|
|
* Phase 0: CSI data loading + amplitude extraction + 20-frame windowing
|
|
* Phase 1: Contrastive pretraining (temporal consistency loss)
|
|
* Phase 2: Supervised pose training (SmoothL1 + bone constraint loss)
|
|
* Phase 3: LoRA room-specific adaptation
|
|
* Phase 4: Quantization (TurboQuant INT8 target: ~2.5 MB)
|
|
* Phase 5: Export (SafeTensors + ONNX-compatible + RVF)
|
|
*
|
|
* Usage:
|
|
* node scripts/train-wiflow.js --data data/recordings/pretrain-*.csi.jsonl
|
|
* node scripts/train-wiflow.js --data data/recordings/*.csi.jsonl --epochs 50 --output models/wiflow-v1
|
|
* node scripts/train-wiflow.js --data data/recordings/*.csi.jsonl --contrastive-only
|
|
*
|
|
* ADR: docs/adr/ADR-072-wiflow-architecture.md
|
|
*/
|
|
|
|
'use strict';
|
|
|
|
const fs = require('fs');
|
|
const path = require('path');
|
|
const { parseArgs } = require('util');
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Resolve dependencies
|
|
// ---------------------------------------------------------------------------
|
|
const RUVLLM_PATH = path.resolve(__dirname, '..', 'vendor', 'ruvector', 'npm', 'packages', 'ruvllm', 'src');
|
|
|
|
const {
|
|
ContrastiveTrainer,
|
|
cosineSimilarity,
|
|
tripletLoss,
|
|
infoNCELoss,
|
|
computeGradient,
|
|
} = require(path.join(RUVLLM_PATH, 'contrastive.js'));
|
|
|
|
const {
|
|
TrainingPipeline,
|
|
} = require(path.join(RUVLLM_PATH, 'training.js'));
|
|
|
|
const {
|
|
LoraAdapter,
|
|
LoraManager,
|
|
} = require(path.join(RUVLLM_PATH, 'lora.js'));
|
|
|
|
const {
|
|
EwcManager,
|
|
} = require(path.join(RUVLLM_PATH, 'sona.js'));
|
|
|
|
const {
|
|
SafeTensorsWriter,
|
|
ModelExporter,
|
|
} = require(path.join(RUVLLM_PATH, 'export.js'));
|
|
|
|
const {
|
|
WiFlowModel,
|
|
COCO_KEYPOINTS,
|
|
BONE_CONNECTIONS,
|
|
BONE_LENGTH_PRIORS,
|
|
smoothL1,
|
|
createRng,
|
|
gaussianRng,
|
|
estimateFLOPs,
|
|
} = require(path.join(__dirname, 'wiflow-model.js'));
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// CLI argument parsing
|
|
// ---------------------------------------------------------------------------
|
|
const { values: args } = parseArgs({
|
|
options: {
|
|
data: { type: 'string', short: 'd' },
|
|
output: { type: 'string', short: 'o', default: 'models/wiflow-v1' },
|
|
epochs: { type: 'string', short: 'e', default: '30' },
|
|
'batch-size': { type: 'string', default: '8' },
|
|
'learning-rate': { type: 'string', default: '0.001' },
|
|
'lora-rank': { type: 'string', default: '4' },
|
|
'quantize-bits': { type: 'string', default: '8' },
|
|
'contrastive-only': { type: 'boolean', default: false },
|
|
'max-samples': { type: 'string', default: '0' },
|
|
'time-steps': { type: 'string', default: '20' },
|
|
'subcarriers': { type: 'string', default: '128' },
|
|
seed: { type: 'string', default: '42' },
|
|
verbose: { type: 'boolean', short: 'v', default: false },
|
|
},
|
|
strict: true,
|
|
});
|
|
|
|
if (!args.data) {
|
|
console.error('Usage: node scripts/train-wiflow.js --data <path-to-csi-jsonl> [--output dir] [--epochs N]');
|
|
process.exit(1);
|
|
}
|
|
|
|
const CONFIG = {
|
|
dataGlob: args.data,
|
|
outputDir: args.output,
|
|
epochs: parseInt(args.epochs, 10),
|
|
batchSize: parseInt(args['batch-size'], 10),
|
|
learningRate: parseFloat(args['learning-rate']),
|
|
loraRank: parseInt(args['lora-rank'], 10),
|
|
quantizeBits: parseInt(args['quantize-bits'], 10),
|
|
contrastiveOnly: args['contrastive-only'],
|
|
maxSamples: parseInt(args['max-samples'], 10) || 0,
|
|
timeSteps: parseInt(args['time-steps'], 10),
|
|
subcarriers: parseInt(args['subcarriers'], 10),
|
|
seed: parseInt(args.seed, 10),
|
|
verbose: args.verbose,
|
|
|
|
// Contrastive hyperparameters
|
|
margin: 0.3,
|
|
temperature: 0.07,
|
|
contrastiveEpochs: 3,
|
|
|
|
// Bone constraint weight
|
|
boneWeight: 0.2,
|
|
};
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Data loading and CSI amplitude extraction
|
|
// ---------------------------------------------------------------------------
|
|
|
|
/**
|
|
* Parse CSI JSONL file and extract raw CSI frames.
|
|
*/
|
|
function loadCsiData(filePath) {
|
|
const rawCsi = [];
|
|
const features = [];
|
|
const vitals = [];
|
|
|
|
const content = fs.readFileSync(filePath, 'utf-8');
|
|
for (const line of content.split('\n')) {
|
|
if (!line.trim()) continue;
|
|
try {
|
|
const frame = JSON.parse(line);
|
|
switch (frame.type) {
|
|
case 'raw_csi':
|
|
rawCsi.push({
|
|
timestamp: frame.timestamp,
|
|
nodeId: frame.node_id,
|
|
subcarriers: frame.subcarriers,
|
|
iqHex: frame.iq_hex,
|
|
rssi: frame.rssi,
|
|
});
|
|
break;
|
|
case 'feature':
|
|
features.push({
|
|
timestamp: frame.timestamp,
|
|
nodeId: frame.node_id,
|
|
features: frame.features,
|
|
rssi: frame.rssi,
|
|
});
|
|
break;
|
|
case 'vitals':
|
|
vitals.push({
|
|
timestamp: frame.timestamp,
|
|
nodeId: frame.node_id,
|
|
presenceScore: frame.presence_score,
|
|
motionEnergy: frame.motion_energy,
|
|
breathingBpm: frame.breathing_bpm,
|
|
heartrateBpm: frame.heartrate_bpm,
|
|
nPersons: frame.n_persons,
|
|
});
|
|
break;
|
|
}
|
|
} catch (_) { /* skip malformed */ }
|
|
}
|
|
|
|
return { rawCsi, features, vitals };
|
|
}
|
|
|
|
/**
|
|
* Parse IQ hex string into complex pairs [I0, Q0, I1, Q1, ...].
|
|
* Each I/Q value is a signed byte.
|
|
*/
|
|
function parseIqHex(hexStr) {
|
|
const bytes = [];
|
|
for (let i = 0; i < hexStr.length; i += 2) {
|
|
let val = parseInt(hexStr.substr(i, 2), 16);
|
|
if (val > 127) val -= 256; // signed byte
|
|
bytes.push(val);
|
|
}
|
|
return bytes;
|
|
}
|
|
|
|
/**
|
|
* Extract amplitude from IQ data for a given number of subcarriers.
|
|
* Returns Float32Array of amplitudes [nSubcarriers].
|
|
* Skips first I/Q pair (DC offset) per WiFlow paper recommendation.
|
|
*/
|
|
function extractAmplitude(iqBytes, nSubcarriers) {
|
|
const amp = new Float32Array(nSubcarriers);
|
|
// Each subcarrier has 2 bytes (I, Q), first pair is often DC/padding
|
|
const start = 2; // skip first IQ pair (index 0,1)
|
|
for (let sc = 0; sc < nSubcarriers; sc++) {
|
|
const idx = start + sc * 2;
|
|
if (idx + 1 < iqBytes.length) {
|
|
const I = iqBytes[idx];
|
|
const Q = iqBytes[idx + 1];
|
|
amp[sc] = Math.sqrt(I * I + Q * Q);
|
|
}
|
|
}
|
|
return amp;
|
|
}
|
|
|
|
/**
|
|
* Normalize amplitude to zero-mean, unit-variance per subcarrier across time window.
|
|
*/
|
|
function normalizeAmplitude(window) {
|
|
// window: array of Float32Array [nSubcarriers]
|
|
const T = window.length;
|
|
if (T === 0) return [];
|
|
const nSc = window[0].length;
|
|
|
|
// Compute per-subcarrier mean and std
|
|
const mean = new Float32Array(nSc);
|
|
const std = new Float32Array(nSc);
|
|
for (let sc = 0; sc < nSc; sc++) {
|
|
let sum = 0;
|
|
for (let t = 0; t < T; t++) sum += window[t][sc];
|
|
mean[sc] = sum / T;
|
|
let varSum = 0;
|
|
for (let t = 0; t < T; t++) varSum += (window[t][sc] - mean[sc]) ** 2;
|
|
std[sc] = Math.sqrt(varSum / T) || 1;
|
|
}
|
|
|
|
return window.map(frame => {
|
|
const normed = new Float32Array(nSc);
|
|
for (let sc = 0; sc < nSc; sc++) {
|
|
normed[sc] = (frame[sc] - mean[sc]) / std[sc];
|
|
}
|
|
return normed;
|
|
});
|
|
}
|
|
|
|
/**
|
|
* Create sliding windows of CSI amplitude data.
|
|
* Returns arrays of { input: Float32Array[nSc * T], timestamp, nodeId }.
|
|
*/
|
|
function createWindows(rawCsi, nSubcarriers, timeSteps) {
|
|
// Group by nodeId, sort by timestamp
|
|
const byNode = {};
|
|
for (const frame of rawCsi) {
|
|
if (!byNode[frame.nodeId]) byNode[frame.nodeId] = [];
|
|
byNode[frame.nodeId].push(frame);
|
|
}
|
|
|
|
const windows = [];
|
|
|
|
for (const nodeId of Object.keys(byNode)) {
|
|
const frames = byNode[nodeId].sort((a, b) => a.timestamp - b.timestamp);
|
|
|
|
// Extract amplitudes
|
|
const amplitudes = frames.map(f => {
|
|
const iq = parseIqHex(f.iqHex);
|
|
return extractAmplitude(iq, nSubcarriers);
|
|
});
|
|
|
|
// Create sliding windows with stride 1
|
|
for (let i = 0; i <= amplitudes.length - timeSteps; i++) {
|
|
const windowFrames = amplitudes.slice(i, i + timeSteps);
|
|
const normalized = normalizeAmplitude(windowFrames);
|
|
|
|
// Flatten to [nSubcarriers, timeSteps] (channel-first)
|
|
const input = new Float32Array(nSubcarriers * timeSteps);
|
|
for (let sc = 0; sc < nSubcarriers; sc++) {
|
|
for (let t = 0; t < timeSteps; t++) {
|
|
input[sc * timeSteps + t] = normalized[t][sc];
|
|
}
|
|
}
|
|
|
|
windows.push({
|
|
input,
|
|
timestamp: frames[i + timeSteps - 1].timestamp,
|
|
startTimestamp: frames[i].timestamp,
|
|
nodeId: parseInt(nodeId),
|
|
});
|
|
}
|
|
}
|
|
|
|
return windows;
|
|
}
|
|
|
|
/**
|
|
* Generate pose proxy labels from vitals and motion data.
|
|
* This is the camera-free pipeline: no ground truth keypoints,
|
|
* but we can generate coarse pose proxies from sensor data.
|
|
*
|
|
* Strategy:
|
|
* - Person detected (presence > 0.3): place a standing skeleton at center
|
|
* - High motion (energy > 2): add random perturbation to limbs
|
|
* - Multiple people: offset skeletons horizontally
|
|
* - No presence: return null (skip)
|
|
*/
|
|
function generatePoseProxy(timestamp, nodeId, vitals, rng) {
|
|
// Find nearest vitals for this timestamp and node
|
|
let nearest = null;
|
|
let bestDist = Infinity;
|
|
for (const v of vitals) {
|
|
if (v.nodeId !== nodeId) continue;
|
|
const dist = Math.abs(v.timestamp - timestamp);
|
|
if (dist < bestDist) {
|
|
bestDist = dist;
|
|
nearest = v;
|
|
}
|
|
}
|
|
|
|
if (!nearest || bestDist > 2.0 || nearest.presenceScore <= 0.1) {
|
|
return null; // No person detected
|
|
}
|
|
|
|
// Base standing skeleton (COCO 17 keypoints, normalized [0,1])
|
|
const baseKeypoints = new Float32Array([
|
|
0.50, 0.10, // 0: nose
|
|
0.48, 0.08, // 1: left_eye
|
|
0.52, 0.08, // 2: right_eye
|
|
0.45, 0.09, // 3: left_ear
|
|
0.55, 0.09, // 4: right_ear
|
|
0.40, 0.25, // 5: left_shoulder
|
|
0.60, 0.25, // 6: right_shoulder
|
|
0.35, 0.40, // 7: left_elbow
|
|
0.65, 0.40, // 8: right_elbow
|
|
0.32, 0.55, // 9: left_wrist
|
|
0.68, 0.55, // 10: right_wrist
|
|
0.43, 0.55, // 11: left_hip
|
|
0.57, 0.55, // 12: right_hip
|
|
0.42, 0.72, // 13: left_knee
|
|
0.58, 0.72, // 14: right_knee
|
|
0.41, 0.90, // 15: left_ankle
|
|
0.59, 0.90, // 16: right_ankle
|
|
]);
|
|
|
|
const keypoints = new Float32Array(baseKeypoints);
|
|
const gauss = gaussianRng(rng);
|
|
|
|
// Add motion-based perturbation
|
|
const motionScale = Math.min(nearest.motionEnergy / 10.0, 0.15);
|
|
for (let i = 0; i < keypoints.length; i++) {
|
|
keypoints[i] += gauss() * motionScale;
|
|
// Clamp to [0.01, 0.99]
|
|
keypoints[i] = Math.max(0.01, Math.min(0.99, keypoints[i]));
|
|
}
|
|
|
|
// Add breathing-related micro-motion to torso
|
|
if (nearest.breathingBpm > 0) {
|
|
const breathPhase = (nearest.timestamp * nearest.breathingBpm / 60.0) * 2 * Math.PI;
|
|
const breathAmp = 0.005; // very small
|
|
for (const idx of [5, 6, 11, 12]) { // shoulders and hips
|
|
keypoints[idx * 2 + 1] += Math.sin(breathPhase) * breathAmp;
|
|
}
|
|
}
|
|
|
|
return {
|
|
keypoints,
|
|
confidence: nearest.presenceScore,
|
|
isProxy: true,
|
|
};
|
|
}
|
|
|
|
/**
|
|
* Resolve glob pattern to file list.
|
|
*/
|
|
function resolveGlob(pattern) {
|
|
if (!pattern.includes('*')) {
|
|
return fs.existsSync(pattern) ? [pattern] : [];
|
|
}
|
|
const dir = path.dirname(pattern);
|
|
const base = path.basename(pattern);
|
|
const regex = new RegExp('^' + base.replace(/\*/g, '.*') + '$');
|
|
if (!fs.existsSync(dir)) return [];
|
|
return fs.readdirSync(dir)
|
|
.filter(f => regex.test(f))
|
|
.map(f => path.join(dir, f));
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Quantization (from train-ruvllm.js)
|
|
// ---------------------------------------------------------------------------
|
|
|
|
function quantizeWeights(weights, bits) {
|
|
const maxVal = 2 ** bits - 1;
|
|
let wMin = Infinity, wMax = -Infinity;
|
|
for (let i = 0; i < weights.length; i++) {
|
|
if (weights[i] < wMin) wMin = weights[i];
|
|
if (weights[i] > wMax) wMax = weights[i];
|
|
}
|
|
const range = wMax - wMin || 1e-10;
|
|
const scale = range / maxVal;
|
|
const zeroPoint = Math.round(-wMin / scale);
|
|
|
|
const qValues = new Uint8Array(weights.length);
|
|
for (let i = 0; i < weights.length; i++) {
|
|
let q = Math.round((weights[i] - wMin) / scale);
|
|
qValues[i] = Math.max(0, Math.min(maxVal, q));
|
|
}
|
|
|
|
let packed;
|
|
if (bits === 8) {
|
|
packed = new Uint8Array(weights.length);
|
|
for (let i = 0; i < weights.length; i++) packed[i] = qValues[i];
|
|
} else if (bits === 4) {
|
|
packed = new Uint8Array(Math.ceil(weights.length / 2));
|
|
for (let i = 0; i < weights.length; i += 2) {
|
|
const hi = qValues[i] & 0x0F;
|
|
const lo = (i + 1 < weights.length) ? (qValues[i + 1] & 0x0F) : 0;
|
|
packed[i >> 1] = (hi << 4) | lo;
|
|
}
|
|
} else if (bits === 2) {
|
|
packed = new Uint8Array(Math.ceil(weights.length / 4));
|
|
for (let i = 0; i < weights.length; i += 4) {
|
|
let byte = 0;
|
|
for (let k = 0; k < 4; k++) {
|
|
const val = (i + k < weights.length) ? (qValues[i + k] & 0x03) : 0;
|
|
byte |= val << (6 - k * 2);
|
|
}
|
|
packed[Math.floor(i / 4)] = byte;
|
|
}
|
|
} else {
|
|
packed = new Uint8Array(weights.length);
|
|
for (let i = 0; i < weights.length; i++) packed[i] = qValues[i];
|
|
}
|
|
|
|
const originalSize = weights.length * 4;
|
|
return {
|
|
quantized: packed, scale, zeroPoint, bits,
|
|
numWeights: weights.length, originalSize,
|
|
quantizedSize: packed.length,
|
|
compressionRatio: originalSize / packed.length,
|
|
};
|
|
}
|
|
|
|
function dequantizeWeights(packed, scale, zeroPoint, bits, numWeights) {
|
|
const result = new Float32Array(numWeights);
|
|
if (bits === 8) {
|
|
for (let i = 0; i < numWeights; i++) result[i] = (packed[i] - zeroPoint) * scale;
|
|
} else if (bits === 4) {
|
|
for (let i = 0; i < numWeights; i++) {
|
|
const byteIdx = i >> 1;
|
|
const nibble = (i % 2 === 0) ? (packed[byteIdx] >> 4) & 0x0F : packed[byteIdx] & 0x0F;
|
|
result[i] = (nibble - zeroPoint) * scale;
|
|
}
|
|
} else if (bits === 2) {
|
|
for (let i = 0; i < numWeights; i++) {
|
|
const byteIdx = Math.floor(i / 4);
|
|
const shift = 6 - (i % 4) * 2;
|
|
const val = (packed[byteIdx] >> shift) & 0x03;
|
|
result[i] = (val - zeroPoint) * scale;
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
function quantizationQuality(original, dequantized) {
|
|
let sumSqErr = 0;
|
|
const n = Math.min(original.length, dequantized.length);
|
|
for (let i = 0; i < n; i++) sumSqErr += (original[i] - dequantized[i]) ** 2;
|
|
return Math.sqrt(sumSqErr / n);
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Deterministic shuffle
|
|
// ---------------------------------------------------------------------------
|
|
|
|
function shuffleArray(arr, seed) {
|
|
const result = [...arr];
|
|
let s = seed;
|
|
for (let i = result.length - 1; i > 0; i--) {
|
|
s ^= s << 13; s ^= s >> 17; s ^= s << 5;
|
|
const j = (s >>> 0) % (i + 1);
|
|
[result[i], result[j]] = [result[j], result[i]];
|
|
}
|
|
return result;
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Main training pipeline
|
|
// ---------------------------------------------------------------------------
|
|
|
|
async function main() {
|
|
const startTime = Date.now();
|
|
console.log('=== WiFlow Pose Estimation Training Pipeline ===');
|
|
console.log(`Config: epochs=${CONFIG.epochs} batch=${CONFIG.batchSize} lr=${CONFIG.learningRate}`);
|
|
console.log(` subcarriers=${CONFIG.subcarriers} timeSteps=${CONFIG.timeSteps} seed=${CONFIG.seed}`);
|
|
console.log('');
|
|
|
|
// -----------------------------------------------------------------------
|
|
// Step 1: Load CSI data
|
|
// -----------------------------------------------------------------------
|
|
console.log('[1/7] Loading CSI data...');
|
|
const files = resolveGlob(CONFIG.dataGlob);
|
|
if (files.length === 0) {
|
|
console.error(`No files found matching: ${CONFIG.dataGlob}`);
|
|
process.exit(1);
|
|
}
|
|
|
|
let allRawCsi = [];
|
|
let allFeatures = [];
|
|
let allVitals = [];
|
|
|
|
for (const file of files) {
|
|
console.log(` Loading: ${path.basename(file)}`);
|
|
const { rawCsi, features, vitals } = loadCsiData(file);
|
|
allRawCsi = allRawCsi.concat(rawCsi);
|
|
allFeatures = allFeatures.concat(features);
|
|
allVitals = allVitals.concat(vitals);
|
|
}
|
|
|
|
console.log(` Raw CSI frames: ${allRawCsi.length}`);
|
|
console.log(` Feature frames: ${allFeatures.length}`);
|
|
console.log(` Vitals frames: ${allVitals.length}`);
|
|
console.log(` Nodes: ${[...new Set(allRawCsi.map(f => f.nodeId))].join(', ')}`);
|
|
|
|
if (allRawCsi.length === 0) {
|
|
console.error('No raw CSI frames found. WiFlow requires raw IQ data (type="raw_csi").');
|
|
process.exit(1);
|
|
}
|
|
|
|
// Check subcarrier counts in data
|
|
const scCounts = new Map();
|
|
for (const f of allRawCsi) {
|
|
scCounts.set(f.subcarriers, (scCounts.get(f.subcarriers) || 0) + 1);
|
|
}
|
|
console.log(` Subcarrier distributions: ${[...scCounts.entries()].map(([k,v]) => `${k}sc: ${v} frames`).join(', ')}`);
|
|
|
|
// Use the target subcarrier count; frames with different counts will be resampled
|
|
const targetSc = CONFIG.subcarriers;
|
|
|
|
// -----------------------------------------------------------------------
|
|
// Step 2: Create amplitude windows
|
|
// -----------------------------------------------------------------------
|
|
console.log('\n[2/7] Extracting amplitude and creating windows...');
|
|
|
|
// If frames have different subcarrier counts, resample to target
|
|
const resampledCsi = allRawCsi.map(f => {
|
|
if (f.subcarriers === targetSc) return f;
|
|
// For frames with fewer subcarriers (e.g., 64), zero-pad to 128
|
|
// For frames with more, truncate
|
|
const iq = parseIqHex(f.iqHex);
|
|
const amp = extractAmplitude(iq, f.subcarriers);
|
|
// Resample amplitude to targetSc via linear interpolation
|
|
const resampled = new Float32Array(targetSc);
|
|
for (let i = 0; i < targetSc; i++) {
|
|
const srcIdx = (i / targetSc) * f.subcarriers;
|
|
const lo = Math.floor(srcIdx);
|
|
const hi = Math.min(lo + 1, f.subcarriers - 1);
|
|
const frac = srcIdx - lo;
|
|
resampled[i] = amp[lo] * (1 - frac) + amp[hi] * frac;
|
|
}
|
|
// Re-encode as fake iqHex (amplitude only, Q=0)
|
|
const newIq = [];
|
|
newIq.push(0, 0); // DC offset
|
|
for (let i = 0; i < targetSc; i++) {
|
|
const v = Math.round(Math.min(127, Math.max(-128, resampled[i])));
|
|
newIq.push(v, 0); // I = amplitude, Q = 0
|
|
}
|
|
const hexStr = newIq.map(b => {
|
|
const unsigned = b < 0 ? b + 256 : b;
|
|
return unsigned.toString(16).padStart(2, '0');
|
|
}).join('');
|
|
return { ...f, iqHex: hexStr, subcarriers: targetSc };
|
|
});
|
|
|
|
const windows = createWindows(resampledCsi, targetSc, CONFIG.timeSteps);
|
|
console.log(` Windows created: ${windows.length} (from ${allRawCsi.length} raw frames)`);
|
|
console.log(` Window shape: [${targetSc}, ${CONFIG.timeSteps}] = ${targetSc * CONFIG.timeSteps} values`);
|
|
|
|
if (windows.length === 0) {
|
|
console.error(`Not enough consecutive frames to create ${CONFIG.timeSteps}-step windows.`);
|
|
process.exit(1);
|
|
}
|
|
|
|
// -----------------------------------------------------------------------
|
|
// Step 3: Initialize WiFlow model
|
|
// -----------------------------------------------------------------------
|
|
console.log('\n[3/7] Initializing WiFlow model...');
|
|
const model = new WiFlowModel({
|
|
inputChannels: targetSc,
|
|
timeSteps: CONFIG.timeSteps,
|
|
numKeypoints: 17,
|
|
numHeads: 8,
|
|
seed: CONFIG.seed,
|
|
});
|
|
|
|
const breakdown = model.paramBreakdown();
|
|
console.log(` Parameter count: ${model.numParams().toLocaleString()}`);
|
|
console.log(` TCN: ${breakdown.tcn.toLocaleString()}`);
|
|
console.log(` Spatial encoder: ${breakdown.spatialEncoder.toLocaleString()}`);
|
|
console.log(` Axial attention: ${breakdown.axialAttention.toLocaleString()}`);
|
|
console.log(` Decoder: ${breakdown.decoder.toLocaleString()}`);
|
|
|
|
const flops = estimateFLOPs({ inputChannels: targetSc, timeSteps: CONFIG.timeSteps });
|
|
console.log(` Estimated FLOPs: ${(flops.total / 1e6).toFixed(1)}M`);
|
|
|
|
// Verify forward pass works
|
|
console.log(' Verifying forward pass...');
|
|
const testInput = new Float32Array(targetSc * CONFIG.timeSteps);
|
|
const rng = createRng(CONFIG.seed);
|
|
for (let i = 0; i < testInput.length; i++) testInput[i] = (rng() - 0.5) * 2;
|
|
|
|
const t0 = Date.now();
|
|
const testOutput = model.forward(testInput);
|
|
const fwdMs = Date.now() - t0;
|
|
console.log(` Forward pass: ${fwdMs}ms, output shape: [${testOutput.length / 2}, 2]`);
|
|
console.log(` Sample keypoints (nose): x=${testOutput[0].toFixed(3)}, y=${testOutput[1].toFixed(3)}`);
|
|
|
|
// -----------------------------------------------------------------------
|
|
// Phase 1: Contrastive pretraining (temporal consistency)
|
|
// -----------------------------------------------------------------------
|
|
console.log('\n[4/7] Phase 1: Contrastive pretraining...');
|
|
|
|
// Generate temporal triplets from windows
|
|
const triplets = [];
|
|
const nodeWindows = {};
|
|
for (const w of windows) {
|
|
if (!nodeWindows[w.nodeId]) nodeWindows[w.nodeId] = [];
|
|
nodeWindows[w.nodeId].push(w);
|
|
}
|
|
|
|
for (const nodeId of Object.keys(nodeWindows)) {
|
|
const nw = nodeWindows[nodeId];
|
|
for (let i = 0; i < nw.length; i++) {
|
|
// Positive: adjacent window (temporal consistency)
|
|
for (let j = i + 1; j < Math.min(i + 3, nw.length); j++) {
|
|
// Negative: window at least 10 windows away
|
|
const negStart = Math.max(0, i - 20);
|
|
const negEnd = Math.min(nw.length, i + 20);
|
|
for (let k = 0; k < nw.length; k++) {
|
|
if (k >= i - 3 && k <= i + 3) continue; // skip nearby
|
|
triplets.push({
|
|
anchor: nw[i],
|
|
positive: nw[j],
|
|
negative: nw[k],
|
|
});
|
|
if (triplets.length > 5000) break; // cap triplets
|
|
}
|
|
if (triplets.length > 5000) break;
|
|
}
|
|
if (triplets.length > 5000) break;
|
|
}
|
|
}
|
|
|
|
console.log(` Temporal triplets: ${triplets.length}`);
|
|
|
|
if (triplets.length > 0) {
|
|
// Use ruvllm ContrastiveTrainer for metric tracking
|
|
const contrastiveTrainer = new ContrastiveTrainer({
|
|
epochs: CONFIG.contrastiveEpochs,
|
|
batchSize: CONFIG.batchSize,
|
|
margin: CONFIG.margin,
|
|
temperature: CONFIG.temperature,
|
|
hardNegativeRatio: 0.5,
|
|
learningRate: CONFIG.learningRate,
|
|
outputPath: path.join(CONFIG.outputDir, 'contrastive'),
|
|
});
|
|
|
|
// Use model's forward pass to generate embeddings for contrastive learning
|
|
// We use the decoder output as the embedding (34-dim for 17 keypoints * 2)
|
|
const sampleTriplets = triplets.slice(0, Math.min(50, triplets.length));
|
|
for (const t of sampleTriplets) {
|
|
const anchorEmb = Array.from(model.forward(t.anchor.input));
|
|
const posEmb = Array.from(model.forward(t.positive.input));
|
|
const negEmb = Array.from(model.forward(t.negative.input));
|
|
contrastiveTrainer.addTriplet(
|
|
`a-${t.anchor.timestamp}`, anchorEmb,
|
|
`p-${t.positive.timestamp}`, posEmb,
|
|
`n-${t.negative.timestamp}`, negEmb,
|
|
false
|
|
);
|
|
}
|
|
|
|
const contrastiveResult = contrastiveTrainer.train();
|
|
console.log(` Contrastive loss: ${contrastiveResult.finalLoss.toFixed(6)}`);
|
|
console.log(` Duration: ${contrastiveResult.durationMs}ms`);
|
|
|
|
// Apply gradient updates to decoder weights via temporal consistency
|
|
console.log(' Applying decoder weight updates for temporal consistency...');
|
|
const decoderLr = CONFIG.learningRate * 0.1;
|
|
|
|
for (let epoch = 0; epoch < CONFIG.contrastiveEpochs; epoch++) {
|
|
let epochLoss = 0;
|
|
const shuffled = shuffleArray(sampleTriplets, epoch * 31 + 17);
|
|
|
|
for (const t of shuffled) {
|
|
const anchorOut = model.forward(t.anchor.input);
|
|
const posOut = model.forward(t.positive.input);
|
|
const negOut = model.forward(t.negative.input);
|
|
|
|
const loss = tripletLoss(
|
|
Array.from(anchorOut), Array.from(posOut), Array.from(negOut), CONFIG.margin
|
|
);
|
|
epochLoss += loss;
|
|
|
|
if (loss > 0) {
|
|
// Update decoder weights to push anchor closer to positive, away from negative
|
|
const grad = computeGradient(
|
|
Array.from(anchorOut), Array.from(posOut), Array.from(negOut), decoderLr
|
|
);
|
|
// Apply gradient to decoder bias (simplified update)
|
|
for (let j = 0; j < Math.min(grad.length, model.decoder.bias.length); j++) {
|
|
model.decoder.bias[j] += grad[j] * 0.01;
|
|
}
|
|
}
|
|
}
|
|
|
|
epochLoss /= shuffled.length || 1;
|
|
if (CONFIG.verbose || epoch === CONFIG.contrastiveEpochs - 1) {
|
|
console.log(` Epoch ${epoch + 1}/${CONFIG.contrastiveEpochs}: loss=${epochLoss.toFixed(6)}`);
|
|
}
|
|
}
|
|
}
|
|
|
|
if (CONFIG.contrastiveOnly) {
|
|
console.log('\n --contrastive-only flag set, skipping supervised training.');
|
|
await exportModel(model, CONFIG, startTime, { contrastiveOnly: true });
|
|
return;
|
|
}
|
|
|
|
// -----------------------------------------------------------------------
|
|
// Phase 2: Supervised pose training (SmoothL1 + bone constraint)
|
|
// -----------------------------------------------------------------------
|
|
console.log('\n[5/7] Phase 2: Supervised pose training...');
|
|
|
|
// Generate pose proxy labels for each window
|
|
const proxyRng = createRng(CONFIG.seed + 100);
|
|
const labeledWindows = [];
|
|
const unlabeledWindows = [];
|
|
|
|
for (const w of windows) {
|
|
const proxy = generatePoseProxy(w.timestamp, w.nodeId, allVitals, proxyRng);
|
|
if (proxy) {
|
|
labeledWindows.push({ ...w, target: proxy.keypoints, confidence: proxy.confidence });
|
|
} else {
|
|
unlabeledWindows.push(w);
|
|
}
|
|
}
|
|
|
|
// Limit samples if --max-samples set (useful for fast iteration)
|
|
if (CONFIG.maxSamples > 0 && labeledWindows.length > CONFIG.maxSamples) {
|
|
labeledWindows.length = CONFIG.maxSamples;
|
|
}
|
|
|
|
console.log(` Labeled windows (pose proxy): ${labeledWindows.length}`);
|
|
console.log(` Unlabeled windows: ${unlabeledWindows.length}`);
|
|
|
|
if (labeledWindows.length > 0) {
|
|
// Training loop with SmoothL1 + bone constraint
|
|
const lr = CONFIG.learningRate;
|
|
let bestLoss = Infinity;
|
|
let patience = 10;
|
|
let patienceCounter = 0;
|
|
|
|
for (let epoch = 0; epoch < CONFIG.epochs; epoch++) {
|
|
let epochLossH = 0;
|
|
let epochLossB = 0;
|
|
let epochPCK = 0;
|
|
let nSamples = 0;
|
|
|
|
const shuffled = shuffleArray(labeledWindows, epoch * 41 + 7);
|
|
const batches = [];
|
|
for (let i = 0; i < shuffled.length; i += CONFIG.batchSize) {
|
|
batches.push(shuffled.slice(i, i + CONFIG.batchSize));
|
|
}
|
|
|
|
for (const batch of batches) {
|
|
for (const sample of batch) {
|
|
const predicted = model.forward(sample.input);
|
|
const lossResult = model.computeLoss(predicted, sample.target, true);
|
|
|
|
epochLossH += lossResult.smoothL1;
|
|
epochLossB += lossResult.boneLoss;
|
|
|
|
// Compute PCK@20
|
|
epochPCK += WiFlowModel.pck(predicted, sample.target, 0.2);
|
|
nSamples++;
|
|
|
|
// Gradient update on decoder (simplified: update decoder weights)
|
|
const grad = model.computeLossGrad(predicted, sample.target);
|
|
const decoderDim = model.decoder.outDim;
|
|
const featureDim = model.decoder.inFeatures;
|
|
|
|
// Update decoder bias
|
|
for (let j = 0; j < decoderDim; j++) {
|
|
model.decoder.bias[j] -= lr * grad[j] * sample.confidence;
|
|
}
|
|
|
|
// Update decoder weights (approximate: use small perturbation)
|
|
// Full backprop through TCN/spatial/attention is expensive in pure JS
|
|
// We use decoder-only updates + contrastive pretrained features
|
|
for (let j = 0; j < decoderDim; j++) {
|
|
for (let i = 0; i < Math.min(featureDim, 48); i++) {
|
|
model.decoder.weight[i * decoderDim + j] -= lr * grad[j] * 0.001;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
epochLossH /= nSamples || 1;
|
|
epochLossB /= nSamples || 1;
|
|
epochPCK /= nSamples || 1;
|
|
const totalLoss = epochLossH + 0.2 * epochLossB;
|
|
|
|
if (CONFIG.verbose || epoch % 5 === 0 || epoch === CONFIG.epochs - 1) {
|
|
console.log(` Epoch ${epoch + 1}/${CONFIG.epochs}: L_H=${epochLossH.toFixed(4)} L_B=${epochLossB.toFixed(4)} total=${totalLoss.toFixed(4)} PCK@20=${(epochPCK * 100).toFixed(1)}%`);
|
|
}
|
|
|
|
// Early stopping
|
|
if (totalLoss < bestLoss) {
|
|
bestLoss = totalLoss;
|
|
patienceCounter = 0;
|
|
} else {
|
|
patienceCounter++;
|
|
if (patienceCounter >= patience) {
|
|
console.log(` Early stopping at epoch ${epoch + 1} (patience=${patience})`);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
console.log(' WARN: No pose proxy labels generated. Skipping supervised training.');
|
|
}
|
|
|
|
// -----------------------------------------------------------------------
|
|
// Phase 3: LoRA room-specific adaptation
|
|
// -----------------------------------------------------------------------
|
|
console.log('\n[6/7] Phase 3: LoRA adaptation...');
|
|
|
|
const loraManager = new LoraManager({
|
|
rank: CONFIG.loraRank,
|
|
alpha: CONFIG.loraRank * 2,
|
|
dropout: 0.1,
|
|
targetModules: ['decoder'],
|
|
});
|
|
|
|
const nodeIds = [...new Set(windows.map(w => w.nodeId))];
|
|
|
|
for (const nodeId of nodeIds) {
|
|
console.log(` Training LoRA adapter for node ${nodeId}...`);
|
|
const nodeAdapter = loraManager.create(
|
|
`wiflow-node-${nodeId}`,
|
|
{ rank: CONFIG.loraRank, alpha: CONFIG.loraRank * 2, dropout: 0.1 },
|
|
2048, // decoder input dim (256 * 8)
|
|
34 // decoder output dim (17 * 2)
|
|
);
|
|
|
|
const nodeData = labeledWindows.filter(w => w.nodeId === nodeId);
|
|
if (nodeData.length > 0) {
|
|
const nodePipeline = new TrainingPipeline({
|
|
learningRate: CONFIG.learningRate * 0.5,
|
|
batchSize: Math.min(CONFIG.batchSize, nodeData.length),
|
|
epochs: 5,
|
|
scheduler: 'cosine',
|
|
ewcLambda: 2000,
|
|
}, nodeAdapter);
|
|
|
|
const pipelineData = nodeData.map(w => ({
|
|
input: Array.from(model.forward(w.input)),
|
|
target: Array.from(w.target),
|
|
quality: w.confidence,
|
|
}));
|
|
nodePipeline.addData(pipelineData);
|
|
const nodeResult = nodePipeline.train();
|
|
console.log(` Node ${nodeId}: ${nodeData.length} samples, loss=${nodeResult.finalLoss.toFixed(6)}`);
|
|
}
|
|
}
|
|
|
|
console.log(` LoRA adapters: ${loraManager.list().join(', ')}`);
|
|
|
|
// -----------------------------------------------------------------------
|
|
// Phase 4 + 5: Quantization + Export
|
|
// -----------------------------------------------------------------------
|
|
await exportModel(model, CONFIG, startTime, {
|
|
loraManager,
|
|
labeledWindows,
|
|
windows,
|
|
nodeIds,
|
|
allRawCsi,
|
|
allVitals,
|
|
allFeatures,
|
|
});
|
|
}
|
|
|
|
/**
|
|
* Export trained model.
|
|
*/
|
|
async function exportModel(model, config, startTime, context) {
|
|
console.log('\n[7/7] Quantization + Export...');
|
|
|
|
fs.mkdirSync(config.outputDir, { recursive: true });
|
|
|
|
// Quantization
|
|
const allWeights = model.getAllWeights();
|
|
console.log(` Total weights: ${allWeights.length.toLocaleString()} (${(allWeights.length * 4 / 1024 / 1024).toFixed(2)} MB fp32)`);
|
|
|
|
const quantResults = {};
|
|
for (const bits of [2, 4, 8]) {
|
|
const qr = quantizeWeights(allWeights, bits);
|
|
const deq = dequantizeWeights(qr.quantized, qr.scale, qr.zeroPoint, bits, qr.numWeights);
|
|
const rmse = quantizationQuality(allWeights, deq);
|
|
quantResults[bits] = { ...qr, rmse };
|
|
console.log(` ${bits}-bit: ${qr.compressionRatio.toFixed(1)}x compression, RMSE=${rmse.toFixed(6)}, size=${(qr.quantizedSize / 1024).toFixed(1)} KB`);
|
|
}
|
|
|
|
// SafeTensors export
|
|
const exporter = new ModelExporter();
|
|
const exportData = {
|
|
metadata: {
|
|
name: 'wifi-densepose-wiflow',
|
|
version: '1.0.0',
|
|
architecture: 'wiflow-tcn-asymconv-axialattn',
|
|
training: {
|
|
steps: config.epochs,
|
|
learningRate: config.learningRate,
|
|
},
|
|
custom: {
|
|
inputChannels: config.subcarriers,
|
|
timeSteps: config.timeSteps,
|
|
numKeypoints: 17,
|
|
numHeads: 8,
|
|
totalParams: model.numParams(),
|
|
paramBreakdown: model.paramBreakdown(),
|
|
flops: estimateFLOPs({ inputChannels: config.subcarriers, timeSteps: config.timeSteps }),
|
|
seed: config.seed,
|
|
quantizationBits: config.quantizeBits,
|
|
},
|
|
},
|
|
tensors: model.toTensorMap(),
|
|
};
|
|
|
|
const safetensorsBuffer = exporter.toSafeTensors(exportData);
|
|
fs.writeFileSync(path.join(config.outputDir, 'model.safetensors'), safetensorsBuffer);
|
|
console.log(` SafeTensors: ${path.join(config.outputDir, 'model.safetensors')} (${(safetensorsBuffer.length / 1024).toFixed(1)} KB)`);
|
|
|
|
// HuggingFace config
|
|
const hfExport = exporter.toHuggingFace(exportData);
|
|
fs.writeFileSync(path.join(config.outputDir, 'config.json'), hfExport.config);
|
|
|
|
// JSON export
|
|
const jsonExport = exporter.toJSON(exportData);
|
|
fs.writeFileSync(path.join(config.outputDir, 'model.json'), jsonExport);
|
|
|
|
// Quantized models
|
|
const quantDir = path.join(config.outputDir, 'quantized');
|
|
fs.mkdirSync(quantDir, { recursive: true });
|
|
for (const [bits, qr] of Object.entries(quantResults)) {
|
|
const qPath = path.join(quantDir, `wiflow-q${bits}.bin`);
|
|
fs.writeFileSync(qPath, Buffer.from(qr.quantized));
|
|
console.log(` Quantized ${bits}-bit: ${qPath} (${(qr.quantizedSize / 1024).toFixed(1)} KB)`);
|
|
}
|
|
|
|
// LoRA adapters
|
|
if (context.loraManager) {
|
|
const loraDir = path.join(config.outputDir, 'lora');
|
|
fs.mkdirSync(loraDir, { recursive: true });
|
|
for (const adapterId of context.loraManager.list()) {
|
|
const adapter = context.loraManager.get(adapterId);
|
|
const loraPath = path.join(loraDir, `${adapterId}.json`);
|
|
fs.writeFileSync(loraPath, adapter.toJSON());
|
|
console.log(` LoRA adapter: ${loraPath}`);
|
|
}
|
|
}
|
|
|
|
// RVF manifest
|
|
const rvfPath = path.join(config.outputDir, 'model.rvf.jsonl');
|
|
const rvfLines = [
|
|
JSON.stringify({ type: 'metadata', ...exportData.metadata }),
|
|
JSON.stringify({ type: 'wiflow', architecture: 'tcn-asymconv-axialattn', stages: 4 }),
|
|
JSON.stringify({ type: 'quantization', default_bits: config.quantizeBits, variants: [2, 4, 8] }),
|
|
];
|
|
fs.writeFileSync(rvfPath, rvfLines.join('\n'));
|
|
|
|
// Training metrics
|
|
const metricsPath = path.join(config.outputDir, 'training-metrics.json');
|
|
const metrics = {
|
|
timestamp: new Date().toISOString(),
|
|
totalDurationMs: Date.now() - startTime,
|
|
model: {
|
|
architecture: 'wiflow',
|
|
totalParams: model.numParams(),
|
|
paramBreakdown: model.paramBreakdown(),
|
|
flops: estimateFLOPs({ inputChannels: config.subcarriers, timeSteps: config.timeSteps }),
|
|
},
|
|
data: {
|
|
rawCsiFrames: context.allRawCsi ? context.allRawCsi.length : 0,
|
|
windows: context.windows ? context.windows.length : 0,
|
|
labeledWindows: context.labeledWindows ? context.labeledWindows.length : 0,
|
|
nodes: context.nodeIds || [],
|
|
},
|
|
quantization: Object.fromEntries(
|
|
Object.entries(quantResults).map(([bits, qr]) => [
|
|
`q${bits}`,
|
|
{ compressionRatio: qr.compressionRatio, rmse: qr.rmse, sizeKB: qr.quantizedSize / 1024 },
|
|
])
|
|
),
|
|
config,
|
|
};
|
|
fs.writeFileSync(metricsPath, JSON.stringify(metrics, null, 2));
|
|
console.log(` Metrics: ${metricsPath}`);
|
|
|
|
const elapsed = ((Date.now() - startTime) / 1000).toFixed(1);
|
|
console.log(`\n=== Training complete in ${elapsed}s ===`);
|
|
console.log(` Output: ${config.outputDir}`);
|
|
console.log(` Model size: ${(allWeights.length * 4 / 1024 / 1024).toFixed(2)} MB (fp32), ${(quantResults[8].quantizedSize / 1024 / 1024).toFixed(2)} MB (int8)`);
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Run
|
|
// ---------------------------------------------------------------------------
|
|
main().catch(err => {
|
|
console.error('Training failed:', err);
|
|
process.exit(1);
|
|
});
|