feat: add OpenAI TTS/STT API endpoints for comparison testing

- Add openai package to voice-service requirements
- Add /api/v1/test/tts/synthesize-openai (tts-1/tts-1-hd/gpt-4o-mini-tts)
- Add /api/v1/test/stt/transcribe-openai (gpt-4o-transcribe/whisper-1)
- Add OPENAI_API_KEY and OPENAI_BASE_URL env vars to voice-service
- Flutter test page: SegmentedButton to toggle Local/OpenAI provider
- All endpoints maintain same response format for easy comparison

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
hailin 2026-02-24 07:20:03 -08:00
parent ac0b8ee1c6
commit d43baed3a5
4 changed files with 172 additions and 34 deletions

View File

@ -327,6 +327,8 @@ services:
- KOKORO_MODEL=${KOKORO_MODEL:-kokoro-82m} - KOKORO_MODEL=${KOKORO_MODEL:-kokoro-82m}
- KOKORO_VOICE=${KOKORO_VOICE:-zf_xiaoxiao} - KOKORO_VOICE=${KOKORO_VOICE:-zf_xiaoxiao}
- DEVICE=${VOICE_DEVICE:-cpu} - DEVICE=${VOICE_DEVICE:-cpu}
- OPENAI_API_KEY=${OPENAI_API_KEY}
- OPENAI_BASE_URL=${OPENAI_BASE_URL}
healthcheck: healthcheck:
test: ["CMD-SHELL", "python3 -c \"import urllib.request; urllib.request.urlopen('http://localhost:3008/docs')\""] test: ["CMD-SHELL", "python3 -c \"import urllib.request; urllib.request.urlopen('http://localhost:3008/docs')\""]
interval: 30s interval: 30s

View File

