Docker Support

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2024-09-25 14:39:06 -07:00
parent 778c16f8bf
commit d247cacaa9
5 changed files with 267 additions and 79 deletions

@ -1 +1 @@
Subproject commit aeb87beee51bbf040307157d6c54b3a3a82ef620
Subproject commit 4d0d1414c4774704619834ec8e9280940976f500

View file

@ -1,7 +1,7 @@
#true if you wana run local setup with Ollama llama3.1
IS_LOCAL_SETUP = 'false'
#POSTGRES DB TO TRACK USERS
#POSTGRES DB TO TRACK USERS | replace localhost with db for docker setups eg "postgresql+psycopg2://postgres:postgres@db:5432/surfsense"
POSTGRES_DATABASE_URL = "postgresql+psycopg2://postgres:postgres@localhost:5432/surfsense"
# API KEY TO PREVENT USER REGISTRATION SPAM

View file

@ -23,16 +23,26 @@ IS_LOCAL_SETUP = os.environ.get("IS_LOCAL_SETUP")
# Dependency
def get_db():
"""
Dependency to get a database session.
Yields:
Session: A SQLAlchemy database session.
"""
db = SessionLocal()
try:
yield db
finally:
db.close()
class HIndices:
def __init__(self, username, api_key='local'):
"""
Initialize the HIndices object.
Args:
username (str): The username for the indices.
api_key (str, optional): API key for non-local setups. Defaults to 'local'.
"""
self.username = username
if(IS_LOCAL_SETUP == 'true'):

View file

