feat: Introduce the RAPTOR Search.

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2025-05-11 23:04:48 -07:00
parent d3540d8cc5
commit fbbb3294f4
11 changed files with 318 additions and 127 deletions

View file

@ -3,10 +3,16 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from enum import Enum
from typing import Optional, List, Any from typing import Optional, List, Any
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
class SearchMode(Enum):
"""Enum defining the type of search mode."""
CHUNKS = "CHUNKS"
DOCUMENTS = "DOCUMENTS"
@dataclass(kw_only=True) @dataclass(kw_only=True)
class Configuration: class Configuration:
@ -18,6 +24,7 @@ class Configuration:
connectors_to_search: List[str] connectors_to_search: List[str]
user_id: str user_id: str
search_space_id: int search_space_id: int
search_mode: SearchMode
@classmethod @classmethod

View file

@ -10,7 +10,7 @@ from langchain_core.runnables import RunnableConfig
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from .configuration import Configuration from .configuration import Configuration, SearchMode
from .prompts import get_answer_outline_system_prompt from .prompts import get_answer_outline_system_prompt
from .state import State from .state import State
from .sub_section_writer.graph import graph as sub_section_writer_graph from .sub_section_writer.graph import graph as sub_section_writer_graph
@ -149,7 +149,8 @@ async def fetch_relevant_documents(
writer: StreamWriter = None, writer: StreamWriter = None,
state: State = None, state: State = None,
top_k: int = 10, top_k: int = 10,
connector_service: ConnectorService = None connector_service: ConnectorService = None,
search_mode: SearchMode = SearchMode.CHUNKS
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
Fetch relevant documents for research questions using the provided connectors. Fetch relevant documents for research questions using the provided connectors.
@ -213,7 +214,8 @@ async def fetch_relevant_documents(
user_query=reformulated_query, user_query=reformulated_query,
user_id=user_id, user_id=user_id,
search_space_id=search_space_id, search_space_id=search_space_id,
top_k=top_k top_k=top_k,
search_mode=search_mode
) )
# Add to sources and raw documents # Add to sources and raw documents
@ -231,7 +233,8 @@ async def fetch_relevant_documents(
user_query=reformulated_query, user_query=reformulated_query,
user_id=user_id, user_id=user_id,
search_space_id=search_space_id, search_space_id=search_space_id,
top_k=top_k top_k=top_k,
search_mode=search_mode
) )
# Add to sources and raw documents # Add to sources and raw documents
@ -249,7 +252,8 @@ async def fetch_relevant_documents(
user_query=reformulated_query, user_query=reformulated_query,
user_id=user_id, user_id=user_id,
search_space_id=search_space_id, search_space_id=search_space_id,
top_k=top_k top_k=top_k,
search_mode=search_mode
) )
# Add to sources and raw documents # Add to sources and raw documents
@ -267,7 +271,8 @@ async def fetch_relevant_documents(
user_query=reformulated_query, user_query=reformulated_query,
user_id=user_id, user_id=user_id,
search_space_id=search_space_id, search_space_id=search_space_id,
top_k=top_k top_k=top_k,
search_mode=search_mode
) )
# Add to sources and raw documents # Add to sources and raw documents
@ -286,7 +291,8 @@ async def fetch_relevant_documents(
user_query=reformulated_query, user_query=reformulated_query,
user_id=user_id, user_id=user_id,
search_space_id=search_space_id, search_space_id=search_space_id,
top_k=top_k top_k=top_k,
search_mode=search_mode
) )
# Add to sources and raw documents # Add to sources and raw documents
@ -304,7 +310,8 @@ async def fetch_relevant_documents(
user_query=reformulated_query, user_query=reformulated_query,
user_id=user_id, user_id=user_id,
search_space_id=search_space_id, search_space_id=search_space_id,
top_k=top_k top_k=top_k,
search_mode=search_mode
) )
# Add to sources and raw documents # Add to sources and raw documents
@ -322,7 +329,8 @@ async def fetch_relevant_documents(
user_query=reformulated_query, user_query=reformulated_query,
user_id=user_id, user_id=user_id,
search_space_id=search_space_id, search_space_id=search_space_id,
top_k=top_k top_k=top_k,
search_mode=search_mode
) )
# Add to sources and raw documents # Add to sources and raw documents
@ -340,7 +348,8 @@ async def fetch_relevant_documents(
user_query=reformulated_query, user_query=reformulated_query,
user_id=user_id, user_id=user_id,
search_space_id=search_space_id, search_space_id=search_space_id,
top_k=top_k top_k=top_k,
search_mode=search_mode
) )
# Add to sources and raw documents # Add to sources and raw documents
@ -558,7 +567,8 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW
writer=writer, writer=writer,
state=state, state=state,
top_k=TOP_K, top_k=TOP_K,
connector_service=connector_service connector_service=connector_service,
search_mode=configuration.search_mode
) )
except Exception as e: except Exception as e:
error_message = f"Error fetching relevant documents: {str(e)}" error_message = f"Error fetching relevant documents: {str(e)}"

