Docker Support

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2024-09-25 14:39:06 -07:00
parent 778c16f8bf
commit d247cacaa9
5 changed files with 267 additions and 79 deletions

@ -1 +1 @@
Subproject commit aeb87beee51bbf040307157d6c54b3a3a82ef620 Subproject commit 4d0d1414c4774704619834ec8e9280940976f500

View file

@ -1,7 +1,7 @@
#true if you wana run local setup with Ollama llama3.1 #true if you wana run local setup with Ollama llama3.1
IS_LOCAL_SETUP = 'false' IS_LOCAL_SETUP = 'false'
#POSTGRES DB TO TRACK USERS #POSTGRES DB TO TRACK USERS | replace localhost with db for docker setups eg "postgresql+psycopg2://postgres:postgres@db:5432/surfsense"
POSTGRES_DATABASE_URL = "postgresql+psycopg2://postgres:postgres@localhost:5432/surfsense" POSTGRES_DATABASE_URL = "postgresql+psycopg2://postgres:postgres@localhost:5432/surfsense"
# API KEY TO PREVENT USER REGISTRATION SPAM # API KEY TO PREVENT USER REGISTRATION SPAM

View file

@ -23,16 +23,26 @@ IS_LOCAL_SETUP = os.environ.get("IS_LOCAL_SETUP")
# Dependency # Dependency
def get_db(): def get_db():
"""
Dependency to get a database session.
Yields:
Session: A SQLAlchemy database session.
"""
db = SessionLocal() db = SessionLocal()
try: try:
yield db yield db
finally: finally:
db.close() db.close()
class HIndices: class HIndices:
def __init__(self, username, api_key='local'): def __init__(self, username, api_key='local'):
""" """
Initialize the HIndices object.
Args:
username (str): The username for the indices.
api_key (str, optional): API key for non-local setups. Defaults to 'local'.
""" """
self.username = username self.username = username
if(IS_LOCAL_SETUP == 'true'): if(IS_LOCAL_SETUP == 'true'):

View file

