SurfSense/backend/server.py
2024-08-21 23:06:30 -07:00

566 lines
No EOL
20 KiB
Python

from __future__ import annotations
from langchain.chains import GraphCypherQAChain
from langchain_community.graphs import Neo4jGraph
from langchain_core.documents import Document
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import Neo4jVector
from envs import ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, API_SECRET_KEY, SECRET_KEY
from prompts import CYPHER_QA_PROMPT, DATE_TODAY, DOC_DESCRIPTION_PROMPT, GRAPH_QUERY_GEN_PROMPT, NOTIFICATION_GENERATION_PROMT, SIMILARITY_SEARCH_PROMPT , CYPHER_GENERATION_PROMPT, DOCUMENT_METADATA_EXTRACTION_PROMT
from pydmodels import ChatToUpdate, DescriptionResponse, GraphDocs, NewUserChat, NewUserData, Notifications, PrecisionQuery, PrecisionResponse, UserQuery, DocMeta, RetrivedDocList, UserQueryResponse, UserQueryWithChatHistory, VectorSearchQuery
from langchain_experimental.text_splitter import SemanticChunker
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
#Our Imps
from LLMGraphTransformer import LLMGraphTransformer
from langchain_openai import ChatOpenAI
from DataExample import examples
# import nest_asyncio
# from langchain_community.chains.graph_qa.gremlin import GremlinQAChain
# from langchain_community.graphs import GremlinGraph
# from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
# from langchain_core.documents import Document
# from langchain_openai import AzureChatOpenAI
# Auth Libs
from fastapi import FastAPI, Depends, HTTPException, Request, status
from sqlalchemy.orm import Session
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from datetime import datetime, timedelta
from passlib.context import CryptContext
from models import Chat, Notification, User
from database import SessionLocal, engine
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from langchain_openai import AzureChatOpenAI
app = FastAPI()
# Dependency
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
class UserCreate(BaseModel):
username: str
password: str
apisecretkey: str
# General GraphCypherQAChain
@app.post("/")
def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
if(data.apisecretkey != API_SECRET_KEY):
raise HTTPException(status_code=401, detail="Unauthorized")
graph = Neo4jGraph(url=data.neourl, username=data.neouser, password=data.neopass)
llm = ChatOpenAI(
model="gpt-4o-mini",
temperature=0,
max_tokens=None,
timeout=None,
api_key=data.openaikey
)
# Query Expansion
# searchchain = GRAPH_QUERY_GEN_PROMPT | llm
# qry = searchchain.invoke({"question": data.query, "context": examples})
query = data.query #qry.content
embeddings = OpenAIEmbeddings(
model="text-embedding-ada-002",
api_key=data.openaikey,
)
chain = GraphCypherQAChain.from_llm(
graph=graph,
cypher_prompt=CYPHER_GENERATION_PROMPT,
cypher_llm=llm,
verbose=True,
validate_cypher=True,
qa_prompt=CYPHER_QA_PROMPT ,
qa_llm=llm,
return_intermediate_steps=True,
top_k=5,
)
vector_index = Neo4jVector.from_existing_graph(
embeddings,
graph=graph,
search_type="hybrid",
node_label="Document",
text_node_properties=["text"],
embedding_node_property="embedding",
)
graphdocs = vector_index.similarity_search(data.query,k=15)
docsDict = {}
for d in graphdocs:
if d.metadata['BrowsingSessionId'] not in docsDict:
newVal = d.metadata.copy()
newVal['VisitedWebPageContent'] = d.page_content
docsDict[d.metadata['BrowsingSessionId']] = newVal
else:
docsDict[d.metadata['BrowsingSessionId']]['VisitedWebPageContent'] += d.page_content
docstoreturn = []
for x in docsDict.values():
docstoreturn.append(DocMeta(
BrowsingSessionId=x['BrowsingSessionId'],
VisitedWebPageURL=x['VisitedWebPageURL'],
VisitedWebPageVisitDurationInMilliseconds=x['VisitedWebPageVisitDurationInMilliseconds'],
VisitedWebPageTitle=x['VisitedWebPageTitle'],
VisitedWebPageReffererURL=x['VisitedWebPageReffererURL'],
VisitedWebPageDateWithTimeInISOString=x['VisitedWebPageDateWithTimeInISOString'],
VisitedWebPageContent=x['VisitedWebPageContent']
))
try:
responsegrp = chain.invoke({"query": query})
if "don't know" in responsegrp["result"]:
raise Exception("No response from graph")
structured_llm = llm.with_structured_output(VectorSearchQuery)
doc_extract_chain = DOCUMENT_METADATA_EXTRACTION_PROMT | structured_llm
newquery = doc_extract_chain.invoke(responsegrp["intermediate_steps"][1]["context"])
graphdocs = vector_index.similarity_search(newquery.searchquery,k=15)
docsDict = {}
for d in graphdocs:
if d.metadata['BrowsingSessionId'] not in docsDict:
newVal = d.metadata.copy()
newVal['VisitedWebPageContent'] = d.page_content
docsDict[d.metadata['BrowsingSessionId']] = newVal
else:
docsDict[d.metadata['BrowsingSessionId']]['VisitedWebPageContent'] += d.page_content
docstoreturn = []
for x in docsDict.values():
docstoreturn.append(DocMeta(
BrowsingSessionId=x['BrowsingSessionId'],
VisitedWebPageURL=x['VisitedWebPageURL'],
VisitedWebPageVisitDurationInMilliseconds=x['VisitedWebPageVisitDurationInMilliseconds'],
VisitedWebPageTitle=x['VisitedWebPageTitle'],
VisitedWebPageReffererURL=x['VisitedWebPageReffererURL'],
VisitedWebPageDateWithTimeInISOString=x['VisitedWebPageDateWithTimeInISOString'],
VisitedWebPageContent=x['VisitedWebPageContent']
))
return UserQueryResponse(relateddocs=docstoreturn,response=responsegrp["result"])
except:
# Fallback to Similarity Search RAG
searchchain = SIMILARITY_SEARCH_PROMPT | llm
response = searchchain.invoke({"question": data.query, "context": docstoreturn})
return UserQueryResponse(relateddocs=docstoreturn,response=response.content)
# Precision Search
@app.post("/precision")
def get_precision_search_response(data: PrecisionQuery, response_model=PrecisionResponse):
if(data.apisecretkey != API_SECRET_KEY):
raise HTTPException(status_code=401, detail="Unauthorized")
graph = Neo4jGraph(url=data.neourl, username=data.neouser, password=data.neopass)
GRAPH_QUERY = "MATCH (d:Document) WHERE d.VisitedWebPageDateWithTimeInISOString >= " + "'" + data.daterange[0] + "'" + " AND d.VisitedWebPageDateWithTimeInISOString <= " + "'" + data.daterange[1] + "'"
if(data.timerange[0] >= data.timerange[1]):
GRAPH_QUERY += " AND d.VisitedWebPageVisitDurationInMilliseconds >= 0"
else:
GRAPH_QUERY += " AND d.VisitedWebPageVisitDurationInMilliseconds >= "+ str(data.timerange[0]) + " AND d.VisitedWebPageVisitDurationInMilliseconds <= " + str(data.timerange[1])
if(data.webpageurl):
GRAPH_QUERY += " AND d.VisitedWebPageURL CONTAINS " + "'" + data.webpageurl.lower() + "'"
if(data.sessionid):
GRAPH_QUERY += " AND d.BrowsingSessionId = " + "'" + data.sessionid + "'"
GRAPH_QUERY += " RETURN d;"
graphdocs = graph.query(GRAPH_QUERY)
docsDict = {}
for d in graphdocs:
if d['d']['BrowsingSessionId'] not in docsDict:
docsDict[d['d']['BrowsingSessionId']] = d['d']
else:
docsDict[d['d']['BrowsingSessionId']]['text'] += d['d']['text']
docs = []
for x in docsDict.values():
docs.append(DocMeta(
BrowsingSessionId=x['BrowsingSessionId'],
VisitedWebPageURL=x['VisitedWebPageURL'],
VisitedWebPageVisitDurationInMilliseconds=x['VisitedWebPageVisitDurationInMilliseconds'],
VisitedWebPageTitle=x['VisitedWebPageTitle'],
VisitedWebPageReffererURL=x['VisitedWebPageReffererURL'],
VisitedWebPageDateWithTimeInISOString=x['VisitedWebPageDateWithTimeInISOString'],
VisitedWebPageContent=x['text']
))
return PrecisionResponse(documents=docs)
# Multi DOC Chat
@app.post("/chat/docs")
def doc_chat_with_history(data: UserQueryWithChatHistory, response_model=DescriptionResponse):
if(data.apisecretkey != API_SECRET_KEY):
raise HTTPException(status_code=401, detail="Unauthorized")
llm = ChatOpenAI(
model="gpt-4o-mini",
temperature=0,
max_tokens=None,
timeout=None,
api_key=data.openaikey
)
chatHistory = []
for chat in data.chat:
if(chat.type == 'system'):
chatHistory.append(SystemMessage(content=DATE_TODAY + """You are an helpful assistant for question-answering tasks.
Use the following pieces of retrieved context to answer the question.
If you don't know the answer, just say that you don't know.
Context:""" + str(chat.content)))
if(chat.type == 'ai'):
chatHistory.append(AIMessage(content=chat.content))
if(chat.type == 'human'):
chatHistory.append(HumanMessage(content=chat.content))
chatHistory.append(("human", "{input}"));
qa_prompt = ChatPromptTemplate.from_messages(chatHistory)
descriptionchain = qa_prompt | llm
response = descriptionchain.invoke({"input": data.query})
return DescriptionResponse(response=response.content)
# DOC DESCRIPTION
@app.post("/kb/doc")
def get_doc_description(data: UserQuery, response_model=DescriptionResponse):
if(data.apisecretkey != API_SECRET_KEY):
raise HTTPException(status_code=401, detail="Unauthorized")
document = data.query
llm = ChatOpenAI(
model="gpt-4o-mini",
temperature=0,
max_tokens=None,
timeout=None,
api_key=data.openaikey
)
descriptionchain = DOC_DESCRIPTION_PROMPT | llm
response = descriptionchain.invoke({"document": document})
return DescriptionResponse(response=response.content)
# SAVE DOCS TO GRAPH DB
@app.post("/kb/")
def populate_graph(apires: RetrivedDocList, db: Session = Depends(get_db)):
try:
payload = jwt.decode(apires.token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
print("STARTED")
# print(apires)
graph = Neo4jGraph(url=apires.neourl, username=apires.neouser, password=apires.neopass)
llm = ChatOpenAI(
model="gpt-4o-mini",
temperature=0,
max_tokens=None,
timeout=None,
api_key=apires.openaikey
)
embeddings = OpenAIEmbeddings(
model="text-embedding-ada-002",
api_key=apires.openaikey,
)
llm_transformer = LLMGraphTransformer(llm=llm)
raw_documents = []
for doc in apires.documents:
raw_documents.append(Document(page_content=doc.pageContent, metadata=doc.metadata))
text_splitter = SemanticChunker(embeddings=embeddings)
documents = text_splitter.split_documents(raw_documents)
graph_documents = llm_transformer.convert_to_graph_documents(documents)
graph.add_graph_documents(
graph_documents,
baseEntityLabel=True,
include_source=True
)
structured_llm = llm.with_structured_output(Notifications)
notifs_extraction_chain = NOTIFICATION_GENERATION_PROMT | structured_llm
notifications = notifs_extraction_chain.invoke({"documents": raw_documents})
notifsdb = []
for text in notifications.notifications:
notifsdb.append(Notification(text=text))
user = db.query(User).filter(User.username == username).first()
user.notifications.extend(notifsdb)
db.commit()
print("FINISHED")
return {
"success": "Graph Will be populated Shortly"
}
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
#Fuction to populate db ( Comment out when running on server )
@app.post("/add/")
def add_user(user: UserCreate, db: Session = Depends(get_db)):
db_user = User(username=user.username, hashed_password=user.password, graph_config="", llm_config="")
db.add(db_user)
db.commit()
return "Success"
#AUTH CODE
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# Recommended for Local Setups
# origins = [
# "http://localhost:3000", # Adjust the port if your frontend runs on a different one
# "https://yourfrontenddomain.com",
# ]
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins from the list
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def get_user_by_username(db: Session, username: str):
return db.query(User).filter(User.username == username).first()
def create_user(db: Session, user: UserCreate):
hashed_password = pwd_context.hash(user.password)
db_user = User(username=user.username, hashed_password=hashed_password, graph_config="", llm_config="")
db.add(db_user)
db.commit()
return "complete"
@app.post("/register")
def register_user(user: UserCreate, db: Session = Depends(get_db)):
if(user.apisecretkey != API_SECRET_KEY):
raise HTTPException(status_code=401, detail="Unauthorized")
db_user = get_user_by_username(db, username=user.username)
if db_user:
raise HTTPException(status_code=400, detail="Username already registered")
del user.apisecretkey
return create_user(db=db, user=user)
# Authenticate the user
def authenticate_user(username: str, password: str, db: Session):
user = db.query(User).filter(User.username == username).first()
if not user:
return False
if not pwd_context.verify(password, user.hashed_password):
return False
return user
# Create access token
def create_access_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
@app.post("/token")
def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
user = authenticate_user(form_data.username, form_data.password, db)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}
def verify_token(token: str = Depends(oauth2_scheme)):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
return payload
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
@app.get("/verify-token/{token}")
async def verify_user_token(token: str):
verify_token(token=token)
return {"message": "Token is valid"}
@app.post("/user/chat/save")
def populate_user_chat(chat: NewUserChat, db: Session = Depends(get_db)):
try:
payload = jwt.decode(chat.token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
user = db.query(User).filter(User.username == username).first()
newchat = Chat(type=chat.type, title=chat.title, chats_list=chat.chats_list)
user.chats.append(newchat)
db.commit()
return {
"message": "Chat Saved"
}
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
@app.post("/user/chat/update")
def populate_user_chat(chat: ChatToUpdate, db: Session = Depends(get_db)):
try:
payload = jwt.decode(chat.token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
chatindb = db.query(Chat).filter(Chat.id == chat.chatid).first()
chatindb.chats_list = chat.chats_list
db.commit()
return {
"message": "Chat Updated"
}
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
@app.get("/user/chat/delete/{token}/{chatid}")
async def delete_chat_of_user(token: str, chatid: str, db: Session = Depends(get_db)):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
chatindb = db.query(Chat).filter(Chat.id == chatid).first()
db.delete(chatindb)
db.commit()
return {
"message": "Chat Deleted"
}
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
#Gets user id & name
@app.get("/user/{token}")
async def get_user_with_token(token: str, db: Session = Depends(get_db)):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
user = db.query(User).filter(User.username == username).first()
return {
"userid": user.id,
"username": user.username,
"chats": user.chats,
"notifications": user.notifications
}
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
@app.get("/user/notification/delete/{token}/{notificationid}")
async def delete_chat_of_user(token: str, notificationid: str, db: Session = Depends(get_db)):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
notificationindb = db.query(Notification).filter(Notification.id == notificationid).first()
db.delete(notificationindb)
db.commit()
return {
"message": "Notification Deleted"
}
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8000)