mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-02 10:39:13 +00:00
566 lines
No EOL
20 KiB
Python
566 lines
No EOL
20 KiB
Python
from __future__ import annotations
|
|
|
|
from langchain.chains import GraphCypherQAChain
|
|
from langchain_community.graphs import Neo4jGraph
|
|
from langchain_core.documents import Document
|
|
from langchain_openai import OpenAIEmbeddings
|
|
from langchain_community.vectorstores import Neo4jVector
|
|
from envs import ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, API_SECRET_KEY, SECRET_KEY
|
|
from prompts import CYPHER_QA_PROMPT, DATE_TODAY, DOC_DESCRIPTION_PROMPT, GRAPH_QUERY_GEN_PROMPT, NOTIFICATION_GENERATION_PROMT, SIMILARITY_SEARCH_PROMPT , CYPHER_GENERATION_PROMPT, DOCUMENT_METADATA_EXTRACTION_PROMT
|
|
from pydmodels import ChatToUpdate, DescriptionResponse, GraphDocs, NewUserChat, NewUserData, Notifications, PrecisionQuery, PrecisionResponse, UserQuery, DocMeta, RetrivedDocList, UserQueryResponse, UserQueryWithChatHistory, VectorSearchQuery
|
|
from langchain_experimental.text_splitter import SemanticChunker
|
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
|
|
|
#Our Imps
|
|
from LLMGraphTransformer import LLMGraphTransformer
|
|
from langchain_openai import ChatOpenAI
|
|
from DataExample import examples
|
|
# import nest_asyncio
|
|
# from langchain_community.chains.graph_qa.gremlin import GremlinQAChain
|
|
# from langchain_community.graphs import GremlinGraph
|
|
# from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
|
|
# from langchain_core.documents import Document
|
|
# from langchain_openai import AzureChatOpenAI
|
|
|
|
# Auth Libs
|
|
from fastapi import FastAPI, Depends, HTTPException, Request, status
|
|
from sqlalchemy.orm import Session
|
|
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
|
from jose import JWTError, jwt
|
|
from datetime import datetime, timedelta
|
|
from passlib.context import CryptContext
|
|
from models import Chat, Notification, User
|
|
from database import SessionLocal, engine
|
|
from pydantic import BaseModel
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from langchain_openai import AzureChatOpenAI
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
# Dependency
|
|
def get_db():
|
|
db = SessionLocal()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
|
|
class UserCreate(BaseModel):
|
|
username: str
|
|
password: str
|
|
apisecretkey: str
|
|
|
|
|
|
|
|
# General GraphCypherQAChain
|
|
@app.post("/")
|
|
def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
|
|
|
|
if(data.apisecretkey != API_SECRET_KEY):
|
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
|
|
|
|
graph = Neo4jGraph(url=data.neourl, username=data.neouser, password=data.neopass)
|
|
|
|
llm = ChatOpenAI(
|
|
model="gpt-4o-mini",
|
|
temperature=0,
|
|
max_tokens=None,
|
|
timeout=None,
|
|
api_key=data.openaikey
|
|
)
|
|
|
|
# Query Expansion
|
|
# searchchain = GRAPH_QUERY_GEN_PROMPT | llm
|
|
|
|
# qry = searchchain.invoke({"question": data.query, "context": examples})
|
|
|
|
query = data.query #qry.content
|
|
|
|
embeddings = OpenAIEmbeddings(
|
|
model="text-embedding-ada-002",
|
|
api_key=data.openaikey,
|
|
)
|
|
|
|
|
|
chain = GraphCypherQAChain.from_llm(
|
|
graph=graph,
|
|
cypher_prompt=CYPHER_GENERATION_PROMPT,
|
|
cypher_llm=llm,
|
|
verbose=True,
|
|
validate_cypher=True,
|
|
qa_prompt=CYPHER_QA_PROMPT ,
|
|
qa_llm=llm,
|
|
return_intermediate_steps=True,
|
|
top_k=5,
|
|
)
|
|
|
|
vector_index = Neo4jVector.from_existing_graph(
|
|
embeddings,
|
|
graph=graph,
|
|
search_type="hybrid",
|
|
node_label="Document",
|
|
text_node_properties=["text"],
|
|
embedding_node_property="embedding",
|
|
)
|
|
|
|
graphdocs = vector_index.similarity_search(data.query,k=15)
|
|
docsDict = {}
|
|
|
|
for d in graphdocs:
|
|
if d.metadata['BrowsingSessionId'] not in docsDict:
|
|
newVal = d.metadata.copy()
|
|
newVal['VisitedWebPageContent'] = d.page_content
|
|
docsDict[d.metadata['BrowsingSessionId']] = newVal
|
|
else:
|
|
docsDict[d.metadata['BrowsingSessionId']]['VisitedWebPageContent'] += d.page_content
|
|
|
|
docstoreturn = []
|
|
|
|
for x in docsDict.values():
|
|
docstoreturn.append(DocMeta(
|
|
BrowsingSessionId=x['BrowsingSessionId'],
|
|
VisitedWebPageURL=x['VisitedWebPageURL'],
|
|
VisitedWebPageVisitDurationInMilliseconds=x['VisitedWebPageVisitDurationInMilliseconds'],
|
|
VisitedWebPageTitle=x['VisitedWebPageTitle'],
|
|
VisitedWebPageReffererURL=x['VisitedWebPageReffererURL'],
|
|
VisitedWebPageDateWithTimeInISOString=x['VisitedWebPageDateWithTimeInISOString'],
|
|
VisitedWebPageContent=x['VisitedWebPageContent']
|
|
))
|
|
|
|
|
|
try:
|
|
responsegrp = chain.invoke({"query": query})
|
|
|
|
if "don't know" in responsegrp["result"]:
|
|
raise Exception("No response from graph")
|
|
|
|
structured_llm = llm.with_structured_output(VectorSearchQuery)
|
|
doc_extract_chain = DOCUMENT_METADATA_EXTRACTION_PROMT | structured_llm
|
|
|
|
newquery = doc_extract_chain.invoke(responsegrp["intermediate_steps"][1]["context"])
|
|
|
|
graphdocs = vector_index.similarity_search(newquery.searchquery,k=15)
|
|
|
|
docsDict = {}
|
|
|
|
for d in graphdocs:
|
|
if d.metadata['BrowsingSessionId'] not in docsDict:
|
|
newVal = d.metadata.copy()
|
|
newVal['VisitedWebPageContent'] = d.page_content
|
|
docsDict[d.metadata['BrowsingSessionId']] = newVal
|
|
else:
|
|
docsDict[d.metadata['BrowsingSessionId']]['VisitedWebPageContent'] += d.page_content
|
|
|
|
docstoreturn = []
|
|
|
|
for x in docsDict.values():
|
|
docstoreturn.append(DocMeta(
|
|
BrowsingSessionId=x['BrowsingSessionId'],
|
|
VisitedWebPageURL=x['VisitedWebPageURL'],
|
|
VisitedWebPageVisitDurationInMilliseconds=x['VisitedWebPageVisitDurationInMilliseconds'],
|
|
VisitedWebPageTitle=x['VisitedWebPageTitle'],
|
|
VisitedWebPageReffererURL=x['VisitedWebPageReffererURL'],
|
|
VisitedWebPageDateWithTimeInISOString=x['VisitedWebPageDateWithTimeInISOString'],
|
|
VisitedWebPageContent=x['VisitedWebPageContent']
|
|
))
|
|
|
|
return UserQueryResponse(relateddocs=docstoreturn,response=responsegrp["result"])
|
|
except:
|
|
# Fallback to Similarity Search RAG
|
|
searchchain = SIMILARITY_SEARCH_PROMPT | llm
|
|
|
|
response = searchchain.invoke({"question": data.query, "context": docstoreturn})
|
|
|
|
return UserQueryResponse(relateddocs=docstoreturn,response=response.content)
|
|
|
|
|
|
# Precision Search
|
|
@app.post("/precision")
|
|
def get_precision_search_response(data: PrecisionQuery, response_model=PrecisionResponse):
|
|
if(data.apisecretkey != API_SECRET_KEY):
|
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
|
|
graph = Neo4jGraph(url=data.neourl, username=data.neouser, password=data.neopass)
|
|
|
|
GRAPH_QUERY = "MATCH (d:Document) WHERE d.VisitedWebPageDateWithTimeInISOString >= " + "'" + data.daterange[0] + "'" + " AND d.VisitedWebPageDateWithTimeInISOString <= " + "'" + data.daterange[1] + "'"
|
|
|
|
if(data.timerange[0] >= data.timerange[1]):
|
|
GRAPH_QUERY += " AND d.VisitedWebPageVisitDurationInMilliseconds >= 0"
|
|
else:
|
|
GRAPH_QUERY += " AND d.VisitedWebPageVisitDurationInMilliseconds >= "+ str(data.timerange[0]) + " AND d.VisitedWebPageVisitDurationInMilliseconds <= " + str(data.timerange[1])
|
|
|
|
if(data.webpageurl):
|
|
GRAPH_QUERY += " AND d.VisitedWebPageURL CONTAINS " + "'" + data.webpageurl.lower() + "'"
|
|
|
|
if(data.sessionid):
|
|
GRAPH_QUERY += " AND d.BrowsingSessionId = " + "'" + data.sessionid + "'"
|
|
|
|
GRAPH_QUERY += " RETURN d;"
|
|
|
|
graphdocs = graph.query(GRAPH_QUERY)
|
|
|
|
docsDict = {}
|
|
|
|
for d in graphdocs:
|
|
if d['d']['BrowsingSessionId'] not in docsDict:
|
|
docsDict[d['d']['BrowsingSessionId']] = d['d']
|
|
else:
|
|
docsDict[d['d']['BrowsingSessionId']]['text'] += d['d']['text']
|
|
|
|
docs = []
|
|
|
|
for x in docsDict.values():
|
|
docs.append(DocMeta(
|
|
BrowsingSessionId=x['BrowsingSessionId'],
|
|
VisitedWebPageURL=x['VisitedWebPageURL'],
|
|
VisitedWebPageVisitDurationInMilliseconds=x['VisitedWebPageVisitDurationInMilliseconds'],
|
|
VisitedWebPageTitle=x['VisitedWebPageTitle'],
|
|
VisitedWebPageReffererURL=x['VisitedWebPageReffererURL'],
|
|
VisitedWebPageDateWithTimeInISOString=x['VisitedWebPageDateWithTimeInISOString'],
|
|
VisitedWebPageContent=x['text']
|
|
))
|
|
|
|
return PrecisionResponse(documents=docs)
|
|
|
|
|
|
# Multi DOC Chat
|
|
@app.post("/chat/docs")
|
|
def doc_chat_with_history(data: UserQueryWithChatHistory, response_model=DescriptionResponse):
|
|
if(data.apisecretkey != API_SECRET_KEY):
|
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
|
|
llm = ChatOpenAI(
|
|
model="gpt-4o-mini",
|
|
temperature=0,
|
|
max_tokens=None,
|
|
timeout=None,
|
|
api_key=data.openaikey
|
|
)
|
|
|
|
chatHistory = []
|
|
|
|
for chat in data.chat:
|
|
if(chat.type == 'system'):
|
|
chatHistory.append(SystemMessage(content=DATE_TODAY + """You are an helpful assistant for question-answering tasks.
|
|
Use the following pieces of retrieved context to answer the question.
|
|
If you don't know the answer, just say that you don't know.
|
|
Context:""" + str(chat.content)))
|
|
|
|
if(chat.type == 'ai'):
|
|
chatHistory.append(AIMessage(content=chat.content))
|
|
|
|
if(chat.type == 'human'):
|
|
chatHistory.append(HumanMessage(content=chat.content))
|
|
|
|
chatHistory.append(("human", "{input}"));
|
|
|
|
|
|
qa_prompt = ChatPromptTemplate.from_messages(chatHistory)
|
|
|
|
descriptionchain = qa_prompt | llm
|
|
|
|
response = descriptionchain.invoke({"input": data.query})
|
|
|
|
return DescriptionResponse(response=response.content)
|
|
|
|
|
|
# DOC DESCRIPTION
|
|
@app.post("/kb/doc")
|
|
def get_doc_description(data: UserQuery, response_model=DescriptionResponse):
|
|
if(data.apisecretkey != API_SECRET_KEY):
|
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
|
|
document = data.query
|
|
llm = ChatOpenAI(
|
|
model="gpt-4o-mini",
|
|
temperature=0,
|
|
max_tokens=None,
|
|
timeout=None,
|
|
api_key=data.openaikey
|
|
)
|
|
|
|
descriptionchain = DOC_DESCRIPTION_PROMPT | llm
|
|
|
|
response = descriptionchain.invoke({"document": document})
|
|
|
|
return DescriptionResponse(response=response.content)
|
|
|
|
|
|
# SAVE DOCS TO GRAPH DB
|
|
@app.post("/kb/")
|
|
def populate_graph(apires: RetrivedDocList, db: Session = Depends(get_db)):
|
|
|
|
try:
|
|
payload = jwt.decode(apires.token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
username: str = payload.get("sub")
|
|
if username is None:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
|
|
print("STARTED")
|
|
# print(apires)
|
|
graph = Neo4jGraph(url=apires.neourl, username=apires.neouser, password=apires.neopass)
|
|
|
|
llm = ChatOpenAI(
|
|
model="gpt-4o-mini",
|
|
temperature=0,
|
|
max_tokens=None,
|
|
timeout=None,
|
|
api_key=apires.openaikey
|
|
)
|
|
|
|
embeddings = OpenAIEmbeddings(
|
|
model="text-embedding-ada-002",
|
|
api_key=apires.openaikey,
|
|
)
|
|
|
|
llm_transformer = LLMGraphTransformer(llm=llm)
|
|
|
|
raw_documents = []
|
|
|
|
for doc in apires.documents:
|
|
raw_documents.append(Document(page_content=doc.pageContent, metadata=doc.metadata))
|
|
|
|
text_splitter = SemanticChunker(embeddings=embeddings)
|
|
|
|
documents = text_splitter.split_documents(raw_documents)
|
|
graph_documents = llm_transformer.convert_to_graph_documents(documents)
|
|
|
|
|
|
graph.add_graph_documents(
|
|
graph_documents,
|
|
baseEntityLabel=True,
|
|
include_source=True
|
|
)
|
|
|
|
|
|
structured_llm = llm.with_structured_output(Notifications)
|
|
notifs_extraction_chain = NOTIFICATION_GENERATION_PROMT | structured_llm
|
|
|
|
notifications = notifs_extraction_chain.invoke({"documents": raw_documents})
|
|
|
|
notifsdb = []
|
|
|
|
for text in notifications.notifications:
|
|
notifsdb.append(Notification(text=text))
|
|
|
|
user = db.query(User).filter(User.username == username).first()
|
|
user.notifications.extend(notifsdb)
|
|
|
|
db.commit()
|
|
|
|
print("FINISHED")
|
|
|
|
return {
|
|
"success": "Graph Will be populated Shortly"
|
|
}
|
|
|
|
except JWTError:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
|
|
#Fuction to populate db ( Comment out when running on server )
|
|
@app.post("/add/")
|
|
def add_user(user: UserCreate, db: Session = Depends(get_db)):
|
|
db_user = User(username=user.username, hashed_password=user.password, graph_config="", llm_config="")
|
|
db.add(db_user)
|
|
db.commit()
|
|
return "Success"
|
|
|
|
|
|
#AUTH CODE
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
|
|
|
# Recommended for Local Setups
|
|
# origins = [
|
|
# "http://localhost:3000", # Adjust the port if your frontend runs on a different one
|
|
# "https://yourfrontenddomain.com",
|
|
# ]
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"], # Allows all origins from the list
|
|
allow_credentials=True,
|
|
allow_methods=["*"], # Allows all methods
|
|
allow_headers=["*"], # Allows all headers
|
|
)
|
|
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
|
def get_user_by_username(db: Session, username: str):
|
|
return db.query(User).filter(User.username == username).first()
|
|
|
|
def create_user(db: Session, user: UserCreate):
|
|
hashed_password = pwd_context.hash(user.password)
|
|
db_user = User(username=user.username, hashed_password=hashed_password, graph_config="", llm_config="")
|
|
db.add(db_user)
|
|
db.commit()
|
|
return "complete"
|
|
|
|
@app.post("/register")
|
|
def register_user(user: UserCreate, db: Session = Depends(get_db)):
|
|
if(user.apisecretkey != API_SECRET_KEY):
|
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
|
|
db_user = get_user_by_username(db, username=user.username)
|
|
if db_user:
|
|
raise HTTPException(status_code=400, detail="Username already registered")
|
|
|
|
del user.apisecretkey
|
|
return create_user(db=db, user=user)
|
|
|
|
# Authenticate the user
|
|
def authenticate_user(username: str, password: str, db: Session):
|
|
user = db.query(User).filter(User.username == username).first()
|
|
if not user:
|
|
return False
|
|
if not pwd_context.verify(password, user.hashed_password):
|
|
return False
|
|
return user
|
|
|
|
# Create access token
|
|
def create_access_token(data: dict, expires_delta: timedelta | None = None):
|
|
to_encode = data.copy()
|
|
if expires_delta:
|
|
expire = datetime.utcnow() + expires_delta
|
|
else:
|
|
expire = datetime.utcnow() + timedelta(minutes=15)
|
|
to_encode.update({"exp": expire})
|
|
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
|
return encoded_jwt
|
|
|
|
@app.post("/token")
|
|
def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
|
|
user = authenticate_user(form_data.username, form_data.password, db)
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Incorrect username or password",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
access_token = create_access_token(
|
|
data={"sub": user.username}, expires_delta=access_token_expires
|
|
)
|
|
return {"access_token": access_token, "token_type": "bearer"}
|
|
|
|
def verify_token(token: str = Depends(oauth2_scheme)):
|
|
try:
|
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
username: str = payload.get("sub")
|
|
if username is None:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
return payload
|
|
except JWTError:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
@app.get("/verify-token/{token}")
|
|
async def verify_user_token(token: str):
|
|
verify_token(token=token)
|
|
return {"message": "Token is valid"}
|
|
|
|
|
|
|
|
|
|
@app.post("/user/chat/save")
|
|
def populate_user_chat(chat: NewUserChat, db: Session = Depends(get_db)):
|
|
try:
|
|
payload = jwt.decode(chat.token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
username: str = payload.get("sub")
|
|
if username is None:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
user = db.query(User).filter(User.username == username).first()
|
|
newchat = Chat(type=chat.type, title=chat.title, chats_list=chat.chats_list)
|
|
|
|
user.chats.append(newchat)
|
|
db.commit()
|
|
return {
|
|
"message": "Chat Saved"
|
|
}
|
|
except JWTError:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
@app.post("/user/chat/update")
|
|
def populate_user_chat(chat: ChatToUpdate, db: Session = Depends(get_db)):
|
|
try:
|
|
payload = jwt.decode(chat.token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
username: str = payload.get("sub")
|
|
if username is None:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
chatindb = db.query(Chat).filter(Chat.id == chat.chatid).first()
|
|
chatindb.chats_list = chat.chats_list
|
|
|
|
db.commit()
|
|
return {
|
|
"message": "Chat Updated"
|
|
}
|
|
except JWTError:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
@app.get("/user/chat/delete/{token}/{chatid}")
|
|
async def delete_chat_of_user(token: str, chatid: str, db: Session = Depends(get_db)):
|
|
try:
|
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
username: str = payload.get("sub")
|
|
if username is None:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
chatindb = db.query(Chat).filter(Chat.id == chatid).first()
|
|
db.delete(chatindb)
|
|
db.commit()
|
|
return {
|
|
"message": "Chat Deleted"
|
|
}
|
|
except JWTError:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
#Gets user id & name
|
|
@app.get("/user/{token}")
|
|
async def get_user_with_token(token: str, db: Session = Depends(get_db)):
|
|
try:
|
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
username: str = payload.get("sub")
|
|
if username is None:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
user = db.query(User).filter(User.username == username).first()
|
|
return {
|
|
"userid": user.id,
|
|
"username": user.username,
|
|
"chats": user.chats,
|
|
"notifications": user.notifications
|
|
}
|
|
except JWTError:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
|
|
@app.get("/user/notification/delete/{token}/{notificationid}")
|
|
async def delete_chat_of_user(token: str, notificationid: str, db: Session = Depends(get_db)):
|
|
try:
|
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
username: str = payload.get("sub")
|
|
if username is None:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
notificationindb = db.query(Notification).filter(Notification.id == notificationid).first()
|
|
db.delete(notificationindb)
|
|
db.commit()
|
|
return {
|
|
"message": "Notification Deleted"
|
|
}
|
|
except JWTError:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run(app, host="127.0.0.1", port=8000) |