diff --git a/deploy/docker/docker-compose.yml b/deploy/docker/docker-compose.yml index 6841887..c56609e 100644 --- a/deploy/docker/docker-compose.yml +++ b/deploy/docker/docker-compose.yml @@ -327,6 +327,8 @@ services: - KOKORO_MODEL=${KOKORO_MODEL:-kokoro-82m} - KOKORO_VOICE=${KOKORO_VOICE:-zf_xiaoxiao} - DEVICE=${VOICE_DEVICE:-cpu} + - OPENAI_API_KEY=${OPENAI_API_KEY} + - OPENAI_BASE_URL=${OPENAI_BASE_URL} healthcheck: test: ["CMD-SHELL", "python3 -c \"import urllib.request; urllib.request.urlopen('http://localhost:3008/docs')\""] interval: 30s diff --git a/it0_app/lib/features/agent_call/presentation/pages/voice_test_page.dart b/it0_app/lib/features/agent_call/presentation/pages/voice_test_page.dart index 07724a7..0c8412e 100644 --- a/it0_app/lib/features/agent_call/presentation/pages/voice_test_page.dart +++ b/it0_app/lib/features/agent_call/presentation/pages/voice_test_page.dart @@ -10,8 +10,10 @@ import 'package:path/path.dart' as p; import 'package:permission_handler/permission_handler.dart'; import '../../../../core/network/dio_client.dart'; +enum _VoiceProvider { local, openai } + /// Temporary voice I/O test page — TTS + STT + Round-trip. -/// Uses flutter_sound for both recording and playback. +/// Supports switching between Local (Kokoro/faster-whisper) and OpenAI APIs. class VoiceTestPage extends ConsumerStatefulWidget { const VoiceTestPage({super.key}); @@ -28,6 +30,8 @@ class _VoiceTestPageState extends ConsumerState { bool _playerReady = false; bool _recorderReady = false; + _VoiceProvider _provider = _VoiceProvider.local; + String _ttsStatus = ''; String _sttStatus = ''; String _sttResult = ''; @@ -38,13 +42,22 @@ class _VoiceTestPageState extends ConsumerState { String _recordMode = ''; // 'stt' or 'rt' String? _recordingPath; + // Endpoint paths based on provider + String get _ttsEndpoint => _provider == _VoiceProvider.openai + ? '/api/v1/test/tts/synthesize-openai' + : '/api/v1/test/tts/synthesize'; + + String get _sttEndpoint => _provider == _VoiceProvider.openai + ? '/api/v1/test/stt/transcribe-openai' + : '/api/v1/test/stt/transcribe'; + Dio get _dioBinary { final base = ref.read(dioClientProvider); return Dio(BaseOptions( baseUrl: base.options.baseUrl, headers: Map.from(base.options.headers), connectTimeout: const Duration(seconds: 30), - receiveTimeout: const Duration(seconds: 60), + receiveTimeout: const Duration(seconds: 120), responseType: ResponseType.bytes, ))..interceptors.addAll(base.interceptors); } @@ -77,26 +90,27 @@ class _VoiceTestPageState extends ConsumerState { Future _doTTS() async { final text = _ttsController.text.trim(); if (text.isEmpty) return; + final label = _provider == _VoiceProvider.openai ? 'OpenAI' : 'Local'; setState(() { _isSynthesizing = true; - _ttsStatus = '合成中...'; + _ttsStatus = '[$label] 合成中...'; }); final sw = Stopwatch()..start(); try { final resp = await _dioBinary.get( - '/api/v1/test/tts/synthesize', + _ttsEndpoint, queryParameters: {'text': text}, ); sw.stop(); final bytes = resp.data as List; setState(() { _ttsStatus = - '完成!耗时 ${sw.elapsedMilliseconds}ms,大小 ${(bytes.length / 1024).toStringAsFixed(1)}KB'; + '[$label] 完成!耗时 ${sw.elapsedMilliseconds}ms,大小 ${(bytes.length / 1024).toStringAsFixed(1)}KB'; }); await _playWav(Uint8List.fromList(bytes)); } catch (e) { sw.stop(); - setState(() => _ttsStatus = '错误: $e'); + setState(() => _ttsStatus = '[$label] 错误: $e'); } finally { setState(() => _isSynthesizing = false); } @@ -161,26 +175,26 @@ class _VoiceTestPageState extends ConsumerState { // ---- STT ---- Future _doSTT(String audioPath) async { - setState(() => _sttStatus = '识别中...'); + final label = _provider == _VoiceProvider.openai ? 'OpenAI' : 'Local'; + setState(() => _sttStatus = '[$label] 识别中...'); final sw = Stopwatch()..start(); try { final formData = FormData.fromMap({ 'audio': await MultipartFile.fromFile(audioPath, filename: 'recording.wav'), }); - final resp = - await _dioJson.post('/api/v1/test/stt/transcribe', data: formData); + final resp = await _dioJson.post(_sttEndpoint, data: formData); sw.stop(); final data = resp.data as Map; setState(() { _sttResult = data['text'] ?? '(empty)'; - _sttStatus = - '完成!耗时 ${sw.elapsedMilliseconds}ms,时长 ${data['duration'] ?? 0}s'; + final extra = data['duration'] != null ? ',时长 ${data['duration']}s' : ''; + _sttStatus = '[$label] 完成!耗时 ${sw.elapsedMilliseconds}ms$extra'; }); } catch (e) { sw.stop(); setState(() { - _sttStatus = '错误: $e'; + _sttStatus = '[$label] 错误: $e'; _sttResult = ''; }); } finally { @@ -190,7 +204,8 @@ class _VoiceTestPageState extends ConsumerState { // ---- Round-trip: STT → TTS ---- Future _doRoundTrip(String audioPath) async { - setState(() => _rtStatus = 'STT 识别中...'); + final label = _provider == _VoiceProvider.openai ? 'OpenAI' : 'Local'; + setState(() => _rtStatus = '[$label] STT 识别中...'); final totalSw = Stopwatch()..start(); try { // 1. STT @@ -199,24 +214,23 @@ class _VoiceTestPageState extends ConsumerState { 'audio': await MultipartFile.fromFile(audioPath, filename: 'recording.wav'), }); - final sttResp = - await _dioJson.post('/api/v1/test/stt/transcribe', data: formData); + final sttResp = await _dioJson.post(_sttEndpoint, data: formData); sttSw.stop(); final sttData = sttResp.data as Map; final text = sttData['text'] ?? ''; setState(() { - _rtResult = 'STT (${sttSw.elapsedMilliseconds}ms): $text'; - _rtStatus = 'TTS 合成中...'; + _rtResult = '[$label] STT (${sttSw.elapsedMilliseconds}ms): $text'; + _rtStatus = '[$label] TTS 合成中...'; }); if (text.isEmpty) { - setState(() => _rtStatus = 'STT 识别为空'); + setState(() => _rtStatus = '[$label] STT 识别为空'); return; } // 2. TTS final ttsSw = Stopwatch()..start(); final ttsResp = await _dioBinary.get( - '/api/v1/test/tts/synthesize', + _ttsEndpoint, queryParameters: {'text': text}, ); ttsSw.stop(); @@ -224,14 +238,14 @@ class _VoiceTestPageState extends ConsumerState { final audioBytes = ttsResp.data as List; setState(() { _rtResult += - '\nTTS (${ttsSw.elapsedMilliseconds}ms): ${(audioBytes.length / 1024).toStringAsFixed(1)}KB'; + '\n[$label] TTS (${ttsSw.elapsedMilliseconds}ms): ${(audioBytes.length / 1024).toStringAsFixed(1)}KB'; _rtStatus = - '完成!STT=${sttSw.elapsedMilliseconds}ms + TTS=${ttsSw.elapsedMilliseconds}ms = ${totalSw.elapsedMilliseconds}ms'; + '[$label] STT=${sttSw.elapsedMilliseconds}ms + TTS=${ttsSw.elapsedMilliseconds}ms = ${totalSw.elapsedMilliseconds}ms'; }); await _playWav(Uint8List.fromList(audioBytes)); } catch (e) { totalSw.stop(); - setState(() => _rtStatus = '错误: $e'); + setState(() => _rtStatus = '[$label] 错误: $e'); } finally { _cleanupFile(audioPath); } @@ -259,6 +273,39 @@ class _VoiceTestPageState extends ConsumerState { body: ListView( padding: const EdgeInsets.all(16), children: [ + // Provider toggle + Center( + child: SegmentedButton<_VoiceProvider>( + segments: const [ + ButtonSegment( + value: _VoiceProvider.local, + label: Text('Local'), + icon: Icon(Icons.computer, size: 18), + ), + ButtonSegment( + value: _VoiceProvider.openai, + label: Text('OpenAI'), + icon: Icon(Icons.cloud, size: 18), + ), + ], + selected: {_provider}, + onSelectionChanged: (Set<_VoiceProvider> sel) { + setState(() => _provider = sel.first); + }, + ), + ), + Padding( + padding: const EdgeInsets.only(top: 4, bottom: 12), + child: Center( + child: Text( + _provider == _VoiceProvider.openai + ? 'STT: gpt-4o-transcribe | TTS: tts-1' + : 'STT: faster-whisper | TTS: Kokoro', + style: TextStyle(fontSize: 12, color: Colors.grey[500]), + ), + ), + ), + // TTS Section _buildSection( title: 'TTS (文本转语音)', child: Column( @@ -287,12 +334,14 @@ class _VoiceTestPageState extends ConsumerState { Padding( padding: const EdgeInsets.only(top: 8), child: Text(_ttsStatus, - style: TextStyle(color: Colors.grey[600], fontSize: 13)), + style: + TextStyle(color: Colors.grey[600], fontSize: 13)), ), ], ), ), const SizedBox(height: 16), + // STT Section _buildSection( title: 'STT (语音转文本)', child: Column( @@ -304,8 +353,9 @@ class _VoiceTestPageState extends ConsumerState { child: ElevatedButton.icon( onPressed: () {}, style: ElevatedButton.styleFrom( - backgroundColor: - _isRecording && _recordMode == 'stt' ? Colors.red : null, + backgroundColor: _isRecording && _recordMode == 'stt' + ? Colors.red + : null, ), icon: Icon(_isRecording && _recordMode == 'stt' ? Icons.mic @@ -319,7 +369,8 @@ class _VoiceTestPageState extends ConsumerState { Padding( padding: const EdgeInsets.only(top: 8), child: Text(_sttStatus, - style: TextStyle(color: Colors.grey[600], fontSize: 13)), + style: + TextStyle(color: Colors.grey[600], fontSize: 13)), ), if (_sttResult.isNotEmpty) Container( @@ -330,12 +381,14 @@ class _VoiceTestPageState extends ConsumerState { color: Colors.grey[100], borderRadius: BorderRadius.circular(8), ), - child: Text(_sttResult, style: const TextStyle(fontSize: 16)), + child: + Text(_sttResult, style: const TextStyle(fontSize: 16)), ), ], ), ), const SizedBox(height: 16), + // Round-trip Section _buildSection( title: 'Round-trip (STT + TTS)', subtitle: '录音 → 识别文本 → 合成语音播放', @@ -348,8 +401,9 @@ class _VoiceTestPageState extends ConsumerState { child: ElevatedButton.icon( onPressed: () {}, style: ElevatedButton.styleFrom( - backgroundColor: - _isRecording && _recordMode == 'rt' ? Colors.red : null, + backgroundColor: _isRecording && _recordMode == 'rt' + ? Colors.red + : null, ), icon: Icon(_isRecording && _recordMode == 'rt' ? Icons.mic @@ -363,7 +417,8 @@ class _VoiceTestPageState extends ConsumerState { Padding( padding: const EdgeInsets.only(top: 8), child: Text(_rtStatus, - style: TextStyle(color: Colors.grey[600], fontSize: 13)), + style: + TextStyle(color: Colors.grey[600], fontSize: 13)), ), if (_rtResult.isNotEmpty) Container( @@ -374,7 +429,8 @@ class _VoiceTestPageState extends ConsumerState { color: Colors.grey[100], borderRadius: BorderRadius.circular(8), ), - child: Text(_rtResult, style: const TextStyle(fontSize: 14)), + child: + Text(_rtResult, style: const TextStyle(fontSize: 14)), ), ], ), diff --git a/packages/services/voice-service/requirements.txt b/packages/services/voice-service/requirements.txt index 3822509..2bba1c9 100644 --- a/packages/services/voice-service/requirements.txt +++ b/packages/services/voice-service/requirements.txt @@ -6,6 +6,7 @@ misaki==0.7.17 silero-vad==5.1 twilio==9.0.0 anthropic>=0.32.0 +openai>=1.30.0 websockets==12.0 pydantic==2.6.0 pydantic-settings==2.2.0 diff --git a/packages/services/voice-service/src/api/test_tts.py b/packages/services/voice-service/src/api/test_tts.py index 3a105e1..36a2757 100644 --- a/packages/services/voice-service/src/api/test_tts.py +++ b/packages/services/voice-service/src/api/test_tts.py @@ -2,7 +2,9 @@ import asyncio import io +import os import struct +import tempfile import numpy as np from fastapi import APIRouter, Request, Query, UploadFile, File from fastapi.responses import HTMLResponse, Response @@ -245,9 +247,6 @@ async def stt_transcribe(request: Request, audio: UploadFile = File(...)): if stt is None or stt._model is None: return {"error": "STT model not loaded", "text": ""} - import tempfile - import os - # Save uploaded file to temp raw = await audio.read() suffix = os.path.splitext(audio.filename or "audio.webm")[1] or ".webm" @@ -276,3 +275,83 @@ async def stt_transcribe(request: Request, audio: UploadFile = File(...)): } finally: os.unlink(tmp_path) + + +# ===================================================================== +# OpenAI Voice API endpoints +# ===================================================================== + +def _get_openai_client(): + """Lazy-init OpenAI client with proxy support.""" + from openai import OpenAI + api_key = os.environ.get("OPENAI_API_KEY") + base_url = os.environ.get("OPENAI_BASE_URL") + if not api_key: + return None + kwargs = {"api_key": api_key} + if base_url: + kwargs["base_url"] = base_url + return OpenAI(**kwargs) + + +@router.get("/tts/synthesize-openai") +async def tts_synthesize_openai( + text: str = Query(..., min_length=1, max_length=500), + model: str = Query("tts-1", regex="^(tts-1|tts-1-hd|gpt-4o-mini-tts)$"), + voice: str = Query("alloy", regex="^(alloy|ash|ballad|coral|echo|fable|nova|onyx|sage|shimmer)$"), +): + """Synthesize text to audio via OpenAI TTS API.""" + client = _get_openai_client() + if client is None: + return Response(content="OPENAI_API_KEY not configured", status_code=503) + + loop = asyncio.get_event_loop() + def _synth(): + response = client.audio.speech.create( + model=model, + voice=voice, + input=text, + response_format="wav", + ) + return response.content + + try: + wav_bytes = await loop.run_in_executor(None, _synth) + return Response(content=wav_bytes, media_type="audio/wav") + except Exception as e: + return Response(content=f"OpenAI TTS error: {e}", status_code=500) + + +@router.post("/stt/transcribe-openai") +async def stt_transcribe_openai( + audio: UploadFile = File(...), + model: str = Query("gpt-4o-transcribe", regex="^(whisper-1|gpt-4o-transcribe|gpt-4o-mini-transcribe)$"), +): + """Transcribe uploaded audio via OpenAI STT API.""" + client = _get_openai_client() + if client is None: + return {"error": "OPENAI_API_KEY not configured", "text": ""} + + raw = await audio.read() + suffix = os.path.splitext(audio.filename or "audio.wav")[1] or ".wav" + with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as f: + f.write(raw) + tmp_path = f.name + + try: + loop = asyncio.get_event_loop() + def _transcribe(): + with open(tmp_path, "rb") as af: + result = client.audio.transcriptions.create( + model=model, + file=af, + language="zh", + ) + return result.text + + text = await loop.run_in_executor(None, _transcribe) + return {"text": text, "language": "zh", "model": model} + except Exception as e: + return {"error": f"OpenAI STT error: {e}", "text": ""} + finally: + os.unlink(tmp_path)