mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-04 11:39:19 +00:00
Docker Support
This commit is contained in:
parent
778c16f8bf
commit
d247cacaa9
5 changed files with 267 additions and 79 deletions
|
@ -1 +1 @@
|
||||||
Subproject commit aeb87beee51bbf040307157d6c54b3a3a82ef620
|
Subproject commit 4d0d1414c4774704619834ec8e9280940976f500
|
|
@ -1,7 +1,7 @@
|
||||||
#true if you wana run local setup with Ollama llama3.1
|
#true if you wana run local setup with Ollama llama3.1
|
||||||
IS_LOCAL_SETUP = 'false'
|
IS_LOCAL_SETUP = 'false'
|
||||||
|
|
||||||
#POSTGRES DB TO TRACK USERS
|
#POSTGRES DB TO TRACK USERS | replace localhost with db for docker setups eg "postgresql+psycopg2://postgres:postgres@db:5432/surfsense"
|
||||||
POSTGRES_DATABASE_URL = "postgresql+psycopg2://postgres:postgres@localhost:5432/surfsense"
|
POSTGRES_DATABASE_URL = "postgresql+psycopg2://postgres:postgres@localhost:5432/surfsense"
|
||||||
|
|
||||||
# API KEY TO PREVENT USER REGISTRATION SPAM
|
# API KEY TO PREVENT USER REGISTRATION SPAM
|
||||||
|
|
|
@ -23,16 +23,26 @@ IS_LOCAL_SETUP = os.environ.get("IS_LOCAL_SETUP")
|
||||||
|
|
||||||
# Dependency
|
# Dependency
|
||||||
def get_db():
|
def get_db():
|
||||||
|
"""
|
||||||
|
Dependency to get a database session.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Session: A SQLAlchemy database session.
|
||||||
|
"""
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
yield db
|
yield db
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
class HIndices:
|
class HIndices:
|
||||||
def __init__(self, username, api_key='local'):
|
def __init__(self, username, api_key='local'):
|
||||||
"""
|
"""
|
||||||
|
Initialize the HIndices object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
username (str): The username for the indices.
|
||||||
|
api_key (str, optional): API key for non-local setups. Defaults to 'local'.
|
||||||
"""
|
"""
|
||||||
self.username = username
|
self.username = username
|
||||||
if(IS_LOCAL_SETUP == 'true'):
|
if(IS_LOCAL_SETUP == 'true'):
|
||||||
|
|
|
@ -8,7 +8,7 @@ from prompts import CONTEXT_ANSWER_PROMPT, DATE_TODAY, SUBQUERY_DECOMPOSITION_PR
|
||||||
from pydmodels import ChatToUpdate, DescriptionResponse, DocWithContent, DocumentsToDelete, NewUserChat, UserCreate, UserQuery, RetrivedDocList, UserQueryResponse, UserQueryWithChatHistory
|
from pydmodels import ChatToUpdate, DescriptionResponse, DocWithContent, DocumentsToDelete, NewUserChat, UserCreate, UserQuery, RetrivedDocList, UserQueryResponse, UserQueryWithChatHistory
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
||||||
|
|
||||||
#Heirerical Indices class
|
# Hierarchical Indices class
|
||||||
from HIndices import HIndices
|
from HIndices import HIndices
|
||||||
|
|
||||||
from Utils.stringify import stringify
|
from Utils.stringify import stringify
|
||||||
|
@ -24,11 +24,11 @@ from models import Chat, Documents, SearchSpace, User
|
||||||
from database import SessionLocal
|
from database import SessionLocal
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
# Environment variables
|
||||||
IS_LOCAL_SETUP = os.environ.get("IS_LOCAL_SETUP")
|
IS_LOCAL_SETUP = os.environ.get("IS_LOCAL_SETUP")
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.environ.get("ACCESS_TOKEN_EXPIRE_MINUTES"))
|
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.environ.get("ACCESS_TOKEN_EXPIRE_MINUTES"))
|
||||||
ALGORITHM = os.environ.get("ALGORITHM")
|
ALGORITHM = os.environ.get("ALGORITHM")
|
||||||
|
@ -37,17 +37,33 @@ SECRET_KEY = os.environ.get("SECRET_KEY")
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
# Dependency
|
|
||||||
def get_db():
|
def get_db():
|
||||||
|
"""
|
||||||
|
Dependency to get a database session.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Session: A SQLAlchemy database session.
|
||||||
|
"""
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
yield db
|
yield db
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
@app.post("/chat/")
|
@app.post("/chat/")
|
||||||
def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
|
def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
|
||||||
|
"""
|
||||||
|
Process a user query and return a response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (UserQuery): The user query data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserQueryResponse: The response to the user query.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If the token is invalid or expired.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
|
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
username: str = payload.get("sub")
|
username: str = payload.get("sub")
|
||||||
|
@ -57,15 +73,14 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
|
||||||
query = data.query
|
query = data.query
|
||||||
search_space = data.search_space
|
search_space = data.search_space
|
||||||
|
|
||||||
if(IS_LOCAL_SETUP == 'true'):
|
# Initialize LLM based on setup
|
||||||
sub_query_llm = OllamaLLM(model="mistral-nemo",temperature=0)
|
if IS_LOCAL_SETUP == 'true':
|
||||||
qa_llm = OllamaLLM(model="mistral-nemo",temperature=0)
|
sub_query_llm = OllamaLLM(model="mistral-nemo", temperature=0)
|
||||||
|
qa_llm = OllamaLLM(model="mistral-nemo", temperature=0)
|
||||||
else:
|
else:
|
||||||
sub_query_llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", api_key=data.openaikey)
|
sub_query_llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", api_key=data.openaikey)
|
||||||
qa_llm = ChatOpenAI(temperature=0.5, model_name="gpt-4o-mini", api_key=data.openaikey)
|
qa_llm = ChatOpenAI(temperature=0.5, model_name="gpt-4o-mini", api_key=data.openaikey)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Create an LLMChain for sub-query decomposition
|
# Create an LLMChain for sub-query decomposition
|
||||||
subquery_decomposer_chain = SUBQUERY_DECOMPOSITION_PROMT | sub_query_llm
|
subquery_decomposer_chain = SUBQUERY_DECOMPOSITION_PROMT | sub_query_llm
|
||||||
|
|
||||||
|
@ -74,12 +89,12 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
|
||||||
Decompose the original query into simpler sub-queries.
|
Decompose the original query into simpler sub-queries.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
original_query (str): The original complex query
|
original_query (str): The original complex query
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[str]: A list of simpler sub-queries
|
List[str]: A list of simpler sub-queries
|
||||||
"""
|
"""
|
||||||
if(IS_LOCAL_SETUP == 'true'):
|
if IS_LOCAL_SETUP == 'true':
|
||||||
response = subquery_decomposer_chain.invoke(original_query)
|
response = subquery_decomposer_chain.invoke(original_query)
|
||||||
else:
|
else:
|
||||||
response = subquery_decomposer_chain.invoke(original_query).content
|
response = subquery_decomposer_chain.invoke(original_query).content
|
||||||
|
@ -87,20 +102,16 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
|
||||||
sub_queries = [q.strip() for q in response.split('\n') if q.strip() and not q.strip().startswith('Sub-queries:')]
|
sub_queries = [q.strip() for q in response.split('\n') if q.strip() and not q.strip().startswith('Sub-queries:')]
|
||||||
return sub_queries
|
return sub_queries
|
||||||
|
|
||||||
|
# Create Hierarchical Indices
|
||||||
# Create Heirarical Indecices
|
if IS_LOCAL_SETUP == 'true':
|
||||||
if(IS_LOCAL_SETUP == 'true'):
|
|
||||||
index = HIndices(username=username)
|
index = HIndices(username=username)
|
||||||
else:
|
else:
|
||||||
index = HIndices(username=username,api_key=data.openaikey)
|
index = HIndices(username=username, api_key=data.openaikey)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# For Those Who Want HyDe Questions
|
# For those who want HyDE questions
|
||||||
# sub_queries = decompose_query(query)
|
# sub_queries = decompose_query(query)
|
||||||
|
|
||||||
sub_queries = []
|
sub_queries = [query]
|
||||||
sub_queries.append(query)
|
|
||||||
|
|
||||||
duplicate_related_summary_docs = []
|
duplicate_related_summary_docs = []
|
||||||
context_to_answer = ""
|
context_to_answer = ""
|
||||||
|
@ -112,12 +123,11 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
|
||||||
|
|
||||||
duplicate_related_summary_docs.extend(related_summary_docs)
|
duplicate_related_summary_docs.extend(related_summary_docs)
|
||||||
|
|
||||||
|
# Remove duplicate documents
|
||||||
combined_docs_seen_metadata = set()
|
combined_docs_seen_metadata = set()
|
||||||
combined_docs_unique_documents = []
|
combined_docs_unique_documents = []
|
||||||
|
|
||||||
for doc in duplicate_related_summary_docs:
|
for doc in duplicate_related_summary_docs:
|
||||||
# Convert metadata to a tuple of its items (this allows it to be added to a set)
|
|
||||||
doc.metadata['relevance_score'] = 0.0
|
doc.metadata['relevance_score'] = 0.0
|
||||||
metadata_tuple = tuple(sorted(doc.metadata.items()))
|
metadata_tuple = tuple(sorted(doc.metadata.items()))
|
||||||
if metadata_tuple not in combined_docs_seen_metadata:
|
if metadata_tuple not in combined_docs_seen_metadata:
|
||||||
|
@ -134,29 +144,37 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
|
||||||
VisitedWebPageDateWithTimeInISOString=doc.metadata['VisitedWebPageDateWithTimeInISOString'],
|
VisitedWebPageDateWithTimeInISOString=doc.metadata['VisitedWebPageDateWithTimeInISOString'],
|
||||||
VisitedWebPageReffererURL=doc.metadata['VisitedWebPageReffererURL'],
|
VisitedWebPageReffererURL=doc.metadata['VisitedWebPageReffererURL'],
|
||||||
VisitedWebPageVisitDurationInMilliseconds=doc.metadata['VisitedWebPageVisitDurationInMilliseconds'],
|
VisitedWebPageVisitDurationInMilliseconds=doc.metadata['VisitedWebPageVisitDurationInMilliseconds'],
|
||||||
)
|
)
|
||||||
|
|
||||||
returnDocs.append(entry)
|
returnDocs.append(entry)
|
||||||
|
|
||||||
|
# Generate final answer
|
||||||
ans_chain = CONTEXT_ANSWER_PROMPT | qa_llm
|
ans_chain = CONTEXT_ANSWER_PROMPT | qa_llm
|
||||||
|
|
||||||
finalans = ans_chain.invoke({"query": query, "context": context_to_answer})
|
finalans = ans_chain.invoke({"query": query, "context": context_to_answer})
|
||||||
|
|
||||||
|
if IS_LOCAL_SETUP == 'true':
|
||||||
if(IS_LOCAL_SETUP == 'true'):
|
|
||||||
return UserQueryResponse(response=finalans, relateddocs=returnDocs)
|
return UserQueryResponse(response=finalans, relateddocs=returnDocs)
|
||||||
else:
|
else:
|
||||||
return UserQueryResponse(response=finalans.content, relateddocs=returnDocs)
|
return UserQueryResponse(response=finalans.content, relateddocs=returnDocs)
|
||||||
|
|
||||||
|
|
||||||
except JWTError:
|
except JWTError:
|
||||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||||
|
|
||||||
# SAVE DOCS
|
|
||||||
@app.post("/save/")
|
@app.post("/save/")
|
||||||
def save_data(apires: RetrivedDocList, db: Session = Depends(get_db)):
|
def save_data(apires: RetrivedDocList, db: Session = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
Save retrieved documents to the database and vector stores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
apires (RetrivedDocList): The list of documents to save.
|
||||||
|
db (Session): The database session.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A message indicating the success of the operation.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If the token is invalid or expired.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(apires.token, SECRET_KEY, algorithms=[ALGORITHM])
|
payload = jwt.decode(apires.token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
username: str = payload.get("sub")
|
username: str = payload.get("sub")
|
||||||
|
@ -181,31 +199,29 @@ def save_data(apires: RetrivedDocList, db: Session = Depends(get_db)):
|
||||||
content += f"===================================================================================== \n"
|
content += f"===================================================================================== \n"
|
||||||
content += f"Webpage Content of the visited webpage url in markdown format : \n\n {doc.pageContent} \n\n"
|
content += f"Webpage Content of the visited webpage url in markdown format : \n\n {doc.pageContent} \n\n"
|
||||||
content += f"===================================================================================== \n"
|
content += f"===================================================================================== \n"
|
||||||
raw_documents.append(Document(page_content=content,metadata=doc.metadata.__dict__))
|
raw_documents.append(Document(page_content=content, metadata=doc.metadata.__dict__))
|
||||||
|
|
||||||
pgdocmeta = stringify(doc.metadata.__dict__)
|
pgdocmeta = stringify(doc.metadata.__dict__)
|
||||||
|
|
||||||
if(searchspace):
|
if searchspace:
|
||||||
DocumentPgEntry.append(Documents(file_type='WEBPAGE',title=doc.metadata.VisitedWebPageTitle,search_space=searchspace, document_metadata=pgdocmeta, page_content=content))
|
DocumentPgEntry.append(Documents(file_type='WEBPAGE', title=doc.metadata.VisitedWebPageTitle, search_space=searchspace, document_metadata=pgdocmeta, page_content=content))
|
||||||
else:
|
else:
|
||||||
DocumentPgEntry.append(Documents(file_type='WEBPAGE',title=doc.metadata.VisitedWebPageTitle,search_space=SearchSpace(search_space=apires.search_space.upper()), document_metadata=pgdocmeta, page_content=content))
|
DocumentPgEntry.append(Documents(file_type='WEBPAGE', title=doc.metadata.VisitedWebPageTitle, search_space=SearchSpace(search_space=apires.search_space.upper()), document_metadata=pgdocmeta, page_content=content))
|
||||||
|
|
||||||
|
# Save docs in PG
|
||||||
#Save docs in PG
|
|
||||||
user = db.query(User).filter(User.username == username).first()
|
user = db.query(User).filter(User.username == username).first()
|
||||||
user.documents.extend(DocumentPgEntry)
|
user.documents.extend(DocumentPgEntry)
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
|
# Create Hierarchical Indices
|
||||||
# Create Heirarical Indecices
|
if IS_LOCAL_SETUP == 'true':
|
||||||
if(IS_LOCAL_SETUP == 'true'):
|
|
||||||
index = HIndices(username=username)
|
index = HIndices(username=username)
|
||||||
else:
|
else:
|
||||||
index = HIndices(username=username,api_key=apires.openaikey)
|
index = HIndices(username=username, api_key=apires.openaikey)
|
||||||
|
|
||||||
#Save Indices in vector Stores
|
# Save Indices in vector stores
|
||||||
index.encode_docs_hierarchical(documents=raw_documents, files_type='WEBPAGE',search_space=apires.search_space.upper(), db=db)
|
index.encode_docs_hierarchical(documents=raw_documents, files_type='WEBPAGE', search_space=apires.search_space.upper(), db=db)
|
||||||
|
|
||||||
print("FINISHED")
|
print("FINISHED")
|
||||||
|
|
||||||
|
@ -216,45 +232,55 @@ def save_data(apires: RetrivedDocList, db: Session = Depends(get_db)):
|
||||||
except JWTError:
|
except JWTError:
|
||||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||||
|
|
||||||
# Multi DOC Chat
|
|
||||||
@app.post("/chat/docs")
|
@app.post("/chat/docs")
|
||||||
def doc_chat_with_history(data: UserQueryWithChatHistory, response_model=DescriptionResponse):
|
def doc_chat_with_history(data: UserQueryWithChatHistory, response_model=DescriptionResponse):
|
||||||
|
"""
|
||||||
|
Process a user query with chat history and return a response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (UserQueryWithChatHistory): The user query data with chat history.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DescriptionResponse: The response to the user query.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If the token is invalid or expired.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
|
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
username: str = payload.get("sub")
|
username: str = payload.get("sub")
|
||||||
if username is None:
|
if username is None:
|
||||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||||
|
|
||||||
if(IS_LOCAL_SETUP == 'true'):
|
if IS_LOCAL_SETUP == 'true':
|
||||||
llm = OllamaLLM(model="mistral-nemo",temperature=0)
|
llm = OllamaLLM(model="mistral-nemo", temperature=0)
|
||||||
else:
|
else:
|
||||||
llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", api_key=data.openaikey)
|
llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", api_key=data.openaikey)
|
||||||
|
|
||||||
chatHistory = []
|
chatHistory = []
|
||||||
|
|
||||||
for chat in data.chat:
|
for chat in data.chat:
|
||||||
if(chat.type == 'system'):
|
if chat.type == 'system':
|
||||||
chatHistory.append(SystemMessage(content=DATE_TODAY + """You are an helpful assistant for question-answering tasks.
|
chatHistory.append(SystemMessage(content=DATE_TODAY + """You are a helpful assistant for question-answering tasks.
|
||||||
Use the following pieces of retrieved context to answer the question.
|
Use the following pieces of retrieved context to answer the question.
|
||||||
If you don't know the answer, just say that you don't know.
|
If you don't know the answer, just say that you don't know.
|
||||||
Context:""" + str(chat.content)))
|
Context:""" + str(chat.content)))
|
||||||
|
|
||||||
if(chat.type == 'ai'):
|
if chat.type == 'ai':
|
||||||
chatHistory.append(AIMessage(content=chat.content))
|
chatHistory.append(AIMessage(content=chat.content))
|
||||||
|
|
||||||
if(chat.type == 'human'):
|
if chat.type == 'human':
|
||||||
chatHistory.append(HumanMessage(content=chat.content))
|
chatHistory.append(HumanMessage(content=chat.content))
|
||||||
|
|
||||||
chatHistory.append(("human", "{input}"));
|
chatHistory.append(("human", "{input}"))
|
||||||
|
|
||||||
|
|
||||||
qa_prompt = ChatPromptTemplate.from_messages(chatHistory)
|
qa_prompt = ChatPromptTemplate.from_messages(chatHistory)
|
||||||
|
|
||||||
descriptionchain = qa_prompt | llm
|
descriptionchain = qa_prompt | llm
|
||||||
|
|
||||||
response = descriptionchain.invoke({"input": data.query})
|
response = descriptionchain.invoke({"input": data.query})
|
||||||
|
|
||||||
if(IS_LOCAL_SETUP == 'true'):
|
if IS_LOCAL_SETUP == 'true':
|
||||||
return DescriptionResponse(response=response)
|
return DescriptionResponse(response=response)
|
||||||
else:
|
else:
|
||||||
return DescriptionResponse(response=response.content)
|
return DescriptionResponse(response=response.content)
|
||||||
|
@ -262,23 +288,33 @@ def doc_chat_with_history(data: UserQueryWithChatHistory, response_model=Descrip
|
||||||
except JWTError:
|
except JWTError:
|
||||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||||
|
|
||||||
|
|
||||||
# Multi DOC Chat
|
|
||||||
|
|
||||||
@app.post("/delete/docs")
|
@app.post("/delete/docs")
|
||||||
def delete_all_related_data(data: DocumentsToDelete, db: Session = Depends(get_db)):
|
def delete_all_related_data(data: DocumentsToDelete, db: Session = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
Delete documents and related data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (DocumentsToDelete): The data containing documents to delete.
|
||||||
|
db (Session): The database session.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A message indicating the result of the deletion.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If the token is invalid or expired.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
|
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
username: str = payload.get("sub")
|
username: str = payload.get("sub")
|
||||||
if username is None:
|
if username is None:
|
||||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||||
|
|
||||||
if(IS_LOCAL_SETUP == 'true'):
|
if IS_LOCAL_SETUP == 'true':
|
||||||
index = HIndices(username=username)
|
index = HIndices(username=username)
|
||||||
else:
|
else:
|
||||||
index = HIndices(username=username,api_key=data.openaikey)
|
index = HIndices(username=username, api_key=data.openaikey)
|
||||||
|
|
||||||
message = index.delete_vector_stores(summary_ids_to_delete=data.ids_to_delete,db=db )
|
message = index.delete_vector_stores(summary_ids_to_delete=data.ids_to_delete, db=db)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"message": message
|
"message": message
|
||||||
|
@ -287,19 +323,12 @@ def delete_all_related_data(data: DocumentsToDelete, db: Session = Depends(get_d
|
||||||
except JWTError:
|
except JWTError:
|
||||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||||
|
|
||||||
|
# AUTH CODE
|
||||||
#AUTH CODE
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||||
|
|
||||||
# Manual Origins
|
|
||||||
# origins = [
|
|
||||||
# "http://localhost:3000", # Adjust the port if your frontend runs on a different one
|
|
||||||
# "https://yourfrontenddomain.com",
|
|
||||||
# ]
|
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=["*"], # Allows all origins from the list
|
allow_origins=["*"], # Allows all origins
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
allow_methods=["*"], # Allows all methods
|
allow_methods=["*"], # Allows all methods
|
||||||
allow_headers=["*"], # Allows all headers
|
allow_headers=["*"], # Allows all headers
|
||||||
|
@ -308,9 +337,29 @@ app.add_middleware(
|
||||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
|
|
||||||
def get_user_by_username(db: Session, username: str):
|
def get_user_by_username(db: Session, username: str):
|
||||||
|
"""
|
||||||
|
Retrieve a user by username from the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db (Session): The database session.
|
||||||
|
username (str): The username to search for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User: The user object if found, None otherwise.
|
||||||
|
"""
|
||||||
return db.query(User).filter(User.username == username).first()
|
return db.query(User).filter(User.username == username).first()
|
||||||
|
|
||||||
def create_user(db: Session, user: UserCreate):
|
def create_user(db: Session, user: UserCreate):
|
||||||
|
"""
|
||||||
|
Create a new user in the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db (Session): The database session.
|
||||||
|
user (UserCreate): The user data to create.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: A message indicating the completion of user creation.
|
||||||
|
"""
|
||||||
hashed_password = pwd_context.hash(user.password)
|
hashed_password = pwd_context.hash(user.password)
|
||||||
db_user = User(username=user.username, hashed_password=hashed_password)
|
db_user = User(username=user.username, hashed_password=hashed_password)
|
||||||
db.add(db_user)
|
db.add(db_user)
|
||||||
|
@ -319,7 +368,20 @@ def create_user(db: Session, user: UserCreate):
|
||||||
|
|
||||||
@app.post("/register")
|
@app.post("/register")
|
||||||
def register_user(user: UserCreate, db: Session = Depends(get_db)):
|
def register_user(user: UserCreate, db: Session = Depends(get_db)):
|
||||||
if(user.apisecretkey != API_SECRET_KEY):
|
"""
|
||||||
|
Register a new user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user (UserCreate): The user data for registration.
|
||||||
|
db (Session): The database session.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: A message indicating the completion of user registration.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If the API secret key is invalid or the username is already registered.
|
||||||
|
"""
|
||||||
|
if user.apisecretkey != API_SECRET_KEY:
|
||||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||||
|
|
||||||
db_user = get_user_by_username(db, username=user.username)
|
db_user = get_user_by_username(db, username=user.username)
|
||||||
|
@ -329,8 +391,18 @@ def register_user(user: UserCreate, db: Session = Depends(get_db)):
|
||||||
del user.apisecretkey
|
del user.apisecretkey
|
||||||
return create_user(db=db, user=user)
|
return create_user(db=db, user=user)
|
||||||
|
|
||||||
# Authenticate the user
|
|
||||||
def authenticate_user(username: str, password: str, db: Session):
|
def authenticate_user(username: str, password: str, db: Session):
|
||||||
|
"""
|
||||||
|
Authenticate a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
username (str): The username of the user.
|
||||||
|
password (str): The password of the user.
|
||||||
|
db (Session): The database session.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User: The authenticated user object if successful, False otherwise.
|
||||||
|
"""
|
||||||
user = db.query(User).filter(User.username == username).first()
|
user = db.query(User).filter(User.username == username).first()
|
||||||
if not user:
|
if not user:
|
||||||
return False
|
return False
|
||||||
|
@ -338,8 +410,17 @@ def authenticate_user(username: str, password: str, db: Session):
|
||||||
return False
|
return False
|
||||||
return user
|
return user
|
||||||
|
|
||||||
# Create access token
|
|
||||||
def create_access_token(data: dict, expires_delta: timedelta | None = None):
|
def create_access_token(data: dict, expires_delta: timedelta | None = None):
|
||||||
|
"""
|
||||||
|
Create an access token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (dict): The data to encode in the token.
|
||||||
|
expires_delta (timedelta, optional): The expiration time for the token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The encoded JWT token.
|
||||||
|
"""
|
||||||
to_encode = data.copy()
|
to_encode = data.copy()
|
||||||
if expires_delta:
|
if expires_delta:
|
||||||
expire = datetime.utcnow() + expires_delta
|
expire = datetime.utcnow() + expires_delta
|
||||||
|
@ -351,6 +432,19 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None):
|
||||||
|
|
||||||
@app.post("/token")
|
@app.post("/token")
|
||||||
def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
|
def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
Authenticate a user and return an access token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
form_data (OAuth2PasswordRequestForm): The form data containing username and password.
|
||||||
|
db (Session): The database session.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary containing the access token and token type.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If the username or password is incorrect.
|
||||||
|
"""
|
||||||
user = authenticate_user(form_data.username, form_data.password, db)
|
user = authenticate_user(form_data.username, form_data.password, db)
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
@ -365,6 +459,18 @@ def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db:
|
||||||
return {"access_token": access_token, "token_type": "bearer"}
|
return {"access_token": access_token, "token_type": "bearer"}
|
||||||
|
|
||||||
def verify_token(token: str = Depends(oauth2_scheme)):
|
def verify_token(token: str = Depends(oauth2_scheme)):
|
||||||
|
"""
|
||||||
|
Verify the validity of a token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token (str): The token to verify.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The payload of the token if valid.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If the token is invalid or expired.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
username: str = payload.get("sub")
|
username: str = payload.get("sub")
|
||||||
|
@ -376,11 +482,33 @@ def verify_token(token: str = Depends(oauth2_scheme)):
|
||||||
|
|
||||||
@app.get("/verify-token/{token}")
|
@app.get("/verify-token/{token}")
|
||||||
async def verify_user_token(token: str):
|
async def verify_user_token(token: str):
|
||||||
|
"""
|
||||||
|
Verify a user token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token (str): The token to verify.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A message indicating the validity of the token.
|
||||||
|
"""
|
||||||
verify_token(token=token)
|
verify_token(token=token)
|
||||||
return {"message": "Token is valid"}
|
return {"message": "Token is valid"}
|
||||||
|
|
||||||
@app.post("/user/chat/save")
|
@app.post("/user/chat/save")
|
||||||
def populate_user_chat(chat: NewUserChat, db: Session = Depends(get_db)):
|
def populate_user_chat(chat: NewUserChat, db: Session = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
Save a new chat for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat (NewUserChat): The chat data to save.
|
||||||
|
db (Session): The database session.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A message indicating the success of the operation.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If the token is invalid or expired.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(chat.token, SECRET_KEY, algorithms=[ALGORITHM])
|
payload = jwt.decode(chat.token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
username: str = payload.get("sub")
|
username: str = payload.get("sub")
|
||||||
|
@ -400,6 +528,19 @@ def populate_user_chat(chat: NewUserChat, db: Session = Depends(get_db)):
|
||||||
|
|
||||||
@app.post("/user/chat/update")
|
@app.post("/user/chat/update")
|
||||||
def populate_user_chat(chat: ChatToUpdate, db: Session = Depends(get_db)):
|
def populate_user_chat(chat: ChatToUpdate, db: Session = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
Update an existing chat for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat (ChatToUpdate): The chat data to update.
|
||||||
|
db (Session): The database session.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A message indicating the success of the operation.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If the token is invalid or expired.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(chat.token, SECRET_KEY, algorithms=[ALGORITHM])
|
payload = jwt.decode(chat.token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
username: str = payload.get("sub")
|
username: str = payload.get("sub")
|
||||||
|
@ -418,6 +559,20 @@ def populate_user_chat(chat: ChatToUpdate, db: Session = Depends(get_db)):
|
||||||
|
|
||||||
@app.get("/user/chat/delete/{token}/{chatid}")
|
@app.get("/user/chat/delete/{token}/{chatid}")
|
||||||
async def delete_chat_of_user(token: str, chatid: str, db: Session = Depends(get_db)):
|
async def delete_chat_of_user(token: str, chatid: str, db: Session = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
Delete a chat for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token (str): The user's authentication token.
|
||||||
|
chatid (str): The ID of the chat to delete.
|
||||||
|
db (Session): The database session.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A message indicating the success of the operation.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If the token is invalid or expired.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
username: str = payload.get("sub")
|
username: str = payload.get("sub")
|
||||||
|
@ -433,9 +588,21 @@ async def delete_chat_of_user(token: str, chatid: str, db: Session = Depends(get
|
||||||
except JWTError:
|
except JWTError:
|
||||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||||
|
|
||||||
#Gets user id & name
|
|
||||||
@app.get("/user/{token}")
|
@app.get("/user/{token}")
|
||||||
async def get_user_with_token(token: str, db: Session = Depends(get_db)):
|
async def get_user_with_token(token: str, db: Session = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
Get user information using a token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token (str): The user's authentication token.
|
||||||
|
db (Session): The database session.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: User information including ID, username, chats, and documents.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If the token is invalid or expired.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
username: str = payload.get("sub")
|
username: str = payload.get("sub")
|
||||||
|
@ -454,6 +621,19 @@ async def get_user_with_token(token: str, db: Session = Depends(get_db)):
|
||||||
|
|
||||||
@app.get("/searchspaces/{token}")
|
@app.get("/searchspaces/{token}")
|
||||||
async def get_user_with_token(token: str, db: Session = Depends(get_db)):
|
async def get_user_with_token(token: str, db: Session = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
Get all search spaces.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token (str): The user's authentication token.
|
||||||
|
db (Session): The database session.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary containing all search spaces.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If the token is invalid or expired.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
username: str = payload.get("sub")
|
username: str = payload.get("sub")
|
||||||
|
@ -467,12 +647,7 @@ async def get_user_with_token(token: str, db: Session = Depends(get_db)):
|
||||||
except JWTError:
|
except JWTError:
|
||||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
uvicorn.run(app, host="127.0.0.1", port=8000)
|
uvicorn.run(app, host="127.0.0.1", port=8000)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,8 @@ services:
|
||||||
- postgres_data:/var/lib/postgresql/data
|
- postgres_data:/var/lib/postgresql/data
|
||||||
networks:
|
networks:
|
||||||
- surfsense-network
|
- surfsense-network
|
||||||
|
ports:
|
||||||
|
- "5432:5432"
|
||||||
|
|
||||||
# Backend Service (FastAPI)
|
# Backend Service (FastAPI)
|
||||||
backend:
|
backend:
|
||||||
|
@ -23,6 +25,7 @@ services:
|
||||||
- ./backend/.env
|
- ./backend/.env
|
||||||
depends_on:
|
depends_on:
|
||||||
- db
|
- db
|
||||||
|
privileged: true
|
||||||
networks:
|
networks:
|
||||||
- surfsense-network
|
- surfsense-network
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue