it0/packages/services/voice-service/src/api/test_tts.py

431 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Temporary test endpoints for TTS and STT — browser-accessible."""
import asyncio
import io
import os
import struct
import tempfile
import threading
import numpy as np
from fastapi import APIRouter, Request, Query, UploadFile, File
from fastapi.responses import HTMLResponse, Response
router = APIRouter()
_loading_lock = threading.Lock()
def _ensure_local_tts(app) -> bool:
"""Lazy-load Kokoro TTS if not yet loaded. Returns True if available."""
tts = getattr(app.state, "tts", None)
if tts is not None and tts._pipeline is not None:
return True
with _loading_lock:
# Double-check after acquiring lock
tts = getattr(app.state, "tts", None)
if tts is not None and tts._pipeline is not None:
return True
try:
from ..tts.kokoro_service import KokoroTTSService, _patch_misaki_compat
from ..config.settings import settings
_patch_misaki_compat()
from kokoro import KPipeline
svc = KokoroTTSService(model=settings.kokoro_model, voice=settings.kokoro_voice)
svc._pipeline = KPipeline(lang_code='z')
app.state.tts = svc
print(f"[lazy] Local TTS loaded: {settings.kokoro_model} voice={settings.kokoro_voice}", flush=True)
return True
except Exception as e:
print(f"[lazy] Failed to load local TTS: {e}", flush=True)
return False
def _ensure_local_stt(app) -> bool:
"""Lazy-load faster-whisper STT if not yet loaded. Returns True if available."""
stt = getattr(app.state, "stt", None)
if stt is not None and stt._model is not None:
return True
with _loading_lock:
# Double-check after acquiring lock
stt = getattr(app.state, "stt", None)
if stt is not None and stt._model is not None:
return True
try:
from ..stt.whisper_service import WhisperSTTService
from faster_whisper import WhisperModel
from ..config.settings import settings
svc = WhisperSTTService(
model=settings.whisper_model,
device=settings.device,
language=settings.whisper_language,
)
compute_type = "float16" if settings.device == "cuda" else "int8"
try:
svc._model = WhisperModel(svc.model_name, device=svc.device, compute_type=compute_type)
except Exception:
if svc.device != "cpu":
svc._model = WhisperModel(svc.model_name, device="cpu", compute_type="int8")
app.state.stt = svc
print(f"[lazy] Local STT loaded: {settings.whisper_model}", flush=True)
return True
except Exception as e:
print(f"[lazy] Failed to load local STT: {e}", flush=True)
return False
_SAMPLE_RATE = 24000 # Kokoro native output rate
def _make_wav(pcm_bytes: bytes, sample_rate: int = 16000) -> bytes:
"""Wrap raw 16-bit PCM into a WAV container."""
buf = io.BytesIO()
num_samples = len(pcm_bytes) // 2
data_size = num_samples * 2
# WAV header
buf.write(b"RIFF")
buf.write(struct.pack("<I", 36 + data_size))
buf.write(b"WAVE")
buf.write(b"fmt ")
buf.write(struct.pack("<I", 16)) # chunk size
buf.write(struct.pack("<H", 1)) # PCM
buf.write(struct.pack("<H", 1)) # mono
buf.write(struct.pack("<I", sample_rate)) # sample rate
buf.write(struct.pack("<I", sample_rate * 2)) # byte rate
buf.write(struct.pack("<H", 2)) # block align
buf.write(struct.pack("<H", 16)) # bits per sample
buf.write(b"data")
buf.write(struct.pack("<I", data_size))
buf.write(pcm_bytes)
return buf.getvalue()
@router.get("/tts", response_class=HTMLResponse)
async def tts_test_page():
"""Combined TTS + STT test page."""
return """<!DOCTYPE html>
<html><head><meta charset="utf-8"><title>Voice Test</title>
<style>
body { font-family: sans-serif; max-width: 700px; margin: 30px auto; padding: 0 20px; }
h2 { border-bottom: 2px solid #333; padding-bottom: 8px; }
textarea { width: 100%; height: 70px; font-size: 15px; }
button { font-size: 16px; padding: 8px 24px; margin: 8px 4px 8px 0; cursor: pointer; border-radius: 4px; border: 1px solid #999; }
button:hover { background: #e0e0e0; }
.recording { background: #ff4444 !important; color: white !important; }
.status { margin-top: 10px; color: #666; font-size: 14px; }
audio { margin-top: 10px; width: 100%; }
.section { background: #f8f8f8; padding: 20px; border-radius: 8px; margin-bottom: 20px; }
#stt-result { font-size: 18px; color: #333; margin-top: 10px; padding: 10px; background: white; border: 1px solid #ddd; border-radius: 4px; min-height: 40px; }
</style></head>
<body>
<h2>Voice I/O Test</h2>
<div class="section">
<h3>TTS (Text to Speech)</h3>
<textarea id="tts-text" placeholder="输入要合成的文本...">你好我是IT0运维助手。很高兴为您服务</textarea>
<br><button onclick="doTTS()">合成语音</button>
<div class="status" id="tts-status"></div>
<audio id="tts-player" controls style="display:none"></audio>
</div>
<div class="section">
<h3>STT (Speech to Text)</h3>
<p style="font-size:14px;color:#888;">点击录音按钮说话,松开后自动识别。或上传音频文件。</p>
<button id="rec-btn" onmousedown="startRec()" onmouseup="stopRec()" ontouchstart="startRec()" ontouchend="stopRec()">按住录音</button>
<label style="cursor:pointer;border:1px solid #999;padding:8px 24px;border-radius:4px;font-size:16px;">
上传音频 <input type="file" id="audio-file" accept="audio/*" style="display:none" onchange="uploadAudio(this)">
</label>
<div class="status" id="stt-status"></div>
<div id="stt-result"></div>
<audio id="stt-player" controls style="display:none"></audio>
</div>
<div class="section">
<h3>Round-trip (STT + TTS)</h3>
<p style="font-size:14px;color:#888;">录音 → 识别文本 → 再合成语音播放。测试全链路。</p>
<button id="rt-btn" onmousedown="startRoundTrip()" onmouseup="stopRoundTrip()" ontouchstart="startRoundTrip()" ontouchend="stopRoundTrip()">按住说话 (Round-trip)</button>
<div class="status" id="rt-status"></div>
<div id="rt-result"></div>
<audio id="rt-player" controls style="display:none"></audio>
</div>
<script>
let mediaRec, audioChunks, recMode;
async function doTTS() {
const text = document.getElementById('tts-text').value.trim();
if (!text) return;
const st = document.getElementById('tts-status');
const pl = document.getElementById('tts-player');
st.textContent = '合成中...'; pl.style.display = 'none';
const t0 = Date.now();
try {
const r = await fetch('/api/v1/test/tts/synthesize?text=' + encodeURIComponent(text));
if (!r.ok) { st.textContent = 'Error: ' + r.status + ' ' + await r.text(); return; }
const blob = await r.blob();
st.textContent = '完成!耗时 ' + (Date.now()-t0) + 'ms大小 ' + (blob.size/1024).toFixed(1) + 'KB';
pl.src = URL.createObjectURL(blob); pl.style.display = 'block'; pl.play();
} catch(e) { st.textContent = 'Error: ' + e.message; }
}
function startRec() { _startRec('stt'); }
function stopRec() { _stopRec('stt'); }
function startRoundTrip() { _startRec('rt'); }
function stopRoundTrip() { _stopRec('rt'); }
async function _startRec(mode) {
recMode = mode;
const btn = document.getElementById(mode === 'rt' ? 'rt-btn' : 'rec-btn');
btn.classList.add('recording');
btn.textContent = '录音中...';
audioChunks = [];
try {
const stream = await navigator.mediaDevices.getUserMedia({ audio: { sampleRate: 16000, channelCount: 1 } });
mediaRec = new MediaRecorder(stream, { mimeType: 'audio/webm;codecs=opus' });
mediaRec.ondataavailable = e => { if (e.data.size > 0) audioChunks.push(e.data); };
mediaRec.onstop = () => {
stream.getTracks().forEach(t => t.stop());
const blob = new Blob(audioChunks, { type: 'audio/webm' });
if (mode === 'rt') doRoundTrip(blob);
else doSTT(blob);
};
mediaRec.start();
} catch(e) {
btn.classList.remove('recording');
btn.textContent = mode === 'rt' ? '按住说话 (Round-trip)' : '按住录音';
alert('麦克风权限被拒绝: ' + e.message);
}
}
function _stopRec(mode) {
const btn = document.getElementById(mode === 'rt' ? 'rt-btn' : 'rec-btn');
btn.classList.remove('recording');
btn.textContent = mode === 'rt' ? '按住说话 (Round-trip)' : '按住录音';
if (mediaRec && mediaRec.state === 'recording') mediaRec.stop();
}
async function doSTT(blob) {
const st = document.getElementById('stt-status');
const res = document.getElementById('stt-result');
const pl = document.getElementById('stt-player');
st.textContent = '识别中...'; res.textContent = '';
pl.src = URL.createObjectURL(blob); pl.style.display = 'block';
const t0 = Date.now();
try {
const fd = new FormData(); fd.append('audio', blob, 'recording.webm');
const r = await fetch('/api/v1/test/stt/transcribe', { method: 'POST', body: fd });
const data = await r.json();
st.textContent = '完成!耗时 ' + (Date.now()-t0) + 'ms';
res.textContent = data.text || '(empty)';
} catch(e) { st.textContent = 'Error: ' + e.message; }
}
async function uploadAudio(input) {
if (!input.files[0]) return;
const blob = input.files[0];
const st = document.getElementById('stt-status');
const res = document.getElementById('stt-result');
const pl = document.getElementById('stt-player');
st.textContent = '识别中...'; res.textContent = '';
pl.src = URL.createObjectURL(blob); pl.style.display = 'block';
const t0 = Date.now();
try {
const fd = new FormData(); fd.append('audio', blob, blob.name);
const r = await fetch('/api/v1/test/stt/transcribe', { method: 'POST', body: fd });
const data = await r.json();
st.textContent = '完成!耗时 ' + (Date.now()-t0) + 'ms';
res.textContent = data.text || '(empty)';
} catch(e) { st.textContent = 'Error: ' + e.message; }
input.value = '';
}
async function doRoundTrip(blob) {
const st = document.getElementById('rt-status');
const res = document.getElementById('rt-result');
const pl = document.getElementById('rt-player');
st.textContent = 'STT识别中...'; res.textContent = ''; pl.style.display = 'none';
const t0 = Date.now();
try {
// 1. STT
const fd = new FormData(); fd.append('audio', blob, 'recording.webm');
const r1 = await fetch('/api/v1/test/stt/transcribe', { method: 'POST', body: fd });
const sttData = await r1.json();
const text = sttData.text || '';
const sttMs = Date.now() - t0;
res.textContent = 'STT (' + sttMs + 'ms): ' + (text || '(empty)');
if (!text) { st.textContent = '识别为空'; return; }
// 2. TTS
st.textContent = 'TTS合成中...';
const t1 = Date.now();
const r2 = await fetch('/api/v1/test/tts/synthesize?text=' + encodeURIComponent(text));
if (!r2.ok) { st.textContent = 'TTS Error: ' + r2.status; return; }
const audioBlob = await r2.blob();
const ttsMs = Date.now() - t1;
const totalMs = Date.now() - t0;
st.textContent = '完成STT=' + sttMs + 'ms + TTS=' + ttsMs + 'ms = 总计' + totalMs + 'ms';
res.textContent += '\\nTTS (' + ttsMs + 'ms): ' + (audioBlob.size/1024).toFixed(1) + 'KB';
pl.src = URL.createObjectURL(audioBlob); pl.style.display = 'block'; pl.play();
} catch(e) { st.textContent = 'Error: ' + e.message; }
}
</script>
</body></html>"""
@router.get("/tts/synthesize")
async def tts_synthesize(request: Request, text: str = Query(..., min_length=1, max_length=500)):
"""Synthesize text to WAV audio (lazy-loads local Kokoro model on first call)."""
loaded = await asyncio.get_event_loop().run_in_executor(None, _ensure_local_tts, request.app)
if not loaded:
return Response(content="Failed to load local TTS model", status_code=503)
tts = request.app.state.tts
loop = asyncio.get_event_loop()
def _synth():
samples = []
for _, _, audio in tts._pipeline(text, voice=tts.voice):
if hasattr(audio, "numpy"):
samples.append(audio.numpy())
else:
samples.append(audio)
if not samples:
return b""
audio_np = np.concatenate(samples)
# Resample 24kHz → 16kHz
target_len = int(len(audio_np) / _SAMPLE_RATE * 16000)
indices = np.linspace(0, len(audio_np) - 1, target_len)
resampled = np.interp(indices, np.arange(len(audio_np)), audio_np)
pcm = (resampled * 32768).clip(-32768, 32767).astype(np.int16).tobytes()
return _make_wav(pcm, 16000)
wav_bytes = await loop.run_in_executor(None, _synth)
if not wav_bytes:
return Response(content="TTS produced no audio", status_code=500)
return Response(content=wav_bytes, media_type="audio/wav")
@router.post("/stt/transcribe")
async def stt_transcribe(request: Request, audio: UploadFile = File(...)):
"""Transcribe uploaded audio to text via faster-whisper (lazy-loads on first call)."""
loaded = await asyncio.get_event_loop().run_in_executor(None, _ensure_local_stt, request.app)
if not loaded:
return {"error": "Failed to load local STT model", "text": ""}
stt = request.app.state.stt
# Save uploaded file to temp
raw = await audio.read()
suffix = os.path.splitext(audio.filename or "audio.webm")[1] or ".webm"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as f:
f.write(raw)
tmp_path = f.name
try:
# faster-whisper can handle webm/mp3/wav etc. directly
loop = asyncio.get_event_loop()
def _transcribe():
segments, info = stt._model.transcribe(
tmp_path,
language=stt.language if hasattr(stt, 'language') and stt.language else None,
beam_size=5,
vad_filter=True,
)
text = "".join(seg.text for seg in segments).strip()
return text, info
text, info = await loop.run_in_executor(None, _transcribe)
return {
"text": text,
"language": getattr(info, "language", ""),
"duration": round(getattr(info, "duration", 0), 2),
}
finally:
os.unlink(tmp_path)
# =====================================================================
# OpenAI Voice API endpoints
# =====================================================================
def _get_openai_client():
"""Lazy-init OpenAI client with proxy support."""
from openai import OpenAI
import httpx
api_key = os.environ.get("OPENAI_API_KEY")
base_url = os.environ.get("OPENAI_BASE_URL")
if not api_key:
return None
kwargs = {"api_key": api_key}
if base_url:
kwargs["base_url"] = base_url
# Disable SSL verification for self-signed proxy certs
kwargs["http_client"] = httpx.Client(verify=False)
return OpenAI(**kwargs)
@router.get("/tts/synthesize-openai")
async def tts_synthesize_openai(
text: str = Query(..., min_length=1, max_length=500),
model: str = Query("tts-1", regex="^(tts-1|tts-1-hd|gpt-4o-mini-tts)$"),
voice: str = Query("alloy", regex="^(alloy|ash|ballad|coral|echo|fable|nova|onyx|sage|shimmer)$"),
):
"""Synthesize text to audio via OpenAI TTS API, resampled to 16kHz WAV."""
client = _get_openai_client()
if client is None:
return Response(content="OPENAI_API_KEY not configured", status_code=503)
loop = asyncio.get_event_loop()
def _synth():
response = client.audio.speech.create(
model=model,
voice=voice,
input=text,
response_format="pcm", # raw 24kHz 16-bit mono PCM (no header)
)
raw_pcm = response.content
# Resample 24kHz → 16kHz to match Flutter player expectations
audio_np = np.frombuffer(raw_pcm, dtype=np.int16).astype(np.float32)
target_len = int(len(audio_np) / 24000 * 16000)
indices = np.linspace(0, len(audio_np) - 1, target_len)
resampled = np.interp(indices, np.arange(len(audio_np)), audio_np)
pcm_16k = resampled.clip(-32768, 32767).astype(np.int16).tobytes()
return _make_wav(pcm_16k, 16000)
try:
wav_bytes = await loop.run_in_executor(None, _synth)
return Response(content=wav_bytes, media_type="audio/wav")
except Exception as e:
return Response(content=f"OpenAI TTS error: {e}", status_code=500)
@router.post("/stt/transcribe-openai")
async def stt_transcribe_openai(
audio: UploadFile = File(...),
model: str = Query("gpt-4o-transcribe", regex="^(whisper-1|gpt-4o-transcribe|gpt-4o-mini-transcribe)$"),
):
"""Transcribe uploaded audio via OpenAI STT API."""
client = _get_openai_client()
if client is None:
return {"error": "OPENAI_API_KEY not configured", "text": ""}
raw = await audio.read()
suffix = os.path.splitext(audio.filename or "audio.wav")[1] or ".wav"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as f:
f.write(raw)
tmp_path = f.name
try:
loop = asyncio.get_event_loop()
def _transcribe():
with open(tmp_path, "rb") as af:
result = client.audio.transcriptions.create(
model=model,
file=af,
language="zh",
)
return result.text
text = await loop.run_in_executor(None, _transcribe)
return {"text": text, "language": "zh", "model": model}
except Exception as e:
return {"error": f"OpenAI STT error: {e}", "text": ""}
finally:
os.unlink(tmp_path)