Merge pull request #262 from MODSetter/dev

feat: Added Local TTS (Kokoro TTS) Support
This commit is contained in:
Rohan Verma 2025-08-13 17:33:33 -07:00 committed by GitHub
commit 939b365176
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 3056 additions and 2293 deletions

View file

@ -17,6 +17,7 @@ RERANKERS_MODEL_NAME=ms-marco-MiniLM-L-12-v2
RERANKERS_MODEL_TYPE=flashrank RERANKERS_MODEL_TYPE=flashrank
# TTS_SERVICE=local/kokoro for local Kokoro TTS or
# LiteLLM TTS Provider: https://docs.litellm.ai/docs/text_to_speech#supported-providers # LiteLLM TTS Provider: https://docs.litellm.ai/docs/text_to_speech#supported-providers
TTS_SERVICE=openai/tts-1 TTS_SERVICE=openai/tts-1
# Respective TTS Service API # Respective TTS Service API

View file

@ -6,3 +6,4 @@ __pycache__/
.flashrank_cache .flashrank_cache
surf_new_backend.egg-info/ surf_new_backend.egg-info/
podcasts/ podcasts/
temp_audio/

View file

@ -11,6 +11,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
wget \ wget \
unzip \ unzip \
gnupg2 \ gnupg2 \
espeak-ng \
libsndfile1 \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Update certificates and install SSL tools # Update certificates and install SSL tools
@ -51,7 +53,7 @@ RUN python -c "try:\n from docling.document_converter import DocumentConverte
# Install Playwright browsers for web scraping if needed # Install Playwright browsers for web scraping if needed
RUN pip install playwright && \ RUN pip install playwright && \
playwright install --with-deps chromium playwright install chromium
# Copy source code # Copy source code
COPY . . COPY . .

View file

