Fixed all ruff lint and formatting errors

This commit is contained in:
Utkarsh-Patel-13 2025-07-24 14:43:48 -07:00
parent 0a03c42cc5
commit d359a59f6d
85 changed files with 5520 additions and 3870 deletions

View file

@ -1,8 +1,8 @@
import asyncio
from logging.config import fileConfig
import os
import sys
from logging.config import fileConfig
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import async_engine_from_config
@ -11,10 +11,10 @@ from alembic import context
# Ensure the app directory is in the Python path
# This allows Alembic to find your models
sys.path.insert(0, os.path.realpath(os.path.join(os.path.dirname(__file__), '..')))
sys.path.insert(0, os.path.realpath(os.path.join(os.path.dirname(__file__), "..")))
# Import your models base
from app.db import Base # Assuming your Base is defined in app.db
from app.db import Base # Assuming your Base is defined in app.db
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.

View file

@ -6,9 +6,9 @@ Revises: 9
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "10"

View file

@ -6,10 +6,10 @@ Revises: 10
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID, JSON
from sqlalchemy.dialects.postgresql import JSON, UUID
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "11"

View file

@ -6,10 +6,10 @@ Revises: 11
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import JSON
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "12"

View file

@ -6,8 +6,10 @@ Revises:
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from alembic import op
# Import pgvector if needed for other types, though not for this ENUM change
# import pgvector

View file

@ -6,9 +6,9 @@ Revises: e55302644c51
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '2'

View file

@ -6,9 +6,9 @@ Revises: 2
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '3'

View file

@ -6,9 +6,9 @@ Revises: 3
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '4'

View file

@ -6,9 +6,9 @@ Revises: 4
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '5'

View file

@ -6,10 +6,10 @@ Revises: 5
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import JSON
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '6'

View file

@ -6,9 +6,9 @@ Revises: 6
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '7'

View file

@ -5,9 +5,9 @@ Revises: 7
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '8'

View file

@ -6,9 +6,9 @@ Revises: 8
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "9"

View file

@ -6,9 +6,9 @@ Revises: 1
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = 'e55302644c51'

View file

@ -3,7 +3,6 @@
from __future__ import annotations
from dataclasses import dataclass, fields
from typing import Optional
from langchain_core.runnables import RunnableConfig
@ -17,11 +16,11 @@ class Configuration:
# create assistants (https://langchain-ai.github.io/langgraph/cloud/how-tos/configuration_cloud/)
# and when you invoke the graph
podcast_title: str
user_id: str
user_id: str
@classmethod
def from_runnable_config(
cls, config: Optional[RunnableConfig] = None
cls, config: RunnableConfig | None = None
) -> Configuration:
"""Create a Configuration instance from a RunnableConfig object."""
configurable = (config.get("configurable") or {}) if config else {}

View file

@ -1,14 +1,11 @@
from langgraph.graph import StateGraph
from .configuration import Configuration
from .nodes import create_merged_podcast_audio, create_podcast_transcript
from .state import State
from .nodes import create_merged_podcast_audio, create_podcast_transcript
def build_graph():
# Define a new graph
workflow = StateGraph(State, config_schema=Configuration)
@ -24,8 +21,9 @@ def build_graph():
# Compile the workflow into an executable graph
graph = workflow.compile()
graph.name = "Surfsense Podcaster" # This defines the custom name in LangSmith
return graph
# Compile the graph once when the module is loaded
graph = build_graph()

View file

@ -1,148 +1,154 @@
from typing import Any, Dict
import asyncio
import json
import os
import uuid
from pathlib import Path
import asyncio
from typing import Any
from ffmpeg.asyncio import FFmpeg
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig
from litellm import aspeech
from ffmpeg.asyncio import FFmpeg
from .configuration import Configuration
from .state import PodcastTranscriptEntry, State, PodcastTranscripts
from .prompts import get_podcast_generation_prompt
from app.config import config as app_config
from app.services.llm_service import get_user_long_context_llm
from .configuration import Configuration
from .prompts import get_podcast_generation_prompt
from .state import PodcastTranscriptEntry, PodcastTranscripts, State
async def create_podcast_transcript(state: State, config: RunnableConfig) -> Dict[str, Any]:
async def create_podcast_transcript(
state: State, config: RunnableConfig
) -> dict[str, Any]:
"""Each node does work."""
# Get configuration from runnable config
configuration = Configuration.from_runnable_config(config)
user_id = configuration.user_id
# Get user's long context LLM
llm = await get_user_long_context_llm(state.db_session, user_id)
if not llm:
error_message = f"No long context LLM configured for user {user_id}"
print(error_message)
raise RuntimeError(error_message)
# Get the prompt
prompt = get_podcast_generation_prompt()
# Create the messages
messages = [
SystemMessage(content=prompt),
HumanMessage(content=f"<source_content>{state.source_content}</source_content>")
HumanMessage(
content=f"<source_content>{state.source_content}</source_content>"
),
]
# Generate the podcast transcript
llm_response = await llm.ainvoke(messages)
# First try the direct approach
try:
podcast_transcript = PodcastTranscripts.model_validate(json.loads(llm_response.content))
podcast_transcript = PodcastTranscripts.model_validate(
json.loads(llm_response.content)
)
except (json.JSONDecodeError, ValueError) as e:
print(f"Direct JSON parsing failed, trying fallback approach: {str(e)}")
print(f"Direct JSON parsing failed, trying fallback approach: {e!s}")
# Fallback: Parse the JSON response manually
try:
# Extract JSON content from the response
content = llm_response.content
# Find the JSON in the content (handle case where LLM might add additional text)
json_start = content.find('{')
json_end = content.rfind('}') + 1
json_start = content.find("{")
json_end = content.rfind("}") + 1
if json_start >= 0 and json_end > json_start:
json_str = content[json_start:json_end]
# Parse the JSON string
parsed_data = json.loads(json_str)
# Convert to Pydantic model
podcast_transcript = PodcastTranscripts.model_validate(parsed_data)
print(f"Successfully parsed podcast transcript using fallback approach")
print("Successfully parsed podcast transcript using fallback approach")
else:
# If JSON structure not found, raise a clear error
error_message = f"Could not find valid JSON in LLM response. Raw response: {content}"
print(error_message)
raise ValueError(error_message)
except (json.JSONDecodeError, ValueError) as e2:
# Log the error and re-raise it
error_message = f"Error parsing LLM response (fallback also failed): {str(e2)}"
print(f"Error parsing LLM response: {str(e2)}")
error_message = f"Error parsing LLM response (fallback also failed): {e2!s}"
print(f"Error parsing LLM response: {e2!s}")
print(f"Raw response: {llm_response.content}")
raise
return {
"podcast_transcript": podcast_transcript.podcast_transcripts
}
async def create_merged_podcast_audio(state: State, config: RunnableConfig) -> Dict[str, Any]:
return {"podcast_transcript": podcast_transcript.podcast_transcripts}
async def create_merged_podcast_audio(
state: State, config: RunnableConfig
) -> dict[str, Any]:
"""Generate audio for each transcript and merge them into a single podcast file."""
configuration = Configuration.from_runnable_config(config)
starting_transcript = PodcastTranscriptEntry(
speaker_id=1,
dialog=f"Welcome to {configuration.podcast_title} Podcast."
speaker_id=1, dialog=f"Welcome to {configuration.podcast_title} Podcast."
)
transcript = state.podcast_transcript
# Merge the starting transcript with the podcast transcript
# Check if transcript is a PodcastTranscripts object or already a list
if hasattr(transcript, 'podcast_transcripts'):
if hasattr(transcript, "podcast_transcripts"):
transcript_entries = transcript.podcast_transcripts
else:
transcript_entries = transcript
merged_transcript = [starting_transcript] + transcript_entries
merged_transcript = [starting_transcript, *transcript_entries]
# Create a temporary directory for audio files
temp_dir = Path("temp_audio")
temp_dir.mkdir(exist_ok=True)
# Generate a unique session ID for this podcast
session_id = str(uuid.uuid4())
output_path = f"podcasts/{session_id}_podcast.mp3"
os.makedirs("podcasts", exist_ok=True)
# Map of speaker_id to voice
voice_mapping = {
0: "alloy", # Default/intro voice
1: "echo", # First speaker
1: "echo", # First speaker
# 2: "fable", # Second speaker
# 3: "onyx", # Third speaker
# 4: "nova", # Fourth speaker
# 5: "shimmer" # Fifth speaker
}
# Generate audio for each transcript segment
audio_files = []
async def generate_speech_for_segment(segment, index):
# Handle both dictionary and PodcastTranscriptEntry objects
if hasattr(segment, 'speaker_id'):
if hasattr(segment, "speaker_id"):
speaker_id = segment.speaker_id
dialog = segment.dialog
else:
speaker_id = segment.get("speaker_id", 0)
dialog = segment.get("dialog", "")
# Select voice based on speaker_id
voice = voice_mapping.get(speaker_id, "alloy")
# Generate a unique filename for this segment
filename = f"{temp_dir}/{session_id}_{index}.mp3"
try:
if app_config.TTS_SERVICE_API_BASE:
response = await aspeech(
@ -163,55 +169,61 @@ async def create_merged_podcast_audio(state: State, config: RunnableConfig) -> D
max_retries=2,
timeout=600,
)
# Save the audio to a file - use proper streaming method
with open(filename, 'wb') as f:
with open(filename, "wb") as f:
f.write(response.content)
return filename
except Exception as e:
print(f"Error generating speech for segment {index}: {str(e)}")
print(f"Error generating speech for segment {index}: {e!s}")
raise
# Generate all audio files concurrently
tasks = [generate_speech_for_segment(segment, i) for i, segment in enumerate(merged_transcript)]
tasks = [
generate_speech_for_segment(segment, i)
for i, segment in enumerate(merged_transcript)
]
audio_files = await asyncio.gather(*tasks)
# Merge audio files using ffmpeg
try:
# Create FFmpeg instance with the first input
ffmpeg = FFmpeg().option("y")
# Add each audio file as input
for audio_file in audio_files:
ffmpeg = ffmpeg.input(audio_file)
# Configure the concatenation and output
filter_complex = []
for i in range(len(audio_files)):
filter_complex.append(f"[{i}:0]")
filter_complex_str = "".join(filter_complex) + f"concat=n={len(audio_files)}:v=0:a=1[outa]"
filter_complex_str = (
"".join(filter_complex) + f"concat=n={len(audio_files)}:v=0:a=1[outa]"
)
ffmpeg = ffmpeg.option("filter_complex", filter_complex_str)
ffmpeg = ffmpeg.output(output_path, map="[outa]")
# Execute FFmpeg
await ffmpeg.execute()
print(f"Successfully created podcast audio: {output_path}")
except Exception as e:
print(f"Error merging audio files: {str(e)}")
print(f"Error merging audio files: {e!s}")
raise
finally:
# Clean up temporary files
for audio_file in audio_files:
try:
os.remove(audio_file)
except:
except Exception as e:
print(f"Error removing audio file {audio_file}: {e!s}")
pass
return {
"podcast_transcript": merged_transcript,
"final_podcast_file_path": output_path
"final_podcast_file_path": output_path,
}

View file

@ -108,4 +108,4 @@ Output:
Transform the source material into a lively and engaging podcast conversation. Craft dialogue that showcases authentic host chemistry and natural interaction (including occasional disagreement, building on points, or asking follow-up questions). Use varied speech patterns reflecting real human conversation, ensuring the final script effectively educates *and* entertains the listener while keeping within a 5-minute audio duration.
</podcast_generation_system>
"""
"""

View file

@ -3,14 +3,16 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Optional
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
class PodcastTranscriptEntry(BaseModel):
"""
Represents a single entry in a podcast transcript.
"""
speaker_id: int = Field(..., description="The ID of the speaker (0 or 1)")
dialog: str = Field(..., description="The dialog text spoken by the speaker")
@ -19,10 +21,11 @@ class PodcastTranscripts(BaseModel):
"""
Represents the full podcast transcript structure.
"""
podcast_transcripts: List[PodcastTranscriptEntry] = Field(
...,
description="List of transcript entries with alternating speakers"
)
podcast_transcripts: list[PodcastTranscriptEntry] = Field(
..., description="List of transcript entries with alternating speakers"
)
@dataclass
class State:
@ -32,8 +35,9 @@ class State:
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
for more information.
"""
# Runtime context
db_session: AsyncSession
source_content: str
podcast_transcript: Optional[List[PodcastTranscriptEntry]] = None
final_podcast_file_path: Optional[str] = None
podcast_transcript: list[PodcastTranscriptEntry] | None = None
final_podcast_file_path: str | None = None

View file

@ -4,17 +4,20 @@ from __future__ import annotations
from dataclasses import dataclass, fields
from enum import Enum
from typing import Optional, List, Any
from langchain_core.runnables import RunnableConfig
class SearchMode(Enum):
class SearchMode(Enum):
"""Enum defining the type of search mode."""
CHUNKS = "CHUNKS"
DOCUMENTS = "DOCUMENTS"
class ResearchMode(Enum):
"""Enum defining the type of research mode."""
QNA = "QNA"
REPORT_GENERAL = "REPORT_GENERAL"
REPORT_DEEP = "REPORT_DEEP"
@ -28,16 +31,16 @@ class Configuration:
# Input parameters provided at invocation
user_query: str
num_sections: int
connectors_to_search: List[str]
connectors_to_search: list[str]
user_id: str
search_space_id: int
search_mode: SearchMode
research_mode: ResearchMode
document_ids_to_add_in_context: List[int]
document_ids_to_add_in_context: list[int]
@classmethod
def from_runnable_config(
cls, config: Optional[RunnableConfig] = None
cls, config: RunnableConfig | None = None
) -> Configuration:
"""Create a Configuration instance from a RunnableConfig object."""
configurable = (config.get("configurable") or {}) if config else {}

View file

@ -1,31 +1,41 @@
from typing import Any, TypedDict
from langgraph.graph import StateGraph
from .state import State
from .nodes import reformulate_user_query, write_answer_outline, process_sections, handle_qna_workflow, generate_further_questions
from .configuration import Configuration, ResearchMode
from typing import TypedDict, List, Dict, Any, Optional
from .nodes import (
generate_further_questions,
handle_qna_workflow,
process_sections,
reformulate_user_query,
write_answer_outline,
)
from .state import State
# Define what keys are in our state dict
class GraphState(TypedDict):
# Intermediate data produced during workflow
answer_outline: Optional[Any]
answer_outline: Any | None
# Final output
final_written_report: Optional[str]
final_written_report: str | None
def build_graph():
"""
Build and return the LangGraph workflow.
This function constructs the researcher agent graph with conditional routing
based on research_mode - QNA mode uses a direct Q&A workflow while other modes
use the full report generation pipeline. Both paths generate follow-up questions
at the end using the reranked documents from the sub-agents.
Returns:
A compiled LangGraph workflow
"""
# Define a new graph with state class
workflow = StateGraph(State, config_schema=Configuration)
# Add nodes to the graph
workflow.add_node("reformulate_user_query", reformulate_user_query)
workflow.add_node("handle_qna_workflow", handle_qna_workflow)
@ -35,41 +45,42 @@ def build_graph():
# Define the edges
workflow.add_edge("__start__", "reformulate_user_query")
# Add conditional edges from reformulate_user_query based on research mode
def route_after_reformulate(state: State, config) -> str:
"""Route based on research_mode after reformulating the query."""
configuration = Configuration.from_runnable_config(config)
if configuration.research_mode == ResearchMode.QNA.value:
return "handle_qna_workflow"
else:
return "write_answer_outline"
workflow.add_conditional_edges(
"reformulate_user_query",
route_after_reformulate,
{
"handle_qna_workflow": "handle_qna_workflow",
"write_answer_outline": "write_answer_outline"
}
"write_answer_outline": "write_answer_outline",
},
)
# QNA workflow path: handle_qna_workflow -> generate_further_questions -> __end__
workflow.add_edge("handle_qna_workflow", "generate_further_questions")
# Report generation workflow path: write_answer_outline -> process_sections -> generate_further_questions -> __end__
workflow.add_edge("write_answer_outline", "process_sections")
workflow.add_edge("process_sections", "generate_further_questions")
# Both paths end after generating further questions
workflow.add_edge("generate_further_questions", "__end__")
# Compile the workflow into an executable graph
graph = workflow.compile()
graph.name = "Surfsense Researcher" # This defines the custom name in LangSmith
return graph
# Compile the graph once when the module is loaded
graph = build_graph()

File diff suppressed because it is too large Load diff

View file

@ -221,4 +221,4 @@ Output:
}}
</examples>
</further_questions_system>
"""
"""

View file

@ -1,5 +1,4 @@
"""QnA Agent.
"""
"""QnA Agent."""
from .graph import graph

View file

@ -3,7 +3,7 @@
from __future__ import annotations
from dataclasses import dataclass, fields
from typing import Optional, List, Any
from typing import Any
from langchain_core.runnables import RunnableConfig
@ -15,13 +15,15 @@ class Configuration:
# Configuration parameters for the Q&A agent
user_query: str # The user's question to answer
reformulated_query: str # The reformulated query
relevant_documents: List[Any] # Documents provided directly to the agent for answering
relevant_documents: list[
Any
] # Documents provided directly to the agent for answering
user_id: str # User identifier
search_space_id: int # Search space identifier
@classmethod
def from_runnable_config(
cls, config: Optional[RunnableConfig] = None
cls, config: RunnableConfig | None = None
) -> Configuration:
"""Create a Configuration instance from a RunnableConfig object."""
configurable = (config.get("configurable") or {}) if config else {}

View file

@ -1,7 +1,8 @@
from langgraph.graph import StateGraph
from .state import State
from .nodes import rerank_documents, answer_question
from .configuration import Configuration
from .nodes import answer_question, rerank_documents
from .state import State
# Define a new graph
workflow = StateGraph(State, config_schema=Configuration)

View file

@ -1,24 +1,28 @@
from app.services.reranker_service import RerankerService
from .configuration import Configuration
from langchain_core.runnables import RunnableConfig
from .state import State
from typing import Any, Dict
from .prompts import get_qna_citation_system_prompt, get_qna_no_documents_system_prompt
from langchain_core.messages import HumanMessage, SystemMessage
from ..utils import (
optimize_documents_for_token_limit,
calculate_token_count,
format_documents_section
)
from typing import Any
async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, Any]:
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig
from app.services.reranker_service import RerankerService
from ..utils import (
calculate_token_count,
format_documents_section,
optimize_documents_for_token_limit,
)
from .configuration import Configuration
from .prompts import get_qna_citation_system_prompt, get_qna_no_documents_system_prompt
from .state import State
async def rerank_documents(state: State, config: RunnableConfig) -> dict[str, Any]:
"""
Rerank the documents based on relevance to the user's question.
This node takes the relevant documents provided in the configuration,
reranks them using the reranker service based on the user's query,
and updates the state with the reranked documents.
Returns:
Dict containing the reranked documents.
"""
@ -30,16 +34,14 @@ async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, An
# If no documents were provided, return empty list
if not documents or len(documents) == 0:
return {
"reranked_documents": []
}
return {"reranked_documents": []}
# Get reranker service from app config
reranker_service = RerankerService.get_reranker_instance()
# Use documents as is if no reranker service is available
reranked_docs = documents
if reranker_service:
try:
# Convert documents to format expected by reranker if needed
@ -51,58 +53,64 @@ async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, An
"document": {
"id": doc.get("document", {}).get("id", ""),
"title": doc.get("document", {}).get("title", ""),
"document_type": doc.get("document", {}).get("document_type", ""),
"metadata": doc.get("document", {}).get("metadata", {})
}
} for i, doc in enumerate(documents)
"document_type": doc.get("document", {}).get(
"document_type", ""
),
"metadata": doc.get("document", {}).get("metadata", {}),
},
}
for i, doc in enumerate(documents)
]
# Rerank documents using the user's query
reranked_docs = reranker_service.rerank_documents(user_query + "\n" + reformulated_query, reranker_input_docs)
reranked_docs = reranker_service.rerank_documents(
user_query + "\n" + reformulated_query, reranker_input_docs
)
# Sort by score in descending order
reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True)
print(f"Reranked {len(reranked_docs)} documents for Q&A query: {user_query}")
except Exception as e:
print(f"Error during reranking: {str(e)}")
# Use original docs if reranking fails
return {
"reranked_documents": reranked_docs
}
async def answer_question(state: State, config: RunnableConfig) -> Dict[str, Any]:
print(
f"Reranked {len(reranked_docs)} documents for Q&A query: {user_query}"
)
except Exception as e:
print(f"Error during reranking: {e!s}")
# Use original docs if reranking fails
return {"reranked_documents": reranked_docs}
async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any]:
"""
Answer the user's question using the provided documents.
This node takes the relevant documents provided in the configuration and uses
an LLM to generate a comprehensive answer to the user's question with
proper citations. The citations follow IEEE format using source IDs from the
documents. If no documents are provided, it will use chat history to generate
an answer.
Returns:
Dict containing the final answer in the "final_answer" key.
"""
from app.services.llm_service import get_user_fast_llm
# Get configuration and relevant documents from configuration
configuration = Configuration.from_runnable_config(config)
documents = state.reranked_documents
user_query = configuration.user_query
user_id = configuration.user_id
# Get user's fast LLM
llm = await get_user_fast_llm(state.db_session, user_id)
if not llm:
error_message = f"No fast LLM configured for user {user_id}"
print(error_message)
raise RuntimeError(error_message)
# Determine if we have documents and optimize for token limits
has_documents_initially = documents and len(documents) > 0
if has_documents_initially:
# Create base message template for token calculation (without documents)
base_human_message_template = f"""
@ -114,41 +122,49 @@ async def answer_question(state: State, config: RunnableConfig) -> Dict[str, Any
Please provide a detailed, comprehensive answer to the user's question using the information from their personal knowledge sources. Make sure to cite all information appropriately and engage in a conversational manner.
"""
# Use initial system prompt for token calculation
initial_system_prompt = get_qna_citation_system_prompt()
base_messages = state.chat_history + [
base_messages = [
*state.chat_history,
SystemMessage(content=initial_system_prompt),
HumanMessage(content=base_human_message_template)
HumanMessage(content=base_human_message_template),
]
# Optimize documents to fit within token limits
optimized_documents, has_optimized_documents = optimize_documents_for_token_limit(
documents, base_messages, llm.model
optimized_documents, has_optimized_documents = (
optimize_documents_for_token_limit(documents, base_messages, llm.model)
)
# Update state based on optimization result
documents = optimized_documents
has_documents = has_optimized_documents
else:
has_documents = False
# Choose system prompt based on final document availability
system_prompt = get_qna_citation_system_prompt() if has_documents else get_qna_no_documents_system_prompt()
system_prompt = (
get_qna_citation_system_prompt()
if has_documents
else get_qna_no_documents_system_prompt()
)
# Generate documents section
documents_text = format_documents_section(
documents,
"Source material from your personal knowledge base"
) if has_documents else ""
documents_text = (
format_documents_section(
documents, "Source material from your personal knowledge base"
)
if has_documents
else ""
)
# Create final human message content
instruction_text = (
"Please provide a detailed, comprehensive answer to the user's question using the information from their personal knowledge sources. Make sure to cite all information appropriately and engage in a conversational manner."
if has_documents else
"Please provide a helpful answer to the user's question based on our conversation history and your general knowledge. Engage in a conversational manner."
if has_documents
else "Please provide a helpful answer to the user's question based on our conversation history and your general knowledge. Engage in a conversational manner."
)
human_message_content = f"""
{documents_text}
@ -159,22 +175,20 @@ async def answer_question(state: State, config: RunnableConfig) -> Dict[str, Any
{instruction_text}
"""
# Create final messages for the LLM
messages_with_chat_history = state.chat_history + [
messages_with_chat_history = [
*state.chat_history,
SystemMessage(content=system_prompt),
HumanMessage(content=human_message_content)
HumanMessage(content=human_message_content),
]
# Log final token count
total_tokens = calculate_token_count(messages_with_chat_history, llm.model)
print(f"Final token count: {total_tokens}")
# Call the LLM and get the response
response = await llm.ainvoke(messages_with_chat_history)
final_answer = response.content
return {
"final_answer": final_answer
}
return {"final_answer": final_answer}

View file

@ -3,14 +3,16 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import List, Optional, Any
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
@dataclass
class State:
"""Defines the dynamic state for the Q&A agent during execution.
This state tracks the database session, chat history, and the outputs
This state tracks the database session, chat history, and the outputs
generated by the agent's nodes during question answering.
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
for more information.
@ -18,8 +20,8 @@ class State:
# Runtime context
db_session: AsyncSession
chat_history: Optional[List[Any]] = field(default_factory=list)
chat_history: list[Any] | None = field(default_factory=list)
# OUTPUT: Populated by agent nodes
reranked_documents: Optional[List[Any]] = None
final_answer: Optional[str] = None
reranked_documents: list[Any] | None = None
final_answer: str | None = None

View file

@ -3,10 +3,13 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import List, Optional, Any
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.streaming_service import StreamingService
@dataclass
class State:
"""Defines the dynamic state for the agent during execution.
@ -15,23 +18,23 @@ class State:
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
for more information.
"""
# Runtime context (not part of actual graph state)
db_session: AsyncSession
# Streaming service
streaming_service: StreamingService
chat_history: Optional[List[Any]] = field(default_factory=list)
reformulated_query: Optional[str] = field(default=None)
chat_history: list[Any] | None = field(default_factory=list)
reformulated_query: str | None = field(default=None)
# Using field to explicitly mark as part of state
answer_outline: Optional[Any] = field(default=None)
further_questions: Optional[Any] = field(default=None)
answer_outline: Any | None = field(default=None)
further_questions: Any | None = field(default=None)
# Temporary field to hold reranked documents from sub-agents for further question generation
reranked_documents: Optional[List[Any]] = field(default=None)
reranked_documents: list[Any] | None = field(default=None)
# OUTPUT: Populated by agent nodes
# Using field to explicitly mark as part of state
final_written_report: Optional[str] = field(default=None)
final_written_report: str | None = field(default=None)

View file

@ -4,13 +4,14 @@ from __future__ import annotations
from dataclasses import dataclass, fields
from enum import Enum
from typing import Optional, List, Any
from typing import Any
from langchain_core.runnables import RunnableConfig
class SubSectionType(Enum):
"""Enum defining the type of sub-section."""
START = "START"
MIDDLE = "MIDDLE"
END = "END"
@ -22,17 +23,16 @@ class Configuration:
# Input parameters provided at invocation
sub_section_title: str
sub_section_questions: List[str]
sub_section_questions: list[str]
sub_section_type: SubSectionType
user_query: str
relevant_documents: List[Any] # Documents provided directly to the agent
relevant_documents: list[Any] # Documents provided directly to the agent
user_id: str
search_space_id: int
@classmethod
def from_runnable_config(
cls, config: Optional[RunnableConfig] = None
cls, config: RunnableConfig | None = None
) -> Configuration:
"""Create a Configuration instance from a RunnableConfig object."""
configurable = (config.get("configurable") or {}) if config else {}

View file

@ -1,7 +1,8 @@
from langgraph.graph import StateGraph
from .state import State
from .nodes import write_sub_section, rerank_documents
from .configuration import Configuration
from .nodes import rerank_documents, write_sub_section
from .state import State
# Define a new graph
workflow = StateGraph(State, config_schema=Configuration)

View file

@ -1,25 +1,28 @@
from .configuration import Configuration
from langchain_core.runnables import RunnableConfig
from .state import State
from typing import Any, Dict
from app.services.reranker_service import RerankerService
from .prompts import get_citation_system_prompt, get_no_documents_system_prompt
from langchain_core.messages import HumanMessage, SystemMessage
from .configuration import SubSectionType
from ..utils import (
optimize_documents_for_token_limit,
calculate_token_count,
format_documents_section
)
from typing import Any
async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, Any]:
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig
from app.services.reranker_service import RerankerService
from ..utils import (
calculate_token_count,
format_documents_section,
optimize_documents_for_token_limit,
)
from .configuration import Configuration, SubSectionType
from .prompts import get_citation_system_prompt, get_no_documents_system_prompt
from .state import State
async def rerank_documents(state: State, config: RunnableConfig) -> dict[str, Any]:
"""
Rerank the documents based on relevance to the sub-section title.
This node takes the relevant documents provided in the configuration,
reranks them using the reranker service based on the sub-section title,
and updates the state with the reranked documents.
Returns:
Dict containing the reranked documents.
"""
@ -30,23 +33,23 @@ async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, An
# If no documents were provided, return empty list
if not documents or len(documents) == 0:
return {
"reranked_documents": []
}
return {"reranked_documents": []}
# Get reranker service from app config
reranker_service = RerankerService.get_reranker_instance()
# Use documents as is if no reranker service is available
reranked_docs = documents
if reranker_service:
try:
# Use the sub-section questions for reranking context
# rerank_query = "\n".join(sub_section_questions)
# rerank_query = configuration.user_query
rerank_query = configuration.user_query + "\n" + "\n".join(sub_section_questions)
rerank_query = (
configuration.user_query + "\n" + "\n".join(sub_section_questions)
)
# Convert documents to format expected by reranker if needed
reranker_input_docs = [
@ -57,54 +60,60 @@ async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, An
"document": {
"id": doc.get("document", {}).get("id", ""),
"title": doc.get("document", {}).get("title", ""),
"document_type": doc.get("document", {}).get("document_type", ""),
"metadata": doc.get("document", {}).get("metadata", {})
}
} for i, doc in enumerate(documents)
"document_type": doc.get("document", {}).get(
"document_type", ""
),
"metadata": doc.get("document", {}).get("metadata", {}),
},
}
for i, doc in enumerate(documents)
]
# Rerank documents using the section title
reranked_docs = reranker_service.rerank_documents(rerank_query, reranker_input_docs)
reranked_docs = reranker_service.rerank_documents(
rerank_query, reranker_input_docs
)
# Sort by score in descending order
reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True)
print(f"Reranked {len(reranked_docs)} documents for section: {configuration.sub_section_title}")
except Exception as e:
print(f"Error during reranking: {str(e)}")
# Use original docs if reranking fails
return {
"reranked_documents": reranked_docs
}
async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, Any]:
print(
f"Reranked {len(reranked_docs)} documents for section: {configuration.sub_section_title}"
)
except Exception as e:
print(f"Error during reranking: {e!s}")
# Use original docs if reranking fails
return {"reranked_documents": reranked_docs}
async def write_sub_section(state: State, config: RunnableConfig) -> dict[str, Any]:
"""
Write the sub-section using the provided documents.
This node takes the relevant documents provided in the configuration and uses
an LLM to generate a comprehensive answer to the sub-section title with
proper citations. The citations follow IEEE format using source IDs from the
documents. If no documents are provided, it will use chat history to generate
content.
Returns:
Dict containing the final answer in the "final_answer" key.
"""
from app.services.llm_service import get_user_fast_llm
# Get configuration and relevant documents from configuration
configuration = Configuration.from_runnable_config(config)
documents = state.reranked_documents
user_id = configuration.user_id
# Get user's fast LLM
llm = await get_user_fast_llm(state.db_session, user_id)
if not llm:
error_message = f"No fast LLM configured for user {user_id}"
print(error_message)
raise RuntimeError(error_message)
# Extract configuration data
section_title = configuration.sub_section_title
sub_section_questions = configuration.sub_section_questions
@ -113,18 +122,18 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
# Format the questions as bullet points for clarity
questions_text = "\n".join([f"- {question}" for question in sub_section_questions])
# Provide context based on the subsection type
section_position_context_map = {
SubSectionType.START: "This is the INTRODUCTION section.",
SubSectionType.MIDDLE: "This is a MIDDLE section. Ensure this content flows naturally from previous sections and into subsequent ones. This could be any middle section in the document, so maintain coherence with the overall structure while addressing the specific topic of this section. Do not provide any conclusions in this section, as conclusions should only appear in the final section.",
SubSectionType.END: "This is the CONCLUSION section. Focus on summarizing key points, providing closure."
SubSectionType.END: "This is the CONCLUSION section. Focus on summarizing key points, providing closure.",
}
section_position_context = section_position_context_map.get(sub_section_type, "")
# Determine if we have documents and optimize for token limits
has_documents_initially = documents and len(documents) > 0
if has_documents_initially:
# Create base message template for token calculation (without documents)
base_human_message_template = f"""
@ -149,38 +158,45 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
Please write content for this sub-section using the provided source material and cite all information appropriately.
"""
# Use initial system prompt for token calculation
initial_system_prompt = get_citation_system_prompt()
base_messages = state.chat_history + [
base_messages = [
*state.chat_history,
SystemMessage(content=initial_system_prompt),
HumanMessage(content=base_human_message_template)
HumanMessage(content=base_human_message_template),
]
# Optimize documents to fit within token limits
optimized_documents, has_optimized_documents = optimize_documents_for_token_limit(
documents, base_messages, llm.model
optimized_documents, has_optimized_documents = (
optimize_documents_for_token_limit(documents, base_messages, llm.model)
)
# Update state based on optimization result
documents = optimized_documents
has_documents = has_optimized_documents
else:
has_documents = False
# Choose system prompt based on final document availability
system_prompt = get_citation_system_prompt() if has_documents else get_no_documents_system_prompt()
system_prompt = (
get_citation_system_prompt()
if has_documents
else get_no_documents_system_prompt()
)
# Generate documents section
documents_text = format_documents_section(documents, "Source material") if has_documents else ""
documents_text = (
format_documents_section(documents, "Source material") if has_documents else ""
)
# Create final human message content
instruction_text = (
"Please write content for this sub-section using the provided source material and cite all information appropriately."
if has_documents else
"Please write content for this sub-section based on our conversation history and your general knowledge."
if has_documents
else "Please write content for this sub-section based on our conversation history and your general knowledge."
)
human_message_content = f"""
{documents_text}
@ -204,22 +220,20 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
{instruction_text}
"""
# Create final messages for the LLM
messages_with_chat_history = state.chat_history + [
messages_with_chat_history = [
*state.chat_history,
SystemMessage(content=system_prompt),
HumanMessage(content=human_message_content)
HumanMessage(content=human_message_content),
]
# Log final token count
total_tokens = calculate_token_count(messages_with_chat_history, llm.model)
print(f"Final token count: {total_tokens}")
# Call the LLM and get the response
response = await llm.ainvoke(messages_with_chat_history)
final_answer = response.content
return {
"final_answer": final_answer
}
return {"final_answer": final_answer}

View file

@ -182,4 +182,4 @@ When writing content for a sub-section without access to personal documents:
5. Address the guiding questions through natural content flow without explicitly listing them
6. Suggest how adding relevant sources to SurfSense could enhance future content when appropriate
</user_query_instructions>
"""
"""

View file

@ -3,9 +3,11 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import List, Optional, Any
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
@dataclass
class State:
"""Defines the dynamic state for the agent during execution.
@ -14,11 +16,11 @@ class State:
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
for more information.
"""
# Runtime context
db_session: AsyncSession
chat_history: Optional[List[Any]] = field(default_factory=list)
# OUTPUT: Populated by agent nodes
reranked_documents: Optional[List[Any]] = None
final_answer: Optional[str] = None
chat_history: list[Any] | None = field(default_factory=list)
# OUTPUT: Populated by agent nodes
reranked_documents: list[Any] | None = None
final_answer: str | None = None

View file

@ -1,27 +1,37 @@
from typing import List, Dict, Any, Tuple, NamedTuple
from typing import Any, NamedTuple
from langchain_core.messages import BaseMessage
from litellm import get_model_info, token_counter
from pydantic import BaseModel, Field
from litellm import token_counter, get_model_info
class Section(BaseModel):
"""A section in the answer outline."""
section_id: int = Field(..., description="The zero-based index of the section")
section_title: str = Field(..., description="The title of the section")
questions: List[str] = Field(..., description="Questions to research for this section")
questions: list[str] = Field(
..., description="Questions to research for this section"
)
class AnswerOutline(BaseModel):
"""The complete answer outline with all sections."""
answer_outline: List[Section] = Field(..., description="List of sections in the answer outline")
answer_outline: list[Section] = Field(
..., description="List of sections in the answer outline"
)
class DocumentTokenInfo(NamedTuple):
"""Information about a document and its token cost."""
index: int
document: Dict[str, Any]
document: dict[str, Any]
formatted_content: str
token_count: int
def get_connector_emoji(connector_name: str) -> str:
"""Get an appropriate emoji for a connector type."""
connector_emojis = {
@ -34,7 +44,7 @@ def get_connector_emoji(connector_name: str) -> str:
"GITHUB_CONNECTOR": "🐙",
"LINEAR_CONNECTOR": "📊",
"TAVILY_API": "🔍",
"LINKUP_API": "🔗"
"LINKUP_API": "🔗",
}
return connector_emojis.get(connector_name, "🔎")
@ -51,31 +61,26 @@ def get_connector_friendly_name(connector_name: str) -> str:
"GITHUB_CONNECTOR": "GitHub",
"LINEAR_CONNECTOR": "Linear",
"TAVILY_API": "Tavily Search",
"LINKUP_API": "Linkup Search"
"LINKUP_API": "Linkup Search",
}
return connector_friendly_names.get(connector_name, connector_name)
def convert_langchain_messages_to_dict(messages: List[BaseMessage]) -> List[Dict[str, str]]:
def convert_langchain_messages_to_dict(
messages: list[BaseMessage],
) -> list[dict[str, str]]:
"""Convert LangChain messages to format expected by token_counter."""
role_mapping = {
'system': 'system',
'human': 'user',
'ai': 'assistant'
}
role_mapping = {"system": "system", "human": "user", "ai": "assistant"}
converted_messages = []
for msg in messages:
role = role_mapping.get(getattr(msg, 'type', None), 'user')
converted_messages.append({
"role": role,
"content": str(msg.content)
})
role = role_mapping.get(getattr(msg, "type", None), "user")
converted_messages.append({"role": role, "content": str(msg.content)})
return converted_messages
def format_document_for_citation(document: Dict[str, Any]) -> str:
def format_document_for_citation(document: dict[str, Any]) -> str:
"""Format a single document for citation in the standard XML format."""
content = document.get("content", "")
doc_info = document.get("document", {})
@ -93,7 +98,9 @@ def format_document_for_citation(document: Dict[str, Any]) -> str:
</document>"""
def format_documents_section(documents: List[Dict[str, Any]], section_title: str = "Source material") -> str:
def format_documents_section(
documents: list[dict[str, Any]], section_title: str = "Source material"
) -> str:
"""Format multiple documents into a complete documents section."""
if not documents:
return ""
@ -106,7 +113,9 @@ def format_documents_section(documents: List[Dict[str, Any]], section_title: str
</documents>"""
def calculate_document_token_costs(documents: List[Dict[str, Any]], model: str) -> List[DocumentTokenInfo]:
def calculate_document_token_costs(
documents: list[dict[str, Any]], model: str
) -> list[DocumentTokenInfo]:
"""Pre-calculate token costs for each document."""
document_token_info = []
@ -115,24 +124,24 @@ def calculate_document_token_costs(documents: List[Dict[str, Any]], model: str)
# Calculate token count for this document
token_count = token_counter(
messages=[{"role": "user", "content": formatted_doc}],
model=model
messages=[{"role": "user", "content": formatted_doc}], model=model
)
document_token_info.append(DocumentTokenInfo(
index=i,
document=doc,
formatted_content=formatted_doc,
token_count=token_count
))
document_token_info.append(
DocumentTokenInfo(
index=i,
document=doc,
formatted_content=formatted_doc,
token_count=token_count,
)
)
return document_token_info
def find_optimal_documents_with_binary_search(
document_tokens: List[DocumentTokenInfo],
available_tokens: int
) -> List[DocumentTokenInfo]:
document_tokens: list[DocumentTokenInfo], available_tokens: int
) -> list[DocumentTokenInfo]:
"""Use binary search to find the maximum number of documents that fit within token limit."""
if not document_tokens or available_tokens <= 0:
return []
@ -143,8 +152,7 @@ def find_optimal_documents_with_binary_search(
while left <= right:
mid = (left + right) // 2
current_docs = document_tokens[:mid]
current_token_sum = sum(
doc_info.token_count for doc_info in current_docs)
current_token_sum = sum(doc_info.token_count for doc_info in current_docs)
if current_token_sum <= available_tokens:
optimal_docs = current_docs
@ -159,20 +167,18 @@ def get_model_context_window(model_name: str) -> int:
"""Get the total context window size for a model (input + output tokens)."""
try:
model_info = get_model_info(model_name)
context_window = model_info.get(
'max_input_tokens', 4096) # Default fallback
context_window = model_info.get("max_input_tokens", 4096) # Default fallback
return context_window
except Exception as e:
print(
f"Warning: Could not get model info for {model_name}, using default 4096 tokens. Error: {e}")
f"Warning: Could not get model info for {model_name}, using default 4096 tokens. Error: {e}"
)
return 4096 # Conservative fallback
def optimize_documents_for_token_limit(
documents: List[Dict[str, Any]],
base_messages: List[BaseMessage],
model_name: str
) -> Tuple[List[Dict[str, Any]], bool]:
documents: list[dict[str, Any]], base_messages: list[BaseMessage], model_name: str
) -> tuple[list[dict[str, Any]], bool]:
"""
Optimize documents to fit within token limits using binary search.
@ -197,7 +203,8 @@ def optimize_documents_for_token_limit(
available_tokens_for_docs = context_window - base_tokens
print(
f"Token optimization: Context window={context_window}, Base={base_tokens}, Available for docs={available_tokens_for_docs}")
f"Token optimization: Context window={context_window}, Base={base_tokens}, Available for docs={available_tokens_for_docs}"
)
if available_tokens_for_docs <= 0:
print("No tokens available for documents after base content and output buffer")
@ -208,8 +215,7 @@ def optimize_documents_for_token_limit(
# Find optimal number of documents using binary search
optimal_doc_info = find_optimal_documents_with_binary_search(
document_token_info,
available_tokens_for_docs
document_token_info, available_tokens_for_docs
)
# Extract the original document objects
@ -217,12 +223,13 @@ def optimize_documents_for_token_limit(
has_documents_remaining = len(optimized_documents) > 0
print(
f"Token optimization result: Using {len(optimized_documents)}/{len(documents)} documents")
f"Token optimization result: Using {len(optimized_documents)}/{len(documents)} documents"
)
return optimized_documents, has_documents_remaining
def calculate_token_count(messages: List[BaseMessage], model_name: str) -> int:
def calculate_token_count(messages: list[BaseMessage], model_name: str) -> int:
"""Calculate token count for a list of LangChain messages."""
model = model_name
messages_dict = convert_langchain_messages_to_dict(messages)

View file

@ -2,22 +2,13 @@ from contextlib import asynccontextmanager
from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import User, create_db_and_tables, get_async_session
from app.schemas import UserCreate, UserRead, UserUpdate
from app.routes import router as crud_router
from app.config import config
from app.users import (
SECRET,
auth_backend,
fastapi_users,
current_active_user
)
from app.db import User, create_db_and_tables, get_async_session
from app.routes import router as crud_router
from app.schemas import UserCreate, UserRead, UserUpdate
from app.users import SECRET, auth_backend, current_active_user, fastapi_users
@asynccontextmanager
@ -64,12 +55,10 @@ app.include_router(
if config.AUTH_TYPE == "GOOGLE":
from app.users import google_oauth_client
app.include_router(
fastapi_users.get_oauth_router(
google_oauth_client,
auth_backend,
SECRET,
is_verified_by_default=True
google_oauth_client, auth_backend, SECRET, is_verified_by_default=True
),
prefix="/auth/google",
tags=["auth"],
@ -79,5 +68,8 @@ app.include_router(crud_router, prefix="/api/v1", tags=["crud"])
@app.get("/verify-token")
async def authenticated_route(user: User = Depends(current_active_user), session: AsyncSession = Depends(get_async_session)):
async def authenticated_route(
user: User = Depends(current_active_user),
session: AsyncSession = Depends(get_async_session),
):
return {"message": "Token is valid"}

View file

@ -1,13 +1,11 @@
import os
from pathlib import Path
import shutil
from pathlib import Path
from chonkie import AutoEmbeddings, CodeChunker, RecursiveChunker
from dotenv import load_dotenv
from rerankers import Reranker
# Get the base directory of the project
BASE_DIR = Path(__file__).resolve().parent.parent.parent
@ -18,37 +16,37 @@ load_dotenv(env_file)
def is_ffmpeg_installed():
"""
Check if ffmpeg is installed on the current system.
Returns:
bool: True if ffmpeg is installed, False otherwise.
"""
return shutil.which("ffmpeg") is not None
class Config:
# Check if ffmpeg is installed
if not is_ffmpeg_installed():
import static_ffmpeg
# ffmpeg installed on first call to add_paths(), threadsafe.
static_ffmpeg.add_paths()
# check if ffmpeg is installed again
if not is_ffmpeg_installed():
raise ValueError("FFmpeg is not installed on the system. Please install it to use the Surfsense Podcaster.")
raise ValueError(
"FFmpeg is not installed on the system. Please install it to use the Surfsense Podcaster."
)
# Database
DATABASE_URL = os.getenv("DATABASE_URL")
NEXT_FRONTEND_URL = os.getenv("NEXT_FRONTEND_URL")
# AUTH: Google OAuth
AUTH_TYPE = os.getenv("AUTH_TYPE")
if AUTH_TYPE == "GOOGLE":
GOOGLE_OAUTH_CLIENT_ID = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
GOOGLE_OAUTH_CLIENT_SECRET = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
# LLM instances are now managed per-user through the LLMConfig system
# Legacy environment variables removed in favor of user-specific configurations
@ -56,12 +54,12 @@ class Config:
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
embedding_model_instance = AutoEmbeddings.get_embeddings(EMBEDDING_MODEL)
chunker_instance = RecursiveChunker(
chunk_size=getattr(embedding_model_instance, 'max_seq_length', 512)
chunk_size=getattr(embedding_model_instance, "max_seq_length", 512)
)
code_chunker_instance = CodeChunker(
chunk_size=getattr(embedding_model_instance, 'max_seq_length', 512)
chunk_size=getattr(embedding_model_instance, "max_seq_length", 512)
)
# Reranker's Configuration | Pinecode, Cohere etc. Read more at https://github.com/AnswerDotAI/rerankers?tab=readme-ov-file#usage
RERANKERS_MODEL_NAME = os.getenv("RERANKERS_MODEL_NAME")
RERANKERS_MODEL_TYPE = os.getenv("RERANKERS_MODEL_TYPE")
@ -69,45 +67,46 @@ class Config:
model_name=RERANKERS_MODEL_NAME,
model_type=RERANKERS_MODEL_TYPE,
)
# OAuth JWT
SECRET_KEY = os.getenv("SECRET_KEY")
# ETL Service
ETL_SERVICE = os.getenv("ETL_SERVICE")
if ETL_SERVICE == "UNSTRUCTURED":
# Unstructured API Key
UNSTRUCTURED_API_KEY = os.getenv("UNSTRUCTURED_API_KEY")
elif ETL_SERVICE == "LLAMACLOUD":
# LlamaCloud API Key
LLAMA_CLOUD_API_KEY = os.getenv("LLAMA_CLOUD_API_KEY")
# Firecrawl API Key
FIRECRAWL_API_KEY = os.getenv("FIRECRAWL_API_KEY", None)
FIRECRAWL_API_KEY = os.getenv("FIRECRAWL_API_KEY", None)
# Litellm TTS Configuration
TTS_SERVICE = os.getenv("TTS_SERVICE")
TTS_SERVICE_API_BASE = os.getenv("TTS_SERVICE_API_BASE")
TTS_SERVICE_API_KEY = os.getenv("TTS_SERVICE_API_KEY")
# Litellm STT Configuration
STT_SERVICE = os.getenv("STT_SERVICE")
STT_SERVICE_API_BASE = os.getenv("STT_SERVICE_API_BASE")
STT_SERVICE_API_KEY = os.getenv("STT_SERVICE_API_KEY")
# Validation Checks
# Check embedding dimension
if hasattr(embedding_model_instance, 'dimension') and embedding_model_instance.dimension > 2000:
if (
hasattr(embedding_model_instance, "dimension")
and embedding_model_instance.dimension > 2000
):
raise ValueError(
f"Embedding dimension for Model: {EMBEDDING_MODEL} "
f"has {embedding_model_instance.dimension} dimensions, which "
f"exceeds the maximum of 2000 allowed by PGVector."
)
@classmethod
def get_settings(cls):
"""Get all settings as a dictionary."""

View file

@ -1,26 +1,25 @@
import os
def _parse_bool(value):
"""Parse boolean value from string."""
return value.lower() == "true" if value else False
def _parse_int(value, var_name):
"""Parse integer value with error handling."""
try:
return int(value)
except ValueError:
raise ValueError(f"Invalid integer value for {var_name}: {value}")
raise ValueError(f"Invalid integer value for {var_name}: {value}") from None
def _parse_headers(value):
"""Parse headers from comma-separated string."""
try:
return [
tuple(h.split(":", 1))
for h in value.split(",")
if ":" in h
]
return [tuple(h.split(":", 1)) for h in value.split(",") if ":" in h]
except Exception:
raise ValueError(f"Invalid headers format: {value}")
raise ValueError(f"Invalid headers format: {value}") from None
def load_uvicorn_config(args=None):
@ -28,16 +27,16 @@ def load_uvicorn_config(args=None):
Load Uvicorn configuration from environment variables and CLI args.
Returns a dict suitable for passing to uvicorn.Config.
"""
config_kwargs = dict(
app="app.app:app",
host=os.getenv("UVICORN_HOST", "0.0.0.0"),
port=int(os.getenv("UVICORN_PORT", 8000)),
log_level=os.getenv("UVICORN_LOG_LEVEL", "info"),
reload=args.reload if args else False,
reload_dirs=["app"] if (args and args.reload) else None,
)
# Configuration mapping for advanced options
config_kwargs = {
"app": "app.app:app",
"host": os.getenv("UVICORN_HOST", "0.0.0.0"),
"port": int(os.getenv("UVICORN_PORT", 8000)),
"log_level": os.getenv("UVICORN_LOG_LEVEL", "info"),
"reload": args.reload if args else False,
"reload_dirs": ["app"] if (args and args.reload) else None,
}
# Configuration mapping for advanced options
config_mapping = {
"UVICORN_PROXY_HEADERS": ("proxy_headers", _parse_bool),
"UVICORN_FORWARDED_ALLOW_IPS": ("forwarded_allow_ips", str),
@ -51,15 +50,33 @@ def load_uvicorn_config(args=None):
"UVICORN_LOG_CONFIG": ("log_config", str),
"UVICORN_SERVER_HEADER": ("server_header", _parse_bool),
"UVICORN_DATE_HEADER": ("date_header", _parse_bool),
"UVICORN_LIMIT_CONCURRENCY": ("limit_concurrency", lambda x: _parse_int(x, "UVICORN_LIMIT_CONCURRENCY")),
"UVICORN_LIMIT_MAX_REQUESTS": ("limit_max_requests", lambda x: _parse_int(x, "UVICORN_LIMIT_MAX_REQUESTS")),
"UVICORN_TIMEOUT_KEEP_ALIVE": ("timeout_keep_alive", lambda x: _parse_int(x, "UVICORN_TIMEOUT_KEEP_ALIVE")),
"UVICORN_TIMEOUT_NOTIFY": ("timeout_notify", lambda x: _parse_int(x, "UVICORN_TIMEOUT_NOTIFY")),
"UVICORN_LIMIT_CONCURRENCY": (
"limit_concurrency",
lambda x: _parse_int(x, "UVICORN_LIMIT_CONCURRENCY"),
),
"UVICORN_LIMIT_MAX_REQUESTS": (
"limit_max_requests",
lambda x: _parse_int(x, "UVICORN_LIMIT_MAX_REQUESTS"),
),
"UVICORN_TIMEOUT_KEEP_ALIVE": (
"timeout_keep_alive",
lambda x: _parse_int(x, "UVICORN_TIMEOUT_KEEP_ALIVE"),
),
"UVICORN_TIMEOUT_NOTIFY": (
"timeout_notify",
lambda x: _parse_int(x, "UVICORN_TIMEOUT_NOTIFY"),
),
"UVICORN_SSL_KEYFILE": ("ssl_keyfile", str),
"UVICORN_SSL_CERTFILE": ("ssl_certfile", str),
"UVICORN_SSL_KEYFILE_PASSWORD": ("ssl_keyfile_password", str),
"UVICORN_SSL_VERSION": ("ssl_version", lambda x: _parse_int(x, "UVICORN_SSL_VERSION")),
"UVICORN_SSL_CERT_REQS": ("ssl_cert_reqs", lambda x: _parse_int(x, "UVICORN_SSL_CERT_REQS")),
"UVICORN_SSL_VERSION": (
"ssl_version",
lambda x: _parse_int(x, "UVICORN_SSL_VERSION"),
),
"UVICORN_SSL_CERT_REQS": (
"ssl_cert_reqs",
lambda x: _parse_int(x, "UVICORN_SSL_CERT_REQS"),
),
"UVICORN_SSL_CA_CERTS": ("ssl_ca_certs", str),
"UVICORN_SSL_CIPHERS": ("ssl_ciphers", str),
"UVICORN_HEADERS": ("headers", _parse_headers),
@ -76,7 +93,6 @@ def load_uvicorn_config(args=None):
try:
config_kwargs[config_key] = parser(value)
except ValueError as e:
raise ValueError(f"Configuration error for {env_var}: {e}")
raise ValueError(f"Configuration error for {env_var}: {e}") from e
return config_kwargs

View file

@ -6,11 +6,12 @@ A module for interacting with Discord's HTTP API to retrieve guilds, channels, a
Requires a Discord bot token.
"""
import asyncio
import datetime
import logging
import discord
from discord.ext import commands
import datetime
import asyncio
logger = logging.getLogger(__name__)
@ -18,7 +19,7 @@ logger = logging.getLogger(__name__)
class DiscordConnector(commands.Bot):
"""Class for retrieving guild, channel, and message history from Discord."""
def __init__(self, token: str = None):
def __init__(self, token: str | None = None):
"""
Initialize the DiscordConnector with a bot token.
@ -30,7 +31,9 @@ class DiscordConnector(commands.Bot):
intents.messages = True # Required to fetch messages
intents.message_content = True # Required to read message content
intents.members = True # Required to fetch member information
super().__init__(command_prefix="!", intents=intents) # command_prefix is required but not strictly used here
super().__init__(
command_prefix="!", intents=intents
) # command_prefix is required but not strictly used here
self.token = token
self._bot_task = None # Holds the async bot task
self._is_running = False # Flag to track if the bot is running
@ -48,7 +51,7 @@ class DiscordConnector(commands.Bot):
@self.event
async def on_disconnect():
logger.debug("Bot disconnected from Discord gateway.")
self._is_running = False # Reset flag on disconnect
self._is_running = False # Reset flag on disconnect
@self.event
async def on_resumed():
@ -63,17 +66,23 @@ class DiscordConnector(commands.Bot):
try:
if self._is_running:
logger.warning("Bot is already running. Use close_bot() to stop it before starting again.")
logger.warning(
"Bot is already running. Use close_bot() to stop it before starting again."
)
return
await self.start(self.token)
logger.info("Discord bot started successfully.")
except discord.LoginFailure:
logger.error("Failed to log in: Invalid token was provided. Please check your bot token.")
logger.error(
"Failed to log in: Invalid token was provided. Please check your bot token."
)
self._is_running = False
raise
except discord.PrivilegedIntentsRequired as e:
logger.error(f"Privileged Intents Required: {e}. Make sure all required intents are enabled in your bot's application page.")
logger.error(
f"Privileged Intents Required: {e}. Make sure all required intents are enabled in your bot's application page."
)
self._is_running = False
raise
except discord.ConnectionClosed as e:
@ -96,7 +105,6 @@ class DiscordConnector(commands.Bot):
else:
logger.info("Bot is not running or already disconnected.")
def set_token(self, token: str) -> None:
"""
Set the discord bot token.
@ -106,8 +114,10 @@ class DiscordConnector(commands.Bot):
"""
logger.info("Setting Discord bot token.")
self.token = token
logger.info("Token set successfully. You can now start the bot with start_bot().")
logger.info(
"Token set successfully. You can now start the bot with start_bot()."
)
async def _wait_until_ready(self):
"""Helper to wait until the bot is connected and ready."""
logger.info("Waiting for the bot to be ready...")
@ -115,16 +125,20 @@ class DiscordConnector(commands.Bot):
# Give the event loop a chance to switch to the bot's startup task.
# This allows self.start() to begin initializing the client.
# Terrible solution, but necessary to avoid blocking the event loop.
await asyncio.sleep(1) # Yield control to the event loop
await asyncio.sleep(1) # Yield control to the event loop
try:
await asyncio.wait_for(self.wait_until_ready(), timeout=60.0)
logger.info("Bot is ready.")
except asyncio.TimeoutError:
logger.error(f"Bot did not become ready within 60 seconds. Connection may have failed.")
except TimeoutError:
logger.error(
"Bot did not become ready within 60 seconds. Connection may have failed."
)
raise
except Exception as e:
logger.error(f"An unexpected error occurred while waiting for the bot to be ready: {e}")
logger.error(
f"An unexpected error occurred while waiting for the bot to be ready: {e}"
)
raise
async def get_guilds(self) -> list[dict]:
@ -143,7 +157,9 @@ class DiscordConnector(commands.Bot):
guilds_data = []
for guild in self.guilds:
member_count = guild.member_count if guild.member_count is not None else "N/A"
member_count = (
guild.member_count if guild.member_count is not None else "N/A"
)
guilds_data.append(
{
"id": str(guild.id),
@ -183,15 +199,17 @@ class DiscordConnector(commands.Bot):
channels_data.append(
{"id": str(channel.id), "name": channel.name, "type": "text"}
)
logger.info(f"Fetched {len(channels_data)} text channels from guild {guild_id}.")
logger.info(
f"Fetched {len(channels_data)} text channels from guild {guild_id}."
)
return channels_data
async def get_channel_history(
self,
channel_id: str,
start_date: str = None,
end_date: str = None,
start_date: str | None = None,
end_date: str | None = None,
) -> list[dict]:
"""
Fetch message history from a text channel.
@ -227,20 +245,26 @@ class DiscordConnector(commands.Bot):
if start_date:
try:
start_datetime = datetime.datetime.fromisoformat(start_date).replace(tzinfo=datetime.timezone.utc)
start_datetime = datetime.datetime.fromisoformat(start_date).replace(
tzinfo=datetime.UTC
)
after = start_datetime
except ValueError:
logger.warning(f"Invalid start_date format: {start_date}. Ignoring.")
if end_date:
try:
end_datetime = datetime.datetime.fromisoformat(f"{end_date}").replace(tzinfo=datetime.timezone.utc)
end_datetime = datetime.datetime.fromisoformat(f"{end_date}").replace(
tzinfo=datetime.UTC
)
before = end_datetime
except ValueError:
logger.warning(f"Invalid end_date format: {end_date}. Ignoring.")
try:
async for message in channel.history(limit=None, before=before, after=after):
async for message in channel.history(
limit=None, before=before, after=after
):
messages_data.append(
{
"id": str(message.id),
@ -251,12 +275,14 @@ class DiscordConnector(commands.Bot):
}
)
except discord.Forbidden:
logger.error(f"Bot does not have permissions to read message history in channel {channel_id}.")
logger.error(
f"Bot does not have permissions to read message history in channel {channel_id}."
)
raise
except discord.HTTPException as e:
logger.error(f"Failed to fetch messages from channel {channel_id}: {e}")
return []
logger.info(f"Fetched {len(messages_data)} messages from channel {channel_id}.")
return messages_data
@ -278,7 +304,9 @@ class DiscordConnector(commands.Bot):
permissions to view members.
"""
await self._wait_until_ready()
logger.info(f"Fetching user info for user ID: {user_id} in guild ID: {guild_id}")
logger.info(
f"Fetching user info for user ID: {user_id} in guild ID: {guild_id}"
)
guild = self.get_guild(int(guild_id))
if not guild:
@ -294,7 +322,9 @@ class DiscordConnector(commands.Bot):
return {
"id": str(member.id),
"name": member.name,
"joined_at": member.joined_at.isoformat() if member.joined_at else None,
"joined_at": member.joined_at.isoformat()
if member.joined_at
else None,
"roles": roles,
}
logger.warning(f"User {user_id} not found in guild {guild_id}.")
@ -303,8 +333,12 @@ class DiscordConnector(commands.Bot):
logger.warning(f"User {user_id} not found in guild {guild_id}.")
return None
except discord.Forbidden:
logger.error(f"Bot does not have permissions to fetch members in guild {guild_id}. Ensure GUILD_MEMBERS intent is enabled.")
logger.error(
f"Bot does not have permissions to fetch members in guild {guild_id}. Ensure GUILD_MEMBERS intent is enabled."
)
raise
except discord.HTTPException as e:
logger.error(f"Failed to fetch user info for {user_id} in guild {guild_id}: {e}")
logger.error(
f"Failed to fetch user info for {user_id} in guild {guild_id}: {e}"
)
return None

View file

@ -1,54 +1,91 @@
import base64
import logging
from typing import List, Optional, Dict, Any
from github3 import login as github_login, exceptions as github_exceptions
from github3.repos.contents import Contents
from typing import Any
from github3 import exceptions as github_exceptions, login as github_login
from github3.exceptions import ForbiddenError, NotFoundError
from github3.repos.contents import Contents
logger = logging.getLogger(__name__)
# List of common code file extensions to target
CODE_EXTENSIONS = {
'.py', '.js', '.jsx', '.ts', '.tsx', '.java', '.c', '.cpp', '.h', '.hpp',
'.cs', '.go', '.rb', '.php', '.swift', '.kt', '.scala', '.rs', '.m',
'.sh', '.bash', '.ps1', '.lua', '.pl', '.pm', '.r', '.dart', '.sql'
".py",
".js",
".jsx",
".ts",
".tsx",
".java",
".c",
".cpp",
".h",
".hpp",
".cs",
".go",
".rb",
".php",
".swift",
".kt",
".scala",
".rs",
".m",
".sh",
".bash",
".ps1",
".lua",
".pl",
".pm",
".r",
".dart",
".sql",
}
# List of common documentation/text file extensions
DOC_EXTENSIONS = {
'.md', '.txt', '.rst', '.adoc', '.html', '.htm', '.xml', '.json', '.yaml', '.yml', '.toml'
".md",
".txt",
".rst",
".adoc",
".html",
".htm",
".xml",
".json",
".yaml",
".yml",
".toml",
}
# Maximum file size in bytes (e.g., 1MB)
MAX_FILE_SIZE = 1 * 1024 * 1024
class GitHubConnector:
"""Connector for interacting with the GitHub API."""
# Directories to skip during file traversal
SKIPPED_DIRS = {
# Version control
'.git',
".git",
# Dependencies
'node_modules',
'vendor',
"node_modules",
"vendor",
# Build artifacts / Caches
'build',
'dist',
'target',
'__pycache__',
"build",
"dist",
"target",
"__pycache__",
# Virtual environments
'venv',
'.venv',
'env',
"venv",
".venv",
"env",
# IDE/Editor config
'.vscode',
'.idea',
'.project',
'.settings',
".vscode",
".idea",
".project",
".settings",
# Temporary / Logs
'tmp',
'logs',
"tmp",
"logs",
# Add other project-specific irrelevant directories if needed
}
@ -68,35 +105,39 @@ class GitHubConnector:
logger.info("Successfully authenticated with GitHub API.")
except (github_exceptions.AuthenticationFailed, ForbiddenError) as e:
logger.error(f"GitHub authentication failed: {e}")
raise ValueError("Invalid GitHub token or insufficient permissions.")
raise ValueError("Invalid GitHub token or insufficient permissions.") from e
except Exception as e:
logger.error(f"Failed to initialize GitHub client: {e}")
raise
raise e
def get_user_repositories(self) -> List[Dict[str, Any]]:
def get_user_repositories(self) -> list[dict[str, Any]]:
"""Fetches repositories accessible by the authenticated user."""
repos_data = []
try:
# type='owner' fetches repos owned by the user
# type='member' fetches repos the user is a collaborator on (including orgs)
# type='all' fetches both
for repo in self.gh.repositories(type='all', sort='updated'):
repos_data.append({
"id": repo.id,
"name": repo.name,
"full_name": repo.full_name,
"private": repo.private,
"url": repo.html_url,
"description": repo.description or "",
"last_updated": repo.updated_at if repo.updated_at else None,
})
for repo in self.gh.repositories(type="all", sort="updated"):
repos_data.append(
{
"id": repo.id,
"name": repo.name,
"full_name": repo.full_name,
"private": repo.private,
"url": repo.html_url,
"description": repo.description or "",
"last_updated": repo.updated_at if repo.updated_at else None,
}
)
logger.info(f"Fetched {len(repos_data)} repositories.")
return repos_data
except Exception as e:
logger.error(f"Failed to fetch GitHub repositories: {e}")
return [] # Return empty list on error
return [] # Return empty list on error
def get_repository_files(self, repo_full_name: str, path: str = '') -> List[Dict[str, Any]]:
def get_repository_files(
self, repo_full_name: str, path: str = ""
) -> list[dict[str, Any]]:
"""
Recursively fetches details of relevant files (code, docs) within a repository path.
@ -110,54 +151,72 @@ class GitHubConnector:
"""
files_list = []
try:
owner, repo_name = repo_full_name.split('/')
owner, repo_name = repo_full_name.split("/")
repo = self.gh.repository(owner, repo_name)
if not repo:
logger.warning(f"Repository '{repo_full_name}' not found.")
return []
contents = repo.directory_contents(directory_path=path) # Use directory_contents for clarity
contents = repo.directory_contents(
directory_path=path
) # Use directory_contents for clarity
# contents returns a list of tuples (name, content_obj)
for item_name, content_item in contents:
for _item_name, content_item in contents:
if not isinstance(content_item, Contents):
continue
if content_item.type == 'dir':
if content_item.type == "dir":
# Check if the directory name is in the skipped list
if content_item.name in self.SKIPPED_DIRS:
logger.debug(f"Skipping directory: {content_item.path}")
continue # Skip recursion for this directory
continue # Skip recursion for this directory
# Recursively fetch contents of subdirectory
files_list.extend(self.get_repository_files(repo_full_name, path=content_item.path))
elif content_item.type == 'file':
files_list.extend(
self.get_repository_files(
repo_full_name, path=content_item.path
)
)
elif content_item.type == "file":
# Check if the file extension is relevant and size is within limits
file_extension = '.' + content_item.name.split('.')[-1].lower() if '.' in content_item.name else ''
file_extension = (
"." + content_item.name.split(".")[-1].lower()
if "." in content_item.name
else ""
)
is_code = file_extension in CODE_EXTENSIONS
is_doc = file_extension in DOC_EXTENSIONS
if (is_code or is_doc) and content_item.size <= MAX_FILE_SIZE:
files_list.append({
"path": content_item.path,
"sha": content_item.sha,
"url": content_item.html_url,
"size": content_item.size,
"type": "code" if is_code else "doc"
})
files_list.append(
{
"path": content_item.path,
"sha": content_item.sha,
"url": content_item.html_url,
"size": content_item.size,
"type": "code" if is_code else "doc",
}
)
elif content_item.size > MAX_FILE_SIZE:
logger.debug(f"Skipping large file: {content_item.path} ({content_item.size} bytes)")
logger.debug(
f"Skipping large file: {content_item.path} ({content_item.size} bytes)"
)
else:
logger.debug(f"Skipping irrelevant file type: {content_item.path}")
logger.debug(
f"Skipping irrelevant file type: {content_item.path}"
)
except (NotFoundError, ForbiddenError) as e:
logger.warning(f"Cannot access path '{path}' in '{repo_full_name}': {e}")
logger.warning(f"Cannot access path '{path}' in '{repo_full_name}': {e}")
except Exception as e:
logger.error(f"Failed to get files for {repo_full_name} at path '{path}': {e}")
logger.error(
f"Failed to get files for {repo_full_name} at path '{path}': {e}"
)
# Return what we have collected so far in case of partial failure
return files_list
def get_file_content(self, repo_full_name: str, file_path: str) -> Optional[str]:
def get_file_content(self, repo_full_name: str, file_path: str) -> str | None:
"""
Fetches the decoded content of a specific file.
@ -169,43 +228,69 @@ class GitHubConnector:
The decoded file content as a string, or None if fetching fails or file is too large.
"""
try:
owner, repo_name = repo_full_name.split('/')
owner, repo_name = repo_full_name.split("/")
repo = self.gh.repository(owner, repo_name)
if not repo:
logger.warning(f"Repository '{repo_full_name}' not found when fetching file '{file_path}'.")
logger.warning(
f"Repository '{repo_full_name}' not found when fetching file '{file_path}'."
)
return None
content_item = repo.file_contents(path=file_path) # Use file_contents for clarity
if not content_item or not isinstance(content_item, Contents) or content_item.type != 'file':
logger.warning(f"File '{file_path}' not found or is not a file in '{repo_full_name}'.")
content_item = repo.file_contents(
path=file_path
) # Use file_contents for clarity
if (
not content_item
or not isinstance(content_item, Contents)
or content_item.type != "file"
):
logger.warning(
f"File '{file_path}' not found or is not a file in '{repo_full_name}'."
)
return None
if content_item.size > MAX_FILE_SIZE:
logger.warning(f"File '{file_path}' in '{repo_full_name}' exceeds max size ({content_item.size} > {MAX_FILE_SIZE}). Skipping content fetch.")
logger.warning(
f"File '{file_path}' in '{repo_full_name}' exceeds max size ({content_item.size} > {MAX_FILE_SIZE}). Skipping content fetch."
)
return None
# Content is base64 encoded
if content_item.content:
try:
decoded_content = base64.b64decode(content_item.content).decode('utf-8')
decoded_content = base64.b64decode(content_item.content).decode(
"utf-8"
)
return decoded_content
except UnicodeDecodeError:
logger.warning(f"Could not decode file '{file_path}' in '{repo_full_name}' as UTF-8. Trying with 'latin-1'.")
logger.warning(
f"Could not decode file '{file_path}' in '{repo_full_name}' as UTF-8. Trying with 'latin-1'."
)
try:
# Try a fallback encoding
decoded_content = base64.b64decode(content_item.content).decode('latin-1')
decoded_content = base64.b64decode(content_item.content).decode(
"latin-1"
)
return decoded_content
except Exception as decode_err:
logger.error(f"Failed to decode file '{file_path}' with fallback encoding: {decode_err}")
return None # Give up if fallback fails
logger.error(
f"Failed to decode file '{file_path}' with fallback encoding: {decode_err}"
)
return None # Give up if fallback fails
else:
logger.warning(f"No content returned for file '{file_path}' in '{repo_full_name}'. It might be empty.")
return "" # Return empty string for empty files
logger.warning(
f"No content returned for file '{file_path}' in '{repo_full_name}'. It might be empty."
)
return "" # Return empty string for empty files
except (NotFoundError, ForbiddenError) as e:
logger.warning(f"Cannot access file '{file_path}' in '{repo_full_name}': {e}")
return None
logger.warning(
f"Cannot access file '{file_path}' in '{repo_full_name}': {e}"
)
return None
except Exception as e:
logger.error(f"Failed to get content for file '{file_path}' in '{repo_full_name}': {e}")
return None
logger.error(
f"Failed to get content for file '{file_path}' in '{repo_full_name}': {e}"
)
return None

View file

@ -5,96 +5,94 @@ A module for retrieving issues and comments from Linear.
Allows fetching issue lists and their comments with date range filtering.
"""
import requests
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Any, Union
from typing import Any
import requests
class LinearConnector:
"""Class for retrieving issues and comments from Linear."""
def __init__(self, token: str = None):
def __init__(self, token: str | None = None):
"""
Initialize the LinearConnector class.
Args:
token: Linear API token (optional, can be set later with set_token)
"""
self.token = token
self.api_url = "https://api.linear.app/graphql"
def set_token(self, token: str) -> None:
"""
Set the Linear API token.
Args:
token: Linear API token
"""
self.token = token
def get_headers(self) -> Dict[str, str]:
def get_headers(self) -> dict[str, str]:
"""
Get headers for Linear API requests.
Returns:
Dictionary of headers
Raises:
ValueError: If no Linear token has been set
"""
if not self.token:
raise ValueError("Linear token not initialized. Call set_token() first.")
return {
'Content-Type': 'application/json',
'Authorization': self.token
}
def execute_graphql_query(self, query: str, variables: Dict[str, Any] = None) -> Dict[str, Any]:
return {"Content-Type": "application/json", "Authorization": self.token}
def execute_graphql_query(
self, query: str, variables: dict[str, Any] | None = None
) -> dict[str, Any]:
"""
Execute a GraphQL query against the Linear API.
Args:
query: GraphQL query string
variables: Variables for the GraphQL query (optional)
Returns:
Response data from the API
Raises:
ValueError: If no Linear token has been set
Exception: If the API request fails
"""
if not self.token:
raise ValueError("Linear token not initialized. Call set_token() first.")
headers = self.get_headers()
payload = {'query': query}
payload = {"query": query}
if variables:
payload['variables'] = variables
response = requests.post(
self.api_url,
headers=headers,
json=payload
)
payload["variables"] = variables
response = requests.post(self.api_url, headers=headers, json=payload)
if response.status_code == 200:
return response.json()
else:
raise Exception(f"Query failed with status code {response.status_code}: {response.text}")
def get_all_issues(self, include_comments: bool = True) -> List[Dict[str, Any]]:
raise Exception(
f"Query failed with status code {response.status_code}: {response.text}"
)
def get_all_issues(self, include_comments: bool = True) -> list[dict[str, Any]]:
"""
Fetch all issues from Linear.
Args:
include_comments: Whether to include comments in the response
Returns:
List of issue objects
Raises:
ValueError: If no Linear token has been set
Exception: If the API request fails
@ -116,7 +114,7 @@ class LinearConnector:
}
}
"""
query = f"""
query {{
issues {{
@ -147,29 +145,30 @@ class LinearConnector:
}}
}}
"""
result = self.execute_graphql_query(query)
# Extract issues from the response
if "data" in result and "issues" in result["data"] and "nodes" in result["data"]["issues"]:
if (
"data" in result
and "issues" in result["data"]
and "nodes" in result["data"]["issues"]
):
return result["data"]["issues"]["nodes"]
return []
def get_issues_by_date_range(
self,
start_date: str,
end_date: str,
include_comments: bool = True
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
self, start_date: str, end_date: str, include_comments: bool = True
) -> tuple[list[dict[str, Any]], str | None]:
"""
Fetch issues within a date range.
Args:
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format (inclusive)
include_comments: Whether to include comments in the response
Returns:
Tuple containing (issues list, error message or None)
"""
@ -194,7 +193,7 @@ class LinearConnector:
}
}
"""
# Query issues that were either created OR updated within the date range
# This ensures we catch both new issues and updated existing issues
query = f"""
@ -250,58 +249,65 @@ class LinearConnector:
}}
}}
"""
try:
all_issues = []
has_next_page = True
cursor = None
# Handle pagination to get all issues
while has_next_page:
variables = {"after": cursor} if cursor else {}
result = self.execute_graphql_query(query, variables)
# Check for errors
if "errors" in result:
error_message = "; ".join([error.get("message", "Unknown error") for error in result["errors"]])
error_message = "; ".join(
[
error.get("message", "Unknown error")
for error in result["errors"]
]
)
return [], f"GraphQL errors: {error_message}"
# Extract issues from the response
if "data" in result and "issues" in result["data"]:
issues_page = result["data"]["issues"]
# Add issues from this page
if "nodes" in issues_page:
all_issues.extend(issues_page["nodes"])
# Check if there are more pages
if "pageInfo" in issues_page:
page_info = issues_page["pageInfo"]
has_next_page = page_info.get("hasNextPage", False)
cursor = page_info.get("endCursor") if has_next_page else None
cursor = (
page_info.get("endCursor") if has_next_page else None
)
else:
has_next_page = False
else:
has_next_page = False
if not all_issues:
return [], "No issues found in the specified date range."
return all_issues, None
except Exception as e:
return [], f"Error fetching issues: {str(e)}"
return [], f"Error fetching issues: {e!s}"
except ValueError as e:
return [], f"Invalid date format: {str(e)}. Please use YYYY-MM-DD."
def format_issue(self, issue: Dict[str, Any]) -> Dict[str, Any]:
return [], f"Invalid date format: {e!s}. Please use YYYY-MM-DD."
def format_issue(self, issue: dict[str, Any]) -> dict[str, Any]:
"""
Format an issue for easier consumption.
Args:
issue: The issue object from Linear API
Returns:
Formatted issue dictionary
"""
@ -311,23 +317,37 @@ class LinearConnector:
"identifier": issue.get("identifier", ""),
"title": issue.get("title", ""),
"description": issue.get("description", ""),
"state": issue.get("state", {}).get("name", "Unknown") if issue.get("state") else "Unknown",
"state_type": issue.get("state", {}).get("type", "Unknown") if issue.get("state") else "Unknown",
"state": issue.get("state", {}).get("name", "Unknown")
if issue.get("state")
else "Unknown",
"state_type": issue.get("state", {}).get("type", "Unknown")
if issue.get("state")
else "Unknown",
"created_at": issue.get("createdAt", ""),
"updated_at": issue.get("updatedAt", ""),
"creator": {
"id": issue.get("creator", {}).get("id", "") if issue.get("creator") else "",
"name": issue.get("creator", {}).get("name", "Unknown") if issue.get("creator") else "Unknown",
"email": issue.get("creator", {}).get("email", "") if issue.get("creator") else ""
} if issue.get("creator") else {"id": "", "name": "Unknown", "email": ""},
"id": issue.get("creator", {}).get("id", "")
if issue.get("creator")
else "",
"name": issue.get("creator", {}).get("name", "Unknown")
if issue.get("creator")
else "Unknown",
"email": issue.get("creator", {}).get("email", "")
if issue.get("creator")
else "",
}
if issue.get("creator")
else {"id": "", "name": "Unknown", "email": ""},
"assignee": {
"id": issue.get("assignee", {}).get("id", ""),
"name": issue.get("assignee", {}).get("name", "Unknown"),
"email": issue.get("assignee", {}).get("email", "")
} if issue.get("assignee") else None,
"comments": []
"email": issue.get("assignee", {}).get("email", ""),
}
if issue.get("assignee")
else None,
"comments": [],
}
# Extract comments if available
if "comments" in issue and "nodes" in issue["comments"]:
for comment in issue["comments"]["nodes"]:
@ -337,85 +357,93 @@ class LinearConnector:
"created_at": comment.get("createdAt", ""),
"updated_at": comment.get("updatedAt", ""),
"user": {
"id": comment.get("user", {}).get("id", "") if comment.get("user") else "",
"name": comment.get("user", {}).get("name", "Unknown") if comment.get("user") else "Unknown",
"email": comment.get("user", {}).get("email", "") if comment.get("user") else ""
} if comment.get("user") else {"id": "", "name": "Unknown", "email": ""}
"id": comment.get("user", {}).get("id", "")
if comment.get("user")
else "",
"name": comment.get("user", {}).get("name", "Unknown")
if comment.get("user")
else "Unknown",
"email": comment.get("user", {}).get("email", "")
if comment.get("user")
else "",
}
if comment.get("user")
else {"id": "", "name": "Unknown", "email": ""},
}
formatted["comments"].append(formatted_comment)
return formatted
def format_issue_to_markdown(self, issue: Dict[str, Any]) -> str:
def format_issue_to_markdown(self, issue: dict[str, Any]) -> str:
"""
Convert an issue to markdown format.
Args:
issue: The issue object (either raw or formatted)
Returns:
Markdown string representation of the issue
"""
# Format the issue if it's not already formatted
if "identifier" not in issue:
issue = self.format_issue(issue)
# Build the markdown content
markdown = f"# {issue.get('identifier', 'No ID')}: {issue.get('title', 'No Title')}\n\n"
if issue.get('state'):
if issue.get("state"):
markdown += f"**Status:** {issue['state']}\n\n"
if issue.get('assignee') and issue['assignee'].get('name'):
if issue.get("assignee") and issue["assignee"].get("name"):
markdown += f"**Assignee:** {issue['assignee']['name']}\n"
if issue.get('creator') and issue['creator'].get('name'):
if issue.get("creator") and issue["creator"].get("name"):
markdown += f"**Created by:** {issue['creator']['name']}\n"
if issue.get('created_at'):
created_date = self.format_date(issue['created_at'])
if issue.get("created_at"):
created_date = self.format_date(issue["created_at"])
markdown += f"**Created:** {created_date}\n"
if issue.get('updated_at'):
updated_date = self.format_date(issue['updated_at'])
if issue.get("updated_at"):
updated_date = self.format_date(issue["updated_at"])
markdown += f"**Updated:** {updated_date}\n\n"
if issue.get('description'):
if issue.get("description"):
markdown += f"## Description\n\n{issue['description']}\n\n"
if issue.get('comments'):
if issue.get("comments"):
markdown += f"## Comments ({len(issue['comments'])})\n\n"
for comment in issue['comments']:
for comment in issue["comments"]:
user_name = "Unknown"
if comment.get('user') and comment['user'].get('name'):
user_name = comment['user']['name']
if comment.get("user") and comment["user"].get("name"):
user_name = comment["user"]["name"]
comment_date = "Unknown date"
if comment.get('created_at'):
comment_date = self.format_date(comment['created_at'])
if comment.get("created_at"):
comment_date = self.format_date(comment["created_at"])
markdown += f"### {user_name} ({comment_date})\n\n{comment.get('body', '')}\n\n---\n\n"
return markdown
@staticmethod
def format_date(iso_date: str) -> str:
"""
Format an ISO date string to a more readable format.
Args:
iso_date: ISO format date string
Returns:
Formatted date string
"""
if not iso_date or not isinstance(iso_date, str):
return "Unknown date"
try:
dt = datetime.fromisoformat(iso_date.replace('Z', '+00:00'))
return dt.strftime('%Y-%m-%d %H:%M:%S')
dt = datetime.fromisoformat(iso_date.replace("Z", "+00:00"))
return dt.strftime("%Y-%m-%d %H:%M:%S")
except ValueError:
return iso_date

View file

@ -1,176 +1,182 @@
from notion_client import Client
class NotionHistoryConnector:
def __init__(self, token):
"""
Initialize the NotionPageFetcher with a token.
Args:
token (str): Notion integration token
"""
self.notion = Client(auth=token)
def get_all_pages(self, start_date=None, end_date=None):
"""
Fetches all pages shared with your integration and their content.
Args:
start_date (str, optional): ISO 8601 date string (e.g., "2023-01-01T00:00:00Z")
end_date (str, optional): ISO 8601 date string (e.g., "2023-12-31T23:59:59Z")
Returns:
list: List of dictionaries containing page data
"""
# Build the filter for the search
# Note: Notion API requires specific filter structure
search_params = {}
# Filter for pages only (not databases)
search_params["filter"] = {
"value": "page",
"property": "object"
}
search_params["filter"] = {"value": "page", "property": "object"}
# Add date filters if provided
if start_date or end_date:
date_filter = {}
if start_date:
date_filter["on_or_after"] = start_date
if end_date:
date_filter["on_or_before"] = end_date
# Add the date filter to the search params
if date_filter:
search_params["sort"] = {
"direction": "descending",
"timestamp": "last_edited_time"
"timestamp": "last_edited_time",
}
# First, get a list of all pages the integration has access to
search_results = self.notion.search(**search_params)
pages = search_results["results"]
all_page_data = []
for page in pages:
page_id = page["id"]
# Get detailed page information
page_content = self.get_page_content(page_id)
all_page_data.append({
"page_id": page_id,
"title": self.get_page_title(page),
"content": page_content
})
all_page_data.append(
{
"page_id": page_id,
"title": self.get_page_title(page),
"content": page_content,
}
)
return all_page_data
def get_page_title(self, page):
"""
Extracts the title from a page object.
Args:
page (dict): Notion page object
Returns:
str: Page title or a fallback string
"""
# Title can be in different properties depending on the page type
if "properties" in page:
# Try to find a title property
for prop_name, prop_data in page["properties"].items():
for _prop_name, prop_data in page["properties"].items():
if prop_data["type"] == "title" and len(prop_data["title"]) > 0:
return " ".join([text_obj["plain_text"] for text_obj in prop_data["title"]])
return " ".join(
[text_obj["plain_text"] for text_obj in prop_data["title"]]
)
# If no title found, return the page ID as fallback
return f"Untitled page ({page['id']})"
def get_page_content(self, page_id):
"""
Fetches the content (blocks) of a specific page.
Args:
page_id (str): The ID of the page to fetch
Returns:
list: List of processed blocks from the page
"""
blocks = []
has_more = True
cursor = None
# Paginate through all blocks
while has_more:
if cursor:
response = self.notion.blocks.children.list(block_id=page_id, start_cursor=cursor)
response = self.notion.blocks.children.list(
block_id=page_id, start_cursor=cursor
)
else:
response = self.notion.blocks.children.list(block_id=page_id)
blocks.extend(response["results"])
has_more = response["has_more"]
if has_more:
cursor = response["next_cursor"]
# Process nested blocks recursively
processed_blocks = []
for block in blocks:
processed_block = self.process_block(block)
processed_blocks.append(processed_block)
return processed_blocks
def process_block(self, block):
"""
Processes a block and recursively fetches any child blocks.
Args:
block (dict): The block to process
Returns:
dict: Processed block with content and children
"""
block_id = block["id"]
block_type = block["type"]
# Extract block content based on its type
content = self.extract_block_content(block)
# Check if block has children
has_children = block.get("has_children", False)
child_blocks = []
if has_children:
# Fetch and process child blocks
children_response = self.notion.blocks.children.list(block_id=block_id)
for child_block in children_response["results"]:
child_blocks.append(self.process_block(child_block))
return {
"id": block_id,
"type": block_type,
"content": content,
"children": child_blocks
"children": child_blocks,
}
def extract_block_content(self, block):
"""
Extracts the content from a block based on its type.
Args:
block (dict): The block to extract content from
Returns:
str: Extracted content as a string
"""
block_type = block["type"]
# Different block types have different structures
if block_type in block and "rich_text" in block[block_type]:
return "".join([text_obj["plain_text"] for text_obj in block[block_type]["rich_text"]])
return "".join(
[text_obj["plain_text"] for text_obj in block[block_type]["rich_text"]]
)
elif block_type == "image":
# Instead of returning the raw URL which may contain sensitive AWS credentials,
# return a placeholder or reference to the image
@ -183,18 +189,21 @@ class NotionHistoryConnector:
# Only return the domain part of external URLs to avoid potential sensitive parameters
try:
from urllib.parse import urlparse
parsed_url = urlparse(url)
return f"[External Image from {parsed_url.netloc}]"
except:
except Exception:
return "[External Image]"
elif block_type == "code":
language = block["code"]["language"]
code_text = "".join([text_obj["plain_text"] for text_obj in block["code"]["rich_text"]])
code_text = "".join(
[text_obj["plain_text"] for text_obj in block["code"]["rich_text"]]
)
return f"```{language}\n{code_text}\n```"
elif block_type == "equation":
return block["equation"]["expression"]
# Add more block types as needed
# Return empty string for unsupported block types
return ""
@ -203,23 +212,23 @@ class NotionHistoryConnector:
# if __name__ == "__main__":
# # Simple example of how to use this module
# import argparse
# parser = argparse.ArgumentParser(description="Fetch Notion pages using an integration token")
# parser.add_argument("--token", help="Your Notion integration token")
# parser.add_argument("--start-date", help="Start date in ISO format (e.g., 2023-01-01T00:00:00Z)")
# parser.add_argument("--end-date", help="End date in ISO format (e.g., 2023-12-31T23:59:59Z)")
# args = parser.parse_args()
# token = args.token
# if not token:
# token = input("Enter your Notion integration token: ")
# fetcher = NotionPageFetcher(token)
# try:
# pages = fetcher.get_all_pages(args.start_date, args.end_date)
# print(f"Fetched {len(pages)} pages from Notion")
# for page in pages:
# print(f"- {page['title']}")
# except Exception as e:
# print(f"Error: {str(e)}")
# print(f"Error: {str(e)}")

View file

@ -5,47 +5,48 @@ A module for retrieving conversation history from Slack channels.
Allows fetching channel lists and message history with date range filtering.
"""
import time # Added import
import logging # Added import
import logging # Added import
import time # Added import
from datetime import datetime
from typing import Any
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Any
logger = logging.getLogger(__name__) # Added logger
logger = logging.getLogger(__name__) # Added logger
class SlackHistory:
"""Class for retrieving conversation history from Slack channels."""
def __init__(self, token: str = None):
def __init__(self, token: str | None = None):
"""
Initialize the SlackHistory class.
Args:
token: Slack API token (optional, can be set later with set_token)
"""
self.client = WebClient(token=token) if token else None
def set_token(self, token: str) -> None:
"""
Set the Slack API token.
Args:
token: Slack API token
"""
self.client = WebClient(token=token)
def get_all_channels(self, include_private: bool = True) -> List[Dict[str, Any]]:
def get_all_channels(self, include_private: bool = True) -> list[dict[str, Any]]:
"""
Fetch all channels that the bot has access to, with rate limit handling.
Args:
include_private: Whether to include private channels
Returns:
List of dictionaries, each representing a channel with id, name, is_private, is_member.
Raises:
ValueError: If no Slack client has been initialized
SlackApiError: If there's an unrecoverable error calling the Slack API
@ -53,8 +54,8 @@ class SlackHistory:
"""
if not self.client:
raise ValueError("Slack client not initialized. Call set_token() first.")
channels_list = [] # Changed from dict to list
channels_list = [] # Changed from dict to list
types = "public_channel"
if include_private:
types += ",private_channel"
@ -65,16 +66,16 @@ class SlackHistory:
while is_first_request or next_cursor:
try:
if not is_first_request: # Add delay only for paginated requests
logger.info(f"Paginating for channels, waiting 3 seconds before next call. Cursor: {next_cursor}")
logger.info(
f"Paginating for channels, waiting 3 seconds before next call. Cursor: {next_cursor}"
)
time.sleep(3)
current_limit = 1000 # Max limit
api_result = self.client.conversations_list(
types=types,
cursor=next_cursor,
limit=current_limit
types=types, cursor=next_cursor, limit=current_limit
)
channels_on_page = api_result["channels"]
for channel in channels_on_page:
if "name" in channel and "id" in channel:
@ -86,12 +87,13 @@ class SlackHistory:
# It indicates if the authenticated user (bot) is a member.
# For public channels, this might be true or the API might not focus on it
# if the bot can read it anyway. For private, it's crucial.
"is_member": channel.get("is_member", False)
"is_member": channel.get("is_member", False),
}
channels_list.append(channel_data)
else:
logger.warning(f"Channel found with missing name or id. Data: {channel}")
logger.warning(
f"Channel found with missing name or id. Data: {channel}"
)
next_cursor = api_result.get("response_metadata", {}).get("next_cursor")
is_first_request = False # Subsequent requests are not the first
@ -101,57 +103,65 @@ class SlackHistory:
except SlackApiError as e:
if e.response is not None and e.response.status_code == 429:
retry_after_header = e.response.headers.get('Retry-After')
retry_after_header = e.response.headers.get("Retry-After")
wait_duration = 60 # Default wait time
if retry_after_header and retry_after_header.isdigit():
wait_duration = int(retry_after_header)
logger.warning(f"Slack API rate limit hit while fetching channels. Waiting for {wait_duration} seconds. Cursor: {next_cursor}")
logger.warning(
f"Slack API rate limit hit while fetching channels. Waiting for {wait_duration} seconds. Cursor: {next_cursor}"
)
time.sleep(wait_duration)
# The loop will continue, retrying with the same cursor
else:
# Not a 429 error, or no response object, re-raise
raise SlackApiError(f"Error retrieving channels: {e}", e.response)
raise SlackApiError(
f"Error retrieving channels: {e}", e.response
) from e
except Exception as general_error:
# Handle other potential errors like network issues if necessary, or re-raise
logger.error(f"An unexpected error occurred during channel fetching: {general_error}")
raise RuntimeError(f"An unexpected error occurred during channel fetching: {general_error}")
logger.error(
f"An unexpected error occurred during channel fetching: {general_error}"
)
raise RuntimeError(
f"An unexpected error occurred during channel fetching: {general_error}"
) from general_error
return channels_list
def get_conversation_history(
self,
channel_id: str,
limit: int = 1000,
oldest: Optional[int] = None,
latest: Optional[int] = None
) -> List[Dict[str, Any]]:
self,
channel_id: str,
limit: int = 1000,
oldest: int | None = None,
latest: int | None = None,
) -> list[dict[str, Any]]:
"""
Fetch conversation history for a channel.
Args:
channel_id: The ID of the channel to fetch history for
limit: Maximum number of messages to return per request (default 1000)
oldest: Start of time range (Unix timestamp)
latest: End of time range (Unix timestamp)
Returns:
List of message objects
Raises:
ValueError: If no Slack client has been initialized
SlackApiError: If there's an error calling the Slack API
"""
if not self.client:
raise ValueError("Slack client not initialized. Call set_token() first.")
messages = []
next_cursor = None
while True:
try:
# Proactive delay for conversations.history (Tier 3)
time.sleep(1.2) # Wait 1.2 seconds before each history call.
time.sleep(1.2) # Wait 1.2 seconds before each history call.
kwargs = {
"channel": channel_id,
@ -163,16 +173,19 @@ class SlackHistory:
kwargs["latest"] = latest
if next_cursor:
kwargs["cursor"] = next_cursor
current_api_call_successful = False
result = None # Ensure result is defined
result = None # Ensure result is defined
try:
result = self.client.conversations_history(**kwargs)
current_api_call_successful = True
except SlackApiError as e_history:
if e_history.response is not None and e_history.response.status_code == 429:
retry_after_str = e_history.response.headers.get('Retry-After')
wait_time = 60 # Default
if (
e_history.response is not None
and e_history.response.status_code == 429
):
retry_after_str = e_history.response.headers.get("Retry-After")
wait_time = 60 # Default
if retry_after_str and retry_after_str.isdigit():
wait_time = int(retry_after_str)
logger.warning(
@ -182,47 +195,54 @@ class SlackHistory:
time.sleep(wait_time)
# current_api_call_successful remains False, loop will retry this page
else:
raise # Re-raise to outer handler for not_in_channel or other SlackApiErrors
raise # Re-raise to outer handler for not_in_channel or other SlackApiErrors
if not current_api_call_successful:
continue # Retry the current page fetch due to handled rate limit
continue # Retry the current page fetch due to handled rate limit
# Process result if successful
batch = result["messages"]
messages.extend(batch)
if result.get("has_more", False) and len(messages) < limit:
next_cursor = result["response_metadata"]["next_cursor"]
else:
break # Exit pagination loop
except SlackApiError as e: # Outer catch for not_in_channel or unhandled SlackApiErrors from inner try
if (e.response is not None and
hasattr(e.response, 'data') and
isinstance(e.response.data, dict) and
e.response.data.get('error') == 'not_in_channel'):
break # Exit pagination loop
except SlackApiError as e: # Outer catch for not_in_channel or unhandled SlackApiErrors from inner try
if (
e.response is not None
and hasattr(e.response, "data")
and isinstance(e.response.data, dict)
and e.response.data.get("error") == "not_in_channel"
):
logger.warning(
f"Bot is not in channel '{channel_id}'. Cannot fetch history. "
"Please add the bot to this channel."
)
return []
return []
# For other SlackApiErrors from inner block or this level
raise SlackApiError(f"Error retrieving history for channel {channel_id}: {e}", e.response)
except Exception as general_error: # Catch any other unexpected errors
logger.error(f"Unexpected error in get_conversation_history for channel {channel_id}: {general_error}")
raise SlackApiError(
f"Error retrieving history for channel {channel_id}: {e}",
e.response,
) from e
except Exception as general_error: # Catch any other unexpected errors
logger.error(
f"Unexpected error in get_conversation_history for channel {channel_id}: {general_error}"
)
# Re-raise the general error to allow higher-level handling or visibility
raise
raise general_error from general_error
return messages[:limit]
@staticmethod
def convert_date_to_timestamp(date_str: str) -> Optional[int]:
def convert_date_to_timestamp(date_str: str) -> int | None:
"""
Convert a date string in format YYYY-MM-DD to Unix timestamp.
Args:
date_str: Date string in YYYY-MM-DD format
Returns:
Unix timestamp (seconds since epoch) or None if invalid format
"""
@ -231,67 +251,63 @@ class SlackHistory:
return int(dt.timestamp())
except ValueError:
return None
def get_history_by_date_range(
self,
channel_id: str,
start_date: str,
end_date: str,
limit: int = 1000
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
self, channel_id: str, start_date: str, end_date: str, limit: int = 1000
) -> tuple[list[dict[str, Any]], str | None]:
"""
Fetch conversation history within a date range.
Args:
channel_id: The ID of the channel to fetch history for
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format (inclusive)
limit: Maximum number of messages to return
Returns:
Tuple containing (messages list, error message or None)
"""
oldest = self.convert_date_to_timestamp(start_date)
if not oldest:
return [], f"Invalid start date format: {start_date}. Please use YYYY-MM-DD."
return (
[],
f"Invalid start date format: {start_date}. Please use YYYY-MM-DD.",
)
latest = self.convert_date_to_timestamp(end_date)
if not latest:
return [], f"Invalid end date format: {end_date}. Please use YYYY-MM-DD."
# Add one day to end date to make it inclusive
latest += 86400 # seconds in a day
try:
messages = self.get_conversation_history(
channel_id=channel_id,
limit=limit,
oldest=oldest,
latest=latest
channel_id=channel_id, limit=limit, oldest=oldest, latest=latest
)
return messages, None
except SlackApiError as e:
return [], f"Slack API error: {str(e)}"
return [], f"Slack API error: {e!s}"
except ValueError as e:
return [], str(e)
def get_user_info(self, user_id: str) -> Dict[str, Any]:
def get_user_info(self, user_id: str) -> dict[str, Any]:
"""
Get information about a user.
Args:
user_id: The ID of the user to get info for
Returns:
User information dictionary
Raises:
ValueError: If no Slack client has been initialized
SlackApiError: If there's an error calling the Slack API
"""
if not self.client:
raise ValueError("Slack client not initialized. Call set_token() first.")
while True:
try:
# Proactive delay for users.info (Tier 4) - generally not needed unless called extremely rapidly.
@ -299,46 +315,60 @@ class SlackHistory:
# time.sleep(0.6) # Optional: ~100 req/min if ever needed.
result = self.client.users_info(user=user_id)
return result["user"] # Success, return and exit loop implicitly
return result["user"] # Success, return and exit loop implicitly
except SlackApiError as e_user_info:
if e_user_info.response is not None and e_user_info.response.status_code == 429:
retry_after_str = e_user_info.response.headers.get('Retry-After')
if (
e_user_info.response is not None
and e_user_info.response.status_code == 429
):
retry_after_str = e_user_info.response.headers.get("Retry-After")
wait_time = 30 # Default for Tier 4, can be adjusted
if retry_after_str and retry_after_str.isdigit():
wait_time = int(retry_after_str)
logger.warning(f"Rate limited by Slack on users.info for user {user_id}. Retrying after {wait_time} seconds.")
logger.warning(
f"Rate limited by Slack on users.info for user {user_id}. Retrying after {wait_time} seconds."
)
time.sleep(wait_time)
continue # Retry the API call
else:
# Not a 429 error, or no response object, re-raise
raise SlackApiError(f"Error retrieving user info for {user_id}: {e_user_info}", e_user_info.response)
except Exception as general_error: # Catch any other unexpected errors
logger.error(f"Unexpected error in get_user_info for user {user_id}: {general_error}")
raise # Re-raise unexpected errors
def format_message(self, msg: Dict[str, Any], include_user_info: bool = False) -> Dict[str, Any]:
raise SlackApiError(
f"Error retrieving user info for {user_id}: {e_user_info}",
e_user_info.response,
) from e_user_info
except Exception as general_error: # Catch any other unexpected errors
logger.error(
f"Unexpected error in get_user_info for user {user_id}: {general_error}"
)
raise general_error from general_error # Re-raise unexpected errors
def format_message(
self, msg: dict[str, Any], include_user_info: bool = False
) -> dict[str, Any]:
"""
Format a message for easier consumption.
Args:
msg: The message object from Slack API
include_user_info: Whether to fetch and include user info
Returns:
Formatted message dictionary
"""
formatted = {
"text": msg.get("text", ""),
"timestamp": msg.get("ts"),
"datetime": datetime.fromtimestamp(float(msg.get("ts", 0))).strftime('%Y-%m-%d %H:%M:%S'),
"datetime": datetime.fromtimestamp(float(msg.get("ts", 0))).strftime(
"%Y-%m-%d %H:%M:%S"
),
"user_id": msg.get("user", "UNKNOWN"),
"has_attachments": bool(msg.get("attachments")),
"has_files": bool(msg.get("files")),
"thread_ts": msg.get("thread_ts"),
"is_thread": "thread_ts" in msg,
}
if include_user_info and "user" in msg and self.client:
try:
user_info = self.get_user_info(msg["user"])
@ -347,7 +377,7 @@ class SlackHistory:
except Exception:
# If we can't get user info, just continue without it
formatted["user_name"] = "Unknown"
return formatted
@ -388,4 +418,4 @@ if __name__ == "__main__":
except Exception as e:
print(f"Error: {e}")
"""
"""

View file

@ -1,23 +1,24 @@
import unittest
from unittest.mock import patch, Mock
from datetime import datetime
from unittest.mock import Mock, patch
from github3.exceptions import ForbiddenError # Import the specific exception
# Adjust the import path based on the actual location if test_github_connector.py
# is not in the same directory as github_connector.py or if paths are set up differently.
# Assuming surfsend_backend/app/connectors/test_github_connector.py
from surfsense_backend.app.connectors.github_connector import GitHubConnector
from github3.exceptions import ForbiddenError # Import the specific exception
class TestGitHubConnector(unittest.TestCase):
@patch('surfsense_backend.app.connectors.github_connector.github_login')
@patch("surfsense_backend.app.connectors.github_connector.github_login")
def test_get_user_repositories_uses_type_all(self, mock_github_login):
# Mock the GitHub client object and its methods
mock_gh_instance = Mock()
mock_github_login.return_value = mock_gh_instance
# Mock the self.gh.me() call in __init__ to prevent an actual API call
mock_gh_instance.me.return_value = Mock() # Simple mock to pass initialization
mock_gh_instance.me.return_value = Mock() # Simple mock to pass initialization
# Prepare mock repository data
mock_repo1_data = Mock()
@ -27,7 +28,9 @@ class TestGitHubConnector(unittest.TestCase):
mock_repo1_data.private = False
mock_repo1_data.html_url = "http://example.com/user/repo1"
mock_repo1_data.description = "Test repo 1"
mock_repo1_data.updated_at = datetime(2023, 1, 1, 10, 30, 0) # Added time component
mock_repo1_data.updated_at = datetime(
2023, 1, 1, 10, 30, 0
) # Added time component
mock_repo2_data = Mock()
mock_repo2_data.id = 2
@ -36,8 +39,10 @@ class TestGitHubConnector(unittest.TestCase):
mock_repo2_data.private = True
mock_repo2_data.html_url = "http://example.com/org/org-repo"
mock_repo2_data.description = "Org repo"
mock_repo2_data.updated_at = datetime(2023, 1, 2, 12, 0, 0) # Added time component
mock_repo2_data.updated_at = datetime(
2023, 1, 2, 12, 0, 0
) # Added time component
# Configure the mock for gh.repositories() call
# This method is an iterator, so it should return an iterable (e.g., a list)
mock_gh_instance.repositories.return_value = [mock_repo1_data, mock_repo2_data]
@ -46,26 +51,38 @@ class TestGitHubConnector(unittest.TestCase):
repositories = connector.get_user_repositories()
# Assert that gh.repositories was called correctly
mock_gh_instance.repositories.assert_called_once_with(type='all', sort='updated')
mock_gh_instance.repositories.assert_called_once_with(
type="all", sort="updated"
)
# Assert the structure and content of the returned data
expected_repositories = [
{
"id": 1, "name": "repo1", "full_name": "user/repo1", "private": False,
"url": "http://example.com/user/repo1", "description": "Test repo 1",
"last_updated": datetime(2023, 1, 1, 10, 30, 0)
"id": 1,
"name": "repo1",
"full_name": "user/repo1",
"private": False,
"url": "http://example.com/user/repo1",
"description": "Test repo 1",
"last_updated": datetime(2023, 1, 1, 10, 30, 0),
},
{
"id": 2, "name": "org-repo", "full_name": "org/org-repo", "private": True,
"url": "http://example.com/org/org-repo", "description": "Org repo",
"last_updated": datetime(2023, 1, 2, 12, 0, 0)
}
"id": 2,
"name": "org-repo",
"full_name": "org/org-repo",
"private": True,
"url": "http://example.com/org/org-repo",
"description": "Org repo",
"last_updated": datetime(2023, 1, 2, 12, 0, 0),
},
]
self.assertEqual(repositories, expected_repositories)
self.assertEqual(len(repositories), 2)
@patch('surfsense_backend.app.connectors.github_connector.github_login')
def test_get_user_repositories_handles_empty_description_and_none_updated_at(self, mock_github_login):
@patch("surfsense_backend.app.connectors.github_connector.github_login")
def test_get_user_repositories_handles_empty_description_and_none_updated_at(
self, mock_github_login
):
# Mock the GitHub client object and its methods
mock_gh_instance = Mock()
mock_github_login.return_value = mock_gh_instance
@ -77,61 +94,73 @@ class TestGitHubConnector(unittest.TestCase):
mock_repo_data.full_name = "user/repo_no_desc"
mock_repo_data.private = False
mock_repo_data.html_url = "http://example.com/user/repo_no_desc"
mock_repo_data.description = None # Test None description
mock_repo_data.updated_at = None # Test None updated_at
mock_repo_data.description = None # Test None description
mock_repo_data.updated_at = None # Test None updated_at
mock_gh_instance.repositories.return_value = [mock_repo_data]
connector = GitHubConnector(token="fake_token")
repositories = connector.get_user_repositories()
mock_gh_instance.repositories.assert_called_once_with(type='all', sort='updated')
mock_gh_instance.repositories.assert_called_once_with(
type="all", sort="updated"
)
expected_repositories = [
{
"id": 1, "name": "repo_no_desc", "full_name": "user/repo_no_desc", "private": False,
"url": "http://example.com/user/repo_no_desc", "description": "", # Expect empty string
"last_updated": None # Expect None
"id": 1,
"name": "repo_no_desc",
"full_name": "user/repo_no_desc",
"private": False,
"url": "http://example.com/user/repo_no_desc",
"description": "", # Expect empty string
"last_updated": None, # Expect None
}
]
self.assertEqual(repositories, expected_repositories)
@patch('surfsense_backend.app.connectors.github_connector.github_login')
@patch("surfsense_backend.app.connectors.github_connector.github_login")
def test_github_connector_initialization_failure_forbidden(self, mock_github_login):
# Test that __init__ raises ValueError on auth failure (ForbiddenError)
mock_gh_instance = Mock()
mock_github_login.return_value = mock_gh_instance
# Create a mock response object for the ForbiddenError
# The actual response structure might vary, but github3.py's ForbiddenError
# can be instantiated with just a response object that has a status_code.
mock_response = Mock()
mock_response.status_code = 403 # Typically Forbidden
mock_response.status_code = 403 # Typically Forbidden
# Setup the side_effect for self.gh.me()
mock_gh_instance.me.side_effect = ForbiddenError(mock_response)
with self.assertRaises(ValueError) as context:
GitHubConnector(token="invalid_token_forbidden")
self.assertIn("Invalid GitHub token or insufficient permissions.", str(context.exception))
self.assertIn(
"Invalid GitHub token or insufficient permissions.", str(context.exception)
)
@patch('surfsense_backend.app.connectors.github_connector.github_login')
def test_github_connector_initialization_failure_authentication_failed(self, mock_github_login):
@patch("surfsense_backend.app.connectors.github_connector.github_login")
def test_github_connector_initialization_failure_authentication_failed(
self, mock_github_login
):
# Test that __init__ raises ValueError on auth failure (AuthenticationFailed, which is a subclass of ForbiddenError)
# For github3.py, AuthenticationFailed is more specific for token issues.
from github3.exceptions import AuthenticationFailed
mock_gh_instance = Mock()
mock_github_login.return_value = mock_gh_instance
mock_response = Mock()
mock_response.status_code = 401 # Typically Unauthorized
mock_response.status_code = 401 # Typically Unauthorized
mock_gh_instance.me.side_effect = AuthenticationFailed(mock_response)
with self.assertRaises(ValueError) as context:
GitHubConnector(token="invalid_token_authfailed")
self.assertIn("Invalid GitHub token or insufficient permissions.", str(context.exception))
@patch('surfsense_backend.app.connectors.github_connector.github_login')
self.assertIn(
"Invalid GitHub token or insufficient permissions.", str(context.exception)
)
@patch("surfsense_backend.app.connectors.github_connector.github_login")
def test_get_user_repositories_handles_api_exception(self, mock_github_login):
mock_gh_instance = Mock()
mock_github_login.return_value = mock_gh_instance
@ -142,13 +171,18 @@ class TestGitHubConnector(unittest.TestCase):
connector = GitHubConnector(token="fake_token")
# We expect it to log an error and return an empty list
with patch('surfsense_backend.app.connectors.github_connector.logger') as mock_logger:
with patch(
"surfsense_backend.app.connectors.github_connector.logger"
) as mock_logger:
repositories = connector.get_user_repositories()
self.assertEqual(repositories, [])
mock_logger.error.assert_called_once()
self.assertIn("Failed to fetch GitHub repositories: API Error", mock_logger.error.call_args[0][0])
self.assertIn(
"Failed to fetch GitHub repositories: API Error",
mock_logger.error.call_args[0][0],
)
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View file

@ -1,373 +1,448 @@
import unittest
import time # Imported to be available for patching target module
from unittest.mock import patch, Mock, call
from unittest.mock import Mock, call, patch
from slack_sdk.errors import SlackApiError
# Since test_slack_history.py is in the same directory as slack_history.py
from .slack_history import SlackHistory
class TestSlackHistoryGetAllChannels(unittest.TestCase):
@patch('surfsense_backend.app.connectors.slack_history.logger')
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
@patch('slack_sdk.WebClient')
def test_get_all_channels_pagination_with_delay(self, MockWebClient, mock_sleep, mock_logger):
mock_client_instance = MockWebClient.return_value
class TestSlackHistoryGetAllChannels(unittest.TestCase):
@patch("surfsense_backend.app.connectors.slack_history.logger")
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
@patch("slack_sdk.WebClient")
def test_get_all_channels_pagination_with_delay(
self, mock_web_client, mock_sleep, mock_logger
):
mock_client_instance = mock_web_client.return_value
# Mock API responses now include is_private and is_member
page1_response = {
"channels": [
{"name": "general", "id": "C1", "is_private": False, "is_member": True},
{"name": "dev", "id": "C0", "is_private": False, "is_member": True}
{"name": "general", "id": "C1", "is_private": False, "is_member": True},
{"name": "dev", "id": "C0", "is_private": False, "is_member": True},
],
"response_metadata": {"next_cursor": "cursor123"}
"response_metadata": {"next_cursor": "cursor123"},
}
page2_response = {
"channels": [{"name": "random", "id": "C2", "is_private": True, "is_member": True}],
"response_metadata": {"next_cursor": ""}
"channels": [
{"name": "random", "id": "C2", "is_private": True, "is_member": True}
],
"response_metadata": {"next_cursor": ""},
}
mock_client_instance.conversations_list.side_effect = [
page1_response,
page2_response
page2_response,
]
slack_history = SlackHistory(token="fake_token")
channels_list = slack_history.get_all_channels(include_private=True)
expected_channels_list = [
{"id": "C1", "name": "general", "is_private": False, "is_member": True},
{"id": "C0", "name": "dev", "is_private": False, "is_member": True},
{"id": "C2", "name": "random", "is_private": True, "is_member": True}
{"id": "C2", "name": "random", "is_private": True, "is_member": True},
]
self.assertEqual(len(channels_list), 3)
self.assertListEqual(channels_list, expected_channels_list) # Assert list equality
self.assertListEqual(
channels_list, expected_channels_list
) # Assert list equality
expected_calls = [
call(types="public_channel,private_channel", cursor=None, limit=1000),
call(types="public_channel,private_channel", cursor="cursor123", limit=1000)
call(
types="public_channel,private_channel", cursor="cursor123", limit=1000
),
]
mock_client_instance.conversations_list.assert_has_calls(expected_calls)
self.assertEqual(mock_client_instance.conversations_list.call_count, 2)
mock_sleep.assert_called_once_with(3)
mock_logger.info.assert_called_once_with("Paginating for channels, waiting 3 seconds before next call. Cursor: cursor123")
@patch('surfsense_backend.app.connectors.slack_history.logger')
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
@patch('slack_sdk.WebClient')
def test_get_all_channels_rate_limit_with_retry_after(self, MockWebClient, mock_sleep, mock_logger):
mock_client_instance = MockWebClient.return_value
mock_sleep.assert_called_once_with(3)
mock_logger.info.assert_called_once_with(
"Paginating for channels, waiting 3 seconds before next call. Cursor: cursor123"
)
@patch("surfsense_backend.app.connectors.slack_history.logger")
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
@patch("slack_sdk.WebClient")
def test_get_all_channels_rate_limit_with_retry_after(
self, mock_web_client, mock_sleep, mock_logger
):
mock_client_instance = mock_web_client.return_value
mock_error_response = Mock()
mock_error_response.status_code = 429
mock_error_response.headers = {'Retry-After': '5'}
mock_error_response.headers = {"Retry-After": "5"}
successful_response = {
"channels": [{"name": "general", "id": "C1", "is_private": False, "is_member": True}],
"response_metadata": {"next_cursor": ""}
"channels": [
{"name": "general", "id": "C1", "is_private": False, "is_member": True}
],
"response_metadata": {"next_cursor": ""},
}
mock_client_instance.conversations_list.side_effect = [
SlackApiError(message="ratelimited", response=mock_error_response),
successful_response
successful_response,
]
slack_history = SlackHistory(token="fake_token")
channels_list = slack_history.get_all_channels(include_private=True)
expected_channels_list = [{"id": "C1", "name": "general", "is_private": False, "is_member": True}]
expected_channels_list = [
{"id": "C1", "name": "general", "is_private": False, "is_member": True}
]
self.assertEqual(len(channels_list), 1)
self.assertListEqual(channels_list, expected_channels_list)
mock_sleep.assert_called_once_with(5)
mock_logger.warning.assert_called_once_with("Slack API rate limit hit while fetching channels. Waiting for 5 seconds. Cursor: None")
mock_sleep.assert_called_once_with(5)
mock_logger.warning.assert_called_once_with(
"Slack API rate limit hit while fetching channels. Waiting for 5 seconds. Cursor: None"
)
expected_calls = [
call(types="public_channel,private_channel", cursor=None, limit=1000),
call(types="public_channel,private_channel", cursor=None, limit=1000)
call(types="public_channel,private_channel", cursor=None, limit=1000),
call(types="public_channel,private_channel", cursor=None, limit=1000),
]
mock_client_instance.conversations_list.assert_has_calls(expected_calls)
self.assertEqual(mock_client_instance.conversations_list.call_count, 2)
@patch('surfsense_backend.app.connectors.slack_history.logger')
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
@patch('slack_sdk.WebClient')
def test_get_all_channels_rate_limit_no_retry_after_valid_header(self, MockWebClient, mock_sleep, mock_logger):
mock_client_instance = MockWebClient.return_value
@patch("surfsense_backend.app.connectors.slack_history.logger")
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
@patch("slack_sdk.WebClient")
def test_get_all_channels_rate_limit_no_retry_after_valid_header(
self, mock_web_client, mock_sleep, mock_logger
):
mock_client_instance = mock_web_client.return_value
mock_error_response = Mock()
mock_error_response.status_code = 429
mock_error_response.headers = {'Retry-After': 'invalid_value'}
mock_error_response.headers = {"Retry-After": "invalid_value"}
successful_response = {
"channels": [{"name": "general", "id": "C1", "is_private": False, "is_member": True}],
"response_metadata": {"next_cursor": ""}
"channels": [
{"name": "general", "id": "C1", "is_private": False, "is_member": True}
],
"response_metadata": {"next_cursor": ""},
}
mock_client_instance.conversations_list.side_effect = [
SlackApiError(message="ratelimited", response=mock_error_response),
successful_response
successful_response,
]
slack_history = SlackHistory(token="fake_token")
channels_list = slack_history.get_all_channels(include_private=True)
expected_channels_list = [{"id": "C1", "name": "general", "is_private": False, "is_member": True}]
expected_channels_list = [
{"id": "C1", "name": "general", "is_private": False, "is_member": True}
]
self.assertListEqual(channels_list, expected_channels_list)
mock_sleep.assert_called_once_with(60) # Default fallback
mock_logger.warning.assert_called_once_with("Slack API rate limit hit while fetching channels. Waiting for 60 seconds. Cursor: None")
mock_sleep.assert_called_once_with(60) # Default fallback
mock_logger.warning.assert_called_once_with(
"Slack API rate limit hit while fetching channels. Waiting for 60 seconds. Cursor: None"
)
self.assertEqual(mock_client_instance.conversations_list.call_count, 2)
@patch('surfsense_backend.app.connectors.slack_history.logger')
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
@patch('slack_sdk.WebClient')
def test_get_all_channels_rate_limit_no_retry_after_header(self, MockWebClient, mock_sleep, mock_logger):
mock_client_instance = MockWebClient.return_value
@patch("surfsense_backend.app.connectors.slack_history.logger")
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
@patch("slack_sdk.WebClient")
def test_get_all_channels_rate_limit_no_retry_after_header(
self, mock_web_client, mock_sleep, mock_logger
):
mock_client_instance = mock_web_client.return_value
mock_error_response = Mock()
mock_error_response.status_code = 429
mock_error_response.headers = {}
successful_response = {
"channels": [{"name": "general", "id": "C1", "is_private": False, "is_member": True}],
"response_metadata": {"next_cursor": ""}
}
mock_client_instance.conversations_list.side_effect = [
SlackApiError(message="ratelimited", response=mock_error_response),
successful_response
]
slack_history = SlackHistory(token="fake_token")
channels_list = slack_history.get_all_channels(include_private=True)
expected_channels_list = [{"id": "C1", "name": "general", "is_private": False, "is_member": True}]
self.assertListEqual(channels_list, expected_channels_list)
mock_sleep.assert_called_once_with(60) # Default fallback
mock_logger.warning.assert_called_once_with("Slack API rate limit hit while fetching channels. Waiting for 60 seconds. Cursor: None")
self.assertEqual(mock_client_instance.conversations_list.call_count, 2)
@patch('surfsense_backend.app.connectors.slack_history.logger')
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
@patch('slack_sdk.WebClient')
def test_get_all_channels_other_slack_api_error(self, MockWebClient, mock_sleep, mock_logger):
mock_client_instance = MockWebClient.return_value
mock_error_response = Mock()
mock_error_response.status_code = 500
mock_error_response.headers = {}
mock_error_response.data = {"ok": False, "error": "internal_error"}
original_error = SlackApiError(message="server error", response=mock_error_response)
mock_client_instance.conversations_list.side_effect = original_error
successful_response = {
"channels": [
{"name": "general", "id": "C1", "is_private": False, "is_member": True}
],
"response_metadata": {"next_cursor": ""},
}
mock_client_instance.conversations_list.side_effect = [
SlackApiError(message="ratelimited", response=mock_error_response),
successful_response,
]
slack_history = SlackHistory(token="fake_token")
channels_list = slack_history.get_all_channels(include_private=True)
expected_channels_list = [
{"id": "C1", "name": "general", "is_private": False, "is_member": True}
]
self.assertListEqual(channels_list, expected_channels_list)
mock_sleep.assert_called_once_with(60) # Default fallback
mock_logger.warning.assert_called_once_with(
"Slack API rate limit hit while fetching channels. Waiting for 60 seconds. Cursor: None"
)
self.assertEqual(mock_client_instance.conversations_list.call_count, 2)
@patch("surfsense_backend.app.connectors.slack_history.logger")
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
@patch("slack_sdk.WebClient")
def test_get_all_channels_other_slack_api_error(
self, mock_web_client, mock_sleep, mock_logger
):
mock_client_instance = mock_web_client.return_value
mock_error_response = Mock()
mock_error_response.status_code = 500
mock_error_response.headers = {}
mock_error_response.data = {"ok": False, "error": "internal_error"}
original_error = SlackApiError(
message="server error", response=mock_error_response
)
mock_client_instance.conversations_list.side_effect = original_error
slack_history = SlackHistory(token="fake_token")
with self.assertRaises(SlackApiError) as context:
slack_history.get_all_channels(include_private=True)
self.assertEqual(context.exception.response.status_code, 500)
self.assertIn("server error", str(context.exception))
mock_sleep.assert_not_called()
mock_logger.warning.assert_not_called() # Ensure no rate limit log
mock_logger.warning.assert_not_called() # Ensure no rate limit log
mock_client_instance.conversations_list.assert_called_once_with(
types="public_channel,private_channel", cursor=None, limit=1000
)
@patch('surfsense_backend.app.connectors.slack_history.logger')
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
@patch('slack_sdk.WebClient')
def test_get_all_channels_handles_missing_name_id_gracefully(self, MockWebClient, mock_sleep, mock_logger):
mock_client_instance = MockWebClient.return_value
@patch("surfsense_backend.app.connectors.slack_history.logger")
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
@patch("slack_sdk.WebClient")
def test_get_all_channels_handles_missing_name_id_gracefully(
self, mock_web_client, mock_sleep, mock_logger
):
mock_client_instance = mock_web_client.return_value
response_with_malformed_data = {
"channels": [
{"id": "C1_missing_name", "is_private": False, "is_member": True},
{"id": "C1_missing_name", "is_private": False, "is_member": True},
{"name": "channel_missing_id", "is_private": False, "is_member": True},
{"name": "general", "id": "C2_valid", "is_private": False, "is_member": True}
{
"name": "general",
"id": "C2_valid",
"is_private": False,
"is_member": True,
},
],
"response_metadata": {"next_cursor": ""}
"response_metadata": {"next_cursor": ""},
}
mock_client_instance.conversations_list.return_value = response_with_malformed_data
mock_client_instance.conversations_list.return_value = (
response_with_malformed_data
)
slack_history = SlackHistory(token="fake_token")
channels_list = slack_history.get_all_channels(include_private=True)
expected_channels_list = [
{"id": "C2_valid", "name": "general", "is_private": False, "is_member": True}
]
self.assertEqual(len(channels_list), 1)
self.assertListEqual(channels_list, expected_channels_list)
self.assertEqual(mock_logger.warning.call_count, 2)
mock_logger.warning.assert_any_call("Channel found with missing name or id. Data: {'id': 'C1_missing_name', 'is_private': False, 'is_member': True}")
mock_logger.warning.assert_any_call("Channel found with missing name or id. Data: {'name': 'channel_missing_id', 'is_private': False, 'is_member': True}")
mock_sleep.assert_not_called()
expected_channels_list = [
{
"id": "C2_valid",
"name": "general",
"is_private": False,
"is_member": True,
}
]
self.assertEqual(len(channels_list), 1)
self.assertListEqual(channels_list, expected_channels_list)
self.assertEqual(mock_logger.warning.call_count, 2)
mock_logger.warning.assert_any_call(
"Channel found with missing name or id. Data: {'id': 'C1_missing_name', 'is_private': False, 'is_member': True}"
)
mock_logger.warning.assert_any_call(
"Channel found with missing name or id. Data: {'name': 'channel_missing_id', 'is_private': False, 'is_member': True}"
)
mock_sleep.assert_not_called()
mock_client_instance.conversations_list.assert_called_once_with(
types="public_channel,private_channel", cursor=None, limit=1000
)
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
class TestSlackHistoryGetConversationHistory(unittest.TestCase):
@patch('surfsense_backend.app.connectors.slack_history.logger')
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
@patch('slack_sdk.WebClient')
def test_proactive_delay_single_page(self, MockWebClient, mock_time_sleep, mock_logger):
mock_client_instance = MockWebClient.return_value
class TestSlackHistoryGetConversationHistory(unittest.TestCase):
@patch("surfsense_backend.app.connectors.slack_history.logger")
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
@patch("slack_sdk.WebClient")
def test_proactive_delay_single_page(
self, mock_web_client, mock_time_sleep, mock_logger
):
mock_client_instance = mock_web_client.return_value
mock_client_instance.conversations_history.return_value = {
"messages": [{"text": "msg1"}],
"has_more": False
"has_more": False,
}
slack_history = SlackHistory(token="fake_token")
slack_history.get_conversation_history(channel_id="C123")
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay
@patch('surfsense_backend.app.connectors.slack_history.logger')
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
@patch('slack_sdk.WebClient')
def test_proactive_delay_multiple_pages(self, MockWebClient, mock_time_sleep, mock_logger):
mock_client_instance = MockWebClient.return_value
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay
@patch("surfsense_backend.app.connectors.slack_history.logger")
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
@patch("slack_sdk.WebClient")
def test_proactive_delay_multiple_pages(
self, mock_web_client, mock_time_sleep, mock_logger
):
mock_client_instance = mock_web_client.return_value
mock_client_instance.conversations_history.side_effect = [
{
"messages": [{"text": "msg1"}],
"has_more": True,
"response_metadata": {"next_cursor": "cursor1"}
"response_metadata": {"next_cursor": "cursor1"},
},
{
"messages": [{"text": "msg2"}],
"has_more": False
}
{"messages": [{"text": "msg2"}], "has_more": False},
]
slack_history = SlackHistory(token="fake_token")
slack_history.get_conversation_history(channel_id="C123")
# Expected calls: 1.2 (page1), 1.2 (page2)
self.assertEqual(mock_time_sleep.call_count, 2)
mock_time_sleep.assert_has_calls([call(1.2), call(1.2)])
@patch('surfsense_backend.app.connectors.slack_history.logger')
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
@patch('slack_sdk.WebClient')
def test_retry_after_logic(self, MockWebClient, mock_time_sleep, mock_logger):
mock_client_instance = MockWebClient.return_value
@patch("surfsense_backend.app.connectors.slack_history.logger")
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
@patch("slack_sdk.WebClient")
def test_retry_after_logic(self, mock_web_client, mock_time_sleep, mock_logger):
mock_client_instance = mock_web_client.return_value
mock_error_response = Mock()
mock_error_response.status_code = 429
mock_error_response.headers = {'Retry-After': '5'}
mock_error_response.headers = {"Retry-After": "5"}
mock_client_instance.conversations_history.side_effect = [
SlackApiError(message="ratelimited", response=mock_error_response),
{"messages": [{"text": "msg1"}], "has_more": False}
{"messages": [{"text": "msg1"}], "has_more": False},
]
slack_history = SlackHistory(token="fake_token")
messages = slack_history.get_conversation_history(channel_id="C123")
self.assertEqual(len(messages), 1)
self.assertEqual(messages[0]["text"], "msg1")
# Expected sleep calls: 1.2 (proactive for 1st attempt), 5 (rate limit), 1.2 (proactive for 2nd attempt)
mock_time_sleep.assert_has_calls([call(1.2), call(5), call(1.2)], any_order=False)
mock_logger.warning.assert_called_once() # Check that a warning was logged for rate limiting
@patch('surfsense_backend.app.connectors.slack_history.logger')
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
@patch('slack_sdk.WebClient')
def test_not_in_channel_error(self, MockWebClient, mock_time_sleep, mock_logger):
mock_client_instance = MockWebClient.return_value
# Expected sleep calls: 1.2 (proactive for 1st attempt), 5 (rate limit), 1.2 (proactive for 2nd attempt)
mock_time_sleep.assert_has_calls(
[call(1.2), call(5), call(1.2)], any_order=False
)
mock_logger.warning.assert_called_once() # Check that a warning was logged for rate limiting
@patch("surfsense_backend.app.connectors.slack_history.logger")
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
@patch("slack_sdk.WebClient")
def test_not_in_channel_error(self, mock_web_client, mock_time_sleep, mock_logger):
mock_client_instance = mock_web_client.return_value
mock_error_response = Mock()
mock_error_response.status_code = 403 # Typical for not_in_channel, but data matters more
mock_error_response.data = {'ok': False, 'error': 'not_in_channel'}
mock_error_response.status_code = (
403 # Typical for not_in_channel, but data matters more
)
mock_error_response.data = {"ok": False, "error": "not_in_channel"}
# This error is now raised by the inner try-except, then caught by the outer one
mock_client_instance.conversations_history.side_effect = SlackApiError(
message="not_in_channel error",
response=mock_error_response
message="not_in_channel error", response=mock_error_response
)
slack_history = SlackHistory(token="fake_token")
messages = slack_history.get_conversation_history(channel_id="C123")
self.assertEqual(messages, [])
mock_logger.warning.assert_called_with(
"Bot is not in channel 'C123'. Cannot fetch history. Please add the bot to this channel."
)
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay before the API call
mock_time_sleep.assert_called_once_with(
1.2
) # Proactive delay before the API call
@patch("surfsense_backend.app.connectors.slack_history.logger")
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
@patch("slack_sdk.WebClient")
def test_other_slack_api_error_propagates(
self, mock_web_client, mock_time_sleep, mock_logger
):
mock_client_instance = mock_web_client.return_value
@patch('surfsense_backend.app.connectors.slack_history.logger')
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
@patch('slack_sdk.WebClient')
def test_other_slack_api_error_propagates(self, MockWebClient, mock_time_sleep, mock_logger):
mock_client_instance = MockWebClient.return_value
mock_error_response = Mock()
mock_error_response.status_code = 500
mock_error_response.data = {'ok': False, 'error': 'internal_error'}
original_error = SlackApiError(message="server error", response=mock_error_response)
mock_error_response.data = {"ok": False, "error": "internal_error"}
original_error = SlackApiError(
message="server error", response=mock_error_response
)
mock_client_instance.conversations_history.side_effect = original_error
slack_history = SlackHistory(token="fake_token")
with self.assertRaises(SlackApiError) as context:
slack_history.get_conversation_history(channel_id="C123")
self.assertIn("Error retrieving history for channel C123", str(context.exception))
self.assertIs(context.exception.response, mock_error_response)
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay
@patch('surfsense_backend.app.connectors.slack_history.logger')
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
@patch('slack_sdk.WebClient')
def test_general_exception_propagates(self, MockWebClient, mock_time_sleep, mock_logger):
mock_client_instance = MockWebClient.return_value
self.assertIn(
"Error retrieving history for channel C123", str(context.exception)
)
self.assertIs(context.exception.response, mock_error_response)
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay
@patch("surfsense_backend.app.connectors.slack_history.logger")
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
@patch("slack_sdk.WebClient")
def test_general_exception_propagates(
self, mock_web_client, mock_time_sleep, mock_logger
):
mock_client_instance = mock_web_client.return_value
original_error = Exception("Something broke")
mock_client_instance.conversations_history.side_effect = original_error
slack_history = SlackHistory(token="fake_token")
with self.assertRaises(Exception) as context: # Check for generic Exception
with self.assertRaises(Exception) as context: # Check for generic Exception
slack_history.get_conversation_history(channel_id="C123")
self.assertIs(context.exception, original_error) # Should re-raise the original error
mock_logger.error.assert_called_once_with("Unexpected error in get_conversation_history for channel C123: Something broke")
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay
self.assertIs(
context.exception, original_error
) # Should re-raise the original error
mock_logger.error.assert_called_once_with(
"Unexpected error in get_conversation_history for channel C123: Something broke"
)
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay
class TestSlackHistoryGetUserInfo(unittest.TestCase):
@patch("surfsense_backend.app.connectors.slack_history.logger")
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
@patch("slack_sdk.WebClient")
def test_retry_after_logic(self, mock_web_client, mock_time_sleep, mock_logger):
mock_client_instance = mock_web_client.return_value
@patch('surfsense_backend.app.connectors.slack_history.logger')
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
@patch('slack_sdk.WebClient')
def test_retry_after_logic(self, MockWebClient, mock_time_sleep, mock_logger):
mock_client_instance = MockWebClient.return_value
mock_error_response = Mock()
mock_error_response.status_code = 429
mock_error_response.headers = {'Retry-After': '3'} # Using 3 seconds for test
mock_error_response.headers = {"Retry-After": "3"} # Using 3 seconds for test
successful_user_data = {"id": "U123", "name": "testuser"}
mock_client_instance.users_info.side_effect = [
SlackApiError(message="ratelimited_userinfo", response=mock_error_response),
{"user": successful_user_data}
{"user": successful_user_data},
]
slack_history = SlackHistory(token="fake_token")
user_info = slack_history.get_user_info(user_id="U123")
self.assertEqual(user_info, successful_user_data)
# Assert that time.sleep was called for the rate limit
mock_time_sleep.assert_called_once_with(3)
mock_logger.warning.assert_called_once_with(
@ -375,46 +450,58 @@ class TestSlackHistoryGetUserInfo(unittest.TestCase):
)
# Assert users_info was called twice (original + retry)
self.assertEqual(mock_client_instance.users_info.call_count, 2)
mock_client_instance.users_info.assert_has_calls([call(user="U123"), call(user="U123")])
mock_client_instance.users_info.assert_has_calls(
[call(user="U123"), call(user="U123")]
)
@patch("surfsense_backend.app.connectors.slack_history.logger")
@patch(
"surfsense_backend.app.connectors.slack_history.time.sleep"
) # time.sleep might be called by other logic, but not expected here
@patch("slack_sdk.WebClient")
def test_other_slack_api_error_propagates(
self, mock_web_client, mock_time_sleep, mock_logger
):
mock_client_instance = mock_web_client.return_value
@patch('surfsense_backend.app.connectors.slack_history.logger')
@patch('surfsense_backend.app.connectors.slack_history.time.sleep') # time.sleep might be called by other logic, but not expected here
@patch('slack_sdk.WebClient')
def test_other_slack_api_error_propagates(self, MockWebClient, mock_time_sleep, mock_logger):
mock_client_instance = MockWebClient.return_value
mock_error_response = Mock()
mock_error_response.status_code = 500 # Some other error
mock_error_response.data = {'ok': False, 'error': 'internal_server_error'}
original_error = SlackApiError(message="internal server error", response=mock_error_response)
mock_error_response.status_code = 500 # Some other error
mock_error_response.data = {"ok": False, "error": "internal_server_error"}
original_error = SlackApiError(
message="internal server error", response=mock_error_response
)
mock_client_instance.users_info.side_effect = original_error
slack_history = SlackHistory(token="fake_token")
with self.assertRaises(SlackApiError) as context:
slack_history.get_user_info(user_id="U123")
# Check that the raised error is the one we expect
self.assertIn("Error retrieving user info for U123", str(context.exception))
self.assertIs(context.exception.response, mock_error_response)
mock_time_sleep.assert_not_called() # No rate limit sleep
mock_time_sleep.assert_not_called() # No rate limit sleep
@patch('surfsense_backend.app.connectors.slack_history.logger')
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
@patch('slack_sdk.WebClient')
def test_general_exception_propagates(self, MockWebClient, mock_time_sleep, mock_logger):
mock_client_instance = MockWebClient.return_value
@patch("surfsense_backend.app.connectors.slack_history.logger")
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
@patch("slack_sdk.WebClient")
def test_general_exception_propagates(
self, mock_web_client, mock_time_sleep, mock_logger
):
mock_client_instance = mock_web_client.return_value
original_error = Exception("A very generic problem")
mock_client_instance.users_info.side_effect = original_error
slack_history = SlackHistory(token="fake_token")
with self.assertRaises(Exception) as context:
slack_history.get_user_info(user_id="U123")
self.assertIs(context.exception, original_error) # Check it's the exact same exception
self.assertIs(
context.exception, original_error
) # Check it's the exact same exception
mock_logger.error.assert_called_once_with(
"Unexpected error in get_user_info for user U123: A very generic problem"
)
mock_time_sleep.assert_not_called() # No rate limit sleep
mock_time_sleep.assert_not_called() # No rate limit sleep

View file

@ -1,22 +1,21 @@
from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from datetime import UTC, datetime
from enum import Enum
from fastapi import Depends
from pgvector.sqlalchemy import Vector
from sqlalchemy import (
ARRAY,
JSON,
TIMESTAMP,
Boolean,
Column,
Enum as SQLAlchemyEnum,
ForeignKey,
Integer,
JSON,
String,
Text,
text,
TIMESTAMP
)
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
@ -33,10 +32,7 @@ if config.AUTH_TYPE == "GOOGLE":
SQLAlchemyUserDatabase,
)
else:
from fastapi_users.db import (
SQLAlchemyBaseUserTableUUID,
SQLAlchemyUserDatabase,
)
from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase
DATABASE_URL = config.DATABASE_URL
@ -52,8 +48,9 @@ class DocumentType(str, Enum):
LINEAR_CONNECTOR = "LINEAR_CONNECTOR"
DISCORD_CONNECTOR = "DISCORD_CONNECTOR"
class SearchSourceConnectorType(str, Enum):
SERPER_API = "SERPER_API" # NOT IMPLEMENTED YET : DON'T REMEMBER WHY : MOST PROBABLY BECAUSE WE NEED TO CRAWL THE RESULTS RETURNED BY IT
SERPER_API = "SERPER_API" # NOT IMPLEMENTED YET : DON'T REMEMBER WHY : MOST PROBABLY BECAUSE WE NEED TO CRAWL THE RESULTS RETURNED BY IT
TAVILY_API = "TAVILY_API"
LINKUP_API = "LINKUP_API"
SLACK_CONNECTOR = "SLACK_CONNECTOR"
@ -61,13 +58,15 @@ class SearchSourceConnectorType(str, Enum):
GITHUB_CONNECTOR = "GITHUB_CONNECTOR"
LINEAR_CONNECTOR = "LINEAR_CONNECTOR"
DISCORD_CONNECTOR = "DISCORD_CONNECTOR"
class ChatType(str, Enum):
QNA = "QNA"
REPORT_GENERAL = "REPORT_GENERAL"
REPORT_DEEP = "REPORT_DEEP"
REPORT_DEEPER = "REPORT_DEEPER"
class LiteLLMProvider(str, Enum):
OPENAI = "OPENAI"
ANTHROPIC = "ANTHROPIC"
@ -92,6 +91,7 @@ class LiteLLMProvider(str, Enum):
PETALS = "PETALS"
CUSTOM = "CUSTOM"
class LogLevel(str, Enum):
DEBUG = "DEBUG"
INFO = "INFO"
@ -99,18 +99,27 @@ class LogLevel(str, Enum):
ERROR = "ERROR"
CRITICAL = "CRITICAL"
class LogStatus(str, Enum):
IN_PROGRESS = "IN_PROGRESS"
SUCCESS = "SUCCESS"
FAILED = "FAILED"
class Base(DeclarativeBase):
pass
class TimestampMixin:
@declared_attr
def created_at(cls):
return Column(TIMESTAMP(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc), index=True)
def created_at(cls): # noqa: N805
return Column(
TIMESTAMP(timezone=True),
nullable=False,
default=lambda: datetime.now(UTC),
index=True,
)
class BaseModel(Base):
__abstract__ = True
@ -118,6 +127,7 @@ class BaseModel(Base):
id = Column(Integer, primary_key=True, index=True)
class Chat(BaseModel, TimestampMixin):
__tablename__ = "chats"
@ -125,73 +135,115 @@ class Chat(BaseModel, TimestampMixin):
title = Column(String, nullable=False, index=True)
initial_connectors = Column(ARRAY(String), nullable=True)
messages = Column(JSON, nullable=False)
search_space_id = Column(Integer, ForeignKey('searchspaces.id', ondelete='CASCADE'), nullable=False)
search_space = relationship('SearchSpace', back_populates='chats')
search_space_id = Column(
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
)
search_space = relationship("SearchSpace", back_populates="chats")
class Document(BaseModel, TimestampMixin):
__tablename__ = "documents"
title = Column(String, nullable=False, index=True)
document_type = Column(SQLAlchemyEnum(DocumentType), nullable=False)
document_metadata = Column(JSON, nullable=True)
content = Column(Text, nullable=False)
content_hash = Column(String, nullable=False, index=True, unique=True)
embedding = Column(Vector(config.embedding_model_instance.dimension))
search_space_id = Column(Integer, ForeignKey("searchspaces.id", ondelete='CASCADE'), nullable=False)
search_space_id = Column(
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
)
search_space = relationship("SearchSpace", back_populates="documents")
chunks = relationship("Chunk", back_populates="document", cascade="all, delete-orphan")
chunks = relationship(
"Chunk", back_populates="document", cascade="all, delete-orphan"
)
class Chunk(BaseModel, TimestampMixin):
__tablename__ = "chunks"
content = Column(Text, nullable=False)
embedding = Column(Vector(config.embedding_model_instance.dimension))
document_id = Column(Integer, ForeignKey("documents.id", ondelete='CASCADE'), nullable=False)
document_id = Column(
Integer, ForeignKey("documents.id", ondelete="CASCADE"), nullable=False
)
document = relationship("Document", back_populates="chunks")
class Podcast(BaseModel, TimestampMixin):
__tablename__ = "podcasts"
title = Column(String, nullable=False, index=True)
podcast_transcript = Column(JSON, nullable=False, default={})
file_location = Column(String(500), nullable=False, default="")
search_space_id = Column(Integer, ForeignKey("searchspaces.id", ondelete='CASCADE'), nullable=False)
search_space_id = Column(
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
)
search_space = relationship("SearchSpace", back_populates="podcasts")
class SearchSpace(BaseModel, TimestampMixin):
__tablename__ = "searchspaces"
name = Column(String(100), nullable=False, index=True)
description = Column(String(500), nullable=True)
user_id = Column(UUID(as_uuid=True), ForeignKey("user.id", ondelete='CASCADE'), nullable=False)
user_id = Column(
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
)
user = relationship("User", back_populates="search_spaces")
documents = relationship("Document", back_populates="search_space", order_by="Document.id", cascade="all, delete-orphan")
podcasts = relationship("Podcast", back_populates="search_space", order_by="Podcast.id", cascade="all, delete-orphan")
chats = relationship('Chat', back_populates='search_space', order_by='Chat.id', cascade="all, delete-orphan")
logs = relationship("Log", back_populates="search_space", order_by="Log.id", cascade="all, delete-orphan")
documents = relationship(
"Document",
back_populates="search_space",
order_by="Document.id",
cascade="all, delete-orphan",
)
podcasts = relationship(
"Podcast",
back_populates="search_space",
order_by="Podcast.id",
cascade="all, delete-orphan",
)
chats = relationship(
"Chat",
back_populates="search_space",
order_by="Chat.id",
cascade="all, delete-orphan",
)
logs = relationship(
"Log",
back_populates="search_space",
order_by="Log.id",
cascade="all, delete-orphan",
)
class SearchSourceConnector(BaseModel, TimestampMixin):
__tablename__ = "search_source_connectors"
name = Column(String(100), nullable=False, index=True)
connector_type = Column(SQLAlchemyEnum(SearchSourceConnectorType), nullable=False, unique=True)
connector_type = Column(
SQLAlchemyEnum(SearchSourceConnectorType), nullable=False, unique=True
)
is_indexable = Column(Boolean, nullable=False, default=False)
last_indexed_at = Column(TIMESTAMP(timezone=True), nullable=True)
config = Column(JSON, nullable=False)
user_id = Column(UUID(as_uuid=True), ForeignKey("user.id", ondelete='CASCADE'), nullable=False)
user_id = Column(
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
)
user = relationship("User", back_populates="search_source_connectors")
class LLMConfig(BaseModel, TimestampMixin):
__tablename__ = "llm_configs"
name = Column(String(100), nullable=False, index=True)
# Provider from the enum
provider = Column(SQLAlchemyEnum(LiteLLMProvider), nullable=False)
@ -202,78 +254,141 @@ class LLMConfig(BaseModel, TimestampMixin):
# API Key should be encrypted before storing
api_key = Column(String, nullable=False)
api_base = Column(String(500), nullable=True)
# For any other parameters that litellm supports
litellm_params = Column(JSON, nullable=True, default={})
user_id = Column(UUID(as_uuid=True), ForeignKey("user.id", ondelete='CASCADE'), nullable=False)
user_id = Column(
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
)
user = relationship("User", back_populates="llm_configs", foreign_keys=[user_id])
class Log(BaseModel, TimestampMixin):
__tablename__ = "logs"
level = Column(SQLAlchemyEnum(LogLevel), nullable=False, index=True)
status = Column(SQLAlchemyEnum(LogStatus), nullable=False, index=True)
message = Column(Text, nullable=False)
source = Column(String(200), nullable=True, index=True) # Service/component that generated the log
source = Column(
String(200), nullable=True, index=True
) # Service/component that generated the log
log_metadata = Column(JSON, nullable=True, default={}) # Additional context data
search_space_id = Column(Integer, ForeignKey("searchspaces.id", ondelete='CASCADE'), nullable=False)
search_space_id = Column(
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
)
search_space = relationship("SearchSpace", back_populates="logs")
if config.AUTH_TYPE == "GOOGLE":
class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
pass
class User(SQLAlchemyBaseUserTableUUID, Base):
oauth_accounts: Mapped[list[OAuthAccount]] = relationship(
"OAuthAccount", lazy="joined"
)
search_spaces = relationship("SearchSpace", back_populates="user")
search_source_connectors = relationship("SearchSourceConnector", back_populates="user")
llm_configs = relationship("LLMConfig", back_populates="user", foreign_keys="LLMConfig.user_id", cascade="all, delete-orphan")
search_source_connectors = relationship(
"SearchSourceConnector", back_populates="user"
)
llm_configs = relationship(
"LLMConfig",
back_populates="user",
foreign_keys="LLMConfig.user_id",
cascade="all, delete-orphan",
)
long_context_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True)
fast_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True)
strategic_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True)
long_context_llm_id = Column(
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
)
fast_llm_id = Column(
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
)
strategic_llm_id = Column(
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
)
long_context_llm = relationship("LLMConfig", foreign_keys=[long_context_llm_id], post_update=True)
fast_llm = relationship("LLMConfig", foreign_keys=[fast_llm_id], post_update=True)
strategic_llm = relationship("LLMConfig", foreign_keys=[strategic_llm_id], post_update=True)
long_context_llm = relationship(
"LLMConfig", foreign_keys=[long_context_llm_id], post_update=True
)
fast_llm = relationship(
"LLMConfig", foreign_keys=[fast_llm_id], post_update=True
)
strategic_llm = relationship(
"LLMConfig", foreign_keys=[strategic_llm_id], post_update=True
)
else:
class User(SQLAlchemyBaseUserTableUUID, Base):
search_spaces = relationship("SearchSpace", back_populates="user")
search_source_connectors = relationship("SearchSourceConnector", back_populates="user")
llm_configs = relationship("LLMConfig", back_populates="user", foreign_keys="LLMConfig.user_id", cascade="all, delete-orphan")
search_source_connectors = relationship(
"SearchSourceConnector", back_populates="user"
)
llm_configs = relationship(
"LLMConfig",
back_populates="user",
foreign_keys="LLMConfig.user_id",
cascade="all, delete-orphan",
)
long_context_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True)
fast_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True)
strategic_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True)
long_context_llm_id = Column(
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
)
fast_llm_id = Column(
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
)
strategic_llm_id = Column(
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
)
long_context_llm = relationship("LLMConfig", foreign_keys=[long_context_llm_id], post_update=True)
fast_llm = relationship("LLMConfig", foreign_keys=[fast_llm_id], post_update=True)
strategic_llm = relationship("LLMConfig", foreign_keys=[strategic_llm_id], post_update=True)
long_context_llm = relationship(
"LLMConfig", foreign_keys=[long_context_llm_id], post_update=True
)
fast_llm = relationship(
"LLMConfig", foreign_keys=[fast_llm_id], post_update=True
)
strategic_llm = relationship(
"LLMConfig", foreign_keys=[strategic_llm_id], post_update=True
)
engine = create_async_engine(DATABASE_URL)
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
async def setup_indexes():
async with engine.begin() as conn:
# Create indexes
# Create indexes
# Document Summary Indexes
await conn.execute(text('CREATE INDEX IF NOT EXISTS document_vector_index ON documents USING hnsw (embedding public.vector_cosine_ops)'))
await conn.execute(text('CREATE INDEX IF NOT EXISTS document_search_index ON documents USING gin (to_tsvector(\'english\', content))'))
await conn.execute(
text(
"CREATE INDEX IF NOT EXISTS document_vector_index ON documents USING hnsw (embedding public.vector_cosine_ops)"
)
)
await conn.execute(
text(
"CREATE INDEX IF NOT EXISTS document_search_index ON documents USING gin (to_tsvector('english', content))"
)
)
# Document Chuck Indexes
await conn.execute(text('CREATE INDEX IF NOT EXISTS chucks_vector_index ON chunks USING hnsw (embedding public.vector_cosine_ops)'))
await conn.execute(text('CREATE INDEX IF NOT EXISTS chucks_search_index ON chunks USING gin (to_tsvector(\'english\', content))'))
await conn.execute(
text(
"CREATE INDEX IF NOT EXISTS chucks_vector_index ON chunks USING hnsw (embedding public.vector_cosine_ops)"
)
)
await conn.execute(
text(
"CREATE INDEX IF NOT EXISTS chucks_search_index ON chunks USING gin (to_tsvector('english', content))"
)
)
async def create_db_and_tables():
async with engine.begin() as conn:
await conn.execute(text('CREATE EXTENSION IF NOT EXISTS vector'))
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
await conn.run_sync(Base.metadata.create_all)
await setup_indexes()
@ -284,14 +399,22 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
if config.AUTH_TYPE == "GOOGLE":
async def get_user_db(session: AsyncSession = Depends(get_async_session)):
yield SQLAlchemyUserDatabase(session, User, OAuthAccount)
else:
async def get_user_db(session: AsyncSession = Depends(get_async_session)):
yield SQLAlchemyUserDatabase(session, User)
async def get_chucks_hybrid_search_retriever(session: AsyncSession = Depends(get_async_session)):
async def get_chucks_hybrid_search_retriever(
session: AsyncSession = Depends(get_async_session),
):
return ChucksHybridSearchRetriever(session)
async def get_documents_hybrid_search_retriever(session: AsyncSession = Depends(get_async_session)):
async def get_documents_hybrid_search_retriever(
session: AsyncSession = Depends(get_async_session),
):
return DocumentHybridSearchRetriever(session)

View file

@ -1,9 +1,12 @@
from datetime import UTC, datetime
from langchain_core.prompts.prompt import PromptTemplate
from datetime import datetime, timezone
DATE_TODAY = "Today's date is " + datetime.now(timezone.utc).astimezone().isoformat() + '\n'
DATE_TODAY = "Today's date is " + datetime.now(UTC).astimezone().isoformat() + "\n"
SUMMARY_PROMPT = DATE_TODAY + """
SUMMARY_PROMPT = (
DATE_TODAY
+ """
<INSTRUCTIONS>
<context>
You are an expert document analyst and summarization specialist tasked with distilling complex information into clear,
@ -96,8 +99,8 @@ SUMMARY_PROMPT = DATE_TODAY + """
</document_to_summarize>
</INSTRUCTIONS>
"""
)
SUMMARY_PROMPT_TEMPLATE = PromptTemplate(
input_variables=["document"],
template=SUMMARY_PROMPT
)
input_variables=["document"], template=SUMMARY_PROMPT
)

View file

@ -2,34 +2,41 @@ class ChucksHybridSearchRetriever:
def __init__(self, db_session):
"""
Initialize the hybrid search retriever with a database session.
Args:
db_session: SQLAlchemy AsyncSession from FastAPI dependency injection
"""
self.db_session = db_session
async def vector_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None) -> list:
async def vector_search(
self,
query_text: str,
top_k: int,
user_id: str,
search_space_id: int | None = None,
) -> list:
"""
Perform vector similarity search on chunks.
Args:
query_text: The search query text
top_k: Number of results to return
user_id: The ID of the user performing the search
search_space_id: Optional search space ID to filter results
Returns:
List of chunks sorted by vector similarity
"""
from sqlalchemy import select, func
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from app.db import Chunk, Document, SearchSpace
from app.config import config
from app.db import Chunk, Document, SearchSpace
# Get embedding for the query
embedding_model = config.embedding_model_instance
query_embedding = embedding_model.embed(query_text)
# Build the base query with user ownership check
query = (
select(Chunk)
@ -38,45 +45,48 @@ class ChucksHybridSearchRetriever:
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
.where(SearchSpace.user_id == user_id)
)
# Add search space filter if provided
if search_space_id is not None:
query = query.where(Document.search_space_id == search_space_id)
# Add vector similarity ordering
query = (
query
.order_by(Chunk.embedding.op("<=>")(query_embedding))
.limit(top_k)
)
query = query.order_by(Chunk.embedding.op("<=>")(query_embedding)).limit(top_k)
# Execute the query
result = await self.db_session.execute(query)
chunks = result.scalars().all()
return chunks
async def full_text_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None) -> list:
async def full_text_search(
self,
query_text: str,
top_k: int,
user_id: str,
search_space_id: int | None = None,
) -> list:
"""
Perform full-text keyword search on chunks.
Args:
query_text: The search query text
top_k: Number of results to return
user_id: The ID of the user performing the search
search_space_id: Optional search space ID to filter results
Returns:
List of chunks sorted by text relevance
"""
from sqlalchemy import select, func, text
from sqlalchemy import func, select
from sqlalchemy.orm import joinedload
from app.db import Chunk, Document, SearchSpace
# Create tsvector and tsquery for PostgreSQL full-text search
tsvector = func.to_tsvector('english', Chunk.content)
tsquery = func.plainto_tsquery('english', query_text)
tsvector = func.to_tsvector("english", Chunk.content)
tsquery = func.plainto_tsquery("english", query_text)
# Build the base query with user ownership check
query = (
select(Chunk)
@ -84,64 +94,70 @@ class ChucksHybridSearchRetriever:
.join(Document, Chunk.document_id == Document.id)
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
.where(SearchSpace.user_id == user_id)
.where(tsvector.op("@@")(tsquery)) # Only include results that match the query
.where(
tsvector.op("@@")(tsquery)
) # Only include results that match the query
)
# Add search space filter if provided
if search_space_id is not None:
query = query.where(Document.search_space_id == search_space_id)
# Add text search ranking
query = (
query
.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
.limit(top_k)
)
query = query.order_by(func.ts_rank_cd(tsvector, tsquery).desc()).limit(top_k)
# Execute the query
result = await self.db_session.execute(query)
chunks = result.scalars().all()
return chunks
async def hybrid_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None, document_type: str = None) -> list:
async def hybrid_search(
self,
query_text: str,
top_k: int,
user_id: str,
search_space_id: int | None = None,
document_type: str | None = None,
) -> list:
"""
Combine vector similarity and full-text search results using Reciprocal Rank Fusion.
Args:
query_text: The search query text
top_k: Number of results to return
user_id: The ID of the user performing the search
search_space_id: Optional search space ID to filter results
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
Returns:
List of dictionaries containing chunk data and relevance scores
"""
from sqlalchemy import select, func, text
from sqlalchemy import func, select, text
from sqlalchemy.orm import joinedload
from app.db import Chunk, Document, SearchSpace, DocumentType
from app.config import config
from app.db import Chunk, Document, DocumentType, SearchSpace
# Get embedding for the query
embedding_model = config.embedding_model_instance
query_embedding = embedding_model.embed(query_text)
# Constants for RRF calculation
k = 60 # Constant for RRF calculation
n_results = top_k * 2 # Get more results for better fusion
# Create tsvector and tsquery for PostgreSQL full-text search
tsvector = func.to_tsvector('english', Chunk.content)
tsquery = func.plainto_tsquery('english', query_text)
tsvector = func.to_tsvector("english", Chunk.content)
tsquery = func.plainto_tsquery("english", query_text)
# Base conditions for document filtering
base_conditions = [SearchSpace.user_id == user_id]
# Add search space filter if provided
if search_space_id is not None:
base_conditions.append(Document.search_space_id == search_space_id)
# Add document type filter if provided
if document_type is not None:
# Convert string to enum value if needed
@ -154,90 +170,97 @@ class ChucksHybridSearchRetriever:
return []
else:
base_conditions.append(Document.document_type == document_type)
# CTE for semantic search with user ownership check
semantic_search_cte = (
select(
Chunk.id,
func.rank().over(order_by=Chunk.embedding.op("<=>")(query_embedding)).label("rank")
func.rank()
.over(order_by=Chunk.embedding.op("<=>")(query_embedding))
.label("rank"),
)
.join(Document, Chunk.document_id == Document.id)
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
.where(*base_conditions)
)
semantic_search_cte = (
semantic_search_cte
.order_by(Chunk.embedding.op("<=>")(query_embedding))
semantic_search_cte.order_by(Chunk.embedding.op("<=>")(query_embedding))
.limit(n_results)
.cte("semantic_search")
)
# CTE for keyword search with user ownership check
keyword_search_cte = (
select(
Chunk.id,
func.rank().over(order_by=func.ts_rank_cd(tsvector, tsquery).desc()).label("rank")
func.rank()
.over(order_by=func.ts_rank_cd(tsvector, tsquery).desc())
.label("rank"),
)
.join(Document, Chunk.document_id == Document.id)
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
.where(*base_conditions)
.where(tsvector.op("@@")(tsquery))
)
keyword_search_cte = (
keyword_search_cte
.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
keyword_search_cte.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
.limit(n_results)
.cte("keyword_search")
)
# Final combined query using a FULL OUTER JOIN with RRF scoring
final_query = (
select(
Chunk,
(
func.coalesce(1.0 / (k + semantic_search_cte.c.rank), 0.0) +
func.coalesce(1.0 / (k + keyword_search_cte.c.rank), 0.0)
).label("score")
func.coalesce(1.0 / (k + semantic_search_cte.c.rank), 0.0)
+ func.coalesce(1.0 / (k + keyword_search_cte.c.rank), 0.0)
).label("score"),
)
.select_from(
semantic_search_cte.outerjoin(
keyword_search_cte,
keyword_search_cte,
semantic_search_cte.c.id == keyword_search_cte.c.id,
full=True
full=True,
)
)
.join(
Chunk,
Chunk.id == func.coalesce(semantic_search_cte.c.id, keyword_search_cte.c.id)
Chunk.id
== func.coalesce(semantic_search_cte.c.id, keyword_search_cte.c.id),
)
.options(joinedload(Chunk.document))
.order_by(text("score DESC"))
.limit(top_k)
)
# Execute the query
result = await self.db_session.execute(final_query)
chunks_with_scores = result.all()
# If no results were found, return an empty list
if not chunks_with_scores:
return []
# Convert to serializable dictionaries if no reranker is available or if reranking failed
serialized_results = []
for chunk, score in chunks_with_scores:
serialized_results.append({
"chunk_id": chunk.id,
"content": chunk.content,
"score": float(score), # Ensure score is a Python float
"document": {
"id": chunk.document.id,
"title": chunk.document.title,
"document_type": chunk.document.document_type.value if hasattr(chunk.document, 'document_type') else None,
"metadata": chunk.document.document_metadata
serialized_results.append(
{
"chunk_id": chunk.id,
"content": chunk.content,
"score": float(score), # Ensure score is a Python float
"document": {
"id": chunk.document.id,
"title": chunk.document.title,
"document_type": chunk.document.document_type.value
if hasattr(chunk.document, "document_type")
else None,
"metadata": chunk.document.document_metadata,
},
}
})
)
return serialized_results

View file

@ -2,34 +2,41 @@ class DocumentHybridSearchRetriever:
def __init__(self, db_session):
"""
Initialize the hybrid search retriever with a database session.
Args:
db_session: SQLAlchemy AsyncSession from FastAPI dependency injection
"""
self.db_session = db_session
async def vector_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None) -> list:
async def vector_search(
self,
query_text: str,
top_k: int,
user_id: str,
search_space_id: int | None = None,
) -> list:
"""
Perform vector similarity search on documents.
Args:
query_text: The search query text
top_k: Number of results to return
user_id: The ID of the user performing the search
search_space_id: Optional search space ID to filter results
Returns:
List of documents sorted by vector similarity
"""
from sqlalchemy import select, func
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from app.db import Document, SearchSpace
from app.config import config
from app.db import Document, SearchSpace
# Get embedding for the query
embedding_model = config.embedding_model_instance
query_embedding = embedding_model.embed(query_text)
# Build the base query with user ownership check
query = (
select(Document)
@ -37,107 +44,118 @@ class DocumentHybridSearchRetriever:
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
.where(SearchSpace.user_id == user_id)
)
# Add search space filter if provided
if search_space_id is not None:
query = query.where(Document.search_space_id == search_space_id)
# Add vector similarity ordering
query = (
query
.order_by(Document.embedding.op("<=>")(query_embedding))
.limit(top_k)
query = query.order_by(Document.embedding.op("<=>")(query_embedding)).limit(
top_k
)
# Execute the query
result = await self.db_session.execute(query)
documents = result.scalars().all()
return documents
async def full_text_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None) -> list:
async def full_text_search(
self,
query_text: str,
top_k: int,
user_id: str,
search_space_id: int | None = None,
) -> list:
"""
Perform full-text keyword search on documents.
Args:
query_text: The search query text
top_k: Number of results to return
user_id: The ID of the user performing the search
search_space_id: Optional search space ID to filter results
Returns:
List of documents sorted by text relevance
"""
from sqlalchemy import select, func, text
from sqlalchemy import func, select
from sqlalchemy.orm import joinedload
from app.db import Document, SearchSpace
# Create tsvector and tsquery for PostgreSQL full-text search
tsvector = func.to_tsvector('english', Document.content)
tsquery = func.plainto_tsquery('english', query_text)
tsvector = func.to_tsvector("english", Document.content)
tsquery = func.plainto_tsquery("english", query_text)
# Build the base query with user ownership check
query = (
select(Document)
.options(joinedload(Document.search_space))
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
.where(SearchSpace.user_id == user_id)
.where(tsvector.op("@@")(tsquery)) # Only include results that match the query
.where(
tsvector.op("@@")(tsquery)
) # Only include results that match the query
)
# Add search space filter if provided
if search_space_id is not None:
query = query.where(Document.search_space_id == search_space_id)
# Add text search ranking
query = (
query
.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
.limit(top_k)
)
query = query.order_by(func.ts_rank_cd(tsvector, tsquery).desc()).limit(top_k)
# Execute the query
result = await self.db_session.execute(query)
documents = result.scalars().all()
return documents
async def hybrid_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None, document_type: str = None) -> list:
async def hybrid_search(
self,
query_text: str,
top_k: int,
user_id: str,
search_space_id: int | None = None,
document_type: str | None = None,
) -> list:
"""
Combine vector similarity and full-text search results using Reciprocal Rank Fusion.
Args:
query_text: The search query text
top_k: Number of results to return
user_id: The ID of the user performing the search
search_space_id: Optional search space ID to filter results
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
"""
from sqlalchemy import select, func, text
from sqlalchemy import func, select, text
from sqlalchemy.orm import joinedload
from app.db import Document, SearchSpace, DocumentType
from app.config import config
from app.db import Document, DocumentType, SearchSpace
# Get embedding for the query
embedding_model = config.embedding_model_instance
query_embedding = embedding_model.embed(query_text)
# Constants for RRF calculation
k = 60 # Constant for RRF calculation
n_results = top_k * 2 # Get more results for better fusion
# Create tsvector and tsquery for PostgreSQL full-text search
tsvector = func.to_tsvector('english', Document.content)
tsquery = func.plainto_tsquery('english', query_text)
tsvector = func.to_tsvector("english", Document.content)
tsquery = func.plainto_tsquery("english", query_text)
# Base conditions for document filtering
base_conditions = [SearchSpace.user_id == user_id]
# Add search space filter if provided
if search_space_id is not None:
base_conditions.append(Document.search_space_id == search_space_id)
# Add document type filter if provided
if document_type is not None:
# Convert string to enum value if needed
@ -150,98 +168,112 @@ class DocumentHybridSearchRetriever:
return []
else:
base_conditions.append(Document.document_type == document_type)
# CTE for semantic search with user ownership check
semantic_search_cte = (
select(
Document.id,
func.rank().over(order_by=Document.embedding.op("<=>")(query_embedding)).label("rank")
func.rank()
.over(order_by=Document.embedding.op("<=>")(query_embedding))
.label("rank"),
)
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
.where(*base_conditions)
)
semantic_search_cte = (
semantic_search_cte
.order_by(Document.embedding.op("<=>")(query_embedding))
semantic_search_cte.order_by(Document.embedding.op("<=>")(query_embedding))
.limit(n_results)
.cte("semantic_search")
)
# CTE for keyword search with user ownership check
keyword_search_cte = (
select(
Document.id,
func.rank().over(order_by=func.ts_rank_cd(tsvector, tsquery).desc()).label("rank")
func.rank()
.over(order_by=func.ts_rank_cd(tsvector, tsquery).desc())
.label("rank"),
)
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
.where(*base_conditions)
.where(tsvector.op("@@")(tsquery))
)
keyword_search_cte = (
keyword_search_cte
.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
keyword_search_cte.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
.limit(n_results)
.cte("keyword_search")
)
# Final combined query using a FULL OUTER JOIN with RRF scoring
final_query = (
select(
Document,
(
func.coalesce(1.0 / (k + semantic_search_cte.c.rank), 0.0) +
func.coalesce(1.0 / (k + keyword_search_cte.c.rank), 0.0)
).label("score")
func.coalesce(1.0 / (k + semantic_search_cte.c.rank), 0.0)
+ func.coalesce(1.0 / (k + keyword_search_cte.c.rank), 0.0)
).label("score"),
)
.select_from(
semantic_search_cte.outerjoin(
keyword_search_cte,
keyword_search_cte,
semantic_search_cte.c.id == keyword_search_cte.c.id,
full=True
full=True,
)
)
.join(
Document,
Document.id == func.coalesce(semantic_search_cte.c.id, keyword_search_cte.c.id)
Document.id
== func.coalesce(semantic_search_cte.c.id, keyword_search_cte.c.id),
)
.options(joinedload(Document.search_space))
.order_by(text("score DESC"))
.limit(top_k)
)
# Execute the query
result = await self.db_session.execute(final_query)
documents_with_scores = result.all()
# If no results were found, return an empty list
if not documents_with_scores:
return []
# Convert to serializable dictionaries
serialized_results = []
for document, score in documents_with_scores:
# Fetch associated chunks for this document
from sqlalchemy import select
from app.db import Chunk
chunks_query = select(Chunk).where(Chunk.document_id == document.id).order_by(Chunk.id)
chunks_query = (
select(Chunk).where(Chunk.document_id == document.id).order_by(Chunk.id)
)
chunks_result = await self.db_session.execute(chunks_query)
chunks = chunks_result.scalars().all()
# Concatenate chunks content
concatenated_chunks_content = " ".join([chunk.content for chunk in chunks]) if chunks else document.content
serialized_results.append({
"document_id": document.id,
"title": document.title,
"content": document.content,
"chunks_content": concatenated_chunks_content,
"document_type": document.document_type.value if hasattr(document, 'document_type') else None,
"metadata": document.document_metadata,
"score": float(score), # Ensure score is a Python float
"search_space_id": document.search_space_id
})
return serialized_results
concatenated_chunks_content = (
" ".join([chunk.content for chunk in chunks])
if chunks
else document.content
)
serialized_results.append(
{
"document_id": document.id,
"title": document.title,
"content": document.content,
"chunks_content": concatenated_chunks_content,
"document_type": document.document_type.value
if hasattr(document, "document_type")
else None,
"metadata": document.document_metadata,
"score": float(score), # Ensure score is a Python float
"search_space_id": document.search_space_id,
}
)
return serialized_results

View file

@ -1,11 +1,12 @@
from fastapi import APIRouter
from .search_spaces_routes import router as search_spaces_router
from .documents_routes import router as documents_router
from .podcasts_routes import router as podcasts_router
from .chats_routes import router as chats_router
from .search_source_connectors_routes import router as search_source_connectors_router
from .documents_routes import router as documents_router
from .llm_config_routes import router as llm_config_router
from .logs_routes import router as logs_router
from .podcasts_routes import router as podcasts_router
from .search_source_connectors_routes import router as search_source_connectors_router
from .search_spaces_routes import router as search_spaces_router
router = APIRouter()

View file

@ -1,38 +1,40 @@
from typing import List
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from langchain.schema import AIMessage, HumanMessage
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import Chat, SearchSpace, User, get_async_session
from app.schemas import AISDKChatRequest, ChatCreate, ChatRead, ChatUpdate
from app.tasks.stream_connector_search_results import stream_connector_search_results
from app.users import current_active_user
from app.utils.check_ownership import check_ownership
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from langchain.schema import HumanMessage, AIMessage
router = APIRouter()
@router.post("/chat")
async def handle_chat_data(
request: AISDKChatRequest,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
messages = request.messages
if messages[-1]['role'] != "user":
if messages[-1]["role"] != "user":
raise HTTPException(
status_code=400, detail="Last message must be a user message")
status_code=400, detail="Last message must be a user message"
)
user_query = messages[-1]['content']
search_space_id = request.data.get('search_space_id')
research_mode: str = request.data.get('research_mode')
selected_connectors: List[str] = request.data.get('selected_connectors')
document_ids_to_add_in_context: List[int] = request.data.get('document_ids_to_add_in_context')
search_mode_str = request.data.get('search_mode', "CHUNKS")
user_query = messages[-1]["content"]
search_space_id = request.data.get("search_space_id")
research_mode: str = request.data.get("research_mode")
selected_connectors: list[str] = request.data.get("selected_connectors")
document_ids_to_add_in_context: list[int] = request.data.get(
"document_ids_to_add_in_context"
)
search_mode_str = request.data.get("search_mode", "CHUNKS")
# Convert search_space_id to integer if it's a string
if search_space_id and isinstance(search_space_id, str):
@ -40,21 +42,23 @@ async def handle_chat_data(
search_space_id = int(search_space_id)
except ValueError:
raise HTTPException(
status_code=400, detail="Invalid search_space_id format")
status_code=400, detail="Invalid search_space_id format"
) from None
# Check if the search space belongs to the current user
try:
await check_ownership(session, SearchSpace, search_space_id, user)
except HTTPException:
raise HTTPException(
status_code=403, detail="You don't have access to this search space")
status_code=403, detail="You don't have access to this search space"
) from None
langchain_chat_history = []
for message in messages[:-1]:
if message['role'] == "user":
langchain_chat_history.append(HumanMessage(content=message['content']))
elif message['role'] == "assistant":
langchain_chat_history.append(AIMessage(content=message['content']))
if message["role"] == "user":
langchain_chat_history.append(HumanMessage(content=message["content"]))
elif message["role"] == "assistant":
langchain_chat_history.append(AIMessage(content=message["content"]))
response = StreamingResponse(
stream_connector_search_results(
@ -69,7 +73,7 @@ async def handle_chat_data(
document_ids_to_add_in_context,
)
)
response.headers["x-vercel-ai-data-stream"] = "v1"
return response
@ -78,7 +82,7 @@ async def handle_chat_data(
async def create_chat(
chat: ChatCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
try:
await check_ownership(session, SearchSpace, chat.search_space_id, user)
@ -89,52 +93,57 @@ async def create_chat(
return db_chat
except HTTPException:
raise
except IntegrityError as e:
except IntegrityError:
await session.rollback()
raise HTTPException(
status_code=400, detail="Database constraint violation. Please check your input data.")
except OperationalError as e:
status_code=400,
detail="Database constraint violation. Please check your input data.",
) from None
except OperationalError:
await session.rollback()
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later.")
except Exception as e:
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception:
await session.rollback()
raise HTTPException(
status_code=500, detail="An unexpected error occurred while creating the chat.")
status_code=500,
detail="An unexpected error occurred while creating the chat.",
) from None
@router.get("/chats/", response_model=List[ChatRead])
@router.get("/chats/", response_model=list[ChatRead])
async def read_chats(
skip: int = 0,
limit: int = 100,
search_space_id: int = None,
search_space_id: int | None = None,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
try:
query = select(Chat).join(SearchSpace).filter(SearchSpace.user_id == user.id)
# Filter by search_space_id if provided
if search_space_id is not None:
query = query.filter(Chat.search_space_id == search_space_id)
result = await session.execute(
query.offset(skip).limit(limit)
)
result = await session.execute(query.offset(skip).limit(limit))
return result.scalars().all()
except OperationalError:
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later.")
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception:
raise HTTPException(
status_code=500, detail="An unexpected error occurred while fetching chats.")
status_code=500, detail="An unexpected error occurred while fetching chats."
) from None
@router.get("/chats/{chat_id}", response_model=ChatRead)
async def read_chat(
chat_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
try:
result = await session.execute(
@ -145,14 +154,19 @@ async def read_chat(
chat = result.scalars().first()
if not chat:
raise HTTPException(
status_code=404, detail="Chat not found or you don't have permission to access it")
status_code=404,
detail="Chat not found or you don't have permission to access it",
)
return chat
except OperationalError:
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later.")
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception:
raise HTTPException(
status_code=500, detail="An unexpected error occurred while fetching the chat.")
status_code=500,
detail="An unexpected error occurred while fetching the chat.",
) from None
@router.put("/chats/{chat_id}", response_model=ChatRead)
@ -160,7 +174,7 @@ async def update_chat(
chat_id: int,
chat_update: ChatUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
try:
db_chat = await read_chat(chat_id, session, user)
@ -175,22 +189,27 @@ async def update_chat(
except IntegrityError:
await session.rollback()
raise HTTPException(
status_code=400, detail="Database constraint violation. Please check your input data.")
status_code=400,
detail="Database constraint violation. Please check your input data.",
) from None
except OperationalError:
await session.rollback()
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later.")
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception:
await session.rollback()
raise HTTPException(
status_code=500, detail="An unexpected error occurred while updating the chat.")
status_code=500,
detail="An unexpected error occurred while updating the chat.",
) from None
@router.delete("/chats/{chat_id}", response_model=dict)
async def delete_chat(
chat_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
try:
db_chat = await read_chat(chat_id, session, user)
@ -202,81 +221,16 @@ async def delete_chat(
except IntegrityError:
await session.rollback()
raise HTTPException(
status_code=400, detail="Cannot delete chat due to existing dependencies.")
status_code=400, detail="Cannot delete chat due to existing dependencies."
) from None
except OperationalError:
await session.rollback()
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later.")
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception:
await session.rollback()
raise HTTPException(
status_code=500, detail="An unexpected error occurred while deleting the chat.")
# test_data = [
# {
# "type": "TERMINAL_INFO",
# "content": [
# {
# "id": 1,
# "text": "Starting to search for crawled URLs...",
# "type": "info"
# },
# {
# "id": 2,
# "text": "Found 2 relevant crawled URLs",
# "type": "success"
# }
# ]
# },
# {
# "type": "SOURCES",
# "content": [
# {
# "id": 1,
# "name": "Crawled URLs",
# "type": "CRAWLED_URL",
# "sources": [
# {
# "id": 1,
# "title": "Webpage Title",
# "description": "Webpage Dec",
# "url": "https://jsoneditoronline.org/"
# },
# {
# "id": 2,
# "title": "Webpage Title",
# "description": "Webpage Dec",
# "url": "https://www.google.com/"
# }
# ]
# },
# {
# "id": 2,
# "name": "Files",
# "type": "FILE",
# "sources": [
# {
# "id": 3,
# "title": "Webpage Title",
# "description": "Webpage Dec",
# "url": "https://jsoneditoronline.org/"
# },
# {
# "id": 4,
# "title": "Webpage Title",
# "description": "Webpage Dec",
# "url": "https://www.google.com/"
# }
# ]
# }
# ]
# },
# {
# "type": "ANSWER",
# "content": [
# "## SurfSense Introduction",
# "Surfsense is A Personal NotebookLM and Perplexity-like AI Assistant for Everyone. Research and Never forget Anything. [1] [3]"
# ]
# }
# ]
status_code=500,
detail="An unexpected error occurred while deleting the chat.",
) from None

View file

@ -1,23 +1,35 @@
from litellm import atranscription
from fastapi import APIRouter, Depends, BackgroundTasks, UploadFile, Form, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from typing import List
from app.db import Log, get_async_session, User, SearchSpace, Document, DocumentType
from app.schemas import DocumentsCreate, DocumentUpdate, DocumentRead
from app.users import current_active_user
from app.utils.check_ownership import check_ownership
from app.tasks.background_tasks import add_received_markdown_file_document, add_extension_received_document, add_received_file_document_using_unstructured, add_crawled_url_document, add_youtube_video_document, add_received_file_document_using_llamacloud, add_received_file_document_using_docling
from app.config import config as app_config
# Force asyncio to use standard event loop before unstructured imports
import asyncio
from fastapi import APIRouter, BackgroundTasks, Depends, Form, HTTPException, UploadFile
from litellm import atranscription
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.config import config as app_config
from app.db import Document, DocumentType, Log, SearchSpace, User, get_async_session
from app.schemas import DocumentRead, DocumentsCreate, DocumentUpdate
from app.services.task_logging_service import TaskLoggingService
from app.tasks.background_tasks import (
add_crawled_url_document,
add_extension_received_document,
add_received_file_document_using_docling,
add_received_file_document_using_llamacloud,
add_received_file_document_using_unstructured,
add_received_markdown_file_document,
add_youtube_video_document,
)
from app.users import current_active_user
from app.utils.check_ownership import check_ownership
try:
asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
except RuntimeError:
except RuntimeError as e:
print("Error setting event loop policy", e)
pass
import os
os.environ["UNSTRUCTURED_HAS_PATCHED_LOOP"] = "1"
@ -29,7 +41,7 @@ async def create_documents(
request: DocumentsCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
fastapi_background_tasks: BackgroundTasks = BackgroundTasks()
fastapi_background_tasks: BackgroundTasks = BackgroundTasks(),
):
try:
# Check if the user owns the search space
@ -41,7 +53,7 @@ async def create_documents(
process_extension_document_with_new_session,
individual_document,
request.search_space_id,
str(user.id)
str(user.id),
)
elif request.document_type == DocumentType.CRAWLED_URL:
for url in request.content:
@ -49,7 +61,7 @@ async def create_documents(
process_crawled_url_with_new_session,
url,
request.search_space_id,
str(user.id)
str(user.id),
)
elif request.document_type == DocumentType.YOUTUBE_VIDEO:
for url in request.content:
@ -57,13 +69,10 @@ async def create_documents(
process_youtube_video_with_new_session,
url,
request.search_space_id,
str(user.id)
str(user.id),
)
else:
raise HTTPException(
status_code=400,
detail="Invalid document type"
)
raise HTTPException(status_code=400, detail="Invalid document type")
await session.commit()
return {"message": "Documents processed successfully"}
@ -72,18 +81,17 @@ async def create_documents(
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"Failed to process documents: {str(e)}"
)
status_code=500, detail=f"Failed to process documents: {e!s}"
) from e
@router.post("/documents/fileupload")
async def create_documents(
async def create_documents_file_upload(
files: list[UploadFile],
search_space_id: int = Form(...),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
fastapi_background_tasks: BackgroundTasks = BackgroundTasks()
fastapi_background_tasks: BackgroundTasks = BackgroundTasks(),
):
try:
await check_ownership(session, SearchSpace, search_space_id, user)
@ -94,31 +102,32 @@ async def create_documents(
for file in files:
try:
# Save file to a temporary location to avoid stream issues
import tempfile
import aiofiles
import os
import tempfile
# Create temp file
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file:
with tempfile.NamedTemporaryFile(
delete=False, suffix=os.path.splitext(file.filename)[1]
) as temp_file:
temp_path = temp_file.name
# Write uploaded file to temp file
content = await file.read()
with open(temp_path, "wb") as f:
f.write(content)
fastapi_background_tasks.add_task(
process_file_in_background_with_new_session,
temp_path,
file.filename,
search_space_id,
str(user.id)
str(user.id),
)
except Exception as e:
raise HTTPException(
status_code=422,
detail=f"Failed to process file {file.filename}: {str(e)}"
)
detail=f"Failed to process file {file.filename}: {e!s}",
) from e
await session.commit()
return {"message": "Files uploaded for processing"}
@ -127,9 +136,8 @@ async def create_documents(
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"Failed to upload files: {str(e)}"
)
status_code=500, detail=f"Failed to upload files: {e!s}"
) from e
async def process_file_in_background(
@ -139,64 +147,71 @@ async def process_file_in_background(
user_id: str,
session: AsyncSession,
task_logger: TaskLoggingService,
log_entry: Log
log_entry: Log,
):
try:
# Check if the file is a markdown or text file
if filename.lower().endswith(('.md', '.markdown', '.txt')):
if filename.lower().endswith((".md", ".markdown", ".txt")):
await task_logger.log_task_progress(
log_entry,
f"Processing markdown/text file: {filename}",
{"file_type": "markdown", "processing_stage": "reading_file"}
{"file_type": "markdown", "processing_stage": "reading_file"},
)
# For markdown files, read the content directly
with open(file_path, 'r', encoding='utf-8') as f:
with open(file_path, encoding="utf-8") as f:
markdown_content = f.read()
# Clean up the temp file
import os
try:
os.unlink(file_path)
except:
except Exception as e:
print("Error deleting temp file", e)
pass
await task_logger.log_task_progress(
log_entry,
f"Creating document from markdown content: {filename}",
{"processing_stage": "creating_document", "content_length": len(markdown_content)}
{
"processing_stage": "creating_document",
"content_length": len(markdown_content),
},
)
# Process markdown directly through specialized function
result = await add_received_markdown_file_document(
session,
filename,
markdown_content,
search_space_id,
user_id
session, filename, markdown_content, search_space_id, user_id
)
if result:
await task_logger.log_task_success(
log_entry,
f"Successfully processed markdown file: {filename}",
{"document_id": result.id, "content_hash": result.content_hash, "file_type": "markdown"}
{
"document_id": result.id,
"content_hash": result.content_hash,
"file_type": "markdown",
},
)
else:
await task_logger.log_task_success(
log_entry,
f"Markdown file already exists (duplicate): {filename}",
{"duplicate_detected": True, "file_type": "markdown"}
{"duplicate_detected": True, "file_type": "markdown"},
)
# Check if the file is an audio file
elif filename.lower().endswith(('.mp3', '.mp4', '.mpeg', '.mpga', '.m4a', '.wav', '.webm')):
elif filename.lower().endswith(
(".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm")
):
await task_logger.log_task_progress(
log_entry,
f"Processing audio file for transcription: {filename}",
{"file_type": "audio", "processing_stage": "starting_transcription"}
{"file_type": "audio", "processing_stage": "starting_transcription"},
)
# Open the audio file for transcription
with open(file_path, "rb") as audio_file:
# Use LiteLLM for audio transcription
@ -205,65 +220,76 @@ async def process_file_in_background(
model=app_config.STT_SERVICE,
file=audio_file,
api_base=app_config.STT_SERVICE_API_BASE,
api_key=app_config.STT_SERVICE_API_KEY
api_key=app_config.STT_SERVICE_API_KEY,
)
else:
transcription_response = await atranscription(
model=app_config.STT_SERVICE,
api_key=app_config.STT_SERVICE_API_KEY,
file=audio_file
file=audio_file,
)
# Extract the transcribed text
transcribed_text = transcription_response.get("text", "")
# Add metadata about the transcription
transcribed_text = f"# Transcription of {filename}\n\n{transcribed_text}"
transcribed_text = (
f"# Transcription of {filename}\n\n{transcribed_text}"
)
await task_logger.log_task_progress(
log_entry,
f"Transcription completed, creating document: {filename}",
{"processing_stage": "transcription_complete", "transcript_length": len(transcribed_text)}
{
"processing_stage": "transcription_complete",
"transcript_length": len(transcribed_text),
},
)
# Clean up the temp file
try:
os.unlink(file_path)
except:
except Exception as e:
print("Error deleting temp file", e)
pass
# Process transcription as markdown document
result = await add_received_markdown_file_document(
session,
filename,
transcribed_text,
search_space_id,
user_id
session, filename, transcribed_text, search_space_id, user_id
)
if result:
await task_logger.log_task_success(
log_entry,
f"Successfully transcribed and processed audio file: {filename}",
{"document_id": result.id, "content_hash": result.content_hash, "file_type": "audio", "transcript_length": len(transcribed_text)}
{
"document_id": result.id,
"content_hash": result.content_hash,
"file_type": "audio",
"transcript_length": len(transcribed_text),
},
)
else:
await task_logger.log_task_success(
log_entry,
f"Audio file transcript already exists (duplicate): {filename}",
{"duplicate_detected": True, "file_type": "audio"}
{"duplicate_detected": True, "file_type": "audio"},
)
else:
if app_config.ETL_SERVICE == "UNSTRUCTURED":
await task_logger.log_task_progress(
log_entry,
f"Processing file with Unstructured ETL: {filename}",
{"file_type": "document", "etl_service": "UNSTRUCTURED", "processing_stage": "loading"}
{
"file_type": "document",
"etl_service": "UNSTRUCTURED",
"processing_stage": "loading",
},
)
from langchain_unstructured import UnstructuredLoader
# Process the file
loader = UnstructuredLoader(
file_path,
@ -280,212 +306,257 @@ async def process_file_in_background(
await task_logger.log_task_progress(
log_entry,
f"Unstructured ETL completed, creating document: {filename}",
{"processing_stage": "etl_complete", "elements_count": len(docs)}
{"processing_stage": "etl_complete", "elements_count": len(docs)},
)
# Clean up the temp file
import os
try:
os.unlink(file_path)
except:
except Exception as e:
print("Error deleting temp file", e)
pass
# Pass the documents to the existing background task
result = await add_received_file_document_using_unstructured(
session,
filename,
docs,
search_space_id,
user_id
session, filename, docs, search_space_id, user_id
)
if result:
await task_logger.log_task_success(
log_entry,
f"Successfully processed file with Unstructured: {filename}",
{"document_id": result.id, "content_hash": result.content_hash, "file_type": "document", "etl_service": "UNSTRUCTURED"}
{
"document_id": result.id,
"content_hash": result.content_hash,
"file_type": "document",
"etl_service": "UNSTRUCTURED",
},
)
else:
await task_logger.log_task_success(
log_entry,
f"Document already exists (duplicate): {filename}",
{"duplicate_detected": True, "file_type": "document", "etl_service": "UNSTRUCTURED"}
{
"duplicate_detected": True,
"file_type": "document",
"etl_service": "UNSTRUCTURED",
},
)
elif app_config.ETL_SERVICE == "LLAMACLOUD":
await task_logger.log_task_progress(
log_entry,
f"Processing file with LlamaCloud ETL: {filename}",
{"file_type": "document", "etl_service": "LLAMACLOUD", "processing_stage": "parsing"}
{
"file_type": "document",
"etl_service": "LLAMACLOUD",
"processing_stage": "parsing",
},
)
from llama_cloud_services import LlamaParse
from llama_cloud_services.parse.utils import ResultType
# Create LlamaParse parser instance
parser = LlamaParse(
api_key=app_config.LLAMA_CLOUD_API_KEY,
num_workers=1, # Use single worker for file processing
verbose=True,
language="en",
result_type=ResultType.MD
result_type=ResultType.MD,
)
# Parse the file asynchronously
result = await parser.aparse(file_path)
# Clean up the temp file
import os
try:
os.unlink(file_path)
except:
except Exception as e:
print("Error deleting temp file", e)
pass
# Get markdown documents from the result
markdown_documents = await result.aget_markdown_documents(split_by_page=False)
markdown_documents = await result.aget_markdown_documents(
split_by_page=False
)
await task_logger.log_task_progress(
log_entry,
f"LlamaCloud parsing completed, creating documents: {filename}",
{"processing_stage": "parsing_complete", "documents_count": len(markdown_documents)}
{
"processing_stage": "parsing_complete",
"documents_count": len(markdown_documents),
},
)
for doc in markdown_documents:
# Extract text content from the markdown documents
markdown_content = doc.text
# Process the documents using our LlamaCloud background task
doc_result = await add_received_file_document_using_llamacloud(
session,
filename,
llamacloud_markdown_document=markdown_content,
search_space_id=search_space_id,
user_id=user_id
user_id=user_id,
)
if doc_result:
await task_logger.log_task_success(
log_entry,
f"Successfully processed file with LlamaCloud: {filename}",
{"document_id": doc_result.id, "content_hash": doc_result.content_hash, "file_type": "document", "etl_service": "LLAMACLOUD"}
{
"document_id": doc_result.id,
"content_hash": doc_result.content_hash,
"file_type": "document",
"etl_service": "LLAMACLOUD",
},
)
else:
await task_logger.log_task_success(
log_entry,
f"Document already exists (duplicate): {filename}",
{"duplicate_detected": True, "file_type": "document", "etl_service": "LLAMACLOUD"}
{
"duplicate_detected": True,
"file_type": "document",
"etl_service": "LLAMACLOUD",
},
)
elif app_config.ETL_SERVICE == "DOCLING":
await task_logger.log_task_progress(
log_entry,
f"Processing file with Docling ETL: {filename}",
{"file_type": "document", "etl_service": "DOCLING", "processing_stage": "parsing"}
{
"file_type": "document",
"etl_service": "DOCLING",
"processing_stage": "parsing",
},
)
# Use Docling service for document processing
from app.services.docling_service import create_docling_service
# Create Docling service
docling_service = create_docling_service()
# Process the document
result = await docling_service.process_document(file_path, filename)
# Clean up the temp file
import os
try:
os.unlink(file_path)
except:
except Exception as e:
print("Error deleting temp file", e)
pass
await task_logger.log_task_progress(
log_entry,
f"Docling parsing completed, creating document: {filename}",
{"processing_stage": "parsing_complete", "content_length": len(result['content'])}
{
"processing_stage": "parsing_complete",
"content_length": len(result["content"]),
},
)
# Process the document using our Docling background task
doc_result = await add_received_file_document_using_docling(
session,
filename,
docling_markdown_document=result['content'],
docling_markdown_document=result["content"],
search_space_id=search_space_id,
user_id=user_id
user_id=user_id,
)
if doc_result:
await task_logger.log_task_success(
log_entry,
f"Successfully processed file with Docling: {filename}",
{"document_id": doc_result.id, "content_hash": doc_result.content_hash, "file_type": "document", "etl_service": "DOCLING"}
{
"document_id": doc_result.id,
"content_hash": doc_result.content_hash,
"file_type": "document",
"etl_service": "DOCLING",
},
)
else:
await task_logger.log_task_success(
log_entry,
f"Document already exists (duplicate): {filename}",
{"duplicate_detected": True, "file_type": "document", "etl_service": "DOCLING"}
{
"duplicate_detected": True,
"file_type": "document",
"etl_service": "DOCLING",
},
)
except Exception as e:
await task_logger.log_task_failure(
log_entry,
f"Failed to process file: {filename}",
str(e),
{"error_type": type(e).__name__, "filename": filename}
{"error_type": type(e).__name__, "filename": filename},
)
import logging
logging.error(f"Error processing file in background: {str(e)}")
logging.error(f"Error processing file in background: {e!s}")
raise # Re-raise so the wrapper can also handle it
@router.get("/documents/", response_model=List[DocumentRead])
@router.get("/documents/", response_model=list[DocumentRead])
async def read_documents(
skip: int = 0,
limit: int = 300,
search_space_id: int = None,
search_space_id: int | None = None,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
try:
query = select(Document).join(SearchSpace).filter(
SearchSpace.user_id == user.id)
query = (
select(Document).join(SearchSpace).filter(SearchSpace.user_id == user.id)
)
# Filter by search_space_id if provided
if search_space_id is not None:
query = query.filter(Document.search_space_id == search_space_id)
result = await session.execute(
query.offset(skip).limit(limit)
)
result = await session.execute(query.offset(skip).limit(limit))
db_documents = result.scalars().all()
# Convert database objects to API-friendly format
api_documents = []
for doc in db_documents:
api_documents.append(DocumentRead(
id=doc.id,
title=doc.title,
document_type=doc.document_type,
document_metadata=doc.document_metadata,
content=doc.content,
created_at=doc.created_at,
search_space_id=doc.search_space_id
))
api_documents.append(
DocumentRead(
id=doc.id,
title=doc.title,
document_type=doc.document_type,
document_metadata=doc.document_metadata,
content=doc.content,
created_at=doc.created_at,
search_space_id=doc.search_space_id,
)
)
return api_documents
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to fetch documents: {str(e)}"
)
status_code=500, detail=f"Failed to fetch documents: {e!s}"
) from e
@router.get("/documents/{document_id}", response_model=DocumentRead)
async def read_document(
document_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
try:
result = await session.execute(
@ -497,8 +568,7 @@ async def read_document(
if not document:
raise HTTPException(
status_code=404,
detail=f"Document with id {document_id} not found"
status_code=404, detail=f"Document with id {document_id} not found"
)
# Convert database object to API-friendly format
@ -509,13 +579,12 @@ async def read_document(
document_metadata=document.document_metadata,
content=document.content,
created_at=document.created_at,
search_space_id=document.search_space_id
search_space_id=document.search_space_id,
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to fetch document: {str(e)}"
)
status_code=500, detail=f"Failed to fetch document: {e!s}"
) from e
@router.put("/documents/{document_id}", response_model=DocumentRead)
@ -523,7 +592,7 @@ async def update_document(
document_id: int,
document_update: DocumentUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
try:
# Query the document directly instead of using read_document function
@ -536,8 +605,7 @@ async def update_document(
if not db_document:
raise HTTPException(
status_code=404,
detail=f"Document with id {document_id} not found"
status_code=404, detail=f"Document with id {document_id} not found"
)
update_data = document_update.model_dump(exclude_unset=True)
@ -554,23 +622,22 @@ async def update_document(
document_metadata=db_document.document_metadata,
content=db_document.content,
created_at=db_document.created_at,
search_space_id=db_document.search_space_id
search_space_id=db_document.search_space_id,
)
except HTTPException:
raise
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"Failed to update document: {str(e)}"
)
status_code=500, detail=f"Failed to update document: {e!s}"
) from e
@router.delete("/documents/{document_id}", response_model=dict)
async def delete_document(
document_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
try:
# Query the document directly instead of using read_document function
@ -583,8 +650,7 @@ async def delete_document(
if not document:
raise HTTPException(
status_code=404,
detail=f"Document with id {document_id} not found"
status_code=404, detail=f"Document with id {document_id} not found"
)
await session.delete(document)
@ -595,15 +661,12 @@ async def delete_document(
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"Failed to delete document: {str(e)}"
)
status_code=500, detail=f"Failed to delete document: {e!s}"
) from e
async def process_extension_document_with_new_session(
individual_document,
search_space_id: int,
user_id: str
individual_document, search_space_id: int, user_id: str
):
"""Create a new session and process extension document."""
from app.db import async_session_maker
@ -612,7 +675,7 @@ async def process_extension_document_with_new_session(
async with async_session_maker() as session:
# Initialize task logging service
task_logger = TaskLoggingService(session, search_space_id)
# Log task start
log_entry = await task_logger.log_task_start(
task_name="process_extension_document",
@ -622,40 +685,41 @@ async def process_extension_document_with_new_session(
"document_type": "EXTENSION",
"url": individual_document.metadata.VisitedWebPageURL,
"title": individual_document.metadata.VisitedWebPageTitle,
"user_id": user_id
}
"user_id": user_id,
},
)
try:
result = await add_extension_received_document(session, individual_document, search_space_id, user_id)
result = await add_extension_received_document(
session, individual_document, search_space_id, user_id
)
if result:
await task_logger.log_task_success(
log_entry,
f"Successfully processed extension document: {individual_document.metadata.VisitedWebPageTitle}",
{"document_id": result.id, "content_hash": result.content_hash}
{"document_id": result.id, "content_hash": result.content_hash},
)
else:
await task_logger.log_task_success(
log_entry,
f"Extension document already exists (duplicate): {individual_document.metadata.VisitedWebPageTitle}",
{"duplicate_detected": True}
{"duplicate_detected": True},
)
except Exception as e:
await task_logger.log_task_failure(
log_entry,
f"Failed to process extension document: {individual_document.metadata.VisitedWebPageTitle}",
str(e),
{"error_type": type(e).__name__}
{"error_type": type(e).__name__},
)
import logging
logging.error(f"Error processing extension document: {str(e)}")
logging.error(f"Error processing extension document: {e!s}")
async def process_crawled_url_with_new_session(
url: str,
search_space_id: int,
user_id: str
url: str, search_space_id: int, user_id: str
):
"""Create a new session and process crawled URL."""
from app.db import async_session_maker
@ -664,50 +728,50 @@ async def process_crawled_url_with_new_session(
async with async_session_maker() as session:
# Initialize task logging service
task_logger = TaskLoggingService(session, search_space_id)
# Log task start
log_entry = await task_logger.log_task_start(
task_name="process_crawled_url",
source="document_processor",
message=f"Starting URL crawling and processing for: {url}",
metadata={
"document_type": "CRAWLED_URL",
"url": url,
"user_id": user_id
}
metadata={"document_type": "CRAWLED_URL", "url": url, "user_id": user_id},
)
try:
result = await add_crawled_url_document(session, url, search_space_id, user_id)
result = await add_crawled_url_document(
session, url, search_space_id, user_id
)
if result:
await task_logger.log_task_success(
log_entry,
f"Successfully crawled and processed URL: {url}",
{"document_id": result.id, "title": result.title, "content_hash": result.content_hash}
{
"document_id": result.id,
"title": result.title,
"content_hash": result.content_hash,
},
)
else:
await task_logger.log_task_success(
log_entry,
f"URL document already exists (duplicate): {url}",
{"duplicate_detected": True}
{"duplicate_detected": True},
)
except Exception as e:
await task_logger.log_task_failure(
log_entry,
f"Failed to crawl URL: {url}",
str(e),
{"error_type": type(e).__name__}
{"error_type": type(e).__name__},
)
import logging
logging.error(f"Error processing crawled URL: {str(e)}")
logging.error(f"Error processing crawled URL: {e!s}")
async def process_file_in_background_with_new_session(
file_path: str,
filename: str,
search_space_id: int,
user_id: str
file_path: str, filename: str, search_space_id: int, user_id: str
):
"""Create a new session and process file."""
from app.db import async_session_maker
@ -716,7 +780,7 @@ async def process_file_in_background_with_new_session(
async with async_session_maker() as session:
# Initialize task logging service
task_logger = TaskLoggingService(session, search_space_id)
# Log task start
log_entry = await task_logger.log_task_start(
task_name="process_file_upload",
@ -726,29 +790,36 @@ async def process_file_in_background_with_new_session(
"document_type": "FILE",
"filename": filename,
"file_path": file_path,
"user_id": user_id
}
"user_id": user_id,
},
)
try:
await process_file_in_background(file_path, filename, search_space_id, user_id, session, task_logger, log_entry)
await process_file_in_background(
file_path,
filename,
search_space_id,
user_id,
session,
task_logger,
log_entry,
)
# Note: success/failure logging is handled within process_file_in_background
except Exception as e:
await task_logger.log_task_failure(
log_entry,
f"Failed to process file: {filename}",
str(e),
{"error_type": type(e).__name__}
{"error_type": type(e).__name__},
)
import logging
logging.error(f"Error processing file: {str(e)}")
logging.error(f"Error processing file: {e!s}")
async def process_youtube_video_with_new_session(
url: str,
search_space_id: int,
user_id: str
url: str, search_space_id: int, user_id: str
):
"""Create a new session and process YouTube video."""
from app.db import async_session_maker
@ -757,42 +828,43 @@ async def process_youtube_video_with_new_session(
async with async_session_maker() as session:
# Initialize task logging service
task_logger = TaskLoggingService(session, search_space_id)
# Log task start
log_entry = await task_logger.log_task_start(
task_name="process_youtube_video",
source="document_processor",
message=f"Starting YouTube video processing for: {url}",
metadata={
"document_type": "YOUTUBE_VIDEO",
"url": url,
"user_id": user_id
}
metadata={"document_type": "YOUTUBE_VIDEO", "url": url, "user_id": user_id},
)
try:
result = await add_youtube_video_document(session, url, search_space_id, user_id)
result = await add_youtube_video_document(
session, url, search_space_id, user_id
)
if result:
await task_logger.log_task_success(
log_entry,
f"Successfully processed YouTube video: {result.title}",
{"document_id": result.id, "video_id": result.document_metadata.get("video_id"), "content_hash": result.content_hash}
{
"document_id": result.id,
"video_id": result.document_metadata.get("video_id"),
"content_hash": result.content_hash,
},
)
else:
await task_logger.log_task_success(
log_entry,
f"YouTube video document already exists (duplicate): {url}",
{"duplicate_detected": True}
{"duplicate_detected": True},
)
except Exception as e:
await task_logger.log_task_failure(
log_entry,
f"Failed to process YouTube video: {url}",
str(e),
{"error_type": type(e).__name__}
{"error_type": type(e).__name__},
)
import logging
logging.error(f"Error processing YouTube video: {str(e)}")
logging.error(f"Error processing YouTube video: {e!s}")

View file

@ -1,35 +1,40 @@
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from typing import List, Optional
from pydantic import BaseModel
from app.db import get_async_session, User, LLMConfig
from app.schemas import LLMConfigCreate, LLMConfigUpdate, LLMConfigRead
from app.db import LLMConfig, User, get_async_session
from app.schemas import LLMConfigCreate, LLMConfigRead, LLMConfigUpdate
from app.users import current_active_user
from app.utils.check_ownership import check_ownership
router = APIRouter()
class LLMPreferencesUpdate(BaseModel):
"""Schema for updating user LLM preferences"""
long_context_llm_id: Optional[int] = None
fast_llm_id: Optional[int] = None
strategic_llm_id: Optional[int] = None
long_context_llm_id: int | None = None
fast_llm_id: int | None = None
strategic_llm_id: int | None = None
class LLMPreferencesRead(BaseModel):
"""Schema for reading user LLM preferences"""
long_context_llm_id: Optional[int] = None
fast_llm_id: Optional[int] = None
strategic_llm_id: Optional[int] = None
long_context_llm: Optional[LLMConfigRead] = None
fast_llm: Optional[LLMConfigRead] = None
strategic_llm: Optional[LLMConfigRead] = None
long_context_llm_id: int | None = None
fast_llm_id: int | None = None
strategic_llm_id: int | None = None
long_context_llm: LLMConfigRead | None = None
fast_llm: LLMConfigRead | None = None
strategic_llm: LLMConfigRead | None = None
@router.post("/llm-configs/", response_model=LLMConfigRead)
async def create_llm_config(
llm_config: LLMConfigCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
"""Create a new LLM configuration for the authenticated user"""
try:
@ -43,16 +48,16 @@ async def create_llm_config(
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"Failed to create LLM configuration: {str(e)}"
)
status_code=500, detail=f"Failed to create LLM configuration: {e!s}"
) from e
@router.get("/llm-configs/", response_model=List[LLMConfigRead])
@router.get("/llm-configs/", response_model=list[LLMConfigRead])
async def read_llm_configs(
skip: int = 0,
limit: int = 200,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
"""Get all LLM configurations for the authenticated user"""
try:
@ -65,15 +70,15 @@ async def read_llm_configs(
return result.scalars().all()
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to fetch LLM configurations: {str(e)}"
)
status_code=500, detail=f"Failed to fetch LLM configurations: {e!s}"
) from e
@router.get("/llm-configs/{llm_config_id}", response_model=LLMConfigRead)
async def read_llm_config(
llm_config_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
"""Get a specific LLM configuration by ID"""
try:
@ -83,25 +88,25 @@ async def read_llm_config(
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to fetch LLM configuration: {str(e)}"
)
status_code=500, detail=f"Failed to fetch LLM configuration: {e!s}"
) from e
@router.put("/llm-configs/{llm_config_id}", response_model=LLMConfigRead)
async def update_llm_config(
llm_config_id: int,
llm_config_update: LLMConfigUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
"""Update an existing LLM configuration"""
try:
db_llm_config = await check_ownership(session, LLMConfig, llm_config_id, user)
update_data = llm_config_update.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(db_llm_config, key, value)
await session.commit()
await session.refresh(db_llm_config)
return db_llm_config
@ -110,15 +115,15 @@ async def update_llm_config(
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"Failed to update LLM configuration: {str(e)}"
)
status_code=500, detail=f"Failed to update LLM configuration: {e!s}"
) from e
@router.delete("/llm-configs/{llm_config_id}", response_model=dict)
async def delete_llm_config(
llm_config_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
"""Delete an LLM configuration"""
try:
@ -131,22 +136,23 @@ async def delete_llm_config(
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"Failed to delete LLM configuration: {str(e)}"
)
status_code=500, detail=f"Failed to delete LLM configuration: {e!s}"
) from e
# User LLM Preferences endpoints
@router.get("/users/me/llm-preferences", response_model=LLMPreferencesRead)
async def get_user_llm_preferences(
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
"""Get the current user's LLM preferences"""
try:
# Refresh user to get latest relationships
await session.refresh(user)
result = {
"long_context_llm_id": user.long_context_llm_id,
"fast_llm_id": user.fast_llm_id,
@ -155,82 +161,79 @@ async def get_user_llm_preferences(
"fast_llm": None,
"strategic_llm": None,
}
# Fetch the actual LLM configs if they exist
if user.long_context_llm_id:
long_context_llm = await session.execute(
select(LLMConfig).filter(
LLMConfig.id == user.long_context_llm_id,
LLMConfig.user_id == user.id
LLMConfig.user_id == user.id,
)
)
llm_config = long_context_llm.scalars().first()
if llm_config:
result["long_context_llm"] = llm_config
if user.fast_llm_id:
fast_llm = await session.execute(
select(LLMConfig).filter(
LLMConfig.id == user.fast_llm_id,
LLMConfig.user_id == user.id
LLMConfig.id == user.fast_llm_id, LLMConfig.user_id == user.id
)
)
llm_config = fast_llm.scalars().first()
if llm_config:
result["fast_llm"] = llm_config
if user.strategic_llm_id:
strategic_llm = await session.execute(
select(LLMConfig).filter(
LLMConfig.id == user.strategic_llm_id,
LLMConfig.user_id == user.id
LLMConfig.id == user.strategic_llm_id, LLMConfig.user_id == user.id
)
)
llm_config = strategic_llm.scalars().first()
if llm_config:
result["strategic_llm"] = llm_config
return result
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to fetch LLM preferences: {str(e)}"
)
status_code=500, detail=f"Failed to fetch LLM preferences: {e!s}"
) from e
@router.put("/users/me/llm-preferences", response_model=LLMPreferencesRead)
async def update_user_llm_preferences(
preferences: LLMPreferencesUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
"""Update the current user's LLM preferences"""
try:
# Validate that all provided LLM config IDs belong to the user
update_data = preferences.model_dump(exclude_unset=True)
for key, llm_config_id in update_data.items():
for _key, llm_config_id in update_data.items():
if llm_config_id is not None:
# Verify ownership of the LLM config
result = await session.execute(
select(LLMConfig).filter(
LLMConfig.id == llm_config_id,
LLMConfig.user_id == user.id
LLMConfig.id == llm_config_id, LLMConfig.user_id == user.id
)
)
llm_config = result.scalars().first()
if not llm_config:
raise HTTPException(
status_code=404,
detail=f"LLM configuration {llm_config_id} not found or you don't have permission to access it"
detail=f"LLM configuration {llm_config_id} not found or you don't have permission to access it",
)
# Update user preferences
for key, value in update_data.items():
setattr(user, key, value)
await session.commit()
await session.refresh(user)
# Return updated preferences
return await get_user_llm_preferences(session, user)
except HTTPException:
@ -238,6 +241,5 @@ async def update_user_llm_preferences(
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"Failed to update LLM preferences: {str(e)}"
)
status_code=500, detail=f"Failed to update LLM preferences: {e!s}"
) from e

View file

@ -1,22 +1,23 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import and_, desc
from typing import List, Optional
from datetime import datetime, timedelta
from app.db import get_async_session, User, SearchSpace, Log, LogLevel, LogStatus
from app.schemas import LogCreate, LogUpdate, LogRead, LogFilter
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import and_, desc
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import Log, LogLevel, LogStatus, SearchSpace, User, get_async_session
from app.schemas import LogCreate, LogRead, LogUpdate
from app.users import current_active_user
from app.utils.check_ownership import check_ownership
router = APIRouter()
@router.post("/logs/", response_model=LogRead)
async def create_log(
log: LogCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
"""Create a new log entry."""
try:
@ -33,22 +34,22 @@ async def create_log(
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"Failed to create log: {str(e)}"
)
status_code=500, detail=f"Failed to create log: {e!s}"
) from e
@router.get("/logs/", response_model=List[LogRead])
@router.get("/logs/", response_model=list[LogRead])
async def read_logs(
skip: int = 0,
limit: int = 100,
search_space_id: Optional[int] = None,
level: Optional[LogLevel] = None,
status: Optional[LogStatus] = None,
source: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
search_space_id: int | None = None,
level: LogLevel | None = None,
status: LogStatus | None = None,
source: str | None = None,
start_date: datetime | None = None,
end_date: datetime | None = None,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
"""Get logs with optional filtering."""
try:
@ -62,23 +63,23 @@ async def read_logs(
# Apply filters
filters = []
if search_space_id is not None:
await check_ownership(session, SearchSpace, search_space_id, user)
filters.append(Log.search_space_id == search_space_id)
if level is not None:
filters.append(Log.level == level)
if status is not None:
filters.append(Log.status == status)
if source is not None:
filters.append(Log.source.ilike(f"%{source}%"))
if start_date is not None:
filters.append(Log.created_at >= start_date)
if end_date is not None:
filters.append(Log.created_at <= end_date)
@ -93,15 +94,15 @@ async def read_logs(
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to fetch logs: {str(e)}"
)
status_code=500, detail=f"Failed to fetch logs: {e!s}"
) from e
@router.get("/logs/{log_id}", response_model=LogRead)
async def read_log(
log_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
"""Get a specific log by ID."""
try:
@ -112,25 +113,25 @@ async def read_log(
.filter(Log.id == log_id, SearchSpace.user_id == user.id)
)
log = result.scalars().first()
if not log:
raise HTTPException(status_code=404, detail="Log not found")
return log
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to fetch log: {str(e)}"
)
status_code=500, detail=f"Failed to fetch log: {e!s}"
) from e
@router.put("/logs/{log_id}", response_model=LogRead)
async def update_log(
log_id: int,
log_update: LogUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
"""Update a log entry."""
try:
@ -141,7 +142,7 @@ async def update_log(
.filter(Log.id == log_id, SearchSpace.user_id == user.id)
)
db_log = result.scalars().first()
if not db_log:
raise HTTPException(status_code=404, detail="Log not found")
@ -158,15 +159,15 @@ async def update_log(
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"Failed to update log: {str(e)}"
)
status_code=500, detail=f"Failed to update log: {e!s}"
) from e
@router.delete("/logs/{log_id}")
async def delete_log(
log_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
"""Delete a log entry."""
try:
@ -177,7 +178,7 @@ async def delete_log(
.filter(Log.id == log_id, SearchSpace.user_id == user.id)
)
db_log = result.scalars().first()
if not db_log:
raise HTTPException(status_code=404, detail="Log not found")
@ -189,38 +190,35 @@ async def delete_log(
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"Failed to delete log: {str(e)}"
)
status_code=500, detail=f"Failed to delete log: {e!s}"
) from e
@router.get("/logs/search-space/{search_space_id}/summary")
async def get_logs_summary(
search_space_id: int,
hours: int = 24,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
"""Get a summary of logs for a search space in the last X hours."""
try:
# Check ownership
await check_ownership(session, SearchSpace, search_space_id, user)
# Calculate time window
since = datetime.utcnow().replace(microsecond=0) - timedelta(hours=hours)
# Get logs from the time window
result = await session.execute(
select(Log)
.filter(
and_(
Log.search_space_id == search_space_id,
Log.created_at >= since
)
and_(Log.search_space_id == search_space_id, Log.created_at >= since)
)
.order_by(desc(Log.created_at))
)
logs = result.scalars().all()
# Create summary
summary = {
"total_logs": len(logs),
@ -229,52 +227,69 @@ async def get_logs_summary(
"by_level": {},
"by_source": {},
"active_tasks": [],
"recent_failures": []
"recent_failures": [],
}
# Count by status and level
for log in logs:
# Status counts
status_str = log.status.value
summary["by_status"][status_str] = summary["by_status"].get(status_str, 0) + 1
summary["by_status"][status_str] = (
summary["by_status"].get(status_str, 0) + 1
)
# Level counts
level_str = log.level.value
summary["by_level"][level_str] = summary["by_level"].get(level_str, 0) + 1
# Source counts
if log.source:
summary["by_source"][log.source] = summary["by_source"].get(log.source, 0) + 1
summary["by_source"][log.source] = (
summary["by_source"].get(log.source, 0) + 1
)
# Active tasks (IN_PROGRESS)
if log.status == LogStatus.IN_PROGRESS:
task_name = log.log_metadata.get("task_name", "Unknown") if log.log_metadata else "Unknown"
summary["active_tasks"].append({
"id": log.id,
"task_name": task_name,
"message": log.message,
"started_at": log.created_at,
"source": log.source
})
task_name = (
log.log_metadata.get("task_name", "Unknown")
if log.log_metadata
else "Unknown"
)
summary["active_tasks"].append(
{
"id": log.id,
"task_name": task_name,
"message": log.message,
"started_at": log.created_at,
"source": log.source,
}
)
# Recent failures
if log.status == LogStatus.FAILED and len(summary["recent_failures"]) < 10:
task_name = log.log_metadata.get("task_name", "Unknown") if log.log_metadata else "Unknown"
summary["recent_failures"].append({
"id": log.id,
"task_name": task_name,
"message": log.message,
"failed_at": log.created_at,
"source": log.source,
"error_details": log.log_metadata.get("error_details") if log.log_metadata else None
})
task_name = (
log.log_metadata.get("task_name", "Unknown")
if log.log_metadata
else "Unknown"
)
summary["recent_failures"].append(
{
"id": log.id,
"task_name": task_name,
"message": log.message,
"failed_at": log.created_at,
"source": log.source,
"error_details": log.log_metadata.get("error_details")
if log.log_metadata
else None,
}
)
return summary
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to generate logs summary: {str(e)}"
)
status_code=500, detail=f"Failed to generate logs summary: {e!s}"
) from e

View file

@ -1,24 +1,31 @@
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from typing import List
from app.db import get_async_session, User, SearchSpace, Podcast, Chat
from app.schemas import PodcastCreate, PodcastUpdate, PodcastRead, PodcastGenerateRequest
from app.users import current_active_user
from app.utils.check_ownership import check_ownership
from app.tasks.podcast_tasks import generate_chat_podcast
from fastapi.responses import StreamingResponse
import os
from pathlib import Path
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from fastapi.responses import StreamingResponse
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import Chat, Podcast, SearchSpace, User, get_async_session
from app.schemas import (
PodcastCreate,
PodcastGenerateRequest,
PodcastRead,
PodcastUpdate,
)
from app.tasks.podcast_tasks import generate_chat_podcast
from app.users import current_active_user
from app.utils.check_ownership import check_ownership
router = APIRouter()
@router.post("/podcasts/", response_model=PodcastRead)
async def create_podcast(
podcast: PodcastCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
try:
await check_ownership(session, SearchSpace, podcast.search_space_id, user)
@ -29,22 +36,30 @@ async def create_podcast(
return db_podcast
except HTTPException as he:
raise he
except IntegrityError as e:
except IntegrityError:
await session.rollback()
raise HTTPException(status_code=400, detail="Podcast creation failed due to constraint violation")
except SQLAlchemyError as e:
raise HTTPException(
status_code=400,
detail="Podcast creation failed due to constraint violation",
) from None
except SQLAlchemyError:
await session.rollback()
raise HTTPException(status_code=500, detail="Database error occurred while creating podcast")
except Exception as e:
raise HTTPException(
status_code=500, detail="Database error occurred while creating podcast"
) from None
except Exception:
await session.rollback()
raise HTTPException(status_code=500, detail="An unexpected error occurred")
raise HTTPException(
status_code=500, detail="An unexpected error occurred"
) from None
@router.get("/podcasts/", response_model=List[PodcastRead])
@router.get("/podcasts/", response_model=list[PodcastRead])
async def read_podcasts(
skip: int = 0,
limit: int = 100,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
if skip < 0 or limit < 1:
raise HTTPException(status_code=400, detail="Invalid pagination parameters")
@ -58,13 +73,16 @@ async def read_podcasts(
)
return result.scalars().all()
except SQLAlchemyError:
raise HTTPException(status_code=500, detail="Database error occurred while fetching podcasts")
raise HTTPException(
status_code=500, detail="Database error occurred while fetching podcasts"
) from None
@router.get("/podcasts/{podcast_id}", response_model=PodcastRead)
async def read_podcast(
podcast_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
try:
result = await session.execute(
@ -76,20 +94,23 @@ async def read_podcast(
if not podcast:
raise HTTPException(
status_code=404,
detail="Podcast not found or you don't have permission to access it"
detail="Podcast not found or you don't have permission to access it",
)
return podcast
except HTTPException as he:
raise he
except SQLAlchemyError:
raise HTTPException(status_code=500, detail="Database error occurred while fetching podcast")
raise HTTPException(
status_code=500, detail="Database error occurred while fetching podcast"
) from None
@router.put("/podcasts/{podcast_id}", response_model=PodcastRead)
async def update_podcast(
podcast_id: int,
podcast_update: PodcastUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
try:
db_podcast = await read_podcast(podcast_id, session, user)
@ -103,16 +124,21 @@ async def update_podcast(
raise he
except IntegrityError:
await session.rollback()
raise HTTPException(status_code=400, detail="Update failed due to constraint violation")
raise HTTPException(
status_code=400, detail="Update failed due to constraint violation"
) from None
except SQLAlchemyError:
await session.rollback()
raise HTTPException(status_code=500, detail="Database error occurred while updating podcast")
raise HTTPException(
status_code=500, detail="Database error occurred while updating podcast"
) from None
@router.delete("/podcasts/{podcast_id}", response_model=dict)
async def delete_podcast(
podcast_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
try:
db_podcast = await read_podcast(podcast_id, session, user)
@ -123,83 +149,100 @@ async def delete_podcast(
raise he
except SQLAlchemyError:
await session.rollback()
raise HTTPException(status_code=500, detail="Database error occurred while deleting podcast")
raise HTTPException(
status_code=500, detail="Database error occurred while deleting podcast"
) from None
async def generate_chat_podcast_with_new_session(
chat_id: int,
search_space_id: int,
podcast_title: str,
user_id: int
chat_id: int, search_space_id: int, podcast_title: str, user_id: int
):
"""Create a new session and process chat podcast generation."""
from app.db import async_session_maker
async with async_session_maker() as session:
try:
await generate_chat_podcast(session, chat_id, search_space_id, podcast_title, user_id)
await generate_chat_podcast(
session, chat_id, search_space_id, podcast_title, user_id
)
except Exception as e:
import logging
logging.error(f"Error generating podcast from chat: {str(e)}")
logging.error(f"Error generating podcast from chat: {e!s}")
@router.post("/podcasts/generate/")
async def generate_podcast(
request: PodcastGenerateRequest,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
fastapi_background_tasks: BackgroundTasks = BackgroundTasks()
fastapi_background_tasks: BackgroundTasks = BackgroundTasks(),
):
try:
# Check if the user owns the search space
await check_ownership(session, SearchSpace, request.search_space_id, user)
if request.type == "CHAT":
# Verify that all chat IDs belong to this user and search space
query = select(Chat).filter(
Chat.id.in_(request.ids),
Chat.search_space_id == request.search_space_id
).join(SearchSpace).filter(SearchSpace.user_id == user.id)
query = (
select(Chat)
.filter(
Chat.id.in_(request.ids),
Chat.search_space_id == request.search_space_id,
)
.join(SearchSpace)
.filter(SearchSpace.user_id == user.id)
)
result = await session.execute(query)
valid_chats = result.scalars().all()
valid_chat_ids = [chat.id for chat in valid_chats]
# If any requested ID is not in valid IDs, raise error immediately
if len(valid_chat_ids) != len(request.ids):
raise HTTPException(
status_code=403,
detail="One or more chat IDs do not belong to this user or search space"
status_code=403,
detail="One or more chat IDs do not belong to this user or search space",
)
# Only add a single task with the first chat ID
for chat_id in valid_chat_ids:
fastapi_background_tasks.add_task(
generate_chat_podcast_with_new_session,
chat_id,
generate_chat_podcast_with_new_session,
chat_id,
request.search_space_id,
request.podcast_title,
user.id
user.id,
)
return {
"message": "Podcast generation started",
}
except HTTPException as he:
raise he
except IntegrityError as e:
except IntegrityError:
await session.rollback()
raise HTTPException(status_code=400, detail="Podcast generation failed due to constraint violation")
except SQLAlchemyError as e:
raise HTTPException(
status_code=400,
detail="Podcast generation failed due to constraint violation",
) from None
except SQLAlchemyError:
await session.rollback()
raise HTTPException(status_code=500, detail="Database error occurred while generating podcast")
raise HTTPException(
status_code=500, detail="Database error occurred while generating podcast"
) from None
except Exception as e:
await session.rollback()
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
raise HTTPException(
status_code=500, detail=f"An unexpected error occurred: {e!s}"
) from e
@router.get("/podcasts/{podcast_id}/stream")
async def stream_podcast(
podcast_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
"""Stream a podcast audio file."""
try:
@ -210,36 +253,38 @@ async def stream_podcast(
.filter(Podcast.id == podcast_id, SearchSpace.user_id == user.id)
)
podcast = result.scalars().first()
if not podcast:
raise HTTPException(
status_code=404,
detail="Podcast not found or you don't have permission to access it"
detail="Podcast not found or you don't have permission to access it",
)
# Get the file path
file_path = podcast.file_location
# Check if the file exists
if not os.path.isfile(file_path):
raise HTTPException(status_code=404, detail="Podcast audio file not found")
# Define a generator function to stream the file
def iterfile():
with open(file_path, mode="rb") as file_like:
yield from file_like
# Return a streaming response with appropriate headers
return StreamingResponse(
iterfile(),
media_type="audio/mpeg",
headers={
"Accept-Ranges": "bytes",
"Content-Disposition": f"inline; filename={Path(file_path).name}"
}
"Content-Disposition": f"inline; filename={Path(file_path).name}",
},
)
except HTTPException as he:
raise he
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error streaming podcast: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Error streaming podcast: {e!s}"
) from e

View file

@ -9,35 +9,58 @@ POST /search-source-connectors/{connector_id}/index - Index content from a conne
Note: Each user can have only one connector of each type (SERPER_API, TAVILY_API, SLACK_CONNECTOR, NOTION_CONNECTOR, GITHUB_CONNECTOR, LINEAR_CONNECTOR, DISCORD_CONNECTOR).
"""
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks, Body
import logging
from datetime import datetime, timedelta
from typing import Any
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.exc import IntegrityError
from typing import List, Dict, Any
from app.db import get_async_session, User, SearchSourceConnector, SearchSourceConnectorType, SearchSpace, async_session_maker
from app.schemas import SearchSourceConnectorCreate, SearchSourceConnectorUpdate, SearchSourceConnectorRead, SearchSourceConnectorBase
from app.connectors.github_connector import GitHubConnector
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
SearchSpace,
User,
async_session_maker,
get_async_session,
)
from app.schemas import (
SearchSourceConnectorBase,
SearchSourceConnectorCreate,
SearchSourceConnectorRead,
SearchSourceConnectorUpdate,
)
from app.tasks.connectors_indexing_tasks import (
index_discord_messages,
index_github_repos,
index_linear_issues,
index_notion_pages,
index_slack_messages,
)
from app.users import current_active_user
from app.utils.check_ownership import check_ownership
from pydantic import BaseModel, Field, ValidationError
from app.tasks.connectors_indexing_tasks import index_slack_messages, index_notion_pages, index_github_repos, index_linear_issues, index_discord_messages
from app.connectors.github_connector import GitHubConnector
from datetime import datetime, timedelta
import logging
# Set up logging
logger = logging.getLogger(__name__)
router = APIRouter()
# Use Pydantic's BaseModel here
class GitHubPATRequest(BaseModel):
github_pat: str = Field(..., description="GitHub Personal Access Token")
# --- New Endpoint to list GitHub Repositories ---
@router.post("/github/repositories/", response_model=List[Dict[str, Any]])
@router.post("/github/repositories/", response_model=list[dict[str, Any]])
async def list_github_repositories(
pat_request: GitHubPATRequest,
user: User = Depends(current_active_user) # Ensure the user is logged in
user: User = Depends(current_active_user), # Ensure the user is logged in
):
"""
Fetches a list of repositories accessible by the provided GitHub PAT.
@ -51,38 +74,40 @@ async def list_github_repositories(
return repositories
except ValueError as e:
# Handle invalid token error specifically
logger.error(f"GitHub PAT validation failed for user {user.id}: {str(e)}")
raise HTTPException(status_code=400, detail=f"Invalid GitHub PAT: {str(e)}")
logger.error(f"GitHub PAT validation failed for user {user.id}: {e!s}")
raise HTTPException(status_code=400, detail=f"Invalid GitHub PAT: {e!s}") from e
except Exception as e:
logger.error(f"Failed to fetch GitHub repositories for user {user.id}: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to fetch GitHub repositories.")
logger.error(f"Failed to fetch GitHub repositories for user {user.id}: {e!s}")
raise HTTPException(
status_code=500, detail="Failed to fetch GitHub repositories."
) from e
@router.post("/search-source-connectors/", response_model=SearchSourceConnectorRead)
async def create_search_source_connector(
connector: SearchSourceConnectorCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
"""
Create a new search source connector.
Each user can have only one connector of each type (SERPER_API, TAVILY_API, SLACK_CONNECTOR, etc.).
The config must contain the appropriate keys for the connector type.
"""
try:
# Check if a connector with the same type already exists for this user
result = await session.execute(
select(SearchSourceConnector)
.filter(
select(SearchSourceConnector).filter(
SearchSourceConnector.user_id == user.id,
SearchSourceConnector.connector_type == connector.connector_type
SearchSourceConnector.connector_type == connector.connector_type,
)
)
existing_connector = result.scalars().first()
if existing_connector:
raise HTTPException(
status_code=409,
detail=f"A connector with type {connector.connector_type} already exists. Each user can have only one connector of each type."
detail=f"A connector with type {connector.connector_type} already exists. Each user can have only one connector of each type.",
)
db_connector = SearchSourceConnector(**connector.model_dump(), user_id=user.id)
session.add(db_connector)
@ -91,56 +116,59 @@ async def create_search_source_connector(
return db_connector
except ValidationError as e:
await session.rollback()
raise HTTPException(
status_code=422,
detail=f"Validation error: {str(e)}"
)
raise HTTPException(status_code=422, detail=f"Validation error: {e!s}") from e
except IntegrityError as e:
await session.rollback()
raise HTTPException(
status_code=409,
detail=f"Integrity error: A connector with this type already exists. {str(e)}"
)
detail=f"Integrity error: A connector with this type already exists. {e!s}",
) from e
except HTTPException:
await session.rollback()
raise
except Exception as e:
logger.error(f"Failed to create search source connector: {str(e)}")
logger.error(f"Failed to create search source connector: {e!s}")
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"Failed to create search source connector: {str(e)}"
)
detail=f"Failed to create search source connector: {e!s}",
) from e
@router.get("/search-source-connectors/", response_model=List[SearchSourceConnectorRead])
@router.get(
"/search-source-connectors/", response_model=list[SearchSourceConnectorRead]
)
async def read_search_source_connectors(
skip: int = 0,
limit: int = 100,
search_space_id: int = None,
search_space_id: int | None = None,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
"""List all search source connectors for the current user."""
try:
query = select(SearchSourceConnector).filter(SearchSourceConnector.user_id == user.id)
# No need to filter by search_space_id as connectors are user-owned, not search space specific
result = await session.execute(
query.offset(skip).limit(limit)
query = select(SearchSourceConnector).filter(
SearchSourceConnector.user_id == user.id
)
# No need to filter by search_space_id as connectors are user-owned, not search space specific
result = await session.execute(query.offset(skip).limit(limit))
return result.scalars().all()
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to fetch search source connectors: {str(e)}"
)
detail=f"Failed to fetch search source connectors: {e!s}",
) from e
@router.get("/search-source-connectors/{connector_id}", response_model=SearchSourceConnectorRead)
@router.get(
"/search-source-connectors/{connector_id}", response_model=SearchSourceConnectorRead
)
async def read_search_source_connector(
connector_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
"""Get a specific search source connector by ID."""
try:
@ -149,31 +177,37 @@ async def read_search_source_connector(
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to fetch search source connector: {str(e)}"
)
status_code=500, detail=f"Failed to fetch search source connector: {e!s}"
) from e
@router.put("/search-source-connectors/{connector_id}", response_model=SearchSourceConnectorRead)
@router.put(
"/search-source-connectors/{connector_id}", response_model=SearchSourceConnectorRead
)
async def update_search_source_connector(
connector_id: int,
connector_update: SearchSourceConnectorUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
"""
Update a search source connector.
Handles partial updates, including merging changes into the 'config' field.
"""
db_connector = await check_ownership(session, SearchSourceConnector, connector_id, user)
db_connector = await check_ownership(
session, SearchSourceConnector, connector_id, user
)
# Convert the sparse update data (only fields present in request) to a dict
update_data = connector_update.model_dump(exclude_unset=True)
# Special handling for 'config' field
if "config" in update_data:
incoming_config = update_data["config"] # Config data from the request
existing_config = db_connector.config if db_connector.config else {} # Current config from DB
incoming_config = update_data["config"] # Config data from the request
existing_config = (
db_connector.config if db_connector.config else {}
) # Current config from DB
# Merge incoming config into existing config
# This preserves existing keys (like GITHUB_PAT) if they are not in the incoming data
merged_config = existing_config.copy()
@ -182,26 +216,29 @@ async def update_search_source_connector(
# -- Validation after merging --
# Validate the *merged* config based on the connector type
# We need the connector type - use the one from the update if provided, else the existing one
current_connector_type = connector_update.connector_type if connector_update.connector_type is not None else db_connector.connector_type
current_connector_type = (
connector_update.connector_type
if connector_update.connector_type is not None
else db_connector.connector_type
)
try:
# We can reuse the base validator by creating a temporary base model instance
# Note: This assumes 'name' and 'is_indexable' are not crucial for config validation itself
temp_data_for_validation = {
"name": db_connector.name, # Use existing name
"name": db_connector.name, # Use existing name
"connector_type": current_connector_type,
"is_indexable": db_connector.is_indexable, # Use existing value
"last_indexed_at": db_connector.last_indexed_at, # Not used by validator
"config": merged_config
"is_indexable": db_connector.is_indexable, # Use existing value
"last_indexed_at": db_connector.last_indexed_at, # Not used by validator
"config": merged_config,
}
SearchSourceConnectorBase.model_validate(temp_data_for_validation)
except ValidationError as e:
# Raise specific validation error for the merged config
raise HTTPException(
status_code=422,
detail=f"Validation error for merged config: {str(e)}"
)
status_code=422, detail=f"Validation error for merged config: {e!s}"
) from e
# If validation passes, update the main update_data dict with the merged config
update_data["config"] = merged_config
@ -210,20 +247,19 @@ async def update_search_source_connector(
# Prevent changing connector_type if it causes a duplicate (check moved here)
if key == "connector_type" and value != db_connector.connector_type:
result = await session.execute(
select(SearchSourceConnector)
.filter(
select(SearchSourceConnector).filter(
SearchSourceConnector.user_id == user.id,
SearchSourceConnector.connector_type == value,
SearchSourceConnector.id != connector_id
SearchSourceConnector.id != connector_id,
)
)
existing_connector = result.scalars().first()
if existing_connector:
raise HTTPException(
status_code=409,
detail=f"A connector with type {value} already exists. Each user can have only one connector of each type."
detail=f"A connector with type {value} already exists. Each user can have only one connector of each type.",
)
setattr(db_connector, key, value)
try:
@ -234,26 +270,31 @@ async def update_search_source_connector(
await session.rollback()
# This might occur if connector_type constraint is violated somehow after the check
raise HTTPException(
status_code=409,
detail=f"Database integrity error during update: {str(e)}"
)
status_code=409, detail=f"Database integrity error during update: {e!s}"
) from e
except Exception as e:
await session.rollback()
logger.error(f"Failed to update search source connector {connector_id}: {e}", exc_info=True)
logger.error(
f"Failed to update search source connector {connector_id}: {e}",
exc_info=True,
)
raise HTTPException(
status_code=500,
detail=f"Failed to update search source connector: {str(e)}"
)
detail=f"Failed to update search source connector: {e!s}",
) from e
@router.delete("/search-source-connectors/{connector_id}", response_model=dict)
async def delete_search_source_connector(
connector_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
"""Delete a search source connector."""
try:
db_connector = await check_ownership(session, SearchSourceConnector, connector_id, user)
db_connector = await check_ownership(
session, SearchSourceConnector, connector_id, user
)
await session.delete(db_connector)
await session.commit()
return {"message": "Search source connector deleted successfully"}
@ -263,48 +304,61 @@ async def delete_search_source_connector(
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"Failed to delete search source connector: {str(e)}"
)
detail=f"Failed to delete search source connector: {e!s}",
) from e
@router.post("/search-source-connectors/{connector_id}/index", response_model=Dict[str, Any])
@router.post(
"/search-source-connectors/{connector_id}/index", response_model=dict[str, Any]
)
async def index_connector_content(
connector_id: int,
search_space_id: int = Query(..., description="ID of the search space to store indexed content"),
start_date: str = Query(None, description="Start date for indexing (YYYY-MM-DD format). If not provided, uses last_indexed_at or defaults to 365 days ago"),
end_date: str = Query(None, description="End date for indexing (YYYY-MM-DD format). If not provided, uses today's date"),
search_space_id: int = Query(
..., description="ID of the search space to store indexed content"
),
start_date: str = Query(
None,
description="Start date for indexing (YYYY-MM-DD format). If not provided, uses last_indexed_at or defaults to 365 days ago",
),
end_date: str = Query(
None,
description="End date for indexing (YYYY-MM-DD format). If not provided, uses today's date",
),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
background_tasks: BackgroundTasks = None
background_tasks: BackgroundTasks = None,
):
"""
Index content from a connector to a search space.
Currently supports:
- SLACK_CONNECTOR: Indexes messages from all accessible Slack channels
- NOTION_CONNECTOR: Indexes pages from all accessible Notion pages
- GITHUB_CONNECTOR: Indexes code and documentation from GitHub repositories
- LINEAR_CONNECTOR: Indexes issues and comments from Linear
- DISCORD_CONNECTOR: Indexes messages from all accessible Discord channels
Args:
connector_id: ID of the connector to use
search_space_id: ID of the search space to store indexed content
background_tasks: FastAPI background tasks
Returns:
Dictionary with indexing status
"""
try:
# Check if the connector belongs to the user
connector = await check_ownership(session, SearchSourceConnector, connector_id, user)
connector = await check_ownership(
session, SearchSourceConnector, connector_id, user
)
# Check if the search space belongs to the user
search_space = await check_ownership(session, SearchSpace, search_space_id, user)
await check_ownership(session, SearchSpace, search_space_id, user)
# Handle different connector types
response_message = ""
today_str = datetime.now().strftime("%Y-%m-%d")
# Determine the actual date range to use
if start_date is None:
# Use last_indexed_at or default to 365 days ago
@ -316,37 +370,72 @@ async def index_connector_content(
else:
indexing_from = connector.last_indexed_at.strftime("%Y-%m-%d")
else:
indexing_from = (datetime.now() - timedelta(days=365)).strftime("%Y-%m-%d")
indexing_from = (datetime.now() - timedelta(days=365)).strftime(
"%Y-%m-%d"
)
else:
indexing_from = start_date
if end_date is None:
indexing_to = today_str
else:
indexing_to = end_date
indexing_to = end_date if end_date else today_str
if connector.connector_type == SearchSourceConnectorType.SLACK_CONNECTOR:
# Run indexing in background
logger.info(f"Triggering Slack indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}")
background_tasks.add_task(run_slack_indexing_with_new_session, connector_id, search_space_id, str(user.id), indexing_from, indexing_to)
logger.info(
f"Triggering Slack indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
)
background_tasks.add_task(
run_slack_indexing_with_new_session,
connector_id,
search_space_id,
str(user.id),
indexing_from,
indexing_to,
)
response_message = "Slack indexing started in the background."
elif connector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR:
# Run indexing in background
logger.info(f"Triggering Notion indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}")
background_tasks.add_task(run_notion_indexing_with_new_session, connector_id, search_space_id, str(user.id), indexing_from, indexing_to)
logger.info(
f"Triggering Notion indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
)
background_tasks.add_task(
run_notion_indexing_with_new_session,
connector_id,
search_space_id,
str(user.id),
indexing_from,
indexing_to,
)
response_message = "Notion indexing started in the background."
elif connector.connector_type == SearchSourceConnectorType.GITHUB_CONNECTOR:
# Run indexing in background
logger.info(f"Triggering GitHub indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}")
background_tasks.add_task(run_github_indexing_with_new_session, connector_id, search_space_id, str(user.id), indexing_from, indexing_to)
logger.info(
f"Triggering GitHub indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
)
background_tasks.add_task(
run_github_indexing_with_new_session,
connector_id,
search_space_id,
str(user.id),
indexing_from,
indexing_to,
)
response_message = "GitHub indexing started in the background."
elif connector.connector_type == SearchSourceConnectorType.LINEAR_CONNECTOR:
# Run indexing in background
logger.info(f"Triggering Linear indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}")
background_tasks.add_task(run_linear_indexing_with_new_session, connector_id, search_space_id, str(user.id), indexing_from, indexing_to)
logger.info(
f"Triggering Linear indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
)
background_tasks.add_task(
run_linear_indexing_with_new_session,
connector_id,
search_space_id,
str(user.id),
indexing_from,
indexing_to,
)
response_message = "Linear indexing started in the background."
elif connector.connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR:
@ -355,71 +444,83 @@ async def index_connector_content(
f"Triggering Discord indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
)
background_tasks.add_task(
run_discord_indexing_with_new_session, connector_id, search_space_id, str(user.id), indexing_from, indexing_to
run_discord_indexing_with_new_session,
connector_id,
search_space_id,
str(user.id),
indexing_from,
indexing_to,
)
response_message = "Discord indexing started in the background."
else:
raise HTTPException(
status_code=400,
detail=f"Indexing not supported for connector type: {connector.connector_type}"
detail=f"Indexing not supported for connector type: {connector.connector_type}",
)
return {
"message": response_message,
"connector_id": connector_id,
"message": response_message,
"connector_id": connector_id,
"search_space_id": search_space_id,
"indexing_from": indexing_from,
"indexing_to": indexing_to
"indexing_to": indexing_to,
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to initiate indexing for connector {connector_id}: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"Failed to initiate indexing: {str(e)}"
logger.error(
f"Failed to initiate indexing for connector {connector_id}: {e}",
exc_info=True,
)
async def update_connector_last_indexed(
session: AsyncSession,
connector_id: int
):
raise HTTPException(
status_code=500, detail=f"Failed to initiate indexing: {e!s}"
) from e
async def update_connector_last_indexed(session: AsyncSession, connector_id: int):
"""
Update the last_indexed_at timestamp for a connector.
Args:
session: Database session
connector_id: ID of the connector to update
"""
try:
result = await session.execute(
select(SearchSourceConnector)
.filter(SearchSourceConnector.id == connector_id)
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id
)
)
connector = result.scalars().first()
if connector:
connector.last_indexed_at = datetime.now()
await session.commit()
logger.info(f"Updated last_indexed_at for connector {connector_id}")
except Exception as e:
logger.error(f"Failed to update last_indexed_at for connector {connector_id}: {str(e)}")
logger.error(
f"Failed to update last_indexed_at for connector {connector_id}: {e!s}"
)
await session.rollback()
async def run_slack_indexing_with_new_session(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str
end_date: str,
):
"""
Create a new session and run the Slack indexing task.
This prevents session leaks by creating a dedicated session for the background task.
"""
async with async_session_maker() as session:
await run_slack_indexing(session, connector_id, search_space_id, user_id, start_date, end_date)
await run_slack_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
async def run_slack_indexing(
session: AsyncSession,
@ -427,11 +528,11 @@ async def run_slack_indexing(
search_space_id: int,
user_id: str,
start_date: str,
end_date: str
end_date: str,
):
"""
Background task to run Slack indexing.
Args:
session: Database session
connector_id: ID of the Slack connector
@ -449,31 +550,39 @@ async def run_slack_indexing(
user_id=user_id,
start_date=start_date,
end_date=end_date,
update_last_indexed=False # Don't update timestamp in the indexing function
update_last_indexed=False, # Don't update timestamp in the indexing function
)
# Only update last_indexed_at if indexing was successful (either new docs or updated docs)
if documents_processed > 0:
await update_connector_last_indexed(session, connector_id)
logger.info(f"Slack indexing completed successfully: {documents_processed} documents processed")
logger.info(
f"Slack indexing completed successfully: {documents_processed} documents processed"
)
else:
logger.error(f"Slack indexing failed or no documents processed: {error_or_warning}")
logger.error(
f"Slack indexing failed or no documents processed: {error_or_warning}"
)
except Exception as e:
logger.error(f"Error in background Slack indexing task: {str(e)}")
logger.error(f"Error in background Slack indexing task: {e!s}")
async def run_notion_indexing_with_new_session(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str
end_date: str,
):
"""
Create a new session and run the Notion indexing task.
This prevents session leaks by creating a dedicated session for the background task.
"""
async with async_session_maker() as session:
await run_notion_indexing(session, connector_id, search_space_id, user_id, start_date, end_date)
await run_notion_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
async def run_notion_indexing(
session: AsyncSession,
@ -481,11 +590,11 @@ async def run_notion_indexing(
search_space_id: int,
user_id: str,
start_date: str,
end_date: str
end_date: str,
):
"""
Background task to run Notion indexing.
Args:
session: Database session
connector_id: ID of the Notion connector
@ -503,17 +612,22 @@ async def run_notion_indexing(
user_id=user_id,
start_date=start_date,
end_date=end_date,
update_last_indexed=False # Don't update timestamp in the indexing function
update_last_indexed=False, # Don't update timestamp in the indexing function
)
# Only update last_indexed_at if indexing was successful (either new docs or updated docs)
if documents_processed > 0:
await update_connector_last_indexed(session, connector_id)
logger.info(f"Notion indexing completed successfully: {documents_processed} documents processed")
logger.info(
f"Notion indexing completed successfully: {documents_processed} documents processed"
)
else:
logger.error(f"Notion indexing failed or no documents processed: {error_or_warning}")
logger.error(
f"Notion indexing failed or no documents processed: {error_or_warning}"
)
except Exception as e:
logger.error(f"Error in background Notion indexing task: {str(e)}")
logger.error(f"Error in background Notion indexing task: {e!s}")
# Add new helper functions for GitHub indexing
async def run_github_indexing_with_new_session(
@ -521,94 +635,135 @@ async def run_github_indexing_with_new_session(
search_space_id: int,
user_id: str,
start_date: str,
end_date: str
end_date: str,
):
"""Wrapper to run GitHub indexing with its own database session."""
logger.info(f"Background task started: Indexing GitHub connector {connector_id} into space {search_space_id} from {start_date} to {end_date}")
logger.info(
f"Background task started: Indexing GitHub connector {connector_id} into space {search_space_id} from {start_date} to {end_date}"
)
async with async_session_maker() as session:
await run_github_indexing(session, connector_id, search_space_id, user_id, start_date, end_date)
await run_github_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
logger.info(f"Background task finished: Indexing GitHub connector {connector_id}")
async def run_github_indexing(
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str
end_date: str,
):
"""Runs the GitHub indexing task and updates the timestamp."""
try:
indexed_count, error_message = await index_github_repos(
session, connector_id, search_space_id, user_id, start_date, end_date, update_last_indexed=False
session,
connector_id,
search_space_id,
user_id,
start_date,
end_date,
update_last_indexed=False,
)
if error_message:
logger.error(f"GitHub indexing failed for connector {connector_id}: {error_message}")
logger.error(
f"GitHub indexing failed for connector {connector_id}: {error_message}"
)
# Optionally update status in DB to indicate failure
else:
logger.info(f"GitHub indexing successful for connector {connector_id}. Indexed {indexed_count} documents.")
logger.info(
f"GitHub indexing successful for connector {connector_id}. Indexed {indexed_count} documents."
)
# Update the last indexed timestamp only on success
await update_connector_last_indexed(session, connector_id)
await session.commit() # Commit timestamp update
await session.commit() # Commit timestamp update
except Exception as e:
await session.rollback()
logger.error(f"Critical error in run_github_indexing for connector {connector_id}: {e}", exc_info=True)
logger.error(
f"Critical error in run_github_indexing for connector {connector_id}: {e}",
exc_info=True,
)
# Optionally update status in DB to indicate failure
# Add new helper functions for Linear indexing
async def run_linear_indexing_with_new_session(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str
end_date: str,
):
"""Wrapper to run Linear indexing with its own database session."""
logger.info(f"Background task started: Indexing Linear connector {connector_id} into space {search_space_id} from {start_date} to {end_date}")
logger.info(
f"Background task started: Indexing Linear connector {connector_id} into space {search_space_id} from {start_date} to {end_date}"
)
async with async_session_maker() as session:
await run_linear_indexing(session, connector_id, search_space_id, user_id, start_date, end_date)
await run_linear_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
logger.info(f"Background task finished: Indexing Linear connector {connector_id}")
async def run_linear_indexing(
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str
end_date: str,
):
"""Runs the Linear indexing task and updates the timestamp."""
try:
indexed_count, error_message = await index_linear_issues(
session, connector_id, search_space_id, user_id, start_date, end_date, update_last_indexed=False
session,
connector_id,
search_space_id,
user_id,
start_date,
end_date,
update_last_indexed=False,
)
if error_message:
logger.error(f"Linear indexing failed for connector {connector_id}: {error_message}")
logger.error(
f"Linear indexing failed for connector {connector_id}: {error_message}"
)
# Optionally update status in DB to indicate failure
else:
logger.info(f"Linear indexing successful for connector {connector_id}. Indexed {indexed_count} documents.")
logger.info(
f"Linear indexing successful for connector {connector_id}. Indexed {indexed_count} documents."
)
# Update the last indexed timestamp only on success
await update_connector_last_indexed(session, connector_id)
await session.commit() # Commit timestamp update
await session.commit() # Commit timestamp update
except Exception as e:
await session.rollback()
logger.error(f"Critical error in run_linear_indexing for connector {connector_id}: {e}", exc_info=True)
logger.error(
f"Critical error in run_linear_indexing for connector {connector_id}: {e}",
exc_info=True,
)
# Optionally update status in DB to indicate failure
# Add new helper functions for discord indexing
async def run_discord_indexing_with_new_session(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str
end_date: str,
):
"""
Create a new session and run the Discord indexing task.
This prevents session leaks by creating a dedicated session for the background task.
"""
async with async_session_maker() as session:
await run_discord_indexing(session, connector_id, search_space_id, user_id, start_date, end_date)
await run_discord_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
async def run_discord_indexing(
session: AsyncSession,
@ -616,7 +771,7 @@ async def run_discord_indexing(
search_space_id: int,
user_id: str,
start_date: str,
end_date: str
end_date: str,
):
"""
Background task to run Discord indexing.
@ -637,14 +792,18 @@ async def run_discord_indexing(
user_id=user_id,
start_date=start_date,
end_date=end_date,
update_last_indexed=False # Don't update timestamp in the indexing function
update_last_indexed=False, # Don't update timestamp in the indexing function
)
# Only update last_indexed_at if indexing was successful (either new docs or updated docs)
if documents_processed > 0:
await update_connector_last_indexed(session, connector_id)
logger.info(f"Discord indexing completed successfully: {documents_processed} documents processed")
logger.info(
f"Discord indexing completed successfully: {documents_processed} documents processed"
)
else:
logger.error(f"Discord indexing failed or no documents processed: {error_or_warning}")
logger.error(
f"Discord indexing failed or no documents processed: {error_or_warning}"
)
except Exception as e:
logger.error(f"Error in background Discord indexing task: {str(e)}")
logger.error(f"Error in background Discord indexing task: {e!s}")

View file

@ -1,20 +1,20 @@
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from typing import List
from app.db import get_async_session, User, SearchSpace
from app.schemas import SearchSpaceCreate, SearchSpaceUpdate, SearchSpaceRead
from app.db import SearchSpace, User, get_async_session
from app.schemas import SearchSpaceCreate, SearchSpaceRead, SearchSpaceUpdate
from app.users import current_active_user
from app.utils.check_ownership import check_ownership
from fastapi import HTTPException
router = APIRouter()
@router.post("/searchspaces/", response_model=SearchSpaceRead)
async def create_search_space(
search_space: SearchSpaceCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
try:
db_search_space = SearchSpace(**search_space.model_dump(), user_id=user.id)
@ -27,16 +27,16 @@ async def create_search_space(
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"Failed to create search space: {str(e)}"
)
status_code=500, detail=f"Failed to create search space: {e!s}"
) from e
@router.get("/searchspaces/", response_model=List[SearchSpaceRead])
@router.get("/searchspaces/", response_model=list[SearchSpaceRead])
async def read_search_spaces(
skip: int = 0,
limit: int = 200,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
try:
result = await session.execute(
@ -48,37 +48,41 @@ async def read_search_spaces(
return result.scalars().all()
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to fetch search spaces: {str(e)}"
)
status_code=500, detail=f"Failed to fetch search spaces: {e!s}"
) from e
@router.get("/searchspaces/{search_space_id}", response_model=SearchSpaceRead)
async def read_search_space(
search_space_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
try:
search_space = await check_ownership(session, SearchSpace, search_space_id, user)
search_space = await check_ownership(
session, SearchSpace, search_space_id, user
)
return search_space
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to fetch search space: {str(e)}"
)
status_code=500, detail=f"Failed to fetch search space: {e!s}"
) from e
@router.put("/searchspaces/{search_space_id}", response_model=SearchSpaceRead)
async def update_search_space(
search_space_id: int,
search_space_update: SearchSpaceUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
try:
db_search_space = await check_ownership(session, SearchSpace, search_space_id, user)
db_search_space = await check_ownership(
session, SearchSpace, search_space_id, user
)
update_data = search_space_update.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(db_search_space, key, value)
@ -90,18 +94,20 @@ async def update_search_space(
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"Failed to update search space: {str(e)}"
)
status_code=500, detail=f"Failed to update search space: {e!s}"
) from e
@router.delete("/searchspaces/{search_space_id}", response_model=dict)
async def delete_search_space(
search_space_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
user: User = Depends(current_active_user),
):
try:
db_search_space = await check_ownership(session, SearchSpace, search_space_id, user)
db_search_space = await check_ownership(
session, SearchSpace, search_space_id, user
)
await session.delete(db_search_space)
await session.commit()
return {"message": "Search space deleted successfully"}
@ -110,6 +116,5 @@ async def delete_search_space(
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"Failed to delete search space: {str(e)}"
)
status_code=500, detail=f"Failed to delete search space: {e!s}"
) from e

View file

@ -1,62 +1,78 @@
from .base import TimestampModel, IDModel
from .users import UserRead, UserCreate, UserUpdate
from .search_space import SearchSpaceBase, SearchSpaceCreate, SearchSpaceUpdate, SearchSpaceRead
from .base import IDModel, TimestampModel
from .chats import AISDKChatRequest, ChatBase, ChatCreate, ChatRead, ChatUpdate
from .chunks import ChunkBase, ChunkCreate, ChunkRead, ChunkUpdate
from .documents import (
ExtensionDocumentMetadata,
ExtensionDocumentContent,
DocumentBase,
DocumentRead,
DocumentsCreate,
DocumentUpdate,
DocumentRead,
ExtensionDocumentContent,
ExtensionDocumentMetadata,
)
from .chunks import ChunkBase, ChunkCreate, ChunkUpdate, ChunkRead
from .podcasts import PodcastBase, PodcastCreate, PodcastUpdate, PodcastRead, PodcastGenerateRequest
from .chats import ChatBase, ChatCreate, ChatUpdate, ChatRead, AISDKChatRequest
from .search_source_connector import SearchSourceConnectorBase, SearchSourceConnectorCreate, SearchSourceConnectorUpdate, SearchSourceConnectorRead
from .llm_config import LLMConfigBase, LLMConfigCreate, LLMConfigUpdate, LLMConfigRead
from .logs import LogBase, LogCreate, LogUpdate, LogRead, LogFilter
from .llm_config import LLMConfigBase, LLMConfigCreate, LLMConfigRead, LLMConfigUpdate
from .logs import LogBase, LogCreate, LogFilter, LogRead, LogUpdate
from .podcasts import (
PodcastBase,
PodcastCreate,
PodcastGenerateRequest,
PodcastRead,
PodcastUpdate,
)
from .search_source_connector import (
SearchSourceConnectorBase,
SearchSourceConnectorCreate,
SearchSourceConnectorRead,
SearchSourceConnectorUpdate,
)
from .search_space import (
SearchSpaceBase,
SearchSpaceCreate,
SearchSpaceRead,
SearchSpaceUpdate,
)
from .users import UserCreate, UserRead, UserUpdate
__all__ = [
"AISDKChatRequest",
"TimestampModel",
"IDModel",
"UserRead",
"UserCreate",
"UserUpdate",
"SearchSpaceBase",
"SearchSpaceCreate",
"SearchSpaceUpdate",
"SearchSpaceRead",
"ExtensionDocumentMetadata",
"ExtensionDocumentContent",
"DocumentBase",
"DocumentsCreate",
"DocumentUpdate",
"DocumentRead",
"ChunkBase",
"ChunkCreate",
"ChunkUpdate",
"ChunkRead",
"PodcastBase",
"PodcastCreate",
"PodcastUpdate",
"PodcastRead",
"PodcastGenerateRequest",
"ChatBase",
"ChatCreate",
"ChatUpdate",
"ChatRead",
"SearchSourceConnectorBase",
"SearchSourceConnectorCreate",
"SearchSourceConnectorUpdate",
"SearchSourceConnectorRead",
"ChatUpdate",
"ChunkBase",
"ChunkCreate",
"ChunkRead",
"ChunkUpdate",
"DocumentBase",
"DocumentRead",
"DocumentUpdate",
"DocumentsCreate",
"ExtensionDocumentContent",
"ExtensionDocumentMetadata",
"IDModel",
"LLMConfigBase",
"LLMConfigCreate",
"LLMConfigUpdate",
"LLMConfigRead",
"LLMConfigUpdate",
"LogBase",
"LogCreate",
"LogUpdate",
"LogRead",
"LogFilter",
]
"LogRead",
"LogUpdate",
"PodcastBase",
"PodcastCreate",
"PodcastGenerateRequest",
"PodcastRead",
"PodcastUpdate",
"SearchSourceConnectorBase",
"SearchSourceConnectorCreate",
"SearchSourceConnectorRead",
"SearchSourceConnectorUpdate",
"SearchSpaceBase",
"SearchSpaceCreate",
"SearchSpaceRead",
"SearchSpaceUpdate",
"TimestampModel",
"UserCreate",
"UserRead",
"UserUpdate",
]

View file

@ -1,10 +1,13 @@
from datetime import datetime
from pydantic import BaseModel, ConfigDict
class TimestampModel(BaseModel):
created_at: datetime
model_config = ConfigDict(from_attributes=True)
class IDModel(BaseModel):
id: int
model_config = ConfigDict(from_attributes=True)
model_config = ConfigDict(from_attributes=True)

View file

@ -1,7 +1,8 @@
from typing import Any, Dict, List, Optional
from typing import Any
from pydantic import BaseModel, ConfigDict
from app.db import ChatType
from pydantic import BaseModel, ConfigDict
from .base import IDModel, TimestampModel
@ -9,39 +10,43 @@ from .base import IDModel, TimestampModel
class ChatBase(BaseModel):
type: ChatType
title: str
initial_connectors: Optional[List[str]] = None
messages: List[Any]
initial_connectors: list[str] | None = None
messages: list[Any]
search_space_id: int
class ClientAttachment(BaseModel):
name: str
contentType: str
content_type: str
url: str
class ToolInvocation(BaseModel):
toolCallId: str
toolName: str
tool_call_id: str
tool_name: str
args: dict
result: dict
# class ClientMessage(BaseModel):
# role: str
# content: str
# experimental_attachments: Optional[List[ClientAttachment]] = None
# toolInvocations: Optional[List[ToolInvocation]] = None
class AISDKChatRequest(BaseModel):
messages: List[Any]
data: Optional[Dict[str, Any]] = None
messages: list[Any]
data: dict[str, Any] | None = None
class ChatCreate(ChatBase):
pass
class ChatUpdate(ChatBase):
pass
class ChatRead(ChatBase, IDModel, TimestampModel):
model_config = ConfigDict(from_attributes=True)
model_config = ConfigDict(from_attributes=True)

View file

@ -1,15 +1,20 @@
from pydantic import BaseModel, ConfigDict
from .base import IDModel, TimestampModel
class ChunkBase(BaseModel):
content: str
document_id: int
class ChunkCreate(ChunkBase):
pass
class ChunkUpdate(ChunkBase):
pass
class ChunkRead(ChunkBase, IDModel, TimestampModel):
model_config = ConfigDict(from_attributes=True)
model_config = ConfigDict(from_attributes=True)

View file

@ -1,8 +1,10 @@
from typing import List
from pydantic import BaseModel, ConfigDict
from app.db import DocumentType
from datetime import datetime
from pydantic import BaseModel, ConfigDict
from app.db import DocumentType
class ExtensionDocumentMetadata(BaseModel):
BrowsingSessionId: str
VisitedWebPageURL: str
@ -11,21 +13,28 @@ class ExtensionDocumentMetadata(BaseModel):
VisitedWebPageReffererURL: str
VisitedWebPageVisitDurationInMilliseconds: str
class ExtensionDocumentContent(BaseModel):
metadata: ExtensionDocumentMetadata
pageContent: str
pageContent: str # noqa: N815
class DocumentBase(BaseModel):
document_type: DocumentType
content: List[ExtensionDocumentContent] | List[str] | str # Updated to allow string content
content: (
list[ExtensionDocumentContent] | list[str] | str
) # Updated to allow string content
search_space_id: int
class DocumentsCreate(DocumentBase):
pass
class DocumentUpdate(DocumentBase):
pass
class DocumentRead(BaseModel):
id: int
title: str
@ -34,6 +43,5 @@ class DocumentRead(BaseModel):
content: str # Changed to string to match frontend
created_at: datetime
search_space_id: int
model_config = ConfigDict(from_attributes=True)
model_config = ConfigDict(from_attributes=True)

View file

@ -1,34 +1,61 @@
from datetime import datetime
import uuid
from typing import Optional, Dict, Any
from datetime import datetime
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from .base import IDModel, TimestampModel
from app.db import LiteLLMProvider
from .base import IDModel, TimestampModel
class LLMConfigBase(BaseModel):
name: str = Field(..., max_length=100, description="User-friendly name for the LLM configuration")
name: str = Field(
..., max_length=100, description="User-friendly name for the LLM configuration"
)
provider: LiteLLMProvider = Field(..., description="LiteLLM provider type")
custom_provider: Optional[str] = Field(None, max_length=100, description="Custom provider name when provider is CUSTOM")
model_name: str = Field(..., max_length=100, description="Model name without provider prefix")
custom_provider: str | None = Field(
None, max_length=100, description="Custom provider name when provider is CUSTOM"
)
model_name: str = Field(
..., max_length=100, description="Model name without provider prefix"
)
api_key: str = Field(..., description="API key for the provider")
api_base: Optional[str] = Field(None, max_length=500, description="Optional API base URL")
litellm_params: Optional[Dict[str, Any]] = Field(default=None, description="Additional LiteLLM parameters")
api_base: str | None = Field(
None, max_length=500, description="Optional API base URL"
)
litellm_params: dict[str, Any] | None = Field(
default=None, description="Additional LiteLLM parameters"
)
class LLMConfigCreate(LLMConfigBase):
pass
class LLMConfigUpdate(BaseModel):
name: Optional[str] = Field(None, max_length=100, description="User-friendly name for the LLM configuration")
provider: Optional[LiteLLMProvider] = Field(None, description="LiteLLM provider type")
custom_provider: Optional[str] = Field(None, max_length=100, description="Custom provider name when provider is CUSTOM")
model_name: Optional[str] = Field(None, max_length=100, description="Model name without provider prefix")
api_key: Optional[str] = Field(None, description="API key for the provider")
api_base: Optional[str] = Field(None, max_length=500, description="Optional API base URL")
litellm_params: Optional[Dict[str, Any]] = Field(None, description="Additional LiteLLM parameters")
name: str | None = Field(
None, max_length=100, description="User-friendly name for the LLM configuration"
)
provider: LiteLLMProvider | None = Field(None, description="LiteLLM provider type")
custom_provider: str | None = Field(
None, max_length=100, description="Custom provider name when provider is CUSTOM"
)
model_name: str | None = Field(
None, max_length=100, description="Model name without provider prefix"
)
api_key: str | None = Field(None, description="API key for the provider")
api_base: str | None = Field(
None, max_length=500, description="Optional API base URL"
)
litellm_params: dict[str, Any] | None = Field(
None, description="Additional LiteLLM parameters"
)
class LLMConfigRead(LLMConfigBase, IDModel, TimestampModel):
id: int
created_at: datetime
user_id: uuid.UUID
model_config = ConfigDict(from_attributes=True)
model_config = ConfigDict(from_attributes=True)

View file

@ -1,30 +1,37 @@
from datetime import datetime
from typing import Optional, Dict, Any
from typing import Any
from pydantic import BaseModel, ConfigDict
from .base import IDModel, TimestampModel
from app.db import LogLevel, LogStatus
from .base import IDModel, TimestampModel
class LogBase(BaseModel):
level: LogLevel
status: LogStatus
message: str
source: Optional[str] = None
log_metadata: Optional[Dict[str, Any]] = None
source: str | None = None
log_metadata: dict[str, Any] | None = None
class LogCreate(BaseModel):
level: LogLevel
status: LogStatus
message: str
source: Optional[str] = None
log_metadata: Optional[Dict[str, Any]] = None
source: str | None = None
log_metadata: dict[str, Any] | None = None
search_space_id: int
class LogUpdate(BaseModel):
level: Optional[LogLevel] = None
status: Optional[LogStatus] = None
message: Optional[str] = None
source: Optional[str] = None
log_metadata: Optional[Dict[str, Any]] = None
level: LogLevel | None = None
status: LogStatus | None = None
message: str | None = None
source: str | None = None
log_metadata: dict[str, Any] | None = None
class LogRead(LogBase, IDModel, TimestampModel):
id: int
@ -33,12 +40,13 @@ class LogRead(LogBase, IDModel, TimestampModel):
model_config = ConfigDict(from_attributes=True)
class LogFilter(BaseModel):
level: Optional[LogLevel] = None
status: Optional[LogStatus] = None
source: Optional[str] = None
search_space_id: Optional[int] = None
start_date: Optional[datetime] = None
end_date: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
class LogFilter(BaseModel):
level: LogLevel | None = None
status: LogStatus | None = None
source: str | None = None
search_space_id: int | None = None
start_date: datetime | None = None
end_date: datetime | None = None
model_config = ConfigDict(from_attributes=True)

View file

@ -1,24 +1,31 @@
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict
from typing import Any, List, Literal
from .base import IDModel, TimestampModel
class PodcastBase(BaseModel):
title: str
podcast_transcript: List[Any]
podcast_transcript: list[Any]
file_location: str = ""
search_space_id: int
class PodcastCreate(PodcastBase):
pass
class PodcastUpdate(PodcastBase):
pass
class PodcastRead(PodcastBase, IDModel, TimestampModel):
model_config = ConfigDict(from_attributes=True)
class PodcastGenerateRequest(BaseModel):
type: Literal["DOCUMENT", "CHAT"]
ids: List[int]
ids: list[int]
search_space_id: int
podcast_title: str = "SurfSense Podcast"
podcast_title: str = "SurfSense Podcast"

View file

@ -1,102 +1,124 @@
from datetime import datetime
import uuid
from typing import Dict, Any, Optional
from pydantic import BaseModel, field_validator, ConfigDict
from .base import IDModel, TimestampModel
from datetime import datetime
from typing import Any
from pydantic import BaseModel, ConfigDict, field_validator
from app.db import SearchSourceConnectorType
from .base import IDModel, TimestampModel
class SearchSourceConnectorBase(BaseModel):
name: str
connector_type: SearchSourceConnectorType
is_indexable: bool
last_indexed_at: Optional[datetime] = None
config: Dict[str, Any]
@field_validator('config')
last_indexed_at: datetime | None = None
config: dict[str, Any]
@field_validator("config")
@classmethod
def validate_config_for_connector_type(cls, config: Dict[str, Any], values: Dict[str, Any]) -> Dict[str, Any]:
connector_type = values.data.get('connector_type')
def validate_config_for_connector_type(
cls, config: dict[str, Any], values: dict[str, Any]
) -> dict[str, Any]:
connector_type = values.data.get("connector_type")
if connector_type == SearchSourceConnectorType.SERPER_API:
# For SERPER_API, only allow SERPER_API_KEY
allowed_keys = ["SERPER_API_KEY"]
if set(config.keys()) != set(allowed_keys):
raise ValueError(f"For SERPER_API connector type, config must only contain these keys: {allowed_keys}")
raise ValueError(
f"For SERPER_API connector type, config must only contain these keys: {allowed_keys}"
)
# Ensure the API key is not empty
if not config.get("SERPER_API_KEY"):
raise ValueError("SERPER_API_KEY cannot be empty")
elif connector_type == SearchSourceConnectorType.TAVILY_API:
# For TAVILY_API, only allow TAVILY_API_KEY
allowed_keys = ["TAVILY_API_KEY"]
if set(config.keys()) != set(allowed_keys):
raise ValueError(f"For TAVILY_API connector type, config must only contain these keys: {allowed_keys}")
raise ValueError(
f"For TAVILY_API connector type, config must only contain these keys: {allowed_keys}"
)
# Ensure the API key is not empty
if not config.get("TAVILY_API_KEY"):
raise ValueError("TAVILY_API_KEY cannot be empty")
elif connector_type == SearchSourceConnectorType.LINKUP_API:
# For LINKUP_API, only allow LINKUP_API_KEY
allowed_keys = ["LINKUP_API_KEY"]
if set(config.keys()) != set(allowed_keys):
raise ValueError(f"For LINKUP_API connector type, config must only contain these keys: {allowed_keys}")
raise ValueError(
f"For LINKUP_API connector type, config must only contain these keys: {allowed_keys}"
)
# Ensure the API key is not empty
if not config.get("LINKUP_API_KEY"):
raise ValueError("LINKUP_API_KEY cannot be empty")
elif connector_type == SearchSourceConnectorType.SLACK_CONNECTOR:
# For SLACK_CONNECTOR, only allow SLACK_BOT_TOKEN
allowed_keys = ["SLACK_BOT_TOKEN"]
if set(config.keys()) != set(allowed_keys):
raise ValueError(f"For SLACK_CONNECTOR connector type, config must only contain these keys: {allowed_keys}")
raise ValueError(
f"For SLACK_CONNECTOR connector type, config must only contain these keys: {allowed_keys}"
)
# Ensure the bot token is not empty
if not config.get("SLACK_BOT_TOKEN"):
raise ValueError("SLACK_BOT_TOKEN cannot be empty")
elif connector_type == SearchSourceConnectorType.NOTION_CONNECTOR:
# For NOTION_CONNECTOR, only allow NOTION_INTEGRATION_TOKEN
allowed_keys = ["NOTION_INTEGRATION_TOKEN"]
if set(config.keys()) != set(allowed_keys):
raise ValueError(f"For NOTION_CONNECTOR connector type, config must only contain these keys: {allowed_keys}")
raise ValueError(
f"For NOTION_CONNECTOR connector type, config must only contain these keys: {allowed_keys}"
)
# Ensure the integration token is not empty
if not config.get("NOTION_INTEGRATION_TOKEN"):
raise ValueError("NOTION_INTEGRATION_TOKEN cannot be empty")
elif connector_type == SearchSourceConnectorType.GITHUB_CONNECTOR:
# For GITHUB_CONNECTOR, only allow GITHUB_PAT and repo_full_names
allowed_keys = ["GITHUB_PAT", "repo_full_names"]
if set(config.keys()) != set(allowed_keys):
raise ValueError(f"For GITHUB_CONNECTOR connector type, config must only contain these keys: {allowed_keys}")
raise ValueError(
f"For GITHUB_CONNECTOR connector type, config must only contain these keys: {allowed_keys}"
)
# Ensure the token is not empty
if not config.get("GITHUB_PAT"):
raise ValueError("GITHUB_PAT cannot be empty")
# Ensure the repo_full_names is present and is a non-empty list
repo_full_names = config.get("repo_full_names")
if not isinstance(repo_full_names, list) or not repo_full_names:
raise ValueError("repo_full_names must be a non-empty list of strings")
elif connector_type == SearchSourceConnectorType.LINEAR_CONNECTOR:
# For LINEAR_CONNECTOR, only allow LINEAR_API_KEY
allowed_keys = ["LINEAR_API_KEY"]
if set(config.keys()) != set(allowed_keys):
raise ValueError(f"For LINEAR_CONNECTOR connector type, config must only contain these keys: {allowed_keys}")
raise ValueError(
f"For LINEAR_CONNECTOR connector type, config must only contain these keys: {allowed_keys}"
)
# Ensure the token is not empty
if not config.get("LINEAR_API_KEY"):
raise ValueError("LINEAR_API_KEY cannot be empty")
elif connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR:
# For DISCORD_CONNECTOR, only allow DISCORD_BOT_TOKEN
allowed_keys = ["DISCORD_BOT_TOKEN"]
if set(config.keys()) != set(allowed_keys):
raise ValueError(f"For DISCORD_CONNECTOR connector type, config must only contain these keys: {allowed_keys}")
raise ValueError(
f"For DISCORD_CONNECTOR connector type, config must only contain these keys: {allowed_keys}"
)
# Ensure the bot token is not empty
if not config.get("DISCORD_BOT_TOKEN"):
@ -104,17 +126,20 @@ class SearchSourceConnectorBase(BaseModel):
return config
class SearchSourceConnectorCreate(SearchSourceConnectorBase):
pass
class SearchSourceConnectorUpdate(BaseModel):
name: Optional[str] = None
connector_type: Optional[SearchSourceConnectorType] = None
is_indexable: Optional[bool] = None
last_indexed_at: Optional[datetime] = None
config: Optional[Dict[str, Any]] = None
name: str | None = None
connector_type: SearchSourceConnectorType | None = None
is_indexable: bool | None = None
last_indexed_at: datetime | None = None
config: dict[str, Any] | None = None
class SearchSourceConnectorRead(SearchSourceConnectorBase, IDModel, TimestampModel):
user_id: uuid.UUID
model_config = ConfigDict(from_attributes=True)
model_config = ConfigDict(from_attributes=True)

View file

@ -1,22 +1,27 @@
from datetime import datetime
import uuid
from typing import Optional
from datetime import datetime
from pydantic import BaseModel, ConfigDict
from .base import IDModel, TimestampModel
class SearchSpaceBase(BaseModel):
name: str
description: Optional[str] = None
description: str | None = None
class SearchSpaceCreate(SearchSpaceBase):
pass
class SearchSpaceUpdate(SearchSpaceBase):
pass
class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel):
id: int
created_at: datetime
user_id: uuid.UUID
model_config = ConfigDict(from_attributes=True)
model_config = ConfigDict(from_attributes=True)

View file

@ -1,11 +1,15 @@
import uuid
from fastapi_users import schemas
class UserRead(schemas.BaseUser[uuid.UUID]):
pass
class UserCreate(schemas.BaseUserCreate):
pass
class UserUpdate(schemas.BaseUserUpdate):
pass
pass

View file

@ -1 +1 @@
# Services package
# Services package

File diff suppressed because it is too large Load diff

View file

@ -5,15 +5,16 @@ SSL-safe implementation with pre-downloaded models
"""
import logging
import ssl
import os
from typing import Dict, Any
import ssl
from typing import Any
logger = logging.getLogger(__name__)
class DoclingService:
"""Docling service for enhanced document processing with SSL fixes."""
def __init__(self):
"""Initialize Docling service with SSL, model fixes, and GPU acceleration."""
self.converter = None
@ -21,30 +22,32 @@ class DoclingService:
self._configure_ssl_environment()
self._check_wsl2_gpu_support()
self._initialize_docling()
def _configure_ssl_environment(self):
"""Configure SSL environment for secure model downloads."""
try:
# Set SSL context for downloads
ssl._create_default_https_context = ssl._create_unverified_context
# Set SSL environment variables if not already set
if not os.environ.get('SSL_CERT_FILE'):
if not os.environ.get("SSL_CERT_FILE"):
try:
import certifi
os.environ['SSL_CERT_FILE'] = certifi.where()
os.environ['REQUESTS_CA_BUNDLE'] = certifi.where()
os.environ["SSL_CERT_FILE"] = certifi.where()
os.environ["REQUESTS_CA_BUNDLE"] = certifi.where()
except ImportError:
pass
logger.info("🔐 SSL environment configured for model downloads")
except Exception as e:
logger.warning(f"⚠️ SSL configuration warning: {e}")
def _check_wsl2_gpu_support(self):
"""Check and configure GPU support for WSL2 environment."""
try:
import torch
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
gpu_name = torch.cuda.get_device_name(0) if gpu_count > 0 else "Unknown"
@ -60,34 +63,34 @@ class DoclingService:
except Exception as e:
logger.warning(f"⚠️ GPU detection failed: {e}, falling back to CPU")
self.use_gpu = False
def _initialize_docling(self):
"""Initialize Docling with version-safe configuration."""
try:
from docling.document_converter import DocumentConverter, PdfFormatOption
from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend
from docling.datamodel.base_models import InputFormat
from docling.datamodel.pipeline_options import PdfPipelineOptions
from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend
from docling.document_converter import DocumentConverter, PdfFormatOption
logger.info("🔧 Initializing Docling with version-safe configuration...")
# Create pipeline options with version-safe attribute checking
pipeline_options = PdfPipelineOptions()
# Disable OCR (user request)
if hasattr(pipeline_options, 'do_ocr'):
if hasattr(pipeline_options, "do_ocr"):
pipeline_options.do_ocr = False
logger.info("⚠️ OCR disabled by user request")
else:
logger.warning("⚠️ OCR attribute not available in this Docling version")
# Enable table structure if available
if hasattr(pipeline_options, 'do_table_structure'):
if hasattr(pipeline_options, "do_table_structure"):
pipeline_options.do_table_structure = True
logger.info("✅ Table structure detection enabled")
# Configure GPU acceleration for WSL2 if available
if hasattr(pipeline_options, 'accelerator_device'):
if hasattr(pipeline_options, "accelerator_device"):
if self.use_gpu:
try:
pipeline_options.accelerator_device = "cuda"
@ -99,164 +102,180 @@ class DoclingService:
pipeline_options.accelerator_device = "cpu"
logger.info("🖥️ Using CPU acceleration")
else:
logger.info(" Accelerator device attribute not available in this Docling version")
logger.info(
"⚠️ Accelerator device attribute not available in this Docling version"
)
# Create PDF format option with backend
pdf_format_option = PdfFormatOption(
pipeline_options=pipeline_options,
backend=PyPdfiumDocumentBackend
pipeline_options=pipeline_options, backend=PyPdfiumDocumentBackend
)
# Initialize DocumentConverter
self.converter = DocumentConverter(
format_options={
InputFormat.PDF: pdf_format_option
}
format_options={InputFormat.PDF: pdf_format_option}
)
acceleration_type = "GPU (WSL2)" if self.use_gpu else "CPU"
logger.info(f"✅ Docling initialized successfully with {acceleration_type} acceleration")
logger.info(
f"✅ Docling initialized successfully with {acceleration_type} acceleration"
)
except ImportError as e:
logger.error(f"❌ Docling not installed: {e}")
raise RuntimeError(f"Docling not available: {e}")
raise RuntimeError(f"Docling not available: {e}") from e
except Exception as e:
logger.error(f"❌ Docling initialization failed: {e}")
raise RuntimeError(f"Docling initialization failed: {e}")
raise RuntimeError(f"Docling initialization failed: {e}") from e
def _configure_easyocr_local_models(self):
"""Configure EasyOCR to use pre-downloaded local models."""
try:
import easyocr
import os
import easyocr
# Set SSL environment for EasyOCR downloads
os.environ['CURL_CA_BUNDLE'] = ''
os.environ['REQUESTS_CA_BUNDLE'] = ''
os.environ["CURL_CA_BUNDLE"] = ""
os.environ["REQUESTS_CA_BUNDLE"] = ""
# Try to use local models first, fallback to download if needed
try:
reader = easyocr.Reader(['en'],
download_enabled=False,
model_storage_directory="/root/.EasyOCR/model")
reader = easyocr.Reader(
["en"],
download_enabled=False,
model_storage_directory="/root/.EasyOCR/model",
)
logger.info("✅ EasyOCR configured for local models")
return reader
except:
except Exception:
# If local models fail, allow download with SSL bypass
logger.info("🔄 Local models failed, attempting download with SSL bypass...")
reader = easyocr.Reader(['en'],
download_enabled=True,
model_storage_directory="/root/.EasyOCR/model")
logger.info(
"🔄 Local models failed, attempting download with SSL bypass..."
)
reader = easyocr.Reader(
["en"],
download_enabled=True,
model_storage_directory="/root/.EasyOCR/model",
)
logger.info("✅ EasyOCR configured with downloaded models")
return reader
except Exception as e:
logger.warning(f"⚠️ EasyOCR configuration failed: {e}")
return None
async def process_document(self, file_path: str, filename: str = None) -> Dict[str, Any]:
async def process_document(
self, file_path: str, filename: str | None = None
) -> dict[str, Any]:
"""Process document with Docling using pre-downloaded models."""
if self.converter is None:
raise RuntimeError("Docling converter not initialized")
try:
logger.info(f"🔄 Processing {filename} with Docling (using local models)...")
logger.info(
f"🔄 Processing {filename} with Docling (using local models)..."
)
# Process document with local models
result = self.converter.convert(file_path)
# Extract content using version-safe methods
content = None
if hasattr(result, 'document') and result.document:
if hasattr(result, "document") and result.document:
# Try different export methods (version compatibility)
if hasattr(result.document, 'export_to_markdown'):
if hasattr(result.document, "export_to_markdown"):
content = result.document.export_to_markdown()
logger.info("📄 Used export_to_markdown method")
elif hasattr(result.document, 'to_markdown'):
elif hasattr(result.document, "to_markdown"):
content = result.document.to_markdown()
logger.info("📄 Used to_markdown method")
elif hasattr(result.document, 'text'):
elif hasattr(result.document, "text"):
content = result.document.text
logger.info("📄 Used text property")
elif hasattr(result.document, '__str__'):
elif hasattr(result.document, "__str__"):
content = str(result.document)
logger.info("📄 Used string conversion")
if content:
logger.info(f"✅ Docling SUCCESS - {filename}: {len(content)} chars (local models)")
logger.info(
f"✅ Docling SUCCESS - {filename}: {len(content)} chars (local models)"
)
return {
'content': content,
'full_text': content,
'service_used': 'docling',
'status': 'success',
'processing_notes': 'Processed with Docling using pre-downloaded models'
"content": content,
"full_text": content,
"service_used": "docling",
"status": "success",
"processing_notes": "Processed with Docling using pre-downloaded models",
}
else:
raise ValueError("No content could be extracted from document")
else:
raise ValueError("No document object returned by Docling")
except Exception as e:
logger.error(f"❌ Docling processing failed for {filename}: {e}")
# Log the full error for debugging
import traceback
logger.error(f"Full traceback: {traceback.format_exc()}")
raise RuntimeError(f"Docling processing failed: {e}")
raise RuntimeError(f"Docling processing failed: {e}") from e
async def process_large_document_summary(
self,
content: str,
llm,
document_title: str = "Document"
self, content: str, llm, document_title: str = "Document"
) -> str:
"""
Process large documents using chunked LLM summarization.
Args:
content: The full document content
llm: The language model to use for summarization
document_title: Title of the document for context
Returns:
Final summary of the document
"""
# Large document threshold (100K characters ≈ 25K tokens)
LARGE_DOCUMENT_THRESHOLD = 100_000
if len(content) <= LARGE_DOCUMENT_THRESHOLD:
large_document_threshold = 100_000
if len(content) <= large_document_threshold:
# For smaller documents, use direct processing
logger.info(f"📄 Document size: {len(content)} chars - using direct processing")
logger.info(
f"📄 Document size: {len(content)} chars - using direct processing"
)
from app.prompts import SUMMARY_PROMPT_TEMPLATE
summary_chain = SUMMARY_PROMPT_TEMPLATE | llm
result = await summary_chain.ainvoke({"document": content})
return result.content
logger.info(f"📚 Large document detected: {len(content)} chars - using chunked processing")
logger.info(
f"📚 Large document detected: {len(content)} chars - using chunked processing"
)
# Import chunker from config
from app.config import config
from langchain_core.prompts import PromptTemplate
# Create LLM-optimized chunks (8K tokens max for safety)
from chonkie import RecursiveChunker, OverlapRefinery
from chonkie import OverlapRefinery, RecursiveChunker
from langchain_core.prompts import PromptTemplate
llm_chunker = RecursiveChunker(
chunk_size=8000 # Conservative for most LLMs
)
# Apply overlap refinery for context preservation (10% overlap = 800 tokens)
overlap_refinery = OverlapRefinery(
context_size=0.1, # 10% overlap for context preservation
method="suffix" # Add next chunk context to current chunk
method="suffix", # Add next chunk context to current chunk
)
# First chunk the content, then apply overlap refinery
initial_chunks = llm_chunker.chunk(content)
chunks = overlap_refinery.refine(initial_chunks)
total_chunks = len(chunks)
logger.info(f"📄 Split into {total_chunks} chunks for LLM processing")
# Template for chunk processing
chunk_template = PromptTemplate(
input_variables=["chunk", "chunk_number", "total_chunks"],
@ -274,34 +293,38 @@ Chunk {chunk_number}/{total_chunks}:
<document_chunk>
{chunk}
</document_chunk>
</INSTRUCTIONS>"""
</INSTRUCTIONS>""",
)
# Process each chunk individually
chunk_summaries = []
for i, chunk in enumerate(chunks, 1):
try:
logger.info(f"🔄 Processing chunk {i}/{total_chunks} ({len(chunk.text)} chars)")
logger.info(
f"🔄 Processing chunk {i}/{total_chunks} ({len(chunk.text)} chars)"
)
chunk_chain = chunk_template | llm
chunk_result = await chunk_chain.ainvoke({
"chunk": chunk.text,
"chunk_number": i,
"total_chunks": total_chunks
})
chunk_result = await chunk_chain.ainvoke(
{
"chunk": chunk.text,
"chunk_number": i,
"total_chunks": total_chunks,
}
)
chunk_summary = chunk_result.content
chunk_summaries.append(f"=== Section {i} ===\n{chunk_summary}")
logger.info(f"✅ Completed chunk {i}/{total_chunks}")
except Exception as e:
logger.error(f"❌ Failed to process chunk {i}/{total_chunks}: {e}")
chunk_summaries.append(f"=== Section {i} ===\n[Processing failed]")
# Combine summaries into final document summary
logger.info(f"🔄 Combining {len(chunk_summaries)} chunk summaries")
try:
combine_template = PromptTemplate(
input_variables=["summaries", "document_title"],
@ -318,22 +341,23 @@ Ensure:
<section_summaries>
{summaries}
</section_summaries>
</INSTRUCTIONS>"""
</INSTRUCTIONS>""",
)
combined_summaries = "\n\n".join(chunk_summaries)
combine_chain = combine_template | llm
final_result = await combine_chain.ainvoke({
"summaries": combined_summaries,
"document_title": document_title
})
final_result = await combine_chain.ainvoke(
{"summaries": combined_summaries, "document_title": document_title}
)
final_summary = final_result.content
logger.info(f"✅ Large document processing complete: {len(final_summary)} chars summary")
logger.info(
f"✅ Large document processing complete: {len(final_summary)} chars summary"
)
return final_summary
except Exception as e:
logger.error(f"❌ Failed to combine summaries: {e}")
# Fallback: return concatenated chunk summaries
@ -341,6 +365,7 @@ Ensure:
logger.warning("⚠️ Using fallback combined summary")
return fallback_summary
def create_docling_service() -> DoclingService:
"""Create a Docling service instance."""
return DoclingService()
return DoclingService()

View file

@ -1,45 +1,43 @@
from typing import Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from langchain_community.chat_models import ChatLiteLLM
import logging
from app.db import User, LLMConfig
from langchain_community.chat_models import ChatLiteLLM
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import LLMConfig, User
logger = logging.getLogger(__name__)
class LLMRole:
LONG_CONTEXT = "long_context"
FAST = "fast"
STRATEGIC = "strategic"
async def get_user_llm_instance(
session: AsyncSession,
user_id: str,
role: str
) -> Optional[ChatLiteLLM]:
session: AsyncSession, user_id: str, role: str
) -> ChatLiteLLM | None:
"""
Get a ChatLiteLLM instance for a specific user and role.
Args:
session: Database session
user_id: User ID
role: LLM role ('long_context', 'fast', or 'strategic')
Returns:
ChatLiteLLM instance or None if not found
"""
try:
# Get user with their LLM preferences
result = await session.execute(
select(User).where(User.id == user_id)
)
result = await session.execute(select(User).where(User.id == user_id))
user = result.scalars().first()
if not user:
logger.error(f"User {user_id} not found")
return None
# Get the appropriate LLM config ID based on role
llm_config_id = None
if role == LLMRole.LONG_CONTEXT:
@ -51,24 +49,23 @@ async def get_user_llm_instance(
else:
logger.error(f"Invalid LLM role: {role}")
return None
if not llm_config_id:
logger.error(f"No {role} LLM configured for user {user_id}")
return None
# Get the LLM configuration
result = await session.execute(
select(LLMConfig).where(
LLMConfig.id == llm_config_id,
LLMConfig.user_id == user_id
LLMConfig.id == llm_config_id, LLMConfig.user_id == user_id
)
)
llm_config = result.scalars().first()
if not llm_config:
logger.error(f"LLM config {llm_config_id} not found for user {user_id}")
return None
# Build the model string for litellm
if llm_config.custom_provider:
model_string = f"{llm_config.custom_provider}/{llm_config.model_name}"
@ -76,7 +73,7 @@ async def get_user_llm_instance(
# Map provider enum to litellm format
provider_map = {
"OPENAI": "openai",
"ANTHROPIC": "anthropic",
"ANTHROPIC": "anthropic",
"GROQ": "groq",
"COHERE": "cohere",
"GOOGLE": "gemini",
@ -84,37 +81,48 @@ async def get_user_llm_instance(
"MISTRAL": "mistral",
# Add more mappings as needed
}
provider_prefix = provider_map.get(llm_config.provider.value, llm_config.provider.value.lower())
provider_prefix = provider_map.get(
llm_config.provider.value, llm_config.provider.value.lower()
)
model_string = f"{provider_prefix}/{llm_config.model_name}"
# Create ChatLiteLLM instance
litellm_kwargs = {
"model": model_string,
"api_key": llm_config.api_key,
}
# Add optional parameters
if llm_config.api_base:
litellm_kwargs["api_base"] = llm_config.api_base
# Add any additional litellm parameters
if llm_config.litellm_params:
litellm_kwargs.update(llm_config.litellm_params)
return ChatLiteLLM(**litellm_kwargs)
except Exception as e:
logger.error(f"Error getting LLM instance for user {user_id}, role {role}: {str(e)}")
logger.error(
f"Error getting LLM instance for user {user_id}, role {role}: {e!s}"
)
return None
async def get_user_long_context_llm(session: AsyncSession, user_id: str) -> Optional[ChatLiteLLM]:
async def get_user_long_context_llm(
session: AsyncSession, user_id: str
) -> ChatLiteLLM | None:
"""Get user's long context LLM instance."""
return await get_user_llm_instance(session, user_id, LLMRole.LONG_CONTEXT)
async def get_user_fast_llm(session: AsyncSession, user_id: str) -> Optional[ChatLiteLLM]:
async def get_user_fast_llm(session: AsyncSession, user_id: str) -> ChatLiteLLM | None:
"""Get user's fast LLM instance."""
return await get_user_llm_instance(session, user_id, LLMRole.FAST)
async def get_user_strategic_llm(session: AsyncSession, user_id: str) -> Optional[ChatLiteLLM]:
async def get_user_strategic_llm(
session: AsyncSession, user_id: str
) -> ChatLiteLLM | None:
"""Get user's strategic LLM instance."""
return await get_user_llm_instance(session, user_id, LLMRole.STRATEGIC)
return await get_user_llm_instance(session, user_id, LLMRole.STRATEGIC)

View file

@ -1,9 +1,10 @@
import datetime
from langchain.schema import HumanMessage, SystemMessage, AIMessage
from app.config import config
from app.services.llm_service import get_user_strategic_llm
from typing import Any
from langchain.schema import AIMessage, HumanMessage, SystemMessage
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Any, List, Optional
from app.services.llm_service import get_user_strategic_llm
class QueryService:
@ -13,13 +14,13 @@ class QueryService:
@staticmethod
async def reformulate_query_with_chat_history(
user_query: str,
session: AsyncSession,
user_id: str,
chat_history_str: Optional[str] = None
user_query: str,
session: AsyncSession,
user_id: str,
chat_history_str: str | None = None,
) -> str:
"""
Reformulate the user query using the user's strategic LLM to make it more
Reformulate the user query using the user's strategic LLM to make it more
effective for information retrieval and research purposes.
Args:
@ -38,7 +39,9 @@ class QueryService:
# Get the user's strategic LLM instance
llm = await get_user_strategic_llm(session, user_id)
if not llm:
print(f"Warning: No strategic LLM configured for user {user_id}. Using original query.")
print(
f"Warning: No strategic LLM configured for user {user_id}. Using original query."
)
return user_query
# Create system message with instructions
@ -92,14 +95,13 @@ class QueryService:
print(f"Error reformulating query: {e}")
return user_query
@staticmethod
async def langchain_chat_history_to_str(chat_history: List[Any]) -> str:
async def langchain_chat_history_to_str(chat_history: list[Any]) -> str:
"""
Convert a list of chat history messages to a string.
"""
chat_history_str = "<chat_history>\n"
for chat_message in chat_history:
if isinstance(chat_message, HumanMessage):
chat_history_str += f"<user>{chat_message.content}</user>\n"
@ -107,6 +109,6 @@ class QueryService:
chat_history_str += f"<assistant>{chat_message.content}</assistant>\n"
elif isinstance(chat_message, SystemMessage):
chat_history_str += f"<system>{chat_message.content}</system>\n"
chat_history_str += "</chat_history>"
return chat_history_str

View file

@ -1,35 +1,39 @@
import logging
from typing import List, Dict, Any, Optional
from typing import Any, Optional
from rerankers import Document as RerankerDocument
class RerankerService:
"""
Service for reranking documents using a configured reranker
"""
def __init__(self, reranker_instance=None):
"""
Initialize the reranker service
Args:
reranker_instance: The reranker instance to use for reranking
"""
self.reranker_instance = reranker_instance
def rerank_documents(self, query_text: str, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
def rerank_documents(
self, query_text: str, documents: list[dict[str, Any]]
) -> list[dict[str, Any]]:
"""
Rerank documents using the configured reranker
Args:
query_text: The query text to use for reranking
documents: List of document dictionaries to rerank
Returns:
List[Dict[str, Any]]: Reranked documents
"""
if not self.reranker_instance or not documents:
return documents
try:
# Create Document objects for the rerankers library
reranker_docs = []
@ -38,58 +42,63 @@ class RerankerService:
content = doc.get("content", "")
score = doc.get("score", 0.0)
document_info = doc.get("document", {})
reranker_docs.append(
RerankerDocument(
text=content,
doc_id=chunk_id,
metadata={
'document_id': document_info.get("id", ""),
'document_title': document_info.get("title", ""),
'document_type': document_info.get("document_type", ""),
'rrf_score': score
}
"document_id": document_info.get("id", ""),
"document_title": document_info.get("title", ""),
"document_type": document_info.get("document_type", ""),
"rrf_score": score,
},
)
)
# Rerank using the configured reranker
reranking_results = self.reranker_instance.rank(
query=query_text,
docs=reranker_docs
query=query_text, docs=reranker_docs
)
# Process the results from the reranker
# Convert to serializable dictionaries
serialized_results = []
for result in reranking_results.results:
# Find the original document by id
original_doc = next((doc for doc in documents if doc.get("chunk_id") == result.document.doc_id), None)
original_doc = next(
(
doc
for doc in documents
if doc.get("chunk_id") == result.document.doc_id
),
None,
)
if original_doc:
# Create a new document with the reranked score
reranked_doc = original_doc.copy()
reranked_doc["score"] = float(result.score)
reranked_doc["rank"] = result.rank
serialized_results.append(reranked_doc)
return serialized_results
except Exception as e:
# Log the error
logging.error(f"Error during reranking: {str(e)}")
logging.error(f"Error during reranking: {e!s}")
# Fall back to original documents without reranking
return documents
@staticmethod
def get_reranker_instance() -> Optional['RerankerService']:
def get_reranker_instance() -> Optional["RerankerService"]:
"""
Get a reranker service instance from the global configuration.
Returns:
Optional[RerankerService]: A reranker service instance if configured, None otherwise
"""
from app.config import config
if hasattr(config, 'reranker_instance') and config.reranker_instance:
if hasattr(config, "reranker_instance") and config.reranker_instance:
return RerankerService(config.reranker_instance)
return None

View file

@ -1,27 +1,15 @@
import json
from typing import Any, Dict, List
from typing import Any
class StreamingService:
def __init__(self):
self.terminal_idx = 1
self.message_annotations = [
{
"type": "TERMINAL_INFO",
"content": []
},
{
"type": "SOURCES",
"content": []
},
{
"type": "ANSWER",
"content": []
},
{
"type": "FURTHER_QUESTIONS",
"content": []
}
{"type": "TERMINAL_INFO", "content": []},
{"type": "SOURCES", "content": []},
{"type": "ANSWER", "content": []},
{"type": "FURTHER_QUESTIONS", "content": []},
]
# DEPRECATED: This sends the full annotation array every time (inefficient)
@ -35,7 +23,7 @@ class StreamingService:
Returns:
str: The formatted annotations string
"""
return f'8:{json.dumps(self.message_annotations)}\n'
return f"8:{json.dumps(self.message_annotations)}\n"
def format_terminal_info_delta(self, text: str, message_type: str = "info") -> str:
"""
@ -58,7 +46,7 @@ class StreamingService:
annotation = {"type": "TERMINAL_INFO", "content": [message]}
return f"8:[{json.dumps(annotation)}]\n"
def format_sources_delta(self, sources: List[Dict[str, Any]]) -> str:
def format_sources_delta(self, sources: list[dict[str, Any]]) -> str:
"""
Format sources as a delta annotation
@ -95,7 +83,7 @@ class StreamingService:
annotation = {"type": "ANSWER", "content": [answer_chunk]}
return f"8:[{json.dumps(annotation)}]\n"
def format_answer_annotation(self, answer_lines: List[str]) -> str:
def format_answer_annotation(self, answer_lines: list[str]) -> str:
"""
Format the complete answer as a replacement annotation
@ -113,7 +101,7 @@ class StreamingService:
return f"8:[{json.dumps(annotation)}]\n"
def format_further_questions_delta(
self, further_questions: List[Dict[str, Any]]
self, further_questions: list[dict[str, Any]]
) -> str:
"""
Format further questions as a delta annotation
@ -155,14 +143,16 @@ class StreamingService:
"""
return f"3:{json.dumps(error_message)}\n"
def format_completion(self, prompt_tokens: int = 156, completion_tokens: int = 204) -> str:
def format_completion(
self, prompt_tokens: int = 156, completion_tokens: int = 204
) -> str:
"""
Format a completion message
Args:
prompt_tokens: Number of prompt tokens
completion_tokens: Number of completion tokens
Returns:
str: The formatted completion string
"""
@ -172,7 +162,7 @@ class StreamingService:
"usage": {
"promptTokens": prompt_tokens,
"completionTokens": completion_tokens,
"totalTokens": total_tokens
}
"totalTokens": total_tokens,
},
}
return f'd:{json.dumps(completion_data)}\n'
return f"d:{json.dumps(completion_data)}\n"

View file

@ -1,111 +1,116 @@
from typing import Optional, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Log, LogLevel, LogStatus
import logging
import json
from datetime import datetime
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Log, LogLevel, LogStatus
logger = logging.getLogger(__name__)
class TaskLoggingService:
"""Service for logging background tasks using the database Log model"""
def __init__(self, session: AsyncSession, search_space_id: int):
self.session = session
self.search_space_id = search_space_id
async def log_task_start(
self,
task_name: str,
source: str,
message: str,
metadata: Optional[Dict[str, Any]] = None
metadata: dict[str, Any] | None = None,
) -> Log:
"""
Log the start of a task with IN_PROGRESS status
Args:
task_name: Name/identifier of the task
source: Source service/component (e.g., 'document_processor', 'slack_indexer')
message: Human-readable message about the task
metadata: Additional context data
Returns:
Log: The created log entry
"""
log_metadata = metadata or {}
log_metadata.update({
"task_name": task_name,
"started_at": datetime.utcnow().isoformat()
})
log_metadata.update(
{"task_name": task_name, "started_at": datetime.utcnow().isoformat()}
)
log_entry = Log(
level=LogLevel.INFO,
status=LogStatus.IN_PROGRESS,
message=message,
source=source,
log_metadata=log_metadata,
search_space_id=self.search_space_id
search_space_id=self.search_space_id,
)
self.session.add(log_entry)
await self.session.commit()
await self.session.refresh(log_entry)
logger.info(f"Started task {task_name}: {message}")
return log_entry
async def log_task_success(
self,
log_entry: Log,
message: str,
additional_metadata: Optional[Dict[str, Any]] = None
additional_metadata: dict[str, Any] | None = None,
) -> Log:
"""
Update a log entry to SUCCESS status
Args:
log_entry: The original log entry to update
message: Success message
additional_metadata: Additional metadata to merge
Returns:
Log: The updated log entry
"""
# Update the existing log entry
log_entry.status = LogStatus.SUCCESS
log_entry.message = message
# Merge additional metadata
if additional_metadata:
if log_entry.log_metadata is None:
log_entry.log_metadata = {}
log_entry.log_metadata.update(additional_metadata)
log_entry.log_metadata["completed_at"] = datetime.utcnow().isoformat()
await self.session.commit()
await self.session.refresh(log_entry)
task_name = log_entry.log_metadata.get("task_name", "unknown") if log_entry.log_metadata else "unknown"
task_name = (
log_entry.log_metadata.get("task_name", "unknown")
if log_entry.log_metadata
else "unknown"
)
logger.info(f"Completed task {task_name}: {message}")
return log_entry
async def log_task_failure(
self,
log_entry: Log,
error_message: str,
error_details: Optional[str] = None,
additional_metadata: Optional[Dict[str, Any]] = None
error_details: str | None = None,
additional_metadata: dict[str, Any] | None = None,
) -> Log:
"""
Update a log entry to FAILED status
Args:
log_entry: The original log entry to update
error_message: Error message
error_details: Detailed error information
additional_metadata: Additional metadata to merge
Returns:
Log: The updated log entry
"""
@ -113,77 +118,86 @@ class TaskLoggingService:
log_entry.status = LogStatus.FAILED
log_entry.level = LogLevel.ERROR
log_entry.message = error_message
# Merge additional metadata
if log_entry.log_metadata is None:
log_entry.log_metadata = {}
log_entry.log_metadata.update({
"failed_at": datetime.utcnow().isoformat(),
"error_details": error_details
})
log_entry.log_metadata.update(
{"failed_at": datetime.utcnow().isoformat(), "error_details": error_details}
)
if additional_metadata:
log_entry.log_metadata.update(additional_metadata)
await self.session.commit()
await self.session.refresh(log_entry)
task_name = log_entry.log_metadata.get("task_name", "unknown") if log_entry.log_metadata else "unknown"
task_name = (
log_entry.log_metadata.get("task_name", "unknown")
if log_entry.log_metadata
else "unknown"
)
logger.error(f"Failed task {task_name}: {error_message}")
if error_details:
logger.error(f"Error details: {error_details}")
return log_entry
async def log_task_progress(
self,
log_entry: Log,
progress_message: str,
progress_metadata: Optional[Dict[str, Any]] = None
progress_metadata: dict[str, Any] | None = None,
) -> Log:
"""
Update a log entry with progress information while keeping IN_PROGRESS status
Args:
log_entry: The log entry to update
progress_message: Progress update message
progress_metadata: Additional progress metadata
Returns:
Log: The updated log entry
"""
log_entry.message = progress_message
if progress_metadata:
if log_entry.log_metadata is None:
log_entry.log_metadata = {}
log_entry.log_metadata.update(progress_metadata)
log_entry.log_metadata["last_progress_update"] = datetime.utcnow().isoformat()
log_entry.log_metadata["last_progress_update"] = (
datetime.utcnow().isoformat()
)
await self.session.commit()
await self.session.refresh(log_entry)
task_name = log_entry.log_metadata.get("task_name", "unknown") if log_entry.log_metadata else "unknown"
task_name = (
log_entry.log_metadata.get("task_name", "unknown")
if log_entry.log_metadata
else "unknown"
)
logger.info(f"Progress update for task {task_name}: {progress_message}")
return log_entry
async def log_simple_event(
self,
level: LogLevel,
source: str,
message: str,
metadata: Optional[Dict[str, Any]] = None
metadata: dict[str, Any] | None = None,
) -> Log:
"""
Log a simple event (not a long-running task)
Args:
level: Log level
source: Source service/component
message: Log message
metadata: Additional context data
Returns:
Log: The created log entry
"""
@ -193,12 +207,12 @@ class TaskLoggingService:
message=message,
source=source,
log_metadata=metadata or {},
search_space_id=self.search_space_id
search_space_id=self.search_space_id,
)
self.session.add(log_entry)
await self.session.commit()
await self.session.refresh(log_entry)
logger.info(f"Logged event from {source}: {message}")
return log_entry
return log_entry

View file

@ -1,46 +1,49 @@
from typing import Optional, List
from sqlalchemy.ext.asyncio import AsyncSession
import logging
from urllib.parse import parse_qs, urlparse
import aiohttp
import validators
from langchain_community.document_loaders import AsyncChromiumLoader, FireCrawlLoader
from langchain_community.document_transformers import MarkdownifyTransformer
from langchain_core.documents import Document as LangChainDocument
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import Document, DocumentType, Chunk
from app.schemas import ExtensionDocumentContent
from youtube_transcript_api import YouTubeTranscriptApi
from app.config import config
from app.db import Chunk, Document, DocumentType
from app.prompts import SUMMARY_PROMPT_TEMPLATE
from app.utils.document_converters import convert_document_to_markdown, generate_content_hash
from app.schemas import ExtensionDocumentContent
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from langchain_core.documents import Document as LangChainDocument
from langchain_community.document_loaders import FireCrawlLoader, AsyncChromiumLoader
from langchain_community.document_transformers import MarkdownifyTransformer
import validators
from youtube_transcript_api import YouTubeTranscriptApi
from urllib.parse import urlparse, parse_qs
import aiohttp
import logging
from app.utils.document_converters import (
convert_document_to_markdown,
generate_content_hash,
)
md = MarkdownifyTransformer()
async def add_crawled_url_document(
session: AsyncSession, url: str, search_space_id: int, user_id: str
) -> Optional[Document]:
) -> Document | None:
task_logger = TaskLoggingService(session, search_space_id)
# Log task start
log_entry = await task_logger.log_task_start(
task_name="crawl_url_document",
source="background_task",
message=f"Starting URL crawling process for: {url}",
metadata={"url": url, "user_id": str(user_id)}
metadata={"url": url, "user_id": str(user_id)},
)
try:
# URL validation step
await task_logger.log_task_progress(
log_entry,
f"Validating URL: {url}",
{"stage": "validation"}
log_entry, f"Validating URL: {url}", {"stage": "validation"}
)
if not validators.url(url):
raise ValueError(f"Url {url} is not a valid URL address")
@ -48,7 +51,10 @@ async def add_crawled_url_document(
await task_logger.log_task_progress(
log_entry,
f"Setting up crawler for URL: {url}",
{"stage": "crawler_setup", "firecrawl_available": bool(config.FIRECRAWL_API_KEY)}
{
"stage": "crawler_setup",
"firecrawl_available": bool(config.FIRECRAWL_API_KEY),
},
)
if config.FIRECRAWL_API_KEY:
@ -68,21 +74,21 @@ async def add_crawled_url_document(
await task_logger.log_task_progress(
log_entry,
f"Crawling URL content: {url}",
{"stage": "crawling", "crawler_type": type(crawl_loader).__name__}
{"stage": "crawling", "crawler_type": type(crawl_loader).__name__},
)
url_crawled = await crawl_loader.aload()
if type(crawl_loader) == FireCrawlLoader:
if isinstance(crawl_loader, FireCrawlLoader):
content_in_markdown = url_crawled[0].page_content
elif type(crawl_loader) == AsyncChromiumLoader:
elif isinstance(crawl_loader, AsyncChromiumLoader):
content_in_markdown = md.transform_documents(url_crawled)[0].page_content
# Format document
await task_logger.log_task_progress(
log_entry,
f"Processing crawled content from: {url}",
{"stage": "content_processing", "content_length": len(content_in_markdown)}
{"stage": "content_processing", "content_length": len(content_in_markdown)},
)
# Format document metadata in a more maintainable way
@ -117,7 +123,7 @@ async def add_crawled_url_document(
await task_logger.log_task_progress(
log_entry,
f"Checking for duplicate content: {url}",
{"stage": "duplicate_check", "content_hash": content_hash}
{"stage": "duplicate_check", "content_hash": content_hash},
)
# Check if document with this content hash already exists
@ -125,21 +131,26 @@ async def add_crawled_url_document(
select(Document).where(Document.content_hash == content_hash)
)
existing_document = existing_doc_result.scalars().first()
if existing_document:
await task_logger.log_task_success(
log_entry,
f"Document already exists for URL: {url}",
{"duplicate_detected": True, "existing_document_id": existing_document.id}
{
"duplicate_detected": True,
"existing_document_id": existing_document.id,
},
)
logging.info(
f"Document with content hash {content_hash} already exists. Skipping processing."
)
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
return existing_document
# Get LLM for summary generation
await task_logger.log_task_progress(
log_entry,
f"Preparing for summary generation: {url}",
{"stage": "llm_setup"}
{"stage": "llm_setup"},
)
# Get user's long context LLM
@ -151,7 +162,7 @@ async def add_crawled_url_document(
await task_logger.log_task_progress(
log_entry,
f"Generating summary for URL content: {url}",
{"stage": "summary_generation"}
{"stage": "summary_generation"},
)
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
@ -165,7 +176,7 @@ async def add_crawled_url_document(
await task_logger.log_task_progress(
log_entry,
f"Processing content chunks for URL: {url}",
{"stage": "chunk_processing"}
{"stage": "chunk_processing"},
)
chunks = [
@ -180,13 +191,13 @@ async def add_crawled_url_document(
await task_logger.log_task_progress(
log_entry,
f"Creating document in database for URL: {url}",
{"stage": "document_creation", "chunks_count": len(chunks)}
{"stage": "document_creation", "chunks_count": len(chunks)},
)
document = Document(
search_space_id=search_space_id,
title=url_crawled[0].metadata["title"]
if type(crawl_loader) == FireCrawlLoader
if isinstance(crawl_loader, FireCrawlLoader)
else url_crawled[0].metadata["source"],
document_type=DocumentType.CRAWLED_URL,
document_metadata=url_crawled[0].metadata,
@ -209,8 +220,8 @@ async def add_crawled_url_document(
"title": document.title,
"content_hash": content_hash,
"chunks_count": len(chunks),
"summary_length": len(summary_content)
}
"summary_length": len(summary_content),
},
)
return document
@ -221,7 +232,7 @@ async def add_crawled_url_document(
log_entry,
f"Database error while processing URL: {url}",
str(db_error),
{"error_type": "SQLAlchemyError"}
{"error_type": "SQLAlchemyError"},
)
raise db_error
except Exception as e:
@ -230,14 +241,17 @@ async def add_crawled_url_document(
log_entry,
f"Failed to crawl URL: {url}",
str(e),
{"error_type": type(e).__name__}
{"error_type": type(e).__name__},
)
raise RuntimeError(f"Failed to crawl URL: {str(e)}")
raise RuntimeError(f"Failed to crawl URL: {e!s}") from e
async def add_extension_received_document(
session: AsyncSession, content: ExtensionDocumentContent, search_space_id: int, user_id: str
) -> Optional[Document]:
session: AsyncSession,
content: ExtensionDocumentContent,
search_space_id: int,
user_id: str,
) -> Document | None:
"""
Process and store document content received from the SurfSense Extension.
@ -250,7 +264,7 @@ async def add_extension_received_document(
Document object if successful, None if failed
"""
task_logger = TaskLoggingService(session, search_space_id)
# Log task start
log_entry = await task_logger.log_task_start(
task_name="extension_document",
@ -259,10 +273,10 @@ async def add_extension_received_document(
metadata={
"url": content.metadata.VisitedWebPageURL,
"title": content.metadata.VisitedWebPageTitle,
"user_id": str(user_id)
}
"user_id": str(user_id),
},
)
try:
# Format document metadata in a more maintainable way
metadata_sections = [
@ -301,14 +315,19 @@ async def add_extension_received_document(
select(Document).where(Document.content_hash == content_hash)
)
existing_document = existing_doc_result.scalars().first()
if existing_document:
await task_logger.log_task_success(
log_entry,
f"Extension document already exists: {content.metadata.VisitedWebPageTitle}",
{"duplicate_detected": True, "existing_document_id": existing_document.id}
{
"duplicate_detected": True,
"existing_document_id": existing_document.id,
},
)
logging.info(
f"Document with content hash {content_hash} already exists. Skipping processing."
)
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
return existing_document
# Get user's long context LLM
@ -356,8 +375,8 @@ async def add_extension_received_document(
{
"document_id": document.id,
"content_hash": content_hash,
"url": content.metadata.VisitedWebPageURL
}
"url": content.metadata.VisitedWebPageURL,
},
)
return document
@ -368,7 +387,7 @@ async def add_extension_received_document(
log_entry,
f"Database error processing extension document: {content.metadata.VisitedWebPageTitle}",
str(db_error),
{"error_type": "SQLAlchemyError"}
{"error_type": "SQLAlchemyError"},
)
raise db_error
except Exception as e:
@ -377,24 +396,32 @@ async def add_extension_received_document(
log_entry,
f"Failed to process extension document: {content.metadata.VisitedWebPageTitle}",
str(e),
{"error_type": type(e).__name__}
{"error_type": type(e).__name__},
)
raise RuntimeError(f"Failed to process extension document: {str(e)}")
raise RuntimeError(f"Failed to process extension document: {e!s}") from e
async def add_received_markdown_file_document(
session: AsyncSession, file_name: str, file_in_markdown: str, search_space_id: int, user_id: str
) -> Optional[Document]:
session: AsyncSession,
file_name: str,
file_in_markdown: str,
search_space_id: int,
user_id: str,
) -> Document | None:
task_logger = TaskLoggingService(session, search_space_id)
# Log task start
log_entry = await task_logger.log_task_start(
task_name="markdown_file_document",
source="background_task",
message=f"Processing markdown file: {file_name}",
metadata={"filename": file_name, "user_id": str(user_id), "content_length": len(file_in_markdown)}
metadata={
"filename": file_name,
"user_id": str(user_id),
"content_length": len(file_in_markdown),
},
)
try:
content_hash = generate_content_hash(file_in_markdown, search_space_id)
@ -403,14 +430,19 @@ async def add_received_markdown_file_document(
select(Document).where(Document.content_hash == content_hash)
)
existing_document = existing_doc_result.scalars().first()
if existing_document:
await task_logger.log_task_success(
log_entry,
f"Markdown file document already exists: {file_name}",
{"duplicate_detected": True, "existing_document_id": existing_document.id}
{
"duplicate_detected": True,
"existing_document_id": existing_document.id,
},
)
logging.info(
f"Document with content hash {content_hash} already exists. Skipping processing."
)
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
return existing_document
# Get user's long context LLM
@ -459,8 +491,8 @@ async def add_received_markdown_file_document(
"document_id": document.id,
"content_hash": content_hash,
"chunks_count": len(chunks),
"summary_length": len(summary_content)
}
"summary_length": len(summary_content),
},
)
return document
@ -470,7 +502,7 @@ async def add_received_markdown_file_document(
log_entry,
f"Database error processing markdown file: {file_name}",
str(db_error),
{"error_type": "SQLAlchemyError"}
{"error_type": "SQLAlchemyError"},
)
raise db_error
except Exception as e:
@ -479,18 +511,18 @@ async def add_received_markdown_file_document(
log_entry,
f"Failed to process markdown file: {file_name}",
str(e),
{"error_type": type(e).__name__}
{"error_type": type(e).__name__},
)
raise RuntimeError(f"Failed to process file document: {str(e)}")
raise RuntimeError(f"Failed to process file document: {e!s}") from e
async def add_received_file_document_using_unstructured(
session: AsyncSession,
file_name: str,
unstructured_processed_elements: List[LangChainDocument],
unstructured_processed_elements: list[LangChainDocument],
search_space_id: int,
user_id: str,
) -> Optional[Document]:
) -> Document | None:
try:
file_in_markdown = await convert_document_to_markdown(
unstructured_processed_elements
@ -503,9 +535,11 @@ async def add_received_file_document_using_unstructured(
select(Document).where(Document.content_hash == content_hash)
)
existing_document = existing_doc_result.scalars().first()
if existing_document:
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
logging.info(
f"Document with content hash {content_hash} already exists. Skipping processing."
)
return existing_document
# TODO: Check if file_markdown exceeds token limit of embedding model
@ -555,7 +589,7 @@ async def add_received_file_document_using_unstructured(
raise db_error
except Exception as e:
await session.rollback()
raise RuntimeError(f"Failed to process file document: {str(e)}")
raise RuntimeError(f"Failed to process file document: {e!s}") from e
async def add_received_file_document_using_llamacloud(
@ -564,7 +598,7 @@ async def add_received_file_document_using_llamacloud(
llamacloud_markdown_document: str,
search_space_id: int,
user_id: str,
) -> Optional[Document]:
) -> Document | None:
"""
Process and store document content parsed by LlamaCloud.
@ -588,9 +622,11 @@ async def add_received_file_document_using_llamacloud(
select(Document).where(Document.content_hash == content_hash)
)
existing_document = existing_doc_result.scalars().first()
if existing_document:
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
logging.info(
f"Document with content hash {content_hash} already exists. Skipping processing."
)
return existing_document
# Get user's long context LLM
@ -638,7 +674,9 @@ async def add_received_file_document_using_llamacloud(
raise db_error
except Exception as e:
await session.rollback()
raise RuntimeError(f"Failed to process file document using LlamaCloud: {str(e)}")
raise RuntimeError(
f"Failed to process file document using LlamaCloud: {e!s}"
) from e
async def add_received_file_document_using_docling(
@ -647,7 +685,7 @@ async def add_received_file_document_using_docling(
docling_markdown_document: str,
search_space_id: int,
user_id: str,
) -> Optional[Document]:
) -> Document | None:
"""
Process and store document content parsed by Docling.
@ -671,9 +709,11 @@ async def add_received_file_document_using_docling(
select(Document).where(Document.content_hash == content_hash)
)
existing_document = existing_doc_result.scalars().first()
if existing_document:
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
logging.info(
f"Document with content hash {content_hash} already exists. Skipping processing."
)
return existing_document
# Get user's long context LLM
@ -683,12 +723,11 @@ async def add_received_file_document_using_docling(
# Generate summary using chunked processing for large documents
from app.services.docling_service import create_docling_service
docling_service = create_docling_service()
summary_content = await docling_service.process_large_document_summary(
content=file_in_markdown,
llm=user_llm,
document_title=file_name
content=file_in_markdown, llm=user_llm, document_title=file_name
)
summary_embedding = config.embedding_model_instance.embed(summary_content)
@ -726,7 +765,9 @@ async def add_received_file_document_using_docling(
raise db_error
except Exception as e:
await session.rollback()
raise RuntimeError(f"Failed to process file document using Docling: {str(e)}")
raise RuntimeError(
f"Failed to process file document using Docling: {e!s}"
) from e
async def add_youtube_video_document(
@ -749,23 +790,23 @@ async def add_youtube_video_document(
RuntimeError: If the video processing fails
"""
task_logger = TaskLoggingService(session, search_space_id)
# Log task start
log_entry = await task_logger.log_task_start(
task_name="youtube_video_document",
source="background_task",
message=f"Starting YouTube video processing for: {url}",
metadata={"url": url, "user_id": str(user_id)}
metadata={"url": url, "user_id": str(user_id)},
)
try:
# Extract video ID from URL
await task_logger.log_task_progress(
log_entry,
f"Extracting video ID from URL: {url}",
{"stage": "video_id_extraction"}
{"stage": "video_id_extraction"},
)
def get_youtube_video_id(url: str):
parsed_url = urlparse(url)
hostname = parsed_url.hostname
@ -790,14 +831,14 @@ async def add_youtube_video_document(
await task_logger.log_task_progress(
log_entry,
f"Video ID extracted: {video_id}",
{"stage": "video_id_extracted", "video_id": video_id}
{"stage": "video_id_extracted", "video_id": video_id},
)
# Get video metadata
await task_logger.log_task_progress(
log_entry,
f"Fetching video metadata for: {video_id}",
{"stage": "metadata_fetch"}
{"stage": "metadata_fetch"},
)
params = {
@ -806,21 +847,27 @@ async def add_youtube_video_document(
}
oembed_url = "https://www.youtube.com/oembed"
async with aiohttp.ClientSession() as http_session:
async with http_session.get(oembed_url, params=params) as response:
video_data = await response.json()
async with (
aiohttp.ClientSession() as http_session,
http_session.get(oembed_url, params=params) as response,
):
video_data = await response.json()
await task_logger.log_task_progress(
log_entry,
f"Video metadata fetched: {video_data.get('title', 'Unknown')}",
{"stage": "metadata_fetched", "title": video_data.get('title'), "author": video_data.get('author_name')}
{
"stage": "metadata_fetched",
"title": video_data.get("title"),
"author": video_data.get("author_name"),
},
)
# Get video transcript
await task_logger.log_task_progress(
log_entry,
f"Fetching transcript for video: {video_id}",
{"stage": "transcript_fetch"}
{"stage": "transcript_fetch"},
)
try:
@ -834,25 +881,29 @@ async def add_youtube_video_document(
timestamp = f"[{start_time:.2f}s-{start_time + duration:.2f}s]"
transcript_segments.append(f"{timestamp} {text}")
transcript_text = "\n".join(transcript_segments)
await task_logger.log_task_progress(
log_entry,
f"Transcript fetched successfully: {len(captions)} segments",
{"stage": "transcript_fetched", "segments_count": len(captions), "transcript_length": len(transcript_text)}
{
"stage": "transcript_fetched",
"segments_count": len(captions),
"transcript_length": len(transcript_text),
},
)
except Exception as e:
transcript_text = f"No captions available for this video. Error: {str(e)}"
transcript_text = f"No captions available for this video. Error: {e!s}"
await task_logger.log_task_progress(
log_entry,
f"No transcript available for video: {video_id}",
{"stage": "transcript_unavailable", "error": str(e)}
{"stage": "transcript_unavailable", "error": str(e)},
)
# Format document
await task_logger.log_task_progress(
log_entry,
f"Processing video content: {video_data.get('title', 'YouTube Video')}",
{"stage": "content_processing"}
{"stage": "content_processing"},
)
# Format document metadata in a more maintainable way
@ -890,7 +941,7 @@ async def add_youtube_video_document(
await task_logger.log_task_progress(
log_entry,
f"Checking for duplicate video content: {video_id}",
{"stage": "duplicate_check", "content_hash": content_hash}
{"stage": "duplicate_check", "content_hash": content_hash},
)
# Check if document with this content hash already exists
@ -898,21 +949,27 @@ async def add_youtube_video_document(
select(Document).where(Document.content_hash == content_hash)
)
existing_document = existing_doc_result.scalars().first()
if existing_document:
await task_logger.log_task_success(
log_entry,
f"YouTube video document already exists: {video_data.get('title', 'YouTube Video')}",
{"duplicate_detected": True, "existing_document_id": existing_document.id, "video_id": video_id}
{
"duplicate_detected": True,
"existing_document_id": existing_document.id,
"video_id": video_id,
},
)
logging.info(
f"Document with content hash {content_hash} already exists. Skipping processing."
)
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
return existing_document
# Get LLM for summary generation
await task_logger.log_task_progress(
log_entry,
f"Preparing for summary generation: {video_data.get('title', 'YouTube Video')}",
{"stage": "llm_setup"}
{"stage": "llm_setup"},
)
# Get user's long context LLM
@ -924,7 +981,7 @@ async def add_youtube_video_document(
await task_logger.log_task_progress(
log_entry,
f"Generating summary for video: {video_data.get('title', 'YouTube Video')}",
{"stage": "summary_generation"}
{"stage": "summary_generation"},
)
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
@ -938,7 +995,7 @@ async def add_youtube_video_document(
await task_logger.log_task_progress(
log_entry,
f"Processing content chunks for video: {video_data.get('title', 'YouTube Video')}",
{"stage": "chunk_processing"}
{"stage": "chunk_processing"},
)
chunks = [
@ -953,7 +1010,7 @@ async def add_youtube_video_document(
await task_logger.log_task_progress(
log_entry,
f"Creating YouTube video document in database: {video_data.get('title', 'YouTube Video')}",
{"stage": "document_creation", "chunks_count": len(chunks)}
{"stage": "document_creation", "chunks_count": len(chunks)},
)
document = Document(
@ -988,8 +1045,8 @@ async def add_youtube_video_document(
"content_hash": content_hash,
"chunks_count": len(chunks),
"summary_length": len(summary_content),
"has_transcript": "No captions available" not in transcript_text
}
"has_transcript": "No captions available" not in transcript_text,
},
)
return document
@ -999,7 +1056,10 @@ async def add_youtube_video_document(
log_entry,
f"Database error while processing YouTube video: {url}",
str(db_error),
{"error_type": "SQLAlchemyError", "video_id": video_id if 'video_id' in locals() else None}
{
"error_type": "SQLAlchemyError",
"video_id": video_id if "video_id" in locals() else None,
},
)
raise db_error
except Exception as e:
@ -1008,7 +1068,10 @@ async def add_youtube_video_document(
log_entry,
f"Failed to process YouTube video: {url}",
str(e),
{"error_type": type(e).__name__, "video_id": video_id if 'video_id' in locals() else None}
{
"error_type": type(e).__name__,
"video_id": video_id if "video_id" in locals() else None,
},
)
logging.error(f"Failed to process YouTube video: {str(e)}")
logging.error(f"Failed to process YouTube video: {e!s}")
raise

File diff suppressed because it is too large Load diff

View file

@ -1,33 +1,29 @@
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.podcaster.graph import graph as podcaster_graph
from app.agents.podcaster.state import State
from app.db import Chat, Podcast
from app.services.task_logging_service import TaskLoggingService
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import SQLAlchemyError
async def generate_document_podcast(
session: AsyncSession,
document_id: int,
search_space_id: int,
user_id: int
session: AsyncSession, document_id: int, search_space_id: int, user_id: int
):
# TODO: Need to fetch the document chunks, then concatenate them and pass them to the podcast generation model
pass
async def generate_chat_podcast(
session: AsyncSession,
chat_id: int,
search_space_id: int,
podcast_title: str,
user_id: int
user_id: int,
):
task_logger = TaskLoggingService(session, search_space_id)
# Log task start
log_entry = await task_logger.log_task_start(
task_name="generate_chat_podcast",
@ -37,44 +33,43 @@ async def generate_chat_podcast(
"chat_id": chat_id,
"search_space_id": search_space_id,
"podcast_title": podcast_title,
"user_id": str(user_id)
}
"user_id": str(user_id),
},
)
try:
# Fetch the chat with the specified ID
await task_logger.log_task_progress(
log_entry,
f"Fetching chat {chat_id} from database",
{"stage": "fetch_chat"}
log_entry, f"Fetching chat {chat_id} from database", {"stage": "fetch_chat"}
)
query = select(Chat).filter(
Chat.id == chat_id,
Chat.search_space_id == search_space_id
Chat.id == chat_id, Chat.search_space_id == search_space_id
)
result = await session.execute(query)
chat = result.scalars().first()
if not chat:
await task_logger.log_task_failure(
log_entry,
f"Chat with id {chat_id} not found in search space {search_space_id}",
"Chat not found",
{"error_type": "ChatNotFound"}
{"error_type": "ChatNotFound"},
)
raise ValueError(f"Chat with id {chat_id} not found in search space {search_space_id}")
raise ValueError(
f"Chat with id {chat_id} not found in search space {search_space_id}"
)
# Create chat history structure
await task_logger.log_task_progress(
log_entry,
f"Processing chat history for chat {chat_id}",
{"stage": "process_chat_history", "message_count": len(chat.messages)}
{"stage": "process_chat_history", "message_count": len(chat.messages)},
)
chat_history_str = "<chat_history>"
processed_messages = 0
for message in chat.messages:
if message["role"] == "user":
@ -89,18 +84,24 @@ async def generate_chat_podcast(
# If content is a list, join it into a single string
if isinstance(answer_text, list):
answer_text = "\n".join(answer_text)
chat_history_str += f"<assistant_message>{answer_text}</assistant_message>"
chat_history_str += (
f"<assistant_message>{answer_text}</assistant_message>"
)
processed_messages += 1
chat_history_str += "</chat_history>"
# Pass it to the SurfSense Podcaster
await task_logger.log_task_progress(
log_entry,
f"Initializing podcast generation for chat {chat_id}",
{"stage": "initialize_podcast_generation", "processed_messages": processed_messages, "content_length": len(chat_history_str)}
{
"stage": "initialize_podcast_generation",
"processed_messages": processed_messages,
"content_length": len(chat_history_str),
},
)
config = {
"configurable": {
"podcast_title": "SurfSense",
@ -108,53 +109,55 @@ async def generate_chat_podcast(
}
}
# Initialize state with database session and streaming service
initial_state = State(
source_content=chat_history_str,
db_session=session
)
initial_state = State(source_content=chat_history_str, db_session=session)
# Run the graph directly
await task_logger.log_task_progress(
log_entry,
f"Running podcast generation graph for chat {chat_id}",
{"stage": "run_podcast_graph"}
{"stage": "run_podcast_graph"},
)
result = await podcaster_graph.ainvoke(initial_state, config=config)
# Convert podcast transcript entries to serializable format
await task_logger.log_task_progress(
log_entry,
f"Processing podcast transcript for chat {chat_id}",
{"stage": "process_transcript", "transcript_entries": len(result["podcast_transcript"])}
{
"stage": "process_transcript",
"transcript_entries": len(result["podcast_transcript"]),
},
)
serializable_transcript = []
for entry in result["podcast_transcript"]:
serializable_transcript.append({
"speaker_id": entry.speaker_id,
"dialog": entry.dialog
})
serializable_transcript.append(
{"speaker_id": entry.speaker_id, "dialog": entry.dialog}
)
# Create a new podcast entry
await task_logger.log_task_progress(
log_entry,
f"Creating podcast database entry for chat {chat_id}",
{"stage": "create_podcast_entry", "file_location": result.get("final_podcast_file_path")}
{
"stage": "create_podcast_entry",
"file_location": result.get("final_podcast_file_path"),
},
)
podcast = Podcast(
title=f"{podcast_title}",
podcast_transcript=serializable_transcript,
file_location=result["final_podcast_file_path"],
search_space_id=search_space_id
search_space_id=search_space_id,
)
# Add to session and commit
session.add(podcast)
await session.commit()
await session.refresh(podcast)
# Log success
await task_logger.log_task_success(
log_entry,
@ -165,10 +168,10 @@ async def generate_chat_podcast(
"transcript_entries": len(serializable_transcript),
"file_location": result.get("final_podcast_file_path"),
"processed_messages": processed_messages,
"content_length": len(chat_history_str)
}
"content_length": len(chat_history_str),
},
)
return podcast
except ValueError as ve:
@ -178,7 +181,7 @@ async def generate_chat_podcast(
log_entry,
f"Value error during podcast generation for chat {chat_id}",
str(ve),
{"error_type": "ValueError"}
{"error_type": "ValueError"},
)
raise ve
except SQLAlchemyError as db_error:
@ -187,7 +190,7 @@ async def generate_chat_podcast(
log_entry,
f"Database error during podcast generation for chat {chat_id}",
str(db_error),
{"error_type": "SQLAlchemyError"}
{"error_type": "SQLAlchemyError"},
)
raise db_error
except Exception as e:
@ -196,7 +199,8 @@ async def generate_chat_podcast(
log_entry,
f"Unexpected error during podcast generation for chat {chat_id}",
str(e),
{"error_type": type(e).__name__}
{"error_type": type(e).__name__},
)
raise RuntimeError(f"Failed to generate podcast for chat {chat_id}: {str(e)}")
raise RuntimeError(
f"Failed to generate podcast for chat {chat_id}: {e!s}"
) from e

View file

@ -1,28 +1,29 @@
from typing import Any, AsyncGenerator, List, Union
from collections.abc import AsyncGenerator
from typing import Any
from uuid import UUID
from app.agents.researcher.graph import graph as researcher_graph
from app.agents.researcher.state import State
from app.services.streaming_service import StreamingService
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.researcher.configuration import SearchMode
from app.agents.researcher.graph import graph as researcher_graph
from app.agents.researcher.state import State
from app.services.streaming_service import StreamingService
async def stream_connector_search_results(
user_query: str,
user_id: Union[str, UUID],
search_space_id: int,
session: AsyncSession,
research_mode: str,
selected_connectors: List[str],
langchain_chat_history: List[Any],
user_query: str,
user_id: str | UUID,
search_space_id: int,
session: AsyncSession,
research_mode: str,
selected_connectors: list[str],
langchain_chat_history: list[Any],
search_mode_str: str,
document_ids_to_add_in_context: List[int]
document_ids_to_add_in_context: list[int],
) -> AsyncGenerator[str, None]:
"""
Stream connector search results to the client
Args:
user_query: The user's query
user_id: The user's ID (can be UUID object or string)
@ -30,61 +31,60 @@ async def stream_connector_search_results(
session: The database session
research_mode: The research mode
selected_connectors: List of selected connectors
Yields:
str: Formatted response strings
"""
streaming_service = StreamingService()
if research_mode == "REPORT_GENERAL":
NUM_SECTIONS = 1
num_sections = 1
elif research_mode == "REPORT_DEEP":
NUM_SECTIONS = 3
num_sections = 3
elif research_mode == "REPORT_DEEPER":
NUM_SECTIONS = 6
num_sections = 6
else:
# Default fallback
NUM_SECTIONS = 1
num_sections = 1
# Convert UUID to string if needed
user_id_str = str(user_id) if isinstance(user_id, UUID) else user_id
if search_mode_str == "CHUNKS":
search_mode = SearchMode.CHUNKS
elif search_mode_str == "DOCUMENTS":
search_mode = SearchMode.DOCUMENTS
# Sample configuration
config = {
"configurable": {
"user_query": user_query,
"num_sections": NUM_SECTIONS,
"num_sections": num_sections,
"connectors_to_search": selected_connectors,
"user_id": user_id_str,
"search_space_id": search_space_id,
"search_mode": search_mode,
"research_mode": research_mode,
"document_ids_to_add_in_context": document_ids_to_add_in_context
"document_ids_to_add_in_context": document_ids_to_add_in_context,
}
}
# Initialize state with database session and streaming service
initial_state = State(
db_session=session,
streaming_service=streaming_service,
chat_history=langchain_chat_history
chat_history=langchain_chat_history,
)
# Run the graph directly
print("\nRunning the complete researcher workflow...")
# Use streaming with config parameter
async for chunk in researcher_graph.astream(
initial_state,
config=config,
stream_mode="custom",
):
if isinstance(chunk, dict):
if "yield_value" in chunk:
yield chunk["yield_value"]
if isinstance(chunk, dict) and "yield_value" in chunk:
yield chunk["yield_value"]
yield streaming_service.format_completion()

View file

@ -1,8 +1,7 @@
from typing import Optional
import uuid
from fastapi import Depends, Request, Response
from fastapi.responses import RedirectResponse
from fastapi.responses import JSONResponse, RedirectResponse
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models
from fastapi_users.authentication import (
AuthenticationBackend,
@ -10,21 +9,23 @@ from fastapi_users.authentication import (
JWTStrategy,
)
from fastapi_users.db import SQLAlchemyUserDatabase
from fastapi.responses import JSONResponse
from fastapi_users.schemas import model_dump
from pydantic import BaseModel
from app.config import config
from app.db import User, get_user_db
from pydantic import BaseModel
class BearerResponse(BaseModel):
access_token: str
token_type: str
SECRET = config.SECRET_KEY
if config.AUTH_TYPE == "GOOGLE":
from httpx_oauth.clients.google import GoogleOAuth2
google_oauth_client = GoogleOAuth2(
config.GOOGLE_OAUTH_CLIENT_ID,
config.GOOGLE_OAUTH_CLIENT_SECRET,
@ -35,27 +36,26 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = SECRET
verification_token_secret = SECRET
async def on_after_register(self, user: User, request: Optional[Request] = None):
async def on_after_register(self, user: User, request: Request | None = None):
print(f"User {user.id} has registered.")
async def on_after_forgot_password(
self, user: User, token: str, request: Optional[Request] = None
self, user: User, token: str, request: Request | None = None
):
print(f"User {user.id} has forgot their password. Reset token: {token}")
async def on_after_request_verify(
self, user: User, token: str, request: Optional[Request] = None
self, user: User, token: str, request: Request | None = None
):
print(
f"Verification requested for user {user.id}. Verification token: {token}")
print(f"Verification requested for user {user.id}. Verification token: {token}")
async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)):
yield UserManager(user_db)
def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
return JWTStrategy(secret=SECRET, lifetime_seconds=3600*24)
return JWTStrategy(secret=SECRET, lifetime_seconds=3600 * 24)
# # COOKIE AUTH | Uncomment if you want to use cookie auth.
@ -77,6 +77,7 @@ def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
# get_strategy=get_jwt_strategy,
# )
# BEARER AUTH CODE.
class CustomBearerTransport(BearerTransport):
async def get_login_response(self, token: str) -> Response:
@ -87,6 +88,7 @@ class CustomBearerTransport(BearerTransport):
else:
return JSONResponse(model_dump(bearer_response))
bearer_transport = CustomBearerTransport(tokenUrl="auth/jwt/login")
@ -98,4 +100,4 @@ auth_backend = AuthenticationBackend(
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
current_active_user = fastapi_users.current_user(active=True)
current_active_user = fastapi_users.current_user(active=True)

View file

@ -1,12 +1,19 @@
from fastapi import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import User
# Helper function to check user ownership
async def check_ownership(session: AsyncSession, model, item_id: int, user: User):
item = await session.execute(select(model).filter(model.id == item_id, model.user_id == user.id))
item = await session.execute(
select(model).filter(model.id == item_id, model.user_id == user.id)
)
item = item.scalars().first()
if not item:
raise HTTPException(status_code=404, detail="Item not found or you don't have permission to access it")
return item
raise HTTPException(
status_code=404,
detail="Item not found or you don't have permission to access it",
)
return item

View file

@ -32,7 +32,7 @@ async def convert_element_to_markdown(element) -> str:
"Footer": lambda x: f"*{x}*\n\n",
"CodeSnippet": lambda x: f"```\n{x}\n```",
"PageNumber": lambda x: f"*Page {x}*\n\n",
"UncategorizedText": lambda x: f"{x}\n\n"
"UncategorizedText": lambda x: f"{x}\n\n",
}
converter = markdown_mapping.get(element_category, lambda x: x)
@ -74,7 +74,7 @@ def convert_chunks_to_langchain_documents(chunks):
except ImportError:
raise ImportError(
"LangChain is not installed. Please install it with `pip install langchain langchain-core`"
)
) from None
langchain_docs = []
@ -92,17 +92,20 @@ def convert_chunks_to_langchain_documents(chunks):
# Add document information to metadata
if "document" in chunk:
doc = chunk["document"]
metadata.update({
"document_id": doc.get("id"),
"document_title": doc.get("title"),
"document_type": doc.get("document_type"),
})
metadata.update(
{
"document_id": doc.get("id"),
"document_title": doc.get("title"),
"document_type": doc.get("document_type"),
}
)
# Add document metadata if available
if "metadata" in doc:
# Prefix document metadata keys to avoid conflicts
doc_metadata = {f"doc_meta_{k}": v for k,
v in doc.get("metadata", {}).items()}
doc_metadata = {
f"doc_meta_{k}": v for k, v in doc.get("metadata", {}).items()
}
metadata.update(doc_metadata)
# Add source URL if available in metadata
@ -131,10 +134,7 @@ def convert_chunks_to_langchain_documents(chunks):
"""
# Create LangChain Document
langchain_doc = LangChainDocument(
page_content=new_content,
metadata=metadata
)
langchain_doc = LangChainDocument(page_content=new_content, metadata=metadata)
langchain_docs.append(langchain_doc)
@ -144,4 +144,4 @@ def convert_chunks_to_langchain_documents(chunks):
def generate_content_hash(content: str, search_space_id: int) -> str:
"""Generate SHA-256 hash for the given content combined with search space ID."""
combined_data = f"{search_space_id}:{content}"
return hashlib.sha256(combined_data.encode('utf-8')).hexdigest()
return hashlib.sha256(combined_data.encode("utf-8")).hexdigest()

View file

@ -1,7 +1,9 @@
import uvicorn
import argparse
import logging
import uvicorn
from dotenv import load_dotenv
from app.config.uvicorn import load_uvicorn_config
logging.basicConfig(

View file

@ -103,6 +103,8 @@ ignore = [
"RUF012", # Mutable class attributes should be annotated with `typing.ClassVar`
]
extend-select = ["I"]
# Allow fix for all enabled rules (when `--fix`) is provided.
fixable = ["ALL"]
unfixable = []
@ -123,6 +125,7 @@ skip-magic-trailing-comma = false
# Automatically detect the appropriate line ending.
line-ending = "auto"
[tool.ruff.lint.isort]
# Group imports by type
known-first-party = ["app"]