mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-02 10:39:13 +00:00
322 lines
No EOL
11 KiB
Python
322 lines
No EOL
11 KiB
Python
from __future__ import annotations
|
|
|
|
from langchain.chains import GraphCypherQAChain
|
|
from langchain_community.graphs import Neo4jGraph
|
|
from langchain_core.documents import Document
|
|
from langchain_openai import OpenAIEmbeddings
|
|
from langchain_community.vectorstores import Neo4jVector
|
|
from envs import ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, API_SECRET_KEY, SECRET_KEY
|
|
from prompts import CYPHER_QA_PROMPT, DOC_DESCRIPTION_PROMPT, SIMILARITY_SEARCH_PROMPT , CYPHER_GENERATION_PROMPT, DOCUMENT_METADATA_EXTRACTION_PROMT
|
|
from pydmodels import DescriptionResponse, UserQuery, DocMeta, RetrivedDocList, UserQueryResponse
|
|
from langchain_experimental.text_splitter import SemanticChunker
|
|
|
|
#Our Imps
|
|
from LLMGraphTransformer import LLMGraphTransformer
|
|
from langchain_openai import ChatOpenAI
|
|
|
|
# Auth Libs
|
|
from fastapi import FastAPI, Depends, HTTPException, Request, status
|
|
from sqlalchemy.orm import Session
|
|
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
|
from jose import JWTError, jwt
|
|
from datetime import datetime, timedelta
|
|
from passlib.context import CryptContext
|
|
from models import User
|
|
from database import SessionLocal, engine
|
|
from pydantic import BaseModel
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/")
|
|
def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
|
|
|
|
if(data.apisecretkey != API_SECRET_KEY):
|
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
|
|
|
|
query = data.query
|
|
|
|
graph = Neo4jGraph(url=data.neourl, username=data.neouser, password=data.neopass)
|
|
|
|
llm = ChatOpenAI(
|
|
model="gpt-4o-mini",
|
|
temperature=0,
|
|
max_tokens=None,
|
|
timeout=None,
|
|
api_key=data.openaikey
|
|
)
|
|
|
|
embeddings = OpenAIEmbeddings(
|
|
model="text-embedding-ada-002",
|
|
api_key=data.openaikey,
|
|
)
|
|
|
|
|
|
chain = GraphCypherQAChain.from_llm(
|
|
graph=graph,
|
|
cypher_prompt=CYPHER_GENERATION_PROMPT,
|
|
cypher_llm=llm,
|
|
verbose=True,
|
|
validate_cypher=True,
|
|
qa_prompt=CYPHER_QA_PROMPT ,
|
|
qa_llm=llm,
|
|
return_intermediate_steps=True,
|
|
top_k=5,
|
|
)
|
|
|
|
vector_index = Neo4jVector.from_existing_graph(
|
|
embeddings,
|
|
graph=graph,
|
|
search_type="hybrid",
|
|
node_label="Document",
|
|
text_node_properties=["text"],
|
|
embedding_node_property="embedding",
|
|
)
|
|
|
|
docs = vector_index.similarity_search(query,k=5)
|
|
|
|
docstoreturn = []
|
|
|
|
for doc in docs:
|
|
docstoreturn.append(
|
|
DocMeta(
|
|
BrowsingSessionId=doc.metadata["BrowsingSessionId"] if "BrowsingSessionId" in doc.metadata.keys() else "NOT AVAILABLE",
|
|
VisitedWebPageURL=doc.metadata["VisitedWebPageURL"] if "VisitedWebPageURL" in doc.metadata.keys()else "NOT AVAILABLE",
|
|
VisitedWebPageTitle=doc.metadata["VisitedWebPageTitle"] if "VisitedWebPageTitle" in doc.metadata.keys() else "NOT AVAILABLE",
|
|
VisitedWebPageDateWithTimeInISOString= doc.metadata["VisitedWebPageDateWithTimeInISOString"] if "VisitedWebPageDateWithTimeInISOString" in doc.metadata.keys() else "NOT AVAILABLE",
|
|
VisitedWebPageReffererURL= doc.metadata["VisitedWebPageReffererURL"] if "VisitedWebPageReffererURL" in doc.metadata.keys() else "NOT AVAILABLE",
|
|
VisitedWebPageVisitDurationInMilliseconds= doc.metadata["VisitedWebPageVisitDurationInMilliseconds"] if "VisitedWebPageVisitDurationInMilliseconds" in doc.metadata.keys() else None,
|
|
VisitedWebPageContent= doc.page_content if doc.page_content else "NOT AVAILABLE"
|
|
)
|
|
)
|
|
|
|
docstoreturn = [i for n, i in enumerate(docstoreturn) if i not in docstoreturn[n + 1:]]
|
|
|
|
|
|
try:
|
|
response = chain.invoke({"query": query})
|
|
if "don't know" in response["result"]:
|
|
raise Exception("No response from graph")
|
|
|
|
structured_llm = llm.with_structured_output(RetrivedDocList)
|
|
doc_extract_chain = DOCUMENT_METADATA_EXTRACTION_PROMT | structured_llm
|
|
|
|
query = doc_extract_chain.invoke(response["intermediate_steps"][1]["context"])
|
|
|
|
docs = vector_index.similarity_search(query.searchquery,k=5)
|
|
|
|
docstoreturn = []
|
|
|
|
for doc in docs:
|
|
docstoreturn.append(
|
|
DocMeta(
|
|
BrowsingSessionId=doc.metadata["BrowsingSessionId"] if "BrowsingSessionId" in doc.metadata.keys() else "NOT AVAILABLE",
|
|
VisitedWebPageURL=doc.metadata["VisitedWebPageURL"] if "VisitedWebPageURL" in doc.metadata.keys()else "NOT AVAILABLE",
|
|
VisitedWebPageTitle=doc.metadata["VisitedWebPageTitle"] if "VisitedWebPageTitle" in doc.metadata.keys() else "NOT AVAILABLE",
|
|
VisitedWebPageDateWithTimeInISOString= doc.metadata["VisitedWebPageDateWithTimeInISOString"] if "VisitedWebPageDateWithTimeInISOString" in doc.metadata.keys() else "NOT AVAILABLE",
|
|
VisitedWebPageReffererURL= doc.metadata["VisitedWebPageReffererURL"] if "VisitedWebPageReffererURL" in doc.metadata.keys() else "NOT AVAILABLE",
|
|
VisitedWebPageVisitDurationInMilliseconds= doc.metadata["VisitedWebPageVisitDurationInMilliseconds"] if "VisitedWebPageVisitDurationInMilliseconds" in doc.metadata.keys() else None,
|
|
VisitedWebPageContent= doc.page_content if doc.page_content else "NOT AVAILABLE"
|
|
)
|
|
)
|
|
|
|
docstoreturn = [i for n, i in enumerate(docstoreturn) if i not in docstoreturn[n + 1:]]
|
|
|
|
return UserQueryResponse(relateddocs=docstoreturn,response=response["result"])
|
|
except:
|
|
# Fallback to Similarity Search RAG
|
|
searchchain = SIMILARITY_SEARCH_PROMPT | llm
|
|
|
|
response = searchchain.invoke({"question": query, "context": docs})
|
|
|
|
return UserQueryResponse(relateddocs=docstoreturn,response=response.content)
|
|
|
|
# DOC DESCRIPTION
|
|
@app.post("/kb/doc")
|
|
def get_doc_description(data: UserQuery, response_model=DescriptionResponse):
|
|
if(data.apisecretkey != API_SECRET_KEY):
|
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
|
|
document = data.query
|
|
llm = ChatOpenAI(
|
|
model="gpt-4o-mini",
|
|
temperature=0,
|
|
max_tokens=None,
|
|
timeout=None,
|
|
api_key=data.openaikey
|
|
)
|
|
|
|
descriptionchain = DOC_DESCRIPTION_PROMPT | llm
|
|
|
|
response = descriptionchain.invoke({"document": document})
|
|
|
|
return DescriptionResponse(response=response.content)
|
|
|
|
|
|
# SAVE DOCS TO GRAPH DB
|
|
@app.post("/kb/")
|
|
def populate_graph(apires: RetrivedDocList):
|
|
if(apires.apisecretkey != API_SECRET_KEY):
|
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
|
|
print("STARTED")
|
|
# print(apires)
|
|
graph = Neo4jGraph(url=apires.neourl, username=apires.neouser, password=apires.neopass)
|
|
|
|
llm = ChatOpenAI(
|
|
model="gpt-4o-mini",
|
|
temperature=0,
|
|
max_tokens=None,
|
|
timeout=None,
|
|
api_key=apires.openaikey
|
|
)
|
|
|
|
embeddings = OpenAIEmbeddings(
|
|
model="text-embedding-ada-002",
|
|
api_key=apires.openaikey,
|
|
)
|
|
|
|
llm_transformer = LLMGraphTransformer(llm=llm)
|
|
|
|
raw_documents = []
|
|
|
|
for doc in apires.documents:
|
|
raw_documents.append(Document(page_content=doc.pageContent, metadata=doc.metadata))
|
|
|
|
text_splitter = SemanticChunker(embeddings=embeddings)
|
|
|
|
documents = text_splitter.split_documents(raw_documents)
|
|
graph_documents = llm_transformer.convert_to_graph_documents(documents)
|
|
|
|
|
|
graph.add_graph_documents(
|
|
graph_documents,
|
|
baseEntityLabel=True,
|
|
include_source=True
|
|
)
|
|
|
|
print("FINISHED")
|
|
|
|
return {
|
|
"success": "Graph Will be populated Shortly"
|
|
}
|
|
|
|
|
|
|
|
|
|
#AUTH CODE
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
|
|
|
# Recommended for Local Setups
|
|
# origins = [
|
|
# "http://localhost:3000", # Adjust the port if your frontend runs on a different one
|
|
# "https://yourfrontenddomain.com",
|
|
# ]
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"], # Allows all origins from the list
|
|
allow_credentials=True,
|
|
allow_methods=["*"], # Allows all methods
|
|
allow_headers=["*"], # Allows all headers
|
|
)
|
|
|
|
# Dependency
|
|
def get_db():
|
|
db = SessionLocal()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
|
|
|
class UserCreate(BaseModel):
|
|
username: str
|
|
password: str
|
|
apisecretkey: str
|
|
|
|
def get_user_by_username(db: Session, username: str):
|
|
return db.query(User).filter(User.username == username).first()
|
|
|
|
def create_user(db: Session, user: UserCreate):
|
|
hashed_password = pwd_context.hash(user.password)
|
|
db_user = User(username=user.username, hashed_password=hashed_password)
|
|
db.add(db_user)
|
|
db.commit()
|
|
return "complete"
|
|
|
|
@app.post("/register")
|
|
def register_user(user: UserCreate, db: Session = Depends(get_db)):
|
|
if(user.apisecretkey != API_SECRET_KEY):
|
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
|
|
db_user = get_user_by_username(db, username=user.username)
|
|
if db_user:
|
|
raise HTTPException(status_code=400, detail="Username already registered")
|
|
|
|
del user.apisecretkey
|
|
return create_user(db=db, user=user)
|
|
|
|
# Authenticate the user
|
|
def authenticate_user(username: str, password: str, db: Session):
|
|
user = db.query(User).filter(User.username == username).first()
|
|
if not user:
|
|
return False
|
|
if not pwd_context.verify(password, user.hashed_password):
|
|
return False
|
|
return user
|
|
|
|
# Create access token
|
|
def create_access_token(data: dict, expires_delta: timedelta | None = None):
|
|
to_encode = data.copy()
|
|
if expires_delta:
|
|
expire = datetime.utcnow() + expires_delta
|
|
else:
|
|
expire = datetime.utcnow() + timedelta(minutes=15)
|
|
to_encode.update({"exp": expire})
|
|
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
|
return encoded_jwt
|
|
|
|
@app.post("/token")
|
|
def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
|
|
user = authenticate_user(form_data.username, form_data.password, db)
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Incorrect username or password",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
access_token = create_access_token(
|
|
data={"sub": user.username}, expires_delta=access_token_expires
|
|
)
|
|
return {"access_token": access_token, "token_type": "bearer"}
|
|
|
|
|
|
def verify_token(token: str = Depends(oauth2_scheme)):
|
|
try:
|
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
username: str = payload.get("sub")
|
|
if username is None:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
return payload
|
|
except JWTError:
|
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
|
|
|
@app.get("/verify-token/{token}")
|
|
async def verify_user_token(token: str):
|
|
verify_token(token=token)
|
|
return {"message": "Token is valid"}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run(app, host="127.0.0.1", port=8000) |