@ -11,6 +11,7 @@ from langchain_core.runnables import RunnableConfig
from litellm import aspeech from litellm import aspeech
from app.config import config as app_config from app.config import config as app_config
from app.services.kokoro_tts_service import get_kokoro_tts_service
from app.services.llm_service import get_user_long_context_llm from app.services.llm_service import get_user_long_context_llm
from .configuration import Configuration from .configuration import Configuration
@ -138,34 +139,49 @@ async def create_merged_podcast_audio(
voice = get_voice_for_provider(app_config.TTS_SERVICE, speaker_id) voice = get_voice_for_provider(app_config.TTS_SERVICE, speaker_id)
# Generate a unique filename for this segment # Generate a unique filename for this segment
filename = f"{temp_dir}/{session_id}_{index}.mp3" if app_config.TTS_SERVICE == "local/kokoro":
# Kokoro generates WAV files
filename = f"{temp_dir}/{session_id}_{index}.wav"
else:
# Other services generate MP3 files
filename = f"{temp_dir}/{session_id}_{index}.mp3"
try: try:
if app_config.TTS_SERVICE_API_BASE: if app_config.TTS_SERVICE == "local/kokoro":
response = await aspeech( # Use Kokoro TTS service
model=app_config.TTS_SERVICE, kokoro_service = await get_kokoro_tts_service(
api_base=app_config.TTS_SERVICE_API_BASE, lang_code="a"
api_key=app_config.TTS_SERVICE_API_KEY, ) # American English
voice=voice, audio_path = await kokoro_service.generate_speech(
input=dialog, text=dialog, voice=voice, speed=1.0, output_path=filename
max_retries=2,
timeout=600,
) )
return audio_path
else: else:
response = await aspeech( if app_config.TTS_SERVICE_API_BASE:
model=app_config.TTS_SERVICE, response = await aspeech(
api_key=app_config.TTS_SERVICE_API_KEY, model=app_config.TTS_SERVICE,
voice=voice, api_base=app_config.TTS_SERVICE_API_BASE,
input=dialog, api_key=app_config.TTS_SERVICE_API_KEY,
max_retries=2, voice=voice,
timeout=600, input=dialog,
) max_retries=2,
timeout=600,
)
else:
response = await aspeech(
model=app_config.TTS_SERVICE,
api_key=app_config.TTS_SERVICE_API_KEY,
voice=voice,
input=dialog,
max_retries=2,
timeout=600,
)
# Save the audio to a file - use proper streaming method # Save the audio to a file - use proper streaming method
with open(filename, "wb") as f: with open(filename, "wb") as f:
f.write(response.content) f.write(response.content)
return filename return filename
except Exception as e: except Exception as e:
print(f"Error generating speech for segment {index}: {e!s}") print(f"Error generating speech for segment {index}: {e!s}")
raise raise

View file

@ -9,6 +9,14 @@ def get_voice_for_provider(provider: str, speaker_id: int) -> dict | str:
Returns: Returns:
Voice configuration - string for OpenAI, dict for Vertex AI Voice configuration - string for OpenAI, dict for Vertex AI
""" """
if provider == "local/kokoro":
# Kokoro voice mapping - https://huggingface.co/hexgrad/Kokoro-82M/tree/main/voices
kokoro_voices = {
0: "am_adam", # Default/intro voice
1: "af_bella", # First speaker
}
return kokoro_voices.get(speaker_id, "af_heart")
# Extract provider type from the model string # Extract provider type from the model string
provider_type = ( provider_type = (
provider.split("/")[0].lower() if "/" in provider else provider.lower() provider.split("/")[0].lower() if "/" in provider else provider.lower()
@ -59,11 +67,7 @@ def get_voice_for_provider(provider: str, speaker_id: int) -> dict | str:
else: else:
# Default fallback to OpenAI format for unknown providers # Default fallback to OpenAI format for unknown providers
default_voices = { default_voices = {
0: "alloy", 0: {},
1: "echo", 1: {},
2: "fable",
3: "onyx",
4: "nova",
5: "shimmer",
} }
return default_voices.get(speaker_id, "alloy") return default_voices.get(speaker_id, default_voices[0])

View file

@ -0,0 +1,138 @@
import asyncio
from pathlib import Path
import soundfile as sf
import torch
from kokoro import KPipeline
class KokoroTTSService:
"""Kokoro TTS service for generating speech from text."""
def __init__(self, lang_code: str = "a"):
"""
Initialize the Kokoro TTS service.
Args:
lang_code: Language code for TTS
'a' => American English
'b' => British English
'e' => Spanish
'f' => French
'h' => Hindi
'i' => Italian
'j' => Japanese
'p' => Brazilian Portuguese
'z' => Mandarin Chinese
"""
self.lang_code = lang_code
self.pipeline = None
self._initialize_pipeline()
def _initialize_pipeline(self):
"""Initialize the Kokoro pipeline."""
try:
self.pipeline = KPipeline(lang_code=self.lang_code)
except Exception as e:
print(f"Error initializing Kokoro pipeline: {e}")
raise
async def generate_speech(
self,
text: str,
voice: str = "af_heart",
speed: float = 1.0,
output_path: str | None = None,
) -> str:
"""
Generate speech from text using Kokoro TTS.
Args:
text: Text to convert to speech
voice: Voice to use (e.g., "af_heart")
speed: Speech speed (default: 1.0)
output_path: Path to save the audio file. If None, creates a temporary file.
Returns:
Path to the generated audio file
"""
if not self.pipeline:
raise RuntimeError("Kokoro pipeline not initialized")
try:
# If no output path provided, create a temporary file
if output_path is None:
temp_dir = Path("temp_audio")
temp_dir.mkdir(exist_ok=True)
output_path = str(temp_dir / f"kokoro_output_{id(text)}.wav")
# Ensure output directory exists
output_file = Path(output_path)
output_file.parent.mkdir(parents=True, exist_ok=True)
# Handle voice tensor loading if it's a path to a .pt file
voice_param = voice
if isinstance(voice, str) and voice.endswith(".pt"):
try:
voice_param = torch.load(voice, weights_only=True)
except Exception as e:
print(
f"Warning: Could not load voice tensor from {voice}, using default: {e}"
)
voice_param = "af_heart"
# Generate audio using the pipeline
# Run in thread pool since Kokoro is synchronous
loop = asyncio.get_event_loop()
generator = await loop.run_in_executor(
None,
lambda: self.pipeline(
text, voice=voice_param, speed=speed, split_pattern=r"\n+"
),
)
# Collect all audio segments
audio_segments = []
for _i, (_gs, _ps, audio) in enumerate(generator):
audio_segments.append(audio)
# Concatenate all audio segments if there are multiple
if len(audio_segments) > 1:
import numpy as np
final_audio = np.concatenate(audio_segments)
elif len(audio_segments) == 1:
final_audio = audio_segments[0]
else:
raise ValueError("No audio generated from text")
# Save the audio file
sf.write(output_path, final_audio, 24000) # Kokoro uses 24kHz sample rate
return output_path
except Exception as e:
print(f"Error generating speech with Kokoro: {e}")
raise
# Global instance for reuse
_kokoro_service: KokoroTTSService | None = None
async def get_kokoro_tts_service(lang_code: str = "a") -> KokoroTTSService:
"""
Get or create a Kokoro TTS service instance.
Args:
lang_code: Language code for TTS
Returns:
KokoroTTSService instance
"""
global _kokoro_service
if _kokoro_service is None or _kokoro_service.lang_code != lang_code:
_kokoro_service = KokoroTTSService(lang_code=lang_code)
return _kokoro_service

View file

@ -16,6 +16,7 @@ dependencies = [
"github3.py==4.0.1", "github3.py==4.0.1",
"google-api-python-client>=2.156.0", "google-api-python-client>=2.156.0",
"google-auth-oauthlib>=1.2.1", "google-auth-oauthlib>=1.2.1",
"kokoro>=0.9.4",
"langchain-community>=0.3.17", "langchain-community>=0.3.17",
"langchain-unstructured>=0.1.6", "langchain-unstructured>=0.1.6",
"langgraph>=0.3.29", "langgraph>=0.3.29",
@ -24,12 +25,16 @@ dependencies = [
"llama-cloud-services>=0.6.25", "llama-cloud-services>=0.6.25",
"markdownify>=0.14.1", "markdownify>=0.14.1",
"notion-client>=2.3.0", "notion-client>=2.3.0",
"numpy>=1.24.0",
"pgvector>=0.3.6", "pgvector>=0.3.6",
"playwright>=1.50.0", "playwright>=1.50.0",
"python-ffmpeg>=2.0.12", "python-ffmpeg>=2.0.12",
"rerankers[flashrank]>=0.7.1", "rerankers[flashrank]>=0.7.1",
"sentence-transformers>=3.4.1", "sentence-transformers>=3.4.1",
"slack-sdk>=3.34.0", "slack-sdk>=3.34.0",
"soundfile>=0.13.1",
"spacy>=3.8.7",
"en-core-web-sm@https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl",
"static-ffmpeg>=2.13", "static-ffmpeg>=2.13",
"tavily-python>=0.3.2", "tavily-python>=0.3.2",
"unstructured-client>=0.30.0", "unstructured-client>=0.30.0",

5122
surfsense_backend/uv.lock generated

File diff suppressed because it is too large Load diff