@ -8,7 +8,7 @@ from prompts import CONTEXT_ANSWER_PROMPT, DATE_TODAY, SUBQUERY_DECOMPOSITION_PR
from pydmodels import ChatToUpdate, DescriptionResponse, DocWithContent, DocumentsToDelete, NewUserChat, UserCreate, UserQuery, RetrivedDocList, UserQueryResponse, UserQueryWithChatHistory from pydmodels import ChatToUpdate, DescriptionResponse, DocWithContent, DocumentsToDelete, NewUserChat, UserCreate, UserQuery, RetrivedDocList, UserQueryResponse, UserQueryWithChatHistory
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
#Heirerical Indices class # Hierarchical Indices class
from HIndices import HIndices from HIndices import HIndices
from Utils.stringify import stringify from Utils.stringify import stringify
@ -24,11 +24,11 @@ from models import Chat, Documents, SearchSpace, User
from database import SessionLocal from database import SessionLocal
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
import os import os
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
# Environment variables
IS_LOCAL_SETUP = os.environ.get("IS_LOCAL_SETUP") IS_LOCAL_SETUP = os.environ.get("IS_LOCAL_SETUP")
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.environ.get("ACCESS_TOKEN_EXPIRE_MINUTES")) ACCESS_TOKEN_EXPIRE_MINUTES = int(os.environ.get("ACCESS_TOKEN_EXPIRE_MINUTES"))
ALGORITHM = os.environ.get("ALGORITHM") ALGORITHM = os.environ.get("ALGORITHM")
@ -37,17 +37,33 @@ SECRET_KEY = os.environ.get("SECRET_KEY")
app = FastAPI() app = FastAPI()
# Dependency
def get_db(): def get_db():
"""
Dependency to get a database session.
Yields:
Session: A SQLAlchemy database session.
"""
db = SessionLocal() db = SessionLocal()
try: try:
yield db yield db
finally: finally:
db.close() db.close()
@app.post("/chat/") @app.post("/chat/")
def get_user_query_response(data: UserQuery, response_model=UserQueryResponse): def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
"""
Process a user query and return a response.
Args:
data (UserQuery): The user query data.
Returns:
UserQueryResponse: The response to the user query.
Raises:
HTTPException: If the token is invalid or expired.
"""
try: try:
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM]) payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub") username: str = payload.get("sub")
@ -57,15 +73,14 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
query = data.query query = data.query
search_space = data.search_space search_space = data.search_space
if(IS_LOCAL_SETUP == 'true'): # Initialize LLM based on setup
sub_query_llm = OllamaLLM(model="mistral-nemo",temperature=0) if IS_LOCAL_SETUP == 'true':
qa_llm = OllamaLLM(model="mistral-nemo",temperature=0) sub_query_llm = OllamaLLM(model="mistral-nemo", temperature=0)
qa_llm = OllamaLLM(model="mistral-nemo", temperature=0)
else: else:
sub_query_llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", api_key=data.openaikey) sub_query_llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", api_key=data.openaikey)
qa_llm = ChatOpenAI(temperature=0.5, model_name="gpt-4o-mini", api_key=data.openaikey) qa_llm = ChatOpenAI(temperature=0.5, model_name="gpt-4o-mini", api_key=data.openaikey)
# Create an LLMChain for sub-query decomposition # Create an LLMChain for sub-query decomposition
subquery_decomposer_chain = SUBQUERY_DECOMPOSITION_PROMT | sub_query_llm subquery_decomposer_chain = SUBQUERY_DECOMPOSITION_PROMT | sub_query_llm
@ -74,12 +89,12 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
Decompose the original query into simpler sub-queries. Decompose the original query into simpler sub-queries.
Args: Args:
original_query (str): The original complex query original_query (str): The original complex query
Returns: Returns:
List[str]: A list of simpler sub-queries List[str]: A list of simpler sub-queries
""" """
if(IS_LOCAL_SETUP == 'true'): if IS_LOCAL_SETUP == 'true':
response = subquery_decomposer_chain.invoke(original_query) response = subquery_decomposer_chain.invoke(original_query)
else: else:
response = subquery_decomposer_chain.invoke(original_query).content response = subquery_decomposer_chain.invoke(original_query).content
@ -87,20 +102,16 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
sub_queries = [q.strip() for q in response.split('\n') if q.strip() and not q.strip().startswith('Sub-queries:')] sub_queries = [q.strip() for q in response.split('\n') if q.strip() and not q.strip().startswith('Sub-queries:')]
return sub_queries return sub_queries
# Create Hierarchical Indices
# Create Heirarical Indecices if IS_LOCAL_SETUP == 'true':
if(IS_LOCAL_SETUP == 'true'):
index = HIndices(username=username) index = HIndices(username=username)
else: else:
index = HIndices(username=username,api_key=data.openaikey) index = HIndices(username=username, api_key=data.openaikey)
# For Those Who Want HyDe Questions # For those who want HyDE questions
# sub_queries = decompose_query(query) # sub_queries = decompose_query(query)
sub_queries = [] sub_queries = [query]
sub_queries.append(query)
duplicate_related_summary_docs = [] duplicate_related_summary_docs = []
context_to_answer = "" context_to_answer = ""
@ -112,12 +123,11 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
duplicate_related_summary_docs.extend(related_summary_docs) duplicate_related_summary_docs.extend(related_summary_docs)
# Remove duplicate documents
combined_docs_seen_metadata = set() combined_docs_seen_metadata = set()
combined_docs_unique_documents = [] combined_docs_unique_documents = []
for doc in duplicate_related_summary_docs: for doc in duplicate_related_summary_docs:
# Convert metadata to a tuple of its items (this allows it to be added to a set)
doc.metadata['relevance_score'] = 0.0 doc.metadata['relevance_score'] = 0.0
metadata_tuple = tuple(sorted(doc.metadata.items())) metadata_tuple = tuple(sorted(doc.metadata.items()))
if metadata_tuple not in combined_docs_seen_metadata: if metadata_tuple not in combined_docs_seen_metadata:
@ -134,29 +144,37 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
VisitedWebPageDateWithTimeInISOString=doc.metadata['VisitedWebPageDateWithTimeInISOString'], VisitedWebPageDateWithTimeInISOString=doc.metadata['VisitedWebPageDateWithTimeInISOString'],
VisitedWebPageReffererURL=doc.metadata['VisitedWebPageReffererURL'], VisitedWebPageReffererURL=doc.metadata['VisitedWebPageReffererURL'],
VisitedWebPageVisitDurationInMilliseconds=doc.metadata['VisitedWebPageVisitDurationInMilliseconds'], VisitedWebPageVisitDurationInMilliseconds=doc.metadata['VisitedWebPageVisitDurationInMilliseconds'],
) )
returnDocs.append(entry) returnDocs.append(entry)
# Generate final answer
ans_chain = CONTEXT_ANSWER_PROMPT | qa_llm ans_chain = CONTEXT_ANSWER_PROMPT | qa_llm
finalans = ans_chain.invoke({"query": query, "context": context_to_answer}) finalans = ans_chain.invoke({"query": query, "context": context_to_answer})
if IS_LOCAL_SETUP == 'true':
if(IS_LOCAL_SETUP == 'true'):
return UserQueryResponse(response=finalans, relateddocs=returnDocs) return UserQueryResponse(response=finalans, relateddocs=returnDocs)
else: else:
return UserQueryResponse(response=finalans.content, relateddocs=returnDocs) return UserQueryResponse(response=finalans.content, relateddocs=returnDocs)
except JWTError: except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired") raise HTTPException(status_code=403, detail="Token is invalid or expired")
# SAVE DOCS
@app.post("/save/") @app.post("/save/")
def save_data(apires: RetrivedDocList, db: Session = Depends(get_db)): def save_data(apires: RetrivedDocList, db: Session = Depends(get_db)):
"""
Save retrieved documents to the database and vector stores.
Args:
apires (RetrivedDocList): The list of documents to save.
db (Session): The database session.
Returns:
dict: A message indicating the success of the operation.
Raises:
HTTPException: If the token is invalid or expired.
"""
try: try:
payload = jwt.decode(apires.token, SECRET_KEY, algorithms=[ALGORITHM]) payload = jwt.decode(apires.token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub") username: str = payload.get("sub")
@ -181,31 +199,29 @@ def save_data(apires: RetrivedDocList, db: Session = Depends(get_db)):
content += f"===================================================================================== \n" content += f"===================================================================================== \n"
content += f"Webpage Content of the visited webpage url in markdown format : \n\n {doc.pageContent} \n\n" content += f"Webpage Content of the visited webpage url in markdown format : \n\n {doc.pageContent} \n\n"
content += f"===================================================================================== \n" content += f"===================================================================================== \n"
raw_documents.append(Document(page_content=content,metadata=doc.metadata.__dict__)) raw_documents.append(Document(page_content=content, metadata=doc.metadata.__dict__))
pgdocmeta = stringify(doc.metadata.__dict__) pgdocmeta = stringify(doc.metadata.__dict__)
if(searchspace): if searchspace:
DocumentPgEntry.append(Documents(file_type='WEBPAGE',title=doc.metadata.VisitedWebPageTitle,search_space=searchspace, document_metadata=pgdocmeta, page_content=content)) DocumentPgEntry.append(Documents(file_type='WEBPAGE', title=doc.metadata.VisitedWebPageTitle, search_space=searchspace, document_metadata=pgdocmeta, page_content=content))
else: else:
DocumentPgEntry.append(Documents(file_type='WEBPAGE',title=doc.metadata.VisitedWebPageTitle,search_space=SearchSpace(search_space=apires.search_space.upper()), document_metadata=pgdocmeta, page_content=content)) DocumentPgEntry.append(Documents(file_type='WEBPAGE', title=doc.metadata.VisitedWebPageTitle, search_space=SearchSpace(search_space=apires.search_space.upper()), document_metadata=pgdocmeta, page_content=content))
# Save docs in PG
#Save docs in PG
user = db.query(User).filter(User.username == username).first() user = db.query(User).filter(User.username == username).first()
user.documents.extend(DocumentPgEntry) user.documents.extend(DocumentPgEntry)
db.commit() db.commit()
# Create Hierarchical Indices
# Create Heirarical Indecices if IS_LOCAL_SETUP == 'true':
if(IS_LOCAL_SETUP == 'true'):
index = HIndices(username=username) index = HIndices(username=username)
else: else:
index = HIndices(username=username,api_key=apires.openaikey) index = HIndices(username=username, api_key=apires.openaikey)
#Save Indices in vector Stores # Save Indices in vector stores
index.encode_docs_hierarchical(documents=raw_documents, files_type='WEBPAGE',search_space=apires.search_space.upper(), db=db) index.encode_docs_hierarchical(documents=raw_documents, files_type='WEBPAGE', search_space=apires.search_space.upper(), db=db)
print("FINISHED") print("FINISHED")
@ -216,45 +232,55 @@ def save_data(apires: RetrivedDocList, db: Session = Depends(get_db)):
except JWTError: except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired") raise HTTPException(status_code=403, detail="Token is invalid or expired")
# Multi DOC Chat
@app.post("/chat/docs") @app.post("/chat/docs")
def doc_chat_with_history(data: UserQueryWithChatHistory, response_model=DescriptionResponse): def doc_chat_with_history(data: UserQueryWithChatHistory, response_model=DescriptionResponse):
"""
Process a user query with chat history and return a response.
Args:
data (UserQueryWithChatHistory): The user query data with chat history.
Returns:
DescriptionResponse: The response to the user query.
Raises:
HTTPException: If the token is invalid or expired.
"""
try: try:
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM]) payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub") username: str = payload.get("sub")
if username is None: if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired") raise HTTPException(status_code=403, detail="Token is invalid or expired")
if(IS_LOCAL_SETUP == 'true'): if IS_LOCAL_SETUP == 'true':
llm = OllamaLLM(model="mistral-nemo",temperature=0) llm = OllamaLLM(model="mistral-nemo", temperature=0)
else: else:
llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", api_key=data.openaikey) llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", api_key=data.openaikey)
chatHistory = [] chatHistory = []
for chat in data.chat: for chat in data.chat:
if(chat.type == 'system'): if chat.type == 'system':
chatHistory.append(SystemMessage(content=DATE_TODAY + """You are an helpful assistant for question-answering tasks. chatHistory.append(SystemMessage(content=DATE_TODAY + """You are a helpful assistant for question-answering tasks.
Use the following pieces of retrieved context to answer the question. Use the following pieces of retrieved context to answer the question.
If you don't know the answer, just say that you don't know. If you don't know the answer, just say that you don't know.
Context:""" + str(chat.content))) Context:""" + str(chat.content)))
if(chat.type == 'ai'): if chat.type == 'ai':
chatHistory.append(AIMessage(content=chat.content)) chatHistory.append(AIMessage(content=chat.content))
if(chat.type == 'human'): if chat.type == 'human':
chatHistory.append(HumanMessage(content=chat.content)) chatHistory.append(HumanMessage(content=chat.content))
chatHistory.append(("human", "{input}")); chatHistory.append(("human", "{input}"))
qa_prompt = ChatPromptTemplate.from_messages(chatHistory) qa_prompt = ChatPromptTemplate.from_messages(chatHistory)
descriptionchain = qa_prompt | llm descriptionchain = qa_prompt | llm
response = descriptionchain.invoke({"input": data.query}) response = descriptionchain.invoke({"input": data.query})
if(IS_LOCAL_SETUP == 'true'): if IS_LOCAL_SETUP == 'true':
return DescriptionResponse(response=response) return DescriptionResponse(response=response)
else: else:
return DescriptionResponse(response=response.content) return DescriptionResponse(response=response.content)
@ -262,23 +288,33 @@ def doc_chat_with_history(data: UserQueryWithChatHistory, response_model=Descrip
except JWTError: except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired") raise HTTPException(status_code=403, detail="Token is invalid or expired")
# Multi DOC Chat
@app.post("/delete/docs") @app.post("/delete/docs")
def delete_all_related_data(data: DocumentsToDelete, db: Session = Depends(get_db)): def delete_all_related_data(data: DocumentsToDelete, db: Session = Depends(get_db)):
"""
Delete documents and related data.
Args:
data (DocumentsToDelete): The data containing documents to delete.
db (Session): The database session.
Returns:
dict: A message indicating the result of the deletion.
Raises:
HTTPException: If the token is invalid or expired.
"""
try: try:
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM]) payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub") username: str = payload.get("sub")
if username is None: if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired") raise HTTPException(status_code=403, detail="Token is invalid or expired")
if(IS_LOCAL_SETUP == 'true'): if IS_LOCAL_SETUP == 'true':
index = HIndices(username=username) index = HIndices(username=username)
else: else:
index = HIndices(username=username,api_key=data.openaikey) index = HIndices(username=username, api_key=data.openaikey)
message = index.delete_vector_stores(summary_ids_to_delete=data.ids_to_delete,db=db ) message = index.delete_vector_stores(summary_ids_to_delete=data.ids_to_delete, db=db)
return { return {
"message": message "message": message
@ -287,19 +323,12 @@ def delete_all_related_data(data: DocumentsToDelete, db: Session = Depends(get_d
except JWTError: except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired") raise HTTPException(status_code=403, detail="Token is invalid or expired")
# AUTH CODE
#AUTH CODE
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# Manual Origins
# origins = [
# "http://localhost:3000", # Adjust the port if your frontend runs on a different one
# "https://yourfrontenddomain.com",
# ]
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=["*"], # Allows all origins from the list allow_origins=["*"], # Allows all origins
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], # Allows all methods allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers allow_headers=["*"], # Allows all headers
@ -308,9 +337,29 @@ app.add_middleware(
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def get_user_by_username(db: Session, username: str): def get_user_by_username(db: Session, username: str):
"""
Retrieve a user by username from the database.
Args:
db (Session): The database session.
username (str): The username to search for.
Returns:
User: The user object if found, None otherwise.
"""
return db.query(User).filter(User.username == username).first() return db.query(User).filter(User.username == username).first()
def create_user(db: Session, user: UserCreate): def create_user(db: Session, user: UserCreate):
"""
Create a new user in the database.
Args:
db (Session): The database session.
user (UserCreate): The user data to create.
Returns:
str: A message indicating the completion of user creation.
"""
hashed_password = pwd_context.hash(user.password) hashed_password = pwd_context.hash(user.password)
db_user = User(username=user.username, hashed_password=hashed_password) db_user = User(username=user.username, hashed_password=hashed_password)
db.add(db_user) db.add(db_user)
@ -319,7 +368,20 @@ def create_user(db: Session, user: UserCreate):
@app.post("/register") @app.post("/register")
def register_user(user: UserCreate, db: Session = Depends(get_db)): def register_user(user: UserCreate, db: Session = Depends(get_db)):
if(user.apisecretkey != API_SECRET_KEY): """
Register a new user.
Args:
user (UserCreate): The user data for registration.
db (Session): The database session.
Returns:
str: A message indicating the completion of user registration.
Raises:
HTTPException: If the API secret key is invalid or the username is already registered.
"""
if user.apisecretkey != API_SECRET_KEY:
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
db_user = get_user_by_username(db, username=user.username) db_user = get_user_by_username(db, username=user.username)
@ -329,8 +391,18 @@ def register_user(user: UserCreate, db: Session = Depends(get_db)):
del user.apisecretkey del user.apisecretkey
return create_user(db=db, user=user) return create_user(db=db, user=user)
# Authenticate the user
def authenticate_user(username: str, password: str, db: Session): def authenticate_user(username: str, password: str, db: Session):
"""
Authenticate a user.
Args:
username (str): The username of the user.
password (str): The password of the user.
db (Session): The database session.
Returns:
User: The authenticated user object if successful, False otherwise.
"""
user = db.query(User).filter(User.username == username).first() user = db.query(User).filter(User.username == username).first()
if not user: if not user:
return False return False
@ -338,8 +410,17 @@ def authenticate_user(username: str, password: str, db: Session):
return False return False
return user return user
# Create access token
def create_access_token(data: dict, expires_delta: timedelta | None = None): def create_access_token(data: dict, expires_delta: timedelta | None = None):
"""
Create an access token.
Args:
data (dict): The data to encode in the token.
expires_delta (timedelta, optional): The expiration time for the token.
Returns:
str: The encoded JWT token.
"""
to_encode = data.copy() to_encode = data.copy()
if expires_delta: if expires_delta:
expire = datetime.utcnow() + expires_delta expire = datetime.utcnow() + expires_delta
@ -351,6 +432,19 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None):
@app.post("/token") @app.post("/token")
def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)): def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
"""
Authenticate a user and return an access token.
Args:
form_data (OAuth2PasswordRequestForm): The form data containing username and password.
db (Session): The database session.
Returns:
dict: A dictionary containing the access token and token type.
Raises:
HTTPException: If the username or password is incorrect.
"""
user = authenticate_user(form_data.username, form_data.password, db) user = authenticate_user(form_data.username, form_data.password, db)
if not user: if not user:
raise HTTPException( raise HTTPException(
@ -365,6 +459,18 @@ def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db:
return {"access_token": access_token, "token_type": "bearer"} return {"access_token": access_token, "token_type": "bearer"}
def verify_token(token: str = Depends(oauth2_scheme)): def verify_token(token: str = Depends(oauth2_scheme)):
"""
Verify the validity of a token.
Args:
token (str): The token to verify.
Returns:
dict: The payload of the token if valid.
Raises:
HTTPException: If the token is invalid or expired.
"""
try: try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub") username: str = payload.get("sub")
@ -376,11 +482,33 @@ def verify_token(token: str = Depends(oauth2_scheme)):
@app.get("/verify-token/{token}") @app.get("/verify-token/{token}")
async def verify_user_token(token: str): async def verify_user_token(token: str):
"""
Verify a user token.
Args:
token (str): The token to verify.
Returns:
dict: A message indicating the validity of the token.
"""
verify_token(token=token) verify_token(token=token)
return {"message": "Token is valid"} return {"message": "Token is valid"}
@app.post("/user/chat/save") @app.post("/user/chat/save")
def populate_user_chat(chat: NewUserChat, db: Session = Depends(get_db)): def populate_user_chat(chat: NewUserChat, db: Session = Depends(get_db)):
"""
Save a new chat for a user.
Args:
chat (NewUserChat): The chat data to save.
db (Session): The database session.
Returns:
dict: A message indicating the success of the operation.
Raises:
HTTPException: If the token is invalid or expired.
"""
try: try:
payload = jwt.decode(chat.token, SECRET_KEY, algorithms=[ALGORITHM]) payload = jwt.decode(chat.token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub") username: str = payload.get("sub")
@ -400,6 +528,19 @@ def populate_user_chat(chat: NewUserChat, db: Session = Depends(get_db)):
@app.post("/user/chat/update") @app.post("/user/chat/update")
def populate_user_chat(chat: ChatToUpdate, db: Session = Depends(get_db)): def populate_user_chat(chat: ChatToUpdate, db: Session = Depends(get_db)):
"""
Update an existing chat for a user.
Args:
chat (ChatToUpdate): The chat data to update.
db (Session): The database session.
Returns:
dict: A message indicating the success of the operation.
Raises:
HTTPException: If the token is invalid or expired.
"""
try: try:
payload = jwt.decode(chat.token, SECRET_KEY, algorithms=[ALGORITHM]) payload = jwt.decode(chat.token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub") username: str = payload.get("sub")
@ -418,6 +559,20 @@ def populate_user_chat(chat: ChatToUpdate, db: Session = Depends(get_db)):
@app.get("/user/chat/delete/{token}/{chatid}") @app.get("/user/chat/delete/{token}/{chatid}")
async def delete_chat_of_user(token: str, chatid: str, db: Session = Depends(get_db)): async def delete_chat_of_user(token: str, chatid: str, db: Session = Depends(get_db)):
"""
Delete a chat for a user.
Args:
token (str): The user's authentication token.
chatid (str): The ID of the chat to delete.
db (Session): The database session.
Returns:
dict: A message indicating the success of the operation.
Raises:
HTTPException: If the token is invalid or expired.
"""
try: try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub") username: str = payload.get("sub")
@ -433,9 +588,21 @@ async def delete_chat_of_user(token: str, chatid: str, db: Session = Depends(get
except JWTError: except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired") raise HTTPException(status_code=403, detail="Token is invalid or expired")
#Gets user id & name
@app.get("/user/{token}") @app.get("/user/{token}")
async def get_user_with_token(token: str, db: Session = Depends(get_db)): async def get_user_with_token(token: str, db: Session = Depends(get_db)):
"""
Get user information using a token.
Args:
token (str): The user's authentication token.
db (Session): The database session.
Returns:
dict: User information including ID, username, chats, and documents.
Raises:
HTTPException: If the token is invalid or expired.
"""
try: try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub") username: str = payload.get("sub")
@ -454,6 +621,19 @@ async def get_user_with_token(token: str, db: Session = Depends(get_db)):
@app.get("/searchspaces/{token}") @app.get("/searchspaces/{token}")
async def get_user_with_token(token: str, db: Session = Depends(get_db)): async def get_user_with_token(token: str, db: Session = Depends(get_db)):
"""
Get all search spaces.
Args:
token (str): The user's authentication token.
db (Session): The database session.
Returns:
dict: A dictionary containing all search spaces.
Raises:
HTTPException: If the token is invalid or expired.
"""
try: try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub") username: str = payload.get("sub")
@ -467,12 +647,7 @@ async def get_user_with_token(token: str, db: Session = Depends(get_db)):
except JWTError: except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired") raise HTTPException(status_code=403, detail="Token is invalid or expired")
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8000) uvicorn.run(app, host="127.0.0.1", port=8000)

View file

@ -12,6 +12,8 @@ services:
- postgres_data:/var/lib/postgresql/data - postgres_data:/var/lib/postgresql/data
networks: networks:
- surfsense-network - surfsense-network
ports:
- "5432:5432"
# Backend Service (FastAPI) # Backend Service (FastAPI)
backend: backend:
@ -23,6 +25,7 @@ services:
- ./backend/.env - ./backend/.env
depends_on: depends_on:
- db - db
privileged: true
networks: networks:
- surfsense-network - surfsense-network