- Fix app.py: @app.get -> @app.websocket for /ws/voice route (was returning 403) - Fix app.py: create static_dir before mounting it (AttributeError on startup) - Fix voice.html: AudioWorkletNode constructor (was AudioWorkletProcessor) - Fix voice.html: use ScriptProcessor directly (more reliable) - Fix voice.html: send Float32 directly (server expects float32, was sending Int16) - Fix voice.html: auto-detect ws/wss protocol from page URL - Add Caddy reverse proxy keepalive pings every 15s to prevent timeout - Add detailed message type logging in WebSocket receive loop - Strip Jarvis/Sage personas, rename bot to MoltMic - Add /moltmic voice slash command for portal URL - Update portal URL to https://voice.jezzahehn.com
274 lines
8.6 KiB
Python
274 lines
8.6 KiB
Python
"""WebSocket voice endpoint for browser-based speech-to-text and text-to-speech.
|
|
|
|
Accepts binary PCM audio from browser, transcribes via Deepgram, sends to OpenClaw Gateway,
|
|
and streams TTS audio back to browser.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import random
|
|
import string
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
from fastapi import WebSocket, WebSocketDisconnect
|
|
from pydantic import BaseModel
|
|
|
|
from server.stt import DeepgramSTT
|
|
from server.tts import VeniceKokoroTTS
|
|
from openclaw_client.client import OpenClawClient, OpenClawConfig
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class VoiceSession:
|
|
"""Manages a single voice session."""
|
|
|
|
def __init__(self, session_id: str):
|
|
self.session_id = session_id
|
|
self.transcript_file = Path("logs/voice") / f"{session_id}.jsonl"
|
|
self.transcript_file.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Audio buffering
|
|
self.audio_buffer = bytearray()
|
|
self.buffer_duration = 0.0 # Seconds
|
|
self._buffer_lock = asyncio.Lock()
|
|
|
|
# Audio processing
|
|
self.sample_rate = 16000
|
|
self.channel_count = 1
|
|
self.bits_per_sample = 32
|
|
|
|
# Engines (self-contained, don't share with run.py)
|
|
self.stt = None
|
|
self.tts = None
|
|
self.openclaw = None
|
|
|
|
# Session state
|
|
self.connected = False
|
|
self.transcript = []
|
|
|
|
logger.info(f"Created voice session {session_id}")
|
|
|
|
async def initialize(self):
|
|
"""Initialize STT, TTS, and OpenClaw client."""
|
|
# Load env vars
|
|
deepgram_key = os.getenv("DEEPGRAM_API_KEY")
|
|
venice_key = os.getenv("VENICE_API_KEY")
|
|
openclaw_url = os.getenv("OPENCLAW_BASE_URL", "ws://192.168.50.9:18789")
|
|
openclaw_token = os.getenv("OPENCLAW_AUTH_TOKEN")
|
|
|
|
if not deepgram_key or not venice_key:
|
|
raise ValueError("Missing required API keys")
|
|
|
|
# Initialize STT
|
|
self.stt = DeepgramSTT(
|
|
api_key=deepgram_key,
|
|
model="nova-3",
|
|
language="en",
|
|
sample_rate=self.sample_rate,
|
|
)
|
|
|
|
# Initialize TTS
|
|
self.tts = VeniceKokoroTTS(
|
|
api_key=venice_key,
|
|
voice="am_liam",
|
|
base_url="https://api.venice.ai/api/v1",
|
|
)
|
|
|
|
# Initialize OpenClaw client
|
|
self.openclaw = OpenClawClient(
|
|
config=OpenClawConfig(
|
|
base_url=openclaw_url,
|
|
auth_token=openclaw_token,
|
|
timeout=30.0,
|
|
agent_id="main",
|
|
)
|
|
)
|
|
|
|
await self.openclaw.connect()
|
|
|
|
logger.info(f"Voice session {self.session_id} initialized")
|
|
|
|
async def close(self):
|
|
"""Clean up resources."""
|
|
self.connected = False
|
|
|
|
if self.openclaw:
|
|
await self.openclaw.disconnect()
|
|
|
|
logger.info(f"Voice session {self.session_id} closed")
|
|
|
|
def _new_id(self) -> str:
|
|
"""Generate random session ID."""
|
|
return "".join(random.choices(string.ascii_letters + string.digits, k=8))
|
|
|
|
async def process_audio_chunk(self, data: bytes):
|
|
"""Process incoming audio chunk."""
|
|
async with self._buffer_lock:
|
|
self.audio_buffer.extend(data)
|
|
|
|
# Calculate duration
|
|
chunk_size = len(data)
|
|
chunk_duration = chunk_size / (self.sample_rate * self.channel_count * 4)
|
|
|
|
self.buffer_duration += chunk_duration
|
|
|
|
# Buffer until ~1 second
|
|
if self.buffer_duration >= 0.8: # Slightly less than 1 second
|
|
await self._transcribe_buffered_audio()
|
|
|
|
async def _transcribe_buffered_audio(self):
|
|
"""Transcribe accumulated audio and send to OpenClaw."""
|
|
async with self._buffer_lock:
|
|
if not self.audio_buffer:
|
|
return
|
|
|
|
# Convert bytearray to numpy array
|
|
audio_data = np.frombuffer(bytes(self.audio_buffer), dtype=np.float32)
|
|
|
|
# Transcribe
|
|
try:
|
|
result = await self.stt.transcribe_async(audio_data)
|
|
|
|
if result.text.strip():
|
|
# Send to OpenClaw
|
|
response = await self.openclaw.send_message(
|
|
agent="main",
|
|
message=result.text,
|
|
speaker="voice_user",
|
|
)
|
|
|
|
# Log transcript
|
|
timestamp = asyncio.get_event_loop().time()
|
|
entry = {
|
|
"timestamp": timestamp,
|
|
"session_id": self.session_id,
|
|
"transcript": result.text,
|
|
"response": response,
|
|
}
|
|
|
|
self.transcript.append(entry)
|
|
|
|
# Write to file
|
|
with open(self.transcript_file, "a") as f:
|
|
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
|
|
|
logger.info(
|
|
f"Session {self.session_id}: "
|
|
f'"{result.text[:50]}..." -> "{response[:50]}..."'
|
|
)
|
|
|
|
# Clear buffer
|
|
self.audio_buffer.clear()
|
|
self.buffer_duration = 0.0
|
|
|
|
except Exception as e:
|
|
logger.error(f"Transcription error: {e}")
|
|
|
|
async def synthesize_response(self, text: str):
|
|
"""Synthesize TTS audio from response text."""
|
|
try:
|
|
audio = await self.tts.generate_async(
|
|
text=text,
|
|
voice_ref_path=None,
|
|
emotion_exaggeration=0.8,
|
|
)
|
|
|
|
return audio
|
|
|
|
except Exception as e:
|
|
logger.error(f"TTS synthesis error: {e}")
|
|
return None
|
|
|
|
def get_transcript(self) -> list:
|
|
"""Get transcript history."""
|
|
return self.transcript
|
|
|
|
|
|
async def handle_voice_websocket(websocket: WebSocket, session_id: str):
|
|
"""Handle WebSocket connection for voice session."""
|
|
session = VoiceSession(session_id)
|
|
|
|
await websocket.accept()
|
|
session.connected = True
|
|
|
|
logger.info(f"WebSocket connected for session {session_id}")
|
|
|
|
# Initialize session
|
|
try:
|
|
await session.initialize()
|
|
|
|
# Send welcome message
|
|
await websocket.send_json({
|
|
"type": "welcome",
|
|
"message": "Connected to voice portal",
|
|
})
|
|
|
|
# Background task: send periodic pings to keep connection alive through Caddy
|
|
async def keepalive():
|
|
while session.connected:
|
|
try:
|
|
await asyncio.sleep(15)
|
|
if session.connected:
|
|
await websocket.send_json({"type": "ping"})
|
|
except Exception:
|
|
break
|
|
|
|
keepalive_task = asyncio.create_task(keepalive())
|
|
|
|
# Receive and process audio
|
|
chunk_count = 0
|
|
while session.connected:
|
|
try:
|
|
msg = await websocket.receive()
|
|
msg_type = msg.get("type", "unknown")
|
|
|
|
if msg_type == "websocket.disconnect":
|
|
session.connected = False
|
|
logger.info(f"WebSocket disconnected for session {session_id}")
|
|
break
|
|
|
|
elif msg_type == "websocket.receive":
|
|
if "bytes" in msg:
|
|
chunk_count += 1
|
|
if chunk_count <= 5 or chunk_count % 100 == 0:
|
|
logger.info(f"Audio chunk #{chunk_count}: {len(msg['bytes'])} bytes")
|
|
await session.process_audio_chunk(msg["bytes"])
|
|
elif "text" in msg:
|
|
pass
|
|
else:
|
|
logger.warning(f"Unknown receive msg: {msg}")
|
|
|
|
else:
|
|
logger.warning(f"Unknown WebSocket msg type: {msg_type}: {msg}")
|
|
|
|
except WebSocketDisconnect:
|
|
session.connected = False
|
|
logger.info(f"WebSocket disconnected for session {session_id}")
|
|
break
|
|
|
|
except Exception as e:
|
|
logger.error(f"WebSocket error in receive loop: {e}", exc_info=True)
|
|
session.connected = False
|
|
break
|
|
|
|
keepalive_task.cancel()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Session error: {e}", exc_info=True)
|
|
try:
|
|
await websocket.close(code=1011, reason=str(e))
|
|
except Exception:
|
|
pass
|
|
|
|
finally:
|
|
await session.close()
|
|
|
|
|
|
def create_session_id() -> str:
|
|
"""Generate a random session ID."""
|
|
return "".join(random.choices(string.ascii_letters + string.digits, k=8))
|