mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-01 01:59:06 +00:00
feat: SurfSense v0.0.6 init
This commit is contained in:
parent
18fc19e8d9
commit
da23012970
58 changed files with 8284 additions and 2076 deletions
13
.gitignore
vendored
13
.gitignore
vendored
|
@ -8,18 +8,7 @@ env.bak/
|
|||
venv.bak/
|
||||
data/
|
||||
.data
|
||||
|
||||
__pycache__
|
||||
__pycache__/
|
||||
.__pycache__
|
||||
|
||||
backend/examples
|
||||
|
||||
backend/old
|
||||
backend/RAGAgent
|
||||
backend/testfiles
|
||||
backend/.env
|
||||
|
||||
vectorstores/*
|
||||
vectorstores/
|
||||
.vectorstores
|
||||
surfsense_backend/.env
|
6
.gitmodules
vendored
6
.gitmodules
vendored
|
@ -1,6 +1,6 @@
|
|||
[submodule "SurfSense-Frontend"]
|
||||
path = SurfSense-Frontend
|
||||
url = https://github.com/MODSetter/SurfSense-Frontend.git
|
||||
[submodule "surfsense_frontend"]
|
||||
path = surfsense_frontend
|
||||
url = https://github.com/MODSetter/surfsense_frontend.git
|
||||
[submodule "ss-cross-browser-extension"]
|
||||
path = ss-cross-browser-extension
|
||||
url = https://github.com/MODSetter/ss-cross-browser-extension.git
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
Subproject commit 53211d0b590ff5c7aaf721fbf0a39c21d7f0b823
|
|
@ -1,35 +0,0 @@
|
|||
#Your Unstructed IO API Key. Random value if you are using unstructured locally or dont want to upload files
|
||||
UNSTRUCTURED_API_KEY = ""
|
||||
|
||||
#POSTGRES DB TO TRACK USERS
|
||||
POSTGRES_DATABASE_URL = "postgresql+psycopg2://postgres:postgres@localhost:5432/surfsense"
|
||||
|
||||
# API KEY TO VERIFY
|
||||
API_SECRET_KEY = "surfsense"
|
||||
|
||||
# Your JWT secret and algorithm
|
||||
SECRET_KEY = "your_secret_key"
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = "720"
|
||||
|
||||
# SEARCHE ENGINES TO USE FOR WEB SEARCH
|
||||
TAVILY_API_KEY=""
|
||||
|
||||
|
||||
# UNCOMMENT THE RESPECTIVE BELOW LINES FOR LOCAL/OPENAI SETUP
|
||||
|
||||
# For OpenAI LLM SETUP
|
||||
OPENAI_API_KEY="sk-proj-GHG....."
|
||||
FAST_LLM="openai:gpt-4o-mini"
|
||||
SMART_LLM="openai:gpt-4o-mini"
|
||||
EMBEDDING="openai:text-embedding-3-large"
|
||||
|
||||
# For Local Setups
|
||||
# OPENAI_API_KEY="123"
|
||||
# OLLAMA_BASE_URL="http://localhost:11434"
|
||||
# FAST_LLM="ollama:qwen2.5:7b"
|
||||
# SMART_LLM="ollama:qwen2.5:7b"
|
||||
# EMBEDDING="ollama:qwen2.5:7b"
|
||||
# TEMPRATURE="0"
|
||||
|
||||
|
12
backend/.gitignore
vendored
12
backend/.gitignore
vendored
|
@ -1,12 +0,0 @@
|
|||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
__pycache__
|
||||
__pycache__/
|
||||
.__pycache__
|
|
@ -1,19 +0,0 @@
|
|||
# Use official Python image
|
||||
FROM python:3.12-slim
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy backend source code
|
||||
COPY . /app
|
||||
|
||||
# Install dependencies
|
||||
RUN apt-get update && apt-get install -y gcc libpq-dev
|
||||
RUN pip install --upgrade pip
|
||||
RUN pip install -r requirements.txt
|
||||
|
||||
# Expose backend port
|
||||
EXPOSE 8000
|
||||
|
||||
# Run FastAPI using uvicorn
|
||||
CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"]
|
|
@ -1,501 +0,0 @@
|
|||
from datetime import datetime
|
||||
import json
|
||||
from typing import List
|
||||
from gpt_researcher import GPTResearcher
|
||||
from langchain_chroma import Chroma
|
||||
from langchain_ollama import OllamaLLM, OllamaEmbeddings
|
||||
from langchain_community.vectorstores.utils import filter_complex_metadata
|
||||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||
from langchain.docstore.document import Document
|
||||
from langchain_experimental.text_splitter import SemanticChunker
|
||||
from langchain.retrievers import ContextualCompressionRetriever
|
||||
from langchain.retrievers.document_compressors import FlashrankRerank
|
||||
import numpy as np
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import Depends, WebSocket
|
||||
|
||||
from prompts import report_prompt
|
||||
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from Utils.stringify import stringify
|
||||
from pydmodels import AIAnswer, Reference
|
||||
from database import SessionLocal
|
||||
from models import Documents
|
||||
load_dotenv()
|
||||
|
||||
SMART_LLM = os.environ.get("SMART_LLM")
|
||||
EMBEDDING = os.environ.get("EMBEDDING")
|
||||
IS_LOCAL_SETUP = True if SMART_LLM.startswith("ollama") else False
|
||||
|
||||
|
||||
def extract_model_name(model_string: str) -> tuple[str, str]:
|
||||
part1, part2 = model_string.split(":", 1) # Split into two parts at the first colon
|
||||
return part2
|
||||
|
||||
MODEL_NAME = extract_model_name(SMART_LLM)
|
||||
EMBEDDING_MODEL = extract_model_name(EMBEDDING)
|
||||
|
||||
# Dependency
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
class NumpyEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
return super().default(obj)
|
||||
|
||||
class ConnectionManager:
|
||||
def __init__(self):
|
||||
self.active_connections: List[WebSocket] = []
|
||||
|
||||
async def connect(self, websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
self.active_connections.append(websocket)
|
||||
|
||||
def disconnect(self, websocket: WebSocket):
|
||||
self.active_connections.remove(websocket)
|
||||
|
||||
async def send_personal_message(self, message: str, websocket: WebSocket):
|
||||
await websocket.send_text(message)
|
||||
|
||||
|
||||
class HIndices:
|
||||
def __init__(self, username, api_key='local'):
|
||||
"""
|
||||
"""
|
||||
self.username = username
|
||||
if(IS_LOCAL_SETUP == True):
|
||||
self.llm = OllamaLLM(model=MODEL_NAME,temperature=0)
|
||||
self.embeddings = OllamaEmbeddings(model=EMBEDDING_MODEL)
|
||||
else:
|
||||
self.llm = ChatOpenAI(temperature=0, model_name=MODEL_NAME, api_key=api_key)
|
||||
self.embeddings = OpenAIEmbeddings(api_key=api_key,model=EMBEDDING_MODEL)
|
||||
|
||||
self.summary_store = Chroma(
|
||||
collection_name="summary_store",
|
||||
embedding_function=self.embeddings,
|
||||
persist_directory="./vectorstores/" + username + "/summary_store_db", # Where to save data locally
|
||||
)
|
||||
self.detailed_store = Chroma(
|
||||
collection_name="detailed_store",
|
||||
embedding_function=self.embeddings,
|
||||
persist_directory="./vectorstores/" + username + "/detailed_store_db", # Where to save data locally
|
||||
)
|
||||
|
||||
# self.summary_store_size = len(self.summary_store.get()['documents'])
|
||||
# self.detailed_store_size = len(self.detailed_store.get()['documents'])
|
||||
|
||||
def summarize_file_doc(self, page_no, doc, search_space):
|
||||
|
||||
# Create an LLMChain for sub-query decomposition
|
||||
report_chain = report_prompt | self.llm
|
||||
|
||||
|
||||
if(IS_LOCAL_SETUP == True):
|
||||
|
||||
response = report_chain.invoke({"document": doc})
|
||||
|
||||
metadict = {
|
||||
"page": page_no,
|
||||
"summary": True,
|
||||
"search_space": search_space,
|
||||
}
|
||||
|
||||
metadict.update(doc.metadata)
|
||||
|
||||
# metadict['languages'] = metadict['languages'][0]
|
||||
|
||||
return Document(
|
||||
page_content=response,
|
||||
metadata=metadict
|
||||
)
|
||||
|
||||
else:
|
||||
response = report_chain.invoke({"document": doc})
|
||||
|
||||
metadict = {
|
||||
"page": page_no,
|
||||
"summary": True,
|
||||
"search_space": search_space,
|
||||
}
|
||||
|
||||
metadict.update(doc.metadata)
|
||||
|
||||
# metadict['languages'] = metadict['languages'][0]
|
||||
|
||||
return Document(
|
||||
page_content=response.content,
|
||||
metadata=metadict
|
||||
)
|
||||
|
||||
def summarize_webpage_doc(self, page_no, doc, search_space):
|
||||
|
||||
# Create an LLMChain for sub-query decomposition
|
||||
report_chain = report_prompt | self.llm
|
||||
|
||||
|
||||
if(IS_LOCAL_SETUP == True):
|
||||
response = report_chain.invoke({"document": doc})
|
||||
|
||||
return Document(
|
||||
page_content=response,
|
||||
metadata={
|
||||
"filetype": 'WEBPAGE',
|
||||
"page": page_no,
|
||||
"summary": True,
|
||||
"search_space": search_space,
|
||||
"BrowsingSessionId": doc.metadata['BrowsingSessionId'],
|
||||
"VisitedWebPageURL": doc.metadata['VisitedWebPageURL'],
|
||||
"VisitedWebPageTitle": doc.metadata['VisitedWebPageTitle'],
|
||||
"VisitedWebPageDateWithTimeInISOString": doc.metadata['VisitedWebPageDateWithTimeInISOString'],
|
||||
"VisitedWebPageReffererURL": doc.metadata['VisitedWebPageReffererURL'],
|
||||
"VisitedWebPageVisitDurationInMilliseconds": doc.metadata['VisitedWebPageVisitDurationInMilliseconds'],
|
||||
}
|
||||
)
|
||||
|
||||
else:
|
||||
response = report_chain.invoke({"document": doc})
|
||||
|
||||
return Document(
|
||||
page_content=response.content,
|
||||
metadata={
|
||||
"filetype": 'WEBPAGE',
|
||||
"page": page_no,
|
||||
"summary": True,
|
||||
"search_space": search_space,
|
||||
"BrowsingSessionId": doc.metadata['BrowsingSessionId'],
|
||||
"VisitedWebPageURL": doc.metadata['VisitedWebPageURL'],
|
||||
"VisitedWebPageTitle": doc.metadata['VisitedWebPageTitle'],
|
||||
"VisitedWebPageDateWithTimeInISOString": doc.metadata['VisitedWebPageDateWithTimeInISOString'],
|
||||
"VisitedWebPageReffererURL": doc.metadata['VisitedWebPageReffererURL'],
|
||||
"VisitedWebPageVisitDurationInMilliseconds": doc.metadata['VisitedWebPageVisitDurationInMilliseconds'],
|
||||
}
|
||||
)
|
||||
|
||||
def encode_docs_hierarchical(self, documents, search_space_instance, files_type, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Creates and Saves/Updates docs in hierarchical indices and postgres table
|
||||
"""
|
||||
page_no_offset = len(self.detailed_store.get()['documents'])
|
||||
# Process documents
|
||||
summaries = []
|
||||
if(files_type=='WEBPAGE'):
|
||||
batch_summaries = [self.summarize_webpage_doc(page_no = i + page_no_offset, doc=doc, search_space=search_space_instance.name) for i, doc in enumerate(documents)]
|
||||
else:
|
||||
batch_summaries = [self.summarize_file_doc(page_no = i + page_no_offset , doc=doc, search_space=search_space_instance.name) for i, doc in enumerate(documents)]
|
||||
|
||||
|
||||
summaries.extend(batch_summaries)
|
||||
|
||||
detailed_chunks = []
|
||||
|
||||
for i, summary in enumerate(summaries):
|
||||
|
||||
# Add single summary in vector store
|
||||
added_doc_id = self.summary_store.add_documents(filter_complex_metadata([summary]))
|
||||
|
||||
if(files_type=='WEBPAGE'):
|
||||
new_pg_doc = Documents(
|
||||
title=summary.metadata['VisitedWebPageTitle'],
|
||||
document_metadata=stringify(summary.metadata),
|
||||
page_content=documents[i].page_content,
|
||||
file_type='WEBPAGE',
|
||||
summary_vector_id=added_doc_id[0],
|
||||
)
|
||||
else:
|
||||
new_pg_doc = Documents(
|
||||
title=summary.metadata['filename'],
|
||||
document_metadata=stringify(summary.metadata),
|
||||
page_content=documents[i].page_content,
|
||||
file_type=summary.metadata['filetype'],
|
||||
summary_vector_id=added_doc_id[0],
|
||||
)
|
||||
|
||||
# Store it in PG
|
||||
search_space_instance.documents.append(new_pg_doc)
|
||||
db.commit()
|
||||
|
||||
# Semantic chucking for better contexual compression
|
||||
text_splitter = SemanticChunker(embeddings=self.embeddings)
|
||||
chunks = text_splitter.split_documents([documents[i]])
|
||||
|
||||
# Update metadata for detailed chunks
|
||||
for i, chunk in enumerate(chunks):
|
||||
chunk.metadata.update({
|
||||
"summary": False,
|
||||
"page": summary.metadata['page'],
|
||||
})
|
||||
|
||||
if(files_type == 'WEBPAGE'):
|
||||
ieee_content = (
|
||||
f"=======================================DOCUMENT METADATA==================================== \n"
|
||||
f"Source: {chunk.metadata['VisitedWebPageURL']} \n"
|
||||
f"Title: {chunk.metadata['VisitedWebPageTitle']} \n"
|
||||
f"Visited Date and Time : {chunk.metadata['VisitedWebPageDateWithTimeInISOString']} \n"
|
||||
f"============================DOCUMENT PAGE CONTENT CHUNK===================================== \n"
|
||||
f"Page Content Chunk: \n\n{chunk.page_content}\n\n"
|
||||
f"===================================================================================== \n"
|
||||
)
|
||||
|
||||
else:
|
||||
ieee_content = (
|
||||
f"=======================================DOCUMENT METADATA==================================== \n"
|
||||
f"Source: {chunk.metadata['filename']} \n"
|
||||
f"Title: {chunk.metadata['filename']} \n"
|
||||
f"Visited Date and Time : {datetime.now()} \n"
|
||||
f"============================DOCUMENT PAGE CONTENT CHUNK===================================== \n"
|
||||
f"Page Content Chunk: \n\n{chunk.page_content}\n\n"
|
||||
f"===================================================================================== \n"
|
||||
)
|
||||
|
||||
chunk.page_content = ieee_content
|
||||
|
||||
detailed_chunks.extend(chunks)
|
||||
|
||||
#update vector stores
|
||||
self.detailed_store.add_documents(filter_complex_metadata(detailed_chunks))
|
||||
|
||||
return self.summary_store, self.detailed_store
|
||||
|
||||
def delete_vector_stores(self, summary_ids_to_delete: list[str], db: Session = Depends(get_db)):
|
||||
self.summary_store.delete(ids=summary_ids_to_delete)
|
||||
return "success"
|
||||
|
||||
def summary_vector_search(self,query, search_space='GENERAL'):
|
||||
top_summaries_compressor = FlashrankRerank(top_n=20)
|
||||
|
||||
top_summaries_retreiver = ContextualCompressionRetriever(
|
||||
base_compressor=top_summaries_compressor, base_retriever=self.summary_store.as_retriever(search_kwargs={'filter': {'search_space': search_space}})
|
||||
)
|
||||
|
||||
return top_summaries_retreiver.invoke(query)
|
||||
|
||||
def deduplicate_references_and_update_answer(self, answer: str, references: List[Reference]) -> tuple[str, List[Reference]]:
|
||||
"""
|
||||
Deduplicates references and updates the answer text to maintain correct reference numbering.
|
||||
|
||||
Args:
|
||||
answer: The text containing reference citations
|
||||
references: List of Reference objects
|
||||
|
||||
Returns:
|
||||
tuple: (updated_answer, deduplicated_references)
|
||||
"""
|
||||
# Track unique references and create ID mapping using a dictionary comprehension
|
||||
unique_refs = {}
|
||||
id_mapping = {
|
||||
ref.id: unique_refs.setdefault(
|
||||
ref.source, Reference(id=str(len(unique_refs) + 1), title=ref.title, source=ref.source)
|
||||
).id
|
||||
for ref in references
|
||||
}
|
||||
|
||||
# Apply new mappings to the answer text
|
||||
updated_answer = answer
|
||||
for old_id, new_id in sorted(id_mapping.items(), key=lambda x: len(x[0]), reverse=True):
|
||||
updated_answer = updated_answer.replace(f'[{old_id}]', f'[{new_id}]')
|
||||
|
||||
return updated_answer, list(unique_refs.values())
|
||||
|
||||
async def ws_get_vectorstore_report(self, query: str, report_type: str, report_source: str, documents: List[Document],websocket: WebSocket) -> str:
|
||||
researcher = GPTResearcher(query=query, report_type=report_type, report_source=report_source, documents=documents, report_format="APA",websocket=websocket)
|
||||
await researcher.conduct_research()
|
||||
report = await researcher.write_report()
|
||||
return report
|
||||
|
||||
async def ws_get_web_report(self, query: str, report_type: str, report_source: str, websocket: WebSocket) -> str:
|
||||
researcher = GPTResearcher(query=query, report_type=report_type, report_source=report_source, report_format="APA",websocket=websocket)
|
||||
await researcher.conduct_research()
|
||||
report = await researcher.write_report()
|
||||
return report
|
||||
|
||||
async def ws_experimental_search(self, websocket: WebSocket, manager: ConnectionManager , query, search_space='GENERAL', report_type = "custom_report", report_source = "langchain_documents"):
|
||||
custom_prompt = """
|
||||
Please answer the following user query using only the **Document Page Content** provided below, while citing sources exclusively from the **Document Metadata** section, in the format shown. **Do not add any external information.**
|
||||
|
||||
**USER QUERY:** """ + query + """
|
||||
|
||||
**Answer Requirements:**
|
||||
- Provide a detailed long response using IEEE-style in-text citations (e.g., [1], [2]) based solely on the **Document Page Content**.
|
||||
- Use **Document Metadata** only for citation details and format each reference exactly once, with no duplicates.
|
||||
- Structure references in this format at the end of your response, using this format: (Access Date and Time). [Title or Filename](Source)
|
||||
|
||||
FOR EXAMPLE:
|
||||
EXAMPLE User Query : Explain the impact of artificial intelligence on modern healthcare.
|
||||
|
||||
EXAMPLE Given Documents:
|
||||
=======================================DOCUMENT METADATA==================================== \n"
|
||||
Source: https://www.reddit.com/r/ChatGPT/comments/13na8yp/highly_effective_prompt_for_summarizing_gpt4/ \n
|
||||
Title: Artificial intelligence\n
|
||||
Visited Date and Time : 2024-10-23T22:44:03-07:00 \n
|
||||
============================DOCUMENT PAGE CONTENT CHUNK===================================== \n
|
||||
Page Content Chunk: \n\nArtificial intelligence (AI) has significantly transformed modern healthcare by enhancing diagnostic accuracy, personalizing patient care, and optimizing operational efficiency. AI algorithms can analyze vast datasets to identify patterns that may be missed by human practitioners, leading to improved diagnostic outcomes. \n\n
|
||||
===================================================================================== \n
|
||||
|
||||
|
||||
=======================================DOCUMENT METADATA==================================== \n"
|
||||
Source: https://github.com/MODSetter/SurfSense \n
|
||||
Title: MODSetter/SurfSense: Personal AI Assistant for Internet Surfers and Researchers. \n
|
||||
Visited Date and Time : 2024-10-23T22:44:03-07:00 \n
|
||||
============================DOCUMENT PAGE CONTENT CHUNK===================================== \n
|
||||
Page Content Chunk: \n\nAI systems have been deployed in radiology to detect anomalies in medical imaging with high precision, reducing the risk of misdiagnosis and improving patient outcomes. Additionally, AI-powered chatbots and virtual assistants are being used to provide 24/7 support, answer queries, and offer personalized health advice\n\n
|
||||
===================================================================================== \n
|
||||
|
||||
|
||||
=======================================DOCUMENT METADATA==================================== \n"
|
||||
Source: https://github.com/MODSetter/SurfSense \n
|
||||
Title: MODSetter/SurfSense: Personal AI Assistant for Internet Surfers and Researchers. \n
|
||||
Visited Date and Time : 2024-10-23T22:44:03-07:00 \n
|
||||
============================DOCUMENT PAGE CONTENT CHUNK===================================== \n
|
||||
Page Content Chunk: \n\nAI algorithms can analyze a patient's genetic information to predict their risk of certain diseases and recommend tailored treatment plans. \n\n
|
||||
===================================================================================== \n
|
||||
|
||||
|
||||
=======================================DOCUMENT METADATA==================================== \n"
|
||||
Source: filename.pdf \n
|
||||
============================DOCUMENT PAGE CONTENT CHUNK===================================== \n
|
||||
Page Content Chunk: \n\nApart from diagnostics, AI-driven tools facilitate personalized treatment plans by considering individual patient data, thereby improving patient outcomes\n\n
|
||||
===================================================================================== \n
|
||||
|
||||
|
||||
|
||||
Ensure your response is structured something like this:
|
||||
**OUTPUT FORMAT:**
|
||||
|
||||
---
|
||||
|
||||
**Answer:**
|
||||
Artificial intelligence (AI) has significantly transformed modern healthcare by enhancing diagnostic accuracy, personalizing patient care, and optimizing operational efficiency. AI algorithms can analyze vast datasets to identify patterns that may be missed by human practitioners, leading to improved diagnostic outcomes [1]. For instance, AI systems have been deployed in radiology to detect anomalies in medical imaging with high precision [2]. Moreover, AI-driven tools facilitate personalized treatment plans by considering individual patient data, thereby improving patient outcomes [3].
|
||||
|
||||
**References:**
|
||||
1. (2024, October 23). [Artificial intelligence — GPT-4 Optimized: r/ChatGPT](https://www.reddit.com/r/ChatGPT/comments/13na8yp/highly_effective_prompt_for_summarizing_gpt4)
|
||||
2. (2024, October 23). [MODSetter/SurfSense: Personal AI Assistant for Internet Surfers and Researchers](https://github.com/MODSetter/SurfSense)
|
||||
3. (2024, October 23). [filename.pdf](filename.pdf)
|
||||
|
||||
---
|
||||
|
||||
"""
|
||||
|
||||
structured_llm = self.llm.with_structured_output(AIAnswer)
|
||||
|
||||
if report_source == "web" :
|
||||
if report_type == "custom_report" :
|
||||
ret_report = await self.ws_get_web_report(query=custom_prompt, report_type=report_type, report_source="web", websocket=websocket)
|
||||
else:
|
||||
ret_report = await self.ws_get_web_report(
|
||||
query=query,
|
||||
report_type=report_type,
|
||||
report_source="web",
|
||||
websocket=websocket
|
||||
)
|
||||
await manager.send_personal_message(
|
||||
json.dumps({"type": "stream", "content": "Converting to IEEE format..."}),
|
||||
websocket
|
||||
)
|
||||
ret_report = self.llm.invoke("I have a report written in APA format. Please convert it to IEEE format, ensuring that all citations, references, headings, and overall formatting adhere to the IEEE style guidelines. Maintain the original content and structure while applying the correct IEEE formatting rules. Just return the converted report thats it. NOW MY REPORT : " + ret_report).content
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
for chuck in structured_llm.stream(
|
||||
"Please extract and separate the references from the main text. "
|
||||
"References are formatted as follows:"
|
||||
"[Reference Id]. (Access Date and Time). [Title or Filename](Source or URL). "
|
||||
"Provide the text and references as distinct outputs. "
|
||||
"IMPORTANT : Never hallucinate the references. If there is no reference just return nothing in the reference field."
|
||||
"Here is the content to process: \n\n\n" + ret_report):
|
||||
# ans, sources = self.deduplicate_references_and_update_answer(answer=chuck.answer, references=chuck.references)
|
||||
|
||||
await manager.send_personal_message(
|
||||
json.dumps({"type": "stream", "sources": [source.model_dump() for source in chuck.references]}),
|
||||
websocket
|
||||
)
|
||||
|
||||
await manager.send_personal_message(
|
||||
json.dumps({"type": "stream", "content": ret_report}),
|
||||
websocket
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
contextdocs = []
|
||||
top_summaries_compressor = FlashrankRerank(top_n=5)
|
||||
details_compressor = FlashrankRerank(top_n=50)
|
||||
top_summaries_retreiver = ContextualCompressionRetriever(
|
||||
base_compressor=top_summaries_compressor, base_retriever=self.summary_store.as_retriever(search_kwargs={'filter': {'search_space': search_space}})#
|
||||
)
|
||||
|
||||
top_summaries_compressed_docs = top_summaries_retreiver.invoke(query)
|
||||
|
||||
rel_docs = filter_complex_metadata(top_summaries_compressed_docs)
|
||||
|
||||
await manager.send_personal_message(
|
||||
json.dumps({"type": "stream", "relateddocs": [relateddoc.model_dump() for relateddoc in rel_docs]}, cls=NumpyEncoder),
|
||||
websocket
|
||||
)
|
||||
|
||||
for summary in top_summaries_compressed_docs:
|
||||
# For each summary, retrieve relevant detailed chunks
|
||||
page_number = summary.metadata["page"]
|
||||
|
||||
detailed_compression_retriever = ContextualCompressionRetriever(
|
||||
base_compressor=details_compressor, base_retriever=self.detailed_store.as_retriever(search_kwargs={'filter': {'page': page_number}})
|
||||
)
|
||||
|
||||
detailed_compressed_docs = detailed_compression_retriever.invoke(
|
||||
query
|
||||
)
|
||||
|
||||
contextdocs.extend(detailed_compressed_docs)
|
||||
|
||||
|
||||
|
||||
|
||||
# local_report = asyncio.run(self.get_vectorstore_report(query=custom_prompt, report_type=report_type, report_source=report_source, documents=contextdocs))
|
||||
if report_source == "langchain_documents" :
|
||||
if report_type == "custom_report" :
|
||||
ret_report = await self.ws_get_vectorstore_report(query=custom_prompt, report_type=report_type, report_source=report_source, documents=contextdocs, websocket=websocket)
|
||||
else:
|
||||
ret_report = await self.ws_get_vectorstore_report(query=query, report_type=report_type, report_source=report_source, documents=contextdocs, websocket=websocket)
|
||||
await manager.send_personal_message(
|
||||
json.dumps({"type": "stream", "content": "Converting to IEEE format..."}),
|
||||
websocket
|
||||
)
|
||||
ret_report = self.llm.invoke("I have a report written in APA format. Please convert it to IEEE format, ensuring that all citations, references, headings, and overall formatting adhere to the IEEE style guidelines. Maintain the original content and structure while applying the correct IEEE formatting rules. Just return the converted report thats it. NOW MY REPORT : " + ret_report).content
|
||||
|
||||
|
||||
for chuck in structured_llm.stream(
|
||||
"Please extract and separate the references from the main text. "
|
||||
"References are formatted as follows:"
|
||||
"[Reference Id]. (Access Date and Time). [Title or Filename](Source or URL). "
|
||||
"Provide the text and references as distinct outputs. "
|
||||
"Ensure that in-text citation numbers such as [1], [2], (1), (2), etc., as well as in-text links or in-text citation links within the content, remain unaltered and are accurately extracted."
|
||||
"IMPORTANT : Never hallucinate the references. If there is no reference just return nothing in the reference field."
|
||||
"Here is the content to process: \n\n\n" + ret_report):
|
||||
ans, sources = self.deduplicate_references_and_update_answer(answer=chuck.answer, references=chuck.references)
|
||||
|
||||
await manager.send_personal_message(
|
||||
json.dumps({"type": "stream", "sources": [source.model_dump() for source in sources]}),
|
||||
websocket
|
||||
)
|
||||
|
||||
await manager.send_personal_message(
|
||||
json.dumps({"type": "stream", "content": ans}),
|
||||
websocket
|
||||
)
|
||||
|
||||
|
||||
|
||||
return
|
||||
|
||||
|
|
@ -1,27 +0,0 @@
|
|||
from collections import OrderedDict
|
||||
|
||||
try:
|
||||
from collections.abc import Mapping, Sequence
|
||||
except ImportError:
|
||||
from collections import Mapping, Sequence
|
||||
|
||||
import json
|
||||
|
||||
COMPACT_SEPARATORS = (',', ':')
|
||||
|
||||
def order_by_key(kv):
|
||||
key, val = kv
|
||||
return key
|
||||
|
||||
def recursive_order(node):
|
||||
if isinstance(node, Mapping):
|
||||
ordered_mapping = OrderedDict(sorted(node.items(), key=order_by_key))
|
||||
for key, value in ordered_mapping.items():
|
||||
ordered_mapping[key] = recursive_order(value)
|
||||
return ordered_mapping
|
||||
elif isinstance(node, Sequence) and not isinstance(node, (str, bytes)):
|
||||
return [recursive_order(item) for item in node]
|
||||
return node
|
||||
|
||||
def stringify(node):
|
||||
return json.dumps(recursive_order(node), separators=COMPACT_SEPARATORS)
|
|
@ -1,16 +0,0 @@
|
|||
from sqlalchemy import create_engine, Column, Integer, String
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
POSTGRES_DATABASE_URL = os.environ.get("POSTGRES_DATABASE_URL")
|
||||
|
||||
engine = create_engine(
|
||||
POSTGRES_DATABASE_URL
|
||||
)
|
||||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
Base = declarative_base()
|
|
@ -1,80 +0,0 @@
|
|||
from datetime import datetime
|
||||
# from typing import List
|
||||
from database import Base, engine
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Boolean, create_engine
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
class BaseModel(Base):
|
||||
__abstract__ = True
|
||||
__allow_unmapped__ = True
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
|
||||
class Chat(BaseModel):
|
||||
__tablename__ = "chats"
|
||||
|
||||
type = Column(String)
|
||||
title = Column(String)
|
||||
chats_list = Column(String)
|
||||
|
||||
search_space_id = Column(Integer, ForeignKey('searchspaces.id'))
|
||||
search_space = relationship('SearchSpace', back_populates='chats')
|
||||
|
||||
|
||||
class Documents(BaseModel):
|
||||
__tablename__ = "documents"
|
||||
|
||||
title = Column(String)
|
||||
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
file_type = Column(String)
|
||||
document_metadata = Column(String)
|
||||
page_content = Column(String)
|
||||
|
||||
summary_vector_id = Column(String)
|
||||
|
||||
search_space_id = Column(Integer, ForeignKey("searchspaces.id"))
|
||||
search_space = relationship("SearchSpace", back_populates="documents")
|
||||
|
||||
|
||||
class Podcast(BaseModel):
|
||||
__tablename__ = "podcasts"
|
||||
|
||||
title = Column(String)
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
is_generated = Column(Boolean, default=False)
|
||||
podcast_content = Column(String, default="")
|
||||
file_location = Column(String, default="")
|
||||
|
||||
search_space_id = Column(Integer, ForeignKey("searchspaces.id"))
|
||||
search_space = relationship("SearchSpace", back_populates="podcasts")
|
||||
|
||||
|
||||
|
||||
class SearchSpace(BaseModel):
|
||||
__tablename__ = "searchspaces"
|
||||
|
||||
name = Column(String, index=True)
|
||||
description = Column(String)
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
|
||||
user_id = Column(Integer, ForeignKey("users.id"))
|
||||
user = relationship("User", back_populates="search_spaces")
|
||||
|
||||
documents = relationship("Documents", back_populates="search_space", order_by="Documents.id")
|
||||
podcasts = relationship("Podcast", back_populates="search_space", order_by="Podcast.id")
|
||||
|
||||
chats = relationship('Chat', back_populates='search_space', order_by='Chat.id')
|
||||
|
||||
class User(BaseModel):
|
||||
__tablename__ = "users"
|
||||
|
||||
username = Column(String, unique=True, index=True)
|
||||
hashed_password = Column(String)
|
||||
|
||||
search_spaces = relationship("SearchSpace", back_populates="user")
|
||||
|
||||
|
||||
# Create the database tables if they don't exist
|
||||
User.metadata.create_all(bind=engine)
|
|
@ -1,70 +0,0 @@
|
|||
# Need to move new prompts to here will move after testing some more
|
||||
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from datetime import datetime, timezone
|
||||
|
||||
DATE_TODAY = "Today's date is " + datetime.now(timezone.utc).astimezone().isoformat() + '\n'
|
||||
|
||||
|
||||
report_template = """
|
||||
You are an eagle-eyed researcher, skilled at summarizing lengthy documents with precision and clarity. Create a comprehensive summary of the provided document, capturing the main ideas, key details, and essential arguments presented.
|
||||
|
||||
Length and depth:
|
||||
- Produce a detailed summary that captures all essential content from the document. Adjust the length as needed to ensure no critical information is omitted.
|
||||
|
||||
Structure:
|
||||
- Organize the summary logically.
|
||||
- Use clear headings and subheadings for different sections or themes, to help convey the flow of ideas.
|
||||
|
||||
Content to Include:
|
||||
- Highlight the main arguments.
|
||||
- Identify and include key supporting details.
|
||||
- Incorporate relevant examples or data that strengthen the key points.
|
||||
|
||||
Tone:
|
||||
- Use an objective, neutral tone, delivering precise and insightful analysis without personal opinions or interpretations.
|
||||
|
||||
# Steps
|
||||
|
||||
1. **Thoroughly read the entire text** to grasp the author's perspective, main arguments, and overall structure.
|
||||
2. **Identify key sections or themes** and thematically group the information.
|
||||
3. **Extract main points** from each section to capture supporting details and relevant examples.
|
||||
4. **Use headings/subheadings** to provide a clear and logically organized structure.
|
||||
5. **Write a conclusion** that succinctly encapsulates the overarching message and significance of the text.
|
||||
|
||||
# Output Format
|
||||
|
||||
- Provide a summary in well-structured paragraphs.
|
||||
- Clearly delineate different themes or sections with suitable headings or sub-headings.
|
||||
- Adjust the length of the summary based on the content's complexity and depth.
|
||||
- Conclusions should be clearly marked.
|
||||
|
||||
# Example
|
||||
|
||||
**Heading 1: Introduction to Main Theme**
|
||||
The document begins by discussing [main idea], outlining [initial point] with supporting data like [example].
|
||||
|
||||
**Heading 2: Supporting Arguments**
|
||||
The text then presents several supporting arguments, such as [supporting detail]. Notably, [data or statistic] is used to reinforce the main concept.
|
||||
|
||||
**Heading 3: Conclusion**
|
||||
In summary, [document's conclusion statement], highlighting the broader implications like [significance].
|
||||
|
||||
(This is an example format; each section should be expanded comprehensively based on the provided document.)
|
||||
|
||||
# Notes
|
||||
- Ensure the summary is adequately comprehensive without omitting crucial parts.
|
||||
- Aim for readability by using formal yet accessible language, maintaining depth without unnecessary complexity.
|
||||
|
||||
Now, Please summarize the following document:
|
||||
|
||||
<document>
|
||||
{document}
|
||||
</document>
|
||||
"""
|
||||
|
||||
|
||||
report_prompt = PromptTemplate(
|
||||
input_variables=["document"],
|
||||
template=report_template
|
||||
)
|
|
@ -1,95 +0,0 @@
|
|||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
apisecretkey: str
|
||||
|
||||
class DocMeta(BaseModel):
|
||||
BrowsingSessionId: Optional[str] = Field(default=None, description="BrowsingSessionId of Document")
|
||||
VisitedWebPageURL: Optional[str] = Field(default=None, description="VisitedWebPageURL of Document")
|
||||
VisitedWebPageTitle: Optional[str] = Field(default=None, description="VisitedWebPageTitle of Document")
|
||||
VisitedWebPageDateWithTimeInISOString: Optional[str] = Field(default=None, description="VisitedWebPageDateWithTimeInISOString of Document")
|
||||
VisitedWebPageReffererURL: Optional[str] = Field(default=None, description="VisitedWebPageReffererURL of Document")
|
||||
VisitedWebPageVisitDurationInMilliseconds: Optional[int] = Field(default=None, description="VisitedWebPageVisitDurationInMilliseconds of Document")
|
||||
|
||||
class CreatePodcast(BaseModel):
|
||||
token: str
|
||||
search_space_id: int
|
||||
title: str
|
||||
wordcount: int
|
||||
podcast_content: str
|
||||
|
||||
class CreateStorageSpace(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
token : str
|
||||
|
||||
class Reference(BaseModel):
|
||||
id: str = Field(..., description="reference no")
|
||||
title: str = Field(..., description="reference title.")
|
||||
source: str = Field(..., description="reference Source or URL. Prefer URL only include file names if no URL available.")
|
||||
|
||||
class AIAnswer(BaseModel):
|
||||
answer: str = Field(..., description="The provided answer, excluding references, but including in-text citation numbers such as [1], [2], (1), (2), etc.")
|
||||
references: List[Reference] = Field(..., description="References")
|
||||
|
||||
class DocWithContent(BaseModel):
|
||||
DocMetadata: Optional[str] = Field(default=None, description="Document Metadata")
|
||||
Content: Optional[str] = Field(default=None, description="Document Page Content")
|
||||
|
||||
class DocumentsToDelete(BaseModel):
|
||||
ids_to_delete: List[str]
|
||||
token: str
|
||||
|
||||
class UserQuery(BaseModel):
|
||||
query: str
|
||||
search_space: str
|
||||
token: str
|
||||
|
||||
class MainUserQuery(BaseModel):
|
||||
query: str
|
||||
search_space: str
|
||||
token: str
|
||||
|
||||
class ChatHistory(BaseModel):
|
||||
type: str
|
||||
content: str | List[DocMeta] | List[str]
|
||||
|
||||
class UserQueryWithChatHistory(BaseModel):
|
||||
chat: List[ChatHistory]
|
||||
query: str
|
||||
token: str
|
||||
|
||||
class DescriptionResponse(BaseModel):
|
||||
response: str
|
||||
|
||||
class RetrivedDocListItem(BaseModel):
|
||||
metadata: DocMeta
|
||||
pageContent: str
|
||||
|
||||
class RetrivedDocList(BaseModel):
|
||||
documents: List[RetrivedDocListItem]
|
||||
search_space_id: int
|
||||
token: str
|
||||
|
||||
class UserQueryResponse(BaseModel):
|
||||
response: str
|
||||
relateddocs: List[DocWithContent]
|
||||
|
||||
class NewUserQueryResponse(BaseModel):
|
||||
response: str
|
||||
sources: List[Reference]
|
||||
relateddocs: List[DocWithContent]
|
||||
|
||||
class NewUserChat(BaseModel):
|
||||
token: str
|
||||
type: str
|
||||
title: str
|
||||
chats_list: str
|
||||
|
||||
class ChatToUpdate(BaseModel):
|
||||
chatid: str
|
||||
token: str
|
||||
chats_list: str
|
|
@ -1,275 +0,0 @@
|
|||
aiofiles==24.1.0
|
||||
aiohappyeyeballs==2.4.3
|
||||
aiohttp==3.10.10
|
||||
aiosignal==1.3.1
|
||||
alabaster==1.0.0
|
||||
annotated-types==0.7.0
|
||||
anyio==4.6.2.post1
|
||||
arxiv==2.1.3
|
||||
asgiref==3.8.1
|
||||
attrs==24.2.0
|
||||
babel==2.16.0
|
||||
backoff==2.2.1
|
||||
bcrypt==4.2.0
|
||||
beautifulsoup4==4.12.3
|
||||
bleach==6.2.0
|
||||
Brotli==1.1.0
|
||||
build==1.2.2.post1
|
||||
cachetools==5.5.0
|
||||
certifi==2024.8.30
|
||||
cffi==1.17.1
|
||||
chardet==5.2.0
|
||||
charset-normalizer==3.4.0
|
||||
chroma-hnswlib==0.7.6
|
||||
chromadb==0.5.18
|
||||
click==8.1.7
|
||||
colorama==0.4.6
|
||||
coloredlogs==15.0.1
|
||||
cryptography==43.0.3
|
||||
cssselect2==0.7.0
|
||||
Cython==3.0.11
|
||||
dataclasses-json==0.6.7
|
||||
deepdiff==8.0.1
|
||||
defusedxml==0.7.1
|
||||
Deprecated==1.2.14
|
||||
distro==1.9.0
|
||||
docopt==0.6.2
|
||||
docstring_parser==0.16
|
||||
docutils==0.21.2
|
||||
durationpy==0.9
|
||||
ecdsa==0.19.0
|
||||
edge-tts==6.1.18
|
||||
elevenlabs==1.12.1
|
||||
emoji==2.14.0
|
||||
execnet==2.1.1
|
||||
fastapi==0.115.5
|
||||
fastjsonschema==2.20.0
|
||||
feedparser==6.0.11
|
||||
ffmpeg==1.4
|
||||
filelock==3.16.1
|
||||
filetype==1.2.0
|
||||
FlashRank==0.2.9
|
||||
flatbuffers==24.3.25
|
||||
fonttools==4.54.1
|
||||
frozenlist==1.5.0
|
||||
fsspec==2024.10.0
|
||||
fuzzywuzzy==0.18.0
|
||||
google-ai-generativelanguage==0.6.10
|
||||
google-api-core==2.23.0
|
||||
google-api-python-client==2.152.0
|
||||
google-auth==2.36.0
|
||||
google-auth-httplib2==0.2.0
|
||||
google-cloud-aiplatform==1.72.0
|
||||
google-cloud-bigquery==3.27.0
|
||||
google-cloud-core==2.4.1
|
||||
google-cloud-resource-manager==1.13.0
|
||||
google-cloud-storage==2.18.2
|
||||
google-cloud-texttospeech==2.21.0
|
||||
google-crc32c==1.6.0
|
||||
google-generativeai==0.8.3
|
||||
google-resumable-media==2.7.2
|
||||
googleapis-common-protos==1.66.0
|
||||
gpt-researcher==0.10.3
|
||||
greenlet==3.1.1
|
||||
grpc-google-iam-v1==0.13.1
|
||||
grpcio==1.67.1
|
||||
grpcio-status==1.67.1
|
||||
h11==0.14.0
|
||||
html5lib==1.1
|
||||
htmldocx==0.0.6
|
||||
httpcore==1.0.6
|
||||
httplib2==0.22.0
|
||||
httptools==0.6.4
|
||||
httpx==0.27.2
|
||||
httpx-sse==0.4.0
|
||||
huggingface-hub==0.26.2
|
||||
humanfriendly==10.0
|
||||
idna==3.10
|
||||
imagesize==1.4.1
|
||||
importlib_metadata==8.5.0
|
||||
importlib_resources==6.4.5
|
||||
iniconfig==2.0.0
|
||||
Jinja2==3.1.4
|
||||
jiter==0.7.1
|
||||
joblib==1.4.2
|
||||
json5==0.9.28
|
||||
json_repair==0.30.1
|
||||
jsonpatch==1.33
|
||||
jsonpath-python==1.0.6
|
||||
jsonpointer==3.0.0
|
||||
jsonschema==4.23.0
|
||||
jsonschema-specifications==2024.10.1
|
||||
jupyter_client==8.6.3
|
||||
jupyter_core==5.7.2
|
||||
jupyterlab_pygments==0.3.0
|
||||
kubernetes==31.0.0
|
||||
langchain==0.3.7
|
||||
langchain-chroma==0.1.4
|
||||
langchain-community==0.3.7
|
||||
langchain-core==0.3.17
|
||||
langchain-experimental==0.3.3
|
||||
langchain-google-genai==2.0.4
|
||||
langchain-google-vertexai==2.0.7
|
||||
langchain-ollama==0.2.0
|
||||
langchain-openai==0.2.8
|
||||
langchain-text-splitters==0.3.2
|
||||
langchain-unstructured==0.1.5
|
||||
langdetect==1.0.9
|
||||
langgraph==0.2.46
|
||||
langgraph-checkpoint==2.0.3
|
||||
langgraph-cli==0.1.53
|
||||
langgraph-sdk==0.1.35
|
||||
langsmith==0.1.142
|
||||
Levenshtein==0.26.1
|
||||
litellm==1.52.6
|
||||
loguru==0.7.2
|
||||
lxml==5.3.0
|
||||
lxml_html_clean==0.3.1
|
||||
Markdown==3.7
|
||||
markdown-it-py==3.0.0
|
||||
markdown2==2.5.1
|
||||
MarkupSafe==3.0.2
|
||||
marshmallow==3.23.1
|
||||
md2pdf==1.0.1
|
||||
mdurl==0.1.2
|
||||
mistune==3.0.2
|
||||
mmh3==5.0.1
|
||||
monotonic==1.6
|
||||
mpmath==1.3.0
|
||||
msgpack==1.1.0
|
||||
multidict==6.1.0
|
||||
mypy-extensions==1.0.0
|
||||
nbclient==0.10.0
|
||||
nbconvert==7.16.4
|
||||
nbformat==5.10.4
|
||||
nbsphinx==0.9.5
|
||||
nest-asyncio==1.6.0
|
||||
nltk==3.9.1
|
||||
numpy==1.26.4
|
||||
oauthlib==3.2.2
|
||||
olefile==0.47
|
||||
ollama==0.3.3
|
||||
onnxruntime==1.20.0
|
||||
openai==1.54.4
|
||||
opentelemetry-api==1.28.1
|
||||
opentelemetry-exporter-otlp-proto-common==1.28.1
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.28.1
|
||||
opentelemetry-instrumentation==0.49b1
|
||||
opentelemetry-instrumentation-asgi==0.49b1
|
||||
opentelemetry-instrumentation-fastapi==0.49b1
|
||||
opentelemetry-proto==1.28.1
|
||||
opentelemetry-sdk==1.28.1
|
||||
opentelemetry-semantic-conventions==0.49b1
|
||||
opentelemetry-util-http==0.49b1
|
||||
orderly-set==5.2.2
|
||||
orjson==3.10.11
|
||||
overrides==7.7.0
|
||||
packaging==24.2
|
||||
pandas==2.2.3
|
||||
pandoc==2.4
|
||||
pandocfilters==1.5.1
|
||||
passlib==1.7.4
|
||||
pillow==10.4.0
|
||||
platformdirs==4.3.6
|
||||
pluggy==1.5.0
|
||||
plumbum==1.9.0
|
||||
ply==3.11
|
||||
podcastfy==0.3.5
|
||||
posthog==3.7.0
|
||||
propcache==0.2.0
|
||||
proto-plus==1.25.0
|
||||
protobuf==5.28.3
|
||||
psutil==6.1.0
|
||||
psycopg2==2.9.10
|
||||
pyasn1==0.6.1
|
||||
pyasn1_modules==0.4.1
|
||||
pycparser==2.22
|
||||
pydantic==2.9.2
|
||||
pydantic-settings==2.6.1
|
||||
pydantic_core==2.23.4
|
||||
pydub==0.25.1
|
||||
pydyf==0.11.0
|
||||
Pygments==2.18.0
|
||||
PyMuPDF==1.24.13
|
||||
pyparsing==3.2.0
|
||||
pypdf==5.1.0
|
||||
pyphen==0.15.0
|
||||
PyPika==0.48.9
|
||||
pyproject_hooks==1.2.0
|
||||
pyreadline3==3.5.4
|
||||
pytest==8.3.3
|
||||
pytest-xdist==3.6.1
|
||||
python-dateutil==2.8.2
|
||||
python-docx==1.1.2
|
||||
python-dotenv==1.0.1
|
||||
python-iso639==2024.10.22
|
||||
python-jose==3.3.0
|
||||
python-Levenshtein==0.26.1
|
||||
python-magic==0.4.27
|
||||
python-multipart==0.0.17
|
||||
python-oxmsg==0.0.1
|
||||
pytz==2024.2
|
||||
pywin32==308
|
||||
PyYAML==6.0.2
|
||||
pyzmq==26.2.0
|
||||
RapidFuzz==3.10.1
|
||||
referencing==0.35.1
|
||||
regex==2024.11.6
|
||||
requests==2.32.3
|
||||
requests-oauthlib==2.0.0
|
||||
requests-toolbelt==1.0.0
|
||||
rich==13.9.4
|
||||
rpds-py==0.21.0
|
||||
rsa==4.9
|
||||
setuptools==75.4.0
|
||||
sgmllib3k==1.0.0
|
||||
shapely==2.0.6
|
||||
shellingham==1.5.4
|
||||
six==1.16.0
|
||||
sniffio==1.3.1
|
||||
snowballstemmer==2.2.0
|
||||
soupsieve==2.6
|
||||
Sphinx==8.1.3
|
||||
sphinx-autodoc-typehints==2.5.0
|
||||
sphinx-rtd-theme==3.0.1
|
||||
sphinxcontrib-applehelp==2.0.0
|
||||
sphinxcontrib-devhelp==2.0.0
|
||||
sphinxcontrib-htmlhelp==2.1.0
|
||||
sphinxcontrib-jquery==4.1
|
||||
sphinxcontrib-jsmath==1.0.1
|
||||
sphinxcontrib-qthelp==2.0.0
|
||||
sphinxcontrib-serializinghtml==2.0.0
|
||||
SQLAlchemy==2.0.35
|
||||
starlette==0.41.2
|
||||
sympy==1.13.3
|
||||
tenacity==9.0.0
|
||||
tiktoken==0.8.0
|
||||
tinycss2==1.4.0
|
||||
tinyhtml5==2.0.0
|
||||
tokenizers==0.20.3
|
||||
tornado==6.4.1
|
||||
tqdm==4.67.0
|
||||
traitlets==5.14.3
|
||||
typer==0.12.5
|
||||
types-PyYAML==6.0.12.20240917
|
||||
typing-inspect==0.9.0
|
||||
typing_extensions==4.12.2
|
||||
tzdata==2024.2
|
||||
unstructured==0.16.5
|
||||
unstructured-client==0.25.9
|
||||
uritemplate==4.1.1
|
||||
urllib3==2.2.3
|
||||
uvicorn==0.32.0
|
||||
watchfiles==0.24.0
|
||||
weasyprint==63.0
|
||||
webencodings==0.5.1
|
||||
websocket-client==1.8.0
|
||||
websockets==14.1
|
||||
wheel==0.44.0
|
||||
win32-setctime==1.1.0
|
||||
wrapt==1.16.0
|
||||
wsproto==1.2.0
|
||||
yarl==1.17.1
|
||||
youtube-transcript-api==0.6.2
|
||||
zipp==3.21.0
|
||||
zopfli==0.2.3.post1
|
|
@ -1,880 +0,0 @@
|
|||
from __future__ import annotations
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.documents import Document
|
||||
from langchain_ollama import OllamaLLM
|
||||
from langchain_openai import ChatOpenAI
|
||||
from sqlalchemy import insert
|
||||
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
||||
from langchain_unstructured import UnstructuredLoader
|
||||
|
||||
# OUR LIBS
|
||||
from HIndices import ConnectionManager, HIndices
|
||||
from Utils.stringify import stringify
|
||||
from prompts import DATE_TODAY
|
||||
from pydmodels import ChatToUpdate, CreatePodcast, CreateStorageSpace, DescriptionResponse, DocWithContent, DocumentsToDelete, MainUserQuery, NewUserChat, NewUserQueryResponse, UserCreate, UserQuery, RetrivedDocList, UserQueryResponse, UserQueryWithChatHistory
|
||||
from podcastfy.client import generate_podcast
|
||||
|
||||
# Auth Libs
|
||||
from fastapi import FastAPI, Depends, Form, HTTPException, Response, WebSocket, status, UploadFile, BackgroundTasks
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from jose import JWTError, jwt
|
||||
from datetime import datetime, timedelta
|
||||
from passlib.context import CryptContext
|
||||
from models import Chat, Documents, Podcast, SearchSpace, User
|
||||
from database import SessionLocal
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
|
||||
SMART_LLM = os.environ.get("SMART_LLM")
|
||||
IS_LOCAL_SETUP = True if SMART_LLM.startswith("ollama") else False
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.environ.get("ACCESS_TOKEN_EXPIRE_MINUTES"))
|
||||
ALGORITHM = os.environ.get("ALGORITHM")
|
||||
API_SECRET_KEY = os.environ.get("API_SECRET_KEY")
|
||||
SECRET_KEY = os.environ.get("SECRET_KEY")
|
||||
UNSTRUCTURED_API_KEY = os.environ.get("UNSTRUCTURED_API_KEY")
|
||||
|
||||
def extract_model_name(model_string: str) -> tuple[str, str]:
|
||||
part1, part2 = model_string.split(":", 1) # Split into two parts at the first colon
|
||||
return part2
|
||||
|
||||
MODEL_NAME = extract_model_name(SMART_LLM)
|
||||
|
||||
app = FastAPI()
|
||||
manager = ConnectionManager()
|
||||
|
||||
# Dependency
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
# Manual Origins
|
||||
# origins = [
|
||||
# "http://localhost:3000", # Adjust the port if your frontend runs on a different one
|
||||
# "https://yourfrontenddomain.com",
|
||||
# ]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Allows all origins from the list
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"], # Allows all methods
|
||||
allow_headers=["*"], # Allows all headers
|
||||
)
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
def get_user_by_username(db: Session, username: str):
|
||||
return db.query(User).filter(User.username == username).first()
|
||||
|
||||
def create_user(db: Session, user: UserCreate):
|
||||
hashed_password = pwd_context.hash(user.password)
|
||||
db_user = User(username=user.username, hashed_password=hashed_password)
|
||||
db.add(db_user)
|
||||
db.commit()
|
||||
return "complete"
|
||||
|
||||
@app.post("/register")
|
||||
def register_user(user: UserCreate, db: Session = Depends(get_db)):
|
||||
if(user.apisecretkey != API_SECRET_KEY):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
db_user = get_user_by_username(db, username=user.username)
|
||||
if db_user:
|
||||
raise HTTPException(status_code=400, detail="Username already registered")
|
||||
|
||||
del user.apisecretkey
|
||||
return create_user(db=db, user=user)
|
||||
|
||||
# Authenticate the user
|
||||
def authenticate_user(username: str, password: str, db: Session):
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user:
|
||||
return False
|
||||
if not pwd_context.verify(password, user.hashed_password):
|
||||
return False
|
||||
return user
|
||||
|
||||
# Create access token
|
||||
def create_access_token(data: dict, expires_delta: timedelta | None = None):
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=15)
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
@app.post("/token")
|
||||
def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
|
||||
user = authenticate_user(form_data.username, form_data.password, db)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token = create_access_token(
|
||||
data={"sub": user.username}, expires_delta=access_token_expires
|
||||
)
|
||||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
|
||||
def verify_token(token: str = Depends(oauth2_scheme)):
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
return payload
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
@app.get("/verify-token/{token}")
|
||||
async def verify_user_token(token: str):
|
||||
verify_token(token=token)
|
||||
return {"message": "Token is valid"}
|
||||
|
||||
|
||||
@app.post("/searchspace/{search_space_id}/chat/create")
|
||||
def create_chat_in_searchspace(chat: NewUserChat, search_space_id: int, db: Session = Depends(get_db)):
|
||||
try:
|
||||
payload = jwt.decode(chat.token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
search_space = db.query(SearchSpace).filter(
|
||||
SearchSpace.id == search_space_id,
|
||||
SearchSpace.user_id == db.query(User).filter(User.username == username).first().id
|
||||
).first()
|
||||
|
||||
if not search_space:
|
||||
raise HTTPException(status_code=404, detail="SearchSpace not found or does not belong to the user")
|
||||
|
||||
new_chat = Chat(type=chat.type, title=chat.title, chats_list=chat.chats_list)
|
||||
|
||||
search_space.chats.append(new_chat)
|
||||
|
||||
db.commit()
|
||||
db.refresh(new_chat)
|
||||
|
||||
return {"chat_id": new_chat.id}
|
||||
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
@app.post("/searchspace/{search_space_id}/chat/update")
|
||||
def update_chat_in_searchspace(chat: ChatToUpdate, search_space_id: int, db: Session = Depends(get_db)):
|
||||
try:
|
||||
payload = jwt.decode(chat.token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
chatindb = db.query(Chat).join(SearchSpace).filter(
|
||||
Chat.id == chat.chatid,
|
||||
SearchSpace.id == search_space_id,
|
||||
SearchSpace.user_id == db.query(User).filter(User.username == username).first().id
|
||||
).first()
|
||||
|
||||
if not chatindb:
|
||||
raise HTTPException(status_code=404, detail="Chat not found or does not belong to the searchspace owned by the user")
|
||||
|
||||
chatindb.chats_list = chat.chats_list
|
||||
db.commit()
|
||||
return {"message": "Chat Updated"}
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
@app.get("/searchspace/{search_space_id}/chat/delete/{token}/{chatid}")
|
||||
async def delete_chat_in_searchspace(token: str, search_space_id: int, chatid: str, db: Session = Depends(get_db)):
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
chatindb = db.query(Chat).join(SearchSpace).filter(
|
||||
Chat.id == chatid,
|
||||
SearchSpace.id == search_space_id,
|
||||
SearchSpace.user_id == db.query(User).filter(User.username == username).first().id
|
||||
).first()
|
||||
|
||||
if not chatindb:
|
||||
raise HTTPException(status_code=404, detail="Chat not found or does not belong to the searchspace owned by the user")
|
||||
|
||||
db.delete(chatindb)
|
||||
db.commit()
|
||||
return {"message": "Chat Deleted"}
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
@app.get("/searchspace/{search_space_id}/chat/{token}/{chatid}")
|
||||
def get_chat_by_id_in_searchspace(chatid: int, search_space_id: int, token: str, db: Session = Depends(get_db)):
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
chat = db.query(Chat).join(SearchSpace).filter(
|
||||
Chat.id == chatid,
|
||||
SearchSpace.id == search_space_id,
|
||||
SearchSpace.user_id == db.query(User).filter(User.username == username).first().id
|
||||
).first()
|
||||
|
||||
if not chat:
|
||||
raise HTTPException(status_code=404, detail="Chat not found or does not belong to the searchspace owned by the user")
|
||||
|
||||
return chat
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
@app.get("/searchspace/{search_space_id}/chats/{token}")
|
||||
def get_chats_in_searchspace(search_space_id: int, token: str, db: Session = Depends(get_db)):
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# Filter chats that are specifically in the given search space
|
||||
chats = db.query(Chat).filter(
|
||||
Chat.search_space_id == search_space_id,
|
||||
SearchSpace.user_id == user.id
|
||||
).join(SearchSpace).all()
|
||||
|
||||
return chats
|
||||
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
|
||||
|
||||
|
||||
@app.get("/user/{token}/searchspace/{search_space_id}/documents/")
|
||||
def get_user_documents(search_space_id: int, token: str, db: Session = Depends(get_db)):
|
||||
try:
|
||||
# Decode the token to get the username
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
# Get the user by username and ensure they exist
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# Verify the search space belongs to the user
|
||||
search_space = db.query(SearchSpace).filter(
|
||||
SearchSpace.id == search_space_id,
|
||||
SearchSpace.user_id == user.id
|
||||
).first()
|
||||
if not search_space:
|
||||
raise HTTPException(status_code=404, detail="Search space not found or does not belong to the user")
|
||||
|
||||
# Retrieve documents associated with the search space
|
||||
return db.query(Documents).filter(Documents.search_space_id == search_space_id).all()
|
||||
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
@app.get("/user/{token}/searchspace/{search_space_id}/")
|
||||
def get_user_search_space_by_id(search_space_id: int, token: str, db: Session = Depends(get_db)):
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
# Get the user by username
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# Get the search space by ID and verify it belongs to this user
|
||||
search_space = db.query(SearchSpace).filter(
|
||||
SearchSpace.id == search_space_id,
|
||||
SearchSpace.user_id == user.id
|
||||
).first()
|
||||
if not search_space:
|
||||
raise HTTPException(status_code=404, detail="Search space not found or does not belong to the user")
|
||||
|
||||
return search_space
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
@app.get("/user/{token}/searchspaces/")
|
||||
def get_user_search_spaces(token: str, db: Session = Depends(get_db)):
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
|
||||
return db.query(SearchSpace).filter(SearchSpace.user_id == user.id).all()
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
@app.post("/user/create/searchspace/")
|
||||
def create_user_search_space(data: CreateStorageSpace, db: Session = Depends(get_db)):
|
||||
try:
|
||||
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
|
||||
db_search_space = SearchSpace(user_id=user.id, name=data.name, description=data.description)
|
||||
db.add(db_search_space)
|
||||
db.commit()
|
||||
db.refresh(db_search_space)
|
||||
return db_search_space
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
@app.post("/user/save/")
|
||||
def save_user_extension_documents(data: RetrivedDocList, db: Session = Depends(get_db)):
|
||||
try:
|
||||
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
# Get the user by username and ensure they exist
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# Verify the search space belongs to the user
|
||||
search_space = db.query(SearchSpace).filter(
|
||||
SearchSpace.id == data.search_space_id,
|
||||
SearchSpace.user_id == user.id
|
||||
).first()
|
||||
if not search_space:
|
||||
raise HTTPException(status_code=404, detail="Search space not found or does not belong to the user")
|
||||
|
||||
|
||||
# all_search_space_docs = db.query(SearchSpace).filter(
|
||||
# SearchSpace.user_id == user.id
|
||||
# ).all()
|
||||
|
||||
# total_doc_count = 0
|
||||
# for search_space in all_search_space_docs:
|
||||
# total_doc_count += db.query(Documents).filter(Documents.search_space_id == search_space.id).count()
|
||||
|
||||
print(f"STARTED")
|
||||
|
||||
# Initialize containers for documents and entries
|
||||
# DocumentPgEntry = []
|
||||
raw_documents = []
|
||||
|
||||
# Process each document in the retrieved document list
|
||||
for doc in data.documents:
|
||||
# Construct document content
|
||||
content = (
|
||||
f"USER BROWSING SESSION EVENT: \n"
|
||||
f"=======================================METADATA==================================== \n"
|
||||
f"User Browsing Session ID : {doc.metadata.BrowsingSessionId} \n"
|
||||
f"User Visited website with url : {doc.metadata.VisitedWebPageURL} \n"
|
||||
f"This visited website url had title : {doc.metadata.VisitedWebPageTitle} \n"
|
||||
f"User Visited this website from referring url : {doc.metadata.VisitedWebPageReffererURL} \n"
|
||||
f"User Visited this website url at this Date and Time : {doc.metadata.VisitedWebPageDateWithTimeInISOString} \n"
|
||||
f"User Visited this website for : {str(doc.metadata.VisitedWebPageVisitDurationInMilliseconds)} milliseconds. \n"
|
||||
f"===================================================================================== \n"
|
||||
f"Webpage Content of the visited webpage url in markdown format : \n\n{doc.pageContent}\n\n"
|
||||
f"===================================================================================== \n"
|
||||
)
|
||||
raw_documents.append(Document(page_content=content, metadata=doc.metadata.__dict__))
|
||||
|
||||
|
||||
|
||||
|
||||
# pgdocmeta = stringify(doc.metadata.__dict__)
|
||||
|
||||
# DocumentPgEntry.append(Documents(
|
||||
# file_type='WEBPAGE',
|
||||
# title=doc.metadata.VisitedWebPageTitle,
|
||||
# search_space=search_space,
|
||||
# document_metadata=pgdocmeta,
|
||||
# page_content=content
|
||||
# ))
|
||||
|
||||
# # Save documents in PostgreSQL
|
||||
# search_space.documents.extend(DocumentPgEntry)
|
||||
# db.commit()
|
||||
|
||||
# Create hierarchical indices
|
||||
if IS_LOCAL_SETUP == True:
|
||||
index = HIndices(username=username)
|
||||
else:
|
||||
index = HIndices(username=username, api_key=OPENAI_API_KEY)
|
||||
|
||||
# Save indices in vector stores
|
||||
index.encode_docs_hierarchical(
|
||||
documents=raw_documents,
|
||||
search_space_instance=search_space,
|
||||
files_type='WEBPAGE',
|
||||
db=db
|
||||
)
|
||||
|
||||
print("FINISHED")
|
||||
|
||||
return {
|
||||
"success": "Save Job Completed Successfully"
|
||||
}
|
||||
|
||||
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
@app.post("/user/uploadfiles/")
|
||||
def save_user_documents(files: list[UploadFile], token: str = Depends(oauth2_scheme), search_space_id: int = Form(...), db: Session = Depends(get_db)):
|
||||
try:
|
||||
# Decode and verify the token
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
# Get the user by username and ensure they exist
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# Verify the search space belongs to the user
|
||||
search_space = db.query(SearchSpace).filter(
|
||||
SearchSpace.id == search_space_id,
|
||||
SearchSpace.user_id == user.id
|
||||
).first()
|
||||
if not search_space:
|
||||
raise HTTPException(status_code=404, detail="Search space not found or does not belong to the user")
|
||||
|
||||
docs = []
|
||||
|
||||
for file in files:
|
||||
if file.content_type.startswith('image'):
|
||||
loader = UnstructuredLoader(
|
||||
file=file.file,
|
||||
api_key=UNSTRUCTURED_API_KEY,
|
||||
partition_via_api=True,
|
||||
chunking_strategy="basic",
|
||||
max_characters=90000,
|
||||
include_orig_elements=False,
|
||||
)
|
||||
else:
|
||||
loader = UnstructuredLoader(
|
||||
file=file.file,
|
||||
api_key=UNSTRUCTURED_API_KEY,
|
||||
partition_via_api=True,
|
||||
chunking_strategy="basic",
|
||||
max_characters=90000,
|
||||
include_orig_elements=False,
|
||||
strategy="fast"
|
||||
)
|
||||
|
||||
filedocs = loader.load()
|
||||
|
||||
|
||||
fileswithfilename = []
|
||||
for f in filedocs:
|
||||
temp = f
|
||||
temp.metadata['filename'] = file.filename
|
||||
fileswithfilename.append(temp)
|
||||
|
||||
docs.extend(fileswithfilename)
|
||||
|
||||
raw_documents = []
|
||||
|
||||
# Process each document in the retrieved document list
|
||||
for doc in docs:
|
||||
raw_documents.append(Document(page_content=doc.page_content, metadata=doc.metadata))
|
||||
|
||||
# Create hierarchical indices
|
||||
if IS_LOCAL_SETUP == True:
|
||||
index = HIndices(username=username)
|
||||
else:
|
||||
index = HIndices(username=username, api_key=OPENAI_API_KEY)
|
||||
|
||||
# Save indices in vector stores
|
||||
index.encode_docs_hierarchical(documents=raw_documents, search_space_instance=search_space, files_type='OTHER', db=db)
|
||||
|
||||
print("FINISHED")
|
||||
|
||||
return {
|
||||
"message": "Files Uploaded Successfully"
|
||||
}
|
||||
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
@app.websocket("/beta/chat/{search_space_id}/{token}")
|
||||
async def searchspace_chat_websocket_endpoint(websocket: WebSocket, search_space_id: int, token: str, db: Session = Depends(get_db)):
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
# Get the user by username and ensure they exist
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# Verify the search space belongs to the user
|
||||
search_space = db.query(SearchSpace).filter(
|
||||
SearchSpace.id == search_space_id,
|
||||
SearchSpace.user_id == user.id
|
||||
).first()
|
||||
if not search_space:
|
||||
raise HTTPException(status_code=404, detail="Search space not found or does not belong to the user")
|
||||
|
||||
await manager.connect(websocket)
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
message = json.loads(data)
|
||||
# print(message)
|
||||
if message["type"] == "search_space_chat":
|
||||
query = message["content"]
|
||||
|
||||
if message["searchtype"] == "local" :
|
||||
report_source = "langchain_documents"
|
||||
else:
|
||||
report_source = message["searchtype"]
|
||||
|
||||
if message["answertype"] == "general_answer" :
|
||||
report_type = "custom_report"
|
||||
else:
|
||||
report_type = message["answertype"]
|
||||
|
||||
|
||||
# Create Heirarical Indecices
|
||||
if(IS_LOCAL_SETUP == True):
|
||||
index = HIndices(username=username)
|
||||
else:
|
||||
index = HIndices(username=username,api_key=OPENAI_API_KEY)
|
||||
|
||||
|
||||
await index.ws_experimental_search(websocket=websocket, manager=manager, query=query, search_space=search_space.name, report_type=report_type, report_source=report_source)
|
||||
|
||||
await manager.send_personal_message(
|
||||
json.dumps({"type": "end"}),
|
||||
websocket
|
||||
)
|
||||
|
||||
|
||||
|
||||
if message["type"] == "multiple_documents_chat":
|
||||
query = message["content"]
|
||||
received_chat_history = message["chat_history"]
|
||||
|
||||
chatHistory = []
|
||||
|
||||
chatHistory = [
|
||||
SystemMessage(
|
||||
content=DATE_TODAY + """You are an helpful assistant for question-answering tasks.
|
||||
Use the following pieces of retrieved context to answer the question.
|
||||
If you don't know the answer, just say that you don't know.
|
||||
Context:""" + str(received_chat_history[0]['relateddocs']))
|
||||
]
|
||||
|
||||
for data in received_chat_history[1:]:
|
||||
if data["role"] == "user":
|
||||
chatHistory.append(HumanMessage(content=data["content"]))
|
||||
|
||||
if data["role"] == "assistant":
|
||||
chatHistory.append(AIMessage(content=data["content"]))
|
||||
|
||||
|
||||
chatHistory.append(("human", "{input}"))
|
||||
|
||||
qa_prompt = ChatPromptTemplate.from_messages(chatHistory)
|
||||
|
||||
if(IS_LOCAL_SETUP == True):
|
||||
llm = OllamaLLM(model=MODEL_NAME,temperature=0)
|
||||
else:
|
||||
llm = ChatOpenAI(temperature=0, model_name=MODEL_NAME, api_key=OPENAI_API_KEY)
|
||||
|
||||
descriptionchain = qa_prompt | llm
|
||||
|
||||
streamingResponse = ""
|
||||
counter = 0
|
||||
for res in descriptionchain.stream({"input": query}):
|
||||
streamingResponse += res.content
|
||||
|
||||
if (counter < 20) :
|
||||
counter += 1
|
||||
else :
|
||||
await manager.send_personal_message(
|
||||
json.dumps({"type": "stream", "content": streamingResponse}),
|
||||
websocket
|
||||
)
|
||||
|
||||
counter = 0
|
||||
|
||||
await manager.send_personal_message(
|
||||
json.dumps({"type": "stream", "content": streamingResponse}),
|
||||
websocket
|
||||
)
|
||||
|
||||
await manager.send_personal_message(
|
||||
json.dumps({"type": "end"}),
|
||||
websocket
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
finally:
|
||||
manager.disconnect(websocket)
|
||||
except JWTError:
|
||||
await websocket.close(code=4003, reason="Invalid token")
|
||||
|
||||
@app.post("/user/searchspace/create-podcast")
|
||||
async def create_podcast(
|
||||
data: CreatePodcast,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
try:
|
||||
# Verify token and get username
|
||||
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
# Get user
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# Verify search space belongs to user
|
||||
search_space = db.query(SearchSpace).filter(
|
||||
SearchSpace.id == data.search_space_id,
|
||||
SearchSpace.user_id == user.id
|
||||
).first()
|
||||
if not search_space:
|
||||
raise HTTPException(status_code=404, detail="Search space not found or does not belong to the user")
|
||||
|
||||
# Create new podcast entry
|
||||
new_podcast = Podcast(
|
||||
title=data.title,
|
||||
podcast_content=data.podcast_content,
|
||||
search_space_id=search_space.id
|
||||
)
|
||||
|
||||
db.add(new_podcast)
|
||||
db.commit()
|
||||
db.refresh(new_podcast)
|
||||
|
||||
podcast_config = {
|
||||
'word_count': data.wordcount,
|
||||
'podcast_name': 'SurfSense Podcast',
|
||||
'podcast_tagline': 'Your Own Personal Podcast.',
|
||||
'output_language': 'English',
|
||||
'user_instructions': 'Make if fun and engaging',
|
||||
'engagement_techniques': ['Rhetorical Questions', 'Personal Testimonials', 'Quotes', 'Anecdotes', 'Analogies', 'Humor'],
|
||||
}
|
||||
|
||||
try:
|
||||
background_tasks.add_task(
|
||||
generate_podcast_background,
|
||||
new_podcast.id,
|
||||
data.podcast_content,
|
||||
MODEL_NAME,
|
||||
"OPENAI_API_KEY",
|
||||
podcast_config,
|
||||
db
|
||||
)
|
||||
# # Check MODEL NAME behavior on Local Setups
|
||||
# saved_file_location = generate_podcast(
|
||||
# text=data.podcast_content,
|
||||
# llm_model_name=MODEL_NAME,
|
||||
# api_key_label="OPENAI_API_KEY",
|
||||
# conversation_config=podcast_config,
|
||||
# )
|
||||
|
||||
# new_podcast.file_location = saved_file_location
|
||||
# new_podcast.is_generated = True
|
||||
|
||||
# db.commit()
|
||||
# db.refresh(new_podcast)
|
||||
|
||||
|
||||
return {"message": "Podcast created successfully", "podcast_id": new_podcast.id}
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
async def generate_podcast_background(
|
||||
podcast_id: int,
|
||||
podcast_content: str,
|
||||
model_name: str,
|
||||
api_key_label: str,
|
||||
conversation_config: dict,
|
||||
db: Session
|
||||
):
|
||||
try:
|
||||
saved_file_location = generate_podcast(
|
||||
text=podcast_content,
|
||||
llm_model_name=model_name,
|
||||
api_key_label=api_key_label,
|
||||
conversation_config=conversation_config,
|
||||
)
|
||||
|
||||
# Update podcast in database
|
||||
podcast = db.query(Podcast).filter(Podcast.id == podcast_id).first()
|
||||
if podcast:
|
||||
podcast.file_location = saved_file_location
|
||||
podcast.is_generated = True
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
# Log the error or handle it appropriately
|
||||
print(f"Error generating podcast: {str(e)}")
|
||||
|
||||
|
||||
@app.get("/user/{token}/searchspace/{search_space_id}/download-podcast/{podcast_id}")
|
||||
async def download_podcast(search_space_id: int, podcast_id: int, token: str, db: Session = Depends(get_db)):
|
||||
try:
|
||||
# Verify the token and get the username
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
# Get the user by username
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# Verify the search space belongs to the user
|
||||
search_space = db.query(SearchSpace).filter(
|
||||
SearchSpace.id == search_space_id,
|
||||
SearchSpace.user_id == user.id
|
||||
).first()
|
||||
if not search_space:
|
||||
raise HTTPException(status_code=404, detail="Search space not found or does not belong to the user")
|
||||
|
||||
# Retrieve the podcast file from the database
|
||||
podcast = db.query(Podcast).filter(
|
||||
Podcast.id == podcast_id,
|
||||
Podcast.search_space_id == search_space_id
|
||||
).first()
|
||||
if not podcast:
|
||||
raise HTTPException(status_code=404, detail="Podcast not found in the specified search space")
|
||||
|
||||
# Read the file content
|
||||
with open(podcast.file_location, "rb") as file:
|
||||
file_content = file.read()
|
||||
|
||||
# Create a response with the file content
|
||||
response = Response(content=file_content)
|
||||
response.headers["Content-Disposition"] = f"attachment; filename={podcast.title}.mp3"
|
||||
response.headers["Content-Type"] = "audio/mpeg"
|
||||
|
||||
return response
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/user/{token}/searchspace/{search_space_id}/podcasts")
|
||||
async def get_user_podcasts(token: str, search_space_id: int, db: Session = Depends(get_db)):
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
search_space = db.query(SearchSpace).filter(
|
||||
SearchSpace.id == search_space_id,
|
||||
SearchSpace.user_id == user.id
|
||||
).first()
|
||||
if not search_space:
|
||||
raise HTTPException(status_code=404, detail="Search space not found or does not belong to the user")
|
||||
|
||||
podcasts = db.query(Podcast).filter(Podcast.search_space_id == search_space_id).all()
|
||||
return podcasts
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
# Incomplete function, needs to be implemented based on the actual requirements and database structure
|
||||
@app.post("/searchspace/{search_space_id}/delete/docs")
|
||||
def delete_all_related_data(search_space_id: int, data: DocumentsToDelete, db: Session = Depends(get_db)):
|
||||
try:
|
||||
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
# Get the user by username and ensure they exist
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# Verify the search space belongs to the user
|
||||
search_space = db.query(SearchSpace).filter(
|
||||
SearchSpace.id == search_space_id,
|
||||
SearchSpace.user_id == user.id
|
||||
).first()
|
||||
if not search_space:
|
||||
raise HTTPException(status_code=404, detail="Search space not found or does not belong to the user")
|
||||
|
||||
if IS_LOCAL_SETUP:
|
||||
index = HIndices(username=username)
|
||||
else:
|
||||
index = HIndices(username=username, api_key=OPENAI_API_KEY)
|
||||
|
||||
message = index.delete_vector_stores(summary_ids_to_delete=data.ids_to_delete, db=db, search_space=search_space.name)
|
||||
|
||||
return {
|
||||
"message": message
|
||||
}
|
||||
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, ws="wsproto", host="127.0.0.1", port=8000)
|
|
@ -1,50 +0,0 @@
|
|||
version: '3.8'
|
||||
|
||||
services:
|
||||
# PostgreSQL Database
|
||||
db:
|
||||
image: postgres:13
|
||||
environment:
|
||||
POSTGRES_USER: your_postgres_user
|
||||
POSTGRES_PASSWORD: your_postgres_password
|
||||
POSTGRES_DB: surfsense_db
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
ports:
|
||||
- "5432:5432"
|
||||
networks:
|
||||
- surfsense-network
|
||||
|
||||
# Backend Service (FastAPI)
|
||||
backend:
|
||||
build:
|
||||
context: ./backend
|
||||
ports:
|
||||
- "8000:8000"
|
||||
env_file:
|
||||
- ./backend/.env
|
||||
depends_on:
|
||||
- db
|
||||
# privileged: true #when backend is pointing to the hostmaschine with localhost or 127.0.0.1 you need to add privileged: true else the container points at itself
|
||||
networks:
|
||||
- surfsense-network
|
||||
|
||||
# Frontend Service (Next.js)
|
||||
frontend:
|
||||
build:
|
||||
context: ./SurfSense-Frontend
|
||||
ports:
|
||||
- "3000:3000"
|
||||
env_file:
|
||||
- ./SurfSense-Frontend/.env
|
||||
networks:
|
||||
- surfsense-network
|
||||
|
||||
# Volumes for persistent storage
|
||||
volumes:
|
||||
postgres_data:
|
||||
|
||||
# Docker network
|
||||
networks:
|
||||
surfsense-network:
|
||||
driver: bridge
|
21
surfsense_backend/.env.example
Normal file
21
surfsense_backend/.env.example
Normal file
|
@ -0,0 +1,21 @@
|
|||
DATABASE_URL="postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense"
|
||||
|
||||
SECRET_KEY="SECRET"
|
||||
GOOGLE_OAUTH_CLIENT_ID="924507538m"
|
||||
GOOGLE_OAUTH_CLIENT_SECRET="GOCSV"
|
||||
NEXT_FRONTEND_URL="http://localhost:3000"
|
||||
EMBEDDING_MODEL="mixedbread-ai/mxbai-embed-large-v1"
|
||||
|
||||
RERANKERS_MODEL_NAME="ms-marco-MiniLM-L-12-v2"
|
||||
RERANKERS_MODEL_TYPE="flashrank"
|
||||
|
||||
FAST_LLM="litellm:openai/gpt-4o-mini"
|
||||
SMART_LLM="litellm:openai/gpt-4o-mini"
|
||||
STRATEGIC_LLM="litellm:openai/gpt-4o-mini"
|
||||
LONG_CONTEXT_LLM="litellm:gemini/gemini-2.0-flash-thinking-exp-01-21"
|
||||
|
||||
OPENAI_API_KEY="sk-proj-iA"
|
||||
GEMINI_API_KEY="AIzaSyB6-1641124124124124124124124124124"
|
||||
|
||||
UNSTRUCTURED_API_KEY="Tpu3P0U8iy"
|
||||
FIRECRAWL_API_KEY="fcr-01J0000000000000000000000"
|
6
surfsense_backend/.gitignore
vendored
Normal file
6
surfsense_backend/.gitignore
vendored
Normal file
|
@ -0,0 +1,6 @@
|
|||
.env
|
||||
.venv
|
||||
venv/
|
||||
data/
|
||||
__pycache__/
|
||||
.flashrank_cache
|
1
surfsense_backend/.python-version
Normal file
1
surfsense_backend/.python-version
Normal file
|
@ -0,0 +1 @@
|
|||
3.12
|
16
surfsense_backend/.vscode/launch.json
vendored
Normal file
16
surfsense_backend/.vscode/launch.json
vendored
Normal file
|
@ -0,0 +1,16 @@
|
|||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python Debugger: main.py",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false
|
||||
}
|
||||
]
|
||||
}
|
119
surfsense_backend/README.md
Normal file
119
surfsense_backend/README.md
Normal file
|
@ -0,0 +1,119 @@
|
|||
# Surf Backend
|
||||
|
||||
## Technology Stack Overview
|
||||
|
||||
This application is a modern AI-powered search and knowledge management platform built with the following technology stack:
|
||||
|
||||
### Core Framework and Environment
|
||||
- **Python 3.12+**: The application requires Python 3.12 or newer
|
||||
- **FastAPI**: Modern, fast web framework for building APIs with Python
|
||||
- **Uvicorn**: ASGI server implementation, running the FastAPI application
|
||||
- **PostgreSQL with pgvector**: Database with vector search capabilities for similarity searches
|
||||
- **SQLAlchemy**: SQL toolkit and ORM (Object-Relational Mapping) for database interactions
|
||||
- **FastAPI Users**: Authentication and user management with JWT and OAuth support
|
||||
|
||||
### Key Features and Components
|
||||
|
||||
#### Authentication and User Management
|
||||
- JWT-based authentication
|
||||
- OAuth integration (Google)
|
||||
- User registration, login, and password reset flows
|
||||
|
||||
#### Search and Retrieval System
|
||||
- **Hybrid Search**: Combines vector similarity and full-text search for optimal results using Reciprocal Rank Fusion (RRF)
|
||||
- **Vector Embeddings**: Document and text embeddings for semantic search
|
||||
- **pgvector**: PostgreSQL extension for efficient vector similarity operations
|
||||
- **Chonkie**: Advanced document chunking and embedding library
|
||||
- Uses `AutoEmbeddings` for flexible embedding model selection
|
||||
- `LateChunker` for optimized document chunking based on embedding model's max sequence length
|
||||
|
||||
#### AI and NLP Capabilities
|
||||
- **LangChain**: Framework for developing AI-powered applications
|
||||
- Used for document processing, research, and response generation
|
||||
- Integration with various LLM models through LiteLLM
|
||||
- Document conversion utilities for standardized processing
|
||||
- **GPT Integration**: Integration with LLM models through LiteLLM
|
||||
- Multiple LLM configurations for different use cases:
|
||||
- Fast LLM: Quick responses (default: gpt-4o-mini)
|
||||
- Smart LLM: More comprehensive analysis (default: gpt-4o-mini)
|
||||
- Strategic LLM: Complex reasoning (default: gpt-4o-mini)
|
||||
- Long Context LLM: For processing large documents (default: gemini-2.0-flash-thinking)
|
||||
- **Rerankers with FlashRank**: Advanced result ranking for improved search relevance
|
||||
- Configurable reranking models (default: ms-marco-MiniLM-L-12-v2)
|
||||
- Supports multiple reranking backends (FlashRank, Cohere, etc.)
|
||||
- Improves search result quality by reordering based on semantic relevance
|
||||
- **GPT-Researcher**: Advanced research capabilities
|
||||
- Multiple research modes (GENERAL, DEEP, DEEPER)
|
||||
- Customizable report formats with proper citations
|
||||
- Streaming research results for real-time updates
|
||||
|
||||
#### External Integrations
|
||||
- **Slack Connector**: Integration with Slack for data retrieval and notifications
|
||||
- **Notion Connector**: Integration with Notion for document retrieval
|
||||
- **Search APIs**: Integration with Tavily and Serper API for web search
|
||||
- **Firecrawl**: Web crawling and data extraction capabilities
|
||||
|
||||
#### Data Processing
|
||||
- **Unstructured**: Tools for processing unstructured data
|
||||
- **Markdownify**: Converting HTML to Markdown
|
||||
- **Playwright**: Web automation and scraping capabilities
|
||||
|
||||
#### Main Modules
|
||||
- **Search Spaces**: Isolated search environments for different contexts or projects
|
||||
- **Documents**: Storage and retrieval of various document types
|
||||
- **Chunks**: Document fragments for more precise retrieval
|
||||
- **Chats**: Conversation management with different depth levels (GENERAL, DEEP)
|
||||
- **Podcasts**: Audio content management with generation capabilities
|
||||
- **Search Source Connectors**: Integration with various data sources
|
||||
|
||||
### Development Tools
|
||||
- **Poetry**: Python dependency management (indicated by pyproject.toml)
|
||||
- **CORS support**: Cross-Origin Resource Sharing enabled for API access
|
||||
- **Environment Variables**: Configuration through .env files
|
||||
|
||||
## Database Schema
|
||||
|
||||
The application uses a relational database with the following main entities:
|
||||
- Users: Authentication and user management
|
||||
- SearchSpaces: Isolated search environments owned by users
|
||||
- Documents: Various document types with content and embeddings
|
||||
- Chunks: Smaller pieces of documents for granular retrieval
|
||||
- Chats: Conversation tracking with different depth levels
|
||||
- Podcasts: Audio content with generation capabilities
|
||||
- SearchSourceConnectors: External data source integrations
|
||||
|
||||
## API Endpoints
|
||||
|
||||
The API is structured with the following main route groups:
|
||||
- `/auth/*`: Authentication endpoints (JWT, OAuth)
|
||||
- `/users/*`: User management
|
||||
- `/api/v1/search-spaces/*`: Search space management
|
||||
- `/api/v1/documents/*`: Document management
|
||||
- `/api/v1/podcasts/*`: Podcast functionality
|
||||
- `/api/v1/chats/*`: Chat and conversation endpoints
|
||||
- `/api/v1/search-source-connectors/*`: External data source management
|
||||
|
||||
## Deployment
|
||||
|
||||
The application is configured to run with Uvicorn and can be deployed with:
|
||||
```
|
||||
python main.py
|
||||
```
|
||||
|
||||
This will start the server on all interfaces (0.0.0.0) with info-level logging.
|
||||
|
||||
## Requirements
|
||||
|
||||
See pyproject.toml for detailed dependency information. Key dependencies include:
|
||||
- asyncpg: Asynchronous PostgreSQL client
|
||||
- chonkie: Document chunking and embedding library
|
||||
- fastapi and related packages
|
||||
- fastapi-users: Authentication and user management
|
||||
- firecrawl-py: Web crawling capabilities
|
||||
- gpt-researcher: Advanced research capabilities
|
||||
- langchain components for AI workflows
|
||||
- litellm: LLM model integration
|
||||
- pgvector: Vector similarity search in PostgreSQL
|
||||
- rerankers with FlashRank: Advanced result ranking
|
||||
- Various AI and NLP libraries
|
||||
- Integration clients for Slack, Notion, etc.
|
0
surfsense_backend/app/__init__.py
Normal file
0
surfsense_backend/app/__init__.py
Normal file
80
surfsense_backend/app/app.py
Normal file
80
surfsense_backend/app/app.py
Normal file
|
@ -0,0 +1,80 @@
|
|||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import User, create_db_and_tables, get_async_session
|
||||
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||
from app.schemas import UserCreate, UserRead, UserUpdate
|
||||
from app.users import (
|
||||
SECRET,
|
||||
auth_backend,
|
||||
fastapi_users,
|
||||
google_oauth_client,
|
||||
current_active_user,
|
||||
)
|
||||
from app.routes import router as crud_router
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Not needed if you setup a migration system like Alembic
|
||||
await create_db_and_tables()
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Allows all origins
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"], # Allows all methods
|
||||
allow_headers=["*"], # Allows all headers
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
fastapi_users.get_auth_router(auth_backend), prefix="/auth/jwt", tags=["auth"]
|
||||
)
|
||||
app.include_router(
|
||||
fastapi_users.get_register_router(UserRead, UserCreate),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
app.include_router(
|
||||
fastapi_users.get_reset_password_router(),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
app.include_router(
|
||||
fastapi_users.get_verify_router(UserRead),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
app.include_router(
|
||||
fastapi_users.get_users_router(UserRead, UserUpdate),
|
||||
prefix="/users",
|
||||
tags=["users"],
|
||||
)
|
||||
app.include_router(
|
||||
fastapi_users.get_oauth_router(google_oauth_client, auth_backend, SECRET, is_verified_by_default=True),
|
||||
prefix="/auth/google",
|
||||
tags=["auth"],
|
||||
)
|
||||
app.include_router(crud_router, prefix="/api/v1", tags=["crud"])
|
||||
|
||||
|
||||
@app.get("/authenticated-route")
|
||||
async def authenticated_route(user: User = Depends(current_active_user), session: AsyncSession = Depends(get_async_session)):
|
||||
retriever = ChucksHybridSearchRetriever(session)
|
||||
results = await retriever.hybrid_search(
|
||||
query_text="SurfSense",
|
||||
top_k=1,
|
||||
user_id=user.id,
|
||||
search_space_id=1,
|
||||
document_type="CRAWLED_URL"
|
||||
)
|
||||
return results
|
98
surfsense_backend/app/config/__init__.py
Normal file
98
surfsense_backend/app/config/__init__.py
Normal file
|
@ -0,0 +1,98 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from chonkie import AutoEmbeddings, LateChunker
|
||||
from rerankers import Reranker
|
||||
from langchain_community.chat_models import ChatLiteLLM
|
||||
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Get the base directory of the project
|
||||
BASE_DIR = Path(__file__).resolve().parent.parent.parent
|
||||
|
||||
env_file = BASE_DIR / ".env"
|
||||
load_dotenv(env_file)
|
||||
|
||||
|
||||
def extract_model_name(llm_string: str) -> str:
|
||||
"""Extract the model name from an LLM string.
|
||||
Example: "litellm:openai/gpt-4o-mini" -> "openai/gpt-4o-mini"
|
||||
|
||||
Args:
|
||||
llm_string: The LLM string with optional prefix
|
||||
|
||||
Returns:
|
||||
str: The extracted model name
|
||||
"""
|
||||
return llm_string.split(":", 1)[1] if ":" in llm_string else llm_string
|
||||
|
||||
class Config:
|
||||
# Database
|
||||
DATABASE_URL = os.getenv("DATABASE_URL")
|
||||
|
||||
# Google OAuth
|
||||
GOOGLE_OAUTH_CLIENT_ID = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
|
||||
GOOGLE_OAUTH_CLIENT_SECRET = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
|
||||
NEXT_FRONTEND_URL = os.getenv("NEXT_FRONTEND_URL")
|
||||
|
||||
# LONG-CONTEXT LLMS
|
||||
LONG_CONTEXT_LLM = os.getenv("LONG_CONTEXT_LLM")
|
||||
long_context_llm_instance = ChatLiteLLM(model=extract_model_name(LONG_CONTEXT_LLM))
|
||||
|
||||
# GPT Researcher
|
||||
FAST_LLM = os.getenv("FAST_LLM")
|
||||
SMART_LLM = os.getenv("SMART_LLM")
|
||||
STRATEGIC_LLM = os.getenv("STRATEGIC_LLM")
|
||||
fast_llm_instance = ChatLiteLLM(model=extract_model_name(FAST_LLM))
|
||||
smart_llm_instance = ChatLiteLLM(model=extract_model_name(SMART_LLM))
|
||||
strategic_llm_instance = ChatLiteLLM(model=extract_model_name(STRATEGIC_LLM))
|
||||
|
||||
|
||||
# Chonkie Configuration | Edit this to your needs
|
||||
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
|
||||
embedding_model_instance = AutoEmbeddings.get_embeddings(EMBEDDING_MODEL)
|
||||
chunker_instance = LateChunker(
|
||||
embedding_model=EMBEDDING_MODEL,
|
||||
chunk_size=embedding_model_instance.max_seq_length,
|
||||
)
|
||||
|
||||
# Reranker's Configuration | Pinecode, Cohere etc. Read more at https://github.com/AnswerDotAI/rerankers?tab=readme-ov-file#usage
|
||||
RERANKERS_MODEL_NAME = os.getenv("RERANKERS_MODEL_NAME")
|
||||
RERANKERS_MODEL_TYPE = os.getenv("RERANKERS_MODEL_TYPE")
|
||||
reranker_instance = Reranker(
|
||||
model_name=RERANKERS_MODEL_NAME,
|
||||
model_type=RERANKERS_MODEL_TYPE,
|
||||
)
|
||||
|
||||
# OAuth JWT
|
||||
SECRET_KEY = os.getenv("SECRET_KEY")
|
||||
|
||||
# Unstructured API Key
|
||||
UNSTRUCTURED_API_KEY = os.getenv("UNSTRUCTURED_API_KEY")
|
||||
|
||||
# Firecrawl API Key
|
||||
FIRECRAWL_API_KEY = os.getenv("FIRECRAWL_API_KEY", None)
|
||||
|
||||
# Validation Checks
|
||||
# Check embedding dimension
|
||||
if hasattr(embedding_model_instance, 'dimension') and embedding_model_instance.dimension > 2000:
|
||||
raise ValueError(
|
||||
f"Embedding dimension for Model: {EMBEDDING_MODEL} "
|
||||
f"has {embedding_model_instance.dimension} dimensions, which "
|
||||
f"exceeds the maximum of 2000 allowed by PGVector."
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_settings(cls):
|
||||
"""Get all settings as a dictionary."""
|
||||
return {
|
||||
key: value
|
||||
for key, value in cls.__dict__.items()
|
||||
if not key.startswith("_") and not callable(value)
|
||||
}
|
||||
|
||||
|
||||
# Create a config instance
|
||||
config = Config()
|
225
surfsense_backend/app/connectors/notion_history.py
Normal file
225
surfsense_backend/app/connectors/notion_history.py
Normal file
|
@ -0,0 +1,225 @@
|
|||
from notion_client import Client
|
||||
|
||||
class NotionHistoryConnector:
|
||||
def __init__(self, token):
|
||||
"""
|
||||
Initialize the NotionPageFetcher with a token.
|
||||
|
||||
Args:
|
||||
token (str): Notion integration token
|
||||
"""
|
||||
self.notion = Client(auth=token)
|
||||
|
||||
def get_all_pages(self, start_date=None, end_date=None):
|
||||
"""
|
||||
Fetches all pages shared with your integration and their content.
|
||||
|
||||
Args:
|
||||
start_date (str, optional): ISO 8601 date string (e.g., "2023-01-01T00:00:00Z")
|
||||
end_date (str, optional): ISO 8601 date string (e.g., "2023-12-31T23:59:59Z")
|
||||
|
||||
Returns:
|
||||
list: List of dictionaries containing page data
|
||||
"""
|
||||
# Build the filter for the search
|
||||
# Note: Notion API requires specific filter structure
|
||||
search_params = {}
|
||||
|
||||
# Filter for pages only (not databases)
|
||||
search_params["filter"] = {
|
||||
"value": "page",
|
||||
"property": "object"
|
||||
}
|
||||
|
||||
# Add date filters if provided
|
||||
if start_date or end_date:
|
||||
date_filter = {}
|
||||
|
||||
if start_date:
|
||||
date_filter["on_or_after"] = start_date
|
||||
|
||||
if end_date:
|
||||
date_filter["on_or_before"] = end_date
|
||||
|
||||
# Add the date filter to the search params
|
||||
if date_filter:
|
||||
search_params["sort"] = {
|
||||
"direction": "descending",
|
||||
"timestamp": "last_edited_time"
|
||||
}
|
||||
|
||||
# First, get a list of all pages the integration has access to
|
||||
search_results = self.notion.search(**search_params)
|
||||
|
||||
pages = search_results["results"]
|
||||
all_page_data = []
|
||||
|
||||
for page in pages:
|
||||
page_id = page["id"]
|
||||
|
||||
# Get detailed page information
|
||||
page_content = self.get_page_content(page_id)
|
||||
|
||||
all_page_data.append({
|
||||
"page_id": page_id,
|
||||
"title": self.get_page_title(page),
|
||||
"content": page_content
|
||||
})
|
||||
|
||||
return all_page_data
|
||||
|
||||
def get_page_title(self, page):
|
||||
"""
|
||||
Extracts the title from a page object.
|
||||
|
||||
Args:
|
||||
page (dict): Notion page object
|
||||
|
||||
Returns:
|
||||
str: Page title or a fallback string
|
||||
"""
|
||||
# Title can be in different properties depending on the page type
|
||||
if "properties" in page:
|
||||
# Try to find a title property
|
||||
for prop_name, prop_data in page["properties"].items():
|
||||
if prop_data["type"] == "title" and len(prop_data["title"]) > 0:
|
||||
return " ".join([text_obj["plain_text"] for text_obj in prop_data["title"]])
|
||||
|
||||
# If no title found, return the page ID as fallback
|
||||
return f"Untitled page ({page['id']})"
|
||||
|
||||
def get_page_content(self, page_id):
|
||||
"""
|
||||
Fetches the content (blocks) of a specific page.
|
||||
|
||||
Args:
|
||||
page_id (str): The ID of the page to fetch
|
||||
|
||||
Returns:
|
||||
list: List of processed blocks from the page
|
||||
"""
|
||||
blocks = []
|
||||
has_more = True
|
||||
cursor = None
|
||||
|
||||
# Paginate through all blocks
|
||||
while has_more:
|
||||
if cursor:
|
||||
response = self.notion.blocks.children.list(block_id=page_id, start_cursor=cursor)
|
||||
else:
|
||||
response = self.notion.blocks.children.list(block_id=page_id)
|
||||
|
||||
blocks.extend(response["results"])
|
||||
has_more = response["has_more"]
|
||||
|
||||
if has_more:
|
||||
cursor = response["next_cursor"]
|
||||
|
||||
# Process nested blocks recursively
|
||||
processed_blocks = []
|
||||
for block in blocks:
|
||||
processed_block = self.process_block(block)
|
||||
processed_blocks.append(processed_block)
|
||||
|
||||
return processed_blocks
|
||||
|
||||
def process_block(self, block):
|
||||
"""
|
||||
Processes a block and recursively fetches any child blocks.
|
||||
|
||||
Args:
|
||||
block (dict): The block to process
|
||||
|
||||
Returns:
|
||||
dict: Processed block with content and children
|
||||
"""
|
||||
block_id = block["id"]
|
||||
block_type = block["type"]
|
||||
|
||||
# Extract block content based on its type
|
||||
content = self.extract_block_content(block)
|
||||
|
||||
# Check if block has children
|
||||
has_children = block.get("has_children", False)
|
||||
child_blocks = []
|
||||
|
||||
if has_children:
|
||||
# Fetch and process child blocks
|
||||
children_response = self.notion.blocks.children.list(block_id=block_id)
|
||||
for child_block in children_response["results"]:
|
||||
child_blocks.append(self.process_block(child_block))
|
||||
|
||||
return {
|
||||
"id": block_id,
|
||||
"type": block_type,
|
||||
"content": content,
|
||||
"children": child_blocks
|
||||
}
|
||||
|
||||
def extract_block_content(self, block):
|
||||
"""
|
||||
Extracts the content from a block based on its type.
|
||||
|
||||
Args:
|
||||
block (dict): The block to extract content from
|
||||
|
||||
Returns:
|
||||
str: Extracted content as a string
|
||||
"""
|
||||
block_type = block["type"]
|
||||
|
||||
# Different block types have different structures
|
||||
if block_type in block and "rich_text" in block[block_type]:
|
||||
return "".join([text_obj["plain_text"] for text_obj in block[block_type]["rich_text"]])
|
||||
elif block_type == "image":
|
||||
# Instead of returning the raw URL which may contain sensitive AWS credentials,
|
||||
# return a placeholder or reference to the image
|
||||
if "file" in block["image"]:
|
||||
# For Notion-hosted images (which use AWS S3 pre-signed URLs)
|
||||
return "[Notion Image]"
|
||||
elif "external" in block["image"]:
|
||||
# For external images, we can return a sanitized reference
|
||||
url = block["image"]["external"]["url"]
|
||||
# Only return the domain part of external URLs to avoid potential sensitive parameters
|
||||
try:
|
||||
from urllib.parse import urlparse
|
||||
parsed_url = urlparse(url)
|
||||
return f"[External Image from {parsed_url.netloc}]"
|
||||
except:
|
||||
return "[External Image]"
|
||||
elif block_type == "code":
|
||||
language = block["code"]["language"]
|
||||
code_text = "".join([text_obj["plain_text"] for text_obj in block["code"]["rich_text"]])
|
||||
return f"```{language}\n{code_text}\n```"
|
||||
elif block_type == "equation":
|
||||
return block["equation"]["expression"]
|
||||
# Add more block types as needed
|
||||
|
||||
# Return empty string for unsupported block types
|
||||
return ""
|
||||
|
||||
|
||||
# Example usage
|
||||
# if __name__ == "__main__":
|
||||
# # Simple example of how to use this module
|
||||
# import argparse
|
||||
|
||||
# parser = argparse.ArgumentParser(description="Fetch Notion pages using an integration token")
|
||||
# parser.add_argument("--token", help="Your Notion integration token")
|
||||
# parser.add_argument("--start-date", help="Start date in ISO format (e.g., 2023-01-01T00:00:00Z)")
|
||||
# parser.add_argument("--end-date", help="End date in ISO format (e.g., 2023-12-31T23:59:59Z)")
|
||||
# args = parser.parse_args()
|
||||
|
||||
# token = args.token
|
||||
# if not token:
|
||||
# token = input("Enter your Notion integration token: ")
|
||||
|
||||
# fetcher = NotionPageFetcher(token)
|
||||
|
||||
# try:
|
||||
# pages = fetcher.get_all_pages(args.start_date, args.end_date)
|
||||
# print(f"Fetched {len(pages)} pages from Notion")
|
||||
# for page in pages:
|
||||
# print(f"- {page['title']}")
|
||||
# except Exception as e:
|
||||
# print(f"Error: {str(e)}")
|
301
surfsense_backend/app/connectors/slack_history.py
Normal file
301
surfsense_backend/app/connectors/slack_history.py
Normal file
|
@ -0,0 +1,301 @@
|
|||
"""
|
||||
Slack History Module
|
||||
|
||||
A module for retrieving conversation history from Slack channels.
|
||||
Allows fetching channel lists and message history with date range filtering.
|
||||
"""
|
||||
|
||||
import os
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any, Union
|
||||
|
||||
|
||||
class SlackHistory:
|
||||
"""Class for retrieving conversation history from Slack channels."""
|
||||
|
||||
def __init__(self, token: str = None):
|
||||
"""
|
||||
Initialize the SlackHistory class.
|
||||
|
||||
Args:
|
||||
token: Slack API token (optional, can be set later with set_token)
|
||||
"""
|
||||
self.client = WebClient(token=token) if token else None
|
||||
|
||||
def set_token(self, token: str) -> None:
|
||||
"""
|
||||
Set the Slack API token.
|
||||
|
||||
Args:
|
||||
token: Slack API token
|
||||
"""
|
||||
self.client = WebClient(token=token)
|
||||
|
||||
def get_all_channels(self, include_private: bool = True) -> Dict[str, str]:
|
||||
"""
|
||||
Fetch all channels that the bot has access to.
|
||||
|
||||
Args:
|
||||
include_private: Whether to include private channels
|
||||
|
||||
Returns:
|
||||
Dictionary mapping channel names to channel IDs
|
||||
|
||||
Raises:
|
||||
ValueError: If no Slack client has been initialized
|
||||
SlackApiError: If there's an error calling the Slack API
|
||||
"""
|
||||
if not self.client:
|
||||
raise ValueError("Slack client not initialized. Call set_token() first.")
|
||||
|
||||
channels_dict = {}
|
||||
types = "public_channel"
|
||||
if include_private:
|
||||
types += ",private_channel"
|
||||
|
||||
try:
|
||||
# Call the conversations.list method
|
||||
result = self.client.conversations_list(
|
||||
types=types,
|
||||
limit=1000 # Maximum allowed by API
|
||||
)
|
||||
channels = result["channels"]
|
||||
|
||||
# Handle pagination for workspaces with many channels
|
||||
while result.get("response_metadata", {}).get("next_cursor"):
|
||||
next_cursor = result["response_metadata"]["next_cursor"]
|
||||
|
||||
# Get the next batch of channels
|
||||
result = self.client.conversations_list(
|
||||
types=types,
|
||||
cursor=next_cursor,
|
||||
limit=1000
|
||||
)
|
||||
channels.extend(result["channels"])
|
||||
|
||||
# Create a dictionary mapping channel names to IDs
|
||||
for channel in channels:
|
||||
channels_dict[channel["name"]] = channel["id"]
|
||||
|
||||
return channels_dict
|
||||
|
||||
except SlackApiError as e:
|
||||
raise SlackApiError(f"Error retrieving channels: {e}", e.response)
|
||||
|
||||
def get_conversation_history(
|
||||
self,
|
||||
channel_id: str,
|
||||
limit: int = 1000,
|
||||
oldest: Optional[int] = None,
|
||||
latest: Optional[int] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch conversation history for a channel.
|
||||
|
||||
Args:
|
||||
channel_id: The ID of the channel to fetch history for
|
||||
limit: Maximum number of messages to return per request (default 1000)
|
||||
oldest: Start of time range (Unix timestamp)
|
||||
latest: End of time range (Unix timestamp)
|
||||
|
||||
Returns:
|
||||
List of message objects
|
||||
|
||||
Raises:
|
||||
ValueError: If no Slack client has been initialized
|
||||
SlackApiError: If there's an error calling the Slack API
|
||||
"""
|
||||
if not self.client:
|
||||
raise ValueError("Slack client not initialized. Call set_token() first.")
|
||||
|
||||
try:
|
||||
# Call the conversations.history method
|
||||
messages = []
|
||||
next_cursor = None
|
||||
|
||||
while True:
|
||||
kwargs = {
|
||||
"channel": channel_id,
|
||||
"limit": min(limit, 1000), # API max is 1000
|
||||
}
|
||||
|
||||
if oldest:
|
||||
kwargs["oldest"] = oldest
|
||||
if latest:
|
||||
kwargs["latest"] = latest
|
||||
if next_cursor:
|
||||
kwargs["cursor"] = next_cursor
|
||||
|
||||
result = self.client.conversations_history(**kwargs)
|
||||
batch = result["messages"]
|
||||
messages.extend(batch)
|
||||
|
||||
# Check if we need to paginate
|
||||
if result.get("has_more", False) and len(messages) < limit:
|
||||
next_cursor = result["response_metadata"]["next_cursor"]
|
||||
else:
|
||||
break
|
||||
|
||||
# Respect the overall limit parameter
|
||||
return messages[:limit]
|
||||
|
||||
except SlackApiError as e:
|
||||
raise SlackApiError(f"Error retrieving history for channel {channel_id}: {e}", e.response)
|
||||
|
||||
@staticmethod
|
||||
def convert_date_to_timestamp(date_str: str) -> Optional[int]:
|
||||
"""
|
||||
Convert a date string in format YYYY-MM-DD to Unix timestamp.
|
||||
|
||||
Args:
|
||||
date_str: Date string in YYYY-MM-DD format
|
||||
|
||||
Returns:
|
||||
Unix timestamp (seconds since epoch) or None if invalid format
|
||||
"""
|
||||
try:
|
||||
dt = datetime.strptime(date_str, "%Y-%m-%d")
|
||||
return int(dt.timestamp())
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def get_history_by_date_range(
|
||||
self,
|
||||
channel_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
limit: int = 1000
|
||||
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
|
||||
"""
|
||||
Fetch conversation history within a date range.
|
||||
|
||||
Args:
|
||||
channel_id: The ID of the channel to fetch history for
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format (inclusive)
|
||||
limit: Maximum number of messages to return
|
||||
|
||||
Returns:
|
||||
Tuple containing (messages list, error message or None)
|
||||
"""
|
||||
oldest = self.convert_date_to_timestamp(start_date)
|
||||
if not oldest:
|
||||
return [], f"Invalid start date format: {start_date}. Please use YYYY-MM-DD."
|
||||
|
||||
latest = self.convert_date_to_timestamp(end_date)
|
||||
if not latest:
|
||||
return [], f"Invalid end date format: {end_date}. Please use YYYY-MM-DD."
|
||||
|
||||
# Add one day to end date to make it inclusive
|
||||
latest += 86400 # seconds in a day
|
||||
|
||||
try:
|
||||
messages = self.get_conversation_history(
|
||||
channel_id=channel_id,
|
||||
limit=limit,
|
||||
oldest=oldest,
|
||||
latest=latest
|
||||
)
|
||||
return messages, None
|
||||
except SlackApiError as e:
|
||||
return [], f"Slack API error: {str(e)}"
|
||||
except ValueError as e:
|
||||
return [], str(e)
|
||||
|
||||
def get_user_info(self, user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about a user.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user to get info for
|
||||
|
||||
Returns:
|
||||
User information dictionary
|
||||
|
||||
Raises:
|
||||
ValueError: If no Slack client has been initialized
|
||||
SlackApiError: If there's an error calling the Slack API
|
||||
"""
|
||||
if not self.client:
|
||||
raise ValueError("Slack client not initialized. Call set_token() first.")
|
||||
|
||||
try:
|
||||
result = self.client.users_info(user=user_id)
|
||||
return result["user"]
|
||||
except SlackApiError as e:
|
||||
raise SlackApiError(f"Error retrieving user info for {user_id}: {e}", e.response)
|
||||
|
||||
def format_message(self, msg: Dict[str, Any], include_user_info: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Format a message for easier consumption.
|
||||
|
||||
Args:
|
||||
msg: The message object from Slack API
|
||||
include_user_info: Whether to fetch and include user info
|
||||
|
||||
Returns:
|
||||
Formatted message dictionary
|
||||
"""
|
||||
formatted = {
|
||||
"text": msg.get("text", ""),
|
||||
"timestamp": msg.get("ts"),
|
||||
"datetime": datetime.fromtimestamp(float(msg.get("ts", 0))).strftime('%Y-%m-%d %H:%M:%S'),
|
||||
"user_id": msg.get("user", "UNKNOWN"),
|
||||
"has_attachments": bool(msg.get("attachments")),
|
||||
"has_files": bool(msg.get("files")),
|
||||
"thread_ts": msg.get("thread_ts"),
|
||||
"is_thread": "thread_ts" in msg,
|
||||
}
|
||||
|
||||
if include_user_info and "user" in msg and self.client:
|
||||
try:
|
||||
user_info = self.get_user_info(msg["user"])
|
||||
formatted["user_name"] = user_info.get("real_name", "Unknown")
|
||||
formatted["user_email"] = user_info.get("profile", {}).get("email", "")
|
||||
except Exception:
|
||||
# If we can't get user info, just continue without it
|
||||
formatted["user_name"] = "Unknown"
|
||||
|
||||
return formatted
|
||||
|
||||
|
||||
# Example usage (uncomment to use):
|
||||
"""
|
||||
if __name__ == "__main__":
|
||||
# Set your token here or via environment variable
|
||||
token = os.environ.get("SLACK_API_TOKEN", "xoxb-your-token-here")
|
||||
|
||||
slack = SlackHistory(token)
|
||||
|
||||
# Get all channels
|
||||
try:
|
||||
channels = slack.get_all_channels()
|
||||
print("Available channels:")
|
||||
for name, channel_id in sorted(channels.items()):
|
||||
print(f"- {name}: {channel_id}")
|
||||
|
||||
# Example: Get history for a specific channel and date range
|
||||
channel_id = channels.get("general")
|
||||
if channel_id:
|
||||
messages, error = slack.get_history_by_date_range(
|
||||
channel_id=channel_id,
|
||||
start_date="2023-01-01",
|
||||
end_date="2023-01-31",
|
||||
limit=500
|
||||
)
|
||||
|
||||
if error:
|
||||
print(f"Error: {error}")
|
||||
else:
|
||||
print(f"\nRetrieved {len(messages)} messages from #general")
|
||||
|
||||
# Print formatted messages
|
||||
for msg in messages[:10]: # Show first 10 messages
|
||||
formatted = slack.format_message(msg, include_user_info=True)
|
||||
print(f"[{formatted['datetime']}] {formatted['user_name']}: {formatted['text']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
"""
|
181
surfsense_backend/app/db.py
Normal file
181
surfsense_backend/app/db.py
Normal file
|
@ -0,0 +1,181 @@
|
|||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi_users.db import (
|
||||
SQLAlchemyBaseOAuthAccountTableUUID,
|
||||
SQLAlchemyBaseUserTableUUID,
|
||||
SQLAlchemyUserDatabase,
|
||||
)
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy import (
|
||||
ARRAY,
|
||||
Boolean,
|
||||
Column,
|
||||
Enum as SQLAlchemyEnum,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
JSON,
|
||||
String,
|
||||
Text,
|
||||
text,
|
||||
TIMESTAMP
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, declared_attr, relationship
|
||||
|
||||
from app.config import config
|
||||
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||
|
||||
DATABASE_URL = config.DATABASE_URL
|
||||
|
||||
|
||||
class DocumentType(str, Enum):
|
||||
EXTENSION = "EXTENSION"
|
||||
CRAWLED_URL = "CRAWLED_URL"
|
||||
FILE = "FILE"
|
||||
SLACK_CONNECTOR = "SLACK_CONNECTOR"
|
||||
NOTION_CONNECTOR = "NOTION_CONNECTOR"
|
||||
|
||||
class SearchSourceConnectorType(str, Enum):
|
||||
SERPER_API = "SERPER_API"
|
||||
TAVILY_API = "TAVILY_API"
|
||||
SLACK_CONNECTOR = "SLACK_CONNECTOR"
|
||||
NOTION_CONNECTOR = "NOTION_CONNECTOR"
|
||||
|
||||
class ChatType(str, Enum):
|
||||
GENERAL = "GENERAL"
|
||||
DEEP = "DEEP"
|
||||
DEEPER = "DEEPER"
|
||||
DEEPEST = "DEEPEST"
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
class TimestampMixin:
|
||||
@declared_attr
|
||||
def created_at(cls):
|
||||
return Column(TIMESTAMP(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc), index=True)
|
||||
|
||||
class BaseModel(Base):
|
||||
__abstract__ = True
|
||||
__allow_unmapped__ = True
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
class Chat(BaseModel, TimestampMixin):
|
||||
__tablename__ = "chats"
|
||||
|
||||
type = Column(SQLAlchemyEnum(ChatType), nullable=False)
|
||||
title = Column(String(200), nullable=False, index=True)
|
||||
initial_connectors = Column(ARRAY(String), nullable=True)
|
||||
messages = Column(JSON, nullable=False)
|
||||
|
||||
search_space_id = Column(Integer, ForeignKey('searchspaces.id', ondelete='CASCADE'), nullable=False)
|
||||
search_space = relationship('SearchSpace', back_populates='chats')
|
||||
|
||||
class Document(BaseModel, TimestampMixin):
|
||||
__tablename__ = "documents"
|
||||
|
||||
title = Column(String(200), nullable=False, index=True)
|
||||
document_type = Column(SQLAlchemyEnum(DocumentType), nullable=False)
|
||||
document_metadata = Column(JSON, nullable=True)
|
||||
|
||||
content = Column(Text, nullable=False)
|
||||
embedding = Column(Vector(config.embedding_model_instance.dimension))
|
||||
|
||||
search_space_id = Column(Integer, ForeignKey("searchspaces.id", ondelete='CASCADE'), nullable=False)
|
||||
search_space = relationship("SearchSpace", back_populates="documents")
|
||||
chunks = relationship("Chunk", back_populates="document", cascade="all, delete-orphan")
|
||||
|
||||
class Chunk(BaseModel, TimestampMixin):
|
||||
__tablename__ = "chunks"
|
||||
|
||||
content = Column(Text, nullable=False)
|
||||
embedding = Column(Vector(config.embedding_model_instance.dimension))
|
||||
|
||||
document_id = Column(Integer, ForeignKey("documents.id", ondelete='CASCADE'), nullable=False)
|
||||
document = relationship("Document", back_populates="chunks")
|
||||
|
||||
class Podcast(BaseModel, TimestampMixin):
|
||||
__tablename__ = "podcasts"
|
||||
|
||||
title = Column(String(200), nullable=False, index=True)
|
||||
is_generated = Column(Boolean, nullable=False, default=False)
|
||||
podcast_content = Column(Text, nullable=False, default="")
|
||||
file_location = Column(String(500), nullable=False, default="")
|
||||
|
||||
search_space_id = Column(Integer, ForeignKey("searchspaces.id", ondelete='CASCADE'), nullable=False)
|
||||
search_space = relationship("SearchSpace", back_populates="podcasts")
|
||||
|
||||
class SearchSpace(BaseModel, TimestampMixin):
|
||||
__tablename__ = "searchspaces"
|
||||
|
||||
name = Column(String(100), nullable=False, index=True)
|
||||
description = Column(String(500), nullable=True)
|
||||
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey("user.id", ondelete='CASCADE'), nullable=False)
|
||||
user = relationship("User", back_populates="search_spaces")
|
||||
|
||||
documents = relationship("Document", back_populates="search_space", order_by="Document.id", cascade="all, delete-orphan")
|
||||
podcasts = relationship("Podcast", back_populates="search_space", order_by="Podcast.id", cascade="all, delete-orphan")
|
||||
chats = relationship('Chat', back_populates='search_space', order_by='Chat.id', cascade="all, delete-orphan")
|
||||
|
||||
class SearchSourceConnector(BaseModel, TimestampMixin):
|
||||
__tablename__ = "search_source_connectors"
|
||||
|
||||
name = Column(String(100), nullable=False, index=True)
|
||||
connector_type = Column(SQLAlchemyEnum(SearchSourceConnectorType), nullable=False, unique=True)
|
||||
is_indexable = Column(Boolean, nullable=False, default=False)
|
||||
last_indexed_at = Column(TIMESTAMP(timezone=True), nullable=True)
|
||||
config = Column(JSON, nullable=False)
|
||||
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey("user.id", ondelete='CASCADE'), nullable=False)
|
||||
user = relationship("User", back_populates="search_source_connectors")
|
||||
|
||||
|
||||
class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
|
||||
pass
|
||||
|
||||
|
||||
class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
oauth_accounts: Mapped[list[OAuthAccount]] = relationship(
|
||||
"OAuthAccount", lazy="joined"
|
||||
)
|
||||
search_spaces = relationship("SearchSpace", back_populates="user")
|
||||
search_source_connectors = relationship("SearchSourceConnector", back_populates="user")
|
||||
|
||||
|
||||
engine = create_async_engine(DATABASE_URL)
|
||||
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
async def setup_indexes():
|
||||
async with engine.begin() as conn:
|
||||
# Create indexes
|
||||
# Document Summary Indexes
|
||||
await conn.execute(text('CREATE INDEX IF NOT EXISTS document_vector_index ON documents USING hnsw (embedding public.vector_cosine_ops)'))
|
||||
await conn.execute(text('CREATE INDEX IF NOT EXISTS document_search_index ON documents USING gin (to_tsvector(\'english\', content))'))
|
||||
# Document Chuck Indexes
|
||||
await conn.execute(text('CREATE INDEX IF NOT EXISTS chucks_vector_index ON chunks USING hnsw (embedding public.vector_cosine_ops)'))
|
||||
await conn.execute(text('CREATE INDEX IF NOT EXISTS chucks_search_index ON chunks USING gin (to_tsvector(\'english\', content))'))
|
||||
|
||||
async def create_db_and_tables():
|
||||
async with engine.begin() as conn:
|
||||
await conn.execute(text('CREATE EXTENSION IF NOT EXISTS vector'))
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
await setup_indexes()
|
||||
|
||||
|
||||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with async_session_maker() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def get_user_db(session: AsyncSession = Depends(get_async_session)):
|
||||
yield SQLAlchemyUserDatabase(session, User, OAuthAccount)
|
||||
|
||||
async def get_chucks_hybrid_search_retriever(session: AsyncSession = Depends(get_async_session)):
|
||||
return ChucksHybridSearchRetriever(session)
|
103
surfsense_backend/app/prompts/__init__.py
Normal file
103
surfsense_backend/app/prompts/__init__.py
Normal file
|
@ -0,0 +1,103 @@
|
|||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from datetime import datetime, timezone
|
||||
|
||||
DATE_TODAY = "Today's date is " + datetime.now(timezone.utc).astimezone().isoformat() + '\n'
|
||||
|
||||
SUMMARY_PROMPT = DATE_TODAY + """
|
||||
<INSTRUCTIONS>
|
||||
<context>
|
||||
You are an expert document analyst and summarization specialist tasked with distilling complex information into clear,
|
||||
comprehensive summaries. Your role is to analyze documents thoroughly and create structured summaries that:
|
||||
1. Capture the complete essence and key insights of the source material
|
||||
2. Maintain perfect accuracy and factual precision
|
||||
3. Present information objectively without bias or interpretation
|
||||
4. Preserve critical context and logical relationships
|
||||
5. Structure content in a clear, hierarchical format
|
||||
</context>
|
||||
|
||||
<principles>
|
||||
<accuracy>
|
||||
- Maintain absolute factual accuracy and fidelity to source material
|
||||
- Avoid any subjective interpretation, inference or speculation
|
||||
- Preserve complete original meaning, nuance and contextual relationships
|
||||
- Report all quantitative data with precise values and appropriate units
|
||||
- Verify and cross-reference facts before inclusion
|
||||
- Flag any ambiguous or unclear information
|
||||
</accuracy>
|
||||
|
||||
<objectivity>
|
||||
- Present information with strict neutrality and impartiality
|
||||
- Exclude all forms of bias, personal opinions, and editorial commentary
|
||||
- Ensure balanced representation of all perspectives and viewpoints
|
||||
- Maintain objective professional distance from the content
|
||||
- Use precise, factual language free from emotional coloring
|
||||
- Focus solely on verifiable information and evidence
|
||||
</objectivity>
|
||||
|
||||
<comprehensiveness>
|
||||
- Capture all essential information, key themes, and central arguments
|
||||
- Preserve critical context and background necessary for understanding
|
||||
- Include relevant supporting details, examples, and evidence
|
||||
- Maintain logical flow and connections between concepts
|
||||
- Ensure hierarchical organization of information
|
||||
- Document relationships between different components
|
||||
- Highlight dependencies and causal links
|
||||
- Track chronological progression where relevant
|
||||
</comprehensiveness>
|
||||
</principles>
|
||||
|
||||
<output_format>
|
||||
<type>
|
||||
- Return summary in clean markdown format
|
||||
- Do not include markdown code block tags (```markdown ```)
|
||||
- Use standard markdown syntax for formatting (headers, lists, etc.)
|
||||
- Use # for main headings (e.g., # EXECUTIVE SUMMARY)
|
||||
- Use ## for subheadings where appropriate
|
||||
- Use bullet points (- item) for lists
|
||||
- Ensure proper indentation and spacing
|
||||
- Use appropriate emphasis (**bold**, *italic*) where needed
|
||||
</type>
|
||||
<style>
|
||||
- Use clear, concise language focused on key points
|
||||
- Maintain professional and objective tone throughout
|
||||
- Follow consistent formatting and style conventions
|
||||
- Provide descriptive section headings and subheadings
|
||||
- Utilize bullet points and lists for better readability
|
||||
- Structure content with clear hierarchy and organization
|
||||
- Avoid jargon and overly technical language
|
||||
- Include transition sentences between sections
|
||||
</style>
|
||||
</output_format>
|
||||
|
||||
<validation>
|
||||
<criteria>
|
||||
- Verify all facts and claims match source material exactly
|
||||
- Cross-reference and validate all numerical data points
|
||||
- Ensure logical flow and consistency throughout summary
|
||||
- Confirm comprehensive coverage of key information
|
||||
- Check for objective, unbiased language and tone
|
||||
- Validate accurate representation of source context
|
||||
- Review for proper attribution of ideas and quotes
|
||||
- Verify temporal accuracy and chronological order
|
||||
</criteria>
|
||||
</validation>
|
||||
|
||||
<length_guidelines>
|
||||
- Scale summary length proportionally to source document complexity and length
|
||||
- Minimum: 3-5 well-developed paragraphs per major section
|
||||
- Maximum: 8-10 paragraphs per section for highly complex documents
|
||||
- Adjust level of detail based on information density and importance
|
||||
- Ensure key concepts receive adequate coverage regardless of length
|
||||
</length_guidelines>
|
||||
|
||||
Now, create a summary of the following document:
|
||||
<document_to_summarize>
|
||||
{document}
|
||||
</document_to_summarize>
|
||||
</INSTRUCTIONS>
|
||||
"""
|
||||
|
||||
SUMMARY_PROMPT_TEMPLATE = PromptTemplate(
|
||||
input_variables=["document"],
|
||||
template=SUMMARY_PROMPT
|
||||
)
|
0
surfsense_backend/app/retriver/__init__.py
Normal file
0
surfsense_backend/app/retriver/__init__.py
Normal file
243
surfsense_backend/app/retriver/chunks_hybrid_search.py
Normal file
243
surfsense_backend/app/retriver/chunks_hybrid_search.py
Normal file
|
@ -0,0 +1,243 @@
|
|||
class ChucksHybridSearchRetriever:
|
||||
def __init__(self, db_session):
|
||||
"""
|
||||
Initialize the hybrid search retriever with a database session.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy AsyncSession from FastAPI dependency injection
|
||||
"""
|
||||
self.db_session = db_session
|
||||
|
||||
async def vector_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None) -> list:
|
||||
"""
|
||||
Perform vector similarity search on chunks.
|
||||
|
||||
Args:
|
||||
query_text: The search query text
|
||||
top_k: Number of results to return
|
||||
user_id: The ID of the user performing the search
|
||||
search_space_id: Optional search space ID to filter results
|
||||
|
||||
Returns:
|
||||
List of chunks sorted by vector similarity
|
||||
"""
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.orm import joinedload
|
||||
from app.db import Chunk, Document, SearchSpace
|
||||
from app.config import config
|
||||
|
||||
# Get embedding for the query
|
||||
embedding_model = config.embedding_model_instance
|
||||
query_embedding = embedding_model.embed(query_text)
|
||||
|
||||
# Build the base query with user ownership check
|
||||
query = (
|
||||
select(Chunk)
|
||||
.options(joinedload(Chunk.document).joinedload(Document.search_space))
|
||||
.join(Document, Chunk.document_id == Document.id)
|
||||
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
||||
.where(SearchSpace.user_id == user_id)
|
||||
)
|
||||
|
||||
# Add search space filter if provided
|
||||
if search_space_id is not None:
|
||||
query = query.where(Document.search_space_id == search_space_id)
|
||||
|
||||
# Add vector similarity ordering
|
||||
query = (
|
||||
query
|
||||
.order_by(Chunk.embedding.op("<=>")(query_embedding))
|
||||
.limit(top_k)
|
||||
)
|
||||
|
||||
# Execute the query
|
||||
result = await self.db_session.execute(query)
|
||||
chunks = result.scalars().all()
|
||||
|
||||
return chunks
|
||||
|
||||
async def full_text_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None) -> list:
|
||||
"""
|
||||
Perform full-text keyword search on chunks.
|
||||
|
||||
Args:
|
||||
query_text: The search query text
|
||||
top_k: Number of results to return
|
||||
user_id: The ID of the user performing the search
|
||||
search_space_id: Optional search space ID to filter results
|
||||
|
||||
Returns:
|
||||
List of chunks sorted by text relevance
|
||||
"""
|
||||
from sqlalchemy import select, func, text
|
||||
from sqlalchemy.orm import joinedload
|
||||
from app.db import Chunk, Document, SearchSpace
|
||||
|
||||
# Create tsvector and tsquery for PostgreSQL full-text search
|
||||
tsvector = func.to_tsvector('english', Chunk.content)
|
||||
tsquery = func.plainto_tsquery('english', query_text)
|
||||
|
||||
# Build the base query with user ownership check
|
||||
query = (
|
||||
select(Chunk)
|
||||
.options(joinedload(Chunk.document).joinedload(Document.search_space))
|
||||
.join(Document, Chunk.document_id == Document.id)
|
||||
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
||||
.where(SearchSpace.user_id == user_id)
|
||||
.where(tsvector.op("@@")(tsquery)) # Only include results that match the query
|
||||
)
|
||||
|
||||
# Add search space filter if provided
|
||||
if search_space_id is not None:
|
||||
query = query.where(Document.search_space_id == search_space_id)
|
||||
|
||||
# Add text search ranking
|
||||
query = (
|
||||
query
|
||||
.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
|
||||
.limit(top_k)
|
||||
)
|
||||
|
||||
# Execute the query
|
||||
result = await self.db_session.execute(query)
|
||||
chunks = result.scalars().all()
|
||||
|
||||
return chunks
|
||||
|
||||
async def hybrid_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None, document_type: str = None) -> list:
|
||||
"""
|
||||
Combine vector similarity and full-text search results using Reciprocal Rank Fusion.
|
||||
|
||||
Args:
|
||||
query_text: The search query text
|
||||
top_k: Number of results to return
|
||||
user_id: The ID of the user performing the search
|
||||
search_space_id: Optional search space ID to filter results
|
||||
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
|
||||
|
||||
Returns:
|
||||
List of dictionaries containing chunk data and relevance scores
|
||||
"""
|
||||
from sqlalchemy import select, func, text
|
||||
from sqlalchemy.orm import joinedload
|
||||
from app.db import Chunk, Document, SearchSpace, DocumentType
|
||||
from app.config import config
|
||||
|
||||
# Get embedding for the query
|
||||
embedding_model = config.embedding_model_instance
|
||||
query_embedding = embedding_model.embed(query_text)
|
||||
|
||||
# Constants for RRF calculation
|
||||
k = 60 # Constant for RRF calculation
|
||||
n_results = top_k * 2 # Get more results for better fusion
|
||||
|
||||
# Create tsvector and tsquery for PostgreSQL full-text search
|
||||
tsvector = func.to_tsvector('english', Chunk.content)
|
||||
tsquery = func.plainto_tsquery('english', query_text)
|
||||
|
||||
# Base conditions for document filtering
|
||||
base_conditions = [SearchSpace.user_id == user_id]
|
||||
|
||||
# Add search space filter if provided
|
||||
if search_space_id is not None:
|
||||
base_conditions.append(Document.search_space_id == search_space_id)
|
||||
|
||||
# Add document type filter if provided
|
||||
if document_type is not None:
|
||||
# Convert string to enum value if needed
|
||||
if isinstance(document_type, str):
|
||||
try:
|
||||
doc_type_enum = DocumentType[document_type]
|
||||
base_conditions.append(Document.document_type == doc_type_enum)
|
||||
except KeyError:
|
||||
# If the document type doesn't exist in the enum, return empty results
|
||||
return []
|
||||
else:
|
||||
base_conditions.append(Document.document_type == document_type)
|
||||
|
||||
# CTE for semantic search with user ownership check
|
||||
semantic_search_cte = (
|
||||
select(
|
||||
Chunk.id,
|
||||
func.rank().over(order_by=Chunk.embedding.op("<=>")(query_embedding)).label("rank")
|
||||
)
|
||||
.join(Document, Chunk.document_id == Document.id)
|
||||
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
||||
.where(*base_conditions)
|
||||
)
|
||||
|
||||
semantic_search_cte = (
|
||||
semantic_search_cte
|
||||
.order_by(Chunk.embedding.op("<=>")(query_embedding))
|
||||
.limit(n_results)
|
||||
.cte("semantic_search")
|
||||
)
|
||||
|
||||
# CTE for keyword search with user ownership check
|
||||
keyword_search_cte = (
|
||||
select(
|
||||
Chunk.id,
|
||||
func.rank().over(order_by=func.ts_rank_cd(tsvector, tsquery).desc()).label("rank")
|
||||
)
|
||||
.join(Document, Chunk.document_id == Document.id)
|
||||
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
||||
.where(*base_conditions)
|
||||
.where(tsvector.op("@@")(tsquery))
|
||||
)
|
||||
|
||||
keyword_search_cte = (
|
||||
keyword_search_cte
|
||||
.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
|
||||
.limit(n_results)
|
||||
.cte("keyword_search")
|
||||
)
|
||||
|
||||
# Final combined query using a FULL OUTER JOIN with RRF scoring
|
||||
final_query = (
|
||||
select(
|
||||
Chunk,
|
||||
(
|
||||
func.coalesce(1.0 / (k + semantic_search_cte.c.rank), 0.0) +
|
||||
func.coalesce(1.0 / (k + keyword_search_cte.c.rank), 0.0)
|
||||
).label("score")
|
||||
)
|
||||
.select_from(
|
||||
semantic_search_cte.outerjoin(
|
||||
keyword_search_cte,
|
||||
semantic_search_cte.c.id == keyword_search_cte.c.id,
|
||||
full=True
|
||||
)
|
||||
)
|
||||
.join(
|
||||
Chunk,
|
||||
Chunk.id == func.coalesce(semantic_search_cte.c.id, keyword_search_cte.c.id)
|
||||
)
|
||||
.options(joinedload(Chunk.document))
|
||||
.order_by(text("score DESC"))
|
||||
.limit(top_k)
|
||||
)
|
||||
|
||||
# Execute the query
|
||||
result = await self.db_session.execute(final_query)
|
||||
chunks_with_scores = result.all()
|
||||
|
||||
# If no results were found, return an empty list
|
||||
if not chunks_with_scores:
|
||||
return []
|
||||
|
||||
# Convert to serializable dictionaries if no reranker is available or if reranking failed
|
||||
serialized_results = []
|
||||
for chunk, score in chunks_with_scores:
|
||||
serialized_results.append({
|
||||
"chunk_id": chunk.id,
|
||||
"content": chunk.content,
|
||||
"score": float(score), # Ensure score is a Python float
|
||||
"document": {
|
||||
"id": chunk.document.id,
|
||||
"title": chunk.document.title,
|
||||
"document_type": chunk.document.document_type.value if hasattr(chunk.document, 'document_type') else None,
|
||||
"metadata": chunk.document.document_metadata
|
||||
}
|
||||
})
|
||||
|
||||
return serialized_results
|
14
surfsense_backend/app/routes/__init__.py
Normal file
14
surfsense_backend/app/routes/__init__.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
from fastapi import APIRouter
|
||||
from .search_spaces_routes import router as search_spaces_router
|
||||
from .documents_routes import router as documents_router
|
||||
from .podcasts_routes import router as podcasts_router
|
||||
from .chats_routes import router as chats_router
|
||||
from .search_source_connectors_routes import router as search_source_connectors_router
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
router.include_router(search_spaces_router)
|
||||
router.include_router(documents_router)
|
||||
router.include_router(podcasts_router)
|
||||
router.include_router(chats_router)
|
||||
router.include_router(search_source_connectors_router)
|
260
surfsense_backend/app/routes/chats_routes.py
Normal file
260
surfsense_backend/app/routes/chats_routes.py
Normal file
|
@ -0,0 +1,260 @@
|
|||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
from typing import List
|
||||
from app.db import get_async_session, User, SearchSpace, Chat
|
||||
from app.schemas import ChatCreate, ChatUpdate, ChatRead, AISDKChatRequest
|
||||
from app.tasks.stream_connector_search_results import stream_connector_search_results
|
||||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/chat")
|
||||
async def handle_chat_data(
|
||||
request: AISDKChatRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
messages = request.messages
|
||||
if messages[-1].role != "user":
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Last message must be a user message")
|
||||
|
||||
user_query = messages[-1].content
|
||||
search_space_id = request.data.get('search_space_id')
|
||||
research_mode: str = request.data.get('research_mode')
|
||||
selected_connectors: List[str] = request.data.get('selected_connectors')
|
||||
|
||||
# Convert search_space_id to integer if it's a string
|
||||
if search_space_id and isinstance(search_space_id, str):
|
||||
try:
|
||||
search_space_id = int(search_space_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid search_space_id format")
|
||||
|
||||
# Check if the search space belongs to the current user
|
||||
try:
|
||||
await check_ownership(session, SearchSpace, search_space_id, user)
|
||||
except HTTPException:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have access to this search space")
|
||||
|
||||
response = StreamingResponse(stream_connector_search_results(
|
||||
user_query,
|
||||
user.id,
|
||||
search_space_id,
|
||||
session,
|
||||
research_mode,
|
||||
selected_connectors
|
||||
))
|
||||
response.headers['x-vercel-ai-data-stream'] = 'v1'
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/chats/", response_model=ChatRead)
|
||||
async def create_chat(
|
||||
chat: ChatCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
await check_ownership(session, SearchSpace, chat.search_space_id, user)
|
||||
db_chat = Chat(**chat.model_dump())
|
||||
session.add(db_chat)
|
||||
await session.commit()
|
||||
await session.refresh(db_chat)
|
||||
return db_chat
|
||||
except HTTPException:
|
||||
raise
|
||||
except IntegrityError as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Database constraint violation. Please check your input data.")
|
||||
except OperationalError as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later.")
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500, detail="An unexpected error occurred while creating the chat.")
|
||||
|
||||
|
||||
@router.get("/chats/", response_model=List[ChatRead])
|
||||
async def read_chats(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(Chat)
|
||||
.join(SearchSpace)
|
||||
.filter(SearchSpace.user_id == user.id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return result.scalars().all()
|
||||
except OperationalError:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later.")
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="An unexpected error occurred while fetching chats.")
|
||||
|
||||
|
||||
@router.get("/chats/{chat_id}", response_model=ChatRead)
|
||||
async def read_chat(
|
||||
chat_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(Chat)
|
||||
.join(SearchSpace)
|
||||
.filter(Chat.id == chat_id, SearchSpace.user_id == user.id)
|
||||
)
|
||||
chat = result.scalars().first()
|
||||
if not chat:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Chat not found or you don't have permission to access it")
|
||||
return chat
|
||||
except OperationalError:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later.")
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="An unexpected error occurred while fetching the chat.")
|
||||
|
||||
|
||||
@router.put("/chats/{chat_id}", response_model=ChatRead)
|
||||
async def update_chat(
|
||||
chat_id: int,
|
||||
chat_update: ChatUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
db_chat = await read_chat(chat_id, session, user)
|
||||
update_data = chat_update.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(db_chat, key, value)
|
||||
await session.commit()
|
||||
await session.refresh(db_chat)
|
||||
return db_chat
|
||||
except HTTPException:
|
||||
raise
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Database constraint violation. Please check your input data.")
|
||||
except OperationalError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later.")
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500, detail="An unexpected error occurred while updating the chat.")
|
||||
|
||||
|
||||
@router.delete("/chats/{chat_id}", response_model=dict)
|
||||
async def delete_chat(
|
||||
chat_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
db_chat = await read_chat(chat_id, session, user)
|
||||
await session.delete(db_chat)
|
||||
await session.commit()
|
||||
return {"message": "Chat deleted successfully"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Cannot delete chat due to existing dependencies.")
|
||||
except OperationalError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later.")
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500, detail="An unexpected error occurred while deleting the chat.")
|
||||
|
||||
|
||||
# test_data = [
|
||||
# {
|
||||
# "type": "TERMINAL_INFO",
|
||||
# "content": [
|
||||
# {
|
||||
# "id": 1,
|
||||
# "text": "Starting to search for crawled URLs...",
|
||||
# "type": "info"
|
||||
# },
|
||||
# {
|
||||
# "id": 2,
|
||||
# "text": "Found 2 relevant crawled URLs",
|
||||
# "type": "success"
|
||||
# }
|
||||
# ]
|
||||
# },
|
||||
# {
|
||||
# "type": "SOURCES",
|
||||
# "content": [
|
||||
# {
|
||||
# "id": 1,
|
||||
# "name": "Crawled URLs",
|
||||
# "type": "CRAWLED_URL",
|
||||
# "sources": [
|
||||
# {
|
||||
# "id": 1,
|
||||
# "title": "Webpage Title",
|
||||
# "description": "Webpage Dec",
|
||||
# "url": "https://jsoneditoronline.org/"
|
||||
# },
|
||||
# {
|
||||
# "id": 2,
|
||||
# "title": "Webpage Title",
|
||||
# "description": "Webpage Dec",
|
||||
# "url": "https://www.google.com/"
|
||||
# }
|
||||
# ]
|
||||
# },
|
||||
# {
|
||||
# "id": 2,
|
||||
# "name": "Files",
|
||||
# "type": "FILE",
|
||||
# "sources": [
|
||||
# {
|
||||
# "id": 3,
|
||||
# "title": "Webpage Title",
|
||||
# "description": "Webpage Dec",
|
||||
# "url": "https://jsoneditoronline.org/"
|
||||
# },
|
||||
# {
|
||||
# "id": 4,
|
||||
# "title": "Webpage Title",
|
||||
# "description": "Webpage Dec",
|
||||
# "url": "https://www.google.com/"
|
||||
# }
|
||||
# ]
|
||||
# }
|
||||
# ]
|
||||
# },
|
||||
# {
|
||||
# "type": "ANSWER",
|
||||
# "content": [
|
||||
# "## SurfSense Introduction",
|
||||
# "Surfsense is A Personal NotebookLM and Perplexity-like AI Assistant for Everyone. Research and Never forget Anything. [1] [3]"
|
||||
# ]
|
||||
# }
|
||||
# ]
|
262
surfsense_backend/app/routes/documents_routes.py
Normal file
262
surfsense_backend/app/routes/documents_routes.py
Normal file
|
@ -0,0 +1,262 @@
|
|||
from fastapi import APIRouter, Depends, BackgroundTasks, UploadFile, Form, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from typing import List
|
||||
from app.db import get_async_session, User, SearchSpace, Document, DocumentType
|
||||
from app.schemas import DocumentsCreate, DocumentUpdate, DocumentRead
|
||||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
from app.tasks.background_tasks import add_extension_received_document, add_received_file_document, add_crawled_url_document
|
||||
from langchain_unstructured import UnstructuredLoader
|
||||
from app.config import config
|
||||
import json
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/documents/")
|
||||
async def create_documents(
|
||||
request: DocumentsCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
fastapi_background_tasks: BackgroundTasks = BackgroundTasks()
|
||||
):
|
||||
try:
|
||||
# Check if the user owns the search space
|
||||
await check_ownership(session, SearchSpace, request.search_space_id, user)
|
||||
|
||||
if request.document_type == DocumentType.EXTENSION:
|
||||
for individual_document in request.content:
|
||||
fastapi_background_tasks.add_task(
|
||||
add_extension_received_document,
|
||||
session,
|
||||
individual_document,
|
||||
request.search_space_id
|
||||
)
|
||||
elif request.document_type == DocumentType.CRAWLED_URL:
|
||||
for url in request.content:
|
||||
fastapi_background_tasks.add_task(
|
||||
add_crawled_url_document,
|
||||
session,
|
||||
url,
|
||||
request.search_space_id
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid document type"
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
return {"message": "Documents processed successfully"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to process documents: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/documents/fileupload")
|
||||
async def create_documents(
|
||||
files: list[UploadFile],
|
||||
search_space_id: int = Form(...),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
fastapi_background_tasks: BackgroundTasks = BackgroundTasks()
|
||||
):
|
||||
try:
|
||||
await check_ownership(session, SearchSpace, search_space_id, user)
|
||||
|
||||
if not files:
|
||||
raise HTTPException(status_code=400, detail="No files provided")
|
||||
|
||||
for file in files:
|
||||
try:
|
||||
unstructured_loader = UnstructuredLoader(
|
||||
file=file.file,
|
||||
api_key=config.UNSTRUCTURED_API_KEY,
|
||||
partition_via_api=True,
|
||||
languages=["eng"],
|
||||
include_orig_elements=False,
|
||||
strategy="fast",
|
||||
)
|
||||
|
||||
unstructured_processed_elements = await unstructured_loader.aload()
|
||||
|
||||
fastapi_background_tasks.add_task(
|
||||
add_received_file_document,
|
||||
session,
|
||||
file.filename,
|
||||
unstructured_processed_elements,
|
||||
search_space_id
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Failed to process file {file.filename}: {str(e)}"
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
return {"message": "Files added for processing successfully"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to process documents: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/documents/", response_model=List[DocumentRead])
|
||||
async def read_documents(
|
||||
skip: int = 0,
|
||||
limit: int = 300,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(Document)
|
||||
.join(SearchSpace)
|
||||
.filter(SearchSpace.user_id == user.id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
db_documents = result.scalars().all()
|
||||
|
||||
# Convert database objects to API-friendly format
|
||||
api_documents = []
|
||||
for doc in db_documents:
|
||||
api_documents.append(DocumentRead(
|
||||
id=doc.id,
|
||||
title=doc.title,
|
||||
document_type=doc.document_type,
|
||||
document_metadata=doc.document_metadata,
|
||||
content=doc.content,
|
||||
created_at=doc.created_at,
|
||||
search_space_id=doc.search_space_id
|
||||
))
|
||||
|
||||
return api_documents
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch documents: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/documents/{document_id}", response_model=DocumentRead)
|
||||
async def read_document(
|
||||
document_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(Document)
|
||||
.join(SearchSpace)
|
||||
.filter(Document.id == document_id, SearchSpace.user_id == user.id)
|
||||
)
|
||||
document = result.scalars().first()
|
||||
|
||||
if not document:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Document with id {document_id} not found"
|
||||
)
|
||||
|
||||
# Convert database object to API-friendly format
|
||||
return DocumentRead(
|
||||
id=document.id,
|
||||
title=document.title,
|
||||
document_type=document.document_type,
|
||||
document_metadata=document.document_metadata,
|
||||
content=document.content,
|
||||
created_at=document.created_at,
|
||||
search_space_id=document.search_space_id
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch document: {str(e)}"
|
||||
)
|
||||
|
||||
@router.put("/documents/{document_id}", response_model=DocumentRead)
|
||||
async def update_document(
|
||||
document_id: int,
|
||||
document_update: DocumentUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
# Query the document directly instead of using read_document function
|
||||
result = await session.execute(
|
||||
select(Document)
|
||||
.join(SearchSpace)
|
||||
.filter(Document.id == document_id, SearchSpace.user_id == user.id)
|
||||
)
|
||||
db_document = result.scalars().first()
|
||||
|
||||
if not db_document:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Document with id {document_id} not found"
|
||||
)
|
||||
|
||||
update_data = document_update.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(db_document, key, value)
|
||||
await session.commit()
|
||||
await session.refresh(db_document)
|
||||
|
||||
# Convert to DocumentRead for response
|
||||
return DocumentRead(
|
||||
id=db_document.id,
|
||||
title=db_document.title,
|
||||
document_type=db_document.document_type,
|
||||
document_metadata=db_document.document_metadata,
|
||||
content=db_document.content,
|
||||
created_at=db_document.created_at,
|
||||
search_space_id=db_document.search_space_id
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to update document: {str(e)}"
|
||||
)
|
||||
|
||||
@router.delete("/documents/{document_id}", response_model=dict)
|
||||
async def delete_document(
|
||||
document_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
# Query the document directly instead of using read_document function
|
||||
result = await session.execute(
|
||||
select(Document)
|
||||
.join(SearchSpace)
|
||||
.filter(Document.id == document_id, SearchSpace.user_id == user.id)
|
||||
)
|
||||
document = result.scalars().first()
|
||||
|
||||
if not document:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Document with id {document_id} not found"
|
||||
)
|
||||
|
||||
await session.delete(document)
|
||||
await session.commit()
|
||||
return {"message": "Document deleted successfully"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to delete document: {str(e)}"
|
||||
)
|
122
surfsense_backend/app/routes/podcasts_routes.py
Normal file
122
surfsense_backend/app/routes/podcasts_routes.py
Normal file
|
@ -0,0 +1,122 @@
|
|||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
||||
from typing import List
|
||||
from app.db import get_async_session, User, SearchSpace, Podcast
|
||||
from app.schemas import PodcastCreate, PodcastUpdate, PodcastRead
|
||||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/podcasts/", response_model=PodcastRead)
|
||||
async def create_podcast(
|
||||
podcast: PodcastCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
await check_ownership(session, SearchSpace, podcast.search_space_id, user)
|
||||
db_podcast = Podcast(**podcast.model_dump())
|
||||
session.add(db_podcast)
|
||||
await session.commit()
|
||||
await session.refresh(db_podcast)
|
||||
return db_podcast
|
||||
except HTTPException as he:
|
||||
raise he
|
||||
except IntegrityError as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=400, detail="Podcast creation failed due to constraint violation")
|
||||
except SQLAlchemyError as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=500, detail="Database error occurred while creating podcast")
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=500, detail="An unexpected error occurred")
|
||||
|
||||
@router.get("/podcasts/", response_model=List[PodcastRead])
|
||||
async def read_podcasts(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
if skip < 0 or limit < 1:
|
||||
raise HTTPException(status_code=400, detail="Invalid pagination parameters")
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(Podcast)
|
||||
.join(SearchSpace)
|
||||
.filter(SearchSpace.user_id == user.id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return result.scalars().all()
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=500, detail="Database error occurred while fetching podcasts")
|
||||
|
||||
@router.get("/podcasts/{podcast_id}", response_model=PodcastRead)
|
||||
async def read_podcast(
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(Podcast)
|
||||
.join(SearchSpace)
|
||||
.filter(Podcast.id == podcast_id, SearchSpace.user_id == user.id)
|
||||
)
|
||||
podcast = result.scalars().first()
|
||||
if not podcast:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Podcast not found or you don't have permission to access it"
|
||||
)
|
||||
return podcast
|
||||
except HTTPException as he:
|
||||
raise he
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=500, detail="Database error occurred while fetching podcast")
|
||||
|
||||
@router.put("/podcasts/{podcast_id}", response_model=PodcastRead)
|
||||
async def update_podcast(
|
||||
podcast_id: int,
|
||||
podcast_update: PodcastUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
db_podcast = await read_podcast(podcast_id, session, user)
|
||||
update_data = podcast_update.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(db_podcast, key, value)
|
||||
await session.commit()
|
||||
await session.refresh(db_podcast)
|
||||
return db_podcast
|
||||
except HTTPException as he:
|
||||
raise he
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=400, detail="Update failed due to constraint violation")
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=500, detail="Database error occurred while updating podcast")
|
||||
|
||||
@router.delete("/podcasts/{podcast_id}", response_model=dict)
|
||||
async def delete_podcast(
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
db_podcast = await read_podcast(podcast_id, session, user)
|
||||
await session.delete(db_podcast)
|
||||
await session.commit()
|
||||
return {"message": "Podcast deleted successfully"}
|
||||
except HTTPException as he:
|
||||
raise he
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=500, detail="Database error occurred while deleting podcast")
|
418
surfsense_backend/app/routes/search_source_connectors_routes.py
Normal file
418
surfsense_backend/app/routes/search_source_connectors_routes.py
Normal file
|
@ -0,0 +1,418 @@
|
|||
"""
|
||||
SearchSourceConnector routes for CRUD operations:
|
||||
POST /search-source-connectors/ - Create a new connector
|
||||
GET /search-source-connectors/ - List all connectors for the current user
|
||||
GET /search-source-connectors/{connector_id} - Get a specific connector
|
||||
PUT /search-source-connectors/{connector_id} - Update a specific connector
|
||||
DELETE /search-source-connectors/{connector_id} - Delete a specific connector
|
||||
POST /search-source-connectors/{connector_id}/index - Index content from a connector to a search space
|
||||
|
||||
Note: Each user can have only one connector of each type (SERPER_API, TAVILY_API, SLACK_CONNECTOR, NOTION_CONNECTOR).
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from typing import List, Dict, Any
|
||||
from app.db import get_async_session, User, SearchSourceConnector, SearchSourceConnectorType, SearchSpace
|
||||
from app.schemas import SearchSourceConnectorCreate, SearchSourceConnectorUpdate, SearchSourceConnectorRead
|
||||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
from pydantic import ValidationError
|
||||
from app.tasks.connectors_indexing_tasks import index_slack_messages, index_notion_pages
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
# Set up logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/search-source-connectors/", response_model=SearchSourceConnectorRead)
|
||||
async def create_search_source_connector(
|
||||
connector: SearchSourceConnectorCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
"""
|
||||
Create a new search source connector.
|
||||
|
||||
Each user can have only one connector of each type (SERPER_API, TAVILY_API, SLACK_CONNECTOR).
|
||||
The config must contain the appropriate keys for the connector type.
|
||||
"""
|
||||
try:
|
||||
# Check if a connector with the same type already exists for this user
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector)
|
||||
.filter(
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.connector_type == connector.connector_type
|
||||
)
|
||||
)
|
||||
existing_connector = result.scalars().first()
|
||||
|
||||
if existing_connector:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"A connector with type {connector.connector_type} already exists. Each user can have only one connector of each type."
|
||||
)
|
||||
|
||||
db_connector = SearchSourceConnector(**connector.model_dump(), user_id=user.id)
|
||||
session.add(db_connector)
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
return db_connector
|
||||
except ValidationError as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Validation error: {str(e)}"
|
||||
)
|
||||
except IntegrityError as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Integrity error: A connector with this type already exists. {str(e)}"
|
||||
)
|
||||
except HTTPException:
|
||||
await session.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to create search source connector: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/search-source-connectors/", response_model=List[SearchSourceConnectorRead])
|
||||
async def read_search_source_connectors(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
"""List all search source connectors for the current user."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector)
|
||||
.filter(SearchSourceConnector.user_id == user.id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return result.scalars().all()
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch search source connectors: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/search-source-connectors/{connector_id}", response_model=SearchSourceConnectorRead)
|
||||
async def read_search_source_connector(
|
||||
connector_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
"""Get a specific search source connector by ID."""
|
||||
try:
|
||||
return await check_ownership(session, SearchSourceConnector, connector_id, user)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch search source connector: {str(e)}"
|
||||
)
|
||||
|
||||
@router.put("/search-source-connectors/{connector_id}", response_model=SearchSourceConnectorRead)
|
||||
async def update_search_source_connector(
|
||||
connector_id: int,
|
||||
connector_update: SearchSourceConnectorUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
"""
|
||||
Update a search source connector.
|
||||
|
||||
Each user can have only one connector of each type (SERPER_API, TAVILY_API, SLACK_CONNECTOR).
|
||||
The config must contain the appropriate keys for the connector type.
|
||||
"""
|
||||
try:
|
||||
db_connector = await check_ownership(session, SearchSourceConnector, connector_id, user)
|
||||
|
||||
# If connector type is being changed, check if one of that type already exists
|
||||
if connector_update.connector_type != db_connector.connector_type:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector)
|
||||
.filter(
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.connector_type == connector_update.connector_type,
|
||||
SearchSourceConnector.id != connector_id
|
||||
)
|
||||
)
|
||||
existing_connector = result.scalars().first()
|
||||
|
||||
if existing_connector:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"A connector with type {connector_update.connector_type} already exists. Each user can have only one connector of each type."
|
||||
)
|
||||
|
||||
update_data = connector_update.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(db_connector, key, value)
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
return db_connector
|
||||
except ValidationError as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Validation error: {str(e)}"
|
||||
)
|
||||
except IntegrityError as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Integrity error: A connector with this type already exists. {str(e)}"
|
||||
)
|
||||
except HTTPException:
|
||||
await session.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to update search source connector: {str(e)}"
|
||||
)
|
||||
|
||||
@router.delete("/search-source-connectors/{connector_id}", response_model=dict)
|
||||
async def delete_search_source_connector(
|
||||
connector_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
"""Delete a search source connector."""
|
||||
try:
|
||||
db_connector = await check_ownership(session, SearchSourceConnector, connector_id, user)
|
||||
await session.delete(db_connector)
|
||||
await session.commit()
|
||||
return {"message": "Search source connector deleted successfully"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to delete search source connector: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/search-source-connectors/{connector_id}/index", response_model=Dict[str, Any])
|
||||
async def index_connector_content(
|
||||
connector_id: int,
|
||||
search_space_id: int = Query(..., description="ID of the search space to store indexed content"),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
background_tasks: BackgroundTasks = None
|
||||
):
|
||||
"""
|
||||
Index content from a connector to a search space.
|
||||
|
||||
Currently supports:
|
||||
- SLACK_CONNECTOR: Indexes messages from all accessible Slack channels since the last indexing
|
||||
(or the last 365 days if never indexed before)
|
||||
- NOTION_CONNECTOR: Indexes pages from all accessible Notion pages since the last indexing
|
||||
(or the last 365 days if never indexed before)
|
||||
|
||||
Args:
|
||||
connector_id: ID of the connector to use
|
||||
search_space_id: ID of the search space to store indexed content
|
||||
background_tasks: FastAPI background tasks
|
||||
|
||||
Returns:
|
||||
Dictionary with indexing status
|
||||
"""
|
||||
try:
|
||||
# Check if the connector belongs to the user
|
||||
connector = await check_ownership(session, SearchSourceConnector, connector_id, user)
|
||||
|
||||
# Check if the search space belongs to the user
|
||||
search_space = await check_ownership(session, SearchSpace, search_space_id, user)
|
||||
|
||||
# Handle different connector types
|
||||
if connector.connector_type == SearchSourceConnectorType.SLACK_CONNECTOR:
|
||||
# Determine the time range that will be indexed
|
||||
if not connector.last_indexed_at:
|
||||
start_date = "365 days ago"
|
||||
else:
|
||||
# Check if last_indexed_at is today
|
||||
today = datetime.now().date()
|
||||
if connector.last_indexed_at.date() == today:
|
||||
# If last indexed today, go back 1 day to ensure we don't miss anything
|
||||
start_date = (today - datetime.timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
else:
|
||||
start_date = connector.last_indexed_at.strftime("%Y-%m-%d")
|
||||
|
||||
# Add the indexing task to background tasks
|
||||
if background_tasks:
|
||||
background_tasks.add_task(
|
||||
run_slack_indexing,
|
||||
session,
|
||||
connector_id,
|
||||
search_space_id
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Slack indexing started in the background",
|
||||
"connector_type": connector.connector_type,
|
||||
"search_space": search_space.name,
|
||||
"indexing_from": start_date,
|
||||
"indexing_to": datetime.now().strftime("%Y-%m-%d")
|
||||
}
|
||||
else:
|
||||
# For testing or if background tasks are not available
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Background tasks not available",
|
||||
"connector_type": connector.connector_type
|
||||
}
|
||||
elif connector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR:
|
||||
# Determine the time range that will be indexed
|
||||
if not connector.last_indexed_at:
|
||||
start_date = "365 days ago"
|
||||
else:
|
||||
# Check if last_indexed_at is today
|
||||
today = datetime.now().date()
|
||||
if connector.last_indexed_at.date() == today:
|
||||
# If last indexed today, go back 1 day to ensure we don't miss anything
|
||||
start_date = (today - datetime.timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
else:
|
||||
start_date = connector.last_indexed_at.strftime("%Y-%m-%d")
|
||||
|
||||
# Add the indexing task to background tasks
|
||||
if background_tasks:
|
||||
background_tasks.add_task(
|
||||
run_notion_indexing,
|
||||
session,
|
||||
connector_id,
|
||||
search_space_id
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Notion indexing started in the background",
|
||||
"connector_type": connector.connector_type,
|
||||
"search_space": search_space.name,
|
||||
"indexing_from": start_date,
|
||||
"indexing_to": datetime.now().strftime("%Y-%m-%d")
|
||||
}
|
||||
else:
|
||||
# For testing or if background tasks are not available
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Background tasks not available",
|
||||
"connector_type": connector.connector_type
|
||||
}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Indexing not supported for connector type: {connector.connector_type}"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start indexing: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to start indexing: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
async def update_connector_last_indexed(
|
||||
session: AsyncSession,
|
||||
connector_id: int
|
||||
):
|
||||
"""
|
||||
Update the last_indexed_at timestamp for a connector.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the connector to update
|
||||
"""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector)
|
||||
.filter(SearchSourceConnector.id == connector_id)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
||||
if connector:
|
||||
connector.last_indexed_at = datetime.now()
|
||||
await session.commit()
|
||||
logger.info(f"Updated last_indexed_at for connector {connector_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update last_indexed_at for connector {connector_id}: {str(e)}")
|
||||
await session.rollback()
|
||||
|
||||
async def run_slack_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int
|
||||
):
|
||||
"""
|
||||
Background task to run Slack indexing.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Slack connector
|
||||
search_space_id: ID of the search space
|
||||
"""
|
||||
try:
|
||||
# Index Slack messages without updating last_indexed_at (we'll do it separately)
|
||||
documents_indexed, error_or_warning = await index_slack_messages(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
update_last_indexed=False # Don't update timestamp in the indexing function
|
||||
)
|
||||
|
||||
# Only update last_indexed_at if indexing was successful
|
||||
if documents_indexed > 0 and (error_or_warning is None or "Indexed" in error_or_warning):
|
||||
await update_connector_last_indexed(session, connector_id)
|
||||
logger.info(f"Slack indexing completed successfully: {documents_indexed} documents indexed")
|
||||
else:
|
||||
logger.error(f"Slack indexing failed or no documents indexed: {error_or_warning}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in background Slack indexing task: {str(e)}")
|
||||
|
||||
async def run_notion_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int
|
||||
):
|
||||
"""
|
||||
Background task to run Notion indexing.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Notion connector
|
||||
search_space_id: ID of the search space
|
||||
"""
|
||||
try:
|
||||
# Index Notion pages without updating last_indexed_at (we'll do it separately)
|
||||
documents_indexed, error_or_warning = await index_notion_pages(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
update_last_indexed=False # Don't update timestamp in the indexing function
|
||||
)
|
||||
|
||||
# Only update last_indexed_at if indexing was successful
|
||||
if documents_indexed > 0 and (error_or_warning is None or "Indexed" in error_or_warning):
|
||||
await update_connector_last_indexed(session, connector_id)
|
||||
logger.info(f"Notion indexing completed successfully: {documents_indexed} documents indexed")
|
||||
else:
|
||||
logger.error(f"Notion indexing failed or no documents indexed: {error_or_warning}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in background Notion indexing task: {str(e)}")
|
115
surfsense_backend/app/routes/search_spaces_routes.py
Normal file
115
surfsense_backend/app/routes/search_spaces_routes.py
Normal file
|
@ -0,0 +1,115 @@
|
|||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from typing import List
|
||||
from app.db import get_async_session, User, SearchSpace
|
||||
from app.schemas import SearchSpaceCreate, SearchSpaceUpdate, SearchSpaceRead
|
||||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
from fastapi import HTTPException
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/searchspaces/", response_model=SearchSpaceRead)
|
||||
async def create_search_space(
|
||||
search_space: SearchSpaceCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
db_search_space = SearchSpace(**search_space.model_dump(), user_id=user.id)
|
||||
session.add(db_search_space)
|
||||
await session.commit()
|
||||
await session.refresh(db_search_space)
|
||||
return db_search_space
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to create search space: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/searchspaces/", response_model=List[SearchSpaceRead])
|
||||
async def read_search_spaces(
|
||||
skip: int = 0,
|
||||
limit: int = 200,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSpace)
|
||||
.filter(SearchSpace.user_id == user.id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return result.scalars().all()
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch search spaces: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/searchspaces/{search_space_id}", response_model=SearchSpaceRead)
|
||||
async def read_search_space(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
search_space = await check_ownership(session, SearchSpace, search_space_id, user)
|
||||
return search_space
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch search space: {str(e)}"
|
||||
)
|
||||
|
||||
@router.put("/searchspaces/{search_space_id}", response_model=SearchSpaceRead)
|
||||
async def update_search_space(
|
||||
search_space_id: int,
|
||||
search_space_update: SearchSpaceUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
db_search_space = await check_ownership(session, SearchSpace, search_space_id, user)
|
||||
update_data = search_space_update.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(db_search_space, key, value)
|
||||
await session.commit()
|
||||
await session.refresh(db_search_space)
|
||||
return db_search_space
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to update search space: {str(e)}"
|
||||
)
|
||||
|
||||
@router.delete("/searchspaces/{search_space_id}", response_model=dict)
|
||||
async def delete_search_space(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
db_search_space = await check_ownership(session, SearchSpace, search_space_id, user)
|
||||
await session.delete(db_search_space)
|
||||
await session.commit()
|
||||
return {"message": "Search space deleted successfully"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to delete search space: {str(e)}"
|
||||
)
|
50
surfsense_backend/app/schemas/__init__.py
Normal file
50
surfsense_backend/app/schemas/__init__.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
from .base import TimestampModel, IDModel
|
||||
from .users import UserRead, UserCreate, UserUpdate
|
||||
from .search_space import SearchSpaceBase, SearchSpaceCreate, SearchSpaceUpdate, SearchSpaceRead
|
||||
from .documents import (
|
||||
ExtensionDocumentMetadata,
|
||||
ExtensionDocumentContent,
|
||||
DocumentBase,
|
||||
DocumentsCreate,
|
||||
DocumentUpdate,
|
||||
DocumentRead,
|
||||
)
|
||||
from .chunks import ChunkBase, ChunkCreate, ChunkUpdate, ChunkRead
|
||||
from .podcasts import PodcastBase, PodcastCreate, PodcastUpdate, PodcastRead
|
||||
from .chats import ChatBase, ChatCreate, ChatUpdate, ChatRead, AISDKChatRequest
|
||||
from .search_source_connector import SearchSourceConnectorBase, SearchSourceConnectorCreate, SearchSourceConnectorUpdate, SearchSourceConnectorRead
|
||||
|
||||
__all__ = [
|
||||
"AISDKChatRequest",
|
||||
"TimestampModel",
|
||||
"IDModel",
|
||||
"UserRead",
|
||||
"UserCreate",
|
||||
"UserUpdate",
|
||||
"SearchSpaceBase",
|
||||
"SearchSpaceCreate",
|
||||
"SearchSpaceUpdate",
|
||||
"SearchSpaceRead",
|
||||
"ExtensionDocumentMetadata",
|
||||
"ExtensionDocumentContent",
|
||||
"DocumentBase",
|
||||
"DocumentsCreate",
|
||||
"DocumentUpdate",
|
||||
"DocumentRead",
|
||||
"ChunkBase",
|
||||
"ChunkCreate",
|
||||
"ChunkUpdate",
|
||||
"ChunkRead",
|
||||
"PodcastBase",
|
||||
"PodcastCreate",
|
||||
"PodcastUpdate",
|
||||
"PodcastRead",
|
||||
"ChatBase",
|
||||
"ChatCreate",
|
||||
"ChatUpdate",
|
||||
"ChatRead",
|
||||
"SearchSourceConnectorBase",
|
||||
"SearchSourceConnectorCreate",
|
||||
"SearchSourceConnectorUpdate",
|
||||
"SearchSourceConnectorRead",
|
||||
]
|
8
surfsense_backend/app/schemas/base.py
Normal file
8
surfsense_backend/app/schemas/base.py
Normal file
|
@ -0,0 +1,8 @@
|
|||
from datetime import datetime
|
||||
from pydantic import BaseModel
|
||||
|
||||
class TimestampModel(BaseModel):
|
||||
created_at: datetime
|
||||
|
||||
class IDModel(BaseModel):
|
||||
id: int
|
46
surfsense_backend/app/schemas/chats.py
Normal file
46
surfsense_backend/app/schemas/chats.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
from typing import Any, Dict, List, Optional
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import JSON
|
||||
from .base import IDModel, TimestampModel
|
||||
from app.db import ChatType
|
||||
|
||||
class ChatBase(BaseModel):
|
||||
type: ChatType
|
||||
title: str
|
||||
initial_connectors: Optional[List[str]] = None
|
||||
messages: List[Any]
|
||||
search_space_id: int
|
||||
|
||||
|
||||
class ClientAttachment(BaseModel):
|
||||
name: str
|
||||
contentType: str
|
||||
url: str
|
||||
|
||||
|
||||
class ToolInvocation(BaseModel):
|
||||
toolCallId: str
|
||||
toolName: str
|
||||
args: dict
|
||||
result: dict
|
||||
|
||||
|
||||
class ClientMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
experimental_attachments: Optional[List[ClientAttachment]] = None
|
||||
toolInvocations: Optional[List[ToolInvocation]] = None
|
||||
|
||||
class AISDKChatRequest(BaseModel):
|
||||
messages: List[ClientMessage]
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
|
||||
class ChatCreate(ChatBase):
|
||||
pass
|
||||
|
||||
class ChatUpdate(ChatBase):
|
||||
pass
|
||||
|
||||
class ChatRead(ChatBase, IDModel, TimestampModel):
|
||||
class Config:
|
||||
from_attributes = True
|
16
surfsense_backend/app/schemas/chunks.py
Normal file
16
surfsense_backend/app/schemas/chunks.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
from pydantic import BaseModel
|
||||
from .base import IDModel, TimestampModel
|
||||
|
||||
class ChunkBase(BaseModel):
|
||||
content: str
|
||||
document_id: int
|
||||
|
||||
class ChunkCreate(ChunkBase):
|
||||
pass
|
||||
|
||||
class ChunkUpdate(ChunkBase):
|
||||
pass
|
||||
|
||||
class ChunkRead(ChunkBase, IDModel, TimestampModel):
|
||||
class Config:
|
||||
from_attributes = True
|
42
surfsense_backend/app/schemas/documents.py
Normal file
42
surfsense_backend/app/schemas/documents.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
from typing import List, Any
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import JSON
|
||||
from .base import IDModel, TimestampModel
|
||||
from app.db import DocumentType
|
||||
from datetime import datetime
|
||||
|
||||
class ExtensionDocumentMetadata(BaseModel):
|
||||
BrowsingSessionId: str
|
||||
VisitedWebPageURL: str
|
||||
VisitedWebPageTitle: str
|
||||
VisitedWebPageDateWithTimeInISOString: str
|
||||
VisitedWebPageReffererURL: str
|
||||
VisitedWebPageVisitDurationInMilliseconds: str
|
||||
|
||||
class ExtensionDocumentContent(BaseModel):
|
||||
metadata: ExtensionDocumentMetadata
|
||||
pageContent: str
|
||||
|
||||
class DocumentBase(BaseModel):
|
||||
document_type: DocumentType
|
||||
content: List[ExtensionDocumentContent] | List[str] | str # Updated to allow string content
|
||||
search_space_id: int
|
||||
|
||||
class DocumentsCreate(DocumentBase):
|
||||
pass
|
||||
|
||||
class DocumentUpdate(DocumentBase):
|
||||
pass
|
||||
|
||||
class DocumentRead(BaseModel):
|
||||
id: int
|
||||
title: str
|
||||
document_type: DocumentType
|
||||
document_metadata: dict
|
||||
content: str # Changed to string to match frontend
|
||||
created_at: datetime
|
||||
search_space_id: int
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
19
surfsense_backend/app/schemas/podcasts.py
Normal file
19
surfsense_backend/app/schemas/podcasts.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
from pydantic import BaseModel
|
||||
from .base import IDModel, TimestampModel
|
||||
|
||||
class PodcastBase(BaseModel):
|
||||
title: str
|
||||
is_generated: bool = False
|
||||
podcast_content: str = ""
|
||||
file_location: str = ""
|
||||
search_space_id: int
|
||||
|
||||
class PodcastCreate(PodcastBase):
|
||||
pass
|
||||
|
||||
class PodcastUpdate(PodcastBase):
|
||||
pass
|
||||
|
||||
class PodcastRead(PodcastBase, IDModel, TimestampModel):
|
||||
class Config:
|
||||
from_attributes = True
|
73
surfsense_backend/app/schemas/search_source_connector.py
Normal file
73
surfsense_backend/app/schemas/search_source_connector.py
Normal file
|
@ -0,0 +1,73 @@
|
|||
from datetime import datetime
|
||||
import uuid
|
||||
from typing import Dict, Any
|
||||
from pydantic import BaseModel, field_validator
|
||||
from .base import IDModel, TimestampModel
|
||||
from app.db import SearchSourceConnectorType
|
||||
from fastapi import HTTPException
|
||||
|
||||
class SearchSourceConnectorBase(BaseModel):
|
||||
name: str
|
||||
connector_type: SearchSourceConnectorType
|
||||
is_indexable: bool
|
||||
last_indexed_at: datetime | None
|
||||
config: Dict[str, Any]
|
||||
|
||||
@field_validator('config')
|
||||
@classmethod
|
||||
def validate_config_for_connector_type(cls, config: Dict[str, Any], values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
connector_type = values.data.get('connector_type')
|
||||
|
||||
if connector_type == SearchSourceConnectorType.SERPER_API:
|
||||
# For SERPER_API, only allow SERPER_API_KEY
|
||||
allowed_keys = ["SERPER_API_KEY"]
|
||||
if set(config.keys()) != set(allowed_keys):
|
||||
raise ValueError(f"For SERPER_API connector type, config must only contain these keys: {allowed_keys}")
|
||||
|
||||
# Ensure the API key is not empty
|
||||
if not config.get("SERPER_API_KEY"):
|
||||
raise ValueError("SERPER_API_KEY cannot be empty")
|
||||
|
||||
elif connector_type == SearchSourceConnectorType.TAVILY_API:
|
||||
# For TAVILY_API, only allow TAVILY_API_KEY
|
||||
allowed_keys = ["TAVILY_API_KEY"]
|
||||
if set(config.keys()) != set(allowed_keys):
|
||||
raise ValueError(f"For TAVILY_API connector type, config must only contain these keys: {allowed_keys}")
|
||||
|
||||
# Ensure the API key is not empty
|
||||
if not config.get("TAVILY_API_KEY"):
|
||||
raise ValueError("TAVILY_API_KEY cannot be empty")
|
||||
|
||||
elif connector_type == SearchSourceConnectorType.SLACK_CONNECTOR:
|
||||
# For SLACK_CONNECTOR, only allow SLACK_BOT_TOKEN
|
||||
allowed_keys = ["SLACK_BOT_TOKEN"]
|
||||
if set(config.keys()) != set(allowed_keys):
|
||||
raise ValueError(f"For SLACK_CONNECTOR connector type, config must only contain these keys: {allowed_keys}")
|
||||
|
||||
# Ensure the bot token is not empty
|
||||
if not config.get("SLACK_BOT_TOKEN"):
|
||||
raise ValueError("SLACK_BOT_TOKEN cannot be empty")
|
||||
|
||||
elif connector_type == SearchSourceConnectorType.NOTION_CONNECTOR:
|
||||
# For NOTION_CONNECTOR, only allow NOTION_INTEGRATION_TOKEN
|
||||
allowed_keys = ["NOTION_INTEGRATION_TOKEN"]
|
||||
if set(config.keys()) != set(allowed_keys):
|
||||
raise ValueError(f"For NOTION_CONNECTOR connector type, config must only contain these keys: {allowed_keys}")
|
||||
|
||||
# Ensure the integration token is not empty
|
||||
if not config.get("NOTION_INTEGRATION_TOKEN"):
|
||||
raise ValueError("NOTION_INTEGRATION_TOKEN cannot be empty")
|
||||
|
||||
return config
|
||||
|
||||
class SearchSourceConnectorCreate(SearchSourceConnectorBase):
|
||||
pass
|
||||
|
||||
class SearchSourceConnectorUpdate(SearchSourceConnectorBase):
|
||||
pass
|
||||
|
||||
class SearchSourceConnectorRead(SearchSourceConnectorBase, IDModel, TimestampModel):
|
||||
user_id: uuid.UUID
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
23
surfsense_backend/app/schemas/search_space.py
Normal file
23
surfsense_backend/app/schemas/search_space.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
from datetime import datetime
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
from .base import IDModel, TimestampModel
|
||||
|
||||
class SearchSpaceBase(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
|
||||
class SearchSpaceCreate(SearchSpaceBase):
|
||||
pass
|
||||
|
||||
class SearchSpaceUpdate(SearchSpaceBase):
|
||||
pass
|
||||
|
||||
class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel):
|
||||
id: int
|
||||
created_at: datetime
|
||||
user_id: uuid.UUID
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
11
surfsense_backend/app/schemas/users.py
Normal file
11
surfsense_backend/app/schemas/users.py
Normal file
|
@ -0,0 +1,11 @@
|
|||
import uuid
|
||||
from fastapi_users import schemas
|
||||
|
||||
class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
pass
|
||||
|
||||
class UserCreate(schemas.BaseUserCreate):
|
||||
pass
|
||||
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
pass
|
0
surfsense_backend/app/tasks/__init__.py
Normal file
0
surfsense_backend/app/tasks/__init__.py
Normal file
246
surfsense_backend/app/tasks/background_tasks.py
Normal file
246
surfsense_backend/app/tasks/background_tasks.py
Normal file
|
@ -0,0 +1,246 @@
|
|||
from typing import Optional, List
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from app.db import Document, DocumentType, Chunk
|
||||
from app.schemas import ExtensionDocumentContent
|
||||
from app.config import config
|
||||
from app.prompts import SUMMARY_PROMPT_TEMPLATE
|
||||
from datetime import datetime
|
||||
from app.utils.document_converters import convert_document_to_markdown
|
||||
from langchain_core.documents import Document as LangChainDocument
|
||||
from langchain_community.document_loaders import FireCrawlLoader, AsyncChromiumLoader
|
||||
from langchain_community.document_transformers import MarkdownifyTransformer
|
||||
import validators
|
||||
|
||||
md = MarkdownifyTransformer()
|
||||
|
||||
|
||||
async def add_crawled_url_document(
|
||||
session: AsyncSession,
|
||||
url: str,
|
||||
search_space_id: int
|
||||
) -> Optional[Document]:
|
||||
try:
|
||||
|
||||
if not validators.url(url):
|
||||
raise ValueError(f"Url {url} is not a valid URL address")
|
||||
|
||||
if config.FIRECRAWL_API_KEY:
|
||||
crawl_loader = FireCrawlLoader(
|
||||
url=url,
|
||||
api_key=config.FIRECRAWL_API_KEY,
|
||||
mode="scrape",
|
||||
params={
|
||||
"formats": ["markdown"],
|
||||
"excludeTags": ["a"],
|
||||
}
|
||||
)
|
||||
else:
|
||||
crawl_loader = AsyncChromiumLoader(urls=[url], headless=True)
|
||||
|
||||
url_crawled = await crawl_loader.aload()
|
||||
|
||||
if type(crawl_loader) == FireCrawlLoader:
|
||||
content_in_markdown = url_crawled[0].page_content
|
||||
elif type(crawl_loader) == AsyncChromiumLoader:
|
||||
content_in_markdown = md.transform_documents(url_crawled)[
|
||||
0].page_content
|
||||
|
||||
# Format document metadata in a more maintainable way
|
||||
metadata_sections = [
|
||||
("METADATA", [
|
||||
f"{key.upper()}: {value}" for key, value in url_crawled[0].metadata.items()
|
||||
]),
|
||||
("CONTENT", [
|
||||
"FORMAT: markdown",
|
||||
"TEXT_START",
|
||||
content_in_markdown,
|
||||
"TEXT_END"
|
||||
])
|
||||
]
|
||||
|
||||
# Build the document string more efficiently
|
||||
document_parts = []
|
||||
document_parts.append("<DOCUMENT>")
|
||||
|
||||
for section_title, section_content in metadata_sections:
|
||||
document_parts.append(f"<{section_title}>")
|
||||
document_parts.extend(section_content)
|
||||
document_parts.append(f"</{section_title}>")
|
||||
|
||||
document_parts.append("</DOCUMENT>")
|
||||
combined_document_string = '\n'.join(document_parts)
|
||||
|
||||
# Generate summary
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
|
||||
summary_result = await summary_chain.ainvoke({"document": combined_document_string})
|
||||
summary_content = summary_result.content
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content)
|
||||
|
||||
# Process chunks
|
||||
chunks = [
|
||||
Chunk(content=chunk.text, embedding=chunk.embedding)
|
||||
for chunk in config.chunker_instance.chunk(content_in_markdown)
|
||||
]
|
||||
|
||||
# Create and store document
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=url_crawled[0].metadata['title'] if type(
|
||||
crawl_loader) == FireCrawlLoader else url_crawled[0].metadata['source'],
|
||||
document_type=DocumentType.CRAWLED_URL,
|
||||
document_metadata=url_crawled[0].metadata,
|
||||
content=summary_content,
|
||||
embedding=summary_embedding,
|
||||
chunks=chunks
|
||||
)
|
||||
|
||||
session.add(document)
|
||||
await session.commit()
|
||||
await session.refresh(document)
|
||||
|
||||
return document
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
raise db_error
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise RuntimeError(f"Failed to crawl URL: {str(e)}")
|
||||
|
||||
|
||||
async def add_extension_received_document(
|
||||
session: AsyncSession,
|
||||
content: ExtensionDocumentContent,
|
||||
search_space_id: int
|
||||
) -> Optional[Document]:
|
||||
"""
|
||||
Process and store document content received from the SurfSense Extension.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
content: Document content from extension
|
||||
search_space_id: ID of the search space
|
||||
|
||||
Returns:
|
||||
Document object if successful, None if failed
|
||||
"""
|
||||
try:
|
||||
# Format document metadata in a more maintainable way
|
||||
metadata_sections = [
|
||||
("METADATA", [
|
||||
f"SESSION_ID: {content.metadata.BrowsingSessionId}",
|
||||
f"URL: {content.metadata.VisitedWebPageURL}",
|
||||
f"TITLE: {content.metadata.VisitedWebPageTitle}",
|
||||
f"REFERRER: {content.metadata.VisitedWebPageReffererURL}",
|
||||
f"TIMESTAMP: {content.metadata.VisitedWebPageDateWithTimeInISOString}",
|
||||
f"DURATION_MS: {content.metadata.VisitedWebPageVisitDurationInMilliseconds}"
|
||||
]),
|
||||
("CONTENT", [
|
||||
"FORMAT: markdown",
|
||||
"TEXT_START",
|
||||
content.pageContent,
|
||||
"TEXT_END"
|
||||
])
|
||||
]
|
||||
|
||||
# Build the document string more efficiently
|
||||
document_parts = []
|
||||
document_parts.append("<DOCUMENT>")
|
||||
|
||||
for section_title, section_content in metadata_sections:
|
||||
document_parts.append(f"<{section_title}>")
|
||||
document_parts.extend(section_content)
|
||||
document_parts.append(f"</{section_title}>")
|
||||
|
||||
document_parts.append("</DOCUMENT>")
|
||||
combined_document_string = '\n'.join(document_parts)
|
||||
|
||||
# Generate summary
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
|
||||
summary_result = await summary_chain.ainvoke({"document": combined_document_string})
|
||||
summary_content = summary_result.content
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content)
|
||||
|
||||
# Process chunks
|
||||
chunks = [
|
||||
Chunk(content=chunk.text, embedding=chunk.embedding)
|
||||
for chunk in config.chunker_instance.chunk(content.pageContent)
|
||||
]
|
||||
|
||||
# Create and store document
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=content.metadata.VisitedWebPageTitle,
|
||||
document_type=DocumentType.EXTENSION,
|
||||
document_metadata=content.metadata.model_dump(),
|
||||
content=summary_content,
|
||||
embedding=summary_embedding,
|
||||
chunks=chunks
|
||||
)
|
||||
|
||||
session.add(document)
|
||||
await session.commit()
|
||||
await session.refresh(document)
|
||||
|
||||
return document
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
raise db_error
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise RuntimeError(f"Failed to process extension document: {str(e)}")
|
||||
|
||||
|
||||
async def add_received_file_document(
|
||||
session: AsyncSession,
|
||||
file_name: str,
|
||||
unstructured_processed_elements: List[LangChainDocument],
|
||||
search_space_id: int
|
||||
) -> Optional[Document]:
|
||||
try:
|
||||
file_in_markdown = await convert_document_to_markdown(unstructured_processed_elements)
|
||||
|
||||
# TODO: Check if file_markdown exceeds token limit of embedding model
|
||||
|
||||
# Generate summary
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
|
||||
summary_result = await summary_chain.ainvoke({"document": file_in_markdown})
|
||||
summary_content = summary_result.content
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content)
|
||||
|
||||
# Process chunks
|
||||
chunks = [
|
||||
Chunk(content=chunk.text, embedding=chunk.embedding)
|
||||
for chunk in config.chunker_instance.chunk(file_in_markdown)
|
||||
]
|
||||
|
||||
# Create and store document
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=file_name,
|
||||
document_type=DocumentType.FILE,
|
||||
document_metadata={
|
||||
"FILE_NAME": file_name,
|
||||
"SAVED_AT": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
},
|
||||
content=summary_content,
|
||||
embedding=summary_embedding,
|
||||
chunks=chunks
|
||||
)
|
||||
|
||||
session.add(document)
|
||||
await session.commit()
|
||||
await session.refresh(document)
|
||||
|
||||
return document
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
raise db_error
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise RuntimeError(f"Failed to process file document: {str(e)}")
|
486
surfsense_backend/app/tasks/connectors_indexing_tasks.py
Normal file
486
surfsense_backend/app/tasks/connectors_indexing_tasks.py
Normal file
|
@ -0,0 +1,486 @@
|
|||
from typing import Optional, List, Dict, Any, Tuple
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.future import select
|
||||
from datetime import datetime, timedelta
|
||||
from app.db import Document, DocumentType, Chunk, SearchSourceConnector, SearchSourceConnectorType
|
||||
from app.config import config
|
||||
from app.prompts import SUMMARY_PROMPT_TEMPLATE
|
||||
from app.connectors.slack_history import SlackHistory
|
||||
from app.connectors.notion_history import NotionHistoryConnector
|
||||
from slack_sdk.errors import SlackApiError
|
||||
import logging
|
||||
|
||||
# Set up logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def index_slack_messages(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
update_last_indexed: bool = True
|
||||
) -> Tuple[int, Optional[str]]:
|
||||
"""
|
||||
Index Slack messages from all accessible channels.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Slack connector
|
||||
search_space_id: ID of the search space to store documents in
|
||||
update_last_indexed: Whether to update the last_indexed_at timestamp (default: True)
|
||||
|
||||
Returns:
|
||||
Tuple containing (number of documents indexed, error message or None)
|
||||
"""
|
||||
try:
|
||||
# Get the connector
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector)
|
||||
.filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.connector_type == SearchSourceConnectorType.SLACK_CONNECTOR
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
||||
if not connector:
|
||||
return 0, f"Connector with ID {connector_id} not found or is not a Slack connector"
|
||||
|
||||
# Get the Slack token from the connector config
|
||||
slack_token = connector.config.get("SLACK_BOT_TOKEN")
|
||||
if not slack_token:
|
||||
return 0, "Slack token not found in connector config"
|
||||
|
||||
# Initialize Slack client
|
||||
slack_client = SlackHistory(token=slack_token)
|
||||
|
||||
# Calculate date range
|
||||
end_date = datetime.now()
|
||||
|
||||
# Use last_indexed_at as start date if available, otherwise use 365 days ago
|
||||
if connector.last_indexed_at:
|
||||
# Check if last_indexed_at is today
|
||||
today = datetime.now().date()
|
||||
if connector.last_indexed_at.date() == today:
|
||||
# If last indexed today, go back 1 day to ensure we don't miss anything
|
||||
start_date = end_date - timedelta(days=7)
|
||||
else:
|
||||
start_date = connector.last_indexed_at
|
||||
else:
|
||||
start_date = end_date - timedelta(days=365)
|
||||
|
||||
# Format dates for Slack API
|
||||
start_date_str = start_date.strftime("%Y-%m-%d")
|
||||
end_date_str = end_date.strftime("%Y-%m-%d")
|
||||
|
||||
# Get all channels
|
||||
try:
|
||||
channels = slack_client.get_all_channels()
|
||||
except Exception as e:
|
||||
return 0, f"Failed to get Slack channels: {str(e)}"
|
||||
|
||||
if not channels:
|
||||
return 0, "No Slack channels found"
|
||||
|
||||
# Track the number of documents indexed
|
||||
documents_indexed = 0
|
||||
skipped_channels = []
|
||||
|
||||
# Process each channel
|
||||
for channel_name, channel_id in channels.items():
|
||||
try:
|
||||
# Check if the bot is a member of the channel
|
||||
try:
|
||||
# First try to get channel info to check if bot is a member
|
||||
channel_info = slack_client.client.conversations_info(channel=channel_id)
|
||||
|
||||
# For private channels, the bot needs to be a member
|
||||
if channel_info.get("channel", {}).get("is_private", False):
|
||||
# Check if bot is a member
|
||||
is_member = channel_info.get("channel", {}).get("is_member", False)
|
||||
if not is_member:
|
||||
logger.warning(f"Bot is not a member of private channel {channel_name} ({channel_id}). Skipping.")
|
||||
skipped_channels.append(f"{channel_name} (private, bot not a member)")
|
||||
continue
|
||||
except SlackApiError as e:
|
||||
if "not_in_channel" in str(e) or "channel_not_found" in str(e):
|
||||
logger.warning(f"Bot cannot access channel {channel_name} ({channel_id}). Skipping.")
|
||||
skipped_channels.append(f"{channel_name} (access error)")
|
||||
continue
|
||||
else:
|
||||
# Re-raise if it's a different error
|
||||
raise
|
||||
|
||||
# Get messages for this channel
|
||||
messages, error = slack_client.get_history_by_date_range(
|
||||
channel_id=channel_id,
|
||||
start_date=start_date_str,
|
||||
end_date=end_date_str,
|
||||
limit=1000 # Limit to 1000 messages per channel
|
||||
)
|
||||
|
||||
if error:
|
||||
logger.warning(f"Error getting messages from channel {channel_name}: {error}")
|
||||
skipped_channels.append(f"{channel_name} (error: {error})")
|
||||
continue # Skip this channel if there's an error
|
||||
|
||||
if not messages:
|
||||
logger.info(f"No messages found in channel {channel_name} for the specified date range.")
|
||||
continue # Skip if no messages
|
||||
|
||||
# Format messages with user info
|
||||
formatted_messages = []
|
||||
for msg in messages:
|
||||
# Skip bot messages and system messages
|
||||
if msg.get("subtype") in ["bot_message", "channel_join", "channel_leave"]:
|
||||
continue
|
||||
|
||||
formatted_msg = slack_client.format_message(msg, include_user_info=True)
|
||||
formatted_messages.append(formatted_msg)
|
||||
|
||||
if not formatted_messages:
|
||||
logger.info(f"No valid messages found in channel {channel_name} after filtering.")
|
||||
continue # Skip if no valid messages after filtering
|
||||
|
||||
# Convert messages to markdown format
|
||||
channel_content = f"# Slack Channel: {channel_name}\n\n"
|
||||
|
||||
for msg in formatted_messages:
|
||||
user_name = msg.get("user_name", "Unknown User")
|
||||
timestamp = msg.get("datetime", "Unknown Time")
|
||||
text = msg.get("text", "")
|
||||
|
||||
channel_content += f"## {user_name} ({timestamp})\n\n{text}\n\n---\n\n"
|
||||
|
||||
# Format document metadata
|
||||
metadata_sections = [
|
||||
("METADATA", [
|
||||
f"CHANNEL_NAME: {channel_name}",
|
||||
f"CHANNEL_ID: {channel_id}",
|
||||
f"START_DATE: {start_date_str}",
|
||||
f"END_DATE: {end_date_str}",
|
||||
f"MESSAGE_COUNT: {len(formatted_messages)}",
|
||||
f"INDEXED_AT: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
]),
|
||||
("CONTENT", [
|
||||
"FORMAT: markdown",
|
||||
"TEXT_START",
|
||||
channel_content,
|
||||
"TEXT_END"
|
||||
])
|
||||
]
|
||||
|
||||
# Build the document string
|
||||
document_parts = []
|
||||
document_parts.append("<DOCUMENT>")
|
||||
|
||||
for section_title, section_content in metadata_sections:
|
||||
document_parts.append(f"<{section_title}>")
|
||||
document_parts.extend(section_content)
|
||||
document_parts.append(f"</{section_title}>")
|
||||
|
||||
document_parts.append("</DOCUMENT>")
|
||||
combined_document_string = '\n'.join(document_parts)
|
||||
|
||||
# Generate summary
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
|
||||
summary_result = await summary_chain.ainvoke({"document": combined_document_string})
|
||||
summary_content = summary_result.content
|
||||
summary_embedding = config.embedding_model_instance.embed(summary_content)
|
||||
|
||||
# Process chunks
|
||||
chunks = [
|
||||
Chunk(content=chunk.text, embedding=chunk.embedding)
|
||||
for chunk in config.chunker_instance.chunk(channel_content)
|
||||
]
|
||||
|
||||
# Create and store document
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=f"Slack - {channel_name}",
|
||||
document_type=DocumentType.SLACK_CONNECTOR,
|
||||
document_metadata={
|
||||
"channel_name": channel_name,
|
||||
"channel_id": channel_id,
|
||||
"start_date": start_date_str,
|
||||
"end_date": end_date_str,
|
||||
"message_count": len(formatted_messages),
|
||||
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
},
|
||||
content=summary_content,
|
||||
embedding=summary_embedding,
|
||||
chunks=chunks
|
||||
)
|
||||
|
||||
session.add(document)
|
||||
documents_indexed += 1
|
||||
logger.info(f"Successfully indexed channel {channel_name} with {len(formatted_messages)} messages")
|
||||
|
||||
except SlackApiError as slack_error:
|
||||
logger.error(f"Slack API error for channel {channel_name}: {str(slack_error)}")
|
||||
skipped_channels.append(f"{channel_name} (Slack API error)")
|
||||
continue # Skip this channel and continue with others
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing channel {channel_name}: {str(e)}")
|
||||
skipped_channels.append(f"{channel_name} (processing error)")
|
||||
continue # Skip this channel and continue with others
|
||||
|
||||
# Update the last_indexed_at timestamp for the connector only if requested
|
||||
# and if we successfully indexed at least one channel
|
||||
if update_last_indexed and documents_indexed > 0:
|
||||
connector.last_indexed_at = datetime.now()
|
||||
|
||||
# Commit all changes
|
||||
await session.commit()
|
||||
|
||||
# Prepare result message
|
||||
result_message = None
|
||||
if skipped_channels:
|
||||
result_message = f"Indexed {documents_indexed} channels. Skipped {len(skipped_channels)} channels: {', '.join(skipped_channels)}"
|
||||
|
||||
return documents_indexed, result_message
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
logger.error(f"Database error: {str(db_error)}")
|
||||
return 0, f"Database error: {str(db_error)}"
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Failed to index Slack messages: {str(e)}")
|
||||
return 0, f"Failed to index Slack messages: {str(e)}"
|
||||
|
||||
async def index_notion_pages(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
update_last_indexed: bool = True
|
||||
) -> Tuple[int, Optional[str]]:
|
||||
"""
|
||||
Index Notion pages from all accessible pages.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Notion connector
|
||||
search_space_id: ID of the search space to store documents in
|
||||
update_last_indexed: Whether to update the last_indexed_at timestamp (default: True)
|
||||
|
||||
Returns:
|
||||
Tuple containing (number of documents indexed, error message or None)
|
||||
"""
|
||||
try:
|
||||
# Get the connector
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector)
|
||||
.filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
||||
if not connector:
|
||||
return 0, f"Connector with ID {connector_id} not found or is not a Notion connector"
|
||||
|
||||
# Get the Notion token from the connector config
|
||||
notion_token = connector.config.get("NOTION_INTEGRATION_TOKEN")
|
||||
if not notion_token:
|
||||
return 0, "Notion integration token not found in connector config"
|
||||
|
||||
# Initialize Notion client
|
||||
logger.info(f"Initializing Notion client for connector {connector_id}")
|
||||
notion_client = NotionHistoryConnector(token=notion_token)
|
||||
|
||||
# Calculate date range
|
||||
end_date = datetime.now()
|
||||
|
||||
# Use last_indexed_at as start date if available, otherwise use 365 days ago
|
||||
if connector.last_indexed_at:
|
||||
# Check if last_indexed_at is today
|
||||
today = datetime.now().date()
|
||||
if connector.last_indexed_at.date() == today:
|
||||
# If last indexed today, go back 1 day to ensure we don't miss anything
|
||||
start_date = end_date - timedelta(days=1)
|
||||
else:
|
||||
start_date = connector.last_indexed_at
|
||||
else:
|
||||
start_date = end_date - timedelta(days=365)
|
||||
|
||||
# Format dates for Notion API (ISO format)
|
||||
start_date_str = start_date.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
end_date_str = end_date.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
logger.info(f"Fetching Notion pages from {start_date_str} to {end_date_str}")
|
||||
|
||||
# Get all pages
|
||||
try:
|
||||
pages = notion_client.get_all_pages(start_date=start_date_str, end_date=end_date_str)
|
||||
logger.info(f"Found {len(pages)} Notion pages")
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Notion pages: {str(e)}", exc_info=True)
|
||||
return 0, f"Failed to get Notion pages: {str(e)}"
|
||||
|
||||
if not pages:
|
||||
logger.info("No Notion pages found to index")
|
||||
return 0, "No Notion pages found"
|
||||
|
||||
# Track the number of documents indexed
|
||||
documents_indexed = 0
|
||||
skipped_pages = []
|
||||
|
||||
# Process each page
|
||||
for page in pages:
|
||||
try:
|
||||
page_id = page.get("page_id")
|
||||
page_title = page.get("title", f"Untitled page ({page_id})")
|
||||
page_content = page.get("content", [])
|
||||
|
||||
logger.info(f"Processing Notion page: {page_title} ({page_id})")
|
||||
|
||||
if not page_content:
|
||||
logger.info(f"No content found in page {page_title}. Skipping.")
|
||||
skipped_pages.append(f"{page_title} (no content)")
|
||||
continue
|
||||
|
||||
# Convert page content to markdown format
|
||||
markdown_content = f"# Notion Page: {page_title}\n\n"
|
||||
|
||||
# Process blocks recursively
|
||||
def process_blocks(blocks, level=0):
|
||||
result = ""
|
||||
for block in blocks:
|
||||
block_type = block.get("type")
|
||||
block_content = block.get("content", "")
|
||||
children = block.get("children", [])
|
||||
|
||||
# Add indentation based on level
|
||||
indent = " " * level
|
||||
|
||||
# Format based on block type
|
||||
if block_type in ["paragraph", "text"]:
|
||||
result += f"{indent}{block_content}\n\n"
|
||||
elif block_type in ["heading_1", "header"]:
|
||||
result += f"{indent}# {block_content}\n\n"
|
||||
elif block_type == "heading_2":
|
||||
result += f"{indent}## {block_content}\n\n"
|
||||
elif block_type == "heading_3":
|
||||
result += f"{indent}### {block_content}\n\n"
|
||||
elif block_type == "bulleted_list_item":
|
||||
result += f"{indent}* {block_content}\n"
|
||||
elif block_type == "numbered_list_item":
|
||||
result += f"{indent}1. {block_content}\n"
|
||||
elif block_type == "to_do":
|
||||
result += f"{indent}- [ ] {block_content}\n"
|
||||
elif block_type == "toggle":
|
||||
result += f"{indent}> {block_content}\n"
|
||||
elif block_type == "code":
|
||||
result += f"{indent}```\n{block_content}\n```\n\n"
|
||||
elif block_type == "quote":
|
||||
result += f"{indent}> {block_content}\n\n"
|
||||
elif block_type == "callout":
|
||||
result += f"{indent}> **Note:** {block_content}\n\n"
|
||||
elif block_type == "image":
|
||||
result += f"{indent}\n\n"
|
||||
else:
|
||||
# Default for other block types
|
||||
if block_content:
|
||||
result += f"{indent}{block_content}\n\n"
|
||||
|
||||
# Process children recursively
|
||||
if children:
|
||||
result += process_blocks(children, level + 1)
|
||||
|
||||
return result
|
||||
|
||||
logger.debug(f"Converting {len(page_content)} blocks to markdown for page {page_title}")
|
||||
markdown_content += process_blocks(page_content)
|
||||
|
||||
# Format document metadata
|
||||
metadata_sections = [
|
||||
("METADATA", [
|
||||
f"PAGE_TITLE: {page_title}",
|
||||
f"PAGE_ID: {page_id}",
|
||||
f"INDEXED_AT: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
]),
|
||||
("CONTENT", [
|
||||
"FORMAT: markdown",
|
||||
"TEXT_START",
|
||||
markdown_content,
|
||||
"TEXT_END"
|
||||
])
|
||||
]
|
||||
|
||||
# Build the document string
|
||||
document_parts = []
|
||||
document_parts.append("<DOCUMENT>")
|
||||
|
||||
for section_title, section_content in metadata_sections:
|
||||
document_parts.append(f"<{section_title}>")
|
||||
document_parts.extend(section_content)
|
||||
document_parts.append(f"</{section_title}>")
|
||||
|
||||
document_parts.append("</DOCUMENT>")
|
||||
combined_document_string = '\n'.join(document_parts)
|
||||
|
||||
# Generate summary
|
||||
logger.debug(f"Generating summary for page {page_title}")
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
|
||||
summary_result = await summary_chain.ainvoke({"document": combined_document_string})
|
||||
summary_content = summary_result.content
|
||||
summary_embedding = config.embedding_model_instance.embed(summary_content)
|
||||
|
||||
# Process chunks
|
||||
logger.debug(f"Chunking content for page {page_title}")
|
||||
chunks = [
|
||||
Chunk(content=chunk.text, embedding=chunk.embedding)
|
||||
for chunk in config.chunker_instance.chunk(markdown_content)
|
||||
]
|
||||
|
||||
# Create and store document
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=f"Notion - {page_title}",
|
||||
document_type=DocumentType.NOTION_CONNECTOR,
|
||||
document_metadata={
|
||||
"page_title": page_title,
|
||||
"page_id": page_id,
|
||||
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
},
|
||||
content=summary_content,
|
||||
embedding=summary_embedding,
|
||||
chunks=chunks
|
||||
)
|
||||
|
||||
session.add(document)
|
||||
documents_indexed += 1
|
||||
logger.info(f"Successfully indexed Notion page: {page_title}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing Notion page {page.get('title', 'Unknown')}: {str(e)}", exc_info=True)
|
||||
skipped_pages.append(f"{page.get('title', 'Unknown')} (processing error)")
|
||||
continue # Skip this page and continue with others
|
||||
|
||||
# Update the last_indexed_at timestamp for the connector only if requested
|
||||
# and if we successfully indexed at least one page
|
||||
if update_last_indexed and documents_indexed > 0:
|
||||
connector.last_indexed_at = datetime.now()
|
||||
logger.info(f"Updated last_indexed_at for connector {connector_id}")
|
||||
|
||||
# Commit all changes
|
||||
await session.commit()
|
||||
|
||||
# Prepare result message
|
||||
result_message = None
|
||||
if skipped_pages:
|
||||
result_message = f"Indexed {documents_indexed} pages. Skipped {len(skipped_pages)} pages: {', '.join(skipped_pages)}"
|
||||
|
||||
logger.info(f"Notion indexing completed: {documents_indexed} pages indexed, {len(skipped_pages)} pages skipped")
|
||||
return documents_indexed, result_message
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
logger.error(f"Database error during Notion indexing: {str(db_error)}", exc_info=True)
|
||||
return 0, f"Database error: {str(db_error)}"
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Failed to index Notion pages: {str(e)}", exc_info=True)
|
||||
return 0, f"Failed to index Notion pages: {str(e)}"
|
340
surfsense_backend/app/tasks/stream_connector_search_results.py
Normal file
340
surfsense_backend/app/tasks/stream_connector_search_results.py
Normal file
|
@ -0,0 +1,340 @@
|
|||
import json
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import List, AsyncGenerator, Dict, Any
|
||||
import asyncio
|
||||
import re
|
||||
|
||||
from app.utils.connector_service import ConnectorService
|
||||
from app.utils.research_service import ResearchService
|
||||
from app.utils.streaming_service import StreamingService
|
||||
from app.utils.reranker_service import RerankerService
|
||||
from app.config import config
|
||||
from app.utils.document_converters import convert_chunks_to_langchain_documents
|
||||
|
||||
async def stream_connector_search_results(
|
||||
user_query: str,
|
||||
user_id: int,
|
||||
search_space_id: int,
|
||||
session: AsyncSession,
|
||||
research_mode: str,
|
||||
selected_connectors: List[str]
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Stream connector search results to the client
|
||||
|
||||
Args:
|
||||
user_query: The user's query
|
||||
user_id: The user's ID
|
||||
search_space_id: The search space ID
|
||||
session: The database session
|
||||
research_mode: The research mode
|
||||
selected_connectors: List of selected connectors
|
||||
|
||||
Yields:
|
||||
str: Formatted response strings
|
||||
"""
|
||||
# Initialize services
|
||||
connector_service = ConnectorService(session)
|
||||
streaming_service = StreamingService()
|
||||
|
||||
|
||||
reranker_service = RerankerService.get_reranker_instance(config)
|
||||
|
||||
all_raw_documents = [] # Store all raw documents before reranking
|
||||
all_sources = []
|
||||
TOP_K = 20
|
||||
|
||||
if research_mode == "GENERAL":
|
||||
TOP_K = 20
|
||||
elif research_mode == "DEEP":
|
||||
TOP_K = 40
|
||||
elif research_mode == "DEEPER":
|
||||
TOP_K = 60
|
||||
|
||||
|
||||
# Process each selected connector
|
||||
for connector in selected_connectors:
|
||||
# Crawled URLs
|
||||
if connector == "CRAWLED_URL":
|
||||
# Send terminal message about starting search
|
||||
yield streaming_service.add_terminal_message("Starting to search for crawled URLs...")
|
||||
|
||||
# Search for crawled URLs
|
||||
result_object, crawled_urls_chunks = await connector_service.search_crawled_urls(
|
||||
user_query=user_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=TOP_K
|
||||
)
|
||||
|
||||
# Send terminal message about search results
|
||||
yield streaming_service.add_terminal_message(
|
||||
f"Found {len(result_object['sources'])} relevant crawled URLs",
|
||||
"success"
|
||||
)
|
||||
|
||||
# Update sources
|
||||
all_sources.append(result_object)
|
||||
yield streaming_service.update_sources(all_sources)
|
||||
|
||||
# Add documents to collection
|
||||
all_raw_documents.extend(crawled_urls_chunks)
|
||||
|
||||
|
||||
# Files
|
||||
if connector == "FILE":
|
||||
# Send terminal message about starting search
|
||||
yield streaming_service.add_terminal_message("Starting to search for files...")
|
||||
|
||||
# Search for files
|
||||
result_object, files_chunks = await connector_service.search_files(
|
||||
user_query=user_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=TOP_K
|
||||
)
|
||||
|
||||
# Send terminal message about search results
|
||||
yield streaming_service.add_terminal_message(
|
||||
f"Found {len(result_object['sources'])} relevant files",
|
||||
"success"
|
||||
)
|
||||
|
||||
# Update sources
|
||||
all_sources.append(result_object)
|
||||
yield streaming_service.update_sources(all_sources)
|
||||
|
||||
# Add documents to collection
|
||||
all_raw_documents.extend(files_chunks)
|
||||
|
||||
# Tavily Connector
|
||||
if connector == "TAVILY_API":
|
||||
# Send terminal message about starting search
|
||||
yield streaming_service.add_terminal_message("Starting to search with Tavily API...")
|
||||
|
||||
# Search using Tavily API
|
||||
result_object, tavily_chunks = await connector_service.search_tavily(
|
||||
user_query=user_query,
|
||||
user_id=user_id,
|
||||
top_k=TOP_K
|
||||
)
|
||||
|
||||
# Send terminal message about search results
|
||||
yield streaming_service.add_terminal_message(
|
||||
f"Found {len(result_object['sources'])} relevant results from Tavily",
|
||||
"success"
|
||||
)
|
||||
|
||||
# Update sources
|
||||
all_sources.append(result_object)
|
||||
yield streaming_service.update_sources(all_sources)
|
||||
|
||||
# Add documents to collection
|
||||
all_raw_documents.extend(tavily_chunks)
|
||||
|
||||
# Slack Connector
|
||||
if connector == "SLACK_CONNECTOR":
|
||||
# Send terminal message about starting search
|
||||
yield streaming_service.add_terminal_message("Starting to search for slack connector...")
|
||||
|
||||
# Search using Slack API
|
||||
result_object, slack_chunks = await connector_service.search_slack(
|
||||
user_query=user_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=TOP_K
|
||||
)
|
||||
|
||||
# Send terminal message about search results
|
||||
yield streaming_service.add_terminal_message(
|
||||
f"Found {len(result_object['sources'])} relevant results from Slack",
|
||||
"success"
|
||||
)
|
||||
|
||||
# Update sources
|
||||
all_sources.append(result_object)
|
||||
yield streaming_service.update_sources(all_sources)
|
||||
|
||||
# Add documents to collection
|
||||
all_raw_documents.extend(slack_chunks)
|
||||
|
||||
|
||||
# Notion Connector
|
||||
if connector == "NOTION_CONNECTOR":
|
||||
# Send terminal message about starting search
|
||||
yield streaming_service.add_terminal_message("Starting to search for notion connector...")
|
||||
|
||||
# Search using Notion API
|
||||
result_object, notion_chunks = await connector_service.search_notion(
|
||||
user_query=user_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=TOP_K
|
||||
)
|
||||
|
||||
# Send terminal message about search results
|
||||
yield streaming_service.add_terminal_message(
|
||||
f"Found {len(result_object['sources'])} relevant results from Notion",
|
||||
"success"
|
||||
)
|
||||
|
||||
# Update sources
|
||||
all_sources.append(result_object)
|
||||
yield streaming_service.update_sources(all_sources)
|
||||
|
||||
# Add documents to collection
|
||||
all_raw_documents.extend(notion_chunks)
|
||||
|
||||
|
||||
|
||||
|
||||
# If we have documents to research
|
||||
if all_raw_documents:
|
||||
# Rerank all documents if reranker is available
|
||||
if reranker_service:
|
||||
yield streaming_service.add_terminal_message("Reranking documents for better relevance...", "info")
|
||||
|
||||
# Convert documents to format expected by reranker
|
||||
reranker_input_docs = [
|
||||
{
|
||||
"chunk_id": doc.get("chunk_id", f"chunk_{i}"),
|
||||
"content": doc.get("content", ""),
|
||||
"score": doc.get("score", 0.0),
|
||||
"document": {
|
||||
"id": doc.get("document", {}).get("id", ""),
|
||||
"title": doc.get("document", {}).get("title", ""),
|
||||
"document_type": doc.get("document", {}).get("document_type", ""),
|
||||
"metadata": doc.get("document", {}).get("metadata", {})
|
||||
}
|
||||
} for i, doc in enumerate(all_raw_documents)
|
||||
]
|
||||
|
||||
# Rerank documents
|
||||
reranked_docs = reranker_service.rerank_documents(user_query, reranker_input_docs)
|
||||
|
||||
# Sort by score in descending order
|
||||
reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True)
|
||||
|
||||
|
||||
|
||||
# Convert back to langchain documents format
|
||||
from langchain.schema import Document as LangchainDocument
|
||||
all_langchain_documents_to_research = [
|
||||
LangchainDocument(
|
||||
page_content= f"""<document><metadata><source_id>{doc.get("document", {}).get("id", "")}</source_id></metadata><content>{doc.get("content", "")}</content></document>""",
|
||||
metadata={
|
||||
# **doc.get("document", {}).get("metadata", {}),
|
||||
# "score": doc.get("score", 0.0),
|
||||
# "rank": doc.get("rank", 0),
|
||||
# "document_id": doc.get("document", {}).get("id", ""),
|
||||
# "document_title": doc.get("document", {}).get("title", ""),
|
||||
# "document_type": doc.get("document", {}).get("document_type", ""),
|
||||
# # Explicitly set source_id for citation purposes
|
||||
"source_id": str(doc.get("document", {}).get("id", ""))
|
||||
}
|
||||
) for doc in reranked_docs
|
||||
]
|
||||
|
||||
yield streaming_service.add_terminal_message(f"Reranked {len(all_langchain_documents_to_research)} documents", "success")
|
||||
else:
|
||||
# Use raw documents if no reranker is available
|
||||
all_langchain_documents_to_research = convert_chunks_to_langchain_documents(all_raw_documents)
|
||||
|
||||
# Send terminal message about starting research
|
||||
yield streaming_service.add_terminal_message("Starting to research...", "info")
|
||||
|
||||
# Create a buffer to collect report content
|
||||
report_buffer = []
|
||||
|
||||
|
||||
# Use the streaming research method
|
||||
yield streaming_service.add_terminal_message("Generating report...", "info")
|
||||
|
||||
# Create a wrapper to handle the streaming
|
||||
class StreamHandler:
|
||||
def __init__(self):
|
||||
self.queue = asyncio.Queue()
|
||||
|
||||
async def handle_progress(self, data):
|
||||
result = None
|
||||
if data.get("type") == "logs":
|
||||
# Handle log messages
|
||||
result = streaming_service.add_terminal_message(data.get("output", ""), "info")
|
||||
elif data.get("type") == "report":
|
||||
# Handle report content
|
||||
content = data.get("output", "")
|
||||
|
||||
# Fix incorrect citation formats using regex
|
||||
|
||||
# More specific pattern to match only numeric citations in markdown-style links
|
||||
# This matches patterns like ([1](https://github.com/...)) but not general links like ([Click here](https://...))
|
||||
pattern = r'\(\[(\d+)\]\((https?://[^\)]+)\)\)'
|
||||
|
||||
# Replace with just [X] where X is the number
|
||||
content = re.sub(pattern, r'[\1]', content)
|
||||
|
||||
# Also match other incorrect formats like ([1]) and convert to [1]
|
||||
# Only match if the content inside brackets is a number
|
||||
content = re.sub(r'\(\[(\d+)\]\)', r'[\1]', content)
|
||||
|
||||
report_buffer.append(content)
|
||||
# Update the answer with the accumulated content
|
||||
result = streaming_service.update_answer(report_buffer)
|
||||
|
||||
if result:
|
||||
await self.queue.put(result)
|
||||
return result
|
||||
|
||||
async def get_next(self):
|
||||
try:
|
||||
return await self.queue.get()
|
||||
except Exception as e:
|
||||
print(f"Error getting next item from queue: {e}")
|
||||
return None
|
||||
|
||||
def task_done(self):
|
||||
self.queue.task_done()
|
||||
|
||||
# Create the stream handler
|
||||
stream_handler = StreamHandler()
|
||||
|
||||
# Start the research process in a separate task
|
||||
research_task = asyncio.create_task(
|
||||
ResearchService.stream_research(
|
||||
user_query=user_query,
|
||||
documents=all_langchain_documents_to_research,
|
||||
on_progress=stream_handler.handle_progress,
|
||||
research_mode=research_mode
|
||||
)
|
||||
)
|
||||
|
||||
# Stream results as they become available
|
||||
while not research_task.done() or not stream_handler.queue.empty():
|
||||
try:
|
||||
# Get the next result with a timeout
|
||||
result = await asyncio.wait_for(stream_handler.get_next(), timeout=0.1)
|
||||
stream_handler.task_done()
|
||||
yield result
|
||||
except asyncio.TimeoutError:
|
||||
# No result available yet, check if the research task is done
|
||||
if research_task.done():
|
||||
# If the queue is empty and the task is done, we're finished
|
||||
if stream_handler.queue.empty():
|
||||
break
|
||||
|
||||
# Get the final report
|
||||
try:
|
||||
final_report = await research_task
|
||||
|
||||
# Send terminal message about research completion
|
||||
yield streaming_service.add_terminal_message("Research completed", "success")
|
||||
|
||||
# Update the answer with the final report
|
||||
final_report_lines = final_report.split('\n')
|
||||
yield streaming_service.update_answer(final_report_lines)
|
||||
except Exception as e:
|
||||
# Handle any exceptions
|
||||
yield streaming_service.add_terminal_message(f"Error during research: {str(e)}", "error")
|
||||
|
||||
# Send completion message
|
||||
yield streaming_service.format_completion()
|
95
surfsense_backend/app/users.py
Normal file
95
surfsense_backend/app/users.py
Normal file
|
@ -0,0 +1,95 @@
|
|||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from fastapi import Depends, Request, Response
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models
|
||||
from fastapi_users.authentication import (
|
||||
AuthenticationBackend,
|
||||
BearerTransport,
|
||||
JWTStrategy,
|
||||
)
|
||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
|
||||
from app.config import config
|
||||
from app.db import User, get_user_db
|
||||
from pydantic import BaseModel
|
||||
|
||||
class BearerResponse(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
SECRET = config.SECRET_KEY
|
||||
|
||||
google_oauth_client = GoogleOAuth2(
|
||||
config.GOOGLE_OAUTH_CLIENT_ID,
|
||||
config.GOOGLE_OAUTH_CLIENT_SECRET,
|
||||
)
|
||||
|
||||
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reset_password_token_secret = SECRET
|
||||
verification_token_secret = SECRET
|
||||
|
||||
async def on_after_register(self, user: User, request: Optional[Request] = None):
|
||||
print(f"User {user.id} has registered.")
|
||||
|
||||
async def on_after_forgot_password(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
):
|
||||
print(f"User {user.id} has forgot their password. Reset token: {token}")
|
||||
|
||||
async def on_after_request_verify(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
):
|
||||
print(
|
||||
f"Verification requested for user {user.id}. Verification token: {token}")
|
||||
|
||||
|
||||
async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)):
|
||||
yield UserManager(user_db)
|
||||
|
||||
|
||||
def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
|
||||
return JWTStrategy(secret=SECRET, lifetime_seconds=3600*24)
|
||||
|
||||
|
||||
# # COOKIE AUTH | Uncomment if you want to use cookie auth.
|
||||
# from fastapi_users.authentication import (
|
||||
# CookieTransport,
|
||||
# )
|
||||
# class CustomCookieTransport(CookieTransport):
|
||||
# async def get_login_response(self, token: str) -> Response:
|
||||
# response = RedirectResponse(config.OAUTH_REDIRECT_URL, status_code=302)
|
||||
# return self._set_login_cookie(response, token)
|
||||
|
||||
# cookie_transport = CustomCookieTransport(
|
||||
# cookie_max_age=3600,
|
||||
# )
|
||||
|
||||
# auth_backend = AuthenticationBackend(
|
||||
# name="jwt",
|
||||
# transport=cookie_transport,
|
||||
# get_strategy=get_jwt_strategy,
|
||||
# )
|
||||
|
||||
# BEARER AUTH CODE.
|
||||
class CustomBearerTransport(BearerTransport):
|
||||
async def get_login_response(self, token: str) -> Response:
|
||||
bearer_response = BearerResponse(access_token=token, token_type="bearer")
|
||||
redirect_url = f"{config.NEXT_FRONTEND_URL}/auth/callback?token={bearer_response.access_token}"
|
||||
return RedirectResponse(redirect_url, status_code=302)
|
||||
|
||||
bearer_transport = CustomBearerTransport(tokenUrl="auth/jwt/login")
|
||||
|
||||
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="jwt",
|
||||
transport=bearer_transport,
|
||||
get_strategy=get_jwt_strategy,
|
||||
)
|
||||
|
||||
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
|
||||
|
||||
current_active_user = fastapi_users.current_user(active=True)
|
12
surfsense_backend/app/utils/check_ownership.py
Normal file
12
surfsense_backend/app/utils/check_ownership.py
Normal file
|
@ -0,0 +1,12 @@
|
|||
from fastapi import HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from app.db import User
|
||||
|
||||
# Helper function to check user ownership
|
||||
async def check_ownership(session: AsyncSession, model, item_id: int, user: User):
|
||||
item = await session.execute(select(model).filter(model.id == item_id, model.user_id == user.id))
|
||||
item = item.scalars().first()
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Item not found or you don't have permission to access it")
|
||||
return item
|
385
surfsense_backend/app/utils/connector_service.py
Normal file
385
surfsense_backend/app/utils/connector_service.py
Normal file
|
@ -0,0 +1,385 @@
|
|||
import json
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
from tavily import TavilyClient
|
||||
|
||||
|
||||
class ConnectorService:
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
self.retriever = ChucksHybridSearchRetriever(session)
|
||||
self.source_id_counter = 1
|
||||
|
||||
async def search_crawled_urls(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
"""
|
||||
Search for crawled URLs and return both the source information and langchain documents
|
||||
|
||||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
crawled_urls_chunks = await self.retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="CRAWLED_URL"
|
||||
)
|
||||
|
||||
# Map crawled_urls_chunks to the required format
|
||||
mapped_sources = {}
|
||||
for i, chunk in enumerate(crawled_urls_chunks):
|
||||
#Fix for UI
|
||||
crawled_urls_chunks[i]['document']['id'] = self.source_id_counter
|
||||
# Extract document metadata
|
||||
document = chunk.get('document', {})
|
||||
metadata = document.get('metadata', {})
|
||||
|
||||
# Create a mapped source entry
|
||||
source = {
|
||||
"id": self.source_id_counter,
|
||||
"title": document.get('title', 'Untitled Document'),
|
||||
"description": metadata.get('og:description', metadata.get('ogDescription', chunk.get('content', '')[:100])),
|
||||
"url": metadata.get('url', '')
|
||||
}
|
||||
|
||||
self.source_id_counter += 1
|
||||
|
||||
# Use a unique identifier for tracking unique sources
|
||||
source_key = source.get("url") or source.get("title")
|
||||
if source_key and source_key not in mapped_sources:
|
||||
mapped_sources[source_key] = source
|
||||
|
||||
# Convert to list of sources
|
||||
sources_list = list(mapped_sources.values())
|
||||
|
||||
# Create result object
|
||||
result_object = {
|
||||
"id": 1,
|
||||
"name": "Crawled URLs",
|
||||
"type": "CRAWLED_URL",
|
||||
"sources": sources_list,
|
||||
}
|
||||
|
||||
|
||||
return result_object, crawled_urls_chunks
|
||||
|
||||
async def search_files(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
"""
|
||||
Search for files and return both the source information and langchain documents
|
||||
|
||||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
files_chunks = await self.retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="FILE"
|
||||
)
|
||||
|
||||
# Map crawled_urls_chunks to the required format
|
||||
mapped_sources = {}
|
||||
for i, chunk in enumerate(files_chunks):
|
||||
#Fix for UI
|
||||
files_chunks[i]['document']['id'] = self.source_id_counter
|
||||
# Extract document metadata
|
||||
document = chunk.get('document', {})
|
||||
metadata = document.get('metadata', {})
|
||||
|
||||
# Create a mapped source entry
|
||||
source = {
|
||||
"id": self.source_id_counter,
|
||||
"title": document.get('title', 'Untitled Document'),
|
||||
"description": metadata.get('og:description', metadata.get('ogDescription', chunk.get('content', '')[:100])),
|
||||
"url": metadata.get('url', '')
|
||||
}
|
||||
|
||||
self.source_id_counter += 1
|
||||
|
||||
# Use a unique identifier for tracking unique sources
|
||||
source_key = source.get("url") or source.get("title")
|
||||
if source_key and source_key not in mapped_sources:
|
||||
mapped_sources[source_key] = source
|
||||
|
||||
# Convert to list of sources
|
||||
sources_list = list(mapped_sources.values())
|
||||
|
||||
# Create result object
|
||||
result_object = {
|
||||
"id": 2,
|
||||
"name": "Files",
|
||||
"type": "FILE",
|
||||
"sources": sources_list,
|
||||
}
|
||||
|
||||
return result_object, files_chunks
|
||||
|
||||
async def get_connector_by_type(self, user_id: int, connector_type: SearchSourceConnectorType) -> Optional[SearchSourceConnector]:
|
||||
"""
|
||||
Get a connector by type for a specific user
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
connector_type: The connector type to retrieve
|
||||
|
||||
Returns:
|
||||
Optional[SearchSourceConnector]: The connector if found, None otherwise
|
||||
"""
|
||||
result = await self.session.execute(
|
||||
select(SearchSourceConnector)
|
||||
.filter(
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type == connector_type
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def search_tavily(self, user_query: str, user_id: int, top_k: int = 20) -> tuple:
|
||||
"""
|
||||
Search using Tavily API and return both the source information and documents
|
||||
|
||||
Args:
|
||||
user_query: The user's query
|
||||
user_id: The user's ID
|
||||
top_k: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
tuple: (sources_info, documents)
|
||||
"""
|
||||
# Get Tavily connector configuration
|
||||
tavily_connector = await self.get_connector_by_type(user_id, SearchSourceConnectorType.TAVILY_API)
|
||||
|
||||
if not tavily_connector:
|
||||
# Return empty results if no Tavily connector is configured
|
||||
return {
|
||||
"id": 3,
|
||||
"name": "Tavily Search",
|
||||
"type": "TAVILY_API",
|
||||
"sources": [],
|
||||
}, []
|
||||
|
||||
# Initialize Tavily client with API key from connector config
|
||||
tavily_api_key = tavily_connector.config.get("TAVILY_API_KEY")
|
||||
tavily_client = TavilyClient(api_key=tavily_api_key)
|
||||
|
||||
# Perform search with Tavily
|
||||
try:
|
||||
response = tavily_client.search(
|
||||
query=user_query,
|
||||
max_results=top_k,
|
||||
search_depth="advanced" # Use advanced search for better results
|
||||
)
|
||||
|
||||
# Extract results from Tavily response
|
||||
tavily_results = response.get("results", [])
|
||||
|
||||
# Map Tavily results to the required format
|
||||
sources_list = []
|
||||
documents = []
|
||||
|
||||
# Start IDs from 1000 to avoid conflicts with other connectors
|
||||
base_id = 100
|
||||
|
||||
for i, result in enumerate(tavily_results):
|
||||
|
||||
# Create a source entry
|
||||
source = {
|
||||
"id": self.source_id_counter,
|
||||
"title": result.get("title", "Tavily Result"),
|
||||
"description": result.get("content", "")[:100],
|
||||
"url": result.get("url", "")
|
||||
}
|
||||
sources_list.append(source)
|
||||
|
||||
# Create a document entry
|
||||
document = {
|
||||
"chunk_id": f"tavily_chunk_{i}",
|
||||
"content": result.get("content", ""),
|
||||
"score": result.get("score", 0.0),
|
||||
"document": {
|
||||
"id": self.source_id_counter,
|
||||
"title": result.get("title", "Tavily Result"),
|
||||
"document_type": "TAVILY_API",
|
||||
"metadata": {
|
||||
"url": result.get("url", ""),
|
||||
"published_date": result.get("published_date", ""),
|
||||
"source": "TAVILY_API"
|
||||
}
|
||||
}
|
||||
}
|
||||
documents.append(document)
|
||||
self.source_id_counter += 1
|
||||
|
||||
# Create result object
|
||||
result_object = {
|
||||
"id": 3,
|
||||
"name": "Tavily Search",
|
||||
"type": "TAVILY_API",
|
||||
"sources": sources_list,
|
||||
}
|
||||
|
||||
return result_object, documents
|
||||
|
||||
except Exception as e:
|
||||
# Log the error and return empty results
|
||||
print(f"Error searching with Tavily: {str(e)}")
|
||||
return {
|
||||
"id": 3,
|
||||
"name": "Tavily Search",
|
||||
"type": "TAVILY_API",
|
||||
"sources": [],
|
||||
}, []
|
||||
|
||||
async def search_slack(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
"""
|
||||
Search for slack and return both the source information and langchain documents
|
||||
|
||||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
slack_chunks = await self.retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="SLACK_CONNECTOR"
|
||||
)
|
||||
|
||||
# Map slack_chunks to the required format
|
||||
mapped_sources = {}
|
||||
for i, chunk in enumerate(slack_chunks):
|
||||
#Fix for UI
|
||||
slack_chunks[i]['document']['id'] = self.source_id_counter
|
||||
# Extract document metadata
|
||||
document = chunk.get('document', {})
|
||||
metadata = document.get('metadata', {})
|
||||
|
||||
# Create a mapped source entry with Slack-specific metadata
|
||||
channel_name = metadata.get('channel_name', 'Unknown Channel')
|
||||
channel_id = metadata.get('channel_id', '')
|
||||
message_date = metadata.get('start_date', '')
|
||||
|
||||
# Create a more descriptive title for Slack messages
|
||||
title = f"Slack: {channel_name}"
|
||||
if message_date:
|
||||
title += f" ({message_date})"
|
||||
|
||||
# Create a more descriptive description for Slack messages
|
||||
description = chunk.get('content', '')[:100]
|
||||
if len(description) == 100:
|
||||
description += "..."
|
||||
|
||||
# For URL, we can use a placeholder or construct a URL to the Slack channel if available
|
||||
url = ""
|
||||
if channel_id:
|
||||
url = f"https://slack.com/app_redirect?channel={channel_id}"
|
||||
|
||||
source = {
|
||||
"id": self.source_id_counter,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"url": url,
|
||||
}
|
||||
|
||||
self.source_id_counter += 1
|
||||
|
||||
# Use channel_id and content as a unique identifier for tracking unique sources
|
||||
source_key = f"{channel_id}_{chunk.get('chunk_id', i)}"
|
||||
if source_key and source_key not in mapped_sources:
|
||||
mapped_sources[source_key] = source
|
||||
|
||||
# Convert to list of sources
|
||||
sources_list = list(mapped_sources.values())
|
||||
|
||||
# Create result object
|
||||
result_object = {
|
||||
"id": 4,
|
||||
"name": "Slack",
|
||||
"type": "SLACK_CONNECTOR",
|
||||
"sources": sources_list,
|
||||
}
|
||||
|
||||
return result_object, slack_chunks
|
||||
|
||||
async def search_notion(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
"""
|
||||
Search for Notion pages and return both the source information and langchain documents
|
||||
|
||||
Args:
|
||||
user_query: The user's query
|
||||
user_id: The user's ID
|
||||
search_space_id: The search space ID to search in
|
||||
top_k: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
notion_chunks = await self.retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="NOTION_CONNECTOR"
|
||||
)
|
||||
|
||||
# Map notion_chunks to the required format
|
||||
mapped_sources = {}
|
||||
for i, chunk in enumerate(notion_chunks):
|
||||
# Fix for UI
|
||||
notion_chunks[i]['document']['id'] = self.source_id_counter
|
||||
|
||||
# Extract document metadata
|
||||
document = chunk.get('document', {})
|
||||
metadata = document.get('metadata', {})
|
||||
|
||||
# Create a mapped source entry with Notion-specific metadata
|
||||
page_title = metadata.get('page_title', 'Untitled Page')
|
||||
page_id = metadata.get('page_id', '')
|
||||
indexed_at = metadata.get('indexed_at', '')
|
||||
|
||||
# Create a more descriptive title for Notion pages
|
||||
title = f"Notion: {page_title}"
|
||||
if indexed_at:
|
||||
title += f" (indexed: {indexed_at})"
|
||||
|
||||
# Create a more descriptive description for Notion pages
|
||||
description = chunk.get('content', '')[:100]
|
||||
if len(description) == 100:
|
||||
description += "..."
|
||||
|
||||
# For URL, we can use a placeholder or construct a URL to the Notion page if available
|
||||
url = ""
|
||||
if page_id:
|
||||
# Notion page URLs follow this format
|
||||
url = f"https://notion.so/{page_id.replace('-', '')}"
|
||||
|
||||
source = {
|
||||
"id": self.source_id_counter,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"url": url,
|
||||
}
|
||||
|
||||
self.source_id_counter += 1
|
||||
|
||||
# Use page_id and content as a unique identifier for tracking unique sources
|
||||
source_key = f"{page_id}_{chunk.get('chunk_id', i)}"
|
||||
if source_key and source_key not in mapped_sources:
|
||||
mapped_sources[source_key] = source
|
||||
|
||||
# Convert to list of sources
|
||||
sources_list = list(mapped_sources.values())
|
||||
|
||||
# Create result object
|
||||
result_object = {
|
||||
"id": 5,
|
||||
"name": "Notion",
|
||||
"type": "NOTION_CONNECTOR",
|
||||
"sources": sources_list,
|
||||
}
|
||||
|
||||
return result_object, notion_chunks
|
136
surfsense_backend/app/utils/document_converters.py
Normal file
136
surfsense_backend/app/utils/document_converters.py
Normal file
|
@ -0,0 +1,136 @@
|
|||
async def convert_element_to_markdown(element) -> str:
|
||||
"""
|
||||
Convert an Unstructured element to markdown format based on its category.
|
||||
|
||||
Args:
|
||||
element: The Unstructured API element object
|
||||
|
||||
Returns:
|
||||
str: Markdown formatted string
|
||||
"""
|
||||
element_category = element.metadata["category"]
|
||||
content = element.page_content
|
||||
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
markdown_mapping = {
|
||||
"Formula": lambda x: f"```math\n{x}\n```",
|
||||
"FigureCaption": lambda x: f"*Figure: {x}*",
|
||||
"NarrativeText": lambda x: f"{x}\n\n",
|
||||
"ListItem": lambda x: f"- {x}\n",
|
||||
"Title": lambda x: f"# {x}\n\n",
|
||||
"Address": lambda x: f"> {x}\n\n",
|
||||
"EmailAddress": lambda x: f"`{x}`",
|
||||
"Image": lambda x: f"",
|
||||
"PageBreak": lambda x: "\n---\n",
|
||||
"Table": lambda x: f"```html\n{element.metadata['text_as_html']}\n```",
|
||||
"Header": lambda x: f"## {x}\n\n",
|
||||
"Footer": lambda x: f"*{x}*\n\n",
|
||||
"CodeSnippet": lambda x: f"```\n{x}\n```",
|
||||
"PageNumber": lambda x: f"*Page {x}*\n\n",
|
||||
"UncategorizedText": lambda x: f"{x}\n\n"
|
||||
}
|
||||
|
||||
converter = markdown_mapping.get(element_category, lambda x: x)
|
||||
return converter(content)
|
||||
|
||||
|
||||
async def convert_document_to_markdown(elements):
|
||||
"""
|
||||
Convert all document elements to markdown.
|
||||
|
||||
Args:
|
||||
elements: List of Unstructured API elements
|
||||
|
||||
Returns:
|
||||
str: Complete markdown document
|
||||
"""
|
||||
markdown_parts = []
|
||||
|
||||
for element in elements:
|
||||
markdown_text = await convert_element_to_markdown(element)
|
||||
if markdown_text:
|
||||
markdown_parts.append(markdown_text)
|
||||
|
||||
return "".join(markdown_parts)
|
||||
|
||||
def convert_chunks_to_langchain_documents(chunks):
|
||||
"""
|
||||
Convert chunks from hybrid search results to LangChain Document objects.
|
||||
|
||||
Args:
|
||||
chunks: List of chunk dictionaries from hybrid search results
|
||||
|
||||
Returns:
|
||||
List of LangChain Document objects
|
||||
"""
|
||||
try:
|
||||
from langchain_core.documents import Document as LangChainDocument
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"LangChain is not installed. Please install it with `pip install langchain langchain-core`"
|
||||
)
|
||||
|
||||
langchain_docs = []
|
||||
|
||||
for chunk in chunks:
|
||||
# Extract content from the chunk
|
||||
content = chunk.get("content", "")
|
||||
|
||||
# Create metadata dictionary
|
||||
metadata = {
|
||||
"chunk_id": chunk.get("chunk_id"),
|
||||
"score": chunk.get("score"),
|
||||
"rank": chunk.get("rank") if "rank" in chunk else None,
|
||||
}
|
||||
|
||||
# Add document information to metadata
|
||||
if "document" in chunk:
|
||||
doc = chunk["document"]
|
||||
metadata.update({
|
||||
"document_id": doc.get("id"),
|
||||
"document_title": doc.get("title"),
|
||||
"document_type": doc.get("document_type"),
|
||||
})
|
||||
|
||||
# Add document metadata if available
|
||||
if "metadata" in doc:
|
||||
# Prefix document metadata keys to avoid conflicts
|
||||
doc_metadata = {f"doc_meta_{k}": v for k, v in doc.get("metadata", {}).items()}
|
||||
metadata.update(doc_metadata)
|
||||
|
||||
# Add source URL if available in metadata
|
||||
if "url" in doc.get("metadata", {}):
|
||||
metadata["source"] = doc["metadata"]["url"]
|
||||
elif "sourceURL" in doc.get("metadata", {}):
|
||||
metadata["source"] = doc["metadata"]["sourceURL"]
|
||||
|
||||
# Ensure source_id is set for citation purposes
|
||||
# Use document_id as the source_id if available
|
||||
if "document_id" in metadata:
|
||||
metadata["source_id"] = metadata["document_id"]
|
||||
|
||||
# Update content for citation mode - format as XML with explicit source_id
|
||||
new_content = f"""
|
||||
<document>
|
||||
<metadata>
|
||||
<source_id>{metadata.get("source_id", metadata.get("document_id", "unknown"))}</source_id>
|
||||
</metadata>
|
||||
<content>
|
||||
<text>
|
||||
{content}
|
||||
</text>
|
||||
</content>
|
||||
</document>
|
||||
"""
|
||||
|
||||
# Create LangChain Document
|
||||
langchain_doc = LangChainDocument(
|
||||
page_content=new_content,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
langchain_docs.append(langchain_doc)
|
||||
|
||||
return langchain_docs
|
95
surfsense_backend/app/utils/reranker_service.py
Normal file
95
surfsense_backend/app/utils/reranker_service.py
Normal file
|
@ -0,0 +1,95 @@
|
|||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from rerankers import Document as RerankerDocument
|
||||
|
||||
class RerankerService:
|
||||
"""
|
||||
Service for reranking documents using a configured reranker
|
||||
"""
|
||||
|
||||
def __init__(self, reranker_instance=None):
|
||||
"""
|
||||
Initialize the reranker service
|
||||
|
||||
Args:
|
||||
reranker_instance: The reranker instance to use for reranking
|
||||
"""
|
||||
self.reranker_instance = reranker_instance
|
||||
|
||||
def rerank_documents(self, query_text: str, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Rerank documents using the configured reranker
|
||||
|
||||
Args:
|
||||
query_text: The query text to use for reranking
|
||||
documents: List of document dictionaries to rerank
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: Reranked documents
|
||||
"""
|
||||
if not self.reranker_instance or not documents:
|
||||
return documents
|
||||
|
||||
try:
|
||||
# Create Document objects for the rerankers library
|
||||
reranker_docs = []
|
||||
for i, doc in enumerate(documents):
|
||||
chunk_id = doc.get("chunk_id", f"chunk_{i}")
|
||||
content = doc.get("content", "")
|
||||
score = doc.get("score", 0.0)
|
||||
document_info = doc.get("document", {})
|
||||
|
||||
reranker_docs.append(
|
||||
RerankerDocument(
|
||||
text=content,
|
||||
doc_id=chunk_id,
|
||||
metadata={
|
||||
'document_id': document_info.get("id", ""),
|
||||
'document_title': document_info.get("title", ""),
|
||||
'document_type': document_info.get("document_type", ""),
|
||||
'rrf_score': score
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Rerank using the configured reranker
|
||||
reranking_results = self.reranker_instance.rank(
|
||||
query=query_text,
|
||||
docs=reranker_docs
|
||||
)
|
||||
|
||||
# Process the results from the reranker
|
||||
# Convert to serializable dictionaries
|
||||
serialized_results = []
|
||||
for result in reranking_results.results:
|
||||
# Find the original document by id
|
||||
original_doc = next((doc for doc in documents if doc.get("chunk_id") == result.document.doc_id), None)
|
||||
if original_doc:
|
||||
# Create a new document with the reranked score
|
||||
reranked_doc = original_doc.copy()
|
||||
reranked_doc["score"] = float(result.score)
|
||||
reranked_doc["rank"] = result.rank
|
||||
serialized_results.append(reranked_doc)
|
||||
|
||||
return serialized_results
|
||||
|
||||
except Exception as e:
|
||||
# Log the error
|
||||
logging.error(f"Error during reranking: {str(e)}")
|
||||
# Fall back to original documents without reranking
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def get_reranker_instance(config=None) -> Optional['RerankerService']:
|
||||
"""
|
||||
Get a reranker service instance based on configuration
|
||||
|
||||
Args:
|
||||
config: Configuration object that may contain a reranker_instance
|
||||
|
||||
Returns:
|
||||
Optional[RerankerService]: A reranker service instance or None
|
||||
"""
|
||||
if config and hasattr(config, 'reranker_instance') and config.reranker_instance:
|
||||
return RerankerService(config.reranker_instance)
|
||||
return None
|
211
surfsense_backend/app/utils/research_service.py
Normal file
211
surfsense_backend/app/utils/research_service.py
Normal file
|
@ -0,0 +1,211 @@
|
|||
import asyncio
|
||||
import re
|
||||
from typing import List, Dict, Any, AsyncGenerator, Callable, Optional
|
||||
from langchain.schema import Document
|
||||
from gpt_researcher.agent import GPTResearcher
|
||||
from gpt_researcher.utils.enum import ReportType, Tone, ReportSource
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
class ResearchService:
|
||||
@staticmethod
|
||||
async def create_custom_prompt(user_query: str) -> str:
|
||||
citation_prompt = f"""
|
||||
You are a research assistant tasked with analyzing documents and providing comprehensive answers with proper citations in IEEE format.
|
||||
|
||||
<instructions>
|
||||
1. Carefully analyze all provided documents in the <document> section's.
|
||||
2. Extract relevant information that addresses the user's query.
|
||||
3. Synthesize a comprehensive, well-structured answer using information from these documents.
|
||||
4. For EVERY piece of information you include from the documents, add an IEEE-style citation in square brackets [X] where X is the source_id from the document's metadata.
|
||||
5. Make sure ALL factual statements from the documents have proper citations.
|
||||
6. If multiple documents support the same point, include all relevant citations [X], [Y].
|
||||
7. Present information in a logical, coherent flow.
|
||||
8. Use your own words to connect ideas, but cite ALL information from the documents.
|
||||
9. If documents contain conflicting information, acknowledge this and present both perspectives with appropriate citations.
|
||||
10. Do not make up or include information not found in the provided documents.
|
||||
11. CRITICAL: You MUST use the exact source_id value from each document's metadata for citations. Do not create your own citation numbers.
|
||||
12. CRITICAL: Every citation MUST be in the IEEE format [X] where X is the exact source_id value.
|
||||
13. CRITICAL: Never renumber or reorder citations - always use the original source_id values.
|
||||
14. CRITICAL: Do not return citations as clickable links.
|
||||
15. CRITICAL: Never format citations as markdown links like "([1](https://example.com))". Always use plain square brackets only.
|
||||
16. CRITICAL: Citations must ONLY appear as [X] or [X], [Y], [Z] format - never with parentheses, hyperlinks, or other formatting.
|
||||
17. CRITICAL: Never make up citation numbers. Only use source_id values that are explicitly provided in the document metadata.
|
||||
18. CRITICAL: If you are unsure about a source_id, do not include a citation rather than guessing or making one up.
|
||||
</instructions>
|
||||
|
||||
<format>
|
||||
- Write in clear, professional language suitable for academic or technical audiences
|
||||
- Organize your response with appropriate paragraphs, headings, and structure
|
||||
- Every fact from the documents must have an IEEE-style citation in square brackets [X] where X is the EXACT source_id from the document's metadata
|
||||
- Citations should appear at the end of the sentence containing the information they support
|
||||
- Multiple citations should be separated by commas: [X], [Y], [Z]
|
||||
- No need to return references section. Just citation numbers in answer.
|
||||
- NEVER create your own citation numbering system - use the exact source_id values from the documents.
|
||||
- NEVER format citations as clickable links or as markdown links like "([1](https://example.com))". Always use plain square brackets only.
|
||||
- NEVER make up citation numbers if you are unsure about the source_id. It is better to omit the citation than to guess.
|
||||
</format>
|
||||
|
||||
<input_example>
|
||||
<document>
|
||||
<metadata>
|
||||
<source_id>1</source_id>
|
||||
</metadata>
|
||||
<content>
|
||||
<text>
|
||||
The Great Barrier Reef is the world's largest coral reef system, stretching over 2,300 kilometers along the coast of Queensland, Australia. It comprises over 2,900 individual reefs and 900 islands.
|
||||
</text>
|
||||
</content>
|
||||
</document>
|
||||
|
||||
<document>
|
||||
<metadata>
|
||||
<source_id>13</source_id>
|
||||
</metadata>
|
||||
<content>
|
||||
<text>
|
||||
Climate change poses a significant threat to coral reefs worldwide. Rising ocean temperatures have led to mass coral bleaching events in the Great Barrier Reef in 2016, 2017, and 2020.
|
||||
</text>
|
||||
</content>
|
||||
</document>
|
||||
|
||||
<document>
|
||||
<metadata>
|
||||
<source_id>21</source_id>
|
||||
</metadata>
|
||||
<content>
|
||||
<text>
|
||||
The Great Barrier Reef was designated a UNESCO World Heritage Site in 1981 due to its outstanding universal value and biological diversity. It is home to over 1,500 species of fish and 400 types of coral.
|
||||
</text>
|
||||
</content>
|
||||
</document>
|
||||
</input_example>
|
||||
|
||||
<output_example>
|
||||
The Great Barrier Reef is the world's largest coral reef system, stretching over 2,300 kilometers along the coast of Queensland, Australia [1]. It was designated a UNESCO World Heritage Site in 1981 due to its outstanding universal value and biological diversity [21]. The reef is home to over 1,500 species of fish and 400 types of coral [21]. Unfortunately, climate change poses a significant threat to coral reefs worldwide, with rising ocean temperatures leading to mass coral bleaching events in the Great Barrier Reef in 2016, 2017, and 2020 [13]. The reef system comprises over 2,900 individual reefs and 900 islands [1], making it an ecological treasure that requires protection from multiple threats [1], [13].
|
||||
</output_example>
|
||||
|
||||
<incorrect_citation_formats>
|
||||
DO NOT use any of these incorrect citation formats:
|
||||
- Using parentheses and markdown links: ([1](https://github.com/MODSetter/SurfSense))
|
||||
- Using parentheses around brackets: ([1])
|
||||
- Using hyperlinked text: [link to source 1](https://example.com)
|
||||
- Using footnote style: ... reef system¹
|
||||
- Making up citation numbers when source_id is unknown
|
||||
|
||||
ONLY use plain square brackets [1] or multiple citations [1], [2], [3]
|
||||
</incorrect_citation_formats>
|
||||
|
||||
Note that the citation numbers match exactly with the source_id values (1, 13, and 21) and are not renumbered sequentially. Citations follow IEEE style with square brackets and appear at the end of sentences.
|
||||
|
||||
Now, please research the following query:
|
||||
|
||||
<user_query_to_research>
|
||||
{user_query}
|
||||
</user_query_to_research>
|
||||
"""
|
||||
|
||||
return citation_prompt
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def stream_research(
|
||||
user_query: str,
|
||||
documents: List[Document] = None,
|
||||
on_progress: Optional[Callable] = None,
|
||||
research_mode: str = "GENERAL"
|
||||
) -> str:
|
||||
"""
|
||||
Stream the research process using GPTResearcher
|
||||
|
||||
Args:
|
||||
user_query: The user's query
|
||||
documents: List of Document objects to use for research
|
||||
on_progress: Optional callback for progress updates
|
||||
research_mode: Research mode to use
|
||||
|
||||
Returns:
|
||||
str: The final research report
|
||||
"""
|
||||
# Create a custom websocket-like object to capture streaming output
|
||||
class StreamingWebsocket:
|
||||
async def send_json(self, data):
|
||||
if on_progress:
|
||||
try:
|
||||
# Filter out excessive logging of the prompt
|
||||
if data.get("type") == "logs":
|
||||
output = data.get("output", "")
|
||||
# Check if this is a verbose prompt log
|
||||
if "You are a research assistant tasked with analyzing documents" in output and len(output) > 500:
|
||||
# Replace with a shorter message
|
||||
data["output"] = f"Processing research for query: {user_query}"
|
||||
|
||||
result = await on_progress(data)
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"Error in on_progress callback: {e}")
|
||||
return None
|
||||
|
||||
streaming_websocket = StreamingWebsocket()
|
||||
|
||||
custom_prompt_for_ieee_citations = await ResearchService.create_custom_prompt(user_query)
|
||||
|
||||
if(research_mode == "GENERAL"):
|
||||
research_report_type = ReportType.CustomReport.value
|
||||
elif(research_mode == "DEEP"):
|
||||
research_report_type = ReportType.ResearchReport.value
|
||||
elif(research_mode == "DEEPER"):
|
||||
research_report_type = ReportType.DetailedReport.value
|
||||
# elif(research_mode == "DEEPEST"):
|
||||
# research_report_type = ReportType.DeepResearch.value
|
||||
|
||||
# Initialize GPTResearcher with the streaming websocket
|
||||
researcher = GPTResearcher(
|
||||
query=custom_prompt_for_ieee_citations,
|
||||
report_type=research_report_type,
|
||||
report_format="IEEE",
|
||||
report_source=ReportSource.LangChainDocuments.value,
|
||||
tone=Tone.Formal,
|
||||
documents=documents,
|
||||
verbose=True,
|
||||
websocket=streaming_websocket
|
||||
)
|
||||
|
||||
# Conduct research
|
||||
await researcher.conduct_research()
|
||||
|
||||
# Generate report with streaming
|
||||
report = await researcher.write_report()
|
||||
|
||||
# Fix citation format
|
||||
report = ResearchService.fix_citation_format(report)
|
||||
|
||||
return report
|
||||
|
||||
@staticmethod
|
||||
def fix_citation_format(text: str) -> str:
|
||||
"""
|
||||
Fix any incorrectly formatted citations in the text.
|
||||
|
||||
Args:
|
||||
text: The text to fix
|
||||
|
||||
Returns:
|
||||
str: The text with fixed citations
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
# More specific pattern to match only numeric citations in markdown-style links
|
||||
# This matches patterns like ([1](https://github.com/...)) but not general links like ([Click here](https://...))
|
||||
pattern = r'\(\[(\d+)\]\((https?://[^\)]+)\)\)'
|
||||
|
||||
# Replace with just [X] where X is the number
|
||||
text = re.sub(pattern, r'[\1]', text)
|
||||
|
||||
# Also match other incorrect formats like ([1]) and convert to [1]
|
||||
# Only match if the content inside brackets is a number
|
||||
text = re.sub(r'\(\[(\d+)\]\)', r'[\1]', text)
|
||||
|
||||
return text
|
99
surfsense_backend/app/utils/streaming_service.py
Normal file
99
surfsense_backend/app/utils/streaming_service.py
Normal file
|
@ -0,0 +1,99 @@
|
|||
import json
|
||||
from typing import List, Dict, Any, Generator
|
||||
|
||||
class StreamingService:
|
||||
def __init__(self):
|
||||
self.terminal_idx = 1
|
||||
self.message_annotations = [
|
||||
{
|
||||
"type": "TERMINAL_INFO",
|
||||
"content": []
|
||||
},
|
||||
{
|
||||
"type": "SOURCES",
|
||||
"content": []
|
||||
},
|
||||
{
|
||||
"type": "ANSWER",
|
||||
"content": []
|
||||
}
|
||||
]
|
||||
|
||||
def add_terminal_message(self, text: str, message_type: str = "info") -> str:
|
||||
"""
|
||||
Add a terminal message to the annotations and return the formatted response
|
||||
|
||||
Args:
|
||||
text: The message text
|
||||
message_type: The message type (info, success, error)
|
||||
|
||||
Returns:
|
||||
str: The formatted response string
|
||||
"""
|
||||
self.message_annotations[0]["content"].append({
|
||||
"id": self.terminal_idx,
|
||||
"text": text,
|
||||
"type": message_type
|
||||
})
|
||||
self.terminal_idx += 1
|
||||
return self._format_annotations()
|
||||
|
||||
def update_sources(self, sources: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Update the sources in the annotations and return the formatted response
|
||||
|
||||
Args:
|
||||
sources: List of source objects
|
||||
|
||||
Returns:
|
||||
str: The formatted response string
|
||||
"""
|
||||
self.message_annotations[1]["content"] = sources
|
||||
return self._format_annotations()
|
||||
|
||||
def update_answer(self, answer_content: List[str]) -> str:
|
||||
"""
|
||||
Update the answer in the annotations and return the formatted response
|
||||
|
||||
Args:
|
||||
answer_content: The answer content as a list of strings
|
||||
|
||||
Returns:
|
||||
str: The formatted response string
|
||||
"""
|
||||
self.message_annotations[2] = {
|
||||
"type": "ANSWER",
|
||||
"content": answer_content
|
||||
}
|
||||
return self._format_annotations()
|
||||
|
||||
def _format_annotations(self) -> str:
|
||||
"""
|
||||
Format the annotations as a string
|
||||
|
||||
Returns:
|
||||
str: The formatted annotations string
|
||||
"""
|
||||
return f'8:{json.dumps(self.message_annotations)}\n'
|
||||
|
||||
def format_completion(self, prompt_tokens: int = 156, completion_tokens: int = 204) -> str:
|
||||
"""
|
||||
Format a completion message
|
||||
|
||||
Args:
|
||||
prompt_tokens: Number of prompt tokens
|
||||
completion_tokens: Number of completion tokens
|
||||
|
||||
Returns:
|
||||
str: The formatted completion string
|
||||
"""
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
completion_data = {
|
||||
"finishReason": "stop",
|
||||
"usage": {
|
||||
"promptTokens": prompt_tokens,
|
||||
"completionTokens": completion_tokens,
|
||||
"totalTokens": total_tokens
|
||||
}
|
||||
}
|
||||
return f'd:{json.dumps(completion_data)}\n'
|
4
surfsense_backend/main.py
Normal file
4
surfsense_backend/main.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
import uvicorn
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run("app.app:app", host="0.0.0.0", log_level="info")
|
27
surfsense_backend/pyproject.toml
Normal file
27
surfsense_backend/pyproject.toml
Normal file
|
@ -0,0 +1,27 @@
|
|||
[project]
|
||||
name = "surf-new-backend"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"asyncpg>=0.30.0",
|
||||
"chonkie[all]>=0.4.1",
|
||||
"fastapi>=0.115.8",
|
||||
"fastapi-users[oauth,sqlalchemy]>=14.0.1",
|
||||
"firecrawl-py>=1.12.0",
|
||||
"gpt-researcher>=0.12.12",
|
||||
"langchain-community>=0.3.17",
|
||||
"langchain-unstructured>=0.1.6",
|
||||
"litellm>=1.61.4",
|
||||
"markdownify>=0.14.1",
|
||||
"notion-client>=2.3.0",
|
||||
"pgvector>=0.3.6",
|
||||
"playwright>=1.50.0",
|
||||
"rerankers[flashrank]>=0.7.1",
|
||||
"slack-sdk>=3.34.0",
|
||||
"tavily-python>=0.3.2",
|
||||
"unstructured-client>=0.30.0",
|
||||
"uvicorn[standard]>=0.34.0",
|
||||
"validators>=0.34.0",
|
||||
]
|
3271
surfsense_backend/uv.lock
generated
Normal file
3271
surfsense_backend/uv.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
Loading…
Add table
Reference in a new issue