104 lines
3.6 KiB
Python
104 lines
3.6 KiB
Python
"""
|
||
Qwen3-ASR Local GPU Provider for xiaozhi-server
|
||
Based on fun_local.py structure.
|
||
"""
|
||
|
||
import os
|
||
import time
|
||
import torch
|
||
import asyncio
|
||
import numpy as np
|
||
|
||
from config.logger import setup_logging
|
||
from typing import Optional, Tuple, List
|
||
from core.providers.asr.base import ASRProviderBase
|
||
from core.providers.asr.dto.dto import InterfaceType
|
||
|
||
TAG = __name__
|
||
logger = setup_logging()
|
||
|
||
MAX_RETRIES = 2
|
||
RETRY_DELAY = 1
|
||
|
||
|
||
class ASRProvider(ASRProviderBase):
|
||
def __init__(self, config: dict, delete_audio_file: bool):
|
||
super().__init__()
|
||
self.interface_type = InterfaceType.LOCAL
|
||
self.output_dir = config.get("output_dir", "tmp/")
|
||
self.delete_audio_file = delete_audio_file
|
||
|
||
model_path = config.get("model_path", "Qwen/Qwen3-ASR-1.7B")
|
||
device = config.get("device", "cuda:1")
|
||
dtype_str = config.get("dtype", "bfloat16")
|
||
dtype = getattr(torch, dtype_str, torch.bfloat16)
|
||
|
||
os.makedirs(self.output_dir, exist_ok=True)
|
||
|
||
logger.bind(tag=TAG).info(
|
||
f"Qwen3ASR loading: model={model_path} device={device} dtype={dtype_str}"
|
||
)
|
||
t0 = time.time()
|
||
|
||
from qwen_asr import Qwen3ASRModel
|
||
self.model = Qwen3ASRModel.from_pretrained(
|
||
model_path,
|
||
dtype=dtype,
|
||
device_map=device,
|
||
max_new_tokens=256,
|
||
)
|
||
logger.bind(tag=TAG).info(f"Qwen3ASR loaded in {time.time()-t0:.1f}s")
|
||
|
||
async def speech_to_text(
|
||
self, opus_data: List[bytes], session_id: str, audio_format="opus", artifacts=None
|
||
) -> Tuple[Optional[str], Optional[str]]:
|
||
"""语音转文本 - 使用本地 Qwen3-ASR 模型"""
|
||
retry_count = 0
|
||
|
||
while retry_count < MAX_RETRIES:
|
||
try:
|
||
if artifacts is None:
|
||
return "", None
|
||
|
||
pcm_bytes = artifacts.pcm_bytes
|
||
if not pcm_bytes or len(pcm_bytes) == 0:
|
||
return "", artifacts.file_path
|
||
|
||
# PCM bytes -> numpy float32 (16kHz, 16-bit, mono)
|
||
audio_np = np.frombuffer(pcm_bytes, dtype=np.int16).astype(np.float32) / 32768.0
|
||
|
||
# 使用线程池避免阻塞事件循环
|
||
start_time = time.time()
|
||
results = await asyncio.to_thread(
|
||
self.model.transcribe,
|
||
audio=(audio_np, 16000),
|
||
language=None, # auto-detect
|
||
)
|
||
|
||
if results and len(results) > 0:
|
||
text = results[0].text
|
||
lang = getattr(results[0], 'language', 'unknown')
|
||
elapsed = time.time() - start_time
|
||
logger.bind(tag=TAG).info(
|
||
f"语音识别耗时: {elapsed:.3f}s | 语言: {lang} | 结果: {text}"
|
||
)
|
||
return text, artifacts.file_path
|
||
else:
|
||
return "", artifacts.file_path
|
||
|
||
except OSError as e:
|
||
retry_count += 1
|
||
if retry_count >= MAX_RETRIES:
|
||
logger.bind(tag=TAG).error(
|
||
f"语音识别失败(已重试{retry_count}次): {e}", exc_info=True
|
||
)
|
||
return "", None
|
||
logger.bind(tag=TAG).warning(
|
||
f"语音识别失败,正在重试({retry_count}/{MAX_RETRIES}): {e}"
|
||
)
|
||
time.sleep(RETRY_DELAY)
|
||
|
||
except Exception as e:
|
||
logger.bind(tag=TAG).error(f"语音识别失败: {e}", exc_info=True)
|
||
return "", None
|