- Fix app.py: @app.get -> @app.websocket for /ws/voice route (was returning 403) - Fix app.py: create static_dir before mounting it (AttributeError on startup) - Fix voice.html: AudioWorkletNode constructor (was AudioWorkletProcessor) - Fix voice.html: use ScriptProcessor directly (more reliable) - Fix voice.html: send Float32 directly (server expects float32, was sending Int16) - Fix voice.html: auto-detect ws/wss protocol from page URL - Add Caddy reverse proxy keepalive pings every 15s to prevent timeout - Add detailed message type logging in WebSocket receive loop - Strip Jarvis/Sage personas, rename bot to MoltMic - Add /moltmic voice slash command for portal URL - Update portal URL to https://voice.jezzahehn.com
457 lines
14 KiB
Python
457 lines
14 KiB
Python
"""FastAPI Server - OpenAI-compatible TTS/STT API.
|
|
|
|
Provides HTTP endpoints for:
|
|
- Text-to-Speech (OpenAI /v1/audio/speech compatible)
|
|
- Speech-to-Text (OpenAI /v1/audio/transcriptions compatible)
|
|
- Health checks and status
|
|
- WebSocket voice endpoint for browser-based speech
|
|
|
|
Shares STT and TTS engines with Discord bot for efficiency.
|
|
"""
|
|
|
|
import asyncio
|
|
import io
|
|
import tempfile
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Literal, Optional
|
|
|
|
import numpy as np
|
|
import soundfile as sf
|
|
from fastapi import FastAPI, File, Form, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import Response, StreamingResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from pydantic import BaseModel, Field
|
|
|
|
from server.stt import FasterWhisperSTT, STTTranscriber
|
|
from server.tts import ChatterboxTTS, TTSSynthesizer
|
|
from server.voice_ws import handle_voice_websocket, create_session_id
|
|
from utils.logging import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
# ============================================================================
|
|
# Request/Response Models
|
|
# ============================================================================
|
|
|
|
|
|
class TTSRequest(BaseModel):
|
|
"""OpenAI-compatible TTS request."""
|
|
|
|
model: str = Field(
|
|
default="chatterbox",
|
|
description="TTS model to use (ignored, using configured model)",
|
|
)
|
|
input: str = Field(..., description="Text to synthesize", max_length=4000)
|
|
voice: str = Field(
|
|
..., description="Voice to use (jarvis, sage, or configured voices)"
|
|
)
|
|
response_format: Literal["pcm", "wav", "mp3"] = Field(
|
|
default="wav", description="Audio format"
|
|
)
|
|
speed: float = Field(
|
|
default=1.0, ge=0.25, le=4.0, description="Playback speed (not supported)"
|
|
)
|
|
|
|
|
|
class TranscriptionResponse(BaseModel):
|
|
"""OpenAI-compatible transcription response."""
|
|
|
|
text: str
|
|
|
|
|
|
class HealthResponse(BaseModel):
|
|
"""Health check response."""
|
|
|
|
status: str
|
|
models: dict
|
|
gpu: dict
|
|
uptime: float
|
|
|
|
|
|
# ============================================================================
|
|
# FastAPI Application
|
|
# ============================================================================
|
|
|
|
|
|
class VoiceAPIServer:
|
|
"""
|
|
Voice API server.
|
|
|
|
Provides OpenAI-compatible TTS and STT endpoints.
|
|
Shares engines with Discord bot for efficiency.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
tts_synthesizer: TTSSynthesizer,
|
|
stt_transcriber: STTTranscriber,
|
|
):
|
|
"""
|
|
Initialize API server.
|
|
|
|
Args:
|
|
tts_synthesizer: TTS synthesizer instance
|
|
stt_transcriber: STT transcriber instance
|
|
"""
|
|
self.tts_synthesizer = tts_synthesizer
|
|
self.stt_transcriber = stt_transcriber
|
|
self.start_time = time.time()
|
|
|
|
# Create FastAPI app
|
|
self.app = FastAPI(
|
|
title="Jarvis Voice API",
|
|
description="OpenAI-compatible TTS/STT API",
|
|
version="1.0.0",
|
|
)
|
|
|
|
# Add CORS middleware
|
|
self.app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"], # Configure based on security needs
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Create static files directory
|
|
self.static_dir = Path("server/static")
|
|
self.static_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Mount static files
|
|
self.app.mount("/static", StaticFiles(directory=str(self.static_dir)), name="static")
|
|
|
|
# Register routes
|
|
self._register_routes()
|
|
|
|
# Stats
|
|
self.total_tts_requests = 0
|
|
self.total_stt_requests = 0
|
|
self.total_errors = 0
|
|
|
|
logger.info("Voice API server initialized")
|
|
|
|
def _register_routes(self) -> None:
|
|
"""Register API routes."""
|
|
|
|
@self.app.get("/health", response_model=HealthResponse)
|
|
async def health_check():
|
|
"""Health check endpoint."""
|
|
return await self._health_check()
|
|
|
|
@self.app.get("/voice")
|
|
async def get_voice_page():
|
|
"""Serve voice portal HTML page."""
|
|
static_file = self.static_dir / "voice.html"
|
|
if static_file.exists():
|
|
return Response(content=static_file.read_text(), media_type="text/html")
|
|
raise HTTPException(status_code=404, detail="Voice page not found")
|
|
|
|
@self.app.websocket("/ws/voice/{session_id}")
|
|
async def voice_websocket(session_id: str, websocket: WebSocket):
|
|
"""WebSocket endpoint for voice session."""
|
|
await handle_voice_websocket(websocket, session_id)
|
|
|
|
@self.app.post("/v1/audio/speech")
|
|
async def create_speech(request: TTSRequest):
|
|
"""
|
|
OpenAI-compatible TTS endpoint.
|
|
|
|
Generate speech from text.
|
|
"""
|
|
return await self._create_speech(request)
|
|
|
|
@self.app.post(
|
|
"/v1/audio/transcriptions", response_model=TranscriptionResponse
|
|
)
|
|
async def create_transcription(
|
|
file: UploadFile = File(...),
|
|
model: str = Form(default="whisper-1"),
|
|
language: Optional[str] = Form(default=None),
|
|
prompt: Optional[str] = Form(default=None),
|
|
response_format: str = Form(default="json"),
|
|
temperature: float = Form(default=0.0),
|
|
):
|
|
"""
|
|
OpenAI-compatible STT endpoint.
|
|
|
|
Transcribe audio to text.
|
|
"""
|
|
return await self._create_transcription(
|
|
file=file,
|
|
model=model,
|
|
language=language,
|
|
prompt=prompt,
|
|
response_format=response_format,
|
|
temperature=temperature,
|
|
)
|
|
|
|
@self.app.get("/")
|
|
async def root():
|
|
"""Root endpoint."""
|
|
return {
|
|
"name": "Jarvis Voice API",
|
|
"version": "1.0.0",
|
|
"endpoints": {
|
|
"health": "/health",
|
|
"tts": "/v1/audio/speech",
|
|
"stt": "/v1/audio/transcriptions",
|
|
},
|
|
}
|
|
|
|
async def _health_check(self) -> HealthResponse:
|
|
"""
|
|
Health check.
|
|
|
|
Returns:
|
|
Health status
|
|
"""
|
|
try:
|
|
# Check GPU availability
|
|
import torch
|
|
|
|
gpu_available = torch.cuda.is_available()
|
|
gpu_memory = (
|
|
torch.cuda.get_device_properties(0).total_memory / 1e9
|
|
if gpu_available
|
|
else 0
|
|
)
|
|
|
|
return HealthResponse(
|
|
status="ok",
|
|
models={
|
|
"tts": self.tts_synthesizer.engine.config.device,
|
|
"stt": self.stt_transcriber.engine.device,
|
|
},
|
|
gpu={
|
|
"available": gpu_available,
|
|
"memory_gb": round(gpu_memory, 2),
|
|
},
|
|
uptime=time.time() - self.start_time,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Health check failed: {e}")
|
|
return HealthResponse(
|
|
status="degraded",
|
|
models={"tts": "unknown", "stt": "unknown"},
|
|
gpu={"available": False, "memory_gb": 0},
|
|
uptime=time.time() - self.start_time,
|
|
)
|
|
|
|
async def _create_speech(self, request: TTSRequest) -> Response:
|
|
"""
|
|
Generate speech from text.
|
|
|
|
Args:
|
|
request: TTS request
|
|
|
|
Returns:
|
|
Audio response
|
|
"""
|
|
try:
|
|
logger.info(
|
|
f"TTS request: voice={request.voice}, "
|
|
f"format={request.response_format}, "
|
|
f"text='{request.input[:50]}...'"
|
|
)
|
|
|
|
# Validate voice
|
|
voice_lower = request.voice.lower()
|
|
if voice_lower not in self.tts_synthesizer.voice_map:
|
|
available_voices = ", ".join(
|
|
self.tts_synthesizer.voice_map.keys()
|
|
)
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Invalid voice '{request.voice}'. "
|
|
f"Available: {available_voices}",
|
|
)
|
|
|
|
# Generate audio
|
|
audio = await self.tts_synthesizer.synthesize(
|
|
agent=voice_lower, text=request.input
|
|
)
|
|
|
|
if audio is None:
|
|
raise HTTPException(
|
|
status_code=500, detail="TTS generation failed"
|
|
)
|
|
|
|
# Convert to requested format
|
|
audio_bytes = self._convert_audio(
|
|
audio=audio,
|
|
sample_rate=self.tts_synthesizer.engine.config.sample_rate,
|
|
format=request.response_format,
|
|
)
|
|
|
|
# Determine content type
|
|
content_type = {
|
|
"pcm": "audio/pcm",
|
|
"wav": "audio/wav",
|
|
"mp3": "audio/mpeg",
|
|
}[request.response_format]
|
|
|
|
self.total_tts_requests += 1
|
|
|
|
return Response(content=audio_bytes, media_type=content_type)
|
|
|
|
except HTTPException:
|
|
self.total_errors += 1
|
|
raise
|
|
|
|
except Exception as e:
|
|
logger.error(f"TTS error: {e}", exc_info=True)
|
|
self.total_errors += 1
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
async def _create_transcription(
|
|
self,
|
|
file: UploadFile,
|
|
model: str,
|
|
language: Optional[str],
|
|
prompt: Optional[str],
|
|
response_format: str,
|
|
temperature: float,
|
|
) -> TranscriptionResponse:
|
|
"""
|
|
Transcribe audio to text.
|
|
|
|
Args:
|
|
file: Audio file
|
|
model: Model name (ignored)
|
|
language: Language hint
|
|
prompt: Prompt for context
|
|
response_format: Response format (json only supported)
|
|
temperature: Temperature (ignored)
|
|
|
|
Returns:
|
|
Transcription response
|
|
"""
|
|
try:
|
|
logger.info(
|
|
f"STT request: filename={file.filename}, "
|
|
f"content_type={file.content_type}"
|
|
)
|
|
|
|
# Read audio file
|
|
audio_bytes = await file.read()
|
|
|
|
# Load audio with soundfile
|
|
audio, sample_rate = sf.read(io.BytesIO(audio_bytes))
|
|
|
|
# Convert to mono if stereo
|
|
if len(audio.shape) > 1:
|
|
audio = audio.mean(axis=1)
|
|
|
|
# Convert to float32
|
|
audio = audio.astype(np.float32)
|
|
|
|
# Resample if needed (STT expects 16kHz)
|
|
if sample_rate != 16000:
|
|
from scipy import signal
|
|
|
|
audio = signal.resample(
|
|
audio, int(len(audio) * 16000 / sample_rate)
|
|
)
|
|
|
|
# Transcribe
|
|
result = await self.stt_transcriber.transcribe_async(audio)
|
|
|
|
if not result or not result.text:
|
|
raise HTTPException(
|
|
status_code=500, detail="Transcription failed"
|
|
)
|
|
|
|
self.total_stt_requests += 1
|
|
|
|
return TranscriptionResponse(text=result.text)
|
|
|
|
except HTTPException:
|
|
self.total_errors += 1
|
|
raise
|
|
|
|
except Exception as e:
|
|
logger.error(f"STT error: {e}", exc_info=True)
|
|
self.total_errors += 1
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
def _convert_audio(
|
|
self, audio: np.ndarray, sample_rate: int, format: str
|
|
) -> bytes:
|
|
"""
|
|
Convert audio to requested format.
|
|
|
|
Args:
|
|
audio: Audio array (float32)
|
|
sample_rate: Sample rate
|
|
format: Target format (pcm, wav, mp3)
|
|
|
|
Returns:
|
|
Audio bytes
|
|
"""
|
|
if format == "pcm":
|
|
# Convert to int16 PCM
|
|
audio_int16 = (audio * 32767).astype(np.int16)
|
|
return audio_int16.tobytes()
|
|
|
|
elif format == "wav":
|
|
# Write WAV file
|
|
buffer = io.BytesIO()
|
|
sf.write(buffer, audio, sample_rate, format="WAV")
|
|
buffer.seek(0)
|
|
return buffer.read()
|
|
|
|
elif format == "mp3":
|
|
# MP3 encoding requires additional library (pydub, ffmpeg)
|
|
# For now, return WAV and document MP3 needs ffmpeg
|
|
logger.warning("MP3 format not fully supported, returning WAV")
|
|
buffer = io.BytesIO()
|
|
sf.write(buffer, audio, sample_rate, format="WAV")
|
|
buffer.seek(0)
|
|
return buffer.read()
|
|
|
|
else:
|
|
raise ValueError(f"Unsupported format: {format}")
|
|
|
|
def get_stats(self) -> dict:
|
|
"""
|
|
Get API server statistics.
|
|
|
|
Returns:
|
|
Statistics dictionary
|
|
"""
|
|
return {
|
|
"uptime": time.time() - self.start_time,
|
|
"total_tts_requests": self.total_tts_requests,
|
|
"total_stt_requests": self.total_stt_requests,
|
|
"total_errors": self.total_errors,
|
|
"tts_stats": self.tts_synthesizer.get_stats(),
|
|
"stt_stats": self.stt_transcriber.get_stats(),
|
|
}
|
|
|
|
|
|
# ============================================================================
|
|
# Factory Function
|
|
# ============================================================================
|
|
|
|
|
|
def create_api_server(
|
|
tts_synthesizer: TTSSynthesizer,
|
|
stt_transcriber: STTTranscriber,
|
|
) -> VoiceAPIServer:
|
|
"""
|
|
Create API server with default settings.
|
|
|
|
Args:
|
|
tts_synthesizer: TTS synthesizer instance
|
|
stt_transcriber: STT transcriber instance
|
|
|
|
Returns:
|
|
VoiceAPIServer instance
|
|
"""
|
|
return VoiceAPIServer(
|
|
tts_synthesizer=tts_synthesizer,
|
|
stt_transcriber=stt_transcriber,
|
|
)
|