Merge pull request #334 from supermemoryai/hybrid-rag

Hybrid rag
This commit is contained in:
Dhravya Shah 2025-02-18 21:20:52 -07:00 committed by GitHub
commit 0c6db45d32
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 1317 additions and 30 deletions

View file

@ -0,0 +1,7 @@
ALTER TABLE "chunks" ALTER COLUMN "embeddings" SET DATA TYPE vector(768);--> statement-breakpoint
CREATE INDEX IF NOT EXISTS "documents_search_idx" ON "documents" USING gin ((
setweight(to_tsvector('english', coalesce("content", '')),'A') ||
setweight(to_tsvector('english', coalesce("title", '')),'B') ||
setweight(to_tsvector('english', coalesce("description", '')),'C') ||
setweight(to_tsvector('english', coalesce("url", '')),'D')
));

File diff suppressed because it is too large Load diff

View file

@ -113,6 +113,13 @@
"when": 1737920848112,
"tag": "0015_perpetual_mauler",
"breakpoints": true
},
{
"idx": 16,
"version": "7",
"when": 1739937938319,
"tag": "0016_good_deathbird",
"breakpoints": true
}
]
}

View file

@ -88,7 +88,8 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>()
apiKey: c.env.BRAINTRUST_API_KEY,
});
const googleClient = wrapAISDKModel(openai(c.env).chat("gpt-4o-mini-2024-07-18"));
const googleClient = wrapAISDKModel(
openai(c.env).chat("gpt-4o-mini-2024-07-18"));
// Get last user message and generate embedding in parallel with thread creation
let lastUserMessage = coreMessages.findLast((i) => i.role === "user");
@ -123,9 +124,15 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>()
return c.json({ error: "Failed to generate embedding" }, 500);
}
// Perform semantic search
const similarity = sql<number>`1 - (${cosineDistance(chunk.embeddings, embedding[0])})`;
// Pre-compute the vector similarity expression to avoid multiple calculations
const vectorSimilarity = sql<number>`1 - (embeddings <=> ${JSON.stringify(embedding[0])}::vector)`;
const textSearchRank = sql<number>`ts_rank_cd((
setweight(to_tsvector('english', coalesce(${documents.content}, '')),'A') ||
setweight(to_tsvector('english', coalesce(${documents.title}, '')),'B') ||
setweight(to_tsvector('english', coalesce(${documents.description}, '')),'C') ||
setweight(to_tsvector('english', coalesce(${documents.url}, '')),'D')
), plainto_tsquery('english', ${queryText}))`;
const finalResults = await db
.select({
id: documents.id,
@ -138,12 +145,25 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>()
userId: documents.userId,
description: documents.description,
ogImage: documents.ogImage,
similarity: vectorSimilarity,
textRank: textSearchRank,
})
.from(chunk)
.innerJoin(documents, eq(chunk.documentId, documents.id))
.where(and(eq(documents.userId, user.id), sql`${similarity} > 0.4`))
.orderBy(desc(similarity))
.limit(5);
.where(
and(
eq(documents.userId, user.id),
sql`${vectorSimilarity} > 0.5`
)
)
.orderBy(
desc(sql<number>`(
0.6 * ${vectorSimilarity} +
0.25 * ${textSearchRank} +
0.15 * (1.0 / (1.0 + extract(epoch from age(${documents.updatedAt})) / (90 * 24 * 60 * 60)))
)::float`)
)
.limit(15);
const cleanDocumentsForContext = finalResults.map((d) => ({
title: d.title,
@ -531,24 +551,37 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>()
);
}
// Perform semantic search using cosine similarity
const results = await database(c.env.HYPERDRIVE.connectionString)
// Pre-compute the vector similarity expression to avoid multiple calculations
const vectorSimilarity = sql<number>`1 - (embeddings <=> ${JSON.stringify(embeddings.data[0])}::vector)`;
const textSearchRank = sql<number>`ts_rank_cd((
setweight(to_tsvector('english', coalesce(${documents.content}, '')),'A') ||
setweight(to_tsvector('english', coalesce(${documents.title}, '')),'B') ||
setweight(to_tsvector('english', coalesce(${documents.description}, '')),'C') ||
setweight(to_tsvector('english', coalesce(${documents.url}, '')),'D')
), plainto_tsquery('english', ${query}))`;
const results = await db
.select({
id: documents.id,
uuid: documents.uuid,
content: documents.content,
type: documents.type,
url: documents.url,
title: documents.title,
createdAt: documents.createdAt,
chunkContent: chunk.textContent,
similarity: sql<number>`1 - (embeddings <=> ${JSON.stringify(
embeddings.data[0]
)}::vector)`,
updatedAt: documents.updatedAt,
userId: documents.userId,
description: documents.description,
ogImage: documents.ogImage,
similarity: vectorSimilarity,
textRank: textSearchRank,
})
.from(chunk)
.innerJoin(documents, eq(chunk.documentId, documents.id))
.where(
and(
eq(documents.userId, user.id),
sql`1 - (embeddings <=> ${JSON.stringify(embeddings.data[0])}::vector) >= ${threshold}`,
sql`${vectorSimilarity} > ${threshold}`,
...(spaces && spaces.length > 0
? [
exists(
@ -570,7 +603,11 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>()
)
)
.orderBy(
sql`1 - (embeddings <=> ${JSON.stringify(embeddings.data[0])}::vector) desc`
desc(sql<number>`(
0.6 * ${vectorSimilarity} +
0.25 * ${textSearchRank} +
0.15 * (1.0 / (1.0 + extract(epoch from age(${documents.updatedAt})) / (90 * 24 * 60 * 60)))
)::float`)
)
.limit(limit);

View file

@ -24,7 +24,9 @@ export class ContentWorkflow extends WorkflowEntrypoint<Env, WorkflowParams> {
async run(event: WorkflowEvent<WorkflowParams>, step: WorkflowStep) {
// Step 0: Check if user has reached memory limit
await step.do("check memory limit", async () => {
const existingMemories = await database(this.env.HYPERDRIVE.connectionString)
const existingMemories = await database(
this.env.HYPERDRIVE.connectionString
)
.select()
.from(documents)
.where(eq(documents.userId, event.payload.userId));
@ -33,7 +35,9 @@ export class ContentWorkflow extends WorkflowEntrypoint<Env, WorkflowParams> {
await database(this.env.HYPERDRIVE.connectionString)
.delete(documents)
.where(eq(documents.uuid, event.payload.uuid));
throw new NonRetryableError("You have reached the maximum limit of 2000 memories");
throw new NonRetryableError(
"You have reached the maximum limit of 2000 memories"
);
}
});
@ -142,12 +146,14 @@ export class ContentWorkflow extends WorkflowEntrypoint<Env, WorkflowParams> {
);
}
// Step 3: Generate embeddings
const { data: embeddings } = await this.env.AI.run(
"@cf/baai/bge-base-en-v1.5",
{
text: chunked,
}
);
const {data: embeddings} = await this.env.AI.run("@cf/baai/bge-base-en-v1.5", {
text: chunked,
});
// Step 4: Prepare chunk data
const chunkInsertData: ChunkInsert[] = await step.do(
"prepare chunk data",
@ -160,8 +166,6 @@ export class ContentWorkflow extends WorkflowEntrypoint<Env, WorkflowParams> {
}))
);
console.log(chunkInsertData);
// Step 5: Insert chunks
if (chunkInsertData.length > 0) {
await step.do("insert chunks", async () =>

View file

@ -13,6 +13,7 @@ import {
jsonb,
date,
} from "drizzle-orm/pg-core";
import { sql } from "drizzle-orm";
import { Metadata } from "../../apps/backend/src/types";
export const users = pgTable(
@ -173,13 +174,22 @@ export const documents = pgTable(
errorMessage: text("error_message"),
contentHash: text("content_hash"),
},
(document) => ({
documentsIdIdx: uniqueIndex("document_id_idx").on(document.id),
documentsUuidIdx: uniqueIndex("document_uuid_idx").on(document.uuid),
documentsTypdIdx: index("document_type_idx").on(document.type),
(table) => ({
documentsIdIdx: uniqueIndex("document_id_idx").on(table.id),
documentsUuidIdx: uniqueIndex("document_uuid_idx").on(table.uuid),
documentsTypdIdx: index("document_type_idx").on(table.type),
documentRawUserIdx: uniqueIndex("document_raw_user_idx").on(
document.raw,
document.userId
table.raw,
table.userId
),
searchIndex: index("documents_search_idx").using(
"gin",
sql`(
setweight(to_tsvector('english', coalesce(${table.content}, '')),'A') ||
setweight(to_tsvector('english', coalesce(${table.title}, '')),'B') ||
setweight(to_tsvector('english', coalesce(${table.description}, '')),'C') ||
setweight(to_tsvector('english', coalesce(${table.url}, '')),'D')
)`
),
})
);