feat: Save Chats, Notifications + extension to v0.0.3

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2024-08-21 23:06:30 -07:00
parent 2f22acbfa0
commit 8ab9c26b4c
6 changed files with 273 additions and 67 deletions

View file

@ -1,14 +1,41 @@
from sqlalchemy import Column, Integer, String
from database import Base
from database import engine
from typing import List
from database import Base, engine
from sqlalchemy import Column, ForeignKey, Integer, String, create_engine
from sqlalchemy.orm import relationship
class User(Base):
__tablename__ = "users"
class BaseModel(Base):
__abstract__ = True
__allow_unmapped__ = True
id = Column(Integer, primary_key=True, index=True)
class Notification(BaseModel):
__tablename__ = "notifications"
text = Column(String)
user_id = Column(ForeignKey('users.id'))
user = relationship('User')
class Chat(BaseModel):
__tablename__ = "chats"
type = Column(String)
title = Column(String)
chats_list = Column(String)
user_id = Column(ForeignKey('users.id'))
user = relationship('User')
class User(BaseModel):
__tablename__ = "users"
username = Column(String, unique=True, index=True)
hashed_password = Column(String)
graph_config = Column(String)
llm_config = Column(String)
chats = relationship(Chat)
notifications = relationship(Notification)
# Create the database tables if they don't exist
User.metadata.create_all(bind=engine)

View file

@ -120,6 +120,23 @@ VECTOR_QUERY_GENERATION_PROMT = PromptTemplate(
)
NOTIFICATION_GENERATION_TEMPLATE = """You are a highly attentive assistant. You are provided with a collection of User Browsing History Events containing page content. Your task is to thoroughly analyze these events and generate a concise list of critical notifications that the User must be aware of.
User Browsing History Events Documents:
{documents}
Instructions:
Return only the notification text, and nothing else.
Exclude any notifications that are not essential.
Response:"""
NOTIFICATION_GENERATION_PROMT = PromptTemplate(
input_variables=["documents"], template=NOTIFICATION_GENERATION_TEMPLATE
)

View file

@ -62,7 +62,7 @@ class RetrivedDocList(BaseModel):
neouser: str
neopass: str
openaikey: str
apisecretkey: str
token: str
class UserQueryResponse(BaseModel):
@ -72,4 +72,33 @@ class UserQueryResponse(BaseModel):
class VectorSearchQuery(BaseModel):
searchquery: str
class NewUserData(BaseModel):
token: str
userid: str
chats: str
notifications: str
class NewUserChat(BaseModel):
token: str
type: str
title: str
chats_list: str
class ChatToUpdate(BaseModel):
chatid: str
token: str
# type: str
# title: str
chats_list: str
class GraphDocs(BaseModel):
documents: List[RetrivedDocListItem]
token: str
class Notifications(BaseModel):
notifications: List[str]

View file

