ruvector/crates/ruvector-core/src/embeddings.rs
rUv d316a52d42 fix(ci): Fix formatting and workflow permission issues
- Run cargo fmt across all crates (468 files formatted)
- Add permissions for PR comments in benchmarks.yml
- Add continue-on-error for PR comment steps
- Remove Docker service from postgres-extension-ci (pgrx manages own postgres)
- Add permissions to postgres-extension-ci.yml

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-26 22:11:57 +00:00

413 lines
13 KiB
Rust

//! Text Embedding Providers
//!
//! This module provides a pluggable embedding system for AgenticDB.
//!
//! ## Available Providers
//!
//! - **HashEmbedding**: Fast hash-based placeholder (default, not semantic)
//! - **CandleEmbedding**: Real embeddings using candle-transformers (feature: `real-embeddings`)
//! - **ApiEmbedding**: External API calls (OpenAI, Anthropic, Cohere, etc.)
//!
//! ## Usage
//!
//! ```rust,no_run
//! use ruvector_core::embeddings::{EmbeddingProvider, HashEmbedding, ApiEmbedding};
//! use ruvector_core::AgenticDB;
//!
//! // Default: Hash-based (fast, but not semantic)
//! let hash_provider = HashEmbedding::new(384);
//! let embedding = hash_provider.embed("hello world")?;
//!
//! // API-based (requires API key)
//! let api_provider = ApiEmbedding::openai("sk-...", "text-embedding-3-small");
//! let embedding = api_provider.embed("hello world")?;
//! # Ok::<(), Box<dyn std::error::Error>>(())
//! ```
use crate::error::{Result, RuvectorError};
use std::sync::Arc;
/// Trait for text embedding providers
pub trait EmbeddingProvider: Send + Sync {
/// Generate embedding vector for the given text
fn embed(&self, text: &str) -> Result<Vec<f32>>;
/// Get the dimensionality of embeddings produced by this provider
fn dimensions(&self) -> usize;
/// Get a description of this provider (for logging/debugging)
fn name(&self) -> &str;
}
/// Hash-based embedding provider (placeholder, not semantic)
///
/// ⚠️ **WARNING**: This does NOT produce semantic embeddings!
/// - "dog" and "cat" will NOT be similar
/// - "dog" and "god" WILL be similar (same characters)
///
/// Use this only for:
/// - Testing
/// - Prototyping
/// - When semantic similarity is not required
#[derive(Debug, Clone)]
pub struct HashEmbedding {
dimensions: usize,
}
impl HashEmbedding {
/// Create a new hash-based embedding provider
pub fn new(dimensions: usize) -> Self {
Self { dimensions }
}
}
impl EmbeddingProvider for HashEmbedding {
fn embed(&self, text: &str) -> Result<Vec<f32>> {
let mut embedding = vec![0.0; self.dimensions];
let bytes = text.as_bytes();
for (i, byte) in bytes.iter().enumerate() {
embedding[i % self.dimensions] += (*byte as f32) / 255.0;
}
// Normalize
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for val in &mut embedding {
*val /= norm;
}
}
Ok(embedding)
}
fn dimensions(&self) -> usize {
self.dimensions
}
fn name(&self) -> &str {
"HashEmbedding (placeholder)"
}
}
/// Real embeddings using candle-transformers
///
/// Requires feature flag: `real-embeddings`
///
/// ⚠️ **Note**: Full candle integration is complex and model-specific.
/// For production use, we recommend:
/// 1. Using the API-based providers (simpler, always up-to-date)
/// 2. Using ONNX Runtime with pre-exported models
/// 3. Implementing your own candle wrapper for your specific model
///
/// This is a stub implementation showing the structure.
/// Users should implement `EmbeddingProvider` trait for their specific models.
#[cfg(feature = "real-embeddings")]
pub mod candle {
use super::*;
/// Candle-based embedding provider stub
///
/// This is a placeholder. For real implementation:
/// 1. Add candle dependencies for your specific model type
/// 2. Implement model loading and inference
/// 3. Handle tokenization appropriately
///
/// Example structure:
/// ```rust,ignore
/// pub struct CandleEmbedding {
/// model: YourModelType,
/// tokenizer: Tokenizer,
/// device: Device,
/// dimensions: usize,
/// }
/// ```
pub struct CandleEmbedding {
dimensions: usize,
model_id: String,
}
impl CandleEmbedding {
/// Create a stub candle embedding provider
///
/// **This is not a real implementation!**
/// For production, implement with actual model loading.
///
/// # Example
/// ```rust,no_run
/// # #[cfg(feature = "real-embeddings")]
/// # {
/// use ruvector_core::embeddings::candle::CandleEmbedding;
///
/// // This returns an error - real implementation required
/// let result = CandleEmbedding::from_pretrained(
/// "sentence-transformers/all-MiniLM-L6-v2",
/// false
/// );
/// assert!(result.is_err());
/// # }
/// ```
pub fn from_pretrained(model_id: &str, _use_gpu: bool) -> Result<Self> {
Err(RuvectorError::ModelLoadError(format!(
"Candle embedding support is a stub. Please:\n\
1. Use ApiEmbedding for production (recommended)\n\
2. Or implement CandleEmbedding for model: {}\n\
3. See docs for ONNX Runtime integration examples",
model_id
)))
}
}
impl EmbeddingProvider for CandleEmbedding {
fn embed(&self, _text: &str) -> Result<Vec<f32>> {
Err(RuvectorError::ModelInferenceError(
"Candle embedding not implemented - use ApiEmbedding instead".to_string(),
))
}
fn dimensions(&self) -> usize {
self.dimensions
}
fn name(&self) -> &str {
"CandleEmbedding (stub - not implemented)"
}
}
}
#[cfg(feature = "real-embeddings")]
pub use candle::CandleEmbedding;
/// API-based embedding provider (OpenAI, Anthropic, Cohere, etc.)
///
/// Supports any API that accepts JSON and returns embeddings in a standard format.
///
/// # Example (OpenAI)
/// ```rust,no_run
/// use ruvector_core::embeddings::{EmbeddingProvider, ApiEmbedding};
///
/// let provider = ApiEmbedding::openai("sk-...", "text-embedding-3-small");
/// let embedding = provider.embed("hello world")?;
/// # Ok::<(), Box<dyn std::error::Error>>(())
/// ```
#[cfg(feature = "api-embeddings")]
#[derive(Clone)]
pub struct ApiEmbedding {
api_key: String,
endpoint: String,
model: String,
dimensions: usize,
client: reqwest::blocking::Client,
}
#[cfg(feature = "api-embeddings")]
impl ApiEmbedding {
/// Create a new API embedding provider
///
/// # Arguments
/// * `api_key` - API key for authentication
/// * `endpoint` - API endpoint URL
/// * `model` - Model identifier
/// * `dimensions` - Expected embedding dimensions
pub fn new(api_key: String, endpoint: String, model: String, dimensions: usize) -> Self {
Self {
api_key,
endpoint,
model,
dimensions,
client: reqwest::blocking::Client::new(),
}
}
/// Create OpenAI embedding provider
///
/// # Models
/// - `text-embedding-3-small` - 1536 dimensions, $0.02/1M tokens
/// - `text-embedding-3-large` - 3072 dimensions, $0.13/1M tokens
/// - `text-embedding-ada-002` - 1536 dimensions (legacy)
pub fn openai(api_key: &str, model: &str) -> Self {
let dimensions = match model {
"text-embedding-3-large" => 3072,
_ => 1536, // text-embedding-3-small and ada-002
};
Self::new(
api_key.to_string(),
"https://api.openai.com/v1/embeddings".to_string(),
model.to_string(),
dimensions,
)
}
/// Create Cohere embedding provider
///
/// # Models
/// - `embed-english-v3.0` - 1024 dimensions
/// - `embed-multilingual-v3.0` - 1024 dimensions
pub fn cohere(api_key: &str, model: &str) -> Self {
Self::new(
api_key.to_string(),
"https://api.cohere.ai/v1/embed".to_string(),
model.to_string(),
1024,
)
}
/// Create Voyage AI embedding provider
///
/// # Models
/// - `voyage-2` - 1024 dimensions
/// - `voyage-large-2` - 1536 dimensions
pub fn voyage(api_key: &str, model: &str) -> Self {
let dimensions = if model.contains("large") { 1536 } else { 1024 };
Self::new(
api_key.to_string(),
"https://api.voyageai.com/v1/embeddings".to_string(),
model.to_string(),
dimensions,
)
}
}
#[cfg(feature = "api-embeddings")]
impl EmbeddingProvider for ApiEmbedding {
fn embed(&self, text: &str) -> Result<Vec<f32>> {
let request_body = serde_json::json!({
"input": text,
"model": self.model,
});
let response = self
.client
.post(&self.endpoint)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request_body)
.send()
.map_err(|e| {
RuvectorError::ModelInferenceError(format!("API request failed: {}", e))
})?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(RuvectorError::ModelInferenceError(format!(
"API returned error {}: {}",
status, error_text
)));
}
let response_json: serde_json::Value = response.json().map_err(|e| {
RuvectorError::ModelInferenceError(format!("Failed to parse response: {}", e))
})?;
// Handle different API response formats
let embedding = if let Some(data) = response_json.get("data") {
// OpenAI format: {"data": [{"embedding": [...]}]}
data.as_array()
.and_then(|arr| arr.first())
.and_then(|obj| obj.get("embedding"))
.and_then(|emb| emb.as_array())
.ok_or_else(|| {
RuvectorError::ModelInferenceError("Invalid OpenAI response format".to_string())
})?
} else if let Some(embeddings) = response_json.get("embeddings") {
// Cohere format: {"embeddings": [[...]]}
embeddings
.as_array()
.and_then(|arr| arr.first())
.and_then(|emb| emb.as_array())
.ok_or_else(|| {
RuvectorError::ModelInferenceError("Invalid Cohere response format".to_string())
})?
} else {
return Err(RuvectorError::ModelInferenceError(
"Unknown API response format".to_string(),
));
};
let embedding_vec: Result<Vec<f32>> = embedding
.iter()
.map(|v| {
v.as_f64().map(|f| f as f32).ok_or_else(|| {
RuvectorError::ModelInferenceError("Invalid embedding value".to_string())
})
})
.collect();
embedding_vec
}
fn dimensions(&self) -> usize {
self.dimensions
}
fn name(&self) -> &str {
"ApiEmbedding"
}
}
/// Type-erased embedding provider for dynamic dispatch
pub type BoxedEmbeddingProvider = Arc<dyn EmbeddingProvider>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hash_embedding() {
let provider = HashEmbedding::new(128);
let emb1 = provider.embed("hello world").unwrap();
let emb2 = provider.embed("hello world").unwrap();
assert_eq!(emb1.len(), 128);
assert_eq!(emb1, emb2, "Same text should produce same embedding");
// Check normalization
let norm: f32 = emb1.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5, "Embedding should be normalized");
}
#[test]
fn test_hash_embedding_different_text() {
let provider = HashEmbedding::new(128);
let emb1 = provider.embed("hello").unwrap();
let emb2 = provider.embed("world").unwrap();
assert_ne!(
emb1, emb2,
"Different text should produce different embeddings"
);
}
#[cfg(feature = "real-embeddings")]
#[test]
#[ignore] // Requires model download
fn test_candle_embedding() {
let provider =
CandleEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2", false)
.unwrap();
let embedding = provider.embed("hello world").unwrap();
assert_eq!(embedding.len(), 384);
// Check normalization
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5, "Embedding should be normalized");
}
#[test]
#[ignore] // Requires API key
fn test_api_embedding_openai() {
let api_key = std::env::var("OPENAI_API_KEY").unwrap();
let provider = ApiEmbedding::openai(&api_key, "text-embedding-3-small");
let embedding = provider.embed("hello world").unwrap();
assert_eq!(embedding.len(), 1536);
}
}