View file

@ -141,6 +141,11 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
# Construct a clear, structured query for the LLM # Construct a clear, structured query for the LLM
human_message_content = f""" human_message_content = f"""
Source material:
<documents>
{documents_text}
</documents>
Now user's query is: Now user's query is:
<user_query> <user_query>
{user_query} {user_query}
@ -158,11 +163,6 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
<guiding_questions> <guiding_questions>
{questions_text} {questions_text}
</guiding_questions> </guiding_questions>
Use the provided documents as your source material and cite them properly using the IEEE citation format [X] where X is the source_id.
<documents>
{documents_text}
</documents>
""" """
# Create messages for the LLM # Create messages for the LLM

View file

@ -25,6 +25,8 @@ You are a research assistant tasked with analyzing documents and providing compr
16. CRITICAL: Citations must ONLY appear as [X] or [X], [Y], [Z] format - never with parentheses, hyperlinks, or other formatting. 16. CRITICAL: Citations must ONLY appear as [X] or [X], [Y], [Z] format - never with parentheses, hyperlinks, or other formatting.
17. CRITICAL: Never make up citation numbers. Only use source_id values that are explicitly provided in the document metadata. 17. CRITICAL: Never make up citation numbers. Only use source_id values that are explicitly provided in the document metadata.
18. CRITICAL: If you are unsure about a source_id, do not include a citation rather than guessing or making one up. 18. CRITICAL: If you are unsure about a source_id, do not include a citation rather than guessing or making one up.
19. CRITICAL: Focus only on answering the user's query. Any guiding questions provided are for your thinking process only and should not be mentioned in your response.
20. CRITICAL: Ensure your response aligns with the provided sub-section title and section position.
</instructions> </instructions>
<format> <format>
@ -37,6 +39,8 @@ You are a research assistant tasked with analyzing documents and providing compr
- NEVER create your own citation numbering system - use the exact source_id values from the documents. - NEVER create your own citation numbering system - use the exact source_id values from the documents.
- NEVER format citations as clickable links or as markdown links like "([1](https://example.com))". Always use plain square brackets only. - NEVER format citations as clickable links or as markdown links like "([1](https://example.com))". Always use plain square brackets only.
- NEVER make up citation numbers if you are unsure about the source_id. It is better to omit the citation than to guess. - NEVER make up citation numbers if you are unsure about the source_id. It is better to omit the citation than to guess.
- NEVER include or mention the guiding questions in your response. They are only to help guide your thinking.
- ALWAYS focus on answering the user's query directly from the information in the documents.
</format> </format>
<input_example> <input_example>
@ -84,4 +88,21 @@ ONLY use plain square brackets [1] or multiple citations [1], [2], [3]
</incorrect_citation_formats> </incorrect_citation_formats>
Note that the citation numbers match exactly with the source_id values (1, 13, and 21) and are not renumbered sequentially. Citations follow IEEE style with square brackets and appear at the end of sentences. Note that the citation numbers match exactly with the source_id values (1, 13, and 21) and are not renumbered sequentially. Citations follow IEEE style with square brackets and appear at the end of sentences.
<user_query_instructions>
When you see a user query like:
<user_query>
Give all linear issues.
</user_query>
Focus exclusively on answering this query using information from the provided documents.
If guiding questions are provided in a <guiding_questions> section, use them only to guide your thinking process. Do not mention or list these questions in your response.
Make sure your response:
1. Directly answers the user's query
2. Fits the provided sub-section title and section position
3. Uses proper citations for all information from documents
4. Is well-structured and professional in tone
</user_query_instructions>
""" """

