feat: SurfSense v0.0.6 init

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2025-03-14 18:53:14 -07:00
parent 18fc19e8d9
commit da23012970
58 changed files with 8284 additions and 2076 deletions

13
.gitignore vendored
View file

@ -8,18 +8,7 @@ env.bak/
venv.bak/ venv.bak/
data/ data/
.data .data
__pycache__ __pycache__
__pycache__/ __pycache__/
.__pycache__ .__pycache__
surfsense_backend/.env
backend/examples
backend/old
backend/RAGAgent
backend/testfiles
backend/.env
vectorstores/*
vectorstores/
.vectorstores

6
.gitmodules vendored
View file

@ -1,6 +1,6 @@
[submodule "SurfSense-Frontend"] [submodule "surfsense_frontend"]
path = SurfSense-Frontend path = surfsense_frontend
url = https://github.com/MODSetter/SurfSense-Frontend.git url = https://github.com/MODSetter/surfsense_frontend.git
[submodule "ss-cross-browser-extension"] [submodule "ss-cross-browser-extension"]
path = ss-cross-browser-extension path = ss-cross-browser-extension
url = https://github.com/MODSetter/ss-cross-browser-extension.git url = https://github.com/MODSetter/ss-cross-browser-extension.git

@ -1 +0,0 @@
Subproject commit 53211d0b590ff5c7aaf721fbf0a39c21d7f0b823

View file

@ -1,35 +0,0 @@
#Your Unstructed IO API Key. Random value if you are using unstructured locally or dont want to upload files
UNSTRUCTURED_API_KEY = ""
#POSTGRES DB TO TRACK USERS
POSTGRES_DATABASE_URL = "postgresql+psycopg2://postgres:postgres@localhost:5432/surfsense"
# API KEY TO VERIFY
API_SECRET_KEY = "surfsense"
# Your JWT secret and algorithm
SECRET_KEY = "your_secret_key"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = "720"
# SEARCHE ENGINES TO USE FOR WEB SEARCH
TAVILY_API_KEY=""
# UNCOMMENT THE RESPECTIVE BELOW LINES FOR LOCAL/OPENAI SETUP
# For OpenAI LLM SETUP
OPENAI_API_KEY="sk-proj-GHG....."
FAST_LLM="openai:gpt-4o-mini"
SMART_LLM="openai:gpt-4o-mini"
EMBEDDING="openai:text-embedding-3-large"
# For Local Setups
# OPENAI_API_KEY="123"
# OLLAMA_BASE_URL="http://localhost:11434"
# FAST_LLM="ollama:qwen2.5:7b"
# SMART_LLM="ollama:qwen2.5:7b"
# EMBEDDING="ollama:qwen2.5:7b"
# TEMPRATURE="0"

12
backend/.gitignore vendored
View file

@ -1,12 +0,0 @@
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
__pycache__
__pycache__/
.__pycache__

View file

@ -1,19 +0,0 @@
# Use official Python image
FROM python:3.12-slim
# Set working directory
WORKDIR /app
# Copy backend source code
COPY . /app
# Install dependencies
RUN apt-get update && apt-get install -y gcc libpq-dev
RUN pip install --upgrade pip
RUN pip install -r requirements.txt
# Expose backend port
EXPOSE 8000
# Run FastAPI using uvicorn
CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"]

View file

@ -1,501 +0,0 @@
from datetime import datetime
import json
from typing import List
from gpt_researcher import GPTResearcher
from langchain_chroma import Chroma
from langchain_ollama import OllamaLLM, OllamaEmbeddings
from langchain_community.vectorstores.utils import filter_complex_metadata
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain.docstore.document import Document
from langchain_experimental.text_splitter import SemanticChunker
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import FlashrankRerank
import numpy as np
from sqlalchemy.orm import Session
from fastapi import Depends, WebSocket
from prompts import report_prompt
import os
from dotenv import load_dotenv
from Utils.stringify import stringify
from pydmodels import AIAnswer, Reference
from database import SessionLocal
from models import Documents
load_dotenv()
SMART_LLM = os.environ.get("SMART_LLM")
EMBEDDING = os.environ.get("EMBEDDING")
IS_LOCAL_SETUP = True if SMART_LLM.startswith("ollama") else False
def extract_model_name(model_string: str) -> tuple[str, str]:
part1, part2 = model_string.split(":", 1) # Split into two parts at the first colon
return part2
MODEL_NAME = extract_model_name(SMART_LLM)
EMBEDDING_MODEL = extract_model_name(EMBEDDING)
# Dependency
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
return super().default(obj)
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def send_personal_message(self, message: str, websocket: WebSocket):
await websocket.send_text(message)
class HIndices:
def __init__(self, username, api_key='local'):
"""
"""
self.username = username
if(IS_LOCAL_SETUP == True):
self.llm = OllamaLLM(model=MODEL_NAME,temperature=0)
self.embeddings = OllamaEmbeddings(model=EMBEDDING_MODEL)
else:
self.llm = ChatOpenAI(temperature=0, model_name=MODEL_NAME, api_key=api_key)
self.embeddings = OpenAIEmbeddings(api_key=api_key,model=EMBEDDING_MODEL)
self.summary_store = Chroma(
collection_name="summary_store",
embedding_function=self.embeddings,
persist_directory="./vectorstores/" + username + "/summary_store_db", # Where to save data locally
)
self.detailed_store = Chroma(
collection_name="detailed_store",
embedding_function=self.embeddings,
persist_directory="./vectorstores/" + username + "/detailed_store_db", # Where to save data locally
)
# self.summary_store_size = len(self.summary_store.get()['documents'])
# self.detailed_store_size = len(self.detailed_store.get()['documents'])
def summarize_file_doc(self, page_no, doc, search_space):
# Create an LLMChain for sub-query decomposition
report_chain = report_prompt | self.llm
if(IS_LOCAL_SETUP == True):
response = report_chain.invoke({"document": doc})
metadict = {
"page": page_no,
"summary": True,
"search_space": search_space,
}
metadict.update(doc.metadata)
# metadict['languages'] = metadict['languages'][0]
return Document(
page_content=response,
metadata=metadict
)
else:
response = report_chain.invoke({"document": doc})
metadict = {
"page": page_no,
"summary": True,
"search_space": search_space,
}
metadict.update(doc.metadata)
# metadict['languages'] = metadict['languages'][0]
return Document(
page_content=response.content,
metadata=metadict
)
def summarize_webpage_doc(self, page_no, doc, search_space):
# Create an LLMChain for sub-query decomposition
report_chain = report_prompt | self.llm
if(IS_LOCAL_SETUP == True):
response = report_chain.invoke({"document": doc})
return Document(
page_content=response,
metadata={
"filetype": 'WEBPAGE',
"page": page_no,
"summary": True,
"search_space": search_space,
"BrowsingSessionId": doc.metadata['BrowsingSessionId'],
"VisitedWebPageURL": doc.metadata['VisitedWebPageURL'],
"VisitedWebPageTitle": doc.metadata['VisitedWebPageTitle'],
"VisitedWebPageDateWithTimeInISOString": doc.metadata['VisitedWebPageDateWithTimeInISOString'],
"VisitedWebPageReffererURL": doc.metadata['VisitedWebPageReffererURL'],
"VisitedWebPageVisitDurationInMilliseconds": doc.metadata['VisitedWebPageVisitDurationInMilliseconds'],
}
)
else:
response = report_chain.invoke({"document": doc})
return Document(
page_content=response.content,
metadata={
"filetype": 'WEBPAGE',
"page": page_no,
"summary": True,
"search_space": search_space,
"BrowsingSessionId": doc.metadata['BrowsingSessionId'],
"VisitedWebPageURL": doc.metadata['VisitedWebPageURL'],
"VisitedWebPageTitle": doc.metadata['VisitedWebPageTitle'],
"VisitedWebPageDateWithTimeInISOString": doc.metadata['VisitedWebPageDateWithTimeInISOString'],
"VisitedWebPageReffererURL": doc.metadata['VisitedWebPageReffererURL'],
"VisitedWebPageVisitDurationInMilliseconds": doc.metadata['VisitedWebPageVisitDurationInMilliseconds'],
}
)
def encode_docs_hierarchical(self, documents, search_space_instance, files_type, db: Session = Depends(get_db)):
"""
Creates and Saves/Updates docs in hierarchical indices and postgres table
"""
page_no_offset = len(self.detailed_store.get()['documents'])
# Process documents
summaries = []
if(files_type=='WEBPAGE'):
batch_summaries = [self.summarize_webpage_doc(page_no = i + page_no_offset, doc=doc, search_space=search_space_instance.name) for i, doc in enumerate(documents)]
else:
batch_summaries = [self.summarize_file_doc(page_no = i + page_no_offset , doc=doc, search_space=search_space_instance.name) for i, doc in enumerate(documents)]
summaries.extend(batch_summaries)
detailed_chunks = []
for i, summary in enumerate(summaries):
# Add single summary in vector store
added_doc_id = self.summary_store.add_documents(filter_complex_metadata([summary]))
if(files_type=='WEBPAGE'):
new_pg_doc = Documents(
title=summary.metadata['VisitedWebPageTitle'],
document_metadata=stringify(summary.metadata),
page_content=documents[i].page_content,
file_type='WEBPAGE',
summary_vector_id=added_doc_id[0],
)
else:
new_pg_doc = Documents(
title=summary.metadata['filename'],
document_metadata=stringify(summary.metadata),
page_content=documents[i].page_content,
file_type=summary.metadata['filetype'],
summary_vector_id=added_doc_id[0],
)
# Store it in PG
search_space_instance.documents.append(new_pg_doc)
db.commit()
# Semantic chucking for better contexual compression
text_splitter = SemanticChunker(embeddings=self.embeddings)
chunks = text_splitter.split_documents([documents[i]])
# Update metadata for detailed chunks
for i, chunk in enumerate(chunks):
chunk.metadata.update({
"summary": False,
"page": summary.metadata['page'],
})
if(files_type == 'WEBPAGE'):
ieee_content = (
f"=======================================DOCUMENT METADATA==================================== \n"
f"Source: {chunk.metadata['VisitedWebPageURL']} \n"
f"Title: {chunk.metadata['VisitedWebPageTitle']} \n"
f"Visited Date and Time : {chunk.metadata['VisitedWebPageDateWithTimeInISOString']} \n"
f"============================DOCUMENT PAGE CONTENT CHUNK===================================== \n"
f"Page Content Chunk: \n\n{chunk.page_content}\n\n"
f"===================================================================================== \n"
)
else:
ieee_content = (
f"=======================================DOCUMENT METADATA==================================== \n"
f"Source: {chunk.metadata['filename']} \n"
f"Title: {chunk.metadata['filename']} \n"
f"Visited Date and Time : {datetime.now()} \n"
f"============================DOCUMENT PAGE CONTENT CHUNK===================================== \n"
f"Page Content Chunk: \n\n{chunk.page_content}\n\n"
f"===================================================================================== \n"
)
chunk.page_content = ieee_content
detailed_chunks.extend(chunks)
#update vector stores
self.detailed_store.add_documents(filter_complex_metadata(detailed_chunks))
return self.summary_store, self.detailed_store
def delete_vector_stores(self, summary_ids_to_delete: list[str], db: Session = Depends(get_db)):
self.summary_store.delete(ids=summary_ids_to_delete)
return "success"
def summary_vector_search(self,query, search_space='GENERAL'):
top_summaries_compressor = FlashrankRerank(top_n=20)
top_summaries_retreiver = ContextualCompressionRetriever(
base_compressor=top_summaries_compressor, base_retriever=self.summary_store.as_retriever(search_kwargs={'filter': {'search_space': search_space}})
)
return top_summaries_retreiver.invoke(query)
def deduplicate_references_and_update_answer(self, answer: str, references: List[Reference]) -> tuple[str, List[Reference]]:
"""
Deduplicates references and updates the answer text to maintain correct reference numbering.
Args:
answer: The text containing reference citations
references: List of Reference objects
Returns:
tuple: (updated_answer, deduplicated_references)
"""
# Track unique references and create ID mapping using a dictionary comprehension
unique_refs = {}
id_mapping = {
ref.id: unique_refs.setdefault(
ref.source, Reference(id=str(len(unique_refs) + 1), title=ref.title, source=ref.source)
).id
for ref in references
}
# Apply new mappings to the answer text
updated_answer = answer
for old_id, new_id in sorted(id_mapping.items(), key=lambda x: len(x[0]), reverse=True):
updated_answer = updated_answer.replace(f'[{old_id}]', f'[{new_id}]')
return updated_answer, list(unique_refs.values())
async def ws_get_vectorstore_report(self, query: str, report_type: str, report_source: str, documents: List[Document],websocket: WebSocket) -> str:
researcher = GPTResearcher(query=query, report_type=report_type, report_source=report_source, documents=documents, report_format="APA",websocket=websocket)
await researcher.conduct_research()
report = await researcher.write_report()
return report
async def ws_get_web_report(self, query: str, report_type: str, report_source: str, websocket: WebSocket) -> str:
researcher = GPTResearcher(query=query, report_type=report_type, report_source=report_source, report_format="APA",websocket=websocket)
await researcher.conduct_research()
report = await researcher.write_report()
return report
async def ws_experimental_search(self, websocket: WebSocket, manager: ConnectionManager , query, search_space='GENERAL', report_type = "custom_report", report_source = "langchain_documents"):
custom_prompt = """
Please answer the following user query using only the **Document Page Content** provided below, while citing sources exclusively from the **Document Metadata** section, in the format shown. **Do not add any external information.**
**USER QUERY:** """ + query + """
**Answer Requirements:**
- Provide a detailed long response using IEEE-style in-text citations (e.g., [1], [2]) based solely on the **Document Page Content**.
- Use **Document Metadata** only for citation details and format each reference exactly once, with no duplicates.
- Structure references in this format at the end of your response, using this format: (Access Date and Time). [Title or Filename](Source)
FOR EXAMPLE:
EXAMPLE User Query : Explain the impact of artificial intelligence on modern healthcare.
EXAMPLE Given Documents:
=======================================DOCUMENT METADATA==================================== \n"
Source: https://www.reddit.com/r/ChatGPT/comments/13na8yp/highly_effective_prompt_for_summarizing_gpt4/ \n
Title: Artificial intelligence\n
Visited Date and Time : 2024-10-23T22:44:03-07:00 \n
============================DOCUMENT PAGE CONTENT CHUNK===================================== \n
Page Content Chunk: \n\nArtificial intelligence (AI) has significantly transformed modern healthcare by enhancing diagnostic accuracy, personalizing patient care, and optimizing operational efficiency. AI algorithms can analyze vast datasets to identify patterns that may be missed by human practitioners, leading to improved diagnostic outcomes. \n\n
===================================================================================== \n
=======================================DOCUMENT METADATA==================================== \n"
Source: https://github.com/MODSetter/SurfSense \n
Title: MODSetter/SurfSense: Personal AI Assistant for Internet Surfers and Researchers. \n
Visited Date and Time : 2024-10-23T22:44:03-07:00 \n
============================DOCUMENT PAGE CONTENT CHUNK===================================== \n
Page Content Chunk: \n\nAI systems have been deployed in radiology to detect anomalies in medical imaging with high precision, reducing the risk of misdiagnosis and improving patient outcomes. Additionally, AI-powered chatbots and virtual assistants are being used to provide 24/7 support, answer queries, and offer personalized health advice\n\n
===================================================================================== \n
=======================================DOCUMENT METADATA==================================== \n"
Source: https://github.com/MODSetter/SurfSense \n
Title: MODSetter/SurfSense: Personal AI Assistant for Internet Surfers and Researchers. \n
Visited Date and Time : 2024-10-23T22:44:03-07:00 \n
============================DOCUMENT PAGE CONTENT CHUNK===================================== \n
Page Content Chunk: \n\nAI algorithms can analyze a patient's genetic information to predict their risk of certain diseases and recommend tailored treatment plans. \n\n
===================================================================================== \n
=======================================DOCUMENT METADATA==================================== \n"
Source: filename.pdf \n
============================DOCUMENT PAGE CONTENT CHUNK===================================== \n
Page Content Chunk: \n\nApart from diagnostics, AI-driven tools facilitate personalized treatment plans by considering individual patient data, thereby improving patient outcomes\n\n
===================================================================================== \n
Ensure your response is structured something like this:
**OUTPUT FORMAT:**
---
**Answer:**
Artificial intelligence (AI) has significantly transformed modern healthcare by enhancing diagnostic accuracy, personalizing patient care, and optimizing operational efficiency. AI algorithms can analyze vast datasets to identify patterns that may be missed by human practitioners, leading to improved diagnostic outcomes [1]. For instance, AI systems have been deployed in radiology to detect anomalies in medical imaging with high precision [2]. Moreover, AI-driven tools facilitate personalized treatment plans by considering individual patient data, thereby improving patient outcomes [3].
**References:**
1. (2024, October 23). [Artificial intelligence GPT-4 Optimized: r/ChatGPT](https://www.reddit.com/r/ChatGPT/comments/13na8yp/highly_effective_prompt_for_summarizing_gpt4)
2. (2024, October 23). [MODSetter/SurfSense: Personal AI Assistant for Internet Surfers and Researchers](https://github.com/MODSetter/SurfSense)
3. (2024, October 23). [filename.pdf](filename.pdf)
---
"""
structured_llm = self.llm.with_structured_output(AIAnswer)
if report_source == "web" :
if report_type == "custom_report" :
ret_report = await self.ws_get_web_report(query=custom_prompt, report_type=report_type, report_source="web", websocket=websocket)
else:
ret_report = await self.ws_get_web_report(
query=query,
report_type=report_type,
report_source="web",
websocket=websocket
)
await manager.send_personal_message(
json.dumps({"type": "stream", "content": "Converting to IEEE format..."}),
websocket
)
ret_report = self.llm.invoke("I have a report written in APA format. Please convert it to IEEE format, ensuring that all citations, references, headings, and overall formatting adhere to the IEEE style guidelines. Maintain the original content and structure while applying the correct IEEE formatting rules. Just return the converted report thats it. NOW MY REPORT : " + ret_report).content
for chuck in structured_llm.stream(
"Please extract and separate the references from the main text. "
"References are formatted as follows:"
"[Reference Id]. (Access Date and Time). [Title or Filename](Source or URL). "
"Provide the text and references as distinct outputs. "
"IMPORTANT : Never hallucinate the references. If there is no reference just return nothing in the reference field."
"Here is the content to process: \n\n\n" + ret_report):
# ans, sources = self.deduplicate_references_and_update_answer(answer=chuck.answer, references=chuck.references)
await manager.send_personal_message(
json.dumps({"type": "stream", "sources": [source.model_dump() for source in chuck.references]}),
websocket
)
await manager.send_personal_message(
json.dumps({"type": "stream", "content": ret_report}),
websocket
)
return
contextdocs = []
top_summaries_compressor = FlashrankRerank(top_n=5)
details_compressor = FlashrankRerank(top_n=50)
top_summaries_retreiver = ContextualCompressionRetriever(
base_compressor=top_summaries_compressor, base_retriever=self.summary_store.as_retriever(search_kwargs={'filter': {'search_space': search_space}})#
)
top_summaries_compressed_docs = top_summaries_retreiver.invoke(query)
rel_docs = filter_complex_metadata(top_summaries_compressed_docs)
await manager.send_personal_message(
json.dumps({"type": "stream", "relateddocs": [relateddoc.model_dump() for relateddoc in rel_docs]}, cls=NumpyEncoder),
websocket
)
for summary in top_summaries_compressed_docs:
# For each summary, retrieve relevant detailed chunks
page_number = summary.metadata["page"]
detailed_compression_retriever = ContextualCompressionRetriever(
base_compressor=details_compressor, base_retriever=self.detailed_store.as_retriever(search_kwargs={'filter': {'page': page_number}})
)
detailed_compressed_docs = detailed_compression_retriever.invoke(
query
)
contextdocs.extend(detailed_compressed_docs)
# local_report = asyncio.run(self.get_vectorstore_report(query=custom_prompt, report_type=report_type, report_source=report_source, documents=contextdocs))
if report_source == "langchain_documents" :
if report_type == "custom_report" :
ret_report = await self.ws_get_vectorstore_report(query=custom_prompt, report_type=report_type, report_source=report_source, documents=contextdocs, websocket=websocket)
else:
ret_report = await self.ws_get_vectorstore_report(query=query, report_type=report_type, report_source=report_source, documents=contextdocs, websocket=websocket)
await manager.send_personal_message(
json.dumps({"type": "stream", "content": "Converting to IEEE format..."}),
websocket
)
ret_report = self.llm.invoke("I have a report written in APA format. Please convert it to IEEE format, ensuring that all citations, references, headings, and overall formatting adhere to the IEEE style guidelines. Maintain the original content and structure while applying the correct IEEE formatting rules. Just return the converted report thats it. NOW MY REPORT : " + ret_report).content
for chuck in structured_llm.stream(
"Please extract and separate the references from the main text. "
"References are formatted as follows:"
"[Reference Id]. (Access Date and Time). [Title or Filename](Source or URL). "
"Provide the text and references as distinct outputs. "
"Ensure that in-text citation numbers such as [1], [2], (1), (2), etc., as well as in-text links or in-text citation links within the content, remain unaltered and are accurately extracted."
"IMPORTANT : Never hallucinate the references. If there is no reference just return nothing in the reference field."
"Here is the content to process: \n\n\n" + ret_report):
ans, sources = self.deduplicate_references_and_update_answer(answer=chuck.answer, references=chuck.references)
await manager.send_personal_message(
json.dumps({"type": "stream", "sources": [source.model_dump() for source in sources]}),
websocket
)
await manager.send_personal_message(
json.dumps({"type": "stream", "content": ans}),
websocket
)
return

View file

@ -1,27 +0,0 @@
from collections import OrderedDict
try:
from collections.abc import Mapping, Sequence
except ImportError:
from collections import Mapping, Sequence
import json
COMPACT_SEPARATORS = (',', ':')
def order_by_key(kv):
key, val = kv
return key
def recursive_order(node):
if isinstance(node, Mapping):
ordered_mapping = OrderedDict(sorted(node.items(), key=order_by_key))
for key, value in ordered_mapping.items():
ordered_mapping[key] = recursive_order(value)
return ordered_mapping
elif isinstance(node, Sequence) and not isinstance(node, (str, bytes)):
return [recursive_order(item) for item in node]
return node
def stringify(node):
return json.dumps(recursive_order(node), separators=COMPACT_SEPARATORS)

View file

@ -1,16 +0,0 @@
from sqlalchemy import create_engine, Column, Integer, String
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
import os
from dotenv import load_dotenv
load_dotenv()
POSTGRES_DATABASE_URL = os.environ.get("POSTGRES_DATABASE_URL")
engine = create_engine(
POSTGRES_DATABASE_URL
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()

View file

@ -1,80 +0,0 @@
from datetime import datetime
# from typing import List
from database import Base, engine
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Boolean, create_engine
from sqlalchemy.orm import relationship
class BaseModel(Base):
__abstract__ = True
__allow_unmapped__ = True
id = Column(Integer, primary_key=True, index=True)
class Chat(BaseModel):
__tablename__ = "chats"
type = Column(String)
title = Column(String)
chats_list = Column(String)
search_space_id = Column(Integer, ForeignKey('searchspaces.id'))
search_space = relationship('SearchSpace', back_populates='chats')
class Documents(BaseModel):
__tablename__ = "documents"
title = Column(String)
created_at = Column(DateTime, default=datetime.now)
file_type = Column(String)
document_metadata = Column(String)
page_content = Column(String)
summary_vector_id = Column(String)
search_space_id = Column(Integer, ForeignKey("searchspaces.id"))
search_space = relationship("SearchSpace", back_populates="documents")
class Podcast(BaseModel):
__tablename__ = "podcasts"
title = Column(String)
created_at = Column(DateTime, default=datetime.now)
is_generated = Column(Boolean, default=False)
podcast_content = Column(String, default="")
file_location = Column(String, default="")
search_space_id = Column(Integer, ForeignKey("searchspaces.id"))
search_space = relationship("SearchSpace", back_populates="podcasts")
class SearchSpace(BaseModel):
__tablename__ = "searchspaces"
name = Column(String, index=True)
description = Column(String)
created_at = Column(DateTime, default=datetime.now)
user_id = Column(Integer, ForeignKey("users.id"))
user = relationship("User", back_populates="search_spaces")
documents = relationship("Documents", back_populates="search_space", order_by="Documents.id")
podcasts = relationship("Podcast", back_populates="search_space", order_by="Podcast.id")
chats = relationship('Chat', back_populates='search_space', order_by='Chat.id')
class User(BaseModel):
__tablename__ = "users"
username = Column(String, unique=True, index=True)
hashed_password = Column(String)
search_spaces = relationship("SearchSpace", back_populates="user")
# Create the database tables if they don't exist
User.metadata.create_all(bind=engine)

View file

@ -1,70 +0,0 @@
# Need to move new prompts to here will move after testing some more
from langchain_core.prompts.prompt import PromptTemplate
from datetime import datetime, timezone
DATE_TODAY = "Today's date is " + datetime.now(timezone.utc).astimezone().isoformat() + '\n'
report_template = """
You are an eagle-eyed researcher, skilled at summarizing lengthy documents with precision and clarity. Create a comprehensive summary of the provided document, capturing the main ideas, key details, and essential arguments presented.
Length and depth:
- Produce a detailed summary that captures all essential content from the document. Adjust the length as needed to ensure no critical information is omitted.
Structure:
- Organize the summary logically.
- Use clear headings and subheadings for different sections or themes, to help convey the flow of ideas.
Content to Include:
- Highlight the main arguments.
- Identify and include key supporting details.
- Incorporate relevant examples or data that strengthen the key points.
Tone:
- Use an objective, neutral tone, delivering precise and insightful analysis without personal opinions or interpretations.
# Steps
1. **Thoroughly read the entire text** to grasp the author's perspective, main arguments, and overall structure.
2. **Identify key sections or themes** and thematically group the information.
3. **Extract main points** from each section to capture supporting details and relevant examples.
4. **Use headings/subheadings** to provide a clear and logically organized structure.
5. **Write a conclusion** that succinctly encapsulates the overarching message and significance of the text.
# Output Format
- Provide a summary in well-structured paragraphs.
- Clearly delineate different themes or sections with suitable headings or sub-headings.
- Adjust the length of the summary based on the content's complexity and depth.
- Conclusions should be clearly marked.
# Example
**Heading 1: Introduction to Main Theme**
The document begins by discussing [main idea], outlining [initial point] with supporting data like [example].
**Heading 2: Supporting Arguments**
The text then presents several supporting arguments, such as [supporting detail]. Notably, [data or statistic] is used to reinforce the main concept.
**Heading 3: Conclusion**
In summary, [document's conclusion statement], highlighting the broader implications like [significance].
(This is an example format; each section should be expanded comprehensively based on the provided document.)
# Notes
- Ensure the summary is adequately comprehensive without omitting crucial parts.
- Aim for readability by using formal yet accessible language, maintaining depth without unnecessary complexity.
Now, Please summarize the following document:
<document>
{document}
</document>
"""
report_prompt = PromptTemplate(
input_variables=["document"],
template=report_template
)

View file

@ -1,95 +0,0 @@
from pydantic import BaseModel, Field
from typing import List, Optional
class UserCreate(BaseModel):
username: str
password: str
apisecretkey: str
class DocMeta(BaseModel):
BrowsingSessionId: Optional[str] = Field(default=None, description="BrowsingSessionId of Document")
VisitedWebPageURL: Optional[str] = Field(default=None, description="VisitedWebPageURL of Document")
VisitedWebPageTitle: Optional[str] = Field(default=None, description="VisitedWebPageTitle of Document")
VisitedWebPageDateWithTimeInISOString: Optional[str] = Field(default=None, description="VisitedWebPageDateWithTimeInISOString of Document")
VisitedWebPageReffererURL: Optional[str] = Field(default=None, description="VisitedWebPageReffererURL of Document")
VisitedWebPageVisitDurationInMilliseconds: Optional[int] = Field(default=None, description="VisitedWebPageVisitDurationInMilliseconds of Document")
class CreatePodcast(BaseModel):
token: str
search_space_id: int
title: str
wordcount: int
podcast_content: str
class CreateStorageSpace(BaseModel):
name: str
description: str
token : str
class Reference(BaseModel):
id: str = Field(..., description="reference no")
title: str = Field(..., description="reference title.")
source: str = Field(..., description="reference Source or URL. Prefer URL only include file names if no URL available.")
class AIAnswer(BaseModel):
answer: str = Field(..., description="The provided answer, excluding references, but including in-text citation numbers such as [1], [2], (1), (2), etc.")
references: List[Reference] = Field(..., description="References")
class DocWithContent(BaseModel):
DocMetadata: Optional[str] = Field(default=None, description="Document Metadata")
Content: Optional[str] = Field(default=None, description="Document Page Content")
class DocumentsToDelete(BaseModel):
ids_to_delete: List[str]
token: str
class UserQuery(BaseModel):
query: str
search_space: str
token: str
class MainUserQuery(BaseModel):
query: str
search_space: str
token: str
class ChatHistory(BaseModel):
type: str
content: str | List[DocMeta] | List[str]
class UserQueryWithChatHistory(BaseModel):
chat: List[ChatHistory]
query: str
token: str
class DescriptionResponse(BaseModel):
response: str
class RetrivedDocListItem(BaseModel):
metadata: DocMeta
pageContent: str
class RetrivedDocList(BaseModel):
documents: List[RetrivedDocListItem]
search_space_id: int
token: str
class UserQueryResponse(BaseModel):
response: str
relateddocs: List[DocWithContent]
class NewUserQueryResponse(BaseModel):
response: str
sources: List[Reference]
relateddocs: List[DocWithContent]
class NewUserChat(BaseModel):
token: str
type: str
title: str
chats_list: str
class ChatToUpdate(BaseModel):
chatid: str
token: str
chats_list: str

View file

@ -1,275 +0,0 @@
aiofiles==24.1.0
aiohappyeyeballs==2.4.3
aiohttp==3.10.10
aiosignal==1.3.1
alabaster==1.0.0
annotated-types==0.7.0
anyio==4.6.2.post1
arxiv==2.1.3
asgiref==3.8.1
attrs==24.2.0
babel==2.16.0
backoff==2.2.1
bcrypt==4.2.0
beautifulsoup4==4.12.3
bleach==6.2.0
Brotli==1.1.0
build==1.2.2.post1
cachetools==5.5.0
certifi==2024.8.30
cffi==1.17.1
chardet==5.2.0
charset-normalizer==3.4.0
chroma-hnswlib==0.7.6
chromadb==0.5.18
click==8.1.7
colorama==0.4.6
coloredlogs==15.0.1
cryptography==43.0.3
cssselect2==0.7.0
Cython==3.0.11
dataclasses-json==0.6.7
deepdiff==8.0.1
defusedxml==0.7.1
Deprecated==1.2.14
distro==1.9.0
docopt==0.6.2
docstring_parser==0.16
docutils==0.21.2
durationpy==0.9
ecdsa==0.19.0
edge-tts==6.1.18
elevenlabs==1.12.1
emoji==2.14.0
execnet==2.1.1
fastapi==0.115.5
fastjsonschema==2.20.0
feedparser==6.0.11
ffmpeg==1.4
filelock==3.16.1
filetype==1.2.0
FlashRank==0.2.9
flatbuffers==24.3.25
fonttools==4.54.1
frozenlist==1.5.0
fsspec==2024.10.0
fuzzywuzzy==0.18.0
google-ai-generativelanguage==0.6.10
google-api-core==2.23.0
google-api-python-client==2.152.0
google-auth==2.36.0
google-auth-httplib2==0.2.0
google-cloud-aiplatform==1.72.0
google-cloud-bigquery==3.27.0
google-cloud-core==2.4.1
google-cloud-resource-manager==1.13.0
google-cloud-storage==2.18.2
google-cloud-texttospeech==2.21.0
google-crc32c==1.6.0
google-generativeai==0.8.3
google-resumable-media==2.7.2
googleapis-common-protos==1.66.0
gpt-researcher==0.10.3
greenlet==3.1.1
grpc-google-iam-v1==0.13.1
grpcio==1.67.1
grpcio-status==1.67.1
h11==0.14.0
html5lib==1.1
htmldocx==0.0.6
httpcore==1.0.6
httplib2==0.22.0
httptools==0.6.4
httpx==0.27.2
httpx-sse==0.4.0
huggingface-hub==0.26.2
humanfriendly==10.0
idna==3.10
imagesize==1.4.1
importlib_metadata==8.5.0
importlib_resources==6.4.5
iniconfig==2.0.0
Jinja2==3.1.4
jiter==0.7.1
joblib==1.4.2
json5==0.9.28
json_repair==0.30.1
jsonpatch==1.33
jsonpath-python==1.0.6
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
jupyter_client==8.6.3
jupyter_core==5.7.2
jupyterlab_pygments==0.3.0
kubernetes==31.0.0
langchain==0.3.7
langchain-chroma==0.1.4
langchain-community==0.3.7
langchain-core==0.3.17
langchain-experimental==0.3.3
langchain-google-genai==2.0.4
langchain-google-vertexai==2.0.7
langchain-ollama==0.2.0
langchain-openai==0.2.8
langchain-text-splitters==0.3.2
langchain-unstructured==0.1.5
langdetect==1.0.9
langgraph==0.2.46
langgraph-checkpoint==2.0.3
langgraph-cli==0.1.53
langgraph-sdk==0.1.35
langsmith==0.1.142
Levenshtein==0.26.1
litellm==1.52.6
loguru==0.7.2
lxml==5.3.0
lxml_html_clean==0.3.1
Markdown==3.7
markdown-it-py==3.0.0
markdown2==2.5.1
MarkupSafe==3.0.2
marshmallow==3.23.1
md2pdf==1.0.1
mdurl==0.1.2
mistune==3.0.2
mmh3==5.0.1
monotonic==1.6
mpmath==1.3.0
msgpack==1.1.0
multidict==6.1.0
mypy-extensions==1.0.0
nbclient==0.10.0
nbconvert==7.16.4
nbformat==5.10.4
nbsphinx==0.9.5
nest-asyncio==1.6.0
nltk==3.9.1
numpy==1.26.4
oauthlib==3.2.2
olefile==0.47
ollama==0.3.3
onnxruntime==1.20.0
openai==1.54.4
opentelemetry-api==1.28.1
opentelemetry-exporter-otlp-proto-common==1.28.1
opentelemetry-exporter-otlp-proto-grpc==1.28.1
opentelemetry-instrumentation==0.49b1
opentelemetry-instrumentation-asgi==0.49b1
opentelemetry-instrumentation-fastapi==0.49b1
opentelemetry-proto==1.28.1
opentelemetry-sdk==1.28.1
opentelemetry-semantic-conventions==0.49b1
opentelemetry-util-http==0.49b1
orderly-set==5.2.2
orjson==3.10.11
overrides==7.7.0
packaging==24.2
pandas==2.2.3
pandoc==2.4
pandocfilters==1.5.1
passlib==1.7.4
pillow==10.4.0
platformdirs==4.3.6
pluggy==1.5.0
plumbum==1.9.0
ply==3.11
podcastfy==0.3.5
posthog==3.7.0
propcache==0.2.0
proto-plus==1.25.0
protobuf==5.28.3
psutil==6.1.0
psycopg2==2.9.10
pyasn1==0.6.1
pyasn1_modules==0.4.1
pycparser==2.22
pydantic==2.9.2
pydantic-settings==2.6.1
pydantic_core==2.23.4
pydub==0.25.1
pydyf==0.11.0
Pygments==2.18.0
PyMuPDF==1.24.13
pyparsing==3.2.0
pypdf==5.1.0
pyphen==0.15.0
PyPika==0.48.9
pyproject_hooks==1.2.0
pyreadline3==3.5.4
pytest==8.3.3
pytest-xdist==3.6.1
python-dateutil==2.8.2
python-docx==1.1.2
python-dotenv==1.0.1
python-iso639==2024.10.22
python-jose==3.3.0
python-Levenshtein==0.26.1
python-magic==0.4.27
python-multipart==0.0.17
python-oxmsg==0.0.1
pytz==2024.2
pywin32==308
PyYAML==6.0.2
pyzmq==26.2.0
RapidFuzz==3.10.1
referencing==0.35.1
regex==2024.11.6
requests==2.32.3
requests-oauthlib==2.0.0
requests-toolbelt==1.0.0
rich==13.9.4
rpds-py==0.21.0
rsa==4.9
setuptools==75.4.0
sgmllib3k==1.0.0
shapely==2.0.6
shellingham==1.5.4
six==1.16.0
sniffio==1.3.1
snowballstemmer==2.2.0
soupsieve==2.6
Sphinx==8.1.3
sphinx-autodoc-typehints==2.5.0
sphinx-rtd-theme==3.0.1
sphinxcontrib-applehelp==2.0.0
sphinxcontrib-devhelp==2.0.0
sphinxcontrib-htmlhelp==2.1.0
sphinxcontrib-jquery==4.1
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==2.0.0
sphinxcontrib-serializinghtml==2.0.0
SQLAlchemy==2.0.35
starlette==0.41.2
sympy==1.13.3
tenacity==9.0.0
tiktoken==0.8.0
tinycss2==1.4.0
tinyhtml5==2.0.0
tokenizers==0.20.3
tornado==6.4.1
tqdm==4.67.0
traitlets==5.14.3
typer==0.12.5
types-PyYAML==6.0.12.20240917
typing-inspect==0.9.0
typing_extensions==4.12.2
tzdata==2024.2
unstructured==0.16.5
unstructured-client==0.25.9
uritemplate==4.1.1
urllib3==2.2.3
uvicorn==0.32.0
watchfiles==0.24.0
weasyprint==63.0
webencodings==0.5.1
websocket-client==1.8.0
websockets==14.1
wheel==0.44.0
win32-setctime==1.1.0
wrapt==1.16.0
wsproto==1.2.0
yarl==1.17.1
youtube-transcript-api==0.6.2
zipp==3.21.0
zopfli==0.2.3.post1

View file

@ -1,880 +0,0 @@
from __future__ import annotations
import json
from typing import List
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.documents import Document
from langchain_ollama import OllamaLLM
from langchain_openai import ChatOpenAI
from sqlalchemy import insert
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
from langchain_unstructured import UnstructuredLoader
# OUR LIBS
from HIndices import ConnectionManager, HIndices
from Utils.stringify import stringify
from prompts import DATE_TODAY
from pydmodels import ChatToUpdate, CreatePodcast, CreateStorageSpace, DescriptionResponse, DocWithContent, DocumentsToDelete, MainUserQuery, NewUserChat, NewUserQueryResponse, UserCreate, UserQuery, RetrivedDocList, UserQueryResponse, UserQueryWithChatHistory
from podcastfy.client import generate_podcast
# Auth Libs
from fastapi import FastAPI, Depends, Form, HTTPException, Response, WebSocket, status, UploadFile, BackgroundTasks
from sqlalchemy.orm import Session
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from datetime import datetime, timedelta
from passlib.context import CryptContext
from models import Chat, Documents, Podcast, SearchSpace, User
from database import SessionLocal
from fastapi.middleware.cors import CORSMiddleware
import os
from dotenv import load_dotenv
load_dotenv()
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
SMART_LLM = os.environ.get("SMART_LLM")
IS_LOCAL_SETUP = True if SMART_LLM.startswith("ollama") else False
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.environ.get("ACCESS_TOKEN_EXPIRE_MINUTES"))
ALGORITHM = os.environ.get("ALGORITHM")
API_SECRET_KEY = os.environ.get("API_SECRET_KEY")
SECRET_KEY = os.environ.get("SECRET_KEY")
UNSTRUCTURED_API_KEY = os.environ.get("UNSTRUCTURED_API_KEY")
def extract_model_name(model_string: str) -> tuple[str, str]:
part1, part2 = model_string.split(":", 1) # Split into two parts at the first colon
return part2
MODEL_NAME = extract_model_name(SMART_LLM)
app = FastAPI()
manager = ConnectionManager()
# Dependency
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# Manual Origins
# origins = [
# "http://localhost:3000", # Adjust the port if your frontend runs on a different one
# "https://yourfrontenddomain.com",
# ]
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins from the list
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def get_user_by_username(db: Session, username: str):
return db.query(User).filter(User.username == username).first()
def create_user(db: Session, user: UserCreate):
hashed_password = pwd_context.hash(user.password)
db_user = User(username=user.username, hashed_password=hashed_password)
db.add(db_user)
db.commit()
return "complete"
@app.post("/register")
def register_user(user: UserCreate, db: Session = Depends(get_db)):
if(user.apisecretkey != API_SECRET_KEY):
raise HTTPException(status_code=401, detail="Unauthorized")
db_user = get_user_by_username(db, username=user.username)
if db_user:
raise HTTPException(status_code=400, detail="Username already registered")
del user.apisecretkey
return create_user(db=db, user=user)
# Authenticate the user
def authenticate_user(username: str, password: str, db: Session):
user = db.query(User).filter(User.username == username).first()
if not user:
return False
if not pwd_context.verify(password, user.hashed_password):
return False
return user
# Create access token
def create_access_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
@app.post("/token")
def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
user = authenticate_user(form_data.username, form_data.password, db)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}
def verify_token(token: str = Depends(oauth2_scheme)):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
return payload
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
@app.get("/verify-token/{token}")
async def verify_user_token(token: str):
verify_token(token=token)
return {"message": "Token is valid"}
@app.post("/searchspace/{search_space_id}/chat/create")
def create_chat_in_searchspace(chat: NewUserChat, search_space_id: int, db: Session = Depends(get_db)):
try:
payload = jwt.decode(chat.token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
search_space = db.query(SearchSpace).filter(
SearchSpace.id == search_space_id,
SearchSpace.user_id == db.query(User).filter(User.username == username).first().id
).first()
if not search_space:
raise HTTPException(status_code=404, detail="SearchSpace not found or does not belong to the user")
new_chat = Chat(type=chat.type, title=chat.title, chats_list=chat.chats_list)
search_space.chats.append(new_chat)
db.commit()
db.refresh(new_chat)
return {"chat_id": new_chat.id}
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
@app.post("/searchspace/{search_space_id}/chat/update")
def update_chat_in_searchspace(chat: ChatToUpdate, search_space_id: int, db: Session = Depends(get_db)):
try:
payload = jwt.decode(chat.token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
chatindb = db.query(Chat).join(SearchSpace).filter(
Chat.id == chat.chatid,
SearchSpace.id == search_space_id,
SearchSpace.user_id == db.query(User).filter(User.username == username).first().id
).first()
if not chatindb:
raise HTTPException(status_code=404, detail="Chat not found or does not belong to the searchspace owned by the user")
chatindb.chats_list = chat.chats_list
db.commit()
return {"message": "Chat Updated"}
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
@app.get("/searchspace/{search_space_id}/chat/delete/{token}/{chatid}")
async def delete_chat_in_searchspace(token: str, search_space_id: int, chatid: str, db: Session = Depends(get_db)):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
chatindb = db.query(Chat).join(SearchSpace).filter(
Chat.id == chatid,
SearchSpace.id == search_space_id,
SearchSpace.user_id == db.query(User).filter(User.username == username).first().id
).first()
if not chatindb:
raise HTTPException(status_code=404, detail="Chat not found or does not belong to the searchspace owned by the user")
db.delete(chatindb)
db.commit()
return {"message": "Chat Deleted"}
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
@app.get("/searchspace/{search_space_id}/chat/{token}/{chatid}")
def get_chat_by_id_in_searchspace(chatid: int, search_space_id: int, token: str, db: Session = Depends(get_db)):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
chat = db.query(Chat).join(SearchSpace).filter(
Chat.id == chatid,
SearchSpace.id == search_space_id,
SearchSpace.user_id == db.query(User).filter(User.username == username).first().id
).first()
if not chat:
raise HTTPException(status_code=404, detail="Chat not found or does not belong to the searchspace owned by the user")
return chat
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
@app.get("/searchspace/{search_space_id}/chats/{token}")
def get_chats_in_searchspace(search_space_id: int, token: str, db: Session = Depends(get_db)):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
user = db.query(User).filter(User.username == username).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
# Filter chats that are specifically in the given search space
chats = db.query(Chat).filter(
Chat.search_space_id == search_space_id,
SearchSpace.user_id == user.id
).join(SearchSpace).all()
return chats
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
@app.get("/user/{token}/searchspace/{search_space_id}/documents/")
def get_user_documents(search_space_id: int, token: str, db: Session = Depends(get_db)):
try:
# Decode the token to get the username
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
# Get the user by username and ensure they exist
user = db.query(User).filter(User.username == username).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
# Verify the search space belongs to the user
search_space = db.query(SearchSpace).filter(
SearchSpace.id == search_space_id,
SearchSpace.user_id == user.id
).first()
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found or does not belong to the user")
# Retrieve documents associated with the search space
return db.query(Documents).filter(Documents.search_space_id == search_space_id).all()
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
@app.get("/user/{token}/searchspace/{search_space_id}/")
def get_user_search_space_by_id(search_space_id: int, token: str, db: Session = Depends(get_db)):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
# Get the user by username
user = db.query(User).filter(User.username == username).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
# Get the search space by ID and verify it belongs to this user
search_space = db.query(SearchSpace).filter(
SearchSpace.id == search_space_id,
SearchSpace.user_id == user.id
).first()
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found or does not belong to the user")
return search_space
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
@app.get("/user/{token}/searchspaces/")
def get_user_search_spaces(token: str, db: Session = Depends(get_db)):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
user = db.query(User).filter(User.username == username).first()
return db.query(SearchSpace).filter(SearchSpace.user_id == user.id).all()
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
@app.post("/user/create/searchspace/")
def create_user_search_space(data: CreateStorageSpace, db: Session = Depends(get_db)):
try:
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
user = db.query(User).filter(User.username == username).first()
db_search_space = SearchSpace(user_id=user.id, name=data.name, description=data.description)
db.add(db_search_space)
db.commit()
db.refresh(db_search_space)
return db_search_space
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
@app.post("/user/save/")
def save_user_extension_documents(data: RetrivedDocList, db: Session = Depends(get_db)):
try:
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
# Get the user by username and ensure they exist
user = db.query(User).filter(User.username == username).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
# Verify the search space belongs to the user
search_space = db.query(SearchSpace).filter(
SearchSpace.id == data.search_space_id,
SearchSpace.user_id == user.id
).first()
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found or does not belong to the user")
# all_search_space_docs = db.query(SearchSpace).filter(
# SearchSpace.user_id == user.id
# ).all()
# total_doc_count = 0
# for search_space in all_search_space_docs:
# total_doc_count += db.query(Documents).filter(Documents.search_space_id == search_space.id).count()
print(f"STARTED")
# Initialize containers for documents and entries
# DocumentPgEntry = []
raw_documents = []
# Process each document in the retrieved document list
for doc in data.documents:
# Construct document content
content = (
f"USER BROWSING SESSION EVENT: \n"
f"=======================================METADATA==================================== \n"
f"User Browsing Session ID : {doc.metadata.BrowsingSessionId} \n"
f"User Visited website with url : {doc.metadata.VisitedWebPageURL} \n"
f"This visited website url had title : {doc.metadata.VisitedWebPageTitle} \n"
f"User Visited this website from referring url : {doc.metadata.VisitedWebPageReffererURL} \n"
f"User Visited this website url at this Date and Time : {doc.metadata.VisitedWebPageDateWithTimeInISOString} \n"
f"User Visited this website for : {str(doc.metadata.VisitedWebPageVisitDurationInMilliseconds)} milliseconds. \n"
f"===================================================================================== \n"
f"Webpage Content of the visited webpage url in markdown format : \n\n{doc.pageContent}\n\n"
f"===================================================================================== \n"
)
raw_documents.append(Document(page_content=content, metadata=doc.metadata.__dict__))
# pgdocmeta = stringify(doc.metadata.__dict__)
# DocumentPgEntry.append(Documents(
# file_type='WEBPAGE',
# title=doc.metadata.VisitedWebPageTitle,
# search_space=search_space,
# document_metadata=pgdocmeta,
# page_content=content
# ))
# # Save documents in PostgreSQL
# search_space.documents.extend(DocumentPgEntry)
# db.commit()
# Create hierarchical indices
if IS_LOCAL_SETUP == True:
index = HIndices(username=username)
else:
index = HIndices(username=username, api_key=OPENAI_API_KEY)
# Save indices in vector stores
index.encode_docs_hierarchical(
documents=raw_documents,
search_space_instance=search_space,
files_type='WEBPAGE',
db=db
)
print("FINISHED")
return {
"success": "Save Job Completed Successfully"
}
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
@app.post("/user/uploadfiles/")
def save_user_documents(files: list[UploadFile], token: str = Depends(oauth2_scheme), search_space_id: int = Form(...), db: Session = Depends(get_db)):
try:
# Decode and verify the token
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
# Get the user by username and ensure they exist
user = db.query(User).filter(User.username == username).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
# Verify the search space belongs to the user
search_space = db.query(SearchSpace).filter(
SearchSpace.id == search_space_id,
SearchSpace.user_id == user.id
).first()
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found or does not belong to the user")
docs = []
for file in files:
if file.content_type.startswith('image'):
loader = UnstructuredLoader(
file=file.file,
api_key=UNSTRUCTURED_API_KEY,
partition_via_api=True,
chunking_strategy="basic",
max_characters=90000,
include_orig_elements=False,
)
else:
loader = UnstructuredLoader(
file=file.file,
api_key=UNSTRUCTURED_API_KEY,
partition_via_api=True,
chunking_strategy="basic",
max_characters=90000,
include_orig_elements=False,
strategy="fast"
)
filedocs = loader.load()
fileswithfilename = []
for f in filedocs:
temp = f
temp.metadata['filename'] = file.filename
fileswithfilename.append(temp)
docs.extend(fileswithfilename)
raw_documents = []
# Process each document in the retrieved document list
for doc in docs:
raw_documents.append(Document(page_content=doc.page_content, metadata=doc.metadata))
# Create hierarchical indices
if IS_LOCAL_SETUP == True:
index = HIndices(username=username)
else:
index = HIndices(username=username, api_key=OPENAI_API_KEY)
# Save indices in vector stores
index.encode_docs_hierarchical(documents=raw_documents, search_space_instance=search_space, files_type='OTHER', db=db)
print("FINISHED")
return {
"message": "Files Uploaded Successfully"
}
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
@app.websocket("/beta/chat/{search_space_id}/{token}")
async def searchspace_chat_websocket_endpoint(websocket: WebSocket, search_space_id: int, token: str, db: Session = Depends(get_db)):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
# Get the user by username and ensure they exist
user = db.query(User).filter(User.username == username).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
# Verify the search space belongs to the user
search_space = db.query(SearchSpace).filter(
SearchSpace.id == search_space_id,
SearchSpace.user_id == user.id
).first()
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found or does not belong to the user")
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
message = json.loads(data)
# print(message)
if message["type"] == "search_space_chat":
query = message["content"]
if message["searchtype"] == "local" :
report_source = "langchain_documents"
else:
report_source = message["searchtype"]
if message["answertype"] == "general_answer" :
report_type = "custom_report"
else:
report_type = message["answertype"]
# Create Heirarical Indecices
if(IS_LOCAL_SETUP == True):
index = HIndices(username=username)
else:
index = HIndices(username=username,api_key=OPENAI_API_KEY)
await index.ws_experimental_search(websocket=websocket, manager=manager, query=query, search_space=search_space.name, report_type=report_type, report_source=report_source)
await manager.send_personal_message(
json.dumps({"type": "end"}),
websocket
)
if message["type"] == "multiple_documents_chat":
query = message["content"]
received_chat_history = message["chat_history"]
chatHistory = []
chatHistory = [
SystemMessage(
content=DATE_TODAY + """You are an helpful assistant for question-answering tasks.
Use the following pieces of retrieved context to answer the question.
If you don't know the answer, just say that you don't know.
Context:""" + str(received_chat_history[0]['relateddocs']))
]
for data in received_chat_history[1:]:
if data["role"] == "user":
chatHistory.append(HumanMessage(content=data["content"]))
if data["role"] == "assistant":
chatHistory.append(AIMessage(content=data["content"]))
chatHistory.append(("human", "{input}"))
qa_prompt = ChatPromptTemplate.from_messages(chatHistory)
if(IS_LOCAL_SETUP == True):
llm = OllamaLLM(model=MODEL_NAME,temperature=0)
else:
llm = ChatOpenAI(temperature=0, model_name=MODEL_NAME, api_key=OPENAI_API_KEY)
descriptionchain = qa_prompt | llm
streamingResponse = ""
counter = 0
for res in descriptionchain.stream({"input": query}):
streamingResponse += res.content
if (counter < 20) :
counter += 1
else :
await manager.send_personal_message(
json.dumps({"type": "stream", "content": streamingResponse}),
websocket
)
counter = 0
await manager.send_personal_message(
json.dumps({"type": "stream", "content": streamingResponse}),
websocket
)
await manager.send_personal_message(
json.dumps({"type": "end"}),
websocket
)
except Exception as e:
print(f"Error: {e}")
finally:
manager.disconnect(websocket)
except JWTError:
await websocket.close(code=4003, reason="Invalid token")
@app.post("/user/searchspace/create-podcast")
async def create_podcast(
data: CreatePodcast,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db)
):
try:
# Verify token and get username
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
# Get user
user = db.query(User).filter(User.username == username).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
# Verify search space belongs to user
search_space = db.query(SearchSpace).filter(
SearchSpace.id == data.search_space_id,
SearchSpace.user_id == user.id
).first()
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found or does not belong to the user")
# Create new podcast entry
new_podcast = Podcast(
title=data.title,
podcast_content=data.podcast_content,
search_space_id=search_space.id
)
db.add(new_podcast)
db.commit()
db.refresh(new_podcast)
podcast_config = {
'word_count': data.wordcount,
'podcast_name': 'SurfSense Podcast',
'podcast_tagline': 'Your Own Personal Podcast.',
'output_language': 'English',
'user_instructions': 'Make if fun and engaging',
'engagement_techniques': ['Rhetorical Questions', 'Personal Testimonials', 'Quotes', 'Anecdotes', 'Analogies', 'Humor'],
}
try:
background_tasks.add_task(
generate_podcast_background,
new_podcast.id,
data.podcast_content,
MODEL_NAME,
"OPENAI_API_KEY",
podcast_config,
db
)
# # Check MODEL NAME behavior on Local Setups
# saved_file_location = generate_podcast(
# text=data.podcast_content,
# llm_model_name=MODEL_NAME,
# api_key_label="OPENAI_API_KEY",
# conversation_config=podcast_config,
# )
# new_podcast.file_location = saved_file_location
# new_podcast.is_generated = True
# db.commit()
# db.refresh(new_podcast)
return {"message": "Podcast created successfully", "podcast_id": new_podcast.id}
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
async def generate_podcast_background(
podcast_id: int,
podcast_content: str,
model_name: str,
api_key_label: str,
conversation_config: dict,
db: Session
):
try:
saved_file_location = generate_podcast(
text=podcast_content,
llm_model_name=model_name,
api_key_label=api_key_label,
conversation_config=conversation_config,
)
# Update podcast in database
podcast = db.query(Podcast).filter(Podcast.id == podcast_id).first()
if podcast:
podcast.file_location = saved_file_location
podcast.is_generated = True
db.commit()
except Exception as e:
# Log the error or handle it appropriately
print(f"Error generating podcast: {str(e)}")
@app.get("/user/{token}/searchspace/{search_space_id}/download-podcast/{podcast_id}")
async def download_podcast(search_space_id: int, podcast_id: int, token: str, db: Session = Depends(get_db)):
try:
# Verify the token and get the username
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
# Get the user by username
user = db.query(User).filter(User.username == username).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
# Verify the search space belongs to the user
search_space = db.query(SearchSpace).filter(
SearchSpace.id == search_space_id,
SearchSpace.user_id == user.id
).first()
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found or does not belong to the user")
# Retrieve the podcast file from the database
podcast = db.query(Podcast).filter(
Podcast.id == podcast_id,
Podcast.search_space_id == search_space_id
).first()
if not podcast:
raise HTTPException(status_code=404, detail="Podcast not found in the specified search space")
# Read the file content
with open(podcast.file_location, "rb") as file:
file_content = file.read()
# Create a response with the file content
response = Response(content=file_content)
response.headers["Content-Disposition"] = f"attachment; filename={podcast.title}.mp3"
response.headers["Content-Type"] = "audio/mpeg"
return response
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/user/{token}/searchspace/{search_space_id}/podcasts")
async def get_user_podcasts(token: str, search_space_id: int, db: Session = Depends(get_db)):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
user = db.query(User).filter(User.username == username).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
search_space = db.query(SearchSpace).filter(
SearchSpace.id == search_space_id,
SearchSpace.user_id == user.id
).first()
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found or does not belong to the user")
podcasts = db.query(Podcast).filter(Podcast.search_space_id == search_space_id).all()
return podcasts
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
# Incomplete function, needs to be implemented based on the actual requirements and database structure
@app.post("/searchspace/{search_space_id}/delete/docs")
def delete_all_related_data(search_space_id: int, data: DocumentsToDelete, db: Session = Depends(get_db)):
try:
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
# Get the user by username and ensure they exist
user = db.query(User).filter(User.username == username).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
# Verify the search space belongs to the user
search_space = db.query(SearchSpace).filter(
SearchSpace.id == search_space_id,
SearchSpace.user_id == user.id
).first()
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found or does not belong to the user")
if IS_LOCAL_SETUP:
index = HIndices(username=username)
else:
index = HIndices(username=username, api_key=OPENAI_API_KEY)
message = index.delete_vector_stores(summary_ids_to_delete=data.ids_to_delete, db=db, search_space=search_space.name)
return {
"message": message
}
except JWTError:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, ws="wsproto", host="127.0.0.1", port=8000)

View file

@ -1,50 +0,0 @@
version: '3.8'
services:
# PostgreSQL Database
db:
image: postgres:13
environment:
POSTGRES_USER: your_postgres_user
POSTGRES_PASSWORD: your_postgres_password
POSTGRES_DB: surfsense_db
volumes:
- postgres_data:/var/lib/postgresql/data
ports:
- "5432:5432"
networks:
- surfsense-network
# Backend Service (FastAPI)
backend:
build:
context: ./backend
ports:
- "8000:8000"
env_file:
- ./backend/.env
depends_on:
- db
# privileged: true #when backend is pointing to the hostmaschine with localhost or 127.0.0.1 you need to add privileged: true else the container points at itself
networks:
- surfsense-network
# Frontend Service (Next.js)
frontend:
build:
context: ./SurfSense-Frontend
ports:
- "3000:3000"
env_file:
- ./SurfSense-Frontend/.env
networks:
- surfsense-network
# Volumes for persistent storage
volumes:
postgres_data:
# Docker network
networks:
surfsense-network:
driver: bridge

View file

@ -0,0 +1,21 @@
DATABASE_URL="postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense"
SECRET_KEY="SECRET"
GOOGLE_OAUTH_CLIENT_ID="924507538m"
GOOGLE_OAUTH_CLIENT_SECRET="GOCSV"
NEXT_FRONTEND_URL="http://localhost:3000"
EMBEDDING_MODEL="mixedbread-ai/mxbai-embed-large-v1"
RERANKERS_MODEL_NAME="ms-marco-MiniLM-L-12-v2"
RERANKERS_MODEL_TYPE="flashrank"
FAST_LLM="litellm:openai/gpt-4o-mini"
SMART_LLM="litellm:openai/gpt-4o-mini"
STRATEGIC_LLM="litellm:openai/gpt-4o-mini"
LONG_CONTEXT_LLM="litellm:gemini/gemini-2.0-flash-thinking-exp-01-21"
OPENAI_API_KEY="sk-proj-iA"
GEMINI_API_KEY="AIzaSyB6-1641124124124124124124124124124"
UNSTRUCTURED_API_KEY="Tpu3P0U8iy"
FIRECRAWL_API_KEY="fcr-01J0000000000000000000000"

6
surfsense_backend/.gitignore vendored Normal file
View file

@ -0,0 +1,6 @@
.env
.venv
venv/
data/
__pycache__/
.flashrank_cache

View file

@ -0,0 +1 @@
3.12

16
surfsense_backend/.vscode/launch.json vendored Normal file
View file

@ -0,0 +1,16 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python Debugger: main.py",
"type": "debugpy",
"request": "launch",
"program": "${workspaceFolder}/main.py",
"console": "integratedTerminal",
"justMyCode": false
}
]
}

119
surfsense_backend/README.md Normal file
View file

@ -0,0 +1,119 @@
# Surf Backend
## Technology Stack Overview
This application is a modern AI-powered search and knowledge management platform built with the following technology stack:
### Core Framework and Environment
- **Python 3.12+**: The application requires Python 3.12 or newer
- **FastAPI**: Modern, fast web framework for building APIs with Python
- **Uvicorn**: ASGI server implementation, running the FastAPI application
- **PostgreSQL with pgvector**: Database with vector search capabilities for similarity searches
- **SQLAlchemy**: SQL toolkit and ORM (Object-Relational Mapping) for database interactions
- **FastAPI Users**: Authentication and user management with JWT and OAuth support
### Key Features and Components
#### Authentication and User Management
- JWT-based authentication
- OAuth integration (Google)
- User registration, login, and password reset flows
#### Search and Retrieval System
- **Hybrid Search**: Combines vector similarity and full-text search for optimal results using Reciprocal Rank Fusion (RRF)
- **Vector Embeddings**: Document and text embeddings for semantic search
- **pgvector**: PostgreSQL extension for efficient vector similarity operations
- **Chonkie**: Advanced document chunking and embedding library
- Uses `AutoEmbeddings` for flexible embedding model selection
- `LateChunker` for optimized document chunking based on embedding model's max sequence length
#### AI and NLP Capabilities
- **LangChain**: Framework for developing AI-powered applications
- Used for document processing, research, and response generation
- Integration with various LLM models through LiteLLM
- Document conversion utilities for standardized processing
- **GPT Integration**: Integration with LLM models through LiteLLM
- Multiple LLM configurations for different use cases:
- Fast LLM: Quick responses (default: gpt-4o-mini)
- Smart LLM: More comprehensive analysis (default: gpt-4o-mini)
- Strategic LLM: Complex reasoning (default: gpt-4o-mini)
- Long Context LLM: For processing large documents (default: gemini-2.0-flash-thinking)
- **Rerankers with FlashRank**: Advanced result ranking for improved search relevance
- Configurable reranking models (default: ms-marco-MiniLM-L-12-v2)
- Supports multiple reranking backends (FlashRank, Cohere, etc.)
- Improves search result quality by reordering based on semantic relevance
- **GPT-Researcher**: Advanced research capabilities
- Multiple research modes (GENERAL, DEEP, DEEPER)
- Customizable report formats with proper citations
- Streaming research results for real-time updates
#### External Integrations
- **Slack Connector**: Integration with Slack for data retrieval and notifications
- **Notion Connector**: Integration with Notion for document retrieval
- **Search APIs**: Integration with Tavily and Serper API for web search
- **Firecrawl**: Web crawling and data extraction capabilities
#### Data Processing
- **Unstructured**: Tools for processing unstructured data
- **Markdownify**: Converting HTML to Markdown
- **Playwright**: Web automation and scraping capabilities
#### Main Modules
- **Search Spaces**: Isolated search environments for different contexts or projects
- **Documents**: Storage and retrieval of various document types
- **Chunks**: Document fragments for more precise retrieval
- **Chats**: Conversation management with different depth levels (GENERAL, DEEP)
- **Podcasts**: Audio content management with generation capabilities
- **Search Source Connectors**: Integration with various data sources
### Development Tools
- **Poetry**: Python dependency management (indicated by pyproject.toml)
- **CORS support**: Cross-Origin Resource Sharing enabled for API access
- **Environment Variables**: Configuration through .env files
## Database Schema
The application uses a relational database with the following main entities:
- Users: Authentication and user management
- SearchSpaces: Isolated search environments owned by users
- Documents: Various document types with content and embeddings
- Chunks: Smaller pieces of documents for granular retrieval
- Chats: Conversation tracking with different depth levels
- Podcasts: Audio content with generation capabilities
- SearchSourceConnectors: External data source integrations
## API Endpoints
The API is structured with the following main route groups:
- `/auth/*`: Authentication endpoints (JWT, OAuth)
- `/users/*`: User management
- `/api/v1/search-spaces/*`: Search space management
- `/api/v1/documents/*`: Document management
- `/api/v1/podcasts/*`: Podcast functionality
- `/api/v1/chats/*`: Chat and conversation endpoints
- `/api/v1/search-source-connectors/*`: External data source management
## Deployment
The application is configured to run with Uvicorn and can be deployed with:
```
python main.py
```
This will start the server on all interfaces (0.0.0.0) with info-level logging.
## Requirements
See pyproject.toml for detailed dependency information. Key dependencies include:
- asyncpg: Asynchronous PostgreSQL client
- chonkie: Document chunking and embedding library
- fastapi and related packages
- fastapi-users: Authentication and user management
- firecrawl-py: Web crawling capabilities
- gpt-researcher: Advanced research capabilities
- langchain components for AI workflows
- litellm: LLM model integration
- pgvector: Vector similarity search in PostgreSQL
- rerankers with FlashRank: Advanced result ranking
- Various AI and NLP libraries
- Integration clients for Slack, Notion, etc.

View file

View file

@ -0,0 +1,80 @@
from contextlib import asynccontextmanager
from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import User, create_db_and_tables, get_async_session
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
from app.schemas import UserCreate, UserRead, UserUpdate
from app.users import (
SECRET,
auth_backend,
fastapi_users,
google_oauth_client,
current_active_user,
)
from app.routes import router as crud_router
@asynccontextmanager
async def lifespan(app: FastAPI):
# Not needed if you setup a migration system like Alembic
await create_db_and_tables()
yield
app = FastAPI(lifespan=lifespan)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
app.include_router(
fastapi_users.get_auth_router(auth_backend), prefix="/auth/jwt", tags=["auth"]
)
app.include_router(
fastapi_users.get_register_router(UserRead, UserCreate),
prefix="/auth",
tags=["auth"],
)
app.include_router(
fastapi_users.get_reset_password_router(),
prefix="/auth",
tags=["auth"],
)
app.include_router(
fastapi_users.get_verify_router(UserRead),
prefix="/auth",
tags=["auth"],
)
app.include_router(
fastapi_users.get_users_router(UserRead, UserUpdate),
prefix="/users",
tags=["users"],
)
app.include_router(
fastapi_users.get_oauth_router(google_oauth_client, auth_backend, SECRET, is_verified_by_default=True),
prefix="/auth/google",
tags=["auth"],
)
app.include_router(crud_router, prefix="/api/v1", tags=["crud"])
@app.get("/authenticated-route")
async def authenticated_route(user: User = Depends(current_active_user), session: AsyncSession = Depends(get_async_session)):
retriever = ChucksHybridSearchRetriever(session)
results = await retriever.hybrid_search(
query_text="SurfSense",
top_k=1,
user_id=user.id,
search_space_id=1,
document_type="CRAWLED_URL"
)
return results

View file

@ -0,0 +1,98 @@
import os
from pathlib import Path
from chonkie import AutoEmbeddings, LateChunker
from rerankers import Reranker
from langchain_community.chat_models import ChatLiteLLM
from dotenv import load_dotenv
# Get the base directory of the project
BASE_DIR = Path(__file__).resolve().parent.parent.parent
env_file = BASE_DIR / ".env"
load_dotenv(env_file)
def extract_model_name(llm_string: str) -> str:
"""Extract the model name from an LLM string.
Example: "litellm:openai/gpt-4o-mini" -> "openai/gpt-4o-mini"
Args:
llm_string: The LLM string with optional prefix
Returns:
str: The extracted model name
"""
return llm_string.split(":", 1)[1] if ":" in llm_string else llm_string
class Config:
# Database
DATABASE_URL = os.getenv("DATABASE_URL")
# Google OAuth
GOOGLE_OAUTH_CLIENT_ID = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
GOOGLE_OAUTH_CLIENT_SECRET = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
NEXT_FRONTEND_URL = os.getenv("NEXT_FRONTEND_URL")
# LONG-CONTEXT LLMS
LONG_CONTEXT_LLM = os.getenv("LONG_CONTEXT_LLM")
long_context_llm_instance = ChatLiteLLM(model=extract_model_name(LONG_CONTEXT_LLM))
# GPT Researcher
FAST_LLM = os.getenv("FAST_LLM")
SMART_LLM = os.getenv("SMART_LLM")
STRATEGIC_LLM = os.getenv("STRATEGIC_LLM")
fast_llm_instance = ChatLiteLLM(model=extract_model_name(FAST_LLM))
smart_llm_instance = ChatLiteLLM(model=extract_model_name(SMART_LLM))
strategic_llm_instance = ChatLiteLLM(model=extract_model_name(STRATEGIC_LLM))
# Chonkie Configuration | Edit this to your needs
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
embedding_model_instance = AutoEmbeddings.get_embeddings(EMBEDDING_MODEL)
chunker_instance = LateChunker(
embedding_model=EMBEDDING_MODEL,
chunk_size=embedding_model_instance.max_seq_length,
)
# Reranker's Configuration | Pinecode, Cohere etc. Read more at https://github.com/AnswerDotAI/rerankers?tab=readme-ov-file#usage
RERANKERS_MODEL_NAME = os.getenv("RERANKERS_MODEL_NAME")
RERANKERS_MODEL_TYPE = os.getenv("RERANKERS_MODEL_TYPE")
reranker_instance = Reranker(
model_name=RERANKERS_MODEL_NAME,
model_type=RERANKERS_MODEL_TYPE,
)
# OAuth JWT
SECRET_KEY = os.getenv("SECRET_KEY")
# Unstructured API Key
UNSTRUCTURED_API_KEY = os.getenv("UNSTRUCTURED_API_KEY")
# Firecrawl API Key
FIRECRAWL_API_KEY = os.getenv("FIRECRAWL_API_KEY", None)
# Validation Checks
# Check embedding dimension
if hasattr(embedding_model_instance, 'dimension') and embedding_model_instance.dimension > 2000:
raise ValueError(
f"Embedding dimension for Model: {EMBEDDING_MODEL} "
f"has {embedding_model_instance.dimension} dimensions, which "
f"exceeds the maximum of 2000 allowed by PGVector."
)
@classmethod
def get_settings(cls):
"""Get all settings as a dictionary."""
return {
key: value
for key, value in cls.__dict__.items()
if not key.startswith("_") and not callable(value)
}
# Create a config instance
config = Config()

View file

@ -0,0 +1,225 @@
from notion_client import Client
class NotionHistoryConnector:
def __init__(self, token):
"""
Initialize the NotionPageFetcher with a token.
Args:
token (str): Notion integration token
"""
self.notion = Client(auth=token)
def get_all_pages(self, start_date=None, end_date=None):
"""
Fetches all pages shared with your integration and their content.
Args:
start_date (str, optional): ISO 8601 date string (e.g., "2023-01-01T00:00:00Z")
end_date (str, optional): ISO 8601 date string (e.g., "2023-12-31T23:59:59Z")
Returns:
list: List of dictionaries containing page data
"""
# Build the filter for the search
# Note: Notion API requires specific filter structure
search_params = {}
# Filter for pages only (not databases)
search_params["filter"] = {
"value": "page",
"property": "object"
}
# Add date filters if provided
if start_date or end_date:
date_filter = {}
if start_date:
date_filter["on_or_after"] = start_date
if end_date:
date_filter["on_or_before"] = end_date
# Add the date filter to the search params
if date_filter:
search_params["sort"] = {
"direction": "descending",
"timestamp": "last_edited_time"
}
# First, get a list of all pages the integration has access to
search_results = self.notion.search(**search_params)
pages = search_results["results"]
all_page_data = []
for page in pages:
page_id = page["id"]
# Get detailed page information
page_content = self.get_page_content(page_id)
all_page_data.append({
"page_id": page_id,
"title": self.get_page_title(page),
"content": page_content
})
return all_page_data
def get_page_title(self, page):
"""
Extracts the title from a page object.
Args:
page (dict): Notion page object
Returns:
str: Page title or a fallback string
"""
# Title can be in different properties depending on the page type
if "properties" in page:
# Try to find a title property
for prop_name, prop_data in page["properties"].items():
if prop_data["type"] == "title" and len(prop_data["title"]) > 0:
return " ".join([text_obj["plain_text"] for text_obj in prop_data["title"]])
# If no title found, return the page ID as fallback
return f"Untitled page ({page['id']})"
def get_page_content(self, page_id):
"""
Fetches the content (blocks) of a specific page.
Args:
page_id (str): The ID of the page to fetch
Returns:
list: List of processed blocks from the page
"""
blocks = []
has_more = True
cursor = None
# Paginate through all blocks
while has_more:
if cursor:
response = self.notion.blocks.children.list(block_id=page_id, start_cursor=cursor)
else:
response = self.notion.blocks.children.list(block_id=page_id)
blocks.extend(response["results"])
has_more = response["has_more"]
if has_more:
cursor = response["next_cursor"]
# Process nested blocks recursively
processed_blocks = []
for block in blocks:
processed_block = self.process_block(block)
processed_blocks.append(processed_block)
return processed_blocks
def process_block(self, block):
"""
Processes a block and recursively fetches any child blocks.
Args:
block (dict): The block to process
Returns:
dict: Processed block with content and children
"""
block_id = block["id"]
block_type = block["type"]
# Extract block content based on its type
content = self.extract_block_content(block)
# Check if block has children
has_children = block.get("has_children", False)
child_blocks = []
if has_children:
# Fetch and process child blocks
children_response = self.notion.blocks.children.list(block_id=block_id)
for child_block in children_response["results"]:
child_blocks.append(self.process_block(child_block))
return {
"id": block_id,
"type": block_type,
"content": content,
"children": child_blocks
}
def extract_block_content(self, block):
"""
Extracts the content from a block based on its type.
Args:
block (dict): The block to extract content from
Returns:
str: Extracted content as a string
"""
block_type = block["type"]
# Different block types have different structures
if block_type in block and "rich_text" in block[block_type]:
return "".join([text_obj["plain_text"] for text_obj in block[block_type]["rich_text"]])
elif block_type == "image":
# Instead of returning the raw URL which may contain sensitive AWS credentials,
# return a placeholder or reference to the image
if "file" in block["image"]:
# For Notion-hosted images (which use AWS S3 pre-signed URLs)
return "[Notion Image]"
elif "external" in block["image"]:
# For external images, we can return a sanitized reference
url = block["image"]["external"]["url"]
# Only return the domain part of external URLs to avoid potential sensitive parameters
try:
from urllib.parse import urlparse
parsed_url = urlparse(url)
return f"[External Image from {parsed_url.netloc}]"
except:
return "[External Image]"
elif block_type == "code":
language = block["code"]["language"]
code_text = "".join([text_obj["plain_text"] for text_obj in block["code"]["rich_text"]])
return f"```{language}\n{code_text}\n```"
elif block_type == "equation":
return block["equation"]["expression"]
# Add more block types as needed
# Return empty string for unsupported block types
return ""
# Example usage
# if __name__ == "__main__":
# # Simple example of how to use this module
# import argparse
# parser = argparse.ArgumentParser(description="Fetch Notion pages using an integration token")
# parser.add_argument("--token", help="Your Notion integration token")
# parser.add_argument("--start-date", help="Start date in ISO format (e.g., 2023-01-01T00:00:00Z)")
# parser.add_argument("--end-date", help="End date in ISO format (e.g., 2023-12-31T23:59:59Z)")
# args = parser.parse_args()
# token = args.token
# if not token:
# token = input("Enter your Notion integration token: ")
# fetcher = NotionPageFetcher(token)
# try:
# pages = fetcher.get_all_pages(args.start_date, args.end_date)
# print(f"Fetched {len(pages)} pages from Notion")
# for page in pages:
# print(f"- {page['title']}")
# except Exception as e:
# print(f"Error: {str(e)}")

View file

@ -0,0 +1,301 @@
"""
Slack History Module
A module for retrieving conversation history from Slack channels.
Allows fetching channel lists and message history with date range filtering.
"""
import os
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Any, Union
class SlackHistory:
"""Class for retrieving conversation history from Slack channels."""
def __init__(self, token: str = None):
"""
Initialize the SlackHistory class.
Args:
token: Slack API token (optional, can be set later with set_token)
"""
self.client = WebClient(token=token) if token else None
def set_token(self, token: str) -> None:
"""
Set the Slack API token.
Args:
token: Slack API token
"""
self.client = WebClient(token=token)
def get_all_channels(self, include_private: bool = True) -> Dict[str, str]:
"""
Fetch all channels that the bot has access to.
Args:
include_private: Whether to include private channels
Returns:
Dictionary mapping channel names to channel IDs
Raises:
ValueError: If no Slack client has been initialized
SlackApiError: If there's an error calling the Slack API
"""
if not self.client:
raise ValueError("Slack client not initialized. Call set_token() first.")
channels_dict = {}
types = "public_channel"
if include_private:
types += ",private_channel"
try:
# Call the conversations.list method
result = self.client.conversations_list(
types=types,
limit=1000 # Maximum allowed by API
)
channels = result["channels"]
# Handle pagination for workspaces with many channels
while result.get("response_metadata", {}).get("next_cursor"):
next_cursor = result["response_metadata"]["next_cursor"]
# Get the next batch of channels
result = self.client.conversations_list(
types=types,
cursor=next_cursor,
limit=1000
)
channels.extend(result["channels"])
# Create a dictionary mapping channel names to IDs
for channel in channels:
channels_dict[channel["name"]] = channel["id"]
return channels_dict
except SlackApiError as e:
raise SlackApiError(f"Error retrieving channels: {e}", e.response)
def get_conversation_history(
self,
channel_id: str,
limit: int = 1000,
oldest: Optional[int] = None,
latest: Optional[int] = None
) -> List[Dict[str, Any]]:
"""
Fetch conversation history for a channel.
Args:
channel_id: The ID of the channel to fetch history for
limit: Maximum number of messages to return per request (default 1000)
oldest: Start of time range (Unix timestamp)
latest: End of time range (Unix timestamp)
Returns:
List of message objects
Raises:
ValueError: If no Slack client has been initialized
SlackApiError: If there's an error calling the Slack API
"""
if not self.client:
raise ValueError("Slack client not initialized. Call set_token() first.")
try:
# Call the conversations.history method
messages = []
next_cursor = None
while True:
kwargs = {
"channel": channel_id,
"limit": min(limit, 1000), # API max is 1000
}
if oldest:
kwargs["oldest"] = oldest
if latest:
kwargs["latest"] = latest
if next_cursor:
kwargs["cursor"] = next_cursor
result = self.client.conversations_history(**kwargs)
batch = result["messages"]
messages.extend(batch)
# Check if we need to paginate
if result.get("has_more", False) and len(messages) < limit:
next_cursor = result["response_metadata"]["next_cursor"]
else:
break
# Respect the overall limit parameter
return messages[:limit]
except SlackApiError as e:
raise SlackApiError(f"Error retrieving history for channel {channel_id}: {e}", e.response)
@staticmethod
def convert_date_to_timestamp(date_str: str) -> Optional[int]:
"""
Convert a date string in format YYYY-MM-DD to Unix timestamp.
Args:
date_str: Date string in YYYY-MM-DD format
Returns:
Unix timestamp (seconds since epoch) or None if invalid format
"""
try:
dt = datetime.strptime(date_str, "%Y-%m-%d")
return int(dt.timestamp())
except ValueError:
return None
def get_history_by_date_range(
self,
channel_id: str,
start_date: str,
end_date: str,
limit: int = 1000
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
"""
Fetch conversation history within a date range.
Args:
channel_id: The ID of the channel to fetch history for
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format (inclusive)
limit: Maximum number of messages to return
Returns:
Tuple containing (messages list, error message or None)
"""
oldest = self.convert_date_to_timestamp(start_date)
if not oldest:
return [], f"Invalid start date format: {start_date}. Please use YYYY-MM-DD."
latest = self.convert_date_to_timestamp(end_date)
if not latest:
return [], f"Invalid end date format: {end_date}. Please use YYYY-MM-DD."
# Add one day to end date to make it inclusive
latest += 86400 # seconds in a day
try:
messages = self.get_conversation_history(
channel_id=channel_id,
limit=limit,
oldest=oldest,
latest=latest
)
return messages, None
except SlackApiError as e:
return [], f"Slack API error: {str(e)}"
except ValueError as e:
return [], str(e)
def get_user_info(self, user_id: str) -> Dict[str, Any]:
"""
Get information about a user.
Args:
user_id: The ID of the user to get info for
Returns:
User information dictionary
Raises:
ValueError: If no Slack client has been initialized
SlackApiError: If there's an error calling the Slack API
"""
if not self.client:
raise ValueError("Slack client not initialized. Call set_token() first.")
try:
result = self.client.users_info(user=user_id)
return result["user"]
except SlackApiError as e:
raise SlackApiError(f"Error retrieving user info for {user_id}: {e}", e.response)
def format_message(self, msg: Dict[str, Any], include_user_info: bool = False) -> Dict[str, Any]:
"""
Format a message for easier consumption.
Args:
msg: The message object from Slack API
include_user_info: Whether to fetch and include user info
Returns:
Formatted message dictionary
"""
formatted = {
"text": msg.get("text", ""),
"timestamp": msg.get("ts"),
"datetime": datetime.fromtimestamp(float(msg.get("ts", 0))).strftime('%Y-%m-%d %H:%M:%S'),
"user_id": msg.get("user", "UNKNOWN"),
"has_attachments": bool(msg.get("attachments")),
"has_files": bool(msg.get("files")),
"thread_ts": msg.get("thread_ts"),
"is_thread": "thread_ts" in msg,
}
if include_user_info and "user" in msg and self.client:
try:
user_info = self.get_user_info(msg["user"])
formatted["user_name"] = user_info.get("real_name", "Unknown")
formatted["user_email"] = user_info.get("profile", {}).get("email", "")
except Exception:
# If we can't get user info, just continue without it
formatted["user_name"] = "Unknown"
return formatted
# Example usage (uncomment to use):
"""
if __name__ == "__main__":
# Set your token here or via environment variable
token = os.environ.get("SLACK_API_TOKEN", "xoxb-your-token-here")
slack = SlackHistory(token)
# Get all channels
try:
channels = slack.get_all_channels()
print("Available channels:")
for name, channel_id in sorted(channels.items()):
print(f"- {name}: {channel_id}")
# Example: Get history for a specific channel and date range
channel_id = channels.get("general")
if channel_id:
messages, error = slack.get_history_by_date_range(
channel_id=channel_id,
start_date="2023-01-01",
end_date="2023-01-31",
limit=500
)
if error:
print(f"Error: {error}")
else:
print(f"\nRetrieved {len(messages)} messages from #general")
# Print formatted messages
for msg in messages[:10]: # Show first 10 messages
formatted = slack.format_message(msg, include_user_info=True)
print(f"[{formatted['datetime']}] {formatted['user_name']}: {formatted['text']}")
except Exception as e:
print(f"Error: {e}")
"""

181
surfsense_backend/app/db.py Normal file
View file

@ -0,0 +1,181 @@
from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from enum import Enum
from fastapi import Depends
from fastapi_users.db import (
SQLAlchemyBaseOAuthAccountTableUUID,
SQLAlchemyBaseUserTableUUID,
SQLAlchemyUserDatabase,
)
from pgvector.sqlalchemy import Vector
from sqlalchemy import (
ARRAY,
Boolean,
Column,
Enum as SQLAlchemyEnum,
ForeignKey,
Integer,
JSON,
String,
Text,
text,
TIMESTAMP
)
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, declared_attr, relationship
from app.config import config
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
DATABASE_URL = config.DATABASE_URL
class DocumentType(str, Enum):
EXTENSION = "EXTENSION"
CRAWLED_URL = "CRAWLED_URL"
FILE = "FILE"
SLACK_CONNECTOR = "SLACK_CONNECTOR"
NOTION_CONNECTOR = "NOTION_CONNECTOR"
class SearchSourceConnectorType(str, Enum):
SERPER_API = "SERPER_API"
TAVILY_API = "TAVILY_API"
SLACK_CONNECTOR = "SLACK_CONNECTOR"
NOTION_CONNECTOR = "NOTION_CONNECTOR"
class ChatType(str, Enum):
GENERAL = "GENERAL"
DEEP = "DEEP"
DEEPER = "DEEPER"
DEEPEST = "DEEPEST"
class Base(DeclarativeBase):
pass
class TimestampMixin:
@declared_attr
def created_at(cls):
return Column(TIMESTAMP(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc), index=True)
class BaseModel(Base):
__abstract__ = True
__allow_unmapped__ = True
id = Column(Integer, primary_key=True, index=True)
class Chat(BaseModel, TimestampMixin):
__tablename__ = "chats"
type = Column(SQLAlchemyEnum(ChatType), nullable=False)
title = Column(String(200), nullable=False, index=True)
initial_connectors = Column(ARRAY(String), nullable=True)
messages = Column(JSON, nullable=False)
search_space_id = Column(Integer, ForeignKey('searchspaces.id', ondelete='CASCADE'), nullable=False)
search_space = relationship('SearchSpace', back_populates='chats')
class Document(BaseModel, TimestampMixin):
__tablename__ = "documents"
title = Column(String(200), nullable=False, index=True)
document_type = Column(SQLAlchemyEnum(DocumentType), nullable=False)
document_metadata = Column(JSON, nullable=True)
content = Column(Text, nullable=False)
embedding = Column(Vector(config.embedding_model_instance.dimension))
search_space_id = Column(Integer, ForeignKey("searchspaces.id", ondelete='CASCADE'), nullable=False)
search_space = relationship("SearchSpace", back_populates="documents")
chunks = relationship("Chunk", back_populates="document", cascade="all, delete-orphan")
class Chunk(BaseModel, TimestampMixin):
__tablename__ = "chunks"
content = Column(Text, nullable=False)
embedding = Column(Vector(config.embedding_model_instance.dimension))
document_id = Column(Integer, ForeignKey("documents.id", ondelete='CASCADE'), nullable=False)
document = relationship("Document", back_populates="chunks")
class Podcast(BaseModel, TimestampMixin):
__tablename__ = "podcasts"
title = Column(String(200), nullable=False, index=True)
is_generated = Column(Boolean, nullable=False, default=False)
podcast_content = Column(Text, nullable=False, default="")
file_location = Column(String(500), nullable=False, default="")
search_space_id = Column(Integer, ForeignKey("searchspaces.id", ondelete='CASCADE'), nullable=False)
search_space = relationship("SearchSpace", back_populates="podcasts")
class SearchSpace(BaseModel, TimestampMixin):
__tablename__ = "searchspaces"
name = Column(String(100), nullable=False, index=True)
description = Column(String(500), nullable=True)
user_id = Column(UUID(as_uuid=True), ForeignKey("user.id", ondelete='CASCADE'), nullable=False)
user = relationship("User", back_populates="search_spaces")
documents = relationship("Document", back_populates="search_space", order_by="Document.id", cascade="all, delete-orphan")
podcasts = relationship("Podcast", back_populates="search_space", order_by="Podcast.id", cascade="all, delete-orphan")
chats = relationship('Chat', back_populates='search_space', order_by='Chat.id', cascade="all, delete-orphan")
class SearchSourceConnector(BaseModel, TimestampMixin):
__tablename__ = "search_source_connectors"
name = Column(String(100), nullable=False, index=True)
connector_type = Column(SQLAlchemyEnum(SearchSourceConnectorType), nullable=False, unique=True)
is_indexable = Column(Boolean, nullable=False, default=False)
last_indexed_at = Column(TIMESTAMP(timezone=True), nullable=True)
config = Column(JSON, nullable=False)
user_id = Column(UUID(as_uuid=True), ForeignKey("user.id", ondelete='CASCADE'), nullable=False)
user = relationship("User", back_populates="search_source_connectors")
class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
pass
class User(SQLAlchemyBaseUserTableUUID, Base):
oauth_accounts: Mapped[list[OAuthAccount]] = relationship(
"OAuthAccount", lazy="joined"
)
search_spaces = relationship("SearchSpace", back_populates="user")
search_source_connectors = relationship("SearchSourceConnector", back_populates="user")
engine = create_async_engine(DATABASE_URL)
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
async def setup_indexes():
async with engine.begin() as conn:
# Create indexes
# Document Summary Indexes
await conn.execute(text('CREATE INDEX IF NOT EXISTS document_vector_index ON documents USING hnsw (embedding public.vector_cosine_ops)'))
await conn.execute(text('CREATE INDEX IF NOT EXISTS document_search_index ON documents USING gin (to_tsvector(\'english\', content))'))
# Document Chuck Indexes
await conn.execute(text('CREATE INDEX IF NOT EXISTS chucks_vector_index ON chunks USING hnsw (embedding public.vector_cosine_ops)'))
await conn.execute(text('CREATE INDEX IF NOT EXISTS chucks_search_index ON chunks USING gin (to_tsvector(\'english\', content))'))
async def create_db_and_tables():
async with engine.begin() as conn:
await conn.execute(text('CREATE EXTENSION IF NOT EXISTS vector'))
await conn.run_sync(Base.metadata.create_all)
await setup_indexes()
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
async with async_session_maker() as session:
yield session
async def get_user_db(session: AsyncSession = Depends(get_async_session)):
yield SQLAlchemyUserDatabase(session, User, OAuthAccount)
async def get_chucks_hybrid_search_retriever(session: AsyncSession = Depends(get_async_session)):
return ChucksHybridSearchRetriever(session)

View file

@ -0,0 +1,103 @@
from langchain_core.prompts.prompt import PromptTemplate
from datetime import datetime, timezone
DATE_TODAY = "Today's date is " + datetime.now(timezone.utc).astimezone().isoformat() + '\n'
SUMMARY_PROMPT = DATE_TODAY + """
<INSTRUCTIONS>
<context>
You are an expert document analyst and summarization specialist tasked with distilling complex information into clear,
comprehensive summaries. Your role is to analyze documents thoroughly and create structured summaries that:
1. Capture the complete essence and key insights of the source material
2. Maintain perfect accuracy and factual precision
3. Present information objectively without bias or interpretation
4. Preserve critical context and logical relationships
5. Structure content in a clear, hierarchical format
</context>
<principles>
<accuracy>
- Maintain absolute factual accuracy and fidelity to source material
- Avoid any subjective interpretation, inference or speculation
- Preserve complete original meaning, nuance and contextual relationships
- Report all quantitative data with precise values and appropriate units
- Verify and cross-reference facts before inclusion
- Flag any ambiguous or unclear information
</accuracy>
<objectivity>
- Present information with strict neutrality and impartiality
- Exclude all forms of bias, personal opinions, and editorial commentary
- Ensure balanced representation of all perspectives and viewpoints
- Maintain objective professional distance from the content
- Use precise, factual language free from emotional coloring
- Focus solely on verifiable information and evidence
</objectivity>
<comprehensiveness>
- Capture all essential information, key themes, and central arguments
- Preserve critical context and background necessary for understanding
- Include relevant supporting details, examples, and evidence
- Maintain logical flow and connections between concepts
- Ensure hierarchical organization of information
- Document relationships between different components
- Highlight dependencies and causal links
- Track chronological progression where relevant
</comprehensiveness>
</principles>
<output_format>
<type>
- Return summary in clean markdown format
- Do not include markdown code block tags (```markdown ```)
- Use standard markdown syntax for formatting (headers, lists, etc.)
- Use # for main headings (e.g., # EXECUTIVE SUMMARY)
- Use ## for subheadings where appropriate
- Use bullet points (- item) for lists
- Ensure proper indentation and spacing
- Use appropriate emphasis (**bold**, *italic*) where needed
</type>
<style>
- Use clear, concise language focused on key points
- Maintain professional and objective tone throughout
- Follow consistent formatting and style conventions
- Provide descriptive section headings and subheadings
- Utilize bullet points and lists for better readability
- Structure content with clear hierarchy and organization
- Avoid jargon and overly technical language
- Include transition sentences between sections
</style>
</output_format>
<validation>
<criteria>
- Verify all facts and claims match source material exactly
- Cross-reference and validate all numerical data points
- Ensure logical flow and consistency throughout summary
- Confirm comprehensive coverage of key information
- Check for objective, unbiased language and tone
- Validate accurate representation of source context
- Review for proper attribution of ideas and quotes
- Verify temporal accuracy and chronological order
</criteria>
</validation>
<length_guidelines>
- Scale summary length proportionally to source document complexity and length
- Minimum: 3-5 well-developed paragraphs per major section
- Maximum: 8-10 paragraphs per section for highly complex documents
- Adjust level of detail based on information density and importance
- Ensure key concepts receive adequate coverage regardless of length
</length_guidelines>
Now, create a summary of the following document:
<document_to_summarize>
{document}
</document_to_summarize>
</INSTRUCTIONS>
"""
SUMMARY_PROMPT_TEMPLATE = PromptTemplate(
input_variables=["document"],
template=SUMMARY_PROMPT
)

View file

@ -0,0 +1,243 @@
class ChucksHybridSearchRetriever:
def __init__(self, db_session):
"""
Initialize the hybrid search retriever with a database session.
Args:
db_session: SQLAlchemy AsyncSession from FastAPI dependency injection
"""
self.db_session = db_session
async def vector_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None) -> list:
"""
Perform vector similarity search on chunks.
Args:
query_text: The search query text
top_k: Number of results to return
user_id: The ID of the user performing the search
search_space_id: Optional search space ID to filter results
Returns:
List of chunks sorted by vector similarity
"""
from sqlalchemy import select, func
from sqlalchemy.orm import joinedload
from app.db import Chunk, Document, SearchSpace
from app.config import config
# Get embedding for the query
embedding_model = config.embedding_model_instance
query_embedding = embedding_model.embed(query_text)
# Build the base query with user ownership check
query = (
select(Chunk)
.options(joinedload(Chunk.document).joinedload(Document.search_space))
.join(Document, Chunk.document_id == Document.id)
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
.where(SearchSpace.user_id == user_id)
)
# Add search space filter if provided
if search_space_id is not None:
query = query.where(Document.search_space_id == search_space_id)
# Add vector similarity ordering
query = (
query
.order_by(Chunk.embedding.op("<=>")(query_embedding))
.limit(top_k)
)
# Execute the query
result = await self.db_session.execute(query)
chunks = result.scalars().all()
return chunks
async def full_text_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None) -> list:
"""
Perform full-text keyword search on chunks.
Args:
query_text: The search query text
top_k: Number of results to return
user_id: The ID of the user performing the search
search_space_id: Optional search space ID to filter results
Returns:
List of chunks sorted by text relevance
"""
from sqlalchemy import select, func, text
from sqlalchemy.orm import joinedload
from app.db import Chunk, Document, SearchSpace
# Create tsvector and tsquery for PostgreSQL full-text search
tsvector = func.to_tsvector('english', Chunk.content)
tsquery = func.plainto_tsquery('english', query_text)
# Build the base query with user ownership check
query = (
select(Chunk)
.options(joinedload(Chunk.document).joinedload(Document.search_space))
.join(Document, Chunk.document_id == Document.id)
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
.where(SearchSpace.user_id == user_id)
.where(tsvector.op("@@")(tsquery)) # Only include results that match the query
)
# Add search space filter if provided
if search_space_id is not None:
query = query.where(Document.search_space_id == search_space_id)
# Add text search ranking
query = (
query
.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
.limit(top_k)
)
# Execute the query
result = await self.db_session.execute(query)
chunks = result.scalars().all()
return chunks
async def hybrid_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None, document_type: str = None) -> list:
"""
Combine vector similarity and full-text search results using Reciprocal Rank Fusion.
Args:
query_text: The search query text
top_k: Number of results to return
user_id: The ID of the user performing the search
search_space_id: Optional search space ID to filter results
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
Returns:
List of dictionaries containing chunk data and relevance scores
"""
from sqlalchemy import select, func, text
from sqlalchemy.orm import joinedload
from app.db import Chunk, Document, SearchSpace, DocumentType
from app.config import config
# Get embedding for the query
embedding_model = config.embedding_model_instance
query_embedding = embedding_model.embed(query_text)
# Constants for RRF calculation
k = 60 # Constant for RRF calculation
n_results = top_k * 2 # Get more results for better fusion
# Create tsvector and tsquery for PostgreSQL full-text search
tsvector = func.to_tsvector('english', Chunk.content)
tsquery = func.plainto_tsquery('english', query_text)
# Base conditions for document filtering
base_conditions = [SearchSpace.user_id == user_id]
# Add search space filter if provided
if search_space_id is not None:
base_conditions.append(Document.search_space_id == search_space_id)
# Add document type filter if provided
if document_type is not None:
# Convert string to enum value if needed
if isinstance(document_type, str):
try:
doc_type_enum = DocumentType[document_type]
base_conditions.append(Document.document_type == doc_type_enum)
except KeyError:
# If the document type doesn't exist in the enum, return empty results
return []
else:
base_conditions.append(Document.document_type == document_type)
# CTE for semantic search with user ownership check
semantic_search_cte = (
select(
Chunk.id,
func.rank().over(order_by=Chunk.embedding.op("<=>")(query_embedding)).label("rank")
)
.join(Document, Chunk.document_id == Document.id)
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
.where(*base_conditions)
)
semantic_search_cte = (
semantic_search_cte
.order_by(Chunk.embedding.op("<=>")(query_embedding))
.limit(n_results)
.cte("semantic_search")
)
# CTE for keyword search with user ownership check
keyword_search_cte = (
select(
Chunk.id,
func.rank().over(order_by=func.ts_rank_cd(tsvector, tsquery).desc()).label("rank")
)
.join(Document, Chunk.document_id == Document.id)
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
.where(*base_conditions)
.where(tsvector.op("@@")(tsquery))
)
keyword_search_cte = (
keyword_search_cte
.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
.limit(n_results)
.cte("keyword_search")
)
# Final combined query using a FULL OUTER JOIN with RRF scoring
final_query = (
select(
Chunk,
(
func.coalesce(1.0 / (k + semantic_search_cte.c.rank), 0.0) +
func.coalesce(1.0 / (k + keyword_search_cte.c.rank), 0.0)
).label("score")
)
.select_from(
semantic_search_cte.outerjoin(
keyword_search_cte,
semantic_search_cte.c.id == keyword_search_cte.c.id,
full=True
)
)
.join(
Chunk,
Chunk.id == func.coalesce(semantic_search_cte.c.id, keyword_search_cte.c.id)
)
.options(joinedload(Chunk.document))
.order_by(text("score DESC"))
.limit(top_k)
)
# Execute the query
result = await self.db_session.execute(final_query)
chunks_with_scores = result.all()
# If no results were found, return an empty list
if not chunks_with_scores:
return []
# Convert to serializable dictionaries if no reranker is available or if reranking failed
serialized_results = []
for chunk, score in chunks_with_scores:
serialized_results.append({
"chunk_id": chunk.id,
"content": chunk.content,
"score": float(score), # Ensure score is a Python float
"document": {
"id": chunk.document.id,
"title": chunk.document.title,
"document_type": chunk.document.document_type.value if hasattr(chunk.document, 'document_type') else None,
"metadata": chunk.document.document_metadata
}
})
return serialized_results

View file

@ -0,0 +1,14 @@
from fastapi import APIRouter
from .search_spaces_routes import router as search_spaces_router
from .documents_routes import router as documents_router
from .podcasts_routes import router as podcasts_router
from .chats_routes import router as chats_router
from .search_source_connectors_routes import router as search_source_connectors_router
router = APIRouter()
router.include_router(search_spaces_router)
router.include_router(documents_router)
router.include_router(podcasts_router)
router.include_router(chats_router)
router.include_router(search_source_connectors_router)

View file

@ -0,0 +1,260 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.exc import IntegrityError, OperationalError
from typing import List
from app.db import get_async_session, User, SearchSpace, Chat
from app.schemas import ChatCreate, ChatUpdate, ChatRead, AISDKChatRequest
from app.tasks.stream_connector_search_results import stream_connector_search_results
from app.users import current_active_user
from app.utils.check_ownership import check_ownership
router = APIRouter()
@router.post("/chat")
async def handle_chat_data(
request: AISDKChatRequest,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
messages = request.messages
if messages[-1].role != "user":
raise HTTPException(
status_code=400, detail="Last message must be a user message")
user_query = messages[-1].content
search_space_id = request.data.get('search_space_id')
research_mode: str = request.data.get('research_mode')
selected_connectors: List[str] = request.data.get('selected_connectors')
# Convert search_space_id to integer if it's a string
if search_space_id and isinstance(search_space_id, str):
try:
search_space_id = int(search_space_id)
except ValueError:
raise HTTPException(
status_code=400, detail="Invalid search_space_id format")
# Check if the search space belongs to the current user
try:
await check_ownership(session, SearchSpace, search_space_id, user)
except HTTPException:
raise HTTPException(
status_code=403, detail="You don't have access to this search space")
response = StreamingResponse(stream_connector_search_results(
user_query,
user.id,
search_space_id,
session,
research_mode,
selected_connectors
))
response.headers['x-vercel-ai-data-stream'] = 'v1'
return response
@router.post("/chats/", response_model=ChatRead)
async def create_chat(
chat: ChatCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
try:
await check_ownership(session, SearchSpace, chat.search_space_id, user)
db_chat = Chat(**chat.model_dump())
session.add(db_chat)
await session.commit()
await session.refresh(db_chat)
return db_chat
except HTTPException:
raise
except IntegrityError as e:
await session.rollback()
raise HTTPException(
status_code=400, detail="Database constraint violation. Please check your input data.")
except OperationalError as e:
await session.rollback()
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later.")
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500, detail="An unexpected error occurred while creating the chat.")
@router.get("/chats/", response_model=List[ChatRead])
async def read_chats(
skip: int = 0,
limit: int = 100,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
try:
result = await session.execute(
select(Chat)
.join(SearchSpace)
.filter(SearchSpace.user_id == user.id)
.offset(skip)
.limit(limit)
)
return result.scalars().all()
except OperationalError:
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later.")
except Exception:
raise HTTPException(
status_code=500, detail="An unexpected error occurred while fetching chats.")
@router.get("/chats/{chat_id}", response_model=ChatRead)
async def read_chat(
chat_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
try:
result = await session.execute(
select(Chat)
.join(SearchSpace)
.filter(Chat.id == chat_id, SearchSpace.user_id == user.id)
)
chat = result.scalars().first()
if not chat:
raise HTTPException(
status_code=404, detail="Chat not found or you don't have permission to access it")
return chat
except OperationalError:
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later.")
except Exception:
raise HTTPException(
status_code=500, detail="An unexpected error occurred while fetching the chat.")
@router.put("/chats/{chat_id}", response_model=ChatRead)
async def update_chat(
chat_id: int,
chat_update: ChatUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
try:
db_chat = await read_chat(chat_id, session, user)
update_data = chat_update.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(db_chat, key, value)
await session.commit()
await session.refresh(db_chat)
return db_chat
except HTTPException:
raise
except IntegrityError:
await session.rollback()
raise HTTPException(
status_code=400, detail="Database constraint violation. Please check your input data.")
except OperationalError:
await session.rollback()
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later.")
except Exception:
await session.rollback()
raise HTTPException(
status_code=500, detail="An unexpected error occurred while updating the chat.")
@router.delete("/chats/{chat_id}", response_model=dict)
async def delete_chat(
chat_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
try:
db_chat = await read_chat(chat_id, session, user)
await session.delete(db_chat)
await session.commit()
return {"message": "Chat deleted successfully"}
except HTTPException:
raise
except IntegrityError:
await session.rollback()
raise HTTPException(
status_code=400, detail="Cannot delete chat due to existing dependencies.")
except OperationalError:
await session.rollback()
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later.")
except Exception:
await session.rollback()
raise HTTPException(
status_code=500, detail="An unexpected error occurred while deleting the chat.")
# test_data = [
# {
# "type": "TERMINAL_INFO",
# "content": [
# {
# "id": 1,
# "text": "Starting to search for crawled URLs...",
# "type": "info"
# },
# {
# "id": 2,
# "text": "Found 2 relevant crawled URLs",
# "type": "success"
# }
# ]
# },
# {
# "type": "SOURCES",
# "content": [
# {
# "id": 1,
# "name": "Crawled URLs",
# "type": "CRAWLED_URL",
# "sources": [
# {
# "id": 1,
# "title": "Webpage Title",
# "description": "Webpage Dec",
# "url": "https://jsoneditoronline.org/"
# },
# {
# "id": 2,
# "title": "Webpage Title",
# "description": "Webpage Dec",
# "url": "https://www.google.com/"
# }
# ]
# },
# {
# "id": 2,
# "name": "Files",
# "type": "FILE",
# "sources": [
# {
# "id": 3,
# "title": "Webpage Title",
# "description": "Webpage Dec",
# "url": "https://jsoneditoronline.org/"
# },
# {
# "id": 4,
# "title": "Webpage Title",
# "description": "Webpage Dec",
# "url": "https://www.google.com/"
# }
# ]
# }
# ]
# },
# {
# "type": "ANSWER",
# "content": [
# "## SurfSense Introduction",
# "Surfsense is A Personal NotebookLM and Perplexity-like AI Assistant for Everyone. Research and Never forget Anything. [1] [3]"
# ]
# }
# ]

View file

@ -0,0 +1,262 @@
from fastapi import APIRouter, Depends, BackgroundTasks, UploadFile, Form, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from typing import List
from app.db import get_async_session, User, SearchSpace, Document, DocumentType
from app.schemas import DocumentsCreate, DocumentUpdate, DocumentRead
from app.users import current_active_user
from app.utils.check_ownership import check_ownership
from app.tasks.background_tasks import add_extension_received_document, add_received_file_document, add_crawled_url_document
from langchain_unstructured import UnstructuredLoader
from app.config import config
import json
router = APIRouter()
@router.post("/documents/")
async def create_documents(
request: DocumentsCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
fastapi_background_tasks: BackgroundTasks = BackgroundTasks()
):
try:
# Check if the user owns the search space
await check_ownership(session, SearchSpace, request.search_space_id, user)
if request.document_type == DocumentType.EXTENSION:
for individual_document in request.content:
fastapi_background_tasks.add_task(
add_extension_received_document,
session,
individual_document,
request.search_space_id
)
elif request.document_type == DocumentType.CRAWLED_URL:
for url in request.content:
fastapi_background_tasks.add_task(
add_crawled_url_document,
session,
url,
request.search_space_id
)
else:
raise HTTPException(
status_code=400,
detail="Invalid document type"
)
await session.commit()
return {"message": "Documents processed successfully"}
except HTTPException:
raise
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"Failed to process documents: {str(e)}"
)
@router.post("/documents/fileupload")
async def create_documents(
files: list[UploadFile],
search_space_id: int = Form(...),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
fastapi_background_tasks: BackgroundTasks = BackgroundTasks()
):
try:
await check_ownership(session, SearchSpace, search_space_id, user)
if not files:
raise HTTPException(status_code=400, detail="No files provided")
for file in files:
try:
unstructured_loader = UnstructuredLoader(
file=file.file,
api_key=config.UNSTRUCTURED_API_KEY,
partition_via_api=True,
languages=["eng"],
include_orig_elements=False,
strategy="fast",
)
unstructured_processed_elements = await unstructured_loader.aload()
fastapi_background_tasks.add_task(
add_received_file_document,
session,
file.filename,
unstructured_processed_elements,
search_space_id
)
except Exception as e:
raise HTTPException(
status_code=422,
detail=f"Failed to process file {file.filename}: {str(e)}"
)
await session.commit()
return {"message": "Files added for processing successfully"}
except HTTPException:
raise
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"Failed to process documents: {str(e)}"
)
@router.get("/documents/", response_model=List[DocumentRead])
async def read_documents(
skip: int = 0,
limit: int = 300,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
try:
result = await session.execute(
select(Document)
.join(SearchSpace)
.filter(SearchSpace.user_id == user.id)
.offset(skip)
.limit(limit)
)
db_documents = result.scalars().all()
# Convert database objects to API-friendly format
api_documents = []
for doc in db_documents:
api_documents.append(DocumentRead(
id=doc.id,
title=doc.title,
document_type=doc.document_type,
document_metadata=doc.document_metadata,
content=doc.content,
created_at=doc.created_at,
search_space_id=doc.search_space_id
))
return api_documents
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to fetch documents: {str(e)}"
)
@router.get("/documents/{document_id}", response_model=DocumentRead)
async def read_document(
document_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
try:
result = await session.execute(
select(Document)
.join(SearchSpace)
.filter(Document.id == document_id, SearchSpace.user_id == user.id)
)
document = result.scalars().first()
if not document:
raise HTTPException(
status_code=404,
detail=f"Document with id {document_id} not found"
)
# Convert database object to API-friendly format
return DocumentRead(
id=document.id,
title=document.title,
document_type=document.document_type,
document_metadata=document.document_metadata,
content=document.content,
created_at=document.created_at,
search_space_id=document.search_space_id
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to fetch document: {str(e)}"
)
@router.put("/documents/{document_id}", response_model=DocumentRead)
async def update_document(
document_id: int,
document_update: DocumentUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
try:
# Query the document directly instead of using read_document function
result = await session.execute(
select(Document)
.join(SearchSpace)
.filter(Document.id == document_id, SearchSpace.user_id == user.id)
)
db_document = result.scalars().first()
if not db_document:
raise HTTPException(
status_code=404,
detail=f"Document with id {document_id} not found"
)
update_data = document_update.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(db_document, key, value)
await session.commit()
await session.refresh(db_document)
# Convert to DocumentRead for response
return DocumentRead(
id=db_document.id,
title=db_document.title,
document_type=db_document.document_type,
document_metadata=db_document.document_metadata,
content=db_document.content,
created_at=db_document.created_at,
search_space_id=db_document.search_space_id
)
except HTTPException:
raise
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"Failed to update document: {str(e)}"
)
@router.delete("/documents/{document_id}", response_model=dict)
async def delete_document(
document_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
try:
# Query the document directly instead of using read_document function
result = await session.execute(
select(Document)
.join(SearchSpace)
.filter(Document.id == document_id, SearchSpace.user_id == user.id)
)
document = result.scalars().first()
if not document:
raise HTTPException(
status_code=404,
detail=f"Document with id {document_id} not found"
)
await session.delete(document)
await session.commit()
return {"message": "Document deleted successfully"}
except HTTPException:
raise
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"Failed to delete document: {str(e)}"
)

View file

@ -0,0 +1,122 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from typing import List
from app.db import get_async_session, User, SearchSpace, Podcast
from app.schemas import PodcastCreate, PodcastUpdate, PodcastRead
from app.users import current_active_user
from app.utils.check_ownership import check_ownership
router = APIRouter()
@router.post("/podcasts/", response_model=PodcastRead)
async def create_podcast(
podcast: PodcastCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
try:
await check_ownership(session, SearchSpace, podcast.search_space_id, user)
db_podcast = Podcast(**podcast.model_dump())
session.add(db_podcast)
await session.commit()
await session.refresh(db_podcast)
return db_podcast
except HTTPException as he:
raise he
except IntegrityError as e:
await session.rollback()
raise HTTPException(status_code=400, detail="Podcast creation failed due to constraint violation")
except SQLAlchemyError as e:
await session.rollback()
raise HTTPException(status_code=500, detail="Database error occurred while creating podcast")
except Exception as e:
await session.rollback()
raise HTTPException(status_code=500, detail="An unexpected error occurred")
@router.get("/podcasts/", response_model=List[PodcastRead])
async def read_podcasts(
skip: int = 0,
limit: int = 100,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
if skip < 0 or limit < 1:
raise HTTPException(status_code=400, detail="Invalid pagination parameters")
try:
result = await session.execute(
select(Podcast)
.join(SearchSpace)
.filter(SearchSpace.user_id == user.id)
.offset(skip)
.limit(limit)
)
return result.scalars().all()
except SQLAlchemyError:
raise HTTPException(status_code=500, detail="Database error occurred while fetching podcasts")
@router.get("/podcasts/{podcast_id}", response_model=PodcastRead)
async def read_podcast(
podcast_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
try:
result = await session.execute(
select(Podcast)
.join(SearchSpace)
.filter(Podcast.id == podcast_id, SearchSpace.user_id == user.id)
)
podcast = result.scalars().first()
if not podcast:
raise HTTPException(
status_code=404,
detail="Podcast not found or you don't have permission to access it"
)
return podcast
except HTTPException as he:
raise he
except SQLAlchemyError:
raise HTTPException(status_code=500, detail="Database error occurred while fetching podcast")
@router.put("/podcasts/{podcast_id}", response_model=PodcastRead)
async def update_podcast(
podcast_id: int,
podcast_update: PodcastUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
try:
db_podcast = await read_podcast(podcast_id, session, user)
update_data = podcast_update.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(db_podcast, key, value)
await session.commit()
await session.refresh(db_podcast)
return db_podcast
except HTTPException as he:
raise he
except IntegrityError:
await session.rollback()
raise HTTPException(status_code=400, detail="Update failed due to constraint violation")
except SQLAlchemyError:
await session.rollback()
raise HTTPException(status_code=500, detail="Database error occurred while updating podcast")
@router.delete("/podcasts/{podcast_id}", response_model=dict)
async def delete_podcast(
podcast_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
try:
db_podcast = await read_podcast(podcast_id, session, user)
await session.delete(db_podcast)
await session.commit()
return {"message": "Podcast deleted successfully"}
except HTTPException as he:
raise he
except SQLAlchemyError:
await session.rollback()
raise HTTPException(status_code=500, detail="Database error occurred while deleting podcast")

View file

@ -0,0 +1,418 @@
"""
SearchSourceConnector routes for CRUD operations:
POST /search-source-connectors/ - Create a new connector
GET /search-source-connectors/ - List all connectors for the current user
GET /search-source-connectors/{connector_id} - Get a specific connector
PUT /search-source-connectors/{connector_id} - Update a specific connector
DELETE /search-source-connectors/{connector_id} - Delete a specific connector
POST /search-source-connectors/{connector_id}/index - Index content from a connector to a search space
Note: Each user can have only one connector of each type (SERPER_API, TAVILY_API, SLACK_CONNECTOR, NOTION_CONNECTOR).
"""
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.exc import IntegrityError
from typing import List, Dict, Any
from app.db import get_async_session, User, SearchSourceConnector, SearchSourceConnectorType, SearchSpace
from app.schemas import SearchSourceConnectorCreate, SearchSourceConnectorUpdate, SearchSourceConnectorRead
from app.users import current_active_user
from app.utils.check_ownership import check_ownership
from pydantic import ValidationError
from app.tasks.connectors_indexing_tasks import index_slack_messages, index_notion_pages
from datetime import datetime
import logging
# Set up logging
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/search-source-connectors/", response_model=SearchSourceConnectorRead)
async def create_search_source_connector(
connector: SearchSourceConnectorCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
"""
Create a new search source connector.
Each user can have only one connector of each type (SERPER_API, TAVILY_API, SLACK_CONNECTOR).
The config must contain the appropriate keys for the connector type.
"""
try:
# Check if a connector with the same type already exists for this user
result = await session.execute(
select(SearchSourceConnector)
.filter(
SearchSourceConnector.user_id == user.id,
SearchSourceConnector.connector_type == connector.connector_type
)
)
existing_connector = result.scalars().first()
if existing_connector:
raise HTTPException(
status_code=409,
detail=f"A connector with type {connector.connector_type} already exists. Each user can have only one connector of each type."
)
db_connector = SearchSourceConnector(**connector.model_dump(), user_id=user.id)
session.add(db_connector)
await session.commit()
await session.refresh(db_connector)
return db_connector
except ValidationError as e:
await session.rollback()
raise HTTPException(
status_code=422,
detail=f"Validation error: {str(e)}"
)
except IntegrityError as e:
await session.rollback()
raise HTTPException(
status_code=409,
detail=f"Integrity error: A connector with this type already exists. {str(e)}"
)
except HTTPException:
await session.rollback()
raise
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"Failed to create search source connector: {str(e)}"
)
@router.get("/search-source-connectors/", response_model=List[SearchSourceConnectorRead])
async def read_search_source_connectors(
skip: int = 0,
limit: int = 100,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
"""List all search source connectors for the current user."""
try:
result = await session.execute(
select(SearchSourceConnector)
.filter(SearchSourceConnector.user_id == user.id)
.offset(skip)
.limit(limit)
)
return result.scalars().all()
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to fetch search source connectors: {str(e)}"
)
@router.get("/search-source-connectors/{connector_id}", response_model=SearchSourceConnectorRead)
async def read_search_source_connector(
connector_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
"""Get a specific search source connector by ID."""
try:
return await check_ownership(session, SearchSourceConnector, connector_id, user)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to fetch search source connector: {str(e)}"
)
@router.put("/search-source-connectors/{connector_id}", response_model=SearchSourceConnectorRead)
async def update_search_source_connector(
connector_id: int,
connector_update: SearchSourceConnectorUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
"""
Update a search source connector.
Each user can have only one connector of each type (SERPER_API, TAVILY_API, SLACK_CONNECTOR).
The config must contain the appropriate keys for the connector type.
"""
try:
db_connector = await check_ownership(session, SearchSourceConnector, connector_id, user)
# If connector type is being changed, check if one of that type already exists
if connector_update.connector_type != db_connector.connector_type:
result = await session.execute(
select(SearchSourceConnector)
.filter(
SearchSourceConnector.user_id == user.id,
SearchSourceConnector.connector_type == connector_update.connector_type,
SearchSourceConnector.id != connector_id
)
)
existing_connector = result.scalars().first()
if existing_connector:
raise HTTPException(
status_code=409,
detail=f"A connector with type {connector_update.connector_type} already exists. Each user can have only one connector of each type."
)
update_data = connector_update.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(db_connector, key, value)
await session.commit()
await session.refresh(db_connector)
return db_connector
except ValidationError as e:
await session.rollback()
raise HTTPException(
status_code=422,
detail=f"Validation error: {str(e)}"
)
except IntegrityError as e:
await session.rollback()
raise HTTPException(
status_code=409,
detail=f"Integrity error: A connector with this type already exists. {str(e)}"
)
except HTTPException:
await session.rollback()
raise
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"Failed to update search source connector: {str(e)}"
)
@router.delete("/search-source-connectors/{connector_id}", response_model=dict)
async def delete_search_source_connector(
connector_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
"""Delete a search source connector."""
try:
db_connector = await check_ownership(session, SearchSourceConnector, connector_id, user)
await session.delete(db_connector)
await session.commit()
return {"message": "Search source connector deleted successfully"}
except HTTPException:
raise
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"Failed to delete search source connector: {str(e)}"
)
@router.post("/search-source-connectors/{connector_id}/index", response_model=Dict[str, Any])
async def index_connector_content(
connector_id: int,
search_space_id: int = Query(..., description="ID of the search space to store indexed content"),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
background_tasks: BackgroundTasks = None
):
"""
Index content from a connector to a search space.
Currently supports:
- SLACK_CONNECTOR: Indexes messages from all accessible Slack channels since the last indexing
(or the last 365 days if never indexed before)
- NOTION_CONNECTOR: Indexes pages from all accessible Notion pages since the last indexing
(or the last 365 days if never indexed before)
Args:
connector_id: ID of the connector to use
search_space_id: ID of the search space to store indexed content
background_tasks: FastAPI background tasks
Returns:
Dictionary with indexing status
"""
try:
# Check if the connector belongs to the user
connector = await check_ownership(session, SearchSourceConnector, connector_id, user)
# Check if the search space belongs to the user
search_space = await check_ownership(session, SearchSpace, search_space_id, user)
# Handle different connector types
if connector.connector_type == SearchSourceConnectorType.SLACK_CONNECTOR:
# Determine the time range that will be indexed
if not connector.last_indexed_at:
start_date = "365 days ago"
else:
# Check if last_indexed_at is today
today = datetime.now().date()
if connector.last_indexed_at.date() == today:
# If last indexed today, go back 1 day to ensure we don't miss anything
start_date = (today - datetime.timedelta(days=1)).strftime("%Y-%m-%d")
else:
start_date = connector.last_indexed_at.strftime("%Y-%m-%d")
# Add the indexing task to background tasks
if background_tasks:
background_tasks.add_task(
run_slack_indexing,
session,
connector_id,
search_space_id
)
return {
"success": True,
"message": "Slack indexing started in the background",
"connector_type": connector.connector_type,
"search_space": search_space.name,
"indexing_from": start_date,
"indexing_to": datetime.now().strftime("%Y-%m-%d")
}
else:
# For testing or if background tasks are not available
return {
"success": False,
"message": "Background tasks not available",
"connector_type": connector.connector_type
}
elif connector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR:
# Determine the time range that will be indexed
if not connector.last_indexed_at:
start_date = "365 days ago"
else:
# Check if last_indexed_at is today
today = datetime.now().date()
if connector.last_indexed_at.date() == today:
# If last indexed today, go back 1 day to ensure we don't miss anything
start_date = (today - datetime.timedelta(days=1)).strftime("%Y-%m-%d")
else:
start_date = connector.last_indexed_at.strftime("%Y-%m-%d")
# Add the indexing task to background tasks
if background_tasks:
background_tasks.add_task(
run_notion_indexing,
session,
connector_id,
search_space_id
)
return {
"success": True,
"message": "Notion indexing started in the background",
"connector_type": connector.connector_type,
"search_space": search_space.name,
"indexing_from": start_date,
"indexing_to": datetime.now().strftime("%Y-%m-%d")
}
else:
# For testing or if background tasks are not available
return {
"success": False,
"message": "Background tasks not available",
"connector_type": connector.connector_type
}
else:
raise HTTPException(
status_code=400,
detail=f"Indexing not supported for connector type: {connector.connector_type}"
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to start indexing: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Failed to start indexing: {str(e)}"
)
async def update_connector_last_indexed(
session: AsyncSession,
connector_id: int
):
"""
Update the last_indexed_at timestamp for a connector.
Args:
session: Database session
connector_id: ID of the connector to update
"""
try:
result = await session.execute(
select(SearchSourceConnector)
.filter(SearchSourceConnector.id == connector_id)
)
connector = result.scalars().first()
if connector:
connector.last_indexed_at = datetime.now()
await session.commit()
logger.info(f"Updated last_indexed_at for connector {connector_id}")
except Exception as e:
logger.error(f"Failed to update last_indexed_at for connector {connector_id}: {str(e)}")
await session.rollback()
async def run_slack_indexing(
session: AsyncSession,
connector_id: int,
search_space_id: int
):
"""
Background task to run Slack indexing.
Args:
session: Database session
connector_id: ID of the Slack connector
search_space_id: ID of the search space
"""
try:
# Index Slack messages without updating last_indexed_at (we'll do it separately)
documents_indexed, error_or_warning = await index_slack_messages(
session=session,
connector_id=connector_id,
search_space_id=search_space_id,
update_last_indexed=False # Don't update timestamp in the indexing function
)
# Only update last_indexed_at if indexing was successful
if documents_indexed > 0 and (error_or_warning is None or "Indexed" in error_or_warning):
await update_connector_last_indexed(session, connector_id)
logger.info(f"Slack indexing completed successfully: {documents_indexed} documents indexed")
else:
logger.error(f"Slack indexing failed or no documents indexed: {error_or_warning}")
except Exception as e:
logger.error(f"Error in background Slack indexing task: {str(e)}")
async def run_notion_indexing(
session: AsyncSession,
connector_id: int,
search_space_id: int
):
"""
Background task to run Notion indexing.
Args:
session: Database session
connector_id: ID of the Notion connector
search_space_id: ID of the search space
"""
try:
# Index Notion pages without updating last_indexed_at (we'll do it separately)
documents_indexed, error_or_warning = await index_notion_pages(
session=session,
connector_id=connector_id,
search_space_id=search_space_id,
update_last_indexed=False # Don't update timestamp in the indexing function
)
# Only update last_indexed_at if indexing was successful
if documents_indexed > 0 and (error_or_warning is None or "Indexed" in error_or_warning):
await update_connector_last_indexed(session, connector_id)
logger.info(f"Notion indexing completed successfully: {documents_indexed} documents indexed")
else:
logger.error(f"Notion indexing failed or no documents indexed: {error_or_warning}")
except Exception as e:
logger.error(f"Error in background Notion indexing task: {str(e)}")

View file

@ -0,0 +1,115 @@
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from typing import List
from app.db import get_async_session, User, SearchSpace
from app.schemas import SearchSpaceCreate, SearchSpaceUpdate, SearchSpaceRead
from app.users import current_active_user
from app.utils.check_ownership import check_ownership
from fastapi import HTTPException
router = APIRouter()
@router.post("/searchspaces/", response_model=SearchSpaceRead)
async def create_search_space(
search_space: SearchSpaceCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
try:
db_search_space = SearchSpace(**search_space.model_dump(), user_id=user.id)
session.add(db_search_space)
await session.commit()
await session.refresh(db_search_space)
return db_search_space
except HTTPException:
raise
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"Failed to create search space: {str(e)}"
)
@router.get("/searchspaces/", response_model=List[SearchSpaceRead])
async def read_search_spaces(
skip: int = 0,
limit: int = 200,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
try:
result = await session.execute(
select(SearchSpace)
.filter(SearchSpace.user_id == user.id)
.offset(skip)
.limit(limit)
)
return result.scalars().all()
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to fetch search spaces: {str(e)}"
)
@router.get("/searchspaces/{search_space_id}", response_model=SearchSpaceRead)
async def read_search_space(
search_space_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
try:
search_space = await check_ownership(session, SearchSpace, search_space_id, user)
return search_space
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to fetch search space: {str(e)}"
)
@router.put("/searchspaces/{search_space_id}", response_model=SearchSpaceRead)
async def update_search_space(
search_space_id: int,
search_space_update: SearchSpaceUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
try:
db_search_space = await check_ownership(session, SearchSpace, search_space_id, user)
update_data = search_space_update.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(db_search_space, key, value)
await session.commit()
await session.refresh(db_search_space)
return db_search_space
except HTTPException:
raise
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"Failed to update search space: {str(e)}"
)
@router.delete("/searchspaces/{search_space_id}", response_model=dict)
async def delete_search_space(
search_space_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
try:
db_search_space = await check_ownership(session, SearchSpace, search_space_id, user)
await session.delete(db_search_space)
await session.commit()
return {"message": "Search space deleted successfully"}
except HTTPException:
raise
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"Failed to delete search space: {str(e)}"
)

View file

@ -0,0 +1,50 @@
from .base import TimestampModel, IDModel
from .users import UserRead, UserCreate, UserUpdate
from .search_space import SearchSpaceBase, SearchSpaceCreate, SearchSpaceUpdate, SearchSpaceRead
from .documents import (
ExtensionDocumentMetadata,
ExtensionDocumentContent,
DocumentBase,
DocumentsCreate,
DocumentUpdate,
DocumentRead,
)
from .chunks import ChunkBase, ChunkCreate, ChunkUpdate, ChunkRead
from .podcasts import PodcastBase, PodcastCreate, PodcastUpdate, PodcastRead
from .chats import ChatBase, ChatCreate, ChatUpdate, ChatRead, AISDKChatRequest
from .search_source_connector import SearchSourceConnectorBase, SearchSourceConnectorCreate, SearchSourceConnectorUpdate, SearchSourceConnectorRead
__all__ = [
"AISDKChatRequest",
"TimestampModel",
"IDModel",
"UserRead",
"UserCreate",
"UserUpdate",
"SearchSpaceBase",
"SearchSpaceCreate",
"SearchSpaceUpdate",
"SearchSpaceRead",
"ExtensionDocumentMetadata",
"ExtensionDocumentContent",
"DocumentBase",
"DocumentsCreate",
"DocumentUpdate",
"DocumentRead",
"ChunkBase",
"ChunkCreate",
"ChunkUpdate",
"ChunkRead",
"PodcastBase",
"PodcastCreate",
"PodcastUpdate",
"PodcastRead",
"ChatBase",
"ChatCreate",
"ChatUpdate",
"ChatRead",
"SearchSourceConnectorBase",
"SearchSourceConnectorCreate",
"SearchSourceConnectorUpdate",
"SearchSourceConnectorRead",
]

View file

@ -0,0 +1,8 @@
from datetime import datetime
from pydantic import BaseModel
class TimestampModel(BaseModel):
created_at: datetime
class IDModel(BaseModel):
id: int

View file

@ -0,0 +1,46 @@
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from sqlalchemy import JSON
from .base import IDModel, TimestampModel
from app.db import ChatType
class ChatBase(BaseModel):
type: ChatType
title: str
initial_connectors: Optional[List[str]] = None
messages: List[Any]
search_space_id: int
class ClientAttachment(BaseModel):
name: str
contentType: str
url: str
class ToolInvocation(BaseModel):
toolCallId: str
toolName: str
args: dict
result: dict
class ClientMessage(BaseModel):
role: str
content: str
experimental_attachments: Optional[List[ClientAttachment]] = None
toolInvocations: Optional[List[ToolInvocation]] = None
class AISDKChatRequest(BaseModel):
messages: List[ClientMessage]
data: Optional[Dict[str, Any]] = None
class ChatCreate(ChatBase):
pass
class ChatUpdate(ChatBase):
pass
class ChatRead(ChatBase, IDModel, TimestampModel):
class Config:
from_attributes = True

View file

@ -0,0 +1,16 @@
from pydantic import BaseModel
from .base import IDModel, TimestampModel
class ChunkBase(BaseModel):
content: str
document_id: int
class ChunkCreate(ChunkBase):
pass
class ChunkUpdate(ChunkBase):
pass
class ChunkRead(ChunkBase, IDModel, TimestampModel):
class Config:
from_attributes = True

View file

@ -0,0 +1,42 @@
from typing import List, Any
from pydantic import BaseModel
from sqlalchemy import JSON
from .base import IDModel, TimestampModel
from app.db import DocumentType
from datetime import datetime
class ExtensionDocumentMetadata(BaseModel):
BrowsingSessionId: str
VisitedWebPageURL: str
VisitedWebPageTitle: str
VisitedWebPageDateWithTimeInISOString: str
VisitedWebPageReffererURL: str
VisitedWebPageVisitDurationInMilliseconds: str
class ExtensionDocumentContent(BaseModel):
metadata: ExtensionDocumentMetadata
pageContent: str
class DocumentBase(BaseModel):
document_type: DocumentType
content: List[ExtensionDocumentContent] | List[str] | str # Updated to allow string content
search_space_id: int
class DocumentsCreate(DocumentBase):
pass
class DocumentUpdate(DocumentBase):
pass
class DocumentRead(BaseModel):
id: int
title: str
document_type: DocumentType
document_metadata: dict
content: str # Changed to string to match frontend
created_at: datetime
search_space_id: int
class Config:
from_attributes = True

View file

@ -0,0 +1,19 @@
from pydantic import BaseModel
from .base import IDModel, TimestampModel
class PodcastBase(BaseModel):
title: str
is_generated: bool = False
podcast_content: str = ""
file_location: str = ""
search_space_id: int
class PodcastCreate(PodcastBase):
pass
class PodcastUpdate(PodcastBase):
pass
class PodcastRead(PodcastBase, IDModel, TimestampModel):
class Config:
from_attributes = True

View file

@ -0,0 +1,73 @@
from datetime import datetime
import uuid
from typing import Dict, Any
from pydantic import BaseModel, field_validator
from .base import IDModel, TimestampModel
from app.db import SearchSourceConnectorType
from fastapi import HTTPException
class SearchSourceConnectorBase(BaseModel):
name: str
connector_type: SearchSourceConnectorType
is_indexable: bool
last_indexed_at: datetime | None
config: Dict[str, Any]
@field_validator('config')
@classmethod
def validate_config_for_connector_type(cls, config: Dict[str, Any], values: Dict[str, Any]) -> Dict[str, Any]:
connector_type = values.data.get('connector_type')
if connector_type == SearchSourceConnectorType.SERPER_API:
# For SERPER_API, only allow SERPER_API_KEY
allowed_keys = ["SERPER_API_KEY"]
if set(config.keys()) != set(allowed_keys):
raise ValueError(f"For SERPER_API connector type, config must only contain these keys: {allowed_keys}")
# Ensure the API key is not empty
if not config.get("SERPER_API_KEY"):
raise ValueError("SERPER_API_KEY cannot be empty")
elif connector_type == SearchSourceConnectorType.TAVILY_API:
# For TAVILY_API, only allow TAVILY_API_KEY
allowed_keys = ["TAVILY_API_KEY"]
if set(config.keys()) != set(allowed_keys):
raise ValueError(f"For TAVILY_API connector type, config must only contain these keys: {allowed_keys}")
# Ensure the API key is not empty
if not config.get("TAVILY_API_KEY"):
raise ValueError("TAVILY_API_KEY cannot be empty")
elif connector_type == SearchSourceConnectorType.SLACK_CONNECTOR:
# For SLACK_CONNECTOR, only allow SLACK_BOT_TOKEN
allowed_keys = ["SLACK_BOT_TOKEN"]
if set(config.keys()) != set(allowed_keys):
raise ValueError(f"For SLACK_CONNECTOR connector type, config must only contain these keys: {allowed_keys}")
# Ensure the bot token is not empty
if not config.get("SLACK_BOT_TOKEN"):
raise ValueError("SLACK_BOT_TOKEN cannot be empty")
elif connector_type == SearchSourceConnectorType.NOTION_CONNECTOR:
# For NOTION_CONNECTOR, only allow NOTION_INTEGRATION_TOKEN
allowed_keys = ["NOTION_INTEGRATION_TOKEN"]
if set(config.keys()) != set(allowed_keys):
raise ValueError(f"For NOTION_CONNECTOR connector type, config must only contain these keys: {allowed_keys}")
# Ensure the integration token is not empty
if not config.get("NOTION_INTEGRATION_TOKEN"):
raise ValueError("NOTION_INTEGRATION_TOKEN cannot be empty")
return config
class SearchSourceConnectorCreate(SearchSourceConnectorBase):
pass
class SearchSourceConnectorUpdate(SearchSourceConnectorBase):
pass
class SearchSourceConnectorRead(SearchSourceConnectorBase, IDModel, TimestampModel):
user_id: uuid.UUID
class Config:
from_attributes = True

View file

@ -0,0 +1,23 @@
from datetime import datetime
import uuid
from typing import Optional
from pydantic import BaseModel
from .base import IDModel, TimestampModel
class SearchSpaceBase(BaseModel):
name: str
description: Optional[str] = None
class SearchSpaceCreate(SearchSpaceBase):
pass
class SearchSpaceUpdate(SearchSpaceBase):
pass
class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel):
id: int
created_at: datetime
user_id: uuid.UUID
class Config:
from_attributes = True

View file

@ -0,0 +1,11 @@
import uuid
from fastapi_users import schemas
class UserRead(schemas.BaseUser[uuid.UUID]):
pass
class UserCreate(schemas.BaseUserCreate):
pass
class UserUpdate(schemas.BaseUserUpdate):
pass

View file

View file

@ -0,0 +1,246 @@
from typing import Optional, List
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import SQLAlchemyError
from app.db import Document, DocumentType, Chunk
from app.schemas import ExtensionDocumentContent
from app.config import config
from app.prompts import SUMMARY_PROMPT_TEMPLATE
from datetime import datetime
from app.utils.document_converters import convert_document_to_markdown
from langchain_core.documents import Document as LangChainDocument
from langchain_community.document_loaders import FireCrawlLoader, AsyncChromiumLoader
from langchain_community.document_transformers import MarkdownifyTransformer
import validators
md = MarkdownifyTransformer()
async def add_crawled_url_document(
session: AsyncSession,
url: str,
search_space_id: int
) -> Optional[Document]:
try:
if not validators.url(url):
raise ValueError(f"Url {url} is not a valid URL address")
if config.FIRECRAWL_API_KEY:
crawl_loader = FireCrawlLoader(
url=url,
api_key=config.FIRECRAWL_API_KEY,
mode="scrape",
params={
"formats": ["markdown"],
"excludeTags": ["a"],
}
)
else:
crawl_loader = AsyncChromiumLoader(urls=[url], headless=True)
url_crawled = await crawl_loader.aload()
if type(crawl_loader) == FireCrawlLoader:
content_in_markdown = url_crawled[0].page_content
elif type(crawl_loader) == AsyncChromiumLoader:
content_in_markdown = md.transform_documents(url_crawled)[
0].page_content
# Format document metadata in a more maintainable way
metadata_sections = [
("METADATA", [
f"{key.upper()}: {value}" for key, value in url_crawled[0].metadata.items()
]),
("CONTENT", [
"FORMAT: markdown",
"TEXT_START",
content_in_markdown,
"TEXT_END"
])
]
# Build the document string more efficiently
document_parts = []
document_parts.append("<DOCUMENT>")
for section_title, section_content in metadata_sections:
document_parts.append(f"<{section_title}>")
document_parts.extend(section_content)
document_parts.append(f"</{section_title}>")
document_parts.append("</DOCUMENT>")
combined_document_string = '\n'.join(document_parts)
# Generate summary
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
summary_result = await summary_chain.ainvoke({"document": combined_document_string})
summary_content = summary_result.content
summary_embedding = config.embedding_model_instance.embed(
summary_content)
# Process chunks
chunks = [
Chunk(content=chunk.text, embedding=chunk.embedding)
for chunk in config.chunker_instance.chunk(content_in_markdown)
]
# Create and store document
document = Document(
search_space_id=search_space_id,
title=url_crawled[0].metadata['title'] if type(
crawl_loader) == FireCrawlLoader else url_crawled[0].metadata['source'],
document_type=DocumentType.CRAWLED_URL,
document_metadata=url_crawled[0].metadata,
content=summary_content,
embedding=summary_embedding,
chunks=chunks
)
session.add(document)
await session.commit()
await session.refresh(document)
return document
except SQLAlchemyError as db_error:
await session.rollback()
raise db_error
except Exception as e:
await session.rollback()
raise RuntimeError(f"Failed to crawl URL: {str(e)}")
async def add_extension_received_document(
session: AsyncSession,
content: ExtensionDocumentContent,
search_space_id: int
) -> Optional[Document]:
"""
Process and store document content received from the SurfSense Extension.
Args:
session: Database session
content: Document content from extension
search_space_id: ID of the search space
Returns:
Document object if successful, None if failed
"""
try:
# Format document metadata in a more maintainable way
metadata_sections = [
("METADATA", [
f"SESSION_ID: {content.metadata.BrowsingSessionId}",
f"URL: {content.metadata.VisitedWebPageURL}",
f"TITLE: {content.metadata.VisitedWebPageTitle}",
f"REFERRER: {content.metadata.VisitedWebPageReffererURL}",
f"TIMESTAMP: {content.metadata.VisitedWebPageDateWithTimeInISOString}",
f"DURATION_MS: {content.metadata.VisitedWebPageVisitDurationInMilliseconds}"
]),
("CONTENT", [
"FORMAT: markdown",
"TEXT_START",
content.pageContent,
"TEXT_END"
])
]
# Build the document string more efficiently
document_parts = []
document_parts.append("<DOCUMENT>")
for section_title, section_content in metadata_sections:
document_parts.append(f"<{section_title}>")
document_parts.extend(section_content)
document_parts.append(f"</{section_title}>")
document_parts.append("</DOCUMENT>")
combined_document_string = '\n'.join(document_parts)
# Generate summary
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
summary_result = await summary_chain.ainvoke({"document": combined_document_string})
summary_content = summary_result.content
summary_embedding = config.embedding_model_instance.embed(
summary_content)
# Process chunks
chunks = [
Chunk(content=chunk.text, embedding=chunk.embedding)
for chunk in config.chunker_instance.chunk(content.pageContent)
]
# Create and store document
document = Document(
search_space_id=search_space_id,
title=content.metadata.VisitedWebPageTitle,
document_type=DocumentType.EXTENSION,
document_metadata=content.metadata.model_dump(),
content=summary_content,
embedding=summary_embedding,
chunks=chunks
)
session.add(document)
await session.commit()
await session.refresh(document)
return document
except SQLAlchemyError as db_error:
await session.rollback()
raise db_error
except Exception as e:
await session.rollback()
raise RuntimeError(f"Failed to process extension document: {str(e)}")
async def add_received_file_document(
session: AsyncSession,
file_name: str,
unstructured_processed_elements: List[LangChainDocument],
search_space_id: int
) -> Optional[Document]:
try:
file_in_markdown = await convert_document_to_markdown(unstructured_processed_elements)
# TODO: Check if file_markdown exceeds token limit of embedding model
# Generate summary
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
summary_result = await summary_chain.ainvoke({"document": file_in_markdown})
summary_content = summary_result.content
summary_embedding = config.embedding_model_instance.embed(
summary_content)
# Process chunks
chunks = [
Chunk(content=chunk.text, embedding=chunk.embedding)
for chunk in config.chunker_instance.chunk(file_in_markdown)
]
# Create and store document
document = Document(
search_space_id=search_space_id,
title=file_name,
document_type=DocumentType.FILE,
document_metadata={
"FILE_NAME": file_name,
"SAVED_AT": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
},
content=summary_content,
embedding=summary_embedding,
chunks=chunks
)
session.add(document)
await session.commit()
await session.refresh(document)
return document
except SQLAlchemyError as db_error:
await session.rollback()
raise db_error
except Exception as e:
await session.rollback()
raise RuntimeError(f"Failed to process file document: {str(e)}")

View file

@ -0,0 +1,486 @@
from typing import Optional, List, Dict, Any, Tuple
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.future import select
from datetime import datetime, timedelta
from app.db import Document, DocumentType, Chunk, SearchSourceConnector, SearchSourceConnectorType
from app.config import config
from app.prompts import SUMMARY_PROMPT_TEMPLATE
from app.connectors.slack_history import SlackHistory
from app.connectors.notion_history import NotionHistoryConnector
from slack_sdk.errors import SlackApiError
import logging
# Set up logging
logger = logging.getLogger(__name__)
async def index_slack_messages(
session: AsyncSession,
connector_id: int,
search_space_id: int,
update_last_indexed: bool = True
) -> Tuple[int, Optional[str]]:
"""
Index Slack messages from all accessible channels.
Args:
session: Database session
connector_id: ID of the Slack connector
search_space_id: ID of the search space to store documents in
update_last_indexed: Whether to update the last_indexed_at timestamp (default: True)
Returns:
Tuple containing (number of documents indexed, error message or None)
"""
try:
# Get the connector
result = await session.execute(
select(SearchSourceConnector)
.filter(
SearchSourceConnector.id == connector_id,
SearchSourceConnector.connector_type == SearchSourceConnectorType.SLACK_CONNECTOR
)
)
connector = result.scalars().first()
if not connector:
return 0, f"Connector with ID {connector_id} not found or is not a Slack connector"
# Get the Slack token from the connector config
slack_token = connector.config.get("SLACK_BOT_TOKEN")
if not slack_token:
return 0, "Slack token not found in connector config"
# Initialize Slack client
slack_client = SlackHistory(token=slack_token)
# Calculate date range
end_date = datetime.now()
# Use last_indexed_at as start date if available, otherwise use 365 days ago
if connector.last_indexed_at:
# Check if last_indexed_at is today
today = datetime.now().date()
if connector.last_indexed_at.date() == today:
# If last indexed today, go back 1 day to ensure we don't miss anything
start_date = end_date - timedelta(days=7)
else:
start_date = connector.last_indexed_at
else:
start_date = end_date - timedelta(days=365)
# Format dates for Slack API
start_date_str = start_date.strftime("%Y-%m-%d")
end_date_str = end_date.strftime("%Y-%m-%d")
# Get all channels
try:
channels = slack_client.get_all_channels()
except Exception as e:
return 0, f"Failed to get Slack channels: {str(e)}"
if not channels:
return 0, "No Slack channels found"
# Track the number of documents indexed
documents_indexed = 0
skipped_channels = []
# Process each channel
for channel_name, channel_id in channels.items():
try:
# Check if the bot is a member of the channel
try:
# First try to get channel info to check if bot is a member
channel_info = slack_client.client.conversations_info(channel=channel_id)
# For private channels, the bot needs to be a member
if channel_info.get("channel", {}).get("is_private", False):
# Check if bot is a member
is_member = channel_info.get("channel", {}).get("is_member", False)
if not is_member:
logger.warning(f"Bot is not a member of private channel {channel_name} ({channel_id}). Skipping.")
skipped_channels.append(f"{channel_name} (private, bot not a member)")
continue
except SlackApiError as e:
if "not_in_channel" in str(e) or "channel_not_found" in str(e):
logger.warning(f"Bot cannot access channel {channel_name} ({channel_id}). Skipping.")
skipped_channels.append(f"{channel_name} (access error)")
continue
else:
# Re-raise if it's a different error
raise
# Get messages for this channel
messages, error = slack_client.get_history_by_date_range(
channel_id=channel_id,
start_date=start_date_str,
end_date=end_date_str,
limit=1000 # Limit to 1000 messages per channel
)
if error:
logger.warning(f"Error getting messages from channel {channel_name}: {error}")
skipped_channels.append(f"{channel_name} (error: {error})")
continue # Skip this channel if there's an error
if not messages:
logger.info(f"No messages found in channel {channel_name} for the specified date range.")
continue # Skip if no messages
# Format messages with user info
formatted_messages = []
for msg in messages:
# Skip bot messages and system messages
if msg.get("subtype") in ["bot_message", "channel_join", "channel_leave"]:
continue
formatted_msg = slack_client.format_message(msg, include_user_info=True)
formatted_messages.append(formatted_msg)
if not formatted_messages:
logger.info(f"No valid messages found in channel {channel_name} after filtering.")
continue # Skip if no valid messages after filtering
# Convert messages to markdown format
channel_content = f"# Slack Channel: {channel_name}\n\n"
for msg in formatted_messages:
user_name = msg.get("user_name", "Unknown User")
timestamp = msg.get("datetime", "Unknown Time")
text = msg.get("text", "")
channel_content += f"## {user_name} ({timestamp})\n\n{text}\n\n---\n\n"
# Format document metadata
metadata_sections = [
("METADATA", [
f"CHANNEL_NAME: {channel_name}",
f"CHANNEL_ID: {channel_id}",
f"START_DATE: {start_date_str}",
f"END_DATE: {end_date_str}",
f"MESSAGE_COUNT: {len(formatted_messages)}",
f"INDEXED_AT: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
]),
("CONTENT", [
"FORMAT: markdown",
"TEXT_START",
channel_content,
"TEXT_END"
])
]
# Build the document string
document_parts = []
document_parts.append("<DOCUMENT>")
for section_title, section_content in metadata_sections:
document_parts.append(f"<{section_title}>")
document_parts.extend(section_content)
document_parts.append(f"</{section_title}>")
document_parts.append("</DOCUMENT>")
combined_document_string = '\n'.join(document_parts)
# Generate summary
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
summary_result = await summary_chain.ainvoke({"document": combined_document_string})
summary_content = summary_result.content
summary_embedding = config.embedding_model_instance.embed(summary_content)
# Process chunks
chunks = [
Chunk(content=chunk.text, embedding=chunk.embedding)
for chunk in config.chunker_instance.chunk(channel_content)
]
# Create and store document
document = Document(
search_space_id=search_space_id,
title=f"Slack - {channel_name}",
document_type=DocumentType.SLACK_CONNECTOR,
document_metadata={
"channel_name": channel_name,
"channel_id": channel_id,
"start_date": start_date_str,
"end_date": end_date_str,
"message_count": len(formatted_messages),
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
},
content=summary_content,
embedding=summary_embedding,
chunks=chunks
)
session.add(document)
documents_indexed += 1
logger.info(f"Successfully indexed channel {channel_name} with {len(formatted_messages)} messages")
except SlackApiError as slack_error:
logger.error(f"Slack API error for channel {channel_name}: {str(slack_error)}")
skipped_channels.append(f"{channel_name} (Slack API error)")
continue # Skip this channel and continue with others
except Exception as e:
logger.error(f"Error processing channel {channel_name}: {str(e)}")
skipped_channels.append(f"{channel_name} (processing error)")
continue # Skip this channel and continue with others
# Update the last_indexed_at timestamp for the connector only if requested
# and if we successfully indexed at least one channel
if update_last_indexed and documents_indexed > 0:
connector.last_indexed_at = datetime.now()
# Commit all changes
await session.commit()
# Prepare result message
result_message = None
if skipped_channels:
result_message = f"Indexed {documents_indexed} channels. Skipped {len(skipped_channels)} channels: {', '.join(skipped_channels)}"
return documents_indexed, result_message
except SQLAlchemyError as db_error:
await session.rollback()
logger.error(f"Database error: {str(db_error)}")
return 0, f"Database error: {str(db_error)}"
except Exception as e:
await session.rollback()
logger.error(f"Failed to index Slack messages: {str(e)}")
return 0, f"Failed to index Slack messages: {str(e)}"
async def index_notion_pages(
session: AsyncSession,
connector_id: int,
search_space_id: int,
update_last_indexed: bool = True
) -> Tuple[int, Optional[str]]:
"""
Index Notion pages from all accessible pages.
Args:
session: Database session
connector_id: ID of the Notion connector
search_space_id: ID of the search space to store documents in
update_last_indexed: Whether to update the last_indexed_at timestamp (default: True)
Returns:
Tuple containing (number of documents indexed, error message or None)
"""
try:
# Get the connector
result = await session.execute(
select(SearchSourceConnector)
.filter(
SearchSourceConnector.id == connector_id,
SearchSourceConnector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR
)
)
connector = result.scalars().first()
if not connector:
return 0, f"Connector with ID {connector_id} not found or is not a Notion connector"
# Get the Notion token from the connector config
notion_token = connector.config.get("NOTION_INTEGRATION_TOKEN")
if not notion_token:
return 0, "Notion integration token not found in connector config"
# Initialize Notion client
logger.info(f"Initializing Notion client for connector {connector_id}")
notion_client = NotionHistoryConnector(token=notion_token)
# Calculate date range
end_date = datetime.now()
# Use last_indexed_at as start date if available, otherwise use 365 days ago
if connector.last_indexed_at:
# Check if last_indexed_at is today
today = datetime.now().date()
if connector.last_indexed_at.date() == today:
# If last indexed today, go back 1 day to ensure we don't miss anything
start_date = end_date - timedelta(days=1)
else:
start_date = connector.last_indexed_at
else:
start_date = end_date - timedelta(days=365)
# Format dates for Notion API (ISO format)
start_date_str = start_date.strftime("%Y-%m-%dT%H:%M:%SZ")
end_date_str = end_date.strftime("%Y-%m-%dT%H:%M:%SZ")
logger.info(f"Fetching Notion pages from {start_date_str} to {end_date_str}")
# Get all pages
try:
pages = notion_client.get_all_pages(start_date=start_date_str, end_date=end_date_str)
logger.info(f"Found {len(pages)} Notion pages")
except Exception as e:
logger.error(f"Error fetching Notion pages: {str(e)}", exc_info=True)
return 0, f"Failed to get Notion pages: {str(e)}"
if not pages:
logger.info("No Notion pages found to index")
return 0, "No Notion pages found"
# Track the number of documents indexed
documents_indexed = 0
skipped_pages = []
# Process each page
for page in pages:
try:
page_id = page.get("page_id")
page_title = page.get("title", f"Untitled page ({page_id})")
page_content = page.get("content", [])
logger.info(f"Processing Notion page: {page_title} ({page_id})")
if not page_content:
logger.info(f"No content found in page {page_title}. Skipping.")
skipped_pages.append(f"{page_title} (no content)")
continue
# Convert page content to markdown format
markdown_content = f"# Notion Page: {page_title}\n\n"
# Process blocks recursively
def process_blocks(blocks, level=0):
result = ""
for block in blocks:
block_type = block.get("type")
block_content = block.get("content", "")
children = block.get("children", [])
# Add indentation based on level
indent = " " * level
# Format based on block type
if block_type in ["paragraph", "text"]:
result += f"{indent}{block_content}\n\n"
elif block_type in ["heading_1", "header"]:
result += f"{indent}# {block_content}\n\n"
elif block_type == "heading_2":
result += f"{indent}## {block_content}\n\n"
elif block_type == "heading_3":
result += f"{indent}### {block_content}\n\n"
elif block_type == "bulleted_list_item":
result += f"{indent}* {block_content}\n"
elif block_type == "numbered_list_item":
result += f"{indent}1. {block_content}\n"
elif block_type == "to_do":
result += f"{indent}- [ ] {block_content}\n"
elif block_type == "toggle":
result += f"{indent}> {block_content}\n"
elif block_type == "code":
result += f"{indent}```\n{block_content}\n```\n\n"
elif block_type == "quote":
result += f"{indent}> {block_content}\n\n"
elif block_type == "callout":
result += f"{indent}> **Note:** {block_content}\n\n"
elif block_type == "image":
result += f"{indent}![Image]({block_content})\n\n"
else:
# Default for other block types
if block_content:
result += f"{indent}{block_content}\n\n"
# Process children recursively
if children:
result += process_blocks(children, level + 1)
return result
logger.debug(f"Converting {len(page_content)} blocks to markdown for page {page_title}")
markdown_content += process_blocks(page_content)
# Format document metadata
metadata_sections = [
("METADATA", [
f"PAGE_TITLE: {page_title}",
f"PAGE_ID: {page_id}",
f"INDEXED_AT: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
]),
("CONTENT", [
"FORMAT: markdown",
"TEXT_START",
markdown_content,
"TEXT_END"
])
]
# Build the document string
document_parts = []
document_parts.append("<DOCUMENT>")
for section_title, section_content in metadata_sections:
document_parts.append(f"<{section_title}>")
document_parts.extend(section_content)
document_parts.append(f"</{section_title}>")
document_parts.append("</DOCUMENT>")
combined_document_string = '\n'.join(document_parts)
# Generate summary
logger.debug(f"Generating summary for page {page_title}")
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
summary_result = await summary_chain.ainvoke({"document": combined_document_string})
summary_content = summary_result.content
summary_embedding = config.embedding_model_instance.embed(summary_content)
# Process chunks
logger.debug(f"Chunking content for page {page_title}")
chunks = [
Chunk(content=chunk.text, embedding=chunk.embedding)
for chunk in config.chunker_instance.chunk(markdown_content)
]
# Create and store document
document = Document(
search_space_id=search_space_id,
title=f"Notion - {page_title}",
document_type=DocumentType.NOTION_CONNECTOR,
document_metadata={
"page_title": page_title,
"page_id": page_id,
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
},
content=summary_content,
embedding=summary_embedding,
chunks=chunks
)
session.add(document)
documents_indexed += 1
logger.info(f"Successfully indexed Notion page: {page_title}")
except Exception as e:
logger.error(f"Error processing Notion page {page.get('title', 'Unknown')}: {str(e)}", exc_info=True)
skipped_pages.append(f"{page.get('title', 'Unknown')} (processing error)")
continue # Skip this page and continue with others
# Update the last_indexed_at timestamp for the connector only if requested
# and if we successfully indexed at least one page
if update_last_indexed and documents_indexed > 0:
connector.last_indexed_at = datetime.now()
logger.info(f"Updated last_indexed_at for connector {connector_id}")
# Commit all changes
await session.commit()
# Prepare result message
result_message = None
if skipped_pages:
result_message = f"Indexed {documents_indexed} pages. Skipped {len(skipped_pages)} pages: {', '.join(skipped_pages)}"
logger.info(f"Notion indexing completed: {documents_indexed} pages indexed, {len(skipped_pages)} pages skipped")
return documents_indexed, result_message
except SQLAlchemyError as db_error:
await session.rollback()
logger.error(f"Database error during Notion indexing: {str(db_error)}", exc_info=True)
return 0, f"Database error: {str(db_error)}"
except Exception as e:
await session.rollback()
logger.error(f"Failed to index Notion pages: {str(e)}", exc_info=True)
return 0, f"Failed to index Notion pages: {str(e)}"

View file

@ -0,0 +1,340 @@
import json
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, AsyncGenerator, Dict, Any
import asyncio
import re
from app.utils.connector_service import ConnectorService
from app.utils.research_service import ResearchService
from app.utils.streaming_service import StreamingService
from app.utils.reranker_service import RerankerService
from app.config import config
from app.utils.document_converters import convert_chunks_to_langchain_documents
async def stream_connector_search_results(
user_query: str,
user_id: int,
search_space_id: int,
session: AsyncSession,
research_mode: str,
selected_connectors: List[str]
) -> AsyncGenerator[str, None]:
"""
Stream connector search results to the client
Args:
user_query: The user's query
user_id: The user's ID
search_space_id: The search space ID
session: The database session
research_mode: The research mode
selected_connectors: List of selected connectors
Yields:
str: Formatted response strings
"""
# Initialize services
connector_service = ConnectorService(session)
streaming_service = StreamingService()
reranker_service = RerankerService.get_reranker_instance(config)
all_raw_documents = [] # Store all raw documents before reranking
all_sources = []
TOP_K = 20
if research_mode == "GENERAL":
TOP_K = 20
elif research_mode == "DEEP":
TOP_K = 40
elif research_mode == "DEEPER":
TOP_K = 60
# Process each selected connector
for connector in selected_connectors:
# Crawled URLs
if connector == "CRAWLED_URL":
# Send terminal message about starting search
yield streaming_service.add_terminal_message("Starting to search for crawled URLs...")
# Search for crawled URLs
result_object, crawled_urls_chunks = await connector_service.search_crawled_urls(
user_query=user_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=TOP_K
)
# Send terminal message about search results
yield streaming_service.add_terminal_message(
f"Found {len(result_object['sources'])} relevant crawled URLs",
"success"
)
# Update sources
all_sources.append(result_object)
yield streaming_service.update_sources(all_sources)
# Add documents to collection
all_raw_documents.extend(crawled_urls_chunks)
# Files
if connector == "FILE":
# Send terminal message about starting search
yield streaming_service.add_terminal_message("Starting to search for files...")
# Search for files
result_object, files_chunks = await connector_service.search_files(
user_query=user_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=TOP_K
)
# Send terminal message about search results
yield streaming_service.add_terminal_message(
f"Found {len(result_object['sources'])} relevant files",
"success"
)
# Update sources
all_sources.append(result_object)
yield streaming_service.update_sources(all_sources)
# Add documents to collection
all_raw_documents.extend(files_chunks)
# Tavily Connector
if connector == "TAVILY_API":
# Send terminal message about starting search
yield streaming_service.add_terminal_message("Starting to search with Tavily API...")
# Search using Tavily API
result_object, tavily_chunks = await connector_service.search_tavily(
user_query=user_query,
user_id=user_id,
top_k=TOP_K
)
# Send terminal message about search results
yield streaming_service.add_terminal_message(
f"Found {len(result_object['sources'])} relevant results from Tavily",
"success"
)
# Update sources
all_sources.append(result_object)
yield streaming_service.update_sources(all_sources)
# Add documents to collection
all_raw_documents.extend(tavily_chunks)
# Slack Connector
if connector == "SLACK_CONNECTOR":
# Send terminal message about starting search
yield streaming_service.add_terminal_message("Starting to search for slack connector...")
# Search using Slack API
result_object, slack_chunks = await connector_service.search_slack(
user_query=user_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=TOP_K
)
# Send terminal message about search results
yield streaming_service.add_terminal_message(
f"Found {len(result_object['sources'])} relevant results from Slack",
"success"
)
# Update sources
all_sources.append(result_object)
yield streaming_service.update_sources(all_sources)
# Add documents to collection
all_raw_documents.extend(slack_chunks)
# Notion Connector
if connector == "NOTION_CONNECTOR":
# Send terminal message about starting search
yield streaming_service.add_terminal_message("Starting to search for notion connector...")
# Search using Notion API
result_object, notion_chunks = await connector_service.search_notion(
user_query=user_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=TOP_K
)
# Send terminal message about search results
yield streaming_service.add_terminal_message(
f"Found {len(result_object['sources'])} relevant results from Notion",
"success"
)
# Update sources
all_sources.append(result_object)
yield streaming_service.update_sources(all_sources)
# Add documents to collection
all_raw_documents.extend(notion_chunks)
# If we have documents to research
if all_raw_documents:
# Rerank all documents if reranker is available
if reranker_service:
yield streaming_service.add_terminal_message("Reranking documents for better relevance...", "info")
# Convert documents to format expected by reranker
reranker_input_docs = [
{
"chunk_id": doc.get("chunk_id", f"chunk_{i}"),
"content": doc.get("content", ""),
"score": doc.get("score", 0.0),
"document": {
"id": doc.get("document", {}).get("id", ""),
"title": doc.get("document", {}).get("title", ""),
"document_type": doc.get("document", {}).get("document_type", ""),
"metadata": doc.get("document", {}).get("metadata", {})
}
} for i, doc in enumerate(all_raw_documents)
]
# Rerank documents
reranked_docs = reranker_service.rerank_documents(user_query, reranker_input_docs)
# Sort by score in descending order
reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True)
# Convert back to langchain documents format
from langchain.schema import Document as LangchainDocument
all_langchain_documents_to_research = [
LangchainDocument(
page_content= f"""<document><metadata><source_id>{doc.get("document", {}).get("id", "")}</source_id></metadata><content>{doc.get("content", "")}</content></document>""",
metadata={
# **doc.get("document", {}).get("metadata", {}),
# "score": doc.get("score", 0.0),
# "rank": doc.get("rank", 0),
# "document_id": doc.get("document", {}).get("id", ""),
# "document_title": doc.get("document", {}).get("title", ""),
# "document_type": doc.get("document", {}).get("document_type", ""),
# # Explicitly set source_id for citation purposes
"source_id": str(doc.get("document", {}).get("id", ""))
}
) for doc in reranked_docs
]
yield streaming_service.add_terminal_message(f"Reranked {len(all_langchain_documents_to_research)} documents", "success")
else:
# Use raw documents if no reranker is available
all_langchain_documents_to_research = convert_chunks_to_langchain_documents(all_raw_documents)
# Send terminal message about starting research
yield streaming_service.add_terminal_message("Starting to research...", "info")
# Create a buffer to collect report content
report_buffer = []
# Use the streaming research method
yield streaming_service.add_terminal_message("Generating report...", "info")
# Create a wrapper to handle the streaming
class StreamHandler:
def __init__(self):
self.queue = asyncio.Queue()
async def handle_progress(self, data):
result = None
if data.get("type") == "logs":
# Handle log messages
result = streaming_service.add_terminal_message(data.get("output", ""), "info")
elif data.get("type") == "report":
# Handle report content
content = data.get("output", "")
# Fix incorrect citation formats using regex
# More specific pattern to match only numeric citations in markdown-style links
# This matches patterns like ([1](https://github.com/...)) but not general links like ([Click here](https://...))
pattern = r'\(\[(\d+)\]\((https?://[^\)]+)\)\)'
# Replace with just [X] where X is the number
content = re.sub(pattern, r'[\1]', content)
# Also match other incorrect formats like ([1]) and convert to [1]
# Only match if the content inside brackets is a number
content = re.sub(r'\(\[(\d+)\]\)', r'[\1]', content)
report_buffer.append(content)
# Update the answer with the accumulated content
result = streaming_service.update_answer(report_buffer)
if result:
await self.queue.put(result)
return result
async def get_next(self):
try:
return await self.queue.get()
except Exception as e:
print(f"Error getting next item from queue: {e}")
return None
def task_done(self):
self.queue.task_done()
# Create the stream handler
stream_handler = StreamHandler()
# Start the research process in a separate task
research_task = asyncio.create_task(
ResearchService.stream_research(
user_query=user_query,
documents=all_langchain_documents_to_research,
on_progress=stream_handler.handle_progress,
research_mode=research_mode
)
)
# Stream results as they become available
while not research_task.done() or not stream_handler.queue.empty():
try:
# Get the next result with a timeout
result = await asyncio.wait_for(stream_handler.get_next(), timeout=0.1)
stream_handler.task_done()
yield result
except asyncio.TimeoutError:
# No result available yet, check if the research task is done
if research_task.done():
# If the queue is empty and the task is done, we're finished
if stream_handler.queue.empty():
break
# Get the final report
try:
final_report = await research_task
# Send terminal message about research completion
yield streaming_service.add_terminal_message("Research completed", "success")
# Update the answer with the final report
final_report_lines = final_report.split('\n')
yield streaming_service.update_answer(final_report_lines)
except Exception as e:
# Handle any exceptions
yield streaming_service.add_terminal_message(f"Error during research: {str(e)}", "error")
# Send completion message
yield streaming_service.format_completion()

View file

@ -0,0 +1,95 @@
from typing import Optional
import uuid
from fastapi import Depends, Request, Response
from fastapi.responses import RedirectResponse
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models
from fastapi_users.authentication import (
AuthenticationBackend,
BearerTransport,
JWTStrategy,
)
from fastapi_users.db import SQLAlchemyUserDatabase
from httpx_oauth.clients.google import GoogleOAuth2
from app.config import config
from app.db import User, get_user_db
from pydantic import BaseModel
class BearerResponse(BaseModel):
access_token: str
token_type: str
SECRET = config.SECRET_KEY
google_oauth_client = GoogleOAuth2(
config.GOOGLE_OAUTH_CLIENT_ID,
config.GOOGLE_OAUTH_CLIENT_SECRET,
)
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = SECRET
verification_token_secret = SECRET
async def on_after_register(self, user: User, request: Optional[Request] = None):
print(f"User {user.id} has registered.")
async def on_after_forgot_password(
self, user: User, token: str, request: Optional[Request] = None
):
print(f"User {user.id} has forgot their password. Reset token: {token}")
async def on_after_request_verify(
self, user: User, token: str, request: Optional[Request] = None
):
print(
f"Verification requested for user {user.id}. Verification token: {token}")
async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)):
yield UserManager(user_db)
def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
return JWTStrategy(secret=SECRET, lifetime_seconds=3600*24)
# # COOKIE AUTH | Uncomment if you want to use cookie auth.
# from fastapi_users.authentication import (
# CookieTransport,
# )
# class CustomCookieTransport(CookieTransport):
# async def get_login_response(self, token: str) -> Response:
# response = RedirectResponse(config.OAUTH_REDIRECT_URL, status_code=302)
# return self._set_login_cookie(response, token)
# cookie_transport = CustomCookieTransport(
# cookie_max_age=3600,
# )
# auth_backend = AuthenticationBackend(
# name="jwt",
# transport=cookie_transport,
# get_strategy=get_jwt_strategy,
# )
# BEARER AUTH CODE.
class CustomBearerTransport(BearerTransport):
async def get_login_response(self, token: str) -> Response:
bearer_response = BearerResponse(access_token=token, token_type="bearer")
redirect_url = f"{config.NEXT_FRONTEND_URL}/auth/callback?token={bearer_response.access_token}"
return RedirectResponse(redirect_url, status_code=302)
bearer_transport = CustomBearerTransport(tokenUrl="auth/jwt/login")
auth_backend = AuthenticationBackend(
name="jwt",
transport=bearer_transport,
get_strategy=get_jwt_strategy,
)
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
current_active_user = fastapi_users.current_user(active=True)

View file

@ -0,0 +1,12 @@
from fastapi import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import User
# Helper function to check user ownership
async def check_ownership(session: AsyncSession, model, item_id: int, user: User):
item = await session.execute(select(model).filter(model.id == item_id, model.user_id == user.id))
item = item.scalars().first()
if not item:
raise HTTPException(status_code=404, detail="Item not found or you don't have permission to access it")
return item

View file

@ -0,0 +1,385 @@
import json
from typing import List, Dict, Any, Optional, Tuple
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
from app.db import SearchSourceConnector, SearchSourceConnectorType
from tavily import TavilyClient
class ConnectorService:
def __init__(self, session: AsyncSession):
self.session = session
self.retriever = ChucksHybridSearchRetriever(session)
self.source_id_counter = 1
async def search_crawled_urls(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20) -> tuple:
"""
Search for crawled URLs and return both the source information and langchain documents
Returns:
tuple: (sources_info, langchain_documents)
"""
crawled_urls_chunks = await self.retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="CRAWLED_URL"
)
# Map crawled_urls_chunks to the required format
mapped_sources = {}
for i, chunk in enumerate(crawled_urls_chunks):
#Fix for UI
crawled_urls_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata
document = chunk.get('document', {})
metadata = document.get('metadata', {})
# Create a mapped source entry
source = {
"id": self.source_id_counter,
"title": document.get('title', 'Untitled Document'),
"description": metadata.get('og:description', metadata.get('ogDescription', chunk.get('content', '')[:100])),
"url": metadata.get('url', '')
}
self.source_id_counter += 1
# Use a unique identifier for tracking unique sources
source_key = source.get("url") or source.get("title")
if source_key and source_key not in mapped_sources:
mapped_sources[source_key] = source
# Convert to list of sources
sources_list = list(mapped_sources.values())
# Create result object
result_object = {
"id": 1,
"name": "Crawled URLs",
"type": "CRAWLED_URL",
"sources": sources_list,
}
return result_object, crawled_urls_chunks
async def search_files(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20) -> tuple:
"""
Search for files and return both the source information and langchain documents
Returns:
tuple: (sources_info, langchain_documents)
"""
files_chunks = await self.retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="FILE"
)
# Map crawled_urls_chunks to the required format
mapped_sources = {}
for i, chunk in enumerate(files_chunks):
#Fix for UI
files_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata
document = chunk.get('document', {})
metadata = document.get('metadata', {})
# Create a mapped source entry
source = {
"id": self.source_id_counter,
"title": document.get('title', 'Untitled Document'),
"description": metadata.get('og:description', metadata.get('ogDescription', chunk.get('content', '')[:100])),
"url": metadata.get('url', '')
}
self.source_id_counter += 1
# Use a unique identifier for tracking unique sources
source_key = source.get("url") or source.get("title")
if source_key and source_key not in mapped_sources:
mapped_sources[source_key] = source
# Convert to list of sources
sources_list = list(mapped_sources.values())
# Create result object
result_object = {
"id": 2,
"name": "Files",
"type": "FILE",
"sources": sources_list,
}
return result_object, files_chunks
async def get_connector_by_type(self, user_id: int, connector_type: SearchSourceConnectorType) -> Optional[SearchSourceConnector]:
"""
Get a connector by type for a specific user
Args:
user_id: The user's ID
connector_type: The connector type to retrieve
Returns:
Optional[SearchSourceConnector]: The connector if found, None otherwise
"""
result = await self.session.execute(
select(SearchSourceConnector)
.filter(
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type == connector_type
)
)
return result.scalars().first()
async def search_tavily(self, user_query: str, user_id: int, top_k: int = 20) -> tuple:
"""
Search using Tavily API and return both the source information and documents
Args:
user_query: The user's query
user_id: The user's ID
top_k: Maximum number of results to return
Returns:
tuple: (sources_info, documents)
"""
# Get Tavily connector configuration
tavily_connector = await self.get_connector_by_type(user_id, SearchSourceConnectorType.TAVILY_API)
if not tavily_connector:
# Return empty results if no Tavily connector is configured
return {
"id": 3,
"name": "Tavily Search",
"type": "TAVILY_API",
"sources": [],
}, []
# Initialize Tavily client with API key from connector config
tavily_api_key = tavily_connector.config.get("TAVILY_API_KEY")
tavily_client = TavilyClient(api_key=tavily_api_key)
# Perform search with Tavily
try:
response = tavily_client.search(
query=user_query,
max_results=top_k,
search_depth="advanced" # Use advanced search for better results
)
# Extract results from Tavily response
tavily_results = response.get("results", [])
# Map Tavily results to the required format
sources_list = []
documents = []
# Start IDs from 1000 to avoid conflicts with other connectors
base_id = 100
for i, result in enumerate(tavily_results):
# Create a source entry
source = {
"id": self.source_id_counter,
"title": result.get("title", "Tavily Result"),
"description": result.get("content", "")[:100],
"url": result.get("url", "")
}
sources_list.append(source)
# Create a document entry
document = {
"chunk_id": f"tavily_chunk_{i}",
"content": result.get("content", ""),
"score": result.get("score", 0.0),
"document": {
"id": self.source_id_counter,
"title": result.get("title", "Tavily Result"),
"document_type": "TAVILY_API",
"metadata": {
"url": result.get("url", ""),
"published_date": result.get("published_date", ""),
"source": "TAVILY_API"
}
}
}
documents.append(document)
self.source_id_counter += 1
# Create result object
result_object = {
"id": 3,
"name": "Tavily Search",
"type": "TAVILY_API",
"sources": sources_list,
}
return result_object, documents
except Exception as e:
# Log the error and return empty results
print(f"Error searching with Tavily: {str(e)}")
return {
"id": 3,
"name": "Tavily Search",
"type": "TAVILY_API",
"sources": [],
}, []
async def search_slack(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20) -> tuple:
"""
Search for slack and return both the source information and langchain documents
Returns:
tuple: (sources_info, langchain_documents)
"""
slack_chunks = await self.retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="SLACK_CONNECTOR"
)
# Map slack_chunks to the required format
mapped_sources = {}
for i, chunk in enumerate(slack_chunks):
#Fix for UI
slack_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata
document = chunk.get('document', {})
metadata = document.get('metadata', {})
# Create a mapped source entry with Slack-specific metadata
channel_name = metadata.get('channel_name', 'Unknown Channel')
channel_id = metadata.get('channel_id', '')
message_date = metadata.get('start_date', '')
# Create a more descriptive title for Slack messages
title = f"Slack: {channel_name}"
if message_date:
title += f" ({message_date})"
# Create a more descriptive description for Slack messages
description = chunk.get('content', '')[:100]
if len(description) == 100:
description += "..."
# For URL, we can use a placeholder or construct a URL to the Slack channel if available
url = ""
if channel_id:
url = f"https://slack.com/app_redirect?channel={channel_id}"
source = {
"id": self.source_id_counter,
"title": title,
"description": description,
"url": url,
}
self.source_id_counter += 1
# Use channel_id and content as a unique identifier for tracking unique sources
source_key = f"{channel_id}_{chunk.get('chunk_id', i)}"
if source_key and source_key not in mapped_sources:
mapped_sources[source_key] = source
# Convert to list of sources
sources_list = list(mapped_sources.values())
# Create result object
result_object = {
"id": 4,
"name": "Slack",
"type": "SLACK_CONNECTOR",
"sources": sources_list,
}
return result_object, slack_chunks
async def search_notion(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20) -> tuple:
"""
Search for Notion pages and return both the source information and langchain documents
Args:
user_query: The user's query
user_id: The user's ID
search_space_id: The search space ID to search in
top_k: Maximum number of results to return
Returns:
tuple: (sources_info, langchain_documents)
"""
notion_chunks = await self.retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="NOTION_CONNECTOR"
)
# Map notion_chunks to the required format
mapped_sources = {}
for i, chunk in enumerate(notion_chunks):
# Fix for UI
notion_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata
document = chunk.get('document', {})
metadata = document.get('metadata', {})
# Create a mapped source entry with Notion-specific metadata
page_title = metadata.get('page_title', 'Untitled Page')
page_id = metadata.get('page_id', '')
indexed_at = metadata.get('indexed_at', '')
# Create a more descriptive title for Notion pages
title = f"Notion: {page_title}"
if indexed_at:
title += f" (indexed: {indexed_at})"
# Create a more descriptive description for Notion pages
description = chunk.get('content', '')[:100]
if len(description) == 100:
description += "..."
# For URL, we can use a placeholder or construct a URL to the Notion page if available
url = ""
if page_id:
# Notion page URLs follow this format
url = f"https://notion.so/{page_id.replace('-', '')}"
source = {
"id": self.source_id_counter,
"title": title,
"description": description,
"url": url,
}
self.source_id_counter += 1
# Use page_id and content as a unique identifier for tracking unique sources
source_key = f"{page_id}_{chunk.get('chunk_id', i)}"
if source_key and source_key not in mapped_sources:
mapped_sources[source_key] = source
# Convert to list of sources
sources_list = list(mapped_sources.values())
# Create result object
result_object = {
"id": 5,
"name": "Notion",
"type": "NOTION_CONNECTOR",
"sources": sources_list,
}
return result_object, notion_chunks

View file

@ -0,0 +1,136 @@
async def convert_element_to_markdown(element) -> str:
"""
Convert an Unstructured element to markdown format based on its category.
Args:
element: The Unstructured API element object
Returns:
str: Markdown formatted string
"""
element_category = element.metadata["category"]
content = element.page_content
if not content:
return ""
markdown_mapping = {
"Formula": lambda x: f"```math\n{x}\n```",
"FigureCaption": lambda x: f"*Figure: {x}*",
"NarrativeText": lambda x: f"{x}\n\n",
"ListItem": lambda x: f"- {x}\n",
"Title": lambda x: f"# {x}\n\n",
"Address": lambda x: f"> {x}\n\n",
"EmailAddress": lambda x: f"`{x}`",
"Image": lambda x: f"![{x}]({x})",
"PageBreak": lambda x: "\n---\n",
"Table": lambda x: f"```html\n{element.metadata['text_as_html']}\n```",
"Header": lambda x: f"## {x}\n\n",
"Footer": lambda x: f"*{x}*\n\n",
"CodeSnippet": lambda x: f"```\n{x}\n```",
"PageNumber": lambda x: f"*Page {x}*\n\n",
"UncategorizedText": lambda x: f"{x}\n\n"
}
converter = markdown_mapping.get(element_category, lambda x: x)
return converter(content)
async def convert_document_to_markdown(elements):
"""
Convert all document elements to markdown.
Args:
elements: List of Unstructured API elements
Returns:
str: Complete markdown document
"""
markdown_parts = []
for element in elements:
markdown_text = await convert_element_to_markdown(element)
if markdown_text:
markdown_parts.append(markdown_text)
return "".join(markdown_parts)
def convert_chunks_to_langchain_documents(chunks):
"""
Convert chunks from hybrid search results to LangChain Document objects.
Args:
chunks: List of chunk dictionaries from hybrid search results
Returns:
List of LangChain Document objects
"""
try:
from langchain_core.documents import Document as LangChainDocument
except ImportError:
raise ImportError(
"LangChain is not installed. Please install it with `pip install langchain langchain-core`"
)
langchain_docs = []
for chunk in chunks:
# Extract content from the chunk
content = chunk.get("content", "")
# Create metadata dictionary
metadata = {
"chunk_id": chunk.get("chunk_id"),
"score": chunk.get("score"),
"rank": chunk.get("rank") if "rank" in chunk else None,
}
# Add document information to metadata
if "document" in chunk:
doc = chunk["document"]
metadata.update({
"document_id": doc.get("id"),
"document_title": doc.get("title"),
"document_type": doc.get("document_type"),
})
# Add document metadata if available
if "metadata" in doc:
# Prefix document metadata keys to avoid conflicts
doc_metadata = {f"doc_meta_{k}": v for k, v in doc.get("metadata", {}).items()}
metadata.update(doc_metadata)
# Add source URL if available in metadata
if "url" in doc.get("metadata", {}):
metadata["source"] = doc["metadata"]["url"]
elif "sourceURL" in doc.get("metadata", {}):
metadata["source"] = doc["metadata"]["sourceURL"]
# Ensure source_id is set for citation purposes
# Use document_id as the source_id if available
if "document_id" in metadata:
metadata["source_id"] = metadata["document_id"]
# Update content for citation mode - format as XML with explicit source_id
new_content = f"""
<document>
<metadata>
<source_id>{metadata.get("source_id", metadata.get("document_id", "unknown"))}</source_id>
</metadata>
<content>
<text>
{content}
</text>
</content>
</document>
"""
# Create LangChain Document
langchain_doc = LangChainDocument(
page_content=new_content,
metadata=metadata
)
langchain_docs.append(langchain_doc)
return langchain_docs

View file

@ -0,0 +1,95 @@
import logging
from typing import List, Dict, Any, Optional
from rerankers import Document as RerankerDocument
class RerankerService:
"""
Service for reranking documents using a configured reranker
"""
def __init__(self, reranker_instance=None):
"""
Initialize the reranker service
Args:
reranker_instance: The reranker instance to use for reranking
"""
self.reranker_instance = reranker_instance
def rerank_documents(self, query_text: str, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Rerank documents using the configured reranker
Args:
query_text: The query text to use for reranking
documents: List of document dictionaries to rerank
Returns:
List[Dict[str, Any]]: Reranked documents
"""
if not self.reranker_instance or not documents:
return documents
try:
# Create Document objects for the rerankers library
reranker_docs = []
for i, doc in enumerate(documents):
chunk_id = doc.get("chunk_id", f"chunk_{i}")
content = doc.get("content", "")
score = doc.get("score", 0.0)
document_info = doc.get("document", {})
reranker_docs.append(
RerankerDocument(
text=content,
doc_id=chunk_id,
metadata={
'document_id': document_info.get("id", ""),
'document_title': document_info.get("title", ""),
'document_type': document_info.get("document_type", ""),
'rrf_score': score
}
)
)
# Rerank using the configured reranker
reranking_results = self.reranker_instance.rank(
query=query_text,
docs=reranker_docs
)
# Process the results from the reranker
# Convert to serializable dictionaries
serialized_results = []
for result in reranking_results.results:
# Find the original document by id
original_doc = next((doc for doc in documents if doc.get("chunk_id") == result.document.doc_id), None)
if original_doc:
# Create a new document with the reranked score
reranked_doc = original_doc.copy()
reranked_doc["score"] = float(result.score)
reranked_doc["rank"] = result.rank
serialized_results.append(reranked_doc)
return serialized_results
except Exception as e:
# Log the error
logging.error(f"Error during reranking: {str(e)}")
# Fall back to original documents without reranking
return documents
@staticmethod
def get_reranker_instance(config=None) -> Optional['RerankerService']:
"""
Get a reranker service instance based on configuration
Args:
config: Configuration object that may contain a reranker_instance
Returns:
Optional[RerankerService]: A reranker service instance or None
"""
if config and hasattr(config, 'reranker_instance') and config.reranker_instance:
return RerankerService(config.reranker_instance)
return None

View file

@ -0,0 +1,211 @@
import asyncio
import re
from typing import List, Dict, Any, AsyncGenerator, Callable, Optional
from langchain.schema import Document
from gpt_researcher.agent import GPTResearcher
from gpt_researcher.utils.enum import ReportType, Tone, ReportSource
from dotenv import load_dotenv
load_dotenv()
class ResearchService:
@staticmethod
async def create_custom_prompt(user_query: str) -> str:
citation_prompt = f"""
You are a research assistant tasked with analyzing documents and providing comprehensive answers with proper citations in IEEE format.
<instructions>
1. Carefully analyze all provided documents in the <document> section's.
2. Extract relevant information that addresses the user's query.
3. Synthesize a comprehensive, well-structured answer using information from these documents.
4. For EVERY piece of information you include from the documents, add an IEEE-style citation in square brackets [X] where X is the source_id from the document's metadata.
5. Make sure ALL factual statements from the documents have proper citations.
6. If multiple documents support the same point, include all relevant citations [X], [Y].
7. Present information in a logical, coherent flow.
8. Use your own words to connect ideas, but cite ALL information from the documents.
9. If documents contain conflicting information, acknowledge this and present both perspectives with appropriate citations.
10. Do not make up or include information not found in the provided documents.
11. CRITICAL: You MUST use the exact source_id value from each document's metadata for citations. Do not create your own citation numbers.
12. CRITICAL: Every citation MUST be in the IEEE format [X] where X is the exact source_id value.
13. CRITICAL: Never renumber or reorder citations - always use the original source_id values.
14. CRITICAL: Do not return citations as clickable links.
15. CRITICAL: Never format citations as markdown links like "([1](https://example.com))". Always use plain square brackets only.
16. CRITICAL: Citations must ONLY appear as [X] or [X], [Y], [Z] format - never with parentheses, hyperlinks, or other formatting.
17. CRITICAL: Never make up citation numbers. Only use source_id values that are explicitly provided in the document metadata.
18. CRITICAL: If you are unsure about a source_id, do not include a citation rather than guessing or making one up.
</instructions>
<format>
- Write in clear, professional language suitable for academic or technical audiences
- Organize your response with appropriate paragraphs, headings, and structure
- Every fact from the documents must have an IEEE-style citation in square brackets [X] where X is the EXACT source_id from the document's metadata
- Citations should appear at the end of the sentence containing the information they support
- Multiple citations should be separated by commas: [X], [Y], [Z]
- No need to return references section. Just citation numbers in answer.
- NEVER create your own citation numbering system - use the exact source_id values from the documents.
- NEVER format citations as clickable links or as markdown links like "([1](https://example.com))". Always use plain square brackets only.
- NEVER make up citation numbers if you are unsure about the source_id. It is better to omit the citation than to guess.
</format>
<input_example>
<document>
<metadata>
<source_id>1</source_id>
</metadata>
<content>
<text>
The Great Barrier Reef is the world's largest coral reef system, stretching over 2,300 kilometers along the coast of Queensland, Australia. It comprises over 2,900 individual reefs and 900 islands.
</text>
</content>
</document>
<document>
<metadata>
<source_id>13</source_id>
</metadata>
<content>
<text>
Climate change poses a significant threat to coral reefs worldwide. Rising ocean temperatures have led to mass coral bleaching events in the Great Barrier Reef in 2016, 2017, and 2020.
</text>
</content>
</document>
<document>
<metadata>
<source_id>21</source_id>
</metadata>
<content>
<text>
The Great Barrier Reef was designated a UNESCO World Heritage Site in 1981 due to its outstanding universal value and biological diversity. It is home to over 1,500 species of fish and 400 types of coral.
</text>
</content>
</document>
</input_example>
<output_example>
The Great Barrier Reef is the world's largest coral reef system, stretching over 2,300 kilometers along the coast of Queensland, Australia [1]. It was designated a UNESCO World Heritage Site in 1981 due to its outstanding universal value and biological diversity [21]. The reef is home to over 1,500 species of fish and 400 types of coral [21]. Unfortunately, climate change poses a significant threat to coral reefs worldwide, with rising ocean temperatures leading to mass coral bleaching events in the Great Barrier Reef in 2016, 2017, and 2020 [13]. The reef system comprises over 2,900 individual reefs and 900 islands [1], making it an ecological treasure that requires protection from multiple threats [1], [13].
</output_example>
<incorrect_citation_formats>
DO NOT use any of these incorrect citation formats:
- Using parentheses and markdown links: ([1](https://github.com/MODSetter/SurfSense))
- Using parentheses around brackets: ([1])
- Using hyperlinked text: [link to source 1](https://example.com)
- Using footnote style: ... reef system¹
- Making up citation numbers when source_id is unknown
ONLY use plain square brackets [1] or multiple citations [1], [2], [3]
</incorrect_citation_formats>
Note that the citation numbers match exactly with the source_id values (1, 13, and 21) and are not renumbered sequentially. Citations follow IEEE style with square brackets and appear at the end of sentences.
Now, please research the following query:
<user_query_to_research>
{user_query}
</user_query_to_research>
"""
return citation_prompt
@staticmethod
async def stream_research(
user_query: str,
documents: List[Document] = None,
on_progress: Optional[Callable] = None,
research_mode: str = "GENERAL"
) -> str:
"""
Stream the research process using GPTResearcher
Args:
user_query: The user's query
documents: List of Document objects to use for research
on_progress: Optional callback for progress updates
research_mode: Research mode to use
Returns:
str: The final research report
"""
# Create a custom websocket-like object to capture streaming output
class StreamingWebsocket:
async def send_json(self, data):
if on_progress:
try:
# Filter out excessive logging of the prompt
if data.get("type") == "logs":
output = data.get("output", "")
# Check if this is a verbose prompt log
if "You are a research assistant tasked with analyzing documents" in output and len(output) > 500:
# Replace with a shorter message
data["output"] = f"Processing research for query: {user_query}"
result = await on_progress(data)
return result
except Exception as e:
print(f"Error in on_progress callback: {e}")
return None
streaming_websocket = StreamingWebsocket()
custom_prompt_for_ieee_citations = await ResearchService.create_custom_prompt(user_query)
if(research_mode == "GENERAL"):
research_report_type = ReportType.CustomReport.value
elif(research_mode == "DEEP"):
research_report_type = ReportType.ResearchReport.value
elif(research_mode == "DEEPER"):
research_report_type = ReportType.DetailedReport.value
# elif(research_mode == "DEEPEST"):
# research_report_type = ReportType.DeepResearch.value
# Initialize GPTResearcher with the streaming websocket
researcher = GPTResearcher(
query=custom_prompt_for_ieee_citations,
report_type=research_report_type,
report_format="IEEE",
report_source=ReportSource.LangChainDocuments.value,
tone=Tone.Formal,
documents=documents,
verbose=True,
websocket=streaming_websocket
)
# Conduct research
await researcher.conduct_research()
# Generate report with streaming
report = await researcher.write_report()
# Fix citation format
report = ResearchService.fix_citation_format(report)
return report
@staticmethod
def fix_citation_format(text: str) -> str:
"""
Fix any incorrectly formatted citations in the text.
Args:
text: The text to fix
Returns:
str: The text with fixed citations
"""
if not text:
return text
# More specific pattern to match only numeric citations in markdown-style links
# This matches patterns like ([1](https://github.com/...)) but not general links like ([Click here](https://...))
pattern = r'\(\[(\d+)\]\((https?://[^\)]+)\)\)'
# Replace with just [X] where X is the number
text = re.sub(pattern, r'[\1]', text)
# Also match other incorrect formats like ([1]) and convert to [1]
# Only match if the content inside brackets is a number
text = re.sub(r'\(\[(\d+)\]\)', r'[\1]', text)
return text

View file

@ -0,0 +1,99 @@
import json
from typing import List, Dict, Any, Generator
class StreamingService:
def __init__(self):
self.terminal_idx = 1
self.message_annotations = [
{
"type": "TERMINAL_INFO",
"content": []
},
{
"type": "SOURCES",
"content": []
},
{
"type": "ANSWER",
"content": []
}
]
def add_terminal_message(self, text: str, message_type: str = "info") -> str:
"""
Add a terminal message to the annotations and return the formatted response
Args:
text: The message text
message_type: The message type (info, success, error)
Returns:
str: The formatted response string
"""
self.message_annotations[0]["content"].append({
"id": self.terminal_idx,
"text": text,
"type": message_type
})
self.terminal_idx += 1
return self._format_annotations()
def update_sources(self, sources: List[Dict[str, Any]]) -> str:
"""
Update the sources in the annotations and return the formatted response
Args:
sources: List of source objects
Returns:
str: The formatted response string
"""
self.message_annotations[1]["content"] = sources
return self._format_annotations()
def update_answer(self, answer_content: List[str]) -> str:
"""
Update the answer in the annotations and return the formatted response
Args:
answer_content: The answer content as a list of strings
Returns:
str: The formatted response string
"""
self.message_annotations[2] = {
"type": "ANSWER",
"content": answer_content
}
return self._format_annotations()
def _format_annotations(self) -> str:
"""
Format the annotations as a string
Returns:
str: The formatted annotations string
"""
return f'8:{json.dumps(self.message_annotations)}\n'
def format_completion(self, prompt_tokens: int = 156, completion_tokens: int = 204) -> str:
"""
Format a completion message
Args:
prompt_tokens: Number of prompt tokens
completion_tokens: Number of completion tokens
Returns:
str: The formatted completion string
"""
total_tokens = prompt_tokens + completion_tokens
completion_data = {
"finishReason": "stop",
"usage": {
"promptTokens": prompt_tokens,
"completionTokens": completion_tokens,
"totalTokens": total_tokens
}
}
return f'd:{json.dumps(completion_data)}\n'

View file

@ -0,0 +1,4 @@
import uvicorn
if __name__ == "__main__":
uvicorn.run("app.app:app", host="0.0.0.0", log_level="info")

View file

@ -0,0 +1,27 @@
[project]
name = "surf-new-backend"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"asyncpg>=0.30.0",
"chonkie[all]>=0.4.1",
"fastapi>=0.115.8",
"fastapi-users[oauth,sqlalchemy]>=14.0.1",
"firecrawl-py>=1.12.0",
"gpt-researcher>=0.12.12",
"langchain-community>=0.3.17",
"langchain-unstructured>=0.1.6",
"litellm>=1.61.4",
"markdownify>=0.14.1",
"notion-client>=2.3.0",
"pgvector>=0.3.6",
"playwright>=1.50.0",
"rerankers[flashrank]>=0.7.1",
"slack-sdk>=3.34.0",
"tavily-python>=0.3.2",
"unstructured-client>=0.30.0",
"uvicorn[standard]>=0.34.0",
"validators>=0.34.0",
]

3271
surfsense_backend/uv.lock generated Normal file

File diff suppressed because it is too large Load diff