diff --git a/SurfSense-Frontend b/SurfSense-Frontend index aeb87be..4d0d141 160000 --- a/SurfSense-Frontend +++ b/SurfSense-Frontend @@ -1 +1 @@ -Subproject commit aeb87beee51bbf040307157d6c54b3a3a82ef620 +Subproject commit 4d0d1414c4774704619834ec8e9280940976f500 diff --git a/backend/.env.example b/backend/.env.example index 948ba63..def5a22 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -1,7 +1,7 @@ #true if you wana run local setup with Ollama llama3.1 IS_LOCAL_SETUP = 'false' -#POSTGRES DB TO TRACK USERS +#POSTGRES DB TO TRACK USERS | replace localhost with db for docker setups eg "postgresql+psycopg2://postgres:postgres@db:5432/surfsense" POSTGRES_DATABASE_URL = "postgresql+psycopg2://postgres:postgres@localhost:5432/surfsense" # API KEY TO PREVENT USER REGISTRATION SPAM diff --git a/backend/HIndices.py b/backend/HIndices.py index 070d9b8..a8fe5fa 100644 --- a/backend/HIndices.py +++ b/backend/HIndices.py @@ -23,16 +23,26 @@ IS_LOCAL_SETUP = os.environ.get("IS_LOCAL_SETUP") # Dependency def get_db(): + """ + Dependency to get a database session. + + Yields: + Session: A SQLAlchemy database session. + """ db = SessionLocal() try: yield db finally: db.close() - class HIndices: def __init__(self, username, api_key='local'): """ + Initialize the HIndices object. + + Args: + username (str): The username for the indices. + api_key (str, optional): API key for non-local setups. Defaults to 'local'. """ self.username = username if(IS_LOCAL_SETUP == 'true'): diff --git a/backend/server.py b/backend/server.py index b4acb96..b6b0ada 100644 --- a/backend/server.py +++ b/backend/server.py @@ -8,7 +8,7 @@ from prompts import CONTEXT_ANSWER_PROMPT, DATE_TODAY, SUBQUERY_DECOMPOSITION_PR from pydmodels import ChatToUpdate, DescriptionResponse, DocWithContent, DocumentsToDelete, NewUserChat, UserCreate, UserQuery, RetrivedDocList, UserQueryResponse, UserQueryWithChatHistory from langchain_core.messages import HumanMessage, SystemMessage, AIMessage -#Heirerical Indices class +# Hierarchical Indices class from HIndices import HIndices from Utils.stringify import stringify @@ -24,11 +24,11 @@ from models import Chat, Documents, SearchSpace, User from database import SessionLocal from fastapi.middleware.cors import CORSMiddleware - import os from dotenv import load_dotenv load_dotenv() +# Environment variables IS_LOCAL_SETUP = os.environ.get("IS_LOCAL_SETUP") ACCESS_TOKEN_EXPIRE_MINUTES = int(os.environ.get("ACCESS_TOKEN_EXPIRE_MINUTES")) ALGORITHM = os.environ.get("ALGORITHM") @@ -37,17 +37,33 @@ SECRET_KEY = os.environ.get("SECRET_KEY") app = FastAPI() -# Dependency def get_db(): + """ + Dependency to get a database session. + + Yields: + Session: A SQLAlchemy database session. + """ db = SessionLocal() try: yield db finally: db.close() - @app.post("/chat/") def get_user_query_response(data: UserQuery, response_model=UserQueryResponse): + """ + Process a user query and return a response. + + Args: + data (UserQuery): The user query data. + + Returns: + UserQueryResponse: The response to the user query. + + Raises: + HTTPException: If the token is invalid or expired. + """ try: payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM]) username: str = payload.get("sub") @@ -57,15 +73,14 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse): query = data.query search_space = data.search_space - if(IS_LOCAL_SETUP == 'true'): - sub_query_llm = OllamaLLM(model="mistral-nemo",temperature=0) - qa_llm = OllamaLLM(model="mistral-nemo",temperature=0) + # Initialize LLM based on setup + if IS_LOCAL_SETUP == 'true': + sub_query_llm = OllamaLLM(model="mistral-nemo", temperature=0) + qa_llm = OllamaLLM(model="mistral-nemo", temperature=0) else: sub_query_llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", api_key=data.openaikey) qa_llm = ChatOpenAI(temperature=0.5, model_name="gpt-4o-mini", api_key=data.openaikey) - - # Create an LLMChain for sub-query decomposition subquery_decomposer_chain = SUBQUERY_DECOMPOSITION_PROMT | sub_query_llm @@ -74,12 +89,12 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse): Decompose the original query into simpler sub-queries. Args: - original_query (str): The original complex query + original_query (str): The original complex query Returns: - List[str]: A list of simpler sub-queries + List[str]: A list of simpler sub-queries """ - if(IS_LOCAL_SETUP == 'true'): + if IS_LOCAL_SETUP == 'true': response = subquery_decomposer_chain.invoke(original_query) else: response = subquery_decomposer_chain.invoke(original_query).content @@ -87,20 +102,16 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse): sub_queries = [q.strip() for q in response.split('\n') if q.strip() and not q.strip().startswith('Sub-queries:')] return sub_queries - - # Create Heirarical Indecices - if(IS_LOCAL_SETUP == 'true'): + # Create Hierarchical Indices + if IS_LOCAL_SETUP == 'true': index = HIndices(username=username) else: - index = HIndices(username=username,api_key=data.openaikey) - - + index = HIndices(username=username, api_key=data.openaikey) - # For Those Who Want HyDe Questions + # For those who want HyDE questions # sub_queries = decompose_query(query) - sub_queries = [] - sub_queries.append(query) + sub_queries = [query] duplicate_related_summary_docs = [] context_to_answer = "" @@ -112,12 +123,11 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse): duplicate_related_summary_docs.extend(related_summary_docs) - + # Remove duplicate documents combined_docs_seen_metadata = set() combined_docs_unique_documents = [] for doc in duplicate_related_summary_docs: - # Convert metadata to a tuple of its items (this allows it to be added to a set) doc.metadata['relevance_score'] = 0.0 metadata_tuple = tuple(sorted(doc.metadata.items())) if metadata_tuple not in combined_docs_seen_metadata: @@ -134,29 +144,37 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse): VisitedWebPageDateWithTimeInISOString=doc.metadata['VisitedWebPageDateWithTimeInISOString'], VisitedWebPageReffererURL=doc.metadata['VisitedWebPageReffererURL'], VisitedWebPageVisitDurationInMilliseconds=doc.metadata['VisitedWebPageVisitDurationInMilliseconds'], - ) + ) returnDocs.append(entry) - + # Generate final answer ans_chain = CONTEXT_ANSWER_PROMPT | qa_llm - finalans = ans_chain.invoke({"query": query, "context": context_to_answer}) - - if(IS_LOCAL_SETUP == 'true'): + if IS_LOCAL_SETUP == 'true': return UserQueryResponse(response=finalans, relateddocs=returnDocs) else: return UserQueryResponse(response=finalans.content, relateddocs=returnDocs) - except JWTError: raise HTTPException(status_code=403, detail="Token is invalid or expired") -# SAVE DOCS @app.post("/save/") def save_data(apires: RetrivedDocList, db: Session = Depends(get_db)): + """ + Save retrieved documents to the database and vector stores. + Args: + apires (RetrivedDocList): The list of documents to save. + db (Session): The database session. + + Returns: + dict: A message indicating the success of the operation. + + Raises: + HTTPException: If the token is invalid or expired. + """ try: payload = jwt.decode(apires.token, SECRET_KEY, algorithms=[ALGORITHM]) username: str = payload.get("sub") @@ -181,31 +199,29 @@ def save_data(apires: RetrivedDocList, db: Session = Depends(get_db)): content += f"===================================================================================== \n" content += f"Webpage Content of the visited webpage url in markdown format : \n\n {doc.pageContent} \n\n" content += f"===================================================================================== \n" - raw_documents.append(Document(page_content=content,metadata=doc.metadata.__dict__)) + raw_documents.append(Document(page_content=content, metadata=doc.metadata.__dict__)) pgdocmeta = stringify(doc.metadata.__dict__) - if(searchspace): - DocumentPgEntry.append(Documents(file_type='WEBPAGE',title=doc.metadata.VisitedWebPageTitle,search_space=searchspace, document_metadata=pgdocmeta, page_content=content)) + if searchspace: + DocumentPgEntry.append(Documents(file_type='WEBPAGE', title=doc.metadata.VisitedWebPageTitle, search_space=searchspace, document_metadata=pgdocmeta, page_content=content)) else: - DocumentPgEntry.append(Documents(file_type='WEBPAGE',title=doc.metadata.VisitedWebPageTitle,search_space=SearchSpace(search_space=apires.search_space.upper()), document_metadata=pgdocmeta, page_content=content)) + DocumentPgEntry.append(Documents(file_type='WEBPAGE', title=doc.metadata.VisitedWebPageTitle, search_space=SearchSpace(search_space=apires.search_space.upper()), document_metadata=pgdocmeta, page_content=content)) - - #Save docs in PG + # Save docs in PG user = db.query(User).filter(User.username == username).first() user.documents.extend(DocumentPgEntry) db.commit() - - # Create Heirarical Indecices - if(IS_LOCAL_SETUP == 'true'): + # Create Hierarchical Indices + if IS_LOCAL_SETUP == 'true': index = HIndices(username=username) else: - index = HIndices(username=username,api_key=apires.openaikey) + index = HIndices(username=username, api_key=apires.openaikey) - #Save Indices in vector Stores - index.encode_docs_hierarchical(documents=raw_documents, files_type='WEBPAGE',search_space=apires.search_space.upper(), db=db) + # Save Indices in vector stores + index.encode_docs_hierarchical(documents=raw_documents, files_type='WEBPAGE', search_space=apires.search_space.upper(), db=db) print("FINISHED") @@ -216,45 +232,55 @@ def save_data(apires: RetrivedDocList, db: Session = Depends(get_db)): except JWTError: raise HTTPException(status_code=403, detail="Token is invalid or expired") -# Multi DOC Chat @app.post("/chat/docs") def doc_chat_with_history(data: UserQueryWithChatHistory, response_model=DescriptionResponse): + """ + Process a user query with chat history and return a response. + + Args: + data (UserQueryWithChatHistory): The user query data with chat history. + + Returns: + DescriptionResponse: The response to the user query. + + Raises: + HTTPException: If the token is invalid or expired. + """ try: payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM]) username: str = payload.get("sub") if username is None: raise HTTPException(status_code=403, detail="Token is invalid or expired") - if(IS_LOCAL_SETUP == 'true'): - llm = OllamaLLM(model="mistral-nemo",temperature=0) + if IS_LOCAL_SETUP == 'true': + llm = OllamaLLM(model="mistral-nemo", temperature=0) else: llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", api_key=data.openaikey) chatHistory = [] for chat in data.chat: - if(chat.type == 'system'): - chatHistory.append(SystemMessage(content=DATE_TODAY + """You are an helpful assistant for question-answering tasks. + if chat.type == 'system': + chatHistory.append(SystemMessage(content=DATE_TODAY + """You are a helpful assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Context:""" + str(chat.content))) - if(chat.type == 'ai'): + if chat.type == 'ai': chatHistory.append(AIMessage(content=chat.content)) - if(chat.type == 'human'): + if chat.type == 'human': chatHistory.append(HumanMessage(content=chat.content)) - chatHistory.append(("human", "{input}")); + chatHistory.append(("human", "{input}")) - qa_prompt = ChatPromptTemplate.from_messages(chatHistory) descriptionchain = qa_prompt | llm response = descriptionchain.invoke({"input": data.query}) - if(IS_LOCAL_SETUP == 'true'): + if IS_LOCAL_SETUP == 'true': return DescriptionResponse(response=response) else: return DescriptionResponse(response=response.content) @@ -262,23 +288,33 @@ def doc_chat_with_history(data: UserQueryWithChatHistory, response_model=Descrip except JWTError: raise HTTPException(status_code=403, detail="Token is invalid or expired") - - # Multi DOC Chat - @app.post("/delete/docs") def delete_all_related_data(data: DocumentsToDelete, db: Session = Depends(get_db)): + """ + Delete documents and related data. + + Args: + data (DocumentsToDelete): The data containing documents to delete. + db (Session): The database session. + + Returns: + dict: A message indicating the result of the deletion. + + Raises: + HTTPException: If the token is invalid or expired. + """ try: payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM]) username: str = payload.get("sub") if username is None: raise HTTPException(status_code=403, detail="Token is invalid or expired") - if(IS_LOCAL_SETUP == 'true'): + if IS_LOCAL_SETUP == 'true': index = HIndices(username=username) else: - index = HIndices(username=username,api_key=data.openaikey) + index = HIndices(username=username, api_key=data.openaikey) - message = index.delete_vector_stores(summary_ids_to_delete=data.ids_to_delete,db=db ) + message = index.delete_vector_stores(summary_ids_to_delete=data.ids_to_delete, db=db) return { "message": message @@ -287,19 +323,12 @@ def delete_all_related_data(data: DocumentsToDelete, db: Session = Depends(get_d except JWTError: raise HTTPException(status_code=403, detail="Token is invalid or expired") - -#AUTH CODE +# AUTH CODE oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") -# Manual Origins -# origins = [ -# "http://localhost:3000", # Adjust the port if your frontend runs on a different one -# "https://yourfrontenddomain.com", -# ] - app.add_middleware( CORSMiddleware, - allow_origins=["*"], # Allows all origins from the list + allow_origins=["*"], # Allows all origins allow_credentials=True, allow_methods=["*"], # Allows all methods allow_headers=["*"], # Allows all headers @@ -308,9 +337,29 @@ app.add_middleware( pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") def get_user_by_username(db: Session, username: str): + """ + Retrieve a user by username from the database. + + Args: + db (Session): The database session. + username (str): The username to search for. + + Returns: + User: The user object if found, None otherwise. + """ return db.query(User).filter(User.username == username).first() def create_user(db: Session, user: UserCreate): + """ + Create a new user in the database. + + Args: + db (Session): The database session. + user (UserCreate): The user data to create. + + Returns: + str: A message indicating the completion of user creation. + """ hashed_password = pwd_context.hash(user.password) db_user = User(username=user.username, hashed_password=hashed_password) db.add(db_user) @@ -319,7 +368,20 @@ def create_user(db: Session, user: UserCreate): @app.post("/register") def register_user(user: UserCreate, db: Session = Depends(get_db)): - if(user.apisecretkey != API_SECRET_KEY): + """ + Register a new user. + + Args: + user (UserCreate): The user data for registration. + db (Session): The database session. + + Returns: + str: A message indicating the completion of user registration. + + Raises: + HTTPException: If the API secret key is invalid or the username is already registered. + """ + if user.apisecretkey != API_SECRET_KEY: raise HTTPException(status_code=401, detail="Unauthorized") db_user = get_user_by_username(db, username=user.username) @@ -329,8 +391,18 @@ def register_user(user: UserCreate, db: Session = Depends(get_db)): del user.apisecretkey return create_user(db=db, user=user) -# Authenticate the user def authenticate_user(username: str, password: str, db: Session): + """ + Authenticate a user. + + Args: + username (str): The username of the user. + password (str): The password of the user. + db (Session): The database session. + + Returns: + User: The authenticated user object if successful, False otherwise. + """ user = db.query(User).filter(User.username == username).first() if not user: return False @@ -338,8 +410,17 @@ def authenticate_user(username: str, password: str, db: Session): return False return user -# Create access token def create_access_token(data: dict, expires_delta: timedelta | None = None): + """ + Create an access token. + + Args: + data (dict): The data to encode in the token. + expires_delta (timedelta, optional): The expiration time for the token. + + Returns: + str: The encoded JWT token. + """ to_encode = data.copy() if expires_delta: expire = datetime.utcnow() + expires_delta @@ -351,6 +432,19 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None): @app.post("/token") def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)): + """ + Authenticate a user and return an access token. + + Args: + form_data (OAuth2PasswordRequestForm): The form data containing username and password. + db (Session): The database session. + + Returns: + dict: A dictionary containing the access token and token type. + + Raises: + HTTPException: If the username or password is incorrect. + """ user = authenticate_user(form_data.username, form_data.password, db) if not user: raise HTTPException( @@ -365,6 +459,18 @@ def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: return {"access_token": access_token, "token_type": "bearer"} def verify_token(token: str = Depends(oauth2_scheme)): + """ + Verify the validity of a token. + + Args: + token (str): The token to verify. + + Returns: + dict: The payload of the token if valid. + + Raises: + HTTPException: If the token is invalid or expired. + """ try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) username: str = payload.get("sub") @@ -376,11 +482,33 @@ def verify_token(token: str = Depends(oauth2_scheme)): @app.get("/verify-token/{token}") async def verify_user_token(token: str): + """ + Verify a user token. + + Args: + token (str): The token to verify. + + Returns: + dict: A message indicating the validity of the token. + """ verify_token(token=token) return {"message": "Token is valid"} @app.post("/user/chat/save") def populate_user_chat(chat: NewUserChat, db: Session = Depends(get_db)): + """ + Save a new chat for a user. + + Args: + chat (NewUserChat): The chat data to save. + db (Session): The database session. + + Returns: + dict: A message indicating the success of the operation. + + Raises: + HTTPException: If the token is invalid or expired. + """ try: payload = jwt.decode(chat.token, SECRET_KEY, algorithms=[ALGORITHM]) username: str = payload.get("sub") @@ -400,6 +528,19 @@ def populate_user_chat(chat: NewUserChat, db: Session = Depends(get_db)): @app.post("/user/chat/update") def populate_user_chat(chat: ChatToUpdate, db: Session = Depends(get_db)): + """ + Update an existing chat for a user. + + Args: + chat (ChatToUpdate): The chat data to update. + db (Session): The database session. + + Returns: + dict: A message indicating the success of the operation. + + Raises: + HTTPException: If the token is invalid or expired. + """ try: payload = jwt.decode(chat.token, SECRET_KEY, algorithms=[ALGORITHM]) username: str = payload.get("sub") @@ -418,6 +559,20 @@ def populate_user_chat(chat: ChatToUpdate, db: Session = Depends(get_db)): @app.get("/user/chat/delete/{token}/{chatid}") async def delete_chat_of_user(token: str, chatid: str, db: Session = Depends(get_db)): + """ + Delete a chat for a user. + + Args: + token (str): The user's authentication token. + chatid (str): The ID of the chat to delete. + db (Session): The database session. + + Returns: + dict: A message indicating the success of the operation. + + Raises: + HTTPException: If the token is invalid or expired. + """ try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) username: str = payload.get("sub") @@ -433,9 +588,21 @@ async def delete_chat_of_user(token: str, chatid: str, db: Session = Depends(get except JWTError: raise HTTPException(status_code=403, detail="Token is invalid or expired") -#Gets user id & name @app.get("/user/{token}") async def get_user_with_token(token: str, db: Session = Depends(get_db)): + """ + Get user information using a token. + + Args: + token (str): The user's authentication token. + db (Session): The database session. + + Returns: + dict: User information including ID, username, chats, and documents. + + Raises: + HTTPException: If the token is invalid or expired. + """ try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) username: str = payload.get("sub") @@ -454,6 +621,19 @@ async def get_user_with_token(token: str, db: Session = Depends(get_db)): @app.get("/searchspaces/{token}") async def get_user_with_token(token: str, db: Session = Depends(get_db)): + """ + Get all search spaces. + + Args: + token (str): The user's authentication token. + db (Session): The database session. + + Returns: + dict: A dictionary containing all search spaces. + + Raises: + HTTPException: If the token is invalid or expired. + """ try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) username: str = payload.get("sub") @@ -467,12 +647,7 @@ async def get_user_with_token(token: str, db: Session = Depends(get_db)): except JWTError: raise HTTPException(status_code=403, detail="Token is invalid or expired") - - if __name__ == "__main__": import uvicorn uvicorn.run(app, host="127.0.0.1", port=8000) - - - diff --git a/docker-compose.yml b/docker-compose.yml index 037cb6a..f4cc9b7 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -12,6 +12,8 @@ services: - postgres_data:/var/lib/postgresql/data networks: - surfsense-network + ports: + - "5432:5432" # Backend Service (FastAPI) backend: @@ -23,6 +25,7 @@ services: - ./backend/.env depends_on: - db + privileged: true networks: - surfsense-network