View file

@ -113,8 +113,6 @@ class DocumentHybridSearchRetriever:
search_space_id: Optional search space ID to filter results search_space_id: Optional search space ID to filter results
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL") document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
Returns:
List of dictionaries containing document data and relevance scores
""" """
from sqlalchemy import select, func, text from sqlalchemy import select, func, text
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
@ -224,10 +222,22 @@ class DocumentHybridSearchRetriever:
# Convert to serializable dictionaries # Convert to serializable dictionaries
serialized_results = [] serialized_results = []
for document, score in documents_with_scores: for document, score in documents_with_scores:
# Fetch associated chunks for this document
from sqlalchemy import select
from app.db import Chunk
chunks_query = select(Chunk).where(Chunk.document_id == document.id).order_by(Chunk.id)
chunks_result = await self.db_session.execute(chunks_query)
chunks = chunks_result.scalars().all()
# Concatenate chunks content
concatenated_chunks_content = " ".join([chunk.content for chunk in chunks]) if chunks else document.content
serialized_results.append({ serialized_results.append({
"document_id": document.id, "document_id": document.id,
"title": document.title, "title": document.title,
"content": document.content, "content": document.content,
"chunks_content": concatenated_chunks_content,
"document_type": document.document_type.value if hasattr(document, 'document_type') else None, "document_type": document.document_type.value if hasattr(document, 'document_type') else None,
"metadata": document.document_metadata, "metadata": document.document_metadata,
"score": float(score), # Ensure score is a Python float "score": float(score), # Ensure score is a Python float

View file

@ -11,6 +11,8 @@ from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from langchain.schema import HumanMessage, AIMessage from langchain.schema import HumanMessage, AIMessage
router = APIRouter() router = APIRouter()
@router.post("/chat") @router.post("/chat")
@ -29,6 +31,8 @@ async def handle_chat_data(
research_mode: str = request.data.get('research_mode') research_mode: str = request.data.get('research_mode')
selected_connectors: List[str] = request.data.get('selected_connectors') selected_connectors: List[str] = request.data.get('selected_connectors')
search_mode_str = request.data.get('search_mode', "CHUNKS")
# Convert search_space_id to integer if it's a string # Convert search_space_id to integer if it's a string
if search_space_id and isinstance(search_space_id, str): if search_space_id and isinstance(search_space_id, str):
try: try:
@ -66,7 +70,8 @@ async def handle_chat_data(
session, session,
research_mode, research_mode,
selected_connectors, selected_connectors,
langchain_chat_history langchain_chat_history,
search_mode_str
)) ))
response.headers['x-vercel-ai-data-stream'] = 'v1' response.headers['x-vercel-ai-data-stream'] = 'v1'
return response return response

View file

@ -6,6 +6,8 @@ from app.agents.researcher.state import State
from app.utils.streaming_service import StreamingService from app.utils.streaming_service import StreamingService
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.researcher.configuration import SearchMode
async def stream_connector_search_results( async def stream_connector_search_results(
user_query: str, user_query: str,
@ -14,7 +16,8 @@ async def stream_connector_search_results(
session: AsyncSession, session: AsyncSession,
research_mode: str, research_mode: str,
selected_connectors: List[str], selected_connectors: List[str],
langchain_chat_history: List[Any] langchain_chat_history: List[Any],
search_mode_str: str
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
""" """
Stream connector search results to the client Stream connector search results to the client
@ -41,6 +44,11 @@ async def stream_connector_search_results(
# Convert UUID to string if needed # Convert UUID to string if needed
user_id_str = str(user_id) if isinstance(user_id, UUID) else user_id user_id_str = str(user_id) if isinstance(user_id, UUID) else user_id
if search_mode_str == "CHUNKS":
search_mode = SearchMode.CHUNKS
elif search_mode_str == "DOCUMENTS":
search_mode = SearchMode.DOCUMENTS
# Sample configuration # Sample configuration
config = { config = {
"configurable": { "configurable": {
@ -48,7 +56,8 @@ async def stream_connector_search_results(
"num_sections": NUM_SECTIONS, "num_sections": NUM_SECTIONS,
"connectors_to_search": selected_connectors, "connectors_to_search": selected_connectors,
"user_id": user_id_str, "user_id": user_id_str,
"search_space_id": search_space_id "search_space_id": search_space_id,
"search_mode": search_mode
} }
} }
# Initialize state with database session and streaming service # Initialize state with database session and streaming service

View file

@ -4,32 +4,47 @@ import asyncio
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
from app.retriver.documents_hybrid_search import DocumentHybridSearchRetriever
from app.db import SearchSourceConnector, SearchSourceConnectorType from app.db import SearchSourceConnector, SearchSourceConnectorType
from tavily import TavilyClient from tavily import TavilyClient
from linkup import LinkupClient from linkup import LinkupClient
from app.agents.researcher.configuration import SearchMode
class ConnectorService: class ConnectorService:
def __init__(self, session: AsyncSession): def __init__(self, session: AsyncSession):
self.session = session self.session = session
self.retriever = ChucksHybridSearchRetriever(session) self.chunk_retriever = ChucksHybridSearchRetriever(session)
self.document_retriever = DocumentHybridSearchRetriever(session)
self.source_id_counter = 1 self.source_id_counter = 1
self.counter_lock = asyncio.Lock() # Lock to protect counter in multithreaded environments self.counter_lock = asyncio.Lock() # Lock to protect counter in multithreaded environments
async def search_crawled_urls(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple: async def search_crawled_urls(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
""" """
Search for crawled URLs and return both the source information and langchain documents Search for crawled URLs and return both the source information and langchain documents
Returns: Returns:
tuple: (sources_info, langchain_documents) tuple: (sources_info, langchain_documents)
""" """
crawled_urls_chunks = await self.retriever.hybrid_search( if search_mode == SearchMode.CHUNKS:
query_text=user_query, crawled_urls_chunks = await self.chunk_retriever.hybrid_search(
top_k=top_k, query_text=user_query,
user_id=user_id, top_k=top_k,
search_space_id=search_space_id, user_id=user_id,
document_type="CRAWLED_URL" search_space_id=search_space_id,
) document_type="CRAWLED_URL"
)
elif search_mode == SearchMode.DOCUMENTS:
crawled_urls_chunks = await self.document_retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="CRAWLED_URL"
)
# Transform document retriever results to match expected format
crawled_urls_chunks = self._transform_document_results(crawled_urls_chunks)
# Early return if no results # Early return if no results
if not crawled_urls_chunks: if not crawled_urls_chunks:
@ -71,20 +86,31 @@ class ConnectorService:
return result_object, crawled_urls_chunks return result_object, crawled_urls_chunks
async def search_files(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple: async def search_files(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
""" """
Search for files and return both the source information and langchain documents Search for files and return both the source information and langchain documents
Returns: Returns:
tuple: (sources_info, langchain_documents) tuple: (sources_info, langchain_documents)
""" """
files_chunks = await self.retriever.hybrid_search( if search_mode == SearchMode.CHUNKS:
query_text=user_query, files_chunks = await self.chunk_retriever.hybrid_search(
top_k=top_k, query_text=user_query,
user_id=user_id, top_k=top_k,
search_space_id=search_space_id, user_id=user_id,
document_type="FILE" search_space_id=search_space_id,
) document_type="FILE"
)
elif search_mode == SearchMode.DOCUMENTS:
files_chunks = await self.document_retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="FILE"
)
# Transform document retriever results to match expected format
files_chunks = self._transform_document_results(files_chunks)
# Early return if no results # Early return if no results
if not files_chunks: if not files_chunks:
@ -126,6 +152,31 @@ class ConnectorService:
return result_object, files_chunks return result_object, files_chunks
def _transform_document_results(self, document_results: List[Dict]) -> List[Dict]:
"""
Transform results from document_retriever.hybrid_search() to match the format
expected by the processing code.
Args:
document_results: Results from document_retriever.hybrid_search()
Returns:
List of transformed results in the format expected by the processing code
"""
transformed_results = []
for doc in document_results:
transformed_results.append({
'document': {
'id': doc.get('document_id'),
'title': doc.get('title', 'Untitled Document'),
'document_type': doc.get('document_type'),
'metadata': doc.get('metadata', {}),
},
'content': doc.get('chunks_content', doc.get('content', '')),
'score': doc.get('score', 0.0)
})
return transformed_results
async def get_connector_by_type(self, user_id: str, connector_type: SearchSourceConnectorType) -> Optional[SearchSourceConnector]: async def get_connector_by_type(self, user_id: str, connector_type: SearchSourceConnectorType) -> Optional[SearchSourceConnector]:
""" """
Get a connector by type for a specific user Get a connector by type for a specific user
@ -249,20 +300,31 @@ class ConnectorService:
"sources": [], "sources": [],
}, [] }, []
async def search_slack(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple: async def search_slack(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
""" """
Search for slack and return both the source information and langchain documents Search for slack and return both the source information and langchain documents
Returns: Returns:
tuple: (sources_info, langchain_documents) tuple: (sources_info, langchain_documents)
""" """
slack_chunks = await self.retriever.hybrid_search( if search_mode == SearchMode.CHUNKS:
query_text=user_query, slack_chunks = await self.chunk_retriever.hybrid_search(
top_k=top_k, query_text=user_query,
user_id=user_id, top_k=top_k,
search_space_id=search_space_id, user_id=user_id,
document_type="SLACK_CONNECTOR" search_space_id=search_space_id,
) document_type="SLACK_CONNECTOR"
)
elif search_mode == SearchMode.DOCUMENTS:
slack_chunks = await self.document_retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="SLACK_CONNECTOR"
)
# Transform document retriever results to match expected format
slack_chunks = self._transform_document_results(slack_chunks)
# Early return if no results # Early return if no results
if not slack_chunks: if not slack_chunks:
@ -323,7 +385,7 @@ class ConnectorService:
return result_object, slack_chunks return result_object, slack_chunks
async def search_notion(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple: async def search_notion(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
""" """
Search for Notion pages and return both the source information and langchain documents Search for Notion pages and return both the source information and langchain documents
@ -336,13 +398,24 @@ class ConnectorService:
Returns: Returns:
tuple: (sources_info, langchain_documents) tuple: (sources_info, langchain_documents)
""" """
notion_chunks = await self.retriever.hybrid_search( if search_mode == SearchMode.CHUNKS:
query_text=user_query, notion_chunks = await self.chunk_retriever.hybrid_search(
top_k=top_k, query_text=user_query,
user_id=user_id, top_k=top_k,
search_space_id=search_space_id, user_id=user_id,
document_type="NOTION_CONNECTOR" search_space_id=search_space_id,
) document_type="NOTION_CONNECTOR"
)
elif search_mode == SearchMode.DOCUMENTS:
notion_chunks = await self.document_retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="NOTION_CONNECTOR"
)
# Transform document retriever results to match expected format
notion_chunks = self._transform_document_results(notion_chunks)
# Early return if no results # Early return if no results
if not notion_chunks: if not notion_chunks:
@ -405,7 +478,7 @@ class ConnectorService:
return result_object, notion_chunks return result_object, notion_chunks
async def search_extension(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple: async def search_extension(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
""" """
Search for extension data and return both the source information and langchain documents Search for extension data and return both the source information and langchain documents
@ -418,13 +491,24 @@ class ConnectorService:
Returns: Returns:
tuple: (sources_info, langchain_documents) tuple: (sources_info, langchain_documents)
""" """
extension_chunks = await self.retriever.hybrid_search( if search_mode == SearchMode.CHUNKS:
query_text=user_query, extension_chunks = await self.chunk_retriever.hybrid_search(
top_k=top_k, query_text=user_query,
user_id=user_id, top_k=top_k,
search_space_id=search_space_id, user_id=user_id,
document_type="EXTENSION" search_space_id=search_space_id,
) document_type="EXTENSION"
)
elif search_mode == SearchMode.DOCUMENTS:
extension_chunks = await self.document_retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="EXTENSION"
)
# Transform document retriever results to match expected format
extension_chunks = self._transform_document_results(extension_chunks)
# Early return if no results # Early return if no results
if not extension_chunks: if not extension_chunks:
@ -505,7 +589,7 @@ class ConnectorService:
return result_object, extension_chunks return result_object, extension_chunks
async def search_youtube(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple: async def search_youtube(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
""" """
Search for YouTube videos and return both the source information and langchain documents Search for YouTube videos and return both the source information and langchain documents
@ -518,13 +602,24 @@ class ConnectorService:
Returns: Returns:
tuple: (sources_info, langchain_documents) tuple: (sources_info, langchain_documents)
""" """
youtube_chunks = await self.retriever.hybrid_search( if search_mode == SearchMode.CHUNKS:
query_text=user_query, youtube_chunks = await self.chunk_retriever.hybrid_search(
top_k=top_k, query_text=user_query,
user_id=user_id, top_k=top_k,
search_space_id=search_space_id, user_id=user_id,
document_type="YOUTUBE_VIDEO" search_space_id=search_space_id,
) document_type="YOUTUBE_VIDEO"
)
elif search_mode == SearchMode.DOCUMENTS:
youtube_chunks = await self.document_retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="YOUTUBE_VIDEO"
)
# Transform document retriever results to match expected format
youtube_chunks = self._transform_document_results(youtube_chunks)
# Early return if no results # Early return if no results
if not youtube_chunks: if not youtube_chunks:
@ -587,20 +682,31 @@ class ConnectorService:
return result_object, youtube_chunks return result_object, youtube_chunks
async def search_github(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20) -> tuple: async def search_github(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
""" """
Search for GitHub documents and return both the source information and langchain documents Search for GitHub documents and return both the source information and langchain documents
Returns: Returns:
tuple: (sources_info, langchain_documents) tuple: (sources_info, langchain_documents)
""" """
github_chunks = await self.retriever.hybrid_search( if search_mode == SearchMode.CHUNKS:
query_text=user_query, github_chunks = await self.chunk_retriever.hybrid_search(
top_k=top_k, query_text=user_query,
user_id=user_id, top_k=top_k,
search_space_id=search_space_id, user_id=user_id,
document_type="GITHUB_CONNECTOR" search_space_id=search_space_id,
) document_type="GITHUB_CONNECTOR"
)
elif search_mode == SearchMode.DOCUMENTS:
github_chunks = await self.document_retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="GITHUB_CONNECTOR"
)
# Transform document retriever results to match expected format
github_chunks = self._transform_document_results(github_chunks)
# Early return if no results # Early return if no results
if not github_chunks: if not github_chunks:
@ -643,7 +749,7 @@ class ConnectorService:
return result_object, github_chunks return result_object, github_chunks
async def search_linear(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple: async def search_linear(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
""" """
Search for Linear issues and comments and return both the source information and langchain documents Search for Linear issues and comments and return both the source information and langchain documents
@ -656,13 +762,24 @@ class ConnectorService:
Returns: Returns:
tuple: (sources_info, langchain_documents) tuple: (sources_info, langchain_documents)
""" """
linear_chunks = await self.retriever.hybrid_search( if search_mode == SearchMode.CHUNKS:
query_text=user_query, linear_chunks = await self.chunk_retriever.hybrid_search(
top_k=top_k, query_text=user_query,
user_id=user_id, top_k=top_k,
search_space_id=search_space_id, user_id=user_id,
document_type="LINEAR_CONNECTOR" search_space_id=search_space_id,
) document_type="LINEAR_CONNECTOR"
)
elif search_mode == SearchMode.DOCUMENTS:
linear_chunks = await self.document_retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="LINEAR_CONNECTOR"
)
# Transform document retriever results to match expected format
linear_chunks = self._transform_document_results(linear_chunks)
# Early return if no results # Early return if no results
if not linear_chunks: if not linear_chunks:

View file

@ -13,7 +13,9 @@ import {
ArrowDown, ArrowDown,
CircleUser, CircleUser,
Database, Database,
SendHorizontal SendHorizontal,
FileText,
Grid3x3
} from 'lucide-react'; } from 'lucide-react';
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'; import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import { Button } from '@/components/ui/button'; import { Button } from '@/components/ui/button';
@ -248,6 +250,7 @@ const ChatPage = () => {
const tabsListRef = useRef<HTMLDivElement>(null); const tabsListRef = useRef<HTMLDivElement>(null);
const [terminalExpanded, setTerminalExpanded] = useState(false); const [terminalExpanded, setTerminalExpanded] = useState(false);
const [selectedConnectors, setSelectedConnectors] = useState<string[]>(["CRAWLED_URL"]); const [selectedConnectors, setSelectedConnectors] = useState<string[]>(["CRAWLED_URL"]);
const [searchMode, setSearchMode] = useState<'DOCUMENTS' | 'CHUNKS'>('DOCUMENTS');
const [researchMode, setResearchMode] = useState<ResearchMode>("GENERAL"); const [researchMode, setResearchMode] = useState<ResearchMode>("GENERAL");
const [currentTime, setCurrentTime] = useState<string>(''); const [currentTime, setCurrentTime] = useState<string>('');
const [currentDate, setCurrentDate] = useState<string>(''); const [currentDate, setCurrentDate] = useState<string>('');
@ -362,7 +365,8 @@ const ChatPage = () => {
data: { data: {
search_space_id: search_space_id, search_space_id: search_space_id,
selected_connectors: selectedConnectors, selected_connectors: selectedConnectors,
research_mode: researchMode research_mode: researchMode,
search_mode: searchMode
} }
}, },
onError: (error) => { onError: (error) => {
@ -557,11 +561,6 @@ const ChatPage = () => {
} }
}, [terminalExpanded]); }, [terminalExpanded]);
// Get total sources count for a connector type
const getSourcesCount = (connectorType: string) => {
return getSourcesCountUtil(getMessageConnectorSources(messages[messages.length - 1]), connectorType);
};
// Function to check scroll position and update indicators // Function to check scroll position and update indicators
const updateScrollIndicators = () => { const updateScrollIndicators = () => {
updateScrollIndicatorsUtil(tabsListRef as React.RefObject<HTMLDivElement>, setCanScrollLeft, setCanScrollRight); updateScrollIndicatorsUtil(tabsListRef as React.RefObject<HTMLDivElement>, setCanScrollLeft, setCanScrollRight);
@ -587,23 +586,6 @@ const ChatPage = () => {
// Use the scroll to bottom hook // Use the scroll to bottom hook
useScrollToBottom(messagesEndRef as React.RefObject<HTMLDivElement>, [messages]); useScrollToBottom(messagesEndRef as React.RefObject<HTMLDivElement>, [messages]);
// Function to get sources for the main view
const getMainViewSources = (connector: any) => {
return getMainViewSourcesUtil(connector, INITIAL_SOURCES_DISPLAY);
};
// Function to get filtered sources for the dialog with null check
const getFilteredSourcesWithCheck = (connector: any, sourceFilter: string) => {
if (!connector?.sources) return [];
return getFilteredSourcesUtil(connector, sourceFilter);
};
// Function to get paginated dialog sources with null check
const getPaginatedDialogSourcesWithCheck = (connector: any, sourceFilter: string, expandedSources: boolean, sourcesPage: number, sourcesPerPage: number) => {
if (!connector?.sources) return [];
return getPaginatedDialogSourcesUtil(connector, sourceFilter, expandedSources, sourcesPage, sourcesPerPage);
};
// Function to get a citation source by ID // Function to get a citation source by ID
const getCitationSource = React.useCallback((citationId: number, messageIndex?: number): Source | null => { const getCitationSource = React.useCallback((citationId: number, messageIndex?: number): Source | null => {
if (!messages || messages.length === 0) return null; if (!messages || messages.length === 0) return null;
@ -995,15 +977,17 @@ const ChatPage = () => {
<span className="sr-only">Send</span> <span className="sr-only">Send</span>
</Button> </Button>
</form> </form>
<div className="flex items-center justify-between px-2 py-1 mt-8"> <div className="flex items-center justify-between px-2 py-2 mt-3">
<div className="flex items-center gap-4"> <div className="flex items-center space-x-3">
{/* Connector Selection Dialog */} {/* Connector Selection Dialog */}
<Dialog> <Dialog>
<DialogTrigger asChild> <DialogTrigger asChild>
<ConnectorButton <div className="h-8">
selectedConnectors={selectedConnectors} <ConnectorButton
onClick={() => { }} selectedConnectors={selectedConnectors}
/> onClick={() => { }}
/>
</div>
</DialogTrigger> </DialogTrigger>
<DialogContent className="sm:max-w-md"> <DialogContent className="sm:max-w-md">
<DialogHeader> <DialogHeader>
@ -1070,12 +1054,40 @@ const ChatPage = () => {
</DialogContent> </DialogContent>
</Dialog> </Dialog>
{/* Search Mode Control */}
<div className="flex items-center p-0.5 rounded-md border border-border bg-muted/20 h-8">
<button
onClick={() => setSearchMode('DOCUMENTS')}
className={`flex h-full items-center justify-center gap-1 px-2 rounded text-xs font-medium transition-colors flex-1 whitespace-nowrap overflow-hidden ${
searchMode === 'DOCUMENTS'
? 'bg-primary text-primary-foreground shadow-sm'
: 'text-muted-foreground hover:text-foreground hover:bg-muted/50'
}`}
>
<FileText className="h-3 w-3 flex-shrink-0 mr-1" />
<span>Full Document</span>
</button>
<button
onClick={() => setSearchMode('CHUNKS')}
className={`flex h-full items-center justify-center gap-1 px-2 rounded text-xs font-medium transition-colors flex-1 whitespace-nowrap overflow-hidden ${
searchMode === 'CHUNKS'
? 'bg-primary text-primary-foreground shadow-sm'
: 'text-muted-foreground hover:text-foreground hover:bg-muted/50'
}`}
>
<Grid3x3 className="h-3 w-3 flex-shrink-0 mr-1" />
<span>Document Chunks</span>
</button>
</div>
{/* Research Mode Segmented Control */} {/* Research Mode Segmented Control */}
<SegmentedControl<ResearchMode> <div className="h-8">
value={researchMode} <SegmentedControl<ResearchMode>
onChange={setResearchMode} value={researchMode}
options={researcherOptions} onChange={setResearchMode}
/> options={researcherOptions}
/>
</div>
</div> </div>
</div> </div>
</div> </div>

View file

@ -147,7 +147,7 @@ export const ConnectorButton = ({ selectedConnectors, onClick, connectorSources
return ( return (
<Button <Button
variant="outline" variant="outline"
className="h-7 px-2 text-xs font-medium rounded-md border-border relative overflow-hidden group scale-90 origin-left" className="h-8 px-2 text-xs font-medium rounded-md border-border relative overflow-hidden group"
onClick={onClick} onClick={onClick}
aria-label={selectedCount === 0 ? "Select Connectors" : `${selectedCount} connectors selected`} aria-label={selectedCount === 0 ? "Select Connectors" : `${selectedCount} connectors selected`}
> >

View file

@ -15,11 +15,11 @@ type SegmentedControlProps<T extends string> = {
*/ */
function SegmentedControl<T extends string>({ value, onChange, options }: SegmentedControlProps<T>) { function SegmentedControl<T extends string>({ value, onChange, options }: SegmentedControlProps<T>) {
return ( return (
<div className="flex rounded-md border border-border overflow-hidden scale-90 origin-left"> <div className="flex h-7 rounded-md border border-border overflow-hidden">
{options.map((option) => ( {options.map((option) => (
<button <button
key={option.value} key={option.value}
className={`flex items-center gap-1 px-2 py-1 text-xs transition-colors ${ className={`flex h-full items-center gap-1 px-2 text-xs transition-colors ${
value === option.value value === option.value
? 'bg-primary text-primary-foreground' ? 'bg-primary text-primary-foreground'
: 'hover:bg-muted' : 'hover:bg-muted'