hybrid rag looks good now

This commit is contained in:
Dhravya Shah 2025-02-18 21:51:26 -07:00
parent 6cfc234cc0
commit d5477b4ef3
5 changed files with 2665 additions and 117 deletions

View file

@ -0,0 +1,7 @@
DROP INDEX IF EXISTS "documents_search_idx";--> statement-breakpoint
CREATE INDEX IF NOT EXISTS "documents_search_idx" ON "documents" USING gin ((
setweight(to_tsvector('english', coalesce("content", '')),'A') ||
setweight(to_tsvector('english', coalesce("title", '')),'B') ||
setweight(to_tsvector('english', coalesce("description", '')),'C') ||
setweight(to_tsvector('english', coalesce("url", '')),'D')
));

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -120,6 +120,20 @@
"when": 1739937938319,
"tag": "0016_good_deathbird",
"breakpoints": true
},
{
"idx": 17,
"version": "7",
"when": 1739939067023,
"tag": "0017_oval_misty_knight",
"breakpoints": true
},
{
"idx": 18,
"version": "7",
"when": 1739939254444,
"tag": "0018_past_inertia",
"breakpoints": true
}
]
}

View file

@ -89,7 +89,8 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>()
});
const googleClient = wrapAISDKModel(
openai(c.env).chat("gpt-4o-mini-2024-07-18"));
openai(c.env).chat("gpt-4o-mini-2024-07-18")
);
// Get last user message and generate embedding in parallel with thread creation
let lastUserMessage = coreMessages.findLast((i) => i.role === "user");
@ -124,96 +125,137 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>()
return c.json({ error: "Failed to generate embedding" }, 500);
}
// Pre-compute the vector similarity expression to avoid multiple calculations
const vectorSimilarity = sql<number>`1 - (embeddings <=> ${JSON.stringify(embedding[0])}::vector)`;
const textSearchRank = sql<number>`ts_rank_cd((
setweight(to_tsvector('english', coalesce(${documents.content}, '')),'A') ||
setweight(to_tsvector('english', coalesce(${documents.title}, '')),'B') ||
setweight(to_tsvector('english', coalesce(${documents.description}, '')),'C') ||
setweight(to_tsvector('english', coalesce(${documents.url}, '')),'D')
), plainto_tsquery('english', ${queryText}))`;
const finalResults = await db
.select({
id: documents.id,
content: documents.content,
type: documents.type,
url: documents.url,
title: documents.title,
createdAt: documents.createdAt,
updatedAt: documents.updatedAt,
userId: documents.userId,
description: documents.description,
ogImage: documents.ogImage,
similarity: vectorSimilarity,
textRank: textSearchRank,
})
.from(chunk)
.innerJoin(documents, eq(chunk.documentId, documents.id))
.where(
and(
eq(documents.userId, user.id),
sql`${vectorSimilarity} > 0.5`
)
)
.orderBy(
desc(sql<number>`(
0.6 * ${vectorSimilarity} +
0.25 * ${textSearchRank} +
0.15 * (1.0 / (1.0 + extract(epoch from age(${documents.updatedAt})) / (90 * 24 * 60 * 60)))
)::float`)
)
.limit(15);
const cleanDocumentsForContext = finalResults.map((d) => ({
title: d.title,
description: d.description,
url: d.url,
type: d.type,
content: d.content,
}));
if (lastUserMessage) {
lastUserMessage.content =
typeof lastUserMessage.content === "string"
? lastUserMessage.content +
`<context>${JSON.stringify(cleanDocumentsForContext)}</context>`
: [
...lastUserMessage.content,
{
type: "text",
text: `<context>${JSON.stringify(cleanDocumentsForContext)}</context>`,
},
];
coreMessages[coreMessages.length - 1] = lastUserMessage;
}
try {
const data = new StreamData();
// De-duplicate chunks by URL to avoid showing duplicate content
const uniqueResults = finalResults.reduce((acc, current) => {
const existingResult = acc.find(item => item.id === current.id);
if (!existingResult) {
acc.push(current);
}
return acc;
}, [] as typeof finalResults);
data.appendMessageAnnotation(
uniqueResults.map((r) => ({
id: r.id,
content: r.content,
type: r.type,
url: r.url,
title: r.title,
description: r.description,
ogImage: r.ogImage,
userId: r.userId,
createdAt: r.createdAt.toISOString(),
updatedAt: r.updatedAt?.toISOString() || null,
}))
// Pre-compute the vector similarity expression
const vectorSimilarity = sql<number>`1 - (embeddings <=> ${JSON.stringify(embedding[0])}::vector)`;
const textSearchRank = sql<number>`ts_rank_cd(
to_tsvector('english', coalesce(${chunk.textContent}, '')),
plainto_tsquery('english', ${queryText})
)`;
// Get matching chunks with document info
const matchingChunks = await db
.select({
chunkId: chunk.id,
documentId: chunk.documentId,
textContent: chunk.textContent,
orderInDocument: chunk.orderInDocument,
metadata: chunk.metadata,
similarity: vectorSimilarity,
textRank: textSearchRank,
// Document fields
docId: documents.id,
docUuid: documents.uuid,
docContent: documents.content,
docType: documents.type,
docUrl: documents.url,
docTitle: documents.title,
docDescription: documents.description,
docOgImage: documents.ogImage,
})
.from(chunk)
.innerJoin(documents, eq(chunk.documentId, documents.id))
.where(
and(eq(documents.userId, user.id), sql`${vectorSimilarity} > 0.5`)
)
.orderBy(
desc(sql<number>`(
0.6 * ${vectorSimilarity} +
0.25 * ${textSearchRank} +
0.15 * (1.0 / (1.0 + extract(epoch from age(${documents.updatedAt})) / (90 * 24 * 60 * 60)))
)::float`)
)
.limit(15);
// Get unique document IDs from matching chunks
const uniqueDocIds = [
...new Set(matchingChunks.map((c) => c.documentId)),
];
// Fetch all chunks for these documents to get context
const contextChunks = await db
.select({
id: chunk.id,
documentId: chunk.documentId,
textContent: chunk.textContent,
orderInDocument: chunk.orderInDocument,
metadata: chunk.metadata,
})
.from(chunk)
.where(inArray(chunk.documentId, uniqueDocIds))
.orderBy(chunk.documentId, chunk.orderInDocument);
// Group chunks by document
const chunksByDocument = new Map<number, typeof contextChunks>();
for (const chunk of contextChunks) {
const docChunks = chunksByDocument.get(chunk.documentId) || [];
docChunks.push(chunk);
chunksByDocument.set(chunk.documentId, docChunks);
}
// Create context with surrounding chunks
const contextualResults = matchingChunks.map((match) => {
const docChunks = chunksByDocument.get(match.documentId) || [];
const matchIndex = docChunks.findIndex((c) => c.id === match.chunkId);
// Get surrounding chunks (1 before and 1 after)
const start = Math.max(0, matchIndex - 1);
const end = Math.min(docChunks.length, matchIndex + 2);
const relevantChunks = docChunks.slice(start, end);
return {
id: match.docId,
title: match.docTitle,
description: match.docDescription,
url: match.docUrl,
type: match.docType,
content: relevantChunks.map((c) => c.textContent).join("\n"),
similarity: Number(match.similarity.toFixed(4)),
chunks: relevantChunks.map((c) => ({
id: c.id,
content: c.textContent,
orderInDocument: c.orderInDocument,
metadata: c.metadata,
isMatch: c.id === match.chunkId,
})),
};
});
// Remove duplicates based on document ID
const uniqueResults = contextualResults.reduce(
(acc, current) => {
const existingDoc = acc.find((doc) => doc.id === current.id);
if (!existingDoc) {
acc.push(current);
} else if (current.similarity > existingDoc.similarity) {
// Replace if current match is better
const index = acc.findIndex((doc) => doc.id === current.id);
acc[index] = current;
}
return acc;
},
[] as typeof contextualResults
);
data.appendMessageAnnotation(uniqueResults);
if (lastUserMessage) {
lastUserMessage.content =
typeof lastUserMessage.content === "string"
? lastUserMessage.content +
`<context>${JSON.stringify(uniqueResults)}</context>`
: [
...lastUserMessage.content,
{
type: "text",
text: `<context>${JSON.stringify(uniqueResults)}</context>`,
},
];
coreMessages[coreMessages.length - 1] = lastUserMessage;
}
const result = await streamText({
model: googleClient,
experimental_providerMetadata: {
@ -267,7 +309,7 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>()
role: "assistant",
content:
completion.text +
`<context>[${JSON.stringify(finalResults)}]</context>`,
`<context>[${JSON.stringify(uniqueResults)}]</context>`,
},
];
@ -279,6 +321,8 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>()
}
} catch (error) {
console.error("Failed to update thread:", error);
} finally {
await data.close();
}
},
});
@ -510,32 +554,38 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>()
.from(spaceInDb)
.where(eq(spaceInDb.uuid, spaceId))
.limit(1);
if (space.length === 0) return null;
return {
id: space[0].id,
ownerId: space[0].ownerId,
uuid: space[0].uuid
uuid: space[0].uuid,
};
})
);
// Filter out any null values and check permissions
const validSpaces = spaceDetails.filter((s): s is NonNullable<typeof s> => s !== null);
const unauthorized = validSpaces.filter(s => s.ownerId !== user.id);
const validSpaces = spaceDetails.filter(
(s): s is NonNullable<typeof s> => s !== null
);
const unauthorized = validSpaces.filter((s) => s.ownerId !== user.id);
if (unauthorized.length > 0) {
return c.json(
{
error: "Space permission denied",
details: unauthorized.map(s => s.uuid).join(", "),
details: unauthorized.map((s) => s.uuid).join(", "),
},
403
);
}
// Replace UUIDs with IDs for the database query
spaces.splice(0, spaces.length, ...validSpaces.map(s => s.id.toString()));
spaces.splice(
0,
spaces.length,
...validSpaces.map((s) => s.id.toString())
);
}
try {
@ -553,28 +603,32 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>()
// Pre-compute the vector similarity expression to avoid multiple calculations
const vectorSimilarity = sql<number>`1 - (embeddings <=> ${JSON.stringify(embeddings.data[0])}::vector)`;
const textSearchRank = sql<number>`ts_rank_cd((
setweight(to_tsvector('english', coalesce(${documents.content}, '')),'A') ||
setweight(to_tsvector('english', coalesce(${documents.title}, '')),'B') ||
setweight(to_tsvector('english', coalesce(${documents.description}, '')),'C') ||
setweight(to_tsvector('english', coalesce(${documents.url}, '')),'D')
), plainto_tsquery('english', ${query}))`;
const textSearchRank = sql<number>`ts_rank_cd(
to_tsvector('english', coalesce(${chunk.textContent}, '')),
plainto_tsquery('english', ${query})
)`;
// First get the top matching chunks
const results = await db
.select({
id: documents.id,
uuid: documents.uuid,
content: documents.content,
type: documents.type,
url: documents.url,
title: documents.title,
createdAt: documents.createdAt,
updatedAt: documents.updatedAt,
userId: documents.userId,
description: documents.description,
ogImage: documents.ogImage,
chunkId: chunk.id,
documentId: chunk.documentId,
textContent: chunk.textContent,
orderInDocument: chunk.orderInDocument,
metadata: chunk.metadata,
similarity: vectorSimilarity,
textRank: textSearchRank,
// Document fields
docUuid: documents.uuid,
docContent: documents.content,
docType: documents.type,
docUrl: documents.url,
docTitle: documents.title,
docCreatedAt: documents.createdAt,
docUpdatedAt: documents.updatedAt,
docUserId: documents.userId,
docDescription: documents.description,
docOgImage: documents.ogImage,
})
.from(chunk)
.innerJoin(documents, eq(chunk.documentId, documents.id))
@ -611,12 +665,41 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>()
)
.limit(limit);
return c.json({
results: results.map((r) => ({
...r,
similarity: Number(r.similarity.toFixed(4)),
})),
});
// Group results by document and take the best matching chunk
const documentResults = new Map<number, (typeof results)[0]>();
for (const result of results) {
const existingResult = documentResults.get(result.documentId);
if (
!existingResult ||
result.similarity > existingResult.similarity
) {
documentResults.set(result.documentId, result);
}
}
// Convert back to array and format response
const finalResults = Array.from(documentResults.values()).map((r) => ({
id: r.documentId,
uuid: r.docUuid,
content: r.docContent,
type: r.docType,
url: r.docUrl,
title: r.docTitle,
createdAt: r.docCreatedAt,
updatedAt: r.docUpdatedAt,
userId: r.docUserId,
description: r.docDescription,
ogImage: r.docOgImage,
similarity: Number(r.similarity.toFixed(4)),
matchingChunk: {
id: r.chunkId,
content: r.textContent,
orderInDocument: r.orderInDocument,
metadata: r.metadata,
},
}));
return c.json({ results: finalResults });
} catch (error) {
console.error("[Search Error]", error);
return c.json(