431 lines
17 KiB
Python
431 lines
17 KiB
Python
"""Temporary test endpoints for TTS and STT — browser-accessible."""
|
||
|
||
import asyncio
|
||
import io
|
||
import os
|
||
import struct
|
||
import tempfile
|
||
import threading
|
||
import numpy as np
|
||
from fastapi import APIRouter, Request, Query, UploadFile, File
|
||
from fastapi.responses import HTMLResponse, Response
|
||
|
||
router = APIRouter()
|
||
|
||
_loading_lock = threading.Lock()
|
||
|
||
|
||
def _ensure_local_tts(app) -> bool:
|
||
"""Lazy-load Kokoro TTS if not yet loaded. Returns True if available."""
|
||
tts = getattr(app.state, "tts", None)
|
||
if tts is not None and tts._pipeline is not None:
|
||
return True
|
||
with _loading_lock:
|
||
# Double-check after acquiring lock
|
||
tts = getattr(app.state, "tts", None)
|
||
if tts is not None and tts._pipeline is not None:
|
||
return True
|
||
try:
|
||
from ..tts.kokoro_service import KokoroTTSService, _patch_misaki_compat
|
||
from ..config.settings import settings
|
||
_patch_misaki_compat()
|
||
from kokoro import KPipeline
|
||
svc = KokoroTTSService(model=settings.kokoro_model, voice=settings.kokoro_voice)
|
||
svc._pipeline = KPipeline(lang_code='z')
|
||
app.state.tts = svc
|
||
print(f"[lazy] Local TTS loaded: {settings.kokoro_model} voice={settings.kokoro_voice}", flush=True)
|
||
return True
|
||
except Exception as e:
|
||
print(f"[lazy] Failed to load local TTS: {e}", flush=True)
|
||
return False
|
||
|
||
|
||
def _ensure_local_stt(app) -> bool:
|
||
"""Lazy-load faster-whisper STT if not yet loaded. Returns True if available."""
|
||
stt = getattr(app.state, "stt", None)
|
||
if stt is not None and stt._model is not None:
|
||
return True
|
||
with _loading_lock:
|
||
# Double-check after acquiring lock
|
||
stt = getattr(app.state, "stt", None)
|
||
if stt is not None and stt._model is not None:
|
||
return True
|
||
try:
|
||
from ..stt.whisper_service import WhisperSTTService
|
||
from faster_whisper import WhisperModel
|
||
from ..config.settings import settings
|
||
svc = WhisperSTTService(
|
||
model=settings.whisper_model,
|
||
device=settings.device,
|
||
language=settings.whisper_language,
|
||
)
|
||
compute_type = "float16" if settings.device == "cuda" else "int8"
|
||
try:
|
||
svc._model = WhisperModel(svc.model_name, device=svc.device, compute_type=compute_type)
|
||
except Exception:
|
||
if svc.device != "cpu":
|
||
svc._model = WhisperModel(svc.model_name, device="cpu", compute_type="int8")
|
||
app.state.stt = svc
|
||
print(f"[lazy] Local STT loaded: {settings.whisper_model}", flush=True)
|
||
return True
|
||
except Exception as e:
|
||
print(f"[lazy] Failed to load local STT: {e}", flush=True)
|
||
return False
|
||
|
||
_SAMPLE_RATE = 24000 # Kokoro native output rate
|
||
|
||
|
||
def _make_wav(pcm_bytes: bytes, sample_rate: int = 16000) -> bytes:
|
||
"""Wrap raw 16-bit PCM into a WAV container."""
|
||
buf = io.BytesIO()
|
||
num_samples = len(pcm_bytes) // 2
|
||
data_size = num_samples * 2
|
||
# WAV header
|
||
buf.write(b"RIFF")
|
||
buf.write(struct.pack("<I", 36 + data_size))
|
||
buf.write(b"WAVE")
|
||
buf.write(b"fmt ")
|
||
buf.write(struct.pack("<I", 16)) # chunk size
|
||
buf.write(struct.pack("<H", 1)) # PCM
|
||
buf.write(struct.pack("<H", 1)) # mono
|
||
buf.write(struct.pack("<I", sample_rate)) # sample rate
|
||
buf.write(struct.pack("<I", sample_rate * 2)) # byte rate
|
||
buf.write(struct.pack("<H", 2)) # block align
|
||
buf.write(struct.pack("<H", 16)) # bits per sample
|
||
buf.write(b"data")
|
||
buf.write(struct.pack("<I", data_size))
|
||
buf.write(pcm_bytes)
|
||
return buf.getvalue()
|
||
|
||
|
||
@router.get("/tts", response_class=HTMLResponse)
|
||
async def tts_test_page():
|
||
"""Combined TTS + STT test page."""
|
||
return """<!DOCTYPE html>
|
||
<html><head><meta charset="utf-8"><title>Voice Test</title>
|
||
<style>
|
||
body { font-family: sans-serif; max-width: 700px; margin: 30px auto; padding: 0 20px; }
|
||
h2 { border-bottom: 2px solid #333; padding-bottom: 8px; }
|
||
textarea { width: 100%; height: 70px; font-size: 15px; }
|
||
button { font-size: 16px; padding: 8px 24px; margin: 8px 4px 8px 0; cursor: pointer; border-radius: 4px; border: 1px solid #999; }
|
||
button:hover { background: #e0e0e0; }
|
||
.recording { background: #ff4444 !important; color: white !important; }
|
||
.status { margin-top: 10px; color: #666; font-size: 14px; }
|
||
audio { margin-top: 10px; width: 100%; }
|
||
.section { background: #f8f8f8; padding: 20px; border-radius: 8px; margin-bottom: 20px; }
|
||
#stt-result { font-size: 18px; color: #333; margin-top: 10px; padding: 10px; background: white; border: 1px solid #ddd; border-radius: 4px; min-height: 40px; }
|
||
</style></head>
|
||
<body>
|
||
<h2>Voice I/O Test</h2>
|
||
|
||
<div class="section">
|
||
<h3>TTS (Text to Speech)</h3>
|
||
<textarea id="tts-text" placeholder="输入要合成的文本...">你好,我是IT0运维助手。很高兴为您服务!</textarea>
|
||
<br><button onclick="doTTS()">合成语音</button>
|
||
<div class="status" id="tts-status"></div>
|
||
<audio id="tts-player" controls style="display:none"></audio>
|
||
</div>
|
||
|
||
<div class="section">
|
||
<h3>STT (Speech to Text)</h3>
|
||
<p style="font-size:14px;color:#888;">点击录音按钮说话,松开后自动识别。或上传音频文件。</p>
|
||
<button id="rec-btn" onmousedown="startRec()" onmouseup="stopRec()" ontouchstart="startRec()" ontouchend="stopRec()">按住录音</button>
|
||
<label style="cursor:pointer;border:1px solid #999;padding:8px 24px;border-radius:4px;font-size:16px;">
|
||
上传音频 <input type="file" id="audio-file" accept="audio/*" style="display:none" onchange="uploadAudio(this)">
|
||
</label>
|
||
<div class="status" id="stt-status"></div>
|
||
<div id="stt-result"></div>
|
||
<audio id="stt-player" controls style="display:none"></audio>
|
||
</div>
|
||
|
||
<div class="section">
|
||
<h3>Round-trip (STT + TTS)</h3>
|
||
<p style="font-size:14px;color:#888;">录音 → 识别文本 → 再合成语音播放。测试全链路。</p>
|
||
<button id="rt-btn" onmousedown="startRoundTrip()" onmouseup="stopRoundTrip()" ontouchstart="startRoundTrip()" ontouchend="stopRoundTrip()">按住说话 (Round-trip)</button>
|
||
<div class="status" id="rt-status"></div>
|
||
<div id="rt-result"></div>
|
||
<audio id="rt-player" controls style="display:none"></audio>
|
||
</div>
|
||
|
||
<script>
|
||
let mediaRec, audioChunks, recMode;
|
||
|
||
async function doTTS() {
|
||
const text = document.getElementById('tts-text').value.trim();
|
||
if (!text) return;
|
||
const st = document.getElementById('tts-status');
|
||
const pl = document.getElementById('tts-player');
|
||
st.textContent = '合成中...'; pl.style.display = 'none';
|
||
const t0 = Date.now();
|
||
try {
|
||
const r = await fetch('/api/v1/test/tts/synthesize?text=' + encodeURIComponent(text));
|
||
if (!r.ok) { st.textContent = 'Error: ' + r.status + ' ' + await r.text(); return; }
|
||
const blob = await r.blob();
|
||
st.textContent = '完成!耗时 ' + (Date.now()-t0) + 'ms,大小 ' + (blob.size/1024).toFixed(1) + 'KB';
|
||
pl.src = URL.createObjectURL(blob); pl.style.display = 'block'; pl.play();
|
||
} catch(e) { st.textContent = 'Error: ' + e.message; }
|
||
}
|
||
|
||
function startRec() { _startRec('stt'); }
|
||
function stopRec() { _stopRec('stt'); }
|
||
function startRoundTrip() { _startRec('rt'); }
|
||
function stopRoundTrip() { _stopRec('rt'); }
|
||
|
||
async function _startRec(mode) {
|
||
recMode = mode;
|
||
const btn = document.getElementById(mode === 'rt' ? 'rt-btn' : 'rec-btn');
|
||
btn.classList.add('recording');
|
||
btn.textContent = '录音中...';
|
||
audioChunks = [];
|
||
try {
|
||
const stream = await navigator.mediaDevices.getUserMedia({ audio: { sampleRate: 16000, channelCount: 1 } });
|
||
mediaRec = new MediaRecorder(stream, { mimeType: 'audio/webm;codecs=opus' });
|
||
mediaRec.ondataavailable = e => { if (e.data.size > 0) audioChunks.push(e.data); };
|
||
mediaRec.onstop = () => {
|
||
stream.getTracks().forEach(t => t.stop());
|
||
const blob = new Blob(audioChunks, { type: 'audio/webm' });
|
||
if (mode === 'rt') doRoundTrip(blob);
|
||
else doSTT(blob);
|
||
};
|
||
mediaRec.start();
|
||
} catch(e) {
|
||
btn.classList.remove('recording');
|
||
btn.textContent = mode === 'rt' ? '按住说话 (Round-trip)' : '按住录音';
|
||
alert('麦克风权限被拒绝: ' + e.message);
|
||
}
|
||
}
|
||
|
||
function _stopRec(mode) {
|
||
const btn = document.getElementById(mode === 'rt' ? 'rt-btn' : 'rec-btn');
|
||
btn.classList.remove('recording');
|
||
btn.textContent = mode === 'rt' ? '按住说话 (Round-trip)' : '按住录音';
|
||
if (mediaRec && mediaRec.state === 'recording') mediaRec.stop();
|
||
}
|
||
|
||
async function doSTT(blob) {
|
||
const st = document.getElementById('stt-status');
|
||
const res = document.getElementById('stt-result');
|
||
const pl = document.getElementById('stt-player');
|
||
st.textContent = '识别中...'; res.textContent = '';
|
||
pl.src = URL.createObjectURL(blob); pl.style.display = 'block';
|
||
const t0 = Date.now();
|
||
try {
|
||
const fd = new FormData(); fd.append('audio', blob, 'recording.webm');
|
||
const r = await fetch('/api/v1/test/stt/transcribe', { method: 'POST', body: fd });
|
||
const data = await r.json();
|
||
st.textContent = '完成!耗时 ' + (Date.now()-t0) + 'ms';
|
||
res.textContent = data.text || '(empty)';
|
||
} catch(e) { st.textContent = 'Error: ' + e.message; }
|
||
}
|
||
|
||
async function uploadAudio(input) {
|
||
if (!input.files[0]) return;
|
||
const blob = input.files[0];
|
||
const st = document.getElementById('stt-status');
|
||
const res = document.getElementById('stt-result');
|
||
const pl = document.getElementById('stt-player');
|
||
st.textContent = '识别中...'; res.textContent = '';
|
||
pl.src = URL.createObjectURL(blob); pl.style.display = 'block';
|
||
const t0 = Date.now();
|
||
try {
|
||
const fd = new FormData(); fd.append('audio', blob, blob.name);
|
||
const r = await fetch('/api/v1/test/stt/transcribe', { method: 'POST', body: fd });
|
||
const data = await r.json();
|
||
st.textContent = '完成!耗时 ' + (Date.now()-t0) + 'ms';
|
||
res.textContent = data.text || '(empty)';
|
||
} catch(e) { st.textContent = 'Error: ' + e.message; }
|
||
input.value = '';
|
||
}
|
||
|
||
async function doRoundTrip(blob) {
|
||
const st = document.getElementById('rt-status');
|
||
const res = document.getElementById('rt-result');
|
||
const pl = document.getElementById('rt-player');
|
||
st.textContent = 'STT识别中...'; res.textContent = ''; pl.style.display = 'none';
|
||
const t0 = Date.now();
|
||
try {
|
||
// 1. STT
|
||
const fd = new FormData(); fd.append('audio', blob, 'recording.webm');
|
||
const r1 = await fetch('/api/v1/test/stt/transcribe', { method: 'POST', body: fd });
|
||
const sttData = await r1.json();
|
||
const text = sttData.text || '';
|
||
const sttMs = Date.now() - t0;
|
||
res.textContent = 'STT (' + sttMs + 'ms): ' + (text || '(empty)');
|
||
if (!text) { st.textContent = '识别为空'; return; }
|
||
// 2. TTS
|
||
st.textContent = 'TTS合成中...';
|
||
const t1 = Date.now();
|
||
const r2 = await fetch('/api/v1/test/tts/synthesize?text=' + encodeURIComponent(text));
|
||
if (!r2.ok) { st.textContent = 'TTS Error: ' + r2.status; return; }
|
||
const audioBlob = await r2.blob();
|
||
const ttsMs = Date.now() - t1;
|
||
const totalMs = Date.now() - t0;
|
||
st.textContent = '完成!STT=' + sttMs + 'ms + TTS=' + ttsMs + 'ms = 总计' + totalMs + 'ms';
|
||
res.textContent += '\\nTTS (' + ttsMs + 'ms): ' + (audioBlob.size/1024).toFixed(1) + 'KB';
|
||
pl.src = URL.createObjectURL(audioBlob); pl.style.display = 'block'; pl.play();
|
||
} catch(e) { st.textContent = 'Error: ' + e.message; }
|
||
}
|
||
</script>
|
||
</body></html>"""
|
||
|
||
|
||
@router.get("/tts/synthesize")
|
||
async def tts_synthesize(request: Request, text: str = Query(..., min_length=1, max_length=500)):
|
||
"""Synthesize text to WAV audio (lazy-loads local Kokoro model on first call)."""
|
||
loaded = await asyncio.get_event_loop().run_in_executor(None, _ensure_local_tts, request.app)
|
||
if not loaded:
|
||
return Response(content="Failed to load local TTS model", status_code=503)
|
||
tts = request.app.state.tts
|
||
|
||
loop = asyncio.get_event_loop()
|
||
def _synth():
|
||
samples = []
|
||
for _, _, audio in tts._pipeline(text, voice=tts.voice):
|
||
if hasattr(audio, "numpy"):
|
||
samples.append(audio.numpy())
|
||
else:
|
||
samples.append(audio)
|
||
if not samples:
|
||
return b""
|
||
audio_np = np.concatenate(samples)
|
||
# Resample 24kHz → 16kHz
|
||
target_len = int(len(audio_np) / _SAMPLE_RATE * 16000)
|
||
indices = np.linspace(0, len(audio_np) - 1, target_len)
|
||
resampled = np.interp(indices, np.arange(len(audio_np)), audio_np)
|
||
pcm = (resampled * 32768).clip(-32768, 32767).astype(np.int16).tobytes()
|
||
return _make_wav(pcm, 16000)
|
||
|
||
wav_bytes = await loop.run_in_executor(None, _synth)
|
||
if not wav_bytes:
|
||
return Response(content="TTS produced no audio", status_code=500)
|
||
|
||
return Response(content=wav_bytes, media_type="audio/wav")
|
||
|
||
|
||
@router.post("/stt/transcribe")
|
||
async def stt_transcribe(request: Request, audio: UploadFile = File(...)):
|
||
"""Transcribe uploaded audio to text via faster-whisper (lazy-loads on first call)."""
|
||
loaded = await asyncio.get_event_loop().run_in_executor(None, _ensure_local_stt, request.app)
|
||
if not loaded:
|
||
return {"error": "Failed to load local STT model", "text": ""}
|
||
stt = request.app.state.stt
|
||
|
||
# Save uploaded file to temp
|
||
raw = await audio.read()
|
||
suffix = os.path.splitext(audio.filename or "audio.webm")[1] or ".webm"
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as f:
|
||
f.write(raw)
|
||
tmp_path = f.name
|
||
|
||
try:
|
||
# faster-whisper can handle webm/mp3/wav etc. directly
|
||
loop = asyncio.get_event_loop()
|
||
def _transcribe():
|
||
segments, info = stt._model.transcribe(
|
||
tmp_path,
|
||
language=stt.language if hasattr(stt, 'language') and stt.language else None,
|
||
beam_size=5,
|
||
vad_filter=True,
|
||
)
|
||
text = "".join(seg.text for seg in segments).strip()
|
||
return text, info
|
||
|
||
text, info = await loop.run_in_executor(None, _transcribe)
|
||
return {
|
||
"text": text,
|
||
"language": getattr(info, "language", ""),
|
||
"duration": round(getattr(info, "duration", 0), 2),
|
||
}
|
||
finally:
|
||
os.unlink(tmp_path)
|
||
|
||
|
||
# =====================================================================
|
||
# OpenAI Voice API endpoints
|
||
# =====================================================================
|
||
|
||
def _get_openai_client():
|
||
"""Lazy-init OpenAI client with proxy support."""
|
||
from openai import OpenAI
|
||
import httpx
|
||
api_key = os.environ.get("OPENAI_API_KEY")
|
||
base_url = os.environ.get("OPENAI_BASE_URL")
|
||
if not api_key:
|
||
return None
|
||
kwargs = {"api_key": api_key}
|
||
if base_url:
|
||
kwargs["base_url"] = base_url
|
||
# Disable SSL verification for self-signed proxy certs
|
||
kwargs["http_client"] = httpx.Client(verify=False)
|
||
return OpenAI(**kwargs)
|
||
|
||
|
||
@router.get("/tts/synthesize-openai")
|
||
async def tts_synthesize_openai(
|
||
text: str = Query(..., min_length=1, max_length=500),
|
||
model: str = Query("tts-1", regex="^(tts-1|tts-1-hd|gpt-4o-mini-tts)$"),
|
||
voice: str = Query("alloy", regex="^(alloy|ash|ballad|coral|echo|fable|nova|onyx|sage|shimmer)$"),
|
||
):
|
||
"""Synthesize text to audio via OpenAI TTS API, resampled to 16kHz WAV."""
|
||
client = _get_openai_client()
|
||
if client is None:
|
||
return Response(content="OPENAI_API_KEY not configured", status_code=503)
|
||
|
||
loop = asyncio.get_event_loop()
|
||
def _synth():
|
||
response = client.audio.speech.create(
|
||
model=model,
|
||
voice=voice,
|
||
input=text,
|
||
response_format="pcm", # raw 24kHz 16-bit mono PCM (no header)
|
||
)
|
||
raw_pcm = response.content
|
||
# Resample 24kHz → 16kHz to match Flutter player expectations
|
||
audio_np = np.frombuffer(raw_pcm, dtype=np.int16).astype(np.float32)
|
||
target_len = int(len(audio_np) / 24000 * 16000)
|
||
indices = np.linspace(0, len(audio_np) - 1, target_len)
|
||
resampled = np.interp(indices, np.arange(len(audio_np)), audio_np)
|
||
pcm_16k = resampled.clip(-32768, 32767).astype(np.int16).tobytes()
|
||
return _make_wav(pcm_16k, 16000)
|
||
|
||
try:
|
||
wav_bytes = await loop.run_in_executor(None, _synth)
|
||
return Response(content=wav_bytes, media_type="audio/wav")
|
||
except Exception as e:
|
||
return Response(content=f"OpenAI TTS error: {e}", status_code=500)
|
||
|
||
|
||
@router.post("/stt/transcribe-openai")
|
||
async def stt_transcribe_openai(
|
||
audio: UploadFile = File(...),
|
||
model: str = Query("gpt-4o-transcribe", regex="^(whisper-1|gpt-4o-transcribe|gpt-4o-mini-transcribe)$"),
|
||
):
|
||
"""Transcribe uploaded audio via OpenAI STT API."""
|
||
client = _get_openai_client()
|
||
if client is None:
|
||
return {"error": "OPENAI_API_KEY not configured", "text": ""}
|
||
|
||
raw = await audio.read()
|
||
suffix = os.path.splitext(audio.filename or "audio.wav")[1] or ".wav"
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as f:
|
||
f.write(raw)
|
||
tmp_path = f.name
|
||
|
||
try:
|
||
loop = asyncio.get_event_loop()
|
||
def _transcribe():
|
||
with open(tmp_path, "rb") as af:
|
||
result = client.audio.transcriptions.create(
|
||
model=model,
|
||
file=af,
|
||
language="zh",
|
||
)
|
||
return result.text
|
||
|
||
text = await loop.run_in_executor(None, _transcribe)
|
||
return {"text": text, "language": "zh", "model": model}
|
||
except Exception as e:
|
||
return {"error": f"OpenAI STT error: {e}", "text": ""}
|
||
finally:
|
||
os.unlink(tmp_path)
|