""" Qwen3-ASR Local GPU Provider for xiaozhi-server Based on fun_local.py structure. """ import os import time import torch import asyncio import numpy as np from config.logger import setup_logging from typing import Optional, Tuple, List from core.providers.asr.base import ASRProviderBase from core.providers.asr.dto.dto import InterfaceType TAG = __name__ logger = setup_logging() MAX_RETRIES = 2 RETRY_DELAY = 1 class ASRProvider(ASRProviderBase): def __init__(self, config: dict, delete_audio_file: bool): super().__init__() self.interface_type = InterfaceType.LOCAL self.output_dir = config.get("output_dir", "tmp/") self.delete_audio_file = delete_audio_file model_path = config.get("model_path", "Qwen/Qwen3-ASR-1.7B") device = config.get("device", "cuda:1") dtype_str = config.get("dtype", "bfloat16") dtype = getattr(torch, dtype_str, torch.bfloat16) os.makedirs(self.output_dir, exist_ok=True) logger.bind(tag=TAG).info( f"Qwen3ASR loading: model={model_path} device={device} dtype={dtype_str}" ) t0 = time.time() from qwen_asr import Qwen3ASRModel self.model = Qwen3ASRModel.from_pretrained( model_path, dtype=dtype, device_map=device, max_new_tokens=256, ) logger.bind(tag=TAG).info(f"Qwen3ASR loaded in {time.time()-t0:.1f}s") async def speech_to_text( self, opus_data: List[bytes], session_id: str, audio_format="opus", artifacts=None ) -> Tuple[Optional[str], Optional[str]]: """语音转文本 - 使用本地 Qwen3-ASR 模型""" retry_count = 0 while retry_count < MAX_RETRIES: try: if artifacts is None: return "", None pcm_bytes = artifacts.pcm_bytes if not pcm_bytes or len(pcm_bytes) == 0: return "", artifacts.file_path # PCM bytes -> numpy float32 (16kHz, 16-bit, mono) audio_np = np.frombuffer(pcm_bytes, dtype=np.int16).astype(np.float32) / 32768.0 # 使用线程池避免阻塞事件循环 start_time = time.time() results = await asyncio.to_thread( self.model.transcribe, audio=(audio_np, 16000), language=None, # auto-detect ) if results and len(results) > 0: text = results[0].text lang = getattr(results[0], 'language', 'unknown') elapsed = time.time() - start_time logger.bind(tag=TAG).info( f"语音识别耗时: {elapsed:.3f}s | 语言: {lang} | 结果: {text}" ) return text, artifacts.file_path else: return "", artifacts.file_path except OSError as e: retry_count += 1 if retry_count >= MAX_RETRIES: logger.bind(tag=TAG).error( f"语音识别失败(已重试{retry_count}次): {e}", exc_info=True ) return "", None logger.bind(tag=TAG).warning( f"语音识别失败,正在重试({retry_count}/{MAX_RETRIES}): {e}" ) time.sleep(RETRY_DELAY) except Exception as e: logger.bind(tag=TAG).error(f"语音识别失败: {e}", exc_info=True) return "", None