From 3c85487998f45a3d5ce87bb5ff289a12b7f8907e Mon Sep 17 00:00:00 2001 From: rUv Date: Wed, 31 Dec 2025 04:12:48 +0000 Subject: [PATCH] feat(onnx-embeddings-wasm): add model loader with HuggingFace support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds loader.js with: - Pre-configured model URLs for 6 popular models - ModelLoader class with caching and progress reporting - createEmbedder() helper for quick setup - embed() and similarity() one-liner helpers Supported models: - all-MiniLM-L6-v2 (default) - all-MiniLM-L12-v2 - bge-small-en-v1.5 - bge-base-en-v1.5 - e5-small-v2 - gte-small Usage: ```javascript import { createEmbedder } from './loader.js'; const embedder = await createEmbedder('all-MiniLM-L6-v2'); const embedding = embedder.embedOne("Hello world"); ``` ๐Ÿค– Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- examples/onnx-embeddings-wasm/loader.js | 348 ++++++++++++++++++++ examples/onnx-embeddings-wasm/test-full.mjs | 142 ++++++++ 2 files changed, 490 insertions(+) create mode 100644 examples/onnx-embeddings-wasm/loader.js create mode 100644 examples/onnx-embeddings-wasm/test-full.mjs diff --git a/examples/onnx-embeddings-wasm/loader.js b/examples/onnx-embeddings-wasm/loader.js new file mode 100644 index 000000000..b20b167fd --- /dev/null +++ b/examples/onnx-embeddings-wasm/loader.js @@ -0,0 +1,348 @@ +/** + * Model Loader for RuVector ONNX Embeddings WASM + * + * Provides easy loading of pre-trained models from HuggingFace Hub + */ + +/** + * Pre-configured models with their HuggingFace URLs + */ +export const MODELS = { + // Sentence Transformers - Small & Fast + 'all-MiniLM-L6-v2': { + name: 'all-MiniLM-L6-v2', + dimension: 384, + maxLength: 256, + size: '23MB', + description: 'Fast, general-purpose embeddings', + model: 'https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model.onnx', + tokenizer: 'https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/tokenizer.json', + }, + 'all-MiniLM-L12-v2': { + name: 'all-MiniLM-L12-v2', + dimension: 384, + maxLength: 256, + size: '33MB', + description: 'Better quality, balanced speed', + model: 'https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/onnx/model.onnx', + tokenizer: 'https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/tokenizer.json', + }, + + // BGE Models - State of the art + 'bge-small-en-v1.5': { + name: 'bge-small-en-v1.5', + dimension: 384, + maxLength: 512, + size: '33MB', + description: 'State-of-the-art small model', + model: 'https://huggingface.co/BAAI/bge-small-en-v1.5/resolve/main/onnx/model.onnx', + tokenizer: 'https://huggingface.co/BAAI/bge-small-en-v1.5/resolve/main/tokenizer.json', + }, + 'bge-base-en-v1.5': { + name: 'bge-base-en-v1.5', + dimension: 768, + maxLength: 512, + size: '110MB', + description: 'Best overall quality', + model: 'https://huggingface.co/BAAI/bge-base-en-v1.5/resolve/main/onnx/model.onnx', + tokenizer: 'https://huggingface.co/BAAI/bge-base-en-v1.5/resolve/main/tokenizer.json', + }, + + // E5 Models - Microsoft + 'e5-small-v2': { + name: 'e5-small-v2', + dimension: 384, + maxLength: 512, + size: '33MB', + description: 'Excellent for search & retrieval', + model: 'https://huggingface.co/intfloat/e5-small-v2/resolve/main/onnx/model.onnx', + tokenizer: 'https://huggingface.co/intfloat/e5-small-v2/resolve/main/tokenizer.json', + }, + + // GTE Models - Alibaba + 'gte-small': { + name: 'gte-small', + dimension: 384, + maxLength: 512, + size: '33MB', + description: 'Good multilingual support', + model: 'https://huggingface.co/thenlper/gte-small/resolve/main/onnx/model.onnx', + tokenizer: 'https://huggingface.co/thenlper/gte-small/resolve/main/tokenizer.json', + }, +}; + +/** + * Default model for quick start + */ +export const DEFAULT_MODEL = 'all-MiniLM-L6-v2'; + +/** + * Model loader with caching support + */ +export class ModelLoader { + constructor(options = {}) { + this.cache = options.cache ?? true; + this.cacheStorage = options.cacheStorage ?? 'ruvector-models'; + this.onProgress = options.onProgress ?? null; + } + + /** + * Load a pre-configured model by name + * @param {string} modelName - Model name from MODELS + * @returns {Promise<{modelBytes: Uint8Array, tokenizerJson: string, config: object}>} + */ + async loadModel(modelName = DEFAULT_MODEL) { + const modelConfig = MODELS[modelName]; + if (!modelConfig) { + throw new Error(`Unknown model: ${modelName}. Available: ${Object.keys(MODELS).join(', ')}`); + } + + console.log(`Loading model: ${modelConfig.name} (${modelConfig.size})`); + + const [modelBytes, tokenizerJson] = await Promise.all([ + this.fetchWithCache(modelConfig.model, `${modelName}-model.onnx`, 'arraybuffer'), + this.fetchWithCache(modelConfig.tokenizer, `${modelName}-tokenizer.json`, 'text'), + ]); + + return { + modelBytes: new Uint8Array(modelBytes), + tokenizerJson, + config: modelConfig, + }; + } + + /** + * Load model from custom URLs + * @param {string} modelUrl - URL to ONNX model + * @param {string} tokenizerUrl - URL to tokenizer.json + * @returns {Promise<{modelBytes: Uint8Array, tokenizerJson: string}>} + */ + async loadFromUrls(modelUrl, tokenizerUrl) { + const [modelBytes, tokenizerJson] = await Promise.all([ + this.fetchWithCache(modelUrl, null, 'arraybuffer'), + this.fetchWithCache(tokenizerUrl, null, 'text'), + ]); + + return { + modelBytes: new Uint8Array(modelBytes), + tokenizerJson, + }; + } + + /** + * Load model from local files (Node.js) + * @param {string} modelPath - Path to ONNX model + * @param {string} tokenizerPath - Path to tokenizer.json + * @returns {Promise<{modelBytes: Uint8Array, tokenizerJson: string}>} + */ + async loadFromFiles(modelPath, tokenizerPath) { + // Node.js environment + if (typeof process !== 'undefined' && process.versions?.node) { + const fs = await import('fs/promises'); + const [modelBytes, tokenizerJson] = await Promise.all([ + fs.readFile(modelPath), + fs.readFile(tokenizerPath, 'utf8'), + ]); + return { + modelBytes: new Uint8Array(modelBytes), + tokenizerJson, + }; + } + throw new Error('loadFromFiles is only available in Node.js'); + } + + /** + * Fetch with optional caching (uses Cache API in browsers) + */ + async fetchWithCache(url, cacheKey, responseType) { + // Try cache first (browser only) + if (this.cache && typeof caches !== 'undefined' && cacheKey) { + try { + const cache = await caches.open(this.cacheStorage); + const cached = await cache.match(cacheKey); + if (cached) { + console.log(` Cache hit: ${cacheKey}`); + return responseType === 'arraybuffer' + ? await cached.arrayBuffer() + : await cached.text(); + } + } catch (e) { + // Cache API not available, continue with fetch + } + } + + // Fetch from network + console.log(` Downloading: ${url}`); + const response = await this.fetchWithProgress(url); + + if (!response.ok) { + throw new Error(`Failed to fetch ${url}: ${response.status} ${response.statusText}`); + } + + // Clone for caching + const responseClone = response.clone(); + + // Cache the response (browser only) + if (this.cache && typeof caches !== 'undefined' && cacheKey) { + try { + const cache = await caches.open(this.cacheStorage); + await cache.put(cacheKey, responseClone); + } catch (e) { + // Cache write failed, continue + } + } + + return responseType === 'arraybuffer' + ? await response.arrayBuffer() + : await response.text(); + } + + /** + * Fetch with progress reporting + */ + async fetchWithProgress(url) { + const response = await fetch(url); + + if (!this.onProgress || !response.body) { + return response; + } + + const contentLength = response.headers.get('content-length'); + if (!contentLength) { + return response; + } + + const total = parseInt(contentLength, 10); + let loaded = 0; + + const reader = response.body.getReader(); + const chunks = []; + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + chunks.push(value); + loaded += value.length; + + this.onProgress({ + loaded, + total, + percent: Math.round((loaded / total) * 100), + }); + } + + const body = new Uint8Array(loaded); + let position = 0; + for (const chunk of chunks) { + body.set(chunk, position); + position += chunk.length; + } + + return new Response(body, { + headers: response.headers, + status: response.status, + statusText: response.statusText, + }); + } + + /** + * Clear cached models + */ + async clearCache() { + if (typeof caches !== 'undefined') { + await caches.delete(this.cacheStorage); + console.log('Model cache cleared'); + } + } + + /** + * List available models + */ + static listModels() { + return Object.entries(MODELS).map(([key, config]) => ({ + id: key, + ...config, + })); + } +} + +/** + * Quick helper to create an embedder with a pre-configured model + * + * @example + * ```javascript + * import { createEmbedder } from './loader.js'; + * + * const embedder = await createEmbedder('all-MiniLM-L6-v2'); + * const embedding = embedder.embedOne("Hello world"); + * ``` + */ +export async function createEmbedder(modelName = DEFAULT_MODEL, wasmModule = null) { + // Import WASM module if not provided + if (!wasmModule) { + wasmModule = await import('./pkg/ruvector_onnx_embeddings_wasm.js'); + await wasmModule.default(); + } + + const loader = new ModelLoader(); + const { modelBytes, tokenizerJson, config } = await loader.loadModel(modelName); + + const embedderConfig = new wasmModule.WasmEmbedderConfig() + .setMaxLength(config.maxLength) + .setNormalize(true) + .setPooling(0); // Mean pooling + + const embedder = wasmModule.WasmEmbedder.withConfig( + modelBytes, + tokenizerJson, + embedderConfig + ); + + return embedder; +} + +/** + * Quick helper for one-off embedding (loads model, embeds, returns) + * + * @example + * ```javascript + * import { embed } from './loader.js'; + * + * const embedding = await embed("Hello world"); + * const embeddings = await embed(["Hello", "World"]); + * ``` + */ +export async function embed(text, modelName = DEFAULT_MODEL) { + const embedder = await createEmbedder(modelName); + + if (Array.isArray(text)) { + return embedder.embedBatch(text); + } + return embedder.embedOne(text); +} + +/** + * Quick helper for similarity comparison + * + * @example + * ```javascript + * import { similarity } from './loader.js'; + * + * const score = await similarity("I love dogs", "I adore puppies"); + * console.log(score); // ~0.85 + * ``` + */ +export async function similarity(text1, text2, modelName = DEFAULT_MODEL) { + const embedder = await createEmbedder(modelName); + return embedder.similarity(text1, text2); +} + +export default { + MODELS, + DEFAULT_MODEL, + ModelLoader, + createEmbedder, + embed, + similarity, +}; diff --git a/examples/onnx-embeddings-wasm/test-full.mjs b/examples/onnx-embeddings-wasm/test-full.mjs new file mode 100644 index 000000000..b43ad6cbf --- /dev/null +++ b/examples/onnx-embeddings-wasm/test-full.mjs @@ -0,0 +1,142 @@ +#!/usr/bin/env node +/** + * Full end-to-end test with model download + * + * Downloads all-MiniLM-L6-v2 and runs embedding tests + */ + +import { ModelLoader, MODELS, DEFAULT_MODEL } from './loader.js'; +import { + WasmEmbedder, + WasmEmbedderConfig, + cosineSimilarity, +} from './pkg/ruvector_onnx_embeddings_wasm.js'; + +console.log('๐Ÿงช RuVector ONNX Embeddings WASM - Full E2E Test\n'); +console.log('='.repeat(60)); + +// List available models +console.log('\n๐Ÿ“ฆ Available Models:'); +ModelLoader.listModels().forEach(m => { + const isDefault = m.id === DEFAULT_MODEL ? ' โญ DEFAULT' : ''; + console.log(` โ€ข ${m.id} (${m.dimension}d, ${m.size})${isDefault}`); + console.log(` ${m.description}`); +}); + +console.log('\n' + '='.repeat(60)); +console.log(`\n๐Ÿ”„ Loading model: ${DEFAULT_MODEL}...\n`); + +// Load model with progress +const loader = new ModelLoader({ + cache: false, // Disable cache for testing + onProgress: ({ loaded, total, percent }) => { + process.stdout.write(`\r Progress: ${percent}% (${(loaded/1024/1024).toFixed(1)}MB / ${(total/1024/1024).toFixed(1)}MB)`); + } +}); + +try { + const { modelBytes, tokenizerJson, config } = await loader.loadModel(DEFAULT_MODEL); + console.log('\n'); + console.log(` โœ… Model loaded: ${config.name}`); + console.log(` โœ… Model size: ${(modelBytes.length / 1024 / 1024).toFixed(2)} MB`); + console.log(` โœ… Tokenizer size: ${(tokenizerJson.length / 1024).toFixed(2)} KB`); + + // Create embedder + console.log('\n๐Ÿ”ง Creating embedder...'); + const embedderConfig = new WasmEmbedderConfig() + .setMaxLength(config.maxLength) + .setNormalize(true) + .setPooling(0); + + const embedder = WasmEmbedder.withConfig(modelBytes, tokenizerJson, embedderConfig); + console.log(` โœ… Embedder created`); + console.log(` โœ… Dimension: ${embedder.dimension()}`); + console.log(` โœ… Max length: ${embedder.maxLength()}`); + + // Test 1: Single embedding + console.log('\n' + '='.repeat(60)); + console.log('\n๐Ÿ“ Test 1: Single Embedding'); + const text1 = "The quick brown fox jumps over the lazy dog."; + console.log(` Input: "${text1}"`); + + const start1 = performance.now(); + const embedding1 = embedder.embedOne(text1); + const time1 = performance.now() - start1; + + console.log(` โœ… Output dimension: ${embedding1.length}`); + console.log(` โœ… First 5 values: [${Array.from(embedding1.slice(0, 5)).map(v => v.toFixed(4)).join(', ')}]`); + console.log(` โœ… Time: ${time1.toFixed(2)}ms`); + + // Test 2: Semantic similarity + console.log('\n' + '='.repeat(60)); + console.log('\n๐Ÿ“ Test 2: Semantic Similarity'); + + const pairs = [ + ["I love programming in Rust", "Rust is my favorite programming language"], + ["The weather is nice today", "It's sunny outside"], + ["I love programming in Rust", "The weather is nice today"], + ["Machine learning is fascinating", "AI and deep learning are interesting"], + ]; + + for (const [a, b] of pairs) { + const start = performance.now(); + const sim = embedder.similarity(a, b); + const time = performance.now() - start; + + const label = sim > 0.5 ? '๐ŸŸข Similar' : '๐Ÿ”ด Different'; + console.log(`\n "${a.substring(0, 30)}..."`); + console.log(` "${b.substring(0, 30)}..."`); + console.log(` ${label}: ${sim.toFixed(4)} (${time.toFixed(1)}ms)`); + } + + // Test 3: Batch embedding + console.log('\n' + '='.repeat(60)); + console.log('\n๐Ÿ“ Test 3: Batch Embedding'); + + const texts = [ + "Artificial intelligence is transforming technology.", + "Machine learning models learn from data.", + "Deep learning uses neural networks.", + "Vector databases enable semantic search.", + ]; + + console.log(` Embedding ${texts.length} texts...`); + const start3 = performance.now(); + const batchEmbeddings = embedder.embedBatch(texts); + const time3 = performance.now() - start3; + + const embeddingDim = embedder.dimension(); + const numEmbeddings = batchEmbeddings.length / embeddingDim; + + console.log(` โœ… Total values: ${batchEmbeddings.length}`); + console.log(` โœ… Embeddings: ${numEmbeddings} x ${embeddingDim}d`); + console.log(` โœ… Time: ${time3.toFixed(2)}ms (${(time3/texts.length).toFixed(2)}ms per text)`); + + // Compute pairwise similarities + console.log('\n Pairwise similarities:'); + for (let i = 0; i < numEmbeddings; i++) { + for (let j = i + 1; j < numEmbeddings; j++) { + const emb_i = batchEmbeddings.slice(i * embeddingDim, (i + 1) * embeddingDim); + const emb_j = batchEmbeddings.slice(j * embeddingDim, (j + 1) * embeddingDim); + const sim = cosineSimilarity(emb_i, emb_j); + console.log(` [${i}] vs [${j}]: ${sim.toFixed(4)}`); + } + } + + // Summary + console.log('\n' + '='.repeat(60)); + console.log('\nโœ… All tests passed!'); + console.log('='.repeat(60)); + + console.log('\n๐Ÿ“Š Performance Summary:'); + console.log(` โ€ข Model: ${config.name}`); + console.log(` โ€ข Dimension: ${embeddingDim}`); + console.log(` โ€ข Single embed: ~${time1.toFixed(0)}ms`); + console.log(` โ€ข Batch (4 texts): ~${time3.toFixed(0)}ms`); + console.log(` โ€ข Throughput: ~${(1000 / (time3/texts.length)).toFixed(0)} texts/sec`); + +} catch (error) { + console.error('\nโŒ Error:', error.message); + console.error(error.stack); + process.exit(1); +}