diff --git a/crates/ruvector-rulake/src/cache.rs b/crates/ruvector-rulake/src/cache.rs index a46adaa8..900767ce 100644 --- a/crates/ruvector-rulake/src/cache.rs +++ b/crates/ruvector-rulake/src/cache.rs @@ -420,6 +420,63 @@ impl VectorCache { .collect()) } + /// Batch form of [`search_cached_with_rerank`]: all `queries` hit + /// the same cached entry, so we take the mutex once and look up + /// the witness + pos_to_id mapping once. This is the surface a + /// GPU / SIMD kernel needs to amortize per-query setup; today's + /// CPU impl already saves the repeated mutex acquisition and + /// eliminates per-query `ensure_fresh` calls from the caller. + /// + /// Preserves the query order: `result[i]` is the top-k for + /// `queries[i]`. + pub fn search_cached_batch( + &self, + key: &CacheKey, + queries: &[Vec], + k: usize, + rerank_factor_override: Option, + ) -> crate::Result>> { + let mut inner = self.inner.lock().unwrap(); + let witness = inner + .pointers + .get(key) + .ok_or_else(|| crate::RuLakeError::UnknownCollection { + backend: key.0.clone(), + collection: key.1.clone(), + })? + .clone(); + let entry = inner.entries.get_mut(&witness).ok_or_else(|| { + crate::RuLakeError::UnknownCollection { + backend: key.0.clone(), + collection: key.1.clone(), + } + })?; + let dim = entry.dim; + for q in queries { + if q.len() != dim { + return Err(crate::RuLakeError::DimensionMismatch { + expected: dim, + actual: q.len(), + }); + } + } + entry.last_used = Instant::now(); + let mut raw: Vec> = Vec::with_capacity(queries.len()); + for q in queries { + let r = match rerank_factor_override { + None => entry.index.search(q, k)?, + Some(rf) => entry.index.search_with_rerank(q, k, rf)?, + }; + raw.push(r); + } + let pos_to_id = entry.pos_to_id.clone(); + drop(inner); + Ok(raw + .into_iter() + .map(|v| v.into_iter().map(|r| (pos_to_id[r.id], r.score)).collect()) + .collect()) + } + pub fn touch(&self, key: &CacheKey) { let mut inner = self.inner.lock().unwrap(); inner.last_checked.insert(key.clone(), Instant::now()); diff --git a/crates/ruvector-rulake/src/lake.rs b/crates/ruvector-rulake/src/lake.rs index 49b293f1..5289a7f8 100644 --- a/crates/ruvector-rulake/src/lake.rs +++ b/crates/ruvector-rulake/src/lake.rs @@ -298,6 +298,42 @@ impl RuLake { .collect()) } + /// Batched single-collection search. All `queries` run against the + /// same `(backend, collection)` key, so coherence is checked once + /// and the cache mutex is held for a single round of N scans + /// instead of N separate acquires. Preserves query order: + /// `result[i]` is the top-k for `queries[i]`. + /// + /// This is also the plug-point for the future `VectorKernel` trait + /// (ADR-157): GPU / SIMD kernels cross over CPU only at batch + /// sizes above their `min_batch`, so a per-query API would never + /// let dispatch pick them. The batch API makes the dispatch + /// decision tractable. + pub fn search_batch( + &self, + backend: &str, + collection: &str, + queries: &[Vec], + k: usize, + ) -> Result>> { + let key: CacheKey = (backend.to_string(), collection.to_string()); + self.ensure_fresh(&key)?; + let raw = self.cache.search_cached_batch(&key, queries, k, None)?; + Ok(raw + .into_iter() + .map(|v| { + v.into_iter() + .map(|(id, score)| SearchResult { + backend: backend.to_string(), + collection: collection.to_string(), + id, + score, + }) + .collect() + }) + .collect()) + } + /// Coherence check: ask the backend for its current bundle and /// compare its witness with whatever the cache currently points at. /// diff --git a/crates/ruvector-rulake/tests/federation_smoke.rs b/crates/ruvector-rulake/tests/federation_smoke.rs index 6746673b..9aa623f5 100644 --- a/crates/ruvector-rulake/tests/federation_smoke.rs +++ b/crates/ruvector-rulake/tests/federation_smoke.rs @@ -762,6 +762,75 @@ fn refresh_from_bundle_dir_reports_all_three_states() { let _ = std::fs::remove_dir_all(&tmp); } +#[test] +fn search_batch_matches_per_query_results() { + // search_batch must return the same top-k as calling search_one + // for each query individually. Byte-exact required: same seed, + // same rerank factor, same cache entry — the only difference is + // the API. + let d = 32; + let n = 500; + let seed = 61; + let data = clustered(n, d, 8, seed); + let backend = Arc::new(LocalBackend::new("batch")); + backend + .put_collection("c", d, (0..n as u64).collect(), data) + .unwrap(); + let lake = RuLake::new(20, seed).with_consistency(Consistency::Eventual { ttl_ms: 60_000 }); + lake.register_backend(backend).unwrap(); + + let queries = clustered(16, d, 8, seed ^ 0xa5a5); + + // Prime. + lake.search_one("batch", "c", &queries[0], 5).unwrap(); + + let single: Vec> = queries + .iter() + .map(|q| lake.search_one("batch", "c", q, 5).unwrap()) + .collect(); + let batch = lake.search_batch("batch", "c", &queries, 5).unwrap(); + assert_eq!(single.len(), batch.len()); + for (i, (a, b)) in single.iter().zip(batch.iter()).enumerate() { + assert_eq!( + a, b, + "batch and single diverged at query {i}: single={:?} batch={:?}", + a, b + ); + } +} + +#[test] +fn search_batch_acquires_cache_lock_once() { + // A single search_batch(N=32) should register as one coherence-skip + // hit (Eventual) or one miss+prime on first call, NOT N individual + // hits. This is how operators can see the lock-amortization working. + let d = 16; + let backend = Arc::new(LocalBackend::new("amort")); + backend + .put_collection("c", d, (0..50).collect(), vec![vec![0.0; d]; 50]) + .unwrap(); + let lake = RuLake::new(20, 42).with_consistency(Consistency::Eventual { ttl_ms: 60_000 }); + lake.register_backend(backend).unwrap(); + + // Prime. + lake.search_one("amort", "c", &vec![0.0f32; d], 1).unwrap(); + let before = lake.cache_stats(); + + let queries = vec![vec![0.0f32; d]; 32]; + let _ = lake.search_batch("amort", "c", &queries, 1).unwrap(); + let after = lake.cache_stats(); + + // Before this change, per-query ensure_fresh bumped hits by N=32. + // The batch path bumps it by exactly 1. + let delta_hits = after.hits - before.hits; + assert_eq!( + delta_hits, 1, + "batch of 32 should register as 1 coherence check, got {delta_hits}" + ); + // No new primes. + assert_eq!(after.primes, before.primes); +} + #[test] fn frozen_consistency_never_rechecks_after_prime() { // Frozen asserts immutability. After the first miss-prime, backend