mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-09 13:54:40 +00:00
feat: Added Speech to Text support.
- Supports audio & video files. - Will be useful for Youtube vids which dont have transcripts.
This commit is contained in:
parent
57987ecc76
commit
a8080d2dc7
8 changed files with 172 additions and 73 deletions
|
@ -1,3 +1,4 @@
|
|||
from litellm import atranscription
|
||||
from fastapi import APIRouter, Depends, BackgroundTasks, UploadFile, Form, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
@ -7,6 +8,7 @@ from app.schemas import DocumentsCreate, DocumentUpdate, DocumentRead
|
|||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
from app.tasks.background_tasks import add_received_markdown_file_document, add_extension_received_document, add_received_file_document, add_crawled_url_document, add_youtube_video_document
|
||||
from app.config import config as app_config
|
||||
# Force asyncio to use standard event loop before unstructured imports
|
||||
import asyncio
|
||||
try:
|
||||
|
@ -17,9 +19,9 @@ import os
|
|||
os.environ["UNSTRUCTURED_HAS_PATCHED_LOOP"] = "1"
|
||||
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/documents/")
|
||||
async def create_documents(
|
||||
request: DocumentsCreate,
|
||||
|
@ -30,19 +32,19 @@ async def create_documents(
|
|||
try:
|
||||
# Check if the user owns the search space
|
||||
await check_ownership(session, SearchSpace, request.search_space_id, user)
|
||||
|
||||
|
||||
if request.document_type == DocumentType.EXTENSION:
|
||||
for individual_document in request.content:
|
||||
fastapi_background_tasks.add_task(
|
||||
process_extension_document_with_new_session,
|
||||
individual_document,
|
||||
process_extension_document_with_new_session,
|
||||
individual_document,
|
||||
request.search_space_id
|
||||
)
|
||||
elif request.document_type == DocumentType.CRAWLED_URL:
|
||||
for url in request.content:
|
||||
for url in request.content:
|
||||
fastapi_background_tasks.add_task(
|
||||
process_crawled_url_with_new_session,
|
||||
url,
|
||||
process_crawled_url_with_new_session,
|
||||
url,
|
||||
request.search_space_id
|
||||
)
|
||||
elif request.document_type == DocumentType.YOUTUBE_VIDEO:
|
||||
|
@ -57,7 +59,7 @@ async def create_documents(
|
|||
status_code=400,
|
||||
detail="Invalid document type"
|
||||
)
|
||||
|
||||
|
||||
await session.commit()
|
||||
return {"message": "Documents processed successfully"}
|
||||
except HTTPException:
|
||||
|
@ -69,6 +71,7 @@ async def create_documents(
|
|||
detail=f"Failed to process documents: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/documents/fileupload")
|
||||
async def create_documents(
|
||||
files: list[UploadFile],
|
||||
|
@ -79,26 +82,26 @@ async def create_documents(
|
|||
):
|
||||
try:
|
||||
await check_ownership(session, SearchSpace, search_space_id, user)
|
||||
|
||||
|
||||
if not files:
|
||||
raise HTTPException(status_code=400, detail="No files provided")
|
||||
|
||||
|
||||
for file in files:
|
||||
try:
|
||||
# Save file to a temporary location to avoid stream issues
|
||||
import tempfile
|
||||
import aiofiles
|
||||
import os
|
||||
|
||||
|
||||
# Create temp file
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file:
|
||||
temp_path = temp_file.name
|
||||
|
||||
|
||||
# Write uploaded file to temp file
|
||||
content = await file.read()
|
||||
with open(temp_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
# Process in background to avoid uvloop conflicts
|
||||
fastapi_background_tasks.add_task(
|
||||
process_file_in_background_with_new_session,
|
||||
|
@ -111,7 +114,7 @@ async def create_documents(
|
|||
status_code=422,
|
||||
detail=f"Failed to process file {file.filename}: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
await session.commit()
|
||||
return {"message": "Files uploaded for processing"}
|
||||
except HTTPException:
|
||||
|
@ -136,14 +139,14 @@ async def process_file_in_background(
|
|||
# For markdown files, read the content directly
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
markdown_content = f.read()
|
||||
|
||||
|
||||
# Clean up the temp file
|
||||
import os
|
||||
try:
|
||||
os.unlink(file_path)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
# Process markdown directly through specialized function
|
||||
await add_received_markdown_file_document(
|
||||
session,
|
||||
|
@ -151,10 +154,46 @@ async def process_file_in_background(
|
|||
markdown_content,
|
||||
search_space_id
|
||||
)
|
||||
# Check if the file is an audio file
|
||||
elif filename.lower().endswith(('.mp3', '.mp4', '.mpeg', '.mpga', '.m4a', '.wav', '.webm')):
|
||||
# Open the audio file for transcription
|
||||
with open(file_path, "rb") as audio_file:
|
||||
# Use LiteLLM for audio transcription
|
||||
if app_config.STT_SERVICE_API_BASE:
|
||||
transcription_response = await atranscription(
|
||||
model=app_config.STT_SERVICE,
|
||||
file=audio_file,
|
||||
api_base=app_config.STT_SERVICE_API_BASE
|
||||
)
|
||||
else:
|
||||
transcription_response = await atranscription(
|
||||
model=app_config.STT_SERVICE,
|
||||
file=audio_file
|
||||
)
|
||||
|
||||
# Extract the transcribed text
|
||||
transcribed_text = transcription_response.get("text", "")
|
||||
|
||||
# Add metadata about the transcription
|
||||
transcribed_text = f"# Transcription of {filename}\n\n{transcribed_text}"
|
||||
|
||||
# Clean up the temp file
|
||||
try:
|
||||
os.unlink(file_path)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Process transcription as markdown document
|
||||
await add_received_markdown_file_document(
|
||||
session,
|
||||
filename,
|
||||
transcribed_text,
|
||||
search_space_id
|
||||
)
|
||||
else:
|
||||
# Use synchronous unstructured API to avoid event loop issues
|
||||
from langchain_unstructured import UnstructuredLoader
|
||||
|
||||
|
||||
# Process the file
|
||||
loader = UnstructuredLoader(
|
||||
file_path,
|
||||
|
@ -165,16 +204,16 @@ async def process_file_in_background(
|
|||
include_metadata=False,
|
||||
strategy="auto",
|
||||
)
|
||||
|
||||
|
||||
docs = await loader.aload()
|
||||
|
||||
|
||||
# Clean up the temp file
|
||||
import os
|
||||
try:
|
||||
os.unlink(file_path)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
# Pass the documents to the existing background task
|
||||
await add_received_file_document(
|
||||
session,
|
||||
|
@ -186,6 +225,7 @@ async def process_file_in_background(
|
|||
import logging
|
||||
logging.error(f"Error processing file in background: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/documents/", response_model=List[DocumentRead])
|
||||
async def read_documents(
|
||||
skip: int = 0,
|
||||
|
@ -195,17 +235,18 @@ async def read_documents(
|
|||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
query = select(Document).join(SearchSpace).filter(SearchSpace.user_id == user.id)
|
||||
|
||||
query = select(Document).join(SearchSpace).filter(
|
||||
SearchSpace.user_id == user.id)
|
||||
|
||||
# Filter by search_space_id if provided
|
||||
if search_space_id is not None:
|
||||
query = query.filter(Document.search_space_id == search_space_id)
|
||||
|
||||
|
||||
result = await session.execute(
|
||||
query.offset(skip).limit(limit)
|
||||
)
|
||||
db_documents = result.scalars().all()
|
||||
|
||||
|
||||
# Convert database objects to API-friendly format
|
||||
api_documents = []
|
||||
for doc in db_documents:
|
||||
|
@ -218,7 +259,7 @@ async def read_documents(
|
|||
created_at=doc.created_at,
|
||||
search_space_id=doc.search_space_id
|
||||
))
|
||||
|
||||
|
||||
return api_documents
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
|
@ -226,6 +267,7 @@ async def read_documents(
|
|||
detail=f"Failed to fetch documents: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/documents/{document_id}", response_model=DocumentRead)
|
||||
async def read_document(
|
||||
document_id: int,
|
||||
|
@ -239,13 +281,13 @@ async def read_document(
|
|||
.filter(Document.id == document_id, SearchSpace.user_id == user.id)
|
||||
)
|
||||
document = result.scalars().first()
|
||||
|
||||
|
||||
if not document:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Document with id {document_id} not found"
|
||||
)
|
||||
|
||||
|
||||
# Convert database object to API-friendly format
|
||||
return DocumentRead(
|
||||
id=document.id,
|
||||
|
@ -262,6 +304,7 @@ async def read_document(
|
|||
detail=f"Failed to fetch document: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.put("/documents/{document_id}", response_model=DocumentRead)
|
||||
async def update_document(
|
||||
document_id: int,
|
||||
|
@ -277,19 +320,19 @@ async def update_document(
|
|||
.filter(Document.id == document_id, SearchSpace.user_id == user.id)
|
||||
)
|
||||
db_document = result.scalars().first()
|
||||
|
||||
|
||||
if not db_document:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Document with id {document_id} not found"
|
||||
)
|
||||
|
||||
|
||||
update_data = document_update.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(db_document, key, value)
|
||||
await session.commit()
|
||||
await session.refresh(db_document)
|
||||
|
||||
|
||||
# Convert to DocumentRead for response
|
||||
return DocumentRead(
|
||||
id=db_document.id,
|
||||
|
@ -309,6 +352,7 @@ async def update_document(
|
|||
detail=f"Failed to update document: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/documents/{document_id}", response_model=dict)
|
||||
async def delete_document(
|
||||
document_id: int,
|
||||
|
@ -323,13 +367,13 @@ async def delete_document(
|
|||
.filter(Document.id == document_id, SearchSpace.user_id == user.id)
|
||||
)
|
||||
document = result.scalars().first()
|
||||
|
||||
|
||||
if not document:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Document with id {document_id} not found"
|
||||
)
|
||||
|
||||
|
||||
await session.delete(document)
|
||||
await session.commit()
|
||||
return {"message": "Document deleted successfully"}
|
||||
|
@ -340,16 +384,16 @@ async def delete_document(
|
|||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to delete document: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
)
|
||||
|
||||
|
||||
async def process_extension_document_with_new_session(
|
||||
individual_document,
|
||||
search_space_id: int
|
||||
):
|
||||
"""Create a new session and process extension document."""
|
||||
from app.db import async_session_maker
|
||||
|
||||
|
||||
async with async_session_maker() as session:
|
||||
try:
|
||||
await add_extension_received_document(session, individual_document, search_space_id)
|
||||
|
@ -357,13 +401,14 @@ async def process_extension_document_with_new_session(
|
|||
import logging
|
||||
logging.error(f"Error processing extension document: {str(e)}")
|
||||
|
||||
|
||||
async def process_crawled_url_with_new_session(
|
||||
url: str,
|
||||
search_space_id: int
|
||||
):
|
||||
"""Create a new session and process crawled URL."""
|
||||
from app.db import async_session_maker
|
||||
|
||||
|
||||
async with async_session_maker() as session:
|
||||
try:
|
||||
await add_crawled_url_document(session, url, search_space_id)
|
||||
|
@ -371,6 +416,7 @@ async def process_crawled_url_with_new_session(
|
|||
import logging
|
||||
logging.error(f"Error processing crawled URL: {str(e)}")
|
||||
|
||||
|
||||
async def process_file_in_background_with_new_session(
|
||||
file_path: str,
|
||||
filename: str,
|
||||
|
@ -378,21 +424,21 @@ async def process_file_in_background_with_new_session(
|
|||
):
|
||||
"""Create a new session and process file."""
|
||||
from app.db import async_session_maker
|
||||
|
||||
|
||||
async with async_session_maker() as session:
|
||||
await process_file_in_background(file_path, filename, search_space_id, session)
|
||||
|
||||
|
||||
async def process_youtube_video_with_new_session(
|
||||
url: str,
|
||||
search_space_id: int
|
||||
):
|
||||
"""Create a new session and process YouTube video."""
|
||||
from app.db import async_session_maker
|
||||
|
||||
|
||||
async with async_session_maker() as session:
|
||||
try:
|
||||
await add_youtube_video_document(session, url, search_space_id)
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.error(f"Error processing YouTube video: {str(e)}")
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue