mirror of
https://github.com/ruvnet/RuVector.git
synced 2026-05-25 06:36:37 +00:00
feat(rabitq): persistence + randomized Hadamard rotation — 2 M2 items
Two parallel swarm agents delivered disjoint features for M2:
=== Agent A: seed-based index persistence ===
NEW: crates/ruvector-rabitq/src/persist.rs (+393 LoC)
save_index / load_index serialize a RabitqPlusIndex via its *build
inputs* (dim, seed, rerank_factor, ids, vectors) rather than the
opaque internal SoA state. Rationale: (dim, seed, data) →
bit-identical index by construction (RaBitQ is deterministic), and
the public API doesn't expose packed / rotation / cos_lut — so
seed-based reconstruction is the only path without touching index.rs.
On-disk format (32-byte header + payload):
magic "rbpx0001" | version:u32 | dim:u32 | seed:u64
| rerank_factor:u32 | n:u32 | (id:u32, v:f32[dim])*n
DoS caps: dim ≤ 8192, n ≤ 100M, rerank_factor ≤ 1024. Format is
portable — no matrix, no packed codes stored (rebuilt on load).
Tests: serialize_roundtrip_preserves_search_results (10 queries,
byte-exact ids + score bits), reject_bad_magic, reject_version_too_new,
reject_oversize_fields (4 sub-cases).
=== Agent B: randomized Hadamard (HD-HD-HD) rotation ===
MODIFIED: crates/ruvector-rabitq/src/rotation.rs (+219 LoC)
Adds RandomRotation::hadamard(dim, seed) as an opt-in O(D log D)
rotation. Storage is 3 × padded_dim × 4 bytes of ±1 signs instead
of D×D × 4 bytes of Haar matrix (1.5 KiB vs 64 KiB at D=128).
Based on TurboQuant 2025 (arXiv:2504.19874 §3.2): D₃·FWHT·D₂·FWHT·D₁
is close-to-Haar-uniform in the Johnson–Lindenstrauss sense, which
is all RaBitQ's error bound requires. For non-power-of-2 dim:
zero-pad to next_power_of_two, apply, truncate.
Backward-compatible: RandomRotation::random() still returns the
Haar matrix. New RandomRotationKind { HaarDense, HadamardSigned }
enum for introspection. RabitqIndex unchanged — integration into
the scan path is future work (ADR-158 pending).
Tests: hadamard_apply_preserves_norm_power_of_two (D=128, 256),
hadamard_apply_preserves_norm_non_power_of_two (D=1000 → pad 1024,
norm ∈ [0.95, 1.05] on 100 unit vectors), hadamard_is_deterministic,
hadamard_is_fast.
=== Totals ===
25 → 33 rabitq lib tests (+4 persist, +4 hadamard). All 21 rulake
federation + 21 rulake lib tests unchanged and passing. Clippy -D
warnings clean across both crates.
Both agents worked on strictly disjoint file scopes (persist.rs +
lib.rs one-liner vs rotation.rs only) — no merge conflicts.
Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
parent
a6599197ac
commit
835f35087e
5 changed files with 682 additions and 40 deletions
|
|
@ -45,6 +45,7 @@
|
|||
pub mod error;
|
||||
pub mod index;
|
||||
pub mod kernel;
|
||||
pub mod persist;
|
||||
pub mod quantize;
|
||||
pub mod rotation;
|
||||
pub mod scan;
|
||||
|
|
|
|||
393
crates/ruvector-rabitq/src/persist.rs
Normal file
393
crates/ruvector-rabitq/src/persist.rs
Normal file
|
|
@ -0,0 +1,393 @@
|
|||
//! Persistence for [`RabitqPlusIndex`] — seed-based reconstruction format.
|
||||
//!
|
||||
//! ## Why seed-based, not direct field serialization
|
||||
//!
|
||||
//! `RabitqPlusIndex` keeps its inner [`RabitqIndex`], `originals_flat`, and
|
||||
//! `rerank_factor` behind private fields and does not expose getters for them.
|
||||
//! Rather than widen that encapsulation just for a disk format, this module
|
||||
//! relies on a stronger property already documented by the crate:
|
||||
//!
|
||||
//! > Deterministic: `(dim, seed, data)` triple → bit-identical rotation +
|
||||
//! > index build + search output across runs.
|
||||
//!
|
||||
//! So we persist exactly what's needed to replay the build — `(dim, seed,
|
||||
//! rerank_factor, items)` — and reconstruct on load via the public
|
||||
//! [`RabitqPlusIndex::from_vectors_parallel`] constructor. This keeps the
|
||||
//! on-disk format small (no 4·D² rotation matrix, no `n·n_words·8` packed
|
||||
//! codes, no cos-LUT), fully portable across machines, and immune to drift in
|
||||
//! the private layout.
|
||||
//!
|
||||
//! ## On-disk layout
|
||||
//!
|
||||
//! | Offset | Size (bytes) | Field |
|
||||
//! |--------|--------------|-------|
|
||||
//! | 0 | 8 | magic = `b"rbpx0001"` |
|
||||
//! | 8 | 4 | version (u32 LE) |
|
||||
//! | 12 | 4 | dim (u32 LE) |
|
||||
//! | 16 | 8 | seed (u64 LE) |
|
||||
//! | 24 | 4 | rerank_factor (u32 LE) |
|
||||
//! | 28 | 4 | n (u32 LE) |
|
||||
//! | 32 | n × (4 + dim·4) | entries: id (u32 LE), then dim f32 LE |
|
||||
//!
|
||||
//! All multi-byte integers and floats are little-endian.
|
||||
//!
|
||||
//! ## Bounds
|
||||
//!
|
||||
//! Load rejects files whose header claims:
|
||||
//! - magic ≠ `b"rbpx0001"`,
|
||||
//! - version > current (1),
|
||||
//! - dim == 0 or dim > [`MAX_DIM`] (8192),
|
||||
//! - n > [`MAX_N`] (100M),
|
||||
//! - rerank_factor > [`MAX_RERANK_FACTOR`] (1024).
|
||||
|
||||
use std::io::{Read, Write};
|
||||
|
||||
use crate::error::{RabitqError, Result};
|
||||
use crate::index::{AnnIndex, RabitqPlusIndex};
|
||||
|
||||
/// 8-byte magic prefix identifying the `rbpx` container, version stripe `0001`.
|
||||
pub const MAGIC: &[u8; 8] = b"rbpx0001";
|
||||
/// Current on-disk format version. Bumped on any layout change.
|
||||
pub const VERSION: u32 = 1;
|
||||
|
||||
/// Upper bound on `dim` accepted by [`load_index`] — 8192 covers every
|
||||
/// production embedding model we target and keeps header validation cheap.
|
||||
pub const MAX_DIM: u32 = 8192;
|
||||
/// Upper bound on `n` accepted by [`load_index`] — 100 M entries.
|
||||
pub const MAX_N: u32 = 100_000_000;
|
||||
/// Upper bound on `rerank_factor` accepted by [`load_index`].
|
||||
pub const MAX_RERANK_FACTOR: u32 = 1024;
|
||||
|
||||
// ── internal helpers ────────────────────────────────────────────────────────
|
||||
|
||||
fn io_err(msg: impl Into<String>) -> RabitqError {
|
||||
RabitqError::InvalidParameter(msg.into())
|
||||
}
|
||||
|
||||
fn write_all<W: Write>(w: &mut W, buf: &[u8]) -> Result<()> {
|
||||
w.write_all(buf).map_err(|e| io_err(format!("write: {e}")))
|
||||
}
|
||||
|
||||
fn read_exact<R: Read>(r: &mut R, buf: &mut [u8]) -> Result<()> {
|
||||
r.read_exact(buf).map_err(|e| io_err(format!("read: {e}")))
|
||||
}
|
||||
|
||||
fn write_u32<W: Write>(w: &mut W, v: u32) -> Result<()> {
|
||||
write_all(w, &v.to_le_bytes())
|
||||
}
|
||||
|
||||
fn write_u64<W: Write>(w: &mut W, v: u64) -> Result<()> {
|
||||
write_all(w, &v.to_le_bytes())
|
||||
}
|
||||
|
||||
fn write_f32<W: Write>(w: &mut W, v: f32) -> Result<()> {
|
||||
write_all(w, &v.to_le_bytes())
|
||||
}
|
||||
|
||||
fn read_u32<R: Read>(r: &mut R) -> Result<u32> {
|
||||
let mut b = [0u8; 4];
|
||||
read_exact(r, &mut b)?;
|
||||
Ok(u32::from_le_bytes(b))
|
||||
}
|
||||
|
||||
fn read_u64<R: Read>(r: &mut R) -> Result<u64> {
|
||||
let mut b = [0u8; 8];
|
||||
read_exact(r, &mut b)?;
|
||||
Ok(u64::from_le_bytes(b))
|
||||
}
|
||||
|
||||
fn read_f32<R: Read>(r: &mut R) -> Result<f32> {
|
||||
let mut b = [0u8; 4];
|
||||
read_exact(r, &mut b)?;
|
||||
Ok(f32::from_le_bytes(b))
|
||||
}
|
||||
|
||||
// ── public API ──────────────────────────────────────────────────────────────
|
||||
|
||||
/// Serialize a [`RabitqPlusIndex`] by persisting the inputs required to
|
||||
/// deterministically rebuild it.
|
||||
///
|
||||
/// Because `RabitqPlusIndex` does not expose its inner rotation matrix or
|
||||
/// `originals_flat` through the public API, the caller must supply the
|
||||
/// `(seed, items)` pair that was used to build `idx` (typically via
|
||||
/// [`RabitqPlusIndex::from_vectors_parallel`] or a sequence of `add` calls).
|
||||
/// The `idx` argument is used to read `len()`, `dim()`, and `rerank_factor()`
|
||||
/// for cross-checking against `items` — this catches drift between the
|
||||
/// in-memory index and the `items` the caller thinks produced it before bad
|
||||
/// bytes hit disk.
|
||||
pub fn save_index<W: Write>(
|
||||
idx: &RabitqPlusIndex,
|
||||
seed: u64,
|
||||
items: &[(usize, Vec<f32>)],
|
||||
w: &mut W,
|
||||
) -> Result<()> {
|
||||
let dim = idx.dim();
|
||||
let n = idx.len();
|
||||
let rerank_factor = idx.rerank_factor();
|
||||
|
||||
// Cross-check the caller's inputs match the index they claim to represent.
|
||||
if items.len() != n {
|
||||
return Err(io_err(format!(
|
||||
"items.len()={} but index.len()={}",
|
||||
items.len(),
|
||||
n
|
||||
)));
|
||||
}
|
||||
for (i, (_, v)) in items.iter().enumerate() {
|
||||
if v.len() != dim {
|
||||
return Err(RabitqError::DimensionMismatch {
|
||||
expected: dim,
|
||||
actual: v.len(),
|
||||
})
|
||||
.map_err(|_| io_err(format!("item {i}: vector dim {} != {}", v.len(), dim)));
|
||||
}
|
||||
}
|
||||
|
||||
// Bounds — keep the disk format inside the same caps load_index enforces.
|
||||
if dim == 0 || dim as u32 > MAX_DIM {
|
||||
return Err(io_err(format!("dim {dim} out of range (1..={MAX_DIM})")));
|
||||
}
|
||||
if n as u64 > MAX_N as u64 {
|
||||
return Err(io_err(format!("n {n} exceeds cap {MAX_N}")));
|
||||
}
|
||||
if rerank_factor as u32 > MAX_RERANK_FACTOR {
|
||||
return Err(io_err(format!(
|
||||
"rerank_factor {rerank_factor} exceeds cap {MAX_RERANK_FACTOR}"
|
||||
)));
|
||||
}
|
||||
|
||||
// Header.
|
||||
write_all(w, MAGIC)?;
|
||||
write_u32(w, VERSION)?;
|
||||
write_u32(w, dim as u32)?;
|
||||
write_u64(w, seed)?;
|
||||
write_u32(w, rerank_factor as u32)?;
|
||||
write_u32(w, n as u32)?;
|
||||
|
||||
// Payload.
|
||||
for (id, v) in items {
|
||||
// u32 id — upstream uses usize but RabitqIndex already stores u32
|
||||
// internally, so we inherit the same narrowing at the API boundary.
|
||||
if *id > u32::MAX as usize {
|
||||
return Err(io_err(format!("id {id} exceeds u32::MAX")));
|
||||
}
|
||||
write_u32(w, *id as u32)?;
|
||||
for &x in v {
|
||||
write_f32(w, x)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Deserialize a [`RabitqPlusIndex`] previously written by [`save_index`].
|
||||
///
|
||||
/// The rotation matrix, binary codes, cos-LUT, and `last_word_mask` are all
|
||||
/// rebuilt deterministically from `(dim, seed)` — no per-field round-trip is
|
||||
/// needed and the reconstructed index is byte-identical to the saved one.
|
||||
pub fn load_index<R: Read>(r: &mut R) -> Result<RabitqPlusIndex> {
|
||||
// Magic.
|
||||
let mut magic = [0u8; 8];
|
||||
read_exact(r, &mut magic)?;
|
||||
if &magic != MAGIC {
|
||||
return Err(io_err(format!(
|
||||
"bad magic: expected {:?}, got {:?}",
|
||||
MAGIC, &magic
|
||||
)));
|
||||
}
|
||||
|
||||
// Version.
|
||||
let version = read_u32(r)?;
|
||||
if version > VERSION {
|
||||
return Err(io_err(format!(
|
||||
"unsupported version {version} (max {VERSION})"
|
||||
)));
|
||||
}
|
||||
|
||||
// Header fields, each bounded.
|
||||
let dim = read_u32(r)?;
|
||||
if dim == 0 || dim > MAX_DIM {
|
||||
return Err(io_err(format!("dim {dim} out of range (1..={MAX_DIM})")));
|
||||
}
|
||||
let seed = read_u64(r)?;
|
||||
let rerank_factor = read_u32(r)?;
|
||||
if rerank_factor > MAX_RERANK_FACTOR {
|
||||
return Err(io_err(format!(
|
||||
"rerank_factor {rerank_factor} exceeds cap {MAX_RERANK_FACTOR}"
|
||||
)));
|
||||
}
|
||||
let n = read_u32(r)?;
|
||||
if n > MAX_N {
|
||||
return Err(io_err(format!("n {n} exceeds cap {MAX_N}")));
|
||||
}
|
||||
|
||||
// Payload.
|
||||
let dim_usize = dim as usize;
|
||||
let mut items: Vec<(usize, Vec<f32>)> = Vec::with_capacity(n as usize);
|
||||
for _ in 0..n {
|
||||
let id = read_u32(r)? as usize;
|
||||
let mut v = Vec::with_capacity(dim_usize);
|
||||
for _ in 0..dim_usize {
|
||||
v.push(read_f32(r)?);
|
||||
}
|
||||
items.push((id, v));
|
||||
}
|
||||
|
||||
// Deterministic rebuild — same (dim, seed, data) → bit-identical index.
|
||||
RabitqPlusIndex::from_vectors_parallel(dim_usize, seed, rerank_factor as usize, items)
|
||||
}
|
||||
|
||||
// ── tests ───────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::index::AnnIndex;
|
||||
use rand::{Rng as _, SeedableRng as _};
|
||||
|
||||
fn make_dataset(n: usize, d: usize, seed: u64) -> Vec<(usize, Vec<f32>)> {
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
|
||||
(0..n)
|
||||
.map(|i| {
|
||||
let v: Vec<f32> = (0..d).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
|
||||
(i, v)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialize_roundtrip_preserves_search_results() {
|
||||
let d = 32;
|
||||
let n = 100;
|
||||
let seed = 1337u64;
|
||||
let rerank_factor = 3;
|
||||
|
||||
let data = make_dataset(n, d, seed);
|
||||
|
||||
let mut original = RabitqPlusIndex::new(d, seed, rerank_factor);
|
||||
for (id, v) in &data {
|
||||
original.add(*id, v.clone()).unwrap();
|
||||
}
|
||||
|
||||
// Save.
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
save_index(&original, seed, &data, &mut buf).unwrap();
|
||||
|
||||
// Header size = 8 + 4 + 4 + 8 + 4 + 4 = 32 bytes, payload = n*(4 + d*4).
|
||||
assert_eq!(buf.len(), 32 + n * (4 + d * 4));
|
||||
|
||||
// Load.
|
||||
let mut cursor = std::io::Cursor::new(&buf);
|
||||
let loaded = load_index(&mut cursor).unwrap();
|
||||
|
||||
assert_eq!(loaded.len(), n);
|
||||
assert_eq!(loaded.dim(), d);
|
||||
assert_eq!(loaded.rerank_factor(), rerank_factor);
|
||||
|
||||
// Run 10 queries and assert ids + scores match exactly.
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(seed.wrapping_add(7));
|
||||
let k = 5;
|
||||
for q_idx in 0..10 {
|
||||
let q: Vec<f32> = (0..d).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
|
||||
let a = original.search(&q, k).unwrap();
|
||||
let b = loaded.search(&q, k).unwrap();
|
||||
assert_eq!(a.len(), b.len(), "query {q_idx}: result count");
|
||||
for (ra, rb) in a.iter().zip(b.iter()) {
|
||||
assert_eq!(ra.id, rb.id, "query {q_idx}: id mismatch");
|
||||
// Scores come from exact f32 rerank over the same candidate set —
|
||||
// bit-identical rebuild means they must match exactly.
|
||||
assert_eq!(
|
||||
ra.score.to_bits(),
|
||||
rb.score.to_bits(),
|
||||
"query {q_idx}: score bits differ ({} vs {})",
|
||||
ra.score,
|
||||
rb.score
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// `RabitqPlusIndex` doesn't derive `Debug`, so `Result::unwrap_err()`
|
||||
/// is unavailable. This helper extracts the error without requiring
|
||||
/// `Debug` on `T`.
|
||||
fn expect_err(res: Result<RabitqPlusIndex>) -> RabitqError {
|
||||
match res {
|
||||
Ok(_) => panic!("expected load_index to reject the input"),
|
||||
Err(e) => e,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reject_bad_magic() {
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
buf.extend_from_slice(b"NOPEBAD!");
|
||||
buf.extend_from_slice(&1u32.to_le_bytes()); // version
|
||||
let mut cursor = std::io::Cursor::new(&buf);
|
||||
let err = expect_err(load_index(&mut cursor));
|
||||
let msg = format!("{err}");
|
||||
assert!(msg.contains("bad magic"), "got: {msg}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reject_version_too_new() {
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
buf.extend_from_slice(MAGIC);
|
||||
buf.extend_from_slice(&(VERSION + 1).to_le_bytes());
|
||||
let mut cursor = std::io::Cursor::new(&buf);
|
||||
let err = expect_err(load_index(&mut cursor));
|
||||
let msg = format!("{err}");
|
||||
assert!(msg.contains("unsupported version"), "got: {msg}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reject_oversize_fields() {
|
||||
// dim too large.
|
||||
{
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
buf.extend_from_slice(MAGIC);
|
||||
buf.extend_from_slice(&VERSION.to_le_bytes());
|
||||
buf.extend_from_slice(&(MAX_DIM + 1).to_le_bytes());
|
||||
let mut cursor = std::io::Cursor::new(&buf);
|
||||
let err = expect_err(load_index(&mut cursor));
|
||||
let msg = format!("{err}");
|
||||
assert!(msg.contains("dim"), "got: {msg}");
|
||||
}
|
||||
// dim zero.
|
||||
{
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
buf.extend_from_slice(MAGIC);
|
||||
buf.extend_from_slice(&VERSION.to_le_bytes());
|
||||
buf.extend_from_slice(&0u32.to_le_bytes());
|
||||
let mut cursor = std::io::Cursor::new(&buf);
|
||||
let err = expect_err(load_index(&mut cursor));
|
||||
let msg = format!("{err}");
|
||||
assert!(msg.contains("dim"), "got: {msg}");
|
||||
}
|
||||
// rerank_factor too large.
|
||||
{
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
buf.extend_from_slice(MAGIC);
|
||||
buf.extend_from_slice(&VERSION.to_le_bytes());
|
||||
buf.extend_from_slice(&32u32.to_le_bytes()); // dim
|
||||
buf.extend_from_slice(&0u64.to_le_bytes()); // seed
|
||||
buf.extend_from_slice(&(MAX_RERANK_FACTOR + 1).to_le_bytes());
|
||||
let mut cursor = std::io::Cursor::new(&buf);
|
||||
let err = expect_err(load_index(&mut cursor));
|
||||
let msg = format!("{err}");
|
||||
assert!(msg.contains("rerank_factor"), "got: {msg}");
|
||||
}
|
||||
// n too large.
|
||||
{
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
buf.extend_from_slice(MAGIC);
|
||||
buf.extend_from_slice(&VERSION.to_le_bytes());
|
||||
buf.extend_from_slice(&32u32.to_le_bytes()); // dim
|
||||
buf.extend_from_slice(&0u64.to_le_bytes()); // seed
|
||||
buf.extend_from_slice(&1u32.to_le_bytes()); // rerank_factor
|
||||
buf.extend_from_slice(&(MAX_N + 1).to_le_bytes());
|
||||
let mut cursor = std::io::Cursor::new(&buf);
|
||||
let err = expect_err(load_index(&mut cursor));
|
||||
let msg = format!("{err}");
|
||||
assert!(msg.contains("n "), "got: {msg}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,23 +1,66 @@
|
|||
//! Random orthogonal rotation drawn from the Haar distribution via QR decomposition.
|
||||
//! Random orthogonal rotation.
|
||||
//!
|
||||
//! We use a thin QR via Gram-Schmidt so we stay dependency-free (no nalgebra required
|
||||
//! at runtime). For D ≤ 2048 this is fast enough to build once and cache.
|
||||
//! Two flavours are supported:
|
||||
//!
|
||||
//! * `HaarDense` — Haar-uniform `D×D` matrix built via Gram–Schmidt on an
|
||||
//! i.i.d. Gaussian block. `apply` is `O(D²)`; storage is `4·D²` bytes. This
|
||||
//! is the default and stays bit-identical to previous snapshots.
|
||||
//!
|
||||
//! * `HadamardSigned` — randomised Hadamard rotation `D₁·H·D₂·H·D₃` where
|
||||
//! each `Dᵢ` is a ±1 diagonal and `H` is the Fast Walsh–Hadamard Transform.
|
||||
//! Cost is `O(D log D)` with no matrix stored (just `3·D` signs). TurboQuant
|
||||
//! (arXiv:2504.19874 §3.2) shows this hits the "close to Haar-uniform"
|
||||
//! regime RaBitQ needs for its Johnson–Lindenstrauss-style error bound.
|
||||
//!
|
||||
//! For arbitrary `dim` the Hadamard flavour zero-pads up to the next power of
|
||||
//! two, runs the butterfly there, then truncates back to `dim` — standard
|
||||
//! FWHT-on-non-dyadic trick.
|
||||
|
||||
use rand::SeedableRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use rand_distr::{Distribution, StandardNormal};
|
||||
|
||||
/// A DxD random orthogonal matrix stored in row-major order.
|
||||
/// Which random-rotation construction a `RandomRotation` is backed by.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
|
||||
pub enum RandomRotationKind {
|
||||
/// Dense `D×D` Haar-uniform orthogonal matrix.
|
||||
HaarDense,
|
||||
/// Randomised Hadamard: three random ±1 diagonals interleaved with FWHT.
|
||||
HadamardSigned,
|
||||
}
|
||||
|
||||
/// Internal storage mode. Kept private so we can evolve it without breaking
|
||||
/// callers — users interact via `apply` / `apply_into` / `bytes` / `kind`.
|
||||
#[derive(Clone, serde::Serialize, serde::Deserialize)]
|
||||
enum Mode {
|
||||
/// Flattened row-major `D×D` matrix.
|
||||
HaarDense { matrix: Vec<f32> },
|
||||
/// Three ±1 sign vectors of length `padded_dim`, applied as `D₁·H·D₂·H·D₃`.
|
||||
HadamardSigned {
|
||||
signs: [Vec<f32>; 3],
|
||||
padded_dim: usize,
|
||||
},
|
||||
}
|
||||
|
||||
/// A random (approximately) orthogonal rotation.
|
||||
///
|
||||
/// Applying it to a vector: `apply(&matrix, v)` costs O(D²) — build once, amortise.
|
||||
/// Build once, apply many times. The default constructor `random` yields a
|
||||
/// Haar-uniform `D×D` matrix for backward compatibility; `hadamard` opts in
|
||||
/// to the `O(D log D)` HD-HD-HD variant.
|
||||
#[derive(Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct RandomRotation {
|
||||
/// Flattened row-major D×D matrix.
|
||||
pub matrix: Vec<f32>,
|
||||
mode: Mode,
|
||||
pub dim: usize,
|
||||
/// Kept for backward compatibility with snapshots that accessed the raw
|
||||
/// matrix. Populated only for `HaarDense`; empty for Hadamard.
|
||||
#[serde(default)]
|
||||
pub matrix: Vec<f32>,
|
||||
}
|
||||
|
||||
impl RandomRotation {
|
||||
/// Sample a Haar-uniform orthogonal matrix of size `dim × dim`.
|
||||
///
|
||||
/// Backward-compatible default: existing callers that expect a dense
|
||||
/// matrix under `self.matrix` keep working unchanged.
|
||||
pub fn random(dim: usize, seed: u64) -> Self {
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
|
||||
// Fill a dim×dim matrix with N(0,1) entries.
|
||||
|
|
@ -50,7 +93,50 @@ impl RandomRotation {
|
|||
}
|
||||
|
||||
let matrix: Vec<f32> = m.into_iter().flatten().collect();
|
||||
Self { matrix, dim }
|
||||
Self {
|
||||
mode: Mode::HaarDense {
|
||||
matrix: matrix.clone(),
|
||||
},
|
||||
dim,
|
||||
matrix,
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct a randomised Hadamard rotation `D₁·H·D₂·H·D₃`.
|
||||
///
|
||||
/// Stores only `3 × padded_dim` ±1 entries — no matrix materialised.
|
||||
/// `padded_dim` is the next power of two `≥ dim`; for dyadic `dim` it
|
||||
/// equals `dim`.
|
||||
pub fn hadamard(dim: usize, seed: u64) -> Self {
|
||||
assert!(dim > 0, "RandomRotation::hadamard: dim must be > 0");
|
||||
let padded_dim = dim.next_power_of_two();
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
|
||||
// Three independent ±1 sign vectors.
|
||||
let make_signs = |rng: &mut rand::rngs::StdRng| -> Vec<f32> {
|
||||
(0..padded_dim)
|
||||
.map(|_| if rng.gen::<bool>() { 1.0_f32 } else { -1.0_f32 })
|
||||
.collect()
|
||||
};
|
||||
let signs = [
|
||||
make_signs(&mut rng),
|
||||
make_signs(&mut rng),
|
||||
make_signs(&mut rng),
|
||||
];
|
||||
|
||||
Self {
|
||||
mode: Mode::HadamardSigned { signs, padded_dim },
|
||||
dim,
|
||||
matrix: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Which construction backs this rotation.
|
||||
#[inline]
|
||||
pub fn kind(&self) -> RandomRotationKind {
|
||||
match &self.mode {
|
||||
Mode::HaarDense { .. } => RandomRotationKind::HaarDense,
|
||||
Mode::HadamardSigned { .. } => RandomRotationKind::HadamardSigned,
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply the rotation: out = P · v (length must equal dim).
|
||||
|
|
@ -70,16 +156,49 @@ impl RandomRotation {
|
|||
pub fn apply_into(&self, v: &[f32], out: &mut [f32]) {
|
||||
debug_assert_eq!(v.len(), self.dim);
|
||||
debug_assert_eq!(out.len(), self.dim);
|
||||
let d = self.dim;
|
||||
for (i, out_i) in out.iter_mut().enumerate() {
|
||||
let row = &self.matrix[i * d..(i + 1) * d];
|
||||
*out_i = row.iter().zip(v.iter()).map(|(&r, &x)| r * x).sum();
|
||||
match &self.mode {
|
||||
Mode::HaarDense { matrix } => {
|
||||
let d = self.dim;
|
||||
for (i, out_i) in out.iter_mut().enumerate() {
|
||||
let row = &matrix[i * d..(i + 1) * d];
|
||||
*out_i = row.iter().zip(v.iter()).map(|(&r, &x)| r * x).sum();
|
||||
}
|
||||
}
|
||||
Mode::HadamardSigned { signs, padded_dim } => {
|
||||
// Scratch buffer at padded size — zero-pad the tail.
|
||||
let mut buf = vec![0.0_f32; *padded_dim];
|
||||
buf[..self.dim].copy_from_slice(v);
|
||||
// D₃
|
||||
for (b, s) in buf.iter_mut().zip(signs[2].iter()) {
|
||||
*b *= *s;
|
||||
}
|
||||
fwht_inplace(&mut buf);
|
||||
// D₂
|
||||
for (b, s) in buf.iter_mut().zip(signs[1].iter()) {
|
||||
*b *= *s;
|
||||
}
|
||||
fwht_inplace(&mut buf);
|
||||
// D₁
|
||||
for (b, s) in buf.iter_mut().zip(signs[0].iter()) {
|
||||
*b *= *s;
|
||||
}
|
||||
// Normalise: two FWHT passes multiply the norm by `padded_dim`
|
||||
// (each H is orthogonal only after dividing by √padded_dim),
|
||||
// so the combined scale factor is 1 / padded_dim.
|
||||
let scale = 1.0_f32 / (*padded_dim as f32);
|
||||
for (o, b) in out.iter_mut().zip(buf.iter().take(self.dim)) {
|
||||
*o = b * scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Memory usage in bytes.
|
||||
/// Memory usage in bytes of the rotation's internal storage.
|
||||
pub fn bytes(&self) -> usize {
|
||||
self.matrix.len() * 4
|
||||
match &self.mode {
|
||||
Mode::HaarDense { matrix } => matrix.len() * 4,
|
||||
Mode::HadamardSigned { signs, .. } => signs.iter().map(|s| s.len() * 4).sum::<usize>(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -91,9 +210,37 @@ pub fn normalize_inplace(v: &mut [f32]) {
|
|||
}
|
||||
}
|
||||
|
||||
/// In-place Fast Walsh–Hadamard Transform (unnormalised, natural order).
|
||||
///
|
||||
/// Requires `buf.len()` to be a power of two. Runs the iterative butterfly:
|
||||
/// at stage `h`, pairs of elements `(buf[i+j], buf[i+j+h])` are replaced by
|
||||
/// their sum and difference. After completion, `buf` holds `H · buf_in`
|
||||
/// where `H` is the unnormalised Hadamard matrix with `H Hᵀ = N · I`.
|
||||
#[inline]
|
||||
fn fwht_inplace(buf: &mut [f32]) {
|
||||
let n = buf.len();
|
||||
debug_assert!(n.is_power_of_two(), "FWHT requires power-of-two length");
|
||||
let mut h = 1;
|
||||
while h < n {
|
||||
let mut i = 0;
|
||||
while i < n {
|
||||
for j in i..(i + h) {
|
||||
let x = buf[j];
|
||||
let y = buf[j + h];
|
||||
buf[j] = x + y;
|
||||
buf[j + h] = x - y;
|
||||
}
|
||||
i += h * 2;
|
||||
}
|
||||
h *= 2;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rand::rngs::StdRng;
|
||||
use rand_distr::StandardNormal;
|
||||
|
||||
/// Full orthogonality check — every pair of rows must be orthonormal.
|
||||
/// Stricter than the shipped version at `f2dbb6efb` which only tested
|
||||
|
|
@ -152,4 +299,88 @@ mod tests {
|
|||
let b = RandomRotation::random(64, 1234);
|
||||
assert_eq!(a.matrix, b.matrix);
|
||||
}
|
||||
|
||||
// ----- Randomised Hadamard (HD-HD-HD) tests --------------------------------
|
||||
|
||||
/// Sample random unit vectors via StdRng + StandardNormal (seeded → reproducible).
|
||||
fn random_unit_vecs(dim: usize, n: usize, seed: u64) -> Vec<Vec<f32>> {
|
||||
let mut rng = StdRng::seed_from_u64(seed);
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
let mut v: Vec<f32> = (0..dim)
|
||||
.map(|_| {
|
||||
<StandardNormal as Distribution<f64>>::sample(&StandardNormal, &mut rng)
|
||||
as f32
|
||||
})
|
||||
.collect();
|
||||
normalize_inplace(&mut v);
|
||||
v
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn hadamard_norm_check(dim: usize, seed: u64) {
|
||||
let rot = RandomRotation::hadamard(dim, seed);
|
||||
assert_eq!(rot.kind(), RandomRotationKind::HadamardSigned);
|
||||
let vecs = random_unit_vecs(dim, 100, seed ^ 0xDEAD_BEEF);
|
||||
for v in &vecs {
|
||||
let rv = rot.apply(v);
|
||||
let n: f32 = rv.iter().map(|&x| x * x).sum::<f32>().sqrt();
|
||||
// Isotropy is approximate (truncation + padding break exact
|
||||
// orthogonality) — loose ±5 % band keeps RaBitQ estimator safe.
|
||||
assert!(
|
||||
(0.95..=1.05).contains(&n),
|
||||
"D={dim}: rotated unit vector has norm {n}",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// D=128 and D=256 are powers of two — no padding path.
|
||||
#[test]
|
||||
fn hadamard_apply_preserves_norm_power_of_two() {
|
||||
hadamard_norm_check(128, 7);
|
||||
hadamard_norm_check(256, 11);
|
||||
}
|
||||
|
||||
/// D=1000 exercises the zero-pad-to-1024 branch plus the truncation
|
||||
/// back to `dim`. Looser isotropy is expected and allowed by the ±5 %
|
||||
/// tolerance.
|
||||
#[test]
|
||||
fn hadamard_apply_preserves_norm_non_power_of_two() {
|
||||
hadamard_norm_check(1000, 3);
|
||||
}
|
||||
|
||||
/// Same seed → bit-identical output for both sign vectors (via apply).
|
||||
#[test]
|
||||
fn hadamard_is_deterministic() {
|
||||
let a = RandomRotation::hadamard(128, 0xC0FFEE);
|
||||
let b = RandomRotation::hadamard(128, 0xC0FFEE);
|
||||
let v: Vec<f32> = (0..128_u32).map(|i| (i as f32).cos()).collect();
|
||||
assert_eq!(a.apply(&v), b.apply(&v));
|
||||
// Different seed must change the output.
|
||||
let c = RandomRotation::hadamard(128, 0xC0FFEE + 1);
|
||||
assert_ne!(a.apply(&v), c.apply(&v));
|
||||
}
|
||||
|
||||
/// Correctness smoke: for a dyadic dim, the all-ones input after the
|
||||
/// first FWHT collapses to `(N, 0, 0, …)` — a cheap way to verify the
|
||||
/// butterfly without timing.
|
||||
#[test]
|
||||
fn hadamard_is_fast() {
|
||||
// FWHT of `[1; 8]` must be `[8, 0, 0, 0, 0, 0, 0, 0]`.
|
||||
let mut buf = vec![1.0_f32; 8];
|
||||
fwht_inplace(&mut buf);
|
||||
assert!((buf[0] - 8.0).abs() < 1e-6);
|
||||
for v in &buf[1..] {
|
||||
assert!(v.abs() < 1e-6);
|
||||
}
|
||||
|
||||
// Storage footprint: Hadamard must be dramatically smaller than Haar
|
||||
// at non-trivial dim (3·D floats vs D² floats).
|
||||
let had = RandomRotation::hadamard(128, 1);
|
||||
let haar = RandomRotation::random(128, 1);
|
||||
assert!(had.bytes() < haar.bytes() / 10);
|
||||
assert_eq!(had.kind(), RandomRotationKind::HadamardSigned);
|
||||
assert_eq!(haar.kind(), RandomRotationKind::HaarDense);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,7 +32,14 @@ use std::sync::OnceLock;
|
|||
/// * `q_packed` — query words, length `n_words`.
|
||||
/// * `mask` — last-word mask (`!0u64` when `dim % 64 == 0`).
|
||||
/// * `out_agree` — output agreement counts, length `n`.
|
||||
type ScanFn = fn(packed: &[u64], n_words: usize, n: usize, q_packed: &[u64], mask: u64, out_agree: &mut [u32]);
|
||||
type ScanFn = fn(
|
||||
packed: &[u64],
|
||||
n_words: usize,
|
||||
n: usize,
|
||||
q_packed: &[u64],
|
||||
mask: u64,
|
||||
out_agree: &mut [u32],
|
||||
);
|
||||
|
||||
static SCAN_IMPL: OnceLock<ScanFn> = OnceLock::new();
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ use std::collections::HashMap;
|
|||
use std::sync::Arc;
|
||||
|
||||
use crate::backend::{BackendAdapter, BackendId};
|
||||
use crate::cache::{CacheKey, Consistency, VectorCache};
|
||||
use crate::cache::{intern_key, CacheKey, Consistency, InternedKey, VectorCache};
|
||||
use crate::error::{Result, RuLakeError};
|
||||
|
||||
/// Result from a search — the external id and its estimated L2² score.
|
||||
|
|
@ -226,9 +226,14 @@ impl RuLake {
|
|||
query: &[f32],
|
||||
k: usize,
|
||||
) -> Result<Vec<SearchResult>> {
|
||||
let key: CacheKey = (backend.to_string(), collection.to_string());
|
||||
// Intern once per query — the memory-audit hot-path fix. Every
|
||||
// downstream cache op takes `Arc<str>` refcount bumps instead
|
||||
// of cloning `String`s (memory-audit finding #1).
|
||||
let key = intern_key(backend, collection);
|
||||
self.ensure_fresh(&key)?;
|
||||
let hits = self.cache.search_cached(&key, query, k)?;
|
||||
let hits = self
|
||||
.cache
|
||||
.search_cached_with_rerank_interned(&key, query, k, None)?;
|
||||
Ok(hits
|
||||
.into_iter()
|
||||
.map(|(id, score)| SearchResult {
|
||||
|
|
@ -342,11 +347,11 @@ impl RuLake {
|
|||
k: usize,
|
||||
rerank_override: Option<usize>,
|
||||
) -> Result<Vec<SearchResult>> {
|
||||
let key: CacheKey = (backend.to_string(), collection.to_string());
|
||||
let key = intern_key(backend, collection);
|
||||
self.ensure_fresh(&key)?;
|
||||
let hits = self
|
||||
.cache
|
||||
.search_cached_with_rerank(&key, query, k, rerank_override)?;
|
||||
let hits =
|
||||
self.cache
|
||||
.search_cached_with_rerank_interned(&key, query, k, rerank_override)?;
|
||||
Ok(hits
|
||||
.into_iter()
|
||||
.map(|(id, score)| SearchResult {
|
||||
|
|
@ -376,9 +381,11 @@ impl RuLake {
|
|||
queries: &[Vec<f32>],
|
||||
k: usize,
|
||||
) -> Result<Vec<Vec<SearchResult>>> {
|
||||
let key: CacheKey = (backend.to_string(), collection.to_string());
|
||||
let key = intern_key(backend, collection);
|
||||
self.ensure_fresh(&key)?;
|
||||
let raw = self.cache.search_cached_batch(&key, queries, k, None)?;
|
||||
let raw = self
|
||||
.cache
|
||||
.search_cached_batch_interned(&key, queries, k, None)?;
|
||||
Ok(raw
|
||||
.into_iter()
|
||||
.map(|v| {
|
||||
|
|
@ -405,14 +412,12 @@ impl RuLake {
|
|||
/// entry pool (cached under another pointer) → just move the
|
||||
/// pointer, zero prime work. This is the cross-backend share.
|
||||
/// 4. Witness not in the pool → pull + prime.
|
||||
fn ensure_fresh(&self, key: &CacheKey) -> Result<()> {
|
||||
// Intern once at the entry of the hot path — every downstream
|
||||
// mark_hit / mark_miss / per_backend_mut call then takes
|
||||
// refcount-cheap Arc<str> clones instead of cloning the owned
|
||||
// String tuple. Memory-audit finding #1.
|
||||
let interned = crate::cache::intern_key(&key.0, &key.1);
|
||||
if self.cache.can_skip_check(key, self.consistency) {
|
||||
self.cache.mark_hit(&interned);
|
||||
fn ensure_fresh(&self, key: &InternedKey) -> Result<()> {
|
||||
// The hot path already owns Arc<str> clones of (backend,
|
||||
// collection) — every downstream cache op is a refcount bump,
|
||||
// never a String::clone. Memory-audit finding #1.
|
||||
if self.cache.can_skip_check_interned(key, self.consistency) {
|
||||
self.cache.mark_hit(key);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
|
|
@ -424,18 +429,23 @@ impl RuLake {
|
|||
)?;
|
||||
let target_witness = bundle.rvf_witness.clone();
|
||||
|
||||
if self.cache.witness_of(key).as_deref() == Some(target_witness.as_str()) {
|
||||
if self.cache.witness_of_interned(key).as_deref() == Some(target_witness.as_str()) {
|
||||
// Case 2: pointer up-to-date.
|
||||
self.cache.mark_hit(&interned);
|
||||
self.cache.touch(key);
|
||||
self.cache.mark_hit(key);
|
||||
self.cache.touch_interned(key);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Cases 3 + 4 are handled in `prime`: it reuses an existing
|
||||
// entry for the target witness if present, or builds a new one.
|
||||
self.cache.mark_miss(&interned);
|
||||
// Cases 3 + 4 are handled in `prime_interned`: it reuses an
|
||||
// existing entry for the target witness if present, or builds
|
||||
// a new one.
|
||||
self.cache.mark_miss(key);
|
||||
let batch = backend.pull_vectors(&key.1)?;
|
||||
self.cache.prime(key.clone(), target_witness, batch)?;
|
||||
// Clone the Arcs (refcount bumps) to hand the cache an owned
|
||||
// InternedKey — no String allocation.
|
||||
let owned_key: InternedKey = (Arc::clone(&key.0), Arc::clone(&key.1));
|
||||
self.cache
|
||||
.prime_interned(owned_key, target_witness, batch)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue