feat: Introduce unified tooling system for plugins and MCP
- Added a new `tooling` module to define a unified model for tool declarations, invocations, and execution results, facilitating compatibility between plugins, legacy actions, and MCP tools. - Implemented `ToolProvider` interface for various tool providers including built-in tools, MCP tools, and plugin runtime tools. - Enhanced `MCPManager` and `MCPConnection` to support unified tool invocation and execution results. - Updated `ComponentRegistry` and related classes to accommodate the new tool specifications and descriptions. - Refactored existing components to utilize the new tooling system, ensuring backward compatibility with legacy actions. - Improved error handling and logging for tool invocations across different providers.
This commit is contained in:
@@ -1,501 +0,0 @@
|
||||
import time
|
||||
from typing import Tuple, Optional # 增加了 Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.config.config import global_config
|
||||
import random
|
||||
from .chat_observer import ChatObserver
|
||||
from .pfc_utils import get_items_from_json
|
||||
from src.services.message_service import build_readable_messages
|
||||
|
||||
from .observation_info import ObservationInfo, dict_to_session_message
|
||||
from .conversation_info import ConversationInfo
|
||||
|
||||
|
||||
logger = get_logger("pfc_action_planner")
|
||||
|
||||
|
||||
# --- 定义 Prompt 模板 ---
|
||||
|
||||
# Prompt(1): 首次回复或非连续回复时的决策 Prompt
|
||||
PROMPT_INITIAL_REPLY = """{persona_text}。现在你在参与一场QQ私聊,请根据以下【所有信息】审慎且灵活的决策下一步行动,可以回复,可以倾听,可以调取知识,甚至可以屏蔽对方:
|
||||
|
||||
【当前对话目标】
|
||||
{goals_str}
|
||||
{knowledge_info_str}
|
||||
|
||||
【最近行动历史概要】
|
||||
{action_history_summary}
|
||||
【上一次行动的详细情况和结果】
|
||||
{last_action_context}
|
||||
【时间和超时提示】
|
||||
{time_since_last_bot_message_info}{timeout_context}
|
||||
【最近的对话记录】(包括你已成功发送的消息 和 新收到的消息)
|
||||
{chat_history_text}
|
||||
|
||||
------
|
||||
可选行动类型以及解释:
|
||||
fetch_knowledge: 需要调取知识或记忆,当需要专业知识或特定信息时选择,对方若提到你不太认识的人名或实体也可以尝试选择
|
||||
listening: 倾听对方发言,当你认为对方话才说到一半,发言明显未结束时选择
|
||||
direct_reply: 直接回复对方
|
||||
rethink_goal: 思考一个对话目标,当你觉得目前对话需要目标,或当前目标不再适用,或话题卡住时选择。注意私聊的环境是灵活的,有可能需要经常选择
|
||||
end_conversation: 结束对话,对方长时间没回复或者当你觉得对话告一段落时可以选择
|
||||
block_and_ignore: 更加极端的结束对话方式,直接结束对话并在一段时间内无视对方所有发言(屏蔽),当对话让你感到十分不适,或你遭到各类骚扰时选择
|
||||
|
||||
请以JSON格式输出你的决策:
|
||||
{{
|
||||
"action": "选择的行动类型 (必须是上面列表中的一个)",
|
||||
"reason": "选择该行动的详细原因 (必须有解释你是如何根据“上一次行动结果”、“对话记录”和自身设定人设做出合理判断的)"
|
||||
}}
|
||||
|
||||
注意:请严格按照JSON格式输出,不要包含任何其他内容。"""
|
||||
|
||||
# Prompt(2): 上一次成功回复后,决定继续发言时的决策 Prompt
|
||||
PROMPT_FOLLOW_UP = """{persona_text}。现在你在参与一场QQ私聊,刚刚你已经回复了对方,请根据以下【所有信息】审慎且灵活的决策下一步行动,可以继续发送新消息,可以等待,可以倾听,可以调取知识,甚至可以屏蔽对方:
|
||||
|
||||
【当前对话目标】
|
||||
{goals_str}
|
||||
{knowledge_info_str}
|
||||
|
||||
【最近行动历史概要】
|
||||
{action_history_summary}
|
||||
【上一次行动的详细情况和结果】
|
||||
{last_action_context}
|
||||
【时间和超时提示】
|
||||
{time_since_last_bot_message_info}{timeout_context}
|
||||
【最近的对话记录】(包括你已成功发送的消息 和 新收到的消息)
|
||||
{chat_history_text}
|
||||
|
||||
------
|
||||
可选行动类型以及解释:
|
||||
fetch_knowledge: 需要调取知识,当需要专业知识或特定信息时选择,对方若提到你不太认识的人名或实体也可以尝试选择
|
||||
wait: 暂时不说话,留给对方交互空间,等待对方回复(尤其是在你刚发言后、或上次发言因重复、发言过多被拒时、或不确定做什么时,这是不错的选择)
|
||||
listening: 倾听对方发言(虽然你刚发过言,但如果对方立刻回复且明显话没说完,可以选择这个)
|
||||
send_new_message: 发送一条新消息继续对话,允许适当的追问、补充、深入话题,或开启相关新话题。**但是避免在因重复被拒后立即使用,也不要在对方没有回复的情况下过多的“消息轰炸”或重复发言**
|
||||
rethink_goal: 思考一个对话目标,当你觉得目前对话需要目标,或当前目标不再适用,或话题卡住时选择。注意私聊的环境是灵活的,有可能需要经常选择
|
||||
end_conversation: 结束对话,对方长时间没回复或者当你觉得对话告一段落时可以选择
|
||||
block_and_ignore: 更加极端的结束对话方式,直接结束对话并在一段时间内无视对方所有发言(屏蔽),当对话让你感到十分不适,或你遭到各类骚扰时选择
|
||||
|
||||
请以JSON格式输出你的决策:
|
||||
{{
|
||||
"action": "选择的行动类型 (必须是上面列表中的一个)",
|
||||
"reason": "选择该行动的详细原因 (必须有解释你是如何根据“上一次行动结果”、“对话记录”和自身设定人设做出合理判断的。请说明你为什么选择继续发言而不是等待,以及打算发送什么类型的新消息连续发言,必须记录已经发言了几次)"
|
||||
}}
|
||||
|
||||
注意:请严格按照JSON格式输出,不要包含任何其他内容。"""
|
||||
|
||||
# 新增:Prompt(3): 决定是否在结束对话前发送告别语
|
||||
PROMPT_END_DECISION = """{persona_text}。刚刚你决定结束一场 QQ 私聊。
|
||||
|
||||
【你们之前的聊天记录】
|
||||
{chat_history_text}
|
||||
|
||||
你觉得你们的对话已经完整结束了吗?有时候,在对话自然结束后再说点什么可能会有点奇怪,但有时也可能需要一条简短的消息来圆满结束。
|
||||
如果觉得确实有必要再发一条简短、自然、符合你人设的告别消息(比如 "好,下次再聊~" 或 "嗯,先这样吧"),就输出 "yes"。
|
||||
如果觉得当前状态下直接结束对话更好,没有必要再发消息,就输出 "no"。
|
||||
|
||||
请以 JSON 格式输出你的选择:
|
||||
{{
|
||||
"say_bye": "yes/no",
|
||||
"reason": "选择 yes 或 no 的原因和内心想法 (简要说明)"
|
||||
}}
|
||||
|
||||
注意:请严格按照 JSON 格式输出,不要包含任何其他内容。"""
|
||||
|
||||
|
||||
# ActionPlanner 类定义,顶格
|
||||
class ActionPlanner:
|
||||
"""行动规划器"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMServiceClient(
|
||||
task_name="planner",
|
||||
request_type="action_planning",
|
||||
)
|
||||
self.personality_info = self._get_personality_prompt()
|
||||
self.name = global_config.bot.nickname
|
||||
self.private_name = private_name
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
||||
# self.action_planner_info = ActionPlannerInfo() # 移除未使用的变量
|
||||
|
||||
def _get_personality_prompt(self) -> str:
|
||||
"""获取个性提示信息"""
|
||||
prompt_personality = global_config.personality.personality
|
||||
|
||||
# 检查是否需要随机替换为状态
|
||||
if (
|
||||
global_config.personality.states
|
||||
and global_config.personality.state_probability > 0
|
||||
and random.random() < global_config.personality.state_probability
|
||||
):
|
||||
prompt_personality = random.choice(global_config.personality.states)
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
return f"你的名字是{bot_name},你{prompt_personality};"
|
||||
|
||||
# 修改 plan 方法签名,增加 last_successful_reply_action 参数
|
||||
async def plan(
|
||||
self,
|
||||
observation_info: ObservationInfo,
|
||||
conversation_info: ConversationInfo,
|
||||
last_successful_reply_action: Optional[str],
|
||||
) -> Tuple[str, str]:
|
||||
"""规划下一步行动
|
||||
|
||||
Args:
|
||||
observation_info: 决策信息
|
||||
conversation_info: 对话信息
|
||||
last_successful_reply_action: 上一次成功的回复动作类型 ('direct_reply' 或 'send_new_message' 或 None)
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: (行动类型, 行动原因)
|
||||
"""
|
||||
# --- 获取 Bot 上次发言时间信息 ---
|
||||
time_since_last_bot_message_info = ""
|
||||
try:
|
||||
bot_id = str(global_config.bot.qq_account)
|
||||
chat_history = getattr(observation_info, "chat_history", None)
|
||||
if chat_history and len(chat_history) > 0:
|
||||
for i in range(len(chat_history) - 1, -1, -1):
|
||||
msg = chat_history[i]
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
sender_info = msg.get("user_info", {})
|
||||
sender_id = str(sender_info.get("user_id")) if isinstance(sender_info, dict) else None
|
||||
msg_time = msg.get("time")
|
||||
if sender_id == bot_id and msg_time:
|
||||
time_diff = time.time() - msg_time
|
||||
if time_diff < 60.0:
|
||||
time_since_last_bot_message_info = (
|
||||
f"提示:你上一条成功发送的消息是在 {time_diff:.1f} 秒前。\n"
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.debug(f"[私聊][{self.private_name}]聊天历史为空或尚未加载,跳过 Bot 发言时间检查。")
|
||||
except Exception as e:
|
||||
logger.debug(f"[私聊][{self.private_name}]获取 Bot 上次发言时间时出错: {e}")
|
||||
|
||||
# --- 获取超时提示信息 ---
|
||||
# (这部分逻辑不变)
|
||||
timeout_context = ""
|
||||
try:
|
||||
if hasattr(conversation_info, "goal_list") and conversation_info.goal_list:
|
||||
last_goal_dict = conversation_info.goal_list[-1]
|
||||
if isinstance(last_goal_dict, dict) and "goal" in last_goal_dict:
|
||||
last_goal_text = last_goal_dict["goal"]
|
||||
if isinstance(last_goal_text, str) and "分钟,思考接下来要做什么" in last_goal_text:
|
||||
try:
|
||||
timeout_minutes_text = last_goal_text.split(",")[0].replace("你等待了", "")
|
||||
timeout_context = f"重要提示:对方已经长时间({timeout_minutes_text})没有回复你的消息了(这可能代表对方繁忙/不想回复/没注意到你的消息等情况,或在对方看来本次聊天已告一段落),请基于此情况规划下一步。\n"
|
||||
except Exception:
|
||||
timeout_context = "重要提示:对方已经长时间没有回复你的消息了(这可能代表对方繁忙/不想回复/没注意到你的消息等情况,或在对方看来本次聊天已告一段落),请基于此情况规划下一步。\n"
|
||||
else:
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]Conversation info goal_list is empty or not available for timeout check."
|
||||
)
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ConversationInfo object might not have goal_list attribute yet for timeout check."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[私聊][{self.private_name}]检查超时目标时出错: {e}")
|
||||
|
||||
# --- 构建通用 Prompt 参数 ---
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]开始规划行动:当前目标: {getattr(conversation_info, 'goal_list', '不可用')}"
|
||||
)
|
||||
|
||||
# 构建对话目标 (goals_str)
|
||||
goals_str = ""
|
||||
try:
|
||||
if hasattr(conversation_info, "goal_list") and conversation_info.goal_list:
|
||||
for goal_reason in conversation_info.goal_list:
|
||||
if isinstance(goal_reason, dict):
|
||||
goal = goal_reason.get("goal", "目标内容缺失")
|
||||
reasoning = goal_reason.get("reasoning", "没有明确原因")
|
||||
else:
|
||||
goal = str(goal_reason)
|
||||
reasoning = "没有明确原因"
|
||||
|
||||
goal = str(goal) if goal is not None else "目标内容缺失"
|
||||
reasoning = str(reasoning) if reasoning is not None else "没有明确原因"
|
||||
goals_str += f"- 目标:{goal}\n 原因:{reasoning}\n"
|
||||
|
||||
if not goals_str:
|
||||
goals_str = "- 目前没有明确对话目标,请考虑设定一个。\n"
|
||||
else:
|
||||
goals_str = "- 目前没有明确对话目标,请考虑设定一个。\n"
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ConversationInfo object might not have goal_list attribute yet."
|
||||
)
|
||||
goals_str = "- 获取对话目标时出错。\n"
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]构建对话目标字符串时出错: {e}")
|
||||
goals_str = "- 构建对话目标时出错。\n"
|
||||
|
||||
# --- 知识信息字符串构建开始 ---
|
||||
knowledge_info_str = "【已获取的相关知识和记忆】\n"
|
||||
try:
|
||||
# 检查 conversation_info 是否有 knowledge_list 并且不为空
|
||||
if hasattr(conversation_info, "knowledge_list") and conversation_info.knowledge_list:
|
||||
# 最多只显示最近的 5 条知识,防止 Prompt 过长
|
||||
recent_knowledge = conversation_info.knowledge_list[-5:]
|
||||
for i, knowledge_item in enumerate(recent_knowledge):
|
||||
if isinstance(knowledge_item, dict):
|
||||
query = knowledge_item.get("query", "未知查询")
|
||||
knowledge = knowledge_item.get("knowledge", "无知识内容")
|
||||
source = knowledge_item.get("source", "未知来源")
|
||||
# 只取知识内容的前 2000 个字,避免太长
|
||||
knowledge_snippet = knowledge[:2000] + "..." if len(knowledge) > 2000 else knowledge
|
||||
knowledge_info_str += (
|
||||
f"{i + 1}. 关于 '{query}' 的知识 (来源: {source}):\n {knowledge_snippet}\n"
|
||||
)
|
||||
else:
|
||||
# 处理列表里不是字典的异常情况
|
||||
knowledge_info_str += f"{i + 1}. 发现一条格式不正确的知识记录。\n"
|
||||
|
||||
if not recent_knowledge: # 如果 knowledge_list 存在但为空
|
||||
knowledge_info_str += "- 暂无相关知识和记忆。\n"
|
||||
|
||||
else:
|
||||
# 如果 conversation_info 没有 knowledge_list 属性,或者列表为空
|
||||
knowledge_info_str += "- 暂无相关知识记忆。\n"
|
||||
except AttributeError:
|
||||
logger.warning(f"[私聊][{self.private_name}]ConversationInfo 对象可能缺少 knowledge_list 属性。")
|
||||
knowledge_info_str += "- 获取知识列表时出错。\n"
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]构建知识信息字符串时出错: {e}")
|
||||
knowledge_info_str += "- 处理知识列表时出错。\n"
|
||||
# --- 知识信息字符串构建结束 ---
|
||||
|
||||
# 获取聊天历史记录 (chat_history_text)
|
||||
try:
|
||||
if hasattr(observation_info, "chat_history") and observation_info.chat_history:
|
||||
chat_history_text = observation_info.chat_history_str or "还没有聊天记录。\n"
|
||||
else:
|
||||
chat_history_text = "还没有聊天记录。\n"
|
||||
|
||||
if hasattr(observation_info, "new_messages_count") and observation_info.new_messages_count > 0:
|
||||
if hasattr(observation_info, "unprocessed_messages") and observation_info.unprocessed_messages:
|
||||
new_messages_list = observation_info.unprocessed_messages
|
||||
# Convert dict format to SessionMessage objects.
|
||||
session_messages = [dict_to_session_message(m) for m in new_messages_list]
|
||||
new_messages_str = build_readable_messages(
|
||||
session_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
chat_history_text += (
|
||||
f"\n--- 以下是 {observation_info.new_messages_count} 条新消息 ---\n{new_messages_str}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ObservationInfo has new_messages_count > 0 but unprocessed_messages is empty or missing."
|
||||
)
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ObservationInfo object might be missing expected attributes for chat history."
|
||||
)
|
||||
chat_history_text = "获取聊天记录时出错。\n"
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]处理聊天记录时发生未知错误: {e}")
|
||||
chat_history_text = "处理聊天记录时出错。\n"
|
||||
|
||||
# 构建 Persona 文本 (persona_text)
|
||||
persona_text = f"你的名字是{self.name},{self.personality_info}。"
|
||||
|
||||
# 构建行动历史和上一次行动结果 (action_history_summary, last_action_context)
|
||||
# (这部分逻辑不变)
|
||||
action_history_summary = "你最近执行的行动历史:\n"
|
||||
last_action_context = "关于你【上一次尝试】的行动:\n"
|
||||
action_history_list = []
|
||||
try:
|
||||
if hasattr(conversation_info, "done_action") and conversation_info.done_action:
|
||||
action_history_list = conversation_info.done_action[-5:]
|
||||
else:
|
||||
logger.debug(f"[私聊][{self.private_name}]Conversation info done_action is empty or not available.")
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ConversationInfo object might not have done_action attribute yet."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]访问行动历史时出错: {e}")
|
||||
|
||||
if not action_history_list:
|
||||
action_history_summary += "- 还没有执行过行动。\n"
|
||||
last_action_context += "- 这是你规划的第一个行动。\n"
|
||||
else:
|
||||
for i, action_data in enumerate(action_history_list):
|
||||
action_type = "未知"
|
||||
plan_reason = "未知"
|
||||
status = "未知"
|
||||
final_reason = ""
|
||||
action_time = ""
|
||||
|
||||
if isinstance(action_data, dict):
|
||||
action_type = action_data.get("action", "未知")
|
||||
plan_reason = action_data.get("plan_reason", "未知规划原因")
|
||||
status = action_data.get("status", "未知")
|
||||
final_reason = action_data.get("final_reason", "")
|
||||
action_time = action_data.get("time", "")
|
||||
elif isinstance(action_data, tuple):
|
||||
# 假设旧格式兼容
|
||||
if len(action_data) > 0:
|
||||
action_type = action_data[0]
|
||||
if len(action_data) > 1:
|
||||
plan_reason = action_data[1] # 可能是规划原因或最终原因
|
||||
if len(action_data) > 2:
|
||||
status = action_data[2]
|
||||
if status == "recall" and len(action_data) > 3:
|
||||
final_reason = action_data[3]
|
||||
elif status == "done" and action_type in ["direct_reply", "send_new_message"]:
|
||||
plan_reason = "成功发送" # 简化显示
|
||||
|
||||
reason_text = f", 失败/取消原因: {final_reason}" if final_reason else ""
|
||||
summary_line = f"- 时间:{action_time}, 尝试行动:'{action_type}', 状态:{status}{reason_text}"
|
||||
action_history_summary += summary_line + "\n"
|
||||
|
||||
if i == len(action_history_list) - 1:
|
||||
last_action_context += f"- 上次【规划】的行动是: '{action_type}'\n"
|
||||
last_action_context += f"- 当时规划的【原因】是: {plan_reason}\n"
|
||||
if status == "done":
|
||||
last_action_context += "- 该行动已【成功执行】。\n"
|
||||
# 记录这次成功的行动类型,供下次决策
|
||||
# self.last_successful_action_type = action_type # 不在这里记录,由 conversation 控制
|
||||
elif status == "recall":
|
||||
last_action_context += "- 但该行动最终【未能执行/被取消】。\n"
|
||||
if final_reason:
|
||||
last_action_context += f"- 【重要】失败/取消的具体原因是: “{final_reason}”\n"
|
||||
else:
|
||||
last_action_context += "- 【重要】失败/取消原因未明确记录。\n"
|
||||
# self.last_successful_action_type = None # 行动失败,清除记录
|
||||
else:
|
||||
last_action_context += f"- 该行动当前状态: {status}\n"
|
||||
# self.last_successful_action_type = None # 非完成状态,清除记录
|
||||
|
||||
# --- 选择 Prompt ---
|
||||
if last_successful_reply_action in ["direct_reply", "send_new_message"]:
|
||||
prompt_template = PROMPT_FOLLOW_UP
|
||||
logger.debug(f"[私聊][{self.private_name}]使用 PROMPT_FOLLOW_UP (追问决策)")
|
||||
else:
|
||||
prompt_template = PROMPT_INITIAL_REPLY
|
||||
logger.debug(f"[私聊][{self.private_name}]使用 PROMPT_INITIAL_REPLY (首次/非连续回复决策)")
|
||||
|
||||
# --- 格式化最终的 Prompt ---
|
||||
prompt = prompt_template.format(
|
||||
persona_text=persona_text,
|
||||
goals_str=goals_str if goals_str.strip() else "- 目前没有明确对话目标,请考虑设定一个。",
|
||||
action_history_summary=action_history_summary,
|
||||
last_action_context=last_action_context,
|
||||
time_since_last_bot_message_info=time_since_last_bot_message_info,
|
||||
timeout_context=timeout_context,
|
||||
chat_history_text=chat_history_text if chat_history_text.strip() else "还没有聊天记录。",
|
||||
knowledge_info_str=knowledge_info_str,
|
||||
)
|
||||
|
||||
logger.debug(f"[私聊][{self.private_name}]发送到LLM的最终提示词:\n------\n{prompt}\n------")
|
||||
try:
|
||||
generation_result = await self.llm.generate_response(prompt)
|
||||
content = generation_result.response
|
||||
logger.debug(f"[私聊][{self.private_name}]LLM (行动规划) 原始返回内容: {content}")
|
||||
|
||||
# --- 初始行动规划解析 ---
|
||||
success, initial_result = get_items_from_json(
|
||||
content,
|
||||
self.private_name,
|
||||
"action",
|
||||
"reason",
|
||||
default_values={"action": "wait", "reason": "LLM返回格式错误或未提供原因,默认等待"},
|
||||
)
|
||||
|
||||
initial_action = initial_result.get("action", "wait")
|
||||
initial_reason = initial_result.get("reason", "LLM未提供原因,默认等待")
|
||||
|
||||
# 检查是否需要进行结束对话决策 ---
|
||||
if initial_action == "end_conversation":
|
||||
logger.info(f"[私聊][{self.private_name}]初步规划结束对话,进入告别决策...")
|
||||
|
||||
# 使用新的 PROMPT_END_DECISION
|
||||
end_decision_prompt = PROMPT_END_DECISION.format(
|
||||
persona_text=persona_text, # 复用之前的 persona_text
|
||||
chat_history_text=chat_history_text, # 复用之前的 chat_history_text
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]发送到LLM的结束决策提示词:\n------\n{end_decision_prompt}\n------"
|
||||
)
|
||||
try:
|
||||
end_generation_result = await self.llm.generate_response(end_decision_prompt)
|
||||
end_content = end_generation_result.response # 再次调用LLM
|
||||
logger.debug(f"[私聊][{self.private_name}]LLM (结束决策) 原始返回内容: {end_content}")
|
||||
|
||||
# 解析结束决策的JSON
|
||||
end_success, end_result = get_items_from_json(
|
||||
end_content,
|
||||
self.private_name,
|
||||
"say_bye",
|
||||
"reason",
|
||||
default_values={"say_bye": "no", "reason": "结束决策LLM返回格式错误,默认不告别"},
|
||||
required_types={"say_bye": str, "reason": str}, # 明确类型
|
||||
)
|
||||
|
||||
say_bye_decision = end_result.get("say_bye", "no").lower() # 转小写方便比较
|
||||
end_decision_reason = end_result.get("reason", "未提供原因")
|
||||
|
||||
if end_success and say_bye_decision == "yes":
|
||||
# 决定要告别,返回新的 'say_goodbye' 动作
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]结束决策: yes, 准备生成告别语. 原因: {end_decision_reason}"
|
||||
)
|
||||
# 注意:这里的 reason 可以考虑拼接初始原因和结束决策原因,或者只用结束决策原因
|
||||
final_action = "say_goodbye"
|
||||
final_reason = f"决定发送告别语。决策原因: {end_decision_reason} (原结束理由: {initial_reason})"
|
||||
return final_action, final_reason
|
||||
else:
|
||||
# 决定不告别 (包括解析失败或明确说no)
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]结束决策: no, 直接结束对话. 原因: {end_decision_reason}"
|
||||
)
|
||||
# 返回原始的 'end_conversation' 动作
|
||||
final_action = "end_conversation"
|
||||
final_reason = initial_reason # 保持原始的结束理由
|
||||
return final_action, final_reason
|
||||
|
||||
except Exception as end_e:
|
||||
logger.error(f"[私聊][{self.private_name}]调用结束决策LLM或处理结果时出错: {str(end_e)}")
|
||||
# 出错时,默认执行原始的结束对话
|
||||
logger.warning(f"[私聊][{self.private_name}]结束决策出错,将按原计划执行 end_conversation")
|
||||
return "end_conversation", initial_reason # 返回原始动作和原因
|
||||
|
||||
else:
|
||||
action = initial_action
|
||||
reason = initial_reason
|
||||
|
||||
# 验证action类型 (保持不变)
|
||||
valid_actions = [
|
||||
"direct_reply",
|
||||
"send_new_message",
|
||||
"fetch_knowledge",
|
||||
"wait",
|
||||
"listening",
|
||||
"rethink_goal",
|
||||
"end_conversation", # 仍然需要验证,因为可能从上面决策后返回
|
||||
"block_and_ignore",
|
||||
"say_goodbye", # 也要验证这个新动作
|
||||
]
|
||||
if action not in valid_actions:
|
||||
logger.warning(f"[私聊][{self.private_name}]LLM返回了未知的行动类型: '{action}',强制改为 wait")
|
||||
reason = f"(原始行动'{action}'无效,已强制改为wait) {reason}"
|
||||
action = "wait"
|
||||
|
||||
logger.info(f"[私聊][{self.private_name}]规划的行动: {action}")
|
||||
logger.info(f"[私聊][{self.private_name}]行动原因: {reason}")
|
||||
return action, reason
|
||||
|
||||
except Exception as e:
|
||||
# 外层异常处理保持不变
|
||||
logger.error(f"[私聊][{self.private_name}]规划行动时调用 LLM 或处理结果出错: {str(e)}")
|
||||
return "wait", f"行动规划处理中发生错误,暂时等待: {str(e)}"
|
||||
@@ -1,394 +0,0 @@
|
||||
import time
|
||||
import asyncio
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
from src.common.logger import get_logger
|
||||
from sqlmodel import select, col
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Messages
|
||||
from maim_message import UserInfo
|
||||
from src.config.config import global_config
|
||||
from .chat_states import NotificationManager, create_new_message_notification, create_cold_chat_notification
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("chat_observer")
|
||||
|
||||
|
||||
def _message_to_dict(message: Messages) -> Dict[str, Any]:
|
||||
"""Convert Peewee Message model to dict for PFC compatibility
|
||||
|
||||
Args:
|
||||
message: Peewee Messages model instance
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Message dictionary
|
||||
"""
|
||||
message_timestamp = message.timestamp.timestamp() if isinstance(message.timestamp, datetime) else message.timestamp
|
||||
return {
|
||||
"message_id": message.message_id,
|
||||
"time": message_timestamp,
|
||||
"chat_id": message.session_id,
|
||||
"user_id": message.user_id,
|
||||
"user_nickname": message.user_nickname,
|
||||
"processed_plain_text": message.processed_plain_text,
|
||||
"display_message": message.display_message,
|
||||
"is_mentioned": message.is_mentioned,
|
||||
"is_command": message.is_command,
|
||||
# Add user_info dict for compatibility with existing code
|
||||
"user_info": {
|
||||
"user_id": message.user_id,
|
||||
"user_nickname": message.user_nickname,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ChatObserver:
|
||||
"""聊天状态观察器"""
|
||||
|
||||
# 类级别的实例管理
|
||||
_instances: Dict[str, "ChatObserver"] = {}
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, stream_id: str, private_name: str) -> "ChatObserver":
|
||||
"""获取或创建观察器实例
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
private_name: 私聊名称
|
||||
|
||||
Returns:
|
||||
ChatObserver: 观察器实例
|
||||
"""
|
||||
if stream_id not in cls._instances:
|
||||
cls._instances[stream_id] = cls(stream_id, private_name)
|
||||
return cls._instances[stream_id]
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
"""初始化观察器
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
"""
|
||||
self.last_check_time = None
|
||||
self.last_bot_speak_time = None
|
||||
self.last_user_speak_time = None
|
||||
if stream_id in self._instances:
|
||||
raise RuntimeError(f"ChatObserver for {stream_id} already exists. Use get_instance() instead.")
|
||||
|
||||
self.stream_id = stream_id
|
||||
self.private_name = private_name
|
||||
|
||||
self.last_message_read: Optional[Dict[str, Any]] = None
|
||||
self.last_message_time: float = time.time()
|
||||
|
||||
self.waiting_start_time: float = time.time() # 等待开始时间,初始化为当前时间
|
||||
|
||||
# 运行状态
|
||||
self._running: bool = False
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._update_event = asyncio.Event() # 触发更新的事件
|
||||
self._update_complete = asyncio.Event() # 更新完成的事件
|
||||
|
||||
# 通知管理器
|
||||
self.notification_manager = NotificationManager()
|
||||
|
||||
# 冷场检查配置
|
||||
self.cold_chat_threshold: float = 60.0 # 60秒无消息判定为冷场
|
||||
self.last_cold_chat_check: float = time.time()
|
||||
self.is_cold_chat_state: bool = False
|
||||
|
||||
self.update_event = asyncio.Event()
|
||||
self.update_interval = 2 # 更新间隔(秒)
|
||||
self.message_cache = []
|
||||
self.update_running = False
|
||||
|
||||
async def check(self) -> bool:
|
||||
"""检查距离上一次观察之后是否有了新消息
|
||||
|
||||
Returns:
|
||||
bool: 是否有新消息
|
||||
"""
|
||||
logger.debug(f"[私聊][{self.private_name}]检查距离上一次观察之后是否有了新消息: {self.last_check_time}")
|
||||
|
||||
last_check_time = self.last_check_time or 0.0
|
||||
last_check_dt = datetime.fromtimestamp(last_check_time)
|
||||
with get_db_session() as session:
|
||||
statement = select(Messages).where(
|
||||
(col(Messages.session_id) == self.stream_id) & (col(Messages.timestamp) > last_check_dt)
|
||||
)
|
||||
new_message_exists = session.exec(statement).first() is not None
|
||||
|
||||
if new_message_exists:
|
||||
logger.debug(f"[私聊][{self.private_name}]发现新消息")
|
||||
self.last_check_time = time.time()
|
||||
|
||||
return new_message_exists
|
||||
|
||||
async def _add_message_to_history(self, message: Dict[str, Any]):
|
||||
"""添加消息到历史记录并发送通知
|
||||
|
||||
Args:
|
||||
message: 消息数据
|
||||
"""
|
||||
try:
|
||||
# 发送新消息通知
|
||||
notification = create_new_message_notification(
|
||||
sender="chat_observer", target="observation_info", message=message
|
||||
)
|
||||
# print(self.notification_manager)
|
||||
await self.notification_manager.send_notification(notification)
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]添加消息到历史记录时出错: {e}")
|
||||
print(traceback.format_exc())
|
||||
|
||||
# 检查并更新冷场状态
|
||||
await self._check_cold_chat()
|
||||
|
||||
async def _check_cold_chat(self):
|
||||
"""检查是否处于冷场状态并发送通知"""
|
||||
current_time = time.time()
|
||||
|
||||
# 每10秒检查一次冷场状态
|
||||
if current_time - self.last_cold_chat_check < 10:
|
||||
return
|
||||
|
||||
self.last_cold_chat_check = current_time
|
||||
|
||||
# 判断是否冷场
|
||||
is_cold = (
|
||||
True
|
||||
if self.last_message_time is None
|
||||
else (current_time - self.last_message_time) > self.cold_chat_threshold
|
||||
)
|
||||
|
||||
# 如果冷场状态发生变化,发送通知
|
||||
if is_cold != self.is_cold_chat_state:
|
||||
self.is_cold_chat_state = is_cold
|
||||
notification = create_cold_chat_notification(sender="chat_observer", target="pfc", is_cold=is_cold)
|
||||
await self.notification_manager.send_notification(notification)
|
||||
|
||||
def new_message_after(self, time_point: float) -> bool:
|
||||
"""判断是否在指定时间点后有新消息
|
||||
|
||||
Args:
|
||||
time_point: 时间戳
|
||||
|
||||
Returns:
|
||||
bool: 是否有新消息
|
||||
"""
|
||||
|
||||
if self.last_message_time is None:
|
||||
logger.debug(f"[私聊][{self.private_name}]没有最后消息时间,返回 False")
|
||||
return False
|
||||
|
||||
has_new = self.last_message_time > time_point
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]判断是否在指定时间点后有新消息: {self.last_message_time} > {time_point} = {has_new}"
|
||||
)
|
||||
return has_new
|
||||
|
||||
async def _fetch_new_messages(self) -> List[Dict[str, Any]]:
|
||||
"""获取新消息
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 新消息列表
|
||||
"""
|
||||
last_message_time = self.last_message_time or 0.0
|
||||
last_message_dt = datetime.fromtimestamp(last_message_time)
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(Messages)
|
||||
.where((col(Messages.session_id) == self.stream_id) & (col(Messages.timestamp) > last_message_dt))
|
||||
.order_by(col(Messages.timestamp))
|
||||
)
|
||||
new_messages = [_message_to_dict(msg) for msg in session.exec(statement).all()]
|
||||
|
||||
if new_messages:
|
||||
self.last_message_read = new_messages[-1]
|
||||
self.last_message_time = new_messages[-1]["time"]
|
||||
|
||||
# print(f"获取数据库中找到的新消息: {new_messages}")
|
||||
|
||||
return new_messages
|
||||
|
||||
async def _fetch_new_messages_before(self, time_point: float) -> List[Dict[str, Any]]:
|
||||
"""获取指定时间点之前的消息
|
||||
|
||||
Args:
|
||||
time_point: 时间戳
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 最多5条消息
|
||||
"""
|
||||
time_point_dt = datetime.fromtimestamp(time_point)
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(Messages)
|
||||
.where((col(Messages.session_id) == self.stream_id) & (col(Messages.timestamp) < time_point_dt))
|
||||
.order_by(col(Messages.timestamp))
|
||||
.limit(5)
|
||||
)
|
||||
messages = list(session.exec(statement).all())
|
||||
messages.reverse()
|
||||
new_messages = [_message_to_dict(msg) for msg in messages]
|
||||
|
||||
if new_messages:
|
||||
self.last_message_read = new_messages[-1]["message_id"]
|
||||
|
||||
logger.debug(f"[私聊][{self.private_name}]获取指定时间点111之前的消息: {new_messages}")
|
||||
|
||||
return new_messages
|
||||
|
||||
"""主要观察循环"""
|
||||
|
||||
async def _update_loop(self):
|
||||
"""更新循环"""
|
||||
# try:
|
||||
# start_time = time.time()
|
||||
# messages = await self._fetch_new_messages_before(start_time)
|
||||
# for message in messages:
|
||||
# await self._add_message_to_history(message)
|
||||
# logger.debug(f"[私聊][{self.private_name}]缓冲消息: {messages}")
|
||||
# except Exception as e:
|
||||
# logger.error(f"[私聊][{self.private_name}]缓冲消息出错: {e}")
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
# 等待事件或超时(1秒)
|
||||
try:
|
||||
# print("等待事件")
|
||||
await asyncio.wait_for(self._update_event.wait(), timeout=1)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# print("超时")
|
||||
pass # 超时后也执行一次检查
|
||||
|
||||
self._update_event.clear() # 重置触发事件
|
||||
self._update_complete.clear() # 重置完成事件
|
||||
|
||||
# 获取新消息
|
||||
new_messages = await self._fetch_new_messages()
|
||||
|
||||
if new_messages:
|
||||
# 处理新消息
|
||||
for message in new_messages:
|
||||
await self._add_message_to_history(message)
|
||||
|
||||
# 设置完成事件
|
||||
self._update_complete.set()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]更新循环出错: {e}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
|
||||
self._update_complete.set() # 即使出错也要设置完成事件
|
||||
|
||||
def trigger_update(self):
|
||||
"""触发一次立即更新"""
|
||||
self._update_event.set()
|
||||
|
||||
async def wait_for_update(self, timeout: float = 5.0) -> bool:
|
||||
"""等待更新完成
|
||||
|
||||
Args:
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功完成更新(False表示超时)
|
||||
"""
|
||||
try:
|
||||
await asyncio.wait_for(self._update_complete.wait(), timeout=timeout)
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"[私聊][{self.private_name}]等待更新完成超时({timeout}秒)")
|
||||
return False
|
||||
|
||||
def start(self):
|
||||
"""启动观察器"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._task = asyncio.create_task(self._update_loop())
|
||||
logger.debug(f"[私聊][{self.private_name}]ChatObserver for {self.stream_id} started")
|
||||
|
||||
def stop(self):
|
||||
"""停止观察器"""
|
||||
self._running = False
|
||||
self._update_event.set() # 设置事件以解除等待
|
||||
self._update_complete.set() # 设置完成事件以解除等待
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
logger.debug(f"[私聊][{self.private_name}]ChatObserver for {self.stream_id} stopped")
|
||||
|
||||
async def process_chat_history(self, messages: list):
|
||||
"""处理聊天历史
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
"""
|
||||
self.update_check_time()
|
||||
|
||||
for msg in messages:
|
||||
try:
|
||||
user_info = UserInfo.from_dict(msg.get("user_info", {}))
|
||||
if user_info.user_id == global_config.bot.qq_account:
|
||||
self.update_bot_speak_time(msg["time"])
|
||||
else:
|
||||
self.update_user_speak_time(msg["time"])
|
||||
except Exception as e:
|
||||
logger.warning(f"[私聊][{self.private_name}]处理消息时间时出错: {e}")
|
||||
continue
|
||||
|
||||
def update_check_time(self):
|
||||
"""更新查看时间"""
|
||||
self.last_check_time = time.time()
|
||||
|
||||
def update_bot_speak_time(self, speak_time: Optional[float] = None):
|
||||
"""更新机器人说话时间"""
|
||||
self.last_bot_speak_time = speak_time or time.time()
|
||||
|
||||
def update_user_speak_time(self, speak_time: Optional[float] = None):
|
||||
"""更新用户说话时间"""
|
||||
self.last_user_speak_time = speak_time or time.time()
|
||||
|
||||
def get_time_info(self) -> str:
|
||||
"""获取时间信息文本"""
|
||||
current_time = time.time()
|
||||
time_info = ""
|
||||
|
||||
if self.last_bot_speak_time:
|
||||
bot_speak_ago = current_time - self.last_bot_speak_time
|
||||
time_info += f"\n距离你上次发言已经过去了{int(bot_speak_ago)}秒"
|
||||
|
||||
if self.last_user_speak_time:
|
||||
user_speak_ago = current_time - self.last_user_speak_time
|
||||
time_info += f"\n距离对方上次发言已经过去了{int(user_speak_ago)}秒"
|
||||
|
||||
return time_info
|
||||
|
||||
def get_cached_messages(self, limit: int = 50) -> List[Dict[str, Any]]:
|
||||
"""获取缓存的消息历史
|
||||
|
||||
Args:
|
||||
limit: 获取的最大消息数量,默认50
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 缓存的消息历史列表
|
||||
"""
|
||||
return self.message_cache[-limit:]
|
||||
|
||||
def get_last_message(self) -> Optional[Dict[str, Any]]:
|
||||
"""获取最后一条消息
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 最后一条消息,如果没有则返回None
|
||||
"""
|
||||
if not self.message_cache:
|
||||
return None
|
||||
return self.message_cache[-1]
|
||||
|
||||
def __str__(self):
|
||||
return f"ChatObserver for {self.stream_id}"
|
||||
@@ -1,290 +0,0 @@
|
||||
from enum import Enum, auto
|
||||
from typing import Optional, Dict, Any, List, Set
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ChatState(Enum):
|
||||
"""聊天状态枚举"""
|
||||
|
||||
NORMAL = auto() # 正常状态
|
||||
NEW_MESSAGE = auto() # 有新消息
|
||||
COLD_CHAT = auto() # 冷场状态
|
||||
ACTIVE_CHAT = auto() # 活跃状态
|
||||
BOT_SPEAKING = auto() # 机器人正在说话
|
||||
USER_SPEAKING = auto() # 用户正在说话
|
||||
SILENT = auto() # 沉默状态
|
||||
ERROR = auto() # 错误状态
|
||||
|
||||
|
||||
class NotificationType(Enum):
|
||||
"""通知类型枚举"""
|
||||
|
||||
NEW_MESSAGE = auto() # 新消息通知
|
||||
COLD_CHAT = auto() # 冷场通知
|
||||
ACTIVE_CHAT = auto() # 活跃通知
|
||||
BOT_SPEAKING = auto() # 机器人说话通知
|
||||
USER_SPEAKING = auto() # 用户说话通知
|
||||
MESSAGE_DELETED = auto() # 消息删除通知
|
||||
USER_JOINED = auto() # 用户加入通知
|
||||
USER_LEFT = auto() # 用户离开通知
|
||||
ERROR = auto() # 错误通知
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatStateInfo:
|
||||
"""聊天状态信息"""
|
||||
|
||||
state: ChatState
|
||||
last_message_time: Optional[float] = None
|
||||
last_message_content: Optional[str] = None
|
||||
last_speaker: Optional[str] = None
|
||||
message_count: int = 0
|
||||
cold_duration: float = 0.0 # 冷场持续时间(秒)
|
||||
active_duration: float = 0.0 # 活跃持续时间(秒)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Notification:
|
||||
"""通知基类"""
|
||||
|
||||
type: NotificationType
|
||||
timestamp: float
|
||||
sender: str # 发送者标识
|
||||
target: str # 接收者标识
|
||||
data: Dict[str, Any]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {"type": self.type.name, "timestamp": self.timestamp, "data": self.data}
|
||||
|
||||
|
||||
@dataclass
|
||||
class StateNotification(Notification):
|
||||
"""持续状态通知"""
|
||||
|
||||
is_active: bool = True
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
base_dict = super().to_dict()
|
||||
base_dict["is_active"] = self.is_active
|
||||
return base_dict
|
||||
|
||||
|
||||
class NotificationHandler(ABC):
|
||||
"""通知处理器接口"""
|
||||
|
||||
@abstractmethod
|
||||
async def handle_notification(self, notification: Notification):
|
||||
"""处理通知"""
|
||||
pass
|
||||
|
||||
|
||||
class NotificationManager:
|
||||
"""通知管理器"""
|
||||
|
||||
def __init__(self):
|
||||
# 按接收者和通知类型存储处理器
|
||||
self._handlers: Dict[str, Dict[NotificationType, List[NotificationHandler]]] = {}
|
||||
self._active_states: Set[NotificationType] = set()
|
||||
self._notification_history: List[Notification] = []
|
||||
|
||||
def register_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler):
|
||||
"""注册通知处理器
|
||||
|
||||
Args:
|
||||
target: 接收者标识(例如:"pfc")
|
||||
notification_type: 要处理的通知类型
|
||||
handler: 处理器实例
|
||||
"""
|
||||
if target not in self._handlers:
|
||||
self._handlers[target] = {}
|
||||
if notification_type not in self._handlers[target]:
|
||||
self._handlers[target][notification_type] = []
|
||||
# print(self._handlers[target][notification_type])
|
||||
self._handlers[target][notification_type].append(handler)
|
||||
# print(self._handlers[target][notification_type])
|
||||
|
||||
def unregister_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler):
|
||||
"""注销通知处理器
|
||||
|
||||
Args:
|
||||
target: 接收者标识
|
||||
notification_type: 通知类型
|
||||
handler: 要注销的处理器实例
|
||||
"""
|
||||
if target in self._handlers and notification_type in self._handlers[target]:
|
||||
handlers = self._handlers[target][notification_type]
|
||||
if handler in handlers:
|
||||
handlers.remove(handler)
|
||||
# 如果该类型的处理器列表为空,删除该类型
|
||||
if not handlers:
|
||||
del self._handlers[target][notification_type]
|
||||
# 如果该目标没有任何处理器,删除该目标
|
||||
if not self._handlers[target]:
|
||||
del self._handlers[target]
|
||||
|
||||
async def send_notification(self, notification: Notification):
|
||||
"""发送通知"""
|
||||
self._notification_history.append(notification)
|
||||
|
||||
# 如果是状态通知,更新活跃状态
|
||||
if isinstance(notification, StateNotification):
|
||||
if notification.is_active:
|
||||
self._active_states.add(notification.type)
|
||||
else:
|
||||
self._active_states.discard(notification.type)
|
||||
|
||||
# 调用目标接收者的处理器
|
||||
target = notification.target
|
||||
if target in self._handlers:
|
||||
handlers = self._handlers[target].get(notification.type, [])
|
||||
# print(handlers)
|
||||
for handler in handlers:
|
||||
# print(f"调用处理器: {handler}")
|
||||
await handler.handle_notification(notification)
|
||||
|
||||
def get_active_states(self) -> Set[NotificationType]:
|
||||
"""获取当前活跃的状态"""
|
||||
return self._active_states.copy()
|
||||
|
||||
def is_state_active(self, state_type: NotificationType) -> bool:
|
||||
"""检查特定状态是否活跃"""
|
||||
return state_type in self._active_states
|
||||
|
||||
def get_notification_history(
|
||||
self, sender: Optional[str] = None, target: Optional[str] = None, limit: Optional[int] = None
|
||||
) -> List[Notification]:
|
||||
"""获取通知历史
|
||||
|
||||
Args:
|
||||
sender: 过滤特定发送者的通知
|
||||
target: 过滤特定接收者的通知
|
||||
limit: 限制返回数量
|
||||
"""
|
||||
history = self._notification_history
|
||||
|
||||
if sender:
|
||||
history = [n for n in history if n.sender == sender]
|
||||
if target:
|
||||
history = [n for n in history if n.target == target]
|
||||
|
||||
if limit is not None:
|
||||
history = history[-limit:]
|
||||
|
||||
return history
|
||||
|
||||
def __str__(self):
|
||||
str = ""
|
||||
for target, handlers in self._handlers.items():
|
||||
for notification_type, handler_list in handlers.items():
|
||||
str += f"NotificationManager for {target} {notification_type} {handler_list}"
|
||||
return str
|
||||
|
||||
|
||||
# 一些常用的通知创建函数
|
||||
def create_new_message_notification(sender: str, target: str, message: Dict[str, Any]) -> Notification:
|
||||
"""创建新消息通知"""
|
||||
return Notification(
|
||||
type=NotificationType.NEW_MESSAGE,
|
||||
timestamp=datetime.now().timestamp(),
|
||||
sender=sender,
|
||||
target=target,
|
||||
data={
|
||||
"message_id": message.get("message_id"),
|
||||
"processed_plain_text": message.get("processed_plain_text"),
|
||||
"detailed_plain_text": message.get("detailed_plain_text"),
|
||||
"user_info": message.get("user_info"),
|
||||
"time": message.get("time"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def create_cold_chat_notification(sender: str, target: str, is_cold: bool) -> StateNotification:
|
||||
"""创建冷场状态通知"""
|
||||
return StateNotification(
|
||||
type=NotificationType.COLD_CHAT,
|
||||
timestamp=datetime.now().timestamp(),
|
||||
sender=sender,
|
||||
target=target,
|
||||
data={"is_cold": is_cold},
|
||||
is_active=is_cold,
|
||||
)
|
||||
|
||||
|
||||
def create_active_chat_notification(sender: str, target: str, is_active: bool) -> StateNotification:
|
||||
"""创建活跃状态通知"""
|
||||
return StateNotification(
|
||||
type=NotificationType.ACTIVE_CHAT,
|
||||
timestamp=datetime.now().timestamp(),
|
||||
sender=sender,
|
||||
target=target,
|
||||
data={"is_active": is_active},
|
||||
is_active=is_active,
|
||||
)
|
||||
|
||||
|
||||
class ChatStateManager:
|
||||
"""聊天状态管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.current_state = ChatState.NORMAL
|
||||
self.state_info = ChatStateInfo(state=ChatState.NORMAL)
|
||||
self.state_history: list[ChatStateInfo] = []
|
||||
|
||||
def update_state(self, new_state: ChatState, **kwargs):
|
||||
"""更新聊天状态
|
||||
|
||||
Args:
|
||||
new_state: 新的状态
|
||||
**kwargs: 其他状态信息
|
||||
"""
|
||||
self.current_state = new_state
|
||||
self.state_info.state = new_state
|
||||
|
||||
# 更新其他状态信息
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(self.state_info, key):
|
||||
setattr(self.state_info, key, value)
|
||||
|
||||
# 记录状态历史
|
||||
self.state_history.append(self.state_info)
|
||||
|
||||
def get_current_state_info(self) -> ChatStateInfo:
|
||||
"""获取当前状态信息"""
|
||||
return self.state_info
|
||||
|
||||
def get_state_history(self) -> list[ChatStateInfo]:
|
||||
"""获取状态历史"""
|
||||
return self.state_history
|
||||
|
||||
def is_cold_chat(self, threshold: float = 60.0) -> bool:
|
||||
"""判断是否处于冷场状态
|
||||
|
||||
Args:
|
||||
threshold: 冷场阈值(秒)
|
||||
|
||||
Returns:
|
||||
bool: 是否冷场
|
||||
"""
|
||||
if not self.state_info.last_message_time:
|
||||
return True
|
||||
|
||||
current_time = datetime.now().timestamp()
|
||||
return (current_time - self.state_info.last_message_time) > threshold
|
||||
|
||||
def is_active_chat(self, threshold: float = 5.0) -> bool:
|
||||
"""判断是否处于活跃状态
|
||||
|
||||
Args:
|
||||
threshold: 活跃阈值(秒)
|
||||
|
||||
Returns:
|
||||
bool: 是否活跃
|
||||
"""
|
||||
if not self.state_info.last_message_time:
|
||||
return False
|
||||
|
||||
current_time = datetime.now().timestamp()
|
||||
return (current_time - self.state_info.last_message_time) <= threshold
|
||||
@@ -1,722 +0,0 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
import time
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage
|
||||
from src.services.message_service import build_readable_messages, get_messages_before_time_in_chat
|
||||
|
||||
# from .message_storage import MongoDBMessageStorage
|
||||
# from src.config.config import global_config
|
||||
from .pfc_types import ConversationState
|
||||
from .pfc import ChatObserver, GoalAnalyzer
|
||||
from .message_sender import DirectMessageSender
|
||||
from src.common.logger import get_logger
|
||||
from .action_planner import ActionPlanner
|
||||
from .observation_info import ObservationInfo
|
||||
from .conversation_info import ConversationInfo # 确保导入 ConversationInfo
|
||||
from .reply_generator import ReplyGenerator
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from maim_message import UserInfo
|
||||
from .pfc_KnowledgeFetcher import KnowledgeFetcher
|
||||
from .waiter import Waiter
|
||||
|
||||
import traceback
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("pfc")
|
||||
|
||||
|
||||
class Conversation:
|
||||
"""对话类,负责管理单个对话的状态和行为"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
"""初始化对话实例
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
"""
|
||||
self.stream_id = stream_id
|
||||
self.private_name = private_name
|
||||
self.state = ConversationState.INIT
|
||||
self.should_continue = False
|
||||
self.ignore_until_timestamp: Optional[float] = None
|
||||
|
||||
# 回复相关
|
||||
self.generated_reply = ""
|
||||
|
||||
async def _initialize(self):
|
||||
"""初始化实例,注册所有组件"""
|
||||
|
||||
try:
|
||||
self.action_planner = ActionPlanner(self.stream_id, self.private_name)
|
||||
self.goal_analyzer = GoalAnalyzer(self.stream_id, self.private_name)
|
||||
self.reply_generator = ReplyGenerator(self.stream_id, self.private_name)
|
||||
self.knowledge_fetcher = KnowledgeFetcher(self.private_name)
|
||||
self.waiter = Waiter(self.stream_id, self.private_name)
|
||||
self.direct_sender = DirectMessageSender(self.private_name)
|
||||
|
||||
# 获取聊天流信息
|
||||
self.chat_stream = _chat_manager.get_session_by_session_id(self.stream_id)
|
||||
|
||||
self.stop_action_planner = False
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]初始化对话实例:注册运行组件失败: {e}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
|
||||
raise
|
||||
|
||||
try:
|
||||
# 决策所需要的信息,包括自身自信和观察信息两部分
|
||||
# 注册观察器和观测信息
|
||||
self.chat_observer = ChatObserver.get_instance(self.stream_id, self.private_name)
|
||||
self.chat_observer.start()
|
||||
self.observation_info = ObservationInfo(self.private_name)
|
||||
self.observation_info.bind_to_chat_observer(self.chat_observer)
|
||||
# print(self.chat_observer.get_cached_messages(limit=)
|
||||
|
||||
self.conversation_info = ConversationInfo()
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]初始化对话实例:注册信息组件失败: {e}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
|
||||
raise
|
||||
try:
|
||||
logger.info(f"[私聊][{self.private_name}]为 {self.stream_id} 加载初始聊天记录...")
|
||||
initial_messages = get_messages_before_time_in_chat(
|
||||
chat_id=self.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=30, # 加载最近30条作为初始上下文,可以调整
|
||||
)
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
initial_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
if initial_messages:
|
||||
# 将 SessionMessage 列表转换为 PFC 期望的 dict 格式(保持嵌套结构)
|
||||
initial_messages_dict: list[dict] = []
|
||||
for msg in initial_messages:
|
||||
user_info = msg.message_info.user_info
|
||||
msg_dict = {
|
||||
"message_id": msg.message_id,
|
||||
"time": msg.timestamp.timestamp(),
|
||||
"chat_id": msg.session_id,
|
||||
"processed_plain_text": msg.processed_plain_text,
|
||||
"display_message": msg.display_message,
|
||||
"is_mentioned": msg.is_mentioned,
|
||||
"is_command": msg.is_command,
|
||||
"user_info": {
|
||||
"user_id": user_info.user_id,
|
||||
"user_nickname": user_info.user_nickname,
|
||||
"user_cardname": user_info.user_cardname,
|
||||
"platform": msg.platform,
|
||||
},
|
||||
}
|
||||
initial_messages_dict.append(msg_dict)
|
||||
|
||||
# 将加载的消息填充到 ObservationInfo 的 chat_history
|
||||
self.observation_info.chat_history = initial_messages_dict
|
||||
self.observation_info.chat_history_str = chat_talking_prompt + "\n"
|
||||
self.observation_info.chat_history_count = len(initial_messages_dict)
|
||||
|
||||
# 更新 ObservationInfo 中的时间戳等信息
|
||||
last_msg_dict: dict = initial_messages_dict[-1]
|
||||
self.observation_info.last_message_time = last_msg_dict.get("time")
|
||||
last_user_info = UserInfo.from_dict(last_msg_dict.get("user_info", {}))
|
||||
self.observation_info.last_message_sender = last_user_info.user_id
|
||||
self.observation_info.last_message_content = last_msg_dict.get("processed_plain_text", "")
|
||||
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]成功加载 {len(initial_messages_dict)} 条初始聊天记录。最后一条消息时间: {self.observation_info.last_message_time}"
|
||||
)
|
||||
|
||||
# 让 ChatObserver 从加载的最后一条消息之后开始同步
|
||||
if self.observation_info.last_message_time:
|
||||
self.chat_observer.last_message_time = self.observation_info.last_message_time
|
||||
self.chat_observer.last_message_read = last_msg_dict # 更新 observer 的最后读取记录
|
||||
else:
|
||||
logger.info(f"[私聊][{self.private_name}]没有找到初始聊天记录。")
|
||||
|
||||
except Exception as load_err:
|
||||
logger.error(f"[私聊][{self.private_name}]加载初始聊天记录时出错: {load_err}")
|
||||
# 出错也要继续,只是没有历史记录而已
|
||||
# 组件准备完成,启动该论对话
|
||||
self.should_continue = True
|
||||
asyncio.create_task(self.start())
|
||||
|
||||
async def start(self):
|
||||
"""开始对话流程"""
|
||||
try:
|
||||
logger.info(f"[私聊][{self.private_name}]对话系统启动中...")
|
||||
asyncio.create_task(self._plan_and_action_loop())
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]启动对话系统失败: {e}")
|
||||
raise
|
||||
|
||||
async def _plan_and_action_loop(self):
|
||||
"""思考步,PFC核心循环模块"""
|
||||
while self.should_continue:
|
||||
# 忽略逻辑
|
||||
if self.ignore_until_timestamp and time.time() < self.ignore_until_timestamp:
|
||||
await asyncio.sleep(30)
|
||||
continue
|
||||
elif self.ignore_until_timestamp and time.time() >= self.ignore_until_timestamp:
|
||||
logger.info(f"[私聊][{self.private_name}]忽略时间已到 {self.stream_id},准备结束对话。")
|
||||
self.ignore_until_timestamp = None
|
||||
self.should_continue = False
|
||||
continue
|
||||
try:
|
||||
# --- 在规划前记录当前新消息数量 ---
|
||||
initial_new_message_count = 0
|
||||
if hasattr(self.observation_info, "new_messages_count"):
|
||||
initial_new_message_count = self.observation_info.new_messages_count + 1 # 算上麦麦自己发的那一条
|
||||
else:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ObservationInfo missing 'new_messages_count' before planning."
|
||||
)
|
||||
|
||||
# --- 调用 Action Planner ---
|
||||
# 传递 self.conversation_info.last_successful_reply_action
|
||||
action, reason = await self.action_planner.plan(
|
||||
self.observation_info, self.conversation_info, self.conversation_info.last_successful_reply_action
|
||||
)
|
||||
|
||||
# --- 规划后检查是否有 *更多* 新消息到达 ---
|
||||
current_new_message_count = 0
|
||||
if hasattr(self.observation_info, "new_messages_count"):
|
||||
current_new_message_count = self.observation_info.new_messages_count
|
||||
else:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ObservationInfo missing 'new_messages_count' after planning."
|
||||
)
|
||||
|
||||
if current_new_message_count > initial_new_message_count + 2:
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]规划期间发现新增消息 ({initial_new_message_count} -> {current_new_message_count}),跳过本次行动,重新规划"
|
||||
)
|
||||
# 如果规划期间有新消息,也应该重置上次回复状态,因为现在要响应新消息了
|
||||
self.conversation_info.last_successful_reply_action = None
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
|
||||
# 包含 send_new_message
|
||||
if initial_new_message_count > 0 and action in ["direct_reply", "send_new_message"]:
|
||||
if hasattr(self.observation_info, "clear_unprocessed_messages"):
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]准备执行 {action},清理 {initial_new_message_count} 条规划时已知的新消息。"
|
||||
)
|
||||
await self.observation_info.clear_unprocessed_messages()
|
||||
if hasattr(self.observation_info, "new_messages_count"):
|
||||
self.observation_info.new_messages_count = 0
|
||||
else:
|
||||
logger.error(
|
||||
f"[私聊][{self.private_name}]无法清理未处理消息: ObservationInfo 缺少 clear_unprocessed_messages 方法!"
|
||||
)
|
||||
|
||||
await self._handle_action(action, reason, self.observation_info, self.conversation_info)
|
||||
|
||||
# 检查是否需要结束对话 (逻辑不变)
|
||||
goal_ended = False
|
||||
if hasattr(self.conversation_info, "goal_list") and self.conversation_info.goal_list:
|
||||
for goal_item in self.conversation_info.goal_list:
|
||||
if isinstance(goal_item, dict):
|
||||
current_goal = goal_item.get("goal")
|
||||
|
||||
if current_goal == "结束对话":
|
||||
goal_ended = True
|
||||
break
|
||||
|
||||
if goal_ended:
|
||||
self.should_continue = False
|
||||
logger.info(f"[私聊][{self.private_name}]检测到'结束对话'目标,停止循环。")
|
||||
|
||||
except Exception as loop_err:
|
||||
logger.error(f"[私聊][{self.private_name}]PFC主循环出错: {loop_err}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
if self.should_continue:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
logger.info(f"[私聊][{self.private_name}]PFC 循环结束 for stream_id: {self.stream_id}")
|
||||
|
||||
def _check_new_messages_after_planning(self):
|
||||
"""检查在规划后是否有新消息"""
|
||||
# 检查 ObservationInfo 是否已初始化并且有 new_messages_count 属性
|
||||
if not hasattr(self, "observation_info") or not hasattr(self.observation_info, "new_messages_count"):
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ObservationInfo 未初始化或缺少 'new_messages_count' 属性,无法检查新消息。"
|
||||
)
|
||||
return False # 或者根据需要抛出错误
|
||||
|
||||
if self.observation_info.new_messages_count > 2:
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]生成/执行动作期间收到 {self.observation_info.new_messages_count} 条新消息,取消当前动作并重新规划"
|
||||
)
|
||||
# 如果有新消息,也应该重置上次回复状态
|
||||
if hasattr(self, "conversation_info"): # 确保 conversation_info 已初始化
|
||||
self.conversation_info.last_successful_reply_action = None
|
||||
else:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ConversationInfo 未初始化,无法重置 last_successful_reply_action。"
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _convert_to_message(self, msg_dict: Dict[str, Any]) -> MaiMessage:
|
||||
"""将消息字典转换为MaiMessage对象"""
|
||||
from datetime import datetime as dt
|
||||
from src.common.data_models.mai_message_data_model import UserInfo as MaiUserInfo, MessageInfo
|
||||
from src.common.data_models.message_component_data_model import MessageSequence
|
||||
|
||||
try:
|
||||
user_info_dict = msg_dict.get("user_info", {})
|
||||
user_info = MaiUserInfo(
|
||||
user_id=user_info_dict.get("user_id", ""),
|
||||
user_nickname=user_info_dict.get("user_nickname", ""),
|
||||
user_cardname=user_info_dict.get("user_cardname"),
|
||||
)
|
||||
|
||||
msg = MaiMessage(
|
||||
message_id=msg_dict.get("message_id", f"gen_{time.time()}"),
|
||||
timestamp=dt.fromtimestamp(msg_dict.get("time", time.time())),
|
||||
)
|
||||
msg.message_info = MessageInfo(user_info=user_info)
|
||||
msg.platform = user_info_dict.get("platform", "")
|
||||
msg.session_id = self.stream_id
|
||||
msg.processed_plain_text = msg_dict.get("processed_plain_text", "")
|
||||
msg.raw_message = MessageSequence(components=[])
|
||||
msg.initialized = True
|
||||
return msg
|
||||
except Exception as e:
|
||||
logger.warning(f"[私聊][{self.private_name}]转换消息时出错: {e}")
|
||||
raise ValueError(f"无法将字典转换为 MaiMessage 对象: {e}") from e
|
||||
|
||||
async def _handle_action(
|
||||
self, action: str, reason: str, observation_info: ObservationInfo, conversation_info: ConversationInfo
|
||||
):
|
||||
"""处理规划的行动"""
|
||||
|
||||
logger.debug(f"[私聊][{self.private_name}]执行行动: {action}, 原因: {reason}")
|
||||
|
||||
# 记录action历史 (逻辑不变)
|
||||
current_action_record = {
|
||||
"action": action,
|
||||
"plan_reason": reason,
|
||||
"status": "start",
|
||||
"time": datetime.datetime.now().strftime("%H:%M:%S"),
|
||||
"final_reason": None,
|
||||
}
|
||||
# 确保 done_action 列表存在
|
||||
if not hasattr(conversation_info, "done_action"):
|
||||
conversation_info.done_action = []
|
||||
conversation_info.done_action.append(current_action_record)
|
||||
action_index = len(conversation_info.done_action) - 1
|
||||
|
||||
action_successful = False # 用于标记动作是否成功完成
|
||||
|
||||
# --- 根据不同的 action 执行 ---
|
||||
|
||||
# send_new_message 失败后执行 wait
|
||||
if action == "send_new_message":
|
||||
max_reply_attempts = 3
|
||||
reply_attempt_count = 0
|
||||
is_suitable = False
|
||||
need_replan = False
|
||||
check_reason = "未进行尝试"
|
||||
final_reply_to_send = ""
|
||||
|
||||
while reply_attempt_count < max_reply_attempts and not is_suitable:
|
||||
reply_attempt_count += 1
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]尝试生成追问回复 (第 {reply_attempt_count}/{max_reply_attempts} 次)..."
|
||||
)
|
||||
self.state = ConversationState.GENERATING
|
||||
|
||||
# 1. 生成回复 (调用 generate 时传入 action_type)
|
||||
self.generated_reply = await self.reply_generator.generate(
|
||||
observation_info, conversation_info, action_type="send_new_message"
|
||||
)
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次生成的追问回复: {self.generated_reply}"
|
||||
)
|
||||
|
||||
# 2. 检查回复 (逻辑不变)
|
||||
self.state = ConversationState.CHECKING
|
||||
try:
|
||||
current_goal_str = conversation_info.goal_list[0]["goal"] if conversation_info.goal_list else ""
|
||||
is_suitable, check_reason, need_replan = await self.reply_generator.check_reply(
|
||||
reply=self.generated_reply,
|
||||
goal=current_goal_str,
|
||||
chat_history=observation_info.chat_history,
|
||||
chat_history_str=observation_info.chat_history_str,
|
||||
retry_count=reply_attempt_count - 1,
|
||||
)
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次追问检查结果: 合适={is_suitable}, 原因='{check_reason}', 需重新规划={need_replan}"
|
||||
)
|
||||
if is_suitable:
|
||||
final_reply_to_send = self.generated_reply
|
||||
break
|
||||
elif need_replan:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次追问检查建议重新规划,停止尝试。原因: {check_reason}"
|
||||
)
|
||||
break
|
||||
except Exception as check_err:
|
||||
logger.error(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次调用 ReplyChecker (追问) 时出错: {check_err}"
|
||||
)
|
||||
check_reason = f"第 {reply_attempt_count} 次检查过程出错: {check_err}"
|
||||
break
|
||||
|
||||
# 循环结束,处理最终结果
|
||||
if is_suitable:
|
||||
# 检查是否有新消息
|
||||
if self._check_new_messages_after_planning():
|
||||
logger.info(f"[私聊][{self.private_name}]生成追问回复期间收到新消息,取消发送,重新规划行动")
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"有新消息,取消发送追问: {final_reply_to_send}"}
|
||||
)
|
||||
return # 直接返回,重新规划
|
||||
|
||||
# 发送合适的回复
|
||||
self.generated_reply = final_reply_to_send
|
||||
# --- 在这里调用 _send_reply ---
|
||||
await self._send_reply() # <--- 调用恢复后的函数
|
||||
|
||||
# 更新状态: 标记上次成功是 send_new_message
|
||||
self.conversation_info.last_successful_reply_action = "send_new_message"
|
||||
action_successful = True # 标记动作成功
|
||||
|
||||
elif need_replan:
|
||||
# 打回动作决策
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]经过 {reply_attempt_count} 次尝试,追问回复决定打回动作决策。打回原因: {check_reason}"
|
||||
)
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"追问尝试{reply_attempt_count}次后打回: {check_reason}"}
|
||||
)
|
||||
|
||||
else:
|
||||
# 追问失败
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]经过 {reply_attempt_count} 次尝试,未能生成合适的追问回复。最终原因: {check_reason}"
|
||||
)
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"追问尝试{reply_attempt_count}次后失败: {check_reason}"}
|
||||
)
|
||||
# 重置状态: 追问失败,下次用初始 prompt
|
||||
self.conversation_info.last_successful_reply_action = None
|
||||
|
||||
# 执行 Wait 操作
|
||||
logger.info(f"[私聊][{self.private_name}]由于无法生成合适追问回复,执行 'wait' 操作...")
|
||||
self.state = ConversationState.WAITING
|
||||
await self.waiter.wait(self.conversation_info)
|
||||
wait_action_record = {
|
||||
"action": "wait",
|
||||
"plan_reason": "因 send_new_message 多次尝试失败而执行的后备等待",
|
||||
"status": "done",
|
||||
"time": datetime.datetime.now().strftime("%H:%M:%S"),
|
||||
"final_reason": None,
|
||||
}
|
||||
conversation_info.done_action.append(wait_action_record)
|
||||
|
||||
elif action == "direct_reply":
|
||||
max_reply_attempts = 3
|
||||
reply_attempt_count = 0
|
||||
is_suitable = False
|
||||
need_replan = False
|
||||
check_reason = "未进行尝试"
|
||||
final_reply_to_send = ""
|
||||
|
||||
while reply_attempt_count < max_reply_attempts and not is_suitable:
|
||||
reply_attempt_count += 1
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]尝试生成首次回复 (第 {reply_attempt_count}/{max_reply_attempts} 次)..."
|
||||
)
|
||||
self.state = ConversationState.GENERATING
|
||||
|
||||
# 1. 生成回复
|
||||
self.generated_reply = await self.reply_generator.generate(
|
||||
observation_info, conversation_info, action_type="direct_reply"
|
||||
)
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次生成的首次回复: {self.generated_reply}"
|
||||
)
|
||||
|
||||
# 2. 检查回复
|
||||
self.state = ConversationState.CHECKING
|
||||
try:
|
||||
current_goal_str = conversation_info.goal_list[0]["goal"] if conversation_info.goal_list else ""
|
||||
is_suitable, check_reason, need_replan = await self.reply_generator.check_reply(
|
||||
reply=self.generated_reply,
|
||||
goal=current_goal_str,
|
||||
chat_history=observation_info.chat_history,
|
||||
chat_history_str=observation_info.chat_history_str,
|
||||
retry_count=reply_attempt_count - 1,
|
||||
)
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次首次回复检查结果: 合适={is_suitable}, 原因='{check_reason}', 需重新规划={need_replan}"
|
||||
)
|
||||
if is_suitable:
|
||||
final_reply_to_send = self.generated_reply
|
||||
break
|
||||
elif need_replan:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次首次回复检查建议重新规划,停止尝试。原因: {check_reason}"
|
||||
)
|
||||
break
|
||||
except Exception as check_err:
|
||||
logger.error(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次调用 ReplyChecker (首次回复) 时出错: {check_err}"
|
||||
)
|
||||
check_reason = f"第 {reply_attempt_count} 次检查过程出错: {check_err}"
|
||||
break
|
||||
|
||||
# 循环结束,处理最终结果
|
||||
if is_suitable:
|
||||
# 检查是否有新消息
|
||||
if self._check_new_messages_after_planning():
|
||||
logger.info(f"[私聊][{self.private_name}]生成首次回复期间收到新消息,取消发送,重新规划行动")
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"有新消息,取消发送首次回复: {final_reply_to_send}"}
|
||||
)
|
||||
return # 直接返回,重新规划
|
||||
|
||||
# 发送合适的回复
|
||||
self.generated_reply = final_reply_to_send
|
||||
# --- 在这里调用 _send_reply ---
|
||||
await self._send_reply() # <--- 调用恢复后的函数
|
||||
|
||||
# 更新状态: 标记上次成功是 direct_reply
|
||||
self.conversation_info.last_successful_reply_action = "direct_reply"
|
||||
action_successful = True # 标记动作成功
|
||||
|
||||
elif need_replan:
|
||||
# 打回动作决策
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]经过 {reply_attempt_count} 次尝试,首次回复决定打回动作决策。打回原因: {check_reason}"
|
||||
)
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"首次回复尝试{reply_attempt_count}次后打回: {check_reason}"}
|
||||
)
|
||||
|
||||
else:
|
||||
# 首次回复失败
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]经过 {reply_attempt_count} 次尝试,未能生成合适的首次回复。最终原因: {check_reason}"
|
||||
)
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"首次回复尝试{reply_attempt_count}次后失败: {check_reason}"}
|
||||
)
|
||||
# 重置状态: 首次回复失败,下次还是用初始 prompt
|
||||
self.conversation_info.last_successful_reply_action = None
|
||||
|
||||
# 执行 Wait 操作 (保持原有逻辑)
|
||||
logger.info(f"[私聊][{self.private_name}]由于无法生成合适首次回复,执行 'wait' 操作...")
|
||||
self.state = ConversationState.WAITING
|
||||
await self.waiter.wait(self.conversation_info)
|
||||
wait_action_record = {
|
||||
"action": "wait",
|
||||
"plan_reason": "因 direct_reply 多次尝试失败而执行的后备等待",
|
||||
"status": "done",
|
||||
"time": datetime.datetime.now().strftime("%H:%M:%S"),
|
||||
"final_reason": None,
|
||||
}
|
||||
conversation_info.done_action.append(wait_action_record)
|
||||
|
||||
elif action == "fetch_knowledge":
|
||||
self.state = ConversationState.FETCHING
|
||||
knowledge_query = reason
|
||||
try:
|
||||
# 检查 knowledge_fetcher 是否存在
|
||||
if not hasattr(self, "knowledge_fetcher"):
|
||||
logger.error(f"[私聊][{self.private_name}]KnowledgeFetcher 未初始化,无法获取知识。")
|
||||
raise AttributeError("KnowledgeFetcher not initialized")
|
||||
|
||||
knowledge, source = await self.knowledge_fetcher.fetch(knowledge_query, observation_info.chat_history)
|
||||
logger.info(f"[私聊][{self.private_name}]获取到知识: {knowledge[:100]}..., 来源: {source}")
|
||||
if knowledge:
|
||||
# 确保 knowledge_list 存在
|
||||
if not hasattr(conversation_info, "knowledge_list"):
|
||||
conversation_info.knowledge_list = []
|
||||
conversation_info.knowledge_list.append(
|
||||
{"query": knowledge_query, "knowledge": knowledge, "source": source}
|
||||
)
|
||||
action_successful = True
|
||||
except Exception as fetch_err:
|
||||
logger.error(f"[私聊][{self.private_name}]获取知识时出错: {str(fetch_err)}")
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"获取知识失败: {str(fetch_err)}"}
|
||||
)
|
||||
self.conversation_info.last_successful_reply_action = None # 重置状态
|
||||
|
||||
elif action == "rethink_goal":
|
||||
self.state = ConversationState.RETHINKING
|
||||
try:
|
||||
# 检查 goal_analyzer 是否存在
|
||||
if not hasattr(self, "goal_analyzer"):
|
||||
logger.error(f"[私聊][{self.private_name}]GoalAnalyzer 未初始化,无法重新思考目标。")
|
||||
raise AttributeError("GoalAnalyzer not initialized")
|
||||
await self.goal_analyzer.analyze_goal(conversation_info, observation_info)
|
||||
action_successful = True
|
||||
except Exception as rethink_err:
|
||||
logger.error(f"[私聊][{self.private_name}]重新思考目标时出错: {rethink_err}")
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"重新思考目标失败: {rethink_err}"}
|
||||
)
|
||||
self.conversation_info.last_successful_reply_action = None # 重置状态
|
||||
|
||||
elif action == "listening":
|
||||
self.state = ConversationState.LISTENING
|
||||
logger.info(f"[私聊][{self.private_name}]倾听对方发言...")
|
||||
try:
|
||||
# 检查 waiter 是否存在
|
||||
if not hasattr(self, "waiter"):
|
||||
logger.error(f"[私聊][{self.private_name}]Waiter 未初始化,无法倾听。")
|
||||
raise AttributeError("Waiter not initialized")
|
||||
await self.waiter.wait_listening(conversation_info)
|
||||
action_successful = True # Listening 完成就算成功
|
||||
except Exception as listen_err:
|
||||
logger.error(f"[私聊][{self.private_name}]倾听时出错: {listen_err}")
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"倾听失败: {listen_err}"}
|
||||
)
|
||||
self.conversation_info.last_successful_reply_action = None # 重置状态
|
||||
|
||||
elif action == "say_goodbye":
|
||||
self.state = ConversationState.GENERATING # 也可以定义一个新的状态,如 ENDING
|
||||
logger.info(f"[私聊][{self.private_name}]执行行动: 生成并发送告别语...")
|
||||
try:
|
||||
# 1. 生成告别语 (使用 'say_goodbye' action_type)
|
||||
self.generated_reply = await self.reply_generator.generate(
|
||||
observation_info, conversation_info, action_type="say_goodbye"
|
||||
)
|
||||
logger.info(f"[私聊][{self.private_name}]生成的告别语: {self.generated_reply}")
|
||||
|
||||
# 2. 直接发送告别语 (不经过检查)
|
||||
if self.generated_reply: # 确保生成了内容
|
||||
await self._send_reply() # 调用发送方法
|
||||
# 发送成功后,标记动作成功
|
||||
action_successful = True
|
||||
logger.info(f"[私聊][{self.private_name}]告别语已发送。")
|
||||
else:
|
||||
logger.warning(f"[私聊][{self.private_name}]未能生成告别语内容,无法发送。")
|
||||
action_successful = False # 标记动作失败
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": "未能生成告别语内容"}
|
||||
)
|
||||
|
||||
# 3. 无论是否发送成功,都准备结束对话
|
||||
self.should_continue = False
|
||||
logger.info(f"[私聊][{self.private_name}]发送告别语流程结束,即将停止对话实例。")
|
||||
|
||||
except Exception as goodbye_err:
|
||||
logger.error(f"[私聊][{self.private_name}]生成或发送告别语时出错: {goodbye_err}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
|
||||
# 即使出错,也结束对话
|
||||
self.should_continue = False
|
||||
action_successful = False # 标记动作失败
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"生成或发送告别语时出错: {goodbye_err}"}
|
||||
)
|
||||
|
||||
elif action == "end_conversation":
|
||||
# 这个分支现在只会在 action_planner 最终决定不告别时被调用
|
||||
self.should_continue = False
|
||||
logger.info(f"[私聊][{self.private_name}]收到最终结束指令,停止对话...")
|
||||
action_successful = True # 标记这个指令本身是成功的
|
||||
|
||||
elif action == "block_and_ignore":
|
||||
logger.info(f"[私聊][{self.private_name}]不想再理你了...")
|
||||
ignore_duration_seconds = 10 * 60
|
||||
self.ignore_until_timestamp = time.time() + ignore_duration_seconds
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]将忽略此对话直到: {datetime.datetime.fromtimestamp(self.ignore_until_timestamp)}"
|
||||
)
|
||||
self.state = ConversationState.IGNORED
|
||||
action_successful = True # 标记动作成功
|
||||
|
||||
else: # 对应 'wait' 动作
|
||||
self.state = ConversationState.WAITING
|
||||
logger.info(f"[私聊][{self.private_name}]等待更多信息...")
|
||||
try:
|
||||
# 检查 waiter 是否存在
|
||||
if not hasattr(self, "waiter"):
|
||||
logger.error(f"[私聊][{self.private_name}]Waiter 未初始化,无法等待。")
|
||||
raise AttributeError("Waiter not initialized")
|
||||
_timeout_occurred = await self.waiter.wait(self.conversation_info)
|
||||
action_successful = True # Wait 完成就算成功
|
||||
except Exception as wait_err:
|
||||
logger.error(f"[私聊][{self.private_name}]等待时出错: {wait_err}")
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"等待失败: {wait_err}"}
|
||||
)
|
||||
self.conversation_info.last_successful_reply_action = None # 重置状态
|
||||
|
||||
# --- 更新 Action History 状态 ---
|
||||
# 只有当动作本身成功时,才更新状态为 done
|
||||
if action_successful:
|
||||
conversation_info.done_action[action_index].update(
|
||||
{
|
||||
"status": "done",
|
||||
"time": datetime.datetime.now().strftime("%H:%M:%S"),
|
||||
}
|
||||
)
|
||||
# 重置状态: 对于非回复类动作的成功,清除上次回复状态
|
||||
if action not in ["direct_reply", "send_new_message"]:
|
||||
self.conversation_info.last_successful_reply_action = None
|
||||
logger.debug(f"[私聊][{self.private_name}]动作 {action} 成功完成,重置 last_successful_reply_action")
|
||||
# 如果动作是 recall 状态,在各自的处理逻辑中已经更新了 done_action
|
||||
|
||||
async def _send_reply(self):
|
||||
"""发送回复"""
|
||||
if not self.generated_reply:
|
||||
logger.warning(f"[私聊][{self.private_name}]没有生成回复内容,无法发送。")
|
||||
return
|
||||
|
||||
try:
|
||||
_current_time = time.time()
|
||||
reply_content = self.generated_reply
|
||||
|
||||
# 发送消息 (确保 direct_sender 和 chat_stream 有效)
|
||||
if not hasattr(self, "direct_sender") or not self.direct_sender:
|
||||
logger.error(f"[私聊][{self.private_name}]DirectMessageSender 未初始化,无法发送回复。")
|
||||
return
|
||||
if not self.chat_stream:
|
||||
logger.error(f"[私聊][{self.private_name}]会话未初始化,无法发送回复。")
|
||||
return
|
||||
|
||||
await self.direct_sender.send_message(chat_stream=self.chat_stream, content=reply_content)
|
||||
|
||||
# 发送成功后,手动触发 observer 更新可能导致重复处理自己发送的消息
|
||||
# 更好的做法是依赖 observer 的自动轮询或数据库触发器(如果支持)
|
||||
# 暂时注释掉,观察是否影响 ObservationInfo 的更新
|
||||
# self.chat_observer.trigger_update()
|
||||
# if not await self.chat_observer.wait_for_update():
|
||||
# logger.warning(f"[私聊][{self.private_name}]等待 ChatObserver 更新完成超时")
|
||||
|
||||
self.state = ConversationState.ANALYZING # 更新状态
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]发送消息或更新状态时失败: {str(e)}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
|
||||
self.state = ConversationState.ANALYZING
|
||||
|
||||
async def _send_timeout_message(self):
|
||||
"""发送超时结束消息"""
|
||||
try:
|
||||
messages = self.chat_observer.get_cached_messages(limit=1)
|
||||
if not messages:
|
||||
return
|
||||
|
||||
latest_message = self._convert_to_message(messages[0])
|
||||
await self.direct_sender.send_message(
|
||||
chat_stream=self.chat_stream, content="TODO:超时消息", reply_to_message=latest_message
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]发送超时消息失败: {str(e)}")
|
||||
@@ -1,10 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ConversationInfo:
|
||||
def __init__(self):
|
||||
self.done_action: list = []
|
||||
self.goal_list: list = []
|
||||
self.knowledge_list: list = []
|
||||
self.memory_list: list = []
|
||||
self.last_successful_reply_action: Optional[str] = None
|
||||
@@ -1,61 +0,0 @@
|
||||
"""PFC 侧消息发送封装。"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage
|
||||
from src.common.logger import get_logger
|
||||
from src.services import send_service as send_api
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("message_sender")
|
||||
|
||||
|
||||
class DirectMessageSender:
|
||||
"""直接消息发送器。"""
|
||||
|
||||
def __init__(self, private_name: str) -> None:
|
||||
"""初始化直接消息发送器。
|
||||
|
||||
Args:
|
||||
private_name: 当前私聊实例的名称。
|
||||
"""
|
||||
self.private_name = private_name
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
chat_stream: BotChatSession,
|
||||
content: str,
|
||||
reply_to_message: Optional[MaiMessage] = None,
|
||||
) -> None:
|
||||
"""发送文本消息到聊天流。
|
||||
|
||||
Args:
|
||||
chat_stream: 目标聊天会话。
|
||||
content: 待发送的文本内容。
|
||||
reply_to_message: 可选的引用回复锚点消息。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当消息发送失败时抛出。
|
||||
"""
|
||||
try:
|
||||
sent = await send_api.text_to_stream(
|
||||
text=content,
|
||||
stream_id=chat_stream.session_id,
|
||||
set_reply=reply_to_message is not None,
|
||||
reply_message=reply_to_message,
|
||||
storage_message=True,
|
||||
)
|
||||
|
||||
if sent:
|
||||
logger.info(f"[私聊][{self.private_name}]PFC消息已发送: {content}")
|
||||
return
|
||||
|
||||
logger.error(f"[私聊][{self.private_name}]PFC消息发送失败")
|
||||
raise RuntimeError("消息发送失败")
|
||||
except Exception as exc:
|
||||
logger.error(f"[私聊][{self.private_name}]PFC消息发送失败: {exc}")
|
||||
raise
|
||||
@@ -1,429 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from maim_message import UserInfo
|
||||
import time
|
||||
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.mai_message_data_model import MessageInfo, UserInfo as MaiUserInfo
|
||||
from src.services.message_service import build_readable_messages
|
||||
|
||||
from .chat_observer import ChatObserver
|
||||
from .chat_states import NotificationHandler, NotificationType, Notification
|
||||
import traceback # 导入 traceback 用于调试
|
||||
|
||||
logger = get_logger("observation_info")
|
||||
|
||||
|
||||
def dict_to_session_message(msg_dict: Dict[str, Any]) -> SessionMessage:
|
||||
"""Convert PFC dict format to SessionMessage object.
|
||||
|
||||
Args:
|
||||
msg_dict: Message in PFC dict format with nested user_info
|
||||
|
||||
Returns:
|
||||
SessionMessage object compatible with build_readable_messages()
|
||||
"""
|
||||
user_info_dict: Dict[str, Any] = msg_dict.get("user_info", {})
|
||||
timestamp = msg_dict.get("time", 0.0)
|
||||
platform = user_info_dict.get("platform", "")
|
||||
message = SessionMessage(
|
||||
message_id=msg_dict.get("message_id", ""),
|
||||
timestamp=datetime.fromtimestamp(timestamp),
|
||||
platform=platform,
|
||||
)
|
||||
message.message_info = MessageInfo(
|
||||
user_info=MaiUserInfo(
|
||||
user_id=user_info_dict.get("user_id", ""),
|
||||
user_nickname=user_info_dict.get("user_nickname", ""),
|
||||
user_cardname=user_info_dict.get("user_cardname"),
|
||||
)
|
||||
)
|
||||
message.session_id = msg_dict.get("chat_id", "")
|
||||
message.processed_plain_text = msg_dict.get("processed_plain_text", "")
|
||||
message.display_message = msg_dict.get("display_message", "")
|
||||
message.is_mentioned = msg_dict.get("is_mentioned", False)
|
||||
message.is_command = msg_dict.get("is_command", False)
|
||||
message.initialized = True
|
||||
return message
|
||||
|
||||
|
||||
class ObservationInfoHandler(NotificationHandler):
|
||||
"""ObservationInfo的通知处理器"""
|
||||
|
||||
def __init__(self, observation_info: "ObservationInfo", private_name: str):
|
||||
"""初始化处理器
|
||||
|
||||
Args:
|
||||
observation_info: 要更新的ObservationInfo实例
|
||||
private_name: 私聊对象的名称,用于日志记录
|
||||
"""
|
||||
self.observation_info = observation_info
|
||||
# 将 private_name 存储在 handler 实例中
|
||||
self.private_name = private_name
|
||||
|
||||
async def handle_notification(self, notification: Notification): # 添加类型提示
|
||||
# 获取通知类型和数据
|
||||
notification_type = notification.type
|
||||
data = notification.data
|
||||
|
||||
try: # 添加错误处理块
|
||||
if notification_type == NotificationType.NEW_MESSAGE:
|
||||
# 处理新消息通知
|
||||
# logger.debug(f"[私聊][{self.private_name}]收到新消息通知data: {data}") # 可以在需要时取消注释
|
||||
message_id = data.get("message_id")
|
||||
processed_plain_text = data.get("processed_plain_text")
|
||||
detailed_plain_text = data.get("detailed_plain_text")
|
||||
user_info_dict = data.get("user_info") # 先获取字典
|
||||
time_value = data.get("time")
|
||||
|
||||
# 确保 user_info 是字典类型再创建 UserInfo 对象
|
||||
user_info = None
|
||||
if isinstance(user_info_dict, dict):
|
||||
try:
|
||||
user_info = UserInfo.from_dict(user_info_dict)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[私聊][{self.private_name}]从字典创建 UserInfo 时出错: {e}, 字典内容: {user_info_dict}"
|
||||
)
|
||||
# 可以选择在这里返回或记录错误,避免后续代码出错
|
||||
return
|
||||
elif user_info_dict is not None:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]收到的 user_info 不是预期的字典类型: {type(user_info_dict)}"
|
||||
)
|
||||
# 根据需要处理非字典情况,这里暂时返回
|
||||
return
|
||||
|
||||
message = {
|
||||
"message_id": message_id,
|
||||
"processed_plain_text": processed_plain_text,
|
||||
"detailed_plain_text": detailed_plain_text,
|
||||
"user_info": user_info_dict, # 存储原始字典或 UserInfo 对象,取决于你的 update_from_message 如何处理
|
||||
"time": time_value,
|
||||
}
|
||||
# 传递 UserInfo 对象(如果成功创建)或原始字典
|
||||
await self.observation_info.update_from_message(message, user_info) # 修改:传递 user_info 对象
|
||||
|
||||
elif notification_type == NotificationType.COLD_CHAT:
|
||||
# 处理冷场通知
|
||||
is_cold = data.get("is_cold", False)
|
||||
await self.observation_info.update_cold_chat_status(is_cold, time.time()) # 修改:改为 await 调用
|
||||
|
||||
elif notification_type == NotificationType.ACTIVE_CHAT:
|
||||
# 处理活跃通知 (通常由 COLD_CHAT 的反向状态处理)
|
||||
is_active = data.get("is_active", False)
|
||||
self.observation_info.is_cold = not is_active
|
||||
|
||||
elif notification_type == NotificationType.BOT_SPEAKING:
|
||||
# 处理机器人说话通知 (按需实现)
|
||||
self.observation_info.is_typing = False
|
||||
self.observation_info.last_bot_speak_time = time.time()
|
||||
|
||||
elif notification_type == NotificationType.USER_SPEAKING:
|
||||
# 处理用户说话通知
|
||||
self.observation_info.is_typing = False
|
||||
self.observation_info.last_user_speak_time = time.time()
|
||||
|
||||
elif notification_type == NotificationType.MESSAGE_DELETED:
|
||||
# 处理消息删除通知
|
||||
message_id = data.get("message_id")
|
||||
# 从 unprocessed_messages 中移除被删除的消息
|
||||
original_count = len(self.observation_info.unprocessed_messages)
|
||||
self.observation_info.unprocessed_messages = [
|
||||
msg for msg in self.observation_info.unprocessed_messages if msg.get("message_id") != message_id
|
||||
]
|
||||
if len(self.observation_info.unprocessed_messages) < original_count:
|
||||
logger.info(f"[私聊][{self.private_name}]移除了未处理的消息 (ID: {message_id})")
|
||||
|
||||
elif notification_type == NotificationType.USER_JOINED:
|
||||
# 处理用户加入通知 (如果适用私聊场景)
|
||||
user_id = data.get("user_id")
|
||||
if user_id:
|
||||
self.observation_info.active_users.add(str(user_id)) # 确保是字符串
|
||||
|
||||
elif notification_type == NotificationType.USER_LEFT:
|
||||
# 处理用户离开通知 (如果适用私聊场景)
|
||||
user_id = data.get("user_id")
|
||||
if user_id:
|
||||
self.observation_info.active_users.discard(str(user_id)) # 确保是字符串
|
||||
|
||||
elif notification_type == NotificationType.ERROR:
|
||||
# 处理错误通知
|
||||
error_msg = data.get("error", "未提供错误信息")
|
||||
logger.error(f"[私聊][{self.private_name}]收到错误通知: {error_msg}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]处理通知时发生错误: {e}")
|
||||
logger.error(traceback.format_exc()) # 打印详细堆栈信息
|
||||
|
||||
|
||||
# @dataclass <-- 这个,不需要了(递黄瓜)
|
||||
class ObservationInfo:
|
||||
"""决策信息类,用于收集和管理来自chat_observer的通知信息 (手动实现 __init__)"""
|
||||
|
||||
# 类型提示保留,可用于文档和静态分析
|
||||
private_name: str
|
||||
chat_history: List[Dict[str, Any]]
|
||||
chat_history_str: str
|
||||
unprocessed_messages: List[Dict[str, Any]]
|
||||
active_users: Set[str]
|
||||
last_bot_speak_time: Optional[float]
|
||||
last_user_speak_time: Optional[float]
|
||||
last_message_time: Optional[float]
|
||||
last_message_id: Optional[str]
|
||||
last_message_content: str
|
||||
last_message_sender: Optional[str]
|
||||
bot_id: Optional[str]
|
||||
chat_history_count: int
|
||||
new_messages_count: int
|
||||
cold_chat_start_time: Optional[float]
|
||||
cold_chat_duration: float
|
||||
is_typing: bool
|
||||
is_cold_chat: bool
|
||||
changed: bool
|
||||
chat_observer: Optional[ChatObserver]
|
||||
handler: Optional[ObservationInfoHandler]
|
||||
|
||||
def __init__(self, private_name: str):
|
||||
"""
|
||||
手动初始化 ObservationInfo 的所有实例变量。
|
||||
"""
|
||||
|
||||
# 接收的参数
|
||||
self.private_name: str = private_name
|
||||
|
||||
# data_list
|
||||
self.chat_history: List[Dict[str, Any]] = []
|
||||
self.chat_history_str: str = ""
|
||||
self.unprocessed_messages: List[Dict[str, Any]] = []
|
||||
self.active_users: Set[str] = set()
|
||||
|
||||
# data
|
||||
self.last_bot_speak_time: Optional[float] = None
|
||||
self.last_user_speak_time: Optional[float] = None
|
||||
self.last_message_time: Optional[float] = None
|
||||
self.last_message_id: Optional[str] = None
|
||||
self.last_message_content: str = ""
|
||||
self.last_message_sender: Optional[str] = None
|
||||
self.bot_id: Optional[str] = None
|
||||
self.chat_history_count: int = 0
|
||||
self.new_messages_count: int = 0
|
||||
self.cold_chat_start_time: Optional[float] = None
|
||||
self.cold_chat_duration: float = 0.0
|
||||
|
||||
# state
|
||||
self.is_typing: bool = False
|
||||
self.is_cold_chat: bool = False
|
||||
self.changed: bool = False
|
||||
|
||||
# 关联对象
|
||||
self.chat_observer: Optional[ChatObserver] = None
|
||||
|
||||
self.handler: ObservationInfoHandler = ObservationInfoHandler(self, self.private_name)
|
||||
|
||||
def bind_to_chat_observer(self, chat_observer: ChatObserver):
|
||||
"""绑定到指定的chat_observer
|
||||
|
||||
Args:
|
||||
chat_observer: 要绑定的 ChatObserver 实例
|
||||
"""
|
||||
if self.chat_observer:
|
||||
logger.warning(f"[私聊][{self.private_name}]尝试重复绑定 ChatObserver")
|
||||
return
|
||||
|
||||
self.chat_observer = chat_observer
|
||||
try:
|
||||
if not self.handler: # 确保 handler 已经被创建
|
||||
logger.error(f"[私聊][{self.private_name}] 尝试绑定时 handler 未初始化!")
|
||||
self.chat_observer = None # 重置,防止后续错误
|
||||
return
|
||||
|
||||
# 注册关心的通知类型
|
||||
self.chat_observer.notification_manager.register_handler(
|
||||
target="observation_info", notification_type=NotificationType.NEW_MESSAGE, handler=self.handler
|
||||
)
|
||||
self.chat_observer.notification_manager.register_handler(
|
||||
target="observation_info", notification_type=NotificationType.COLD_CHAT, handler=self.handler
|
||||
)
|
||||
# 可以根据需要注册更多通知类型
|
||||
# self.chat_observer.notification_manager.register_handler(
|
||||
# target="observation_info", notification_type=NotificationType.MESSAGE_DELETED, handler=self.handler
|
||||
# )
|
||||
logger.info(f"[私聊][{self.private_name}]成功绑定到 ChatObserver")
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]绑定到 ChatObserver 时出错: {e}")
|
||||
self.chat_observer = None # 绑定失败,重置
|
||||
|
||||
def unbind_from_chat_observer(self):
|
||||
"""解除与chat_observer的绑定"""
|
||||
if (
|
||||
self.chat_observer and hasattr(self.chat_observer, "notification_manager") and self.handler
|
||||
): # 增加 handler 检查
|
||||
try:
|
||||
self.chat_observer.notification_manager.unregister_handler(
|
||||
target="observation_info", notification_type=NotificationType.NEW_MESSAGE, handler=self.handler
|
||||
)
|
||||
self.chat_observer.notification_manager.unregister_handler(
|
||||
target="observation_info", notification_type=NotificationType.COLD_CHAT, handler=self.handler
|
||||
)
|
||||
# 如果注册了其他类型,也要在这里注销
|
||||
# self.chat_observer.notification_manager.unregister_handler(
|
||||
# target="observation_info", notification_type=NotificationType.MESSAGE_DELETED, handler=self.handler
|
||||
# )
|
||||
logger.info(f"[私聊][{self.private_name}]成功从 ChatObserver 解绑")
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]从 ChatObserver 解绑时出错: {e}")
|
||||
finally: # 确保 chat_observer 被重置
|
||||
self.chat_observer = None
|
||||
else:
|
||||
logger.warning(f"[私聊][{self.private_name}]尝试解绑时 ChatObserver 不存在、无效或 handler 未设置")
|
||||
|
||||
# 修改:update_from_message 接收 UserInfo 对象
|
||||
async def update_from_message(self, message: Dict[str, Any], user_info: Optional[UserInfo]):
|
||||
"""从消息更新信息
|
||||
|
||||
Args:
|
||||
message: 消息数据字典
|
||||
user_info: 解析后的 UserInfo 对象 (可能为 None)
|
||||
"""
|
||||
message_time = message.get("time")
|
||||
message_id = message.get("message_id")
|
||||
processed_text = message.get("processed_plain_text", "")
|
||||
|
||||
# 只有在新消息到达时才更新 last_message 相关信息
|
||||
if message_time and message_time > (self.last_message_time or 0):
|
||||
self.last_message_time = message_time
|
||||
self.last_message_id = message_id
|
||||
self.last_message_content = processed_text
|
||||
# 重置冷场计时器
|
||||
self.is_cold_chat = False
|
||||
self.cold_chat_start_time = None
|
||||
self.cold_chat_duration = 0.0
|
||||
|
||||
if user_info:
|
||||
sender_id = str(user_info.user_id) # 确保是字符串
|
||||
self.last_message_sender = sender_id
|
||||
# 更新发言时间
|
||||
if sender_id == self.bot_id:
|
||||
self.last_bot_speak_time = message_time
|
||||
else:
|
||||
self.last_user_speak_time = message_time
|
||||
self.active_users.add(sender_id) # 用户发言则认为其活跃
|
||||
else:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]处理消息更新时缺少有效的 UserInfo 对象, message_id: {message_id}"
|
||||
)
|
||||
self.last_message_sender = None # 发送者未知
|
||||
|
||||
# 将原始消息字典添加到未处理列表
|
||||
self.unprocessed_messages.append(message)
|
||||
self.new_messages_count = len(self.unprocessed_messages) # 直接用列表长度
|
||||
|
||||
# logger.debug(f"[私聊][{self.private_name}]消息更新: last_time={self.last_message_time}, new_count={self.new_messages_count}")
|
||||
self.update_changed() # 标记状态已改变
|
||||
else:
|
||||
# 如果消息时间戳不是最新的,可能不需要处理,或者记录一个警告
|
||||
pass
|
||||
# logger.warning(f"[私聊][{self.private_name}]收到过时或无效时间戳的消息: ID={message_id}, time={message_time}")
|
||||
|
||||
def update_changed(self):
|
||||
"""标记状态已改变,并重置标记"""
|
||||
# logger.debug(f"[私聊][{self.private_name}]状态标记为已改变 (changed=True)")
|
||||
self.changed = True
|
||||
|
||||
async def update_cold_chat_status(self, is_cold: bool, current_time: float):
|
||||
"""更新冷场状态
|
||||
|
||||
Args:
|
||||
is_cold: 是否处于冷场状态
|
||||
current_time: 当前时间戳
|
||||
"""
|
||||
if is_cold != self.is_cold_chat: # 仅在状态变化时更新
|
||||
self.is_cold_chat = is_cold
|
||||
if is_cold:
|
||||
# 进入冷场状态
|
||||
self.cold_chat_start_time = (
|
||||
self.last_message_time or current_time
|
||||
) # 从最后消息时间开始算,或从当前时间开始
|
||||
logger.info(f"[私聊][{self.private_name}]进入冷场状态,开始时间: {self.cold_chat_start_time}")
|
||||
else:
|
||||
# 结束冷场状态
|
||||
if self.cold_chat_start_time:
|
||||
self.cold_chat_duration = current_time - self.cold_chat_start_time
|
||||
logger.info(f"[私聊][{self.private_name}]结束冷场状态,持续时间: {self.cold_chat_duration:.2f} 秒")
|
||||
self.cold_chat_start_time = None # 重置开始时间
|
||||
self.update_changed() # 状态变化,标记改变
|
||||
|
||||
# 即使状态没变,如果是冷场状态,也更新持续时间
|
||||
if self.is_cold_chat and self.cold_chat_start_time:
|
||||
self.cold_chat_duration = current_time - self.cold_chat_start_time
|
||||
|
||||
def get_active_duration(self) -> float:
|
||||
"""获取当前活跃时长 (距离最后一条消息的时间)
|
||||
|
||||
Returns:
|
||||
float: 最后一条消息到现在的时长(秒)
|
||||
"""
|
||||
if not self.last_message_time:
|
||||
return 0.0
|
||||
return time.time() - self.last_message_time
|
||||
|
||||
def get_user_response_time(self) -> Optional[float]:
|
||||
"""获取用户最后响应时间 (距离用户最后发言的时间)
|
||||
|
||||
Returns:
|
||||
Optional[float]: 用户最后发言到现在的时长(秒),如果没有用户发言则返回None
|
||||
"""
|
||||
if not self.last_user_speak_time:
|
||||
return None
|
||||
return time.time() - self.last_user_speak_time
|
||||
|
||||
def get_bot_response_time(self) -> Optional[float]:
|
||||
"""获取机器人最后响应时间 (距离机器人最后发言的时间)
|
||||
|
||||
Returns:
|
||||
Optional[float]: 机器人最后发言到现在的时长(秒),如果没有机器人发言则返回None
|
||||
"""
|
||||
if not self.last_bot_speak_time:
|
||||
return None
|
||||
return time.time() - self.last_bot_speak_time
|
||||
|
||||
async def clear_unprocessed_messages(self):
|
||||
"""将未处理消息移入历史记录,并更新相关状态"""
|
||||
if not self.unprocessed_messages:
|
||||
return # 没有未处理消息,直接返回
|
||||
|
||||
# logger.debug(f"[私聊][{self.private_name}]处理 {len(self.unprocessed_messages)} 条未处理消息...")
|
||||
# 将未处理消息添加到历史记录中 (确保历史记录有长度限制,避免无限增长)
|
||||
max_history_len = 100 # 示例:最多保留100条历史记录
|
||||
self.chat_history.extend(self.unprocessed_messages)
|
||||
if len(self.chat_history) > max_history_len:
|
||||
self.chat_history = self.chat_history[-max_history_len:]
|
||||
|
||||
# 更新历史记录字符串 (只使用最近一部分生成,例如20条)
|
||||
history_slice_for_str = self.chat_history[-20:]
|
||||
try:
|
||||
# Convert dict format to SessionMessage objects.
|
||||
session_messages = [dict_to_session_message(m) for m in history_slice_for_str]
|
||||
self.chat_history_str = build_readable_messages(
|
||||
session_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0, # read_mark 可能需要根据逻辑调整
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]构建聊天记录字符串时出错: {e}")
|
||||
self.chat_history_str = "[构建聊天记录出错]" # 提供错误提示
|
||||
|
||||
# 清空未处理消息列表和计数
|
||||
# cleared_count = len(self.unprocessed_messages)
|
||||
self.unprocessed_messages.clear()
|
||||
self.new_messages_count = 0
|
||||
# self.has_unread_messages = False # 这个状态可以通过 new_messages_count 判断
|
||||
|
||||
self.chat_history_count = len(self.chat_history) # 更新历史记录总数
|
||||
# logger.debug(f"[私聊][{self.private_name}]已处理 {cleared_count} 条消息,当前历史记录 {self.chat_history_count} 条。")
|
||||
|
||||
self.update_changed() # 状态改变
|
||||
@@ -1,365 +0,0 @@
|
||||
from typing import List, Tuple, TYPE_CHECKING
|
||||
from src.common.logger import get_logger
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.config.config import global_config
|
||||
import random
|
||||
from .chat_observer import ChatObserver
|
||||
from .pfc_utils import get_items_from_json
|
||||
from .conversation_info import ConversationInfo
|
||||
from src.services.message_service import build_readable_messages
|
||||
|
||||
from .observation_info import ObservationInfo, dict_to_session_message
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = get_logger("pfc")
|
||||
|
||||
|
||||
def _calculate_similarity(goal1: str, goal2: str) -> float:
|
||||
"""简单计算两个目标之间的相似度
|
||||
|
||||
这里使用一个简单的实现,实际可以使用更复杂的文本相似度算法
|
||||
|
||||
Args:
|
||||
goal1: 第一个目标
|
||||
goal2: 第二个目标
|
||||
|
||||
Returns:
|
||||
float: 相似度得分 (0-1)
|
||||
"""
|
||||
# 简单实现:检查重叠字数比例
|
||||
words1 = set(goal1)
|
||||
words2 = set(goal2)
|
||||
overlap = len(words1.intersection(words2))
|
||||
total = len(words1.union(words2))
|
||||
return overlap / total if total > 0 else 0
|
||||
|
||||
|
||||
class GoalAnalyzer:
|
||||
"""对话目标分析器"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMServiceClient(
|
||||
task_name="planner", request_type="conversation_goal"
|
||||
)
|
||||
|
||||
self.personality_info = self._get_personality_prompt()
|
||||
self.name = global_config.bot.nickname
|
||||
self.nick_name = global_config.bot.alias_names
|
||||
self.private_name = private_name
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
||||
|
||||
# 多目标存储结构
|
||||
self.goals = [] # 存储多个目标
|
||||
self.max_goals = 3 # 同时保持的最大目标数量
|
||||
self.current_goal_and_reason = None
|
||||
|
||||
def _get_personality_prompt(self) -> str:
|
||||
"""获取个性提示信息"""
|
||||
prompt_personality = global_config.personality.personality
|
||||
|
||||
# 检查是否需要随机替换为状态
|
||||
if (
|
||||
global_config.personality.states
|
||||
and global_config.personality.state_probability > 0
|
||||
and random.random() < global_config.personality.state_probability
|
||||
):
|
||||
prompt_personality = random.choice(global_config.personality.states)
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
return f"你的名字是{bot_name},你{prompt_personality};"
|
||||
|
||||
async def analyze_goal(self, conversation_info: ConversationInfo, observation_info: ObservationInfo):
|
||||
"""分析对话历史并设定目标
|
||||
|
||||
Args:
|
||||
conversation_info: 对话信息
|
||||
observation_info: 观察信息
|
||||
|
||||
Returns:
|
||||
Tuple[str, str, str]: (目标, 方法, 原因)
|
||||
"""
|
||||
# 构建对话目标
|
||||
goals_str = ""
|
||||
if conversation_info.goal_list:
|
||||
for goal_reason in conversation_info.goal_list:
|
||||
if isinstance(goal_reason, dict):
|
||||
goal = goal_reason.get("goal", "目标内容缺失")
|
||||
reasoning = goal_reason.get("reasoning", "没有明确原因")
|
||||
else:
|
||||
goal = str(goal_reason)
|
||||
reasoning = "没有明确原因"
|
||||
|
||||
goal_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
|
||||
goals_str += goal_str
|
||||
else:
|
||||
goal = "目前没有明确对话目标"
|
||||
reasoning = "目前没有明确对话目标,最好思考一个对话目标"
|
||||
goals_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
|
||||
|
||||
# 获取聊天历史记录
|
||||
chat_history_text = observation_info.chat_history_str
|
||||
|
||||
if observation_info.new_messages_count > 0:
|
||||
new_messages_list = observation_info.unprocessed_messages
|
||||
session_messages = [dict_to_session_message(m) for m in new_messages_list]
|
||||
new_messages_str = build_readable_messages(
|
||||
session_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
chat_history_text += f"\n--- 以下是 {observation_info.new_messages_count} 条新消息 ---\n{new_messages_str}"
|
||||
|
||||
# await observation_info.clear_unprocessed_messages()
|
||||
|
||||
persona_text = f"你的名字是{self.name},{self.personality_info}。"
|
||||
# 构建action历史文本
|
||||
action_history_list = conversation_info.done_action
|
||||
action_history_text = "你之前做的事情是:"
|
||||
for action in action_history_list:
|
||||
action_history_text += f"{action}\n"
|
||||
|
||||
prompt = f"""{persona_text}。现在你在参与一场QQ聊天,请分析以下聊天记录,并根据你的性格特征确定多个明确的对话目标。
|
||||
这些目标应该反映出对话的不同方面和意图。
|
||||
|
||||
{action_history_text}
|
||||
当前对话目标:
|
||||
{goals_str}
|
||||
|
||||
聊天记录:
|
||||
{chat_history_text}
|
||||
|
||||
请分析当前对话并确定最适合的对话目标。你可以:
|
||||
1. 保持现有目标不变
|
||||
2. 修改现有目标
|
||||
3. 添加新目标
|
||||
4. 删除不再相关的目标
|
||||
5. 如果你想结束对话,请设置一个目标,目标goal为"结束对话",原因reasoning为你希望结束对话
|
||||
|
||||
请以JSON数组格式输出当前的所有对话目标,每个目标包含以下字段:
|
||||
1. goal: 对话目标(简短的一句话)
|
||||
2. reasoning: 对话原因,为什么设定这个目标(简要解释)
|
||||
|
||||
输出格式示例:
|
||||
[
|
||||
{{
|
||||
"goal": "回答用户关于Python编程的具体问题",
|
||||
"reasoning": "用户提出了关于Python的技术问题,需要专业且准确的解答"
|
||||
}},
|
||||
{{
|
||||
"goal": "回答用户关于python安装的具体问题",
|
||||
"reasoning": "用户提出了关于Python的技术问题,需要专业且准确的解答"
|
||||
}}
|
||||
]"""
|
||||
|
||||
logger.debug(f"[私聊][{self.private_name}]发送到LLM的提示词: {prompt}")
|
||||
try:
|
||||
generation_result = await self.llm.generate_response(prompt)
|
||||
content = generation_result.response
|
||||
logger.debug(f"[私聊][{self.private_name}]LLM原始返回内容: {content}")
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]分析对话目标时出错: {str(e)}")
|
||||
content = ""
|
||||
|
||||
# 使用改进后的get_items_from_json函数处理JSON数组
|
||||
success, result = get_items_from_json(
|
||||
content,
|
||||
self.private_name,
|
||||
"goal",
|
||||
"reasoning",
|
||||
required_types={"goal": str, "reasoning": str},
|
||||
allow_array=True,
|
||||
)
|
||||
|
||||
if success:
|
||||
# 判断结果是单个字典还是字典列表
|
||||
if isinstance(result, list):
|
||||
# 清空现有目标列表并添加新目标
|
||||
conversation_info.goal_list = []
|
||||
for item in result:
|
||||
conversation_info.goal_list.append(item)
|
||||
|
||||
# 返回第一个目标作为当前主要目标(如果有)
|
||||
if result:
|
||||
first_goal = result[0]
|
||||
return first_goal.get("goal", ""), "", first_goal.get("reasoning", "")
|
||||
else:
|
||||
# 单个目标的情况
|
||||
conversation_info.goal_list.append(result)
|
||||
goal_value = result.get("goal", "")
|
||||
reasoning_value = result.get("reasoning", "")
|
||||
return goal_value, "", reasoning_value
|
||||
|
||||
# 如果解析失败,返回默认值
|
||||
return "", "", ""
|
||||
|
||||
async def _update_goals(self, new_goal: str, method: str, reasoning: str):
|
||||
"""更新目标列表
|
||||
|
||||
Args:
|
||||
new_goal: 新的目标
|
||||
method: 实现目标的方法
|
||||
reasoning: 目标的原因
|
||||
"""
|
||||
# 检查新目标是否与现有目标相似
|
||||
for i, (existing_goal, _, _) in enumerate(self.goals):
|
||||
if _calculate_similarity(new_goal, existing_goal) > 0.7: # 相似度阈值
|
||||
# 更新现有目标
|
||||
self.goals[i] = (new_goal, method, reasoning)
|
||||
# 将此目标移到列表前面(最主要的位置)
|
||||
self.goals.insert(0, self.goals.pop(i))
|
||||
return
|
||||
|
||||
# 添加新目标到列表前面
|
||||
self.goals.insert(0, (new_goal, method, reasoning))
|
||||
|
||||
# 限制目标数量
|
||||
if len(self.goals) > self.max_goals:
|
||||
self.goals.pop() # 移除最老的目标
|
||||
|
||||
async def get_all_goals(self) -> List[Tuple[str, str, str]]:
|
||||
"""获取所有当前目标
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, str]]: 目标列表,每项为(目标, 方法, 原因)
|
||||
"""
|
||||
return self.goals.copy()
|
||||
|
||||
async def get_alternative_goals(self) -> List[Tuple[str, str, str]]:
|
||||
"""获取除了当前主要目标外的其他备选目标
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, str]]: 备选目标列表
|
||||
"""
|
||||
if len(self.goals) <= 1:
|
||||
return []
|
||||
return self.goals[1:].copy()
|
||||
|
||||
async def analyze_conversation(self, goal, reasoning):
|
||||
messages = self.chat_observer.get_cached_messages()
|
||||
session_messages = [dict_to_session_message(m) for m in messages]
|
||||
chat_history_text = build_readable_messages(
|
||||
session_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
|
||||
persona_text = f"你的名字是{self.name},{self.personality_info}。"
|
||||
# ===> Persona 文本构建结束 <===
|
||||
|
||||
# --- 修改 Prompt 字符串,使用 persona_text ---
|
||||
prompt = f"""{persona_text}。现在你在参与一场QQ聊天,
|
||||
当前对话目标:{goal}
|
||||
产生该对话目标的原因:{reasoning}
|
||||
|
||||
请分析以下聊天记录,并根据你的性格特征评估该目标是否已经达到,或者你是否希望停止该次对话。
|
||||
聊天记录:
|
||||
{chat_history_text}
|
||||
请以JSON格式输出,包含以下字段:
|
||||
1. goal_achieved: 对话目标是否已经达到(true/false)
|
||||
2. stop_conversation: 是否希望停止该次对话(true/false)
|
||||
3. reason: 为什么希望停止该次对话(简要解释)
|
||||
|
||||
输出格式示例:
|
||||
{{
|
||||
"goal_achieved": true,
|
||||
"stop_conversation": false,
|
||||
"reason": "虽然目标已达成,但对话仍然有继续的价值"
|
||||
}}"""
|
||||
|
||||
try:
|
||||
generation_result = await self.llm.generate_response(prompt)
|
||||
content = generation_result.response
|
||||
logger.debug(f"[私聊][{self.private_name}]LLM原始返回内容: {content}")
|
||||
|
||||
# 尝试解析JSON
|
||||
success, result = get_items_from_json(
|
||||
content,
|
||||
self.private_name,
|
||||
"goal_achieved",
|
||||
"stop_conversation",
|
||||
"reason",
|
||||
required_types={"goal_achieved": bool, "stop_conversation": bool, "reason": str},
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.error(f"[私聊][{self.private_name}]无法解析对话分析结果JSON")
|
||||
return False, False, "解析结果失败"
|
||||
|
||||
goal_achieved = result["goal_achieved"]
|
||||
stop_conversation = result["stop_conversation"]
|
||||
reason = result["reason"]
|
||||
|
||||
return goal_achieved, stop_conversation, reason
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]分析对话状态时出错: {str(e)}")
|
||||
return False, False, f"分析出错: {str(e)}"
|
||||
|
||||
|
||||
# 先注释掉,万一以后出问题了还能开回来(((
|
||||
# class DirectMessageSender:
|
||||
# """直接发送消息到平台的发送器"""
|
||||
|
||||
# def __init__(self, private_name: str):
|
||||
# self.logger = get_module_logger("direct_sender")
|
||||
# self.storage = MessageStorage()
|
||||
# self.private_name = private_name
|
||||
|
||||
# async def send_via_ws(self, message: MessageSending) -> None:
|
||||
# try:
|
||||
# await global_api.send_message(message)
|
||||
# except Exception as e:
|
||||
# raise ValueError(f"未找到平台:{message.message_info.platform} 的url配置,请检查配置文件") from e
|
||||
|
||||
# async def send_message(
|
||||
# self,
|
||||
# chat_stream: ChatStream,
|
||||
# content: str,
|
||||
# reply_to_message: Optional[Message] = None,
|
||||
# ) -> None:
|
||||
# """直接发送消息到平台
|
||||
|
||||
# Args:
|
||||
# chat_stream: 聊天流
|
||||
# content: 消息内容
|
||||
# reply_to_message: 要回复的消息
|
||||
# """
|
||||
# # 构建消息对象
|
||||
# message_segment = Seg(type="text", data=content)
|
||||
# bot_user_info = UserInfo(
|
||||
# user_id=global_config.BOT_QQ,
|
||||
# user_nickname=global_config.BOT_NICKNAME,
|
||||
# platform=chat_stream.platform,
|
||||
# )
|
||||
|
||||
# message = MessageSending(
|
||||
# message_id=f"dm{round(time.time(), 2)}",
|
||||
# chat_stream=chat_stream,
|
||||
# bot_user_info=bot_user_info,
|
||||
# sender_info=reply_to_message.message_info.user_info if reply_to_message else None,
|
||||
# message_segment=message_segment,
|
||||
# reply=reply_to_message,
|
||||
# is_head=True,
|
||||
# is_emoji=False,
|
||||
# thinking_start_time=time.time(),
|
||||
# )
|
||||
|
||||
# # 处理消息
|
||||
# await message.process()
|
||||
|
||||
# _message_json = message.to_dict()
|
||||
|
||||
# # 发送消息
|
||||
# try:
|
||||
# await self.send_via_ws(message)
|
||||
# await self.storage.store_message(message, chat_stream)
|
||||
# logger.success(f"[私聊][{self.private_name}]PFC消息已发送: {content}")
|
||||
# except Exception as e:
|
||||
# logger.error(f"[私聊][{self.private_name}]PFC消息发送失败: {str(e)}")
|
||||
@@ -1,77 +0,0 @@
|
||||
from typing import List, Tuple, Dict, Any
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# NOTE: HippocampusManager doesn't exist in v0.12.2 - memory system was redesigned
|
||||
# from src.plugins.memory_system.Hippocampus import HippocampusManager
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.chat.knowledge import qa_manager
|
||||
|
||||
logger = get_logger("knowledge_fetcher")
|
||||
|
||||
|
||||
class KnowledgeFetcher:
|
||||
"""知识调取器"""
|
||||
|
||||
def __init__(self, private_name: str):
|
||||
self.llm = LLMServiceClient(task_name="utils")
|
||||
self.private_name = private_name
|
||||
|
||||
def _lpmm_get_knowledge(self, query: str) -> str:
|
||||
"""获取相关知识
|
||||
|
||||
Args:
|
||||
query: 查询内容
|
||||
|
||||
Returns:
|
||||
str: 构造好的,带相关度的知识
|
||||
"""
|
||||
|
||||
logger.debug(f"[私聊][{self.private_name}]正在从LPMM知识库中获取知识")
|
||||
try:
|
||||
knowledge_info = qa_manager.get_knowledge(query)
|
||||
logger.debug(f"[私聊][{self.private_name}]LPMM知识库查询结果: {knowledge_info:150}")
|
||||
return knowledge_info
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]LPMM知识库搜索工具执行失败: {str(e)}")
|
||||
return "未找到匹配的知识"
|
||||
|
||||
async def fetch(self, query: str, chat_history: List[Dict[str, Any]]) -> Tuple[str, str]:
|
||||
"""获取相关知识
|
||||
|
||||
Args:
|
||||
query: 查询内容
|
||||
chat_history: 聊天历史 (PFC dict format)
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: (获取的知识, 知识来源)
|
||||
"""
|
||||
_ = chat_history
|
||||
|
||||
# NOTE: Hippocampus memory system was redesigned in v0.12.2
|
||||
# The old get_memory_from_text API no longer exists
|
||||
# For now, we'll skip the memory retrieval part and only use LPMM knowledge
|
||||
# TODO: Integrate with new memory system if needed
|
||||
knowledge_text = ""
|
||||
sources_text = "无记忆匹配" # 默认值
|
||||
|
||||
# # 从记忆中获取相关知识 (DISABLED - old Hippocampus API)
|
||||
# related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
||||
# text=f"{query}\n{chat_history_text}",
|
||||
# max_memory_num=3,
|
||||
# max_memory_length=2,
|
||||
# max_depth=3,
|
||||
# fast_retrieval=False,
|
||||
# )
|
||||
# if related_memory:
|
||||
# sources = []
|
||||
# for memory in related_memory:
|
||||
# knowledge_text += memory[1] + "\n"
|
||||
# sources.append(f"记忆片段{memory[0]}")
|
||||
# knowledge_text = knowledge_text.strip()
|
||||
# sources_text = ",".join(sources)
|
||||
|
||||
knowledge_text += "\n现在有以下**知识**可供参考:\n "
|
||||
knowledge_text += self._lpmm_get_knowledge(query)
|
||||
knowledge_text += "\n请记住这些**知识**,并根据**知识**回答问题。\n"
|
||||
|
||||
return knowledge_text or "未找到相关知识", sources_text or "无记忆匹配"
|
||||
@@ -1,115 +0,0 @@
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
from src.common.logger import get_logger
|
||||
from .conversation import Conversation
|
||||
import traceback
|
||||
|
||||
logger = get_logger("pfc_manager")
|
||||
|
||||
|
||||
class PFCManager:
|
||||
"""PFC对话管理器,负责管理所有对话实例"""
|
||||
|
||||
# 单例模式
|
||||
_instance = None
|
||||
|
||||
# 会话实例管理
|
||||
_instances: Dict[str, Conversation] = {}
|
||||
_initializing: Dict[str, bool] = {}
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "PFCManager":
|
||||
"""获取管理器单例
|
||||
|
||||
Returns:
|
||||
PFCManager: 管理器实例
|
||||
"""
|
||||
if cls._instance is None:
|
||||
cls._instance = PFCManager()
|
||||
return cls._instance
|
||||
|
||||
async def get_or_create_conversation(self, stream_id: str, private_name: str) -> Optional[Conversation]:
|
||||
"""获取或创建对话实例
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
private_name: 私聊名称
|
||||
|
||||
Returns:
|
||||
Optional[Conversation]: 对话实例,创建失败则返回None
|
||||
"""
|
||||
# 检查是否已经有实例
|
||||
if stream_id in self._initializing and self._initializing[stream_id]:
|
||||
logger.debug(f"[私聊][{private_name}]会话实例正在初始化中: {stream_id}")
|
||||
return None
|
||||
|
||||
if stream_id in self._instances and self._instances[stream_id].should_continue:
|
||||
logger.debug(f"[私聊][{private_name}]使用现有会话实例: {stream_id}")
|
||||
return self._instances[stream_id]
|
||||
if stream_id in self._instances:
|
||||
instance = self._instances[stream_id]
|
||||
if (
|
||||
hasattr(instance, "ignore_until_timestamp")
|
||||
and instance.ignore_until_timestamp
|
||||
and time.time() < instance.ignore_until_timestamp
|
||||
):
|
||||
logger.debug(f"[私聊][{private_name}]会话实例当前处于忽略状态: {stream_id}")
|
||||
# 返回 None 阻止交互。或者可以返回实例但标记它被忽略了喵?
|
||||
# 还是返回 None 吧喵。
|
||||
return None
|
||||
|
||||
# 检查 should_continue 状态
|
||||
if instance.should_continue:
|
||||
logger.debug(f"[私聊][{private_name}]使用现有会话实例: {stream_id}")
|
||||
return instance
|
||||
# else: 实例存在但不应继续
|
||||
try:
|
||||
# 创建新实例
|
||||
logger.info(f"[私聊][{private_name}]创建新的对话实例: {stream_id}")
|
||||
self._initializing[stream_id] = True
|
||||
# 创建实例
|
||||
conversation_instance = Conversation(stream_id, private_name)
|
||||
self._instances[stream_id] = conversation_instance
|
||||
|
||||
# 启动实例初始化
|
||||
await self._initialize_conversation(conversation_instance)
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{private_name}]创建会话实例失败: {stream_id}, 错误: {e}")
|
||||
return None
|
||||
|
||||
return conversation_instance
|
||||
|
||||
async def _initialize_conversation(self, conversation: Conversation):
|
||||
"""初始化会话实例
|
||||
|
||||
Args:
|
||||
conversation: 要初始化的会话实例
|
||||
"""
|
||||
stream_id = conversation.stream_id
|
||||
private_name = conversation.private_name
|
||||
|
||||
try:
|
||||
logger.info(f"[私聊][{private_name}]开始初始化会话实例: {stream_id}")
|
||||
# 启动初始化流程
|
||||
await conversation._initialize()
|
||||
|
||||
# 标记初始化完成
|
||||
self._initializing[stream_id] = False
|
||||
|
||||
logger.info(f"[私聊][{private_name}]会话实例 {stream_id} 初始化完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{private_name}]管理器初始化会话实例失败: {stream_id}, 错误: {e}")
|
||||
logger.error(f"[私聊][{private_name}]{traceback.format_exc()}")
|
||||
# 清理失败的初始化
|
||||
|
||||
async def get_conversation(self, stream_id: str) -> Optional[Conversation]:
|
||||
"""获取已存在的会话实例
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
|
||||
Returns:
|
||||
Optional[Conversation]: 会话实例,不存在则返回None
|
||||
"""
|
||||
return self._instances.get(stream_id)
|
||||
@@ -1,23 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
|
||||
class ConversationState(Enum):
|
||||
"""对话状态"""
|
||||
|
||||
INIT = "初始化"
|
||||
RETHINKING = "重新思考"
|
||||
ANALYZING = "分析历史"
|
||||
PLANNING = "规划目标"
|
||||
GENERATING = "生成回复"
|
||||
CHECKING = "检查回复"
|
||||
SENDING = "发送消息"
|
||||
FETCHING = "获取知识"
|
||||
WAITING = "等待"
|
||||
LISTENING = "倾听"
|
||||
ENDED = "结束"
|
||||
JUDGING = "判断"
|
||||
IGNORED = "屏蔽"
|
||||
|
||||
|
||||
ActionType = Literal["direct_reply", "fetch_knowledge", "wait"]
|
||||
@@ -1,127 +0,0 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Dict, Any, Optional, Tuple, List, Union
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("pfc_utils")
|
||||
|
||||
|
||||
def get_items_from_json(
|
||||
content: str,
|
||||
private_name: str,
|
||||
*items: str,
|
||||
default_values: Optional[Dict[str, Any]] = None,
|
||||
required_types: Optional[Dict[str, type]] = None,
|
||||
allow_array: bool = True,
|
||||
) -> Tuple[bool, Union[Dict[str, Any], List[Dict[str, Any]]]]:
|
||||
"""从文本中提取JSON内容并获取指定字段
|
||||
|
||||
Args:
|
||||
content: 包含JSON的文本
|
||||
private_name: 私聊名称
|
||||
*items: 要提取的字段名
|
||||
default_values: 字段的默认值,格式为 {字段名: 默认值}
|
||||
required_types: 字段的必需类型,格式为 {字段名: 类型}
|
||||
allow_array: 是否允许解析JSON数组
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Union[Dict[str, Any], List[Dict[str, Any]]]]: (是否成功, 提取的字段字典或字典列表)
|
||||
"""
|
||||
content = content.strip()
|
||||
result = {}
|
||||
|
||||
# 设置默认值
|
||||
if default_values:
|
||||
result.update(default_values)
|
||||
|
||||
# 首先尝试解析为JSON数组
|
||||
if allow_array:
|
||||
try:
|
||||
# 尝试找到文本中的JSON数组
|
||||
array_pattern = r"\[[\s\S]*\]"
|
||||
array_match = re.search(array_pattern, content)
|
||||
if array_match:
|
||||
array_content = array_match.group()
|
||||
json_array = json.loads(array_content)
|
||||
|
||||
# 确认是数组类型
|
||||
if isinstance(json_array, list):
|
||||
# 验证数组中的每个项目是否包含所有必需字段
|
||||
valid_items = []
|
||||
for item in json_array:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
# 检查是否有所有必需字段
|
||||
if all(field in item for field in items):
|
||||
# 验证字段类型
|
||||
if required_types:
|
||||
type_valid = True
|
||||
for field, expected_type in required_types.items():
|
||||
if field in item and not isinstance(item[field], expected_type):
|
||||
type_valid = False
|
||||
break
|
||||
|
||||
if not type_valid:
|
||||
continue
|
||||
|
||||
# 验证字符串字段不为空
|
||||
string_valid = True
|
||||
for field in items:
|
||||
if isinstance(item[field], str) and not item[field].strip():
|
||||
string_valid = False
|
||||
break
|
||||
|
||||
if not string_valid:
|
||||
continue
|
||||
|
||||
valid_items.append(item)
|
||||
|
||||
if valid_items:
|
||||
return True, valid_items
|
||||
except json.JSONDecodeError:
|
||||
logger.debug(f"[私聊][{private_name}]JSON数组解析失败,尝试解析单个JSON对象")
|
||||
except Exception as e:
|
||||
logger.debug(f"[私聊][{private_name}]尝试解析JSON数组时出错: {str(e)}")
|
||||
|
||||
# 尝试解析JSON对象
|
||||
try:
|
||||
json_data = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
# 如果直接解析失败,尝试查找和提取JSON部分
|
||||
json_pattern = r"\{[^{}]*\}"
|
||||
json_match = re.search(json_pattern, content)
|
||||
if json_match:
|
||||
try:
|
||||
json_data = json.loads(json_match.group())
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"[私聊][{private_name}]提取的JSON内容解析失败")
|
||||
return False, result
|
||||
else:
|
||||
logger.error(f"[私聊][{private_name}]无法在返回内容中找到有效的JSON")
|
||||
return False, result
|
||||
|
||||
# 提取字段
|
||||
for item in items:
|
||||
if item in json_data:
|
||||
result[item] = json_data[item]
|
||||
|
||||
# 验证必需字段
|
||||
if not all(item in result for item in items):
|
||||
logger.error(f"[私聊][{private_name}]JSON缺少必要字段,实际内容: {json_data}")
|
||||
return False, result
|
||||
|
||||
# 验证字段类型
|
||||
if required_types:
|
||||
for field, expected_type in required_types.items():
|
||||
if field in result and not isinstance(result[field], expected_type):
|
||||
logger.error(f"[私聊][{private_name}]{field} 必须是 {expected_type.__name__} 类型")
|
||||
return False, result
|
||||
|
||||
# 验证字符串字段不为空
|
||||
for field in items:
|
||||
if isinstance(result[field], str) and not result[field].strip():
|
||||
logger.error(f"[私聊][{private_name}]{field} 不能为空")
|
||||
return False, result
|
||||
|
||||
return True, result
|
||||
@@ -1,199 +0,0 @@
|
||||
import json
|
||||
import random
|
||||
from typing import Tuple, List, Dict, Any
|
||||
from src.common.logger import get_logger
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.config.config import global_config
|
||||
from .chat_observer import ChatObserver
|
||||
from maim_message import UserInfo
|
||||
|
||||
logger = get_logger("reply_checker")
|
||||
|
||||
|
||||
class ReplyChecker:
|
||||
"""回复检查器"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMServiceClient(task_name="utils", request_type="reply_check")
|
||||
self.personality_info = self._get_personality_prompt()
|
||||
self.name = global_config.bot.nickname
|
||||
self.private_name = private_name
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
||||
self.max_retries = 3 # 最大重试次数
|
||||
|
||||
def _get_personality_prompt(self) -> str:
|
||||
"""获取个性提示信息"""
|
||||
prompt_personality = global_config.personality.personality
|
||||
|
||||
# 检查是否需要随机替换为状态
|
||||
if (
|
||||
global_config.personality.states
|
||||
and global_config.personality.state_probability > 0
|
||||
and random.random() < global_config.personality.state_probability
|
||||
):
|
||||
prompt_personality = random.choice(global_config.personality.states)
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
return f"你的名字是{bot_name},你{prompt_personality};"
|
||||
|
||||
async def check(
|
||||
self, reply: str, goal: str, chat_history: List[Dict[str, Any]], chat_history_text: str, retry_count: int = 0
|
||||
) -> Tuple[bool, str, bool]:
|
||||
"""检查生成的回复是否合适
|
||||
|
||||
Args:
|
||||
reply: 生成的回复
|
||||
goal: 对话目标
|
||||
chat_history: 对话历史记录
|
||||
chat_history_text: 对话历史记录文本
|
||||
retry_count: 当前重试次数
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str, bool]: (是否合适, 原因, 是否需要重新规划)
|
||||
"""
|
||||
# 不再从 observer 获取,直接使用传入的 chat_history
|
||||
# messages = self.chat_observer.get_cached_messages(limit=20)
|
||||
try:
|
||||
# 筛选出最近由 Bot 自己发送的消息
|
||||
bot_messages = []
|
||||
for msg in reversed(chat_history):
|
||||
user_info = UserInfo.from_dict(msg.get("user_info", {}))
|
||||
if str(user_info.user_id) == str(global_config.bot.qq_account):
|
||||
bot_messages.append(msg.get("processed_plain_text", ""))
|
||||
if len(bot_messages) >= 2: # 只和最近的两条比较
|
||||
break
|
||||
# 进行比较
|
||||
if bot_messages:
|
||||
# 可以用简单比较,或者更复杂的相似度库 (如 difflib)
|
||||
# 简单比较:是否完全相同
|
||||
if reply == bot_messages[0]: # 和最近一条完全一样
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ReplyChecker 检测到回复与上一条 Bot 消息完全相同: '{reply}'"
|
||||
)
|
||||
return (
|
||||
False,
|
||||
"被逻辑检查拒绝:回复内容与你上一条发言完全相同,可以选择深入话题或寻找其它话题或等待",
|
||||
True,
|
||||
) # 不合适,需要返回至决策层
|
||||
# 2. 相似度检查 (如果精确匹配未通过)
|
||||
import difflib # 导入 difflib 库
|
||||
|
||||
# 计算编辑距离相似度,ratio() 返回 0 到 1 之间的浮点数
|
||||
similarity_ratio = difflib.SequenceMatcher(None, reply, bot_messages[0]).ratio()
|
||||
logger.debug(f"[私聊][{self.private_name}]ReplyChecker - 相似度: {similarity_ratio:.2f}")
|
||||
|
||||
# 设置一个相似度阈值
|
||||
similarity_threshold = 0.9
|
||||
if similarity_ratio > similarity_threshold:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ReplyChecker 检测到回复与上一条 Bot 消息高度相似 (相似度 {similarity_ratio:.2f}): '{reply}'"
|
||||
)
|
||||
return (
|
||||
False,
|
||||
f"被逻辑检查拒绝:回复内容与你上一条发言高度相似 (相似度 {similarity_ratio:.2f}),可以选择深入话题或寻找其它话题或等待。",
|
||||
True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
logger.error(f"[私聊][{self.private_name}]检查回复时出错: 类型={type(e)}, 值={e}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}") # 打印详细的回溯信息
|
||||
|
||||
prompt = f"""你是一个聊天逻辑检查器,请检查以下回复或消息是否合适:
|
||||
|
||||
当前对话目标:{goal}
|
||||
最新的对话记录:
|
||||
{chat_history_text}
|
||||
|
||||
待检查的消息:
|
||||
{reply}
|
||||
|
||||
请结合聊天记录检查以下几点:
|
||||
1. 这条消息是否依然符合当前对话目标和实现方式
|
||||
2. 这条消息是否与最新的对话记录保持一致性
|
||||
3. 是否存在重复发言,或重复表达同质内容(尤其是只是换一种方式表达了相同的含义)
|
||||
4. 这条消息是否包含违规内容(例如血腥暴力,政治敏感等)
|
||||
5. 这条消息是否以发送者的角度发言(不要让发送者自己回复自己的消息)
|
||||
6. 这条消息是否通俗易懂
|
||||
7. 这条消息是否有些多余,例如在对方没有回复的情况下,依然连续多次“消息轰炸”(尤其是已经连续发送3条信息的情况,这很可能不合理,需要着重判断)
|
||||
8. 这条消息是否使用了完全没必要的修辞
|
||||
9. 这条消息是否逻辑通顺
|
||||
10. 这条消息是否太过冗长了(通常私聊的每条消息长度在20字以内,除非特殊情况)
|
||||
11. 在连续多次发送消息的情况下,这条消息是否衔接自然,会不会显得奇怪(例如连续两条消息中部分内容重叠)
|
||||
|
||||
请以JSON格式输出,包含以下字段:
|
||||
1. suitable: 是否合适 (true/false)
|
||||
2. reason: 原因说明
|
||||
3. need_replan: 是否需要重新决策 (true/false),当你认为此时已经不适合发消息,需要规划其它行动时,设为true
|
||||
|
||||
输出格式示例:
|
||||
{{
|
||||
"suitable": true,
|
||||
"reason": "回复符合要求,虽然有可能略微偏离目标,但是整体内容流畅得体",
|
||||
"need_replan": false
|
||||
}}
|
||||
|
||||
注意:请严格按照JSON格式输出,不要包含任何其他内容。"""
|
||||
|
||||
try:
|
||||
generation_result = await self.llm.generate_response(prompt)
|
||||
content = generation_result.response
|
||||
logger.debug(f"[私聊][{self.private_name}]检查回复的原始返回: {content}")
|
||||
|
||||
# 清理内容,尝试提取JSON部分
|
||||
content = content.strip()
|
||||
try:
|
||||
# 尝试直接解析
|
||||
result: dict = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
# 如果直接解析失败,尝试查找和提取JSON部分
|
||||
import re
|
||||
|
||||
json_pattern = r"\{[^{}]*\}"
|
||||
json_match = re.search(json_pattern, content)
|
||||
if json_match:
|
||||
try:
|
||||
result: dict = json.loads(json_match.group())
|
||||
except json.JSONDecodeError:
|
||||
# 如果JSON解析失败,尝试从文本中提取结果
|
||||
is_suitable = "不合适" not in content.lower() and "违规" not in content.lower()
|
||||
reason = content[:100] if content else "无法解析响应"
|
||||
need_replan = "重新规划" in content.lower() or "目标不适合" in content.lower()
|
||||
return is_suitable, reason, need_replan
|
||||
else:
|
||||
# 如果找不到JSON,从文本中判断
|
||||
is_suitable = "不合适" not in content.lower() and "违规" not in content.lower()
|
||||
reason = content[:100] if content else "无法解析响应"
|
||||
need_replan = "重新规划" in content.lower() or "目标不适合" in content.lower()
|
||||
return is_suitable, reason, need_replan
|
||||
|
||||
# 验证JSON字段
|
||||
suitable = result.get("suitable", None)
|
||||
reason = result.get("reason", "未提供原因")
|
||||
need_replan = result.get("need_replan", False)
|
||||
|
||||
# 如果suitable字段是字符串,转换为布尔值
|
||||
if isinstance(suitable, str):
|
||||
suitable = suitable.lower() == "true"
|
||||
|
||||
# 如果suitable字段不存在或不是布尔值,从reason中判断
|
||||
if suitable is None:
|
||||
suitable = "不合适" not in reason.lower() and "违规" not in reason.lower()
|
||||
|
||||
# 如果不合适且未达到最大重试次数,返回需要重试
|
||||
if not suitable and retry_count < self.max_retries:
|
||||
return False, reason, False
|
||||
|
||||
# 如果不合适且已达到最大重试次数,返回需要重新规划
|
||||
if not suitable and retry_count >= self.max_retries:
|
||||
return False, f"多次重试后仍不合适: {reason}", True
|
||||
|
||||
return suitable, reason, need_replan
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]检查回复时出错: {e}")
|
||||
# 如果出错且已达到最大重试次数,建议重新规划
|
||||
if retry_count >= self.max_retries:
|
||||
return False, "多次检查失败,建议重新规划", True
|
||||
return False, f"检查过程出错,建议重试: {str(e)}", False
|
||||
@@ -1,243 +0,0 @@
|
||||
from typing import Tuple, List, Dict, Any
|
||||
from src.common.logger import get_logger
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.config.config import global_config
|
||||
import random
|
||||
from .chat_observer import ChatObserver
|
||||
from .reply_checker import ReplyChecker
|
||||
from src.services.message_service import build_readable_messages
|
||||
|
||||
from .observation_info import ObservationInfo, dict_to_session_message
|
||||
from .conversation_info import ConversationInfo
|
||||
|
||||
logger = get_logger("reply_generator")
|
||||
|
||||
# --- 定义 Prompt 模板 ---
|
||||
|
||||
# Prompt for direct_reply (首次回复)
|
||||
PROMPT_DIRECT_REPLY = """{persona_text}。现在你在参与一场QQ私聊,请根据以下信息生成一条回复:
|
||||
|
||||
当前对话目标:{goals_str}
|
||||
|
||||
{knowledge_info_str}
|
||||
|
||||
最近的聊天记录:
|
||||
{chat_history_text}
|
||||
|
||||
|
||||
请根据上述信息,结合聊天记录,回复对方。该回复应该:
|
||||
1. 符合对话目标,以"你"的角度发言(不要自己与自己对话!)
|
||||
2. 符合你的性格特征和身份细节
|
||||
3. 通俗易懂,自然流畅,像正常聊天一样,简短(通常20字以内,除非特殊情况)
|
||||
4. 可以适当利用相关知识,但不要生硬引用
|
||||
5. 自然、得体,结合聊天记录逻辑合理,且没有重复表达同质内容
|
||||
|
||||
请注意把握聊天内容,不要回复的太有条理,可以有个性。请分清"你"和对方说的话,不要把"你"说的话当做对方说的话,这是你自己说的话。
|
||||
可以回复得自然随意自然一些,就像真人一样,注意把握聊天内容,整体风格可以平和、简短,不要刻意突出自身学科背景,不要说你说过的话,可以简短,多简短都可以,但是避免冗长。
|
||||
请你注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
||||
不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。
|
||||
|
||||
请直接输出回复内容,不需要任何额外格式。"""
|
||||
|
||||
# Prompt for send_new_message (追问/补充)
|
||||
PROMPT_SEND_NEW_MESSAGE = """{persona_text}。现在你在参与一场QQ私聊,**刚刚你已经发送了一条或多条消息**,现在请根据以下信息再发一条新消息:
|
||||
|
||||
当前对话目标:{goals_str}
|
||||
|
||||
{knowledge_info_str}
|
||||
|
||||
最近的聊天记录:
|
||||
{chat_history_text}
|
||||
|
||||
|
||||
请根据上述信息,结合聊天记录,继续发一条新消息(例如对之前消息的补充,深入话题,或追问等等)。该消息应该:
|
||||
1. 符合对话目标,以"你"的角度发言(不要自己与自己对话!)
|
||||
2. 符合你的性格特征和身份细节
|
||||
3. 通俗易懂,自然流畅,像正常聊天一样,简短(通常20字以内,除非特殊情况)
|
||||
4. 可以适当利用相关知识,但不要生硬引用
|
||||
5. 跟之前你发的消息自然的衔接,逻辑合理,且没有重复表达同质内容或部分重叠内容
|
||||
|
||||
请注意把握聊天内容,不用太有条理,可以有个性。请分清"你"和对方说的话,不要把"你"说的话当做对方说的话,这是你自己说的话。
|
||||
这条消息可以自然随意自然一些,就像真人一样,注意把握聊天内容,整体风格可以平和、简短,不要刻意突出自身学科背景,不要说你说过的话,可以简短,多简短都可以,但是避免冗长。
|
||||
请你注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出消息内容。
|
||||
不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。
|
||||
|
||||
请直接输出回复内容,不需要任何额外格式。"""
|
||||
|
||||
# Prompt for say_goodbye (告别语生成)
|
||||
PROMPT_FAREWELL = """{persona_text}。你在参与一场 QQ 私聊,现在对话似乎已经结束,你决定再发一条最后的消息来圆满结束。
|
||||
|
||||
最近的聊天记录:
|
||||
{chat_history_text}
|
||||
|
||||
请根据上述信息,结合聊天记录,构思一条**简短、自然、符合你人设**的最后的消息。
|
||||
这条消息应该:
|
||||
1. 从你自己的角度发言。
|
||||
2. 符合你的性格特征和身份细节。
|
||||
3. 通俗易懂,自然流畅,通常很简短。
|
||||
4. 自然地为这场对话画上句号,避免开启新话题或显得冗长、刻意。
|
||||
|
||||
请像真人一样随意自然,**简洁是关键**。
|
||||
不要输出多余内容(包括前后缀、冒号、引号、括号、表情包、at或@等)。
|
||||
|
||||
请直接输出最终的告别消息内容,不需要任何额外格式。"""
|
||||
|
||||
|
||||
class ReplyGenerator:
|
||||
"""回复生成器"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMServiceClient(
|
||||
task_name="replyer",
|
||||
request_type="reply_generation",
|
||||
)
|
||||
self.personality_info = self._get_personality_prompt()
|
||||
self.name = global_config.bot.nickname
|
||||
self.private_name = private_name
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
||||
self.reply_checker = ReplyChecker(stream_id, private_name)
|
||||
|
||||
def _get_personality_prompt(self) -> str:
|
||||
"""获取个性提示信息"""
|
||||
prompt_personality = global_config.personality.personality
|
||||
|
||||
# 检查是否需要随机替换为状态
|
||||
if (
|
||||
global_config.personality.states
|
||||
and global_config.personality.state_probability > 0
|
||||
and random.random() < global_config.personality.state_probability
|
||||
):
|
||||
prompt_personality = random.choice(global_config.personality.states)
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
return f"你的名字是{bot_name},你{prompt_personality};"
|
||||
|
||||
# 修改 generate 方法签名,增加 action_type 参数
|
||||
async def generate(
|
||||
self, observation_info: ObservationInfo, conversation_info: ConversationInfo, action_type: str
|
||||
) -> str:
|
||||
"""生成回复
|
||||
|
||||
Args:
|
||||
observation_info: 观察信息
|
||||
conversation_info: 对话信息
|
||||
action_type: 当前执行的动作类型 ('direct_reply' 或 'send_new_message')
|
||||
|
||||
Returns:
|
||||
str: 生成的回复
|
||||
"""
|
||||
# 构建提示词
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]开始生成回复 (动作类型: {action_type}):当前目标: {conversation_info.goal_list}"
|
||||
)
|
||||
|
||||
# --- 构建通用 Prompt 参数 ---
|
||||
# (这部分逻辑基本不变)
|
||||
|
||||
# 构建对话目标 (goals_str)
|
||||
goals_str = ""
|
||||
if conversation_info.goal_list:
|
||||
for goal_reason in conversation_info.goal_list:
|
||||
if isinstance(goal_reason, dict):
|
||||
goal = goal_reason.get("goal", "目标内容缺失")
|
||||
reasoning = goal_reason.get("reasoning", "没有明确原因")
|
||||
else:
|
||||
goal = str(goal_reason)
|
||||
reasoning = "没有明确原因"
|
||||
|
||||
goal = str(goal) if goal is not None else "目标内容缺失"
|
||||
reasoning = str(reasoning) if reasoning is not None else "没有明确原因"
|
||||
goals_str += f"- 目标:{goal}\n 原因:{reasoning}\n"
|
||||
else:
|
||||
goals_str = "- 目前没有明确对话目标\n" # 简化无目标情况
|
||||
|
||||
# --- 新增:构建知识信息字符串 ---
|
||||
knowledge_info_str = "【供参考的相关知识和记忆】\n" # 稍微改下标题,表明是供参考
|
||||
try:
|
||||
# 检查 conversation_info 是否有 knowledge_list 并且不为空
|
||||
if hasattr(conversation_info, "knowledge_list") and conversation_info.knowledge_list:
|
||||
# 最多只显示最近的 5 条知识
|
||||
recent_knowledge = conversation_info.knowledge_list[-5:]
|
||||
for i, knowledge_item in enumerate(recent_knowledge):
|
||||
if isinstance(knowledge_item, dict):
|
||||
query = knowledge_item.get("query", "未知查询")
|
||||
knowledge = knowledge_item.get("knowledge", "无知识内容")
|
||||
source = knowledge_item.get("source", "未知来源")
|
||||
# 只取知识内容的前 2000 个字
|
||||
knowledge_snippet = f"{knowledge[:2000]}..." if len(knowledge) > 2000 else knowledge
|
||||
knowledge_info_str += (
|
||||
f"{i + 1}. 关于 '{query}' (来源: {source}): {knowledge_snippet}\n" # 格式微调,更简洁
|
||||
)
|
||||
else:
|
||||
knowledge_info_str += f"{i + 1}. 发现一条格式不正确的知识记录。\n"
|
||||
|
||||
if not recent_knowledge:
|
||||
knowledge_info_str += "- 暂无。\n" # 更简洁的提示
|
||||
|
||||
else:
|
||||
knowledge_info_str += "- 暂无。\n"
|
||||
except AttributeError:
|
||||
logger.warning(f"[私聊][{self.private_name}]ConversationInfo 对象可能缺少 knowledge_list 属性。")
|
||||
knowledge_info_str += "- 获取知识列表时出错。\n"
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]构建知识信息字符串时出错: {e}")
|
||||
knowledge_info_str += "- 处理知识列表时出错。\n"
|
||||
|
||||
# 获取聊天历史记录 (chat_history_text)
|
||||
chat_history_text = observation_info.chat_history_str
|
||||
if observation_info.new_messages_count > 0 and observation_info.unprocessed_messages:
|
||||
new_messages_list = observation_info.unprocessed_messages
|
||||
session_messages = [dict_to_session_message(m) for m in new_messages_list]
|
||||
new_messages_str = build_readable_messages(
|
||||
session_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
chat_history_text += f"\n--- 以下是 {observation_info.new_messages_count} 条新消息 ---\n{new_messages_str}"
|
||||
elif not chat_history_text:
|
||||
chat_history_text = "还没有聊天记录。"
|
||||
|
||||
# 构建 Persona 文本 (persona_text)
|
||||
persona_text = f"你的名字是{self.name},{self.personality_info}。"
|
||||
|
||||
# --- 选择 Prompt ---
|
||||
if action_type == "send_new_message":
|
||||
prompt_template = PROMPT_SEND_NEW_MESSAGE
|
||||
logger.info(f"[私聊][{self.private_name}]使用 PROMPT_SEND_NEW_MESSAGE (追问生成)")
|
||||
elif action_type == "say_goodbye": # 处理告别动作
|
||||
prompt_template = PROMPT_FAREWELL
|
||||
logger.info(f"[私聊][{self.private_name}]使用 PROMPT_FAREWELL (告别语生成)")
|
||||
else: # 默认使用 direct_reply 的 prompt (包括 'direct_reply' 或其他未明确处理的类型)
|
||||
prompt_template = PROMPT_DIRECT_REPLY
|
||||
logger.info(f"[私聊][{self.private_name}]使用 PROMPT_DIRECT_REPLY (首次/非连续回复生成)")
|
||||
|
||||
# --- 格式化最终的 Prompt ---
|
||||
prompt = prompt_template.format(
|
||||
persona_text=persona_text,
|
||||
goals_str=goals_str,
|
||||
chat_history_text=chat_history_text,
|
||||
knowledge_info_str=knowledge_info_str,
|
||||
)
|
||||
|
||||
# --- 调用 LLM 生成 ---
|
||||
logger.debug(f"[私聊][{self.private_name}]发送到LLM的生成提示词:\n------\n{prompt}\n------")
|
||||
try:
|
||||
generation_result = await self.llm.generate_response(prompt)
|
||||
content = generation_result.response
|
||||
logger.debug(f"[私聊][{self.private_name}]生成的回复: {content}")
|
||||
# 移除旧的检查新消息逻辑,这应该由 conversation 控制流处理
|
||||
return content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]生成回复时出错: {e}")
|
||||
return "抱歉,我现在有点混乱,让我重新思考一下..."
|
||||
|
||||
# check_reply 方法保持不变
|
||||
async def check_reply(
|
||||
self, reply: str, goal: str, chat_history: List[Dict[str, Any]], chat_history_str: str, retry_count: int = 0
|
||||
) -> Tuple[bool, str, bool]:
|
||||
"""检查回复是否合适
|
||||
(此方法逻辑保持不变)
|
||||
"""
|
||||
return await self.reply_checker.check(reply, goal, chat_history, chat_history_str, retry_count)
|
||||
@@ -1,79 +0,0 @@
|
||||
from src.common.logger import get_logger
|
||||
from .chat_observer import ChatObserver
|
||||
from .conversation_info import ConversationInfo
|
||||
|
||||
# from src.individuality.individuality import Individuality # 不再需要
|
||||
from src.config.config import global_config
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
logger = get_logger("waiter")
|
||||
|
||||
# --- 在这里设定你想要的超时时间(秒) ---
|
||||
# 例如: 120 秒 = 2 分钟
|
||||
DESIRED_TIMEOUT_SECONDS = 300
|
||||
|
||||
|
||||
class Waiter:
|
||||
"""等待处理类"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
||||
self.name = global_config.bot.nickname
|
||||
self.private_name = private_name
|
||||
# self.wait_accumulated_time = 0 # 不再需要累加计时
|
||||
|
||||
async def wait(self, conversation_info: ConversationInfo) -> bool:
|
||||
"""等待用户新消息或超时"""
|
||||
wait_start_time = time.time()
|
||||
logger.info(f"[私聊][{self.private_name}]进入常规等待状态 (超时: {DESIRED_TIMEOUT_SECONDS} 秒)...")
|
||||
|
||||
while True:
|
||||
# 检查是否有新消息
|
||||
if self.chat_observer.new_message_after(wait_start_time):
|
||||
logger.info(f"[私聊][{self.private_name}]等待结束,收到新消息")
|
||||
return False # 返回 False 表示不是超时
|
||||
|
||||
# 检查是否超时
|
||||
elapsed_time = time.time() - wait_start_time
|
||||
if elapsed_time > DESIRED_TIMEOUT_SECONDS:
|
||||
logger.info(f"[私聊][{self.private_name}]等待超过 {DESIRED_TIMEOUT_SECONDS} 秒...添加思考目标。")
|
||||
wait_goal = {
|
||||
"goal": f"你等待了{elapsed_time / 60:.1f}分钟,注意可能在对方看来聊天已经结束,思考接下来要做什么",
|
||||
"reasoning": "对方很久没有回复你的消息了",
|
||||
}
|
||||
conversation_info.goal_list.append(wait_goal)
|
||||
logger.info(f"[私聊][{self.private_name}]添加目标: {wait_goal}")
|
||||
return True # 返回 True 表示超时
|
||||
|
||||
await asyncio.sleep(5) # 每 5 秒检查一次
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]等待中..."
|
||||
) # 可以考虑把这个频繁日志注释掉,只在超时或收到消息时输出
|
||||
|
||||
async def wait_listening(self, conversation_info: ConversationInfo) -> bool:
|
||||
"""倾听用户发言或超时"""
|
||||
wait_start_time = time.time()
|
||||
logger.info(f"[私聊][{self.private_name}]进入倾听等待状态 (超时: {DESIRED_TIMEOUT_SECONDS} 秒)...")
|
||||
|
||||
while True:
|
||||
# 检查是否有新消息
|
||||
if self.chat_observer.new_message_after(wait_start_time):
|
||||
logger.info(f"[私聊][{self.private_name}]倾听等待结束,收到新消息")
|
||||
return False # 返回 False 表示不是超时
|
||||
|
||||
# 检查是否超时
|
||||
elapsed_time = time.time() - wait_start_time
|
||||
if elapsed_time > DESIRED_TIMEOUT_SECONDS:
|
||||
logger.info(f"[私聊][{self.private_name}]倾听等待超过 {DESIRED_TIMEOUT_SECONDS} 秒...添加思考目标。")
|
||||
wait_goal = {
|
||||
# 保持 goal 文本一致
|
||||
"goal": f"你等待了{elapsed_time / 60:.1f}分钟,对方似乎话说一半突然消失了,可能忙去了?也可能忘记了回复?要问问吗?还是结束对话?或继续等待?思考接下来要做什么",
|
||||
"reasoning": "对方话说一半消失了,很久没有回复",
|
||||
}
|
||||
conversation_info.goal_list.append(wait_goal)
|
||||
logger.info(f"[私聊][{self.private_name}]添加目标: {wait_goal}")
|
||||
return True # 返回 True 表示超时
|
||||
|
||||
await asyncio.sleep(5) # 每 5 秒检查一次
|
||||
logger.debug(f"[私聊][{self.private_name}]倾听等待中...") # 同上,可以考虑注释掉
|
||||
@@ -1,800 +0,0 @@
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
from rich.traceback import install
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_config import ExpressionConfigUtils
|
||||
from src.learners.expression_learner import ExpressionLearner
|
||||
from src.learners.jargon_miner import JargonMiner
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.brain_chat.brain_planner import BrainPlanner
|
||||
from src.chat.planner_actions.action_modifier import ActionModifier
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.heart_flow.heartFC_utils import CycleDetail
|
||||
from src.person_info.person_info import Person
|
||||
from src.core.types import ActionInfo, EventType
|
||||
from src.core.event_bus import event_bus
|
||||
from src.chat.event_helpers import build_event_message
|
||||
from src.services import (
|
||||
generator_service as generator_api,
|
||||
send_service as send_api,
|
||||
message_service as message_api,
|
||||
database_service as database_api,
|
||||
)
|
||||
from src.services.message_service import build_readable_messages_with_id, get_messages_before_time_in_chat
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
|
||||
|
||||
ERROR_LOOP_INFO = {
|
||||
"loop_plan_info": {
|
||||
"action_result": {
|
||||
"action_type": "error",
|
||||
"action_data": {},
|
||||
"reasoning": "循环处理失败",
|
||||
},
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": False,
|
||||
"reply_text": "",
|
||||
"command": "",
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
# 注释:原来的动作修改超时常量已移除,因为改为顺序执行
|
||||
|
||||
logger = get_logger("bc") # Logger Name Changed
|
||||
|
||||
|
||||
class BrainChatting:
|
||||
"""
|
||||
管理一个连续的私聊Brain Chat循环
|
||||
用于在特定聊天流中生成回复。
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str):
|
||||
"""
|
||||
BrainChatting 初始化函数
|
||||
|
||||
参数:
|
||||
chat_id: 聊天流唯一标识符(如stream_id)
|
||||
on_stop_focus_chat: 当收到stop_focus_chat命令时调用的回调函数
|
||||
performance_version: 性能记录版本号,用于区分不同启动版本
|
||||
"""
|
||||
# 基础属性
|
||||
self.stream_id: str = session_id # 聊天流ID
|
||||
self.chat_stream: BotChatSession = _chat_manager.get_session_by_session_id(self.stream_id) # type: ignore[assignment]
|
||||
if not self.chat_stream:
|
||||
raise ValueError(f"无法找到聊天流: {self.stream_id}")
|
||||
self.log_prefix = f"[{_chat_manager.get_session_name(self.stream_id) or self.stream_id}]"
|
||||
|
||||
expr_use, jargon_learn, expr_learn = ExpressionConfigUtils.get_expression_config_for_chat(self.stream_id)
|
||||
self._enable_expression_use = expr_use
|
||||
self._enable_expression_learning = expr_learn
|
||||
self._enable_jargon_learning = jargon_learn
|
||||
self._expression_learner = ExpressionLearner(self.stream_id)
|
||||
self._jargon_miner = JargonMiner(self.stream_id, session_name=self.log_prefix.strip("[]"))
|
||||
self._min_messages_for_extraction = 30
|
||||
self._min_extraction_interval = 60
|
||||
self._last_extraction_time = 0.0
|
||||
|
||||
self.action_manager = ActionManager()
|
||||
self.action_planner = BrainPlanner(chat_id=self.stream_id, action_manager=self.action_manager)
|
||||
self.action_modifier = ActionModifier(action_manager=self.action_manager, chat_id=self.stream_id)
|
||||
|
||||
# 循环控制内部状态
|
||||
self.running: bool = False
|
||||
self._loop_task: Optional[asyncio.Task] = None # 主循环任务
|
||||
self._new_message_event = asyncio.Event() # 新消息事件,用于打断 wait
|
||||
|
||||
# 添加循环信息管理相关的属性
|
||||
self.history_loop: List[CycleDetail] = []
|
||||
self._cycle_counter = 0
|
||||
self._current_cycle_detail: CycleDetail = None # type: ignore
|
||||
|
||||
self.last_read_time = time.time() - 2
|
||||
|
||||
self.more_plan = False
|
||||
|
||||
# 最近一次是否成功进行了 reply,用于选择 BrainPlanner 的 Prompt
|
||||
self._last_successful_reply: bool = False
|
||||
|
||||
async def start(self):
|
||||
"""检查是否需要启动主循环,如果未激活则启动。"""
|
||||
|
||||
# 如果循环已经激活,直接返回
|
||||
if self.running:
|
||||
logger.debug(f"{self.log_prefix} BrainChatting 已激活,无需重复启动")
|
||||
return
|
||||
|
||||
try:
|
||||
# 标记为活动状态,防止重复启动
|
||||
self.running = True
|
||||
|
||||
self._loop_task = asyncio.create_task(self._main_chat_loop())
|
||||
self._loop_task.add_done_callback(self._handle_loop_completion)
|
||||
logger.info(f"{self.log_prefix} BrainChatting 启动完成")
|
||||
|
||||
except Exception as e:
|
||||
# 启动失败时重置状态
|
||||
self.running = False
|
||||
self._loop_task = None
|
||||
logger.error(f"{self.log_prefix} BrainChatting 启动失败: {e}")
|
||||
raise
|
||||
|
||||
def _handle_loop_completion(self, task: asyncio.Task):
|
||||
"""当 _hfc_loop 任务完成时执行的回调。"""
|
||||
try:
|
||||
if exception := task.exception():
|
||||
logger.error(f"{self.log_prefix} BrainChatting: 脱离了聊天(异常): {exception}")
|
||||
logger.error(traceback.format_exc()) # Log full traceback for exceptions
|
||||
else:
|
||||
logger.info(f"{self.log_prefix} BrainChatting: 脱离了聊天 (外部停止)")
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.log_prefix} BrainChatting: 结束了聊天")
|
||||
|
||||
def start_cycle(self) -> Tuple[Dict[str, float], str]:
|
||||
self._cycle_counter += 1
|
||||
self._current_cycle_detail = CycleDetail(self._cycle_counter)
|
||||
self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
|
||||
cycle_timers = {}
|
||||
return cycle_timers, self._current_cycle_detail.thinking_id
|
||||
|
||||
def end_cycle(self, loop_info, cycle_timers):
|
||||
self._current_cycle_detail.set_loop_info(loop_info)
|
||||
self.history_loop.append(self._current_cycle_detail)
|
||||
self._current_cycle_detail.timers = cycle_timers
|
||||
self._current_cycle_detail.end_time = time.time()
|
||||
|
||||
def print_cycle_info(self, cycle_timers):
|
||||
# 记录循环信息和计时器结果
|
||||
timer_strings = []
|
||||
for name, elapsed in cycle_timers.items():
|
||||
formatted_time = f"{elapsed * 1000:.2f}毫秒" if elapsed < 1 else f"{elapsed:.2f}秒"
|
||||
timer_strings.append(f"{name}: {formatted_time}")
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 第{self._current_cycle_detail.cycle_id}次思考,"
|
||||
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒" # type: ignore
|
||||
+ (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
|
||||
)
|
||||
|
||||
async def _trigger_expression_learning(self, messages: List[SessionMessage]) -> None:
|
||||
if not messages:
|
||||
return
|
||||
|
||||
self._expression_learner.add_messages(messages)
|
||||
if time.time() - self._last_extraction_time < self._min_extraction_interval:
|
||||
return
|
||||
if self._expression_learner.get_cache_size() < self._min_messages_for_extraction:
|
||||
return
|
||||
if not self._enable_expression_learning:
|
||||
return
|
||||
|
||||
self._last_extraction_time = time.time()
|
||||
try:
|
||||
jargon_miner = self._jargon_miner if self._enable_jargon_learning else None
|
||||
await self._expression_learner.learn(jargon_miner)
|
||||
except Exception as exc:
|
||||
logger.error(f"{self.log_prefix} 表达学习失败: {exc}", exc_info=True)
|
||||
|
||||
async def _loopbody(self): # sourcery skip: hoist-if-from-if
|
||||
# 获取最新消息(用于上下文,但不影响是否调用 observe)
|
||||
recent_messages_list = message_api.get_messages_by_time_in_chat(
|
||||
chat_id=self.stream_id,
|
||||
start_time=self.last_read_time,
|
||||
end_time=time.time(),
|
||||
limit=20,
|
||||
limit_mode="latest",
|
||||
filter_mai=True,
|
||||
filter_command=False,
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
|
||||
# 如果有新消息,更新 last_read_time 并触发事件以打断正在进行的 wait
|
||||
if len(recent_messages_list) >= 1:
|
||||
self.last_read_time = time.time()
|
||||
self._new_message_event.set() # 触发新消息事件,打断 wait
|
||||
|
||||
# 总是执行一次思考迭代(不管有没有新消息)
|
||||
# wait 动作会在其内部等待,不需要在这里处理
|
||||
should_continue = await self._observe(recent_messages_list=recent_messages_list)
|
||||
|
||||
if not should_continue:
|
||||
# 选择了 complete_talk,返回 False 表示需要等待新消息
|
||||
return False
|
||||
|
||||
# 继续下一次迭代(除非选择了 complete_talk)
|
||||
# 短暂等待后再继续,避免过于频繁的循环
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
return True
|
||||
|
||||
async def _send_and_store_reply(
|
||||
self,
|
||||
response_set: MessageSequence,
|
||||
action_message: SessionMessage,
|
||||
cycle_timers: Dict[str, float],
|
||||
thinking_id,
|
||||
actions,
|
||||
selected_expressions: Optional[List[int]] = None,
|
||||
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
|
||||
with Timer("回复发送", cycle_timers):
|
||||
reply_text = await self._send_response(
|
||||
reply_set=response_set,
|
||||
message_data=action_message,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
|
||||
# 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
|
||||
platform = action_message.platform
|
||||
if platform is None:
|
||||
platform = getattr(self.chat_stream, "platform", "unknown")
|
||||
|
||||
person = Person(platform=platform, user_id=action_message.message_info.user_info.user_id)
|
||||
person_name = person.person_name
|
||||
action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
|
||||
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
display_prompt=action_prompt_display,
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reply_text": reply_text},
|
||||
action_name="reply",
|
||||
)
|
||||
|
||||
# 构建循环信息
|
||||
loop_info: Dict[str, Any] = {
|
||||
"loop_plan_info": {
|
||||
"action_result": actions,
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": True,
|
||||
"reply_text": reply_text,
|
||||
"command": "",
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
|
||||
return loop_info, reply_text, cycle_timers
|
||||
|
||||
async def _observe(
|
||||
self, # interest_value: float = 0.0,
|
||||
recent_messages_list: Optional[List[SessionMessage]] = None,
|
||||
) -> bool: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||
if recent_messages_list is None:
|
||||
recent_messages_list = []
|
||||
_reply_text = "" # 初始化reply_text变量,避免UnboundLocalError
|
||||
|
||||
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
||||
if recent_messages_list:
|
||||
asyncio.create_task(self._trigger_expression_learning(recent_messages_list))
|
||||
|
||||
cycle_timers, thinking_id = self.start_cycle()
|
||||
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
|
||||
|
||||
# 第一步:动作检查
|
||||
available_actions: Dict[str, ActionInfo] = {}
|
||||
try:
|
||||
await self.action_modifier.modify_actions()
|
||||
available_actions = self.action_manager.get_using_actions()
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 动作修改失败: {e}")
|
||||
|
||||
# 获取必要信息
|
||||
is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
|
||||
|
||||
# 一次思考迭代:Think - Act - Observe
|
||||
# 获取聊天上下文
|
||||
message_list_before_now = get_messages_before_time_in_chat(
|
||||
chat_id=self.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.6),
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
messages=message_list_before_now,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=self.action_planner.last_obs_time_mark,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
prompt_info = await self.action_planner.build_planner_prompt(
|
||||
chat_target_info=chat_target_info,
|
||||
current_available_actions=available_actions,
|
||||
chat_content_block=chat_content_block,
|
||||
message_id_list=message_id_list,
|
||||
prompt_key="brain_planner",
|
||||
)
|
||||
_event_msg = build_event_message(
|
||||
EventType.ON_PLAN, llm_prompt=prompt_info[0], stream_id=self.chat_stream.session_id
|
||||
)
|
||||
continue_flag, modified_message = await event_bus.emit(EventType.ON_PLAN, _event_msg)
|
||||
if not continue_flag:
|
||||
return False
|
||||
if modified_message and modified_message._modify_flags.modify_llm_prompt:
|
||||
prompt_info = (modified_message.llm_prompt, prompt_info[1])
|
||||
|
||||
with Timer("规划器", cycle_timers):
|
||||
action_to_use_info = await self.action_planner.plan(
|
||||
loop_start_time=self.last_read_time,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
|
||||
# 检查是否有 complete_talk 动作(会停止后续迭代)
|
||||
has_complete_talk = any(action.action_type == "complete_talk" for action in action_to_use_info)
|
||||
|
||||
# 并行执行所有动作
|
||||
action_tasks = [
|
||||
asyncio.create_task(
|
||||
self._execute_action(action, action_to_use_info, thinking_id, available_actions, cycle_timers)
|
||||
)
|
||||
for action in action_to_use_info
|
||||
]
|
||||
|
||||
# 并行执行所有任务
|
||||
results = await asyncio.gather(*action_tasks, return_exceptions=True)
|
||||
|
||||
# 处理执行结果
|
||||
reply_loop_info = None
|
||||
reply_text_from_reply = ""
|
||||
action_success = False
|
||||
action_reply_text = ""
|
||||
|
||||
for result in results:
|
||||
if isinstance(result, BaseException):
|
||||
logger.error(f"{self.log_prefix} 动作执行异常: {result}")
|
||||
continue
|
||||
|
||||
if result["action_type"] != "reply":
|
||||
action_success = result["success"]
|
||||
action_reply_text = result["reply_text"]
|
||||
elif result["action_type"] == "reply":
|
||||
if result["success"]:
|
||||
reply_loop_info = result["loop_info"]
|
||||
reply_text_from_reply = result["reply_text"]
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 回复动作执行失败")
|
||||
|
||||
# 更新观察时间标记
|
||||
self.action_planner.last_obs_time_mark = time.time()
|
||||
|
||||
# 如果选择了 complete_talk,标记为完成,不再继续迭代
|
||||
if has_complete_talk:
|
||||
logger.info(f"{self.log_prefix} 检测到 complete_talk 动作,本次思考完成")
|
||||
|
||||
# 构建循环信息
|
||||
if reply_loop_info:
|
||||
# 如果有回复信息,使用回复的loop_info作为基础
|
||||
loop_info = reply_loop_info
|
||||
# 更新动作执行信息
|
||||
loop_info["loop_action_info"].update(
|
||||
{
|
||||
"action_taken": action_success,
|
||||
"taken_time": time.time(),
|
||||
}
|
||||
)
|
||||
_reply_text = reply_text_from_reply
|
||||
else:
|
||||
# 没有回复信息,构建纯动作的loop_info
|
||||
loop_info = {
|
||||
"loop_plan_info": {
|
||||
"action_result": action_to_use_info,
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": action_success,
|
||||
"reply_text": action_reply_text,
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
_reply_text = action_reply_text
|
||||
|
||||
# 如果选择了 complete_talk,返回 False 以停止 _loopbody 的循环
|
||||
# 否则返回 True,让 _loopbody 继续下一次迭代
|
||||
should_continue = not has_complete_talk
|
||||
|
||||
self.end_cycle(loop_info, cycle_timers)
|
||||
self.print_cycle_info(cycle_timers)
|
||||
|
||||
# 如果选择了 complete_talk,返回 False 停止循环
|
||||
# 否则返回 True,继续下一次思考迭代
|
||||
return should_continue
|
||||
|
||||
async def _main_chat_loop(self):
|
||||
"""主循环,持续进行计划并可能回复消息,直到被外部取消。"""
|
||||
try:
|
||||
while self.running:
|
||||
# 主循环
|
||||
success = await self._loopbody()
|
||||
if not success:
|
||||
# 选择了 complete,等待新消息
|
||||
logger.info(f"{self.log_prefix} 选择了 complete,等待新消息...")
|
||||
await self._wait_for_new_message()
|
||||
# 有新消息后继续循环
|
||||
continue
|
||||
await asyncio.sleep(0.1)
|
||||
except asyncio.CancelledError:
|
||||
# 设置了关闭标志位后被取消是正常流程
|
||||
logger.info(f"{self.log_prefix} 麦麦已关闭聊天")
|
||||
except Exception:
|
||||
logger.error(f"{self.log_prefix} 麦麦聊天意外错误,将于3s后尝试重新启动")
|
||||
print(traceback.format_exc())
|
||||
await asyncio.sleep(3)
|
||||
self._loop_task = asyncio.create_task(self._main_chat_loop())
|
||||
logger.error(f"{self.log_prefix} 结束了当前聊天循环")
|
||||
|
||||
async def _wait_for_new_message(self):
|
||||
"""等待新消息到达"""
|
||||
last_check_time = self.last_read_time
|
||||
check_interval = 1.0 # 每秒检查一次
|
||||
|
||||
# 清除事件状态,准备等待新消息
|
||||
self._new_message_event.clear()
|
||||
|
||||
while self.running:
|
||||
# 检查是否有新消息
|
||||
recent_messages_list = message_api.get_messages_by_time_in_chat(
|
||||
chat_id=self.stream_id,
|
||||
start_time=last_check_time,
|
||||
end_time=time.time(),
|
||||
limit=20,
|
||||
limit_mode="latest",
|
||||
filter_mai=True,
|
||||
filter_command=False,
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
|
||||
# 如果有新消息,更新 last_read_time 并返回
|
||||
if len(recent_messages_list) >= 1:
|
||||
self.last_read_time = time.time()
|
||||
logger.info(f"{self.log_prefix} 检测到新消息,恢复循环")
|
||||
return
|
||||
|
||||
# 等待新消息事件或超时后再次检查
|
||||
try:
|
||||
await asyncio.wait_for(self._new_message_event.wait(), timeout=check_interval)
|
||||
# 事件被触发,说明有新消息
|
||||
logger.info(f"{self.log_prefix} 检测到新消息事件,恢复循环")
|
||||
return
|
||||
except asyncio.TimeoutError:
|
||||
# 超时后继续检查
|
||||
continue
|
||||
|
||||
async def _handle_action(
|
||||
self,
|
||||
action: str,
|
||||
reasoning: str,
|
||||
action_data: dict,
|
||||
cycle_timers: Dict[str, float],
|
||||
thinking_id: str,
|
||||
action_message: Optional[SessionMessage] = None,
|
||||
) -> tuple[bool, str, str]:
|
||||
"""
|
||||
处理规划动作,使用动作工厂创建相应的动作处理器
|
||||
|
||||
参数:
|
||||
action: 动作类型
|
||||
reasoning: 决策理由
|
||||
action_data: 动作数据,包含不同动作需要的参数
|
||||
cycle_timers: 计时器字典
|
||||
thinking_id: 思考ID
|
||||
|
||||
返回:
|
||||
tuple[bool, str, str]: (是否执行了动作, 思考消息ID, 命令)
|
||||
"""
|
||||
try:
|
||||
# 使用工厂创建动作处理器实例
|
||||
try:
|
||||
action_handler = self.action_manager.create_action(
|
||||
action_name=action,
|
||||
action_data=action_data,
|
||||
action_reasoning=reasoning,
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
chat_stream=self.chat_stream,
|
||||
log_prefix=self.log_prefix,
|
||||
action_message=action_message,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 创建动作处理器时出错: {e}")
|
||||
traceback.print_exc()
|
||||
return False, "", ""
|
||||
|
||||
if not action_handler:
|
||||
logger.warning(f"{self.log_prefix} 未能创建动作处理器: {action}")
|
||||
return False, "", ""
|
||||
|
||||
# 处理动作并获取结果(固定记录一次动作信息)
|
||||
# BaseAction 定义了异步方法 execute() 作为统一执行入口
|
||||
# 这里调用 execute() 以兼容所有 Action 实现
|
||||
result = await action_handler.execute()
|
||||
success, action_text = result
|
||||
command = ""
|
||||
|
||||
return success, action_text, command
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 处理{action}时出错: {e}")
|
||||
traceback.print_exc()
|
||||
return False, "", ""
|
||||
|
||||
async def _send_response(
|
||||
self,
|
||||
reply_set: MessageSequence,
|
||||
message_data: SessionMessage,
|
||||
selected_expressions: Optional[List[int]] = None,
|
||||
) -> str:
|
||||
new_message_count = message_api.count_new_messages(
|
||||
chat_id=self.chat_stream.session_id, start_time=self.last_read_time, end_time=time.time()
|
||||
)
|
||||
|
||||
need_reply = new_message_count >= random.randint(2, 4)
|
||||
|
||||
if need_reply:
|
||||
logger.info(f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复")
|
||||
|
||||
reply_text = ""
|
||||
first_replied = False
|
||||
for component in reply_set.components:
|
||||
if not isinstance(component, TextComponent):
|
||||
continue
|
||||
data = component.text
|
||||
if not first_replied:
|
||||
await send_api.text_to_stream(
|
||||
text=data,
|
||||
stream_id=self.chat_stream.session_id,
|
||||
reply_message=message_data,
|
||||
set_reply=need_reply,
|
||||
typing=False,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
first_replied = True
|
||||
else:
|
||||
await send_api.text_to_stream(
|
||||
text=data,
|
||||
stream_id=self.chat_stream.session_id,
|
||||
reply_message=message_data,
|
||||
set_reply=False,
|
||||
typing=True,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
reply_text += data
|
||||
|
||||
return reply_text
|
||||
|
||||
async def _execute_action(
|
||||
self,
|
||||
action_planner_info: ActionPlannerInfo,
|
||||
chosen_action_plan_infos: List[ActionPlannerInfo],
|
||||
thinking_id: str,
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
cycle_timers: Dict[str, float],
|
||||
):
|
||||
"""执行单个动作的通用函数"""
|
||||
try:
|
||||
with Timer(f"动作{action_planner_info.action_type}", cycle_timers):
|
||||
if action_planner_info.action_type == "complete_talk":
|
||||
# 直接处理complete_talk逻辑,不再通过动作系统
|
||||
reason = action_planner_info.reasoning or "选择完成对话"
|
||||
logger.info(f"{self.log_prefix} 选择完成对话,原因: {reason}")
|
||||
|
||||
# 存储complete_talk信息到数据库
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
display_prompt=reason,
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reason": reason},
|
||||
action_name="complete_talk",
|
||||
)
|
||||
return {"action_type": "complete_talk", "success": True, "reply_text": "", "command": ""}
|
||||
|
||||
elif action_planner_info.action_type == "reply":
|
||||
try:
|
||||
# 从 Planner 的 action_data 中提取未知词语列表(仅在 reply 时使用)
|
||||
unknown_words = None
|
||||
if isinstance(action_planner_info.action_data, dict):
|
||||
uw = action_planner_info.action_data.get("unknown_words")
|
||||
if isinstance(uw, list):
|
||||
cleaned_uw: List[str] = []
|
||||
for item in uw:
|
||||
if isinstance(item, str):
|
||||
if stripped_item := item.strip():
|
||||
cleaned_uw.append(stripped_item)
|
||||
if cleaned_uw:
|
||||
unknown_words = cleaned_uw
|
||||
|
||||
success, llm_response = await generator_api.generate_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
reply_message=action_planner_info.action_message,
|
||||
available_actions=available_actions,
|
||||
chosen_actions=chosen_action_plan_infos,
|
||||
reply_reason=action_planner_info.reasoning or "",
|
||||
unknown_words=unknown_words,
|
||||
enable_tool=global_config.tool.enable_tool,
|
||||
request_type="replyer",
|
||||
from_plugin=False,
|
||||
)
|
||||
|
||||
if not success or not llm_response or not llm_response.reply_set:
|
||||
if action_planner_info.action_message:
|
||||
logger.info(
|
||||
f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败"
|
||||
)
|
||||
else:
|
||||
logger.info("回复生成失败")
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"success": False,
|
||||
"reply_text": "",
|
||||
"loop_info": None,
|
||||
}
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消")
|
||||
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||
|
||||
response_set = llm_response.reply_set
|
||||
selected_expressions = llm_response.selected_expressions
|
||||
loop_info, reply_text, _ = await self._send_and_store_reply(
|
||||
response_set=response_set,
|
||||
action_message=action_planner_info.action_message, # type: ignore
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
actions=chosen_action_plan_infos,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
# 标记这次循环已经成功进行了回复
|
||||
self._last_successful_reply = True
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"success": True,
|
||||
"reply_text": reply_text,
|
||||
"loop_info": loop_info,
|
||||
}
|
||||
|
||||
# 其他动作
|
||||
else:
|
||||
# 内建 wait / listening:不通过插件系统,直接在这里处理
|
||||
if action_planner_info.action_type in ["wait", "listening"]:
|
||||
reason = action_planner_info.reasoning or ""
|
||||
action_data = action_planner_info.action_data or {}
|
||||
|
||||
if action_planner_info.action_type == "wait":
|
||||
# 获取等待时间(必填)
|
||||
wait_seconds = action_data.get("wait_seconds")
|
||||
if wait_seconds is None:
|
||||
logger.warning(f"{self.log_prefix} wait 动作缺少 wait_seconds 参数,使用默认值 5 秒")
|
||||
wait_seconds = 5
|
||||
else:
|
||||
try:
|
||||
wait_seconds = float(wait_seconds)
|
||||
if wait_seconds < 0:
|
||||
logger.warning(f"{self.log_prefix} wait_seconds 不能为负数,使用默认值 5 秒")
|
||||
wait_seconds = 5
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"{self.log_prefix} wait_seconds 参数格式错误,使用默认值 5 秒")
|
||||
wait_seconds = 5
|
||||
|
||||
logger.info(f"{self.log_prefix} 执行 wait 动作,等待 {wait_seconds} 秒(可被新消息打断)")
|
||||
|
||||
# 清除事件状态,准备等待新消息
|
||||
self._new_message_event.clear()
|
||||
|
||||
# 记录动作信息
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
display_prompt=reason or f"等待 {wait_seconds} 秒",
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reason": reason, "wait_seconds": wait_seconds},
|
||||
action_name="wait",
|
||||
)
|
||||
|
||||
# 等待指定时间,但可被新消息打断
|
||||
try:
|
||||
await asyncio.wait_for(self._new_message_event.wait(), timeout=wait_seconds)
|
||||
# 如果事件被触发,说明有新消息到达
|
||||
logger.info(f"{self.log_prefix} wait 动作被新消息打断,提前结束等待")
|
||||
except asyncio.TimeoutError:
|
||||
# 超时正常完成
|
||||
pass
|
||||
|
||||
logger.info(f"{self.log_prefix} wait 动作完成,继续下一次思考")
|
||||
|
||||
# 这些动作本身不产生文本回复
|
||||
self._last_successful_reply = False
|
||||
return {
|
||||
"action_type": "wait",
|
||||
"success": True,
|
||||
"reply_text": "",
|
||||
"command": "",
|
||||
}
|
||||
|
||||
# listening 已合并到 wait,如果遇到则转换为 wait(向后兼容)
|
||||
elif action_planner_info.action_type == "listening":
|
||||
logger.debug(f"{self.log_prefix} 检测到 listening 动作,已合并到 wait,自动转换")
|
||||
# 使用默认等待时间
|
||||
wait_seconds = 3
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 执行 listening(转换为 wait)动作,等待 {wait_seconds} 秒(可被新消息打断)"
|
||||
)
|
||||
|
||||
# 清除事件状态,准备等待新消息
|
||||
self._new_message_event.clear()
|
||||
|
||||
# 记录动作信息
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
display_prompt=reason or f"倾听并等待 {wait_seconds} 秒",
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reason": reason, "wait_seconds": wait_seconds},
|
||||
action_name="listening",
|
||||
)
|
||||
|
||||
# 等待指定时间,但可被新消息打断
|
||||
try:
|
||||
await asyncio.wait_for(self._new_message_event.wait(), timeout=wait_seconds)
|
||||
# 如果事件被触发,说明有新消息到达
|
||||
logger.info(f"{self.log_prefix} listening 动作被新消息打断,提前结束等待")
|
||||
except asyncio.TimeoutError:
|
||||
# 超时正常完成
|
||||
pass
|
||||
|
||||
logger.info(f"{self.log_prefix} listening 动作完成,继续下一次思考")
|
||||
|
||||
# 这些动作本身不产生文本回复
|
||||
self._last_successful_reply = False
|
||||
return {
|
||||
"action_type": "listening",
|
||||
"success": True,
|
||||
"reply_text": "",
|
||||
"command": "",
|
||||
}
|
||||
|
||||
# 其余动作:走原有插件 Action 体系
|
||||
with Timer("动作执行", cycle_timers):
|
||||
success, reply_text, command = await self._handle_action(
|
||||
action_planner_info.action_type,
|
||||
action_planner_info.reasoning or "",
|
||||
action_planner_info.action_data or {},
|
||||
cycle_timers,
|
||||
thinking_id,
|
||||
action_planner_info.action_message,
|
||||
)
|
||||
# 非 reply 类动作执行成功时,清空最近成功回复标记,让下一轮回到 initial Prompt
|
||||
if success and action_planner_info.action_type != "reply":
|
||||
self._last_successful_reply = False
|
||||
|
||||
return {
|
||||
"action_type": action_planner_info.action_type,
|
||||
"success": success,
|
||||
"reply_text": reply_text,
|
||||
"command": command,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 执行动作时出错: {e}")
|
||||
logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}")
|
||||
return {
|
||||
"action_type": action_planner_info.action_type,
|
||||
"success": False,
|
||||
"reply_text": "",
|
||||
"loop_info": None,
|
||||
"error": str(e),
|
||||
}
|
||||
@@ -1,622 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from json_repair import repair_json
|
||||
from rich.traceback import install
|
||||
|
||||
from src.chat.logger.plan_reply_logger import PlanReplyLogger
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_action import ActionUtils
|
||||
from src.config.config import global_config
|
||||
from src.core.types import ActionActivationType, ActionInfo, ComponentType
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.message_service import (
|
||||
build_readable_messages_with_id,
|
||||
get_actions_by_timestamp_with_chat,
|
||||
get_messages_before_time_in_chat,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.info_data_model import TargetPersonInfo
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
|
||||
logger = get_logger("planner")
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
class BrainPlanner:
|
||||
def __init__(self, chat_id: str, action_manager: ActionManager):
|
||||
self.chat_id = chat_id
|
||||
self.log_prefix = f"[{_chat_manager.get_session_name(chat_id) or chat_id}]"
|
||||
self.action_manager = action_manager
|
||||
# LLM规划器配置
|
||||
self.planner_llm = LLMServiceClient(
|
||||
task_name="planner", request_type="planner"
|
||||
) # 用于动作规划
|
||||
|
||||
self.last_obs_time_mark = 0.0
|
||||
|
||||
# 计划日志记录
|
||||
self.plan_log: List[Tuple[str, float, List[ActionPlannerInfo]]] = []
|
||||
|
||||
def find_message_by_id(
|
||||
self, message_id: str, message_id_list: List[Tuple[str, "SessionMessage"]]
|
||||
) -> Optional["SessionMessage"]:
|
||||
# sourcery skip: use-next
|
||||
"""
|
||||
根据message_id从message_id_list中查找对应的原始消息
|
||||
|
||||
Args:
|
||||
message_id: 要查找的消息ID
|
||||
message_id_list: 消息ID列表,格式为[{'id': str, 'message': dict}, ...]
|
||||
|
||||
Returns:
|
||||
找到的原始消息字典,如果未找到则返回None
|
||||
"""
|
||||
for item in message_id_list:
|
||||
if item[0] == message_id:
|
||||
return item[1]
|
||||
return None
|
||||
|
||||
def _parse_single_action(
|
||||
self,
|
||||
action_json: dict,
|
||||
message_id_list: List[Tuple[str, "SessionMessage"]],
|
||||
current_available_actions: List[Tuple[str, ActionInfo]],
|
||||
) -> List[ActionPlannerInfo]:
|
||||
"""解析单个action JSON并返回ActionPlannerInfo列表"""
|
||||
action_planner_infos = []
|
||||
|
||||
try:
|
||||
action = action_json.get("action", "complete_talk")
|
||||
logger.debug(f"{self.log_prefix}解析动作JSON: action={action}, json={action_json}")
|
||||
reasoning = action_json.get("reason", "未提供原因")
|
||||
action_data = {key: value for key, value in action_json.items() if key not in ["action", "reason"]}
|
||||
# 非complete_talk动作需要target_message_id
|
||||
target_message = None
|
||||
|
||||
if target_message_id := action_json.get("target_message_id"):
|
||||
# 根据target_message_id查找原始消息
|
||||
target_message = self.find_message_by_id(target_message_id, message_id_list)
|
||||
if target_message is None:
|
||||
logger.warning(f"{self.log_prefix}无法找到target_message_id '{target_message_id}' 对应的消息")
|
||||
# 选择最新消息作为target_message
|
||||
target_message = message_id_list[-1][1]
|
||||
else:
|
||||
target_message = message_id_list[-1][1]
|
||||
logger.debug(f"{self.log_prefix}动作'{action}'缺少target_message_id,使用最新消息作为target_message")
|
||||
|
||||
# 验证action是否可用
|
||||
available_action_names = [action_name for action_name, _ in current_available_actions]
|
||||
# 内部保留动作(不依赖插件系统)
|
||||
# 注意:listening 已合并到 wait 中,如果遇到 listening 则转换为 wait
|
||||
internal_action_names = ["complete_talk", "reply", "wait_time", "wait", "listening"]
|
||||
|
||||
logger.debug(
|
||||
f"{self.log_prefix}动作验证: action={action}, internal={internal_action_names}, available={available_action_names}"
|
||||
)
|
||||
|
||||
# 将 listening 转换为 wait(向后兼容)
|
||||
if action == "listening":
|
||||
logger.debug(f"{self.log_prefix}检测到 listening 动作,已合并到 wait,自动转换")
|
||||
action = "wait"
|
||||
|
||||
if action not in internal_action_names and action not in available_action_names:
|
||||
logger.warning(
|
||||
f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (内部动作: {internal_action_names}, 可用插件动作: {available_action_names}),将强制使用 'complete_talk'"
|
||||
)
|
||||
reasoning = (
|
||||
f"LLM 返回了当前不可用的动作 '{action}' (可用: {available_action_names})。原始理由: {reasoning}"
|
||||
)
|
||||
action = "complete_talk"
|
||||
logger.warning(f"{self.log_prefix}动作已转换为 complete_talk")
|
||||
|
||||
# 创建ActionPlannerInfo对象
|
||||
# 将列表转换为字典格式
|
||||
available_actions_dict = dict(current_available_actions)
|
||||
action_planner_infos.append(
|
||||
ActionPlannerInfo(
|
||||
action_type=action,
|
||||
reasoning=reasoning,
|
||||
action_data=action_data,
|
||||
action_message=target_message,
|
||||
available_actions=available_actions_dict,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}解析单个action时出错: {e}")
|
||||
# 将列表转换为字典格式
|
||||
available_actions_dict = dict(current_available_actions)
|
||||
action_planner_infos.append(
|
||||
ActionPlannerInfo(
|
||||
action_type="complete_talk",
|
||||
reasoning=f"解析单个action时出错: {e}",
|
||||
action_data={},
|
||||
action_message=None,
|
||||
available_actions=available_actions_dict,
|
||||
)
|
||||
)
|
||||
|
||||
return action_planner_infos
|
||||
|
||||
async def plan(
|
||||
self,
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
loop_start_time: float = 0.0,
|
||||
) -> List[ActionPlannerInfo]:
|
||||
# sourcery skip: use-named-expression
|
||||
"""
|
||||
规划器 (Planner): 使用LLM根据上下文决定做出什么动作(ReAct模式)。
|
||||
"""
|
||||
plan_start = time.perf_counter()
|
||||
|
||||
# 获取聊天上下文
|
||||
message_list_before_now = get_messages_before_time_in_chat(
|
||||
chat_id=self.chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.6),
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
message_id_list: list[Tuple[str, "SessionMessage"]] = []
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
messages=message_list_before_now,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=self.last_obs_time_mark,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
message_list_before_now_short = message_list_before_now[-int(global_config.chat.max_context_size * 0.3) :]
|
||||
chat_content_block_short, message_id_list_short = build_readable_messages_with_id(
|
||||
messages=message_list_before_now_short,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
truncate=False,
|
||||
show_actions=False,
|
||||
)
|
||||
|
||||
self.last_obs_time_mark = time.time()
|
||||
|
||||
# 获取必要信息
|
||||
is_group_chat, chat_target_info, current_available_actions = self.get_necessary_info()
|
||||
|
||||
# 提及/被@ 的处理由心流或统一判定模块驱动;Planner 不再做硬编码强制回复
|
||||
|
||||
# 应用激活类型过滤
|
||||
filtered_actions = self._filter_actions_by_activation_type(available_actions, chat_content_block_short)
|
||||
|
||||
logger.debug(f"{self.log_prefix}过滤后有{len(filtered_actions)}个可用动作")
|
||||
|
||||
prompt_build_start = time.perf_counter()
|
||||
# 构建包含所有动作的提示词:使用统一的 ReAct Prompt
|
||||
prompt_key = "brain_planner"
|
||||
# 这里不记录日志,避免重复打印,由调用方按需控制 log_prompt
|
||||
prompt, message_id_list = await self.build_planner_prompt(
|
||||
chat_target_info=chat_target_info,
|
||||
current_available_actions=filtered_actions,
|
||||
chat_content_block=chat_content_block,
|
||||
message_id_list=message_id_list,
|
||||
prompt_key=prompt_key,
|
||||
)
|
||||
prompt_build_ms = (time.perf_counter() - prompt_build_start) * 1000
|
||||
|
||||
# 调用LLM获取决策
|
||||
reasoning, actions, llm_raw_output, llm_reasoning, llm_duration_ms = await self._execute_main_planner(
|
||||
prompt=prompt,
|
||||
message_id_list=message_id_list,
|
||||
filtered_actions=filtered_actions,
|
||||
available_actions=available_actions,
|
||||
loop_start_time=loop_start_time,
|
||||
)
|
||||
|
||||
# 记录和展示计划日志
|
||||
logger.info(
|
||||
f"{self.log_prefix}Planner: {reasoning}。选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}"
|
||||
)
|
||||
self.add_plan_log(reasoning, actions)
|
||||
|
||||
try:
|
||||
PlanReplyLogger.log_plan(
|
||||
chat_id=self.chat_id,
|
||||
prompt=prompt,
|
||||
reasoning=reasoning,
|
||||
raw_output=llm_raw_output,
|
||||
raw_reasoning=llm_reasoning,
|
||||
actions=actions,
|
||||
timing={
|
||||
"prompt_build_ms": round(prompt_build_ms, 2),
|
||||
"llm_duration_ms": round(llm_duration_ms, 2) if llm_duration_ms is not None else None,
|
||||
"total_plan_ms": round((time.perf_counter() - plan_start) * 1000, 2),
|
||||
"loop_start_time": loop_start_time,
|
||||
},
|
||||
extra=None,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"{self.log_prefix}记录plan日志失败")
|
||||
|
||||
return actions
|
||||
|
||||
async def build_planner_prompt(
|
||||
self,
|
||||
chat_target_info: Optional["TargetPersonInfo"],
|
||||
current_available_actions: Dict[str, ActionInfo],
|
||||
message_id_list: List[Tuple[str, "SessionMessage"]],
|
||||
chat_content_block: str = "",
|
||||
interest: str = "",
|
||||
prompt_key: str = "brain_planner",
|
||||
) -> tuple[str, List[Tuple[str, "SessionMessage"]]]:
|
||||
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
|
||||
try:
|
||||
# 获取最近执行过的动作
|
||||
actions_before_now = get_actions_by_timestamp_with_chat(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=time.time() - 600,
|
||||
timestamp_end=time.time(),
|
||||
limit=6,
|
||||
)
|
||||
actions_before_now_block = ActionUtils.build_readable_action_records(actions_before_now)
|
||||
if actions_before_now_block:
|
||||
actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}"
|
||||
else:
|
||||
actions_before_now_block = ""
|
||||
|
||||
chat_context_description: str = ""
|
||||
if chat_target_info:
|
||||
# 构建聊天上下文描述
|
||||
chat_context_description = (
|
||||
f"你正在和 {chat_target_info.person_name or chat_target_info.user_nickname or '对方'} 聊天中"
|
||||
)
|
||||
|
||||
# 构建动作选项块
|
||||
action_options_block = await self._build_action_options_block(current_available_actions)
|
||||
|
||||
# 其他信息
|
||||
moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
|
||||
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
bot_name = global_config.bot.nickname
|
||||
bot_nickname = (
|
||||
f",也可以叫你{','.join(global_config.bot.alias_names)}" if global_config.bot.alias_names else ""
|
||||
)
|
||||
name_block = f"你的名字是{bot_name}{bot_nickname},请注意哪些是你自己的发言。"
|
||||
|
||||
# 获取主规划器模板并填充
|
||||
planner_prompt_template = prompt_manager.get_prompt(prompt_key)
|
||||
planner_prompt_template.add_context("time_block", time_block)
|
||||
planner_prompt_template.add_context("chat_context_description", chat_context_description)
|
||||
planner_prompt_template.add_context("chat_content_block", chat_content_block)
|
||||
planner_prompt_template.add_context("actions_before_now_block", actions_before_now_block)
|
||||
planner_prompt_template.add_context("action_options_text", action_options_block)
|
||||
planner_prompt_template.add_context("moderation_prompt", moderation_prompt_block)
|
||||
planner_prompt_template.add_context("name_block", name_block)
|
||||
planner_prompt_template.add_context("interest", interest)
|
||||
planner_prompt_template.add_context("plan_style", global_config.experimental.private_plan_style)
|
||||
prompt = await prompt_manager.render_prompt(planner_prompt_template)
|
||||
|
||||
return prompt, message_id_list
|
||||
except Exception as e:
|
||||
logger.error(f"构建 Planner 提示词时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return "构建 Planner Prompt 时出错", []
|
||||
|
||||
def get_necessary_info(self) -> Tuple[bool, Optional["TargetPersonInfo"], Dict[str, ActionInfo]]:
|
||||
"""
|
||||
获取 Planner 需要的必要信息
|
||||
"""
|
||||
is_group_chat = True
|
||||
is_group_chat, chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
||||
logger.debug(f"{self.log_prefix}获取到聊天信息 - 群聊: {is_group_chat}, 目标信息: {chat_target_info}")
|
||||
|
||||
current_available_actions_dict = self.action_manager.get_using_actions()
|
||||
|
||||
# 获取完整的动作信息
|
||||
all_registered_actions: Dict[str, ActionInfo] = component_query_service.get_components_by_type( # type: ignore
|
||||
ComponentType.ACTION
|
||||
)
|
||||
current_available_actions = {}
|
||||
for action_name in current_available_actions_dict:
|
||||
if action_name in all_registered_actions:
|
||||
current_available_actions[action_name] = all_registered_actions[action_name]
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix}使用中的动作 {action_name} 未在已注册动作中找到")
|
||||
|
||||
return is_group_chat, chat_target_info, current_available_actions
|
||||
|
||||
def _filter_actions_by_activation_type(
|
||||
self, available_actions: Dict[str, ActionInfo], chat_content_block: str
|
||||
) -> Dict[str, ActionInfo]:
|
||||
"""根据激活类型过滤动作"""
|
||||
filtered_actions = {}
|
||||
|
||||
for action_name, action_info in available_actions.items():
|
||||
if action_info.activation_type == ActionActivationType.NEVER:
|
||||
logger.debug(f"{self.log_prefix}动作 {action_name} 设置为 NEVER 激活类型,跳过")
|
||||
continue
|
||||
elif action_info.activation_type == ActionActivationType.ALWAYS:
|
||||
filtered_actions[action_name] = action_info
|
||||
elif action_info.activation_type == ActionActivationType.RANDOM:
|
||||
if random.random() < action_info.random_activation_probability:
|
||||
filtered_actions[action_name] = action_info
|
||||
elif action_info.activation_type == ActionActivationType.KEYWORD:
|
||||
if action_info.activation_keywords:
|
||||
for keyword in action_info.activation_keywords:
|
||||
if keyword in chat_content_block:
|
||||
filtered_actions[action_name] = action_info
|
||||
break
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix}未知的激活类型: {action_info.activation_type},跳过处理")
|
||||
|
||||
return filtered_actions
|
||||
|
||||
async def _build_action_options_block(self, current_available_actions: Dict[str, ActionInfo]) -> str:
|
||||
# sourcery skip: use-join
|
||||
"""构建动作选项块"""
|
||||
if not current_available_actions:
|
||||
return ""
|
||||
|
||||
action_options_block = ""
|
||||
for action_name, action_info in current_available_actions.items():
|
||||
# 构建参数文本
|
||||
param_text = ""
|
||||
if action_info.action_parameters:
|
||||
param_text = "\n"
|
||||
for param_name, param_description in action_info.action_parameters.items():
|
||||
param_text += f' "{param_name}":"{param_description}"\n'
|
||||
param_text = param_text.rstrip("\n")
|
||||
|
||||
# 构建要求文本
|
||||
require_text = ""
|
||||
for require_item in action_info.action_require:
|
||||
require_text += f"- {require_item}\n"
|
||||
require_text = require_text.rstrip("\n")
|
||||
|
||||
# 获取动作提示模板并填充
|
||||
using_action_prompt_template = prompt_manager.get_prompt("brain_action")
|
||||
using_action_prompt_template.add_context("action_name", action_name)
|
||||
using_action_prompt_template.add_context("action_description", action_info.description)
|
||||
using_action_prompt_template.add_context("action_parameters", param_text)
|
||||
using_action_prompt_template.add_context("action_require", require_text)
|
||||
using_action_prompt = await prompt_manager.render_prompt(using_action_prompt_template)
|
||||
|
||||
action_options_block += using_action_prompt
|
||||
|
||||
return action_options_block
|
||||
|
||||
async def _execute_main_planner(
|
||||
self,
|
||||
prompt: str,
|
||||
message_id_list: List[Tuple[str, "SessionMessage"]],
|
||||
filtered_actions: Dict[str, ActionInfo],
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
loop_start_time: float,
|
||||
) -> Tuple[str, List[ActionPlannerInfo], Optional[str], Optional[str], Optional[float]]:
|
||||
"""执行主规划器"""
|
||||
llm_content = None
|
||||
actions: List[ActionPlannerInfo] = []
|
||||
extracted_reasoning = ""
|
||||
llm_reasoning = None
|
||||
llm_duration_ms = None
|
||||
|
||||
try:
|
||||
# 调用LLM
|
||||
llm_start = time.perf_counter()
|
||||
generation_result = await self.planner_llm.generate_response(prompt=prompt)
|
||||
llm_content = generation_result.response
|
||||
reasoning_content = generation_result.reasoning
|
||||
llm_duration_ms = (time.perf_counter() - llm_start) * 1000
|
||||
llm_reasoning = reasoning_content
|
||||
|
||||
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
|
||||
if global_config.debug.show_planner_prompt:
|
||||
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
if reasoning_content:
|
||||
logger.info(f"{self.log_prefix}规划器推理: {reasoning_content}")
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
logger.debug(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
if reasoning_content:
|
||||
logger.debug(f"{self.log_prefix}规划器推理: {reasoning_content}")
|
||||
|
||||
except Exception as req_e:
|
||||
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
|
||||
extracted_reasoning = f"LLM 请求失败,模型出现问题: {req_e}"
|
||||
return (
|
||||
extracted_reasoning,
|
||||
[
|
||||
ActionPlannerInfo(
|
||||
action_type="complete_talk",
|
||||
reasoning=extracted_reasoning,
|
||||
action_data={},
|
||||
action_message=None,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
],
|
||||
llm_content,
|
||||
llm_reasoning,
|
||||
llm_duration_ms,
|
||||
)
|
||||
|
||||
# 解析LLM响应
|
||||
if llm_content:
|
||||
try:
|
||||
json_objects, extracted_reasoning = self._extract_json_from_markdown(llm_content)
|
||||
if json_objects:
|
||||
logger.info(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
|
||||
for i, json_obj in enumerate(json_objects):
|
||||
logger.info(f"{self.log_prefix}解析第{i + 1}个JSON对象: {json_obj}")
|
||||
filtered_actions_list = list(filtered_actions.items())
|
||||
for json_obj in json_objects:
|
||||
parsed_actions = self._parse_single_action(json_obj, message_id_list, filtered_actions_list)
|
||||
logger.info(f"{self.log_prefix}解析后的动作: {[a.action_type for a in parsed_actions]}")
|
||||
actions.extend(parsed_actions)
|
||||
else:
|
||||
# 尝试解析为直接的JSON
|
||||
logger.warning(f"{self.log_prefix}LLM没有返回可用动作: {llm_content}")
|
||||
extracted_reasoning = extracted_reasoning or "LLM没有返回可用动作"
|
||||
actions = self._create_complete_talk(extracted_reasoning, available_actions)
|
||||
|
||||
except Exception as json_e:
|
||||
logger.warning(f"{self.log_prefix}解析LLM响应JSON失败 {json_e}. LLM原始输出: '{llm_content}'")
|
||||
extracted_reasoning = f"解析LLM响应JSON失败: {json_e}"
|
||||
actions = self._create_complete_talk(extracted_reasoning, available_actions)
|
||||
traceback.print_exc()
|
||||
else:
|
||||
extracted_reasoning = "规划器没有获得LLM响应"
|
||||
actions = self._create_complete_talk(extracted_reasoning, available_actions)
|
||||
|
||||
# 添加循环开始时间到所有动作
|
||||
for action in actions:
|
||||
action.action_data = action.action_data or {}
|
||||
action.action_data["loop_start_time"] = loop_start_time
|
||||
|
||||
logger.debug(
|
||||
f"{self.log_prefix}规划器决定执行{len(actions)}个动作: {' '.join([a.action_type for a in actions])}"
|
||||
)
|
||||
|
||||
return extracted_reasoning, actions, llm_content, llm_reasoning, llm_duration_ms
|
||||
|
||||
def _create_complete_talk(
|
||||
self, reasoning: str, available_actions: Dict[str, ActionInfo]
|
||||
) -> List[ActionPlannerInfo]:
|
||||
"""创建complete_talk"""
|
||||
return [
|
||||
ActionPlannerInfo(
|
||||
action_type="complete_talk",
|
||||
reasoning=reasoning,
|
||||
action_data={},
|
||||
action_message=None,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
]
|
||||
|
||||
def add_plan_log(self, reasoning: str, actions: List[ActionPlannerInfo]):
|
||||
"""添加计划日志"""
|
||||
self.plan_log.append((reasoning, time.time(), actions))
|
||||
if len(self.plan_log) > 20:
|
||||
self.plan_log.pop(0)
|
||||
|
||||
def _extract_json_from_markdown(self, content: str) -> Tuple[List[dict], str]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""从Markdown格式的内容中提取JSON对象和推理内容"""
|
||||
json_objects = []
|
||||
reasoning_content = ""
|
||||
|
||||
# 使用正则表达式查找```json包裹的JSON内容
|
||||
json_pattern = r"```json\s*(.*?)\s*```"
|
||||
markdown_matches = re.findall(json_pattern, content, re.DOTALL)
|
||||
|
||||
# 提取JSON之前的内容作为推理文本
|
||||
first_json_pos = len(content)
|
||||
if markdown_matches:
|
||||
# 找到第一个```json的位置
|
||||
first_json_pos = content.find("```json")
|
||||
if first_json_pos > 0:
|
||||
reasoning_content = content[:first_json_pos].strip()
|
||||
# 清理推理内容中的注释标记
|
||||
reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE)
|
||||
reasoning_content = reasoning_content.strip()
|
||||
|
||||
# 处理```json包裹的JSON
|
||||
for match in markdown_matches:
|
||||
try:
|
||||
# 清理可能的注释和格式问题
|
||||
json_str = re.sub(r"//.*?\n", "\n", match) # 移除单行注释
|
||||
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL) # 移除多行注释
|
||||
if json_str := json_str.strip():
|
||||
# 先尝试将整个块作为一个JSON对象或数组(适用于多行JSON)
|
||||
try:
|
||||
json_obj = json.loads(repair_json(json_str))
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
json_objects.append(item)
|
||||
except json.JSONDecodeError:
|
||||
# 如果整个块解析失败,尝试按行分割(适用于多个单行JSON对象)
|
||||
lines = [line.strip() for line in json_str.split("\n") if line.strip()]
|
||||
for line in lines:
|
||||
try:
|
||||
# 尝试解析每一行作为独立的JSON对象
|
||||
json_obj = json.loads(repair_json(line))
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
json_objects.append(item)
|
||||
except json.JSONDecodeError:
|
||||
# 单行解析失败,继续下一行
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"{self.log_prefix}解析JSON块失败: {e}, 块内容: {match[:100]}...")
|
||||
continue
|
||||
|
||||
# 如果没有找到完整的```json```块,尝试查找不完整的代码块(缺少结尾```)
|
||||
if not json_objects:
|
||||
json_start_pos = content.find("```json")
|
||||
if json_start_pos != -1:
|
||||
# 找到```json之后的内容
|
||||
json_content_start = json_start_pos + 7 # ```json的长度
|
||||
# 提取从```json之后到内容结尾的所有内容
|
||||
incomplete_json_str = content[json_content_start:].strip()
|
||||
|
||||
# 提取JSON之前的内容作为推理文本
|
||||
if json_start_pos > 0:
|
||||
reasoning_content = content[:json_start_pos].strip()
|
||||
reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE)
|
||||
reasoning_content = reasoning_content.strip()
|
||||
|
||||
if incomplete_json_str:
|
||||
try:
|
||||
# 清理可能的注释和格式问题
|
||||
json_str = re.sub(r"//.*?\n", "\n", incomplete_json_str)
|
||||
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL)
|
||||
json_str = json_str.strip()
|
||||
|
||||
if json_str:
|
||||
# 尝试按行分割,每行可能是一个JSON对象
|
||||
lines = [line.strip() for line in json_str.split("\n") if line.strip()]
|
||||
for line in lines:
|
||||
try:
|
||||
json_obj = json.loads(repair_json(line))
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
json_objects.append(item)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 如果按行解析没有成功,尝试将整个块作为一个JSON对象或数组
|
||||
if not json_objects:
|
||||
try:
|
||||
json_obj = json.loads(repair_json(json_str))
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
json_objects.append(item)
|
||||
except Exception as e:
|
||||
logger.debug(f"尝试解析不完整的JSON代码块失败: {e}")
|
||||
except Exception as e:
|
||||
logger.debug(f"处理不完整的JSON代码块时出错: {e}")
|
||||
|
||||
return json_objects, reasoning_content
|
||||
@@ -14,7 +14,6 @@ from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
from src.platform_io.route_key_factory import RouteKeyFactory
|
||||
|
||||
# from src.chat.brain_chat.PFC.pfc_manager import PFCManager
|
||||
from src.core.announcement_manager import global_announcement_manager
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
|
||||
@@ -31,36 +30,20 @@ logger = get_logger("chat")
|
||||
|
||||
|
||||
class ChatBot:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
"""初始化聊天机器人入口。"""
|
||||
|
||||
self.bot = None # bot 实例引用
|
||||
self._started = False
|
||||
self.heartflow_message_receiver = HeartFCMessageReceiver() # 新增
|
||||
# self.pfc_manager = PFCManager.get_instance() # PFC管理器 # TODO: PFC恢复
|
||||
self.heartflow_message_receiver = HeartFCMessageReceiver()
|
||||
|
||||
async def _ensure_started(self):
|
||||
"""确保所有任务已启动"""
|
||||
async def _ensure_started(self) -> None:
|
||||
"""确保所有后台任务已启动。"""
|
||||
if not self._started:
|
||||
logger.debug("确保ChatBot所有任务已启动")
|
||||
|
||||
self._started = True
|
||||
|
||||
async def _create_pfc_chat(self, message: SessionMessage):
|
||||
"""创建或获取PFC对话实例
|
||||
|
||||
Args:
|
||||
message: 消息对象
|
||||
"""
|
||||
try:
|
||||
chat_id = message.session_id
|
||||
private_name = str(message.message_info.user_info.user_nickname)
|
||||
|
||||
logger.debug(f"[私聊][{private_name}]创建或获取PFC对话: {chat_id}")
|
||||
await self.pfc_manager.get_or_create_conversation(chat_id, private_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建PFC聊天失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _process_commands(self, message: SessionMessage) -> tuple[bool, Optional[str], bool]:
|
||||
"""使用统一组件注册表处理命令。
|
||||
|
||||
@@ -177,11 +160,12 @@ class ChatBot:
|
||||
recalled: Dict[str, Any] = {}
|
||||
recalled_id = None
|
||||
|
||||
if getattr(seg, "type", None) == "notify" and isinstance(getattr(seg, "data", None), dict):
|
||||
sub_type = seg.data.get("sub_type")
|
||||
scene = seg.data.get("scene")
|
||||
msg_id = seg.data.get("message_id")
|
||||
recalled = seg.data.get("recalled_user_info") or {}
|
||||
seg_data = getattr(seg, "data", None)
|
||||
if getattr(seg, "type", None) == "notify" and isinstance(seg_data, dict):
|
||||
sub_type = seg_data.get("sub_type")
|
||||
scene = seg_data.get("scene")
|
||||
msg_id = seg_data.get("message_id")
|
||||
recalled = seg_data.get("recalled_user_info") or {}
|
||||
if isinstance(recalled, dict):
|
||||
recalled_id = recalled.get("user_id")
|
||||
|
||||
@@ -369,23 +353,6 @@ class ChatBot:
|
||||
# else:
|
||||
# template_group_name = None
|
||||
|
||||
# async def preprocess():
|
||||
# # 根据聊天类型路由消息
|
||||
# if group_info is None:
|
||||
# # 私聊消息 -> PFC系统
|
||||
# logger.debug("[私聊]检测到私聊消息,路由到PFC系统")
|
||||
# await MessageStorage.store_message(message, chat)
|
||||
# await self._create_pfc_chat(message)
|
||||
# else:
|
||||
# # 群聊消息 -> HeartFlow系统
|
||||
# logger.debug("[群聊]检测到群聊消息,路由到HeartFlow系统")
|
||||
# await self.heartflow_message_receiver.process_message(message)
|
||||
|
||||
# if template_group_name:
|
||||
# async with global_prompt_manager.async_message_scope(template_group_name):
|
||||
# await preprocess()
|
||||
# else:
|
||||
# await preprocess()
|
||||
async def preprocess():
|
||||
if group_info is None:
|
||||
logger.debug("[私聊]检测到私聊消息,路由到 Maisaka")
|
||||
|
||||
@@ -1,137 +0,0 @@
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.logger import get_logger
|
||||
from src.core.types import ActionInfo
|
||||
from src.plugin_runtime.component_query import ActionExecutor, component_query_service
|
||||
|
||||
logger = get_logger("action_manager")
|
||||
|
||||
|
||||
class ActionHandle:
|
||||
"""Action 执行句柄
|
||||
|
||||
不依赖任何插件基类,内部持有 executor (async callable) 和绑定参数。
|
||||
brain_chat 调用 ``await handle.execute()`` 即可。
|
||||
"""
|
||||
|
||||
def __init__(self, executor: ActionExecutor, **kwargs):
|
||||
self._executor = executor
|
||||
self._kwargs = kwargs
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
return await self._executor(**self._kwargs)
|
||||
|
||||
|
||||
class ActionManager:
|
||||
"""
|
||||
动作管理器,用于管理各种类型的动作
|
||||
|
||||
使用插件运行时统一查询服务的 executor-based 模式。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化动作管理器"""
|
||||
|
||||
# 当前正在使用的动作集合,默认加载默认动作
|
||||
self._using_actions: Dict[str, ActionInfo] = {}
|
||||
|
||||
# 初始化时将默认动作加载到使用中的动作
|
||||
self._using_actions = component_query_service.get_default_actions()
|
||||
|
||||
# === 执行Action方法 ===
|
||||
|
||||
def create_action(
|
||||
self,
|
||||
action_name: str,
|
||||
action_data: dict,
|
||||
action_reasoning: str,
|
||||
cycle_timers: dict,
|
||||
thinking_id: str,
|
||||
chat_stream: BotChatSession,
|
||||
log_prefix: str,
|
||||
shutting_down: bool = False,
|
||||
action_message: Optional[SessionMessage] = None,
|
||||
) -> Optional[ActionHandle]:
|
||||
"""
|
||||
创建动作执行句柄
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
action_data: 动作数据
|
||||
action_reasoning: 执行理由
|
||||
cycle_timers: 计时器字典
|
||||
thinking_id: 思考ID
|
||||
chat_stream: 聊天流
|
||||
log_prefix: 日志前缀
|
||||
shutting_down: 是否正在关闭
|
||||
action_message: 动作消息记录
|
||||
|
||||
Returns:
|
||||
Optional[ActionHandle]: 执行句柄,如果动作未注册则返回 None
|
||||
"""
|
||||
try:
|
||||
executor = component_query_service.get_action_executor(action_name)
|
||||
if not executor:
|
||||
logger.warning(f"{log_prefix} 未找到Action组件: {action_name}")
|
||||
return None
|
||||
|
||||
info = component_query_service.get_action_info(action_name)
|
||||
if not info:
|
||||
logger.warning(f"{log_prefix} 未找到Action组件信息: {action_name}")
|
||||
return None
|
||||
|
||||
plugin_config = component_query_service.get_plugin_config(info.plugin_name) or {}
|
||||
|
||||
handle = ActionHandle(
|
||||
executor,
|
||||
action_data=action_data,
|
||||
action_reasoning=action_reasoning,
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
chat_stream=chat_stream,
|
||||
log_prefix=log_prefix,
|
||||
shutting_down=shutting_down,
|
||||
plugin_config=plugin_config,
|
||||
action_message=action_message,
|
||||
)
|
||||
|
||||
logger.debug(f"创建Action执行句柄成功: {action_name}")
|
||||
return handle
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建Action执行句柄失败 {action_name}: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
||||
def get_using_actions(self) -> Dict[str, ActionInfo]:
|
||||
"""获取当前正在使用的动作集合"""
|
||||
return self._using_actions.copy()
|
||||
|
||||
# === Modify相关方法 ===
|
||||
def remove_action_from_using(self, action_name: str) -> bool:
|
||||
"""
|
||||
从当前使用的动作集中移除指定动作
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
|
||||
Returns:
|
||||
bool: 移除是否成功
|
||||
"""
|
||||
if action_name not in self._using_actions:
|
||||
logger.warning(f"移除失败: 动作 {action_name} 不在当前使用的动作集中")
|
||||
return False
|
||||
|
||||
del self._using_actions[action_name]
|
||||
logger.debug(f"已从使用集中移除动作 {action_name}")
|
||||
return True
|
||||
|
||||
def restore_actions(self) -> None:
|
||||
"""恢复到默认动作集"""
|
||||
actions_to_restore = list(self._using_actions.keys())
|
||||
self._using_actions = component_query_service.get_default_actions()
|
||||
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
|
||||
@@ -1,233 +0,0 @@
|
||||
import random
|
||||
import time
|
||||
from typing import List, Dict, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.services.message_service import build_readable_messages, get_messages_before_time_in_chat
|
||||
from src.core.types import ActionActivationType, ActionInfo
|
||||
from src.core.announcement_manager import global_announcement_manager
|
||||
|
||||
logger = get_logger("action_manager")
|
||||
|
||||
|
||||
class ActionModifier:
|
||||
"""动作处理器
|
||||
|
||||
用于处理Observation对象和根据激活类型处理actions。
|
||||
集成了原有的modify_actions功能和新的激活类型处理功能。
|
||||
支持并行判定和智能缓存优化。
|
||||
"""
|
||||
|
||||
def __init__(self, action_manager: ActionManager, chat_id: str):
|
||||
"""初始化动作处理器"""
|
||||
self.chat_id = chat_id
|
||||
self.chat_stream: BotChatSession = _chat_manager.get_session_by_session_id(self.chat_id) # type: ignore
|
||||
self.log_prefix = f"[{_chat_manager.get_session_name(self.chat_id) or self.chat_id}]"
|
||||
|
||||
self.action_manager = action_manager
|
||||
|
||||
async def modify_actions(
|
||||
self,
|
||||
message_content: str = "",
|
||||
): # sourcery skip: use-named-expression
|
||||
"""
|
||||
动作修改流程,整合传统观察处理和新的激活类型判定
|
||||
|
||||
这个方法处理完整的动作管理流程:
|
||||
1. 基于观察的传统动作修改(循环历史分析、类型匹配等)
|
||||
2. 基于激活类型的智能动作判定,最终确定可用动作集
|
||||
|
||||
处理后,ActionManager 将包含最终的可用动作集,供规划器直接使用
|
||||
"""
|
||||
logger.debug(f"{self.log_prefix}开始完整动作修改流程")
|
||||
|
||||
removals_s1: List[Tuple[str, str]] = []
|
||||
removals_s2: List[Tuple[str, str]] = []
|
||||
# removals_s3: List[Tuple[str, str]] = []
|
||||
|
||||
self.action_manager.restore_actions()
|
||||
all_actions = self.action_manager.get_using_actions()
|
||||
|
||||
message_list_before_now_half = get_messages_before_time_in_chat(
|
||||
chat_id=self.chat_stream.session_id,
|
||||
timestamp=time.time(),
|
||||
limit=min(int(global_config.chat.max_context_size * 0.33), 10),
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
|
||||
chat_content = build_readable_messages(
|
||||
message_list_before_now_half,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
if message_content:
|
||||
chat_content = chat_content + "\n" + f"现在,最新的消息是:{message_content}"
|
||||
|
||||
# === 第一阶段:去除用户自行禁用的 ===
|
||||
disabled_actions = global_announcement_manager.get_disabled_chat_actions(self.chat_id)
|
||||
if disabled_actions:
|
||||
for disabled_action_name in disabled_actions:
|
||||
if disabled_action_name in all_actions:
|
||||
removals_s1.append((disabled_action_name, "用户自行禁用"))
|
||||
self.action_manager.remove_action_from_using(disabled_action_name)
|
||||
logger.debug(f"{self.log_prefix}阶段一移除动作: {disabled_action_name},原因: 用户自行禁用")
|
||||
|
||||
# === 第二阶段:检查动作的关联类型 ===
|
||||
chat_context = self.chat_stream.context
|
||||
type_mismatched_actions = self._check_action_associated_types(all_actions, chat_context)
|
||||
|
||||
if type_mismatched_actions:
|
||||
removals_s2.extend(type_mismatched_actions)
|
||||
|
||||
# 应用第二阶段的移除
|
||||
for action_name, reason in removals_s2:
|
||||
self.action_manager.remove_action_from_using(action_name)
|
||||
logger.debug(f"{self.log_prefix}阶段二移除动作: {action_name},原因: {reason}")
|
||||
|
||||
# === 第三阶段:激活类型判定 ===
|
||||
# if chat_content is not None:
|
||||
# logger.debug(f"{self.log_prefix}开始激活类型判定阶段")
|
||||
|
||||
# 获取当前使用的动作集(经过第一阶段处理)
|
||||
# current_using_actions = self.action_manager.get_using_actions()
|
||||
|
||||
# 获取因激活类型判定而需要移除的动作
|
||||
# removals_s3 = await self._get_deactivated_actions_by_type(
|
||||
# current_using_actions,
|
||||
# chat_content,
|
||||
# )
|
||||
|
||||
# 应用第三阶段的移除
|
||||
# for action_name, reason in removals_s3:
|
||||
# self.action_manager.remove_action_from_using(action_name)
|
||||
# logger.debug(f"{self.log_prefix}阶段三移除动作: {action_name},原因: {reason}")
|
||||
|
||||
# === 统一日志记录 ===
|
||||
all_removals = removals_s1 + removals_s2
|
||||
removals_summary: str = ""
|
||||
if all_removals:
|
||||
removals_summary = " | ".join([f"{name}({reason})" for name, reason in all_removals])
|
||||
|
||||
available_actions = list(self.action_manager.get_using_actions().keys())
|
||||
available_actions_text = "、".join(available_actions) if available_actions else "无"
|
||||
logger.debug(f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}")
|
||||
|
||||
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: BotChatSession):
|
||||
type_mismatched_actions: List[Tuple[str, str]] = []
|
||||
for action_name, action_info in all_actions.items():
|
||||
if action_info.associated_types and not chat_context.check_types(action_info.associated_types):
|
||||
associated_types_str = ", ".join(action_info.associated_types)
|
||||
reason = f"适配器不支持(需要: {associated_types_str})"
|
||||
type_mismatched_actions.append((action_name, reason))
|
||||
logger.debug(f"{self.log_prefix}决定移除动作: {action_name},原因: {reason}")
|
||||
return type_mismatched_actions
|
||||
|
||||
async def _get_deactivated_actions_by_type(
|
||||
self,
|
||||
actions_with_info: Dict[str, ActionInfo],
|
||||
chat_content: str = "",
|
||||
) -> List[tuple[str, str]]:
|
||||
"""
|
||||
根据激活类型过滤,返回需要停用的动作列表及原因
|
||||
|
||||
Args:
|
||||
actions_with_info: 带完整信息的动作字典
|
||||
chat_content: 聊天内容
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str]]: 需要停用的 (action_name, reason) 元组列表
|
||||
"""
|
||||
deactivated_actions = []
|
||||
|
||||
actions_to_check = list(actions_with_info.items())
|
||||
random.shuffle(actions_to_check)
|
||||
|
||||
for action_name, action_info in actions_to_check:
|
||||
activation_type = action_info.activation_type
|
||||
|
||||
if activation_type == ActionActivationType.ALWAYS:
|
||||
continue # 总是激活,无需处理
|
||||
|
||||
elif activation_type == ActionActivationType.RANDOM:
|
||||
probability = action_info.random_activation_probability
|
||||
if random.random() >= probability:
|
||||
reason = f"RANDOM类型未触发(概率{probability})"
|
||||
deactivated_actions.append((action_name, reason))
|
||||
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: {reason}")
|
||||
|
||||
elif activation_type == ActionActivationType.KEYWORD:
|
||||
if not self._check_keyword_activation(action_name, action_info, chat_content):
|
||||
keywords = action_info.activation_keywords
|
||||
reason = f"关键词未匹配(关键词: {keywords})"
|
||||
deactivated_actions.append((action_name, reason))
|
||||
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: {reason}")
|
||||
|
||||
elif activation_type == ActionActivationType.NEVER:
|
||||
reason = "激活类型为never"
|
||||
deactivated_actions.append((action_name, reason))
|
||||
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: 激活类型为never")
|
||||
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix}未知的激活类型: {activation_type},跳过处理")
|
||||
|
||||
return deactivated_actions
|
||||
|
||||
def _check_keyword_activation(
|
||||
self,
|
||||
action_name: str,
|
||||
action_info: ActionInfo,
|
||||
chat_content: str = "",
|
||||
) -> bool:
|
||||
"""
|
||||
检查是否匹配关键词触发条件
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
action_info: 动作信息
|
||||
observed_messages_str: 观察到的聊天消息
|
||||
chat_context: 聊天上下文
|
||||
extra_context: 额外上下文
|
||||
|
||||
Returns:
|
||||
bool: 是否应该激活此action
|
||||
"""
|
||||
|
||||
activation_keywords = action_info.activation_keywords
|
||||
case_sensitive = action_info.keyword_case_sensitive
|
||||
|
||||
if not activation_keywords:
|
||||
logger.warning(f"{self.log_prefix}动作 {action_name} 设置为关键词触发但未配置关键词")
|
||||
return False
|
||||
|
||||
# 构建检索文本
|
||||
search_text = ""
|
||||
if chat_content:
|
||||
search_text += chat_content
|
||||
# if chat_context:
|
||||
# search_text += f" {chat_context}"
|
||||
# if extra_context:
|
||||
# search_text += f" {extra_context}"
|
||||
|
||||
# 如果不区分大小写,转换为小写
|
||||
if not case_sensitive:
|
||||
search_text = search_text.lower()
|
||||
|
||||
# 检查每个关键词
|
||||
matched_keywords = []
|
||||
for keyword in activation_keywords:
|
||||
check_keyword = keyword if case_sensitive else keyword.lower()
|
||||
if check_keyword in search_text:
|
||||
matched_keywords.append(keyword)
|
||||
|
||||
if matched_keywords:
|
||||
logger.debug(f"{self.log_prefix}动作 {action_name} 匹配到关键词: {matched_keywords}")
|
||||
return True
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix}动作 {action_name} 未匹配到任何关键词: {activation_keywords}")
|
||||
return False
|
||||
@@ -1,935 +0,0 @@
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from json_repair import repair_json
|
||||
from rich.traceback import install
|
||||
|
||||
from src.chat.logger.plan_reply_logger import PlanReplyLogger
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.core.types import ActionActivationType, ActionInfo, ComponentType
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.person_info.person_info import Person
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.message_service import (
|
||||
build_readable_messages_with_id,
|
||||
get_messages_before_time_in_chat,
|
||||
replace_user_references,
|
||||
translate_pid_to_description,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.info_data_model import TargetPersonInfo
|
||||
|
||||
logger = get_logger("planner")
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
class ActionPlanner:
|
||||
def __init__(self, chat_id: str, action_manager: ActionManager):
|
||||
self.chat_id = chat_id
|
||||
self.log_prefix = f"[{_chat_manager.get_session_name(chat_id) or chat_id}]"
|
||||
self.action_manager = action_manager
|
||||
# LLM规划器配置
|
||||
self.planner_llm = LLMServiceClient(
|
||||
task_name="planner", request_type="planner"
|
||||
) # 用于动作规划
|
||||
|
||||
self.last_obs_time_mark = 0.0
|
||||
|
||||
self.plan_log: List[Tuple[str, float, Union[List[ActionPlannerInfo], str]]] = []
|
||||
|
||||
# 黑话缓存:使用 OrderedDict 实现 LRU,最多缓存10个
|
||||
self.unknown_words_cache: OrderedDict[str, None] = OrderedDict()
|
||||
self.unknown_words_cache_limit = 10
|
||||
|
||||
def find_message_by_id(
|
||||
self, message_id: str, message_id_list: List[Tuple[str, "SessionMessage"]]
|
||||
) -> Optional["SessionMessage"]:
|
||||
# sourcery skip: use-next
|
||||
"""
|
||||
根据message_id从message_id_list中查找对应的原始消息
|
||||
|
||||
Args:
|
||||
message_id: 要查找的消息ID
|
||||
message_id_list: 消息ID列表,格式为[{'id': str, 'message': dict}, ...]
|
||||
|
||||
Returns:
|
||||
找到的原始消息字典,如果未找到则返回None
|
||||
"""
|
||||
for item in message_id_list:
|
||||
if item[0] == message_id:
|
||||
return item[1]
|
||||
return None
|
||||
|
||||
def _replace_message_ids_with_text(
|
||||
self, text: Optional[str], message_id_list: List[Tuple[str, "SessionMessage"]]
|
||||
) -> Optional[str]:
|
||||
"""将文本中的 m+数字 消息ID替换为原消息内容,并添加双引号"""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
id_to_message = dict(message_id_list)
|
||||
|
||||
# 匹配m后带2-4位数字,前后不是字母数字下划线
|
||||
pattern = r"(?<![A-Za-z0-9_])m\d{2,4}(?![A-Za-z0-9_])"
|
||||
|
||||
matches = re.findall(pattern, text)
|
||||
if matches:
|
||||
available_ids = set(id_to_message.keys())
|
||||
found_ids = set(matches)
|
||||
missing_ids = found_ids - available_ids
|
||||
if missing_ids:
|
||||
logger.info(
|
||||
f"{self.log_prefix}planner理由中引用的消息ID不在当前上下文中: {missing_ids}, 可用ID: {list(available_ids)[:10]}..."
|
||||
)
|
||||
logger.info(
|
||||
f"{self.log_prefix}planner理由替换: 找到{len(matches)}个消息ID引用,其中{len(found_ids & available_ids)}个在上下文中"
|
||||
)
|
||||
|
||||
def _replace(match: re.Match[str]) -> str:
|
||||
msg_id = match.group(0)
|
||||
message = id_to_message.get(msg_id)
|
||||
if not message:
|
||||
logger.warning(f"{self.log_prefix}planner理由引用 {msg_id} 未找到对应消息,保持原样")
|
||||
return msg_id
|
||||
|
||||
msg_text = (message.processed_plain_text or "").strip()
|
||||
if not msg_text:
|
||||
logger.warning(f"{self.log_prefix}planner理由引用 {msg_id} 的消息内容为空,保持原样")
|
||||
return msg_id
|
||||
|
||||
# 替换 [picid:xxx] 为 [图片:描述]
|
||||
pic_pattern = r"\[picid:([^\]]+)\]"
|
||||
|
||||
def replace_pic_id(pic_match: re.Match) -> str:
|
||||
pic_id = pic_match.group(1)
|
||||
description = translate_pid_to_description(pic_id)
|
||||
return f"[图片:{description}]"
|
||||
|
||||
msg_text = re.sub(pic_pattern, replace_pic_id, msg_text)
|
||||
|
||||
# 替换用户引用格式:回复<aaa:bbb> 和 @<aaa:bbb>
|
||||
platform = message.platform or ""
|
||||
if not platform:
|
||||
logger.warning(
|
||||
f"{self.log_prefix}planner: message {message.message_id} has no platform set, bot-self detection will be skipped"
|
||||
)
|
||||
msg_text = replace_user_references(msg_text, platform, replace_bot_name=True)
|
||||
|
||||
# 替换单独的 <用户名:用户ID> 格式(replace_user_references 已处理回复<和@<格式)
|
||||
# 匹配所有 <aaa:bbb> 格式,由于 replace_user_references 已经替换了回复<和@<格式,
|
||||
# 这里匹配到的应该都是单独的格式
|
||||
user_ref_pattern = r"<([^:<>]+):([^:<>]+)>"
|
||||
|
||||
def replace_user_ref(user_match: re.Match) -> str:
|
||||
user_name = user_match.group(1)
|
||||
user_id = user_match.group(2)
|
||||
try:
|
||||
# 检查是否是机器人自己
|
||||
if is_bot_self(platform, str(user_id)):
|
||||
return f"{global_config.bot.nickname}(你)"
|
||||
person = Person(platform=platform, user_id=user_id)
|
||||
return person.person_name or user_name
|
||||
except Exception:
|
||||
# 如果解析失败,使用原始昵称
|
||||
return user_name
|
||||
|
||||
msg_text = re.sub(user_ref_pattern, replace_user_ref, msg_text)
|
||||
|
||||
preview = msg_text if len(msg_text) <= 100 else f"{msg_text[:97]}..."
|
||||
logger.info(f"{self.log_prefix}planner理由引用 {msg_id} -> 消息({preview})")
|
||||
return f"消息({msg_text})"
|
||||
|
||||
return re.sub(pattern, _replace, text)
|
||||
|
||||
def _parse_single_action(
|
||||
self,
|
||||
action_json: dict,
|
||||
message_id_list: List[Tuple[str, "SessionMessage"]],
|
||||
current_available_actions: List[Tuple[str, ActionInfo]],
|
||||
extracted_reasoning: str = "",
|
||||
) -> List[ActionPlannerInfo]:
|
||||
"""解析单个action JSON并返回ActionPlannerInfo列表"""
|
||||
action_planner_infos = []
|
||||
|
||||
try:
|
||||
action = action_json.get("action", "no_reply")
|
||||
# 使用 extracted_reasoning(整体推理文本)作为 reasoning
|
||||
if extracted_reasoning:
|
||||
reasoning = self._replace_message_ids_with_text(extracted_reasoning, message_id_list)
|
||||
if reasoning is None:
|
||||
reasoning = extracted_reasoning
|
||||
else:
|
||||
reasoning = "未提供原因"
|
||||
action_data = {key: value for key, value in action_json.items() if key not in ["action"]}
|
||||
|
||||
# 非no_reply动作需要target_message_id
|
||||
target_message = None
|
||||
|
||||
target_message_id = action_json.get("target_message_id")
|
||||
if target_message_id:
|
||||
# 根据target_message_id查找原始消息
|
||||
target_message = self.find_message_by_id(target_message_id, message_id_list)
|
||||
if target_message is None:
|
||||
logger.warning(f"{self.log_prefix}无法找到target_message_id '{target_message_id}' 对应的消息")
|
||||
# 选择最新消息作为target_message
|
||||
target_message = message_id_list[-1][1]
|
||||
else:
|
||||
target_message = message_id_list[-1][1]
|
||||
logger.debug(f"{self.log_prefix}动作'{action}'缺少target_message_id,使用最新消息作为target_message")
|
||||
|
||||
if action != "no_reply" and target_message is not None and self._is_message_from_self(target_message):
|
||||
logger.info(
|
||||
f"{self.log_prefix}Planner选择了自己的消息 {target_message_id or target_message.message_id} 作为目标,强制使用 no_reply"
|
||||
)
|
||||
reasoning = f"目标消息 {target_message_id or target_message.message_id} 来自机器人自身,违反不回复自身消息规则。原始理由: {reasoning}"
|
||||
action = "no_reply"
|
||||
target_message = None
|
||||
|
||||
# 验证action是否可用
|
||||
available_action_names = [action_name for action_name, _ in current_available_actions]
|
||||
internal_action_names = ["no_reply", "reply", "wait_time"]
|
||||
|
||||
if action not in internal_action_names and action not in available_action_names:
|
||||
logger.warning(
|
||||
f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (可用: {available_action_names}),将强制使用 'no_reply'"
|
||||
)
|
||||
reasoning = (
|
||||
f"LLM 返回了当前不可用的动作 '{action}' (可用: {available_action_names})。原始理由: {reasoning}"
|
||||
)
|
||||
action = "no_reply"
|
||||
|
||||
# 创建ActionPlannerInfo对象
|
||||
# 将列表转换为字典格式
|
||||
available_actions_dict = dict(current_available_actions)
|
||||
action_planner_infos.append(
|
||||
ActionPlannerInfo(
|
||||
action_type=action,
|
||||
reasoning=reasoning,
|
||||
action_data=action_data,
|
||||
action_message=target_message,
|
||||
available_actions=available_actions_dict,
|
||||
action_reasoning=extracted_reasoning or None,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}解析单个action时出错: {e}")
|
||||
# 将列表转换为字典格式
|
||||
available_actions_dict = dict(current_available_actions)
|
||||
action_planner_infos.append(
|
||||
ActionPlannerInfo(
|
||||
action_type="no_reply",
|
||||
reasoning=f"解析单个action时出错: {e}",
|
||||
action_data={},
|
||||
action_message=None,
|
||||
available_actions=available_actions_dict,
|
||||
action_reasoning=extracted_reasoning or None,
|
||||
)
|
||||
)
|
||||
|
||||
return action_planner_infos
|
||||
|
||||
def _is_message_from_self(self, message: "SessionMessage") -> bool:
|
||||
"""判断消息是否由机器人自身发送(支持多平台,包括 WebUI)"""
|
||||
try:
|
||||
return is_bot_self(message.platform or "", str(message.message_info.user_info.user_id))
|
||||
except AttributeError:
|
||||
logger.warning(f"{self.log_prefix}检测消息发送者失败,缺少必要字段")
|
||||
return False
|
||||
|
||||
def _update_unknown_words_cache(self, new_words: List[str]) -> None:
|
||||
"""
|
||||
更新黑话缓存,将新的黑话加入缓存
|
||||
|
||||
Args:
|
||||
new_words: 新提取的黑话列表
|
||||
"""
|
||||
for word in new_words:
|
||||
if not isinstance(word, str):
|
||||
continue
|
||||
word = word.strip()
|
||||
if not word:
|
||||
continue
|
||||
|
||||
# 如果已存在,移到末尾(LRU)
|
||||
if word in self.unknown_words_cache:
|
||||
self.unknown_words_cache.move_to_end(word)
|
||||
else:
|
||||
# 添加新词
|
||||
self.unknown_words_cache[word] = None
|
||||
# 如果超过限制,移除最老的
|
||||
if len(self.unknown_words_cache) > self.unknown_words_cache_limit:
|
||||
self.unknown_words_cache.popitem(last=False)
|
||||
logger.debug(f"{self.log_prefix}黑话缓存已满,移除最老的黑话")
|
||||
|
||||
def _merge_unknown_words_with_cache(self, new_words: Optional[List[str]]) -> List[str]:
|
||||
"""
|
||||
合并新提取的黑话和缓存中的黑话
|
||||
|
||||
Args:
|
||||
new_words: 新提取的黑话列表(可能为None)
|
||||
|
||||
Returns:
|
||||
合并后的黑话列表(去重)
|
||||
"""
|
||||
# 清理新提取的黑话
|
||||
cleaned_new_words: List[str] = []
|
||||
if new_words:
|
||||
for word in new_words:
|
||||
if isinstance(word, str):
|
||||
if word := word.strip():
|
||||
cleaned_new_words.append(word)
|
||||
|
||||
# 获取缓存中的黑话列表
|
||||
cached_words = list(self.unknown_words_cache.keys())
|
||||
|
||||
# 合并并去重(保留顺序:新提取的在前,缓存的在后)
|
||||
merged_words: List[str] = []
|
||||
seen = set()
|
||||
|
||||
# 先添加新提取的
|
||||
for word in cleaned_new_words:
|
||||
if word not in seen:
|
||||
merged_words.append(word)
|
||||
seen.add(word)
|
||||
|
||||
# 再添加缓存的(如果不在新提取的列表中)
|
||||
for word in cached_words:
|
||||
if word not in seen:
|
||||
merged_words.append(word)
|
||||
seen.add(word)
|
||||
|
||||
return merged_words
|
||||
|
||||
def _process_unknown_words_cache(self, actions: List[ActionPlannerInfo]) -> None:
|
||||
"""
|
||||
处理黑话缓存逻辑:
|
||||
1. 检查是否有 reply action 提取了 unknown_words
|
||||
2. 如果没有提取,移除最老的1个
|
||||
3. 如果缓存数量大于5,移除最老的2个
|
||||
4. 对于每个 reply action,合并缓存和新提取的黑话
|
||||
5. 更新缓存
|
||||
|
||||
Args:
|
||||
actions: 解析后的动作列表
|
||||
"""
|
||||
# 先检查缓存数量,如果大于5,移除最老的2个
|
||||
if len(self.unknown_words_cache) > 5:
|
||||
# 移除最老的2个
|
||||
removed_count = 0
|
||||
for _ in range(2):
|
||||
if len(self.unknown_words_cache) > 0:
|
||||
self.unknown_words_cache.popitem(last=False)
|
||||
removed_count += 1
|
||||
if removed_count > 0:
|
||||
logger.debug(f"{self.log_prefix}缓存数量大于5,移除最老的{removed_count}个缓存")
|
||||
|
||||
# 检查是否有 reply action 提取了 unknown_words
|
||||
has_extracted_unknown_words = False
|
||||
for action in actions:
|
||||
if action.action_type == "reply":
|
||||
action_data = action.action_data or {}
|
||||
unknown_words = action_data.get("unknown_words")
|
||||
if unknown_words and isinstance(unknown_words, list) and len(unknown_words) > 0:
|
||||
has_extracted_unknown_words = True
|
||||
break
|
||||
|
||||
# 如果当前 plan 的 reply 没有提取,移除最老的1个
|
||||
if not has_extracted_unknown_words and len(self.unknown_words_cache) > 0:
|
||||
self.unknown_words_cache.popitem(last=False)
|
||||
logger.debug(f"{self.log_prefix}当前 plan 的 reply 没有提取黑话,移除最老的1个缓存")
|
||||
|
||||
# 对于每个 reply action,合并缓存和新提取的黑话
|
||||
for action in actions:
|
||||
if action.action_type == "reply":
|
||||
action_data = action.action_data or {}
|
||||
new_words = action_data.get("unknown_words")
|
||||
|
||||
# 合并新提取的和缓存的黑话列表
|
||||
if merged_words := self._merge_unknown_words_with_cache(new_words):
|
||||
action_data["unknown_words"] = merged_words
|
||||
logger.debug(
|
||||
f"{self.log_prefix}合并黑话:新提取 {len(new_words) if new_words else 0} 个,"
|
||||
f"缓存 {len(self.unknown_words_cache)} 个,合并后 {len(merged_words)} 个"
|
||||
)
|
||||
else:
|
||||
# 如果没有合并后的黑话,移除 unknown_words 字段
|
||||
action_data.pop("unknown_words", None)
|
||||
|
||||
# 更新缓存(将新提取的黑话加入缓存)
|
||||
if new_words:
|
||||
self._update_unknown_words_cache(new_words)
|
||||
|
||||
async def plan(
|
||||
self,
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
loop_start_time: float = 0.0,
|
||||
force_reply_message: Optional["SessionMessage"] = None,
|
||||
) -> List[ActionPlannerInfo]:
|
||||
# sourcery skip: use-named-expression
|
||||
"""
|
||||
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
||||
"""
|
||||
plan_start = time.perf_counter()
|
||||
|
||||
# 获取聊天上下文
|
||||
message_list_before_now = get_messages_before_time_in_chat(
|
||||
chat_id=self.chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.6),
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
message_id_list: list[Tuple[str, "SessionMessage"]] = []
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
messages=message_list_before_now,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=self.last_obs_time_mark,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
message_list_before_now_short = message_list_before_now[-int(global_config.chat.max_context_size * 0.3) :]
|
||||
chat_content_block_short, message_id_list_short = build_readable_messages_with_id(
|
||||
messages=message_list_before_now_short,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
truncate=False,
|
||||
show_actions=False,
|
||||
)
|
||||
|
||||
self.last_obs_time_mark = time.time()
|
||||
|
||||
# 获取必要信息
|
||||
is_group_chat, chat_target_info, current_available_actions = self.get_necessary_info()
|
||||
|
||||
# 应用激活类型过滤
|
||||
filtered_actions = self._filter_actions_by_activation_type(available_actions, chat_content_block_short)
|
||||
|
||||
logger.debug(f"{self.log_prefix}过滤后有{len(filtered_actions)}个可用动作")
|
||||
|
||||
prompt_build_start = time.perf_counter()
|
||||
# 构建包含所有动作的提示词
|
||||
prompt, message_id_list = await self.build_planner_prompt(
|
||||
is_group_chat=is_group_chat,
|
||||
chat_target_info=chat_target_info,
|
||||
current_available_actions=filtered_actions,
|
||||
chat_content_block=chat_content_block,
|
||||
message_id_list=message_id_list,
|
||||
)
|
||||
prompt_build_ms = (time.perf_counter() - prompt_build_start) * 1000
|
||||
|
||||
# 调用LLM获取决策
|
||||
reasoning, actions, llm_raw_output, llm_reasoning, llm_duration_ms = await self._execute_main_planner(
|
||||
prompt=prompt,
|
||||
message_id_list=message_id_list,
|
||||
filtered_actions=filtered_actions,
|
||||
available_actions=available_actions,
|
||||
loop_start_time=loop_start_time,
|
||||
)
|
||||
|
||||
# 如果有强制回复消息,确保回复该消息
|
||||
if force_reply_message:
|
||||
# 检查是否已经有回复该消息的 action
|
||||
has_reply_to_force_message = any(
|
||||
action.action_type == "reply"
|
||||
and action.action_message
|
||||
and action.action_message.message_id == force_reply_message.message_id
|
||||
for action in actions
|
||||
)
|
||||
|
||||
# 如果没有回复该消息,强制添加回复 action
|
||||
if not has_reply_to_force_message:
|
||||
# 移除所有 no_reply action(如果有)
|
||||
actions = [a for a in actions if a.action_type != "no_reply"]
|
||||
|
||||
# 创建强制回复 action
|
||||
available_actions_dict = dict(current_available_actions)
|
||||
force_reply_action = ActionPlannerInfo(
|
||||
action_type="reply",
|
||||
reasoning="用户提及了我,必须回复该消息",
|
||||
action_data={"loop_start_time": loop_start_time},
|
||||
action_message=force_reply_message,
|
||||
available_actions=available_actions_dict,
|
||||
action_reasoning=None,
|
||||
)
|
||||
# 将强制回复 action 放在最前面
|
||||
actions.insert(0, force_reply_action)
|
||||
logger.info(f"{self.log_prefix} 检测到强制回复消息,已添加回复动作")
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix}Planner:{reasoning}。选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}"
|
||||
)
|
||||
|
||||
self.add_plan_log(reasoning, actions)
|
||||
|
||||
try:
|
||||
PlanReplyLogger.log_plan(
|
||||
chat_id=self.chat_id,
|
||||
prompt=prompt,
|
||||
reasoning=reasoning,
|
||||
raw_output=llm_raw_output,
|
||||
raw_reasoning=llm_reasoning,
|
||||
actions=actions,
|
||||
timing={
|
||||
"prompt_build_ms": round(prompt_build_ms, 2),
|
||||
"llm_duration_ms": round(llm_duration_ms, 2) if llm_duration_ms is not None else None,
|
||||
"total_plan_ms": round((time.perf_counter() - plan_start) * 1000, 2),
|
||||
"loop_start_time": loop_start_time,
|
||||
},
|
||||
extra=None,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"{self.log_prefix}记录plan日志失败")
|
||||
|
||||
return actions
|
||||
|
||||
def add_plan_log(self, reasoning: str, actions: List[ActionPlannerInfo]):
|
||||
self.plan_log.append((reasoning, time.time(), actions))
|
||||
if len(self.plan_log) > 20:
|
||||
self.plan_log.pop(0)
|
||||
|
||||
def add_plan_excute_log(self, result: str):
|
||||
self.plan_log.append(("", time.time(), result))
|
||||
if len(self.plan_log) > 20:
|
||||
self.plan_log.pop(0)
|
||||
|
||||
def get_plan_log_str(self, max_action_records: int = 2, max_execution_records: int = 5) -> str:
|
||||
"""
|
||||
获取计划日志字符串
|
||||
|
||||
Args:
|
||||
max_action_records: 显示多少条最新的action记录,默认2
|
||||
max_execution_records: 显示多少条最新执行结果记录,默认8
|
||||
|
||||
Returns:
|
||||
格式化的日志字符串
|
||||
"""
|
||||
action_records = []
|
||||
execution_records = []
|
||||
|
||||
# 从后往前遍历,收集最新的记录
|
||||
for reasoning, timestamp, content in reversed(self.plan_log):
|
||||
if isinstance(content, list) and all(isinstance(action, ActionPlannerInfo) for action in content):
|
||||
if len(action_records) < max_action_records:
|
||||
action_records.append((reasoning, timestamp, content, "action"))
|
||||
elif len(execution_records) < max_execution_records:
|
||||
execution_records.append((reasoning, timestamp, content, "execution"))
|
||||
|
||||
# 合并所有记录并按时间戳排序
|
||||
all_records = action_records + execution_records
|
||||
all_records.sort(key=lambda x: x[1]) # 按时间戳排序
|
||||
|
||||
plan_log_str = ""
|
||||
|
||||
# 按时间顺序添加所有记录
|
||||
for reasoning, timestamp, content, record_type in all_records:
|
||||
time_str = datetime.fromtimestamp(timestamp).strftime("%H:%M:%S")
|
||||
if record_type == "action":
|
||||
# plan_log_str += f"{time_str}:{reasoning}|你使用了{','.join([action.action_type for action in content])}\n"
|
||||
plan_log_str += f"{time_str}:{reasoning}\n"
|
||||
else:
|
||||
plan_log_str += f"{time_str}:你执行了action:{content}\n"
|
||||
|
||||
return plan_log_str
|
||||
|
||||
async def build_planner_prompt(
|
||||
self,
|
||||
is_group_chat: bool,
|
||||
chat_target_info: Optional["TargetPersonInfo"],
|
||||
current_available_actions: Dict[str, ActionInfo],
|
||||
message_id_list: List[Tuple[str, "SessionMessage"]],
|
||||
chat_content_block: str = "",
|
||||
interest: str = "",
|
||||
) -> tuple[str, List[Tuple[str, "SessionMessage"]]]:
|
||||
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
|
||||
try:
|
||||
actions_before_now_block = self.get_plan_log_str()
|
||||
|
||||
# 构建聊天上下文描述
|
||||
chat_context_description = "你现在正在一个群聊中"
|
||||
|
||||
# 构建动作选项块
|
||||
action_options_block = await self._build_action_options_block(current_available_actions)
|
||||
|
||||
# 其他信息
|
||||
moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
|
||||
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
bot_name = global_config.bot.nickname
|
||||
bot_nickname = (
|
||||
f",也有人叫你{','.join(global_config.bot.alias_names)}" if global_config.bot.alias_names else ""
|
||||
)
|
||||
name_block = f"你的名字是{bot_name}{bot_nickname},请注意哪些是你自己的发言。"
|
||||
|
||||
# 根据 think_mode 配置决定 reply action 的示例 JSON
|
||||
# 在 JSON 中直接作为 action 参数携带 unknown_words
|
||||
if global_config.chat.think_mode == "classic":
|
||||
reply_action_example = ""
|
||||
if global_config.chat.llm_quote:
|
||||
reply_action_example += (
|
||||
"5.如果要明确回复消息,使用quote,如果消息不多不需要明确回复,设置quote为false\n"
|
||||
)
|
||||
reply_action_example += (
|
||||
'{{"action":"reply", "target_message_id":"消息id(m+数字)", "unknown_words":["词语1","词语2"]'
|
||||
)
|
||||
if global_config.chat.llm_quote:
|
||||
reply_action_example += ', "quote":"如果需要引用该message,设置为true"'
|
||||
reply_action_example += "}"
|
||||
else:
|
||||
reply_action_example = (
|
||||
"5.think_level表示思考深度,0表示该回复不需要思考和回忆,1表示该回复需要进行回忆和思考\n"
|
||||
)
|
||||
if global_config.chat.llm_quote:
|
||||
reply_action_example += (
|
||||
"6.如果要明确回复消息,使用quote,如果消息不多不需要明确回复,设置quote为false\n"
|
||||
)
|
||||
reply_action_example += (
|
||||
'{{"action":"reply", "think_level":数值等级(0或1), '
|
||||
'"target_message_id":"消息id(m+数字)", '
|
||||
'"unknown_words":["词语1","词语2"]'
|
||||
)
|
||||
if global_config.chat.llm_quote:
|
||||
reply_action_example += ', "quote":"如果需要引用该message,设置为true"'
|
||||
reply_action_example += "}"
|
||||
|
||||
planner_prompt_template = prompt_manager.get_prompt("planner")
|
||||
planner_prompt_template.add_context("time_block", time_block)
|
||||
planner_prompt_template.add_context("chat_context_description", chat_context_description)
|
||||
planner_prompt_template.add_context("chat_content_block", chat_content_block)
|
||||
planner_prompt_template.add_context("actions_before_now_block", actions_before_now_block)
|
||||
planner_prompt_template.add_context("action_options_text", action_options_block)
|
||||
planner_prompt_template.add_context("moderation_prompt", moderation_prompt_block)
|
||||
planner_prompt_template.add_context("name_block", name_block)
|
||||
planner_prompt_template.add_context("interest", interest)
|
||||
planner_prompt_template.add_context("plan_style", global_config.personality.plan_style)
|
||||
planner_prompt_template.add_context("reply_action_example", reply_action_example)
|
||||
prompt = await prompt_manager.render_prompt(planner_prompt_template)
|
||||
|
||||
return prompt, message_id_list
|
||||
except Exception as e:
|
||||
logger.error(f"构建 Planner 提示词时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return "构建 Planner Prompt 时出错", []
|
||||
|
||||
def get_necessary_info(self) -> Tuple[bool, Optional["TargetPersonInfo"], Dict[str, ActionInfo]]:
|
||||
"""
|
||||
获取 Planner 需要的必要信息
|
||||
"""
|
||||
is_group_chat = True
|
||||
is_group_chat, chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
||||
logger.debug(f"{self.log_prefix}获取到聊天信息 - 群聊: {is_group_chat}, 目标信息: {chat_target_info}")
|
||||
|
||||
current_available_actions_dict = self.action_manager.get_using_actions()
|
||||
|
||||
# 获取完整的动作信息
|
||||
all_registered_actions: Dict[str, ActionInfo] = component_query_service.get_components_by_type( # type: ignore
|
||||
ComponentType.ACTION
|
||||
)
|
||||
current_available_actions = {}
|
||||
for action_name in current_available_actions_dict:
|
||||
if action_name in all_registered_actions:
|
||||
current_available_actions[action_name] = all_registered_actions[action_name]
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix}使用中的动作 {action_name} 未在已注册动作中找到")
|
||||
|
||||
return is_group_chat, chat_target_info, current_available_actions
|
||||
|
||||
def _filter_actions_by_activation_type(
|
||||
self, available_actions: Dict[str, ActionInfo], chat_content_block: str
|
||||
) -> Dict[str, ActionInfo]:
|
||||
"""根据激活类型过滤动作"""
|
||||
filtered_actions = {}
|
||||
|
||||
for action_name, action_info in available_actions.items():
|
||||
if action_info.activation_type == ActionActivationType.NEVER:
|
||||
logger.debug(f"{self.log_prefix}动作 {action_name} 设置为 NEVER 激活类型,跳过")
|
||||
continue
|
||||
elif action_info.activation_type == ActionActivationType.ALWAYS:
|
||||
filtered_actions[action_name] = action_info
|
||||
elif action_info.activation_type == ActionActivationType.RANDOM:
|
||||
if random.random() < action_info.random_activation_probability:
|
||||
filtered_actions[action_name] = action_info
|
||||
elif action_info.activation_type == ActionActivationType.KEYWORD:
|
||||
if action_info.activation_keywords:
|
||||
for keyword in action_info.activation_keywords:
|
||||
if keyword in chat_content_block:
|
||||
filtered_actions[action_name] = action_info
|
||||
break
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix}未知的激活类型: {action_info.activation_type},跳过处理")
|
||||
|
||||
return filtered_actions
|
||||
|
||||
async def _build_action_options_block(self, current_available_actions: Dict[str, ActionInfo]) -> str:
|
||||
"""构建动作选项块"""
|
||||
if not current_available_actions:
|
||||
return ""
|
||||
|
||||
action_options_block = ""
|
||||
for action_name, action_info in current_available_actions.items():
|
||||
# 构建参数文本
|
||||
param_text = ""
|
||||
if action_info.action_parameters:
|
||||
param_text = "\n"
|
||||
for param_name, param_description in action_info.action_parameters.items():
|
||||
param_text += f' "{param_name}":"{param_description}"\n'
|
||||
param_text = param_text.rstrip("\n")
|
||||
|
||||
# 构建要求文本
|
||||
require_text = "\n".join(f"- {require_item}" for require_item in action_info.action_require)
|
||||
|
||||
parallel_text = "" if action_info.parallel_action else "(当选择这个动作时,请不要选择其他动作)"
|
||||
|
||||
# 获取动作提示模板并填充
|
||||
using_action_prompt = prompt_manager.get_prompt("action")
|
||||
using_action_prompt.add_context("action_name", action_name)
|
||||
using_action_prompt.add_context("action_description", action_info.description)
|
||||
using_action_prompt.add_context("action_parameters", param_text)
|
||||
using_action_prompt.add_context("action_require", require_text)
|
||||
using_action_prompt.add_context("parallel_text", parallel_text)
|
||||
using_action_rendered_prompt = await prompt_manager.render_prompt(using_action_prompt)
|
||||
|
||||
action_options_block += using_action_rendered_prompt
|
||||
|
||||
return action_options_block
|
||||
|
||||
async def _execute_main_planner(
|
||||
self,
|
||||
prompt: str,
|
||||
message_id_list: List[Tuple[str, "SessionMessage"]],
|
||||
filtered_actions: Dict[str, ActionInfo],
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
loop_start_time: float,
|
||||
) -> Tuple[str, List[ActionPlannerInfo], Optional[str], Optional[str], Optional[float]]:
|
||||
"""执行主规划器"""
|
||||
llm_content = None
|
||||
actions: List[ActionPlannerInfo] = []
|
||||
llm_reasoning = None
|
||||
llm_duration_ms = None
|
||||
|
||||
try:
|
||||
# 调用LLM
|
||||
llm_start = time.perf_counter()
|
||||
generation_result = await self.planner_llm.generate_response(prompt=prompt)
|
||||
llm_content = generation_result.response
|
||||
reasoning_content = generation_result.reasoning
|
||||
llm_duration_ms = (time.perf_counter() - llm_start) * 1000
|
||||
llm_reasoning = reasoning_content
|
||||
|
||||
if global_config.debug.show_planner_prompt:
|
||||
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
if reasoning_content:
|
||||
logger.info(f"{self.log_prefix}规划器推理: {reasoning_content}")
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
logger.debug(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
if reasoning_content:
|
||||
logger.debug(f"{self.log_prefix}规划器推理: {reasoning_content}")
|
||||
|
||||
except Exception as req_e:
|
||||
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
|
||||
return (
|
||||
f"LLM 请求失败,模型出现问题: {req_e}",
|
||||
[
|
||||
ActionPlannerInfo(
|
||||
action_type="no_reply",
|
||||
reasoning=f"LLM 请求失败,模型出现问题: {req_e}",
|
||||
action_data={},
|
||||
action_message=None,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
],
|
||||
llm_content,
|
||||
llm_reasoning,
|
||||
llm_duration_ms,
|
||||
)
|
||||
|
||||
# 解析LLM响应
|
||||
extracted_reasoning = ""
|
||||
if llm_content:
|
||||
try:
|
||||
json_objects, extracted_reasoning = self._extract_json_from_markdown(llm_content)
|
||||
extracted_reasoning = self._replace_message_ids_with_text(extracted_reasoning, message_id_list) or ""
|
||||
if json_objects:
|
||||
logger.debug(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
|
||||
filtered_actions_list = list(filtered_actions.items())
|
||||
for json_obj in json_objects:
|
||||
actions.extend(
|
||||
self._parse_single_action(
|
||||
json_obj, message_id_list, filtered_actions_list, extracted_reasoning
|
||||
)
|
||||
)
|
||||
else:
|
||||
# 尝试解析为直接的JSON
|
||||
logger.warning(f"{self.log_prefix}LLM没有返回可用动作: {llm_content}")
|
||||
extracted_reasoning = "LLM没有返回可用动作"
|
||||
actions = self._create_no_reply("LLM没有返回可用动作", available_actions)
|
||||
|
||||
except Exception as json_e:
|
||||
logger.warning(f"{self.log_prefix}解析LLM响应JSON失败 {json_e}. LLM原始输出: '{llm_content}'")
|
||||
extracted_reasoning = f"解析LLM响应JSON失败: {json_e}"
|
||||
actions = self._create_no_reply(f"解析LLM响应JSON失败: {json_e}", available_actions)
|
||||
traceback.print_exc()
|
||||
else:
|
||||
extracted_reasoning = "规划器没有获得LLM响应"
|
||||
actions = self._create_no_reply("规划器没有获得LLM响应", available_actions)
|
||||
|
||||
# 添加循环开始时间到所有非no_reply动作
|
||||
for action in actions:
|
||||
action.action_data = action.action_data or {}
|
||||
action.action_data["loop_start_time"] = loop_start_time
|
||||
|
||||
# 去重:如果同一个动作被选择了多次,随机选择其中一个
|
||||
if actions:
|
||||
shuffled = actions.copy()
|
||||
random.shuffle(shuffled)
|
||||
actions = list({a.action_type: a for a in shuffled}.values())
|
||||
|
||||
# 处理黑话缓存逻辑
|
||||
self._process_unknown_words_cache(actions)
|
||||
|
||||
logger.debug(f"{self.log_prefix}规划器选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}")
|
||||
|
||||
return extracted_reasoning, actions, llm_content, llm_reasoning, llm_duration_ms
|
||||
|
||||
def _create_no_reply(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]:
|
||||
"""创建no_reply"""
|
||||
return [
|
||||
ActionPlannerInfo(
|
||||
action_type="no_reply",
|
||||
reasoning=reasoning,
|
||||
action_data={},
|
||||
action_message=None,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
]
|
||||
|
||||
def _extract_json_from_markdown(self, content: str) -> Tuple[List[dict], str]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""从Markdown格式的内容中提取JSON对象和推理内容"""
|
||||
json_objects = []
|
||||
reasoning_content = ""
|
||||
|
||||
# 使用正则表达式查找```json包裹的JSON内容
|
||||
json_pattern = r"```json\s*(.*?)\s*```"
|
||||
markdown_matches = re.findall(json_pattern, content, re.DOTALL)
|
||||
|
||||
# 提取JSON之前的内容作为推理文本
|
||||
first_json_pos = len(content)
|
||||
if markdown_matches:
|
||||
# 找到第一个```json的位置
|
||||
first_json_pos = content.find("```json")
|
||||
if first_json_pos > 0:
|
||||
reasoning_content = content[:first_json_pos].strip()
|
||||
# 清理推理内容中的注释标记
|
||||
reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE)
|
||||
reasoning_content = reasoning_content.strip()
|
||||
|
||||
# 处理```json包裹的JSON
|
||||
for match in markdown_matches:
|
||||
try:
|
||||
# 清理可能的注释和格式问题
|
||||
json_str = re.sub(r"//.*?\n", "\n", match) # 移除单行注释
|
||||
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL) # 移除多行注释
|
||||
if json_str := json_str.strip():
|
||||
# 尝试按行分割,每行可能是一个JSON对象
|
||||
lines = [line.strip() for line in json_str.split("\n") if line.strip()]
|
||||
for line in lines:
|
||||
with contextlib.suppress(json.JSONDecodeError):
|
||||
json_obj = json.loads(repair_json(line))
|
||||
if isinstance(json_obj, dict):
|
||||
if json_obj:
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict) and item:
|
||||
json_objects.append(item)
|
||||
|
||||
# 如果按行解析没有成功(或只得到空字典),尝试将整个块作为一个JSON对象或数组
|
||||
if not json_objects:
|
||||
json_obj = json.loads(repair_json(json_str))
|
||||
if isinstance(json_obj, dict):
|
||||
# 过滤掉空字典
|
||||
if json_obj:
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict) and item:
|
||||
json_objects.append(item)
|
||||
except Exception as e:
|
||||
logger.warning(f"解析JSON块失败: {e}, 块内容: {match[:100]}...")
|
||||
continue
|
||||
|
||||
# 如果没有找到完整的```json```块,尝试查找不完整的代码块(缺少结尾```)
|
||||
if not json_objects:
|
||||
json_start_pos = content.find("```json")
|
||||
if json_start_pos != -1:
|
||||
# 找到```json之后的内容
|
||||
json_content_start = json_start_pos + 7 # ```json的长度
|
||||
# 提取从```json之后到内容结尾的所有内容
|
||||
incomplete_json_str = content[json_content_start:].strip()
|
||||
|
||||
# 提取JSON之前的内容作为推理文本
|
||||
if json_start_pos > 0:
|
||||
reasoning_content = content[:json_start_pos].strip()
|
||||
reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE)
|
||||
reasoning_content = reasoning_content.strip()
|
||||
|
||||
if incomplete_json_str:
|
||||
try:
|
||||
# 清理可能的注释和格式问题
|
||||
json_str = re.sub(r"//.*?\n", "\n", incomplete_json_str)
|
||||
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL)
|
||||
json_str = json_str.strip()
|
||||
|
||||
if json_str:
|
||||
# 尝试按行分割,每行可能是一个JSON对象
|
||||
lines = [line.strip() for line in json_str.split("\n") if line.strip()]
|
||||
for line in lines:
|
||||
try:
|
||||
json_obj = json.loads(repair_json(line))
|
||||
if isinstance(json_obj, dict):
|
||||
# 过滤掉空字典,避免单个 { 字符被错误修复为 {} 的情况
|
||||
if json_obj:
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict) and item:
|
||||
json_objects.append(item)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 如果按行解析没有成功(或只得到空字典),尝试将整个块作为一个JSON对象或数组
|
||||
if not json_objects:
|
||||
try:
|
||||
json_obj = json.loads(repair_json(json_str))
|
||||
if isinstance(json_obj, dict):
|
||||
# 过滤掉空字典
|
||||
if json_obj:
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict) and item:
|
||||
json_objects.append(item)
|
||||
except Exception as e:
|
||||
logger.debug(f"尝试解析不完整的JSON代码块失败: {e}")
|
||||
except Exception as e:
|
||||
logger.debug(f"处理不完整的JSON代码块时出错: {e}")
|
||||
|
||||
return json_objects, reasoning_content
|
||||
335
src/core/tooling.py
Normal file
335
src/core/tooling.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""统一工具抽象。
|
||||
|
||||
该模块定义主程序内部统一使用的工具声明、调用与执行结果模型,
|
||||
用于收敛插件 Tool、兼容旧 Action、MaiSaka 内置 Tool 与 MCP Tool。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional, Protocol, runtime_checkable
|
||||
import json
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.payload_content.tool_option import ToolDefinitionInput
|
||||
|
||||
logger = get_logger("core.tooling")
|
||||
|
||||
|
||||
def _normalize_schema_type(raw_type: Any) -> str:
|
||||
"""将原始 Schema 类型值规范化为可读字符串。
|
||||
|
||||
Args:
|
||||
raw_type: 原始类型值。
|
||||
|
||||
Returns:
|
||||
str: 规范化后的类型名称。
|
||||
"""
|
||||
|
||||
normalized_type = str(raw_type or "").strip().lower()
|
||||
if not normalized_type:
|
||||
return "string"
|
||||
if normalized_type == "number":
|
||||
return "number"
|
||||
if normalized_type == "integer":
|
||||
return "integer"
|
||||
if normalized_type == "boolean":
|
||||
return "boolean"
|
||||
if normalized_type == "array":
|
||||
return "array"
|
||||
if normalized_type == "object":
|
||||
return "object"
|
||||
return normalized_type
|
||||
|
||||
|
||||
def build_tool_detailed_description(
|
||||
parameters_schema: Optional[Dict[str, Any]],
|
||||
fallback_description: str = "",
|
||||
) -> str:
|
||||
"""根据参数 Schema 构建工具详细描述。
|
||||
|
||||
Args:
|
||||
parameters_schema: 工具参数对象级 Schema。
|
||||
fallback_description: 无法从 Schema 解析时使用的兜底说明。
|
||||
|
||||
Returns:
|
||||
str: 生成后的详细描述文本。
|
||||
"""
|
||||
|
||||
if not parameters_schema:
|
||||
return fallback_description.strip()
|
||||
|
||||
properties = parameters_schema.get("properties")
|
||||
if not isinstance(properties, dict) or not properties:
|
||||
return fallback_description.strip()
|
||||
|
||||
required_names = {
|
||||
str(name).strip()
|
||||
for name in parameters_schema.get("required", [])
|
||||
if str(name).strip()
|
||||
}
|
||||
|
||||
lines = ["参数说明:"]
|
||||
for parameter_name, parameter_schema in properties.items():
|
||||
if not isinstance(parameter_schema, dict):
|
||||
continue
|
||||
|
||||
normalized_name = str(parameter_name).strip()
|
||||
parameter_type = _normalize_schema_type(parameter_schema.get("type"))
|
||||
required_text = "必填" if normalized_name in required_names else "可选"
|
||||
parameter_description = str(parameter_schema.get("description", "") or "").strip() or "无额外说明"
|
||||
line = f"- {normalized_name}:{parameter_type},{required_text}。{parameter_description}"
|
||||
|
||||
if isinstance(parameter_schema.get("enum"), list) and parameter_schema["enum"]:
|
||||
enum_values = "、".join(str(item) for item in parameter_schema["enum"])
|
||||
line += f" 可选值:{enum_values}。"
|
||||
|
||||
if "default" in parameter_schema:
|
||||
line += f" 默认值:{parameter_schema['default']}。"
|
||||
|
||||
lines.append(line)
|
||||
|
||||
if len(lines) == 1:
|
||||
return fallback_description.strip()
|
||||
|
||||
if fallback_description.strip():
|
||||
lines.append("")
|
||||
lines.append(fallback_description.strip())
|
||||
return "\n".join(lines).strip()
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ToolSpec:
|
||||
"""统一工具声明。"""
|
||||
|
||||
name: str
|
||||
brief_description: str
|
||||
detailed_description: str = ""
|
||||
parameters_schema: Dict[str, Any] | None = None
|
||||
provider_name: str = ""
|
||||
provider_type: str = ""
|
||||
enabled: bool = True
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def build_llm_description(self) -> str:
|
||||
"""构建供 LLM 使用的描述文本。
|
||||
|
||||
Returns:
|
||||
str: 合并后的单段工具描述。
|
||||
"""
|
||||
|
||||
parts = [self.brief_description.strip()]
|
||||
if self.detailed_description.strip():
|
||||
parts.append(self.detailed_description.strip())
|
||||
return "\n\n".join(part for part in parts if part).strip()
|
||||
|
||||
def to_llm_definition(self) -> ToolDefinitionInput:
|
||||
"""转换为统一的 LLM 工具定义。
|
||||
|
||||
Returns:
|
||||
ToolDefinitionInput: 可直接交给模型层的工具定义。
|
||||
"""
|
||||
|
||||
definition: Dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"description": self.build_llm_description(),
|
||||
}
|
||||
if self.parameters_schema is not None:
|
||||
definition["parameters_schema"] = deepcopy(self.parameters_schema)
|
||||
return definition
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ToolInvocation:
|
||||
"""统一工具调用请求。"""
|
||||
|
||||
tool_name: str
|
||||
arguments: Dict[str, Any] = field(default_factory=dict)
|
||||
call_id: str = ""
|
||||
session_id: str = ""
|
||||
stream_id: str = ""
|
||||
reasoning: str = ""
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ToolExecutionContext:
|
||||
"""统一工具执行上下文。"""
|
||||
|
||||
session_id: str = ""
|
||||
stream_id: str = ""
|
||||
reasoning: str = ""
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ToolExecutionResult:
|
||||
"""统一工具执行结果。"""
|
||||
|
||||
tool_name: str
|
||||
success: bool
|
||||
content: str = ""
|
||||
error_message: str = ""
|
||||
structured_content: Any = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def get_history_content(self) -> str:
|
||||
"""获取适合写入对话历史的结果文本。
|
||||
|
||||
Returns:
|
||||
str: 优先使用文本内容,其次使用错误信息。
|
||||
"""
|
||||
|
||||
if self.content.strip():
|
||||
return self.content.strip()
|
||||
if self.structured_content is not None:
|
||||
if isinstance(self.structured_content, str):
|
||||
return self.structured_content.strip()
|
||||
try:
|
||||
return json.dumps(self.structured_content, ensure_ascii=False)
|
||||
except (TypeError, ValueError):
|
||||
return str(self.structured_content).strip()
|
||||
return self.error_message.strip()
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ToolProvider(Protocol):
|
||||
"""统一工具提供者协议。"""
|
||||
|
||||
provider_name: str
|
||||
provider_type: str
|
||||
|
||||
async def list_tools(self) -> list[ToolSpec]:
|
||||
"""列出当前 Provider 暴露的全部工具。"""
|
||||
...
|
||||
|
||||
async def invoke(
|
||||
self,
|
||||
invocation: ToolInvocation,
|
||||
context: Optional[ToolExecutionContext] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""执行指定工具调用。"""
|
||||
...
|
||||
|
||||
async def close(self) -> None:
|
||||
"""释放 Provider 资源。"""
|
||||
...
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""统一工具注册表。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._providers: list[ToolProvider] = []
|
||||
|
||||
def register_provider(self, provider: ToolProvider) -> None:
|
||||
"""注册一个工具提供者。
|
||||
|
||||
Args:
|
||||
provider: 待注册的工具提供者。
|
||||
"""
|
||||
|
||||
self._providers = [item for item in self._providers if item.provider_name != provider.provider_name]
|
||||
self._providers.append(provider)
|
||||
|
||||
def unregister_provider(self, provider_name: str) -> None:
|
||||
"""注销指定名称的工具提供者。
|
||||
|
||||
Args:
|
||||
provider_name: 待移除的 Provider 名称。
|
||||
"""
|
||||
|
||||
self._providers = [item for item in self._providers if item.provider_name != provider_name]
|
||||
|
||||
async def list_tools(self) -> list[ToolSpec]:
|
||||
"""按 Provider 顺序列出全部去重后的工具。
|
||||
|
||||
Returns:
|
||||
list[ToolSpec]: 去重后的工具列表。
|
||||
"""
|
||||
|
||||
collected_specs: list[ToolSpec] = []
|
||||
seen_names: set[str] = set()
|
||||
|
||||
for provider in self._providers:
|
||||
provider_specs = await provider.list_tools()
|
||||
for spec in provider_specs:
|
||||
if not spec.enabled:
|
||||
continue
|
||||
if spec.name in seen_names:
|
||||
logger.warning(
|
||||
f"检测到重复工具名 {spec.name},保留先注册的工具,跳过 provider={provider.provider_name}"
|
||||
)
|
||||
continue
|
||||
seen_names.add(spec.name)
|
||||
collected_specs.append(spec)
|
||||
return collected_specs
|
||||
|
||||
async def get_tool_spec(self, tool_name: str) -> Optional[ToolSpec]:
|
||||
"""查询指定工具声明。
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称。
|
||||
|
||||
Returns:
|
||||
Optional[ToolSpec]: 匹配到的工具声明。
|
||||
"""
|
||||
|
||||
for spec in await self.list_tools():
|
||||
if spec.name == tool_name:
|
||||
return spec
|
||||
return None
|
||||
|
||||
async def has_tool(self, tool_name: str) -> bool:
|
||||
"""判断指定工具是否存在。
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称。
|
||||
|
||||
Returns:
|
||||
bool: 是否存在。
|
||||
"""
|
||||
|
||||
return await self.get_tool_spec(tool_name) is not None
|
||||
|
||||
async def get_llm_definitions(self) -> list[ToolDefinitionInput]:
|
||||
"""获取供 LLM 使用的工具定义列表。
|
||||
|
||||
Returns:
|
||||
list[ToolDefinitionInput]: 统一工具定义列表。
|
||||
"""
|
||||
|
||||
return [spec.to_llm_definition() for spec in await self.list_tools()]
|
||||
|
||||
async def invoke(
|
||||
self,
|
||||
invocation: ToolInvocation,
|
||||
context: Optional[ToolExecutionContext] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""执行一次工具调用。
|
||||
|
||||
Args:
|
||||
invocation: 工具调用请求。
|
||||
context: 执行上下文。
|
||||
|
||||
Returns:
|
||||
ToolExecutionResult: 工具执行结果。
|
||||
"""
|
||||
|
||||
for provider in self._providers:
|
||||
provider_specs = await provider.list_tools()
|
||||
if any(spec.name == invocation.tool_name and spec.enabled for spec in provider_specs):
|
||||
return await provider.invoke(invocation, context)
|
||||
|
||||
return ToolExecutionResult(
|
||||
tool_name=invocation.tool_name,
|
||||
success=False,
|
||||
error_message=f"未找到工具:{invocation.tool_name}",
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭全部 Provider。"""
|
||||
|
||||
for provider in self._providers:
|
||||
await provider.close()
|
||||
@@ -1,124 +1,163 @@
|
||||
"""
|
||||
MaiSaka built-in tool definitions.
|
||||
"""
|
||||
"""Maisaka 内置工具声明。"""
|
||||
|
||||
from typing import List
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from src.llm_models.payload_content.tool_option import ToolOption, ToolParamType
|
||||
from src.core.tooling import ToolSpec, build_tool_detailed_description
|
||||
from src.llm_models.payload_content.tool_option import ToolDefinitionInput
|
||||
|
||||
|
||||
def create_builtin_tools() -> List[ToolOption]:
|
||||
"""Create built-in tools exposed to the main chat-loop model."""
|
||||
from src.llm_models.payload_content.tool_option import ToolOptionBuilder
|
||||
def _build_tool_spec(
|
||||
name: str,
|
||||
brief_description: str,
|
||||
parameters_schema: Dict[str, Any] | None = None,
|
||||
detailed_description: str = "",
|
||||
) -> ToolSpec:
|
||||
"""构建单个内置工具声明。
|
||||
|
||||
tools: List[ToolOption] = []
|
||||
Args:
|
||||
name: 工具名称。
|
||||
brief_description: 简要描述。
|
||||
parameters_schema: 参数 Schema。
|
||||
detailed_description: 详细描述;为空时自动根据参数生成。
|
||||
|
||||
wait_builder = ToolOptionBuilder()
|
||||
wait_builder.set_name("wait")
|
||||
wait_builder.set_description("Pause speaking and wait for the user to provide more input.")
|
||||
wait_builder.add_param(
|
||||
name="seconds",
|
||||
param_type=ToolParamType.INTEGER,
|
||||
description="How many seconds to wait before timing out.",
|
||||
required=True,
|
||||
enum_values=None,
|
||||
Returns:
|
||||
ToolSpec: 构建完成的工具声明。
|
||||
"""
|
||||
|
||||
normalized_schema = deepcopy(parameters_schema) if parameters_schema is not None else None
|
||||
return ToolSpec(
|
||||
name=name,
|
||||
brief_description=brief_description,
|
||||
detailed_description=(
|
||||
detailed_description.strip()
|
||||
or build_tool_detailed_description(normalized_schema)
|
||||
),
|
||||
parameters_schema=normalized_schema,
|
||||
provider_name="maisaka_builtin",
|
||||
provider_type="builtin",
|
||||
)
|
||||
tools.append(wait_builder.build())
|
||||
|
||||
reply_builder = ToolOptionBuilder()
|
||||
reply_builder.set_name("reply")
|
||||
reply_builder.set_description(
|
||||
"Generate and emit a visible reply based on the current thought. "
|
||||
"You must specify the target user msg_id to reply to."
|
||||
)
|
||||
reply_builder.add_param(
|
||||
name="msg_id",
|
||||
param_type=ToolParamType.STRING,
|
||||
description="The msg_id of the specific user message that this reply should target.",
|
||||
required=True,
|
||||
enum_values=None,
|
||||
)
|
||||
reply_builder.add_param(
|
||||
name="quote",
|
||||
param_type=ToolParamType.BOOLEAN,
|
||||
description="Whether the visible reply should be sent as a quoted reply to the target msg_id.",
|
||||
required=False,
|
||||
enum_values=None,
|
||||
)
|
||||
reply_builder.add_param(
|
||||
name="unknown_words",
|
||||
param_type=ToolParamType.ARRAY,
|
||||
description="Optional list of words or phrases that may need jargon lookup before replying.",
|
||||
required=False,
|
||||
enum_values=None,
|
||||
items_schema={"type": "string"},
|
||||
)
|
||||
tools.append(reply_builder.build())
|
||||
|
||||
query_jargon_builder = ToolOptionBuilder()
|
||||
query_jargon_builder.set_name("query_jargon")
|
||||
query_jargon_builder.set_description(
|
||||
"Query the meanings of one or more jargon words in the current chat context."
|
||||
)
|
||||
query_jargon_builder.add_param(
|
||||
name="words",
|
||||
param_type=ToolParamType.ARRAY,
|
||||
description="A list of words or phrases to query from the jargon store.",
|
||||
required=True,
|
||||
enum_values=None,
|
||||
items_schema={"type": "string"},
|
||||
)
|
||||
tools.append(query_jargon_builder.build())
|
||||
|
||||
query_person_info_builder = ToolOptionBuilder()
|
||||
query_person_info_builder.set_name("query_person_info")
|
||||
query_person_info_builder.set_description(
|
||||
"Query profile and memory information about a specific person by person name, nickname, or user ID."
|
||||
)
|
||||
query_person_info_builder.add_param(
|
||||
name="person_name",
|
||||
param_type=ToolParamType.STRING,
|
||||
description="The person's name, nickname, or user ID to search for.",
|
||||
required=True,
|
||||
enum_values=None,
|
||||
)
|
||||
query_person_info_builder.add_param(
|
||||
name="limit",
|
||||
param_type=ToolParamType.INTEGER,
|
||||
description="Maximum number of matched person records to return. Defaults to 3.",
|
||||
required=False,
|
||||
enum_values=None,
|
||||
)
|
||||
tools.append(query_person_info_builder.build())
|
||||
|
||||
no_reply_builder = ToolOptionBuilder()
|
||||
no_reply_builder.set_name("no_reply")
|
||||
no_reply_builder.set_description("Do not emit a visible reply this round and continue thinking.")
|
||||
tools.append(no_reply_builder.build())
|
||||
|
||||
stop_builder = ToolOptionBuilder()
|
||||
stop_builder.set_name("stop")
|
||||
stop_builder.set_description("Stop the current inner loop and return control to the outer chat flow.")
|
||||
tools.append(stop_builder.build())
|
||||
|
||||
send_emoji_builder = ToolOptionBuilder()
|
||||
send_emoji_builder.set_name("send_emoji")
|
||||
send_emoji_builder.set_description(
|
||||
"Send an emoji sticker to help express emotions. "
|
||||
"You should specify the emotion type to select an appropriate emoji."
|
||||
)
|
||||
send_emoji_builder.add_param(
|
||||
name="emotion",
|
||||
param_type=ToolParamType.STRING,
|
||||
description="The emotion type for selecting an appropriate emoji (e.g., 'happy', 'sad', 'angry', 'surprised', etc.).",
|
||||
required=False,
|
||||
enum_values=None,
|
||||
)
|
||||
tools.append(send_emoji_builder.build())
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
def get_builtin_tools() -> List[ToolOption]:
|
||||
"""Return built-in tools."""
|
||||
return create_builtin_tools()
|
||||
def create_builtin_tool_specs() -> List[ToolSpec]:
|
||||
"""创建 Maisaka 内置工具声明列表。
|
||||
|
||||
Returns:
|
||||
List[ToolSpec]: 内置工具声明列表。
|
||||
"""
|
||||
|
||||
return [
|
||||
_build_tool_spec(
|
||||
name="wait",
|
||||
brief_description="暂停当前对话并等待用户新的输入。",
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"seconds": {
|
||||
"type": "integer",
|
||||
"description": "等待的秒数。",
|
||||
},
|
||||
},
|
||||
"required": ["seconds"],
|
||||
},
|
||||
),
|
||||
_build_tool_spec(
|
||||
name="reply",
|
||||
brief_description="根据当前思考生成并发送一条可见回复。",
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"msg_id": {
|
||||
"type": "string",
|
||||
"description": "要回复的目标用户消息编号。",
|
||||
},
|
||||
"quote": {
|
||||
"type": "boolean",
|
||||
"description": "是否以引用回复的方式发送。",
|
||||
"default": True,
|
||||
},
|
||||
"unknown_words": {
|
||||
"type": "array",
|
||||
"description": "回复前可能需要查询的黑话或词条列表。",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"required": ["msg_id"],
|
||||
},
|
||||
),
|
||||
_build_tool_spec(
|
||||
name="query_jargon",
|
||||
brief_description="查询当前聊天上下文中的黑话或词条含义。",
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"words": {
|
||||
"type": "array",
|
||||
"description": "要查询的词条列表。",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"required": ["words"],
|
||||
},
|
||||
),
|
||||
_build_tool_spec(
|
||||
name="query_person_info",
|
||||
brief_description="查询某个人的档案和相关记忆信息。",
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"person_name": {
|
||||
"type": "string",
|
||||
"description": "人物名称、昵称或用户 ID。",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "最多返回多少条匹配记录。",
|
||||
"default": 3,
|
||||
},
|
||||
},
|
||||
"required": ["person_name"],
|
||||
},
|
||||
),
|
||||
_build_tool_spec(
|
||||
name="no_reply",
|
||||
brief_description="本轮不发送可见回复,继续下一步思考。",
|
||||
),
|
||||
_build_tool_spec(
|
||||
name="stop",
|
||||
brief_description="暂停当前内部循环,等待新的外部消息。",
|
||||
),
|
||||
_build_tool_spec(
|
||||
name="send_emoji",
|
||||
brief_description="发送一个合适的表情包来辅助表达情绪。",
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"emotion": {
|
||||
"type": "string",
|
||||
"description": "希望表达的情绪,例如 happy、sad、angry 等。",
|
||||
},
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_builtin_tool_specs() -> List[ToolSpec]:
|
||||
"""获取 Maisaka 内置工具声明。
|
||||
|
||||
Returns:
|
||||
List[ToolSpec]: 内置工具声明列表。
|
||||
"""
|
||||
|
||||
return create_builtin_tool_specs()
|
||||
|
||||
|
||||
def get_builtin_tools() -> List[ToolDefinitionInput]:
|
||||
"""获取兼容旧模型层的内置工具定义。
|
||||
|
||||
Returns:
|
||||
List[ToolDefinitionInput]: 可直接传给模型层的工具定义。
|
||||
"""
|
||||
|
||||
return [tool_spec.to_llm_definition() for tool_spec in create_builtin_tool_specs()]
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
"""Maisaka 对话循环服务。"""
|
||||
|
||||
from base64 import b64decode
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from time import perf_counter
|
||||
@@ -20,6 +22,7 @@ from src.common.data_models.message_component_data_model import MessageSequence,
|
||||
from src.common.logger import get_logger
|
||||
from src.common.prompt_i18n import load_prompt
|
||||
from src.config.config import global_config
|
||||
from src.core.tooling import ToolRegistry
|
||||
from src.know_u.knowledge import extract_category_ids_from_result
|
||||
from src.llm_models.model_client.base_client import BaseClient
|
||||
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
|
||||
@@ -52,10 +55,19 @@ class MaisakaChatLoopService:
|
||||
temperature: float = 0.5,
|
||||
max_tokens: int = 2048,
|
||||
) -> None:
|
||||
"""初始化 Maisaka 对话循环服务。
|
||||
|
||||
Args:
|
||||
chat_system_prompt: 可选的系统提示词。
|
||||
temperature: 规划器温度参数。
|
||||
max_tokens: 规划器最大输出长度。
|
||||
"""
|
||||
|
||||
self._temperature = temperature
|
||||
self._max_tokens = max_tokens
|
||||
self._extra_tools: List[ToolOption] = []
|
||||
self._interrupt_flag: asyncio.Event | None = None
|
||||
self._tool_registry: ToolRegistry | None = None
|
||||
self._prompts_loaded = False
|
||||
self._prompt_load_lock = asyncio.Lock()
|
||||
self._personality_prompt = self._build_personality_prompt()
|
||||
@@ -67,9 +79,13 @@ class MaisakaChatLoopService:
|
||||
|
||||
@property
|
||||
def personality_prompt(self) -> str:
|
||||
"""返回当前人格提示词。"""
|
||||
|
||||
return self._personality_prompt
|
||||
|
||||
def _build_personality_prompt(self) -> str:
|
||||
"""构造人格提示词。"""
|
||||
|
||||
try:
|
||||
bot_name = global_config.bot.nickname
|
||||
if global_config.bot.alias_names:
|
||||
@@ -92,6 +108,12 @@ class MaisakaChatLoopService:
|
||||
return "Your name is MaiMai; persona: lively and cute AI assistant."
|
||||
|
||||
async def ensure_chat_prompt_loaded(self, tools_section: str = "") -> None:
|
||||
"""确保主聊天提示词已经加载完成。
|
||||
|
||||
Args:
|
||||
tools_section: 额外注入到提示词中的工具说明片段。
|
||||
"""
|
||||
|
||||
if self._prompts_loaded:
|
||||
return
|
||||
|
||||
@@ -112,8 +134,23 @@ class MaisakaChatLoopService:
|
||||
self._prompts_loaded = True
|
||||
|
||||
def set_extra_tools(self, tools: List[ToolDefinitionInput]) -> None:
|
||||
"""设置额外工具定义。
|
||||
|
||||
Args:
|
||||
tools: 兼容旧接口的额外工具定义列表。
|
||||
"""
|
||||
|
||||
self._extra_tools = normalize_tool_options(tools) or []
|
||||
|
||||
def set_tool_registry(self, tool_registry: ToolRegistry | None) -> None:
|
||||
"""设置统一工具注册表。
|
||||
|
||||
Args:
|
||||
tool_registry: 统一工具注册表;传入 ``None`` 时退回旧工具列表模式。
|
||||
"""
|
||||
|
||||
self._tool_registry = tool_registry
|
||||
|
||||
def set_interrupt_flag(self, interrupt_flag: asyncio.Event | None) -> None:
|
||||
"""设置当前 planner 请求使用的中断标记。"""
|
||||
self._interrupt_flag = interrupt_flag
|
||||
@@ -329,6 +366,15 @@ class MaisakaChatLoopService:
|
||||
)
|
||||
|
||||
async def chat_loop_step(self, chat_history: List[LLMContextMessage]) -> ChatResponse:
|
||||
"""执行一轮 Maisaka 规划器请求。
|
||||
|
||||
Args:
|
||||
chat_history: 当前对话历史。
|
||||
|
||||
Returns:
|
||||
ChatResponse: 本轮规划器返回结果。
|
||||
"""
|
||||
|
||||
await self.ensure_chat_prompt_loaded()
|
||||
selected_history, selection_reason = self._select_llm_context_messages(chat_history)
|
||||
|
||||
@@ -336,7 +382,11 @@ class MaisakaChatLoopService:
|
||||
del _client
|
||||
return self._build_request_messages(selected_history)
|
||||
|
||||
all_tools: List[ToolDefinitionInput] = [*get_builtin_tools(), *self._extra_tools]
|
||||
all_tools: List[ToolDefinitionInput]
|
||||
if self._tool_registry is not None:
|
||||
all_tools = await self._tool_registry.get_llm_definitions()
|
||||
else:
|
||||
all_tools = [*get_builtin_tools(), *self._extra_tools]
|
||||
built_messages = self._build_request_messages(selected_history)
|
||||
|
||||
ordered_panels: List[Panel] = []
|
||||
|
||||
@@ -20,6 +20,7 @@ from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import PersonInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation
|
||||
from src.know_u.knowledge_store import get_knowledge_store
|
||||
from src.learners.jargon_explainer import search_jargon
|
||||
from src.llm_models.exceptions import ReqAbortException
|
||||
@@ -37,13 +38,10 @@ from .message_adapter import (
|
||||
clone_message_sequence,
|
||||
format_speaker_content,
|
||||
)
|
||||
from .tool_handlers import (
|
||||
handle_mcp_tool,
|
||||
handle_unknown_tool,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .runtime import MaisakaHeartFlowChatting
|
||||
from .tool_provider import BuiltinToolHandler
|
||||
|
||||
logger = get_logger("maisaka_reasoning_engine")
|
||||
|
||||
@@ -55,6 +53,23 @@ class MaisakaReasoningEngine:
|
||||
self._runtime = runtime
|
||||
self._last_reasoning_content: str = ""
|
||||
|
||||
def build_builtin_tool_handlers(self) -> dict[str, "BuiltinToolHandler"]:
|
||||
"""构造 Maisaka 内置工具处理器映射。
|
||||
|
||||
Returns:
|
||||
dict[str, BuiltinToolHandler]: 工具名到处理器的映射。
|
||||
"""
|
||||
|
||||
return {
|
||||
"reply": self._invoke_reply_tool,
|
||||
"no_reply": self._invoke_no_reply_tool,
|
||||
"query_jargon": self._invoke_query_jargon_tool,
|
||||
"query_person_info": self._invoke_query_person_info_tool,
|
||||
"wait": self._invoke_wait_tool,
|
||||
"stop": self._invoke_stop_tool,
|
||||
"send_emoji": self._invoke_send_emoji_tool,
|
||||
}
|
||||
|
||||
async def run_loop(self) -> None:
|
||||
"""独立消费消息批次,并执行对应的内部思考轮次。"""
|
||||
try:
|
||||
@@ -360,79 +375,287 @@ class MaisakaReasoningEngine:
|
||||
return processed_segments
|
||||
return [reply_text.strip()]
|
||||
|
||||
def _build_tool_invocation(self, tool_call: ToolCall, latest_thought: str) -> ToolInvocation:
|
||||
"""将模型输出的工具调用转换为统一调用对象。
|
||||
|
||||
Args:
|
||||
tool_call: 模型返回的工具调用。
|
||||
latest_thought: 当前轮的最新思考文本。
|
||||
|
||||
Returns:
|
||||
ToolInvocation: 统一工具调用对象。
|
||||
"""
|
||||
|
||||
return ToolInvocation(
|
||||
tool_name=tool_call.func_name,
|
||||
arguments=dict(tool_call.args or {}),
|
||||
call_id=tool_call.call_id,
|
||||
session_id=self._runtime.session_id,
|
||||
stream_id=self._runtime.session_id,
|
||||
reasoning=latest_thought,
|
||||
)
|
||||
|
||||
def _build_tool_execution_context(
|
||||
self,
|
||||
latest_thought: str,
|
||||
anchor_message: SessionMessage,
|
||||
) -> ToolExecutionContext:
|
||||
"""构造统一工具执行上下文。
|
||||
|
||||
Args:
|
||||
latest_thought: 当前轮的最新思考文本。
|
||||
anchor_message: 当前轮的锚点消息。
|
||||
|
||||
Returns:
|
||||
ToolExecutionContext: 统一工具执行上下文。
|
||||
"""
|
||||
|
||||
return ToolExecutionContext(
|
||||
session_id=self._runtime.session_id,
|
||||
stream_id=self._runtime.session_id,
|
||||
reasoning=latest_thought,
|
||||
metadata={"anchor_message": anchor_message},
|
||||
)
|
||||
|
||||
def _append_tool_execution_result(self, tool_call: ToolCall, result: ToolExecutionResult) -> None:
|
||||
"""将统一工具执行结果写回 Maisaka 历史。
|
||||
|
||||
Args:
|
||||
tool_call: 原始工具调用对象。
|
||||
result: 统一工具执行结果。
|
||||
"""
|
||||
|
||||
history_content = result.get_history_content()
|
||||
if not history_content:
|
||||
history_content = "工具执行成功。" if result.success else f"工具 {tool_call.func_name} 执行失败。"
|
||||
|
||||
self._runtime._chat_history.append(
|
||||
ToolResultMessage(
|
||||
content=history_content,
|
||||
timestamp=datetime.now(),
|
||||
tool_call_id=tool_call.call_id,
|
||||
tool_name=tool_call.func_name,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_call_from_invocation(invocation: ToolInvocation) -> ToolCall:
|
||||
"""将统一工具调用对象恢复为 `ToolCall` 兼容对象。
|
||||
|
||||
Args:
|
||||
invocation: 统一工具调用对象。
|
||||
|
||||
Returns:
|
||||
ToolCall: 兼容旧内部逻辑的工具调用对象。
|
||||
"""
|
||||
|
||||
return ToolCall(
|
||||
call_id=invocation.call_id or f"{invocation.tool_name}_call",
|
||||
func_name=invocation.tool_name,
|
||||
args=dict(invocation.arguments),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_success_result(
|
||||
tool_name: str,
|
||||
content: str = "",
|
||||
structured_content: Any = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""构造统一工具成功结果。
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称。
|
||||
content: 结果文本。
|
||||
structured_content: 结构化结果。
|
||||
metadata: 附加元数据。
|
||||
|
||||
Returns:
|
||||
ToolExecutionResult: 统一工具成功结果。
|
||||
"""
|
||||
|
||||
return ToolExecutionResult(
|
||||
tool_name=tool_name,
|
||||
success=True,
|
||||
content=content,
|
||||
structured_content=structured_content,
|
||||
metadata=dict(metadata or {}),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_failure_result(
|
||||
tool_name: str,
|
||||
error_message: str,
|
||||
structured_content: Any = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""构造统一工具失败结果。
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称。
|
||||
error_message: 错误信息。
|
||||
structured_content: 结构化结果。
|
||||
metadata: 附加元数据。
|
||||
|
||||
Returns:
|
||||
ToolExecutionResult: 统一工具失败结果。
|
||||
"""
|
||||
|
||||
return ToolExecutionResult(
|
||||
tool_name=tool_name,
|
||||
success=False,
|
||||
error_message=error_message,
|
||||
structured_content=structured_content,
|
||||
metadata=dict(metadata or {}),
|
||||
)
|
||||
|
||||
async def _invoke_reply_tool(
|
||||
self,
|
||||
invocation: ToolInvocation,
|
||||
context: Optional[ToolExecutionContext] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""执行 reply 内置工具。"""
|
||||
|
||||
latest_thought = context.reasoning if context is not None else invocation.reasoning
|
||||
return await self._handle_reply(self._build_tool_call_from_invocation(invocation), latest_thought)
|
||||
|
||||
async def _invoke_no_reply_tool(
|
||||
self,
|
||||
invocation: ToolInvocation,
|
||||
context: Optional[ToolExecutionContext] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""执行 no_reply 内置工具。"""
|
||||
|
||||
del context
|
||||
return self._build_tool_success_result(invocation.tool_name, "本轮未发送可见回复。")
|
||||
|
||||
async def _invoke_query_jargon_tool(
|
||||
self,
|
||||
invocation: ToolInvocation,
|
||||
context: Optional[ToolExecutionContext] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""执行 query_jargon 内置工具。"""
|
||||
|
||||
del context
|
||||
return await self._handle_query_jargon(self._build_tool_call_from_invocation(invocation))
|
||||
|
||||
async def _invoke_query_person_info_tool(
|
||||
self,
|
||||
invocation: ToolInvocation,
|
||||
context: Optional[ToolExecutionContext] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""执行 query_person_info 内置工具。"""
|
||||
|
||||
del context
|
||||
return await self._handle_query_person_info(self._build_tool_call_from_invocation(invocation))
|
||||
|
||||
async def _invoke_wait_tool(
|
||||
self,
|
||||
invocation: ToolInvocation,
|
||||
context: Optional[ToolExecutionContext] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""执行 wait 内置工具。"""
|
||||
|
||||
del context
|
||||
seconds = invocation.arguments.get("seconds", 30)
|
||||
try:
|
||||
wait_seconds = int(seconds)
|
||||
except (TypeError, ValueError):
|
||||
wait_seconds = 30
|
||||
wait_seconds = max(0, wait_seconds)
|
||||
self._runtime._enter_wait_state(seconds=wait_seconds, tool_call_id=invocation.call_id)
|
||||
return self._build_tool_success_result(
|
||||
invocation.tool_name,
|
||||
f"当前对话循环进入等待状态,最长等待 {wait_seconds} 秒。",
|
||||
metadata={"pause_execution": True},
|
||||
)
|
||||
|
||||
async def _invoke_stop_tool(
|
||||
self,
|
||||
invocation: ToolInvocation,
|
||||
context: Optional[ToolExecutionContext] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""执行 stop 内置工具。"""
|
||||
|
||||
del context
|
||||
self._runtime._enter_stop_state()
|
||||
return self._build_tool_success_result(
|
||||
invocation.tool_name,
|
||||
"当前对话循环已暂停,等待新消息到来。",
|
||||
metadata={"pause_execution": True},
|
||||
)
|
||||
|
||||
async def _invoke_send_emoji_tool(
|
||||
self,
|
||||
invocation: ToolInvocation,
|
||||
context: Optional[ToolExecutionContext] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""执行 send_emoji 内置工具。"""
|
||||
|
||||
del context
|
||||
return await self._handle_send_emoji(self._build_tool_call_from_invocation(invocation))
|
||||
|
||||
async def _handle_tool_calls(
|
||||
self,
|
||||
tool_calls: list[ToolCall],
|
||||
latest_thought: str,
|
||||
anchor_message: SessionMessage,
|
||||
) -> bool:
|
||||
"""执行一批统一工具调用。
|
||||
|
||||
Args:
|
||||
tool_calls: 模型返回的工具调用列表。
|
||||
latest_thought: 当前轮的最新思考文本。
|
||||
anchor_message: 当前轮的锚点消息。
|
||||
|
||||
Returns:
|
||||
bool: 是否需要暂停当前思考循环。
|
||||
"""
|
||||
|
||||
if self._runtime._tool_registry is None:
|
||||
for tool_call in tool_calls:
|
||||
self._append_tool_execution_result(
|
||||
tool_call,
|
||||
ToolExecutionResult(
|
||||
tool_name=tool_call.func_name,
|
||||
success=False,
|
||||
error_message="统一工具注册表尚未初始化。",
|
||||
),
|
||||
)
|
||||
return False
|
||||
|
||||
execution_context = self._build_tool_execution_context(latest_thought, anchor_message)
|
||||
for tool_call in tool_calls:
|
||||
if tool_call.func_name == "reply":
|
||||
reply_sent = await self._handle_reply(tool_call, latest_thought, anchor_message)
|
||||
if not reply_sent:
|
||||
logger.warning(
|
||||
f"{self._runtime.log_prefix} 回复工具未生成可见消息,将继续下一轮循环"
|
||||
)
|
||||
continue
|
||||
invocation = self._build_tool_invocation(tool_call, latest_thought)
|
||||
result = await self._runtime._tool_registry.invoke(invocation, execution_context)
|
||||
self._append_tool_execution_result(tool_call, result)
|
||||
|
||||
if tool_call.func_name == "no_reply":
|
||||
self._runtime._chat_history.append(
|
||||
self._build_tool_message(
|
||||
tool_call,
|
||||
"本轮未发送可见回复。",
|
||||
)
|
||||
)
|
||||
continue
|
||||
if not result.success and tool_call.func_name == "reply":
|
||||
logger.warning(f"{self._runtime.log_prefix} 回复工具未生成可见消息,将继续下一轮循环")
|
||||
|
||||
if tool_call.func_name == "query_jargon":
|
||||
await self._handle_query_jargon(tool_call)
|
||||
continue
|
||||
|
||||
if tool_call.func_name == "query_person_info":
|
||||
await self._handle_query_person_info(tool_call)
|
||||
continue
|
||||
|
||||
if tool_call.func_name == "wait":
|
||||
seconds = (tool_call.args or {}).get("seconds", 30)
|
||||
try:
|
||||
wait_seconds = int(seconds)
|
||||
except (TypeError, ValueError):
|
||||
wait_seconds = 30
|
||||
wait_seconds = max(0, wait_seconds)
|
||||
self._runtime._enter_wait_state(seconds=wait_seconds, tool_call_id=tool_call.call_id)
|
||||
if bool(result.metadata.get("pause_execution", False)):
|
||||
return True
|
||||
|
||||
if tool_call.func_name == "stop":
|
||||
self._runtime._chat_history.append(
|
||||
self._build_tool_message(
|
||||
tool_call,
|
||||
"当前对话循环已暂停,等待新消息到来。",
|
||||
)
|
||||
)
|
||||
self._runtime._enter_stop_state()
|
||||
return True
|
||||
|
||||
if tool_call.func_name == "send_emoji":
|
||||
await self._handle_send_emoji(tool_call, anchor_message)
|
||||
continue
|
||||
|
||||
if self._runtime._mcp_manager and self._runtime._mcp_manager.is_mcp_tool(tool_call.func_name):
|
||||
await handle_mcp_tool(tool_call, self._runtime._chat_history, self._runtime._mcp_manager)
|
||||
continue
|
||||
|
||||
await handle_unknown_tool(tool_call, self._runtime._chat_history)
|
||||
|
||||
return False
|
||||
|
||||
async def _handle_query_jargon(self, tool_call: ToolCall) -> None:
|
||||
async def _handle_query_jargon(self, tool_call: ToolCall) -> ToolExecutionResult:
|
||||
"""查询黑话解释并返回统一工具结果。
|
||||
|
||||
Args:
|
||||
tool_call: 当前工具调用。
|
||||
|
||||
Returns:
|
||||
ToolExecutionResult: 统一工具执行结果。
|
||||
"""
|
||||
|
||||
tool_args = tool_call.args or {}
|
||||
raw_words = tool_args.get("words")
|
||||
|
||||
if not isinstance(raw_words, list):
|
||||
self._runtime._chat_history.append(
|
||||
self._build_tool_message(tool_call, "查询黑话工具需要提供 `words` 数组参数。")
|
||||
return self._build_tool_failure_result(
|
||||
tool_call.func_name,
|
||||
"查询黑话工具需要提供 `words` 数组参数。",
|
||||
)
|
||||
return
|
||||
|
||||
words: list[str] = []
|
||||
seen_words: set[str] = set()
|
||||
@@ -446,10 +669,10 @@ class MaisakaReasoningEngine:
|
||||
words.append(word)
|
||||
|
||||
if not words:
|
||||
self._runtime._chat_history.append(
|
||||
self._build_tool_message(tool_call, "查询黑话工具至少需要一个非空词条。")
|
||||
return self._build_tool_failure_result(
|
||||
tool_call.func_name,
|
||||
"查询黑话工具至少需要一个非空词条。",
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(f"{self._runtime.log_prefix} 已触发黑话查询: 词条={words!r}")
|
||||
|
||||
@@ -479,31 +702,38 @@ class MaisakaReasoningEngine:
|
||||
)
|
||||
|
||||
logger.info(f"{self._runtime.log_prefix} 黑话查询完成: 结果={results!r}")
|
||||
self._runtime._chat_history.append(
|
||||
self._build_tool_message(
|
||||
tool_call,
|
||||
json.dumps({"results": results}, ensure_ascii=False),
|
||||
)
|
||||
return self._build_tool_success_result(
|
||||
tool_call.func_name,
|
||||
json.dumps({"results": results}, ensure_ascii=False),
|
||||
structured_content={"results": results},
|
||||
)
|
||||
|
||||
async def _handle_query_person_info(self, tool_call: ToolCall) -> None:
|
||||
"""查询指定人物的档案和相关知识。"""
|
||||
async def _handle_query_person_info(self, tool_call: ToolCall) -> ToolExecutionResult:
|
||||
"""查询指定人物的档案和相关知识。
|
||||
|
||||
Args:
|
||||
tool_call: 当前工具调用。
|
||||
|
||||
Returns:
|
||||
ToolExecutionResult: 统一工具执行结果。
|
||||
"""
|
||||
|
||||
tool_args = tool_call.args or {}
|
||||
raw_person_name = tool_args.get("person_name")
|
||||
raw_limit = tool_args.get("limit", 3)
|
||||
|
||||
if not isinstance(raw_person_name, str):
|
||||
self._runtime._chat_history.append(
|
||||
self._build_tool_message(tool_call, "查询人物信息工具需要提供字符串类型的 `person_name` 参数。")
|
||||
return self._build_tool_failure_result(
|
||||
tool_call.func_name,
|
||||
"查询人物信息工具需要提供字符串类型的 `person_name` 参数。",
|
||||
)
|
||||
return
|
||||
|
||||
person_name = raw_person_name.strip()
|
||||
if not person_name:
|
||||
self._runtime._chat_history.append(
|
||||
self._build_tool_message(tool_call, "查询人物信息工具需要提供非空的 `person_name` 参数。")
|
||||
return self._build_tool_failure_result(
|
||||
tool_call.func_name,
|
||||
"查询人物信息工具需要提供非空的 `person_name` 参数。",
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
limit = max(1, min(int(raw_limit), 10))
|
||||
@@ -526,11 +756,10 @@ class MaisakaReasoningEngine:
|
||||
f"{self._runtime.log_prefix} 人物信息查询完成: "
|
||||
f"人物记录数={len(result['persons'])} 相关知识数={len(result['related_knowledge'])}"
|
||||
)
|
||||
self._runtime._chat_history.append(
|
||||
self._build_tool_message(
|
||||
tool_call,
|
||||
json.dumps(result, ensure_ascii=False),
|
||||
)
|
||||
return self._build_tool_success_result(
|
||||
tool_call.func_name,
|
||||
json.dumps(result, ensure_ascii=False),
|
||||
structured_content=result,
|
||||
)
|
||||
|
||||
def _query_person_records(self, person_name: str, limit: int) -> list[dict[str, Any]]:
|
||||
@@ -632,25 +861,34 @@ class MaisakaReasoningEngine:
|
||||
self,
|
||||
tool_call: ToolCall,
|
||||
latest_thought: str,
|
||||
anchor_message: SessionMessage,
|
||||
) -> bool:
|
||||
) -> ToolExecutionResult:
|
||||
"""执行 reply 工具并生成可见回复。
|
||||
|
||||
Args:
|
||||
tool_call: 当前工具调用。
|
||||
latest_thought: 当前轮的最新思考文本。
|
||||
|
||||
Returns:
|
||||
ToolExecutionResult: 统一工具执行结果。
|
||||
"""
|
||||
|
||||
tool_args = tool_call.args or {}
|
||||
target_message_id = str(tool_args.get("msg_id") or "").strip()
|
||||
quote_reply = bool(tool_args.get("quote", True))
|
||||
raw_unknown_words = tool_args.get("unknown_words")
|
||||
unknown_words = raw_unknown_words if isinstance(raw_unknown_words, list) else None
|
||||
if not target_message_id:
|
||||
self._runtime._chat_history.append(
|
||||
self._build_tool_message(tool_call, "回复工具需要提供有效的 `msg_id` 参数。")
|
||||
return self._build_tool_failure_result(
|
||||
tool_call.func_name,
|
||||
"回复工具需要提供有效的 `msg_id` 参数。",
|
||||
)
|
||||
return False
|
||||
|
||||
target_message = self._runtime._source_messages_by_id.get(target_message_id)
|
||||
if target_message is None:
|
||||
self._runtime._chat_history.append(
|
||||
self._build_tool_message(tool_call, f"未找到要回复的目标消息,msg_id={target_message_id}")
|
||||
return self._build_tool_failure_result(
|
||||
tool_call.func_name,
|
||||
f"未找到要回复的目标消息,msg_id={target_message_id}",
|
||||
)
|
||||
return False
|
||||
|
||||
logger.info(
|
||||
f"{self._runtime.log_prefix} 已触发回复工具: "
|
||||
@@ -668,17 +906,17 @@ class MaisakaReasoningEngine:
|
||||
f"{self._runtime.log_prefix} 获取回复生成器时发生异常: "
|
||||
f"目标消息编号={target_message_id}"
|
||||
)
|
||||
self._runtime._chat_history.append(
|
||||
self._build_tool_message(tool_call, "获取 Maisaka 回复生成器时发生异常。")
|
||||
return self._build_tool_failure_result(
|
||||
tool_call.func_name,
|
||||
"获取 Maisaka 回复生成器时发生异常。",
|
||||
)
|
||||
return False
|
||||
|
||||
if replyer is None:
|
||||
logger.error(f"{self._runtime.log_prefix} 获取 Maisaka 回复生成器失败")
|
||||
self._runtime._chat_history.append(
|
||||
self._build_tool_message(tool_call, "Maisaka 回复生成器当前不可用。")
|
||||
return self._build_tool_failure_result(
|
||||
tool_call.func_name,
|
||||
"Maisaka 回复生成器当前不可用。",
|
||||
)
|
||||
return False
|
||||
|
||||
from src.chat.replyer.maisaka_generator import MaisakaReplyGenerator
|
||||
|
||||
@@ -701,10 +939,10 @@ class MaisakaReasoningEngine:
|
||||
f"{self._runtime.log_prefix} 回复生成器执行异常: 目标消息编号={target_message_id} "
|
||||
f"异常类型={type(exc).__name__} 异常信息={str(exc)}\n{traceback.format_exc()}"
|
||||
)
|
||||
self._runtime._chat_history.append(
|
||||
self._build_tool_message(tool_call, "生成可见回复时发生异常。")
|
||||
return self._build_tool_failure_result(
|
||||
tool_call.func_name,
|
||||
"生成可见回复时发生异常。",
|
||||
)
|
||||
return False
|
||||
|
||||
logger.info(
|
||||
f"{self._runtime.log_prefix} 回复生成完成: "
|
||||
@@ -717,10 +955,10 @@ class MaisakaReasoningEngine:
|
||||
f"{self._runtime.log_prefix} 回复生成器返回空文本: "
|
||||
f"目标消息编号={target_message_id} 错误信息={reply_result.error_message!r}"
|
||||
)
|
||||
self._runtime._chat_history.append(
|
||||
self._build_tool_message(tool_call, "生成可见回复失败。")
|
||||
return self._build_tool_failure_result(
|
||||
tool_call.func_name,
|
||||
"生成可见回复失败。",
|
||||
)
|
||||
return False
|
||||
|
||||
reply_segments = self._post_process_reply_text(reply_text)
|
||||
combined_reply_text = "".join(reply_segments)
|
||||
@@ -751,19 +989,25 @@ class MaisakaReasoningEngine:
|
||||
logger.exception(
|
||||
f"{self._runtime.log_prefix} 发送文字消息时发生异常,目标消息编号={target_message_id}"
|
||||
)
|
||||
self._runtime._chat_history.append(
|
||||
self._build_tool_message(tool_call, "发送可见回复时发生异常。")
|
||||
return self._build_tool_failure_result(
|
||||
tool_call.func_name,
|
||||
"发送可见回复时发生异常。",
|
||||
)
|
||||
return False
|
||||
|
||||
logger.info(
|
||||
f"{self._runtime.log_prefix} 引导回复发送结果: "
|
||||
f"目标消息编号={target_message_id} 发送成功={sent}"
|
||||
)
|
||||
tool_result = "可见回复已生成并发送。" if sent else "可见回复生成成功,但发送失败。"
|
||||
self._runtime._chat_history.append(self._build_tool_message(tool_call, tool_result))
|
||||
if not sent:
|
||||
return False
|
||||
return self._build_tool_failure_result(
|
||||
tool_call.func_name,
|
||||
"可见回复生成成功,但发送失败。",
|
||||
structured_content={
|
||||
"msg_id": target_message_id,
|
||||
"quote": quote_reply,
|
||||
"reply_segments": reply_segments,
|
||||
},
|
||||
)
|
||||
|
||||
target_user_info = target_message.message_info.user_info
|
||||
target_user_name = (
|
||||
@@ -807,14 +1051,26 @@ class MaisakaReasoningEngine:
|
||||
)
|
||||
history_message.visible_text = visible_reply_text
|
||||
self._runtime._chat_history.append(history_message)
|
||||
return True
|
||||
return self._build_tool_success_result(
|
||||
tool_call.func_name,
|
||||
"可见回复已生成并发送。",
|
||||
structured_content={
|
||||
"msg_id": target_message_id,
|
||||
"quote": quote_reply,
|
||||
"reply_text": combined_reply_text,
|
||||
"reply_segments": reply_segments,
|
||||
"target_user_name": target_user_name,
|
||||
},
|
||||
)
|
||||
|
||||
async def _handle_send_emoji(self, tool_call: ToolCall, anchor_message: SessionMessage) -> None:
|
||||
async def _handle_send_emoji(self, tool_call: ToolCall) -> ToolExecutionResult:
|
||||
"""处理发送表情包的工具调用。
|
||||
|
||||
Args:
|
||||
tool_call: 工具调用对象
|
||||
anchor_message: 锚点消息
|
||||
tool_call: 工具调用对象。
|
||||
|
||||
Returns:
|
||||
ToolExecutionResult: 统一工具执行结果。
|
||||
"""
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
from src.common.utils.utils_image import ImageUtils
|
||||
@@ -827,10 +1083,10 @@ class MaisakaReasoningEngine:
|
||||
|
||||
# 获取表情包列表
|
||||
if not emoji_manager.emojis:
|
||||
self._runtime._chat_history.append(
|
||||
self._build_tool_message(tool_call, "当前表情包库中没有可用表情。")
|
||||
return self._build_tool_failure_result(
|
||||
tool_call.func_name,
|
||||
"当前表情包库中没有可用表情。",
|
||||
)
|
||||
return
|
||||
|
||||
# 根据情感选择表情包
|
||||
selected_emoji = None
|
||||
@@ -867,10 +1123,10 @@ class MaisakaReasoningEngine:
|
||||
logger.error(
|
||||
f"{self._runtime.log_prefix} 表情图片转换为 base64 失败: {exc}"
|
||||
)
|
||||
self._runtime._chat_history.append(
|
||||
self._build_tool_message(tool_call, f"发送表情包失败:{exc}")
|
||||
return self._build_tool_failure_result(
|
||||
tool_call.func_name,
|
||||
f"发送表情包失败:{exc}",
|
||||
)
|
||||
return
|
||||
|
||||
# 发送表情包
|
||||
try:
|
||||
@@ -885,32 +1141,26 @@ class MaisakaReasoningEngine:
|
||||
logger.exception(
|
||||
f"{self._runtime.log_prefix} 发送表情包时发生异常: {exc}"
|
||||
)
|
||||
self._runtime._chat_history.append(
|
||||
self._build_tool_message(tool_call, f"发送表情包时发生异常:{exc}")
|
||||
return self._build_tool_failure_result(
|
||||
tool_call.func_name,
|
||||
f"发送表情包时发生异常:{exc}",
|
||||
)
|
||||
return
|
||||
|
||||
if sent:
|
||||
logger.info(
|
||||
f"{self._runtime.log_prefix} 表情包发送成功: "
|
||||
f"描述={selected_emoji.description!r} 情绪标签={selected_emoji.emotion}"
|
||||
)
|
||||
self._runtime._chat_history.append(
|
||||
self._build_tool_message(
|
||||
tool_call,
|
||||
f"已发送表情包:{selected_emoji.description}(情绪:{', '.join(selected_emoji.emotion)})"
|
||||
)
|
||||
return self._build_tool_success_result(
|
||||
tool_call.func_name,
|
||||
f"已发送表情包:{selected_emoji.description}(情绪:{', '.join(selected_emoji.emotion)})",
|
||||
structured_content={
|
||||
"description": selected_emoji.description,
|
||||
"emotion": list(selected_emoji.emotion),
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.warning(f"{self._runtime.log_prefix} 表情包发送失败")
|
||||
self._runtime._chat_history.append(
|
||||
self._build_tool_message(tool_call, "发送表情包失败。")
|
||||
)
|
||||
|
||||
def _build_tool_message(self, tool_call: ToolCall, content: str) -> ToolResultMessage:
|
||||
return ToolResultMessage(
|
||||
content=content,
|
||||
timestamp=datetime.now(),
|
||||
tool_call_id=tool_call.call_id,
|
||||
tool_name=tool_call.func_name,
|
||||
logger.warning(f"{self._runtime.log_prefix} 表情包发送失败")
|
||||
return self._build_tool_failure_result(
|
||||
tool_call.func_name,
|
||||
"发送表情包失败。",
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Maisaka runtime for non-CLI integrations."""
|
||||
"""Maisaka 非 CLI 运行时。"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, cast
|
||||
from typing import Literal, Optional
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
@@ -13,21 +13,24 @@ from src.common.data_models.mai_message_data_model import GroupInfo, UserInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_config import ExpressionConfigUtils
|
||||
from src.config.config import global_config
|
||||
from src.core.tooling import ToolRegistry
|
||||
from src.know_u.knowledge import KnowledgeLearner
|
||||
from src.learners.expression_learner import ExpressionLearner
|
||||
from src.learners.jargon_miner import JargonMiner
|
||||
from src.llm_models.payload_content.tool_option import ToolDefinitionInput
|
||||
from src.mcp_module import MCPManager
|
||||
from src.mcp_module.provider import MCPToolProvider
|
||||
from src.plugin_runtime.tool_provider import PluginToolProvider
|
||||
|
||||
from .chat_loop_service import MaisakaChatLoopService
|
||||
from .context_messages import LLMContextMessage
|
||||
from .reasoning_engine import MaisakaReasoningEngine
|
||||
from .tool_provider import MaisakaBuiltinToolProvider
|
||||
|
||||
logger = get_logger("maisaka_runtime")
|
||||
|
||||
|
||||
class MaisakaHeartFlowChatting:
|
||||
"""Session-scoped Maisaka runtime."""
|
||||
"""会话级别的 Maisaka 运行时。"""
|
||||
|
||||
_STATE_RUNNING: Literal["running"] = "running"
|
||||
_STATE_WAIT: Literal["wait"] = "wait"
|
||||
@@ -79,9 +82,11 @@ class MaisakaHeartFlowChatting:
|
||||
self._knowledge_learner = KnowledgeLearner(session_id)
|
||||
|
||||
self._reasoning_engine = MaisakaReasoningEngine(self)
|
||||
self._tool_registry = ToolRegistry()
|
||||
self._register_tool_providers()
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the runtime loop."""
|
||||
"""启动运行时主循环。"""
|
||||
if self._running:
|
||||
self._ensure_background_tasks_running()
|
||||
return
|
||||
@@ -94,7 +99,7 @@ class MaisakaHeartFlowChatting:
|
||||
logger.info(f"{self.log_prefix} Maisaka 运行时已启动")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the runtime loop."""
|
||||
"""停止运行时主循环。"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
@@ -121,18 +126,17 @@ class MaisakaHeartFlowChatting:
|
||||
finally:
|
||||
self._internal_loop_task = None
|
||||
|
||||
if self._mcp_manager is not None:
|
||||
await self._mcp_manager.close()
|
||||
self._mcp_manager = None
|
||||
await self._tool_registry.close()
|
||||
self._mcp_manager = None
|
||||
|
||||
logger.info(f"{self.log_prefix} Maisaka 运行时已停止")
|
||||
|
||||
def adjust_talk_frequency(self, frequency: float) -> None:
|
||||
"""Compatibility shim for the existing manager API."""
|
||||
"""兼容现有管理器接口的占位方法。"""
|
||||
_ = frequency
|
||||
|
||||
async def register_message(self, message: SessionMessage) -> None:
|
||||
"""Cache a new message and wake the main loop."""
|
||||
"""缓存一条新消息并唤醒主循环。"""
|
||||
if self._running:
|
||||
self._ensure_background_tasks_running()
|
||||
self.message_cache.append(message)
|
||||
@@ -175,6 +179,15 @@ class MaisakaHeartFlowChatting:
|
||||
self._loop_task = asyncio.create_task(self._main_loop())
|
||||
logger.warning(f"{self.log_prefix} 已重新拉起 Maisaka 主循环任务")
|
||||
|
||||
def _register_tool_providers(self) -> None:
|
||||
"""注册 Maisaka 运行时默认启用的工具 Provider。"""
|
||||
|
||||
self._tool_registry.register_provider(
|
||||
MaisakaBuiltinToolProvider(self._reasoning_engine.build_builtin_tool_handlers())
|
||||
)
|
||||
self._tool_registry.register_provider(PluginToolProvider())
|
||||
self._chat_loop_service.set_tool_registry(self._tool_registry)
|
||||
|
||||
async def _main_loop(self) -> None:
|
||||
try:
|
||||
while self._running:
|
||||
@@ -215,7 +228,7 @@ class MaisakaHeartFlowChatting:
|
||||
return self._last_processed_index < len(self.message_cache)
|
||||
|
||||
def _collect_pending_messages(self) -> list[SessionMessage]:
|
||||
"""Collect one batch of unprocessed messages from message_cache."""
|
||||
"""从消息缓存中收集一批尚未处理的消息。"""
|
||||
start_index = self._last_processed_index
|
||||
pending_messages = self.message_cache[start_index:]
|
||||
if not pending_messages:
|
||||
@@ -264,13 +277,13 @@ class MaisakaHeartFlowChatting:
|
||||
return "timeout"
|
||||
|
||||
def _enter_wait_state(self, seconds: Optional[float] = None, tool_call_id: Optional[str] = None) -> None:
|
||||
"""Enter wait state."""
|
||||
"""切换到等待状态。"""
|
||||
self._agent_state = self._STATE_WAIT
|
||||
self._wait_until = None if seconds is None else time.time() + seconds
|
||||
self._pending_wait_tool_call_id = tool_call_id
|
||||
|
||||
def _enter_stop_state(self) -> None:
|
||||
"""Enter stop state."""
|
||||
"""切换到停止状态。"""
|
||||
self._agent_state = self._STATE_STOP
|
||||
self._wait_until = None
|
||||
self._pending_wait_tool_call_id = None
|
||||
@@ -288,7 +301,7 @@ class MaisakaHeartFlowChatting:
|
||||
logger.error(f"{self.log_prefix} 知识学习任务异常退出: {knowledge_result}")
|
||||
|
||||
async def _trigger_expression_learning(self, messages: list[SessionMessage]) -> None:
|
||||
"""Trigger expression learning from the newly collected batch."""
|
||||
"""基于新收集的一批消息触发表达学习。"""
|
||||
self._expression_learner.add_messages(messages)
|
||||
|
||||
if not self._enable_expression_learning:
|
||||
@@ -331,7 +344,7 @@ class MaisakaHeartFlowChatting:
|
||||
logger.exception(f"{self.log_prefix} 表达学习失败")
|
||||
|
||||
async def _trigger_knowledge_learning(self, messages: list[SessionMessage]) -> None:
|
||||
"""Trigger knowledge learning from the newly collected batch."""
|
||||
"""基于新收集的一批消息触发知识学习。"""
|
||||
self._knowledge_learner.add_messages(messages)
|
||||
|
||||
if not global_config.maisaka.enable_knowledge_module:
|
||||
@@ -372,22 +385,21 @@ class MaisakaHeartFlowChatting:
|
||||
logger.exception(f"{self.log_prefix} 知识学习失败")
|
||||
|
||||
async def _init_mcp(self) -> None:
|
||||
"""Initialize MCP tools and inject them into the planner."""
|
||||
"""初始化 MCP 工具并注册到统一工具层。"""
|
||||
config_path = Path(__file__).resolve().parents[2] / "config" / "mcp_config.json"
|
||||
self._mcp_manager = await MCPManager.from_config(str(config_path))
|
||||
if self._mcp_manager is None:
|
||||
logger.info(f"{self.log_prefix} MCP 管理器不可用")
|
||||
return
|
||||
|
||||
mcp_tools = self._mcp_manager.get_openai_tools()
|
||||
if not mcp_tools:
|
||||
mcp_tool_specs = self._mcp_manager.get_tool_specs()
|
||||
if not mcp_tool_specs:
|
||||
logger.info(f"{self.log_prefix} 没有可供 Maisaka 使用的 MCP 工具")
|
||||
return
|
||||
|
||||
mcp_tool_definitions = [cast(ToolDefinitionInput, tool) for tool in mcp_tools]
|
||||
self._chat_loop_service.set_extra_tools(mcp_tool_definitions)
|
||||
self._tool_registry.register_provider(MCPToolProvider(self._mcp_manager))
|
||||
logger.info(
|
||||
f"{self.log_prefix} 已向 Maisaka 加载 {len(mcp_tools)} 个 MCP 工具:\n"
|
||||
f"{self.log_prefix} 已向 Maisaka 加载 {len(mcp_tool_specs)} 个 MCP 工具:\n"
|
||||
f"{self._mcp_manager.get_tool_summary()}"
|
||||
)
|
||||
|
||||
|
||||
64
src/maisaka/tool_provider.py
Normal file
64
src/maisaka/tool_provider.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Maisaka 内置工具 Provider。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Dict, Optional
|
||||
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolProvider, ToolSpec
|
||||
|
||||
from .builtin_tools import get_builtin_tool_specs
|
||||
|
||||
BuiltinToolHandler = Callable[[ToolInvocation, Optional[ToolExecutionContext]], Awaitable[ToolExecutionResult]]
|
||||
|
||||
|
||||
class MaisakaBuiltinToolProvider(ToolProvider):
|
||||
"""Maisaka 内置工具提供者。"""
|
||||
|
||||
provider_name = "maisaka_builtin"
|
||||
provider_type = "builtin"
|
||||
|
||||
def __init__(self, handlers: Optional[Dict[str, BuiltinToolHandler]] = None) -> None:
|
||||
"""初始化内置工具 Provider。
|
||||
|
||||
Args:
|
||||
handlers: 工具名到异步处理器的映射。
|
||||
"""
|
||||
|
||||
self._handlers = dict(handlers or {})
|
||||
|
||||
async def list_tools(self) -> list[ToolSpec]:
|
||||
"""列出全部内置工具。"""
|
||||
|
||||
return list(get_builtin_tool_specs())
|
||||
|
||||
async def invoke(
|
||||
self,
|
||||
invocation: ToolInvocation,
|
||||
context: Optional[ToolExecutionContext] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""执行指定内置工具。
|
||||
|
||||
Args:
|
||||
invocation: 工具调用请求。
|
||||
context: 执行上下文。
|
||||
|
||||
Returns:
|
||||
ToolExecutionResult: 工具执行结果。
|
||||
"""
|
||||
|
||||
handler = self._handlers.get(invocation.tool_name)
|
||||
if handler is None:
|
||||
return ToolExecutionResult(
|
||||
tool_name=invocation.tool_name,
|
||||
success=False,
|
||||
error_message=f"未找到内置工具处理器:{invocation.tool_name}",
|
||||
)
|
||||
return await handler(invocation, context)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭 Provider。
|
||||
|
||||
内置 Provider 无需释放额外资源。
|
||||
"""
|
||||
|
||||
@@ -6,6 +6,8 @@ MaiSaka - 单个 MCP 服务器连接管理
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Any, Optional
|
||||
|
||||
from src.core.tooling import ToolExecutionResult
|
||||
|
||||
from src.cli.console import console
|
||||
|
||||
from .config import MCPServerConfig
|
||||
@@ -79,6 +81,8 @@ class MCPConnection:
|
||||
return False
|
||||
|
||||
# 创建并初始化 MCP 会话
|
||||
if ClientSession is None:
|
||||
raise RuntimeError("当前环境未安装可用的 MCP ClientSession")
|
||||
self.session = await self._exit_stack.enter_async_context(ClientSession(read_stream, write_stream))
|
||||
await self.session.initialize()
|
||||
|
||||
@@ -95,6 +99,11 @@ class MCPConnection:
|
||||
|
||||
async def _connect_stdio(self):
|
||||
"""建立 Stdio 传输连接。"""
|
||||
if StdioServerParameters is None or stdio_client is None:
|
||||
raise RuntimeError("当前环境未安装可用的 MCP stdio 客户端")
|
||||
if not self.config.command:
|
||||
raise ValueError(f"MCP 服务器 '{self.config.name}' 缺少 stdio command 配置")
|
||||
|
||||
params = StdioServerParameters(
|
||||
command=self.config.command,
|
||||
args=self.config.args,
|
||||
@@ -106,39 +115,63 @@ class MCPConnection:
|
||||
"""建立 SSE 传输连接。"""
|
||||
if not SSE_AVAILABLE:
|
||||
raise ImportError("SSE 传输需要额外依赖,请运行: pip install mcp[sse]")
|
||||
if sse_client is None:
|
||||
raise RuntimeError("当前环境未安装可用的 MCP SSE 客户端")
|
||||
if not self.config.url:
|
||||
raise ValueError(f"MCP 服务器 '{self.config.name}' 缺少 SSE url 配置")
|
||||
return await self._exit_stack.enter_async_context(sse_client(url=self.config.url, headers=self.config.headers))
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: dict) -> str:
|
||||
"""
|
||||
调用 MCP 工具并返回结果文本。
|
||||
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> ToolExecutionResult:
|
||||
"""调用 MCP 工具并返回统一执行结果。
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数字典
|
||||
tool_name: 工具名称。
|
||||
arguments: 工具参数字典。
|
||||
|
||||
Returns:
|
||||
工具执行结果的文本表示。
|
||||
ToolExecutionResult: 统一执行结果。
|
||||
"""
|
||||
|
||||
if not self.session:
|
||||
return f"MCP 服务器 '{self.config.name}' 未连接"
|
||||
return ToolExecutionResult(
|
||||
tool_name=tool_name,
|
||||
success=False,
|
||||
error_message=f"MCP 服务器 '{self.config.name}' 未连接",
|
||||
)
|
||||
|
||||
result = await self.session.call_tool(tool_name, arguments=arguments)
|
||||
try:
|
||||
result = await self.session.call_tool(tool_name, arguments=arguments)
|
||||
except Exception as exc:
|
||||
return ToolExecutionResult(
|
||||
tool_name=tool_name,
|
||||
success=False,
|
||||
error_message=f"MCP 工具 '{tool_name}' 执行失败: {exc}",
|
||||
metadata={"server_name": self.config.name},
|
||||
)
|
||||
|
||||
# 将结果内容转换为文本
|
||||
parts: list[str] = []
|
||||
text_parts: list[str] = []
|
||||
binary_parts: list[dict[str, Any]] = []
|
||||
for content in result.content:
|
||||
if hasattr(content, "text"):
|
||||
parts.append(content.text)
|
||||
text_parts.append(str(content.text))
|
||||
elif hasattr(content, "data"):
|
||||
# 二进制/图片内容,展示类型信息
|
||||
content_type = getattr(content, "mimeType", "unknown")
|
||||
parts.append(f"[{content_type} 二进制内容]")
|
||||
binary_parts.append({"mime_type": content_type, "type": "binary"})
|
||||
text_parts.append(f"[{content_type} 二进制内容]")
|
||||
elif hasattr(content, "type"):
|
||||
parts.append(f"[{content.type} 内容]")
|
||||
text_parts.append(f"[{content.type} 内容]")
|
||||
|
||||
return "\n".join(parts) if parts else "工具执行成功(无输出)"
|
||||
return ToolExecutionResult(
|
||||
tool_name=tool_name,
|
||||
success=True,
|
||||
content="\n".join(text_parts) if text_parts else "工具执行成功(无输出)",
|
||||
metadata={
|
||||
"server_name": self.config.name,
|
||||
"binary_parts": binary_parts,
|
||||
},
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
async def close(self) -> None:
|
||||
"""关闭连接并释放资源。"""
|
||||
try:
|
||||
await self._exit_stack.aclose()
|
||||
|
||||
@@ -3,9 +3,15 @@ MaiSaka - MCP 管理器
|
||||
管理所有 MCP 服务器连接,提供统一的工具发现与调用接口。
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from src.cli.console import console
|
||||
from src.core.tooling import (
|
||||
ToolExecutionResult,
|
||||
ToolInvocation,
|
||||
ToolSpec,
|
||||
build_tool_detailed_description,
|
||||
)
|
||||
|
||||
from .config import DEFAULT_MCP_CONFIG_PATH, MCPServerConfig, load_mcp_config
|
||||
from .connection import MCPConnection, MCP_AVAILABLE
|
||||
@@ -74,7 +80,7 @@ class MCPManager:
|
||||
|
||||
# ──────── 连接管理 ────────
|
||||
|
||||
async def _connect_all(self, configs: list[MCPServerConfig]):
|
||||
async def _connect_all(self, configs: list[MCPServerConfig]) -> None:
|
||||
"""连接所有配置的 MCP 服务器,跳过失败的连接。"""
|
||||
for cfg in configs:
|
||||
conn = MCPConnection(cfg)
|
||||
@@ -112,44 +118,75 @@ class MCPManager:
|
||||
|
||||
# ──────── 工具发现 ────────
|
||||
|
||||
def get_openai_tools(self) -> list[dict]:
|
||||
"""
|
||||
将所有已注册的 MCP 工具转换为 OpenAI function calling 格式。
|
||||
def _build_tool_parameters_schema(self, tool: Any) -> dict[str, Any] | None:
|
||||
"""构造单个 MCP 工具的对象级参数 Schema。
|
||||
|
||||
Args:
|
||||
tool: MCP SDK 返回的原始工具对象。
|
||||
|
||||
Returns:
|
||||
OpenAI tools 格式的工具定义列表。
|
||||
dict[str, Any] | None: 参数 Schema。
|
||||
"""
|
||||
tools: list[dict] = []
|
||||
|
||||
parameters_schema = (
|
||||
dict(tool.inputSchema)
|
||||
if hasattr(tool, "inputSchema") and tool.inputSchema
|
||||
else {"type": "object", "properties": {}}
|
||||
)
|
||||
parameters_schema.pop("$schema", None)
|
||||
return parameters_schema
|
||||
|
||||
def get_tool_specs(self) -> list[ToolSpec]:
|
||||
"""获取全部已注册 MCP 工具的统一声明。
|
||||
|
||||
Returns:
|
||||
list[ToolSpec]: MCP 工具声明列表。
|
||||
"""
|
||||
|
||||
tool_specs: list[ToolSpec] = []
|
||||
for server_name, conn in self._connections.items():
|
||||
for tool in conn.tools:
|
||||
# 只包含成功注册的工具
|
||||
if tool.name not in self._tool_to_server:
|
||||
continue
|
||||
if self._tool_to_server[tool.name] != server_name:
|
||||
continue
|
||||
|
||||
# MCP inputSchema → OpenAI parameters
|
||||
parameters = (
|
||||
dict(tool.inputSchema)
|
||||
if hasattr(tool, "inputSchema") and tool.inputSchema
|
||||
else {"type": "object", "properties": {}}
|
||||
parameters_schema = self._build_tool_parameters_schema(tool)
|
||||
brief_description = str(tool.description or f"来自 {server_name} 的 MCP 工具").strip()
|
||||
tool_specs.append(
|
||||
ToolSpec(
|
||||
name=str(tool.name),
|
||||
brief_description=brief_description,
|
||||
detailed_description=build_tool_detailed_description(
|
||||
parameters_schema,
|
||||
fallback_description=f"工具来源:MCP 服务 {server_name}。",
|
||||
),
|
||||
parameters_schema=parameters_schema,
|
||||
provider_name="mcp",
|
||||
provider_type="mcp",
|
||||
metadata={"server_name": server_name},
|
||||
)
|
||||
)
|
||||
# 移除 $schema 字段(部分 MCP 服务器会带上,OpenAI 不接受)
|
||||
parameters.pop("$schema", None)
|
||||
return tool_specs
|
||||
|
||||
tools.append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": (tool.description or f"MCP tool from {server_name}"),
|
||||
"parameters": parameters,
|
||||
},
|
||||
}
|
||||
)
|
||||
def get_openai_tools(self) -> list[dict[str, Any]]:
|
||||
"""获取兼容旧模型层的 MCP 工具定义。
|
||||
|
||||
return tools
|
||||
Returns:
|
||||
list[dict[str, Any]]: OpenAI function tool 格式列表。
|
||||
"""
|
||||
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_spec.name,
|
||||
"description": tool_spec.build_llm_description(),
|
||||
"parameters": tool_spec.parameters_schema or {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
for tool_spec in self.get_tool_specs()
|
||||
]
|
||||
|
||||
# ──────── 工具调用 ────────
|
||||
|
||||
@@ -157,28 +194,46 @@ class MCPManager:
|
||||
"""判断工具名是否为已注册的 MCP 工具。"""
|
||||
return tool_name in self._tool_to_server
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: dict) -> str:
|
||||
"""
|
||||
调用指定的 MCP 工具。
|
||||
|
||||
自动路由到正确的 MCP 服务器。
|
||||
async def call_tool_invocation(self, invocation: ToolInvocation) -> ToolExecutionResult:
|
||||
"""执行统一的 MCP 工具调用。
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数
|
||||
invocation: 统一工具调用请求。
|
||||
|
||||
Returns:
|
||||
工具执行结果文本。
|
||||
ToolExecutionResult: 统一工具执行结果。
|
||||
"""
|
||||
|
||||
tool_name = invocation.tool_name
|
||||
server_name = self._tool_to_server.get(tool_name)
|
||||
if not server_name or server_name not in self._connections:
|
||||
return f"MCP 工具 '{tool_name}' 未找到"
|
||||
return ToolExecutionResult(
|
||||
tool_name=tool_name,
|
||||
success=False,
|
||||
error_message=f"MCP 工具 '{tool_name}' 未找到",
|
||||
)
|
||||
|
||||
conn = self._connections[server_name]
|
||||
try:
|
||||
return await conn.call_tool(tool_name, arguments)
|
||||
except Exception as e:
|
||||
return f"MCP 工具 '{tool_name}' 执行失败: {e}"
|
||||
return await conn.call_tool(tool_name, invocation.arguments)
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> str:
|
||||
"""兼容旧接口,返回 MCP 工具的文本结果。
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称。
|
||||
arguments: 工具参数。
|
||||
|
||||
Returns:
|
||||
str: 工具结果文本。
|
||||
"""
|
||||
|
||||
result = await self.call_tool_invocation(
|
||||
ToolInvocation(
|
||||
tool_name=tool_name,
|
||||
arguments=arguments,
|
||||
)
|
||||
)
|
||||
return result.get_history_content()
|
||||
|
||||
# ──────── 信息展示 ────────
|
||||
|
||||
@@ -207,7 +262,7 @@ class MCPManager:
|
||||
|
||||
# ──────── 生命周期 ────────
|
||||
|
||||
async def close(self):
|
||||
async def close(self) -> None:
|
||||
"""关闭所有 MCP 服务器连接。"""
|
||||
for conn in self._connections.values():
|
||||
await conn.close()
|
||||
|
||||
54
src/mcp_module/provider.py
Normal file
54
src/mcp_module/provider.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""MCP 工具 Provider。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolProvider, ToolSpec
|
||||
|
||||
from .manager import MCPManager
|
||||
|
||||
|
||||
class MCPToolProvider(ToolProvider):
|
||||
"""基于 MCPManager 的工具 Provider。"""
|
||||
|
||||
provider_name = "mcp"
|
||||
provider_type = "mcp"
|
||||
|
||||
def __init__(self, manager: MCPManager) -> None:
|
||||
"""初始化 MCP 工具 Provider。
|
||||
|
||||
Args:
|
||||
manager: MCP 管理器实例。
|
||||
"""
|
||||
|
||||
self._manager = manager
|
||||
|
||||
async def list_tools(self) -> list[ToolSpec]:
|
||||
"""列出全部 MCP 工具。"""
|
||||
|
||||
return self._manager.get_tool_specs()
|
||||
|
||||
async def invoke(
|
||||
self,
|
||||
invocation: ToolInvocation,
|
||||
context: Optional[ToolExecutionContext] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""执行指定 MCP 工具。
|
||||
|
||||
Args:
|
||||
invocation: 工具调用请求。
|
||||
context: 执行上下文。
|
||||
|
||||
Returns:
|
||||
ToolExecutionResult: 工具执行结果。
|
||||
"""
|
||||
|
||||
del context
|
||||
return await self._manager.call_tool_invocation(invocation)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭 Provider 并释放 MCP 连接。"""
|
||||
|
||||
await self._manager.close()
|
||||
|
||||
@@ -29,6 +29,19 @@ class _RuntimeComponentManagerProtocol(Protocol):
|
||||
|
||||
def _build_api_unavailable_error(self, entry: "APIEntry") -> str: ...
|
||||
|
||||
def _collect_api_reference_matches(
|
||||
self,
|
||||
caller_plugin_id: str,
|
||||
normalized_api_name: str,
|
||||
normalized_version: str,
|
||||
) -> tuple[List[tuple["PluginSupervisor", "APIEntry"]], List[tuple["PluginSupervisor", "APIEntry"]], bool]: ...
|
||||
|
||||
def _collect_api_toggle_reference_matches(
|
||||
self,
|
||||
normalized_name: str,
|
||||
normalized_version: str,
|
||||
) -> List[tuple["PluginSupervisor", "APIEntry"]]: ...
|
||||
|
||||
def _get_supervisor_for_plugin(self, plugin_id: str) -> Optional["PluginSupervisor"]: ...
|
||||
|
||||
def _resolve_api_target(
|
||||
@@ -136,7 +149,10 @@ class RuntimeComponentCapabilityMixin:
|
||||
str: 统一转为大写后的组件类型名。
|
||||
"""
|
||||
|
||||
return str(component_type or "").strip().upper()
|
||||
normalized_component_type = str(component_type or "").strip().upper()
|
||||
if normalized_component_type == "ACTION":
|
||||
return "TOOL"
|
||||
return normalized_component_type
|
||||
|
||||
@classmethod
|
||||
def _is_api_component_type(cls, component_type: str) -> bool:
|
||||
|
||||
@@ -1,14 +1,21 @@
|
||||
"""插件运行时统一组件查询服务。
|
||||
|
||||
该模块统一从插件运行时的 Host ComponentRegistry 中聚合只读视图,
|
||||
供 HFC/PFC、Planner、ToolExecutor 和运行时能力层查询与调用。
|
||||
供 HFC、ToolExecutor 和运行时能力层查询与调用。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tuple, cast
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.core.tooling import (
|
||||
ToolExecutionContext,
|
||||
ToolExecutionResult,
|
||||
ToolInvocation,
|
||||
ToolSpec,
|
||||
build_tool_detailed_description,
|
||||
)
|
||||
from src.core.types import ActionActivationType, ActionInfo, CommandInfo, ComponentInfo, ComponentType, ToolInfo
|
||||
from src.llm_models.payload_content.tool_option import normalize_tool_option
|
||||
|
||||
@@ -240,12 +247,38 @@ class ComponentQueryService:
|
||||
|
||||
return ToolInfo(
|
||||
name=entry.name,
|
||||
description=entry.description,
|
||||
description=entry.brief_description or entry.description,
|
||||
enabled=bool(entry.enabled),
|
||||
plugin_name=entry.plugin_id,
|
||||
parameters_schema=ComponentQueryService._build_tool_parameters_schema(entry),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_spec(entry: "ToolEntry") -> ToolSpec:
|
||||
"""将运行时 Tool 条目转换为统一工具声明。
|
||||
|
||||
Args:
|
||||
entry: 插件运行时中的 Tool 条目。
|
||||
|
||||
Returns:
|
||||
ToolSpec: 统一工具声明。
|
||||
"""
|
||||
|
||||
parameters_schema = ComponentQueryService._build_tool_parameters_schema(entry)
|
||||
return ToolSpec(
|
||||
name=entry.name,
|
||||
brief_description=entry.brief_description or entry.description or f"工具 {entry.name}",
|
||||
detailed_description=entry.detailed_description or build_tool_detailed_description(parameters_schema),
|
||||
parameters_schema=parameters_schema,
|
||||
provider_name=entry.plugin_id,
|
||||
provider_type="plugin",
|
||||
metadata={
|
||||
"plugin_id": entry.plugin_id,
|
||||
"invoke_method": entry.invoke_method,
|
||||
"legacy_component_type": entry.legacy_component_type,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _log_duplicate_component(component_type: ComponentType, component_name: str) -> None:
|
||||
"""记录重复组件名称冲突。
|
||||
@@ -475,7 +508,12 @@ class ComponentQueryService:
|
||||
return _executor
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_executor(supervisor: "PluginSupervisor", plugin_id: str, component_name: str) -> ToolExecutor:
|
||||
def _build_tool_executor(
|
||||
supervisor: "PluginSupervisor",
|
||||
plugin_id: str,
|
||||
component_name: str,
|
||||
invoke_method: str = "plugin.invoke_tool",
|
||||
) -> ToolExecutor:
|
||||
"""构造工具执行 RPC 闭包。
|
||||
|
||||
Args:
|
||||
@@ -499,7 +537,7 @@ class ComponentQueryService:
|
||||
|
||||
try:
|
||||
response = await supervisor.invoke_plugin(
|
||||
method="plugin.invoke_tool",
|
||||
method=invoke_method,
|
||||
plugin_id=plugin_id,
|
||||
component_name=component_name,
|
||||
args=function_args,
|
||||
@@ -615,7 +653,162 @@ class ComponentQueryService:
|
||||
if matched_entry is None:
|
||||
return None
|
||||
supervisor, entry = matched_entry
|
||||
return self._build_tool_executor(supervisor, entry.plugin_id, entry.name)
|
||||
tool_entry = cast("ToolEntry", entry)
|
||||
return self._build_tool_executor(supervisor, tool_entry.plugin_id, tool_entry.name, tool_entry.invoke_method)
|
||||
|
||||
def get_llm_available_tool_specs(self) -> Dict[str, ToolSpec]:
|
||||
"""获取当前可供 LLM 使用的统一工具声明集合。
|
||||
|
||||
Returns:
|
||||
Dict[str, ToolSpec]: 工具名到工具声明的映射。
|
||||
"""
|
||||
|
||||
collected_specs: Dict[str, ToolSpec] = {}
|
||||
for _supervisor, entry in self._iter_component_entries(ComponentType.TOOL):
|
||||
if entry.name in collected_specs:
|
||||
self._log_duplicate_component(ComponentType.TOOL, entry.name)
|
||||
continue
|
||||
collected_specs[entry.name] = self._build_tool_spec(entry) # type: ignore[arg-type]
|
||||
return collected_specs
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_invocation_payload(
|
||||
entry: "ToolEntry",
|
||||
invocation: ToolInvocation,
|
||||
context: Optional[ToolExecutionContext],
|
||||
) -> Dict[str, Any]:
|
||||
"""构造插件工具执行时发送给 Runner 的参数。
|
||||
|
||||
Args:
|
||||
entry: 目标工具条目。
|
||||
invocation: 统一工具调用请求。
|
||||
context: 统一工具执行上下文。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 发往 Runner 的参数字典。
|
||||
"""
|
||||
|
||||
payload = dict(invocation.arguments)
|
||||
if entry.invoke_method == "plugin.invoke_action":
|
||||
stream_id = context.stream_id if context is not None else invocation.stream_id
|
||||
reasoning = context.reasoning if context is not None else invocation.reasoning
|
||||
payload = {
|
||||
**payload,
|
||||
"stream_id": stream_id,
|
||||
"chat_id": stream_id,
|
||||
"reasoning": reasoning,
|
||||
"action_data": dict(invocation.arguments),
|
||||
}
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def _parse_tool_invoke_result(
|
||||
entry: "ToolEntry",
|
||||
result: Any,
|
||||
) -> ToolExecutionResult:
|
||||
"""将插件组件返回值转换为统一工具执行结果。
|
||||
|
||||
Args:
|
||||
entry: 目标工具条目。
|
||||
result: 插件组件原始返回值。
|
||||
|
||||
Returns:
|
||||
ToolExecutionResult: 统一执行结果。
|
||||
"""
|
||||
|
||||
if isinstance(result, dict):
|
||||
success = bool(result.get("success", True))
|
||||
content = str(result.get("content", result.get("message", "")) or "").strip()
|
||||
error_message = ""
|
||||
if not success:
|
||||
error_message = str(result.get("error", result.get("message", "插件工具执行失败")) or "").strip()
|
||||
return ToolExecutionResult(
|
||||
tool_name=entry.name,
|
||||
success=success,
|
||||
content=content,
|
||||
error_message=error_message,
|
||||
structured_content=result,
|
||||
metadata={"plugin_id": entry.plugin_id},
|
||||
)
|
||||
|
||||
if isinstance(result, (list, tuple)) and result:
|
||||
if isinstance(result[0], bool):
|
||||
success = bool(result[0])
|
||||
message = "" if len(result) < 2 or result[1] is None else str(result[1]).strip()
|
||||
return ToolExecutionResult(
|
||||
tool_name=entry.name,
|
||||
success=success,
|
||||
content=message if success else "",
|
||||
error_message="" if success else message,
|
||||
structured_content=list(result),
|
||||
metadata={"plugin_id": entry.plugin_id},
|
||||
)
|
||||
|
||||
normalized_content = "" if result is None else str(result).strip()
|
||||
return ToolExecutionResult(
|
||||
tool_name=entry.name,
|
||||
success=True,
|
||||
content=normalized_content,
|
||||
structured_content=result,
|
||||
metadata={"plugin_id": entry.plugin_id},
|
||||
)
|
||||
|
||||
async def invoke_tool_as_tool(
|
||||
self,
|
||||
invocation: ToolInvocation,
|
||||
context: Optional[ToolExecutionContext] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""按统一工具语义执行插件工具。
|
||||
|
||||
Args:
|
||||
invocation: 统一工具调用请求。
|
||||
context: 执行上下文。
|
||||
|
||||
Returns:
|
||||
ToolExecutionResult: 统一工具执行结果。
|
||||
"""
|
||||
|
||||
matched_entry = self._get_unique_component_entry(ComponentType.TOOL, invocation.tool_name)
|
||||
if matched_entry is None:
|
||||
return ToolExecutionResult(
|
||||
tool_name=invocation.tool_name,
|
||||
success=False,
|
||||
error_message=f"未找到插件工具:{invocation.tool_name}",
|
||||
)
|
||||
|
||||
supervisor, entry = matched_entry
|
||||
tool_entry = cast("ToolEntry", entry)
|
||||
invoke_payload = self._build_tool_invocation_payload(tool_entry, invocation, context)
|
||||
|
||||
try:
|
||||
response = await supervisor.invoke_plugin(
|
||||
method=tool_entry.invoke_method,
|
||||
plugin_id=tool_entry.plugin_id,
|
||||
component_name=tool_entry.name,
|
||||
args=invoke_payload,
|
||||
timeout_ms=30000,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"运行时工具 {tool_entry.plugin_id}.{tool_entry.name} 执行失败: {exc}", exc_info=True)
|
||||
return ToolExecutionResult(
|
||||
tool_name=tool_entry.name,
|
||||
success=False,
|
||||
error_message=str(exc),
|
||||
metadata={"plugin_id": tool_entry.plugin_id},
|
||||
)
|
||||
|
||||
payload = response.payload if isinstance(response.payload, dict) else {}
|
||||
transport_success = bool(payload.get("success", False))
|
||||
result = payload.get("result")
|
||||
if not transport_success:
|
||||
return ToolExecutionResult(
|
||||
tool_name=tool_entry.name,
|
||||
success=False,
|
||||
error_message="" if result is None else str(result),
|
||||
structured_content=result,
|
||||
metadata={"plugin_id": tool_entry.plugin_id},
|
||||
)
|
||||
return self._parse_tool_invoke_result(tool_entry, result)
|
||||
|
||||
def get_llm_available_tools(self) -> Dict[str, ToolInfo]:
|
||||
"""获取当前可供 LLM 选择的工具集合。
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Host-side ComponentRegistry
|
||||
"""Host 侧组件注册表。
|
||||
|
||||
对齐旧系统 component_registry.py 的核心能力:
|
||||
- 按类型注册组件(action / command / tool / event_handler / hook_handler / message_gateway)
|
||||
@@ -16,6 +16,7 @@ import contextlib
|
||||
import re
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.core.tooling import build_tool_detailed_description
|
||||
|
||||
logger = get_logger("plugin_runtime.host.component_registry")
|
||||
|
||||
@@ -89,11 +90,81 @@ class ToolEntry(ComponentEntry):
|
||||
"""Tool 组件条目"""
|
||||
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
self.description: str = metadata.get("description", "")
|
||||
self.description: str = str(metadata.get("description", "") or "").strip()
|
||||
self.brief_description: str = str(
|
||||
metadata.get("brief_description", self.description) or self.description or f"工具 {name}"
|
||||
).strip()
|
||||
self.parameters: List[Dict[str, Any]] = metadata.get("parameters", [])
|
||||
self.parameters_raw: Dict[str, Any] | List[Dict[str, Any]] = metadata.get("parameters_raw", {})
|
||||
detailed_description = str(metadata.get("detailed_description", "") or "").strip()
|
||||
self.detailed_description: str = detailed_description
|
||||
self.invoke_method: str = str(metadata.get("invoke_method", "plugin.invoke_tool") or "plugin.invoke_tool").strip()
|
||||
self.legacy_component_type: str = str(metadata.get("legacy_component_type", "") or "").strip()
|
||||
super().__init__(name, component_type, plugin_id, metadata)
|
||||
|
||||
if not self.detailed_description:
|
||||
parameters_schema = self._get_parameters_schema()
|
||||
self.detailed_description = build_tool_detailed_description(parameters_schema)
|
||||
|
||||
def _get_parameters_schema(self) -> Dict[str, Any] | None:
|
||||
"""获取当前工具条目的对象级参数 Schema。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any] | None: 归一化后的参数 Schema。
|
||||
"""
|
||||
|
||||
if isinstance(self.parameters_raw, dict) and self.parameters_raw:
|
||||
if self.parameters_raw.get("type") == "object" or "properties" in self.parameters_raw:
|
||||
return dict(self.parameters_raw)
|
||||
|
||||
required_names: List[str] = []
|
||||
normalized_properties: Dict[str, Any] = {}
|
||||
for property_name, property_schema in self.parameters_raw.items():
|
||||
if not isinstance(property_schema, dict):
|
||||
continue
|
||||
property_schema_copy = dict(property_schema)
|
||||
if bool(property_schema_copy.pop("required", False)):
|
||||
required_names.append(str(property_name))
|
||||
normalized_properties[str(property_name)] = property_schema_copy
|
||||
|
||||
schema: Dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": normalized_properties,
|
||||
}
|
||||
if required_names:
|
||||
schema["required"] = required_names
|
||||
return schema
|
||||
|
||||
if isinstance(self.parameters, list) and self.parameters:
|
||||
properties: Dict[str, Any] = {}
|
||||
required_names: List[str] = []
|
||||
for parameter in self.parameters:
|
||||
if not isinstance(parameter, dict):
|
||||
continue
|
||||
parameter_name = str(parameter.get("name", "") or "").strip()
|
||||
if not parameter_name:
|
||||
continue
|
||||
if bool(parameter.get("required", False)):
|
||||
required_names.append(parameter_name)
|
||||
properties[parameter_name] = {
|
||||
key: value
|
||||
for key, value in parameter.items()
|
||||
if key not in {"name", "required", "param_type"}
|
||||
}
|
||||
properties[parameter_name]["type"] = str(
|
||||
parameter.get("type", parameter.get("param_type", "string")) or "string"
|
||||
)
|
||||
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
}
|
||||
if required_names:
|
||||
schema["required"] = required_names
|
||||
return schema
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class EventHandlerEntry(ComponentEntry):
|
||||
"""EventHandler 组件条目"""
|
||||
@@ -282,7 +353,7 @@ class MessageGatewayEntry(ComponentEntry):
|
||||
|
||||
|
||||
class ComponentRegistry:
|
||||
"""Host-side 组件注册表
|
||||
"""Host 侧组件注册表。
|
||||
|
||||
由 Supervisor 在收到 plugin.register_components 时调用。
|
||||
供业务层查询可用组件、匹配命令、调度 action/event 等。
|
||||
@@ -300,6 +371,86 @@ class ComponentRegistry:
|
||||
# 按插件索引
|
||||
self._by_plugin: Dict[str, List[ComponentEntry]] = {}
|
||||
|
||||
@staticmethod
|
||||
def _convert_action_metadata_to_tool_metadata(
|
||||
name: str,
|
||||
metadata: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""将旧 Action 元数据转换为统一 Tool 元数据。
|
||||
|
||||
Args:
|
||||
name: 组件名称。
|
||||
metadata: Action 原始元数据。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 转换后的 Tool 元数据。
|
||||
"""
|
||||
|
||||
action_parameters = metadata.get("action_parameters")
|
||||
parameters_schema: Dict[str, Any] | None = None
|
||||
if isinstance(action_parameters, dict) and action_parameters:
|
||||
properties: Dict[str, Any] = {}
|
||||
for parameter_name, parameter_description in action_parameters.items():
|
||||
normalized_name = str(parameter_name or "").strip()
|
||||
if not normalized_name:
|
||||
continue
|
||||
properties[normalized_name] = {
|
||||
"type": "string",
|
||||
"description": str(parameter_description or "").strip() or "兼容旧 Action 参数",
|
||||
}
|
||||
if properties:
|
||||
parameters_schema = {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
}
|
||||
|
||||
detailed_parts: List[str] = []
|
||||
if parameters_schema is not None:
|
||||
parameter_description = build_tool_detailed_description(parameters_schema)
|
||||
if parameter_description:
|
||||
detailed_parts.append(parameter_description)
|
||||
|
||||
action_require = [
|
||||
str(item).strip()
|
||||
for item in (metadata.get("action_require") or [])
|
||||
if str(item).strip()
|
||||
]
|
||||
if action_require:
|
||||
detailed_parts.append("使用建议:\n" + "\n".join(f"- {item}" for item in action_require))
|
||||
|
||||
associated_types = [
|
||||
str(item).strip()
|
||||
for item in (metadata.get("associated_types") or [])
|
||||
if str(item).strip()
|
||||
]
|
||||
if associated_types:
|
||||
detailed_parts.append(f"适用消息类型:{'、'.join(associated_types)}。")
|
||||
|
||||
activation_type = str(metadata.get("activation_type", "always") or "always").strip()
|
||||
activation_keywords = [
|
||||
str(item).strip()
|
||||
for item in (metadata.get("activation_keywords") or [])
|
||||
if str(item).strip()
|
||||
]
|
||||
activation_lines = [f"兼容旧 Action 激活方式:{activation_type}。"]
|
||||
if activation_keywords:
|
||||
activation_lines.append(f"激活关键词:{'、'.join(activation_keywords)}。")
|
||||
if str(metadata.get("action_prompt", "") or "").strip():
|
||||
activation_lines.append(f"原始 Action 提示语:{str(metadata['action_prompt']).strip()}。")
|
||||
detailed_parts.append("\n".join(activation_lines))
|
||||
|
||||
brief_description = str(metadata.get("brief_description", metadata.get("description", "") or f"工具 {name}")).strip()
|
||||
return {
|
||||
**metadata,
|
||||
"description": brief_description,
|
||||
"brief_description": brief_description,
|
||||
"detailed_description": "\n\n".join(part for part in detailed_parts if part).strip(),
|
||||
"parameters_raw": parameters_schema or {},
|
||||
"invoke_method": "plugin.invoke_action",
|
||||
"legacy_action": True,
|
||||
"legacy_component_type": "ACTION",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _normalize_component_type(component_type: str) -> ComponentTypes:
|
||||
"""规范化组件类型输入。
|
||||
@@ -338,18 +489,20 @@ class ComponentRegistry:
|
||||
"""
|
||||
try:
|
||||
normalized_type = self._normalize_component_type(component_type)
|
||||
normalized_metadata = dict(metadata)
|
||||
if normalized_type == ComponentTypes.ACTION:
|
||||
comp = ActionEntry(name, normalized_type.value, plugin_id, metadata)
|
||||
normalized_metadata = self._convert_action_metadata_to_tool_metadata(name, normalized_metadata)
|
||||
comp = ToolEntry(name, ComponentTypes.TOOL.value, plugin_id, normalized_metadata)
|
||||
elif normalized_type == ComponentTypes.COMMAND:
|
||||
comp = CommandEntry(name, normalized_type.value, plugin_id, metadata)
|
||||
comp = CommandEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
elif normalized_type == ComponentTypes.TOOL:
|
||||
comp = ToolEntry(name, normalized_type.value, plugin_id, metadata)
|
||||
comp = ToolEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
elif normalized_type == ComponentTypes.EVENT_HANDLER:
|
||||
comp = EventHandlerEntry(name, normalized_type.value, plugin_id, metadata)
|
||||
comp = EventHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
elif normalized_type == ComponentTypes.HOOK_HANDLER:
|
||||
comp = HookHandlerEntry(name, normalized_type.value, plugin_id, metadata)
|
||||
comp = HookHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
elif normalized_type == ComponentTypes.MESSAGE_GATEWAY:
|
||||
comp = MessageGatewayEntry(name, normalized_type.value, plugin_id, metadata)
|
||||
comp = MessageGatewayEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
else:
|
||||
raise ValueError(f"组件类型 {component_type} 不存在")
|
||||
except ValueError:
|
||||
|
||||
48
src/plugin_runtime/tool_provider.py
Normal file
48
src/plugin_runtime/tool_provider.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""插件运行时工具 Provider。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolProvider, ToolSpec
|
||||
|
||||
from .component_query import component_query_service
|
||||
|
||||
|
||||
class PluginToolProvider(ToolProvider):
|
||||
"""将插件 Tool 与兼容旧 Action 暴露为统一工具 Provider。"""
|
||||
|
||||
provider_name = "plugin_runtime"
|
||||
provider_type = "plugin"
|
||||
|
||||
async def list_tools(self) -> list[ToolSpec]:
|
||||
"""列出插件运行时当前可用的工具声明。"""
|
||||
|
||||
return list(component_query_service.get_llm_available_tool_specs().values())
|
||||
|
||||
async def invoke(
|
||||
self,
|
||||
invocation: ToolInvocation,
|
||||
context: Optional[ToolExecutionContext] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""执行插件工具或兼容旧 Action 的工具调用。
|
||||
|
||||
Args:
|
||||
invocation: 工具调用请求。
|
||||
context: 执行上下文。
|
||||
|
||||
Returns:
|
||||
ToolExecutionResult: 工具执行结果。
|
||||
"""
|
||||
|
||||
return await component_query_service.invoke_tool_as_tool(
|
||||
invocation=invocation,
|
||||
context=context,
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭 Provider。
|
||||
|
||||
插件运行时工具 Provider 不持有独立资源。
|
||||
"""
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
from maibot_sdk import Command, MaiBotPlugin
|
||||
|
||||
|
||||
_VALID_COMPONENT_TYPES = ("action", "command", "event_handler")
|
||||
_VALID_COMPONENT_TYPES = ("tool", "command", "event_handler")
|
||||
|
||||
HELP_ALL = (
|
||||
"管理命令帮助\n"
|
||||
@@ -37,7 +37,7 @@ HELP_COMPONENT = (
|
||||
"/pm component enable local <component_name> <component_type> 本聊天启用组件\n"
|
||||
"/pm component disable global <component_name> <component_type> 全局禁用组件\n"
|
||||
"/pm component disable local <component_name> <component_type> 本聊天禁用组件\n"
|
||||
" - <component_type> 可选项: action, command, event_handler\n"
|
||||
" - <component_type> 可选项: tool, command, event_handler\n"
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user