@ -6,8 +6,8 @@ from langchain_core.documents import Document
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import Neo4jVector
from envs import ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, API_SECRET_KEY, SECRET_KEY
from prompts import CYPHER_QA_PROMPT, DATE_TODAY, DOC_DESCRIPTION_PROMPT, GRAPH_QUERY_GEN_PROMPT, SIMILARITY_SEARCH_PROMPT , CYPHER_GENERATION_PROMPT, DOCUMENT_METADATA_EXTRACTION_PROMT
from pydmodels import DescriptionResponse, PrecisionQuery, PrecisionResponse, UserQuery, DocMeta, RetrivedDocList, UserQueryResponse, UserQueryWithChatHistory, VectorSearchQuery
from prompts import CYPHER_QA_PROMPT, DATE_TODAY, DOC_DESCRIPTION_PROMPT, GRAPH_QUERY_GEN_PROMPT, NOTIFICATION_GENERATION_PROMT, SIMILARITY_SEARCH_PROMPT , CYPHER_GENERATION_PROMPT, DOCUMENT_METADATA_EXTRACTION_PROMT
from pydmodels import ChatToUpdate, DescriptionResponse, GraphDocs, NewUserChat, NewUserData, Notifications, PrecisionQuery, PrecisionResponse, UserQuery, DocMeta, RetrivedDocList, UserQueryResponse, UserQueryWithChatHistory, VectorSearchQuery
from langchain_experimental.text_splitter import SemanticChunker
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
@ -16,6 +16,12 @@ from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
from LLMGraphTransformer import LLMGraphTransformer
from langchain_openai import ChatOpenAI
from DataExample import examples
# import nest_asyncio
# from langchain_community.chains.graph_qa.gremlin import GremlinQAChain
# from langchain_community.graphs import GremlinGraph
# from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
# from langchain_core.documents import Document
# from langchain_openai import AzureChatOpenAI
# Auth Libs
from fastapi import FastAPI, Depends, HTTPException, Request, status
@ -24,13 +30,27 @@ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from datetime import datetime, timedelta
from passlib.context import CryptContext
from models import User
from models import Chat, Notification, User
from database import SessionLocal, engine
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from langchain_openai import AzureChatOpenAI
app = FastAPI()
# Dependency
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
class UserCreate(BaseModel):
username: str
password: str
apisecretkey: str
@ -271,54 +291,85 @@ def get_doc_description(data: UserQuery, response_model=DescriptionResponse):
# SAVE DOCS TO GRAPH DB
@app.post("/kb/")
def populate_graph(apires: RetrivedDocList):
if(apires.apisecretkey != API_SECRET_KEY):
raise HTTPException(status_code=401, detail="Unauthorized")
def populate_graph(apires: RetrivedDocList, db: Session = Depends(get_db)):
print("STARTED")
# print(apires)
graph = Neo4jGraph(url=apires.neourl, username=apires.neouser, password=apires.neopass)
llm = ChatOpenAI(
model="gpt-4o-mini",
temperature=0,
max_tokens=None,
timeout=None,
api_key=apires.openaikey
)
embeddings = OpenAIEmbeddings(
model="text-embedding-ada-002",
api_key=apires.openaikey,
)
llm_transformer = LLMGraphTransformer(llm=llm)
try:
payload = jwt.decode(apires.token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
print("STARTED")
# print(apires)
graph = Neo4jGraph(url=apires.neourl, username=apires.neouser, password=apires.neopass)
llm = ChatOpenAI(
model="gpt-4o-mini",
temperature=0,
max_tokens=None,
timeout=None,
api_key=apires.openaikey
)
embeddings = OpenAIEmbeddings(
model="text-embedding-ada-002",
api_key=apires.openaikey,
)
llm_transformer = LLMGraphTransformer(llm=llm)
raw_documents = []
raw_documents = []
for doc in apires.documents:
raw_documents.append(Document(page_content=doc.pageContent, metadata=doc.metadata))
for doc in apires.documents:
raw_documents.append(Document(page_content=doc.pageContent, metadata=doc.metadata))
text_splitter = SemanticChunker(embeddings=embeddings)
text_splitter = SemanticChunker(embeddings=embeddings)
documents = text_splitter.split_documents(raw_documents)
graph_documents = llm_transformer.convert_to_graph_documents(documents)
documents = text_splitter.split_documents(raw_documents)
graph_documents = llm_transformer.convert_to_graph_documents(documents)
graph.add_graph_documents(
graph_documents,
baseEntityLabel=True,
include_source=True
)
print("FINISHED")
return {
"success": "Graph Will be populated Shortly"
}
graph.add_graph_documents(
graph_documents,
baseEntityLabel=True,
include_source=True
)
structured_llm = llm.with_structured_output(Notifications)
notifs_extraction_chain = NOTIFICATION_GENERATION_PROMT | structured_llm
notifications = notifs_extraction_chain.invoke({"documents": raw_documents})
notifsdb = []
for text in notifications.notifications:
notifsdb.append(Notification(text=text))
user = db.query(User).filter(User.username == username).first()
user.notifications.extend(notifsdb)
db.commit()
print("FINISHED")
return {
"success": "Graph Will be populated Shortly"
}
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
#Fuction to populate db ( Comment out when running on server )
@app.post("/add/")
def add_user(user: UserCreate, db: Session = Depends(get_db)):
db_user = User(username=user.username, hashed_password=user.password, graph_config="", llm_config="")
db.add(db_user)
db.commit()
return "Success"
#AUTH CODE
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
@ -337,28 +388,14 @@ app.add_middleware(
allow_headers=["*"], # Allows all headers
)
# Dependency
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
class UserCreate(BaseModel):
username: str
password: str
apisecretkey: str
def get_user_by_username(db: Session, username: str):
return db.query(User).filter(User.username == username).first()
def create_user(db: Session, user: UserCreate):
hashed_password = pwd_context.hash(user.password)
db_user = User(username=user.username, hashed_password=hashed_password)
db_user = User(username=user.username, hashed_password=hashed_password, graph_config="", llm_config="")
db.add(db_user)
db.commit()
return "complete"
@ -410,7 +447,6 @@ def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db:
)
return {"access_token": access_token, "token_type": "bearer"}
def verify_token(token: str = Depends(oauth2_scheme)):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
@ -427,6 +463,103 @@ async def verify_user_token(token: str):
return {"message": "Token is valid"}
@app.post("/user/chat/save")
def populate_user_chat(chat: NewUserChat, db: Session = Depends(get_db)):
try:
payload = jwt.decode(chat.token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
user = db.query(User).filter(User.username == username).first()
newchat = Chat(type=chat.type, title=chat.title, chats_list=chat.chats_list)
user.chats.append(newchat)
db.commit()
return {
"message": "Chat Saved"
}
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
@app.post("/user/chat/update")
def populate_user_chat(chat: ChatToUpdate, db: Session = Depends(get_db)):
try:
payload = jwt.decode(chat.token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
chatindb = db.query(Chat).filter(Chat.id == chat.chatid).first()
chatindb.chats_list = chat.chats_list
db.commit()
return {
"message": "Chat Updated"
}
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
@app.get("/user/chat/delete/{token}/{chatid}")
async def delete_chat_of_user(token: str, chatid: str, db: Session = Depends(get_db)):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
chatindb = db.query(Chat).filter(Chat.id == chatid).first()
db.delete(chatindb)
db.commit()
return {
"message": "Chat Deleted"
}
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
#Gets user id & name
@app.get("/user/{token}")
async def get_user_with_token(token: str, db: Session = Depends(get_db)):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
user = db.query(User).filter(User.username == username).first()
return {
"userid": user.id,
"username": user.username,
"chats": user.chats,
"notifications": user.notifications
}
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
@app.get("/user/notification/delete/{token}/{notificationid}")
async def delete_chat_of_user(token: str, notificationid: str, db: Session = Depends(get_db)):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
notificationindb = db.query(Notification).filter(Notification.id == notificationid).first()
db.delete(notificationindb)
db.commit()
return {
"message": "Notification Deleted"
}
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
if __name__ == "__main__":
import uvicorn

View file

@ -3,7 +3,7 @@
"name": "SurfSense",
"description": "Extension to collect Browsing History for SurfSense.",
"version": "0.0.2",
"version": "0.0.3",
"action": {
"default_popup": "popup.html"

View file

@ -216,7 +216,7 @@ export const Popup = () => {
neouser: localStorage.getItem('neouser'),
neopass: localStorage.getItem('neopass'),
openaikey: localStorage.getItem('openaikey'),
apisecretkey: API_SECRET_KEY
token: localStorage.getItem('token')
}
// console.log("toSend",toSend)