perf(trm): Add NaN guards and buffer reuse optimizations

- Add NaN protection to sigmoid activation with -20/20 clamping (mlp.rs)
- Add NaN protection to confidence scoring output (confidence.rs)
- Implement mean_pool_into for zero-allocation pooling (engine.rs)
- Reuse latent buffer across iterations using std::mem::take
- Pre-allocate answer pooling buffer in reasoning loop
- Mark use_simd config as reserved for future implementation

These optimizations reduce heap allocations in the hot path and
prevent potential NaN propagation from unbounded exp() operations.

All 63 tests pass with no regressions.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
rUv 2025-12-11 19:52:20 +00:00
parent 5512729caf
commit aeb95acf2b
4 changed files with 48 additions and 10 deletions

View file

@ -107,8 +107,9 @@ impl ConfidenceScorer {
output += hidden_buffer[i] * self.w2[i];
}
// Sigmoid to bound in [0, 1]
1.0 / (1.0 + (-output).exp())
// Sigmoid to bound in [0, 1] with NaN protection
let clamped = output.clamp(-20.0, 20.0);
1.0 / (1.0 + (-clamped).exp())
}
/// Score with additional entropy-based adjustment

View file

@ -38,7 +38,8 @@ pub struct TrmConfig {
/// Confidence threshold for early stopping
pub confidence_threshold: f32,
/// Enable SIMD optimizations
/// Enable SIMD optimizations (reserved for future use)
#[serde(default)]
pub use_simd: bool,
/// Residual scale for answer refinement

View file

@ -138,6 +138,34 @@ impl TrmEngine {
pooled
}
/// Mean pool an embedding into a pre-allocated buffer (avoids allocation)
fn mean_pool_into(&self, input: &[f32], output: &mut [f32]) {
let target_dim = output.len();
if input.len() == target_dim {
output.copy_from_slice(input);
return;
}
if input.len() < target_dim {
output[..input.len()].copy_from_slice(input);
output[input.len()..].fill(0.0);
return;
}
output.fill(0.0);
let num_chunks = input.len() / target_dim;
let scale = 1.0 / num_chunks as f32;
for chunk in input.chunks(target_dim) {
for (i, &v) in chunk.iter().enumerate() {
if i < target_dim {
output[i] += v * scale;
}
}
}
}
/// Run a single latent update iteration
fn latent_update_step(
&self,
@ -157,14 +185,17 @@ impl TrmEngine {
) -> TrmResult {
let start_time = Instant::now();
// Initialize trajectory
// Initialize trajectory - pool question once (doesn't change)
let question_pooled = self.mean_pool(question, self.config.embedding_dim);
let mut trajectory = TrmTrajectory::new(question_pooled.clone());
// Initialize latent state
let mut latent = self.latent_buffer.clone();
// Initialize latent state - reuse buffer
let mut latent = std::mem::take(&mut self.latent_buffer);
latent.fill(0.0);
// Pre-allocate answer pooling buffer to avoid repeated allocations
let mut answer_pooled = vec![0.0; self.config.embedding_dim];
let mut prev_confidence = 0.0;
let mut early_stopped = false;
let mut iterations_used = 0;
@ -173,8 +204,8 @@ impl TrmEngine {
for k in 0..k_iterations {
let iter_start = Instant::now();
// Pool current answer
let answer_pooled = self.mean_pool(answer, self.config.embedding_dim);
// Pool current answer into pre-allocated buffer
self.mean_pool_into(answer, &mut answer_pooled);
// N latent update iterations
for _ in 0..self.config.latent_iterations {
@ -228,6 +259,9 @@ impl TrmEngine {
prev_confidence = confidence;
}
// Restore latent buffer for reuse
self.latent_buffer = latent;
let total_latency = start_time.elapsed().as_micros() as u64;
let final_confidence = trajectory.final_confidence();

View file

@ -159,10 +159,12 @@ impl MlpLatentUpdater {
}
}
/// Sigmoid activation
/// Sigmoid activation with NaN protection
fn sigmoid_inplace(data: &mut [f32]) {
for x in data.iter_mut() {
*x = 1.0 / (1.0 + (-*x).exp());
// Clamp to prevent overflow in exp()
let clamped = (*x).clamp(-20.0, 20.0);
*x = 1.0 / (1.0 + (-clamped).exp());
}
}