diff --git a/backend/models.py b/backend/models.py index 5de9154..0fd65a2 100644 --- a/backend/models.py +++ b/backend/models.py @@ -1,14 +1,41 @@ -from sqlalchemy import Column, Integer, String -from database import Base -from database import engine +from typing import List +from database import Base, engine +from sqlalchemy import Column, ForeignKey, Integer, String, create_engine +from sqlalchemy.orm import relationship -class User(Base): - __tablename__ = "users" +class BaseModel(Base): + __abstract__ = True + __allow_unmapped__ = True id = Column(Integer, primary_key=True, index=True) + + +class Notification(BaseModel): + __tablename__ = "notifications" + + text = Column(String) + user_id = Column(ForeignKey('users.id')) + user = relationship('User') + + +class Chat(BaseModel): + __tablename__ = "chats" + + type = Column(String) + title = Column(String) + chats_list = Column(String) + user_id = Column(ForeignKey('users.id')) + user = relationship('User') + +class User(BaseModel): + __tablename__ = "users" + username = Column(String, unique=True, index=True) hashed_password = Column(String) + graph_config = Column(String) + llm_config = Column(String) + chats = relationship(Chat) + notifications = relationship(Notification) # Create the database tables if they don't exist - User.metadata.create_all(bind=engine) diff --git a/backend/prompts.py b/backend/prompts.py index ad2962b..d45f42b 100644 --- a/backend/prompts.py +++ b/backend/prompts.py @@ -120,6 +120,23 @@ VECTOR_QUERY_GENERATION_PROMT = PromptTemplate( ) +NOTIFICATION_GENERATION_TEMPLATE = """You are a highly attentive assistant. You are provided with a collection of User Browsing History Events containing page content. Your task is to thoroughly analyze these events and generate a concise list of critical notifications that the User must be aware of. + +User Browsing History Events Documents: +{documents} + +Instructions: +Return only the notification text, and nothing else. +Exclude any notifications that are not essential. + +Response:""" + +NOTIFICATION_GENERATION_PROMT = PromptTemplate( + input_variables=["documents"], template=NOTIFICATION_GENERATION_TEMPLATE +) + + + diff --git a/backend/pydmodels.py b/backend/pydmodels.py index 8aa4959..759973d 100644 --- a/backend/pydmodels.py +++ b/backend/pydmodels.py @@ -62,7 +62,7 @@ class RetrivedDocList(BaseModel): neouser: str neopass: str openaikey: str - apisecretkey: str + token: str class UserQueryResponse(BaseModel): @@ -72,4 +72,33 @@ class UserQueryResponse(BaseModel): class VectorSearchQuery(BaseModel): searchquery: str + + +class NewUserData(BaseModel): + token: str + userid: str + chats: str + notifications: str + +class NewUserChat(BaseModel): + token: str + type: str + title: str + chats_list: str + + +class ChatToUpdate(BaseModel): + chatid: str + token: str + # type: str + # title: str + chats_list: str + +class GraphDocs(BaseModel): + documents: List[RetrivedDocListItem] + token: str + + +class Notifications(BaseModel): + notifications: List[str] \ No newline at end of file diff --git a/backend/server.py b/backend/server.py index 19e7843..f83fbe9 100644 --- a/backend/server.py +++ b/backend/server.py @@ -6,8 +6,8 @@ from langchain_core.documents import Document from langchain_openai import OpenAIEmbeddings from langchain_community.vectorstores import Neo4jVector from envs import ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, API_SECRET_KEY, SECRET_KEY -from prompts import CYPHER_QA_PROMPT, DATE_TODAY, DOC_DESCRIPTION_PROMPT, GRAPH_QUERY_GEN_PROMPT, SIMILARITY_SEARCH_PROMPT , CYPHER_GENERATION_PROMPT, DOCUMENT_METADATA_EXTRACTION_PROMT -from pydmodels import DescriptionResponse, PrecisionQuery, PrecisionResponse, UserQuery, DocMeta, RetrivedDocList, UserQueryResponse, UserQueryWithChatHistory, VectorSearchQuery +from prompts import CYPHER_QA_PROMPT, DATE_TODAY, DOC_DESCRIPTION_PROMPT, GRAPH_QUERY_GEN_PROMPT, NOTIFICATION_GENERATION_PROMT, SIMILARITY_SEARCH_PROMPT , CYPHER_GENERATION_PROMPT, DOCUMENT_METADATA_EXTRACTION_PROMT +from pydmodels import ChatToUpdate, DescriptionResponse, GraphDocs, NewUserChat, NewUserData, Notifications, PrecisionQuery, PrecisionResponse, UserQuery, DocMeta, RetrivedDocList, UserQueryResponse, UserQueryWithChatHistory, VectorSearchQuery from langchain_experimental.text_splitter import SemanticChunker from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.messages import HumanMessage, SystemMessage, AIMessage @@ -16,6 +16,12 @@ from langchain_core.messages import HumanMessage, SystemMessage, AIMessage from LLMGraphTransformer import LLMGraphTransformer from langchain_openai import ChatOpenAI from DataExample import examples +# import nest_asyncio +# from langchain_community.chains.graph_qa.gremlin import GremlinQAChain +# from langchain_community.graphs import GremlinGraph +# from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship +# from langchain_core.documents import Document +# from langchain_openai import AzureChatOpenAI # Auth Libs from fastapi import FastAPI, Depends, HTTPException, Request, status @@ -24,13 +30,27 @@ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from jose import JWTError, jwt from datetime import datetime, timedelta from passlib.context import CryptContext -from models import User +from models import Chat, Notification, User from database import SessionLocal, engine from pydantic import BaseModel from fastapi.middleware.cors import CORSMiddleware +from langchain_openai import AzureChatOpenAI + app = FastAPI() +# Dependency +def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close() + +class UserCreate(BaseModel): + username: str + password: str + apisecretkey: str @@ -271,54 +291,85 @@ def get_doc_description(data: UserQuery, response_model=DescriptionResponse): # SAVE DOCS TO GRAPH DB @app.post("/kb/") -def populate_graph(apires: RetrivedDocList): - if(apires.apisecretkey != API_SECRET_KEY): - raise HTTPException(status_code=401, detail="Unauthorized") +def populate_graph(apires: RetrivedDocList, db: Session = Depends(get_db)): - print("STARTED") - # print(apires) - graph = Neo4jGraph(url=apires.neourl, username=apires.neouser, password=apires.neopass) - - llm = ChatOpenAI( - model="gpt-4o-mini", - temperature=0, - max_tokens=None, - timeout=None, - api_key=apires.openaikey - ) - - embeddings = OpenAIEmbeddings( - model="text-embedding-ada-002", - api_key=apires.openaikey, - ) - - llm_transformer = LLMGraphTransformer(llm=llm) + try: + payload = jwt.decode(apires.token, SECRET_KEY, algorithms=[ALGORITHM]) + username: str = payload.get("sub") + if username is None: + raise HTTPException(status_code=403, detail="Token is invalid or expired") + + + print("STARTED") + # print(apires) + graph = Neo4jGraph(url=apires.neourl, username=apires.neouser, password=apires.neopass) + + llm = ChatOpenAI( + model="gpt-4o-mini", + temperature=0, + max_tokens=None, + timeout=None, + api_key=apires.openaikey + ) + + embeddings = OpenAIEmbeddings( + model="text-embedding-ada-002", + api_key=apires.openaikey, + ) + + llm_transformer = LLMGraphTransformer(llm=llm) - raw_documents = [] + raw_documents = [] - for doc in apires.documents: - raw_documents.append(Document(page_content=doc.pageContent, metadata=doc.metadata)) + for doc in apires.documents: + raw_documents.append(Document(page_content=doc.pageContent, metadata=doc.metadata)) - text_splitter = SemanticChunker(embeddings=embeddings) + text_splitter = SemanticChunker(embeddings=embeddings) - documents = text_splitter.split_documents(raw_documents) - graph_documents = llm_transformer.convert_to_graph_documents(documents) + documents = text_splitter.split_documents(raw_documents) + graph_documents = llm_transformer.convert_to_graph_documents(documents) - graph.add_graph_documents( - graph_documents, - baseEntityLabel=True, - include_source=True - ) - - print("FINISHED") - - return { - "success": "Graph Will be populated Shortly" - } + graph.add_graph_documents( + graph_documents, + baseEntityLabel=True, + include_source=True + ) + + + structured_llm = llm.with_structured_output(Notifications) + notifs_extraction_chain = NOTIFICATION_GENERATION_PROMT | structured_llm + + notifications = notifs_extraction_chain.invoke({"documents": raw_documents}) + + notifsdb = [] + + for text in notifications.notifications: + notifsdb.append(Notification(text=text)) + + user = db.query(User).filter(User.username == username).first() + user.notifications.extend(notifsdb) + + db.commit() + + print("FINISHED") + + return { + "success": "Graph Will be populated Shortly" + } + + except JWTError: + raise HTTPException(status_code=403, detail="Token is invalid or expired") - +#Fuction to populate db ( Comment out when running on server ) +@app.post("/add/") +def add_user(user: UserCreate, db: Session = Depends(get_db)): + db_user = User(username=user.username, hashed_password=user.password, graph_config="", llm_config="") + db.add(db_user) + db.commit() + return "Success" + #AUTH CODE oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") @@ -337,28 +388,14 @@ app.add_middleware( allow_headers=["*"], # Allows all headers ) -# Dependency -def get_db(): - db = SessionLocal() - try: - yield db - finally: - db.close() - pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") - -class UserCreate(BaseModel): - username: str - password: str - apisecretkey: str - def get_user_by_username(db: Session, username: str): return db.query(User).filter(User.username == username).first() def create_user(db: Session, user: UserCreate): hashed_password = pwd_context.hash(user.password) - db_user = User(username=user.username, hashed_password=hashed_password) + db_user = User(username=user.username, hashed_password=hashed_password, graph_config="", llm_config="") db.add(db_user) db.commit() return "complete" @@ -410,7 +447,6 @@ def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: ) return {"access_token": access_token, "token_type": "bearer"} - def verify_token(token: str = Depends(oauth2_scheme)): try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) @@ -427,6 +463,103 @@ async def verify_user_token(token: str): return {"message": "Token is valid"} + + +@app.post("/user/chat/save") +def populate_user_chat(chat: NewUserChat, db: Session = Depends(get_db)): + try: + payload = jwt.decode(chat.token, SECRET_KEY, algorithms=[ALGORITHM]) + username: str = payload.get("sub") + if username is None: + raise HTTPException(status_code=403, detail="Token is invalid or expired") + + user = db.query(User).filter(User.username == username).first() + newchat = Chat(type=chat.type, title=chat.title, chats_list=chat.chats_list) + + user.chats.append(newchat) + db.commit() + return { + "message": "Chat Saved" + } + except JWTError: + raise HTTPException(status_code=403, detail="Token is invalid or expired") + +@app.post("/user/chat/update") +def populate_user_chat(chat: ChatToUpdate, db: Session = Depends(get_db)): + try: + payload = jwt.decode(chat.token, SECRET_KEY, algorithms=[ALGORITHM]) + username: str = payload.get("sub") + if username is None: + raise HTTPException(status_code=403, detail="Token is invalid or expired") + + chatindb = db.query(Chat).filter(Chat.id == chat.chatid).first() + chatindb.chats_list = chat.chats_list + + db.commit() + return { + "message": "Chat Updated" + } + except JWTError: + raise HTTPException(status_code=403, detail="Token is invalid or expired") + +@app.get("/user/chat/delete/{token}/{chatid}") +async def delete_chat_of_user(token: str, chatid: str, db: Session = Depends(get_db)): + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + username: str = payload.get("sub") + if username is None: + raise HTTPException(status_code=403, detail="Token is invalid or expired") + + chatindb = db.query(Chat).filter(Chat.id == chatid).first() + db.delete(chatindb) + db.commit() + return { + "message": "Chat Deleted" + } + except JWTError: + raise HTTPException(status_code=403, detail="Token is invalid or expired") + +#Gets user id & name +@app.get("/user/{token}") +async def get_user_with_token(token: str, db: Session = Depends(get_db)): + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + username: str = payload.get("sub") + if username is None: + raise HTTPException(status_code=403, detail="Token is invalid or expired") + + user = db.query(User).filter(User.username == username).first() + return { + "userid": user.id, + "username": user.username, + "chats": user.chats, + "notifications": user.notifications + } + except JWTError: + raise HTTPException(status_code=403, detail="Token is invalid or expired") + + +@app.get("/user/notification/delete/{token}/{notificationid}") +async def delete_chat_of_user(token: str, notificationid: str, db: Session = Depends(get_db)): + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + username: str = payload.get("sub") + if username is None: + raise HTTPException(status_code=403, detail="Token is invalid or expired") + + notificationindb = db.query(Notification).filter(Notification.id == notificationid).first() + db.delete(notificationindb) + db.commit() + return { + "message": "Notification Deleted" + } + except JWTError: + raise HTTPException(status_code=403, detail="Token is invalid or expired") + + + + + if __name__ == "__main__": import uvicorn diff --git a/extension/public/manifest.json b/extension/public/manifest.json index 1453c32..f98d36d 100644 --- a/extension/public/manifest.json +++ b/extension/public/manifest.json @@ -3,7 +3,7 @@ "name": "SurfSense", "description": "Extension to collect Browsing History for SurfSense.", - "version": "0.0.2", + "version": "0.0.3", "action": { "default_popup": "popup.html" diff --git a/extension/src/popup.tsx b/extension/src/popup.tsx index 52ac932..889f00c 100644 --- a/extension/src/popup.tsx +++ b/extension/src/popup.tsx @@ -216,7 +216,7 @@ export const Popup = () => { neouser: localStorage.getItem('neouser'), neopass: localStorage.getItem('neopass'), openaikey: localStorage.getItem('openaikey'), - apisecretkey: API_SECRET_KEY + token: localStorage.getItem('token') } // console.log("toSend",toSend)