mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-01 18:19:08 +00:00
Merge pull request #90 from MODSetter/dev
feat: Introduce RAPTOR Search Mode
This commit is contained in:
commit
29f4d90a0a
11 changed files with 318 additions and 127 deletions
|
@ -3,10 +3,16 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, fields
|
||||
from enum import Enum
|
||||
from typing import Optional, List, Any
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
class SearchMode(Enum):
|
||||
"""Enum defining the type of search mode."""
|
||||
CHUNKS = "CHUNKS"
|
||||
DOCUMENTS = "DOCUMENTS"
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Configuration:
|
||||
|
@ -18,6 +24,7 @@ class Configuration:
|
|||
connectors_to_search: List[str]
|
||||
user_id: str
|
||||
search_space_id: int
|
||||
search_mode: SearchMode
|
||||
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -10,7 +10,7 @@ from langchain_core.runnables import RunnableConfig
|
|||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from .configuration import Configuration
|
||||
from .configuration import Configuration, SearchMode
|
||||
from .prompts import get_answer_outline_system_prompt
|
||||
from .state import State
|
||||
from .sub_section_writer.graph import graph as sub_section_writer_graph
|
||||
|
@ -149,7 +149,8 @@ async def fetch_relevant_documents(
|
|||
writer: StreamWriter = None,
|
||||
state: State = None,
|
||||
top_k: int = 10,
|
||||
connector_service: ConnectorService = None
|
||||
connector_service: ConnectorService = None,
|
||||
search_mode: SearchMode = SearchMode.CHUNKS
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch relevant documents for research questions using the provided connectors.
|
||||
|
@ -213,7 +214,8 @@ async def fetch_relevant_documents(
|
|||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k
|
||||
top_k=top_k,
|
||||
search_mode=search_mode
|
||||
)
|
||||
|
||||
# Add to sources and raw documents
|
||||
|
@ -231,7 +233,8 @@ async def fetch_relevant_documents(
|
|||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k
|
||||
top_k=top_k,
|
||||
search_mode=search_mode
|
||||
)
|
||||
|
||||
# Add to sources and raw documents
|
||||
|
@ -249,7 +252,8 @@ async def fetch_relevant_documents(
|
|||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k
|
||||
top_k=top_k,
|
||||
search_mode=search_mode
|
||||
)
|
||||
|
||||
# Add to sources and raw documents
|
||||
|
@ -267,7 +271,8 @@ async def fetch_relevant_documents(
|
|||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k
|
||||
top_k=top_k,
|
||||
search_mode=search_mode
|
||||
)
|
||||
|
||||
# Add to sources and raw documents
|
||||
|
@ -286,7 +291,8 @@ async def fetch_relevant_documents(
|
|||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k
|
||||
top_k=top_k,
|
||||
search_mode=search_mode
|
||||
)
|
||||
|
||||
# Add to sources and raw documents
|
||||
|
@ -304,7 +310,8 @@ async def fetch_relevant_documents(
|
|||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k
|
||||
top_k=top_k,
|
||||
search_mode=search_mode
|
||||
)
|
||||
|
||||
# Add to sources and raw documents
|
||||
|
@ -322,7 +329,8 @@ async def fetch_relevant_documents(
|
|||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k
|
||||
top_k=top_k,
|
||||
search_mode=search_mode
|
||||
)
|
||||
|
||||
# Add to sources and raw documents
|
||||
|
@ -340,7 +348,8 @@ async def fetch_relevant_documents(
|
|||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k
|
||||
top_k=top_k,
|
||||
search_mode=search_mode
|
||||
)
|
||||
|
||||
# Add to sources and raw documents
|
||||
|
@ -558,7 +567,8 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW
|
|||
writer=writer,
|
||||
state=state,
|
||||
top_k=TOP_K,
|
||||
connector_service=connector_service
|
||||
connector_service=connector_service,
|
||||
search_mode=configuration.search_mode
|
||||
)
|
||||
except Exception as e:
|
||||
error_message = f"Error fetching relevant documents: {str(e)}"
|
||||
|
|
|
@ -141,6 +141,11 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
|
|||
|
||||
# Construct a clear, structured query for the LLM
|
||||
human_message_content = f"""
|
||||
Source material:
|
||||
<documents>
|
||||
{documents_text}
|
||||
</documents>
|
||||
|
||||
Now user's query is:
|
||||
<user_query>
|
||||
{user_query}
|
||||
|
@ -158,11 +163,6 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
|
|||
<guiding_questions>
|
||||
{questions_text}
|
||||
</guiding_questions>
|
||||
|
||||
Use the provided documents as your source material and cite them properly using the IEEE citation format [X] where X is the source_id.
|
||||
<documents>
|
||||
{documents_text}
|
||||
</documents>
|
||||
"""
|
||||
|
||||
# Create messages for the LLM
|
||||
|
|
|
@ -25,6 +25,8 @@ You are a research assistant tasked with analyzing documents and providing compr
|
|||
16. CRITICAL: Citations must ONLY appear as [X] or [X], [Y], [Z] format - never with parentheses, hyperlinks, or other formatting.
|
||||
17. CRITICAL: Never make up citation numbers. Only use source_id values that are explicitly provided in the document metadata.
|
||||
18. CRITICAL: If you are unsure about a source_id, do not include a citation rather than guessing or making one up.
|
||||
19. CRITICAL: Focus only on answering the user's query. Any guiding questions provided are for your thinking process only and should not be mentioned in your response.
|
||||
20. CRITICAL: Ensure your response aligns with the provided sub-section title and section position.
|
||||
</instructions>
|
||||
|
||||
<format>
|
||||
|
@ -37,6 +39,8 @@ You are a research assistant tasked with analyzing documents and providing compr
|
|||
- NEVER create your own citation numbering system - use the exact source_id values from the documents.
|
||||
- NEVER format citations as clickable links or as markdown links like "([1](https://example.com))". Always use plain square brackets only.
|
||||
- NEVER make up citation numbers if you are unsure about the source_id. It is better to omit the citation than to guess.
|
||||
- NEVER include or mention the guiding questions in your response. They are only to help guide your thinking.
|
||||
- ALWAYS focus on answering the user's query directly from the information in the documents.
|
||||
</format>
|
||||
|
||||
<input_example>
|
||||
|
@ -84,4 +88,21 @@ ONLY use plain square brackets [1] or multiple citations [1], [2], [3]
|
|||
</incorrect_citation_formats>
|
||||
|
||||
Note that the citation numbers match exactly with the source_id values (1, 13, and 21) and are not renumbered sequentially. Citations follow IEEE style with square brackets and appear at the end of sentences.
|
||||
|
||||
<user_query_instructions>
|
||||
When you see a user query like:
|
||||
<user_query>
|
||||
Give all linear issues.
|
||||
</user_query>
|
||||
|
||||
Focus exclusively on answering this query using information from the provided documents.
|
||||
|
||||
If guiding questions are provided in a <guiding_questions> section, use them only to guide your thinking process. Do not mention or list these questions in your response.
|
||||
|
||||
Make sure your response:
|
||||
1. Directly answers the user's query
|
||||
2. Fits the provided sub-section title and section position
|
||||
3. Uses proper citations for all information from documents
|
||||
4. Is well-structured and professional in tone
|
||||
</user_query_instructions>
|
||||
"""
|
|
@ -113,8 +113,6 @@ class DocumentHybridSearchRetriever:
|
|||
search_space_id: Optional search space ID to filter results
|
||||
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
|
||||
|
||||
Returns:
|
||||
List of dictionaries containing document data and relevance scores
|
||||
"""
|
||||
from sqlalchemy import select, func, text
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
@ -224,10 +222,22 @@ class DocumentHybridSearchRetriever:
|
|||
# Convert to serializable dictionaries
|
||||
serialized_results = []
|
||||
for document, score in documents_with_scores:
|
||||
# Fetch associated chunks for this document
|
||||
from sqlalchemy import select
|
||||
from app.db import Chunk
|
||||
|
||||
chunks_query = select(Chunk).where(Chunk.document_id == document.id).order_by(Chunk.id)
|
||||
chunks_result = await self.db_session.execute(chunks_query)
|
||||
chunks = chunks_result.scalars().all()
|
||||
|
||||
# Concatenate chunks content
|
||||
concatenated_chunks_content = " ".join([chunk.content for chunk in chunks]) if chunks else document.content
|
||||
|
||||
serialized_results.append({
|
||||
"document_id": document.id,
|
||||
"title": document.title,
|
||||
"content": document.content,
|
||||
"chunks_content": concatenated_chunks_content,
|
||||
"document_type": document.document_type.value if hasattr(document, 'document_type') else None,
|
||||
"metadata": document.document_metadata,
|
||||
"score": float(score), # Ensure score is a Python float
|
||||
|
|
|
@ -11,6 +11,8 @@ from sqlalchemy.exc import IntegrityError, OperationalError
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from langchain.schema import HumanMessage, AIMessage
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/chat")
|
||||
|
@ -28,6 +30,8 @@ async def handle_chat_data(
|
|||
search_space_id = request.data.get('search_space_id')
|
||||
research_mode: str = request.data.get('research_mode')
|
||||
selected_connectors: List[str] = request.data.get('selected_connectors')
|
||||
|
||||
search_mode_str = request.data.get('search_mode', "CHUNKS")
|
||||
|
||||
# Convert search_space_id to integer if it's a string
|
||||
if search_space_id and isinstance(search_space_id, str):
|
||||
|
@ -66,7 +70,8 @@ async def handle_chat_data(
|
|||
session,
|
||||
research_mode,
|
||||
selected_connectors,
|
||||
langchain_chat_history
|
||||
langchain_chat_history,
|
||||
search_mode_str
|
||||
))
|
||||
response.headers['x-vercel-ai-data-stream'] = 'v1'
|
||||
return response
|
||||
|
|
|
@ -6,6 +6,8 @@ from app.agents.researcher.state import State
|
|||
from app.utils.streaming_service import StreamingService
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.researcher.configuration import SearchMode
|
||||
|
||||
|
||||
async def stream_connector_search_results(
|
||||
user_query: str,
|
||||
|
@ -14,7 +16,8 @@ async def stream_connector_search_results(
|
|||
session: AsyncSession,
|
||||
research_mode: str,
|
||||
selected_connectors: List[str],
|
||||
langchain_chat_history: List[Any]
|
||||
langchain_chat_history: List[Any],
|
||||
search_mode_str: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Stream connector search results to the client
|
||||
|
@ -41,6 +44,11 @@ async def stream_connector_search_results(
|
|||
# Convert UUID to string if needed
|
||||
user_id_str = str(user_id) if isinstance(user_id, UUID) else user_id
|
||||
|
||||
if search_mode_str == "CHUNKS":
|
||||
search_mode = SearchMode.CHUNKS
|
||||
elif search_mode_str == "DOCUMENTS":
|
||||
search_mode = SearchMode.DOCUMENTS
|
||||
|
||||
# Sample configuration
|
||||
config = {
|
||||
"configurable": {
|
||||
|
@ -48,7 +56,8 @@ async def stream_connector_search_results(
|
|||
"num_sections": NUM_SECTIONS,
|
||||
"connectors_to_search": selected_connectors,
|
||||
"user_id": user_id_str,
|
||||
"search_space_id": search_space_id
|
||||
"search_space_id": search_space_id,
|
||||
"search_mode": search_mode
|
||||
}
|
||||
}
|
||||
# Initialize state with database session and streaming service
|
||||
|
|
|
@ -4,32 +4,47 @@ import asyncio
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||
from app.retriver.documents_hybrid_search import DocumentHybridSearchRetriever
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
from tavily import TavilyClient
|
||||
from linkup import LinkupClient
|
||||
|
||||
from app.agents.researcher.configuration import SearchMode
|
||||
|
||||
|
||||
class ConnectorService:
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
self.retriever = ChucksHybridSearchRetriever(session)
|
||||
self.chunk_retriever = ChucksHybridSearchRetriever(session)
|
||||
self.document_retriever = DocumentHybridSearchRetriever(session)
|
||||
self.source_id_counter = 1
|
||||
self.counter_lock = asyncio.Lock() # Lock to protect counter in multithreaded environments
|
||||
|
||||
async def search_crawled_urls(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
async def search_crawled_urls(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
|
||||
"""
|
||||
Search for crawled URLs and return both the source information and langchain documents
|
||||
|
||||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
crawled_urls_chunks = await self.retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="CRAWLED_URL"
|
||||
)
|
||||
if search_mode == SearchMode.CHUNKS:
|
||||
crawled_urls_chunks = await self.chunk_retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="CRAWLED_URL"
|
||||
)
|
||||
elif search_mode == SearchMode.DOCUMENTS:
|
||||
crawled_urls_chunks = await self.document_retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="CRAWLED_URL"
|
||||
)
|
||||
# Transform document retriever results to match expected format
|
||||
crawled_urls_chunks = self._transform_document_results(crawled_urls_chunks)
|
||||
|
||||
# Early return if no results
|
||||
if not crawled_urls_chunks:
|
||||
|
@ -71,20 +86,31 @@ class ConnectorService:
|
|||
|
||||
return result_object, crawled_urls_chunks
|
||||
|
||||
async def search_files(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
async def search_files(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
|
||||
"""
|
||||
Search for files and return both the source information and langchain documents
|
||||
|
||||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
files_chunks = await self.retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="FILE"
|
||||
)
|
||||
if search_mode == SearchMode.CHUNKS:
|
||||
files_chunks = await self.chunk_retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="FILE"
|
||||
)
|
||||
elif search_mode == SearchMode.DOCUMENTS:
|
||||
files_chunks = await self.document_retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="FILE"
|
||||
)
|
||||
# Transform document retriever results to match expected format
|
||||
files_chunks = self._transform_document_results(files_chunks)
|
||||
|
||||
# Early return if no results
|
||||
if not files_chunks:
|
||||
|
@ -126,6 +152,31 @@ class ConnectorService:
|
|||
|
||||
return result_object, files_chunks
|
||||
|
||||
def _transform_document_results(self, document_results: List[Dict]) -> List[Dict]:
|
||||
"""
|
||||
Transform results from document_retriever.hybrid_search() to match the format
|
||||
expected by the processing code.
|
||||
|
||||
Args:
|
||||
document_results: Results from document_retriever.hybrid_search()
|
||||
|
||||
Returns:
|
||||
List of transformed results in the format expected by the processing code
|
||||
"""
|
||||
transformed_results = []
|
||||
for doc in document_results:
|
||||
transformed_results.append({
|
||||
'document': {
|
||||
'id': doc.get('document_id'),
|
||||
'title': doc.get('title', 'Untitled Document'),
|
||||
'document_type': doc.get('document_type'),
|
||||
'metadata': doc.get('metadata', {}),
|
||||
},
|
||||
'content': doc.get('chunks_content', doc.get('content', '')),
|
||||
'score': doc.get('score', 0.0)
|
||||
})
|
||||
return transformed_results
|
||||
|
||||
async def get_connector_by_type(self, user_id: str, connector_type: SearchSourceConnectorType) -> Optional[SearchSourceConnector]:
|
||||
"""
|
||||
Get a connector by type for a specific user
|
||||
|
@ -249,20 +300,31 @@ class ConnectorService:
|
|||
"sources": [],
|
||||
}, []
|
||||
|
||||
async def search_slack(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
async def search_slack(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
|
||||
"""
|
||||
Search for slack and return both the source information and langchain documents
|
||||
|
||||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
slack_chunks = await self.retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="SLACK_CONNECTOR"
|
||||
)
|
||||
if search_mode == SearchMode.CHUNKS:
|
||||
slack_chunks = await self.chunk_retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="SLACK_CONNECTOR"
|
||||
)
|
||||
elif search_mode == SearchMode.DOCUMENTS:
|
||||
slack_chunks = await self.document_retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="SLACK_CONNECTOR"
|
||||
)
|
||||
# Transform document retriever results to match expected format
|
||||
slack_chunks = self._transform_document_results(slack_chunks)
|
||||
|
||||
# Early return if no results
|
||||
if not slack_chunks:
|
||||
|
@ -323,7 +385,7 @@ class ConnectorService:
|
|||
|
||||
return result_object, slack_chunks
|
||||
|
||||
async def search_notion(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
async def search_notion(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
|
||||
"""
|
||||
Search for Notion pages and return both the source information and langchain documents
|
||||
|
||||
|
@ -336,14 +398,25 @@ class ConnectorService:
|
|||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
notion_chunks = await self.retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="NOTION_CONNECTOR"
|
||||
)
|
||||
|
||||
if search_mode == SearchMode.CHUNKS:
|
||||
notion_chunks = await self.chunk_retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="NOTION_CONNECTOR"
|
||||
)
|
||||
elif search_mode == SearchMode.DOCUMENTS:
|
||||
notion_chunks = await self.document_retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="NOTION_CONNECTOR"
|
||||
)
|
||||
# Transform document retriever results to match expected format
|
||||
notion_chunks = self._transform_document_results(notion_chunks)
|
||||
|
||||
# Early return if no results
|
||||
if not notion_chunks:
|
||||
return {
|
||||
|
@ -405,7 +478,7 @@ class ConnectorService:
|
|||
|
||||
return result_object, notion_chunks
|
||||
|
||||
async def search_extension(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
async def search_extension(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
|
||||
"""
|
||||
Search for extension data and return both the source information and langchain documents
|
||||
|
||||
|
@ -418,14 +491,25 @@ class ConnectorService:
|
|||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
extension_chunks = await self.retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="EXTENSION"
|
||||
)
|
||||
|
||||
if search_mode == SearchMode.CHUNKS:
|
||||
extension_chunks = await self.chunk_retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="EXTENSION"
|
||||
)
|
||||
elif search_mode == SearchMode.DOCUMENTS:
|
||||
extension_chunks = await self.document_retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="EXTENSION"
|
||||
)
|
||||
# Transform document retriever results to match expected format
|
||||
extension_chunks = self._transform_document_results(extension_chunks)
|
||||
|
||||
# Early return if no results
|
||||
if not extension_chunks:
|
||||
return {
|
||||
|
@ -505,7 +589,7 @@ class ConnectorService:
|
|||
|
||||
return result_object, extension_chunks
|
||||
|
||||
async def search_youtube(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
async def search_youtube(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
|
||||
"""
|
||||
Search for YouTube videos and return both the source information and langchain documents
|
||||
|
||||
|
@ -518,13 +602,24 @@ class ConnectorService:
|
|||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
youtube_chunks = await self.retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="YOUTUBE_VIDEO"
|
||||
)
|
||||
if search_mode == SearchMode.CHUNKS:
|
||||
youtube_chunks = await self.chunk_retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="YOUTUBE_VIDEO"
|
||||
)
|
||||
elif search_mode == SearchMode.DOCUMENTS:
|
||||
youtube_chunks = await self.document_retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="YOUTUBE_VIDEO"
|
||||
)
|
||||
# Transform document retriever results to match expected format
|
||||
youtube_chunks = self._transform_document_results(youtube_chunks)
|
||||
|
||||
# Early return if no results
|
||||
if not youtube_chunks:
|
||||
|
@ -587,20 +682,31 @@ class ConnectorService:
|
|||
|
||||
return result_object, youtube_chunks
|
||||
|
||||
async def search_github(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
async def search_github(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
|
||||
"""
|
||||
Search for GitHub documents and return both the source information and langchain documents
|
||||
|
||||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
github_chunks = await self.retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="GITHUB_CONNECTOR"
|
||||
)
|
||||
if search_mode == SearchMode.CHUNKS:
|
||||
github_chunks = await self.chunk_retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="GITHUB_CONNECTOR"
|
||||
)
|
||||
elif search_mode == SearchMode.DOCUMENTS:
|
||||
github_chunks = await self.document_retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="GITHUB_CONNECTOR"
|
||||
)
|
||||
# Transform document retriever results to match expected format
|
||||
github_chunks = self._transform_document_results(github_chunks)
|
||||
|
||||
# Early return if no results
|
||||
if not github_chunks:
|
||||
|
@ -643,7 +749,7 @@ class ConnectorService:
|
|||
|
||||
return result_object, github_chunks
|
||||
|
||||
async def search_linear(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
async def search_linear(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
|
||||
"""
|
||||
Search for Linear issues and comments and return both the source information and langchain documents
|
||||
|
||||
|
@ -656,14 +762,25 @@ class ConnectorService:
|
|||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
linear_chunks = await self.retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="LINEAR_CONNECTOR"
|
||||
)
|
||||
|
||||
if search_mode == SearchMode.CHUNKS:
|
||||
linear_chunks = await self.chunk_retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="LINEAR_CONNECTOR"
|
||||
)
|
||||
elif search_mode == SearchMode.DOCUMENTS:
|
||||
linear_chunks = await self.document_retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="LINEAR_CONNECTOR"
|
||||
)
|
||||
# Transform document retriever results to match expected format
|
||||
linear_chunks = self._transform_document_results(linear_chunks)
|
||||
|
||||
# Early return if no results
|
||||
if not linear_chunks:
|
||||
return {
|
||||
|
|
|
@ -13,7 +13,9 @@ import {
|
|||
ArrowDown,
|
||||
CircleUser,
|
||||
Database,
|
||||
SendHorizontal
|
||||
SendHorizontal,
|
||||
FileText,
|
||||
Grid3x3
|
||||
} from 'lucide-react';
|
||||
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
|
||||
import { Button } from '@/components/ui/button';
|
||||
|
@ -248,6 +250,7 @@ const ChatPage = () => {
|
|||
const tabsListRef = useRef<HTMLDivElement>(null);
|
||||
const [terminalExpanded, setTerminalExpanded] = useState(false);
|
||||
const [selectedConnectors, setSelectedConnectors] = useState<string[]>(["CRAWLED_URL"]);
|
||||
const [searchMode, setSearchMode] = useState<'DOCUMENTS' | 'CHUNKS'>('DOCUMENTS');
|
||||
const [researchMode, setResearchMode] = useState<ResearchMode>("GENERAL");
|
||||
const [currentTime, setCurrentTime] = useState<string>('');
|
||||
const [currentDate, setCurrentDate] = useState<string>('');
|
||||
|
@ -362,7 +365,8 @@ const ChatPage = () => {
|
|||
data: {
|
||||
search_space_id: search_space_id,
|
||||
selected_connectors: selectedConnectors,
|
||||
research_mode: researchMode
|
||||
research_mode: researchMode,
|
||||
search_mode: searchMode
|
||||
}
|
||||
},
|
||||
onError: (error) => {
|
||||
|
@ -557,11 +561,6 @@ const ChatPage = () => {
|
|||
}
|
||||
}, [terminalExpanded]);
|
||||
|
||||
// Get total sources count for a connector type
|
||||
const getSourcesCount = (connectorType: string) => {
|
||||
return getSourcesCountUtil(getMessageConnectorSources(messages[messages.length - 1]), connectorType);
|
||||
};
|
||||
|
||||
// Function to check scroll position and update indicators
|
||||
const updateScrollIndicators = () => {
|
||||
updateScrollIndicatorsUtil(tabsListRef as React.RefObject<HTMLDivElement>, setCanScrollLeft, setCanScrollRight);
|
||||
|
@ -587,23 +586,6 @@ const ChatPage = () => {
|
|||
// Use the scroll to bottom hook
|
||||
useScrollToBottom(messagesEndRef as React.RefObject<HTMLDivElement>, [messages]);
|
||||
|
||||
// Function to get sources for the main view
|
||||
const getMainViewSources = (connector: any) => {
|
||||
return getMainViewSourcesUtil(connector, INITIAL_SOURCES_DISPLAY);
|
||||
};
|
||||
|
||||
// Function to get filtered sources for the dialog with null check
|
||||
const getFilteredSourcesWithCheck = (connector: any, sourceFilter: string) => {
|
||||
if (!connector?.sources) return [];
|
||||
return getFilteredSourcesUtil(connector, sourceFilter);
|
||||
};
|
||||
|
||||
// Function to get paginated dialog sources with null check
|
||||
const getPaginatedDialogSourcesWithCheck = (connector: any, sourceFilter: string, expandedSources: boolean, sourcesPage: number, sourcesPerPage: number) => {
|
||||
if (!connector?.sources) return [];
|
||||
return getPaginatedDialogSourcesUtil(connector, sourceFilter, expandedSources, sourcesPage, sourcesPerPage);
|
||||
};
|
||||
|
||||
// Function to get a citation source by ID
|
||||
const getCitationSource = React.useCallback((citationId: number, messageIndex?: number): Source | null => {
|
||||
if (!messages || messages.length === 0) return null;
|
||||
|
@ -995,15 +977,17 @@ const ChatPage = () => {
|
|||
<span className="sr-only">Send</span>
|
||||
</Button>
|
||||
</form>
|
||||
<div className="flex items-center justify-between px-2 py-1 mt-8">
|
||||
<div className="flex items-center gap-4">
|
||||
<div className="flex items-center justify-between px-2 py-2 mt-3">
|
||||
<div className="flex items-center space-x-3">
|
||||
{/* Connector Selection Dialog */}
|
||||
<Dialog>
|
||||
<DialogTrigger asChild>
|
||||
<ConnectorButton
|
||||
selectedConnectors={selectedConnectors}
|
||||
onClick={() => { }}
|
||||
/>
|
||||
<div className="h-8">
|
||||
<ConnectorButton
|
||||
selectedConnectors={selectedConnectors}
|
||||
onClick={() => { }}
|
||||
/>
|
||||
</div>
|
||||
</DialogTrigger>
|
||||
<DialogContent className="sm:max-w-md">
|
||||
<DialogHeader>
|
||||
|
@ -1070,12 +1054,40 @@ const ChatPage = () => {
|
|||
</DialogContent>
|
||||
</Dialog>
|
||||
|
||||
{/* Search Mode Control */}
|
||||
<div className="flex items-center p-0.5 rounded-md border border-border bg-muted/20 h-8">
|
||||
<button
|
||||
onClick={() => setSearchMode('DOCUMENTS')}
|
||||
className={`flex h-full items-center justify-center gap-1 px-2 rounded text-xs font-medium transition-colors flex-1 whitespace-nowrap overflow-hidden ${
|
||||
searchMode === 'DOCUMENTS'
|
||||
? 'bg-primary text-primary-foreground shadow-sm'
|
||||
: 'text-muted-foreground hover:text-foreground hover:bg-muted/50'
|
||||
}`}
|
||||
>
|
||||
<FileText className="h-3 w-3 flex-shrink-0 mr-1" />
|
||||
<span>Full Document</span>
|
||||
</button>
|
||||
<button
|
||||
onClick={() => setSearchMode('CHUNKS')}
|
||||
className={`flex h-full items-center justify-center gap-1 px-2 rounded text-xs font-medium transition-colors flex-1 whitespace-nowrap overflow-hidden ${
|
||||
searchMode === 'CHUNKS'
|
||||
? 'bg-primary text-primary-foreground shadow-sm'
|
||||
: 'text-muted-foreground hover:text-foreground hover:bg-muted/50'
|
||||
}`}
|
||||
>
|
||||
<Grid3x3 className="h-3 w-3 flex-shrink-0 mr-1" />
|
||||
<span>Document Chunks</span>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Research Mode Segmented Control */}
|
||||
<SegmentedControl<ResearchMode>
|
||||
value={researchMode}
|
||||
onChange={setResearchMode}
|
||||
options={researcherOptions}
|
||||
/>
|
||||
<div className="h-8">
|
||||
<SegmentedControl<ResearchMode>
|
||||
value={researchMode}
|
||||
onChange={setResearchMode}
|
||||
options={researcherOptions}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
@ -147,7 +147,7 @@ export const ConnectorButton = ({ selectedConnectors, onClick, connectorSources
|
|||
return (
|
||||
<Button
|
||||
variant="outline"
|
||||
className="h-7 px-2 text-xs font-medium rounded-md border-border relative overflow-hidden group scale-90 origin-left"
|
||||
className="h-8 px-2 text-xs font-medium rounded-md border-border relative overflow-hidden group"
|
||||
onClick={onClick}
|
||||
aria-label={selectedCount === 0 ? "Select Connectors" : `${selectedCount} connectors selected`}
|
||||
>
|
||||
|
|
|
@ -15,11 +15,11 @@ type SegmentedControlProps<T extends string> = {
|
|||
*/
|
||||
function SegmentedControl<T extends string>({ value, onChange, options }: SegmentedControlProps<T>) {
|
||||
return (
|
||||
<div className="flex rounded-md border border-border overflow-hidden scale-90 origin-left">
|
||||
<div className="flex h-7 rounded-md border border-border overflow-hidden">
|
||||
{options.map((option) => (
|
||||
<button
|
||||
key={option.value}
|
||||
className={`flex items-center gap-1 px-2 py-1 text-xs transition-colors ${
|
||||
className={`flex h-full items-center gap-1 px-2 text-xs transition-colors ${
|
||||
value === option.value
|
||||
? 'bg-primary text-primary-foreground'
|
||||
: 'hover:bg-muted'
|
||||
|
|
Loading…
Add table
Reference in a new issue