diff --git a/crates/ruvector-rabitq/src/lib.rs b/crates/ruvector-rabitq/src/lib.rs index 0fbfd02df..ecfa27441 100644 --- a/crates/ruvector-rabitq/src/lib.rs +++ b/crates/ruvector-rabitq/src/lib.rs @@ -45,6 +45,7 @@ pub mod error; pub mod index; pub mod kernel; +pub mod persist; pub mod quantize; pub mod rotation; pub mod scan; diff --git a/crates/ruvector-rabitq/src/persist.rs b/crates/ruvector-rabitq/src/persist.rs new file mode 100644 index 000000000..3aed5f909 --- /dev/null +++ b/crates/ruvector-rabitq/src/persist.rs @@ -0,0 +1,393 @@ +//! Persistence for [`RabitqPlusIndex`] — seed-based reconstruction format. +//! +//! ## Why seed-based, not direct field serialization +//! +//! `RabitqPlusIndex` keeps its inner [`RabitqIndex`], `originals_flat`, and +//! `rerank_factor` behind private fields and does not expose getters for them. +//! Rather than widen that encapsulation just for a disk format, this module +//! relies on a stronger property already documented by the crate: +//! +//! > Deterministic: `(dim, seed, data)` triple → bit-identical rotation + +//! > index build + search output across runs. +//! +//! So we persist exactly what's needed to replay the build — `(dim, seed, +//! rerank_factor, items)` — and reconstruct on load via the public +//! [`RabitqPlusIndex::from_vectors_parallel`] constructor. This keeps the +//! on-disk format small (no 4·D² rotation matrix, no `n·n_words·8` packed +//! codes, no cos-LUT), fully portable across machines, and immune to drift in +//! the private layout. +//! +//! ## On-disk layout +//! +//! | Offset | Size (bytes) | Field | +//! |--------|--------------|-------| +//! | 0 | 8 | magic = `b"rbpx0001"` | +//! | 8 | 4 | version (u32 LE) | +//! | 12 | 4 | dim (u32 LE) | +//! | 16 | 8 | seed (u64 LE) | +//! | 24 | 4 | rerank_factor (u32 LE) | +//! | 28 | 4 | n (u32 LE) | +//! | 32 | n × (4 + dim·4) | entries: id (u32 LE), then dim f32 LE | +//! +//! All multi-byte integers and floats are little-endian. +//! +//! ## Bounds +//! +//! Load rejects files whose header claims: +//! - magic ≠ `b"rbpx0001"`, +//! - version > current (1), +//! - dim == 0 or dim > [`MAX_DIM`] (8192), +//! - n > [`MAX_N`] (100M), +//! - rerank_factor > [`MAX_RERANK_FACTOR`] (1024). + +use std::io::{Read, Write}; + +use crate::error::{RabitqError, Result}; +use crate::index::{AnnIndex, RabitqPlusIndex}; + +/// 8-byte magic prefix identifying the `rbpx` container, version stripe `0001`. +pub const MAGIC: &[u8; 8] = b"rbpx0001"; +/// Current on-disk format version. Bumped on any layout change. +pub const VERSION: u32 = 1; + +/// Upper bound on `dim` accepted by [`load_index`] — 8192 covers every +/// production embedding model we target and keeps header validation cheap. +pub const MAX_DIM: u32 = 8192; +/// Upper bound on `n` accepted by [`load_index`] — 100 M entries. +pub const MAX_N: u32 = 100_000_000; +/// Upper bound on `rerank_factor` accepted by [`load_index`]. +pub const MAX_RERANK_FACTOR: u32 = 1024; + +// ── internal helpers ──────────────────────────────────────────────────────── + +fn io_err(msg: impl Into) -> RabitqError { + RabitqError::InvalidParameter(msg.into()) +} + +fn write_all(w: &mut W, buf: &[u8]) -> Result<()> { + w.write_all(buf).map_err(|e| io_err(format!("write: {e}"))) +} + +fn read_exact(r: &mut R, buf: &mut [u8]) -> Result<()> { + r.read_exact(buf).map_err(|e| io_err(format!("read: {e}"))) +} + +fn write_u32(w: &mut W, v: u32) -> Result<()> { + write_all(w, &v.to_le_bytes()) +} + +fn write_u64(w: &mut W, v: u64) -> Result<()> { + write_all(w, &v.to_le_bytes()) +} + +fn write_f32(w: &mut W, v: f32) -> Result<()> { + write_all(w, &v.to_le_bytes()) +} + +fn read_u32(r: &mut R) -> Result { + let mut b = [0u8; 4]; + read_exact(r, &mut b)?; + Ok(u32::from_le_bytes(b)) +} + +fn read_u64(r: &mut R) -> Result { + let mut b = [0u8; 8]; + read_exact(r, &mut b)?; + Ok(u64::from_le_bytes(b)) +} + +fn read_f32(r: &mut R) -> Result { + let mut b = [0u8; 4]; + read_exact(r, &mut b)?; + Ok(f32::from_le_bytes(b)) +} + +// ── public API ────────────────────────────────────────────────────────────── + +/// Serialize a [`RabitqPlusIndex`] by persisting the inputs required to +/// deterministically rebuild it. +/// +/// Because `RabitqPlusIndex` does not expose its inner rotation matrix or +/// `originals_flat` through the public API, the caller must supply the +/// `(seed, items)` pair that was used to build `idx` (typically via +/// [`RabitqPlusIndex::from_vectors_parallel`] or a sequence of `add` calls). +/// The `idx` argument is used to read `len()`, `dim()`, and `rerank_factor()` +/// for cross-checking against `items` — this catches drift between the +/// in-memory index and the `items` the caller thinks produced it before bad +/// bytes hit disk. +pub fn save_index( + idx: &RabitqPlusIndex, + seed: u64, + items: &[(usize, Vec)], + w: &mut W, +) -> Result<()> { + let dim = idx.dim(); + let n = idx.len(); + let rerank_factor = idx.rerank_factor(); + + // Cross-check the caller's inputs match the index they claim to represent. + if items.len() != n { + return Err(io_err(format!( + "items.len()={} but index.len()={}", + items.len(), + n + ))); + } + for (i, (_, v)) in items.iter().enumerate() { + if v.len() != dim { + return Err(RabitqError::DimensionMismatch { + expected: dim, + actual: v.len(), + }) + .map_err(|_| io_err(format!("item {i}: vector dim {} != {}", v.len(), dim))); + } + } + + // Bounds — keep the disk format inside the same caps load_index enforces. + if dim == 0 || dim as u32 > MAX_DIM { + return Err(io_err(format!("dim {dim} out of range (1..={MAX_DIM})"))); + } + if n as u64 > MAX_N as u64 { + return Err(io_err(format!("n {n} exceeds cap {MAX_N}"))); + } + if rerank_factor as u32 > MAX_RERANK_FACTOR { + return Err(io_err(format!( + "rerank_factor {rerank_factor} exceeds cap {MAX_RERANK_FACTOR}" + ))); + } + + // Header. + write_all(w, MAGIC)?; + write_u32(w, VERSION)?; + write_u32(w, dim as u32)?; + write_u64(w, seed)?; + write_u32(w, rerank_factor as u32)?; + write_u32(w, n as u32)?; + + // Payload. + for (id, v) in items { + // u32 id — upstream uses usize but RabitqIndex already stores u32 + // internally, so we inherit the same narrowing at the API boundary. + if *id > u32::MAX as usize { + return Err(io_err(format!("id {id} exceeds u32::MAX"))); + } + write_u32(w, *id as u32)?; + for &x in v { + write_f32(w, x)?; + } + } + Ok(()) +} + +/// Deserialize a [`RabitqPlusIndex`] previously written by [`save_index`]. +/// +/// The rotation matrix, binary codes, cos-LUT, and `last_word_mask` are all +/// rebuilt deterministically from `(dim, seed)` — no per-field round-trip is +/// needed and the reconstructed index is byte-identical to the saved one. +pub fn load_index(r: &mut R) -> Result { + // Magic. + let mut magic = [0u8; 8]; + read_exact(r, &mut magic)?; + if &magic != MAGIC { + return Err(io_err(format!( + "bad magic: expected {:?}, got {:?}", + MAGIC, &magic + ))); + } + + // Version. + let version = read_u32(r)?; + if version > VERSION { + return Err(io_err(format!( + "unsupported version {version} (max {VERSION})" + ))); + } + + // Header fields, each bounded. + let dim = read_u32(r)?; + if dim == 0 || dim > MAX_DIM { + return Err(io_err(format!("dim {dim} out of range (1..={MAX_DIM})"))); + } + let seed = read_u64(r)?; + let rerank_factor = read_u32(r)?; + if rerank_factor > MAX_RERANK_FACTOR { + return Err(io_err(format!( + "rerank_factor {rerank_factor} exceeds cap {MAX_RERANK_FACTOR}" + ))); + } + let n = read_u32(r)?; + if n > MAX_N { + return Err(io_err(format!("n {n} exceeds cap {MAX_N}"))); + } + + // Payload. + let dim_usize = dim as usize; + let mut items: Vec<(usize, Vec)> = Vec::with_capacity(n as usize); + for _ in 0..n { + let id = read_u32(r)? as usize; + let mut v = Vec::with_capacity(dim_usize); + for _ in 0..dim_usize { + v.push(read_f32(r)?); + } + items.push((id, v)); + } + + // Deterministic rebuild — same (dim, seed, data) → bit-identical index. + RabitqPlusIndex::from_vectors_parallel(dim_usize, seed, rerank_factor as usize, items) +} + +// ── tests ─────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::index::AnnIndex; + use rand::{Rng as _, SeedableRng as _}; + + fn make_dataset(n: usize, d: usize, seed: u64) -> Vec<(usize, Vec)> { + let mut rng = rand::rngs::StdRng::seed_from_u64(seed); + (0..n) + .map(|i| { + let v: Vec = (0..d).map(|_| rng.gen::() * 2.0 - 1.0).collect(); + (i, v) + }) + .collect() + } + + #[test] + fn serialize_roundtrip_preserves_search_results() { + let d = 32; + let n = 100; + let seed = 1337u64; + let rerank_factor = 3; + + let data = make_dataset(n, d, seed); + + let mut original = RabitqPlusIndex::new(d, seed, rerank_factor); + for (id, v) in &data { + original.add(*id, v.clone()).unwrap(); + } + + // Save. + let mut buf: Vec = Vec::new(); + save_index(&original, seed, &data, &mut buf).unwrap(); + + // Header size = 8 + 4 + 4 + 8 + 4 + 4 = 32 bytes, payload = n*(4 + d*4). + assert_eq!(buf.len(), 32 + n * (4 + d * 4)); + + // Load. + let mut cursor = std::io::Cursor::new(&buf); + let loaded = load_index(&mut cursor).unwrap(); + + assert_eq!(loaded.len(), n); + assert_eq!(loaded.dim(), d); + assert_eq!(loaded.rerank_factor(), rerank_factor); + + // Run 10 queries and assert ids + scores match exactly. + let mut rng = rand::rngs::StdRng::seed_from_u64(seed.wrapping_add(7)); + let k = 5; + for q_idx in 0..10 { + let q: Vec = (0..d).map(|_| rng.gen::() * 2.0 - 1.0).collect(); + let a = original.search(&q, k).unwrap(); + let b = loaded.search(&q, k).unwrap(); + assert_eq!(a.len(), b.len(), "query {q_idx}: result count"); + for (ra, rb) in a.iter().zip(b.iter()) { + assert_eq!(ra.id, rb.id, "query {q_idx}: id mismatch"); + // Scores come from exact f32 rerank over the same candidate set — + // bit-identical rebuild means they must match exactly. + assert_eq!( + ra.score.to_bits(), + rb.score.to_bits(), + "query {q_idx}: score bits differ ({} vs {})", + ra.score, + rb.score + ); + } + } + } + + /// `RabitqPlusIndex` doesn't derive `Debug`, so `Result::unwrap_err()` + /// is unavailable. This helper extracts the error without requiring + /// `Debug` on `T`. + fn expect_err(res: Result) -> RabitqError { + match res { + Ok(_) => panic!("expected load_index to reject the input"), + Err(e) => e, + } + } + + #[test] + fn reject_bad_magic() { + let mut buf: Vec = Vec::new(); + buf.extend_from_slice(b"NOPEBAD!"); + buf.extend_from_slice(&1u32.to_le_bytes()); // version + let mut cursor = std::io::Cursor::new(&buf); + let err = expect_err(load_index(&mut cursor)); + let msg = format!("{err}"); + assert!(msg.contains("bad magic"), "got: {msg}"); + } + + #[test] + fn reject_version_too_new() { + let mut buf: Vec = Vec::new(); + buf.extend_from_slice(MAGIC); + buf.extend_from_slice(&(VERSION + 1).to_le_bytes()); + let mut cursor = std::io::Cursor::new(&buf); + let err = expect_err(load_index(&mut cursor)); + let msg = format!("{err}"); + assert!(msg.contains("unsupported version"), "got: {msg}"); + } + + #[test] + fn reject_oversize_fields() { + // dim too large. + { + let mut buf: Vec = Vec::new(); + buf.extend_from_slice(MAGIC); + buf.extend_from_slice(&VERSION.to_le_bytes()); + buf.extend_from_slice(&(MAX_DIM + 1).to_le_bytes()); + let mut cursor = std::io::Cursor::new(&buf); + let err = expect_err(load_index(&mut cursor)); + let msg = format!("{err}"); + assert!(msg.contains("dim"), "got: {msg}"); + } + // dim zero. + { + let mut buf: Vec = Vec::new(); + buf.extend_from_slice(MAGIC); + buf.extend_from_slice(&VERSION.to_le_bytes()); + buf.extend_from_slice(&0u32.to_le_bytes()); + let mut cursor = std::io::Cursor::new(&buf); + let err = expect_err(load_index(&mut cursor)); + let msg = format!("{err}"); + assert!(msg.contains("dim"), "got: {msg}"); + } + // rerank_factor too large. + { + let mut buf: Vec = Vec::new(); + buf.extend_from_slice(MAGIC); + buf.extend_from_slice(&VERSION.to_le_bytes()); + buf.extend_from_slice(&32u32.to_le_bytes()); // dim + buf.extend_from_slice(&0u64.to_le_bytes()); // seed + buf.extend_from_slice(&(MAX_RERANK_FACTOR + 1).to_le_bytes()); + let mut cursor = std::io::Cursor::new(&buf); + let err = expect_err(load_index(&mut cursor)); + let msg = format!("{err}"); + assert!(msg.contains("rerank_factor"), "got: {msg}"); + } + // n too large. + { + let mut buf: Vec = Vec::new(); + buf.extend_from_slice(MAGIC); + buf.extend_from_slice(&VERSION.to_le_bytes()); + buf.extend_from_slice(&32u32.to_le_bytes()); // dim + buf.extend_from_slice(&0u64.to_le_bytes()); // seed + buf.extend_from_slice(&1u32.to_le_bytes()); // rerank_factor + buf.extend_from_slice(&(MAX_N + 1).to_le_bytes()); + let mut cursor = std::io::Cursor::new(&buf); + let err = expect_err(load_index(&mut cursor)); + let msg = format!("{err}"); + assert!(msg.contains("n "), "got: {msg}"); + } + } +} diff --git a/crates/ruvector-rabitq/src/rotation.rs b/crates/ruvector-rabitq/src/rotation.rs index 483e04275..79607893c 100644 --- a/crates/ruvector-rabitq/src/rotation.rs +++ b/crates/ruvector-rabitq/src/rotation.rs @@ -1,23 +1,66 @@ -//! Random orthogonal rotation drawn from the Haar distribution via QR decomposition. +//! Random orthogonal rotation. //! -//! We use a thin QR via Gram-Schmidt so we stay dependency-free (no nalgebra required -//! at runtime). For D ≤ 2048 this is fast enough to build once and cache. +//! Two flavours are supported: +//! +//! * `HaarDense` — Haar-uniform `D×D` matrix built via Gram–Schmidt on an +//! i.i.d. Gaussian block. `apply` is `O(D²)`; storage is `4·D²` bytes. This +//! is the default and stays bit-identical to previous snapshots. +//! +//! * `HadamardSigned` — randomised Hadamard rotation `D₁·H·D₂·H·D₃` where +//! each `Dᵢ` is a ±1 diagonal and `H` is the Fast Walsh–Hadamard Transform. +//! Cost is `O(D log D)` with no matrix stored (just `3·D` signs). TurboQuant +//! (arXiv:2504.19874 §3.2) shows this hits the "close to Haar-uniform" +//! regime RaBitQ needs for its Johnson–Lindenstrauss-style error bound. +//! +//! For arbitrary `dim` the Hadamard flavour zero-pads up to the next power of +//! two, runs the butterfly there, then truncates back to `dim` — standard +//! FWHT-on-non-dyadic trick. -use rand::SeedableRng; +use rand::{Rng, SeedableRng}; use rand_distr::{Distribution, StandardNormal}; -/// A DxD random orthogonal matrix stored in row-major order. +/// Which random-rotation construction a `RandomRotation` is backed by. +#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +pub enum RandomRotationKind { + /// Dense `D×D` Haar-uniform orthogonal matrix. + HaarDense, + /// Randomised Hadamard: three random ±1 diagonals interleaved with FWHT. + HadamardSigned, +} + +/// Internal storage mode. Kept private so we can evolve it without breaking +/// callers — users interact via `apply` / `apply_into` / `bytes` / `kind`. +#[derive(Clone, serde::Serialize, serde::Deserialize)] +enum Mode { + /// Flattened row-major `D×D` matrix. + HaarDense { matrix: Vec }, + /// Three ±1 sign vectors of length `padded_dim`, applied as `D₁·H·D₂·H·D₃`. + HadamardSigned { + signs: [Vec; 3], + padded_dim: usize, + }, +} + +/// A random (approximately) orthogonal rotation. /// -/// Applying it to a vector: `apply(&matrix, v)` costs O(D²) — build once, amortise. +/// Build once, apply many times. The default constructor `random` yields a +/// Haar-uniform `D×D` matrix for backward compatibility; `hadamard` opts in +/// to the `O(D log D)` HD-HD-HD variant. #[derive(Clone, serde::Serialize, serde::Deserialize)] pub struct RandomRotation { - /// Flattened row-major D×D matrix. - pub matrix: Vec, + mode: Mode, pub dim: usize, + /// Kept for backward compatibility with snapshots that accessed the raw + /// matrix. Populated only for `HaarDense`; empty for Hadamard. + #[serde(default)] + pub matrix: Vec, } impl RandomRotation { /// Sample a Haar-uniform orthogonal matrix of size `dim × dim`. + /// + /// Backward-compatible default: existing callers that expect a dense + /// matrix under `self.matrix` keep working unchanged. pub fn random(dim: usize, seed: u64) -> Self { let mut rng = rand::rngs::StdRng::seed_from_u64(seed); // Fill a dim×dim matrix with N(0,1) entries. @@ -50,7 +93,50 @@ impl RandomRotation { } let matrix: Vec = m.into_iter().flatten().collect(); - Self { matrix, dim } + Self { + mode: Mode::HaarDense { + matrix: matrix.clone(), + }, + dim, + matrix, + } + } + + /// Construct a randomised Hadamard rotation `D₁·H·D₂·H·D₃`. + /// + /// Stores only `3 × padded_dim` ±1 entries — no matrix materialised. + /// `padded_dim` is the next power of two `≥ dim`; for dyadic `dim` it + /// equals `dim`. + pub fn hadamard(dim: usize, seed: u64) -> Self { + assert!(dim > 0, "RandomRotation::hadamard: dim must be > 0"); + let padded_dim = dim.next_power_of_two(); + let mut rng = rand::rngs::StdRng::seed_from_u64(seed); + // Three independent ±1 sign vectors. + let make_signs = |rng: &mut rand::rngs::StdRng| -> Vec { + (0..padded_dim) + .map(|_| if rng.gen::() { 1.0_f32 } else { -1.0_f32 }) + .collect() + }; + let signs = [ + make_signs(&mut rng), + make_signs(&mut rng), + make_signs(&mut rng), + ]; + + Self { + mode: Mode::HadamardSigned { signs, padded_dim }, + dim, + matrix: Vec::new(), + } + } + + /// Which construction backs this rotation. + #[inline] + pub fn kind(&self) -> RandomRotationKind { + match &self.mode { + Mode::HaarDense { .. } => RandomRotationKind::HaarDense, + Mode::HadamardSigned { .. } => RandomRotationKind::HadamardSigned, + } } /// Apply the rotation: out = P · v (length must equal dim). @@ -70,16 +156,49 @@ impl RandomRotation { pub fn apply_into(&self, v: &[f32], out: &mut [f32]) { debug_assert_eq!(v.len(), self.dim); debug_assert_eq!(out.len(), self.dim); - let d = self.dim; - for (i, out_i) in out.iter_mut().enumerate() { - let row = &self.matrix[i * d..(i + 1) * d]; - *out_i = row.iter().zip(v.iter()).map(|(&r, &x)| r * x).sum(); + match &self.mode { + Mode::HaarDense { matrix } => { + let d = self.dim; + for (i, out_i) in out.iter_mut().enumerate() { + let row = &matrix[i * d..(i + 1) * d]; + *out_i = row.iter().zip(v.iter()).map(|(&r, &x)| r * x).sum(); + } + } + Mode::HadamardSigned { signs, padded_dim } => { + // Scratch buffer at padded size — zero-pad the tail. + let mut buf = vec![0.0_f32; *padded_dim]; + buf[..self.dim].copy_from_slice(v); + // D₃ + for (b, s) in buf.iter_mut().zip(signs[2].iter()) { + *b *= *s; + } + fwht_inplace(&mut buf); + // D₂ + for (b, s) in buf.iter_mut().zip(signs[1].iter()) { + *b *= *s; + } + fwht_inplace(&mut buf); + // D₁ + for (b, s) in buf.iter_mut().zip(signs[0].iter()) { + *b *= *s; + } + // Normalise: two FWHT passes multiply the norm by `padded_dim` + // (each H is orthogonal only after dividing by √padded_dim), + // so the combined scale factor is 1 / padded_dim. + let scale = 1.0_f32 / (*padded_dim as f32); + for (o, b) in out.iter_mut().zip(buf.iter().take(self.dim)) { + *o = b * scale; + } + } } } - /// Memory usage in bytes. + /// Memory usage in bytes of the rotation's internal storage. pub fn bytes(&self) -> usize { - self.matrix.len() * 4 + match &self.mode { + Mode::HaarDense { matrix } => matrix.len() * 4, + Mode::HadamardSigned { signs, .. } => signs.iter().map(|s| s.len() * 4).sum::(), + } } } @@ -91,9 +210,37 @@ pub fn normalize_inplace(v: &mut [f32]) { } } +/// In-place Fast Walsh–Hadamard Transform (unnormalised, natural order). +/// +/// Requires `buf.len()` to be a power of two. Runs the iterative butterfly: +/// at stage `h`, pairs of elements `(buf[i+j], buf[i+j+h])` are replaced by +/// their sum and difference. After completion, `buf` holds `H · buf_in` +/// where `H` is the unnormalised Hadamard matrix with `H Hᵀ = N · I`. +#[inline] +fn fwht_inplace(buf: &mut [f32]) { + let n = buf.len(); + debug_assert!(n.is_power_of_two(), "FWHT requires power-of-two length"); + let mut h = 1; + while h < n { + let mut i = 0; + while i < n { + for j in i..(i + h) { + let x = buf[j]; + let y = buf[j + h]; + buf[j] = x + y; + buf[j + h] = x - y; + } + i += h * 2; + } + h *= 2; + } +} + #[cfg(test)] mod tests { use super::*; + use rand::rngs::StdRng; + use rand_distr::StandardNormal; /// Full orthogonality check — every pair of rows must be orthonormal. /// Stricter than the shipped version at `f2dbb6efb` which only tested @@ -152,4 +299,88 @@ mod tests { let b = RandomRotation::random(64, 1234); assert_eq!(a.matrix, b.matrix); } + + // ----- Randomised Hadamard (HD-HD-HD) tests -------------------------------- + + /// Sample random unit vectors via StdRng + StandardNormal (seeded → reproducible). + fn random_unit_vecs(dim: usize, n: usize, seed: u64) -> Vec> { + let mut rng = StdRng::seed_from_u64(seed); + (0..n) + .map(|_| { + let mut v: Vec = (0..dim) + .map(|_| { + >::sample(&StandardNormal, &mut rng) + as f32 + }) + .collect(); + normalize_inplace(&mut v); + v + }) + .collect() + } + + fn hadamard_norm_check(dim: usize, seed: u64) { + let rot = RandomRotation::hadamard(dim, seed); + assert_eq!(rot.kind(), RandomRotationKind::HadamardSigned); + let vecs = random_unit_vecs(dim, 100, seed ^ 0xDEAD_BEEF); + for v in &vecs { + let rv = rot.apply(v); + let n: f32 = rv.iter().map(|&x| x * x).sum::().sqrt(); + // Isotropy is approximate (truncation + padding break exact + // orthogonality) — loose ±5 % band keeps RaBitQ estimator safe. + assert!( + (0.95..=1.05).contains(&n), + "D={dim}: rotated unit vector has norm {n}", + ); + } + } + + /// D=128 and D=256 are powers of two — no padding path. + #[test] + fn hadamard_apply_preserves_norm_power_of_two() { + hadamard_norm_check(128, 7); + hadamard_norm_check(256, 11); + } + + /// D=1000 exercises the zero-pad-to-1024 branch plus the truncation + /// back to `dim`. Looser isotropy is expected and allowed by the ±5 % + /// tolerance. + #[test] + fn hadamard_apply_preserves_norm_non_power_of_two() { + hadamard_norm_check(1000, 3); + } + + /// Same seed → bit-identical output for both sign vectors (via apply). + #[test] + fn hadamard_is_deterministic() { + let a = RandomRotation::hadamard(128, 0xC0FFEE); + let b = RandomRotation::hadamard(128, 0xC0FFEE); + let v: Vec = (0..128_u32).map(|i| (i as f32).cos()).collect(); + assert_eq!(a.apply(&v), b.apply(&v)); + // Different seed must change the output. + let c = RandomRotation::hadamard(128, 0xC0FFEE + 1); + assert_ne!(a.apply(&v), c.apply(&v)); + } + + /// Correctness smoke: for a dyadic dim, the all-ones input after the + /// first FWHT collapses to `(N, 0, 0, …)` — a cheap way to verify the + /// butterfly without timing. + #[test] + fn hadamard_is_fast() { + // FWHT of `[1; 8]` must be `[8, 0, 0, 0, 0, 0, 0, 0]`. + let mut buf = vec![1.0_f32; 8]; + fwht_inplace(&mut buf); + assert!((buf[0] - 8.0).abs() < 1e-6); + for v in &buf[1..] { + assert!(v.abs() < 1e-6); + } + + // Storage footprint: Hadamard must be dramatically smaller than Haar + // at non-trivial dim (3·D floats vs D² floats). + let had = RandomRotation::hadamard(128, 1); + let haar = RandomRotation::random(128, 1); + assert!(had.bytes() < haar.bytes() / 10); + assert_eq!(had.kind(), RandomRotationKind::HadamardSigned); + assert_eq!(haar.kind(), RandomRotationKind::HaarDense); + } } diff --git a/crates/ruvector-rabitq/src/scan.rs b/crates/ruvector-rabitq/src/scan.rs index 3376a9d3c..140ddf597 100644 --- a/crates/ruvector-rabitq/src/scan.rs +++ b/crates/ruvector-rabitq/src/scan.rs @@ -32,7 +32,14 @@ use std::sync::OnceLock; /// * `q_packed` — query words, length `n_words`. /// * `mask` — last-word mask (`!0u64` when `dim % 64 == 0`). /// * `out_agree` — output agreement counts, length `n`. -type ScanFn = fn(packed: &[u64], n_words: usize, n: usize, q_packed: &[u64], mask: u64, out_agree: &mut [u32]); +type ScanFn = fn( + packed: &[u64], + n_words: usize, + n: usize, + q_packed: &[u64], + mask: u64, + out_agree: &mut [u32], +); static SCAN_IMPL: OnceLock = OnceLock::new(); diff --git a/crates/ruvector-rulake/src/lake.rs b/crates/ruvector-rulake/src/lake.rs index b099172f4..bcced1f79 100644 --- a/crates/ruvector-rulake/src/lake.rs +++ b/crates/ruvector-rulake/src/lake.rs @@ -8,7 +8,7 @@ use std::collections::HashMap; use std::sync::Arc; use crate::backend::{BackendAdapter, BackendId}; -use crate::cache::{CacheKey, Consistency, VectorCache}; +use crate::cache::{intern_key, CacheKey, Consistency, InternedKey, VectorCache}; use crate::error::{Result, RuLakeError}; /// Result from a search — the external id and its estimated L2² score. @@ -226,9 +226,14 @@ impl RuLake { query: &[f32], k: usize, ) -> Result> { - let key: CacheKey = (backend.to_string(), collection.to_string()); + // Intern once per query — the memory-audit hot-path fix. Every + // downstream cache op takes `Arc` refcount bumps instead + // of cloning `String`s (memory-audit finding #1). + let key = intern_key(backend, collection); self.ensure_fresh(&key)?; - let hits = self.cache.search_cached(&key, query, k)?; + let hits = self + .cache + .search_cached_with_rerank_interned(&key, query, k, None)?; Ok(hits .into_iter() .map(|(id, score)| SearchResult { @@ -342,11 +347,11 @@ impl RuLake { k: usize, rerank_override: Option, ) -> Result> { - let key: CacheKey = (backend.to_string(), collection.to_string()); + let key = intern_key(backend, collection); self.ensure_fresh(&key)?; - let hits = self - .cache - .search_cached_with_rerank(&key, query, k, rerank_override)?; + let hits = + self.cache + .search_cached_with_rerank_interned(&key, query, k, rerank_override)?; Ok(hits .into_iter() .map(|(id, score)| SearchResult { @@ -376,9 +381,11 @@ impl RuLake { queries: &[Vec], k: usize, ) -> Result>> { - let key: CacheKey = (backend.to_string(), collection.to_string()); + let key = intern_key(backend, collection); self.ensure_fresh(&key)?; - let raw = self.cache.search_cached_batch(&key, queries, k, None)?; + let raw = self + .cache + .search_cached_batch_interned(&key, queries, k, None)?; Ok(raw .into_iter() .map(|v| { @@ -405,14 +412,12 @@ impl RuLake { /// entry pool (cached under another pointer) → just move the /// pointer, zero prime work. This is the cross-backend share. /// 4. Witness not in the pool → pull + prime. - fn ensure_fresh(&self, key: &CacheKey) -> Result<()> { - // Intern once at the entry of the hot path — every downstream - // mark_hit / mark_miss / per_backend_mut call then takes - // refcount-cheap Arc clones instead of cloning the owned - // String tuple. Memory-audit finding #1. - let interned = crate::cache::intern_key(&key.0, &key.1); - if self.cache.can_skip_check(key, self.consistency) { - self.cache.mark_hit(&interned); + fn ensure_fresh(&self, key: &InternedKey) -> Result<()> { + // The hot path already owns Arc clones of (backend, + // collection) — every downstream cache op is a refcount bump, + // never a String::clone. Memory-audit finding #1. + if self.cache.can_skip_check_interned(key, self.consistency) { + self.cache.mark_hit(key); return Ok(()); } @@ -424,18 +429,23 @@ impl RuLake { )?; let target_witness = bundle.rvf_witness.clone(); - if self.cache.witness_of(key).as_deref() == Some(target_witness.as_str()) { + if self.cache.witness_of_interned(key).as_deref() == Some(target_witness.as_str()) { // Case 2: pointer up-to-date. - self.cache.mark_hit(&interned); - self.cache.touch(key); + self.cache.mark_hit(key); + self.cache.touch_interned(key); return Ok(()); } - // Cases 3 + 4 are handled in `prime`: it reuses an existing - // entry for the target witness if present, or builds a new one. - self.cache.mark_miss(&interned); + // Cases 3 + 4 are handled in `prime_interned`: it reuses an + // existing entry for the target witness if present, or builds + // a new one. + self.cache.mark_miss(key); let batch = backend.pull_vectors(&key.1)?; - self.cache.prime(key.clone(), target_witness, batch)?; + // Clone the Arcs (refcount bumps) to hand the cache an owned + // InternedKey — no String allocation. + let owned_key: InternedKey = (Arc::clone(&key.0), Arc::clone(&key.1)); + self.cache + .prime_interned(owned_key, target_witness, batch)?; Ok(()) }