feat: add OpenAI TTS/STT API endpoints for comparison testing
- Add openai package to voice-service requirements - Add /api/v1/test/tts/synthesize-openai (tts-1/tts-1-hd/gpt-4o-mini-tts) - Add /api/v1/test/stt/transcribe-openai (gpt-4o-transcribe/whisper-1) - Add OPENAI_API_KEY and OPENAI_BASE_URL env vars to voice-service - Flutter test page: SegmentedButton to toggle Local/OpenAI provider - All endpoints maintain same response format for easy comparison Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
ac0b8ee1c6
commit
d43baed3a5
|
|
@ -327,6 +327,8 @@ services:
|
||||||
- KOKORO_MODEL=${KOKORO_MODEL:-kokoro-82m}
|
- KOKORO_MODEL=${KOKORO_MODEL:-kokoro-82m}
|
||||||
- KOKORO_VOICE=${KOKORO_VOICE:-zf_xiaoxiao}
|
- KOKORO_VOICE=${KOKORO_VOICE:-zf_xiaoxiao}
|
||||||
- DEVICE=${VOICE_DEVICE:-cpu}
|
- DEVICE=${VOICE_DEVICE:-cpu}
|
||||||
|
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||||
|
- OPENAI_BASE_URL=${OPENAI_BASE_URL}
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD-SHELL", "python3 -c \"import urllib.request; urllib.request.urlopen('http://localhost:3008/docs')\""]
|
test: ["CMD-SHELL", "python3 -c \"import urllib.request; urllib.request.urlopen('http://localhost:3008/docs')\""]
|
||||||
interval: 30s
|
interval: 30s
|
||||||
|
|
|
||||||
|
|
@ -10,8 +10,10 @@ import 'package:path/path.dart' as p;
|
||||||
import 'package:permission_handler/permission_handler.dart';
|
import 'package:permission_handler/permission_handler.dart';
|
||||||
import '../../../../core/network/dio_client.dart';
|
import '../../../../core/network/dio_client.dart';
|
||||||
|
|
||||||
|
enum _VoiceProvider { local, openai }
|
||||||
|
|
||||||
/// Temporary voice I/O test page — TTS + STT + Round-trip.
|
/// Temporary voice I/O test page — TTS + STT + Round-trip.
|
||||||
/// Uses flutter_sound for both recording and playback.
|
/// Supports switching between Local (Kokoro/faster-whisper) and OpenAI APIs.
|
||||||
class VoiceTestPage extends ConsumerStatefulWidget {
|
class VoiceTestPage extends ConsumerStatefulWidget {
|
||||||
const VoiceTestPage({super.key});
|
const VoiceTestPage({super.key});
|
||||||
|
|
||||||
|
|
@ -28,6 +30,8 @@ class _VoiceTestPageState extends ConsumerState<VoiceTestPage> {
|
||||||
bool _playerReady = false;
|
bool _playerReady = false;
|
||||||
bool _recorderReady = false;
|
bool _recorderReady = false;
|
||||||
|
|
||||||
|
_VoiceProvider _provider = _VoiceProvider.local;
|
||||||
|
|
||||||
String _ttsStatus = '';
|
String _ttsStatus = '';
|
||||||
String _sttStatus = '';
|
String _sttStatus = '';
|
||||||
String _sttResult = '';
|
String _sttResult = '';
|
||||||
|
|
@ -38,13 +42,22 @@ class _VoiceTestPageState extends ConsumerState<VoiceTestPage> {
|
||||||
String _recordMode = ''; // 'stt' or 'rt'
|
String _recordMode = ''; // 'stt' or 'rt'
|
||||||
String? _recordingPath;
|
String? _recordingPath;
|
||||||
|
|
||||||
|
// Endpoint paths based on provider
|
||||||
|
String get _ttsEndpoint => _provider == _VoiceProvider.openai
|
||||||
|
? '/api/v1/test/tts/synthesize-openai'
|
||||||
|
: '/api/v1/test/tts/synthesize';
|
||||||
|
|
||||||
|
String get _sttEndpoint => _provider == _VoiceProvider.openai
|
||||||
|
? '/api/v1/test/stt/transcribe-openai'
|
||||||
|
: '/api/v1/test/stt/transcribe';
|
||||||
|
|
||||||
Dio get _dioBinary {
|
Dio get _dioBinary {
|
||||||
final base = ref.read(dioClientProvider);
|
final base = ref.read(dioClientProvider);
|
||||||
return Dio(BaseOptions(
|
return Dio(BaseOptions(
|
||||||
baseUrl: base.options.baseUrl,
|
baseUrl: base.options.baseUrl,
|
||||||
headers: Map.from(base.options.headers),
|
headers: Map.from(base.options.headers),
|
||||||
connectTimeout: const Duration(seconds: 30),
|
connectTimeout: const Duration(seconds: 30),
|
||||||
receiveTimeout: const Duration(seconds: 60),
|
receiveTimeout: const Duration(seconds: 120),
|
||||||
responseType: ResponseType.bytes,
|
responseType: ResponseType.bytes,
|
||||||
))..interceptors.addAll(base.interceptors);
|
))..interceptors.addAll(base.interceptors);
|
||||||
}
|
}
|
||||||
|
|
@ -77,26 +90,27 @@ class _VoiceTestPageState extends ConsumerState<VoiceTestPage> {
|
||||||
Future<void> _doTTS() async {
|
Future<void> _doTTS() async {
|
||||||
final text = _ttsController.text.trim();
|
final text = _ttsController.text.trim();
|
||||||
if (text.isEmpty) return;
|
if (text.isEmpty) return;
|
||||||
|
final label = _provider == _VoiceProvider.openai ? 'OpenAI' : 'Local';
|
||||||
setState(() {
|
setState(() {
|
||||||
_isSynthesizing = true;
|
_isSynthesizing = true;
|
||||||
_ttsStatus = '合成中...';
|
_ttsStatus = '[$label] 合成中...';
|
||||||
});
|
});
|
||||||
final sw = Stopwatch()..start();
|
final sw = Stopwatch()..start();
|
||||||
try {
|
try {
|
||||||
final resp = await _dioBinary.get(
|
final resp = await _dioBinary.get(
|
||||||
'/api/v1/test/tts/synthesize',
|
_ttsEndpoint,
|
||||||
queryParameters: {'text': text},
|
queryParameters: {'text': text},
|
||||||
);
|
);
|
||||||
sw.stop();
|
sw.stop();
|
||||||
final bytes = resp.data as List<int>;
|
final bytes = resp.data as List<int>;
|
||||||
setState(() {
|
setState(() {
|
||||||
_ttsStatus =
|
_ttsStatus =
|
||||||
'完成!耗时 ${sw.elapsedMilliseconds}ms,大小 ${(bytes.length / 1024).toStringAsFixed(1)}KB';
|
'[$label] 完成!耗时 ${sw.elapsedMilliseconds}ms,大小 ${(bytes.length / 1024).toStringAsFixed(1)}KB';
|
||||||
});
|
});
|
||||||
await _playWav(Uint8List.fromList(bytes));
|
await _playWav(Uint8List.fromList(bytes));
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
sw.stop();
|
sw.stop();
|
||||||
setState(() => _ttsStatus = '错误: $e');
|
setState(() => _ttsStatus = '[$label] 错误: $e');
|
||||||
} finally {
|
} finally {
|
||||||
setState(() => _isSynthesizing = false);
|
setState(() => _isSynthesizing = false);
|
||||||
}
|
}
|
||||||
|
|
@ -161,26 +175,26 @@ class _VoiceTestPageState extends ConsumerState<VoiceTestPage> {
|
||||||
|
|
||||||
// ---- STT ----
|
// ---- STT ----
|
||||||
Future<void> _doSTT(String audioPath) async {
|
Future<void> _doSTT(String audioPath) async {
|
||||||
setState(() => _sttStatus = '识别中...');
|
final label = _provider == _VoiceProvider.openai ? 'OpenAI' : 'Local';
|
||||||
|
setState(() => _sttStatus = '[$label] 识别中...');
|
||||||
final sw = Stopwatch()..start();
|
final sw = Stopwatch()..start();
|
||||||
try {
|
try {
|
||||||
final formData = FormData.fromMap({
|
final formData = FormData.fromMap({
|
||||||
'audio':
|
'audio':
|
||||||
await MultipartFile.fromFile(audioPath, filename: 'recording.wav'),
|
await MultipartFile.fromFile(audioPath, filename: 'recording.wav'),
|
||||||
});
|
});
|
||||||
final resp =
|
final resp = await _dioJson.post(_sttEndpoint, data: formData);
|
||||||
await _dioJson.post('/api/v1/test/stt/transcribe', data: formData);
|
|
||||||
sw.stop();
|
sw.stop();
|
||||||
final data = resp.data as Map<String, dynamic>;
|
final data = resp.data as Map<String, dynamic>;
|
||||||
setState(() {
|
setState(() {
|
||||||
_sttResult = data['text'] ?? '(empty)';
|
_sttResult = data['text'] ?? '(empty)';
|
||||||
_sttStatus =
|
final extra = data['duration'] != null ? ',时长 ${data['duration']}s' : '';
|
||||||
'完成!耗时 ${sw.elapsedMilliseconds}ms,时长 ${data['duration'] ?? 0}s';
|
_sttStatus = '[$label] 完成!耗时 ${sw.elapsedMilliseconds}ms$extra';
|
||||||
});
|
});
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
sw.stop();
|
sw.stop();
|
||||||
setState(() {
|
setState(() {
|
||||||
_sttStatus = '错误: $e';
|
_sttStatus = '[$label] 错误: $e';
|
||||||
_sttResult = '';
|
_sttResult = '';
|
||||||
});
|
});
|
||||||
} finally {
|
} finally {
|
||||||
|
|
@ -190,7 +204,8 @@ class _VoiceTestPageState extends ConsumerState<VoiceTestPage> {
|
||||||
|
|
||||||
// ---- Round-trip: STT → TTS ----
|
// ---- Round-trip: STT → TTS ----
|
||||||
Future<void> _doRoundTrip(String audioPath) async {
|
Future<void> _doRoundTrip(String audioPath) async {
|
||||||
setState(() => _rtStatus = 'STT 识别中...');
|
final label = _provider == _VoiceProvider.openai ? 'OpenAI' : 'Local';
|
||||||
|
setState(() => _rtStatus = '[$label] STT 识别中...');
|
||||||
final totalSw = Stopwatch()..start();
|
final totalSw = Stopwatch()..start();
|
||||||
try {
|
try {
|
||||||
// 1. STT
|
// 1. STT
|
||||||
|
|
@ -199,24 +214,23 @@ class _VoiceTestPageState extends ConsumerState<VoiceTestPage> {
|
||||||
'audio':
|
'audio':
|
||||||
await MultipartFile.fromFile(audioPath, filename: 'recording.wav'),
|
await MultipartFile.fromFile(audioPath, filename: 'recording.wav'),
|
||||||
});
|
});
|
||||||
final sttResp =
|
final sttResp = await _dioJson.post(_sttEndpoint, data: formData);
|
||||||
await _dioJson.post('/api/v1/test/stt/transcribe', data: formData);
|
|
||||||
sttSw.stop();
|
sttSw.stop();
|
||||||
final sttData = sttResp.data as Map<String, dynamic>;
|
final sttData = sttResp.data as Map<String, dynamic>;
|
||||||
final text = sttData['text'] ?? '';
|
final text = sttData['text'] ?? '';
|
||||||
setState(() {
|
setState(() {
|
||||||
_rtResult = 'STT (${sttSw.elapsedMilliseconds}ms): $text';
|
_rtResult = '[$label] STT (${sttSw.elapsedMilliseconds}ms): $text';
|
||||||
_rtStatus = 'TTS 合成中...';
|
_rtStatus = '[$label] TTS 合成中...';
|
||||||
});
|
});
|
||||||
if (text.isEmpty) {
|
if (text.isEmpty) {
|
||||||
setState(() => _rtStatus = 'STT 识别为空');
|
setState(() => _rtStatus = '[$label] STT 识别为空');
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. TTS
|
// 2. TTS
|
||||||
final ttsSw = Stopwatch()..start();
|
final ttsSw = Stopwatch()..start();
|
||||||
final ttsResp = await _dioBinary.get(
|
final ttsResp = await _dioBinary.get(
|
||||||
'/api/v1/test/tts/synthesize',
|
_ttsEndpoint,
|
||||||
queryParameters: {'text': text},
|
queryParameters: {'text': text},
|
||||||
);
|
);
|
||||||
ttsSw.stop();
|
ttsSw.stop();
|
||||||
|
|
@ -224,14 +238,14 @@ class _VoiceTestPageState extends ConsumerState<VoiceTestPage> {
|
||||||
final audioBytes = ttsResp.data as List<int>;
|
final audioBytes = ttsResp.data as List<int>;
|
||||||
setState(() {
|
setState(() {
|
||||||
_rtResult +=
|
_rtResult +=
|
||||||
'\nTTS (${ttsSw.elapsedMilliseconds}ms): ${(audioBytes.length / 1024).toStringAsFixed(1)}KB';
|
'\n[$label] TTS (${ttsSw.elapsedMilliseconds}ms): ${(audioBytes.length / 1024).toStringAsFixed(1)}KB';
|
||||||
_rtStatus =
|
_rtStatus =
|
||||||
'完成!STT=${sttSw.elapsedMilliseconds}ms + TTS=${ttsSw.elapsedMilliseconds}ms = ${totalSw.elapsedMilliseconds}ms';
|
'[$label] STT=${sttSw.elapsedMilliseconds}ms + TTS=${ttsSw.elapsedMilliseconds}ms = ${totalSw.elapsedMilliseconds}ms';
|
||||||
});
|
});
|
||||||
await _playWav(Uint8List.fromList(audioBytes));
|
await _playWav(Uint8List.fromList(audioBytes));
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
totalSw.stop();
|
totalSw.stop();
|
||||||
setState(() => _rtStatus = '错误: $e');
|
setState(() => _rtStatus = '[$label] 错误: $e');
|
||||||
} finally {
|
} finally {
|
||||||
_cleanupFile(audioPath);
|
_cleanupFile(audioPath);
|
||||||
}
|
}
|
||||||
|
|
@ -259,6 +273,39 @@ class _VoiceTestPageState extends ConsumerState<VoiceTestPage> {
|
||||||
body: ListView(
|
body: ListView(
|
||||||
padding: const EdgeInsets.all(16),
|
padding: const EdgeInsets.all(16),
|
||||||
children: [
|
children: [
|
||||||
|
// Provider toggle
|
||||||
|
Center(
|
||||||
|
child: SegmentedButton<_VoiceProvider>(
|
||||||
|
segments: const [
|
||||||
|
ButtonSegment(
|
||||||
|
value: _VoiceProvider.local,
|
||||||
|
label: Text('Local'),
|
||||||
|
icon: Icon(Icons.computer, size: 18),
|
||||||
|
),
|
||||||
|
ButtonSegment(
|
||||||
|
value: _VoiceProvider.openai,
|
||||||
|
label: Text('OpenAI'),
|
||||||
|
icon: Icon(Icons.cloud, size: 18),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
selected: {_provider},
|
||||||
|
onSelectionChanged: (Set<_VoiceProvider> sel) {
|
||||||
|
setState(() => _provider = sel.first);
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Padding(
|
||||||
|
padding: const EdgeInsets.only(top: 4, bottom: 12),
|
||||||
|
child: Center(
|
||||||
|
child: Text(
|
||||||
|
_provider == _VoiceProvider.openai
|
||||||
|
? 'STT: gpt-4o-transcribe | TTS: tts-1'
|
||||||
|
: 'STT: faster-whisper | TTS: Kokoro',
|
||||||
|
style: TextStyle(fontSize: 12, color: Colors.grey[500]),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
// TTS Section
|
||||||
_buildSection(
|
_buildSection(
|
||||||
title: 'TTS (文本转语音)',
|
title: 'TTS (文本转语音)',
|
||||||
child: Column(
|
child: Column(
|
||||||
|
|
@ -287,12 +334,14 @@ class _VoiceTestPageState extends ConsumerState<VoiceTestPage> {
|
||||||
Padding(
|
Padding(
|
||||||
padding: const EdgeInsets.only(top: 8),
|
padding: const EdgeInsets.only(top: 8),
|
||||||
child: Text(_ttsStatus,
|
child: Text(_ttsStatus,
|
||||||
style: TextStyle(color: Colors.grey[600], fontSize: 13)),
|
style:
|
||||||
|
TextStyle(color: Colors.grey[600], fontSize: 13)),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
const SizedBox(height: 16),
|
const SizedBox(height: 16),
|
||||||
|
// STT Section
|
||||||
_buildSection(
|
_buildSection(
|
||||||
title: 'STT (语音转文本)',
|
title: 'STT (语音转文本)',
|
||||||
child: Column(
|
child: Column(
|
||||||
|
|
@ -304,8 +353,9 @@ class _VoiceTestPageState extends ConsumerState<VoiceTestPage> {
|
||||||
child: ElevatedButton.icon(
|
child: ElevatedButton.icon(
|
||||||
onPressed: () {},
|
onPressed: () {},
|
||||||
style: ElevatedButton.styleFrom(
|
style: ElevatedButton.styleFrom(
|
||||||
backgroundColor:
|
backgroundColor: _isRecording && _recordMode == 'stt'
|
||||||
_isRecording && _recordMode == 'stt' ? Colors.red : null,
|
? Colors.red
|
||||||
|
: null,
|
||||||
),
|
),
|
||||||
icon: Icon(_isRecording && _recordMode == 'stt'
|
icon: Icon(_isRecording && _recordMode == 'stt'
|
||||||
? Icons.mic
|
? Icons.mic
|
||||||
|
|
@ -319,7 +369,8 @@ class _VoiceTestPageState extends ConsumerState<VoiceTestPage> {
|
||||||
Padding(
|
Padding(
|
||||||
padding: const EdgeInsets.only(top: 8),
|
padding: const EdgeInsets.only(top: 8),
|
||||||
child: Text(_sttStatus,
|
child: Text(_sttStatus,
|
||||||
style: TextStyle(color: Colors.grey[600], fontSize: 13)),
|
style:
|
||||||
|
TextStyle(color: Colors.grey[600], fontSize: 13)),
|
||||||
),
|
),
|
||||||
if (_sttResult.isNotEmpty)
|
if (_sttResult.isNotEmpty)
|
||||||
Container(
|
Container(
|
||||||
|
|
@ -330,12 +381,14 @@ class _VoiceTestPageState extends ConsumerState<VoiceTestPage> {
|
||||||
color: Colors.grey[100],
|
color: Colors.grey[100],
|
||||||
borderRadius: BorderRadius.circular(8),
|
borderRadius: BorderRadius.circular(8),
|
||||||
),
|
),
|
||||||
child: Text(_sttResult, style: const TextStyle(fontSize: 16)),
|
child:
|
||||||
|
Text(_sttResult, style: const TextStyle(fontSize: 16)),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
const SizedBox(height: 16),
|
const SizedBox(height: 16),
|
||||||
|
// Round-trip Section
|
||||||
_buildSection(
|
_buildSection(
|
||||||
title: 'Round-trip (STT + TTS)',
|
title: 'Round-trip (STT + TTS)',
|
||||||
subtitle: '录音 → 识别文本 → 合成语音播放',
|
subtitle: '录音 → 识别文本 → 合成语音播放',
|
||||||
|
|
@ -348,8 +401,9 @@ class _VoiceTestPageState extends ConsumerState<VoiceTestPage> {
|
||||||
child: ElevatedButton.icon(
|
child: ElevatedButton.icon(
|
||||||
onPressed: () {},
|
onPressed: () {},
|
||||||
style: ElevatedButton.styleFrom(
|
style: ElevatedButton.styleFrom(
|
||||||
backgroundColor:
|
backgroundColor: _isRecording && _recordMode == 'rt'
|
||||||
_isRecording && _recordMode == 'rt' ? Colors.red : null,
|
? Colors.red
|
||||||
|
: null,
|
||||||
),
|
),
|
||||||
icon: Icon(_isRecording && _recordMode == 'rt'
|
icon: Icon(_isRecording && _recordMode == 'rt'
|
||||||
? Icons.mic
|
? Icons.mic
|
||||||
|
|
@ -363,7 +417,8 @@ class _VoiceTestPageState extends ConsumerState<VoiceTestPage> {
|
||||||
Padding(
|
Padding(
|
||||||
padding: const EdgeInsets.only(top: 8),
|
padding: const EdgeInsets.only(top: 8),
|
||||||
child: Text(_rtStatus,
|
child: Text(_rtStatus,
|
||||||
style: TextStyle(color: Colors.grey[600], fontSize: 13)),
|
style:
|
||||||
|
TextStyle(color: Colors.grey[600], fontSize: 13)),
|
||||||
),
|
),
|
||||||
if (_rtResult.isNotEmpty)
|
if (_rtResult.isNotEmpty)
|
||||||
Container(
|
Container(
|
||||||
|
|
@ -374,7 +429,8 @@ class _VoiceTestPageState extends ConsumerState<VoiceTestPage> {
|
||||||
color: Colors.grey[100],
|
color: Colors.grey[100],
|
||||||
borderRadius: BorderRadius.circular(8),
|
borderRadius: BorderRadius.circular(8),
|
||||||
),
|
),
|
||||||
child: Text(_rtResult, style: const TextStyle(fontSize: 14)),
|
child:
|
||||||
|
Text(_rtResult, style: const TextStyle(fontSize: 14)),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ misaki==0.7.17
|
||||||
silero-vad==5.1
|
silero-vad==5.1
|
||||||
twilio==9.0.0
|
twilio==9.0.0
|
||||||
anthropic>=0.32.0
|
anthropic>=0.32.0
|
||||||
|
openai>=1.30.0
|
||||||
websockets==12.0
|
websockets==12.0
|
||||||
pydantic==2.6.0
|
pydantic==2.6.0
|
||||||
pydantic-settings==2.2.0
|
pydantic-settings==2.2.0
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,9 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import io
|
import io
|
||||||
|
import os
|
||||||
import struct
|
import struct
|
||||||
|
import tempfile
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from fastapi import APIRouter, Request, Query, UploadFile, File
|
from fastapi import APIRouter, Request, Query, UploadFile, File
|
||||||
from fastapi.responses import HTMLResponse, Response
|
from fastapi.responses import HTMLResponse, Response
|
||||||
|
|
@ -245,9 +247,6 @@ async def stt_transcribe(request: Request, audio: UploadFile = File(...)):
|
||||||
if stt is None or stt._model is None:
|
if stt is None or stt._model is None:
|
||||||
return {"error": "STT model not loaded", "text": ""}
|
return {"error": "STT model not loaded", "text": ""}
|
||||||
|
|
||||||
import tempfile
|
|
||||||
import os
|
|
||||||
|
|
||||||
# Save uploaded file to temp
|
# Save uploaded file to temp
|
||||||
raw = await audio.read()
|
raw = await audio.read()
|
||||||
suffix = os.path.splitext(audio.filename or "audio.webm")[1] or ".webm"
|
suffix = os.path.splitext(audio.filename or "audio.webm")[1] or ".webm"
|
||||||
|
|
@ -276,3 +275,83 @@ async def stt_transcribe(request: Request, audio: UploadFile = File(...)):
|
||||||
}
|
}
|
||||||
finally:
|
finally:
|
||||||
os.unlink(tmp_path)
|
os.unlink(tmp_path)
|
||||||
|
|
||||||
|
|
||||||
|
# =====================================================================
|
||||||
|
# OpenAI Voice API endpoints
|
||||||
|
# =====================================================================
|
||||||
|
|
||||||
|
def _get_openai_client():
|
||||||
|
"""Lazy-init OpenAI client with proxy support."""
|
||||||
|
from openai import OpenAI
|
||||||
|
api_key = os.environ.get("OPENAI_API_KEY")
|
||||||
|
base_url = os.environ.get("OPENAI_BASE_URL")
|
||||||
|
if not api_key:
|
||||||
|
return None
|
||||||
|
kwargs = {"api_key": api_key}
|
||||||
|
if base_url:
|
||||||
|
kwargs["base_url"] = base_url
|
||||||
|
return OpenAI(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/tts/synthesize-openai")
|
||||||
|
async def tts_synthesize_openai(
|
||||||
|
text: str = Query(..., min_length=1, max_length=500),
|
||||||
|
model: str = Query("tts-1", regex="^(tts-1|tts-1-hd|gpt-4o-mini-tts)$"),
|
||||||
|
voice: str = Query("alloy", regex="^(alloy|ash|ballad|coral|echo|fable|nova|onyx|sage|shimmer)$"),
|
||||||
|
):
|
||||||
|
"""Synthesize text to audio via OpenAI TTS API."""
|
||||||
|
client = _get_openai_client()
|
||||||
|
if client is None:
|
||||||
|
return Response(content="OPENAI_API_KEY not configured", status_code=503)
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
def _synth():
|
||||||
|
response = client.audio.speech.create(
|
||||||
|
model=model,
|
||||||
|
voice=voice,
|
||||||
|
input=text,
|
||||||
|
response_format="wav",
|
||||||
|
)
|
||||||
|
return response.content
|
||||||
|
|
||||||
|
try:
|
||||||
|
wav_bytes = await loop.run_in_executor(None, _synth)
|
||||||
|
return Response(content=wav_bytes, media_type="audio/wav")
|
||||||
|
except Exception as e:
|
||||||
|
return Response(content=f"OpenAI TTS error: {e}", status_code=500)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/stt/transcribe-openai")
|
||||||
|
async def stt_transcribe_openai(
|
||||||
|
audio: UploadFile = File(...),
|
||||||
|
model: str = Query("gpt-4o-transcribe", regex="^(whisper-1|gpt-4o-transcribe|gpt-4o-mini-transcribe)$"),
|
||||||
|
):
|
||||||
|
"""Transcribe uploaded audio via OpenAI STT API."""
|
||||||
|
client = _get_openai_client()
|
||||||
|
if client is None:
|
||||||
|
return {"error": "OPENAI_API_KEY not configured", "text": ""}
|
||||||
|
|
||||||
|
raw = await audio.read()
|
||||||
|
suffix = os.path.splitext(audio.filename or "audio.wav")[1] or ".wav"
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as f:
|
||||||
|
f.write(raw)
|
||||||
|
tmp_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
def _transcribe():
|
||||||
|
with open(tmp_path, "rb") as af:
|
||||||
|
result = client.audio.transcriptions.create(
|
||||||
|
model=model,
|
||||||
|
file=af,
|
||||||
|
language="zh",
|
||||||
|
)
|
||||||
|
return result.text
|
||||||
|
|
||||||
|
text = await loop.run_in_executor(None, _transcribe)
|
||||||
|
return {"text": text, "language": "zh", "model": model}
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": f"OpenAI STT error: {e}", "text": ""}
|
||||||
|
finally:
|
||||||
|
os.unlink(tmp_path)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue