mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-03 11:09:16 +00:00
Merge pull request #174 from MODSetter/dev
fix(backend): Fix rerank_documents node in sub_section_writer & qna_agent
This commit is contained in:
commit
15a019ef76
12 changed files with 2719 additions and 2120 deletions
|
@ -68,27 +68,3 @@ def build_graph():
|
||||||
|
|
||||||
# Compile the graph once when the module is loaded
|
# Compile the graph once when the module is loaded
|
||||||
graph = build_graph()
|
graph = build_graph()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# L0o55JzTBlCYJNCRYbbxt8mxqRs5kPm6QO8NzVqEZtzqWtG0EklbHuQ3I5ZBdSy8n+EqrdQxcp+R3Yc57NIm79iNS2sxt4tVMSTLeAT6qpMS2SbBER4hRiLaH5BKpXBJoCRPoFMYpDf6pdIokZyJz/EQWQZj531TfLcBfFkxJuWEqvinKhvWJPjApBd1RldixOj57mNXybHN8WFe+FnayhYQhptesoFAVXAk1WuV2URSqXxs5/00Eo8osC55gsye6LXTYzieyUKxurLKw+uy3g==
|
|
||||||
|
|
|
@ -2,11 +2,10 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from app.db import async_session_maker
|
|
||||||
from app.utils.connector_service import ConnectorService
|
from app.utils.connector_service import ConnectorService
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from .configuration import Configuration, SearchMode
|
from .configuration import Configuration, SearchMode
|
||||||
|
@ -15,6 +14,7 @@ from .state import State
|
||||||
from .sub_section_writer.graph import graph as sub_section_writer_graph
|
from .sub_section_writer.graph import graph as sub_section_writer_graph
|
||||||
from .sub_section_writer.configuration import SubSectionType
|
from .sub_section_writer.configuration import SubSectionType
|
||||||
from .qna_agent.graph import graph as qna_agent_graph
|
from .qna_agent.graph import graph as qna_agent_graph
|
||||||
|
from .utils import AnswerOutline, get_connector_emoji, get_connector_friendly_name
|
||||||
|
|
||||||
from app.utils.query_service import QueryService
|
from app.utils.query_service import QueryService
|
||||||
|
|
||||||
|
@ -24,7 +24,6 @@ from langgraph.types import StreamWriter
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
from app.db import Document, SearchSpace
|
from app.db import Document, SearchSpace
|
||||||
|
|
||||||
|
|
||||||
async def fetch_documents_by_ids(
|
async def fetch_documents_by_ids(
|
||||||
document_ids: List[int],
|
document_ids: List[int],
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
@ -41,7 +40,7 @@ async def fetch_documents_by_ids(
|
||||||
Args:
|
Args:
|
||||||
document_ids: List of document IDs to fetch
|
document_ids: List of document IDs to fetch
|
||||||
user_id: The user ID to check ownership
|
user_id: The user ID to check ownership
|
||||||
db_session: The database session eg L0o55JzTBlCYJNCRYbbxt8mxqRs5kPm6QO8NzVqEZtzqWtG0EklbHuQ3I5ZBdSy8n+EqrdQxcp+R3Yc57NIm79iNS2sxt4tVMSTLeAT6qpMS2SbBER4hRiLaH5BKpXBJoCRPoFMYpDf6pdIokZyJz/EQWQZj531TfLcBfFkxJuWEqvinKhvWJPjApBd1RldixOj57mNXybHN8WFe+FnayhYQhptesoFAVXAk1WuV2URSqXxs5/00Eo8osC55gsye6LXTYzieyUKxurLKw+uy3g==
|
db_session: The database session
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (source_objects, document_chunks) - similar to ConnectorService pattern
|
Tuple of (source_objects, document_chunks) - similar to ConnectorService pattern
|
||||||
|
@ -168,7 +167,7 @@ async def fetch_documents_by_ids(
|
||||||
url = f"https://www.youtube.com/watch?v={video_id}" if video_id else ""
|
url = f"https://www.youtube.com/watch?v={video_id}" if video_id else ""
|
||||||
|
|
||||||
elif doc_type == "DISCORD_CONNECTOR":
|
elif doc_type == "DISCORD_CONNECTOR":
|
||||||
# Extract Discord-specific metadata L0o55JzTBlCYJNCRYbbxt8mxqRs5kPm6QO8NzVqEZtzqWtG0EklbHuQ3I5ZBdSy8n+EqrdQxcp+R3Yc57NIm79iNS2sxt4tVMSTLeAT6qpMS2SbBER4hRiLaH5BKpXBJoCRPoFMYpDf6pdIokZyJz/EQWQZj531TfLcBfFkxJuWEqvinKhvWJPjApBd1RldixOj57mNXybHN8WFe+FnayhYQhptesoFAVXAk1WuV2URSqXxs5/00Eo8osC55gsye6LXTYzieyUKxurLKw+uy3g==
|
# Extract Discord-specific metadata
|
||||||
channel_name = metadata.get('channel_name', 'Unknown Channel')
|
channel_name = metadata.get('channel_name', 'Unknown Channel')
|
||||||
channel_id = metadata.get('channel_id', '')
|
channel_id = metadata.get('channel_id', '')
|
||||||
guild_id = metadata.get('guild_id', '')
|
guild_id = metadata.get('guild_id', '')
|
||||||
|
@ -252,16 +251,6 @@ async def fetch_documents_by_ids(
|
||||||
return [], []
|
return [], []
|
||||||
|
|
||||||
|
|
||||||
class Section(BaseModel):
|
|
||||||
"""A section in the answer outline."""
|
|
||||||
section_id: int = Field(..., description="The zero-based index of the section")
|
|
||||||
section_title: str = Field(..., description="The title of the section")
|
|
||||||
questions: List[str] = Field(..., description="Questions to research for this section")
|
|
||||||
|
|
||||||
class AnswerOutline(BaseModel):
|
|
||||||
"""The complete answer outline with all sections."""
|
|
||||||
answer_outline: List[Section] = Field(..., description="List of sections in the answer outline")
|
|
||||||
|
|
||||||
async def write_answer_outline(state: State, config: RunnableConfig, writer: StreamWriter) -> Dict[str, Any]:
|
async def write_answer_outline(state: State, config: RunnableConfig, writer: StreamWriter) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Create a structured answer outline based on the user query.
|
Create a structured answer outline based on the user query.
|
||||||
|
@ -379,6 +368,7 @@ async def write_answer_outline(state: State, config: RunnableConfig, writer: Str
|
||||||
print(f"Raw response: {response.content}")
|
print(f"Raw response: {response.content}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
async def fetch_relevant_documents(
|
async def fetch_relevant_documents(
|
||||||
research_questions: List[str],
|
research_questions: List[str],
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
@ -535,7 +525,7 @@ async def fetch_relevant_documents(
|
||||||
search_mode=search_mode
|
search_mode=search_mode
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add to sources and raw documents L0o55JzTBlCYJNCRYbbxt8mxqRs5kPm6QO8NzVqEZtzqWtG0EklbHuQ3I5ZBdSy8n+EqrdQxcp+R3Yc57NIm79iNS2sxt4tVMSTLeAT6qpMS2SbBER4hRiLaH5BKpXBJoCRPoFMYpDf6pdIokZyJz/EQWQZj531TfLcBfFkxJuWEqvinKhvWJPjApBd1RldixOj57mNXybHN8WFe+FnayhYQhptesoFAVXAk1WuV2URSqXxs5/00Eo8osC55gsye6LXTYzieyUKxurLKw+uy3g==
|
# Add to sources and raw documents
|
||||||
if source_object:
|
if source_object:
|
||||||
all_sources.append(source_object)
|
all_sources.append(source_object)
|
||||||
all_raw_documents.extend(slack_chunks)
|
all_raw_documents.extend(slack_chunks)
|
||||||
|
@ -746,37 +736,6 @@ async def fetch_relevant_documents(
|
||||||
# Return deduplicated documents
|
# Return deduplicated documents
|
||||||
return deduplicated_docs
|
return deduplicated_docs
|
||||||
|
|
||||||
def get_connector_emoji(connector_name: str) -> str:
|
|
||||||
"""Get an appropriate emoji for a connector type."""
|
|
||||||
connector_emojis = {
|
|
||||||
"YOUTUBE_VIDEO": "📹",
|
|
||||||
"EXTENSION": "🧩",
|
|
||||||
"CRAWLED_URL": "🌐",
|
|
||||||
"FILE": "📄",
|
|
||||||
"SLACK_CONNECTOR": "💬",
|
|
||||||
"NOTION_CONNECTOR": "📘",
|
|
||||||
"GITHUB_CONNECTOR": "🐙",
|
|
||||||
"LINEAR_CONNECTOR": "📊",
|
|
||||||
"TAVILY_API": "🔍",
|
|
||||||
"LINKUP_API": "🔗"
|
|
||||||
}
|
|
||||||
return connector_emojis.get(connector_name, "🔎")
|
|
||||||
|
|
||||||
def get_connector_friendly_name(connector_name: str) -> str:
|
|
||||||
"""Convert technical connector IDs to user-friendly names."""
|
|
||||||
connector_friendly_names = {
|
|
||||||
"YOUTUBE_VIDEO": "YouTube",
|
|
||||||
"EXTENSION": "Browser Extension",
|
|
||||||
"CRAWLED_URL": "Web Pages",
|
|
||||||
"FILE": "Files",
|
|
||||||
"SLACK_CONNECTOR": "Slack",
|
|
||||||
"NOTION_CONNECTOR": "Notion",
|
|
||||||
"GITHUB_CONNECTOR": "GitHub",
|
|
||||||
"LINEAR_CONNECTOR": "Linear",
|
|
||||||
"TAVILY_API": "Tavily Search",
|
|
||||||
"LINKUP_API": "Linkup Search"
|
|
||||||
}
|
|
||||||
return connector_friendly_names.get(connector_name, connector_name)
|
|
||||||
|
|
||||||
async def process_sections(state: State, config: RunnableConfig, writer: StreamWriter) -> Dict[str, Any]:
|
async def process_sections(state: State, config: RunnableConfig, writer: StreamWriter) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
|
@ -787,7 +746,7 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW
|
||||||
using the sub_section_writer graph with the shared document pool.
|
using the sub_section_writer graph with the shared document pool.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict containing the final written report in the "final_written_report" key L0o55JzTBlCYJNCRYbbxt8mxqRs5kPm6QO8NzVqEZtzqWtG0EklbHuQ3I5ZBdSy8n+EqrdQxcp+R3Yc57NIm79iNS2sxt4tVMSTLeAT6qpMS2SbBER4hRiLaH5BKpXBJoCRPoFMYpDf6pdIokZyJz/EQWQZj531TfLcBfFkxJuWEqvinKhvWJPjApBd1RldixOj57mNXybHN8WFe+FnayhYQhptesoFAVXAk1WuV2URSqXxs5/00Eo8osC55gsye6LXTYzieyUKxurLKw+uy3g==s.
|
Dict containing the final written report in the "final_written_report" key.
|
||||||
"""
|
"""
|
||||||
# Get configuration and answer outline from state
|
# Get configuration and answer outline from state
|
||||||
configuration = Configuration.from_runnable_config(config)
|
configuration = Configuration.from_runnable_config(config)
|
||||||
|
@ -969,6 +928,7 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW
|
||||||
"final_written_report": final_written_report
|
"final_written_report": final_written_report
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def process_section_with_documents(
|
async def process_section_with_documents(
|
||||||
section_id: int,
|
section_id: int,
|
||||||
section_title: str,
|
section_title: str,
|
||||||
|
@ -1080,7 +1040,7 @@ async def process_section_with_documents(
|
||||||
complete_answer.extend(content_lines)
|
complete_answer.extend(content_lines)
|
||||||
complete_answer.append("") # Empty line after content
|
complete_answer.append("") # Empty line after content
|
||||||
|
|
||||||
# Update answer in UI in real-time L0o55JzTBlCYJNCRYbbxt8mxqRs5kPm6QO8NzVqEZtzqWtG0EklbHuQ3I5ZBdSy8n+EqrdQxcp+R3Yc57NIm79iNS2sxt4tVMSTLeAT6qpMS2SbBER4hRiLaH5BKpXBJoCRPoFMYpDf6pdIokZyJz/EQWQZj531TfLcBfFkxJuWEqvinKhvWJPjApBd1RldixOj57mNXybHN8WFe+FnayhYQhptesoFAVXAk1WuV2URSqXxs5/00Eo8osC55gsye6LXTYzieyUKxurLKw+uy3g==
|
# Update answer in UI in real-time
|
||||||
state.streaming_service.only_update_answer(complete_answer)
|
state.streaming_service.only_update_answer(complete_answer)
|
||||||
writer({"yeild_value": state.streaming_service._format_annotations()})
|
writer({"yeild_value": state.streaming_service._format_annotations()})
|
||||||
|
|
||||||
|
@ -1106,7 +1066,6 @@ async def process_section_with_documents(
|
||||||
return f"Error processing section: {section_title}. Details: {str(e)}"
|
return f"Error processing section: {section_title}. Details: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def reformulate_user_query(state: State, config: RunnableConfig, writer: StreamWriter) -> Dict[str, Any]:
|
async def reformulate_user_query(state: State, config: RunnableConfig, writer: StreamWriter) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Reforms the user query based on the chat history.
|
Reforms the user query based on the chat history.
|
||||||
|
@ -1124,6 +1083,7 @@ async def reformulate_user_query(state: State, config: RunnableConfig, writer: S
|
||||||
"reformulated_query": reformulated_query
|
"reformulated_query": reformulated_query
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def handle_qna_workflow(state: State, config: RunnableConfig, writer: StreamWriter) -> Dict[str, Any]:
|
async def handle_qna_workflow(state: State, config: RunnableConfig, writer: StreamWriter) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Handle the QNA research workflow.
|
Handle the QNA research workflow.
|
||||||
|
@ -1201,7 +1161,7 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre
|
||||||
# Continue with empty documents - the QNA agent will handle this gracefully
|
# Continue with empty documents - the QNA agent will handle this gracefully
|
||||||
relevant_documents = []
|
relevant_documents = []
|
||||||
|
|
||||||
# Combine user-selected documents with connector-fetched documents L0o55JzTBlCYJNCRYbbxt8mxqRs5kPm6QO8NzVqEZtzqWtG0EklbHuQ3I5ZBdSy8n+EqrdQxcp+R3Yc57NIm79iNS2sxt4tVMSTLeAT6qpMS2SbBER4hRiLaH5BKpXBJoCRPoFMYpDf6pdIokZyJz/EQWQZj531TfLcBfFkxJuWEqvinKhvWJPjApBd1RldixOj57mNXybHN8WFe+FnayhYQhptesoFAVXAk1WuV2URSqXxs5/00Eo8osC55gsye6LXTYzieyUKxurLKw+uy3g==
|
# Combine user-selected documents with connector-fetched documents
|
||||||
all_documents = user_selected_documents + relevant_documents
|
all_documents = user_selected_documents + relevant_documents
|
||||||
|
|
||||||
print(f"Fetched {len(relevant_documents)} relevant documents for QNA")
|
print(f"Fetched {len(relevant_documents)} relevant documents for QNA")
|
||||||
|
|
|
@ -1,6 +1,3 @@
|
||||||
"""
|
|
||||||
File Hash: L0o55JzTBlCYJNCRYbbxt8mxqRs5kPm6QO8NzVqEZtzqWtG0EklbHuQ3I5ZBdSy8n+EqrdQxcp+R3Yc57NIm79iNS2sxt4tVMSTLeAT6qpMS2SbBER4hRiLaH5BKpXBJoCRPoFMYpDf6pdIokZyJz/EQWQZj531TfLcBfFkxJuWEqvinKhvWJPjApBd1RldixOj57mNXybHN8WFe+FnayhYQhptesoFAVXAk1WuV2URSqXxs5/00Eo8osC55gsye6LXTYzieyUKxurLKw+uy3g==
|
|
||||||
"""
|
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,105 +18,3 @@ workflow.add_edge("answer_question", "__end__")
|
||||||
# Compile the workflow into an executable graph
|
# Compile the workflow into an executable graph
|
||||||
graph = workflow.compile()
|
graph = workflow.compile()
|
||||||
graph.name = "SurfSense QnA Agent" # This defines the custom name in LangSmith
|
graph.name = "SurfSense QnA Agent" # This defines the custom name in LangSmith
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# L0o55JzTBlCYJNCRYbbxt8mxqRs5kPm6QO8NzVqEZtzqWtG0EklbHuQ3I5ZBdSy8n+EqrdQxcp+R3Yc57NIm79iNS2sxt4tVMSTLeAT6qpMS2SbBER4hRiLaH5BKpXBJoCRPoFMYpDf6pdIokZyJz/EQWQZj531TfLcBfFkxJuWEqvinKhvWJPjApBd1RldixOj57mNXybHN8WFe+FnayhYQhptesoFAVXAk1WuV2URSqXxs5/00Eo8osC55gsye6LXTYzieyUKxurLKw+uy3g==
|
|
|
@ -1,8 +1,8 @@
|
||||||
|
from app.utils.reranker_service import RerankerService
|
||||||
from .configuration import Configuration
|
from .configuration import Configuration
|
||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
from .state import State
|
from .state import State
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
from app.config import config as app_config
|
|
||||||
from .prompts import get_qna_citation_system_prompt, get_qna_no_documents_system_prompt
|
from .prompts import get_qna_citation_system_prompt, get_qna_no_documents_system_prompt
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
|
@ -35,7 +35,7 @@ async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, An
|
||||||
}
|
}
|
||||||
|
|
||||||
# Get reranker service from app config
|
# Get reranker service from app config
|
||||||
reranker_service = getattr(app_config, "reranker_service", None)
|
reranker_service = RerankerService.get_reranker_instance()
|
||||||
|
|
||||||
# Use documents as is if no reranker service is available
|
# Use documents as is if no reranker service is available
|
||||||
reranked_docs = documents
|
reranked_docs = documents
|
||||||
|
|
|
@ -1,6 +1,4 @@
|
||||||
"""Define the state structures for the agent.
|
"""Define the state structures for the agent."""
|
||||||
File Hash: L0o55JzTBlCYJNCRYbbxt8mxqRs5kPm6QO8NzVqEZtzqWtG0EklbHuQ3I5ZBdSy8n+EqrdQxcp+R3Yc57NIm79iNS2sxt4tVMSTLeAT6qpMS2SbBER4hRiLaH5BKpXBJoCRPoFMYpDf6pdIokZyJz/EQWQZj531TfLcBfFkxJuWEqvinKhvWJPjApBd1RldixOj57mNXybHN8WFe+FnayhYQhptesoFAVXAk1WuV2URSqXxs5/00Eo8osC55gsye6LXTYzieyUKxurLKw+uy3g==
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,4 @@
|
||||||
"""Define the state structures for the agent.
|
"""Define the state structures for the agent."""
|
||||||
File Hash: L0o55JzTBlCYJNCRYbbxt8mxqRs5kPm6QO8NzVqEZtzqWtG0EklbHuQ3I5ZBdSy8n+EqrdQxcp+R3Yc57NIm79iNS2sxt4tVMSTLeAT6qpMS2SbBER4hRiLaH5BKpXBJoCRPoFMYpDf6pdIokZyJz/EQWQZj531TfLcBfFkxJuWEqvinKhvWJPjApBd1RldixOj57mNXybHN8WFe+FnayhYQhptesoFAVXAk1WuV2URSqXxs5/00Eo8osC55gsye6LXTYzieyUKxurLKw+uy3g==
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ from .configuration import Configuration
|
||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
from .state import State
|
from .state import State
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
from app.config import config as app_config
|
from app.utils.reranker_service import RerankerService
|
||||||
from .prompts import get_citation_system_prompt, get_no_documents_system_prompt
|
from .prompts import get_citation_system_prompt, get_no_documents_system_prompt
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
from .configuration import SubSectionType
|
from .configuration import SubSectionType
|
||||||
|
@ -35,7 +35,7 @@ async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, An
|
||||||
}
|
}
|
||||||
|
|
||||||
# Get reranker service from app config
|
# Get reranker service from app config
|
||||||
reranker_service = getattr(app_config, "reranker_service", None)
|
reranker_service = RerankerService.get_reranker_instance()
|
||||||
|
|
||||||
# Use documents as is if no reranker service is available
|
# Use documents as is if no reranker service is available
|
||||||
reranked_docs = documents
|
reranked_docs = documents
|
||||||
|
@ -211,7 +211,7 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
|
||||||
HumanMessage(content=human_message_content)
|
HumanMessage(content=human_message_content)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Log final token count L0o55JzTBlCYJNCRYbbxt8mxqRs5kPm6QO8NzVqEZtzqWtG0EklbHuQ3I5ZBdSy8n+EqrdQxcp+R3Yc57NIm79iNS2sxt4tVMSTLeAT6qpMS2SbBER4hRiLaH5BKpXBJoCRPoFMYpDf6pdIokZyJz/EQWQZj531TfLcBfFkxJuWEqvinKhvWJPjApBd1RldixOj57mNXybHN8WFe+FnayhYQhptesoFAVXAk1WuV2URSqXxs5/00Eo8osC55gsye6LXTYzieyUKxurLKw+uy3g==
|
# Log final token count
|
||||||
total_tokens = calculate_token_count(messages_with_chat_history, llm.model)
|
total_tokens = calculate_token_count(messages_with_chat_history, llm.model)
|
||||||
print(f"Final token count: {total_tokens}")
|
print(f"Final token count: {total_tokens}")
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,4 @@
|
||||||
"""Define the state structures for the agent.
|
"""Define the state structures for the agent."""
|
||||||
File Hash: L0o55JzTBlCYJNCRYbbxt8mxqRs5kPm6QO8NzVqEZtzqWtG0EklbHuQ3I5ZBdSy8n+EqrdQxcp+R3Yc57NIm79iNS2sxt4tVMSTLeAT6qpMS2SbBER4hRiLaH5BKpXBJoCRPoFMYpDf6pdIokZyJz/EQWQZj531TfLcBfFkxJuWEqvinKhvWJPjApBd1RldixOj57mNXybHN8WFe+FnayhYQhptesoFAVXAk1WuV2URSqXxs5/00Eo8osC55gsye6LXTYzieyUKxurLKw+uy3g==
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,18 @@
|
||||||
from typing import List, Dict, Any, Tuple, NamedTuple
|
from typing import List, Dict, Any, Tuple, NamedTuple
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from litellm import token_counter, get_model_info
|
from litellm import token_counter, get_model_info
|
||||||
|
|
||||||
|
class Section(BaseModel):
|
||||||
|
"""A section in the answer outline."""
|
||||||
|
section_id: int = Field(..., description="The zero-based index of the section")
|
||||||
|
section_title: str = Field(..., description="The title of the section")
|
||||||
|
questions: List[str] = Field(..., description="Questions to research for this section")
|
||||||
|
|
||||||
|
class AnswerOutline(BaseModel):
|
||||||
|
"""The complete answer outline with all sections."""
|
||||||
|
answer_outline: List[Section] = Field(..., description="List of sections in the answer outline")
|
||||||
|
|
||||||
|
|
||||||
class DocumentTokenInfo(NamedTuple):
|
class DocumentTokenInfo(NamedTuple):
|
||||||
"""Information about a document and its token cost."""
|
"""Information about a document and its token cost."""
|
||||||
|
@ -11,6 +22,40 @@ class DocumentTokenInfo(NamedTuple):
|
||||||
token_count: int
|
token_count: int
|
||||||
|
|
||||||
|
|
||||||
|
def get_connector_emoji(connector_name: str) -> str:
|
||||||
|
"""Get an appropriate emoji for a connector type."""
|
||||||
|
connector_emojis = {
|
||||||
|
"YOUTUBE_VIDEO": "📹",
|
||||||
|
"EXTENSION": "🧩",
|
||||||
|
"CRAWLED_URL": "🌐",
|
||||||
|
"FILE": "📄",
|
||||||
|
"SLACK_CONNECTOR": "💬",
|
||||||
|
"NOTION_CONNECTOR": "📘",
|
||||||
|
"GITHUB_CONNECTOR": "🐙",
|
||||||
|
"LINEAR_CONNECTOR": "📊",
|
||||||
|
"TAVILY_API": "🔍",
|
||||||
|
"LINKUP_API": "🔗"
|
||||||
|
}
|
||||||
|
return connector_emojis.get(connector_name, "🔎")
|
||||||
|
|
||||||
|
|
||||||
|
def get_connector_friendly_name(connector_name: str) -> str:
|
||||||
|
"""Convert technical connector IDs to user-friendly names."""
|
||||||
|
connector_friendly_names = {
|
||||||
|
"YOUTUBE_VIDEO": "YouTube",
|
||||||
|
"EXTENSION": "Browser Extension",
|
||||||
|
"CRAWLED_URL": "Web Pages",
|
||||||
|
"FILE": "Files",
|
||||||
|
"SLACK_CONNECTOR": "Slack",
|
||||||
|
"NOTION_CONNECTOR": "Notion",
|
||||||
|
"GITHUB_CONNECTOR": "GitHub",
|
||||||
|
"LINEAR_CONNECTOR": "Linear",
|
||||||
|
"TAVILY_API": "Tavily Search",
|
||||||
|
"LINKUP_API": "Linkup Search"
|
||||||
|
}
|
||||||
|
return connector_friendly_names.get(connector_name, connector_name)
|
||||||
|
|
||||||
|
|
||||||
def convert_langchain_messages_to_dict(messages: List[BaseMessage]) -> List[Dict[str, str]]:
|
def convert_langchain_messages_to_dict(messages: List[BaseMessage]) -> List[Dict[str, str]]:
|
||||||
"""Convert LangChain messages to format expected by token_counter."""
|
"""Convert LangChain messages to format expected by token_counter."""
|
||||||
role_mapping = {
|
role_mapping = {
|
||||||
|
@ -82,9 +127,6 @@ def calculate_document_token_costs(documents: List[Dict[str, Any]], model: str)
|
||||||
))
|
))
|
||||||
|
|
||||||
return document_token_info
|
return document_token_info
|
||||||
"""
|
|
||||||
File Hash: L0o55JzTBlCYJNCRYbbxt8mxqRs5kPm6QO8NzVqEZtzqWtG0EklbHuQ3I5ZBdSy8n+EqrdQxcp+R3Yc57NIm79iNS2sxt4tVMSTLeAT6qpMS2SbBER4hRiLaH5BKpXBJoCRPoFMYpDf6pdIokZyJz/EQWQZj531TfLcBfFkxJuWEqvinKhvWJPjApBd1RldixOj57mNXybHN8WFe+FnayhYQhptesoFAVXAk1WuV2URSqXxs5/00Eo8osC55gsye6LXTYzieyUKxurLKw+uy3g==
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def find_optimal_documents_with_binary_search(
|
def find_optimal_documents_with_binary_search(
|
||||||
|
@ -185,51 +227,3 @@ def calculate_token_count(messages: List[BaseMessage], model_name: str) -> int:
|
||||||
model = model_name
|
model = model_name
|
||||||
messages_dict = convert_langchain_messages_to_dict(messages)
|
messages_dict = convert_langchain_messages_to_dict(messages)
|
||||||
return token_counter(messages=messages_dict, model=model)
|
return token_counter(messages=messages_dict, model=model)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
File Hash: L0o55JzTBlCYJNCRYbbxt8mxqRs5kPm6QO8NzVqEZtzqWtG0EklbHuQ3I5ZBdSy8n+EqrdQxcp+R3Yc57NIm79iNS2sxt4tVMSTLeAT6qpMS2SbBER4hRiLaH5BKpXBJoCRPoFMYpDf6pdIokZyJz/EQWQZj531TfLcBfFkxJuWEqvinKhvWJPjApBd1RldixOj57mNXybHN8WFe+FnayhYQhptesoFAVXAk1WuV2URSqXxs5/00Eo8osC55gsye6LXTYzieyUKxurLKw+uy3g==
|
|
||||||
"""
|
|
|
@ -80,16 +80,16 @@ class RerankerService:
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_reranker_instance(config=None) -> Optional['RerankerService']:
|
def get_reranker_instance() -> Optional['RerankerService']:
|
||||||
"""
|
"""
|
||||||
Get a reranker service instance based on configuration
|
Get a reranker service instance from the global configuration.
|
||||||
|
|
||||||
Args:
|
|
||||||
config: Configuration object that may contain a reranker_instance
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[RerankerService]: A reranker service instance or None
|
Optional[RerankerService]: A reranker service instance if configured, None otherwise
|
||||||
"""
|
"""
|
||||||
if config and hasattr(config, 'reranker_instance') and config.reranker_instance:
|
from app.config import config
|
||||||
|
|
||||||
|
if hasattr(config, 'reranker_instance') and config.reranker_instance:
|
||||||
return RerankerService(config.reranker_instance)
|
return RerankerService(config.reranker_instance)
|
||||||
return None
|
return None
|
||||||
|
|
4512
surfsense_backend/uv.lock
generated
4512
surfsense_backend/uv.lock
generated
File diff suppressed because it is too large
Load diff
Loading…
Add table
Reference in a new issue