@ -8,7 +8,7 @@ from prompts import CONTEXT_ANSWER_PROMPT, DATE_TODAY, SUBQUERY_DECOMPOSITION_PR
from pydmodels import ChatToUpdate, DescriptionResponse, DocWithContent, DocumentsToDelete, NewUserChat, UserCreate, UserQuery, RetrivedDocList, UserQueryResponse, UserQueryWithChatHistory
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
#Heirerical Indices class
# Hierarchical Indices class
from HIndices import HIndices
from Utils.stringify import stringify
@ -24,11 +24,11 @@ from models import Chat, Documents, SearchSpace, User
from database import SessionLocal
from fastapi.middleware.cors import CORSMiddleware
import os
from dotenv import load_dotenv
load_dotenv()
# Environment variables
IS_LOCAL_SETUP = os.environ.get("IS_LOCAL_SETUP")
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.environ.get("ACCESS_TOKEN_EXPIRE_MINUTES"))
ALGORITHM = os.environ.get("ALGORITHM")
@ -37,17 +37,33 @@ SECRET_KEY = os.environ.get("SECRET_KEY")
app = FastAPI()
# Dependency
def get_db():
"""
Dependency to get a database session.
Yields:
Session: A SQLAlchemy database session.
"""
db = SessionLocal()
try:
yield db
finally:
db.close()
@app.post("/chat/")
def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
"""
Process a user query and return a response.
Args:
data (UserQuery): The user query data.
Returns:
UserQueryResponse: The response to the user query.
Raises:
HTTPException: If the token is invalid or expired.
"""
try:
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
@ -57,15 +73,14 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
query = data.query
search_space = data.search_space
if(IS_LOCAL_SETUP == 'true'):
sub_query_llm = OllamaLLM(model="mistral-nemo",temperature=0)
qa_llm = OllamaLLM(model="mistral-nemo",temperature=0)
# Initialize LLM based on setup
if IS_LOCAL_SETUP == 'true':
sub_query_llm = OllamaLLM(model="mistral-nemo", temperature=0)
qa_llm = OllamaLLM(model="mistral-nemo", temperature=0)
else:
sub_query_llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", api_key=data.openaikey)
qa_llm = ChatOpenAI(temperature=0.5, model_name="gpt-4o-mini", api_key=data.openaikey)
# Create an LLMChain for sub-query decomposition
subquery_decomposer_chain = SUBQUERY_DECOMPOSITION_PROMT | sub_query_llm
@ -74,12 +89,12 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
Decompose the original query into simpler sub-queries.
Args:
original_query (str): The original complex query
original_query (str): The original complex query
Returns:
List[str]: A list of simpler sub-queries
List[str]: A list of simpler sub-queries
"""
if(IS_LOCAL_SETUP == 'true'):
if IS_LOCAL_SETUP == 'true':
response = subquery_decomposer_chain.invoke(original_query)
else:
response = subquery_decomposer_chain.invoke(original_query).content
@ -87,20 +102,16 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
sub_queries = [q.strip() for q in response.split('\n') if q.strip() and not q.strip().startswith('Sub-queries:')]
return sub_queries
# Create Heirarical Indecices
if(IS_LOCAL_SETUP == 'true'):
# Create Hierarchical Indices
if IS_LOCAL_SETUP == 'true':
index = HIndices(username=username)
else:
index = HIndices(username=username,api_key=data.openaikey)
index = HIndices(username=username, api_key=data.openaikey)
# For Those Who Want HyDe Questions
# For those who want HyDE questions
# sub_queries = decompose_query(query)
sub_queries = []
sub_queries.append(query)
sub_queries = [query]
duplicate_related_summary_docs = []
context_to_answer = ""
@ -112,12 +123,11 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
duplicate_related_summary_docs.extend(related_summary_docs)
# Remove duplicate documents
combined_docs_seen_metadata = set()
combined_docs_unique_documents = []
for doc in duplicate_related_summary_docs:
# Convert metadata to a tuple of its items (this allows it to be added to a set)
doc.metadata['relevance_score'] = 0.0
metadata_tuple = tuple(sorted(doc.metadata.items()))
if metadata_tuple not in combined_docs_seen_metadata:
@ -134,29 +144,37 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
VisitedWebPageDateWithTimeInISOString=doc.metadata['VisitedWebPageDateWithTimeInISOString'],
VisitedWebPageReffererURL=doc.metadata['VisitedWebPageReffererURL'],
VisitedWebPageVisitDurationInMilliseconds=doc.metadata['VisitedWebPageVisitDurationInMilliseconds'],
)
)
returnDocs.append(entry)
# Generate final answer
ans_chain = CONTEXT_ANSWER_PROMPT | qa_llm
finalans = ans_chain.invoke({"query": query, "context": context_to_answer})
if(IS_LOCAL_SETUP == 'true'):
if IS_LOCAL_SETUP == 'true':
return UserQueryResponse(response=finalans, relateddocs=returnDocs)
else:
return UserQueryResponse(response=finalans.content, relateddocs=returnDocs)
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
# SAVE DOCS
@app.post("/save/")
def save_data(apires: RetrivedDocList, db: Session = Depends(get_db)):
"""
Save retrieved documents to the database and vector stores.
Args:
apires (RetrivedDocList): The list of documents to save.
db (Session): The database session.
Returns:
dict: A message indicating the success of the operation.
Raises:
HTTPException: If the token is invalid or expired.
"""
try:
payload = jwt.decode(apires.token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
@ -181,31 +199,29 @@ def save_data(apires: RetrivedDocList, db: Session = Depends(get_db)):
content += f"===================================================================================== \n"
content += f"Webpage Content of the visited webpage url in markdown format : \n\n {doc.pageContent} \n\n"
content += f"===================================================================================== \n"
raw_documents.append(Document(page_content=content,metadata=doc.metadata.__dict__))
raw_documents.append(Document(page_content=content, metadata=doc.metadata.__dict__))
pgdocmeta = stringify(doc.metadata.__dict__)
if(searchspace):
DocumentPgEntry.append(Documents(file_type='WEBPAGE',title=doc.metadata.VisitedWebPageTitle,search_space=searchspace, document_metadata=pgdocmeta, page_content=content))
if searchspace:
DocumentPgEntry.append(Documents(file_type='WEBPAGE', title=doc.metadata.VisitedWebPageTitle, search_space=searchspace, document_metadata=pgdocmeta, page_content=content))
else:
DocumentPgEntry.append(Documents(file_type='WEBPAGE',title=doc.metadata.VisitedWebPageTitle,search_space=SearchSpace(search_space=apires.search_space.upper()), document_metadata=pgdocmeta, page_content=content))
DocumentPgEntry.append(Documents(file_type='WEBPAGE', title=doc.metadata.VisitedWebPageTitle, search_space=SearchSpace(search_space=apires.search_space.upper()), document_metadata=pgdocmeta, page_content=content))
#Save docs in PG
# Save docs in PG
user = db.query(User).filter(User.username == username).first()
user.documents.extend(DocumentPgEntry)
db.commit()
# Create Heirarical Indecices
if(IS_LOCAL_SETUP == 'true'):
# Create Hierarchical Indices
if IS_LOCAL_SETUP == 'true':
index = HIndices(username=username)
else:
index = HIndices(username=username,api_key=apires.openaikey)
index = HIndices(username=username, api_key=apires.openaikey)
#Save Indices in vector Stores
index.encode_docs_hierarchical(documents=raw_documents, files_type='WEBPAGE',search_space=apires.search_space.upper(), db=db)
# Save Indices in vector stores
index.encode_docs_hierarchical(documents=raw_documents, files_type='WEBPAGE', search_space=apires.search_space.upper(), db=db)
print("FINISHED")
@ -216,45 +232,55 @@ def save_data(apires: RetrivedDocList, db: Session = Depends(get_db)):
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
# Multi DOC Chat
@app.post("/chat/docs")
def doc_chat_with_history(data: UserQueryWithChatHistory, response_model=DescriptionResponse):
"""
Process a user query with chat history and return a response.
Args:
data (UserQueryWithChatHistory): The user query data with chat history.
Returns:
DescriptionResponse: The response to the user query.
Raises:
HTTPException: If the token is invalid or expired.
"""
try:
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
if(IS_LOCAL_SETUP == 'true'):
llm = OllamaLLM(model="mistral-nemo",temperature=0)
if IS_LOCAL_SETUP == 'true':
llm = OllamaLLM(model="mistral-nemo", temperature=0)
else:
llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", api_key=data.openaikey)
chatHistory = []
for chat in data.chat:
if(chat.type == 'system'):
chatHistory.append(SystemMessage(content=DATE_TODAY + """You are an helpful assistant for question-answering tasks.
if chat.type == 'system':
chatHistory.append(SystemMessage(content=DATE_TODAY + """You are a helpful assistant for question-answering tasks.
Use the following pieces of retrieved context to answer the question.
If you don't know the answer, just say that you don't know.
Context:""" + str(chat.content)))
if(chat.type == 'ai'):
if chat.type == 'ai':
chatHistory.append(AIMessage(content=chat.content))
if(chat.type == 'human'):
if chat.type == 'human':
chatHistory.append(HumanMessage(content=chat.content))
chatHistory.append(("human", "{input}"));
chatHistory.append(("human", "{input}"))
qa_prompt = ChatPromptTemplate.from_messages(chatHistory)
descriptionchain = qa_prompt | llm
response = descriptionchain.invoke({"input": data.query})
if(IS_LOCAL_SETUP == 'true'):
if IS_LOCAL_SETUP == 'true':
return DescriptionResponse(response=response)
else:
return DescriptionResponse(response=response.content)
@ -262,23 +288,33 @@ def doc_chat_with_history(data: UserQueryWithChatHistory, response_model=Descrip
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
# Multi DOC Chat
@app.post("/delete/docs")
def delete_all_related_data(data: DocumentsToDelete, db: Session = Depends(get_db)):
"""
Delete documents and related data.
Args:
data (DocumentsToDelete): The data containing documents to delete.
db (Session): The database session.
Returns:
dict: A message indicating the result of the deletion.
Raises:
HTTPException: If the token is invalid or expired.
"""
try:
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
if(IS_LOCAL_SETUP == 'true'):
if IS_LOCAL_SETUP == 'true':
index = HIndices(username=username)
else:
index = HIndices(username=username,api_key=data.openaikey)
index = HIndices(username=username, api_key=data.openaikey)
message = index.delete_vector_stores(summary_ids_to_delete=data.ids_to_delete,db=db )
message = index.delete_vector_stores(summary_ids_to_delete=data.ids_to_delete, db=db)
return {
"message": message
@ -287,19 +323,12 @@ def delete_all_related_data(data: DocumentsToDelete, db: Session = Depends(get_d
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
#AUTH CODE
# AUTH CODE
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# Manual Origins
# origins = [
# "http://localhost:3000", # Adjust the port if your frontend runs on a different one
# "https://yourfrontenddomain.com",
# ]
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins from the list
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
@ -308,9 +337,29 @@ app.add_middleware(
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def get_user_by_username(db: Session, username: str):
"""
Retrieve a user by username from the database.
Args:
db (Session): The database session.
username (str): The username to search for.
Returns:
User: The user object if found, None otherwise.
"""
return db.query(User).filter(User.username == username).first()
def create_user(db: Session, user: UserCreate):
"""
Create a new user in the database.
Args:
db (Session): The database session.
user (UserCreate): The user data to create.
Returns:
str: A message indicating the completion of user creation.
"""
hashed_password = pwd_context.hash(user.password)
db_user = User(username=user.username, hashed_password=hashed_password)
db.add(db_user)
@ -319,7 +368,20 @@ def create_user(db: Session, user: UserCreate):
@app.post("/register")
def register_user(user: UserCreate, db: Session = Depends(get_db)):
if(user.apisecretkey != API_SECRET_KEY):
"""
Register a new user.
Args:
user (UserCreate): The user data for registration.
db (Session): The database session.
Returns:
str: A message indicating the completion of user registration.
Raises:
HTTPException: If the API secret key is invalid or the username is already registered.
"""
if user.apisecretkey != API_SECRET_KEY:
raise HTTPException(status_code=401, detail="Unauthorized")
db_user = get_user_by_username(db, username=user.username)
@ -329,8 +391,18 @@ def register_user(user: UserCreate, db: Session = Depends(get_db)):
del user.apisecretkey
return create_user(db=db, user=user)
# Authenticate the user
def authenticate_user(username: str, password: str, db: Session):
"""
Authenticate a user.
Args:
username (str): The username of the user.
password (str): The password of the user.
db (Session): The database session.
Returns:
User: The authenticated user object if successful, False otherwise.
"""
user = db.query(User).filter(User.username == username).first()
if not user:
return False
@ -338,8 +410,17 @@ def authenticate_user(username: str, password: str, db: Session):
return False
return user
# Create access token
def create_access_token(data: dict, expires_delta: timedelta | None = None):
"""
Create an access token.
Args:
data (dict): The data to encode in the token.
expires_delta (timedelta, optional): The expiration time for the token.
Returns:
str: The encoded JWT token.
"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
@ -351,6 +432,19 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None):
@app.post("/token")
def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
"""
Authenticate a user and return an access token.
Args:
form_data (OAuth2PasswordRequestForm): The form data containing username and password.
db (Session): The database session.
Returns:
dict: A dictionary containing the access token and token type.
Raises:
HTTPException: If the username or password is incorrect.
"""
user = authenticate_user(form_data.username, form_data.password, db)
if not user:
raise HTTPException(
@ -365,6 +459,18 @@ def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db:
return {"access_token": access_token, "token_type": "bearer"}
def verify_token(token: str = Depends(oauth2_scheme)):
"""
Verify the validity of a token.
Args:
token (str): The token to verify.
Returns:
dict: The payload of the token if valid.
Raises:
HTTPException: If the token is invalid or expired.
"""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
@ -376,11 +482,33 @@ def verify_token(token: str = Depends(oauth2_scheme)):
@app.get("/verify-token/{token}")
async def verify_user_token(token: str):
"""
Verify a user token.
Args:
token (str): The token to verify.
Returns:
dict: A message indicating the validity of the token.
"""
verify_token(token=token)
return {"message": "Token is valid"}
@app.post("/user/chat/save")
def populate_user_chat(chat: NewUserChat, db: Session = Depends(get_db)):
"""
Save a new chat for a user.
Args:
chat (NewUserChat): The chat data to save.
db (Session): The database session.
Returns:
dict: A message indicating the success of the operation.
Raises:
HTTPException: If the token is invalid or expired.
"""
try:
payload = jwt.decode(chat.token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
@ -400,6 +528,19 @@ def populate_user_chat(chat: NewUserChat, db: Session = Depends(get_db)):
@app.post("/user/chat/update")
def populate_user_chat(chat: ChatToUpdate, db: Session = Depends(get_db)):
"""
Update an existing chat for a user.
Args:
chat (ChatToUpdate): The chat data to update.
db (Session): The database session.
Returns:
dict: A message indicating the success of the operation.
Raises:
HTTPException: If the token is invalid or expired.
"""
try:
payload = jwt.decode(chat.token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
@ -418,6 +559,20 @@ def populate_user_chat(chat: ChatToUpdate, db: Session = Depends(get_db)):
@app.get("/user/chat/delete/{token}/{chatid}")
async def delete_chat_of_user(token: str, chatid: str, db: Session = Depends(get_db)):
"""
Delete a chat for a user.
Args:
token (str): The user's authentication token.
chatid (str): The ID of the chat to delete.
db (Session): The database session.
Returns:
dict: A message indicating the success of the operation.
Raises:
HTTPException: If the token is invalid or expired.
"""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
@ -433,9 +588,21 @@ async def delete_chat_of_user(token: str, chatid: str, db: Session = Depends(get
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
#Gets user id & name
@app.get("/user/{token}")
async def get_user_with_token(token: str, db: Session = Depends(get_db)):
"""
Get user information using a token.
Args:
token (str): The user's authentication token.
db (Session): The database session.
Returns:
dict: User information including ID, username, chats, and documents.
Raises:
HTTPException: If the token is invalid or expired.
"""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
@ -454,6 +621,19 @@ async def get_user_with_token(token: str, db: Session = Depends(get_db)):
@app.get("/searchspaces/{token}")
async def get_user_with_token(token: str, db: Session = Depends(get_db)):
"""
Get all search spaces.
Args:
token (str): The user's authentication token.
db (Session): The database session.
Returns:
dict: A dictionary containing all search spaces.
Raises:
HTTPException: If the token is invalid or expired.
"""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
@ -467,12 +647,7 @@ async def get_user_with_token(token: str, db: Session = Depends(get_db)):
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8000)

View file

@ -12,6 +12,8 @@ services:
- postgres_data:/var/lib/postgresql/data
networks:
- surfsense-network
ports:
- "5432:5432"
# Backend Service (FastAPI)
backend:
@ -23,6 +25,7 @@ services:
- ./backend/.env
depends_on:
- db
privileged: true
networks:
- surfsense-network