//! Embedding service with tokenization and caching //! //! Provides text-to-vector conversion with LRU caching for efficiency. use crate::config::EmbeddingConfig; use crate::error::Result; use ahash::AHashMap; use lru::LruCache; use parking_lot::Mutex; use std::num::NonZeroUsize; /// Result of embedding a text #[derive(Debug, Clone)] pub struct Embedding { /// The embedding vector pub vector: Vec, /// Token count pub token_count: usize, /// Whether text was truncated pub truncated: bool, /// Cache hit indicator pub from_cache: bool, } /// Token from tokenization #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Token { /// Token ID pub id: u32, /// Token text pub text: String, } /// Tokenizer for text processing pub struct Tokenizer { /// Vocabulary mapping vocab: AHashMap, /// Reverse mapping id_to_token: Vec, /// Special tokens special_tokens: SpecialTokens, } /// Special token IDs #[derive(Debug, Clone)] struct SpecialTokens { pad: u32, unk: u32, bos: u32, eos: u32, } impl Tokenizer { /// Create a new basic tokenizer pub fn new(vocab_size: usize) -> Self { let mut vocab = AHashMap::new(); let mut id_to_token = Vec::with_capacity(vocab_size); // Add special tokens let special = ["", "", "", "", ""]; for (i, tok) in special.iter().enumerate() { vocab.insert(tok.to_string(), i as u32); id_to_token.push(tok.to_string()); } // Build basic character/word vocabulary let chars: Vec = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .,!?;:'\"-_()[]{}" .chars() .collect(); for ch in chars { let s = ch.to_string(); if !vocab.contains_key(&s) && vocab.len() < vocab_size { let id = vocab.len() as u32; vocab.insert(s.clone(), id); id_to_token.push(s); } } Self { vocab, id_to_token, special_tokens: SpecialTokens { pad: 0, unk: 1, bos: 2, eos: 3, }, } } /// Tokenize text into token IDs pub fn tokenize(&self, text: &str) -> Vec { let mut tokens = vec![self.special_tokens.bos]; // Simple character-level tokenization for word in text.split_whitespace() { for ch in word.chars() { let s = ch.to_string(); let id = self .vocab .get(&s) .copied() .unwrap_or(self.special_tokens.unk); tokens.push(id); } // Add space token if let Some(&space_id) = self.vocab.get(" ") { tokens.push(space_id); } } tokens.push(self.special_tokens.eos); tokens } /// Get vocabulary size pub fn vocab_size(&self) -> usize { self.vocab.len() } /// Decode tokens back to text pub fn decode(&self, tokens: &[u32]) -> String { tokens .iter() .filter_map(|&id| self.id_to_token.get(id as usize)) .cloned() .collect::>() .join("") } } /// Service for text embedding with caching pub struct EmbeddingService { /// Embedding dimension dimension: usize, /// Maximum tokens max_tokens: usize, /// Tokenizer tokenizer: Tokenizer, /// LRU cache for embeddings cache: Mutex>, /// Embedding matrix (token_id -> embedding) embedding_matrix: Vec>, /// Position embeddings position_embeddings: Vec>, /// Statistics stats: EmbeddingStats, } /// Embedding service statistics struct EmbeddingStats { cache_hits: std::sync::atomic::AtomicU64, cache_misses: std::sync::atomic::AtomicU64, total_tokens: std::sync::atomic::AtomicU64, } impl EmbeddingService { /// Create a new embedding service pub fn new(config: &EmbeddingConfig) -> Result { let tokenizer = Tokenizer::new(10000); let vocab_size = tokenizer.vocab_size(); // Initialize embedding matrix with random values let mut rng = rand::thread_rng(); use rand::Rng; let embedding_matrix: Vec> = (0..vocab_size) .map(|_| { let mut vec: Vec = (0..config.dimension) .map(|_| rng.gen_range(-0.1..0.1)) .collect(); // Normalize let norm: f32 = vec.iter().map(|x| x * x).sum::().sqrt(); if norm > 0.0 { vec.iter_mut().for_each(|x| *x /= norm); } vec }) .collect(); // Position embeddings (sinusoidal) let position_embeddings: Vec> = (0..config.max_tokens) .map(|pos| { (0..config.dimension) .map(|i| { let angle = pos as f32 / (10000.0_f32).powf(2.0 * (i / 2) as f32 / config.dimension as f32); if i % 2 == 0 { angle.sin() } else { angle.cos() } }) .collect() }) .collect(); let cache_size = NonZeroUsize::new(10000).unwrap(); Ok(Self { dimension: config.dimension, max_tokens: config.max_tokens, tokenizer, cache: Mutex::new(LruCache::new(cache_size)), embedding_matrix, position_embeddings, stats: EmbeddingStats { cache_hits: std::sync::atomic::AtomicU64::new(0), cache_misses: std::sync::atomic::AtomicU64::new(0), total_tokens: std::sync::atomic::AtomicU64::new(0), }, }) } /// Embed a text string pub fn embed(&self, text: &str) -> Result { // Check cache let hash = self.hash_text(text); { let mut cache = self.cache.lock(); if let Some(cached) = cache.get(&hash) { self.stats .cache_hits .fetch_add(1, std::sync::atomic::Ordering::Relaxed); let mut result = cached.clone(); result.from_cache = true; return Ok(result); } } self.stats .cache_misses .fetch_add(1, std::sync::atomic::Ordering::Relaxed); // Tokenize let tokens = self.tokenizer.tokenize(text); let token_count = tokens.len(); let truncated = token_count > self.max_tokens; let tokens: Vec = tokens.into_iter().take(self.max_tokens).collect(); self.stats .total_tokens .fetch_add(tokens.len() as u64, std::sync::atomic::Ordering::Relaxed); // Compute embedding let vector = self.compute_embedding(&tokens); let embedding = Embedding { vector, token_count: tokens.len(), truncated, from_cache: false, }; // Cache result { let mut cache = self.cache.lock(); cache.put(hash, embedding.clone()); } Ok(embedding) } /// Embed multiple texts (batched for efficiency) pub fn embed_batch(&self, texts: &[&str]) -> Result> { texts.iter().map(|t| self.embed(t)).collect() } /// Embed with specific pooling strategy pub fn embed_with_pooling(&self, text: &str, pooling: PoolingStrategy) -> Result { let tokens = self.tokenizer.tokenize(text); let tokens: Vec = tokens.into_iter().take(self.max_tokens).collect(); let vector = match pooling { PoolingStrategy::Mean => self.mean_pooling(&tokens), PoolingStrategy::Max => self.max_pooling(&tokens), PoolingStrategy::CLS => self.cls_pooling(&tokens), PoolingStrategy::LastToken => self.last_token_pooling(&tokens), }; Ok(Embedding { vector, token_count: tokens.len(), truncated: tokens.len() >= self.max_tokens, from_cache: false, }) } /// Get embedding statistics pub fn get_stats(&self) -> EmbeddingServiceStats { EmbeddingServiceStats { cache_hits: self .stats .cache_hits .load(std::sync::atomic::Ordering::Relaxed), cache_misses: self .stats .cache_misses .load(std::sync::atomic::Ordering::Relaxed), total_tokens: self .stats .total_tokens .load(std::sync::atomic::Ordering::Relaxed), cache_size: self.cache.lock().len(), } } /// Clear the embedding cache pub fn clear_cache(&self) { self.cache.lock().clear(); } fn hash_text(&self, text: &str) -> u64 { use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; let mut hasher = DefaultHasher::new(); text.hash(&mut hasher); hasher.finish() } fn compute_embedding(&self, tokens: &[u32]) -> Vec { self.mean_pooling(tokens) } fn mean_pooling(&self, tokens: &[u32]) -> Vec { let mut result = vec![0.0f32; self.dimension]; for (pos, &token_id) in tokens.iter().enumerate() { let token_emb = self.get_token_embedding(token_id); let pos_emb = self.get_position_embedding(pos); for i in 0..self.dimension { result[i] += token_emb[i] + pos_emb[i]; } } // Average let n = tokens.len() as f32; if n > 0.0 { result.iter_mut().for_each(|x| *x /= n); } // Normalize let norm: f32 = result.iter().map(|x| x * x).sum::().sqrt(); if norm > 0.0 { result.iter_mut().for_each(|x| *x /= norm); } result } fn max_pooling(&self, tokens: &[u32]) -> Vec { let mut result = vec![f32::NEG_INFINITY; self.dimension]; for (pos, &token_id) in tokens.iter().enumerate() { let token_emb = self.get_token_embedding(token_id); let pos_emb = self.get_position_embedding(pos); for i in 0..self.dimension { let val = token_emb[i] + pos_emb[i]; if val > result[i] { result[i] = val; } } } // Normalize let norm: f32 = result.iter().map(|x| x * x).sum::().sqrt(); if norm > 0.0 { result.iter_mut().for_each(|x| *x /= norm); } result } fn cls_pooling(&self, tokens: &[u32]) -> Vec { if let Some(&first_token) = tokens.first() { let token_emb = self.get_token_embedding(first_token); let pos_emb = self.get_position_embedding(0); let mut result: Vec = token_emb .iter() .zip(pos_emb.iter()) .map(|(t, p)| t + p) .collect(); // Normalize let norm: f32 = result.iter().map(|x| x * x).sum::().sqrt(); if norm > 0.0 { result.iter_mut().for_each(|x| *x /= norm); } result } else { vec![0.0; self.dimension] } } fn last_token_pooling(&self, tokens: &[u32]) -> Vec { if let Some(&last_token) = tokens.last() { let pos = tokens.len().saturating_sub(1); let token_emb = self.get_token_embedding(last_token); let pos_emb = self.get_position_embedding(pos); let mut result: Vec = token_emb .iter() .zip(pos_emb.iter()) .map(|(t, p)| t + p) .collect(); // Normalize let norm: f32 = result.iter().map(|x| x * x).sum::().sqrt(); if norm > 0.0 { result.iter_mut().for_each(|x| *x /= norm); } result } else { vec![0.0; self.dimension] } } fn get_token_embedding(&self, token_id: u32) -> &[f32] { let idx = (token_id as usize).min(self.embedding_matrix.len() - 1); &self.embedding_matrix[idx] } fn get_position_embedding(&self, pos: usize) -> &[f32] { let idx = pos.min(self.position_embeddings.len() - 1); &self.position_embeddings[idx] } } /// Pooling strategy for embeddings #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum PoolingStrategy { /// Mean pooling (average all tokens) Mean, /// Max pooling (element-wise max) Max, /// CLS token pooling (first token) CLS, /// Last token pooling LastToken, } /// Public statistics #[derive(Debug, Clone)] pub struct EmbeddingServiceStats { /// Cache hits pub cache_hits: u64, /// Cache misses pub cache_misses: u64, /// Total tokens processed pub total_tokens: u64, /// Current cache size pub cache_size: usize, } #[cfg(test)] mod tests { use super::*; #[test] fn test_embedding_dimension() { let config = EmbeddingConfig::default(); let service = EmbeddingService::new(&config).unwrap(); let embedding = service.embed("Hello world").unwrap(); assert_eq!(embedding.vector.len(), config.dimension); } #[test] fn test_embedding_normalized() { let config = EmbeddingConfig::default(); let service = EmbeddingService::new(&config).unwrap(); let embedding = service.embed("Test text").unwrap(); let norm: f32 = embedding.vector.iter().map(|x| x * x).sum::().sqrt(); assert!((norm - 1.0).abs() < 0.01); } #[test] fn test_same_text_same_embedding() { let config = EmbeddingConfig::default(); let service = EmbeddingService::new(&config).unwrap(); let e1 = service.embed("Same text").unwrap(); let e2 = service.embed("Same text").unwrap(); assert_eq!(e1.vector, e2.vector); assert!(e2.from_cache); } #[test] fn test_different_texts_different_embeddings() { let config = EmbeddingConfig::default(); let service = EmbeddingService::new(&config).unwrap(); let e1 = service.embed("Hello world").unwrap(); let e2 = service.embed("Goodbye moon").unwrap(); // Character-level tokenizer produces similar embeddings for similar text // Just verify they're not identical let diff: f32 = e1 .vector .iter() .zip(e2.vector.iter()) .map(|(a, b)| (a - b).abs()) .sum(); assert!( diff > 0.0, "Different texts should produce different embeddings" ); } #[test] fn test_tokenizer() { let tokenizer = Tokenizer::new(1000); let tokens = tokenizer.tokenize("Hello world"); assert!(!tokens.is_empty()); assert_eq!(tokens[0], 2); // BOS assert_eq!(*tokens.last().unwrap(), 3); // EOS } #[test] fn test_batch_embedding() { let config = EmbeddingConfig::default(); let service = EmbeddingService::new(&config).unwrap(); let texts = vec!["text one", "text two", "text three"]; let embeddings = service.embed_batch(&texts).unwrap(); assert_eq!(embeddings.len(), 3); for emb in &embeddings { assert_eq!(emb.vector.len(), config.dimension); } } #[test] fn test_pooling_strategies() { let config = EmbeddingConfig::default(); let service = EmbeddingService::new(&config).unwrap(); let text = "Test pooling strategies"; let mean = service .embed_with_pooling(text, PoolingStrategy::Mean) .unwrap(); let max = service .embed_with_pooling(text, PoolingStrategy::Max) .unwrap(); let cls = service .embed_with_pooling(text, PoolingStrategy::CLS) .unwrap(); let last = service .embed_with_pooling(text, PoolingStrategy::LastToken) .unwrap(); assert_eq!(mean.vector.len(), config.dimension); assert_eq!(max.vector.len(), config.dimension); assert_eq!(cls.vector.len(), config.dimension); assert_eq!(last.vector.len(), config.dimension); let mean_dot_max: f32 = mean .vector .iter() .zip(max.vector.iter()) .map(|(a, b)| a * b) .sum(); assert!(mean_dot_max < 0.999); } #[test] fn test_cache_stats() { let config = EmbeddingConfig::default(); let service = EmbeddingService::new(&config).unwrap(); service.embed("test 1").unwrap(); service.embed("test 2").unwrap(); service.embed("test 1").unwrap(); // Cache hit let stats = service.get_stats(); assert_eq!(stats.cache_hits, 1); assert_eq!(stats.cache_misses, 2); } #[test] fn test_truncation() { let mut config = EmbeddingConfig::default(); config.max_tokens = 10; let service = EmbeddingService::new(&config).unwrap(); let long_text = "This is a very long text that will definitely be truncated because it exceeds the maximum token limit"; let embedding = service.embed(long_text).unwrap(); assert!(embedding.truncated); } #[test] fn test_clear_cache() { let config = EmbeddingConfig::default(); let service = EmbeddingService::new(&config).unwrap(); service.embed("test").unwrap(); assert_eq!(service.get_stats().cache_size, 1); service.clear_cache(); assert_eq!(service.get_stats().cache_size, 0); } }