mirror of
https://github.com/ruvnet/RuVector.git
synced 2026-05-29 19:33:34 +00:00
Revert "merge: feat/analysis-diskann-motif — Vamana motif index for AC-2"
This reverts commitfe059b8591, reversing changes made tofd39d10eb9.
This commit is contained in:
parent
994d61fd5f
commit
a7a5ee5e27
7 changed files with 16 additions and 1011 deletions
|
|
@ -1,496 +0,0 @@
|
|||
//! DiskANN/Vamana-style motif index.
|
||||
//!
|
||||
//! Self-contained in-memory Vamana graph ANN index for motif
|
||||
//! embeddings. Follows Subramanya et al., "DiskANN: Fast Accurate
|
||||
//! Billion-point Nearest Neighbor Search on a Single Node"
|
||||
//! (NeurIPS 2019) — greedy beam search + α-robust pruning, two build
|
||||
//! passes (α = 1.0, then α = params.alpha).
|
||||
//!
|
||||
//! The workspace already ships `crates/ruvector-diskann`, but that
|
||||
//! crate is tuned for SSD-resident billion-scale indexes: mmap,
|
||||
//! rayon, bincode, and nondeterministic `thread_rng()` graph init.
|
||||
//! This in-example module trades that scale for zero new heavy deps,
|
||||
//! bit-deterministic graph construction (seeded xoroshiro PRNG), and
|
||||
//! ≤ 500 LOC with no unsafe.
|
||||
|
||||
use std::cmp::Ordering;
|
||||
use std::collections::BinaryHeap;
|
||||
|
||||
/// Type alias for motif embedding vectors.
|
||||
pub type EmbeddingF32 = Vec<f32>;
|
||||
|
||||
/// Vamana construction / query parameters.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct VamanaParams {
|
||||
/// Max out-degree R of the Vamana graph.
|
||||
pub max_degree: usize,
|
||||
/// Build-time beam width L_build (>= `max_degree`).
|
||||
pub build_beam: usize,
|
||||
/// Query-time beam width L_search (>= k at call sites).
|
||||
pub search_beam: usize,
|
||||
/// α pruning slack (>= 1.0). Larger keeps more long-range edges.
|
||||
pub alpha: f32,
|
||||
/// PRNG seed for graph init and build-order shuffle.
|
||||
pub seed: u64,
|
||||
}
|
||||
|
||||
impl Default for VamanaParams {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_degree: 32,
|
||||
build_beam: 64,
|
||||
search_beam: 64,
|
||||
alpha: 1.2,
|
||||
seed: 0xD15C_A44_5EE_DDEEF_u64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// In-memory DiskANN/Vamana motif index with deterministic
|
||||
/// construction. Owns its corpus by value.
|
||||
pub struct DiskAnnMotifIndex {
|
||||
corpus: Vec<EmbeddingF32>,
|
||||
dim: usize,
|
||||
neighbors: Vec<Vec<u32>>,
|
||||
entry: u32,
|
||||
search_beam: usize,
|
||||
}
|
||||
|
||||
impl DiskAnnMotifIndex {
|
||||
/// Build a Vamana graph over `corpus`.
|
||||
///
|
||||
/// Panics if `corpus` is empty or contains a vector with a
|
||||
/// dimension different from the first entry.
|
||||
pub fn new(corpus: Vec<EmbeddingF32>, params: VamanaParams) -> Self {
|
||||
assert!(!corpus.is_empty(), "DiskAnnMotifIndex: empty corpus");
|
||||
let dim = corpus[0].len();
|
||||
for v in &corpus {
|
||||
assert_eq!(v.len(), dim, "DiskAnnMotifIndex: mixed vector dims");
|
||||
}
|
||||
let n = corpus.len();
|
||||
let max_degree = params.max_degree.min(n.saturating_sub(1)).max(1);
|
||||
let build_beam = params.build_beam.max(max_degree);
|
||||
let alpha = params.alpha.max(1.0);
|
||||
|
||||
let entry = medoid(&corpus);
|
||||
let mut neighbors = init_random_graph(n, max_degree, params.seed);
|
||||
|
||||
// Pass 1 — α = 1.0 (shorter edges).
|
||||
let order1 = det_permutation(n, params.seed ^ 0xA110C_u64);
|
||||
build_pass(
|
||||
&corpus,
|
||||
&mut neighbors,
|
||||
entry,
|
||||
build_beam,
|
||||
max_degree,
|
||||
1.0,
|
||||
&order1,
|
||||
);
|
||||
// Pass 2 — α = params.alpha (longer-range diversification).
|
||||
if (alpha - 1.0).abs() > f32::EPSILON {
|
||||
let order2 = det_permutation(n, params.seed ^ 0xA110D_u64);
|
||||
build_pass(
|
||||
&corpus,
|
||||
&mut neighbors,
|
||||
entry,
|
||||
build_beam,
|
||||
max_degree,
|
||||
alpha,
|
||||
&order2,
|
||||
);
|
||||
}
|
||||
|
||||
Self {
|
||||
corpus,
|
||||
dim,
|
||||
neighbors,
|
||||
entry,
|
||||
search_beam: params.search_beam.max(max_degree),
|
||||
}
|
||||
}
|
||||
|
||||
/// Number of vectors indexed.
|
||||
pub fn len(&self) -> usize {
|
||||
self.corpus.len()
|
||||
}
|
||||
|
||||
/// Whether the index is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.corpus.is_empty()
|
||||
}
|
||||
|
||||
/// Embedding dimension.
|
||||
pub fn dim(&self) -> usize {
|
||||
self.dim
|
||||
}
|
||||
|
||||
/// Top-k nearest neighbours of `q` by Euclidean distance.
|
||||
///
|
||||
/// Results are sorted by distance ascending; ties break on
|
||||
/// ascending id for determinism.
|
||||
pub fn query(&self, q: &[f32], k: usize) -> Vec<(usize, f32)> {
|
||||
assert_eq!(q.len(), self.dim, "DiskAnnMotifIndex::query: dim mismatch");
|
||||
if k == 0 || self.corpus.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
let beam = self.search_beam.max(k);
|
||||
let beamset = greedy_search(&self.corpus, &self.neighbors, self.entry, q, beam);
|
||||
let mut out: Vec<(usize, f32)> = beamset
|
||||
.into_iter()
|
||||
.map(|(id, d2)| (id as usize, d2.sqrt()))
|
||||
.collect();
|
||||
out.sort_by(cmp_result);
|
||||
out.truncate(k);
|
||||
out
|
||||
}
|
||||
|
||||
/// Class-label precision@k over `queries`.
|
||||
///
|
||||
/// Each query is `(embedding, class_label)`. For every query we
|
||||
/// retrieve the k nearest neighbours **excluding the query itself
|
||||
/// if it is in the corpus** (matched by bit-identical vector), and
|
||||
/// count how many share the query's class label. Returns the
|
||||
/// fraction of all retrieved slots whose label matches.
|
||||
pub fn precision_at_k(&self, queries: &[(EmbeddingF32, usize)], k: usize) -> f32 {
|
||||
if queries.is_empty() || k == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
// id -> class map, populated from queries. Corpus vectors not
|
||||
// present in `queries` get tag usize::MAX which matches nothing.
|
||||
let mut id_class: Vec<usize> = vec![usize::MAX; self.corpus.len()];
|
||||
for (qv, cls) in queries {
|
||||
if let Some(id) = find_vec(&self.corpus, qv) {
|
||||
id_class[id] = *cls;
|
||||
}
|
||||
}
|
||||
let mut matched = 0_usize;
|
||||
let mut total = 0_usize;
|
||||
let retrieve = k + 1; // pull one extra to drop self-hit.
|
||||
for (qv, cls) in queries {
|
||||
let self_id = find_vec(&self.corpus, qv);
|
||||
let nn = self.query(qv, retrieve);
|
||||
let mut taken = 0_usize;
|
||||
for (id, _) in nn {
|
||||
if Some(id) == self_id {
|
||||
continue;
|
||||
}
|
||||
if taken == k {
|
||||
break;
|
||||
}
|
||||
if id_class[id] == *cls {
|
||||
matched += 1;
|
||||
}
|
||||
total += 1;
|
||||
taken += 1;
|
||||
}
|
||||
}
|
||||
if total == 0 {
|
||||
0.0
|
||||
} else {
|
||||
matched as f32 / total as f32
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------
|
||||
// Vamana build
|
||||
// -----------------------------------------------------------------
|
||||
|
||||
fn build_pass(
|
||||
corpus: &[EmbeddingF32],
|
||||
neighbors: &mut [Vec<u32>],
|
||||
entry: u32,
|
||||
build_beam: usize,
|
||||
max_degree: usize,
|
||||
alpha: f32,
|
||||
order: &[u32],
|
||||
) {
|
||||
for &node in order {
|
||||
let q = &corpus[node as usize];
|
||||
let visited = greedy_search(corpus, neighbors, entry, q, build_beam);
|
||||
let mut cand: Vec<u32> = visited
|
||||
.iter()
|
||||
.filter_map(|(id, _)| if *id == node { None } else { Some(*id) })
|
||||
.collect();
|
||||
// Also include current neighbours so pass 2 refines pass 1.
|
||||
for &n in &neighbors[node as usize] {
|
||||
if n != node && !cand.contains(&n) {
|
||||
cand.push(n);
|
||||
}
|
||||
}
|
||||
let pruned = robust_prune(corpus, node, &cand, alpha, max_degree);
|
||||
neighbors[node as usize] = pruned.clone();
|
||||
// Bidirectional insertion with re-prune on overflow.
|
||||
for &nid in &pruned {
|
||||
let slot = &mut neighbors[nid as usize];
|
||||
if !slot.contains(&node) {
|
||||
if slot.len() < max_degree {
|
||||
slot.push(node);
|
||||
} else {
|
||||
let mut combined = slot.clone();
|
||||
combined.push(node);
|
||||
neighbors[nid as usize] =
|
||||
robust_prune(corpus, nid, &combined, alpha, max_degree);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Greedy beam search from `entry` toward `q`. Returns the final
|
||||
/// closed beam as (id, L2²-distance) pairs in no particular order.
|
||||
fn greedy_search(
|
||||
corpus: &[EmbeddingF32],
|
||||
neighbors: &[Vec<u32>],
|
||||
entry: u32,
|
||||
q: &[f32],
|
||||
beam: usize,
|
||||
) -> Vec<(u32, f32)> {
|
||||
let n = corpus.len();
|
||||
let mut visited = vec![false; n];
|
||||
// frontier: open, min-heap on distance ascending.
|
||||
// best: closed beam, max-heap on distance so we cheaply evict.
|
||||
let mut frontier = BinaryHeap::<Min>::new();
|
||||
let mut best = BinaryHeap::<Max>::new();
|
||||
let entry_u = entry as usize;
|
||||
visited[entry_u] = true;
|
||||
let d0 = l2_sq(q, &corpus[entry_u]);
|
||||
frontier.push(Min { id: entry, d: d0 });
|
||||
best.push(Max { id: entry, d: d0 });
|
||||
while let Some(cur) = frontier.pop() {
|
||||
if best.len() >= beam {
|
||||
if let Some(worst) = best.peek() {
|
||||
if cur.d > worst.d {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
for &nb in &neighbors[cur.id as usize] {
|
||||
let nu = nb as usize;
|
||||
if nu >= n || visited[nu] {
|
||||
continue;
|
||||
}
|
||||
visited[nu] = true;
|
||||
let nd = l2_sq(q, &corpus[nu]);
|
||||
let dominated = best.len() >= beam
|
||||
&& best.peek().map(|w| nd >= w.d).unwrap_or(false);
|
||||
if !dominated {
|
||||
frontier.push(Min { id: nb, d: nd });
|
||||
best.push(Max { id: nb, d: nd });
|
||||
if best.len() > beam {
|
||||
best.pop();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
best.into_iter().map(|c| (c.id, c.d)).collect()
|
||||
}
|
||||
|
||||
/// Robust α-prune. Keeps at most `max_degree` neighbours from the
|
||||
/// distance-sorted candidate list, skipping any candidate dominated
|
||||
/// by an already-selected neighbour under the α test.
|
||||
fn robust_prune(
|
||||
corpus: &[EmbeddingF32],
|
||||
node: u32,
|
||||
candidates: &[u32],
|
||||
alpha: f32,
|
||||
max_degree: usize,
|
||||
) -> Vec<u32> {
|
||||
if candidates.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
let node_v = &corpus[node as usize];
|
||||
let mut sorted: Vec<(u32, f32)> = candidates
|
||||
.iter()
|
||||
.filter(|&&c| c != node)
|
||||
.map(|&c| (c, l2_sq(node_v, &corpus[c as usize])))
|
||||
.collect();
|
||||
sorted.sort_by(cmp_id_asc);
|
||||
let mut out: Vec<u32> = Vec::with_capacity(max_degree);
|
||||
for (cand_id, cand_d) in &sorted {
|
||||
if out.len() >= max_degree {
|
||||
break;
|
||||
}
|
||||
let dominated = out.iter().any(|&sel| {
|
||||
let inter = l2_sq(&corpus[sel as usize], &corpus[*cand_id as usize]);
|
||||
alpha * inter <= *cand_d
|
||||
});
|
||||
if !dominated {
|
||||
out.push(*cand_id);
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Deterministic medoid: corpus point with smallest summed L2² to
|
||||
/// every other corpus point. O(n²). Ties break on smaller id.
|
||||
fn medoid(corpus: &[EmbeddingF32]) -> u32 {
|
||||
let n = corpus.len();
|
||||
let mut best: u32 = 0;
|
||||
let mut best_cost = f32::INFINITY;
|
||||
for i in 0..n {
|
||||
let mut s = 0.0_f32;
|
||||
for j in 0..n {
|
||||
if i == j {
|
||||
continue;
|
||||
}
|
||||
s += l2_sq(&corpus[i], &corpus[j]);
|
||||
}
|
||||
if s < best_cost {
|
||||
best_cost = s;
|
||||
best = i as u32;
|
||||
}
|
||||
}
|
||||
best
|
||||
}
|
||||
|
||||
fn init_random_graph(n: usize, max_degree: usize, seed: u64) -> Vec<Vec<u32>> {
|
||||
let mut rng = Rng::new(seed);
|
||||
let mut neighbors = vec![Vec::with_capacity(max_degree); n];
|
||||
if n <= 1 {
|
||||
return neighbors;
|
||||
}
|
||||
let degree = max_degree.min(n - 1);
|
||||
for i in 0..n {
|
||||
let slot = &mut neighbors[i];
|
||||
let mut attempts = 0_usize;
|
||||
while slot.len() < degree && attempts < degree * 6 {
|
||||
let j = (rng.next_u64() % n as u64) as u32;
|
||||
if j != i as u32 && !slot.contains(&j) {
|
||||
slot.push(j);
|
||||
}
|
||||
attempts += 1;
|
||||
}
|
||||
}
|
||||
neighbors
|
||||
}
|
||||
|
||||
fn det_permutation(n: usize, seed: u64) -> Vec<u32> {
|
||||
let mut v: Vec<u32> = (0..n as u32).collect();
|
||||
let mut rng = Rng::new(seed);
|
||||
for i in (1..n).rev() {
|
||||
let j = (rng.next_u64() % (i as u64 + 1)) as usize;
|
||||
v.swap(i, j);
|
||||
}
|
||||
v
|
||||
}
|
||||
|
||||
fn find_vec(corpus: &[EmbeddingF32], needle: &[f32]) -> Option<usize> {
|
||||
for (i, v) in corpus.iter().enumerate() {
|
||||
if v.len() == needle.len() && v.iter().zip(needle).all(|(a, b)| a.to_bits() == b.to_bits())
|
||||
{
|
||||
return Some(i);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------
|
||||
// Math / PRNG / heap helpers
|
||||
// -----------------------------------------------------------------
|
||||
|
||||
#[inline]
|
||||
fn l2_sq(a: &[f32], b: &[f32]) -> f32 {
|
||||
let mut s = 0.0_f32;
|
||||
let n = a.len().min(b.len());
|
||||
for i in 0..n {
|
||||
let d = a[i] - b[i];
|
||||
s += d * d;
|
||||
}
|
||||
s
|
||||
}
|
||||
|
||||
/// Xoroshiro128++ — tiny, deterministic, no deps.
|
||||
struct Rng {
|
||||
s0: u64,
|
||||
s1: u64,
|
||||
}
|
||||
|
||||
impl Rng {
|
||||
fn new(seed: u64) -> Self {
|
||||
let mut z = seed.wrapping_add(0x9E37_79B9_7F4A_7C15);
|
||||
let s0 = splitmix(&mut z);
|
||||
let s1 = splitmix(&mut z);
|
||||
let s0 = if s0 == 0 { 0xD1B5_4A32_D192_ED03 } else { s0 };
|
||||
let s1 = if s1 == 0 { 0x6A09_E667_BB67_AE85 } else { s1 };
|
||||
Self { s0, s1 }
|
||||
}
|
||||
|
||||
fn next_u64(&mut self) -> u64 {
|
||||
let r = self.s0.wrapping_add(self.s1).rotate_left(17).wrapping_add(self.s0);
|
||||
let s1 = self.s1 ^ self.s0;
|
||||
self.s0 = self.s0.rotate_left(49) ^ s1 ^ (s1 << 21);
|
||||
self.s1 = s1.rotate_left(28);
|
||||
r
|
||||
}
|
||||
}
|
||||
|
||||
fn splitmix(z: &mut u64) -> u64 {
|
||||
*z = z.wrapping_add(0x9E37_79B9_7F4A_7C15);
|
||||
let mut x = *z;
|
||||
x = (x ^ (x >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
|
||||
x = (x ^ (x >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
|
||||
x ^ (x >> 31)
|
||||
}
|
||||
|
||||
/// Min-heap (smaller d pops first) when BinaryHeap normally pops max.
|
||||
struct Min {
|
||||
id: u32,
|
||||
d: f32,
|
||||
}
|
||||
impl PartialEq for Min {
|
||||
fn eq(&self, o: &Self) -> bool {
|
||||
self.d == o.d && self.id == o.id
|
||||
}
|
||||
}
|
||||
impl Eq for Min {}
|
||||
impl PartialOrd for Min {
|
||||
fn partial_cmp(&self, o: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(o))
|
||||
}
|
||||
}
|
||||
impl Ord for Min {
|
||||
fn cmp(&self, o: &Self) -> Ordering {
|
||||
o.d.partial_cmp(&self.d)
|
||||
.unwrap_or(Ordering::Equal)
|
||||
.then_with(|| o.id.cmp(&self.id))
|
||||
}
|
||||
}
|
||||
|
||||
/// Max-heap (larger d pops first) — used to evict the worst beam.
|
||||
struct Max {
|
||||
id: u32,
|
||||
d: f32,
|
||||
}
|
||||
impl PartialEq for Max {
|
||||
fn eq(&self, o: &Self) -> bool {
|
||||
self.d == o.d && self.id == o.id
|
||||
}
|
||||
}
|
||||
impl Eq for Max {}
|
||||
impl PartialOrd for Max {
|
||||
fn partial_cmp(&self, o: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(o))
|
||||
}
|
||||
}
|
||||
impl Ord for Max {
|
||||
fn cmp(&self, o: &Self) -> Ordering {
|
||||
self.d
|
||||
.partial_cmp(&o.d)
|
||||
.unwrap_or(Ordering::Equal)
|
||||
.then_with(|| self.id.cmp(&o.id))
|
||||
}
|
||||
}
|
||||
|
||||
fn cmp_result(a: &(usize, f32), b: &(usize, f32)) -> Ordering {
|
||||
a.1.partial_cmp(&b.1)
|
||||
.unwrap_or(Ordering::Equal)
|
||||
.then_with(|| a.0.cmp(&b.0))
|
||||
}
|
||||
|
||||
fn cmp_id_asc(a: &(u32, f32), b: &(u32, f32)) -> Ordering {
|
||||
a.1.partial_cmp(&b.1)
|
||||
.unwrap_or(Ordering::Equal)
|
||||
.then_with(|| a.0.cmp(&b.0))
|
||||
}
|
||||
|
||||
|
|
@ -14,7 +14,6 @@
|
|||
//! The public surface is the `Analysis` struct re-exported from
|
||||
//! here.
|
||||
|
||||
pub mod diskann_motif;
|
||||
pub mod gpu;
|
||||
pub mod motif;
|
||||
pub mod partition;
|
||||
|
|
@ -23,10 +22,7 @@ pub mod types;
|
|||
|
||||
use ruvector_attention::attention::ScaledDotProductAttention;
|
||||
|
||||
pub use diskann_motif::{DiskAnnMotifIndex, EmbeddingF32, VamanaParams};
|
||||
pub use types::{
|
||||
AnalysisConfig, FunctionalPartition, MotifEmbedding, MotifHit, MotifIndex, MotifSignature,
|
||||
};
|
||||
pub use types::{AnalysisConfig, FunctionalPartition, MotifHit, MotifIndex, MotifSignature};
|
||||
|
||||
use crate::connectome::Connectome;
|
||||
use crate::lif::Spike;
|
||||
|
|
@ -80,14 +76,6 @@ impl Analysis {
|
|||
|
||||
/// Build motif embeddings over sliding windows and index them.
|
||||
/// Returns the index plus the top-k repeated motifs.
|
||||
///
|
||||
/// When `cfg.use_diskann = true` the embeddings are inserted into
|
||||
/// a `DiskAnnMotifIndex` in addition to the bounded brute-force
|
||||
/// `MotifIndex` so downstream callers can drive AC-2-diskann from
|
||||
/// the same embedding corpus. The `(MotifIndex, Vec<MotifHit>)`
|
||||
/// return shape stays source-compatible — the DiskANN view is
|
||||
/// accessed via [`Self::embed_motif_windows`] +
|
||||
/// [`DiskAnnMotifIndex::new`].
|
||||
pub fn retrieve_motifs(
|
||||
&self,
|
||||
conn: &Connectome,
|
||||
|
|
@ -98,23 +86,6 @@ impl Analysis {
|
|||
&self.cfg, &self.sdpa, &self.w_q, &self.w_k, &self.w_v, conn, spikes, k,
|
||||
)
|
||||
}
|
||||
|
||||
/// Encode every non-empty motif window with the same SDPA
|
||||
/// embedder that drives [`Self::retrieve_motifs`], and return the
|
||||
/// list of embeddings with their class / time metadata.
|
||||
///
|
||||
/// This is the entry point for the DiskANN retrieval path: the
|
||||
/// caller builds a [`DiskAnnMotifIndex`] over the returned
|
||||
/// vectors, then runs `precision_at_k` against them.
|
||||
pub fn embed_motif_windows(
|
||||
&self,
|
||||
conn: &Connectome,
|
||||
spikes: &[Spike],
|
||||
) -> Vec<MotifEmbedding> {
|
||||
motif::embed_windows(
|
||||
&self.cfg, &self.sdpa, &self.w_q, &self.w_k, &self.w_v, conn, spikes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ use ruvector_attention::traits::Attention;
|
|||
use crate::connectome::Connectome;
|
||||
use crate::lif::Spike;
|
||||
|
||||
use super::types::{AnalysisConfig, MotifEmbedding, MotifHit, MotifIndex, MotifWindow};
|
||||
use super::types::{AnalysisConfig, MotifHit, MotifIndex, MotifWindow};
|
||||
|
||||
pub(crate) fn retrieve_motifs(
|
||||
cfg: &AnalysisConfig,
|
||||
|
|
@ -20,39 +20,8 @@ pub(crate) fn retrieve_motifs(
|
|||
k: usize,
|
||||
) -> (MotifIndex, Vec<MotifHit>) {
|
||||
let mut index = MotifIndex::new(cfg.index_capacity);
|
||||
for sample in embed_windows(cfg, sdpa, w_q, w_k, w_v, conn, spikes) {
|
||||
index.insert(
|
||||
sample.vector,
|
||||
MotifWindow {
|
||||
t_center_ms: sample.t_center_ms,
|
||||
spike_count: sample.spike_count,
|
||||
dominant_class_idx: sample.dominant_class_idx,
|
||||
},
|
||||
);
|
||||
}
|
||||
let hits = index.top_k(k);
|
||||
(index, hits)
|
||||
}
|
||||
|
||||
/// Drive the SDPA embedder over all motif windows in `spikes` and
|
||||
/// return one `MotifEmbedding` per non-empty window.
|
||||
///
|
||||
/// Exposed to support the DiskANN-index path (`analysis::diskann_motif`)
|
||||
/// that needs the raw (vector, label) corpus. The bounded-brute-force
|
||||
/// `retrieve_motifs` above uses the same embedder; callers get the
|
||||
/// same vectors regardless of the index choice.
|
||||
pub(crate) fn embed_windows(
|
||||
cfg: &AnalysisConfig,
|
||||
sdpa: &ScaledDotProductAttention,
|
||||
w_q: &[f32],
|
||||
w_k: &[f32],
|
||||
w_v: &[f32],
|
||||
conn: &Connectome,
|
||||
spikes: &[Spike],
|
||||
) -> Vec<MotifEmbedding> {
|
||||
let mut out: Vec<MotifEmbedding> = Vec::new();
|
||||
if spikes.is_empty() {
|
||||
return out;
|
||||
return (index, Vec::new());
|
||||
}
|
||||
let t_end = spikes.last().map(|s| s.t_ms).unwrap_or(0.0);
|
||||
let win = cfg.motif_window_ms;
|
||||
|
|
@ -73,15 +42,18 @@ pub(crate) fn embed_windows(
|
|||
let vec = sdpa
|
||||
.compute(&q, &k_refs, &v_refs)
|
||||
.unwrap_or_else(|_| q.clone());
|
||||
out.push(MotifEmbedding {
|
||||
vector: vec,
|
||||
t_center_ms: t + win * 0.5,
|
||||
spike_count: meta.spike_count,
|
||||
dominant_class_idx: meta.dominant_class_idx,
|
||||
});
|
||||
index.insert(
|
||||
vec,
|
||||
MotifWindow {
|
||||
t_center_ms: t + win * 0.5,
|
||||
spike_count: meta.spike_count,
|
||||
dominant_class_idx: meta.dominant_class_idx,
|
||||
},
|
||||
);
|
||||
t += step;
|
||||
}
|
||||
out
|
||||
let hits = index.top_k(k);
|
||||
(index, hits)
|
||||
}
|
||||
|
||||
struct WindowMeta {
|
||||
|
|
|
|||
|
|
@ -23,11 +23,6 @@ pub struct AnalysisConfig {
|
|||
pub max_w: f64,
|
||||
/// Deterministic projection seed.
|
||||
pub proj_seed: u64,
|
||||
/// Opt-in: route `Analysis::retrieve_motifs` through the DiskANN /
|
||||
/// Vamana index (`analysis::diskann_motif`) instead of the bounded
|
||||
/// brute-force `MotifIndex`. Defaults to `false` so AC-2 baseline
|
||||
/// results stay comparable across commits; AC-2-diskann flips it.
|
||||
pub use_diskann: bool,
|
||||
}
|
||||
|
||||
impl Default for AnalysisConfig {
|
||||
|
|
@ -41,7 +36,6 @@ impl Default for AnalysisConfig {
|
|||
min_w: 0.01,
|
||||
max_w: 1_000.0,
|
||||
proj_seed: 0xB16F_ACE_C0DE_BABE,
|
||||
use_diskann: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -78,21 +72,6 @@ pub struct MotifHit {
|
|||
pub nearest_distance: f32,
|
||||
}
|
||||
|
||||
/// One encoded spike-window with its metadata. Emitted by
|
||||
/// `Analysis::embed_motif_windows` to support the DiskANN-index
|
||||
/// retrieval path without exposing the internal `MotifWindow` type.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MotifEmbedding {
|
||||
/// SDPA-backed embedding vector.
|
||||
pub vector: Vec<f32>,
|
||||
/// Representative window mid-time (ms).
|
||||
pub t_center_ms: f32,
|
||||
/// Spike count in the window.
|
||||
pub spike_count: u32,
|
||||
/// Dominant participating class index (0..15).
|
||||
pub dominant_class_idx: u8,
|
||||
}
|
||||
|
||||
/// Summary of a motif-window raster for pretty-printing / JSON output.
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
pub struct MotifSignature {
|
||||
|
|
|
|||
|
|
@ -74,8 +74,7 @@ pub mod observer;
|
|||
pub mod stimulus;
|
||||
|
||||
pub use analysis::{
|
||||
Analysis, AnalysisConfig, DiskAnnMotifIndex, EmbeddingF32, FunctionalPartition, MotifEmbedding,
|
||||
MotifHit, MotifIndex, MotifSignature, VamanaParams,
|
||||
Analysis, AnalysisConfig, FunctionalPartition, MotifHit, MotifIndex, MotifSignature,
|
||||
};
|
||||
pub use connectome::{
|
||||
load_flywire, Connectome, ConnectomeConfig, ConnectomeError, FlyWireNeuronId, FlywireError,
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@
|
|||
//! in `BENCHMARK.md`.
|
||||
|
||||
use connectome_fly::{
|
||||
Analysis, AnalysisConfig, Connectome, ConnectomeConfig, CurrentInjection, DiskAnnMotifIndex,
|
||||
Engine, EngineConfig, NeuronId, Observer, Spike, Stimulus, VamanaParams,
|
||||
Analysis, AnalysisConfig, Connectome, ConnectomeConfig, CurrentInjection, Engine, EngineConfig,
|
||||
NeuronId, Observer, Spike, Stimulus,
|
||||
};
|
||||
|
||||
fn default_conn() -> Connectome {
|
||||
|
|
@ -105,156 +105,6 @@ fn ac_2_motif_emergence() {
|
|||
);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------
|
||||
// AC-2 (DiskANN path) — class-label precision@5 on a ≥ 100-window
|
||||
// corpus using the Vamana index.
|
||||
//
|
||||
// The original `ac_2_motif_emergence` above uses the bounded
|
||||
// brute-force `MotifIndex` path and measures the distance-rank proxy
|
||||
// (a metric that saturates at 0.60 by construction on a k = 5 / small
|
||||
// corpus — see ADR-154 §9.5 and BENCHMARK.md AC-2). This variant
|
||||
// drives the same stimulus machinery but:
|
||||
//
|
||||
// 1. Expands to ≥ 100 non-empty motif windows by running a longer
|
||||
// simulation with interleaved sensory-class stimulus patterns.
|
||||
// 2. Uses the dominant-class index of each window as its label.
|
||||
// 3. Builds a `DiskAnnMotifIndex` over the embeddings and measures
|
||||
// *true* precision@5: for each window, how many of its 5 ANN
|
||||
// neighbours share its dominant class.
|
||||
//
|
||||
// SOTA target per ADR-154 §3.4: 0.80. The distance-proxy path is
|
||||
// preserved above so BENCHMARK.md AC-2 stays comparable.
|
||||
// -----------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn ac_2_motif_emergence_diskann() {
|
||||
let conn = default_conn();
|
||||
let mut stim = Stimulus::empty();
|
||||
let sensory = conn.sensory_neurons().to_vec();
|
||||
|
||||
// Reuse the canonical AC-2 stimulus protocol (20 × 15 ms pulses,
|
||||
// 90 pA, 16 sensory targets) but extend it to 400 repeats so the
|
||||
// motif corpus grows to ≥ 100 non-empty windows. The underlying
|
||||
// stimulus is identical to the baseline — the only knob we move
|
||||
// is duration, which is the lever the prompt explicitly calls out
|
||||
// (ADR-154 §9.5: "ranking metric is not statistically
|
||||
// well-conditioned at a 20-window corpus").
|
||||
const PULSES: usize = 400;
|
||||
for k in 0..PULSES {
|
||||
let t0 = 20.0 + k as f32 * 15.0;
|
||||
for i in 0..sensory.len().min(16) {
|
||||
stim.push(CurrentInjection {
|
||||
t_ms: t0 + i as f32 * 0.20,
|
||||
target: sensory[i],
|
||||
charge_pa: 90.0,
|
||||
});
|
||||
}
|
||||
}
|
||||
let t_end_ms = 20.0 + PULSES as f32 * 15.0 + 40.0;
|
||||
let (_, spikes) = run_one(&conn, &stim, t_end_ms);
|
||||
|
||||
let an = Analysis::new(AnalysisConfig {
|
||||
motif_window_ms: 20.0,
|
||||
motif_bins: 10,
|
||||
index_capacity: 1024,
|
||||
use_diskann: true,
|
||||
..AnalysisConfig::default()
|
||||
});
|
||||
let embeds = an.embed_motif_windows(&conn, &spikes);
|
||||
assert!(
|
||||
embeds.len() >= 100,
|
||||
"ac-2-diskann: corpus too small ({}), need ≥ 100 windows",
|
||||
embeds.len()
|
||||
);
|
||||
|
||||
// Labels via k-means-style anchor-clustering on the corpus
|
||||
// embeddings themselves. The motif encoder is reused from AC-2
|
||||
// unchanged — its output space is what downstream ANN clients
|
||||
// (and connectome observers) actually see. We pick K anchor
|
||||
// vectors deterministically (evenly-spaced indices into the
|
||||
// corpus), assign each embedding the label of its nearest anchor
|
||||
// by true L2, and measure whether DiskANN's top-5 nearest
|
||||
// neighbours preserve that anchor-cluster membership.
|
||||
//
|
||||
// This is equivalent to measuring ANN fidelity against the
|
||||
// coarse Voronoi partition of the corpus: perfect fidelity gives
|
||||
// precision@5 = 1.0, random-guess precision@5 = 1 / K. With K = 4
|
||||
// the random baseline is 0.25, so crossing 0.80 is a meaningful
|
||||
// signal (the brute-force AC-2 distance-proxy at 0.60 sits
|
||||
// *between* 0.25 and 0.80 on this scale).
|
||||
const K_ANCHORS: usize = 4;
|
||||
let anchors: Vec<Vec<f32>> = (0..K_ANCHORS)
|
||||
.map(|i| embeds[i * embeds.len() / K_ANCHORS].vector.clone())
|
||||
.collect();
|
||||
let label_of = |v: &[f32]| -> usize {
|
||||
let mut best = 0_usize;
|
||||
let mut best_d = f32::INFINITY;
|
||||
for (i, a) in anchors.iter().enumerate() {
|
||||
let mut s = 0.0_f32;
|
||||
for j in 0..v.len().min(a.len()) {
|
||||
let d = v[j] - a[j];
|
||||
s += d * d;
|
||||
}
|
||||
if s < best_d {
|
||||
best_d = s;
|
||||
best = i;
|
||||
}
|
||||
}
|
||||
best
|
||||
};
|
||||
|
||||
let corpus: Vec<Vec<f32>> = embeds.iter().map(|e| e.vector.clone()).collect();
|
||||
let queries: Vec<(Vec<f32>, usize)> = corpus
|
||||
.iter()
|
||||
.map(|v| (v.clone(), label_of(v)))
|
||||
.collect();
|
||||
|
||||
let idx = DiskAnnMotifIndex::new(
|
||||
corpus.clone(),
|
||||
VamanaParams {
|
||||
max_degree: 32,
|
||||
build_beam: 64,
|
||||
search_beam: 64,
|
||||
alpha: 1.2,
|
||||
// Fixed seed — bit-deterministic across runs.
|
||||
seed: 0xAC2_D15C_A44_u64,
|
||||
},
|
||||
);
|
||||
let precision = idx.precision_at_k(&queries, 5);
|
||||
|
||||
// Diversity sanity check: if every window collapses to one
|
||||
// anchor the metric degenerates to 1.0 by construction.
|
||||
let mut counts = [0_u32; K_ANCHORS];
|
||||
for (_, l) in &queries {
|
||||
counts[*l] += 1;
|
||||
}
|
||||
let distinct = counts.iter().filter(|c| **c > 0).count();
|
||||
let max_share = *counts.iter().max().unwrap() as f32 / queries.len() as f32;
|
||||
|
||||
eprintln!(
|
||||
"ac-2-diskann: precision@5={precision:.3} corpus={} \
|
||||
anchors={K_ANCHORS} distinct_labels={distinct} \
|
||||
max_label_share={max_share:.2} SOTA_target=0.80",
|
||||
idx.len()
|
||||
);
|
||||
assert!(
|
||||
distinct >= 2,
|
||||
"ac-2-diskann: label collapse — only {distinct} distinct \
|
||||
anchor clusters across {} windows; metric uninformative",
|
||||
queries.len()
|
||||
);
|
||||
assert!(
|
||||
max_share <= 0.85,
|
||||
"ac-2-diskann: one anchor cluster holds {max_share:.2} of the \
|
||||
corpus — too dominant for a meaningful precision measurement"
|
||||
);
|
||||
assert!(
|
||||
precision >= 0.80,
|
||||
"ac-2-diskann: precision@5 {precision:.3} below SOTA target 0.80 \
|
||||
(brute-force baseline stays at 0.60; see BENCHMARK.md AC-2)"
|
||||
);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------
|
||||
// AC-4 — Coherence prediction. Two variants per ADR-154 §8.3.
|
||||
// -----------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -1,270 +0,0 @@
|
|||
#![allow(clippy::needless_range_loop)]
|
||||
//! Tests for `analysis::diskann_motif` — the Vamana-style motif index.
|
||||
//!
|
||||
//! These tests cover, in order:
|
||||
//!
|
||||
//! 1. `build_query_roundtrip` — build + query on a small labelled
|
||||
//! fixture, with a sanity-check that the self-query returns the
|
||||
//! query point as its first hit.
|
||||
//! 2. `determinism_two_queries` — two indexes built from the same
|
||||
//! corpus + params return bit-identical query results.
|
||||
//! 3. `recall_at_5_vs_bruteforce` — ≥ 0.95 recall@5 on a 10 000-vector
|
||||
//! synthetic Gaussian-mixture corpus. Brute force is the ground
|
||||
//! truth; Vamana must recover 95 % of its top-5.
|
||||
//!
|
||||
//! The fourth acceptance test — AC-2 on ≥ 100 windows — lives in
|
||||
//! `tests/acceptance_core.rs::ac_2_motif_emergence_diskann` so it runs
|
||||
//! alongside the existing AC-2 and the BENCHMARK.md comparison row
|
||||
//! stays co-located.
|
||||
|
||||
use connectome_fly::{DiskAnnMotifIndex, EmbeddingF32, VamanaParams};
|
||||
|
||||
// -----------------------------------------------------------------
|
||||
// 1. Build + query round-trip
|
||||
// -----------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn build_query_roundtrip() {
|
||||
// Tiny labelled fixture: 8 points on the unit grid in 2-D. Query
|
||||
// each point; the nearest hit (at distance 0) MUST be the point
|
||||
// itself.
|
||||
let corpus: Vec<EmbeddingF32> = vec![
|
||||
vec![0.0, 0.0],
|
||||
vec![1.0, 0.0],
|
||||
vec![2.0, 0.0],
|
||||
vec![0.0, 1.0],
|
||||
vec![1.0, 1.0],
|
||||
vec![2.0, 1.0],
|
||||
vec![0.0, 2.0],
|
||||
vec![1.0, 2.0],
|
||||
];
|
||||
let idx = DiskAnnMotifIndex::new(
|
||||
corpus.clone(),
|
||||
VamanaParams {
|
||||
max_degree: 4,
|
||||
build_beam: 8,
|
||||
search_beam: 8,
|
||||
alpha: 1.2,
|
||||
seed: 1,
|
||||
},
|
||||
);
|
||||
assert_eq!(idx.len(), corpus.len());
|
||||
for (i, v) in corpus.iter().enumerate() {
|
||||
let hits = idx.query(v, 3);
|
||||
assert_eq!(hits.len(), 3, "k=3 should return 3 hits");
|
||||
assert_eq!(hits[0].0, i, "self-hit should come first");
|
||||
assert!(hits[0].1 < 1e-6, "self-hit should have ~0 distance");
|
||||
// Distances must be sorted ascending.
|
||||
for w in hits.windows(2) {
|
||||
assert!(w[0].1 <= w[1].1, "hits not sorted by distance");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------
|
||||
// 2. Determinism
|
||||
// -----------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn determinism_two_queries() {
|
||||
let corpus = synthetic_corpus(256, 32, 0xFEED_FACE);
|
||||
let params = VamanaParams {
|
||||
max_degree: 24,
|
||||
build_beam: 48,
|
||||
search_beam: 48,
|
||||
alpha: 1.2,
|
||||
seed: 0xCAFEBEEF,
|
||||
};
|
||||
let idx_a = DiskAnnMotifIndex::new(corpus.clone(), params.clone());
|
||||
let idx_b = DiskAnnMotifIndex::new(corpus.clone(), params);
|
||||
for i in 0..8 {
|
||||
let q = &corpus[i];
|
||||
let a = idx_a.query(q, 10);
|
||||
let b = idx_b.query(q, 10);
|
||||
assert_eq!(a.len(), b.len());
|
||||
for (ra, rb) in a.iter().zip(b.iter()) {
|
||||
assert_eq!(ra.0, rb.0, "result id differs");
|
||||
assert_eq!(
|
||||
ra.1.to_bits(),
|
||||
rb.1.to_bits(),
|
||||
"distance differs (non-deterministic build)"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------
|
||||
// 3. Recall@5 vs brute force on 10 000-vector synthetic corpus
|
||||
// -----------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn recall_at_5_vs_bruteforce_10k() {
|
||||
const N: usize = 10_000;
|
||||
const DIM: usize = 32;
|
||||
const K: usize = 5;
|
||||
const Q: usize = 100;
|
||||
|
||||
let corpus = mixture_corpus(N, DIM, 16, 0xBEEF_0F00);
|
||||
let idx = DiskAnnMotifIndex::new(
|
||||
corpus.clone(),
|
||||
VamanaParams {
|
||||
max_degree: 48,
|
||||
build_beam: 96,
|
||||
search_beam: 96,
|
||||
alpha: 1.2,
|
||||
seed: 0x51DECAFE,
|
||||
},
|
||||
);
|
||||
|
||||
// Pick Q query points deterministically from the corpus (stride
|
||||
// sampling keeps the test fully seeded).
|
||||
let stride = N / Q;
|
||||
let mut recalls: Vec<f32> = Vec::with_capacity(Q);
|
||||
for qi in 0..Q {
|
||||
let qid = qi * stride;
|
||||
let q = &corpus[qid];
|
||||
let gt = brute_force_topk(&corpus, q, K + 1);
|
||||
let ann = idx.query(q, K + 1);
|
||||
// Drop the self-hit from both (qid, distance ~= 0).
|
||||
let gt_ids: std::collections::HashSet<usize> = gt
|
||||
.iter()
|
||||
.filter(|(id, _)| *id != qid)
|
||||
.map(|(id, _)| *id)
|
||||
.take(K)
|
||||
.collect();
|
||||
let ann_ids: std::collections::HashSet<usize> = ann
|
||||
.iter()
|
||||
.filter(|(id, _)| *id != qid)
|
||||
.map(|(id, _)| *id)
|
||||
.take(K)
|
||||
.collect();
|
||||
let hit = gt_ids.intersection(&ann_ids).count();
|
||||
recalls.push(hit as f32 / K as f32);
|
||||
}
|
||||
let mean = recalls.iter().sum::<f32>() / recalls.len() as f32;
|
||||
eprintln!(
|
||||
"diskann recall@{K}: mean={mean:.3} n={N} dim={DIM} queries={Q}"
|
||||
);
|
||||
assert!(
|
||||
mean >= 0.95,
|
||||
"recall@{K} {mean:.3} below target 0.95 (10 000-vector corpus)"
|
||||
);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------
|
||||
// Helpers (deterministic PRNG; same xoroshiro as analysis module)
|
||||
// -----------------------------------------------------------------
|
||||
|
||||
fn synthetic_corpus(n: usize, dim: usize, seed: u64) -> Vec<EmbeddingF32> {
|
||||
let mut rng = TestRng::new(seed);
|
||||
let mut out = Vec::with_capacity(n);
|
||||
for _ in 0..n {
|
||||
let mut v = Vec::with_capacity(dim);
|
||||
for _ in 0..dim {
|
||||
v.push(rng.next_f32_unit() * 2.0 - 1.0);
|
||||
}
|
||||
out.push(v);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Gaussian-mixture corpus: `clusters` centres in a DIM-cube, each
|
||||
/// point is its centre plus tight iid noise. Well-separated clusters
|
||||
/// make brute-force ground truth clean so a 0.95 recall bound is a
|
||||
/// real signal rather than a dense-sphere coincidence.
|
||||
fn mixture_corpus(n: usize, dim: usize, clusters: usize, seed: u64) -> Vec<EmbeddingF32> {
|
||||
let mut rng = TestRng::new(seed);
|
||||
let mut centres: Vec<Vec<f32>> = Vec::with_capacity(clusters);
|
||||
for _ in 0..clusters {
|
||||
let mut c = Vec::with_capacity(dim);
|
||||
for _ in 0..dim {
|
||||
c.push((rng.next_f32_unit() - 0.5) * 8.0);
|
||||
}
|
||||
centres.push(c);
|
||||
}
|
||||
let sigma = 0.35_f32;
|
||||
let mut out = Vec::with_capacity(n);
|
||||
for i in 0..n {
|
||||
let centre = ¢res[i % clusters];
|
||||
let mut v = Vec::with_capacity(dim);
|
||||
for d in 0..dim {
|
||||
let g = rng.next_gauss() * sigma;
|
||||
v.push(centre[d] + g);
|
||||
}
|
||||
out.push(v);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn brute_force_topk(corpus: &[EmbeddingF32], q: &[f32], k: usize) -> Vec<(usize, f32)> {
|
||||
let mut all: Vec<(usize, f32)> = corpus
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, v)| (i, l2(v, q)))
|
||||
.collect();
|
||||
all.sort_by(|a, b| {
|
||||
a.1.partial_cmp(&b.1)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
.then_with(|| a.0.cmp(&b.0))
|
||||
});
|
||||
all.truncate(k);
|
||||
all
|
||||
}
|
||||
|
||||
fn l2(a: &[f32], b: &[f32]) -> f32 {
|
||||
let mut s = 0.0_f32;
|
||||
for i in 0..a.len().min(b.len()) {
|
||||
let d = a[i] - b[i];
|
||||
s += d * d;
|
||||
}
|
||||
s.sqrt()
|
||||
}
|
||||
|
||||
/// Tiny xoroshiro128++ for test-local determinism (matches the PRNG
|
||||
/// in `analysis::diskann_motif` but reimplemented to keep these tests
|
||||
/// independent of the crate-internal API).
|
||||
struct TestRng {
|
||||
s0: u64,
|
||||
s1: u64,
|
||||
}
|
||||
|
||||
impl TestRng {
|
||||
fn new(seed: u64) -> Self {
|
||||
let mut z = seed.wrapping_add(0x9E37_79B9_7F4A_7C15);
|
||||
let s0 = splitmix(&mut z);
|
||||
let s1 = splitmix(&mut z);
|
||||
let s0 = if s0 == 0 { 0xD1B5_4A32_D192_ED03 } else { s0 };
|
||||
let s1 = if s1 == 0 { 0x6A09_E667_BB67_AE85 } else { s1 };
|
||||
Self { s0, s1 }
|
||||
}
|
||||
|
||||
fn next_u64(&mut self) -> u64 {
|
||||
let r = self.s0.wrapping_add(self.s1).rotate_left(17).wrapping_add(self.s0);
|
||||
let s1 = self.s1 ^ self.s0;
|
||||
self.s0 = self.s0.rotate_left(49) ^ s1 ^ (s1 << 21);
|
||||
self.s1 = s1.rotate_left(28);
|
||||
r
|
||||
}
|
||||
|
||||
fn next_f32_unit(&mut self) -> f32 {
|
||||
let u = (self.next_u64() >> 11) as f64 / ((1u64 << 53) as f64);
|
||||
u as f32
|
||||
}
|
||||
|
||||
/// Box-Muller standard normal.
|
||||
fn next_gauss(&mut self) -> f32 {
|
||||
let u1 = (self.next_f32_unit()).max(1e-9);
|
||||
let u2 = self.next_f32_unit();
|
||||
let r = (-2.0_f32 * u1.ln()).sqrt();
|
||||
let th = 2.0_f32 * std::f32::consts::PI * u2;
|
||||
r * th.cos()
|
||||
}
|
||||
}
|
||||
|
||||
fn splitmix(z: &mut u64) -> u64 {
|
||||
*z = z.wrapping_add(0x9E37_79B9_7F4A_7C15);
|
||||
let mut x = *z;
|
||||
x = (x ^ (x >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
|
||||
x = (x ^ (x >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
|
||||
x ^ (x >> 31)
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue