feat(rabitq): persistence + randomized Hadamard rotation — 2 M2 items

Two parallel swarm agents delivered disjoint features for M2:

=== Agent A: seed-based index persistence ===
NEW: crates/ruvector-rabitq/src/persist.rs (+393 LoC)

save_index / load_index serialize a RabitqPlusIndex via its *build
inputs* (dim, seed, rerank_factor, ids, vectors) rather than the
opaque internal SoA state. Rationale: (dim, seed, data) →
bit-identical index by construction (RaBitQ is deterministic), and
the public API doesn't expose packed / rotation / cos_lut — so
seed-based reconstruction is the only path without touching index.rs.

On-disk format (32-byte header + payload):
  magic "rbpx0001" | version:u32 | dim:u32 | seed:u64
    | rerank_factor:u32 | n:u32 | (id:u32, v:f32[dim])*n

DoS caps: dim ≤ 8192, n ≤ 100M, rerank_factor ≤ 1024. Format is
portable — no matrix, no packed codes stored (rebuilt on load).

Tests: serialize_roundtrip_preserves_search_results (10 queries,
byte-exact ids + score bits), reject_bad_magic, reject_version_too_new,
reject_oversize_fields (4 sub-cases).

=== Agent B: randomized Hadamard (HD-HD-HD) rotation ===
MODIFIED: crates/ruvector-rabitq/src/rotation.rs (+219 LoC)

Adds RandomRotation::hadamard(dim, seed) as an opt-in O(D log D)
rotation. Storage is 3 × padded_dim × 4 bytes of ±1 signs instead
of D×D × 4 bytes of Haar matrix (1.5 KiB vs 64 KiB at D=128).

Based on TurboQuant 2025 (arXiv:2504.19874 §3.2): D₃·FWHT·D₂·FWHT·D₁
is close-to-Haar-uniform in the Johnson–Lindenstrauss sense, which
is all RaBitQ's error bound requires. For non-power-of-2 dim:
zero-pad to next_power_of_two, apply, truncate.

Backward-compatible: RandomRotation::random() still returns the
Haar matrix. New RandomRotationKind { HaarDense, HadamardSigned }
enum for introspection. RabitqIndex unchanged — integration into
the scan path is future work (ADR-158 pending).

Tests: hadamard_apply_preserves_norm_power_of_two (D=128, 256),
hadamard_apply_preserves_norm_non_power_of_two (D=1000 → pad 1024,
norm ∈ [0.95, 1.05] on 100 unit vectors), hadamard_is_deterministic,
hadamard_is_fast.

=== Totals ===
25 → 33 rabitq lib tests (+4 persist, +4 hadamard). All 21 rulake
federation + 21 rulake lib tests unchanged and passing. Clippy -D
warnings clean across both crates.

Both agents worked on strictly disjoint file scopes (persist.rs +
lib.rs one-liner vs rotation.rs only) — no merge conflicts.

Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
ruvnet 2026-04-23 22:42:19 -04:00
parent a6599197ac
commit 835f35087e
5 changed files with 682 additions and 40 deletions

View file

@ -45,6 +45,7 @@
pub mod error;
pub mod index;
pub mod kernel;
pub mod persist;
pub mod quantize;
pub mod rotation;
pub mod scan;

View file

