feat(onnx-embeddings-wasm): add model loader with HuggingFace support

Adds loader.js with:
- Pre-configured model URLs for 6 popular models
- ModelLoader class with caching and progress reporting
- createEmbedder() helper for quick setup
- embed() and similarity() one-liner helpers

Supported models:
- all-MiniLM-L6-v2 (default)
- all-MiniLM-L12-v2
- bge-small-en-v1.5
- bge-base-en-v1.5
- e5-small-v2
- gte-small

Usage:
```javascript
import { createEmbedder } from './loader.js';
const embedder = await createEmbedder('all-MiniLM-L6-v2');
const embedding = embedder.embedOne("Hello world");
```

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
rUv 2025-12-31 04:12:48 +00:00
parent 633aa2796d
commit 3c85487998
2 changed files with 490 additions and 0 deletions

View file

@ -0,0 +1,348 @@
/**
* Model Loader for RuVector ONNX Embeddings WASM
*
* Provides easy loading of pre-trained models from HuggingFace Hub
*/
/**
* Pre-configured models with their HuggingFace URLs
*/
export const MODELS = {
// Sentence Transformers - Small & Fast
'all-MiniLM-L6-v2': {
name: 'all-MiniLM-L6-v2',
dimension: 384,
maxLength: 256,
size: '23MB',
description: 'Fast, general-purpose embeddings',
model: 'https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model.onnx',
tokenizer: 'https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/tokenizer.json',
},
'all-MiniLM-L12-v2': {
name: 'all-MiniLM-L12-v2',
dimension: 384,
maxLength: 256,
size: '33MB',
description: 'Better quality, balanced speed',
model: 'https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/onnx/model.onnx',
tokenizer: 'https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/tokenizer.json',
},
// BGE Models - State of the art
'bge-small-en-v1.5': {
name: 'bge-small-en-v1.5',
dimension: 384,
maxLength: 512,
size: '33MB',
description: 'State-of-the-art small model',
model: 'https://huggingface.co/BAAI/bge-small-en-v1.5/resolve/main/onnx/model.onnx',
tokenizer: 'https://huggingface.co/BAAI/bge-small-en-v1.5/resolve/main/tokenizer.json',
},
'bge-base-en-v1.5': {
name: 'bge-base-en-v1.5',
dimension: 768,
maxLength: 512,
size: '110MB',
description: 'Best overall quality',
model: 'https://huggingface.co/BAAI/bge-base-en-v1.5/resolve/main/onnx/model.onnx',
tokenizer: 'https://huggingface.co/BAAI/bge-base-en-v1.5/resolve/main/tokenizer.json',
},
// E5 Models - Microsoft
'e5-small-v2': {
name: 'e5-small-v2',
dimension: 384,
maxLength: 512,
size: '33MB',
description: 'Excellent for search & retrieval',
model: 'https://huggingface.co/intfloat/e5-small-v2/resolve/main/onnx/model.onnx',
tokenizer: 'https://huggingface.co/intfloat/e5-small-v2/resolve/main/tokenizer.json',
},
// GTE Models - Alibaba
'gte-small': {
name: 'gte-small',
dimension: 384,
maxLength: 512,
size: '33MB',
description: 'Good multilingual support',
model: 'https://huggingface.co/thenlper/gte-small/resolve/main/onnx/model.onnx',
tokenizer: 'https://huggingface.co/thenlper/gte-small/resolve/main/tokenizer.json',
},
};
/**
* Default model for quick start
*/
export const DEFAULT_MODEL = 'all-MiniLM-L6-v2';
/**
* Model loader with caching support
*/
export class ModelLoader {
constructor(options = {}) {
this.cache = options.cache ?? true;
this.cacheStorage = options.cacheStorage ?? 'ruvector-models';
this.onProgress = options.onProgress ?? null;
}
/**
* Load a pre-configured model by name
* @param {string} modelName - Model name from MODELS
* @returns {Promise<{modelBytes: Uint8Array, tokenizerJson: string, config: object}>}
*/
async loadModel(modelName = DEFAULT_MODEL) {
const modelConfig = MODELS[modelName];
if (!modelConfig) {
throw new Error(`Unknown model: ${modelName}. Available: ${Object.keys(MODELS).join(', ')}`);
}
console.log(`Loading model: ${modelConfig.name} (${modelConfig.size})`);
const [modelBytes, tokenizerJson] = await Promise.all([
this.fetchWithCache(modelConfig.model, `${modelName}-model.onnx`, 'arraybuffer'),
this.fetchWithCache(modelConfig.tokenizer, `${modelName}-tokenizer.json`, 'text'),
]);
return {
modelBytes: new Uint8Array(modelBytes),
tokenizerJson,
config: modelConfig,
};
}
/**
* Load model from custom URLs
* @param {string} modelUrl - URL to ONNX model
* @param {string} tokenizerUrl - URL to tokenizer.json
* @returns {Promise<{modelBytes: Uint8Array, tokenizerJson: string}>}
*/
async loadFromUrls(modelUrl, tokenizerUrl) {
const [modelBytes, tokenizerJson] = await Promise.all([
this.fetchWithCache(modelUrl, null, 'arraybuffer'),
this.fetchWithCache(tokenizerUrl, null, 'text'),
]);
return {
modelBytes: new Uint8Array(modelBytes),
tokenizerJson,
};
}
/**
* Load model from local files (Node.js)
* @param {string} modelPath - Path to ONNX model
* @param {string} tokenizerPath - Path to tokenizer.json
* @returns {Promise<{modelBytes: Uint8Array, tokenizerJson: string}>}
*/
async loadFromFiles(modelPath, tokenizerPath) {
// Node.js environment
if (typeof process !== 'undefined' && process.versions?.node) {
const fs = await import('fs/promises');
const [modelBytes, tokenizerJson] = await Promise.all([
fs.readFile(modelPath),
fs.readFile(tokenizerPath, 'utf8'),
]);
return {
modelBytes: new Uint8Array(modelBytes),
tokenizerJson,
};
}
throw new Error('loadFromFiles is only available in Node.js');
}
/**
* Fetch with optional caching (uses Cache API in browsers)
*/
async fetchWithCache(url, cacheKey, responseType) {
// Try cache first (browser only)
if (this.cache && typeof caches !== 'undefined' && cacheKey) {
try {
const cache = await caches.open(this.cacheStorage);
const cached = await cache.match(cacheKey);
if (cached) {
console.log(` Cache hit: ${cacheKey}`);
return responseType === 'arraybuffer'
? await cached.arrayBuffer()
: await cached.text();
}
} catch (e) {
// Cache API not available, continue with fetch
}
}
// Fetch from network
console.log(` Downloading: ${url}`);
const response = await this.fetchWithProgress(url);
if (!response.ok) {
throw new Error(`Failed to fetch ${url}: ${response.status} ${response.statusText}`);
}
// Clone for caching
const responseClone = response.clone();
// Cache the response (browser only)
if (this.cache && typeof caches !== 'undefined' && cacheKey) {
try {
const cache = await caches.open(this.cacheStorage);
await cache.put(cacheKey, responseClone);
} catch (e) {
// Cache write failed, continue
}
}
return responseType === 'arraybuffer'
? await response.arrayBuffer()
: await response.text();
}
/**
* Fetch with progress reporting
*/
async fetchWithProgress(url) {
const response = await fetch(url);
if (!this.onProgress || !response.body) {
return response;
}
const contentLength = response.headers.get('content-length');
if (!contentLength) {
return response;
}
const total = parseInt(contentLength, 10);
let loaded = 0;
const reader = response.body.getReader();
const chunks = [];
while (true) {
const { done, value } = await reader.read();
if (done) break;
chunks.push(value);
loaded += value.length;
this.onProgress({
loaded,
total,
percent: Math.round((loaded / total) * 100),
});
}
const body = new Uint8Array(loaded);
let position = 0;
for (const chunk of chunks) {
body.set(chunk, position);
position += chunk.length;
}
return new Response(body, {
headers: response.headers,
status: response.status,
statusText: response.statusText,
});
}
/**
* Clear cached models
*/
async clearCache() {
if (typeof caches !== 'undefined') {
await caches.delete(this.cacheStorage);
console.log('Model cache cleared');
}
}
/**
* List available models
*/
static listModels() {
return Object.entries(MODELS).map(([key, config]) => ({
id: key,
...config,
}));
}
}
/**
* Quick helper to create an embedder with a pre-configured model
*
* @example
* ```javascript
* import { createEmbedder } from './loader.js';
*
* const embedder = await createEmbedder('all-MiniLM-L6-v2');
* const embedding = embedder.embedOne("Hello world");
* ```
*/
export async function createEmbedder(modelName = DEFAULT_MODEL, wasmModule = null) {
// Import WASM module if not provided
if (!wasmModule) {
wasmModule = await import('./pkg/ruvector_onnx_embeddings_wasm.js');
await wasmModule.default();
}
const loader = new ModelLoader();
const { modelBytes, tokenizerJson, config } = await loader.loadModel(modelName);
const embedderConfig = new wasmModule.WasmEmbedderConfig()
.setMaxLength(config.maxLength)
.setNormalize(true)
.setPooling(0); // Mean pooling
const embedder = wasmModule.WasmEmbedder.withConfig(
modelBytes,
tokenizerJson,
embedderConfig
);
return embedder;
}
/**
* Quick helper for one-off embedding (loads model, embeds, returns)
*
* @example
* ```javascript
* import { embed } from './loader.js';
*
* const embedding = await embed("Hello world");
* const embeddings = await embed(["Hello", "World"]);
* ```
*/
export async function embed(text, modelName = DEFAULT_MODEL) {
const embedder = await createEmbedder(modelName);
if (Array.isArray(text)) {
return embedder.embedBatch(text);
}
return embedder.embedOne(text);
}
/**
* Quick helper for similarity comparison
*
* @example
* ```javascript
* import { similarity } from './loader.js';
*
* const score = await similarity("I love dogs", "I adore puppies");
* console.log(score); // ~0.85
* ```
*/
export async function similarity(text1, text2, modelName = DEFAULT_MODEL) {
const embedder = await createEmbedder(modelName);
return embedder.similarity(text1, text2);
}
export default {
MODELS,
DEFAULT_MODEL,
ModelLoader,
createEmbedder,
embed,
similarity,
};

