taixf/modules/asr/qwen3_asr_local.py

104 lines
3.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Qwen3-ASR Local GPU Provider for xiaozhi-server
Based on fun_local.py structure.
"""
import os
import time
import torch
import asyncio
import numpy as np
from config.logger import setup_logging
from typing import Optional, Tuple, List
from core.providers.asr.base import ASRProviderBase
from core.providers.asr.dto.dto import InterfaceType
TAG = __name__
logger = setup_logging()
MAX_RETRIES = 2
RETRY_DELAY = 1
class ASRProvider(ASRProviderBase):
def __init__(self, config: dict, delete_audio_file: bool):
super().__init__()
self.interface_type = InterfaceType.LOCAL
self.output_dir = config.get("output_dir", "tmp/")
self.delete_audio_file = delete_audio_file
model_path = config.get("model_path", "Qwen/Qwen3-ASR-1.7B")
device = config.get("device", "cuda:1")
dtype_str = config.get("dtype", "bfloat16")
dtype = getattr(torch, dtype_str, torch.bfloat16)
os.makedirs(self.output_dir, exist_ok=True)
logger.bind(tag=TAG).info(
f"Qwen3ASR loading: model={model_path} device={device} dtype={dtype_str}"
)
t0 = time.time()
from qwen_asr import Qwen3ASRModel
self.model = Qwen3ASRModel.from_pretrained(
model_path,
dtype=dtype,
device_map=device,
max_new_tokens=256,
)
logger.bind(tag=TAG).info(f"Qwen3ASR loaded in {time.time()-t0:.1f}s")
async def speech_to_text(
self, opus_data: List[bytes], session_id: str, audio_format="opus", artifacts=None
) -> Tuple[Optional[str], Optional[str]]:
"""语音转文本 - 使用本地 Qwen3-ASR 模型"""
retry_count = 0
while retry_count < MAX_RETRIES:
try:
if artifacts is None:
return "", None
pcm_bytes = artifacts.pcm_bytes
if not pcm_bytes or len(pcm_bytes) == 0:
return "", artifacts.file_path
# PCM bytes -> numpy float32 (16kHz, 16-bit, mono)
audio_np = np.frombuffer(pcm_bytes, dtype=np.int16).astype(np.float32) / 32768.0
# 使用线程池避免阻塞事件循环
start_time = time.time()
results = await asyncio.to_thread(
self.model.transcribe,
audio=(audio_np, 16000),
language=None, # auto-detect
)
if results and len(results) > 0:
text = results[0].text
lang = getattr(results[0], 'language', 'unknown')
elapsed = time.time() - start_time
logger.bind(tag=TAG).info(
f"语音识别耗时: {elapsed:.3f}s | 语言: {lang} | 结果: {text}"
)
return text, artifacts.file_path
else:
return "", artifacts.file_path
except OSError as e:
retry_count += 1
if retry_count >= MAX_RETRIES:
logger.bind(tag=TAG).error(
f"语音识别失败(已重试{retry_count}次): {e}", exc_info=True
)
return "", None
logger.bind(tag=TAG).warning(
f"语音识别失败,正在重试({retry_count}/{MAX_RETRIES}: {e}"
)
time.sleep(RETRY_DELAY)
except Exception as e:
logger.bind(tag=TAG).error(f"语音识别失败: {e}", exc_info=True)
return "", None