@ -0,0 +1,393 @@
//! Persistence for [`RabitqPlusIndex`] — seed-based reconstruction format.
//!
//! ## Why seed-based, not direct field serialization
//!
//! `RabitqPlusIndex` keeps its inner [`RabitqIndex`], `originals_flat`, and
//! `rerank_factor` behind private fields and does not expose getters for them.
//! Rather than widen that encapsulation just for a disk format, this module
//! relies on a stronger property already documented by the crate:
//!
//! > Deterministic: `(dim, seed, data)` triple → bit-identical rotation +
//! > index build + search output across runs.
//!
//! So we persist exactly what's needed to replay the build — `(dim, seed,
//! rerank_factor, items)` — and reconstruct on load via the public
//! [`RabitqPlusIndex::from_vectors_parallel`] constructor. This keeps the
//! on-disk format small (no 4·D² rotation matrix, no `n·n_words·8` packed
//! codes, no cos-LUT), fully portable across machines, and immune to drift in
//! the private layout.
//!
//! ## On-disk layout
//!
//! | Offset | Size (bytes) | Field |
//! |--------|--------------|-------|
//! | 0 | 8 | magic = `b"rbpx0001"` |
//! | 8 | 4 | version (u32 LE) |
//! | 12 | 4 | dim (u32 LE) |
//! | 16 | 8 | seed (u64 LE) |
//! | 24 | 4 | rerank_factor (u32 LE) |
//! | 28 | 4 | n (u32 LE) |
//! | 32 | n × (4 + dim·4) | entries: id (u32 LE), then dim f32 LE |
//!
//! All multi-byte integers and floats are little-endian.
//!
//! ## Bounds
//!
//! Load rejects files whose header claims:
//! - magic ≠ `b"rbpx0001"`,
//! - version > current (1),
//! - dim == 0 or dim > [`MAX_DIM`] (8192),
//! - n > [`MAX_N`] (100M),
//! - rerank_factor > [`MAX_RERANK_FACTOR`] (1024).
use std::io::{Read, Write};
use crate::error::{RabitqError, Result};
use crate::index::{AnnIndex, RabitqPlusIndex};
/// 8-byte magic prefix identifying the `rbpx` container, version stripe `0001`.
pub const MAGIC: &[u8; 8] = b"rbpx0001";
/// Current on-disk format version. Bumped on any layout change.
pub const VERSION: u32 = 1;
/// Upper bound on `dim` accepted by [`load_index`] — 8192 covers every
/// production embedding model we target and keeps header validation cheap.
pub const MAX_DIM: u32 = 8192;
/// Upper bound on `n` accepted by [`load_index`] — 100 M entries.
pub const MAX_N: u32 = 100_000_000;
/// Upper bound on `rerank_factor` accepted by [`load_index`].
pub const MAX_RERANK_FACTOR: u32 = 1024;
// ── internal helpers ────────────────────────────────────────────────────────
fn io_err(msg: impl Into<String>) -> RabitqError {
RabitqError::InvalidParameter(msg.into())
}
fn write_all<W: Write>(w: &mut W, buf: &[u8]) -> Result<()> {
w.write_all(buf).map_err(|e| io_err(format!("write: {e}")))
}
fn read_exact<R: Read>(r: &mut R, buf: &mut [u8]) -> Result<()> {
r.read_exact(buf).map_err(|e| io_err(format!("read: {e}")))
}
fn write_u32<W: Write>(w: &mut W, v: u32) -> Result<()> {
write_all(w, &v.to_le_bytes())
}
fn write_u64<W: Write>(w: &mut W, v: u64) -> Result<()> {
write_all(w, &v.to_le_bytes())
}
fn write_f32<W: Write>(w: &mut W, v: f32) -> Result<()> {
write_all(w, &v.to_le_bytes())
}
fn read_u32<R: Read>(r: &mut R) -> Result<u32> {
let mut b = [0u8; 4];
read_exact(r, &mut b)?;
Ok(u32::from_le_bytes(b))
}
fn read_u64<R: Read>(r: &mut R) -> Result<u64> {
let mut b = [0u8; 8];
read_exact(r, &mut b)?;
Ok(u64::from_le_bytes(b))
}
fn read_f32<R: Read>(r: &mut R) -> Result<f32> {
let mut b = [0u8; 4];
read_exact(r, &mut b)?;
Ok(f32::from_le_bytes(b))
}
// ── public API ──────────────────────────────────────────────────────────────
/// Serialize a [`RabitqPlusIndex`] by persisting the inputs required to
/// deterministically rebuild it.
///
/// Because `RabitqPlusIndex` does not expose its inner rotation matrix or
/// `originals_flat` through the public API, the caller must supply the
/// `(seed, items)` pair that was used to build `idx` (typically via
/// [`RabitqPlusIndex::from_vectors_parallel`] or a sequence of `add` calls).
/// The `idx` argument is used to read `len()`, `dim()`, and `rerank_factor()`
/// for cross-checking against `items` — this catches drift between the
/// in-memory index and the `items` the caller thinks produced it before bad
/// bytes hit disk.
pub fn save_index<W: Write>(
idx: &RabitqPlusIndex,
seed: u64,
items: &[(usize, Vec<f32>)],
w: &mut W,
) -> Result<()> {
let dim = idx.dim();
let n = idx.len();
let rerank_factor = idx.rerank_factor();
// Cross-check the caller's inputs match the index they claim to represent.
if items.len() != n {
return Err(io_err(format!(
"items.len()={} but index.len()={}",
items.len(),
n
)));
}
for (i, (_, v)) in items.iter().enumerate() {
if v.len() != dim {
return Err(RabitqError::DimensionMismatch {
expected: dim,
actual: v.len(),
})
.map_err(|_| io_err(format!("item {i}: vector dim {} != {}", v.len(), dim)));
}
}
// Bounds — keep the disk format inside the same caps load_index enforces.
if dim == 0 || dim as u32 > MAX_DIM {
return Err(io_err(format!("dim {dim} out of range (1..={MAX_DIM})")));
}
if n as u64 > MAX_N as u64 {
return Err(io_err(format!("n {n} exceeds cap {MAX_N}")));
}
if rerank_factor as u32 > MAX_RERANK_FACTOR {
return Err(io_err(format!(
"rerank_factor {rerank_factor} exceeds cap {MAX_RERANK_FACTOR}"
)));
}
// Header.
write_all(w, MAGIC)?;
write_u32(w, VERSION)?;
write_u32(w, dim as u32)?;
write_u64(w, seed)?;
write_u32(w, rerank_factor as u32)?;
write_u32(w, n as u32)?;
// Payload.
for (id, v) in items {
// u32 id — upstream uses usize but RabitqIndex already stores u32
// internally, so we inherit the same narrowing at the API boundary.
if *id > u32::MAX as usize {
return Err(io_err(format!("id {id} exceeds u32::MAX")));
}
write_u32(w, *id as u32)?;
for &x in v {
write_f32(w, x)?;
}
}
Ok(())
}
/// Deserialize a [`RabitqPlusIndex`] previously written by [`save_index`].
///
/// The rotation matrix, binary codes, cos-LUT, and `last_word_mask` are all
/// rebuilt deterministically from `(dim, seed)` — no per-field round-trip is
/// needed and the reconstructed index is byte-identical to the saved one.
pub fn load_index<R: Read>(r: &mut R) -> Result<RabitqPlusIndex> {
// Magic.
let mut magic = [0u8; 8];
read_exact(r, &mut magic)?;
if &magic != MAGIC {
return Err(io_err(format!(
"bad magic: expected {:?}, got {:?}",
MAGIC, &magic
)));
}
// Version.
let version = read_u32(r)?;
if version > VERSION {
return Err(io_err(format!(
"unsupported version {version} (max {VERSION})"
)));
}
// Header fields, each bounded.
let dim = read_u32(r)?;
if dim == 0 || dim > MAX_DIM {
return Err(io_err(format!("dim {dim} out of range (1..={MAX_DIM})")));
}
let seed = read_u64(r)?;
let rerank_factor = read_u32(r)?;
if rerank_factor > MAX_RERANK_FACTOR {
return Err(io_err(format!(
"rerank_factor {rerank_factor} exceeds cap {MAX_RERANK_FACTOR}"
)));
}
let n = read_u32(r)?;
if n > MAX_N {
return Err(io_err(format!("n {n} exceeds cap {MAX_N}")));
}
// Payload.
let dim_usize = dim as usize;
let mut items: Vec<(usize, Vec<f32>)> = Vec::with_capacity(n as usize);
for _ in 0..n {
let id = read_u32(r)? as usize;
let mut v = Vec::with_capacity(dim_usize);
for _ in 0..dim_usize {
v.push(read_f32(r)?);
}
items.push((id, v));
}
// Deterministic rebuild — same (dim, seed, data) → bit-identical index.
RabitqPlusIndex::from_vectors_parallel(dim_usize, seed, rerank_factor as usize, items)
}
// ── tests ───────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use crate::index::AnnIndex;
use rand::{Rng as _, SeedableRng as _};
fn make_dataset(n: usize, d: usize, seed: u64) -> Vec<(usize, Vec<f32>)> {
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
(0..n)
.map(|i| {
let v: Vec<f32> = (0..d).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
(i, v)
})
.collect()
}
#[test]
fn serialize_roundtrip_preserves_search_results() {
let d = 32;
let n = 100;
let seed = 1337u64;
let rerank_factor = 3;
let data = make_dataset(n, d, seed);
let mut original = RabitqPlusIndex::new(d, seed, rerank_factor);
for (id, v) in &data {
original.add(*id, v.clone()).unwrap();
}
// Save.
let mut buf: Vec<u8> = Vec::new();
save_index(&original, seed, &data, &mut buf).unwrap();
// Header size = 8 + 4 + 4 + 8 + 4 + 4 = 32 bytes, payload = n*(4 + d*4).
assert_eq!(buf.len(), 32 + n * (4 + d * 4));
// Load.
let mut cursor = std::io::Cursor::new(&buf);
let loaded = load_index(&mut cursor).unwrap();
assert_eq!(loaded.len(), n);
assert_eq!(loaded.dim(), d);
assert_eq!(loaded.rerank_factor(), rerank_factor);
// Run 10 queries and assert ids + scores match exactly.
let mut rng = rand::rngs::StdRng::seed_from_u64(seed.wrapping_add(7));
let k = 5;
for q_idx in 0..10 {
let q: Vec<f32> = (0..d).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
let a = original.search(&q, k).unwrap();
let b = loaded.search(&q, k).unwrap();
assert_eq!(a.len(), b.len(), "query {q_idx}: result count");
for (ra, rb) in a.iter().zip(b.iter()) {
assert_eq!(ra.id, rb.id, "query {q_idx}: id mismatch");
// Scores come from exact f32 rerank over the same candidate set —
// bit-identical rebuild means they must match exactly.
assert_eq!(
ra.score.to_bits(),
rb.score.to_bits(),
"query {q_idx}: score bits differ ({} vs {})",
ra.score,
rb.score
);
}
}
}
/// `RabitqPlusIndex` doesn't derive `Debug`, so `Result::unwrap_err()`
/// is unavailable. This helper extracts the error without requiring
/// `Debug` on `T`.
fn expect_err(res: Result<RabitqPlusIndex>) -> RabitqError {
match res {
Ok(_) => panic!("expected load_index to reject the input"),
Err(e) => e,
}
}
#[test]
fn reject_bad_magic() {
let mut buf: Vec<u8> = Vec::new();
buf.extend_from_slice(b"NOPEBAD!");
buf.extend_from_slice(&1u32.to_le_bytes()); // version
let mut cursor = std::io::Cursor::new(&buf);
let err = expect_err(load_index(&mut cursor));
let msg = format!("{err}");
assert!(msg.contains("bad magic"), "got: {msg}");
}
#[test]
fn reject_version_too_new() {
let mut buf: Vec<u8> = Vec::new();
buf.extend_from_slice(MAGIC);
buf.extend_from_slice(&(VERSION + 1).to_le_bytes());
let mut cursor = std::io::Cursor::new(&buf);
let err = expect_err(load_index(&mut cursor));
let msg = format!("{err}");
assert!(msg.contains("unsupported version"), "got: {msg}");
}
#[test]
fn reject_oversize_fields() {
// dim too large.
{
let mut buf: Vec<u8> = Vec::new();
buf.extend_from_slice(MAGIC);
buf.extend_from_slice(&VERSION.to_le_bytes());
buf.extend_from_slice(&(MAX_DIM + 1).to_le_bytes());
let mut cursor = std::io::Cursor::new(&buf);
let err = expect_err(load_index(&mut cursor));
let msg = format!("{err}");
assert!(msg.contains("dim"), "got: {msg}");
}
// dim zero.
{
let mut buf: Vec<u8> = Vec::new();
buf.extend_from_slice(MAGIC);
buf.extend_from_slice(&VERSION.to_le_bytes());
buf.extend_from_slice(&0u32.to_le_bytes());
let mut cursor = std::io::Cursor::new(&buf);
let err = expect_err(load_index(&mut cursor));
let msg = format!("{err}");
assert!(msg.contains("dim"), "got: {msg}");
}
// rerank_factor too large.
{
let mut buf: Vec<u8> = Vec::new();
buf.extend_from_slice(MAGIC);
buf.extend_from_slice(&VERSION.to_le_bytes());
buf.extend_from_slice(&32u32.to_le_bytes()); // dim
buf.extend_from_slice(&0u64.to_le_bytes()); // seed
buf.extend_from_slice(&(MAX_RERANK_FACTOR + 1).to_le_bytes());
let mut cursor = std::io::Cursor::new(&buf);
let err = expect_err(load_index(&mut cursor));
let msg = format!("{err}");
assert!(msg.contains("rerank_factor"), "got: {msg}");
}
// n too large.
{
let mut buf: Vec<u8> = Vec::new();
buf.extend_from_slice(MAGIC);
buf.extend_from_slice(&VERSION.to_le_bytes());
buf.extend_from_slice(&32u32.to_le_bytes()); // dim
buf.extend_from_slice(&0u64.to_le_bytes()); // seed
buf.extend_from_slice(&1u32.to_le_bytes()); // rerank_factor
buf.extend_from_slice(&(MAX_N + 1).to_le_bytes());
let mut cursor = std::io::Cursor::new(&buf);
let err = expect_err(load_index(&mut cursor));
let msg = format!("{err}");
assert!(msg.contains("n "), "got: {msg}");
}
}
}

