Merge pull request #90 from MODSetter/dev

feat: Introduce RAPTOR Search Mode
This commit is contained in:
Rohan Verma 2025-05-11 23:15:39 -07:00 committed by GitHub
commit 29f4d90a0a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 318 additions and 127 deletions

View file

@ -3,10 +3,16 @@
from __future__ import annotations
from dataclasses import dataclass, fields
from enum import Enum
from typing import Optional, List, Any
from langchain_core.runnables import RunnableConfig
class SearchMode(Enum):
"""Enum defining the type of search mode."""
CHUNKS = "CHUNKS"
DOCUMENTS = "DOCUMENTS"
@dataclass(kw_only=True)
class Configuration:
@ -18,6 +24,7 @@ class Configuration:
connectors_to_search: List[str]
user_id: str
search_space_id: int
search_mode: SearchMode
@classmethod

View file

@ -10,7 +10,7 @@ from langchain_core.runnables import RunnableConfig
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from .configuration import Configuration
from .configuration import Configuration, SearchMode
from .prompts import get_answer_outline_system_prompt
from .state import State
from .sub_section_writer.graph import graph as sub_section_writer_graph
@ -149,7 +149,8 @@ async def fetch_relevant_documents(
writer: StreamWriter = None,
state: State = None,
top_k: int = 10,
connector_service: ConnectorService = None
connector_service: ConnectorService = None,
search_mode: SearchMode = SearchMode.CHUNKS
) -> List[Dict[str, Any]]:
"""
Fetch relevant documents for research questions using the provided connectors.
@ -213,7 +214,8 @@ async def fetch_relevant_documents(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=top_k
top_k=top_k,
search_mode=search_mode
)
# Add to sources and raw documents
@ -231,7 +233,8 @@ async def fetch_relevant_documents(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=top_k
top_k=top_k,
search_mode=search_mode
)
# Add to sources and raw documents
@ -249,7 +252,8 @@ async def fetch_relevant_documents(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=top_k
top_k=top_k,
search_mode=search_mode
)
# Add to sources and raw documents
@ -267,7 +271,8 @@ async def fetch_relevant_documents(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=top_k
top_k=top_k,
search_mode=search_mode
)
# Add to sources and raw documents
@ -286,7 +291,8 @@ async def fetch_relevant_documents(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=top_k
top_k=top_k,
search_mode=search_mode
)
# Add to sources and raw documents
@ -304,7 +310,8 @@ async def fetch_relevant_documents(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=top_k
top_k=top_k,
search_mode=search_mode
)
# Add to sources and raw documents
@ -322,7 +329,8 @@ async def fetch_relevant_documents(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=top_k
top_k=top_k,
search_mode=search_mode
)
# Add to sources and raw documents
@ -340,7 +348,8 @@ async def fetch_relevant_documents(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=top_k
top_k=top_k,
search_mode=search_mode
)
# Add to sources and raw documents
@ -558,7 +567,8 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW
writer=writer,
state=state,
top_k=TOP_K,
connector_service=connector_service
connector_service=connector_service,
search_mode=configuration.search_mode
)
except Exception as e:
error_message = f"Error fetching relevant documents: {str(e)}"

View file

@ -141,6 +141,11 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
# Construct a clear, structured query for the LLM
human_message_content = f"""
Source material:
<documents>
{documents_text}
</documents>
Now user's query is:
<user_query>
{user_query}
@ -158,11 +163,6 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
<guiding_questions>
{questions_text}
</guiding_questions>
Use the provided documents as your source material and cite them properly using the IEEE citation format [X] where X is the source_id.
<documents>
{documents_text}
</documents>
"""
# Create messages for the LLM

View file

@ -25,6 +25,8 @@ You are a research assistant tasked with analyzing documents and providing compr
16. CRITICAL: Citations must ONLY appear as [X] or [X], [Y], [Z] format - never with parentheses, hyperlinks, or other formatting.
17. CRITICAL: Never make up citation numbers. Only use source_id values that are explicitly provided in the document metadata.
18. CRITICAL: If you are unsure about a source_id, do not include a citation rather than guessing or making one up.
19. CRITICAL: Focus only on answering the user's query. Any guiding questions provided are for your thinking process only and should not be mentioned in your response.
20. CRITICAL: Ensure your response aligns with the provided sub-section title and section position.
</instructions>
<format>
@ -37,6 +39,8 @@ You are a research assistant tasked with analyzing documents and providing compr
- NEVER create your own citation numbering system - use the exact source_id values from the documents.
- NEVER format citations as clickable links or as markdown links like "([1](https://example.com))". Always use plain square brackets only.
- NEVER make up citation numbers if you are unsure about the source_id. It is better to omit the citation than to guess.
- NEVER include or mention the guiding questions in your response. They are only to help guide your thinking.
- ALWAYS focus on answering the user's query directly from the information in the documents.
</format>
<input_example>
@ -84,4 +88,21 @@ ONLY use plain square brackets [1] or multiple citations [1], [2], [3]
</incorrect_citation_formats>
Note that the citation numbers match exactly with the source_id values (1, 13, and 21) and are not renumbered sequentially. Citations follow IEEE style with square brackets and appear at the end of sentences.
<user_query_instructions>
When you see a user query like:
<user_query>
Give all linear issues.
</user_query>
Focus exclusively on answering this query using information from the provided documents.
If guiding questions are provided in a <guiding_questions> section, use them only to guide your thinking process. Do not mention or list these questions in your response.
Make sure your response:
1. Directly answers the user's query
2. Fits the provided sub-section title and section position
3. Uses proper citations for all information from documents
4. Is well-structured and professional in tone
</user_query_instructions>
"""

