mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-01 10:09:08 +00:00
Docker Support
This commit is contained in:
parent
778c16f8bf
commit
d247cacaa9
5 changed files with 267 additions and 79 deletions
|
@ -1 +1 @@
|
|||
Subproject commit aeb87beee51bbf040307157d6c54b3a3a82ef620
|
||||
Subproject commit 4d0d1414c4774704619834ec8e9280940976f500
|
|
@ -1,7 +1,7 @@
|
|||
#true if you wana run local setup with Ollama llama3.1
|
||||
IS_LOCAL_SETUP = 'false'
|
||||
|
||||
#POSTGRES DB TO TRACK USERS
|
||||
#POSTGRES DB TO TRACK USERS | replace localhost with db for docker setups eg "postgresql+psycopg2://postgres:postgres@db:5432/surfsense"
|
||||
POSTGRES_DATABASE_URL = "postgresql+psycopg2://postgres:postgres@localhost:5432/surfsense"
|
||||
|
||||
# API KEY TO PREVENT USER REGISTRATION SPAM
|
||||
|
|
|
@ -23,16 +23,26 @@ IS_LOCAL_SETUP = os.environ.get("IS_LOCAL_SETUP")
|
|||
|
||||
# Dependency
|
||||
def get_db():
|
||||
"""
|
||||
Dependency to get a database session.
|
||||
|
||||
Yields:
|
||||
Session: A SQLAlchemy database session.
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
class HIndices:
|
||||
def __init__(self, username, api_key='local'):
|
||||
"""
|
||||
Initialize the HIndices object.
|
||||
|
||||
Args:
|
||||
username (str): The username for the indices.
|
||||
api_key (str, optional): API key for non-local setups. Defaults to 'local'.
|
||||
"""
|
||||
self.username = username
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
|
|
|
@ -8,7 +8,7 @@ from prompts import CONTEXT_ANSWER_PROMPT, DATE_TODAY, SUBQUERY_DECOMPOSITION_PR
|
|||
from pydmodels import ChatToUpdate, DescriptionResponse, DocWithContent, DocumentsToDelete, NewUserChat, UserCreate, UserQuery, RetrivedDocList, UserQueryResponse, UserQueryWithChatHistory
|
||||
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
||||
|
||||
#Heirerical Indices class
|
||||
# Hierarchical Indices class
|
||||
from HIndices import HIndices
|
||||
|
||||
from Utils.stringify import stringify
|
||||
|
@ -24,11 +24,11 @@ from models import Chat, Documents, SearchSpace, User
|
|||
from database import SessionLocal
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
# Environment variables
|
||||
IS_LOCAL_SETUP = os.environ.get("IS_LOCAL_SETUP")
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.environ.get("ACCESS_TOKEN_EXPIRE_MINUTES"))
|
||||
ALGORITHM = os.environ.get("ALGORITHM")
|
||||
|
@ -37,17 +37,33 @@ SECRET_KEY = os.environ.get("SECRET_KEY")
|
|||
|
||||
app = FastAPI()
|
||||
|
||||
# Dependency
|
||||
def get_db():
|
||||
"""
|
||||
Dependency to get a database session.
|
||||
|
||||
Yields:
|
||||
Session: A SQLAlchemy database session.
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@app.post("/chat/")
|
||||
def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
|
||||
"""
|
||||
Process a user query and return a response.
|
||||
|
||||
Args:
|
||||
data (UserQuery): The user query data.
|
||||
|
||||
Returns:
|
||||
UserQueryResponse: The response to the user query.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the token is invalid or expired.
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
|
@ -57,15 +73,14 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
|
|||
query = data.query
|
||||
search_space = data.search_space
|
||||
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
sub_query_llm = OllamaLLM(model="mistral-nemo",temperature=0)
|
||||
qa_llm = OllamaLLM(model="mistral-nemo",temperature=0)
|
||||
# Initialize LLM based on setup
|
||||
if IS_LOCAL_SETUP == 'true':
|
||||
sub_query_llm = OllamaLLM(model="mistral-nemo", temperature=0)
|
||||
qa_llm = OllamaLLM(model="mistral-nemo", temperature=0)
|
||||
else:
|
||||
sub_query_llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", api_key=data.openaikey)
|
||||
qa_llm = ChatOpenAI(temperature=0.5, model_name="gpt-4o-mini", api_key=data.openaikey)
|
||||
|
||||
|
||||
|
||||
# Create an LLMChain for sub-query decomposition
|
||||
subquery_decomposer_chain = SUBQUERY_DECOMPOSITION_PROMT | sub_query_llm
|
||||
|
||||
|
@ -74,12 +89,12 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
|
|||
Decompose the original query into simpler sub-queries.
|
||||
|
||||
Args:
|
||||
original_query (str): The original complex query
|
||||
original_query (str): The original complex query
|
||||
|
||||
Returns:
|
||||
List[str]: A list of simpler sub-queries
|
||||
List[str]: A list of simpler sub-queries
|
||||
"""
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
if IS_LOCAL_SETUP == 'true':
|
||||
response = subquery_decomposer_chain.invoke(original_query)
|
||||
else:
|
||||
response = subquery_decomposer_chain.invoke(original_query).content
|
||||
|
@ -87,20 +102,16 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
|
|||
sub_queries = [q.strip() for q in response.split('\n') if q.strip() and not q.strip().startswith('Sub-queries:')]
|
||||
return sub_queries
|
||||
|
||||
|
||||
# Create Heirarical Indecices
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
# Create Hierarchical Indices
|
||||
if IS_LOCAL_SETUP == 'true':
|
||||
index = HIndices(username=username)
|
||||
else:
|
||||
index = HIndices(username=username,api_key=data.openaikey)
|
||||
|
||||
|
||||
index = HIndices(username=username, api_key=data.openaikey)
|
||||
|
||||
# For Those Who Want HyDe Questions
|
||||
# For those who want HyDE questions
|
||||
# sub_queries = decompose_query(query)
|
||||
|
||||
sub_queries = []
|
||||
sub_queries.append(query)
|
||||
sub_queries = [query]
|
||||
|
||||
duplicate_related_summary_docs = []
|
||||
context_to_answer = ""
|
||||
|
@ -112,12 +123,11 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
|
|||
|
||||
duplicate_related_summary_docs.extend(related_summary_docs)
|
||||
|
||||
|
||||
# Remove duplicate documents
|
||||
combined_docs_seen_metadata = set()
|
||||
combined_docs_unique_documents = []
|
||||
|
||||
for doc in duplicate_related_summary_docs:
|
||||
# Convert metadata to a tuple of its items (this allows it to be added to a set)
|
||||
doc.metadata['relevance_score'] = 0.0
|
||||
metadata_tuple = tuple(sorted(doc.metadata.items()))
|
||||
if metadata_tuple not in combined_docs_seen_metadata:
|
||||
|
@ -134,29 +144,37 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
|
|||
VisitedWebPageDateWithTimeInISOString=doc.metadata['VisitedWebPageDateWithTimeInISOString'],
|
||||
VisitedWebPageReffererURL=doc.metadata['VisitedWebPageReffererURL'],
|
||||
VisitedWebPageVisitDurationInMilliseconds=doc.metadata['VisitedWebPageVisitDurationInMilliseconds'],
|
||||
)
|
||||
)
|
||||
|
||||
returnDocs.append(entry)
|
||||
|
||||
|
||||
# Generate final answer
|
||||
ans_chain = CONTEXT_ANSWER_PROMPT | qa_llm
|
||||
|
||||
finalans = ans_chain.invoke({"query": query, "context": context_to_answer})
|
||||
|
||||
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
if IS_LOCAL_SETUP == 'true':
|
||||
return UserQueryResponse(response=finalans, relateddocs=returnDocs)
|
||||
else:
|
||||
return UserQueryResponse(response=finalans.content, relateddocs=returnDocs)
|
||||
|
||||
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
# SAVE DOCS
|
||||
@app.post("/save/")
|
||||
def save_data(apires: RetrivedDocList, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Save retrieved documents to the database and vector stores.
|
||||
|
||||
Args:
|
||||
apires (RetrivedDocList): The list of documents to save.
|
||||
db (Session): The database session.
|
||||
|
||||
Returns:
|
||||
dict: A message indicating the success of the operation.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the token is invalid or expired.
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(apires.token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
|
@ -181,31 +199,29 @@ def save_data(apires: RetrivedDocList, db: Session = Depends(get_db)):
|
|||
content += f"===================================================================================== \n"
|
||||
content += f"Webpage Content of the visited webpage url in markdown format : \n\n {doc.pageContent} \n\n"
|
||||
content += f"===================================================================================== \n"
|
||||
raw_documents.append(Document(page_content=content,metadata=doc.metadata.__dict__))
|
||||
raw_documents.append(Document(page_content=content, metadata=doc.metadata.__dict__))
|
||||
|
||||
pgdocmeta = stringify(doc.metadata.__dict__)
|
||||
|
||||
if(searchspace):
|
||||
DocumentPgEntry.append(Documents(file_type='WEBPAGE',title=doc.metadata.VisitedWebPageTitle,search_space=searchspace, document_metadata=pgdocmeta, page_content=content))
|
||||
if searchspace:
|
||||
DocumentPgEntry.append(Documents(file_type='WEBPAGE', title=doc.metadata.VisitedWebPageTitle, search_space=searchspace, document_metadata=pgdocmeta, page_content=content))
|
||||
else:
|
||||
DocumentPgEntry.append(Documents(file_type='WEBPAGE',title=doc.metadata.VisitedWebPageTitle,search_space=SearchSpace(search_space=apires.search_space.upper()), document_metadata=pgdocmeta, page_content=content))
|
||||
DocumentPgEntry.append(Documents(file_type='WEBPAGE', title=doc.metadata.VisitedWebPageTitle, search_space=SearchSpace(search_space=apires.search_space.upper()), document_metadata=pgdocmeta, page_content=content))
|
||||
|
||||
|
||||
#Save docs in PG
|
||||
# Save docs in PG
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
user.documents.extend(DocumentPgEntry)
|
||||
|
||||
db.commit()
|
||||
|
||||
|
||||
# Create Heirarical Indecices
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
# Create Hierarchical Indices
|
||||
if IS_LOCAL_SETUP == 'true':
|
||||
index = HIndices(username=username)
|
||||
else:
|
||||
index = HIndices(username=username,api_key=apires.openaikey)
|
||||
index = HIndices(username=username, api_key=apires.openaikey)
|
||||
|
||||
#Save Indices in vector Stores
|
||||
index.encode_docs_hierarchical(documents=raw_documents, files_type='WEBPAGE',search_space=apires.search_space.upper(), db=db)
|
||||
# Save Indices in vector stores
|
||||
index.encode_docs_hierarchical(documents=raw_documents, files_type='WEBPAGE', search_space=apires.search_space.upper(), db=db)
|
||||
|
||||
print("FINISHED")
|
||||
|
||||
|
@ -216,45 +232,55 @@ def save_data(apires: RetrivedDocList, db: Session = Depends(get_db)):
|
|||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
# Multi DOC Chat
|
||||
@app.post("/chat/docs")
|
||||
def doc_chat_with_history(data: UserQueryWithChatHistory, response_model=DescriptionResponse):
|
||||
"""
|
||||
Process a user query with chat history and return a response.
|
||||
|
||||
Args:
|
||||
data (UserQueryWithChatHistory): The user query data with chat history.
|
||||
|
||||
Returns:
|
||||
DescriptionResponse: The response to the user query.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the token is invalid or expired.
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
llm = OllamaLLM(model="mistral-nemo",temperature=0)
|
||||
if IS_LOCAL_SETUP == 'true':
|
||||
llm = OllamaLLM(model="mistral-nemo", temperature=0)
|
||||
else:
|
||||
llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", api_key=data.openaikey)
|
||||
|
||||
chatHistory = []
|
||||
|
||||
for chat in data.chat:
|
||||
if(chat.type == 'system'):
|
||||
chatHistory.append(SystemMessage(content=DATE_TODAY + """You are an helpful assistant for question-answering tasks.
|
||||
if chat.type == 'system':
|
||||
chatHistory.append(SystemMessage(content=DATE_TODAY + """You are a helpful assistant for question-answering tasks.
|
||||
Use the following pieces of retrieved context to answer the question.
|
||||
If you don't know the answer, just say that you don't know.
|
||||
Context:""" + str(chat.content)))
|
||||
|
||||
if(chat.type == 'ai'):
|
||||
if chat.type == 'ai':
|
||||
chatHistory.append(AIMessage(content=chat.content))
|
||||
|
||||
if(chat.type == 'human'):
|
||||
if chat.type == 'human':
|
||||
chatHistory.append(HumanMessage(content=chat.content))
|
||||
|
||||
chatHistory.append(("human", "{input}"));
|
||||
chatHistory.append(("human", "{input}"))
|
||||
|
||||
|
||||
qa_prompt = ChatPromptTemplate.from_messages(chatHistory)
|
||||
|
||||
descriptionchain = qa_prompt | llm
|
||||
|
||||
response = descriptionchain.invoke({"input": data.query})
|
||||
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
if IS_LOCAL_SETUP == 'true':
|
||||
return DescriptionResponse(response=response)
|
||||
else:
|
||||
return DescriptionResponse(response=response.content)
|
||||
|
@ -262,23 +288,33 @@ def doc_chat_with_history(data: UserQueryWithChatHistory, response_model=Descrip
|
|||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
|
||||
# Multi DOC Chat
|
||||
|
||||
@app.post("/delete/docs")
|
||||
def delete_all_related_data(data: DocumentsToDelete, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Delete documents and related data.
|
||||
|
||||
Args:
|
||||
data (DocumentsToDelete): The data containing documents to delete.
|
||||
db (Session): The database session.
|
||||
|
||||
Returns:
|
||||
dict: A message indicating the result of the deletion.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the token is invalid or expired.
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
if IS_LOCAL_SETUP == 'true':
|
||||
index = HIndices(username=username)
|
||||
else:
|
||||
index = HIndices(username=username,api_key=data.openaikey)
|
||||
index = HIndices(username=username, api_key=data.openaikey)
|
||||
|
||||
message = index.delete_vector_stores(summary_ids_to_delete=data.ids_to_delete,db=db )
|
||||
message = index.delete_vector_stores(summary_ids_to_delete=data.ids_to_delete, db=db)
|
||||
|
||||
return {
|
||||
"message": message
|
||||
|
@ -287,19 +323,12 @@ def delete_all_related_data(data: DocumentsToDelete, db: Session = Depends(get_d
|
|||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
|
||||
#AUTH CODE
|
||||
# AUTH CODE
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
# Manual Origins
|
||||
# origins = [
|
||||
# "http://localhost:3000", # Adjust the port if your frontend runs on a different one
|
||||
# "https://yourfrontenddomain.com",
|
||||
# ]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Allows all origins from the list
|
||||
allow_origins=["*"], # Allows all origins
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"], # Allows all methods
|
||||
allow_headers=["*"], # Allows all headers
|
||||
|
@ -308,9 +337,29 @@ app.add_middleware(
|
|||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
def get_user_by_username(db: Session, username: str):
|
||||
"""
|
||||
Retrieve a user by username from the database.
|
||||
|
||||
Args:
|
||||
db (Session): The database session.
|
||||
username (str): The username to search for.
|
||||
|
||||
Returns:
|
||||
User: The user object if found, None otherwise.
|
||||
"""
|
||||
return db.query(User).filter(User.username == username).first()
|
||||
|
||||
def create_user(db: Session, user: UserCreate):
|
||||
"""
|
||||
Create a new user in the database.
|
||||
|
||||
Args:
|
||||
db (Session): The database session.
|
||||
user (UserCreate): The user data to create.
|
||||
|
||||
Returns:
|
||||
str: A message indicating the completion of user creation.
|
||||
"""
|
||||
hashed_password = pwd_context.hash(user.password)
|
||||
db_user = User(username=user.username, hashed_password=hashed_password)
|
||||
db.add(db_user)
|
||||
|
@ -319,7 +368,20 @@ def create_user(db: Session, user: UserCreate):
|
|||
|
||||
@app.post("/register")
|
||||
def register_user(user: UserCreate, db: Session = Depends(get_db)):
|
||||
if(user.apisecretkey != API_SECRET_KEY):
|
||||
"""
|
||||
Register a new user.
|
||||
|
||||
Args:
|
||||
user (UserCreate): The user data for registration.
|
||||
db (Session): The database session.
|
||||
|
||||
Returns:
|
||||
str: A message indicating the completion of user registration.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the API secret key is invalid or the username is already registered.
|
||||
"""
|
||||
if user.apisecretkey != API_SECRET_KEY:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
db_user = get_user_by_username(db, username=user.username)
|
||||
|
@ -329,8 +391,18 @@ def register_user(user: UserCreate, db: Session = Depends(get_db)):
|
|||
del user.apisecretkey
|
||||
return create_user(db=db, user=user)
|
||||
|
||||
# Authenticate the user
|
||||
def authenticate_user(username: str, password: str, db: Session):
|
||||
"""
|
||||
Authenticate a user.
|
||||
|
||||
Args:
|
||||
username (str): The username of the user.
|
||||
password (str): The password of the user.
|
||||
db (Session): The database session.
|
||||
|
||||
Returns:
|
||||
User: The authenticated user object if successful, False otherwise.
|
||||
"""
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user:
|
||||
return False
|
||||
|
@ -338,8 +410,17 @@ def authenticate_user(username: str, password: str, db: Session):
|
|||
return False
|
||||
return user
|
||||
|
||||
# Create access token
|
||||
def create_access_token(data: dict, expires_delta: timedelta | None = None):
|
||||
"""
|
||||
Create an access token.
|
||||
|
||||
Args:
|
||||
data (dict): The data to encode in the token.
|
||||
expires_delta (timedelta, optional): The expiration time for the token.
|
||||
|
||||
Returns:
|
||||
str: The encoded JWT token.
|
||||
"""
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
|
@ -351,6 +432,19 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None):
|
|||
|
||||
@app.post("/token")
|
||||
def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
|
||||
"""
|
||||
Authenticate a user and return an access token.
|
||||
|
||||
Args:
|
||||
form_data (OAuth2PasswordRequestForm): The form data containing username and password.
|
||||
db (Session): The database session.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the access token and token type.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the username or password is incorrect.
|
||||
"""
|
||||
user = authenticate_user(form_data.username, form_data.password, db)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
|
@ -365,6 +459,18 @@ def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db:
|
|||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
|
||||
def verify_token(token: str = Depends(oauth2_scheme)):
|
||||
"""
|
||||
Verify the validity of a token.
|
||||
|
||||
Args:
|
||||
token (str): The token to verify.
|
||||
|
||||
Returns:
|
||||
dict: The payload of the token if valid.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the token is invalid or expired.
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
|
@ -376,11 +482,33 @@ def verify_token(token: str = Depends(oauth2_scheme)):
|
|||
|
||||
@app.get("/verify-token/{token}")
|
||||
async def verify_user_token(token: str):
|
||||
"""
|
||||
Verify a user token.
|
||||
|
||||
Args:
|
||||
token (str): The token to verify.
|
||||
|
||||
Returns:
|
||||
dict: A message indicating the validity of the token.
|
||||
"""
|
||||
verify_token(token=token)
|
||||
return {"message": "Token is valid"}
|
||||
|
||||
@app.post("/user/chat/save")
|
||||
def populate_user_chat(chat: NewUserChat, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Save a new chat for a user.
|
||||
|
||||
Args:
|
||||
chat (NewUserChat): The chat data to save.
|
||||
db (Session): The database session.
|
||||
|
||||
Returns:
|
||||
dict: A message indicating the success of the operation.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the token is invalid or expired.
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(chat.token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
|
@ -400,6 +528,19 @@ def populate_user_chat(chat: NewUserChat, db: Session = Depends(get_db)):
|
|||
|
||||
@app.post("/user/chat/update")
|
||||
def populate_user_chat(chat: ChatToUpdate, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Update an existing chat for a user.
|
||||
|
||||
Args:
|
||||
chat (ChatToUpdate): The chat data to update.
|
||||
db (Session): The database session.
|
||||
|
||||
Returns:
|
||||
dict: A message indicating the success of the operation.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the token is invalid or expired.
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(chat.token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
|
@ -418,6 +559,20 @@ def populate_user_chat(chat: ChatToUpdate, db: Session = Depends(get_db)):
|
|||
|
||||
@app.get("/user/chat/delete/{token}/{chatid}")
|
||||
async def delete_chat_of_user(token: str, chatid: str, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Delete a chat for a user.
|
||||
|
||||
Args:
|
||||
token (str): The user's authentication token.
|
||||
chatid (str): The ID of the chat to delete.
|
||||
db (Session): The database session.
|
||||
|
||||
Returns:
|
||||
dict: A message indicating the success of the operation.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the token is invalid or expired.
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
|
@ -433,9 +588,21 @@ async def delete_chat_of_user(token: str, chatid: str, db: Session = Depends(get
|
|||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
#Gets user id & name
|
||||
@app.get("/user/{token}")
|
||||
async def get_user_with_token(token: str, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Get user information using a token.
|
||||
|
||||
Args:
|
||||
token (str): The user's authentication token.
|
||||
db (Session): The database session.
|
||||
|
||||
Returns:
|
||||
dict: User information including ID, username, chats, and documents.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the token is invalid or expired.
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
|
@ -454,6 +621,19 @@ async def get_user_with_token(token: str, db: Session = Depends(get_db)):
|
|||
|
||||
@app.get("/searchspaces/{token}")
|
||||
async def get_user_with_token(token: str, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Get all search spaces.
|
||||
|
||||
Args:
|
||||
token (str): The user's authentication token.
|
||||
db (Session): The database session.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing all search spaces.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the token is invalid or expired.
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
|
@ -467,12 +647,7 @@ async def get_user_with_token(token: str, db: Session = Depends(get_db)):
|
|||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="127.0.0.1", port=8000)
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -12,6 +12,8 @@ services:
|
|||
- postgres_data:/var/lib/postgresql/data
|
||||
networks:
|
||||
- surfsense-network
|
||||
ports:
|
||||
- "5432:5432"
|
||||
|
||||
# Backend Service (FastAPI)
|
||||
backend:
|
||||
|
@ -23,6 +25,7 @@ services:
|
|||
- ./backend/.env
|
||||
depends_on:
|
||||
- db
|
||||
privileged: true
|
||||
networks:
|
||||
- surfsense-network
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue