fix: GraphQAchain Invalid prompts

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2024-08-13 19:56:54 -07:00
parent 6d8e1fe994
commit 4b3148fc3e
47 changed files with 62 additions and 5849 deletions

View file

@ -6,13 +6,14 @@ from langchain_core.documents import Document
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import Neo4jVector
from envs import ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, API_SECRET_KEY, SECRET_KEY
from prompts import CYPHER_QA_PROMPT, DOC_DESCRIPTION_PROMPT, SIMILARITY_SEARCH_PROMPT , CYPHER_GENERATION_PROMPT, DOCUMENT_METADATA_EXTRACTION_PROMT
from pydmodels import DescriptionResponse, UserQuery, DocMeta, RetrivedDocList, UserQueryResponse
from prompts import CYPHER_QA_PROMPT, DOC_DESCRIPTION_PROMPT, GRAPH_QUERY_GEN_PROMPT, SIMILARITY_SEARCH_PROMPT , CYPHER_GENERATION_PROMPT, DOCUMENT_METADATA_EXTRACTION_PROMT
from pydmodels import DescriptionResponse, UserQuery, DocMeta, RetrivedDocList, UserQueryResponse, VectorSearchQuery
from langchain_experimental.text_splitter import SemanticChunker
#Our Imps
from LLMGraphTransformer import LLMGraphTransformer
from langchain_openai import ChatOpenAI
from DataExample import examples
# Auth Libs
from fastapi import FastAPI, Depends, HTTPException, Request, status
@ -39,8 +40,6 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
raise HTTPException(status_code=401, detail="Unauthorized")
query = data.query
graph = Neo4jGraph(url=data.neourl, username=data.neouser, password=data.neopass)
llm = ChatOpenAI(
@ -51,6 +50,13 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
api_key=data.openaikey
)
# Query Expansion
searchchain = GRAPH_QUERY_GEN_PROMPT | llm
qry = searchchain.invoke({"question": data.query, "context": examples})
query = qry.content
embeddings = OpenAIEmbeddings(
model="text-embedding-ada-002",
api_key=data.openaikey,
@ -96,19 +102,22 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
)
docstoreturn = [i for n, i in enumerate(docstoreturn) if i not in docstoreturn[n + 1:]]
# responsegrp = chain.invoke({"query": query})
try:
response = chain.invoke({"query": query})
if "don't know" in response["result"]:
responsegrp = chain.invoke({"query": query})
if "don't know" in responsegrp["result"]:
raise Exception("No response from graph")
structured_llm = llm.with_structured_output(RetrivedDocList)
structured_llm = llm.with_structured_output(VectorSearchQuery)
doc_extract_chain = DOCUMENT_METADATA_EXTRACTION_PROMT | structured_llm
query = doc_extract_chain.invoke(response["intermediate_steps"][1]["context"])
newquery = doc_extract_chain.invoke(responsegrp["intermediate_steps"][1]["context"])
docs = vector_index.similarity_search(query.searchquery,k=5)
docs = vector_index.similarity_search(newquery.searchquery,k=5)
docstoreturn = []
@ -127,12 +136,12 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
docstoreturn = [i for n, i in enumerate(docstoreturn) if i not in docstoreturn[n + 1:]]
return UserQueryResponse(relateddocs=docstoreturn,response=response["result"])
return UserQueryResponse(relateddocs=docstoreturn,response=responsegrp["result"])
except:
# Fallback to Similarity Search RAG
searchchain = SIMILARITY_SEARCH_PROMPT | llm
response = searchchain.invoke({"question": query, "context": docs})
response = searchchain.invoke({"question": data.query, "context": docs})
return UserQueryResponse(relateddocs=docstoreturn,response=response.content)