mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-07 13:09:11 +00:00
feat: Save Chats, Notifications + extension to v0.0.3
This commit is contained in:
parent
2f22acbfa0
commit
8ab9c26b4c
6 changed files with 273 additions and 67 deletions
|
@ -1,14 +1,41 @@
|
||||||
from sqlalchemy import Column, Integer, String
|
from typing import List
|
||||||
from database import Base
|
from database import Base, engine
|
||||||
from database import engine
|
from sqlalchemy import Column, ForeignKey, Integer, String, create_engine
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
class User(Base):
|
class BaseModel(Base):
|
||||||
__tablename__ = "users"
|
__abstract__ = True
|
||||||
|
__allow_unmapped__ = True
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, index=True)
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
|
||||||
|
|
||||||
|
class Notification(BaseModel):
|
||||||
|
__tablename__ = "notifications"
|
||||||
|
|
||||||
|
text = Column(String)
|
||||||
|
user_id = Column(ForeignKey('users.id'))
|
||||||
|
user = relationship('User')
|
||||||
|
|
||||||
|
|
||||||
|
class Chat(BaseModel):
|
||||||
|
__tablename__ = "chats"
|
||||||
|
|
||||||
|
type = Column(String)
|
||||||
|
title = Column(String)
|
||||||
|
chats_list = Column(String)
|
||||||
|
user_id = Column(ForeignKey('users.id'))
|
||||||
|
user = relationship('User')
|
||||||
|
|
||||||
|
class User(BaseModel):
|
||||||
|
__tablename__ = "users"
|
||||||
|
|
||||||
username = Column(String, unique=True, index=True)
|
username = Column(String, unique=True, index=True)
|
||||||
hashed_password = Column(String)
|
hashed_password = Column(String)
|
||||||
|
graph_config = Column(String)
|
||||||
|
llm_config = Column(String)
|
||||||
|
chats = relationship(Chat)
|
||||||
|
notifications = relationship(Notification)
|
||||||
|
|
||||||
# Create the database tables if they don't exist
|
# Create the database tables if they don't exist
|
||||||
|
|
||||||
User.metadata.create_all(bind=engine)
|
User.metadata.create_all(bind=engine)
|
||||||
|
|
|
@ -120,6 +120,23 @@ VECTOR_QUERY_GENERATION_PROMT = PromptTemplate(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
NOTIFICATION_GENERATION_TEMPLATE = """You are a highly attentive assistant. You are provided with a collection of User Browsing History Events containing page content. Your task is to thoroughly analyze these events and generate a concise list of critical notifications that the User must be aware of.
|
||||||
|
|
||||||
|
User Browsing History Events Documents:
|
||||||
|
{documents}
|
||||||
|
|
||||||
|
Instructions:
|
||||||
|
Return only the notification text, and nothing else.
|
||||||
|
Exclude any notifications that are not essential.
|
||||||
|
|
||||||
|
Response:"""
|
||||||
|
|
||||||
|
NOTIFICATION_GENERATION_PROMT = PromptTemplate(
|
||||||
|
input_variables=["documents"], template=NOTIFICATION_GENERATION_TEMPLATE
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -62,7 +62,7 @@ class RetrivedDocList(BaseModel):
|
||||||
neouser: str
|
neouser: str
|
||||||
neopass: str
|
neopass: str
|
||||||
openaikey: str
|
openaikey: str
|
||||||
apisecretkey: str
|
token: str
|
||||||
|
|
||||||
|
|
||||||
class UserQueryResponse(BaseModel):
|
class UserQueryResponse(BaseModel):
|
||||||
|
@ -72,4 +72,33 @@ class UserQueryResponse(BaseModel):
|
||||||
|
|
||||||
class VectorSearchQuery(BaseModel):
|
class VectorSearchQuery(BaseModel):
|
||||||
searchquery: str
|
searchquery: str
|
||||||
|
|
||||||
|
|
||||||
|
class NewUserData(BaseModel):
|
||||||
|
token: str
|
||||||
|
userid: str
|
||||||
|
chats: str
|
||||||
|
notifications: str
|
||||||
|
|
||||||
|
class NewUserChat(BaseModel):
|
||||||
|
token: str
|
||||||
|
type: str
|
||||||
|
title: str
|
||||||
|
chats_list: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatToUpdate(BaseModel):
|
||||||
|
chatid: str
|
||||||
|
token: str
|
||||||
|
# type: str
|
||||||
|
# title: str
|
||||||
|
chats_list: str
|
||||||
|
|
||||||
|
class GraphDocs(BaseModel):
|
||||||
|
documents: List[RetrivedDocListItem]
|
||||||
|
token: str
|
||||||
|
|
||||||
|
|
||||||
|
class Notifications(BaseModel):
|
||||||
|
notifications: List[str]
|
||||||
|
|
|
@ -6,8 +6,8 @@ from langchain_core.documents import Document
|
||||||
from langchain_openai import OpenAIEmbeddings
|
from langchain_openai import OpenAIEmbeddings
|
||||||
from langchain_community.vectorstores import Neo4jVector
|
from langchain_community.vectorstores import Neo4jVector
|
||||||
from envs import ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, API_SECRET_KEY, SECRET_KEY
|
from envs import ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, API_SECRET_KEY, SECRET_KEY
|
||||||
from prompts import CYPHER_QA_PROMPT, DATE_TODAY, DOC_DESCRIPTION_PROMPT, GRAPH_QUERY_GEN_PROMPT, SIMILARITY_SEARCH_PROMPT , CYPHER_GENERATION_PROMPT, DOCUMENT_METADATA_EXTRACTION_PROMT
|
from prompts import CYPHER_QA_PROMPT, DATE_TODAY, DOC_DESCRIPTION_PROMPT, GRAPH_QUERY_GEN_PROMPT, NOTIFICATION_GENERATION_PROMT, SIMILARITY_SEARCH_PROMPT , CYPHER_GENERATION_PROMPT, DOCUMENT_METADATA_EXTRACTION_PROMT
|
||||||
from pydmodels import DescriptionResponse, PrecisionQuery, PrecisionResponse, UserQuery, DocMeta, RetrivedDocList, UserQueryResponse, UserQueryWithChatHistory, VectorSearchQuery
|
from pydmodels import ChatToUpdate, DescriptionResponse, GraphDocs, NewUserChat, NewUserData, Notifications, PrecisionQuery, PrecisionResponse, UserQuery, DocMeta, RetrivedDocList, UserQueryResponse, UserQueryWithChatHistory, VectorSearchQuery
|
||||||
from langchain_experimental.text_splitter import SemanticChunker
|
from langchain_experimental.text_splitter import SemanticChunker
|
||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
||||||
|
@ -16,6 +16,12 @@ from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
||||||
from LLMGraphTransformer import LLMGraphTransformer
|
from LLMGraphTransformer import LLMGraphTransformer
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
from DataExample import examples
|
from DataExample import examples
|
||||||
|
# import nest_asyncio
|
||||||
|
# from langchain_community.chains.graph_qa.gremlin import GremlinQAChain
|
||||||
|
# from langchain_community.graphs import GremlinGraph
|
||||||
|
# from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
|
||||||
|
# from langchain_core.documents import Document
|
||||||
|
# from langchain_openai import AzureChatOpenAI
|
||||||
|
|
||||||
# Auth Libs
|
# Auth Libs
|
||||||
from fastapi import FastAPI, Depends, HTTPException, Request, status
|
from fastapi import FastAPI, Depends, HTTPException, Request, status
|
||||||
|
@ -24,13 +30,27 @@ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||||
from jose import JWTError, jwt
|
from jose import JWTError, jwt
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from passlib.context import CryptContext
|
from passlib.context import CryptContext
|
||||||
from models import User
|
from models import Chat, Notification, User
|
||||||
from database import SessionLocal, engine
|
from database import SessionLocal, engine
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from langchain_openai import AzureChatOpenAI
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
# Dependency
|
||||||
|
def get_db():
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
yield db
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
class UserCreate(BaseModel):
|
||||||
|
username: str
|
||||||
|
password: str
|
||||||
|
apisecretkey: str
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -271,54 +291,85 @@ def get_doc_description(data: UserQuery, response_model=DescriptionResponse):
|
||||||
|
|
||||||
# SAVE DOCS TO GRAPH DB
|
# SAVE DOCS TO GRAPH DB
|
||||||
@app.post("/kb/")
|
@app.post("/kb/")
|
||||||
def populate_graph(apires: RetrivedDocList):
|
def populate_graph(apires: RetrivedDocList, db: Session = Depends(get_db)):
|
||||||
if(apires.apisecretkey != API_SECRET_KEY):
|
|
||||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
||||||
|
|
||||||
print("STARTED")
|
try:
|
||||||
# print(apires)
|
payload = jwt.decode(apires.token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
graph = Neo4jGraph(url=apires.neourl, username=apires.neouser, password=apires.neopass)
|
username: str = payload.get("sub")
|
||||||
|
if username is None:
|
||||||
llm = ChatOpenAI(
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||||
model="gpt-4o-mini",
|
|
||||||
temperature=0,
|
|
||||||
max_tokens=None,
|
print("STARTED")
|
||||||
timeout=None,
|
# print(apires)
|
||||||
api_key=apires.openaikey
|
graph = Neo4jGraph(url=apires.neourl, username=apires.neouser, password=apires.neopass)
|
||||||
)
|
|
||||||
|
llm = ChatOpenAI(
|
||||||
embeddings = OpenAIEmbeddings(
|
model="gpt-4o-mini",
|
||||||
model="text-embedding-ada-002",
|
temperature=0,
|
||||||
api_key=apires.openaikey,
|
max_tokens=None,
|
||||||
)
|
timeout=None,
|
||||||
|
api_key=apires.openaikey
|
||||||
llm_transformer = LLMGraphTransformer(llm=llm)
|
)
|
||||||
|
|
||||||
|
embeddings = OpenAIEmbeddings(
|
||||||
|
model="text-embedding-ada-002",
|
||||||
|
api_key=apires.openaikey,
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_transformer = LLMGraphTransformer(llm=llm)
|
||||||
|
|
||||||
raw_documents = []
|
raw_documents = []
|
||||||
|
|
||||||
for doc in apires.documents:
|
for doc in apires.documents:
|
||||||
raw_documents.append(Document(page_content=doc.pageContent, metadata=doc.metadata))
|
raw_documents.append(Document(page_content=doc.pageContent, metadata=doc.metadata))
|
||||||
|
|
||||||
text_splitter = SemanticChunker(embeddings=embeddings)
|
text_splitter = SemanticChunker(embeddings=embeddings)
|
||||||
|
|
||||||
documents = text_splitter.split_documents(raw_documents)
|
documents = text_splitter.split_documents(raw_documents)
|
||||||
graph_documents = llm_transformer.convert_to_graph_documents(documents)
|
graph_documents = llm_transformer.convert_to_graph_documents(documents)
|
||||||
|
|
||||||
|
|
||||||
graph.add_graph_documents(
|
graph.add_graph_documents(
|
||||||
graph_documents,
|
graph_documents,
|
||||||
baseEntityLabel=True,
|
baseEntityLabel=True,
|
||||||
include_source=True
|
include_source=True
|
||||||
)
|
)
|
||||||
|
|
||||||
print("FINISHED")
|
|
||||||
|
structured_llm = llm.with_structured_output(Notifications)
|
||||||
return {
|
notifs_extraction_chain = NOTIFICATION_GENERATION_PROMT | structured_llm
|
||||||
"success": "Graph Will be populated Shortly"
|
|
||||||
}
|
notifications = notifs_extraction_chain.invoke({"documents": raw_documents})
|
||||||
|
|
||||||
|
notifsdb = []
|
||||||
|
|
||||||
|
for text in notifications.notifications:
|
||||||
|
notifsdb.append(Notification(text=text))
|
||||||
|
|
||||||
|
user = db.query(User).filter(User.username == username).first()
|
||||||
|
user.notifications.extend(notifsdb)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
print("FINISHED")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": "Graph Will be populated Shortly"
|
||||||
|
}
|
||||||
|
|
||||||
|
except JWTError:
|
||||||
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||||
|
|
||||||
|
|
||||||
|
#Fuction to populate db ( Comment out when running on server )
|
||||||
|
@app.post("/add/")
|
||||||
|
def add_user(user: UserCreate, db: Session = Depends(get_db)):
|
||||||
|
db_user = User(username=user.username, hashed_password=user.password, graph_config="", llm_config="")
|
||||||
|
db.add(db_user)
|
||||||
|
db.commit()
|
||||||
|
return "Success"
|
||||||
|
|
||||||
|
|
||||||
#AUTH CODE
|
#AUTH CODE
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||||
|
@ -337,28 +388,14 @@ app.add_middleware(
|
||||||
allow_headers=["*"], # Allows all headers
|
allow_headers=["*"], # Allows all headers
|
||||||
)
|
)
|
||||||
|
|
||||||
# Dependency
|
|
||||||
def get_db():
|
|
||||||
db = SessionLocal()
|
|
||||||
try:
|
|
||||||
yield db
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
|
|
||||||
|
|
||||||
class UserCreate(BaseModel):
|
|
||||||
username: str
|
|
||||||
password: str
|
|
||||||
apisecretkey: str
|
|
||||||
|
|
||||||
def get_user_by_username(db: Session, username: str):
|
def get_user_by_username(db: Session, username: str):
|
||||||
return db.query(User).filter(User.username == username).first()
|
return db.query(User).filter(User.username == username).first()
|
||||||
|
|
||||||
def create_user(db: Session, user: UserCreate):
|
def create_user(db: Session, user: UserCreate):
|
||||||
hashed_password = pwd_context.hash(user.password)
|
hashed_password = pwd_context.hash(user.password)
|
||||||
db_user = User(username=user.username, hashed_password=hashed_password)
|
db_user = User(username=user.username, hashed_password=hashed_password, graph_config="", llm_config="")
|
||||||
db.add(db_user)
|
db.add(db_user)
|
||||||
db.commit()
|
db.commit()
|
||||||
return "complete"
|
return "complete"
|
||||||
|
@ -410,7 +447,6 @@ def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db:
|
||||||
)
|
)
|
||||||
return {"access_token": access_token, "token_type": "bearer"}
|
return {"access_token": access_token, "token_type": "bearer"}
|
||||||
|
|
||||||
|
|
||||||
def verify_token(token: str = Depends(oauth2_scheme)):
|
def verify_token(token: str = Depends(oauth2_scheme)):
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
|
@ -427,6 +463,103 @@ async def verify_user_token(token: str):
|
||||||
return {"message": "Token is valid"}
|
return {"message": "Token is valid"}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/user/chat/save")
|
||||||
|
def populate_user_chat(chat: NewUserChat, db: Session = Depends(get_db)):
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(chat.token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
|
username: str = payload.get("sub")
|
||||||
|
if username is None:
|
||||||
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||||
|
|
||||||
|
user = db.query(User).filter(User.username == username).first()
|
||||||
|
newchat = Chat(type=chat.type, title=chat.title, chats_list=chat.chats_list)
|
||||||
|
|
||||||
|
user.chats.append(newchat)
|
||||||
|
db.commit()
|
||||||
|
return {
|
||||||
|
"message": "Chat Saved"
|
||||||
|
}
|
||||||
|
except JWTError:
|
||||||
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||||
|
|
||||||
|
@app.post("/user/chat/update")
|
||||||
|
def populate_user_chat(chat: ChatToUpdate, db: Session = Depends(get_db)):
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(chat.token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
|
username: str = payload.get("sub")
|
||||||
|
if username is None:
|
||||||
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||||
|
|
||||||
|
chatindb = db.query(Chat).filter(Chat.id == chat.chatid).first()
|
||||||
|
chatindb.chats_list = chat.chats_list
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
return {
|
||||||
|
"message": "Chat Updated"
|
||||||
|
}
|
||||||
|
except JWTError:
|
||||||
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||||
|
|
||||||
|
@app.get("/user/chat/delete/{token}/{chatid}")
|
||||||
|
async def delete_chat_of_user(token: str, chatid: str, db: Session = Depends(get_db)):
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
|
username: str = payload.get("sub")
|
||||||
|
if username is None:
|
||||||
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||||
|
|
||||||
|
chatindb = db.query(Chat).filter(Chat.id == chatid).first()
|
||||||
|
db.delete(chatindb)
|
||||||
|
db.commit()
|
||||||
|
return {
|
||||||
|
"message": "Chat Deleted"
|
||||||
|
}
|
||||||
|
except JWTError:
|
||||||
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||||
|
|
||||||
|
#Gets user id & name
|
||||||
|
@app.get("/user/{token}")
|
||||||
|
async def get_user_with_token(token: str, db: Session = Depends(get_db)):
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
|
username: str = payload.get("sub")
|
||||||
|
if username is None:
|
||||||
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||||
|
|
||||||
|
user = db.query(User).filter(User.username == username).first()
|
||||||
|
return {
|
||||||
|
"userid": user.id,
|
||||||
|
"username": user.username,
|
||||||
|
"chats": user.chats,
|
||||||
|
"notifications": user.notifications
|
||||||
|
}
|
||||||
|
except JWTError:
|
||||||
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/user/notification/delete/{token}/{notificationid}")
|
||||||
|
async def delete_chat_of_user(token: str, notificationid: str, db: Session = Depends(get_db)):
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
|
username: str = payload.get("sub")
|
||||||
|
if username is None:
|
||||||
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||||
|
|
||||||
|
notificationindb = db.query(Notification).filter(Notification.id == notificationid).first()
|
||||||
|
db.delete(notificationindb)
|
||||||
|
db.commit()
|
||||||
|
return {
|
||||||
|
"message": "Notification Deleted"
|
||||||
|
}
|
||||||
|
except JWTError:
|
||||||
|
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
|
|
||||||
"name": "SurfSense",
|
"name": "SurfSense",
|
||||||
"description": "Extension to collect Browsing History for SurfSense.",
|
"description": "Extension to collect Browsing History for SurfSense.",
|
||||||
"version": "0.0.2",
|
"version": "0.0.3",
|
||||||
|
|
||||||
"action": {
|
"action": {
|
||||||
"default_popup": "popup.html"
|
"default_popup": "popup.html"
|
||||||
|
|
|
@ -216,7 +216,7 @@ export const Popup = () => {
|
||||||
neouser: localStorage.getItem('neouser'),
|
neouser: localStorage.getItem('neouser'),
|
||||||
neopass: localStorage.getItem('neopass'),
|
neopass: localStorage.getItem('neopass'),
|
||||||
openaikey: localStorage.getItem('openaikey'),
|
openaikey: localStorage.getItem('openaikey'),
|
||||||
apisecretkey: API_SECRET_KEY
|
token: localStorage.getItem('token')
|
||||||
}
|
}
|
||||||
|
|
||||||
// console.log("toSend",toSend)
|
// console.log("toSend",toSend)
|
||||||
|
|
Loading…
Add table
Reference in a new issue