From ebfc2cb6798f87d98c779793673aaf9f2ad0f9aa Mon Sep 17 00:00:00 2001 From: Muhamad Aji Wibisono Date: Mon, 2 Jun 2025 21:04:13 +0700 Subject: [PATCH] feat: optimized discord indexing by matching the document pattern --- .../app/tasks/connectors_indexing_tasks.py | 199 +++++++++++------- 1 file changed, 126 insertions(+), 73 deletions(-) diff --git a/surfsense_backend/app/tasks/connectors_indexing_tasks.py b/surfsense_backend/app/tasks/connectors_indexing_tasks.py index 9904573..0331725 100644 --- a/surfsense_backend/app/tasks/connectors_indexing_tasks.py +++ b/surfsense_backend/app/tasks/connectors_indexing_tasks.py @@ -12,7 +12,6 @@ from app.connectors.notion_history import NotionHistoryConnector from app.connectors.github_connector import GitHubConnector from app.connectors.linear_connector import LinearConnector from app.connectors.discord_connector import DiscordConnector -from discord import DiscordException from slack_sdk.errors import SlackApiError import logging import asyncio @@ -924,13 +923,13 @@ async def index_discord_messages( ) -> Tuple[int, Optional[str]]: """ Index Discord messages from all accessible channels. - + Args: session: Database session connector_id: ID of the Discord connector search_space_id: ID of the search space to store documents in update_last_indexed: Whether to update the last_indexed_at timestamp (default: True) - + Returns: Tuple containing (number of documents indexed, error message or None) """ @@ -944,39 +943,39 @@ async def index_discord_messages( ) ) connector = result.scalars().first() - + if not connector: return 0, f"Connector with ID {connector_id} not found or is not a Discord connector" - + # Get the Discord token from the connector config discord_token = connector.config.get("DISCORD_BOT_TOKEN") if not discord_token: return 0, "Discord token not found in connector config" - + logger.info(f"Starting Discord indexing for connector {connector_id}") # Initialize Discord client discord_client = DiscordConnector(token=discord_token) - + # Calculate date range end_date = datetime.now(timezone.utc) - + # Use last_indexed_at as start date if available, otherwise use 365 days ago if connector.last_indexed_at: start_date = connector.last_indexed_at.replace(tzinfo=timezone.utc) logger.info(f"Using last_indexed_at ({start_date.strftime('%Y-%m-%d')}) as start date") else: - start_date = end_date - timedelta(days=365) # Use 365 days as default + start_date = end_date - timedelta(days=365) logger.info(f"No last_indexed_at found, using {start_date.strftime('%Y-%m-%d')} (365 days ago) as start date") - + # Format dates for Discord API start_date_str = start_date.isoformat() end_date_str = end_date.isoformat() - + documents_indexed = 0 documents_skipped = 0 - skipped_guilds = [] - + skipped_channels = [] + try: logger.info("Starting Discord bot to fetch guilds") discord_client._bot_task = asyncio.create_task(discord_client.start_bot()) @@ -987,33 +986,30 @@ async def index_discord_messages( logger.info(f"Found {len(guilds)} guilds") except Exception as e: logger.error(f"Failed to get Discord guilds: {str(e)}", exc_info=True) - await discord_client.close_bot() return 0, f"Failed to get Discord guilds: {str(e)}" if not guilds: logger.info("No Discord guilds found to index") - await discord_client.close_bot() return 0, "No Discord guilds found" - - # Process each guild + + # Process each guild and channel for guild in guilds: guild_id = guild["id"] guild_name = guild["name"] logger.info(f"Processing guild: {guild_name} ({guild_id})") try: channels = await discord_client.get_text_channels(guild_id) - if not channels: logger.info(f"No channels found in guild {guild_name}. Skipping.") - skipped_guilds.append(f"{guild_name} (no channels)") + skipped_channels.append(f"{guild_name} (no channels)") documents_skipped += 1 continue for channel in channels: channel_id = channel["id"] channel_name = channel["name"] - + try: messages = await discord_client.get_channel_history( channel_id=channel_id, @@ -1022,66 +1018,115 @@ async def index_discord_messages( ) except Exception as e: logger.error(f"Failed to get messages for channel {channel_name}: {str(e)}") + skipped_channels.append(f"{guild_name}#{channel_name} (fetch error)") documents_skipped += 1 continue if not messages: + logger.info(f"No messages found in channel {channel_name} for the specified date range.") + documents_skipped += 1 continue - - for message in messages: - try: - content = message.get("content", "") - if not content: - continue - content_hash = generate_content_hash(content) - existing_doc_by_hash_result = await session.execute( - select(Document).where(Document.content_hash == content_hash) - ) - existing_document_by_hash = existing_doc_by_hash_result.scalars().first() - - if existing_document_by_hash: - documents_skipped += 1 - continue - - summary_content = f"Discord message by {message.get('author_name', 'Unknown')} in {channel_name} ({guild_name})\n\n{content}" - summary_embedding = config.embedding_model_instance.embed(summary_content) - chunks = [ - Chunk(content=chunk.text, embedding=config.embedding_model_instance.embed(chunk.text)) - for chunk in config.chunker_instance.chunk(content) - ] - document = Document( - search_space_id=search_space_id, - title=f"Discord - {guild_name}#{channel_name}", - document_type=DocumentType.DISCORD_CONNECTOR, - document_metadata={ - "guild_id": guild_id, - "guild_name": guild_name, - "channel_id": channel_id, - "channel_name": channel_name, - "message_id": message.get("id"), - "author_id": message.get("author_id"), - "author_name": message.get("author_name"), - "created_at": message.get("created_at"), - "indexed_at": datetime.now(timezone.utc).isoformat() - }, - content=summary_content, - content_hash=content_hash, - embedding=summary_embedding, - chunks=chunks - ) - - session.add(document) - documents_indexed += 1 - - except Exception as e: - logger.error(f"Error processing Discord message: {str(e)}", exc_info=True) - documents_skipped += 1 + # Format messages + formatted_messages = [] + for msg in messages: + # Skip system messages if needed (Discord has some types) + if msg.get("type") in ["system"]: continue + formatted_messages.append(msg) + + if not formatted_messages: + logger.info(f"No valid messages found in channel {channel_name} after filtering.") + documents_skipped += 1 + continue + + # Convert messages to markdown format + channel_content = f"# Discord Channel: {guild_name} / {channel_name}\n\n" + for msg in formatted_messages: + user_name = msg.get("author_name", "Unknown User") + timestamp = msg.get("created_at", "Unknown Time") + text = msg.get("content", "") + channel_content += f"## {user_name} ({timestamp})\n\n{text}\n\n---\n\n" + + # Format document metadata + metadata_sections = [ + ("METADATA", [ + f"GUILD_NAME: {guild_name}", + f"GUILD_ID: {guild_id}", + f"CHANNEL_NAME: {channel_name}", + f"CHANNEL_ID: {channel_id}", + f"MESSAGE_COUNT: {len(formatted_messages)}" + ]), + ("CONTENT", [ + "FORMAT: markdown", + "TEXT_START", + channel_content, + "TEXT_END" + ]) + ] + + # Build the document string + document_parts = [] + document_parts.append("") + for section_title, section_content in metadata_sections: + document_parts.append(f"<{section_title}>") + document_parts.extend(section_content) + document_parts.append(f"") + document_parts.append("") + combined_document_string = '\n'.join(document_parts) + content_hash = generate_content_hash(combined_document_string) + + # Check if document with this content hash already exists + existing_doc_by_hash_result = await session.execute( + select(Document).where(Document.content_hash == content_hash) + ) + existing_document_by_hash = existing_doc_by_hash_result.scalars().first() + + if existing_document_by_hash: + logger.info(f"Document with content hash {content_hash} already exists for channel {guild_name}#{channel_name}. Skipping processing.") + documents_skipped += 1 + continue + + # Generate summary using summary_chain + summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance + summary_result = await summary_chain.ainvoke({"document": combined_document_string}) + summary_content = summary_result.content + summary_embedding = config.embedding_model_instance.embed(summary_content) + + # Process chunks + chunks = [ + Chunk(content=chunk.text, embedding=config.embedding_model_instance.embed(chunk.text)) + for chunk in config.chunker_instance.chunk(channel_content) + ] + + # Create and store new document + document = Document( + search_space_id=search_space_id, + title=f"Discord - {guild_name}#{channel_name}", + document_type=DocumentType.DISCORD_CONNECTOR, + document_metadata={ + "guild_name": guild_name, + "guild_id": guild_id, + "channel_name": channel_name, + "channel_id": channel_id, + "message_count": len(formatted_messages), + "start_date": start_date_str, + "end_date": end_date_str, + "indexed_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S") + }, + content=summary_content, + content_hash=content_hash, + embedding=summary_embedding, + chunks=chunks + ) + + session.add(document) + documents_indexed += 1 + logger.info(f"Successfully indexed new channel {guild_name}#{channel_name} with {len(formatted_messages)} messages") except Exception as e: logger.error(f"Error processing guild {guild_name}: {str(e)}", exc_info=True) - skipped_guilds.append(f"{guild_name} (processing error)") + skipped_channels.append(f"{guild_name} (processing error)") documents_skipped += 1 continue @@ -1091,9 +1136,17 @@ async def index_discord_messages( await session.commit() await discord_client.close_bot() - logger.info(f"Discord indexing completed: {documents_indexed} new messages, {documents_skipped} skipped") - return documents_indexed, None - + + # Prepare result message + result_message = None + if skipped_channels: + result_message = f"Processed {documents_indexed} channels. Skipped {len(skipped_channels)} channels: {', '.join(skipped_channels)}" + else: + result_message = f"Processed {documents_indexed} channels." + + logger.info(f"Discord indexing completed: {documents_indexed} new channels, {documents_skipped} skipped") + return documents_indexed, result_message + except SQLAlchemyError as db_error: await session.rollback() logger.error(f"Database error during Discord indexing: {str(db_error)}", exc_info=True)