View file

@ -0,0 +1,142 @@
#!/usr/bin/env node
/**
* Full end-to-end test with model download
*
* Downloads all-MiniLM-L6-v2 and runs embedding tests
*/
import { ModelLoader, MODELS, DEFAULT_MODEL } from './loader.js';
import {
WasmEmbedder,
WasmEmbedderConfig,
cosineSimilarity,
} from './pkg/ruvector_onnx_embeddings_wasm.js';
console.log('🧪 RuVector ONNX Embeddings WASM - Full E2E Test\n');
console.log('='.repeat(60));
// List available models
console.log('\n📦 Available Models:');
ModelLoader.listModels().forEach(m => {
const isDefault = m.id === DEFAULT_MODEL ? ' ⭐ DEFAULT' : '';
console.log(`${m.id} (${m.dimension}d, ${m.size})${isDefault}`);
console.log(` ${m.description}`);
});
console.log('\n' + '='.repeat(60));
console.log(`\n🔄 Loading model: ${DEFAULT_MODEL}...\n`);
// Load model with progress
const loader = new ModelLoader({
cache: false, // Disable cache for testing
onProgress: ({ loaded, total, percent }) => {
process.stdout.write(`\r Progress: ${percent}% (${(loaded/1024/1024).toFixed(1)}MB / ${(total/1024/1024).toFixed(1)}MB)`);
}
});
try {
const { modelBytes, tokenizerJson, config } = await loader.loadModel(DEFAULT_MODEL);
console.log('\n');
console.log(` ✅ Model loaded: ${config.name}`);
console.log(` ✅ Model size: ${(modelBytes.length / 1024 / 1024).toFixed(2)} MB`);
console.log(` ✅ Tokenizer size: ${(tokenizerJson.length / 1024).toFixed(2)} KB`);
// Create embedder
console.log('\n🔧 Creating embedder...');
const embedderConfig = new WasmEmbedderConfig()
.setMaxLength(config.maxLength)
.setNormalize(true)
.setPooling(0);
const embedder = WasmEmbedder.withConfig(modelBytes, tokenizerJson, embedderConfig);
console.log(` ✅ Embedder created`);
console.log(` ✅ Dimension: ${embedder.dimension()}`);
console.log(` ✅ Max length: ${embedder.maxLength()}`);
// Test 1: Single embedding
console.log('\n' + '='.repeat(60));
console.log('\n📝 Test 1: Single Embedding');
const text1 = "The quick brown fox jumps over the lazy dog.";
console.log(` Input: "${text1}"`);
const start1 = performance.now();
const embedding1 = embedder.embedOne(text1);
const time1 = performance.now() - start1;
console.log(` ✅ Output dimension: ${embedding1.length}`);
console.log(` ✅ First 5 values: [${Array.from(embedding1.slice(0, 5)).map(v => v.toFixed(4)).join(', ')}]`);
console.log(` ✅ Time: ${time1.toFixed(2)}ms`);
// Test 2: Semantic similarity
console.log('\n' + '='.repeat(60));
console.log('\n📝 Test 2: Semantic Similarity');
const pairs = [
["I love programming in Rust", "Rust is my favorite programming language"],
["The weather is nice today", "It's sunny outside"],
["I love programming in Rust", "The weather is nice today"],
["Machine learning is fascinating", "AI and deep learning are interesting"],
];
for (const [a, b] of pairs) {
const start = performance.now();
const sim = embedder.similarity(a, b);
const time = performance.now() - start;
const label = sim > 0.5 ? '🟢 Similar' : '🔴 Different';
console.log(`\n "${a.substring(0, 30)}..."`);
console.log(` "${b.substring(0, 30)}..."`);
console.log(` ${label}: ${sim.toFixed(4)} (${time.toFixed(1)}ms)`);
}
// Test 3: Batch embedding
console.log('\n' + '='.repeat(60));
console.log('\n📝 Test 3: Batch Embedding');
const texts = [
"Artificial intelligence is transforming technology.",
"Machine learning models learn from data.",
"Deep learning uses neural networks.",
"Vector databases enable semantic search.",
];
console.log(` Embedding ${texts.length} texts...`);
const start3 = performance.now();
const batchEmbeddings = embedder.embedBatch(texts);
const time3 = performance.now() - start3;
const embeddingDim = embedder.dimension();
const numEmbeddings = batchEmbeddings.length / embeddingDim;
console.log(` ✅ Total values: ${batchEmbeddings.length}`);
console.log(` ✅ Embeddings: ${numEmbeddings} x ${embeddingDim}d`);
console.log(` ✅ Time: ${time3.toFixed(2)}ms (${(time3/texts.length).toFixed(2)}ms per text)`);
// Compute pairwise similarities
console.log('\n Pairwise similarities:');
for (let i = 0; i < numEmbeddings; i++) {
for (let j = i + 1; j < numEmbeddings; j++) {
const emb_i = batchEmbeddings.slice(i * embeddingDim, (i + 1) * embeddingDim);
const emb_j = batchEmbeddings.slice(j * embeddingDim, (j + 1) * embeddingDim);
const sim = cosineSimilarity(emb_i, emb_j);
console.log(` [${i}] vs [${j}]: ${sim.toFixed(4)}`);
}
}
// Summary
console.log('\n' + '='.repeat(60));
console.log('\n✅ All tests passed!');
console.log('='.repeat(60));
console.log('\n📊 Performance Summary:');
console.log(` • Model: ${config.name}`);
console.log(` • Dimension: ${embeddingDim}`);
console.log(` • Single embed: ~${time1.toFixed(0)}ms`);
console.log(` • Batch (4 texts): ~${time3.toFixed(0)}ms`);
console.log(` • Throughput: ~${(1000 / (time3/texts.length)).toFixed(0)} texts/sec`);
} catch (error) {
console.error('\n❌ Error:', error.message);
console.error(error.stack);
process.exit(1);
}