From d359a59f6d5ffdaf6d818983675ce988f68ab068 Mon Sep 17 00:00:00 2001 From: Utkarsh-Patel-13 Date: Thu, 24 Jul 2025 14:43:48 -0700 Subject: [PATCH] Fixed all ruff lint and formatting errors --- surfsense_backend/alembic/env.py | 8 +- ...e_chattype_enum_to_qna_report_structure.py | 2 +- ..._add_llm_config_table_and_relationships.py | 4 +- .../alembic/versions/12_add_logs_table.py | 2 +- .../versions/1_add_github_connector_enum.py | 4 +- .../versions/2_add_linear_connector_enum.py | 2 +- ...3_add_linear_connector_to_documenttype_.py | 2 +- .../alembic/versions/4_add_linkup_api_enum.py | 2 +- .../versions/5_remove_title_char_limit.py | 2 +- .../6_change_podcast_content_to_transcript.py | 2 +- .../versions/7_remove_is_generated_column.py | 2 +- .../8_add_content_hash_to_documents.py | 2 +- ...discord_connector_enum_and_documenttype.py | 2 +- ...1_add_github_connector_to_documenttype_.py | 2 +- .../app/agents/podcaster/configuration.py | 5 +- .../app/agents/podcaster/graph.py | 8 +- .../app/agents/podcaster/nodes.py | 154 +- .../app/agents/podcaster/prompts.py | 2 +- .../app/agents/podcaster/state.py | 18 +- .../app/agents/researcher/configuration.py | 13 +- .../app/agents/researcher/graph.py | 45 +- .../app/agents/researcher/nodes.py | 869 ++++++----- .../app/agents/researcher/prompts.py | 2 +- .../agents/researcher/qna_agent/__init__.py | 3 +- .../researcher/qna_agent/configuration.py | 8 +- .../app/agents/researcher/qna_agent/graph.py | 5 +- .../app/agents/researcher/qna_agent/nodes.py | 158 +- .../app/agents/researcher/qna_agent/state.py | 14 +- .../app/agents/researcher/state.py | 29 +- .../sub_section_writer/configuration.py | 10 +- .../researcher/sub_section_writer/graph.py | 5 +- .../researcher/sub_section_writer/nodes.py | 162 ++- .../researcher/sub_section_writer/prompts.py | 2 +- .../researcher/sub_section_writer/state.py | 14 +- .../app/agents/researcher/utils.py | 103 +- surfsense_backend/app/app.py | 28 +- surfsense_backend/app/config/__init__.py | 53 +- surfsense_backend/app/config/uvicorn.py | 66 +- .../app/connectors/discord_connector.py | 94 +- .../app/connectors/github_connector.py | 247 ++-- .../app/connectors/linear_connector.py | 270 ++-- .../app/connectors/notion_history.py | 131 +- .../app/connectors/slack_history.py | 250 ++-- .../app/connectors/test_github_connector.py | 116 +- .../app/connectors/test_slack_history.py | 577 ++++---- surfsense_backend/app/db.py | 277 +++- surfsense_backend/app/prompts/__init__.py | 15 +- .../app/retriver/chunks_hybrid_search.py | 181 ++- .../app/retriver/documents_hybrid_search.py | 200 +-- surfsense_backend/app/routes/__init__.py | 9 +- surfsense_backend/app/routes/chats_routes.py | 202 +-- .../app/routes/documents_routes.py | 500 ++++--- .../app/routes/llm_config_routes.py | 126 +- surfsense_backend/app/routes/logs_routes.py | 179 +-- .../app/routes/podcasts_routes.py | 179 ++- .../routes/search_source_connectors_routes.py | 499 ++++--- .../app/routes/search_spaces_routes.py | 65 +- surfsense_backend/app/schemas/__init__.py | 106 +- surfsense_backend/app/schemas/base.py | 5 +- surfsense_backend/app/schemas/chats.py | 33 +- surfsense_backend/app/schemas/chunks.py | 7 +- surfsense_backend/app/schemas/documents.py | 22 +- surfsense_backend/app/schemas/llm_config.py | 59 +- surfsense_backend/app/schemas/logs.py | 46 +- surfsense_backend/app/schemas/podcasts.py | 15 +- .../app/schemas/search_source_connector.py | 103 +- surfsense_backend/app/schemas/search_space.py | 13 +- surfsense_backend/app/schemas/users.py | 6 +- surfsense_backend/app/services/__init__.py | 2 +- .../app/services/connector_service.py | 608 ++++---- .../app/services/docling_service.py | 263 ++-- surfsense_backend/app/services/llm_service.py | 76 +- .../app/services/query_service.py | 30 +- .../app/services/reranker_service.py | 65 +- .../app/services/streaming_service.py | 44 +- .../app/services/task_logging_service.py | 130 +- .../app/tasks/background_tasks.py | 291 ++-- .../app/tasks/connectors_indexing_tasks.py | 1273 +++++++++++------ surfsense_backend/app/tasks/podcast_tasks.py | 124 +- .../tasks/stream_connector_search_results.py | 60 +- surfsense_backend/app/users.py | 30 +- .../app/utils/check_ownership.py | 13 +- .../app/utils/document_converters.py | 28 +- surfsense_backend/main.py | 4 +- surfsense_backend/pyproject.toml | 3 + 85 files changed, 5520 insertions(+), 3870 deletions(-) diff --git a/surfsense_backend/alembic/env.py b/surfsense_backend/alembic/env.py index d6e7104..fd9740e 100644 --- a/surfsense_backend/alembic/env.py +++ b/surfsense_backend/alembic/env.py @@ -1,8 +1,8 @@ import asyncio -from logging.config import fileConfig - import os import sys +from logging.config import fileConfig + from sqlalchemy import pool from sqlalchemy.engine import Connection from sqlalchemy.ext.asyncio import async_engine_from_config @@ -11,10 +11,10 @@ from alembic import context # Ensure the app directory is in the Python path # This allows Alembic to find your models -sys.path.insert(0, os.path.realpath(os.path.join(os.path.dirname(__file__), '..'))) +sys.path.insert(0, os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))) # Import your models base -from app.db import Base # Assuming your Base is defined in app.db +from app.db import Base # Assuming your Base is defined in app.db # this is the Alembic Config object, which provides # access to the values within the .ini file in use. diff --git a/surfsense_backend/alembic/versions/10_update_chattype_enum_to_qna_report_structure.py b/surfsense_backend/alembic/versions/10_update_chattype_enum_to_qna_report_structure.py index 35902c7..29a00e6 100644 --- a/surfsense_backend/alembic/versions/10_update_chattype_enum_to_qna_report_structure.py +++ b/surfsense_backend/alembic/versions/10_update_chattype_enum_to_qna_report_structure.py @@ -6,9 +6,9 @@ Revises: 9 from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "10" diff --git a/surfsense_backend/alembic/versions/11_add_llm_config_table_and_relationships.py b/surfsense_backend/alembic/versions/11_add_llm_config_table_and_relationships.py index 83fdef1..028d147 100644 --- a/surfsense_backend/alembic/versions/11_add_llm_config_table_and_relationships.py +++ b/surfsense_backend/alembic/versions/11_add_llm_config_table_and_relationships.py @@ -6,10 +6,10 @@ Revises: 10 from typing import Sequence, Union -from alembic import op import sqlalchemy as sa -from sqlalchemy.dialects.postgresql import UUID, JSON +from sqlalchemy.dialects.postgresql import JSON, UUID +from alembic import op # revision identifiers, used by Alembic. revision: str = "11" diff --git a/surfsense_backend/alembic/versions/12_add_logs_table.py b/surfsense_backend/alembic/versions/12_add_logs_table.py index 0b2cc13..5033d2a 100644 --- a/surfsense_backend/alembic/versions/12_add_logs_table.py +++ b/surfsense_backend/alembic/versions/12_add_logs_table.py @@ -6,10 +6,10 @@ Revises: 11 from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects.postgresql import JSON +from alembic import op # revision identifiers, used by Alembic. revision: str = "12" diff --git a/surfsense_backend/alembic/versions/1_add_github_connector_enum.py b/surfsense_backend/alembic/versions/1_add_github_connector_enum.py index 1902777..2224ab6 100644 --- a/surfsense_backend/alembic/versions/1_add_github_connector_enum.py +++ b/surfsense_backend/alembic/versions/1_add_github_connector_enum.py @@ -6,8 +6,10 @@ Revises: """ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa + +from alembic import op + # Import pgvector if needed for other types, though not for this ENUM change # import pgvector diff --git a/surfsense_backend/alembic/versions/2_add_linear_connector_enum.py b/surfsense_backend/alembic/versions/2_add_linear_connector_enum.py index 526c7c3..dd1ccca 100644 --- a/surfsense_backend/alembic/versions/2_add_linear_connector_enum.py +++ b/surfsense_backend/alembic/versions/2_add_linear_connector_enum.py @@ -6,9 +6,9 @@ Revises: e55302644c51 """ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = '2' diff --git a/surfsense_backend/alembic/versions/3_add_linear_connector_to_documenttype_.py b/surfsense_backend/alembic/versions/3_add_linear_connector_to_documenttype_.py index e71ee2e..0a3c41e 100644 --- a/surfsense_backend/alembic/versions/3_add_linear_connector_to_documenttype_.py +++ b/surfsense_backend/alembic/versions/3_add_linear_connector_to_documenttype_.py @@ -6,9 +6,9 @@ Revises: 2 """ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = '3' diff --git a/surfsense_backend/alembic/versions/4_add_linkup_api_enum.py b/surfsense_backend/alembic/versions/4_add_linkup_api_enum.py index 093bdf0..35562a8 100644 --- a/surfsense_backend/alembic/versions/4_add_linkup_api_enum.py +++ b/surfsense_backend/alembic/versions/4_add_linkup_api_enum.py @@ -6,9 +6,9 @@ Revises: 3 """ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = '4' diff --git a/surfsense_backend/alembic/versions/5_remove_title_char_limit.py b/surfsense_backend/alembic/versions/5_remove_title_char_limit.py index 62fe019..db36965 100644 --- a/surfsense_backend/alembic/versions/5_remove_title_char_limit.py +++ b/surfsense_backend/alembic/versions/5_remove_title_char_limit.py @@ -6,9 +6,9 @@ Revises: 4 """ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = '5' diff --git a/surfsense_backend/alembic/versions/6_change_podcast_content_to_transcript.py b/surfsense_backend/alembic/versions/6_change_podcast_content_to_transcript.py index fa7a0f8..411761f 100644 --- a/surfsense_backend/alembic/versions/6_change_podcast_content_to_transcript.py +++ b/surfsense_backend/alembic/versions/6_change_podcast_content_to_transcript.py @@ -6,10 +6,10 @@ Revises: 5 """ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects.postgresql import JSON +from alembic import op # revision identifiers, used by Alembic. revision: str = '6' diff --git a/surfsense_backend/alembic/versions/7_remove_is_generated_column.py b/surfsense_backend/alembic/versions/7_remove_is_generated_column.py index 03048a1..8acc224 100644 --- a/surfsense_backend/alembic/versions/7_remove_is_generated_column.py +++ b/surfsense_backend/alembic/versions/7_remove_is_generated_column.py @@ -6,9 +6,9 @@ Revises: 6 """ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = '7' diff --git a/surfsense_backend/alembic/versions/8_add_content_hash_to_documents.py b/surfsense_backend/alembic/versions/8_add_content_hash_to_documents.py index 64982fc..908b956 100644 --- a/surfsense_backend/alembic/versions/8_add_content_hash_to_documents.py +++ b/surfsense_backend/alembic/versions/8_add_content_hash_to_documents.py @@ -5,9 +5,9 @@ Revises: 7 """ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = '8' diff --git a/surfsense_backend/alembic/versions/9_add_discord_connector_enum_and_documenttype.py b/surfsense_backend/alembic/versions/9_add_discord_connector_enum_and_documenttype.py index fbf748a..45e7f81 100644 --- a/surfsense_backend/alembic/versions/9_add_discord_connector_enum_and_documenttype.py +++ b/surfsense_backend/alembic/versions/9_add_discord_connector_enum_and_documenttype.py @@ -6,9 +6,9 @@ Revises: 8 from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "9" diff --git a/surfsense_backend/alembic/versions/e55302644c51_add_github_connector_to_documenttype_.py b/surfsense_backend/alembic/versions/e55302644c51_add_github_connector_to_documenttype_.py index 12d6537..1e00936 100644 --- a/surfsense_backend/alembic/versions/e55302644c51_add_github_connector_to_documenttype_.py +++ b/surfsense_backend/alembic/versions/e55302644c51_add_github_connector_to_documenttype_.py @@ -6,9 +6,9 @@ Revises: 1 """ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = 'e55302644c51' diff --git a/surfsense_backend/app/agents/podcaster/configuration.py b/surfsense_backend/app/agents/podcaster/configuration.py index 062b1ee..c4c5f9e 100644 --- a/surfsense_backend/app/agents/podcaster/configuration.py +++ b/surfsense_backend/app/agents/podcaster/configuration.py @@ -3,7 +3,6 @@ from __future__ import annotations from dataclasses import dataclass, fields -from typing import Optional from langchain_core.runnables import RunnableConfig @@ -17,11 +16,11 @@ class Configuration: # create assistants (https://langchain-ai.github.io/langgraph/cloud/how-tos/configuration_cloud/) # and when you invoke the graph podcast_title: str - user_id: str + user_id: str @classmethod def from_runnable_config( - cls, config: Optional[RunnableConfig] = None + cls, config: RunnableConfig | None = None ) -> Configuration: """Create a Configuration instance from a RunnableConfig object.""" configurable = (config.get("configurable") or {}) if config else {} diff --git a/surfsense_backend/app/agents/podcaster/graph.py b/surfsense_backend/app/agents/podcaster/graph.py index d102432..9404556 100644 --- a/surfsense_backend/app/agents/podcaster/graph.py +++ b/surfsense_backend/app/agents/podcaster/graph.py @@ -1,14 +1,11 @@ from langgraph.graph import StateGraph from .configuration import Configuration +from .nodes import create_merged_podcast_audio, create_podcast_transcript from .state import State -from .nodes import create_merged_podcast_audio, create_podcast_transcript - - def build_graph(): - # Define a new graph workflow = StateGraph(State, config_schema=Configuration) @@ -24,8 +21,9 @@ def build_graph(): # Compile the workflow into an executable graph graph = workflow.compile() graph.name = "Surfsense Podcaster" # This defines the custom name in LangSmith - + return graph + # Compile the graph once when the module is loaded graph = build_graph() diff --git a/surfsense_backend/app/agents/podcaster/nodes.py b/surfsense_backend/app/agents/podcaster/nodes.py index 63373b5..2309a29 100644 --- a/surfsense_backend/app/agents/podcaster/nodes.py +++ b/surfsense_backend/app/agents/podcaster/nodes.py @@ -1,148 +1,154 @@ -from typing import Any, Dict +import asyncio import json import os import uuid from pathlib import Path -import asyncio +from typing import Any +from ffmpeg.asyncio import FFmpeg from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.runnables import RunnableConfig from litellm import aspeech -from ffmpeg.asyncio import FFmpeg -from .configuration import Configuration -from .state import PodcastTranscriptEntry, State, PodcastTranscripts -from .prompts import get_podcast_generation_prompt from app.config import config as app_config from app.services.llm_service import get_user_long_context_llm +from .configuration import Configuration +from .prompts import get_podcast_generation_prompt +from .state import PodcastTranscriptEntry, PodcastTranscripts, State -async def create_podcast_transcript(state: State, config: RunnableConfig) -> Dict[str, Any]: + +async def create_podcast_transcript( + state: State, config: RunnableConfig +) -> dict[str, Any]: """Each node does work.""" - + # Get configuration from runnable config configuration = Configuration.from_runnable_config(config) user_id = configuration.user_id - + # Get user's long context LLM llm = await get_user_long_context_llm(state.db_session, user_id) if not llm: error_message = f"No long context LLM configured for user {user_id}" print(error_message) raise RuntimeError(error_message) - + # Get the prompt prompt = get_podcast_generation_prompt() - + # Create the messages messages = [ SystemMessage(content=prompt), - HumanMessage(content=f"{state.source_content}") + HumanMessage( + content=f"{state.source_content}" + ), ] - + # Generate the podcast transcript llm_response = await llm.ainvoke(messages) - + # First try the direct approach try: - podcast_transcript = PodcastTranscripts.model_validate(json.loads(llm_response.content)) + podcast_transcript = PodcastTranscripts.model_validate( + json.loads(llm_response.content) + ) except (json.JSONDecodeError, ValueError) as e: - print(f"Direct JSON parsing failed, trying fallback approach: {str(e)}") - + print(f"Direct JSON parsing failed, trying fallback approach: {e!s}") + # Fallback: Parse the JSON response manually try: # Extract JSON content from the response content = llm_response.content - + # Find the JSON in the content (handle case where LLM might add additional text) - json_start = content.find('{') - json_end = content.rfind('}') + 1 + json_start = content.find("{") + json_end = content.rfind("}") + 1 if json_start >= 0 and json_end > json_start: json_str = content[json_start:json_end] - + # Parse the JSON string parsed_data = json.loads(json_str) - + # Convert to Pydantic model podcast_transcript = PodcastTranscripts.model_validate(parsed_data) - - print(f"Successfully parsed podcast transcript using fallback approach") + + print("Successfully parsed podcast transcript using fallback approach") else: # If JSON structure not found, raise a clear error error_message = f"Could not find valid JSON in LLM response. Raw response: {content}" print(error_message) raise ValueError(error_message) - + except (json.JSONDecodeError, ValueError) as e2: # Log the error and re-raise it - error_message = f"Error parsing LLM response (fallback also failed): {str(e2)}" - print(f"Error parsing LLM response: {str(e2)}") + error_message = f"Error parsing LLM response (fallback also failed): {e2!s}" + print(f"Error parsing LLM response: {e2!s}") print(f"Raw response: {llm_response.content}") raise - - return { - "podcast_transcript": podcast_transcript.podcast_transcripts - } - - -async def create_merged_podcast_audio(state: State, config: RunnableConfig) -> Dict[str, Any]: + + return {"podcast_transcript": podcast_transcript.podcast_transcripts} + + +async def create_merged_podcast_audio( + state: State, config: RunnableConfig +) -> dict[str, Any]: """Generate audio for each transcript and merge them into a single podcast file.""" - + configuration = Configuration.from_runnable_config(config) - + starting_transcript = PodcastTranscriptEntry( - speaker_id=1, - dialog=f"Welcome to {configuration.podcast_title} Podcast." + speaker_id=1, dialog=f"Welcome to {configuration.podcast_title} Podcast." ) - + transcript = state.podcast_transcript - + # Merge the starting transcript with the podcast transcript # Check if transcript is a PodcastTranscripts object or already a list - if hasattr(transcript, 'podcast_transcripts'): + if hasattr(transcript, "podcast_transcripts"): transcript_entries = transcript.podcast_transcripts else: transcript_entries = transcript - - merged_transcript = [starting_transcript] + transcript_entries - + + merged_transcript = [starting_transcript, *transcript_entries] + # Create a temporary directory for audio files temp_dir = Path("temp_audio") temp_dir.mkdir(exist_ok=True) - + # Generate a unique session ID for this podcast session_id = str(uuid.uuid4()) output_path = f"podcasts/{session_id}_podcast.mp3" os.makedirs("podcasts", exist_ok=True) - + # Map of speaker_id to voice voice_mapping = { 0: "alloy", # Default/intro voice - 1: "echo", # First speaker + 1: "echo", # First speaker # 2: "fable", # Second speaker # 3: "onyx", # Third speaker # 4: "nova", # Fourth speaker # 5: "shimmer" # Fifth speaker } - + # Generate audio for each transcript segment audio_files = [] - + async def generate_speech_for_segment(segment, index): # Handle both dictionary and PodcastTranscriptEntry objects - if hasattr(segment, 'speaker_id'): + if hasattr(segment, "speaker_id"): speaker_id = segment.speaker_id dialog = segment.dialog else: speaker_id = segment.get("speaker_id", 0) dialog = segment.get("dialog", "") - + # Select voice based on speaker_id voice = voice_mapping.get(speaker_id, "alloy") - + # Generate a unique filename for this segment filename = f"{temp_dir}/{session_id}_{index}.mp3" - + try: if app_config.TTS_SERVICE_API_BASE: response = await aspeech( @@ -163,55 +169,61 @@ async def create_merged_podcast_audio(state: State, config: RunnableConfig) -> D max_retries=2, timeout=600, ) - + # Save the audio to a file - use proper streaming method - with open(filename, 'wb') as f: + with open(filename, "wb") as f: f.write(response.content) - + return filename except Exception as e: - print(f"Error generating speech for segment {index}: {str(e)}") + print(f"Error generating speech for segment {index}: {e!s}") raise - + # Generate all audio files concurrently - tasks = [generate_speech_for_segment(segment, i) for i, segment in enumerate(merged_transcript)] + tasks = [ + generate_speech_for_segment(segment, i) + for i, segment in enumerate(merged_transcript) + ] audio_files = await asyncio.gather(*tasks) - + # Merge audio files using ffmpeg try: # Create FFmpeg instance with the first input ffmpeg = FFmpeg().option("y") - + # Add each audio file as input for audio_file in audio_files: ffmpeg = ffmpeg.input(audio_file) - + # Configure the concatenation and output filter_complex = [] for i in range(len(audio_files)): filter_complex.append(f"[{i}:0]") - - filter_complex_str = "".join(filter_complex) + f"concat=n={len(audio_files)}:v=0:a=1[outa]" + + filter_complex_str = ( + "".join(filter_complex) + f"concat=n={len(audio_files)}:v=0:a=1[outa]" + ) ffmpeg = ffmpeg.option("filter_complex", filter_complex_str) ffmpeg = ffmpeg.output(output_path, map="[outa]") - + # Execute FFmpeg await ffmpeg.execute() - + print(f"Successfully created podcast audio: {output_path}") - + except Exception as e: - print(f"Error merging audio files: {str(e)}") + print(f"Error merging audio files: {e!s}") raise finally: # Clean up temporary files for audio_file in audio_files: try: os.remove(audio_file) - except: + except Exception as e: + print(f"Error removing audio file {audio_file}: {e!s}") pass - + return { "podcast_transcript": merged_transcript, - "final_podcast_file_path": output_path + "final_podcast_file_path": output_path, } diff --git a/surfsense_backend/app/agents/podcaster/prompts.py b/surfsense_backend/app/agents/podcaster/prompts.py index c08d38e..a3d6c31 100644 --- a/surfsense_backend/app/agents/podcaster/prompts.py +++ b/surfsense_backend/app/agents/podcaster/prompts.py @@ -108,4 +108,4 @@ Output: Transform the source material into a lively and engaging podcast conversation. Craft dialogue that showcases authentic host chemistry and natural interaction (including occasional disagreement, building on points, or asking follow-up questions). Use varied speech patterns reflecting real human conversation, ensuring the final script effectively educates *and* entertains the listener while keeping within a 5-minute audio duration. -""" \ No newline at end of file +""" diff --git a/surfsense_backend/app/agents/podcaster/state.py b/surfsense_backend/app/agents/podcaster/state.py index 79fccef..62eb053 100644 --- a/surfsense_backend/app/agents/podcaster/state.py +++ b/surfsense_backend/app/agents/podcaster/state.py @@ -3,14 +3,16 @@ from __future__ import annotations from dataclasses import dataclass -from typing import List, Optional + from pydantic import BaseModel, Field from sqlalchemy.ext.asyncio import AsyncSession + class PodcastTranscriptEntry(BaseModel): """ Represents a single entry in a podcast transcript. """ + speaker_id: int = Field(..., description="The ID of the speaker (0 or 1)") dialog: str = Field(..., description="The dialog text spoken by the speaker") @@ -19,10 +21,11 @@ class PodcastTranscripts(BaseModel): """ Represents the full podcast transcript structure. """ - podcast_transcripts: List[PodcastTranscriptEntry] = Field( - ..., - description="List of transcript entries with alternating speakers" - ) + + podcast_transcripts: list[PodcastTranscriptEntry] = Field( + ..., description="List of transcript entries with alternating speakers" + ) + @dataclass class State: @@ -32,8 +35,9 @@ class State: See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state for more information. """ + # Runtime context db_session: AsyncSession source_content: str - podcast_transcript: Optional[List[PodcastTranscriptEntry]] = None - final_podcast_file_path: Optional[str] = None + podcast_transcript: list[PodcastTranscriptEntry] | None = None + final_podcast_file_path: str | None = None diff --git a/surfsense_backend/app/agents/researcher/configuration.py b/surfsense_backend/app/agents/researcher/configuration.py index a45aa9b..3e81a59 100644 --- a/surfsense_backend/app/agents/researcher/configuration.py +++ b/surfsense_backend/app/agents/researcher/configuration.py @@ -4,17 +4,20 @@ from __future__ import annotations from dataclasses import dataclass, fields from enum import Enum -from typing import Optional, List, Any from langchain_core.runnables import RunnableConfig -class SearchMode(Enum): + +class SearchMode(Enum): """Enum defining the type of search mode.""" + CHUNKS = "CHUNKS" DOCUMENTS = "DOCUMENTS" + class ResearchMode(Enum): """Enum defining the type of research mode.""" + QNA = "QNA" REPORT_GENERAL = "REPORT_GENERAL" REPORT_DEEP = "REPORT_DEEP" @@ -28,16 +31,16 @@ class Configuration: # Input parameters provided at invocation user_query: str num_sections: int - connectors_to_search: List[str] + connectors_to_search: list[str] user_id: str search_space_id: int search_mode: SearchMode research_mode: ResearchMode - document_ids_to_add_in_context: List[int] + document_ids_to_add_in_context: list[int] @classmethod def from_runnable_config( - cls, config: Optional[RunnableConfig] = None + cls, config: RunnableConfig | None = None ) -> Configuration: """Create a Configuration instance from a RunnableConfig object.""" configurable = (config.get("configurable") or {}) if config else {} diff --git a/surfsense_backend/app/agents/researcher/graph.py b/surfsense_backend/app/agents/researcher/graph.py index ed378ca..b3ffadd 100644 --- a/surfsense_backend/app/agents/researcher/graph.py +++ b/surfsense_backend/app/agents/researcher/graph.py @@ -1,31 +1,41 @@ +from typing import Any, TypedDict + from langgraph.graph import StateGraph -from .state import State -from .nodes import reformulate_user_query, write_answer_outline, process_sections, handle_qna_workflow, generate_further_questions + from .configuration import Configuration, ResearchMode -from typing import TypedDict, List, Dict, Any, Optional +from .nodes import ( + generate_further_questions, + handle_qna_workflow, + process_sections, + reformulate_user_query, + write_answer_outline, +) +from .state import State + # Define what keys are in our state dict class GraphState(TypedDict): # Intermediate data produced during workflow - answer_outline: Optional[Any] + answer_outline: Any | None # Final output - final_written_report: Optional[str] + final_written_report: str | None + def build_graph(): """ Build and return the LangGraph workflow. - + This function constructs the researcher agent graph with conditional routing based on research_mode - QNA mode uses a direct Q&A workflow while other modes use the full report generation pipeline. Both paths generate follow-up questions at the end using the reranked documents from the sub-agents. - + Returns: A compiled LangGraph workflow """ # Define a new graph with state class workflow = StateGraph(State, config_schema=Configuration) - + # Add nodes to the graph workflow.add_node("reformulate_user_query", reformulate_user_query) workflow.add_node("handle_qna_workflow", handle_qna_workflow) @@ -35,41 +45,42 @@ def build_graph(): # Define the edges workflow.add_edge("__start__", "reformulate_user_query") - + # Add conditional edges from reformulate_user_query based on research mode def route_after_reformulate(state: State, config) -> str: """Route based on research_mode after reformulating the query.""" configuration = Configuration.from_runnable_config(config) - + if configuration.research_mode == ResearchMode.QNA.value: return "handle_qna_workflow" else: return "write_answer_outline" - + workflow.add_conditional_edges( "reformulate_user_query", route_after_reformulate, { "handle_qna_workflow": "handle_qna_workflow", - "write_answer_outline": "write_answer_outline" - } + "write_answer_outline": "write_answer_outline", + }, ) - + # QNA workflow path: handle_qna_workflow -> generate_further_questions -> __end__ workflow.add_edge("handle_qna_workflow", "generate_further_questions") - + # Report generation workflow path: write_answer_outline -> process_sections -> generate_further_questions -> __end__ workflow.add_edge("write_answer_outline", "process_sections") workflow.add_edge("process_sections", "generate_further_questions") - + # Both paths end after generating further questions workflow.add_edge("generate_further_questions", "__end__") # Compile the workflow into an executable graph graph = workflow.compile() graph.name = "Surfsense Researcher" # This defines the custom name in LangSmith - + return graph + # Compile the graph once when the module is loaded graph = build_graph() diff --git a/surfsense_backend/app/agents/researcher/nodes.py b/surfsense_backend/app/agents/researcher/nodes.py index 30d572a..67550d6 100644 --- a/surfsense_backend/app/agents/researcher/nodes.py +++ b/surfsense_backend/app/agents/researcher/nodes.py @@ -1,79 +1,81 @@ import asyncio import json -from typing import Any, Dict, List +from typing import Any -from app.services.connector_service import ConnectorService from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.runnables import RunnableConfig - -from sqlalchemy.ext.asyncio import AsyncSession - -from .configuration import Configuration, SearchMode -from .prompts import get_answer_outline_system_prompt, get_further_questions_system_prompt -from .state import State -from .sub_section_writer.graph import graph as sub_section_writer_graph -from .sub_section_writer.configuration import SubSectionType -from .qna_agent.graph import graph as qna_agent_graph -from .utils import AnswerOutline, get_connector_emoji, get_connector_friendly_name - -from app.services.query_service import QueryService - from langgraph.types import StreamWriter +from sqlalchemy.ext.asyncio import AsyncSession # Additional imports for document fetching from sqlalchemy.future import select + from app.db import Document, SearchSpace +from app.services.connector_service import ConnectorService +from app.services.query_service import QueryService + +from .configuration import Configuration, SearchMode +from .prompts import ( + get_answer_outline_system_prompt, + get_further_questions_system_prompt, +) +from .qna_agent.graph import graph as qna_agent_graph +from .state import State +from .sub_section_writer.configuration import SubSectionType +from .sub_section_writer.graph import graph as sub_section_writer_graph +from .utils import AnswerOutline, get_connector_emoji, get_connector_friendly_name + async def fetch_documents_by_ids( - document_ids: List[int], - user_id: str, - db_session: AsyncSession -) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + document_ids: list[int], user_id: str, db_session: AsyncSession +) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: """ Fetch documents by their IDs with ownership check using DOCUMENTS mode approach. - + This function ensures that only documents belonging to the user are fetched, providing security by checking ownership through SearchSpace association. Similar to SearchMode.DOCUMENTS, it fetches full documents and concatenates their chunks. Also creates source objects for UI display, grouped by document type. - + Args: document_ids: List of document IDs to fetch user_id: The user ID to check ownership db_session: The database session - + Returns: Tuple of (source_objects, document_chunks) - similar to ConnectorService pattern """ if not document_ids: return [], [] - + try: # Query documents with ownership check result = await db_session.execute( select(Document) .join(SearchSpace) - .filter( - Document.id.in_(document_ids), - SearchSpace.user_id == user_id - ) + .filter(Document.id.in_(document_ids), SearchSpace.user_id == user_id) ) documents = result.scalars().all() - + # Group documents by type for source object creation documents_by_type = {} formatted_documents = [] - + for doc in documents: # Fetch associated chunks for this document (similar to DocumentHybridSearchRetriever) from app.db import Chunk - chunks_query = select(Chunk).where(Chunk.document_id == doc.id).order_by(Chunk.id) + + chunks_query = ( + select(Chunk).where(Chunk.document_id == doc.id).order_by(Chunk.id) + ) chunks_result = await db_session.execute(chunks_query) chunks = chunks_result.scalars().all() - + # Concatenate chunks content (similar to SearchMode.DOCUMENTS approach) - concatenated_chunks_content = " ".join([chunk.content for chunk in chunks]) if chunks else doc.content - + concatenated_chunks_content = ( + " ".join([chunk.content for chunk in chunks]) if chunks else doc.content + ) + # Format to match connector service return format formatted_doc = { "chunk_id": f"user_doc_{doc.id}", @@ -82,143 +84,215 @@ async def fetch_documents_by_ids( "document": { "id": doc.id, "title": doc.title, - "document_type": doc.document_type.value if doc.document_type else "UNKNOWN", + "document_type": doc.document_type.value + if doc.document_type + else "UNKNOWN", "metadata": doc.document_metadata or {}, }, - "source": doc.document_type.value if doc.document_type else "UNKNOWN" + "source": doc.document_type.value if doc.document_type else "UNKNOWN", } formatted_documents.append(formatted_doc) - + # Group by document type for source objects doc_type = doc.document_type.value if doc.document_type else "UNKNOWN" if doc_type not in documents_by_type: documents_by_type[doc_type] = [] documents_by_type[doc_type].append(doc) - + # Create source objects for each document type (similar to ConnectorService) source_objects = [] - connector_id_counter = 100 # Start from 100 to avoid conflicts with regular connectors - + connector_id_counter = ( + 100 # Start from 100 to avoid conflicts with regular connectors + ) + for doc_type, docs in documents_by_type.items(): sources_list = [] - + for doc in docs: metadata = doc.document_metadata or {} - + # Create type-specific source formatting (similar to ConnectorService) if doc_type == "LINEAR_CONNECTOR": # Extract Linear-specific metadata - issue_identifier = metadata.get('issue_identifier', '') - issue_title = metadata.get('issue_title', doc.title) - issue_state = metadata.get('state', '') - comment_count = metadata.get('comment_count', 0) - + issue_identifier = metadata.get("issue_identifier", "") + issue_title = metadata.get("issue_title", doc.title) + issue_state = metadata.get("state", "") + comment_count = metadata.get("comment_count", 0) + # Create a more descriptive title for Linear issues - title = f"Linear: {issue_identifier} - {issue_title}" if issue_identifier else f"Linear: {issue_title}" + title = ( + f"Linear: {issue_identifier} - {issue_title}" + if issue_identifier + else f"Linear: {issue_title}" + ) if issue_state: title += f" ({issue_state})" - + # Create description - description = doc.content[:100] + "..." if len(doc.content) > 100 else doc.content + description = ( + doc.content[:100] + "..." + if len(doc.content) > 100 + else doc.content + ) if comment_count: description += f" | Comments: {comment_count}" - + # Create URL - url = f"https://linear.app/issue/{issue_identifier}" if issue_identifier else "" - + url = ( + f"https://linear.app/issue/{issue_identifier}" + if issue_identifier + else "" + ) + elif doc_type == "SLACK_CONNECTOR": # Extract Slack-specific metadata - channel_name = metadata.get('channel_name', 'Unknown Channel') - channel_id = metadata.get('channel_id', '') - message_date = metadata.get('start_date', '') - + channel_name = metadata.get("channel_name", "Unknown Channel") + channel_id = metadata.get("channel_id", "") + message_date = metadata.get("start_date", "") + title = f"Slack: {channel_name}" if message_date: title += f" ({message_date})" - - description = doc.content[:100] + "..." if len(doc.content) > 100 else doc.content - url = f"https://slack.com/app_redirect?channel={channel_id}" if channel_id else "" - + + description = ( + doc.content[:100] + "..." + if len(doc.content) > 100 + else doc.content + ) + url = ( + f"https://slack.com/app_redirect?channel={channel_id}" + if channel_id + else "" + ) + elif doc_type == "NOTION_CONNECTOR": # Extract Notion-specific metadata - page_title = metadata.get('page_title', doc.title) - page_id = metadata.get('page_id', '') - + page_title = metadata.get("page_title", doc.title) + page_id = metadata.get("page_id", "") + title = f"Notion: {page_title}" - description = doc.content[:100] + "..." if len(doc.content) > 100 else doc.content - url = f"https://notion.so/{page_id.replace('-', '')}" if page_id else "" - + description = ( + doc.content[:100] + "..." + if len(doc.content) > 100 + else doc.content + ) + url = ( + f"https://notion.so/{page_id.replace('-', '')}" + if page_id + else "" + ) + elif doc_type == "GITHUB_CONNECTOR": title = f"GitHub: {doc.title}" - description = metadata.get('description', doc.content[:100] + "..." if len(doc.content) > 100 else doc.content) - url = metadata.get('url', '') - + description = metadata.get( + "description", + doc.content[:100] + "..." + if len(doc.content) > 100 + else doc.content, + ) + url = metadata.get("url", "") + elif doc_type == "YOUTUBE_VIDEO": # Extract YouTube-specific metadata - video_title = metadata.get('video_title', doc.title) - video_id = metadata.get('video_id', '') - channel_name = metadata.get('channel_name', '') - + video_title = metadata.get("video_title", doc.title) + video_id = metadata.get("video_id", "") + channel_name = metadata.get("channel_name", "") + title = video_title if channel_name: title += f" - {channel_name}" - - description = metadata.get('description', doc.content[:100] + "..." if len(doc.content) > 100 else doc.content) - url = f"https://www.youtube.com/watch?v={video_id}" if video_id else "" - + + description = metadata.get( + "description", + doc.content[:100] + "..." + if len(doc.content) > 100 + else doc.content, + ) + url = ( + f"https://www.youtube.com/watch?v={video_id}" + if video_id + else "" + ) + elif doc_type == "DISCORD_CONNECTOR": # Extract Discord-specific metadata - channel_name = metadata.get('channel_name', 'Unknown Channel') - channel_id = metadata.get('channel_id', '') - guild_id = metadata.get('guild_id', '') - message_date = metadata.get('start_date', '') - + channel_name = metadata.get("channel_name", "Unknown Channel") + channel_id = metadata.get("channel_id", "") + guild_id = metadata.get("guild_id", "") + message_date = metadata.get("start_date", "") + title = f"Discord: {channel_name}" if message_date: title += f" ({message_date})" - - description = doc.content[:100] + "..." if len(doc.content) > 100 else doc.content - + + description = ( + doc.content[:100] + "..." + if len(doc.content) > 100 + else doc.content + ) + if guild_id and channel_id: url = f"https://discord.com/channels/{guild_id}/{channel_id}" elif channel_id: url = f"https://discord.com/channels/@me/{channel_id}" else: url = "" - + elif doc_type == "EXTENSION": # Extract Extension-specific metadata - webpage_title = metadata.get('VisitedWebPageTitle', doc.title) - webpage_url = metadata.get('VisitedWebPageURL', '') - visit_date = metadata.get('VisitedWebPageDateWithTimeInISOString', '') - + webpage_title = metadata.get("VisitedWebPageTitle", doc.title) + webpage_url = metadata.get("VisitedWebPageURL", "") + visit_date = metadata.get( + "VisitedWebPageDateWithTimeInISOString", "" + ) + title = webpage_title if visit_date: - formatted_date = visit_date.split('T')[0] if 'T' in visit_date else visit_date + formatted_date = ( + visit_date.split("T")[0] + if "T" in visit_date + else visit_date + ) title += f" (visited: {formatted_date})" - - description = doc.content[:100] + "..." if len(doc.content) > 100 else doc.content + + description = ( + doc.content[:100] + "..." + if len(doc.content) > 100 + else doc.content + ) url = webpage_url - + elif doc_type == "CRAWLED_URL": title = doc.title - description = metadata.get('og:description', metadata.get('ogDescription', doc.content[:100] + "..." if len(doc.content) > 100 else doc.content)) - url = metadata.get('url', '') - + description = metadata.get( + "og:description", + metadata.get( + "ogDescription", + doc.content[:100] + "..." + if len(doc.content) > 100 + else doc.content, + ), + ) + url = metadata.get("url", "") + else: # FILE and other types title = doc.title - description = doc.content[:100] + "..." if len(doc.content) > 100 else doc.content - url = metadata.get('url', '') - + description = ( + doc.content[:100] + "..." + if len(doc.content) > 100 + else doc.content + ) + url = metadata.get("url", "") + # Create source entry source = { "id": doc.id, "title": title, "description": description, - "url": url + "url": url, } sources_list.append(source) - + # Create source object for this document type friendly_type_names = { "LINEAR_CONNECTOR": "Linear Issues (Selected)", @@ -229,9 +303,9 @@ async def fetch_documents_by_ids( "DISCORD_CONNECTOR": "Discord (Selected)", "EXTENSION": "Browser Extension (Selected)", "CRAWLED_URL": "Web Pages (Selected)", - "FILE": "Files (Selected)" + "FILE": "Files (Selected)", } - + source_object = { "id": connector_id_counter, "name": friendly_type_names.get(doc_type, f"{doc_type} (Selected)"), @@ -240,31 +314,34 @@ async def fetch_documents_by_ids( } source_objects.append(source_object) connector_id_counter += 1 - - print(f"Fetched {len(formatted_documents)} user-selected documents (with concatenated chunks) from {len(document_ids)} requested IDs") + + print( + f"Fetched {len(formatted_documents)} user-selected documents (with concatenated chunks) from {len(document_ids)} requested IDs" + ) print(f"Created {len(source_objects)} source objects for UI display") - + return source_objects, formatted_documents - + except Exception as e: - print(f"Error fetching documents by IDs: {str(e)}") + print(f"Error fetching documents by IDs: {e!s}") return [], [] -async def write_answer_outline(state: State, config: RunnableConfig, writer: StreamWriter) -> Dict[str, Any]: +async def write_answer_outline( + state: State, config: RunnableConfig, writer: StreamWriter +) -> dict[str, Any]: """ Create a structured answer outline based on the user query. - + This node takes the user query and number of sections from the configuration and uses an LLM to generate a comprehensive outline with logical sections and research questions for each section. - + Returns: Dict containing the answer outline in the "answer_outline" key for state update. """ from app.services.llm_service import get_user_strategic_llm - from app.db import get_async_session - + streaming_service = state.streaming_service writer( @@ -331,9 +408,9 @@ async def write_answer_outline(state: State, config: RunnableConfig, writer: Str # Create messages for the LLM messages = [ SystemMessage(content=get_answer_outline_system_prompt()), - HumanMessage(content=human_message_content) + HumanMessage(content=human_message_content), ] - + # Call the LLM directly without using structured output writer( { @@ -344,26 +421,28 @@ async def write_answer_outline(state: State, config: RunnableConfig, writer: Str ) response = await llm.ainvoke(messages) - + # Parse the JSON response manually try: # Extract JSON content from the response content = response.content - + # Find the JSON in the content (handle case where LLM might add additional text) - json_start = content.find('{') - json_end = content.rfind('}') + 1 + json_start = content.find("{") + json_end = content.rfind("}") + 1 if json_start >= 0 and json_end > json_start: json_str = content[json_start:json_end] - + # Parse the JSON string parsed_data = json.loads(json_str) - + # Convert to Pydantic model answer_outline = AnswerOutline(**parsed_data) - - total_questions = sum(len(section.questions) for section in answer_outline.answer_outline) - + + total_questions = sum( + len(section.questions) for section in answer_outline.answer_outline + ) + writer( { "yield_value": streaming_service.format_terminal_info_delta( @@ -388,35 +467,35 @@ async def write_answer_outline(state: State, config: RunnableConfig, writer: Str except (json.JSONDecodeError, ValueError) as e: # Log the error and re-raise it - error_message = f"Error parsing LLM response: {str(e)}" + error_message = f"Error parsing LLM response: {e!s}" writer({"yield_value": streaming_service.format_error(error_message)}) - print(f"Error parsing LLM response: {str(e)}") + print(f"Error parsing LLM response: {e!s}") print(f"Raw response: {response.content}") raise async def fetch_relevant_documents( - research_questions: List[str], + research_questions: list[str], user_id: str, search_space_id: int, db_session: AsyncSession, - connectors_to_search: List[str], + connectors_to_search: list[str], writer: StreamWriter = None, state: State = None, top_k: int = 10, connector_service: ConnectorService = None, search_mode: SearchMode = SearchMode.CHUNKS, - user_selected_sources: List[Dict[str, Any]] = None -) -> List[Dict[str, Any]]: + user_selected_sources: list[dict[str, Any]] | None = None, +) -> list[dict[str, Any]]: """ Fetch relevant documents for research questions using the provided connectors. - + This function searches across multiple data sources for information related to the research questions. It provides user-friendly feedback during the search process by displaying connector names (like "Web Search" instead of "TAVILY_API") and adding relevant emojis to indicate the type of source being searched. - + Args: research_questions: List of research questions to find documents for user_id: The user ID @@ -427,19 +506,21 @@ async def fetch_relevant_documents( state: The current state containing the streaming service top_k: Number of top results to retrieve per connector per question connector_service: An initialized connector service to use for searching - + Returns: List of relevant documents """ # Initialize services # connector_service = ConnectorService(db_session) - + # Only use streaming if both writer and state are provided streaming_service = state.streaming_service if state is not None else None # Stream initial status update if streaming_service and writer: - connector_names = [get_connector_friendly_name(connector) for connector in connectors_to_search] + connector_names = [ + get_connector_friendly_name(connector) for connector in connectors_to_search + ] connector_names_str = ", ".join(connector_names) writer( { @@ -451,7 +532,7 @@ async def fetch_relevant_documents( all_raw_documents = [] # Store all raw documents all_sources = [] # Store all sources - + for i, user_query in enumerate(research_questions): # Stream question being researched if streaming_service and writer: @@ -465,7 +546,7 @@ async def fetch_relevant_documents( # Use original research question as the query reformulated_query = user_query - + # Process each selected connector for connector in connectors_to_search: # Stream connector being searched @@ -482,19 +563,22 @@ async def fetch_relevant_documents( try: if connector == "YOUTUBE_VIDEO": - source_object, youtube_chunks = await connector_service.search_youtube( + ( + source_object, + youtube_chunks, + ) = await connector_service.search_youtube( user_query=reformulated_query, user_id=user_id, search_space_id=search_space_id, top_k=top_k, - search_mode=search_mode + search_mode=search_mode, ) - + # Add to sources and raw documents if source_object: all_sources.append(source_object) all_raw_documents.extend(youtube_chunks) - + # Stream found document count if streaming_service and writer: writer( @@ -506,19 +590,22 @@ async def fetch_relevant_documents( ) elif connector == "EXTENSION": - source_object, extension_chunks = await connector_service.search_extension( + ( + source_object, + extension_chunks, + ) = await connector_service.search_extension( user_query=reformulated_query, user_id=user_id, search_space_id=search_space_id, top_k=top_k, - search_mode=search_mode + search_mode=search_mode, ) - + # Add to sources and raw documents if source_object: all_sources.append(source_object) all_raw_documents.extend(extension_chunks) - + # Stream found document count if streaming_service and writer: writer( @@ -530,19 +617,22 @@ async def fetch_relevant_documents( ) elif connector == "CRAWLED_URL": - source_object, crawled_urls_chunks = await connector_service.search_crawled_urls( + ( + source_object, + crawled_urls_chunks, + ) = await connector_service.search_crawled_urls( user_query=reformulated_query, user_id=user_id, search_space_id=search_space_id, top_k=top_k, - search_mode=search_mode + search_mode=search_mode, ) - + # Add to sources and raw documents if source_object: all_sources.append(source_object) all_raw_documents.extend(crawled_urls_chunks) - + # Stream found document count if streaming_service and writer: writer( @@ -559,14 +649,14 @@ async def fetch_relevant_documents( user_id=user_id, search_space_id=search_space_id, top_k=top_k, - search_mode=search_mode + search_mode=search_mode, ) - + # Add to sources and raw documents if source_object: all_sources.append(source_object) all_raw_documents.extend(files_chunks) - + # Stream found document count if streaming_service and writer: writer( @@ -583,14 +673,14 @@ async def fetch_relevant_documents( user_id=user_id, search_space_id=search_space_id, top_k=top_k, - search_mode=search_mode + search_mode=search_mode, ) - + # Add to sources and raw documents if source_object: all_sources.append(source_object) all_raw_documents.extend(slack_chunks) - + # Stream found document count if streaming_service and writer: writer( @@ -602,19 +692,22 @@ async def fetch_relevant_documents( ) elif connector == "NOTION_CONNECTOR": - source_object, notion_chunks = await connector_service.search_notion( + ( + source_object, + notion_chunks, + ) = await connector_service.search_notion( user_query=reformulated_query, user_id=user_id, search_space_id=search_space_id, top_k=top_k, - search_mode=search_mode + search_mode=search_mode, ) - + # Add to sources and raw documents if source_object: all_sources.append(source_object) all_raw_documents.extend(notion_chunks) - + # Stream found document count if streaming_service and writer: writer( @@ -626,19 +719,22 @@ async def fetch_relevant_documents( ) elif connector == "GITHUB_CONNECTOR": - source_object, github_chunks = await connector_service.search_github( + ( + source_object, + github_chunks, + ) = await connector_service.search_github( user_query=reformulated_query, user_id=user_id, search_space_id=search_space_id, top_k=top_k, - search_mode=search_mode + search_mode=search_mode, ) - + # Add to sources and raw documents if source_object: all_sources.append(source_object) all_raw_documents.extend(github_chunks) - + # Stream found document count if streaming_service and writer: writer( @@ -650,19 +746,22 @@ async def fetch_relevant_documents( ) elif connector == "LINEAR_CONNECTOR": - source_object, linear_chunks = await connector_service.search_linear( + ( + source_object, + linear_chunks, + ) = await connector_service.search_linear( user_query=reformulated_query, user_id=user_id, search_space_id=search_space_id, top_k=top_k, - search_mode=search_mode + search_mode=search_mode, ) - + # Add to sources and raw documents if source_object: all_sources.append(source_object) all_raw_documents.extend(linear_chunks) - + # Stream found document count if streaming_service and writer: writer( @@ -674,17 +773,18 @@ async def fetch_relevant_documents( ) elif connector == "TAVILY_API": - source_object, tavily_chunks = await connector_service.search_tavily( - user_query=reformulated_query, - user_id=user_id, - top_k=top_k + ( + source_object, + tavily_chunks, + ) = await connector_service.search_tavily( + user_query=reformulated_query, user_id=user_id, top_k=top_k ) - + # Add to sources and raw documents if source_object: all_sources.append(source_object) all_raw_documents.extend(tavily_chunks) - + # Stream found document count if streaming_service and writer: writer( @@ -697,18 +797,19 @@ async def fetch_relevant_documents( elif connector == "LINKUP_API": linkup_mode = "standard" - - source_object, linkup_chunks = await connector_service.search_linkup( - user_query=reformulated_query, - user_id=user_id, - mode=linkup_mode - ) - + + ( + source_object, + linkup_chunks, + ) = await connector_service.search_linkup( + user_query=reformulated_query, user_id=user_id, mode=linkup_mode + ) + # Add to sources and raw documents if source_object: all_sources.append(source_object) - all_raw_documents.extend(linkup_chunks) - + all_raw_documents.extend(linkup_chunks) + # Stream found document count if streaming_service and writer: writer( @@ -720,12 +821,15 @@ async def fetch_relevant_documents( ) elif connector == "DISCORD_CONNECTOR": - source_object, discord_chunks = await connector_service.search_discord( + ( + source_object, + discord_chunks, + ) = await connector_service.search_discord( user_query=reformulated_query, user_id=user_id, search_space_id=search_space_id, top_k=top_k, - search_mode=search_mode + search_mode=search_mode, ) # Add to sources and raw documents if source_object: @@ -742,33 +846,33 @@ async def fetch_relevant_documents( ) except Exception as e: - error_message = f"Error searching connector {connector}: {str(e)}" + error_message = f"Error searching connector {connector}: {e!s}" print(error_message) - + # Stream error message if streaming_service and writer: friendly_name = get_connector_friendly_name(connector) writer( { "yield_value": streaming_service.format_error( - f"Error searching {friendly_name}: {str(e)}" + f"Error searching {friendly_name}: {e!s}" ) } ) # Continue with other connectors on error continue - + # Deduplicate source objects by ID before streaming deduplicated_sources = [] seen_source_keys = set() - + # First add user-selected sources (if any) if user_selected_sources: for source_obj in user_selected_sources: - source_id = source_obj.get('id') - source_type = source_obj.get('type') - + source_id = source_obj.get("id") + source_type = source_obj.get("type") + if source_id and source_type: source_key = f"{source_type}_{source_id}" if source_key not in seen_source_keys: @@ -776,47 +880,59 @@ async def fetch_relevant_documents( deduplicated_sources.append(source_obj) else: deduplicated_sources.append(source_obj) - + # Then add connector sources for source_obj in all_sources: # Use combination of source ID and type as a unique identifier # This ensures we don't accidentally deduplicate sources from different connectors - source_id = source_obj.get('id') - source_type = source_obj.get('type') - + source_id = source_obj.get("id") + source_type = source_obj.get("type") + if source_id and source_type: source_key = f"{source_type}_{source_id}" - current_sources_count = len(source_obj.get('sources', [])) - + current_sources_count = len(source_obj.get("sources", [])) + if source_key not in seen_source_keys: seen_source_keys.add(source_key) deduplicated_sources.append(source_obj) - print(f"Debug: Added source - ID: {source_id}, Type: {source_type}, Key: {source_key}, Sources count: {current_sources_count}") + print( + f"Debug: Added source - ID: {source_id}, Type: {source_type}, Key: {source_key}, Sources count: {current_sources_count}" + ) else: # Check if this source object has more sources than the existing one existing_index = None for i, existing_source in enumerate(deduplicated_sources): - existing_id = existing_source.get('id') - existing_type = existing_source.get('type') + existing_id = existing_source.get("id") + existing_type = existing_source.get("type") if existing_id == source_id and existing_type == source_type: existing_index = i break - + if existing_index is not None: - existing_sources_count = len(deduplicated_sources[existing_index].get('sources', [])) + existing_sources_count = len( + deduplicated_sources[existing_index].get("sources", []) + ) if current_sources_count > existing_sources_count: # Replace the existing source object with the new one that has more sources deduplicated_sources[existing_index] = source_obj - print(f"Debug: Replaced source - ID: {source_id}, Type: {source_type}, Key: {source_key}, Sources count: {existing_sources_count} -> {current_sources_count}") + print( + f"Debug: Replaced source - ID: {source_id}, Type: {source_type}, Key: {source_key}, Sources count: {existing_sources_count} -> {current_sources_count}" + ) else: - print(f"Debug: Skipped duplicate source - ID: {source_id}, Type: {source_type}, Key: {source_key}, Sources count: {current_sources_count} <= {existing_sources_count}") + print( + f"Debug: Skipped duplicate source - ID: {source_id}, Type: {source_type}, Key: {source_key}, Sources count: {current_sources_count} <= {existing_sources_count}" + ) else: - print(f"Debug: Skipped duplicate source - ID: {source_id}, Type: {source_type}, Key: {source_key} (couldn't find existing)") + print( + f"Debug: Skipped duplicate source - ID: {source_id}, Type: {source_type}, Key: {source_key} (couldn't find existing)" + ) else: # If there's no ID or type, just add it to be safe deduplicated_sources.append(source_obj) - print(f"Debug: Added source without ID/type - {source_obj.get('name', 'UNKNOWN')}") - + print( + f"Debug: Added source without ID/type - {source_obj.get('name', 'UNKNOWN')}" + ) + # Stream info about deduplicated sources if streaming_service and writer: user_source_count = len(user_selected_sources) if user_selected_sources else 0 @@ -831,28 +947,36 @@ async def fetch_relevant_documents( # After all sources are collected and deduplicated, stream them if streaming_service and writer: - writer({"yield_value": streaming_service.format_sources_delta(deduplicated_sources)}) + writer( + { + "yield_value": streaming_service.format_sources_delta( + deduplicated_sources + ) + } + ) # Deduplicate raw documents based on chunk_id or content seen_chunk_ids = set() seen_content_hashes = set() deduplicated_docs = [] - + for doc in all_raw_documents: chunk_id = doc.get("chunk_id") content = doc.get("content", "") content_hash = hash(content) - + # Skip if we've seen this chunk_id or content before - if (chunk_id and chunk_id in seen_chunk_ids) or content_hash in seen_content_hashes: + if ( + chunk_id and chunk_id in seen_chunk_ids + ) or content_hash in seen_content_hashes: continue - + # Add to our tracking sets and keep this document if chunk_id: seen_chunk_ids.add(chunk_id) seen_content_hashes.add(content_hash) deduplicated_docs.append(doc) - + # Stream info about deduplicated documents if streaming_service and writer: writer( @@ -867,14 +991,16 @@ async def fetch_relevant_documents( return deduplicated_docs -async def process_sections(state: State, config: RunnableConfig, writer: StreamWriter) -> Dict[str, Any]: +async def process_sections( + state: State, config: RunnableConfig, writer: StreamWriter +) -> dict[str, Any]: """ Process all sections in parallel and combine the results. - - This node takes the answer outline from the previous step, fetches relevant documents - for all questions across all sections once, and then processes each section in parallel + + This node takes the answer outline from the previous step, fetches relevant documents + for all questions across all sections once, and then processes each section in parallel using the sub_section_writer graph with the shared document pool. - + Returns: Dict containing the final written report in the "final_written_report" key. """ @@ -882,7 +1008,7 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW configuration = Configuration.from_runnable_config(config) answer_outline = state.answer_outline streaming_service = state.streaming_service - + # Initialize a dictionary to track content for all sections # This is used to maintain section content while streaming multiple sections section_contents = {} @@ -896,19 +1022,19 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW ) print(f"Processing sections from outline: {answer_outline is not None}") - + if not answer_outline: error_message = "No answer outline was provided. Cannot generate report." writer({"yield_value": streaming_service.format_error(error_message)}) return { "final_written_report": "No answer outline was provided. Cannot generate final report." } - + # Collect all questions from all sections all_questions = [] for section in answer_outline.answer_outline: all_questions.extend(section.questions) - + print(f"Collected {len(all_questions)} questions from all sections") writer( { @@ -928,18 +1054,18 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW ) if configuration.num_sections == 1: - TOP_K = 10 + top_k = 10 elif configuration.num_sections == 3: - TOP_K = 20 + top_k = 20 elif configuration.num_sections == 6: - TOP_K = 30 + top_k = 30 else: - TOP_K = 10 - + top_k = 10 + relevant_documents = [] user_selected_documents = [] user_selected_sources = [] - + try: # First, fetch user-selected documents if any if configuration.document_ids_to_add_in_context: @@ -951,12 +1077,15 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW } ) - user_selected_sources, user_selected_documents = await fetch_documents_by_ids( + ( + user_selected_sources, + user_selected_documents, + ) = await fetch_documents_by_ids( document_ids=configuration.document_ids_to_add_in_context, user_id=configuration.user_id, - db_session=state.db_session + db_session=state.db_session, ) - + if user_selected_documents: writer( { @@ -967,9 +1096,11 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW ) # Create connector service using state db_session - connector_service = ConnectorService(state.db_session, user_id=configuration.user_id) + connector_service = ConnectorService( + state.db_session, user_id=configuration.user_id + ) await connector_service.initialize_counter() - + relevant_documents = await fetch_relevant_documents( research_questions=all_questions, user_id=configuration.user_id, @@ -978,24 +1109,26 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW connectors_to_search=configuration.connectors_to_search, writer=writer, state=state, - top_k=TOP_K, + top_k=top_k, connector_service=connector_service, search_mode=configuration.search_mode, - user_selected_sources=user_selected_sources + user_selected_sources=user_selected_sources, ) except Exception as e: - error_message = f"Error fetching relevant documents: {str(e)}" + error_message = f"Error fetching relevant documents: {e!s}" print(error_message) writer({"yield_value": streaming_service.format_error(error_message)}) # Log the error and continue with an empty list of documents # This allows the process to continue, but the report might lack information relevant_documents = [] - + # Combine user-selected documents with connector-fetched documents all_documents = user_selected_documents + relevant_documents - + print(f"Fetched {len(relevant_documents)} relevant documents for all sections") - print(f"Added {len(user_selected_documents)} user-selected documents for all sections") + print( + f"Added {len(user_selected_documents)} user-selected documents for all sections" + ) print(f"Total documents for sections: {len(all_documents)}") writer( @@ -1023,14 +1156,14 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW sub_section_type = SubSectionType.END else: sub_section_type = SubSectionType.MIDDLE - + # Initialize the section_contents entry for this section section_contents[i] = { "title": section.section_title, "content": "", - "index": i + "index": i, } - + section_tasks.append( process_section_with_documents( section_id=i, @@ -1043,10 +1176,10 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW state=state, writer=writer, sub_section_type=sub_section_type, - section_contents=section_contents + section_contents=section_contents, ) ) - + # Run all section processing tasks in parallel print(f"Running {len(section_tasks)} section processing tasks in parallel") writer( @@ -1058,7 +1191,7 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW ) section_results = await asyncio.gather(*section_tasks, return_exceptions=True) - + # Handle any exceptions in the results writer( { @@ -1072,28 +1205,31 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW for i, result in enumerate(section_results): if isinstance(result, Exception): section_title = answer_outline.answer_outline[i].section_title - error_message = f"Error processing section '{section_title}': {str(result)}" + error_message = f"Error processing section '{section_title}': {result!s}" print(error_message) writer({"yield_value": streaming_service.format_error(error_message)}) processed_results.append(error_message) else: processed_results.append(result) - + # Combine the results into a final report with section titles final_report = [] - for i, (section, content) in enumerate(zip(answer_outline.answer_outline, processed_results)): + for _, (section, content) in enumerate( + zip(answer_outline.answer_outline, processed_results, strict=False) + ): # Skip adding the section header since the content already contains the title final_report.append(content) - final_report.append("\n") - + final_report.append("\n") + # Stream each section with its title writer( { - "yield_value": state.streaming_service.format_text_chunk(f"# {section.section_title}\n\n{content}") + "yield_value": state.streaming_service.format_text_chunk( + f"# {section.section_title}\n\n{content}" + ) } ) - # Join all sections with newlines final_written_report = "\n".join(final_report) print(f"Generated final report with {len(final_report)} parts") @@ -1110,26 +1246,26 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW # Since all sections used the same document pool, we can use it directly return { "final_written_report": final_written_report, - "reranked_documents": all_documents + "reranked_documents": all_documents, } async def process_section_with_documents( section_id: int, - section_title: str, - section_questions: List[str], - user_id: str, - search_space_id: int, - relevant_documents: List[Dict[str, Any]], + section_title: str, + section_questions: list[str], + user_id: str, + search_space_id: int, + relevant_documents: list[dict[str, Any]], user_query: str, state: State = None, writer: StreamWriter = None, sub_section_type: SubSectionType = SubSectionType.MIDDLE, - section_contents: Dict[int, Dict[str, Any]] = None + section_contents: dict[int, dict[str, Any]] | None = None, ) -> str: """ Process a single section using pre-fetched documents. - + Args: section_id: The ID of the section section_title: The title of the section @@ -1141,14 +1277,14 @@ async def process_section_with_documents( writer: StreamWriter for sending progress updates sub_section_type: The type of section (start, middle, end) section_contents: Dictionary to track content across multiple sections - + Returns: The written section content """ try: # Use the provided documents documents_to_use = relevant_documents - + # Send status update via streaming if available if state and state.streaming_service and writer: writer( @@ -1175,7 +1311,7 @@ async def process_section_with_documents( {"content": f"No specific information was found for: {question}"} for question in section_questions ] - + # Call the sub_section_writer graph with the appropriate config config = { "configurable": { @@ -1188,13 +1324,10 @@ async def process_section_with_documents( "search_space_id": search_space_id, } } - + # Create the initial state with db_session and chat_history - sub_state = { - "db_session": state.db_session, - "chat_history": state.chat_history - } - + sub_state = {"db_session": state.db_session, "chat_history": state.chat_history} + # Invoke the sub-section writer graph with streaming print(f"Invoking sub_section_writer for: {section_title}") if state and state.streaming_service and writer: @@ -1208,17 +1341,19 @@ async def process_section_with_documents( # Variables to track streaming state complete_content = "" # Tracks the complete content received so far - - async for chunk_type, chunk in sub_section_writer_graph.astream(sub_state, config, stream_mode=["values"]): + + async for _chunk_type, chunk in sub_section_writer_graph.astream( + sub_state, config, stream_mode=["values"] + ): if "final_answer" in chunk: new_content = chunk["final_answer"] if new_content and new_content != complete_content: # Extract only the new content (delta) - delta = new_content[len(complete_content):] - + delta = new_content[len(complete_content) :] + # Update what we've processed so far complete_content = new_content - + # Only stream if there's actual new content if delta and state and state.streaming_service and writer: # Update terminal with real-time progress indicator @@ -1232,26 +1367,29 @@ async def process_section_with_documents( # Update section_contents with just the new delta section_contents[section_id]["content"] += delta - + # Build UI-friendly content for all sections complete_answer = [] for i in range(len(section_contents)): if i in section_contents and section_contents[i]["content"]: # Add section header - complete_answer.append(f"# {section_contents[i]['title']}") + complete_answer.append( + f"# {section_contents[i]['title']}" + ) complete_answer.append("") # Empty line after title - + # Add section content - content_lines = section_contents[i]["content"].split("\n") + content_lines = section_contents[i]["content"].split( + "\n" + ) complete_answer.extend(content_lines) complete_answer.append("") # Empty line after content - # Set default if no content was received if not complete_content: complete_content = "No content was generated for this section." section_contents[section_id]["content"] = complete_content - + # Final terminal update if state and state.streaming_service and writer: writer( @@ -1264,52 +1402,61 @@ async def process_section_with_documents( return complete_content except Exception as e: - print(f"Error processing section '{section_title}': {str(e)}") - + print(f"Error processing section '{section_title}': {e!s}") + # Send error update via streaming if available if state and state.streaming_service and writer: writer( { "yield_value": state.streaming_service.format_error( - f'Error processing section "{section_title}": {str(e)}' + f'Error processing section "{section_title}": {e!s}' ) } ) - return f"Error processing section: {section_title}. Details: {str(e)}" + return f"Error processing section: {section_title}. Details: {e!s}" -async def reformulate_user_query(state: State, config: RunnableConfig, writer: StreamWriter) -> Dict[str, Any]: +async def reformulate_user_query( + state: State, config: RunnableConfig, writer: StreamWriter +) -> dict[str, Any]: """ Reforms the user query based on the chat history. """ - + configuration = Configuration.from_runnable_config(config) user_query = configuration.user_query - chat_history_str = await QueryService.langchain_chat_history_to_str(state.chat_history) - if len(state.chat_history) == 0: + chat_history_str = await QueryService.langchain_chat_history_to_str( + state.chat_history + ) + if len(state.chat_history) == 0: reformulated_query = user_query else: - reformulated_query = await QueryService.reformulate_query_with_chat_history(user_query=user_query, session=state.db_session, user_id=configuration.user_id, chat_history_str=chat_history_str) - - return { - "reformulated_query": reformulated_query - } + reformulated_query = await QueryService.reformulate_query_with_chat_history( + user_query=user_query, + session=state.db_session, + user_id=configuration.user_id, + chat_history_str=chat_history_str, + ) + + return {"reformulated_query": reformulated_query} -async def handle_qna_workflow(state: State, config: RunnableConfig, writer: StreamWriter) -> Dict[str, Any]: +async def handle_qna_workflow( + state: State, config: RunnableConfig, writer: StreamWriter +) -> dict[str, Any]: """ Handle the QNA research workflow. - + This node fetches relevant documents for the user query and then uses the QNA agent to generate a comprehensive answer with proper citations. - + Returns: Dict containing the final answer in the "final_written_report" key for consistency. """ streaming_service = state.streaming_service configuration = Configuration.from_runnable_config(config) - + reformulated_query = state.reformulated_query user_query = configuration.user_query @@ -1339,12 +1486,12 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre ) # Use a reasonable top_k for QNA - not too many documents to avoid overwhelming the LLM - TOP_K = 15 - + top_k = 15 + relevant_documents = [] user_selected_documents = [] user_selected_sources = [] - + try: # First, fetch user-selected documents if any if configuration.document_ids_to_add_in_context: @@ -1356,12 +1503,15 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre } ) - user_selected_sources, user_selected_documents = await fetch_documents_by_ids( + ( + user_selected_sources, + user_selected_documents, + ) = await fetch_documents_by_ids( document_ids=configuration.document_ids_to_add_in_context, user_id=configuration.user_id, - db_session=state.db_session + db_session=state.db_session, ) - + if user_selected_documents: writer( { @@ -1372,12 +1522,14 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre ) # Create connector service using state db_session - connector_service = ConnectorService(state.db_session, user_id=configuration.user_id) + connector_service = ConnectorService( + state.db_session, user_id=configuration.user_id + ) await connector_service.initialize_counter() - + # Use the reformulated query as a single research question research_questions = [reformulated_query, user_query] - + relevant_documents = await fetch_relevant_documents( research_questions=research_questions, user_id=configuration.user_id, @@ -1386,21 +1538,21 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre connectors_to_search=configuration.connectors_to_search, writer=writer, state=state, - top_k=TOP_K, + top_k=top_k, connector_service=connector_service, search_mode=configuration.search_mode, - user_selected_sources=user_selected_sources + user_selected_sources=user_selected_sources, ) except Exception as e: - error_message = f"Error fetching relevant documents for QNA: {str(e)}" + error_message = f"Error fetching relevant documents for QNA: {e!s}" print(error_message) writer({"yield_value": streaming_service.format_error(error_message)}) # Continue with empty documents - the QNA agent will handle this gracefully relevant_documents = [] - + # Combine user-selected documents with connector-fetched documents all_documents = user_selected_documents + relevant_documents - + print(f"Fetched {len(relevant_documents)} relevant documents for QNA") print(f"Added {len(user_selected_documents)} user-selected documents for QNA") print(f"Total documents for QNA: {len(all_documents)}") @@ -1420,16 +1572,13 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre "reformulated_query": reformulated_query, "relevant_documents": all_documents, # Use combined documents "user_id": configuration.user_id, - "search_space_id": configuration.search_space_id + "search_space_id": configuration.search_space_id, } } - + # Create the state for the QNA agent (it has a different state structure) - qna_state = { - "db_session": state.db_session, - "chat_history": state.chat_history - } - + qna_state = {"db_session": state.db_session, "chat_history": state.chat_history} + try: writer( { @@ -1442,16 +1591,18 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre # Track streaming content for real-time updates complete_content = "" captured_reranked_documents = [] - + # Call the QNA agent with streaming - async for _chunk_type, chunk in qna_agent_graph.astream(qna_state, qna_config, stream_mode=["values"]): + async for _chunk_type, chunk in qna_agent_graph.astream( + qna_state, qna_config, stream_mode=["values"] + ): if "final_answer" in chunk: new_content = chunk["final_answer"] if new_content and new_content != complete_content: # Extract only the new content (delta) - delta = new_content[len(complete_content):] + delta = new_content[len(complete_content) :] complete_content = new_content - + # Stream the real-time answer if there's new content if delta: # Update terminal with progress @@ -1471,7 +1622,7 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre # Capture reranked documents from QNA agent for further question generation if "reranked_documents" in chunk: captured_reranked_documents = chunk["reranked_documents"] - + # Set default if no content was received if not complete_content: complete_content = "I couldn't find relevant information in your knowledge base to answer this question." @@ -1487,38 +1638,40 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre # Return the final answer and captured reranked documents for further question generation return { "final_written_report": complete_content, - "reranked_documents": captured_reranked_documents + "reranked_documents": captured_reranked_documents, } - + except Exception as e: - error_message = f"Error generating QNA answer: {str(e)}" + error_message = f"Error generating QNA answer: {e!s}" print(error_message) writer({"yield_value": streaming_service.format_error(error_message)}) - return {"final_written_report": f"Error generating answer: {str(e)}"} + return {"final_written_report": f"Error generating answer: {e!s}"} -async def generate_further_questions(state: State, config: RunnableConfig, writer: StreamWriter) -> Dict[str, Any]: +async def generate_further_questions( + state: State, config: RunnableConfig, writer: StreamWriter +) -> dict[str, Any]: """ Generate contextually relevant follow-up questions based on chat history and available documents. - + This node takes the chat history and reranked documents from sub-agents (qna_agent or sub_section_writer) and uses an LLM to generate follow-up questions that would naturally extend the conversation and provide additional value to the user. - + Returns: Dict containing the further questions in the "further_questions" key for state update. """ from app.services.llm_service import get_user_fast_llm - + # Get configuration and state data configuration = Configuration.from_runnable_config(config) chat_history = state.chat_history user_id = configuration.user_id streaming_service = state.streaming_service - + # Get reranked documents from the state (will be populated by sub-agents) - reranked_documents = getattr(state, 'reranked_documents', None) or [] + reranked_documents = getattr(state, "reranked_documents", None) or [] writer( { @@ -1538,20 +1691,20 @@ async def generate_further_questions(state: State, config: RunnableConfig, write # Stream empty further questions to UI writer({"yield_value": streaming_service.format_further_questions_delta([])}) return {"further_questions": []} - + # Format chat history for the prompt chat_history_xml = "\n" for message in chat_history: - if hasattr(message, 'type'): + if hasattr(message, "type"): if message.type == "human": chat_history_xml += f"{message.content}\n" elif message.type == "ai": chat_history_xml += f"{message.content}\n" else: # Handle other message types if needed - chat_history_xml += f"{str(message)}\n" + chat_history_xml += f"{message!s}\n" chat_history_xml += "" - + # Format available documents for the prompt documents_xml = "\n" for i, doc in enumerate(reranked_documents): @@ -1559,16 +1712,16 @@ async def generate_further_questions(state: State, config: RunnableConfig, write source_id = document_info.get("id", f"doc_{i}") source_type = document_info.get("document_type", "UNKNOWN") content = doc.get("content", "") - - documents_xml += f"\n" - documents_xml += f"\n" + + documents_xml += "\n" + documents_xml += "\n" documents_xml += f"{source_id}\n" documents_xml += f"{source_type}\n" - documents_xml += f"\n" + documents_xml += "\n" documents_xml += f"\n{content}\n" - documents_xml += f"\n" + documents_xml += "\n" documents_xml += "" - + # Create the human message content human_message_content = f""" {chat_history_xml} @@ -1605,25 +1758,25 @@ async def generate_further_questions(state: State, config: RunnableConfig, write # Create messages for the LLM messages = [ SystemMessage(content=get_further_questions_system_prompt()), - HumanMessage(content=human_message_content) + HumanMessage(content=human_message_content), ] - + try: # Call the LLM response = await llm.ainvoke(messages) - + # Parse the JSON response content = response.content - + # Find the JSON in the content - json_start = content.find('{') - json_end = content.rfind('}') + 1 + json_start = content.find("{") + json_end = content.rfind("}") + 1 if json_start >= 0 and json_end > json_start: json_str = content[json_start:json_end] - + # Parse the JSON string parsed_data = json.loads(json_str) - + # Extract the further_questions array further_questions = parsed_data.get("further_questions", []) @@ -1645,7 +1798,7 @@ async def generate_further_questions(state: State, config: RunnableConfig, write ) print(f"Successfully generated {len(further_questions)} further questions") - + return {"further_questions": further_questions} else: # If JSON structure not found, return empty list @@ -1666,10 +1819,10 @@ async def generate_further_questions(state: State, config: RunnableConfig, write {"yield_value": streaming_service.format_further_questions_delta([])} ) return {"further_questions": []} - + except (json.JSONDecodeError, ValueError) as e: # Log the error and return empty list - error_message = f"Error parsing further questions response: {str(e)}" + error_message = f"Error parsing further questions response: {e!s}" print(error_message) writer( {"yield_value": streaming_service.format_error(f"Warning: {error_message}")} @@ -1678,10 +1831,10 @@ async def generate_further_questions(state: State, config: RunnableConfig, write # Stream empty further questions to UI writer({"yield_value": streaming_service.format_further_questions_delta([])}) return {"further_questions": []} - + except Exception as e: # Handle any other errors - error_message = f"Error generating further questions: {str(e)}" + error_message = f"Error generating further questions: {e!s}" print(error_message) writer( {"yield_value": streaming_service.format_error(f"Warning: {error_message}")} diff --git a/surfsense_backend/app/agents/researcher/prompts.py b/surfsense_backend/app/agents/researcher/prompts.py index 4270e3f..44b2189 100644 --- a/surfsense_backend/app/agents/researcher/prompts.py +++ b/surfsense_backend/app/agents/researcher/prompts.py @@ -221,4 +221,4 @@ Output: }} -""" \ No newline at end of file +""" diff --git a/surfsense_backend/app/agents/researcher/qna_agent/__init__.py b/surfsense_backend/app/agents/researcher/qna_agent/__init__.py index 33fe6bf..163b8bf 100644 --- a/surfsense_backend/app/agents/researcher/qna_agent/__init__.py +++ b/surfsense_backend/app/agents/researcher/qna_agent/__init__.py @@ -1,5 +1,4 @@ -"""QnA Agent. -""" +"""QnA Agent.""" from .graph import graph diff --git a/surfsense_backend/app/agents/researcher/qna_agent/configuration.py b/surfsense_backend/app/agents/researcher/qna_agent/configuration.py index 0f3c74d..5a4529e 100644 --- a/surfsense_backend/app/agents/researcher/qna_agent/configuration.py +++ b/surfsense_backend/app/agents/researcher/qna_agent/configuration.py @@ -3,7 +3,7 @@ from __future__ import annotations from dataclasses import dataclass, fields -from typing import Optional, List, Any +from typing import Any from langchain_core.runnables import RunnableConfig @@ -15,13 +15,15 @@ class Configuration: # Configuration parameters for the Q&A agent user_query: str # The user's question to answer reformulated_query: str # The reformulated query - relevant_documents: List[Any] # Documents provided directly to the agent for answering + relevant_documents: list[ + Any + ] # Documents provided directly to the agent for answering user_id: str # User identifier search_space_id: int # Search space identifier @classmethod def from_runnable_config( - cls, config: Optional[RunnableConfig] = None + cls, config: RunnableConfig | None = None ) -> Configuration: """Create a Configuration instance from a RunnableConfig object.""" configurable = (config.get("configurable") or {}) if config else {} diff --git a/surfsense_backend/app/agents/researcher/qna_agent/graph.py b/surfsense_backend/app/agents/researcher/qna_agent/graph.py index 788ec4a..0d9c8ba 100644 --- a/surfsense_backend/app/agents/researcher/qna_agent/graph.py +++ b/surfsense_backend/app/agents/researcher/qna_agent/graph.py @@ -1,7 +1,8 @@ from langgraph.graph import StateGraph -from .state import State -from .nodes import rerank_documents, answer_question + from .configuration import Configuration +from .nodes import answer_question, rerank_documents +from .state import State # Define a new graph workflow = StateGraph(State, config_schema=Configuration) diff --git a/surfsense_backend/app/agents/researcher/qna_agent/nodes.py b/surfsense_backend/app/agents/researcher/qna_agent/nodes.py index 910f55b..4bcc042 100644 --- a/surfsense_backend/app/agents/researcher/qna_agent/nodes.py +++ b/surfsense_backend/app/agents/researcher/qna_agent/nodes.py @@ -1,24 +1,28 @@ -from app.services.reranker_service import RerankerService -from .configuration import Configuration -from langchain_core.runnables import RunnableConfig -from .state import State -from typing import Any, Dict -from .prompts import get_qna_citation_system_prompt, get_qna_no_documents_system_prompt -from langchain_core.messages import HumanMessage, SystemMessage -from ..utils import ( - optimize_documents_for_token_limit, - calculate_token_count, - format_documents_section -) +from typing import Any -async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, Any]: +from langchain_core.messages import HumanMessage, SystemMessage +from langchain_core.runnables import RunnableConfig + +from app.services.reranker_service import RerankerService + +from ..utils import ( + calculate_token_count, + format_documents_section, + optimize_documents_for_token_limit, +) +from .configuration import Configuration +from .prompts import get_qna_citation_system_prompt, get_qna_no_documents_system_prompt +from .state import State + + +async def rerank_documents(state: State, config: RunnableConfig) -> dict[str, Any]: """ Rerank the documents based on relevance to the user's question. - + This node takes the relevant documents provided in the configuration, reranks them using the reranker service based on the user's query, and updates the state with the reranked documents. - + Returns: Dict containing the reranked documents. """ @@ -30,16 +34,14 @@ async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, An # If no documents were provided, return empty list if not documents or len(documents) == 0: - return { - "reranked_documents": [] - } - + return {"reranked_documents": []} + # Get reranker service from app config reranker_service = RerankerService.get_reranker_instance() - + # Use documents as is if no reranker service is available reranked_docs = documents - + if reranker_service: try: # Convert documents to format expected by reranker if needed @@ -51,58 +53,64 @@ async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, An "document": { "id": doc.get("document", {}).get("id", ""), "title": doc.get("document", {}).get("title", ""), - "document_type": doc.get("document", {}).get("document_type", ""), - "metadata": doc.get("document", {}).get("metadata", {}) - } - } for i, doc in enumerate(documents) + "document_type": doc.get("document", {}).get( + "document_type", "" + ), + "metadata": doc.get("document", {}).get("metadata", {}), + }, + } + for i, doc in enumerate(documents) ] - + # Rerank documents using the user's query - reranked_docs = reranker_service.rerank_documents(user_query + "\n" + reformulated_query, reranker_input_docs) - + reranked_docs = reranker_service.rerank_documents( + user_query + "\n" + reformulated_query, reranker_input_docs + ) + # Sort by score in descending order reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True) - - print(f"Reranked {len(reranked_docs)} documents for Q&A query: {user_query}") - except Exception as e: - print(f"Error during reranking: {str(e)}") - # Use original docs if reranking fails - - return { - "reranked_documents": reranked_docs - } -async def answer_question(state: State, config: RunnableConfig) -> Dict[str, Any]: + print( + f"Reranked {len(reranked_docs)} documents for Q&A query: {user_query}" + ) + except Exception as e: + print(f"Error during reranking: {e!s}") + # Use original docs if reranking fails + + return {"reranked_documents": reranked_docs} + + +async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any]: """ Answer the user's question using the provided documents. - + This node takes the relevant documents provided in the configuration and uses an LLM to generate a comprehensive answer to the user's question with proper citations. The citations follow IEEE format using source IDs from the documents. If no documents are provided, it will use chat history to generate an answer. - + Returns: Dict containing the final answer in the "final_answer" key. """ from app.services.llm_service import get_user_fast_llm - + # Get configuration and relevant documents from configuration configuration = Configuration.from_runnable_config(config) documents = state.reranked_documents user_query = configuration.user_query user_id = configuration.user_id - + # Get user's fast LLM llm = await get_user_fast_llm(state.db_session, user_id) if not llm: error_message = f"No fast LLM configured for user {user_id}" print(error_message) raise RuntimeError(error_message) - + # Determine if we have documents and optimize for token limits has_documents_initially = documents and len(documents) > 0 - + if has_documents_initially: # Create base message template for token calculation (without documents) base_human_message_template = f""" @@ -114,41 +122,49 @@ async def answer_question(state: State, config: RunnableConfig) -> Dict[str, Any Please provide a detailed, comprehensive answer to the user's question using the information from their personal knowledge sources. Make sure to cite all information appropriately and engage in a conversational manner. """ - + # Use initial system prompt for token calculation initial_system_prompt = get_qna_citation_system_prompt() - base_messages = state.chat_history + [ + base_messages = [ + *state.chat_history, SystemMessage(content=initial_system_prompt), - HumanMessage(content=base_human_message_template) + HumanMessage(content=base_human_message_template), ] - + # Optimize documents to fit within token limits - optimized_documents, has_optimized_documents = optimize_documents_for_token_limit( - documents, base_messages, llm.model + optimized_documents, has_optimized_documents = ( + optimize_documents_for_token_limit(documents, base_messages, llm.model) ) - + # Update state based on optimization result documents = optimized_documents has_documents = has_optimized_documents else: has_documents = False - + # Choose system prompt based on final document availability - system_prompt = get_qna_citation_system_prompt() if has_documents else get_qna_no_documents_system_prompt() - + system_prompt = ( + get_qna_citation_system_prompt() + if has_documents + else get_qna_no_documents_system_prompt() + ) + # Generate documents section - documents_text = format_documents_section( - documents, - "Source material from your personal knowledge base" - ) if has_documents else "" - + documents_text = ( + format_documents_section( + documents, "Source material from your personal knowledge base" + ) + if has_documents + else "" + ) + # Create final human message content instruction_text = ( "Please provide a detailed, comprehensive answer to the user's question using the information from their personal knowledge sources. Make sure to cite all information appropriately and engage in a conversational manner." - if has_documents else - "Please provide a helpful answer to the user's question based on our conversation history and your general knowledge. Engage in a conversational manner." + if has_documents + else "Please provide a helpful answer to the user's question based on our conversation history and your general knowledge. Engage in a conversational manner." ) - + human_message_content = f""" {documents_text} @@ -159,22 +175,20 @@ async def answer_question(state: State, config: RunnableConfig) -> Dict[str, Any {instruction_text} """ - + # Create final messages for the LLM - messages_with_chat_history = state.chat_history + [ + messages_with_chat_history = [ + *state.chat_history, SystemMessage(content=system_prompt), - HumanMessage(content=human_message_content) + HumanMessage(content=human_message_content), ] - + # Log final token count total_tokens = calculate_token_count(messages_with_chat_history, llm.model) print(f"Final token count: {total_tokens}") - - + # Call the LLM and get the response response = await llm.ainvoke(messages_with_chat_history) final_answer = response.content - - return { - "final_answer": final_answer - } + + return {"final_answer": final_answer} diff --git a/surfsense_backend/app/agents/researcher/qna_agent/state.py b/surfsense_backend/app/agents/researcher/qna_agent/state.py index 69bd843..f6cc7b1 100644 --- a/surfsense_backend/app/agents/researcher/qna_agent/state.py +++ b/surfsense_backend/app/agents/researcher/qna_agent/state.py @@ -3,14 +3,16 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import List, Optional, Any +from typing import Any + from sqlalchemy.ext.asyncio import AsyncSession + @dataclass class State: """Defines the dynamic state for the Q&A agent during execution. - This state tracks the database session, chat history, and the outputs + This state tracks the database session, chat history, and the outputs generated by the agent's nodes during question answering. See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state for more information. @@ -18,8 +20,8 @@ class State: # Runtime context db_session: AsyncSession - - chat_history: Optional[List[Any]] = field(default_factory=list) + + chat_history: list[Any] | None = field(default_factory=list) # OUTPUT: Populated by agent nodes - reranked_documents: Optional[List[Any]] = None - final_answer: Optional[str] = None + reranked_documents: list[Any] | None = None + final_answer: str | None = None diff --git a/surfsense_backend/app/agents/researcher/state.py b/surfsense_backend/app/agents/researcher/state.py index 8f50e30..0e10dfa 100644 --- a/surfsense_backend/app/agents/researcher/state.py +++ b/surfsense_backend/app/agents/researcher/state.py @@ -3,10 +3,13 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import List, Optional, Any +from typing import Any + from sqlalchemy.ext.asyncio import AsyncSession + from app.services.streaming_service import StreamingService + @dataclass class State: """Defines the dynamic state for the agent during execution. @@ -15,23 +18,23 @@ class State: See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state for more information. """ + # Runtime context (not part of actual graph state) db_session: AsyncSession - + # Streaming service streaming_service: StreamingService - - chat_history: Optional[List[Any]] = field(default_factory=list) - - reformulated_query: Optional[str] = field(default=None) + + chat_history: list[Any] | None = field(default_factory=list) + + reformulated_query: str | None = field(default=None) # Using field to explicitly mark as part of state - answer_outline: Optional[Any] = field(default=None) - further_questions: Optional[Any] = field(default=None) - + answer_outline: Any | None = field(default=None) + further_questions: Any | None = field(default=None) + # Temporary field to hold reranked documents from sub-agents for further question generation - reranked_documents: Optional[List[Any]] = field(default=None) - + reranked_documents: list[Any] | None = field(default=None) + # OUTPUT: Populated by agent nodes # Using field to explicitly mark as part of state - final_written_report: Optional[str] = field(default=None) - + final_written_report: str | None = field(default=None) diff --git a/surfsense_backend/app/agents/researcher/sub_section_writer/configuration.py b/surfsense_backend/app/agents/researcher/sub_section_writer/configuration.py index b7acf8b..29cbf45 100644 --- a/surfsense_backend/app/agents/researcher/sub_section_writer/configuration.py +++ b/surfsense_backend/app/agents/researcher/sub_section_writer/configuration.py @@ -4,13 +4,14 @@ from __future__ import annotations from dataclasses import dataclass, fields from enum import Enum -from typing import Optional, List, Any +from typing import Any from langchain_core.runnables import RunnableConfig class SubSectionType(Enum): """Enum defining the type of sub-section.""" + START = "START" MIDDLE = "MIDDLE" END = "END" @@ -22,17 +23,16 @@ class Configuration: # Input parameters provided at invocation sub_section_title: str - sub_section_questions: List[str] + sub_section_questions: list[str] sub_section_type: SubSectionType user_query: str - relevant_documents: List[Any] # Documents provided directly to the agent + relevant_documents: list[Any] # Documents provided directly to the agent user_id: str search_space_id: int - @classmethod def from_runnable_config( - cls, config: Optional[RunnableConfig] = None + cls, config: RunnableConfig | None = None ) -> Configuration: """Create a Configuration instance from a RunnableConfig object.""" configurable = (config.get("configurable") or {}) if config else {} diff --git a/surfsense_backend/app/agents/researcher/sub_section_writer/graph.py b/surfsense_backend/app/agents/researcher/sub_section_writer/graph.py index 5a5a5ba..35ebc4e 100644 --- a/surfsense_backend/app/agents/researcher/sub_section_writer/graph.py +++ b/surfsense_backend/app/agents/researcher/sub_section_writer/graph.py @@ -1,7 +1,8 @@ from langgraph.graph import StateGraph -from .state import State -from .nodes import write_sub_section, rerank_documents + from .configuration import Configuration +from .nodes import rerank_documents, write_sub_section +from .state import State # Define a new graph workflow = StateGraph(State, config_schema=Configuration) diff --git a/surfsense_backend/app/agents/researcher/sub_section_writer/nodes.py b/surfsense_backend/app/agents/researcher/sub_section_writer/nodes.py index 75d3f35..2dee978 100644 --- a/surfsense_backend/app/agents/researcher/sub_section_writer/nodes.py +++ b/surfsense_backend/app/agents/researcher/sub_section_writer/nodes.py @@ -1,25 +1,28 @@ -from .configuration import Configuration -from langchain_core.runnables import RunnableConfig -from .state import State -from typing import Any, Dict -from app.services.reranker_service import RerankerService -from .prompts import get_citation_system_prompt, get_no_documents_system_prompt -from langchain_core.messages import HumanMessage, SystemMessage -from .configuration import SubSectionType -from ..utils import ( - optimize_documents_for_token_limit, - calculate_token_count, - format_documents_section -) +from typing import Any -async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, Any]: +from langchain_core.messages import HumanMessage, SystemMessage +from langchain_core.runnables import RunnableConfig + +from app.services.reranker_service import RerankerService + +from ..utils import ( + calculate_token_count, + format_documents_section, + optimize_documents_for_token_limit, +) +from .configuration import Configuration, SubSectionType +from .prompts import get_citation_system_prompt, get_no_documents_system_prompt +from .state import State + + +async def rerank_documents(state: State, config: RunnableConfig) -> dict[str, Any]: """ Rerank the documents based on relevance to the sub-section title. - + This node takes the relevant documents provided in the configuration, reranks them using the reranker service based on the sub-section title, and updates the state with the reranked documents. - + Returns: Dict containing the reranked documents. """ @@ -30,23 +33,23 @@ async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, An # If no documents were provided, return empty list if not documents or len(documents) == 0: - return { - "reranked_documents": [] - } - + return {"reranked_documents": []} + # Get reranker service from app config reranker_service = RerankerService.get_reranker_instance() - + # Use documents as is if no reranker service is available reranked_docs = documents - + if reranker_service: try: # Use the sub-section questions for reranking context # rerank_query = "\n".join(sub_section_questions) # rerank_query = configuration.user_query - - rerank_query = configuration.user_query + "\n" + "\n".join(sub_section_questions) + + rerank_query = ( + configuration.user_query + "\n" + "\n".join(sub_section_questions) + ) # Convert documents to format expected by reranker if needed reranker_input_docs = [ @@ -57,54 +60,60 @@ async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, An "document": { "id": doc.get("document", {}).get("id", ""), "title": doc.get("document", {}).get("title", ""), - "document_type": doc.get("document", {}).get("document_type", ""), - "metadata": doc.get("document", {}).get("metadata", {}) - } - } for i, doc in enumerate(documents) + "document_type": doc.get("document", {}).get( + "document_type", "" + ), + "metadata": doc.get("document", {}).get("metadata", {}), + }, + } + for i, doc in enumerate(documents) ] - + # Rerank documents using the section title - reranked_docs = reranker_service.rerank_documents(rerank_query, reranker_input_docs) - + reranked_docs = reranker_service.rerank_documents( + rerank_query, reranker_input_docs + ) + # Sort by score in descending order reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True) - - print(f"Reranked {len(reranked_docs)} documents for section: {configuration.sub_section_title}") - except Exception as e: - print(f"Error during reranking: {str(e)}") - # Use original docs if reranking fails - - return { - "reranked_documents": reranked_docs - } -async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, Any]: + print( + f"Reranked {len(reranked_docs)} documents for section: {configuration.sub_section_title}" + ) + except Exception as e: + print(f"Error during reranking: {e!s}") + # Use original docs if reranking fails + + return {"reranked_documents": reranked_docs} + + +async def write_sub_section(state: State, config: RunnableConfig) -> dict[str, Any]: """ Write the sub-section using the provided documents. - + This node takes the relevant documents provided in the configuration and uses an LLM to generate a comprehensive answer to the sub-section title with proper citations. The citations follow IEEE format using source IDs from the documents. If no documents are provided, it will use chat history to generate content. - + Returns: Dict containing the final answer in the "final_answer" key. """ from app.services.llm_service import get_user_fast_llm - + # Get configuration and relevant documents from configuration configuration = Configuration.from_runnable_config(config) documents = state.reranked_documents user_id = configuration.user_id - + # Get user's fast LLM llm = await get_user_fast_llm(state.db_session, user_id) if not llm: error_message = f"No fast LLM configured for user {user_id}" print(error_message) raise RuntimeError(error_message) - + # Extract configuration data section_title = configuration.sub_section_title sub_section_questions = configuration.sub_section_questions @@ -113,18 +122,18 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A # Format the questions as bullet points for clarity questions_text = "\n".join([f"- {question}" for question in sub_section_questions]) - + # Provide context based on the subsection type section_position_context_map = { SubSectionType.START: "This is the INTRODUCTION section.", SubSectionType.MIDDLE: "This is a MIDDLE section. Ensure this content flows naturally from previous sections and into subsequent ones. This could be any middle section in the document, so maintain coherence with the overall structure while addressing the specific topic of this section. Do not provide any conclusions in this section, as conclusions should only appear in the final section.", - SubSectionType.END: "This is the CONCLUSION section. Focus on summarizing key points, providing closure." + SubSectionType.END: "This is the CONCLUSION section. Focus on summarizing key points, providing closure.", } section_position_context = section_position_context_map.get(sub_section_type, "") - + # Determine if we have documents and optimize for token limits has_documents_initially = documents and len(documents) > 0 - + if has_documents_initially: # Create base message template for token calculation (without documents) base_human_message_template = f""" @@ -149,38 +158,45 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A Please write content for this sub-section using the provided source material and cite all information appropriately. """ - + # Use initial system prompt for token calculation initial_system_prompt = get_citation_system_prompt() - base_messages = state.chat_history + [ + base_messages = [ + *state.chat_history, SystemMessage(content=initial_system_prompt), - HumanMessage(content=base_human_message_template) + HumanMessage(content=base_human_message_template), ] - + # Optimize documents to fit within token limits - optimized_documents, has_optimized_documents = optimize_documents_for_token_limit( - documents, base_messages, llm.model + optimized_documents, has_optimized_documents = ( + optimize_documents_for_token_limit(documents, base_messages, llm.model) ) - + # Update state based on optimization result documents = optimized_documents has_documents = has_optimized_documents else: has_documents = False - + # Choose system prompt based on final document availability - system_prompt = get_citation_system_prompt() if has_documents else get_no_documents_system_prompt() - + system_prompt = ( + get_citation_system_prompt() + if has_documents + else get_no_documents_system_prompt() + ) + # Generate documents section - documents_text = format_documents_section(documents, "Source material") if has_documents else "" - + documents_text = ( + format_documents_section(documents, "Source material") if has_documents else "" + ) + # Create final human message content instruction_text = ( "Please write content for this sub-section using the provided source material and cite all information appropriately." - if has_documents else - "Please write content for this sub-section based on our conversation history and your general knowledge." + if has_documents + else "Please write content for this sub-section based on our conversation history and your general knowledge." ) - + human_message_content = f""" {documents_text} @@ -204,22 +220,20 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A {instruction_text} """ - + # Create final messages for the LLM - messages_with_chat_history = state.chat_history + [ + messages_with_chat_history = [ + *state.chat_history, SystemMessage(content=system_prompt), - HumanMessage(content=human_message_content) + HumanMessage(content=human_message_content), ] - + # Log final token count total_tokens = calculate_token_count(messages_with_chat_history, llm.model) print(f"Final token count: {total_tokens}") - + # Call the LLM and get the response response = await llm.ainvoke(messages_with_chat_history) final_answer = response.content - - return { - "final_answer": final_answer - } + return {"final_answer": final_answer} diff --git a/surfsense_backend/app/agents/researcher/sub_section_writer/prompts.py b/surfsense_backend/app/agents/researcher/sub_section_writer/prompts.py index 48b0c66..02036fc 100644 --- a/surfsense_backend/app/agents/researcher/sub_section_writer/prompts.py +++ b/surfsense_backend/app/agents/researcher/sub_section_writer/prompts.py @@ -182,4 +182,4 @@ When writing content for a sub-section without access to personal documents: 5. Address the guiding questions through natural content flow without explicitly listing them 6. Suggest how adding relevant sources to SurfSense could enhance future content when appropriate -""" \ No newline at end of file +""" diff --git a/surfsense_backend/app/agents/researcher/sub_section_writer/state.py b/surfsense_backend/app/agents/researcher/sub_section_writer/state.py index 7998279..6fb5434 100644 --- a/surfsense_backend/app/agents/researcher/sub_section_writer/state.py +++ b/surfsense_backend/app/agents/researcher/sub_section_writer/state.py @@ -3,9 +3,11 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import List, Optional, Any +from typing import Any + from sqlalchemy.ext.asyncio import AsyncSession + @dataclass class State: """Defines the dynamic state for the agent during execution. @@ -14,11 +16,11 @@ class State: See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state for more information. """ + # Runtime context db_session: AsyncSession - - chat_history: Optional[List[Any]] = field(default_factory=list) - # OUTPUT: Populated by agent nodes - reranked_documents: Optional[List[Any]] = None - final_answer: Optional[str] = None + chat_history: list[Any] | None = field(default_factory=list) + # OUTPUT: Populated by agent nodes + reranked_documents: list[Any] | None = None + final_answer: str | None = None diff --git a/surfsense_backend/app/agents/researcher/utils.py b/surfsense_backend/app/agents/researcher/utils.py index c4991cc..908b3ab 100644 --- a/surfsense_backend/app/agents/researcher/utils.py +++ b/surfsense_backend/app/agents/researcher/utils.py @@ -1,27 +1,37 @@ -from typing import List, Dict, Any, Tuple, NamedTuple +from typing import Any, NamedTuple + from langchain_core.messages import BaseMessage +from litellm import get_model_info, token_counter from pydantic import BaseModel, Field -from litellm import token_counter, get_model_info + class Section(BaseModel): """A section in the answer outline.""" + section_id: int = Field(..., description="The zero-based index of the section") section_title: str = Field(..., description="The title of the section") - questions: List[str] = Field(..., description="Questions to research for this section") + questions: list[str] = Field( + ..., description="Questions to research for this section" + ) + class AnswerOutline(BaseModel): """The complete answer outline with all sections.""" - answer_outline: List[Section] = Field(..., description="List of sections in the answer outline") + + answer_outline: list[Section] = Field( + ..., description="List of sections in the answer outline" + ) class DocumentTokenInfo(NamedTuple): """Information about a document and its token cost.""" + index: int - document: Dict[str, Any] + document: dict[str, Any] formatted_content: str token_count: int - - + + def get_connector_emoji(connector_name: str) -> str: """Get an appropriate emoji for a connector type.""" connector_emojis = { @@ -34,7 +44,7 @@ def get_connector_emoji(connector_name: str) -> str: "GITHUB_CONNECTOR": "🐙", "LINEAR_CONNECTOR": "📊", "TAVILY_API": "🔍", - "LINKUP_API": "🔗" + "LINKUP_API": "🔗", } return connector_emojis.get(connector_name, "🔎") @@ -51,31 +61,26 @@ def get_connector_friendly_name(connector_name: str) -> str: "GITHUB_CONNECTOR": "GitHub", "LINEAR_CONNECTOR": "Linear", "TAVILY_API": "Tavily Search", - "LINKUP_API": "Linkup Search" + "LINKUP_API": "Linkup Search", } return connector_friendly_names.get(connector_name, connector_name) -def convert_langchain_messages_to_dict(messages: List[BaseMessage]) -> List[Dict[str, str]]: +def convert_langchain_messages_to_dict( + messages: list[BaseMessage], +) -> list[dict[str, str]]: """Convert LangChain messages to format expected by token_counter.""" - role_mapping = { - 'system': 'system', - 'human': 'user', - 'ai': 'assistant' - } + role_mapping = {"system": "system", "human": "user", "ai": "assistant"} converted_messages = [] for msg in messages: - role = role_mapping.get(getattr(msg, 'type', None), 'user') - converted_messages.append({ - "role": role, - "content": str(msg.content) - }) + role = role_mapping.get(getattr(msg, "type", None), "user") + converted_messages.append({"role": role, "content": str(msg.content)}) return converted_messages -def format_document_for_citation(document: Dict[str, Any]) -> str: +def format_document_for_citation(document: dict[str, Any]) -> str: """Format a single document for citation in the standard XML format.""" content = document.get("content", "") doc_info = document.get("document", {}) @@ -93,7 +98,9 @@ def format_document_for_citation(document: Dict[str, Any]) -> str: """ -def format_documents_section(documents: List[Dict[str, Any]], section_title: str = "Source material") -> str: +def format_documents_section( + documents: list[dict[str, Any]], section_title: str = "Source material" +) -> str: """Format multiple documents into a complete documents section.""" if not documents: return "" @@ -106,7 +113,9 @@ def format_documents_section(documents: List[Dict[str, Any]], section_title: str """ -def calculate_document_token_costs(documents: List[Dict[str, Any]], model: str) -> List[DocumentTokenInfo]: +def calculate_document_token_costs( + documents: list[dict[str, Any]], model: str +) -> list[DocumentTokenInfo]: """Pre-calculate token costs for each document.""" document_token_info = [] @@ -115,24 +124,24 @@ def calculate_document_token_costs(documents: List[Dict[str, Any]], model: str) # Calculate token count for this document token_count = token_counter( - messages=[{"role": "user", "content": formatted_doc}], - model=model + messages=[{"role": "user", "content": formatted_doc}], model=model ) - document_token_info.append(DocumentTokenInfo( - index=i, - document=doc, - formatted_content=formatted_doc, - token_count=token_count - )) + document_token_info.append( + DocumentTokenInfo( + index=i, + document=doc, + formatted_content=formatted_doc, + token_count=token_count, + ) + ) return document_token_info def find_optimal_documents_with_binary_search( - document_tokens: List[DocumentTokenInfo], - available_tokens: int -) -> List[DocumentTokenInfo]: + document_tokens: list[DocumentTokenInfo], available_tokens: int +) -> list[DocumentTokenInfo]: """Use binary search to find the maximum number of documents that fit within token limit.""" if not document_tokens or available_tokens <= 0: return [] @@ -143,8 +152,7 @@ def find_optimal_documents_with_binary_search( while left <= right: mid = (left + right) // 2 current_docs = document_tokens[:mid] - current_token_sum = sum( - doc_info.token_count for doc_info in current_docs) + current_token_sum = sum(doc_info.token_count for doc_info in current_docs) if current_token_sum <= available_tokens: optimal_docs = current_docs @@ -159,20 +167,18 @@ def get_model_context_window(model_name: str) -> int: """Get the total context window size for a model (input + output tokens).""" try: model_info = get_model_info(model_name) - context_window = model_info.get( - 'max_input_tokens', 4096) # Default fallback + context_window = model_info.get("max_input_tokens", 4096) # Default fallback return context_window except Exception as e: print( - f"Warning: Could not get model info for {model_name}, using default 4096 tokens. Error: {e}") + f"Warning: Could not get model info for {model_name}, using default 4096 tokens. Error: {e}" + ) return 4096 # Conservative fallback def optimize_documents_for_token_limit( - documents: List[Dict[str, Any]], - base_messages: List[BaseMessage], - model_name: str -) -> Tuple[List[Dict[str, Any]], bool]: + documents: list[dict[str, Any]], base_messages: list[BaseMessage], model_name: str +) -> tuple[list[dict[str, Any]], bool]: """ Optimize documents to fit within token limits using binary search. @@ -197,7 +203,8 @@ def optimize_documents_for_token_limit( available_tokens_for_docs = context_window - base_tokens print( - f"Token optimization: Context window={context_window}, Base={base_tokens}, Available for docs={available_tokens_for_docs}") + f"Token optimization: Context window={context_window}, Base={base_tokens}, Available for docs={available_tokens_for_docs}" + ) if available_tokens_for_docs <= 0: print("No tokens available for documents after base content and output buffer") @@ -208,8 +215,7 @@ def optimize_documents_for_token_limit( # Find optimal number of documents using binary search optimal_doc_info = find_optimal_documents_with_binary_search( - document_token_info, - available_tokens_for_docs + document_token_info, available_tokens_for_docs ) # Extract the original document objects @@ -217,12 +223,13 @@ def optimize_documents_for_token_limit( has_documents_remaining = len(optimized_documents) > 0 print( - f"Token optimization result: Using {len(optimized_documents)}/{len(documents)} documents") + f"Token optimization result: Using {len(optimized_documents)}/{len(documents)} documents" + ) return optimized_documents, has_documents_remaining -def calculate_token_count(messages: List[BaseMessage], model_name: str) -> int: +def calculate_token_count(messages: list[BaseMessage], model_name: str) -> int: """Calculate token count for a list of LangChain messages.""" model = model_name messages_dict = convert_langchain_messages_to_dict(messages) diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py index 956740f..17f9082 100644 --- a/surfsense_backend/app/app.py +++ b/surfsense_backend/app/app.py @@ -2,22 +2,13 @@ from contextlib import asynccontextmanager from fastapi import Depends, FastAPI from fastapi.middleware.cors import CORSMiddleware - from sqlalchemy.ext.asyncio import AsyncSession -from app.db import User, create_db_and_tables, get_async_session -from app.schemas import UserCreate, UserRead, UserUpdate - - -from app.routes import router as crud_router from app.config import config - -from app.users import ( - SECRET, - auth_backend, - fastapi_users, - current_active_user -) +from app.db import User, create_db_and_tables, get_async_session +from app.routes import router as crud_router +from app.schemas import UserCreate, UserRead, UserUpdate +from app.users import SECRET, auth_backend, current_active_user, fastapi_users @asynccontextmanager @@ -64,12 +55,10 @@ app.include_router( if config.AUTH_TYPE == "GOOGLE": from app.users import google_oauth_client + app.include_router( fastapi_users.get_oauth_router( - google_oauth_client, - auth_backend, - SECRET, - is_verified_by_default=True + google_oauth_client, auth_backend, SECRET, is_verified_by_default=True ), prefix="/auth/google", tags=["auth"], @@ -79,5 +68,8 @@ app.include_router(crud_router, prefix="/api/v1", tags=["crud"]) @app.get("/verify-token") -async def authenticated_route(user: User = Depends(current_active_user), session: AsyncSession = Depends(get_async_session)): +async def authenticated_route( + user: User = Depends(current_active_user), + session: AsyncSession = Depends(get_async_session), +): return {"message": "Token is valid"} diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index 06c89ae..48de86d 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -1,13 +1,11 @@ import os -from pathlib import Path import shutil +from pathlib import Path from chonkie import AutoEmbeddings, CodeChunker, RecursiveChunker from dotenv import load_dotenv - from rerankers import Reranker - # Get the base directory of the project BASE_DIR = Path(__file__).resolve().parent.parent.parent @@ -18,37 +16,37 @@ load_dotenv(env_file) def is_ffmpeg_installed(): """ Check if ffmpeg is installed on the current system. - + Returns: bool: True if ffmpeg is installed, False otherwise. """ return shutil.which("ffmpeg") is not None - class Config: # Check if ffmpeg is installed if not is_ffmpeg_installed(): import static_ffmpeg + # ffmpeg installed on first call to add_paths(), threadsafe. static_ffmpeg.add_paths() # check if ffmpeg is installed again if not is_ffmpeg_installed(): - raise ValueError("FFmpeg is not installed on the system. Please install it to use the Surfsense Podcaster.") - + raise ValueError( + "FFmpeg is not installed on the system. Please install it to use the Surfsense Podcaster." + ) + # Database DATABASE_URL = os.getenv("DATABASE_URL") - + NEXT_FRONTEND_URL = os.getenv("NEXT_FRONTEND_URL") - - + # AUTH: Google OAuth AUTH_TYPE = os.getenv("AUTH_TYPE") if AUTH_TYPE == "GOOGLE": GOOGLE_OAUTH_CLIENT_ID = os.getenv("GOOGLE_OAUTH_CLIENT_ID") GOOGLE_OAUTH_CLIENT_SECRET = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET") - - + # LLM instances are now managed per-user through the LLMConfig system # Legacy environment variables removed in favor of user-specific configurations @@ -56,12 +54,12 @@ class Config: EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL") embedding_model_instance = AutoEmbeddings.get_embeddings(EMBEDDING_MODEL) chunker_instance = RecursiveChunker( - chunk_size=getattr(embedding_model_instance, 'max_seq_length', 512) + chunk_size=getattr(embedding_model_instance, "max_seq_length", 512) ) code_chunker_instance = CodeChunker( - chunk_size=getattr(embedding_model_instance, 'max_seq_length', 512) + chunk_size=getattr(embedding_model_instance, "max_seq_length", 512) ) - + # Reranker's Configuration | Pinecode, Cohere etc. Read more at https://github.com/AnswerDotAI/rerankers?tab=readme-ov-file#usage RERANKERS_MODEL_NAME = os.getenv("RERANKERS_MODEL_NAME") RERANKERS_MODEL_TYPE = os.getenv("RERANKERS_MODEL_TYPE") @@ -69,45 +67,46 @@ class Config: model_name=RERANKERS_MODEL_NAME, model_type=RERANKERS_MODEL_TYPE, ) - + # OAuth JWT SECRET_KEY = os.getenv("SECRET_KEY") - + # ETL Service ETL_SERVICE = os.getenv("ETL_SERVICE") - + if ETL_SERVICE == "UNSTRUCTURED": # Unstructured API Key UNSTRUCTURED_API_KEY = os.getenv("UNSTRUCTURED_API_KEY") - + elif ETL_SERVICE == "LLAMACLOUD": # LlamaCloud API Key LLAMA_CLOUD_API_KEY = os.getenv("LLAMA_CLOUD_API_KEY") - + # Firecrawl API Key - FIRECRAWL_API_KEY = os.getenv("FIRECRAWL_API_KEY", None) - + FIRECRAWL_API_KEY = os.getenv("FIRECRAWL_API_KEY", None) + # Litellm TTS Configuration TTS_SERVICE = os.getenv("TTS_SERVICE") TTS_SERVICE_API_BASE = os.getenv("TTS_SERVICE_API_BASE") TTS_SERVICE_API_KEY = os.getenv("TTS_SERVICE_API_KEY") - + # Litellm STT Configuration STT_SERVICE = os.getenv("STT_SERVICE") STT_SERVICE_API_BASE = os.getenv("STT_SERVICE_API_BASE") STT_SERVICE_API_KEY = os.getenv("STT_SERVICE_API_KEY") - - + # Validation Checks # Check embedding dimension - if hasattr(embedding_model_instance, 'dimension') and embedding_model_instance.dimension > 2000: + if ( + hasattr(embedding_model_instance, "dimension") + and embedding_model_instance.dimension > 2000 + ): raise ValueError( f"Embedding dimension for Model: {EMBEDDING_MODEL} " f"has {embedding_model_instance.dimension} dimensions, which " f"exceeds the maximum of 2000 allowed by PGVector." ) - @classmethod def get_settings(cls): """Get all settings as a dictionary.""" diff --git a/surfsense_backend/app/config/uvicorn.py b/surfsense_backend/app/config/uvicorn.py index f7086e1..7f28333 100644 --- a/surfsense_backend/app/config/uvicorn.py +++ b/surfsense_backend/app/config/uvicorn.py @@ -1,26 +1,25 @@ import os + def _parse_bool(value): """Parse boolean value from string.""" return value.lower() == "true" if value else False + def _parse_int(value, var_name): """Parse integer value with error handling.""" try: return int(value) except ValueError: - raise ValueError(f"Invalid integer value for {var_name}: {value}") + raise ValueError(f"Invalid integer value for {var_name}: {value}") from None + def _parse_headers(value): """Parse headers from comma-separated string.""" try: - return [ - tuple(h.split(":", 1)) - for h in value.split(",") - if ":" in h - ] + return [tuple(h.split(":", 1)) for h in value.split(",") if ":" in h] except Exception: - raise ValueError(f"Invalid headers format: {value}") + raise ValueError(f"Invalid headers format: {value}") from None def load_uvicorn_config(args=None): @@ -28,16 +27,16 @@ def load_uvicorn_config(args=None): Load Uvicorn configuration from environment variables and CLI args. Returns a dict suitable for passing to uvicorn.Config. """ - config_kwargs = dict( - app="app.app:app", - host=os.getenv("UVICORN_HOST", "0.0.0.0"), - port=int(os.getenv("UVICORN_PORT", 8000)), - log_level=os.getenv("UVICORN_LOG_LEVEL", "info"), - reload=args.reload if args else False, - reload_dirs=["app"] if (args and args.reload) else None, - ) - - # Configuration mapping for advanced options + config_kwargs = { + "app": "app.app:app", + "host": os.getenv("UVICORN_HOST", "0.0.0.0"), + "port": int(os.getenv("UVICORN_PORT", 8000)), + "log_level": os.getenv("UVICORN_LOG_LEVEL", "info"), + "reload": args.reload if args else False, + "reload_dirs": ["app"] if (args and args.reload) else None, + } + + # Configuration mapping for advanced options config_mapping = { "UVICORN_PROXY_HEADERS": ("proxy_headers", _parse_bool), "UVICORN_FORWARDED_ALLOW_IPS": ("forwarded_allow_ips", str), @@ -51,15 +50,33 @@ def load_uvicorn_config(args=None): "UVICORN_LOG_CONFIG": ("log_config", str), "UVICORN_SERVER_HEADER": ("server_header", _parse_bool), "UVICORN_DATE_HEADER": ("date_header", _parse_bool), - "UVICORN_LIMIT_CONCURRENCY": ("limit_concurrency", lambda x: _parse_int(x, "UVICORN_LIMIT_CONCURRENCY")), - "UVICORN_LIMIT_MAX_REQUESTS": ("limit_max_requests", lambda x: _parse_int(x, "UVICORN_LIMIT_MAX_REQUESTS")), - "UVICORN_TIMEOUT_KEEP_ALIVE": ("timeout_keep_alive", lambda x: _parse_int(x, "UVICORN_TIMEOUT_KEEP_ALIVE")), - "UVICORN_TIMEOUT_NOTIFY": ("timeout_notify", lambda x: _parse_int(x, "UVICORN_TIMEOUT_NOTIFY")), + "UVICORN_LIMIT_CONCURRENCY": ( + "limit_concurrency", + lambda x: _parse_int(x, "UVICORN_LIMIT_CONCURRENCY"), + ), + "UVICORN_LIMIT_MAX_REQUESTS": ( + "limit_max_requests", + lambda x: _parse_int(x, "UVICORN_LIMIT_MAX_REQUESTS"), + ), + "UVICORN_TIMEOUT_KEEP_ALIVE": ( + "timeout_keep_alive", + lambda x: _parse_int(x, "UVICORN_TIMEOUT_KEEP_ALIVE"), + ), + "UVICORN_TIMEOUT_NOTIFY": ( + "timeout_notify", + lambda x: _parse_int(x, "UVICORN_TIMEOUT_NOTIFY"), + ), "UVICORN_SSL_KEYFILE": ("ssl_keyfile", str), "UVICORN_SSL_CERTFILE": ("ssl_certfile", str), "UVICORN_SSL_KEYFILE_PASSWORD": ("ssl_keyfile_password", str), - "UVICORN_SSL_VERSION": ("ssl_version", lambda x: _parse_int(x, "UVICORN_SSL_VERSION")), - "UVICORN_SSL_CERT_REQS": ("ssl_cert_reqs", lambda x: _parse_int(x, "UVICORN_SSL_CERT_REQS")), + "UVICORN_SSL_VERSION": ( + "ssl_version", + lambda x: _parse_int(x, "UVICORN_SSL_VERSION"), + ), + "UVICORN_SSL_CERT_REQS": ( + "ssl_cert_reqs", + lambda x: _parse_int(x, "UVICORN_SSL_CERT_REQS"), + ), "UVICORN_SSL_CA_CERTS": ("ssl_ca_certs", str), "UVICORN_SSL_CIPHERS": ("ssl_ciphers", str), "UVICORN_HEADERS": ("headers", _parse_headers), @@ -76,7 +93,6 @@ def load_uvicorn_config(args=None): try: config_kwargs[config_key] = parser(value) except ValueError as e: - raise ValueError(f"Configuration error for {env_var}: {e}") - + raise ValueError(f"Configuration error for {env_var}: {e}") from e return config_kwargs diff --git a/surfsense_backend/app/connectors/discord_connector.py b/surfsense_backend/app/connectors/discord_connector.py index 1d5c1fb..506b463 100644 --- a/surfsense_backend/app/connectors/discord_connector.py +++ b/surfsense_backend/app/connectors/discord_connector.py @@ -6,11 +6,12 @@ A module for interacting with Discord's HTTP API to retrieve guilds, channels, a Requires a Discord bot token. """ +import asyncio +import datetime import logging + import discord from discord.ext import commands -import datetime -import asyncio logger = logging.getLogger(__name__) @@ -18,7 +19,7 @@ logger = logging.getLogger(__name__) class DiscordConnector(commands.Bot): """Class for retrieving guild, channel, and message history from Discord.""" - def __init__(self, token: str = None): + def __init__(self, token: str | None = None): """ Initialize the DiscordConnector with a bot token. @@ -30,7 +31,9 @@ class DiscordConnector(commands.Bot): intents.messages = True # Required to fetch messages intents.message_content = True # Required to read message content intents.members = True # Required to fetch member information - super().__init__(command_prefix="!", intents=intents) # command_prefix is required but not strictly used here + super().__init__( + command_prefix="!", intents=intents + ) # command_prefix is required but not strictly used here self.token = token self._bot_task = None # Holds the async bot task self._is_running = False # Flag to track if the bot is running @@ -48,7 +51,7 @@ class DiscordConnector(commands.Bot): @self.event async def on_disconnect(): logger.debug("Bot disconnected from Discord gateway.") - self._is_running = False # Reset flag on disconnect + self._is_running = False # Reset flag on disconnect @self.event async def on_resumed(): @@ -63,17 +66,23 @@ class DiscordConnector(commands.Bot): try: if self._is_running: - logger.warning("Bot is already running. Use close_bot() to stop it before starting again.") + logger.warning( + "Bot is already running. Use close_bot() to stop it before starting again." + ) return await self.start(self.token) logger.info("Discord bot started successfully.") except discord.LoginFailure: - logger.error("Failed to log in: Invalid token was provided. Please check your bot token.") + logger.error( + "Failed to log in: Invalid token was provided. Please check your bot token." + ) self._is_running = False raise except discord.PrivilegedIntentsRequired as e: - logger.error(f"Privileged Intents Required: {e}. Make sure all required intents are enabled in your bot's application page.") + logger.error( + f"Privileged Intents Required: {e}. Make sure all required intents are enabled in your bot's application page." + ) self._is_running = False raise except discord.ConnectionClosed as e: @@ -96,7 +105,6 @@ class DiscordConnector(commands.Bot): else: logger.info("Bot is not running or already disconnected.") - def set_token(self, token: str) -> None: """ Set the discord bot token. @@ -106,8 +114,10 @@ class DiscordConnector(commands.Bot): """ logger.info("Setting Discord bot token.") self.token = token - logger.info("Token set successfully. You can now start the bot with start_bot().") - + logger.info( + "Token set successfully. You can now start the bot with start_bot()." + ) + async def _wait_until_ready(self): """Helper to wait until the bot is connected and ready.""" logger.info("Waiting for the bot to be ready...") @@ -115,16 +125,20 @@ class DiscordConnector(commands.Bot): # Give the event loop a chance to switch to the bot's startup task. # This allows self.start() to begin initializing the client. # Terrible solution, but necessary to avoid blocking the event loop. - await asyncio.sleep(1) # Yield control to the event loop - + await asyncio.sleep(1) # Yield control to the event loop + try: await asyncio.wait_for(self.wait_until_ready(), timeout=60.0) logger.info("Bot is ready.") - except asyncio.TimeoutError: - logger.error(f"Bot did not become ready within 60 seconds. Connection may have failed.") + except TimeoutError: + logger.error( + "Bot did not become ready within 60 seconds. Connection may have failed." + ) raise except Exception as e: - logger.error(f"An unexpected error occurred while waiting for the bot to be ready: {e}") + logger.error( + f"An unexpected error occurred while waiting for the bot to be ready: {e}" + ) raise async def get_guilds(self) -> list[dict]: @@ -143,7 +157,9 @@ class DiscordConnector(commands.Bot): guilds_data = [] for guild in self.guilds: - member_count = guild.member_count if guild.member_count is not None else "N/A" + member_count = ( + guild.member_count if guild.member_count is not None else "N/A" + ) guilds_data.append( { "id": str(guild.id), @@ -183,15 +199,17 @@ class DiscordConnector(commands.Bot): channels_data.append( {"id": str(channel.id), "name": channel.name, "type": "text"} ) - - logger.info(f"Fetched {len(channels_data)} text channels from guild {guild_id}.") + + logger.info( + f"Fetched {len(channels_data)} text channels from guild {guild_id}." + ) return channels_data async def get_channel_history( self, channel_id: str, - start_date: str = None, - end_date: str = None, + start_date: str | None = None, + end_date: str | None = None, ) -> list[dict]: """ Fetch message history from a text channel. @@ -227,20 +245,26 @@ class DiscordConnector(commands.Bot): if start_date: try: - start_datetime = datetime.datetime.fromisoformat(start_date).replace(tzinfo=datetime.timezone.utc) + start_datetime = datetime.datetime.fromisoformat(start_date).replace( + tzinfo=datetime.UTC + ) after = start_datetime except ValueError: logger.warning(f"Invalid start_date format: {start_date}. Ignoring.") if end_date: try: - end_datetime = datetime.datetime.fromisoformat(f"{end_date}").replace(tzinfo=datetime.timezone.utc) + end_datetime = datetime.datetime.fromisoformat(f"{end_date}").replace( + tzinfo=datetime.UTC + ) before = end_datetime except ValueError: logger.warning(f"Invalid end_date format: {end_date}. Ignoring.") try: - async for message in channel.history(limit=None, before=before, after=after): + async for message in channel.history( + limit=None, before=before, after=after + ): messages_data.append( { "id": str(message.id), @@ -251,12 +275,14 @@ class DiscordConnector(commands.Bot): } ) except discord.Forbidden: - logger.error(f"Bot does not have permissions to read message history in channel {channel_id}.") + logger.error( + f"Bot does not have permissions to read message history in channel {channel_id}." + ) raise except discord.HTTPException as e: logger.error(f"Failed to fetch messages from channel {channel_id}: {e}") return [] - + logger.info(f"Fetched {len(messages_data)} messages from channel {channel_id}.") return messages_data @@ -278,7 +304,9 @@ class DiscordConnector(commands.Bot): permissions to view members. """ await self._wait_until_ready() - logger.info(f"Fetching user info for user ID: {user_id} in guild ID: {guild_id}") + logger.info( + f"Fetching user info for user ID: {user_id} in guild ID: {guild_id}" + ) guild = self.get_guild(int(guild_id)) if not guild: @@ -294,7 +322,9 @@ class DiscordConnector(commands.Bot): return { "id": str(member.id), "name": member.name, - "joined_at": member.joined_at.isoformat() if member.joined_at else None, + "joined_at": member.joined_at.isoformat() + if member.joined_at + else None, "roles": roles, } logger.warning(f"User {user_id} not found in guild {guild_id}.") @@ -303,8 +333,12 @@ class DiscordConnector(commands.Bot): logger.warning(f"User {user_id} not found in guild {guild_id}.") return None except discord.Forbidden: - logger.error(f"Bot does not have permissions to fetch members in guild {guild_id}. Ensure GUILD_MEMBERS intent is enabled.") + logger.error( + f"Bot does not have permissions to fetch members in guild {guild_id}. Ensure GUILD_MEMBERS intent is enabled." + ) raise except discord.HTTPException as e: - logger.error(f"Failed to fetch user info for {user_id} in guild {guild_id}: {e}") + logger.error( + f"Failed to fetch user info for {user_id} in guild {guild_id}: {e}" + ) return None diff --git a/surfsense_backend/app/connectors/github_connector.py b/surfsense_backend/app/connectors/github_connector.py index 6434f1e..647856c 100644 --- a/surfsense_backend/app/connectors/github_connector.py +++ b/surfsense_backend/app/connectors/github_connector.py @@ -1,54 +1,91 @@ import base64 import logging -from typing import List, Optional, Dict, Any -from github3 import login as github_login, exceptions as github_exceptions -from github3.repos.contents import Contents +from typing import Any + +from github3 import exceptions as github_exceptions, login as github_login from github3.exceptions import ForbiddenError, NotFoundError +from github3.repos.contents import Contents logger = logging.getLogger(__name__) # List of common code file extensions to target CODE_EXTENSIONS = { - '.py', '.js', '.jsx', '.ts', '.tsx', '.java', '.c', '.cpp', '.h', '.hpp', - '.cs', '.go', '.rb', '.php', '.swift', '.kt', '.scala', '.rs', '.m', - '.sh', '.bash', '.ps1', '.lua', '.pl', '.pm', '.r', '.dart', '.sql' + ".py", + ".js", + ".jsx", + ".ts", + ".tsx", + ".java", + ".c", + ".cpp", + ".h", + ".hpp", + ".cs", + ".go", + ".rb", + ".php", + ".swift", + ".kt", + ".scala", + ".rs", + ".m", + ".sh", + ".bash", + ".ps1", + ".lua", + ".pl", + ".pm", + ".r", + ".dart", + ".sql", } # List of common documentation/text file extensions DOC_EXTENSIONS = { - '.md', '.txt', '.rst', '.adoc', '.html', '.htm', '.xml', '.json', '.yaml', '.yml', '.toml' + ".md", + ".txt", + ".rst", + ".adoc", + ".html", + ".htm", + ".xml", + ".json", + ".yaml", + ".yml", + ".toml", } # Maximum file size in bytes (e.g., 1MB) MAX_FILE_SIZE = 1 * 1024 * 1024 + class GitHubConnector: """Connector for interacting with the GitHub API.""" # Directories to skip during file traversal SKIPPED_DIRS = { # Version control - '.git', + ".git", # Dependencies - 'node_modules', - 'vendor', + "node_modules", + "vendor", # Build artifacts / Caches - 'build', - 'dist', - 'target', - '__pycache__', + "build", + "dist", + "target", + "__pycache__", # Virtual environments - 'venv', - '.venv', - 'env', + "venv", + ".venv", + "env", # IDE/Editor config - '.vscode', - '.idea', - '.project', - '.settings', + ".vscode", + ".idea", + ".project", + ".settings", # Temporary / Logs - 'tmp', - 'logs', + "tmp", + "logs", # Add other project-specific irrelevant directories if needed } @@ -68,35 +105,39 @@ class GitHubConnector: logger.info("Successfully authenticated with GitHub API.") except (github_exceptions.AuthenticationFailed, ForbiddenError) as e: logger.error(f"GitHub authentication failed: {e}") - raise ValueError("Invalid GitHub token or insufficient permissions.") + raise ValueError("Invalid GitHub token or insufficient permissions.") from e except Exception as e: logger.error(f"Failed to initialize GitHub client: {e}") - raise + raise e - def get_user_repositories(self) -> List[Dict[str, Any]]: + def get_user_repositories(self) -> list[dict[str, Any]]: """Fetches repositories accessible by the authenticated user.""" repos_data = [] try: # type='owner' fetches repos owned by the user # type='member' fetches repos the user is a collaborator on (including orgs) # type='all' fetches both - for repo in self.gh.repositories(type='all', sort='updated'): - repos_data.append({ - "id": repo.id, - "name": repo.name, - "full_name": repo.full_name, - "private": repo.private, - "url": repo.html_url, - "description": repo.description or "", - "last_updated": repo.updated_at if repo.updated_at else None, - }) + for repo in self.gh.repositories(type="all", sort="updated"): + repos_data.append( + { + "id": repo.id, + "name": repo.name, + "full_name": repo.full_name, + "private": repo.private, + "url": repo.html_url, + "description": repo.description or "", + "last_updated": repo.updated_at if repo.updated_at else None, + } + ) logger.info(f"Fetched {len(repos_data)} repositories.") return repos_data except Exception as e: logger.error(f"Failed to fetch GitHub repositories: {e}") - return [] # Return empty list on error + return [] # Return empty list on error - def get_repository_files(self, repo_full_name: str, path: str = '') -> List[Dict[str, Any]]: + def get_repository_files( + self, repo_full_name: str, path: str = "" + ) -> list[dict[str, Any]]: """ Recursively fetches details of relevant files (code, docs) within a repository path. @@ -110,54 +151,72 @@ class GitHubConnector: """ files_list = [] try: - owner, repo_name = repo_full_name.split('/') + owner, repo_name = repo_full_name.split("/") repo = self.gh.repository(owner, repo_name) if not repo: logger.warning(f"Repository '{repo_full_name}' not found.") return [] - contents = repo.directory_contents(directory_path=path) # Use directory_contents for clarity - + contents = repo.directory_contents( + directory_path=path + ) # Use directory_contents for clarity + # contents returns a list of tuples (name, content_obj) - for item_name, content_item in contents: + for _item_name, content_item in contents: if not isinstance(content_item, Contents): continue - if content_item.type == 'dir': + if content_item.type == "dir": # Check if the directory name is in the skipped list if content_item.name in self.SKIPPED_DIRS: logger.debug(f"Skipping directory: {content_item.path}") - continue # Skip recursion for this directory - + continue # Skip recursion for this directory + # Recursively fetch contents of subdirectory - files_list.extend(self.get_repository_files(repo_full_name, path=content_item.path)) - elif content_item.type == 'file': + files_list.extend( + self.get_repository_files( + repo_full_name, path=content_item.path + ) + ) + elif content_item.type == "file": # Check if the file extension is relevant and size is within limits - file_extension = '.' + content_item.name.split('.')[-1].lower() if '.' in content_item.name else '' + file_extension = ( + "." + content_item.name.split(".")[-1].lower() + if "." in content_item.name + else "" + ) is_code = file_extension in CODE_EXTENSIONS is_doc = file_extension in DOC_EXTENSIONS - + if (is_code or is_doc) and content_item.size <= MAX_FILE_SIZE: - files_list.append({ - "path": content_item.path, - "sha": content_item.sha, - "url": content_item.html_url, - "size": content_item.size, - "type": "code" if is_code else "doc" - }) + files_list.append( + { + "path": content_item.path, + "sha": content_item.sha, + "url": content_item.html_url, + "size": content_item.size, + "type": "code" if is_code else "doc", + } + ) elif content_item.size > MAX_FILE_SIZE: - logger.debug(f"Skipping large file: {content_item.path} ({content_item.size} bytes)") + logger.debug( + f"Skipping large file: {content_item.path} ({content_item.size} bytes)" + ) else: - logger.debug(f"Skipping irrelevant file type: {content_item.path}") + logger.debug( + f"Skipping irrelevant file type: {content_item.path}" + ) except (NotFoundError, ForbiddenError) as e: - logger.warning(f"Cannot access path '{path}' in '{repo_full_name}': {e}") + logger.warning(f"Cannot access path '{path}' in '{repo_full_name}': {e}") except Exception as e: - logger.error(f"Failed to get files for {repo_full_name} at path '{path}': {e}") + logger.error( + f"Failed to get files for {repo_full_name} at path '{path}': {e}" + ) # Return what we have collected so far in case of partial failure - + return files_list - def get_file_content(self, repo_full_name: str, file_path: str) -> Optional[str]: + def get_file_content(self, repo_full_name: str, file_path: str) -> str | None: """ Fetches the decoded content of a specific file. @@ -169,43 +228,69 @@ class GitHubConnector: The decoded file content as a string, or None if fetching fails or file is too large. """ try: - owner, repo_name = repo_full_name.split('/') + owner, repo_name = repo_full_name.split("/") repo = self.gh.repository(owner, repo_name) if not repo: - logger.warning(f"Repository '{repo_full_name}' not found when fetching file '{file_path}'.") + logger.warning( + f"Repository '{repo_full_name}' not found when fetching file '{file_path}'." + ) return None - - content_item = repo.file_contents(path=file_path) # Use file_contents for clarity - if not content_item or not isinstance(content_item, Contents) or content_item.type != 'file': - logger.warning(f"File '{file_path}' not found or is not a file in '{repo_full_name}'.") + content_item = repo.file_contents( + path=file_path + ) # Use file_contents for clarity + + if ( + not content_item + or not isinstance(content_item, Contents) + or content_item.type != "file" + ): + logger.warning( + f"File '{file_path}' not found or is not a file in '{repo_full_name}'." + ) return None - + if content_item.size > MAX_FILE_SIZE: - logger.warning(f"File '{file_path}' in '{repo_full_name}' exceeds max size ({content_item.size} > {MAX_FILE_SIZE}). Skipping content fetch.") + logger.warning( + f"File '{file_path}' in '{repo_full_name}' exceeds max size ({content_item.size} > {MAX_FILE_SIZE}). Skipping content fetch." + ) return None # Content is base64 encoded if content_item.content: try: - decoded_content = base64.b64decode(content_item.content).decode('utf-8') + decoded_content = base64.b64decode(content_item.content).decode( + "utf-8" + ) return decoded_content except UnicodeDecodeError: - logger.warning(f"Could not decode file '{file_path}' in '{repo_full_name}' as UTF-8. Trying with 'latin-1'.") + logger.warning( + f"Could not decode file '{file_path}' in '{repo_full_name}' as UTF-8. Trying with 'latin-1'." + ) try: # Try a fallback encoding - decoded_content = base64.b64decode(content_item.content).decode('latin-1') + decoded_content = base64.b64decode(content_item.content).decode( + "latin-1" + ) return decoded_content except Exception as decode_err: - logger.error(f"Failed to decode file '{file_path}' with fallback encoding: {decode_err}") - return None # Give up if fallback fails + logger.error( + f"Failed to decode file '{file_path}' with fallback encoding: {decode_err}" + ) + return None # Give up if fallback fails else: - logger.warning(f"No content returned for file '{file_path}' in '{repo_full_name}'. It might be empty.") - return "" # Return empty string for empty files + logger.warning( + f"No content returned for file '{file_path}' in '{repo_full_name}'. It might be empty." + ) + return "" # Return empty string for empty files except (NotFoundError, ForbiddenError) as e: - logger.warning(f"Cannot access file '{file_path}' in '{repo_full_name}': {e}") - return None + logger.warning( + f"Cannot access file '{file_path}' in '{repo_full_name}': {e}" + ) + return None except Exception as e: - logger.error(f"Failed to get content for file '{file_path}' in '{repo_full_name}': {e}") - return None + logger.error( + f"Failed to get content for file '{file_path}' in '{repo_full_name}': {e}" + ) + return None diff --git a/surfsense_backend/app/connectors/linear_connector.py b/surfsense_backend/app/connectors/linear_connector.py index 52b7704..b4c54fd 100644 --- a/surfsense_backend/app/connectors/linear_connector.py +++ b/surfsense_backend/app/connectors/linear_connector.py @@ -5,96 +5,94 @@ A module for retrieving issues and comments from Linear. Allows fetching issue lists and their comments with date range filtering. """ -import requests from datetime import datetime -from typing import Dict, List, Optional, Tuple, Any, Union +from typing import Any + +import requests class LinearConnector: """Class for retrieving issues and comments from Linear.""" - - def __init__(self, token: str = None): + + def __init__(self, token: str | None = None): """ Initialize the LinearConnector class. - + Args: token: Linear API token (optional, can be set later with set_token) """ self.token = token self.api_url = "https://api.linear.app/graphql" - + def set_token(self, token: str) -> None: """ Set the Linear API token. - + Args: token: Linear API token """ self.token = token - - def get_headers(self) -> Dict[str, str]: + + def get_headers(self) -> dict[str, str]: """ Get headers for Linear API requests. - + Returns: Dictionary of headers - + Raises: ValueError: If no Linear token has been set """ if not self.token: raise ValueError("Linear token not initialized. Call set_token() first.") - - return { - 'Content-Type': 'application/json', - 'Authorization': self.token - } - - def execute_graphql_query(self, query: str, variables: Dict[str, Any] = None) -> Dict[str, Any]: + + return {"Content-Type": "application/json", "Authorization": self.token} + + def execute_graphql_query( + self, query: str, variables: dict[str, Any] | None = None + ) -> dict[str, Any]: """ Execute a GraphQL query against the Linear API. - + Args: query: GraphQL query string variables: Variables for the GraphQL query (optional) - + Returns: Response data from the API - + Raises: ValueError: If no Linear token has been set Exception: If the API request fails """ if not self.token: raise ValueError("Linear token not initialized. Call set_token() first.") - + headers = self.get_headers() - payload = {'query': query} - + payload = {"query": query} + if variables: - payload['variables'] = variables - - response = requests.post( - self.api_url, - headers=headers, - json=payload - ) - + payload["variables"] = variables + + response = requests.post(self.api_url, headers=headers, json=payload) + if response.status_code == 200: return response.json() else: - raise Exception(f"Query failed with status code {response.status_code}: {response.text}") - - def get_all_issues(self, include_comments: bool = True) -> List[Dict[str, Any]]: + raise Exception( + f"Query failed with status code {response.status_code}: {response.text}" + ) + + def get_all_issues(self, include_comments: bool = True) -> list[dict[str, Any]]: """ Fetch all issues from Linear. - + Args: include_comments: Whether to include comments in the response - + Returns: List of issue objects - + Raises: ValueError: If no Linear token has been set Exception: If the API request fails @@ -116,7 +114,7 @@ class LinearConnector: } } """ - + query = f""" query {{ issues {{ @@ -147,29 +145,30 @@ class LinearConnector: }} }} """ - + result = self.execute_graphql_query(query) - + # Extract issues from the response - if "data" in result and "issues" in result["data"] and "nodes" in result["data"]["issues"]: + if ( + "data" in result + and "issues" in result["data"] + and "nodes" in result["data"]["issues"] + ): return result["data"]["issues"]["nodes"] - + return [] - + def get_issues_by_date_range( - self, - start_date: str, - end_date: str, - include_comments: bool = True - ) -> Tuple[List[Dict[str, Any]], Optional[str]]: + self, start_date: str, end_date: str, include_comments: bool = True + ) -> tuple[list[dict[str, Any]], str | None]: """ Fetch issues within a date range. - + Args: start_date: Start date in YYYY-MM-DD format end_date: End date in YYYY-MM-DD format (inclusive) include_comments: Whether to include comments in the response - + Returns: Tuple containing (issues list, error message or None) """ @@ -194,7 +193,7 @@ class LinearConnector: } } """ - + # Query issues that were either created OR updated within the date range # This ensures we catch both new issues and updated existing issues query = f""" @@ -250,58 +249,65 @@ class LinearConnector: }} }} """ - + try: all_issues = [] has_next_page = True cursor = None - + # Handle pagination to get all issues while has_next_page: variables = {"after": cursor} if cursor else {} result = self.execute_graphql_query(query, variables) - + # Check for errors if "errors" in result: - error_message = "; ".join([error.get("message", "Unknown error") for error in result["errors"]]) + error_message = "; ".join( + [ + error.get("message", "Unknown error") + for error in result["errors"] + ] + ) return [], f"GraphQL errors: {error_message}" - + # Extract issues from the response if "data" in result and "issues" in result["data"]: issues_page = result["data"]["issues"] - + # Add issues from this page if "nodes" in issues_page: all_issues.extend(issues_page["nodes"]) - + # Check if there are more pages if "pageInfo" in issues_page: page_info = issues_page["pageInfo"] has_next_page = page_info.get("hasNextPage", False) - cursor = page_info.get("endCursor") if has_next_page else None + cursor = ( + page_info.get("endCursor") if has_next_page else None + ) else: has_next_page = False else: has_next_page = False - + if not all_issues: return [], "No issues found in the specified date range." - + return all_issues, None - + except Exception as e: - return [], f"Error fetching issues: {str(e)}" - + return [], f"Error fetching issues: {e!s}" + except ValueError as e: - return [], f"Invalid date format: {str(e)}. Please use YYYY-MM-DD." - - def format_issue(self, issue: Dict[str, Any]) -> Dict[str, Any]: + return [], f"Invalid date format: {e!s}. Please use YYYY-MM-DD." + + def format_issue(self, issue: dict[str, Any]) -> dict[str, Any]: """ Format an issue for easier consumption. - + Args: issue: The issue object from Linear API - + Returns: Formatted issue dictionary """ @@ -311,23 +317,37 @@ class LinearConnector: "identifier": issue.get("identifier", ""), "title": issue.get("title", ""), "description": issue.get("description", ""), - "state": issue.get("state", {}).get("name", "Unknown") if issue.get("state") else "Unknown", - "state_type": issue.get("state", {}).get("type", "Unknown") if issue.get("state") else "Unknown", + "state": issue.get("state", {}).get("name", "Unknown") + if issue.get("state") + else "Unknown", + "state_type": issue.get("state", {}).get("type", "Unknown") + if issue.get("state") + else "Unknown", "created_at": issue.get("createdAt", ""), "updated_at": issue.get("updatedAt", ""), "creator": { - "id": issue.get("creator", {}).get("id", "") if issue.get("creator") else "", - "name": issue.get("creator", {}).get("name", "Unknown") if issue.get("creator") else "Unknown", - "email": issue.get("creator", {}).get("email", "") if issue.get("creator") else "" - } if issue.get("creator") else {"id": "", "name": "Unknown", "email": ""}, + "id": issue.get("creator", {}).get("id", "") + if issue.get("creator") + else "", + "name": issue.get("creator", {}).get("name", "Unknown") + if issue.get("creator") + else "Unknown", + "email": issue.get("creator", {}).get("email", "") + if issue.get("creator") + else "", + } + if issue.get("creator") + else {"id": "", "name": "Unknown", "email": ""}, "assignee": { "id": issue.get("assignee", {}).get("id", ""), "name": issue.get("assignee", {}).get("name", "Unknown"), - "email": issue.get("assignee", {}).get("email", "") - } if issue.get("assignee") else None, - "comments": [] + "email": issue.get("assignee", {}).get("email", ""), + } + if issue.get("assignee") + else None, + "comments": [], } - + # Extract comments if available if "comments" in issue and "nodes" in issue["comments"]: for comment in issue["comments"]["nodes"]: @@ -337,85 +357,93 @@ class LinearConnector: "created_at": comment.get("createdAt", ""), "updated_at": comment.get("updatedAt", ""), "user": { - "id": comment.get("user", {}).get("id", "") if comment.get("user") else "", - "name": comment.get("user", {}).get("name", "Unknown") if comment.get("user") else "Unknown", - "email": comment.get("user", {}).get("email", "") if comment.get("user") else "" - } if comment.get("user") else {"id": "", "name": "Unknown", "email": ""} + "id": comment.get("user", {}).get("id", "") + if comment.get("user") + else "", + "name": comment.get("user", {}).get("name", "Unknown") + if comment.get("user") + else "Unknown", + "email": comment.get("user", {}).get("email", "") + if comment.get("user") + else "", + } + if comment.get("user") + else {"id": "", "name": "Unknown", "email": ""}, } formatted["comments"].append(formatted_comment) - + return formatted - - def format_issue_to_markdown(self, issue: Dict[str, Any]) -> str: + + def format_issue_to_markdown(self, issue: dict[str, Any]) -> str: """ Convert an issue to markdown format. - + Args: issue: The issue object (either raw or formatted) - + Returns: Markdown string representation of the issue """ # Format the issue if it's not already formatted if "identifier" not in issue: issue = self.format_issue(issue) - + # Build the markdown content markdown = f"# {issue.get('identifier', 'No ID')}: {issue.get('title', 'No Title')}\n\n" - - if issue.get('state'): + + if issue.get("state"): markdown += f"**Status:** {issue['state']}\n\n" - - if issue.get('assignee') and issue['assignee'].get('name'): + + if issue.get("assignee") and issue["assignee"].get("name"): markdown += f"**Assignee:** {issue['assignee']['name']}\n" - - if issue.get('creator') and issue['creator'].get('name'): + + if issue.get("creator") and issue["creator"].get("name"): markdown += f"**Created by:** {issue['creator']['name']}\n" - - if issue.get('created_at'): - created_date = self.format_date(issue['created_at']) + + if issue.get("created_at"): + created_date = self.format_date(issue["created_at"]) markdown += f"**Created:** {created_date}\n" - - if issue.get('updated_at'): - updated_date = self.format_date(issue['updated_at']) + + if issue.get("updated_at"): + updated_date = self.format_date(issue["updated_at"]) markdown += f"**Updated:** {updated_date}\n\n" - - if issue.get('description'): + + if issue.get("description"): markdown += f"## Description\n\n{issue['description']}\n\n" - - if issue.get('comments'): + + if issue.get("comments"): markdown += f"## Comments ({len(issue['comments'])})\n\n" - - for comment in issue['comments']: + + for comment in issue["comments"]: user_name = "Unknown" - if comment.get('user') and comment['user'].get('name'): - user_name = comment['user']['name'] - + if comment.get("user") and comment["user"].get("name"): + user_name = comment["user"]["name"] + comment_date = "Unknown date" - if comment.get('created_at'): - comment_date = self.format_date(comment['created_at']) - + if comment.get("created_at"): + comment_date = self.format_date(comment["created_at"]) + markdown += f"### {user_name} ({comment_date})\n\n{comment.get('body', '')}\n\n---\n\n" - + return markdown - + @staticmethod def format_date(iso_date: str) -> str: """ Format an ISO date string to a more readable format. - + Args: iso_date: ISO format date string - + Returns: Formatted date string """ if not iso_date or not isinstance(iso_date, str): return "Unknown date" - + try: - dt = datetime.fromisoformat(iso_date.replace('Z', '+00:00')) - return dt.strftime('%Y-%m-%d %H:%M:%S') + dt = datetime.fromisoformat(iso_date.replace("Z", "+00:00")) + return dt.strftime("%Y-%m-%d %H:%M:%S") except ValueError: return iso_date diff --git a/surfsense_backend/app/connectors/notion_history.py b/surfsense_backend/app/connectors/notion_history.py index da719c0..87948c6 100644 --- a/surfsense_backend/app/connectors/notion_history.py +++ b/surfsense_backend/app/connectors/notion_history.py @@ -1,176 +1,182 @@ from notion_client import Client + class NotionHistoryConnector: def __init__(self, token): """ Initialize the NotionPageFetcher with a token. - + Args: token (str): Notion integration token """ self.notion = Client(auth=token) - + def get_all_pages(self, start_date=None, end_date=None): """ Fetches all pages shared with your integration and their content. - + Args: start_date (str, optional): ISO 8601 date string (e.g., "2023-01-01T00:00:00Z") end_date (str, optional): ISO 8601 date string (e.g., "2023-12-31T23:59:59Z") - + Returns: list: List of dictionaries containing page data """ # Build the filter for the search # Note: Notion API requires specific filter structure search_params = {} - + # Filter for pages only (not databases) - search_params["filter"] = { - "value": "page", - "property": "object" - } - + search_params["filter"] = {"value": "page", "property": "object"} + # Add date filters if provided if start_date or end_date: date_filter = {} - + if start_date: date_filter["on_or_after"] = start_date - + if end_date: date_filter["on_or_before"] = end_date - + # Add the date filter to the search params if date_filter: search_params["sort"] = { "direction": "descending", - "timestamp": "last_edited_time" + "timestamp": "last_edited_time", } - + # First, get a list of all pages the integration has access to search_results = self.notion.search(**search_params) - + pages = search_results["results"] all_page_data = [] - + for page in pages: page_id = page["id"] - + # Get detailed page information page_content = self.get_page_content(page_id) - - all_page_data.append({ - "page_id": page_id, - "title": self.get_page_title(page), - "content": page_content - }) - + + all_page_data.append( + { + "page_id": page_id, + "title": self.get_page_title(page), + "content": page_content, + } + ) + return all_page_data - + def get_page_title(self, page): """ Extracts the title from a page object. - + Args: page (dict): Notion page object - + Returns: str: Page title or a fallback string """ # Title can be in different properties depending on the page type if "properties" in page: # Try to find a title property - for prop_name, prop_data in page["properties"].items(): + for _prop_name, prop_data in page["properties"].items(): if prop_data["type"] == "title" and len(prop_data["title"]) > 0: - return " ".join([text_obj["plain_text"] for text_obj in prop_data["title"]]) - + return " ".join( + [text_obj["plain_text"] for text_obj in prop_data["title"]] + ) + # If no title found, return the page ID as fallback return f"Untitled page ({page['id']})" - + def get_page_content(self, page_id): """ Fetches the content (blocks) of a specific page. - + Args: page_id (str): The ID of the page to fetch - + Returns: list: List of processed blocks from the page """ blocks = [] has_more = True cursor = None - + # Paginate through all blocks while has_more: if cursor: - response = self.notion.blocks.children.list(block_id=page_id, start_cursor=cursor) + response = self.notion.blocks.children.list( + block_id=page_id, start_cursor=cursor + ) else: response = self.notion.blocks.children.list(block_id=page_id) - + blocks.extend(response["results"]) has_more = response["has_more"] - + if has_more: cursor = response["next_cursor"] - + # Process nested blocks recursively processed_blocks = [] for block in blocks: processed_block = self.process_block(block) processed_blocks.append(processed_block) - + return processed_blocks - + def process_block(self, block): """ Processes a block and recursively fetches any child blocks. - + Args: block (dict): The block to process - + Returns: dict: Processed block with content and children """ block_id = block["id"] block_type = block["type"] - + # Extract block content based on its type content = self.extract_block_content(block) - + # Check if block has children has_children = block.get("has_children", False) child_blocks = [] - + if has_children: # Fetch and process child blocks children_response = self.notion.blocks.children.list(block_id=block_id) for child_block in children_response["results"]: child_blocks.append(self.process_block(child_block)) - + return { "id": block_id, "type": block_type, "content": content, - "children": child_blocks + "children": child_blocks, } - + def extract_block_content(self, block): """ Extracts the content from a block based on its type. - + Args: block (dict): The block to extract content from - + Returns: str: Extracted content as a string """ block_type = block["type"] - + # Different block types have different structures if block_type in block and "rich_text" in block[block_type]: - return "".join([text_obj["plain_text"] for text_obj in block[block_type]["rich_text"]]) + return "".join( + [text_obj["plain_text"] for text_obj in block[block_type]["rich_text"]] + ) elif block_type == "image": # Instead of returning the raw URL which may contain sensitive AWS credentials, # return a placeholder or reference to the image @@ -183,18 +189,21 @@ class NotionHistoryConnector: # Only return the domain part of external URLs to avoid potential sensitive parameters try: from urllib.parse import urlparse + parsed_url = urlparse(url) return f"[External Image from {parsed_url.netloc}]" - except: + except Exception: return "[External Image]" elif block_type == "code": language = block["code"]["language"] - code_text = "".join([text_obj["plain_text"] for text_obj in block["code"]["rich_text"]]) + code_text = "".join( + [text_obj["plain_text"] for text_obj in block["code"]["rich_text"]] + ) return f"```{language}\n{code_text}\n```" elif block_type == "equation": return block["equation"]["expression"] # Add more block types as needed - + # Return empty string for unsupported block types return "" @@ -203,23 +212,23 @@ class NotionHistoryConnector: # if __name__ == "__main__": # # Simple example of how to use this module # import argparse - + # parser = argparse.ArgumentParser(description="Fetch Notion pages using an integration token") # parser.add_argument("--token", help="Your Notion integration token") # parser.add_argument("--start-date", help="Start date in ISO format (e.g., 2023-01-01T00:00:00Z)") # parser.add_argument("--end-date", help="End date in ISO format (e.g., 2023-12-31T23:59:59Z)") # args = parser.parse_args() - + # token = args.token # if not token: # token = input("Enter your Notion integration token: ") - + # fetcher = NotionPageFetcher(token) - + # try: # pages = fetcher.get_all_pages(args.start_date, args.end_date) # print(f"Fetched {len(pages)} pages from Notion") # for page in pages: # print(f"- {page['title']}") # except Exception as e: -# print(f"Error: {str(e)}") \ No newline at end of file +# print(f"Error: {str(e)}") diff --git a/surfsense_backend/app/connectors/slack_history.py b/surfsense_backend/app/connectors/slack_history.py index 13e366c..982dc8a 100644 --- a/surfsense_backend/app/connectors/slack_history.py +++ b/surfsense_backend/app/connectors/slack_history.py @@ -5,47 +5,48 @@ A module for retrieving conversation history from Slack channels. Allows fetching channel lists and message history with date range filtering. """ -import time # Added import -import logging # Added import +import logging # Added import +import time # Added import +from datetime import datetime +from typing import Any + from slack_sdk import WebClient from slack_sdk.errors import SlackApiError -from datetime import datetime -from typing import Dict, List, Optional, Tuple, Any -logger = logging.getLogger(__name__) # Added logger +logger = logging.getLogger(__name__) # Added logger class SlackHistory: """Class for retrieving conversation history from Slack channels.""" - - def __init__(self, token: str = None): + + def __init__(self, token: str | None = None): """ Initialize the SlackHistory class. - + Args: token: Slack API token (optional, can be set later with set_token) """ self.client = WebClient(token=token) if token else None - + def set_token(self, token: str) -> None: """ Set the Slack API token. - + Args: token: Slack API token """ self.client = WebClient(token=token) - - def get_all_channels(self, include_private: bool = True) -> List[Dict[str, Any]]: + + def get_all_channels(self, include_private: bool = True) -> list[dict[str, Any]]: """ Fetch all channels that the bot has access to, with rate limit handling. - + Args: include_private: Whether to include private channels - + Returns: List of dictionaries, each representing a channel with id, name, is_private, is_member. - + Raises: ValueError: If no Slack client has been initialized SlackApiError: If there's an unrecoverable error calling the Slack API @@ -53,8 +54,8 @@ class SlackHistory: """ if not self.client: raise ValueError("Slack client not initialized. Call set_token() first.") - - channels_list = [] # Changed from dict to list + + channels_list = [] # Changed from dict to list types = "public_channel" if include_private: types += ",private_channel" @@ -65,16 +66,16 @@ class SlackHistory: while is_first_request or next_cursor: try: if not is_first_request: # Add delay only for paginated requests - logger.info(f"Paginating for channels, waiting 3 seconds before next call. Cursor: {next_cursor}") + logger.info( + f"Paginating for channels, waiting 3 seconds before next call. Cursor: {next_cursor}" + ) time.sleep(3) current_limit = 1000 # Max limit api_result = self.client.conversations_list( - types=types, - cursor=next_cursor, - limit=current_limit + types=types, cursor=next_cursor, limit=current_limit ) - + channels_on_page = api_result["channels"] for channel in channels_on_page: if "name" in channel and "id" in channel: @@ -86,12 +87,13 @@ class SlackHistory: # It indicates if the authenticated user (bot) is a member. # For public channels, this might be true or the API might not focus on it # if the bot can read it anyway. For private, it's crucial. - "is_member": channel.get("is_member", False) + "is_member": channel.get("is_member", False), } channels_list.append(channel_data) else: - logger.warning(f"Channel found with missing name or id. Data: {channel}") - + logger.warning( + f"Channel found with missing name or id. Data: {channel}" + ) next_cursor = api_result.get("response_metadata", {}).get("next_cursor") is_first_request = False # Subsequent requests are not the first @@ -101,57 +103,65 @@ class SlackHistory: except SlackApiError as e: if e.response is not None and e.response.status_code == 429: - retry_after_header = e.response.headers.get('Retry-After') + retry_after_header = e.response.headers.get("Retry-After") wait_duration = 60 # Default wait time if retry_after_header and retry_after_header.isdigit(): wait_duration = int(retry_after_header) - - logger.warning(f"Slack API rate limit hit while fetching channels. Waiting for {wait_duration} seconds. Cursor: {next_cursor}") + + logger.warning( + f"Slack API rate limit hit while fetching channels. Waiting for {wait_duration} seconds. Cursor: {next_cursor}" + ) time.sleep(wait_duration) # The loop will continue, retrying with the same cursor else: # Not a 429 error, or no response object, re-raise - raise SlackApiError(f"Error retrieving channels: {e}", e.response) + raise SlackApiError( + f"Error retrieving channels: {e}", e.response + ) from e except Exception as general_error: # Handle other potential errors like network issues if necessary, or re-raise - logger.error(f"An unexpected error occurred during channel fetching: {general_error}") - raise RuntimeError(f"An unexpected error occurred during channel fetching: {general_error}") - + logger.error( + f"An unexpected error occurred during channel fetching: {general_error}" + ) + raise RuntimeError( + f"An unexpected error occurred during channel fetching: {general_error}" + ) from general_error + return channels_list - + def get_conversation_history( - self, - channel_id: str, - limit: int = 1000, - oldest: Optional[int] = None, - latest: Optional[int] = None - ) -> List[Dict[str, Any]]: + self, + channel_id: str, + limit: int = 1000, + oldest: int | None = None, + latest: int | None = None, + ) -> list[dict[str, Any]]: """ Fetch conversation history for a channel. - + Args: channel_id: The ID of the channel to fetch history for limit: Maximum number of messages to return per request (default 1000) oldest: Start of time range (Unix timestamp) latest: End of time range (Unix timestamp) - + Returns: List of message objects - + Raises: ValueError: If no Slack client has been initialized SlackApiError: If there's an error calling the Slack API """ if not self.client: raise ValueError("Slack client not initialized. Call set_token() first.") - + messages = [] next_cursor = None - + while True: try: # Proactive delay for conversations.history (Tier 3) - time.sleep(1.2) # Wait 1.2 seconds before each history call. + time.sleep(1.2) # Wait 1.2 seconds before each history call. kwargs = { "channel": channel_id, @@ -163,16 +173,19 @@ class SlackHistory: kwargs["latest"] = latest if next_cursor: kwargs["cursor"] = next_cursor - + current_api_call_successful = False - result = None # Ensure result is defined + result = None # Ensure result is defined try: result = self.client.conversations_history(**kwargs) current_api_call_successful = True except SlackApiError as e_history: - if e_history.response is not None and e_history.response.status_code == 429: - retry_after_str = e_history.response.headers.get('Retry-After') - wait_time = 60 # Default + if ( + e_history.response is not None + and e_history.response.status_code == 429 + ): + retry_after_str = e_history.response.headers.get("Retry-After") + wait_time = 60 # Default if retry_after_str and retry_after_str.isdigit(): wait_time = int(retry_after_str) logger.warning( @@ -182,47 +195,54 @@ class SlackHistory: time.sleep(wait_time) # current_api_call_successful remains False, loop will retry this page else: - raise # Re-raise to outer handler for not_in_channel or other SlackApiErrors - + raise # Re-raise to outer handler for not_in_channel or other SlackApiErrors + if not current_api_call_successful: - continue # Retry the current page fetch due to handled rate limit + continue # Retry the current page fetch due to handled rate limit # Process result if successful batch = result["messages"] messages.extend(batch) - + if result.get("has_more", False) and len(messages) < limit: next_cursor = result["response_metadata"]["next_cursor"] else: - break # Exit pagination loop - - except SlackApiError as e: # Outer catch for not_in_channel or unhandled SlackApiErrors from inner try - if (e.response is not None and - hasattr(e.response, 'data') and - isinstance(e.response.data, dict) and - e.response.data.get('error') == 'not_in_channel'): + break # Exit pagination loop + + except SlackApiError as e: # Outer catch for not_in_channel or unhandled SlackApiErrors from inner try + if ( + e.response is not None + and hasattr(e.response, "data") + and isinstance(e.response.data, dict) + and e.response.data.get("error") == "not_in_channel" + ): logger.warning( f"Bot is not in channel '{channel_id}'. Cannot fetch history. " "Please add the bot to this channel." ) - return [] + return [] # For other SlackApiErrors from inner block or this level - raise SlackApiError(f"Error retrieving history for channel {channel_id}: {e}", e.response) - except Exception as general_error: # Catch any other unexpected errors - logger.error(f"Unexpected error in get_conversation_history for channel {channel_id}: {general_error}") + raise SlackApiError( + f"Error retrieving history for channel {channel_id}: {e}", + e.response, + ) from e + except Exception as general_error: # Catch any other unexpected errors + logger.error( + f"Unexpected error in get_conversation_history for channel {channel_id}: {general_error}" + ) # Re-raise the general error to allow higher-level handling or visibility - raise - + raise general_error from general_error + return messages[:limit] @staticmethod - def convert_date_to_timestamp(date_str: str) -> Optional[int]: + def convert_date_to_timestamp(date_str: str) -> int | None: """ Convert a date string in format YYYY-MM-DD to Unix timestamp. - + Args: date_str: Date string in YYYY-MM-DD format - + Returns: Unix timestamp (seconds since epoch) or None if invalid format """ @@ -231,67 +251,63 @@ class SlackHistory: return int(dt.timestamp()) except ValueError: return None - + def get_history_by_date_range( - self, - channel_id: str, - start_date: str, - end_date: str, - limit: int = 1000 - ) -> Tuple[List[Dict[str, Any]], Optional[str]]: + self, channel_id: str, start_date: str, end_date: str, limit: int = 1000 + ) -> tuple[list[dict[str, Any]], str | None]: """ Fetch conversation history within a date range. - + Args: channel_id: The ID of the channel to fetch history for start_date: Start date in YYYY-MM-DD format end_date: End date in YYYY-MM-DD format (inclusive) limit: Maximum number of messages to return - + Returns: Tuple containing (messages list, error message or None) """ oldest = self.convert_date_to_timestamp(start_date) if not oldest: - return [], f"Invalid start date format: {start_date}. Please use YYYY-MM-DD." - + return ( + [], + f"Invalid start date format: {start_date}. Please use YYYY-MM-DD.", + ) + latest = self.convert_date_to_timestamp(end_date) if not latest: return [], f"Invalid end date format: {end_date}. Please use YYYY-MM-DD." - + # Add one day to end date to make it inclusive latest += 86400 # seconds in a day - + try: messages = self.get_conversation_history( - channel_id=channel_id, - limit=limit, - oldest=oldest, - latest=latest + channel_id=channel_id, limit=limit, oldest=oldest, latest=latest ) return messages, None except SlackApiError as e: - return [], f"Slack API error: {str(e)}" + return [], f"Slack API error: {e!s}" except ValueError as e: return [], str(e) - - def get_user_info(self, user_id: str) -> Dict[str, Any]: + + def get_user_info(self, user_id: str) -> dict[str, Any]: """ Get information about a user. - + Args: user_id: The ID of the user to get info for - + Returns: User information dictionary - + Raises: ValueError: If no Slack client has been initialized SlackApiError: If there's an error calling the Slack API """ if not self.client: raise ValueError("Slack client not initialized. Call set_token() first.") - + while True: try: # Proactive delay for users.info (Tier 4) - generally not needed unless called extremely rapidly. @@ -299,46 +315,60 @@ class SlackHistory: # time.sleep(0.6) # Optional: ~100 req/min if ever needed. result = self.client.users_info(user=user_id) - return result["user"] # Success, return and exit loop implicitly + return result["user"] # Success, return and exit loop implicitly except SlackApiError as e_user_info: - if e_user_info.response is not None and e_user_info.response.status_code == 429: - retry_after_str = e_user_info.response.headers.get('Retry-After') + if ( + e_user_info.response is not None + and e_user_info.response.status_code == 429 + ): + retry_after_str = e_user_info.response.headers.get("Retry-After") wait_time = 30 # Default for Tier 4, can be adjusted if retry_after_str and retry_after_str.isdigit(): wait_time = int(retry_after_str) - logger.warning(f"Rate limited by Slack on users.info for user {user_id}. Retrying after {wait_time} seconds.") + logger.warning( + f"Rate limited by Slack on users.info for user {user_id}. Retrying after {wait_time} seconds." + ) time.sleep(wait_time) continue # Retry the API call else: # Not a 429 error, or no response object, re-raise - raise SlackApiError(f"Error retrieving user info for {user_id}: {e_user_info}", e_user_info.response) - except Exception as general_error: # Catch any other unexpected errors - logger.error(f"Unexpected error in get_user_info for user {user_id}: {general_error}") - raise # Re-raise unexpected errors - - def format_message(self, msg: Dict[str, Any], include_user_info: bool = False) -> Dict[str, Any]: + raise SlackApiError( + f"Error retrieving user info for {user_id}: {e_user_info}", + e_user_info.response, + ) from e_user_info + except Exception as general_error: # Catch any other unexpected errors + logger.error( + f"Unexpected error in get_user_info for user {user_id}: {general_error}" + ) + raise general_error from general_error # Re-raise unexpected errors + + def format_message( + self, msg: dict[str, Any], include_user_info: bool = False + ) -> dict[str, Any]: """ Format a message for easier consumption. - + Args: msg: The message object from Slack API include_user_info: Whether to fetch and include user info - + Returns: Formatted message dictionary """ formatted = { "text": msg.get("text", ""), "timestamp": msg.get("ts"), - "datetime": datetime.fromtimestamp(float(msg.get("ts", 0))).strftime('%Y-%m-%d %H:%M:%S'), + "datetime": datetime.fromtimestamp(float(msg.get("ts", 0))).strftime( + "%Y-%m-%d %H:%M:%S" + ), "user_id": msg.get("user", "UNKNOWN"), "has_attachments": bool(msg.get("attachments")), "has_files": bool(msg.get("files")), "thread_ts": msg.get("thread_ts"), "is_thread": "thread_ts" in msg, } - + if include_user_info and "user" in msg and self.client: try: user_info = self.get_user_info(msg["user"]) @@ -347,7 +377,7 @@ class SlackHistory: except Exception: # If we can't get user info, just continue without it formatted["user_name"] = "Unknown" - + return formatted @@ -388,4 +418,4 @@ if __name__ == "__main__": except Exception as e: print(f"Error: {e}") -""" \ No newline at end of file +""" diff --git a/surfsense_backend/app/connectors/test_github_connector.py b/surfsense_backend/app/connectors/test_github_connector.py index d55ebf3..6ed9ffa 100644 --- a/surfsense_backend/app/connectors/test_github_connector.py +++ b/surfsense_backend/app/connectors/test_github_connector.py @@ -1,23 +1,24 @@ import unittest -from unittest.mock import patch, Mock from datetime import datetime +from unittest.mock import Mock, patch + +from github3.exceptions import ForbiddenError # Import the specific exception # Adjust the import path based on the actual location if test_github_connector.py # is not in the same directory as github_connector.py or if paths are set up differently. # Assuming surfsend_backend/app/connectors/test_github_connector.py from surfsense_backend.app.connectors.github_connector import GitHubConnector -from github3.exceptions import ForbiddenError # Import the specific exception + class TestGitHubConnector(unittest.TestCase): - - @patch('surfsense_backend.app.connectors.github_connector.github_login') + @patch("surfsense_backend.app.connectors.github_connector.github_login") def test_get_user_repositories_uses_type_all(self, mock_github_login): # Mock the GitHub client object and its methods mock_gh_instance = Mock() mock_github_login.return_value = mock_gh_instance # Mock the self.gh.me() call in __init__ to prevent an actual API call - mock_gh_instance.me.return_value = Mock() # Simple mock to pass initialization + mock_gh_instance.me.return_value = Mock() # Simple mock to pass initialization # Prepare mock repository data mock_repo1_data = Mock() @@ -27,7 +28,9 @@ class TestGitHubConnector(unittest.TestCase): mock_repo1_data.private = False mock_repo1_data.html_url = "http://example.com/user/repo1" mock_repo1_data.description = "Test repo 1" - mock_repo1_data.updated_at = datetime(2023, 1, 1, 10, 30, 0) # Added time component + mock_repo1_data.updated_at = datetime( + 2023, 1, 1, 10, 30, 0 + ) # Added time component mock_repo2_data = Mock() mock_repo2_data.id = 2 @@ -36,8 +39,10 @@ class TestGitHubConnector(unittest.TestCase): mock_repo2_data.private = True mock_repo2_data.html_url = "http://example.com/org/org-repo" mock_repo2_data.description = "Org repo" - mock_repo2_data.updated_at = datetime(2023, 1, 2, 12, 0, 0) # Added time component - + mock_repo2_data.updated_at = datetime( + 2023, 1, 2, 12, 0, 0 + ) # Added time component + # Configure the mock for gh.repositories() call # This method is an iterator, so it should return an iterable (e.g., a list) mock_gh_instance.repositories.return_value = [mock_repo1_data, mock_repo2_data] @@ -46,26 +51,38 @@ class TestGitHubConnector(unittest.TestCase): repositories = connector.get_user_repositories() # Assert that gh.repositories was called correctly - mock_gh_instance.repositories.assert_called_once_with(type='all', sort='updated') + mock_gh_instance.repositories.assert_called_once_with( + type="all", sort="updated" + ) # Assert the structure and content of the returned data expected_repositories = [ { - "id": 1, "name": "repo1", "full_name": "user/repo1", "private": False, - "url": "http://example.com/user/repo1", "description": "Test repo 1", - "last_updated": datetime(2023, 1, 1, 10, 30, 0) + "id": 1, + "name": "repo1", + "full_name": "user/repo1", + "private": False, + "url": "http://example.com/user/repo1", + "description": "Test repo 1", + "last_updated": datetime(2023, 1, 1, 10, 30, 0), }, { - "id": 2, "name": "org-repo", "full_name": "org/org-repo", "private": True, - "url": "http://example.com/org/org-repo", "description": "Org repo", - "last_updated": datetime(2023, 1, 2, 12, 0, 0) - } + "id": 2, + "name": "org-repo", + "full_name": "org/org-repo", + "private": True, + "url": "http://example.com/org/org-repo", + "description": "Org repo", + "last_updated": datetime(2023, 1, 2, 12, 0, 0), + }, ] self.assertEqual(repositories, expected_repositories) self.assertEqual(len(repositories), 2) - @patch('surfsense_backend.app.connectors.github_connector.github_login') - def test_get_user_repositories_handles_empty_description_and_none_updated_at(self, mock_github_login): + @patch("surfsense_backend.app.connectors.github_connector.github_login") + def test_get_user_repositories_handles_empty_description_and_none_updated_at( + self, mock_github_login + ): # Mock the GitHub client object and its methods mock_gh_instance = Mock() mock_github_login.return_value = mock_gh_instance @@ -77,61 +94,73 @@ class TestGitHubConnector(unittest.TestCase): mock_repo_data.full_name = "user/repo_no_desc" mock_repo_data.private = False mock_repo_data.html_url = "http://example.com/user/repo_no_desc" - mock_repo_data.description = None # Test None description - mock_repo_data.updated_at = None # Test None updated_at + mock_repo_data.description = None # Test None description + mock_repo_data.updated_at = None # Test None updated_at mock_gh_instance.repositories.return_value = [mock_repo_data] connector = GitHubConnector(token="fake_token") repositories = connector.get_user_repositories() - mock_gh_instance.repositories.assert_called_once_with(type='all', sort='updated') + mock_gh_instance.repositories.assert_called_once_with( + type="all", sort="updated" + ) expected_repositories = [ { - "id": 1, "name": "repo_no_desc", "full_name": "user/repo_no_desc", "private": False, - "url": "http://example.com/user/repo_no_desc", "description": "", # Expect empty string - "last_updated": None # Expect None + "id": 1, + "name": "repo_no_desc", + "full_name": "user/repo_no_desc", + "private": False, + "url": "http://example.com/user/repo_no_desc", + "description": "", # Expect empty string + "last_updated": None, # Expect None } ] self.assertEqual(repositories, expected_repositories) - @patch('surfsense_backend.app.connectors.github_connector.github_login') + @patch("surfsense_backend.app.connectors.github_connector.github_login") def test_github_connector_initialization_failure_forbidden(self, mock_github_login): # Test that __init__ raises ValueError on auth failure (ForbiddenError) mock_gh_instance = Mock() mock_github_login.return_value = mock_gh_instance - + # Create a mock response object for the ForbiddenError # The actual response structure might vary, but github3.py's ForbiddenError # can be instantiated with just a response object that has a status_code. mock_response = Mock() - mock_response.status_code = 403 # Typically Forbidden - + mock_response.status_code = 403 # Typically Forbidden + # Setup the side_effect for self.gh.me() mock_gh_instance.me.side_effect = ForbiddenError(mock_response) with self.assertRaises(ValueError) as context: GitHubConnector(token="invalid_token_forbidden") - self.assertIn("Invalid GitHub token or insufficient permissions.", str(context.exception)) + self.assertIn( + "Invalid GitHub token or insufficient permissions.", str(context.exception) + ) - @patch('surfsense_backend.app.connectors.github_connector.github_login') - def test_github_connector_initialization_failure_authentication_failed(self, mock_github_login): + @patch("surfsense_backend.app.connectors.github_connector.github_login") + def test_github_connector_initialization_failure_authentication_failed( + self, mock_github_login + ): # Test that __init__ raises ValueError on auth failure (AuthenticationFailed, which is a subclass of ForbiddenError) # For github3.py, AuthenticationFailed is more specific for token issues. from github3.exceptions import AuthenticationFailed mock_gh_instance = Mock() mock_github_login.return_value = mock_gh_instance - + mock_response = Mock() - mock_response.status_code = 401 # Typically Unauthorized - + mock_response.status_code = 401 # Typically Unauthorized + mock_gh_instance.me.side_effect = AuthenticationFailed(mock_response) with self.assertRaises(ValueError) as context: GitHubConnector(token="invalid_token_authfailed") - self.assertIn("Invalid GitHub token or insufficient permissions.", str(context.exception)) - - @patch('surfsense_backend.app.connectors.github_connector.github_login') + self.assertIn( + "Invalid GitHub token or insufficient permissions.", str(context.exception) + ) + + @patch("surfsense_backend.app.connectors.github_connector.github_login") def test_get_user_repositories_handles_api_exception(self, mock_github_login): mock_gh_instance = Mock() mock_github_login.return_value = mock_gh_instance @@ -142,13 +171,18 @@ class TestGitHubConnector(unittest.TestCase): connector = GitHubConnector(token="fake_token") # We expect it to log an error and return an empty list - with patch('surfsense_backend.app.connectors.github_connector.logger') as mock_logger: + with patch( + "surfsense_backend.app.connectors.github_connector.logger" + ) as mock_logger: repositories = connector.get_user_repositories() - + self.assertEqual(repositories, []) mock_logger.error.assert_called_once() - self.assertIn("Failed to fetch GitHub repositories: API Error", mock_logger.error.call_args[0][0]) + self.assertIn( + "Failed to fetch GitHub repositories: API Error", + mock_logger.error.call_args[0][0], + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/surfsense_backend/app/connectors/test_slack_history.py b/surfsense_backend/app/connectors/test_slack_history.py index ecff2c5..67677df 100644 --- a/surfsense_backend/app/connectors/test_slack_history.py +++ b/surfsense_backend/app/connectors/test_slack_history.py @@ -1,373 +1,448 @@ import unittest -import time # Imported to be available for patching target module -from unittest.mock import patch, Mock, call +from unittest.mock import Mock, call, patch + from slack_sdk.errors import SlackApiError # Since test_slack_history.py is in the same directory as slack_history.py from .slack_history import SlackHistory -class TestSlackHistoryGetAllChannels(unittest.TestCase): - @patch('surfsense_backend.app.connectors.slack_history.logger') - @patch('surfsense_backend.app.connectors.slack_history.time.sleep') - @patch('slack_sdk.WebClient') - def test_get_all_channels_pagination_with_delay(self, MockWebClient, mock_sleep, mock_logger): - mock_client_instance = MockWebClient.return_value - +class TestSlackHistoryGetAllChannels(unittest.TestCase): + @patch("surfsense_backend.app.connectors.slack_history.logger") + @patch("surfsense_backend.app.connectors.slack_history.time.sleep") + @patch("slack_sdk.WebClient") + def test_get_all_channels_pagination_with_delay( + self, mock_web_client, mock_sleep, mock_logger + ): + mock_client_instance = mock_web_client.return_value + # Mock API responses now include is_private and is_member page1_response = { "channels": [ - {"name": "general", "id": "C1", "is_private": False, "is_member": True}, - {"name": "dev", "id": "C0", "is_private": False, "is_member": True} + {"name": "general", "id": "C1", "is_private": False, "is_member": True}, + {"name": "dev", "id": "C0", "is_private": False, "is_member": True}, ], - "response_metadata": {"next_cursor": "cursor123"} + "response_metadata": {"next_cursor": "cursor123"}, } page2_response = { - "channels": [{"name": "random", "id": "C2", "is_private": True, "is_member": True}], - "response_metadata": {"next_cursor": ""} + "channels": [ + {"name": "random", "id": "C2", "is_private": True, "is_member": True} + ], + "response_metadata": {"next_cursor": ""}, } - + mock_client_instance.conversations_list.side_effect = [ page1_response, - page2_response + page2_response, ] - + slack_history = SlackHistory(token="fake_token") channels_list = slack_history.get_all_channels(include_private=True) - + expected_channels_list = [ {"id": "C1", "name": "general", "is_private": False, "is_member": True}, {"id": "C0", "name": "dev", "is_private": False, "is_member": True}, - {"id": "C2", "name": "random", "is_private": True, "is_member": True} + {"id": "C2", "name": "random", "is_private": True, "is_member": True}, ] - + self.assertEqual(len(channels_list), 3) - self.assertListEqual(channels_list, expected_channels_list) # Assert list equality - + self.assertListEqual( + channels_list, expected_channels_list + ) # Assert list equality + expected_calls = [ call(types="public_channel,private_channel", cursor=None, limit=1000), - call(types="public_channel,private_channel", cursor="cursor123", limit=1000) + call( + types="public_channel,private_channel", cursor="cursor123", limit=1000 + ), ] mock_client_instance.conversations_list.assert_has_calls(expected_calls) self.assertEqual(mock_client_instance.conversations_list.call_count, 2) - - mock_sleep.assert_called_once_with(3) - mock_logger.info.assert_called_once_with("Paginating for channels, waiting 3 seconds before next call. Cursor: cursor123") - @patch('surfsense_backend.app.connectors.slack_history.logger') - @patch('surfsense_backend.app.connectors.slack_history.time.sleep') - @patch('slack_sdk.WebClient') - def test_get_all_channels_rate_limit_with_retry_after(self, MockWebClient, mock_sleep, mock_logger): - mock_client_instance = MockWebClient.return_value - + mock_sleep.assert_called_once_with(3) + mock_logger.info.assert_called_once_with( + "Paginating for channels, waiting 3 seconds before next call. Cursor: cursor123" + ) + + @patch("surfsense_backend.app.connectors.slack_history.logger") + @patch("surfsense_backend.app.connectors.slack_history.time.sleep") + @patch("slack_sdk.WebClient") + def test_get_all_channels_rate_limit_with_retry_after( + self, mock_web_client, mock_sleep, mock_logger + ): + mock_client_instance = mock_web_client.return_value + mock_error_response = Mock() mock_error_response.status_code = 429 - mock_error_response.headers = {'Retry-After': '5'} - + mock_error_response.headers = {"Retry-After": "5"} + successful_response = { - "channels": [{"name": "general", "id": "C1", "is_private": False, "is_member": True}], - "response_metadata": {"next_cursor": ""} + "channels": [ + {"name": "general", "id": "C1", "is_private": False, "is_member": True} + ], + "response_metadata": {"next_cursor": ""}, } - + mock_client_instance.conversations_list.side_effect = [ SlackApiError(message="ratelimited", response=mock_error_response), - successful_response + successful_response, ] - + slack_history = SlackHistory(token="fake_token") channels_list = slack_history.get_all_channels(include_private=True) - - expected_channels_list = [{"id": "C1", "name": "general", "is_private": False, "is_member": True}] + + expected_channels_list = [ + {"id": "C1", "name": "general", "is_private": False, "is_member": True} + ] self.assertEqual(len(channels_list), 1) self.assertListEqual(channels_list, expected_channels_list) - - mock_sleep.assert_called_once_with(5) - mock_logger.warning.assert_called_once_with("Slack API rate limit hit while fetching channels. Waiting for 5 seconds. Cursor: None") - + + mock_sleep.assert_called_once_with(5) + mock_logger.warning.assert_called_once_with( + "Slack API rate limit hit while fetching channels. Waiting for 5 seconds. Cursor: None" + ) + expected_calls = [ - call(types="public_channel,private_channel", cursor=None, limit=1000), - call(types="public_channel,private_channel", cursor=None, limit=1000) + call(types="public_channel,private_channel", cursor=None, limit=1000), + call(types="public_channel,private_channel", cursor=None, limit=1000), ] mock_client_instance.conversations_list.assert_has_calls(expected_calls) self.assertEqual(mock_client_instance.conversations_list.call_count, 2) - @patch('surfsense_backend.app.connectors.slack_history.logger') - @patch('surfsense_backend.app.connectors.slack_history.time.sleep') - @patch('slack_sdk.WebClient') - def test_get_all_channels_rate_limit_no_retry_after_valid_header(self, MockWebClient, mock_sleep, mock_logger): - mock_client_instance = MockWebClient.return_value - + @patch("surfsense_backend.app.connectors.slack_history.logger") + @patch("surfsense_backend.app.connectors.slack_history.time.sleep") + @patch("slack_sdk.WebClient") + def test_get_all_channels_rate_limit_no_retry_after_valid_header( + self, mock_web_client, mock_sleep, mock_logger + ): + mock_client_instance = mock_web_client.return_value + mock_error_response = Mock() mock_error_response.status_code = 429 - mock_error_response.headers = {'Retry-After': 'invalid_value'} - + mock_error_response.headers = {"Retry-After": "invalid_value"} + successful_response = { - "channels": [{"name": "general", "id": "C1", "is_private": False, "is_member": True}], - "response_metadata": {"next_cursor": ""} + "channels": [ + {"name": "general", "id": "C1", "is_private": False, "is_member": True} + ], + "response_metadata": {"next_cursor": ""}, } - + mock_client_instance.conversations_list.side_effect = [ SlackApiError(message="ratelimited", response=mock_error_response), - successful_response + successful_response, ] - + slack_history = SlackHistory(token="fake_token") channels_list = slack_history.get_all_channels(include_private=True) - - expected_channels_list = [{"id": "C1", "name": "general", "is_private": False, "is_member": True}] + + expected_channels_list = [ + {"id": "C1", "name": "general", "is_private": False, "is_member": True} + ] self.assertListEqual(channels_list, expected_channels_list) - mock_sleep.assert_called_once_with(60) # Default fallback - mock_logger.warning.assert_called_once_with("Slack API rate limit hit while fetching channels. Waiting for 60 seconds. Cursor: None") + mock_sleep.assert_called_once_with(60) # Default fallback + mock_logger.warning.assert_called_once_with( + "Slack API rate limit hit while fetching channels. Waiting for 60 seconds. Cursor: None" + ) self.assertEqual(mock_client_instance.conversations_list.call_count, 2) - @patch('surfsense_backend.app.connectors.slack_history.logger') - @patch('surfsense_backend.app.connectors.slack_history.time.sleep') - @patch('slack_sdk.WebClient') - def test_get_all_channels_rate_limit_no_retry_after_header(self, MockWebClient, mock_sleep, mock_logger): - mock_client_instance = MockWebClient.return_value - + @patch("surfsense_backend.app.connectors.slack_history.logger") + @patch("surfsense_backend.app.connectors.slack_history.time.sleep") + @patch("slack_sdk.WebClient") + def test_get_all_channels_rate_limit_no_retry_after_header( + self, mock_web_client, mock_sleep, mock_logger + ): + mock_client_instance = mock_web_client.return_value + mock_error_response = Mock() mock_error_response.status_code = 429 - mock_error_response.headers = {} - - successful_response = { - "channels": [{"name": "general", "id": "C1", "is_private": False, "is_member": True}], - "response_metadata": {"next_cursor": ""} - } - - mock_client_instance.conversations_list.side_effect = [ - SlackApiError(message="ratelimited", response=mock_error_response), - successful_response - ] - - slack_history = SlackHistory(token="fake_token") - channels_list = slack_history.get_all_channels(include_private=True) - - expected_channels_list = [{"id": "C1", "name": "general", "is_private": False, "is_member": True}] - self.assertListEqual(channels_list, expected_channels_list) - mock_sleep.assert_called_once_with(60) # Default fallback - mock_logger.warning.assert_called_once_with("Slack API rate limit hit while fetching channels. Waiting for 60 seconds. Cursor: None") - self.assertEqual(mock_client_instance.conversations_list.call_count, 2) - - @patch('surfsense_backend.app.connectors.slack_history.logger') - @patch('surfsense_backend.app.connectors.slack_history.time.sleep') - @patch('slack_sdk.WebClient') - def test_get_all_channels_other_slack_api_error(self, MockWebClient, mock_sleep, mock_logger): - mock_client_instance = MockWebClient.return_value - - mock_error_response = Mock() - mock_error_response.status_code = 500 mock_error_response.headers = {} - mock_error_response.data = {"ok": False, "error": "internal_error"} - - original_error = SlackApiError(message="server error", response=mock_error_response) - mock_client_instance.conversations_list.side_effect = original_error - + + successful_response = { + "channels": [ + {"name": "general", "id": "C1", "is_private": False, "is_member": True} + ], + "response_metadata": {"next_cursor": ""}, + } + + mock_client_instance.conversations_list.side_effect = [ + SlackApiError(message="ratelimited", response=mock_error_response), + successful_response, + ] + slack_history = SlackHistory(token="fake_token") - + channels_list = slack_history.get_all_channels(include_private=True) + + expected_channels_list = [ + {"id": "C1", "name": "general", "is_private": False, "is_member": True} + ] + self.assertListEqual(channels_list, expected_channels_list) + mock_sleep.assert_called_once_with(60) # Default fallback + mock_logger.warning.assert_called_once_with( + "Slack API rate limit hit while fetching channels. Waiting for 60 seconds. Cursor: None" + ) + self.assertEqual(mock_client_instance.conversations_list.call_count, 2) + + @patch("surfsense_backend.app.connectors.slack_history.logger") + @patch("surfsense_backend.app.connectors.slack_history.time.sleep") + @patch("slack_sdk.WebClient") + def test_get_all_channels_other_slack_api_error( + self, mock_web_client, mock_sleep, mock_logger + ): + mock_client_instance = mock_web_client.return_value + + mock_error_response = Mock() + mock_error_response.status_code = 500 + mock_error_response.headers = {} + mock_error_response.data = {"ok": False, "error": "internal_error"} + + original_error = SlackApiError( + message="server error", response=mock_error_response + ) + mock_client_instance.conversations_list.side_effect = original_error + + slack_history = SlackHistory(token="fake_token") + with self.assertRaises(SlackApiError) as context: slack_history.get_all_channels(include_private=True) - + self.assertEqual(context.exception.response.status_code, 500) self.assertIn("server error", str(context.exception)) mock_sleep.assert_not_called() - mock_logger.warning.assert_not_called() # Ensure no rate limit log + mock_logger.warning.assert_not_called() # Ensure no rate limit log mock_client_instance.conversations_list.assert_called_once_with( types="public_channel,private_channel", cursor=None, limit=1000 ) - @patch('surfsense_backend.app.connectors.slack_history.logger') - @patch('surfsense_backend.app.connectors.slack_history.time.sleep') - @patch('slack_sdk.WebClient') - def test_get_all_channels_handles_missing_name_id_gracefully(self, MockWebClient, mock_sleep, mock_logger): - mock_client_instance = MockWebClient.return_value - + @patch("surfsense_backend.app.connectors.slack_history.logger") + @patch("surfsense_backend.app.connectors.slack_history.time.sleep") + @patch("slack_sdk.WebClient") + def test_get_all_channels_handles_missing_name_id_gracefully( + self, mock_web_client, mock_sleep, mock_logger + ): + mock_client_instance = mock_web_client.return_value + response_with_malformed_data = { "channels": [ - {"id": "C1_missing_name", "is_private": False, "is_member": True}, + {"id": "C1_missing_name", "is_private": False, "is_member": True}, {"name": "channel_missing_id", "is_private": False, "is_member": True}, - {"name": "general", "id": "C2_valid", "is_private": False, "is_member": True} + { + "name": "general", + "id": "C2_valid", + "is_private": False, + "is_member": True, + }, ], - "response_metadata": {"next_cursor": ""} + "response_metadata": {"next_cursor": ""}, } - - mock_client_instance.conversations_list.return_value = response_with_malformed_data - + + mock_client_instance.conversations_list.return_value = ( + response_with_malformed_data + ) + slack_history = SlackHistory(token="fake_token") channels_list = slack_history.get_all_channels(include_private=True) - - expected_channels_list = [ - {"id": "C2_valid", "name": "general", "is_private": False, "is_member": True} - ] - self.assertEqual(len(channels_list), 1) - self.assertListEqual(channels_list, expected_channels_list) - - self.assertEqual(mock_logger.warning.call_count, 2) - mock_logger.warning.assert_any_call("Channel found with missing name or id. Data: {'id': 'C1_missing_name', 'is_private': False, 'is_member': True}") - mock_logger.warning.assert_any_call("Channel found with missing name or id. Data: {'name': 'channel_missing_id', 'is_private': False, 'is_member': True}") - mock_sleep.assert_not_called() + expected_channels_list = [ + { + "id": "C2_valid", + "name": "general", + "is_private": False, + "is_member": True, + } + ] + self.assertEqual(len(channels_list), 1) + self.assertListEqual(channels_list, expected_channels_list) + + self.assertEqual(mock_logger.warning.call_count, 2) + mock_logger.warning.assert_any_call( + "Channel found with missing name or id. Data: {'id': 'C1_missing_name', 'is_private': False, 'is_member': True}" + ) + mock_logger.warning.assert_any_call( + "Channel found with missing name or id. Data: {'name': 'channel_missing_id', 'is_private': False, 'is_member': True}" + ) + + mock_sleep.assert_not_called() mock_client_instance.conversations_list.assert_called_once_with( types="public_channel,private_channel", cursor=None, limit=1000 ) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() -class TestSlackHistoryGetConversationHistory(unittest.TestCase): - @patch('surfsense_backend.app.connectors.slack_history.logger') - @patch('surfsense_backend.app.connectors.slack_history.time.sleep') - @patch('slack_sdk.WebClient') - def test_proactive_delay_single_page(self, MockWebClient, mock_time_sleep, mock_logger): - mock_client_instance = MockWebClient.return_value +class TestSlackHistoryGetConversationHistory(unittest.TestCase): + @patch("surfsense_backend.app.connectors.slack_history.logger") + @patch("surfsense_backend.app.connectors.slack_history.time.sleep") + @patch("slack_sdk.WebClient") + def test_proactive_delay_single_page( + self, mock_web_client, mock_time_sleep, mock_logger + ): + mock_client_instance = mock_web_client.return_value mock_client_instance.conversations_history.return_value = { "messages": [{"text": "msg1"}], - "has_more": False + "has_more": False, } - + slack_history = SlackHistory(token="fake_token") slack_history.get_conversation_history(channel_id="C123") - - mock_time_sleep.assert_called_once_with(1.2) # Proactive delay - @patch('surfsense_backend.app.connectors.slack_history.logger') - @patch('surfsense_backend.app.connectors.slack_history.time.sleep') - @patch('slack_sdk.WebClient') - def test_proactive_delay_multiple_pages(self, MockWebClient, mock_time_sleep, mock_logger): - mock_client_instance = MockWebClient.return_value + mock_time_sleep.assert_called_once_with(1.2) # Proactive delay + + @patch("surfsense_backend.app.connectors.slack_history.logger") + @patch("surfsense_backend.app.connectors.slack_history.time.sleep") + @patch("slack_sdk.WebClient") + def test_proactive_delay_multiple_pages( + self, mock_web_client, mock_time_sleep, mock_logger + ): + mock_client_instance = mock_web_client.return_value mock_client_instance.conversations_history.side_effect = [ { "messages": [{"text": "msg1"}], "has_more": True, - "response_metadata": {"next_cursor": "cursor1"} + "response_metadata": {"next_cursor": "cursor1"}, }, - { - "messages": [{"text": "msg2"}], - "has_more": False - } + {"messages": [{"text": "msg2"}], "has_more": False}, ] - + slack_history = SlackHistory(token="fake_token") slack_history.get_conversation_history(channel_id="C123") - + # Expected calls: 1.2 (page1), 1.2 (page2) self.assertEqual(mock_time_sleep.call_count, 2) mock_time_sleep.assert_has_calls([call(1.2), call(1.2)]) - @patch('surfsense_backend.app.connectors.slack_history.logger') - @patch('surfsense_backend.app.connectors.slack_history.time.sleep') - @patch('slack_sdk.WebClient') - def test_retry_after_logic(self, MockWebClient, mock_time_sleep, mock_logger): - mock_client_instance = MockWebClient.return_value - + @patch("surfsense_backend.app.connectors.slack_history.logger") + @patch("surfsense_backend.app.connectors.slack_history.time.sleep") + @patch("slack_sdk.WebClient") + def test_retry_after_logic(self, mock_web_client, mock_time_sleep, mock_logger): + mock_client_instance = mock_web_client.return_value + mock_error_response = Mock() mock_error_response.status_code = 429 - mock_error_response.headers = {'Retry-After': '5'} - + mock_error_response.headers = {"Retry-After": "5"} + mock_client_instance.conversations_history.side_effect = [ SlackApiError(message="ratelimited", response=mock_error_response), - {"messages": [{"text": "msg1"}], "has_more": False} + {"messages": [{"text": "msg1"}], "has_more": False}, ] - + slack_history = SlackHistory(token="fake_token") messages = slack_history.get_conversation_history(channel_id="C123") - + self.assertEqual(len(messages), 1) self.assertEqual(messages[0]["text"], "msg1") - - # Expected sleep calls: 1.2 (proactive for 1st attempt), 5 (rate limit), 1.2 (proactive for 2nd attempt) - mock_time_sleep.assert_has_calls([call(1.2), call(5), call(1.2)], any_order=False) - mock_logger.warning.assert_called_once() # Check that a warning was logged for rate limiting - @patch('surfsense_backend.app.connectors.slack_history.logger') - @patch('surfsense_backend.app.connectors.slack_history.time.sleep') - @patch('slack_sdk.WebClient') - def test_not_in_channel_error(self, MockWebClient, mock_time_sleep, mock_logger): - mock_client_instance = MockWebClient.return_value - + # Expected sleep calls: 1.2 (proactive for 1st attempt), 5 (rate limit), 1.2 (proactive for 2nd attempt) + mock_time_sleep.assert_has_calls( + [call(1.2), call(5), call(1.2)], any_order=False + ) + mock_logger.warning.assert_called_once() # Check that a warning was logged for rate limiting + + @patch("surfsense_backend.app.connectors.slack_history.logger") + @patch("surfsense_backend.app.connectors.slack_history.time.sleep") + @patch("slack_sdk.WebClient") + def test_not_in_channel_error(self, mock_web_client, mock_time_sleep, mock_logger): + mock_client_instance = mock_web_client.return_value + mock_error_response = Mock() - mock_error_response.status_code = 403 # Typical for not_in_channel, but data matters more - mock_error_response.data = {'ok': False, 'error': 'not_in_channel'} - + mock_error_response.status_code = ( + 403 # Typical for not_in_channel, but data matters more + ) + mock_error_response.data = {"ok": False, "error": "not_in_channel"} + # This error is now raised by the inner try-except, then caught by the outer one mock_client_instance.conversations_history.side_effect = SlackApiError( - message="not_in_channel error", - response=mock_error_response + message="not_in_channel error", response=mock_error_response ) - + slack_history = SlackHistory(token="fake_token") messages = slack_history.get_conversation_history(channel_id="C123") - + self.assertEqual(messages, []) mock_logger.warning.assert_called_with( "Bot is not in channel 'C123'. Cannot fetch history. Please add the bot to this channel." ) - mock_time_sleep.assert_called_once_with(1.2) # Proactive delay before the API call + mock_time_sleep.assert_called_once_with( + 1.2 + ) # Proactive delay before the API call + + @patch("surfsense_backend.app.connectors.slack_history.logger") + @patch("surfsense_backend.app.connectors.slack_history.time.sleep") + @patch("slack_sdk.WebClient") + def test_other_slack_api_error_propagates( + self, mock_web_client, mock_time_sleep, mock_logger + ): + mock_client_instance = mock_web_client.return_value - @patch('surfsense_backend.app.connectors.slack_history.logger') - @patch('surfsense_backend.app.connectors.slack_history.time.sleep') - @patch('slack_sdk.WebClient') - def test_other_slack_api_error_propagates(self, MockWebClient, mock_time_sleep, mock_logger): - mock_client_instance = MockWebClient.return_value - mock_error_response = Mock() mock_error_response.status_code = 500 - mock_error_response.data = {'ok': False, 'error': 'internal_error'} - original_error = SlackApiError(message="server error", response=mock_error_response) + mock_error_response.data = {"ok": False, "error": "internal_error"} + original_error = SlackApiError( + message="server error", response=mock_error_response + ) mock_client_instance.conversations_history.side_effect = original_error - + slack_history = SlackHistory(token="fake_token") - + with self.assertRaises(SlackApiError) as context: slack_history.get_conversation_history(channel_id="C123") - - self.assertIn("Error retrieving history for channel C123", str(context.exception)) - self.assertIs(context.exception.response, mock_error_response) - mock_time_sleep.assert_called_once_with(1.2) # Proactive delay - @patch('surfsense_backend.app.connectors.slack_history.logger') - @patch('surfsense_backend.app.connectors.slack_history.time.sleep') - @patch('slack_sdk.WebClient') - def test_general_exception_propagates(self, MockWebClient, mock_time_sleep, mock_logger): - mock_client_instance = MockWebClient.return_value + self.assertIn( + "Error retrieving history for channel C123", str(context.exception) + ) + self.assertIs(context.exception.response, mock_error_response) + mock_time_sleep.assert_called_once_with(1.2) # Proactive delay + + @patch("surfsense_backend.app.connectors.slack_history.logger") + @patch("surfsense_backend.app.connectors.slack_history.time.sleep") + @patch("slack_sdk.WebClient") + def test_general_exception_propagates( + self, mock_web_client, mock_time_sleep, mock_logger + ): + mock_client_instance = mock_web_client.return_value original_error = Exception("Something broke") mock_client_instance.conversations_history.side_effect = original_error - + slack_history = SlackHistory(token="fake_token") - - with self.assertRaises(Exception) as context: # Check for generic Exception + + with self.assertRaises(Exception) as context: # Check for generic Exception slack_history.get_conversation_history(channel_id="C123") - - self.assertIs(context.exception, original_error) # Should re-raise the original error - mock_logger.error.assert_called_once_with("Unexpected error in get_conversation_history for channel C123: Something broke") - mock_time_sleep.assert_called_once_with(1.2) # Proactive delay + + self.assertIs( + context.exception, original_error + ) # Should re-raise the original error + mock_logger.error.assert_called_once_with( + "Unexpected error in get_conversation_history for channel C123: Something broke" + ) + mock_time_sleep.assert_called_once_with(1.2) # Proactive delay + class TestSlackHistoryGetUserInfo(unittest.TestCase): + @patch("surfsense_backend.app.connectors.slack_history.logger") + @patch("surfsense_backend.app.connectors.slack_history.time.sleep") + @patch("slack_sdk.WebClient") + def test_retry_after_logic(self, mock_web_client, mock_time_sleep, mock_logger): + mock_client_instance = mock_web_client.return_value - @patch('surfsense_backend.app.connectors.slack_history.logger') - @patch('surfsense_backend.app.connectors.slack_history.time.sleep') - @patch('slack_sdk.WebClient') - def test_retry_after_logic(self, MockWebClient, mock_time_sleep, mock_logger): - mock_client_instance = MockWebClient.return_value - mock_error_response = Mock() mock_error_response.status_code = 429 - mock_error_response.headers = {'Retry-After': '3'} # Using 3 seconds for test - + mock_error_response.headers = {"Retry-After": "3"} # Using 3 seconds for test + successful_user_data = {"id": "U123", "name": "testuser"} - + mock_client_instance.users_info.side_effect = [ SlackApiError(message="ratelimited_userinfo", response=mock_error_response), - {"user": successful_user_data} + {"user": successful_user_data}, ] - + slack_history = SlackHistory(token="fake_token") user_info = slack_history.get_user_info(user_id="U123") - + self.assertEqual(user_info, successful_user_data) - + # Assert that time.sleep was called for the rate limit mock_time_sleep.assert_called_once_with(3) mock_logger.warning.assert_called_once_with( @@ -375,46 +450,58 @@ class TestSlackHistoryGetUserInfo(unittest.TestCase): ) # Assert users_info was called twice (original + retry) self.assertEqual(mock_client_instance.users_info.call_count, 2) - mock_client_instance.users_info.assert_has_calls([call(user="U123"), call(user="U123")]) + mock_client_instance.users_info.assert_has_calls( + [call(user="U123"), call(user="U123")] + ) + + @patch("surfsense_backend.app.connectors.slack_history.logger") + @patch( + "surfsense_backend.app.connectors.slack_history.time.sleep" + ) # time.sleep might be called by other logic, but not expected here + @patch("slack_sdk.WebClient") + def test_other_slack_api_error_propagates( + self, mock_web_client, mock_time_sleep, mock_logger + ): + mock_client_instance = mock_web_client.return_value - @patch('surfsense_backend.app.connectors.slack_history.logger') - @patch('surfsense_backend.app.connectors.slack_history.time.sleep') # time.sleep might be called by other logic, but not expected here - @patch('slack_sdk.WebClient') - def test_other_slack_api_error_propagates(self, MockWebClient, mock_time_sleep, mock_logger): - mock_client_instance = MockWebClient.return_value - mock_error_response = Mock() - mock_error_response.status_code = 500 # Some other error - mock_error_response.data = {'ok': False, 'error': 'internal_server_error'} - original_error = SlackApiError(message="internal server error", response=mock_error_response) + mock_error_response.status_code = 500 # Some other error + mock_error_response.data = {"ok": False, "error": "internal_server_error"} + original_error = SlackApiError( + message="internal server error", response=mock_error_response + ) mock_client_instance.users_info.side_effect = original_error - + slack_history = SlackHistory(token="fake_token") - + with self.assertRaises(SlackApiError) as context: slack_history.get_user_info(user_id="U123") - + # Check that the raised error is the one we expect self.assertIn("Error retrieving user info for U123", str(context.exception)) self.assertIs(context.exception.response, mock_error_response) - mock_time_sleep.assert_not_called() # No rate limit sleep + mock_time_sleep.assert_not_called() # No rate limit sleep - @patch('surfsense_backend.app.connectors.slack_history.logger') - @patch('surfsense_backend.app.connectors.slack_history.time.sleep') - @patch('slack_sdk.WebClient') - def test_general_exception_propagates(self, MockWebClient, mock_time_sleep, mock_logger): - mock_client_instance = MockWebClient.return_value + @patch("surfsense_backend.app.connectors.slack_history.logger") + @patch("surfsense_backend.app.connectors.slack_history.time.sleep") + @patch("slack_sdk.WebClient") + def test_general_exception_propagates( + self, mock_web_client, mock_time_sleep, mock_logger + ): + mock_client_instance = mock_web_client.return_value original_error = Exception("A very generic problem") mock_client_instance.users_info.side_effect = original_error - + slack_history = SlackHistory(token="fake_token") - + with self.assertRaises(Exception) as context: slack_history.get_user_info(user_id="U123") - - self.assertIs(context.exception, original_error) # Check it's the exact same exception + + self.assertIs( + context.exception, original_error + ) # Check it's the exact same exception mock_logger.error.assert_called_once_with( "Unexpected error in get_user_info for user U123: A very generic problem" ) - mock_time_sleep.assert_not_called() # No rate limit sleep + mock_time_sleep.assert_not_called() # No rate limit sleep diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 7caf365..3d235d0 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -1,22 +1,21 @@ from collections.abc import AsyncGenerator -from datetime import datetime, timezone +from datetime import UTC, datetime from enum import Enum from fastapi import Depends - from pgvector.sqlalchemy import Vector from sqlalchemy import ( ARRAY, + JSON, + TIMESTAMP, Boolean, Column, Enum as SQLAlchemyEnum, ForeignKey, Integer, - JSON, String, Text, text, - TIMESTAMP ) from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine @@ -33,10 +32,7 @@ if config.AUTH_TYPE == "GOOGLE": SQLAlchemyUserDatabase, ) else: - from fastapi_users.db import ( - SQLAlchemyBaseUserTableUUID, - SQLAlchemyUserDatabase, - ) + from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase DATABASE_URL = config.DATABASE_URL @@ -52,8 +48,9 @@ class DocumentType(str, Enum): LINEAR_CONNECTOR = "LINEAR_CONNECTOR" DISCORD_CONNECTOR = "DISCORD_CONNECTOR" + class SearchSourceConnectorType(str, Enum): - SERPER_API = "SERPER_API" # NOT IMPLEMENTED YET : DON'T REMEMBER WHY : MOST PROBABLY BECAUSE WE NEED TO CRAWL THE RESULTS RETURNED BY IT + SERPER_API = "SERPER_API" # NOT IMPLEMENTED YET : DON'T REMEMBER WHY : MOST PROBABLY BECAUSE WE NEED TO CRAWL THE RESULTS RETURNED BY IT TAVILY_API = "TAVILY_API" LINKUP_API = "LINKUP_API" SLACK_CONNECTOR = "SLACK_CONNECTOR" @@ -61,13 +58,15 @@ class SearchSourceConnectorType(str, Enum): GITHUB_CONNECTOR = "GITHUB_CONNECTOR" LINEAR_CONNECTOR = "LINEAR_CONNECTOR" DISCORD_CONNECTOR = "DISCORD_CONNECTOR" - + + class ChatType(str, Enum): QNA = "QNA" REPORT_GENERAL = "REPORT_GENERAL" REPORT_DEEP = "REPORT_DEEP" REPORT_DEEPER = "REPORT_DEEPER" + class LiteLLMProvider(str, Enum): OPENAI = "OPENAI" ANTHROPIC = "ANTHROPIC" @@ -92,6 +91,7 @@ class LiteLLMProvider(str, Enum): PETALS = "PETALS" CUSTOM = "CUSTOM" + class LogLevel(str, Enum): DEBUG = "DEBUG" INFO = "INFO" @@ -99,18 +99,27 @@ class LogLevel(str, Enum): ERROR = "ERROR" CRITICAL = "CRITICAL" + class LogStatus(str, Enum): IN_PROGRESS = "IN_PROGRESS" SUCCESS = "SUCCESS" FAILED = "FAILED" - + + class Base(DeclarativeBase): pass + class TimestampMixin: @declared_attr - def created_at(cls): - return Column(TIMESTAMP(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc), index=True) + def created_at(cls): # noqa: N805 + return Column( + TIMESTAMP(timezone=True), + nullable=False, + default=lambda: datetime.now(UTC), + index=True, + ) + class BaseModel(Base): __abstract__ = True @@ -118,6 +127,7 @@ class BaseModel(Base): id = Column(Integer, primary_key=True, index=True) + class Chat(BaseModel, TimestampMixin): __tablename__ = "chats" @@ -125,73 +135,115 @@ class Chat(BaseModel, TimestampMixin): title = Column(String, nullable=False, index=True) initial_connectors = Column(ARRAY(String), nullable=True) messages = Column(JSON, nullable=False) - - search_space_id = Column(Integer, ForeignKey('searchspaces.id', ondelete='CASCADE'), nullable=False) - search_space = relationship('SearchSpace', back_populates='chats') + + search_space_id = Column( + Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False + ) + search_space = relationship("SearchSpace", back_populates="chats") + class Document(BaseModel, TimestampMixin): __tablename__ = "documents" - + title = Column(String, nullable=False, index=True) document_type = Column(SQLAlchemyEnum(DocumentType), nullable=False) document_metadata = Column(JSON, nullable=True) - + content = Column(Text, nullable=False) content_hash = Column(String, nullable=False, index=True, unique=True) embedding = Column(Vector(config.embedding_model_instance.dimension)) - - search_space_id = Column(Integer, ForeignKey("searchspaces.id", ondelete='CASCADE'), nullable=False) + + search_space_id = Column( + Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False + ) search_space = relationship("SearchSpace", back_populates="documents") - chunks = relationship("Chunk", back_populates="document", cascade="all, delete-orphan") + chunks = relationship( + "Chunk", back_populates="document", cascade="all, delete-orphan" + ) + class Chunk(BaseModel, TimestampMixin): __tablename__ = "chunks" - + content = Column(Text, nullable=False) embedding = Column(Vector(config.embedding_model_instance.dimension)) - - document_id = Column(Integer, ForeignKey("documents.id", ondelete='CASCADE'), nullable=False) + + document_id = Column( + Integer, ForeignKey("documents.id", ondelete="CASCADE"), nullable=False + ) document = relationship("Document", back_populates="chunks") + class Podcast(BaseModel, TimestampMixin): __tablename__ = "podcasts" - + title = Column(String, nullable=False, index=True) podcast_transcript = Column(JSON, nullable=False, default={}) file_location = Column(String(500), nullable=False, default="") - - search_space_id = Column(Integer, ForeignKey("searchspaces.id", ondelete='CASCADE'), nullable=False) + + search_space_id = Column( + Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False + ) search_space = relationship("SearchSpace", back_populates="podcasts") - + + class SearchSpace(BaseModel, TimestampMixin): __tablename__ = "searchspaces" - + name = Column(String(100), nullable=False, index=True) description = Column(String(500), nullable=True) - - user_id = Column(UUID(as_uuid=True), ForeignKey("user.id", ondelete='CASCADE'), nullable=False) + + user_id = Column( + UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False + ) user = relationship("User", back_populates="search_spaces") - - documents = relationship("Document", back_populates="search_space", order_by="Document.id", cascade="all, delete-orphan") - podcasts = relationship("Podcast", back_populates="search_space", order_by="Podcast.id", cascade="all, delete-orphan") - chats = relationship('Chat', back_populates='search_space', order_by='Chat.id', cascade="all, delete-orphan") - logs = relationship("Log", back_populates="search_space", order_by="Log.id", cascade="all, delete-orphan") - + + documents = relationship( + "Document", + back_populates="search_space", + order_by="Document.id", + cascade="all, delete-orphan", + ) + podcasts = relationship( + "Podcast", + back_populates="search_space", + order_by="Podcast.id", + cascade="all, delete-orphan", + ) + chats = relationship( + "Chat", + back_populates="search_space", + order_by="Chat.id", + cascade="all, delete-orphan", + ) + logs = relationship( + "Log", + back_populates="search_space", + order_by="Log.id", + cascade="all, delete-orphan", + ) + + class SearchSourceConnector(BaseModel, TimestampMixin): __tablename__ = "search_source_connectors" - + name = Column(String(100), nullable=False, index=True) - connector_type = Column(SQLAlchemyEnum(SearchSourceConnectorType), nullable=False, unique=True) + connector_type = Column( + SQLAlchemyEnum(SearchSourceConnectorType), nullable=False, unique=True + ) is_indexable = Column(Boolean, nullable=False, default=False) last_indexed_at = Column(TIMESTAMP(timezone=True), nullable=True) config = Column(JSON, nullable=False) - - user_id = Column(UUID(as_uuid=True), ForeignKey("user.id", ondelete='CASCADE'), nullable=False) + + user_id = Column( + UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False + ) user = relationship("User", back_populates="search_source_connectors") + class LLMConfig(BaseModel, TimestampMixin): __tablename__ = "llm_configs" - + name = Column(String(100), nullable=False, index=True) # Provider from the enum provider = Column(SQLAlchemyEnum(LiteLLMProvider), nullable=False) @@ -202,78 +254,141 @@ class LLMConfig(BaseModel, TimestampMixin): # API Key should be encrypted before storing api_key = Column(String, nullable=False) api_base = Column(String(500), nullable=True) - + # For any other parameters that litellm supports litellm_params = Column(JSON, nullable=True, default={}) - - user_id = Column(UUID(as_uuid=True), ForeignKey("user.id", ondelete='CASCADE'), nullable=False) + + user_id = Column( + UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False + ) user = relationship("User", back_populates="llm_configs", foreign_keys=[user_id]) + class Log(BaseModel, TimestampMixin): __tablename__ = "logs" - + level = Column(SQLAlchemyEnum(LogLevel), nullable=False, index=True) status = Column(SQLAlchemyEnum(LogStatus), nullable=False, index=True) message = Column(Text, nullable=False) - source = Column(String(200), nullable=True, index=True) # Service/component that generated the log + source = Column( + String(200), nullable=True, index=True + ) # Service/component that generated the log log_metadata = Column(JSON, nullable=True, default={}) # Additional context data - - search_space_id = Column(Integer, ForeignKey("searchspaces.id", ondelete='CASCADE'), nullable=False) + + search_space_id = Column( + Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False + ) search_space = relationship("SearchSpace", back_populates="logs") + if config.AUTH_TYPE == "GOOGLE": + class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base): pass - class User(SQLAlchemyBaseUserTableUUID, Base): oauth_accounts: Mapped[list[OAuthAccount]] = relationship( "OAuthAccount", lazy="joined" ) search_spaces = relationship("SearchSpace", back_populates="user") - search_source_connectors = relationship("SearchSourceConnector", back_populates="user") - llm_configs = relationship("LLMConfig", back_populates="user", foreign_keys="LLMConfig.user_id", cascade="all, delete-orphan") + search_source_connectors = relationship( + "SearchSourceConnector", back_populates="user" + ) + llm_configs = relationship( + "LLMConfig", + back_populates="user", + foreign_keys="LLMConfig.user_id", + cascade="all, delete-orphan", + ) - long_context_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True) - fast_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True) - strategic_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True) + long_context_llm_id = Column( + Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True + ) + fast_llm_id = Column( + Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True + ) + strategic_llm_id = Column( + Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True + ) - long_context_llm = relationship("LLMConfig", foreign_keys=[long_context_llm_id], post_update=True) - fast_llm = relationship("LLMConfig", foreign_keys=[fast_llm_id], post_update=True) - strategic_llm = relationship("LLMConfig", foreign_keys=[strategic_llm_id], post_update=True) + long_context_llm = relationship( + "LLMConfig", foreign_keys=[long_context_llm_id], post_update=True + ) + fast_llm = relationship( + "LLMConfig", foreign_keys=[fast_llm_id], post_update=True + ) + strategic_llm = relationship( + "LLMConfig", foreign_keys=[strategic_llm_id], post_update=True + ) else: + class User(SQLAlchemyBaseUserTableUUID, Base): - search_spaces = relationship("SearchSpace", back_populates="user") - search_source_connectors = relationship("SearchSourceConnector", back_populates="user") - llm_configs = relationship("LLMConfig", back_populates="user", foreign_keys="LLMConfig.user_id", cascade="all, delete-orphan") + search_source_connectors = relationship( + "SearchSourceConnector", back_populates="user" + ) + llm_configs = relationship( + "LLMConfig", + back_populates="user", + foreign_keys="LLMConfig.user_id", + cascade="all, delete-orphan", + ) - long_context_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True) - fast_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True) - strategic_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True) + long_context_llm_id = Column( + Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True + ) + fast_llm_id = Column( + Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True + ) + strategic_llm_id = Column( + Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True + ) - long_context_llm = relationship("LLMConfig", foreign_keys=[long_context_llm_id], post_update=True) - fast_llm = relationship("LLMConfig", foreign_keys=[fast_llm_id], post_update=True) - strategic_llm = relationship("LLMConfig", foreign_keys=[strategic_llm_id], post_update=True) + long_context_llm = relationship( + "LLMConfig", foreign_keys=[long_context_llm_id], post_update=True + ) + fast_llm = relationship( + "LLMConfig", foreign_keys=[fast_llm_id], post_update=True + ) + strategic_llm = relationship( + "LLMConfig", foreign_keys=[strategic_llm_id], post_update=True + ) engine = create_async_engine(DATABASE_URL) async_session_maker = async_sessionmaker(engine, expire_on_commit=False) - + async def setup_indexes(): async with engine.begin() as conn: - # Create indexes + # Create indexes # Document Summary Indexes - await conn.execute(text('CREATE INDEX IF NOT EXISTS document_vector_index ON documents USING hnsw (embedding public.vector_cosine_ops)')) - await conn.execute(text('CREATE INDEX IF NOT EXISTS document_search_index ON documents USING gin (to_tsvector(\'english\', content))')) + await conn.execute( + text( + "CREATE INDEX IF NOT EXISTS document_vector_index ON documents USING hnsw (embedding public.vector_cosine_ops)" + ) + ) + await conn.execute( + text( + "CREATE INDEX IF NOT EXISTS document_search_index ON documents USING gin (to_tsvector('english', content))" + ) + ) # Document Chuck Indexes - await conn.execute(text('CREATE INDEX IF NOT EXISTS chucks_vector_index ON chunks USING hnsw (embedding public.vector_cosine_ops)')) - await conn.execute(text('CREATE INDEX IF NOT EXISTS chucks_search_index ON chunks USING gin (to_tsvector(\'english\', content))')) + await conn.execute( + text( + "CREATE INDEX IF NOT EXISTS chucks_vector_index ON chunks USING hnsw (embedding public.vector_cosine_ops)" + ) + ) + await conn.execute( + text( + "CREATE INDEX IF NOT EXISTS chucks_search_index ON chunks USING gin (to_tsvector('english', content))" + ) + ) + async def create_db_and_tables(): async with engine.begin() as conn: - await conn.execute(text('CREATE EXTENSION IF NOT EXISTS vector')) + await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) await conn.run_sync(Base.metadata.create_all) await setup_indexes() @@ -284,14 +399,22 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]: if config.AUTH_TYPE == "GOOGLE": + async def get_user_db(session: AsyncSession = Depends(get_async_session)): yield SQLAlchemyUserDatabase(session, User, OAuthAccount) else: + async def get_user_db(session: AsyncSession = Depends(get_async_session)): yield SQLAlchemyUserDatabase(session, User) - -async def get_chucks_hybrid_search_retriever(session: AsyncSession = Depends(get_async_session)): + + +async def get_chucks_hybrid_search_retriever( + session: AsyncSession = Depends(get_async_session), +): return ChucksHybridSearchRetriever(session) -async def get_documents_hybrid_search_retriever(session: AsyncSession = Depends(get_async_session)): + +async def get_documents_hybrid_search_retriever( + session: AsyncSession = Depends(get_async_session), +): return DocumentHybridSearchRetriever(session) diff --git a/surfsense_backend/app/prompts/__init__.py b/surfsense_backend/app/prompts/__init__.py index 6239d48..3b21cb9 100644 --- a/surfsense_backend/app/prompts/__init__.py +++ b/surfsense_backend/app/prompts/__init__.py @@ -1,9 +1,12 @@ +from datetime import UTC, datetime + from langchain_core.prompts.prompt import PromptTemplate -from datetime import datetime, timezone -DATE_TODAY = "Today's date is " + datetime.now(timezone.utc).astimezone().isoformat() + '\n' +DATE_TODAY = "Today's date is " + datetime.now(UTC).astimezone().isoformat() + "\n" -SUMMARY_PROMPT = DATE_TODAY + """ +SUMMARY_PROMPT = ( + DATE_TODAY + + """ You are an expert document analyst and summarization specialist tasked with distilling complex information into clear, @@ -96,8 +99,8 @@ SUMMARY_PROMPT = DATE_TODAY + """ """ +) SUMMARY_PROMPT_TEMPLATE = PromptTemplate( - input_variables=["document"], - template=SUMMARY_PROMPT -) \ No newline at end of file + input_variables=["document"], template=SUMMARY_PROMPT +) diff --git a/surfsense_backend/app/retriver/chunks_hybrid_search.py b/surfsense_backend/app/retriver/chunks_hybrid_search.py index b3e75e4..cb96ac6 100644 --- a/surfsense_backend/app/retriver/chunks_hybrid_search.py +++ b/surfsense_backend/app/retriver/chunks_hybrid_search.py @@ -2,34 +2,41 @@ class ChucksHybridSearchRetriever: def __init__(self, db_session): """ Initialize the hybrid search retriever with a database session. - + Args: db_session: SQLAlchemy AsyncSession from FastAPI dependency injection """ self.db_session = db_session - async def vector_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None) -> list: + async def vector_search( + self, + query_text: str, + top_k: int, + user_id: str, + search_space_id: int | None = None, + ) -> list: """ Perform vector similarity search on chunks. - + Args: query_text: The search query text top_k: Number of results to return user_id: The ID of the user performing the search search_space_id: Optional search space ID to filter results - + Returns: List of chunks sorted by vector similarity """ - from sqlalchemy import select, func + from sqlalchemy import select from sqlalchemy.orm import joinedload - from app.db import Chunk, Document, SearchSpace + from app.config import config - + from app.db import Chunk, Document, SearchSpace + # Get embedding for the query embedding_model = config.embedding_model_instance query_embedding = embedding_model.embed(query_text) - + # Build the base query with user ownership check query = ( select(Chunk) @@ -38,45 +45,48 @@ class ChucksHybridSearchRetriever: .join(SearchSpace, Document.search_space_id == SearchSpace.id) .where(SearchSpace.user_id == user_id) ) - + # Add search space filter if provided if search_space_id is not None: query = query.where(Document.search_space_id == search_space_id) - + # Add vector similarity ordering - query = ( - query - .order_by(Chunk.embedding.op("<=>")(query_embedding)) - .limit(top_k) - ) - + query = query.order_by(Chunk.embedding.op("<=>")(query_embedding)).limit(top_k) + # Execute the query result = await self.db_session.execute(query) chunks = result.scalars().all() - + return chunks - async def full_text_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None) -> list: + async def full_text_search( + self, + query_text: str, + top_k: int, + user_id: str, + search_space_id: int | None = None, + ) -> list: """ Perform full-text keyword search on chunks. - + Args: query_text: The search query text top_k: Number of results to return user_id: The ID of the user performing the search search_space_id: Optional search space ID to filter results - + Returns: List of chunks sorted by text relevance """ - from sqlalchemy import select, func, text + from sqlalchemy import func, select from sqlalchemy.orm import joinedload + from app.db import Chunk, Document, SearchSpace - + # Create tsvector and tsquery for PostgreSQL full-text search - tsvector = func.to_tsvector('english', Chunk.content) - tsquery = func.plainto_tsquery('english', query_text) - + tsvector = func.to_tsvector("english", Chunk.content) + tsquery = func.plainto_tsquery("english", query_text) + # Build the base query with user ownership check query = ( select(Chunk) @@ -84,64 +94,70 @@ class ChucksHybridSearchRetriever: .join(Document, Chunk.document_id == Document.id) .join(SearchSpace, Document.search_space_id == SearchSpace.id) .where(SearchSpace.user_id == user_id) - .where(tsvector.op("@@")(tsquery)) # Only include results that match the query + .where( + tsvector.op("@@")(tsquery) + ) # Only include results that match the query ) - + # Add search space filter if provided if search_space_id is not None: query = query.where(Document.search_space_id == search_space_id) - + # Add text search ranking - query = ( - query - .order_by(func.ts_rank_cd(tsvector, tsquery).desc()) - .limit(top_k) - ) - + query = query.order_by(func.ts_rank_cd(tsvector, tsquery).desc()).limit(top_k) + # Execute the query result = await self.db_session.execute(query) chunks = result.scalars().all() - + return chunks - async def hybrid_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None, document_type: str = None) -> list: + async def hybrid_search( + self, + query_text: str, + top_k: int, + user_id: str, + search_space_id: int | None = None, + document_type: str | None = None, + ) -> list: """ Combine vector similarity and full-text search results using Reciprocal Rank Fusion. - + Args: query_text: The search query text top_k: Number of results to return user_id: The ID of the user performing the search search_space_id: Optional search space ID to filter results document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL") - + Returns: List of dictionaries containing chunk data and relevance scores """ - from sqlalchemy import select, func, text + from sqlalchemy import func, select, text from sqlalchemy.orm import joinedload - from app.db import Chunk, Document, SearchSpace, DocumentType + from app.config import config - + from app.db import Chunk, Document, DocumentType, SearchSpace + # Get embedding for the query embedding_model = config.embedding_model_instance query_embedding = embedding_model.embed(query_text) - + # Constants for RRF calculation k = 60 # Constant for RRF calculation n_results = top_k * 2 # Get more results for better fusion - + # Create tsvector and tsquery for PostgreSQL full-text search - tsvector = func.to_tsvector('english', Chunk.content) - tsquery = func.plainto_tsquery('english', query_text) - + tsvector = func.to_tsvector("english", Chunk.content) + tsquery = func.plainto_tsquery("english", query_text) + # Base conditions for document filtering base_conditions = [SearchSpace.user_id == user_id] - + # Add search space filter if provided if search_space_id is not None: base_conditions.append(Document.search_space_id == search_space_id) - + # Add document type filter if provided if document_type is not None: # Convert string to enum value if needed @@ -154,90 +170,97 @@ class ChucksHybridSearchRetriever: return [] else: base_conditions.append(Document.document_type == document_type) - + # CTE for semantic search with user ownership check semantic_search_cte = ( select( Chunk.id, - func.rank().over(order_by=Chunk.embedding.op("<=>")(query_embedding)).label("rank") + func.rank() + .over(order_by=Chunk.embedding.op("<=>")(query_embedding)) + .label("rank"), ) .join(Document, Chunk.document_id == Document.id) .join(SearchSpace, Document.search_space_id == SearchSpace.id) .where(*base_conditions) ) - + semantic_search_cte = ( - semantic_search_cte - .order_by(Chunk.embedding.op("<=>")(query_embedding)) + semantic_search_cte.order_by(Chunk.embedding.op("<=>")(query_embedding)) .limit(n_results) .cte("semantic_search") ) - + # CTE for keyword search with user ownership check keyword_search_cte = ( select( Chunk.id, - func.rank().over(order_by=func.ts_rank_cd(tsvector, tsquery).desc()).label("rank") + func.rank() + .over(order_by=func.ts_rank_cd(tsvector, tsquery).desc()) + .label("rank"), ) .join(Document, Chunk.document_id == Document.id) .join(SearchSpace, Document.search_space_id == SearchSpace.id) .where(*base_conditions) .where(tsvector.op("@@")(tsquery)) ) - + keyword_search_cte = ( - keyword_search_cte - .order_by(func.ts_rank_cd(tsvector, tsquery).desc()) + keyword_search_cte.order_by(func.ts_rank_cd(tsvector, tsquery).desc()) .limit(n_results) .cte("keyword_search") ) - + # Final combined query using a FULL OUTER JOIN with RRF scoring final_query = ( select( Chunk, ( - func.coalesce(1.0 / (k + semantic_search_cte.c.rank), 0.0) + - func.coalesce(1.0 / (k + keyword_search_cte.c.rank), 0.0) - ).label("score") + func.coalesce(1.0 / (k + semantic_search_cte.c.rank), 0.0) + + func.coalesce(1.0 / (k + keyword_search_cte.c.rank), 0.0) + ).label("score"), ) .select_from( semantic_search_cte.outerjoin( - keyword_search_cte, + keyword_search_cte, semantic_search_cte.c.id == keyword_search_cte.c.id, - full=True + full=True, ) ) .join( Chunk, - Chunk.id == func.coalesce(semantic_search_cte.c.id, keyword_search_cte.c.id) + Chunk.id + == func.coalesce(semantic_search_cte.c.id, keyword_search_cte.c.id), ) .options(joinedload(Chunk.document)) .order_by(text("score DESC")) .limit(top_k) ) - + # Execute the query result = await self.db_session.execute(final_query) chunks_with_scores = result.all() - + # If no results were found, return an empty list if not chunks_with_scores: return [] - + # Convert to serializable dictionaries if no reranker is available or if reranking failed serialized_results = [] for chunk, score in chunks_with_scores: - serialized_results.append({ - "chunk_id": chunk.id, - "content": chunk.content, - "score": float(score), # Ensure score is a Python float - "document": { - "id": chunk.document.id, - "title": chunk.document.title, - "document_type": chunk.document.document_type.value if hasattr(chunk.document, 'document_type') else None, - "metadata": chunk.document.document_metadata + serialized_results.append( + { + "chunk_id": chunk.id, + "content": chunk.content, + "score": float(score), # Ensure score is a Python float + "document": { + "id": chunk.document.id, + "title": chunk.document.title, + "document_type": chunk.document.document_type.value + if hasattr(chunk.document, "document_type") + else None, + "metadata": chunk.document.document_metadata, + }, } - }) - + ) + return serialized_results diff --git a/surfsense_backend/app/retriver/documents_hybrid_search.py b/surfsense_backend/app/retriver/documents_hybrid_search.py index 2163635..a9bdb29 100644 --- a/surfsense_backend/app/retriver/documents_hybrid_search.py +++ b/surfsense_backend/app/retriver/documents_hybrid_search.py @@ -2,34 +2,41 @@ class DocumentHybridSearchRetriever: def __init__(self, db_session): """ Initialize the hybrid search retriever with a database session. - + Args: db_session: SQLAlchemy AsyncSession from FastAPI dependency injection """ self.db_session = db_session - async def vector_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None) -> list: + async def vector_search( + self, + query_text: str, + top_k: int, + user_id: str, + search_space_id: int | None = None, + ) -> list: """ Perform vector similarity search on documents. - + Args: query_text: The search query text top_k: Number of results to return user_id: The ID of the user performing the search search_space_id: Optional search space ID to filter results - + Returns: List of documents sorted by vector similarity """ - from sqlalchemy import select, func + from sqlalchemy import select from sqlalchemy.orm import joinedload - from app.db import Document, SearchSpace + from app.config import config - + from app.db import Document, SearchSpace + # Get embedding for the query embedding_model = config.embedding_model_instance query_embedding = embedding_model.embed(query_text) - + # Build the base query with user ownership check query = ( select(Document) @@ -37,107 +44,118 @@ class DocumentHybridSearchRetriever: .join(SearchSpace, Document.search_space_id == SearchSpace.id) .where(SearchSpace.user_id == user_id) ) - + # Add search space filter if provided if search_space_id is not None: query = query.where(Document.search_space_id == search_space_id) - + # Add vector similarity ordering - query = ( - query - .order_by(Document.embedding.op("<=>")(query_embedding)) - .limit(top_k) + query = query.order_by(Document.embedding.op("<=>")(query_embedding)).limit( + top_k ) - + # Execute the query result = await self.db_session.execute(query) documents = result.scalars().all() - + return documents - async def full_text_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None) -> list: + async def full_text_search( + self, + query_text: str, + top_k: int, + user_id: str, + search_space_id: int | None = None, + ) -> list: """ Perform full-text keyword search on documents. - + Args: query_text: The search query text top_k: Number of results to return user_id: The ID of the user performing the search search_space_id: Optional search space ID to filter results - + Returns: List of documents sorted by text relevance """ - from sqlalchemy import select, func, text + from sqlalchemy import func, select from sqlalchemy.orm import joinedload + from app.db import Document, SearchSpace - + # Create tsvector and tsquery for PostgreSQL full-text search - tsvector = func.to_tsvector('english', Document.content) - tsquery = func.plainto_tsquery('english', query_text) - + tsvector = func.to_tsvector("english", Document.content) + tsquery = func.plainto_tsquery("english", query_text) + # Build the base query with user ownership check query = ( select(Document) .options(joinedload(Document.search_space)) .join(SearchSpace, Document.search_space_id == SearchSpace.id) .where(SearchSpace.user_id == user_id) - .where(tsvector.op("@@")(tsquery)) # Only include results that match the query + .where( + tsvector.op("@@")(tsquery) + ) # Only include results that match the query ) - + # Add search space filter if provided if search_space_id is not None: query = query.where(Document.search_space_id == search_space_id) - + # Add text search ranking - query = ( - query - .order_by(func.ts_rank_cd(tsvector, tsquery).desc()) - .limit(top_k) - ) - + query = query.order_by(func.ts_rank_cd(tsvector, tsquery).desc()).limit(top_k) + # Execute the query result = await self.db_session.execute(query) documents = result.scalars().all() - + return documents - async def hybrid_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None, document_type: str = None) -> list: + async def hybrid_search( + self, + query_text: str, + top_k: int, + user_id: str, + search_space_id: int | None = None, + document_type: str | None = None, + ) -> list: """ Combine vector similarity and full-text search results using Reciprocal Rank Fusion. - + Args: query_text: The search query text top_k: Number of results to return user_id: The ID of the user performing the search search_space_id: Optional search space ID to filter results document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL") - + """ - from sqlalchemy import select, func, text + from sqlalchemy import func, select, text from sqlalchemy.orm import joinedload - from app.db import Document, SearchSpace, DocumentType + from app.config import config - + from app.db import Document, DocumentType, SearchSpace + # Get embedding for the query embedding_model = config.embedding_model_instance query_embedding = embedding_model.embed(query_text) - + # Constants for RRF calculation k = 60 # Constant for RRF calculation n_results = top_k * 2 # Get more results for better fusion - + # Create tsvector and tsquery for PostgreSQL full-text search - tsvector = func.to_tsvector('english', Document.content) - tsquery = func.plainto_tsquery('english', query_text) - + tsvector = func.to_tsvector("english", Document.content) + tsquery = func.plainto_tsquery("english", query_text) + # Base conditions for document filtering base_conditions = [SearchSpace.user_id == user_id] - + # Add search space filter if provided if search_space_id is not None: base_conditions.append(Document.search_space_id == search_space_id) - + # Add document type filter if provided if document_type is not None: # Convert string to enum value if needed @@ -150,98 +168,112 @@ class DocumentHybridSearchRetriever: return [] else: base_conditions.append(Document.document_type == document_type) - + # CTE for semantic search with user ownership check semantic_search_cte = ( select( Document.id, - func.rank().over(order_by=Document.embedding.op("<=>")(query_embedding)).label("rank") + func.rank() + .over(order_by=Document.embedding.op("<=>")(query_embedding)) + .label("rank"), ) .join(SearchSpace, Document.search_space_id == SearchSpace.id) .where(*base_conditions) ) - + semantic_search_cte = ( - semantic_search_cte - .order_by(Document.embedding.op("<=>")(query_embedding)) + semantic_search_cte.order_by(Document.embedding.op("<=>")(query_embedding)) .limit(n_results) .cte("semantic_search") ) - + # CTE for keyword search with user ownership check keyword_search_cte = ( select( Document.id, - func.rank().over(order_by=func.ts_rank_cd(tsvector, tsquery).desc()).label("rank") + func.rank() + .over(order_by=func.ts_rank_cd(tsvector, tsquery).desc()) + .label("rank"), ) .join(SearchSpace, Document.search_space_id == SearchSpace.id) .where(*base_conditions) .where(tsvector.op("@@")(tsquery)) ) - + keyword_search_cte = ( - keyword_search_cte - .order_by(func.ts_rank_cd(tsvector, tsquery).desc()) + keyword_search_cte.order_by(func.ts_rank_cd(tsvector, tsquery).desc()) .limit(n_results) .cte("keyword_search") ) - + # Final combined query using a FULL OUTER JOIN with RRF scoring final_query = ( select( Document, ( - func.coalesce(1.0 / (k + semantic_search_cte.c.rank), 0.0) + - func.coalesce(1.0 / (k + keyword_search_cte.c.rank), 0.0) - ).label("score") + func.coalesce(1.0 / (k + semantic_search_cte.c.rank), 0.0) + + func.coalesce(1.0 / (k + keyword_search_cte.c.rank), 0.0) + ).label("score"), ) .select_from( semantic_search_cte.outerjoin( - keyword_search_cte, + keyword_search_cte, semantic_search_cte.c.id == keyword_search_cte.c.id, - full=True + full=True, ) ) .join( Document, - Document.id == func.coalesce(semantic_search_cte.c.id, keyword_search_cte.c.id) + Document.id + == func.coalesce(semantic_search_cte.c.id, keyword_search_cte.c.id), ) .options(joinedload(Document.search_space)) .order_by(text("score DESC")) .limit(top_k) ) - + # Execute the query result = await self.db_session.execute(final_query) documents_with_scores = result.all() - + # If no results were found, return an empty list if not documents_with_scores: return [] - + # Convert to serializable dictionaries serialized_results = [] for document, score in documents_with_scores: # Fetch associated chunks for this document from sqlalchemy import select + from app.db import Chunk - - chunks_query = select(Chunk).where(Chunk.document_id == document.id).order_by(Chunk.id) + + chunks_query = ( + select(Chunk).where(Chunk.document_id == document.id).order_by(Chunk.id) + ) chunks_result = await self.db_session.execute(chunks_query) chunks = chunks_result.scalars().all() - + # Concatenate chunks content - concatenated_chunks_content = " ".join([chunk.content for chunk in chunks]) if chunks else document.content - - serialized_results.append({ - "document_id": document.id, - "title": document.title, - "content": document.content, - "chunks_content": concatenated_chunks_content, - "document_type": document.document_type.value if hasattr(document, 'document_type') else None, - "metadata": document.document_metadata, - "score": float(score), # Ensure score is a Python float - "search_space_id": document.search_space_id - }) - - return serialized_results \ No newline at end of file + concatenated_chunks_content = ( + " ".join([chunk.content for chunk in chunks]) + if chunks + else document.content + ) + + serialized_results.append( + { + "document_id": document.id, + "title": document.title, + "content": document.content, + "chunks_content": concatenated_chunks_content, + "document_type": document.document_type.value + if hasattr(document, "document_type") + else None, + "metadata": document.document_metadata, + "score": float(score), # Ensure score is a Python float + "search_space_id": document.search_space_id, + } + ) + + return serialized_results diff --git a/surfsense_backend/app/routes/__init__.py b/surfsense_backend/app/routes/__init__.py index dd6be9b..91c41ee 100644 --- a/surfsense_backend/app/routes/__init__.py +++ b/surfsense_backend/app/routes/__init__.py @@ -1,11 +1,12 @@ from fastapi import APIRouter -from .search_spaces_routes import router as search_spaces_router -from .documents_routes import router as documents_router -from .podcasts_routes import router as podcasts_router + from .chats_routes import router as chats_router -from .search_source_connectors_routes import router as search_source_connectors_router +from .documents_routes import router as documents_router from .llm_config_routes import router as llm_config_router from .logs_routes import router as logs_router +from .podcasts_routes import router as podcasts_router +from .search_source_connectors_routes import router as search_source_connectors_router +from .search_spaces_routes import router as search_spaces_router router = APIRouter() diff --git a/surfsense_backend/app/routes/chats_routes.py b/surfsense_backend/app/routes/chats_routes.py index dc7c126..e01b857 100644 --- a/surfsense_backend/app/routes/chats_routes.py +++ b/surfsense_backend/app/routes/chats_routes.py @@ -1,38 +1,40 @@ -from typing import List +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import StreamingResponse +from langchain.schema import AIMessage, HumanMessage +from sqlalchemy.exc import IntegrityError, OperationalError +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select from app.db import Chat, SearchSpace, User, get_async_session from app.schemas import AISDKChatRequest, ChatCreate, ChatRead, ChatUpdate from app.tasks.stream_connector_search_results import stream_connector_search_results from app.users import current_active_user from app.utils.check_ownership import check_ownership -from fastapi import APIRouter, Depends, HTTPException -from fastapi.responses import StreamingResponse -from sqlalchemy.exc import IntegrityError, OperationalError -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select -from langchain.schema import HumanMessage, AIMessage - router = APIRouter() + @router.post("/chat") async def handle_chat_data( request: AISDKChatRequest, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): messages = request.messages - if messages[-1]['role'] != "user": + if messages[-1]["role"] != "user": raise HTTPException( - status_code=400, detail="Last message must be a user message") + status_code=400, detail="Last message must be a user message" + ) - user_query = messages[-1]['content'] - search_space_id = request.data.get('search_space_id') - research_mode: str = request.data.get('research_mode') - selected_connectors: List[str] = request.data.get('selected_connectors') - document_ids_to_add_in_context: List[int] = request.data.get('document_ids_to_add_in_context') - - search_mode_str = request.data.get('search_mode', "CHUNKS") + user_query = messages[-1]["content"] + search_space_id = request.data.get("search_space_id") + research_mode: str = request.data.get("research_mode") + selected_connectors: list[str] = request.data.get("selected_connectors") + document_ids_to_add_in_context: list[int] = request.data.get( + "document_ids_to_add_in_context" + ) + + search_mode_str = request.data.get("search_mode", "CHUNKS") # Convert search_space_id to integer if it's a string if search_space_id and isinstance(search_space_id, str): @@ -40,21 +42,23 @@ async def handle_chat_data( search_space_id = int(search_space_id) except ValueError: raise HTTPException( - status_code=400, detail="Invalid search_space_id format") + status_code=400, detail="Invalid search_space_id format" + ) from None # Check if the search space belongs to the current user try: await check_ownership(session, SearchSpace, search_space_id, user) except HTTPException: raise HTTPException( - status_code=403, detail="You don't have access to this search space") - + status_code=403, detail="You don't have access to this search space" + ) from None + langchain_chat_history = [] for message in messages[:-1]: - if message['role'] == "user": - langchain_chat_history.append(HumanMessage(content=message['content'])) - elif message['role'] == "assistant": - langchain_chat_history.append(AIMessage(content=message['content'])) + if message["role"] == "user": + langchain_chat_history.append(HumanMessage(content=message["content"])) + elif message["role"] == "assistant": + langchain_chat_history.append(AIMessage(content=message["content"])) response = StreamingResponse( stream_connector_search_results( @@ -69,7 +73,7 @@ async def handle_chat_data( document_ids_to_add_in_context, ) ) - + response.headers["x-vercel-ai-data-stream"] = "v1" return response @@ -78,7 +82,7 @@ async def handle_chat_data( async def create_chat( chat: ChatCreate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): try: await check_ownership(session, SearchSpace, chat.search_space_id, user) @@ -89,52 +93,57 @@ async def create_chat( return db_chat except HTTPException: raise - except IntegrityError as e: + except IntegrityError: await session.rollback() raise HTTPException( - status_code=400, detail="Database constraint violation. Please check your input data.") - except OperationalError as e: + status_code=400, + detail="Database constraint violation. Please check your input data.", + ) from None + except OperationalError: await session.rollback() raise HTTPException( - status_code=503, detail="Database operation failed. Please try again later.") - except Exception as e: + status_code=503, detail="Database operation failed. Please try again later." + ) from None + except Exception: await session.rollback() raise HTTPException( - status_code=500, detail="An unexpected error occurred while creating the chat.") + status_code=500, + detail="An unexpected error occurred while creating the chat.", + ) from None -@router.get("/chats/", response_model=List[ChatRead]) +@router.get("/chats/", response_model=list[ChatRead]) async def read_chats( skip: int = 0, limit: int = 100, - search_space_id: int = None, + search_space_id: int | None = None, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): try: query = select(Chat).join(SearchSpace).filter(SearchSpace.user_id == user.id) - + # Filter by search_space_id if provided if search_space_id is not None: query = query.filter(Chat.search_space_id == search_space_id) - - result = await session.execute( - query.offset(skip).limit(limit) - ) + + result = await session.execute(query.offset(skip).limit(limit)) return result.scalars().all() except OperationalError: raise HTTPException( - status_code=503, detail="Database operation failed. Please try again later.") + status_code=503, detail="Database operation failed. Please try again later." + ) from None except Exception: raise HTTPException( - status_code=500, detail="An unexpected error occurred while fetching chats.") + status_code=500, detail="An unexpected error occurred while fetching chats." + ) from None @router.get("/chats/{chat_id}", response_model=ChatRead) async def read_chat( chat_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): try: result = await session.execute( @@ -145,14 +154,19 @@ async def read_chat( chat = result.scalars().first() if not chat: raise HTTPException( - status_code=404, detail="Chat not found or you don't have permission to access it") + status_code=404, + detail="Chat not found or you don't have permission to access it", + ) return chat except OperationalError: raise HTTPException( - status_code=503, detail="Database operation failed. Please try again later.") + status_code=503, detail="Database operation failed. Please try again later." + ) from None except Exception: raise HTTPException( - status_code=500, detail="An unexpected error occurred while fetching the chat.") + status_code=500, + detail="An unexpected error occurred while fetching the chat.", + ) from None @router.put("/chats/{chat_id}", response_model=ChatRead) @@ -160,7 +174,7 @@ async def update_chat( chat_id: int, chat_update: ChatUpdate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): try: db_chat = await read_chat(chat_id, session, user) @@ -175,22 +189,27 @@ async def update_chat( except IntegrityError: await session.rollback() raise HTTPException( - status_code=400, detail="Database constraint violation. Please check your input data.") + status_code=400, + detail="Database constraint violation. Please check your input data.", + ) from None except OperationalError: await session.rollback() raise HTTPException( - status_code=503, detail="Database operation failed. Please try again later.") + status_code=503, detail="Database operation failed. Please try again later." + ) from None except Exception: await session.rollback() raise HTTPException( - status_code=500, detail="An unexpected error occurred while updating the chat.") + status_code=500, + detail="An unexpected error occurred while updating the chat.", + ) from None @router.delete("/chats/{chat_id}", response_model=dict) async def delete_chat( chat_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): try: db_chat = await read_chat(chat_id, session, user) @@ -202,81 +221,16 @@ async def delete_chat( except IntegrityError: await session.rollback() raise HTTPException( - status_code=400, detail="Cannot delete chat due to existing dependencies.") + status_code=400, detail="Cannot delete chat due to existing dependencies." + ) from None except OperationalError: await session.rollback() raise HTTPException( - status_code=503, detail="Database operation failed. Please try again later.") + status_code=503, detail="Database operation failed. Please try again later." + ) from None except Exception: await session.rollback() raise HTTPException( - status_code=500, detail="An unexpected error occurred while deleting the chat.") - - -# test_data = [ -# { -# "type": "TERMINAL_INFO", -# "content": [ -# { -# "id": 1, -# "text": "Starting to search for crawled URLs...", -# "type": "info" -# }, -# { -# "id": 2, -# "text": "Found 2 relevant crawled URLs", -# "type": "success" -# } -# ] -# }, -# { -# "type": "SOURCES", -# "content": [ -# { -# "id": 1, -# "name": "Crawled URLs", -# "type": "CRAWLED_URL", -# "sources": [ -# { -# "id": 1, -# "title": "Webpage Title", -# "description": "Webpage Dec", -# "url": "https://jsoneditoronline.org/" -# }, -# { -# "id": 2, -# "title": "Webpage Title", -# "description": "Webpage Dec", -# "url": "https://www.google.com/" -# } -# ] -# }, -# { -# "id": 2, -# "name": "Files", -# "type": "FILE", -# "sources": [ -# { -# "id": 3, -# "title": "Webpage Title", -# "description": "Webpage Dec", -# "url": "https://jsoneditoronline.org/" -# }, -# { -# "id": 4, -# "title": "Webpage Title", -# "description": "Webpage Dec", -# "url": "https://www.google.com/" -# } -# ] -# } -# ] -# }, -# { -# "type": "ANSWER", -# "content": [ -# "## SurfSense Introduction", -# "Surfsense is A Personal NotebookLM and Perplexity-like AI Assistant for Everyone. Research and Never forget Anything. [1] [3]" -# ] -# } -# ] + status_code=500, + detail="An unexpected error occurred while deleting the chat.", + ) from None diff --git a/surfsense_backend/app/routes/documents_routes.py b/surfsense_backend/app/routes/documents_routes.py index 190a7f1..2c21fa3 100644 --- a/surfsense_backend/app/routes/documents_routes.py +++ b/surfsense_backend/app/routes/documents_routes.py @@ -1,23 +1,35 @@ -from litellm import atranscription -from fastapi import APIRouter, Depends, BackgroundTasks, UploadFile, Form, HTTPException -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select -from typing import List -from app.db import Log, get_async_session, User, SearchSpace, Document, DocumentType -from app.schemas import DocumentsCreate, DocumentUpdate, DocumentRead -from app.users import current_active_user -from app.utils.check_ownership import check_ownership -from app.tasks.background_tasks import add_received_markdown_file_document, add_extension_received_document, add_received_file_document_using_unstructured, add_crawled_url_document, add_youtube_video_document, add_received_file_document_using_llamacloud, add_received_file_document_using_docling -from app.config import config as app_config # Force asyncio to use standard event loop before unstructured imports import asyncio +from fastapi import APIRouter, BackgroundTasks, Depends, Form, HTTPException, UploadFile +from litellm import atranscription +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.config import config as app_config +from app.db import Document, DocumentType, Log, SearchSpace, User, get_async_session +from app.schemas import DocumentRead, DocumentsCreate, DocumentUpdate from app.services.task_logging_service import TaskLoggingService +from app.tasks.background_tasks import ( + add_crawled_url_document, + add_extension_received_document, + add_received_file_document_using_docling, + add_received_file_document_using_llamacloud, + add_received_file_document_using_unstructured, + add_received_markdown_file_document, + add_youtube_video_document, +) +from app.users import current_active_user +from app.utils.check_ownership import check_ownership + try: asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy()) -except RuntimeError: +except RuntimeError as e: + print("Error setting event loop policy", e) pass + import os + os.environ["UNSTRUCTURED_HAS_PATCHED_LOOP"] = "1" @@ -29,7 +41,7 @@ async def create_documents( request: DocumentsCreate, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), - fastapi_background_tasks: BackgroundTasks = BackgroundTasks() + fastapi_background_tasks: BackgroundTasks = BackgroundTasks(), ): try: # Check if the user owns the search space @@ -41,7 +53,7 @@ async def create_documents( process_extension_document_with_new_session, individual_document, request.search_space_id, - str(user.id) + str(user.id), ) elif request.document_type == DocumentType.CRAWLED_URL: for url in request.content: @@ -49,7 +61,7 @@ async def create_documents( process_crawled_url_with_new_session, url, request.search_space_id, - str(user.id) + str(user.id), ) elif request.document_type == DocumentType.YOUTUBE_VIDEO: for url in request.content: @@ -57,13 +69,10 @@ async def create_documents( process_youtube_video_with_new_session, url, request.search_space_id, - str(user.id) + str(user.id), ) else: - raise HTTPException( - status_code=400, - detail="Invalid document type" - ) + raise HTTPException(status_code=400, detail="Invalid document type") await session.commit() return {"message": "Documents processed successfully"} @@ -72,18 +81,17 @@ async def create_documents( except Exception as e: await session.rollback() raise HTTPException( - status_code=500, - detail=f"Failed to process documents: {str(e)}" - ) + status_code=500, detail=f"Failed to process documents: {e!s}" + ) from e @router.post("/documents/fileupload") -async def create_documents( +async def create_documents_file_upload( files: list[UploadFile], search_space_id: int = Form(...), session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), - fastapi_background_tasks: BackgroundTasks = BackgroundTasks() + fastapi_background_tasks: BackgroundTasks = BackgroundTasks(), ): try: await check_ownership(session, SearchSpace, search_space_id, user) @@ -94,31 +102,32 @@ async def create_documents( for file in files: try: # Save file to a temporary location to avoid stream issues - import tempfile - import aiofiles import os + import tempfile # Create temp file - with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file: + with tempfile.NamedTemporaryFile( + delete=False, suffix=os.path.splitext(file.filename)[1] + ) as temp_file: temp_path = temp_file.name # Write uploaded file to temp file content = await file.read() with open(temp_path, "wb") as f: f.write(content) - + fastapi_background_tasks.add_task( process_file_in_background_with_new_session, temp_path, file.filename, search_space_id, - str(user.id) + str(user.id), ) except Exception as e: raise HTTPException( status_code=422, - detail=f"Failed to process file {file.filename}: {str(e)}" - ) + detail=f"Failed to process file {file.filename}: {e!s}", + ) from e await session.commit() return {"message": "Files uploaded for processing"} @@ -127,9 +136,8 @@ async def create_documents( except Exception as e: await session.rollback() raise HTTPException( - status_code=500, - detail=f"Failed to upload files: {str(e)}" - ) + status_code=500, detail=f"Failed to upload files: {e!s}" + ) from e async def process_file_in_background( @@ -139,64 +147,71 @@ async def process_file_in_background( user_id: str, session: AsyncSession, task_logger: TaskLoggingService, - log_entry: Log + log_entry: Log, ): try: # Check if the file is a markdown or text file - if filename.lower().endswith(('.md', '.markdown', '.txt')): + if filename.lower().endswith((".md", ".markdown", ".txt")): await task_logger.log_task_progress( log_entry, f"Processing markdown/text file: {filename}", - {"file_type": "markdown", "processing_stage": "reading_file"} + {"file_type": "markdown", "processing_stage": "reading_file"}, ) - + # For markdown files, read the content directly - with open(file_path, 'r', encoding='utf-8') as f: + with open(file_path, encoding="utf-8") as f: markdown_content = f.read() # Clean up the temp file import os + try: os.unlink(file_path) - except: + except Exception as e: + print("Error deleting temp file", e) pass await task_logger.log_task_progress( log_entry, f"Creating document from markdown content: {filename}", - {"processing_stage": "creating_document", "content_length": len(markdown_content)} + { + "processing_stage": "creating_document", + "content_length": len(markdown_content), + }, ) # Process markdown directly through specialized function result = await add_received_markdown_file_document( - session, - filename, - markdown_content, - search_space_id, - user_id + session, filename, markdown_content, search_space_id, user_id ) - + if result: await task_logger.log_task_success( log_entry, f"Successfully processed markdown file: {filename}", - {"document_id": result.id, "content_hash": result.content_hash, "file_type": "markdown"} + { + "document_id": result.id, + "content_hash": result.content_hash, + "file_type": "markdown", + }, ) else: await task_logger.log_task_success( log_entry, f"Markdown file already exists (duplicate): {filename}", - {"duplicate_detected": True, "file_type": "markdown"} + {"duplicate_detected": True, "file_type": "markdown"}, ) - + # Check if the file is an audio file - elif filename.lower().endswith(('.mp3', '.mp4', '.mpeg', '.mpga', '.m4a', '.wav', '.webm')): + elif filename.lower().endswith( + (".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm") + ): await task_logger.log_task_progress( log_entry, f"Processing audio file for transcription: {filename}", - {"file_type": "audio", "processing_stage": "starting_transcription"} + {"file_type": "audio", "processing_stage": "starting_transcription"}, ) - + # Open the audio file for transcription with open(file_path, "rb") as audio_file: # Use LiteLLM for audio transcription @@ -205,65 +220,76 @@ async def process_file_in_background( model=app_config.STT_SERVICE, file=audio_file, api_base=app_config.STT_SERVICE_API_BASE, - api_key=app_config.STT_SERVICE_API_KEY + api_key=app_config.STT_SERVICE_API_KEY, ) else: transcription_response = await atranscription( model=app_config.STT_SERVICE, api_key=app_config.STT_SERVICE_API_KEY, - file=audio_file + file=audio_file, ) # Extract the transcribed text transcribed_text = transcription_response.get("text", "") # Add metadata about the transcription - transcribed_text = f"# Transcription of {filename}\n\n{transcribed_text}" + transcribed_text = ( + f"# Transcription of {filename}\n\n{transcribed_text}" + ) await task_logger.log_task_progress( log_entry, f"Transcription completed, creating document: {filename}", - {"processing_stage": "transcription_complete", "transcript_length": len(transcribed_text)} + { + "processing_stage": "transcription_complete", + "transcript_length": len(transcribed_text), + }, ) # Clean up the temp file try: os.unlink(file_path) - except: + except Exception as e: + print("Error deleting temp file", e) pass # Process transcription as markdown document result = await add_received_markdown_file_document( - session, - filename, - transcribed_text, - search_space_id, - user_id + session, filename, transcribed_text, search_space_id, user_id ) - + if result: await task_logger.log_task_success( log_entry, f"Successfully transcribed and processed audio file: {filename}", - {"document_id": result.id, "content_hash": result.content_hash, "file_type": "audio", "transcript_length": len(transcribed_text)} + { + "document_id": result.id, + "content_hash": result.content_hash, + "file_type": "audio", + "transcript_length": len(transcribed_text), + }, ) else: await task_logger.log_task_success( log_entry, f"Audio file transcript already exists (duplicate): {filename}", - {"duplicate_detected": True, "file_type": "audio"} + {"duplicate_detected": True, "file_type": "audio"}, ) - + else: if app_config.ETL_SERVICE == "UNSTRUCTURED": await task_logger.log_task_progress( log_entry, f"Processing file with Unstructured ETL: {filename}", - {"file_type": "document", "etl_service": "UNSTRUCTURED", "processing_stage": "loading"} + { + "file_type": "document", + "etl_service": "UNSTRUCTURED", + "processing_stage": "loading", + }, ) - + from langchain_unstructured import UnstructuredLoader - + # Process the file loader = UnstructuredLoader( file_path, @@ -280,212 +306,257 @@ async def process_file_in_background( await task_logger.log_task_progress( log_entry, f"Unstructured ETL completed, creating document: {filename}", - {"processing_stage": "etl_complete", "elements_count": len(docs)} + {"processing_stage": "etl_complete", "elements_count": len(docs)}, ) # Clean up the temp file import os + try: os.unlink(file_path) - except: + except Exception as e: + print("Error deleting temp file", e) pass # Pass the documents to the existing background task result = await add_received_file_document_using_unstructured( - session, - filename, - docs, - search_space_id, - user_id + session, filename, docs, search_space_id, user_id ) - + if result: await task_logger.log_task_success( log_entry, f"Successfully processed file with Unstructured: {filename}", - {"document_id": result.id, "content_hash": result.content_hash, "file_type": "document", "etl_service": "UNSTRUCTURED"} + { + "document_id": result.id, + "content_hash": result.content_hash, + "file_type": "document", + "etl_service": "UNSTRUCTURED", + }, ) else: await task_logger.log_task_success( log_entry, f"Document already exists (duplicate): {filename}", - {"duplicate_detected": True, "file_type": "document", "etl_service": "UNSTRUCTURED"} + { + "duplicate_detected": True, + "file_type": "document", + "etl_service": "UNSTRUCTURED", + }, ) - + elif app_config.ETL_SERVICE == "LLAMACLOUD": await task_logger.log_task_progress( log_entry, f"Processing file with LlamaCloud ETL: {filename}", - {"file_type": "document", "etl_service": "LLAMACLOUD", "processing_stage": "parsing"} + { + "file_type": "document", + "etl_service": "LLAMACLOUD", + "processing_stage": "parsing", + }, ) - + from llama_cloud_services import LlamaParse from llama_cloud_services.parse.utils import ResultType - # Create LlamaParse parser instance parser = LlamaParse( api_key=app_config.LLAMA_CLOUD_API_KEY, num_workers=1, # Use single worker for file processing verbose=True, language="en", - result_type=ResultType.MD + result_type=ResultType.MD, ) - + # Parse the file asynchronously result = await parser.aparse(file_path) - + # Clean up the temp file import os + try: os.unlink(file_path) - except: + except Exception as e: + print("Error deleting temp file", e) pass - + # Get markdown documents from the result - markdown_documents = await result.aget_markdown_documents(split_by_page=False) - + markdown_documents = await result.aget_markdown_documents( + split_by_page=False + ) + await task_logger.log_task_progress( log_entry, f"LlamaCloud parsing completed, creating documents: {filename}", - {"processing_stage": "parsing_complete", "documents_count": len(markdown_documents)} + { + "processing_stage": "parsing_complete", + "documents_count": len(markdown_documents), + }, ) - + for doc in markdown_documents: # Extract text content from the markdown documents markdown_content = doc.text - + # Process the documents using our LlamaCloud background task doc_result = await add_received_file_document_using_llamacloud( session, filename, llamacloud_markdown_document=markdown_content, search_space_id=search_space_id, - user_id=user_id + user_id=user_id, ) - + if doc_result: await task_logger.log_task_success( log_entry, f"Successfully processed file with LlamaCloud: {filename}", - {"document_id": doc_result.id, "content_hash": doc_result.content_hash, "file_type": "document", "etl_service": "LLAMACLOUD"} + { + "document_id": doc_result.id, + "content_hash": doc_result.content_hash, + "file_type": "document", + "etl_service": "LLAMACLOUD", + }, ) else: await task_logger.log_task_success( log_entry, f"Document already exists (duplicate): {filename}", - {"duplicate_detected": True, "file_type": "document", "etl_service": "LLAMACLOUD"} + { + "duplicate_detected": True, + "file_type": "document", + "etl_service": "LLAMACLOUD", + }, ) - + elif app_config.ETL_SERVICE == "DOCLING": await task_logger.log_task_progress( log_entry, f"Processing file with Docling ETL: {filename}", - {"file_type": "document", "etl_service": "DOCLING", "processing_stage": "parsing"} + { + "file_type": "document", + "etl_service": "DOCLING", + "processing_stage": "parsing", + }, ) - + # Use Docling service for document processing from app.services.docling_service import create_docling_service - + # Create Docling service docling_service = create_docling_service() - + # Process the document result = await docling_service.process_document(file_path, filename) - + # Clean up the temp file import os + try: os.unlink(file_path) - except: + except Exception as e: + print("Error deleting temp file", e) pass - + await task_logger.log_task_progress( log_entry, f"Docling parsing completed, creating document: {filename}", - {"processing_stage": "parsing_complete", "content_length": len(result['content'])} + { + "processing_stage": "parsing_complete", + "content_length": len(result["content"]), + }, ) - + # Process the document using our Docling background task doc_result = await add_received_file_document_using_docling( session, filename, - docling_markdown_document=result['content'], + docling_markdown_document=result["content"], search_space_id=search_space_id, - user_id=user_id + user_id=user_id, ) - + if doc_result: await task_logger.log_task_success( log_entry, f"Successfully processed file with Docling: {filename}", - {"document_id": doc_result.id, "content_hash": doc_result.content_hash, "file_type": "document", "etl_service": "DOCLING"} + { + "document_id": doc_result.id, + "content_hash": doc_result.content_hash, + "file_type": "document", + "etl_service": "DOCLING", + }, ) else: await task_logger.log_task_success( log_entry, f"Document already exists (duplicate): {filename}", - {"duplicate_detected": True, "file_type": "document", "etl_service": "DOCLING"} + { + "duplicate_detected": True, + "file_type": "document", + "etl_service": "DOCLING", + }, ) except Exception as e: await task_logger.log_task_failure( log_entry, f"Failed to process file: {filename}", str(e), - {"error_type": type(e).__name__, "filename": filename} + {"error_type": type(e).__name__, "filename": filename}, ) import logging - logging.error(f"Error processing file in background: {str(e)}") + + logging.error(f"Error processing file in background: {e!s}") raise # Re-raise so the wrapper can also handle it -@router.get("/documents/", response_model=List[DocumentRead]) +@router.get("/documents/", response_model=list[DocumentRead]) async def read_documents( skip: int = 0, limit: int = 300, - search_space_id: int = None, + search_space_id: int | None = None, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): try: - query = select(Document).join(SearchSpace).filter( - SearchSpace.user_id == user.id) + query = ( + select(Document).join(SearchSpace).filter(SearchSpace.user_id == user.id) + ) # Filter by search_space_id if provided if search_space_id is not None: query = query.filter(Document.search_space_id == search_space_id) - result = await session.execute( - query.offset(skip).limit(limit) - ) + result = await session.execute(query.offset(skip).limit(limit)) db_documents = result.scalars().all() # Convert database objects to API-friendly format api_documents = [] for doc in db_documents: - api_documents.append(DocumentRead( - id=doc.id, - title=doc.title, - document_type=doc.document_type, - document_metadata=doc.document_metadata, - content=doc.content, - created_at=doc.created_at, - search_space_id=doc.search_space_id - )) + api_documents.append( + DocumentRead( + id=doc.id, + title=doc.title, + document_type=doc.document_type, + document_metadata=doc.document_metadata, + content=doc.content, + created_at=doc.created_at, + search_space_id=doc.search_space_id, + ) + ) return api_documents except Exception as e: raise HTTPException( - status_code=500, - detail=f"Failed to fetch documents: {str(e)}" - ) + status_code=500, detail=f"Failed to fetch documents: {e!s}" + ) from e @router.get("/documents/{document_id}", response_model=DocumentRead) async def read_document( document_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): try: result = await session.execute( @@ -497,8 +568,7 @@ async def read_document( if not document: raise HTTPException( - status_code=404, - detail=f"Document with id {document_id} not found" + status_code=404, detail=f"Document with id {document_id} not found" ) # Convert database object to API-friendly format @@ -509,13 +579,12 @@ async def read_document( document_metadata=document.document_metadata, content=document.content, created_at=document.created_at, - search_space_id=document.search_space_id + search_space_id=document.search_space_id, ) except Exception as e: raise HTTPException( - status_code=500, - detail=f"Failed to fetch document: {str(e)}" - ) + status_code=500, detail=f"Failed to fetch document: {e!s}" + ) from e @router.put("/documents/{document_id}", response_model=DocumentRead) @@ -523,7 +592,7 @@ async def update_document( document_id: int, document_update: DocumentUpdate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): try: # Query the document directly instead of using read_document function @@ -536,8 +605,7 @@ async def update_document( if not db_document: raise HTTPException( - status_code=404, - detail=f"Document with id {document_id} not found" + status_code=404, detail=f"Document with id {document_id} not found" ) update_data = document_update.model_dump(exclude_unset=True) @@ -554,23 +622,22 @@ async def update_document( document_metadata=db_document.document_metadata, content=db_document.content, created_at=db_document.created_at, - search_space_id=db_document.search_space_id + search_space_id=db_document.search_space_id, ) except HTTPException: raise except Exception as e: await session.rollback() raise HTTPException( - status_code=500, - detail=f"Failed to update document: {str(e)}" - ) + status_code=500, detail=f"Failed to update document: {e!s}" + ) from e @router.delete("/documents/{document_id}", response_model=dict) async def delete_document( document_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): try: # Query the document directly instead of using read_document function @@ -583,8 +650,7 @@ async def delete_document( if not document: raise HTTPException( - status_code=404, - detail=f"Document with id {document_id} not found" + status_code=404, detail=f"Document with id {document_id} not found" ) await session.delete(document) @@ -595,15 +661,12 @@ async def delete_document( except Exception as e: await session.rollback() raise HTTPException( - status_code=500, - detail=f"Failed to delete document: {str(e)}" - ) + status_code=500, detail=f"Failed to delete document: {e!s}" + ) from e async def process_extension_document_with_new_session( - individual_document, - search_space_id: int, - user_id: str + individual_document, search_space_id: int, user_id: str ): """Create a new session and process extension document.""" from app.db import async_session_maker @@ -612,7 +675,7 @@ async def process_extension_document_with_new_session( async with async_session_maker() as session: # Initialize task logging service task_logger = TaskLoggingService(session, search_space_id) - + # Log task start log_entry = await task_logger.log_task_start( task_name="process_extension_document", @@ -622,40 +685,41 @@ async def process_extension_document_with_new_session( "document_type": "EXTENSION", "url": individual_document.metadata.VisitedWebPageURL, "title": individual_document.metadata.VisitedWebPageTitle, - "user_id": user_id - } + "user_id": user_id, + }, ) - + try: - result = await add_extension_received_document(session, individual_document, search_space_id, user_id) - + result = await add_extension_received_document( + session, individual_document, search_space_id, user_id + ) + if result: await task_logger.log_task_success( log_entry, f"Successfully processed extension document: {individual_document.metadata.VisitedWebPageTitle}", - {"document_id": result.id, "content_hash": result.content_hash} + {"document_id": result.id, "content_hash": result.content_hash}, ) else: await task_logger.log_task_success( log_entry, f"Extension document already exists (duplicate): {individual_document.metadata.VisitedWebPageTitle}", - {"duplicate_detected": True} + {"duplicate_detected": True}, ) except Exception as e: await task_logger.log_task_failure( log_entry, f"Failed to process extension document: {individual_document.metadata.VisitedWebPageTitle}", str(e), - {"error_type": type(e).__name__} + {"error_type": type(e).__name__}, ) import logging - logging.error(f"Error processing extension document: {str(e)}") + + logging.error(f"Error processing extension document: {e!s}") async def process_crawled_url_with_new_session( - url: str, - search_space_id: int, - user_id: str + url: str, search_space_id: int, user_id: str ): """Create a new session and process crawled URL.""" from app.db import async_session_maker @@ -664,50 +728,50 @@ async def process_crawled_url_with_new_session( async with async_session_maker() as session: # Initialize task logging service task_logger = TaskLoggingService(session, search_space_id) - + # Log task start log_entry = await task_logger.log_task_start( task_name="process_crawled_url", source="document_processor", message=f"Starting URL crawling and processing for: {url}", - metadata={ - "document_type": "CRAWLED_URL", - "url": url, - "user_id": user_id - } + metadata={"document_type": "CRAWLED_URL", "url": url, "user_id": user_id}, ) - + try: - result = await add_crawled_url_document(session, url, search_space_id, user_id) - + result = await add_crawled_url_document( + session, url, search_space_id, user_id + ) + if result: await task_logger.log_task_success( log_entry, f"Successfully crawled and processed URL: {url}", - {"document_id": result.id, "title": result.title, "content_hash": result.content_hash} + { + "document_id": result.id, + "title": result.title, + "content_hash": result.content_hash, + }, ) else: await task_logger.log_task_success( log_entry, f"URL document already exists (duplicate): {url}", - {"duplicate_detected": True} + {"duplicate_detected": True}, ) except Exception as e: await task_logger.log_task_failure( log_entry, f"Failed to crawl URL: {url}", str(e), - {"error_type": type(e).__name__} + {"error_type": type(e).__name__}, ) import logging - logging.error(f"Error processing crawled URL: {str(e)}") + + logging.error(f"Error processing crawled URL: {e!s}") async def process_file_in_background_with_new_session( - file_path: str, - filename: str, - search_space_id: int, - user_id: str + file_path: str, filename: str, search_space_id: int, user_id: str ): """Create a new session and process file.""" from app.db import async_session_maker @@ -716,7 +780,7 @@ async def process_file_in_background_with_new_session( async with async_session_maker() as session: # Initialize task logging service task_logger = TaskLoggingService(session, search_space_id) - + # Log task start log_entry = await task_logger.log_task_start( task_name="process_file_upload", @@ -726,29 +790,36 @@ async def process_file_in_background_with_new_session( "document_type": "FILE", "filename": filename, "file_path": file_path, - "user_id": user_id - } + "user_id": user_id, + }, ) - + try: - await process_file_in_background(file_path, filename, search_space_id, user_id, session, task_logger, log_entry) - + await process_file_in_background( + file_path, + filename, + search_space_id, + user_id, + session, + task_logger, + log_entry, + ) + # Note: success/failure logging is handled within process_file_in_background except Exception as e: await task_logger.log_task_failure( log_entry, f"Failed to process file: {filename}", str(e), - {"error_type": type(e).__name__} + {"error_type": type(e).__name__}, ) import logging - logging.error(f"Error processing file: {str(e)}") + + logging.error(f"Error processing file: {e!s}") async def process_youtube_video_with_new_session( - url: str, - search_space_id: int, - user_id: str + url: str, search_space_id: int, user_id: str ): """Create a new session and process YouTube video.""" from app.db import async_session_maker @@ -757,42 +828,43 @@ async def process_youtube_video_with_new_session( async with async_session_maker() as session: # Initialize task logging service task_logger = TaskLoggingService(session, search_space_id) - + # Log task start log_entry = await task_logger.log_task_start( task_name="process_youtube_video", source="document_processor", message=f"Starting YouTube video processing for: {url}", - metadata={ - "document_type": "YOUTUBE_VIDEO", - "url": url, - "user_id": user_id - } + metadata={"document_type": "YOUTUBE_VIDEO", "url": url, "user_id": user_id}, ) - + try: - result = await add_youtube_video_document(session, url, search_space_id, user_id) - + result = await add_youtube_video_document( + session, url, search_space_id, user_id + ) + if result: await task_logger.log_task_success( log_entry, f"Successfully processed YouTube video: {result.title}", - {"document_id": result.id, "video_id": result.document_metadata.get("video_id"), "content_hash": result.content_hash} + { + "document_id": result.id, + "video_id": result.document_metadata.get("video_id"), + "content_hash": result.content_hash, + }, ) else: await task_logger.log_task_success( log_entry, f"YouTube video document already exists (duplicate): {url}", - {"duplicate_detected": True} + {"duplicate_detected": True}, ) except Exception as e: await task_logger.log_task_failure( log_entry, f"Failed to process YouTube video: {url}", str(e), - {"error_type": type(e).__name__} + {"error_type": type(e).__name__}, ) import logging - logging.error(f"Error processing YouTube video: {str(e)}") - + logging.error(f"Error processing YouTube video: {e!s}") diff --git a/surfsense_backend/app/routes/llm_config_routes.py b/surfsense_backend/app/routes/llm_config_routes.py index 644503f..ce76dc9 100644 --- a/surfsense_backend/app/routes/llm_config_routes.py +++ b/surfsense_backend/app/routes/llm_config_routes.py @@ -1,35 +1,40 @@ from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from typing import List, Optional -from pydantic import BaseModel -from app.db import get_async_session, User, LLMConfig -from app.schemas import LLMConfigCreate, LLMConfigUpdate, LLMConfigRead + +from app.db import LLMConfig, User, get_async_session +from app.schemas import LLMConfigCreate, LLMConfigRead, LLMConfigUpdate from app.users import current_active_user from app.utils.check_ownership import check_ownership router = APIRouter() + class LLMPreferencesUpdate(BaseModel): """Schema for updating user LLM preferences""" - long_context_llm_id: Optional[int] = None - fast_llm_id: Optional[int] = None - strategic_llm_id: Optional[int] = None + + long_context_llm_id: int | None = None + fast_llm_id: int | None = None + strategic_llm_id: int | None = None + class LLMPreferencesRead(BaseModel): """Schema for reading user LLM preferences""" - long_context_llm_id: Optional[int] = None - fast_llm_id: Optional[int] = None - strategic_llm_id: Optional[int] = None - long_context_llm: Optional[LLMConfigRead] = None - fast_llm: Optional[LLMConfigRead] = None - strategic_llm: Optional[LLMConfigRead] = None + + long_context_llm_id: int | None = None + fast_llm_id: int | None = None + strategic_llm_id: int | None = None + long_context_llm: LLMConfigRead | None = None + fast_llm: LLMConfigRead | None = None + strategic_llm: LLMConfigRead | None = None + @router.post("/llm-configs/", response_model=LLMConfigRead) async def create_llm_config( llm_config: LLMConfigCreate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): """Create a new LLM configuration for the authenticated user""" try: @@ -43,16 +48,16 @@ async def create_llm_config( except Exception as e: await session.rollback() raise HTTPException( - status_code=500, - detail=f"Failed to create LLM configuration: {str(e)}" - ) + status_code=500, detail=f"Failed to create LLM configuration: {e!s}" + ) from e -@router.get("/llm-configs/", response_model=List[LLMConfigRead]) + +@router.get("/llm-configs/", response_model=list[LLMConfigRead]) async def read_llm_configs( skip: int = 0, limit: int = 200, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): """Get all LLM configurations for the authenticated user""" try: @@ -65,15 +70,15 @@ async def read_llm_configs( return result.scalars().all() except Exception as e: raise HTTPException( - status_code=500, - detail=f"Failed to fetch LLM configurations: {str(e)}" - ) + status_code=500, detail=f"Failed to fetch LLM configurations: {e!s}" + ) from e + @router.get("/llm-configs/{llm_config_id}", response_model=LLMConfigRead) async def read_llm_config( llm_config_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): """Get a specific LLM configuration by ID""" try: @@ -83,25 +88,25 @@ async def read_llm_config( raise except Exception as e: raise HTTPException( - status_code=500, - detail=f"Failed to fetch LLM configuration: {str(e)}" - ) + status_code=500, detail=f"Failed to fetch LLM configuration: {e!s}" + ) from e + @router.put("/llm-configs/{llm_config_id}", response_model=LLMConfigRead) async def update_llm_config( llm_config_id: int, llm_config_update: LLMConfigUpdate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): """Update an existing LLM configuration""" try: db_llm_config = await check_ownership(session, LLMConfig, llm_config_id, user) update_data = llm_config_update.model_dump(exclude_unset=True) - + for key, value in update_data.items(): setattr(db_llm_config, key, value) - + await session.commit() await session.refresh(db_llm_config) return db_llm_config @@ -110,15 +115,15 @@ async def update_llm_config( except Exception as e: await session.rollback() raise HTTPException( - status_code=500, - detail=f"Failed to update LLM configuration: {str(e)}" - ) + status_code=500, detail=f"Failed to update LLM configuration: {e!s}" + ) from e + @router.delete("/llm-configs/{llm_config_id}", response_model=dict) async def delete_llm_config( llm_config_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): """Delete an LLM configuration""" try: @@ -131,22 +136,23 @@ async def delete_llm_config( except Exception as e: await session.rollback() raise HTTPException( - status_code=500, - detail=f"Failed to delete LLM configuration: {str(e)}" - ) + status_code=500, detail=f"Failed to delete LLM configuration: {e!s}" + ) from e + # User LLM Preferences endpoints + @router.get("/users/me/llm-preferences", response_model=LLMPreferencesRead) async def get_user_llm_preferences( session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): """Get the current user's LLM preferences""" try: # Refresh user to get latest relationships await session.refresh(user) - + result = { "long_context_llm_id": user.long_context_llm_id, "fast_llm_id": user.fast_llm_id, @@ -155,82 +161,79 @@ async def get_user_llm_preferences( "fast_llm": None, "strategic_llm": None, } - + # Fetch the actual LLM configs if they exist if user.long_context_llm_id: long_context_llm = await session.execute( select(LLMConfig).filter( LLMConfig.id == user.long_context_llm_id, - LLMConfig.user_id == user.id + LLMConfig.user_id == user.id, ) ) llm_config = long_context_llm.scalars().first() if llm_config: result["long_context_llm"] = llm_config - + if user.fast_llm_id: fast_llm = await session.execute( select(LLMConfig).filter( - LLMConfig.id == user.fast_llm_id, - LLMConfig.user_id == user.id + LLMConfig.id == user.fast_llm_id, LLMConfig.user_id == user.id ) ) llm_config = fast_llm.scalars().first() if llm_config: result["fast_llm"] = llm_config - + if user.strategic_llm_id: strategic_llm = await session.execute( select(LLMConfig).filter( - LLMConfig.id == user.strategic_llm_id, - LLMConfig.user_id == user.id + LLMConfig.id == user.strategic_llm_id, LLMConfig.user_id == user.id ) ) llm_config = strategic_llm.scalars().first() if llm_config: result["strategic_llm"] = llm_config - + return result except Exception as e: raise HTTPException( - status_code=500, - detail=f"Failed to fetch LLM preferences: {str(e)}" - ) + status_code=500, detail=f"Failed to fetch LLM preferences: {e!s}" + ) from e + @router.put("/users/me/llm-preferences", response_model=LLMPreferencesRead) async def update_user_llm_preferences( preferences: LLMPreferencesUpdate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): """Update the current user's LLM preferences""" try: # Validate that all provided LLM config IDs belong to the user update_data = preferences.model_dump(exclude_unset=True) - - for key, llm_config_id in update_data.items(): + + for _key, llm_config_id in update_data.items(): if llm_config_id is not None: # Verify ownership of the LLM config result = await session.execute( select(LLMConfig).filter( - LLMConfig.id == llm_config_id, - LLMConfig.user_id == user.id + LLMConfig.id == llm_config_id, LLMConfig.user_id == user.id ) ) llm_config = result.scalars().first() if not llm_config: raise HTTPException( status_code=404, - detail=f"LLM configuration {llm_config_id} not found or you don't have permission to access it" + detail=f"LLM configuration {llm_config_id} not found or you don't have permission to access it", ) - + # Update user preferences for key, value in update_data.items(): setattr(user, key, value) - + await session.commit() await session.refresh(user) - + # Return updated preferences return await get_user_llm_preferences(session, user) except HTTPException: @@ -238,6 +241,5 @@ async def update_user_llm_preferences( except Exception as e: await session.rollback() raise HTTPException( - status_code=500, - detail=f"Failed to update LLM preferences: {str(e)}" - ) \ No newline at end of file + status_code=500, detail=f"Failed to update LLM preferences: {e!s}" + ) from e diff --git a/surfsense_backend/app/routes/logs_routes.py b/surfsense_backend/app/routes/logs_routes.py index 65e33ec..cdcb034 100644 --- a/surfsense_backend/app/routes/logs_routes.py +++ b/surfsense_backend/app/routes/logs_routes.py @@ -1,22 +1,23 @@ -from fastapi import APIRouter, Depends, HTTPException, Query -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select -from sqlalchemy import and_, desc -from typing import List, Optional from datetime import datetime, timedelta -from app.db import get_async_session, User, SearchSpace, Log, LogLevel, LogStatus -from app.schemas import LogCreate, LogUpdate, LogRead, LogFilter +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import and_, desc +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.db import Log, LogLevel, LogStatus, SearchSpace, User, get_async_session +from app.schemas import LogCreate, LogRead, LogUpdate from app.users import current_active_user from app.utils.check_ownership import check_ownership router = APIRouter() + @router.post("/logs/", response_model=LogRead) async def create_log( log: LogCreate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): """Create a new log entry.""" try: @@ -33,22 +34,22 @@ async def create_log( except Exception as e: await session.rollback() raise HTTPException( - status_code=500, - detail=f"Failed to create log: {str(e)}" - ) + status_code=500, detail=f"Failed to create log: {e!s}" + ) from e -@router.get("/logs/", response_model=List[LogRead]) + +@router.get("/logs/", response_model=list[LogRead]) async def read_logs( skip: int = 0, limit: int = 100, - search_space_id: Optional[int] = None, - level: Optional[LogLevel] = None, - status: Optional[LogStatus] = None, - source: Optional[str] = None, - start_date: Optional[datetime] = None, - end_date: Optional[datetime] = None, + search_space_id: int | None = None, + level: LogLevel | None = None, + status: LogStatus | None = None, + source: str | None = None, + start_date: datetime | None = None, + end_date: datetime | None = None, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): """Get logs with optional filtering.""" try: @@ -62,23 +63,23 @@ async def read_logs( # Apply filters filters = [] - + if search_space_id is not None: await check_ownership(session, SearchSpace, search_space_id, user) filters.append(Log.search_space_id == search_space_id) - + if level is not None: filters.append(Log.level == level) - + if status is not None: filters.append(Log.status == status) - + if source is not None: filters.append(Log.source.ilike(f"%{source}%")) - + if start_date is not None: filters.append(Log.created_at >= start_date) - + if end_date is not None: filters.append(Log.created_at <= end_date) @@ -93,15 +94,15 @@ async def read_logs( raise except Exception as e: raise HTTPException( - status_code=500, - detail=f"Failed to fetch logs: {str(e)}" - ) + status_code=500, detail=f"Failed to fetch logs: {e!s}" + ) from e + @router.get("/logs/{log_id}", response_model=LogRead) async def read_log( log_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): """Get a specific log by ID.""" try: @@ -112,25 +113,25 @@ async def read_log( .filter(Log.id == log_id, SearchSpace.user_id == user.id) ) log = result.scalars().first() - + if not log: raise HTTPException(status_code=404, detail="Log not found") - + return log except HTTPException: raise except Exception as e: raise HTTPException( - status_code=500, - detail=f"Failed to fetch log: {str(e)}" - ) + status_code=500, detail=f"Failed to fetch log: {e!s}" + ) from e + @router.put("/logs/{log_id}", response_model=LogRead) async def update_log( log_id: int, log_update: LogUpdate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): """Update a log entry.""" try: @@ -141,7 +142,7 @@ async def update_log( .filter(Log.id == log_id, SearchSpace.user_id == user.id) ) db_log = result.scalars().first() - + if not db_log: raise HTTPException(status_code=404, detail="Log not found") @@ -158,15 +159,15 @@ async def update_log( except Exception as e: await session.rollback() raise HTTPException( - status_code=500, - detail=f"Failed to update log: {str(e)}" - ) + status_code=500, detail=f"Failed to update log: {e!s}" + ) from e + @router.delete("/logs/{log_id}") async def delete_log( log_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): """Delete a log entry.""" try: @@ -177,7 +178,7 @@ async def delete_log( .filter(Log.id == log_id, SearchSpace.user_id == user.id) ) db_log = result.scalars().first() - + if not db_log: raise HTTPException(status_code=404, detail="Log not found") @@ -189,38 +190,35 @@ async def delete_log( except Exception as e: await session.rollback() raise HTTPException( - status_code=500, - detail=f"Failed to delete log: {str(e)}" - ) + status_code=500, detail=f"Failed to delete log: {e!s}" + ) from e + @router.get("/logs/search-space/{search_space_id}/summary") async def get_logs_summary( search_space_id: int, hours: int = 24, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): """Get a summary of logs for a search space in the last X hours.""" try: # Check ownership await check_ownership(session, SearchSpace, search_space_id, user) - + # Calculate time window since = datetime.utcnow().replace(microsecond=0) - timedelta(hours=hours) - + # Get logs from the time window result = await session.execute( select(Log) .filter( - and_( - Log.search_space_id == search_space_id, - Log.created_at >= since - ) + and_(Log.search_space_id == search_space_id, Log.created_at >= since) ) .order_by(desc(Log.created_at)) ) logs = result.scalars().all() - + # Create summary summary = { "total_logs": len(logs), @@ -229,52 +227,69 @@ async def get_logs_summary( "by_level": {}, "by_source": {}, "active_tasks": [], - "recent_failures": [] + "recent_failures": [], } - + # Count by status and level for log in logs: # Status counts status_str = log.status.value - summary["by_status"][status_str] = summary["by_status"].get(status_str, 0) + 1 - + summary["by_status"][status_str] = ( + summary["by_status"].get(status_str, 0) + 1 + ) + # Level counts level_str = log.level.value summary["by_level"][level_str] = summary["by_level"].get(level_str, 0) + 1 - + # Source counts if log.source: - summary["by_source"][log.source] = summary["by_source"].get(log.source, 0) + 1 - + summary["by_source"][log.source] = ( + summary["by_source"].get(log.source, 0) + 1 + ) + # Active tasks (IN_PROGRESS) if log.status == LogStatus.IN_PROGRESS: - task_name = log.log_metadata.get("task_name", "Unknown") if log.log_metadata else "Unknown" - summary["active_tasks"].append({ - "id": log.id, - "task_name": task_name, - "message": log.message, - "started_at": log.created_at, - "source": log.source - }) - + task_name = ( + log.log_metadata.get("task_name", "Unknown") + if log.log_metadata + else "Unknown" + ) + summary["active_tasks"].append( + { + "id": log.id, + "task_name": task_name, + "message": log.message, + "started_at": log.created_at, + "source": log.source, + } + ) + # Recent failures if log.status == LogStatus.FAILED and len(summary["recent_failures"]) < 10: - task_name = log.log_metadata.get("task_name", "Unknown") if log.log_metadata else "Unknown" - summary["recent_failures"].append({ - "id": log.id, - "task_name": task_name, - "message": log.message, - "failed_at": log.created_at, - "source": log.source, - "error_details": log.log_metadata.get("error_details") if log.log_metadata else None - }) - + task_name = ( + log.log_metadata.get("task_name", "Unknown") + if log.log_metadata + else "Unknown" + ) + summary["recent_failures"].append( + { + "id": log.id, + "task_name": task_name, + "message": log.message, + "failed_at": log.created_at, + "source": log.source, + "error_details": log.log_metadata.get("error_details") + if log.log_metadata + else None, + } + ) + return summary - + except HTTPException: raise except Exception as e: raise HTTPException( - status_code=500, - detail=f"Failed to generate logs summary: {str(e)}" - ) \ No newline at end of file + status_code=500, detail=f"Failed to generate logs summary: {e!s}" + ) from e diff --git a/surfsense_backend/app/routes/podcasts_routes.py b/surfsense_backend/app/routes/podcasts_routes.py index 507c15e..644d2ad 100644 --- a/surfsense_backend/app/routes/podcasts_routes.py +++ b/surfsense_backend/app/routes/podcasts_routes.py @@ -1,24 +1,31 @@ -from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select -from sqlalchemy.exc import IntegrityError, SQLAlchemyError -from typing import List -from app.db import get_async_session, User, SearchSpace, Podcast, Chat -from app.schemas import PodcastCreate, PodcastUpdate, PodcastRead, PodcastGenerateRequest -from app.users import current_active_user -from app.utils.check_ownership import check_ownership -from app.tasks.podcast_tasks import generate_chat_podcast -from fastapi.responses import StreamingResponse import os from pathlib import Path +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException +from fastapi.responses import StreamingResponse +from sqlalchemy.exc import IntegrityError, SQLAlchemyError +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.db import Chat, Podcast, SearchSpace, User, get_async_session +from app.schemas import ( + PodcastCreate, + PodcastGenerateRequest, + PodcastRead, + PodcastUpdate, +) +from app.tasks.podcast_tasks import generate_chat_podcast +from app.users import current_active_user +from app.utils.check_ownership import check_ownership + router = APIRouter() + @router.post("/podcasts/", response_model=PodcastRead) async def create_podcast( podcast: PodcastCreate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): try: await check_ownership(session, SearchSpace, podcast.search_space_id, user) @@ -29,22 +36,30 @@ async def create_podcast( return db_podcast except HTTPException as he: raise he - except IntegrityError as e: + except IntegrityError: await session.rollback() - raise HTTPException(status_code=400, detail="Podcast creation failed due to constraint violation") - except SQLAlchemyError as e: + raise HTTPException( + status_code=400, + detail="Podcast creation failed due to constraint violation", + ) from None + except SQLAlchemyError: await session.rollback() - raise HTTPException(status_code=500, detail="Database error occurred while creating podcast") - except Exception as e: + raise HTTPException( + status_code=500, detail="Database error occurred while creating podcast" + ) from None + except Exception: await session.rollback() - raise HTTPException(status_code=500, detail="An unexpected error occurred") + raise HTTPException( + status_code=500, detail="An unexpected error occurred" + ) from None -@router.get("/podcasts/", response_model=List[PodcastRead]) + +@router.get("/podcasts/", response_model=list[PodcastRead]) async def read_podcasts( skip: int = 0, limit: int = 100, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): if skip < 0 or limit < 1: raise HTTPException(status_code=400, detail="Invalid pagination parameters") @@ -58,13 +73,16 @@ async def read_podcasts( ) return result.scalars().all() except SQLAlchemyError: - raise HTTPException(status_code=500, detail="Database error occurred while fetching podcasts") + raise HTTPException( + status_code=500, detail="Database error occurred while fetching podcasts" + ) from None + @router.get("/podcasts/{podcast_id}", response_model=PodcastRead) async def read_podcast( podcast_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): try: result = await session.execute( @@ -76,20 +94,23 @@ async def read_podcast( if not podcast: raise HTTPException( status_code=404, - detail="Podcast not found or you don't have permission to access it" + detail="Podcast not found or you don't have permission to access it", ) return podcast except HTTPException as he: raise he except SQLAlchemyError: - raise HTTPException(status_code=500, detail="Database error occurred while fetching podcast") + raise HTTPException( + status_code=500, detail="Database error occurred while fetching podcast" + ) from None + @router.put("/podcasts/{podcast_id}", response_model=PodcastRead) async def update_podcast( podcast_id: int, podcast_update: PodcastUpdate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): try: db_podcast = await read_podcast(podcast_id, session, user) @@ -103,16 +124,21 @@ async def update_podcast( raise he except IntegrityError: await session.rollback() - raise HTTPException(status_code=400, detail="Update failed due to constraint violation") + raise HTTPException( + status_code=400, detail="Update failed due to constraint violation" + ) from None except SQLAlchemyError: await session.rollback() - raise HTTPException(status_code=500, detail="Database error occurred while updating podcast") + raise HTTPException( + status_code=500, detail="Database error occurred while updating podcast" + ) from None + @router.delete("/podcasts/{podcast_id}", response_model=dict) async def delete_podcast( podcast_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): try: db_podcast = await read_podcast(podcast_id, session, user) @@ -123,83 +149,100 @@ async def delete_podcast( raise he except SQLAlchemyError: await session.rollback() - raise HTTPException(status_code=500, detail="Database error occurred while deleting podcast") + raise HTTPException( + status_code=500, detail="Database error occurred while deleting podcast" + ) from None + async def generate_chat_podcast_with_new_session( - chat_id: int, - search_space_id: int, - podcast_title: str, - user_id: int + chat_id: int, search_space_id: int, podcast_title: str, user_id: int ): """Create a new session and process chat podcast generation.""" from app.db import async_session_maker - + async with async_session_maker() as session: try: - await generate_chat_podcast(session, chat_id, search_space_id, podcast_title, user_id) + await generate_chat_podcast( + session, chat_id, search_space_id, podcast_title, user_id + ) except Exception as e: import logging - logging.error(f"Error generating podcast from chat: {str(e)}") + + logging.error(f"Error generating podcast from chat: {e!s}") + @router.post("/podcasts/generate/") async def generate_podcast( request: PodcastGenerateRequest, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), - fastapi_background_tasks: BackgroundTasks = BackgroundTasks() + fastapi_background_tasks: BackgroundTasks = BackgroundTasks(), ): try: # Check if the user owns the search space await check_ownership(session, SearchSpace, request.search_space_id, user) - + if request.type == "CHAT": # Verify that all chat IDs belong to this user and search space - query = select(Chat).filter( - Chat.id.in_(request.ids), - Chat.search_space_id == request.search_space_id - ).join(SearchSpace).filter(SearchSpace.user_id == user.id) - + query = ( + select(Chat) + .filter( + Chat.id.in_(request.ids), + Chat.search_space_id == request.search_space_id, + ) + .join(SearchSpace) + .filter(SearchSpace.user_id == user.id) + ) + result = await session.execute(query) valid_chats = result.scalars().all() valid_chat_ids = [chat.id for chat in valid_chats] - + # If any requested ID is not in valid IDs, raise error immediately if len(valid_chat_ids) != len(request.ids): raise HTTPException( - status_code=403, - detail="One or more chat IDs do not belong to this user or search space" + status_code=403, + detail="One or more chat IDs do not belong to this user or search space", ) - + # Only add a single task with the first chat ID for chat_id in valid_chat_ids: fastapi_background_tasks.add_task( - generate_chat_podcast_with_new_session, - chat_id, + generate_chat_podcast_with_new_session, + chat_id, request.search_space_id, request.podcast_title, - user.id + user.id, ) - + return { "message": "Podcast generation started", } except HTTPException as he: raise he - except IntegrityError as e: + except IntegrityError: await session.rollback() - raise HTTPException(status_code=400, detail="Podcast generation failed due to constraint violation") - except SQLAlchemyError as e: + raise HTTPException( + status_code=400, + detail="Podcast generation failed due to constraint violation", + ) from None + except SQLAlchemyError: await session.rollback() - raise HTTPException(status_code=500, detail="Database error occurred while generating podcast") + raise HTTPException( + status_code=500, detail="Database error occurred while generating podcast" + ) from None except Exception as e: await session.rollback() - raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}") + raise HTTPException( + status_code=500, detail=f"An unexpected error occurred: {e!s}" + ) from e + @router.get("/podcasts/{podcast_id}/stream") async def stream_podcast( podcast_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): """Stream a podcast audio file.""" try: @@ -210,36 +253,38 @@ async def stream_podcast( .filter(Podcast.id == podcast_id, SearchSpace.user_id == user.id) ) podcast = result.scalars().first() - + if not podcast: raise HTTPException( status_code=404, - detail="Podcast not found or you don't have permission to access it" + detail="Podcast not found or you don't have permission to access it", ) - + # Get the file path file_path = podcast.file_location - + # Check if the file exists if not os.path.isfile(file_path): raise HTTPException(status_code=404, detail="Podcast audio file not found") - + # Define a generator function to stream the file def iterfile(): with open(file_path, mode="rb") as file_like: yield from file_like - + # Return a streaming response with appropriate headers return StreamingResponse( iterfile(), media_type="audio/mpeg", headers={ "Accept-Ranges": "bytes", - "Content-Disposition": f"inline; filename={Path(file_path).name}" - } + "Content-Disposition": f"inline; filename={Path(file_path).name}", + }, ) - + except HTTPException as he: raise he except Exception as e: - raise HTTPException(status_code=500, detail=f"Error streaming podcast: {str(e)}") \ No newline at end of file + raise HTTPException( + status_code=500, detail=f"Error streaming podcast: {e!s}" + ) from e diff --git a/surfsense_backend/app/routes/search_source_connectors_routes.py b/surfsense_backend/app/routes/search_source_connectors_routes.py index 54f97d6..47caa97 100644 --- a/surfsense_backend/app/routes/search_source_connectors_routes.py +++ b/surfsense_backend/app/routes/search_source_connectors_routes.py @@ -9,35 +9,58 @@ POST /search-source-connectors/{connector_id}/index - Index content from a conne Note: Each user can have only one connector of each type (SERPER_API, TAVILY_API, SLACK_CONNECTOR, NOTION_CONNECTOR, GITHUB_CONNECTOR, LINEAR_CONNECTOR, DISCORD_CONNECTOR). """ -from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks, Body + +import logging +from datetime import datetime, timedelta +from typing import Any + +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query +from pydantic import BaseModel, Field, ValidationError +from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from sqlalchemy.exc import IntegrityError -from typing import List, Dict, Any -from app.db import get_async_session, User, SearchSourceConnector, SearchSourceConnectorType, SearchSpace, async_session_maker -from app.schemas import SearchSourceConnectorCreate, SearchSourceConnectorUpdate, SearchSourceConnectorRead, SearchSourceConnectorBase + +from app.connectors.github_connector import GitHubConnector +from app.db import ( + SearchSourceConnector, + SearchSourceConnectorType, + SearchSpace, + User, + async_session_maker, + get_async_session, +) +from app.schemas import ( + SearchSourceConnectorBase, + SearchSourceConnectorCreate, + SearchSourceConnectorRead, + SearchSourceConnectorUpdate, +) +from app.tasks.connectors_indexing_tasks import ( + index_discord_messages, + index_github_repos, + index_linear_issues, + index_notion_pages, + index_slack_messages, +) from app.users import current_active_user from app.utils.check_ownership import check_ownership -from pydantic import BaseModel, Field, ValidationError -from app.tasks.connectors_indexing_tasks import index_slack_messages, index_notion_pages, index_github_repos, index_linear_issues, index_discord_messages -from app.connectors.github_connector import GitHubConnector -from datetime import datetime, timedelta -import logging # Set up logging logger = logging.getLogger(__name__) router = APIRouter() + # Use Pydantic's BaseModel here class GitHubPATRequest(BaseModel): github_pat: str = Field(..., description="GitHub Personal Access Token") + # --- New Endpoint to list GitHub Repositories --- -@router.post("/github/repositories/", response_model=List[Dict[str, Any]]) +@router.post("/github/repositories/", response_model=list[dict[str, Any]]) async def list_github_repositories( pat_request: GitHubPATRequest, - user: User = Depends(current_active_user) # Ensure the user is logged in + user: User = Depends(current_active_user), # Ensure the user is logged in ): """ Fetches a list of repositories accessible by the provided GitHub PAT. @@ -51,38 +74,40 @@ async def list_github_repositories( return repositories except ValueError as e: # Handle invalid token error specifically - logger.error(f"GitHub PAT validation failed for user {user.id}: {str(e)}") - raise HTTPException(status_code=400, detail=f"Invalid GitHub PAT: {str(e)}") + logger.error(f"GitHub PAT validation failed for user {user.id}: {e!s}") + raise HTTPException(status_code=400, detail=f"Invalid GitHub PAT: {e!s}") from e except Exception as e: - logger.error(f"Failed to fetch GitHub repositories for user {user.id}: {str(e)}") - raise HTTPException(status_code=500, detail="Failed to fetch GitHub repositories.") + logger.error(f"Failed to fetch GitHub repositories for user {user.id}: {e!s}") + raise HTTPException( + status_code=500, detail="Failed to fetch GitHub repositories." + ) from e + @router.post("/search-source-connectors/", response_model=SearchSourceConnectorRead) async def create_search_source_connector( connector: SearchSourceConnectorCreate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): """ Create a new search source connector. - + Each user can have only one connector of each type (SERPER_API, TAVILY_API, SLACK_CONNECTOR, etc.). The config must contain the appropriate keys for the connector type. """ try: # Check if a connector with the same type already exists for this user result = await session.execute( - select(SearchSourceConnector) - .filter( + select(SearchSourceConnector).filter( SearchSourceConnector.user_id == user.id, - SearchSourceConnector.connector_type == connector.connector_type + SearchSourceConnector.connector_type == connector.connector_type, ) ) existing_connector = result.scalars().first() if existing_connector: raise HTTPException( status_code=409, - detail=f"A connector with type {connector.connector_type} already exists. Each user can have only one connector of each type." + detail=f"A connector with type {connector.connector_type} already exists. Each user can have only one connector of each type.", ) db_connector = SearchSourceConnector(**connector.model_dump(), user_id=user.id) session.add(db_connector) @@ -91,56 +116,59 @@ async def create_search_source_connector( return db_connector except ValidationError as e: await session.rollback() - raise HTTPException( - status_code=422, - detail=f"Validation error: {str(e)}" - ) + raise HTTPException(status_code=422, detail=f"Validation error: {e!s}") from e except IntegrityError as e: await session.rollback() raise HTTPException( status_code=409, - detail=f"Integrity error: A connector with this type already exists. {str(e)}" - ) + detail=f"Integrity error: A connector with this type already exists. {e!s}", + ) from e except HTTPException: await session.rollback() raise except Exception as e: - logger.error(f"Failed to create search source connector: {str(e)}") + logger.error(f"Failed to create search source connector: {e!s}") await session.rollback() raise HTTPException( status_code=500, - detail=f"Failed to create search source connector: {str(e)}" - ) + detail=f"Failed to create search source connector: {e!s}", + ) from e -@router.get("/search-source-connectors/", response_model=List[SearchSourceConnectorRead]) + +@router.get( + "/search-source-connectors/", response_model=list[SearchSourceConnectorRead] +) async def read_search_source_connectors( skip: int = 0, limit: int = 100, - search_space_id: int = None, + search_space_id: int | None = None, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): """List all search source connectors for the current user.""" try: - query = select(SearchSourceConnector).filter(SearchSourceConnector.user_id == user.id) - - # No need to filter by search_space_id as connectors are user-owned, not search space specific - - result = await session.execute( - query.offset(skip).limit(limit) + query = select(SearchSourceConnector).filter( + SearchSourceConnector.user_id == user.id ) + + # No need to filter by search_space_id as connectors are user-owned, not search space specific + + result = await session.execute(query.offset(skip).limit(limit)) return result.scalars().all() except Exception as e: raise HTTPException( status_code=500, - detail=f"Failed to fetch search source connectors: {str(e)}" - ) + detail=f"Failed to fetch search source connectors: {e!s}", + ) from e -@router.get("/search-source-connectors/{connector_id}", response_model=SearchSourceConnectorRead) + +@router.get( + "/search-source-connectors/{connector_id}", response_model=SearchSourceConnectorRead +) async def read_search_source_connector( connector_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): """Get a specific search source connector by ID.""" try: @@ -149,31 +177,37 @@ async def read_search_source_connector( raise except Exception as e: raise HTTPException( - status_code=500, - detail=f"Failed to fetch search source connector: {str(e)}" - ) + status_code=500, detail=f"Failed to fetch search source connector: {e!s}" + ) from e -@router.put("/search-source-connectors/{connector_id}", response_model=SearchSourceConnectorRead) + +@router.put( + "/search-source-connectors/{connector_id}", response_model=SearchSourceConnectorRead +) async def update_search_source_connector( connector_id: int, connector_update: SearchSourceConnectorUpdate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): """ Update a search source connector. Handles partial updates, including merging changes into the 'config' field. """ - db_connector = await check_ownership(session, SearchSourceConnector, connector_id, user) - + db_connector = await check_ownership( + session, SearchSourceConnector, connector_id, user + ) + # Convert the sparse update data (only fields present in request) to a dict update_data = connector_update.model_dump(exclude_unset=True) # Special handling for 'config' field if "config" in update_data: - incoming_config = update_data["config"] # Config data from the request - existing_config = db_connector.config if db_connector.config else {} # Current config from DB - + incoming_config = update_data["config"] # Config data from the request + existing_config = ( + db_connector.config if db_connector.config else {} + ) # Current config from DB + # Merge incoming config into existing config # This preserves existing keys (like GITHUB_PAT) if they are not in the incoming data merged_config = existing_config.copy() @@ -182,26 +216,29 @@ async def update_search_source_connector( # -- Validation after merging -- # Validate the *merged* config based on the connector type # We need the connector type - use the one from the update if provided, else the existing one - current_connector_type = connector_update.connector_type if connector_update.connector_type is not None else db_connector.connector_type - + current_connector_type = ( + connector_update.connector_type + if connector_update.connector_type is not None + else db_connector.connector_type + ) + try: # We can reuse the base validator by creating a temporary base model instance # Note: This assumes 'name' and 'is_indexable' are not crucial for config validation itself temp_data_for_validation = { - "name": db_connector.name, # Use existing name + "name": db_connector.name, # Use existing name "connector_type": current_connector_type, - "is_indexable": db_connector.is_indexable, # Use existing value - "last_indexed_at": db_connector.last_indexed_at, # Not used by validator - "config": merged_config + "is_indexable": db_connector.is_indexable, # Use existing value + "last_indexed_at": db_connector.last_indexed_at, # Not used by validator + "config": merged_config, } SearchSourceConnectorBase.model_validate(temp_data_for_validation) except ValidationError as e: # Raise specific validation error for the merged config raise HTTPException( - status_code=422, - detail=f"Validation error for merged config: {str(e)}" - ) - + status_code=422, detail=f"Validation error for merged config: {e!s}" + ) from e + # If validation passes, update the main update_data dict with the merged config update_data["config"] = merged_config @@ -210,20 +247,19 @@ async def update_search_source_connector( # Prevent changing connector_type if it causes a duplicate (check moved here) if key == "connector_type" and value != db_connector.connector_type: result = await session.execute( - select(SearchSourceConnector) - .filter( + select(SearchSourceConnector).filter( SearchSourceConnector.user_id == user.id, SearchSourceConnector.connector_type == value, - SearchSourceConnector.id != connector_id + SearchSourceConnector.id != connector_id, ) ) existing_connector = result.scalars().first() if existing_connector: raise HTTPException( status_code=409, - detail=f"A connector with type {value} already exists. Each user can have only one connector of each type." + detail=f"A connector with type {value} already exists. Each user can have only one connector of each type.", ) - + setattr(db_connector, key, value) try: @@ -234,26 +270,31 @@ async def update_search_source_connector( await session.rollback() # This might occur if connector_type constraint is violated somehow after the check raise HTTPException( - status_code=409, - detail=f"Database integrity error during update: {str(e)}" - ) + status_code=409, detail=f"Database integrity error during update: {e!s}" + ) from e except Exception as e: await session.rollback() - logger.error(f"Failed to update search source connector {connector_id}: {e}", exc_info=True) + logger.error( + f"Failed to update search source connector {connector_id}: {e}", + exc_info=True, + ) raise HTTPException( status_code=500, - detail=f"Failed to update search source connector: {str(e)}" - ) + detail=f"Failed to update search source connector: {e!s}", + ) from e + @router.delete("/search-source-connectors/{connector_id}", response_model=dict) async def delete_search_source_connector( connector_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): """Delete a search source connector.""" try: - db_connector = await check_ownership(session, SearchSourceConnector, connector_id, user) + db_connector = await check_ownership( + session, SearchSourceConnector, connector_id, user + ) await session.delete(db_connector) await session.commit() return {"message": "Search source connector deleted successfully"} @@ -263,48 +304,61 @@ async def delete_search_source_connector( await session.rollback() raise HTTPException( status_code=500, - detail=f"Failed to delete search source connector: {str(e)}" - ) + detail=f"Failed to delete search source connector: {e!s}", + ) from e -@router.post("/search-source-connectors/{connector_id}/index", response_model=Dict[str, Any]) + +@router.post( + "/search-source-connectors/{connector_id}/index", response_model=dict[str, Any] +) async def index_connector_content( connector_id: int, - search_space_id: int = Query(..., description="ID of the search space to store indexed content"), - start_date: str = Query(None, description="Start date for indexing (YYYY-MM-DD format). If not provided, uses last_indexed_at or defaults to 365 days ago"), - end_date: str = Query(None, description="End date for indexing (YYYY-MM-DD format). If not provided, uses today's date"), + search_space_id: int = Query( + ..., description="ID of the search space to store indexed content" + ), + start_date: str = Query( + None, + description="Start date for indexing (YYYY-MM-DD format). If not provided, uses last_indexed_at or defaults to 365 days ago", + ), + end_date: str = Query( + None, + description="End date for indexing (YYYY-MM-DD format). If not provided, uses today's date", + ), session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), - background_tasks: BackgroundTasks = None + background_tasks: BackgroundTasks = None, ): """ Index content from a connector to a search space. - + Currently supports: - SLACK_CONNECTOR: Indexes messages from all accessible Slack channels - NOTION_CONNECTOR: Indexes pages from all accessible Notion pages - GITHUB_CONNECTOR: Indexes code and documentation from GitHub repositories - LINEAR_CONNECTOR: Indexes issues and comments from Linear - DISCORD_CONNECTOR: Indexes messages from all accessible Discord channels - + Args: connector_id: ID of the connector to use search_space_id: ID of the search space to store indexed content background_tasks: FastAPI background tasks - + Returns: Dictionary with indexing status """ try: # Check if the connector belongs to the user - connector = await check_ownership(session, SearchSourceConnector, connector_id, user) - + connector = await check_ownership( + session, SearchSourceConnector, connector_id, user + ) + # Check if the search space belongs to the user - search_space = await check_ownership(session, SearchSpace, search_space_id, user) - + await check_ownership(session, SearchSpace, search_space_id, user) + # Handle different connector types response_message = "" today_str = datetime.now().strftime("%Y-%m-%d") - + # Determine the actual date range to use if start_date is None: # Use last_indexed_at or default to 365 days ago @@ -316,37 +370,72 @@ async def index_connector_content( else: indexing_from = connector.last_indexed_at.strftime("%Y-%m-%d") else: - indexing_from = (datetime.now() - timedelta(days=365)).strftime("%Y-%m-%d") + indexing_from = (datetime.now() - timedelta(days=365)).strftime( + "%Y-%m-%d" + ) else: indexing_from = start_date - - if end_date is None: - indexing_to = today_str - else: - indexing_to = end_date + + indexing_to = end_date if end_date else today_str if connector.connector_type == SearchSourceConnectorType.SLACK_CONNECTOR: # Run indexing in background - logger.info(f"Triggering Slack indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}") - background_tasks.add_task(run_slack_indexing_with_new_session, connector_id, search_space_id, str(user.id), indexing_from, indexing_to) + logger.info( + f"Triggering Slack indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}" + ) + background_tasks.add_task( + run_slack_indexing_with_new_session, + connector_id, + search_space_id, + str(user.id), + indexing_from, + indexing_to, + ) response_message = "Slack indexing started in the background." elif connector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR: # Run indexing in background - logger.info(f"Triggering Notion indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}") - background_tasks.add_task(run_notion_indexing_with_new_session, connector_id, search_space_id, str(user.id), indexing_from, indexing_to) + logger.info( + f"Triggering Notion indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}" + ) + background_tasks.add_task( + run_notion_indexing_with_new_session, + connector_id, + search_space_id, + str(user.id), + indexing_from, + indexing_to, + ) response_message = "Notion indexing started in the background." - + elif connector.connector_type == SearchSourceConnectorType.GITHUB_CONNECTOR: # Run indexing in background - logger.info(f"Triggering GitHub indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}") - background_tasks.add_task(run_github_indexing_with_new_session, connector_id, search_space_id, str(user.id), indexing_from, indexing_to) + logger.info( + f"Triggering GitHub indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}" + ) + background_tasks.add_task( + run_github_indexing_with_new_session, + connector_id, + search_space_id, + str(user.id), + indexing_from, + indexing_to, + ) response_message = "GitHub indexing started in the background." - + elif connector.connector_type == SearchSourceConnectorType.LINEAR_CONNECTOR: # Run indexing in background - logger.info(f"Triggering Linear indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}") - background_tasks.add_task(run_linear_indexing_with_new_session, connector_id, search_space_id, str(user.id), indexing_from, indexing_to) + logger.info( + f"Triggering Linear indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}" + ) + background_tasks.add_task( + run_linear_indexing_with_new_session, + connector_id, + search_space_id, + str(user.id), + indexing_from, + indexing_to, + ) response_message = "Linear indexing started in the background." elif connector.connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR: @@ -355,71 +444,83 @@ async def index_connector_content( f"Triggering Discord indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}" ) background_tasks.add_task( - run_discord_indexing_with_new_session, connector_id, search_space_id, str(user.id), indexing_from, indexing_to + run_discord_indexing_with_new_session, + connector_id, + search_space_id, + str(user.id), + indexing_from, + indexing_to, ) response_message = "Discord indexing started in the background." else: raise HTTPException( status_code=400, - detail=f"Indexing not supported for connector type: {connector.connector_type}" + detail=f"Indexing not supported for connector type: {connector.connector_type}", ) return { - "message": response_message, - "connector_id": connector_id, + "message": response_message, + "connector_id": connector_id, "search_space_id": search_space_id, "indexing_from": indexing_from, - "indexing_to": indexing_to + "indexing_to": indexing_to, } except HTTPException: raise except Exception as e: - logger.error(f"Failed to initiate indexing for connector {connector_id}: {e}", exc_info=True) - raise HTTPException( - status_code=500, - detail=f"Failed to initiate indexing: {str(e)}" + logger.error( + f"Failed to initiate indexing for connector {connector_id}: {e}", + exc_info=True, ) - -async def update_connector_last_indexed( - session: AsyncSession, - connector_id: int -): + raise HTTPException( + status_code=500, detail=f"Failed to initiate indexing: {e!s}" + ) from e + + +async def update_connector_last_indexed(session: AsyncSession, connector_id: int): """ Update the last_indexed_at timestamp for a connector. - + Args: session: Database session connector_id: ID of the connector to update """ try: result = await session.execute( - select(SearchSourceConnector) - .filter(SearchSourceConnector.id == connector_id) + select(SearchSourceConnector).filter( + SearchSourceConnector.id == connector_id + ) ) connector = result.scalars().first() - + if connector: connector.last_indexed_at = datetime.now() await session.commit() logger.info(f"Updated last_indexed_at for connector {connector_id}") except Exception as e: - logger.error(f"Failed to update last_indexed_at for connector {connector_id}: {str(e)}") + logger.error( + f"Failed to update last_indexed_at for connector {connector_id}: {e!s}" + ) await session.rollback() + async def run_slack_indexing_with_new_session( connector_id: int, search_space_id: int, user_id: str, start_date: str, - end_date: str + end_date: str, ): """ Create a new session and run the Slack indexing task. This prevents session leaks by creating a dedicated session for the background task. """ async with async_session_maker() as session: - await run_slack_indexing(session, connector_id, search_space_id, user_id, start_date, end_date) + await run_slack_indexing( + session, connector_id, search_space_id, user_id, start_date, end_date + ) + async def run_slack_indexing( session: AsyncSession, @@ -427,11 +528,11 @@ async def run_slack_indexing( search_space_id: int, user_id: str, start_date: str, - end_date: str + end_date: str, ): """ Background task to run Slack indexing. - + Args: session: Database session connector_id: ID of the Slack connector @@ -449,31 +550,39 @@ async def run_slack_indexing( user_id=user_id, start_date=start_date, end_date=end_date, - update_last_indexed=False # Don't update timestamp in the indexing function + update_last_indexed=False, # Don't update timestamp in the indexing function ) - + # Only update last_indexed_at if indexing was successful (either new docs or updated docs) if documents_processed > 0: await update_connector_last_indexed(session, connector_id) - logger.info(f"Slack indexing completed successfully: {documents_processed} documents processed") + logger.info( + f"Slack indexing completed successfully: {documents_processed} documents processed" + ) else: - logger.error(f"Slack indexing failed or no documents processed: {error_or_warning}") + logger.error( + f"Slack indexing failed or no documents processed: {error_or_warning}" + ) except Exception as e: - logger.error(f"Error in background Slack indexing task: {str(e)}") + logger.error(f"Error in background Slack indexing task: {e!s}") + async def run_notion_indexing_with_new_session( connector_id: int, search_space_id: int, user_id: str, start_date: str, - end_date: str + end_date: str, ): """ Create a new session and run the Notion indexing task. This prevents session leaks by creating a dedicated session for the background task. """ async with async_session_maker() as session: - await run_notion_indexing(session, connector_id, search_space_id, user_id, start_date, end_date) + await run_notion_indexing( + session, connector_id, search_space_id, user_id, start_date, end_date + ) + async def run_notion_indexing( session: AsyncSession, @@ -481,11 +590,11 @@ async def run_notion_indexing( search_space_id: int, user_id: str, start_date: str, - end_date: str + end_date: str, ): """ Background task to run Notion indexing. - + Args: session: Database session connector_id: ID of the Notion connector @@ -503,17 +612,22 @@ async def run_notion_indexing( user_id=user_id, start_date=start_date, end_date=end_date, - update_last_indexed=False # Don't update timestamp in the indexing function + update_last_indexed=False, # Don't update timestamp in the indexing function ) - + # Only update last_indexed_at if indexing was successful (either new docs or updated docs) if documents_processed > 0: await update_connector_last_indexed(session, connector_id) - logger.info(f"Notion indexing completed successfully: {documents_processed} documents processed") + logger.info( + f"Notion indexing completed successfully: {documents_processed} documents processed" + ) else: - logger.error(f"Notion indexing failed or no documents processed: {error_or_warning}") + logger.error( + f"Notion indexing failed or no documents processed: {error_or_warning}" + ) except Exception as e: - logger.error(f"Error in background Notion indexing task: {str(e)}") + logger.error(f"Error in background Notion indexing task: {e!s}") + # Add new helper functions for GitHub indexing async def run_github_indexing_with_new_session( @@ -521,94 +635,135 @@ async def run_github_indexing_with_new_session( search_space_id: int, user_id: str, start_date: str, - end_date: str + end_date: str, ): """Wrapper to run GitHub indexing with its own database session.""" - logger.info(f"Background task started: Indexing GitHub connector {connector_id} into space {search_space_id} from {start_date} to {end_date}") + logger.info( + f"Background task started: Indexing GitHub connector {connector_id} into space {search_space_id} from {start_date} to {end_date}" + ) async with async_session_maker() as session: - await run_github_indexing(session, connector_id, search_space_id, user_id, start_date, end_date) + await run_github_indexing( + session, connector_id, search_space_id, user_id, start_date, end_date + ) logger.info(f"Background task finished: Indexing GitHub connector {connector_id}") + async def run_github_indexing( session: AsyncSession, connector_id: int, search_space_id: int, user_id: str, start_date: str, - end_date: str + end_date: str, ): """Runs the GitHub indexing task and updates the timestamp.""" try: indexed_count, error_message = await index_github_repos( - session, connector_id, search_space_id, user_id, start_date, end_date, update_last_indexed=False + session, + connector_id, + search_space_id, + user_id, + start_date, + end_date, + update_last_indexed=False, ) if error_message: - logger.error(f"GitHub indexing failed for connector {connector_id}: {error_message}") + logger.error( + f"GitHub indexing failed for connector {connector_id}: {error_message}" + ) # Optionally update status in DB to indicate failure else: - logger.info(f"GitHub indexing successful for connector {connector_id}. Indexed {indexed_count} documents.") + logger.info( + f"GitHub indexing successful for connector {connector_id}. Indexed {indexed_count} documents." + ) # Update the last indexed timestamp only on success await update_connector_last_indexed(session, connector_id) - await session.commit() # Commit timestamp update + await session.commit() # Commit timestamp update except Exception as e: await session.rollback() - logger.error(f"Critical error in run_github_indexing for connector {connector_id}: {e}", exc_info=True) + logger.error( + f"Critical error in run_github_indexing for connector {connector_id}: {e}", + exc_info=True, + ) # Optionally update status in DB to indicate failure + # Add new helper functions for Linear indexing async def run_linear_indexing_with_new_session( connector_id: int, search_space_id: int, user_id: str, start_date: str, - end_date: str + end_date: str, ): """Wrapper to run Linear indexing with its own database session.""" - logger.info(f"Background task started: Indexing Linear connector {connector_id} into space {search_space_id} from {start_date} to {end_date}") + logger.info( + f"Background task started: Indexing Linear connector {connector_id} into space {search_space_id} from {start_date} to {end_date}" + ) async with async_session_maker() as session: - await run_linear_indexing(session, connector_id, search_space_id, user_id, start_date, end_date) + await run_linear_indexing( + session, connector_id, search_space_id, user_id, start_date, end_date + ) logger.info(f"Background task finished: Indexing Linear connector {connector_id}") + async def run_linear_indexing( session: AsyncSession, connector_id: int, search_space_id: int, user_id: str, start_date: str, - end_date: str + end_date: str, ): """Runs the Linear indexing task and updates the timestamp.""" try: indexed_count, error_message = await index_linear_issues( - session, connector_id, search_space_id, user_id, start_date, end_date, update_last_indexed=False + session, + connector_id, + search_space_id, + user_id, + start_date, + end_date, + update_last_indexed=False, ) if error_message: - logger.error(f"Linear indexing failed for connector {connector_id}: {error_message}") + logger.error( + f"Linear indexing failed for connector {connector_id}: {error_message}" + ) # Optionally update status in DB to indicate failure else: - logger.info(f"Linear indexing successful for connector {connector_id}. Indexed {indexed_count} documents.") + logger.info( + f"Linear indexing successful for connector {connector_id}. Indexed {indexed_count} documents." + ) # Update the last indexed timestamp only on success await update_connector_last_indexed(session, connector_id) - await session.commit() # Commit timestamp update + await session.commit() # Commit timestamp update except Exception as e: await session.rollback() - logger.error(f"Critical error in run_linear_indexing for connector {connector_id}: {e}", exc_info=True) + logger.error( + f"Critical error in run_linear_indexing for connector {connector_id}: {e}", + exc_info=True, + ) # Optionally update status in DB to indicate failure + # Add new helper functions for discord indexing async def run_discord_indexing_with_new_session( connector_id: int, search_space_id: int, user_id: str, start_date: str, - end_date: str + end_date: str, ): """ Create a new session and run the Discord indexing task. This prevents session leaks by creating a dedicated session for the background task. """ async with async_session_maker() as session: - await run_discord_indexing(session, connector_id, search_space_id, user_id, start_date, end_date) + await run_discord_indexing( + session, connector_id, search_space_id, user_id, start_date, end_date + ) + async def run_discord_indexing( session: AsyncSession, @@ -616,7 +771,7 @@ async def run_discord_indexing( search_space_id: int, user_id: str, start_date: str, - end_date: str + end_date: str, ): """ Background task to run Discord indexing. @@ -637,14 +792,18 @@ async def run_discord_indexing( user_id=user_id, start_date=start_date, end_date=end_date, - update_last_indexed=False # Don't update timestamp in the indexing function + update_last_indexed=False, # Don't update timestamp in the indexing function ) # Only update last_indexed_at if indexing was successful (either new docs or updated docs) if documents_processed > 0: await update_connector_last_indexed(session, connector_id) - logger.info(f"Discord indexing completed successfully: {documents_processed} documents processed") + logger.info( + f"Discord indexing completed successfully: {documents_processed} documents processed" + ) else: - logger.error(f"Discord indexing failed or no documents processed: {error_or_warning}") + logger.error( + f"Discord indexing failed or no documents processed: {error_or_warning}" + ) except Exception as e: - logger.error(f"Error in background Discord indexing task: {str(e)}") \ No newline at end of file + logger.error(f"Error in background Discord indexing task: {e!s}") diff --git a/surfsense_backend/app/routes/search_spaces_routes.py b/surfsense_backend/app/routes/search_spaces_routes.py index 0f07f71..dc7f69a 100644 --- a/surfsense_backend/app/routes/search_spaces_routes.py +++ b/surfsense_backend/app/routes/search_spaces_routes.py @@ -1,20 +1,20 @@ -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from typing import List -from app.db import get_async_session, User, SearchSpace -from app.schemas import SearchSpaceCreate, SearchSpaceUpdate, SearchSpaceRead + +from app.db import SearchSpace, User, get_async_session +from app.schemas import SearchSpaceCreate, SearchSpaceRead, SearchSpaceUpdate from app.users import current_active_user from app.utils.check_ownership import check_ownership -from fastapi import HTTPException router = APIRouter() + @router.post("/searchspaces/", response_model=SearchSpaceRead) async def create_search_space( search_space: SearchSpaceCreate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): try: db_search_space = SearchSpace(**search_space.model_dump(), user_id=user.id) @@ -27,16 +27,16 @@ async def create_search_space( except Exception as e: await session.rollback() raise HTTPException( - status_code=500, - detail=f"Failed to create search space: {str(e)}" - ) + status_code=500, detail=f"Failed to create search space: {e!s}" + ) from e -@router.get("/searchspaces/", response_model=List[SearchSpaceRead]) + +@router.get("/searchspaces/", response_model=list[SearchSpaceRead]) async def read_search_spaces( skip: int = 0, limit: int = 200, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): try: result = await session.execute( @@ -48,37 +48,41 @@ async def read_search_spaces( return result.scalars().all() except Exception as e: raise HTTPException( - status_code=500, - detail=f"Failed to fetch search spaces: {str(e)}" - ) + status_code=500, detail=f"Failed to fetch search spaces: {e!s}" + ) from e + @router.get("/searchspaces/{search_space_id}", response_model=SearchSpaceRead) async def read_search_space( search_space_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): try: - search_space = await check_ownership(session, SearchSpace, search_space_id, user) + search_space = await check_ownership( + session, SearchSpace, search_space_id, user + ) return search_space - + except HTTPException: raise except Exception as e: raise HTTPException( - status_code=500, - detail=f"Failed to fetch search space: {str(e)}" - ) + status_code=500, detail=f"Failed to fetch search space: {e!s}" + ) from e + @router.put("/searchspaces/{search_space_id}", response_model=SearchSpaceRead) async def update_search_space( search_space_id: int, search_space_update: SearchSpaceUpdate, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): try: - db_search_space = await check_ownership(session, SearchSpace, search_space_id, user) + db_search_space = await check_ownership( + session, SearchSpace, search_space_id, user + ) update_data = search_space_update.model_dump(exclude_unset=True) for key, value in update_data.items(): setattr(db_search_space, key, value) @@ -90,18 +94,20 @@ async def update_search_space( except Exception as e: await session.rollback() raise HTTPException( - status_code=500, - detail=f"Failed to update search space: {str(e)}" - ) + status_code=500, detail=f"Failed to update search space: {e!s}" + ) from e + @router.delete("/searchspaces/{search_space_id}", response_model=dict) async def delete_search_space( search_space_id: int, session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user) + user: User = Depends(current_active_user), ): try: - db_search_space = await check_ownership(session, SearchSpace, search_space_id, user) + db_search_space = await check_ownership( + session, SearchSpace, search_space_id, user + ) await session.delete(db_search_space) await session.commit() return {"message": "Search space deleted successfully"} @@ -110,6 +116,5 @@ async def delete_search_space( except Exception as e: await session.rollback() raise HTTPException( - status_code=500, - detail=f"Failed to delete search space: {str(e)}" - ) \ No newline at end of file + status_code=500, detail=f"Failed to delete search space: {e!s}" + ) from e diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index 89525c9..e38d534 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -1,62 +1,78 @@ -from .base import TimestampModel, IDModel -from .users import UserRead, UserCreate, UserUpdate -from .search_space import SearchSpaceBase, SearchSpaceCreate, SearchSpaceUpdate, SearchSpaceRead +from .base import IDModel, TimestampModel +from .chats import AISDKChatRequest, ChatBase, ChatCreate, ChatRead, ChatUpdate +from .chunks import ChunkBase, ChunkCreate, ChunkRead, ChunkUpdate from .documents import ( - ExtensionDocumentMetadata, - ExtensionDocumentContent, DocumentBase, + DocumentRead, DocumentsCreate, DocumentUpdate, - DocumentRead, + ExtensionDocumentContent, + ExtensionDocumentMetadata, ) -from .chunks import ChunkBase, ChunkCreate, ChunkUpdate, ChunkRead -from .podcasts import PodcastBase, PodcastCreate, PodcastUpdate, PodcastRead, PodcastGenerateRequest -from .chats import ChatBase, ChatCreate, ChatUpdate, ChatRead, AISDKChatRequest -from .search_source_connector import SearchSourceConnectorBase, SearchSourceConnectorCreate, SearchSourceConnectorUpdate, SearchSourceConnectorRead -from .llm_config import LLMConfigBase, LLMConfigCreate, LLMConfigUpdate, LLMConfigRead -from .logs import LogBase, LogCreate, LogUpdate, LogRead, LogFilter +from .llm_config import LLMConfigBase, LLMConfigCreate, LLMConfigRead, LLMConfigUpdate +from .logs import LogBase, LogCreate, LogFilter, LogRead, LogUpdate +from .podcasts import ( + PodcastBase, + PodcastCreate, + PodcastGenerateRequest, + PodcastRead, + PodcastUpdate, +) +from .search_source_connector import ( + SearchSourceConnectorBase, + SearchSourceConnectorCreate, + SearchSourceConnectorRead, + SearchSourceConnectorUpdate, +) +from .search_space import ( + SearchSpaceBase, + SearchSpaceCreate, + SearchSpaceRead, + SearchSpaceUpdate, +) +from .users import UserCreate, UserRead, UserUpdate __all__ = [ "AISDKChatRequest", - "TimestampModel", - "IDModel", - "UserRead", - "UserCreate", - "UserUpdate", - "SearchSpaceBase", - "SearchSpaceCreate", - "SearchSpaceUpdate", - "SearchSpaceRead", - "ExtensionDocumentMetadata", - "ExtensionDocumentContent", - "DocumentBase", - "DocumentsCreate", - "DocumentUpdate", - "DocumentRead", - "ChunkBase", - "ChunkCreate", - "ChunkUpdate", - "ChunkRead", - "PodcastBase", - "PodcastCreate", - "PodcastUpdate", - "PodcastRead", - "PodcastGenerateRequest", "ChatBase", "ChatCreate", - "ChatUpdate", "ChatRead", - "SearchSourceConnectorBase", - "SearchSourceConnectorCreate", - "SearchSourceConnectorUpdate", - "SearchSourceConnectorRead", + "ChatUpdate", + "ChunkBase", + "ChunkCreate", + "ChunkRead", + "ChunkUpdate", + "DocumentBase", + "DocumentRead", + "DocumentUpdate", + "DocumentsCreate", + "ExtensionDocumentContent", + "ExtensionDocumentMetadata", + "IDModel", "LLMConfigBase", "LLMConfigCreate", - "LLMConfigUpdate", "LLMConfigRead", + "LLMConfigUpdate", "LogBase", "LogCreate", - "LogUpdate", - "LogRead", "LogFilter", -] \ No newline at end of file + "LogRead", + "LogUpdate", + "PodcastBase", + "PodcastCreate", + "PodcastGenerateRequest", + "PodcastRead", + "PodcastUpdate", + "SearchSourceConnectorBase", + "SearchSourceConnectorCreate", + "SearchSourceConnectorRead", + "SearchSourceConnectorUpdate", + "SearchSpaceBase", + "SearchSpaceCreate", + "SearchSpaceRead", + "SearchSpaceUpdate", + "TimestampModel", + "UserCreate", + "UserRead", + "UserUpdate", +] diff --git a/surfsense_backend/app/schemas/base.py b/surfsense_backend/app/schemas/base.py index d357aab..a5b4f5e 100644 --- a/surfsense_backend/app/schemas/base.py +++ b/surfsense_backend/app/schemas/base.py @@ -1,10 +1,13 @@ from datetime import datetime + from pydantic import BaseModel, ConfigDict + class TimestampModel(BaseModel): created_at: datetime model_config = ConfigDict(from_attributes=True) + class IDModel(BaseModel): id: int - model_config = ConfigDict(from_attributes=True) \ No newline at end of file + model_config = ConfigDict(from_attributes=True) diff --git a/surfsense_backend/app/schemas/chats.py b/surfsense_backend/app/schemas/chats.py index 82191fb..1dfedef 100644 --- a/surfsense_backend/app/schemas/chats.py +++ b/surfsense_backend/app/schemas/chats.py @@ -1,7 +1,8 @@ -from typing import Any, Dict, List, Optional +from typing import Any + +from pydantic import BaseModel, ConfigDict from app.db import ChatType -from pydantic import BaseModel, ConfigDict from .base import IDModel, TimestampModel @@ -9,39 +10,43 @@ from .base import IDModel, TimestampModel class ChatBase(BaseModel): type: ChatType title: str - initial_connectors: Optional[List[str]] = None - messages: List[Any] + initial_connectors: list[str] | None = None + messages: list[Any] search_space_id: int - + class ClientAttachment(BaseModel): name: str - contentType: str + content_type: str url: str class ToolInvocation(BaseModel): - toolCallId: str - toolName: str + tool_call_id: str + tool_name: str args: dict result: dict - - + + # class ClientMessage(BaseModel): # role: str # content: str # experimental_attachments: Optional[List[ClientAttachment]] = None # toolInvocations: Optional[List[ToolInvocation]] = None - + + class AISDKChatRequest(BaseModel): - messages: List[Any] - data: Optional[Dict[str, Any]] = None + messages: list[Any] + data: dict[str, Any] | None = None + class ChatCreate(ChatBase): pass + class ChatUpdate(ChatBase): pass + class ChatRead(ChatBase, IDModel, TimestampModel): - model_config = ConfigDict(from_attributes=True) \ No newline at end of file + model_config = ConfigDict(from_attributes=True) diff --git a/surfsense_backend/app/schemas/chunks.py b/surfsense_backend/app/schemas/chunks.py index de0764f..7fec0d4 100644 --- a/surfsense_backend/app/schemas/chunks.py +++ b/surfsense_backend/app/schemas/chunks.py @@ -1,15 +1,20 @@ from pydantic import BaseModel, ConfigDict + from .base import IDModel, TimestampModel + class ChunkBase(BaseModel): content: str document_id: int + class ChunkCreate(ChunkBase): pass + class ChunkUpdate(ChunkBase): pass + class ChunkRead(ChunkBase, IDModel, TimestampModel): - model_config = ConfigDict(from_attributes=True) \ No newline at end of file + model_config = ConfigDict(from_attributes=True) diff --git a/surfsense_backend/app/schemas/documents.py b/surfsense_backend/app/schemas/documents.py index b83a749..b98ccfd 100644 --- a/surfsense_backend/app/schemas/documents.py +++ b/surfsense_backend/app/schemas/documents.py @@ -1,8 +1,10 @@ -from typing import List -from pydantic import BaseModel, ConfigDict -from app.db import DocumentType from datetime import datetime +from pydantic import BaseModel, ConfigDict + +from app.db import DocumentType + + class ExtensionDocumentMetadata(BaseModel): BrowsingSessionId: str VisitedWebPageURL: str @@ -11,21 +13,28 @@ class ExtensionDocumentMetadata(BaseModel): VisitedWebPageReffererURL: str VisitedWebPageVisitDurationInMilliseconds: str + class ExtensionDocumentContent(BaseModel): metadata: ExtensionDocumentMetadata - pageContent: str + pageContent: str # noqa: N815 + class DocumentBase(BaseModel): document_type: DocumentType - content: List[ExtensionDocumentContent] | List[str] | str # Updated to allow string content + content: ( + list[ExtensionDocumentContent] | list[str] | str + ) # Updated to allow string content search_space_id: int + class DocumentsCreate(DocumentBase): pass + class DocumentUpdate(DocumentBase): pass + class DocumentRead(BaseModel): id: int title: str @@ -34,6 +43,5 @@ class DocumentRead(BaseModel): content: str # Changed to string to match frontend created_at: datetime search_space_id: int - - model_config = ConfigDict(from_attributes=True) + model_config = ConfigDict(from_attributes=True) diff --git a/surfsense_backend/app/schemas/llm_config.py b/surfsense_backend/app/schemas/llm_config.py index f4032cb..c3c0033 100644 --- a/surfsense_backend/app/schemas/llm_config.py +++ b/surfsense_backend/app/schemas/llm_config.py @@ -1,34 +1,61 @@ -from datetime import datetime import uuid -from typing import Optional, Dict, Any +from datetime import datetime +from typing import Any + from pydantic import BaseModel, ConfigDict, Field -from .base import IDModel, TimestampModel + from app.db import LiteLLMProvider +from .base import IDModel, TimestampModel + + class LLMConfigBase(BaseModel): - name: str = Field(..., max_length=100, description="User-friendly name for the LLM configuration") + name: str = Field( + ..., max_length=100, description="User-friendly name for the LLM configuration" + ) provider: LiteLLMProvider = Field(..., description="LiteLLM provider type") - custom_provider: Optional[str] = Field(None, max_length=100, description="Custom provider name when provider is CUSTOM") - model_name: str = Field(..., max_length=100, description="Model name without provider prefix") + custom_provider: str | None = Field( + None, max_length=100, description="Custom provider name when provider is CUSTOM" + ) + model_name: str = Field( + ..., max_length=100, description="Model name without provider prefix" + ) api_key: str = Field(..., description="API key for the provider") - api_base: Optional[str] = Field(None, max_length=500, description="Optional API base URL") - litellm_params: Optional[Dict[str, Any]] = Field(default=None, description="Additional LiteLLM parameters") + api_base: str | None = Field( + None, max_length=500, description="Optional API base URL" + ) + litellm_params: dict[str, Any] | None = Field( + default=None, description="Additional LiteLLM parameters" + ) + class LLMConfigCreate(LLMConfigBase): pass + class LLMConfigUpdate(BaseModel): - name: Optional[str] = Field(None, max_length=100, description="User-friendly name for the LLM configuration") - provider: Optional[LiteLLMProvider] = Field(None, description="LiteLLM provider type") - custom_provider: Optional[str] = Field(None, max_length=100, description="Custom provider name when provider is CUSTOM") - model_name: Optional[str] = Field(None, max_length=100, description="Model name without provider prefix") - api_key: Optional[str] = Field(None, description="API key for the provider") - api_base: Optional[str] = Field(None, max_length=500, description="Optional API base URL") - litellm_params: Optional[Dict[str, Any]] = Field(None, description="Additional LiteLLM parameters") + name: str | None = Field( + None, max_length=100, description="User-friendly name for the LLM configuration" + ) + provider: LiteLLMProvider | None = Field(None, description="LiteLLM provider type") + custom_provider: str | None = Field( + None, max_length=100, description="Custom provider name when provider is CUSTOM" + ) + model_name: str | None = Field( + None, max_length=100, description="Model name without provider prefix" + ) + api_key: str | None = Field(None, description="API key for the provider") + api_base: str | None = Field( + None, max_length=500, description="Optional API base URL" + ) + litellm_params: dict[str, Any] | None = Field( + None, description="Additional LiteLLM parameters" + ) + class LLMConfigRead(LLMConfigBase, IDModel, TimestampModel): id: int created_at: datetime user_id: uuid.UUID - model_config = ConfigDict(from_attributes=True) \ No newline at end of file + model_config = ConfigDict(from_attributes=True) diff --git a/surfsense_backend/app/schemas/logs.py b/surfsense_backend/app/schemas/logs.py index 1d9d7e7..a47d5db 100644 --- a/surfsense_backend/app/schemas/logs.py +++ b/surfsense_backend/app/schemas/logs.py @@ -1,30 +1,37 @@ from datetime import datetime -from typing import Optional, Dict, Any +from typing import Any + from pydantic import BaseModel, ConfigDict -from .base import IDModel, TimestampModel + from app.db import LogLevel, LogStatus +from .base import IDModel, TimestampModel + + class LogBase(BaseModel): level: LogLevel status: LogStatus message: str - source: Optional[str] = None - log_metadata: Optional[Dict[str, Any]] = None + source: str | None = None + log_metadata: dict[str, Any] | None = None + class LogCreate(BaseModel): level: LogLevel status: LogStatus message: str - source: Optional[str] = None - log_metadata: Optional[Dict[str, Any]] = None + source: str | None = None + log_metadata: dict[str, Any] | None = None search_space_id: int + class LogUpdate(BaseModel): - level: Optional[LogLevel] = None - status: Optional[LogStatus] = None - message: Optional[str] = None - source: Optional[str] = None - log_metadata: Optional[Dict[str, Any]] = None + level: LogLevel | None = None + status: LogStatus | None = None + message: str | None = None + source: str | None = None + log_metadata: dict[str, Any] | None = None + class LogRead(LogBase, IDModel, TimestampModel): id: int @@ -33,12 +40,13 @@ class LogRead(LogBase, IDModel, TimestampModel): model_config = ConfigDict(from_attributes=True) -class LogFilter(BaseModel): - level: Optional[LogLevel] = None - status: Optional[LogStatus] = None - source: Optional[str] = None - search_space_id: Optional[int] = None - start_date: Optional[datetime] = None - end_date: Optional[datetime] = None - model_config = ConfigDict(from_attributes=True) \ No newline at end of file +class LogFilter(BaseModel): + level: LogLevel | None = None + status: LogStatus | None = None + source: str | None = None + search_space_id: int | None = None + start_date: datetime | None = None + end_date: datetime | None = None + + model_config = ConfigDict(from_attributes=True) diff --git a/surfsense_backend/app/schemas/podcasts.py b/surfsense_backend/app/schemas/podcasts.py index 0356dd0..d86b315 100644 --- a/surfsense_backend/app/schemas/podcasts.py +++ b/surfsense_backend/app/schemas/podcasts.py @@ -1,24 +1,31 @@ +from typing import Any, Literal + from pydantic import BaseModel, ConfigDict -from typing import Any, List, Literal + from .base import IDModel, TimestampModel + class PodcastBase(BaseModel): title: str - podcast_transcript: List[Any] + podcast_transcript: list[Any] file_location: str = "" search_space_id: int + class PodcastCreate(PodcastBase): pass + class PodcastUpdate(PodcastBase): pass + class PodcastRead(PodcastBase, IDModel, TimestampModel): model_config = ConfigDict(from_attributes=True) + class PodcastGenerateRequest(BaseModel): type: Literal["DOCUMENT", "CHAT"] - ids: List[int] + ids: list[int] search_space_id: int - podcast_title: str = "SurfSense Podcast" \ No newline at end of file + podcast_title: str = "SurfSense Podcast" diff --git a/surfsense_backend/app/schemas/search_source_connector.py b/surfsense_backend/app/schemas/search_source_connector.py index 1225d54..719a9f9 100644 --- a/surfsense_backend/app/schemas/search_source_connector.py +++ b/surfsense_backend/app/schemas/search_source_connector.py @@ -1,102 +1,124 @@ -from datetime import datetime import uuid -from typing import Dict, Any, Optional -from pydantic import BaseModel, field_validator, ConfigDict -from .base import IDModel, TimestampModel +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, ConfigDict, field_validator + from app.db import SearchSourceConnectorType +from .base import IDModel, TimestampModel + + class SearchSourceConnectorBase(BaseModel): name: str connector_type: SearchSourceConnectorType is_indexable: bool - last_indexed_at: Optional[datetime] = None - config: Dict[str, Any] - - @field_validator('config') + last_indexed_at: datetime | None = None + config: dict[str, Any] + + @field_validator("config") @classmethod - def validate_config_for_connector_type(cls, config: Dict[str, Any], values: Dict[str, Any]) -> Dict[str, Any]: - connector_type = values.data.get('connector_type') - + def validate_config_for_connector_type( + cls, config: dict[str, Any], values: dict[str, Any] + ) -> dict[str, Any]: + connector_type = values.data.get("connector_type") + if connector_type == SearchSourceConnectorType.SERPER_API: # For SERPER_API, only allow SERPER_API_KEY allowed_keys = ["SERPER_API_KEY"] if set(config.keys()) != set(allowed_keys): - raise ValueError(f"For SERPER_API connector type, config must only contain these keys: {allowed_keys}") - + raise ValueError( + f"For SERPER_API connector type, config must only contain these keys: {allowed_keys}" + ) + # Ensure the API key is not empty if not config.get("SERPER_API_KEY"): raise ValueError("SERPER_API_KEY cannot be empty") - + elif connector_type == SearchSourceConnectorType.TAVILY_API: # For TAVILY_API, only allow TAVILY_API_KEY allowed_keys = ["TAVILY_API_KEY"] if set(config.keys()) != set(allowed_keys): - raise ValueError(f"For TAVILY_API connector type, config must only contain these keys: {allowed_keys}") - + raise ValueError( + f"For TAVILY_API connector type, config must only contain these keys: {allowed_keys}" + ) + # Ensure the API key is not empty if not config.get("TAVILY_API_KEY"): raise ValueError("TAVILY_API_KEY cannot be empty") - + elif connector_type == SearchSourceConnectorType.LINKUP_API: # For LINKUP_API, only allow LINKUP_API_KEY allowed_keys = ["LINKUP_API_KEY"] if set(config.keys()) != set(allowed_keys): - raise ValueError(f"For LINKUP_API connector type, config must only contain these keys: {allowed_keys}") - + raise ValueError( + f"For LINKUP_API connector type, config must only contain these keys: {allowed_keys}" + ) + # Ensure the API key is not empty if not config.get("LINKUP_API_KEY"): raise ValueError("LINKUP_API_KEY cannot be empty") - + elif connector_type == SearchSourceConnectorType.SLACK_CONNECTOR: # For SLACK_CONNECTOR, only allow SLACK_BOT_TOKEN allowed_keys = ["SLACK_BOT_TOKEN"] if set(config.keys()) != set(allowed_keys): - raise ValueError(f"For SLACK_CONNECTOR connector type, config must only contain these keys: {allowed_keys}") + raise ValueError( + f"For SLACK_CONNECTOR connector type, config must only contain these keys: {allowed_keys}" + ) # Ensure the bot token is not empty if not config.get("SLACK_BOT_TOKEN"): raise ValueError("SLACK_BOT_TOKEN cannot be empty") - + elif connector_type == SearchSourceConnectorType.NOTION_CONNECTOR: # For NOTION_CONNECTOR, only allow NOTION_INTEGRATION_TOKEN allowed_keys = ["NOTION_INTEGRATION_TOKEN"] if set(config.keys()) != set(allowed_keys): - raise ValueError(f"For NOTION_CONNECTOR connector type, config must only contain these keys: {allowed_keys}") - + raise ValueError( + f"For NOTION_CONNECTOR connector type, config must only contain these keys: {allowed_keys}" + ) + # Ensure the integration token is not empty if not config.get("NOTION_INTEGRATION_TOKEN"): raise ValueError("NOTION_INTEGRATION_TOKEN cannot be empty") - + elif connector_type == SearchSourceConnectorType.GITHUB_CONNECTOR: # For GITHUB_CONNECTOR, only allow GITHUB_PAT and repo_full_names allowed_keys = ["GITHUB_PAT", "repo_full_names"] if set(config.keys()) != set(allowed_keys): - raise ValueError(f"For GITHUB_CONNECTOR connector type, config must only contain these keys: {allowed_keys}") - + raise ValueError( + f"For GITHUB_CONNECTOR connector type, config must only contain these keys: {allowed_keys}" + ) + # Ensure the token is not empty if not config.get("GITHUB_PAT"): raise ValueError("GITHUB_PAT cannot be empty") - + # Ensure the repo_full_names is present and is a non-empty list repo_full_names = config.get("repo_full_names") if not isinstance(repo_full_names, list) or not repo_full_names: raise ValueError("repo_full_names must be a non-empty list of strings") - + elif connector_type == SearchSourceConnectorType.LINEAR_CONNECTOR: # For LINEAR_CONNECTOR, only allow LINEAR_API_KEY allowed_keys = ["LINEAR_API_KEY"] if set(config.keys()) != set(allowed_keys): - raise ValueError(f"For LINEAR_CONNECTOR connector type, config must only contain these keys: {allowed_keys}") - + raise ValueError( + f"For LINEAR_CONNECTOR connector type, config must only contain these keys: {allowed_keys}" + ) + # Ensure the token is not empty if not config.get("LINEAR_API_KEY"): raise ValueError("LINEAR_API_KEY cannot be empty") - + elif connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR: # For DISCORD_CONNECTOR, only allow DISCORD_BOT_TOKEN allowed_keys = ["DISCORD_BOT_TOKEN"] if set(config.keys()) != set(allowed_keys): - raise ValueError(f"For DISCORD_CONNECTOR connector type, config must only contain these keys: {allowed_keys}") + raise ValueError( + f"For DISCORD_CONNECTOR connector type, config must only contain these keys: {allowed_keys}" + ) # Ensure the bot token is not empty if not config.get("DISCORD_BOT_TOKEN"): @@ -104,17 +126,20 @@ class SearchSourceConnectorBase(BaseModel): return config + class SearchSourceConnectorCreate(SearchSourceConnectorBase): pass + class SearchSourceConnectorUpdate(BaseModel): - name: Optional[str] = None - connector_type: Optional[SearchSourceConnectorType] = None - is_indexable: Optional[bool] = None - last_indexed_at: Optional[datetime] = None - config: Optional[Dict[str, Any]] = None + name: str | None = None + connector_type: SearchSourceConnectorType | None = None + is_indexable: bool | None = None + last_indexed_at: datetime | None = None + config: dict[str, Any] | None = None + class SearchSourceConnectorRead(SearchSourceConnectorBase, IDModel, TimestampModel): user_id: uuid.UUID - model_config = ConfigDict(from_attributes=True) + model_config = ConfigDict(from_attributes=True) diff --git a/surfsense_backend/app/schemas/search_space.py b/surfsense_backend/app/schemas/search_space.py index 2c99c45..00bfdc0 100644 --- a/surfsense_backend/app/schemas/search_space.py +++ b/surfsense_backend/app/schemas/search_space.py @@ -1,22 +1,27 @@ -from datetime import datetime import uuid -from typing import Optional +from datetime import datetime + from pydantic import BaseModel, ConfigDict + from .base import IDModel, TimestampModel + class SearchSpaceBase(BaseModel): name: str - description: Optional[str] = None + description: str | None = None + class SearchSpaceCreate(SearchSpaceBase): pass + class SearchSpaceUpdate(SearchSpaceBase): pass + class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel): id: int created_at: datetime user_id: uuid.UUID - model_config = ConfigDict(from_attributes=True) \ No newline at end of file + model_config = ConfigDict(from_attributes=True) diff --git a/surfsense_backend/app/schemas/users.py b/surfsense_backend/app/schemas/users.py index 180faf2..de1169e 100644 --- a/surfsense_backend/app/schemas/users.py +++ b/surfsense_backend/app/schemas/users.py @@ -1,11 +1,15 @@ import uuid + from fastapi_users import schemas + class UserRead(schemas.BaseUser[uuid.UUID]): pass + class UserCreate(schemas.BaseUserCreate): pass + class UserUpdate(schemas.BaseUserUpdate): - pass \ No newline at end of file + pass diff --git a/surfsense_backend/app/services/__init__.py b/surfsense_backend/app/services/__init__.py index 9983534..a70b302 100644 --- a/surfsense_backend/app/services/__init__.py +++ b/surfsense_backend/app/services/__init__.py @@ -1 +1 @@ -# Services package \ No newline at end of file +# Services package diff --git a/surfsense_backend/app/services/connector_service.py b/surfsense_backend/app/services/connector_service.py index f53fd4d..33001e2 100644 --- a/surfsense_backend/app/services/connector_service.py +++ b/surfsense_backend/app/services/connector_service.py @@ -1,26 +1,36 @@ -from typing import List, Dict, Optional import asyncio -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select -from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever -from app.retriver.documents_hybrid_search import DocumentHybridSearchRetriever -from app.db import SearchSourceConnector, SearchSourceConnectorType, Chunk, Document, SearchSpace -from tavily import TavilyClient + from linkup import LinkupClient from sqlalchemy import func +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from tavily import TavilyClient from app.agents.researcher.configuration import SearchMode +from app.db import ( + Chunk, + Document, + SearchSourceConnector, + SearchSourceConnectorType, + SearchSpace, +) +from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever +from app.retriver.documents_hybrid_search import DocumentHybridSearchRetriever class ConnectorService: - def __init__(self, session: AsyncSession, user_id: str = None): + def __init__(self, session: AsyncSession, user_id: str | None = None): self.session = session self.chunk_retriever = ChucksHybridSearchRetriever(session) self.document_retriever = DocumentHybridSearchRetriever(session) self.user_id = user_id - self.source_id_counter = 100000 # High starting value to avoid collisions with existing IDs - self.counter_lock = asyncio.Lock() # Lock to protect counter in multithreaded environments - + self.source_id_counter = ( + 100000 # High starting value to avoid collisions with existing IDs + ) + self.counter_lock = ( + asyncio.Lock() + ) # Lock to protect counter in multithreaded environments + async def initialize_counter(self): """ Initialize the source_id_counter based on the total number of chunks for the user. @@ -38,16 +48,25 @@ class ConnectorService: ) chunk_count = result.scalar() or 0 self.source_id_counter = chunk_count + 1 - print(f"Initialized source_id_counter to {self.source_id_counter} for user {self.user_id}") + print( + f"Initialized source_id_counter to {self.source_id_counter} for user {self.user_id}" + ) except Exception as e: - print(f"Error initializing source_id_counter: {str(e)}") + print(f"Error initializing source_id_counter: {e!s}") # Fallback to default value self.source_id_counter = 1 - - async def search_crawled_urls(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple: + + async def search_crawled_urls( + self, + user_query: str, + user_id: str, + search_space_id: int, + top_k: int = 20, + search_mode: SearchMode = SearchMode.CHUNKS, + ) -> tuple: """ Search for crawled URLs and return both the source information and langchain documents - + Returns: tuple: (sources_info, langchain_documents) """ @@ -57,7 +76,7 @@ class ConnectorService: top_k=top_k, user_id=user_id, search_space_id=search_space_id, - document_type="CRAWLED_URL" + document_type="CRAWLED_URL", ) elif search_mode == SearchMode.DOCUMENTS: crawled_urls_chunks = await self.document_retriever.hybrid_search( @@ -65,7 +84,7 @@ class ConnectorService: top_k=top_k, user_id=user_id, search_space_id=search_space_id, - document_type="CRAWLED_URL" + document_type="CRAWLED_URL", ) # Transform document retriever results to match expected format crawled_urls_chunks = self._transform_document_results(crawled_urls_chunks) @@ -84,20 +103,23 @@ class ConnectorService: async with self.counter_lock: for _i, chunk in enumerate(crawled_urls_chunks): # Extract document metadata - document = chunk.get('document', {}) - metadata = document.get('metadata', {}) + document = chunk.get("document", {}) + metadata = document.get("metadata", {}) # Create a source entry source = { - "id": document.get('id', self.source_id_counter), - "title": document.get('title', 'Untitled Document'), - "description": metadata.get('og:description', metadata.get('ogDescription', chunk.get('content', '')[:100])), - "url": metadata.get('url', '') + "id": document.get("id", self.source_id_counter), + "title": document.get("title", "Untitled Document"), + "description": metadata.get( + "og:description", + metadata.get("ogDescription", chunk.get("content", "")[:100]), + ), + "url": metadata.get("url", ""), } self.source_id_counter += 1 sources_list.append(source) - + # Create result object result_object = { "id": 1, @@ -105,13 +127,20 @@ class ConnectorService: "type": "CRAWLED_URL", "sources": sources_list, } - + return result_object, crawled_urls_chunks - - async def search_files(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple: + + async def search_files( + self, + user_query: str, + user_id: str, + search_space_id: int, + top_k: int = 20, + search_mode: SearchMode = SearchMode.CHUNKS, + ) -> tuple: """ Search for files and return both the source information and langchain documents - + Returns: tuple: (sources_info, langchain_documents) """ @@ -121,7 +150,7 @@ class ConnectorService: top_k=top_k, user_id=user_id, search_space_id=search_space_id, - document_type="FILE" + document_type="FILE", ) elif search_mode == SearchMode.DOCUMENTS: files_chunks = await self.document_retriever.hybrid_search( @@ -129,11 +158,11 @@ class ConnectorService: top_k=top_k, user_id=user_id, search_space_id=search_space_id, - document_type="FILE" + document_type="FILE", ) # Transform document retriever results to match expected format files_chunks = self._transform_document_results(files_chunks) - + # Early return if no results if not files_chunks: return { @@ -148,20 +177,23 @@ class ConnectorService: async with self.counter_lock: for _i, chunk in enumerate(files_chunks): # Extract document metadata - document = chunk.get('document', {}) - metadata = document.get('metadata', {}) + document = chunk.get("document", {}) + metadata = document.get("metadata", {}) # Create a source entry source = { - "id": document.get('id', self.source_id_counter), - "title": document.get('title', 'Untitled Document'), - "description": metadata.get('og:description', metadata.get('ogDescription', chunk.get('content', '')[:100])), - "url": metadata.get('url', '') + "id": document.get("id", self.source_id_counter), + "title": document.get("title", "Untitled Document"), + "description": metadata.get( + "og:description", + metadata.get("ogDescription", chunk.get("content", "")[:100]), + ), + "url": metadata.get("url", ""), } self.source_id_counter += 1 sources_list.append(source) - + # Create result object result_object = { "id": 2, @@ -169,69 +201,76 @@ class ConnectorService: "type": "FILE", "sources": sources_list, } - + return result_object, files_chunks - - def _transform_document_results(self, document_results: List[Dict]) -> List[Dict]: + + def _transform_document_results(self, document_results: list[dict]) -> list[dict]: """ Transform results from document_retriever.hybrid_search() to match the format expected by the processing code. - + Args: document_results: Results from document_retriever.hybrid_search() - + Returns: List of transformed results in the format expected by the processing code """ transformed_results = [] for doc in document_results: - transformed_results.append({ - 'document': { - 'id': doc.get('document_id'), - 'title': doc.get('title', 'Untitled Document'), - 'document_type': doc.get('document_type'), - 'metadata': doc.get('metadata', {}), - }, - 'content': doc.get('chunks_content', doc.get('content', '')), - 'score': doc.get('score', 0.0) - }) + transformed_results.append( + { + "document": { + "id": doc.get("document_id"), + "title": doc.get("title", "Untitled Document"), + "document_type": doc.get("document_type"), + "metadata": doc.get("metadata", {}), + }, + "content": doc.get("chunks_content", doc.get("content", "")), + "score": doc.get("score", 0.0), + } + ) return transformed_results - - async def get_connector_by_type(self, user_id: str, connector_type: SearchSourceConnectorType) -> Optional[SearchSourceConnector]: + + async def get_connector_by_type( + self, user_id: str, connector_type: SearchSourceConnectorType + ) -> SearchSourceConnector | None: """ Get a connector by type for a specific user - + Args: user_id: The user's ID connector_type: The connector type to retrieve - + Returns: Optional[SearchSourceConnector]: The connector if found, None otherwise """ result = await self.session.execute( - select(SearchSourceConnector) - .filter( + select(SearchSourceConnector).filter( SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type == connector_type + SearchSourceConnector.connector_type == connector_type, ) ) return result.scalars().first() - - async def search_tavily(self, user_query: str, user_id: str, top_k: int = 20) -> tuple: + + async def search_tavily( + self, user_query: str, user_id: str, top_k: int = 20 + ) -> tuple: """ Search using Tavily API and return both the source information and documents - + Args: user_query: The user's query user_id: The user's ID top_k: Maximum number of results to return - + Returns: tuple: (sources_info, documents) """ # Get Tavily connector configuration - tavily_connector = await self.get_connector_by_type(user_id, SearchSourceConnectorType.TAVILY_API) - + tavily_connector = await self.get_connector_by_type( + user_id, SearchSourceConnectorType.TAVILY_API + ) + if not tavily_connector: # Return empty results if no Tavily connector is configured return { @@ -240,22 +279,22 @@ class ConnectorService: "type": "TAVILY_API", "sources": [], }, [] - + # Initialize Tavily client with API key from connector config tavily_api_key = tavily_connector.config.get("TAVILY_API_KEY") tavily_client = TavilyClient(api_key=tavily_api_key) - + # Perform search with Tavily try: response = tavily_client.search( query=user_query, max_results=top_k, - search_depth="advanced" # Use advanced search for better results + search_depth="advanced", # Use advanced search for better results ) - + # Extract results from Tavily response tavily_results = response.get("results", []) - + # Early return if no results if not tavily_results: return { @@ -264,23 +303,22 @@ class ConnectorService: "type": "TAVILY_API", "sources": [], }, [] - + # Process each result and create sources directly without deduplication sources_list = [] documents = [] - + async with self.counter_lock: for i, result in enumerate(tavily_results): - # Create a source entry source = { "id": self.source_id_counter, "title": result.get("title", "Tavily Result"), "description": result.get("content", "")[:100], - "url": result.get("url", "") + "url": result.get("url", ""), } sources_list.append(source) - + # Create a document entry document = { "chunk_id": f"tavily_chunk_{i}", @@ -293,9 +331,9 @@ class ConnectorService: "metadata": { "url": result.get("url", ""), "published_date": result.get("published_date", ""), - "source": "TAVILY_API" - } - } + "source": "TAVILY_API", + }, + }, } documents.append(document) self.source_id_counter += 1 @@ -307,23 +345,30 @@ class ConnectorService: "type": "TAVILY_API", "sources": sources_list, } - + return result_object, documents - + except Exception as e: # Log the error and return empty results - print(f"Error searching with Tavily: {str(e)}") + print(f"Error searching with Tavily: {e!s}") return { "id": 3, "name": "Tavily Search", "type": "TAVILY_API", "sources": [], }, [] - - async def search_slack(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple: + + async def search_slack( + self, + user_query: str, + user_id: str, + search_space_id: int, + top_k: int = 20, + search_mode: SearchMode = SearchMode.CHUNKS, + ) -> tuple: """ Search for slack and return both the source information and langchain documents - + Returns: tuple: (sources_info, langchain_documents) """ @@ -333,7 +378,7 @@ class ConnectorService: top_k=top_k, user_id=user_id, search_space_id=search_space_id, - document_type="SLACK_CONNECTOR" + document_type="SLACK_CONNECTOR", ) elif search_mode == SearchMode.DOCUMENTS: slack_chunks = await self.document_retriever.hybrid_search( @@ -341,11 +386,11 @@ class ConnectorService: top_k=top_k, user_id=user_id, search_space_id=search_space_id, - document_type="SLACK_CONNECTOR" + document_type="SLACK_CONNECTOR", ) # Transform document retriever results to match expected format slack_chunks = self._transform_document_results(slack_chunks) - + # Early return if no results if not slack_chunks: return { @@ -360,31 +405,31 @@ class ConnectorService: async with self.counter_lock: for _i, chunk in enumerate(slack_chunks): # Extract document metadata - document = chunk.get('document', {}) - metadata = document.get('metadata', {}) + document = chunk.get("document", {}) + metadata = document.get("metadata", {}) # Create a mapped source entry with Slack-specific metadata - channel_name = metadata.get('channel_name', 'Unknown Channel') - channel_id = metadata.get('channel_id', '') - message_date = metadata.get('start_date', '') - + channel_name = metadata.get("channel_name", "Unknown Channel") + channel_id = metadata.get("channel_id", "") + message_date = metadata.get("start_date", "") + # Create a more descriptive title for Slack messages title = f"Slack: {channel_name}" if message_date: title += f" ({message_date})" - + # Create a more descriptive description for Slack messages - description = chunk.get('content', '')[:100] + description = chunk.get("content", "")[:100] if len(description) == 100: description += "..." - + # For URL, we can use a placeholder or construct a URL to the Slack channel if available url = "" if channel_id: url = f"https://slack.com/app_redirect?channel={channel_id}" source = { - "id": document.get('id', self.source_id_counter), + "id": document.get("id", self.source_id_counter), "title": title, "description": description, "url": url, @@ -392,7 +437,7 @@ class ConnectorService: self.source_id_counter += 1 sources_list.append(source) - + # Create result object result_object = { "id": 4, @@ -400,19 +445,26 @@ class ConnectorService: "type": "SLACK_CONNECTOR", "sources": sources_list, } - + return result_object, slack_chunks - - async def search_notion(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple: + + async def search_notion( + self, + user_query: str, + user_id: str, + search_space_id: int, + top_k: int = 20, + search_mode: SearchMode = SearchMode.CHUNKS, + ) -> tuple: """ Search for Notion pages and return both the source information and langchain documents - + Args: user_query: The user's query user_id: The user's ID search_space_id: The search space ID to search in top_k: Maximum number of results to return - + Returns: tuple: (sources_info, langchain_documents) """ @@ -422,7 +474,7 @@ class ConnectorService: top_k=top_k, user_id=user_id, search_space_id=search_space_id, - document_type="NOTION_CONNECTOR" + document_type="NOTION_CONNECTOR", ) elif search_mode == SearchMode.DOCUMENTS: notion_chunks = await self.document_retriever.hybrid_search( @@ -430,11 +482,11 @@ class ConnectorService: top_k=top_k, user_id=user_id, search_space_id=search_space_id, - document_type="NOTION_CONNECTOR" + document_type="NOTION_CONNECTOR", ) # Transform document retriever results to match expected format notion_chunks = self._transform_document_results(notion_chunks) - + # Early return if no results if not notion_chunks: return { @@ -449,24 +501,24 @@ class ConnectorService: async with self.counter_lock: for _i, chunk in enumerate(notion_chunks): # Extract document metadata - document = chunk.get('document', {}) - metadata = document.get('metadata', {}) + document = chunk.get("document", {}) + metadata = document.get("metadata", {}) # Create a mapped source entry with Notion-specific metadata - page_title = metadata.get('page_title', 'Untitled Page') - page_id = metadata.get('page_id', '') - indexed_at = metadata.get('indexed_at', '') - + page_title = metadata.get("page_title", "Untitled Page") + page_id = metadata.get("page_id", "") + indexed_at = metadata.get("indexed_at", "") + # Create a more descriptive title for Notion pages title = f"Notion: {page_title}" if indexed_at: title += f" (indexed: {indexed_at})" - + # Create a more descriptive description for Notion pages - description = chunk.get('content', '')[:100] + description = chunk.get("content", "")[:100] if len(description) == 100: description += "..." - + # For URL, we can use a placeholder or construct a URL to the Notion page if available url = "" if page_id: @@ -474,7 +526,7 @@ class ConnectorService: url = f"https://notion.so/{page_id.replace('-', '')}" source = { - "id": document.get('id', self.source_id_counter), + "id": document.get("id", self.source_id_counter), "title": title, "description": description, "url": url, @@ -482,7 +534,7 @@ class ConnectorService: self.source_id_counter += 1 sources_list.append(source) - + # Create result object result_object = { "id": 5, @@ -490,19 +542,26 @@ class ConnectorService: "type": "NOTION_CONNECTOR", "sources": sources_list, } - + return result_object, notion_chunks - - async def search_extension(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple: + + async def search_extension( + self, + user_query: str, + user_id: str, + search_space_id: int, + top_k: int = 20, + search_mode: SearchMode = SearchMode.CHUNKS, + ) -> tuple: """ Search for extension data and return both the source information and langchain documents - + Args: user_query: The user's query user_id: The user's ID search_space_id: The search space ID to search in top_k: Maximum number of results to return - + Returns: tuple: (sources_info, langchain_documents) """ @@ -512,7 +571,7 @@ class ConnectorService: top_k=top_k, user_id=user_id, search_space_id=search_space_id, - document_type="EXTENSION" + document_type="EXTENSION", ) elif search_mode == SearchMode.DOCUMENTS: extension_chunks = await self.document_retriever.hybrid_search( @@ -520,7 +579,7 @@ class ConnectorService: top_k=top_k, user_id=user_id, search_space_id=search_space_id, - document_type="EXTENSION" + document_type="EXTENSION", ) # Transform document retriever results to match expected format extension_chunks = self._transform_document_results(extension_chunks) @@ -537,35 +596,40 @@ class ConnectorService: # Process each chunk and create sources directly without deduplication sources_list = [] async with self.counter_lock: - for i, chunk in enumerate(extension_chunks): + for _, chunk in enumerate(extension_chunks): # Extract document metadata - document = chunk.get('document', {}) - metadata = document.get('metadata', {}) + document = chunk.get("document", {}) + metadata = document.get("metadata", {}) # Extract extension-specific metadata - webpage_title = metadata.get('VisitedWebPageTitle', 'Untitled Page') - webpage_url = metadata.get('VisitedWebPageURL', '') - visit_date = metadata.get('VisitedWebPageDateWithTimeInISOString', '') - visit_duration = metadata.get('VisitedWebPageVisitDurationInMilliseconds', '') - browsing_session_id = metadata.get('BrowsingSessionId', '') - + webpage_title = metadata.get("VisitedWebPageTitle", "Untitled Page") + webpage_url = metadata.get("VisitedWebPageURL", "") + visit_date = metadata.get("VisitedWebPageDateWithTimeInISOString", "") + visit_duration = metadata.get( + "VisitedWebPageVisitDurationInMilliseconds", "" + ) + # Create a more descriptive title for extension data title = webpage_title if visit_date: # Format the date for display (simplified) try: # Just extract the date part for display - formatted_date = visit_date.split('T')[0] if 'T' in visit_date else visit_date + formatted_date = ( + visit_date.split("T")[0] + if "T" in visit_date + else visit_date + ) title += f" (visited: {formatted_date})" - except: + except Exception: # Fallback if date parsing fails title += f" (visited: {visit_date})" - + # Create a more descriptive description for extension data - description = chunk.get('content', '')[:100] + description = chunk.get("content", "")[:100] if len(description) == 100: description += "..." - + # Add visit duration if available if visit_duration: try: @@ -573,24 +637,24 @@ class ConnectorService: if duration_seconds < 60: duration_text = f"{duration_seconds:.1f} seconds" else: - duration_text = f"{duration_seconds/60:.1f} minutes" - + duration_text = f"{duration_seconds / 60:.1f} minutes" + if description: description += f" | Duration: {duration_text}" - except: + except Exception: # Fallback if duration parsing fails pass source = { - "id": document.get('id', self.source_id_counter), + "id": document.get("id", self.source_id_counter), "title": title, "description": description, - "url": webpage_url + "url": webpage_url, } self.source_id_counter += 1 sources_list.append(source) - + # Create result object result_object = { "id": 6, @@ -598,19 +662,26 @@ class ConnectorService: "type": "EXTENSION", "sources": sources_list, } - + return result_object, extension_chunks - - async def search_youtube(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple: + + async def search_youtube( + self, + user_query: str, + user_id: str, + search_space_id: int, + top_k: int = 20, + search_mode: SearchMode = SearchMode.CHUNKS, + ) -> tuple: """ Search for YouTube videos and return both the source information and langchain documents - + Args: user_query: The user's query user_id: The user's ID search_space_id: The search space ID to search in top_k: Maximum number of results to return - + Returns: tuple: (sources_info, langchain_documents) """ @@ -620,7 +691,7 @@ class ConnectorService: top_k=top_k, user_id=user_id, search_space_id=search_space_id, - document_type="YOUTUBE_VIDEO" + document_type="YOUTUBE_VIDEO", ) elif search_mode == SearchMode.DOCUMENTS: youtube_chunks = await self.document_retriever.hybrid_search( @@ -628,11 +699,11 @@ class ConnectorService: top_k=top_k, user_id=user_id, search_space_id=search_space_id, - document_type="YOUTUBE_VIDEO" + document_type="YOUTUBE_VIDEO", ) # Transform document retriever results to match expected format youtube_chunks = self._transform_document_results(youtube_chunks) - + # Early return if no results if not youtube_chunks: return { @@ -647,40 +718,42 @@ class ConnectorService: async with self.counter_lock: for _i, chunk in enumerate(youtube_chunks): # Extract document metadata - document = chunk.get('document', {}) - metadata = document.get('metadata', {}) + document = chunk.get("document", {}) + metadata = document.get("metadata", {}) # Extract YouTube-specific metadata - video_title = metadata.get('video_title', 'Untitled Video') - video_id = metadata.get('video_id', '') - channel_name = metadata.get('channel_name', '') + video_title = metadata.get("video_title", "Untitled Video") + video_id = metadata.get("video_id", "") + channel_name = metadata.get("channel_name", "") # published_date = metadata.get('published_date', '') - + # Create a more descriptive title for YouTube videos title = video_title if channel_name: title += f" - {channel_name}" - + # Create a more descriptive description for YouTube videos - description = metadata.get('description', chunk.get('content', '')[:100]) + description = metadata.get( + "description", chunk.get("content", "")[:100] + ) if len(description) == 100: description += "..." - + # For URL, construct a URL to the YouTube video url = f"https://www.youtube.com/watch?v={video_id}" if video_id else "" source = { - "id": document.get('id', self.source_id_counter), + "id": document.get("id", self.source_id_counter), "title": title, "description": description, "url": url, "video_id": video_id, # Additional field for YouTube videos - "channel_name": channel_name # Additional field for YouTube videos + "channel_name": channel_name, # Additional field for YouTube videos } self.source_id_counter += 1 sources_list.append(source) - + # Create result object result_object = { "id": 7, # Assign a unique ID for the YouTube connector @@ -688,13 +761,20 @@ class ConnectorService: "type": "YOUTUBE_VIDEO", "sources": sources_list, } - + return result_object, youtube_chunks - async def search_github(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple: + async def search_github( + self, + user_query: str, + user_id: int, + search_space_id: int, + top_k: int = 20, + search_mode: SearchMode = SearchMode.CHUNKS, + ) -> tuple: """ Search for GitHub documents and return both the source information and langchain documents - + Returns: tuple: (sources_info, langchain_documents) """ @@ -704,7 +784,7 @@ class ConnectorService: top_k=top_k, user_id=user_id, search_space_id=search_space_id, - document_type="GITHUB_CONNECTOR" + document_type="GITHUB_CONNECTOR", ) elif search_mode == SearchMode.DOCUMENTS: github_chunks = await self.document_retriever.hybrid_search( @@ -712,11 +792,11 @@ class ConnectorService: top_k=top_k, user_id=user_id, search_space_id=search_space_id, - document_type="GITHUB_CONNECTOR" + document_type="GITHUB_CONNECTOR", ) # Transform document retriever results to match expected format github_chunks = self._transform_document_results(github_chunks) - + # Early return if no results if not github_chunks: return { @@ -731,20 +811,24 @@ class ConnectorService: async with self.counter_lock: for _i, chunk in enumerate(github_chunks): # Extract document metadata - document = chunk.get('document', {}) - metadata = document.get('metadata', {}) + document = chunk.get("document", {}) + metadata = document.get("metadata", {}) # Create a source entry source = { - "id": document.get('id', self.source_id_counter), - "title": document.get('title', 'GitHub Document'), # Use specific title if available - "description": metadata.get('description', chunk.get('content', '')[:100]), # Use description or content preview - "url": metadata.get('url', '') # Use URL if available in metadata + "id": document.get("id", self.source_id_counter), + "title": document.get( + "title", "GitHub Document" + ), # Use specific title if available + "description": metadata.get( + "description", chunk.get("content", "")[:100] + ), # Use description or content preview + "url": metadata.get("url", ""), # Use URL if available in metadata } self.source_id_counter += 1 sources_list.append(source) - + # Create result object result_object = { "id": 8, @@ -752,19 +836,26 @@ class ConnectorService: "type": "GITHUB_CONNECTOR", "sources": sources_list, } - + return result_object, github_chunks - async def search_linear(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple: + async def search_linear( + self, + user_query: str, + user_id: str, + search_space_id: int, + top_k: int = 20, + search_mode: SearchMode = SearchMode.CHUNKS, + ) -> tuple: """ Search for Linear issues and comments and return both the source information and langchain documents - + Args: user_query: The user's query user_id: The user's ID search_space_id: The search space ID to search in top_k: Maximum number of results to return - + Returns: tuple: (sources_info, langchain_documents) """ @@ -774,7 +865,7 @@ class ConnectorService: top_k=top_k, user_id=user_id, search_space_id=search_space_id, - document_type="LINEAR_CONNECTOR" + document_type="LINEAR_CONNECTOR", ) elif search_mode == SearchMode.DOCUMENTS: linear_chunks = await self.document_retriever.hybrid_search( @@ -782,7 +873,7 @@ class ConnectorService: top_k=top_k, user_id=user_id, search_space_id=search_space_id, - document_type="LINEAR_CONNECTOR" + document_type="LINEAR_CONNECTOR", ) # Transform document retriever results to match expected format linear_chunks = self._transform_document_results(linear_chunks) @@ -801,32 +892,32 @@ class ConnectorService: async with self.counter_lock: for _i, chunk in enumerate(linear_chunks): # Extract document metadata - document = chunk.get('document', {}) - metadata = document.get('metadata', {}) + document = chunk.get("document", {}) + metadata = document.get("metadata", {}) # Extract Linear-specific metadata - issue_identifier = metadata.get('issue_identifier', '') - issue_title = metadata.get('issue_title', 'Untitled Issue') - issue_state = metadata.get('state', '') - comment_count = metadata.get('comment_count', 0) - + issue_identifier = metadata.get("issue_identifier", "") + issue_title = metadata.get("issue_title", "Untitled Issue") + issue_state = metadata.get("state", "") + comment_count = metadata.get("comment_count", 0) + # Create a more descriptive title for Linear issues title = f"Linear: {issue_identifier} - {issue_title}" if issue_state: title += f" ({issue_state})" - + # Create a more descriptive description for Linear issues - description = chunk.get('content', '')[:100] + description = chunk.get("content", "")[:100] if len(description) == 100: description += "..." - + # Add comment count info to description if comment_count: if description: description += f" | Comments: {comment_count}" else: description = f"Comments: {comment_count}" - + # For URL, we could construct a URL to the Linear issue if we have the workspace info # For now, use a generic placeholder url = "" @@ -835,18 +926,18 @@ class ConnectorService: url = f"https://linear.app/issue/{issue_identifier}" source = { - "id": document.get('id', self.source_id_counter), + "id": document.get("id", self.source_id_counter), "title": title, "description": description, "url": url, "issue_identifier": issue_identifier, "state": issue_state, - "comment_count": comment_count + "comment_count": comment_count, } self.source_id_counter += 1 sources_list.append(source) - + # Create result object result_object = { "id": 9, # Assign a unique ID for the Linear connector @@ -854,24 +945,28 @@ class ConnectorService: "type": "LINEAR_CONNECTOR", "sources": sources_list, } - + return result_object, linear_chunks - async def search_linkup(self, user_query: str, user_id: str, mode: str = "standard") -> tuple: + async def search_linkup( + self, user_query: str, user_id: str, mode: str = "standard" + ) -> tuple: """ Search using Linkup API and return both the source information and documents - + Args: user_query: The user's query user_id: The user's ID mode: Search depth mode, can be "standard" or "deep" - + Returns: tuple: (sources_info, documents) """ # Get Linkup connector configuration - linkup_connector = await self.get_connector_by_type(user_id, SearchSourceConnectorType.LINKUP_API) - + linkup_connector = await self.get_connector_by_type( + user_id, SearchSourceConnectorType.LINKUP_API + ) + if not linkup_connector: # Return empty results if no Linkup connector is configured return { @@ -880,11 +975,11 @@ class ConnectorService: "type": "LINKUP_API", "sources": [], }, [] - + # Initialize Linkup client with API key from connector config linkup_api_key = linkup_connector.config.get("LINKUP_API_KEY") linkup_client = LinkupClient(api_key=linkup_api_key) - + # Perform search with Linkup try: response = linkup_client.search( @@ -892,10 +987,10 @@ class ConnectorService: depth=mode, # Use the provided mode ("standard" or "deep") output_type="searchResults", # Default to search results ) - + # Extract results from Linkup response - access as attribute instead of using .get() - linkup_results = response.results if hasattr(response, 'results') else [] - + linkup_results = response.results if hasattr(response, "results") else [] + # Only proceed if we have results if not linkup_results: return { @@ -904,41 +999,47 @@ class ConnectorService: "type": "LINKUP_API", "sources": [], }, [] - + # Process each result and create sources directly without deduplication sources_list = [] documents = [] - + async with self.counter_lock: for i, result in enumerate(linkup_results): # Only process results that have content - if not hasattr(result, 'content') or not result.content: + if not hasattr(result, "content") or not result.content: continue - + # Create a source entry source = { "id": self.source_id_counter, - "title": result.name if hasattr(result, 'name') else "Linkup Result", - "description": result.content[:100] if hasattr(result, 'content') else "", - "url": result.url if hasattr(result, 'url') else "" + "title": result.name + if hasattr(result, "name") + else "Linkup Result", + "description": result.content[:100] + if hasattr(result, "content") + else "", + "url": result.url if hasattr(result, "url") else "", } sources_list.append(source) - + # Create a document entry document = { "chunk_id": f"linkup_chunk_{i}", - "content": result.content if hasattr(result, 'content') else "", + "content": result.content if hasattr(result, "content") else "", "score": 1.0, # Default score since not provided by Linkup "document": { "id": self.source_id_counter, - "title": result.name if hasattr(result, 'name') else "Linkup Result", + "title": result.name + if hasattr(result, "name") + else "Linkup Result", "document_type": "LINKUP_API", "metadata": { - "url": result.url if hasattr(result, 'url') else "", - "type": result.type if hasattr(result, 'type') else "", - "source": "LINKUP_API" - } - } + "url": result.url if hasattr(result, "url") else "", + "type": result.type if hasattr(result, "type") else "", + "source": "LINKUP_API", + }, + }, } documents.append(document) self.source_id_counter += 1 @@ -950,29 +1051,36 @@ class ConnectorService: "type": "LINKUP_API", "sources": sources_list, } - + return result_object, documents - + except Exception as e: # Log the error and return empty results - print(f"Error searching with Linkup: {str(e)}") + print(f"Error searching with Linkup: {e!s}") return { "id": 10, "name": "Linkup Search", "type": "LINKUP_API", "sources": [], }, [] - - async def search_discord(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple: + + async def search_discord( + self, + user_query: str, + user_id: str, + search_space_id: int, + top_k: int = 20, + search_mode: SearchMode = SearchMode.CHUNKS, + ) -> tuple: """ Search for Discord messages and return both the source information and langchain documents - + Args: user_query: The user's query user_id: The user's ID search_space_id: The search space ID to search in top_k: Maximum number of results to return - + Returns: tuple: (sources_info, langchain_documents) """ @@ -982,7 +1090,7 @@ class ConnectorService: top_k=top_k, user_id=user_id, search_space_id=search_space_id, - document_type="DISCORD_CONNECTOR" + document_type="DISCORD_CONNECTOR", ) elif search_mode == SearchMode.DOCUMENTS: discord_chunks = await self.document_retriever.hybrid_search( @@ -990,11 +1098,11 @@ class ConnectorService: top_k=top_k, user_id=user_id, search_space_id=search_space_id, - document_type="DISCORD_CONNECTOR" + document_type="DISCORD_CONNECTOR", ) # Transform document retriever results to match expected format discord_chunks = self._transform_document_results(discord_chunks) - + # Early return if no results if not discord_chunks: return { @@ -1007,28 +1115,28 @@ class ConnectorService: # Process each chunk and create sources directly without deduplication sources_list = [] async with self.counter_lock: - for i, chunk in enumerate(discord_chunks): + for _, chunk in enumerate(discord_chunks): # Extract document metadata - document = chunk.get('document', {}) - metadata = document.get('metadata', {}) + document = chunk.get("document", {}) + metadata = document.get("metadata", {}) # Create a mapped source entry with Discord-specific metadata - channel_name = metadata.get('channel_name', 'Unknown Channel') - channel_id = metadata.get('channel_id', '') - message_date = metadata.get('start_date', '') - + channel_name = metadata.get("channel_name", "Unknown Channel") + channel_id = metadata.get("channel_id", "") + message_date = metadata.get("start_date", "") + # Create a more descriptive title for Discord messages title = f"Discord: {channel_name}" if message_date: title += f" ({message_date})" - + # Create a more descriptive description for Discord messages - description = chunk.get('content', '')[:100] + description = chunk.get("content", "")[:100] if len(description) == 100: description += "..." - + url = "" - guild_id = metadata.get('guild_id', '') + guild_id = metadata.get("guild_id", "") if guild_id and channel_id: url = f"https://discord.com/channels/{guild_id}/{channel_id}" elif channel_id: @@ -1036,7 +1144,7 @@ class ConnectorService: url = f"https://discord.com/channels/@me/{channel_id}" source = { - "id": document.get('id', self.source_id_counter), + "id": document.get("id", self.source_id_counter), "title": title, "description": description, "url": url, @@ -1044,7 +1152,7 @@ class ConnectorService: self.source_id_counter += 1 sources_list.append(source) - + # Create result object result_object = { "id": 11, @@ -1052,7 +1160,5 @@ class ConnectorService: "type": "DISCORD_CONNECTOR", "sources": sources_list, } - + return result_object, discord_chunks - - diff --git a/surfsense_backend/app/services/docling_service.py b/surfsense_backend/app/services/docling_service.py index 6552681..a61148c 100644 --- a/surfsense_backend/app/services/docling_service.py +++ b/surfsense_backend/app/services/docling_service.py @@ -5,15 +5,16 @@ SSL-safe implementation with pre-downloaded models """ import logging -import ssl import os -from typing import Dict, Any +import ssl +from typing import Any logger = logging.getLogger(__name__) + class DoclingService: """Docling service for enhanced document processing with SSL fixes.""" - + def __init__(self): """Initialize Docling service with SSL, model fixes, and GPU acceleration.""" self.converter = None @@ -21,30 +22,32 @@ class DoclingService: self._configure_ssl_environment() self._check_wsl2_gpu_support() self._initialize_docling() - + def _configure_ssl_environment(self): """Configure SSL environment for secure model downloads.""" try: # Set SSL context for downloads ssl._create_default_https_context = ssl._create_unverified_context - + # Set SSL environment variables if not already set - if not os.environ.get('SSL_CERT_FILE'): + if not os.environ.get("SSL_CERT_FILE"): try: import certifi - os.environ['SSL_CERT_FILE'] = certifi.where() - os.environ['REQUESTS_CA_BUNDLE'] = certifi.where() + + os.environ["SSL_CERT_FILE"] = certifi.where() + os.environ["REQUESTS_CA_BUNDLE"] = certifi.where() except ImportError: pass - + logger.info("🔐 SSL environment configured for model downloads") except Exception as e: logger.warning(f"⚠️ SSL configuration warning: {e}") - + def _check_wsl2_gpu_support(self): """Check and configure GPU support for WSL2 environment.""" try: import torch + if torch.cuda.is_available(): gpu_count = torch.cuda.device_count() gpu_name = torch.cuda.get_device_name(0) if gpu_count > 0 else "Unknown" @@ -60,34 +63,34 @@ class DoclingService: except Exception as e: logger.warning(f"⚠️ GPU detection failed: {e}, falling back to CPU") self.use_gpu = False - + def _initialize_docling(self): """Initialize Docling with version-safe configuration.""" try: - from docling.document_converter import DocumentConverter, PdfFormatOption + from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend from docling.datamodel.base_models import InputFormat from docling.datamodel.pipeline_options import PdfPipelineOptions - from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend - + from docling.document_converter import DocumentConverter, PdfFormatOption + logger.info("🔧 Initializing Docling with version-safe configuration...") - + # Create pipeline options with version-safe attribute checking pipeline_options = PdfPipelineOptions() - + # Disable OCR (user request) - if hasattr(pipeline_options, 'do_ocr'): + if hasattr(pipeline_options, "do_ocr"): pipeline_options.do_ocr = False logger.info("⚠️ OCR disabled by user request") else: logger.warning("⚠️ OCR attribute not available in this Docling version") - + # Enable table structure if available - if hasattr(pipeline_options, 'do_table_structure'): + if hasattr(pipeline_options, "do_table_structure"): pipeline_options.do_table_structure = True logger.info("✅ Table structure detection enabled") - + # Configure GPU acceleration for WSL2 if available - if hasattr(pipeline_options, 'accelerator_device'): + if hasattr(pipeline_options, "accelerator_device"): if self.use_gpu: try: pipeline_options.accelerator_device = "cuda" @@ -99,164 +102,180 @@ class DoclingService: pipeline_options.accelerator_device = "cpu" logger.info("🖥️ Using CPU acceleration") else: - logger.info("ℹ️ Accelerator device attribute not available in this Docling version") - + logger.info( + "⚠️ Accelerator device attribute not available in this Docling version" + ) + # Create PDF format option with backend pdf_format_option = PdfFormatOption( - pipeline_options=pipeline_options, - backend=PyPdfiumDocumentBackend + pipeline_options=pipeline_options, backend=PyPdfiumDocumentBackend ) - + # Initialize DocumentConverter self.converter = DocumentConverter( - format_options={ - InputFormat.PDF: pdf_format_option - } + format_options={InputFormat.PDF: pdf_format_option} ) - + acceleration_type = "GPU (WSL2)" if self.use_gpu else "CPU" - logger.info(f"✅ Docling initialized successfully with {acceleration_type} acceleration") - + logger.info( + f"✅ Docling initialized successfully with {acceleration_type} acceleration" + ) + except ImportError as e: logger.error(f"❌ Docling not installed: {e}") - raise RuntimeError(f"Docling not available: {e}") + raise RuntimeError(f"Docling not available: {e}") from e except Exception as e: logger.error(f"❌ Docling initialization failed: {e}") - raise RuntimeError(f"Docling initialization failed: {e}") - + raise RuntimeError(f"Docling initialization failed: {e}") from e + def _configure_easyocr_local_models(self): """Configure EasyOCR to use pre-downloaded local models.""" try: - import easyocr import os - + + import easyocr + # Set SSL environment for EasyOCR downloads - os.environ['CURL_CA_BUNDLE'] = '' - os.environ['REQUESTS_CA_BUNDLE'] = '' - + os.environ["CURL_CA_BUNDLE"] = "" + os.environ["REQUESTS_CA_BUNDLE"] = "" + # Try to use local models first, fallback to download if needed try: - reader = easyocr.Reader(['en'], - download_enabled=False, - model_storage_directory="/root/.EasyOCR/model") + reader = easyocr.Reader( + ["en"], + download_enabled=False, + model_storage_directory="/root/.EasyOCR/model", + ) logger.info("✅ EasyOCR configured for local models") return reader - except: + except Exception: # If local models fail, allow download with SSL bypass - logger.info("🔄 Local models failed, attempting download with SSL bypass...") - reader = easyocr.Reader(['en'], - download_enabled=True, - model_storage_directory="/root/.EasyOCR/model") + logger.info( + "🔄 Local models failed, attempting download with SSL bypass..." + ) + reader = easyocr.Reader( + ["en"], + download_enabled=True, + model_storage_directory="/root/.EasyOCR/model", + ) logger.info("✅ EasyOCR configured with downloaded models") return reader except Exception as e: logger.warning(f"⚠️ EasyOCR configuration failed: {e}") return None - - async def process_document(self, file_path: str, filename: str = None) -> Dict[str, Any]: + + async def process_document( + self, file_path: str, filename: str | None = None + ) -> dict[str, Any]: """Process document with Docling using pre-downloaded models.""" - + if self.converter is None: raise RuntimeError("Docling converter not initialized") - + try: - logger.info(f"🔄 Processing {filename} with Docling (using local models)...") - + logger.info( + f"🔄 Processing {filename} with Docling (using local models)..." + ) + # Process document with local models result = self.converter.convert(file_path) - + # Extract content using version-safe methods content = None - if hasattr(result, 'document') and result.document: + if hasattr(result, "document") and result.document: # Try different export methods (version compatibility) - if hasattr(result.document, 'export_to_markdown'): + if hasattr(result.document, "export_to_markdown"): content = result.document.export_to_markdown() logger.info("📄 Used export_to_markdown method") - elif hasattr(result.document, 'to_markdown'): + elif hasattr(result.document, "to_markdown"): content = result.document.to_markdown() logger.info("📄 Used to_markdown method") - elif hasattr(result.document, 'text'): + elif hasattr(result.document, "text"): content = result.document.text logger.info("📄 Used text property") - elif hasattr(result.document, '__str__'): + elif hasattr(result.document, "__str__"): content = str(result.document) logger.info("📄 Used string conversion") - + if content: - logger.info(f"✅ Docling SUCCESS - {filename}: {len(content)} chars (local models)") - + logger.info( + f"✅ Docling SUCCESS - {filename}: {len(content)} chars (local models)" + ) + return { - 'content': content, - 'full_text': content, - 'service_used': 'docling', - 'status': 'success', - 'processing_notes': 'Processed with Docling using pre-downloaded models' + "content": content, + "full_text": content, + "service_used": "docling", + "status": "success", + "processing_notes": "Processed with Docling using pre-downloaded models", } else: raise ValueError("No content could be extracted from document") else: raise ValueError("No document object returned by Docling") - + except Exception as e: logger.error(f"❌ Docling processing failed for {filename}: {e}") # Log the full error for debugging import traceback + logger.error(f"Full traceback: {traceback.format_exc()}") - raise RuntimeError(f"Docling processing failed: {e}") - + raise RuntimeError(f"Docling processing failed: {e}") from e + async def process_large_document_summary( - self, - content: str, - llm, - document_title: str = "Document" + self, content: str, llm, document_title: str = "Document" ) -> str: """ Process large documents using chunked LLM summarization. - + Args: content: The full document content llm: The language model to use for summarization document_title: Title of the document for context - + Returns: Final summary of the document """ # Large document threshold (100K characters ≈ 25K tokens) - LARGE_DOCUMENT_THRESHOLD = 100_000 - - if len(content) <= LARGE_DOCUMENT_THRESHOLD: + large_document_threshold = 100_000 + + if len(content) <= large_document_threshold: # For smaller documents, use direct processing - logger.info(f"📄 Document size: {len(content)} chars - using direct processing") + logger.info( + f"📄 Document size: {len(content)} chars - using direct processing" + ) from app.prompts import SUMMARY_PROMPT_TEMPLATE + summary_chain = SUMMARY_PROMPT_TEMPLATE | llm result = await summary_chain.ainvoke({"document": content}) return result.content - - logger.info(f"📚 Large document detected: {len(content)} chars - using chunked processing") - + + logger.info( + f"📚 Large document detected: {len(content)} chars - using chunked processing" + ) + # Import chunker from config - from app.config import config - from langchain_core.prompts import PromptTemplate - # Create LLM-optimized chunks (8K tokens max for safety) - from chonkie import RecursiveChunker, OverlapRefinery + from chonkie import OverlapRefinery, RecursiveChunker + from langchain_core.prompts import PromptTemplate + llm_chunker = RecursiveChunker( chunk_size=8000 # Conservative for most LLMs ) - + # Apply overlap refinery for context preservation (10% overlap = 800 tokens) overlap_refinery = OverlapRefinery( context_size=0.1, # 10% overlap for context preservation - method="suffix" # Add next chunk context to current chunk + method="suffix", # Add next chunk context to current chunk ) - + # First chunk the content, then apply overlap refinery initial_chunks = llm_chunker.chunk(content) chunks = overlap_refinery.refine(initial_chunks) total_chunks = len(chunks) - + logger.info(f"📄 Split into {total_chunks} chunks for LLM processing") - + # Template for chunk processing chunk_template = PromptTemplate( input_variables=["chunk", "chunk_number", "total_chunks"], @@ -274,34 +293,38 @@ Chunk {chunk_number}/{total_chunks}: {chunk} -""" +""", ) - + # Process each chunk individually chunk_summaries = [] for i, chunk in enumerate(chunks, 1): try: - logger.info(f"🔄 Processing chunk {i}/{total_chunks} ({len(chunk.text)} chars)") - + logger.info( + f"🔄 Processing chunk {i}/{total_chunks} ({len(chunk.text)} chars)" + ) + chunk_chain = chunk_template | llm - chunk_result = await chunk_chain.ainvoke({ - "chunk": chunk.text, - "chunk_number": i, - "total_chunks": total_chunks - }) - + chunk_result = await chunk_chain.ainvoke( + { + "chunk": chunk.text, + "chunk_number": i, + "total_chunks": total_chunks, + } + ) + chunk_summary = chunk_result.content chunk_summaries.append(f"=== Section {i} ===\n{chunk_summary}") - + logger.info(f"✅ Completed chunk {i}/{total_chunks}") - + except Exception as e: logger.error(f"❌ Failed to process chunk {i}/{total_chunks}: {e}") chunk_summaries.append(f"=== Section {i} ===\n[Processing failed]") - + # Combine summaries into final document summary logger.info(f"🔄 Combining {len(chunk_summaries)} chunk summaries") - + try: combine_template = PromptTemplate( input_variables=["summaries", "document_title"], @@ -318,22 +341,23 @@ Ensure: {summaries} -""" +""", ) - + combined_summaries = "\n\n".join(chunk_summaries) combine_chain = combine_template | llm - - final_result = await combine_chain.ainvoke({ - "summaries": combined_summaries, - "document_title": document_title - }) - + + final_result = await combine_chain.ainvoke( + {"summaries": combined_summaries, "document_title": document_title} + ) + final_summary = final_result.content - logger.info(f"✅ Large document processing complete: {len(final_summary)} chars summary") - + logger.info( + f"✅ Large document processing complete: {len(final_summary)} chars summary" + ) + return final_summary - + except Exception as e: logger.error(f"❌ Failed to combine summaries: {e}") # Fallback: return concatenated chunk summaries @@ -341,6 +365,7 @@ Ensure: logger.warning("⚠️ Using fallback combined summary") return fallback_summary + def create_docling_service() -> DoclingService: """Create a Docling service instance.""" - return DoclingService() \ No newline at end of file + return DoclingService() diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py index 7867d09..335a645 100644 --- a/surfsense_backend/app/services/llm_service.py +++ b/surfsense_backend/app/services/llm_service.py @@ -1,45 +1,43 @@ -from typing import Optional -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select -from langchain_community.chat_models import ChatLiteLLM import logging -from app.db import User, LLMConfig +from langchain_community.chat_models import ChatLiteLLM +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.db import LLMConfig, User logger = logging.getLogger(__name__) + class LLMRole: LONG_CONTEXT = "long_context" FAST = "fast" STRATEGIC = "strategic" + async def get_user_llm_instance( - session: AsyncSession, - user_id: str, - role: str -) -> Optional[ChatLiteLLM]: + session: AsyncSession, user_id: str, role: str +) -> ChatLiteLLM | None: """ Get a ChatLiteLLM instance for a specific user and role. - + Args: session: Database session user_id: User ID role: LLM role ('long_context', 'fast', or 'strategic') - + Returns: ChatLiteLLM instance or None if not found """ try: # Get user with their LLM preferences - result = await session.execute( - select(User).where(User.id == user_id) - ) + result = await session.execute(select(User).where(User.id == user_id)) user = result.scalars().first() - + if not user: logger.error(f"User {user_id} not found") return None - + # Get the appropriate LLM config ID based on role llm_config_id = None if role == LLMRole.LONG_CONTEXT: @@ -51,24 +49,23 @@ async def get_user_llm_instance( else: logger.error(f"Invalid LLM role: {role}") return None - + if not llm_config_id: logger.error(f"No {role} LLM configured for user {user_id}") return None - + # Get the LLM configuration result = await session.execute( select(LLMConfig).where( - LLMConfig.id == llm_config_id, - LLMConfig.user_id == user_id + LLMConfig.id == llm_config_id, LLMConfig.user_id == user_id ) ) llm_config = result.scalars().first() - + if not llm_config: logger.error(f"LLM config {llm_config_id} not found for user {user_id}") return None - + # Build the model string for litellm if llm_config.custom_provider: model_string = f"{llm_config.custom_provider}/{llm_config.model_name}" @@ -76,7 +73,7 @@ async def get_user_llm_instance( # Map provider enum to litellm format provider_map = { "OPENAI": "openai", - "ANTHROPIC": "anthropic", + "ANTHROPIC": "anthropic", "GROQ": "groq", "COHERE": "cohere", "GOOGLE": "gemini", @@ -84,37 +81,48 @@ async def get_user_llm_instance( "MISTRAL": "mistral", # Add more mappings as needed } - provider_prefix = provider_map.get(llm_config.provider.value, llm_config.provider.value.lower()) + provider_prefix = provider_map.get( + llm_config.provider.value, llm_config.provider.value.lower() + ) model_string = f"{provider_prefix}/{llm_config.model_name}" - + # Create ChatLiteLLM instance litellm_kwargs = { "model": model_string, "api_key": llm_config.api_key, } - + # Add optional parameters if llm_config.api_base: litellm_kwargs["api_base"] = llm_config.api_base - + # Add any additional litellm parameters if llm_config.litellm_params: litellm_kwargs.update(llm_config.litellm_params) - + return ChatLiteLLM(**litellm_kwargs) - + except Exception as e: - logger.error(f"Error getting LLM instance for user {user_id}, role {role}: {str(e)}") + logger.error( + f"Error getting LLM instance for user {user_id}, role {role}: {e!s}" + ) return None -async def get_user_long_context_llm(session: AsyncSession, user_id: str) -> Optional[ChatLiteLLM]: + +async def get_user_long_context_llm( + session: AsyncSession, user_id: str +) -> ChatLiteLLM | None: """Get user's long context LLM instance.""" return await get_user_llm_instance(session, user_id, LLMRole.LONG_CONTEXT) -async def get_user_fast_llm(session: AsyncSession, user_id: str) -> Optional[ChatLiteLLM]: + +async def get_user_fast_llm(session: AsyncSession, user_id: str) -> ChatLiteLLM | None: """Get user's fast LLM instance.""" return await get_user_llm_instance(session, user_id, LLMRole.FAST) -async def get_user_strategic_llm(session: AsyncSession, user_id: str) -> Optional[ChatLiteLLM]: + +async def get_user_strategic_llm( + session: AsyncSession, user_id: str +) -> ChatLiteLLM | None: """Get user's strategic LLM instance.""" - return await get_user_llm_instance(session, user_id, LLMRole.STRATEGIC) \ No newline at end of file + return await get_user_llm_instance(session, user_id, LLMRole.STRATEGIC) diff --git a/surfsense_backend/app/services/query_service.py b/surfsense_backend/app/services/query_service.py index c26cd0a..4a4bc59 100644 --- a/surfsense_backend/app/services/query_service.py +++ b/surfsense_backend/app/services/query_service.py @@ -1,9 +1,10 @@ import datetime -from langchain.schema import HumanMessage, SystemMessage, AIMessage -from app.config import config -from app.services.llm_service import get_user_strategic_llm +from typing import Any + +from langchain.schema import AIMessage, HumanMessage, SystemMessage from sqlalchemy.ext.asyncio import AsyncSession -from typing import Any, List, Optional + +from app.services.llm_service import get_user_strategic_llm class QueryService: @@ -13,13 +14,13 @@ class QueryService: @staticmethod async def reformulate_query_with_chat_history( - user_query: str, - session: AsyncSession, - user_id: str, - chat_history_str: Optional[str] = None + user_query: str, + session: AsyncSession, + user_id: str, + chat_history_str: str | None = None, ) -> str: """ - Reformulate the user query using the user's strategic LLM to make it more + Reformulate the user query using the user's strategic LLM to make it more effective for information retrieval and research purposes. Args: @@ -38,7 +39,9 @@ class QueryService: # Get the user's strategic LLM instance llm = await get_user_strategic_llm(session, user_id) if not llm: - print(f"Warning: No strategic LLM configured for user {user_id}. Using original query.") + print( + f"Warning: No strategic LLM configured for user {user_id}. Using original query." + ) return user_query # Create system message with instructions @@ -92,14 +95,13 @@ class QueryService: print(f"Error reformulating query: {e}") return user_query - @staticmethod - async def langchain_chat_history_to_str(chat_history: List[Any]) -> str: + async def langchain_chat_history_to_str(chat_history: list[Any]) -> str: """ Convert a list of chat history messages to a string. """ chat_history_str = "\n" - + for chat_message in chat_history: if isinstance(chat_message, HumanMessage): chat_history_str += f"{chat_message.content}\n" @@ -107,6 +109,6 @@ class QueryService: chat_history_str += f"{chat_message.content}\n" elif isinstance(chat_message, SystemMessage): chat_history_str += f"{chat_message.content}\n" - + chat_history_str += "" return chat_history_str diff --git a/surfsense_backend/app/services/reranker_service.py b/surfsense_backend/app/services/reranker_service.py index cf83df2..bea74e3 100644 --- a/surfsense_backend/app/services/reranker_service.py +++ b/surfsense_backend/app/services/reranker_service.py @@ -1,35 +1,39 @@ import logging -from typing import List, Dict, Any, Optional +from typing import Any, Optional + from rerankers import Document as RerankerDocument + class RerankerService: """ Service for reranking documents using a configured reranker """ - + def __init__(self, reranker_instance=None): """ Initialize the reranker service - + Args: reranker_instance: The reranker instance to use for reranking """ self.reranker_instance = reranker_instance - - def rerank_documents(self, query_text: str, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + + def rerank_documents( + self, query_text: str, documents: list[dict[str, Any]] + ) -> list[dict[str, Any]]: """ Rerank documents using the configured reranker - + Args: query_text: The query text to use for reranking documents: List of document dictionaries to rerank - + Returns: List[Dict[str, Any]]: Reranked documents """ if not self.reranker_instance or not documents: return documents - + try: # Create Document objects for the rerankers library reranker_docs = [] @@ -38,58 +42,63 @@ class RerankerService: content = doc.get("content", "") score = doc.get("score", 0.0) document_info = doc.get("document", {}) - + reranker_docs.append( RerankerDocument( text=content, doc_id=chunk_id, metadata={ - 'document_id': document_info.get("id", ""), - 'document_title': document_info.get("title", ""), - 'document_type': document_info.get("document_type", ""), - 'rrf_score': score - } + "document_id": document_info.get("id", ""), + "document_title": document_info.get("title", ""), + "document_type": document_info.get("document_type", ""), + "rrf_score": score, + }, ) ) - + # Rerank using the configured reranker reranking_results = self.reranker_instance.rank( - query=query_text, - docs=reranker_docs + query=query_text, docs=reranker_docs ) - + # Process the results from the reranker # Convert to serializable dictionaries serialized_results = [] for result in reranking_results.results: # Find the original document by id - original_doc = next((doc for doc in documents if doc.get("chunk_id") == result.document.doc_id), None) + original_doc = next( + ( + doc + for doc in documents + if doc.get("chunk_id") == result.document.doc_id + ), + None, + ) if original_doc: # Create a new document with the reranked score reranked_doc = original_doc.copy() reranked_doc["score"] = float(result.score) reranked_doc["rank"] = result.rank serialized_results.append(reranked_doc) - + return serialized_results - + except Exception as e: # Log the error - logging.error(f"Error during reranking: {str(e)}") + logging.error(f"Error during reranking: {e!s}") # Fall back to original documents without reranking return documents - + @staticmethod - def get_reranker_instance() -> Optional['RerankerService']: + def get_reranker_instance() -> Optional["RerankerService"]: """ Get a reranker service instance from the global configuration. - + Returns: Optional[RerankerService]: A reranker service instance if configured, None otherwise """ from app.config import config - - if hasattr(config, 'reranker_instance') and config.reranker_instance: + + if hasattr(config, "reranker_instance") and config.reranker_instance: return RerankerService(config.reranker_instance) return None - \ No newline at end of file diff --git a/surfsense_backend/app/services/streaming_service.py b/surfsense_backend/app/services/streaming_service.py index ab42c3b..dde792c 100644 --- a/surfsense_backend/app/services/streaming_service.py +++ b/surfsense_backend/app/services/streaming_service.py @@ -1,27 +1,15 @@ import json -from typing import Any, Dict, List +from typing import Any class StreamingService: def __init__(self): self.terminal_idx = 1 self.message_annotations = [ - { - "type": "TERMINAL_INFO", - "content": [] - }, - { - "type": "SOURCES", - "content": [] - }, - { - "type": "ANSWER", - "content": [] - }, - { - "type": "FURTHER_QUESTIONS", - "content": [] - } + {"type": "TERMINAL_INFO", "content": []}, + {"type": "SOURCES", "content": []}, + {"type": "ANSWER", "content": []}, + {"type": "FURTHER_QUESTIONS", "content": []}, ] # DEPRECATED: This sends the full annotation array every time (inefficient) @@ -35,7 +23,7 @@ class StreamingService: Returns: str: The formatted annotations string """ - return f'8:{json.dumps(self.message_annotations)}\n' + return f"8:{json.dumps(self.message_annotations)}\n" def format_terminal_info_delta(self, text: str, message_type: str = "info") -> str: """ @@ -58,7 +46,7 @@ class StreamingService: annotation = {"type": "TERMINAL_INFO", "content": [message]} return f"8:[{json.dumps(annotation)}]\n" - def format_sources_delta(self, sources: List[Dict[str, Any]]) -> str: + def format_sources_delta(self, sources: list[dict[str, Any]]) -> str: """ Format sources as a delta annotation @@ -95,7 +83,7 @@ class StreamingService: annotation = {"type": "ANSWER", "content": [answer_chunk]} return f"8:[{json.dumps(annotation)}]\n" - def format_answer_annotation(self, answer_lines: List[str]) -> str: + def format_answer_annotation(self, answer_lines: list[str]) -> str: """ Format the complete answer as a replacement annotation @@ -113,7 +101,7 @@ class StreamingService: return f"8:[{json.dumps(annotation)}]\n" def format_further_questions_delta( - self, further_questions: List[Dict[str, Any]] + self, further_questions: list[dict[str, Any]] ) -> str: """ Format further questions as a delta annotation @@ -155,14 +143,16 @@ class StreamingService: """ return f"3:{json.dumps(error_message)}\n" - def format_completion(self, prompt_tokens: int = 156, completion_tokens: int = 204) -> str: + def format_completion( + self, prompt_tokens: int = 156, completion_tokens: int = 204 + ) -> str: """ Format a completion message - + Args: prompt_tokens: Number of prompt tokens completion_tokens: Number of completion tokens - + Returns: str: The formatted completion string """ @@ -172,7 +162,7 @@ class StreamingService: "usage": { "promptTokens": prompt_tokens, "completionTokens": completion_tokens, - "totalTokens": total_tokens - } + "totalTokens": total_tokens, + }, } - return f'd:{json.dumps(completion_data)}\n' \ No newline at end of file + return f"d:{json.dumps(completion_data)}\n" diff --git a/surfsense_backend/app/services/task_logging_service.py b/surfsense_backend/app/services/task_logging_service.py index c50e420..39316b7 100644 --- a/surfsense_backend/app/services/task_logging_service.py +++ b/surfsense_backend/app/services/task_logging_service.py @@ -1,111 +1,116 @@ -from typing import Optional, Dict, Any -from sqlalchemy.ext.asyncio import AsyncSession -from app.db import Log, LogLevel, LogStatus import logging -import json from datetime import datetime +from typing import Any + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import Log, LogLevel, LogStatus logger = logging.getLogger(__name__) + class TaskLoggingService: """Service for logging background tasks using the database Log model""" - + def __init__(self, session: AsyncSession, search_space_id: int): self.session = session self.search_space_id = search_space_id - + async def log_task_start( self, task_name: str, source: str, message: str, - metadata: Optional[Dict[str, Any]] = None + metadata: dict[str, Any] | None = None, ) -> Log: """ Log the start of a task with IN_PROGRESS status - + Args: task_name: Name/identifier of the task source: Source service/component (e.g., 'document_processor', 'slack_indexer') message: Human-readable message about the task metadata: Additional context data - + Returns: Log: The created log entry """ log_metadata = metadata or {} - log_metadata.update({ - "task_name": task_name, - "started_at": datetime.utcnow().isoformat() - }) - + log_metadata.update( + {"task_name": task_name, "started_at": datetime.utcnow().isoformat()} + ) + log_entry = Log( level=LogLevel.INFO, status=LogStatus.IN_PROGRESS, message=message, source=source, log_metadata=log_metadata, - search_space_id=self.search_space_id + search_space_id=self.search_space_id, ) - + self.session.add(log_entry) await self.session.commit() await self.session.refresh(log_entry) - + logger.info(f"Started task {task_name}: {message}") return log_entry - + async def log_task_success( self, log_entry: Log, message: str, - additional_metadata: Optional[Dict[str, Any]] = None + additional_metadata: dict[str, Any] | None = None, ) -> Log: """ Update a log entry to SUCCESS status - + Args: log_entry: The original log entry to update message: Success message additional_metadata: Additional metadata to merge - + Returns: Log: The updated log entry """ # Update the existing log entry log_entry.status = LogStatus.SUCCESS log_entry.message = message - + # Merge additional metadata if additional_metadata: if log_entry.log_metadata is None: log_entry.log_metadata = {} log_entry.log_metadata.update(additional_metadata) log_entry.log_metadata["completed_at"] = datetime.utcnow().isoformat() - + await self.session.commit() await self.session.refresh(log_entry) - - task_name = log_entry.log_metadata.get("task_name", "unknown") if log_entry.log_metadata else "unknown" + + task_name = ( + log_entry.log_metadata.get("task_name", "unknown") + if log_entry.log_metadata + else "unknown" + ) logger.info(f"Completed task {task_name}: {message}") return log_entry - + async def log_task_failure( self, log_entry: Log, error_message: str, - error_details: Optional[str] = None, - additional_metadata: Optional[Dict[str, Any]] = None + error_details: str | None = None, + additional_metadata: dict[str, Any] | None = None, ) -> Log: """ Update a log entry to FAILED status - + Args: log_entry: The original log entry to update error_message: Error message error_details: Detailed error information additional_metadata: Additional metadata to merge - + Returns: Log: The updated log entry """ @@ -113,77 +118,86 @@ class TaskLoggingService: log_entry.status = LogStatus.FAILED log_entry.level = LogLevel.ERROR log_entry.message = error_message - + # Merge additional metadata if log_entry.log_metadata is None: log_entry.log_metadata = {} - - log_entry.log_metadata.update({ - "failed_at": datetime.utcnow().isoformat(), - "error_details": error_details - }) - + + log_entry.log_metadata.update( + {"failed_at": datetime.utcnow().isoformat(), "error_details": error_details} + ) + if additional_metadata: log_entry.log_metadata.update(additional_metadata) - + await self.session.commit() await self.session.refresh(log_entry) - - task_name = log_entry.log_metadata.get("task_name", "unknown") if log_entry.log_metadata else "unknown" + + task_name = ( + log_entry.log_metadata.get("task_name", "unknown") + if log_entry.log_metadata + else "unknown" + ) logger.error(f"Failed task {task_name}: {error_message}") if error_details: logger.error(f"Error details: {error_details}") - + return log_entry - + async def log_task_progress( self, log_entry: Log, progress_message: str, - progress_metadata: Optional[Dict[str, Any]] = None + progress_metadata: dict[str, Any] | None = None, ) -> Log: """ Update a log entry with progress information while keeping IN_PROGRESS status - + Args: log_entry: The log entry to update progress_message: Progress update message progress_metadata: Additional progress metadata - + Returns: Log: The updated log entry """ log_entry.message = progress_message - + if progress_metadata: if log_entry.log_metadata is None: log_entry.log_metadata = {} log_entry.log_metadata.update(progress_metadata) - log_entry.log_metadata["last_progress_update"] = datetime.utcnow().isoformat() - + log_entry.log_metadata["last_progress_update"] = ( + datetime.utcnow().isoformat() + ) + await self.session.commit() await self.session.refresh(log_entry) - - task_name = log_entry.log_metadata.get("task_name", "unknown") if log_entry.log_metadata else "unknown" + + task_name = ( + log_entry.log_metadata.get("task_name", "unknown") + if log_entry.log_metadata + else "unknown" + ) logger.info(f"Progress update for task {task_name}: {progress_message}") return log_entry - + async def log_simple_event( self, level: LogLevel, source: str, message: str, - metadata: Optional[Dict[str, Any]] = None + metadata: dict[str, Any] | None = None, ) -> Log: """ Log a simple event (not a long-running task) - + Args: level: Log level source: Source service/component message: Log message metadata: Additional context data - + Returns: Log: The created log entry """ @@ -193,12 +207,12 @@ class TaskLoggingService: message=message, source=source, log_metadata=metadata or {}, - search_space_id=self.search_space_id + search_space_id=self.search_space_id, ) - + self.session.add(log_entry) await self.session.commit() await self.session.refresh(log_entry) - + logger.info(f"Logged event from {source}: {message}") - return log_entry \ No newline at end of file + return log_entry diff --git a/surfsense_backend/app/tasks/background_tasks.py b/surfsense_backend/app/tasks/background_tasks.py index 9599619..06304b4 100644 --- a/surfsense_backend/app/tasks/background_tasks.py +++ b/surfsense_backend/app/tasks/background_tasks.py @@ -1,46 +1,49 @@ -from typing import Optional, List -from sqlalchemy.ext.asyncio import AsyncSession +import logging +from urllib.parse import parse_qs, urlparse + +import aiohttp +import validators +from langchain_community.document_loaders import AsyncChromiumLoader, FireCrawlLoader +from langchain_community.document_transformers import MarkdownifyTransformer +from langchain_core.documents import Document as LangChainDocument from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from app.db import Document, DocumentType, Chunk -from app.schemas import ExtensionDocumentContent +from youtube_transcript_api import YouTubeTranscriptApi + from app.config import config +from app.db import Chunk, Document, DocumentType from app.prompts import SUMMARY_PROMPT_TEMPLATE -from app.utils.document_converters import convert_document_to_markdown, generate_content_hash +from app.schemas import ExtensionDocumentContent from app.services.llm_service import get_user_long_context_llm from app.services.task_logging_service import TaskLoggingService -from langchain_core.documents import Document as LangChainDocument -from langchain_community.document_loaders import FireCrawlLoader, AsyncChromiumLoader -from langchain_community.document_transformers import MarkdownifyTransformer -import validators -from youtube_transcript_api import YouTubeTranscriptApi -from urllib.parse import urlparse, parse_qs -import aiohttp -import logging +from app.utils.document_converters import ( + convert_document_to_markdown, + generate_content_hash, +) md = MarkdownifyTransformer() + async def add_crawled_url_document( session: AsyncSession, url: str, search_space_id: int, user_id: str -) -> Optional[Document]: +) -> Document | None: task_logger = TaskLoggingService(session, search_space_id) - + # Log task start log_entry = await task_logger.log_task_start( task_name="crawl_url_document", source="background_task", message=f"Starting URL crawling process for: {url}", - metadata={"url": url, "user_id": str(user_id)} + metadata={"url": url, "user_id": str(user_id)}, ) - + try: # URL validation step await task_logger.log_task_progress( - log_entry, - f"Validating URL: {url}", - {"stage": "validation"} + log_entry, f"Validating URL: {url}", {"stage": "validation"} ) - + if not validators.url(url): raise ValueError(f"Url {url} is not a valid URL address") @@ -48,7 +51,10 @@ async def add_crawled_url_document( await task_logger.log_task_progress( log_entry, f"Setting up crawler for URL: {url}", - {"stage": "crawler_setup", "firecrawl_available": bool(config.FIRECRAWL_API_KEY)} + { + "stage": "crawler_setup", + "firecrawl_available": bool(config.FIRECRAWL_API_KEY), + }, ) if config.FIRECRAWL_API_KEY: @@ -68,21 +74,21 @@ async def add_crawled_url_document( await task_logger.log_task_progress( log_entry, f"Crawling URL content: {url}", - {"stage": "crawling", "crawler_type": type(crawl_loader).__name__} + {"stage": "crawling", "crawler_type": type(crawl_loader).__name__}, ) url_crawled = await crawl_loader.aload() - if type(crawl_loader) == FireCrawlLoader: + if isinstance(crawl_loader, FireCrawlLoader): content_in_markdown = url_crawled[0].page_content - elif type(crawl_loader) == AsyncChromiumLoader: + elif isinstance(crawl_loader, AsyncChromiumLoader): content_in_markdown = md.transform_documents(url_crawled)[0].page_content # Format document await task_logger.log_task_progress( log_entry, f"Processing crawled content from: {url}", - {"stage": "content_processing", "content_length": len(content_in_markdown)} + {"stage": "content_processing", "content_length": len(content_in_markdown)}, ) # Format document metadata in a more maintainable way @@ -117,7 +123,7 @@ async def add_crawled_url_document( await task_logger.log_task_progress( log_entry, f"Checking for duplicate content: {url}", - {"stage": "duplicate_check", "content_hash": content_hash} + {"stage": "duplicate_check", "content_hash": content_hash}, ) # Check if document with this content hash already exists @@ -125,21 +131,26 @@ async def add_crawled_url_document( select(Document).where(Document.content_hash == content_hash) ) existing_document = existing_doc_result.scalars().first() - + if existing_document: await task_logger.log_task_success( log_entry, f"Document already exists for URL: {url}", - {"duplicate_detected": True, "existing_document_id": existing_document.id} + { + "duplicate_detected": True, + "existing_document_id": existing_document.id, + }, + ) + logging.info( + f"Document with content hash {content_hash} already exists. Skipping processing." ) - logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.") return existing_document # Get LLM for summary generation await task_logger.log_task_progress( log_entry, f"Preparing for summary generation: {url}", - {"stage": "llm_setup"} + {"stage": "llm_setup"}, ) # Get user's long context LLM @@ -151,7 +162,7 @@ async def add_crawled_url_document( await task_logger.log_task_progress( log_entry, f"Generating summary for URL content: {url}", - {"stage": "summary_generation"} + {"stage": "summary_generation"}, ) summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm @@ -165,7 +176,7 @@ async def add_crawled_url_document( await task_logger.log_task_progress( log_entry, f"Processing content chunks for URL: {url}", - {"stage": "chunk_processing"} + {"stage": "chunk_processing"}, ) chunks = [ @@ -180,13 +191,13 @@ async def add_crawled_url_document( await task_logger.log_task_progress( log_entry, f"Creating document in database for URL: {url}", - {"stage": "document_creation", "chunks_count": len(chunks)} + {"stage": "document_creation", "chunks_count": len(chunks)}, ) document = Document( search_space_id=search_space_id, title=url_crawled[0].metadata["title"] - if type(crawl_loader) == FireCrawlLoader + if isinstance(crawl_loader, FireCrawlLoader) else url_crawled[0].metadata["source"], document_type=DocumentType.CRAWLED_URL, document_metadata=url_crawled[0].metadata, @@ -209,8 +220,8 @@ async def add_crawled_url_document( "title": document.title, "content_hash": content_hash, "chunks_count": len(chunks), - "summary_length": len(summary_content) - } + "summary_length": len(summary_content), + }, ) return document @@ -221,7 +232,7 @@ async def add_crawled_url_document( log_entry, f"Database error while processing URL: {url}", str(db_error), - {"error_type": "SQLAlchemyError"} + {"error_type": "SQLAlchemyError"}, ) raise db_error except Exception as e: @@ -230,14 +241,17 @@ async def add_crawled_url_document( log_entry, f"Failed to crawl URL: {url}", str(e), - {"error_type": type(e).__name__} + {"error_type": type(e).__name__}, ) - raise RuntimeError(f"Failed to crawl URL: {str(e)}") + raise RuntimeError(f"Failed to crawl URL: {e!s}") from e async def add_extension_received_document( - session: AsyncSession, content: ExtensionDocumentContent, search_space_id: int, user_id: str -) -> Optional[Document]: + session: AsyncSession, + content: ExtensionDocumentContent, + search_space_id: int, + user_id: str, +) -> Document | None: """ Process and store document content received from the SurfSense Extension. @@ -250,7 +264,7 @@ async def add_extension_received_document( Document object if successful, None if failed """ task_logger = TaskLoggingService(session, search_space_id) - + # Log task start log_entry = await task_logger.log_task_start( task_name="extension_document", @@ -259,10 +273,10 @@ async def add_extension_received_document( metadata={ "url": content.metadata.VisitedWebPageURL, "title": content.metadata.VisitedWebPageTitle, - "user_id": str(user_id) - } + "user_id": str(user_id), + }, ) - + try: # Format document metadata in a more maintainable way metadata_sections = [ @@ -301,14 +315,19 @@ async def add_extension_received_document( select(Document).where(Document.content_hash == content_hash) ) existing_document = existing_doc_result.scalars().first() - + if existing_document: await task_logger.log_task_success( log_entry, f"Extension document already exists: {content.metadata.VisitedWebPageTitle}", - {"duplicate_detected": True, "existing_document_id": existing_document.id} + { + "duplicate_detected": True, + "existing_document_id": existing_document.id, + }, + ) + logging.info( + f"Document with content hash {content_hash} already exists. Skipping processing." ) - logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.") return existing_document # Get user's long context LLM @@ -356,8 +375,8 @@ async def add_extension_received_document( { "document_id": document.id, "content_hash": content_hash, - "url": content.metadata.VisitedWebPageURL - } + "url": content.metadata.VisitedWebPageURL, + }, ) return document @@ -368,7 +387,7 @@ async def add_extension_received_document( log_entry, f"Database error processing extension document: {content.metadata.VisitedWebPageTitle}", str(db_error), - {"error_type": "SQLAlchemyError"} + {"error_type": "SQLAlchemyError"}, ) raise db_error except Exception as e: @@ -377,24 +396,32 @@ async def add_extension_received_document( log_entry, f"Failed to process extension document: {content.metadata.VisitedWebPageTitle}", str(e), - {"error_type": type(e).__name__} + {"error_type": type(e).__name__}, ) - raise RuntimeError(f"Failed to process extension document: {str(e)}") + raise RuntimeError(f"Failed to process extension document: {e!s}") from e async def add_received_markdown_file_document( - session: AsyncSession, file_name: str, file_in_markdown: str, search_space_id: int, user_id: str -) -> Optional[Document]: + session: AsyncSession, + file_name: str, + file_in_markdown: str, + search_space_id: int, + user_id: str, +) -> Document | None: task_logger = TaskLoggingService(session, search_space_id) - + # Log task start log_entry = await task_logger.log_task_start( task_name="markdown_file_document", source="background_task", message=f"Processing markdown file: {file_name}", - metadata={"filename": file_name, "user_id": str(user_id), "content_length": len(file_in_markdown)} + metadata={ + "filename": file_name, + "user_id": str(user_id), + "content_length": len(file_in_markdown), + }, ) - + try: content_hash = generate_content_hash(file_in_markdown, search_space_id) @@ -403,14 +430,19 @@ async def add_received_markdown_file_document( select(Document).where(Document.content_hash == content_hash) ) existing_document = existing_doc_result.scalars().first() - + if existing_document: await task_logger.log_task_success( log_entry, f"Markdown file document already exists: {file_name}", - {"duplicate_detected": True, "existing_document_id": existing_document.id} + { + "duplicate_detected": True, + "existing_document_id": existing_document.id, + }, + ) + logging.info( + f"Document with content hash {content_hash} already exists. Skipping processing." ) - logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.") return existing_document # Get user's long context LLM @@ -459,8 +491,8 @@ async def add_received_markdown_file_document( "document_id": document.id, "content_hash": content_hash, "chunks_count": len(chunks), - "summary_length": len(summary_content) - } + "summary_length": len(summary_content), + }, ) return document @@ -470,7 +502,7 @@ async def add_received_markdown_file_document( log_entry, f"Database error processing markdown file: {file_name}", str(db_error), - {"error_type": "SQLAlchemyError"} + {"error_type": "SQLAlchemyError"}, ) raise db_error except Exception as e: @@ -479,18 +511,18 @@ async def add_received_markdown_file_document( log_entry, f"Failed to process markdown file: {file_name}", str(e), - {"error_type": type(e).__name__} + {"error_type": type(e).__name__}, ) - raise RuntimeError(f"Failed to process file document: {str(e)}") + raise RuntimeError(f"Failed to process file document: {e!s}") from e async def add_received_file_document_using_unstructured( session: AsyncSession, file_name: str, - unstructured_processed_elements: List[LangChainDocument], + unstructured_processed_elements: list[LangChainDocument], search_space_id: int, user_id: str, -) -> Optional[Document]: +) -> Document | None: try: file_in_markdown = await convert_document_to_markdown( unstructured_processed_elements @@ -503,9 +535,11 @@ async def add_received_file_document_using_unstructured( select(Document).where(Document.content_hash == content_hash) ) existing_document = existing_doc_result.scalars().first() - + if existing_document: - logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.") + logging.info( + f"Document with content hash {content_hash} already exists. Skipping processing." + ) return existing_document # TODO: Check if file_markdown exceeds token limit of embedding model @@ -555,7 +589,7 @@ async def add_received_file_document_using_unstructured( raise db_error except Exception as e: await session.rollback() - raise RuntimeError(f"Failed to process file document: {str(e)}") + raise RuntimeError(f"Failed to process file document: {e!s}") from e async def add_received_file_document_using_llamacloud( @@ -564,7 +598,7 @@ async def add_received_file_document_using_llamacloud( llamacloud_markdown_document: str, search_space_id: int, user_id: str, -) -> Optional[Document]: +) -> Document | None: """ Process and store document content parsed by LlamaCloud. @@ -588,9 +622,11 @@ async def add_received_file_document_using_llamacloud( select(Document).where(Document.content_hash == content_hash) ) existing_document = existing_doc_result.scalars().first() - + if existing_document: - logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.") + logging.info( + f"Document with content hash {content_hash} already exists. Skipping processing." + ) return existing_document # Get user's long context LLM @@ -638,7 +674,9 @@ async def add_received_file_document_using_llamacloud( raise db_error except Exception as e: await session.rollback() - raise RuntimeError(f"Failed to process file document using LlamaCloud: {str(e)}") + raise RuntimeError( + f"Failed to process file document using LlamaCloud: {e!s}" + ) from e async def add_received_file_document_using_docling( @@ -647,7 +685,7 @@ async def add_received_file_document_using_docling( docling_markdown_document: str, search_space_id: int, user_id: str, -) -> Optional[Document]: +) -> Document | None: """ Process and store document content parsed by Docling. @@ -671,9 +709,11 @@ async def add_received_file_document_using_docling( select(Document).where(Document.content_hash == content_hash) ) existing_document = existing_doc_result.scalars().first() - + if existing_document: - logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.") + logging.info( + f"Document with content hash {content_hash} already exists. Skipping processing." + ) return existing_document # Get user's long context LLM @@ -683,12 +723,11 @@ async def add_received_file_document_using_docling( # Generate summary using chunked processing for large documents from app.services.docling_service import create_docling_service + docling_service = create_docling_service() - + summary_content = await docling_service.process_large_document_summary( - content=file_in_markdown, - llm=user_llm, - document_title=file_name + content=file_in_markdown, llm=user_llm, document_title=file_name ) summary_embedding = config.embedding_model_instance.embed(summary_content) @@ -726,7 +765,9 @@ async def add_received_file_document_using_docling( raise db_error except Exception as e: await session.rollback() - raise RuntimeError(f"Failed to process file document using Docling: {str(e)}") + raise RuntimeError( + f"Failed to process file document using Docling: {e!s}" + ) from e async def add_youtube_video_document( @@ -749,23 +790,23 @@ async def add_youtube_video_document( RuntimeError: If the video processing fails """ task_logger = TaskLoggingService(session, search_space_id) - + # Log task start log_entry = await task_logger.log_task_start( task_name="youtube_video_document", source="background_task", message=f"Starting YouTube video processing for: {url}", - metadata={"url": url, "user_id": str(user_id)} + metadata={"url": url, "user_id": str(user_id)}, ) - + try: # Extract video ID from URL await task_logger.log_task_progress( log_entry, f"Extracting video ID from URL: {url}", - {"stage": "video_id_extraction"} + {"stage": "video_id_extraction"}, ) - + def get_youtube_video_id(url: str): parsed_url = urlparse(url) hostname = parsed_url.hostname @@ -790,14 +831,14 @@ async def add_youtube_video_document( await task_logger.log_task_progress( log_entry, f"Video ID extracted: {video_id}", - {"stage": "video_id_extracted", "video_id": video_id} + {"stage": "video_id_extracted", "video_id": video_id}, ) # Get video metadata await task_logger.log_task_progress( log_entry, f"Fetching video metadata for: {video_id}", - {"stage": "metadata_fetch"} + {"stage": "metadata_fetch"}, ) params = { @@ -806,21 +847,27 @@ async def add_youtube_video_document( } oembed_url = "https://www.youtube.com/oembed" - async with aiohttp.ClientSession() as http_session: - async with http_session.get(oembed_url, params=params) as response: - video_data = await response.json() + async with ( + aiohttp.ClientSession() as http_session, + http_session.get(oembed_url, params=params) as response, + ): + video_data = await response.json() await task_logger.log_task_progress( log_entry, f"Video metadata fetched: {video_data.get('title', 'Unknown')}", - {"stage": "metadata_fetched", "title": video_data.get('title'), "author": video_data.get('author_name')} + { + "stage": "metadata_fetched", + "title": video_data.get("title"), + "author": video_data.get("author_name"), + }, ) # Get video transcript await task_logger.log_task_progress( log_entry, f"Fetching transcript for video: {video_id}", - {"stage": "transcript_fetch"} + {"stage": "transcript_fetch"}, ) try: @@ -834,25 +881,29 @@ async def add_youtube_video_document( timestamp = f"[{start_time:.2f}s-{start_time + duration:.2f}s]" transcript_segments.append(f"{timestamp} {text}") transcript_text = "\n".join(transcript_segments) - + await task_logger.log_task_progress( log_entry, f"Transcript fetched successfully: {len(captions)} segments", - {"stage": "transcript_fetched", "segments_count": len(captions), "transcript_length": len(transcript_text)} + { + "stage": "transcript_fetched", + "segments_count": len(captions), + "transcript_length": len(transcript_text), + }, ) except Exception as e: - transcript_text = f"No captions available for this video. Error: {str(e)}" + transcript_text = f"No captions available for this video. Error: {e!s}" await task_logger.log_task_progress( log_entry, f"No transcript available for video: {video_id}", - {"stage": "transcript_unavailable", "error": str(e)} + {"stage": "transcript_unavailable", "error": str(e)}, ) # Format document await task_logger.log_task_progress( log_entry, f"Processing video content: {video_data.get('title', 'YouTube Video')}", - {"stage": "content_processing"} + {"stage": "content_processing"}, ) # Format document metadata in a more maintainable way @@ -890,7 +941,7 @@ async def add_youtube_video_document( await task_logger.log_task_progress( log_entry, f"Checking for duplicate video content: {video_id}", - {"stage": "duplicate_check", "content_hash": content_hash} + {"stage": "duplicate_check", "content_hash": content_hash}, ) # Check if document with this content hash already exists @@ -898,21 +949,27 @@ async def add_youtube_video_document( select(Document).where(Document.content_hash == content_hash) ) existing_document = existing_doc_result.scalars().first() - + if existing_document: await task_logger.log_task_success( log_entry, f"YouTube video document already exists: {video_data.get('title', 'YouTube Video')}", - {"duplicate_detected": True, "existing_document_id": existing_document.id, "video_id": video_id} + { + "duplicate_detected": True, + "existing_document_id": existing_document.id, + "video_id": video_id, + }, + ) + logging.info( + f"Document with content hash {content_hash} already exists. Skipping processing." ) - logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.") return existing_document # Get LLM for summary generation await task_logger.log_task_progress( log_entry, f"Preparing for summary generation: {video_data.get('title', 'YouTube Video')}", - {"stage": "llm_setup"} + {"stage": "llm_setup"}, ) # Get user's long context LLM @@ -924,7 +981,7 @@ async def add_youtube_video_document( await task_logger.log_task_progress( log_entry, f"Generating summary for video: {video_data.get('title', 'YouTube Video')}", - {"stage": "summary_generation"} + {"stage": "summary_generation"}, ) summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm @@ -938,7 +995,7 @@ async def add_youtube_video_document( await task_logger.log_task_progress( log_entry, f"Processing content chunks for video: {video_data.get('title', 'YouTube Video')}", - {"stage": "chunk_processing"} + {"stage": "chunk_processing"}, ) chunks = [ @@ -953,7 +1010,7 @@ async def add_youtube_video_document( await task_logger.log_task_progress( log_entry, f"Creating YouTube video document in database: {video_data.get('title', 'YouTube Video')}", - {"stage": "document_creation", "chunks_count": len(chunks)} + {"stage": "document_creation", "chunks_count": len(chunks)}, ) document = Document( @@ -988,8 +1045,8 @@ async def add_youtube_video_document( "content_hash": content_hash, "chunks_count": len(chunks), "summary_length": len(summary_content), - "has_transcript": "No captions available" not in transcript_text - } + "has_transcript": "No captions available" not in transcript_text, + }, ) return document @@ -999,7 +1056,10 @@ async def add_youtube_video_document( log_entry, f"Database error while processing YouTube video: {url}", str(db_error), - {"error_type": "SQLAlchemyError", "video_id": video_id if 'video_id' in locals() else None} + { + "error_type": "SQLAlchemyError", + "video_id": video_id if "video_id" in locals() else None, + }, ) raise db_error except Exception as e: @@ -1008,7 +1068,10 @@ async def add_youtube_video_document( log_entry, f"Failed to process YouTube video: {url}", str(e), - {"error_type": type(e).__name__, "video_id": video_id if 'video_id' in locals() else None} + { + "error_type": type(e).__name__, + "video_id": video_id if "video_id" in locals() else None, + }, ) - logging.error(f"Failed to process YouTube video: {str(e)}") + logging.error(f"Failed to process YouTube video: {e!s}") raise diff --git a/surfsense_backend/app/tasks/connectors_indexing_tasks.py b/surfsense_backend/app/tasks/connectors_indexing_tasks.py index e0b3cd1..053b8ba 100644 --- a/surfsense_backend/app/tasks/connectors_indexing_tasks.py +++ b/surfsense_backend/app/tasks/connectors_indexing_tasks.py @@ -1,84 +1,99 @@ -from typing import Optional, Tuple -from sqlalchemy.ext.asyncio import AsyncSession +import asyncio +import logging +from datetime import UTC, datetime, timedelta + +from slack_sdk.errors import SlackApiError from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from datetime import datetime, timedelta, timezone -from app.db import Document, DocumentType, Chunk, SearchSourceConnector, SearchSourceConnectorType, SearchSpace + from app.config import config +from app.connectors.discord_connector import DiscordConnector +from app.connectors.github_connector import GitHubConnector +from app.connectors.linear_connector import LinearConnector +from app.connectors.notion_history import NotionHistoryConnector +from app.connectors.slack_history import SlackHistory +from app.db import ( + Chunk, + Document, + DocumentType, + SearchSourceConnector, + SearchSourceConnectorType, +) from app.prompts import SUMMARY_PROMPT_TEMPLATE from app.services.llm_service import get_user_long_context_llm from app.services.task_logging_service import TaskLoggingService -from app.connectors.slack_history import SlackHistory -from app.connectors.notion_history import NotionHistoryConnector -from app.connectors.github_connector import GitHubConnector -from app.connectors.linear_connector import LinearConnector -from app.connectors.discord_connector import DiscordConnector -from slack_sdk.errors import SlackApiError -import logging -import asyncio - from app.utils.document_converters import generate_content_hash # Set up logging logger = logging.getLogger(__name__) + async def index_slack_messages( session: AsyncSession, connector_id: int, search_space_id: int, user_id: str, - start_date: str = None, - end_date: str = None, - update_last_indexed: bool = True -) -> Tuple[int, Optional[str]]: + start_date: str | None = None, + end_date: str | None = None, + update_last_indexed: bool = True, +) -> tuple[int, str | None]: """ Index Slack messages from all accessible channels. - + Args: session: Database session connector_id: ID of the Slack connector search_space_id: ID of the search space to store documents in update_last_indexed: Whether to update the last_indexed_at timestamp (default: True) - + Returns: Tuple containing (number of documents indexed, error message or None) """ task_logger = TaskLoggingService(session, search_space_id) - + # Log task start log_entry = await task_logger.log_task_start( task_name="slack_messages_indexing", source="connector_indexing_task", message=f"Starting Slack messages indexing for connector {connector_id}", - metadata={"connector_id": connector_id, "user_id": str(user_id), "start_date": start_date, "end_date": end_date} + metadata={ + "connector_id": connector_id, + "user_id": str(user_id), + "start_date": start_date, + "end_date": end_date, + }, ) - + try: # Get the connector await task_logger.log_task_progress( log_entry, f"Retrieving Slack connector {connector_id} from database", - {"stage": "connector_retrieval"} + {"stage": "connector_retrieval"}, ) - + result = await session.execute( - select(SearchSourceConnector) - .filter( + select(SearchSourceConnector).filter( SearchSourceConnector.id == connector_id, - SearchSourceConnector.connector_type == SearchSourceConnectorType.SLACK_CONNECTOR + SearchSourceConnector.connector_type + == SearchSourceConnectorType.SLACK_CONNECTOR, ) ) connector = result.scalars().first() - + if not connector: await task_logger.log_task_failure( log_entry, f"Connector with ID {connector_id} not found or is not a Slack connector", "Connector not found", - {"error_type": "ConnectorNotFound"} + {"error_type": "ConnectorNotFound"}, ) - return 0, f"Connector with ID {connector_id} not found or is not a Slack connector" - + return ( + 0, + f"Connector with ID {connector_id} not found or is not a Slack connector", + ) + # Get the Slack token from the connector config slack_token = connector.config.get("SLACK_BOT_TOKEN") if not slack_token: @@ -86,62 +101,86 @@ async def index_slack_messages( log_entry, f"Slack token not found in connector config for connector {connector_id}", "Missing Slack token", - {"error_type": "MissingToken"} + {"error_type": "MissingToken"}, ) return 0, "Slack token not found in connector config" - + # Initialize Slack client await task_logger.log_task_progress( log_entry, f"Initializing Slack client for connector {connector_id}", - {"stage": "client_initialization"} + {"stage": "client_initialization"}, ) - + slack_client = SlackHistory(token=slack_token) - + # Calculate date range await task_logger.log_task_progress( log_entry, - f"Calculating date range for Slack indexing", - {"stage": "date_calculation", "provided_start_date": start_date, "provided_end_date": end_date} + "Calculating date range for Slack indexing", + { + "stage": "date_calculation", + "provided_start_date": start_date, + "provided_end_date": end_date, + }, ) - + if start_date is None or end_date is None: # Fall back to calculating dates based on last_indexed_at calculated_end_date = datetime.now() - + # Use last_indexed_at as start date if available, otherwise use 365 days ago if connector.last_indexed_at: # Convert dates to be comparable (both timezone-naive) - last_indexed_naive = connector.last_indexed_at.replace(tzinfo=None) if connector.last_indexed_at.tzinfo else connector.last_indexed_at - + last_indexed_naive = ( + connector.last_indexed_at.replace(tzinfo=None) + if connector.last_indexed_at.tzinfo + else connector.last_indexed_at + ) + # Check if last_indexed_at is in the future or after end_date if last_indexed_naive > calculated_end_date: - logger.warning(f"Last indexed date ({last_indexed_naive.strftime('%Y-%m-%d')}) is in the future. Using 365 days ago instead.") + logger.warning( + f"Last indexed date ({last_indexed_naive.strftime('%Y-%m-%d')}) is in the future. Using 365 days ago instead." + ) calculated_start_date = calculated_end_date - timedelta(days=365) else: calculated_start_date = last_indexed_naive - logger.info(f"Using last_indexed_at ({calculated_start_date.strftime('%Y-%m-%d')}) as start date") + logger.info( + f"Using last_indexed_at ({calculated_start_date.strftime('%Y-%m-%d')}) as start date" + ) else: - calculated_start_date = calculated_end_date - timedelta(days=365) # Use 365 days as default - logger.info(f"No last_indexed_at found, using {calculated_start_date.strftime('%Y-%m-%d')} (365 days ago) as start date") - + calculated_start_date = calculated_end_date - timedelta( + days=365 + ) # Use 365 days as default + logger.info( + f"No last_indexed_at found, using {calculated_start_date.strftime('%Y-%m-%d')} (365 days ago) as start date" + ) + # Use calculated dates if not provided - start_date_str = start_date if start_date else calculated_start_date.strftime("%Y-%m-%d") - end_date_str = end_date if end_date else calculated_end_date.strftime("%Y-%m-%d") + start_date_str = ( + start_date if start_date else calculated_start_date.strftime("%Y-%m-%d") + ) + end_date_str = ( + end_date if end_date else calculated_end_date.strftime("%Y-%m-%d") + ) else: # Use provided dates start_date_str = start_date end_date_str = end_date - + logger.info(f"Indexing Slack messages from {start_date_str} to {end_date_str}") - + await task_logger.log_task_progress( log_entry, f"Fetching Slack channels from {start_date_str} to {end_date_str}", - {"stage": "fetch_channels", "start_date": start_date_str, "end_date": end_date_str} + { + "stage": "fetch_channels", + "start_date": start_date_str, + "end_date": end_date_str, + }, ) - + # Get all channels try: channels = slack_client.get_all_channels() @@ -150,133 +189,162 @@ async def index_slack_messages( log_entry, f"Failed to get Slack channels for connector {connector_id}", str(e), - {"error_type": "ChannelFetchError"} + {"error_type": "ChannelFetchError"}, ) - return 0, f"Failed to get Slack channels: {str(e)}" - + return 0, f"Failed to get Slack channels: {e!s}" + if not channels: await task_logger.log_task_success( log_entry, f"No Slack channels found for connector {connector_id}", - {"channels_found": 0} + {"channels_found": 0}, ) return 0, "No Slack channels found" - + # Track the number of documents indexed documents_indexed = 0 documents_skipped = 0 skipped_channels = [] - + await task_logger.log_task_progress( log_entry, f"Starting to process {len(channels)} Slack channels", - {"stage": "process_channels", "total_channels": len(channels)} + {"stage": "process_channels", "total_channels": len(channels)}, ) - + # Process each channel - for channel_obj in channels: # Modified loop to iterate over list of channel objects + for ( + channel_obj + ) in channels: # Modified loop to iterate over list of channel objects channel_id = channel_obj["id"] channel_name = channel_obj["name"] is_private = channel_obj["is_private"] - is_member = channel_obj["is_member"] # This might be False for public channels too + is_member = channel_obj[ + "is_member" + ] # This might be False for public channels too try: # If it's a private channel and the bot is not a member, skip. # For public channels, if they are listed by conversations.list, the bot can typically read history. # The `not_in_channel` error in get_conversation_history will be the ultimate gatekeeper if history is inaccessible. if is_private and not is_member: - logger.warning(f"Bot is not a member of private channel {channel_name} ({channel_id}). Skipping.") - skipped_channels.append(f"{channel_name} (private, bot not a member)") + logger.warning( + f"Bot is not a member of private channel {channel_name} ({channel_id}). Skipping." + ) + skipped_channels.append( + f"{channel_name} (private, bot not a member)" + ) documents_skipped += 1 continue - + # Get messages for this channel - # The get_history_by_date_range now uses get_conversation_history, + # The get_history_by_date_range now uses get_conversation_history, # which handles 'not_in_channel' by returning [] and logging. messages, error = slack_client.get_history_by_date_range( channel_id=channel_id, start_date=start_date_str, end_date=end_date_str, - limit=1000 # Limit to 1000 messages per channel + limit=1000, # Limit to 1000 messages per channel ) - + if error: - logger.warning(f"Error getting messages from channel {channel_name}: {error}") + logger.warning( + f"Error getting messages from channel {channel_name}: {error}" + ) skipped_channels.append(f"{channel_name} (error: {error})") documents_skipped += 1 continue # Skip this channel if there's an error - + if not messages: - logger.info(f"No messages found in channel {channel_name} for the specified date range.") + logger.info( + f"No messages found in channel {channel_name} for the specified date range." + ) documents_skipped += 1 continue # Skip if no messages - + # Format messages with user info formatted_messages = [] for msg in messages: # Skip bot messages and system messages - if msg.get("subtype") in ["bot_message", "channel_join", "channel_leave"]: + if msg.get("subtype") in [ + "bot_message", + "channel_join", + "channel_leave", + ]: continue - - formatted_msg = slack_client.format_message(msg, include_user_info=True) + + formatted_msg = slack_client.format_message( + msg, include_user_info=True + ) formatted_messages.append(formatted_msg) - + if not formatted_messages: - logger.info(f"No valid messages found in channel {channel_name} after filtering.") + logger.info( + f"No valid messages found in channel {channel_name} after filtering." + ) documents_skipped += 1 continue # Skip if no valid messages after filtering - + # Convert messages to markdown format channel_content = f"# Slack Channel: {channel_name}\n\n" - + for msg in formatted_messages: user_name = msg.get("user_name", "Unknown User") timestamp = msg.get("datetime", "Unknown Time") text = msg.get("text", "") - - channel_content += f"## {user_name} ({timestamp})\n\n{text}\n\n---\n\n" - + + channel_content += ( + f"## {user_name} ({timestamp})\n\n{text}\n\n---\n\n" + ) + # Format document metadata metadata_sections = [ - ("METADATA", [ - f"CHANNEL_NAME: {channel_name}", - f"CHANNEL_ID: {channel_id}", - # f"START_DATE: {start_date_str}", - # f"END_DATE: {end_date_str}", - f"MESSAGE_COUNT: {len(formatted_messages)}" - ]), - ("CONTENT", [ - "FORMAT: markdown", - "TEXT_START", - channel_content, - "TEXT_END" - ]) + ( + "METADATA", + [ + f"CHANNEL_NAME: {channel_name}", + f"CHANNEL_ID: {channel_id}", + # f"START_DATE: {start_date_str}", + # f"END_DATE: {end_date_str}", + f"MESSAGE_COUNT: {len(formatted_messages)}", + ], + ), + ( + "CONTENT", + ["FORMAT: markdown", "TEXT_START", channel_content, "TEXT_END"], + ), ] - + # Build the document string document_parts = [] document_parts.append("") - + for section_title, section_content in metadata_sections: document_parts.append(f"<{section_title}>") document_parts.extend(section_content) document_parts.append(f"") - + document_parts.append("") - combined_document_string = '\n'.join(document_parts) - content_hash = generate_content_hash(combined_document_string, search_space_id) + combined_document_string = "\n".join(document_parts) + content_hash = generate_content_hash( + combined_document_string, search_space_id + ) # Check if document with this content hash already exists existing_doc_by_hash_result = await session.execute( select(Document).where(Document.content_hash == content_hash) ) - existing_document_by_hash = existing_doc_by_hash_result.scalars().first() - + existing_document_by_hash = ( + existing_doc_by_hash_result.scalars().first() + ) + if existing_document_by_hash: - logger.info(f"Document with content hash {content_hash} already exists for channel {channel_name}. Skipping processing.") + logger.info( + f"Document with content hash {content_hash} already exists for channel {channel_name}. Skipping processing." + ) documents_skipped += 1 continue - + # Get user's long context LLM user_llm = await get_user_long_context_llm(session, user_id) if not user_llm: @@ -284,19 +352,26 @@ async def index_slack_messages( skipped_channels.append(f"{channel_name} (no LLM configured)") documents_skipped += 1 continue - + # Generate summary summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm - summary_result = await summary_chain.ainvoke({"document": combined_document_string}) + summary_result = await summary_chain.ainvoke( + {"document": combined_document_string} + ) summary_content = summary_result.content - summary_embedding = config.embedding_model_instance.embed(summary_content) - + summary_embedding = config.embedding_model_instance.embed( + summary_content + ) + # Process chunks chunks = [ - Chunk(content=chunk.text, embedding=config.embedding_model_instance.embed(chunk.text)) + Chunk( + content=chunk.text, + embedding=config.embedding_model_instance.embed(chunk.text), + ) for chunk in config.chunker_instance.chunk(channel_content) ] - + # Create and store new document document = Document( search_space_id=search_space_id, @@ -308,45 +383,49 @@ async def index_slack_messages( "start_date": start_date_str, "end_date": end_date_str, "message_count": len(formatted_messages), - "indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), }, content=summary_content, embedding=summary_embedding, chunks=chunks, content_hash=content_hash, ) - + session.add(document) documents_indexed += 1 - logger.info(f"Successfully indexed new channel {channel_name} with {len(formatted_messages)} messages") - + logger.info( + f"Successfully indexed new channel {channel_name} with {len(formatted_messages)} messages" + ) + except SlackApiError as slack_error: - logger.error(f"Slack API error for channel {channel_name}: {str(slack_error)}") + logger.error( + f"Slack API error for channel {channel_name}: {slack_error!s}" + ) skipped_channels.append(f"{channel_name} (Slack API error)") documents_skipped += 1 continue # Skip this channel and continue with others except Exception as e: - logger.error(f"Error processing channel {channel_name}: {str(e)}") + logger.error(f"Error processing channel {channel_name}: {e!s}") skipped_channels.append(f"{channel_name} (processing error)") documents_skipped += 1 continue # Skip this channel and continue with others - + # Update the last_indexed_at timestamp for the connector only if requested # and if we successfully indexed at least one channel total_processed = documents_indexed if update_last_indexed and total_processed > 0: connector.last_indexed_at = datetime.now() - + # Commit all changes await session.commit() - + # Prepare result message result_message = None if skipped_channels: result_message = f"Processed {total_processed} channels. Skipped {len(skipped_channels)} channels: {', '.join(skipped_channels)}" else: result_message = f"Processed {total_processed} channels." - + # Log success await task_logger.log_task_success( log_entry, @@ -356,91 +435,102 @@ async def index_slack_messages( "documents_indexed": documents_indexed, "documents_skipped": documents_skipped, "skipped_channels_count": len(skipped_channels), - "result_message": result_message - } + "result_message": result_message, + }, + ) + + logger.info( + f"Slack indexing completed: {documents_indexed} new channels, {documents_skipped} skipped" ) - - logger.info(f"Slack indexing completed: {documents_indexed} new channels, {documents_skipped} skipped") return total_processed, result_message - + except SQLAlchemyError as db_error: await session.rollback() await task_logger.log_task_failure( log_entry, f"Database error during Slack indexing for connector {connector_id}", str(db_error), - {"error_type": "SQLAlchemyError"} + {"error_type": "SQLAlchemyError"}, ) - logger.error(f"Database error: {str(db_error)}") - return 0, f"Database error: {str(db_error)}" + logger.error(f"Database error: {db_error!s}") + return 0, f"Database error: {db_error!s}" except Exception as e: await session.rollback() await task_logger.log_task_failure( log_entry, f"Failed to index Slack messages for connector {connector_id}", str(e), - {"error_type": type(e).__name__} + {"error_type": type(e).__name__}, ) - logger.error(f"Failed to index Slack messages: {str(e)}") - return 0, f"Failed to index Slack messages: {str(e)}" + logger.error(f"Failed to index Slack messages: {e!s}") + return 0, f"Failed to index Slack messages: {e!s}" + async def index_notion_pages( session: AsyncSession, connector_id: int, search_space_id: int, user_id: str, - start_date: str = None, - end_date: str = None, - update_last_indexed: bool = True -) -> Tuple[int, Optional[str]]: + start_date: str | None = None, + end_date: str | None = None, + update_last_indexed: bool = True, +) -> tuple[int, str | None]: """ Index Notion pages from all accessible pages. - + Args: session: Database session connector_id: ID of the Notion connector search_space_id: ID of the search space to store documents in update_last_indexed: Whether to update the last_indexed_at timestamp (default: True) - + Returns: Tuple containing (number of documents indexed, error message or None) """ task_logger = TaskLoggingService(session, search_space_id) - + # Log task start log_entry = await task_logger.log_task_start( task_name="notion_pages_indexing", source="connector_indexing_task", message=f"Starting Notion pages indexing for connector {connector_id}", - metadata={"connector_id": connector_id, "user_id": str(user_id), "start_date": start_date, "end_date": end_date} + metadata={ + "connector_id": connector_id, + "user_id": str(user_id), + "start_date": start_date, + "end_date": end_date, + }, ) - + try: # Get the connector await task_logger.log_task_progress( log_entry, f"Retrieving Notion connector {connector_id} from database", - {"stage": "connector_retrieval"} + {"stage": "connector_retrieval"}, ) - + result = await session.execute( - select(SearchSourceConnector) - .filter( + select(SearchSourceConnector).filter( SearchSourceConnector.id == connector_id, - SearchSourceConnector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR + SearchSourceConnector.connector_type + == SearchSourceConnectorType.NOTION_CONNECTOR, ) ) connector = result.scalars().first() - + if not connector: await task_logger.log_task_failure( log_entry, f"Connector with ID {connector_id} not found or is not a Notion connector", "Connector not found", - {"error_type": "ConnectorNotFound"} + {"error_type": "ConnectorNotFound"}, ) - return 0, f"Connector with ID {connector_id} not found or is not a Notion connector" - + return ( + 0, + f"Connector with ID {connector_id} not found or is not a Notion connector", + ) + # Get the Notion token from the connector config notion_token = connector.config.get("NOTION_INTEGRATION_TOKEN") if not notion_token: @@ -448,103 +538,119 @@ async def index_notion_pages( log_entry, f"Notion integration token not found in connector config for connector {connector_id}", "Missing Notion token", - {"error_type": "MissingToken"} + {"error_type": "MissingToken"}, ) return 0, "Notion integration token not found in connector config" - + # Initialize Notion client await task_logger.log_task_progress( log_entry, f"Initializing Notion client for connector {connector_id}", - {"stage": "client_initialization"} + {"stage": "client_initialization"}, ) - + logger.info(f"Initializing Notion client for connector {connector_id}") notion_client = NotionHistoryConnector(token=notion_token) - + # Calculate date range if start_date is None or end_date is None: # Fall back to calculating dates calculated_end_date = datetime.now() - calculated_start_date = calculated_end_date - timedelta(days=365) # Check for last 1 year of pages - + calculated_start_date = calculated_end_date - timedelta( + days=365 + ) # Check for last 1 year of pages + # Use calculated dates if not provided if start_date is None: start_date_iso = calculated_start_date.strftime("%Y-%m-%dT%H:%M:%SZ") else: # Convert YYYY-MM-DD to ISO format - start_date_iso = datetime.strptime(start_date, "%Y-%m-%d").strftime("%Y-%m-%dT%H:%M:%SZ") - + start_date_iso = datetime.strptime(start_date, "%Y-%m-%d").strftime( + "%Y-%m-%dT%H:%M:%SZ" + ) + if end_date is None: end_date_iso = calculated_end_date.strftime("%Y-%m-%dT%H:%M:%SZ") else: # Convert YYYY-MM-DD to ISO format - end_date_iso = datetime.strptime(end_date, "%Y-%m-%d").strftime("%Y-%m-%dT%H:%M:%SZ") + end_date_iso = datetime.strptime(end_date, "%Y-%m-%d").strftime( + "%Y-%m-%dT%H:%M:%SZ" + ) else: # Convert provided dates to ISO format for Notion API - start_date_iso = datetime.strptime(start_date, "%Y-%m-%d").strftime("%Y-%m-%dT%H:%M:%SZ") - end_date_iso = datetime.strptime(end_date, "%Y-%m-%d").strftime("%Y-%m-%dT%H:%M:%SZ") - + start_date_iso = datetime.strptime(start_date, "%Y-%m-%d").strftime( + "%Y-%m-%dT%H:%M:%SZ" + ) + end_date_iso = datetime.strptime(end_date, "%Y-%m-%d").strftime( + "%Y-%m-%dT%H:%M:%SZ" + ) + logger.info(f"Fetching Notion pages from {start_date_iso} to {end_date_iso}") - + await task_logger.log_task_progress( log_entry, f"Fetching Notion pages from {start_date_iso} to {end_date_iso}", - {"stage": "fetch_pages", "start_date": start_date_iso, "end_date": end_date_iso} + { + "stage": "fetch_pages", + "start_date": start_date_iso, + "end_date": end_date_iso, + }, ) - + # Get all pages try: - pages = notion_client.get_all_pages(start_date=start_date_iso, end_date=end_date_iso) + pages = notion_client.get_all_pages( + start_date=start_date_iso, end_date=end_date_iso + ) logger.info(f"Found {len(pages)} Notion pages") except Exception as e: await task_logger.log_task_failure( log_entry, f"Failed to get Notion pages for connector {connector_id}", str(e), - {"error_type": "PageFetchError"} + {"error_type": "PageFetchError"}, ) - logger.error(f"Error fetching Notion pages: {str(e)}", exc_info=True) - return 0, f"Failed to get Notion pages: {str(e)}" - + logger.error(f"Error fetching Notion pages: {e!s}", exc_info=True) + return 0, f"Failed to get Notion pages: {e!s}" + if not pages: await task_logger.log_task_success( log_entry, f"No Notion pages found for connector {connector_id}", - {"pages_found": 0} + {"pages_found": 0}, ) logger.info("No Notion pages found to index") return 0, "No Notion pages found" - + # Track the number of documents indexed documents_indexed = 0 documents_skipped = 0 skipped_pages = [] - + await task_logger.log_task_progress( log_entry, f"Starting to process {len(pages)} Notion pages", - {"stage": "process_pages", "total_pages": len(pages)} + {"stage": "process_pages", "total_pages": len(pages)}, ) - + # Process each page for page in pages: try: page_id = page.get("page_id") page_title = page.get("title", f"Untitled page ({page_id})") page_content = page.get("content", []) - + logger.info(f"Processing Notion page: {page_title} ({page_id})") - + if not page_content: logger.info(f"No content found in page {page_title}. Skipping.") skipped_pages.append(f"{page_title} (no content)") documents_skipped += 1 continue - + # Convert page content to markdown format markdown_content = f"# Notion Page: {page_title}\n\n" - + # Process blocks recursively def process_blocks(blocks, level=0): result = "" @@ -552,10 +658,10 @@ async def index_notion_pages( block_type = block.get("type") block_content = block.get("content", "") children = block.get("children", []) - + # Add indentation based on level indent = " " * level - + # Format based on block type if block_type in ["paragraph", "text"]: result += f"{indent}{block_content}\n\n" @@ -585,54 +691,62 @@ async def index_notion_pages( # Default for other block types if block_content: result += f"{indent}{block_content}\n\n" - + # Process children recursively if children: result += process_blocks(children, level + 1) - + return result - - logger.debug(f"Converting {len(page_content)} blocks to markdown for page {page_title}") + + logger.debug( + f"Converting {len(page_content)} blocks to markdown for page {page_title}" + ) markdown_content += process_blocks(page_content) - + # Format document metadata metadata_sections = [ - ("METADATA", [ - f"PAGE_TITLE: {page_title}", - f"PAGE_ID: {page_id}" - ]), - ("CONTENT", [ - "FORMAT: markdown", - "TEXT_START", - markdown_content, - "TEXT_END" - ]) + ("METADATA", [f"PAGE_TITLE: {page_title}", f"PAGE_ID: {page_id}"]), + ( + "CONTENT", + [ + "FORMAT: markdown", + "TEXT_START", + markdown_content, + "TEXT_END", + ], + ), ] - + # Build the document string document_parts = [] document_parts.append("") - + for section_title, section_content in metadata_sections: document_parts.append(f"<{section_title}>") document_parts.extend(section_content) document_parts.append(f"") - + document_parts.append("") - combined_document_string = '\n'.join(document_parts) - content_hash = generate_content_hash(combined_document_string, search_space_id) + combined_document_string = "\n".join(document_parts) + content_hash = generate_content_hash( + combined_document_string, search_space_id + ) # Check if document with this content hash already exists existing_doc_by_hash_result = await session.execute( select(Document).where(Document.content_hash == content_hash) ) - existing_document_by_hash = existing_doc_by_hash_result.scalars().first() - + existing_document_by_hash = ( + existing_doc_by_hash_result.scalars().first() + ) + if existing_document_by_hash: - logger.info(f"Document with content hash {content_hash} already exists for page {page_title}. Skipping processing.") + logger.info( + f"Document with content hash {content_hash} already exists for page {page_title}. Skipping processing." + ) documents_skipped += 1 continue - + # Get user's long context LLM user_llm = await get_user_long_context_llm(session, user_id) if not user_llm: @@ -640,21 +754,28 @@ async def index_notion_pages( skipped_pages.append(f"{page_title} (no LLM configured)") documents_skipped += 1 continue - + # Generate summary logger.debug(f"Generating summary for page {page_title}") summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm - summary_result = await summary_chain.ainvoke({"document": combined_document_string}) + summary_result = await summary_chain.ainvoke( + {"document": combined_document_string} + ) summary_content = summary_result.content - summary_embedding = config.embedding_model_instance.embed(summary_content) - + summary_embedding = config.embedding_model_instance.embed( + summary_content + ) + # Process chunks logger.debug(f"Chunking content for page {page_title}") chunks = [ - Chunk(content=chunk.text, embedding=config.embedding_model_instance.embed(chunk.text)) + Chunk( + content=chunk.text, + embedding=config.embedding_model_instance.embed(chunk.text), + ) for chunk in config.chunker_instance.chunk(markdown_content) ] - + # Create and store new document document = Document( search_space_id=search_space_id, @@ -663,41 +784,46 @@ async def index_notion_pages( document_metadata={ "page_title": page_title, "page_id": page_id, - "indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), }, content=summary_content, content_hash=content_hash, embedding=summary_embedding, - chunks=chunks + chunks=chunks, ) - + session.add(document) documents_indexed += 1 logger.info(f"Successfully indexed new Notion page: {page_title}") - + except Exception as e: - logger.error(f"Error processing Notion page {page.get('title', 'Unknown')}: {str(e)}", exc_info=True) - skipped_pages.append(f"{page.get('title', 'Unknown')} (processing error)") + logger.error( + f"Error processing Notion page {page.get('title', 'Unknown')}: {e!s}", + exc_info=True, + ) + skipped_pages.append( + f"{page.get('title', 'Unknown')} (processing error)" + ) documents_skipped += 1 continue # Skip this page and continue with others - + # Update the last_indexed_at timestamp for the connector only if requested # and if we successfully indexed at least one page total_processed = documents_indexed if update_last_indexed and total_processed > 0: connector.last_indexed_at = datetime.now() logger.info(f"Updated last_indexed_at for connector {connector_id}") - + # Commit all changes await session.commit() - + # Prepare result message result_message = None if skipped_pages: result_message = f"Processed {total_processed} pages. Skipped {len(skipped_pages)} pages: {', '.join(skipped_pages)}" else: result_message = f"Processed {total_processed} pages." - + # Log success await task_logger.log_task_success( log_entry, @@ -707,43 +833,48 @@ async def index_notion_pages( "documents_indexed": documents_indexed, "documents_skipped": documents_skipped, "skipped_pages_count": len(skipped_pages), - "result_message": result_message - } + "result_message": result_message, + }, + ) + + logger.info( + f"Notion indexing completed: {documents_indexed} new pages, {documents_skipped} skipped" ) - - logger.info(f"Notion indexing completed: {documents_indexed} new pages, {documents_skipped} skipped") return total_processed, result_message - + except SQLAlchemyError as db_error: await session.rollback() await task_logger.log_task_failure( log_entry, f"Database error during Notion indexing for connector {connector_id}", str(db_error), - {"error_type": "SQLAlchemyError"} + {"error_type": "SQLAlchemyError"}, ) - logger.error(f"Database error during Notion indexing: {str(db_error)}", exc_info=True) - return 0, f"Database error: {str(db_error)}" + logger.error( + f"Database error during Notion indexing: {db_error!s}", exc_info=True + ) + return 0, f"Database error: {db_error!s}" except Exception as e: await session.rollback() await task_logger.log_task_failure( log_entry, f"Failed to index Notion pages for connector {connector_id}", str(e), - {"error_type": type(e).__name__} + {"error_type": type(e).__name__}, ) - logger.error(f"Failed to index Notion pages: {str(e)}", exc_info=True) - return 0, f"Failed to index Notion pages: {str(e)}" + logger.error(f"Failed to index Notion pages: {e!s}", exc_info=True) + return 0, f"Failed to index Notion pages: {e!s}" + async def index_github_repos( session: AsyncSession, connector_id: int, search_space_id: int, user_id: str, - start_date: str = None, - end_date: str = None, - update_last_indexed: bool = True -) -> Tuple[int, Optional[str]]: + start_date: str | None = None, + end_date: str | None = None, + update_last_indexed: bool = True, +) -> tuple[int, str | None]: """ Index code and documentation files from accessible GitHub repositories. @@ -757,15 +888,20 @@ async def index_github_repos( Tuple containing (number of documents indexed, error message or None) """ task_logger = TaskLoggingService(session, search_space_id) - + # Log task start log_entry = await task_logger.log_task_start( task_name="github_repos_indexing", source="connector_indexing_task", message=f"Starting GitHub repositories indexing for connector {connector_id}", - metadata={"connector_id": connector_id, "user_id": str(user_id), "start_date": start_date, "end_date": end_date} + metadata={ + "connector_id": connector_id, + "user_id": str(user_id), + "start_date": start_date, + "end_date": end_date, + }, ) - + documents_processed = 0 errors = [] @@ -774,14 +910,14 @@ async def index_github_repos( await task_logger.log_task_progress( log_entry, f"Retrieving GitHub connector {connector_id} from database", - {"stage": "connector_retrieval"} + {"stage": "connector_retrieval"}, ) - + result = await session.execute( - select(SearchSourceConnector) - .filter( + select(SearchSourceConnector).filter( SearchSourceConnector.id == connector_id, - SearchSourceConnector.connector_type == SearchSourceConnectorType.GITHUB_CONNECTOR + SearchSourceConnector.connector_type + == SearchSourceConnectorType.GITHUB_CONNECTOR, ) ) connector = result.scalars().first() @@ -791,9 +927,12 @@ async def index_github_repos( log_entry, f"Connector with ID {connector_id} not found or is not a GitHub connector", "Connector not found", - {"error_type": "ConnectorNotFound"} + {"error_type": "ConnectorNotFound"}, + ) + return ( + 0, + f"Connector with ID {connector_id} not found or is not a GitHub connector", ) - return 0, f"Connector with ID {connector_id} not found or is not a GitHub connector" # 2. Get the GitHub PAT and selected repositories from the connector config github_pat = connector.config.get("GITHUB_PAT") @@ -804,16 +943,18 @@ async def index_github_repos( log_entry, f"GitHub Personal Access Token (PAT) not found in connector config for connector {connector_id}", "Missing GitHub PAT", - {"error_type": "MissingToken"} + {"error_type": "MissingToken"}, ) return 0, "GitHub Personal Access Token (PAT) not found in connector config" - - if not repo_full_names_to_index or not isinstance(repo_full_names_to_index, list): + + if not repo_full_names_to_index or not isinstance( + repo_full_names_to_index, list + ): await task_logger.log_task_failure( log_entry, f"'repo_full_names' not found or is not a list in connector config for connector {connector_id}", "Invalid repo configuration", - {"error_type": "InvalidConfiguration"} + {"error_type": "InvalidConfiguration"}, ) return 0, "'repo_full_names' not found or is not a list in connector config" @@ -821,9 +962,12 @@ async def index_github_repos( await task_logger.log_task_progress( log_entry, f"Initializing GitHub client for connector {connector_id}", - {"stage": "client_initialization", "repo_count": len(repo_full_names_to_index)} + { + "stage": "client_initialization", + "repo_count": len(repo_full_names_to_index), + }, ) - + try: github_client = GitHubConnector(token=github_pat) except ValueError as e: @@ -831,9 +975,9 @@ async def index_github_repos( log_entry, f"Failed to initialize GitHub client for connector {connector_id}", str(e), - {"error_type": "ClientInitializationError"} + {"error_type": "ClientInitializationError"}, ) - return 0, f"Failed to initialize GitHub client: {str(e)}" + return 0, f"Failed to initialize GitHub client: {e!s}" # 4. Validate selected repositories # For simplicity, we'll proceed with the list provided. @@ -841,12 +985,21 @@ async def index_github_repos( await task_logger.log_task_progress( log_entry, f"Starting indexing for {len(repo_full_names_to_index)} selected repositories", - {"stage": "repo_processing", "repo_count": len(repo_full_names_to_index), "start_date": start_date, "end_date": end_date} + { + "stage": "repo_processing", + "repo_count": len(repo_full_names_to_index), + "start_date": start_date, + "end_date": end_date, + }, + ) + + logger.info( + f"Starting indexing for {len(repo_full_names_to_index)} selected repositories." ) - - logger.info(f"Starting indexing for {len(repo_full_names_to_index)} selected repositories.") if start_date and end_date: - logger.info(f"Date range requested: {start_date} to {end_date} (Note: GitHub indexing processes all files regardless of dates)") + logger.info( + f"Date range requested: {start_date} to {end_date} (Note: GitHub indexing processes all files regardless of dates)" + ) # 6. Iterate through selected repositories and index files for repo_full_name in repo_full_names_to_index: @@ -858,65 +1011,92 @@ async def index_github_repos( try: files_to_index = github_client.get_repository_files(repo_full_name) if not files_to_index: - logger.info(f"No indexable files found in repository: {repo_full_name}") + logger.info( + f"No indexable files found in repository: {repo_full_name}" + ) continue - logger.info(f"Found {len(files_to_index)} files to process in {repo_full_name}") + logger.info( + f"Found {len(files_to_index)} files to process in {repo_full_name}" + ) for file_info in files_to_index: file_path = file_info.get("path") file_url = file_info.get("url") file_sha = file_info.get("sha") - file_type = file_info.get("type") # 'code' or 'doc' + file_type = file_info.get("type") # 'code' or 'doc' full_path_key = f"{repo_full_name}/{file_path}" if not file_path or not file_url or not file_sha: - logger.warning(f"Skipping file with missing info in {repo_full_name}: {file_info}") + logger.warning( + f"Skipping file with missing info in {repo_full_name}: {file_info}" + ) continue # Get file content - file_content = github_client.get_file_content(repo_full_name, file_path) + file_content = github_client.get_file_content( + repo_full_name, file_path + ) if file_content is None: - logger.warning(f"Could not retrieve content for {full_path_key}. Skipping.") - continue # Skip if content fetch failed - + logger.warning( + f"Could not retrieve content for {full_path_key}. Skipping." + ) + continue # Skip if content fetch failed + content_hash = generate_content_hash(file_content, search_space_id) # Check if document with this content hash already exists existing_doc_by_hash_result = await session.execute( select(Document).where(Document.content_hash == content_hash) ) - existing_document_by_hash = existing_doc_by_hash_result.scalars().first() - + existing_document_by_hash = ( + existing_doc_by_hash_result.scalars().first() + ) + if existing_document_by_hash: - logger.info(f"Document with content hash {content_hash} already exists for file {full_path_key}. Skipping processing.") + logger.info( + f"Document with content hash {content_hash} already exists for file {full_path_key}. Skipping processing." + ) continue - + # Use file_content directly for chunking, maybe summary for main content? # For now, let's use the full content for both, might need refinement - summary_content = f"GitHub file: {full_path_key}\n\n{file_content[:1000]}..." # Simple summary - summary_embedding = config.embedding_model_instance.embed(summary_content) + summary_content = f"GitHub file: {full_path_key}\n\n{file_content[:1000]}..." # Simple summary + summary_embedding = config.embedding_model_instance.embed( + summary_content + ) # Chunk the content try: chunks_data = [ - Chunk(content=chunk.text, embedding=config.embedding_model_instance.embed(chunk.text)) - for chunk in config.code_chunker_instance.chunk(file_content) + Chunk( + content=chunk.text, + embedding=config.embedding_model_instance.embed( + chunk.text + ), + ) + for chunk in config.code_chunker_instance.chunk( + file_content + ) ] except Exception as chunk_err: - logger.error(f"Failed to chunk file {full_path_key}: {chunk_err}") - errors.append(f"Chunking failed for {full_path_key}: {chunk_err}") - continue # Skip this file if chunking fails + logger.error( + f"Failed to chunk file {full_path_key}: {chunk_err}" + ) + errors.append( + f"Chunking failed for {full_path_key}: {chunk_err}" + ) + continue # Skip this file if chunking fails doc_metadata = { "repository_full_name": repo_full_name, "file_path": file_path, - "full_path": full_path_key, # For easier lookup + "full_path": full_path_key, # For easier lookup "url": file_url, "sha": file_sha, "type": file_type, - "indexed_at": datetime.now(timezone.utc).isoformat() + "indexed_at": datetime.now(UTC).isoformat(), } # Create new document @@ -925,22 +1105,26 @@ async def index_github_repos( title=f"GitHub - {file_path}", document_type=DocumentType.GITHUB_CONNECTOR, document_metadata=doc_metadata, - content=summary_content, # Store summary + content=summary_content, # Store summary content_hash=content_hash, embedding=summary_embedding, search_space_id=search_space_id, - chunks=chunks_data # Associate chunks directly + chunks=chunks_data, # Associate chunks directly ) session.add(document) documents_processed += 1 except Exception as repo_err: - logger.error(f"Failed to process repository {repo_full_name}: {repo_err}") + logger.error( + f"Failed to process repository {repo_full_name}: {repo_err}" + ) errors.append(f"Failed processing {repo_full_name}: {repo_err}") - + # Commit all changes at the end await session.commit() - logger.info(f"Finished GitHub indexing for connector {connector_id}. Processed {documents_processed} files.") + logger.info( + f"Finished GitHub indexing for connector {connector_id}. Processed {documents_processed} files." + ) # Log success await task_logger.log_task_success( @@ -949,8 +1133,8 @@ async def index_github_repos( { "documents_processed": documents_processed, "errors_count": len(errors), - "repo_count": len(repo_full_names_to_index) - } + "repo_count": len(repo_full_names_to_index), + }, ) except SQLAlchemyError as db_err: @@ -959,9 +1143,11 @@ async def index_github_repos( log_entry, f"Database error during GitHub indexing for connector {connector_id}", str(db_err), - {"error_type": "SQLAlchemyError"} + {"error_type": "SQLAlchemyError"}, + ) + logger.error( + f"Database error during GitHub indexing for connector {connector_id}: {db_err}" ) - logger.error(f"Database error during GitHub indexing for connector {connector_id}: {db_err}") errors.append(f"Database error: {db_err}") return documents_processed, "; ".join(errors) if errors else str(db_err) except Exception as e: @@ -970,72 +1156,84 @@ async def index_github_repos( log_entry, f"Unexpected error during GitHub indexing for connector {connector_id}", str(e), - {"error_type": type(e).__name__} + {"error_type": type(e).__name__}, + ) + logger.error( + f"Unexpected error during GitHub indexing for connector {connector_id}: {e}", + exc_info=True, ) - logger.error(f"Unexpected error during GitHub indexing for connector {connector_id}: {e}", exc_info=True) errors.append(f"Unexpected error: {e}") return documents_processed, "; ".join(errors) if errors else str(e) error_message = "; ".join(errors) if errors else None return documents_processed, error_message + async def index_linear_issues( session: AsyncSession, connector_id: int, search_space_id: int, user_id: str, - start_date: str = None, - end_date: str = None, - update_last_indexed: bool = True -) -> Tuple[int, Optional[str]]: + start_date: str | None = None, + end_date: str | None = None, + update_last_indexed: bool = True, +) -> tuple[int, str | None]: """ Index Linear issues and comments. - + Args: session: Database session connector_id: ID of the Linear connector search_space_id: ID of the search space to store documents in update_last_indexed: Whether to update the last_indexed_at timestamp (default: True) - + Returns: Tuple containing (number of documents indexed, error message or None) """ task_logger = TaskLoggingService(session, search_space_id) - + # Log task start log_entry = await task_logger.log_task_start( task_name="linear_issues_indexing", source="connector_indexing_task", message=f"Starting Linear issues indexing for connector {connector_id}", - metadata={"connector_id": connector_id, "user_id": str(user_id), "start_date": start_date, "end_date": end_date} + metadata={ + "connector_id": connector_id, + "user_id": str(user_id), + "start_date": start_date, + "end_date": end_date, + }, ) - + try: # Get the connector await task_logger.log_task_progress( log_entry, f"Retrieving Linear connector {connector_id} from database", - {"stage": "connector_retrieval"} + {"stage": "connector_retrieval"}, ) - + result = await session.execute( - select(SearchSourceConnector) - .filter( + select(SearchSourceConnector).filter( SearchSourceConnector.id == connector_id, - SearchSourceConnector.connector_type == SearchSourceConnectorType.LINEAR_CONNECTOR + SearchSourceConnector.connector_type + == SearchSourceConnectorType.LINEAR_CONNECTOR, ) ) connector = result.scalars().first() - + if not connector: await task_logger.log_task_failure( log_entry, f"Connector with ID {connector_id} not found or is not a Linear connector", "Connector not found", - {"error_type": "ConnectorNotFound"} + {"error_type": "ConnectorNotFound"}, ) - return 0, f"Connector with ID {connector_id} not found or is not a Linear connector" - + return ( + 0, + f"Connector with ID {connector_id} not found or is not a Linear connector", + ) + # Get the Linear token from the connector config linear_token = connector.config.get("LINEAR_API_KEY") if not linear_token: @@ -1043,135 +1241,167 @@ async def index_linear_issues( log_entry, f"Linear API token not found in connector config for connector {connector_id}", "Missing Linear token", - {"error_type": "MissingToken"} + {"error_type": "MissingToken"}, ) return 0, "Linear API token not found in connector config" - + # Initialize Linear client await task_logger.log_task_progress( log_entry, f"Initializing Linear client for connector {connector_id}", - {"stage": "client_initialization"} + {"stage": "client_initialization"}, ) - + linear_client = LinearConnector(token=linear_token) - + # Calculate date range if start_date is None or end_date is None: # Fall back to calculating dates based on last_indexed_at calculated_end_date = datetime.now() - + # Use last_indexed_at as start date if available, otherwise use 365 days ago if connector.last_indexed_at: # Convert dates to be comparable (both timezone-naive) - last_indexed_naive = connector.last_indexed_at.replace(tzinfo=None) if connector.last_indexed_at.tzinfo else connector.last_indexed_at - + last_indexed_naive = ( + connector.last_indexed_at.replace(tzinfo=None) + if connector.last_indexed_at.tzinfo + else connector.last_indexed_at + ) + # Check if last_indexed_at is in the future or after end_date if last_indexed_naive > calculated_end_date: - logger.warning(f"Last indexed date ({last_indexed_naive.strftime('%Y-%m-%d')}) is in the future. Using 365 days ago instead.") + logger.warning( + f"Last indexed date ({last_indexed_naive.strftime('%Y-%m-%d')}) is in the future. Using 365 days ago instead." + ) calculated_start_date = calculated_end_date - timedelta(days=365) else: calculated_start_date = last_indexed_naive - logger.info(f"Using last_indexed_at ({calculated_start_date.strftime('%Y-%m-%d')}) as start date") + logger.info( + f"Using last_indexed_at ({calculated_start_date.strftime('%Y-%m-%d')}) as start date" + ) else: - calculated_start_date = calculated_end_date - timedelta(days=365) # Use 365 days as default - logger.info(f"No last_indexed_at found, using {calculated_start_date.strftime('%Y-%m-%d')} (365 days ago) as start date") - + calculated_start_date = calculated_end_date - timedelta( + days=365 + ) # Use 365 days as default + logger.info( + f"No last_indexed_at found, using {calculated_start_date.strftime('%Y-%m-%d')} (365 days ago) as start date" + ) + # Use calculated dates if not provided - start_date_str = start_date if start_date else calculated_start_date.strftime("%Y-%m-%d") - end_date_str = end_date if end_date else calculated_end_date.strftime("%Y-%m-%d") + start_date_str = ( + start_date if start_date else calculated_start_date.strftime("%Y-%m-%d") + ) + end_date_str = ( + end_date if end_date else calculated_end_date.strftime("%Y-%m-%d") + ) else: # Use provided dates start_date_str = start_date end_date_str = end_date - + logger.info(f"Fetching Linear issues from {start_date_str} to {end_date_str}") - + await task_logger.log_task_progress( log_entry, f"Fetching Linear issues from {start_date_str} to {end_date_str}", - {"stage": "fetch_issues", "start_date": start_date_str, "end_date": end_date_str} + { + "stage": "fetch_issues", + "start_date": start_date_str, + "end_date": end_date_str, + }, ) - + # Get issues within date range try: issues, error = linear_client.get_issues_by_date_range( - start_date=start_date_str, - end_date=end_date_str, - include_comments=True + start_date=start_date_str, end_date=end_date_str, include_comments=True ) - + if error: logger.error(f"Failed to get Linear issues: {error}") - + # Don't treat "No issues found" as an error that should stop indexing if "No issues found" in error: - logger.info("No issues found is not a critical error, continuing with update") + logger.info( + "No issues found is not a critical error, continuing with update" + ) if update_last_indexed: connector.last_indexed_at = datetime.now() await session.commit() - logger.info(f"Updated last_indexed_at to {connector.last_indexed_at} despite no issues found") + logger.info( + f"Updated last_indexed_at to {connector.last_indexed_at} despite no issues found" + ) return 0, None else: return 0, f"Failed to get Linear issues: {error}" - + logger.info(f"Retrieved {len(issues)} issues from Linear API") - + except Exception as e: - logger.error(f"Exception when calling Linear API: {str(e)}", exc_info=True) - return 0, f"Failed to get Linear issues: {str(e)}" - + logger.error(f"Exception when calling Linear API: {e!s}", exc_info=True) + return 0, f"Failed to get Linear issues: {e!s}" + if not issues: logger.info("No Linear issues found for the specified date range") if update_last_indexed: connector.last_indexed_at = datetime.now() await session.commit() - logger.info(f"Updated last_indexed_at to {connector.last_indexed_at} despite no issues found") + logger.info( + f"Updated last_indexed_at to {connector.last_indexed_at} despite no issues found" + ) return 0, None # Return None instead of error message when no issues found - + # Log issue IDs and titles for debugging logger.info("Issues retrieved from Linear API:") for idx, issue in enumerate(issues[:10]): # Log first 10 issues - logger.info(f" {idx+1}. {issue.get('identifier', 'Unknown')} - {issue.get('title', 'Unknown')} - Created: {issue.get('createdAt', 'Unknown')} - Updated: {issue.get('updatedAt', 'Unknown')}") + logger.info( + f" {idx + 1}. {issue.get('identifier', 'Unknown')} - {issue.get('title', 'Unknown')} - Created: {issue.get('createdAt', 'Unknown')} - Updated: {issue.get('updatedAt', 'Unknown')}" + ) if len(issues) > 10: logger.info(f" ...and {len(issues) - 10} more issues") - + # Track the number of documents indexed documents_indexed = 0 documents_skipped = 0 skipped_issues = [] - + await task_logger.log_task_progress( log_entry, f"Starting to process {len(issues)} Linear issues", - {"stage": "process_issues", "total_issues": len(issues)} + {"stage": "process_issues", "total_issues": len(issues)}, ) - + # Process each issue for issue in issues: try: issue_id = issue.get("id") issue_identifier = issue.get("identifier", "") issue_title = issue.get("title", "") - + if not issue_id or not issue_title: - logger.warning(f"Skipping issue with missing ID or title: {issue_id or 'Unknown'}") - skipped_issues.append(f"{issue_identifier or 'Unknown'} (missing data)") + logger.warning( + f"Skipping issue with missing ID or title: {issue_id or 'Unknown'}" + ) + skipped_issues.append( + f"{issue_identifier or 'Unknown'} (missing data)" + ) documents_skipped += 1 continue - + # Format the issue first to get well-structured data formatted_issue = linear_client.format_issue(issue) - + # Convert issue to markdown format issue_content = linear_client.format_issue_to_markdown(formatted_issue) - + if not issue_content: - logger.warning(f"Skipping issue with no content: {issue_identifier} - {issue_title}") + logger.warning( + f"Skipping issue with no content: {issue_identifier} - {issue_title}" + ) skipped_issues.append(f"{issue_identifier} (no content)") documents_skipped += 1 continue - + # Create a short summary for the embedding # This avoids using the LLM and just uses the issue data directly state = formatted_issue.get("state", "Unknown") @@ -1179,40 +1409,51 @@ async def index_linear_issues( # Truncate description if it's too long for the summary if description and len(description) > 500: description = description[:497] + "..." - + # Create a simple summary from the issue data summary_content = f"Linear Issue {issue_identifier}: {issue_title}\n\nStatus: {state}\n\n" if description: summary_content += f"Description: {description}\n\n" - + # Add comment count comment_count = len(formatted_issue.get("comments", [])) summary_content += f"Comments: {comment_count}" - + content_hash = generate_content_hash(issue_content, search_space_id) # Check if document with this content hash already exists existing_doc_by_hash_result = await session.execute( select(Document).where(Document.content_hash == content_hash) ) - existing_document_by_hash = existing_doc_by_hash_result.scalars().first() - + existing_document_by_hash = ( + existing_doc_by_hash_result.scalars().first() + ) + if existing_document_by_hash: - logger.info(f"Document with content hash {content_hash} already exists for issue {issue_identifier}. Skipping processing.") + logger.info( + f"Document with content hash {content_hash} already exists for issue {issue_identifier}. Skipping processing." + ) documents_skipped += 1 continue - + # Generate embedding for the summary - summary_embedding = config.embedding_model_instance.embed(summary_content) - + summary_embedding = config.embedding_model_instance.embed( + summary_content + ) + # Process chunks - using the full issue content with comments chunks = [ - Chunk(content=chunk.text, embedding=config.embedding_model_instance.embed(chunk.text)) + Chunk( + content=chunk.text, + embedding=config.embedding_model_instance.embed(chunk.text), + ) for chunk in config.chunker_instance.chunk(issue_content) ] - + # Create and store new document - logger.info(f"Creating new document for issue {issue_identifier} - {issue_title}") + logger.info( + f"Creating new document for issue {issue_identifier} - {issue_title}" + ) document = Document( search_space_id=search_space_id, title=f"Linear - {issue_identifier}: {issue_title}", @@ -1223,34 +1464,41 @@ async def index_linear_issues( "issue_title": issue_title, "state": state, "comment_count": comment_count, - "indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), }, content=summary_content, content_hash=content_hash, embedding=summary_embedding, - chunks=chunks + chunks=chunks, ) - + session.add(document) documents_indexed += 1 - logger.info(f"Successfully indexed new issue {issue_identifier} - {issue_title}") - + logger.info( + f"Successfully indexed new issue {issue_identifier} - {issue_title}" + ) + except Exception as e: - logger.error(f"Error processing issue {issue.get('identifier', 'Unknown')}: {str(e)}", exc_info=True) - skipped_issues.append(f"{issue.get('identifier', 'Unknown')} (processing error)") + logger.error( + f"Error processing issue {issue.get('identifier', 'Unknown')}: {e!s}", + exc_info=True, + ) + skipped_issues.append( + f"{issue.get('identifier', 'Unknown')} (processing error)" + ) documents_skipped += 1 continue # Skip this issue and continue with others - + # Update the last_indexed_at timestamp for the connector only if requested total_processed = documents_indexed if update_last_indexed: connector.last_indexed_at = datetime.now() logger.info(f"Updated last_indexed_at to {connector.last_indexed_at}") - + # Commit all changes await session.commit() - logger.info(f"Successfully committed all Linear document changes to database") - + logger.info("Successfully committed all Linear document changes to database") + # Log success await task_logger.log_task_success( log_entry, @@ -1259,43 +1507,49 @@ async def index_linear_issues( "issues_processed": total_processed, "documents_indexed": documents_indexed, "documents_skipped": documents_skipped, - "skipped_issues_count": len(skipped_issues) - } + "skipped_issues_count": len(skipped_issues), + }, ) - - logger.info(f"Linear indexing completed: {documents_indexed} new issues, {documents_skipped} skipped") - return total_processed, None # Return None as the error message to indicate success - + + logger.info( + f"Linear indexing completed: {documents_indexed} new issues, {documents_skipped} skipped" + ) + return ( + total_processed, + None, + ) # Return None as the error message to indicate success + except SQLAlchemyError as db_error: await session.rollback() await task_logger.log_task_failure( log_entry, f"Database error during Linear indexing for connector {connector_id}", str(db_error), - {"error_type": "SQLAlchemyError"} + {"error_type": "SQLAlchemyError"}, ) - logger.error(f"Database error: {str(db_error)}", exc_info=True) - return 0, f"Database error: {str(db_error)}" + logger.error(f"Database error: {db_error!s}", exc_info=True) + return 0, f"Database error: {db_error!s}" except Exception as e: await session.rollback() await task_logger.log_task_failure( log_entry, f"Failed to index Linear issues for connector {connector_id}", str(e), - {"error_type": type(e).__name__} + {"error_type": type(e).__name__}, ) - logger.error(f"Failed to index Linear issues: {str(e)}", exc_info=True) - return 0, f"Failed to index Linear issues: {str(e)}" + logger.error(f"Failed to index Linear issues: {e!s}", exc_info=True) + return 0, f"Failed to index Linear issues: {e!s}" + async def index_discord_messages( session: AsyncSession, connector_id: int, search_space_id: int, user_id: str, - start_date: str = None, - end_date: str = None, - update_last_indexed: bool = True -) -> Tuple[int, Optional[str]]: + start_date: str | None = None, + end_date: str | None = None, + update_last_indexed: bool = True, +) -> tuple[int, str | None]: """ Index Discord messages from all accessible channels. @@ -1309,28 +1563,33 @@ async def index_discord_messages( Tuple containing (number of documents indexed, error message or None) """ task_logger = TaskLoggingService(session, search_space_id) - + # Log task start log_entry = await task_logger.log_task_start( task_name="discord_messages_indexing", source="connector_indexing_task", message=f"Starting Discord messages indexing for connector {connector_id}", - metadata={"connector_id": connector_id, "user_id": str(user_id), "start_date": start_date, "end_date": end_date} + metadata={ + "connector_id": connector_id, + "user_id": str(user_id), + "start_date": start_date, + "end_date": end_date, + }, ) - + try: # Get the connector await task_logger.log_task_progress( log_entry, f"Retrieving Discord connector {connector_id} from database", - {"stage": "connector_retrieval"} + {"stage": "connector_retrieval"}, ) - + result = await session.execute( - select(SearchSourceConnector) - .filter( + select(SearchSourceConnector).filter( SearchSourceConnector.id == connector_id, - SearchSourceConnector.connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR + SearchSourceConnector.connector_type + == SearchSourceConnectorType.DISCORD_CONNECTOR, ) ) connector = result.scalars().first() @@ -1340,9 +1599,12 @@ async def index_discord_messages( log_entry, f"Connector with ID {connector_id} not found or is not a Discord connector", "Connector not found", - {"error_type": "ConnectorNotFound"} + {"error_type": "ConnectorNotFound"}, + ) + return ( + 0, + f"Connector with ID {connector_id} not found or is not a Discord connector", ) - return 0, f"Connector with ID {connector_id} not found or is not a Discord connector" # Get the Discord token from the connector config discord_token = connector.config.get("DISCORD_BOT_TOKEN") @@ -1351,7 +1613,7 @@ async def index_discord_messages( log_entry, f"Discord token not found in connector config for connector {connector_id}", "Missing Discord token", - {"error_type": "MissingToken"} + {"error_type": "MissingToken"}, ) return 0, "Discord token not found in connector config" @@ -1361,42 +1623,62 @@ async def index_discord_messages( await task_logger.log_task_progress( log_entry, f"Initializing Discord client for connector {connector_id}", - {"stage": "client_initialization"} + {"stage": "client_initialization"}, ) - + discord_client = DiscordConnector(token=discord_token) # Calculate date range if start_date is None or end_date is None: # Fall back to calculating dates based on last_indexed_at - calculated_end_date = datetime.now(timezone.utc) + calculated_end_date = datetime.now(UTC) # Use last_indexed_at as start date if available, otherwise use 365 days ago if connector.last_indexed_at: - calculated_start_date = connector.last_indexed_at.replace(tzinfo=timezone.utc) - logger.info(f"Using last_indexed_at ({calculated_start_date.strftime('%Y-%m-%d')}) as start date") + calculated_start_date = connector.last_indexed_at.replace(tzinfo=UTC) + logger.info( + f"Using last_indexed_at ({calculated_start_date.strftime('%Y-%m-%d')}) as start date" + ) else: calculated_start_date = calculated_end_date - timedelta(days=365) - logger.info(f"No last_indexed_at found, using {calculated_start_date.strftime('%Y-%m-%d')} (365 days ago) as start date") + logger.info( + f"No last_indexed_at found, using {calculated_start_date.strftime('%Y-%m-%d')} (365 days ago) as start date" + ) # Use calculated dates if not provided, convert to ISO format for Discord API if start_date is None: start_date_iso = calculated_start_date.isoformat() else: # Convert YYYY-MM-DD to ISO format - start_date_iso = datetime.strptime(start_date, "%Y-%m-%d").replace(tzinfo=timezone.utc).isoformat() - + start_date_iso = ( + datetime.strptime(start_date, "%Y-%m-%d") + .replace(tzinfo=UTC) + .isoformat() + ) + if end_date is None: end_date_iso = calculated_end_date.isoformat() else: - # Convert YYYY-MM-DD to ISO format - end_date_iso = datetime.strptime(end_date, "%Y-%m-%d").replace(tzinfo=timezone.utc).isoformat() + # Convert YYYY-MM-DD to ISO format + end_date_iso = ( + datetime.strptime(end_date, "%Y-%m-%d") + .replace(tzinfo=UTC) + .isoformat() + ) else: # Convert provided dates to ISO format for Discord API - start_date_iso = datetime.strptime(start_date, "%Y-%m-%d").replace(tzinfo=timezone.utc).isoformat() - end_date_iso = datetime.strptime(end_date, "%Y-%m-%d").replace(tzinfo=timezone.utc).isoformat() - - logger.info(f"Indexing Discord messages from {start_date_iso} to {end_date_iso}") + start_date_iso = ( + datetime.strptime(start_date, "%Y-%m-%d") + .replace(tzinfo=UTC) + .isoformat() + ) + end_date_iso = ( + datetime.strptime(end_date, "%Y-%m-%d").replace(tzinfo=UTC).isoformat() + ) + + logger.info( + f"Indexing Discord messages from {start_date_iso} to {end_date_iso}" + ) documents_indexed = 0 documents_skipped = 0 @@ -1406,9 +1688,9 @@ async def index_discord_messages( await task_logger.log_task_progress( log_entry, f"Starting Discord bot and fetching guilds for connector {connector_id}", - {"stage": "fetch_guilds"} + {"stage": "fetch_guilds"}, ) - + logger.info("Starting Discord bot to fetch guilds") discord_client._bot_task = asyncio.create_task(discord_client.start_bot()) await discord_client._wait_until_ready() @@ -1421,16 +1703,16 @@ async def index_discord_messages( log_entry, f"Failed to get Discord guilds for connector {connector_id}", str(e), - {"error_type": "GuildFetchError"} + {"error_type": "GuildFetchError"}, ) - logger.error(f"Failed to get Discord guilds: {str(e)}", exc_info=True) + logger.error(f"Failed to get Discord guilds: {e!s}", exc_info=True) await discord_client.close_bot() - return 0, f"Failed to get Discord guilds: {str(e)}" + return 0, f"Failed to get Discord guilds: {e!s}" if not guilds: await task_logger.log_task_success( log_entry, f"No Discord guilds found for connector {connector_id}", - {"guilds_found": 0} + {"guilds_found": 0}, ) logger.info("No Discord guilds found to index") await discord_client.close_bot() @@ -1440,9 +1722,9 @@ async def index_discord_messages( await task_logger.log_task_progress( log_entry, f"Starting to process {len(guilds)} Discord guilds", - {"stage": "process_guilds", "total_guilds": len(guilds)} + {"stage": "process_guilds", "total_guilds": len(guilds)}, ) - + for guild in guilds: guild_id = guild["id"] guild_name = guild["name"] @@ -1466,13 +1748,19 @@ async def index_discord_messages( end_date=end_date_iso, ) except Exception as e: - logger.error(f"Failed to get messages for channel {channel_name}: {str(e)}") - skipped_channels.append(f"{guild_name}#{channel_name} (fetch error)") + logger.error( + f"Failed to get messages for channel {channel_name}: {e!s}" + ) + skipped_channels.append( + f"{guild_name}#{channel_name} (fetch error)" + ) documents_skipped += 1 continue if not messages: - logger.info(f"No messages found in channel {channel_name} for the specified date range.") + logger.info( + f"No messages found in channel {channel_name} for the specified date range." + ) documents_skipped += 1 continue @@ -1485,33 +1773,45 @@ async def index_discord_messages( formatted_messages.append(msg) if not formatted_messages: - logger.info(f"No valid messages found in channel {channel_name} after filtering.") + logger.info( + f"No valid messages found in channel {channel_name} after filtering." + ) documents_skipped += 1 continue # Convert messages to markdown format - channel_content = f"# Discord Channel: {guild_name} / {channel_name}\n\n" + channel_content = ( + f"# Discord Channel: {guild_name} / {channel_name}\n\n" + ) for msg in formatted_messages: user_name = msg.get("author_name", "Unknown User") timestamp = msg.get("created_at", "Unknown Time") text = msg.get("content", "") - channel_content += f"## {user_name} ({timestamp})\n\n{text}\n\n---\n\n" + channel_content += ( + f"## {user_name} ({timestamp})\n\n{text}\n\n---\n\n" + ) # Format document metadata metadata_sections = [ - ("METADATA", [ - f"GUILD_NAME: {guild_name}", - f"GUILD_ID: {guild_id}", - f"CHANNEL_NAME: {channel_name}", - f"CHANNEL_ID: {channel_id}", - f"MESSAGE_COUNT: {len(formatted_messages)}" - ]), - ("CONTENT", [ - "FORMAT: markdown", - "TEXT_START", - channel_content, - "TEXT_END" - ]) + ( + "METADATA", + [ + f"GUILD_NAME: {guild_name}", + f"GUILD_ID: {guild_id}", + f"CHANNEL_NAME: {channel_name}", + f"CHANNEL_ID: {channel_id}", + f"MESSAGE_COUNT: {len(formatted_messages)}", + ], + ), + ( + "CONTENT", + [ + "FORMAT: markdown", + "TEXT_START", + channel_content, + "TEXT_END", + ], + ), ] # Build the document string @@ -1522,31 +1822,43 @@ async def index_discord_messages( document_parts.extend(section_content) document_parts.append(f"") document_parts.append("") - combined_document_string = '\n'.join(document_parts) - content_hash = generate_content_hash(combined_document_string, search_space_id) + combined_document_string = "\n".join(document_parts) + content_hash = generate_content_hash( + combined_document_string, search_space_id + ) # Check if document with this content hash already exists existing_doc_by_hash_result = await session.execute( select(Document).where(Document.content_hash == content_hash) ) - existing_document_by_hash = existing_doc_by_hash_result.scalars().first() + existing_document_by_hash = ( + existing_doc_by_hash_result.scalars().first() + ) if existing_document_by_hash: - logger.info(f"Document with content hash {content_hash} already exists for channel {guild_name}#{channel_name}. Skipping processing.") + logger.info( + f"Document with content hash {content_hash} already exists for channel {guild_name}#{channel_name}. Skipping processing." + ) documents_skipped += 1 continue # Get user's long context LLM user_llm = await get_user_long_context_llm(session, user_id) if not user_llm: - logger.error(f"No long context LLM configured for user {user_id}") - skipped_channels.append(f"{guild_name}#{channel_name} (no LLM configured)") + logger.error( + f"No long context LLM configured for user {user_id}" + ) + skipped_channels.append( + f"{guild_name}#{channel_name} (no LLM configured)" + ) documents_skipped += 1 continue # Generate summary using summary_chain summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm - summary_result = await summary_chain.ainvoke({"document": combined_document_string}) + summary_result = await summary_chain.ainvoke( + {"document": combined_document_string} + ) summary_content = summary_result.content summary_embedding = await asyncio.to_thread( config.embedding_model_instance.embed, summary_content @@ -1554,19 +1866,24 @@ async def index_discord_messages( # Process chunks raw_chunks = await asyncio.to_thread( - config.chunker_instance.chunk, - channel_content + config.chunker_instance.chunk, channel_content ) - chunk_texts = [chunk.text for chunk in raw_chunks if chunk.text.strip()] + chunk_texts = [ + chunk.text for chunk in raw_chunks if chunk.text.strip() + ] chunk_embeddings = await asyncio.to_thread( - lambda texts: [config.embedding_model_instance.embed(t) for t in texts], - chunk_texts + lambda texts: [ + config.embedding_model_instance.embed(t) for t in texts + ], + chunk_texts, ) chunks = [ Chunk(content=raw_chunk.text, embedding=embedding) - for raw_chunk, embedding in zip(raw_chunks, chunk_embeddings) + for raw_chunk, embedding in zip( + raw_chunks, chunk_embeddings, strict=False + ) ] # Create and store new document @@ -1582,26 +1899,32 @@ async def index_discord_messages( "message_count": len(formatted_messages), "start_date": start_date_iso, "end_date": end_date_iso, - "indexed_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S") + "indexed_at": datetime.now(UTC).strftime( + "%Y-%m-%d %H:%M:%S" + ), }, content=summary_content, content_hash=content_hash, embedding=summary_embedding, - chunks=chunks + chunks=chunks, ) session.add(document) documents_indexed += 1 - logger.info(f"Successfully indexed new channel {guild_name}#{channel_name} with {len(formatted_messages)} messages") + logger.info( + f"Successfully indexed new channel {guild_name}#{channel_name} with {len(formatted_messages)} messages" + ) except Exception as e: - logger.error(f"Error processing guild {guild_name}: {str(e)}", exc_info=True) + logger.error( + f"Error processing guild {guild_name}: {e!s}", exc_info=True + ) skipped_channels.append(f"{guild_name} (processing error)") documents_skipped += 1 continue if update_last_indexed and documents_indexed > 0: - connector.last_indexed_at = datetime.now(timezone.utc) + connector.last_indexed_at = datetime.now(UTC) logger.info(f"Updated last_indexed_at to {connector.last_indexed_at}") await session.commit() @@ -1624,11 +1947,13 @@ async def index_discord_messages( "documents_skipped": documents_skipped, "skipped_channels_count": len(skipped_channels), "guilds_processed": len(guilds), - "result_message": result_message - } + "result_message": result_message, + }, ) - logger.info(f"Discord indexing completed: {documents_indexed} new channels, {documents_skipped} skipped") + logger.info( + f"Discord indexing completed: {documents_indexed} new channels, {documents_skipped} skipped" + ) return documents_indexed, result_message except SQLAlchemyError as db_error: @@ -1637,17 +1962,19 @@ async def index_discord_messages( log_entry, f"Database error during Discord indexing for connector {connector_id}", str(db_error), - {"error_type": "SQLAlchemyError"} + {"error_type": "SQLAlchemyError"}, ) - logger.error(f"Database error during Discord indexing: {str(db_error)}", exc_info=True) - return 0, f"Database error: {str(db_error)}" + logger.error( + f"Database error during Discord indexing: {db_error!s}", exc_info=True + ) + return 0, f"Database error: {db_error!s}" except Exception as e: await session.rollback() await task_logger.log_task_failure( log_entry, f"Failed to index Discord messages for connector {connector_id}", str(e), - {"error_type": type(e).__name__} + {"error_type": type(e).__name__}, ) - logger.error(f"Failed to index Discord messages: {str(e)}", exc_info=True) - return 0, f"Failed to index Discord messages: {str(e)}" + logger.error(f"Failed to index Discord messages: {e!s}", exc_info=True) + return 0, f"Failed to index Discord messages: {e!s}" diff --git a/surfsense_backend/app/tasks/podcast_tasks.py b/surfsense_backend/app/tasks/podcast_tasks.py index f4907af..6bc3510 100644 --- a/surfsense_backend/app/tasks/podcast_tasks.py +++ b/surfsense_backend/app/tasks/podcast_tasks.py @@ -1,33 +1,29 @@ +from sqlalchemy import select +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.ext.asyncio import AsyncSession from app.agents.podcaster.graph import graph as podcaster_graph from app.agents.podcaster.state import State from app.db import Chat, Podcast from app.services.task_logging_service import TaskLoggingService -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.exc import SQLAlchemyError async def generate_document_podcast( - session: AsyncSession, - document_id: int, - search_space_id: int, - user_id: int + session: AsyncSession, document_id: int, search_space_id: int, user_id: int ): # TODO: Need to fetch the document chunks, then concatenate them and pass them to the podcast generation model pass - async def generate_chat_podcast( session: AsyncSession, chat_id: int, search_space_id: int, podcast_title: str, - user_id: int + user_id: int, ): task_logger = TaskLoggingService(session, search_space_id) - + # Log task start log_entry = await task_logger.log_task_start( task_name="generate_chat_podcast", @@ -37,44 +33,43 @@ async def generate_chat_podcast( "chat_id": chat_id, "search_space_id": search_space_id, "podcast_title": podcast_title, - "user_id": str(user_id) - } + "user_id": str(user_id), + }, ) - + try: # Fetch the chat with the specified ID await task_logger.log_task_progress( - log_entry, - f"Fetching chat {chat_id} from database", - {"stage": "fetch_chat"} + log_entry, f"Fetching chat {chat_id} from database", {"stage": "fetch_chat"} ) - + query = select(Chat).filter( - Chat.id == chat_id, - Chat.search_space_id == search_space_id + Chat.id == chat_id, Chat.search_space_id == search_space_id ) - + result = await session.execute(query) chat = result.scalars().first() - + if not chat: await task_logger.log_task_failure( log_entry, f"Chat with id {chat_id} not found in search space {search_space_id}", "Chat not found", - {"error_type": "ChatNotFound"} + {"error_type": "ChatNotFound"}, ) - raise ValueError(f"Chat with id {chat_id} not found in search space {search_space_id}") - + raise ValueError( + f"Chat with id {chat_id} not found in search space {search_space_id}" + ) + # Create chat history structure await task_logger.log_task_progress( log_entry, f"Processing chat history for chat {chat_id}", - {"stage": "process_chat_history", "message_count": len(chat.messages)} + {"stage": "process_chat_history", "message_count": len(chat.messages)}, ) - + chat_history_str = "" - + processed_messages = 0 for message in chat.messages: if message["role"] == "user": @@ -89,18 +84,24 @@ async def generate_chat_podcast( # If content is a list, join it into a single string if isinstance(answer_text, list): answer_text = "\n".join(answer_text) - chat_history_str += f"{answer_text}" + chat_history_str += ( + f"{answer_text}" + ) processed_messages += 1 - + chat_history_str += "" - + # Pass it to the SurfSense Podcaster await task_logger.log_task_progress( log_entry, f"Initializing podcast generation for chat {chat_id}", - {"stage": "initialize_podcast_generation", "processed_messages": processed_messages, "content_length": len(chat_history_str)} + { + "stage": "initialize_podcast_generation", + "processed_messages": processed_messages, + "content_length": len(chat_history_str), + }, ) - + config = { "configurable": { "podcast_title": "SurfSense", @@ -108,53 +109,55 @@ async def generate_chat_podcast( } } # Initialize state with database session and streaming service - initial_state = State( - source_content=chat_history_str, - db_session=session - ) - + initial_state = State(source_content=chat_history_str, db_session=session) + # Run the graph directly await task_logger.log_task_progress( log_entry, f"Running podcast generation graph for chat {chat_id}", - {"stage": "run_podcast_graph"} + {"stage": "run_podcast_graph"}, ) - + result = await podcaster_graph.ainvoke(initial_state, config=config) - + # Convert podcast transcript entries to serializable format await task_logger.log_task_progress( log_entry, f"Processing podcast transcript for chat {chat_id}", - {"stage": "process_transcript", "transcript_entries": len(result["podcast_transcript"])} + { + "stage": "process_transcript", + "transcript_entries": len(result["podcast_transcript"]), + }, ) - + serializable_transcript = [] for entry in result["podcast_transcript"]: - serializable_transcript.append({ - "speaker_id": entry.speaker_id, - "dialog": entry.dialog - }) - + serializable_transcript.append( + {"speaker_id": entry.speaker_id, "dialog": entry.dialog} + ) + # Create a new podcast entry await task_logger.log_task_progress( log_entry, f"Creating podcast database entry for chat {chat_id}", - {"stage": "create_podcast_entry", "file_location": result.get("final_podcast_file_path")} + { + "stage": "create_podcast_entry", + "file_location": result.get("final_podcast_file_path"), + }, ) - + podcast = Podcast( title=f"{podcast_title}", podcast_transcript=serializable_transcript, file_location=result["final_podcast_file_path"], - search_space_id=search_space_id + search_space_id=search_space_id, ) - + # Add to session and commit session.add(podcast) await session.commit() await session.refresh(podcast) - + # Log success await task_logger.log_task_success( log_entry, @@ -165,10 +168,10 @@ async def generate_chat_podcast( "transcript_entries": len(serializable_transcript), "file_location": result.get("final_podcast_file_path"), "processed_messages": processed_messages, - "content_length": len(chat_history_str) - } + "content_length": len(chat_history_str), + }, ) - + return podcast except ValueError as ve: @@ -178,7 +181,7 @@ async def generate_chat_podcast( log_entry, f"Value error during podcast generation for chat {chat_id}", str(ve), - {"error_type": "ValueError"} + {"error_type": "ValueError"}, ) raise ve except SQLAlchemyError as db_error: @@ -187,7 +190,7 @@ async def generate_chat_podcast( log_entry, f"Database error during podcast generation for chat {chat_id}", str(db_error), - {"error_type": "SQLAlchemyError"} + {"error_type": "SQLAlchemyError"}, ) raise db_error except Exception as e: @@ -196,7 +199,8 @@ async def generate_chat_podcast( log_entry, f"Unexpected error during podcast generation for chat {chat_id}", str(e), - {"error_type": type(e).__name__} + {"error_type": type(e).__name__}, ) - raise RuntimeError(f"Failed to generate podcast for chat {chat_id}: {str(e)}") - + raise RuntimeError( + f"Failed to generate podcast for chat {chat_id}: {e!s}" + ) from e diff --git a/surfsense_backend/app/tasks/stream_connector_search_results.py b/surfsense_backend/app/tasks/stream_connector_search_results.py index f66bf1a..ead6a89 100644 --- a/surfsense_backend/app/tasks/stream_connector_search_results.py +++ b/surfsense_backend/app/tasks/stream_connector_search_results.py @@ -1,28 +1,29 @@ -from typing import Any, AsyncGenerator, List, Union +from collections.abc import AsyncGenerator +from typing import Any from uuid import UUID -from app.agents.researcher.graph import graph as researcher_graph -from app.agents.researcher.state import State -from app.services.streaming_service import StreamingService from sqlalchemy.ext.asyncio import AsyncSession from app.agents.researcher.configuration import SearchMode +from app.agents.researcher.graph import graph as researcher_graph +from app.agents.researcher.state import State +from app.services.streaming_service import StreamingService async def stream_connector_search_results( - user_query: str, - user_id: Union[str, UUID], - search_space_id: int, - session: AsyncSession, - research_mode: str, - selected_connectors: List[str], - langchain_chat_history: List[Any], + user_query: str, + user_id: str | UUID, + search_space_id: int, + session: AsyncSession, + research_mode: str, + selected_connectors: list[str], + langchain_chat_history: list[Any], search_mode_str: str, - document_ids_to_add_in_context: List[int] + document_ids_to_add_in_context: list[int], ) -> AsyncGenerator[str, None]: """ Stream connector search results to the client - + Args: user_query: The user's query user_id: The user's ID (can be UUID object or string) @@ -30,61 +31,60 @@ async def stream_connector_search_results( session: The database session research_mode: The research mode selected_connectors: List of selected connectors - + Yields: str: Formatted response strings """ streaming_service = StreamingService() - + if research_mode == "REPORT_GENERAL": - NUM_SECTIONS = 1 + num_sections = 1 elif research_mode == "REPORT_DEEP": - NUM_SECTIONS = 3 + num_sections = 3 elif research_mode == "REPORT_DEEPER": - NUM_SECTIONS = 6 + num_sections = 6 else: # Default fallback - NUM_SECTIONS = 1 - + num_sections = 1 + # Convert UUID to string if needed user_id_str = str(user_id) if isinstance(user_id, UUID) else user_id - + if search_mode_str == "CHUNKS": search_mode = SearchMode.CHUNKS elif search_mode_str == "DOCUMENTS": search_mode = SearchMode.DOCUMENTS - + # Sample configuration config = { "configurable": { "user_query": user_query, - "num_sections": NUM_SECTIONS, + "num_sections": num_sections, "connectors_to_search": selected_connectors, "user_id": user_id_str, "search_space_id": search_space_id, "search_mode": search_mode, "research_mode": research_mode, - "document_ids_to_add_in_context": document_ids_to_add_in_context + "document_ids_to_add_in_context": document_ids_to_add_in_context, } } # Initialize state with database session and streaming service initial_state = State( db_session=session, streaming_service=streaming_service, - chat_history=langchain_chat_history + chat_history=langchain_chat_history, ) - + # Run the graph directly print("\nRunning the complete researcher workflow...") - + # Use streaming with config parameter async for chunk in researcher_graph.astream( initial_state, config=config, stream_mode="custom", ): - if isinstance(chunk, dict): - if "yield_value" in chunk: - yield chunk["yield_value"] + if isinstance(chunk, dict) and "yield_value" in chunk: + yield chunk["yield_value"] yield streaming_service.format_completion() diff --git a/surfsense_backend/app/users.py b/surfsense_backend/app/users.py index d73baae..54db064 100644 --- a/surfsense_backend/app/users.py +++ b/surfsense_backend/app/users.py @@ -1,8 +1,7 @@ -from typing import Optional import uuid from fastapi import Depends, Request, Response -from fastapi.responses import RedirectResponse +from fastapi.responses import JSONResponse, RedirectResponse from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models from fastapi_users.authentication import ( AuthenticationBackend, @@ -10,21 +9,23 @@ from fastapi_users.authentication import ( JWTStrategy, ) from fastapi_users.db import SQLAlchemyUserDatabase -from fastapi.responses import JSONResponse from fastapi_users.schemas import model_dump +from pydantic import BaseModel + from app.config import config from app.db import User, get_user_db -from pydantic import BaseModel + class BearerResponse(BaseModel): access_token: str token_type: str + SECRET = config.SECRET_KEY if config.AUTH_TYPE == "GOOGLE": from httpx_oauth.clients.google import GoogleOAuth2 - + google_oauth_client = GoogleOAuth2( config.GOOGLE_OAUTH_CLIENT_ID, config.GOOGLE_OAUTH_CLIENT_SECRET, @@ -35,27 +36,26 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): reset_password_token_secret = SECRET verification_token_secret = SECRET - async def on_after_register(self, user: User, request: Optional[Request] = None): + async def on_after_register(self, user: User, request: Request | None = None): print(f"User {user.id} has registered.") async def on_after_forgot_password( - self, user: User, token: str, request: Optional[Request] = None + self, user: User, token: str, request: Request | None = None ): print(f"User {user.id} has forgot their password. Reset token: {token}") async def on_after_request_verify( - self, user: User, token: str, request: Optional[Request] = None + self, user: User, token: str, request: Request | None = None ): - print( - f"Verification requested for user {user.id}. Verification token: {token}") + print(f"Verification requested for user {user.id}. Verification token: {token}") async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)): yield UserManager(user_db) - - + + def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]: - return JWTStrategy(secret=SECRET, lifetime_seconds=3600*24) + return JWTStrategy(secret=SECRET, lifetime_seconds=3600 * 24) # # COOKIE AUTH | Uncomment if you want to use cookie auth. @@ -77,6 +77,7 @@ def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]: # get_strategy=get_jwt_strategy, # ) + # BEARER AUTH CODE. class CustomBearerTransport(BearerTransport): async def get_login_response(self, token: str) -> Response: @@ -87,6 +88,7 @@ class CustomBearerTransport(BearerTransport): else: return JSONResponse(model_dump(bearer_response)) + bearer_transport = CustomBearerTransport(tokenUrl="auth/jwt/login") @@ -98,4 +100,4 @@ auth_backend = AuthenticationBackend( fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend]) -current_active_user = fastapi_users.current_user(active=True) \ No newline at end of file +current_active_user = fastapi_users.current_user(active=True) diff --git a/surfsense_backend/app/utils/check_ownership.py b/surfsense_backend/app/utils/check_ownership.py index 3ea21c2..0bd290f 100644 --- a/surfsense_backend/app/utils/check_ownership.py +++ b/surfsense_backend/app/utils/check_ownership.py @@ -1,12 +1,19 @@ from fastapi import HTTPException from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select + from app.db import User + # Helper function to check user ownership async def check_ownership(session: AsyncSession, model, item_id: int, user: User): - item = await session.execute(select(model).filter(model.id == item_id, model.user_id == user.id)) + item = await session.execute( + select(model).filter(model.id == item_id, model.user_id == user.id) + ) item = item.scalars().first() if not item: - raise HTTPException(status_code=404, detail="Item not found or you don't have permission to access it") - return item \ No newline at end of file + raise HTTPException( + status_code=404, + detail="Item not found or you don't have permission to access it", + ) + return item diff --git a/surfsense_backend/app/utils/document_converters.py b/surfsense_backend/app/utils/document_converters.py index ab8ba4d..3b23f54 100644 --- a/surfsense_backend/app/utils/document_converters.py +++ b/surfsense_backend/app/utils/document_converters.py @@ -32,7 +32,7 @@ async def convert_element_to_markdown(element) -> str: "Footer": lambda x: f"*{x}*\n\n", "CodeSnippet": lambda x: f"```\n{x}\n```", "PageNumber": lambda x: f"*Page {x}*\n\n", - "UncategorizedText": lambda x: f"{x}\n\n" + "UncategorizedText": lambda x: f"{x}\n\n", } converter = markdown_mapping.get(element_category, lambda x: x) @@ -74,7 +74,7 @@ def convert_chunks_to_langchain_documents(chunks): except ImportError: raise ImportError( "LangChain is not installed. Please install it with `pip install langchain langchain-core`" - ) + ) from None langchain_docs = [] @@ -92,17 +92,20 @@ def convert_chunks_to_langchain_documents(chunks): # Add document information to metadata if "document" in chunk: doc = chunk["document"] - metadata.update({ - "document_id": doc.get("id"), - "document_title": doc.get("title"), - "document_type": doc.get("document_type"), - }) + metadata.update( + { + "document_id": doc.get("id"), + "document_title": doc.get("title"), + "document_type": doc.get("document_type"), + } + ) # Add document metadata if available if "metadata" in doc: # Prefix document metadata keys to avoid conflicts - doc_metadata = {f"doc_meta_{k}": v for k, - v in doc.get("metadata", {}).items()} + doc_metadata = { + f"doc_meta_{k}": v for k, v in doc.get("metadata", {}).items() + } metadata.update(doc_metadata) # Add source URL if available in metadata @@ -131,10 +134,7 @@ def convert_chunks_to_langchain_documents(chunks): """ # Create LangChain Document - langchain_doc = LangChainDocument( - page_content=new_content, - metadata=metadata - ) + langchain_doc = LangChainDocument(page_content=new_content, metadata=metadata) langchain_docs.append(langchain_doc) @@ -144,4 +144,4 @@ def convert_chunks_to_langchain_documents(chunks): def generate_content_hash(content: str, search_space_id: int) -> str: """Generate SHA-256 hash for the given content combined with search space ID.""" combined_data = f"{search_space_id}:{content}" - return hashlib.sha256(combined_data.encode('utf-8')).hexdigest() + return hashlib.sha256(combined_data.encode("utf-8")).hexdigest() diff --git a/surfsense_backend/main.py b/surfsense_backend/main.py index f44b7d6..6a86bbd 100644 --- a/surfsense_backend/main.py +++ b/surfsense_backend/main.py @@ -1,7 +1,9 @@ -import uvicorn import argparse import logging + +import uvicorn from dotenv import load_dotenv + from app.config.uvicorn import load_uvicorn_config logging.basicConfig( diff --git a/surfsense_backend/pyproject.toml b/surfsense_backend/pyproject.toml index c37eda1..21e86da 100644 --- a/surfsense_backend/pyproject.toml +++ b/surfsense_backend/pyproject.toml @@ -103,6 +103,8 @@ ignore = [ "RUF012", # Mutable class attributes should be annotated with `typing.ClassVar` ] +extend-select = ["I"] + # Allow fix for all enabled rules (when `--fix`) is provided. fixable = ["ALL"] unfixable = [] @@ -123,6 +125,7 @@ skip-magic-trailing-comma = false # Automatically detect the appropriate line ending. line-ending = "auto" + [tool.ruff.lint.isort] # Group imports by type known-first-party = ["app"]