mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-02 10:39:13 +00:00
chore: Added direct handling for markdown files.
- Fixed podcast imports.
This commit is contained in:
parent
704af3e4d1
commit
1586a0bd78
10 changed files with 118 additions and 59 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -1,2 +1,2 @@
|
||||||
.flashrank_cache*
|
.flashrank_cache*
|
||||||
podcasts/*
|
podcasts/
|
||||||
|
|
|
@ -6,7 +6,7 @@ from app.db import get_async_session, User, SearchSpace, Document, DocumentType
|
||||||
from app.schemas import DocumentsCreate, DocumentUpdate, DocumentRead
|
from app.schemas import DocumentsCreate, DocumentUpdate, DocumentRead
|
||||||
from app.users import current_active_user
|
from app.users import current_active_user
|
||||||
from app.utils.check_ownership import check_ownership
|
from app.utils.check_ownership import check_ownership
|
||||||
from app.tasks.background_tasks import add_extension_received_document, add_received_file_document, add_crawled_url_document, add_youtube_video_document
|
from app.tasks.background_tasks import add_received_markdown_file_document, add_extension_received_document, add_received_file_document, add_crawled_url_document, add_youtube_video_document
|
||||||
# Force asyncio to use standard event loop before unstructured imports
|
# Force asyncio to use standard event loop before unstructured imports
|
||||||
import asyncio
|
import asyncio
|
||||||
try:
|
try:
|
||||||
|
@ -15,9 +15,8 @@ except RuntimeError:
|
||||||
pass
|
pass
|
||||||
import os
|
import os
|
||||||
os.environ["UNSTRUCTURED_HAS_PATCHED_LOOP"] = "1"
|
os.environ["UNSTRUCTURED_HAS_PATCHED_LOOP"] = "1"
|
||||||
from langchain_unstructured import UnstructuredLoader
|
|
||||||
from app.config import config
|
|
||||||
import json
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
@ -132,11 +131,32 @@ async def process_file_in_background(
|
||||||
session: AsyncSession
|
session: AsyncSession
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
# Check if the file is a markdown file
|
||||||
|
if filename.lower().endswith(('.md', '.markdown')):
|
||||||
|
# For markdown files, read the content directly
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
markdown_content = f.read()
|
||||||
|
|
||||||
|
# Clean up the temp file
|
||||||
|
import os
|
||||||
|
try:
|
||||||
|
os.unlink(file_path)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Process markdown directly through specialized function
|
||||||
|
await add_received_markdown_file_document(
|
||||||
|
session,
|
||||||
|
filename,
|
||||||
|
markdown_content,
|
||||||
|
search_space_id
|
||||||
|
)
|
||||||
|
else:
|
||||||
# Use synchronous unstructured API to avoid event loop issues
|
# Use synchronous unstructured API to avoid event loop issues
|
||||||
from langchain_community.document_loaders import UnstructuredFileLoader
|
from langchain_unstructured import UnstructuredLoader
|
||||||
|
|
||||||
# Process the file
|
# Process the file
|
||||||
loader = UnstructuredFileLoader(
|
loader = UnstructuredLoader(
|
||||||
file_path,
|
file_path,
|
||||||
mode="elements",
|
mode="elements",
|
||||||
post_processors=[],
|
post_processors=[],
|
||||||
|
@ -146,7 +166,7 @@ async def process_file_in_background(
|
||||||
strategy="auto",
|
strategy="auto",
|
||||||
)
|
)
|
||||||
|
|
||||||
docs = loader.load()
|
docs = await loader.aload()
|
||||||
|
|
||||||
# Clean up the temp file
|
# Clean up the temp file
|
||||||
import os
|
import os
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from app.db import ChatType
|
from app.db import ChatType
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from .base import IDModel, TimestampModel
|
from .base import IDModel, TimestampModel
|
||||||
|
|
||||||
|
@ -44,5 +44,4 @@ class ChatUpdate(ChatBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class ChatRead(ChatBase, IDModel, TimestampModel):
|
class ChatRead(ChatBase, IDModel, TimestampModel):
|
||||||
class Config:
|
model_config = ConfigDict(from_attributes=True)
|
||||||
from_attributes = True
|
|
|
@ -1,4 +1,4 @@
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, ConfigDict
|
||||||
from .base import IDModel, TimestampModel
|
from .base import IDModel, TimestampModel
|
||||||
|
|
||||||
class ChunkBase(BaseModel):
|
class ChunkBase(BaseModel):
|
||||||
|
@ -12,5 +12,4 @@ class ChunkUpdate(ChunkBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class ChunkRead(ChunkBase, IDModel, TimestampModel):
|
class ChunkRead(ChunkBase, IDModel, TimestampModel):
|
||||||
class Config:
|
model_config = ConfigDict(from_attributes=True)
|
||||||
from_attributes = True
|
|
|
@ -1,5 +1,5 @@
|
||||||
from typing import List, Any
|
from typing import List, Any
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, ConfigDict
|
||||||
from sqlalchemy import JSON
|
from sqlalchemy import JSON
|
||||||
from .base import IDModel, TimestampModel
|
from .base import IDModel, TimestampModel
|
||||||
from app.db import DocumentType
|
from app.db import DocumentType
|
||||||
|
@ -37,6 +37,5 @@ class DocumentRead(BaseModel):
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
search_space_id: int
|
search_space_id: int
|
||||||
|
|
||||||
class Config:
|
model_config = ConfigDict(from_attributes=True)
|
||||||
from_attributes = True
|
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, ConfigDict
|
||||||
from typing import Any, List, Literal
|
from typing import Any, List, Literal
|
||||||
from .base import IDModel, TimestampModel
|
from .base import IDModel, TimestampModel
|
||||||
|
|
||||||
|
@ -15,8 +15,7 @@ class PodcastUpdate(PodcastBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class PodcastRead(PodcastBase, IDModel, TimestampModel):
|
class PodcastRead(PodcastBase, IDModel, TimestampModel):
|
||||||
class Config:
|
model_config = ConfigDict(from_attributes=True)
|
||||||
from_attributes = True
|
|
||||||
|
|
||||||
class PodcastGenerateRequest(BaseModel):
|
class PodcastGenerateRequest(BaseModel):
|
||||||
type: Literal["DOCUMENT", "CHAT"]
|
type: Literal["DOCUMENT", "CHAT"]
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional
|
||||||
from pydantic import BaseModel, field_validator
|
from pydantic import BaseModel, field_validator, ConfigDict
|
||||||
from .base import IDModel, TimestampModel
|
from .base import IDModel, TimestampModel
|
||||||
from app.db import SearchSourceConnectorType
|
from app.db import SearchSourceConnectorType
|
||||||
|
|
||||||
|
@ -106,5 +106,4 @@ class SearchSourceConnectorUpdate(BaseModel):
|
||||||
class SearchSourceConnectorRead(SearchSourceConnectorBase, IDModel, TimestampModel):
|
class SearchSourceConnectorRead(SearchSourceConnectorBase, IDModel, TimestampModel):
|
||||||
user_id: uuid.UUID
|
user_id: uuid.UUID
|
||||||
|
|
||||||
class Config:
|
model_config = ConfigDict(from_attributes=True)
|
||||||
from_attributes = True
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, ConfigDict
|
||||||
from .base import IDModel, TimestampModel
|
from .base import IDModel, TimestampModel
|
||||||
|
|
||||||
class SearchSpaceBase(BaseModel):
|
class SearchSpaceBase(BaseModel):
|
||||||
|
@ -19,5 +19,4 @@ class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel):
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
user_id: uuid.UUID
|
user_id: uuid.UUID
|
||||||
|
|
||||||
class Config:
|
model_config = ConfigDict(from_attributes=True)
|
||||||
from_attributes = True
|
|
|
@ -194,6 +194,52 @@ async def add_extension_received_document(
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise RuntimeError(f"Failed to process extension document: {str(e)}")
|
raise RuntimeError(f"Failed to process extension document: {str(e)}")
|
||||||
|
|
||||||
|
async def add_received_markdown_file_document(
|
||||||
|
session: AsyncSession,
|
||||||
|
file_name: str,
|
||||||
|
file_in_markdown: str,
|
||||||
|
search_space_id: int
|
||||||
|
) -> Optional[Document]:
|
||||||
|
try:
|
||||||
|
|
||||||
|
# Generate summary
|
||||||
|
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
|
||||||
|
summary_result = await summary_chain.ainvoke({"document": file_in_markdown})
|
||||||
|
summary_content = summary_result.content
|
||||||
|
summary_embedding = config.embedding_model_instance.embed(
|
||||||
|
summary_content)
|
||||||
|
|
||||||
|
# Process chunks
|
||||||
|
chunks = [
|
||||||
|
Chunk(content=chunk.text, embedding=config.embedding_model_instance.embed(chunk.text))
|
||||||
|
for chunk in config.chunker_instance.chunk(file_in_markdown)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create and store document
|
||||||
|
document = Document(
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
title=file_name,
|
||||||
|
document_type=DocumentType.FILE,
|
||||||
|
document_metadata={
|
||||||
|
"FILE_NAME": file_name,
|
||||||
|
"SAVED_AT": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
},
|
||||||
|
content=summary_content,
|
||||||
|
embedding=summary_embedding,
|
||||||
|
chunks=chunks
|
||||||
|
)
|
||||||
|
|
||||||
|
session.add(document)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(document)
|
||||||
|
|
||||||
|
return document
|
||||||
|
except SQLAlchemyError as db_error:
|
||||||
|
await session.rollback()
|
||||||
|
raise db_error
|
||||||
|
except Exception as e:
|
||||||
|
await session.rollback()
|
||||||
|
raise RuntimeError(f"Failed to process file document: {str(e)}")
|
||||||
|
|
||||||
async def add_received_file_document(
|
async def add_received_file_document(
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from app.schemas import PodcastGenerateRequest
|
|
||||||
from typing import List
|
|
||||||
from sqlalchemy import select
|
|
||||||
from app.db import Chat, Podcast
|
|
||||||
from app.agents.podcaster.graph import graph as podcaster_graph
|
from app.agents.podcaster.graph import graph as podcaster_graph
|
||||||
from surfsense_backend.app.agents.podcaster.state import State
|
from app.agents.podcaster.state import State
|
||||||
|
from app.db import Chat, Podcast
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
|
||||||
async def generate_document_podcast(
|
async def generate_document_podcast(
|
||||||
|
|
Loading…
Add table
Reference in a new issue