mirror of
https://github.com/ruvnet/RuView.git
synced 2026-05-01 23:40:09 +00:00
- Implemented the WiFi DensePose model in PyTorch, including CSI phase processing, modality translation, and DensePose prediction heads. - Added a comprehensive training utility for the model, including loss functions and training steps. - Created a CSV file to document hardware specifications, architecture details, training parameters, performance metrics, and advantages of the model.
489 lines
No EOL
18 KiB
Python
489 lines
No EOL
18 KiB
Python
# WiFi DensePose Implementation in PyTorch
|
|
# Based on "DensePose From WiFi" by Carnegie Mellon University
|
|
# Paper: https://arxiv.org/pdf/2301.00250
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
import math
|
|
from typing import Dict, List, Tuple, Optional
|
|
from collections import OrderedDict
|
|
|
|
class CSIPhaseProcessor:
|
|
"""
|
|
Processes raw CSI phase data through unwrapping, filtering, and linear fitting
|
|
Based on the phase sanitization methodology from the paper
|
|
"""
|
|
|
|
def __init__(self, num_subcarriers: int = 30):
|
|
self.num_subcarriers = num_subcarriers
|
|
|
|
def unwrap_phase(self, phase_data: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Unwrap phase values to handle discontinuities
|
|
Args:
|
|
phase_data: Raw phase data of shape (batch, freq_samples, tx, rx)
|
|
Returns:
|
|
Unwrapped phase data
|
|
"""
|
|
unwrapped = phase_data.clone()
|
|
|
|
# Unwrap along frequency dimension (groups of 30 frequencies)
|
|
for sample_group in range(5): # 5 consecutive samples
|
|
start_idx = sample_group * 30
|
|
end_idx = start_idx + 30
|
|
|
|
for i in range(start_idx + 1, end_idx):
|
|
diff = unwrapped[:, i] - unwrapped[:, i-1]
|
|
|
|
# Apply unwrapping logic
|
|
unwrapped[:, i] = torch.where(diff > math.pi,
|
|
unwrapped[:, i-1] + diff - 2*math.pi,
|
|
unwrapped[:, i])
|
|
unwrapped[:, i] = torch.where(diff < -math.pi,
|
|
unwrapped[:, i-1] + diff + 2*math.pi,
|
|
unwrapped[:, i])
|
|
|
|
return unwrapped
|
|
|
|
def apply_filters(self, phase_data: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Apply median and uniform filters to eliminate outliers
|
|
"""
|
|
# Simple smoothing in frequency dimension
|
|
filtered = phase_data.clone()
|
|
for i in range(1, phase_data.shape[1]-1):
|
|
filtered[:, i] = (phase_data[:, i-1] + phase_data[:, i] + phase_data[:, i+1]) / 3
|
|
|
|
return filtered
|
|
|
|
def linear_fitting(self, phase_data: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Apply linear fitting to remove systematic phase drift
|
|
"""
|
|
fitted_data = phase_data.clone()
|
|
F = self.num_subcarriers
|
|
|
|
# Process each sample group (5 consecutive samples)
|
|
for sample_group in range(5):
|
|
start_idx = sample_group * 30
|
|
end_idx = start_idx + 30
|
|
|
|
for batch_idx in range(phase_data.shape[0]):
|
|
for tx in range(phase_data.shape[2]):
|
|
for rx in range(phase_data.shape[3]):
|
|
phase_seq = phase_data[batch_idx, start_idx:end_idx, tx, rx]
|
|
|
|
if len(phase_seq) > 1:
|
|
# Calculate linear coefficients
|
|
alpha1 = (phase_seq[-1] - phase_seq[0]) / (2 * math.pi * F)
|
|
alpha0 = torch.mean(phase_seq)
|
|
|
|
# Apply linear fitting
|
|
frequencies = torch.arange(1, len(phase_seq) + 1, dtype=phase_seq.dtype, device=phase_seq.device)
|
|
linear_trend = alpha1 * frequencies + alpha0
|
|
fitted_data[batch_idx, start_idx:end_idx, tx, rx] = phase_seq - linear_trend
|
|
|
|
return fitted_data
|
|
|
|
def sanitize_phase(self, raw_phase: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Complete phase sanitization pipeline
|
|
"""
|
|
# Step 1: Unwrap phase
|
|
unwrapped = self.unwrap_phase(raw_phase)
|
|
|
|
# Step 2: Apply filters
|
|
filtered = self.apply_filters(unwrapped)
|
|
|
|
# Step 3: Linear fitting
|
|
sanitized = self.linear_fitting(filtered)
|
|
|
|
return sanitized
|
|
|
|
class ModalityTranslationNetwork(nn.Module):
|
|
"""
|
|
Translates CSI domain features to spatial domain features
|
|
Input: 150x3x3 amplitude and phase tensors
|
|
Output: 3x720x1280 feature map
|
|
"""
|
|
|
|
def __init__(self, input_dim: int = 1350, hidden_dim: int = 512, output_height: int = 720, output_width: int = 1280):
|
|
super(ModalityTranslationNetwork, self).__init__()
|
|
|
|
self.input_dim = input_dim
|
|
self.output_height = output_height
|
|
self.output_width = output_width
|
|
|
|
# Amplitude encoder
|
|
self.amplitude_encoder = nn.Sequential(
|
|
nn.Linear(input_dim, hidden_dim),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.2),
|
|
nn.Linear(hidden_dim, hidden_dim//2),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.2),
|
|
nn.Linear(hidden_dim//2, hidden_dim//4),
|
|
nn.ReLU()
|
|
)
|
|
|
|
# Phase encoder
|
|
self.phase_encoder = nn.Sequential(
|
|
nn.Linear(input_dim, hidden_dim),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.2),
|
|
nn.Linear(hidden_dim, hidden_dim//2),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.2),
|
|
nn.Linear(hidden_dim//2, hidden_dim//4),
|
|
nn.ReLU()
|
|
)
|
|
|
|
# Feature fusion
|
|
self.fusion_mlp = nn.Sequential(
|
|
nn.Linear(hidden_dim//2, hidden_dim//4),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.2),
|
|
nn.Linear(hidden_dim//4, 24*24), # Reshape to 24x24
|
|
nn.ReLU()
|
|
)
|
|
|
|
# Spatial processing
|
|
self.spatial_conv = nn.Sequential(
|
|
nn.Conv2d(1, 64, kernel_size=3, padding=1),
|
|
nn.BatchNorm2d(64),
|
|
nn.ReLU(),
|
|
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
|
nn.BatchNorm2d(128),
|
|
nn.ReLU(),
|
|
nn.AdaptiveAvgPool2d((6, 6)) # Compress to 6x6
|
|
)
|
|
|
|
# Upsampling to target resolution
|
|
self.upsample = nn.Sequential(
|
|
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # 12x12
|
|
nn.BatchNorm2d(64),
|
|
nn.ReLU(),
|
|
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), # 24x24
|
|
nn.BatchNorm2d(32),
|
|
nn.ReLU(),
|
|
nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1), # 48x48
|
|
nn.BatchNorm2d(16),
|
|
nn.ReLU(),
|
|
nn.ConvTranspose2d(16, 8, kernel_size=4, stride=2, padding=1), # 96x96
|
|
nn.BatchNorm2d(8),
|
|
nn.ReLU(),
|
|
)
|
|
|
|
# Final upsampling to target size
|
|
self.final_conv = nn.Conv2d(8, 3, kernel_size=1)
|
|
|
|
def forward(self, amplitude_tensor: torch.Tensor, phase_tensor: torch.Tensor) -> torch.Tensor:
|
|
batch_size = amplitude_tensor.shape[0]
|
|
|
|
# Flatten input tensors
|
|
amplitude_flat = amplitude_tensor.view(batch_size, -1) # [B, 1350]
|
|
phase_flat = phase_tensor.view(batch_size, -1) # [B, 1350]
|
|
|
|
# Encode features
|
|
amp_features = self.amplitude_encoder(amplitude_flat) # [B, 128]
|
|
phase_features = self.phase_encoder(phase_flat) # [B, 128]
|
|
|
|
# Fuse features
|
|
fused_features = torch.cat([amp_features, phase_features], dim=1) # [B, 256]
|
|
spatial_features = self.fusion_mlp(fused_features) # [B, 576]
|
|
|
|
# Reshape to 2D feature map
|
|
spatial_map = spatial_features.view(batch_size, 1, 24, 24) # [B, 1, 24, 24]
|
|
|
|
# Apply spatial convolutions
|
|
conv_features = self.spatial_conv(spatial_map) # [B, 128, 6, 6]
|
|
|
|
# Upsample
|
|
upsampled = self.upsample(conv_features) # [B, 8, 96, 96]
|
|
|
|
# Final convolution
|
|
final_features = self.final_conv(upsampled) # [B, 3, 96, 96]
|
|
|
|
# Interpolate to target resolution
|
|
output = F.interpolate(final_features, size=(self.output_height, self.output_width),
|
|
mode='bilinear', align_corners=False)
|
|
|
|
return output
|
|
|
|
class DensePoseHead(nn.Module):
|
|
"""
|
|
DensePose prediction head for estimating UV coordinates
|
|
"""
|
|
def __init__(self, input_channels=256, num_parts=24, output_size=(112, 112)):
|
|
super(DensePoseHead, self).__init__()
|
|
|
|
self.num_parts = num_parts
|
|
self.output_size = output_size
|
|
|
|
# Shared convolutional layers
|
|
self.shared_conv = nn.Sequential(
|
|
nn.Conv2d(input_channels, 512, kernel_size=3, padding=1),
|
|
nn.ReLU(),
|
|
nn.Conv2d(512, 512, kernel_size=3, padding=1),
|
|
nn.ReLU(),
|
|
nn.Conv2d(512, 512, kernel_size=3, padding=1),
|
|
nn.ReLU(),
|
|
)
|
|
|
|
# Part classification branch
|
|
self.part_classifier = nn.Conv2d(512, num_parts + 1, kernel_size=1) # +1 for background
|
|
|
|
# UV coordinate regression branches
|
|
self.u_regressor = nn.Conv2d(512, num_parts, kernel_size=1)
|
|
self.v_regressor = nn.Conv2d(512, num_parts, kernel_size=1)
|
|
|
|
def forward(self, x):
|
|
# Shared feature extraction
|
|
features = self.shared_conv(x)
|
|
|
|
# Upsample features to target size
|
|
features = F.interpolate(features, size=self.output_size, mode='bilinear', align_corners=False)
|
|
|
|
# Predict part labels
|
|
part_logits = self.part_classifier(features)
|
|
|
|
# Predict UV coordinates
|
|
u_coords = torch.sigmoid(self.u_regressor(features)) # Sigmoid to ensure [0,1] range
|
|
v_coords = torch.sigmoid(self.v_regressor(features))
|
|
|
|
return {
|
|
'part_logits': part_logits,
|
|
'u_coords': u_coords,
|
|
'v_coords': v_coords
|
|
}
|
|
|
|
class KeypointHead(nn.Module):
|
|
"""
|
|
Keypoint prediction head for estimating body keypoints
|
|
"""
|
|
def __init__(self, input_channels=256, num_keypoints=17, output_size=(56, 56)):
|
|
super(KeypointHead, self).__init__()
|
|
|
|
self.num_keypoints = num_keypoints
|
|
self.output_size = output_size
|
|
|
|
# Convolutional layers for keypoint detection
|
|
self.conv_layers = nn.Sequential(
|
|
nn.Conv2d(input_channels, 512, kernel_size=3, padding=1),
|
|
nn.ReLU(),
|
|
nn.Conv2d(512, 512, kernel_size=3, padding=1),
|
|
nn.ReLU(),
|
|
nn.Conv2d(512, 512, kernel_size=3, padding=1),
|
|
nn.ReLU(),
|
|
nn.Conv2d(512, num_keypoints, kernel_size=1)
|
|
)
|
|
|
|
def forward(self, x):
|
|
# Extract keypoint heatmaps
|
|
heatmaps = self.conv_layers(x)
|
|
|
|
# Upsample to target size
|
|
heatmaps = F.interpolate(heatmaps, size=self.output_size, mode='bilinear', align_corners=False)
|
|
|
|
return heatmaps
|
|
|
|
class WiFiDensePoseRCNN(nn.Module):
|
|
"""
|
|
Complete WiFi-DensePose RCNN architecture
|
|
"""
|
|
def __init__(self):
|
|
super(WiFiDensePoseRCNN, self).__init__()
|
|
|
|
# CSI processing
|
|
self.phase_processor = CSIPhaseProcessor()
|
|
|
|
# Modality translation
|
|
self.modality_translation = ModalityTranslationNetwork()
|
|
|
|
# Simplified backbone (in practice, use ResNet-FPN)
|
|
self.backbone = nn.Sequential(
|
|
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
|
|
nn.BatchNorm2d(64),
|
|
nn.ReLU(),
|
|
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
|
|
|
|
# Simplified ResNet blocks
|
|
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
|
nn.BatchNorm2d(128),
|
|
nn.ReLU(),
|
|
nn.Conv2d(128, 256, kernel_size=3, padding=1),
|
|
nn.BatchNorm2d(256),
|
|
nn.ReLU(),
|
|
)
|
|
|
|
# Prediction heads
|
|
self.densepose_head = DensePoseHead(input_channels=256)
|
|
self.keypoint_head = KeypointHead(input_channels=256)
|
|
|
|
# Global average pooling for simplified processing
|
|
self.global_pool = nn.AdaptiveAvgPool2d((7, 7))
|
|
|
|
def forward(self, amplitude_data, phase_data):
|
|
batch_size = amplitude_data.shape[0]
|
|
|
|
# Process CSI phase data
|
|
sanitized_phase = self.phase_processor.sanitize_phase(phase_data)
|
|
|
|
# Translate to spatial domain
|
|
spatial_features = self.modality_translation(amplitude_data, sanitized_phase)
|
|
|
|
# Extract backbone features
|
|
backbone_features = self.backbone(spatial_features)
|
|
|
|
# Global pooling to get fixed-size features
|
|
pooled_features = self.global_pool(backbone_features)
|
|
|
|
# Predict DensePose
|
|
densepose_output = self.densepose_head(pooled_features)
|
|
|
|
# Predict keypoints
|
|
keypoint_heatmaps = self.keypoint_head(pooled_features)
|
|
|
|
return {
|
|
'spatial_features': spatial_features,
|
|
'densepose': densepose_output,
|
|
'keypoints': keypoint_heatmaps
|
|
}
|
|
|
|
class WiFiDensePoseLoss(nn.Module):
|
|
"""
|
|
Combined loss function for WiFi DensePose training
|
|
"""
|
|
def __init__(self, lambda_dp=0.6, lambda_kp=0.3, lambda_tr=0.1):
|
|
super(WiFiDensePoseLoss, self).__init__()
|
|
|
|
self.lambda_dp = lambda_dp
|
|
self.lambda_kp = lambda_kp
|
|
self.lambda_tr = lambda_tr
|
|
|
|
# Loss functions
|
|
self.cross_entropy = nn.CrossEntropyLoss()
|
|
self.mse_loss = nn.MSELoss()
|
|
self.smooth_l1 = nn.SmoothL1Loss()
|
|
|
|
def forward(self, predictions, targets, teacher_features=None):
|
|
total_loss = 0.0
|
|
loss_dict = {}
|
|
|
|
# DensePose losses
|
|
if 'densepose' in predictions and 'densepose' in targets:
|
|
# Part classification loss
|
|
part_loss = self.cross_entropy(
|
|
predictions['densepose']['part_logits'],
|
|
targets['densepose']['part_labels']
|
|
)
|
|
|
|
# UV coordinate regression loss
|
|
uv_loss = (self.smooth_l1(predictions['densepose']['u_coords'], targets['densepose']['u_coords']) +
|
|
self.smooth_l1(predictions['densepose']['v_coords'], targets['densepose']['v_coords'])) / 2
|
|
|
|
dp_loss = part_loss + uv_loss
|
|
total_loss += self.lambda_dp * dp_loss
|
|
loss_dict['densepose'] = dp_loss
|
|
|
|
# Keypoint loss
|
|
if 'keypoints' in predictions and 'keypoints' in targets:
|
|
kp_loss = self.mse_loss(predictions['keypoints'], targets['keypoints'])
|
|
total_loss += self.lambda_kp * kp_loss
|
|
loss_dict['keypoint'] = kp_loss
|
|
|
|
# Transfer learning loss
|
|
if teacher_features is not None and 'backbone_features' in predictions:
|
|
tr_loss = self.mse_loss(predictions['backbone_features'], teacher_features)
|
|
total_loss += self.lambda_tr * tr_loss
|
|
loss_dict['transfer'] = tr_loss
|
|
|
|
loss_dict['total'] = total_loss
|
|
return total_loss, loss_dict
|
|
|
|
# Training utilities
|
|
class WiFiDensePoseTrainer:
|
|
"""
|
|
Training utilities for WiFi DensePose
|
|
"""
|
|
def __init__(self, model, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
|
self.model = model.to(device)
|
|
self.device = device
|
|
self.criterion = WiFiDensePoseLoss()
|
|
self.optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
|
self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
|
self.optimizer, milestones=[48000, 96000], gamma=0.1
|
|
)
|
|
|
|
def train_step(self, amplitude_data, phase_data, targets):
|
|
self.model.train()
|
|
self.optimizer.zero_grad()
|
|
|
|
# Forward pass
|
|
outputs = self.model(amplitude_data, phase_data)
|
|
|
|
# Compute loss
|
|
loss, loss_dict = self.criterion(outputs, targets)
|
|
|
|
# Backward pass
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
self.scheduler.step()
|
|
|
|
return loss.item(), loss_dict
|
|
|
|
def save_model(self, path):
|
|
torch.save({
|
|
'model_state_dict': self.model.state_dict(),
|
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
}, path)
|
|
|
|
def load_model(self, path):
|
|
checkpoint = torch.load(path)
|
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
|
|
# Example usage
|
|
def create_sample_data(batch_size=1, device='cpu'):
|
|
"""
|
|
Create sample CSI data for testing
|
|
"""
|
|
amplitude = torch.randn(batch_size, 150, 3, 3).to(device)
|
|
phase = torch.randn(batch_size, 150, 3, 3).to(device)
|
|
|
|
# Sample targets
|
|
targets = {
|
|
'densepose': {
|
|
'part_labels': torch.randint(0, 25, (batch_size, 112, 112)).to(device),
|
|
'u_coords': torch.rand(batch_size, 24, 112, 112).to(device),
|
|
'v_coords': torch.rand(batch_size, 24, 112, 112).to(device)
|
|
},
|
|
'keypoints': torch.rand(batch_size, 17, 56, 56).to(device)
|
|
}
|
|
|
|
return amplitude, phase, targets
|
|
|
|
if __name__ == "__main__":
|
|
# Initialize model
|
|
model = WiFiDensePoseRCNN()
|
|
trainer = WiFiDensePoseTrainer(model)
|
|
|
|
print("WiFi DensePose model initialized!")
|
|
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
|
|
|
|
# Create sample data
|
|
amplitude, phase, targets = create_sample_data()
|
|
|
|
# Run inference
|
|
with torch.no_grad():
|
|
outputs = model(amplitude, phase)
|
|
print(f"Spatial features shape: {outputs['spatial_features'].shape}")
|
|
print(f"DensePose part logits shape: {outputs['densepose']['part_logits'].shape}")
|
|
print(f"Keypoint heatmaps shape: {outputs['keypoints'].shape}")
|
|
|
|
# Training step
|
|
loss, loss_dict = trainer.train_step(amplitude, phase, targets)
|
|
print(f"Training loss: {loss:.4f}")
|
|
print(f"Loss breakdown: {loss_dict}") |