239 lines
7.2 KiB
Python
239 lines
7.2 KiB
Python
"""WebSocket voice endpoint for browser-based speech-to-text and text-to-speech.
|
|
|
|
Accepts binary PCM audio from browser, transcribes via Deepgram, sends to OpenClaw Gateway,
|
|
and streams TTS audio back to browser.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import random
|
|
import string
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
from fastapi import WebSocket, WebSocketDisconnect
|
|
from pydantic import BaseModel
|
|
|
|
from server.stt import DeepgramSTT
|
|
from server.tts import VeniceKokoroTTS
|
|
from openclaw_client.client import OpenClawClient, OpenClawConfig
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class VoiceSession:
|
|
"""Manages a single voice session."""
|
|
|
|
def __init__(self, session_id: str):
|
|
self.session_id = session_id
|
|
self.transcript_file = Path("logs/voice") / f"{session_id}.jsonl"
|
|
self.transcript_file.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Audio buffering
|
|
self.audio_buffer = bytearray()
|
|
self.buffer_duration = 0.0 # Seconds
|
|
self._buffer_lock = asyncio.Lock()
|
|
|
|
# Audio processing
|
|
self.sample_rate = 16000
|
|
self.channel_count = 1
|
|
self.bits_per_sample = 32
|
|
|
|
# Engines (self-contained, don't share with run.py)
|
|
self.stt = None
|
|
self.tts = None
|
|
self.openclaw = None
|
|
|
|
# Session state
|
|
self.connected = False
|
|
self.transcript = []
|
|
|
|
logger.info(f"Created voice session {session_id}")
|
|
|
|
async def initialize(self):
|
|
"""Initialize STT, TTS, and OpenClaw client."""
|
|
# Load env vars
|
|
deepgram_key = os.getenv("DEEPGRAM_API_KEY")
|
|
venice_key = os.getenv("VENICE_API_KEY")
|
|
openclaw_url = os.getenv("OPENCLAW_BASE_URL", "ws://192.168.50.9:18789")
|
|
openclaw_token = os.getenv("OPENCLAW_AUTH_TOKEN")
|
|
|
|
if not deepgram_key or not venice_key:
|
|
raise ValueError("Missing required API keys")
|
|
|
|
# Initialize STT
|
|
self.stt = DeepgramSTT(
|
|
api_key=deepgram_key,
|
|
model="nova-3",
|
|
language="en",
|
|
sample_rate=self.sample_rate,
|
|
)
|
|
|
|
# Initialize TTS
|
|
self.tts = VeniceKokoroTTS(
|
|
api_key=venice_key,
|
|
voice="am_liam",
|
|
base_url="https://api.venice.ai/api/v1",
|
|
)
|
|
|
|
# Initialize OpenClaw client
|
|
self.openclaw = OpenClawClient(
|
|
config=OpenClawConfig(
|
|
base_url=openclaw_url,
|
|
auth_token=openclaw_token,
|
|
timeout=30.0,
|
|
agent_id="main",
|
|
)
|
|
)
|
|
|
|
await self.openclaw.connect()
|
|
|
|
logger.info(f"Voice session {self.session_id} initialized")
|
|
|
|
async def close(self):
|
|
"""Clean up resources."""
|
|
self.connected = False
|
|
|
|
if self.openclaw:
|
|
await self.openclaw.disconnect()
|
|
|
|
logger.info(f"Voice session {self.session_id} closed")
|
|
|
|
def _new_id(self) -> str:
|
|
"""Generate random session ID."""
|
|
return "".join(random.choices(string.ascii_letters + string.digits, k=8))
|
|
|
|
async def process_audio_chunk(self, data: bytes):
|
|
"""Process incoming audio chunk."""
|
|
async with self._buffer_lock:
|
|
self.audio_buffer.extend(data)
|
|
|
|
# Calculate duration
|
|
chunk_size = len(data)
|
|
chunk_duration = chunk_size / (self.sample_rate * self.channel_count * 4)
|
|
|
|
self.buffer_duration += chunk_duration
|
|
|
|
# Buffer until ~1 second
|
|
if self.buffer_duration >= 0.8: # Slightly less than 1 second
|
|
await self._transcribe_buffered_audio()
|
|
|
|
async def _transcribe_buffered_audio(self):
|
|
"""Transcribe accumulated audio and send to OpenClaw."""
|
|
async with self._buffer_lock:
|
|
if not self.audio_buffer:
|
|
return
|
|
|
|
# Convert bytearray to numpy array
|
|
audio_data = np.frombuffer(bytes(self.audio_buffer), dtype=np.float32)
|
|
|
|
# Transcribe
|
|
try:
|
|
result = await self.stt.transcribe_async(audio_data)
|
|
|
|
if result.text.strip():
|
|
# Send to OpenClaw
|
|
response = await self.openclaw.send_message(
|
|
agent="main",
|
|
message=result.text,
|
|
speaker="voice_user",
|
|
)
|
|
|
|
# Log transcript
|
|
timestamp = asyncio.get_event_loop().time()
|
|
entry = {
|
|
"timestamp": timestamp,
|
|
"session_id": self.session_id,
|
|
"transcript": result.text,
|
|
"response": response,
|
|
}
|
|
|
|
self.transcript.append(entry)
|
|
|
|
# Write to file
|
|
with open(self.transcript_file, "a") as f:
|
|
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
|
|
|
logger.info(
|
|
f"Session {self.session_id}: "
|
|
f'"{result.text[:50]}..." -> "{response[:50]}..."'
|
|
)
|
|
|
|
# Clear buffer
|
|
self.audio_buffer.clear()
|
|
self.buffer_duration = 0.0
|
|
|
|
except Exception as e:
|
|
logger.error(f"Transcription error: {e}")
|
|
|
|
async def synthesize_response(self, text: str):
|
|
"""Synthesize TTS audio from response text."""
|
|
try:
|
|
audio = await self.tts.generate_async(
|
|
text=text,
|
|
voice_ref_path=None,
|
|
emotion_exaggeration=0.8,
|
|
)
|
|
|
|
return audio
|
|
|
|
except Exception as e:
|
|
logger.error(f"TTS synthesis error: {e}")
|
|
return None
|
|
|
|
def get_transcript(self) -> list:
|
|
"""Get transcript history."""
|
|
return self.transcript
|
|
|
|
|
|
async def handle_voice_websocket(websocket: WebSocket, session_id: str):
|
|
"""Handle WebSocket connection for voice session."""
|
|
session = VoiceSession(session_id)
|
|
|
|
await websocket.accept()
|
|
session.connected = True
|
|
|
|
logger.info(f"WebSocket connected for session {session_id}")
|
|
|
|
# Initialize session
|
|
try:
|
|
await session.initialize()
|
|
|
|
# Send welcome message
|
|
await websocket.send_json({
|
|
"type": "welcome",
|
|
"message": "Connected to voice portal",
|
|
})
|
|
|
|
# Receive and process audio
|
|
while session.connected:
|
|
try:
|
|
data = await websocket.receive_bytes()
|
|
|
|
# Process audio chunk
|
|
await session.process_audio_chunk(data)
|
|
|
|
except WebSocketDisconnect:
|
|
session.connected = False
|
|
logger.info(f"WebSocket disconnected for session {session_id}")
|
|
break
|
|
|
|
except Exception as e:
|
|
logger.error(f"WebSocket error: {e}")
|
|
session.connected = False
|
|
break
|
|
|
|
except Exception as e:
|
|
logger.error(f"Session initialization error: {e}")
|
|
await websocket.close(code=1011, reason=str(e))
|
|
|
|
finally:
|
|
await session.close()
|
|
|
|
|
|
def create_session_id() -> str:
|
|
"""Generate a random session ID."""
|
|
return "".join(random.choices(string.ascii_letters + string.digits, k=8))
|