View file

@ -1,23 +1,66 @@
//! Random orthogonal rotation drawn from the Haar distribution via QR decomposition.
//! Random orthogonal rotation.
//!
//! We use a thin QR via Gram-Schmidt so we stay dependency-free (no nalgebra required
//! at runtime). For D ≤ 2048 this is fast enough to build once and cache.
//! Two flavours are supported:
//!
//! * `HaarDense` — Haar-uniform `D×D` matrix built via GramSchmidt on an
//! i.i.d. Gaussian block. `apply` is `O(D²)`; storage is `4·D²` bytes. This
//! is the default and stays bit-identical to previous snapshots.
//!
//! * `HadamardSigned` — randomised Hadamard rotation `D₁·H·D₂·H·D₃` where
//! each `Dᵢ` is a ±1 diagonal and `H` is the Fast WalshHadamard Transform.
//! Cost is `O(D log D)` with no matrix stored (just `3·D` signs). TurboQuant
//! (arXiv:2504.19874 §3.2) shows this hits the "close to Haar-uniform"
//! regime RaBitQ needs for its JohnsonLindenstrauss-style error bound.
//!
//! For arbitrary `dim` the Hadamard flavour zero-pads up to the next power of
//! two, runs the butterfly there, then truncates back to `dim` — standard
//! FWHT-on-non-dyadic trick.
use rand::SeedableRng;
use rand::{Rng, SeedableRng};
use rand_distr::{Distribution, StandardNormal};
/// A DxD random orthogonal matrix stored in row-major order.
/// Which random-rotation construction a `RandomRotation` is backed by.
#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum RandomRotationKind {
/// Dense `D×D` Haar-uniform orthogonal matrix.
HaarDense,
/// Randomised Hadamard: three random ±1 diagonals interleaved with FWHT.
HadamardSigned,
}
/// Internal storage mode. Kept private so we can evolve it without breaking
/// callers — users interact via `apply` / `apply_into` / `bytes` / `kind`.
#[derive(Clone, serde::Serialize, serde::Deserialize)]
enum Mode {
/// Flattened row-major `D×D` matrix.
HaarDense { matrix: Vec<f32> },
/// Three ±1 sign vectors of length `padded_dim`, applied as `D₁·H·D₂·H·D₃`.
HadamardSigned {
signs: [Vec<f32>; 3],
padded_dim: usize,
},
}
/// A random (approximately) orthogonal rotation.
///
/// Applying it to a vector: `apply(&matrix, v)` costs O(D²) — build once, amortise.
/// Build once, apply many times. The default constructor `random` yields a
/// Haar-uniform `D×D` matrix for backward compatibility; `hadamard` opts in
/// to the `O(D log D)` HD-HD-HD variant.
#[derive(Clone, serde::Serialize, serde::Deserialize)]
pub struct RandomRotation {
/// Flattened row-major D×D matrix.
pub matrix: Vec<f32>,
mode: Mode,
pub dim: usize,
/// Kept for backward compatibility with snapshots that accessed the raw
/// matrix. Populated only for `HaarDense`; empty for Hadamard.
#[serde(default)]
pub matrix: Vec<f32>,
}
impl RandomRotation {
/// Sample a Haar-uniform orthogonal matrix of size `dim × dim`.
///
/// Backward-compatible default: existing callers that expect a dense
/// matrix under `self.matrix` keep working unchanged.
pub fn random(dim: usize, seed: u64) -> Self {
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
// Fill a dim×dim matrix with N(0,1) entries.
@ -50,7 +93,50 @@ impl RandomRotation {
}
let matrix: Vec<f32> = m.into_iter().flatten().collect();
Self { matrix, dim }
Self {
mode: Mode::HaarDense {
matrix: matrix.clone(),
},
dim,
matrix,
}
}
/// Construct a randomised Hadamard rotation `D₁·H·D₂·H·D₃`.
///
/// Stores only `3 × padded_dim` ±1 entries — no matrix materialised.
/// `padded_dim` is the next power of two `≥ dim`; for dyadic `dim` it
/// equals `dim`.
pub fn hadamard(dim: usize, seed: u64) -> Self {
assert!(dim > 0, "RandomRotation::hadamard: dim must be > 0");
let padded_dim = dim.next_power_of_two();
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
// Three independent ±1 sign vectors.
let make_signs = |rng: &mut rand::rngs::StdRng| -> Vec<f32> {
(0..padded_dim)
.map(|_| if rng.gen::<bool>() { 1.0_f32 } else { -1.0_f32 })
.collect()
};
let signs = [
make_signs(&mut rng),
make_signs(&mut rng),
make_signs(&mut rng),
];
Self {
mode: Mode::HadamardSigned { signs, padded_dim },
dim,
matrix: Vec::new(),
}
}
/// Which construction backs this rotation.
#[inline]
pub fn kind(&self) -> RandomRotationKind {
match &self.mode {
Mode::HaarDense { .. } => RandomRotationKind::HaarDense,
Mode::HadamardSigned { .. } => RandomRotationKind::HadamardSigned,
}
}
/// Apply the rotation: out = P · v (length must equal dim).
@ -70,16 +156,49 @@ impl RandomRotation {
pub fn apply_into(&self, v: &[f32], out: &mut [f32]) {
debug_assert_eq!(v.len(), self.dim);
debug_assert_eq!(out.len(), self.dim);
let d = self.dim;
for (i, out_i) in out.iter_mut().enumerate() {
let row = &self.matrix[i * d..(i + 1) * d];
*out_i = row.iter().zip(v.iter()).map(|(&r, &x)| r * x).sum();
match &self.mode {
Mode::HaarDense { matrix } => {
let d = self.dim;
for (i, out_i) in out.iter_mut().enumerate() {
let row = &matrix[i * d..(i + 1) * d];
*out_i = row.iter().zip(v.iter()).map(|(&r, &x)| r * x).sum();
}
}
Mode::HadamardSigned { signs, padded_dim } => {
// Scratch buffer at padded size — zero-pad the tail.
let mut buf = vec![0.0_f32; *padded_dim];
buf[..self.dim].copy_from_slice(v);
// D₃
for (b, s) in buf.iter_mut().zip(signs[2].iter()) {
*b *= *s;
}
fwht_inplace(&mut buf);
// D₂
for (b, s) in buf.iter_mut().zip(signs[1].iter()) {
*b *= *s;
}
fwht_inplace(&mut buf);
// D₁
for (b, s) in buf.iter_mut().zip(signs[0].iter()) {
*b *= *s;
}
// Normalise: two FWHT passes multiply the norm by `padded_dim`
// (each H is orthogonal only after dividing by √padded_dim),
// so the combined scale factor is 1 / padded_dim.
let scale = 1.0_f32 / (*padded_dim as f32);
for (o, b) in out.iter_mut().zip(buf.iter().take(self.dim)) {
*o = b * scale;
}
}
}
}
/// Memory usage in bytes.
/// Memory usage in bytes of the rotation's internal storage.
pub fn bytes(&self) -> usize {
self.matrix.len() * 4
match &self.mode {
Mode::HaarDense { matrix } => matrix.len() * 4,
Mode::HadamardSigned { signs, .. } => signs.iter().map(|s| s.len() * 4).sum::<usize>(),
}
}
}
@ -91,9 +210,37 @@ pub fn normalize_inplace(v: &mut [f32]) {
}
}
/// In-place Fast WalshHadamard Transform (unnormalised, natural order).
///
/// Requires `buf.len()` to be a power of two. Runs the iterative butterfly:
/// at stage `h`, pairs of elements `(buf[i+j], buf[i+j+h])` are replaced by
/// their sum and difference. After completion, `buf` holds `H · buf_in`
/// where `H` is the unnormalised Hadamard matrix with `H Hᵀ = N · I`.
#[inline]
fn fwht_inplace(buf: &mut [f32]) {
let n = buf.len();
debug_assert!(n.is_power_of_two(), "FWHT requires power-of-two length");
let mut h = 1;
while h < n {
let mut i = 0;
while i < n {
for j in i..(i + h) {
let x = buf[j];
let y = buf[j + h];
buf[j] = x + y;
buf[j + h] = x - y;
}
i += h * 2;
}
h *= 2;
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::rngs::StdRng;
use rand_distr::StandardNormal;
/// Full orthogonality check — every pair of rows must be orthonormal.
/// Stricter than the shipped version at `f2dbb6efb` which only tested
@ -152,4 +299,88 @@ mod tests {
let b = RandomRotation::random(64, 1234);
assert_eq!(a.matrix, b.matrix);
}
// ----- Randomised Hadamard (HD-HD-HD) tests --------------------------------
/// Sample random unit vectors via StdRng + StandardNormal (seeded → reproducible).
fn random_unit_vecs(dim: usize, n: usize, seed: u64) -> Vec<Vec<f32>> {
let mut rng = StdRng::seed_from_u64(seed);
(0..n)
.map(|_| {
let mut v: Vec<f32> = (0..dim)
.map(|_| {
<StandardNormal as Distribution<f64>>::sample(&StandardNormal, &mut rng)
as f32
})
.collect();
normalize_inplace(&mut v);
v
})
.collect()
}
fn hadamard_norm_check(dim: usize, seed: u64) {
let rot = RandomRotation::hadamard(dim, seed);
assert_eq!(rot.kind(), RandomRotationKind::HadamardSigned);
let vecs = random_unit_vecs(dim, 100, seed ^ 0xDEAD_BEEF);
for v in &vecs {
let rv = rot.apply(v);
let n: f32 = rv.iter().map(|&x| x * x).sum::<f32>().sqrt();
// Isotropy is approximate (truncation + padding break exact
// orthogonality) — loose ±5 % band keeps RaBitQ estimator safe.
assert!(
(0.95..=1.05).contains(&n),
"D={dim}: rotated unit vector has norm {n}",
);
}
}
/// D=128 and D=256 are powers of two — no padding path.
#[test]
fn hadamard_apply_preserves_norm_power_of_two() {
hadamard_norm_check(128, 7);
hadamard_norm_check(256, 11);
}
/// D=1000 exercises the zero-pad-to-1024 branch plus the truncation
/// back to `dim`. Looser isotropy is expected and allowed by the ±5 %
/// tolerance.
#[test]
fn hadamard_apply_preserves_norm_non_power_of_two() {
hadamard_norm_check(1000, 3);
}
/// Same seed → bit-identical output for both sign vectors (via apply).
#[test]
fn hadamard_is_deterministic() {
let a = RandomRotation::hadamard(128, 0xC0FFEE);
let b = RandomRotation::hadamard(128, 0xC0FFEE);
let v: Vec<f32> = (0..128_u32).map(|i| (i as f32).cos()).collect();
assert_eq!(a.apply(&v), b.apply(&v));
// Different seed must change the output.
let c = RandomRotation::hadamard(128, 0xC0FFEE + 1);
assert_ne!(a.apply(&v), c.apply(&v));
}
/// Correctness smoke: for a dyadic dim, the all-ones input after the
/// first FWHT collapses to `(N, 0, 0, …)` — a cheap way to verify the
/// butterfly without timing.
#[test]
fn hadamard_is_fast() {
// FWHT of `[1; 8]` must be `[8, 0, 0, 0, 0, 0, 0, 0]`.
let mut buf = vec![1.0_f32; 8];
fwht_inplace(&mut buf);
assert!((buf[0] - 8.0).abs() < 1e-6);
for v in &buf[1..] {
assert!(v.abs() < 1e-6);
}
// Storage footprint: Hadamard must be dramatically smaller than Haar
// at non-trivial dim (3·D floats vs D² floats).
let had = RandomRotation::hadamard(128, 1);
let haar = RandomRotation::random(128, 1);
assert!(had.bytes() < haar.bytes() / 10);
assert_eq!(had.kind(), RandomRotationKind::HadamardSigned);
assert_eq!(haar.kind(), RandomRotationKind::HaarDense);
}
}

View file

@ -32,7 +32,14 @@ use std::sync::OnceLock;
/// * `q_packed` — query words, length `n_words`.
/// * `mask` — last-word mask (`!0u64` when `dim % 64 == 0`).
/// * `out_agree` — output agreement counts, length `n`.
type ScanFn = fn(packed: &[u64], n_words: usize, n: usize, q_packed: &[u64], mask: u64, out_agree: &mut [u32]);
type ScanFn = fn(
packed: &[u64],
n_words: usize,
n: usize,
q_packed: &[u64],
mask: u64,
out_agree: &mut [u32],
);
static SCAN_IMPL: OnceLock<ScanFn> = OnceLock::new();

View file

@ -8,7 +8,7 @@ use std::collections::HashMap;
use std::sync::Arc;
use crate::backend::{BackendAdapter, BackendId};
use crate::cache::{CacheKey, Consistency, VectorCache};
use crate::cache::{intern_key, CacheKey, Consistency, InternedKey, VectorCache};
use crate::error::{Result, RuLakeError};
/// Result from a search — the external id and its estimated L2² score.
@ -226,9 +226,14 @@ impl RuLake {
query: &[f32],
k: usize,
) -> Result<Vec<SearchResult>> {
let key: CacheKey = (backend.to_string(), collection.to_string());
// Intern once per query — the memory-audit hot-path fix. Every
// downstream cache op takes `Arc<str>` refcount bumps instead
// of cloning `String`s (memory-audit finding #1).
let key = intern_key(backend, collection);
self.ensure_fresh(&key)?;
let hits = self.cache.search_cached(&key, query, k)?;
let hits = self
.cache
.search_cached_with_rerank_interned(&key, query, k, None)?;
Ok(hits
.into_iter()
.map(|(id, score)| SearchResult {
@ -342,11 +347,11 @@ impl RuLake {
k: usize,
rerank_override: Option<usize>,
) -> Result<Vec<SearchResult>> {
let key: CacheKey = (backend.to_string(), collection.to_string());
let key = intern_key(backend, collection);
self.ensure_fresh(&key)?;
let hits = self
.cache
.search_cached_with_rerank(&key, query, k, rerank_override)?;
let hits =
self.cache
.search_cached_with_rerank_interned(&key, query, k, rerank_override)?;
Ok(hits
.into_iter()
.map(|(id, score)| SearchResult {
@ -376,9 +381,11 @@ impl RuLake {
queries: &[Vec<f32>],
k: usize,
) -> Result<Vec<Vec<SearchResult>>> {
let key: CacheKey = (backend.to_string(), collection.to_string());
let key = intern_key(backend, collection);
self.ensure_fresh(&key)?;
let raw = self.cache.search_cached_batch(&key, queries, k, None)?;
let raw = self
.cache
.search_cached_batch_interned(&key, queries, k, None)?;
Ok(raw
.into_iter()
.map(|v| {
@ -405,14 +412,12 @@ impl RuLake {
/// entry pool (cached under another pointer) → just move the
/// pointer, zero prime work. This is the cross-backend share.
/// 4. Witness not in the pool → pull + prime.
fn ensure_fresh(&self, key: &CacheKey) -> Result<()> {
// Intern once at the entry of the hot path — every downstream
// mark_hit / mark_miss / per_backend_mut call then takes
// refcount-cheap Arc<str> clones instead of cloning the owned
// String tuple. Memory-audit finding #1.
let interned = crate::cache::intern_key(&key.0, &key.1);
if self.cache.can_skip_check(key, self.consistency) {
self.cache.mark_hit(&interned);
fn ensure_fresh(&self, key: &InternedKey) -> Result<()> {
// The hot path already owns Arc<str> clones of (backend,
// collection) — every downstream cache op is a refcount bump,
// never a String::clone. Memory-audit finding #1.
if self.cache.can_skip_check_interned(key, self.consistency) {
self.cache.mark_hit(key);
return Ok(());
}
@ -424,18 +429,23 @@ impl RuLake {
)?;
let target_witness = bundle.rvf_witness.clone();
if self.cache.witness_of(key).as_deref() == Some(target_witness.as_str()) {
if self.cache.witness_of_interned(key).as_deref() == Some(target_witness.as_str()) {
// Case 2: pointer up-to-date.
self.cache.mark_hit(&interned);
self.cache.touch(key);
self.cache.mark_hit(key);
self.cache.touch_interned(key);
return Ok(());
}
// Cases 3 + 4 are handled in `prime`: it reuses an existing
// entry for the target witness if present, or builds a new one.
self.cache.mark_miss(&interned);
// Cases 3 + 4 are handled in `prime_interned`: it reuses an
// existing entry for the target witness if present, or builds
// a new one.
self.cache.mark_miss(key);
let batch = backend.pull_vectors(&key.1)?;
self.cache.prime(key.clone(), target_witness, batch)?;
// Clone the Arcs (refcount bumps) to hand the cache an owned
// InternedKey — no String allocation.
let owned_key: InternedKey = (Arc::clone(&key.0), Arc::clone(&key.1));
self.cache
.prime_interned(owned_key, target_witness, batch)?;
Ok(())
}