taixf/backend/main/xiaozhi-server/core/connection.py

1507 lines
62 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import sys
import copy
import json
import uuid
import time
import queue
import asyncio
import threading
import traceback
import subprocess
import websockets
from core.utils.util import (
extract_json_from_string,
check_vad_update,
check_asr_update,
filter_sensitive_info,
)
from typing import Dict, Any
from collections import deque
from core.utils.modules_initialize import (
initialize_modules,
initialize_tts,
initialize_asr,
)
from core.handle.reportHandle import report, enqueue_tool_report
from core.providers.tts.default import DefaultTTS
from concurrent.futures import ThreadPoolExecutor
from core.utils.dialogue import Message, Dialogue
from core.providers.asr.dto.dto import InterfaceType
from core.handle.textHandle import handleTextMessage
from core.providers.tools.unified_tool_handler import UnifiedToolHandler
from plugins_func.loadplugins import auto_import_modules
from plugins_func.register import Action, ActionResponse
from core.auth import AuthenticationError
from config.config_loader import get_private_config_from_api
from core.providers.tts.dto.dto import ContentType, TTSMessageDTO, SentenceType
from config.logger import setup_logging, build_module_string, create_connection_logger
from config.manage_api_client import DeviceNotFoundException, DeviceBindException
from core.utils.prompt_manager import PromptManager
from core.utils.voiceprint_provider import VoiceprintProvider
from core.utils.util import get_system_error_response
from core.utils import textUtils
TAG = __name__
# 工具调用规则 - 用于动态注入提醒
TOOL_CALLING_RULES = """
<tool_calling>
【核心原则】你是拥有工具能力的智能助手。当用户请求需要实时信息或执行操作时,调用相应工具获取数据,禁止凭空编造答案。
- **何时必须调用工具:**
1. 实时信息查询(新闻、非本地天气、股价、汇率等)
2. 执行操作(播放音乐、控制设备、拍照、设置闹钟等)
3. 知识库检索(当工具列表包含 search_from_ragflow 时,结合用户意图判断是否需要调用)
4. 查询非今天的农历信息(明天农历、某日宜忌、节气等)
5. 用户说"拍照"时调用 self_camera_take_photo默认 question 参数为"描述一下看到的物品"
- **何时无需调用工具:**
1. `<context>` 中已提供的信息(当前时间、今天日期、今天农历、本地天气等)
2. 普通对话、问候、闲聊、情感交流、讲故事
3. 通用知识问答(非实时信息)
- **调用规范:**
1. 每次请求独立判断,不复用历史工具结果,需重新获取最新数据
2. 多任务时依次调用所有需要的工具,并依次总结每个工具的结果,不得遗漏
3. 严格遵循工具的参数要求,提供所有必要参数
4. 不确定时引导用户澄清或告知能力限制,切勿猜测或编造
5. 不调用未提供的工具,对话中提及的旧工具若不可用则忽略或说明
- **反偷懒机制(最高优先级):**
1. **每次独立判断:** 无论对话历史中是否调用过工具,当前请求必须根据当前需求独立判断是否需要调用
2. **禁止模式模仿:** 即使之前的回复没有调用工具,也不代表本次可以不调用
3. **自我检查:** 回复前必须自问:"这个请求是否涉及实时信息或执行操作?如果是,我调用工具了吗?"
4. **历史不等于现在:** 对话历史中的行为模式不影响当前判断,每个用户请求都是全新的开始
</tool_calling>
"""
auto_import_modules("plugins_func.functions")
class TTSException(RuntimeError):
pass
class ConnectionHandler:
def __init__(
self,
config: Dict[str, Any],
_vad,
_asr,
_llm,
_memory,
_intent,
server=None,
):
self.common_config = config
self.config = copy.deepcopy(config)
self.session_id = str(uuid.uuid4())
self.logger = setup_logging()
self.server = server # 保存server实例的引用
self.need_bind = False # 是否需要绑定设备
self.bind_completed_event = asyncio.Event()
self.bind_code = None # 绑定设备的验证码
self.last_bind_prompt_time = 0 # 上次播放绑定提示的时间戳(秒)
self.bind_prompt_interval = 60 # 绑定提示播放间隔(秒)
self.read_config_from_api = self.config.get("read_config_from_api", False)
self.websocket: websockets.ServerConnection | None = None
self.headers = None
self.device_id = None
self.client_ip = None
self.prompt = None
self.welcome_msg = None
self.max_output_size = 0
self.chat_history_conf = 0
self.audio_format = "opus"
self.sample_rate = 24000 # 默认采样率,从客户端 hello 消息中动态更新
# 客户端状态相关
self.client_abort = False
self.client_is_speaking = False
self.client_listen_mode = "auto"
# 线程任务相关
self.loop = None # 在 handle_connection 中获取运行中的事件循环
self.stop_event = threading.Event()
self.executor = ThreadPoolExecutor(max_workers=5)
# 添加上报线程池
self.report_queue = queue.Queue()
self.report_thread = None
# 未来可以通过修改此处调节asr的上报和tts的上报目前默认都开启
self.report_asr_enable = self.read_config_from_api
self.report_tts_enable = self.read_config_from_api
# 依赖的组件
self.vad = None
self.asr = None
self.tts = None
self._asr = _asr
self._vad = _vad
self.llm = _llm
self.memory = _memory
self.intent = _intent
self.is_exiting = False # 标记是否正在执行退出流程
# 为每个连接单独管理声纹识别
self.voiceprint_provider = None
# vad相关变量
self.client_audio_buffer = bytearray()
self.client_have_voice = False
self.client_voice_window = deque(maxlen=5)
self.first_activity_time = 0.0 # 记录首次活动的时间(毫秒)
self.last_activity_time = 0.0 # 统一的活动时间戳(毫秒)
self.client_voice_stop = False
self.last_is_voice = False
# asr相关变量
# 因为实际部署时可能会用到公共的本地ASR不能把变量暴露给公共ASR
# 所以涉及到ASR的变量需要在这里定义属于connection的私有变量
self.asr_audio = []
self.asr_audio_queue = queue.Queue()
self.current_speaker = None # 存储当前说话人
# llm相关变量
self.dialogue = Dialogue()
# 工具调用统计(用于监控和自动恢复)
self.tool_call_stats = {
'last_call_turn': -1, # 上次调用工具的轮数
'consecutive_no_call': 0, # 连续未调用次数
}
# tts相关变量
self.sentence_id = None
# 处理TTS响应没有文本返回
self.tts_MessageText = ""
# iot相关变量
self.iot_descriptors = {}
self.func_handler = None
self.cmd_exit = self.config["exit_commands"]
# 是否在聊天结束后关闭连接
self.close_after_chat = False
self.load_function_plugin = False
self.intent_type = "nointent"
self.timeout_seconds = (
int(self.config.get("close_connection_no_voice_time", 120)) + 60
) # 在原来第一道关闭的基础上加60秒进行二道关闭
self.timeout_task = None
# {"mcp":true} 表示启用MCP功能
self.features = None
# 标记连接是否来自MQTT
self.conn_from_mqtt_gateway = False
# 初始化提示词管理器
self.prompt_manager = PromptManager(self.config, self.logger)
async def handle_connection(self, ws: websockets.ServerConnection):
try:
# 获取运行中的事件循环(必须在异步上下文中)
self.loop = asyncio.get_running_loop()
# 获取并验证headers
self.headers = dict(ws.request.headers)
real_ip = self.headers.get("x-real-ip") or self.headers.get(
"x-forwarded-for"
)
if real_ip:
self.client_ip = real_ip.split(",")[0].strip()
else:
self.client_ip = ws.remote_address[0]
self.logger.bind(tag=TAG).info(
f"{self.client_ip} conn - Headers: {self.headers}"
)
self.device_id = self.headers.get("device-id", None)
# 认证通过,继续处理
self.websocket = ws
# 检查是否来自MQTT连接
request_path = ws.request.path
self.conn_from_mqtt_gateway = request_path.endswith("?from=mqtt_gateway")
if self.conn_from_mqtt_gateway:
self.logger.bind(tag=TAG).info("连接来自:MQTT网关")
# 初始化活动时间戳
self.first_activity_time = time.time() * 1000
self.last_activity_time = time.time() * 1000
# 启动超时检查任务
self.timeout_task = asyncio.create_task(self._check_timeout())
self.welcome_msg = self.config["xiaozhi"]
self.welcome_msg["session_id"] = self.session_id
# 从配置中读取采样率
self.sample_rate = self.welcome_msg["audio_params"]["sample_rate"]
self.logger.bind(tag=TAG).info(f"配置输出音频采样率为: {self.sample_rate}")
# 在后台初始化配置和组件(完全不阻塞主循环)
asyncio.create_task(self._background_initialize())
try:
async for message in self.websocket:
await self._route_message(message)
except websockets.exceptions.ConnectionClosed:
self.logger.bind(tag=TAG).info("客户端断开连接")
except AuthenticationError as e:
self.logger.bind(tag=TAG).error(f"Authentication failed: {str(e)}")
return
except Exception as e:
stack_trace = traceback.format_exc()
self.logger.bind(tag=TAG).error(f"Connection error: {str(e)}-{stack_trace}")
return
finally:
try:
await self._save_and_close(ws)
except Exception as final_error:
self.logger.bind(tag=TAG).error(f"最终清理时出错: {final_error}")
# 确保即使保存记忆失败,也要关闭连接
try:
await self.close(ws)
except Exception as close_error:
self.logger.bind(tag=TAG).error(
f"强制关闭连接时出错: {close_error}"
)
async def _save_and_close(self, ws):
"""保存记忆并关闭连接"""
try:
if self.memory:
# 使用线程池异步保存记忆
def save_memory_task():
try:
# 创建新事件循环(避免与主循环冲突)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(
self.memory.save_memory(
self.dialogue.dialogue, self.session_id
)
)
except Exception as e:
self.logger.bind(tag=TAG).error(f"保存记忆失败: {e}")
finally:
try:
loop.close()
except Exception:
pass
# 启动线程保存记忆,不等待完成
threading.Thread(target=save_memory_task, daemon=True).start()
except Exception as e:
self.logger.bind(tag=TAG).error(f"保存记忆失败: {e}")
finally:
# 立即关闭连接,不等待记忆保存完成
try:
await self.close(ws)
except Exception as close_error:
self.logger.bind(tag=TAG).error(
f"保存记忆后关闭连接失败: {close_error}"
)
async def _discard_message_with_bind_prompt(self):
"""丢弃消息并检查是否需要播放绑定提示"""
current_time = time.time()
# 检查是否需要播放绑定提示
if current_time - self.last_bind_prompt_time >= self.bind_prompt_interval:
self.last_bind_prompt_time = current_time
# 复用现有的绑定提示逻辑
from core.handle.receiveAudioHandle import check_bind_device
asyncio.create_task(check_bind_device(self))
async def _route_message(self, message):
"""消息路由"""
# 退出状态丢弃所有消息
if self.is_exiting:
return
# 检查是否已经获取到真实的绑定状态
if not self.bind_completed_event.is_set():
# 还没有获取到真实状态,等待直到获取到真实状态或超时
try:
await asyncio.wait_for(self.bind_completed_event.wait(), timeout=1)
except asyncio.TimeoutError:
# 超时仍未获取到真实状态,丢弃消息
await self._discard_message_with_bind_prompt()
return
# 已经获取到真实状态,检查是否需要绑定
if self.need_bind:
# 需要绑定,丢弃消息
await self._discard_message_with_bind_prompt()
return
# 不需要绑定,继续处理消息
if isinstance(message, str):
await handleTextMessage(self, message)
elif isinstance(message, bytes):
if self.vad is None or self.asr is None:
return
# 处理来自MQTT网关的音频包
if self.conn_from_mqtt_gateway and len(message) >= 16:
handled = await self._process_mqtt_audio_message(message)
if handled:
return
# 不需要头部处理或没有头部时,直接处理原始消息
self.asr_audio_queue.put(message)
async def _process_mqtt_audio_message(self, message):
"""
处理来自MQTT网关的音频消息解析16字节头部并提取音频数据
Args:
message: 包含头部的音频消息
Returns:
bool: 是否成功处理了消息
"""
try:
# 提取头部信息
timestamp = int.from_bytes(message[8:12], "big")
audio_length = int.from_bytes(message[12:16], "big")
# 提取音频数据
if audio_length > 0 and len(message) >= 16 + audio_length:
# 有指定长度,提取精确的音频数据
audio_data = message[16 : 16 + audio_length]
# 基于时间戳进行排序处理
self._process_websocket_audio(audio_data, timestamp)
return True
elif len(message) > 16:
# 没有指定长度或长度无效,去掉头部后处理剩余数据
audio_data = message[16:]
self.asr_audio_queue.put(audio_data)
return True
except Exception as e:
self.logger.bind(tag=TAG).error(f"解析WebSocket音频包失败: {e}")
# 处理失败返回False表示需要继续处理
return False
def _process_websocket_audio(self, audio_data, timestamp):
"""处理WebSocket格式的音频包"""
# 初始化时间戳序列管理
if not hasattr(self, "audio_timestamp_buffer"):
self.audio_timestamp_buffer = {}
self.last_processed_timestamp = 0
self.max_timestamp_buffer_size = 20
# 如果时间戳是递增的,直接处理
if timestamp >= self.last_processed_timestamp:
self.asr_audio_queue.put(audio_data)
self.last_processed_timestamp = timestamp
# 处理缓冲区中的后续包
processed_any = True
while processed_any:
processed_any = False
for ts in sorted(self.audio_timestamp_buffer.keys()):
if ts > self.last_processed_timestamp:
buffered_audio = self.audio_timestamp_buffer.pop(ts)
self.asr_audio_queue.put(buffered_audio)
self.last_processed_timestamp = ts
processed_any = True
break
else:
# 乱序包,暂存
if len(self.audio_timestamp_buffer) < self.max_timestamp_buffer_size:
self.audio_timestamp_buffer[timestamp] = audio_data
else:
self.asr_audio_queue.put(audio_data)
async def handle_restart(self, message):
"""处理服务器重启请求"""
try:
self.logger.bind(tag=TAG).info("收到服务器重启指令,准备执行...")
# 发送确认响应
await self.websocket.send(
json.dumps(
{
"type": "server",
"status": "success",
"message": "服务器重启中...",
"content": {"action": "restart"},
}
)
)
# 异步执行重启操作
def restart_server():
"""实际执行重启的方法"""
time.sleep(1)
self.logger.bind(tag=TAG).info("执行服务器重启...")
subprocess.Popen(
[sys.executable, "app.py"],
stdin=sys.stdin,
stdout=sys.stdout,
stderr=sys.stderr,
start_new_session=True,
)
os._exit(0)
# 使用线程执行重启避免阻塞事件循环
threading.Thread(target=restart_server, daemon=True).start()
except Exception as e:
self.logger.bind(tag=TAG).error(f"重启失败: {str(e)}")
await self.websocket.send(
json.dumps(
{
"type": "server",
"status": "error",
"message": f"Restart failed: {str(e)}",
"content": {"action": "restart"},
}
)
)
def _initialize_components(self):
try:
if self.tts is None:
self.tts = self._initialize_tts()
# 打开语音合成通道
asyncio.run_coroutine_threadsafe(
self.tts.open_audio_channels(self), self.loop
)
if self.need_bind:
self.bind_completed_event.set()
return
self.selected_module_str = build_module_string(
self.config.get("selected_module", {})
)
self.logger = create_connection_logger(self.selected_module_str)
"""初始化组件"""
if self.config.get("prompt") is not None:
user_prompt = self.config["prompt"]
# 使用快速提示词进行初始化
prompt = self.prompt_manager.get_quick_prompt(user_prompt)
self.change_system_prompt(prompt)
self.logger.bind(tag=TAG).info(
f"快速初始化组件: prompt成功 {prompt[:50]}..."
)
"""初始化本地组件"""
if self.vad is None:
self.vad = self._vad
if self.asr is None:
self.asr = self._initialize_asr()
# 初始化声纹识别
self._initialize_voiceprint()
# 打开语音识别通道
asyncio.run_coroutine_threadsafe(
self.asr.open_audio_channels(self), self.loop
)
"""加载记忆"""
self._initialize_memory()
"""加载意图识别"""
self._initialize_intent()
"""初始化上报线程"""
self._init_report_threads()
"""更新系统提示词"""
self._init_prompt_enhancement()
except Exception as e:
self.logger.bind(tag=TAG).error(f"实例化组件失败: {e}")
def _init_prompt_enhancement(self):
# 更新上下文信息
self.prompt_manager.update_context_info(self, self.client_ip)
enhanced_prompt = self.prompt_manager.build_enhanced_prompt(
self.config["prompt"], self.device_id, self.client_ip
)
if enhanced_prompt:
self.change_system_prompt(enhanced_prompt)
self.logger.bind(tag=TAG).debug("系统提示词已增强更新")
def _init_report_threads(self):
"""初始化ASR和TTS上报线程"""
if not self.read_config_from_api or self.need_bind:
return
if self.chat_history_conf == 0:
return
if self.report_thread is None or not self.report_thread.is_alive():
self.report_thread = threading.Thread(
target=self._report_worker, daemon=True
)
self.report_thread.start()
self.logger.bind(tag=TAG).info("TTS上报线程已启动")
def _initialize_tts(self):
"""初始化TTS"""
tts = None
if not self.need_bind:
tts = initialize_tts(self.config)
if tts is None:
tts = DefaultTTS(self.config, delete_audio_file=True)
return tts
def _initialize_asr(self):
"""初始化ASR"""
if (
self._asr is not None
and hasattr(self._asr, "interface_type")
and self._asr.interface_type == InterfaceType.LOCAL
):
# 如果公共ASR是本地服务则直接返回
# 因为本地一个实例ASR可以被多个连接共享
asr = self._asr
else:
# 如果公共ASR是远程服务则初始化一个新实例
# 因为远程ASR涉及到websocket连接和接收线程需要每个连接一个实例
asr = initialize_asr(self.config)
return asr
def _initialize_voiceprint(self):
"""为当前连接初始化声纹识别"""
try:
voiceprint_config = self.config.get("voiceprint", {})
if voiceprint_config:
voiceprint_provider = VoiceprintProvider(voiceprint_config)
if voiceprint_provider is not None and voiceprint_provider.enabled:
self.voiceprint_provider = voiceprint_provider
self.logger.bind(tag=TAG).info("声纹识别功能已在连接时动态启用")
else:
self.logger.bind(tag=TAG).warning("声纹识别功能启用但配置不完整")
else:
self.logger.bind(tag=TAG).info("声纹识别功能未启用")
except Exception as e:
self.logger.bind(tag=TAG).warning(f"声纹识别初始化失败: {str(e)}")
async def _background_initialize(self):
"""在后台初始化配置和组件(完全不阻塞主循环)"""
try:
# 异步获取差异化配置
await self._initialize_private_config_async()
# 在线程池中初始化组件
self.executor.submit(self._initialize_components)
except Exception as e:
self.logger.bind(tag=TAG).error(f"后台初始化失败: {e}")
async def _initialize_private_config_async(self):
"""从接口异步获取差异化配置(异步版本,不阻塞主循环)"""
if not self.read_config_from_api:
self.need_bind = False
self.bind_completed_event.set()
return
try:
begin_time = time.time()
private_config = await get_private_config_from_api(
self.config,
self.headers.get("device-id"),
self.headers.get("client-id", self.headers.get("device-id")),
)
private_config["delete_audio"] = bool(self.config.get("delete_audio", True))
self.logger.bind(tag=TAG).info(
f"{time.time() - begin_time} 秒,异步获取差异化配置成功: {json.dumps(filter_sensitive_info(private_config), ensure_ascii=False)}"
)
self.need_bind = False
self.bind_completed_event.set()
except DeviceNotFoundException as e:
self.need_bind = True
private_config = {}
except DeviceBindException as e:
self.need_bind = True
self.bind_code = e.bind_code
private_config = {}
except Exception as e:
self.need_bind = True
self.logger.bind(tag=TAG).error(f"异步获取差异化配置失败: {e}")
private_config = {}
init_llm, init_tts, init_memory, init_intent = (
False,
False,
False,
False,
)
init_vad = check_vad_update(self.common_config, private_config)
init_asr = check_asr_update(self.common_config, private_config)
if init_vad:
self.config["VAD"] = private_config["VAD"]
self.config["selected_module"]["VAD"] = private_config["selected_module"][
"VAD"
]
if init_asr:
self.config["ASR"] = private_config["ASR"]
self.config["selected_module"]["ASR"] = private_config["selected_module"][
"ASR"
]
if private_config.get("TTS", None) is not None:
init_tts = True
self.config["TTS"] = private_config["TTS"]
self.config["selected_module"]["TTS"] = private_config["selected_module"][
"TTS"
]
if private_config.get("LLM", None) is not None:
init_llm = True
self.config["LLM"] = private_config["LLM"]
self.config["selected_module"]["LLM"] = private_config["selected_module"][
"LLM"
]
if private_config.get("VLLM", None) is not None:
self.config["VLLM"] = private_config["VLLM"]
self.config["selected_module"]["VLLM"] = private_config["selected_module"][
"VLLM"
]
if private_config.get("Memory", None) is not None:
init_memory = True
self.config["Memory"] = private_config["Memory"]
self.config["selected_module"]["Memory"] = private_config[
"selected_module"
]["Memory"]
if private_config.get("Intent", None) is not None:
init_intent = True
self.config["Intent"] = private_config["Intent"]
model_intent = private_config.get("selected_module", {}).get("Intent", {})
self.config["selected_module"]["Intent"] = model_intent
# 加载插件配置
if model_intent != "Intent_nointent":
plugin_from_server = private_config.get("plugins", {})
for plugin, config_str in plugin_from_server.items():
plugin_from_server[plugin] = json.loads(config_str)
self.config["plugins"] = plugin_from_server
self.config["Intent"][self.config["selected_module"]["Intent"]][
"functions"
] = plugin_from_server.keys()
if private_config.get("prompt", None) is not None:
self.config["prompt"] = private_config["prompt"]
# 获取声纹信息
if private_config.get("voiceprint", None) is not None:
self.config["voiceprint"] = private_config["voiceprint"]
if private_config.get("summaryMemory", None) is not None:
self.config["summaryMemory"] = private_config["summaryMemory"]
if private_config.get("device_max_output_size", None) is not None:
self.max_output_size = int(private_config["device_max_output_size"])
if private_config.get("chat_history_conf", None) is not None:
self.chat_history_conf = int(private_config["chat_history_conf"])
if private_config.get("mcp_endpoint", None) is not None:
self.config["mcp_endpoint"] = private_config["mcp_endpoint"]
if private_config.get("context_providers", None) is not None:
self.config["context_providers"] = private_config["context_providers"]
# 使用 run_in_executor 在线程池中执行 initialize_modules避免阻塞主循环
try:
modules = await self.loop.run_in_executor(
None, # 使用默认线程池
initialize_modules,
self.logger,
private_config,
init_vad,
init_asr,
init_llm,
init_tts,
init_memory,
init_intent,
)
except Exception as e:
self.logger.bind(tag=TAG).error(f"初始化组件失败: {e}")
modules = {}
if modules.get("tts", None) is not None:
self.tts = modules["tts"]
if modules.get("vad", None) is not None:
self.vad = modules["vad"]
if modules.get("asr", None) is not None:
self.asr = modules["asr"]
if modules.get("llm", None) is not None:
self.llm = modules["llm"]
if modules.get("intent", None) is not None:
self.intent = modules["intent"]
if modules.get("memory", None) is not None:
self.memory = modules["memory"]
def _initialize_memory(self):
if self.memory is None:
return
"""初始化记忆模块"""
self.memory.init_memory(
role_id=self.device_id,
llm=self.llm,
summary_memory=self.config.get("summaryMemory", None),
save_to_file=not self.read_config_from_api,
)
# 获取记忆总结配置
memory_config = self.config["Memory"]
memory_type = self.config["Memory"][self.config["selected_module"]["Memory"]][
"type"
]
# 如果使用 nomen 或 mem_report_only直接返回
if memory_type == "nomem" or memory_type == "mem_report_only":
return
# 使用 mem_local_short 模式
elif memory_type == "mem_local_short":
memory_llm_name = memory_config[self.config["selected_module"]["Memory"]][
"llm"
]
if memory_llm_name and memory_llm_name in self.config["LLM"]:
# 如果配置了专用LLM则创建独立的LLM实例
from core.utils import llm as llm_utils
memory_llm_config = self.config["LLM"][memory_llm_name]
memory_llm_type = memory_llm_config.get("type", memory_llm_name)
memory_llm = llm_utils.create_instance(
memory_llm_type, memory_llm_config
)
self.logger.bind(tag=TAG).info(
f"为记忆总结创建了专用LLM: {memory_llm_name}, 类型: {memory_llm_type}"
)
self.memory.set_llm(memory_llm)
else:
# 否则使用主LLM
self.memory.set_llm(self.llm)
self.logger.bind(tag=TAG).info("使用主LLM作为意图识别模型")
def _initialize_intent(self):
if self.intent is None:
return
self.intent_type = self.config["Intent"][
self.config["selected_module"]["Intent"]
]["type"]
if self.intent_type == "function_call" or self.intent_type == "intent_llm":
self.load_function_plugin = True
"""初始化意图识别模块"""
# 获取意图识别配置
intent_config = self.config["Intent"]
intent_type = self.config["Intent"][self.config["selected_module"]["Intent"]][
"type"
]
# 如果使用 nointent直接返回
if intent_type == "nointent":
return
# 使用 intent_llm 模式
elif intent_type == "intent_llm":
intent_llm_name = intent_config[self.config["selected_module"]["Intent"]][
"llm"
]
if intent_llm_name and intent_llm_name in self.config["LLM"]:
# 如果配置了专用LLM则创建独立的LLM实例
from core.utils import llm as llm_utils
intent_llm_config = self.config["LLM"][intent_llm_name]
intent_llm_type = intent_llm_config.get("type", intent_llm_name)
intent_llm = llm_utils.create_instance(
intent_llm_type, intent_llm_config
)
self.logger.bind(tag=TAG).info(
f"为意图识别创建了专用LLM: {intent_llm_name}, 类型: {intent_llm_type}"
)
self.intent.set_llm(intent_llm)
else:
# 否则使用主LLM
self.intent.set_llm(self.llm)
self.logger.bind(tag=TAG).info("使用主LLM作为意图识别模型")
"""加载统一工具处理器"""
self.func_handler = UnifiedToolHandler(self)
# 异步初始化工具处理器
if hasattr(self, "loop") and self.loop:
asyncio.run_coroutine_threadsafe(self.func_handler._initialize(), self.loop)
def change_system_prompt(self, prompt):
self.prompt = prompt
# 更新系统prompt至上下文
self.dialogue.update_system_message(self.prompt)
def chat(self, query, depth=0):
if query is not None:
self.logger.bind(tag=TAG).info(f"大模型收到用户消息: {query}")
# 为最顶层时新建会话ID和发送FIRST请求
if depth == 0:
self.sentence_id = str(uuid.uuid4().hex)
self.dialogue.put(Message(role="user", content=query))
self.tts.tts_text_queue.put(
TTSMessageDTO(
sentence_id=self.sentence_id,
sentence_type=SentenceType.FIRST,
content_type=ContentType.ACTION,
)
)
# 设置最大递归深度,避免无限循环,可根据实际需求调整
MAX_DEPTH = 5
force_final_answer = False # 标记是否强制最终回答
if depth >= MAX_DEPTH:
self.logger.bind(tag=TAG).debug(
f"已达到最大工具调用深度 {MAX_DEPTH},将强制基于现有信息回答"
)
force_final_answer = True
# 添加系统指令,要求 LLM 基于现有信息回答
self.dialogue.put(
Message(
role="user",
content="[系统提示] 已达到最大工具调用次数限制,请你基于目前已经获取的所有信息,直接给出最终答案。不要再尝试调用任何工具。",
)
)
# 长对话工具调用提醒:当对话轮数较多时,提醒模型正确使用工具
force_reminder = False # 是否强制提醒
if depth == 0 and query is not None:
dialogue_length = len(self.dialogue.dialogue)
current_turn = dialogue_length // 2
# 检测距离上一次连续未调用工具的情况
if self.tool_call_stats['last_call_turn'] >= 0:
turns_since_last = current_turn - self.tool_call_stats['last_call_turn']
if turns_since_last > 3: # 超过3轮未调用
self.logger.bind(tag=TAG).warning(
f"检测到{turns_since_last}轮未调用工具,可能进入偷懒模式,将强制注入提醒"
)
force_reminder = True
# 对话历史截断:防止历史过长导致模型"偷懒模式"扩散
# 当对话历史超过阈值时,保留最近的 10 轮对话
# max_dialogue_turns = 10
# if dialogue_length > max_dialogue_turns * 2:
# removed = self.dialogue.trim_history(max_turns=max_dialogue_turns)
# if removed > 0:
# self.logger.bind(tag=TAG).info(
# f"对话历史过长({dialogue_length}条),已智能截断保留最近{max_dialogue_turns}轮,移除{removed}条消息"
# )
# Define intent functions
functions = None
# 达到最大深度时,禁用工具调用,强制 LLM 直接回答
if (
self.intent_type == "function_call"
and hasattr(self, "func_handler")
and not force_final_answer
):
functions = self.func_handler.get_functions()
# 长对话工具调用规则强化:动态生成基于当前可用工具的提醒
tool_call_reminder = None
if depth == 0 and query is not None and functions is not None:
dialogue_length = len(self.dialogue.dialogue)
# 当对话历史超过4条消息时注入规则强化
if dialogue_length > 4:
tool_summary = self._get_tool_summary(functions)
if tool_summary:
# 根据对话长度和偷懒检测,使用不同强度的提醒
if force_reminder:
# 强提醒 - 包含完整规则前缀
tool_call_reminder = (
TOOL_CALLING_RULES +
f"[重要提醒] 多轮未使用工具,检查回复是否遗漏了必要的工具调用!上一轮未使用工具,本轮必须重新判断是否需要工具。"
f"当前可用工具: {tool_summary}"
)
reminder_level = ""
else:
# 中等提醒 - 包含规则前缀
tool_call_reminder = (
TOOL_CALLING_RULES +
f"当前可用工具: {tool_summary}"
f"仅当用户请求涉及实时信息查询或执行操作时调用,日常对话无需调用。"
)
reminder_level = ""
self.logger.bind(tag=TAG).debug(
f"对话历史较长({dialogue_length}条),已注入{reminder_level}等级工具调用规则强化,当前可用工具:{tool_summary}"
)
response_message = []
# 如果有工具调用提醒,临时添加到对话中(标记为临时消息)
if tool_call_reminder:
self.dialogue.put(Message(role="user", content=tool_call_reminder, is_temporary=True))
try:
# 使用带记忆的对话
memory_str = None
# 仅当query非空代表用户询问时查询记忆
if self.memory is not None and query:
future = asyncio.run_coroutine_threadsafe(
self.memory.query_memory(query), self.loop
)
memory_str = future.result()
if self.intent_type == "function_call" and functions is not None:
# 使用支持functions的streaming接口
llm_responses = self.llm.response_with_functions(
self.session_id,
self.dialogue.get_llm_dialogue_with_memory(
memory_str, self.config.get("voiceprint", {})
),
functions=functions,
)
else:
llm_responses = self.llm.response(
self.session_id,
self.dialogue.get_llm_dialogue_with_memory(
memory_str, self.config.get("voiceprint", {})
),
)
except Exception as e:
self.logger.bind(tag=TAG).error(f"LLM 处理出错 {query}: {e}")
return None
# 处理流式响应
tool_call_flag = False
# 支持多个并行工具调用 - 使用列表存储
tool_calls_list = [] # 格式: [{"id": "", "name": "", "arguments": ""}]
content_arguments = ""
self.client_abort = False
emotion_flag = True
try:
for response in llm_responses:
if self.client_abort:
break
if self.intent_type == "function_call" and functions is not None:
content, tools_call = response
if "content" in response:
content = response["content"]
tools_call = None
if content is not None and len(content) > 0:
content_arguments += content
if not tool_call_flag and content_arguments.startswith("<tool_call>"):
# print("content_arguments", content_arguments)
tool_call_flag = True
if tools_call is not None and len(tools_call) > 0:
tool_call_flag = True
self._merge_tool_calls(tool_calls_list, tools_call)
else:
content = response
# 在llm回复中获取情绪表情一轮对话只在开头获取一次
if emotion_flag and content is not None and content.strip():
asyncio.run_coroutine_threadsafe(
textUtils.get_emotion(self, content),
self.loop,
)
emotion_flag = False
if content is not None and len(content) > 0:
if not tool_call_flag:
response_message.append(content)
self.tts.tts_text_queue.put(
TTSMessageDTO(
sentence_id=self.sentence_id,
sentence_type=SentenceType.MIDDLE,
content_type=ContentType.TEXT,
content_detail=content,
)
)
except Exception as e:
self.logger.bind(tag=TAG).error(f"LLM stream processing error: {e}")
self.tts.tts_text_queue.put(
TTSMessageDTO(
sentence_id=self.sentence_id,
sentence_type=SentenceType.MIDDLE,
content_type=ContentType.TEXT,
content_detail=get_system_error_response(self.config),
)
)
if depth == 0:
self.tts.tts_text_queue.put(
TTSMessageDTO(
sentence_id=self.sentence_id,
sentence_type=SentenceType.LAST,
content_type=ContentType.ACTION,
)
)
return
# 处理function call
if tool_call_flag:
bHasError = False
# 处理基于文本的工具调用格式
if len(tool_calls_list) == 0 and content_arguments:
a = extract_json_from_string(content_arguments)
if a is not None:
try:
content_arguments_json = json.loads(a)
tool_calls_list.append(
{
"id": str(uuid.uuid4().hex),
"name": content_arguments_json["name"],
"arguments": json.dumps(
content_arguments_json["arguments"],
ensure_ascii=False,
),
}
)
except Exception as e:
bHasError = True
response_message.append(a)
else:
bHasError = True
response_message.append(content_arguments)
if bHasError:
self.logger.bind(tag=TAG).error(
f"function call error: {content_arguments}"
)
if not bHasError and len(tool_calls_list) > 0:
self.logger.bind(tag=TAG).debug(
f"检测到 {len(tool_calls_list)} 个工具调用"
)
# 更新工具调用统计
if depth == 0:
current_turn = len(self.dialogue.dialogue) // 2
self.tool_call_stats['last_call_turn'] = current_turn
self.tool_call_stats['consecutive_no_call'] = 0
self.logger.bind(tag=TAG).debug(
f"工具调用统计更新: 当前轮次={current_turn}"
)
# 如需要大模型先处理一轮,添加相关处理后的日志情况
if len(response_message) > 0:
text_buff = "".join(response_message)
self.tts_MessageText = text_buff
self.dialogue.put(Message(role="assistant", content=text_buff))
response_message.clear()
# 收集所有工具调用的 Future
futures_with_data = []
for tool_call_data in tool_calls_list:
self.logger.bind(tag=TAG).debug(
f"function_name={tool_call_data['name']}, function_id={tool_call_data['id']}, function_arguments={tool_call_data['arguments']}"
)
# 使用公共方法上报工具调用
tool_input = json.loads(tool_call_data.get("arguments") or "{}")
enqueue_tool_report(self, tool_call_data['name'], tool_input)
future = asyncio.run_coroutine_threadsafe(
self.func_handler.handle_llm_function_call(
self, tool_call_data
),
self.loop,
)
futures_with_data.append((future, tool_call_data, tool_input))
# 工具调用超时时间可配置默认30秒
tool_call_timeout = int(self.config.get("tool_call_timeout", 30))
# 等待协程结束(实际等待时长为最慢的那个)
tool_results = []
for future, tool_call_data, tool_input in futures_with_data:
try:
result = future.result(timeout=tool_call_timeout)
tool_results.append((result, tool_call_data))
# 使用公共方法上报工具调用结果
enqueue_tool_report(self, tool_call_data['name'], tool_input, str(result.result) if result.result else None, report_tool_call=False)
except Exception as e:
self.logger.bind(tag=TAG).error(
f"工具调用超时或异常: {tool_call_data['name']}, 错误: {e}"
)
# 超时时返回错误响应,避免整个流程卡死
tool_results.append((
ActionResponse(action=Action.ERROR, result="哎呀,网络遇到点问题,请稍后再试下!"),
tool_call_data
))
# 上报工具调用错误
enqueue_tool_report(self, tool_call_data['name'], tool_input, str(e), report_tool_call=False)
# 统一处理工具调用结果
if tool_results:
self._handle_function_result(tool_results, depth=depth)
# 存储对话内容
if len(response_message) > 0:
text_buff = "".join(response_message)
self.tts_MessageText = text_buff
self.dialogue.put(Message(role="assistant", content=text_buff))
# 更新工具调用统计:如果没有调用工具,增加计数
if depth == 0 and not tool_call_flag:
self.tool_call_stats['consecutive_no_call'] += 1
if depth == 0:
# Check if LLM signaled to send idle after TTS completes
if hasattr(self.llm, 'should_idle') and self.llm.should_idle:
self.send_idle_after_tts = True
self.llm.should_idle = False
self.tts.tts_text_queue.put(
TTSMessageDTO(
sentence_id=self.sentence_id,
sentence_type=SentenceType.LAST,
content_type=ContentType.ACTION,
)
)
# 使用lambda延迟计算只有在DEBUG级别时才执行get_llm_dialogue()
self.logger.bind(tag=TAG).debug(
lambda: json.dumps(
self.dialogue.get_llm_dialogue(), indent=4, ensure_ascii=False
)
)
# 清理临时插入的工具调用提醒消息(使用标记清理)
if tool_call_reminder and len(self.dialogue.dialogue) > 0:
original_length = len(self.dialogue.dialogue)
self.dialogue.dialogue = [
msg for msg in self.dialogue.dialogue
if not getattr(msg, 'is_temporary', False)
]
if len(self.dialogue.dialogue) < original_length:
self.logger.bind(tag=TAG).debug("已清理临时的工具调用提醒消息")
return True
def _get_tool_summary(self, functions: list) -> str:
"""
从工具定义中提取摘要,用于规则强化注入
Args:
functions: 工具列表
Returns:
str: 工具名称字符串
"""
if not functions:
return ""
datas = []
for func in functions:
func_info = func.get("function", {})
name = func_info.get("name", "")
datas.append(name)
result = "".join(datas)
return result
def _handle_function_result(self, tool_results, depth):
need_llm_tools = []
for result, tool_call_data in tool_results:
if result.action in [
Action.RESPONSE,
Action.NOTFOUND,
Action.ERROR,
]: # 直接回复前端
text = result.response if result.response else result.result
self.tts.tts_one_sentence(self, ContentType.TEXT, content_detail=text)
self.dialogue.put(Message(role="assistant", content=text))
elif result.action == Action.REQLLM:
# 收集需要 LLM 处理的工具
need_llm_tools.append((result, tool_call_data))
else:
pass
if need_llm_tools:
all_tool_calls = [
{
"id": tool_call_data["id"],
"function": {
"arguments": (
"{}"
if tool_call_data["arguments"] == ""
else tool_call_data["arguments"]
),
"name": tool_call_data["name"],
},
"type": "function",
"index": idx,
}
for idx, (_, tool_call_data) in enumerate(need_llm_tools)
]
self.dialogue.put(Message(role="assistant", tool_calls=all_tool_calls))
for result, tool_call_data in need_llm_tools:
text = result.result
if text is not None and len(text) > 0:
self.dialogue.put(
Message(
role="tool",
tool_call_id=(
str(uuid.uuid4())
if tool_call_data["id"] is None
else tool_call_data["id"]
),
content=text,
)
)
self.chat(None, depth=depth + 1)
def _report_worker(self):
"""聊天记录上报工作线程"""
while not self.stop_event.is_set():
try:
# 从队列获取数据,设置超时以便定期检查停止事件
item = self.report_queue.get(timeout=1)
if item is None: # 检测毒丸对象
break
try:
# 检查线程池状态
if self.executor is None:
continue
# 提交任务到线程池
self.executor.submit(self._process_report, *item)
except Exception as e:
self.logger.bind(tag=TAG).error(f"聊天记录上报线程异常: {e}")
except queue.Empty:
continue
except Exception as e:
self.logger.bind(tag=TAG).error(f"聊天记录上报工作线程异常: {e}")
self.logger.bind(tag=TAG).info("聊天记录上报线程已退出")
def _process_report(self, type, text, audio_data, report_time):
"""处理上报任务"""
try:
# 执行异步上报(在事件循环中运行)
asyncio.run(report(self, type, text, audio_data, report_time))
except Exception as e:
self.logger.bind(tag=TAG).error(f"上报处理异常: {e}")
finally:
# 标记任务完成
self.report_queue.task_done()
def clearSpeakStatus(self):
self.client_is_speaking = False
self.logger.bind(tag=TAG).debug(f"清除服务端讲话状态")
async def close(self, ws=None):
"""资源清理方法"""
try:
# 清理 VAD 连接资源
if (
hasattr(self, "vad")
and self.vad
and hasattr(self.vad, "release_conn_resources")
):
self.vad.release_conn_resources(self)
# 清理音频缓冲区
if hasattr(self, "audio_buffer"):
self.audio_buffer.clear()
# 取消超时任务
if self.timeout_task and not self.timeout_task.done():
self.timeout_task.cancel()
try:
await self.timeout_task
except asyncio.CancelledError:
pass
self.timeout_task = None
# 清理工具处理器资源
if hasattr(self, "func_handler") and self.func_handler:
try:
await self.func_handler.cleanup()
except Exception as cleanup_error:
self.logger.bind(tag=TAG).error(
f"清理工具处理器时出错: {cleanup_error}"
)
# 触发停止事件
if self.stop_event:
self.stop_event.set()
# 清空任务队列
self.clear_queues()
# 关闭WebSocket连接
try:
if ws:
# 安全地检查WebSocket状态并关闭
try:
if hasattr(ws, "closed") and not ws.closed:
await ws.close()
elif hasattr(ws, "state") and ws.state.name != "CLOSED":
await ws.close()
else:
# 如果没有closed属性直接尝试关闭
await ws.close()
except Exception:
# 如果关闭失败,忽略错误
pass
elif self.websocket:
try:
if (
hasattr(self.websocket, "closed")
and not self.websocket.closed
):
await self.websocket.close()
elif (
hasattr(self.websocket, "state")
and self.websocket.state.name != "CLOSED"
):
await self.websocket.close()
else:
# 如果没有closed属性直接尝试关闭
await self.websocket.close()
except Exception:
# 如果关闭失败,忽略错误
pass
except Exception as ws_error:
self.logger.bind(tag=TAG).error(f"关闭WebSocket连接时出错: {ws_error}")
if self.tts:
await self.tts.close()
if self.asr:
await self.asr.close()
# 最后关闭线程池(避免阻塞)
if self.executor:
try:
self.executor.shutdown(wait=False)
except Exception as executor_error:
self.logger.bind(tag=TAG).error(
f"关闭线程池时出错: {executor_error}"
)
self.executor = None
self.logger.bind(tag=TAG).info("连接资源已释放")
except Exception as e:
self.logger.bind(tag=TAG).error(f"关闭连接时出错: {e}")
finally:
# 确保停止事件被设置
if self.stop_event:
self.stop_event.set()
def clear_queues(self):
"""清空所有任务队列"""
if self.tts:
self.logger.bind(tag=TAG).debug(
f"开始清理: TTS队列大小={self.tts.tts_text_queue.qsize()}, 音频队列大小={self.tts.tts_audio_queue.qsize()}"
)
# 使用非阻塞方式清空队列
for q in [
self.tts.tts_text_queue,
self.tts.tts_audio_queue,
self.report_queue,
]:
if not q:
continue
while True:
try:
q.get_nowait()
except queue.Empty:
break
# 重置音频流控器(取消后台任务并清空队列)
if hasattr(self, "audio_rate_controller") and self.audio_rate_controller:
self.audio_rate_controller.reset()
self.logger.bind(tag=TAG).debug("已重置音频流控器")
self.logger.bind(tag=TAG).debug(
f"清理结束: TTS队列大小={self.tts.tts_text_queue.qsize()}, 音频队列大小={self.tts.tts_audio_queue.qsize()}"
)
def reset_audio_states(self):
"""
重置所有音频相关状态(VAD + ASR)
"""
# Reset VAD states
self.client_audio_buffer.clear()
self.client_have_voice = False
self.client_voice_stop = False
self.client_voice_window.clear()
self.last_is_voice = False
# Clear ASR buffers
self.asr_audio.clear()
self.logger.bind(tag=TAG).debug("All audio states reset.")
def chat_and_close(self, text):
"""Chat with the user and then close the connection"""
try:
# Use the existing chat method
self.chat(text)
# After chat is complete, close the connection
self.close_after_chat = True
except Exception as e:
self.logger.bind(tag=TAG).error(f"Chat and close error: {str(e)}")
async def _check_timeout(self):
"""检查连接超时"""
try:
while not self.stop_event.is_set():
last_activity_time = self.last_activity_time
if self.need_bind:
last_activity_time = self.first_activity_time
# 检查是否超时(只有在时间戳已初始化的情况下)
if last_activity_time > 0.0:
current_time = time.time() * 1000
if current_time - last_activity_time > self.timeout_seconds * 1000:
if not self.stop_event.is_set():
self.logger.bind(tag=TAG).info("连接超时,准备关闭")
# 设置停止事件,防止重复处理
self.stop_event.set()
# 使用 try-except 包装关闭操作,确保不会因为异常而阻塞
try:
await self.close(self.websocket)
except Exception as close_error:
self.logger.bind(tag=TAG).error(
f"超时关闭连接时出错: {close_error}"
)
break
# 每10秒检查一次避免过于频繁
await asyncio.sleep(10)
except Exception as e:
self.logger.bind(tag=TAG).error(f"超时检查任务出错: {e}")
finally:
self.logger.bind(tag=TAG).info("超时检查任务已退出")
def _merge_tool_calls(self, tool_calls_list, tools_call):
"""合并工具调用列表
Args:
tool_calls_list: 已收集的工具调用列表
tools_call: 新的工具调用
"""
for tool_call in tools_call:
tool_index = getattr(tool_call, "index", None)
if tool_index is None:
if tool_call.function.name:
# 有 function_name说明是新的工具调用
tool_index = len(tool_calls_list)
else:
tool_index = len(tool_calls_list) - 1 if tool_calls_list else 0
# 确保列表有足够的位置
if tool_index >= len(tool_calls_list):
tool_calls_list.append({"id": "", "name": "", "arguments": ""})
# 更新工具调用信息
if tool_call.id:
tool_calls_list[tool_index]["id"] = tool_call.id
if tool_call.function.name:
tool_calls_list[tool_index]["name"] = tool_call.function.name
if tool_call.function.arguments:
tool_calls_list[tool_index]["arguments"] += tool_call.function.arguments