@ -10,8 +10,10 @@ import 'package:path/path.dart' as p;
import 'package:permission_handler/permission_handler.dart'; import 'package:permission_handler/permission_handler.dart';
import '../../../../core/network/dio_client.dart'; import '../../../../core/network/dio_client.dart';
enum _VoiceProvider { local, openai }
/// Temporary voice I/O test page TTS + STT + Round-trip. /// Temporary voice I/O test page TTS + STT + Round-trip.
/// Uses flutter_sound for both recording and playback. /// Supports switching between Local (Kokoro/faster-whisper) and OpenAI APIs.
class VoiceTestPage extends ConsumerStatefulWidget { class VoiceTestPage extends ConsumerStatefulWidget {
const VoiceTestPage({super.key}); const VoiceTestPage({super.key});
@ -28,6 +30,8 @@ class _VoiceTestPageState extends ConsumerState<VoiceTestPage> {
bool _playerReady = false; bool _playerReady = false;
bool _recorderReady = false; bool _recorderReady = false;
_VoiceProvider _provider = _VoiceProvider.local;
String _ttsStatus = ''; String _ttsStatus = '';
String _sttStatus = ''; String _sttStatus = '';
String _sttResult = ''; String _sttResult = '';
@ -38,13 +42,22 @@ class _VoiceTestPageState extends ConsumerState<VoiceTestPage> {
String _recordMode = ''; // 'stt' or 'rt' String _recordMode = ''; // 'stt' or 'rt'
String? _recordingPath; String? _recordingPath;
// Endpoint paths based on provider
String get _ttsEndpoint => _provider == _VoiceProvider.openai
? '/api/v1/test/tts/synthesize-openai'
: '/api/v1/test/tts/synthesize';
String get _sttEndpoint => _provider == _VoiceProvider.openai
? '/api/v1/test/stt/transcribe-openai'
: '/api/v1/test/stt/transcribe';
Dio get _dioBinary { Dio get _dioBinary {
final base = ref.read(dioClientProvider); final base = ref.read(dioClientProvider);
return Dio(BaseOptions( return Dio(BaseOptions(
baseUrl: base.options.baseUrl, baseUrl: base.options.baseUrl,
headers: Map.from(base.options.headers), headers: Map.from(base.options.headers),
connectTimeout: const Duration(seconds: 30), connectTimeout: const Duration(seconds: 30),
receiveTimeout: const Duration(seconds: 60), receiveTimeout: const Duration(seconds: 120),
responseType: ResponseType.bytes, responseType: ResponseType.bytes,
))..interceptors.addAll(base.interceptors); ))..interceptors.addAll(base.interceptors);
} }
@ -77,26 +90,27 @@ class _VoiceTestPageState extends ConsumerState<VoiceTestPage> {
Future<void> _doTTS() async { Future<void> _doTTS() async {
final text = _ttsController.text.trim(); final text = _ttsController.text.trim();
if (text.isEmpty) return; if (text.isEmpty) return;
final label = _provider == _VoiceProvider.openai ? 'OpenAI' : 'Local';
setState(() { setState(() {
_isSynthesizing = true; _isSynthesizing = true;
_ttsStatus = '合成中...'; _ttsStatus = '[$label] 合成中...';
}); });
final sw = Stopwatch()..start(); final sw = Stopwatch()..start();
try { try {
final resp = await _dioBinary.get( final resp = await _dioBinary.get(
'/api/v1/test/tts/synthesize', _ttsEndpoint,
queryParameters: {'text': text}, queryParameters: {'text': text},
); );
sw.stop(); sw.stop();
final bytes = resp.data as List<int>; final bytes = resp.data as List<int>;
setState(() { setState(() {
_ttsStatus = _ttsStatus =
'完成!耗时 ${sw.elapsedMilliseconds}ms大小 ${(bytes.length / 1024).toStringAsFixed(1)}KB'; '[$label] 完成!耗时 ${sw.elapsedMilliseconds}ms大小 ${(bytes.length / 1024).toStringAsFixed(1)}KB';
}); });
await _playWav(Uint8List.fromList(bytes)); await _playWav(Uint8List.fromList(bytes));
} catch (e) { } catch (e) {
sw.stop(); sw.stop();
setState(() => _ttsStatus = '错误: $e'); setState(() => _ttsStatus = '[$label] 错误: $e');
} finally { } finally {
setState(() => _isSynthesizing = false); setState(() => _isSynthesizing = false);
} }
@ -161,26 +175,26 @@ class _VoiceTestPageState extends ConsumerState<VoiceTestPage> {
// ---- STT ---- // ---- STT ----
Future<void> _doSTT(String audioPath) async { Future<void> _doSTT(String audioPath) async {
setState(() => _sttStatus = '识别中...'); final label = _provider == _VoiceProvider.openai ? 'OpenAI' : 'Local';
setState(() => _sttStatus = '[$label] 识别中...');
final sw = Stopwatch()..start(); final sw = Stopwatch()..start();
try { try {
final formData = FormData.fromMap({ final formData = FormData.fromMap({
'audio': 'audio':
await MultipartFile.fromFile(audioPath, filename: 'recording.wav'), await MultipartFile.fromFile(audioPath, filename: 'recording.wav'),
}); });
final resp = final resp = await _dioJson.post(_sttEndpoint, data: formData);
await _dioJson.post('/api/v1/test/stt/transcribe', data: formData);
sw.stop(); sw.stop();
final data = resp.data as Map<String, dynamic>; final data = resp.data as Map<String, dynamic>;
setState(() { setState(() {
_sttResult = data['text'] ?? '(empty)'; _sttResult = data['text'] ?? '(empty)';
_sttStatus = final extra = data['duration'] != null ? ',时长 ${data['duration']}s' : '';
'完成!耗时 ${sw.elapsedMilliseconds}ms时长 ${data['duration'] ?? 0}s'; _sttStatus = '[$label] 完成!耗时 ${sw.elapsedMilliseconds}ms$extra';
}); });
} catch (e) { } catch (e) {
sw.stop(); sw.stop();
setState(() { setState(() {
_sttStatus = '错误: $e'; _sttStatus = '[$label] 错误: $e';
_sttResult = ''; _sttResult = '';
}); });
} finally { } finally {
@ -190,7 +204,8 @@ class _VoiceTestPageState extends ConsumerState<VoiceTestPage> {
// ---- Round-trip: STT TTS ---- // ---- Round-trip: STT TTS ----
Future<void> _doRoundTrip(String audioPath) async { Future<void> _doRoundTrip(String audioPath) async {
setState(() => _rtStatus = 'STT 识别中...'); final label = _provider == _VoiceProvider.openai ? 'OpenAI' : 'Local';
setState(() => _rtStatus = '[$label] STT 识别中...');
final totalSw = Stopwatch()..start(); final totalSw = Stopwatch()..start();
try { try {
// 1. STT // 1. STT
@ -199,24 +214,23 @@ class _VoiceTestPageState extends ConsumerState<VoiceTestPage> {
'audio': 'audio':
await MultipartFile.fromFile(audioPath, filename: 'recording.wav'), await MultipartFile.fromFile(audioPath, filename: 'recording.wav'),
}); });
final sttResp = final sttResp = await _dioJson.post(_sttEndpoint, data: formData);
await _dioJson.post('/api/v1/test/stt/transcribe', data: formData);
sttSw.stop(); sttSw.stop();
final sttData = sttResp.data as Map<String, dynamic>; final sttData = sttResp.data as Map<String, dynamic>;
final text = sttData['text'] ?? ''; final text = sttData['text'] ?? '';
setState(() { setState(() {
_rtResult = 'STT (${sttSw.elapsedMilliseconds}ms): $text'; _rtResult = '[$label] STT (${sttSw.elapsedMilliseconds}ms): $text';
_rtStatus = 'TTS 合成中...'; _rtStatus = '[$label] TTS 合成中...';
}); });
if (text.isEmpty) { if (text.isEmpty) {
setState(() => _rtStatus = 'STT 识别为空'); setState(() => _rtStatus = '[$label] STT 识别为空');
return; return;
} }
// 2. TTS // 2. TTS
final ttsSw = Stopwatch()..start(); final ttsSw = Stopwatch()..start();
final ttsResp = await _dioBinary.get( final ttsResp = await _dioBinary.get(
'/api/v1/test/tts/synthesize', _ttsEndpoint,
queryParameters: {'text': text}, queryParameters: {'text': text},
); );
ttsSw.stop(); ttsSw.stop();
@ -224,14 +238,14 @@ class _VoiceTestPageState extends ConsumerState<VoiceTestPage> {
final audioBytes = ttsResp.data as List<int>; final audioBytes = ttsResp.data as List<int>;
setState(() { setState(() {
_rtResult += _rtResult +=
'\nTTS (${ttsSw.elapsedMilliseconds}ms): ${(audioBytes.length / 1024).toStringAsFixed(1)}KB'; '\n[$label] TTS (${ttsSw.elapsedMilliseconds}ms): ${(audioBytes.length / 1024).toStringAsFixed(1)}KB';
_rtStatus = _rtStatus =
'完成!STT=${sttSw.elapsedMilliseconds}ms + TTS=${ttsSw.elapsedMilliseconds}ms = ${totalSw.elapsedMilliseconds}ms'; '[$label] STT=${sttSw.elapsedMilliseconds}ms + TTS=${ttsSw.elapsedMilliseconds}ms = ${totalSw.elapsedMilliseconds}ms';
}); });
await _playWav(Uint8List.fromList(audioBytes)); await _playWav(Uint8List.fromList(audioBytes));
} catch (e) { } catch (e) {
totalSw.stop(); totalSw.stop();
setState(() => _rtStatus = '错误: $e'); setState(() => _rtStatus = '[$label] 错误: $e');
} finally { } finally {
_cleanupFile(audioPath); _cleanupFile(audioPath);
} }
@ -259,6 +273,39 @@ class _VoiceTestPageState extends ConsumerState<VoiceTestPage> {
body: ListView( body: ListView(
padding: const EdgeInsets.all(16), padding: const EdgeInsets.all(16),
children: [ children: [
// Provider toggle
Center(
child: SegmentedButton<_VoiceProvider>(
segments: const [
ButtonSegment(
value: _VoiceProvider.local,
label: Text('Local'),
icon: Icon(Icons.computer, size: 18),
),
ButtonSegment(
value: _VoiceProvider.openai,
label: Text('OpenAI'),
icon: Icon(Icons.cloud, size: 18),
),
],
selected: {_provider},
onSelectionChanged: (Set<_VoiceProvider> sel) {
setState(() => _provider = sel.first);
},
),
),
Padding(
padding: const EdgeInsets.only(top: 4, bottom: 12),
child: Center(
child: Text(
_provider == _VoiceProvider.openai
? 'STT: gpt-4o-transcribe | TTS: tts-1'
: 'STT: faster-whisper | TTS: Kokoro',
style: TextStyle(fontSize: 12, color: Colors.grey[500]),
),
),
),
// TTS Section
_buildSection( _buildSection(
title: 'TTS (文本转语音)', title: 'TTS (文本转语音)',
child: Column( child: Column(
@ -287,12 +334,14 @@ class _VoiceTestPageState extends ConsumerState<VoiceTestPage> {
Padding( Padding(
padding: const EdgeInsets.only(top: 8), padding: const EdgeInsets.only(top: 8),
child: Text(_ttsStatus, child: Text(_ttsStatus,
style: TextStyle(color: Colors.grey[600], fontSize: 13)), style:
TextStyle(color: Colors.grey[600], fontSize: 13)),
), ),
], ],
), ),
), ),
const SizedBox(height: 16), const SizedBox(height: 16),
// STT Section
_buildSection( _buildSection(
title: 'STT (语音转文本)', title: 'STT (语音转文本)',
child: Column( child: Column(
@ -304,8 +353,9 @@ class _VoiceTestPageState extends ConsumerState<VoiceTestPage> {
child: ElevatedButton.icon( child: ElevatedButton.icon(
onPressed: () {}, onPressed: () {},
style: ElevatedButton.styleFrom( style: ElevatedButton.styleFrom(
backgroundColor: backgroundColor: _isRecording && _recordMode == 'stt'
_isRecording && _recordMode == 'stt' ? Colors.red : null, ? Colors.red
: null,
), ),
icon: Icon(_isRecording && _recordMode == 'stt' icon: Icon(_isRecording && _recordMode == 'stt'
? Icons.mic ? Icons.mic
@ -319,7 +369,8 @@ class _VoiceTestPageState extends ConsumerState<VoiceTestPage> {
Padding( Padding(
padding: const EdgeInsets.only(top: 8), padding: const EdgeInsets.only(top: 8),
child: Text(_sttStatus, child: Text(_sttStatus,
style: TextStyle(color: Colors.grey[600], fontSize: 13)), style:
TextStyle(color: Colors.grey[600], fontSize: 13)),
), ),
if (_sttResult.isNotEmpty) if (_sttResult.isNotEmpty)
Container( Container(
@ -330,12 +381,14 @@ class _VoiceTestPageState extends ConsumerState<VoiceTestPage> {
color: Colors.grey[100], color: Colors.grey[100],
borderRadius: BorderRadius.circular(8), borderRadius: BorderRadius.circular(8),
), ),
child: Text(_sttResult, style: const TextStyle(fontSize: 16)), child:
Text(_sttResult, style: const TextStyle(fontSize: 16)),
), ),
], ],
), ),
), ),
const SizedBox(height: 16), const SizedBox(height: 16),
// Round-trip Section
_buildSection( _buildSection(
title: 'Round-trip (STT + TTS)', title: 'Round-trip (STT + TTS)',
subtitle: '录音 → 识别文本 → 合成语音播放', subtitle: '录音 → 识别文本 → 合成语音播放',
@ -348,8 +401,9 @@ class _VoiceTestPageState extends ConsumerState<VoiceTestPage> {
child: ElevatedButton.icon( child: ElevatedButton.icon(
onPressed: () {}, onPressed: () {},
style: ElevatedButton.styleFrom( style: ElevatedButton.styleFrom(
backgroundColor: backgroundColor: _isRecording && _recordMode == 'rt'
_isRecording && _recordMode == 'rt' ? Colors.red : null, ? Colors.red
: null,
), ),
icon: Icon(_isRecording && _recordMode == 'rt' icon: Icon(_isRecording && _recordMode == 'rt'
? Icons.mic ? Icons.mic
@ -363,7 +417,8 @@ class _VoiceTestPageState extends ConsumerState<VoiceTestPage> {
Padding( Padding(
padding: const EdgeInsets.only(top: 8), padding: const EdgeInsets.only(top: 8),
child: Text(_rtStatus, child: Text(_rtStatus,
style: TextStyle(color: Colors.grey[600], fontSize: 13)), style:
TextStyle(color: Colors.grey[600], fontSize: 13)),
), ),
if (_rtResult.isNotEmpty) if (_rtResult.isNotEmpty)
Container( Container(
@ -374,7 +429,8 @@ class _VoiceTestPageState extends ConsumerState<VoiceTestPage> {
color: Colors.grey[100], color: Colors.grey[100],
borderRadius: BorderRadius.circular(8), borderRadius: BorderRadius.circular(8),
), ),
child: Text(_rtResult, style: const TextStyle(fontSize: 14)), child:
Text(_rtResult, style: const TextStyle(fontSize: 14)),
), ),
], ],
), ),

View File

@ -6,6 +6,7 @@ misaki==0.7.17
silero-vad==5.1 silero-vad==5.1
twilio==9.0.0 twilio==9.0.0
anthropic>=0.32.0 anthropic>=0.32.0
openai>=1.30.0
websockets==12.0 websockets==12.0
pydantic==2.6.0 pydantic==2.6.0
pydantic-settings==2.2.0 pydantic-settings==2.2.0

View File

@ -2,7 +2,9 @@
import asyncio import asyncio
import io import io
import os
import struct import struct
import tempfile
import numpy as np import numpy as np
from fastapi import APIRouter, Request, Query, UploadFile, File from fastapi import APIRouter, Request, Query, UploadFile, File
from fastapi.responses import HTMLResponse, Response from fastapi.responses import HTMLResponse, Response
@ -245,9 +247,6 @@ async def stt_transcribe(request: Request, audio: UploadFile = File(...)):
if stt is None or stt._model is None: if stt is None or stt._model is None:
return {"error": "STT model not loaded", "text": ""} return {"error": "STT model not loaded", "text": ""}
import tempfile
import os
# Save uploaded file to temp # Save uploaded file to temp
raw = await audio.read() raw = await audio.read()
suffix = os.path.splitext(audio.filename or "audio.webm")[1] or ".webm" suffix = os.path.splitext(audio.filename or "audio.webm")[1] or ".webm"
@ -276,3 +275,83 @@ async def stt_transcribe(request: Request, audio: UploadFile = File(...)):
} }
finally: finally:
os.unlink(tmp_path) os.unlink(tmp_path)
# =====================================================================
# OpenAI Voice API endpoints
# =====================================================================
def _get_openai_client():
"""Lazy-init OpenAI client with proxy support."""
from openai import OpenAI
api_key = os.environ.get("OPENAI_API_KEY")
base_url = os.environ.get("OPENAI_BASE_URL")
if not api_key:
return None
kwargs = {"api_key": api_key}
if base_url:
kwargs["base_url"] = base_url
return OpenAI(**kwargs)
@router.get("/tts/synthesize-openai")
async def tts_synthesize_openai(
text: str = Query(..., min_length=1, max_length=500),
model: str = Query("tts-1", regex="^(tts-1|tts-1-hd|gpt-4o-mini-tts)$"),
voice: str = Query("alloy", regex="^(alloy|ash|ballad|coral|echo|fable|nova|onyx|sage|shimmer)$"),
):
"""Synthesize text to audio via OpenAI TTS API."""
client = _get_openai_client()
if client is None:
return Response(content="OPENAI_API_KEY not configured", status_code=503)
loop = asyncio.get_event_loop()
def _synth():
response = client.audio.speech.create(
model=model,
voice=voice,
input=text,
response_format="wav",
)
return response.content
try:
wav_bytes = await loop.run_in_executor(None, _synth)
return Response(content=wav_bytes, media_type="audio/wav")
except Exception as e:
return Response(content=f"OpenAI TTS error: {e}", status_code=500)
@router.post("/stt/transcribe-openai")
async def stt_transcribe_openai(
audio: UploadFile = File(...),
model: str = Query("gpt-4o-transcribe", regex="^(whisper-1|gpt-4o-transcribe|gpt-4o-mini-transcribe)$"),
):
"""Transcribe uploaded audio via OpenAI STT API."""
client = _get_openai_client()
if client is None:
return {"error": "OPENAI_API_KEY not configured", "text": ""}
raw = await audio.read()
suffix = os.path.splitext(audio.filename or "audio.wav")[1] or ".wav"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as f:
f.write(raw)
tmp_path = f.name
try:
loop = asyncio.get_event_loop()
def _transcribe():
with open(tmp_path, "rb") as af:
result = client.audio.transcriptions.create(
model=model,
file=af,
language="zh",
)
return result.text
text = await loop.run_in_executor(None, _transcribe)
return {"text": text, "language": "zh", "model": model}
except Exception as e:
return {"error": f"OpenAI STT error: {e}", "text": ""}
finally:
os.unlink(tmp_path)