SurfSense/backend/server.py
DESKTOP-RTLN3BA\$punk 63146aa9b7 init
2024-08-12 00:32:42 -07:00

322 lines
No EOL
11 KiB
Python

from __future__ import annotations
from langchain.chains import GraphCypherQAChain
from langchain_community.graphs import Neo4jGraph
from langchain_core.documents import Document
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import Neo4jVector
from envs import ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, API_SECRET_KEY, SECRET_KEY
from prompts import CYPHER_QA_PROMPT, DOC_DESCRIPTION_PROMPT, SIMILARITY_SEARCH_PROMPT , CYPHER_GENERATION_PROMPT, DOCUMENT_METADATA_EXTRACTION_PROMT
from pydmodels import DescriptionResponse, UserQuery, DocMeta, RetrivedDocList, UserQueryResponse
from langchain_experimental.text_splitter import SemanticChunker
#Our Imps
from LLMGraphTransformer import LLMGraphTransformer
from langchain_openai import ChatOpenAI
# Auth Libs
from fastapi import FastAPI, Depends, HTTPException, Request, status
from sqlalchemy.orm import Session
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from datetime import datetime, timedelta
from passlib.context import CryptContext
from models import User
from database import SessionLocal, engine
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
@app.post("/")
def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
if(data.apisecretkey != API_SECRET_KEY):
raise HTTPException(status_code=401, detail="Unauthorized")
query = data.query
graph = Neo4jGraph(url=data.neourl, username=data.neouser, password=data.neopass)
llm = ChatOpenAI(
model="gpt-4o-mini",
temperature=0,
max_tokens=None,
timeout=None,
api_key=data.openaikey
)
embeddings = OpenAIEmbeddings(
model="text-embedding-ada-002",
api_key=data.openaikey,
)
chain = GraphCypherQAChain.from_llm(
graph=graph,
cypher_prompt=CYPHER_GENERATION_PROMPT,
cypher_llm=llm,
verbose=True,
validate_cypher=True,
qa_prompt=CYPHER_QA_PROMPT ,
qa_llm=llm,
return_intermediate_steps=True,
top_k=5,
)
vector_index = Neo4jVector.from_existing_graph(
embeddings,
graph=graph,
search_type="hybrid",
node_label="Document",
text_node_properties=["text"],
embedding_node_property="embedding",
)
docs = vector_index.similarity_search(query,k=5)
docstoreturn = []
for doc in docs:
docstoreturn.append(
DocMeta(
BrowsingSessionId=doc.metadata["BrowsingSessionId"] if "BrowsingSessionId" in doc.metadata.keys() else "NOT AVAILABLE",
VisitedWebPageURL=doc.metadata["VisitedWebPageURL"] if "VisitedWebPageURL" in doc.metadata.keys()else "NOT AVAILABLE",
VisitedWebPageTitle=doc.metadata["VisitedWebPageTitle"] if "VisitedWebPageTitle" in doc.metadata.keys() else "NOT AVAILABLE",
VisitedWebPageDateWithTimeInISOString= doc.metadata["VisitedWebPageDateWithTimeInISOString"] if "VisitedWebPageDateWithTimeInISOString" in doc.metadata.keys() else "NOT AVAILABLE",
VisitedWebPageReffererURL= doc.metadata["VisitedWebPageReffererURL"] if "VisitedWebPageReffererURL" in doc.metadata.keys() else "NOT AVAILABLE",
VisitedWebPageVisitDurationInMilliseconds= doc.metadata["VisitedWebPageVisitDurationInMilliseconds"] if "VisitedWebPageVisitDurationInMilliseconds" in doc.metadata.keys() else None,
VisitedWebPageContent= doc.page_content if doc.page_content else "NOT AVAILABLE"
)
)
docstoreturn = [i for n, i in enumerate(docstoreturn) if i not in docstoreturn[n + 1:]]
try:
response = chain.invoke({"query": query})
if "don't know" in response["result"]:
raise Exception("No response from graph")
structured_llm = llm.with_structured_output(RetrivedDocList)
doc_extract_chain = DOCUMENT_METADATA_EXTRACTION_PROMT | structured_llm
query = doc_extract_chain.invoke(response["intermediate_steps"][1]["context"])
docs = vector_index.similarity_search(query.searchquery,k=5)
docstoreturn = []
for doc in docs:
docstoreturn.append(
DocMeta(
BrowsingSessionId=doc.metadata["BrowsingSessionId"] if "BrowsingSessionId" in doc.metadata.keys() else "NOT AVAILABLE",
VisitedWebPageURL=doc.metadata["VisitedWebPageURL"] if "VisitedWebPageURL" in doc.metadata.keys()else "NOT AVAILABLE",
VisitedWebPageTitle=doc.metadata["VisitedWebPageTitle"] if "VisitedWebPageTitle" in doc.metadata.keys() else "NOT AVAILABLE",
VisitedWebPageDateWithTimeInISOString= doc.metadata["VisitedWebPageDateWithTimeInISOString"] if "VisitedWebPageDateWithTimeInISOString" in doc.metadata.keys() else "NOT AVAILABLE",
VisitedWebPageReffererURL= doc.metadata["VisitedWebPageReffererURL"] if "VisitedWebPageReffererURL" in doc.metadata.keys() else "NOT AVAILABLE",
VisitedWebPageVisitDurationInMilliseconds= doc.metadata["VisitedWebPageVisitDurationInMilliseconds"] if "VisitedWebPageVisitDurationInMilliseconds" in doc.metadata.keys() else None,
VisitedWebPageContent= doc.page_content if doc.page_content else "NOT AVAILABLE"
)
)
docstoreturn = [i for n, i in enumerate(docstoreturn) if i not in docstoreturn[n + 1:]]
return UserQueryResponse(relateddocs=docstoreturn,response=response["result"])
except:
# Fallback to Similarity Search RAG
searchchain = SIMILARITY_SEARCH_PROMPT | llm
response = searchchain.invoke({"question": query, "context": docs})
return UserQueryResponse(relateddocs=docstoreturn,response=response.content)
# DOC DESCRIPTION
@app.post("/kb/doc")
def get_doc_description(data: UserQuery, response_model=DescriptionResponse):
if(data.apisecretkey != API_SECRET_KEY):
raise HTTPException(status_code=401, detail="Unauthorized")
document = data.query
llm = ChatOpenAI(
model="gpt-4o-mini",
temperature=0,
max_tokens=None,
timeout=None,
api_key=data.openaikey
)
descriptionchain = DOC_DESCRIPTION_PROMPT | llm
response = descriptionchain.invoke({"document": document})
return DescriptionResponse(response=response.content)
# SAVE DOCS TO GRAPH DB
@app.post("/kb/")
def populate_graph(apires: RetrivedDocList):
if(apires.apisecretkey != API_SECRET_KEY):
raise HTTPException(status_code=401, detail="Unauthorized")
print("STARTED")
# print(apires)
graph = Neo4jGraph(url=apires.neourl, username=apires.neouser, password=apires.neopass)
llm = ChatOpenAI(
model="gpt-4o-mini",
temperature=0,
max_tokens=None,
timeout=None,
api_key=apires.openaikey
)
embeddings = OpenAIEmbeddings(
model="text-embedding-ada-002",
api_key=apires.openaikey,
)
llm_transformer = LLMGraphTransformer(llm=llm)
raw_documents = []
for doc in apires.documents:
raw_documents.append(Document(page_content=doc.pageContent, metadata=doc.metadata))
text_splitter = SemanticChunker(embeddings=embeddings)
documents = text_splitter.split_documents(raw_documents)
graph_documents = llm_transformer.convert_to_graph_documents(documents)
graph.add_graph_documents(
graph_documents,
baseEntityLabel=True,
include_source=True
)
print("FINISHED")
return {
"success": "Graph Will be populated Shortly"
}
#AUTH CODE
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# Recommended for Local Setups
# origins = [
# "http://localhost:3000", # Adjust the port if your frontend runs on a different one
# "https://yourfrontenddomain.com",
# ]
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins from the list
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
# Dependency
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
class UserCreate(BaseModel):
username: str
password: str
apisecretkey: str
def get_user_by_username(db: Session, username: str):
return db.query(User).filter(User.username == username).first()
def create_user(db: Session, user: UserCreate):
hashed_password = pwd_context.hash(user.password)
db_user = User(username=user.username, hashed_password=hashed_password)
db.add(db_user)
db.commit()
return "complete"
@app.post("/register")
def register_user(user: UserCreate, db: Session = Depends(get_db)):
if(user.apisecretkey != API_SECRET_KEY):
raise HTTPException(status_code=401, detail="Unauthorized")
db_user = get_user_by_username(db, username=user.username)
if db_user:
raise HTTPException(status_code=400, detail="Username already registered")
del user.apisecretkey
return create_user(db=db, user=user)
# Authenticate the user
def authenticate_user(username: str, password: str, db: Session):
user = db.query(User).filter(User.username == username).first()
if not user:
return False
if not pwd_context.verify(password, user.hashed_password):
return False
return user
# Create access token
def create_access_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
@app.post("/token")
def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
user = authenticate_user(form_data.username, form_data.password, db)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}
def verify_token(token: str = Depends(oauth2_scheme)):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
return payload
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
@app.get("/verify-token/{token}")
async def verify_user_token(token: str):
verify_token(token=token)
return {"message": "Token is valid"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8000)