openclaw-voice/server/app.py
Jezza Hehn 3450e57ca6 Fix voice portal: WebSocket routing, Caddy keepalive, audio pipeline
- Fix app.py: @app.get -> @app.websocket for /ws/voice route (was returning 403)
- Fix app.py: create static_dir before mounting it (AttributeError on startup)
- Fix voice.html: AudioWorkletNode constructor (was AudioWorkletProcessor)
- Fix voice.html: use ScriptProcessor directly (more reliable)
- Fix voice.html: send Float32 directly (server expects float32, was sending Int16)
- Fix voice.html: auto-detect ws/wss protocol from page URL
- Add Caddy reverse proxy keepalive pings every 15s to prevent timeout
- Add detailed message type logging in WebSocket receive loop
- Strip Jarvis/Sage personas, rename bot to MoltMic
- Add /moltmic voice slash command for portal URL
- Update portal URL to https://voice.jezzahehn.com
2026-04-10 04:47:31 +00:00

457 lines
14 KiB
Python

"""FastAPI Server - OpenAI-compatible TTS/STT API.
Provides HTTP endpoints for:
- Text-to-Speech (OpenAI /v1/audio/speech compatible)
- Speech-to-Text (OpenAI /v1/audio/transcriptions compatible)
- Health checks and status
- WebSocket voice endpoint for browser-based speech
Shares STT and TTS engines with Discord bot for efficiency.
"""
import asyncio
import io
import tempfile
import time
from pathlib import Path
from typing import Literal, Optional
import numpy as np
import soundfile as sf
from fastapi import FastAPI, File, Form, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response, StreamingResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field
from server.stt import FasterWhisperSTT, STTTranscriber
from server.tts import ChatterboxTTS, TTSSynthesizer
from server.voice_ws import handle_voice_websocket, create_session_id
from utils.logging import get_logger
logger = get_logger(__name__)
# ============================================================================
# Request/Response Models
# ============================================================================
class TTSRequest(BaseModel):
"""OpenAI-compatible TTS request."""
model: str = Field(
default="chatterbox",
description="TTS model to use (ignored, using configured model)",
)
input: str = Field(..., description="Text to synthesize", max_length=4000)
voice: str = Field(
..., description="Voice to use (jarvis, sage, or configured voices)"
)
response_format: Literal["pcm", "wav", "mp3"] = Field(
default="wav", description="Audio format"
)
speed: float = Field(
default=1.0, ge=0.25, le=4.0, description="Playback speed (not supported)"
)
class TranscriptionResponse(BaseModel):
"""OpenAI-compatible transcription response."""
text: str
class HealthResponse(BaseModel):
"""Health check response."""
status: str
models: dict
gpu: dict
uptime: float
# ============================================================================
# FastAPI Application
# ============================================================================
class VoiceAPIServer:
"""
Voice API server.
Provides OpenAI-compatible TTS and STT endpoints.
Shares engines with Discord bot for efficiency.
"""
def __init__(
self,
tts_synthesizer: TTSSynthesizer,
stt_transcriber: STTTranscriber,
):
"""
Initialize API server.
Args:
tts_synthesizer: TTS synthesizer instance
stt_transcriber: STT transcriber instance
"""
self.tts_synthesizer = tts_synthesizer
self.stt_transcriber = stt_transcriber
self.start_time = time.time()
# Create FastAPI app
self.app = FastAPI(
title="Jarvis Voice API",
description="OpenAI-compatible TTS/STT API",
version="1.0.0",
)
# Add CORS middleware
self.app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Configure based on security needs
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Create static files directory
self.static_dir = Path("server/static")
self.static_dir.mkdir(parents=True, exist_ok=True)
# Mount static files
self.app.mount("/static", StaticFiles(directory=str(self.static_dir)), name="static")
# Register routes
self._register_routes()
# Stats
self.total_tts_requests = 0
self.total_stt_requests = 0
self.total_errors = 0
logger.info("Voice API server initialized")
def _register_routes(self) -> None:
"""Register API routes."""
@self.app.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint."""
return await self._health_check()
@self.app.get("/voice")
async def get_voice_page():
"""Serve voice portal HTML page."""
static_file = self.static_dir / "voice.html"
if static_file.exists():
return Response(content=static_file.read_text(), media_type="text/html")
raise HTTPException(status_code=404, detail="Voice page not found")
@self.app.websocket("/ws/voice/{session_id}")
async def voice_websocket(session_id: str, websocket: WebSocket):
"""WebSocket endpoint for voice session."""
await handle_voice_websocket(websocket, session_id)
@self.app.post("/v1/audio/speech")
async def create_speech(request: TTSRequest):
"""
OpenAI-compatible TTS endpoint.
Generate speech from text.
"""
return await self._create_speech(request)
@self.app.post(
"/v1/audio/transcriptions", response_model=TranscriptionResponse
)
async def create_transcription(
file: UploadFile = File(...),
model: str = Form(default="whisper-1"),
language: Optional[str] = Form(default=None),
prompt: Optional[str] = Form(default=None),
response_format: str = Form(default="json"),
temperature: float = Form(default=0.0),
):
"""
OpenAI-compatible STT endpoint.
Transcribe audio to text.
"""
return await self._create_transcription(
file=file,
model=model,
language=language,
prompt=prompt,
response_format=response_format,
temperature=temperature,
)
@self.app.get("/")
async def root():
"""Root endpoint."""
return {
"name": "Jarvis Voice API",
"version": "1.0.0",
"endpoints": {
"health": "/health",
"tts": "/v1/audio/speech",
"stt": "/v1/audio/transcriptions",
},
}
async def _health_check(self) -> HealthResponse:
"""
Health check.
Returns:
Health status
"""
try:
# Check GPU availability
import torch
gpu_available = torch.cuda.is_available()
gpu_memory = (
torch.cuda.get_device_properties(0).total_memory / 1e9
if gpu_available
else 0
)
return HealthResponse(
status="ok",
models={
"tts": self.tts_synthesizer.engine.config.device,
"stt": self.stt_transcriber.engine.device,
},
gpu={
"available": gpu_available,
"memory_gb": round(gpu_memory, 2),
},
uptime=time.time() - self.start_time,
)
except Exception as e:
logger.error(f"Health check failed: {e}")
return HealthResponse(
status="degraded",
models={"tts": "unknown", "stt": "unknown"},
gpu={"available": False, "memory_gb": 0},
uptime=time.time() - self.start_time,
)
async def _create_speech(self, request: TTSRequest) -> Response:
"""
Generate speech from text.
Args:
request: TTS request
Returns:
Audio response
"""
try:
logger.info(
f"TTS request: voice={request.voice}, "
f"format={request.response_format}, "
f"text='{request.input[:50]}...'"
)
# Validate voice
voice_lower = request.voice.lower()
if voice_lower not in self.tts_synthesizer.voice_map:
available_voices = ", ".join(
self.tts_synthesizer.voice_map.keys()
)
raise HTTPException(
status_code=400,
detail=f"Invalid voice '{request.voice}'. "
f"Available: {available_voices}",
)
# Generate audio
audio = await self.tts_synthesizer.synthesize(
agent=voice_lower, text=request.input
)
if audio is None:
raise HTTPException(
status_code=500, detail="TTS generation failed"
)
# Convert to requested format
audio_bytes = self._convert_audio(
audio=audio,
sample_rate=self.tts_synthesizer.engine.config.sample_rate,
format=request.response_format,
)
# Determine content type
content_type = {
"pcm": "audio/pcm",
"wav": "audio/wav",
"mp3": "audio/mpeg",
}[request.response_format]
self.total_tts_requests += 1
return Response(content=audio_bytes, media_type=content_type)
except HTTPException:
self.total_errors += 1
raise
except Exception as e:
logger.error(f"TTS error: {e}", exc_info=True)
self.total_errors += 1
raise HTTPException(status_code=500, detail=str(e))
async def _create_transcription(
self,
file: UploadFile,
model: str,
language: Optional[str],
prompt: Optional[str],
response_format: str,
temperature: float,
) -> TranscriptionResponse:
"""
Transcribe audio to text.
Args:
file: Audio file
model: Model name (ignored)
language: Language hint
prompt: Prompt for context
response_format: Response format (json only supported)
temperature: Temperature (ignored)
Returns:
Transcription response
"""
try:
logger.info(
f"STT request: filename={file.filename}, "
f"content_type={file.content_type}"
)
# Read audio file
audio_bytes = await file.read()
# Load audio with soundfile
audio, sample_rate = sf.read(io.BytesIO(audio_bytes))
# Convert to mono if stereo
if len(audio.shape) > 1:
audio = audio.mean(axis=1)
# Convert to float32
audio = audio.astype(np.float32)
# Resample if needed (STT expects 16kHz)
if sample_rate != 16000:
from scipy import signal
audio = signal.resample(
audio, int(len(audio) * 16000 / sample_rate)
)
# Transcribe
result = await self.stt_transcriber.transcribe_async(audio)
if not result or not result.text:
raise HTTPException(
status_code=500, detail="Transcription failed"
)
self.total_stt_requests += 1
return TranscriptionResponse(text=result.text)
except HTTPException:
self.total_errors += 1
raise
except Exception as e:
logger.error(f"STT error: {e}", exc_info=True)
self.total_errors += 1
raise HTTPException(status_code=500, detail=str(e))
def _convert_audio(
self, audio: np.ndarray, sample_rate: int, format: str
) -> bytes:
"""
Convert audio to requested format.
Args:
audio: Audio array (float32)
sample_rate: Sample rate
format: Target format (pcm, wav, mp3)
Returns:
Audio bytes
"""
if format == "pcm":
# Convert to int16 PCM
audio_int16 = (audio * 32767).astype(np.int16)
return audio_int16.tobytes()
elif format == "wav":
# Write WAV file
buffer = io.BytesIO()
sf.write(buffer, audio, sample_rate, format="WAV")
buffer.seek(0)
return buffer.read()
elif format == "mp3":
# MP3 encoding requires additional library (pydub, ffmpeg)
# For now, return WAV and document MP3 needs ffmpeg
logger.warning("MP3 format not fully supported, returning WAV")
buffer = io.BytesIO()
sf.write(buffer, audio, sample_rate, format="WAV")
buffer.seek(0)
return buffer.read()
else:
raise ValueError(f"Unsupported format: {format}")
def get_stats(self) -> dict:
"""
Get API server statistics.
Returns:
Statistics dictionary
"""
return {
"uptime": time.time() - self.start_time,
"total_tts_requests": self.total_tts_requests,
"total_stt_requests": self.total_stt_requests,
"total_errors": self.total_errors,
"tts_stats": self.tts_synthesizer.get_stats(),
"stt_stats": self.stt_transcriber.get_stats(),
}
# ============================================================================
# Factory Function
# ============================================================================
def create_api_server(
tts_synthesizer: TTSSynthesizer,
stt_transcriber: STTTranscriber,
) -> VoiceAPIServer:
"""
Create API server with default settings.
Args:
tts_synthesizer: TTS synthesizer instance
stt_transcriber: STT transcriber instance
Returns:
VoiceAPIServer instance
"""
return VoiceAPIServer(
tts_synthesizer=tts_synthesizer,
stt_transcriber=stt_transcriber,
)