mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-04 19:49:09 +00:00
feat: Save Chats, Notifications + extension to v0.0.3
This commit is contained in:
parent
2f22acbfa0
commit
8ab9c26b4c
6 changed files with 273 additions and 67 deletions
|
@ -1,14 +1,41 @@
|
|||
from sqlalchemy import Column, Integer, String
|
||||
from database import Base
|
||||
from database import engine
|
||||
from typing import List
|
||||
from database import Base, engine
|
||||
from sqlalchemy import Column, ForeignKey, Integer, String, create_engine
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
class BaseModel(Base):
|
||||
__abstract__ = True
|
||||
__allow_unmapped__ = True
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
|
||||
class Notification(BaseModel):
|
||||
__tablename__ = "notifications"
|
||||
|
||||
text = Column(String)
|
||||
user_id = Column(ForeignKey('users.id'))
|
||||
user = relationship('User')
|
||||
|
||||
|
||||
class Chat(BaseModel):
|
||||
__tablename__ = "chats"
|
||||
|
||||
type = Column(String)
|
||||
title = Column(String)
|
||||
chats_list = Column(String)
|
||||
user_id = Column(ForeignKey('users.id'))
|
||||
user = relationship('User')
|
||||
|
||||
class User(BaseModel):
|
||||
__tablename__ = "users"
|
||||
|
||||
username = Column(String, unique=True, index=True)
|
||||
hashed_password = Column(String)
|
||||
graph_config = Column(String)
|
||||
llm_config = Column(String)
|
||||
chats = relationship(Chat)
|
||||
notifications = relationship(Notification)
|
||||
|
||||
# Create the database tables if they don't exist
|
||||
|
||||
User.metadata.create_all(bind=engine)
|
||||
|
|
|
@ -120,6 +120,23 @@ VECTOR_QUERY_GENERATION_PROMT = PromptTemplate(
|
|||
)
|
||||
|
||||
|
||||
NOTIFICATION_GENERATION_TEMPLATE = """You are a highly attentive assistant. You are provided with a collection of User Browsing History Events containing page content. Your task is to thoroughly analyze these events and generate a concise list of critical notifications that the User must be aware of.
|
||||
|
||||
User Browsing History Events Documents:
|
||||
{documents}
|
||||
|
||||
Instructions:
|
||||
Return only the notification text, and nothing else.
|
||||
Exclude any notifications that are not essential.
|
||||
|
||||
Response:"""
|
||||
|
||||
NOTIFICATION_GENERATION_PROMT = PromptTemplate(
|
||||
input_variables=["documents"], template=NOTIFICATION_GENERATION_TEMPLATE
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -62,7 +62,7 @@ class RetrivedDocList(BaseModel):
|
|||
neouser: str
|
||||
neopass: str
|
||||
openaikey: str
|
||||
apisecretkey: str
|
||||
token: str
|
||||
|
||||
|
||||
class UserQueryResponse(BaseModel):
|
||||
|
@ -72,4 +72,33 @@ class UserQueryResponse(BaseModel):
|
|||
|
||||
class VectorSearchQuery(BaseModel):
|
||||
searchquery: str
|
||||
|
||||
|
||||
class NewUserData(BaseModel):
|
||||
token: str
|
||||
userid: str
|
||||
chats: str
|
||||
notifications: str
|
||||
|
||||
class NewUserChat(BaseModel):
|
||||
token: str
|
||||
type: str
|
||||
title: str
|
||||
chats_list: str
|
||||
|
||||
|
||||
class ChatToUpdate(BaseModel):
|
||||
chatid: str
|
||||
token: str
|
||||
# type: str
|
||||
# title: str
|
||||
chats_list: str
|
||||
|
||||
class GraphDocs(BaseModel):
|
||||
documents: List[RetrivedDocListItem]
|
||||
token: str
|
||||
|
||||
|
||||
class Notifications(BaseModel):
|
||||
notifications: List[str]
|
||||
|
|
@ -6,8 +6,8 @@ from langchain_core.documents import Document
|
|||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain_community.vectorstores import Neo4jVector
|
||||
from envs import ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, API_SECRET_KEY, SECRET_KEY
|
||||
from prompts import CYPHER_QA_PROMPT, DATE_TODAY, DOC_DESCRIPTION_PROMPT, GRAPH_QUERY_GEN_PROMPT, SIMILARITY_SEARCH_PROMPT , CYPHER_GENERATION_PROMPT, DOCUMENT_METADATA_EXTRACTION_PROMT
|
||||
from pydmodels import DescriptionResponse, PrecisionQuery, PrecisionResponse, UserQuery, DocMeta, RetrivedDocList, UserQueryResponse, UserQueryWithChatHistory, VectorSearchQuery
|
||||
from prompts import CYPHER_QA_PROMPT, DATE_TODAY, DOC_DESCRIPTION_PROMPT, GRAPH_QUERY_GEN_PROMPT, NOTIFICATION_GENERATION_PROMT, SIMILARITY_SEARCH_PROMPT , CYPHER_GENERATION_PROMPT, DOCUMENT_METADATA_EXTRACTION_PROMT
|
||||
from pydmodels import ChatToUpdate, DescriptionResponse, GraphDocs, NewUserChat, NewUserData, Notifications, PrecisionQuery, PrecisionResponse, UserQuery, DocMeta, RetrivedDocList, UserQueryResponse, UserQueryWithChatHistory, VectorSearchQuery
|
||||
from langchain_experimental.text_splitter import SemanticChunker
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
||||
|
@ -16,6 +16,12 @@ from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
|||
from LLMGraphTransformer import LLMGraphTransformer
|
||||
from langchain_openai import ChatOpenAI
|
||||
from DataExample import examples
|
||||
# import nest_asyncio
|
||||
# from langchain_community.chains.graph_qa.gremlin import GremlinQAChain
|
||||
# from langchain_community.graphs import GremlinGraph
|
||||
# from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
|
||||
# from langchain_core.documents import Document
|
||||
# from langchain_openai import AzureChatOpenAI
|
||||
|
||||
# Auth Libs
|
||||
from fastapi import FastAPI, Depends, HTTPException, Request, status
|
||||
|
@ -24,13 +30,27 @@ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
|||
from jose import JWTError, jwt
|
||||
from datetime import datetime, timedelta
|
||||
from passlib.context import CryptContext
|
||||
from models import User
|
||||
from models import Chat, Notification, User
|
||||
from database import SessionLocal, engine
|
||||
from pydantic import BaseModel
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Dependency
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
apisecretkey: str
|
||||
|
||||
|
||||
|
||||
|
@ -271,54 +291,85 @@ def get_doc_description(data: UserQuery, response_model=DescriptionResponse):
|
|||
|
||||
# SAVE DOCS TO GRAPH DB
|
||||
@app.post("/kb/")
|
||||
def populate_graph(apires: RetrivedDocList):
|
||||
if(apires.apisecretkey != API_SECRET_KEY):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
def populate_graph(apires: RetrivedDocList, db: Session = Depends(get_db)):
|
||||
|
||||
print("STARTED")
|
||||
# print(apires)
|
||||
graph = Neo4jGraph(url=apires.neourl, username=apires.neouser, password=apires.neopass)
|
||||
|
||||
llm = ChatOpenAI(
|
||||
model="gpt-4o-mini",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
api_key=apires.openaikey
|
||||
)
|
||||
|
||||
embeddings = OpenAIEmbeddings(
|
||||
model="text-embedding-ada-002",
|
||||
api_key=apires.openaikey,
|
||||
)
|
||||
|
||||
llm_transformer = LLMGraphTransformer(llm=llm)
|
||||
try:
|
||||
payload = jwt.decode(apires.token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
|
||||
print("STARTED")
|
||||
# print(apires)
|
||||
graph = Neo4jGraph(url=apires.neourl, username=apires.neouser, password=apires.neopass)
|
||||
|
||||
llm = ChatOpenAI(
|
||||
model="gpt-4o-mini",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
api_key=apires.openaikey
|
||||
)
|
||||
|
||||
embeddings = OpenAIEmbeddings(
|
||||
model="text-embedding-ada-002",
|
||||
api_key=apires.openaikey,
|
||||
)
|
||||
|
||||
llm_transformer = LLMGraphTransformer(llm=llm)
|
||||
|
||||
raw_documents = []
|
||||
raw_documents = []
|
||||
|
||||
for doc in apires.documents:
|
||||
raw_documents.append(Document(page_content=doc.pageContent, metadata=doc.metadata))
|
||||
for doc in apires.documents:
|
||||
raw_documents.append(Document(page_content=doc.pageContent, metadata=doc.metadata))
|
||||
|
||||
text_splitter = SemanticChunker(embeddings=embeddings)
|
||||
text_splitter = SemanticChunker(embeddings=embeddings)
|
||||
|
||||
documents = text_splitter.split_documents(raw_documents)
|
||||
graph_documents = llm_transformer.convert_to_graph_documents(documents)
|
||||
documents = text_splitter.split_documents(raw_documents)
|
||||
graph_documents = llm_transformer.convert_to_graph_documents(documents)
|
||||
|
||||
|
||||
graph.add_graph_documents(
|
||||
graph_documents,
|
||||
baseEntityLabel=True,
|
||||
include_source=True
|
||||
)
|
||||
|
||||
print("FINISHED")
|
||||
|
||||
return {
|
||||
"success": "Graph Will be populated Shortly"
|
||||
}
|
||||
graph.add_graph_documents(
|
||||
graph_documents,
|
||||
baseEntityLabel=True,
|
||||
include_source=True
|
||||
)
|
||||
|
||||
|
||||
structured_llm = llm.with_structured_output(Notifications)
|
||||
notifs_extraction_chain = NOTIFICATION_GENERATION_PROMT | structured_llm
|
||||
|
||||
notifications = notifs_extraction_chain.invoke({"documents": raw_documents})
|
||||
|
||||
notifsdb = []
|
||||
|
||||
for text in notifications.notifications:
|
||||
notifsdb.append(Notification(text=text))
|
||||
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
user.notifications.extend(notifsdb)
|
||||
|
||||
db.commit()
|
||||
|
||||
print("FINISHED")
|
||||
|
||||
return {
|
||||
"success": "Graph Will be populated Shortly"
|
||||
}
|
||||
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
|
||||
|
||||
#Fuction to populate db ( Comment out when running on server )
|
||||
@app.post("/add/")
|
||||
def add_user(user: UserCreate, db: Session = Depends(get_db)):
|
||||
db_user = User(username=user.username, hashed_password=user.password, graph_config="", llm_config="")
|
||||
db.add(db_user)
|
||||
db.commit()
|
||||
return "Success"
|
||||
|
||||
|
||||
#AUTH CODE
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
@ -337,28 +388,14 @@ app.add_middleware(
|
|||
allow_headers=["*"], # Allows all headers
|
||||
)
|
||||
|
||||
# Dependency
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
apisecretkey: str
|
||||
|
||||
def get_user_by_username(db: Session, username: str):
|
||||
return db.query(User).filter(User.username == username).first()
|
||||
|
||||
def create_user(db: Session, user: UserCreate):
|
||||
hashed_password = pwd_context.hash(user.password)
|
||||
db_user = User(username=user.username, hashed_password=hashed_password)
|
||||
db_user = User(username=user.username, hashed_password=hashed_password, graph_config="", llm_config="")
|
||||
db.add(db_user)
|
||||
db.commit()
|
||||
return "complete"
|
||||
|
@ -410,7 +447,6 @@ def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db:
|
|||
)
|
||||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
|
||||
|
||||
def verify_token(token: str = Depends(oauth2_scheme)):
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
|
@ -427,6 +463,103 @@ async def verify_user_token(token: str):
|
|||
return {"message": "Token is valid"}
|
||||
|
||||
|
||||
|
||||
|
||||
@app.post("/user/chat/save")
|
||||
def populate_user_chat(chat: NewUserChat, db: Session = Depends(get_db)):
|
||||
try:
|
||||
payload = jwt.decode(chat.token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
newchat = Chat(type=chat.type, title=chat.title, chats_list=chat.chats_list)
|
||||
|
||||
user.chats.append(newchat)
|
||||
db.commit()
|
||||
return {
|
||||
"message": "Chat Saved"
|
||||
}
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
@app.post("/user/chat/update")
|
||||
def populate_user_chat(chat: ChatToUpdate, db: Session = Depends(get_db)):
|
||||
try:
|
||||
payload = jwt.decode(chat.token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
chatindb = db.query(Chat).filter(Chat.id == chat.chatid).first()
|
||||
chatindb.chats_list = chat.chats_list
|
||||
|
||||
db.commit()
|
||||
return {
|
||||
"message": "Chat Updated"
|
||||
}
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
@app.get("/user/chat/delete/{token}/{chatid}")
|
||||
async def delete_chat_of_user(token: str, chatid: str, db: Session = Depends(get_db)):
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
chatindb = db.query(Chat).filter(Chat.id == chatid).first()
|
||||
db.delete(chatindb)
|
||||
db.commit()
|
||||
return {
|
||||
"message": "Chat Deleted"
|
||||
}
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
#Gets user id & name
|
||||
@app.get("/user/{token}")
|
||||
async def get_user_with_token(token: str, db: Session = Depends(get_db)):
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
return {
|
||||
"userid": user.id,
|
||||
"username": user.username,
|
||||
"chats": user.chats,
|
||||
"notifications": user.notifications
|
||||
}
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
|
||||
@app.get("/user/notification/delete/{token}/{notificationid}")
|
||||
async def delete_chat_of_user(token: str, notificationid: str, db: Session = Depends(get_db)):
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
notificationindb = db.query(Notification).filter(Notification.id == notificationid).first()
|
||||
db.delete(notificationindb)
|
||||
db.commit()
|
||||
return {
|
||||
"message": "Notification Deleted"
|
||||
}
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
"name": "SurfSense",
|
||||
"description": "Extension to collect Browsing History for SurfSense.",
|
||||
"version": "0.0.2",
|
||||
"version": "0.0.3",
|
||||
|
||||
"action": {
|
||||
"default_popup": "popup.html"
|
||||
|
|
|
@ -216,7 +216,7 @@ export const Popup = () => {
|
|||
neouser: localStorage.getItem('neouser'),
|
||||
neopass: localStorage.getItem('neopass'),
|
||||
openaikey: localStorage.getItem('openaikey'),
|
||||
apisecretkey: API_SECRET_KEY
|
||||
token: localStorage.getItem('token')
|
||||
}
|
||||
|
||||
// console.log("toSend",toSend)
|
||||
|
|
Loading…
Add table
Reference in a new issue