mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-11 06:44:40 +00:00
SurfSense v3 - Highlight: Local LLM Support
This commit is contained in:
parent
04df919cf9
commit
7f38091d3d
13 changed files with 692 additions and 1345 deletions
|
@ -1,42 +1,40 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from langchain.chains import GraphCypherQAChain
|
||||
from langchain_community.graphs import Neo4jGraph
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.documents import Document
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain_community.vectorstores import Neo4jVector
|
||||
from envs import ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, API_SECRET_KEY, SECRET_KEY
|
||||
from prompts import CYPHER_QA_PROMPT, DATE_TODAY, DOC_DESCRIPTION_PROMPT, GRAPH_QUERY_GEN_PROMPT, NOTIFICATION_GENERATION_PROMT, SIMILARITY_SEARCH_PROMPT , CYPHER_GENERATION_PROMPT, DOCUMENT_METADATA_EXTRACTION_PROMT
|
||||
from pydmodels import ChatToUpdate, DescriptionResponse, GraphDocs, NewUserChat, NewUserData, Notifications, PrecisionQuery, PrecisionResponse, UserQuery, DocMeta, RetrivedDocList, UserQueryResponse, UserQueryWithChatHistory, VectorSearchQuery
|
||||
from langchain_experimental.text_splitter import SemanticChunker
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_ollama import OllamaLLM
|
||||
from langchain_openai import ChatOpenAI
|
||||
from prompts import CONTEXT_ANSWER_PROMPT, DATE_TODAY, SUBQUERY_DECOMPOSITION_PROMT
|
||||
from pydmodels import ChatToUpdate, DescriptionResponse, DocWithContent, DocumentsToDelete, NewUserChat, UserCreate, UserQuery, RetrivedDocList, UserQueryResponse, UserQueryWithChatHistory
|
||||
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
||||
|
||||
#Our Imps
|
||||
from LLMGraphTransformer import LLMGraphTransformer
|
||||
from langchain_openai import ChatOpenAI
|
||||
from DataExample import examples
|
||||
# import nest_asyncio
|
||||
# from langchain_community.chains.graph_qa.gremlin import GremlinQAChain
|
||||
# from langchain_community.graphs import GremlinGraph
|
||||
# from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
|
||||
# from langchain_core.documents import Document
|
||||
# from langchain_openai import AzureChatOpenAI
|
||||
#Heirerical Indices class
|
||||
from HIndices import HIndices
|
||||
|
||||
from Utils.stringify import stringify
|
||||
|
||||
# Auth Libs
|
||||
from fastapi import FastAPI, Depends, HTTPException, Request, status
|
||||
from fastapi import FastAPI, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from jose import JWTError, jwt
|
||||
from datetime import datetime, timedelta
|
||||
from passlib.context import CryptContext
|
||||
from models import Chat, Notification, User
|
||||
from database import SessionLocal, engine
|
||||
from pydantic import BaseModel
|
||||
from models import Chat, Documents, SearchSpace, User
|
||||
from database import SessionLocal
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
|
||||
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
IS_LOCAL_SETUP = os.environ.get("IS_LOCAL_SETUP")
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.environ.get("ACCESS_TOKEN_EXPIRE_MINUTES"))
|
||||
ALGORITHM = os.environ.get("ALGORITHM")
|
||||
API_SECRET_KEY = os.environ.get("API_SECRET_KEY")
|
||||
SECRET_KEY = os.environ.get("SECRET_KEY")
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Dependency
|
||||
|
@ -47,251 +45,117 @@ def get_db():
|
|||
finally:
|
||||
db.close()
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
apisecretkey: str
|
||||
|
||||
|
||||
|
||||
# General GraphCypherQAChain
|
||||
@app.post("/")
|
||||
@app.post("/chat/")
|
||||
def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
|
||||
|
||||
if(data.apisecretkey != API_SECRET_KEY):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
|
||||
graph = Neo4jGraph(url=data.neourl, username=data.neouser, password=data.neopass)
|
||||
|
||||
llm = ChatOpenAI(
|
||||
model="gpt-4o-mini",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
api_key=data.openaikey
|
||||
)
|
||||
|
||||
# Query Expansion
|
||||
# searchchain = GRAPH_QUERY_GEN_PROMPT | llm
|
||||
|
||||
# qry = searchchain.invoke({"question": data.query, "context": examples})
|
||||
|
||||
query = data.query #qry.content
|
||||
|
||||
embeddings = OpenAIEmbeddings(
|
||||
model="text-embedding-ada-002",
|
||||
api_key=data.openaikey,
|
||||
)
|
||||
|
||||
|
||||
chain = GraphCypherQAChain.from_llm(
|
||||
graph=graph,
|
||||
cypher_prompt=CYPHER_GENERATION_PROMPT,
|
||||
cypher_llm=llm,
|
||||
verbose=True,
|
||||
validate_cypher=True,
|
||||
qa_prompt=CYPHER_QA_PROMPT ,
|
||||
qa_llm=llm,
|
||||
return_intermediate_steps=True,
|
||||
top_k=5,
|
||||
)
|
||||
|
||||
vector_index = Neo4jVector.from_existing_graph(
|
||||
embeddings,
|
||||
graph=graph,
|
||||
search_type="hybrid",
|
||||
node_label="Document",
|
||||
text_node_properties=["text"],
|
||||
embedding_node_property="embedding",
|
||||
)
|
||||
|
||||
graphdocs = vector_index.similarity_search(data.query,k=15)
|
||||
docsDict = {}
|
||||
|
||||
for d in graphdocs:
|
||||
if d.metadata['BrowsingSessionId'] not in docsDict:
|
||||
newVal = d.metadata.copy()
|
||||
newVal['VisitedWebPageContent'] = d.page_content
|
||||
docsDict[d.metadata['BrowsingSessionId']] = newVal
|
||||
else:
|
||||
docsDict[d.metadata['BrowsingSessionId']]['VisitedWebPageContent'] += d.page_content
|
||||
|
||||
docstoreturn = []
|
||||
|
||||
for x in docsDict.values():
|
||||
docstoreturn.append(DocMeta(
|
||||
BrowsingSessionId=x['BrowsingSessionId'],
|
||||
VisitedWebPageURL=x['VisitedWebPageURL'],
|
||||
VisitedWebPageVisitDurationInMilliseconds=x['VisitedWebPageVisitDurationInMilliseconds'],
|
||||
VisitedWebPageTitle=x['VisitedWebPageTitle'],
|
||||
VisitedWebPageReffererURL=x['VisitedWebPageReffererURL'],
|
||||
VisitedWebPageDateWithTimeInISOString=x['VisitedWebPageDateWithTimeInISOString'],
|
||||
VisitedWebPageContent=x['VisitedWebPageContent']
|
||||
))
|
||||
|
||||
|
||||
try:
|
||||
responsegrp = chain.invoke({"query": query})
|
||||
|
||||
if "don't know" in responsegrp["result"]:
|
||||
raise Exception("No response from graph")
|
||||
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
structured_llm = llm.with_structured_output(VectorSearchQuery)
|
||||
doc_extract_chain = DOCUMENT_METADATA_EXTRACTION_PROMT | structured_llm
|
||||
query = data.query
|
||||
search_space = data.search_space
|
||||
|
||||
newquery = doc_extract_chain.invoke(responsegrp["intermediate_steps"][1]["context"])
|
||||
|
||||
graphdocs = vector_index.similarity_search(newquery.searchquery,k=15)
|
||||
|
||||
docsDict = {}
|
||||
|
||||
for d in graphdocs:
|
||||
if d.metadata['BrowsingSessionId'] not in docsDict:
|
||||
newVal = d.metadata.copy()
|
||||
newVal['VisitedWebPageContent'] = d.page_content
|
||||
docsDict[d.metadata['BrowsingSessionId']] = newVal
|
||||
else:
|
||||
docsDict[d.metadata['BrowsingSessionId']]['VisitedWebPageContent'] += d.page_content
|
||||
|
||||
docstoreturn = []
|
||||
|
||||
for x in docsDict.values():
|
||||
docstoreturn.append(DocMeta(
|
||||
BrowsingSessionId=x['BrowsingSessionId'],
|
||||
VisitedWebPageURL=x['VisitedWebPageURL'],
|
||||
VisitedWebPageVisitDurationInMilliseconds=x['VisitedWebPageVisitDurationInMilliseconds'],
|
||||
VisitedWebPageTitle=x['VisitedWebPageTitle'],
|
||||
VisitedWebPageReffererURL=x['VisitedWebPageReffererURL'],
|
||||
VisitedWebPageDateWithTimeInISOString=x['VisitedWebPageDateWithTimeInISOString'],
|
||||
VisitedWebPageContent=x['VisitedWebPageContent']
|
||||
))
|
||||
|
||||
return UserQueryResponse(relateddocs=docstoreturn,response=responsegrp["result"])
|
||||
except:
|
||||
# Fallback to Similarity Search RAG
|
||||
searchchain = SIMILARITY_SEARCH_PROMPT | llm
|
||||
|
||||
response = searchchain.invoke({"question": data.query, "context": docstoreturn})
|
||||
|
||||
return UserQueryResponse(relateddocs=docstoreturn,response=response.content)
|
||||
|
||||
|
||||
# Precision Search
|
||||
@app.post("/precision")
|
||||
def get_precision_search_response(data: PrecisionQuery, response_model=PrecisionResponse):
|
||||
if(data.apisecretkey != API_SECRET_KEY):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
graph = Neo4jGraph(url=data.neourl, username=data.neouser, password=data.neopass)
|
||||
|
||||
GRAPH_QUERY = "MATCH (d:Document) WHERE d.VisitedWebPageDateWithTimeInISOString >= " + "'" + data.daterange[0] + "'" + " AND d.VisitedWebPageDateWithTimeInISOString <= " + "'" + data.daterange[1] + "'"
|
||||
|
||||
if(data.timerange[0] >= data.timerange[1]):
|
||||
GRAPH_QUERY += " AND d.VisitedWebPageVisitDurationInMilliseconds >= 0"
|
||||
else:
|
||||
GRAPH_QUERY += " AND d.VisitedWebPageVisitDurationInMilliseconds >= "+ str(data.timerange[0]) + " AND d.VisitedWebPageVisitDurationInMilliseconds <= " + str(data.timerange[1])
|
||||
|
||||
if(data.webpageurl):
|
||||
GRAPH_QUERY += " AND d.VisitedWebPageURL CONTAINS " + "'" + data.webpageurl.lower() + "'"
|
||||
|
||||
if(data.sessionid):
|
||||
GRAPH_QUERY += " AND d.BrowsingSessionId = " + "'" + data.sessionid + "'"
|
||||
|
||||
GRAPH_QUERY += " RETURN d;"
|
||||
|
||||
graphdocs = graph.query(GRAPH_QUERY)
|
||||
|
||||
docsDict = {}
|
||||
|
||||
for d in graphdocs:
|
||||
if d['d']['VisitedWebPageVisitDurationInMilliseconds'] not in docsDict:
|
||||
docsDict[d['d']['VisitedWebPageVisitDurationInMilliseconds']] = d['d']
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
sub_query_llm = OllamaLLM(model="mistral-nemo",temperature=0)
|
||||
qa_llm = OllamaLLM(model="mistral-nemo",temperature=0)
|
||||
else:
|
||||
docsDict[d['d']['VisitedWebPageVisitDurationInMilliseconds']]['text'] += d['d']['text']
|
||||
|
||||
docs = []
|
||||
|
||||
for x in docsDict.values():
|
||||
docs.append(DocMeta(
|
||||
BrowsingSessionId=x['BrowsingSessionId'],
|
||||
VisitedWebPageURL=x['VisitedWebPageURL'],
|
||||
VisitedWebPageVisitDurationInMilliseconds=x['VisitedWebPageVisitDurationInMilliseconds'],
|
||||
VisitedWebPageTitle=x['VisitedWebPageTitle'],
|
||||
VisitedWebPageReffererURL=x['VisitedWebPageReffererURL'],
|
||||
VisitedWebPageDateWithTimeInISOString=x['VisitedWebPageDateWithTimeInISOString'],
|
||||
VisitedWebPageContent=x['text']
|
||||
))
|
||||
|
||||
return PrecisionResponse(documents=docs)
|
||||
sub_query_llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", api_key=data.openaikey)
|
||||
qa_llm = ChatOpenAI(temperature=0.5, model_name="gpt-4o-mini", api_key=data.openaikey)
|
||||
|
||||
|
||||
# Multi DOC Chat
|
||||
@app.post("/chat/docs")
|
||||
def doc_chat_with_history(data: UserQueryWithChatHistory, response_model=DescriptionResponse):
|
||||
if(data.apisecretkey != API_SECRET_KEY):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
llm = ChatOpenAI(
|
||||
model="gpt-4o-mini",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
api_key=data.openaikey
|
||||
)
|
||||
|
||||
chatHistory = []
|
||||
|
||||
for chat in data.chat:
|
||||
if(chat.type == 'system'):
|
||||
chatHistory.append(SystemMessage(content=DATE_TODAY + """You are an helpful assistant for question-answering tasks.
|
||||
Use the following pieces of retrieved context to answer the question.
|
||||
If you don't know the answer, just say that you don't know.
|
||||
Context:""" + str(chat.content)))
|
||||
|
||||
if(chat.type == 'ai'):
|
||||
chatHistory.append(AIMessage(content=chat.content))
|
||||
|
||||
if(chat.type == 'human'):
|
||||
chatHistory.append(HumanMessage(content=chat.content))
|
||||
|
||||
chatHistory.append(("human", "{input}"));
|
||||
|
||||
|
||||
qa_prompt = ChatPromptTemplate.from_messages(chatHistory)
|
||||
# Create an LLMChain for sub-query decomposition
|
||||
subquery_decomposer_chain = SUBQUERY_DECOMPOSITION_PROMT | sub_query_llm
|
||||
|
||||
descriptionchain = qa_prompt | llm
|
||||
def decompose_query(original_query: str):
|
||||
"""
|
||||
Decompose the original query into simpler sub-queries.
|
||||
|
||||
Args:
|
||||
original_query (str): The original complex query
|
||||
|
||||
Returns:
|
||||
List[str]: A list of simpler sub-queries
|
||||
"""
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
response = subquery_decomposer_chain.invoke(original_query)
|
||||
else:
|
||||
response = subquery_decomposer_chain.invoke(original_query).content
|
||||
|
||||
sub_queries = [q.strip() for q in response.split('\n') if q.strip() and not q.strip().startswith('Sub-queries:')]
|
||||
return sub_queries
|
||||
|
||||
response = descriptionchain.invoke({"input": data.query})
|
||||
|
||||
return DescriptionResponse(response=response.content)
|
||||
|
||||
# Create Heirarical Indecices
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
index = HIndices(username=username)
|
||||
else:
|
||||
index = HIndices(username=username,api_key=data.openaikey)
|
||||
|
||||
|
||||
|
||||
# For Those Who Want HyDe Questions
|
||||
# sub_queries = decompose_query(query)
|
||||
|
||||
sub_queries = []
|
||||
sub_queries.append(query)
|
||||
|
||||
|
||||
# DOC DESCRIPTION
|
||||
@app.post("/kb/doc")
|
||||
def get_doc_description(data: UserQuery, response_model=DescriptionResponse):
|
||||
if(data.apisecretkey != API_SECRET_KEY):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
document = data.query
|
||||
llm = ChatOpenAI(
|
||||
model="gpt-4o-mini",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
api_key=data.openaikey
|
||||
)
|
||||
|
||||
descriptionchain = DOC_DESCRIPTION_PROMPT | llm
|
||||
|
||||
response = descriptionchain.invoke({"document": document})
|
||||
|
||||
return DescriptionResponse(response=response.content)
|
||||
|
||||
duplicate_related_summary_docs = []
|
||||
context_to_answer = ""
|
||||
for sub_query in sub_queries:
|
||||
localreturn = index.local_search(query=sub_query, search_space=search_space)
|
||||
globalreturn, related_summary_docs = index.global_search(query=sub_query, search_space=search_space)
|
||||
|
||||
context_to_answer += localreturn + "\n\n" + globalreturn
|
||||
|
||||
# SAVE DOCS TO GRAPH DB
|
||||
@app.post("/kb/")
|
||||
def populate_graph(apires: RetrivedDocList, db: Session = Depends(get_db)):
|
||||
duplicate_related_summary_docs.extend(related_summary_docs)
|
||||
|
||||
|
||||
combined_docs_seen_metadata = set()
|
||||
combined_docs_unique_documents = []
|
||||
|
||||
for doc in duplicate_related_summary_docs:
|
||||
# Convert metadata to a tuple of its items (this allows it to be added to a set)
|
||||
doc.metadata['relevance_score'] = 0.0
|
||||
metadata_tuple = tuple(sorted(doc.metadata.items()))
|
||||
if metadata_tuple not in combined_docs_seen_metadata:
|
||||
combined_docs_seen_metadata.add(metadata_tuple)
|
||||
combined_docs_unique_documents.append(doc)
|
||||
|
||||
returnDocs = []
|
||||
for doc in combined_docs_unique_documents:
|
||||
entry = DocWithContent(
|
||||
BrowsingSessionId=doc.metadata['BrowsingSessionId'],
|
||||
VisitedWebPageURL=doc.metadata['VisitedWebPageURL'],
|
||||
VisitedWebPageContent=doc.page_content,
|
||||
VisitedWebPageTitle=doc.metadata['VisitedWebPageTitle'],
|
||||
VisitedWebPageDateWithTimeInISOString=doc.metadata['VisitedWebPageDateWithTimeInISOString'],
|
||||
VisitedWebPageReffererURL=doc.metadata['VisitedWebPageReffererURL'],
|
||||
VisitedWebPageVisitDurationInMilliseconds=doc.metadata['VisitedWebPageVisitDurationInMilliseconds'],
|
||||
)
|
||||
|
||||
returnDocs.append(entry)
|
||||
|
||||
|
||||
ans_chain = CONTEXT_ANSWER_PROMPT | qa_llm
|
||||
|
||||
finalans = ans_chain.invoke({"query": query, "context": context_to_answer})
|
||||
|
||||
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
return UserQueryResponse(response=finalans, relateddocs=returnDocs)
|
||||
else:
|
||||
return UserQueryResponse(response=finalans.content, relateddocs=returnDocs)
|
||||
|
||||
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
# SAVE DOCS
|
||||
@app.post("/save/")
|
||||
def save_data(apires: RetrivedDocList, db: Session = Depends(get_db)):
|
||||
|
||||
try:
|
||||
payload = jwt.decode(apires.token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
|
@ -299,58 +163,49 @@ def populate_graph(apires: RetrivedDocList, db: Session = Depends(get_db)):
|
|||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
|
||||
print("STARTED")
|
||||
# print(apires)
|
||||
graph = Neo4jGraph(url=apires.neourl, username=apires.neouser, password=apires.neopass)
|
||||
|
||||
llm = ChatOpenAI(
|
||||
model="gpt-4o-mini",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
api_key=apires.openaikey
|
||||
)
|
||||
|
||||
embeddings = OpenAIEmbeddings(
|
||||
model="text-embedding-ada-002",
|
||||
api_key=apires.openaikey,
|
||||
)
|
||||
|
||||
llm_transformer = LLMGraphTransformer(llm=llm)
|
||||
|
||||
DocumentPgEntry = []
|
||||
raw_documents = []
|
||||
|
||||
searchspace = db.query(SearchSpace).filter(SearchSpace.search_space == apires.search_space).first()
|
||||
|
||||
for doc in apires.documents:
|
||||
raw_documents.append(Document(page_content=doc.pageContent, metadata=doc.metadata))
|
||||
|
||||
text_splitter = SemanticChunker(embeddings=embeddings)
|
||||
|
||||
documents = text_splitter.split_documents(raw_documents)
|
||||
graph_documents = llm_transformer.convert_to_graph_documents(documents)
|
||||
|
||||
|
||||
graph.add_graph_documents(
|
||||
graph_documents,
|
||||
baseEntityLabel=True,
|
||||
include_source=True
|
||||
)
|
||||
|
||||
|
||||
structured_llm = llm.with_structured_output(Notifications)
|
||||
notifs_extraction_chain = NOTIFICATION_GENERATION_PROMT | structured_llm
|
||||
|
||||
notifications = notifs_extraction_chain.invoke({"documents": raw_documents})
|
||||
|
||||
notifsdb = []
|
||||
|
||||
for text in notifications.notifications:
|
||||
notifsdb.append(Notification(text=text))
|
||||
content = f"USER BROWSING SESSION EVENT: \n"
|
||||
content += f"=======================================METADATA==================================== \n"
|
||||
content += f"User Browsing Session ID : {doc.metadata.BrowsingSessionId} \n"
|
||||
content += f"User Visited website with url : {doc.metadata.VisitedWebPageURL} \n"
|
||||
content += f"This visited website url had title : {doc.metadata.VisitedWebPageTitle} \n"
|
||||
content += f"User Visited this website from reffering url : {doc.metadata.VisitedWebPageReffererURL} \n"
|
||||
content += f"User Visited this website url at this Date and Time : {doc.metadata.VisitedWebPageDateWithTimeInISOString} \n"
|
||||
content += f"User Visited this website for : {str(doc.metadata.VisitedWebPageVisitDurationInMilliseconds)} milliseconds. \n"
|
||||
content += f"===================================================================================== \n"
|
||||
content += f"Webpage Content of the visited webpage url in markdown format : \n\n {doc.pageContent} \n\n"
|
||||
content += f"===================================================================================== \n"
|
||||
raw_documents.append(Document(page_content=content,metadata=doc.metadata.__dict__))
|
||||
|
||||
pgdocmeta = stringify(doc.metadata.__dict__)
|
||||
|
||||
if(searchspace):
|
||||
DocumentPgEntry.append(Documents(file_type='WEBPAGE',title=doc.metadata.VisitedWebPageTitle,search_space=searchspace, document_metadata=pgdocmeta, page_content=content))
|
||||
else:
|
||||
DocumentPgEntry.append(Documents(file_type='WEBPAGE',title=doc.metadata.VisitedWebPageTitle,search_space=SearchSpace(search_space=apires.search_space.upper()), document_metadata=pgdocmeta, page_content=content))
|
||||
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
user.notifications.extend(notifsdb)
|
||||
|
||||
db.commit()
|
||||
#Save docs in PG
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
user.documents.extend(DocumentPgEntry)
|
||||
|
||||
db.commit()
|
||||
|
||||
|
||||
# Create Heirarical Indecices
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
index = HIndices(username=username)
|
||||
else:
|
||||
index = HIndices(username=username,api_key=apires.openaikey)
|
||||
|
||||
#Save Indices in vector Stores
|
||||
index.encode_docs_hierarchical(documents=raw_documents, files_type='WEBPAGE',search_space=apires.search_space.upper(), db=db)
|
||||
|
||||
print("FINISHED")
|
||||
|
||||
|
@ -360,21 +215,83 @@ def populate_graph(apires: RetrivedDocList, db: Session = Depends(get_db)):
|
|||
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
|
||||
#Fuction to populate db ( Comment out when running on server )
|
||||
@app.post("/add/")
|
||||
def add_user(user: UserCreate, db: Session = Depends(get_db)):
|
||||
db_user = User(username=user.username, hashed_password=user.password, graph_config="", llm_config="")
|
||||
db.add(db_user)
|
||||
db.commit()
|
||||
return "Success"
|
||||
|
||||
# Multi DOC Chat
|
||||
@app.post("/chat/docs")
|
||||
def doc_chat_with_history(data: UserQueryWithChatHistory, response_model=DescriptionResponse):
|
||||
try:
|
||||
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
llm = OllamaLLM(model="mistral-nemo",temperature=0)
|
||||
else:
|
||||
llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", api_key=data.openaikey)
|
||||
|
||||
chatHistory = []
|
||||
|
||||
for chat in data.chat:
|
||||
if(chat.type == 'system'):
|
||||
chatHistory.append(SystemMessage(content=DATE_TODAY + """You are an helpful assistant for question-answering tasks.
|
||||
Use the following pieces of retrieved context to answer the question.
|
||||
If you don't know the answer, just say that you don't know.
|
||||
Context:""" + str(chat.content)))
|
||||
|
||||
if(chat.type == 'ai'):
|
||||
chatHistory.append(AIMessage(content=chat.content))
|
||||
|
||||
if(chat.type == 'human'):
|
||||
chatHistory.append(HumanMessage(content=chat.content))
|
||||
|
||||
chatHistory.append(("human", "{input}"));
|
||||
|
||||
|
||||
qa_prompt = ChatPromptTemplate.from_messages(chatHistory)
|
||||
|
||||
descriptionchain = qa_prompt | llm
|
||||
|
||||
response = descriptionchain.invoke({"input": data.query})
|
||||
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
return DescriptionResponse(response=response)
|
||||
else:
|
||||
return DescriptionResponse(response=response.content)
|
||||
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
|
||||
# Multi DOC Chat
|
||||
|
||||
@app.post("/delete/docs")
|
||||
def delete_all_related_data(data: DocumentsToDelete, db: Session = Depends(get_db)):
|
||||
try:
|
||||
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
index = HIndices(username=username)
|
||||
else:
|
||||
index = HIndices(username=username,api_key=data.openaikey)
|
||||
|
||||
message = index.delete_vector_stores(summary_ids_to_delete=data.ids_to_delete,db=db )
|
||||
|
||||
return {
|
||||
"message": message
|
||||
}
|
||||
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
|
||||
#AUTH CODE
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
# Recommended for Local Setups
|
||||
# Manual Origins
|
||||
# origins = [
|
||||
# "http://localhost:3000", # Adjust the port if your frontend runs on a different one
|
||||
# "https://yourfrontenddomain.com",
|
||||
|
@ -395,7 +312,7 @@ def get_user_by_username(db: Session, username: str):
|
|||
|
||||
def create_user(db: Session, user: UserCreate):
|
||||
hashed_password = pwd_context.hash(user.password)
|
||||
db_user = User(username=user.username, hashed_password=hashed_password, graph_config="", llm_config="")
|
||||
db_user = User(username=user.username, hashed_password=hashed_password)
|
||||
db.add(db_user)
|
||||
db.commit()
|
||||
return "complete"
|
||||
|
@ -462,9 +379,6 @@ async def verify_user_token(token: str):
|
|||
verify_token(token=token)
|
||||
return {"message": "Token is valid"}
|
||||
|
||||
|
||||
|
||||
|
||||
@app.post("/user/chat/save")
|
||||
def populate_user_chat(chat: NewUserChat, db: Session = Depends(get_db)):
|
||||
try:
|
||||
|
@ -533,34 +447,32 @@ async def get_user_with_token(token: str, db: Session = Depends(get_db)):
|
|||
"userid": user.id,
|
||||
"username": user.username,
|
||||
"chats": user.chats,
|
||||
"notifications": user.notifications
|
||||
"documents": user.documents
|
||||
}
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
|
||||
@app.get("/user/notification/delete/{token}/{notificationid}")
|
||||
async def delete_chat_of_user(token: str, notificationid: str, db: Session = Depends(get_db)):
|
||||
@app.get("/searchspaces/{token}")
|
||||
async def get_user_with_token(token: str, db: Session = Depends(get_db)):
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
notificationindb = db.query(Notification).filter(Notification.id == notificationid).first()
|
||||
db.delete(notificationindb)
|
||||
db.commit()
|
||||
search_spaces = db.query(SearchSpace).all()
|
||||
return {
|
||||
"message": "Notification Deleted"
|
||||
"search_spaces": search_spaces
|
||||
}
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="127.0.0.1", port=8000)
|
||||
uvicorn.run(app, host="127.0.0.1", port=8000)
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue