mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-01 10:09:08 +00:00
feat: Added Local TTS (Kokoro TTS) Support
This commit is contained in:
parent
994ebb4efd
commit
1b29310ae7
8 changed files with 3056 additions and 2293 deletions
|
@ -17,6 +17,7 @@ RERANKERS_MODEL_NAME=ms-marco-MiniLM-L-12-v2
|
||||||
RERANKERS_MODEL_TYPE=flashrank
|
RERANKERS_MODEL_TYPE=flashrank
|
||||||
|
|
||||||
|
|
||||||
|
# TTS_SERVICE=local/kokoro for local Kokoro TTS or
|
||||||
# LiteLLM TTS Provider: https://docs.litellm.ai/docs/text_to_speech#supported-providers
|
# LiteLLM TTS Provider: https://docs.litellm.ai/docs/text_to_speech#supported-providers
|
||||||
TTS_SERVICE=openai/tts-1
|
TTS_SERVICE=openai/tts-1
|
||||||
# Respective TTS Service API
|
# Respective TTS Service API
|
||||||
|
|
1
surfsense_backend/.gitignore
vendored
1
surfsense_backend/.gitignore
vendored
|
@ -6,3 +6,4 @@ __pycache__/
|
||||||
.flashrank_cache
|
.flashrank_cache
|
||||||
surf_new_backend.egg-info/
|
surf_new_backend.egg-info/
|
||||||
podcasts/
|
podcasts/
|
||||||
|
temp_audio/
|
|
@ -11,6 +11,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
wget \
|
wget \
|
||||||
unzip \
|
unzip \
|
||||||
gnupg2 \
|
gnupg2 \
|
||||||
|
espeak-ng \
|
||||||
|
libsndfile1 \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Update certificates and install SSL tools
|
# Update certificates and install SSL tools
|
||||||
|
@ -51,7 +53,7 @@ RUN python -c "try:\n from docling.document_converter import DocumentConverte
|
||||||
|
|
||||||
# Install Playwright browsers for web scraping if needed
|
# Install Playwright browsers for web scraping if needed
|
||||||
RUN pip install playwright && \
|
RUN pip install playwright && \
|
||||||
playwright install --with-deps chromium
|
playwright install chromium
|
||||||
|
|
||||||
# Copy source code
|
# Copy source code
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
|
@ -11,6 +11,7 @@ from langchain_core.runnables import RunnableConfig
|
||||||
from litellm import aspeech
|
from litellm import aspeech
|
||||||
|
|
||||||
from app.config import config as app_config
|
from app.config import config as app_config
|
||||||
|
from app.services.kokoro_tts_service import get_kokoro_tts_service
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
|
|
||||||
from .configuration import Configuration
|
from .configuration import Configuration
|
||||||
|
@ -138,34 +139,49 @@ async def create_merged_podcast_audio(
|
||||||
voice = get_voice_for_provider(app_config.TTS_SERVICE, speaker_id)
|
voice = get_voice_for_provider(app_config.TTS_SERVICE, speaker_id)
|
||||||
|
|
||||||
# Generate a unique filename for this segment
|
# Generate a unique filename for this segment
|
||||||
filename = f"{temp_dir}/{session_id}_{index}.mp3"
|
if app_config.TTS_SERVICE == "local/kokoro":
|
||||||
|
# Kokoro generates WAV files
|
||||||
|
filename = f"{temp_dir}/{session_id}_{index}.wav"
|
||||||
|
else:
|
||||||
|
# Other services generate MP3 files
|
||||||
|
filename = f"{temp_dir}/{session_id}_{index}.mp3"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if app_config.TTS_SERVICE_API_BASE:
|
if app_config.TTS_SERVICE == "local/kokoro":
|
||||||
response = await aspeech(
|
# Use Kokoro TTS service
|
||||||
model=app_config.TTS_SERVICE,
|
kokoro_service = await get_kokoro_tts_service(
|
||||||
api_base=app_config.TTS_SERVICE_API_BASE,
|
lang_code="a"
|
||||||
api_key=app_config.TTS_SERVICE_API_KEY,
|
) # American English
|
||||||
voice=voice,
|
audio_path = await kokoro_service.generate_speech(
|
||||||
input=dialog,
|
text=dialog, voice=voice, speed=1.0, output_path=filename
|
||||||
max_retries=2,
|
|
||||||
timeout=600,
|
|
||||||
)
|
)
|
||||||
|
return audio_path
|
||||||
else:
|
else:
|
||||||
response = await aspeech(
|
if app_config.TTS_SERVICE_API_BASE:
|
||||||
model=app_config.TTS_SERVICE,
|
response = await aspeech(
|
||||||
api_key=app_config.TTS_SERVICE_API_KEY,
|
model=app_config.TTS_SERVICE,
|
||||||
voice=voice,
|
api_base=app_config.TTS_SERVICE_API_BASE,
|
||||||
input=dialog,
|
api_key=app_config.TTS_SERVICE_API_KEY,
|
||||||
max_retries=2,
|
voice=voice,
|
||||||
timeout=600,
|
input=dialog,
|
||||||
)
|
max_retries=2,
|
||||||
|
timeout=600,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = await aspeech(
|
||||||
|
model=app_config.TTS_SERVICE,
|
||||||
|
api_key=app_config.TTS_SERVICE_API_KEY,
|
||||||
|
voice=voice,
|
||||||
|
input=dialog,
|
||||||
|
max_retries=2,
|
||||||
|
timeout=600,
|
||||||
|
)
|
||||||
|
|
||||||
# Save the audio to a file - use proper streaming method
|
# Save the audio to a file - use proper streaming method
|
||||||
with open(filename, "wb") as f:
|
with open(filename, "wb") as f:
|
||||||
f.write(response.content)
|
f.write(response.content)
|
||||||
|
|
||||||
return filename
|
return filename
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error generating speech for segment {index}: {e!s}")
|
print(f"Error generating speech for segment {index}: {e!s}")
|
||||||
raise
|
raise
|
||||||
|
|
|
@ -9,6 +9,14 @@ def get_voice_for_provider(provider: str, speaker_id: int) -> dict | str:
|
||||||
Returns:
|
Returns:
|
||||||
Voice configuration - string for OpenAI, dict for Vertex AI
|
Voice configuration - string for OpenAI, dict for Vertex AI
|
||||||
"""
|
"""
|
||||||
|
if provider == "local/kokoro":
|
||||||
|
# Kokoro voice mapping - https://huggingface.co/hexgrad/Kokoro-82M/tree/main/voices
|
||||||
|
kokoro_voices = {
|
||||||
|
0: "am_adam", # Default/intro voice
|
||||||
|
1: "af_bella", # First speaker
|
||||||
|
}
|
||||||
|
return kokoro_voices.get(speaker_id, "af_heart")
|
||||||
|
|
||||||
# Extract provider type from the model string
|
# Extract provider type from the model string
|
||||||
provider_type = (
|
provider_type = (
|
||||||
provider.split("/")[0].lower() if "/" in provider else provider.lower()
|
provider.split("/")[0].lower() if "/" in provider else provider.lower()
|
||||||
|
@ -59,11 +67,7 @@ def get_voice_for_provider(provider: str, speaker_id: int) -> dict | str:
|
||||||
else:
|
else:
|
||||||
# Default fallback to OpenAI format for unknown providers
|
# Default fallback to OpenAI format for unknown providers
|
||||||
default_voices = {
|
default_voices = {
|
||||||
0: "alloy",
|
0: {},
|
||||||
1: "echo",
|
1: {},
|
||||||
2: "fable",
|
|
||||||
3: "onyx",
|
|
||||||
4: "nova",
|
|
||||||
5: "shimmer",
|
|
||||||
}
|
}
|
||||||
return default_voices.get(speaker_id, "alloy")
|
return default_voices.get(speaker_id, default_voices[0])
|
||||||
|
|
138
surfsense_backend/app/services/kokoro_tts_service.py
Normal file
138
surfsense_backend/app/services/kokoro_tts_service.py
Normal file
|
@ -0,0 +1,138 @@
|
||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import soundfile as sf
|
||||||
|
import torch
|
||||||
|
from kokoro import KPipeline
|
||||||
|
|
||||||
|
|
||||||
|
class KokoroTTSService:
|
||||||
|
"""Kokoro TTS service for generating speech from text."""
|
||||||
|
|
||||||
|
def __init__(self, lang_code: str = "a"):
|
||||||
|
"""
|
||||||
|
Initialize the Kokoro TTS service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lang_code: Language code for TTS
|
||||||
|
'a' => American English
|
||||||
|
'b' => British English
|
||||||
|
'e' => Spanish
|
||||||
|
'f' => French
|
||||||
|
'h' => Hindi
|
||||||
|
'i' => Italian
|
||||||
|
'j' => Japanese
|
||||||
|
'p' => Brazilian Portuguese
|
||||||
|
'z' => Mandarin Chinese
|
||||||
|
"""
|
||||||
|
self.lang_code = lang_code
|
||||||
|
self.pipeline = None
|
||||||
|
self._initialize_pipeline()
|
||||||
|
|
||||||
|
def _initialize_pipeline(self):
|
||||||
|
"""Initialize the Kokoro pipeline."""
|
||||||
|
try:
|
||||||
|
self.pipeline = KPipeline(lang_code=self.lang_code)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error initializing Kokoro pipeline: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def generate_speech(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
voice: str = "af_heart",
|
||||||
|
speed: float = 1.0,
|
||||||
|
output_path: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Generate speech from text using Kokoro TTS.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to convert to speech
|
||||||
|
voice: Voice to use (e.g., "af_heart")
|
||||||
|
speed: Speech speed (default: 1.0)
|
||||||
|
output_path: Path to save the audio file. If None, creates a temporary file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the generated audio file
|
||||||
|
"""
|
||||||
|
if not self.pipeline:
|
||||||
|
raise RuntimeError("Kokoro pipeline not initialized")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# If no output path provided, create a temporary file
|
||||||
|
if output_path is None:
|
||||||
|
temp_dir = Path("temp_audio")
|
||||||
|
temp_dir.mkdir(exist_ok=True)
|
||||||
|
output_path = str(temp_dir / f"kokoro_output_{id(text)}.wav")
|
||||||
|
|
||||||
|
# Ensure output directory exists
|
||||||
|
output_file = Path(output_path)
|
||||||
|
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Handle voice tensor loading if it's a path to a .pt file
|
||||||
|
voice_param = voice
|
||||||
|
if isinstance(voice, str) and voice.endswith(".pt"):
|
||||||
|
try:
|
||||||
|
voice_param = torch.load(voice, weights_only=True)
|
||||||
|
except Exception as e:
|
||||||
|
print(
|
||||||
|
f"Warning: Could not load voice tensor from {voice}, using default: {e}"
|
||||||
|
)
|
||||||
|
voice_param = "af_heart"
|
||||||
|
|
||||||
|
# Generate audio using the pipeline
|
||||||
|
# Run in thread pool since Kokoro is synchronous
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
generator = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: self.pipeline(
|
||||||
|
text, voice=voice_param, speed=speed, split_pattern=r"\n+"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Collect all audio segments
|
||||||
|
audio_segments = []
|
||||||
|
for _i, (_gs, _ps, audio) in enumerate(generator):
|
||||||
|
audio_segments.append(audio)
|
||||||
|
|
||||||
|
# Concatenate all audio segments if there are multiple
|
||||||
|
if len(audio_segments) > 1:
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
final_audio = np.concatenate(audio_segments)
|
||||||
|
elif len(audio_segments) == 1:
|
||||||
|
final_audio = audio_segments[0]
|
||||||
|
else:
|
||||||
|
raise ValueError("No audio generated from text")
|
||||||
|
|
||||||
|
# Save the audio file
|
||||||
|
sf.write(output_path, final_audio, 24000) # Kokoro uses 24kHz sample rate
|
||||||
|
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error generating speech with Kokoro: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# Global instance for reuse
|
||||||
|
_kokoro_service: KokoroTTSService | None = None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_kokoro_tts_service(lang_code: str = "a") -> KokoroTTSService:
|
||||||
|
"""
|
||||||
|
Get or create a Kokoro TTS service instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lang_code: Language code for TTS
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
KokoroTTSService instance
|
||||||
|
"""
|
||||||
|
global _kokoro_service
|
||||||
|
|
||||||
|
if _kokoro_service is None or _kokoro_service.lang_code != lang_code:
|
||||||
|
_kokoro_service = KokoroTTSService(lang_code=lang_code)
|
||||||
|
|
||||||
|
return _kokoro_service
|
|
@ -16,6 +16,7 @@ dependencies = [
|
||||||
"github3.py==4.0.1",
|
"github3.py==4.0.1",
|
||||||
"google-api-python-client>=2.156.0",
|
"google-api-python-client>=2.156.0",
|
||||||
"google-auth-oauthlib>=1.2.1",
|
"google-auth-oauthlib>=1.2.1",
|
||||||
|
"kokoro>=0.9.4",
|
||||||
"langchain-community>=0.3.17",
|
"langchain-community>=0.3.17",
|
||||||
"langchain-unstructured>=0.1.6",
|
"langchain-unstructured>=0.1.6",
|
||||||
"langgraph>=0.3.29",
|
"langgraph>=0.3.29",
|
||||||
|
@ -24,12 +25,16 @@ dependencies = [
|
||||||
"llama-cloud-services>=0.6.25",
|
"llama-cloud-services>=0.6.25",
|
||||||
"markdownify>=0.14.1",
|
"markdownify>=0.14.1",
|
||||||
"notion-client>=2.3.0",
|
"notion-client>=2.3.0",
|
||||||
|
"numpy>=1.24.0",
|
||||||
"pgvector>=0.3.6",
|
"pgvector>=0.3.6",
|
||||||
"playwright>=1.50.0",
|
"playwright>=1.50.0",
|
||||||
"python-ffmpeg>=2.0.12",
|
"python-ffmpeg>=2.0.12",
|
||||||
"rerankers[flashrank]>=0.7.1",
|
"rerankers[flashrank]>=0.7.1",
|
||||||
"sentence-transformers>=3.4.1",
|
"sentence-transformers>=3.4.1",
|
||||||
"slack-sdk>=3.34.0",
|
"slack-sdk>=3.34.0",
|
||||||
|
"soundfile>=0.13.1",
|
||||||
|
"spacy>=3.8.7",
|
||||||
|
"en-core-web-sm@https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl",
|
||||||
"static-ffmpeg>=2.13",
|
"static-ffmpeg>=2.13",
|
||||||
"tavily-python>=0.3.2",
|
"tavily-python>=0.3.2",
|
||||||
"unstructured-client>=0.30.0",
|
"unstructured-client>=0.30.0",
|
||||||
|
|
5122
surfsense_backend/uv.lock
generated
5122
surfsense_backend/uv.lock
generated
File diff suppressed because it is too large
Load diff
Loading…
Add table
Reference in a new issue