from __future__ import annotations from langchain.chains import GraphCypherQAChain from langchain_community.graphs import Neo4jGraph from langchain_core.documents import Document from langchain_openai import OpenAIEmbeddings from langchain_community.vectorstores import Neo4jVector from envs import ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, API_SECRET_KEY, SECRET_KEY from prompts import CYPHER_QA_PROMPT, DATE_TODAY, DOC_DESCRIPTION_PROMPT, GRAPH_QUERY_GEN_PROMPT, SIMILARITY_SEARCH_PROMPT , CYPHER_GENERATION_PROMPT, DOCUMENT_METADATA_EXTRACTION_PROMT from pydmodels import DescriptionResponse, PrecisionQuery, PrecisionResponse, UserQuery, DocMeta, RetrivedDocList, UserQueryResponse, UserQueryWithChatHistory, VectorSearchQuery from langchain_experimental.text_splitter import SemanticChunker from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.messages import HumanMessage, SystemMessage, AIMessage #Our Imps from LLMGraphTransformer import LLMGraphTransformer from langchain_openai import ChatOpenAI from DataExample import examples # Auth Libs from fastapi import FastAPI, Depends, HTTPException, Request, status from sqlalchemy.orm import Session from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from jose import JWTError, jwt from datetime import datetime, timedelta from passlib.context import CryptContext from models import User from database import SessionLocal, engine from pydantic import BaseModel from fastapi.middleware.cors import CORSMiddleware app = FastAPI() # GraphCypherQAChain @app.post("/") def get_user_query_response(data: UserQuery, response_model=UserQueryResponse): if(data.apisecretkey != API_SECRET_KEY): raise HTTPException(status_code=401, detail="Unauthorized") graph = Neo4jGraph(url=data.neourl, username=data.neouser, password=data.neopass) llm = ChatOpenAI( model="gpt-4o-mini", temperature=0, max_tokens=None, timeout=None, api_key=data.openaikey ) # Query Expansion searchchain = GRAPH_QUERY_GEN_PROMPT | llm # qry = searchchain.invoke({"question": data.query, "context": examples}) query = data.query #qry.content embeddings = OpenAIEmbeddings( model="text-embedding-ada-002", api_key=data.openaikey, ) chain = GraphCypherQAChain.from_llm( graph=graph, cypher_prompt=CYPHER_GENERATION_PROMPT, cypher_llm=llm, verbose=True, validate_cypher=True, qa_prompt=CYPHER_QA_PROMPT , qa_llm=llm, return_intermediate_steps=True, top_k=5, ) vector_index = Neo4jVector.from_existing_graph( embeddings, graph=graph, search_type="hybrid", node_label="Document", text_node_properties=["text"], embedding_node_property="embedding", ) docs = vector_index.similarity_search(data.query,k=5) docstoreturn = [] for doc in docs: docstoreturn.append( DocMeta( BrowsingSessionId=doc.metadata["BrowsingSessionId"] if "BrowsingSessionId" in doc.metadata.keys() else "NOT AVAILABLE", VisitedWebPageURL=doc.metadata["VisitedWebPageURL"] if "VisitedWebPageURL" in doc.metadata.keys()else "NOT AVAILABLE", VisitedWebPageTitle=doc.metadata["VisitedWebPageTitle"] if "VisitedWebPageTitle" in doc.metadata.keys() else "NOT AVAILABLE", VisitedWebPageDateWithTimeInISOString= doc.metadata["VisitedWebPageDateWithTimeInISOString"] if "VisitedWebPageDateWithTimeInISOString" in doc.metadata.keys() else "NOT AVAILABLE", VisitedWebPageReffererURL= doc.metadata["VisitedWebPageReffererURL"] if "VisitedWebPageReffererURL" in doc.metadata.keys() else "NOT AVAILABLE", VisitedWebPageVisitDurationInMilliseconds= doc.metadata["VisitedWebPageVisitDurationInMilliseconds"] if "VisitedWebPageVisitDurationInMilliseconds" in doc.metadata.keys() else None, VisitedWebPageContent= doc.page_content if doc.page_content else "NOT AVAILABLE" ) ) docstoreturn = [i for n, i in enumerate(docstoreturn) if i not in docstoreturn[n + 1:]] # responsegrp = chain.invoke({"query": query}) try: responsegrp = chain.invoke({"query": query}) if "don't know" in responsegrp["result"]: raise Exception("No response from graph") structured_llm = llm.with_structured_output(VectorSearchQuery) doc_extract_chain = DOCUMENT_METADATA_EXTRACTION_PROMT | structured_llm newquery = doc_extract_chain.invoke(responsegrp["intermediate_steps"][1]["context"]) docs = vector_index.similarity_search(newquery.searchquery,k=5) docstoreturn = [] for doc in docs: docstoreturn.append( DocMeta( BrowsingSessionId=doc.metadata["BrowsingSessionId"] if "BrowsingSessionId" in doc.metadata.keys() else "NOT AVAILABLE", VisitedWebPageURL=doc.metadata["VisitedWebPageURL"] if "VisitedWebPageURL" in doc.metadata.keys()else "NOT AVAILABLE", VisitedWebPageTitle=doc.metadata["VisitedWebPageTitle"] if "VisitedWebPageTitle" in doc.metadata.keys() else "NOT AVAILABLE", VisitedWebPageDateWithTimeInISOString= doc.metadata["VisitedWebPageDateWithTimeInISOString"] if "VisitedWebPageDateWithTimeInISOString" in doc.metadata.keys() else "NOT AVAILABLE", VisitedWebPageReffererURL= doc.metadata["VisitedWebPageReffererURL"] if "VisitedWebPageReffererURL" in doc.metadata.keys() else "NOT AVAILABLE", VisitedWebPageVisitDurationInMilliseconds= doc.metadata["VisitedWebPageVisitDurationInMilliseconds"] if "VisitedWebPageVisitDurationInMilliseconds" in doc.metadata.keys() else None, VisitedWebPageContent= doc.page_content if doc.page_content else "NOT AVAILABLE" ) ) docstoreturn = [i for n, i in enumerate(docstoreturn) if i not in docstoreturn[n + 1:]] return UserQueryResponse(relateddocs=docstoreturn,response=responsegrp["result"]) except: # Fallback to Similarity Search RAG searchchain = SIMILARITY_SEARCH_PROMPT | llm response = searchchain.invoke({"question": data.query, "context": docs}) return UserQueryResponse(relateddocs=docstoreturn,response=response.content) #RETURN n LIMIT 25; @app.post("/precision") def get_precision_search_response(data: PrecisionQuery, response_model=PrecisionResponse): if(data.apisecretkey != API_SECRET_KEY): raise HTTPException(status_code=401, detail="Unauthorized") graph = Neo4jGraph(url=data.neourl, username=data.neouser, password=data.neopass) GRAPH_QUERY = "MATCH (d:Document) WHERE d.VisitedWebPageDateWithTimeInISOString >= " + "'" + data.daterange[0] + "'" + " AND d.VisitedWebPageDateWithTimeInISOString <= " + "'" + data.daterange[1] + "'" if(data.timerange[0] >= data.timerange[1]): GRAPH_QUERY += " AND d.VisitedWebPageVisitDurationInMilliseconds >= 0" else: GRAPH_QUERY += " AND d.VisitedWebPageVisitDurationInMilliseconds >= "+ str(data.timerange[0]) + " AND d.VisitedWebPageVisitDurationInMilliseconds <= " + str(data.timerange[1]) if(data.webpageurl): GRAPH_QUERY += " AND d.VisitedWebPageURL CONTAINS " + "'" + data.webpageurl.lower() + "'" if(data.sessionid): GRAPH_QUERY += " AND d.BrowsingSessionId = " + "'" + data.sessionid + "'" GRAPH_QUERY += " RETURN d;" graphdocs = graph.query(GRAPH_QUERY) docsDict = {} for d in graphdocs: if d['d']['BrowsingSessionId'] not in docsDict: docsDict[d['d']['BrowsingSessionId']] = d['d'] else: docsDict[d['d']['BrowsingSessionId']]['text'] += d['d']['text'] docs = [] for x in docsDict.values(): docs.append(DocMeta( BrowsingSessionId=x['BrowsingSessionId'], VisitedWebPageURL=x['VisitedWebPageURL'], VisitedWebPageVisitDurationInMilliseconds=x['VisitedWebPageVisitDurationInMilliseconds'], VisitedWebPageTitle=x['VisitedWebPageTitle'], VisitedWebPageReffererURL=x['VisitedWebPageReffererURL'], VisitedWebPageDateWithTimeInISOString=x['VisitedWebPageDateWithTimeInISOString'], VisitedWebPageContent=x['text'] )) return PrecisionResponse(documents=docs) # Multi DOC Chat @app.post("/chat/docs") def doc_chat_with_history(data: UserQueryWithChatHistory, response_model=DescriptionResponse): if(data.apisecretkey != API_SECRET_KEY): raise HTTPException(status_code=401, detail="Unauthorized") llm = ChatOpenAI( model="gpt-4o-mini", temperature=0, max_tokens=None, timeout=None, api_key=data.openaikey ) chatHistory = [] for chat in data.chat: if(chat.type == 'system'): chatHistory.append(SystemMessage(content=DATE_TODAY + """You are an helpful assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Context:""" + str(chat.content))) if(chat.type == 'ai'): chatHistory.append(AIMessage(content=chat.content)) if(chat.type == 'human'): chatHistory.append(HumanMessage(content=chat.content)) chatHistory.append(("human", "{input}")); qa_prompt = ChatPromptTemplate.from_messages(chatHistory) descriptionchain = qa_prompt | llm response = descriptionchain.invoke({"input": data.query}) return DescriptionResponse(response=response.content) # DOC DESCRIPTION @app.post("/kb/doc") def get_doc_description(data: UserQuery, response_model=DescriptionResponse): if(data.apisecretkey != API_SECRET_KEY): raise HTTPException(status_code=401, detail="Unauthorized") document = data.query llm = ChatOpenAI( model="gpt-4o-mini", temperature=0, max_tokens=None, timeout=None, api_key=data.openaikey ) descriptionchain = DOC_DESCRIPTION_PROMPT | llm response = descriptionchain.invoke({"document": document}) return DescriptionResponse(response=response.content) # SAVE DOCS TO GRAPH DB @app.post("/kb/") def populate_graph(apires: RetrivedDocList): if(apires.apisecretkey != API_SECRET_KEY): raise HTTPException(status_code=401, detail="Unauthorized") print("STARTED") # print(apires) graph = Neo4jGraph(url=apires.neourl, username=apires.neouser, password=apires.neopass) llm = ChatOpenAI( model="gpt-4o-mini", temperature=0, max_tokens=None, timeout=None, api_key=apires.openaikey ) embeddings = OpenAIEmbeddings( model="text-embedding-ada-002", api_key=apires.openaikey, ) llm_transformer = LLMGraphTransformer(llm=llm) raw_documents = [] for doc in apires.documents: raw_documents.append(Document(page_content=doc.pageContent, metadata=doc.metadata)) text_splitter = SemanticChunker(embeddings=embeddings) documents = text_splitter.split_documents(raw_documents) graph_documents = llm_transformer.convert_to_graph_documents(documents) graph.add_graph_documents( graph_documents, baseEntityLabel=True, include_source=True ) print("FINISHED") return { "success": "Graph Will be populated Shortly" } #AUTH CODE oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") # Recommended for Local Setups # origins = [ # "http://localhost:3000", # Adjust the port if your frontend runs on a different one # "https://yourfrontenddomain.com", # ] app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allows all origins from the list allow_credentials=True, allow_methods=["*"], # Allows all methods allow_headers=["*"], # Allows all headers ) # Dependency def get_db(): db = SessionLocal() try: yield db finally: db.close() pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") class UserCreate(BaseModel): username: str password: str apisecretkey: str def get_user_by_username(db: Session, username: str): return db.query(User).filter(User.username == username).first() def create_user(db: Session, user: UserCreate): hashed_password = pwd_context.hash(user.password) db_user = User(username=user.username, hashed_password=hashed_password) db.add(db_user) db.commit() return "complete" @app.post("/register") def register_user(user: UserCreate, db: Session = Depends(get_db)): if(user.apisecretkey != API_SECRET_KEY): raise HTTPException(status_code=401, detail="Unauthorized") db_user = get_user_by_username(db, username=user.username) if db_user: raise HTTPException(status_code=400, detail="Username already registered") del user.apisecretkey return create_user(db=db, user=user) # Authenticate the user def authenticate_user(username: str, password: str, db: Session): user = db.query(User).filter(User.username == username).first() if not user: return False if not pwd_context.verify(password, user.hashed_password): return False return user # Create access token def create_access_token(data: dict, expires_delta: timedelta | None = None): to_encode = data.copy() if expires_delta: expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta(minutes=15) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt @app.post("/token") def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)): user = authenticate_user(form_data.username, form_data.password, db) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username or password", headers={"WWW-Authenticate": "Bearer"}, ) access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) access_token = create_access_token( data={"sub": user.username}, expires_delta=access_token_expires ) return {"access_token": access_token, "token_type": "bearer"} def verify_token(token: str = Depends(oauth2_scheme)): try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) username: str = payload.get("sub") if username is None: raise HTTPException(status_code=403, detail="Token is invalid or expired") return payload except JWTError: raise HTTPException(status_code=403, detail="Token is invalid or expired") @app.get("/verify-token/{token}") async def verify_user_token(token: str): verify_token(token=token) return {"message": "Token is valid"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="127.0.0.1", port=8000)