added gemini streaming in cf-ai-backend

This commit is contained in:
Dhravya 2024-03-31 22:44:22 -07:00
parent 1ae3ff1a5f
commit c40202fd04
3 changed files with 50 additions and 27 deletions

View file

@ -7,15 +7,15 @@ import type {
import {
CloudflareVectorizeStore,
} from "@langchain/cloudflare";
import { Ai } from '@cloudflare/ai';
import { OpenAIEmbeddings } from "./OpenAIEmbedder";
import { AiTextGenerationOutput } from "@cloudflare/ai/dist/ai/tasks/text-generation";
import { GoogleGenerativeAI } from "@google/generative-ai";
export interface Env {
VECTORIZE_INDEX: VectorizeIndex;
AI: Fetcher;
SECURITY_KEY: string;
OPENAI_API_KEY: string;
GOOGLE_AI_API_KEY: string;
}
@ -38,7 +38,10 @@ export default {
const store = new CloudflareVectorizeStore(embeddings, {
index: env.VECTORIZE_INDEX,
});
const ai = new Ai(env.AI)
// const ai = new Ai(env.AI)
const genAI = new GoogleGenerativeAI(env.GOOGLE_AI_API_KEY);
const model = genAI.getGenerativeModel({ model: "gemini-pro" });
if (pathname === "/add" && request.method === "POST") {
@ -119,22 +122,27 @@ export default {
return new Response(JSON.stringify({ message: "No Results Found" }), { status: 400 });
}
const metadatas = vec.map(({ metadata }) => metadata)
const preparedContext = vec.slice(0, 3).map(({ metadata }) => `Website title: ${metadata!.title}\nDescription: ${metadata!.description}\nURL: ${metadata!.url}\nContent: ${metadata!.text}`).join("\n\n");
console.log(metadatas)
const prompt = `You are an agent that summarizes a page based on the query. Be direct and concise, don't say 'based on the context'.\n\n Context:\n${preparedContext} \nAnswer this question based on the context. Question: ${query}\nAnswer:`
const output = await model.generateContentStream(prompt);
// TODO: TAKE ALL THE HIGH SCORED IDS INTO CONSIDERATION
const output: AiTextGenerationOutput = await ai.run('@hf/thebloke/mistral-7b-instruct-v0.1-awq', {
prompt: `You are an agent that summarizes a page based on the query. Be direct and concise, don't say 'based on the context'.\n\n Context:\n${vec[0].metadata!.text} \nAnswer this question based on the context. Question: ${query}\nAnswer:`,
stream: true
}) as ReadableStream
return new Response(output, {
headers: {
"content-type": "text/event-stream",
},
});
const response = new Response(
new ReadableStream({
async start(controller) {
const converter = new TextEncoder();
for await (const chunk of output.stream) {
const chunkText = await chunk.text();
const encodedChunk = converter.encode("data: " + JSON.stringify({ "response": chunkText }) + "\n\n");
controller.enqueue(encodedChunk);
}
const doneChunk = converter.encode("data: [DONE]");
controller.enqueue(doneChunk);
controller.close();
}
})
);
return response;
}
else if (pathname === "/ask" && request.method === "POST") {
@ -146,17 +154,26 @@ export default {
return new Response(JSON.stringify({ message: "Invalid Page Content" }), { status: 400 });
}
const output: AiTextGenerationOutput = await ai.run('@hf/thebloke/mistral-7b-instruct-v0.1-awq', {
prompt: `You are an agent that answers a question based on the query. Be direct and concise, don't say 'based on the context'.\n\n Context:\n${body.query} \nAnswer this question based on the context. Question: ${body.query}\nAnswer:`,
stream: true
}) as ReadableStream
const prompt = `You are an agent that answers a question based on the query. Be direct and concise, don't say 'based on the context'.\n\n Context:\n${body.query} \nAnswer this question based on the context. Question: ${body.query}\nAnswer:`
const output = await model.generateContentStream(prompt);
return new Response(output, {
headers: {
"content-type": "text/event-stream",
},
});
const response = new Response(
new ReadableStream({
async start(controller) {
const converter = new TextEncoder();
for await (const chunk of output.stream) {
const chunkText = await chunk.text();
console.log(chunkText);
const encodedChunk = converter.encode("data: " + JSON.stringify({ "response": chunkText }) + "\n\n");
controller.enqueue(encodedChunk);
}
const doneChunk = converter.encode("data: [DONE] \n\n");
controller.enqueue(doneChunk);
controller.close();
}
})
);
return response;
}
return new Response(JSON.stringify({ message: "Invalid Request" }), { status: 400 });

View file

@ -82,6 +82,11 @@ function QueryAI() {
const response = await fetch(`/api/query?q=${input}`);
if (response.status !== 200) {
setIsAiLoading(false);
return;
}
if (response.body) {
let reader = response.body.getReader();
let decoder = new TextDecoder('utf-8');

View file

@ -39,6 +39,7 @@
"@cloudflare/ai": "^1.0.52",
"@cloudflare/next-on-pages-next-dev": "^0.0.1",
"@crxjs/vite-plugin": "^1.0.14",
"@google/generative-ai": "^0.3.1",
"@heroicons/react": "^2.1.1",
"@langchain/cloudflare": "^0.0.3",
"@radix-ui/colors": "^3.0.0",