Preloading audio for Edge TTS and other fixes (#138)

* Blacklist some low-quality web tts voices

* Dropdown layout tweaks

* Handle no audio data received in Edge TTS

* Preloading audio for Edge TTS
This commit is contained in:
Huang Xin 2025-01-10 14:11:09 +01:00 committed by GitHub
parent 7402141237
commit 00003a9415
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 210 additions and 31 deletions

View file

@ -62,7 +62,7 @@ const FontDropdown: React.FC<DropdownProps> = ({
>
{moreOptions.map((option) => (
<li key={option} onClick={() => onSelect(option)}>
<div className='flex items-center px-0'>
<div className='flex items-center px-2'>
<span style={{ minWidth: '20px' }}>
{selected === option && <MdCheck size={20} className='text-base-content' />}
</span>

View file

@ -54,6 +54,11 @@ const TTSControl = () => {
setBookKey(bookKey);
if (ttsControllerRef.current) {
ttsControllerRef.current.stop();
ttsControllerRef.current = null;
}
try {
const ttsController = new TTSController(view);
await ttsController.init();
@ -94,6 +99,8 @@ const TTSControl = () => {
setIsPlaying(false);
setIsPaused(true);
} else if (isPaused) {
// start for forward/backward/setvoice-paused
// set rate don't pause the tts
if (ttsController.state === 'paused') {
ttsController.resume();
} else {
@ -163,6 +170,14 @@ const TTSControl = () => {
return [];
};
const handleGetVoiceId = () => {
const ttsController = ttsControllerRef.current;
if (ttsController) {
return ttsController.getVoiceId();
}
return '';
};
const updatePanelPosition = () => {
if (iconRef.current) {
const rect = iconRef.current.getBoundingClientRect();
@ -228,6 +243,7 @@ const TTSControl = () => {
onSetRate={handleSetRate}
onGetVoices={handleGetVoices}
onSetVoice={handleSetVoice}
onGetVoiceId={handleGetVoiceId}
/>
</Popup>
)}

View file

@ -18,6 +18,7 @@ type TTSPanelProps = {
onSetRate: (rate: number) => void;
onGetVoices: (lang: string) => Promise<TTSVoice[]>;
onSetVoice: (voice: string) => void;
onGetVoiceId: () => string;
};
const TTSPanel = ({
@ -31,6 +32,7 @@ const TTSPanel = ({
onSetRate,
onGetVoices,
onSetVoice,
onGetVoiceId,
}: TTSPanelProps) => {
const _ = useTranslation();
const { getViewSettings, setViewSettings } = useReaderStore();
@ -58,6 +60,12 @@ const TTSPanel = ({
setViewSettings(bookKey, viewSettings);
};
useEffect(() => {
const voiceId = onGetVoiceId();
setSelectedVoice(voiceId);
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
useEffect(() => {
const fetchVoices = async () => {
const voices = await onGetVoices(ttsLang);
@ -131,7 +139,7 @@ const TTSPanel = ({
key={`${index}-${voice.id}`}
onClick={() => !voice.disabled && handleSelectVoice(voice.id)}
>
<div className='flex items-center px-0'>
<div className='flex items-center px-2'>
<span style={{ minWidth: '20px' }}>
{selectedVoice === voice.id && (
<MdCheck size={20} className='text-base-content' />

View file

@ -1,4 +1,6 @@
import { md5 } from 'js-md5';
import { randomMd5 } from '@/utils/misc';
import { LRUCache } from '@/utils/lru';
const EDGE_SPEECH_URL =
'wss://speech.platform.bing.com/consumer/speech/synthesize/readaloud/edge/v1';
@ -46,18 +48,25 @@ const genVoiceList = (voices: Record<string, string[]>) => {
};
export interface EdgeTTSPayload {
lang: string;
text: string;
voice: string;
rate: number;
pitch: number;
}
const hashPayload = (payload: EdgeTTSPayload): string => {
const base = JSON.stringify(payload);
return md5(base);
};
export class EdgeSpeechTTS {
static voices = genVoiceList(EDGE_TTS_VOICES);
private static audioCache = new LRUCache<string, AudioBuffer>(200);
constructor() {}
async #fetchEdgeSpeechWs({ text, voice, rate }: EdgeTTSPayload): Promise<Response> {
async #fetchEdgeSpeechWs({ lang, text, voice, rate }: EdgeTTSPayload): Promise<Response> {
const connectId = randomMd5();
const url = `${EDGE_SPEECH_URL}?ConnectionId=${connectId}&TrustedClientToken=${EDGE_API_TOKEN}`;
const date = new Date().toString();
@ -83,9 +92,9 @@ export class EdgeSpeechTTS {
},
});
const genSSML = (text: string, voice: string, rate: number) => {
const genSSML = (lang: string, text: string, voice: string, rate: number) => {
return `
<speak version="1.0" xml:lang="en-US">
<speak version="1.0" xml:lang="${lang}">
<voice name="${voice}">
<prosody rate="${rate}">
${text}
@ -126,7 +135,7 @@ export class EdgeSpeechTTS {
return { headers, body };
};
const ssml = genSSML(text, voice, rate);
const ssml = genSSML(lang, text, voice, rate);
const content = genSendContent(contentHeaders, ssml);
const config = genSendContent(configHeaders, configContent);
@ -177,9 +186,19 @@ export class EdgeSpeechTTS {
}
async createAudio(payload: EdgeTTSPayload): Promise<AudioBuffer> {
const res = await this.create(payload);
const arrayBuffer = await res.arrayBuffer();
const audioContext = new AudioContext();
return await audioContext.decodeAudioData(arrayBuffer.slice(0));
const cacheKey = hashPayload(payload);
if (EdgeSpeechTTS.audioCache.has(cacheKey)) {
return EdgeSpeechTTS.audioCache.get(cacheKey)!;
}
try {
const res = await this.create(payload);
const arrayBuffer = await res.arrayBuffer();
const audioContext = new AudioContext();
const audioBuffer = await audioContext.decodeAudioData(arrayBuffer.slice(0));
EdgeSpeechTTS.audioCache.set(cacheKey, audioBuffer);
return audioBuffer;
} catch (error) {
throw error;
}
}
}

View file

@ -1,6 +1,6 @@
import { getUserLocale } from '@/utils/misc';
import { TTSClient, TTSMessageEvent, TTSVoice } from './TTSClient';
import { EdgeSpeechTTS } from '@/libs/edgeTTS';
import { EdgeSpeechTTS, EdgeTTSPayload } from '@/libs/edgeTTS';
import { parseSSMLLang, parseSSMLMarks } from '@/utils/ssml';
import { TTSGranularity } from '@/types/view';
@ -27,6 +27,7 @@ export class EdgeTTSClient implements TTSClient {
this.#voices = EdgeSpeechTTS.voices;
try {
await this.#edgeTTS.create({
lang: 'en',
text: 'test',
voice: 'en-US-AriaNeural',
rate: 1.0,
@ -39,29 +40,39 @@ export class EdgeTTSClient implements TTSClient {
return this.available;
}
getPayload = (lang: string, text: string, voiceId: string) => {
return { lang, text, voice: voiceId, rate: this.#rate, pitch: this.#pitch } as EdgeTTSPayload;
};
async *speak(ssml: string): AsyncGenerator<TTSMessageEvent> {
const { marks } = parseSSMLMarks(ssml);
const lang = parseSSMLLang(ssml) || 'en';
let voiceId = 'en-US-AriaNeural';
if (!this.#voice) {
const voices = await this.getVoices(lang);
this.#voice = voices[0] ? voices[0] : this.#voices.find((v) => v.id === voiceId) || null;
}
if (this.#voice) {
voiceId = this.#voice.id;
} else {
const voices = await this.getVoices(lang);
voiceId = voices[0]?.id || voiceId;
}
this.stopInternal();
// Preloading for longer ssml
if (marks.length > 1) {
for (const mark of marks.slice(1)) {
this.#edgeTTS.createAudio(this.getPayload(lang, mark.text, voiceId)).catch((error) => {
console.warn('Error preloading mark:', mark, error);
});
}
}
for (const mark of marks) {
try {
this.#audioBuffer = await this.#edgeTTS.createAudio({
text: mark.text.replace(/\r?\n/g, ''),
voice: voiceId,
rate: this.#rate,
pitch: this.#pitch,
});
this.#audioBuffer = await this.#edgeTTS.createAudio(
this.getPayload(lang, mark.text, voiceId),
);
this.#audioContext = new AudioContext();
this.#sourceNode = this.#audioContext.createBufferSource();
this.#sourceNode.buffer = this.#audioBuffer;
@ -89,10 +100,16 @@ export class EdgeTTSClient implements TTSClient {
this.#startedAt = this.#audioContext.currentTime;
});
yield result;
if (result.code === 'error') {
break;
}
} catch (error) {
if (error instanceof Error && error.message === 'No audio data received.') {
console.warn('No audio data received for:', mark.text);
yield {
code: 'end',
message: `Chunk finished: ${mark.name}`,
};
continue;
}
console.log('Error:', error);
yield {
code: 'error',
message: error instanceof Error ? error.message : String(error),
@ -175,4 +192,8 @@ export class EdgeTTSClient implements TTSClient {
getGranularities(): TTSGranularity[] {
return ['sentence'];
}
getVoiceId(): string {
return this.#voice?.id || '';
}
}

View file

@ -27,4 +27,5 @@ export interface TTSClient {
getAllVoices(): Promise<TTSVoice[]>;
getVoices(lang: string): Promise<TTSVoice[]>;
getGranularities(): TTSGranularity[];
getVoiceId(): string;
}

View file

@ -3,7 +3,13 @@ import { TTSClient, TTSMessageCode, TTSVoice } from './TTSClient';
import { WebSpeechClient } from './WebSpeechClient';
import { EdgeTTSClient } from './EdgeTTSClient';
type TTSState = 'stopped' | 'playing' | 'paused' | 'backward-paused' | 'forward-paused';
type TTSState =
| 'stopped'
| 'playing'
| 'paused'
| 'backward-paused'
| 'forward-paused'
| 'setvoice-paused';
export class TTSController extends EventTarget {
state: TTSState = 'stopped';
@ -53,7 +59,7 @@ export class TTSController extends EventTarget {
if (!ssml) {
this.#nossmlCnt++;
// FIXME: in case we are at the end of the book, need a better way to handle this
if (this.#nossmlCnt < 10) {
if (this.#nossmlCnt < 10 && this.state === 'playing') {
await this.view.next(1);
this.forward();
}
@ -71,7 +77,7 @@ export class TTSController extends EventTarget {
lastCode = code;
}
if (lastCode === 'end') {
if (lastCode === 'end' && this.state === 'playing') {
this.forward();
}
}
@ -147,6 +153,7 @@ export class TTSController extends EventTarget {
}
async setVoice(voiceId: string) {
this.state = 'setvoice-paused';
this.ttsClient.stop();
if (this.ttsEdgeVoices.find((voice) => voice.id === voiceId && !voice.disabled)) {
this.ttsClient = this.ttsEdgeClient;
@ -156,6 +163,10 @@ export class TTSController extends EventTarget {
await this.ttsClient.setVoice(voiceId);
}
getVoiceId() {
return this.ttsClient.getVoiceId();
}
error(e: unknown) {
console.error(e);
this.state = 'stopped';

View file

@ -4,6 +4,36 @@ import { AsyncQueue } from '@/utils/queue';
import { findSSMLMark, parseSSMLLang, parseSSMLMarks } from '@/utils/ssml';
import { TTSGranularity } from '@/types/view';
const BLACKLISTED_VOICES = [
'Albert',
'Bad News',
'Bahh',
'Bells',
'Boing',
'Bubbles',
'Cellos',
'Eddy',
'Flo',
'Fred',
'Good News',
'Grandma',
'Grandpa',
'Jester',
'Junior',
'Kathy',
'Organ',
'Ralph',
'Reed',
'Rocko',
'Sandy',
'Shelley',
'Superstar',
'Trinoids',
'Whisper',
'Wobble',
'Zarvox',
];
interface TTSBoundaryEvent {
type: 'boundary' | 'end' | 'error';
speaking: boolean;
@ -176,6 +206,12 @@ export class WebSpeechClient implements TTSClient {
}
async *speak(ssml: string): AsyncGenerator<TTSMessageEvent> {
const lang = parseSSMLLang(ssml) || 'en';
if (!this.#voice) {
const voices = await this.getVoices(lang);
const voiceId = voices[0]?.id ?? '';
this.#voice = this.#voices.find((v) => v.voiceURI === voiceId) || null;
}
for await (const ev of speakWithMarks(
ssml,
() => this.#rate,
@ -240,11 +276,15 @@ export class WebSpeechClient implements TTSClient {
async getVoices(lang: string) {
const locale = lang === 'en' ? getUserLocale(lang) || lang : lang;
const isValidVoice = (id: string) => {
return !id.includes('com.apple') || id.includes('com.apple.voice');
return !id.includes('com.apple') || id.includes('com.apple.voice.compact');
};
const isNotBlacklisted = (voice: SpeechSynthesisVoice) => {
return BLACKLISTED_VOICES.some((name) => voice.name.includes(name)) === false;
};
const filteredVoices = this.#voices
.filter((voice) => voice.lang.startsWith(locale))
.filter((voice) => isValidVoice(voice.voiceURI || ''));
.filter((voice) => isValidVoice(voice.voiceURI || ''))
.filter(isNotBlacklisted);
const voices = filteredVoices.map((voice) => {
return { id: voice.voiceURI, name: voice.name, lang: voice.lang } as TTSVoice;
});
@ -259,4 +299,8 @@ export class WebSpeechClient implements TTSClient {
// in the middle of speech is not possible for different granularities
return ['sentence'];
}
getVoiceId(): string {
return this.#voice?.voiceURI ?? '';
}
}

View file

@ -0,0 +1,54 @@
export class LRUCache<K, V> {
private capacity: number;
private map: Map<K, V>;
constructor(capacity: number) {
if (capacity <= 0) {
throw new Error('LRUCache capacity must be greater than 0');
}
this.capacity = capacity;
this.map = new Map();
}
get(key: K): V | undefined {
if (!this.map.has(key)) {
return undefined;
}
const value = this.map.get(key)!;
this.map.delete(key);
this.map.set(key, value);
return value;
}
set(key: K, value: V): void {
if (this.map.has(key)) {
this.map.delete(key);
} else if (this.map.size === this.capacity) {
const oldestKey = this.map.keys().next().value;
if (oldestKey) {
this.map.delete(oldestKey);
}
}
this.map.set(key, value);
}
has(key: K): boolean {
return this.map.has(key);
}
delete(key: K): boolean {
return this.map.delete(key);
}
clear(): void {
this.map.clear();
}
size(): number {
return this.map.size;
}
entries(): Array<[K, V]> {
return Array.from(this.map).reverse();
}
}

View file

@ -20,7 +20,12 @@ export const parseSSMLMarks = (ssml: string) => {
markTagEndIndex,
nextMarkIndex !== -1 ? nextMarkIndex : ssml.length,
);
const cleanedChunk = nextChunk.replace(/<[^>]+>/g, '').trimStart();
const cleanedChunk = nextChunk
.replace(/<[^>]+>/g, '')
.replace(/\r\n/g, ' ')
.replace(/\r/g, ' ')
.replace(/\n/g, ' ')
.trimStart();
plainText += cleanedChunk;
const offset = plainText.length - cleanedChunk.length;