SurfSense v3 - Highlight: Local LLM Support

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2024-09-19 22:50:16 -07:00
parent 04df919cf9
commit 7f38091d3d
13 changed files with 692 additions and 1345 deletions

View file

@ -1,42 +1,40 @@
from __future__ import annotations
from langchain.chains import GraphCypherQAChain
from langchain_community.graphs import Neo4jGraph
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.documents import Document
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import Neo4jVector
from envs import ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, API_SECRET_KEY, SECRET_KEY
from prompts import CYPHER_QA_PROMPT, DATE_TODAY, DOC_DESCRIPTION_PROMPT, GRAPH_QUERY_GEN_PROMPT, NOTIFICATION_GENERATION_PROMT, SIMILARITY_SEARCH_PROMPT , CYPHER_GENERATION_PROMPT, DOCUMENT_METADATA_EXTRACTION_PROMT
from pydmodels import ChatToUpdate, DescriptionResponse, GraphDocs, NewUserChat, NewUserData, Notifications, PrecisionQuery, PrecisionResponse, UserQuery, DocMeta, RetrivedDocList, UserQueryResponse, UserQueryWithChatHistory, VectorSearchQuery
from langchain_experimental.text_splitter import SemanticChunker
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_ollama import OllamaLLM
from langchain_openai import ChatOpenAI
from prompts import CONTEXT_ANSWER_PROMPT, DATE_TODAY, SUBQUERY_DECOMPOSITION_PROMT
from pydmodels import ChatToUpdate, DescriptionResponse, DocWithContent, DocumentsToDelete, NewUserChat, UserCreate, UserQuery, RetrivedDocList, UserQueryResponse, UserQueryWithChatHistory
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
#Our Imps
from LLMGraphTransformer import LLMGraphTransformer
from langchain_openai import ChatOpenAI
from DataExample import examples
# import nest_asyncio
# from langchain_community.chains.graph_qa.gremlin import GremlinQAChain
# from langchain_community.graphs import GremlinGraph
# from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
# from langchain_core.documents import Document
# from langchain_openai import AzureChatOpenAI
#Heirerical Indices class
from HIndices import HIndices
from Utils.stringify import stringify
# Auth Libs
from fastapi import FastAPI, Depends, HTTPException, Request, status
from fastapi import FastAPI, Depends, HTTPException, status
from sqlalchemy.orm import Session
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from datetime import datetime, timedelta
from passlib.context import CryptContext
from models import Chat, Notification, User
from database import SessionLocal, engine
from pydantic import BaseModel
from models import Chat, Documents, SearchSpace, User
from database import SessionLocal
from fastapi.middleware.cors import CORSMiddleware
from langchain_openai import AzureChatOpenAI
import os
from dotenv import load_dotenv
load_dotenv()
IS_LOCAL_SETUP = os.environ.get("IS_LOCAL_SETUP")
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.environ.get("ACCESS_TOKEN_EXPIRE_MINUTES"))
ALGORITHM = os.environ.get("ALGORITHM")
API_SECRET_KEY = os.environ.get("API_SECRET_KEY")
SECRET_KEY = os.environ.get("SECRET_KEY")
app = FastAPI()
# Dependency
@ -47,251 +45,117 @@ def get_db():
finally:
db.close()
class UserCreate(BaseModel):
username: str
password: str
apisecretkey: str
# General GraphCypherQAChain
@app.post("/")
@app.post("/chat/")
def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
if(data.apisecretkey != API_SECRET_KEY):
raise HTTPException(status_code=401, detail="Unauthorized")
graph = Neo4jGraph(url=data.neourl, username=data.neouser, password=data.neopass)
llm = ChatOpenAI(
model="gpt-4o-mini",
temperature=0,
max_tokens=None,
timeout=None,
api_key=data.openaikey
)
# Query Expansion
# searchchain = GRAPH_QUERY_GEN_PROMPT | llm
# qry = searchchain.invoke({"question": data.query, "context": examples})
query = data.query #qry.content
embeddings = OpenAIEmbeddings(
model="text-embedding-ada-002",
api_key=data.openaikey,
)
chain = GraphCypherQAChain.from_llm(
graph=graph,
cypher_prompt=CYPHER_GENERATION_PROMPT,
cypher_llm=llm,
verbose=True,
validate_cypher=True,
qa_prompt=CYPHER_QA_PROMPT ,
qa_llm=llm,
return_intermediate_steps=True,
top_k=5,
)
vector_index = Neo4jVector.from_existing_graph(
embeddings,
graph=graph,
search_type="hybrid",
node_label="Document",
text_node_properties=["text"],
embedding_node_property="embedding",
)
graphdocs = vector_index.similarity_search(data.query,k=15)
docsDict = {}
for d in graphdocs:
if d.metadata['BrowsingSessionId'] not in docsDict:
newVal = d.metadata.copy()
newVal['VisitedWebPageContent'] = d.page_content
docsDict[d.metadata['BrowsingSessionId']] = newVal
else:
docsDict[d.metadata['BrowsingSessionId']]['VisitedWebPageContent'] += d.page_content
docstoreturn = []
for x in docsDict.values():
docstoreturn.append(DocMeta(
BrowsingSessionId=x['BrowsingSessionId'],
VisitedWebPageURL=x['VisitedWebPageURL'],
VisitedWebPageVisitDurationInMilliseconds=x['VisitedWebPageVisitDurationInMilliseconds'],
VisitedWebPageTitle=x['VisitedWebPageTitle'],
VisitedWebPageReffererURL=x['VisitedWebPageReffererURL'],
VisitedWebPageDateWithTimeInISOString=x['VisitedWebPageDateWithTimeInISOString'],
VisitedWebPageContent=x['VisitedWebPageContent']
))
try:
responsegrp = chain.invoke({"query": query})
if "don't know" in responsegrp["result"]:
raise Exception("No response from graph")
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
structured_llm = llm.with_structured_output(VectorSearchQuery)
doc_extract_chain = DOCUMENT_METADATA_EXTRACTION_PROMT | structured_llm
query = data.query
search_space = data.search_space
newquery = doc_extract_chain.invoke(responsegrp["intermediate_steps"][1]["context"])
graphdocs = vector_index.similarity_search(newquery.searchquery,k=15)
docsDict = {}
for d in graphdocs:
if d.metadata['BrowsingSessionId'] not in docsDict:
newVal = d.metadata.copy()
newVal['VisitedWebPageContent'] = d.page_content
docsDict[d.metadata['BrowsingSessionId']] = newVal
else:
docsDict[d.metadata['BrowsingSessionId']]['VisitedWebPageContent'] += d.page_content
docstoreturn = []
for x in docsDict.values():
docstoreturn.append(DocMeta(
BrowsingSessionId=x['BrowsingSessionId'],
VisitedWebPageURL=x['VisitedWebPageURL'],
VisitedWebPageVisitDurationInMilliseconds=x['VisitedWebPageVisitDurationInMilliseconds'],
VisitedWebPageTitle=x['VisitedWebPageTitle'],
VisitedWebPageReffererURL=x['VisitedWebPageReffererURL'],
VisitedWebPageDateWithTimeInISOString=x['VisitedWebPageDateWithTimeInISOString'],
VisitedWebPageContent=x['VisitedWebPageContent']
))
return UserQueryResponse(relateddocs=docstoreturn,response=responsegrp["result"])
except:
# Fallback to Similarity Search RAG
searchchain = SIMILARITY_SEARCH_PROMPT | llm
response = searchchain.invoke({"question": data.query, "context": docstoreturn})
return UserQueryResponse(relateddocs=docstoreturn,response=response.content)
# Precision Search
@app.post("/precision")
def get_precision_search_response(data: PrecisionQuery, response_model=PrecisionResponse):
if(data.apisecretkey != API_SECRET_KEY):
raise HTTPException(status_code=401, detail="Unauthorized")
graph = Neo4jGraph(url=data.neourl, username=data.neouser, password=data.neopass)
GRAPH_QUERY = "MATCH (d:Document) WHERE d.VisitedWebPageDateWithTimeInISOString >= " + "'" + data.daterange[0] + "'" + " AND d.VisitedWebPageDateWithTimeInISOString <= " + "'" + data.daterange[1] + "'"
if(data.timerange[0] >= data.timerange[1]):
GRAPH_QUERY += " AND d.VisitedWebPageVisitDurationInMilliseconds >= 0"
else:
GRAPH_QUERY += " AND d.VisitedWebPageVisitDurationInMilliseconds >= "+ str(data.timerange[0]) + " AND d.VisitedWebPageVisitDurationInMilliseconds <= " + str(data.timerange[1])
if(data.webpageurl):
GRAPH_QUERY += " AND d.VisitedWebPageURL CONTAINS " + "'" + data.webpageurl.lower() + "'"
if(data.sessionid):
GRAPH_QUERY += " AND d.BrowsingSessionId = " + "'" + data.sessionid + "'"
GRAPH_QUERY += " RETURN d;"
graphdocs = graph.query(GRAPH_QUERY)
docsDict = {}
for d in graphdocs:
if d['d']['VisitedWebPageVisitDurationInMilliseconds'] not in docsDict:
docsDict[d['d']['VisitedWebPageVisitDurationInMilliseconds']] = d['d']
if(IS_LOCAL_SETUP == 'true'):
sub_query_llm = OllamaLLM(model="mistral-nemo",temperature=0)
qa_llm = OllamaLLM(model="mistral-nemo",temperature=0)
else:
docsDict[d['d']['VisitedWebPageVisitDurationInMilliseconds']]['text'] += d['d']['text']
docs = []
for x in docsDict.values():
docs.append(DocMeta(
BrowsingSessionId=x['BrowsingSessionId'],
VisitedWebPageURL=x['VisitedWebPageURL'],
VisitedWebPageVisitDurationInMilliseconds=x['VisitedWebPageVisitDurationInMilliseconds'],
VisitedWebPageTitle=x['VisitedWebPageTitle'],
VisitedWebPageReffererURL=x['VisitedWebPageReffererURL'],
VisitedWebPageDateWithTimeInISOString=x['VisitedWebPageDateWithTimeInISOString'],
VisitedWebPageContent=x['text']
))
return PrecisionResponse(documents=docs)
sub_query_llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", api_key=data.openaikey)
qa_llm = ChatOpenAI(temperature=0.5, model_name="gpt-4o-mini", api_key=data.openaikey)
# Multi DOC Chat
@app.post("/chat/docs")
def doc_chat_with_history(data: UserQueryWithChatHistory, response_model=DescriptionResponse):
if(data.apisecretkey != API_SECRET_KEY):
raise HTTPException(status_code=401, detail="Unauthorized")
llm = ChatOpenAI(
model="gpt-4o-mini",
temperature=0,
max_tokens=None,
timeout=None,
api_key=data.openaikey
)
chatHistory = []
for chat in data.chat:
if(chat.type == 'system'):
chatHistory.append(SystemMessage(content=DATE_TODAY + """You are an helpful assistant for question-answering tasks.
Use the following pieces of retrieved context to answer the question.
If you don't know the answer, just say that you don't know.
Context:""" + str(chat.content)))
if(chat.type == 'ai'):
chatHistory.append(AIMessage(content=chat.content))
if(chat.type == 'human'):
chatHistory.append(HumanMessage(content=chat.content))
chatHistory.append(("human", "{input}"));
qa_prompt = ChatPromptTemplate.from_messages(chatHistory)
# Create an LLMChain for sub-query decomposition
subquery_decomposer_chain = SUBQUERY_DECOMPOSITION_PROMT | sub_query_llm
descriptionchain = qa_prompt | llm
def decompose_query(original_query: str):
"""
Decompose the original query into simpler sub-queries.
Args:
original_query (str): The original complex query
Returns:
List[str]: A list of simpler sub-queries
"""
if(IS_LOCAL_SETUP == 'true'):
response = subquery_decomposer_chain.invoke(original_query)
else:
response = subquery_decomposer_chain.invoke(original_query).content
sub_queries = [q.strip() for q in response.split('\n') if q.strip() and not q.strip().startswith('Sub-queries:')]
return sub_queries
response = descriptionchain.invoke({"input": data.query})
return DescriptionResponse(response=response.content)
# Create Heirarical Indecices
if(IS_LOCAL_SETUP == 'true'):
index = HIndices(username=username)
else:
index = HIndices(username=username,api_key=data.openaikey)
# For Those Who Want HyDe Questions
# sub_queries = decompose_query(query)
sub_queries = []
sub_queries.append(query)
# DOC DESCRIPTION
@app.post("/kb/doc")
def get_doc_description(data: UserQuery, response_model=DescriptionResponse):
if(data.apisecretkey != API_SECRET_KEY):
raise HTTPException(status_code=401, detail="Unauthorized")
document = data.query
llm = ChatOpenAI(
model="gpt-4o-mini",
temperature=0,
max_tokens=None,
timeout=None,
api_key=data.openaikey
)
descriptionchain = DOC_DESCRIPTION_PROMPT | llm
response = descriptionchain.invoke({"document": document})
return DescriptionResponse(response=response.content)
duplicate_related_summary_docs = []
context_to_answer = ""
for sub_query in sub_queries:
localreturn = index.local_search(query=sub_query, search_space=search_space)
globalreturn, related_summary_docs = index.global_search(query=sub_query, search_space=search_space)
context_to_answer += localreturn + "\n\n" + globalreturn
# SAVE DOCS TO GRAPH DB
@app.post("/kb/")
def populate_graph(apires: RetrivedDocList, db: Session = Depends(get_db)):
duplicate_related_summary_docs.extend(related_summary_docs)
combined_docs_seen_metadata = set()
combined_docs_unique_documents = []
for doc in duplicate_related_summary_docs:
# Convert metadata to a tuple of its items (this allows it to be added to a set)
doc.metadata['relevance_score'] = 0.0
metadata_tuple = tuple(sorted(doc.metadata.items()))
if metadata_tuple not in combined_docs_seen_metadata:
combined_docs_seen_metadata.add(metadata_tuple)
combined_docs_unique_documents.append(doc)
returnDocs = []
for doc in combined_docs_unique_documents:
entry = DocWithContent(
BrowsingSessionId=doc.metadata['BrowsingSessionId'],
VisitedWebPageURL=doc.metadata['VisitedWebPageURL'],
VisitedWebPageContent=doc.page_content,
VisitedWebPageTitle=doc.metadata['VisitedWebPageTitle'],
VisitedWebPageDateWithTimeInISOString=doc.metadata['VisitedWebPageDateWithTimeInISOString'],
VisitedWebPageReffererURL=doc.metadata['VisitedWebPageReffererURL'],
VisitedWebPageVisitDurationInMilliseconds=doc.metadata['VisitedWebPageVisitDurationInMilliseconds'],
)
returnDocs.append(entry)
ans_chain = CONTEXT_ANSWER_PROMPT | qa_llm
finalans = ans_chain.invoke({"query": query, "context": context_to_answer})
if(IS_LOCAL_SETUP == 'true'):
return UserQueryResponse(response=finalans, relateddocs=returnDocs)
else:
return UserQueryResponse(response=finalans.content, relateddocs=returnDocs)
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
# SAVE DOCS
@app.post("/save/")
def save_data(apires: RetrivedDocList, db: Session = Depends(get_db)):
try:
payload = jwt.decode(apires.token, SECRET_KEY, algorithms=[ALGORITHM])
@ -299,58 +163,49 @@ def populate_graph(apires: RetrivedDocList, db: Session = Depends(get_db)):
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
print("STARTED")
# print(apires)
graph = Neo4jGraph(url=apires.neourl, username=apires.neouser, password=apires.neopass)
llm = ChatOpenAI(
model="gpt-4o-mini",
temperature=0,
max_tokens=None,
timeout=None,
api_key=apires.openaikey
)
embeddings = OpenAIEmbeddings(
model="text-embedding-ada-002",
api_key=apires.openaikey,
)
llm_transformer = LLMGraphTransformer(llm=llm)
DocumentPgEntry = []
raw_documents = []
searchspace = db.query(SearchSpace).filter(SearchSpace.search_space == apires.search_space).first()
for doc in apires.documents:
raw_documents.append(Document(page_content=doc.pageContent, metadata=doc.metadata))
text_splitter = SemanticChunker(embeddings=embeddings)
documents = text_splitter.split_documents(raw_documents)
graph_documents = llm_transformer.convert_to_graph_documents(documents)
graph.add_graph_documents(
graph_documents,
baseEntityLabel=True,
include_source=True
)
structured_llm = llm.with_structured_output(Notifications)
notifs_extraction_chain = NOTIFICATION_GENERATION_PROMT | structured_llm
notifications = notifs_extraction_chain.invoke({"documents": raw_documents})
notifsdb = []
for text in notifications.notifications:
notifsdb.append(Notification(text=text))
content = f"USER BROWSING SESSION EVENT: \n"
content += f"=======================================METADATA==================================== \n"
content += f"User Browsing Session ID : {doc.metadata.BrowsingSessionId} \n"
content += f"User Visited website with url : {doc.metadata.VisitedWebPageURL} \n"
content += f"This visited website url had title : {doc.metadata.VisitedWebPageTitle} \n"
content += f"User Visited this website from reffering url : {doc.metadata.VisitedWebPageReffererURL} \n"
content += f"User Visited this website url at this Date and Time : {doc.metadata.VisitedWebPageDateWithTimeInISOString} \n"
content += f"User Visited this website for : {str(doc.metadata.VisitedWebPageVisitDurationInMilliseconds)} milliseconds. \n"
content += f"===================================================================================== \n"
content += f"Webpage Content of the visited webpage url in markdown format : \n\n {doc.pageContent} \n\n"
content += f"===================================================================================== \n"
raw_documents.append(Document(page_content=content,metadata=doc.metadata.__dict__))
pgdocmeta = stringify(doc.metadata.__dict__)
if(searchspace):
DocumentPgEntry.append(Documents(file_type='WEBPAGE',title=doc.metadata.VisitedWebPageTitle,search_space=searchspace, document_metadata=pgdocmeta, page_content=content))
else:
DocumentPgEntry.append(Documents(file_type='WEBPAGE',title=doc.metadata.VisitedWebPageTitle,search_space=SearchSpace(search_space=apires.search_space.upper()), document_metadata=pgdocmeta, page_content=content))
user = db.query(User).filter(User.username == username).first()
user.notifications.extend(notifsdb)
db.commit()
#Save docs in PG
user = db.query(User).filter(User.username == username).first()
user.documents.extend(DocumentPgEntry)
db.commit()
# Create Heirarical Indecices
if(IS_LOCAL_SETUP == 'true'):
index = HIndices(username=username)
else:
index = HIndices(username=username,api_key=apires.openaikey)
#Save Indices in vector Stores
index.encode_docs_hierarchical(documents=raw_documents, files_type='WEBPAGE',search_space=apires.search_space.upper(), db=db)
print("FINISHED")
@ -360,21 +215,83 @@ def populate_graph(apires: RetrivedDocList, db: Session = Depends(get_db)):
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
#Fuction to populate db ( Comment out when running on server )
@app.post("/add/")
def add_user(user: UserCreate, db: Session = Depends(get_db)):
db_user = User(username=user.username, hashed_password=user.password, graph_config="", llm_config="")
db.add(db_user)
db.commit()
return "Success"
# Multi DOC Chat
@app.post("/chat/docs")
def doc_chat_with_history(data: UserQueryWithChatHistory, response_model=DescriptionResponse):
try:
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
if(IS_LOCAL_SETUP == 'true'):
llm = OllamaLLM(model="mistral-nemo",temperature=0)
else:
llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", api_key=data.openaikey)
chatHistory = []
for chat in data.chat:
if(chat.type == 'system'):
chatHistory.append(SystemMessage(content=DATE_TODAY + """You are an helpful assistant for question-answering tasks.
Use the following pieces of retrieved context to answer the question.
If you don't know the answer, just say that you don't know.
Context:""" + str(chat.content)))
if(chat.type == 'ai'):
chatHistory.append(AIMessage(content=chat.content))
if(chat.type == 'human'):
chatHistory.append(HumanMessage(content=chat.content))
chatHistory.append(("human", "{input}"));
qa_prompt = ChatPromptTemplate.from_messages(chatHistory)
descriptionchain = qa_prompt | llm
response = descriptionchain.invoke({"input": data.query})
if(IS_LOCAL_SETUP == 'true'):
return DescriptionResponse(response=response)
else:
return DescriptionResponse(response=response.content)
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
# Multi DOC Chat
@app.post("/delete/docs")
def delete_all_related_data(data: DocumentsToDelete, db: Session = Depends(get_db)):
try:
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
if(IS_LOCAL_SETUP == 'true'):
index = HIndices(username=username)
else:
index = HIndices(username=username,api_key=data.openaikey)
message = index.delete_vector_stores(summary_ids_to_delete=data.ids_to_delete,db=db )
return {
"message": message
}
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
#AUTH CODE
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# Recommended for Local Setups
# Manual Origins
# origins = [
# "http://localhost:3000", # Adjust the port if your frontend runs on a different one
# "https://yourfrontenddomain.com",
@ -395,7 +312,7 @@ def get_user_by_username(db: Session, username: str):
def create_user(db: Session, user: UserCreate):
hashed_password = pwd_context.hash(user.password)
db_user = User(username=user.username, hashed_password=hashed_password, graph_config="", llm_config="")
db_user = User(username=user.username, hashed_password=hashed_password)
db.add(db_user)
db.commit()
return "complete"
@ -462,9 +379,6 @@ async def verify_user_token(token: str):
verify_token(token=token)
return {"message": "Token is valid"}
@app.post("/user/chat/save")
def populate_user_chat(chat: NewUserChat, db: Session = Depends(get_db)):
try:
@ -533,34 +447,32 @@ async def get_user_with_token(token: str, db: Session = Depends(get_db)):
"userid": user.id,
"username": user.username,
"chats": user.chats,
"notifications": user.notifications
"documents": user.documents
}
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
@app.get("/user/notification/delete/{token}/{notificationid}")
async def delete_chat_of_user(token: str, notificationid: str, db: Session = Depends(get_db)):
@app.get("/searchspaces/{token}")
async def get_user_with_token(token: str, db: Session = Depends(get_db)):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
notificationindb = db.query(Notification).filter(Notification.id == notificationid).first()
db.delete(notificationindb)
db.commit()
search_spaces = db.query(SearchSpace).all()
return {
"message": "Notification Deleted"
"search_spaces": search_spaces
}
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8000)
uvicorn.run(app, host="127.0.0.1", port=8000)