feat: File Uploader

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2024-10-07 22:08:53 -07:00
parent c3030f058e
commit 0435cdd177
4 changed files with 292 additions and 114 deletions

View file

@ -1,4 +1,4 @@
#true if you wana run local setup with Ollama llama3.1
#true if you wana run local setup with Ollama
IS_LOCAL_SETUP = 'false'
#POSTGRES DB TO TRACK USERS

View file

@ -1,6 +1,6 @@
from langchain_chroma import Chroma
from langchain_ollama import OllamaLLM, OllamaEmbeddings
from langchain_community.vectorstores.utils import filter_complex_metadata
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain.docstore.document import Document
from langchain_experimental.text_splitter import SemanticChunker
@ -56,22 +56,73 @@ class HIndices:
# self.summary_store_size = len(self.summary_store.get()['documents'])
# self.detailed_store_size = len(self.detailed_store.get()['documents'])
def encode_docs_hierarchical(self, documents, files_type, search_space='GENERAL', db: Session = Depends(get_db)):
"""
Creates and Saves/Updates docs in hierarchical indices and postgres table
"""
def summarize_file_doc(self, page_no, doc, search_space):
def summarize_doc(page_no,doc):
"""
Summarizes a single document.
report_template = """You are a forensic investigator expert in making detailed report of the document. You are given the document make a report of it.
Args:
page_no: Page no in Summary Vector store
doc: The document to be summarized.
DOCUMENT: {document}
Returns:
A summarized Document object.
"""
Detailed Report:"""
report_prompt = PromptTemplate(
input_variables=["document"],
template=report_template
)
# Create an LLMChain for sub-query decomposition
report_chain = report_prompt | self.llm
if(IS_LOCAL_SETUP == 'true'):
# Local LLMS suck at summaries so need this slow and painful procedure
text_splitter = SemanticChunker(embeddings=self.embeddings)
chunks = text_splitter.split_documents([doc])
combined_summary = ""
for i, chunk in enumerate(chunks):
print("GENERATING SUMMARY FOR CHUNK "+ str(i))
chunk_summary = report_chain.invoke({"document": chunk})
combined_summary += "\n\n" + chunk_summary + "\n\n"
response = combined_summary
metadict = {
"page": page_no,
"summary": True,
"search_space": search_space,
}
# metadict['languages'] = metadict['languages'][0]
metadict.update(doc.metadata)
return Document(
id=str(page_no),
page_content=response,
metadata=metadict
)
else:
response = report_chain.invoke({"document": doc})
metadict = {
"page": page_no,
"summary": True,
"search_space": search_space,
}
metadict.update(doc.metadata)
# metadict['languages'] = metadict['languages'][0]
return Document(
id=str(page_no),
page_content=response.content,
metadata=metadict
)
def summarize_webpage_doc(self, page_no, doc, search_space):
report_template = """You are a forensic investigator expert in making detailed report of the document. You are given the document make a report of it.
@ -106,7 +157,7 @@ class HIndices:
id=str(page_no),
page_content=response,
metadata={
"filetype": files_type,
"filetype": 'WEBPAGE',
"page": page_no,
"summary": True,
"search_space": search_space,
@ -125,7 +176,7 @@ class HIndices:
id=str(page_no),
page_content=response.content,
metadata={
"filetype": files_type,
"filetype": 'WEBPAGE',
"page": page_no,
"summary": True,
"search_space": search_space,
@ -138,6 +189,11 @@ class HIndices:
}
)
def encode_docs_hierarchical(self, documents, files_type, search_space='GENERAL', db: Session = Depends(get_db)):
"""
Creates and Saves/Updates docs in hierarchical indices and postgres table
"""
# DocumentPgEntry = []
# searchspace = db.query(SearchSpace).filter(SearchSpace.search_space == search_space).first()
@ -165,7 +221,12 @@ class HIndices:
# Process documents
summaries = []
batch_summaries = [summarize_doc(i + summary_last_id, doc) for i, doc in enumerate(documents)]
if(files_type=='WEBPAGE'):
batch_summaries = [self.summarize_webpage_doc(page_no = i + summary_last_id, doc=doc, search_space=search_space) for i, doc in enumerate(documents)]
else:
batch_summaries = [self.summarize_file_doc(page_no = i + summary_last_id, doc=doc, search_space=search_space) for i, doc in enumerate(documents)]
# batch_summaries = [summarize_doc(i + summary_last_id, doc) for i, doc in enumerate(documents)]
summaries.extend(batch_summaries)
detailed_chunks = []
@ -198,8 +259,8 @@ class HIndices:
detailed_chunks.extend(chunks)
#update vector stores
self.summary_store.add_documents(summaries)
self.detailed_store.add_documents(detailed_chunks)
self.summary_store.add_documents(filter_complex_metadata(summaries))
self.detailed_store.add_documents(filter_complex_metadata(detailed_chunks))
return self.summary_store, self.detailed_store

View file

@ -17,3 +17,5 @@ langchain_ollama
langchain_chroma
flashrank
psycopg2
unstructured-client
langchain-unstructured

View file

@ -4,9 +4,11 @@ from langchain_core.prompts import ChatPromptTemplate
from langchain_core.documents import Document
from langchain_ollama import OllamaLLM
from langchain_openai import ChatOpenAI
from sqlalchemy import insert
from prompts import CONTEXT_ANSWER_PROMPT, DATE_TODAY, SUBQUERY_DECOMPOSITION_PROMT
from pydmodels import ChatToUpdate, DescriptionResponse, DocWithContent, DocumentsToDelete, NewUserChat, UserCreate, UserQuery, RetrivedDocList, UserQueryResponse, UserQueryWithChatHistory
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
from langchain_unstructured import UnstructuredLoader
#Heirerical Indices class
from HIndices import HIndices
@ -14,7 +16,7 @@ from HIndices import HIndices
from Utils.stringify import stringify
# Auth Libs
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi import FastAPI, Depends, Form, HTTPException, status, UploadFile
from sqlalchemy.orm import Session
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
@ -46,6 +48,92 @@ def get_db():
db.close()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
@app.post("/uploadfiles/")
async def upload_files(files: list[UploadFile], token: str = Depends(oauth2_scheme), search_space: str = Form(...), api_key: str = Form(...), db: Session = Depends(get_db)):
try:
# Decode and verify the token
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
docs = []
for file in files:
loader = UnstructuredLoader(
file=file.file,
api_key="nWDzKHygqnQzbb1pBsxYnMoQ3nQBja",
partition_via_api=True,
chunking_strategy="basic",
max_characters=90000,
include_orig_elements=False,
strategy="fast",
)
filedocs = loader.load()
fileswithfilename = []
for f in filedocs:
temp = f
temp.metadata['filename'] = file.filename
fileswithfilename.append(temp)
docs.extend(fileswithfilename)
# Initialize containers for documents and entries
DocumentPgEntry = []
raw_documents = []
# Fetch the search space from the database or create it if it doesn't exist
searchspace = db.query(SearchSpace).filter(SearchSpace.search_space == search_space.upper()).first()
if not searchspace:
stmt = insert(SearchSpace).values(search_space=search_space.upper())
db.execute(stmt)
db.commit()
# Process each document in the retrieved document list
for doc in docs:
raw_documents.append(Document(page_content=doc.page_content, metadata=doc.metadata))
# Stringify the document metadata
pgdocmeta = stringify(doc.metadata)
DocumentPgEntry.append(Documents(
file_type=doc.metadata['filetype'],
title=doc.metadata['filename'],
search_space=db.query(SearchSpace).filter(SearchSpace.search_space == search_space.upper()).first(),
document_metadata=pgdocmeta,
page_content=doc.page_content
))
# Save documents in PostgreSQL
user = db.query(User).filter(User.username == username).first()
user.documents.extend(DocumentPgEntry)
db.commit()
# Create hierarchical indices
if IS_LOCAL_SETUP == 'true':
index = HIndices(username=username)
else:
index = HIndices(username=username, api_key=api_key)
# Save indices in vector stores
index.encode_docs_hierarchical(documents=raw_documents, files_type='OTHER', search_space=search_space.upper(), db=db)
print("FINISHED")
return {
"message": "Files Uploaded Successfully"
}
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
@app.post("/chat/")
def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
try:
@ -69,6 +157,7 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
# Create an LLMChain for sub-query decomposition
subquery_decomposer_chain = SUBQUERY_DECOMPOSITION_PROMT | sub_query_llm
#Experimental
def decompose_query(original_query: str):
"""
Decompose the original query into simpler sub-queries.
@ -156,8 +245,23 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
# SAVE DOCS
@app.post("/save/")
def save_data(apires: RetrivedDocList, db: Session = Depends(get_db)):
"""
Save retrieved documents to the database and encode them for hierarchical indexing.
This endpoint processes the provided documents, saves related information
in the PostgreSQL database, and updates hierarchical indices for the user.
Args:
apires (RetrivedDocList): The list of retrieved documents with metadata.
db (Session, optional): Dependency-injected session for database operations.
Returns:
dict: A message indicating the success of the operation.
Raises:
HTTPException: If the token is invalid or expired.
"""
try:
# Decode token and extract username
payload = jwt.decode(apires.token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
@ -165,47 +269,59 @@ def save_data(apires: RetrivedDocList, db: Session = Depends(get_db)):
print("STARTED")
# Initialize containers for documents and entries
DocumentPgEntry = []
raw_documents = []
searchspace = db.query(SearchSpace).filter(SearchSpace.search_space == apires.search_space).first()
for doc in apires.documents:
content = f"USER BROWSING SESSION EVENT: \n"
content += f"=======================================METADATA==================================== \n"
content += f"User Browsing Session ID : {doc.metadata.BrowsingSessionId} \n"
content += f"User Visited website with url : {doc.metadata.VisitedWebPageURL} \n"
content += f"This visited website url had title : {doc.metadata.VisitedWebPageTitle} \n"
content += f"User Visited this website from reffering url : {doc.metadata.VisitedWebPageReffererURL} \n"
content += f"User Visited this website url at this Date and Time : {doc.metadata.VisitedWebPageDateWithTimeInISOString} \n"
content += f"User Visited this website for : {str(doc.metadata.VisitedWebPageVisitDurationInMilliseconds)} milliseconds. \n"
content += f"===================================================================================== \n"
content += f"Webpage Content of the visited webpage url in markdown format : \n\n {doc.pageContent} \n\n"
content += f"===================================================================================== \n"
raw_documents.append(Document(page_content=content,metadata=doc.metadata.__dict__))
pgdocmeta = stringify(doc.metadata.__dict__)
if(searchspace):
DocumentPgEntry.append(Documents(file_type='WEBPAGE',title=doc.metadata.VisitedWebPageTitle,search_space=searchspace, document_metadata=pgdocmeta, page_content=content))
else:
DocumentPgEntry.append(Documents(file_type='WEBPAGE',title=doc.metadata.VisitedWebPageTitle,search_space=SearchSpace(search_space=apires.search_space.upper()), document_metadata=pgdocmeta, page_content=content))
#Save docs in PG
user = db.query(User).filter(User.username == username).first()
user.documents.extend(DocumentPgEntry)
# Fetch the search space from the database
searchspace = db.query(SearchSpace).filter(SearchSpace.search_space == apires.search_space.upper()).first()
if not searchspace:
stmt = insert(SearchSpace).values(search_space=apires.search_space.upper())
db.execute(stmt)
db.commit()
# Process each document in the retrieved document list
for doc in apires.documents:
# Construct document content
content = (
f"USER BROWSING SESSION EVENT: \n"
f"=======================================METADATA==================================== \n"
f"User Browsing Session ID : {doc.metadata.BrowsingSessionId} \n"
f"User Visited website with url : {doc.metadata.VisitedWebPageURL} \n"
f"This visited website url had title : {doc.metadata.VisitedWebPageTitle} \n"
f"User Visited this website from referring url : {doc.metadata.VisitedWebPageReffererURL} \n"
f"User Visited this website url at this Date and Time : {doc.metadata.VisitedWebPageDateWithTimeInISOString} \n"
f"User Visited this website for : {str(doc.metadata.VisitedWebPageVisitDurationInMilliseconds)} milliseconds. \n"
f"===================================================================================== \n"
f"Webpage Content of the visited webpage url in markdown format : \n\n{doc.pageContent}\n\n"
f"===================================================================================== \n"
)
raw_documents.append(Document(page_content=content, metadata=doc.metadata.__dict__))
# Create Heirarical Indecices
if(IS_LOCAL_SETUP == 'true'):
# Stringify the document metadata
pgdocmeta = stringify(doc.metadata.__dict__)
DocumentPgEntry.append(Documents(
file_type='WEBPAGE',
title=doc.metadata.VisitedWebPageTitle,
search_space=searchspace,
document_metadata=pgdocmeta,
page_content=content
))
# Save documents in PostgreSQL
user = db.query(User).filter(User.username == username).first()
user.documents.extend(DocumentPgEntry)
db.commit()
# Create hierarchical indices
if IS_LOCAL_SETUP == 'true':
index = HIndices(username=username)
else:
index = HIndices(username=username,api_key=apires.openaikey)
index = HIndices(username=username, api_key=apires.openaikey)
#Save Indices in vector Stores
index.encode_docs_hierarchical(documents=raw_documents, files_type='WEBPAGE',search_space=apires.search_space.upper(), db=db)
# Save indices in vector stores
index.encode_docs_hierarchical(documents=raw_documents, files_type='WEBPAGE', search_space=apires.search_space.upper(), db=db)
print("FINISHED")
@ -288,8 +404,7 @@ def delete_all_related_data(data: DocumentsToDelete, db: Session = Depends(get_d
raise HTTPException(status_code=403, detail="Token is invalid or expired")
#AUTH CODE
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# Manual Origins
# origins = [