mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-02 10:39:13 +00:00
588 lines
22 KiB
Python
588 lines
22 KiB
Python
from __future__ import annotations
|
|
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
from langchain_core.documents import Document
|
|
from langchain_ollama import OllamaLLM
|
|
from langchain_openai import ChatOpenAI
|
|
from sqlalchemy import insert
|
|
from prompts import CONTEXT_ANSWER_PROMPT, DATE_TODAY, SUBQUERY_DECOMPOSITION_PROMT
|
|
from pydmodels import ChatToUpdate, DescriptionResponse, DocWithContent, DocumentsToDelete, NewUserChat, UserCreate, UserQuery, RetrivedDocList, UserQueryResponse, UserQueryWithChatHistory
|
|
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
|
from langchain_unstructured import UnstructuredLoader
|
|
|
|
#Heirerical Indices class
|
|
from HIndices import HIndices
|
|
|
|
from Utils.stringify import stringify
|
|
|
|
# Auth Libs
|
|
from fastapi import FastAPI, Depends, Form, HTTPException, status, UploadFile
|
|
from sqlalchemy.orm import Session
|
|
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
|
from jose import JWTError, jwt
|
|
from datetime import datetime, timedelta
|
|
from passlib.context import CryptContext
|
|
from models import Chat, Documents, SearchSpace, User
|
|
from database import SessionLocal
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
|
import os
|
|
from dotenv import load_dotenv
|
|
load_dotenv()
|
|
|
|
IS_LOCAL_SETUP = os.environ.get("IS_LOCAL_SETUP")
|
|
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.environ.get("ACCESS_TOKEN_EXPIRE_MINUTES"))
|
|
ALGORITHM = os.environ.get("ALGORITHM")
|
|
API_SECRET_KEY = os.environ.get("API_SECRET_KEY")
|
|
SECRET_KEY = os.environ.get("SECRET_KEY")
|
|
UNSTRUCTURED_API_KEY = os.environ.get("UNSTRUCTURED_API_KEY")
|
|
|
|
app = FastAPI()
|
|
|
|
# Dependency
|
|
def get_db():
|
|
db = SessionLocal()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
|
|
|
@app.post("/uploadfiles/")
|
|
async def upload_files(files: list[UploadFile], token: str = Depends(oauth2_scheme), search_space: str = Form(...), api_key: str = Form(...), db: Session = Depends(get_db)):
|
|
try:
|
|
# Decode and verify the token
|
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
username: str = payload.get("sub")
|
|
if username is None:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
docs = []
|
|
|
|
for file in files:
|
|
|
|
loader = UnstructuredLoader(
|
|
file=file.file,
|
|
api_key=UNSTRUCTURED_API_KEY,
|
|
partition_via_api=True,
|
|
chunking_strategy="basic",
|
|
max_characters=90000,
|
|
include_orig_elements=False,
|
|
)
|
|
|
|
filedocs = loader.load()
|
|
|
|
fileswithfilename = []
|
|
for f in filedocs:
|
|
temp = f
|
|
temp.metadata['filename'] = file.filename
|
|
fileswithfilename.append(temp)
|
|
|
|
docs.extend(fileswithfilename)
|
|
|
|
# Initialize containers for documents and entries
|
|
DocumentPgEntry = []
|
|
raw_documents = []
|
|
|
|
# Fetch the search space from the database or create it if it doesn't exist
|
|
searchspace = db.query(SearchSpace).filter(SearchSpace.search_space == search_space.upper()).first()
|
|
if not searchspace:
|
|
stmt = insert(SearchSpace).values(search_space=search_space.upper())
|
|
db.execute(stmt)
|
|
db.commit()
|
|
|
|
# Process each document in the retrieved document list
|
|
for doc in docs:
|
|
|
|
raw_documents.append(Document(page_content=doc.page_content, metadata=doc.metadata))
|
|
|
|
# Stringify the document metadata
|
|
pgdocmeta = stringify(doc.metadata)
|
|
|
|
DocumentPgEntry.append(Documents(
|
|
file_type=doc.metadata['filetype'],
|
|
title=doc.metadata['filename'],
|
|
search_space=db.query(SearchSpace).filter(SearchSpace.search_space == search_space.upper()).first(),
|
|
document_metadata=pgdocmeta,
|
|
page_content=doc.page_content
|
|
))
|
|
|
|
|
|
# Save documents in PostgreSQL
|
|
user = db.query(User).filter(User.username == username).first()
|
|
user.documents.extend(DocumentPgEntry)
|
|
db.commit()
|
|
|
|
# Create hierarchical indices
|
|
if IS_LOCAL_SETUP == 'true':
|
|
index = HIndices(username=username)
|
|
else:
|
|
index = HIndices(username=username, api_key=api_key)
|
|
|
|
# Save indices in vector stores
|
|
index.encode_docs_hierarchical(documents=raw_documents, files_type='OTHER', search_space=search_space.upper(), db=db)
|
|
|
|
print("FINISHED")
|
|
|
|
return {
|
|
"message": "Files Uploaded Successfully"
|
|
}
|
|
|
|
except JWTError:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
@app.post("/chat/")
|
|
def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
|
|
try:
|
|
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
username: str = payload.get("sub")
|
|
if username is None:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
query = data.query
|
|
search_space = data.search_space
|
|
|
|
if(IS_LOCAL_SETUP == 'true'):
|
|
sub_query_llm = OllamaLLM(model="mistral-nemo",temperature=0)
|
|
qa_llm = OllamaLLM(model="mistral-nemo",temperature=0)
|
|
else:
|
|
sub_query_llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", api_key=data.openaikey)
|
|
qa_llm = ChatOpenAI(temperature=0.5, model_name="gpt-4o-mini", api_key=data.openaikey)
|
|
|
|
|
|
|
|
# Create an LLMChain for sub-query decomposition
|
|
subquery_decomposer_chain = SUBQUERY_DECOMPOSITION_PROMT | sub_query_llm
|
|
|
|
#Experimental
|
|
def decompose_query(original_query: str):
|
|
"""
|
|
Decompose the original query into simpler sub-queries.
|
|
|
|
Args:
|
|
original_query (str): The original complex query
|
|
|
|
Returns:
|
|
List[str]: A list of simpler sub-queries
|
|
"""
|
|
if(IS_LOCAL_SETUP == 'true'):
|
|
response = subquery_decomposer_chain.invoke(original_query)
|
|
else:
|
|
response = subquery_decomposer_chain.invoke(original_query).content
|
|
|
|
sub_queries = [q.strip() for q in response.split('\n') if q.strip() and not q.strip().startswith('Sub-queries:')]
|
|
return sub_queries
|
|
|
|
|
|
# Create Heirarical Indecices
|
|
if(IS_LOCAL_SETUP == 'true'):
|
|
index = HIndices(username=username)
|
|
else:
|
|
index = HIndices(username=username,api_key=data.openaikey)
|
|
|
|
|
|
|
|
# For Those Who Want HyDe Questions
|
|
# sub_queries = decompose_query(query)
|
|
|
|
sub_queries = []
|
|
sub_queries.append(query)
|
|
|
|
duplicate_related_summary_docs = []
|
|
context_to_answer = ""
|
|
for sub_query in sub_queries:
|
|
localreturn = index.local_search(query=sub_query, search_space=search_space)
|
|
globalreturn, related_summary_docs = index.global_search(query=sub_query, search_space=search_space)
|
|
|
|
context_to_answer += localreturn + "\n\n" + globalreturn
|
|
|
|
duplicate_related_summary_docs.extend(related_summary_docs)
|
|
|
|
|
|
combined_docs_seen_metadata = set()
|
|
combined_docs_unique_documents = []
|
|
|
|
for doc in duplicate_related_summary_docs:
|
|
# Convert metadata to a tuple of its items (this allows it to be added to a set)
|
|
doc.metadata['relevance_score'] = 0.0
|
|
metadata_tuple = tuple(sorted(doc.metadata.items()))
|
|
if metadata_tuple not in combined_docs_seen_metadata:
|
|
combined_docs_seen_metadata.add(metadata_tuple)
|
|
combined_docs_unique_documents.append(doc)
|
|
|
|
returnDocs = []
|
|
for doc in combined_docs_unique_documents:
|
|
entry = DocWithContent(
|
|
DocMetadata=stringify(doc.metadata),
|
|
Content=doc.page_content
|
|
)
|
|
|
|
returnDocs.append(entry)
|
|
|
|
|
|
ans_chain = CONTEXT_ANSWER_PROMPT | qa_llm
|
|
|
|
finalans = ans_chain.invoke({"query": query, "context": context_to_answer})
|
|
|
|
|
|
if(IS_LOCAL_SETUP == 'true'):
|
|
return UserQueryResponse(response=finalans, relateddocs=returnDocs)
|
|
else:
|
|
return UserQueryResponse(response=finalans.content, relateddocs=returnDocs)
|
|
|
|
|
|
except JWTError:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
# SAVE DOCS
|
|
@app.post("/save/")
|
|
def save_data(apires: RetrivedDocList, db: Session = Depends(get_db)):
|
|
"""
|
|
Save retrieved documents to the database and encode them for hierarchical indexing.
|
|
|
|
This endpoint processes the provided documents, saves related information
|
|
in the PostgreSQL database, and updates hierarchical indices for the user.
|
|
Args:
|
|
apires (RetrivedDocList): The list of retrieved documents with metadata.
|
|
db (Session, optional): Dependency-injected session for database operations.
|
|
|
|
Returns:
|
|
dict: A message indicating the success of the operation.
|
|
|
|
Raises:
|
|
HTTPException: If the token is invalid or expired.
|
|
"""
|
|
try:
|
|
# Decode token and extract username
|
|
payload = jwt.decode(apires.token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
username: str = payload.get("sub")
|
|
if username is None:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
print("STARTED")
|
|
|
|
# Initialize containers for documents and entries
|
|
DocumentPgEntry = []
|
|
raw_documents = []
|
|
|
|
# Fetch the search space from the database
|
|
searchspace = db.query(SearchSpace).filter(SearchSpace.search_space == apires.search_space.upper()).first()
|
|
if not searchspace:
|
|
stmt = insert(SearchSpace).values(search_space=apires.search_space.upper())
|
|
db.execute(stmt)
|
|
db.commit()
|
|
|
|
# Process each document in the retrieved document list
|
|
for doc in apires.documents:
|
|
# Construct document content
|
|
content = (
|
|
f"USER BROWSING SESSION EVENT: \n"
|
|
f"=======================================METADATA==================================== \n"
|
|
f"User Browsing Session ID : {doc.metadata.BrowsingSessionId} \n"
|
|
f"User Visited website with url : {doc.metadata.VisitedWebPageURL} \n"
|
|
f"This visited website url had title : {doc.metadata.VisitedWebPageTitle} \n"
|
|
f"User Visited this website from referring url : {doc.metadata.VisitedWebPageReffererURL} \n"
|
|
f"User Visited this website url at this Date and Time : {doc.metadata.VisitedWebPageDateWithTimeInISOString} \n"
|
|
f"User Visited this website for : {str(doc.metadata.VisitedWebPageVisitDurationInMilliseconds)} milliseconds. \n"
|
|
f"===================================================================================== \n"
|
|
f"Webpage Content of the visited webpage url in markdown format : \n\n{doc.pageContent}\n\n"
|
|
f"===================================================================================== \n"
|
|
)
|
|
raw_documents.append(Document(page_content=content, metadata=doc.metadata.__dict__))
|
|
|
|
# Stringify the document metadata
|
|
pgdocmeta = stringify(doc.metadata.__dict__)
|
|
|
|
DocumentPgEntry.append(Documents(
|
|
file_type='WEBPAGE',
|
|
title=doc.metadata.VisitedWebPageTitle,
|
|
search_space=searchspace,
|
|
document_metadata=pgdocmeta,
|
|
page_content=content
|
|
))
|
|
|
|
# Save documents in PostgreSQL
|
|
user = db.query(User).filter(User.username == username).first()
|
|
user.documents.extend(DocumentPgEntry)
|
|
db.commit()
|
|
|
|
# Create hierarchical indices
|
|
if IS_LOCAL_SETUP == 'true':
|
|
index = HIndices(username=username)
|
|
else:
|
|
index = HIndices(username=username, api_key=apires.openaikey)
|
|
|
|
# Save indices in vector stores
|
|
index.encode_docs_hierarchical(documents=raw_documents, files_type='WEBPAGE', search_space=apires.search_space.upper(), db=db)
|
|
|
|
print("FINISHED")
|
|
|
|
return {
|
|
"success": "Graph Will be populated Shortly"
|
|
}
|
|
|
|
except JWTError:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
# Multi DOC Chat
|
|
@app.post("/chat/docs")
|
|
def doc_chat_with_history(data: UserQueryWithChatHistory, response_model=DescriptionResponse):
|
|
try:
|
|
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
username: str = payload.get("sub")
|
|
if username is None:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
if(IS_LOCAL_SETUP == 'true'):
|
|
llm = OllamaLLM(model="mistral-nemo",temperature=0)
|
|
else:
|
|
llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", api_key=data.openaikey)
|
|
|
|
chatHistory = []
|
|
|
|
for chat in data.chat:
|
|
if(chat.type == 'system'):
|
|
chatHistory.append(SystemMessage(content=DATE_TODAY + """You are an helpful assistant for question-answering tasks.
|
|
Use the following pieces of retrieved context to answer the question.
|
|
If you don't know the answer, just say that you don't know.
|
|
Context:""" + str(chat.content)))
|
|
|
|
if(chat.type == 'ai'):
|
|
chatHistory.append(AIMessage(content=chat.content))
|
|
|
|
if(chat.type == 'human'):
|
|
chatHistory.append(HumanMessage(content=chat.content))
|
|
|
|
chatHistory.append(("human", "{input}"));
|
|
|
|
|
|
qa_prompt = ChatPromptTemplate.from_messages(chatHistory)
|
|
|
|
descriptionchain = qa_prompt | llm
|
|
|
|
response = descriptionchain.invoke({"input": data.query})
|
|
|
|
if(IS_LOCAL_SETUP == 'true'):
|
|
return DescriptionResponse(response=response)
|
|
else:
|
|
return DescriptionResponse(response=response.content)
|
|
|
|
except JWTError:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
|
|
# Multi DOC Chat
|
|
|
|
@app.post("/delete/docs")
|
|
def delete_all_related_data(data: DocumentsToDelete, db: Session = Depends(get_db)):
|
|
try:
|
|
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
username: str = payload.get("sub")
|
|
if username is None:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
if(IS_LOCAL_SETUP == 'true'):
|
|
index = HIndices(username=username)
|
|
else:
|
|
index = HIndices(username=username,api_key=data.openaikey)
|
|
|
|
message = index.delete_vector_stores(summary_ids_to_delete=data.ids_to_delete,db=db )
|
|
|
|
return {
|
|
"message": message
|
|
}
|
|
|
|
except JWTError:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
|
|
|
|
|
|
# Manual Origins
|
|
# origins = [
|
|
# "http://localhost:3000", # Adjust the port if your frontend runs on a different one
|
|
# "https://yourfrontenddomain.com",
|
|
# ]
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"], # Allows all origins from the list
|
|
allow_credentials=True,
|
|
allow_methods=["*"], # Allows all methods
|
|
allow_headers=["*"], # Allows all headers
|
|
)
|
|
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
|
def get_user_by_username(db: Session, username: str):
|
|
return db.query(User).filter(User.username == username).first()
|
|
|
|
def create_user(db: Session, user: UserCreate):
|
|
hashed_password = pwd_context.hash(user.password)
|
|
db_user = User(username=user.username, hashed_password=hashed_password)
|
|
db.add(db_user)
|
|
db.commit()
|
|
return "complete"
|
|
|
|
@app.post("/register")
|
|
def register_user(user: UserCreate, db: Session = Depends(get_db)):
|
|
if(user.apisecretkey != API_SECRET_KEY):
|
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
|
|
db_user = get_user_by_username(db, username=user.username)
|
|
if db_user:
|
|
raise HTTPException(status_code=400, detail="Username already registered")
|
|
|
|
del user.apisecretkey
|
|
return create_user(db=db, user=user)
|
|
|
|
# Authenticate the user
|
|
def authenticate_user(username: str, password: str, db: Session):
|
|
user = db.query(User).filter(User.username == username).first()
|
|
if not user:
|
|
return False
|
|
if not pwd_context.verify(password, user.hashed_password):
|
|
return False
|
|
return user
|
|
|
|
# Create access token
|
|
def create_access_token(data: dict, expires_delta: timedelta | None = None):
|
|
to_encode = data.copy()
|
|
if expires_delta:
|
|
expire = datetime.utcnow() + expires_delta
|
|
else:
|
|
expire = datetime.utcnow() + timedelta(minutes=15)
|
|
to_encode.update({"exp": expire})
|
|
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
|
return encoded_jwt
|
|
|
|
@app.post("/token")
|
|
def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
|
|
user = authenticate_user(form_data.username, form_data.password, db)
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Incorrect username or password",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
access_token = create_access_token(
|
|
data={"sub": user.username}, expires_delta=access_token_expires
|
|
)
|
|
return {"access_token": access_token, "token_type": "bearer"}
|
|
|
|
def verify_token(token: str = Depends(oauth2_scheme)):
|
|
try:
|
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
username: str = payload.get("sub")
|
|
if username is None:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
return payload
|
|
except JWTError:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
@app.get("/verify-token/{token}")
|
|
async def verify_user_token(token: str):
|
|
verify_token(token=token)
|
|
return {"message": "Token is valid"}
|
|
|
|
@app.post("/user/chat/save")
|
|
def populate_user_chat(chat: NewUserChat, db: Session = Depends(get_db)):
|
|
try:
|
|
payload = jwt.decode(chat.token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
username: str = payload.get("sub")
|
|
if username is None:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
user = db.query(User).filter(User.username == username).first()
|
|
newchat = Chat(type=chat.type, title=chat.title, chats_list=chat.chats_list)
|
|
|
|
user.chats.append(newchat)
|
|
db.commit()
|
|
return {
|
|
"message": "Chat Saved"
|
|
}
|
|
except JWTError:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
@app.post("/user/chat/update")
|
|
def populate_user_chat(chat: ChatToUpdate, db: Session = Depends(get_db)):
|
|
try:
|
|
payload = jwt.decode(chat.token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
username: str = payload.get("sub")
|
|
if username is None:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
chatindb = db.query(Chat).filter(Chat.id == chat.chatid).first()
|
|
chatindb.chats_list = chat.chats_list
|
|
|
|
db.commit()
|
|
return {
|
|
"message": "Chat Updated"
|
|
}
|
|
except JWTError:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
@app.get("/user/chat/delete/{token}/{chatid}")
|
|
async def delete_chat_of_user(token: str, chatid: str, db: Session = Depends(get_db)):
|
|
try:
|
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
username: str = payload.get("sub")
|
|
if username is None:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
chatindb = db.query(Chat).filter(Chat.id == chatid).first()
|
|
db.delete(chatindb)
|
|
db.commit()
|
|
return {
|
|
"message": "Chat Deleted"
|
|
}
|
|
except JWTError:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
#Gets user id & name
|
|
@app.get("/user/{token}")
|
|
async def get_user_with_token(token: str, db: Session = Depends(get_db)):
|
|
try:
|
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
username: str = payload.get("sub")
|
|
if username is None:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
user = db.query(User).filter(User.username == username).first()
|
|
return {
|
|
"userid": user.id,
|
|
"username": user.username,
|
|
"chats": user.chats,
|
|
"documents": user.documents
|
|
}
|
|
except JWTError:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
@app.get("/searchspaces/{token}")
|
|
async def get_user_with_token(token: str, db: Session = Depends(get_db)):
|
|
try:
|
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
username: str = payload.get("sub")
|
|
if username is None:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
search_spaces = db.query(SearchSpace).all()
|
|
return {
|
|
"search_spaces": search_spaces
|
|
}
|
|
except JWTError:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run(app, host="127.0.0.1", port=8000)
|
|
|
|
|
|
|