ruvector/crates/ruvector-dag/examples/attention_selection.rs
Claude ec323f5a4d chore(dag): optimize codebase - fix warnings and format code
- Fix unused variable warnings with underscore prefixes
- Add #[allow(dead_code)] for API-reserved fields
- Run cargo fmt for consistent formatting
- Apply cargo clippy --fix for lint improvements
- Reduce ruvector-dag lib warnings from 17 to 0
- Improve code quality across 60 files

Changes include:
- qudag/client.rs: prefix unused params (_pattern, _proposal_id, _since_round)
- sona/engine.rs: prefix unused param (_similar), add deprecated match arms
- sona/reasoning_bank.rs: prefix unused var (_dim)
- attention/*.rs: consistent formatting and minor improvements
- examples/exotic/*.rs: formatting for all 7 coherence-sensing examples
2025-12-30 02:08:55 +00:00

99 lines
3.2 KiB
Rust

//! Attention mechanism selection example
use ruvector_dag::attention::{
CausalConeAttention, CausalConeConfig, DagAttention, TopologicalAttention, TopologicalConfig,
};
use ruvector_dag::dag::{OperatorNode, OperatorType, QueryDag};
fn main() {
println!("=== Attention Mechanism Selection ===\n");
// Create a sample DAG
let dag = create_vector_search_dag();
println!("Created vector search DAG:");
println!(" Nodes: {}", dag.node_count());
println!(" Edges: {}", dag.edge_count());
// Test Topological Attention
println!("\n--- Topological Attention ---");
println!("Emphasizes node depth in the DAG hierarchy");
let topo = TopologicalAttention::new(TopologicalConfig {
decay_factor: 0.9,
max_depth: 10,
});
let scores = topo.forward(&dag).unwrap();
println!("\nAttention scores:");
for (node_id, score) in &scores {
let node = dag.get_node(*node_id).unwrap();
println!(" Node {}: {:.4} - {:?}", node_id, score, node.op_type);
}
let sum: f32 = scores.values().sum();
println!("\nSum of scores: {:.4} (should be ~1.0)", sum);
// Test Causal Cone Attention
println!("\n--- Causal Cone Attention ---");
println!("Focuses on downstream dependencies");
let causal = CausalConeAttention::new(CausalConeConfig {
time_window_ms: 1000,
future_discount: 0.85,
ancestor_weight: 0.5,
});
let causal_scores = causal.forward(&dag).unwrap();
println!("\nCausal cone scores:");
for (node_id, score) in &causal_scores {
let node = dag.get_node(*node_id).unwrap();
println!(" Node {}: {:.4} - {:?}", node_id, score, node.op_type);
}
// Compare mechanisms
println!("\n--- Comparison ---");
println!("Node | Topological | Causal Cone | Difference");
println!("-----|-------------|-------------|------------");
for node_id in 0..dag.node_count() {
let topo_score = scores.get(&node_id).unwrap_or(&0.0);
let causal_score = causal_scores.get(&node_id).unwrap_or(&0.0);
let diff = (topo_score - causal_score).abs();
println!(
"{:4} | {:11.4} | {:11.4} | {:11.4}",
node_id, topo_score, causal_score, diff
);
}
println!("\n=== Example Complete ===");
}
fn create_vector_search_dag() -> QueryDag {
let mut dag = QueryDag::new();
// HNSW scan - the primary vector search
let hnsw = dag.add_node(OperatorNode::hnsw_scan(0, "embeddings_idx", 64));
// Metadata table scan
let meta = dag.add_node(OperatorNode::seq_scan(1, "metadata"));
// Join embeddings with metadata
let join = dag.add_node(OperatorNode::new(2, OperatorType::NestedLoopJoin));
dag.add_edge(hnsw, join).unwrap();
dag.add_edge(meta, join).unwrap();
// Filter by category
let filter = dag.add_node(OperatorNode::filter(3, "category = 'tech'"));
dag.add_edge(join, filter).unwrap();
// Limit results
let limit = dag.add_node(OperatorNode::limit(4, 10));
dag.add_edge(filter, limit).unwrap();
// Result node
let result = dag.add_node(OperatorNode::new(5, OperatorType::Result));
dag.add_edge(limit, result).unwrap();
dag
}