diff --git a/npm/packages/rudag/.gitignore b/npm/packages/rudag/.gitignore new file mode 100644 index 000000000..59a08e350 --- /dev/null +++ b/npm/packages/rudag/.gitignore @@ -0,0 +1,26 @@ +# Dependencies +node_modules/ + +# Build artifacts (generated by npm run build) +# Note: dist/, pkg/, pkg-node/ are kept for publishing + +# TypeScript cache +*.tsbuildinfo + +# npm +npm-debug.log* +.npm + +# Test coverage +coverage/ +.nyc_output/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# OS +.DS_Store +Thumbs.db diff --git a/npm/packages/rudag/bin/cli.js b/npm/packages/rudag/bin/cli.js new file mode 100644 index 000000000..c40fd13b8 --- /dev/null +++ b/npm/packages/rudag/bin/cli.js @@ -0,0 +1,272 @@ +#!/usr/bin/env node + +/** + * rudag CLI - Command-line interface for DAG operations + */ + +const { RuDag, DagOperator, AttentionMechanism, MemoryStorage } = require('../dist/index.js'); +const fs = require('fs'); +const path = require('path'); + +const args = process.argv.slice(2); +const command = args[0]; + +const help = ` +rudag - Self-learning DAG query optimization CLI + +Usage: rudag [options] + +Commands: + create Create a new DAG + load Load DAG from file + info Show DAG information + topo Print topological sort + critical Find critical path + attention [type] Compute attention scores (type: topo|critical|uniform) + convert Convert between JSON and binary formats + help Show this help message + +Examples: + rudag create my-query > my-query.dag + rudag info my-query.dag + rudag critical my-query.dag + rudag attention my-query.dag critical + +Options: + --json Output in JSON format + --verbose Verbose output +`; + +async function main() { + if (!command || command === 'help' || command === '--help') { + console.log(help); + process.exit(0); + } + + const isJson = args.includes('--json'); + const verbose = args.includes('--verbose'); + + try { + switch (command) { + case 'create': { + const name = args[1] || 'untitled'; + const dag = new RuDag({ name, storage: null, autoSave: false }); + await dag.init(); + + // Create a simple example DAG + const scan = dag.addNode(DagOperator.SCAN, 10.0); + const filter = dag.addNode(DagOperator.FILTER, 2.0); + const project = dag.addNode(DagOperator.PROJECT, 1.0); + + dag.addEdge(scan, filter); + dag.addEdge(filter, project); + + if (isJson) { + console.log(dag.toJSON()); + } else { + const bytes = dag.toBytes(); + process.stdout.write(Buffer.from(bytes)); + } + break; + } + + case 'load': { + const file = args[1]; + if (!file) { + console.error('Error: No file specified'); + process.exit(1); + } + + const data = fs.readFileSync(file); + let dag; + + if (file.endsWith('.json')) { + dag = await RuDag.fromJSON(data.toString(), { storage: null }); + } else { + dag = await RuDag.fromBytes(new Uint8Array(data), { storage: null }); + } + + console.log(`Loaded DAG with ${dag.nodeCount} nodes and ${dag.edgeCount} edges`); + break; + } + + case 'info': { + const file = args[1]; + if (!file) { + console.error('Error: No file specified'); + process.exit(1); + } + + const data = fs.readFileSync(file); + let dag; + + if (file.endsWith('.json')) { + dag = await RuDag.fromJSON(data.toString(), { storage: null }); + } else { + dag = await RuDag.fromBytes(new Uint8Array(data), { storage: null }); + } + + const info = { + file, + nodes: dag.nodeCount, + edges: dag.edgeCount, + criticalPath: dag.criticalPath(), + }; + + if (isJson) { + console.log(JSON.stringify(info, null, 2)); + } else { + console.log(`File: ${info.file}`); + console.log(`Nodes: ${info.nodes}`); + console.log(`Edges: ${info.edges}`); + console.log(`Critical Path: ${info.criticalPath.path.join(' -> ')}`); + console.log(`Total Cost: ${info.criticalPath.cost}`); + } + break; + } + + case 'topo': { + const file = args[1]; + if (!file) { + console.error('Error: No file specified'); + process.exit(1); + } + + const data = fs.readFileSync(file); + let dag; + + if (file.endsWith('.json')) { + dag = await RuDag.fromJSON(data.toString(), { storage: null }); + } else { + dag = await RuDag.fromBytes(new Uint8Array(data), { storage: null }); + } + + const topo = dag.topoSort(); + + if (isJson) { + console.log(JSON.stringify(topo)); + } else { + console.log('Topological order:', topo.join(' -> ')); + } + break; + } + + case 'critical': { + const file = args[1]; + if (!file) { + console.error('Error: No file specified'); + process.exit(1); + } + + const data = fs.readFileSync(file); + let dag; + + if (file.endsWith('.json')) { + dag = await RuDag.fromJSON(data.toString(), { storage: null }); + } else { + dag = await RuDag.fromBytes(new Uint8Array(data), { storage: null }); + } + + const result = dag.criticalPath(); + + if (isJson) { + console.log(JSON.stringify(result)); + } else { + console.log('Critical Path:', result.path.join(' -> ')); + console.log('Total Cost:', result.cost); + } + break; + } + + case 'attention': { + const file = args[1]; + const type = args[2] || 'critical'; + + if (!file) { + console.error('Error: No file specified'); + process.exit(1); + } + + const data = fs.readFileSync(file); + let dag; + + if (file.endsWith('.json')) { + dag = await RuDag.fromJSON(data.toString(), { storage: null }); + } else { + dag = await RuDag.fromBytes(new Uint8Array(data), { storage: null }); + } + + let mechanism; + switch (type) { + case 'topo': + case 'topological': + mechanism = AttentionMechanism.TOPOLOGICAL; + break; + case 'critical': + case 'critical_path': + mechanism = AttentionMechanism.CRITICAL_PATH; + break; + case 'uniform': + mechanism = AttentionMechanism.UNIFORM; + break; + default: + console.error(`Unknown attention type: ${type}`); + process.exit(1); + } + + const scores = dag.attention(mechanism); + + if (isJson) { + console.log(JSON.stringify({ type, scores })); + } else { + console.log(`Attention type: ${type}`); + scores.forEach((score, i) => { + console.log(` Node ${i}: ${score.toFixed(4)}`); + }); + } + break; + } + + case 'convert': { + const inFile = args[1]; + const outFile = args[2]; + + if (!inFile || !outFile) { + console.error('Error: Input and output files required'); + process.exit(1); + } + + const data = fs.readFileSync(inFile); + let dag; + + if (inFile.endsWith('.json')) { + dag = await RuDag.fromJSON(data.toString(), { storage: null }); + } else { + dag = await RuDag.fromBytes(new Uint8Array(data), { storage: null }); + } + + if (outFile.endsWith('.json')) { + fs.writeFileSync(outFile, dag.toJSON()); + } else { + fs.writeFileSync(outFile, Buffer.from(dag.toBytes())); + } + + console.log(`Converted ${inFile} -> ${outFile}`); + break; + } + + default: + console.error(`Unknown command: ${command}`); + console.log('Run "rudag help" for usage information'); + process.exit(1); + } + } catch (error) { + console.error('Error:', error.message); + if (verbose) { + console.error(error.stack); + } + process.exit(1); + } +} + +main(); diff --git a/npm/packages/rudag/package.json b/npm/packages/rudag/package.json new file mode 100644 index 000000000..2b92c4138 --- /dev/null +++ b/npm/packages/rudag/package.json @@ -0,0 +1,96 @@ +{ + "name": "@ruvector/rudag", + "version": "0.1.0", + "description": "Self-learning DAG query optimization with WASM acceleration and IndexedDB persistence for browsers", + "main": "./dist/index.js", + "module": "./dist/index.mjs", + "types": "./dist/index.d.ts", + "bin": { + "rudag": "./bin/cli.js" + }, + "exports": { + ".": { + "types": "./dist/index.d.ts", + "browser": { + "import": "./dist/browser.mjs", + "require": "./dist/browser.js" + }, + "node": { + "import": "./dist/node.mjs", + "require": "./dist/node.js" + }, + "default": "./dist/index.js" + }, + "./browser": { + "types": "./dist/browser.d.ts", + "import": "./dist/browser.mjs", + "require": "./dist/browser.js" + }, + "./node": { + "types": "./dist/node.d.ts", + "import": "./dist/node.mjs", + "require": "./dist/node.js" + }, + "./wasm": { + "types": "./pkg/ruvector_dag_wasm.d.ts", + "import": "./pkg/ruvector_dag_wasm.js", + "require": "./pkg-node/ruvector_dag_wasm.js" + } + }, + "files": [ + "dist", + "pkg", + "pkg-node", + "bin", + "README.md", + "LICENSE" + ], + "scripts": { + "build:wasm": "npm run build:wasm:bundler && npm run build:wasm:node", + "build:wasm:bundler": "cd ../../../crates/ruvector-dag-wasm && wasm-pack build --target bundler --out-dir ../../npm/packages/rudag/pkg", + "build:wasm:node": "cd ../../../crates/ruvector-dag-wasm && wasm-pack build --target nodejs --out-dir ../../npm/packages/rudag/pkg-node", + "build:ts": "tsc && tsc -p tsconfig.esm.json", + "build": "npm run build:wasm && npm run build:ts", + "test": "node --test dist/**/*.test.js", + "prepublishOnly": "npm run build" + }, + "keywords": [ + "dag", + "query-optimization", + "self-learning", + "wasm", + "webassembly", + "indexeddb", + "browser", + "persistence", + "machine-learning", + "attention-mechanism", + "neural-network", + "ruvector" + ], + "author": "rUv Team ", + "license": "MIT OR Apache-2.0", + "repository": { + "type": "git", + "url": "https://github.com/ruvnet/ruvector.git", + "directory": "npm/packages/rudag" + }, + "bugs": { + "url": "https://github.com/ruvnet/ruvector/issues" + }, + "homepage": "https://github.com/ruvnet/ruvector/tree/main/crates/ruvector-dag", + "engines": { + "node": ">= 16" + }, + "publishConfig": { + "registry": "https://registry.npmjs.org/", + "access": "public" + }, + "devDependencies": { + "@types/node": "^20.19.25", + "typescript": "^5.9.3" + }, + "dependencies": { + "idb": "^8.0.0" + } +} diff --git a/npm/packages/rudag/src/browser.ts b/npm/packages/rudag/src/browser.ts new file mode 100644 index 000000000..7f5fa009e --- /dev/null +++ b/npm/packages/rudag/src/browser.ts @@ -0,0 +1,73 @@ +/** + * Browser-specific entry point with IndexedDB support + */ + +export * from './index'; + +// Re-export with browser-specific defaults +import { RuDag, DagStorage } from './index'; + +/** + * Create a browser-optimized DAG with IndexedDB persistence + */ +export async function createBrowserDag(name?: string): Promise { + const storage = new DagStorage(); + const dag = new RuDag({ name, storage }); + await dag.init(); + return dag; +} + +/** + * Browser storage manager for DAGs + */ +export class BrowserDagManager { + private storage: DagStorage; + private initialized = false; + + constructor() { + this.storage = new DagStorage(); + } + + async init(): Promise { + if (this.initialized) return; + await this.storage.init(); + this.initialized = true; + } + + async createDag(name?: string): Promise { + await this.init(); + const dag = new RuDag({ name, storage: this.storage }); + await dag.init(); + return dag; + } + + async loadDag(id: string): Promise { + await this.init(); + return RuDag.load(id, this.storage); + } + + async listDags() { + await this.init(); + return this.storage.list(); + } + + async deleteDag(id: string): Promise { + await this.init(); + return this.storage.delete(id); + } + + async clearAll(): Promise { + await this.init(); + return this.storage.clear(); + } + + async getStats() { + await this.init(); + return this.storage.stats(); + } + + close(): void { + this.storage.close(); + this.initialized = false; + } +} diff --git a/npm/packages/rudag/src/dag.ts b/npm/packages/rudag/src/dag.ts new file mode 100644 index 000000000..f2b460493 --- /dev/null +++ b/npm/packages/rudag/src/dag.ts @@ -0,0 +1,401 @@ +/** + * High-level DAG API with WASM acceleration + * Provides a TypeScript-friendly interface to the WASM DAG implementation + */ + +import { createStorage, DagStorage, MemoryStorage, StoredDag } from './storage'; + +// WASM module type definitions +interface WasmDagModule { + WasmDag: { + new(): WasmDagInstance; + from_bytes(data: Uint8Array): WasmDagInstance; + from_json(json: string): WasmDagInstance; + }; +} + +interface WasmDagInstance { + add_node(op: number, cost: number): number; + add_edge(from: number, to: number): boolean; + node_count(): number; + edge_count(): number; + topo_sort(): Uint32Array; + critical_path(): unknown; + attention(mechanism: number): Float32Array; + to_bytes(): Uint8Array; + to_json(): string; + free(): void; +} + +/** + * Operator types for DAG nodes + */ +export enum DagOperator { + SCAN = 0, + FILTER = 1, + PROJECT = 2, + JOIN = 3, + AGGREGATE = 4, + SORT = 5, + LIMIT = 6, + UNION = 7, + CUSTOM = 255, +} + +/** + * Attention mechanism types + */ +export enum AttentionMechanism { + TOPOLOGICAL = 0, + CRITICAL_PATH = 1, + UNIFORM = 2, +} + +/** + * Node representation + */ +export interface DagNode { + id: number; + operator: DagOperator | number; + cost: number; + metadata?: Record; +} + +/** + * Edge representation + */ +export interface DagEdge { + from: number; + to: number; +} + +/** + * Critical path result + */ +export interface CriticalPath { + path: number[]; + cost: number; +} + +/** + * DAG configuration options + */ +export interface RuDagOptions { + id?: string; + name?: string; + storage?: DagStorage | MemoryStorage | null; + autoSave?: boolean; +} + +let wasmModule: WasmDagModule | null = null; + +/** + * Initialize WASM module + */ +async function initWasm(): Promise { + if (wasmModule) return wasmModule; + + try { + // Try browser bundler version first + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const mod = await import('../pkg/ruvector_dag_wasm.js') as any; + if (typeof mod.default === 'function') { + await mod.default(); + } + wasmModule = mod as WasmDagModule; + return wasmModule; + } catch { + try { + // Fallback to Node.js version + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const mod = await import('../pkg-node/ruvector_dag_wasm.js') as any; + wasmModule = mod as WasmDagModule; + return wasmModule; + } catch (e) { + throw new Error(`Failed to load WASM module: ${e}`); + } + } +} + +/** + * RuDag - High-performance DAG with WASM acceleration and persistence + */ +export class RuDag { + private wasm: WasmDagInstance | null = null; + private nodes: Map = new Map(); + private storage: DagStorage | MemoryStorage | null; + private id: string; + private name?: string; + private autoSave: boolean; + private initialized = false; + + constructor(options: RuDagOptions = {}) { + this.id = options.id || `dag-${Date.now()}-${Math.random().toString(36).slice(2, 8)}`; + this.name = options.name; + this.storage = options.storage === undefined ? createStorage() : options.storage; + this.autoSave = options.autoSave ?? true; + } + + /** + * Initialize the DAG + */ + async init(): Promise { + if (this.initialized) return this; + + const mod = await initWasm(); + this.wasm = new mod.WasmDag(); + + if (this.storage) { + await this.storage.init(); + } + + this.initialized = true; + return this; + } + + /** + * Ensure DAG is initialized + */ + private ensureInit(): WasmDagInstance { + if (!this.wasm) { + throw new Error('DAG not initialized. Call init() first.'); + } + return this.wasm; + } + + /** + * Add a node to the DAG + */ + addNode(operator: DagOperator | number, cost: number, metadata?: Record): number { + const wasm = this.ensureInit(); + const id = wasm.add_node(operator, cost); + + this.nodes.set(id, { + id, + operator, + cost, + metadata, + }); + + if (this.autoSave) { + this.save().catch(() => {}); // Background save + } + + return id; + } + + /** + * Add an edge between nodes + */ + addEdge(from: number, to: number): boolean { + const wasm = this.ensureInit(); + const success = wasm.add_edge(from, to); + + if (success && this.autoSave) { + this.save().catch(() => {}); // Background save + } + + return success; + } + + /** + * Get node count + */ + get nodeCount(): number { + return this.ensureInit().node_count(); + } + + /** + * Get edge count + */ + get edgeCount(): number { + return this.ensureInit().edge_count(); + } + + /** + * Get topological sort + */ + topoSort(): number[] { + const result = this.ensureInit().topo_sort(); + return Array.from(result); + } + + /** + * Find critical path + */ + criticalPath(): CriticalPath { + const result = this.ensureInit().critical_path(); + + if (typeof result === 'string') { + return JSON.parse(result); + } + return result as CriticalPath; + } + + /** + * Compute attention scores + */ + attention(mechanism: AttentionMechanism = AttentionMechanism.CRITICAL_PATH): number[] { + const result = this.ensureInit().attention(mechanism); + return Array.from(result); + } + + /** + * Get node by ID + */ + getNode(id: number): DagNode | undefined { + return this.nodes.get(id); + } + + /** + * Get all nodes + */ + getNodes(): DagNode[] { + return Array.from(this.nodes.values()); + } + + /** + * Serialize to bytes + */ + toBytes(): Uint8Array { + return this.ensureInit().to_bytes(); + } + + /** + * Serialize to JSON + */ + toJSON(): string { + return this.ensureInit().to_json(); + } + + /** + * Save DAG to storage + */ + async save(): Promise { + if (!this.storage) return null; + + const data = this.toBytes(); + return this.storage.save(this.id, data, { + name: this.name, + metadata: { + nodeCount: this.nodeCount, + edgeCount: this.edgeCount, + nodes: Object.fromEntries(this.nodes), + }, + }); + } + + /** + * Load DAG from storage by ID + */ + static async load(id: string, storage?: DagStorage | MemoryStorage): Promise { + const store = storage || createStorage(); + await store.init(); + + const record = await store.get(id); + if (!record) return null; + + return RuDag.fromBytes(record.data, { + id: record.id, + name: record.name, + storage: store, + }); + } + + /** + * Create DAG from bytes + */ + static async fromBytes(data: Uint8Array, options: RuDagOptions = {}): Promise { + const mod = await initWasm(); + const dag = new RuDag(options); + dag.wasm = mod.WasmDag.from_bytes(data); + dag.initialized = true; + + if (dag.storage) { + await dag.storage.init(); + } + + return dag; + } + + /** + * Create DAG from JSON + */ + static async fromJSON(json: string, options: RuDagOptions = {}): Promise { + const mod = await initWasm(); + const dag = new RuDag(options); + dag.wasm = mod.WasmDag.from_json(json); + dag.initialized = true; + + if (dag.storage) { + await dag.storage.init(); + } + + return dag; + } + + /** + * List all stored DAGs + */ + static async listStored(storage?: DagStorage | MemoryStorage): Promise { + const store = storage || createStorage(); + await store.init(); + return store.list(); + } + + /** + * Delete a stored DAG + */ + static async deleteStored(id: string, storage?: DagStorage | MemoryStorage): Promise { + const store = storage || createStorage(); + await store.init(); + return store.delete(id); + } + + /** + * Get storage statistics + */ + static async storageStats(storage?: DagStorage | MemoryStorage): Promise<{ count: number; totalSize: number }> { + const store = storage || createStorage(); + await store.init(); + return store.stats(); + } + + /** + * Get DAG ID + */ + getId(): string { + return this.id; + } + + /** + * Get DAG name + */ + getName(): string | undefined { + return this.name; + } + + /** + * Set DAG name + */ + setName(name: string): void { + this.name = name; + if (this.autoSave) { + this.save().catch(() => {}); + } + } + + /** + * Cleanup resources + */ + dispose(): void { + if (this.wasm) { + this.wasm.free(); + this.wasm = null; + } + if (this.storage) { + this.storage.close(); + } + this.initialized = false; + } +} diff --git a/npm/packages/rudag/src/index.test.ts b/npm/packages/rudag/src/index.test.ts new file mode 100644 index 000000000..05f52e66c --- /dev/null +++ b/npm/packages/rudag/src/index.test.ts @@ -0,0 +1,216 @@ +/** + * Tests for @ruvector/rudag + */ + +import { test, describe, beforeEach, afterEach } from 'node:test'; +import assert from 'node:assert'; +import { RuDag, DagOperator, AttentionMechanism, MemoryStorage, createStorage } from './index'; + +describe('RuDag', () => { + let dag: RuDag; + + beforeEach(async () => { + dag = new RuDag({ storage: new MemoryStorage(), autoSave: false }); + await dag.init(); + }); + + afterEach(() => { + dag.dispose(); + }); + + test('should create empty DAG', () => { + assert.strictEqual(dag.nodeCount, 0); + assert.strictEqual(dag.edgeCount, 0); + }); + + test('should add nodes', () => { + const id1 = dag.addNode(DagOperator.SCAN, 10.0); + const id2 = dag.addNode(DagOperator.FILTER, 2.0); + + assert.strictEqual(id1, 0); + assert.strictEqual(id2, 1); + assert.strictEqual(dag.nodeCount, 2); + }); + + test('should add edges', () => { + const n1 = dag.addNode(DagOperator.SCAN, 10.0); + const n2 = dag.addNode(DagOperator.FILTER, 2.0); + + const success = dag.addEdge(n1, n2); + assert.strictEqual(success, true); + assert.strictEqual(dag.edgeCount, 1); + }); + + test('should reject cycles', () => { + const n1 = dag.addNode(DagOperator.SCAN, 1.0); + const n2 = dag.addNode(DagOperator.FILTER, 1.0); + const n3 = dag.addNode(DagOperator.PROJECT, 1.0); + + dag.addEdge(n1, n2); + dag.addEdge(n2, n3); + + // This should fail - would create cycle + const success = dag.addEdge(n3, n1); + assert.strictEqual(success, false); + }); + + test('should compute topological sort', () => { + const n1 = dag.addNode(DagOperator.SCAN, 1.0); + const n2 = dag.addNode(DagOperator.FILTER, 1.0); + const n3 = dag.addNode(DagOperator.PROJECT, 1.0); + + dag.addEdge(n1, n2); + dag.addEdge(n2, n3); + + const topo = dag.topoSort(); + assert.deepStrictEqual(topo, [0, 1, 2]); + }); + + test('should find critical path', () => { + const n1 = dag.addNode(DagOperator.SCAN, 10.0); + const n2 = dag.addNode(DagOperator.FILTER, 2.0); + const n3 = dag.addNode(DagOperator.PROJECT, 1.0); + + dag.addEdge(n1, n2); + dag.addEdge(n2, n3); + + const result = dag.criticalPath(); + assert.deepStrictEqual(result.path, [0, 1, 2]); + assert.strictEqual(result.cost, 13); // 10 + 2 + 1 + }); + + test('should compute attention scores', () => { + dag.addNode(DagOperator.SCAN, 1.0); + dag.addNode(DagOperator.FILTER, 2.0); + dag.addNode(DagOperator.PROJECT, 3.0); + + const uniform = dag.attention(AttentionMechanism.UNIFORM); + assert.strictEqual(uniform.length, 3); + // All should be approximately 0.333 + assert.ok(Math.abs(uniform[0] - 0.333) < 0.01); + + const topo = dag.attention(AttentionMechanism.TOPOLOGICAL); + assert.strictEqual(topo.length, 3); + + const critical = dag.attention(AttentionMechanism.CRITICAL_PATH); + assert.strictEqual(critical.length, 3); + }); + + test('should serialize to JSON', () => { + dag.addNode(DagOperator.SCAN, 1.0); + dag.addNode(DagOperator.FILTER, 2.0); + dag.addEdge(0, 1); + + const json = dag.toJSON(); + assert.ok(json.includes('nodes')); + assert.ok(json.includes('edges')); + }); + + test('should serialize to bytes', () => { + dag.addNode(DagOperator.SCAN, 1.0); + dag.addNode(DagOperator.FILTER, 2.0); + dag.addEdge(0, 1); + + const bytes = dag.toBytes(); + assert.ok(bytes instanceof Uint8Array); + assert.ok(bytes.length > 0); + }); + + test('should round-trip through JSON', async () => { + const n1 = dag.addNode(DagOperator.SCAN, 10.0); + const n2 = dag.addNode(DagOperator.FILTER, 2.0); + dag.addEdge(n1, n2); + + const json = dag.toJSON(); + const restored = await RuDag.fromJSON(json, { storage: null }); + + assert.strictEqual(restored.nodeCount, 2); + assert.strictEqual(restored.edgeCount, 1); + + restored.dispose(); + }); + + test('should round-trip through bytes', async () => { + const n1 = dag.addNode(DagOperator.SCAN, 10.0); + const n2 = dag.addNode(DagOperator.FILTER, 2.0); + dag.addEdge(n1, n2); + + const bytes = dag.toBytes(); + const restored = await RuDag.fromBytes(bytes, { storage: null }); + + assert.strictEqual(restored.nodeCount, 2); + assert.strictEqual(restored.edgeCount, 1); + + restored.dispose(); + }); +}); + +describe('MemoryStorage', () => { + let storage: MemoryStorage; + + beforeEach(async () => { + storage = new MemoryStorage(); + await storage.init(); + }); + + test('should save and retrieve DAG', async () => { + const data = new Uint8Array([1, 2, 3, 4]); + await storage.save('test-dag', data, { name: 'Test DAG' }); + + const retrieved = await storage.get('test-dag'); + assert.ok(retrieved); + assert.strictEqual(retrieved.id, 'test-dag'); + assert.strictEqual(retrieved.name, 'Test DAG'); + assert.deepStrictEqual(Array.from(retrieved.data), [1, 2, 3, 4]); + }); + + test('should list all DAGs', async () => { + await storage.save('dag-1', new Uint8Array([1])); + await storage.save('dag-2', new Uint8Array([2])); + + const list = await storage.list(); + assert.strictEqual(list.length, 2); + }); + + test('should delete DAG', async () => { + await storage.save('to-delete', new Uint8Array([1])); + assert.ok(await storage.get('to-delete')); + + await storage.delete('to-delete'); + assert.strictEqual(await storage.get('to-delete'), null); + }); + + test('should find by name', async () => { + await storage.save('dag-1', new Uint8Array([1]), { name: 'query' }); + await storage.save('dag-2', new Uint8Array([2]), { name: 'query' }); + await storage.save('dag-3', new Uint8Array([3]), { name: 'other' }); + + const results = await storage.findByName('query'); + assert.strictEqual(results.length, 2); + }); + + test('should calculate stats', async () => { + await storage.save('dag-1', new Uint8Array(100)); + await storage.save('dag-2', new Uint8Array(200)); + + const stats = await storage.stats(); + assert.strictEqual(stats.count, 2); + assert.strictEqual(stats.totalSize, 300); + }); + + test('should clear all', async () => { + await storage.save('dag-1', new Uint8Array([1])); + await storage.save('dag-2', new Uint8Array([2])); + + await storage.clear(); + const list = await storage.list(); + assert.strictEqual(list.length, 0); + }); +}); + +describe('createStorage', () => { + test('should create MemoryStorage in Node.js', () => { + const storage = createStorage(); + assert.ok(storage instanceof MemoryStorage); + }); +}); diff --git a/npm/packages/rudag/src/index.ts b/npm/packages/rudag/src/index.ts new file mode 100644 index 000000000..5af27a31a --- /dev/null +++ b/npm/packages/rudag/src/index.ts @@ -0,0 +1,60 @@ +/** + * @ruvector/rudag - Self-learning DAG query optimization + * + * Provides WASM-accelerated DAG operations with IndexedDB persistence + * for browser environments. + */ + +export { + RuDag, + DagOperator, + AttentionMechanism, + type DagNode, + type DagEdge, + type CriticalPath, + type RuDagOptions, +} from './dag'; + +export { + DagStorage, + MemoryStorage, + createStorage, + isIndexedDBAvailable, + type StoredDag, + type DagStorageOptions, +} from './storage'; + +// Version info +export const VERSION = '0.1.0'; + +/** + * Quick start example: + * + * ```typescript + * import { RuDag, DagOperator, AttentionMechanism } from '@ruvector/rudag'; + * + * // Create and initialize a DAG + * const dag = await new RuDag({ name: 'my-query' }).init(); + * + * // Add nodes (query operators) + * const scan = dag.addNode(DagOperator.SCAN, 10.0); + * const filter = dag.addNode(DagOperator.FILTER, 2.0); + * const project = dag.addNode(DagOperator.PROJECT, 1.0); + * + * // Connect nodes + * dag.addEdge(scan, filter); + * dag.addEdge(filter, project); + * + * // Get critical path + * const { path, cost } = dag.criticalPath(); + * console.log(`Critical path: ${path.join(' -> ')}, total cost: ${cost}`); + * + * // Compute attention scores + * const scores = dag.attention(AttentionMechanism.CRITICAL_PATH); + * console.log('Attention scores:', scores); + * + * // DAG is auto-saved to IndexedDB + * // Load it later + * const loadedDag = await RuDag.load(dag.getId()); + * ``` + */ diff --git a/npm/packages/rudag/src/node.ts b/npm/packages/rudag/src/node.ts new file mode 100644 index 000000000..d82b03ad4 --- /dev/null +++ b/npm/packages/rudag/src/node.ts @@ -0,0 +1,127 @@ +/** + * Node.js-specific entry point with filesystem support + */ + +export * from './index'; + +import { RuDag, MemoryStorage } from './index'; +import * as fs from 'fs'; +import * as path from 'path'; + +/** + * Create a Node.js DAG with memory storage + */ +export async function createNodeDag(name?: string): Promise { + const storage = new MemoryStorage(); + const dag = new RuDag({ name, storage }); + await dag.init(); + return dag; +} + +/** + * File-based storage for Node.js environments + */ +export class FileDagStorage { + private basePath: string; + + constructor(basePath: string = '.rudag') { + this.basePath = basePath; + } + + async init(): Promise { + if (!fs.existsSync(this.basePath)) { + fs.mkdirSync(this.basePath, { recursive: true }); + } + } + + private getFilePath(id: string): string { + return path.join(this.basePath, `${id}.dag`); + } + + private getMetaPath(id: string): string { + return path.join(this.basePath, `${id}.meta.json`); + } + + async save(id: string, data: Uint8Array, options: { name?: string; metadata?: Record } = {}): Promise { + await this.init(); + + fs.writeFileSync(this.getFilePath(id), Buffer.from(data)); + fs.writeFileSync(this.getMetaPath(id), JSON.stringify({ + id, + name: options.name, + metadata: options.metadata, + createdAt: Date.now(), + updatedAt: Date.now(), + })); + } + + async load(id: string): Promise { + const filePath = this.getFilePath(id); + if (!fs.existsSync(filePath)) { + return null; + } + return new Uint8Array(fs.readFileSync(filePath)); + } + + async delete(id: string): Promise { + const filePath = this.getFilePath(id); + const metaPath = this.getMetaPath(id); + + if (fs.existsSync(filePath)) { + fs.unlinkSync(filePath); + } + if (fs.existsSync(metaPath)) { + fs.unlinkSync(metaPath); + } + return true; + } + + async list(): Promise { + await this.init(); + + const files = fs.readdirSync(this.basePath); + return files + .filter(f => f.endsWith('.dag')) + .map(f => f.replace('.dag', '')); + } +} + +/** + * Node.js DAG manager with file persistence + */ +export class NodeDagManager { + private storage: FileDagStorage; + + constructor(basePath?: string) { + this.storage = new FileDagStorage(basePath); + } + + async init(): Promise { + await this.storage.init(); + } + + async createDag(name?: string): Promise { + const dag = new RuDag({ name, storage: null, autoSave: false }); + await dag.init(); + return dag; + } + + async saveDag(dag: RuDag): Promise { + const data = dag.toBytes(); + await this.storage.save(dag.getId(), data, { name: dag.getName() }); + } + + async loadDag(id: string): Promise { + const data = await this.storage.load(id); + if (!data) return null; + return RuDag.fromBytes(data, { id }); + } + + async deleteDag(id: string): Promise { + return this.storage.delete(id); + } + + async listDags(): Promise { + return this.storage.list(); + } +} diff --git a/npm/packages/rudag/src/storage.ts b/npm/packages/rudag/src/storage.ts new file mode 100644 index 000000000..bb5473d32 --- /dev/null +++ b/npm/packages/rudag/src/storage.ts @@ -0,0 +1,281 @@ +/** + * IndexedDB-based persistence layer for DAG storage + * Provides browser-compatible persistent storage for DAGs + */ + +const DB_NAME = 'rudag-storage'; +const DB_VERSION = 1; +const STORE_NAME = 'dags'; + +export interface StoredDag { + id: string; + name?: string; + data: Uint8Array; + createdAt: number; + updatedAt: number; + metadata?: Record; +} + +export interface DagStorageOptions { + dbName?: string; + version?: number; +} + +/** + * Check if IndexedDB is available (browser environment) + */ +export function isIndexedDBAvailable(): boolean { + return typeof indexedDB !== 'undefined'; +} + +/** + * IndexedDB storage class for DAG persistence + */ +export class DagStorage { + private dbName: string; + private version: number; + private db: IDBDatabase | null = null; + + constructor(options: DagStorageOptions = {}) { + this.dbName = options.dbName || DB_NAME; + this.version = options.version || DB_VERSION; + } + + /** + * Initialize the database connection + */ + async init(): Promise { + if (!isIndexedDBAvailable()) { + throw new Error('IndexedDB is not available in this environment'); + } + + return new Promise((resolve, reject) => { + const request = indexedDB.open(this.dbName, this.version); + + request.onerror = () => reject(request.error); + + request.onsuccess = () => { + this.db = request.result; + resolve(); + }; + + request.onupgradeneeded = (event) => { + const db = (event.target as IDBOpenDBRequest).result; + + if (!db.objectStoreNames.contains(STORE_NAME)) { + const store = db.createObjectStore(STORE_NAME, { keyPath: 'id' }); + store.createIndex('name', 'name', { unique: false }); + store.createIndex('createdAt', 'createdAt', { unique: false }); + store.createIndex('updatedAt', 'updatedAt', { unique: false }); + } + }; + }); + } + + /** + * Ensure database is initialized + */ + private ensureInit(): IDBDatabase { + if (!this.db) { + throw new Error('Database not initialized. Call init() first.'); + } + return this.db; + } + + /** + * Save a DAG to storage + */ + async save(id: string, data: Uint8Array, options: { name?: string; metadata?: Record } = {}): Promise { + const db = this.ensureInit(); + const now = Date.now(); + + // Check if exists for update timestamp + const existing = await this.get(id); + + const record: StoredDag = { + id, + name: options.name, + data, + createdAt: existing?.createdAt || now, + updatedAt: now, + metadata: options.metadata, + }; + + return new Promise((resolve, reject) => { + const transaction = db.transaction([STORE_NAME], 'readwrite'); + const store = transaction.objectStore(STORE_NAME); + const request = store.put(record); + + request.onsuccess = () => resolve(record); + request.onerror = () => reject(request.error); + }); + } + + /** + * Get a DAG from storage + */ + async get(id: string): Promise { + const db = this.ensureInit(); + + return new Promise((resolve, reject) => { + const transaction = db.transaction([STORE_NAME], 'readonly'); + const store = transaction.objectStore(STORE_NAME); + const request = store.get(id); + + request.onsuccess = () => resolve(request.result || null); + request.onerror = () => reject(request.error); + }); + } + + /** + * Delete a DAG from storage + */ + async delete(id: string): Promise { + const db = this.ensureInit(); + + return new Promise((resolve, reject) => { + const transaction = db.transaction([STORE_NAME], 'readwrite'); + const store = transaction.objectStore(STORE_NAME); + const request = store.delete(id); + + request.onsuccess = () => resolve(true); + request.onerror = () => reject(request.error); + }); + } + + /** + * List all DAGs in storage + */ + async list(): Promise { + const db = this.ensureInit(); + + return new Promise((resolve, reject) => { + const transaction = db.transaction([STORE_NAME], 'readonly'); + const store = transaction.objectStore(STORE_NAME); + const request = store.getAll(); + + request.onsuccess = () => resolve(request.result); + request.onerror = () => reject(request.error); + }); + } + + /** + * Search DAGs by name + */ + async findByName(name: string): Promise { + const db = this.ensureInit(); + + return new Promise((resolve, reject) => { + const transaction = db.transaction([STORE_NAME], 'readonly'); + const store = transaction.objectStore(STORE_NAME); + const index = store.index('name'); + const request = index.getAll(name); + + request.onsuccess = () => resolve(request.result); + request.onerror = () => reject(request.error); + }); + } + + /** + * Clear all DAGs from storage + */ + async clear(): Promise { + const db = this.ensureInit(); + + return new Promise((resolve, reject) => { + const transaction = db.transaction([STORE_NAME], 'readwrite'); + const store = transaction.objectStore(STORE_NAME); + const request = store.clear(); + + request.onsuccess = () => resolve(); + request.onerror = () => reject(request.error); + }); + } + + /** + * Get storage statistics + */ + async stats(): Promise<{ count: number; totalSize: number }> { + const dags = await this.list(); + const totalSize = dags.reduce((sum, dag) => sum + dag.data.byteLength, 0); + return { count: dags.length, totalSize }; + } + + /** + * Close the database connection + */ + close(): void { + if (this.db) { + this.db.close(); + this.db = null; + } + } +} + +/** + * In-memory storage fallback for Node.js or environments without IndexedDB + */ +export class MemoryStorage { + private store: Map = new Map(); + + async init(): Promise { + // No-op for memory storage + } + + async save(id: string, data: Uint8Array, options: { name?: string; metadata?: Record } = {}): Promise { + const now = Date.now(); + const existing = this.store.get(id); + + const record: StoredDag = { + id, + name: options.name, + data, + createdAt: existing?.createdAt || now, + updatedAt: now, + metadata: options.metadata, + }; + + this.store.set(id, record); + return record; + } + + async get(id: string): Promise { + return this.store.get(id) || null; + } + + async delete(id: string): Promise { + return this.store.delete(id); + } + + async list(): Promise { + return Array.from(this.store.values()); + } + + async findByName(name: string): Promise { + return Array.from(this.store.values()).filter(dag => dag.name === name); + } + + async clear(): Promise { + this.store.clear(); + } + + async stats(): Promise<{ count: number; totalSize: number }> { + const dags = Array.from(this.store.values()); + const totalSize = dags.reduce((sum, dag) => sum + dag.data.byteLength, 0); + return { count: dags.length, totalSize }; + } + + close(): void { + // No-op for memory storage + } +} + +/** + * Create appropriate storage based on environment + */ +export function createStorage(options: DagStorageOptions = {}): DagStorage | MemoryStorage { + if (isIndexedDBAvailable()) { + return new DagStorage(options); + } + return new MemoryStorage(); +} diff --git a/npm/packages/rudag/tsconfig.esm.json b/npm/packages/rudag/tsconfig.esm.json new file mode 100644 index 000000000..889fa57bb --- /dev/null +++ b/npm/packages/rudag/tsconfig.esm.json @@ -0,0 +1,11 @@ +{ + "extends": "./tsconfig.json", + "compilerOptions": { + "module": "ESNext", + "outDir": "./dist", + "declaration": false, + "declarationMap": false + }, + "include": ["src/**/*"], + "exclude": ["node_modules", "dist", "pkg", "pkg-node"] +} diff --git a/npm/packages/rudag/tsconfig.json b/npm/packages/rudag/tsconfig.json new file mode 100644 index 000000000..6b29bf2d4 --- /dev/null +++ b/npm/packages/rudag/tsconfig.json @@ -0,0 +1,20 @@ +{ + "compilerOptions": { + "target": "ES2020", + "module": "CommonJS", + "lib": ["ES2020", "DOM"], + "declaration": true, + "declarationMap": true, + "sourceMap": true, + "outDir": "./dist", + "rootDir": "./src", + "strict": true, + "esModuleInterop": true, + "skipLibCheck": true, + "forceConsistentCasingInFileNames": true, + "resolveJsonModule": true, + "moduleResolution": "node" + }, + "include": ["src/**/*"], + "exclude": ["node_modules", "dist", "pkg", "pkg-node"] +} diff --git a/patches/hnsw_rs/.gitignore b/patches/hnsw_rs/.gitignore new file mode 100644 index 000000000..11dd2f172 --- /dev/null +++ b/patches/hnsw_rs/.gitignore @@ -0,0 +1,9 @@ +target/** +Runs +Cargo.lock +rls* +dumpreloadtest* +*.pdf +*.html +.idea/ +.vscode/ diff --git a/patches/hnsw_rs/Cargo.toml b/patches/hnsw_rs/Cargo.toml new file mode 100644 index 000000000..7b1147d83 --- /dev/null +++ b/patches/hnsw_rs/Cargo.toml @@ -0,0 +1,111 @@ +[package] +name = "hnsw_rs" +version = "0.3.3" +authors = ["jeanpierre.both@gmail.com"] +description = "Ann based on Hierarchical Navigable Small World Graphs from Yu.A. Malkov and D.A Yashunin" +license = "MIT/Apache-2.0" +readme = "README.md" +keywords = ["algorithms", "ann", "hnsw"] +repository = "https://github.com/jean-pierreBoth/hnswlib-rs" +documentation = "https://docs.rs/hnsw_rs" +edition = "2021" + + +# declare a feature with no dependancy to get some modulated debug print +# to be run with cargo build --features verbose_1 +#verbose_1 = [ ] + +[profile.release] +lto = true +opt-level = 3 + +[lib] +# cargo rustc --lib -- --crate-type cdylib [or staticlib] or rlib (default) +# if we want to avoid specifying in advance crate-type +path = "src/lib.rs" +#crate-type = ["cdylib"] + + +[[example]] +name = "random" +path = "examples/random.rs" + + +[[example]] +name = "ann-glove" +path = "examples/ann-glove25-angular.rs" + + +[[example]] +name = "ann-mnist" +path = "examples/ann-mnist-784-euclidean.rs" + +[[example]] +name = "ann-sift1m" +path = "examples/ann-sift1m-128-euclidean.rs" + +[[example]] +name = "levenshtein" +path = "examples/levensthein.rs" + + +[dependencies] +# default is version spec is ^ meaning can update up to max non null version number +# cargo doc --no-deps avoid dependencies doc generation +# + +serde = { version = "1.0", features = ["derive"] } +bincode = { version = "1.3" } + +cfg-if = { version = "1.0" } + +# for // +parking_lot = "0.12" +rayon = { version = "1.11" } +num_cpus = { version = "1.16" } + +cpu-time = { version = "1.0" } +num-traits = { version = "0.2" } + + +# for hashing . hashbrown still needed beccause of get_key_value(&key) +hashbrown = { version = "0.15" } +indexmap = { version = ">= 2.11, < 2.13" } + +rand = { version = "0.8" } +lazy_static = { version = "1.4" } + +# +mmap-rs = { version = "0.6" } +# +# decreasing order of log for debug build : (max_level_)trace debug info warn error off +# decreasing order of log for release build (release_max_level_) .. idem +#log = { version = "0.4", features = ["max_level_debug", "release_max_level_info"] } +log = { version = "0.4" } +env_logger = { version = "0.11" } + +anyhow = { version = "1.0" } + +# anndists = { path = "../anndists" } +anndists = { version = "0.1" } +# anndists = { git = "https://github.com/jean-pierreBoth/anndists" } + +# for benchmark reading, so the lbrary do not depend on hdf5 nor ndarray +[dev-dependencies] +# hdf5 = { version = "0.8" } +# metno is needed as hdf5 is blocked to hdfsys 1.12 +hdf5 = {package = "hdf5-metno", version = "0.10.0" } + +ndarray = { version = ">=0.16.0, <0.18" } +skiplist = { version = "0.6" } +tempfile = { version = "3" } +itertools = {version = "0.14"} + +[features] + +default = [] + +# feature for std simd on nightly +stdsimd = ["anndists/stdsimd"] +# feature for simd on stable for x86* +simdeez_f = ["anndists/simdeez_f"] diff --git a/patches/hnsw_rs/Changes.md b/patches/hnsw_rs/Changes.md new file mode 100644 index 000000000..3e7e4abf7 --- /dev/null +++ b/patches/hnsw_rs/Changes.md @@ -0,0 +1,56 @@ +- version 0.3.3 + small fix on filter (thanks to VillSnow). include ndarray 0.17 as possible dep. fixed compiler warning on elided lifetimes + +- version 0.3.2 + update dependencies to ndarray 0.16 , rand 0.9 indexmap 2.9, hdf5. edition=2024 + +- version 0.3.1 + + Possibility to reduce the number of levels used Hnsw structure with the function hnsw::modify_level_scale. + This often increases significantly recall while incurring a moderate cpu cost. It is also possible + to have same recall with smaller *max_nb_conn* parameters so reducing memory usage. + See README.md at [bigann](https://github.com/jean-pierreBoth/bigann). + Modification inspired by the article by [Munyampirwa](https://arxiv.org/abs/2412.01940) + + Clippy cleaning and minor arguments change (PathBuf to Path String to &str) in dump/reload + with the help of bwsw (https://github.com/bwsw) + + +- **version 0.3.0**: + + The distances implementation is now in a separate crate [anndsits](https://crates.io/crates/anndists). Using hnsw_rs::prelude:::* should make the change transparent. + + The mmap implementation makes it possible to use the [coreset](https://github.com/jean-pierreBoth/coreset) crate to compute coreset and clusters of data stored in hnsw dumps. + +- version 0.2.1: + + when using mmap, the points less frequently used (points in lower layers) are preferentially mmap-ed while upper layers are preferentially + explcitly read from file. + + Hnswio is now Sync. + + feature stdsimd, based on std::simd, runs with nightly on Hamming with u32,u64 and DisL1,DistL2, DistDot with f32 + +- The **version 0.2** introduces + 1. possibility to use mmap on the data file storing the vectors represented in the hnsw structure. This is mostly usefule for + large vectors, where data needs more space than the graph part. + As a consequence the format of this file changed. Old format can be read but new dumps will be in the new format. + In case of mmap usage, a dump after inserting new elements must ensure that the old file is not overwritten, so a unique file name is + generated if necessary. See documentation of module Hnswio + + 1. the filtering trait + + +- Upgrade of many dependencies. Change from simple_logger to env_logger. The logger is initialized one for all in file src/lib.rs and cannot be intialized twice. The level of log can be modulated by the RUST_LOG env variable on a module basis or switched off. See the *env_logger* crate doc. + +- A rust crate *edlib_rs* provides an interface to the *excellent* edlib C++ library [(Cf edlib)](https://github.com/Martinsos/edlib) can be found at [edlib_rs](https://github.com/jean-pierreBoth/edlib-rs) or on crate.io. It can be used to define a user adhoc distance on &[u8] with normal, prefix or infix mode (which is useful in genomics alignment). + +- The library do not depend anymore on hdf5 and ndarray. They are dev-dependancies needed for examples, this simplify compatibility issues. +- Added insertion methods for slices for easier use with the ndarray crate. + +- simd/avx2 requires now the feature "simdeez_f". So by default the crate can compile on M1 chip and transitions to std::simd. + +- Added DistPtr and possiblity to dump/reload with this distance type. (See *load_hnsw_with_dist* function) + +- Implementation of Hamming for f64 exclusively in the context SuperMinHash in crate [probminhash](https://crates.io/crates/probminhash) + diff --git a/patches/hnsw_rs/LICENSE-APACHE b/patches/hnsw_rs/LICENSE-APACHE new file mode 100644 index 000000000..d8afa4c9a --- /dev/null +++ b/patches/hnsw_rs/LICENSE-APACHE @@ -0,0 +1,13 @@ +Copyright 2020 jean-pierre.both + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/patches/hnsw_rs/LICENSE-MIT b/patches/hnsw_rs/LICENSE-MIT new file mode 100644 index 000000000..531ac4659 --- /dev/null +++ b/patches/hnsw_rs/LICENSE-MIT @@ -0,0 +1,25 @@ +Copyright (c) 2020 jean-pierre.both + +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/patches/hnsw_rs/README.md b/patches/hnsw_rs/README.md new file mode 100644 index 000000000..f905b8923 --- /dev/null +++ b/patches/hnsw_rs/README.md @@ -0,0 +1,168 @@ +# hnsw-rs + +This crate provides a Rust implementation of the paper by Yu.A. Malkov and D.A Yashunin: + +"Efficient and Robust approximate nearest neighbours using Hierarchical Navigable Small World Graphs" (2016,2018) +[arxiv](https://arxiv.org/abs/1603.09320) + + + +## Functionalities + +The crate is built on top of the [anndists](https://crates.io/crates/anndists) and can use the following distances: + +* usual distances as L1, L2, Cosine, Jaccard, Hamming for vectors of standard numeric types, Levenshtein distance on u16. + +* Hellinger distance and Jeffreys divergence between probability distributions (f32 and f64). It must be noted that the Jeffreys divergence +(a symetrized Kullback-Leibler divergence) do not satisfy the triangle inequality. (Neither Cosine distance !). + +* Jensen-Shannon distance between probability distributions (f32 and f64). It is defined as the **square root** of the Jensen-Shannon divergence and is a bounded metric. See [Nielsen F. in Entropy 2019, 21(5), 485](https://doi.org/10.3390/e21050485). + +* A Trait to enable the user to implement its own distances. + It takes as data slices of types T satisfying T:Serialize+Clone+Send+Sync. It is also possible to use C extern functions or closures. + +* An interface towards C and more specifically to the [Julia](https://julialang.org/) language. +See the companion Julia package [HnswAnn.jl](https://gitlab.com/jpboth/HnswAnn.jl) and the building paragraph for some help for Julia users. + +The hnsw implementation provides: + +* Multithreaded insertion and search requests. + +* Dump and reload functions (*See module hnswio*) to store the data and the graph once it is built. These facilities rely partly on Serde so T needs to implement Serialize and Deserialized as derived by Serde. + It is also possible to reload only the graph and not the data themselves. A specific type (struct NoData, associated to the NoDist distance is dedicated to this functionality. + +* A flattening conversion of the Hnsw structure to keep only neighborhood relationships between points (without their internal data) internal to the Hnsw structure (*see module flatten.rs, FlatPoint and FlatNeighborhood*). It is thus possible to keep some topology information with low memory usage. + +* Filtering: It is possible to add filters so only results which satisfies the filter is in the result set. The filtering is done during the search, so it is not a post filter. There is currently two ways of using the filter, one can add allowed ids in a sorted vector and send as a parameter, or one can define a function which will be called before an id is added to the result set. +Examples on both these strategies are in the examples or tests directory. One can also implement the trait Filterable for new types, if one would like the filter to be kept in a bitvector, for example. + +* Possibilty to use mmap on dumped data (not on graph part) which is useful for large data vectors. This enables coreset and clusters computation in streaming, see [coreset](https://github.com/jean-pierreBoth/coreset) and soon on [crates.io](https://crates.io/crates). + +## Implementation + +The graph construction and searches are multithreaded with the **parking_lot** crate (See **parallel_insert_data** and **parallel_search_neighbours** functions and also examples files). +Distances are provided by the crate [anndists](https://github.com/jean-pierreBoth/anndists), see *Building*. + +## Building + +### Simd + +Two features activate simd in the crate **anndists** : + +* The feature "simdeez_f" provide simd for x86_64 processors. +Compile with **cargo build --release --features "simdeez_f"** or change the default features in Cargo.toml. +To compile this crate on a M1 chip just do not activate this feature. + +* The feature "stdsimd" provides portable simd through std::simd but **requires rust nightly**. +Setting this feature in features default (or by cargo command) activates the portable_simd feature on rust nightly. + Not all couples (Distance, type) are provided yet. (See the crate anndists) + +### Julia interface + +By default the crate is a standalone project and builds a static libray and executable. +To be used with the companion Julia package it is necessary to build a dynamic library. +This can be done by just uncommenting (i.e get rid of the #) in file Cargo.toml the line: + +*#crate-type = ["cdylib"]* + +and rerun the command: cargo build --release. + +This will generate a .so file in the target/release directory. + +## Algorithm and Input Parameters + +The algorithm stores points in layers (at most 16), and a graph is constructed to enable a search from less densely populated levels to most densely populated levels by constructing links from less dense layers to the most dense layer (level 0). + +Roughly the algorithm goes along runs as follows: + +Upon insertion, the level ***l*** of a new point is sampled with an exponential law, limiting the number of levels to 16, +so that level 0 is the most densely populated layer, upper layers being exponentially less populated as level increases. +The nearest neighbour of the point is searched in lookup tables from the upper level to the level just above its layer (***l***), so we should arrive near the new point at its level at a relatively low cost. Then the ***max_nb_connection*** nearest neighbours are searched in neighbours of neighbours table (with a reverse updating of tables) recursively from its layer ***l*** down to the most populated level 0. + +The parameter of the exponential law to sample point levels is set to `ln(max_nb_connection)/scale`. +By default *scale* is set to 1. It is possible to reduce the *scale* parameter and thus reduce the number of levels used (See Hnsw::modify_level_scale) without increasing max_nb_connection. +This often provide better recalls without increasing *max_nb_connection* and thus spare memory usage. (See examples) + + +The main parameters occuring in constructing the graph or in searching are: + +* max_nb_connection (in hnsw initialization) + The maximum number of links from one point to others. Values ranging from 16 to 64 are standard initialising values, the higher the more time consuming. + +* ef_construction (in hnsw initialization) + This parameter controls the width of the search for neighbours during insertion. Values from 200 to 800 are standard initialising values, the higher the more time consuming. + +* max_layer (in hnsw initialization) + The maximum number of layers in graph. Must be less or equal than 16. + +* ef_arg (in search methods) + This parameter controls the width of the search in the lowest level, it must be greater than number of neighbours asked but can be less than ***ef_construction***. + As a rule of thumb could be between the number of neighbours we will ask for (knbn arg in search method) and max_nb_connection. + +* keep_pruned and extend_candidates. + These parameters are described in the paper by Malkov and Yashunin can be used to + modify the search strategy. The interested user should check the paper to see the impact. By default + the values are as recommended in the paper. + +## Benchmarks and Examples [(examples)](./examples) + +Some examples are taken from the [ann-benchmarks site](https://github.com/erikbern/ann-benchmarks) +and recall rates and request/s are given in comments in the examples files for some input parameters. +The annhdf5 module implements reading the standardized data files +of the [ann-benchmarks site](https://github.com/erikbern/ann-benchmarks), +just download the necessary benchmark data files and modify path in sources accordingly. +Then run: cargo build --release --features="simdeez_f" --examples . +It is possible in these examples to change from parallel searches to serial searches to check for speeds +or modify parameters to see the impact on performance. + +With a i9-13900HX 24 cores laptop we get the following results: +1. fashion-mnist-784-euclidean : search requests run at 62000 req/s with a recall rate of 0.977 +2. ann-glove-25-angular : search for the first 100 neighbours run with recall 0.979 at 12000 req/s +3. sift1m benchmark: (1 million points in 128 dimension) search requests for the 10 first neighbours runs at 15000 req/s with a recall rate of 0.9907 or at 8300 req/s with a recall rate of 0.9959, depending on the parameters. + +Moreover a tiny crate [bigann](https://github.com/jean-pierreBoth/bigann) +gives results on the first 10 Million points of the [BIGANN](https://big-ann-benchmarks.com/neurips21.html) benchmark. The benchmark is also described at [IRISA](http://corpus-texmex.irisa.fr/). This crate can used to play with parameters on this data. Results give a recall between 0.92 and 0.99 depending on number of requests and parameters. + +Some lines extracted from this Mnist benchmark show how it works for f32 and L2 norm + +```rust + // reading data + let anndata = AnnBenchmarkData::new(fname).unwrap(); + let nb_elem = anndata.train_data.len(); + let max_nb_connection = 24; + let nb_layer = 16.min((nb_elem as f32).ln().trunc() as usize); + let ef_c = 400; + // allocating network + let mut hnsw = Hnsw::::new(max_nb_connection, nb_elem, nb_layer, ef_c, DistL2{}); + hnsw.set_extend_candidates(false); + // parallel insertion of train data + let data_for_par_insertion = anndata.train_data.iter().map( |x| (&x.0, x.1)).collect(); + hnsw.parallel_insert(&data_for_par_insertion); + // + hnsw.dump_layer_info(); + // Now the bench with 10 neighbours + let mut knn_neighbours_for_tests = Vec::>::with_capacity(nb_elem); + hnsw.set_searching_mode(true); + let knbn = 10; + let ef_c = max_nb_connection; + // search 10 nearest neighbours for test data + knn_neighbours_for_tests = hnsw.parallel_search(&anndata.test_data, knbn, ef_c); + .... +``` + +## Contributions + +[Sannsyn](https://sannsyn.com/en/) contributed to Drop implementation and FilterT trait. +Petter Egesund added the DistLevenshtein distance. + +## Evolutions are described [here](./Changes.md) + +## License + +Licensed under either of + +* Apache License, Version 2.0, [LICENSE-APACHE](LICENSE-APACHE) or +* MIT license [LICENSE-MIT](LICENSE-MIT) or + +at your option. + diff --git a/patches/hnsw_rs/examples/ann-glove25-angular.rs b/patches/hnsw_rs/examples/ann-glove25-angular.rs new file mode 100644 index 000000000..861444b27 --- /dev/null +++ b/patches/hnsw_rs/examples/ann-glove25-angular.rs @@ -0,0 +1,220 @@ +#![allow(clippy::needless_range_loop)] + +use cpu_time::ProcessTime; +use std::time::{Duration, SystemTime}; + +// glove 25 // 2.7 Ghz 4 cores 8Mb L3 k = 10 +// ============================================ +// +// max_nb_conn ef_cons ef_search scale_factor extend keep pruned recall req/s last ratio +// 24 800 64 1. 1 0 0.928 4090 1.003 +// 24 800 64 1. 1 1 0.927 4594 1.003 +// 24 400, 48 1. 1 0 0.919 6349 1.0044 +// 24 800 48 1 1 1 0.918 5785 1.005 +// 24 400 32 1. 0 0 0.898 8662 +// 24 400 64 1. 1 0 0.930 4711 1.0027 +// 24 400 64 1. 1 1 0.921 4550 1.0039 +// 24 1600 48 1 1 0 0.924 5380 1.0034 + +// 32 400 48 1 1 0 0.93 4706 1.0026 +// 32 800 64 1 1 0 0.94 3780. 1.0015 +// 32 1600 48 1 1 0 0.934 4455 1.0023 +// 48 1600 48 1 1 0 0.945 3253 1.00098 + +// 24 400 48 1 1 0 0.92 6036. 1.0038 +// 48 800 48 1 1 0 0.935 4018 1.002 +// 48 800 64 1 1 0 0.942 3091 1.0014 +// 48 800 64 1 1 1 0.9435 2640 1.00126 + +// k = 100 + +// 24 800 48 1 1 0 0.96 2432 1.004 +// 48 800 128 1 1 0 0.979 1626 1.001 + +// glove 25 // 8 cores i7 2.3 Ghz 8Mb L3 knbn = 100 +// ================================================== + +// 48 800 48 1 1 0 0.935 13400 1.002 +// 48 800 128 1 1 0 0.979 5227 1.002 + +// 24 core Core(TM) i9-13900HX simdeez knbn = 10 +// ================================================== +// 48 800 48 1 1 0 0.936 30748 1.002 + +// 24 core Core(TM) i9-13900HX simdeez knbn = 100 +// ================================================== +// 48 800 128 1 1 0 0.979 12000 1.002 + +// results with scale modification 0.5 +//==================================== + +// 24 core Core(TM) i9-13900HX simdeez knbn = 10 +// ================================================== +// 24 800 48 0.5 1 0 0.931 40700 1.002 +// 48 800 48 0.5 1 0 0.941 30001 1.001 + +// 24 core Core(TM) i9-13900HX simdeez knbn = 100 +// ================================================== +// 24 800 128 0.5 1 0 0.974 16521 1.002 +// 48 800 128 0.5 1 0 0.985 11484 1.001 + +use anndists::dist::*; +use hnsw_rs::prelude::*; +use log::info; + +mod utils; + +use utils::*; + +pub fn main() { + let _ = env_logger::builder().is_test(true).try_init().unwrap(); + let parallel = true; + // + let fname = String::from("/home/jpboth/Data/ANN/glove-25-angular.hdf5"); + println!("\n\n test_load_hdf5 {:?}", fname); + // now recall that data are stored in row order. + let mut anndata = annhdf5::AnnBenchmarkData::new(fname).unwrap(); + // pre normalisation to use Dot computations instead of Cosine + anndata.do_l2_normalization(); + // run bench + let nb_elem = anndata.train_data.len(); + let knbn_max = anndata.test_distances.dim().1; + info!( + "Train size : {}, test size : {}", + nb_elem, + anndata.test_data.len() + ); + info!("Nb neighbours answers for test data : {} \n\n", knbn_max); + // + let max_nb_connection = 24; + let ef_c = 800; + println!( + " max_nb_conn : {:?}, ef_construction : {:?} ", + max_nb_connection, ef_c + ); + let nb_layer = 16.min((nb_elem as f32).ln().trunc() as usize); + println!( + " number of elements to insert {:?} , setting max nb layer to {:?} ef_construction {:?}", + nb_elem, nb_layer, ef_c + ); + let nb_search = anndata.test_data.len(); + println!(" number of search {:?}", nb_search); + // Hnsw allocation + let mut hnsw = + Hnsw::::new(max_nb_connection, nb_elem, nb_layer, ef_c, DistDot {}); + // + hnsw.set_extend_candidates(true); + hnsw.modify_level_scale(0.5); + // + // parallel insertion + let start = ProcessTime::now(); + let now = SystemTime::now(); + let data_for_par_insertion = anndata + .train_data + .iter() + .map(|x| (x.0.as_slice(), x.1)) + .collect(); + if parallel { + println!(" \n parallel insertion"); + hnsw.parallel_insert_slice(&data_for_par_insertion); + } else { + println!(" \n serial insertion"); + for d in data_for_par_insertion { + hnsw.insert_slice(d); + } + } + let cpu_time: Duration = start.elapsed(); + // + println!( + "\n hnsw data insertion cpu time {:?} system time {:?} ", + cpu_time, + now.elapsed() + ); + hnsw.dump_layer_info(); + println!(" hnsw data nb point inserted {:?}", hnsw.get_nb_point()); + // + // Now the bench with 10 neighbours + // + let knbn = 10; + let ef_search = 48; + search(&mut hnsw, knbn, ef_search, &anndata); + + let knbn = 100; + let ef_search = 128; + search(&mut hnsw, knbn, ef_search, &anndata); +} + +pub fn search( + hnsw: &mut Hnsw, + knbn: usize, + ef_search: usize, + anndata: &annhdf5::AnnBenchmarkData, +) where + Dist: Distance + Send + Sync, +{ + println!("\n\n ef_search : {:?} knbn : {:?} ", ef_search, knbn); + let parallel = true; + // + let nb_elem = anndata.train_data.len(); + let nb_search = anndata.test_data.len(); + // + let mut recalls = Vec::::with_capacity(nb_elem); + let mut nb_returned = Vec::::with_capacity(nb_elem); + let mut last_distances_ratio = Vec::::with_capacity(nb_elem); + let mut knn_neighbours_for_tests = Vec::>::with_capacity(nb_elem); + hnsw.set_searching_mode(true); + println!("searching with ef : {:?}", ef_search); + let start = ProcessTime::now(); + let now = SystemTime::now(); + // search + if parallel { + println!(" \n parallel search"); + knn_neighbours_for_tests = hnsw.parallel_search(&anndata.test_data, knbn, ef_search); + } else { + println!(" \n serial search"); + for i in 0..anndata.test_data.len() { + let knn_neighbours: Vec = + hnsw.search(&anndata.test_data[i], knbn, ef_search); + knn_neighbours_for_tests.push(knn_neighbours); + } + } + let cpu_time = start.elapsed(); + let search_cpu_time = cpu_time.as_micros() as f32; + let search_sys_time = now.elapsed().unwrap().as_micros() as f32; + println!( + "total cpu time for search requests {:?} , system time {:?} ", + search_cpu_time, + now.elapsed() + ); + // now compute recall rate + for i in 0..anndata.test_data.len() { + let max_dist = anndata.test_distances.row(i)[knbn - 1]; + let knn_neighbours_d: Vec = knn_neighbours_for_tests[i] + .iter() + .map(|p| p.distance) + .collect(); + nb_returned.push(knn_neighbours_d.len()); + let recall = knn_neighbours_d.iter().filter(|d| *d <= &max_dist).count(); + recalls.push(recall); + let mut ratio = 0.; + if !knn_neighbours_d.is_empty() { + ratio = knn_neighbours_d[knn_neighbours_d.len() - 1] / max_dist; + } + last_distances_ratio.push(ratio); + } + let mean_recall = (recalls.iter().sum::() as f32) / ((knbn * recalls.len()) as f32); + println!( + "\n mean fraction nb returned by search {:?} ", + (nb_returned.iter().sum::() as f32) / ((nb_returned.len() * knbn) as f32) + ); + println!( + "\n last distances ratio {:?} ", + last_distances_ratio.iter().sum::() / last_distances_ratio.len() as f32 + ); + println!( + "\n recall rate for {:?} is {:?} , nb req /s {:?}", + anndata.fname, + mean_recall, + (nb_search as f32) * 1.0e+6_f32 / search_sys_time + ); +} diff --git a/patches/hnsw_rs/examples/ann-mnist-784-euclidean.rs b/patches/hnsw_rs/examples/ann-mnist-784-euclidean.rs new file mode 100644 index 000000000..2045fd582 --- /dev/null +++ b/patches/hnsw_rs/examples/ann-mnist-784-euclidean.rs @@ -0,0 +1,162 @@ +#![allow(clippy::needless_range_loop)] + +use cpu_time::ProcessTime; +use std::time::{Duration, SystemTime}; + +// search in serial mode i7-core @2.7Ghz for 10 fist neighbours +// max_nb_conn ef_cons ef_search scale_factor extend keep pruned recall req/s last ratio +// +// 12 400 12 1 0 0 0.917 6486 1.005 +// 24 400 24 1 1 0 0.9779 3456 1.001 + +// parallel mode 4 i7-core @2.7Ghz +// max_nb_conn ef_cons ef_search scale_factor extend keep pruned recall req/s last ratio +// 24 400 24 1 0 0 0.977 12566 1.001 +// 24 400 12 1 0 0 0.947 18425 1.003 + +// 8 hyperthreaded i7-core @ 2.3 Ghz +// 24 400 24 1 0 0 0.977 22197 1.001 + +// 24 core Core(TM) i9-13900HX simdeez +// 24 400 24 1 0 0 0.977 62000 1.001 + +// 24 core Core(TM) i9-13900HX simdeez with modify_level_scale at 0.5 +// 24 400 24 0.5 0 0 0.990 58722 1.000 + +use anndists::dist::*; +use hnsw_rs::prelude::*; +use log::info; + +mod utils; +use utils::*; + +pub fn main() { + let mut parallel = true; + // + let fname = String::from("/home/jpboth/Data/ANN/fashion-mnist-784-euclidean.hdf5"); + println!("\n\n test_load_hdf5 {:?}", fname); + // now recall that data are stored in row order. + let anndata = annhdf5::AnnBenchmarkData::new(fname).unwrap(); + let knbn_max = anndata.test_distances.dim().1; + let nb_elem = anndata.train_data.len(); + info!( + "Train size : {}, test size : {}", + nb_elem, + anndata.test_data.len() + ); + info!("Nb neighbours answers for test data : {}", knbn_max); + // + let max_nb_connection = 24; + let nb_layer = 16.min((nb_elem as f32).ln().trunc() as usize); + let ef_c = 400; + println!( + " number of elements to insert {:?} , setting max nb layer to {:?} ef_construction {:?}", + nb_elem, nb_layer, ef_c + ); + println!( + " =====================================================================================" + ); + let nb_search = anndata.test_data.len(); + println!(" number of search {:?}", nb_search); + + let mut hnsw = Hnsw::::new(max_nb_connection, nb_elem, nb_layer, ef_c, DistL2 {}); + hnsw.set_extend_candidates(false); + // + hnsw.modify_level_scale(0.25); + // parallel insertion + let mut start = ProcessTime::now(); + let mut now = SystemTime::now(); + let data_for_par_insertion = anndata + .train_data + .iter() + .map(|x| (x.0.as_slice(), x.1)) + .collect(); + if parallel { + println!(" \n parallel insertion"); + hnsw.parallel_insert_slice(&data_for_par_insertion); + } else { + println!(" \n serial insertion"); + for d in data_for_par_insertion { + hnsw.insert_slice(d); + } + } + let mut cpu_time: Duration = start.elapsed(); + // + println!( + "\n hnsw data insertion cpu time {:?} system time {:?} ", + cpu_time, + now.elapsed() + ); + hnsw.dump_layer_info(); + println!(" hnsw data nb point inserted {:?}", hnsw.get_nb_point()); + // + // Now the bench with 10 neighbours + // + let mut recalls = Vec::::with_capacity(nb_elem); + let mut nb_returned = Vec::::with_capacity(nb_elem); + let mut last_distances_ratio = Vec::::with_capacity(nb_elem); + let mut knn_neighbours_for_tests = Vec::>::with_capacity(nb_elem); + hnsw.set_searching_mode(true); + let knbn = 10; + let ef_c = max_nb_connection; + println!("\n searching with ef : {:?}", ef_c); + start = ProcessTime::now(); + now = SystemTime::now(); + // search + parallel = true; + if parallel { + println!(" \n parallel search"); + knn_neighbours_for_tests = hnsw.parallel_search(&anndata.test_data, knbn, ef_c); + } else { + println!(" \n serial search"); + for i in 0..anndata.test_data.len() { + let knn_neighbours: Vec = hnsw.search(&anndata.test_data[i], knbn, ef_c); + knn_neighbours_for_tests.push(knn_neighbours); + } + } + cpu_time = start.elapsed(); + let search_sys_time = now.elapsed().unwrap().as_micros() as f32; + let search_cpu_time = cpu_time.as_micros() as f32; + println!( + "total cpu time for search requests {:?} , system time {:?} ", + search_cpu_time, search_sys_time + ); + // now compute recall rate + for i in 0..anndata.test_data.len() { + let true_distances = anndata.test_distances.row(i); + let max_dist = true_distances[knbn - 1]; + let mut _knn_neighbours_id: Vec = + knn_neighbours_for_tests[i].iter().map(|p| p.d_id).collect(); + let knn_neighbours_dist: Vec = knn_neighbours_for_tests[i] + .iter() + .map(|p| p.distance) + .collect(); + nb_returned.push(knn_neighbours_dist.len()); + // count how many distances of knn_neighbours_dist are less than + let recall = knn_neighbours_dist + .iter() + .filter(|x| *x <= &max_dist) + .count(); + recalls.push(recall); + let mut ratio = 0.; + if !knn_neighbours_dist.is_empty() { + ratio = knn_neighbours_dist[knn_neighbours_dist.len() - 1] / max_dist; + } + last_distances_ratio.push(ratio); + } + let mean_recall = (recalls.iter().sum::() as f32) / ((knbn * recalls.len()) as f32); + println!( + "\n mean fraction nb returned by search {:?} ", + (nb_returned.iter().sum::() as f32) / ((nb_returned.len() * knbn) as f32) + ); + println!( + "\n last distances ratio {:?} ", + last_distances_ratio.iter().sum::() / last_distances_ratio.len() as f32 + ); + println!( + "\n recall rate for {:?} is {:?} , nb req /s {:?}", + anndata.fname, + mean_recall, + (nb_search as f32) * 1.0e+6_f32 / search_sys_time + ); +} diff --git a/patches/hnsw_rs/examples/ann-sift1m-128-euclidean.rs b/patches/hnsw_rs/examples/ann-sift1m-128-euclidean.rs new file mode 100644 index 000000000..fa2ef9987 --- /dev/null +++ b/patches/hnsw_rs/examples/ann-sift1m-128-euclidean.rs @@ -0,0 +1,196 @@ +#![allow(clippy::needless_range_loop)] + +use cpu_time::ProcessTime; +use env_logger::Builder; +use std::time::{Duration, SystemTime}; + +use anndists::dist::*; +use log::info; + +// search in paralle mode 8 core i7-10875H @2.3Ghz time 100 neighbours + +// max_nb_conn ef_cons ef_search scale_factor extend keep pruned recall req/s last ratio +// +// 64 800 64 1 0 0 0.976 4894 1.001 +// 64 800 128 1 0 0 0.985 3811 1.00064 +// 64 800 128 1 1 0 0.9854 3765 1.0 + +// 64 1600 64 1 0 0 0.9877 3419. 1.0005 + +// search in parallel mode 8 core i7-10875H @2.3Ghz time for 10 neighbours + +// 64 1600 64 1 0 0 0.9907 6100 1.0004 +// 64 1600 128 1 0 0 0.9959 3077. 1.0001 + +// 24 core Core(TM) i9-13900HX simdeez + +// 64 1600 64 1 0 0 0.9907 15258 1.0004 +// 64 1600 128 1 0 0 0.9957 8296 1.0002 + +// 24 core Core(TM) i9-13900HX simdeez with level scale modification factor 0.5 +//============================================================================= + +// 48 1600 64 0.5 0 0 0.9938 14073 1.0002 +// 48 1600 128 0.5 0 0 0.9992 7906 1.0000 + +// with an AMD ryzen 9 7950X 16-Core simdeez with level scale modification factor 0.5 +//============================================================================= +// 48 1600 64 0.5 0 0 0.9938 17000 1.0002 +// 48 1600 128 0.5 0 0 0.9992 9600 1.0000 + +use hnsw_rs::prelude::*; + +mod utils; +use utils::*; + +pub fn main() { + // + Builder::from_default_env().init(); + // + let parallel = true; + // + let fname = String::from("/home/jpboth/Data/ANN/sift1m-128-euclidean.hdf5"); + println!("\n\n test_load_hdf5 {:?}", fname); + // now recall that data are stored in row order. + let anndata = annhdf5::AnnBenchmarkData::new(fname).unwrap(); + // run bench + let knbn_max = anndata.test_distances.dim().1; + let nb_elem = anndata.train_data.len(); + info!( + " train size : {}, test size : {}", + nb_elem, + anndata.test_data.len() + ); + info!(" nb neighbours answers for test data : {}", knbn_max); + // + let max_nb_connection = 48; + let nb_layer = 16.min((nb_elem as f32).ln().trunc() as usize); + let ef_c = 1600; + // + println!( + " number of elements to insert {:?} , setting max nb layer to {:?} ef_construction {:?}", + nb_elem, nb_layer, ef_c + ); + println!( + " =====================================================================================" + ); + // + let mut hnsw = Hnsw::::new(max_nb_connection, nb_elem, nb_layer, ef_c, DistL2 {}); + // + let extend_flag = false; + info!("extend flag = {:?} ", extend_flag); + hnsw.set_extend_candidates(extend_flag); + hnsw.modify_level_scale(0.5); + // + // parallel insertion + let start = ProcessTime::now(); + let now = SystemTime::now(); + let data_for_par_insertion = anndata + .train_data + .iter() + .map(|x| (x.0.as_slice(), x.1)) + .collect(); + if parallel { + println!(" \n parallel insertion"); + hnsw.parallel_insert_slice(&data_for_par_insertion); + } else { + println!(" \n serial insertion"); + for d in data_for_par_insertion { + hnsw.insert_slice(d); + } + } + let cpu_time: Duration = start.elapsed(); + // + println!( + "\n hnsw data insertion cpu time {:?} system time {:?} ", + cpu_time, + now.elapsed() + ); + hnsw.dump_layer_info(); + println!(" hnsw data nb point inserted {:?}", hnsw.get_nb_point()); + // + // + let knbn = 10.min(knbn_max); + let ef_search = 64; + println!("searching with ef = {}", ef_search); + search(&mut hnsw, knbn, ef_search, &anndata); + // + println!("searching with ef = {}", ef_search); + let ef_search = 128; + search(&mut hnsw, knbn, ef_search, &anndata); +} + +pub fn search( + hnsw: &mut Hnsw, + knbn: usize, + ef_search: usize, + anndata: &annhdf5::AnnBenchmarkData, +) where + Dist: Distance + Send + Sync, +{ + println!("\n\n ef_search : {:?} knbn : {:?} ", ef_search, knbn); + let parallel = true; + // + let nb_elem = anndata.train_data.len(); + let nb_search = anndata.test_data.len(); + // + let mut recalls = Vec::::with_capacity(nb_elem); + let mut nb_returned = Vec::::with_capacity(nb_elem); + let mut last_distances_ratio = Vec::::with_capacity(nb_elem); + let mut knn_neighbours_for_tests = Vec::>::with_capacity(nb_elem); + hnsw.set_searching_mode(true); + println!("searching with ef : {:?}", ef_search); + let start = ProcessTime::now(); + let now = SystemTime::now(); + // search + if parallel { + println!(" \n parallel search"); + knn_neighbours_for_tests = hnsw.parallel_search(&anndata.test_data, knbn, ef_search); + } else { + println!(" \n serial search"); + for i in 0..anndata.test_data.len() { + let knn_neighbours: Vec = + hnsw.search(&anndata.test_data[i], knbn, ef_search); + knn_neighbours_for_tests.push(knn_neighbours); + } + } + let cpu_time = start.elapsed(); + let search_cpu_time = cpu_time.as_micros() as f32; + let search_sys_time = now.elapsed().unwrap().as_micros() as f32; + println!( + "total cpu time for search requests {:?} , system time {:?} ", + search_cpu_time, + now.elapsed() + ); + // now compute recall rate + for i in 0..anndata.test_data.len() { + let max_dist = anndata.test_distances.row(i)[knbn - 1]; + let knn_neighbours_d: Vec = knn_neighbours_for_tests[i] + .iter() + .map(|p| p.distance) + .collect(); + nb_returned.push(knn_neighbours_d.len()); + let recall = knn_neighbours_d.iter().filter(|d| *d <= &max_dist).count(); + recalls.push(recall); + let mut ratio = 0.; + if !knn_neighbours_d.is_empty() { + ratio = knn_neighbours_d[knn_neighbours_d.len() - 1] / max_dist; + } + last_distances_ratio.push(ratio); + } + let mean_recall = (recalls.iter().sum::() as f32) / ((knbn * recalls.len()) as f32); + println!( + "\n mean fraction nb returned by search {:?} ", + (nb_returned.iter().sum::() as f32) / ((nb_returned.len() * knbn) as f32) + ); + println!( + "\n last distances ratio {:?} ", + last_distances_ratio.iter().sum::() / last_distances_ratio.len() as f32 + ); + println!( + "\n recall rate for {:?} is {:?} , nb req /s {:?}", + anndata.fname, + mean_recall, + (nb_search as f32) * 1.0e+6_f32 / search_sys_time + ); +} // end of search diff --git a/patches/hnsw_rs/examples/levensthein.rs b/patches/hnsw_rs/examples/levensthein.rs new file mode 100644 index 000000000..eb7b9ec08 --- /dev/null +++ b/patches/hnsw_rs/examples/levensthein.rs @@ -0,0 +1,63 @@ +use anndists::dist::*; + +use hnsw_rs::prelude::*; +use rand::Rng; +use std::iter; + +fn generate(len: usize) -> String { + const CHARSET: &[u8] = b"abcdefghij"; + let mut rng = rand::rng(); + let one_char = || CHARSET[rng.random_range(0..CHARSET.len())] as char; + iter::repeat_with(one_char).take(len).collect() +} + +fn main() { + let nb_elem = 500000; // number of possible words in the dictionary + let max_nb_connection = 15; + let nb_layer = 16.min((nb_elem as f32).ln().trunc() as usize); + let ef_c = 200; + let nb_words = 1000; + let hns = Hnsw::::new( + max_nb_connection, + nb_elem, + nb_layer, + ef_c, + DistLevenshtein {}, + ); + let mut words = vec![]; + for _n in 1..nb_words { + let tw = generate(5); + words.push(tw); + } + words.push(String::from("abcdj")); + // + for (i, w) in words.iter().enumerate() { + let vec: Vec = w.chars().map(|c| c as u16).collect(); + hns.insert((&vec, i)); + } + // create a filter + let mut filter: Vec = Vec::new(); + for i in 1..100 { + filter.push(i); + } + // + let ef_search: usize = 30; + let tosearch: Vec = "abcde".chars().map(|c| c as u16).collect(); + // + println!("========== search with filter "); + let res = hns.search_filter(&tosearch, 10, ef_search, Some(&filter)); + for r in res { + println!( + "Word: {:?} Id: {:?} Distance: {:?}", + words[r.d_id], r.d_id, r.distance + ); + } + println!("========== search without filter "); + let res3 = hns.search(&tosearch, 10, ef_search); + for r in res3 { + println!( + "Word: {:?} Id: {:?} Distance: {:?}", + words[r.d_id], r.d_id, r.distance + ); + } +} diff --git a/patches/hnsw_rs/examples/random.rs b/patches/hnsw_rs/examples/random.rs new file mode 100644 index 000000000..9abba0e28 --- /dev/null +++ b/patches/hnsw_rs/examples/random.rs @@ -0,0 +1,80 @@ +#![allow(clippy::needless_range_loop)] +#![allow(clippy::range_zip_with_len)] + +use cpu_time::ProcessTime; +use rand::distr::Uniform; +use rand::prelude::*; +use std::time::{Duration, SystemTime}; + +use anndists::dist::*; +use hnsw_rs::prelude::*; + +fn main() { + env_logger::Builder::from_default_env().init(); + // + let nb_elem = 500000; + let dim = 25; + // generate nb_elem colmuns vectors of dimension dim + let mut rng = rand::rng(); + let unif = rand::distr::StandardUniform; + let mut data = Vec::with_capacity(nb_elem); + for _ in 0..nb_elem { + let column = (0..dim).map(|_| rng.sample(unif)).collect::>(); + data.push(column); + } + // give an id to each data + let data_with_id = data.iter().zip(0..data.len()).collect::>(); + + let ef_c = 200; + let max_nb_connection = 15; + let nb_layer = 16.min((nb_elem as f32).ln().trunc() as usize); + let hns = Hnsw::::new(max_nb_connection, nb_elem, nb_layer, ef_c, DistL2 {}); + let mut start = ProcessTime::now(); + let mut begin_t = SystemTime::now(); + hns.parallel_insert(&data_with_id); + let mut cpu_time: Duration = start.elapsed(); + println!(" hnsw data insertion cpu time {:?}", cpu_time); + println!( + " hnsw data insertion parallel, system time {:?} \n", + begin_t.elapsed().unwrap() + ); + hns.dump_layer_info(); + println!( + " parallel hnsw data nb point inserted {:?}", + hns.get_nb_point() + ); + // + // serial insertion + // + let hns = Hnsw::::new(max_nb_connection, nb_elem, nb_layer, ef_c, DistL2 {}); + start = ProcessTime::now(); + begin_t = SystemTime::now(); + for _i in 0..data_with_id.len() { + hns.insert((data_with_id[_i].0.as_slice(), data_with_id[_i].1)) + } + cpu_time = start.elapsed(); + println!("\n\n serial hnsw data insertion {:?}", cpu_time); + println!( + " hnsw data insertion serial, system time {:?}", + begin_t.elapsed().unwrap() + ); + hns.dump_layer_info(); + println!( + " serial hnsw data nb point inserted {:?}", + hns.get_nb_point() + ); + + let ef_search = max_nb_connection * 2; + let knbn = 10; + // + for _iter in 0..100 { + let mut r_vec = Vec::::with_capacity(dim); + let mut rng = rand::rng(); + let unif = Uniform::::new(0., 1.).unwrap(); + for _ in 0..dim { + r_vec.push(rng.sample(unif)); + } + // + let _neighbours = hns.search(&r_vec, knbn, ef_search); + } +} diff --git a/patches/hnsw_rs/examples/utils/annhdf5.rs b/patches/hnsw_rs/examples/utils/annhdf5.rs new file mode 100644 index 000000000..a827479dd --- /dev/null +++ b/patches/hnsw_rs/examples/utils/annhdf5.rs @@ -0,0 +1,233 @@ +//! This file provides hdf5 utilities to load ann-benchmarks hdf5 data files +//! As the libray does not depend on hdf5 nor on ndarray, it is nearly the same for both +//! ann benchmarks. + +use ndarray::Array2; + +use ::hdf5::*; +use log::debug; + +// datasets +// . distances (nbojects, dim) f32 matrix for tests objects +// . neighbors (nbobjects, nbnearest) int32 matrix giving the num of nearest neighbors in train data +// . test (nbobjects, dim) f32 matrix test data +// . train (nbobjects, dim) f32 matrix train data + +/// a structure to load hdf5 data file benchmarks from https://github.com/erikbern/ann-benchmarks +pub struct AnnBenchmarkData { + pub fname: String, + /// distances from each test object to its nearest neighbours. + pub test_distances: Array2, + // for each test data , id of its nearest neighbours + #[allow(unused)] + pub test_neighbours: Array2, + /// list of vectors for which we will search ann. + pub test_data: Vec>, + /// list of data vectors and id + pub train_data: Vec<(Vec, usize)>, + /// searched results. first neighbours for each test data. + #[allow(unused)] + pub searched_neighbours: Vec>, + /// distances of neighbours obtained of each test + #[allow(unused)] + pub searched_distances: Vec>, +} + +impl AnnBenchmarkData { + pub fn new(fname: String) -> Result { + let res = hdf5::File::open(fname.clone()); + if res.is_err() { + println!("you must download file {:?}", fname); + panic!( + "download benchmark file some where and modify examples source file accordingly" + ); + } + let file = res.ok().unwrap(); + // + // get test distances + // + let res_distances = file.dataset("distances"); + if res_distances.is_err() { + // let reader = hdf5::Reader::::new(&test_distance); + panic!("error getting distances dataset"); + } + let distances = res_distances.unwrap(); + let shape = distances.shape(); + assert_eq!(shape.len(), 2); + let dataf32 = distances.dtype().unwrap().is::(); + if !dataf32 { + // error + panic!("error getting type distances dataset"); + } + // read really data + let res = distances.read_2d::(); + if res.is_err() { + // some error + panic!("error reading distances dataset"); + } + let test_distances = res.unwrap(); + // a check for row order + debug!( + "First 2 distances for first test {:?} {:?} ", + test_distances.get((0, 0)).unwrap(), + test_distances.get((0, 1)).unwrap() + ); + // + // read neighbours + // + let res_neighbours = file.dataset("neighbors"); + if res_neighbours.is_err() { + // let reader = hdf5::Reader::::new(&test_distance); + panic!("error getting neighbours"); + } + let neighbours = res_neighbours.unwrap(); + let shape = neighbours.shape(); + assert_eq!(shape.len(), 2); + println!("neighbours shape : {:?}", shape); + let datai32 = neighbours.dtype().unwrap().is::(); + if !datai32 { + // error + panic!("error getting type neighbours"); + } + // read really data + let res = neighbours.read_2d::(); + if res.is_err() { + // some error + panic!("error reading neighbours dataset"); + } + let test_neighbours = res.unwrap(); + debug!( + "First 2 neighbours for first test {:?} {:?} ", + test_neighbours.get((0, 0)).unwrap(), + test_neighbours.get((0, 1)).unwrap() + ); + println!("\n 10 first neighbours for first vector : "); + for i in 0..10 { + print!(" {:?} ", test_neighbours.get((0, i)).unwrap()); + } + println!("\n 10 first neighbours for second vector : "); + for i in 0..10 { + print!(" {:?} ", test_neighbours.get((1, i)).unwrap()); + } + // + // read test data + // =============== + // + let res_testdata = file.dataset("test"); + if res_testdata.is_err() { + panic!("error getting test de notataset"); + } + let test_data = res_testdata.unwrap(); + let shape = test_data.shape(); // nota shape returns a slice, dim returns a t-uple + assert_eq!(shape.len(), 2); + let dataf32 = test_data.dtype().unwrap().is::(); + if !dataf32 { + panic!("error getting type de notistances dataset"); + } + // read really datae not + let res = test_data.read_2d::(); + if res.is_err() { + // some error + panic!("error reading distances dataset"); + } + let test_data_2d = res.unwrap(); + let mut test_data = Vec::>::with_capacity(shape[1]); + let (nbrow, nbcolumn) = test_data_2d.dim(); + println!(" test data, nb element {:?}, dim : {:?}", nbrow, nbcolumn); + for i in 0..nbrow { + let mut vec = Vec::with_capacity(nbcolumn); + for j in 0..nbcolumn { + vec.push(*test_data_2d.get((i, j)).unwrap()); + } + test_data.push(vec); + } + // + // loaf train data + // + let res_traindata = file.dataset("train"); + if res_traindata.is_err() { + panic!("error getting distances dataset"); + } + let train_data = res_traindata.unwrap(); + let train_shape = train_data.shape(); + assert_eq!(shape.len(), 2); + if test_data_2d.dim().1 != train_shape[1] { + println!("test and train have not the same dimension"); + panic!(); + } + println!( + "\n train data shape : {:?}, nbvector {:?} ", + train_shape, train_shape[0] + ); + let dataf32 = train_data.dtype().unwrap().is::(); + if !dataf32 { + // error + panic!("error getting type distances dataset"); + } + // read really data + let res = train_data.read_2d::(); + if res.is_err() { + // some error + panic!("error reading distances dataset"); + } + let train_data_2d = res.unwrap(); + let mut train_data = Vec::<(Vec, usize)>::with_capacity(shape[1]); + let (nbrow, nbcolumn) = train_data_2d.dim(); + for i in 0..nbrow { + let mut vec = Vec::with_capacity(nbcolumn); + for j in 0..nbcolumn { + vec.push(*train_data_2d.get((i, j)).unwrap()); + } + train_data.push((vec, i)); + } + // + // now allocate array's for result + // + println!( + " allocating vector for search neighbours answer : {:?}", + test_data.len() + ); + let searched_neighbours = Vec::>::with_capacity(test_data.len()); + let searched_distances = Vec::>::with_capacity(test_data.len()); + // searched_distances + Ok(AnnBenchmarkData { + fname: fname.clone(), + test_distances, + test_neighbours, + test_data, + train_data, + searched_neighbours, + searched_distances, + }) + } // end new + + /// do l2 normalisation of test and train vector to use DistDot metrinc instead DistCosine to spare cpu + #[allow(unused)] + pub fn do_l2_normalization(&mut self) { + for i in 0..self.test_data.len() { + anndists::dist::l2_normalize(&mut self.test_data[i]); + } + for i in 0..self.train_data.len() { + anndists::dist::l2_normalize(&mut self.train_data[i].0); + } + } // end of do_l2_normalization +} // end of impl block + +#[cfg(test)] + +mod tests { + + use super::*; + + #[test] + + fn test_load_hdf5() { + env_logger::Builder::from_default_env().init(); + // + let fname = String::from("/home.2/Data/ANN/glove-25-angular.hdf5"); + println!("\n\n test_load_hdf5 {:?}", fname); + // now recall that data are stored in row order. + let _anndata = AnnBenchmarkData::new(fname).unwrap(); + // + } // end of test_load_hdf5 +} // end of module test diff --git a/patches/hnsw_rs/examples/utils/mod.rs b/patches/hnsw_rs/examples/utils/mod.rs new file mode 100644 index 000000000..a9c20d551 --- /dev/null +++ b/patches/hnsw_rs/examples/utils/mod.rs @@ -0,0 +1,3 @@ +//! hdf5 utilities for examples + +pub mod annhdf5; diff --git a/patches/hnsw_rs/src/api.rs b/patches/hnsw_rs/src/api.rs new file mode 100644 index 000000000..5d75fbc53 --- /dev/null +++ b/patches/hnsw_rs/src/api.rs @@ -0,0 +1,87 @@ +//! Api for external language. +//! This file provides a trait to be used as an opaque pointer for C or Julia calls used in file libext.rs + +use std::path::Path; + +use serde::{de::DeserializeOwned, Serialize}; + +use crate::hnsw::*; +use crate::hnswio::*; +use anndists::dist::distances::Distance; +use log::info; + +pub trait AnnT { + /// type of data vectors + type Val; + // + fn insert_data(&mut self, data: &[Self::Val], id: usize); + // + fn search_neighbours(&self, data: &[Self::Val], knbn: usize, ef_s: usize) -> Vec; + // + fn parallel_insert_data(&mut self, data: &[(&Vec, usize)]); + // + fn parallel_search_neighbours( + &self, + data: &[Vec], + knbn: usize, + ef_s: usize, + ) -> Vec>; + /// + /// dumps a data and graph in 2 files. + /// Datas are dumped in file filename.hnsw.data and graph in filename.hnsw.graph + /// + /// **We do not overwrite old files if they are currently in use by memory map** + /// If these files already exist , they are not overwritten and a unique filename is generated by concatenating a random number to filename. + /// The function returns the basename used for the dump + fn file_dump(&self, path: &Path, file_basename: &str) -> anyhow::Result; +} + +impl AnnT for Hnsw<'_, T, D> +where + T: Serialize + DeserializeOwned + Clone + Send + Sync, + D: Distance + Send + Sync, +{ + type Val = T; + // + fn insert_data(&mut self, data: &[Self::Val], id: usize) { + self.insert((data, id)); + } + // + fn search_neighbours(&self, data: &[T], knbn: usize, ef_s: usize) -> Vec { + self.search(data, knbn, ef_s) + } + fn parallel_insert_data(&mut self, data: &[(&Vec, usize)]) { + self.parallel_insert(data); + } + + fn parallel_search_neighbours( + &self, + data: &[Vec], + knbn: usize, + ef_s: usize, + ) -> Vec> { + self.parallel_search(data, knbn, ef_s) + } + + // The main entry point to do a dump. + // It will generate two files one for the graph part of the data. The other for the real data points of the structure. + // The names of file are $filename.hnsw.graph for the graph and $filename.hnsw.data. + fn file_dump(&self, path: &Path, file_basename: &str) -> anyhow::Result { + info!("In Hnsw::file_dump"); + // + // do not overwrite if mmap is active + let overwrite = !self.get_datamap_opt(); + let mut dumpinit = DumpInit::new(path, file_basename, overwrite); + let dumpname = dumpinit.get_basename().clone(); + // + let res = self.dump(DumpMode::Full, &mut dumpinit); + // + dumpinit.flush()?; + info!("\n End of dump, file basename : {}\n", &dumpname); + if res.is_ok() { + Ok(dumpname) + } else { + Err(anyhow::anyhow!("unexpected error")) + } + } // end of dump +} // end of impl block AnnT for Hnsw diff --git a/patches/hnsw_rs/src/datamap.rs b/patches/hnsw_rs/src/datamap.rs new file mode 100644 index 000000000..45945f154 --- /dev/null +++ b/patches/hnsw_rs/src/datamap.rs @@ -0,0 +1,457 @@ +//! This module provides a memory mapping of Data vectors filling the Hnsw structure. +//! It is used by the module [hnswio] and also gives access to an iterator over data without loading the graph. +//! +//! We mmap the file and provide +//! - a Hashmap from DataId to address +//! - an interface for retrieving just data vectors loaded in the hnsw structure. + +use std::io::BufReader; + +use std::fs::{File, OpenOptions}; +use std::path::{Path, PathBuf}; + +use indexmap::map::IndexMap; +use log::{debug, error, info, trace}; +use mmap_rs::{Mmap, MmapOptions}; + +use crate::hnsw::DataId; +use crate::hnswio; + +use crate::hnswio::MAGICDATAP; +/// This structure uses the data part of the dump of a Hnsw structure to retrieve the data. +/// The data is access via a mmap of the data file, so memory is spared at the expense of page loading. +// possibly to be used in graph to spare memory? +pub struct DataMap { + /// File containing Points data + _datapath: PathBuf, + /// The mmap structure + mmap: Mmap, + /// map a dataId to an address where we get a bson encoded vector of type T + hmap: IndexMap, + /// type name of Data + t_name: String, + /// dimension of data vector + dimension: usize, + // + distname: String, +} // end of DataMap + +impl DataMap { + // TODO: specifiy mmap option + /// The fname argument corresponds to the basename of the dump. + /// To reload from file fname.hnsw.data just pass fname as argument. + /// The dir argument is the directory where the fname.hnsw.data and fname.hnsw.graph reside. + pub fn from_hnswdump( + dir: &Path, + file_name: &str, + ) -> Result { + // reload description to have data type, and check for dump version + let mut graphpath = PathBuf::from(dir); + graphpath.push(dir); + let mut filename = file_name.to_string(); + filename.push_str(".hnsw.graph"); + graphpath.push(filename); + let graphfileres = OpenOptions::new().read(true).open(&graphpath); + if graphfileres.is_err() { + println!("DataMap: could not open file {:?}", graphpath.as_os_str()); + std::process::exit(1); + } + let graphfile = graphfileres.unwrap(); + let mut graph_in = BufReader::new(graphfile); + // we need to call load_description first to get distance name + let hnsw_description = hnswio::load_description(&mut graph_in).unwrap(); + if hnsw_description.format_version <= 2 { + let msg = String::from( + "from_hnsw::from_hnsw : data mapping is only possible for dumps with the version > 0.1.19 of this crate", + ); + error!( + "Data mapping is only possible for dumps with the version > 0.1.19 of this crate" + ); + return Err(msg); + } + let distname = hnsw_description.distname.clone(); + let t_name = hnsw_description.get_typename(); + // check typename coherence + info!("Got typename from reload : {:?}", t_name); + if std::any::type_name::() != t_name { + error!( + "Description has typename {:?}, function type argument is : {:?}", + t_name, + std::any::type_name::() + ); + return Err(String::from("type error")); + } + // get dimension as declared in description + let descr_dimension = hnsw_description.get_dimension(); + drop(graph_in); + // + // we know data filename is hnswdump.hnsw.data + // + let mut datapath = PathBuf::new(); + datapath.push(dir); + let mut filename = file_name.to_string(); + filename.push_str(".hnsw.data"); + datapath.push(filename); + // + let meta = std::fs::metadata(&datapath); + if meta.is_err() { + error!("Could not open file : {:?}", &datapath); + std::process::exit(1); + } + let fsize = meta.unwrap().len().try_into().unwrap(); + // + let file_res = File::open(&datapath); + if file_res.is_err() { + error!("Could not open file : {:?}", &datapath); + std::process::exit(1); + } + let file = file_res.unwrap(); + let offset = 0; + // + let mmap_opt = MmapOptions::new(fsize).unwrap(); + let mmap_opt = unsafe { mmap_opt.with_file(&file, offset) }; + let mapping_res = mmap_opt.map(); + if mapping_res.is_err() { + error!("Could not memory map : {:?}", &datapath); + std::process::exit(1); + } + let mmap = mapping_res.unwrap(); + // + info!("Mmap done on file : {:?}", &datapath); + // + // where are we in decoding mmap slice? at beginning + // + let mapped_slice = mmap.as_slice(); + // + // where are we in decoding mmap slice? + let mut current_mmap_addr = 0usize; + let mut usize_slice = [0u8; std::mem::size_of::()]; + // check magic + let mut u32_slice = [0u8; std::mem::size_of::()]; + u32_slice.copy_from_slice( + &mapped_slice[current_mmap_addr..current_mmap_addr + std::mem::size_of::()], + ); + current_mmap_addr += std::mem::size_of::(); + let magic = u32::from_ne_bytes(u32_slice); + assert_eq!(magic, MAGICDATAP, "magic not equal to MAGICDATAP in mmap"); + // get dimension + usize_slice.copy_from_slice( + &mapped_slice[current_mmap_addr..current_mmap_addr + std::mem::size_of::()], + ); + current_mmap_addr += std::mem::size_of::(); + let dimension = usize::from_ne_bytes(usize_slice); + if dimension != descr_dimension { + error!( + "Description and data do not agree on dimension, data got : {:?}, description got : {:?}", + dimension, descr_dimension + ); + return Err(String::from( + "description and data do not agree on dimension", + )); + } else { + info!("Got dimension : {:?}", dimension); + } + // + // now we know that each record consists in + // - MAGICDATAP (u32), DataId (u64), dimension (u64) and then (length of type in bytes * dimension) + // + let record_size = std::mem::size_of::() + + 2 * std::mem::size_of::() + + dimension * std::mem::size_of::(); + let residual = mmap.size() - current_mmap_addr; + info!( + "Mmap size {}, current_mmap_addr {}, residual : {}", + mmap.size(), + current_mmap_addr, + residual + ); + let nb_record = residual / record_size; + debug!("Record size : {}, nb_record : {}", record_size, nb_record); + // allocate hmap with correct capacity + let mut hmap = IndexMap::::with_capacity(nb_record); + // fill hmap to have address of each data point in file + let mut u64_slice = [0u8; std::mem::size_of::()]; + // + // now we loop on records + // + for i in 0..nb_record { + debug!("Record i : {}, addr : {}", i, current_mmap_addr); + // decode Magic + u32_slice.copy_from_slice( + &mapped_slice[current_mmap_addr..current_mmap_addr + std::mem::size_of::()], + ); + current_mmap_addr += std::mem::size_of::(); + let magic = u32::from_ne_bytes(u32_slice); + assert_eq!(magic, MAGICDATAP, "magic not equal to MAGICDATAP in mmap"); + // decode DataId + u64_slice.copy_from_slice( + &mapped_slice[current_mmap_addr..current_mmap_addr + std::mem::size_of::()], + ); + current_mmap_addr += std::mem::size_of::(); + let data_id = u64::from_ne_bytes(u64_slice) as usize; + debug!( + "Inserting in hmap : got dataid : {:?} current map address : {:?}", + data_id, current_mmap_addr + ); + // Note we store address where we have to decode dimension*size_of:: and full bson encoded vector + hmap.insert(data_id, current_mmap_addr); + // now read serialized length + u64_slice.copy_from_slice( + &mapped_slice[current_mmap_addr..current_mmap_addr + std::mem::size_of::()], + ); + current_mmap_addr += std::mem::size_of::(); + let serialized_len = u64::from_ne_bytes(u64_slice) as usize; + if i == 0 { + debug!("serialized bytes len to reload {:?}", serialized_len); + } + let mut v_serialized = vec![0; serialized_len]; + v_serialized.copy_from_slice( + &mapped_slice[current_mmap_addr..current_mmap_addr + serialized_len], + ); + current_mmap_addr += serialized_len; + let slice_t = + unsafe { std::slice::from_raw_parts(v_serialized.as_ptr() as *const T, dimension) }; + trace!( + "Deserialized v : {:?} address : {:?} ", + slice_t, + v_serialized.as_ptr() as *const T + ); + } // end of for on record + // + debug!("End of DataMap::from_hnsw."); + // + let datamap = DataMap { + _datapath: datapath, + mmap, + hmap, + t_name, + dimension: descr_dimension, + distname, + }; + // + Ok(datamap) + } // end of from_datas + + // + + /// returns true if type T corresponds to type as retrieved in DataMap. + /// This function can (should!) be used before calling [Self::get_data()] + pub fn check_data_type(&self) -> bool + where + T: 'static + Sized, + { + // we check last part of name of type + let tname_vec = self.t_name.rsplit_terminator("::").collect::>(); + + if tname_vec.last().is_none() { + let errmsg = "DataMap::check_data_type() cannot determine data type name "; + error!("DataMap::check_data_type() cannot determine data type name "); + std::panic!("DataMap::check_data_type(), {}", errmsg); + } + let tname_last = tname_vec.last().unwrap(); + // + let datat_name_arg = std::any::type_name::().to_string(); + let datat_name_vec = datat_name_arg + .rsplit_terminator("::") + .collect::>(); + + let datat_name_arg_last = datat_name_vec.last().unwrap(); + // + if datat_name_arg_last == tname_last { + true + } else { + info!( + "Data type in DataMap : {}, type arg = {}", + tname_last, datat_name_arg_last + ); + false + } + } // end of check_data_type + + // + + /// return the data corresponding to dataid. Access is done using mmap. + /// Function returns None if address is invalid + /// This function requires you know the type T. + /// **As mmap loading calls an unsafe function it is recommended to check the type name with [Self::check_data_type()]** + pub fn get_data<'a, T: Clone + std::fmt::Debug>(&'a self, dataid: &DataId) -> Option<&'a [T]> { + // + trace!("In DataMap::get_data, dataid : {:?}", dataid); + let address = self.hmap.get(dataid)?; + debug!("Address for id : {}, address : {:?}", dataid, address); + let mut current_mmap_addr = *address; + let mapped_slice = self.mmap.as_slice(); + let mut u64_slice = [0u8; std::mem::size_of::()]; + u64_slice.copy_from_slice( + &mapped_slice[current_mmap_addr..current_mmap_addr + std::mem::size_of::()], + ); + let serialized_len = u64::from_ne_bytes(u64_slice) as usize; + current_mmap_addr += std::mem::size_of::(); + trace!("Serialized bytes len to reload {:?}", serialized_len); + let slice_t = unsafe { + std::slice::from_raw_parts( + mapped_slice[current_mmap_addr..].as_ptr() as *const T, + self.dimension, + ) + }; + Some(slice_t) + } + + /// returns Keys in order they are in the file, thus optimizing file/memory access. + /// Note that in case of parallel insertion this can be different from insertion odrer. + pub fn get_dataid_iter(&self) -> indexmap::map::Keys<'_, DataId, usize> { + self.hmap.keys() + } + + /// returns full data type name + pub fn get_data_typename(&self) -> String { + self.t_name.clone() + } + + /// returns full data type name + pub fn get_distname(&self) -> String { + self.distname.clone() + } + + /// return the number of data in mmap + pub fn get_nb_data(&self) -> usize { + self.hmap.len() + } +} // end of impl DataMap + +//===================================================================================== + +#[cfg(test)] + +mod tests { + + use super::*; + + use crate::hnswio::HnswIo; + use anndists::dist::*; + + pub use crate::api::AnnT; + use crate::prelude::*; + + use rand::distr::{Distribution, Uniform}; + + fn log_init_test() { + let _ = env_logger::builder().is_test(true).try_init(); + } + + #[test] + fn test_file_mmap() { + println!("\n\n test_file_mmap"); + log_init_test(); + // generate a random test + let mut rng = rand::rng(); + let unif = Uniform::::new(0., 1.).unwrap(); + // 1000 vectors of size 10 f32 + let nbcolumn = 50; + let nbrow = 11; + let mut xsi; + let mut data = Vec::with_capacity(nbcolumn); + for j in 0..nbcolumn { + data.push(Vec::with_capacity(nbrow)); + for _ in 0..nbrow { + xsi = unif.sample(&mut rng); + data[j].push(xsi); + } + debug!("j : {:?}, data : {:?} ", j, &data[j]); + } + // define hnsw + let ef_construct = 25; + let nb_connection = 10; + let hnsw = Hnsw::::new(nb_connection, nbcolumn, 16, ef_construct, DistL1 {}); + for (i, d) in data.iter().enumerate() { + hnsw.insert((d, i)); + } + // some loggin info + hnsw.dump_layer_info(); + // dump in a file. Must take care of name as tests runs in // !!! + let fname = "mmap_test"; + let directory = tempfile::tempdir().unwrap(); + let _res = hnsw.file_dump(directory.path(), fname); + + let check_reload = false; + if check_reload { + // We check we can reload + debug!("HNSW reload."); + let directory = tempfile::tempdir().unwrap(); + let mut reloader = HnswIo::new(directory.path(), fname); + let hnsw_loaded: Hnsw = reloader.load_hnsw::().unwrap(); + check_graph_equality(&hnsw_loaded, &hnsw); + info!("========= reload success, going to mmap reloading ========="); + } + // + // now we have check that datamap seems ok, test reload of hnsw with mmap + let datamap: DataMap = DataMap::from_hnswdump::(directory.path(), fname).unwrap(); + let nb_test = 30; + info!("Checking random access of id , nb test : {}", nb_test); + for _ in 0..nb_test { + // sample an id in 0..nb_data + let unif = Uniform::::new(0, nbcolumn).unwrap(); + let id = unif.sample(&mut rng); + let d = datamap.get_data::(&id); + assert!(d.is_some()); + if d.is_some() { + debug!("id = {}, v = {:?}", id, d.as_ref().unwrap()); + assert_eq!(d.as_ref().unwrap(), &data[id]); + } + } + // test iterator from datamap + let keys = datamap.get_dataid_iter(); + for k in keys { + let _data = datamap.get_data::(k); + } + } // end of test_file_mmap + + #[test] + fn test_mmap_iter() { + log_init_test(); + // generate a random test + let mut rng = rand::rng(); + let unif = Uniform::::new(0, 10000).unwrap(); + // 1000 vectors of size 10 f32 + let nbcolumn = 50; + let nbrow = 11; + let mut xsi; + let mut data = Vec::with_capacity(nbcolumn); + for j in 0..nbcolumn { + data.push(Vec::with_capacity(nbrow)); + for _ in 0..nbrow { + xsi = unif.sample(&mut rng); + data[j].push(xsi); + } + debug!("j : {:?}, data : {:?} ", j, &data[j]); + } + // define hnsw + let ef_construct = 25; + let nb_connection = 10; + let hnsw = Hnsw::::new(nb_connection, nbcolumn, 16, ef_construct, DistL1 {}); + for (i, d) in data.iter().enumerate() { + hnsw.insert((d, i)); + } + // some loggin info + hnsw.dump_layer_info(); + // dump in a file. Must take care of name as tests runs in // !!! + let fname = "mmap_order_test"; + let directory = tempfile::tempdir().unwrap(); + let _res = hnsw.file_dump(directory.path(), fname); + // now we have check that datamap seems ok, test reload of hnsw with mmap + let datamap: DataMap = DataMap::from_hnswdump::(directory.path(), fname).unwrap(); + // testing type check + assert!(datamap.check_data_type::()); + assert!(!datamap.check_data_type::()); + info!("Datamap iteration order checking"); + let keys = datamap.get_dataid_iter(); + for (i, dataid) in keys.enumerate() { + let v = datamap.get_data::(dataid).unwrap(); + assert_eq!(v, &data[*dataid], "dataid = {}, ukey = {}", dataid, i); + } + // rm files generated! + let _ = std::fs::remove_file("mmap_order_test.hnsw.data"); + let _ = std::fs::remove_file("mmap_order_test.hnsw.graph"); + } + // +} // end of mod tests diff --git a/patches/hnsw_rs/src/filter.rs b/patches/hnsw_rs/src/filter.rs new file mode 100644 index 000000000..e891e4cd2 --- /dev/null +++ b/patches/hnsw_rs/src/filter.rs @@ -0,0 +1,24 @@ +//! defines a trait for filtering requests. +//! See examples in tests/filtertest.rs + +use crate::prelude::DataId; + +/// Only queries returning true are taken into account along the search +pub trait FilterT { + fn hnsw_filter(&self, id: &DataId) -> bool; +} + +impl FilterT for Vec { + fn hnsw_filter(&self, id: &DataId) -> bool { + self.binary_search(id).is_ok() + } +} + +impl FilterT for F +where + F: Fn(&DataId) -> bool, +{ + fn hnsw_filter(&self, id: &DataId) -> bool { + self(id) + } +} diff --git a/patches/hnsw_rs/src/flatten.rs b/patches/hnsw_rs/src/flatten.rs new file mode 100644 index 000000000..21462c7bf --- /dev/null +++ b/patches/hnsw_rs/src/flatten.rs @@ -0,0 +1,200 @@ +//! This module provides conversion of a Point structure to a FlatPoint containing just the Id of a point +//! and those of its neighbours. +//! The whole Hnsw structure is then flattened into a Hashtable associating the data ID of a point to +//! its corresponding FlatPoint. +//! It can be used, for example, when reloading only the graph part of the data to have knowledge +//! of relative proximity of points as described just by their DataId +//! + +use hashbrown::HashMap; +use std::cmp::Ordering; + +use crate::hnsw; +use anndists::dist::distances::Distance; +use hnsw::*; +use log::error; + +// an ordering of Neighbour of a Point + +impl PartialEq for Neighbour { + fn eq(&self, other: &Neighbour) -> bool { + self.distance == other.distance + } // end eq +} + +impl Eq for Neighbour {} + +// order points by distance to self. +#[allow(clippy::non_canonical_partial_ord_impl)] +impl PartialOrd for Neighbour { + fn partial_cmp(&self, other: &Neighbour) -> Option { + self.distance.partial_cmp(&other.distance) + } // end cmp +} // end impl PartialOrd + +impl Ord for Neighbour { + fn cmp(&self, other: &Neighbour) -> Ordering { + if !self.distance.is_nan() && !other.distance.is_nan() { + self.distance.partial_cmp(&other.distance).unwrap() + } else { + panic!("got a NaN in a distance"); + } + } // end cmp +} + +/// a reduced version of point inserted in the Hnsw structure. +/// It contains original id of point as submitted to the struct Hnsw +/// an ordered (by distance) list of neighbours to the point +/// and it position in layers. +#[derive(Clone)] +pub struct FlatPoint { + /// an id coming from client using hnsw, should identify point uniquely + origin_id: DataId, + /// a point id identifying point as stored in our structure + p_id: PointId, + /// neighbours info + neighbours: Vec, +} + +impl FlatPoint { + /// returns the neighbours orderded by distance. + pub fn get_neighbours(&self) -> &Vec { + &self.neighbours + } + /// returns the origin id of the point + pub fn get_id(&self) -> DataId { + self.origin_id + } + // + pub fn get_p_id(&self) -> PointId { + self.p_id + } +} // end impl block for FlatPoint + +fn flatten_point(point: &Point) -> FlatPoint { + let neighbours = point.get_neighborhood_id(); + // now we flatten neighbours + let mut flat_neighbours = Vec::::new(); + for layer in neighbours { + for neighbour in layer { + flat_neighbours.push(neighbour); + } + } + flat_neighbours.sort_unstable(); + FlatPoint { + origin_id: point.get_origin_id(), + p_id: point.get_point_id(), + neighbours: flat_neighbours, + } +} // end of flatten_point + +/// A structure providing neighbourhood information of a point stored in the Hnsw structure given its DataId. +/// The structure uses the [FlatPoint] structure. +/// This structure can be obtained by FlatNeighborhood::from<&Hnsw> +pub struct FlatNeighborhood { + hash_t: HashMap, +} + +impl FlatNeighborhood { + /// get neighbour of a point given its id. + /// The neighbours are sorted in increasing distance from data_id. + pub fn get_neighbours(&self, p_id: DataId) -> Option> { + self.hash_t + .get(&p_id) + .map(|point| point.get_neighbours().clone()) + } +} // end impl block for FlatNeighborhood + +impl + Send + Sync> From<&Hnsw<'_, T, D>> + for FlatNeighborhood +{ + /// extract from the Hnsw strucure a hashtable mapping original DataId into a FlatPoint structure gathering its neighbourhood information. + /// Useful after reloading from a dump with T=NoData and D = NoDist as points are then reloaded with neighbourhood information only. + fn from(hnsw: &Hnsw) -> Self { + let mut hash_t = HashMap::new(); + let pt_iter = hnsw.get_point_indexation().into_iter(); + // + for point in pt_iter { + // println!("point : {:?}", _point.p_id); + let res_insert = hash_t.insert(point.get_origin_id(), flatten_point(&point)); + if let Some(old_point) = res_insert { + error!("2 points with same origin id {:?}", old_point.origin_id); + } + } + FlatNeighborhood { hash_t } + } +} // e,d of Fom implementation + +#[cfg(test)] + +mod tests { + + use super::*; + use anndists::dist::distances::*; + use log::debug; + + use crate::api::AnnT; + use crate::hnswio::*; + + use rand::distr::{Distribution, Uniform}; + + fn log_init_test() { + let _ = env_logger::builder().is_test(true).try_init(); + } + + #[test] + fn test_dump_reload_graph_flatten() { + println!("\n\n test_dump_reload_graph_flatten"); + log_init_test(); + // generate a random test + let mut rng = rand::rng(); + let unif = Uniform::::new(0., 1.).unwrap(); + // 1000 vectors of size 10 f32 + let nbcolumn = 1000; + let nbrow = 10; + let mut xsi; + let mut data = Vec::with_capacity(nbcolumn); + for j in 0..nbcolumn { + data.push(Vec::with_capacity(nbrow)); + for _ in 0..nbrow { + xsi = unif.sample(&mut rng); + data[j].push(xsi); + } + } + // define hnsw + let ef_construct = 25; + let nb_connection = 10; + let hnsw = Hnsw::::new(nb_connection, nbcolumn, 16, ef_construct, DistL1 {}); + for (i, d) in data.iter().enumerate() { + hnsw.insert((d, i)); + } + // some loggin info + hnsw.dump_layer_info(); + // get flat neighbours of point 3 + let neighborhood_before_dump = FlatNeighborhood::from(&hnsw); + let nbg_2_before = neighborhood_before_dump.get_neighbours(2).unwrap(); + println!("voisins du point 2 {:?}", nbg_2_before); + // dump in a file. Must take care of name as tests runs in // !!! + let fname = "dumpreloadtestflat"; + let directory = tempfile::tempdir().unwrap(); + let _res = hnsw.file_dump(directory.path(), fname); + // This will dump in 2 files named dumpreloadtest.hnsw.graph and dumpreloadtest.hnsw.data + // + // reload + debug!("HNSW reload"); + // we will need a procedural macro to get from distance name to its instantiation. + // from now on we test with DistL1 + let mut reloader = HnswIo::new(directory.path(), fname); + let hnsw_loaded: Hnsw = reloader.load_hnsw().unwrap(); + let neighborhood_after_dump = FlatNeighborhood::from(&hnsw_loaded); + let nbg_2_after = neighborhood_after_dump.get_neighbours(2).unwrap(); + println!("Neighbors of point 2 {:?}", nbg_2_after); + // test equality of neighborhood + assert_eq!(nbg_2_after.len(), nbg_2_before.len()); + for i in 0..nbg_2_before.len() { + assert_eq!(nbg_2_before[i].p_id, nbg_2_after[i].p_id); + assert_eq!(nbg_2_before[i].distance, nbg_2_after[i].distance); + } + check_graph_equality(&hnsw_loaded, &hnsw); + } // end of test_dump_reload +} // end module test diff --git a/patches/hnsw_rs/src/hnsw.rs b/patches/hnsw_rs/src/hnsw.rs new file mode 100644 index 000000000..444f33744 --- /dev/null +++ b/patches/hnsw_rs/src/hnsw.rs @@ -0,0 +1,1872 @@ +//! A rust implementation of Approximate NN search from: +//! Efficient and robust approximate nearest neighbour search using Hierarchical Navigable +//! small World graphs. +//! Yu. A. Malkov, D.A Yashunin 2016, 2018 + +use serde::{Deserialize, Serialize}; + +use cpu_time::ProcessTime; +use std::time::SystemTime; + +use std::cmp::Ordering; + +use parking_lot::{Mutex, RwLock, RwLockReadGuard}; +use rayon::prelude::*; +use std::sync::Arc; +use std::sync::mpsc::channel; + +use std::any::type_name; + +use hashbrown::HashMap; +#[allow(unused)] +use std::collections::HashSet; +use std::collections::binary_heap::BinaryHeap; + +use log::trace; +use log::{debug, info}; + +pub use crate::filter::FilterT; +use anndists::dist::distances::Distance; + +// TODO +// Profiling. + +/// This unit structure provides the type to instanciate Hnsw with, +/// to get reload of graph only in the the structure. +/// It must be associated to the unit structure dist::NoDist for the distance type to provide. +#[derive(Default, Clone, Copy, Serialize, Deserialize, Debug)] +pub struct NoData; + +/// maximum number of layers +pub(crate) const NB_LAYER_MAX: u8 = 16; // so max layer is 15!! + +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +/// The 2-uple represent layer as u8 and rank in layer as a i32 as stored in our structure +pub struct PointId(pub u8, pub i32); + +/// this type is for an identificateur of each data vector, given by client. +/// Can be the rank of data in an array, a hash value or anything that permits +/// retrieving the data. +pub type DataId = usize; + +pub type PointDistance = Box>; + +/// A structure containing internal pointId with distance to this pointId. +/// The order is given by ordering the distance to the point it refers to. +/// So points ordering has a meaning only has points refers to the same point +#[derive(Debug, Clone, Copy)] +pub struct PointIdWithOrder { + /// the identificateur of the point for which we store a distance + pub point_id: PointId, + /// The distance to a reference point (not represented in the structure) + pub dist_to_ref: f32, +} + +impl PartialEq for PointIdWithOrder { + fn eq(&self, other: &PointIdWithOrder) -> bool { + self.dist_to_ref == other.dist_to_ref + } // end eq +} + +// order points by distance to self. +impl PartialOrd for PointIdWithOrder { + fn partial_cmp(&self, other: &PointIdWithOrder) -> Option { + self.dist_to_ref.partial_cmp(&other.dist_to_ref) + } // end cmp +} // end impl PartialOrd + +impl From<&PointWithOrder<'_, T>> for PointIdWithOrder { + fn from(point: &PointWithOrder) -> PointIdWithOrder { + PointIdWithOrder::new(point.point_ref.p_id, point.dist_to_ref) + } +} + +impl PointIdWithOrder { + pub fn new(point_id: PointId, dist_to_ref: f32) -> Self { + PointIdWithOrder { + point_id, + dist_to_ref, + } + } +} // end of impl block + +//======================================================================================= +/// The struct giving an answer point to a search request. +/// This structure is exported to other language API. +/// First field is origin id of the request point, second field is distance to request point +#[repr(C)] +#[derive(Debug, Copy, Clone, Default)] +pub struct Neighbour { + /// identification of data vector as given in initializing hnsw + pub d_id: DataId, + /// distance of neighbours + pub distance: f32, + /// point identification inside layers + pub p_id: PointId, +} + +impl Neighbour { + pub fn new(d_id: DataId, distance: f32, p_id: PointId) -> Neighbour { + Neighbour { + d_id, + distance, + p_id, + } + } + /// retrieves original id of neighbour as given in hnsw initialization + pub fn get_origin_id(&self) -> DataId { + self.d_id + } + /// return the distance + pub fn get_distance(&self) -> f32 { + self.distance + } +} + +//======================================================================================= + +#[derive(Debug, Clone)] +enum PointData<'b, T: Clone + Send + Sync + 'b> { + // full data + V(Vec), + // areference to a mmaped slice + S(&'b [T]), +} // end of enum PointData + +impl<'b, T: Clone + Send + Sync + 'b> PointData<'b, T> { + // allocate a point stored in structure + fn new_v(v: Vec) -> Self { + PointData::V(v) + } + + // allocate a point representation a memory mapped slice + fn new_s(s: &'b [T]) -> Self { + PointData::S(s) + } + + fn get_v(&self) -> &[T] { + match self { + PointData::V(v) => v.as_slice(), + PointData::S(s) => s, + } + } // end of get_v +} // end of impl block for PointData + +/// The basestructure representing a data point. +/// Its constains data as coming from the client, its client id, +/// and position in layer representation and neighbours. +/// +// neighbours table : one vector by layer so neighbours is allocated to NB_LAYER_MAX +// +#[derive(Debug, Clone)] +#[allow(clippy::type_complexity)] +pub struct Point<'b, T: Clone + Send + Sync> { + /// The data of this point, coming from hnsw client and associated to origin_id, + data: PointData<'b, T>, + /// an id coming from client using hnsw, should identify point uniquely + origin_id: DataId, + /// a point id identifying point as stored in our structure + p_id: PointId, + /// neighbours info + pub(crate) neighbours: Arc>>>>>, +} + +impl<'b, T: Clone + Send + Sync> Point<'b, T> { + pub fn new(v: Vec, origin_id: usize, p_id: PointId) -> Self { + let mut neighbours = Vec::with_capacity(NB_LAYER_MAX as usize); + // CAVEAT, perhaps pass nb layer as arg ? + for _ in 0..NB_LAYER_MAX { + neighbours.push(Vec::>>::new()); + } + Point { + data: PointData::new_v(v), + origin_id, + p_id, + neighbours: Arc::new(RwLock::new(neighbours)), + } + } + + pub fn new_from_mmap(s: &'b [T], origin_id: usize, p_id: PointId) -> Self { + let mut neighbours = Vec::with_capacity(NB_LAYER_MAX as usize); + // CAVEAT, perhaps pass nb layer as arg ? + for _ in 0..NB_LAYER_MAX { + neighbours.push(Vec::>>::new()); + } + Point { + data: PointData::new_s(s), + origin_id, + p_id, + neighbours: Arc::new(RwLock::new(neighbours)), + } + } + + /// get a reference to vector data + pub fn get_v(&self) -> &[T] { + self.data.get_v() + } + + /// return coordinates in indexation + pub fn get_point_id(&self) -> PointId { + self.p_id + } + + /// returns external (or client id) id of point + pub fn get_origin_id(&self) -> usize { + self.origin_id + } + + /// returns for each layer, a vector Neighbour of a point, one vector by layer + /// useful for extern crate only as it reallocates vectors + pub fn get_neighborhood_id(&self) -> Vec> { + let ref_neighbours = self.neighbours.read(); + let nb_layer = ref_neighbours.len(); + let mut neighborhood = Vec::>::with_capacity(nb_layer); + for i in 0..nb_layer { + let mut neighbours = Vec::::new(); + let nb_ngbh = ref_neighbours[i].len(); + if nb_ngbh > 0usize { + neighbours.reserve(nb_ngbh); + for pointwo in &ref_neighbours[i] { + neighbours.push(Neighbour::new( + pointwo.point_ref.get_origin_id(), + pointwo.dist_to_ref, + pointwo.point_ref.get_point_id(), + )); + } + } + neighborhood.push(neighbours); + } + neighborhood + } + + /// prints minimal information on neighbours of point. + pub fn debug_dump(&self) { + println!(" \n dump of point id : {:?}", self.p_id); + println!("\n origin id : {:?} ", self.origin_id); + println!(" neighbours : ..."); + let ref_neighbours = self.neighbours.read(); + for i in 0..ref_neighbours.len() { + if !ref_neighbours[i].is_empty() { + println!("neighbours at layer {:?}", i); + for n in &ref_neighbours[i] { + println!(" {:?}", n.point_ref.p_id); + } + } + } + println!(" neighbours dump : end"); + } +} // end of block + +//=========================================================================================== + +/// A structure to store neighbours for of a point. +#[derive(Debug, Clone)] +pub(crate) struct PointWithOrder<'b, T: Clone + Send + Sync> { + /// the identificateur of the point for which we store a distance to a point for which + /// we made a request. + point_ref: Arc>, + /// The distance to a point_ref to the request point (not represented in the structure) + dist_to_ref: f32, +} + +impl PartialEq for PointWithOrder<'_, T> { + fn eq(&self, other: &PointWithOrder) -> bool { + self.dist_to_ref == other.dist_to_ref + } // end eq +} + +impl Eq for PointWithOrder<'_, T> {} + +// order points by distance to self. +#[allow(clippy::non_canonical_partial_ord_impl)] +impl PartialOrd for PointWithOrder<'_, T> { + fn partial_cmp(&self, other: &PointWithOrder) -> Option { + self.dist_to_ref.partial_cmp(&other.dist_to_ref) + } // end cmp +} // end impl PartialOrd + +impl Ord for PointWithOrder<'_, T> { + fn cmp(&self, other: &PointWithOrder) -> Ordering { + if !self.dist_to_ref.is_nan() && !other.dist_to_ref.is_nan() { + self.dist_to_ref.partial_cmp(&other.dist_to_ref).unwrap() + } else { + panic!("got a NaN in a distance"); + } + } // end cmp +} + +impl<'b, T: Clone + Send + Sync> PointWithOrder<'b, T> { + pub fn new(point_ref: &Arc>, dist_to_ref: f32) -> Self { + PointWithOrder { + point_ref: Arc::clone(point_ref), + dist_to_ref, + } + } +} // end of impl block + +//============================================================================================ + +// LayerGenerator +use rand::distributions::Uniform; +use rand::prelude::*; + +/// a struct to randomly generate a level for an item according to an exponential law +/// of parameter given by scale. +/// The distribution is constrained to be in [0..maxlevel[ +pub struct LayerGenerator { + rng: Arc>, + unif: Uniform, + // drives number of levels generated ~ S + scale: f64, + maxlevel: usize, +} + +impl LayerGenerator { + pub fn new(max_nb_connection: usize, maxlevel: usize) -> Self { + let scale = 1. / (max_nb_connection as f64).ln(); + LayerGenerator { + rng: Arc::new(Mutex::new(StdRng::from_entropy())), + unif: Uniform::new(0., 1.), + scale, + maxlevel, + } + } + + // new when we know scale used. Should replace the one without scale + pub(crate) fn new_with_scale( + max_nb_connection: usize, + scale_factor: f64, + maxlevel: usize, + ) -> Self { + let scale_default = 1. / (max_nb_connection as f64).ln(); + LayerGenerator { + rng: Arc::new(Mutex::new(StdRng::from_entropy())), + unif: Uniform::new(0., 1.), + scale: scale_default * scale_factor, + maxlevel, + } + } + // + // l=0 most densely packed layer + // if S is scale we sample so that P(l=n) = exp(-n/S) - exp(- (n+1)/S) + // with S = 1./ln(max_nb_connection) P(l >= maxlevel) = exp(-maxlevel * ln(max_nb_connection)) + // for nb_conn = 10, even with maxlevel = 10, we get P(l >= maxlevel) = 1.E-13 + // In Malkov(2016) S = 1./log(max_nb_connection) + // + /// generate a layer with given maxlevel. upper layers (higher index) are of decreasing probabilities. + /// thread safe method. + fn generate(&self) -> usize { + let mut protected_rng = self.rng.lock(); + let xsi = protected_rng.sample(self.unif); + let level = -xsi.ln() * self.scale; + let mut ulevel = level.floor() as usize; + // we redispatch possibly sampled level >= maxlevel to required range + if ulevel >= self.maxlevel { + // This occurs with very low probability. Cf commentary above. + ulevel = protected_rng.sample(Uniform::new(0, self.maxlevel)); + } + ulevel + } + + /// just to try some variations on exponential level sampling. Unused. + fn set_scale_modification(&mut self, scale_modification: f64) { + self.scale *= scale_modification; + log::info!("using scale for sampling levels : {:.2e}", self.scale); + } + + // + fn get_level_scale(&self) -> f64 { + self.scale + } +} // end impl for LayerGenerator + +// ==================================================================== + +/// A short-hand for points in a layer +type Layer<'b, T> = Vec>>; + +/// a structure for indexation of points in layer +#[allow(unused)] +pub struct PointIndexation<'b, T: Clone + Send + Sync> { + /// max number of connection for a point at a layer + pub(crate) max_nb_connection: usize, + // + pub(crate) max_layer: usize, + /// needs at least one representation of points. points_by_layers\[i\] gives the points in layer i + pub(crate) points_by_layer: Arc>>>, + /// utility to generate a level + pub(crate) layer_g: LayerGenerator, + /// number of points in indexed structure + pub(crate) nb_point: Arc>, + /// curent enter_point: an Arc RwLock on a possible Arc Point + pub(crate) entry_point: Arc>>>>, +} + +// A point indexation may contain circular references. To deallocate these after a point indexation goes out of scope, +// implement the Drop trait. + +impl Drop for PointIndexation<'_, T> { + fn drop(&mut self) { + let cpu_start = ProcessTime::now(); + let sys_now = SystemTime::now(); + info!("entering PointIndexation drop"); + // clear_neighborhood. There are no point in neighborhoods that are not referenced directly in layers. + // so we cannot lose reference to a point by cleaning neighborhood + fn clear_neighborhoods(init: &Point) { + let mut neighbours = init.neighbours.write(); + let nb_layer = neighbours.len(); + for l in 0..nb_layer { + neighbours[l].clear(); + } + neighbours.clear(); + } + if let Some(i) = self.entry_point.write().as_ref() { + clear_neighborhoods(i.as_ref()); + } + // + let nb_level = self.get_max_level_observed(); + for l in 0..=nb_level { + trace!("clearing layer {}", l); + let layer = &mut self.points_by_layer.write()[l as usize]; + layer.into_par_iter().for_each(|p| clear_neighborhoods(p)); + layer.clear(); + } + // + debug!("clearing self.points_by_layer..."); + drop(self.points_by_layer.write()); + debug!("exiting PointIndexation drop"); + info!( + " drop sys time(s) {:?} cpu time {:?}", + sys_now.elapsed().unwrap().as_secs(), + cpu_start.elapsed().as_secs() + ); + } // end my drop +} // end implementation Drop + +impl<'b, T: Clone + Send + Sync> PointIndexation<'b, T> { + pub fn new(max_nb_connection: usize, max_layer: usize, max_elements: usize) -> Self { + let mut points_by_layer = Vec::with_capacity(max_layer); + for i in 0..max_layer { + // recall that range are right extremeity excluded + // compute fraction of points going into layer i and do expected memory reservation + let s = 1. / (max_nb_connection as f64).ln(); + let frac = (-(i as f64) / s).exp() - (-((i + 1) as f64) / s); + let expected_size = ((frac * max_elements as f64).round()) as usize; + points_by_layer.push(Vec::with_capacity(expected_size)); + } + let layer_g = LayerGenerator::new(max_nb_connection, max_layer); + PointIndexation { + max_nb_connection, + max_layer, + points_by_layer: Arc::new(RwLock::new(points_by_layer)), + layer_g, + nb_point: Arc::new(RwLock::new(0)), + entry_point: Arc::new(RwLock::new(None)), + } + } // end of new + + /// returns the maximum level of layer observed + pub fn get_max_level_observed(&self) -> u8 { + let opt = self.entry_point.read(); + match opt.as_ref() { + Some(arc_point) => arc_point.p_id.0, + None => 0, + } + } + + pub fn get_level_scale(&self) -> f64 { + self.layer_g.get_level_scale() + } + + fn debug_dump(&self) { + println!(" debug dump of PointIndexation"); + let max_level_observed = self.get_max_level_observed(); + // CAVEAT a lock once + for l in 0..=max_level_observed as usize { + println!( + " layer {} : length : {} ", + l, + self.points_by_layer.read()[l].len() + ); + } + println!(" debug dump of PointIndexation end"); + } + + /// real insertion of point in point indexation + // generate a new Point/ArcPoint (with neigbourhood info empty) and store it in global table + // The function is called by Hnsw insert method + fn generate_new_point(&self, data: &[T], origin_id: usize) -> (Arc>, usize) { + // get a write lock at the beginning of the function + let level = self.layer_g.generate(); + let new_point; + { + // open a write lock on points_by_layer + let mut points_by_layer_ref = self.points_by_layer.write(); + let mut p_id = PointId(level as u8, -1); + p_id.1 = points_by_layer_ref[p_id.0 as usize].len() as i32; + // make a Point and then an Arc + let point = Point::new(data.to_vec(), origin_id, p_id); + new_point = Arc::new(point); + trace!("definitive pushing of point {:?}", p_id); + points_by_layer_ref[p_id.0 as usize].push(Arc::clone(&new_point)); + } // close write lock on points_by_layer + // + let nb_point; + { + let mut lock_nb_point = self.nb_point.write(); + *lock_nb_point += 1; + nb_point = *lock_nb_point; + if nb_point % 50000 == 0 { + println!(" setting number of points {:?} ", nb_point); + } + } + trace!(" setting number of points {:?} ", *self.nb_point); + // Now possibly this is a point on a new layer that will have no neighbours in its layer + (Arc::clone(&new_point), nb_point) + } // end of insert + + /// check if entry_point is modified + fn check_entry_point(&self, new_point: &Arc>) { + // + // take directly a write lock so that we are sure nobody can change anything between read and write + // of entry_point_id + trace!("trying to get a lock on entry point"); + let mut entry_point_ref = self.entry_point.write(); + match entry_point_ref.as_ref() { + Some(arc_point) => { + if new_point.p_id.0 > arc_point.p_id.0 { + debug!("Hnsw , inserting entry point {:?} ", new_point.p_id); + debug!( + "PointIndexation insert setting max level from {:?} to {:?}", + arc_point.p_id.0, new_point.p_id.0 + ); + *entry_point_ref = Some(Arc::clone(new_point)); + } + } + None => { + trace!("initializing entry point"); + debug!("Hnsw , inserting entry point {:?} ", new_point.p_id); + *entry_point_ref = Some(Arc::clone(new_point)); + } + } + } // end of check_entry_point + + /// returns the number of points in layered structure + pub fn get_nb_point(&self) -> usize { + *self.nb_point.read() + } + + /// returns the number of points in a given layer, 0 on a bad layer num + pub fn get_layer_nb_point(&self, layer: usize) -> usize { + let nb_layer = self.points_by_layer.read().len(); + if layer < nb_layer { + self.points_by_layer.read()[layer].len() + } else { + 0 + } + } // end of get_layer_nb_point + + /// returns the size of data vector in graph if any, else return 0 + pub fn get_data_dimension(&self) -> usize { + let ep = self.entry_point.read(); + match ep.as_ref() { + Some(point) => point.get_v().len(), + None => 0, + } + } + + /// returns (**by cloning**) the data inside a point given it PointId, or None if PointId is not coherent. + /// Can be useful after reloading from a dump. + /// NOTE : This function should not be called during or before insertion in the structure is terminated as it + /// uses read locks to access the inside of Hnsw structure. + pub fn get_point_data(&self, p_id: &PointId) -> Option> { + if p_id.1 < 0 { + return None; + } + let p: usize = std::convert::TryFrom::try_from(p_id.1).unwrap(); + let l = p_id.0 as usize; + if p_id.0 <= self.get_max_level_observed() && p < self.get_layer_nb_point(l) { + Some(self.points_by_layer.read()[l][p].get_v().to_vec()) + } else { + None + } + } // end of get_point_data + + /// returns (**by Arc::clone**) the point given it PointId, or None if PointId is not coherent. + /// Can be useful after reloading from a dump. + /// NOTE : This function should not be called during or before insertion in the structure is terminated as it + /// uses read locks to access the inside of Hnsw structure. + #[allow(unused)] + pub(crate) fn get_point(&self, p_id: &PointId) -> Option>> { + if p_id.1 < 0 { + return None; + } + let p: usize = std::convert::TryFrom::try_from(p_id.1).unwrap(); + let l = p_id.0 as usize; + if p_id.0 <= self.get_max_level_observed() && p < self.get_layer_nb_point(l) { + Some(self.points_by_layer.read()[l][p].clone()) + } else { + None + } + } // end of get_point + + /// get an iterator on the points stored in a given layer + pub fn get_layer_iterator<'a>(&'a self, layer: usize) -> IterPointLayer<'a, 'b, T> { + IterPointLayer::new(self, layer) + } // end of get_layer_iterator +} // end of impl PointIndexation + +//============================================================================================ + +/// an iterator on points stored. +/// The iteration begins at level 0 (most populated level) and goes upward in levels. +/// The iterator takes a ReadGuard on the PointIndexation structure +pub struct IterPoint<'a, 'b, T: Clone + Send + Sync + 'b> { + point_indexation: &'a PointIndexation<'b, T>, + pi_guard: RwLockReadGuard<'a, Vec>>, + layer: i64, + slot_in_layer: i64, +} + +impl<'a, 'b, T: Clone + Send + Sync> IterPoint<'a, 'b, T> { + pub fn new(point_indexation: &'a PointIndexation<'b, T>) -> Self { + let pi_guard: RwLockReadGuard>> = point_indexation.points_by_layer.read(); + IterPoint { + point_indexation, + pi_guard, + layer: -1, + slot_in_layer: -1, + } + } +} // end of block impl IterPoint + +/// iterator for layer 0 to upper layer. +impl<'b, T: Clone + Send + Sync> Iterator for IterPoint<'_, 'b, T> { + type Item = Arc>; + // + fn next(&mut self) -> Option { + if self.layer == -1 { + self.layer = 0; + self.slot_in_layer = 0; + } + if (self.slot_in_layer as usize) < self.pi_guard[self.layer as usize].len() { + let slot = self.slot_in_layer as usize; + self.slot_in_layer += 1; + Some(self.pi_guard[self.layer as usize][slot].clone()) + } else { + self.slot_in_layer = 0; + self.layer += 1; + // must reach a non empty layer if possible + let entry_point_ref = self.point_indexation.entry_point.read(); + let points_by_layer = self.point_indexation.points_by_layer.read(); + let entry_point_level = entry_point_ref.as_ref().unwrap().p_id.0; + while (self.layer as u8) <= entry_point_level + && points_by_layer[self.layer as usize].is_empty() + { + self.layer += 1; + } + // now here either (self.layer as u8) > self.point_indexation.max_level_observed + // or self.point_indexation.points_by_layer[self.layer as usize ].len() > 0 + if (self.layer as u8) <= entry_point_level { + let slot = self.slot_in_layer as usize; + self.slot_in_layer += 1; + Some(points_by_layer[self.layer as usize][slot].clone()) + } else { + None + } + } + } // end of next +} // end of impl Iterator + +impl<'a, 'b, T: Clone + Send + Sync> IntoIterator for &'a PointIndexation<'b, T> { + type Item = Arc>; + type IntoIter = IterPoint<'a, 'b, T>; + // + fn into_iter(self) -> Self::IntoIter { + IterPoint::new(self) + } +} // end of IntoIterator for &'a PointIndexation + +/// An iterator on points stored in a given layer +/// The iterator stores a ReadGuard on the structure PointIndexation +pub struct IterPointLayer<'a, 'b, T: Clone + Send + Sync> { + _point_indexation: &'a PointIndexation<'b, T>, + pi_guard: RwLockReadGuard<'a, Vec>>, + layer: usize, + slot_in_layer: usize, +} + +impl<'a, 'b, T: Clone + Send + Sync> IterPointLayer<'a, 'b, T> { + pub fn new(point_indexation: &'a PointIndexation<'b, T>, layer: usize) -> Self { + let pi_guard: RwLockReadGuard>> = point_indexation.points_by_layer.read(); + IterPointLayer { + _point_indexation: point_indexation, + pi_guard, + layer, + slot_in_layer: 0, + } + } +} // end of block impl IterPointLayer + +/// iterator for layer 0 to upper layer. +impl<'b, T: Clone + Send + Sync + 'b> Iterator for IterPointLayer<'_, 'b, T> { + type Item = Arc>; + // + fn next(&mut self) -> Option { + if (self.slot_in_layer) < self.pi_guard[self.layer].len() { + let slot = self.slot_in_layer; + self.slot_in_layer += 1; + Some(self.pi_guard[self.layer][slot].clone()) + } else { + None + } + } // end of next +} // end of impl Iterator + +// ============================================================================================ + +// The fields are made pub(crate) to be able to initialize struct from hnswio +/// The Base structure for hnsw implementation. +/// The main useful functions are : new, insert, insert_parallel, search, parallel_search and file_dump +/// as described in trait AnnT. +/// +/// Other functions are mainly for others crate to get access to some fields. +pub struct Hnsw<'b, T: Clone + Send + Sync + 'b, D: Distance> { + /// asked number of candidates in search + pub(crate) ef_construction: usize, + /// maximum number of connection by layer for a point + pub(crate) max_nb_connection: usize, + /// flag to enforce that we have ef candidates as pruning strategy can discard some points + /// Can be set to true with method :set_extend_candidates + /// When set to true used only in base layer. + pub(crate) extend_candidates: bool, + /// defuault to false + pub(crate) keep_pruned: bool, + /// max layer , recall rust is in 0..maxlevel right bound excluded + pub(crate) max_layer: usize, + /// The global table containing points + pub(crate) layer_indexed_points: PointIndexation<'b, T>, + /// dimension data stored in points + #[allow(unused)] + pub(crate) data_dimension: usize, + /// distance between points. initialized at first insertion + pub(crate) dist_f: D, + /// insertion mode or searching mode. This flag prevents a internal thread to do a write when searching with other threads. + pub(crate) searching: bool, + /// set to true if some data come from a mmap + pub(crate) datamap_opt: bool, +} // end of Hnsw + +impl<'b, T: Clone + Send + Sync, D: Distance + Send + Sync> Hnsw<'b, T, D> { + /// allocation function + /// . max_nb_connection : number of neighbours stored, by layer, in tables. Must be less than 256. + /// . ef_construction : controls numbers of neighbours explored during construction. See README or paper. + /// . max_elements : hint to speed up allocation tables. number of elements expected. + /// . f : the distance function + pub fn new( + max_nb_connection: usize, + max_elements: usize, + max_layer: usize, + ef_construction: usize, + f: D, + ) -> Self { + let adjusted_max_layer = (NB_LAYER_MAX as usize).min(max_layer); + let layer_indexed_points = + PointIndexation::::new(max_nb_connection, adjusted_max_layer, max_elements); + let extend_candidates = false; + let keep_pruned = false; + // + if max_nb_connection > 256 { + println!("error max_nb_connection must be less equal than 256"); + std::process::exit(1); + } + // + info!("Hnsw max_nb_connection {:?}", max_nb_connection); + info!("Hnsw nb elements {:?}", max_elements); + info!("Hnsw ef_construction {:?}", ef_construction); + info!("Hnsw distance {:?}", type_name::()); + info!("Hnsw extend candidates {:?}", extend_candidates); + // + Hnsw { + max_nb_connection, + ef_construction, + extend_candidates, + keep_pruned, + max_layer: adjusted_max_layer, + layer_indexed_points, + data_dimension: 0, + dist_f: f, + searching: false, + datamap_opt: false, + } + } // end of new + + /// get ef_construction used in graph creation + pub fn get_ef_construction(&self) -> usize { + self.ef_construction + } + /// returns the maximum layer authorized in construction + pub fn get_max_level(&self) -> usize { + self.max_layer + } + + /// return the maximum level reached in the layers. + pub fn get_max_level_observed(&self) -> u8 { + self.layer_indexed_points.get_max_level_observed() + } + /// returns the maximum of links between a point and others points in each layer + pub fn get_max_nb_connection(&self) -> u8 { + self.max_nb_connection as u8 + } + /// returns number of points stored in hnsw structure + pub fn get_nb_point(&self) -> usize { + self.layer_indexed_points.get_nb_point() + } + /// set searching mode. + /// It is not possible to do parallel insertion and parallel searching simultaneously in different threads + /// so to enable searching after parallel insertion the flag must be set to true. + /// To resume parallel insertion reset the flag to false and so on. + pub fn set_searching_mode(&mut self, flag: bool) { + // must use an atomic! + self.searching = flag; + } + /// get name if distance + pub fn get_distance_name(&self) -> String { + type_name::().to_string() + } + /// set the flag asking to keep pruned vectors by Navarro's heuristic (see Paper). + /// It can be useful for small datasets where the pruning can make it difficult + /// to get the exact number of neighbours asked for. + pub fn set_keeping_pruned(&mut self, flag: bool) { + self.keep_pruned = flag; + } + + /// retrieves the distance used in Hnsw construction + pub fn get_distance(&self) -> &D { + &self.dist_f + } + + /// set extend_candidates to given flag. By default it is false. + /// Only used in the level 0 layer during insertion (see the paper) + /// flag to enforce that we have ef candidates neighbours examined as pruning strategy + /// can discard some points + pub fn set_extend_candidates(&mut self, flag: bool) { + self.extend_candidates = flag; + } + + // When dumping we need to know if some file is mmapped + pub(crate) fn get_datamap_opt(&self) -> bool { + self.datamap_opt + } + + /// By default the levels are sampled using an exponential law of parameter **ln(max_nb_conn)** + /// so the probability of having more than l levels decrease as **exp(-l * ln(max_nb_conn))**. + /// Reducing the scale change the parameter of the exponential to **ln(max_nb_conn)/scale**. + /// This reduce the number of levels generated and can provide better precision, reduce memory with marginally more cpu used. + /// The factor must between 0.2 and 1. + pub fn modify_level_scale(&mut self, scale_modification: f64) { + // + if self.get_nb_point() > 0 { + println!( + "using modify_level_scale is possible at creation of a Hnsw structure to ensure coherence between runs" + ) + } + // + let min_factor = 0.2; + println!( + "\n Current scale value : {:.2e}, Scale modification factor asked : {:.2e},(modification factor must be between {:.2e} and 1.)", + self.layer_indexed_points.layer_g.scale, scale_modification, min_factor + ); + // + if scale_modification > 1. { + println!( + "\n Scale modification not applied, modification arg {:.2e} not valid , factor must be less than 1.)", + scale_modification + ); + } else if scale_modification < min_factor { + println!( + "\n Scale modification arg {:.2e} not valid , factor must be greater than {:.2e}, using {:.2e})", + scale_modification, min_factor, min_factor + ); + } + // + self.layer_indexed_points + .layer_g + .set_scale_modification(scale_modification.max(min_factor).min(1.)); + } // end of set_scale_modification + + // here we could pass a point_id_with_order instead of entry_point_id: PointId + // The efficacity depends on greedy part depends on how near entry point is from point. + // ef is the number of points to return + // The method returns a BinaryHeap with positive distances. The caller must transforms it according its need + //** NOTE: the entry point is pushed into returned point at the beginning of the function, but in fact entry_point is in a layer + //** with higher (one more) index than the argument layer. If the greedy search matches a sufficiently large number of points + //** nearer to point searched (arg point) than entry_point it will finally get popped up from the heap of returned points + //** but otherwise it will stay in the binary heap and so we can have a point in neighbours that is in fact in a layer + //** above the one we search in. + //** The guarantee is that the binary heap will return points in layer + //** with a larger index, although we can expect that most often (at least in densely populated layers) the returned + //** points will be found in searched layer + /// + /// Greedy algorithm n° 2 in Malkov paper. + /// search in a layer (layer) for the ef points nearest a point to be inserted in hnsw. + fn search_layer( + &self, + point: &[T], + entry_point: Arc>, + ef: usize, + layer: u8, + filter: Option<&dyn FilterT>, + ) -> BinaryHeap>> { + // + trace!( + "entering search_layer with entry_point_id {:?} layer : {:?} ef {:?} ", + entry_point.p_id, layer, ef + ); + // + // here we allocate a binary_heap on values not on reference beccause we want to return + // log2(skiplist_size) must be greater than 1. + let skiplist_size = ef.max(2); + // we will store positive distances in this one + let mut return_points = BinaryHeap::>>::with_capacity(skiplist_size); + // + if self.layer_indexed_points.points_by_layer.read()[layer as usize].is_empty() { + // at the beginning we can have nothing in layer + trace!("search layer {:?}, empty layer", layer); + return return_points; + } + if entry_point.p_id.1 < 0 { + trace!("search layer negative point id : {:?}", entry_point.p_id); + return return_points; + } + // initialize visited points + let dist_to_entry_point = self.dist_f.eval(point, entry_point.data.get_v()); + trace!(" distance to entry point: {:?} ", dist_to_entry_point); + // keep a list of id visited + let mut visited_point_id = HashMap::>>::new(); + visited_point_id.insert(entry_point.p_id, Arc::clone(&entry_point)); + // + let mut candidate_points = + BinaryHeap::>>::with_capacity(skiplist_size); + candidate_points.push(Arc::new(PointWithOrder::new( + &entry_point, + -dist_to_entry_point, + ))); + return_points.push(Arc::new(PointWithOrder::new( + &entry_point, + dist_to_entry_point, + ))); + // at the beginning candidate_points contains point passed as arg in layer entry_point_id.0 + while !candidate_points.is_empty() { + // get nearest point in candidate_points + let c = candidate_points.pop().unwrap(); + // f farthest point to + let f = return_points.peek().unwrap(); + assert!(f.dist_to_ref >= 0.); + assert!(c.dist_to_ref <= 0.); + trace!( + "Comparaing c : {:?} f : {:?}", + -(c.dist_to_ref), + f.dist_to_ref + ); + if -(c.dist_to_ref) > f.dist_to_ref { + // this comparison requires that we are sure that distances compared are distances to the same point : + // This is the case we compare distance to point passed as arg. + trace!( + "Fast return from search_layer, nb points : {:?} \n \t c {:?} \n \t f {:?} dists: {:?} {:?}", + return_points.len(), + c.point_ref.p_id, + f.point_ref.p_id, + -(c.dist_to_ref), + f.dist_to_ref + ); + if filter.is_none() { + return return_points; + } else if return_points.len() >= ef { + return_points.retain(|p| { + filter + .as_ref() + .unwrap() + .hnsw_filter(&p.point_ref.get_origin_id()) + }); + } + } + // now we scan neighborhood of c in layer and increment visited_point, candidate_points + // and optimize candidate_points so that it contains points with lowest distances to point arg + // + let neighbours_c_l = &c.point_ref.neighbours.read()[layer as usize]; + let c_pid = c.point_ref.p_id; + trace!( + " search_layer, {:?} has nb neighbours : {:?} ", + c_pid, + neighbours_c_l.len() + ); + for e in neighbours_c_l { + // HERE WE sEE THAT neighbours should be stored as PointIdWithOrder !! + // CAVEAT what if several point_id with same distance to ref point? + if !visited_point_id.contains_key(&e.point_ref.p_id) { + visited_point_id.insert(e.point_ref.p_id, Arc::clone(&e.point_ref)); + trace!(" visited insertion {:?}", e.point_ref.p_id); + let f_opt = return_points.peek(); + if f_opt.is_none() { + // do some debug info, dumped distance is from e to c! as e is in c neighbours + debug!("return points empty when inserting {:?}", e.point_ref.p_id); + return return_points; + } + let f = f_opt.unwrap(); + let e_dist_to_p = self.dist_f.eval(point, e.point_ref.data.get_v()); + let f_dist_to_p = f.dist_to_ref; + if e_dist_to_p < f_dist_to_p || return_points.len() < ef { + let e_prime = Arc::new(PointWithOrder::new(&e.point_ref, e_dist_to_p)); + // a neighbour of neighbour is better, we insert it into candidate with the distance to point + trace!( + " inserting new candidate {:?}", + e_prime.point_ref.p_id + ); + candidate_points + .push(Arc::new(PointWithOrder::new(&e.point_ref, -e_dist_to_p))); + if filter.is_none() { + return_points.push(Arc::clone(&e_prime)); + } else { + let id: &usize = &e_prime.point_ref.get_origin_id(); + if filter.as_ref().unwrap().hnsw_filter(id) { + if return_points.len() == 1 { + let only_id = return_points.peek().unwrap().point_ref.origin_id; + if !filter.as_ref().unwrap().hnsw_filter(&only_id) { + return_points.clear() + } + } + return_points.push(Arc::clone(&e_prime)) + } + } + if return_points.len() > ef { + return_points.pop(); + } + } // end if e.dist_to_ref < f.dist_to_ref + } + } // end of for on neighbours_c + } // end of while in candidates + // + trace!( + "return from search_layer, nb points : {:?}", + return_points.len() + ); + return_points + } // end of search_layer + + /// insert a tuple (&Vec, usize) with its external id as given by the client. + /// The insertion method gives the point an internal id. + #[inline] + pub fn insert(&self, datav_with_id: (&[T], usize)) { + self.insert_slice((datav_with_id.0, datav_with_id.1)) + } + + // Hnsw insert. + /// Insert a data slice with its external id as given by the client. + /// The insertion method gives the point an internal id. + /// The slice insertion makes integration with ndarray crate easier than the vector insertion + pub fn insert_slice(&self, data_with_id: (&[T], usize)) { + // + let (data, origin_id) = data_with_id; + let keep_pruned = self.keep_pruned; + // insert in indexation and get point_id adn generate a new entry_point if necessary + let (new_point, point_rank) = self + .layer_indexed_points + .generate_new_point(data, origin_id); + trace!("Hnsw insert generated new point {:?} ", new_point.p_id); + // now real work begins + // allocate a binary heap + let level = new_point.p_id.0; + let mut enter_point_copy = None; + let mut max_level_observed = 0; + // entry point has been set in + { + // I open a read lock on an option + if let Some(arc_point) = self.layer_indexed_points.entry_point.read().as_ref() { + enter_point_copy = Some(Arc::clone(arc_point)); + if point_rank == 1 { + debug!( + "Hnsw stored first point , direct return {:?} ", + new_point.p_id + ); + return; + } + max_level_observed = enter_point_copy.as_ref().unwrap().p_id.0; + } + } + if enter_point_copy.is_none() { + self.layer_indexed_points.check_entry_point(&new_point); + return; + } + let mut dist_to_entry = self + .dist_f + .eval(data, enter_point_copy.as_ref().unwrap().data.get_v()); + // we go from self.max_level_observed to level+1 included + for l in ((level + 1)..(max_level_observed + 1)).rev() { + // CAVEAT could bypass when layer empty, avoid allocation.. + let mut sorted_points = self.search_layer( + data, + Arc::clone(enter_point_copy.as_ref().unwrap()), + 1, + l, + None, + ); + trace!( + "in insert :search_layer layer {:?}, returned {:?} points ", + l, + sorted_points.len() + ); + if sorted_points.len() > 1 { + panic!( + "in insert : search_layer layer {:?}, returned {:?} points ", + l, + sorted_points.len() + ); + } + // the heap conversion is useless beccause of the preceding test. + // sorted_points = from_positive_binaryheap_to_negative_binary_heap(&mut sorted_points); + // + if let Some(ep) = sorted_points.pop() { + // useful for projecting lower layer to upper layer. keep track of points encountered. + if new_point.neighbours.read()[l as usize].len() + < self.get_max_nb_connection() as usize + { + new_point.neighbours.write()[l as usize].push(Arc::clone(&ep)); + } + // get the lowest distance point + let tmp_dist = self.dist_f.eval(data, ep.point_ref.data.get_v()); + if tmp_dist < dist_to_entry { + enter_point_copy = Some(Arc::clone(&ep.point_ref)); + dist_to_entry = tmp_dist; + } + } else { + // this layer is not yet filled + trace!("layer still empty {} : got null list", l); + } + } + // now enter_point_id_copy contains id of nearest + // now loop down to 0 + for l in (0..level + 1).rev() { + let ef = self.ef_construction; + // when l == level, we cannot get new_point in sorted_points as it is seen only from declared neighbours + let mut sorted_points = self.search_layer( + data, + Arc::clone(enter_point_copy.as_ref().unwrap()), + ef, + l, + None, + ); + trace!( + "in insert :search_layer layer {:?}, returned {:?} points ", + l, + sorted_points.len() + ); + sorted_points = from_positive_binaryheap_to_negative_binary_heap(&mut sorted_points); + if !sorted_points.is_empty() { + let nb_conn; + let extend_c; + if l == 0 { + nb_conn = 2 * self.max_nb_connection; + extend_c = self.extend_candidates; + } else { + nb_conn = self.max_nb_connection; + extend_c = false; + } + let mut neighbours = Vec::>>::with_capacity(nb_conn); + self.select_neighbours( + data, + &mut sorted_points, + nb_conn, + extend_c, + l, + keep_pruned, + &mut neighbours, + ); + // sort neighbours + neighbours.sort_unstable(); + // we must add bidirecti*onal from data i.e new_point_id to neighbours + new_point.neighbours.write()[l as usize].clone_from(&neighbours); + // this reverse neighbour update could be done here but we put it at end to gather all code + // requiring a mutex guard for multi threading. + // update ep for loop iteration. As we sorted neighbours the nearest + if !neighbours.is_empty() { + enter_point_copy = Some(Arc::clone(&neighbours[0].point_ref)); + } + } + } // for l + // + // new_point has been inserted at the beginning in table + // so that we can call reverse_update_neighborhoodwe consitently + // now reverse update of neighbours. + self.reverse_update_neighborhood_simple(Arc::clone(&new_point)); + // + self.layer_indexed_points.check_entry_point(&new_point); + // + trace!("Hnsw exiting insert new point {:?} ", new_point.p_id); + } // end of insert + + /// Insert in parallel a slice of Vec\ each associated to its id. + /// It uses Rayon for threading so the number of insertions asked for must be large enough to be efficient. + /// Typically 1000 * the number of threads. + /// Many consecutive parallel_insert can be done, so the size of vector inserted in one insertion can be optimized. + pub fn parallel_insert(&self, datas: &[(&Vec, usize)]) { + debug!("entering parallel_insert"); + datas + .par_iter() + .for_each(|&(item, v)| self.insert((item.as_slice(), v))); + debug!("exiting parallel_insert"); + } // end of parallel_insert + + /// Insert in parallel slices of \[T\] each associated to its id. + /// It uses Rayon for threading so the number of insertions asked for must be large enough to be efficient. + /// Typically 1000 * the number of threads. + /// Facilitates the use with the ndarray crate as we can extract slices (for data in contiguous order) from Array. + pub fn parallel_insert_slice(&self, datas: &Vec<(&[T], usize)>) { + datas.par_iter().for_each(|&item| self.insert_slice(item)); + } // end of parallel_insert + + /// insert new_point in neighbourhood info of point + fn reverse_update_neighborhood_simple(&self, new_point: Arc>) { + // println!("reverse update neighbourhood for new point {:?} ", new_point.p_id); + trace!( + "reverse update neighbourhood for new point {:?} ", + new_point.p_id + ); + let level = new_point.p_id.0; + for l in (0..level + 1).rev() { + for q in &new_point.neighbours.read()[l as usize] { + if new_point.p_id != q.point_ref.p_id { + // as new point is in global table, do not loop and deadlock!! + let q_point = &q.point_ref; + let mut q_point_neighbours = q_point.neighbours.write(); + let n_to_add = PointWithOrder::::new(&Arc::clone(&new_point), q.dist_to_ref); + // must be sure that we add a point at the correct level. See the comment to search_layer! + // this ensures that reverse updating do not add problems. + let l_n = n_to_add.point_ref.p_id.0 as usize; + let already = q_point_neighbours[l_n] + .iter() + .position(|old| old.point_ref.p_id == new_point.p_id); + if already.is_some() { + // debug!(" new_point.p_id {:?} already in neighbourhood of q_point {:?} at index {:?}", new_point.p_id, q_point.p_id, already.unwrap()); + // q_point.debug_dump(); cannot be called as its neighbours are locked write by this method. + // new_point.debug_dump(); + // panic!(); + continue; + } + q_point_neighbours[l_n].push(Arc::new(n_to_add)); + let nbn_at_l = q_point_neighbours[l_n].len(); + // + // if l < level, update upward chaining, insert does a sort! t_q has a neighbour not yet in global table of points! + let threshold_shrinking = if l_n > 0 { + self.max_nb_connection + } else { + 2 * self.max_nb_connection + }; + let shrink = nbn_at_l > threshold_shrinking; + { + // sort and shring if necessary + q_point_neighbours[l_n].sort_unstable(); + if shrink { + q_point_neighbours[l_n].pop(); + } + } + } // end protection against point identity + } + } + // println!(" exitingreverse update neighbourhood for new point {:?} ", new_point.p_id); + } // end of reverse_update_neighborhood_simple + + pub fn get_point_indexation(&self) -> &PointIndexation<'b, T> { + &self.layer_indexed_points + } + + // This is best explained in : Navarro. Searching in metric spaces by spatial approximation. + /// simplest searh neighbours + // The binary heaps here is with negative distance sorted. + #[allow(clippy::too_many_arguments)] + fn select_neighbours( + &self, + data: &[T], + candidates: &mut BinaryHeap>>, + nb_neighbours_asked: usize, + extend_candidates_asked: bool, + layer: u8, + keep_pruned: bool, + neighbours_vec: &mut Vec>>, + ) { + // + trace!( + "entering select_neighbours : nb candidates: {}", + candidates.len() + ); + // + neighbours_vec.clear(); + // we will extend if we do not have enough candidates and it is explicitly asked in arg + let mut extend_candidates = false; + if candidates.len() <= nb_neighbours_asked { + if !extend_candidates_asked { + // just transfer taking care of signs + while !candidates.is_empty() { + let p = candidates.pop().unwrap(); + assert!(-p.dist_to_ref >= 0.); + neighbours_vec + .push(Arc::new(PointWithOrder::new(&p.point_ref, -p.dist_to_ref))); + } + return; + } else { + extend_candidates = true; + } + } + // + // + //extend_candidates = true; + // + if extend_candidates { + let mut candidates_set = HashMap::>>::new(); + for c in candidates.iter() { + candidates_set.insert(c.point_ref.p_id, Arc::clone(&c.point_ref)); + } + let mut new_candidates_set = HashMap::>>::new(); + // get a list of all neighbours of candidates + for (_p_id, p_point) in candidates_set.iter() { + let n_p_layer = &p_point.neighbours.read()[layer as usize]; + for q in n_p_layer { + if !candidates_set.contains_key(&q.point_ref.p_id) + && !new_candidates_set.contains_key(&q.point_ref.p_id) + { + new_candidates_set.insert(q.point_ref.p_id, Arc::clone(&q.point_ref)); + } + } + } // end of for p + trace!( + "select neighbours extend candidates from : {:?} adding : {:?}", + candidates.len(), + new_candidates_set.len() + ); + for (_p_id, p_point) in new_candidates_set.iter() { + let dist_topoint = self.dist_f.eval(data, p_point.data.get_v()); + candidates.push(Arc::new(PointWithOrder::new(p_point, -dist_topoint))); + } + } // end if extend_candidates + // + let mut discarded_points = BinaryHeap::>>::new(); + while !candidates.is_empty() && neighbours_vec.len() < nb_neighbours_asked { + // compare distances of e to data. we do not need to recompute dists! + if let Some(e_p) = candidates.pop() { + let mut e_to_insert = true; + let e_point_v = e_p.point_ref.data.get_v(); + assert!(e_p.dist_to_ref <= 0.); + // is e_p the nearest to reference? data than to previous neighbours + if !neighbours_vec.is_empty() { + e_to_insert = !neighbours_vec.iter().any(|d| { + self.dist_f.eval(e_point_v, d.point_ref.data.get_v()) <= -e_p.dist_to_ref + }); + } + if e_to_insert { + trace!("inserting neighbours : {:?} ", e_p.point_ref.p_id); + neighbours_vec.push(Arc::new(PointWithOrder::new( + &e_p.point_ref, + -e_p.dist_to_ref, + ))); + } else { + trace!("discarded neighbours : {:?} ", e_p.point_ref.p_id); + // ep is taken from a binary heap, so it has a negative sign, we keep its sign + // to store it in another binary heap will possibly need to retain the best ones from the discarde binaryHeap + if keep_pruned { + discarded_points.push(Arc::new(PointWithOrder::new( + &e_p.point_ref, + e_p.dist_to_ref, + ))); + } + } + } + } + // now this part of neighbours is the most interesting and is distance sorted. + + // not pruned are at the end of neighbours_vec which is not re-sorted , but discarded are sorted. + if keep_pruned { + while !discarded_points.is_empty() && neighbours_vec.len() < nb_neighbours_asked { + let best_point = discarded_points.pop().unwrap(); + // do not forget to reverse sign + assert!(best_point.dist_to_ref <= 0.); + neighbours_vec.push(Arc::new(PointWithOrder::new( + &best_point.point_ref, + -best_point.dist_to_ref, + ))); + } + }; + // + if log::log_enabled!(log::Level::Trace) { + trace!( + "exiting select_neighbours : nb candidates: {}", + neighbours_vec.len() + ); + for n in neighbours_vec { + trace!(" neighbours {:?} ", n.point_ref.p_id); + } + } + // + } // end of select_neighbours + + /// A utility to get printed info on how many points there are in each layer. + pub fn dump_layer_info(&self) { + self.layer_indexed_points.debug_dump(); + } + + // search the first knbn nearest neigbours of a data, but can modify ef for layer > 1 + // This function return Vec >> + // The parameter ef controls the width of the search in the lowest level, it must be greater + // than number of neighbours asked. A rule of thumb could be between knbn and max_nb_connection. + #[allow(unused)] + fn search_general(&self, data: &[T], knbn: usize, ef_arg: usize) -> Vec { + // + let mut entry_point; + { + // a lock on an option an a Arc + let entry_point_opt_ref = self.layer_indexed_points.entry_point.read(); + if entry_point_opt_ref.is_none() { + return Vec::::new(); + } else { + entry_point = Arc::clone((*entry_point_opt_ref).as_ref().unwrap()); + } + } + // + let mut dist_to_entry = self.dist_f.eval(data, entry_point.as_ref().data.get_v()); + for layer in (1..=entry_point.p_id.0).rev() { + let mut neighbours = self.search_layer(data, Arc::clone(&entry_point), 1, layer, None); + neighbours = from_positive_binaryheap_to_negative_binary_heap(&mut neighbours); + if let Some(entry_point_tmp) = neighbours.pop() { + // get the lowest distance point. + let tmp_dist = self + .dist_f + .eval(data, entry_point_tmp.point_ref.data.get_v()); + if tmp_dist < dist_to_entry { + entry_point = Arc::clone(&entry_point_tmp.point_ref); + dist_to_entry = tmp_dist; + } + } + } + // ef must be greater than knbn. Possibly it should be between knbn and self.max_nb_connection + let ef = ef_arg.max(knbn); + // now search with asked ef in layer 0 + let neighbours_heap = self.search_layer(data, entry_point, ef, 0, None); + // go from heap of points with negative dist to a sorted vec of increasing points with > 0 distances. + let neighbours = neighbours_heap.into_sorted_vec(); + // get the min of K and ef points into a vector. + // + let last = knbn.min(ef).min(neighbours.len()); + let knn_neighbours: Vec = neighbours[0..last] + .iter() + .map(|p| { + Neighbour::new( + p.as_ref().point_ref.origin_id, + p.as_ref().dist_to_ref, + p.as_ref().point_ref.p_id, + ) + }) + .collect(); + + knn_neighbours + } // end of knn_search + + /// a filtered version of [`Self::search`]. + /// A filter can be added to the search to get nodes with a particular property or id constraint. + /// See examples in tests/filtertest.rs + pub fn search_filter( + &self, + data: &[T], + knbn: usize, + ef_arg: usize, + filter: Option<&dyn FilterT>, + ) -> Vec { + // + let entry_point; + { + // a lock on an option an a Arc + let entry_point_opt_ref = self.layer_indexed_points.entry_point.read(); + if entry_point_opt_ref.is_none() { + return Vec::::new(); + } else { + entry_point = Arc::clone((*entry_point_opt_ref).as_ref().unwrap()); + } + } + // + let mut dist_to_entry = self.dist_f.eval(data, entry_point.as_ref().data.get_v()); + let mut pivot = Arc::clone(&entry_point); + let mut new_pivot = None; + + // + for layer in (1..=entry_point.p_id.0).rev() { + let mut has_changed = false; + // search in stored neighbours + { + let neighbours = &pivot.neighbours.read()[layer as usize]; + for n in neighbours { + // get the lowest distance point. + let tmp_dist = self.dist_f.eval(data, n.point_ref.data.get_v()); + if tmp_dist < dist_to_entry { + new_pivot = Some(Arc::clone(&n.point_ref)); + has_changed = true; + dist_to_entry = tmp_dist; + } + } // end of for on neighbours + } + if has_changed { + pivot = Arc::clone(new_pivot.as_ref().unwrap()); + } + } // end on for on layers + // ef must be greater than knbn. Possibly it should be between knbn and self.max_nb_connection + let ef = ef_arg.max(knbn); + log::debug!("pivot changed , current pivot {:?}", pivot.get_point_id()); + // search lowest non empty layer (in case of search with incomplete lower layer at beginning of hnsw filling) + let mut l = 0u8; + let layer_to_search = loop { + if self.get_point_indexation().get_layer_nb_point(l as usize) > 0 { + break l; + } + l += 1; + }; + // now search with asked ef in lower layer + let neighbours_heap = self.search_layer(data, pivot, ef, layer_to_search, filter); + // go from heap of points with negative dist to a sorted vec of increasing points with > 0 distances. + let neighbours = neighbours_heap.into_sorted_vec(); + // get the min of K and ef points into a vector. + // + let last = knbn.min(ef).min(neighbours.len()); + // + if let Some(filter_t) = filter { + let knn_neighbours: Vec = neighbours[0..last] + .iter() + .map(|p| { + if filter_t.hnsw_filter(&p.as_ref().point_ref.origin_id) { + Some(Neighbour::new( + p.as_ref().point_ref.origin_id, + p.as_ref().dist_to_ref, + p.as_ref().point_ref.p_id, + )) + } else { + None + } + }) + .filter(|x| x.is_some()) + .map(|x| x.unwrap()) + .collect(); + // + knn_neighbours + } else { + let knn_neighbours: Vec = neighbours[0..last] + .iter() + .map(|p| { + Neighbour::new( + p.as_ref().point_ref.origin_id, + p.as_ref().dist_to_ref, + p.as_ref().point_ref.p_id, + ) + }) + .collect(); + + knn_neighbours + } + } // end of search_filter + + #[inline] + pub fn search_possible_filter( + &self, + data: &[T], + knbn: usize, + ef_arg: usize, + filter: Option<&dyn FilterT>, + ) -> Vec { + self.search_filter(data, knbn, ef_arg, filter) + } + + /// search the first knbn nearest neigbours of a data and returns a Vector of Neighbour. + /// The parameter ef controls the width of the search in the lowest level, it must be greater + /// than number of neighbours asked. + /// A rule of thumb could be between knbn and max_nb_connection. + pub fn search(&self, data: &[T], knbn: usize, ef_arg: usize) -> Vec { + self.search_possible_filter(data, knbn, ef_arg, None) + } + + fn search_with_id( + &self, + request: (usize, &Vec), + knbn: usize, + ef: usize, + ) -> (usize, Vec) { + (request.0, self.search(request.1, knbn, ef)) + } + + /// knbn is the number of nearest neigbours asked for. Returns for each data vector + /// a Vector of Neighbour + pub fn parallel_search(&self, datas: &[Vec], knbn: usize, ef: usize) -> Vec> { + let (sender, receiver) = channel(); + // make up requests + let nb_request = datas.len(); + let requests: Vec<(usize, &Vec)> = (0..nb_request).zip(datas.iter()).collect(); + // + requests.par_iter().for_each_with(sender, |s, item| { + s.send(self.search_with_id(*item, knbn, ef)).unwrap() + }); + let req_res: Vec<(usize, Vec)> = receiver.iter().collect(); + // now sort to respect the key order of input + let mut answers = Vec::>::with_capacity(datas.len()); + // get a map from request id to rank + let mut req_hash = HashMap::::new(); + for (i, elt) in req_res.iter().enumerate() { + // the response of request req_res[i].0 is at rank i + req_hash.insert(elt.0, i); + } + for i in 0..datas.len() { + let answer_i = req_hash.get_key_value(&i).unwrap().1; + answers.push((req_res[*answer_i].1).clone()); + } + answers + } // end of insert_parallel +} // end of Hnsw + +// This function takes a binary heap with points declared with a negative distance +// and returns a vector of points with their correct positive distance to some reference distance +// The vector is sorted by construction +#[allow(unused)] +fn from_negative_binaryheap_to_sorted_vector<'b, T: Send + Sync + Copy>( + heap_points: &mut BinaryHeap>>, +) -> Vec>> { + let nb_points = heap_points.len(); + let mut vec_points = Vec::>>::with_capacity(nb_points); + // + for p in heap_points.iter() { + assert!(p.dist_to_ref <= 0.); + let reverse_p = Arc::new(PointWithOrder::new(&p.point_ref, -p.dist_to_ref)); + vec_points.push(reverse_p); + } + trace!( + "from_negative_binaryheap_to_sorted_vector nb points in out {:?} {:?} ", + nb_points, + vec_points.len() + ); + vec_points +} + +// This function takes a binary heap with points declared with a positive distance +// and returns a binary_heap of points with their correct negative distance to some reference distance +// +fn from_positive_binaryheap_to_negative_binary_heap<'b, T: Send + Sync + Clone>( + positive_heap: &mut BinaryHeap>>, +) -> BinaryHeap>> { + let nb_points = positive_heap.len(); + let mut negative_heap = BinaryHeap::>>::with_capacity(nb_points); + // + for p in positive_heap.iter() { + assert!(p.dist_to_ref >= 0.); + let reverse_p = Arc::new(PointWithOrder::new(&p.point_ref, -p.dist_to_ref)); + negative_heap.push(reverse_p); + } + trace!( + "from_positive_binaryheap_to_negative_binary_heap nb points in out {:?} {:?} ", + nb_points, + negative_heap.len() + ); + negative_heap +} + +// essentialy to check dump/reload conssistency +// in fact checks only equality of graph +#[allow(unused)] +pub(crate) fn check_graph_equality(hnsw1: &Hnsw, hnsw2: &Hnsw) +where + T1: Copy + Clone + Send + Sync, + D1: Distance + Default + Send + Sync, + T2: Copy + Clone + Send + Sync, + D2: Distance + Default + Send + Sync, +{ + // + debug!("In check_graph_equality"); + // + assert_eq!(hnsw1.get_nb_point(), hnsw2.get_nb_point()); + // check for entry point + assert!( + hnsw1.layer_indexed_points.entry_point.read().is_some() + || hnsw1.layer_indexed_points.entry_point.read().is_some(), + "one entry point is None" + ); + let ep1_read = hnsw1.layer_indexed_points.entry_point.read(); + let ep2_read = hnsw2.layer_indexed_points.entry_point.read(); + let ep1 = ep1_read.as_ref().unwrap(); + let ep2 = ep2_read.as_ref().unwrap(); + assert_eq!( + ep1.origin_id, ep2.origin_id, + "different entry points {:?} {:?}", + ep1.origin_id, ep2.origin_id + ); + assert_eq!(ep1.p_id, ep2.p_id, "origin id {:?} ", ep1.origin_id); + // check layers + let layers_1 = hnsw1.layer_indexed_points.points_by_layer.read(); + let layers_2 = hnsw2.layer_indexed_points.points_by_layer.read(); + let mut nb_point_checked = 0; + let mut nb_neighbours_checked = 0; + for i in 0..NB_LAYER_MAX as usize { + debug!("Checking layer {:?}", i); + assert_eq!(layers_1[i].len(), layers_2[i].len()); + for j in 0..layers_1[i].len() { + let p1 = &layers_1[i][j]; + let p2 = &layers_2[i][j]; + assert_eq!(p1.origin_id, p2.origin_id); + assert_eq!( + p1.p_id, p2.p_id, + "Checking origin_id point {:?} ", + p1.origin_id + ); + nb_point_checked += 1; + // check neighborhood + let nbgh1 = p1.neighbours.read(); + let nbgh2 = p2.neighbours.read(); + assert_eq!(nbgh1.len(), nbgh2.len()); + for k in 0..nbgh1.len() { + assert_eq!(nbgh1[k].len(), nbgh2[k].len()); + for l in 0..nbgh1[k].len() { + assert_eq!( + nbgh1[k][l].point_ref.origin_id, + nbgh2[k][l].point_ref.origin_id + ); + assert_eq!(nbgh1[k][l].point_ref.p_id, nbgh2[k][l].point_ref.p_id); + // CAVEAT for precision with f32 + assert_eq!(nbgh1[k][l].dist_to_ref, nbgh2[k][l].dist_to_ref); + nb_neighbours_checked += 1; + } + } + } // end of for j + } // end of for i + assert_eq!(nb_point_checked, hnsw1.get_nb_point()); + debug!("nb neighbours checked {:?}", nb_neighbours_checked); + debug!("exiting check_equality"); +} // end of check_reload + +#[cfg(test)] + +mod tests { + + use super::*; + use anndists::dist; + + fn log_init_test() { + let _ = env_logger::builder().is_test(true).try_init(); + } + + #[test] + fn test_iter_point() { + // + println!("\n\n test_iter_point"); + // + let mut rng = rand::rng(); + let unif = Uniform::::new(0., 1.).unwrap(); + let nbcolumn = 5000; + let nbrow = 10; + let mut xsi; + let mut data = Vec::with_capacity(nbcolumn); + for j in 0..nbcolumn { + data.push(Vec::with_capacity(nbrow)); + for _ in 0..nbrow { + xsi = rng.sample(unif); + data[j].push(xsi); + } + } + // + // check insertion + let ef_construct = 25; + let nb_connection = 10; + let start = ProcessTime::now(); + let hns = Hnsw::::new( + nb_connection, + nbcolumn, + 16, + ef_construct, + dist::DistL1 {}, + ); + for (i, d) in data.iter().enumerate() { + hns.insert((d, i)); + } + let cpu_time = start.elapsed(); + println!(" test_insert_iter_point time inserting {:?}", cpu_time); + + hns.dump_layer_info(); + // now check iteration + let ptiter = hns.get_point_indexation().into_iter(); + let mut nb_dumped = 0; + for _point in ptiter { + // println!("point : {:?}", _point.p_id); + nb_dumped += 1; + } + // + assert_eq!(nb_dumped, nbcolumn); + } // end of test_iter_point + + #[test] + fn test_iter_layerpoint() { + // + println!("\n\n test_iter_point"); + // + let mut rng = rand::rng(); + let unif = Uniform::::new(0., 1.).unwrap(); + let nbcolumn = 5000; + let nbrow = 10; + let mut xsi; + let mut data = Vec::with_capacity(nbcolumn); + for j in 0..nbcolumn { + data.push(Vec::with_capacity(nbrow)); + for _ in 0..nbrow { + xsi = rng.sample(unif); + data[j].push(xsi); + } + } + // + // check insertion + let ef_construct = 25; + let nb_connection = 10; + let start = ProcessTime::now(); + let hns = Hnsw::::new( + nb_connection, + nbcolumn, + 16, + ef_construct, + dist::DistL1 {}, + ); + for (i, d) in data.iter().enumerate() { + hns.insert((d, i)); + } + let cpu_time = start.elapsed(); + println!(" test_insert_iter_point time inserting {:?}", cpu_time); + + hns.dump_layer_info(); + // now check iteration + let layer_num = 0; + let nbpl = hns.get_point_indexation().get_layer_nb_point(layer_num); + let layer_iter = hns.get_point_indexation().get_layer_iterator(layer_num); + // + let mut nb_dumped = 0; + for _point in layer_iter { + // println!("point : {:?}", _point.p_id); + nb_dumped += 1; + } + println!( + "test_iter_layerpoint : nb point in layer {} , nb found {}", + nbpl, nb_dumped + ); + // + assert_eq!(nb_dumped, nbpl); + } // end of test_iter_layerpoint + + // we should find point even if it is in layer >= 1 + #[test] + fn test_sparse_search() { + log_init_test(); + // + for _ in 0..800 { + let hnsw: Hnsw = + Hnsw::new(15, 100_000, 20, 500_000, dist::DistL1 {}); + hnsw.insert((&[1.0, 0.0, 0.0, 0.0], 0)); + let result = hnsw.search(&[1.0, 0.0, 0.0, 0.0], 2, 10); + assert_eq!(result, vec![Neighbour::new(0, 0.0, PointId(0, 0))]); + } + } +} // end of module test diff --git a/patches/hnsw_rs/src/hnswio.rs b/patches/hnsw_rs/src/hnswio.rs new file mode 100644 index 000000000..b4e8a53df --- /dev/null +++ b/patches/hnsw_rs/src/hnswio.rs @@ -0,0 +1,1703 @@ +//! This module provides io dump/ reload of computed graph via the structure Hnswio. +//! This structure stores references to data points if memory map is used. +//! +//! A dump is constituted of 2 files. +//! One file stores just the graph (or topology) with id of points. +//! The other file stores the ids and vector in point and can be reloaded via a mmap scheme. +//! The graph file is suffixed by "hnsw.graph" the other is suffixed by "hnsw.data" +//! +//! Examples of dump and reload of structure Hnsw is given in the tests (see test_dump_reload, reload_with_mmap) +// datafile +// MAGICDATAP : u32 +// dimension : usize!! +// The for each point the triplet: (MAGICDATAP, origin_id , dimension , array of values bson encoded) ( u32, u64, ....) +// +// A point is dumped in graph file as given by its external id (type DataId i.e : a usize, possibly a hash value) +// and layer (u8) and rank_in_layer:i32. +// In the data file the point dump consist in the triplet: (MAGICDATAP, origin_id , array of values.) +// +use serde::{Serialize, de::DeserializeOwned}; +use std::sync::atomic::{AtomicUsize, Ordering}; +// +use std::time::SystemTime; + +// io +use std::fs::{File, OpenOptions}; +use std::io::{BufReader, BufWriter}; +use std::path::{Path, PathBuf}; + +// synchro +use parking_lot::RwLock; +use std::sync::Arc; + +use std::collections::HashMap; + +use rand::Rng; + +use anyhow::*; +use std::any::type_name; + +use anndists::dist::distances::*; + +use self::hnsw::*; +use crate::datamap::*; +use crate::hnsw; +use log::{debug, error, info, trace}; +use std::io::prelude::*; + +// magic before each graph point data for each point +const MAGICPOINT: u32 = 0x000a678f; +// magic at beginning of description format v2 of dump +const MAGICDESCR_2: u32 = 0x002a677f; + +// magic at beginning of description format v3 of dump +// format where we can use mmap to provide acces to data (not graph) via a memory mapping of file data , +// useful when data vector are large and data uses more space than graph. +// differ from v2 as we do not use bincode encoding for point. We dump pure binary +// This help use mmap as we can return directly a slice. +const MAGICDESCR_3: u32 = 0x002a6771; + +// magic for v4 +// we dump level scale modififcation factor +const MAGICDESCR_4: u32 = 0x002a6779; + +// magic at beginning of a layer dump +const MAGICLAYER: u32 = 0x000a676f; +// magic head of data file and before each data vector +pub(crate) const MAGICDATAP: u32 = 0xa67f0000; + +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum DumpMode { + Light, + Full, +} + +/// The main interface for dumping struct Hnsw. +pub(crate) trait HnswIoT { + fn dump(&self, mode: DumpMode, dumpinit: &mut DumpInit) -> anyhow::Result; +} + +/// Describe options accessible for reload +/// +/// - datamap : a bool for mmap usage. +/// The data point can be reloaded via mmap of data file dump. +/// This can be useful when data points consist in large vectors (as in genomic sketching) +/// as in this case data needs more space than the graph. +/// +/// - mmap_threshold : the number of itmes above which we use mmap. Default is 0, meaning always use mmap data +/// Can be useful for search speed in hnsw if we have part of data resident in memory. +#[derive(Copy, Clone)] +pub struct ReloadOptions { + datamap: bool, + /// number of data items above which we use mmap. + mmap_threshold: usize, +} + +impl Default for ReloadOptions { + /// default is no mmap + fn default() -> Self { + ReloadOptions { + datamap: false, + mmap_threshold: 0, + } + } +} + +impl ReloadOptions { + pub fn new(datamap: bool) -> Self { + ReloadOptions { + datamap, + mmap_threshold: 0, + } + } + + /// set mmap uasge to true + pub fn set_mmap(&mut self, val: bool) -> Self { + self.datamap = val; + *self + } + + /// set mmap threshold i.e : The maximum number of data that will be reloaded in memory by reading file dump, the other points will be mmapped. + /// As the upper layers are the most frequently used, these points will be loaded in memory during reading, the others will be mmaped. + /// See test *reload_with_mmap()* + pub fn set_mmap_threshold(&mut self, threshold: usize) -> Self { + if threshold > 0 { + self.datamap = true; + self.mmap_threshold = threshold; + } + *self + } + + /// return a 2-uple, (datamap, threshold) + pub fn use_mmap(&self) -> (bool, usize) { + (self.datamap, self.mmap_threshold) + } +} // end of ReloadOptions + +//=============================================================================================== + +// initialize datafile and graphfile for io ops +// This structure will check existence of dumps of same name and generate a unique filename if necessary according to overwrite flag +#[allow(unused)] +pub struct DumpInit { + // basename dump + basename: String, + // to dump data + pub(crate) data_out: BufWriter, + // to dump graph + pub(crate) graph_out: BufWriter, +} // end of + +impl DumpInit { + // This structure will check existence of dumps of same name and generate a unique filename if necessary according to overwrite flag + pub fn new(dir: &Path, basename_default: &str, overwrite: bool) -> Self { + // if we cannot overwrite data files (in case of mmap in particular) + // we will ensure we have a unique basename + let basename = match overwrite { + true => basename_default.to_string(), + false => { + // we check + let mut dataname = basename_default.to_string(); + dataname.push_str(".hnsw.data"); + let mut datapath = PathBuf::from(dir); + datapath.push(dataname); + let exist_res = std::fs::metadata(datapath.as_os_str()); + if exist_res.is_ok() { + let unique_basename = loop { + let mut unique_basename; + let mut dataname: String; + let id: usize = rand::thread_rng().gen_range(0..10000); + let strid: String = id.to_string(); + unique_basename = basename_default.to_string(); + unique_basename.push('-'); + unique_basename.push_str(&strid); + dataname = unique_basename.clone(); + dataname.push_str(".hnsw.data"); + let mut datapath = PathBuf::from(dir); + datapath.push(dataname); + let exist_res = std::fs::metadata(datapath.as_os_str()); + if exist_res.is_err() { + break unique_basename; + } + }; + unique_basename + } else { + basename_default.to_string() + } + } + }; + // + info!("Dumping with (unique) basename : {}", basename); + // + let mut graphname = basename.clone(); + graphname.push_str(".hnsw.graph"); + let mut graphpath = PathBuf::from(dir); + graphpath.push(graphname); + let graphfileres = OpenOptions::new() + .create(true) + .truncate(true) + .write(true) + .open(&graphpath); + if graphfileres.is_err() { + println!( + "HnswIo::reload_hnsw : could not open file {:?}", + graphpath.as_os_str() + ); + std::panic::panic_any("HnswIo::init : could not open file".to_string()); + } + let graphfile = graphfileres.unwrap(); + // same thing for data file + let mut dataname = basename.clone(); + dataname.push_str(".hnsw.data"); + let mut datapath = PathBuf::from(dir); + datapath.push(dataname); + let datafileres = OpenOptions::new() + .create(true) + .truncate(true) + .write(true) + .open(&datapath); + if datafileres.is_err() { + println!( + "HnswIo::init : could not open file {:?}", + datapath.as_os_str() + ); + std::panic::panic_any("HnswIo::init : could not open file".to_string()); + } + let datafile = datafileres.unwrap(); + // + let graph_out = BufWriter::new(graphfile); + let data_out = BufWriter::new(datafile); + // + DumpInit { + basename, + data_out, + graph_out, + } + } + + /// returns the basename used for the dump. May be it has been made unique to void overwriting a previous or mmapped dump + pub fn get_basename(&self) -> &String { + &self.basename + } + + pub fn flush(&mut self) -> Result<()> { + self.data_out.flush()?; + self.graph_out.flush()?; + Ok(()) + } +} // end impl for DumpInit + +//==================================================== +// basic block used to provide arguments to load_hnsw and load_hnsw_with_dist +struct LoadInit { + descr: Description, + // + graphfile: BufReader, + // + datafile: BufReader, +} // end of LoadInit + +/// a structure to provide simplified methods for reloading a previous dump. +/// +/// The data point can be reloaded via mmap of data file dump. +/// This can be useful when data points consist in large vectors (as in genomic sketching) +/// as in this case data needs more space than the graph. +/// Note : **As this structure potentially contains the mmap data used in hnsw after reload it must not be dropped +/// before the reloaded hnsw.** +/// Example: +/// +/// See example in tests::reload_with_mmap +/// ```text +/// let directory = Path::new("."); +/// let mut reloader = HnswIo::new(directory, "mmapreloadtest"); +/// let options = ReloadOptions::default().set_mmap(true); +/// reloader.set_options(options); +/// let hnsw_loaded : Hnsw= reloader.load_hnsw::().unwrap(); +/// ``` +/// +/// In some cases we need a hnsw variable that can come from a reload **OR** a direct initialization. +/// +/// Hnswio must be defined before Hnsw as drop is done in reverse order of definition, and the function [load_hnsw](Self::load_hnsw()) +/// borrows Hnswio. (Hnswio stores the mmap address Hnsw can refer to if mmap is used) +/// It is also possible to preinitialize a Hnswio with the default() function which leaves all the fields with blank values and use +/// the function [set_values](Self::set_values()) after. +/// We get something like: +/// +/// ```text +/// let need_reload : bool; +/// .................... +/// let mut hnswio : Hnswio::default(); +/// let hnsw : Hnsw<>; +/// if need_reload { +/// hnswio.set_values(...); +/// hnsw = hnswio.reload_hnsw(...) +/// } +/// else { +/// hnsw = Hnsw::new(...) +/// } +/// ```` +#[derive(Default)] +pub struct HnswIo { + dir: PathBuf, + /// basename is used to build $basename.hnsw.data and $basename.hnsw.graph + basename: String, + /// options + options: ReloadOptions, + datamap: Option, + /// for Hnswio to be async + nb_point_loaded: Arc, + initialized: bool, +} // end of struct ReloadOptions + +impl HnswIo { + /// - directory is directory containing the dumped files, + /// - basename is used to build $basename.hnsw.data and $basename.hnsw.graph + /// + /// default is to use default ReloadOptions. + pub fn new(directory: &Path, basename: &str) -> Self { + HnswIo { + dir: directory.to_path_buf(), + basename: basename.to_string(), + options: ReloadOptions::default(), + datamap: None, + nb_point_loaded: Arc::new(AtomicUsize::new(0)), + initialized: true, + } + } + + /// same as preceding, avoids the call to [set_options](Self::set_options()) + pub fn new_with_options(directory: &Path, basename: &str, options: ReloadOptions) -> Self { + HnswIo { + dir: directory.to_path_buf(), + basename: basename.to_string(), + options, + datamap: None, + nb_point_loaded: Arc::new(AtomicUsize::new(0)), + initialized: true, + } + } + + /// return basename of dump + pub fn get_basename(&self) -> &str { + &self.basename + } + /// this method enables effective initialization after default allocation. + /// It is an error to call set_values on an already defined Hswnio by any function other than [default](Self::default()) + pub fn set_values( + &mut self, + directory: &Path, + basename: String, + options: ReloadOptions, + ) -> Result<()> { + if self.initialized { + return Err(anyhow!("Hnswio already initialized")); + }; + // + self.dir = directory.to_path_buf(); + self.basename = basename; + self.options = options; + self.datamap = None; + // + self.initialized = true; + // + Ok(()) + } // end of set_values + + // + fn init(&self) -> Result { + // + info!("reloading from basename : {}", &self.basename); + // + let mut graphname = self.basename.clone(); + graphname.push_str(".hnsw.graph"); + let mut graphpath = self.dir.clone(); + graphpath.push(graphname); + let graphfileres = OpenOptions::new().read(true).open(&graphpath); + if graphfileres.is_err() { + println!( + "HnswIo::reload_hnsw : could not open file {:?}", + graphpath.as_os_str() + ); + error!( + "HnswIo::reload_hnsw : could not open file {:?}", + graphpath.as_os_str() + ); + return Err(anyhow!( + "HnswIo::reload_hnsw : could not open file {:?}", + graphpath.as_os_str() + )); + } + let graphfile = graphfileres.unwrap(); + // same thing for data file + let mut dataname = self.basename.clone(); + dataname.push_str(".hnsw.data"); + let mut datapath = self.dir.clone(); + datapath.push(dataname); + let datafileres = OpenOptions::new().read(true).open(&datapath); + if datafileres.is_err() { + println!( + "HnswIo::init : could not open file {:?}", + datapath.as_os_str() + ); + error!( + "HnswIo::init : could not open file {:?}", + datapath.as_os_str() + ); + return Err(anyhow!( + "HnswIo::reload_hnsw : could not open file {:?}", + datapath.as_os_str() + )); + } + let datafile = datafileres.unwrap(); + // + let mut graph_in = BufReader::new(graphfile); + let data_in = BufReader::new(datafile); + // we need to call load_description first to get distance name + let hnsw_description = load_description(&mut graph_in).unwrap(); + // + Ok(LoadInit { + descr: hnsw_description, + graphfile: graph_in, + datafile: data_in, + }) + } + + /// to set non default options, in particular to ask for mmap of data file + pub fn set_options(&mut self, options: ReloadOptions) { + self.options = options; + } + + /// reload a previously dumped hnsw structure + pub fn load_hnsw<'b, 'a, T, D>(&'a mut self) -> Result> + where + T: 'static + Serialize + DeserializeOwned + Clone + Sized + Send + Sync + std::fmt::Debug, + D: Distance + Default + Send + Sync, + 'a: 'b, + { + // + debug!("HnswIo::load_hnsw "); + let start_t = SystemTime::now(); + // + let init = self.init(); + if init.is_err() { + return Err(anyhow!("could not reload HNSW structure")); + } + let mut init = init.unwrap(); + let data_in = &mut init.datafile; + let graph_in = &mut init.graphfile; + let description = init.descr; + info!("format version : {}", description.format_version); + // In datafile , we must read MAGICDATAP and dimension and check + let mut it_slice = [0u8; std::mem::size_of::()]; + data_in.read_exact(&mut it_slice)?; + let magic = u32::from_ne_bytes(it_slice); + assert_eq!( + magic, MAGICDATAP, + "magic not equal to MAGICDATAP in load_point" + ); + // + let mut it_slice = [0u8; std::mem::size_of::()]; + data_in.read_exact(&mut it_slice)?; + let dimension = usize::from_ne_bytes(it_slice); + assert_eq!( + dimension, description.dimension, + "data dimension incoherent {:?} {:?} ", + dimension, description.dimension + ); + // + let _mode = description.dumpmode; + let distname = description.distname.clone(); + // We must ensure that the distance stored matches the one asked for in loading hnsw + // for that we check for short names equality stripping + debug!("distance in description = {:?}", distname); + let d_type_name = type_name::().to_string(); + let d_type_name_split: Vec<&str> = d_type_name.rsplit_terminator("::").collect(); + for s in &d_type_name_split { + info!(" distname in generic type argument {:?}", s); + } + let distname_split: Vec<&str> = distname.rsplit_terminator("::").collect(); + if (std::any::TypeId::of::() != std::any::TypeId::of::()) + && (d_type_name_split[0] != distname_split[0]) + { + // for all types except NoData , distance asked in reload declaration and distance in dump must be equal! + let mut errmsg = String::from("error in distances : dumped distance is : "); + errmsg.push_str(&distname); + errmsg.push_str(" asked distance in loading is : "); + errmsg.push_str(&d_type_name); + error!(" distance in type argument : {:?}", d_type_name); + error!("error , dump is for distance = {:?}", distname); + return Err(anyhow!(errmsg)); + } + let t_type = description.t_name.clone(); + debug!("T type name in dump = {:?}", t_type); + // Do we use mmap at reload + if self.options.use_mmap().0 { + let datamap_res = DataMap::from_hnswdump::(self.dir.as_path(), &self.basename); + if datamap_res.is_err() { + error!("load_hnsw could not initialize mmap") + } else { + info!("reload using mmap"); + self.datamap = Some(datamap_res.unwrap()); + } + } + // reloader can use datamap + let layer_point_indexation = self.load_point_indexation(graph_in, &description, data_in)?; + let data_dim = layer_point_indexation.get_data_dimension(); + // + let hnsw: Hnsw = Hnsw { + max_nb_connection: description.max_nb_connection as usize, + ef_construction: description.ef, + extend_candidates: true, + keep_pruned: false, + max_layer: description.nb_layer as usize, + layer_indexed_points: layer_point_indexation, + data_dimension: data_dim, + dist_f: D::default(), + searching: false, + datamap_opt: true, // set datamap_opt to true + }; + // + debug!("load_hnsw completed"); + let elapsed_t = start_t.elapsed().unwrap().as_secs() as f32; + info!("reload_hnsw : elapsed system time(s) {}", elapsed_t); + Ok(hnsw) + } // end of load_hnsw + + /// reload a previously dumped hnsw structure + /// This function makes reload of a Hnsw dump with a given Dist. + /// It is dedicated to distance of type DistPtr (see crate [anndist](https://crates.io/crates/anndists)) that cannot implement Default. + /// **It is the user responsability to reload with the same function as used in the dump** + /// + pub fn load_hnsw_with_dist<'b, 'a, T, D>(&'a self, f: D) -> anyhow::Result> + where + T: 'static + Serialize + DeserializeOwned + Clone + Sized + Send + Sync + std::fmt::Debug, + D: Distance + Send + Sync, + 'a: 'b, + { + // + debug!("HnswIo::load_hnsw_with_dist"); + // + let init = self.init(); + if init.is_err() { + return Err(anyhow!("Could not reload hnsw structure")); + } + let mut init = init.unwrap(); + // + let data_in = &mut init.datafile; + let graph_in = &mut init.graphfile; + let description = init.descr; + // In datafile , we must read MAGICDATAP and dimension and check + let mut it_slice = [0u8; std::mem::size_of::()]; + data_in.read_exact(&mut it_slice)?; + let magic = u32::from_ne_bytes(it_slice); + assert_eq!( + magic, MAGICDATAP, + "magic not equal to MAGICDATAP in load_point" + ); + // + let mut it_slice = [0u8; std::mem::size_of::()]; + data_in.read_exact(&mut it_slice)?; + let dimension = usize::from_ne_bytes(it_slice); + assert_eq!( + dimension, description.dimension, + "data dimension incoherent {:?} {:?} ", + dimension, description.dimension + ); + // + let _mode = description.dumpmode; + let distname = description.distname.clone(); + // We must ensure that the distance stored matches the one asked for in loading hnsw + // for that we check for short names equality stripping + info!("distance in description = {:?}", distname); + let d_type_name = type_name::().to_string(); + let v: Vec<&str> = d_type_name.rsplit_terminator("::").collect(); + for s in v { + info!(" distname in generic type argument {:?}", s); + } + if (std::any::TypeId::of::() != std::any::TypeId::of::()) + && (d_type_name != distname) + { + // for all types except NoData , distance asked in reload declaration and distance in dump must be equal! + let mut errmsg = String::from("error in distances : dumped distance is : "); + errmsg.push_str(&distname); + errmsg.push_str(" asked distance in loading is : "); + errmsg.push_str(&d_type_name); + error!(" distance in type argument : {:?}", d_type_name); + error!("error , dump is for distance = {:?}", distname); + return Err(anyhow!(errmsg)); + } + let t_type = description.t_name.clone(); + info!("T type name in dump = {:?}", t_type); + // + // + let layer_point_indexation = self.load_point_indexation(graph_in, &description, data_in)?; + let data_dim = layer_point_indexation.get_data_dimension(); + // + let hnsw: Hnsw = Hnsw { + max_nb_connection: description.max_nb_connection as usize, + ef_construction: description.ef, + extend_candidates: true, + keep_pruned: false, + max_layer: description.nb_layer as usize, + layer_indexed_points: layer_point_indexation, + data_dimension: data_dim, + dist_f: f, + searching: false, + datamap_opt: false, + }; + // + debug!("load_hnsw_with_dist completed"); + // We cannot check that the pointer function was the same as the dump + // + Ok(hnsw) + } // end of load_hnsw_with_dist + + fn load_point_indexation<'b, 'a, T>( + &'a self, + graph_in: &mut dyn Read, + descr: &Description, + data_in: &mut dyn Read, + ) -> anyhow::Result> + where + T: 'static + Serialize + DeserializeOwned + Clone + Sized + Send + Sync + std::fmt::Debug, + 'a: 'b, + { + // + debug!(" in load_point_indexation"); + // + // now we check that except for the case NoData, the typename are the sames. + if std::any::TypeId::of::() != std::any::TypeId::of::() + && std::any::type_name::() != descr.t_name + { + error!( + "typename loaded in description {:?} do not correspond to instanciation type {:?}", + descr.t_name, + std::any::type_name::() + ); + panic!("incohrent size of T in description"); + } + // + let mut points_by_layer: Vec>>> = + Vec::with_capacity(NB_LAYER_MAX as usize); + let mut neighbourhood_map: HashMap>> = HashMap::new(); + // load max layer + let mut it_slice = [0u8; ::std::mem::size_of::()]; + graph_in.read_exact(&mut it_slice)?; + let nb_layer = u8::from_ne_bytes(it_slice); + debug!("nb layer {:?}", nb_layer); + if nb_layer > NB_LAYER_MAX { + return Err(anyhow!("inconsistent number of layErrers")); + } + // + let mut nb_points_loaded: usize = 0; + let mut nb_still_to_load = descr.nb_point as i64; + let (use_mmap, max_nbpoint_in_memory) = self.options.use_mmap(); + // + for l in 0..nb_layer as usize { + // read and check magic + debug!("loading layer {:?}", l); + let mut it_slice = [0u8; ::std::mem::size_of::()]; + graph_in.read_exact(&mut it_slice)?; + let magic = u32::from_ne_bytes(it_slice); + if magic != MAGICLAYER { + return Err(anyhow!("bad magic at layer beginning")); + } + let mut it_slice = [0u8; ::std::mem::size_of::()]; + graph_in.read_exact(&mut it_slice)?; + let nbpoints = usize::from_ne_bytes(it_slice); + debug!(" layer {:?} , nb points {:?}", l, nbpoints); + let mut vlayer: Vec>> = Vec::with_capacity(nbpoints); + // load graph and data part of point. Points are dumped in the same order. + for r in 0..nbpoints { + // do we use mmap? for this point. We must load into memory up to threshold points, and we also want the most + // frequently accessed points, i.e those in upper layers! to be physically loaded. + // So we do use mmap from the moment the number of points yet to be loaded is less than threshold. + let point_use_mmap = match use_mmap { + false => false, + true => { + if nb_still_to_load <= max_nbpoint_in_memory as i64 { + if log::log_enabled!(log::Level::Info) + && nb_still_to_load == max_nbpoint_in_memory as i64 + { + info!( + "Switching to points in memory. nb points stiil to load {:?}", + nb_still_to_load + ); + } + false + } else { + true + } + } + }; + let load_point_res = self.load_point(graph_in, descr, data_in, point_use_mmap); + if let Err(other) = load_point_res { + error!("in load_point_indexation, loading of point {} failed", r); + return Err(anyhow!(other)); + } + + let load_point_res = load_point_res.unwrap(); + let point = load_point_res.0; + let p_id = point.get_point_id(); + // some checks + assert_eq!(l, p_id.0 as usize); + if r != p_id.1 as usize { + debug!("Origin= {:?}, p_id = {:?}", point.get_origin_id(), p_id); + debug!("Storing at l {:?}, r {:?}", l, r); + } + assert_eq!(r, p_id.1 as usize); + // store neoghbour info of this point + neighbourhood_map.insert(p_id, load_point_res.1); + vlayer.push(point); + nb_points_loaded += 1; + nb_still_to_load -= 1; + assert!(nb_still_to_load >= 0); + } + points_by_layer.push(vlayer); + } + // at this step all points are loaded , but without their neighbours fileds are not yet initialized + let mut nbp: usize = 0; + for (p_id, neighbours) in &neighbourhood_map { + let point = &points_by_layer[p_id.0 as usize][p_id.1 as usize]; + for (l, neighbours) in neighbours.iter().enumerate() { + for n in neighbours { + let n_point = &points_by_layer[n.p_id.0 as usize][n.p_id.1 as usize]; + // now n_point is the Arc corresponding to neighbour n of point, + // construct a corresponding PointWithOrder + let n_pwo = PointWithOrder::::new(n_point, n.distance); + point.neighbours.write()[l].push(Arc::new(n_pwo)); + } // end of for n + // must sort + point.neighbours.write()[l].sort_unstable(); + } // end of for l + nbp += 1; + if nbp % 500_000 == 0 { + debug!("reloading nb_points neighbourhood completed : {}", nbp); + } + } // end loop in neighbourhood_map + // + // get id of entry_point + // load entry point + info!( + "end of layer loading, allocating PointIndexation, nb points loaded {:?}", + nb_points_loaded + ); + // + let mut it_slice = [0u8; std::mem::size_of::()]; + graph_in.read_exact(&mut it_slice)?; + let origin_id = DataId::from_ne_bytes(it_slice); + // + let mut it_slice = [0u8; ::std::mem::size_of::()]; + graph_in.read_exact(&mut it_slice)?; + let layer = u8::from_ne_bytes(it_slice); + // + let mut it_slice = [0u8; std::mem::size_of::()]; + graph_in.read_exact(&mut it_slice)?; + let rank_in_l = i32::from_ne_bytes(it_slice); + // + info!( + "found entry point, origin_id {:?} , layer {:?}, rank in layer {:?} ", + origin_id, layer, rank_in_l + ); + let entry_point = Arc::clone(&points_by_layer[layer as usize][rank_in_l as usize]); + info!( + " loaded entry point, origin_id {:} p_id {:?}", + entry_point.get_origin_id(), + entry_point.get_point_id() + ); + // + let point_indexation = PointIndexation { + max_nb_connection: descr.max_nb_connection as usize, + max_layer: NB_LAYER_MAX as usize, + points_by_layer: Arc::new(RwLock::new(points_by_layer)), + layer_g: LayerGenerator::new_with_scale( + descr.max_nb_connection as usize, + descr.level_scale, + NB_LAYER_MAX as usize, + ), + nb_point: Arc::new(RwLock::new(nb_points_loaded)), // CAVEAT , we should increase , the whole thing is to be able to increment graph ? + entry_point: Arc::new(RwLock::new(Some(entry_point))), + }; + // + debug!("Exiting load_pointIndexation"); + Ok(point_indexation) + } // end of load_pointIndexation + + // + // Reload a point from a dump. + // + // The graph part is loaded from graph_in file + // the data vector itself is loaded from data_in + // + #[allow(clippy::type_complexity)] + fn load_point<'b, 'a, T>( + &'a self, + graph_in: &mut dyn Read, + descr: &Description, + data_in: &mut dyn Read, + point_use_mmap: bool, + ) -> Result<(Arc>, Vec>)> + where + T: 'static + DeserializeOwned + Clone + Sized + Send + Sync + std::fmt::Debug, + 'a: 'b, + { + // + // debug!(" point load {:?} {:?} ", p_id, origin_id); + // Now for each layer , read neighbours + let load_res = load_point_graph(graph_in, descr); + if load_res.is_err() { + error!("load_point error reading graph data for point p_id"); + return Err(anyhow!("error reading graph data for point")); + } + let (origin_id, p_id, neighborhood) = load_res.unwrap(); + // + let point = match point_use_mmap { + false => { + let v = load_point_data::(origin_id, data_in, descr); + if v.is_err() { + error!("loading point {:?}", origin_id); + std::process::exit(1); + } + Point::::new(v.unwrap(), origin_id, p_id) + } + true => { + skip_point_data(origin_id, data_in, descr)?; // keep cohrence between data file and graph file! + debug!("constructing point from datamap, dataid : {:?}", origin_id); + let s: Option<&'b [T]> = self.datamap.as_ref().unwrap().get_data::(&origin_id); + Point::::new_from_mmap(s.unwrap(), origin_id, p_id) + } + }; + self.nb_point_loaded.fetch_add(1, Ordering::Relaxed); + trace!( + "load_point origin {:?} allocated size {:?}, dim {:?}", + origin_id, + point.get_v().len(), + descr.dimension + ); + // + Ok((Arc::new(point), neighborhood)) + } // end of load_point +} // end of Hnswio + +/// structure describing main parameters for hnsnw data and written at the beginning of a dump file. +/// +/// Name of distance and type of data must be encoded in the dump file for a coherent reload. +#[repr(C)] +pub struct Description { + /// to keep track of format version + pub format_version: usize, + /// value is 1 for Full 0 for Light + pub dumpmode: u8, + /// max number of connections in layers != 0 + pub max_nb_connection: u8, + /// scale used in level sampling + pub level_scale: f64, + /// number of observed layers + pub nb_layer: u8, + /// search parameter + pub ef: usize, + /// total number of points + pub nb_point: usize, + /// data dimension + pub dimension: usize, + /// name of distance + pub distname: String, + /// T typename + pub t_name: String, +} + +impl Description { + /// The dump of Description consists in : + /// . The value MAGICDESCR_* as a u32 (4 u8) + /// . The type of dump as u8 + /// . max_nb_connection as u8 + /// . ef (search parameter used in construction) as usize + /// . nb_point (the number points dumped) as a usize + /// . the name of distance used. (nb byes as a usize then list of bytes) + /// + fn dump(&self, argmode: DumpMode, out: &mut BufWriter) -> Result { + info!("in dump of description"); + out.write_all(&MAGICDESCR_4.to_ne_bytes())?; + let mode: u8 = match argmode { + DumpMode::Full => 1, + _ => 0, + }; + // CAVEAT should check mode == self.mode + out.write_all(&mode.to_ne_bytes())?; + // dump of max_nb_connection as u8!! + out.write_all(&self.max_nb_connection.to_ne_bytes())?; + // with MAGICDESCR_4 we must dump self.level_scale + out.write_all(&self.level_scale.to_ne_bytes())?; + // + out.write_all(&self.nb_layer.to_ne_bytes())?; + if self.nb_layer != NB_LAYER_MAX { + println!("dump of Description, nb_layer != NB_MAX_LAYER"); + return Err(anyhow!("dump of Description, nb_layer != NB_MAX_LAYER")); + } + // + info!("dumping ef {:?}", self.ef); + out.write_all(&self.ef.to_ne_bytes())?; + // + info!("dumping nb point {:?}", self.nb_point); + out.write_all(&self.nb_point.to_ne_bytes())?; + // + info!("dumping dimension of data {:?}", self.dimension); + out.write_all(&self.dimension.to_ne_bytes())?; + + // dump of distance name + let namelen: usize = self.distname.len(); + info!("distance name {:?} ", self.distname); + out.write_all(&namelen.to_ne_bytes())?; + out.write_all(self.distname.as_bytes())?; + // dump of T value typename + let namelen: usize = self.t_name.len(); + info!("T name {:?} ", self.t_name); + out.write_all(&namelen.to_ne_bytes())?; + out.write_all(self.t_name.as_bytes())?; + // + Ok(1) + } // end fo dump + + /// return data typename + pub fn get_typename(&self) -> String { + self.t_name.clone() + } + + /// returns dimension of data + pub fn get_dimension(&self) -> usize { + self.dimension + } +} // end of HnswIO impl for Descr + +// + +/// This method is internally used by Hnswio. +/// It is make *pub* as it can be used to retrieve the description of a dump. +/// It takes as input the graph part of the dump. +pub fn load_description(io_in: &mut dyn Read) -> Result { + // + let mut descr = Description { + format_version: 0, + dumpmode: 0, + max_nb_connection: 0, + level_scale: 1.0f64, + nb_layer: 0, + ef: 0, + nb_point: 0, + dimension: 0, + distname: String::from(""), + t_name: String::from(""), + }; + // + let mut it_slice = [0u8; std::mem::size_of::()]; + io_in.read_exact(&mut it_slice)?; + let magic = u32::from_ne_bytes(it_slice); + debug!(" magic {:X} ", magic); + match magic { + MAGICDESCR_2 => { + descr.format_version = 2; + } + MAGICDESCR_3 => { + descr.format_version = 3; + } + MAGICDESCR_4 => { + descr.format_version = 4; + } + _ => { + error!("bad magic"); + return Err(anyhow!("bad magic at descr beginning")); + } + } + let mut it_slice = [0u8; std::mem::size_of::()]; + io_in.read_exact(&mut it_slice)?; + descr.dumpmode = u8::from_ne_bytes(it_slice); + info!(" dumpmode {:?} ", descr.dumpmode); + // + let mut it_slice = [0u8; std::mem::size_of::()]; + io_in.read_exact(&mut it_slice)?; + descr.max_nb_connection = u8::from_ne_bytes(it_slice); + info!(" max_nb_connection {:?} ", descr.max_nb_connection); + // + if descr.format_version == 4 { + // we read modification for level sampling + let mut it_slice = [0u8; std::mem::size_of::()]; + io_in.read_exact(&mut it_slice)?; + descr.level_scale = f64::from_ne_bytes(it_slice); + info!(" level scale : {:.2e}", descr.level_scale); + } + // + let mut it_slice = [0u8; std::mem::size_of::()]; + io_in.read_exact(&mut it_slice)?; + descr.nb_layer = u8::from_ne_bytes(it_slice); + info!("nb_layer {:?} ", descr.nb_layer); + // ef + let mut it_slice = [0u8; std::mem::size_of::()]; + io_in.read_exact(&mut it_slice)?; + descr.ef = usize::from_ne_bytes(it_slice); + info!("ef {:?} ", descr.ef); + // nb_point + let mut it_slice = [0u8; std::mem::size_of::()]; + io_in.read_exact(&mut it_slice)?; + descr.nb_point = usize::from_ne_bytes(it_slice); + // read dimension + let mut it_slice = [0u8; std::mem::size_of::()]; + io_in.read_exact(&mut it_slice)?; + descr.dimension = usize::from_ne_bytes(it_slice); + info!( + "nb_point {:?} dimension {:?} ", + descr.nb_point, descr.dimension + ); + // distance name + let mut it_slice = [0u8; std::mem::size_of::()]; + io_in.read_exact(&mut it_slice)?; + let len: usize = usize::from_ne_bytes(it_slice); + debug!("length of distance name {:?} ", len); + if len > 256 { + info!(" length of distance name > 256"); + println!(" length of distance name should not exceed 256"); + return Err(anyhow!("bad length for distance name")); + } + let mut distv = vec![0; len]; + io_in.read_exact(distv.as_mut_slice())?; + let distname = String::from_utf8(distv).unwrap(); + debug!("distance name {:?} ", distname); + descr.distname = distname; + // reload of type name + let mut it_slice = [0u8; std::mem::size_of::()]; + io_in.read_exact(&mut it_slice)?; + let len: usize = usize::from_ne_bytes(it_slice); + debug!("length of T name {:?} ", len); + if len > 256 { + println!(" length of T name should not exceed 256"); + return Err(anyhow!("bad lenght for T name")); + } + let mut tnamev = vec![0; len]; + io_in.read_exact(tnamev.as_mut_slice())?; + let t_name = String::from_utf8(tnamev).unwrap(); + debug!("T type name {:?} ", t_name); + descr.t_name = t_name; + debug!(" end of description load \n"); + // + Ok(descr) +} + +// +// dump and load of Point +// ========================== +// + +/// Graph part of point dump +/// dump of a point consist in +/// 1. The value MAGICPOINT +/// 2. its identity ( a usize rank in original data , hash value or else , and PointId) +/// 3. for each layer dump of the number of neighbours followed by : +/// for each neighbour dump of its identity (: usize) and then distance (): u32) to point dumped. +/// +/// identity of a point is in full mode the triplet origin_id (: usize), layer (: u8) rank_in_layer (: u32) +/// light mode only origin_id (: usize) +/// For data dump +/// 1. The value MAGICDATAP (u32) +/// 2. origin_id as a u64 +/// 3. The vector of data (the length is known from Description) +/// +fn dump_point( + point: &Point, + mode: DumpMode, + graphout: &mut BufWriter, + dataout: &mut BufWriter, +) -> Result { + // + graphout.write_all(&MAGICPOINT.to_ne_bytes())?; + // dump ext_id: usize , layer : u8 , rank in layer : i32 + graphout.write_all(&point.get_origin_id().to_ne_bytes())?; + let p_id = point.get_point_id(); + if mode == DumpMode::Full { + graphout.write_all(&p_id.0.to_ne_bytes())?; + graphout.write_all(&p_id.1.to_ne_bytes())?; + } + trace!(" point dump {:?} {:?} ", p_id, point.get_origin_id()); + // then dump neighborhood info : nb neighbours : u32 , then list of origin_id, layer, rank_in_layer + let neighborhood = point.get_neighborhood_id(); + // in any case nb_layers are dumped with possibly 0 neighbours at a layer, but this does not occur by construction + for (l, neighbours_at_l) in neighborhood.iter().enumerate() { + // Caution : we dump number of neighbours as a usize, even if it cannot be so large! + let nbg_l: usize = neighbours_at_l.len(); + trace!("\t dumping nbng : {} at l {}", nbg_l, l); + graphout.write_all(&nbg_l.to_ne_bytes())?; + for n in neighbours_at_l { + // dump d_id : uszie , distance : f32, layer : u8, rank in layer : i32 + graphout.write_all(&n.d_id.to_ne_bytes())?; + if mode == DumpMode::Full { + graphout.write_all(&n.p_id.0.to_ne_bytes())?; + graphout.write_all(&n.p_id.1.to_ne_bytes())?; + } + graphout.write_all(&n.distance.to_ne_bytes())?; + // debug!(" voisins {:?} {:?} {:?}", n.p_id, n.d_id , n.distance); + } + } + // now we dump data vector! + dataout.write_all(&MAGICDATAP.to_ne_bytes())?; + let origin_u64 = point.get_origin_id() as u64; + dataout.write_all(&origin_u64.to_ne_bytes())?; + // + let serialized = unsafe { + std::slice::from_raw_parts( + point.get_v().as_ptr() as *const u8, + std::mem::size_of_val(point.get_v()), + ) + }; + trace!("serializing len {:?}", serialized.len()); + let len_64 = serialized.len() as u64; + dataout.write_all(&len_64.to_ne_bytes())?; + dataout.write_all(serialized)?; + // + Ok(1) +} // end of dump for Point + +// just reload data vector for point from file where data were dumped +// used when we do not used memory map in reload +fn load_point_data( + origin_id: usize, + data_in: &mut dyn Read, + descr: &Description, +) -> Result> +where + T: 'static + DeserializeOwned + Clone + Sized + Send + Sync, +{ + // + trace!("load_point_data , origin id : {}", origin_id); + // + // construct a point from data_in + // + let mut it_slice = [0u8; std::mem::size_of::()]; + data_in.read_exact(&mut it_slice)?; + let magic = u32::from_ne_bytes(it_slice); + assert_eq!( + magic, MAGICDATAP, + "magic not equal to MAGICDATAP in load_point, point_id : {:?} ", + origin_id + ); + // read origin id + let mut it_slice = [0u8; std::mem::size_of::()]; + data_in.read_exact(&mut it_slice)?; + let origin_id_data = u64::from_ne_bytes(it_slice) as usize; + assert_eq!( + origin_id, origin_id_data, + "origin_id incoherent between graph and data" + ); + // now read data. we use size_t that is in description, to take care of the casewhere we reload + let mut it_slice = [0u8; std::mem::size_of::()]; + data_in.read_exact(&mut it_slice)?; + let serialized_len = u64::from_ne_bytes(it_slice); + trace!("serialized len to reload {:?}", serialized_len); + let mut v_serialized = vec![0; serialized_len as usize]; + data_in.read_exact(&mut v_serialized)?; + + let v: Vec = if std::any::TypeId::of::() != std::any::TypeId::of::() { + match descr.format_version { + 2 => bincode::deserialize(&v_serialized).unwrap(), + 3 | 4 => { + let slice_t = unsafe { + std::slice::from_raw_parts(v_serialized.as_ptr() as *const T, descr.dimension) + }; + slice_t.to_vec() + } + _ => { + error!( + "error in load_point, unknow format_version : {:?}", + descr.format_version + ); + std::process::exit(1); + } + } + } else { + Vec::new() + }; + // + Ok(v) +} // end of load_point_data + +// We need to maintain coherence in data and graph stream, so we read to keep in phase +fn skip_point_data(origin_id: usize, data_in: &mut dyn Read, _descr: &Description) -> Result<()> { + // + let mut it_slice = [0u8; std::mem::size_of::()]; + data_in.read_exact(&mut it_slice)?; + let magic = u32::from_ne_bytes(it_slice); + assert_eq!( + magic, MAGICDATAP, + "magic not equal to MAGICDATAP in load_point, point_id : {:?} ", + origin_id + ); + // read origin id + let mut it_slice = [0u8; std::mem::size_of::()]; + data_in.read_exact(&mut it_slice)?; + let origin_id_data = u64::from_ne_bytes(it_slice) as usize; + assert_eq!( + origin_id, origin_id_data, + "origin_id incoherent between graph and data" + ); + // + // now read data. we use size_t that is in description, to take care of the casewhere we reload + let mut it_slice = [0u8; std::mem::size_of::()]; + data_in.read_exact(&mut it_slice)?; + let serialized_len = u64::from_ne_bytes(it_slice); + trace!( + "skip_point_data : serialized len to reload {:?}", + serialized_len + ); + let mut v_serialized = vec![0; serialized_len as usize]; + data_in.read_exact(&mut v_serialized)?; + // + Ok(()) +} // end of skip_point_data + +//================================================================================== + +/// This structure gathers info loaded in dumped graph file for a point. +type PointGraphInfo = (usize, PointId, Vec>); + +// This function reads neighbourhood info and returns neighbourhood info. +// It suppose and requires that the file graph_in is just at beginning of info related to origin_id +fn load_point_graph(graph_in: &mut dyn Read, descr: &Description) -> Result { + // + trace!("in load_point_graph"); + // read and check magic + let mut it_slice = [0u8; std::mem::size_of::()]; + graph_in.read_exact(&mut it_slice).unwrap(); + let magic = u32::from_ne_bytes(it_slice); + if magic != MAGICPOINT { + error!("got instead of MAGICPOINT {:x}", magic); + return Err(anyhow!("bad magic at point beginning")); + } + let mut it_slice = [0u8; std::mem::size_of::()]; + graph_in.read_exact(&mut it_slice).unwrap(); + let origin_id = DataId::from_ne_bytes(it_slice); + // + // read point_id + let mut it_slice = [0u8; std::mem::size_of::()]; + graph_in.read_exact(&mut it_slice).unwrap(); + let layer = u8::from_ne_bytes(it_slice); + // + let mut it_slice = [0u8; std::mem::size_of::()]; + graph_in.read_exact(&mut it_slice).unwrap(); + let rank_in_l = i32::from_ne_bytes(it_slice); + let p_id = PointId(layer, rank_in_l); + debug!( + "in load_point_graph, got origin_id : {}, p_id : {:?}", + origin_id, p_id + ); + // + // Now for each layer , read neighbours + let nb_layer = descr.nb_layer; + let mut neighborhood = Vec::>::with_capacity(NB_LAYER_MAX as usize); + for _l in 0..nb_layer { + let mut neighbour: Neighbour = Default::default(); + // read nb_neighbour as usize!!! CAUTION, then nb_neighbours times identity(depends on Full or Light) distance : f32 + let mut it_slice = [0u8; std::mem::size_of::()]; + graph_in.read_exact(&mut it_slice).unwrap(); + let nb_neighbours = usize::from_ne_bytes(it_slice); + let mut neighborhood_l: Vec = Vec::with_capacity(nb_neighbours); + for _j in 0..nb_neighbours { + let mut it_slice = [0u8; std::mem::size_of::()]; + graph_in.read_exact(&mut it_slice).unwrap(); + neighbour.d_id = DataId::from_ne_bytes(it_slice); + if descr.dumpmode == 1 { + let mut it_slice = [0u8; std::mem::size_of::()]; + graph_in.read_exact(&mut it_slice).unwrap(); + neighbour.p_id.0 = u8::from_ne_bytes(it_slice); + // + let mut it_slice = [0u8; std::mem::size_of::()]; + graph_in.read_exact(&mut it_slice).unwrap(); + neighbour.p_id.1 = i32::from_ne_bytes(it_slice); + } + let mut it_slice = [0u8; std::mem::size_of::()]; + graph_in.read_exact(&mut it_slice).unwrap(); + neighbour.distance = f32::from_ne_bytes(it_slice); + // debug!(" voisins load {:?} {:?} {:?} ", neighbour.p_id, neighbour.d_id , neighbour.distance); + // now we have a new neighbour, we must really fill neighbourhood info, so it means going from Neighbour to PointWithOrder + neighborhood_l.push(neighbour); + } + neighborhood.push(neighborhood_l); + } + for _l in nb_layer..NB_LAYER_MAX { + neighborhood.push(Vec::::new()); + } + // + let point_grap_info = (origin_id, p_id, neighborhood); + // + Ok(point_grap_info) +} // end of load_point_graph + +// +// dump and load of PointIndexation +// =================================== +// +// +// nb_layer : 8 +// a magick at each Layer : u32 +// . number of points in layer (usize), +// . list of point of layer +// dump entry point +// +impl HnswIoT for PointIndexation<'_, T> { + fn dump(&self, mode: DumpMode, dumpinit: &mut DumpInit) -> Result { + let graphout = &mut dumpinit.graph_out; + let dataout = &mut dumpinit.data_out; + // dump max_layer + let layers = self.points_by_layer.read(); + let nb_layer = layers.len() as u8; + graphout.write_all(&nb_layer.to_ne_bytes())?; + // dump layers from lower (most populatated to higher level) + for i in 0..layers.len() { + let nb_point = layers[i].len(); + debug!("dumping layer {:?}, nb_point {:?}", i, nb_point); + graphout.write_all(&MAGICLAYER.to_ne_bytes())?; + graphout.write_all(&nb_point.to_ne_bytes())?; + for j in 0..layers[i].len() { + assert_eq!(layers[i][j].get_point_id(), PointId(i as u8, j as i32)); + dump_point(&layers[i][j], mode, graphout, dataout)?; + } + } + // dump id of entry point + let ep_read = self.entry_point.read(); + let ep = ep_read + .as_ref() + .ok_or(anyhow!("entry point not initialized"))?; + //let ep = ep_read.as_ref().unwrap(); + graphout.write_all(&ep.get_origin_id().to_ne_bytes())?; + let p_id = ep.get_point_id(); + if mode == DumpMode::Full { + graphout.write_all(&p_id.0.to_ne_bytes())?; + graphout.write_all(&p_id.1.to_ne_bytes())?; + } + info!( + "dumped entry_point origin_d {:?}, p_id {:?} ", + ep.get_origin_id(), + p_id + ); + // + Ok(1) + } // end of dump for PointIndexation +} // end of impl HnswIO + +// +// dump and load of Hnsw +// ========================= +// +// + +impl + Send + Sync> + HnswIoT for Hnsw<'_, T, D> +{ + /// The dump method for hnsw. + /// - graphout is a BufWriter dedicated to the dump of the graph part of Hnsw + /// - dataout is a bufWriter dedicated to the dump of the data stored in the Hnsw structure. + fn dump(&self, mode: DumpMode, dumpinit: &mut DumpInit) -> anyhow::Result { + // + let graphout = &mut dumpinit.graph_out; + let dataout = &mut dumpinit.data_out; + // dump description , then PointIndexation + let dumpmode: u8 = match mode { + DumpMode::Full => 1, + _ => 0, + }; + let datadim: usize = self.layer_indexed_points.get_data_dimension(); + let level_scale = self.layer_indexed_points.get_level_scale(); + let description = Description { + format_version: 3, + // value is 1 for Full 0 for Light + dumpmode, + max_nb_connection: self.get_max_nb_connection(), + level_scale, + nb_layer: self.get_max_level() as u8, + ef: self.get_ef_construction(), + nb_point: self.get_nb_point(), + dimension: datadim, + distname: self.get_distance_name(), + t_name: type_name::().to_string(), + }; + debug!("dump obtained typename {:?}", type_name::()); + description.dump(mode, graphout)?; + // We must dump a header for dataout. + dataout.write_all(&MAGICDATAP.to_ne_bytes())?; + dataout.write_all(&datadim.to_ne_bytes())?; + // + self.layer_indexed_points.dump(mode, dumpinit)?; + Ok(1) + } +} // end impl block for Hnsw + +//=============================================================================================================== + +#[cfg(test)] + +mod tests { + use super::*; + + pub use crate::api::AnnT; + use anndists::dist; + use log::error; + + use rand::distr::{Distribution, Uniform}; + + fn log_init_test() { + let _ = env_logger::builder().is_test(true).try_init(); + } + + fn my_fn(v1: &[f32], v2: &[f32]) -> f32 { + let norm_l1: f32 = v1.iter().zip(v2.iter()).map(|t| (*t.0 - *t.1).abs()).sum(); + norm_l1 + } + + #[test] + fn test_dump_reload_1() { + println!("\n\n test_dump_reload_1"); + log_init_test(); + // generate a random test + let mut rng = rand::rng(); + let unif = Uniform::::new(0., 1.).unwrap(); + // 1000 vectors of size 10 f32 + let nbcolumn = 1000; + let nbrow = 10; + let mut xsi; + let mut data = Vec::with_capacity(nbcolumn); + for j in 0..nbcolumn { + data.push(Vec::with_capacity(nbrow)); + for _ in 0..nbrow { + xsi = unif.sample(&mut rng); + data[j].push(xsi); + } + } + // define hnsw + let ef_construct = 25; + let nb_connection = 10; + let hnsw = Hnsw::::new( + nb_connection, + nbcolumn, + 16, + ef_construct, + dist::DistL1 {}, + ); + for (i, d) in data.iter().enumerate() { + hnsw.insert((d, i)); + } + // some loggin info + hnsw.dump_layer_info(); + // dump in a file. Must take care of name as tests runs in // !!! + let fname = "dumpreloadtest1"; + let directory = tempfile::tempdir().unwrap(); + let _res = hnsw.file_dump(directory.path(), fname); + // + // reload + debug!("\n\n test_dump_reload_1 hnsw reload"); + // we will need a procedural macro to get from distance name to its instanciation. + // from now on we test with DistL1 + let mut reloader = HnswIo::new(directory.path(), fname); + let hnsw_loaded: Hnsw = reloader.load_hnsw::().unwrap(); + // test equality + check_graph_equality(&hnsw_loaded, &hnsw); + } // end of test_dump_reload + + #[test] + fn test_dump_reload_myfn() { + println!("\n\n test_dump_reload_myfn"); + log_init_test(); + // generate a random test + let mut rng = rand::rng(); + let unif = Uniform::::new(0., 1.).unwrap(); + // 1000 vectors of size 10 f32 + let nbcolumn = 1000; + let nbrow = 10; + let mut xsi; + let mut data = Vec::with_capacity(nbcolumn); + for j in 0..nbcolumn { + data.push(Vec::with_capacity(nbrow)); + for _ in 0..nbrow { + xsi = unif.sample(&mut rng); + data[j].push(xsi); + } + } + // define hnsw + let ef_construct = 25; + let nb_connection = 10; + let mydist = dist::DistPtr::::new(my_fn); + let hnsw = Hnsw::>::new( + nb_connection, + nbcolumn, + 16, + ef_construct, + mydist, + ); + for (i, d) in data.iter().enumerate() { + hnsw.insert((d, i)); + } + // some loggin info + hnsw.dump_layer_info(); + let fname = "dumpreloadtest_myfn"; + let directory = tempfile::tempdir().unwrap(); + + let _res = hnsw.file_dump(directory.path(), fname); + // This will dump in 2 files named dumpreloadtest.hnsw.graph and dumpreloadtest.hnsw.data + // + // reload + debug!("HNSW reload"); + let reloader = HnswIo::new(directory.path(), fname); + let mydist = dist::DistPtr::::new(my_fn); + let _hnsw_loaded: Hnsw> = + reloader.load_hnsw_with_dist(mydist).unwrap(); + } // end of test_dump_reload_myfn + + #[test] + fn test_dump_reload_graph_only() { + println!("\n\n test_dump_reload_graph_only"); + log_init_test(); + // generate a random test + let mut rng = rand::rng(); + let unif = Uniform::::new(0., 1.).unwrap(); + // 1000 vectors of size 10 f32 + let nbcolumn = 1000; + let nbrow = 10; + let mut xsi; + let mut data = Vec::with_capacity(nbcolumn); + for j in 0..nbcolumn { + data.push(Vec::with_capacity(nbrow)); + for _ in 0..nbrow { + xsi = unif.sample(&mut rng); + data[j].push(xsi); + } + } + // define hnsw + let ef_construct = 25; + let nb_connection = 10; + let hnsw = Hnsw::::new( + nb_connection, + nbcolumn, + 16, + ef_construct, + dist::DistL1 {}, + ); + for (i, d) in data.iter().enumerate() { + hnsw.insert((d, i)); + } + // some loggin info + hnsw.dump_layer_info(); + // dump in a file. Must take care of name as tests runs in // !!! + let fname = "dumpreloadtestgraph"; + let directory = tempfile::tempdir().unwrap(); + let _res = hnsw.file_dump(directory.path(), fname); + // This will dump in 2 files named dumpreloadtest.hnsw.graph and dumpreloadtest.hnsw.data + // + // reload + debug!("\n\n hnsw reload"); + let mut reloader = HnswIo::new(directory.path(), fname); + let hnsw_loaded: Hnsw = reloader.load_hnsw().unwrap(); + // test equality + check_graph_equality(&hnsw_loaded, &hnsw); + } // end of test_dump_reload + + // this tests reloads a dump with memory mapping of data, inserts new data and redump + #[test] + fn reload_with_mmap() { + println!("\n\n hnswio tests : reload_with_mmap"); + log_init_test(); + // generate a random test + let mut rng = rand::rng(); + let unif = Uniform::::new(0., 1.).unwrap(); + // 100 vectors of size 10 f32 + let nbcolumn = 100; + let nbrow = 10; + let mut xsi; + let mut data = Vec::with_capacity(nbcolumn); + for j in 0..nbcolumn { + data.push(Vec::with_capacity(nbrow)); + for _ in 0..nbrow { + xsi = unif.sample(&mut rng); + data[j].push(xsi); + } + } + // + let first: Vec = data[0].clone(); + info!("data[0] = {:?}", first); + // define hnsw + let ef_construct = 25; + let nb_connection = 10; + let hnsw = Hnsw::::new( + nb_connection, + nbcolumn, + 16, + ef_construct, + dist::DistL1 {}, + ); + for (i, d) in data.iter().enumerate() { + hnsw.insert((d, i)); + } + // some loggin info + hnsw.dump_layer_info(); + // dump in a file. Must take care of name as tests runs in // !!! + let fname = "mmapreloadtest"; + let directory = tempfile::tempdir().unwrap(); + let dumpname = hnsw.file_dump(directory.path(), fname).unwrap(); + debug!("dump succeeded in file basename : {}", dumpname); + // + // reload reload_with_mmap + debug!("HNSW reload"); + let mut reloader = HnswIo::new(directory.path(), &dumpname); + // use mmap for points after half number of points + let options = ReloadOptions::default().set_mmap_threshold(nbcolumn / 2); + reloader.set_options(options); + let hnsw_loaded: Hnsw = reloader.load_hnsw::().unwrap(); + // test equality + check_graph_equality(&hnsw_loaded, &hnsw); + // We add nbcolumn new vectors + info!("adding points in hnsw reloaded"); + let nbcolumn = 5; + let nbrow = 10; + let mut xsi; + let mut data = Vec::with_capacity(nbcolumn); + for j in 0..nbcolumn { + data.push(Vec::with_capacity(nbrow)); + for _ in 0..nbrow { + xsi = unif.sample(&mut rng); + data[j].push(xsi); + } + } + let first_with_mmap: Vec = data[0].clone(); + info!( + "first added after reloading with mmap : data[0] = {:?}", + first_with_mmap + ); + let nb_in = hnsw.get_nb_point(); + for (i, d) in data.iter().enumerate() { + hnsw.insert((d, i + nb_in)); + } + // + let search_res = hnsw.search(&first, 5, ef_construct); + info!("neighbours od first point inserted"); + for n in &search_res { + info!("neighbour: {:?}", n); + } + assert_eq!(search_res[0].d_id, 0); + assert_eq!(search_res[0].distance, 0.); + let search_res = hnsw.search(&first_with_mmap, 5, ef_construct); + info!("neighbours of first point inserted after reload with mmap"); + for n in &search_res { + info!("neighbour {:?}", n); + } + if search_res[0].d_id != nb_in { + // with very low probability it could happen that we find a very near point! + // then distance should very small + info!( + "neighbour found for point id : {}, distance : {:.2e}, should have been id : {}, dist : {:.2e}", + search_res[0].d_id, search_res[0].distance, nb_in, 0. + ); + } + assert_eq!(search_res[0].d_id, nb_in); + assert_eq!(search_res[0].distance, 0.); + // + // TODO: redump and care about mmapped file, so we do not overwrite + // + let dump_init = DumpInit::new(directory.path(), fname, false); + info!("will use basename : {}", dump_init.get_basename()); + let res = hnsw.file_dump(directory.path(), dump_init.get_basename()); + if res.is_err() { + error!("hnsw.file_dump failed"); + std::panic!("hnsw.file_dump failed"); + } + } // end of reload_with_mmap + + #[test] + fn test_bincode() { + let mut rng = rand::rng(); + let unif = Uniform::::new(0., 1.).unwrap(); + let size = 10; + let mut xsi; + let mut data = Vec::with_capacity(size); + for _ in 0..size { + xsi = unif.sample(&mut rng); + println!("xsi = {:?}", xsi); + data.push(xsi); + } + println!("to serialized {:?}", data); + + let v_serialized: Vec = bincode::serialize(&data).unwrap(); + debug!("serializing len {:?}", v_serialized.len()); + let v_deserialized: Vec = bincode::deserialize(&v_serialized).unwrap(); + println!("deserialized {:?}", v_deserialized); + } + + #[test] + fn read_write_empty_db() -> Result<()> { + log_init_test(); + let ef_construct = 25; + let nb_connection = 10; + let hnsw = + Hnsw::::new(nb_connection, 0, 16, ef_construct, dist::DistL1 {}); + let fname = "empty_db"; + let directory = tempfile::tempdir()?; + let _res = hnsw.file_dump(directory.path(), fname); + let mut reloader = HnswIo::new(directory.path(), fname); + let hnsw_loaded_res = reloader.load_hnsw::(); + assert!(hnsw_loaded_res.is_err()); + Ok(()) + } +} // end module tests diff --git a/patches/hnsw_rs/src/lib.rs b/patches/hnsw_rs/src/lib.rs new file mode 100644 index 000000000..9c207134f --- /dev/null +++ b/patches/hnsw_rs/src/lib.rs @@ -0,0 +1,30 @@ +#![cfg_attr(feature = "stdsimd", feature(portable_simd))] +// +// for logging (debug mostly, switched at compile time in cargo.toml) +use env_logger::Builder; + +use lazy_static::lazy_static; + +pub mod api; +pub mod datamap; +pub mod filter; +pub mod flatten; +pub mod hnsw; +pub mod hnswio; +pub mod libext; +pub mod prelude; + +// we impose our version of anndists +pub use anndists; + +lazy_static! { + static ref LOG: u64 = init_log(); +} + +// install a logger facility +#[allow(unused)] +fn init_log() -> u64 { + Builder::from_default_env().init(); + println!("\n ************** initializing logger *****************\n"); + 1 +} diff --git a/patches/hnsw_rs/src/libext.rs b/patches/hnsw_rs/src/libext.rs new file mode 100644 index 000000000..8543a3dba --- /dev/null +++ b/patches/hnsw_rs/src/libext.rs @@ -0,0 +1,1240 @@ +//! This file contains lib to call hnsw from julia (or any language providing a C api) +//! The AnnT trait is implemented with macros for u32, u16, u8, f32, f64 and i32. +//! The macro declare_myapi_type! produces struct HnswApif32 and so on. +//! + +#![allow(non_camel_case_types)] + +use core::ffi::c_ulonglong; +use std::fs::OpenOptions; +use std::io::BufReader; +use std::path::PathBuf; +use std::ptr; + +use anndists::dist::distances::*; +use log::{debug, error, info, trace, warn}; + +use crate::api::*; +use crate::hnsw::*; +use crate::hnswio::*; + +//========== Hnswio + +/// returns a pointer to a Hnswio +/// args corresponds to string giving base filename of dump, supposed to be in current directory +/// # Safety +/// pointer must be char* pointer to the string +#[unsafe(no_mangle)] +pub unsafe extern "C" fn get_hnswio(flen: u64, name: *const u8) -> *const HnswIo { + let slice = unsafe { std::slice::from_raw_parts(name, flen as usize) }; + let filename = String::from_utf8_lossy(slice).into_owned(); + let hnswio = HnswIo::new(std::path::Path::new("."), &filename); + Box::into_raw(Box::new(hnswio)) +} + +//================= +// the export macro makes the macro global in crate and accecssible via crate::declare_myapi_type! +#[macro_export] +macro_rules! declare_myapi_type( + ($name:ident, $ty:ty) => ( + pub struct $name { +#[allow(dead_code)] + pub(crate) opaque: Box>, + } // end struct + impl $name { + pub fn new(arg: Box>) -> Self { + $name{ opaque:arg} + } // end new + } // end impl + ) +); + +declare_myapi_type!(HnswApiNodata, NoData); + +// declare_myapi_type!(HnswApif64, f64); +// declare_myapi_type!(HnswApif32, f32); + +/// to be able to return a vector from rust in a julia struct before converting to a julia Vector +#[repr(C)] +pub struct Vec_api { + len: i64, + ptr: *const T, +} // end struct + +#[repr(C)] +/// The basic Neighbour info returned by api +pub struct Neighbour_api { + /// id of neighbour + pub id: usize, + /// distance of data sent in request to this neighbour + pub d: f32, +} + +impl From<&Neighbour> for Neighbour_api { + fn from(neighbour: &Neighbour) -> Self { + Neighbour_api { + id: neighbour.d_id, + d: neighbour.distance, + } + } +} + +#[repr(C)] +/// The response to a neighbour search requests +pub struct Neighbourhood_api { + pub nbgh: i64, + pub neighbours: *const Neighbour_api, +} + +#[repr(C)] +pub struct Neighbour_api_parsearch_answer { + /// The number of answers (o request), i.e size of both vectors nbgh and neighbours + pub nb_answer: usize, + /// for each request, we get a Neighbourhood_api + pub neighbourhoods: *const Neighbourhood_api, +} + +//===================================== f32 type ===================================== + +// macros have been exported to the root of the crate so we do not refer to them via api:: +super::declare_myapi_type!(HnswApif32, f32); +super::declare_myapi_type!(HnswApif64, f64); + +//=================================================================================================== +// These are the macros to generate trait implementation for useful numeric types +#[allow(unused_macros)] +macro_rules! generate_insert( +($function_name:ident, $api_name:ty, $type_val:ty) => ( + /// # Safety + /// The function is unsafe because it dereferences a raw pointer + /// + #[unsafe(no_mangle)] + pub unsafe extern "C" fn $function_name(hnsw_api : *mut $api_name, len:usize, data : *const $type_val, id : usize) { + trace!("entering insert, type {:?} vec len is {:?}, id : {:?} ", stringify!($type_val), len, id); + // construct vector: Rust clones and take ownership. + let data_v : Vec<$type_val>; + unsafe { + let slice = std::slice::from_raw_parts(data, len); + data_v = Vec::from(slice); + trace!("calling insert data"); + (*hnsw_api).opaque.insert_data(&data_v, id); + } + trace!("exiting insert for type {:?}", stringify!($type_val)); + } // end of insert + ) +); + +macro_rules! generate_parallel_insert( +($function_name:ident, $api_name:ty, $type_val:ty) => ( + /// # Safety + /// The function is unsafe because it dereferences a raw pointer + /// + #[unsafe(no_mangle)] + pub unsafe extern "C" fn $function_name(hnsw_api : *mut $api_name, nb_vec: usize, vec_len : usize, + datas : *mut *const $type_val, ids : *const usize) { + // + trace!("entering parallel_insert type {:?} , vec len is {:?}, nb_vec : {:?}", stringify!($type_val), vec_len, nb_vec); + let data_ids : Vec; + let data_ptrs : Vec<*const $type_val>; + unsafe { + let slice = std::slice::from_raw_parts(ids, nb_vec); + data_ids = Vec::from(slice); + } + // debug!("got ids"); + unsafe { + let slice = std::slice::from_raw_parts(datas, nb_vec); + data_ptrs = Vec::from(slice); + } + // debug!("got data ptrs"); + let mut data_v = Vec::>::with_capacity(nb_vec); + for i in 0..nb_vec { + unsafe { + let slice = std::slice::from_raw_parts(data_ptrs[i], vec_len); + let v = Vec::from(slice); + data_v.push(v); + } + } + // debug!("sending request"); + let mut request : Vec<(&Vec<$type_val>, usize)> = Vec::with_capacity(nb_vec); + for i in 0..nb_vec { + request.push((&data_v[i], data_ids[i])); + } + // + unsafe { (*hnsw_api).opaque.parallel_insert_data(&request); }; + trace!("exiting parallel_insert"); + } // end of parallel_insert + ) +); + +macro_rules! generate_search_neighbours( +($function_name:ident, $api_name:ty, $type_val:ty) => ( + /// # Safety + /// The function is unsafe because it dereferences a raw pointer + /// + #[unsafe(no_mangle)] + pub unsafe extern "C" fn $function_name(hnsw_api : *const $api_name, len:usize, data : *const $type_val, + knbn : usize, ef_search : usize) -> *const Neighbourhood_api { + // + trace!("entering search_neighbours type {:?}, vec len is {:?}, id : {:?} ef_search {:?}", stringify!($type_val), len, knbn, ef_search); + let data_v : Vec<$type_val>; + let neighbours : Vec; + unsafe { + let slice = std::slice::from_raw_parts(data, len); + data_v = Vec::from(slice); + trace!("calling search neighbours {:?}", data_v); + neighbours = (*hnsw_api).opaque.search_neighbours(&data_v, knbn, ef_search); + } + let neighbours_api : Vec = neighbours.iter().map(|n| Neighbour_api::from(n)).collect(); + trace!(" got nb neighbours {:?}", neighbours_api.len()); + // for i in 0..neighbours_api.len() { + // println!(" id {:?} dist : {:?} ", neighbours_api[i].id, neighbours_api[i].d); + // } + let nbgh_i = neighbours.len() as i64; + let neighbours_ptr = neighbours_api.as_ptr(); + std::mem::forget(neighbours_api); + let answer = Neighbourhood_api { + nbgh : nbgh_i, + neighbours : neighbours_ptr, + }; + trace!("search_neighbours returning nb neighbours {:?} id ptr {:?} ", nbgh_i, neighbours_ptr); + Box::into_raw(Box::new(answer)) + } + ) +); + +macro_rules! generate_parallel_search_neighbours( +($function_name:ident, $api_name:ty, $type_val:ty) => ( + #[unsafe(no_mangle)] + /// search nb_vec of size vec_len. The the searches will be done in // as far as possible. + /// # Safety + /// The function is unsafe because it dereferences a raw pointer + /// + pub unsafe extern "C" fn $function_name(hnsw_api : *const $api_name, nb_vec : usize, vec_len :i64, + data : *mut *const $type_val, knbn : usize, ef_search : usize) -> *const Vec_api { + // + // must build a Vec to build request + trace!("recieving // search request for type: {:?} with {:?} vectors", stringify!($type_val), nb_vec); + let neighbours : Vec >; + let mut data_v = Vec::>::with_capacity(nb_vec); + unsafe { + let slice = std::slice::from_raw_parts(data, nb_vec); + let ptr_list : Vec<*const $type_val> = Vec::from(slice); + for i in 0..nb_vec { + let slice_i = std::slice::from_raw_parts(ptr_list[i], vec_len as usize); + let v = Vec::from(slice_i); + data_v.push(v); + } + // debug!(" reconstructed input vectors"); + neighbours = (*hnsw_api).opaque.parallel_search_neighbours(&data_v, knbn, ef_search); + } + // construct a vector of Neighbourhood_api + // reverse work, construct 2 arrays, one vector of Neighbours, and one vectors of number of returned neigbours by input a vector. + let mut neighbour_lists = Vec::::with_capacity(nb_vec); + for v in neighbours { + let neighbours_api : Vec = v.iter().map(|n| Neighbour_api::from(n)).collect(); + let nbgh = neighbours_api.len(); + let neighbours_api_ptr = neighbours_api.as_ptr(); + std::mem::forget(neighbours_api); + let v_answer = Neighbourhood_api { + nbgh : nbgh as i64, + neighbours: neighbours_api_ptr, + }; + neighbour_lists.push(v_answer); + } + trace!(" reconstructed output pointers to vectors"); + let neighbour_lists_ptr = neighbour_lists.as_ptr(); + std::mem::forget(neighbour_lists); + let answer = Vec_api:: { + len : nb_vec as i64, + ptr : neighbour_lists_ptr, + }; + Box::into_raw(Box::new(answer)) + } // end of parallel_search_neighbours_f32 for HnswApif32 + ) +); + +#[allow(unused_macros)] +macro_rules! generate_file_dump( + ($function_name:ident, $api_name:ty, $type_val:ty) => ( + /// dump the graph to a file + /// # Safety + /// The function is unsafe because it dereferences a raw pointer + /// + #[unsafe(no_mangle)] + pub unsafe extern "C" fn $function_name(hnsw_api : *const $api_name, namelen : usize, filename :*const u8) -> i64 { + log::info!("receiving request for file dump"); + let slice = unsafe { std::slice::from_raw_parts(filename, namelen) } ; + let fstring = String::from_utf8_lossy(slice).into_owned(); + let res = unsafe { (*hnsw_api).opaque.file_dump(&PathBuf::from("."), &fstring) } ; + if res.is_ok() { + return 1; + } + else { return -1; } + } // end of function_name + ) +); + +//======= Reload stuff + +#[allow(unused_macros)] +macro_rules! generate_loadhnsw( + ($function_name:ident, $api_name:ty, $type_val:ty, $type_dist : ty) => ( + /// function to reload from a previous dump (knowing data type and distance used). + /// This function takes as argument a pointer to Hnswio_api that drives the reloading. + /// The pointer is provided by the function [get_hnswio()](get_hnswio). + /// # Safety + /// The function is unsafe because it dereferences a raw pointer + /// + #[unsafe(no_mangle)] + pub unsafe extern "C" fn $function_name(hnswio_c : *mut HnswIo) -> *const $api_name { + // + unsafe { + let hnsw_loaded_res = (*hnswio_c).load_hnsw::<$type_val, $type_dist>(); + + if let Ok(hnsw_loaded) = hnsw_loaded_res { + let api = <$api_name>::new(Box::new(hnsw_loaded)); + return Box::into_raw(Box::new(api)); + } + else { + warn!("an error occured, could not reload data from {:?}", (*hnswio_c).get_basename()); + return ptr::null(); + } + } + } // end of load_hnswdump_ + ) +); + +// here we must generate as many function as there are couples (type, distance) to be accessed from our needs in Julia + +// f32 +generate_loadhnsw!( + load_hnswdump_f32_DistL1, + HnswApif32, + f32, + anndists::dist::distances::DistL1 +); +generate_loadhnsw!( + load_hnswdump_f32_DistL2, + HnswApif32, + f32, + anndists::dist::distances::DistL2 +); +generate_loadhnsw!( + load_hnswdump_f32_DistCosine, + HnswApif32, + f32, + anndists::dist::distances::DistCosine +); +generate_loadhnsw!( + load_hnswdump_f32_DistDot, + HnswApif32, + f32, + anndists::dist::distances::DistDot +); +generate_loadhnsw!( + load_hnswdump_f32_DistJensenShannon, + HnswApif32, + f32, + anndists::dist::distances::DistJensenShannon +); +generate_loadhnsw!( + load_hnswdump_f32_DistJeffreys, + HnswApif32, + f32, + anndists::dist::distances::DistJeffreys +); + +// i32 +generate_loadhnsw!( + load_hnswdump_i32_DistL1, + HnswApii32, + i32, + anndists::dist::distances::DistL1 +); +generate_loadhnsw!( + load_hnswdump_i32_DistL2, + HnswApii32, + i32, + anndists::dist::distances::DistL2 +); +generate_loadhnsw!( + load_hnswdump_i32_DistHamming, + HnswApii32, + i32, + anndists::dist::distances::DistHamming +); + +// u32 +generate_loadhnsw!( + load_hnswdump_u32_DistL1, + HnswApiu32, + u32, + anndists::dist::distances::DistL1 +); +generate_loadhnsw!( + load_hnswdump_u32_DistL2, + HnswApiu32, + u32, + anndists::dist::distances::DistL2 +); +generate_loadhnsw!( + load_hnswdump_u32_DistHamming, + HnswApiu32, + u32, + anndists::dist::distances::DistHamming +); +generate_loadhnsw!( + load_hnswdump_u32_DistJaccard, + HnswApiu32, + u32, + anndists::dist::distances::DistJaccard +); + +// u16 +generate_loadhnsw!( + load_hnswdump_u16_DistL1, + HnswApiu16, + u16, + anndists::dist::distances::DistL1 +); +generate_loadhnsw!( + load_hnswdump_u16_DistL2, + HnswApiu16, + u16, + anndists::dist::distances::DistL2 +); +generate_loadhnsw!( + load_hnswdump_u16_DistHamming, + HnswApiu16, + u16, + anndists::dist::distances::DistHamming +); +generate_loadhnsw!( + load_hnswdump_u16_DistLevenshtein, + HnswApiu16, + u16, + anndists::dist::distances::DistLevenshtein +); + +// u8 +generate_loadhnsw!( + load_hnswdump_u8_DistL1, + HnswApiu8, + u8, + anndists::dist::distances::DistL1 +); +generate_loadhnsw!( + load_hnswdump_u8_DistL2, + HnswApiu8, + u8, + anndists::dist::distances::DistL2 +); +generate_loadhnsw!( + load_hnswdump_u8_DistHamming, + HnswApiu8, + u8, + anndists::dist::distances::DistHamming +); +generate_loadhnsw!( + load_hnswdump_u8_DistJaccard, + HnswApiu8, + u8, + anndists::dist::distances::DistJaccard +); + +// Reload only graph +generate_loadhnsw!( + load_hnswdump_NoData_DistNoDist, + HnswApiNodata, + NoData, + anndists::dist::NoDist +); + +//=============== implementation for i32 +/// # Safety +/// The function is unsafe because it dereferences a raw pointer +/// +#[unsafe(no_mangle)] +pub unsafe extern "C" fn init_hnsw_f32( + max_nb_conn: usize, + ef_const: usize, + namelen: usize, + cdistname: *const u8, +) -> *const HnswApif32 { + info!("entering init_hnsw_f32"); + let slice = unsafe { std::slice::from_raw_parts(cdistname, namelen) }; + let dname = String::from_utf8_lossy(slice).into_owned(); + // map distname to sthg. This whole block will go to a macro + match dname.as_str() { + "DistL1" => { + info!(" received DistL1"); + let h = Hnsw::::new(max_nb_conn, 10000, 16, ef_const, DistL1 {}); + let api = HnswApif32 { + opaque: Box::new(h), + }; + Box::into_raw(Box::new(api)) + } + "DistL2" => { + let h = Hnsw::::new(max_nb_conn, 10000, 16, ef_const, DistL2 {}); + let api = HnswApif32 { + opaque: Box::new(h), + }; + Box::into_raw(Box::new(api)) + } + "DistDot" => { + let h = Hnsw::::new(max_nb_conn, 10000, 16, ef_const, DistDot {}); + let api = HnswApif32 { + opaque: Box::new(h), + }; + Box::into_raw(Box::new(api)) + } + "DistHellinger" => { + let h = + Hnsw::::new(max_nb_conn, 10000, 16, ef_const, DistHellinger {}); + let api = HnswApif32 { + opaque: Box::new(h), + }; + Box::into_raw(Box::new(api)) + } + "DistJeffreys" => { + let h = + Hnsw::::new(max_nb_conn, 10000, 16, ef_const, DistJeffreys {}); + let api = HnswApif32 { + opaque: Box::new(h), + }; + Box::into_raw(Box::new(api)) + } + "DistJensenShannon" => { + let h = Hnsw::::new( + max_nb_conn, + 10000, + 16, + ef_const, + DistJensenShannon {}, + ); + let api = HnswApif32 { + opaque: Box::new(h), + }; + Box::into_raw(Box::new(api)) + } + _ => { + warn!("init_hnsw_f32 received unknow distance {:?} ", dname); + ptr::null::() + } + } // znd match +} // end of init_hnsw_f32 + +/// same as max_layer with different arguments, we pass max_elements and max_layer +/// # Safety +/// This function is unsafe because it dereferences raw pointers. +/// +#[unsafe(no_mangle)] +pub unsafe extern "C" fn new_hnsw_f32( + max_nb_conn: usize, + ef_const: usize, + namelen: usize, + cdistname: *const u8, + max_elements: usize, + max_layer: usize, +) -> *const HnswApif32 { + debug!("entering new_hnsw_f32"); + let slice = unsafe { std::slice::from_raw_parts(cdistname, namelen) }; + let dname = String::from_utf8_lossy(slice); + // map distname to sthg. This whole block will go to a macro + match dname.as_ref() { + "DistL1" => { + info!(" received DistL1"); + let h = + Hnsw::::new(max_nb_conn, max_elements, max_layer, ef_const, DistL1 {}); + let api = HnswApif32 { + opaque: Box::new(h), + }; + Box::into_raw(Box::new(api)) + } + "DistL2" => { + let h = + Hnsw::::new(max_nb_conn, max_elements, max_layer, ef_const, DistL2 {}); + let api = HnswApif32 { + opaque: Box::new(h), + }; + Box::into_raw(Box::new(api)) + } + "DistDot" => { + let h = Hnsw::::new( + max_nb_conn, + max_elements, + max_layer, + ef_const, + DistDot {}, + ); + let api = HnswApif32 { + opaque: Box::new(h), + }; + Box::into_raw(Box::new(api)) + } + "DistHellinger" => { + let h = Hnsw::::new( + max_nb_conn, + max_elements, + max_layer, + ef_const, + DistHellinger {}, + ); + let api = HnswApif32 { + opaque: Box::new(h), + }; + Box::into_raw(Box::new(api)) + } + "DistJeffreys" => { + let h = Hnsw::::new( + max_nb_conn, + max_elements, + max_layer, + ef_const, + DistJeffreys {}, + ); + let api = HnswApif32 { + opaque: Box::new(h), + }; + Box::into_raw(Box::new(api)) + } + "DistJensenShannon" => { + let h = Hnsw::::new( + max_nb_conn, + max_elements, + max_layer, + ef_const, + DistJensenShannon {}, + ); + let api = HnswApif32 { + opaque: Box::new(h), + }; + Box::into_raw(Box::new(api)) + } + _ => { + warn!("init_hnsw_f32 received unknow distance {:?} ", dname); + ptr::null::() + } + } // znd match + // +} // end of new_hnsw_f32 + +/// # Safety +/// This function is unsafe because it dereferences raw pointers. +/// +#[unsafe(no_mangle)] +pub unsafe extern "C" fn drop_hnsw_f32(p: *const HnswApif32) { + unsafe { + let _raw = Box::from_raw(p as *mut HnswApif32); + } +} + +/// # Safety +/// This function is unsafe because it dereferences raw pointers. +/// +#[unsafe(no_mangle)] +pub unsafe extern "C" fn drop_hnsw_u16(p: *const HnswApiu16) { + unsafe { + let _raw = Box::from_raw(p as *mut HnswApiu16); + } +} + +#[unsafe(no_mangle)] +pub extern "C" fn init_hnsw_ptrdist_f32( + max_nb_conn: usize, + ef_const: usize, + c_func: extern "C" fn(*const f32, *const f32, c_ulonglong) -> f32, +) -> *const HnswApif32 { + info!("init_ hnsw_ptrdist: ptr {:?}", c_func); + let c_dist = DistCFFI::::new(c_func); + let h = Hnsw::>::new(max_nb_conn, 10000, 16, ef_const, c_dist); + let api = HnswApif32 { + opaque: Box::new(h), + }; + Box::into_raw(Box::new(api)) +} + +/// # Safety +/// This function is unsafe because it dereferences raw pointers. +/// +#[unsafe(no_mangle)] +pub unsafe extern "C" fn insert_f32( + hnsw_api: *mut HnswApif32, + len: usize, + data: *const f32, + id: usize, +) { + trace!("entering insert_f32, vec len is {:?}, id : {:?} ", len, id); + // construct vector: Rust clones and take ownership. + let data_v: Vec; + unsafe { + let slice = std::slice::from_raw_parts(data, len); + data_v = Vec::from(slice); + // debug!("calling insert data"); + (*hnsw_api).opaque.insert_data(&data_v, id); + } + // trace!("exiting insert_f32"); +} // end of insert_f32 + +/// # Safety +/// This function is unsafe because it dereferences raw pointers. +/// +#[unsafe(no_mangle)] +pub unsafe extern "C" fn parallel_insert_f32( + hnsw_api: *mut HnswApif32, + nb_vec: usize, + vec_len: usize, + datas: *mut *const f32, + ids: *const usize, +) { + // + // debug!("entering parallel_insert_f32 , vec len is {:?}, nb_vec : {:?}", vec_len, nb_vec); + let data_ids: Vec; + let data_ptrs: Vec<*const f32>; + unsafe { + let slice = std::slice::from_raw_parts(ids, nb_vec); + data_ids = Vec::from(slice); + } + unsafe { + let slice = std::slice::from_raw_parts(datas, nb_vec); + data_ptrs = Vec::from(slice); + } + // debug!("got data ptrs"); + let mut data_v = Vec::>::with_capacity(nb_vec); + #[allow(clippy::needless_range_loop)] + for i in 0..nb_vec { + unsafe { + let slice = std::slice::from_raw_parts(data_ptrs[i], vec_len); + let v = Vec::from(slice); + data_v.push(v); + } + } + // debug!("sending request"); + let mut request: Vec<(&Vec, usize)> = Vec::with_capacity(nb_vec); + for i in 0..nb_vec { + request.push((&data_v[i], data_ids[i])); + } + // + unsafe { + (*hnsw_api).opaque.parallel_insert_data(&request); + }; + trace!("exiting parallel_insert"); +} // end of parallel_insert_f32 + +/// # Safety +/// This function is unsafe because it dereferences raw pointers. +/// +#[unsafe(no_mangle)] +pub unsafe extern "C" fn search_neighbours_f32( + hnsw_api: *const HnswApif32, + len: usize, + data: *const f32, + knbn: usize, + ef_search: usize, +) -> *const Neighbourhood_api { + // + trace!( + "entering search_neighbours , vec len is {:?}, id : {:?} ef_search {:?}", + len, knbn, ef_search + ); + let data_v: Vec; + let neighbours: Vec; + unsafe { + let slice = std::slice::from_raw_parts(data, len); + data_v = Vec::from(slice); + trace!("calling search neighbours {:?}", data_v); + neighbours = (*hnsw_api) + .opaque + .search_neighbours(&data_v, knbn, ef_search); + } + let neighbours_api: Vec = neighbours.iter().map(Neighbour_api::from).collect(); + trace!(" got nb neighbours {:?}", neighbours_api.len()); + // for i in 0..neighbours_api.len() { + // println!(" id {:?} dist : {:?} ", neighbours_api[i].id, neighbours_api[i].d); + // } + let nbgh_i = neighbours.len() as i64; + let neighbours_ptr = neighbours_api.as_ptr(); + std::mem::forget(neighbours_api); + let answer = Neighbourhood_api { + nbgh: nbgh_i, + neighbours: neighbours_ptr, + }; + trace!( + "search_neighbours returning nb neighbours {:?} id ptr {:?} ", + nbgh_i, neighbours_ptr + ); + Box::into_raw(Box::new(answer)) +} +// end of search_neighbours for HnswApif32 + +generate_parallel_search_neighbours!(parallel_search_neighbours_f32, HnswApif32, f32); +generate_file_dump!(file_dump_f32, HnswApif32, f32); + +//=============== implementation for i32 + +/// # Safety +/// This function is unsafe because it dereferences raw pointers. +/// +#[unsafe(no_mangle)] +pub unsafe extern "C" fn init_hnsw_i32( + max_nb_conn: usize, + ef_const: usize, + namelen: usize, + cdistname: *const u8, +) -> *const HnswApii32 { + info!("entering init_hnsw_i32"); + let slice = unsafe { std::slice::from_raw_parts(cdistname, namelen) }; + let dname = String::from_utf8_lossy(slice); + // map distname to sthing. This whole block will go to a macro + if dname == "DistL1" { + info!(" received DistL1"); + let h = Hnsw::::new(max_nb_conn, 10000, 16, ef_const, DistL1 {}); + let api = HnswApii32 { + opaque: Box::new(h), + }; + return Box::into_raw(Box::new(api)); + } else if dname == "DistL2" { + let h = Hnsw::::new(max_nb_conn, 10000, 16, ef_const, DistL2 {}); + let api = HnswApii32 { + opaque: Box::new(h), + }; + return Box::into_raw(Box::new(api)); + } else if dname == "DistHamming" { + let h = Hnsw::::new(max_nb_conn, 10000, 16, ef_const, DistHamming {}); + let api = HnswApii32 { + opaque: Box::new(h), + }; + return Box::into_raw(Box::new(api)); + } + ptr::null::() +} // end of init_hnsw_i32 + +#[unsafe(no_mangle)] +pub extern "C" fn init_hnsw_ptrdist_i32( + max_nb_conn: usize, + ef_const: usize, + c_func: extern "C" fn(*const i32, *const i32, c_ulonglong) -> f32, +) -> *const HnswApii32 { + debug!("init_ hnsw_ptrdist: ptr {:?}", c_func); + let c_dist = DistCFFI::::new(c_func); + let h = Hnsw::>::new(max_nb_conn, 10000, 16, ef_const, c_dist); + let api = HnswApii32 { + opaque: Box::new(h), + }; + Box::into_raw(Box::new(api)) +} + +//==generation of function for i32 + +super::declare_myapi_type!(HnswApii32, i32); + +generate_insert!(insert_i32, HnswApii32, i32); +generate_parallel_insert!(parallel_insert_i32, HnswApii32, i32); +generate_search_neighbours!(search_neighbours_i32, HnswApii32, i32); +generate_parallel_search_neighbours!(parallel_search_neighbours_i32, HnswApii32, i32); +generate_file_dump!(file_dump_i32, HnswApii32, i32); + +//========== generation for u32 + +/// # Safety +/// This function is unsafe because it dereferences raw pointers. +/// +#[unsafe(no_mangle)] +pub unsafe extern "C" fn init_hnsw_u32( + max_nb_conn: usize, + ef_const: usize, + namelen: usize, + cdistname: *const u8, +) -> *const HnswApiu32 { + debug!("Entering init_hnsw_u32"); + let slice = unsafe { std::slice::from_raw_parts(cdistname, namelen) }; + let dname = String::from_utf8_lossy(slice); + // map distname to sthg. This whole block will go to a macro + if dname == "DistL1" { + debug!("Received DistL1"); + let h = Hnsw::::new(max_nb_conn, 10000, 16, ef_const, DistL1 {}); + let api = HnswApiu32 { + opaque: Box::new(h), + }; + return Box::into_raw(Box::new(api)); + } else if dname == "DistL2" { + let h = Hnsw::::new(max_nb_conn, 10000, 16, ef_const, DistL2 {}); + let api = HnswApiu32 { + opaque: Box::new(h), + }; + return Box::into_raw(Box::new(api)); + } else if dname == "DistJaccard" { + let h = Hnsw::::new(max_nb_conn, 10000, 16, ef_const, DistJaccard {}); + let api = HnswApiu32 { + opaque: Box::new(h), + }; + return Box::into_raw(Box::new(api)); + } else if dname == "DistHamming" { + let h = Hnsw::::new(max_nb_conn, 10000, 16, ef_const, DistHamming {}); + let api = HnswApiu32 { + opaque: Box::new(h), + }; + return Box::into_raw(Box::new(api)); + } + // + ptr::null::() +} // end of init_hnsw_i32 + +#[unsafe(no_mangle)] +pub extern "C" fn init_hnsw_ptrdist_u32( + max_nb_conn: usize, + ef_const: usize, + c_func: extern "C" fn(*const u32, *const u32, c_ulonglong) -> f32, +) -> *const HnswApiu32 { + info!("init_ hnsw_ptrdist: ptr {:?}", c_func); + let c_dist = DistCFFI::::new(c_func); + let h = Hnsw::>::new(max_nb_conn, 10000, 16, ef_const, c_dist); + let api = HnswApiu32 { + opaque: Box::new(h), + }; + Box::into_raw(Box::new(api)) +} + +super::declare_myapi_type!(HnswApiu32, u32); + +generate_insert!(insert_u32, HnswApiu32, u32); +generate_parallel_insert!(parallel_insert_u32, HnswApiu32, u32); +generate_search_neighbours!(search_neighbours_u32, HnswApiu32, u32); +generate_parallel_search_neighbours!(parallel_search_neighbours_u32, HnswApiu32, u32); +generate_file_dump!(file_dump_u32, HnswApiu32, u32); + +//============== generation of function for u16 ===================== + +super::declare_myapi_type!(HnswApiu16, u16); + +/// # Safety +/// This function is unsafe because it dereferences raw pointers. +/// +#[unsafe(no_mangle)] +pub unsafe extern "C" fn init_hnsw_u16( + max_nb_conn: usize, + ef_const: usize, + namelen: usize, + cdistname: *const u8, +) -> *const HnswApiu16 { + info!("entering init_hnsw_u16"); + let slice = unsafe { std::slice::from_raw_parts(cdistname, namelen) }; + let dname = String::from_utf8_lossy(slice); + // map distname to sthg. This whole block will go to a macro + if dname == "DistL1" { + info!(" received DistL1"); + let h = Hnsw::::new(max_nb_conn, 10000, 16, ef_const, DistL1 {}); + let api = HnswApiu16 { + opaque: Box::new(h), + }; + return Box::into_raw(Box::new(api)); + } else if dname == "DistL2" { + let h = Hnsw::::new(max_nb_conn, 10000, 16, ef_const, DistL2 {}); + let api = HnswApiu16 { + opaque: Box::new(h), + }; + return Box::into_raw(Box::new(api)); + } else if dname == "DistHamming" { + let h = Hnsw::::new(max_nb_conn, 10000, 16, ef_const, DistHamming {}); + let api = HnswApiu16 { + opaque: Box::new(h), + }; + return Box::into_raw(Box::new(api)); + } else if dname == "DistJaccard" { + let h = Hnsw::::new(max_nb_conn, 10000, 16, ef_const, DistJaccard {}); + let api = HnswApiu16 { + opaque: Box::new(h), + }; + return Box::into_raw(Box::new(api)); + } else if dname == "DistLevenshtein" { + let h = + Hnsw::::new(max_nb_conn, 10000, 16, ef_const, DistLevenshtein {}); + let api = HnswApiu16 { + opaque: Box::new(h), + }; + return Box::into_raw(Box::new(api)); + } + ptr::null::() +} // end of init_hnsw_u16 + +/// # Safety +/// This function is unsafe because it dereferences raw pointers. +/// +#[unsafe(no_mangle)] +pub unsafe extern "C" fn new_hnsw_u16( + max_nb_conn: usize, + ef_const: usize, + namelen: usize, + cdistname: *const u8, + max_elements: usize, + max_layer: usize, +) -> *const HnswApiu16 { + info!("entering init_hnsw_u16"); + let slice = unsafe { std::slice::from_raw_parts(cdistname, namelen) }; + let dname = String::from_utf8_lossy(slice); + // map distname to sthg. This whole block will go to a macro + if dname == "DistL1" { + info!(" received DistL1"); + let h = Hnsw::::new(max_nb_conn, max_elements, max_layer, ef_const, DistL1 {}); + let api = HnswApiu16 { + opaque: Box::new(h), + }; + return Box::into_raw(Box::new(api)); + } else if dname == "DistL2" { + let h = Hnsw::::new(max_nb_conn, max_elements, max_layer, ef_const, DistL2 {}); + let api = HnswApiu16 { + opaque: Box::new(h), + }; + return Box::into_raw(Box::new(api)); + } else if dname == "DistHamming" { + let h = Hnsw::::new( + max_nb_conn, + max_elements, + max_layer, + ef_const, + DistHamming {}, + ); + let api = HnswApiu16 { + opaque: Box::new(h), + }; + return Box::into_raw(Box::new(api)); + } else if dname == "DistJaccard" { + let h = Hnsw::::new( + max_nb_conn, + max_elements, + max_layer, + ef_const, + DistJaccard {}, + ); + let api = HnswApiu16 { + opaque: Box::new(h), + }; + return Box::into_raw(Box::new(api)); + } else if dname == "DistLevenshtein" { + let h = Hnsw::::new( + max_nb_conn, + max_elements, + max_layer, + ef_const, + DistLevenshtein {}, + ); + let api = HnswApiu16 { + opaque: Box::new(h), + }; + return Box::into_raw(Box::new(api)); + } + ptr::null::() +} // end of init_hnsw_u16 + +#[unsafe(no_mangle)] +pub extern "C" fn init_hnsw_ptrdist_u16( + max_nb_conn: usize, + ef_const: usize, + c_func: extern "C" fn(*const u16, *const u16, c_ulonglong) -> f32, +) -> *const HnswApiu16 { + info!("init_ hnsw_ptrdist: ptr {:?}", c_func); + let c_dist = DistCFFI::::new(c_func); + let h = Hnsw::>::new(max_nb_conn, 10000, 16, ef_const, c_dist); + let api = HnswApiu16 { + opaque: Box::new(h), + }; + Box::into_raw(Box::new(api)) +} + +generate_insert!(insert_u16, HnswApiu16, u16); +generate_parallel_insert!(parallel_insert_u16, HnswApiu16, u16); +generate_search_neighbours!(search_neighbours_u16, HnswApiu16, u16); +generate_parallel_search_neighbours!(parallel_search_neighbours_u16, HnswApiu16, u16); +generate_file_dump!(file_dump_u16, HnswApiu16, u16); + +//============== generation of function for u8 ===================== + +super::declare_myapi_type!(HnswApiu8, u8); + +/// # Safety +/// This function is unsafe because it dereferences raw pointers. +/// +#[unsafe(no_mangle)] +pub unsafe extern "C" fn init_hnsw_u8( + max_nb_conn: usize, + ef_const: usize, + namelen: usize, + cdistname: *const u8, +) -> *const HnswApiu8 { + debug!("entering init_hnsw_u8"); + let slice = unsafe { std::slice::from_raw_parts(cdistname, namelen) }; + let dname = String::from_utf8_lossy(slice); + // map distname to sthg. This whole block will go to a macro + if dname == "DistL1" { + info!(" received DistL1"); + let h = Hnsw::::new(max_nb_conn, 10000, 16, ef_const, DistL1 {}); + let api = HnswApiu8 { + opaque: Box::new(h), + }; + return Box::into_raw(Box::new(api)); + } else if dname == "DistL2" { + let h = Hnsw::::new(max_nb_conn, 10000, 16, ef_const, DistL2 {}); + let api = HnswApiu8 { + opaque: Box::new(h), + }; + return Box::into_raw(Box::new(api)); + } else if dname == "DistHamming" { + let h = Hnsw::::new(max_nb_conn, 10000, 16, ef_const, DistHamming {}); + let api = HnswApiu8 { + opaque: Box::new(h), + }; + return Box::into_raw(Box::new(api)); + } else if dname == "DistJaccard" { + let h = Hnsw::::new(max_nb_conn, 10000, 16, ef_const, DistJaccard {}); + let api = HnswApiu8 { + opaque: Box::new(h), + }; + return Box::into_raw(Box::new(api)); + } + ptr::null::() +} // end of init_hnsw_u16 + +#[unsafe(no_mangle)] +pub extern "C" fn init_hnsw_ptrdist_u8( + max_nb_conn: usize, + ef_const: usize, + c_func: extern "C" fn(*const u8, *const u8, c_ulonglong) -> f32, +) -> *const HnswApiu8 { + info!("init_ hnsw_ptrdist: ptr {:?}", c_func); + let c_dist = DistCFFI::::new(c_func); + let h = Hnsw::>::new(max_nb_conn, 10000, 16, ef_const, c_dist); + let api = HnswApiu8 { + opaque: Box::new(h), + }; + Box::into_raw(Box::new(api)) +} + +generate_insert!(insert_u8, HnswApiu8, u8); +generate_parallel_insert!(parallel_insert_u8, HnswApiu8, u8); +generate_search_neighbours!(search_neighbours_u8, HnswApiu8, u8); +generate_parallel_search_neighbours!(parallel_search_neighbours_u8, HnswApiu8, u8); +generate_file_dump!(file_dump_u8, HnswApiu8, u8); + +//=========================== dump restore functions + +/// This structure provides a light description of the graph to be passed to C compatible languages. +#[repr(C)] +pub struct DescriptionFFI { + /// value is 1 for Full 0 for Light + pub dumpmode: u8, + /// max number of connections in layers != 0 + pub max_nb_connection: u8, + /// number of observed layers + pub nb_layer: u8, + /// search parameter + pub ef: usize, + /// total number of points + pub nb_point: usize, + /// dimension of data vector + pub data_dimension: usize, + /// length and pointer on dist name + pub distname_len: usize, + pub distname: *const u8, + /// T typename + pub t_name_len: usize, + pub t_name: *const u8, +} + +impl Default for DescriptionFFI { + fn default() -> Self { + Self::new() + } +} + +impl DescriptionFFI { + pub fn new() -> Self { + DescriptionFFI { + dumpmode: 0, + max_nb_connection: 0, + nb_layer: 0, + ef: 0, + nb_point: 0, + data_dimension: 0, + distname_len: 0, + distname: ptr::null(), + t_name_len: 0, + t_name: ptr::null(), + } + } // end of new +} + +/// returns a const pointer to a DescriptionFFI from a dump file, given filename length and pointer (*const u8) +/// # Safety +/// This function is unsafe because it dereferences raw pointers. +/// +#[unsafe(no_mangle)] +pub unsafe extern "C" fn load_hnsw_description( + flen: usize, + name: *const u8, +) -> *const DescriptionFFI { + // opens file + let slice = unsafe { std::slice::from_raw_parts(name, flen) }; + let filename = String::from_utf8_lossy(slice).into_owned(); + let fpath = PathBuf::from(filename); + let fileres = OpenOptions::new().read(true).open(&fpath); + // + let mut ffi_description = DescriptionFFI::new(); + match fileres { + Ok(file) => { + // + let mut bufr = BufReader::with_capacity(10000000, file); + let res = load_description(&mut bufr); + if let Ok(description) = res { + let distname = String::clone(&description.distname); + let distname_ptr = distname.as_ptr(); + let distname_len = distname.len(); + std::mem::forget(distname); + + let t_name = String::clone(&description.t_name); + let t_name_ptr = t_name.as_ptr(); + let t_name_len = t_name.len(); + std::mem::forget(t_name); + + ffi_description.dumpmode = 1; // CAVEAT + ffi_description.max_nb_connection = description.max_nb_connection; + ffi_description.nb_layer = description.nb_layer; + ffi_description.ef = description.ef; + ffi_description.data_dimension = description.dimension; + ffi_description.distname_len = distname_len; + ffi_description.distname = distname_ptr; + ffi_description.t_name_len = t_name_len; + ffi_description.t_name = t_name_ptr; + Box::into_raw(Box::new(ffi_description)) + } else { + error!( + "could not get descrption of hnsw from file {:?}", + fpath.as_os_str() + ); + println!( + "could not get descrption of hnsw from file {:?} ", + fpath.as_os_str() + ); + ptr::null() + } + } + Err(_e) => { + error!( + "no such file, load_hnsw_description: could not open file {:?}", + fpath.as_os_str() + ); + println!( + "no such file, load_hnsw_description: could not open file {:?}", + fpath.as_os_str() + ); + ptr::null() + } + } +} // end of load_hnsw_description + +//============ log initialization ============// + +/// to initialize rust logging from Julia +#[unsafe(no_mangle)] +pub extern "C" fn init_rust_log() { + let _res = env_logger::Builder::from_default_env().try_init(); +} diff --git a/patches/hnsw_rs/src/prelude.rs b/patches/hnsw_rs/src/prelude.rs new file mode 100644 index 000000000..3ea93002e --- /dev/null +++ b/patches/hnsw_rs/src/prelude.rs @@ -0,0 +1,11 @@ +// gathers modules to include and re-exorts all of anndists! + +pub use crate::api::*; +pub use crate::hnsw::*; + +#[allow(unused)] +pub use crate::filter::*; + +pub use crate::hnswio::*; + +pub use anndists::dist::distances::*; diff --git a/patches/hnsw_rs/tests/deallocation_test.rs b/patches/hnsw_rs/tests/deallocation_test.rs new file mode 100644 index 000000000..8df3c8428 --- /dev/null +++ b/patches/hnsw_rs/tests/deallocation_test.rs @@ -0,0 +1,34 @@ +use env_logger::Builder; + +use anndists::dist::DistL1; +use hnsw_rs::hnsw::Hnsw; + +// A test program to see if memory from insertions gets deallocated. +// This program sets up a process that iteratively builds a new model and lets it go out of scope. +// Since the models go out of scope, the desired behavior is that memory consumption is constant while this program is running. +fn main() { + // + Builder::from_default_env().init(); + // + let mut counter: usize = 0; + loop { + let hnsw: Hnsw = Hnsw::new(15, 100_000, 20, 500_000, DistL1 {}); + let s1 = [1.0, 0.0, 0.0, 0.0]; + hnsw.insert_slice((&s1, 0)); + let s2 = [0.0, 1.0, 1.0]; + hnsw.insert_slice((&s2, 1)); + let s3 = [0.0, 0.0, 1.0]; + hnsw.insert_slice((&s3, 2)); + let s4 = [1.0, 0.0, 0.0, 1.0]; + hnsw.insert_slice((&s4, 3)); + let s5 = [1.0, 1.0, 1.0]; + hnsw.insert_slice((&s5, 4)); + let s6 = [1.0, -1.0, 1.0]; + hnsw.insert_slice((&s6, 5)); + + if counter % 1_000_000 == 0 { + println!("counter : {}", counter) + } + counter += 1; + } +} diff --git a/patches/hnsw_rs/tests/filtertest.rs b/patches/hnsw_rs/tests/filtertest.rs new file mode 100644 index 000000000..3885fc33b --- /dev/null +++ b/patches/hnsw_rs/tests/filtertest.rs @@ -0,0 +1,266 @@ +#![allow(clippy::needless_range_loop)] +#![allow(clippy::range_zip_with_len)] + +use anndists::dist::*; +use hnsw_rs::prelude::*; +use rand::{Rng, distr::Uniform}; +use std::iter; + +#[allow(unused)] +fn log_init_test() { + let _ = env_logger::builder().is_test(true).try_init(); +} + +// Shows two ways to do filtering, by a sorted vector or with a closure +// We define a hnsw-index with 500 entries +// Only ids within 300-400 should be in the result-set + +// Used to create a random string +fn generate_random_string(len: usize) -> String { + const CHARSET: &[u8] = b"abcdefghij"; + let mut rng = rand::rng(); + let one_char = || CHARSET[rng.random_range(0..CHARSET.len())] as char; + iter::repeat_with(one_char).take(len).collect() +} + +// this function uses a sorted vector as a filter +fn search_closure_filter( + word: &str, + hns: &Hnsw, + words: &[String], + filter_vector: &[usize], +) { + // transform string to u16 values + let vec: Vec = word.chars().map(|c| c as u16).collect(); + // now create a closure using this filter_vector + // here we can off course implement more advanced filter logic + let filter = |id: &usize| -> bool { filter_vector.binary_search(id).is_ok() }; + + // Now let us do the search by using the defined clojure, which in turn uses our vector + // ids not in the vector will not be indluced in the search results + println!("========== Search with closure filter"); + let ef_search = 30; + let res = hns.search_possible_filter(&vec, 10, ef_search, Some(&filter)); + for r in res { + println!( + "Word: {:?} Id: {:?} Distance: {:?}", + words[r.d_id], r.d_id, r.distance + ); + } +} + +#[test] +fn filter_levenstein() { + let nb_elem = 500000; // number of possible words in the dictionary + let max_nb_connection = 15; + let nb_layer = 16.min((nb_elem as f32).ln().trunc() as usize); + let ef_c = 200; + let hns = Hnsw::::new( + max_nb_connection, + nb_elem, + nb_layer, + ef_c, + DistLevenshtein {}, + ); + let mut words = vec![]; + for _n in 1..1000 { + let tw = generate_random_string(8); + words.push(tw); + } + + for (i, w) in words.iter().enumerate() { + let vec: Vec = w.chars().map(|c| c as u16).collect(); + hns.insert((&vec, i)); + if i % 1000 == 0 { + println!("Inserting: {:?}", i); + } + } + // Create a sorted vector of ids + // the ids in the vector will be used as a filter + let filtered_hns = Hnsw::::new( + max_nb_connection, + nb_elem, + nb_layer, + ef_c, + DistLevenshtein {}, + ); + let mut filter_vector: Vec = Vec::new(); + for i in 300..400 { + filter_vector.push(i); + let v: Vec = words[i].chars().map(|c| c as u16).collect(); + filtered_hns.insert((&v, i)); + } + // + let ef_search = 30; + let tosearch = "abcdefg"; + let knbn = 10; + let vec_tosearch: Vec = tosearch.chars().map(|c| c as u16).collect(); + // + println!("========== Search in full hns with filter"); + let vec_res = hns.search_filter(&vec_tosearch, knbn, ef_search, Some(&filter_vector)); + for r in &vec_res { + println!( + "Word: {:?} Id: {:?} Distance: {:?}", + words[r.d_id], r.d_id, r.distance + ); + } + // + println!("========== Search in restricted_hns but without filter"); + // + let vec: Vec = tosearch.chars().map(|c| c as u16).collect(); + let res: Vec = filtered_hns.search(&vec, knbn, ef_search); + for r in &res { + println!( + "Word: {:?} Id: {:?} Distance: {:?}", + words[r.d_id], r.d_id, r.distance + ); + } + // + // search with filter + // first with closure + println!("========== Search in full hns with closure filter"); + search_closure_filter(tosearch, &hns, &words, &filter_vector); + // + // now with vector filter and estimate recall + // + println!("========== Search in full hns with vector filter"); + let filter_vec_res = hns.search_filter(&vec_tosearch, knbn, ef_search, Some(&filter_vector)); + for r in &filter_vec_res { + println!( + "Word: {:?} Id: {:?} Distance: {:?}", + words[r.d_id], r.d_id, r.distance + ); + } + // how many neighbours in res are in filter_vec_res + let mut nb_found: usize = 0; + for n in &res { + let found = filter_vec_res.iter().find(|&&m| m.d_id == n.d_id); + if found.is_some() { + nb_found += 1; + assert_eq!(n.distance, found.unwrap().distance); + } + } + println!(" recall : {}", nb_found as f32 / res.len() as f32); + println!( + " last distances ratio : {} ", + res.last().unwrap().distance / filter_vec_res.last().unwrap().distance + ); +} + +// A test with random uniform data vectors and L2 distance +// We compare a search of a random vector in hnsw structure with a filter to a filtered_hnsw +// containing only the data fitting the filter +#[test] +fn filter_l2() { + let nb_elem = 5000; + let dim = 25; + // generate nb_elem colmuns vectors of dimension dim + let mut rng = rand::rng(); + let unif = Uniform::::new(0., 1.).unwrap(); + let mut data = Vec::with_capacity(nb_elem); + for _ in 0..nb_elem { + let column = (0..dim).map(|_| rng.sample(unif)).collect::>(); + data.push(column); + } + // give an id to each data + let data_with_id = data.iter().zip(0..data.len()).collect::>(); + + let ef_c = 200; + let max_nb_connection = 15; + let nb_layer = 16.min((nb_elem as f32).ln().trunc() as usize); + let hnsw = Hnsw::::new(max_nb_connection, nb_elem, nb_layer, ef_c, DistL2 {}); + hnsw.parallel_insert(&data_with_id); + + // + let ef_search = 30; + let knbn = 10; + let vec_tosearch = (0..dim).map(|_| rng.sample(unif)).collect::>(); + // + // Create a sorted vector of ids + // the ids in the vector will be used as a filter + let filtered_hns = + Hnsw::::new(max_nb_connection, nb_elem, nb_layer, ef_c, DistL2 {}); + let mut filter_vector: Vec = Vec::new(); + for i in 300..400 { + filter_vector.push(i); + filtered_hns.insert((&data[i], i)); + } + // + println!("========== Search in full hnsw with filter"); + let filter_vec_res = hnsw.search_filter(&vec_tosearch, knbn, ef_search, Some(&filter_vector)); + for r in &filter_vec_res { + println!("Id: {:?} Distance: {:?}", r.d_id, r.distance); + } + // + println!("========== Search in restricted_hns but without filter"); + let res: Vec = filtered_hns.search(&vec_tosearch, knbn, ef_search); + for r in &res { + println!("Id: {:?} Distance: {:?}", r.d_id, r.distance); + } + // how many neighbours in res are in filter_vec_res and what is the distance gap + let mut nb_found: usize = 0; + for n in &res { + let found = filter_vec_res.iter().find(|&&m| m.d_id == n.d_id); + if found.is_some() { + nb_found += 1; + assert!((1. - n.distance / found.unwrap().distance).abs() < 1.0e-5); + } + } + println!(" recall : {}", nb_found as f32 / res.len() as f32); + println!( + " last distances ratio : {} ", + res.last().unwrap().distance / filter_vec_res.last().unwrap().distance + ); +} // end of filter_l2 + +// + +use std::collections::HashMap; +#[test] +fn filter_villsnow() { + println!("\n\n in test villsnow"); + log_init_test(); + // + let grid_size = 100; + let mut hnsw = Hnsw::::new(4, grid_size * grid_size, 16, 100, DistL2::default()); + let mut points = HashMap::new(); + + { + for (id, (i, j)) in itertools::iproduct!(0..grid_size, 0..grid_size,).enumerate() { + let data = [ + (i as f64 + 0.5) / grid_size as f64, + (j as f64 + 0.5) / grid_size as f64, + ]; + hnsw.insert((&data, id)); + points.insert(id, data); + } + + hnsw.set_searching_mode(true); + } + { + println!("first case"); + // first case + let filter = |id: &usize| DistL2::default().eval(&points[id], &[1.0, 1.0]) < 1e-2; + dbg!(points.keys().filter(|x| filter(x)).count()); // -> 1 + + let hit = hnsw.search_filter(&[0.0, 0.0], 10, 4, Some(&filter)); + if !hit.is_empty() { + log::info!("got point : {:?}", points.get(&hit[0].d_id)); + log::info!("got {:?}, must be true", filter(&hit[0].d_id)); // -> sometimes false + } else { + log::info!("found no point"); + } + assert!(hit.len() <= 1); + } + { + println!("second case"); + // second case + let filter = |_id: &usize| false; + dbg!(points.keys().filter(|x| filter(x)).count()); // -> 0, obviously + + let hit = hnsw.search_filter(&[0.0, 0.0], 10, 64, Some(&filter)); + println!("villsnow , {:?}", hit.len()); + log::info!("got {:?}, must be 0", hit.len()); // -> 1 + assert_eq!(hit.len(), 0); + } +} diff --git a/patches/hnsw_rs/tests/serpar.rs b/patches/hnsw_rs/tests/serpar.rs new file mode 100644 index 000000000..917278b1e --- /dev/null +++ b/patches/hnsw_rs/tests/serpar.rs @@ -0,0 +1,328 @@ +#![allow(clippy::range_zip_with_len)] + +//! some testing utilities. +//! run with to get output statistics : cargo test --release -- --nocapture --test test_parallel. +//! serial test corresponds to random-10nn-euclidean(k=10) +//! parallel test corresponds to random data in 25 dimensions k = 10, dist Cosine + +use rand::distr::Uniform; +use rand::prelude::*; + +use skiplist::OrderedSkipList; + +use anndists::dist; +use hnsw_rs::prelude::*; +use serde::{de::DeserializeOwned, Serialize}; + +pub fn gen_random_vector_f32(nbrow: usize) -> Vec { + let mut rng = rand::rng(); + let unif = Uniform::::new(0., 1.).unwrap(); + (0..nbrow).map(|_| rng.sample(unif)).collect::>() +} + +/// return nbcolumn vectors of dimension nbrow +pub fn gen_random_matrix_f32(nbrow: usize, nbcolumn: usize) -> Vec> { + let mut rng = rand::rng(); + let unif = Uniform::::new(0., 1.).unwrap(); + let mut data = Vec::with_capacity(nbcolumn); + for _ in 0..nbcolumn { + let column = (0..nbrow).map(|_| rng.sample(unif)).collect::>(); + data.push(column); + } + data +} + +fn brute_force_neighbours( + nb_neighbours: usize, + refdata: &PointIndexation, + distance: PointDistance, + data: &[T], +) -> OrderedSkipList { + let mut neighbours = OrderedSkipList::::with_capacity(refdata.get_nb_point()); + + let mut ptiter = refdata.into_iter(); + let mut more = true; + while more { + if let Some(point) = ptiter.next() { + let dist_p = distance.eval(data, point.get_v()); + let ordered_point = PointIdWithOrder::new(point.get_point_id(), dist_p); + // log::debug!(" brute force inserting {:?}", ordered_point); + if neighbours.len() < nb_neighbours { + neighbours.insert(ordered_point); + } else { + neighbours.insert(ordered_point); + neighbours.pop_back(); + } + } else { + more = false; + } + } // end while + neighbours +} // end of brute_force_2 + +//================================================================================================ + +mod tests { + use cpu_time::ProcessTime; + use std::time::Duration; + + use super::*; + use dist::l2_normalize; + + #[test] + fn test_serial() { + // + // + let nb_elem = 1000; + let dim = 10; + let knbn = 10; + let ef = 20; + let parallel = true; + // + println!("\n\n test_serial nb_elem {:?}", nb_elem); + // + let data = gen_random_matrix_f32(dim, nb_elem); + let data_with_id = data.iter().zip(0..data.len()).collect::>(); + + let ef_c = 400; + let max_nb_connection = 32; + let nb_layer = 16.min((nb_elem as f32).ln().trunc() as usize); + let mut hns = Hnsw::::new( + max_nb_connection, + nb_elem, + nb_layer, + ef_c, + dist::DistL1 {}, + ); + hns.set_extend_candidates(true); + hns.set_keeping_pruned(true); + let mut start = ProcessTime::now(); + if parallel { + println!("parallel insertion"); + hns.parallel_insert(&data_with_id); + } else { + println!("serial insertion"); + for (i, d) in data.iter().enumerate() { + hns.insert((d, i)); + } + } + let mut cpu_time: Duration = start.elapsed(); + println!(" hnsw serial data insertion {:?}", cpu_time); + hns.dump_layer_info(); + println!(" hnsw data nb point inserted {:?}", hns.get_nb_point()); + // + + let nbtest = 300; + let mut recalls = Vec::::with_capacity(nbtest); + let mut nb_returned = Vec::::with_capacity(nb_elem); + let mut search_times = Vec::::with_capacity(nbtest); + for _itest in 0..nbtest { + // + let mut r_vec = Vec::::with_capacity(dim); + let mut rng = rand::rng(); + let unif = Uniform::::new(0., 1.).unwrap(); + for _ in 0..dim { + r_vec.push(rng.sample(unif)); + } + start = ProcessTime::now(); + let brute_neighbours = brute_force_neighbours( + knbn, + hns.get_point_indexation(), + Box::new(dist::DistL1 {}), + &r_vec, + ); + cpu_time = start.elapsed(); + if nbtest <= 100 { + println!("\n\n **************** test {:?}", _itest); + println!("\n brute force neighbours :"); + println!("======================"); + println!(" brute force computing {:?} \n ", cpu_time); + for i in 0..brute_neighbours.len() { + let p = brute_neighbours[i].point_id; + println!(" {:?} {:?} ", p, brute_neighbours[i].dist_to_ref); + } + } + // + hns.set_searching_mode(true); + start = ProcessTime::now(); + let knn_neighbours = hns.search(&r_vec, knbn, ef); + cpu_time = start.elapsed(); + search_times.push(cpu_time.as_micros() as f32); + if nbtest <= 100 { + println!("\n\n hnsw searching {:?} \n", cpu_time); + println!("\n knn neighbours"); + println!("======================"); + for n in &knn_neighbours { + println!(" {:?} {:?} {:?} ", n.d_id, n.p_id, n.distance); + } + } + // compute recall + let knn_neighbours_dist: Vec = knn_neighbours.iter().map(|p| p.distance).collect(); + let max_dist = brute_neighbours[knbn - 1].dist_to_ref; + let recall = knn_neighbours_dist + .iter() + .filter(|d| *d <= &max_dist) + .count(); + if nbtest <= 100 { + println!("recall {:?}", (recall as f32) / (knbn as f32)); + } + recalls.push(recall); + nb_returned.push(knn_neighbours.len()); + } // end on nbtest + // + // compute recall + // + + let mean_recall = (recalls.iter().sum::() as f32) / ((knbn * recalls.len()) as f32); + let mean_search_time = (search_times.iter().sum::()) / (search_times.len() as f32); + println!( + "\n mean fraction (of knbn) returned by search {:?} ", + (nb_returned.iter().sum::() as f32) / ((nb_returned.len() * knbn) as f32) + ); + println!( + "\n nb element {:?} nb search : {:?} recall rate is {:?} search time inverse {:?} ", + nb_elem, + nbtest, + mean_recall, + 1.0e+6_f32 / mean_search_time + ); + } // end test1 + + #[test] + fn test_parallel() { + // + let nb_elem = 1000; + let dim = 25; + let knbn = 10; + let ef_c = 800; + let max_nb_connection = 48; + let ef = 20; + // + // + let mut data = gen_random_matrix_f32(dim, nb_elem); + for v in &mut data { + l2_normalize(v); + } + let data_with_id = data.iter().zip(0..data.len()).collect::>(); + let nb_layer = 16.min((nb_elem as f32).ln().trunc() as usize); + let mut hns = Hnsw::::new( + max_nb_connection, + nb_elem, + nb_layer, + ef_c, + dist::DistDot {}, + ); + // ! + // hns.set_extend_candidates(true); + let mut start = ProcessTime::now(); + let now = std::time::SystemTime::now(); + // parallel insertion + hns.parallel_insert(&data_with_id); + let mut cpu_time: Duration = start.elapsed(); + println!( + "\n hnsw data parallel insertion cpu time {:?} , system time {:?}", + cpu_time, + now.elapsed() + ); + // one serial more to check + let mut v = gen_random_vector_f32(dim); + l2_normalize(&mut v); + hns.insert((&v, hns.get_nb_point() + 1)); + // + hns.dump_layer_info(); + println!(" hnsw data nb point inserted {:?}", hns.get_nb_point()); + // + println!("\n hnsw testing requests ..."); + let nbtest = 100; + let mut recalls = Vec::::with_capacity(nbtest); + let mut recalls_id = Vec::::with_capacity(nbtest); + + let mut search_times = Vec::::with_capacity(nbtest); + for _itest in 0..nbtest { + let mut r_vec = Vec::::with_capacity(dim); + let mut rng = rand::rng(); + let unif = Uniform::::new(0., 1.).unwrap(); + for _ in 0..dim { + r_vec.push(rng.sample(unif)); + } + l2_normalize(&mut r_vec); + + start = ProcessTime::now(); + let brute_neighbours = brute_force_neighbours( + knbn, + hns.get_point_indexation(), + Box::new(dist::DistDot), + &r_vec, + ); + cpu_time = start.elapsed(); + if nbtest <= 100 { + println!("\n\n test_par nb_elem {:?}", nb_elem); + println!("\n brute force neighbours :"); + println!("======================"); + println!(" brute force computing {:?} \n", cpu_time); + for i in 0..brute_neighbours.len() { + println!( + " {:?} {:?} ", + brute_neighbours[i].point_id, brute_neighbours[i].dist_to_ref + ); + } + } + // + let knbn = 10; + hns.set_searching_mode(true); + start = ProcessTime::now(); + let knn_neighbours = hns.search(&r_vec, knbn, ef); + cpu_time = start.elapsed(); + search_times.push(cpu_time.as_micros() as f32); + if nbtest <= 100 { + println!("\n knn neighbours"); + println!("======================"); + println!(" hnsw searching {:?} \n", cpu_time); + for n in &knn_neighbours { + println!(" {:?} \t {:?} \t {:?}", n.d_id, n.p_id, n.distance); + } + } + // compute recall with balls + let knn_neighbours_dist: Vec = knn_neighbours.iter().map(|p| p.distance).collect(); + let max_dist = brute_neighbours[knbn - 1].dist_to_ref; + let recall = knn_neighbours_dist + .iter() + .filter(|d| *d <= &max_dist) + .count(); + if nbtest <= 100 { + println!("recall {:?}", (recall as f32) / (knbn as f32)); + } + recalls.push(recall); + // compute recall with id + let mut recall_id = 0; + let mut knn_neighbours_id: Vec = + knn_neighbours.iter().map(|p| p.p_id).collect(); + knn_neighbours_id.sort_unstable(); + let snbn = knbn.min(brute_neighbours.len()); + for j in 0..snbn { + let to_search = brute_neighbours[j].point_id; + if knn_neighbours_id.binary_search(&to_search).is_ok() { + recall_id += 1; + } + } + recalls_id.push(recall_id); + } // end on nbtest + // + // compute recall + // + + let mean_recall = (recalls.iter().sum::() as f32) / ((knbn * recalls.len()) as f32); + let mean_search_time = (search_times.iter().sum::()) / (search_times.len() as f32); + println!( + "\n nb search {:?} recall rate is {:?} search time inverse {:?} ", + nbtest, + mean_recall, + 1.0e+6_f32 / mean_search_time + ); + let mean_recall_id = + (recalls.iter().sum::() as f32) / ((knbn * recalls.len()) as f32); + println!("mean recall rate with point ids {:?}", mean_recall_id); + // + // assert!(1==0); + } // end test_par +}