diff --git a/config/.config.yaml b/config/.config.yaml index fb9746f..4eef51c 100644 --- a/config/.config.yaml +++ b/config/.config.yaml @@ -9,7 +9,7 @@ log: log_level: INFO prompt: | - 你是阿福,一位经验丰富、亲切温暖的家庭医生。你服务的对象主要是中老年人。 + 你是小虎,一位经验丰富、亲切温暖的家庭医生。你服务的对象主要是中老年人。 你说话语速适中,用词通俗易懂,像一位值得信赖的老朋友。 [核心能力] - 健康咨询:解答日常健康问题,提供科学、实用的建议 @@ -27,7 +27,7 @@ prompt: | - 给出确定性诊断 - 推荐具体药品品牌 -system_error_response: "抱歉,阿福现在有点忙,咱们稍后再聊。" +system_error_response: "抱歉,小虎现在有点忙,咱们稍后再聊。" end_prompt: enable: true @@ -37,8 +37,8 @@ end_prompt: wakeup_words: - "你好小智" - "你好小志" - - "阿福阿福" - - "你好阿福" + - "小虎小虎" + - "你好小虎" selected_module: LLM: Qwen3Local diff --git a/modules/asr/qwen3_asr_local.py b/modules/asr/qwen3_asr_local.py new file mode 100644 index 0000000..2074f3d --- /dev/null +++ b/modules/asr/qwen3_asr_local.py @@ -0,0 +1,103 @@ +""" +Qwen3-ASR Local GPU Provider for xiaozhi-server +Based on fun_local.py structure. +""" + +import os +import time +import torch +import asyncio +import numpy as np + +from config.logger import setup_logging +from typing import Optional, Tuple, List +from core.providers.asr.base import ASRProviderBase +from core.providers.asr.dto.dto import InterfaceType + +TAG = __name__ +logger = setup_logging() + +MAX_RETRIES = 2 +RETRY_DELAY = 1 + + +class ASRProvider(ASRProviderBase): + def __init__(self, config: dict, delete_audio_file: bool): + super().__init__() + self.interface_type = InterfaceType.LOCAL + self.output_dir = config.get("output_dir", "tmp/") + self.delete_audio_file = delete_audio_file + + model_path = config.get("model_path", "Qwen/Qwen3-ASR-1.7B") + device = config.get("device", "cuda:1") + dtype_str = config.get("dtype", "bfloat16") + dtype = getattr(torch, dtype_str, torch.bfloat16) + + os.makedirs(self.output_dir, exist_ok=True) + + logger.bind(tag=TAG).info( + f"Qwen3ASR loading: model={model_path} device={device} dtype={dtype_str}" + ) + t0 = time.time() + + from qwen_asr import Qwen3ASRModel + self.model = Qwen3ASRModel.from_pretrained( + model_path, + dtype=dtype, + device_map=device, + max_new_tokens=256, + ) + logger.bind(tag=TAG).info(f"Qwen3ASR loaded in {time.time()-t0:.1f}s") + + async def speech_to_text( + self, opus_data: List[bytes], session_id: str, audio_format="opus", artifacts=None + ) -> Tuple[Optional[str], Optional[str]]: + """语音转文本 - 使用本地 Qwen3-ASR 模型""" + retry_count = 0 + + while retry_count < MAX_RETRIES: + try: + if artifacts is None: + return "", None + + pcm_bytes = artifacts.pcm_bytes + if not pcm_bytes or len(pcm_bytes) == 0: + return "", artifacts.file_path + + # PCM bytes -> numpy float32 (16kHz, 16-bit, mono) + audio_np = np.frombuffer(pcm_bytes, dtype=np.int16).astype(np.float32) / 32768.0 + + # 使用线程池避免阻塞事件循环 + start_time = time.time() + results = await asyncio.to_thread( + self.model.transcribe, + audio=(audio_np, 16000), + language=None, # auto-detect + ) + + if results and len(results) > 0: + text = results[0].text + lang = getattr(results[0], 'language', 'unknown') + elapsed = time.time() - start_time + logger.bind(tag=TAG).info( + f"语音识别耗时: {elapsed:.3f}s | 语言: {lang} | 结果: {text}" + ) + return text, artifacts.file_path + else: + return "", artifacts.file_path + + except OSError as e: + retry_count += 1 + if retry_count >= MAX_RETRIES: + logger.bind(tag=TAG).error( + f"语音识别失败(已重试{retry_count}次): {e}", exc_info=True + ) + return "", None + logger.bind(tag=TAG).warning( + f"语音识别失败,正在重试({retry_count}/{MAX_RETRIES}): {e}" + ) + time.sleep(RETRY_DELAY) + + except Exception as e: + logger.bind(tag=TAG).error(f"语音识别失败: {e}", exc_info=True) + return "", None diff --git a/modules/tts/sherpa_tts.py b/modules/tts/sherpa_tts.py index de0efc8..6bd04f3 100644 --- a/modules/tts/sherpa_tts.py +++ b/modules/tts/sherpa_tts.py @@ -46,20 +46,27 @@ class TTSProvider(TTSProviderBase): ) def _generate_wav(self, text): - """同步合成,在线程池中调用""" + """同步合成""" + import time from scipy.signal import resample_poly from math import gcd + logger.bind(tag=TAG).info(f"TTS 收到文字: [{text}]") + t0 = time.time() + audio = self.tts.generate(text, sid=self.sid, speed=self.speed) samples = np.array(audio.samples, dtype=np.float32) + t1 = time.time() # 重采样到目标采样率(设备要求 24000Hz,模型输出 44100Hz) target_sr = 24000 if self.sample_rate != target_sr: g = gcd(self.sample_rate, target_sr) samples = resample_poly(samples, target_sr // g, self.sample_rate // g) + t2 = time.time() pcm = (samples * 32767).astype(np.int16) + audio_duration = len(pcm) / target_sr wav_io = io.BytesIO() with wave.open(wav_io, "wb") as wf: @@ -67,7 +74,12 @@ class TTSProvider(TTSProviderBase): wf.setsampwidth(2) wf.setframerate(target_sr) wf.writeframes(pcm.tobytes()) - return wav_io.getvalue() + wav_data = wav_io.getvalue() + + logger.bind(tag=TAG).info( + f"TTS 完成: 合成={t1-t0:.2f}s 重采样={t2-t1:.2f}s 音频时长={audio_duration:.1f}s 大小={len(wav_data)}B [{text[:20]}]" + ) + return wav_data async def text_to_speak(self, text, output_file): wav_data = self._generate_wav(text)