mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-01 18:19:08 +00:00
fix: Fix rerank_documents node in sub_section_writer & qna_agent
This commit is contained in:
parent
d005f810f1
commit
f99878c07c
6 changed files with 2708 additions and 1923 deletions
|
@ -2,11 +2,10 @@ import asyncio
|
|||
import json
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from app.db import async_session_maker
|
||||
from app.utils.connector_service import ConnectorService
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from .configuration import Configuration, SearchMode
|
||||
|
@ -15,6 +14,7 @@ from .state import State
|
|||
from .sub_section_writer.graph import graph as sub_section_writer_graph
|
||||
from .sub_section_writer.configuration import SubSectionType
|
||||
from .qna_agent.graph import graph as qna_agent_graph
|
||||
from .utils import AnswerOutline, get_connector_emoji, get_connector_friendly_name
|
||||
|
||||
from app.utils.query_service import QueryService
|
||||
|
||||
|
@ -24,7 +24,6 @@ from langgraph.types import StreamWriter
|
|||
from sqlalchemy.future import select
|
||||
from app.db import Document, SearchSpace
|
||||
|
||||
|
||||
async def fetch_documents_by_ids(
|
||||
document_ids: List[int],
|
||||
user_id: str,
|
||||
|
@ -252,16 +251,6 @@ async def fetch_documents_by_ids(
|
|||
return [], []
|
||||
|
||||
|
||||
class Section(BaseModel):
|
||||
"""A section in the answer outline."""
|
||||
section_id: int = Field(..., description="The zero-based index of the section")
|
||||
section_title: str = Field(..., description="The title of the section")
|
||||
questions: List[str] = Field(..., description="Questions to research for this section")
|
||||
|
||||
class AnswerOutline(BaseModel):
|
||||
"""The complete answer outline with all sections."""
|
||||
answer_outline: List[Section] = Field(..., description="List of sections in the answer outline")
|
||||
|
||||
async def write_answer_outline(state: State, config: RunnableConfig, writer: StreamWriter) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a structured answer outline based on the user query.
|
||||
|
@ -379,6 +368,7 @@ async def write_answer_outline(state: State, config: RunnableConfig, writer: Str
|
|||
print(f"Raw response: {response.content}")
|
||||
raise
|
||||
|
||||
|
||||
async def fetch_relevant_documents(
|
||||
research_questions: List[str],
|
||||
user_id: str,
|
||||
|
@ -746,37 +736,6 @@ async def fetch_relevant_documents(
|
|||
# Return deduplicated documents
|
||||
return deduplicated_docs
|
||||
|
||||
def get_connector_emoji(connector_name: str) -> str:
|
||||
"""Get an appropriate emoji for a connector type."""
|
||||
connector_emojis = {
|
||||
"YOUTUBE_VIDEO": "📹",
|
||||
"EXTENSION": "🧩",
|
||||
"CRAWLED_URL": "🌐",
|
||||
"FILE": "📄",
|
||||
"SLACK_CONNECTOR": "💬",
|
||||
"NOTION_CONNECTOR": "📘",
|
||||
"GITHUB_CONNECTOR": "🐙",
|
||||
"LINEAR_CONNECTOR": "📊",
|
||||
"TAVILY_API": "🔍",
|
||||
"LINKUP_API": "🔗"
|
||||
}
|
||||
return connector_emojis.get(connector_name, "🔎")
|
||||
|
||||
def get_connector_friendly_name(connector_name: str) -> str:
|
||||
"""Convert technical connector IDs to user-friendly names."""
|
||||
connector_friendly_names = {
|
||||
"YOUTUBE_VIDEO": "YouTube",
|
||||
"EXTENSION": "Browser Extension",
|
||||
"CRAWLED_URL": "Web Pages",
|
||||
"FILE": "Files",
|
||||
"SLACK_CONNECTOR": "Slack",
|
||||
"NOTION_CONNECTOR": "Notion",
|
||||
"GITHUB_CONNECTOR": "GitHub",
|
||||
"LINEAR_CONNECTOR": "Linear",
|
||||
"TAVILY_API": "Tavily Search",
|
||||
"LINKUP_API": "Linkup Search"
|
||||
}
|
||||
return connector_friendly_names.get(connector_name, connector_name)
|
||||
|
||||
async def process_sections(state: State, config: RunnableConfig, writer: StreamWriter) -> Dict[str, Any]:
|
||||
"""
|
||||
|
@ -969,6 +928,7 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW
|
|||
"final_written_report": final_written_report
|
||||
}
|
||||
|
||||
|
||||
async def process_section_with_documents(
|
||||
section_id: int,
|
||||
section_title: str,
|
||||
|
@ -1106,7 +1066,6 @@ async def process_section_with_documents(
|
|||
return f"Error processing section: {section_title}. Details: {str(e)}"
|
||||
|
||||
|
||||
|
||||
async def reformulate_user_query(state: State, config: RunnableConfig, writer: StreamWriter) -> Dict[str, Any]:
|
||||
"""
|
||||
Reforms the user query based on the chat history.
|
||||
|
@ -1124,6 +1083,7 @@ async def reformulate_user_query(state: State, config: RunnableConfig, writer: S
|
|||
"reformulated_query": reformulated_query
|
||||
}
|
||||
|
||||
|
||||
async def handle_qna_workflow(state: State, config: RunnableConfig, writer: StreamWriter) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle the QNA research workflow.
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from app.utils.reranker_service import RerankerService
|
||||
from .configuration import Configuration
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from .state import State
|
||||
from typing import Any, Dict
|
||||
from app.config import config as app_config
|
||||
from .prompts import get_qna_citation_system_prompt, get_qna_no_documents_system_prompt
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from ..utils import (
|
||||
|
@ -35,7 +35,7 @@ async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, An
|
|||
}
|
||||
|
||||
# Get reranker service from app config
|
||||
reranker_service = getattr(app_config, "reranker_service", None)
|
||||
reranker_service = RerankerService.get_reranker_instance()
|
||||
|
||||
# Use documents as is if no reranker service is available
|
||||
reranked_docs = documents
|
||||
|
|
|
@ -2,7 +2,7 @@ from .configuration import Configuration
|
|||
from langchain_core.runnables import RunnableConfig
|
||||
from .state import State
|
||||
from typing import Any, Dict
|
||||
from app.config import config as app_config
|
||||
from app.utils.reranker_service import RerankerService
|
||||
from .prompts import get_citation_system_prompt, get_no_documents_system_prompt
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from .configuration import SubSectionType
|
||||
|
@ -35,7 +35,7 @@ async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, An
|
|||
}
|
||||
|
||||
# Get reranker service from app config
|
||||
reranker_service = getattr(app_config, "reranker_service", None)
|
||||
reranker_service = RerankerService.get_reranker_instance()
|
||||
|
||||
# Use documents as is if no reranker service is available
|
||||
reranked_docs = documents
|
||||
|
|
|
@ -1,7 +1,18 @@
|
|||
from typing import List, Dict, Any, Tuple, NamedTuple
|
||||
from langchain_core.messages import BaseMessage
|
||||
from pydantic import BaseModel, Field
|
||||
from litellm import token_counter, get_model_info
|
||||
|
||||
class Section(BaseModel):
|
||||
"""A section in the answer outline."""
|
||||
section_id: int = Field(..., description="The zero-based index of the section")
|
||||
section_title: str = Field(..., description="The title of the section")
|
||||
questions: List[str] = Field(..., description="Questions to research for this section")
|
||||
|
||||
class AnswerOutline(BaseModel):
|
||||
"""The complete answer outline with all sections."""
|
||||
answer_outline: List[Section] = Field(..., description="List of sections in the answer outline")
|
||||
|
||||
|
||||
class DocumentTokenInfo(NamedTuple):
|
||||
"""Information about a document and its token cost."""
|
||||
|
@ -9,6 +20,40 @@ class DocumentTokenInfo(NamedTuple):
|
|||
document: Dict[str, Any]
|
||||
formatted_content: str
|
||||
token_count: int
|
||||
|
||||
|
||||
def get_connector_emoji(connector_name: str) -> str:
|
||||
"""Get an appropriate emoji for a connector type."""
|
||||
connector_emojis = {
|
||||
"YOUTUBE_VIDEO": "📹",
|
||||
"EXTENSION": "🧩",
|
||||
"CRAWLED_URL": "🌐",
|
||||
"FILE": "📄",
|
||||
"SLACK_CONNECTOR": "💬",
|
||||
"NOTION_CONNECTOR": "📘",
|
||||
"GITHUB_CONNECTOR": "🐙",
|
||||
"LINEAR_CONNECTOR": "📊",
|
||||
"TAVILY_API": "🔍",
|
||||
"LINKUP_API": "🔗"
|
||||
}
|
||||
return connector_emojis.get(connector_name, "🔎")
|
||||
|
||||
|
||||
def get_connector_friendly_name(connector_name: str) -> str:
|
||||
"""Convert technical connector IDs to user-friendly names."""
|
||||
connector_friendly_names = {
|
||||
"YOUTUBE_VIDEO": "YouTube",
|
||||
"EXTENSION": "Browser Extension",
|
||||
"CRAWLED_URL": "Web Pages",
|
||||
"FILE": "Files",
|
||||
"SLACK_CONNECTOR": "Slack",
|
||||
"NOTION_CONNECTOR": "Notion",
|
||||
"GITHUB_CONNECTOR": "GitHub",
|
||||
"LINEAR_CONNECTOR": "Linear",
|
||||
"TAVILY_API": "Tavily Search",
|
||||
"LINKUP_API": "Linkup Search"
|
||||
}
|
||||
return connector_friendly_names.get(connector_name, connector_name)
|
||||
|
||||
|
||||
def convert_langchain_messages_to_dict(messages: List[BaseMessage]) -> List[Dict[str, str]]:
|
||||
|
|
|
@ -80,16 +80,16 @@ class RerankerService:
|
|||
return documents
|
||||
|
||||
@staticmethod
|
||||
def get_reranker_instance(config=None) -> Optional['RerankerService']:
|
||||
def get_reranker_instance() -> Optional['RerankerService']:
|
||||
"""
|
||||
Get a reranker service instance based on configuration
|
||||
Get a reranker service instance from the global configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration object that may contain a reranker_instance
|
||||
|
||||
Returns:
|
||||
Optional[RerankerService]: A reranker service instance or None
|
||||
Optional[RerankerService]: A reranker service instance if configured, None otherwise
|
||||
"""
|
||||
if config and hasattr(config, 'reranker_instance') and config.reranker_instance:
|
||||
from app.config import config
|
||||
|
||||
if hasattr(config, 'reranker_instance') and config.reranker_instance:
|
||||
return RerankerService(config.reranker_instance)
|
||||
return None
|
||||
return None
|
||||
|
4512
surfsense_backend/uv.lock
generated
4512
surfsense_backend/uv.lock
generated
File diff suppressed because it is too large
Load diff
Loading…
Add table
Reference in a new issue