View file

@ -113,8 +113,6 @@ class DocumentHybridSearchRetriever:
search_space_id: Optional search space ID to filter results
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
Returns:
List of dictionaries containing document data and relevance scores
"""
from sqlalchemy import select, func, text
from sqlalchemy.orm import joinedload
@ -224,10 +222,22 @@ class DocumentHybridSearchRetriever:
# Convert to serializable dictionaries
serialized_results = []
for document, score in documents_with_scores:
# Fetch associated chunks for this document
from sqlalchemy import select
from app.db import Chunk
chunks_query = select(Chunk).where(Chunk.document_id == document.id).order_by(Chunk.id)
chunks_result = await self.db_session.execute(chunks_query)
chunks = chunks_result.scalars().all()
# Concatenate chunks content
concatenated_chunks_content = " ".join([chunk.content for chunk in chunks]) if chunks else document.content
serialized_results.append({
"document_id": document.id,
"title": document.title,
"content": document.content,
"chunks_content": concatenated_chunks_content,
"document_type": document.document_type.value if hasattr(document, 'document_type') else None,
"metadata": document.document_metadata,
"score": float(score), # Ensure score is a Python float

View file

@ -11,6 +11,8 @@ from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from langchain.schema import HumanMessage, AIMessage
router = APIRouter()
@router.post("/chat")
@ -28,6 +30,8 @@ async def handle_chat_data(
search_space_id = request.data.get('search_space_id')
research_mode: str = request.data.get('research_mode')
selected_connectors: List[str] = request.data.get('selected_connectors')
search_mode_str = request.data.get('search_mode', "CHUNKS")
# Convert search_space_id to integer if it's a string
if search_space_id and isinstance(search_space_id, str):
@ -66,7 +70,8 @@ async def handle_chat_data(
session,
research_mode,
selected_connectors,
langchain_chat_history
langchain_chat_history,
search_mode_str
))
response.headers['x-vercel-ai-data-stream'] = 'v1'
return response

View file

@ -6,6 +6,8 @@ from app.agents.researcher.state import State
from app.utils.streaming_service import StreamingService
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.researcher.configuration import SearchMode
async def stream_connector_search_results(
user_query: str,
@ -14,7 +16,8 @@ async def stream_connector_search_results(
session: AsyncSession,
research_mode: str,
selected_connectors: List[str],
langchain_chat_history: List[Any]
langchain_chat_history: List[Any],
search_mode_str: str
) -> AsyncGenerator[str, None]:
"""
Stream connector search results to the client
@ -41,6 +44,11 @@ async def stream_connector_search_results(
# Convert UUID to string if needed
user_id_str = str(user_id) if isinstance(user_id, UUID) else user_id
if search_mode_str == "CHUNKS":
search_mode = SearchMode.CHUNKS
elif search_mode_str == "DOCUMENTS":
search_mode = SearchMode.DOCUMENTS
# Sample configuration
config = {
"configurable": {
@ -48,7 +56,8 @@ async def stream_connector_search_results(
"num_sections": NUM_SECTIONS,
"connectors_to_search": selected_connectors,
"user_id": user_id_str,
"search_space_id": search_space_id
"search_space_id": search_space_id,
"search_mode": search_mode
}
}
# Initialize state with database session and streaming service

View file

@ -4,32 +4,47 @@ import asyncio
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
from app.retriver.documents_hybrid_search import DocumentHybridSearchRetriever
from app.db import SearchSourceConnector, SearchSourceConnectorType
from tavily import TavilyClient
from linkup import LinkupClient
from app.agents.researcher.configuration import SearchMode
class ConnectorService:
def __init__(self, session: AsyncSession):
self.session = session
self.retriever = ChucksHybridSearchRetriever(session)
self.chunk_retriever = ChucksHybridSearchRetriever(session)
self.document_retriever = DocumentHybridSearchRetriever(session)
self.source_id_counter = 1
self.counter_lock = asyncio.Lock() # Lock to protect counter in multithreaded environments
async def search_crawled_urls(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple:
async def search_crawled_urls(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
"""
Search for crawled URLs and return both the source information and langchain documents
Returns:
tuple: (sources_info, langchain_documents)
"""
crawled_urls_chunks = await self.retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="CRAWLED_URL"
)
if search_mode == SearchMode.CHUNKS:
crawled_urls_chunks = await self.chunk_retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="CRAWLED_URL"
)
elif search_mode == SearchMode.DOCUMENTS:
crawled_urls_chunks = await self.document_retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="CRAWLED_URL"
)
# Transform document retriever results to match expected format
crawled_urls_chunks = self._transform_document_results(crawled_urls_chunks)
# Early return if no results
if not crawled_urls_chunks:
@ -71,20 +86,31 @@ class ConnectorService:
return result_object, crawled_urls_chunks
async def search_files(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple:
async def search_files(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
"""
Search for files and return both the source information and langchain documents
Returns:
tuple: (sources_info, langchain_documents)
"""
files_chunks = await self.retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="FILE"
)
if search_mode == SearchMode.CHUNKS:
files_chunks = await self.chunk_retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="FILE"
)
elif search_mode == SearchMode.DOCUMENTS:
files_chunks = await self.document_retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="FILE"
)
# Transform document retriever results to match expected format
files_chunks = self._transform_document_results(files_chunks)
# Early return if no results
if not files_chunks:
@ -126,6 +152,31 @@ class ConnectorService:
return result_object, files_chunks
def _transform_document_results(self, document_results: List[Dict]) -> List[Dict]:
"""
Transform results from document_retriever.hybrid_search() to match the format
expected by the processing code.
Args:
document_results: Results from document_retriever.hybrid_search()
Returns:
List of transformed results in the format expected by the processing code
"""
transformed_results = []
for doc in document_results:
transformed_results.append({
'document': {
'id': doc.get('document_id'),
'title': doc.get('title', 'Untitled Document'),
'document_type': doc.get('document_type'),
'metadata': doc.get('metadata', {}),
},
'content': doc.get('chunks_content', doc.get('content', '')),
'score': doc.get('score', 0.0)
})
return transformed_results
async def get_connector_by_type(self, user_id: str, connector_type: SearchSourceConnectorType) -> Optional[SearchSourceConnector]:
"""
Get a connector by type for a specific user
@ -249,20 +300,31 @@ class ConnectorService:
"sources": [],
}, []
async def search_slack(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple:
async def search_slack(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
"""
Search for slack and return both the source information and langchain documents
Returns:
tuple: (sources_info, langchain_documents)
"""
slack_chunks = await self.retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="SLACK_CONNECTOR"
)
if search_mode == SearchMode.CHUNKS:
slack_chunks = await self.chunk_retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="SLACK_CONNECTOR"
)
elif search_mode == SearchMode.DOCUMENTS:
slack_chunks = await self.document_retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="SLACK_CONNECTOR"
)
# Transform document retriever results to match expected format
slack_chunks = self._transform_document_results(slack_chunks)
# Early return if no results
if not slack_chunks:
@ -323,7 +385,7 @@ class ConnectorService:
return result_object, slack_chunks
async def search_notion(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple:
async def search_notion(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
"""
Search for Notion pages and return both the source information and langchain documents
@ -336,14 +398,25 @@ class ConnectorService:
Returns:
tuple: (sources_info, langchain_documents)
"""
notion_chunks = await self.retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="NOTION_CONNECTOR"
)
if search_mode == SearchMode.CHUNKS:
notion_chunks = await self.chunk_retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="NOTION_CONNECTOR"
)
elif search_mode == SearchMode.DOCUMENTS:
notion_chunks = await self.document_retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="NOTION_CONNECTOR"
)
# Transform document retriever results to match expected format
notion_chunks = self._transform_document_results(notion_chunks)
# Early return if no results
if not notion_chunks:
return {
@ -405,7 +478,7 @@ class ConnectorService:
return result_object, notion_chunks
async def search_extension(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple:
async def search_extension(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
"""
Search for extension data and return both the source information and langchain documents
@ -418,14 +491,25 @@ class ConnectorService:
Returns:
tuple: (sources_info, langchain_documents)
"""
extension_chunks = await self.retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="EXTENSION"
)
if search_mode == SearchMode.CHUNKS:
extension_chunks = await self.chunk_retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="EXTENSION"
)
elif search_mode == SearchMode.DOCUMENTS:
extension_chunks = await self.document_retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="EXTENSION"
)
# Transform document retriever results to match expected format
extension_chunks = self._transform_document_results(extension_chunks)
# Early return if no results
if not extension_chunks:
return {
@ -505,7 +589,7 @@ class ConnectorService:
return result_object, extension_chunks
async def search_youtube(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple:
async def search_youtube(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
"""
Search for YouTube videos and return both the source information and langchain documents
@ -518,13 +602,24 @@ class ConnectorService:
Returns:
tuple: (sources_info, langchain_documents)
"""
youtube_chunks = await self.retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="YOUTUBE_VIDEO"
)
if search_mode == SearchMode.CHUNKS:
youtube_chunks = await self.chunk_retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="YOUTUBE_VIDEO"
)
elif search_mode == SearchMode.DOCUMENTS:
youtube_chunks = await self.document_retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="YOUTUBE_VIDEO"
)
# Transform document retriever results to match expected format
youtube_chunks = self._transform_document_results(youtube_chunks)
# Early return if no results
if not youtube_chunks:
@ -587,20 +682,31 @@ class ConnectorService:
return result_object, youtube_chunks
async def search_github(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20) -> tuple:
async def search_github(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
"""
Search for GitHub documents and return both the source information and langchain documents
Returns:
tuple: (sources_info, langchain_documents)
"""
github_chunks = await self.retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="GITHUB_CONNECTOR"
)
if search_mode == SearchMode.CHUNKS:
github_chunks = await self.chunk_retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="GITHUB_CONNECTOR"
)
elif search_mode == SearchMode.DOCUMENTS:
github_chunks = await self.document_retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="GITHUB_CONNECTOR"
)
# Transform document retriever results to match expected format
github_chunks = self._transform_document_results(github_chunks)
# Early return if no results
if not github_chunks:
@ -643,7 +749,7 @@ class ConnectorService:
return result_object, github_chunks
async def search_linear(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple:
async def search_linear(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
"""
Search for Linear issues and comments and return both the source information and langchain documents
@ -656,14 +762,25 @@ class ConnectorService:
Returns:
tuple: (sources_info, langchain_documents)
"""
linear_chunks = await self.retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="LINEAR_CONNECTOR"
)
if search_mode == SearchMode.CHUNKS:
linear_chunks = await self.chunk_retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="LINEAR_CONNECTOR"
)
elif search_mode == SearchMode.DOCUMENTS:
linear_chunks = await self.document_retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="LINEAR_CONNECTOR"
)
# Transform document retriever results to match expected format
linear_chunks = self._transform_document_results(linear_chunks)
# Early return if no results
if not linear_chunks:
return {

View file

@ -13,7 +13,9 @@ import {
ArrowDown,
CircleUser,
Database,
SendHorizontal
SendHorizontal,
FileText,
Grid3x3
} from 'lucide-react';
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import { Button } from '@/components/ui/button';
@ -248,6 +250,7 @@ const ChatPage = () => {
const tabsListRef = useRef<HTMLDivElement>(null);
const [terminalExpanded, setTerminalExpanded] = useState(false);
const [selectedConnectors, setSelectedConnectors] = useState<string[]>(["CRAWLED_URL"]);
const [searchMode, setSearchMode] = useState<'DOCUMENTS' | 'CHUNKS'>('DOCUMENTS');
const [researchMode, setResearchMode] = useState<ResearchMode>("GENERAL");
const [currentTime, setCurrentTime] = useState<string>('');
const [currentDate, setCurrentDate] = useState<string>('');
@ -362,7 +365,8 @@ const ChatPage = () => {
data: {
search_space_id: search_space_id,
selected_connectors: selectedConnectors,
research_mode: researchMode
research_mode: researchMode,
search_mode: searchMode
}
},
onError: (error) => {
@ -557,11 +561,6 @@ const ChatPage = () => {
}
}, [terminalExpanded]);
// Get total sources count for a connector type
const getSourcesCount = (connectorType: string) => {
return getSourcesCountUtil(getMessageConnectorSources(messages[messages.length - 1]), connectorType);
};
// Function to check scroll position and update indicators
const updateScrollIndicators = () => {
updateScrollIndicatorsUtil(tabsListRef as React.RefObject<HTMLDivElement>, setCanScrollLeft, setCanScrollRight);
@ -587,23 +586,6 @@ const ChatPage = () => {
// Use the scroll to bottom hook
useScrollToBottom(messagesEndRef as React.RefObject<HTMLDivElement>, [messages]);
// Function to get sources for the main view
const getMainViewSources = (connector: any) => {
return getMainViewSourcesUtil(connector, INITIAL_SOURCES_DISPLAY);
};
// Function to get filtered sources for the dialog with null check
const getFilteredSourcesWithCheck = (connector: any, sourceFilter: string) => {
if (!connector?.sources) return [];
return getFilteredSourcesUtil(connector, sourceFilter);
};
// Function to get paginated dialog sources with null check
const getPaginatedDialogSourcesWithCheck = (connector: any, sourceFilter: string, expandedSources: boolean, sourcesPage: number, sourcesPerPage: number) => {
if (!connector?.sources) return [];
return getPaginatedDialogSourcesUtil(connector, sourceFilter, expandedSources, sourcesPage, sourcesPerPage);
};
// Function to get a citation source by ID
const getCitationSource = React.useCallback((citationId: number, messageIndex?: number): Source | null => {
if (!messages || messages.length === 0) return null;
@ -995,15 +977,17 @@ const ChatPage = () => {
<span className="sr-only">Send</span>
</Button>
</form>
<div className="flex items-center justify-between px-2 py-1 mt-8">
<div className="flex items-center gap-4">
<div className="flex items-center justify-between px-2 py-2 mt-3">
<div className="flex items-center space-x-3">
{/* Connector Selection Dialog */}
<Dialog>
<DialogTrigger asChild>
<ConnectorButton
selectedConnectors={selectedConnectors}
onClick={() => { }}
/>
<div className="h-8">
<ConnectorButton
selectedConnectors={selectedConnectors}
onClick={() => { }}
/>
</div>
</DialogTrigger>
<DialogContent className="sm:max-w-md">
<DialogHeader>
@ -1070,12 +1054,40 @@ const ChatPage = () => {
</DialogContent>
</Dialog>
{/* Search Mode Control */}
<div className="flex items-center p-0.5 rounded-md border border-border bg-muted/20 h-8">
<button
onClick={() => setSearchMode('DOCUMENTS')}
className={`flex h-full items-center justify-center gap-1 px-2 rounded text-xs font-medium transition-colors flex-1 whitespace-nowrap overflow-hidden ${
searchMode === 'DOCUMENTS'
? 'bg-primary text-primary-foreground shadow-sm'
: 'text-muted-foreground hover:text-foreground hover:bg-muted/50'
}`}
>
<FileText className="h-3 w-3 flex-shrink-0 mr-1" />
<span>Full Document</span>
</button>
<button
onClick={() => setSearchMode('CHUNKS')}
className={`flex h-full items-center justify-center gap-1 px-2 rounded text-xs font-medium transition-colors flex-1 whitespace-nowrap overflow-hidden ${
searchMode === 'CHUNKS'
? 'bg-primary text-primary-foreground shadow-sm'
: 'text-muted-foreground hover:text-foreground hover:bg-muted/50'
}`}
>
<Grid3x3 className="h-3 w-3 flex-shrink-0 mr-1" />
<span>Document Chunks</span>
</button>
</div>
{/* Research Mode Segmented Control */}
<SegmentedControl<ResearchMode>
value={researchMode}
onChange={setResearchMode}
options={researcherOptions}
/>
<div className="h-8">
<SegmentedControl<ResearchMode>
value={researchMode}
onChange={setResearchMode}
options={researcherOptions}
/>
</div>
</div>
</div>
</div>

View file

@ -147,7 +147,7 @@ export const ConnectorButton = ({ selectedConnectors, onClick, connectorSources
return (
<Button
variant="outline"
className="h-7 px-2 text-xs font-medium rounded-md border-border relative overflow-hidden group scale-90 origin-left"
className="h-8 px-2 text-xs font-medium rounded-md border-border relative overflow-hidden group"
onClick={onClick}
aria-label={selectedCount === 0 ? "Select Connectors" : `${selectedCount} connectors selected`}
>

View file

@ -15,11 +15,11 @@ type SegmentedControlProps<T extends string> = {
*/
function SegmentedControl<T extends string>({ value, onChange, options }: SegmentedControlProps<T>) {
return (
<div className="flex rounded-md border border-border overflow-hidden scale-90 origin-left">
<div className="flex h-7 rounded-md border border-border overflow-hidden">
{options.map((option) => (
<button
key={option.value}
className={`flex items-center gap-1 px-2 py-1 text-xs transition-colors ${
className={`flex h-full items-center gap-1 px-2 text-xs transition-colors ${
value === option.value
? 'bg-primary text-primary-foreground'
: 'hover:bg-muted'