diff --git a/server/tts.py b/server/tts.py index df3230d..8f69c51 100644 --- a/server/tts.py +++ b/server/tts.py @@ -4,6 +4,7 @@ Integrated Chatterbox-Turbo TTS with zero-shot voice cloning. Supports native paralinguistic sounds ([laugh], [sigh], etc.) """ +import asyncio import io import re import time @@ -13,6 +14,7 @@ from typing import Dict, List, Optional, Tuple import numpy as np import torch +import httpx from utils.logging import get_logger @@ -436,6 +438,129 @@ class ChatterboxTTS: pass +class VeniceKokoroTTS: + """ + Venice Kokoro TTS provider. + + Uses Venice.ai's Kokoro model for text-to-speech. + """ + + def __init__( + self, + api_key: str, + voice: str = "am_liam", + base_url: str = "https://api.venice.ai/api/v1", + ): + """ + Initialize Venice Kokoro TTS engine. + + Args: + api_key: Venice.ai API key + voice: Voice name (default: "am_liam") + base_url: Venice.ai API base URL + """ + self.api_key = api_key + self.voice = voice + self.base_url = base_url.rstrip("/") + + logger.info(f"Initialized Venice Kokoro TTS engine (voice: {voice})") + + async def generate_async( + self, + text: str, + voice_ref_path: Optional[Path] = None, + emotion_exaggeration: Optional[float] = None, + ) -> np.ndarray: + """ + Generate speech from text. + + Args: + text: Text to synthesize + voice_ref_path: Not used by Venice (reserved for interface compatibility) + emotion_exaggeration: Not used by Venice (reserved for interface compatibility) + + Returns: + Audio array (float32, 16kHz mono) + """ + start_time = time.time() + + logger.info(f"Generating TTS via Venice: '{text[:50]}...'") + + if not text or not text.strip(): + logger.warning("Empty text, returning silence") + duration = 1.0 + audio = np.zeros( + int(duration * 16000), dtype=np.float32 + ) + return audio + + try: + # Prepare request payload + payload = { + "model": "kokoro", + "input": text.strip(), + "voice": self.voice, + } + + # Make API request + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.post( + f"{self.base_url}/audio/speech", + json=payload, + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + ) + response.raise_for_status() + + # Get audio bytes + audio_bytes = response.content + + # Decode audio + from scipy.io import wavfile + + sr, audio = wavfile.read(io.BytesIO(audio_bytes)) + + # Convert to float32 + if audio.dtype != np.float32: + audio = audio.astype(np.float32) / 32768.0 + + # Check sample rate and resample if needed + if sr != 16000: + from scipy import signal as scipy_signal + target_samples = int(len(audio) * 16000 / sr) + audio = scipy_signal.resample(audio, target_samples).astype(np.float32) + + # Ensure mono + if len(audio.shape) > 1: + audio = audio.mean(axis=1) + + # Update stats + processing_time = time.time() - start_time + duration = len(audio) / 16000 + logger.info( + f"Generated {duration:.2f}s audio via Venice in {processing_time:.2f}s " + f"(RTF: {processing_time / duration:.2f})" + ) + + return audio + + except Exception as e: + logger.error(f"Venice TTS generation error: {e}") + # Return silence on error (16kHz processing format) + duration = 2.0 + audio = np.zeros( + int(duration * 16000), dtype=np.float32 + ) + return audio + + async def close(self): + """Cleanup resources.""" + # Nothing to close for Venice provider + pass + + class TTSSynthesizer: """ Pipeline TTS synthesizer. @@ -762,3 +887,31 @@ async def create_tts_synthesizer( ) return synthesizer + + +def create_tts_engine(provider: str, config: dict) -> ChatterboxTTS | VeniceKokoroTTS: + """ + Create TTS engine based on provider. + + Args: + provider: Provider name ("chatterbox" or "venice") + config: Configuration dictionary + + Returns: + TTS engine instance + """ + if provider == "venice": + return VeniceKokoroTTS( + api_key=config.get("api_key", ""), + voice=config.get("voice", "am_liam"), + base_url=config.get("base_url", "https://api.venice.ai/api/v1"), + ) + else: + # Default to Chatterbox + return ChatterboxTTS( + config=TTSConfig( + device=config.get("device", "cuda"), + sample_rate=config.get("sample_rate", 24000), + ), + voice_references={}, + )