mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-10 14:28:57 +00:00
init
This commit is contained in:
parent
24f641347d
commit
63146aa9b7
86 changed files with 18766 additions and 0 deletions
322
backend/server.py
Normal file
322
backend/server.py
Normal file
|
@ -0,0 +1,322 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from langchain.chains import GraphCypherQAChain
|
||||
from langchain_community.graphs import Neo4jGraph
|
||||
from langchain_core.documents import Document
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain_community.vectorstores import Neo4jVector
|
||||
from envs import ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, API_SECRET_KEY, SECRET_KEY
|
||||
from prompts import CYPHER_QA_PROMPT, DOC_DESCRIPTION_PROMPT, SIMILARITY_SEARCH_PROMPT , CYPHER_GENERATION_PROMPT, DOCUMENT_METADATA_EXTRACTION_PROMT
|
||||
from pydmodels import DescriptionResponse, UserQuery, DocMeta, RetrivedDocList, UserQueryResponse
|
||||
from langchain_experimental.text_splitter import SemanticChunker
|
||||
|
||||
#Our Imps
|
||||
from LLMGraphTransformer import LLMGraphTransformer
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
# Auth Libs
|
||||
from fastapi import FastAPI, Depends, HTTPException, Request, status
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from jose import JWTError, jwt
|
||||
from datetime import datetime, timedelta
|
||||
from passlib.context import CryptContext
|
||||
from models import User
|
||||
from database import SessionLocal, engine
|
||||
from pydantic import BaseModel
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@app.post("/")
|
||||
def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
|
||||
|
||||
if(data.apisecretkey != API_SECRET_KEY):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
|
||||
query = data.query
|
||||
|
||||
graph = Neo4jGraph(url=data.neourl, username=data.neouser, password=data.neopass)
|
||||
|
||||
llm = ChatOpenAI(
|
||||
model="gpt-4o-mini",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
api_key=data.openaikey
|
||||
)
|
||||
|
||||
embeddings = OpenAIEmbeddings(
|
||||
model="text-embedding-ada-002",
|
||||
api_key=data.openaikey,
|
||||
)
|
||||
|
||||
|
||||
chain = GraphCypherQAChain.from_llm(
|
||||
graph=graph,
|
||||
cypher_prompt=CYPHER_GENERATION_PROMPT,
|
||||
cypher_llm=llm,
|
||||
verbose=True,
|
||||
validate_cypher=True,
|
||||
qa_prompt=CYPHER_QA_PROMPT ,
|
||||
qa_llm=llm,
|
||||
return_intermediate_steps=True,
|
||||
top_k=5,
|
||||
)
|
||||
|
||||
vector_index = Neo4jVector.from_existing_graph(
|
||||
embeddings,
|
||||
graph=graph,
|
||||
search_type="hybrid",
|
||||
node_label="Document",
|
||||
text_node_properties=["text"],
|
||||
embedding_node_property="embedding",
|
||||
)
|
||||
|
||||
docs = vector_index.similarity_search(query,k=5)
|
||||
|
||||
docstoreturn = []
|
||||
|
||||
for doc in docs:
|
||||
docstoreturn.append(
|
||||
DocMeta(
|
||||
BrowsingSessionId=doc.metadata["BrowsingSessionId"] if "BrowsingSessionId" in doc.metadata.keys() else "NOT AVAILABLE",
|
||||
VisitedWebPageURL=doc.metadata["VisitedWebPageURL"] if "VisitedWebPageURL" in doc.metadata.keys()else "NOT AVAILABLE",
|
||||
VisitedWebPageTitle=doc.metadata["VisitedWebPageTitle"] if "VisitedWebPageTitle" in doc.metadata.keys() else "NOT AVAILABLE",
|
||||
VisitedWebPageDateWithTimeInISOString= doc.metadata["VisitedWebPageDateWithTimeInISOString"] if "VisitedWebPageDateWithTimeInISOString" in doc.metadata.keys() else "NOT AVAILABLE",
|
||||
VisitedWebPageReffererURL= doc.metadata["VisitedWebPageReffererURL"] if "VisitedWebPageReffererURL" in doc.metadata.keys() else "NOT AVAILABLE",
|
||||
VisitedWebPageVisitDurationInMilliseconds= doc.metadata["VisitedWebPageVisitDurationInMilliseconds"] if "VisitedWebPageVisitDurationInMilliseconds" in doc.metadata.keys() else None,
|
||||
VisitedWebPageContent= doc.page_content if doc.page_content else "NOT AVAILABLE"
|
||||
)
|
||||
)
|
||||
|
||||
docstoreturn = [i for n, i in enumerate(docstoreturn) if i not in docstoreturn[n + 1:]]
|
||||
|
||||
|
||||
try:
|
||||
response = chain.invoke({"query": query})
|
||||
if "don't know" in response["result"]:
|
||||
raise Exception("No response from graph")
|
||||
|
||||
structured_llm = llm.with_structured_output(RetrivedDocList)
|
||||
doc_extract_chain = DOCUMENT_METADATA_EXTRACTION_PROMT | structured_llm
|
||||
|
||||
query = doc_extract_chain.invoke(response["intermediate_steps"][1]["context"])
|
||||
|
||||
docs = vector_index.similarity_search(query.searchquery,k=5)
|
||||
|
||||
docstoreturn = []
|
||||
|
||||
for doc in docs:
|
||||
docstoreturn.append(
|
||||
DocMeta(
|
||||
BrowsingSessionId=doc.metadata["BrowsingSessionId"] if "BrowsingSessionId" in doc.metadata.keys() else "NOT AVAILABLE",
|
||||
VisitedWebPageURL=doc.metadata["VisitedWebPageURL"] if "VisitedWebPageURL" in doc.metadata.keys()else "NOT AVAILABLE",
|
||||
VisitedWebPageTitle=doc.metadata["VisitedWebPageTitle"] if "VisitedWebPageTitle" in doc.metadata.keys() else "NOT AVAILABLE",
|
||||
VisitedWebPageDateWithTimeInISOString= doc.metadata["VisitedWebPageDateWithTimeInISOString"] if "VisitedWebPageDateWithTimeInISOString" in doc.metadata.keys() else "NOT AVAILABLE",
|
||||
VisitedWebPageReffererURL= doc.metadata["VisitedWebPageReffererURL"] if "VisitedWebPageReffererURL" in doc.metadata.keys() else "NOT AVAILABLE",
|
||||
VisitedWebPageVisitDurationInMilliseconds= doc.metadata["VisitedWebPageVisitDurationInMilliseconds"] if "VisitedWebPageVisitDurationInMilliseconds" in doc.metadata.keys() else None,
|
||||
VisitedWebPageContent= doc.page_content if doc.page_content else "NOT AVAILABLE"
|
||||
)
|
||||
)
|
||||
|
||||
docstoreturn = [i for n, i in enumerate(docstoreturn) if i not in docstoreturn[n + 1:]]
|
||||
|
||||
return UserQueryResponse(relateddocs=docstoreturn,response=response["result"])
|
||||
except:
|
||||
# Fallback to Similarity Search RAG
|
||||
searchchain = SIMILARITY_SEARCH_PROMPT | llm
|
||||
|
||||
response = searchchain.invoke({"question": query, "context": docs})
|
||||
|
||||
return UserQueryResponse(relateddocs=docstoreturn,response=response.content)
|
||||
|
||||
# DOC DESCRIPTION
|
||||
@app.post("/kb/doc")
|
||||
def get_doc_description(data: UserQuery, response_model=DescriptionResponse):
|
||||
if(data.apisecretkey != API_SECRET_KEY):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
document = data.query
|
||||
llm = ChatOpenAI(
|
||||
model="gpt-4o-mini",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
api_key=data.openaikey
|
||||
)
|
||||
|
||||
descriptionchain = DOC_DESCRIPTION_PROMPT | llm
|
||||
|
||||
response = descriptionchain.invoke({"document": document})
|
||||
|
||||
return DescriptionResponse(response=response.content)
|
||||
|
||||
|
||||
# SAVE DOCS TO GRAPH DB
|
||||
@app.post("/kb/")
|
||||
def populate_graph(apires: RetrivedDocList):
|
||||
if(apires.apisecretkey != API_SECRET_KEY):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
print("STARTED")
|
||||
# print(apires)
|
||||
graph = Neo4jGraph(url=apires.neourl, username=apires.neouser, password=apires.neopass)
|
||||
|
||||
llm = ChatOpenAI(
|
||||
model="gpt-4o-mini",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
api_key=apires.openaikey
|
||||
)
|
||||
|
||||
embeddings = OpenAIEmbeddings(
|
||||
model="text-embedding-ada-002",
|
||||
api_key=apires.openaikey,
|
||||
)
|
||||
|
||||
llm_transformer = LLMGraphTransformer(llm=llm)
|
||||
|
||||
raw_documents = []
|
||||
|
||||
for doc in apires.documents:
|
||||
raw_documents.append(Document(page_content=doc.pageContent, metadata=doc.metadata))
|
||||
|
||||
text_splitter = SemanticChunker(embeddings=embeddings)
|
||||
|
||||
documents = text_splitter.split_documents(raw_documents)
|
||||
graph_documents = llm_transformer.convert_to_graph_documents(documents)
|
||||
|
||||
|
||||
graph.add_graph_documents(
|
||||
graph_documents,
|
||||
baseEntityLabel=True,
|
||||
include_source=True
|
||||
)
|
||||
|
||||
print("FINISHED")
|
||||
|
||||
return {
|
||||
"success": "Graph Will be populated Shortly"
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
#AUTH CODE
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
# Recommended for Local Setups
|
||||
# origins = [
|
||||
# "http://localhost:3000", # Adjust the port if your frontend runs on a different one
|
||||
# "https://yourfrontenddomain.com",
|
||||
# ]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Allows all origins from the list
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"], # Allows all methods
|
||||
allow_headers=["*"], # Allows all headers
|
||||
)
|
||||
|
||||
# Dependency
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
apisecretkey: str
|
||||
|
||||
def get_user_by_username(db: Session, username: str):
|
||||
return db.query(User).filter(User.username == username).first()
|
||||
|
||||
def create_user(db: Session, user: UserCreate):
|
||||
hashed_password = pwd_context.hash(user.password)
|
||||
db_user = User(username=user.username, hashed_password=hashed_password)
|
||||
db.add(db_user)
|
||||
db.commit()
|
||||
return "complete"
|
||||
|
||||
@app.post("/register")
|
||||
def register_user(user: UserCreate, db: Session = Depends(get_db)):
|
||||
if(user.apisecretkey != API_SECRET_KEY):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
db_user = get_user_by_username(db, username=user.username)
|
||||
if db_user:
|
||||
raise HTTPException(status_code=400, detail="Username already registered")
|
||||
|
||||
del user.apisecretkey
|
||||
return create_user(db=db, user=user)
|
||||
|
||||
# Authenticate the user
|
||||
def authenticate_user(username: str, password: str, db: Session):
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user:
|
||||
return False
|
||||
if not pwd_context.verify(password, user.hashed_password):
|
||||
return False
|
||||
return user
|
||||
|
||||
# Create access token
|
||||
def create_access_token(data: dict, expires_delta: timedelta | None = None):
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=15)
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
@app.post("/token")
|
||||
def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
|
||||
user = authenticate_user(form_data.username, form_data.password, db)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token = create_access_token(
|
||||
data={"sub": user.username}, expires_delta=access_token_expires
|
||||
)
|
||||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
|
||||
|
||||
def verify_token(token: str = Depends(oauth2_scheme)):
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
return payload
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
@app.get("/verify-token/{token}")
|
||||
async def verify_user_token(token: str):
|
||||
verify_token(token=token)
|
||||
return {"message": "Token is valid"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="127.0.0.1", port=8000)
|
Loading…
Add table
Add a link
Reference in a new issue