Merge branch 'r-dev' of https://github.com/Mai-with-u/MaiBot into r-dev
This commit is contained in:
@@ -1,27 +1,28 @@
|
||||
import time
|
||||
"""PFC 侧消息发送封装。"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from maim_message import Seg
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.message import MessageSending
|
||||
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
|
||||
from src.chat.utils.utils import get_bot_account
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.services import send_service as send_api
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
logger = get_logger("message_sender")
|
||||
|
||||
|
||||
class DirectMessageSender:
|
||||
"""直接消息发送器"""
|
||||
"""直接消息发送器。"""
|
||||
|
||||
def __init__(self, private_name: str):
|
||||
def __init__(self, private_name: str) -> None:
|
||||
"""初始化直接消息发送器。
|
||||
|
||||
Args:
|
||||
private_name: 当前私聊实例的名称。
|
||||
"""
|
||||
self.private_name = private_name
|
||||
|
||||
async def send_message(
|
||||
@@ -30,58 +31,31 @@ class DirectMessageSender:
|
||||
content: str,
|
||||
reply_to_message: Optional[MaiMessage] = None,
|
||||
) -> None:
|
||||
"""发送消息到聊天流
|
||||
"""发送文本消息到聊天流。
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天会话
|
||||
content: 消息内容
|
||||
reply_to_message: 要回复的消息(可选)
|
||||
chat_stream: 目标聊天会话。
|
||||
content: 待发送的文本内容。
|
||||
reply_to_message: 可选的引用回复锚点消息。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当消息发送失败时抛出。
|
||||
"""
|
||||
try:
|
||||
# 创建消息内容
|
||||
segments = Seg(type="seglist", data=[Seg(type="text", data=content)])
|
||||
|
||||
# 获取麦麦的信息
|
||||
bot_user_id = get_bot_account(chat_stream.platform)
|
||||
if not bot_user_id:
|
||||
logger.error(f"[私聊][{self.private_name}]平台 {chat_stream.platform} 未配置机器人账号,无法发送消息")
|
||||
raise RuntimeError(f"平台 {chat_stream.platform} 未配置机器人账号")
|
||||
bot_user_info = UserInfo(
|
||||
user_id=bot_user_id,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
sent = await send_api.text_to_stream(
|
||||
text=content,
|
||||
stream_id=chat_stream.session_id,
|
||||
set_reply=reply_to_message is not None,
|
||||
reply_message=reply_to_message,
|
||||
storage_message=True,
|
||||
)
|
||||
|
||||
# 用当前时间作为message_id,和之前那套sender一样
|
||||
message_id = f"dm{round(time.time(), 2)}"
|
||||
|
||||
# 构建发送者信息(私聊时为接收者)
|
||||
sender_info = None
|
||||
if reply_to_message and reply_to_message.message_info and reply_to_message.message_info.user_info:
|
||||
sender_info = reply_to_message.message_info.user_info
|
||||
|
||||
# 构建消息对象
|
||||
message = MessageSending(
|
||||
message_id=message_id,
|
||||
session=chat_stream,
|
||||
bot_user_info=bot_user_info,
|
||||
sender_info=sender_info,
|
||||
message_segment=segments,
|
||||
reply=reply_to_message,
|
||||
is_head=True,
|
||||
is_emoji=False,
|
||||
thinking_start_time=time.time(),
|
||||
)
|
||||
|
||||
# 发送消息
|
||||
message_sender = UniversalMessageSender()
|
||||
sent = await message_sender.send_message(message, typing=False, set_reply=False, storage_message=True)
|
||||
|
||||
if sent:
|
||||
logger.info(f"[私聊][{self.private_name}]PFC消息已发送: {content}")
|
||||
else:
|
||||
logger.error(f"[私聊][{self.private_name}]PFC消息发送失败")
|
||||
raise RuntimeError("消息发送失败")
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]PFC消息发送失败: {str(e)}")
|
||||
logger.error(f"[私聊][{self.private_name}]PFC消息发送失败")
|
||||
raise RuntimeError("消息发送失败")
|
||||
except Exception as exc:
|
||||
logger.error(f"[私聊][{self.private_name}]PFC消息发送失败: {exc}")
|
||||
raise
|
||||
|
||||
@@ -1,30 +1,32 @@
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
import json
|
||||
import time
|
||||
import traceback
|
||||
import random
|
||||
import re
|
||||
from typing import Dict, Optional, Tuple, List, TYPE_CHECKING
|
||||
from rich.traceback import install
|
||||
from datetime import datetime
|
||||
from json_repair import repair_json
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from json_repair import repair_json
|
||||
from rich.traceback import install
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.logger.plan_reply_logger import PlanReplyLogger
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_action import ActionUtils
|
||||
from src.config.config import global_config, model_config
|
||||
from src.core.types import ActionActivationType, ActionInfo, ComponentType
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.message_service import (
|
||||
build_readable_messages_with_id,
|
||||
get_actions_by_timestamp_with_chat,
|
||||
get_messages_before_time_in_chat,
|
||||
)
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.core.types import ActionActivationType, ActionInfo, ComponentType
|
||||
from src.core.component_registry import component_registry
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.info_data_model import TargetPersonInfo
|
||||
@@ -320,7 +322,7 @@ class BrainPlanner:
|
||||
current_available_actions_dict = self.action_manager.get_using_actions()
|
||||
|
||||
# 获取完整的动作信息
|
||||
all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore
|
||||
all_registered_actions: Dict[str, ActionInfo] = component_query_service.get_components_by_type( # type: ignore
|
||||
ComponentType.ACTION
|
||||
)
|
||||
current_available_actions = {}
|
||||
|
||||
@@ -1,734 +0,0 @@
|
||||
import asyncio
|
||||
import time
|
||||
import traceback
|
||||
import random
|
||||
from typing import List, Optional, Dict, Any, Tuple, TYPE_CHECKING
|
||||
from rich.traceback import install
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.data_models.message_data_model import ReplyContentType
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.planner_actions.planner import ActionPlanner
|
||||
from src.chat.planner_actions.action_modifier import ActionModifier
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.heart_flow.hfc_utils import CycleDetail
|
||||
from src.learners.expression_learner import expression_learner_manager
|
||||
from src.chat.heart_flow.frequency_control import frequency_control_manager
|
||||
from src.learners.message_recorder import extract_and_distribute_messages
|
||||
from src.person_info.person_info import Person
|
||||
from src.plugin_system.base.component_types import EventType, ActionInfo
|
||||
from src.plugin_system.core import events_manager
|
||||
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages_with_id,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
)
|
||||
from src.chat.utils.utils import record_replyer_action_temp
|
||||
from src.memory_system.chat_history_summarizer import ChatHistorySummarizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.message_data_model import ReplySetModel
|
||||
|
||||
|
||||
ERROR_LOOP_INFO = {
|
||||
"loop_plan_info": {
|
||||
"action_result": {
|
||||
"action_type": "error",
|
||||
"action_data": {},
|
||||
"reasoning": "循环处理失败",
|
||||
},
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": False,
|
||||
"reply_text": "",
|
||||
"command": "",
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
# 注释:原来的动作修改超时常量已移除,因为改为顺序执行
|
||||
|
||||
logger = get_logger("hfc") # Logger Name Changed
|
||||
|
||||
|
||||
class HeartFChatting:
|
||||
"""
|
||||
管理一个连续的Focus Chat循环
|
||||
用于在特定聊天流中生成回复。
|
||||
其生命周期现在由其关联的 SubHeartflow 的 FOCUSED 状态控制。
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str):
|
||||
"""
|
||||
HeartFChatting 初始化函数
|
||||
|
||||
参数:
|
||||
chat_id: 聊天流唯一标识符(如stream_id)
|
||||
on_stop_focus_chat: 当收到stop_focus_chat命令时调用的回调函数
|
||||
performance_version: 性能记录版本号,用于区分不同启动版本
|
||||
"""
|
||||
# 基础属性
|
||||
self.stream_id: str = chat_id # 聊天流ID
|
||||
self.chat_stream: ChatStream = get_chat_manager().get_stream(self.stream_id) # type: ignore
|
||||
if not self.chat_stream:
|
||||
raise ValueError(f"无法找到聊天流: {self.stream_id}")
|
||||
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.stream_id) or self.stream_id}]"
|
||||
|
||||
self.expression_learner = expression_learner_manager.get_expression_learner(self.stream_id)
|
||||
|
||||
self.action_manager = ActionManager()
|
||||
self.action_planner = ActionPlanner(chat_id=self.stream_id, action_manager=self.action_manager)
|
||||
self.action_modifier = ActionModifier(action_manager=self.action_manager, chat_id=self.stream_id)
|
||||
|
||||
# 循环控制内部状态
|
||||
self.running: bool = False
|
||||
self._loop_task: Optional[asyncio.Task] = None # 主循环任务
|
||||
|
||||
# 添加循环信息管理相关的属性
|
||||
self.history_loop: List[CycleDetail] = []
|
||||
self._cycle_counter = 0
|
||||
self._current_cycle_detail: CycleDetail = None # type: ignore
|
||||
|
||||
self.last_read_time = time.time() - 2
|
||||
|
||||
self.is_mute = False
|
||||
|
||||
self.last_active_time = time.time() # 记录上一次非noreply时间
|
||||
|
||||
self.question_probability_multiplier = 1
|
||||
self.questioned = False
|
||||
|
||||
# 跟踪连续 no_reply 次数,用于动态调整阈值
|
||||
self.consecutive_no_reply_count = 0
|
||||
|
||||
# 聊天内容概括器
|
||||
self.chat_history_summarizer = ChatHistorySummarizer(chat_id=self.stream_id)
|
||||
|
||||
async def start(self):
|
||||
"""检查是否需要启动主循环,如果未激活则启动。"""
|
||||
|
||||
# 如果循环已经激活,直接返回
|
||||
if self.running:
|
||||
logger.debug(f"{self.log_prefix} HeartFChatting 已激活,无需重复启动")
|
||||
return
|
||||
|
||||
try:
|
||||
# 标记为活动状态,防止重复启动
|
||||
self.running = True
|
||||
|
||||
self._loop_task = asyncio.create_task(self._main_chat_loop())
|
||||
self._loop_task.add_done_callback(self._handle_loop_completion)
|
||||
|
||||
# 启动聊天内容概括器的后台定期检查循环
|
||||
await self.chat_history_summarizer.start()
|
||||
|
||||
logger.info(f"{self.log_prefix} HeartFChatting 启动完成")
|
||||
|
||||
except Exception as e:
|
||||
# 启动失败时重置状态
|
||||
self.running = False
|
||||
self._loop_task = None
|
||||
logger.error(f"{self.log_prefix} HeartFChatting 启动失败: {e}")
|
||||
raise
|
||||
|
||||
def _handle_loop_completion(self, task: asyncio.Task):
|
||||
"""当 _hfc_loop 任务完成时执行的回调。"""
|
||||
try:
|
||||
if exception := task.exception():
|
||||
logger.error(f"{self.log_prefix} HeartFChatting: 脱离了聊天(异常): {exception}")
|
||||
logger.error(traceback.format_exc()) # Log full traceback for exceptions
|
||||
else:
|
||||
logger.info(f"{self.log_prefix} HeartFChatting: 脱离了聊天 (外部停止)")
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.log_prefix} HeartFChatting: 结束了聊天")
|
||||
|
||||
def start_cycle(self) -> Tuple[Dict[str, float], str]:
|
||||
self._cycle_counter += 1
|
||||
self._current_cycle_detail = CycleDetail(self._cycle_counter)
|
||||
self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
|
||||
cycle_timers = {}
|
||||
return cycle_timers, self._current_cycle_detail.thinking_id
|
||||
|
||||
def end_cycle(self, loop_info, cycle_timers):
|
||||
self._current_cycle_detail.set_loop_info(loop_info)
|
||||
self.history_loop.append(self._current_cycle_detail)
|
||||
self._current_cycle_detail.timers = cycle_timers
|
||||
self._current_cycle_detail.end_time = time.time()
|
||||
|
||||
def print_cycle_info(self, cycle_timers):
|
||||
# 记录循环信息和计时器结果
|
||||
timer_strings = []
|
||||
for name, elapsed in cycle_timers.items():
|
||||
if elapsed < 0.1:
|
||||
# 不显示小于0.1秒的计时器
|
||||
continue
|
||||
formatted_time = f"{elapsed:.2f}秒"
|
||||
timer_strings.append(f"{name}: {formatted_time}")
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 第{self._current_cycle_detail.cycle_id}次思考,"
|
||||
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒;" # type: ignore
|
||||
+ (f"详情: {'; '.join(timer_strings)}" if timer_strings else "")
|
||||
)
|
||||
|
||||
async def _loopbody(self):
|
||||
recent_messages_list = message_api.get_messages_by_time_in_chat(
|
||||
chat_id=self.stream_id,
|
||||
start_time=self.last_read_time,
|
||||
end_time=time.time(),
|
||||
limit=20,
|
||||
limit_mode="latest",
|
||||
filter_mai=True,
|
||||
filter_command=False,
|
||||
filter_intercept_message_level=0,
|
||||
)
|
||||
|
||||
# 根据连续 no_reply 次数动态调整阈值
|
||||
# 3次 no_reply 时,阈值调高到 1.5(50%概率为1,50%概率为2)
|
||||
# 5次 no_reply 时,提高到 2(大于等于两条消息的阈值)
|
||||
if self.consecutive_no_reply_count >= 5:
|
||||
threshold = 2
|
||||
elif self.consecutive_no_reply_count >= 3:
|
||||
# 1.5 的含义:50%概率为1,50%概率为2
|
||||
threshold = 2 if random.random() < 0.5 else 1
|
||||
else:
|
||||
threshold = 1
|
||||
|
||||
if len(recent_messages_list) >= threshold:
|
||||
# for message in recent_messages_list:
|
||||
# print(message.processed_plain_text)
|
||||
|
||||
self.last_read_time = time.time()
|
||||
|
||||
# !此处使at或者提及必定回复
|
||||
mentioned_message = None
|
||||
for message in recent_messages_list:
|
||||
if (message.is_mentioned or message.is_at) and global_config.chat.mentioned_bot_reply:
|
||||
mentioned_message = message
|
||||
|
||||
# logger.info(f"{self.log_prefix} 当前talk_value: {global_config.chat.get_talk_value(self.stream_id)}")
|
||||
|
||||
# *控制频率用
|
||||
if mentioned_message:
|
||||
await self._observe(recent_messages_list=recent_messages_list, force_reply_message=mentioned_message)
|
||||
elif (
|
||||
random.random()
|
||||
< global_config.chat.get_talk_value(self.stream_id)
|
||||
* frequency_control_manager.get_or_create_frequency_control(self.stream_id).get_talk_frequency_adjust()
|
||||
):
|
||||
await self._observe(recent_messages_list=recent_messages_list)
|
||||
else:
|
||||
# 没有提到,继续保持沉默,等待5秒防止频繁触发
|
||||
await asyncio.sleep(10)
|
||||
return True
|
||||
else:
|
||||
await asyncio.sleep(0.2)
|
||||
return True
|
||||
return True
|
||||
|
||||
async def _send_and_store_reply(
|
||||
self,
|
||||
response_set: "ReplySetModel",
|
||||
action_message: "DatabaseMessages",
|
||||
cycle_timers: Dict[str, float],
|
||||
thinking_id,
|
||||
actions,
|
||||
selected_expressions: Optional[List[int]] = None,
|
||||
quote_message: Optional[bool] = None,
|
||||
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
|
||||
with Timer("回复发送", cycle_timers):
|
||||
reply_text = await self._send_response(
|
||||
reply_set=response_set,
|
||||
message_data=action_message,
|
||||
selected_expressions=selected_expressions,
|
||||
quote_message=quote_message,
|
||||
)
|
||||
|
||||
# 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
|
||||
platform = action_message.chat_info.platform
|
||||
if platform is None:
|
||||
platform = getattr(self.chat_stream, "platform", "unknown")
|
||||
|
||||
person = Person(platform=platform, user_id=action_message.user_info.user_id)
|
||||
person_name = person.person_name
|
||||
action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
|
||||
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=action_prompt_display,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reply_text": reply_text},
|
||||
action_name="reply",
|
||||
)
|
||||
|
||||
# 构建循环信息
|
||||
loop_info: Dict[str, Any] = {
|
||||
"loop_plan_info": {
|
||||
"action_result": actions,
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": True,
|
||||
"reply_text": reply_text,
|
||||
"command": "",
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
|
||||
return loop_info, reply_text, cycle_timers
|
||||
|
||||
async def _observe(
|
||||
self, # interest_value: float = 0.0,
|
||||
recent_messages_list: Optional[List["DatabaseMessages"]] = None,
|
||||
force_reply_message: Optional["DatabaseMessages"] = None,
|
||||
) -> bool: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||
if recent_messages_list is None:
|
||||
recent_messages_list = []
|
||||
_reply_text = "" # 初始化reply_text变量,避免UnboundLocalError
|
||||
|
||||
start_time = time.time()
|
||||
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
||||
# 通过 MessageRecorder 统一提取消息并分发给 expression_learner 和 jargon_miner
|
||||
# 在 replyer 执行时触发,统一管理时间窗口,避免重复获取消息
|
||||
asyncio.create_task(extract_and_distribute_messages(self.stream_id))
|
||||
|
||||
# 添加curious检测任务 - 检测聊天记录中的矛盾、冲突或需要提问的内容
|
||||
# asyncio.create_task(check_and_make_question(self.stream_id))
|
||||
# 添加聊天内容概括任务 - 累积、打包和压缩聊天记录
|
||||
# 注意:后台循环已在start()中启动,这里作为额外触发点,在有思考时立即处理
|
||||
# asyncio.create_task(self.chat_history_summarizer.process())
|
||||
|
||||
cycle_timers, thinking_id = self.start_cycle()
|
||||
logger.info(
|
||||
f"{self.log_prefix} 开始第{self._cycle_counter}次思考(频率: {global_config.chat.get_talk_value(self.stream_id)})"
|
||||
)
|
||||
|
||||
# 第一步:动作检查
|
||||
available_actions: Dict[str, ActionInfo] = {}
|
||||
try:
|
||||
await self.action_modifier.modify_actions()
|
||||
available_actions = self.action_manager.get_using_actions()
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 动作修改失败: {e}")
|
||||
|
||||
# 执行planner
|
||||
is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
|
||||
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.6),
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
messages=message_list_before_now,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=self.action_planner.last_obs_time_mark,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
prompt_info = await self.action_planner.build_planner_prompt(
|
||||
is_group_chat=is_group_chat,
|
||||
chat_target_info=chat_target_info,
|
||||
current_available_actions=available_actions,
|
||||
chat_content_block=chat_content_block,
|
||||
message_id_list=message_id_list,
|
||||
)
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id
|
||||
)
|
||||
if not continue_flag:
|
||||
return False
|
||||
if modified_message and modified_message._modify_flags.modify_llm_prompt:
|
||||
prompt_info = (modified_message.llm_prompt, prompt_info[1])
|
||||
|
||||
with Timer("规划器", cycle_timers):
|
||||
action_to_use_info = await self.action_planner.plan(
|
||||
loop_start_time=self.last_read_time,
|
||||
available_actions=available_actions,
|
||||
force_reply_message=force_reply_message,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 决定执行{len(action_to_use_info)}个动作: {' '.join([a.action_type for a in action_to_use_info])}"
|
||||
)
|
||||
|
||||
# 3. 并行执行所有动作
|
||||
action_tasks = [
|
||||
asyncio.create_task(
|
||||
self._execute_action(action, action_to_use_info, thinking_id, available_actions, cycle_timers)
|
||||
)
|
||||
for action in action_to_use_info
|
||||
]
|
||||
|
||||
# 并行执行所有任务
|
||||
results = await asyncio.gather(*action_tasks, return_exceptions=True)
|
||||
|
||||
# 处理执行结果
|
||||
reply_loop_info = None
|
||||
reply_text_from_reply = ""
|
||||
action_success = False
|
||||
action_reply_text = ""
|
||||
|
||||
excute_result_str = ""
|
||||
for result in results:
|
||||
excute_result_str += f"{result['action_type']} 执行结果:{result['result']}\n"
|
||||
|
||||
if isinstance(result, BaseException):
|
||||
logger.error(f"{self.log_prefix} 动作执行异常: {result}")
|
||||
continue
|
||||
|
||||
if result["action_type"] != "reply":
|
||||
action_success = result["success"]
|
||||
action_reply_text = result["result"]
|
||||
elif result["action_type"] == "reply":
|
||||
if result["success"]:
|
||||
reply_loop_info = result["loop_info"]
|
||||
reply_text_from_reply = result["result"]
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 回复动作执行失败")
|
||||
|
||||
self.action_planner.add_plan_excute_log(result=excute_result_str)
|
||||
|
||||
# 构建最终的循环信息
|
||||
if reply_loop_info:
|
||||
# 如果有回复信息,使用回复的loop_info作为基础
|
||||
loop_info = reply_loop_info
|
||||
# 更新动作执行信息
|
||||
loop_info["loop_action_info"].update(
|
||||
{
|
||||
"action_taken": action_success,
|
||||
"taken_time": time.time(),
|
||||
}
|
||||
)
|
||||
_reply_text = reply_text_from_reply
|
||||
else:
|
||||
# 没有回复信息,构建纯动作的loop_info
|
||||
loop_info = {
|
||||
"loop_plan_info": {
|
||||
"action_result": action_to_use_info,
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": action_success,
|
||||
"reply_text": action_reply_text,
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
_reply_text = action_reply_text
|
||||
|
||||
self.end_cycle(loop_info, cycle_timers)
|
||||
self.print_cycle_info(cycle_timers)
|
||||
|
||||
end_time = time.time()
|
||||
if end_time - start_time < global_config.chat.planner_smooth:
|
||||
wait_time = global_config.chat.planner_smooth - (end_time - start_time)
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
await asyncio.sleep(0.1)
|
||||
return True
|
||||
|
||||
async def _main_chat_loop(self):
|
||||
"""主循环,持续进行计划并可能回复消息,直到被外部取消。"""
|
||||
try:
|
||||
while self.running:
|
||||
# 主循环
|
||||
success = await self._loopbody()
|
||||
await asyncio.sleep(0.1)
|
||||
if not success:
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
# 设置了关闭标志位后被取消是正常流程
|
||||
logger.info(f"{self.log_prefix} 麦麦已关闭聊天")
|
||||
except Exception:
|
||||
logger.error(f"{self.log_prefix} 麦麦聊天意外错误,将于3s后尝试重新启动")
|
||||
print(traceback.format_exc())
|
||||
await asyncio.sleep(3)
|
||||
self._loop_task = asyncio.create_task(self._main_chat_loop())
|
||||
logger.error(f"{self.log_prefix} 结束了当前聊天循环")
|
||||
|
||||
async def _handle_action(
|
||||
self,
|
||||
action: str,
|
||||
action_reasoning: str,
|
||||
action_data: dict,
|
||||
cycle_timers: Dict[str, float],
|
||||
thinking_id: str,
|
||||
action_message: Optional["DatabaseMessages"] = None,
|
||||
) -> tuple[bool, str, str]:
|
||||
"""
|
||||
处理规划动作,使用动作工厂创建相应的动作处理器
|
||||
|
||||
参数:
|
||||
action: 动作类型
|
||||
action_reasoning: 决策理由
|
||||
action_data: 动作数据,包含不同动作需要的参数
|
||||
cycle_timers: 计时器字典
|
||||
thinking_id: 思考ID
|
||||
action_message: 消息数据
|
||||
返回:
|
||||
tuple[bool, str, str]: (是否执行了动作, 思考消息ID, 命令)
|
||||
"""
|
||||
try:
|
||||
# 使用工厂创建动作处理器实例
|
||||
try:
|
||||
action_handler = self.action_manager.create_action(
|
||||
action_name=action,
|
||||
action_data=action_data,
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
chat_stream=self.chat_stream,
|
||||
log_prefix=self.log_prefix,
|
||||
action_reasoning=action_reasoning,
|
||||
action_message=action_message,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 创建动作处理器时出错: {e}")
|
||||
traceback.print_exc()
|
||||
return False, ""
|
||||
|
||||
# 处理动作并获取结果(固定记录一次动作信息)
|
||||
result = await action_handler.execute()
|
||||
success, action_text = result
|
||||
|
||||
return success, action_text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 处理{action}时出错: {e}")
|
||||
traceback.print_exc()
|
||||
return False, ""
|
||||
|
||||
async def _send_response(
|
||||
self,
|
||||
reply_set: "ReplySetModel",
|
||||
message_data: "DatabaseMessages",
|
||||
selected_expressions: Optional[List[int]] = None,
|
||||
quote_message: Optional[bool] = None,
|
||||
) -> str:
|
||||
# 根据 llm_quote 配置决定是否使用 quote_message 参数
|
||||
if global_config.chat.llm_quote:
|
||||
# 如果配置为 true,使用 llm_quote 参数决定是否引用回复
|
||||
if quote_message is None:
|
||||
logger.warning(f"{self.log_prefix} quote_message 参数为空,不引用")
|
||||
need_reply = False
|
||||
else:
|
||||
need_reply = quote_message
|
||||
if need_reply:
|
||||
logger.info(f"{self.log_prefix} LLM 决定使用引用回复")
|
||||
else:
|
||||
# 如果配置为 false,使用原来的模式
|
||||
new_message_count = message_api.count_new_messages(
|
||||
chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time()
|
||||
)
|
||||
need_reply = new_message_count >= random.randint(2, 3) or time.time() - self.last_read_time > 90
|
||||
if need_reply:
|
||||
logger.info(f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复,或者上次回复时间超过90秒")
|
||||
|
||||
reply_text = ""
|
||||
first_replied = False
|
||||
for reply_content in reply_set.reply_data:
|
||||
if reply_content.content_type != ReplyContentType.TEXT:
|
||||
continue
|
||||
data: str = reply_content.content # type: ignore
|
||||
if not first_replied:
|
||||
await send_api.text_to_stream(
|
||||
text=data,
|
||||
stream_id=self.chat_stream.stream_id,
|
||||
reply_message=message_data,
|
||||
set_reply=need_reply,
|
||||
typing=False,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
first_replied = True
|
||||
else:
|
||||
await send_api.text_to_stream(
|
||||
text=data,
|
||||
stream_id=self.chat_stream.stream_id,
|
||||
reply_message=message_data,
|
||||
set_reply=False,
|
||||
typing=True,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
reply_text += data
|
||||
|
||||
return reply_text
|
||||
|
||||
async def _execute_action(
|
||||
self,
|
||||
action_planner_info: ActionPlannerInfo,
|
||||
chosen_action_plan_infos: List[ActionPlannerInfo],
|
||||
thinking_id: str,
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
cycle_timers: Dict[str, float],
|
||||
):
|
||||
"""执行单个动作的通用函数"""
|
||||
try:
|
||||
with Timer(f"动作{action_planner_info.action_type}", cycle_timers):
|
||||
# 直接当场执行no_reply逻辑
|
||||
if action_planner_info.action_type == "no_reply":
|
||||
# 直接处理no_reply逻辑,不再通过动作系统
|
||||
reason = action_planner_info.reasoning or "选择不回复"
|
||||
# logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
|
||||
|
||||
# 增加连续 no_reply 计数
|
||||
self.consecutive_no_reply_count += 1
|
||||
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={},
|
||||
action_name="no_reply",
|
||||
action_reasoning=reason,
|
||||
)
|
||||
|
||||
return {"action_type": "no_reply", "success": True, "result": "选择不回复", "command": ""}
|
||||
|
||||
elif action_planner_info.action_type == "reply":
|
||||
# 直接当场执行reply逻辑
|
||||
self.questioned = False
|
||||
# 刷新主动发言状态
|
||||
# 重置连续 no_reply 计数
|
||||
self.consecutive_no_reply_count = 0
|
||||
|
||||
reason = action_planner_info.reasoning or ""
|
||||
# 根据 think_mode 配置决定 think_level 的值
|
||||
think_mode = global_config.chat.think_mode
|
||||
if think_mode == "default":
|
||||
think_level = 0
|
||||
elif think_mode == "deep":
|
||||
think_level = 1
|
||||
elif think_mode == "dynamic":
|
||||
# dynamic 模式:从 planner 返回的 action_data 中获取
|
||||
think_level = action_planner_info.action_data.get("think_level", 1)
|
||||
else:
|
||||
# 默认使用 default 模式
|
||||
think_level = 0
|
||||
# 使用 action_reasoning(planner 的整体思考理由)作为 reply_reason
|
||||
planner_reasoning = action_planner_info.action_reasoning or reason
|
||||
|
||||
record_replyer_action_temp(
|
||||
chat_id=self.stream_id,
|
||||
reason=reason,
|
||||
think_level=think_level,
|
||||
)
|
||||
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={},
|
||||
action_name="reply",
|
||||
action_reasoning=reason,
|
||||
)
|
||||
|
||||
# 从 Planner 的 action_data 中提取未知词语列表(仅在 reply 时使用)
|
||||
unknown_words = None
|
||||
quote_message = None
|
||||
if isinstance(action_planner_info.action_data, dict):
|
||||
uw = action_planner_info.action_data.get("unknown_words")
|
||||
if isinstance(uw, list):
|
||||
cleaned_uw: List[str] = []
|
||||
for item in uw:
|
||||
if isinstance(item, str):
|
||||
s = item.strip()
|
||||
if s:
|
||||
cleaned_uw.append(s)
|
||||
if cleaned_uw:
|
||||
unknown_words = cleaned_uw
|
||||
|
||||
# 从 Planner 的 action_data 中提取 quote_message 参数
|
||||
qm = action_planner_info.action_data.get("quote")
|
||||
if qm is not None:
|
||||
# 支持多种格式:true/false, "true"/"false", 1/0
|
||||
if isinstance(qm, bool):
|
||||
quote_message = qm
|
||||
elif isinstance(qm, str):
|
||||
quote_message = qm.lower() in ("true", "1", "yes")
|
||||
elif isinstance(qm, (int, float)):
|
||||
quote_message = bool(qm)
|
||||
|
||||
logger.info(f"{self.log_prefix} {qm}引用回复设置: {quote_message}")
|
||||
|
||||
success, llm_response = await generator_api.generate_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
reply_message=action_planner_info.action_message,
|
||||
available_actions=available_actions,
|
||||
chosen_actions=chosen_action_plan_infos,
|
||||
reply_reason=planner_reasoning,
|
||||
unknown_words=unknown_words,
|
||||
enable_tool=global_config.tool.enable_tool,
|
||||
request_type="replyer",
|
||||
from_plugin=False,
|
||||
reply_time_point=action_planner_info.action_data.get("loop_start_time", time.time()),
|
||||
think_level=think_level,
|
||||
)
|
||||
|
||||
if not success or not llm_response or not llm_response.reply_set:
|
||||
if action_planner_info.action_message:
|
||||
logger.info(f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败")
|
||||
else:
|
||||
logger.info("回复生成失败")
|
||||
return {"action_type": "reply", "success": False, "result": "回复生成失败", "loop_info": None}
|
||||
|
||||
response_set = llm_response.reply_set
|
||||
selected_expressions = llm_response.selected_expressions
|
||||
loop_info, reply_text, _ = await self._send_and_store_reply(
|
||||
response_set=response_set,
|
||||
action_message=action_planner_info.action_message, # type: ignore
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
actions=chosen_action_plan_infos,
|
||||
selected_expressions=selected_expressions,
|
||||
quote_message=quote_message,
|
||||
)
|
||||
self.last_active_time = time.time()
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"success": True,
|
||||
"result": f"你使用reply动作,对' {action_planner_info.action_message.processed_plain_text} '这句话进行了回复,回复内容为: '{reply_text}'",
|
||||
"loop_info": loop_info,
|
||||
}
|
||||
|
||||
else:
|
||||
# 执行普通动作
|
||||
with Timer("动作执行", cycle_timers):
|
||||
success, result = await self._handle_action(
|
||||
action=action_planner_info.action_type,
|
||||
action_reasoning=action_planner_info.action_reasoning or "",
|
||||
action_data=action_planner_info.action_data or {},
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
action_message=action_planner_info.action_message,
|
||||
)
|
||||
|
||||
self.last_active_time = time.time()
|
||||
return {
|
||||
"action_type": action_planner_info.action_type,
|
||||
"success": success,
|
||||
"result": result,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 执行动作时出错: {e}")
|
||||
logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}")
|
||||
return {
|
||||
"action_type": action_planner_info.action_type,
|
||||
"success": False,
|
||||
"result": "",
|
||||
"loop_info": None,
|
||||
"error": str(e),
|
||||
}
|
||||
@@ -1,377 +1,231 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
from rich.traceback import install
|
||||
from typing import List, Optional, TYPE_CHECKING
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
from src.learners.expression_learner import ExpressionLearner
|
||||
from src.learners.jargon_miner import JargonMiner
|
||||
from src.chat.event_helpers import build_event_message
|
||||
from src.chat.logger.plan_reply_logger import PlanReplyLogger
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.planner_actions.action_modifier import ActionModifier
|
||||
from src.chat.planner_actions.planner import ActionPlanner
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.utils.utils import record_replyer_action_temp
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
|
||||
from src.chat.message_receive.chat_manager import chat_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_config import ChatConfigUtils, ExpressionConfigUtils
|
||||
from src.config.config import global_config
|
||||
from src.config.file_watcher import FileChange
|
||||
from src.core.event_bus import event_bus
|
||||
from src.core.types import ActionInfo, EventType
|
||||
from src.person_info.person_info import Person
|
||||
from src.services import (
|
||||
database_service as database_api,
|
||||
generator_service as generator_api,
|
||||
message_service as message_api,
|
||||
send_service as send_api,
|
||||
)
|
||||
from src.services.message_service import build_readable_messages_with_id, get_messages_before_time_in_chat
|
||||
from src.learners.expression_learner import ExpressionLearner
|
||||
from src.learners.jargon_miner import JargonMiner
|
||||
|
||||
from .heartFC_utils import CycleDetail
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
|
||||
|
||||
install(extra_lines=5)
|
||||
|
||||
logger = get_logger("heartFC_chat")
|
||||
|
||||
|
||||
class HeartFChatting:
|
||||
"""管理一个持续运行的 Focus Chat 会话。"""
|
||||
"""
|
||||
管理一个连续的Focus Chat聊天会话
|
||||
用于在特定的聊天会话里面生成回复
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str):
|
||||
self.session_id = session_id
|
||||
self.chat_stream: BotChatSession = _chat_manager.get_session_by_session_id(self.session_id) # type: ignore[assignment]
|
||||
if not self.chat_stream:
|
||||
raise ValueError(f"无法找到聊天会话 {self.session_id}")
|
||||
"""
|
||||
初始化 HeartFChatting 实例
|
||||
|
||||
session_name = _chat_manager.get_session_name(session_id) or session_id
|
||||
Args:
|
||||
session_id: 聊天会话ID
|
||||
"""
|
||||
# 基础属性
|
||||
self.session_id = session_id
|
||||
session_name = chat_manager.get_session_name(session_id) or session_id
|
||||
self.log_prefix = f"[{session_name}]"
|
||||
self.session_name = session_name
|
||||
|
||||
self.action_manager = ActionManager()
|
||||
self.action_planner = ActionPlanner(chat_id=self.session_id, action_manager=self.action_manager)
|
||||
self.action_modifier = ActionModifier(action_manager=self.action_manager, chat_id=self.session_id)
|
||||
|
||||
# 系统运行状态
|
||||
self._running: bool = False
|
||||
self._loop_task: Optional[asyncio.Task] = None
|
||||
self._cycle_counter: int = 0
|
||||
self._hfc_lock: asyncio.Lock = asyncio.Lock() # 用于保护 _hfc_func 的并发访问
|
||||
# 聊天频率相关
|
||||
self._consecutive_no_reply_count = 0 # 跟踪连续 no_reply 次数,用于动态调整阈值
|
||||
self._talk_frequency_adjust: float = 1.0 # 发言频率修正值,默认为1.0,可以根据需要调整
|
||||
|
||||
# HFC内消息缓存
|
||||
self.message_cache: List[SessionMessage] = []
|
||||
|
||||
# Asyncio Event 用于控制循环的开始和结束
|
||||
self._cycle_event = asyncio.Event()
|
||||
self._hfc_lock = asyncio.Lock()
|
||||
|
||||
self._cycle_counter = 0
|
||||
self._current_cycle_detail: Optional[CycleDetail] = None
|
||||
self.history_loop: List[CycleDetail] = []
|
||||
|
||||
self.last_read_time = time.time() - 2
|
||||
self.last_active_time = time.time()
|
||||
self._talk_frequency_adjust = 1.0
|
||||
self._consecutive_no_reply_count = 0
|
||||
|
||||
self.message_cache: List["SessionMessage"] = []
|
||||
|
||||
self._min_messages_for_extraction = 30
|
||||
self._min_extraction_interval = 60
|
||||
self._last_extraction_time = 0.0
|
||||
|
||||
# 表达方式相关内容
|
||||
self._min_messages_for_extraction = 30 # 最少提取消息数
|
||||
self._min_extraction_interval = 60 # 最小提取时间间隔,单位为秒
|
||||
self._last_extraction_time: float = 0.0 # 上次提取的时间戳
|
||||
expr_use, jargon_learn, expr_learn = ExpressionConfigUtils.get_expression_config_for_chat(session_id)
|
||||
self._enable_expression_use = expr_use
|
||||
self._enable_expression_learning = expr_learn
|
||||
self._enable_jargon_learning = jargon_learn
|
||||
self._expression_learner = ExpressionLearner(session_id)
|
||||
self._jargon_miner = JargonMiner(session_id, session_name=session_name)
|
||||
self._enable_expression_use = expr_use # 允许使用表达方式,但不一定启用学习
|
||||
self._enable_expression_learning = expr_learn # 允许学习表达方式
|
||||
self._enable_jargon_learning = jargon_learn # 允许学习黑话
|
||||
# 表达学习器
|
||||
self._expression_learner: ExpressionLearner = ExpressionLearner(session_id)
|
||||
# 黑话挖掘器
|
||||
self._jargon_miner: JargonMiner = JargonMiner(session_id, session_name=session_name)
|
||||
|
||||
# TODO: ChatSummarizer 聊天总结器重构
|
||||
|
||||
# ====== 公开方法 ======
|
||||
|
||||
async def start(self):
|
||||
"""启动 HeartFChatting 的主循环"""
|
||||
# 先检查是否已经启动运行
|
||||
if self._running:
|
||||
logger.debug(f"{self.log_prefix} HeartFChatting 已在运行中")
|
||||
logger.debug(f"{self.log_prefix} 已经在运行中,无需重复启动")
|
||||
return
|
||||
|
||||
try:
|
||||
self._running = True
|
||||
self._cycle_event.clear()
|
||||
self._cycle_event.clear() # 确保事件初始状态为未设置
|
||||
|
||||
self._loop_task = asyncio.create_task(self.main_loop())
|
||||
self._loop_task.add_done_callback(self._handle_loop_completion)
|
||||
|
||||
logger.info(f"{self.log_prefix} HeartFChatting 启动完成")
|
||||
except Exception as exc:
|
||||
logger.error(f"{self.log_prefix} HeartFChatting 启动失败: {exc}", exc_info=True)
|
||||
self._running = False
|
||||
self._cycle_event.set()
|
||||
self._loop_task = None
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 启动 HeartFChatting 失败: {e}", exc_info=True)
|
||||
self._running = False # 确保状态正确
|
||||
self._cycle_event.set() # 确保事件被设置,避免死锁
|
||||
self._loop_task = None # 确保任务引用被清理
|
||||
raise
|
||||
|
||||
async def stop(self):
|
||||
"""停止 HeartFChatting 的主循环"""
|
||||
if not self._running:
|
||||
logger.debug(f"{self.log_prefix} HeartFChatting 已停止")
|
||||
logger.debug(f"{self.log_prefix} HeartFChatting 已经停止,无需重复停止")
|
||||
return
|
||||
|
||||
self._running = False
|
||||
self._cycle_event.set()
|
||||
self._cycle_event.set() # 触发事件,通知循环结束
|
||||
|
||||
if self._loop_task:
|
||||
self._loop_task.cancel()
|
||||
self._loop_task.cancel() # 取消主循环任务
|
||||
try:
|
||||
await self._loop_task
|
||||
await self._loop_task # 等待任务完成
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.log_prefix} HeartFChatting 主循环已取消")
|
||||
except Exception as exc:
|
||||
logger.error(f"{self.log_prefix} 停止 HeartFChatting 时发生错误: {exc}", exc_info=True)
|
||||
logger.info(f"{self.log_prefix} HeartFChatting 主循环已成功取消")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 停止 HeartFChatting 时发生错误: {e}", exc_info=True)
|
||||
finally:
|
||||
self._loop_task = None
|
||||
self._loop_task = None # 确保任务引用被清理
|
||||
|
||||
logger.info(f"{self.log_prefix} HeartFChatting 已停止")
|
||||
|
||||
def adjust_talk_frequency(self, new_value: float):
|
||||
"""调整发言频率的调整值
|
||||
|
||||
Args:
|
||||
new_value: 新的修正值,必须为非负数。值越大,修正发言频率越高;值越小,修正发言频率越低。
|
||||
"""
|
||||
self._talk_frequency_adjust = max(0.0, new_value)
|
||||
|
||||
async def register_message(self, message: "SessionMessage"):
|
||||
"""注册一条消息到 HeartFChatting 的缓存中,并检测其是否产生提及,决定是否唤醒聊天
|
||||
|
||||
Args:
|
||||
message: 待注册的消息对象
|
||||
"""
|
||||
self.message_cache.append(message)
|
||||
|
||||
# 先检查at必回复
|
||||
if global_config.chat.inevitable_at_reply and message.is_at:
|
||||
self.last_read_time = time.time()
|
||||
async with self._hfc_lock:
|
||||
await self._judge_and_response(mentioned_message=message, recent_messages_list=[message])
|
||||
return
|
||||
|
||||
async with self._hfc_lock: # 确保与主循环逻辑的互斥访问
|
||||
await self._judge_and_response(message)
|
||||
return # 直接返回,避免同一条消息被主循环再次处理
|
||||
# 再检查提及必回复
|
||||
if global_config.chat.mentioned_bot_reply and message.is_mentioned:
|
||||
self.last_read_time = time.time()
|
||||
async with self._hfc_lock:
|
||||
await self._judge_and_response(mentioned_message=message, recent_messages_list=[message])
|
||||
# 直接获取锁,确保一定一定触发回复逻辑,不受当前是否正在执行主循环的影响
|
||||
async with self._hfc_lock: # 确保与主循环逻辑的互斥访问
|
||||
await self._judge_and_response(message)
|
||||
return
|
||||
|
||||
async def main_loop(self):
|
||||
try:
|
||||
while self._running and not self._cycle_event.is_set():
|
||||
if not self._hfc_lock.locked():
|
||||
async with self._hfc_lock:
|
||||
async with self._hfc_lock: # 确保主循环逻辑的互斥访问
|
||||
await self._hfc_func()
|
||||
await asyncio.sleep(0.1)
|
||||
await asyncio.sleep(5)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.log_prefix} HeartFChatting: 主循环被取消")
|
||||
except Exception as exc:
|
||||
logger.error(f"{self.log_prefix} HeartFChatting: 主循环异常: {exc}", exc_info=True)
|
||||
await self.stop()
|
||||
logger.info(f"{self.log_prefix} HeartFChatting: 主循环被取消,正在关闭")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 麦麦聊天意外错误: {e},将于3s后尝试重新启动")
|
||||
await self.stop() # 确保状态正确
|
||||
await asyncio.sleep(3)
|
||||
await self.start()
|
||||
await self.start() # 尝试重新启动
|
||||
|
||||
async def _config_callback(self, file_change: Optional[FileChange] = None):
|
||||
del file_change
|
||||
expr_use, jargon_learn, expr_learn = ExpressionConfigUtils.get_expression_config_for_chat(self.session_id)
|
||||
self._enable_expression_use = expr_use
|
||||
self._enable_expression_learning = expr_learn
|
||||
self._enable_jargon_learning = jargon_learn
|
||||
"""配置文件变更回调函数"""
|
||||
# TODO: 根据配置文件变动重新计算相关参数:
|
||||
"""
|
||||
需要计算的参数:
|
||||
self._enable_expression_use = expr_use # 允许使用表达方式,但不一定启用学习
|
||||
self._enable_expression_learning = expr_learn # 允许学习表达方式
|
||||
self._enable_jargon_learning = jargon_learn # 允许学习黑话
|
||||
"""
|
||||
|
||||
async def _hfc_func(self):
|
||||
recent_messages_list = message_api.get_messages_by_time_in_chat(
|
||||
chat_id=self.session_id,
|
||||
start_time=self.last_read_time,
|
||||
end_time=time.time(),
|
||||
limit=20,
|
||||
limit_mode="latest",
|
||||
filter_mai=True,
|
||||
filter_command=False,
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
# ====== 心流聊天核心逻辑 ======
|
||||
async def _hfc_func(self, mentioned_message: Optional["SessionMessage"] = None):
|
||||
"""心流聊天的主循环逻辑"""
|
||||
if self._consecutive_no_reply_count >= 5:
|
||||
threshold = 2
|
||||
elif self._consecutive_no_reply_count >= 3:
|
||||
threshold = 2 if random.random() < 0.5 else 1
|
||||
else:
|
||||
threshold = 1
|
||||
|
||||
if len(recent_messages_list) < 1:
|
||||
if len(self.message_cache) < threshold:
|
||||
await asyncio.sleep(0.2)
|
||||
return True
|
||||
|
||||
self.last_read_time = time.time()
|
||||
|
||||
mentioned_message: Optional["SessionMessage"] = None
|
||||
for message in recent_messages_list:
|
||||
if global_config.chat.inevitable_at_reply and message.is_at:
|
||||
mentioned_message = message
|
||||
elif global_config.chat.mentioned_bot_reply and message.is_mentioned:
|
||||
mentioned_message = message
|
||||
|
||||
talk_value = ChatConfigUtils.get_talk_value(self.session_id) * self._talk_frequency_adjust
|
||||
if mentioned_message:
|
||||
await self._judge_and_response(mentioned_message=mentioned_message, recent_messages_list=recent_messages_list)
|
||||
elif random.random() < talk_value:
|
||||
await self._judge_and_response(recent_messages_list=recent_messages_list)
|
||||
talk_value_threshold = (
|
||||
random.random() * ChatConfigUtils.get_talk_value(self.session_id) * self._talk_frequency_adjust
|
||||
)
|
||||
if mentioned_message and global_config.chat.mentioned_bot_reply:
|
||||
await self._judge_and_response(mentioned_message)
|
||||
elif random.random() < talk_value_threshold:
|
||||
await self._judge_and_response()
|
||||
return True
|
||||
|
||||
async def _judge_and_response(
|
||||
self,
|
||||
mentioned_message: Optional["SessionMessage"] = None,
|
||||
recent_messages_list: Optional[List["SessionMessage"]] = None,
|
||||
):
|
||||
recent_messages = list(recent_messages_list or self.message_cache[-20:])
|
||||
if recent_messages:
|
||||
asyncio.create_task(self._trigger_expression_learning(recent_messages))
|
||||
|
||||
cycle_timers, thinking_id = self._start_cycle()
|
||||
async def _judge_and_response(self, mentioned_message: Optional["SessionMessage"] = None):
|
||||
"""判定和生成回复"""
|
||||
asyncio.create_task(self._trigger_expression_learning(self.message_cache))
|
||||
# TODO: 完成反思器之后的逻辑
|
||||
start_time = time.time()
|
||||
current_cycle_detail = self._start_cycle()
|
||||
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
|
||||
|
||||
try:
|
||||
async with global_prompt_manager.async_message_scope(self._get_template_name()):
|
||||
available_actions: Dict[str, ActionInfo] = {}
|
||||
try:
|
||||
await self.action_modifier.modify_actions()
|
||||
available_actions = self.action_manager.get_using_actions()
|
||||
except Exception as exc:
|
||||
logger.error(f"{self.log_prefix} 动作修改失败: {exc}", exc_info=True)
|
||||
# TODO: 动作检查逻辑
|
||||
# TODO: Planner逻辑
|
||||
# TODO: 动作执行逻辑
|
||||
|
||||
is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
|
||||
message_list_before_now = get_messages_before_time_in_chat(
|
||||
chat_id=self.session_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.6),
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
messages=message_list_before_now,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=self.action_planner.last_obs_time_mark,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
prompt, filtered_actions = await self._build_planner_prompt_with_event(
|
||||
available_actions=available_actions,
|
||||
is_group_chat=is_group_chat,
|
||||
chat_target_info=chat_target_info,
|
||||
chat_content_block=chat_content_block,
|
||||
message_id_list=message_id_list,
|
||||
)
|
||||
if prompt is None:
|
||||
return False
|
||||
|
||||
with Timer("规划器", cycle_timers):
|
||||
reasoning, action_to_use_info, llm_raw_output, llm_reasoning, llm_duration_ms = (
|
||||
await self.action_planner._execute_main_planner(
|
||||
prompt=prompt,
|
||||
message_id_list=message_id_list,
|
||||
filtered_actions=filtered_actions,
|
||||
available_actions=available_actions,
|
||||
loop_start_time=self.last_read_time,
|
||||
)
|
||||
)
|
||||
|
||||
action_to_use_info = self._ensure_force_reply_action(
|
||||
actions=action_to_use_info,
|
||||
force_reply_message=mentioned_message,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
self.action_planner.add_plan_log(reasoning, action_to_use_info)
|
||||
self.action_planner.last_obs_time_mark = time.time()
|
||||
self._log_plan(
|
||||
prompt=prompt,
|
||||
reasoning=reasoning,
|
||||
llm_raw_output=llm_raw_output,
|
||||
llm_reasoning=llm_reasoning,
|
||||
llm_duration_ms=llm_duration_ms,
|
||||
actions=action_to_use_info,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 决定执行{len(action_to_use_info)}个动作: {' '.join([a.action_type for a in action_to_use_info])}"
|
||||
)
|
||||
|
||||
action_tasks = [
|
||||
asyncio.create_task(
|
||||
self._execute_action(
|
||||
action,
|
||||
action_to_use_info,
|
||||
thinking_id,
|
||||
available_actions,
|
||||
cycle_timers,
|
||||
)
|
||||
)
|
||||
for action in action_to_use_info
|
||||
]
|
||||
results = await asyncio.gather(*action_tasks, return_exceptions=True)
|
||||
|
||||
reply_loop_info = None
|
||||
reply_text_from_reply = ""
|
||||
action_success = False
|
||||
action_reply_text = ""
|
||||
execute_result_str = ""
|
||||
|
||||
for result in results:
|
||||
if isinstance(result, BaseException):
|
||||
logger.error(f"{self.log_prefix} 动作执行异常: {result}", exc_info=True)
|
||||
continue
|
||||
|
||||
execute_result_str += f"{result['action_type']} 执行结果:{result['result']}\n"
|
||||
if result["action_type"] == "reply":
|
||||
if result["success"]:
|
||||
reply_loop_info = result["loop_info"]
|
||||
reply_text_from_reply = result["result"]
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} reply 动作执行失败")
|
||||
else:
|
||||
action_success = result["success"]
|
||||
action_reply_text = result["result"]
|
||||
|
||||
self.action_planner.add_plan_excute_log(result=execute_result_str)
|
||||
|
||||
if reply_loop_info:
|
||||
loop_info = reply_loop_info
|
||||
loop_info["loop_action_info"].update(
|
||||
{
|
||||
"action_taken": action_success,
|
||||
"taken_time": time.time(),
|
||||
}
|
||||
)
|
||||
else:
|
||||
loop_info = {
|
||||
"loop_plan_info": {
|
||||
"action_result": action_to_use_info,
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": action_success,
|
||||
"reply_text": action_reply_text,
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
reply_text_from_reply = action_reply_text
|
||||
|
||||
current_cycle_detail = self._end_cycle(self._current_cycle_detail, loop_info)
|
||||
logger.debug(f"{self.log_prefix} 本轮最终输出: {reply_text_from_reply}")
|
||||
return current_cycle_detail is not None
|
||||
except Exception as exc:
|
||||
logger.error(f"{self.log_prefix} 判定与回复流程失败: {exc}", exc_info=True)
|
||||
if self._current_cycle_detail:
|
||||
self._end_cycle(
|
||||
self._current_cycle_detail,
|
||||
{
|
||||
"loop_plan_info": {"action_result": []},
|
||||
"loop_action_info": {
|
||||
"action_taken": False,
|
||||
"reply_text": "",
|
||||
"taken_time": time.time(),
|
||||
"error": str(exc),
|
||||
},
|
||||
},
|
||||
)
|
||||
return False
|
||||
cycle_detail = self._end_cycle(current_cycle_detail)
|
||||
if wait_time := global_config.chat.planner_smooth - (time.time() - start_time) > 0:
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
await asyncio.sleep(0.1) # 最小等待时间,避免过快循环
|
||||
return True
|
||||
|
||||
def _handle_loop_completion(self, task: asyncio.Task):
|
||||
"""当 _hfc_func 任务完成时执行的回调。"""
|
||||
try:
|
||||
if exception := task.exception():
|
||||
logger.error(f"{self.log_prefix} HeartFChatting: 主循环异常退出: {exception}")
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"{self.log_prefix} HeartFChatting: 脱离了聊天(异常): {exception}")
|
||||
logger.error(traceback.format_exc()) # Log full traceback for exceptions
|
||||
else:
|
||||
logger.info(f"{self.log_prefix} HeartFChatting: 主循环已退出")
|
||||
logger.info(f"{self.log_prefix} HeartFChatting: 脱离了聊天 (外部停止)")
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.log_prefix} HeartFChatting: 聊天已结束")
|
||||
logger.info(f"{self.log_prefix} HeartFChatting: 结束了聊天")
|
||||
|
||||
# ====== 学习器触发逻辑 ======
|
||||
async def _trigger_expression_learning(self, messages: List["SessionMessage"]):
|
||||
if not messages:
|
||||
return
|
||||
|
||||
self._expression_learner.add_messages(messages)
|
||||
if time.time() - self._last_extraction_time < self._min_extraction_interval:
|
||||
return
|
||||
@@ -379,14 +233,12 @@ class HeartFChatting:
|
||||
return
|
||||
if not self._enable_expression_learning:
|
||||
return
|
||||
|
||||
extraction_end_time = time.time()
|
||||
logger.info(
|
||||
f"聊天流 {self.session_name} 提取到 {len(messages)} 条消息,"
|
||||
f"时间窗口: {self._last_extraction_time:.2f} - {extraction_end_time:.2f}"
|
||||
)
|
||||
self._last_extraction_time = extraction_end_time
|
||||
|
||||
try:
|
||||
jargon_miner = self._jargon_miner if self._enable_jargon_learning else None
|
||||
learnt_style = await self._expression_learner.learn(jargon_miner)
|
||||
@@ -394,398 +246,43 @@ class HeartFChatting:
|
||||
logger.info(f"{self.log_prefix} 表达学习完成")
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix} 表达学习未获得有效结果")
|
||||
except Exception as exc:
|
||||
logger.error(f"{self.log_prefix} 表达学习失败: {exc}", exc_info=True)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 表达学习失败: {e}", exc_info=True)
|
||||
|
||||
def _start_cycle(self) -> Tuple[Dict[str, float], str]:
|
||||
# ====== 记录循环执行信息相关逻辑 ======
|
||||
def _start_cycle(self) -> CycleDetail:
|
||||
self._cycle_counter += 1
|
||||
self._current_cycle_detail = CycleDetail(cycle_id=self._cycle_counter)
|
||||
self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
|
||||
return self._current_cycle_detail.time_records, self._current_cycle_detail.thinking_id
|
||||
current_cycle_detail = CycleDetail(cycle_id=self._cycle_counter)
|
||||
current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
|
||||
return current_cycle_detail
|
||||
|
||||
def _end_cycle(self, cycle_detail: Optional[CycleDetail], loop_info: Optional[Dict[str, Any]] = None):
|
||||
if cycle_detail is None:
|
||||
return None
|
||||
|
||||
cycle_detail.loop_plan_info = (loop_info or {}).get("loop_plan_info")
|
||||
cycle_detail.loop_action_info = (loop_info or {}).get("loop_action_info")
|
||||
def _end_cycle(self, cycle_detail: CycleDetail, only_long_execution: bool = True):
|
||||
cycle_detail.end_time = time.time()
|
||||
self.history_loop.append(cycle_detail)
|
||||
|
||||
timer_strings = [
|
||||
timer_strings: List[str] = [
|
||||
f"{name}: {duration:.2f}s"
|
||||
for name, duration in cycle_detail.time_records.items()
|
||||
if duration >= 0.1
|
||||
if not only_long_execution or duration >= 0.1
|
||||
]
|
||||
logger.info(
|
||||
f"{self.log_prefix} 第{cycle_detail.cycle_id} 个心流循环完成,"
|
||||
f"耗时: {cycle_detail.end_time - cycle_detail.start_time:.2f}s;"
|
||||
f"{self.log_prefix} 第 {cycle_detail.cycle_id} 个心流循环完成"
|
||||
f"耗时: {cycle_detail.end_time - cycle_detail.start_time:.2f}秒\n"
|
||||
f"详细计时: {', '.join(timer_strings) if timer_strings else '无'}"
|
||||
)
|
||||
|
||||
return cycle_detail
|
||||
|
||||
async def _execute_action(
|
||||
self,
|
||||
action_planner_info: ActionPlannerInfo,
|
||||
chosen_action_plan_infos: List[ActionPlannerInfo],
|
||||
thinking_id: str,
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
cycle_timers: Dict[str, float],
|
||||
):
|
||||
try:
|
||||
with Timer(f"动作{action_planner_info.action_type}", cycle_timers):
|
||||
if action_planner_info.action_type == "no_reply":
|
||||
reason = action_planner_info.reasoning or "选择不回复"
|
||||
self._consecutive_no_reply_count += 1
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
display_prompt=reason,
|
||||
thinking_id=thinking_id,
|
||||
action_data={},
|
||||
action_name="no_reply",
|
||||
action_reasoning=reason,
|
||||
)
|
||||
return {
|
||||
"action_type": "no_reply",
|
||||
"success": True,
|
||||
"result": "选择不回复",
|
||||
"loop_info": None,
|
||||
}
|
||||
# ====== Action相关逻辑 ======
|
||||
async def _execute_action(self, *args, **kwargs):
|
||||
"""原ExecuteAction"""
|
||||
raise NotImplementedError("执行动作的逻辑尚未实现") # TODO: 实现动作执行的逻辑,替换掉*args, **kwargs*占位符
|
||||
|
||||
if action_planner_info.action_type == "reply":
|
||||
self._consecutive_no_reply_count = 0
|
||||
reason = action_planner_info.reasoning or ""
|
||||
think_level = self._get_think_level(action_planner_info)
|
||||
planner_reasoning = action_planner_info.action_reasoning or reason
|
||||
async def _execute_other_actions(self, *args, **kwargs):
|
||||
"""原HandleAction"""
|
||||
raise NotImplementedError(
|
||||
"执行其他动作的逻辑尚未实现"
|
||||
) # TODO: 实现其他动作执行的逻辑, 替换掉*args, **kwargs*占位符
|
||||
|
||||
record_replyer_action_temp(
|
||||
chat_id=self.session_id,
|
||||
reason=reason,
|
||||
think_level=think_level,
|
||||
)
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
display_prompt=reason,
|
||||
thinking_id=thinking_id,
|
||||
action_data={},
|
||||
action_name="reply",
|
||||
action_reasoning=reason,
|
||||
)
|
||||
|
||||
unknown_words, quote_message = self._extract_reply_metadata(action_planner_info)
|
||||
success, llm_response = await generator_api.generate_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
reply_message=action_planner_info.action_message,
|
||||
available_actions=available_actions,
|
||||
chosen_actions=chosen_action_plan_infos,
|
||||
reply_reason=planner_reasoning,
|
||||
unknown_words=unknown_words,
|
||||
enable_tool=global_config.tool.enable_tool,
|
||||
request_type="replyer",
|
||||
from_plugin=False,
|
||||
reply_time_point=action_planner_info.action_data.get("loop_start_time", time.time())
|
||||
if action_planner_info.action_data
|
||||
else time.time(),
|
||||
think_level=think_level,
|
||||
)
|
||||
if not success or not llm_response or not llm_response.reply_set:
|
||||
if action_planner_info.action_message:
|
||||
logger.info(
|
||||
f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败"
|
||||
)
|
||||
else:
|
||||
logger.info(f"{self.log_prefix} 回复生成失败")
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"success": False,
|
||||
"result": "回复生成失败",
|
||||
"loop_info": None,
|
||||
}
|
||||
|
||||
loop_info, reply_text, _ = await self._send_and_store_reply(
|
||||
response_set=llm_response.reply_set,
|
||||
action_message=action_planner_info.action_message, # type: ignore[arg-type]
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
actions=chosen_action_plan_infos,
|
||||
selected_expressions=llm_response.selected_expressions,
|
||||
quote_message=quote_message,
|
||||
)
|
||||
self.last_active_time = time.time()
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"success": True,
|
||||
"result": reply_text,
|
||||
"loop_info": loop_info,
|
||||
}
|
||||
|
||||
with Timer("动作执行", cycle_timers):
|
||||
success, result = await self._handle_action(
|
||||
action=action_planner_info.action_type,
|
||||
action_reasoning=action_planner_info.action_reasoning or "",
|
||||
action_data=action_planner_info.action_data or {},
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
action_message=action_planner_info.action_message,
|
||||
)
|
||||
if success:
|
||||
self.last_active_time = time.time()
|
||||
return {
|
||||
"action_type": action_planner_info.action_type,
|
||||
"success": success,
|
||||
"result": result,
|
||||
"loop_info": None,
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.error(f"{self.log_prefix} 执行动作时出错: {exc}", exc_info=True)
|
||||
return {
|
||||
"action_type": action_planner_info.action_type,
|
||||
"success": False,
|
||||
"result": "",
|
||||
"loop_info": None,
|
||||
"error": str(exc),
|
||||
}
|
||||
|
||||
async def _handle_action(
|
||||
self,
|
||||
action: str,
|
||||
action_reasoning: str,
|
||||
action_data: dict,
|
||||
cycle_timers: Dict[str, float],
|
||||
thinking_id: str,
|
||||
action_message: Optional["SessionMessage"] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
try:
|
||||
action_handler = self.action_manager.create_action(
|
||||
action_name=action,
|
||||
action_data=action_data,
|
||||
action_reasoning=action_reasoning,
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
chat_stream=self.chat_stream,
|
||||
log_prefix=self.log_prefix,
|
||||
action_message=action_message,
|
||||
)
|
||||
if not action_handler:
|
||||
logger.warning(f"{self.log_prefix} 未能创建动作处理器: {action}")
|
||||
return False, ""
|
||||
|
||||
success, action_text = await action_handler.execute()
|
||||
return success, action_text
|
||||
except Exception as exc:
|
||||
logger.error(f"{self.log_prefix} 处理动作 {action} 时出错: {exc}", exc_info=True)
|
||||
return False, ""
|
||||
|
||||
async def _send_and_store_reply(
|
||||
self,
|
||||
response_set: MessageSequence,
|
||||
action_message: "SessionMessage",
|
||||
cycle_timers: Dict[str, float],
|
||||
thinking_id: str,
|
||||
actions: List[ActionPlannerInfo],
|
||||
selected_expressions: Optional[List[int]] = None,
|
||||
quote_message: Optional[bool] = None,
|
||||
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
|
||||
with Timer("回复发送", cycle_timers):
|
||||
reply_text = await self._send_response(
|
||||
reply_set=response_set,
|
||||
message_data=action_message,
|
||||
selected_expressions=selected_expressions,
|
||||
quote_message=quote_message,
|
||||
)
|
||||
|
||||
platform = action_message.platform or getattr(self.chat_stream, "platform", "unknown")
|
||||
person = Person(platform=platform, user_id=action_message.message_info.user_info.user_id)
|
||||
action_prompt_display = f"你对{person.person_name}进行了回复:{reply_text}"
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
display_prompt=action_prompt_display,
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reply_text": reply_text},
|
||||
action_name="reply",
|
||||
)
|
||||
|
||||
loop_info: Dict[str, Any] = {
|
||||
"loop_plan_info": {
|
||||
"action_result": actions,
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": True,
|
||||
"reply_text": reply_text,
|
||||
"command": "",
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
return loop_info, reply_text, cycle_timers
|
||||
|
||||
async def _send_response(
|
||||
self,
|
||||
reply_set: MessageSequence,
|
||||
message_data: "SessionMessage",
|
||||
selected_expressions: Optional[List[int]] = None,
|
||||
quote_message: Optional[bool] = None,
|
||||
) -> str:
|
||||
if global_config.chat.llm_quote:
|
||||
need_reply = bool(quote_message)
|
||||
else:
|
||||
new_message_count = message_api.count_new_messages(
|
||||
chat_id=self.session_id,
|
||||
start_time=self.last_read_time,
|
||||
end_time=time.time(),
|
||||
)
|
||||
need_reply = new_message_count >= random.randint(2, 3) or time.time() - self.last_read_time > 90
|
||||
|
||||
reply_text = ""
|
||||
first_replied = False
|
||||
for component in reply_set.components:
|
||||
if not isinstance(component, TextComponent):
|
||||
continue
|
||||
data = component.text
|
||||
if not first_replied:
|
||||
await send_api.text_to_stream(
|
||||
text=data,
|
||||
stream_id=self.session_id,
|
||||
reply_message=message_data,
|
||||
set_reply=need_reply,
|
||||
typing=False,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
first_replied = True
|
||||
else:
|
||||
await send_api.text_to_stream(
|
||||
text=data,
|
||||
stream_id=self.session_id,
|
||||
reply_message=message_data,
|
||||
set_reply=False,
|
||||
typing=True,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
reply_text += data
|
||||
return reply_text
|
||||
|
||||
async def _build_planner_prompt_with_event(
|
||||
self,
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
is_group_chat: bool,
|
||||
chat_target_info: Any,
|
||||
chat_content_block: str,
|
||||
message_id_list: List[Tuple[str, "SessionMessage"]],
|
||||
) -> Tuple[Optional[str], Dict[str, ActionInfo]]:
|
||||
filtered_actions = self.action_planner._filter_actions_by_activation_type(available_actions, chat_content_block)
|
||||
prompt, _ = await self.action_planner.build_planner_prompt(
|
||||
is_group_chat=is_group_chat,
|
||||
chat_target_info=chat_target_info,
|
||||
current_available_actions=filtered_actions,
|
||||
chat_content_block=chat_content_block,
|
||||
message_id_list=message_id_list,
|
||||
)
|
||||
event_message = build_event_message(EventType.ON_PLAN, llm_prompt=prompt, stream_id=self.session_id)
|
||||
continue_flag, modified_message = await event_bus.emit(EventType.ON_PLAN, event_message)
|
||||
if not continue_flag:
|
||||
logger.info(f"{self.log_prefix} ON_PLAN 事件中止了本轮 HFC")
|
||||
return None, filtered_actions
|
||||
if modified_message and modified_message._modify_flags.modify_llm_prompt and modified_message.llm_prompt:
|
||||
prompt = modified_message.llm_prompt
|
||||
return prompt, filtered_actions
|
||||
|
||||
def _ensure_force_reply_action(
|
||||
self,
|
||||
actions: List[ActionPlannerInfo],
|
||||
force_reply_message: Optional["SessionMessage"],
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
) -> List[ActionPlannerInfo]:
|
||||
if not force_reply_message:
|
||||
return actions
|
||||
|
||||
has_reply_to_force_message = any(
|
||||
action.action_type == "reply"
|
||||
and action.action_message
|
||||
and action.action_message.message_id == force_reply_message.message_id
|
||||
for action in actions
|
||||
)
|
||||
if has_reply_to_force_message:
|
||||
return actions
|
||||
|
||||
actions = [action for action in actions if action.action_type != "no_reply"]
|
||||
actions.insert(
|
||||
0,
|
||||
ActionPlannerInfo(
|
||||
action_type="reply",
|
||||
reasoning="用户提及了我,必须回复该消息",
|
||||
action_data={"loop_start_time": self.last_read_time},
|
||||
action_message=force_reply_message,
|
||||
available_actions=available_actions,
|
||||
action_reasoning=None,
|
||||
),
|
||||
)
|
||||
logger.info(f"{self.log_prefix} 检测到强制回复消息,已补充 reply 动作")
|
||||
return actions
|
||||
|
||||
def _log_plan(
|
||||
self,
|
||||
prompt: str,
|
||||
reasoning: str,
|
||||
llm_raw_output: Optional[str],
|
||||
llm_reasoning: Optional[str],
|
||||
llm_duration_ms: Optional[float],
|
||||
actions: List[ActionPlannerInfo],
|
||||
) -> None:
|
||||
try:
|
||||
PlanReplyLogger.log_plan(
|
||||
chat_id=self.session_id,
|
||||
prompt=prompt,
|
||||
reasoning=reasoning,
|
||||
raw_output=llm_raw_output,
|
||||
raw_reasoning=llm_reasoning,
|
||||
actions=actions,
|
||||
timing={
|
||||
"llm_duration_ms": round(llm_duration_ms, 2) if llm_duration_ms is not None else None,
|
||||
"loop_start_time": self.last_read_time,
|
||||
},
|
||||
extra=None,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"{self.log_prefix} 记录 plan 日志失败")
|
||||
|
||||
def _extract_reply_metadata(
|
||||
self,
|
||||
action_planner_info: ActionPlannerInfo,
|
||||
) -> Tuple[Optional[List[str]], Optional[bool]]:
|
||||
unknown_words: Optional[List[str]] = None
|
||||
quote_message: Optional[bool] = None
|
||||
action_data = action_planner_info.action_data or {}
|
||||
|
||||
raw_unknown_words = action_data.get("unknown_words")
|
||||
if isinstance(raw_unknown_words, list):
|
||||
cleaned_unknown_words = []
|
||||
for item in raw_unknown_words:
|
||||
if isinstance(item, str) and (cleaned_item := item.strip()):
|
||||
cleaned_unknown_words.append(cleaned_item)
|
||||
if cleaned_unknown_words:
|
||||
unknown_words = cleaned_unknown_words
|
||||
|
||||
raw_quote = action_data.get("quote")
|
||||
if isinstance(raw_quote, bool):
|
||||
quote_message = raw_quote
|
||||
elif isinstance(raw_quote, str):
|
||||
quote_message = raw_quote.lower() in {"true", "1", "yes"}
|
||||
elif isinstance(raw_quote, (int, float)):
|
||||
quote_message = bool(raw_quote)
|
||||
|
||||
return unknown_words, quote_message
|
||||
|
||||
def _get_think_level(self, action_planner_info: ActionPlannerInfo) -> int:
|
||||
think_mode = global_config.chat.think_mode
|
||||
if think_mode == "default":
|
||||
return 0
|
||||
if think_mode == "deep":
|
||||
return 1
|
||||
if think_mode == "dynamic":
|
||||
action_data = action_planner_info.action_data or {}
|
||||
return int(action_data.get("think_level", 1))
|
||||
return 0
|
||||
|
||||
def _get_template_name(self) -> Optional[str]:
|
||||
if self.chat_stream.context:
|
||||
return self.chat_stream.context.template_name
|
||||
return None
|
||||
# ====== 响应发送相关方法 ======
|
||||
async def _send_response(self, *args, **kwargs):
|
||||
raise NotImplementedError("发送回复的逻辑尚未实现") # TODO: 实现发送回复的逻辑,替换掉*args, **kwargs*占位符
|
||||
# 传入的消息至少应该是个MessageSequence实例,最好是SessionMessage实例,随后可直接转化为MessageSending实例
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
import traceback
|
||||
from typing import Any, Optional, Dict
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.heart_flow.heartFC_chat import HeartFChatting
|
||||
from src.chat.brain_chat.brain_chat import BrainChatting
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
|
||||
logger = get_logger("heartflow")
|
||||
|
||||
|
||||
class Heartflow:
|
||||
"""主心流协调器,负责初始化并协调聊天"""
|
||||
|
||||
def __init__(self):
|
||||
self.heartflow_chat_list: Dict[Any, HeartFChatting | BrainChatting] = {}
|
||||
|
||||
async def get_or_create_heartflow_chat(self, chat_id: Any) -> Optional[HeartFChatting | BrainChatting]:
|
||||
"""获取或创建一个新的HeartFChatting实例"""
|
||||
try:
|
||||
if chat_id in self.heartflow_chat_list:
|
||||
if chat := self.heartflow_chat_list.get(chat_id):
|
||||
return chat
|
||||
else:
|
||||
chat_stream: ChatStream | None = get_chat_manager().get_stream(chat_id)
|
||||
if not chat_stream:
|
||||
raise ValueError(f"未找到 chat_id={chat_id} 的聊天流")
|
||||
if chat_stream.group_info:
|
||||
new_chat = HeartFChatting(chat_id=chat_id)
|
||||
else:
|
||||
new_chat = BrainChatting(chat_id=chat_id)
|
||||
await new_chat.start()
|
||||
self.heartflow_chat_list[chat_id] = new_chat
|
||||
return new_chat
|
||||
except Exception as e:
|
||||
logger.error(f"创建心流聊天 {chat_id} 失败: {e}", exc_info=True)
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
heartflow = Heartflow()
|
||||
@@ -1,19 +1,20 @@
|
||||
from contextlib import suppress
|
||||
import traceback
|
||||
import os
|
||||
|
||||
from maim_message import MessageBase
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import os
|
||||
import traceback
|
||||
|
||||
from maim_message import MessageBase
|
||||
|
||||
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_message import MessageUtils
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
||||
from src.platform_io.route_key_factory import RouteKeyFactory
|
||||
|
||||
# from src.chat.brain_chat.PFC.pfc_manager import PFCManager
|
||||
from src.core.announcement_manager import global_announcement_manager
|
||||
from src.core.component_registry import component_registry
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
|
||||
from .message import SessionMessage
|
||||
from .chat_manager import chat_manager
|
||||
@@ -58,16 +59,22 @@ class ChatBot:
|
||||
logger.error(f"创建PFC聊天失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _process_commands(self, message: SessionMessage):
|
||||
# sourcery skip: use-named-expression
|
||||
"""使用新插件系统处理命令"""
|
||||
async def _process_commands(self, message: SessionMessage) -> tuple[bool, Optional[str], bool]:
|
||||
"""使用统一组件注册表处理命令。
|
||||
|
||||
Args:
|
||||
message: 当前待处理的会话消息。
|
||||
|
||||
Returns:
|
||||
tuple[bool, Optional[str], bool]: ``(是否命中命令, 命令响应文本, 是否继续后续处理)``。
|
||||
"""
|
||||
if not message.processed_plain_text:
|
||||
return False, None, True # 没有文本内容,继续处理消息
|
||||
try:
|
||||
text = message.processed_plain_text
|
||||
|
||||
# 使用核心组件注册表查找命令
|
||||
command_result = component_registry.find_command_by_text(text)
|
||||
# 使用插件运行时统一查询服务查找命令
|
||||
command_result = component_query_service.find_command_by_text(text)
|
||||
if command_result:
|
||||
command_executor, matched_groups, command_info = command_result
|
||||
plugin_name = command_info.plugin_name
|
||||
@@ -81,7 +88,7 @@ class ChatBot:
|
||||
message.is_command = True
|
||||
|
||||
# 获取插件配置
|
||||
plugin_config = component_registry.get_plugin_config(plugin_name)
|
||||
plugin_config = component_query_service.get_plugin_config(plugin_name)
|
||||
|
||||
try:
|
||||
# 调用命令执行器
|
||||
@@ -112,88 +119,32 @@ class ChatBot:
|
||||
# 命令出错时,根据命令的拦截设置决定是否继续处理消息
|
||||
return True, str(e), False # 出错时继续处理消息
|
||||
|
||||
# 没有找到旧系统命令,尝试新版本插件运行时
|
||||
new_cmd_result = await self._process_new_runtime_command(message)
|
||||
return new_cmd_result if new_cmd_result is not None else (False, None, True)
|
||||
return False, None, True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理命令时出错: {e}")
|
||||
return False, None, True # 出错时继续处理消息
|
||||
|
||||
async def _process_new_runtime_command(self, message: SessionMessage):
|
||||
"""尝试在新版本插件运行时中查找并执行命令
|
||||
|
||||
Returns:
|
||||
(found, response, continue_processing) 三元组,
|
||||
或 None 表示新运行时中也未找到匹配命令。
|
||||
"""
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
prm = get_plugin_runtime_manager()
|
||||
if not prm.is_running:
|
||||
return None
|
||||
|
||||
matched = prm.find_command_by_text(message.processed_plain_text)
|
||||
if matched is None:
|
||||
return None
|
||||
|
||||
command_name = matched["name"]
|
||||
if message.session_id and command_name in global_announcement_manager.get_disabled_chat_commands(
|
||||
message.session_id
|
||||
):
|
||||
logger.info(f"[新运行时] 用户禁用的命令,跳过处理: {matched['full_name']}")
|
||||
return False, None, True
|
||||
|
||||
message.is_command = True
|
||||
logger.info(f"[新运行时] 匹配命令: {matched['full_name']}")
|
||||
|
||||
try:
|
||||
resp = await prm.invoke_plugin(
|
||||
method="plugin.invoke_command",
|
||||
plugin_id=matched["plugin_id"],
|
||||
component_name=matched["name"],
|
||||
args={
|
||||
"text": message.processed_plain_text,
|
||||
"stream_id": message.session_id or "",
|
||||
"matched_groups": matched.get("matched_groups") or {},
|
||||
},
|
||||
timeout_ms=30000,
|
||||
)
|
||||
|
||||
payload = resp.payload
|
||||
success = payload.get("success", False)
|
||||
cmd_result = payload.get("result")
|
||||
|
||||
# 拦截位优先从命令返回值中获取(支持运行时动态决定),
|
||||
# 回退到组件 metadata 中的静态声明
|
||||
if isinstance(cmd_result, (list, tuple)) and len(cmd_result) >= 3:
|
||||
# 命令返回 (found, response_text, intercept_bool) 三元组
|
||||
response_text = cmd_result[1] if cmd_result[1] is not None else ""
|
||||
intercept = bool(cmd_result[2])
|
||||
else:
|
||||
response_text = cmd_result if cmd_result is not None else ""
|
||||
intercept = bool(matched["metadata"].get("intercept_message_level", 0))
|
||||
|
||||
self._mark_command_message(message, int(intercept))
|
||||
|
||||
if success:
|
||||
logger.info(f"[新运行时] 命令执行成功: {matched['full_name']}")
|
||||
else:
|
||||
logger.warning(f"[新运行时] 命令执行失败: {matched['full_name']} - {response_text}")
|
||||
|
||||
return True, response_text, not intercept
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[新运行时] 执行命令 {matched['full_name']} 异常: {e}", exc_info=True)
|
||||
return True, str(e), True
|
||||
|
||||
@staticmethod
|
||||
def _mark_command_message(message: SessionMessage, intercept_message_level: int) -> None:
|
||||
"""标记消息已经被命令链消费。
|
||||
|
||||
Args:
|
||||
message: 待标记的会话消息。
|
||||
intercept_message_level: 命令设置的拦截级别。
|
||||
"""
|
||||
|
||||
message.is_command = True
|
||||
message.message_info.additional_config["intercept_message_level"] = intercept_message_level
|
||||
|
||||
@staticmethod
|
||||
def _store_intercepted_command_message(message: SessionMessage) -> None:
|
||||
"""将被命令链拦截的消息写入数据库。
|
||||
|
||||
Args:
|
||||
message: 已完成命令处理的会话消息。
|
||||
"""
|
||||
|
||||
MessageUtils.store_message_to_db(message)
|
||||
|
||||
async def _handle_command_processing_result(
|
||||
@@ -310,13 +261,28 @@ class ChatBot:
|
||||
# logger.debug(str(message_data))
|
||||
maim_raw_message = MessageBase.from_dict(message_data)
|
||||
message = SessionMessage.from_maim_message(maim_raw_message)
|
||||
await self.receive_message(message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"预处理消息失败: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
async def receive_message(self, message: SessionMessage):
|
||||
try:
|
||||
group_info = message.message_info.group_info
|
||||
user_info = message.message_info.user_info
|
||||
account_id = None
|
||||
scope = None
|
||||
additional_config = message.message_info.additional_config
|
||||
if isinstance(additional_config, dict):
|
||||
account_id, scope = RouteKeyFactory.extract_components(additional_config)
|
||||
|
||||
session_id = SessionUtils.calculate_session_id(
|
||||
message.platform,
|
||||
user_id=message.message_info.user_info.user_id,
|
||||
group_id=group_info.group_id if group_info else None,
|
||||
account_id=account_id,
|
||||
scope=scope,
|
||||
)
|
||||
|
||||
message.session_id = session_id # 正确初始化session_id
|
||||
@@ -359,18 +325,24 @@ class ChatBot:
|
||||
platform = message.platform
|
||||
user_id = user_info.user_id
|
||||
group_id = group_info.group_id if group_info else None
|
||||
_ = await chat_manager.get_or_create_session(platform, user_id, group_id) # 确保会话存在
|
||||
_ = await chat_manager.get_or_create_session(
|
||||
platform,
|
||||
user_id,
|
||||
group_id,
|
||||
account_id=account_id,
|
||||
scope=scope,
|
||||
) # 确保会话存在
|
||||
|
||||
# message.update_chat_stream(chat)
|
||||
|
||||
# 命令处理 - 使用新插件系统检查并处理命令
|
||||
# 注意:命令返回的 response 当前只用于日志记录和流程判断,
|
||||
# 不会在这里自动作为回复消息发送回会话。
|
||||
is_command, cmd_result, continue_process = await self._process_commands(message)
|
||||
# is_command, cmd_result, continue_process = await self._process_commands(message)
|
||||
|
||||
# 如果是命令且不需要继续处理,则直接返回
|
||||
if is_command and await self._handle_command_processing_result(message, cmd_result, continue_process):
|
||||
return
|
||||
# # 如果是命令且不需要继续处理,则直接返回
|
||||
# if is_command and await self._handle_command_processing_result(message, cmd_result, continue_process):
|
||||
# return
|
||||
|
||||
# continue_flag, modified_message = await events_manager.handle_mai_events(EventType.ON_MESSAGE, message)
|
||||
# if not continue_flag:
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
from rich.traceback import install
|
||||
from sqlmodel import select
|
||||
from typing import Optional, TYPE_CHECKING, List, Dict
|
||||
|
||||
import asyncio
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.chat_session_data_model import MaiChatSession
|
||||
from src.common.database.database_model import ChatSession
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import ChatSession
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.platform_io.route_key_factory import RouteKeyFactory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .message import SessionMessage
|
||||
@@ -82,7 +83,12 @@ class ChatManager:
|
||||
logger.error(f"初始化聊天管理器出现错误: {e}")
|
||||
|
||||
async def get_or_create_session(
|
||||
self, platform: str, user_id: str, group_id: Optional[str] = None
|
||||
self,
|
||||
platform: str,
|
||||
user_id: str,
|
||||
group_id: Optional[str] = None,
|
||||
account_id: Optional[str] = None,
|
||||
scope: Optional[str] = None,
|
||||
) -> BotChatSession:
|
||||
"""获取会话,如果不存在则创建一个新会话;一个封装方法。
|
||||
|
||||
@@ -90,12 +96,20 @@ class ChatManager:
|
||||
platform: 平台
|
||||
user_id: 用户ID
|
||||
group_id: 群ID(如果是群聊)
|
||||
account_id: 平台账号 ID
|
||||
scope: 路由作用域
|
||||
Returns:
|
||||
return (BotChatSession) 会话对象
|
||||
Raises:
|
||||
Exception: 获取或创建会话时发生错误
|
||||
"""
|
||||
session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
|
||||
session_id = SessionUtils.calculate_session_id(
|
||||
platform,
|
||||
user_id=user_id,
|
||||
group_id=group_id,
|
||||
account_id=account_id,
|
||||
scope=scope,
|
||||
)
|
||||
if session := self.get_session_by_session_id(session_id):
|
||||
session.update_active_time()
|
||||
return session
|
||||
@@ -131,7 +145,18 @@ class ChatManager:
|
||||
raise ValueError("消息缺少平台信息")
|
||||
user_id = message.message_info.user_info.user_id
|
||||
group_id = message.message_info.group_info.group_id if message.message_info.group_info else None
|
||||
session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
|
||||
account_id = None
|
||||
scope = None
|
||||
additional_config = message.message_info.additional_config
|
||||
if isinstance(additional_config, dict):
|
||||
account_id, scope = RouteKeyFactory.extract_components(additional_config)
|
||||
session_id = SessionUtils.calculate_session_id(
|
||||
platform,
|
||||
user_id=user_id,
|
||||
group_id=group_id,
|
||||
account_id=account_id,
|
||||
scope=scope,
|
||||
)
|
||||
message.session_id = session_id # 确保消息的session_id正确设置
|
||||
self.last_messages[session_id] = message
|
||||
|
||||
@@ -188,7 +213,12 @@ class ChatManager:
|
||||
return None
|
||||
|
||||
def get_session_by_info(
|
||||
self, platform: str, user_id: Optional[str] = None, group_id: Optional[str] = None
|
||||
self,
|
||||
platform: str,
|
||||
user_id: Optional[str] = None,
|
||||
group_id: Optional[str] = None,
|
||||
account_id: Optional[str] = None,
|
||||
scope: Optional[str] = None,
|
||||
) -> Optional[BotChatSession]:
|
||||
"""根据平台、用户ID和群ID获取对应的会话
|
||||
|
||||
@@ -196,10 +226,18 @@ class ChatManager:
|
||||
platform: 平台
|
||||
user_id: 用户ID
|
||||
group_id: 群ID(如果是群聊)
|
||||
account_id: 平台账号 ID
|
||||
scope: 路由作用域
|
||||
Returns:
|
||||
return (Optional[BotChatSession]): 会话对象,如果不存在则返回None
|
||||
"""
|
||||
session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
|
||||
session_id = SessionUtils.calculate_session_id(
|
||||
platform,
|
||||
user_id=user_id,
|
||||
group_id=group_id,
|
||||
account_id=account_id,
|
||||
scope=scope,
|
||||
)
|
||||
return self.get_session_by_session_id(session_id)
|
||||
|
||||
def get_session_by_session_id(self, session_id: str) -> Optional[BotChatSession]:
|
||||
|
||||
@@ -1,31 +1,37 @@
|
||||
from rich.traceback import install
|
||||
from typing import Optional
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.message_server.api import get_global_api
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import get_db_session
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.chat.utils.utils import calculate_typing_time, truncate_message
|
||||
from src.common.data_models.message_component_data_model import ReplyComponent
|
||||
from src.chat.utils.utils import truncate_message
|
||||
from src.chat.utils.utils import calculate_typing_time
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_server.api import get_global_api
|
||||
from src.webui.routers.chat.serializers import serialize_message_sequence
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("sender")
|
||||
|
||||
# WebUI 聊天室的消息广播器(延迟导入避免循环依赖)
|
||||
_webui_chat_broadcaster = None
|
||||
_webui_chat_broadcaster: Optional[Tuple[Any, Optional[str]]] = None
|
||||
|
||||
# 虚拟群 ID 前缀(与 chat_routes.py 保持一致)
|
||||
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
|
||||
|
||||
|
||||
# TODO: 重构完成后完成webui相关
|
||||
def get_webui_chat_broadcaster():
|
||||
"""获取 WebUI 聊天室广播器"""
|
||||
def get_webui_chat_broadcaster() -> Tuple[Any, Optional[str]]:
|
||||
"""获取 WebUI 聊天室广播器。
|
||||
|
||||
Returns:
|
||||
Tuple[Any, Optional[str]]: ``(chat_manager, platform_name)`` 二元组;
|
||||
若 WebUI 相关模块不可用,则元素会退化为 ``None``。
|
||||
"""
|
||||
global _webui_chat_broadcaster
|
||||
if _webui_chat_broadcaster is None:
|
||||
try:
|
||||
@@ -38,102 +44,35 @@ def get_webui_chat_broadcaster():
|
||||
|
||||
|
||||
def is_webui_virtual_group(group_id: str) -> bool:
|
||||
"""检查是否是 WebUI 虚拟群"""
|
||||
return group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX)
|
||||
|
||||
|
||||
def parse_message_segments(segment) -> list:
|
||||
"""解析消息段,转换为 WebUI 可用的格式
|
||||
|
||||
参考 NapCat 适配器的消息解析逻辑
|
||||
"""检查是否是 WebUI 虚拟群。
|
||||
|
||||
Args:
|
||||
segment: Seg 消息段对象
|
||||
group_id: 待判断的群 ID。
|
||||
|
||||
Returns:
|
||||
list: 消息段列表,每个元素为 {"type": "...", "data": ...}
|
||||
bool: 若群 ID 属于 WebUI 虚拟群则返回 ``True``。
|
||||
"""
|
||||
|
||||
result = []
|
||||
|
||||
if segment is None:
|
||||
return result
|
||||
|
||||
if segment.type == "seglist":
|
||||
# 处理消息段列表
|
||||
if segment.data:
|
||||
for seg in segment.data:
|
||||
result.extend(parse_message_segments(seg))
|
||||
elif segment.type == "text":
|
||||
# 文本消息
|
||||
if segment.data:
|
||||
result.append({"type": "text", "data": segment.data})
|
||||
elif segment.type == "image":
|
||||
# 图片消息(base64)
|
||||
if segment.data:
|
||||
result.append({"type": "image", "data": f"data:image/png;base64,{segment.data}"})
|
||||
elif segment.type == "emoji":
|
||||
# 表情包消息(base64)
|
||||
if segment.data:
|
||||
result.append({"type": "emoji", "data": f"data:image/gif;base64,{segment.data}"})
|
||||
elif segment.type == "imageurl":
|
||||
# 图片链接消息
|
||||
if segment.data:
|
||||
result.append({"type": "image", "data": segment.data})
|
||||
elif segment.type == "face":
|
||||
# 原生表情
|
||||
result.append({"type": "face", "data": segment.data})
|
||||
elif segment.type == "voice":
|
||||
# 语音消息(base64)
|
||||
if segment.data:
|
||||
result.append({"type": "voice", "data": f"data:audio/wav;base64,{segment.data}"})
|
||||
elif segment.type == "voiceurl":
|
||||
# 语音链接
|
||||
if segment.data:
|
||||
result.append({"type": "voice", "data": segment.data})
|
||||
elif segment.type == "video":
|
||||
# 视频消息(base64)
|
||||
if segment.data:
|
||||
result.append({"type": "video", "data": f"data:video/mp4;base64,{segment.data}"})
|
||||
elif segment.type == "videourl":
|
||||
# 视频链接
|
||||
if segment.data:
|
||||
result.append({"type": "video", "data": segment.data})
|
||||
elif segment.type == "music":
|
||||
# 音乐消息
|
||||
result.append({"type": "music", "data": segment.data})
|
||||
elif segment.type == "file":
|
||||
# 文件消息
|
||||
result.append({"type": "file", "data": segment.data})
|
||||
elif segment.type == "reply":
|
||||
# 回复消息
|
||||
result.append({"type": "reply", "data": segment.data})
|
||||
elif segment.type == "forward":
|
||||
# 转发消息
|
||||
forward_items = []
|
||||
if segment.data:
|
||||
for item in segment.data:
|
||||
forward_items.append(
|
||||
{
|
||||
"content": parse_message_segments(item.get("message_segment", {}))
|
||||
if isinstance(item, dict)
|
||||
else []
|
||||
}
|
||||
)
|
||||
result.append({"type": "forward", "data": forward_items})
|
||||
else:
|
||||
# 未知类型,尝试作为文本处理
|
||||
if segment.data:
|
||||
result.append({"type": "unknown", "original_type": segment.type, "data": str(segment.data)})
|
||||
|
||||
return result
|
||||
return bool(group_id) and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX)
|
||||
|
||||
|
||||
async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||
"""合并后的消息发送函数,包含WS发送和日志记录"""
|
||||
async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
|
||||
"""执行统一的消息发送流程。
|
||||
|
||||
发送顺序为:
|
||||
1. WebUI 特殊链路
|
||||
2. 旧版 ``maim_message`` / API Server 链路
|
||||
|
||||
Args:
|
||||
message: 待发送的内部会话消息。
|
||||
show_log: 是否输出发送成功日志。
|
||||
|
||||
Returns:
|
||||
bool: 是否最终发送成功。
|
||||
"""
|
||||
message_preview = truncate_message(message.processed_plain_text, max_length=200)
|
||||
platform = message.platform
|
||||
group_id = message.session.group_id
|
||||
group_info = message.message_info.group_info
|
||||
group_id = group_info.group_id if group_info is not None else ""
|
||||
|
||||
try:
|
||||
# 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息
|
||||
@@ -146,7 +85,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||
from src.config.config import global_config
|
||||
|
||||
# 解析消息段,获取富文本内容
|
||||
message_segments = parse_message_segments(message.message_segment)
|
||||
message_segments = serialize_message_sequence(message.raw_message)
|
||||
|
||||
# 判断消息类型
|
||||
# 如果只有一个文本段,使用简单的 text 类型
|
||||
@@ -185,7 +124,15 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||
return True
|
||||
|
||||
# Fallback 逻辑: 尝试通过 API Server 发送
|
||||
async def send_with_new_api(legacy_exception=None):
|
||||
async def send_with_new_api(legacy_exception: Optional[Exception] = None) -> bool:
|
||||
"""通过 API Server 回退链路发送消息。
|
||||
|
||||
Args:
|
||||
legacy_exception: 旧发送链已经抛出的异常;若回退也失败,则重新抛出。
|
||||
|
||||
Returns:
|
||||
bool: 回退链路是否发送成功。
|
||||
"""
|
||||
try:
|
||||
from src.config.config import global_config
|
||||
|
||||
@@ -286,10 +233,24 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||
raise e # 重新抛出其他异常
|
||||
|
||||
|
||||
class UniversalMessageSender:
|
||||
"""管理消息的注册、即时处理、发送和存储,并跟踪思考状态。"""
|
||||
async def send_prepared_message_to_platform(message: SessionMessage, show_log: bool = True) -> bool:
|
||||
"""发送一条已完成预处理的消息到底层平台。
|
||||
|
||||
def __init__(self):
|
||||
Args:
|
||||
message: 已经完成回复组件注入、文本处理等预处理的消息对象。
|
||||
show_log: 是否输出发送成功日志。
|
||||
|
||||
Returns:
|
||||
bool: 发送成功时返回 ``True``。
|
||||
"""
|
||||
return await _send_message(message, show_log=show_log)
|
||||
|
||||
|
||||
class UniversalMessageSender:
|
||||
"""旧链与 WebUI 的底层发送器。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化统一消息发送器。"""
|
||||
pass
|
||||
|
||||
async def send_message(
|
||||
@@ -300,18 +261,19 @@ class UniversalMessageSender:
|
||||
reply_message_id: Optional[str] = None,
|
||||
storage_message: bool = True,
|
||||
show_log: bool = True,
|
||||
):
|
||||
"""
|
||||
处理、发送并存储一条消息。
|
||||
) -> bool:
|
||||
"""通过旧链或 WebUI 发送并存储一条消息。
|
||||
|
||||
参数:
|
||||
message: MessageSession 对象,待发送的消息。
|
||||
Args:
|
||||
message: 待发送的内部消息对象。
|
||||
typing: 是否模拟打字等待。
|
||||
set_reply: 是否构建回复引用消息。
|
||||
set_reply: 是否构建引用回复消息。
|
||||
reply_message_id: 被引用消息的 ID。
|
||||
storage_message: 是否在发送成功后写入数据库。
|
||||
show_log: 是否输出发送日志。
|
||||
|
||||
|
||||
用法:
|
||||
- typing=True 时,发送前会有打字等待。
|
||||
Returns:
|
||||
bool: 发送成功时返回 ``True``。
|
||||
"""
|
||||
if not message.message_id:
|
||||
logger.error("消息缺少 message_id,无法发送")
|
||||
@@ -364,7 +326,7 @@ class UniversalMessageSender:
|
||||
)
|
||||
await asyncio.sleep(typing_time)
|
||||
|
||||
sent_msg = await _send_message(message, show_log=show_log)
|
||||
sent_msg = await send_prepared_message_to_platform(message, show_log=show_log)
|
||||
if not sent_msg:
|
||||
return False
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ from typing import Dict, Optional, Tuple
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.logger import get_logger
|
||||
from src.core.component_registry import component_registry, ActionExecutor
|
||||
from src.core.types import ActionInfo
|
||||
from src.plugin_runtime.component_query import ActionExecutor, component_query_service
|
||||
|
||||
logger = get_logger("action_manager")
|
||||
|
||||
@@ -28,7 +28,7 @@ class ActionManager:
|
||||
"""
|
||||
动作管理器,用于管理各种类型的动作
|
||||
|
||||
使用核心组件注册表的 executor-based 模式。
|
||||
使用插件运行时统一查询服务的 executor-based 模式。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -38,7 +38,7 @@ class ActionManager:
|
||||
self._using_actions: Dict[str, ActionInfo] = {}
|
||||
|
||||
# 初始化时将默认动作加载到使用中的动作
|
||||
self._using_actions = component_registry.get_default_actions()
|
||||
self._using_actions = component_query_service.get_default_actions()
|
||||
|
||||
# === 执行Action方法 ===
|
||||
|
||||
@@ -72,17 +72,17 @@ class ActionManager:
|
||||
Optional[ActionHandle]: 执行句柄,如果动作未注册则返回 None
|
||||
"""
|
||||
try:
|
||||
executor = component_registry.get_action_executor(action_name)
|
||||
executor = component_query_service.get_action_executor(action_name)
|
||||
if not executor:
|
||||
logger.warning(f"{log_prefix} 未找到Action组件: {action_name}")
|
||||
return None
|
||||
|
||||
info = component_registry.get_action_info(action_name)
|
||||
info = component_query_service.get_action_info(action_name)
|
||||
if not info:
|
||||
logger.warning(f"{log_prefix} 未找到Action组件信息: {action_name}")
|
||||
return None
|
||||
|
||||
plugin_config = component_registry.get_plugin_config(info.plugin_name) or {}
|
||||
plugin_config = component_query_service.get_plugin_config(info.plugin_name) or {}
|
||||
|
||||
handle = ActionHandle(
|
||||
executor,
|
||||
@@ -133,5 +133,5 @@ class ActionManager:
|
||||
def restore_actions(self) -> None:
|
||||
"""恢复到默认动作集"""
|
||||
actions_to_restore = list(self._using_actions.keys())
|
||||
self._using_actions = component_registry.get_default_actions()
|
||||
self._using_actions = component_query_service.get_default_actions()
|
||||
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
|
||||
|
||||
@@ -1,33 +1,36 @@
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
import time
|
||||
import traceback
|
||||
import random
|
||||
import re
|
||||
import contextlib
|
||||
from typing import Dict, Optional, Tuple, List, TYPE_CHECKING, Union
|
||||
from collections import OrderedDict
|
||||
from rich.traceback import install
|
||||
from datetime import datetime
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from json_repair import repair_json
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from rich.traceback import install
|
||||
|
||||
from src.chat.logger.plan_reply_logger import PlanReplyLogger
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.core.types import ActionActivationType, ActionInfo, ComponentType
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.person_info.person_info import Person
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.message_service import (
|
||||
build_readable_messages_with_id,
|
||||
replace_user_references,
|
||||
get_messages_before_time_in_chat,
|
||||
replace_user_references,
|
||||
translate_pid_to_description,
|
||||
)
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.core.types import ActionActivationType, ActionInfo, ComponentType
|
||||
from src.core.component_registry import component_registry
|
||||
from src.person_info.person_info import Person
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.info_data_model import TargetPersonInfo
|
||||
@@ -634,7 +637,7 @@ class ActionPlanner:
|
||||
current_available_actions_dict = self.action_manager.get_using_actions()
|
||||
|
||||
# 获取完整的动作信息
|
||||
all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore
|
||||
all_registered_actions: Dict[str, ActionInfo] = component_query_service.get_components_by_type( # type: ignore
|
||||
ComponentType.ACTION
|
||||
)
|
||||
current_available_actions = {}
|
||||
|
||||
@@ -17,7 +17,6 @@ from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUser
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
|
||||
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
|
||||
from src.chat.utils.utils import get_bot_account, get_chat_type_and_target_info, is_bot_self
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
@@ -51,10 +50,15 @@ class DefaultReplyer:
|
||||
chat_stream: BotChatSession,
|
||||
request_type: str = "replyer",
|
||||
):
|
||||
"""初始化群聊回复器。
|
||||
|
||||
Args:
|
||||
chat_stream: 当前绑定的聊天会话。
|
||||
request_type: LLM 请求类型标识。
|
||||
"""
|
||||
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
|
||||
self.chat_stream = chat_stream
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
|
||||
self.heart_fc_sender = UniversalMessageSender()
|
||||
|
||||
from src.chat.tool_executor import ToolExecutor
|
||||
|
||||
@@ -1129,7 +1133,10 @@ class DefaultReplyer:
|
||||
user_id=bot_user_id,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
),
|
||||
additional_config={},
|
||||
additional_config={
|
||||
"platform_io_target_group_id": self.chat_stream.group_id,
|
||||
"platform_io_target_user_id": self.chat_stream.user_id,
|
||||
},
|
||||
),
|
||||
message_segment=message_segment,
|
||||
)
|
||||
|
||||
@@ -16,7 +16,6 @@ from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUser
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.utils.utils import get_bot_account, get_chat_type_and_target_info, is_bot_self
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
@@ -47,10 +46,15 @@ class PrivateReplyer:
|
||||
chat_stream: BotChatSession,
|
||||
request_type: str = "replyer",
|
||||
):
|
||||
"""初始化私聊回复器。
|
||||
|
||||
Args:
|
||||
chat_stream: 当前绑定的聊天会话。
|
||||
request_type: LLM 请求类型标识。
|
||||
"""
|
||||
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
|
||||
self.chat_stream = chat_stream
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
|
||||
self.heart_fc_sender = UniversalMessageSender()
|
||||
# self.memory_activator = MemoryActivator()
|
||||
|
||||
from src.chat.tool_executor import ToolExecutor
|
||||
@@ -970,7 +974,9 @@ class PrivateReplyer:
|
||||
user_nickname=global_config.bot.nickname,
|
||||
),
|
||||
group_info=None,
|
||||
additional_config={},
|
||||
additional_config={
|
||||
"platform_io_target_user_id": self.chat_stream.user_id,
|
||||
},
|
||||
),
|
||||
message_segment=message_segment,
|
||||
)
|
||||
|
||||
@@ -1,22 +1,20 @@
|
||||
"""
|
||||
工具执行器
|
||||
"""工具执行器。
|
||||
|
||||
独立的工具执行组件,可以直接输入聊天消息内容,
|
||||
自动判断并执行相应的工具,返回结构化的工具执行结果。
|
||||
|
||||
从 src.plugin_system.core.tool_use 迁移,使用新的核心组件注册表。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.core.announcement_manager import global_announcement_manager
|
||||
from src.core.component_registry import component_registry
|
||||
from src.llm_models.payload_content import ToolCall
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
|
||||
logger = get_logger("tool_use")
|
||||
@@ -89,7 +87,7 @@ class ToolExecutor:
|
||||
|
||||
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
|
||||
"""获取 LLM 可用的工具定义列表"""
|
||||
all_tools = component_registry.get_llm_available_tools()
|
||||
all_tools = component_query_service.get_llm_available_tools()
|
||||
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
|
||||
return [info.get_llm_definition() for name, info in all_tools.items() if name not in user_disabled_tools]
|
||||
|
||||
@@ -152,7 +150,7 @@ class ToolExecutor:
|
||||
function_args = tool_call.args or {}
|
||||
function_args["llm_called"] = True
|
||||
|
||||
executor = component_registry.get_tool_executor(function_name)
|
||||
executor = component_query_service.get_tool_executor(function_name)
|
||||
if not executor:
|
||||
logger.warning(f"未知工具名称: {function_name}")
|
||||
return None
|
||||
|
||||
@@ -436,14 +436,14 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
@staticmethod
|
||||
def _fetch_online_time_since(query_start_time: datetime) -> list[tuple[datetime, datetime]]:
|
||||
with get_db_session() as session:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(OnlineTime).where(col(OnlineTime.end_timestamp) >= query_start_time)
|
||||
records = session.exec(statement).all()
|
||||
return [(record.start_timestamp, record.end_timestamp) for record in records]
|
||||
|
||||
@staticmethod
|
||||
def _fetch_model_usage_since(query_start_time: datetime) -> list[dict[str, object]]:
|
||||
with get_db_session() as session:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(ModelUsage).where(col(ModelUsage.timestamp) >= query_start_time)
|
||||
records = session.exec(statement).all()
|
||||
return [
|
||||
@@ -664,7 +664,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
}
|
||||
|
||||
query_start_timestamp = collect_period[-1][1]
|
||||
with get_db_session() as session:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(Messages).where(col(Messages.timestamp) >= query_start_timestamp)
|
||||
messages = session.exec(statement).all()
|
||||
for message in messages:
|
||||
@@ -713,7 +713,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
# 使用 ActionRecords 中的 reply 动作次数作为回复数基准
|
||||
try:
|
||||
action_query_start_timestamp = collect_period[-1][1]
|
||||
with get_db_session() as session:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(ActionRecord).where(col(ActionRecord.timestamp) >= action_query_start_timestamp)
|
||||
actions = session.exec(statement).all()
|
||||
for action in actions:
|
||||
@@ -1750,7 +1750,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
# 查询消息记录
|
||||
query_start_timestamp = start_time.timestamp()
|
||||
with get_db_session() as session:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(Messages).where(col(Messages.timestamp) >= start_time)
|
||||
messages = session.exec(statement).all()
|
||||
for message in messages:
|
||||
@@ -2131,7 +2131,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
# 查询消息记录
|
||||
query_start_timestamp = start_time.timestamp()
|
||||
with get_db_session() as session:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(Messages).where(col(Messages.timestamp) >= start_time)
|
||||
messages = session.exec(statement).all()
|
||||
for message in messages:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from typing import Any, List, Mapping, Optional, Sequence
|
||||
|
||||
import json
|
||||
|
||||
@@ -15,6 +15,76 @@ class GroupCardnameInfo:
|
||||
group_cardname: str
|
||||
|
||||
|
||||
def _normalize_group_cardname_item(raw_item: Mapping[str, Any]) -> Optional[GroupCardnameInfo]:
|
||||
"""将单条群名片数据规范化为统一结构。
|
||||
|
||||
Args:
|
||||
raw_item: 原始群名片字典,必须包含 `group_id` 和 `group_cardname`。
|
||||
|
||||
Returns:
|
||||
Optional[GroupCardnameInfo]: 规范化后的群名片信息;若数据不完整则返回 ``None``。
|
||||
"""
|
||||
group_id = str(raw_item.get("group_id") or "").strip()
|
||||
group_cardname = str(raw_item.get("group_cardname") or "").strip()
|
||||
if not group_id or not group_cardname:
|
||||
return None
|
||||
return GroupCardnameInfo(group_id=group_id, group_cardname=group_cardname)
|
||||
|
||||
|
||||
def parse_group_cardname_json(group_cardname_json: Optional[str]) -> Optional[List[GroupCardnameInfo]]:
|
||||
"""解析数据库中的群名片 JSON 字段。
|
||||
|
||||
Args:
|
||||
group_cardname_json: 数据库存储的群名片 JSON 字符串。
|
||||
|
||||
Returns:
|
||||
Optional[List[GroupCardnameInfo]]: 解析并规范化后的群名片列表;若字段为空或无有效项则返回 ``None``。
|
||||
|
||||
Raises:
|
||||
json.JSONDecodeError: 当 JSON 文本格式非法时抛出。
|
||||
TypeError: 当输入值类型不符合 `json.loads()` 要求时抛出。
|
||||
"""
|
||||
if not group_cardname_json:
|
||||
return None
|
||||
|
||||
raw_items = json.loads(group_cardname_json)
|
||||
if not isinstance(raw_items, list):
|
||||
return None
|
||||
|
||||
normalized_items: List[GroupCardnameInfo] = []
|
||||
for raw_item in raw_items:
|
||||
if not isinstance(raw_item, Mapping):
|
||||
continue
|
||||
if normalized_item := _normalize_group_cardname_item(raw_item):
|
||||
normalized_items.append(normalized_item)
|
||||
|
||||
return normalized_items or None
|
||||
|
||||
|
||||
def dump_group_cardname_records(
|
||||
group_cardname_records: Optional[Sequence[GroupCardnameInfo | Mapping[str, Any]]],
|
||||
) -> str:
|
||||
"""将群名片列表序列化为数据库使用的标准 JSON 字符串。
|
||||
|
||||
Args:
|
||||
group_cardname_records: 待序列化的群名片列表,支持 `GroupCardnameInfo`
|
||||
对象和包含 `group_id` / `group_cardname` 的字典。
|
||||
|
||||
Returns:
|
||||
str: 统一使用 `group_cardname` 键名的 JSON 字符串。
|
||||
"""
|
||||
normalized_items: List[GroupCardnameInfo] = []
|
||||
for raw_item in group_cardname_records or []:
|
||||
if isinstance(raw_item, GroupCardnameInfo):
|
||||
normalized_items.append(raw_item)
|
||||
continue
|
||||
if isinstance(raw_item, Mapping):
|
||||
if normalized_item := _normalize_group_cardname_item(raw_item):
|
||||
normalized_items.append(normalized_item)
|
||||
|
||||
return json.dumps([asdict(item) for item in normalized_items], ensure_ascii=False)
|
||||
|
||||
|
||||
class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -58,9 +128,16 @@ class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
|
||||
"""最后一次被认识的时间"""
|
||||
|
||||
@classmethod
|
||||
def from_db_instance(cls, db_record: "PersonInfo"):
|
||||
nickname_json = json.loads(db_record.group_cardname) if db_record.group_cardname else None
|
||||
group_cardname_list = [GroupCardnameInfo(**item) for item in nickname_json] if nickname_json else None
|
||||
def from_db_instance(cls, db_record: "PersonInfo") -> "MaiPersonInfo":
|
||||
"""从数据库记录构造人物信息数据模型。
|
||||
|
||||
Args:
|
||||
db_record: 数据库中的人物信息记录。
|
||||
|
||||
Returns:
|
||||
MaiPersonInfo: 转换后的数据模型对象。
|
||||
"""
|
||||
group_cardname_list = parse_group_cardname_json(db_record.group_cardname)
|
||||
memory_points = json.loads(db_record.memory_points) if db_record.memory_points else None
|
||||
return cls(
|
||||
is_known=db_record.is_known,
|
||||
@@ -78,9 +155,12 @@ class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
|
||||
)
|
||||
|
||||
def to_db_instance(self) -> "PersonInfo":
|
||||
group_cardname = (
|
||||
json.dumps([gc.__dict__ for gc in self.group_cardname_list]) if self.group_cardname_list else None
|
||||
)
|
||||
"""将当前数据模型转换为数据库记录对象。
|
||||
|
||||
Returns:
|
||||
PersonInfo: 可直接写入数据库的模型实例。
|
||||
"""
|
||||
group_cardname = dump_group_cardname_records(self.group_cardname_list)
|
||||
return PersonInfo(
|
||||
is_known=self.is_known,
|
||||
person_id=self.person_id,
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from typing import Optional
|
||||
from sqlalchemy import Column, Float, Enum as SQLEnum, DateTime
|
||||
from sqlmodel import SQLModel, Field, LargeBinary
|
||||
from enum import Enum
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Column, DateTime, Enum as SQLEnum, Float
|
||||
from sqlmodel import Field, LargeBinary, SQLModel
|
||||
|
||||
|
||||
class ModelUser(str, Enum):
|
||||
@@ -172,8 +173,8 @@ class Expression(SQLModel, table=True):
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
|
||||
|
||||
situation: str = Field(index=True, max_length=255, primary_key=True) # 情景
|
||||
style: str = Field(index=True, max_length=255, primary_key=True) # 风格
|
||||
situation: str = Field(index=True, max_length=255) # 情景
|
||||
style: str = Field(index=True, max_length=255) # 风格
|
||||
|
||||
# context: str # 上下文
|
||||
# up_content: str
|
||||
@@ -200,7 +201,7 @@ class Jargon(SQLModel, table=True):
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
|
||||
|
||||
content: str = Field(index=True, max_length=255, primary_key=True) # 黑话内容
|
||||
content: str = Field(index=True, max_length=255) # 黑话内容
|
||||
raw_content: Optional[str] = Field(default=None, nullable=True) # 原始内容,未处理的黑话内容,为List[str]
|
||||
|
||||
meaning: str # 黑话含义
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
# 定义模块颜色映射
|
||||
from typing import Optional, Tuple, Dict
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
|
||||
MODULE_COLORS: Dict[str, Tuple[str, Optional[str], bool]] = {
|
||||
@@ -54,15 +53,19 @@ MODULE_COLORS: Dict[str, Tuple[str, Optional[str], bool]] = {
|
||||
"component_registry": ("#ffaf00", None, False),
|
||||
"plugin_runtime.integration": ("#d75f00", None, False),
|
||||
"plugin_runtime.host.supervisor": ("#ff5f00", None, False),
|
||||
"plugin_runtime.host.runner_manager": ("#ff5f00", None, False),
|
||||
"plugin_runtime.host.rpc_server": ("#ff8700", None, False),
|
||||
"plugin_runtime.host.component_registry": ("#ffaf00", None, False),
|
||||
"plugin_runtime.host.capability_service": ("#ffd700", None, False),
|
||||
"plugin_runtime.host.event_dispatcher": ("#87d700", None, False),
|
||||
"plugin_runtime.host.workflow_executor": ("#5fd7af", None, False),
|
||||
"plugin_runtime.host.hook_dispatcher": ("#5fd7af", None, False),
|
||||
"plugin_runtime.host.message_gateway": ("#5fd7d7", None, False),
|
||||
"plugin_runtime.host.message_utils": ("#5faf87", None, False),
|
||||
"plugin_runtime.runner.main": ("#d787ff", None, False),
|
||||
"plugin_runtime.runner.rpc_client": ("#8787ff", None, False),
|
||||
"plugin_runtime.runner.manifest_validator": ("#5fafff", None, False),
|
||||
"plugin_runtime.runner.plugin_loader": ("#00afaf", None, False),
|
||||
"plugin.maibot-team.napcat-adapter": ("#00af87", None, False),
|
||||
"webui": ("#5f87ff", None, False),
|
||||
"webui.app": ("#5f87d7", None, False),
|
||||
"webui.api": ("#5fafff", None, False),
|
||||
@@ -157,15 +160,20 @@ MODULE_ALIASES = {
|
||||
"chat_history_summarizer": "聊天概括器",
|
||||
"plugin_runtime.integration": "IPC插件系统",
|
||||
"plugin_runtime.host.supervisor": "插件监督器",
|
||||
"plugin_runtime.host.runner_manager": "插件监督器",
|
||||
"plugin_runtime.host.rpc_server": "插件RPC服务",
|
||||
"plugin_runtime.host.component_registry": "插件组件注册",
|
||||
"plugin_runtime.host.capability_service": "插件能力服务",
|
||||
"plugin_runtime.host.event_dispatcher": "插件事件分发",
|
||||
"plugin_runtime.host.hook_dispatcher": "插件Hook分发",
|
||||
"plugin_runtime.host.message_gateway": "插件消息网关",
|
||||
"plugin_runtime.host.message_utils": "插件消息工具",
|
||||
"plugin_runtime.host.workflow_executor": "插件工作流",
|
||||
"plugin_runtime.runner.main": "插件运行器",
|
||||
"plugin_runtime.runner.rpc_client": "插件RPC客户端",
|
||||
"plugin_runtime.runner.manifest_validator": "插件清单校验",
|
||||
"plugin_runtime.runner.plugin_loader": "插件加载器",
|
||||
"plugin.maibot-team.napcat-adapter": "NapCat内置适配器",
|
||||
"webui": "WebUI",
|
||||
"webui.app": "WebUI应用",
|
||||
"webui.api": "WebUI接口",
|
||||
|
||||
@@ -21,7 +21,7 @@ class Server:
|
||||
self._server: Optional[UvicornServer] = None
|
||||
self.set_address(host, port)
|
||||
|
||||
def register_router(self, router: APIRouter, prefix: str = ""):
|
||||
def register_router(self, router: APIRouter, prefix: str = ""):
|
||||
"""注册路由
|
||||
|
||||
APIRouter 用于对相关的路由端点进行分组和模块化管理:
|
||||
|
||||
@@ -5,13 +5,22 @@ import hashlib
|
||||
|
||||
class SessionUtils:
|
||||
@staticmethod
|
||||
def calculate_session_id(platform: str, *, user_id: Optional[str] = None, group_id: Optional[str] = None) -> str:
|
||||
def calculate_session_id(
|
||||
platform: str,
|
||||
*,
|
||||
user_id: Optional[str] = None,
|
||||
group_id: Optional[str] = None,
|
||||
account_id: Optional[str] = None,
|
||||
scope: Optional[str] = None,
|
||||
) -> str:
|
||||
"""计算session_id
|
||||
|
||||
Args:
|
||||
platform: 平台名称
|
||||
user_id: 用户ID(如果是私聊)
|
||||
group_id: 群ID(如果是群聊)
|
||||
account_id: 当前平台账号 ID,可选
|
||||
scope: 当前路由作用域,可选
|
||||
Returns:
|
||||
str: 计算得到的会话ID
|
||||
Raises:
|
||||
@@ -19,8 +28,15 @@ class SessionUtils:
|
||||
"""
|
||||
if not user_id and not group_id:
|
||||
raise ValueError("UserID 或 GroupID 必须提供其一")
|
||||
|
||||
route_components = []
|
||||
if account_id:
|
||||
route_components.append(f"account:{account_id}")
|
||||
if scope:
|
||||
route_components.append(f"scope:{scope}")
|
||||
|
||||
if group_id:
|
||||
components = [platform, group_id]
|
||||
components = [platform, *route_components, group_id]
|
||||
else:
|
||||
components = [platform, user_id, "private"]
|
||||
components = [platform, *route_components, user_id, "private"]
|
||||
return hashlib.md5("_".join(components).encode()).hexdigest()
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Callable, Mapping, Sequence, TypeVar
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import inspect
|
||||
import sys
|
||||
|
||||
import tomlkit
|
||||
@@ -61,6 +62,7 @@ MODEL_CONFIG_VERSION: str = "1.12.0"
|
||||
logger = get_logger("config")
|
||||
|
||||
T = TypeVar("T", bound="ConfigBase")
|
||||
ConfigReloadCallback = Callable[[Sequence[str]], object] | Callable[[], object]
|
||||
|
||||
|
||||
class Config(ConfigBase):
|
||||
@@ -190,7 +192,7 @@ class ConfigManager:
|
||||
self.global_config: Config | None = None
|
||||
self.model_config: ModelConfig | None = None
|
||||
self._reload_lock: asyncio.Lock = asyncio.Lock()
|
||||
self._reload_callbacks: list[Callable[[], object]] = []
|
||||
self._reload_callbacks: list[ConfigReloadCallback] = []
|
||||
self._file_watcher: FileWatcher | None = None
|
||||
self._file_watcher_subscription_id: str | None = None
|
||||
self._hot_reload_min_interval_s: float = 1.0
|
||||
@@ -226,16 +228,125 @@ class ConfigManager:
|
||||
raise RuntimeError(t("config.model_not_initialized"))
|
||||
return self.model_config
|
||||
|
||||
def register_reload_callback(self, callback: Callable[[], object]) -> None:
|
||||
def register_reload_callback(self, callback: ConfigReloadCallback) -> None:
|
||||
"""注册配置热重载回调。
|
||||
|
||||
Args:
|
||||
callback: 配置热重载回调。允许无参回调,也允许接收
|
||||
``Sequence[str]`` 类型的变更范围列表。
|
||||
"""
|
||||
|
||||
self._reload_callbacks.append(callback)
|
||||
|
||||
def unregister_reload_callback(self, callback: Callable[[], object]) -> None:
|
||||
def unregister_reload_callback(self, callback: ConfigReloadCallback) -> None:
|
||||
"""注销配置热重载回调。
|
||||
|
||||
Args:
|
||||
callback: 先前注册过的回调对象。
|
||||
"""
|
||||
|
||||
try:
|
||||
self._reload_callbacks.remove(callback)
|
||||
except ValueError:
|
||||
return
|
||||
|
||||
async def reload_config(self) -> bool:
|
||||
@staticmethod
|
||||
def _normalize_changed_scopes(changed_scopes: Sequence[str] | None) -> tuple[str, ...]:
|
||||
"""规范化配置变更范围列表。
|
||||
|
||||
Args:
|
||||
changed_scopes: 原始配置变更范围。
|
||||
|
||||
Returns:
|
||||
tuple[str, ...]: 去重后的配置变更范围元组。
|
||||
"""
|
||||
|
||||
if not changed_scopes:
|
||||
return ("bot", "model")
|
||||
|
||||
normalized_scopes: list[str] = []
|
||||
for scope in changed_scopes:
|
||||
normalized_scope = str(scope or "").strip().lower()
|
||||
if normalized_scope not in {"bot", "model"}:
|
||||
continue
|
||||
if normalized_scope not in normalized_scopes:
|
||||
normalized_scopes.append(normalized_scope)
|
||||
return tuple(normalized_scopes)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_changed_scopes(changes: Sequence[FileChange]) -> tuple[str, ...]:
|
||||
"""根据文件变更列表推断配置变更范围。
|
||||
|
||||
Args:
|
||||
changes: 文件监听器返回的变更列表。
|
||||
|
||||
Returns:
|
||||
tuple[str, ...]: 命中的配置变更范围元组。
|
||||
"""
|
||||
|
||||
changed_scopes: list[str] = []
|
||||
for change in changes:
|
||||
file_name = change.path.name
|
||||
if file_name == "bot_config.toml" and "bot" not in changed_scopes:
|
||||
changed_scopes.append("bot")
|
||||
if file_name == "model_config.toml" and "model" not in changed_scopes:
|
||||
changed_scopes.append("model")
|
||||
return tuple(changed_scopes)
|
||||
|
||||
@staticmethod
|
||||
def _callback_accepts_scopes(callback: ConfigReloadCallback) -> bool:
|
||||
"""判断回调是否接收配置变更范围参数。
|
||||
|
||||
Args:
|
||||
callback: 待检测的回调对象。
|
||||
|
||||
Returns:
|
||||
bool: 若回调可接收一个位置参数或可变位置参数,则返回 ``True``。
|
||||
"""
|
||||
|
||||
try:
|
||||
parameters = inspect.signature(callback).parameters.values()
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
positional_params = {
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
}
|
||||
for parameter in parameters:
|
||||
if parameter.kind == inspect.Parameter.VAR_POSITIONAL:
|
||||
return True
|
||||
if parameter.kind in positional_params:
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _invoke_reload_callback(
|
||||
self,
|
||||
callback: ConfigReloadCallback,
|
||||
changed_scopes: Sequence[str],
|
||||
) -> None:
|
||||
"""执行单个配置热重载回调。
|
||||
|
||||
Args:
|
||||
callback: 要执行的回调对象。
|
||||
changed_scopes: 本次热重载命中的配置范围。
|
||||
"""
|
||||
|
||||
result = callback(changed_scopes) if self._callback_accepts_scopes(callback) else callback()
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
|
||||
async def reload_config(self, changed_scopes: Sequence[str] | None = None) -> bool:
|
||||
"""重新加载主配置和模型配置。
|
||||
|
||||
Args:
|
||||
changed_scopes: 本次触发热重载的配置范围。
|
||||
|
||||
Returns:
|
||||
bool: 是否重载成功。
|
||||
"""
|
||||
|
||||
normalized_scopes = self._normalize_changed_scopes(changed_scopes)
|
||||
async with self._reload_lock:
|
||||
try:
|
||||
global_config_new, global_updated = load_config_from_file(
|
||||
@@ -265,9 +376,7 @@ class ConfigManager:
|
||||
|
||||
for callback in list(self._reload_callbacks):
|
||||
try:
|
||||
result = callback()
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
await self._invoke_reload_callback(callback, normalized_scopes)
|
||||
except Exception as exc:
|
||||
logger.warning(t("config.reload_callback_failed", error=exc))
|
||||
return True
|
||||
@@ -312,6 +421,12 @@ class ConfigManager:
|
||||
self._file_watcher = None
|
||||
|
||||
async def _handle_file_changes(self, changes: Sequence[FileChange]) -> None:
|
||||
"""处理主配置与模型配置文件变更。
|
||||
|
||||
Args:
|
||||
changes: 当前批次收集到的文件变更列表。
|
||||
"""
|
||||
|
||||
if not changes:
|
||||
return
|
||||
now_monotonic = asyncio.get_running_loop().time()
|
||||
@@ -321,7 +436,11 @@ class ConfigManager:
|
||||
self._last_hot_reload_monotonic = now_monotonic
|
||||
logger.info(t("config.file_change_detected"))
|
||||
try:
|
||||
await asyncio.wait_for(self.reload_config(), timeout=self._hot_reload_timeout_s)
|
||||
changed_scopes = self._resolve_changed_scopes(changes)
|
||||
await asyncio.wait_for(
|
||||
self.reload_config(changed_scopes=changed_scopes),
|
||||
timeout=self._hot_reload_timeout_s,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(t("config.reload_timeout", timeout_seconds=self._hot_reload_timeout_s))
|
||||
|
||||
|
||||
@@ -1633,24 +1633,6 @@ class PluginRuntimeConfig(ConfigBase):
|
||||
)
|
||||
"""启用插件系统"""
|
||||
|
||||
builtin_plugin_dir: str = Field(
|
||||
default="src/plugins/built_in",
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "folder",
|
||||
},
|
||||
)
|
||||
"""内置插件目录(相对于项目根目录)"""
|
||||
|
||||
thirdparty_plugin_dir: str = Field(
|
||||
default="plugins",
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "folder-open",
|
||||
},
|
||||
)
|
||||
"""第三方插件目录(相对于项目根目录)"""
|
||||
|
||||
health_check_interval_sec: float = Field(
|
||||
default=30.0,
|
||||
json_schema_extra={
|
||||
@@ -1678,14 +1660,14 @@ class PluginRuntimeConfig(ConfigBase):
|
||||
)
|
||||
"""等待 Runner 子进程启动并注册的超时时间(秒)"""
|
||||
|
||||
workflow_blocking_timeout_sec: float = Field(
|
||||
default=120.0,
|
||||
hook_blocking_timeout_sec: float = Field(
|
||||
default=30,
|
||||
json_schema_extra={
|
||||
"x-widget": "number",
|
||||
"x-icon": "timer",
|
||||
},
|
||||
)
|
||||
"""Workflow 阻塞步骤的全局超时上限(秒)"""
|
||||
"""Hook 阻塞步骤的全局超时上限(秒)"""
|
||||
|
||||
ipc_socket_path: str = Field(
|
||||
default="",
|
||||
@@ -1694,4 +1676,7 @@ class PluginRuntimeConfig(ConfigBase):
|
||||
"x-icon": "link",
|
||||
},
|
||||
)
|
||||
"""_wrap_\n 自定义 IPC Socket 路径(仅 Linux/macOS 生效)\n 留空则自动生成临时路径"""
|
||||
"""
|
||||
自定义 IPC Socket 路径(仅 Linux/macOS 生效)
|
||||
留空则自动生成临时路径
|
||||
"""
|
||||
|
||||
@@ -1,239 +0,0 @@
|
||||
"""
|
||||
核心组件注册表
|
||||
|
||||
面向最终架构的组件管理:
|
||||
- Action:注册 ActionInfo + 执行器(本地 callable 或 IPC 路由)
|
||||
- Command:注册正则模式 + 执行器
|
||||
- Tool:注册工具定义 + 执行器
|
||||
|
||||
不依赖任何插件基类,组件执行器是纯 async callable。
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional, Pattern, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.core.types import (
|
||||
ActionInfo,
|
||||
CommandInfo,
|
||||
ComponentInfo,
|
||||
ComponentType,
|
||||
ToolInfo,
|
||||
)
|
||||
|
||||
logger = get_logger("component_registry")
|
||||
|
||||
# 执行器类型
|
||||
ActionExecutor = Callable[..., Awaitable[Any]]
|
||||
CommandExecutor = Callable[..., Awaitable[Tuple[bool, Optional[str], bool]]]
|
||||
ToolExecutor = Callable[..., Awaitable[Any]]
|
||||
|
||||
|
||||
class ComponentRegistry:
|
||||
"""核心组件注册表
|
||||
|
||||
管理 action、command、tool 三类组件。
|
||||
每个组件由「元信息 + 执行器」构成,执行器是 async callable,
|
||||
不需要继承任何基类。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Action 注册
|
||||
self._actions: Dict[str, ActionInfo] = {}
|
||||
self._action_executors: Dict[str, ActionExecutor] = {}
|
||||
self._default_actions: Dict[str, ActionInfo] = {}
|
||||
|
||||
# Command 注册
|
||||
self._commands: Dict[str, CommandInfo] = {}
|
||||
self._command_executors: Dict[str, CommandExecutor] = {}
|
||||
self._command_patterns: Dict[Pattern, str] = {}
|
||||
|
||||
# Tool 注册
|
||||
self._tools: Dict[str, ToolInfo] = {}
|
||||
self._tool_executors: Dict[str, ToolExecutor] = {}
|
||||
self._llm_available_tools: Dict[str, ToolInfo] = {}
|
||||
|
||||
# 插件配置(plugin_name -> config dict)
|
||||
self._plugin_configs: Dict[str, dict] = {}
|
||||
|
||||
logger.info("核心组件注册表初始化完成")
|
||||
|
||||
# ========== Action ==========
|
||||
|
||||
def register_action(
|
||||
self,
|
||||
info: ActionInfo,
|
||||
executor: ActionExecutor,
|
||||
) -> bool:
|
||||
"""注册 action
|
||||
|
||||
Args:
|
||||
info: action 元信息
|
||||
executor: 执行器,async callable
|
||||
"""
|
||||
name = info.name
|
||||
if name in self._actions:
|
||||
logger.warning(f"Action {name} 已存在,跳过注册")
|
||||
return False
|
||||
|
||||
self._actions[name] = info
|
||||
self._action_executors[name] = executor
|
||||
|
||||
if info.enabled:
|
||||
self._default_actions[name] = info
|
||||
|
||||
logger.debug(f"注册 Action: {name}")
|
||||
return True
|
||||
|
||||
def get_action_info(self, name: str) -> Optional[ActionInfo]:
|
||||
return self._actions.get(name)
|
||||
|
||||
def get_action_executor(self, name: str) -> Optional[ActionExecutor]:
|
||||
return self._action_executors.get(name)
|
||||
|
||||
def get_default_actions(self) -> Dict[str, ActionInfo]:
|
||||
return self._default_actions.copy()
|
||||
|
||||
def get_all_actions(self) -> Dict[str, ActionInfo]:
|
||||
return self._actions.copy()
|
||||
|
||||
def remove_action(self, name: str) -> bool:
|
||||
if name not in self._actions:
|
||||
return False
|
||||
del self._actions[name]
|
||||
self._action_executors.pop(name, None)
|
||||
self._default_actions.pop(name, None)
|
||||
logger.debug(f"移除 Action: {name}")
|
||||
return True
|
||||
|
||||
# ========== Command ==========
|
||||
|
||||
def register_command(
|
||||
self,
|
||||
info: CommandInfo,
|
||||
executor: CommandExecutor,
|
||||
) -> bool:
|
||||
"""注册 command"""
|
||||
name = info.name
|
||||
if name in self._commands:
|
||||
logger.warning(f"Command {name} 已存在,跳过注册")
|
||||
return False
|
||||
|
||||
self._commands[name] = info
|
||||
self._command_executors[name] = executor
|
||||
|
||||
if info.enabled and info.command_pattern:
|
||||
pattern = re.compile(info.command_pattern, re.IGNORECASE | re.DOTALL)
|
||||
self._command_patterns[pattern] = name
|
||||
|
||||
logger.debug(f"注册 Command: {name}")
|
||||
return True
|
||||
|
||||
def find_command_by_text(self, text: str) -> Optional[Tuple[CommandExecutor, dict, CommandInfo]]:
|
||||
"""根据文本查找匹配的命令
|
||||
|
||||
Returns:
|
||||
(executor, matched_groups, command_info) 或 None
|
||||
"""
|
||||
candidates = [p for p in self._command_patterns if p.match(text)]
|
||||
if not candidates:
|
||||
return None
|
||||
if len(candidates) > 1:
|
||||
logger.warning(f"文本 '{text[:50]}' 匹配到多个命令模式,使用第一个")
|
||||
pattern = candidates[0]
|
||||
name = self._command_patterns[pattern]
|
||||
return (
|
||||
self._command_executors[name],
|
||||
pattern.match(text).groupdict(), # type: ignore
|
||||
self._commands[name],
|
||||
)
|
||||
|
||||
def remove_command(self, name: str) -> bool:
|
||||
if name not in self._commands:
|
||||
return False
|
||||
del self._commands[name]
|
||||
self._command_executors.pop(name, None)
|
||||
self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != name}
|
||||
logger.debug(f"移除 Command: {name}")
|
||||
return True
|
||||
|
||||
# ========== Tool ==========
|
||||
|
||||
def register_tool(
|
||||
self,
|
||||
info: ToolInfo,
|
||||
executor: ToolExecutor,
|
||||
) -> bool:
|
||||
"""注册 tool"""
|
||||
name = info.name
|
||||
if name in self._tools:
|
||||
logger.warning(f"Tool {name} 已存在,跳过注册")
|
||||
return False
|
||||
|
||||
self._tools[name] = info
|
||||
self._tool_executors[name] = executor
|
||||
|
||||
if info.enabled:
|
||||
self._llm_available_tools[name] = info
|
||||
|
||||
logger.debug(f"注册 Tool: {name}")
|
||||
return True
|
||||
|
||||
def get_tool_info(self, name: str) -> Optional[ToolInfo]:
|
||||
return self._tools.get(name)
|
||||
|
||||
def get_tool_executor(self, name: str) -> Optional[ToolExecutor]:
|
||||
return self._tool_executors.get(name)
|
||||
|
||||
def get_llm_available_tools(self) -> Dict[str, ToolInfo]:
|
||||
return self._llm_available_tools.copy()
|
||||
|
||||
def get_all_tools(self) -> Dict[str, ToolInfo]:
|
||||
return self._tools.copy()
|
||||
|
||||
def remove_tool(self, name: str) -> bool:
|
||||
if name not in self._tools:
|
||||
return False
|
||||
del self._tools[name]
|
||||
self._tool_executors.pop(name, None)
|
||||
self._llm_available_tools.pop(name, None)
|
||||
logger.debug(f"移除 Tool: {name}")
|
||||
return True
|
||||
|
||||
# ========== 通用查询 ==========
|
||||
|
||||
def get_component_info(self, name: str, component_type: ComponentType) -> Optional[ComponentInfo]:
|
||||
"""获取组件元信息"""
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
return self._actions.get(name)
|
||||
case ComponentType.COMMAND:
|
||||
return self._commands.get(name)
|
||||
case ComponentType.TOOL:
|
||||
return self._tools.get(name)
|
||||
case _:
|
||||
return None
|
||||
|
||||
def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
|
||||
"""获取某类型的所有组件"""
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
return dict(self._actions)
|
||||
case ComponentType.COMMAND:
|
||||
return dict(self._commands)
|
||||
case ComponentType.TOOL:
|
||||
return dict(self._tools)
|
||||
case _:
|
||||
return {}
|
||||
|
||||
# ========== 插件配置 ==========
|
||||
|
||||
def set_plugin_config(self, plugin_name: str, config: dict) -> None:
|
||||
self._plugin_configs[plugin_name] = config
|
||||
|
||||
def get_plugin_config(self, plugin_name: str) -> Optional[dict]:
|
||||
return self._plugin_configs.get(plugin_name)
|
||||
|
||||
|
||||
# 全局单例
|
||||
component_registry = ComponentRegistry()
|
||||
@@ -3,15 +3,15 @@
|
||||
|
||||
功能:
|
||||
1. 定期随机选取指定数量的表达方式
|
||||
2. 使用LLM进行评估
|
||||
2. 使用 LLM 进行评估
|
||||
3. 通过评估的:rejected=0, checked=1
|
||||
4. 未通过评估的:rejected=1, checked=1
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
@@ -146,7 +146,8 @@ class ExpressionAutoCheckTask(AsyncTask):
|
||||
选中的表达方式列表
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
# 这里只做查询,避免退出上下文时自动提交导致 ORM 实例过期。
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(Expression)
|
||||
all_expressions = session.exec(statement).all()
|
||||
|
||||
|
||||
@@ -329,7 +329,13 @@ class ExpressionLearner:
|
||||
return filtered_expressions
|
||||
|
||||
# ====== DB 操作相关 ======
|
||||
async def _upsert_expression_to_db(self, situation: str, style: str):
|
||||
async def _upsert_expression_to_db(self, situation: str, style: str) -> None:
|
||||
"""将表达方式写入数据库,存在时更新,不存在时新增。
|
||||
|
||||
Args:
|
||||
situation: 表达方式对应的使用情景。
|
||||
style: 表达方式风格。
|
||||
"""
|
||||
expr, similarity = self._find_similar_expression(situation) or (None, 0)
|
||||
if expr:
|
||||
# 根据相似度决定是否使用 LLM 总结
|
||||
@@ -340,7 +346,13 @@ class ExpressionLearner:
|
||||
# 没有找到匹配的记录,创建新记录
|
||||
self._create_expression(situation, style)
|
||||
|
||||
def _create_expression(self, situation: str, style: str):
|
||||
def _create_expression(self, situation: str, style: str) -> None:
|
||||
"""创建新的表达方式记录。
|
||||
|
||||
Args:
|
||||
situation: 表达方式对应的使用情景。
|
||||
style: 表达方式风格。
|
||||
"""
|
||||
content_list = [situation]
|
||||
try:
|
||||
with get_db_session() as db:
|
||||
@@ -353,6 +365,7 @@ class ExpressionLearner:
|
||||
last_active_time=datetime.now(),
|
||||
)
|
||||
db.add(new_expr)
|
||||
db.flush()
|
||||
except Exception as e:
|
||||
logger.error(f"创建表达方式失败: {e}")
|
||||
|
||||
@@ -448,25 +461,43 @@ class ExpressionLearner:
|
||||
def _find_similar_expression(
|
||||
self, situation: str, similarity_threshold: float = 0.75
|
||||
) -> Optional[Tuple[MaiExpression, float]]:
|
||||
"""在数据库中查找相似的表达方式"""
|
||||
"""在数据库中查找相似的表达方式。
|
||||
|
||||
Args:
|
||||
situation: 当前待匹配的情景描述。
|
||||
similarity_threshold: 认定为相似表达方式的最低相似度阈值。
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[MaiExpression, float]]: 若找到最相似的表达方式,则返回
|
||||
``(表达方式对象, 相似度)``;否则返回 ``None``。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(Expression).filter_by(session_id=self.session_id)
|
||||
expressions = session.exec(statement).all()
|
||||
|
||||
best_match: Optional[Expression] = None
|
||||
best_similarity = 0.0
|
||||
best_match: Optional[MaiExpression] = None
|
||||
best_similarity = 0.0
|
||||
|
||||
for db_expression in expressions:
|
||||
expression = MaiExpression.from_db_instance(db_expression)
|
||||
candidate_situations = [expression.situation, *expression.content]
|
||||
for candidate_situation in candidate_situations:
|
||||
normalized_candidate_situation = candidate_situation.strip()
|
||||
if not normalized_candidate_situation:
|
||||
continue
|
||||
similarity = difflib.SequenceMatcher(
|
||||
None,
|
||||
situation,
|
||||
normalized_candidate_situation,
|
||||
).ratio()
|
||||
if similarity > similarity_threshold and similarity > best_similarity:
|
||||
best_similarity = similarity
|
||||
best_match = expression
|
||||
|
||||
for expr in expressions:
|
||||
content_list = json.loads(expr.content_list)
|
||||
for situation in content_list:
|
||||
similarity = difflib.SequenceMatcher(None, situation, expr.situation).ratio()
|
||||
if similarity > similarity_threshold and similarity > best_similarity:
|
||||
best_similarity = similarity
|
||||
best_match = expr
|
||||
if best_match:
|
||||
logger.debug(f"找到相似表达方式情景 [ID: {best_match.id}],相似度: {best_similarity:.2f}")
|
||||
return MaiExpression.from_db_instance(best_match), best_similarity
|
||||
logger.debug(f"找到相似表达方式情景 [ID: {best_match.item_id}],相似度: {best_similarity:.2f}")
|
||||
return best_match, best_similarity
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查找相似表达方式失败: {e}")
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
from collections import OrderedDict
|
||||
from json_repair import repair_json
|
||||
from sqlmodel import select
|
||||
from typing import List, Optional, Dict, Callable, TypedDict, Set
|
||||
from typing import Callable, Dict, List, Optional, Set, TypedDict
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from json_repair import repair_json
|
||||
from sqlmodel import select
|
||||
|
||||
from src.common.data_models.jargon_data_model import MaiJargon
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Jargon
|
||||
from src.common.data_models.jargon_data_model import MaiJargon
|
||||
from src.config.config import model_config, global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
|
||||
@@ -198,7 +199,7 @@ class JargonMiner:
|
||||
|
||||
async def process_extracted_entries(
|
||||
self, entries: List[JargonEntry], person_name_filter: Optional[Callable[[str], bool]] = None
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
处理已提取的黑话条目(从 expression_learner 路由过来的)
|
||||
|
||||
@@ -229,7 +230,7 @@ class JargonMiner:
|
||||
content = entry["content"]
|
||||
raw_content_set = entry["raw_content"]
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
jargon_items = session.exec(select(Jargon).filter_by(content=content)).all()
|
||||
except Exception as e:
|
||||
logger.error(f"查询黑话 '{content}' 失败: {e}")
|
||||
@@ -273,11 +274,12 @@ class JargonMiner:
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
session.add(new_jargon)
|
||||
session.flush()
|
||||
saved += 1
|
||||
self._add_to_cache(content)
|
||||
except Exception as e:
|
||||
logger.error(f"保存新黑话 '{content}' 失败: {e}")
|
||||
continue
|
||||
finally:
|
||||
self._add_to_cache(content)
|
||||
# 固定输出提取的jargon结果,格式化为可读形式(只要有提取结果就输出)
|
||||
if uniq_entries:
|
||||
# 收集所有提取的jargon内容
|
||||
@@ -304,7 +306,13 @@ class JargonMiner:
|
||||
removed_content, _ = self.cache.popitem(last=False)
|
||||
logger.debug(f"缓存已满,移除最旧的黑话: {removed_content}")
|
||||
|
||||
def _update_jargon(self, db_jargon: Jargon, raw_content_set: Set[str]):
|
||||
def _update_jargon(self, db_jargon: Jargon, raw_content_set: Set[str]) -> None:
|
||||
"""更新已有黑话记录并写回数据库。
|
||||
|
||||
Args:
|
||||
db_jargon: 已命中的黑话 ORM 对象。
|
||||
raw_content_set: 本次新增的原始上下文集合。
|
||||
"""
|
||||
db_jargon.count += 1
|
||||
existing_raw_content: List[str] = []
|
||||
if db_jargon.raw_content:
|
||||
@@ -326,7 +334,17 @@ class JargonMiner:
|
||||
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
session.add(db_jargon)
|
||||
if db_jargon.id is None:
|
||||
raise ValueError("黑话记录缺少 id,无法更新数据库")
|
||||
statement = select(Jargon).filter_by(id=db_jargon.id).limit(1)
|
||||
if persisted_jargon := session.exec(statement).first():
|
||||
persisted_jargon.count = db_jargon.count
|
||||
persisted_jargon.raw_content = db_jargon.raw_content
|
||||
persisted_jargon.session_id_dict = db_jargon.session_id_dict
|
||||
persisted_jargon.is_global = db_jargon.is_global
|
||||
session.add(persisted_jargon)
|
||||
else:
|
||||
logger.warning(f"黑话 ID {db_jargon.id} 在数据库中未找到,无法更新")
|
||||
except Exception as e:
|
||||
logger.error(f"更新黑话 '{db_jargon.content}' 失败: {e}")
|
||||
|
||||
|
||||
@@ -1,22 +1,24 @@
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
import random
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
|
||||
from json_repair import repair_json
|
||||
from typing import Union, Optional, Dict
|
||||
from datetime import datetime
|
||||
|
||||
from sqlmodel import col, select
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.common.data_models.person_info_data_model import dump_group_cardname_records, parse_group_cardname_json
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import PersonInfo
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
|
||||
logger = get_logger("person_info")
|
||||
@@ -26,6 +28,32 @@ relation_selection_model = LLMRequest(
|
||||
)
|
||||
|
||||
|
||||
def _to_group_cardname_records(group_cardname_json: Optional[str]) -> list[dict[str, str]]:
|
||||
"""将数据库中的群名片 JSON 转换为 `Person` 内部使用的结构。
|
||||
|
||||
Args:
|
||||
group_cardname_json: 数据库存储的群名片 JSON 字符串。
|
||||
|
||||
Returns:
|
||||
list[dict[str, str]]: 统一使用 `group_cardname` 键名的群名片列表。
|
||||
|
||||
Raises:
|
||||
json.JSONDecodeError: 当 JSON 文本格式非法时抛出。
|
||||
TypeError: 当输入值类型不符合 `json.loads()` 要求时抛出。
|
||||
"""
|
||||
group_cardname_list = parse_group_cardname_json(group_cardname_json)
|
||||
if not group_cardname_list:
|
||||
return []
|
||||
|
||||
return [
|
||||
{
|
||||
"group_id": group_cardname.group_id,
|
||||
"group_cardname": group_cardname.group_cardname,
|
||||
}
|
||||
for group_cardname in group_cardname_list
|
||||
]
|
||||
|
||||
|
||||
def get_person_id(platform: str, user_id: Union[int, str]) -> str:
|
||||
"""获取唯一id"""
|
||||
if "-" in platform:
|
||||
@@ -231,7 +259,7 @@ class Person:
|
||||
person.know_since = time.time()
|
||||
person.last_know = time.time()
|
||||
person.memory_points = []
|
||||
person.group_nick_name = [] # 初始化群昵称列表
|
||||
person.group_cardname_list = [] # 初始化群名片列表
|
||||
|
||||
# 如果是群聊,添加群昵称
|
||||
if group_id and group_nick_name:
|
||||
@@ -269,7 +297,7 @@ class Person:
|
||||
self.platform = platform
|
||||
self.nickname = global_config.bot.nickname
|
||||
self.person_name = global_config.bot.nickname
|
||||
self.group_nick_name: list[dict[str, str]] = []
|
||||
self.group_cardname_list: list[dict[str, str]] = []
|
||||
return
|
||||
|
||||
self.user_id = ""
|
||||
@@ -308,7 +336,7 @@ class Person:
|
||||
self.know_since = None
|
||||
self.last_know: Optional[float] = None
|
||||
self.memory_points = []
|
||||
self.group_nick_name: list[dict[str, str]] = [] # 群昵称列表,存储 {"group_id": str, "group_nick_name": str}
|
||||
self.group_cardname_list: list[dict[str, str]] = [] # 群名片列表,存储 {"group_id": str, "group_cardname": str}
|
||||
|
||||
# 从数据库加载数据
|
||||
self.load_from_database()
|
||||
@@ -408,16 +436,16 @@ class Person:
|
||||
return
|
||||
|
||||
# 检查是否已存在该群号的记录
|
||||
for item in self.group_nick_name:
|
||||
for item in self.group_cardname_list:
|
||||
if item.get("group_id") == group_id:
|
||||
# 更新现有记录
|
||||
item["group_nick_name"] = group_nick_name
|
||||
item["group_cardname"] = group_nick_name
|
||||
self.sync_to_database()
|
||||
logger.debug(f"更新用户 {self.person_id} 在群 {group_id} 的群昵称为 {group_nick_name}")
|
||||
return
|
||||
|
||||
# 添加新记录
|
||||
self.group_nick_name.append({"group_id": group_id, "group_nick_name": group_nick_name})
|
||||
self.group_cardname_list.append({"group_id": group_id, "group_cardname": group_nick_name})
|
||||
self.sync_to_database()
|
||||
logger.debug(f"添加用户 {self.person_id} 在群 {group_id} 的群昵称 {group_nick_name}")
|
||||
|
||||
@@ -452,20 +480,15 @@ class Person:
|
||||
else:
|
||||
self.memory_points = []
|
||||
|
||||
# 处理group_nick_name字段(JSON格式的列表)
|
||||
# 处理 group_cardname 字段(JSON 格式的列表)
|
||||
if record.group_cardname:
|
||||
try:
|
||||
loaded_group_nick_names = json.loads(record.group_cardname)
|
||||
# 确保是列表格式
|
||||
if isinstance(loaded_group_nick_names, list):
|
||||
self.group_nick_name = loaded_group_nick_names
|
||||
else:
|
||||
self.group_nick_name = []
|
||||
self.group_cardname_list = _to_group_cardname_records(record.group_cardname)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning(f"解析用户 {self.person_id} 的group_cardname字段失败,使用默认值")
|
||||
self.group_nick_name = []
|
||||
self.group_cardname_list = []
|
||||
else:
|
||||
self.group_nick_name = []
|
||||
self.group_cardname_list = []
|
||||
|
||||
logger.debug(f"已从数据库加载用户 {self.person_id} 的信息")
|
||||
else:
|
||||
@@ -486,11 +509,7 @@ class Person:
|
||||
if self.memory_points
|
||||
else json.dumps([], ensure_ascii=False)
|
||||
)
|
||||
group_nickname_value = (
|
||||
json.dumps(self.group_nick_name, ensure_ascii=False)
|
||||
if self.group_nick_name
|
||||
else json.dumps([], ensure_ascii=False)
|
||||
)
|
||||
group_cardname_value = dump_group_cardname_records(self.group_cardname_list)
|
||||
first_known_time = datetime.fromtimestamp(self.know_since) if self.know_since else None
|
||||
last_known_time = datetime.fromtimestamp(self.last_know) if self.last_know else None
|
||||
|
||||
@@ -510,7 +529,7 @@ class Person:
|
||||
record.first_known_time = first_known_time
|
||||
record.last_known_time = last_known_time
|
||||
record.memory_points = memory_points_value
|
||||
record.group_nickname = group_nickname_value
|
||||
record.group_cardname = group_cardname_value
|
||||
session.add(record)
|
||||
logger.debug(f"已同步用户 {self.person_id} 的信息到数据库")
|
||||
else:
|
||||
@@ -526,7 +545,7 @@ class Person:
|
||||
first_known_time=first_known_time,
|
||||
last_known_time=last_known_time,
|
||||
memory_points=memory_points_value,
|
||||
group_nickname=group_nickname_value,
|
||||
group_cardname=group_cardname_value,
|
||||
)
|
||||
session.add(record)
|
||||
logger.debug(f"已创建用户 {self.person_id} 的信息到数据库")
|
||||
|
||||
34
src/platform_io/__init__.py
Normal file
34
src/platform_io/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""导出 Platform IO 层的公开入口。
|
||||
|
||||
当前仍处于地基阶段,调用方应优先从这里导入共享类型和全局管理器,
|
||||
而不是直接依赖更底层的私有子模块。
|
||||
"""
|
||||
|
||||
from .manager import PlatformIOManager, get_platform_io_manager
|
||||
from .route_key_factory import RouteKeyFactory
|
||||
from .routing import RouteTable
|
||||
from .types import (
|
||||
DeliveryBatch,
|
||||
DeliveryReceipt,
|
||||
DeliveryStatus,
|
||||
DriverDescriptor,
|
||||
DriverKind,
|
||||
InboundMessageEnvelope,
|
||||
RouteBinding,
|
||||
RouteKey,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DeliveryBatch",
|
||||
"DeliveryReceipt",
|
||||
"DeliveryStatus",
|
||||
"DriverDescriptor",
|
||||
"DriverKind",
|
||||
"InboundMessageEnvelope",
|
||||
"PlatformIOManager",
|
||||
"RouteKeyFactory",
|
||||
"RouteBinding",
|
||||
"RouteKey",
|
||||
"RouteTable",
|
||||
"get_platform_io_manager",
|
||||
]
|
||||
133
src/platform_io/dedupe.py
Normal file
133
src/platform_io/dedupe.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""提供 Platform IO 的轻量入站消息去重能力。
|
||||
|
||||
当前实现基于 ``dict + heapq``:
|
||||
- ``dict`` 保存去重键到过期时间的映射
|
||||
- ``heapq`` 维护按过期时间排序的小顶堆
|
||||
|
||||
这样就不需要在每次检查时全表扫描,而是通过懒清理逐步弹出已经过期
|
||||
或已经失效的堆节点。
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import heapq
|
||||
import time
|
||||
|
||||
|
||||
class MessageDeduplicator:
|
||||
"""使用基于 TTL 的内存缓存进行入站消息去重。
|
||||
|
||||
主要用于解决同一条外部消息被重复送入 Core 的问题,例如双路径并存、
|
||||
适配器重试、重连或重复回调等场景。Broker 可以借助这个组件在进入
|
||||
Core 前先拦住重复投递,避免重复处理、重复回复和重复入库。
|
||||
|
||||
当前实现使用 ``dict + heapq`` 维护过期时间:
|
||||
- ``dict`` 负责 ``O(1)`` 级别的去重键查找
|
||||
- ``heapq`` 负责按过期时间顺序做懒清理
|
||||
|
||||
这比“每次调用都全表扫描过期项”的实现更适合高吞吐消息场景。
|
||||
|
||||
Notes:
|
||||
复杂度说明如下,设 ``n`` 为当前缓存中的有效去重键数量:
|
||||
|
||||
- 单次 ``mark_seen()`` 在常见路径下的时间复杂度接近 ``O(log n)``
|
||||
- 从长期摊还角度看,``mark_seen()`` 的时间复杂度也接近 ``O(log n)``
|
||||
- 如果某次调用恰好触发一批过期键的集中清理,则该次调用的最坏时间复杂度
|
||||
可达到 ``O(k log n)``,其中 ``k`` 为本次被弹出或清理的键数量
|
||||
- 空间复杂度为 ``O(n)``
|
||||
"""
|
||||
|
||||
def __init__(self, ttl_seconds: float = 300.0, max_entries: int = 10000) -> None:
|
||||
"""初始化去重器。
|
||||
|
||||
Args:
|
||||
ttl_seconds: 每个去重键在缓存中的保留时长,单位为秒。
|
||||
max_entries: 缓存允许保留的最大有效键数量,超出后会触发
|
||||
机会性淘汰。
|
||||
|
||||
Raises:
|
||||
ValueError: 当 ``ttl_seconds`` 或 ``max_entries`` 非正数时抛出。
|
||||
"""
|
||||
if ttl_seconds <= 0:
|
||||
raise ValueError("ttl_seconds 必须大于 0")
|
||||
if max_entries <= 0:
|
||||
raise ValueError("max_entries 必须大于 0")
|
||||
|
||||
self._ttl_seconds = ttl_seconds
|
||||
self._max_entries = max_entries
|
||||
self._expire_heap: List[Tuple[float, str]] = []
|
||||
self._seen: Dict[str, float] = {}
|
||||
|
||||
def mark_seen(self, dedupe_key: str) -> bool:
|
||||
"""标记一条去重键已经出现过。
|
||||
|
||||
Args:
|
||||
dedupe_key: 能稳定标识一条外部入站消息的去重键。
|
||||
|
||||
Returns:
|
||||
bool: 若该键在当前 TTL 窗口内首次出现则返回 ``True``,
|
||||
否则返回 ``False``。
|
||||
|
||||
Notes:
|
||||
方法会先基于小顶堆做一次懒清理,再判断当前键是否仍在有效期内。
|
||||
如果缓存已达到上限,则会优先淘汰“最早过期的仍然有效的键”。
|
||||
|
||||
复杂度方面,常见路径下该方法接近 ``O(log n)``;如果恰好需要
|
||||
集中清理一批过期键,则单次调用最坏可达到 ``O(k log n)``。
|
||||
"""
|
||||
now = time.monotonic()
|
||||
self._purge_expired(now)
|
||||
|
||||
expires_at = self._seen.get(dedupe_key)
|
||||
if expires_at is not None and expires_at > now:
|
||||
return False
|
||||
|
||||
if len(self._seen) >= self._max_entries:
|
||||
self._evict_earliest_live()
|
||||
|
||||
expires_at = now + self._ttl_seconds
|
||||
self._seen[dedupe_key] = expires_at
|
||||
heapq.heappush(self._expire_heap, (expires_at, dedupe_key))
|
||||
return True
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空全部去重缓存。"""
|
||||
self._expire_heap.clear()
|
||||
self._seen.clear()
|
||||
|
||||
def _purge_expired(self, now: float) -> None:
|
||||
"""从缓存中清理已经过期的去重键。
|
||||
|
||||
Args:
|
||||
now: 当前单调时钟时间戳。
|
||||
|
||||
Notes:
|
||||
堆中可能存在旧版本节点。例如同一个 ``dedupe_key`` 被重新写入后,
|
||||
旧的过期时间节点仍会留在堆里。这里会通过和 ``dict`` 中当前值比对,
|
||||
跳过这类失效节点。
|
||||
"""
|
||||
while self._expire_heap and self._expire_heap[0][0] <= now:
|
||||
expires_at, dedupe_key = heapq.heappop(self._expire_heap)
|
||||
current_expires_at = self._seen.get(dedupe_key)
|
||||
if current_expires_at is None:
|
||||
continue
|
||||
if current_expires_at != expires_at:
|
||||
continue
|
||||
self._seen.pop(dedupe_key, None)
|
||||
|
||||
def _evict_earliest_live(self) -> None:
|
||||
"""当缓存达到容量上限时,淘汰一条最早过期的有效键。
|
||||
|
||||
Notes:
|
||||
堆顶可能是已经过期或已失效的旧节点,因此这里同样需要循环弹出,
|
||||
直到找到一条当前仍然在 ``dict`` 中生效的键。
|
||||
"""
|
||||
while self._expire_heap:
|
||||
expires_at, dedupe_key = heapq.heappop(self._expire_heap)
|
||||
current_expires_at = self._seen.get(dedupe_key)
|
||||
if current_expires_at is None:
|
||||
continue
|
||||
if current_expires_at != expires_at:
|
||||
continue
|
||||
self._seen.pop(dedupe_key, None)
|
||||
return
|
||||
11
src/platform_io/drivers/__init__.py
Normal file
11
src/platform_io/drivers/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""导出 Platform IO 层的公开驱动类型。"""
|
||||
|
||||
from .base import PlatformIODriver
|
||||
from .legacy_driver import LegacyPlatformDriver
|
||||
from .plugin_driver import PluginPlatformDriver
|
||||
|
||||
__all__ = [
|
||||
"LegacyPlatformDriver",
|
||||
"PlatformIODriver",
|
||||
"PluginPlatformDriver",
|
||||
]
|
||||
104
src/platform_io/drivers/base.py
Normal file
104
src/platform_io/drivers/base.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""定义 Platform IO 传输驱动的基础抽象协议。"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional
|
||||
|
||||
from src.platform_io.types import DeliveryReceipt, DriverDescriptor, InboundMessageEnvelope, RouteKey
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
|
||||
InboundHandler = Callable[[InboundMessageEnvelope], Awaitable[bool]]
|
||||
|
||||
|
||||
class PlatformIODriver(ABC):
|
||||
"""定义所有 Platform IO 驱动都必须实现的最小契约。
|
||||
|
||||
当前实现故意保持接口很小,让中间层可以先落地,再逐步把 legacy
|
||||
与 plugin 路径的真实收发能力迁入这套协议之下。
|
||||
"""
|
||||
|
||||
def __init__(self, descriptor: DriverDescriptor) -> None:
|
||||
"""使用驱动描述对象初始化驱动。
|
||||
|
||||
Args:
|
||||
descriptor: 注册到 Broker 中的静态驱动元数据。
|
||||
"""
|
||||
self._descriptor = descriptor
|
||||
self._inbound_handler: Optional[InboundHandler] = None
|
||||
|
||||
@property
|
||||
def descriptor(self) -> DriverDescriptor:
|
||||
"""返回当前驱动的描述对象。
|
||||
|
||||
Returns:
|
||||
DriverDescriptor: 当前驱动实例对应的描述对象。
|
||||
"""
|
||||
return self._descriptor
|
||||
|
||||
@property
|
||||
def driver_id(self) -> str:
|
||||
"""返回驱动标识。
|
||||
|
||||
Returns:
|
||||
str: 当前驱动的唯一 ID。
|
||||
"""
|
||||
return self._descriptor.driver_id
|
||||
|
||||
def set_inbound_handler(self, handler: InboundHandler) -> None:
|
||||
"""注册入站消息交回 Broker 的回调函数。
|
||||
|
||||
Args:
|
||||
handler: 将规范化入站封装继续转发给 Broker 的异步回调。
|
||||
"""
|
||||
self._inbound_handler = handler
|
||||
|
||||
def clear_inbound_handler(self) -> None:
|
||||
"""清除当前注册的入站回调函数。"""
|
||||
self._inbound_handler = None
|
||||
|
||||
async def emit_inbound(self, envelope: InboundMessageEnvelope) -> bool:
|
||||
"""将一条入站封装转交给 Broker 回调。
|
||||
|
||||
Args:
|
||||
envelope: 由驱动产出的规范化入站封装。
|
||||
|
||||
Returns:
|
||||
bool: 若 Broker 接受该入站消息则返回 ``True``,否则返回 ``False``。
|
||||
"""
|
||||
|
||||
if self._inbound_handler is None:
|
||||
return False
|
||||
return await self._inbound_handler(envelope)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""启动驱动生命周期。
|
||||
|
||||
子类后续若需要初始化逻辑,可以覆盖这个钩子。
|
||||
"""
|
||||
return None
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止驱动生命周期。
|
||||
|
||||
子类后续若需要清理逻辑,可以覆盖这个钩子。
|
||||
"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
async def send_message(
|
||||
self,
|
||||
message: "SessionMessage",
|
||||
route_key: RouteKey,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> DeliveryReceipt:
|
||||
"""通过具体驱动发送一条消息。
|
||||
|
||||
Args:
|
||||
message: 要投递的内部会话消息。
|
||||
route_key: Broker 为本次投递选中的路由键。
|
||||
metadata: 本次出站投递可选的 Broker 侧元数据。
|
||||
|
||||
Returns:
|
||||
DeliveryReceipt: 规范化后的投递结果。
|
||||
"""
|
||||
92
src/platform_io/drivers/legacy_driver.py
Normal file
92
src/platform_io/drivers/legacy_driver.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""提供 Platform IO 的 legacy 传输驱动实现。"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from src.platform_io.drivers.base import PlatformIODriver
|
||||
from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, RouteKey
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
|
||||
|
||||
class LegacyPlatformDriver(PlatformIODriver):
|
||||
"""面向 ``UniversalMessageSender`` 旧链的 Platform IO 驱动。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
driver_id: str,
|
||||
platform: str,
|
||||
account_id: Optional[str] = None,
|
||||
scope: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""初始化一个 legacy 驱动描述对象。
|
||||
|
||||
Args:
|
||||
driver_id: Broker 内的唯一驱动 ID。
|
||||
platform: 该 legacy 适配器链路负责的平台。
|
||||
account_id: 可选的账号 ID。
|
||||
scope: 可选的额外路由作用域。
|
||||
metadata: 可选的额外驱动元数据。
|
||||
"""
|
||||
descriptor = DriverDescriptor(
|
||||
driver_id=driver_id,
|
||||
kind=DriverKind.LEGACY,
|
||||
platform=platform,
|
||||
account_id=account_id,
|
||||
scope=scope,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
super().__init__(descriptor)
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
message: "SessionMessage",
|
||||
route_key: RouteKey,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> DeliveryReceipt:
|
||||
"""通过旧链发送一条已经过预处理的消息。
|
||||
|
||||
Args:
|
||||
message: 要投递的内部会话消息。
|
||||
route_key: Broker 为本次投递选择的路由键。
|
||||
metadata: 本次出站投递可选的 Broker 侧元数据。
|
||||
|
||||
Returns:
|
||||
DeliveryReceipt: 规范化后的发送回执。
|
||||
"""
|
||||
from src.chat.message_receive.uni_message_sender import send_prepared_message_to_platform
|
||||
|
||||
show_log = False
|
||||
if isinstance(metadata, dict):
|
||||
show_log = bool(metadata.get("show_log", False))
|
||||
|
||||
try:
|
||||
sent = await send_prepared_message_to_platform(message, show_log=show_log)
|
||||
except Exception as exc:
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=message.message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.FAILED,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
if not sent:
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=message.message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.FAILED,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
error="旧链发送失败",
|
||||
)
|
||||
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=message.message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.SENT,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
)
|
||||
211
src/platform_io/drivers/plugin_driver.py
Normal file
211
src/platform_io/drivers/plugin_driver.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""提供 Platform IO 的插件消息网关驱动实现。"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Protocol
|
||||
|
||||
from src.platform_io.drivers.base import PlatformIODriver
|
||||
from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, RouteKey
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
|
||||
|
||||
class _GatewaySupervisorProtocol(Protocol):
|
||||
"""消息网关驱动依赖的 Supervisor 最小协议。"""
|
||||
|
||||
async def invoke_message_gateway(
|
||||
self,
|
||||
plugin_id: str,
|
||||
component_name: str,
|
||||
args: Optional[Dict[str, Any]] = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Any:
|
||||
"""调用插件声明的消息网关方法。"""
|
||||
|
||||
|
||||
class PluginPlatformDriver(PlatformIODriver):
|
||||
"""面向插件消息网关链路的 Platform IO 驱动。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
driver_id: str,
|
||||
platform: str,
|
||||
supervisor: _GatewaySupervisorProtocol,
|
||||
component_name: str,
|
||||
*,
|
||||
supports_send: bool,
|
||||
account_id: Optional[str] = None,
|
||||
scope: Optional[str] = None,
|
||||
plugin_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""初始化一个插件消息网关驱动。
|
||||
|
||||
Args:
|
||||
driver_id: Broker 内的唯一驱动 ID。
|
||||
platform: 该消息网关负责的平台名称。
|
||||
supervisor: 持有该插件的 Supervisor。
|
||||
component_name: 出站时要调用的网关组件名称。
|
||||
supports_send: 当前驱动是否具备出站能力。
|
||||
account_id: 可选的账号 ID 或 self ID。
|
||||
scope: 可选的额外路由作用域。
|
||||
plugin_id: 拥有该实现的插件 ID。
|
||||
metadata: 可选的额外驱动元数据。
|
||||
"""
|
||||
|
||||
descriptor = DriverDescriptor(
|
||||
driver_id=driver_id,
|
||||
kind=DriverKind.PLUGIN,
|
||||
platform=platform,
|
||||
account_id=account_id,
|
||||
scope=scope,
|
||||
plugin_id=plugin_id,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
super().__init__(descriptor)
|
||||
self._supervisor = supervisor
|
||||
self._component_name = component_name
|
||||
self._supports_send = supports_send
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
message: "SessionMessage",
|
||||
route_key: RouteKey,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> DeliveryReceipt:
|
||||
"""通过插件消息网关发送消息。
|
||||
|
||||
Args:
|
||||
message: 要投递的内部会话消息。
|
||||
route_key: Broker 为本次投递选择的路由键。
|
||||
metadata: 可选的发送元数据。
|
||||
|
||||
Returns:
|
||||
DeliveryReceipt: 规范化后的发送回执。
|
||||
"""
|
||||
|
||||
if not self._supports_send:
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=message.message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.FAILED,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
error="当前消息网关仅支持接收,不支持发送",
|
||||
)
|
||||
|
||||
from src.plugin_runtime.host.message_utils import PluginMessageUtils
|
||||
|
||||
plugin_id = self.descriptor.plugin_id or ""
|
||||
if not plugin_id:
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=message.message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.FAILED,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
error="插件消息网关驱动缺少 plugin_id",
|
||||
)
|
||||
|
||||
try:
|
||||
message_dict = PluginMessageUtils._session_message_to_dict(message)
|
||||
response = await self._supervisor.invoke_message_gateway(
|
||||
plugin_id=plugin_id,
|
||||
component_name=self._component_name,
|
||||
args={
|
||||
"message": message_dict,
|
||||
"route": {
|
||||
"platform": route_key.platform,
|
||||
"account_id": route_key.account_id,
|
||||
"scope": route_key.scope,
|
||||
},
|
||||
"metadata": metadata or {},
|
||||
},
|
||||
timeout_ms=30000,
|
||||
)
|
||||
except Exception as exc:
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=message.message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.FAILED,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
return self._build_receipt(message.message_id, route_key, response)
|
||||
|
||||
def _build_receipt(self, internal_message_id: str, route_key: RouteKey, response: Any) -> DeliveryReceipt:
|
||||
"""将网关调用响应归一化为出站回执。
|
||||
|
||||
Args:
|
||||
internal_message_id: 内部消息 ID。
|
||||
route_key: 本次投递的路由键。
|
||||
response: Supervisor 返回的 RPC 响应对象。
|
||||
|
||||
Returns:
|
||||
DeliveryReceipt: 标准化后的出站回执。
|
||||
"""
|
||||
|
||||
if getattr(response, "error", None):
|
||||
error = response.error.get("message", "消息网关发送失败")
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=internal_message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.FAILED,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
error=error,
|
||||
)
|
||||
|
||||
payload = getattr(response, "payload", {})
|
||||
invoke_success = bool(payload.get("success", False)) if isinstance(payload, dict) else False
|
||||
if not invoke_success:
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=internal_message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.FAILED,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
error=str(payload.get("result", "消息网关发送失败")) if isinstance(payload, dict) else "消息网关发送失败",
|
||||
)
|
||||
|
||||
result = payload.get("result") if isinstance(payload, dict) else None
|
||||
if isinstance(result, dict):
|
||||
if result.get("success") is False:
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=internal_message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.FAILED,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
error=str(result.get("error", "消息网关发送失败")),
|
||||
metadata=result.get("metadata", {}) if isinstance(result.get("metadata"), dict) else {},
|
||||
)
|
||||
external_message_id = str(result.get("external_message_id") or result.get("message_id") or "") or None
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=internal_message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.SENT,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
external_message_id=external_message_id,
|
||||
metadata=result.get("metadata", {}) if isinstance(result.get("metadata"), dict) else {},
|
||||
)
|
||||
|
||||
if isinstance(result, str) and result.strip():
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=internal_message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.SENT,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
external_message_id=result.strip(),
|
||||
)
|
||||
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=internal_message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.SENT,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
)
|
||||
611
src/platform_io/manager.py
Normal file
611
src/platform_io/manager.py
Normal file
@@ -0,0 +1,611 @@
|
||||
"""提供 Platform IO 层的中心 Broker 管理器。"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.platform_io.drivers.base import PlatformIODriver
|
||||
|
||||
from .dedupe import MessageDeduplicator
|
||||
from .outbound_tracker import OutboundTracker
|
||||
from .route_key_factory import RouteKeyFactory
|
||||
from .registry import DriverRegistry
|
||||
from .routing import RouteTable
|
||||
from .types import DeliveryBatch, DeliveryReceipt, DeliveryStatus, InboundMessageEnvelope, RouteBinding, RouteKey
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
|
||||
logger = get_logger("platform_io.manager")
|
||||
|
||||
InboundDispatcher = Callable[[InboundMessageEnvelope], Awaitable[None]]
|
||||
|
||||
|
||||
class PlatformIOManager:
|
||||
"""统一协调平台消息 IO 的路由、去重与状态跟踪。
|
||||
|
||||
与旧实现不同,这个管理器不再负责“多条链路谁该接管平台”的裁决,
|
||||
只维护发送表和接收表两张轻量路由表:
|
||||
|
||||
- 发送时:解析所有命中的发送绑定并全部投递。
|
||||
- 接收时:只校验当前驱动是否已登记为可接收链路,然后全部放行给上层。
|
||||
- 去重时:仅对单条链路做技术性重放抑制,不做跨链路语义去重。
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化 Broker 管理器及其内存状态。"""
|
||||
self._driver_registry = DriverRegistry()
|
||||
self._send_route_table = RouteTable()
|
||||
self._receive_route_table = RouteTable()
|
||||
self._legacy_send_drivers: Dict[str, PlatformIODriver] = {}
|
||||
self._deduplicator = MessageDeduplicator()
|
||||
self._outbound_tracker = OutboundTracker()
|
||||
self._inbound_dispatcher: Optional[InboundDispatcher] = None
|
||||
self._started = False
|
||||
|
||||
@property
|
||||
def is_started(self) -> bool:
|
||||
"""返回 Broker 当前是否已进入运行态。
|
||||
|
||||
Returns:
|
||||
bool: 若 Broker 已启动则返回 ``True``。
|
||||
"""
|
||||
return self._started
|
||||
|
||||
async def start(self) -> None:
|
||||
"""启动 Broker,并依次启动当前已注册的全部驱动。
|
||||
|
||||
Raises:
|
||||
Exception: 当某个驱动启动失败时,异常会继续上抛;已成功启动的驱动
|
||||
会被自动回滚停止。
|
||||
"""
|
||||
if self._started:
|
||||
return
|
||||
|
||||
started_drivers: List[PlatformIODriver] = []
|
||||
try:
|
||||
for driver in self._driver_registry.list():
|
||||
await driver.start()
|
||||
started_drivers.append(driver)
|
||||
except Exception:
|
||||
for driver in reversed(started_drivers):
|
||||
try:
|
||||
await driver.stop()
|
||||
except Exception:
|
||||
logger.exception(f"回滚驱动停止失败: driver_id={driver.driver_id}")
|
||||
raise
|
||||
|
||||
self._started = True
|
||||
|
||||
async def ensure_send_pipeline_ready(self) -> None:
|
||||
"""确保出站发送管线已准备就绪。
|
||||
|
||||
该方法会先同步 legacy fallback driver,再在需要时启动 Broker。
|
||||
send service 应只调用这一层准备入口,而不是自行判断旧链或插件链。
|
||||
"""
|
||||
await self._sync_legacy_send_drivers()
|
||||
if not self._started:
|
||||
await self.start()
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止 Broker,并按逆序停止全部已注册驱动。
|
||||
|
||||
停止完成后,会同步清空仅对当前运行周期有效的去重缓存和出站跟踪状态,
|
||||
避免下一次启动时继续沿用上一个运行周期的瞬时内存数据。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当一个或多个驱动停止失败时抛出汇总异常。
|
||||
"""
|
||||
if not self._started:
|
||||
return
|
||||
|
||||
stop_errors: List[str] = []
|
||||
for driver in reversed(self._driver_registry.list()):
|
||||
try:
|
||||
await driver.stop()
|
||||
except Exception as exc:
|
||||
stop_errors.append(f"{driver.driver_id}: {exc}")
|
||||
logger.exception(f"驱动停止失败: driver_id={driver.driver_id}")
|
||||
|
||||
self._started = False
|
||||
self._deduplicator.clear()
|
||||
self._outbound_tracker.clear()
|
||||
if stop_errors:
|
||||
raise RuntimeError(f"部分驱动停止失败: {'; '.join(stop_errors)}")
|
||||
|
||||
async def add_driver(self, driver: PlatformIODriver) -> None:
|
||||
"""向运行中的 Broker 注册并启动一个驱动。
|
||||
|
||||
如果 Broker 尚未启动,则该方法等价于 ``register_driver()``。
|
||||
|
||||
Args:
|
||||
driver: 要添加的驱动实例。
|
||||
|
||||
Raises:
|
||||
Exception: 当驱动启动失败时,注册会自动回滚,异常继续上抛。
|
||||
"""
|
||||
self._register_driver_internal(driver)
|
||||
if not self._started:
|
||||
return
|
||||
|
||||
try:
|
||||
await driver.start()
|
||||
except Exception:
|
||||
self._unregister_driver_internal(driver.driver_id)
|
||||
raise
|
||||
|
||||
async def remove_driver(self, driver_id: str) -> Optional[PlatformIODriver]:
|
||||
"""从运行中的 Broker 停止并移除一个驱动。
|
||||
|
||||
如果 Broker 尚未启动,则该方法等价于 ``unregister_driver()``。
|
||||
|
||||
Args:
|
||||
driver_id: 要移除的驱动 ID。
|
||||
|
||||
Returns:
|
||||
Optional[PlatformIODriver]: 若驱动存在,则返回被移除的驱动实例。
|
||||
|
||||
Raises:
|
||||
Exception: 当 Broker 运行中且驱动停止失败时,异常会继续上抛。
|
||||
"""
|
||||
if not self._started:
|
||||
return self.unregister_driver(driver_id)
|
||||
|
||||
driver = self._driver_registry.get(driver_id)
|
||||
if driver is None:
|
||||
return None
|
||||
|
||||
await driver.stop()
|
||||
return self._unregister_driver_internal(driver_id)
|
||||
|
||||
@property
|
||||
def driver_registry(self) -> DriverRegistry:
|
||||
"""返回管理器持有的驱动注册表。
|
||||
|
||||
Returns:
|
||||
DriverRegistry: 用于保存全部已注册驱动的注册表。
|
||||
"""
|
||||
return self._driver_registry
|
||||
|
||||
@property
|
||||
def send_route_table(self) -> RouteTable:
|
||||
"""返回发送路由表。"""
|
||||
|
||||
return self._send_route_table
|
||||
|
||||
@property
|
||||
def receive_route_table(self) -> RouteTable:
|
||||
"""返回接收路由表。"""
|
||||
|
||||
return self._receive_route_table
|
||||
|
||||
@property
|
||||
def deduplicator(self) -> MessageDeduplicator:
|
||||
"""返回管理器持有的入站去重器。
|
||||
|
||||
Returns:
|
||||
MessageDeduplicator: 用于抑制重复入站的去重器。
|
||||
"""
|
||||
return self._deduplicator
|
||||
|
||||
@property
|
||||
def outbound_tracker(self) -> OutboundTracker:
|
||||
"""返回管理器持有的出站跟踪器。
|
||||
|
||||
Returns:
|
||||
OutboundTracker: 用于记录出站 pending 状态与回执的跟踪器。
|
||||
"""
|
||||
return self._outbound_tracker
|
||||
|
||||
def set_inbound_dispatcher(self, dispatcher: InboundDispatcher) -> None:
|
||||
"""设置统一的入站分发回调。
|
||||
|
||||
Args:
|
||||
dispatcher: 接收已通过 Broker 审核的入站封装,并继续送入
|
||||
Core 下一处理阶段的异步回调。
|
||||
"""
|
||||
|
||||
self._inbound_dispatcher = dispatcher
|
||||
|
||||
def clear_inbound_dispatcher(self) -> None:
|
||||
"""清除当前的入站分发回调。"""
|
||||
self._inbound_dispatcher = None
|
||||
|
||||
@property
|
||||
def has_inbound_dispatcher(self) -> bool:
|
||||
"""返回当前是否已经配置入站分发回调。
|
||||
|
||||
Returns:
|
||||
bool: 若已经配置入站分发回调则返回 ``True``。
|
||||
"""
|
||||
return self._inbound_dispatcher is not None
|
||||
|
||||
def register_driver(self, driver: PlatformIODriver) -> None:
|
||||
"""注册驱动,并把它的入站回调挂到 Broker。
|
||||
|
||||
Args:
|
||||
driver: 要注册的驱动实例。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当 Broker 已经处于运行态时抛出。此时应改用
|
||||
``add_driver()`` 以保证驱动生命周期和注册状态一致。
|
||||
"""
|
||||
if self._started:
|
||||
raise RuntimeError("Broker 运行中不允许直接 register_driver,请改用 add_driver()")
|
||||
|
||||
self._register_driver_internal(driver)
|
||||
|
||||
def _register_driver_internal(self, driver: PlatformIODriver) -> None:
|
||||
"""执行不带运行态限制的内部驱动注册。
|
||||
|
||||
Args:
|
||||
driver: 要注册的驱动实例。
|
||||
"""
|
||||
driver.set_inbound_handler(self.accept_inbound)
|
||||
self._driver_registry.register(driver)
|
||||
|
||||
def unregister_driver(self, driver_id: str) -> Optional[PlatformIODriver]:
|
||||
"""从 Broker 注销一个驱动。
|
||||
|
||||
Args:
|
||||
driver_id: 要移除的驱动 ID。
|
||||
|
||||
Returns:
|
||||
Optional[PlatformIODriver]: 若驱动存在,则返回被移除的驱动实例。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当 Broker 已经处于运行态时抛出。此时应改用
|
||||
``remove_driver()``,避免驱动停止与路由解绑脱节。
|
||||
"""
|
||||
if self._started:
|
||||
raise RuntimeError("Broker 运行中不允许直接 unregister_driver,请改用 remove_driver()")
|
||||
|
||||
return self._unregister_driver_internal(driver_id)
|
||||
|
||||
def _unregister_driver_internal(self, driver_id: str) -> Optional[PlatformIODriver]:
|
||||
"""执行不带运行态限制的内部驱动注销。
|
||||
|
||||
Args:
|
||||
driver_id: 要移除的驱动 ID。
|
||||
|
||||
Returns:
|
||||
Optional[PlatformIODriver]: 若驱动存在,则返回被移除的驱动实例。
|
||||
"""
|
||||
removed_driver = self._driver_registry.unregister(driver_id)
|
||||
if removed_driver is None:
|
||||
return None
|
||||
|
||||
removed_driver.clear_inbound_handler()
|
||||
self._send_route_table.remove_bindings_by_driver(driver_id)
|
||||
self._receive_route_table.remove_bindings_by_driver(driver_id)
|
||||
self._legacy_send_drivers = {
|
||||
platform: driver
|
||||
for platform, driver in self._legacy_send_drivers.items()
|
||||
if driver.driver_id != driver_id
|
||||
}
|
||||
return removed_driver
|
||||
|
||||
async def _sync_legacy_send_drivers(self) -> None:
|
||||
"""根据当前配置同步 legacy fallback driver。"""
|
||||
from src.chat.utils.utils import get_all_bot_accounts
|
||||
from src.platform_io.drivers.legacy_driver import LegacyPlatformDriver
|
||||
|
||||
desired_accounts = get_all_bot_accounts()
|
||||
desired_platforms = set(desired_accounts.keys())
|
||||
current_platforms = set(self._legacy_send_drivers.keys())
|
||||
|
||||
for platform in sorted(current_platforms - desired_platforms):
|
||||
await self._remove_legacy_send_driver(platform)
|
||||
|
||||
for platform, account_id in desired_accounts.items():
|
||||
existing_driver = self._legacy_send_drivers.get(platform)
|
||||
if existing_driver is not None and existing_driver.descriptor.account_id == account_id:
|
||||
continue
|
||||
|
||||
if existing_driver is not None:
|
||||
await self._remove_legacy_send_driver(platform)
|
||||
|
||||
driver = LegacyPlatformDriver(
|
||||
driver_id=f"legacy.send.{platform}",
|
||||
platform=platform,
|
||||
account_id=account_id,
|
||||
)
|
||||
if self._started:
|
||||
await self.add_driver(driver)
|
||||
else:
|
||||
self.register_driver(driver)
|
||||
self._legacy_send_drivers[platform] = driver
|
||||
|
||||
async def _remove_legacy_send_driver(self, platform: str) -> None:
|
||||
"""移除指定平台的 legacy fallback driver。
|
||||
|
||||
Args:
|
||||
platform: 要移除的目标平台。
|
||||
"""
|
||||
driver = self._legacy_send_drivers.get(platform)
|
||||
if driver is None:
|
||||
return
|
||||
|
||||
if self._started:
|
||||
await self.remove_driver(driver.driver_id)
|
||||
else:
|
||||
self.unregister_driver(driver.driver_id)
|
||||
self._legacy_send_drivers.pop(platform, None)
|
||||
|
||||
def bind_send_route(self, binding: RouteBinding) -> None:
|
||||
"""为某个路由键绑定发送驱动。
|
||||
|
||||
Args:
|
||||
binding: 要保存的路由绑定。
|
||||
|
||||
Raises:
|
||||
ValueError: 当绑定引用了不存在的驱动,或者绑定与驱动描述不一致时抛出。
|
||||
"""
|
||||
driver = self._driver_registry.get(binding.driver_id)
|
||||
if driver is None:
|
||||
raise ValueError(f"驱动 {binding.driver_id} 未注册,无法绑定路由")
|
||||
|
||||
self._validate_binding_against_driver(binding, driver)
|
||||
self._send_route_table.bind(binding)
|
||||
|
||||
def bind_receive_route(self, binding: RouteBinding) -> None:
|
||||
"""为某个路由键绑定接收驱动。
|
||||
|
||||
Args:
|
||||
binding: 要保存的路由绑定。
|
||||
|
||||
Raises:
|
||||
ValueError: 当绑定引用了不存在的驱动,或者绑定与驱动描述不一致时抛出。
|
||||
"""
|
||||
driver = self._driver_registry.get(binding.driver_id)
|
||||
if driver is None:
|
||||
raise ValueError(f"驱动 {binding.driver_id} 未注册,无法绑定路由")
|
||||
|
||||
self._validate_binding_against_driver(binding, driver)
|
||||
self._receive_route_table.bind(binding)
|
||||
|
||||
def unbind_send_route(self, route_key: RouteKey, driver_id: Optional[str] = None) -> None:
|
||||
"""移除发送路由绑定。
|
||||
|
||||
Args:
|
||||
route_key: 要移除绑定的路由键。
|
||||
driver_id: 可选的特定驱动 ID。
|
||||
"""
|
||||
|
||||
self._send_route_table.unbind(route_key, driver_id)
|
||||
|
||||
def unbind_receive_route(self, route_key: RouteKey, driver_id: Optional[str] = None) -> None:
|
||||
"""移除接收路由绑定。
|
||||
|
||||
Args:
|
||||
route_key: 要移除绑定的路由键。
|
||||
driver_id: 可选的特定驱动 ID。
|
||||
"""
|
||||
|
||||
self._receive_route_table.unbind(route_key, driver_id)
|
||||
|
||||
def resolve_drivers(self, route_key: RouteKey) -> List[PlatformIODriver]:
|
||||
"""解析某个路由键当前命中的全部发送驱动。
|
||||
|
||||
Args:
|
||||
route_key: 要解析的路由键。
|
||||
|
||||
Returns:
|
||||
List[PlatformIODriver]: 当前命中的全部发送驱动。
|
||||
"""
|
||||
|
||||
drivers: List[PlatformIODriver] = []
|
||||
seen_driver_ids: set[str] = set()
|
||||
for binding in self._send_route_table.resolve_bindings(route_key):
|
||||
driver = self._driver_registry.get(binding.driver_id)
|
||||
if driver is not None and driver.driver_id not in seen_driver_ids:
|
||||
drivers.append(driver)
|
||||
seen_driver_ids.add(driver.driver_id)
|
||||
|
||||
fallback_driver = self._legacy_send_drivers.get(route_key.platform)
|
||||
if fallback_driver is not None:
|
||||
descriptor = fallback_driver.descriptor
|
||||
account_matches = descriptor.account_id is None or route_key.account_id in (None, descriptor.account_id)
|
||||
scope_matches = descriptor.scope is None or route_key.scope in (None, descriptor.scope)
|
||||
if account_matches and scope_matches and fallback_driver.driver_id not in seen_driver_ids:
|
||||
drivers.append(fallback_driver)
|
||||
|
||||
return drivers
|
||||
|
||||
@staticmethod
|
||||
def build_route_key_from_message(message: "SessionMessage") -> RouteKey:
|
||||
"""根据 ``SessionMessage`` 构造路由键。
|
||||
|
||||
Args:
|
||||
message: 内部会话消息对象。
|
||||
|
||||
Returns:
|
||||
RouteKey: 由消息内容提取出的规范化路由键。
|
||||
"""
|
||||
return RouteKeyFactory.from_session_message(message)
|
||||
|
||||
@staticmethod
|
||||
def build_route_key_from_message_dict(message_dict: Dict[str, Any]) -> RouteKey:
|
||||
"""根据消息字典构造路由键。
|
||||
|
||||
Args:
|
||||
message_dict: Host 与插件之间传输的消息字典。
|
||||
|
||||
Returns:
|
||||
RouteKey: 由消息字典提取出的规范化路由键。
|
||||
"""
|
||||
return RouteKeyFactory.from_message_dict(message_dict)
|
||||
|
||||
async def accept_inbound(self, envelope: InboundMessageEnvelope) -> bool:
|
||||
"""处理一条由驱动上报的入站封装。
|
||||
|
||||
Args:
|
||||
envelope: 由传输驱动产出的入站封装。
|
||||
|
||||
Returns:
|
||||
bool: 若消息被接受并继续转发给入站分发器,则返回 ``True``,
|
||||
否则返回 ``False``。
|
||||
"""
|
||||
|
||||
if not self._receive_route_table.has_binding_for_driver(envelope.route_key, envelope.driver_id):
|
||||
logger.info(
|
||||
f"忽略未登记到接收路由表的入站消息: route={envelope.route_key} "
|
||||
f"driver={envelope.driver_id}"
|
||||
)
|
||||
return False
|
||||
|
||||
if self._inbound_dispatcher is None:
|
||||
logger.debug("PlatformIOManager 尚未配置 inbound dispatcher,暂不继续分发")
|
||||
return False
|
||||
|
||||
dedupe_key = self._build_inbound_dedupe_key(envelope)
|
||||
if dedupe_key is not None:
|
||||
if not self._deduplicator.mark_seen(dedupe_key):
|
||||
logger.info(f"忽略重复入站消息: dedupe_key={dedupe_key}")
|
||||
return False
|
||||
|
||||
await self._inbound_dispatcher(envelope)
|
||||
return True
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
message: "SessionMessage",
|
||||
route_key: RouteKey,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> DeliveryBatch:
|
||||
"""通过 Broker 选中的全部发送驱动广播一条消息。
|
||||
|
||||
Args:
|
||||
message: 要投递的内部会话消息。
|
||||
route_key: 本次出站投递选择的路由键。
|
||||
metadata: 可选的额外 Broker 侧元数据。
|
||||
|
||||
Returns:
|
||||
DeliveryBatch: 规范化后的批量出站回执。
|
||||
"""
|
||||
drivers = self.resolve_drivers(route_key)
|
||||
if not drivers:
|
||||
return DeliveryBatch(internal_message_id=message.message_id, route_key=route_key)
|
||||
|
||||
receipts: List[DeliveryReceipt] = []
|
||||
for driver in drivers:
|
||||
try:
|
||||
self._outbound_tracker.begin_tracking(
|
||||
internal_message_id=message.message_id,
|
||||
route_key=route_key,
|
||||
driver_id=driver.driver_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
except ValueError as exc:
|
||||
receipts.append(
|
||||
DeliveryReceipt(
|
||||
internal_message_id=message.message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.FAILED,
|
||||
driver_id=driver.driver_id,
|
||||
driver_kind=driver.descriptor.kind,
|
||||
error=str(exc),
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
receipt = await driver.send_message(message=message, route_key=route_key, metadata=metadata)
|
||||
except Exception as exc:
|
||||
receipt = DeliveryReceipt(
|
||||
internal_message_id=message.message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.FAILED,
|
||||
driver_id=driver.driver_id,
|
||||
driver_kind=driver.descriptor.kind,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
self._outbound_tracker.finish_tracking(receipt)
|
||||
receipts.append(receipt)
|
||||
|
||||
return DeliveryBatch(
|
||||
internal_message_id=message.message_id,
|
||||
route_key=route_key,
|
||||
receipts=receipts,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_inbound_dedupe_key(envelope: InboundMessageEnvelope) -> Optional[str]:
|
||||
"""构造用于入站抑制的去重键。
|
||||
|
||||
Args:
|
||||
envelope: 当前正在处理的入站封装。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 若可以构造稳定去重键则返回该键,否则返回 ``None``。
|
||||
|
||||
Notes:
|
||||
这里仅接受上游显式提供的稳定消息身份,例如 ``dedupe_key``、
|
||||
平台侧 ``external_message_id`` 或已经完成规范化的
|
||||
``session_message.message_id``。Broker 不再根据 ``payload`` 内容
|
||||
猜测语义去重键,避免把“短时间内两条内容刚好完全相同”的合法消息
|
||||
误判为重复入站。
|
||||
"""
|
||||
raw_dedupe_key = envelope.dedupe_key or envelope.external_message_id
|
||||
if raw_dedupe_key is None and envelope.session_message is not None:
|
||||
raw_dedupe_key = envelope.session_message.message_id
|
||||
if raw_dedupe_key is None:
|
||||
return None
|
||||
|
||||
normalized_dedupe_key = str(raw_dedupe_key).strip()
|
||||
if not normalized_dedupe_key:
|
||||
return None
|
||||
|
||||
return f"{envelope.driver_id}:{normalized_dedupe_key}"
|
||||
|
||||
@staticmethod
|
||||
def _validate_binding_against_driver(binding: RouteBinding, driver: PlatformIODriver) -> None:
|
||||
"""校验路由绑定与驱动描述是否一致。
|
||||
|
||||
Args:
|
||||
binding: 待校验的路由绑定。
|
||||
driver: 被绑定的驱动实例。
|
||||
|
||||
Raises:
|
||||
ValueError: 当绑定类型、平台或更细粒度路由维度与驱动描述冲突时抛出。
|
||||
"""
|
||||
descriptor = driver.descriptor
|
||||
if binding.driver_kind != descriptor.kind:
|
||||
raise ValueError(
|
||||
f"路由绑定的 driver_kind={binding.driver_kind} 与驱动 {driver.driver_id} 的类型 "
|
||||
f"{descriptor.kind} 不一致"
|
||||
)
|
||||
|
||||
if binding.route_key.platform != descriptor.platform:
|
||||
raise ValueError(
|
||||
f"路由绑定的平台 {binding.route_key.platform} 与驱动 {driver.driver_id} 的平台 "
|
||||
f"{descriptor.platform} 不一致"
|
||||
)
|
||||
|
||||
if descriptor.account_id is not None and binding.route_key.account_id not in (None, descriptor.account_id):
|
||||
raise ValueError(
|
||||
f"路由绑定的 account_id={binding.route_key.account_id} 与驱动 {driver.driver_id} 的 "
|
||||
f"account_id={descriptor.account_id} 冲突"
|
||||
)
|
||||
|
||||
if descriptor.scope is not None and binding.route_key.scope not in (None, descriptor.scope):
|
||||
raise ValueError(
|
||||
f"路由绑定的 scope={binding.route_key.scope} 与驱动 {driver.driver_id} 的 "
|
||||
f"scope={descriptor.scope} 冲突"
|
||||
)
|
||||
|
||||
|
||||
_platform_io_manager: Optional[PlatformIOManager] = None
|
||||
|
||||
|
||||
def get_platform_io_manager() -> PlatformIOManager:
|
||||
"""返回全局 ``PlatformIOManager`` 单例。
|
||||
|
||||
Returns:
|
||||
PlatformIOManager: 进程级共享的 Broker 管理器实例。
|
||||
"""
|
||||
|
||||
global _platform_io_manager
|
||||
if _platform_io_manager is None:
|
||||
_platform_io_manager = PlatformIOManager()
|
||||
return _platform_io_manager
|
||||
286
src/platform_io/outbound_tracker.py
Normal file
286
src/platform_io/outbound_tracker.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""跟踪 Platform IO 层的出站投递状态。
|
||||
|
||||
当前实现基于两组 ``dict + heapq``:
|
||||
- ``_pending`` 和 ``_pending_expire_heap`` 负责管理待完成的出站记录
|
||||
- ``_receipts_by_external_id`` 和 ``_receipt_expire_heap`` 负责管理已完成回执索引
|
||||
|
||||
这样就不需要在每次读写时全表扫描过期项,而是通过懒清理逐步弹出已经过期
|
||||
或已经失效的堆节点。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import heapq
|
||||
import time
|
||||
|
||||
from .types import DeliveryReceipt, RouteKey
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class PendingOutboundRecord:
|
||||
"""表示一条仍在等待完成的出站投递记录。
|
||||
|
||||
Attributes:
|
||||
internal_message_id: 正在跟踪的内部 ``SessionMessage.message_id``。
|
||||
route_key: 该出站投递开始时使用的路由键。
|
||||
driver_id: 负责这次出站投递的驱动 ID。
|
||||
created_at: 开始跟踪时记录的单调时钟时间戳。
|
||||
expires_at: 该待完成记录预计过期的单调时钟时间戳。
|
||||
metadata: 与待完成记录一同保留的额外 Broker 侧元数据。
|
||||
"""
|
||||
|
||||
internal_message_id: str
|
||||
route_key: RouteKey
|
||||
driver_id: str
|
||||
created_at: float = field(default_factory=time.monotonic)
|
||||
expires_at: float = 0.0
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class StoredDeliveryReceipt:
|
||||
"""表示一条已完成并暂存的出站回执。
|
||||
|
||||
Attributes:
|
||||
receipt: 规范化后的出站投递回执。
|
||||
stored_at: 回执被写入索引时记录的单调时钟时间戳。
|
||||
expires_at: 该回执索引预计过期的单调时钟时间戳。
|
||||
"""
|
||||
|
||||
receipt: DeliveryReceipt
|
||||
stored_at: float = field(default_factory=time.monotonic)
|
||||
expires_at: float = 0.0
|
||||
|
||||
|
||||
class OutboundTracker:
|
||||
"""统一跟踪出站消息的 pending 状态与最终回执。
|
||||
|
||||
主要用于解决出站消息在发送过程中“状态散落在不同路径里”的问题:
|
||||
- 发送开始后,需要在最终回执返回前保留一份 pending 状态
|
||||
- 平台返回 ``external_message_id`` 后,需要保留一段时间的回执索引
|
||||
|
||||
当前实现使用 ``dict + heapq`` 做 TTL 管理:
|
||||
- ``dict`` 提供 ``O(1)`` 级别的主键查询
|
||||
- ``heapq`` 提供按过期时间排序的懒清理能力
|
||||
|
||||
这比“每次 begin/finish/get 都全表扫描”的实现更适合高吞吐出站场景。
|
||||
|
||||
Notes:
|
||||
复杂度说明如下,设 ``p`` 为当前有效 pending 数量,``r`` 为当前有效回执数量:
|
||||
|
||||
- ``begin_tracking()``、``finish_tracking()`` 的常见路径时间复杂度接近
|
||||
``O(log p)`` 或 ``O(log r)``
|
||||
- ``get_pending()``、``get_receipt_by_external_id()`` 的查询本身是 ``O(1)``
|
||||
,连同懒清理一起看,长期摊还复杂度接近 ``O(log n)``
|
||||
- 如果某次调用恰好触发一批过期节点的集中清理,则该次调用的最坏时间复杂度
|
||||
可达到 ``O(k log n)``,其中 ``k`` 为本次被弹出的节点数量
|
||||
- 空间复杂度为 ``O(p + r)``
|
||||
"""
|
||||
|
||||
def __init__(self, ttl_seconds: float = 1800.0) -> None:
|
||||
"""初始化出站跟踪器。
|
||||
|
||||
Args:
|
||||
ttl_seconds: 待完成记录与按外部消息 ID 建立的回执索引保留时长,
|
||||
单位为秒。
|
||||
|
||||
Raises:
|
||||
ValueError: 当 ``ttl_seconds`` 非正数时抛出。
|
||||
"""
|
||||
if ttl_seconds <= 0:
|
||||
raise ValueError("ttl_seconds 必须大于 0")
|
||||
|
||||
self._ttl_seconds = ttl_seconds
|
||||
self._pending: Dict[Tuple[str, str], PendingOutboundRecord] = {}
|
||||
self._pending_expire_heap: List[Tuple[float, str, str]] = []
|
||||
self._receipts_by_external_id: Dict[str, StoredDeliveryReceipt] = {}
|
||||
self._receipt_expire_heap: List[Tuple[float, str]] = []
|
||||
|
||||
@staticmethod
|
||||
def _build_pending_key(internal_message_id: str, driver_id: str) -> Tuple[str, str]:
|
||||
"""构造单条出站跟踪记录的唯一键。
|
||||
|
||||
Args:
|
||||
internal_message_id: 内部消息 ID。
|
||||
driver_id: 负责当前投递的驱动 ID。
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: ``(internal_message_id, driver_id)`` 组合键。
|
||||
"""
|
||||
return internal_message_id, driver_id
|
||||
|
||||
def begin_tracking(
|
||||
self,
|
||||
internal_message_id: str,
|
||||
route_key: RouteKey,
|
||||
driver_id: str,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> PendingOutboundRecord:
|
||||
"""开始跟踪一次出站投递。
|
||||
|
||||
Args:
|
||||
internal_message_id: 正在投递的内部消息 ID。
|
||||
route_key: 这次出站投递选择的路由键。
|
||||
driver_id: 负责本次投递的驱动 ID。
|
||||
metadata: 可选的额外元数据,会一并保存在待完成记录中。
|
||||
|
||||
Returns:
|
||||
PendingOutboundRecord: 新创建的待完成记录。
|
||||
|
||||
Raises:
|
||||
ValueError: 当同一个 ``internal_message_id`` 与 ``driver_id`` 组合已经存在
|
||||
未完成记录时抛出。
|
||||
"""
|
||||
now = time.monotonic()
|
||||
self._cleanup_expired(now)
|
||||
pending_key = self._build_pending_key(internal_message_id, driver_id)
|
||||
|
||||
if pending_key in self._pending:
|
||||
raise ValueError(f"消息 {internal_message_id} 在驱动 {driver_id} 上已存在未完成的出站跟踪记录")
|
||||
|
||||
expires_at = now + self._ttl_seconds
|
||||
record = PendingOutboundRecord(
|
||||
internal_message_id=internal_message_id,
|
||||
route_key=route_key,
|
||||
driver_id=driver_id,
|
||||
created_at=now,
|
||||
expires_at=expires_at,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
self._pending[pending_key] = record
|
||||
heapq.heappush(self._pending_expire_heap, (expires_at, internal_message_id, driver_id))
|
||||
return record
|
||||
|
||||
def finish_tracking(self, receipt: DeliveryReceipt) -> Optional[PendingOutboundRecord]:
|
||||
"""使用最终回执结束一条出站跟踪。
|
||||
|
||||
Args:
|
||||
receipt: 规范化后的最终投递回执。
|
||||
|
||||
Returns:
|
||||
Optional[PendingOutboundRecord]: 若此前存在待完成记录,则返回该记录。
|
||||
"""
|
||||
now = time.monotonic()
|
||||
self._cleanup_expired(now)
|
||||
|
||||
pending_record: Optional[PendingOutboundRecord] = None
|
||||
if receipt.driver_id:
|
||||
pending_key = self._build_pending_key(receipt.internal_message_id, receipt.driver_id)
|
||||
pending_record = self._pending.pop(pending_key, None)
|
||||
else:
|
||||
matched_records = [
|
||||
key
|
||||
for key, record in self._pending.items()
|
||||
if record.internal_message_id == receipt.internal_message_id
|
||||
]
|
||||
if len(matched_records) == 1:
|
||||
pending_record = self._pending.pop(matched_records[0], None)
|
||||
|
||||
if receipt.external_message_id:
|
||||
expires_at = now + self._ttl_seconds
|
||||
self._receipts_by_external_id[receipt.external_message_id] = StoredDeliveryReceipt(
|
||||
receipt=receipt,
|
||||
stored_at=now,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
heapq.heappush(self._receipt_expire_heap, (expires_at, receipt.external_message_id))
|
||||
return pending_record
|
||||
|
||||
def get_pending(
|
||||
self,
|
||||
internal_message_id: str,
|
||||
driver_id: Optional[str] = None,
|
||||
) -> Optional[PendingOutboundRecord]:
|
||||
"""根据内部消息 ID 查询待完成记录。
|
||||
|
||||
Args:
|
||||
internal_message_id: 要查询的内部消息 ID。
|
||||
driver_id: 可选的驱动 ID;提供后仅返回该驱动上的待完成记录。
|
||||
|
||||
Returns:
|
||||
Optional[PendingOutboundRecord]: 若记录仍存在,则返回对应待完成记录。
|
||||
"""
|
||||
self._cleanup_expired(time.monotonic())
|
||||
|
||||
if driver_id:
|
||||
return self._pending.get(self._build_pending_key(internal_message_id, driver_id))
|
||||
|
||||
matched_records = [
|
||||
record
|
||||
for record in self._pending.values()
|
||||
if record.internal_message_id == internal_message_id
|
||||
]
|
||||
if len(matched_records) == 1:
|
||||
return matched_records[0]
|
||||
return None
|
||||
|
||||
def get_receipt_by_external_id(self, external_message_id: str) -> Optional[DeliveryReceipt]:
|
||||
"""根据外部平台消息 ID 查询已完成回执。
|
||||
|
||||
Args:
|
||||
external_message_id: 要查询的平台侧消息 ID。
|
||||
|
||||
Returns:
|
||||
Optional[DeliveryReceipt]: 若存在对应回执,则返回该回执。
|
||||
"""
|
||||
self._cleanup_expired(time.monotonic())
|
||||
stored_receipt = self._receipts_by_external_id.get(external_message_id)
|
||||
return stored_receipt.receipt if stored_receipt else None
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空全部待完成记录与已保存回执。"""
|
||||
self._pending.clear()
|
||||
self._pending_expire_heap.clear()
|
||||
self._receipts_by_external_id.clear()
|
||||
self._receipt_expire_heap.clear()
|
||||
|
||||
def _cleanup_expired(self, now: float) -> None:
|
||||
"""清理内存中已经过期的待完成记录与已保存回执。
|
||||
|
||||
Args:
|
||||
now: 当前单调时钟时间戳。
|
||||
"""
|
||||
self._cleanup_expired_pending(now)
|
||||
self._cleanup_expired_receipts(now)
|
||||
|
||||
def _cleanup_expired_pending(self, now: float) -> None:
|
||||
"""清理已经过期的待完成记录。
|
||||
|
||||
Args:
|
||||
now: 当前单调时钟时间戳。
|
||||
|
||||
Notes:
|
||||
堆中可能存在已经失效的旧节点。例如某条记录提前 ``finish`` 后,
|
||||
它原本的过期节点仍可能留在堆里。这里会通过和 ``dict`` 中当前记录的
|
||||
``expires_at`` 对比,跳过这类旧节点。
|
||||
"""
|
||||
while self._pending_expire_heap and self._pending_expire_heap[0][0] <= now:
|
||||
expires_at, internal_message_id, driver_id = heapq.heappop(self._pending_expire_heap)
|
||||
pending_key = self._build_pending_key(internal_message_id, driver_id)
|
||||
current_record = self._pending.get(pending_key)
|
||||
if current_record is None:
|
||||
continue
|
||||
if current_record.expires_at != expires_at:
|
||||
continue
|
||||
self._pending.pop(pending_key, None)
|
||||
|
||||
def _cleanup_expired_receipts(self, now: float) -> None:
|
||||
"""清理已经过期的回执索引。
|
||||
|
||||
Args:
|
||||
now: 当前单调时钟时间戳。
|
||||
|
||||
Notes:
|
||||
同一个 ``external_message_id`` 在极端情况下可能被重复写入索引,
|
||||
因此这里同样需要通过 ``expires_at`` 和当前 ``dict`` 中的值比对,
|
||||
跳过已经失效的旧堆节点。
|
||||
"""
|
||||
while self._receipt_expire_heap and self._receipt_expire_heap[0][0] <= now:
|
||||
expires_at, external_message_id = heapq.heappop(self._receipt_expire_heap)
|
||||
current_receipt = self._receipts_by_external_id.get(external_message_id)
|
||||
if current_receipt is None:
|
||||
continue
|
||||
if current_receipt.expires_at != expires_at:
|
||||
continue
|
||||
self._receipts_by_external_id.pop(external_message_id, None)
|
||||
70
src/platform_io/registry.py
Normal file
70
src/platform_io/registry.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""提供 Platform IO 的驱动注册与查询能力。"""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from src.platform_io.drivers.base import PlatformIODriver
|
||||
from src.platform_io.types import DriverKind
|
||||
|
||||
|
||||
class DriverRegistry:
|
||||
"""集中保存已注册的 Platform IO 驱动,并提供基础查询接口。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化一个空的驱动注册表。"""
|
||||
self._drivers: Dict[str, PlatformIODriver] = {}
|
||||
|
||||
def register(self, driver: PlatformIODriver) -> None:
|
||||
"""注册一个驱动实例。
|
||||
|
||||
Args:
|
||||
driver: 要注册的驱动实例。
|
||||
|
||||
Raises:
|
||||
ValueError: 当驱动 ID 已经存在时抛出。
|
||||
"""
|
||||
if driver.driver_id in self._drivers:
|
||||
raise ValueError(f"驱动 {driver.driver_id} 已注册")
|
||||
self._drivers[driver.driver_id] = driver
|
||||
|
||||
def unregister(self, driver_id: str) -> Optional[PlatformIODriver]:
|
||||
"""按驱动 ID 注销一个驱动。
|
||||
|
||||
Args:
|
||||
driver_id: 要移除的驱动 ID。
|
||||
|
||||
Returns:
|
||||
Optional[PlatformIODriver]: 若驱动存在,则返回被移除的驱动实例。
|
||||
"""
|
||||
return self._drivers.pop(driver_id, None)
|
||||
|
||||
def get(self, driver_id: str) -> Optional[PlatformIODriver]:
|
||||
"""按驱动 ID 获取驱动实例。
|
||||
|
||||
Args:
|
||||
driver_id: 要查询的驱动 ID。
|
||||
|
||||
Returns:
|
||||
Optional[PlatformIODriver]: 若存在匹配驱动,则返回该驱动实例。
|
||||
"""
|
||||
return self._drivers.get(driver_id)
|
||||
|
||||
def list(self, *, kind: Optional[DriverKind] = None, platform: Optional[str] = None) -> List[PlatformIODriver]:
|
||||
"""列出已注册驱动,并支持可选过滤。
|
||||
|
||||
Args:
|
||||
kind: 可选的驱动类型过滤条件。
|
||||
platform: 可选的平台名称过滤条件。
|
||||
|
||||
Returns:
|
||||
List[PlatformIODriver]: 符合过滤条件的驱动列表。
|
||||
"""
|
||||
drivers = list(self._drivers.values())
|
||||
if kind is not None:
|
||||
drivers = [driver for driver in drivers if driver.descriptor.kind == kind]
|
||||
if platform is not None:
|
||||
drivers = [driver for driver in drivers if driver.descriptor.platform == platform]
|
||||
return drivers
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空全部已注册驱动。"""
|
||||
self._drivers.clear()
|
||||
150
src/platform_io/route_key_factory.py
Normal file
150
src/platform_io/route_key_factory.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""提供 Platform IO 路由键的统一提取与构造能力。
|
||||
|
||||
这层的目标不是直接接入具体消息链,而是先把“未来接线时用什么字段构造
|
||||
RouteKey”约定下来,避免 legacy 和 plugin 两条链路各自发明一套隐式规则。
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||
|
||||
from .types import RouteKey
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
|
||||
|
||||
class RouteKeyFactory:
|
||||
"""统一构造 ``RouteKey`` 的工厂。
|
||||
|
||||
当前约定会优先从消息字典顶层、``message_info``、``additional_config`` 或传入 metadata 中提取
|
||||
以下字段:
|
||||
|
||||
- account_id: ``platform_io_account_id`` / ``account_id`` / ``self_id`` / ``bot_account``
|
||||
- scope: ``platform_io_scope`` / ``route_scope`` / ``adapter_scope`` / ``connection_id``
|
||||
|
||||
这样即使上游主链暂时还没有正式的 ``self_id`` 字段,中间层也能先统一
|
||||
约定提取口径,等具体消息链接入时直接复用。
|
||||
"""
|
||||
|
||||
ACCOUNT_ID_KEYS = (
|
||||
"platform_io_account_id",
|
||||
"account_id",
|
||||
"self_id",
|
||||
"bot_account",
|
||||
)
|
||||
SCOPE_KEYS = (
|
||||
"platform_io_scope",
|
||||
"route_scope",
|
||||
"adapter_scope",
|
||||
"connection_id",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_platform(
|
||||
cls,
|
||||
platform: str,
|
||||
*,
|
||||
account_id: Optional[str] = None,
|
||||
scope: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> RouteKey:
|
||||
"""根据平台名和可选 metadata 构造 ``RouteKey``。
|
||||
|
||||
Args:
|
||||
platform: 平台名称。
|
||||
account_id: 显式传入的账号 ID;若为空,则尝试从 metadata 提取。
|
||||
scope: 显式传入的路由作用域;若为空,则尝试从 metadata 提取。
|
||||
metadata: 可选的元数据字典。
|
||||
|
||||
Returns:
|
||||
RouteKey: 构造出的规范化路由键。
|
||||
"""
|
||||
extracted_account_id, extracted_scope = cls.extract_components(metadata)
|
||||
return RouteKey(
|
||||
platform=platform,
|
||||
account_id=account_id or extracted_account_id,
|
||||
scope=scope or extracted_scope,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_message_dict(cls, message_dict: Dict[str, Any]) -> RouteKey:
|
||||
"""从消息字典中提取 ``RouteKey``。
|
||||
|
||||
Args:
|
||||
message_dict: Host 与插件之间传输的消息字典。
|
||||
|
||||
Returns:
|
||||
RouteKey: 构造出的规范化路由键。
|
||||
|
||||
Raises:
|
||||
ValueError: 当消息字典缺少有效 ``platform`` 字段时抛出。
|
||||
"""
|
||||
platform = str(message_dict.get("platform") or "").strip()
|
||||
if not platform:
|
||||
raise ValueError("消息字典缺少有效的 platform 字段,无法构造 RouteKey")
|
||||
|
||||
message_info = message_dict.get("message_info", {})
|
||||
additional_config = {}
|
||||
if isinstance(message_info, dict):
|
||||
raw_additional_config = message_info.get("additional_config", {})
|
||||
if isinstance(raw_additional_config, dict):
|
||||
additional_config = raw_additional_config
|
||||
|
||||
explicit_account_id, explicit_scope = cls.extract_components(message_dict)
|
||||
message_info_account_id, message_info_scope = cls.extract_components(message_info)
|
||||
metadata_account_id, metadata_scope = cls.extract_components(additional_config)
|
||||
return RouteKey(
|
||||
platform=platform,
|
||||
account_id=explicit_account_id or message_info_account_id or metadata_account_id,
|
||||
scope=explicit_scope or message_info_scope or metadata_scope,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_session_message(cls, message: "SessionMessage") -> RouteKey:
|
||||
"""从 ``SessionMessage`` 中提取 ``RouteKey``。
|
||||
|
||||
Args:
|
||||
message: 内部会话消息对象。
|
||||
|
||||
Returns:
|
||||
RouteKey: 构造出的规范化路由键。
|
||||
"""
|
||||
additional_config = message.message_info.additional_config or {}
|
||||
metadata = additional_config if isinstance(additional_config, dict) else {}
|
||||
return cls.from_platform(message.platform, metadata=metadata)
|
||||
|
||||
@classmethod
|
||||
def extract_components(cls, mapping: Optional[Dict[str, Any]]) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""从任意字典中提取 ``account_id`` 与 ``scope``。
|
||||
|
||||
Args:
|
||||
mapping: 待提取的字典;若为空或不是字典,则返回空结果。
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[str], Optional[str]]: ``(account_id, scope)``。
|
||||
"""
|
||||
if not mapping or not isinstance(mapping, dict):
|
||||
return None, None
|
||||
|
||||
account_id = cls._pick_string(mapping, cls.ACCOUNT_ID_KEYS)
|
||||
scope = cls._pick_string(mapping, cls.SCOPE_KEYS)
|
||||
return account_id, scope
|
||||
|
||||
@staticmethod
|
||||
def _pick_string(mapping: Dict[str, Any], keys: Tuple[str, ...]) -> Optional[str]:
|
||||
"""按优先级从字典里挑选第一个有效字符串。
|
||||
|
||||
Args:
|
||||
mapping: 待查询的字典。
|
||||
keys: 按优先级排列的候选键名。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 第一个规范化后非空的字符串值;若不存在则返回 ``None``。
|
||||
"""
|
||||
for key in keys:
|
||||
value = mapping.get(key)
|
||||
if value is None:
|
||||
continue
|
||||
normalized = str(value).strip()
|
||||
if normalized:
|
||||
return normalized
|
||||
return None
|
||||
141
src/platform_io/routing.py
Normal file
141
src/platform_io/routing.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""提供 Platform IO 的轻量路由绑定表。"""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from .types import RouteBinding, RouteKey
|
||||
|
||||
|
||||
class RouteTable:
|
||||
"""维护单张路由绑定表。
|
||||
|
||||
该实现不负责裁决“唯一 owner”,只负责保存绑定,并按
|
||||
``RouteKey.resolution_order()`` 解析出候选绑定列表。
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化空路由绑定表。"""
|
||||
|
||||
self._bindings: Dict[RouteKey, Dict[str, RouteBinding]] = {}
|
||||
|
||||
def bind(self, binding: RouteBinding) -> None:
|
||||
"""注册或更新一条路由绑定。
|
||||
|
||||
Args:
|
||||
binding: 要保存的路由绑定。
|
||||
"""
|
||||
|
||||
self._bindings.setdefault(binding.route_key, {})[binding.driver_id] = binding
|
||||
|
||||
def unbind(self, route_key: RouteKey, driver_id: Optional[str] = None) -> List[RouteBinding]:
|
||||
"""移除指定路由键上的绑定。
|
||||
|
||||
Args:
|
||||
route_key: 要移除绑定的路由键。
|
||||
driver_id: 可选的驱动 ID;为空时移除该路由键下全部绑定。
|
||||
|
||||
Returns:
|
||||
List[RouteBinding]: 被移除的绑定列表。
|
||||
"""
|
||||
|
||||
binding_map = self._bindings.get(route_key)
|
||||
if not binding_map:
|
||||
return []
|
||||
|
||||
if driver_id is None:
|
||||
removed = list(binding_map.values())
|
||||
self._bindings.pop(route_key, None)
|
||||
return self._sort_bindings(removed)
|
||||
|
||||
removed_binding = binding_map.pop(driver_id, None)
|
||||
if not binding_map:
|
||||
self._bindings.pop(route_key, None)
|
||||
return [removed_binding] if removed_binding is not None else []
|
||||
|
||||
def remove_bindings_by_driver(self, driver_id: str) -> List[RouteBinding]:
|
||||
"""移除某个驱动在整张表上的全部绑定。
|
||||
|
||||
Args:
|
||||
driver_id: 要移除绑定的驱动 ID。
|
||||
|
||||
Returns:
|
||||
List[RouteBinding]: 被移除的绑定列表。
|
||||
"""
|
||||
|
||||
removed_bindings: List[RouteBinding] = []
|
||||
empty_route_keys: List[RouteKey] = []
|
||||
for route_key, binding_map in self._bindings.items():
|
||||
removed_binding = binding_map.pop(driver_id, None)
|
||||
if removed_binding is not None:
|
||||
removed_bindings.append(removed_binding)
|
||||
if not binding_map:
|
||||
empty_route_keys.append(route_key)
|
||||
|
||||
for route_key in empty_route_keys:
|
||||
self._bindings.pop(route_key, None)
|
||||
|
||||
return self._sort_bindings(removed_bindings)
|
||||
|
||||
def list_bindings(self, route_key: Optional[RouteKey] = None) -> List[RouteBinding]:
|
||||
"""列出当前路由表中的绑定。
|
||||
|
||||
Args:
|
||||
route_key: 可选的路由键过滤条件。
|
||||
|
||||
Returns:
|
||||
List[RouteBinding]: 当前绑定列表。
|
||||
"""
|
||||
|
||||
if route_key is None:
|
||||
bindings: List[RouteBinding] = []
|
||||
for binding_map in self._bindings.values():
|
||||
bindings.extend(binding_map.values())
|
||||
return self._sort_bindings(bindings)
|
||||
|
||||
binding_map = self._bindings.get(route_key, {})
|
||||
return self._sort_bindings(list(binding_map.values()))
|
||||
|
||||
def resolve_bindings(self, route_key: RouteKey) -> List[RouteBinding]:
|
||||
"""按从具体到宽泛的顺序解析路由候选绑定。
|
||||
|
||||
Args:
|
||||
route_key: 待解析的路由键。
|
||||
|
||||
Returns:
|
||||
List[RouteBinding]: 去重后的候选绑定列表。
|
||||
"""
|
||||
|
||||
resolved_bindings: List[RouteBinding] = []
|
||||
seen_driver_ids: set[str] = set()
|
||||
for candidate_key in route_key.resolution_order():
|
||||
for binding in self.list_bindings(candidate_key):
|
||||
if binding.driver_id in seen_driver_ids:
|
||||
continue
|
||||
seen_driver_ids.add(binding.driver_id)
|
||||
resolved_bindings.append(binding)
|
||||
return resolved_bindings
|
||||
|
||||
def has_binding_for_driver(self, route_key: RouteKey, driver_id: str) -> bool:
|
||||
"""判断指定驱动是否在当前路由键解析结果中。
|
||||
|
||||
Args:
|
||||
route_key: 待解析的路由键。
|
||||
driver_id: 目标驱动 ID。
|
||||
|
||||
Returns:
|
||||
bool: 若驱动存在于解析结果中则返回 ``True``。
|
||||
"""
|
||||
|
||||
return any(binding.driver_id == driver_id for binding in self.resolve_bindings(route_key))
|
||||
|
||||
@staticmethod
|
||||
def _sort_bindings(bindings: List[RouteBinding]) -> List[RouteBinding]:
|
||||
"""按优先级降序排列绑定列表。
|
||||
|
||||
Args:
|
||||
bindings: 待排序的绑定列表。
|
||||
|
||||
Returns:
|
||||
List[RouteBinding]: 排序后的绑定列表。
|
||||
"""
|
||||
|
||||
return sorted(bindings, key=lambda item: item.priority, reverse=True)
|
||||
264
src/platform_io/types.py
Normal file
264
src/platform_io/types.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""定义 Platform IO 中间层共享的核心类型。
|
||||
|
||||
本模块放置路由、驱动、入站与出站等规范化数据结构,供 Broker
|
||||
层在 legacy 适配器链路和 plugin 适配器链路之间复用。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
|
||||
|
||||
class DriverKind(str, Enum):
|
||||
"""底层收发驱动类型枚举。"""
|
||||
|
||||
LEGACY = "legacy"
|
||||
PLUGIN = "plugin"
|
||||
|
||||
|
||||
class DeliveryStatus(str, Enum):
|
||||
"""统一出站回执状态枚举。"""
|
||||
|
||||
PENDING = "pending"
|
||||
SENT = "sent"
|
||||
FAILED = "failed"
|
||||
DROPPED = "dropped"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RouteKey:
|
||||
"""用于 Platform IO 路由决策的唯一键。
|
||||
|
||||
路由解析会按照“从最具体到最宽泛”的顺序进行回退,这样同一平台
|
||||
后续就能自然支持按账号、自定义 scope 等更细粒度的归属控制。
|
||||
|
||||
Attributes:
|
||||
platform: 平台名称,例如 ``qq``。
|
||||
account_id: 机器人账号 ID 或 self ID,用于区分同平台多身份。
|
||||
scope: 额外路由作用域,预留给未来的连接实例、租户或子通道等维度。
|
||||
"""
|
||||
|
||||
platform: str
|
||||
account_id: Optional[str] = None
|
||||
scope: Optional[str] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""规范化并校验路由键字段。
|
||||
|
||||
Raises:
|
||||
ValueError: 当 ``platform`` 规范化后为空时抛出。
|
||||
"""
|
||||
platform = str(self.platform).strip()
|
||||
account_id = str(self.account_id).strip() if self.account_id is not None else None
|
||||
scope = str(self.scope).strip() if self.scope is not None else None
|
||||
|
||||
if not platform:
|
||||
raise ValueError("RouteKey.platform 不能为空")
|
||||
|
||||
object.__setattr__(self, "platform", platform)
|
||||
object.__setattr__(self, "account_id", account_id or None)
|
||||
object.__setattr__(self, "scope", scope or None)
|
||||
|
||||
def resolution_order(self) -> List["RouteKey"]:
|
||||
"""返回从最具体到最宽泛的路由匹配顺序。
|
||||
|
||||
Returns:
|
||||
List[RouteKey]: 按回退优先级排序的候选路由键列表。
|
||||
"""
|
||||
|
||||
keys: List[RouteKey] = [self]
|
||||
|
||||
if self.account_id is not None and self.scope is not None:
|
||||
keys.append(RouteKey(platform=self.platform, account_id=self.account_id, scope=None))
|
||||
keys.append(RouteKey(platform=self.platform, account_id=None, scope=self.scope))
|
||||
elif self.account_id is not None:
|
||||
keys.append(RouteKey(platform=self.platform, account_id=None, scope=None))
|
||||
elif self.scope is not None:
|
||||
keys.append(RouteKey(platform=self.platform, account_id=None, scope=None))
|
||||
|
||||
default_key = RouteKey(platform=self.platform, account_id=None, scope=None)
|
||||
if default_key not in keys:
|
||||
keys.append(default_key)
|
||||
|
||||
return keys
|
||||
|
||||
def to_dedupe_scope(self) -> str:
|
||||
"""生成跨驱动共享的去重作用域字符串。
|
||||
|
||||
Returns:
|
||||
str: 用于入站消息去重的稳定文本作用域键。
|
||||
"""
|
||||
|
||||
account_id = self.account_id or "*"
|
||||
scope = self.scope or "*"
|
||||
return f"{self.platform}:{account_id}:{scope}"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DriverDescriptor:
|
||||
"""描述一个已注册的 Platform IO 驱动。
|
||||
|
||||
Attributes:
|
||||
driver_id: Broker 层内全局唯一的驱动标识。
|
||||
kind: 驱动实现类型,例如 legacy 或 plugin。
|
||||
platform: 驱动负责的平台名称。
|
||||
account_id: 可选的账号 ID 或 self ID。
|
||||
scope: 可选的额外路由作用域。
|
||||
plugin_id: 当驱动来自插件适配器时,对应的插件 ID。
|
||||
metadata: 预留给路由策略或观测能力的额外驱动元数据。
|
||||
"""
|
||||
|
||||
driver_id: str
|
||||
kind: DriverKind
|
||||
platform: str
|
||||
account_id: Optional[str] = None
|
||||
scope: Optional[str] = None
|
||||
plugin_id: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""规范化并校验驱动描述字段。
|
||||
|
||||
Raises:
|
||||
ValueError: 当 ``driver_id`` 或 ``platform`` 规范化后为空时抛出。
|
||||
"""
|
||||
driver_id = str(self.driver_id).strip()
|
||||
platform = str(self.platform).strip()
|
||||
plugin_id = str(self.plugin_id).strip() if self.plugin_id is not None else None
|
||||
|
||||
if not driver_id:
|
||||
raise ValueError("DriverDescriptor.driver_id 不能为空")
|
||||
if not platform:
|
||||
raise ValueError("DriverDescriptor.platform 不能为空")
|
||||
|
||||
object.__setattr__(self, "driver_id", driver_id)
|
||||
object.__setattr__(self, "platform", platform)
|
||||
object.__setattr__(self, "plugin_id", plugin_id or None)
|
||||
|
||||
@property
|
||||
def route_key(self) -> RouteKey:
|
||||
"""构造该驱动默认代表的路由键。
|
||||
|
||||
Returns:
|
||||
RouteKey: 当前驱动描述对应的规范化路由键。
|
||||
"""
|
||||
return RouteKey(platform=self.platform, account_id=self.account_id, scope=self.scope)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RouteBinding:
|
||||
"""表示一条从路由键到驱动的绑定关系。
|
||||
|
||||
Attributes:
|
||||
route_key: 该绑定覆盖的路由键。
|
||||
driver_id: 拥有该路由的驱动 ID。
|
||||
driver_kind: 绑定驱动的类型。
|
||||
priority: 当同一路由键存在多条绑定时使用的相对优先级。
|
||||
metadata: 预留给未来路由策略的额外绑定元数据。
|
||||
"""
|
||||
|
||||
route_key: RouteKey
|
||||
driver_id: str
|
||||
driver_kind: DriverKind
|
||||
priority: int = 0
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""规范化并校验绑定字段。
|
||||
|
||||
Raises:
|
||||
ValueError: 当 ``driver_id`` 规范化后为空时抛出。
|
||||
"""
|
||||
driver_id = str(self.driver_id).strip()
|
||||
if not driver_id:
|
||||
raise ValueError("RouteBinding.driver_id 不能为空")
|
||||
object.__setattr__(self, "driver_id", driver_id)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class InboundMessageEnvelope:
|
||||
"""封装一次由驱动产出的规范化入站消息。
|
||||
|
||||
Attributes:
|
||||
route_key: 该入站消息解析出的路由键。
|
||||
driver_id: 产出该消息的驱动 ID。
|
||||
driver_kind: 产出该消息的驱动类型。
|
||||
external_message_id: 可选的平台侧消息 ID,用于去重。
|
||||
dedupe_key: 可选的显式去重键。当外部消息没有稳定 ``message_id`` 时,
|
||||
可由上游驱动提供稳定的技术性幂等键。若这里为空,中间层仅会继续
|
||||
回退到 ``external_message_id`` 或 ``session_message.message_id``,
|
||||
不会再根据 ``payload`` 内容猜测语义去重键。
|
||||
session_message: 可选的、已经完成规范化的 ``SessionMessage`` 对象。
|
||||
payload: 可选的原始字典载荷,供延迟转换或调试使用。
|
||||
metadata: 额外入站元数据,例如连接信息或追踪上下文。
|
||||
"""
|
||||
|
||||
route_key: RouteKey
|
||||
driver_id: str
|
||||
driver_kind: DriverKind
|
||||
external_message_id: Optional[str] = None
|
||||
dedupe_key: Optional[str] = None
|
||||
session_message: Optional["SessionMessage"] = None
|
||||
payload: Optional[Dict[str, Any]] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class DeliveryReceipt:
|
||||
"""表示一次出站投递尝试的统一结果。
|
||||
|
||||
Attributes:
|
||||
internal_message_id: Broker 跟踪的内部 ``SessionMessage.message_id``。
|
||||
route_key: 本次投递使用的路由键。
|
||||
status: 规范化后的投递状态。
|
||||
driver_id: 实际处理该投递的驱动 ID,可为空。
|
||||
driver_kind: 实际处理该投递的驱动类型,可为空。
|
||||
external_message_id: 驱动或适配器返回的平台侧消息 ID,可为空。
|
||||
error: 投递失败时的错误信息,可为空。
|
||||
metadata: 预留给回执、时间戳或平台特有信息的额外元数据。
|
||||
"""
|
||||
|
||||
internal_message_id: str
|
||||
route_key: RouteKey
|
||||
status: DeliveryStatus
|
||||
driver_id: Optional[str] = None
|
||||
driver_kind: Optional[DriverKind] = None
|
||||
external_message_id: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class DeliveryBatch:
|
||||
"""表示一次广播式出站投递的批量结果。
|
||||
|
||||
Attributes:
|
||||
internal_message_id: 内部消息 ID。
|
||||
route_key: 本次投递使用的路由键。
|
||||
receipts: 各条路由的独立投递回执列表。
|
||||
"""
|
||||
|
||||
internal_message_id: str
|
||||
route_key: RouteKey
|
||||
receipts: List[DeliveryReceipt] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def sent_receipts(self) -> List[DeliveryReceipt]:
|
||||
"""返回全部发送成功的回执。"""
|
||||
|
||||
return [receipt for receipt in self.receipts if receipt.status == DeliveryStatus.SENT]
|
||||
|
||||
@property
|
||||
def failed_receipts(self) -> List[DeliveryReceipt]:
|
||||
"""返回全部发送失败的回执。"""
|
||||
|
||||
return [receipt for receipt in self.receipts if receipt.status != DeliveryStatus.SENT]
|
||||
|
||||
@property
|
||||
def has_success(self) -> bool:
|
||||
"""返回当前批量投递是否至少命中一条成功回执。"""
|
||||
|
||||
return bool(self.sent_receipts)
|
||||
@@ -16,3 +16,9 @@ ENV_PLUGIN_DIRS = "MAIBOT_PLUGIN_DIRS"
|
||||
|
||||
ENV_HOST_VERSION = "MAIBOT_HOST_VERSION"
|
||||
"""Runner 读取的 Host 应用版本号,用于 manifest 兼容性校验"""
|
||||
|
||||
ENV_EXTERNAL_PLUGIN_IDS = "MAIBOT_EXTERNAL_PLUGIN_IDS"
|
||||
"""Runner 启动时可视为已满足的外部插件依赖版本映射(JSON 对象)"""
|
||||
|
||||
ENV_GLOBAL_CONFIG_SNAPSHOT = "MAIBOT_GLOBAL_CONFIG_SNAPSHOT"
|
||||
"""Runner 启动时注入的全局配置快照(JSON 对象)"""
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Protocol
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Protocol, Sequence
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("plugin_runtime.integration")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.plugin_runtime.host.component_registry import RegisteredComponent
|
||||
from src.plugin_runtime.host.api_registry import APIEntry
|
||||
from src.plugin_runtime.host.component_registry import ComponentEntry
|
||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||||
|
||||
|
||||
@@ -14,18 +15,311 @@ class _RuntimeComponentManagerProtocol(Protocol):
|
||||
@property
|
||||
def supervisors(self) -> List["PluginSupervisor"]: ...
|
||||
|
||||
def _normalize_component_type(self, component_type: str) -> str: ...
|
||||
|
||||
def _is_api_component_type(self, component_type: str) -> bool: ...
|
||||
|
||||
def _serialize_api_entry(self, entry: "APIEntry") -> Dict[str, Any]: ...
|
||||
|
||||
def _serialize_api_component_entry(self, entry: "APIEntry") -> Dict[str, Any]: ...
|
||||
|
||||
def _is_api_visible_to_plugin(self, entry: "APIEntry", caller_plugin_id: str) -> bool: ...
|
||||
|
||||
def _normalize_api_reference(self, api_name: str, version: str = "") -> tuple[str, str]: ...
|
||||
|
||||
def _build_api_unavailable_error(self, entry: "APIEntry") -> str: ...
|
||||
|
||||
def _get_supervisor_for_plugin(self, plugin_id: str) -> Optional["PluginSupervisor"]: ...
|
||||
|
||||
def _resolve_api_target(
|
||||
self,
|
||||
caller_plugin_id: str,
|
||||
api_name: str,
|
||||
version: str = "",
|
||||
) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]: ...
|
||||
|
||||
def _resolve_api_toggle_target(
|
||||
self,
|
||||
name: str,
|
||||
version: str = "",
|
||||
) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]: ...
|
||||
|
||||
def _resolve_component_toggle_target(
|
||||
self, name: str, component_type: str
|
||||
) -> tuple[Optional["RegisteredComponent"], Optional[str]]: ...
|
||||
) -> tuple[Optional["ComponentEntry"], Optional[str]]: ...
|
||||
|
||||
def _find_duplicate_plugin_ids(self, plugin_dirs: List[Path]) -> Dict[str, List[Path]]: ...
|
||||
|
||||
def _iter_plugin_dirs(self) -> Iterable[Path]: ...
|
||||
|
||||
async def load_plugin_globally(self, plugin_id: str, reason: str = "manual") -> bool: ...
|
||||
|
||||
async def reload_plugins_globally(self, plugin_ids: Sequence[str], reason: str = "manual") -> bool: ...
|
||||
|
||||
|
||||
class RuntimeComponentCapabilityMixin:
|
||||
@staticmethod
|
||||
def _normalize_component_type(component_type: str) -> str:
|
||||
"""规范化组件类型名称。
|
||||
|
||||
Args:
|
||||
component_type: 原始组件类型。
|
||||
|
||||
Returns:
|
||||
str: 统一转为大写后的组件类型名。
|
||||
"""
|
||||
|
||||
return str(component_type or "").strip().upper()
|
||||
|
||||
@classmethod
|
||||
def _is_api_component_type(cls, component_type: str) -> bool:
|
||||
"""判断组件类型是否为 API。
|
||||
|
||||
Args:
|
||||
component_type: 原始组件类型。
|
||||
|
||||
Returns:
|
||||
bool: 是否为 API 组件类型。
|
||||
"""
|
||||
|
||||
return cls._normalize_component_type(component_type) == "API"
|
||||
|
||||
@staticmethod
|
||||
def _serialize_api_entry(entry: "APIEntry") -> Dict[str, Any]:
|
||||
"""将 API 组件条目序列化为能力返回值。
|
||||
|
||||
Args:
|
||||
entry: API 组件条目。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 适合通过能力层返回给插件的 API 元信息。
|
||||
"""
|
||||
|
||||
return {
|
||||
"name": entry.name,
|
||||
"full_name": entry.full_name,
|
||||
"plugin_id": entry.plugin_id,
|
||||
"description": entry.description,
|
||||
"version": entry.version,
|
||||
"public": entry.public,
|
||||
"enabled": entry.enabled,
|
||||
"dynamic": entry.dynamic,
|
||||
"offline_reason": entry.offline_reason,
|
||||
"metadata": dict(entry.metadata),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _serialize_api_component_entry(cls, entry: "APIEntry") -> Dict[str, Any]:
|
||||
"""将 API 条目序列化为通用组件视图。
|
||||
|
||||
Args:
|
||||
entry: API 组件条目。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 适合 ``component.get_all_plugins`` 返回的组件结构。
|
||||
"""
|
||||
|
||||
serialized_entry = cls._serialize_api_entry(entry)
|
||||
return {
|
||||
"name": serialized_entry["name"],
|
||||
"full_name": serialized_entry["full_name"],
|
||||
"type": "API",
|
||||
"enabled": serialized_entry["enabled"],
|
||||
"metadata": serialized_entry["metadata"],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _is_api_visible_to_plugin(entry: "APIEntry", caller_plugin_id: str) -> bool:
|
||||
"""判断某个 API 是否对调用方可见。
|
||||
|
||||
Args:
|
||||
entry: 目标 API 组件条目。
|
||||
caller_plugin_id: 调用方插件 ID。
|
||||
|
||||
Returns:
|
||||
bool: 是否允许当前插件可见并调用。
|
||||
"""
|
||||
|
||||
return entry.plugin_id == caller_plugin_id or entry.public
|
||||
|
||||
@staticmethod
|
||||
def _normalize_api_reference(api_name: str, version: str = "") -> tuple[str, str]:
|
||||
"""规范化 API 名称与版本参数。
|
||||
|
||||
支持在 ``api_name`` 中直接携带 ``@version`` 后缀。
|
||||
"""
|
||||
|
||||
normalized_api_name = str(api_name or "").strip()
|
||||
normalized_version = str(version or "").strip()
|
||||
if normalized_api_name and not normalized_version and "@" in normalized_api_name:
|
||||
candidate_name, candidate_version = normalized_api_name.rsplit("@", 1)
|
||||
candidate_name = candidate_name.strip()
|
||||
candidate_version = candidate_version.strip()
|
||||
if candidate_name and candidate_version:
|
||||
normalized_api_name = candidate_name
|
||||
normalized_version = candidate_version
|
||||
return normalized_api_name, normalized_version
|
||||
|
||||
@staticmethod
|
||||
def _build_api_unavailable_error(entry: "APIEntry") -> str:
|
||||
"""构造 API 当前不可用时的错误信息。"""
|
||||
|
||||
if entry.offline_reason:
|
||||
return entry.offline_reason
|
||||
return f"API {entry.registry_key} 当前不可用"
|
||||
|
||||
def _resolve_api_target(
|
||||
self: _RuntimeComponentManagerProtocol,
|
||||
caller_plugin_id: str,
|
||||
api_name: str,
|
||||
version: str = "",
|
||||
) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]:
|
||||
"""解析 API 名称到唯一可调用的目标组件。
|
||||
|
||||
Args:
|
||||
caller_plugin_id: 调用方插件 ID。
|
||||
api_name: API 名称,支持 ``plugin_id.api_name`` 或唯一短名。
|
||||
version: 可选的 API 版本。
|
||||
|
||||
Returns:
|
||||
tuple[Optional[PluginSupervisor], Optional[APIEntry], Optional[str]]:
|
||||
解析成功时返回 ``(监督器, API 条目, None)``,失败时返回错误信息。
|
||||
"""
|
||||
|
||||
normalized_api_name, normalized_version = self._normalize_api_reference(api_name, version)
|
||||
if not normalized_api_name:
|
||||
return None, None, "缺少必要参数 api_name"
|
||||
|
||||
if "." in normalized_api_name:
|
||||
target_plugin_id, target_api_name = normalized_api_name.rsplit(".", 1)
|
||||
try:
|
||||
supervisor = self._get_supervisor_for_plugin(target_plugin_id)
|
||||
except RuntimeError as exc:
|
||||
return None, None, str(exc)
|
||||
|
||||
if supervisor is None:
|
||||
return None, None, f"未找到 API 提供方插件: {target_plugin_id}"
|
||||
|
||||
entries = supervisor.api_registry.get_apis(
|
||||
plugin_id=target_plugin_id,
|
||||
name=target_api_name,
|
||||
version=normalized_version,
|
||||
enabled_only=False,
|
||||
)
|
||||
visible_enabled_entries = [
|
||||
entry
|
||||
for entry in entries
|
||||
if self._is_api_visible_to_plugin(entry, caller_plugin_id) and entry.enabled
|
||||
]
|
||||
visible_disabled_entries = [
|
||||
entry
|
||||
for entry in entries
|
||||
if self._is_api_visible_to_plugin(entry, caller_plugin_id) and not entry.enabled
|
||||
]
|
||||
if len(visible_enabled_entries) == 1:
|
||||
return supervisor, visible_enabled_entries[0], None
|
||||
if len(visible_enabled_entries) > 1:
|
||||
return None, None, f"API {normalized_api_name} 存在多个版本,请显式指定 version"
|
||||
if visible_disabled_entries:
|
||||
if len(visible_disabled_entries) == 1:
|
||||
return None, None, self._build_api_unavailable_error(visible_disabled_entries[0])
|
||||
return None, None, f"API {normalized_api_name} 存在多个已下线版本,请显式指定 version"
|
||||
if any(not self._is_api_visible_to_plugin(entry, caller_plugin_id) for entry in entries):
|
||||
return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用"
|
||||
if normalized_version:
|
||||
return None, None, f"未找到版本为 {normalized_version} 的 API: {normalized_api_name}"
|
||||
return None, None, f"未找到 API: {normalized_api_name}"
|
||||
|
||||
visible_enabled_matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
|
||||
visible_disabled_matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
|
||||
hidden_match_exists = False
|
||||
for supervisor in self.supervisors:
|
||||
for entry in supervisor.api_registry.get_apis(
|
||||
name=normalized_api_name,
|
||||
version=normalized_version,
|
||||
enabled_only=False,
|
||||
):
|
||||
if self._is_api_visible_to_plugin(entry, caller_plugin_id):
|
||||
if entry.enabled:
|
||||
visible_enabled_matches.append((supervisor, entry))
|
||||
else:
|
||||
visible_disabled_matches.append((supervisor, entry))
|
||||
else:
|
||||
hidden_match_exists = True
|
||||
|
||||
if len(visible_enabled_matches) == 1:
|
||||
return visible_enabled_matches[0][0], visible_enabled_matches[0][1], None
|
||||
if len(visible_enabled_matches) > 1:
|
||||
return None, None, f"API 名称不唯一: {normalized_api_name},请使用 plugin_id.api_name 或显式指定 version"
|
||||
if visible_disabled_matches:
|
||||
if len(visible_disabled_matches) == 1:
|
||||
return None, None, self._build_api_unavailable_error(visible_disabled_matches[0][1])
|
||||
return None, None, f"API {normalized_api_name} 存在多个已下线版本,请使用 plugin_id.api_name@version"
|
||||
if hidden_match_exists:
|
||||
return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用"
|
||||
if normalized_version:
|
||||
return None, None, f"未找到版本为 {normalized_version} 的 API: {normalized_api_name}"
|
||||
return None, None, f"未找到 API: {normalized_api_name}"
|
||||
|
||||
def _resolve_api_toggle_target(
|
||||
self: _RuntimeComponentManagerProtocol,
|
||||
name: str,
|
||||
version: str = "",
|
||||
) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]:
|
||||
"""解析需要启用或禁用的 API 组件。
|
||||
|
||||
Args:
|
||||
name: API 名称,支持 ``plugin_id.api_name`` 或唯一短名。
|
||||
version: 可选的 API 版本。
|
||||
|
||||
Returns:
|
||||
tuple[Optional[PluginSupervisor], Optional[APIEntry], Optional[str]]:
|
||||
解析成功时返回 ``(监督器, API 条目, None)``,失败时返回错误信息。
|
||||
"""
|
||||
|
||||
normalized_name, normalized_version = self._normalize_api_reference(name, version)
|
||||
if not normalized_name:
|
||||
return None, None, "缺少必要参数 name"
|
||||
|
||||
if "." in normalized_name:
|
||||
plugin_id, api_name = normalized_name.rsplit(".", 1)
|
||||
try:
|
||||
supervisor = self._get_supervisor_for_plugin(plugin_id)
|
||||
except RuntimeError as exc:
|
||||
return None, None, str(exc)
|
||||
|
||||
if supervisor is None:
|
||||
return None, None, f"未找到 API 提供方插件: {plugin_id}"
|
||||
|
||||
entries = supervisor.api_registry.get_apis(
|
||||
plugin_id=plugin_id,
|
||||
name=api_name,
|
||||
version=normalized_version,
|
||||
enabled_only=False,
|
||||
)
|
||||
if len(entries) == 1:
|
||||
return supervisor, entries[0], None
|
||||
if entries:
|
||||
return None, None, f"API {normalized_name} 存在多个版本,请显式指定 version"
|
||||
return None, None, f"未找到 API: {normalized_name}"
|
||||
|
||||
matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
|
||||
for supervisor in self.supervisors:
|
||||
matches.extend(
|
||||
(supervisor, entry)
|
||||
for entry in supervisor.api_registry.get_apis(
|
||||
name=normalized_name,
|
||||
version=normalized_version,
|
||||
enabled_only=False,
|
||||
)
|
||||
)
|
||||
|
||||
if len(matches) == 1:
|
||||
return matches[0][0], matches[0][1], None
|
||||
if len(matches) > 1:
|
||||
return None, None, f"API 名称不唯一: {normalized_name},请使用 plugin_id.api_name 或显式指定 version"
|
||||
return None, None, f"未找到 API: {normalized_name}"
|
||||
|
||||
async def _cap_component_get_all_plugins(
|
||||
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
|
||||
) -> Any:
|
||||
@@ -46,6 +340,10 @@ class RuntimeComponentCapabilityMixin:
|
||||
}
|
||||
for component in comps
|
||||
]
|
||||
components_list.extend(
|
||||
self._serialize_api_component_entry(entry)
|
||||
for entry in sv.api_registry.get_apis(plugin_id=pid, enabled_only=False)
|
||||
)
|
||||
result[pid] = {
|
||||
"name": pid,
|
||||
"version": reg.plugin_version,
|
||||
@@ -96,30 +394,35 @@ class RuntimeComponentCapabilityMixin:
|
||||
|
||||
def _resolve_component_toggle_target(
|
||||
self: _RuntimeComponentManagerProtocol, name: str, component_type: str
|
||||
) -> tuple[Optional["RegisteredComponent"], Optional[str]]:
|
||||
short_name_matches: List["RegisteredComponent"] = []
|
||||
) -> tuple[Optional["ComponentEntry"], Optional[str]]:
|
||||
normalized_component_type = self._normalize_component_type(component_type)
|
||||
short_name_matches: List["ComponentEntry"] = []
|
||||
for sv in self.supervisors:
|
||||
comp = sv.component_registry.get_component(name)
|
||||
if comp is not None and comp.component_type == component_type:
|
||||
if comp is not None and comp.component_type == normalized_component_type:
|
||||
return comp, None
|
||||
|
||||
short_name_matches.extend(
|
||||
candidate
|
||||
for candidate in sv.component_registry.get_components_by_type(component_type, enabled_only=False)
|
||||
for candidate in sv.component_registry.get_components_by_type(
|
||||
normalized_component_type,
|
||||
enabled_only=False,
|
||||
)
|
||||
if candidate.name == name
|
||||
)
|
||||
|
||||
if len(short_name_matches) == 1:
|
||||
return short_name_matches[0], None
|
||||
if len(short_name_matches) > 1:
|
||||
return None, f"组件名不唯一: {name} ({component_type}),请使用完整名 plugin_id.component_name"
|
||||
return None, f"未找到组件: {name} ({component_type})"
|
||||
return None, f"组件名不唯一: {name} ({normalized_component_type}),请使用完整名 plugin_id.component_name"
|
||||
return None, f"未找到组件: {name} ({normalized_component_type})"
|
||||
|
||||
async def _cap_component_enable(
|
||||
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
|
||||
) -> Any:
|
||||
name: str = args.get("name", "")
|
||||
component_type: str = args.get("component_type", "")
|
||||
version: str = args.get("version", "")
|
||||
scope: str = args.get("scope", "global")
|
||||
stream_id: str = args.get("stream_id", "")
|
||||
if not name or not component_type:
|
||||
@@ -127,6 +430,13 @@ class RuntimeComponentCapabilityMixin:
|
||||
if scope != "global" or stream_id:
|
||||
return {"success": False, "error": "当前仅支持全局组件启用,不支持 scope/stream_id 定位"}
|
||||
|
||||
if self._is_api_component_type(component_type):
|
||||
supervisor, api_entry, error = self._resolve_api_toggle_target(name, version)
|
||||
if supervisor is None or api_entry is None:
|
||||
return {"success": False, "error": error or f"未找到 API: {name}"}
|
||||
supervisor.api_registry.toggle_api_status(api_entry.registry_key, True)
|
||||
return {"success": True}
|
||||
|
||||
comp, error = self._resolve_component_toggle_target(name, component_type)
|
||||
if comp is None:
|
||||
return {"success": False, "error": error or f"未找到组件: {name} ({component_type})"}
|
||||
@@ -139,6 +449,7 @@ class RuntimeComponentCapabilityMixin:
|
||||
) -> Any:
|
||||
name: str = args.get("name", "")
|
||||
component_type: str = args.get("component_type", "")
|
||||
version: str = args.get("version", "")
|
||||
scope: str = args.get("scope", "global")
|
||||
stream_id: str = args.get("stream_id", "")
|
||||
if not name or not component_type:
|
||||
@@ -146,6 +457,13 @@ class RuntimeComponentCapabilityMixin:
|
||||
if scope != "global" or stream_id:
|
||||
return {"success": False, "error": "当前仅支持全局组件禁用,不支持 scope/stream_id 定位"}
|
||||
|
||||
if self._is_api_component_type(component_type):
|
||||
supervisor, api_entry, error = self._resolve_api_toggle_target(name, version)
|
||||
if supervisor is None or api_entry is None:
|
||||
return {"success": False, "error": error or f"未找到 API: {name}"}
|
||||
supervisor.api_registry.toggle_api_status(api_entry.registry_key, False)
|
||||
return {"success": True}
|
||||
|
||||
comp, error = self._resolve_component_toggle_target(name, component_type)
|
||||
if comp is None:
|
||||
return {"success": False, "error": error or f"未找到组件: {name} ({component_type})"}
|
||||
@@ -168,33 +486,14 @@ class RuntimeComponentCapabilityMixin:
|
||||
return {"success": False, "error": f"检测到重复插件 ID,拒绝热重载: {details}"}
|
||||
|
||||
try:
|
||||
registered_supervisor = self._get_supervisor_for_plugin(plugin_name)
|
||||
except RuntimeError as exc:
|
||||
return {"success": False, "error": str(exc)}
|
||||
loaded = await self.load_plugin_globally(plugin_name, reason=f"load {plugin_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.component.load_plugin] 热重载失败: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
if registered_supervisor is not None:
|
||||
try:
|
||||
reloaded = await registered_supervisor.reload_plugins(reason=f"load {plugin_name}")
|
||||
if reloaded:
|
||||
return {"success": True, "count": 1}
|
||||
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.component.load_plugin] 热重载失败: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
for sv in self.supervisors:
|
||||
for pdir in sv._plugin_dirs:
|
||||
if (pdir / plugin_name).is_dir():
|
||||
try:
|
||||
reloaded = await sv.reload_plugins(reason=f"load {plugin_name}")
|
||||
if reloaded:
|
||||
return {"success": True, "count": 1}
|
||||
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.component.load_plugin] 热重载失败: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
return {"success": False, "error": f"未找到插件: {plugin_name}"}
|
||||
if loaded:
|
||||
return {"success": True, "count": 1}
|
||||
return {"success": False, "error": f"插件 {plugin_name} 热重载失败"}
|
||||
|
||||
async def _cap_component_unload_plugin(
|
||||
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
|
||||
@@ -216,17 +515,204 @@ class RuntimeComponentCapabilityMixin:
|
||||
return {"success": False, "error": f"检测到重复插件 ID,拒绝热重载: {details}"}
|
||||
|
||||
try:
|
||||
sv = self._get_supervisor_for_plugin(plugin_name)
|
||||
reloaded = await self.reload_plugins_globally([plugin_name], reason=f"reload {plugin_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.component.reload_plugin] 热重载失败: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
if reloaded:
|
||||
return {"success": True}
|
||||
return {"success": False, "error": f"插件 {plugin_name} 热重载失败"}
|
||||
|
||||
async def _cap_api_call(
|
||||
self: _RuntimeComponentManagerProtocol,
|
||||
plugin_id: str,
|
||||
capability: str,
|
||||
args: Dict[str, Any],
|
||||
) -> Any:
|
||||
"""调用其他插件公开的 API。
|
||||
|
||||
Args:
|
||||
plugin_id: 当前调用方插件 ID。
|
||||
capability: 能力名称。
|
||||
args: 能力参数。
|
||||
|
||||
Returns:
|
||||
Any: API 调用结果。
|
||||
"""
|
||||
|
||||
del capability
|
||||
api_name = str(args.get("api_name", "") or "").strip()
|
||||
version = str(args.get("version", "") or "").strip()
|
||||
api_args = args.get("args", {})
|
||||
if not isinstance(api_args, dict):
|
||||
return {"success": False, "error": "参数 args 必须为字典"}
|
||||
|
||||
supervisor, entry, error = self._resolve_api_target(plugin_id, api_name, version)
|
||||
if supervisor is None or entry is None:
|
||||
return {"success": False, "error": error or "API 解析失败"}
|
||||
|
||||
invoke_args = dict(api_args)
|
||||
if entry.dynamic:
|
||||
invoke_args.setdefault("__maibot_api_name__", entry.name)
|
||||
invoke_args.setdefault("__maibot_api_full_name__", entry.full_name)
|
||||
invoke_args.setdefault("__maibot_api_version__", entry.version)
|
||||
|
||||
try:
|
||||
response = await supervisor.invoke_api(
|
||||
plugin_id=entry.plugin_id,
|
||||
component_name=entry.handler_name,
|
||||
args=invoke_args,
|
||||
timeout_ms=30000,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"[cap.api.call] 调用 API {entry.full_name} 失败: {exc}", exc_info=True)
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
if response.error:
|
||||
return {"success": False, "error": response.error.get("message", "API 调用失败")}
|
||||
|
||||
payload = response.payload if isinstance(response.payload, dict) else {}
|
||||
if not bool(payload.get("success", False)):
|
||||
result = payload.get("result")
|
||||
return {"success": False, "error": "" if result is None else str(result)}
|
||||
return {"success": True, "result": payload.get("result")}
|
||||
|
||||
async def _cap_api_get(
|
||||
self: _RuntimeComponentManagerProtocol,
|
||||
plugin_id: str,
|
||||
capability: str,
|
||||
args: Dict[str, Any],
|
||||
) -> Any:
|
||||
"""获取当前插件可见的单个 API 元信息。
|
||||
|
||||
Args:
|
||||
plugin_id: 当前调用方插件 ID。
|
||||
capability: 能力名称。
|
||||
args: 能力参数。
|
||||
|
||||
Returns:
|
||||
Any: API 元信息或 ``None``。
|
||||
"""
|
||||
|
||||
del capability
|
||||
api_name = str(args.get("api_name", "") or "").strip()
|
||||
version = str(args.get("version", "") or "").strip()
|
||||
if not api_name:
|
||||
return {"success": False, "error": "缺少必要参数 api_name"}
|
||||
|
||||
supervisor, entry, _error = self._resolve_api_target(plugin_id, api_name, version)
|
||||
if supervisor is None or entry is None:
|
||||
return {"success": True, "api": None}
|
||||
return {"success": True, "api": self._serialize_api_entry(entry)}
|
||||
|
||||
async def _cap_api_list(
|
||||
self: _RuntimeComponentManagerProtocol,
|
||||
plugin_id: str,
|
||||
capability: str,
|
||||
args: Dict[str, Any],
|
||||
) -> Any:
|
||||
"""列出当前插件可见的 API 列表。
|
||||
|
||||
Args:
|
||||
plugin_id: 当前调用方插件 ID。
|
||||
capability: 能力名称。
|
||||
args: 能力参数。
|
||||
|
||||
Returns:
|
||||
Any: API 元信息列表。
|
||||
"""
|
||||
|
||||
del capability
|
||||
target_plugin_id = str(args.get("plugin_id", "") or "").strip()
|
||||
api_name, version = self._normalize_api_reference(
|
||||
str(args.get("api_name", args.get("name", "")) or ""),
|
||||
str(args.get("version", "") or ""),
|
||||
)
|
||||
apis: List[Dict[str, Any]] = []
|
||||
for supervisor in self.supervisors:
|
||||
apis.extend(
|
||||
self._serialize_api_entry(entry)
|
||||
for entry in supervisor.api_registry.get_apis(
|
||||
plugin_id=target_plugin_id or None,
|
||||
name=api_name,
|
||||
version=version,
|
||||
enabled_only=True,
|
||||
)
|
||||
if self._is_api_visible_to_plugin(entry, plugin_id)
|
||||
)
|
||||
|
||||
apis.sort(key=lambda item: (str(item["plugin_id"]), str(item["name"]), str(item["version"])))
|
||||
return {"success": True, "apis": apis}
|
||||
|
||||
async def _cap_api_replace_dynamic(
|
||||
self: _RuntimeComponentManagerProtocol,
|
||||
plugin_id: str,
|
||||
capability: str,
|
||||
args: Dict[str, Any],
|
||||
) -> Any:
|
||||
"""替换插件自行维护的动态 API 列表。"""
|
||||
|
||||
del capability
|
||||
raw_apis = args.get("apis", [])
|
||||
offline_reason = str(args.get("offline_reason", "") or "").strip() or "动态 API 已下线"
|
||||
if not isinstance(raw_apis, list):
|
||||
return {"success": False, "error": "参数 apis 必须为列表"}
|
||||
|
||||
try:
|
||||
supervisor = self._get_supervisor_for_plugin(plugin_id)
|
||||
except RuntimeError as exc:
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
if sv is not None:
|
||||
try:
|
||||
reloaded = await sv.reload_plugins(reason=f"reload {plugin_name}")
|
||||
if reloaded:
|
||||
return {"success": True}
|
||||
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.component.reload_plugin] 热重载失败: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
return {"success": False, "error": f"未找到插件: {plugin_name}"}
|
||||
if supervisor is None:
|
||||
return {"success": False, "error": f"未找到插件: {plugin_id}"}
|
||||
|
||||
normalized_components: List[Dict[str, Any]] = []
|
||||
seen_registry_keys: set[str] = set()
|
||||
for index, raw_api in enumerate(raw_apis):
|
||||
if not isinstance(raw_api, dict):
|
||||
return {"success": False, "error": f"apis[{index}] 必须为字典"}
|
||||
|
||||
api_name = str(raw_api.get("name", "") or "").strip()
|
||||
component_type = str(raw_api.get("component_type", raw_api.get("type", "API")) or "").strip()
|
||||
if not api_name:
|
||||
return {"success": False, "error": f"apis[{index}] 缺少 name"}
|
||||
if not self._is_api_component_type(component_type):
|
||||
return {"success": False, "error": f"apis[{index}] 不是 API 组件"}
|
||||
|
||||
metadata = raw_api.get("metadata", {}) if isinstance(raw_api.get("metadata"), dict) else {}
|
||||
normalized_metadata = dict(metadata)
|
||||
normalized_metadata["dynamic"] = True
|
||||
version = str(normalized_metadata.get("version", "1") or "1").strip() or "1"
|
||||
registry_key = supervisor.api_registry.build_registry_key(plugin_id, api_name, version)
|
||||
if registry_key in seen_registry_keys:
|
||||
return {"success": False, "error": f"动态 API 重复声明: {registry_key}"}
|
||||
seen_registry_keys.add(registry_key)
|
||||
|
||||
existing_entry = supervisor.api_registry.get_api(
|
||||
plugin_id,
|
||||
api_name,
|
||||
version=version,
|
||||
enabled_only=False,
|
||||
)
|
||||
if existing_entry is not None and not existing_entry.dynamic:
|
||||
return {"success": False, "error": f"动态 API 不能覆盖静态 API: {registry_key}"}
|
||||
|
||||
normalized_components.append(
|
||||
{
|
||||
"name": api_name,
|
||||
"component_type": "API",
|
||||
"metadata": normalized_metadata,
|
||||
}
|
||||
)
|
||||
|
||||
registered_count, offlined_count = supervisor.api_registry.replace_plugin_dynamic_apis(
|
||||
plugin_id,
|
||||
normalized_components,
|
||||
offline_reason=offline_reason,
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"count": registered_count,
|
||||
"offlined": offlined_count,
|
||||
}
|
||||
|
||||
@@ -238,14 +238,14 @@ class RuntimeCoreCapabilityMixin:
|
||||
return {"success": False, "value": None, "error": str(e)}
|
||||
|
||||
async def _cap_config_get_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
from src.core.component_registry import component_registry as core_registry
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
|
||||
plugin_name: str = args.get("plugin_name", plugin_id)
|
||||
key: str = args.get("key", "")
|
||||
default = args.get("default")
|
||||
|
||||
try:
|
||||
config = core_registry.get_plugin_config(plugin_name)
|
||||
config = component_query_service.get_plugin_config(plugin_name)
|
||||
if config is None:
|
||||
return {"success": False, "value": default, "error": f"未找到插件 {plugin_name} 的配置"}
|
||||
|
||||
@@ -258,11 +258,11 @@ class RuntimeCoreCapabilityMixin:
|
||||
return {"success": False, "value": default, "error": str(e)}
|
||||
|
||||
async def _cap_config_get_all(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
from src.core.component_registry import component_registry as core_registry
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
|
||||
plugin_name: str = args.get("plugin_name", plugin_id)
|
||||
try:
|
||||
config = core_registry.get_plugin_config(plugin_name)
|
||||
config = component_query_service.get_plugin_config(plugin_name)
|
||||
if config is None:
|
||||
return {"success": True, "value": {}}
|
||||
return {"success": True, "value": config}
|
||||
|
||||
@@ -648,10 +648,10 @@ class RuntimeDataCapabilityMixin:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def _cap_tool_get_definitions(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
from src.core.component_registry import component_registry as core_registry
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
|
||||
try:
|
||||
tools = core_registry.get_llm_available_tools()
|
||||
tools = component_query_service.get_llm_available_tools()
|
||||
return {
|
||||
"success": True,
|
||||
"tools": [{"name": name, "definition": info.get_llm_definition()} for name, info in tools.items()],
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_runtime.host.capability_service import CapabilityImpl
|
||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -13,66 +14,80 @@ def register_capability_impls(manager: "PluginRuntimeManager", supervisor: Plugi
|
||||
"""向指定 Supervisor 注册主程序提供的能力实现。"""
|
||||
cap_service = supervisor.capability_service
|
||||
|
||||
cap_service.register_capability("send.text", manager._cap_send_text)
|
||||
cap_service.register_capability("send.emoji", manager._cap_send_emoji)
|
||||
cap_service.register_capability("send.image", manager._cap_send_image)
|
||||
cap_service.register_capability("send.command", manager._cap_send_command)
|
||||
cap_service.register_capability("send.custom", manager._cap_send_custom)
|
||||
def _register(name: str, impl: CapabilityImpl) -> None:
|
||||
"""注册单个能力实现。
|
||||
|
||||
cap_service.register_capability("llm.generate", manager._cap_llm_generate)
|
||||
cap_service.register_capability("llm.generate_with_tools", manager._cap_llm_generate_with_tools)
|
||||
cap_service.register_capability("llm.get_available_models", manager._cap_llm_get_available_models)
|
||||
Args:
|
||||
name: 能力名称。
|
||||
impl: 能力实现函数。
|
||||
"""
|
||||
cap_service.register_capability(name, impl)
|
||||
|
||||
cap_service.register_capability("config.get", manager._cap_config_get)
|
||||
cap_service.register_capability("config.get_plugin", manager._cap_config_get_plugin)
|
||||
cap_service.register_capability("config.get_all", manager._cap_config_get_all)
|
||||
_register("send.text", manager._cap_send_text)
|
||||
_register("send.emoji", manager._cap_send_emoji)
|
||||
_register("send.image", manager._cap_send_image)
|
||||
_register("send.command", manager._cap_send_command)
|
||||
_register("send.custom", manager._cap_send_custom)
|
||||
|
||||
cap_service.register_capability("database.query", manager._cap_database_query)
|
||||
cap_service.register_capability("database.save", manager._cap_database_save)
|
||||
cap_service.register_capability("database.get", manager._cap_database_get)
|
||||
cap_service.register_capability("database.delete", manager._cap_database_delete)
|
||||
cap_service.register_capability("database.count", manager._cap_database_count)
|
||||
_register("llm.generate", manager._cap_llm_generate)
|
||||
_register("llm.generate_with_tools", manager._cap_llm_generate_with_tools)
|
||||
_register("llm.get_available_models", manager._cap_llm_get_available_models)
|
||||
|
||||
cap_service.register_capability("chat.get_all_streams", manager._cap_chat_get_all_streams)
|
||||
cap_service.register_capability("chat.get_group_streams", manager._cap_chat_get_group_streams)
|
||||
cap_service.register_capability("chat.get_private_streams", manager._cap_chat_get_private_streams)
|
||||
cap_service.register_capability("chat.get_stream_by_group_id", manager._cap_chat_get_stream_by_group_id)
|
||||
cap_service.register_capability("chat.get_stream_by_user_id", manager._cap_chat_get_stream_by_user_id)
|
||||
_register("config.get", manager._cap_config_get)
|
||||
_register("config.get_plugin", manager._cap_config_get_plugin)
|
||||
_register("config.get_all", manager._cap_config_get_all)
|
||||
|
||||
cap_service.register_capability("message.get_by_time", manager._cap_message_get_by_time)
|
||||
cap_service.register_capability("message.get_by_time_in_chat", manager._cap_message_get_by_time_in_chat)
|
||||
cap_service.register_capability("message.get_recent", manager._cap_message_get_recent)
|
||||
cap_service.register_capability("message.count_new", manager._cap_message_count_new)
|
||||
cap_service.register_capability("message.build_readable", manager._cap_message_build_readable)
|
||||
_register("database.query", manager._cap_database_query)
|
||||
_register("database.save", manager._cap_database_save)
|
||||
_register("database.get", manager._cap_database_get)
|
||||
_register("database.delete", manager._cap_database_delete)
|
||||
_register("database.count", manager._cap_database_count)
|
||||
|
||||
cap_service.register_capability("person.get_id", manager._cap_person_get_id)
|
||||
cap_service.register_capability("person.get_value", manager._cap_person_get_value)
|
||||
cap_service.register_capability("person.get_id_by_name", manager._cap_person_get_id_by_name)
|
||||
_register("chat.get_all_streams", manager._cap_chat_get_all_streams)
|
||||
_register("chat.get_group_streams", manager._cap_chat_get_group_streams)
|
||||
_register("chat.get_private_streams", manager._cap_chat_get_private_streams)
|
||||
_register("chat.get_stream_by_group_id", manager._cap_chat_get_stream_by_group_id)
|
||||
_register("chat.get_stream_by_user_id", manager._cap_chat_get_stream_by_user_id)
|
||||
|
||||
cap_service.register_capability("emoji.get_by_description", manager._cap_emoji_get_by_description)
|
||||
cap_service.register_capability("emoji.get_random", manager._cap_emoji_get_random)
|
||||
cap_service.register_capability("emoji.get_count", manager._cap_emoji_get_count)
|
||||
cap_service.register_capability("emoji.get_emotions", manager._cap_emoji_get_emotions)
|
||||
cap_service.register_capability("emoji.get_all", manager._cap_emoji_get_all)
|
||||
cap_service.register_capability("emoji.get_info", manager._cap_emoji_get_info)
|
||||
cap_service.register_capability("emoji.register", manager._cap_emoji_register)
|
||||
cap_service.register_capability("emoji.delete", manager._cap_emoji_delete)
|
||||
_register("message.get_by_time", manager._cap_message_get_by_time)
|
||||
_register("message.get_by_time_in_chat", manager._cap_message_get_by_time_in_chat)
|
||||
_register("message.get_recent", manager._cap_message_get_recent)
|
||||
_register("message.count_new", manager._cap_message_count_new)
|
||||
_register("message.build_readable", manager._cap_message_build_readable)
|
||||
|
||||
cap_service.register_capability("frequency.get_current_talk_value", manager._cap_frequency_get_current_talk_value)
|
||||
cap_service.register_capability("frequency.set_adjust", manager._cap_frequency_set_adjust)
|
||||
cap_service.register_capability("frequency.get_adjust", manager._cap_frequency_get_adjust)
|
||||
_register("person.get_id", manager._cap_person_get_id)
|
||||
_register("person.get_value", manager._cap_person_get_value)
|
||||
_register("person.get_id_by_name", manager._cap_person_get_id_by_name)
|
||||
|
||||
cap_service.register_capability("tool.get_definitions", manager._cap_tool_get_definitions)
|
||||
_register("emoji.get_by_description", manager._cap_emoji_get_by_description)
|
||||
_register("emoji.get_random", manager._cap_emoji_get_random)
|
||||
_register("emoji.get_count", manager._cap_emoji_get_count)
|
||||
_register("emoji.get_emotions", manager._cap_emoji_get_emotions)
|
||||
_register("emoji.get_all", manager._cap_emoji_get_all)
|
||||
_register("emoji.get_info", manager._cap_emoji_get_info)
|
||||
_register("emoji.register", manager._cap_emoji_register)
|
||||
_register("emoji.delete", manager._cap_emoji_delete)
|
||||
|
||||
cap_service.register_capability("component.get_all_plugins", manager._cap_component_get_all_plugins)
|
||||
cap_service.register_capability("component.get_plugin_info", manager._cap_component_get_plugin_info)
|
||||
cap_service.register_capability("component.list_loaded_plugins", manager._cap_component_list_loaded_plugins)
|
||||
cap_service.register_capability("component.list_registered_plugins", manager._cap_component_list_registered_plugins)
|
||||
cap_service.register_capability("component.enable", manager._cap_component_enable)
|
||||
cap_service.register_capability("component.disable", manager._cap_component_disable)
|
||||
cap_service.register_capability("component.load_plugin", manager._cap_component_load_plugin)
|
||||
cap_service.register_capability("component.unload_plugin", manager._cap_component_unload_plugin)
|
||||
cap_service.register_capability("component.reload_plugin", manager._cap_component_reload_plugin)
|
||||
_register("frequency.get_current_talk_value", manager._cap_frequency_get_current_talk_value)
|
||||
_register("frequency.set_adjust", manager._cap_frequency_set_adjust)
|
||||
_register("frequency.get_adjust", manager._cap_frequency_get_adjust)
|
||||
|
||||
cap_service.register_capability("knowledge.search", manager._cap_knowledge_search)
|
||||
_register("tool.get_definitions", manager._cap_tool_get_definitions)
|
||||
|
||||
_register("api.call", manager._cap_api_call)
|
||||
_register("api.get", manager._cap_api_get)
|
||||
_register("api.list", manager._cap_api_list)
|
||||
_register("api.replace_dynamic", manager._cap_api_replace_dynamic)
|
||||
|
||||
_register("component.get_all_plugins", manager._cap_component_get_all_plugins)
|
||||
_register("component.get_plugin_info", manager._cap_component_get_plugin_info)
|
||||
_register("component.list_loaded_plugins", manager._cap_component_list_loaded_plugins)
|
||||
_register("component.list_registered_plugins", manager._cap_component_list_registered_plugins)
|
||||
_register("component.enable", manager._cap_component_enable)
|
||||
_register("component.disable", manager._cap_component_disable)
|
||||
_register("component.load_plugin", manager._cap_component_load_plugin)
|
||||
_register("component.unload_plugin", manager._cap_component_unload_plugin)
|
||||
_register("component.reload_plugin", manager._cap_component_reload_plugin)
|
||||
|
||||
_register("knowledge.search", manager._cap_knowledge_search)
|
||||
logger.debug("已注册全部主程序能力实现")
|
||||
|
||||
709
src/plugin_runtime/component_query.py
Normal file
709
src/plugin_runtime/component_query.py
Normal file
@@ -0,0 +1,709 @@
|
||||
"""插件运行时统一组件查询服务。
|
||||
|
||||
该模块统一从插件运行时的 Host ComponentRegistry 中聚合只读视图,
|
||||
供 HFC/PFC、Planner、ToolExecutor 和运行时能力层查询与调用。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.core.types import ActionActivationType, ActionInfo, CommandInfo, ComponentInfo, ComponentType, ToolInfo
|
||||
from src.llm_models.payload_content.tool_option import ToolParamType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.plugin_runtime.host.component_registry import ActionEntry, CommandEntry, ComponentEntry, ToolEntry
|
||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||||
from src.plugin_runtime.integration import PluginRuntimeManager
|
||||
|
||||
logger = get_logger("plugin_runtime.component_query")
|
||||
|
||||
ActionExecutor = Callable[..., Awaitable[Any]]
|
||||
CommandExecutor = Callable[..., Awaitable[Tuple[bool, Optional[str], bool]]]
|
||||
ToolExecutor = Callable[..., Awaitable[Any]]
|
||||
|
||||
_HOST_COMPONENT_TYPE_MAP: Dict[ComponentType, str] = {
|
||||
ComponentType.ACTION: "ACTION",
|
||||
ComponentType.COMMAND: "COMMAND",
|
||||
ComponentType.TOOL: "TOOL",
|
||||
}
|
||||
_TOOL_PARAM_TYPE_MAP: Dict[str, ToolParamType] = {
|
||||
"string": ToolParamType.STRING,
|
||||
"integer": ToolParamType.INTEGER,
|
||||
"float": ToolParamType.FLOAT,
|
||||
"boolean": ToolParamType.BOOLEAN,
|
||||
"bool": ToolParamType.BOOLEAN,
|
||||
}
|
||||
|
||||
|
||||
class ComponentQueryService:
|
||||
"""插件运行时统一组件查询服务。
|
||||
|
||||
该对象不维护独立状态,只读取插件系统中的注册结果。
|
||||
所有注册、删除、配置写入等写操作都被显式禁用。
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _get_runtime_manager() -> "PluginRuntimeManager":
|
||||
"""获取插件运行时管理器单例。
|
||||
|
||||
Returns:
|
||||
PluginRuntimeManager: 当前全局插件运行时管理器。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
def _iter_supervisors(self) -> list["PluginSupervisor"]:
|
||||
"""获取当前所有活跃的插件运行时监督器。
|
||||
|
||||
Returns:
|
||||
list[PluginSupervisor]: 当前运行中的监督器列表。
|
||||
"""
|
||||
|
||||
runtime_manager = self._get_runtime_manager()
|
||||
return list(runtime_manager.supervisors)
|
||||
|
||||
def _iter_component_entries(
|
||||
self,
|
||||
component_type: ComponentType,
|
||||
*,
|
||||
enabled_only: bool = True,
|
||||
) -> list[tuple["PluginSupervisor", "ComponentEntry"]]:
|
||||
"""遍历指定类型的全部组件条目。
|
||||
|
||||
Args:
|
||||
component_type: 目标组件类型。
|
||||
enabled_only: 是否仅返回启用状态的组件。
|
||||
|
||||
Returns:
|
||||
list[tuple[PluginSupervisor, ComponentEntry]]: ``(监督器, 组件条目)`` 列表。
|
||||
"""
|
||||
|
||||
host_component_type = _HOST_COMPONENT_TYPE_MAP.get(component_type)
|
||||
if host_component_type is None:
|
||||
return []
|
||||
|
||||
collected_entries: list[tuple["PluginSupervisor", "ComponentEntry"]] = []
|
||||
for supervisor in self._iter_supervisors():
|
||||
for component in supervisor.component_registry.get_components_by_type(
|
||||
host_component_type,
|
||||
enabled_only=enabled_only,
|
||||
):
|
||||
collected_entries.append((supervisor, component))
|
||||
return collected_entries
|
||||
|
||||
@staticmethod
|
||||
def _coerce_action_activation_type(raw_value: Any) -> ActionActivationType:
|
||||
"""规范化动作激活类型。
|
||||
|
||||
Args:
|
||||
raw_value: 原始激活类型值。
|
||||
|
||||
Returns:
|
||||
ActionActivationType: 规范化后的激活类型枚举。
|
||||
"""
|
||||
|
||||
normalized_value = str(raw_value or "").strip().lower()
|
||||
if normalized_value == ActionActivationType.NEVER.value:
|
||||
return ActionActivationType.NEVER
|
||||
if normalized_value == ActionActivationType.RANDOM.value:
|
||||
return ActionActivationType.RANDOM
|
||||
if normalized_value == ActionActivationType.KEYWORD.value:
|
||||
return ActionActivationType.KEYWORD
|
||||
return ActionActivationType.ALWAYS
|
||||
|
||||
@staticmethod
|
||||
def _coerce_float(value: Any, default: float = 0.0) -> float:
|
||||
"""将任意值安全转换为浮点数。
|
||||
|
||||
Args:
|
||||
value: 待转换的输入值。
|
||||
default: 转换失败时返回的默认值。
|
||||
|
||||
Returns:
|
||||
float: 转换后的浮点结果。
|
||||
"""
|
||||
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def _build_action_info(entry: "ActionEntry") -> ActionInfo:
|
||||
"""将运行时 Action 条目转换为核心动作信息。
|
||||
|
||||
Args:
|
||||
entry: 插件运行时中的 Action 条目。
|
||||
|
||||
Returns:
|
||||
ActionInfo: 供核心 Planner 使用的动作信息。
|
||||
"""
|
||||
|
||||
metadata = dict(entry.metadata)
|
||||
raw_action_parameters = metadata.get("action_parameters")
|
||||
action_parameters = (
|
||||
{
|
||||
str(param_name): str(param_description)
|
||||
for param_name, param_description in raw_action_parameters.items()
|
||||
}
|
||||
if isinstance(raw_action_parameters, dict)
|
||||
else {}
|
||||
)
|
||||
action_require = [
|
||||
str(item)
|
||||
for item in (metadata.get("action_require") or [])
|
||||
if item is not None and str(item).strip()
|
||||
]
|
||||
associated_types = [
|
||||
str(item)
|
||||
for item in (metadata.get("associated_types") or [])
|
||||
if item is not None and str(item).strip()
|
||||
]
|
||||
activation_keywords = [
|
||||
str(item)
|
||||
for item in (metadata.get("activation_keywords") or [])
|
||||
if item is not None and str(item).strip()
|
||||
]
|
||||
|
||||
return ActionInfo(
|
||||
name=entry.name,
|
||||
component_type=ComponentType.ACTION,
|
||||
description=str(metadata.get("description", "") or ""),
|
||||
enabled=bool(entry.enabled),
|
||||
plugin_name=entry.plugin_id,
|
||||
metadata=metadata,
|
||||
action_parameters=action_parameters,
|
||||
action_require=action_require,
|
||||
associated_types=associated_types,
|
||||
activation_type=ComponentQueryService._coerce_action_activation_type(metadata.get("activation_type")),
|
||||
random_activation_probability=ComponentQueryService._coerce_float(
|
||||
metadata.get("activation_probability"),
|
||||
0.0,
|
||||
),
|
||||
activation_keywords=activation_keywords,
|
||||
parallel_action=bool(metadata.get("parallel_action", False)),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_command_info(entry: "CommandEntry") -> CommandInfo:
|
||||
"""将运行时 Command 条目转换为核心命令信息。
|
||||
|
||||
Args:
|
||||
entry: 插件运行时中的 Command 条目。
|
||||
|
||||
Returns:
|
||||
CommandInfo: 供核心命令链使用的命令信息。
|
||||
"""
|
||||
|
||||
metadata = dict(entry.metadata)
|
||||
return CommandInfo(
|
||||
name=entry.name,
|
||||
component_type=ComponentType.COMMAND,
|
||||
description=str(metadata.get("description", "") or ""),
|
||||
enabled=bool(entry.enabled),
|
||||
plugin_name=entry.plugin_id,
|
||||
metadata=metadata,
|
||||
command_pattern=str(metadata.get("command_pattern", "") or ""),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _coerce_tool_param_type(raw_value: Any) -> ToolParamType:
|
||||
"""规范化工具参数类型。
|
||||
|
||||
Args:
|
||||
raw_value: 原始工具参数类型值。
|
||||
|
||||
Returns:
|
||||
ToolParamType: 规范化后的工具参数类型。
|
||||
"""
|
||||
|
||||
normalized_value = str(raw_value or "").strip().lower()
|
||||
return _TOOL_PARAM_TYPE_MAP.get(normalized_value, ToolParamType.STRING)
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_parameters(entry: "ToolEntry") -> list[tuple[str, ToolParamType, str, bool, list[str] | None]]:
|
||||
"""将运行时工具参数元数据转换为核心 ToolInfo 参数列表。
|
||||
|
||||
Args:
|
||||
entry: 插件运行时中的 Tool 条目。
|
||||
|
||||
Returns:
|
||||
list[tuple[str, ToolParamType, str, bool, list[str] | None]]: 转换后的参数列表。
|
||||
"""
|
||||
|
||||
structured_parameters = entry.parameters if isinstance(entry.parameters, list) else []
|
||||
if not structured_parameters and isinstance(entry.parameters_raw, dict):
|
||||
structured_parameters = [
|
||||
{"name": key, **value}
|
||||
for key, value in entry.parameters_raw.items()
|
||||
if isinstance(value, dict)
|
||||
]
|
||||
|
||||
normalized_parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = []
|
||||
for parameter in structured_parameters:
|
||||
if not isinstance(parameter, dict):
|
||||
continue
|
||||
|
||||
parameter_name = str(parameter.get("name", "") or "").strip()
|
||||
if not parameter_name:
|
||||
continue
|
||||
|
||||
enum_values = parameter.get("enum")
|
||||
normalized_enum_values = (
|
||||
[str(item) for item in enum_values if item is not None]
|
||||
if isinstance(enum_values, list)
|
||||
else None
|
||||
)
|
||||
normalized_parameters.append(
|
||||
(
|
||||
parameter_name,
|
||||
ComponentQueryService._coerce_tool_param_type(parameter.get("param_type") or parameter.get("type")),
|
||||
str(parameter.get("description", "") or ""),
|
||||
bool(parameter.get("required", True)),
|
||||
normalized_enum_values,
|
||||
)
|
||||
)
|
||||
return normalized_parameters
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_info(entry: "ToolEntry") -> ToolInfo:
|
||||
"""将运行时 Tool 条目转换为核心工具信息。
|
||||
|
||||
Args:
|
||||
entry: 插件运行时中的 Tool 条目。
|
||||
|
||||
Returns:
|
||||
ToolInfo: 供 ToolExecutor 与能力层使用的工具信息。
|
||||
"""
|
||||
|
||||
return ToolInfo(
|
||||
name=entry.name,
|
||||
component_type=ComponentType.TOOL,
|
||||
description=entry.description,
|
||||
enabled=bool(entry.enabled),
|
||||
plugin_name=entry.plugin_id,
|
||||
metadata=dict(entry.metadata),
|
||||
tool_parameters=ComponentQueryService._build_tool_parameters(entry),
|
||||
tool_description=entry.description,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _log_duplicate_component(component_type: ComponentType, component_name: str) -> None:
|
||||
"""记录重复组件名称冲突。
|
||||
|
||||
Args:
|
||||
component_type: 组件类型。
|
||||
component_name: 发生冲突的组件名称。
|
||||
"""
|
||||
|
||||
logger.warning(f"检测到重复{component_type.value}名称 {component_name},将只保留首个匹配项")
|
||||
|
||||
def _get_unique_component_entry(
|
||||
self,
|
||||
component_type: ComponentType,
|
||||
name: str,
|
||||
) -> Optional[tuple["PluginSupervisor", "ComponentEntry"]]:
|
||||
"""按组件短名解析唯一条目。
|
||||
|
||||
Args:
|
||||
component_type: 目标组件类型。
|
||||
name: 组件短名。
|
||||
|
||||
Returns:
|
||||
Optional[tuple[PluginSupervisor, ComponentEntry]]: 唯一命中的组件条目。
|
||||
"""
|
||||
|
||||
matched_entries = [
|
||||
(supervisor, entry)
|
||||
for supervisor, entry in self._iter_component_entries(component_type)
|
||||
if entry.name == name
|
||||
]
|
||||
if not matched_entries:
|
||||
return None
|
||||
if len(matched_entries) > 1:
|
||||
self._log_duplicate_component(component_type, name)
|
||||
return matched_entries[0]
|
||||
|
||||
def _collect_unique_component_infos(
|
||||
self,
|
||||
component_type: ComponentType,
|
||||
) -> Dict[str, ComponentInfo]:
|
||||
"""收集某类组件的唯一信息视图。
|
||||
|
||||
Args:
|
||||
component_type: 目标组件类型。
|
||||
|
||||
Returns:
|
||||
Dict[str, ComponentInfo]: 组件名到核心组件信息的映射。
|
||||
"""
|
||||
|
||||
collected_components: Dict[str, ComponentInfo] = {}
|
||||
for _supervisor, entry in self._iter_component_entries(component_type):
|
||||
if entry.name in collected_components:
|
||||
self._log_duplicate_component(component_type, entry.name)
|
||||
continue
|
||||
|
||||
if component_type == ComponentType.ACTION:
|
||||
collected_components[entry.name] = self._build_action_info(entry) # type: ignore[arg-type]
|
||||
elif component_type == ComponentType.COMMAND:
|
||||
collected_components[entry.name] = self._build_command_info(entry) # type: ignore[arg-type]
|
||||
elif component_type == ComponentType.TOOL:
|
||||
collected_components[entry.name] = self._build_tool_info(entry) # type: ignore[arg-type]
|
||||
return collected_components
|
||||
|
||||
@staticmethod
|
||||
def _extract_stream_id_from_action_kwargs(kwargs: Dict[str, Any]) -> str:
|
||||
"""从旧 ActionManager 参数中提取聊天流 ID。
|
||||
|
||||
Args:
|
||||
kwargs: 旧动作执行器收到的关键字参数。
|
||||
|
||||
Returns:
|
||||
str: 提取出的 ``stream_id``。
|
||||
"""
|
||||
|
||||
chat_stream = kwargs.get("chat_stream")
|
||||
if chat_stream is not None:
|
||||
try:
|
||||
return str(chat_stream.session_id)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
return str(kwargs.get("stream_id", "") or "")
|
||||
|
||||
@staticmethod
|
||||
def _build_action_executor(supervisor: "PluginSupervisor", plugin_id: str, component_name: str) -> ActionExecutor:
|
||||
"""构造动作执行 RPC 闭包。
|
||||
|
||||
Args:
|
||||
supervisor: 负责该组件的监督器。
|
||||
plugin_id: 插件 ID。
|
||||
component_name: 组件名称。
|
||||
|
||||
Returns:
|
||||
ActionExecutor: 兼容旧 Planner 的异步执行器。
|
||||
"""
|
||||
|
||||
async def _executor(**kwargs: Any) -> tuple[bool, str]:
|
||||
"""将核心动作调用桥接到插件运行时。
|
||||
|
||||
Args:
|
||||
**kwargs: 旧 ActionManager 传入的上下文参数。
|
||||
|
||||
Returns:
|
||||
tuple[bool, str]: ``(是否成功, 结果说明)``。
|
||||
"""
|
||||
|
||||
invoke_args: Dict[str, Any] = {}
|
||||
action_data = kwargs.get("action_data")
|
||||
if isinstance(action_data, dict):
|
||||
invoke_args.update(action_data)
|
||||
|
||||
stream_id = ComponentQueryService._extract_stream_id_from_action_kwargs(kwargs)
|
||||
invoke_args["action_data"] = action_data if isinstance(action_data, dict) else {}
|
||||
invoke_args["stream_id"] = stream_id
|
||||
invoke_args["chat_id"] = stream_id
|
||||
invoke_args["reasoning"] = str(kwargs.get("action_reasoning", "") or "")
|
||||
|
||||
if (thinking_id := kwargs.get("thinking_id")) is not None:
|
||||
invoke_args["thinking_id"] = str(thinking_id)
|
||||
if isinstance(kwargs.get("cycle_timers"), dict):
|
||||
invoke_args["cycle_timers"] = kwargs["cycle_timers"]
|
||||
if isinstance(kwargs.get("plugin_config"), dict):
|
||||
invoke_args["plugin_config"] = kwargs["plugin_config"]
|
||||
if isinstance(kwargs.get("log_prefix"), str):
|
||||
invoke_args["log_prefix"] = kwargs["log_prefix"]
|
||||
if isinstance(kwargs.get("shutting_down"), bool):
|
||||
invoke_args["shutting_down"] = kwargs["shutting_down"]
|
||||
|
||||
try:
|
||||
response = await supervisor.invoke_plugin(
|
||||
method="plugin.invoke_action",
|
||||
plugin_id=plugin_id,
|
||||
component_name=component_name,
|
||||
args=invoke_args,
|
||||
timeout_ms=30000,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"运行时 Action {plugin_id}.{component_name} 执行失败: {exc}", exc_info=True)
|
||||
return False, str(exc)
|
||||
|
||||
payload = response.payload if isinstance(response.payload, dict) else {}
|
||||
success = bool(payload.get("success", False))
|
||||
result = payload.get("result")
|
||||
if isinstance(result, (list, tuple)):
|
||||
if len(result) >= 2:
|
||||
return bool(result[0]), "" if result[1] is None else str(result[1])
|
||||
if len(result) == 1:
|
||||
return bool(result[0]), ""
|
||||
if success:
|
||||
return True, "" if result is None else str(result)
|
||||
return False, "" if result is None else str(result)
|
||||
|
||||
return _executor
|
||||
|
||||
@staticmethod
|
||||
def _build_command_executor(
|
||||
supervisor: "PluginSupervisor",
|
||||
plugin_id: str,
|
||||
component_name: str,
|
||||
metadata: Dict[str, Any],
|
||||
) -> CommandExecutor:
|
||||
"""构造命令执行 RPC 闭包。
|
||||
|
||||
Args:
|
||||
supervisor: 负责该组件的监督器。
|
||||
plugin_id: 插件 ID。
|
||||
component_name: 组件名称。
|
||||
metadata: 命令组件元数据。
|
||||
|
||||
Returns:
|
||||
CommandExecutor: 兼容旧消息命令链的执行器。
|
||||
"""
|
||||
|
||||
async def _executor(**kwargs: Any) -> tuple[bool, Optional[str], bool]:
|
||||
"""将核心命令调用桥接到插件运行时。
|
||||
|
||||
Args:
|
||||
**kwargs: 命令执行上下文参数。
|
||||
|
||||
Returns:
|
||||
tuple[bool, Optional[str], bool]: ``(是否成功, 返回文本, 是否拦截后续消息)``。
|
||||
"""
|
||||
|
||||
message = kwargs.get("message")
|
||||
matched_groups = kwargs.get("matched_groups")
|
||||
plugin_config = kwargs.get("plugin_config")
|
||||
invoke_args: Dict[str, Any] = {
|
||||
"text": str(getattr(message, "processed_plain_text", "") or ""),
|
||||
"stream_id": str(getattr(message, "session_id", "") or ""),
|
||||
"matched_groups": matched_groups if isinstance(matched_groups, dict) else {},
|
||||
}
|
||||
if isinstance(plugin_config, dict):
|
||||
invoke_args["plugin_config"] = plugin_config
|
||||
|
||||
try:
|
||||
response = await supervisor.invoke_plugin(
|
||||
method="plugin.invoke_command",
|
||||
plugin_id=plugin_id,
|
||||
component_name=component_name,
|
||||
args=invoke_args,
|
||||
timeout_ms=30000,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"运行时 Command {plugin_id}.{component_name} 执行失败: {exc}", exc_info=True)
|
||||
return False, str(exc), True
|
||||
|
||||
payload = response.payload if isinstance(response.payload, dict) else {}
|
||||
success = bool(payload.get("success", False))
|
||||
result = payload.get("result")
|
||||
intercept = bool(metadata.get("intercept_message_level", 0))
|
||||
response_text: Optional[str]
|
||||
|
||||
if isinstance(result, (list, tuple)) and len(result) >= 3:
|
||||
response_text = None if result[1] is None else str(result[1])
|
||||
intercept = bool(result[2])
|
||||
else:
|
||||
response_text = None if result is None else str(result)
|
||||
|
||||
return success, response_text, intercept
|
||||
|
||||
return _executor
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_executor(supervisor: "PluginSupervisor", plugin_id: str, component_name: str) -> ToolExecutor:
|
||||
"""构造工具执行 RPC 闭包。
|
||||
|
||||
Args:
|
||||
supervisor: 负责该组件的监督器。
|
||||
plugin_id: 插件 ID。
|
||||
component_name: 组件名称。
|
||||
|
||||
Returns:
|
||||
ToolExecutor: 兼容旧 ToolExecutor 的异步执行器。
|
||||
"""
|
||||
|
||||
async def _executor(function_args: Dict[str, Any]) -> Any:
|
||||
"""将核心工具调用桥接到插件运行时。
|
||||
|
||||
Args:
|
||||
function_args: 工具调用参数。
|
||||
|
||||
Returns:
|
||||
Any: 插件工具返回结果;若结果不是字典,则会包装为 ``{"content": ...}``。
|
||||
"""
|
||||
|
||||
try:
|
||||
response = await supervisor.invoke_plugin(
|
||||
method="plugin.invoke_tool",
|
||||
plugin_id=plugin_id,
|
||||
component_name=component_name,
|
||||
args=function_args,
|
||||
timeout_ms=30000,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"运行时 Tool {plugin_id}.{component_name} 执行失败: {exc}", exc_info=True)
|
||||
return {"content": f"工具 {component_name} 执行失败: {exc}"}
|
||||
|
||||
payload = response.payload if isinstance(response.payload, dict) else {}
|
||||
result = payload.get("result")
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
return {"content": "" if result is None else str(result)}
|
||||
|
||||
return _executor
|
||||
|
||||
def get_action_info(self, name: str) -> Optional[ActionInfo]:
|
||||
"""获取指定动作的信息。
|
||||
|
||||
Args:
|
||||
name: 动作名称。
|
||||
|
||||
Returns:
|
||||
Optional[ActionInfo]: 匹配到的动作信息。
|
||||
"""
|
||||
|
||||
matched_entry = self._get_unique_component_entry(ComponentType.ACTION, name)
|
||||
if matched_entry is None:
|
||||
return None
|
||||
_supervisor, entry = matched_entry
|
||||
return self._build_action_info(entry) # type: ignore[arg-type]
|
||||
|
||||
def get_action_executor(self, name: str) -> Optional[ActionExecutor]:
|
||||
"""获取指定动作的执行器。
|
||||
|
||||
Args:
|
||||
name: 动作名称。
|
||||
|
||||
Returns:
|
||||
Optional[ActionExecutor]: 运行时 RPC 执行闭包。
|
||||
"""
|
||||
|
||||
matched_entry = self._get_unique_component_entry(ComponentType.ACTION, name)
|
||||
if matched_entry is None:
|
||||
return None
|
||||
supervisor, entry = matched_entry
|
||||
return self._build_action_executor(supervisor, entry.plugin_id, entry.name)
|
||||
|
||||
def get_default_actions(self) -> Dict[str, ActionInfo]:
|
||||
"""获取当前默认启用的动作集合。
|
||||
|
||||
Returns:
|
||||
Dict[str, ActionInfo]: 动作名到动作信息的映射。
|
||||
"""
|
||||
|
||||
action_infos = self._collect_unique_component_infos(ComponentType.ACTION)
|
||||
return {name: info for name, info in action_infos.items() if isinstance(info, ActionInfo) and info.enabled}
|
||||
|
||||
def find_command_by_text(self, text: str) -> Optional[Tuple[CommandExecutor, dict, CommandInfo]]:
|
||||
"""根据文本查找匹配的命令。
|
||||
|
||||
Args:
|
||||
text: 待匹配的文本内容。
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[CommandExecutor, dict, CommandInfo]]: 匹配结果。
|
||||
"""
|
||||
|
||||
for supervisor in self._iter_supervisors():
|
||||
match_result = supervisor.component_registry.find_command_by_text(text)
|
||||
if match_result is None:
|
||||
continue
|
||||
|
||||
entry, matched_groups = match_result
|
||||
command_info = self._build_command_info(entry) # type: ignore[arg-type]
|
||||
command_executor = self._build_command_executor(
|
||||
supervisor,
|
||||
entry.plugin_id,
|
||||
entry.name,
|
||||
dict(entry.metadata),
|
||||
)
|
||||
return command_executor, matched_groups, command_info
|
||||
return None
|
||||
|
||||
def get_tool_info(self, name: str) -> Optional[ToolInfo]:
|
||||
"""获取指定工具的信息。
|
||||
|
||||
Args:
|
||||
name: 工具名称。
|
||||
|
||||
Returns:
|
||||
Optional[ToolInfo]: 匹配到的工具信息。
|
||||
"""
|
||||
|
||||
matched_entry = self._get_unique_component_entry(ComponentType.TOOL, name)
|
||||
if matched_entry is None:
|
||||
return None
|
||||
_supervisor, entry = matched_entry
|
||||
return self._build_tool_info(entry) # type: ignore[arg-type]
|
||||
|
||||
def get_tool_executor(self, name: str) -> Optional[ToolExecutor]:
|
||||
"""获取指定工具的执行器。
|
||||
|
||||
Args:
|
||||
name: 工具名称。
|
||||
|
||||
Returns:
|
||||
Optional[ToolExecutor]: 运行时 RPC 执行闭包。
|
||||
"""
|
||||
|
||||
matched_entry = self._get_unique_component_entry(ComponentType.TOOL, name)
|
||||
if matched_entry is None:
|
||||
return None
|
||||
supervisor, entry = matched_entry
|
||||
return self._build_tool_executor(supervisor, entry.plugin_id, entry.name)
|
||||
|
||||
def get_llm_available_tools(self) -> Dict[str, ToolInfo]:
|
||||
"""获取当前可供 LLM 选择的工具集合。
|
||||
|
||||
Returns:
|
||||
Dict[str, ToolInfo]: 工具名到工具信息的映射。
|
||||
"""
|
||||
|
||||
tool_infos = self._collect_unique_component_infos(ComponentType.TOOL)
|
||||
return {name: info for name, info in tool_infos.items() if isinstance(info, ToolInfo) and info.enabled}
|
||||
|
||||
def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
|
||||
"""获取某类组件的全部信息。
|
||||
|
||||
Args:
|
||||
component_type: 组件类型。
|
||||
|
||||
Returns:
|
||||
Dict[str, ComponentInfo]: 组件名到组件信息的映射。
|
||||
"""
|
||||
|
||||
return self._collect_unique_component_infos(component_type)
|
||||
|
||||
def get_plugin_config(self, plugin_name: str) -> Optional[dict]:
|
||||
"""读取指定插件的配置文件内容。
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称。
|
||||
|
||||
Returns:
|
||||
Optional[dict]: 读取成功时返回配置字典;未找到时返回 ``None``。
|
||||
"""
|
||||
|
||||
runtime_manager = self._get_runtime_manager()
|
||||
try:
|
||||
supervisor = runtime_manager._get_supervisor_for_plugin(plugin_name)
|
||||
except RuntimeError as exc:
|
||||
logger.error(f"读取插件配置失败: {exc}")
|
||||
return None
|
||||
|
||||
if supervisor is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return runtime_manager._load_plugin_config_for_supervisor(supervisor, plugin_name)
|
||||
except Exception as exc:
|
||||
logger.error(f"读取插件 {plugin_name} 配置失败: {exc}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
component_query_service = ComponentQueryService()
|
||||
349
src/plugin_runtime/host/api_registry.py
Normal file
349
src/plugin_runtime/host/api_registry.py
Normal file
@@ -0,0 +1,349 @@
|
||||
"""Host 侧插件 API 动态注册表。"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("plugin_runtime.host.api_registry")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class APIEntry:
|
||||
"""API 组件条目。"""
|
||||
|
||||
name: str
|
||||
plugin_id: str
|
||||
description: str = ""
|
||||
version: str = "1"
|
||||
public: bool = False
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
enabled: bool = True
|
||||
handler_name: str = ""
|
||||
dynamic: bool = False
|
||||
offline_reason: str = ""
|
||||
disabled_session: Set[str] = field(default_factory=set)
|
||||
full_name: str = field(init=False)
|
||||
registry_key: str = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""规范化 API 条目字段。"""
|
||||
|
||||
self.name = str(self.name or "").strip()
|
||||
self.plugin_id = str(self.plugin_id or "").strip()
|
||||
self.description = str(self.description or "").strip()
|
||||
self.version = str(self.version or "1").strip() or "1"
|
||||
self.handler_name = str(self.handler_name or self.name).strip() or self.name
|
||||
self.offline_reason = str(self.offline_reason or "").strip()
|
||||
self.full_name = f"{self.plugin_id}.{self.name}"
|
||||
self.registry_key = APIRegistry.build_registry_key(self.plugin_id, self.name, self.version)
|
||||
|
||||
@classmethod
|
||||
def from_metadata(cls, name: str, plugin_id: str, metadata: Dict[str, Any]) -> "APIEntry":
|
||||
"""根据 Runner 上报的元数据构造 API 条目。"""
|
||||
|
||||
safe_metadata = dict(metadata)
|
||||
return cls(
|
||||
name=name,
|
||||
plugin_id=plugin_id,
|
||||
description=str(safe_metadata.get("description", "") or ""),
|
||||
version=str(safe_metadata.get("version", "1") or "1"),
|
||||
public=bool(safe_metadata.get("public", False)),
|
||||
metadata=safe_metadata,
|
||||
enabled=bool(safe_metadata.get("enabled", True)),
|
||||
handler_name=str(safe_metadata.get("handler_name", name) or name),
|
||||
dynamic=bool(safe_metadata.get("dynamic", False)),
|
||||
offline_reason=str(safe_metadata.get("offline_reason", "") or ""),
|
||||
)
|
||||
|
||||
|
||||
class APIRegistry:
|
||||
"""Host 侧插件 API 动态注册表。
|
||||
|
||||
该注册表不直接面向 Runner,而是复用插件组件注册/卸载事件,
|
||||
维护面向 API 调用场景的专用索引。
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化 API 注册表。"""
|
||||
|
||||
self._apis: Dict[str, APIEntry] = {}
|
||||
self._by_full_name: Dict[str, List[APIEntry]] = {}
|
||||
self._by_plugin: Dict[str, List[APIEntry]] = {}
|
||||
self._by_name: Dict[str, List[APIEntry]] = {}
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空全部 API 注册状态。"""
|
||||
|
||||
self._apis.clear()
|
||||
self._by_full_name.clear()
|
||||
self._by_plugin.clear()
|
||||
self._by_name.clear()
|
||||
|
||||
@staticmethod
|
||||
def _is_api_component(component_type: Any) -> bool:
|
||||
"""判断组件声明是否属于 API。"""
|
||||
|
||||
return str(component_type or "").strip().upper() == "API"
|
||||
|
||||
@staticmethod
|
||||
def _normalize_query_version(version: Any) -> str:
|
||||
"""规范化查询使用的版本字符串。"""
|
||||
|
||||
return str(version or "").strip()
|
||||
|
||||
@classmethod
|
||||
def _split_reference(cls, reference: str, version: Any = "") -> Tuple[str, str]:
|
||||
"""解析可能带 ``@version`` 后缀的 API 引用。"""
|
||||
|
||||
normalized_reference = str(reference or "").strip()
|
||||
normalized_version = cls._normalize_query_version(version)
|
||||
if normalized_reference and not normalized_version and "@" in normalized_reference:
|
||||
candidate_reference, candidate_version = normalized_reference.rsplit("@", 1)
|
||||
candidate_reference = candidate_reference.strip()
|
||||
candidate_version = candidate_version.strip()
|
||||
if candidate_reference and candidate_version:
|
||||
normalized_reference = candidate_reference
|
||||
normalized_version = candidate_version
|
||||
return normalized_reference, normalized_version
|
||||
|
||||
@staticmethod
|
||||
def build_registry_key(plugin_id: str, name: str, version: str) -> str:
|
||||
"""构造 API 注册表唯一键。"""
|
||||
|
||||
normalized_full_name = f"{str(plugin_id or '').strip()}.{str(name or '').strip()}"
|
||||
normalized_version = str(version or "1").strip() or "1"
|
||||
return f"{normalized_full_name}@{normalized_version}"
|
||||
|
||||
@staticmethod
|
||||
def check_api_enabled(entry: APIEntry, session_id: Optional[str] = None) -> bool:
|
||||
"""判断 API 条目当前是否处于启用状态。"""
|
||||
|
||||
if session_id and session_id in entry.disabled_session:
|
||||
return False
|
||||
return entry.enabled
|
||||
|
||||
def register_api(self, name: str, plugin_id: str, metadata: Dict[str, Any]) -> bool:
|
||||
"""注册单个 API 条目。"""
|
||||
|
||||
normalized_name = str(name or "").strip()
|
||||
if not normalized_name:
|
||||
logger.warning(f"插件 {plugin_id} 存在空 API 名称声明,已忽略")
|
||||
return False
|
||||
|
||||
entry = APIEntry.from_metadata(name=normalized_name, plugin_id=plugin_id, metadata=metadata)
|
||||
existing_entry = self._apis.get(entry.registry_key)
|
||||
if existing_entry is not None:
|
||||
logger.warning(f"API {entry.registry_key} 已存在,覆盖旧条目")
|
||||
self._remove_entry(existing_entry)
|
||||
|
||||
self._apis[entry.registry_key] = entry
|
||||
self._by_full_name.setdefault(entry.full_name, []).append(entry)
|
||||
self._by_plugin.setdefault(plugin_id, []).append(entry)
|
||||
self._by_name.setdefault(entry.name, []).append(entry)
|
||||
return True
|
||||
|
||||
def register_plugin_apis(self, plugin_id: str, components: List[Dict[str, Any]]) -> int:
|
||||
"""批量注册某个插件声明的全部 API。"""
|
||||
|
||||
count = 0
|
||||
for component in components:
|
||||
if not self._is_api_component(component.get("component_type")):
|
||||
continue
|
||||
if self.register_api(
|
||||
name=str(component.get("name", "") or ""),
|
||||
plugin_id=plugin_id,
|
||||
metadata=component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {},
|
||||
):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def replace_plugin_dynamic_apis(
|
||||
self,
|
||||
plugin_id: str,
|
||||
components: List[Dict[str, Any]],
|
||||
*,
|
||||
offline_reason: str = "动态 API 已下线",
|
||||
) -> Tuple[int, int]:
|
||||
"""替换指定插件当前声明的动态 API 集合。"""
|
||||
|
||||
normalized_offline_reason = str(offline_reason or "").strip() or "动态 API 已下线"
|
||||
desired_registry_keys: Set[str] = set()
|
||||
registered_count = 0
|
||||
|
||||
for component in components:
|
||||
if not self._is_api_component(component.get("component_type")):
|
||||
continue
|
||||
metadata = component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {}
|
||||
dynamic_metadata = dict(metadata)
|
||||
dynamic_metadata["dynamic"] = True
|
||||
dynamic_metadata.pop("offline_reason", None)
|
||||
|
||||
entry = APIEntry.from_metadata(
|
||||
name=str(component.get("name", "") or ""),
|
||||
plugin_id=plugin_id,
|
||||
metadata=dynamic_metadata,
|
||||
)
|
||||
desired_registry_keys.add(entry.registry_key)
|
||||
if self.register_api(entry.name, plugin_id, dynamic_metadata):
|
||||
registered_count += 1
|
||||
|
||||
offlined_count = 0
|
||||
for entry in list(self._by_plugin.get(plugin_id, [])):
|
||||
if not entry.dynamic or entry.registry_key in desired_registry_keys:
|
||||
continue
|
||||
entry.enabled = False
|
||||
entry.offline_reason = normalized_offline_reason
|
||||
entry.metadata["offline_reason"] = normalized_offline_reason
|
||||
offlined_count += 1
|
||||
|
||||
return registered_count, offlined_count
|
||||
|
||||
def _remove_entry(self, entry: APIEntry) -> None:
|
||||
"""从全部索引中移除单个 API 条目。"""
|
||||
|
||||
self._apis.pop(entry.registry_key, None)
|
||||
|
||||
full_name_entries = self._by_full_name.get(entry.full_name)
|
||||
if full_name_entries is not None:
|
||||
self._by_full_name[entry.full_name] = [
|
||||
candidate for candidate in full_name_entries if candidate is not entry
|
||||
]
|
||||
if not self._by_full_name[entry.full_name]:
|
||||
self._by_full_name.pop(entry.full_name, None)
|
||||
|
||||
plugin_entries = self._by_plugin.get(entry.plugin_id)
|
||||
if plugin_entries is not None:
|
||||
self._by_plugin[entry.plugin_id] = [candidate for candidate in plugin_entries if candidate is not entry]
|
||||
if not self._by_plugin[entry.plugin_id]:
|
||||
self._by_plugin.pop(entry.plugin_id, None)
|
||||
|
||||
name_entries = self._by_name.get(entry.name)
|
||||
if name_entries is not None:
|
||||
self._by_name[entry.name] = [candidate for candidate in name_entries if candidate is not entry]
|
||||
if not self._by_name[entry.name]:
|
||||
self._by_name.pop(entry.name, None)
|
||||
|
||||
def remove_apis_by_plugin(self, plugin_id: str) -> int:
|
||||
"""移除某个插件的全部 API。"""
|
||||
|
||||
entries = list(self._by_plugin.get(plugin_id, []))
|
||||
for entry in entries:
|
||||
self._remove_entry(entry)
|
||||
return len(entries)
|
||||
|
||||
def get_api_by_full_name(
|
||||
self,
|
||||
full_name: str,
|
||||
*,
|
||||
version: str = "",
|
||||
enabled_only: bool = True,
|
||||
session_id: Optional[str] = None,
|
||||
) -> Optional[APIEntry]:
|
||||
"""按完整名查询单个 API。"""
|
||||
|
||||
normalized_full_name, normalized_version = self._split_reference(full_name, version)
|
||||
if not normalized_full_name:
|
||||
return None
|
||||
|
||||
if normalized_version:
|
||||
entry = self._apis.get(f"{normalized_full_name}@{normalized_version}")
|
||||
if entry is None:
|
||||
return None
|
||||
if enabled_only and not self.check_api_enabled(entry, session_id):
|
||||
return None
|
||||
return entry
|
||||
|
||||
candidates = list(self._by_full_name.get(normalized_full_name, []))
|
||||
filtered_entries = [
|
||||
entry
|
||||
for entry in candidates
|
||||
if not enabled_only or self.check_api_enabled(entry, session_id)
|
||||
]
|
||||
if len(filtered_entries) != 1:
|
||||
return None
|
||||
return filtered_entries[0]
|
||||
|
||||
def get_api(
|
||||
self,
|
||||
plugin_id: str,
|
||||
name: str,
|
||||
*,
|
||||
version: str = "",
|
||||
enabled_only: bool = True,
|
||||
session_id: Optional[str] = None,
|
||||
) -> Optional[APIEntry]:
|
||||
"""按插件 ID、短名与版本查询单个 API。"""
|
||||
|
||||
return self.get_api_by_full_name(
|
||||
f"{plugin_id}.{name}",
|
||||
version=version,
|
||||
enabled_only=enabled_only,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
def get_apis(
|
||||
self,
|
||||
*,
|
||||
plugin_id: Optional[str] = None,
|
||||
name: str = "",
|
||||
version: str = "",
|
||||
enabled_only: bool = True,
|
||||
session_id: Optional[str] = None,
|
||||
) -> List[APIEntry]:
|
||||
"""查询 API 列表。"""
|
||||
|
||||
normalized_name = str(name or "").strip()
|
||||
normalized_version = self._normalize_query_version(version)
|
||||
|
||||
if plugin_id:
|
||||
candidates = list(self._by_plugin.get(plugin_id, []))
|
||||
elif normalized_name:
|
||||
candidates = list(self._by_name.get(normalized_name, []))
|
||||
else:
|
||||
candidates = list(self._apis.values())
|
||||
|
||||
filtered_entries: List[APIEntry] = []
|
||||
for entry in candidates:
|
||||
if plugin_id and entry.plugin_id != plugin_id:
|
||||
continue
|
||||
if normalized_name and entry.name != normalized_name:
|
||||
continue
|
||||
if normalized_version and entry.version != normalized_version:
|
||||
continue
|
||||
if enabled_only and not self.check_api_enabled(entry, session_id):
|
||||
continue
|
||||
filtered_entries.append(entry)
|
||||
|
||||
filtered_entries.sort(key=lambda entry: (entry.plugin_id, entry.name, entry.version))
|
||||
return filtered_entries
|
||||
|
||||
def toggle_api_status(
|
||||
self,
|
||||
full_name: str,
|
||||
enabled: bool,
|
||||
*,
|
||||
version: str = "",
|
||||
session_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""设置指定 API 的启用状态。"""
|
||||
|
||||
entry = self.get_api_by_full_name(
|
||||
full_name,
|
||||
version=version,
|
||||
enabled_only=False,
|
||||
session_id=session_id,
|
||||
)
|
||||
if entry is None:
|
||||
return False
|
||||
if session_id:
|
||||
if enabled:
|
||||
entry.disabled_session.discard(session_id)
|
||||
else:
|
||||
entry.disabled_session.add(session_id)
|
||||
else:
|
||||
entry.enabled = enabled
|
||||
if enabled:
|
||||
entry.offline_reason = ""
|
||||
entry.metadata.pop("offline_reason", None)
|
||||
return True
|
||||
67
src/plugin_runtime/host/authorization.py
Normal file
67
src/plugin_runtime/host/authorization.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""授权管理器
|
||||
|
||||
负责管理插件的能力授权以及校验
|
||||
每个插件在 manifest 中声明能力需求,Host 启动时签发能力令牌。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
_ALWAYS_ALLOWED_CAPABILITIES = frozenset({"api.replace_dynamic"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class CapabilityPermissionToken:
|
||||
"""能力令牌"""
|
||||
|
||||
plugin_id: str
|
||||
capabilities: Set[str] = field(default_factory=set)
|
||||
|
||||
|
||||
class AuthorizationManager:
|
||||
"""授权管理器
|
||||
|
||||
管理所有插件的能力令牌,提供授权校验。
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._permission_tokens: Dict[str, CapabilityPermissionToken] = {}
|
||||
|
||||
def register_plugin(self, plugin_id: str, capabilities: List[str]) -> CapabilityPermissionToken:
|
||||
"""为插件签发能力令牌"""
|
||||
token = CapabilityPermissionToken(plugin_id=plugin_id, capabilities=set(capabilities))
|
||||
self._permission_tokens[plugin_id] = token
|
||||
return token
|
||||
|
||||
def revoke_permission_token(self, plugin_id: str):
|
||||
"""移除插件的能力令牌。"""
|
||||
self._permission_tokens.pop(plugin_id, None)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空所有能力令牌。"""
|
||||
self._permission_tokens.clear()
|
||||
|
||||
def check_capability(self, plugin_id: str, capability: str) -> Tuple[bool, str]:
|
||||
# sourcery skip: assign-if-exp, reintroduce-else, swap-if-else-branches, use-named-expression
|
||||
"""检查插件是否有权调用某项能力
|
||||
|
||||
Returns:
|
||||
return (bool, str): (是否有此能力, 原因)
|
||||
"""
|
||||
if capability in _ALWAYS_ALLOWED_CAPABILITIES:
|
||||
return True, ""
|
||||
|
||||
token = self._permission_tokens.get(plugin_id)
|
||||
if not token:
|
||||
return False, f"插件 {plugin_id} 未注册能力令牌"
|
||||
if capability not in token.capabilities:
|
||||
return False, f"插件 {plugin_id} 未获授权能力: {capability}"
|
||||
return True, ""
|
||||
|
||||
def get_token(self, plugin_id: str) -> Optional[CapabilityPermissionToken]:
|
||||
"""获取插件的能力令牌"""
|
||||
return self._permission_tokens.get(plugin_id)
|
||||
|
||||
def list_plugins(self) -> List[str]:
|
||||
"""列出所有已注册的插件"""
|
||||
return list(self._permission_tokens.keys())
|
||||
@@ -4,21 +4,19 @@ Host 端实现的能力服务,处理来自插件的 cap.* 请求。
|
||||
每个能力方法被注册到 RPC Server,接收 Runner 转发的请求并执行实际操作。
|
||||
"""
|
||||
|
||||
from typing import Any, Awaitable, Callable, Dict, List
|
||||
from typing import Any, Callable, Dict, List, Coroutine, TYPE_CHECKING
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_runtime.host.policy_engine import PolicyEngine
|
||||
from src.plugin_runtime.protocol.envelope import (
|
||||
CapabilityRequestPayload,
|
||||
CapabilityResponsePayload,
|
||||
Envelope,
|
||||
)
|
||||
from src.plugin_runtime.protocol.envelope import CapabilityRequestPayload, CapabilityResponsePayload, Envelope
|
||||
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.plugin_runtime.host.authorization import AuthorizationManager
|
||||
|
||||
logger = get_logger("plugin_runtime.host.capability_service")
|
||||
|
||||
# 能力实现函数类型: (plugin_id, capability, args) -> result
|
||||
CapabilityImpl = Callable[[str, str, Dict[str, Any]], Awaitable[Any]]
|
||||
CapabilityImpl = Callable[[str, str, Dict[str, Any]], Coroutine[Any, Any, Any]]
|
||||
|
||||
|
||||
class CapabilityService:
|
||||
@@ -31,8 +29,13 @@ class CapabilityService:
|
||||
4. 执行实际操作并返回结果
|
||||
"""
|
||||
|
||||
def __init__(self, policy_engine: PolicyEngine) -> None:
|
||||
self._policy = policy_engine
|
||||
def __init__(self, authorization: "AuthorizationManager") -> None:
|
||||
"""初始化能力服务。
|
||||
|
||||
Args:
|
||||
authorization: 能力授权管理器。
|
||||
"""
|
||||
self._authorization = authorization
|
||||
# capability_name -> implementation
|
||||
self._implementations: Dict[str, CapabilityImpl] = {}
|
||||
|
||||
@@ -56,46 +59,32 @@ class CapabilityService:
|
||||
|
||||
try:
|
||||
req = CapabilityRequestPayload.model_validate(envelope.payload)
|
||||
except Exception as e:
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_BAD_PAYLOAD.value,
|
||||
f"能力调用 payload 格式错误: {e}",
|
||||
)
|
||||
except Exception as exc:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, f"能力调用 payload 非法: {exc}")
|
||||
|
||||
capability = req.capability
|
||||
args = req.args
|
||||
|
||||
# 1. 权限校验
|
||||
allowed, reason = self._policy.check_capability(plugin_id, capability, envelope.generation)
|
||||
allowed, reason = self._authorization.check_capability(plugin_id, capability)
|
||||
if not allowed:
|
||||
error_code = (
|
||||
ErrorCode.E_GENERATION_MISMATCH if "generation 不匹配" in reason else ErrorCode.E_CAPABILITY_DENIED
|
||||
)
|
||||
return envelope.make_error_response(
|
||||
error_code.value,
|
||||
reason,
|
||||
)
|
||||
return envelope.make_error_response(ErrorCode.E_CAPABILITY_DENIED.value, reason)
|
||||
|
||||
# 2. 查找实现
|
||||
impl = self._implementations.get(capability)
|
||||
if impl is None:
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_METHOD_NOT_ALLOWED.value,
|
||||
f"未注册的能力: {capability}",
|
||||
)
|
||||
return envelope.make_error_response(ErrorCode.E_METHOD_NOT_ALLOWED.value, f"未注册的能力: {capability}")
|
||||
|
||||
# 3. 执行
|
||||
try:
|
||||
result = await impl(plugin_id, capability, req.args)
|
||||
result = await impl(plugin_id, capability, args)
|
||||
resp_payload = CapabilityResponsePayload(success=True, result=result)
|
||||
return envelope.make_response(payload=resp_payload.model_dump())
|
||||
except RPCError as e:
|
||||
return envelope.make_error_response(e.code.value, e.message, e.details)
|
||||
except Exception as e:
|
||||
logger.error(f"能力 {capability} 执行异常: {e}", exc_info=True)
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_CAPABILITY_FAILED.value,
|
||||
str(e),
|
||||
)
|
||||
return envelope.make_error_response(ErrorCode.E_CAPABILITY_FAILED.value, str(e))
|
||||
|
||||
def list_capabilities(self) -> List[str]:
|
||||
"""列出所有已注册的能力"""
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Host-side ComponentRegistry
|
||||
|
||||
对齐旧系统 component_registry.py 的核心能力:
|
||||
- 按类型注册组件(action / command / tool / event_handler / workflow_step)
|
||||
- 按类型注册组件(action / command / tool / event_handler / workflow_handler / message_gateway)
|
||||
- 命名空间 (plugin_id.component_name)
|
||||
- 命令正则匹配
|
||||
- 组件启用/禁用
|
||||
@@ -9,8 +9,10 @@
|
||||
- 注册统计
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Set, TypedDict, Tuple
|
||||
|
||||
import contextlib
|
||||
import re
|
||||
|
||||
from src.common.logger import get_logger
|
||||
@@ -18,8 +20,28 @@ from src.common.logger import get_logger
|
||||
logger = get_logger("plugin_runtime.host.component_registry")
|
||||
|
||||
|
||||
class RegisteredComponent:
|
||||
"""已注册的组件条目"""
|
||||
class ComponentTypes(str, Enum):
|
||||
ACTION = "ACTION"
|
||||
COMMAND = "COMMAND"
|
||||
TOOL = "TOOL"
|
||||
EVENT_HANDLER = "EVENT_HANDLER"
|
||||
HOOK_HANDLER = "HOOK_HANDLER"
|
||||
MESSAGE_GATEWAY = "MESSAGE_GATEWAY"
|
||||
|
||||
|
||||
class StatusDict(TypedDict):
|
||||
total: int
|
||||
action: int
|
||||
command: int
|
||||
tool: int
|
||||
event_handler: int
|
||||
hook_handler: int
|
||||
message_gateway: int
|
||||
plugins: int
|
||||
|
||||
|
||||
class ComponentEntry:
|
||||
"""组件条目"""
|
||||
|
||||
__slots__ = (
|
||||
"name",
|
||||
@@ -28,31 +50,120 @@ class RegisteredComponent:
|
||||
"plugin_id",
|
||||
"metadata",
|
||||
"enabled",
|
||||
"_compiled_pattern",
|
||||
"compiled_pattern",
|
||||
"disabled_session",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
component_type: str,
|
||||
plugin_id: str,
|
||||
metadata: Dict[str, Any],
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.full_name = f"{plugin_id}.{name}"
|
||||
self.component_type = component_type
|
||||
self.plugin_id = plugin_id
|
||||
self.metadata = metadata
|
||||
self.enabled = metadata.get("enabled", True)
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
self.name: str = name
|
||||
self.full_name: str = f"{plugin_id}.{name}"
|
||||
self.component_type: ComponentTypes = ComponentTypes(component_type)
|
||||
self.plugin_id: str = plugin_id
|
||||
self.metadata: Dict[str, Any] = metadata
|
||||
self.enabled: bool = metadata.get("enabled", True)
|
||||
self.disabled_session: Set[str] = set()
|
||||
|
||||
# 预编译命令正则(仅 command 类型)
|
||||
self._compiled_pattern: Optional[re.Pattern] = None
|
||||
if component_type == "command":
|
||||
if pattern := metadata.get("command_pattern", ""):
|
||||
try:
|
||||
self._compiled_pattern = re.compile(pattern)
|
||||
except re.error as e:
|
||||
logger.warning(f"命令 {self.full_name} 正则编译失败: {e}")
|
||||
|
||||
class ActionEntry(ComponentEntry):
|
||||
"""Action 组件条目"""
|
||||
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
super().__init__(name, component_type, plugin_id, metadata)
|
||||
|
||||
|
||||
class CommandEntry(ComponentEntry):
|
||||
"""Command 组件条目"""
|
||||
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
super().__init__(name, component_type, plugin_id, metadata)
|
||||
self.aliases: List[str] = metadata.get("aliases", [])
|
||||
self.compiled_pattern: Optional[re.Pattern] = None
|
||||
if pattern := metadata.get("command_pattern", ""):
|
||||
try:
|
||||
self.compiled_pattern = re.compile(pattern)
|
||||
except (re.error, TypeError) as e:
|
||||
logger.warning(f"命令 {self.full_name} 正则编译失败: {e}")
|
||||
|
||||
|
||||
class ToolEntry(ComponentEntry):
|
||||
"""Tool 组件条目"""
|
||||
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
self.description: str = metadata.get("description", "")
|
||||
self.parameters: List[Dict[str, Any]] = metadata.get("parameters", [])
|
||||
self.parameters_raw: List[Dict[str, Any]] = metadata.get("parameters_raw", [])
|
||||
super().__init__(name, component_type, plugin_id, metadata)
|
||||
|
||||
|
||||
class EventHandlerEntry(ComponentEntry):
|
||||
"""EventHandler 组件条目"""
|
||||
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
self.event_type: str = metadata.get("event_type", "")
|
||||
self.weight: int = metadata.get("weight", 0)
|
||||
self.intercept_message: bool = metadata.get("intercept_message", False)
|
||||
super().__init__(name, component_type, plugin_id, metadata)
|
||||
|
||||
|
||||
class HookHandlerEntry(ComponentEntry):
|
||||
"""WorkflowHandler 组件条目"""
|
||||
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
self.stage: str = metadata.get("stage", "")
|
||||
self.priority: int = metadata.get("priority", 0)
|
||||
self.blocking: bool = metadata.get("blocking", False)
|
||||
super().__init__(name, component_type, plugin_id, metadata)
|
||||
|
||||
|
||||
class MessageGatewayEntry(ComponentEntry):
|
||||
"""MessageGateway 组件条目"""
|
||||
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
self.route_type: str = self._normalize_route_type(metadata.get("route_type", ""))
|
||||
self.platform: str = str(metadata.get("platform", "") or "").strip()
|
||||
self.protocol: str = str(metadata.get("protocol", "") or "").strip()
|
||||
self.account_id: str = str(metadata.get("account_id", "") or "").strip()
|
||||
self.scope: str = str(metadata.get("scope", "") or "").strip()
|
||||
super().__init__(name, component_type, plugin_id, metadata)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_route_type(raw_value: Any) -> str:
|
||||
"""规范化消息网关路由类型。
|
||||
|
||||
Args:
|
||||
raw_value: 原始路由类型值。
|
||||
|
||||
Returns:
|
||||
str: 规范化后的路由类型。
|
||||
|
||||
Raises:
|
||||
ValueError: 当路由类型不受支持时抛出。
|
||||
"""
|
||||
|
||||
normalized_value = str(raw_value or "").strip().lower()
|
||||
route_type_aliases = {
|
||||
"send": "send",
|
||||
"receive": "receive",
|
||||
"recv": "receive",
|
||||
"recive": "receive",
|
||||
"duplex": "duplex",
|
||||
}
|
||||
route_type = route_type_aliases.get(normalized_value)
|
||||
if route_type is None:
|
||||
raise ValueError(f"MessageGateway 路由类型不合法: {raw_value}")
|
||||
return route_type
|
||||
|
||||
@property
|
||||
def supports_send(self) -> bool:
|
||||
"""返回当前网关是否支持出站。"""
|
||||
|
||||
return self.route_type in {"send", "duplex"}
|
||||
|
||||
@property
|
||||
def supports_receive(self) -> bool:
|
||||
"""返回当前网关是否支持入站。"""
|
||||
|
||||
return self.route_type in {"receive", "duplex"}
|
||||
|
||||
|
||||
class ComponentRegistry:
|
||||
@@ -64,19 +175,32 @@ class ComponentRegistry:
|
||||
|
||||
def __init__(self) -> None:
|
||||
# 全量索引
|
||||
self._components: Dict[str, RegisteredComponent] = {} # full_name -> comp
|
||||
self._components: Dict[str, ComponentEntry] = {} # full_name -> comp
|
||||
|
||||
# 按类型索引
|
||||
self._by_type: Dict[str, Dict[str, RegisteredComponent]] = {
|
||||
"action": {},
|
||||
"command": {},
|
||||
"tool": {},
|
||||
"event_handler": {},
|
||||
"workflow_step": {},
|
||||
}
|
||||
self._by_type: Dict[ComponentTypes, Dict[str, ComponentEntry]] = {
|
||||
comp_type: {} for comp_type in ComponentTypes
|
||||
} # component_type -> (full_name -> comp)
|
||||
|
||||
# 按插件索引
|
||||
self._by_plugin: Dict[str, List[RegisteredComponent]] = {}
|
||||
self._by_plugin: Dict[str, List[ComponentEntry]] = {}
|
||||
|
||||
@staticmethod
|
||||
def _normalize_component_type(component_type: str) -> ComponentTypes:
|
||||
"""规范化组件类型输入。
|
||||
|
||||
Args:
|
||||
component_type: 原始组件类型字符串。
|
||||
|
||||
Returns:
|
||||
ComponentTypes: 规范化后的组件类型枚举。
|
||||
|
||||
Raises:
|
||||
ValueError: 当组件类型不受支持时抛出。
|
||||
"""
|
||||
|
||||
normalized_value = str(component_type or "").strip().upper()
|
||||
return ComponentTypes(normalized_value)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空全部组件注册状态。"""
|
||||
@@ -85,47 +209,64 @@ class ComponentRegistry:
|
||||
type_dict.clear()
|
||||
self._by_plugin.clear()
|
||||
|
||||
# ──── 注册 / 注销 ─────────────────────────────────────────
|
||||
# ====== 注册 / 注销 ======
|
||||
def register_component(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> bool:
|
||||
"""注册单个组件
|
||||
|
||||
Args:
|
||||
name: 组件名称(不含插件id前缀)
|
||||
component_type: 组件类型(如 `ACTION`、`COMMAND` 等)
|
||||
plugin_id: 插件id
|
||||
metadata: 组件元数据
|
||||
Returns:
|
||||
success (bool): 是否成功注册(失败原因通常是组件类型无效)
|
||||
"""
|
||||
try:
|
||||
normalized_type = self._normalize_component_type(component_type)
|
||||
if normalized_type == ComponentTypes.ACTION:
|
||||
comp = ActionEntry(name, normalized_type.value, plugin_id, metadata)
|
||||
elif normalized_type == ComponentTypes.COMMAND:
|
||||
comp = CommandEntry(name, normalized_type.value, plugin_id, metadata)
|
||||
elif normalized_type == ComponentTypes.TOOL:
|
||||
comp = ToolEntry(name, normalized_type.value, plugin_id, metadata)
|
||||
elif normalized_type == ComponentTypes.EVENT_HANDLER:
|
||||
comp = EventHandlerEntry(name, normalized_type.value, plugin_id, metadata)
|
||||
elif normalized_type == ComponentTypes.HOOK_HANDLER:
|
||||
comp = HookHandlerEntry(name, normalized_type.value, plugin_id, metadata)
|
||||
elif normalized_type == ComponentTypes.MESSAGE_GATEWAY:
|
||||
comp = MessageGatewayEntry(name, normalized_type.value, plugin_id, metadata)
|
||||
else:
|
||||
raise ValueError(f"组件类型 {component_type} 不存在")
|
||||
except ValueError:
|
||||
logger.error(f"组件类型 {component_type} 不存在")
|
||||
return False
|
||||
|
||||
def register_component(
|
||||
self,
|
||||
name: str,
|
||||
component_type: str,
|
||||
plugin_id: str,
|
||||
metadata: Dict[str, Any],
|
||||
) -> bool:
|
||||
"""注册单个组件。"""
|
||||
comp = RegisteredComponent(name, component_type, plugin_id, metadata)
|
||||
if comp.full_name in self._components:
|
||||
logger.warning(f"组件 {comp.full_name} 已存在,覆盖")
|
||||
old_comp = self._components[comp.full_name]
|
||||
# 从 _by_plugin 列表中移除旧条目,防止幽灵组件堆积
|
||||
old_list = self._by_plugin.get(old_comp.plugin_id)
|
||||
if old_list is not None:
|
||||
try:
|
||||
with contextlib.suppress(ValueError):
|
||||
old_list.remove(old_comp)
|
||||
except ValueError:
|
||||
pass
|
||||
# 从旧类型索引中移除,防止类型变更时幽灵残留
|
||||
if old_type_dict := self._by_type.get(old_comp.component_type):
|
||||
old_type_dict.pop(comp.full_name, None)
|
||||
|
||||
self._components[comp.full_name] = comp
|
||||
|
||||
if component_type not in self._by_type:
|
||||
self._by_type[component_type] = {}
|
||||
self._by_type[component_type][comp.full_name] = comp
|
||||
|
||||
self._by_type[comp.component_type][comp.full_name] = comp
|
||||
self._by_plugin.setdefault(plugin_id, []).append(comp)
|
||||
|
||||
return True
|
||||
|
||||
def register_plugin_components(
|
||||
self,
|
||||
plugin_id: str,
|
||||
components: List[Dict[str, Any]],
|
||||
) -> int:
|
||||
"""批量注册一个插件的所有组件,返回成功注册数。"""
|
||||
def register_plugin_components(self, plugin_id: str, components: List[Dict[str, Any]]) -> int:
|
||||
"""批量注册一个插件的所有组件,返回成功注册数。
|
||||
Args:
|
||||
plugin_id (str): 插件id
|
||||
components (List[Dict[str, Any]]): 组件字典列表,每个组件包含 name, component_type, metadata 等字段
|
||||
Returns:
|
||||
count (int): 成功注册的组件数量
|
||||
"""
|
||||
count = 0
|
||||
for comp_data in components:
|
||||
ok = self.register_component(
|
||||
@@ -139,7 +280,13 @@ class ComponentRegistry:
|
||||
return count
|
||||
|
||||
def remove_components_by_plugin(self, plugin_id: str) -> int:
|
||||
"""移除某个插件的所有组件,返回移除数量。"""
|
||||
"""移除某个插件的所有组件,返回移除数量。
|
||||
|
||||
Args:
|
||||
plugin_id (str): 插件id
|
||||
Returns:
|
||||
count (int): 移除的组件数量
|
||||
"""
|
||||
comps = self._by_plugin.pop(plugin_id, [])
|
||||
for comp in comps:
|
||||
self._components.pop(comp.full_name, None)
|
||||
@@ -147,106 +294,280 @@ class ComponentRegistry:
|
||||
type_dict.pop(comp.full_name, None)
|
||||
return len(comps)
|
||||
|
||||
# ──── 启用 / 禁用 ─────────────────────────────────────────
|
||||
# ====== 启用 / 禁用 ======
|
||||
def check_component_enabled(self, component: ComponentEntry, session_id: Optional[str] = None):
|
||||
if session_id and session_id in component.disabled_session:
|
||||
return False
|
||||
return component.enabled
|
||||
|
||||
def set_component_enabled(self, full_name: str, enabled: bool) -> bool:
|
||||
"""启用或禁用指定组件。"""
|
||||
def toggle_component_status(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool:
|
||||
"""启用或禁用指定组件。
|
||||
|
||||
Args:
|
||||
full_name (str): 组件全名
|
||||
enabled (bool): 使能情况
|
||||
session_id (Optional[str]): 可选的会话ID,仅对该会话禁用(如果提供)
|
||||
Returns:
|
||||
success (bool): 是否成功设置(失败原因通常是组件不存在)
|
||||
"""
|
||||
comp = self._components.get(full_name)
|
||||
if comp is None:
|
||||
return False
|
||||
comp.enabled = enabled
|
||||
if session_id:
|
||||
if enabled:
|
||||
comp.disabled_session.discard(session_id)
|
||||
else:
|
||||
comp.disabled_session.add(session_id)
|
||||
else:
|
||||
comp.enabled = enabled
|
||||
return True
|
||||
|
||||
def set_plugin_enabled(self, plugin_id: str, enabled: bool) -> int:
|
||||
"""批量启用或禁用某插件的所有组件。"""
|
||||
def set_component_enabled(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool:
|
||||
"""设置指定组件的启用状态。
|
||||
|
||||
Args:
|
||||
full_name: 组件全名。
|
||||
enabled: 目标启用状态。
|
||||
session_id: 可选的会话 ID,仅对该会话生效。
|
||||
|
||||
Returns:
|
||||
bool: 是否设置成功。
|
||||
"""
|
||||
|
||||
return self.toggle_component_status(full_name, enabled, session_id=session_id)
|
||||
|
||||
def toggle_plugin_status(self, plugin_id: str, enabled: bool, session_id: Optional[str] = None) -> int:
|
||||
"""批量启用或禁用某插件的所有组件。
|
||||
|
||||
Args:
|
||||
plugin_id (str): 插件id
|
||||
enabled (bool): 使能情况
|
||||
session_id (Optional[str]): 可选的会话ID,仅对该会话禁用(如果提供)
|
||||
Returns:
|
||||
count (int): 成功设置的组件数量(失败原因通常是插件不存在)
|
||||
"""
|
||||
comps = self._by_plugin.get(plugin_id, [])
|
||||
for comp in comps:
|
||||
comp.enabled = enabled
|
||||
if session_id:
|
||||
if enabled:
|
||||
comp.disabled_session.discard(session_id)
|
||||
else:
|
||||
comp.disabled_session.add(session_id)
|
||||
else:
|
||||
comp.enabled = enabled
|
||||
return len(comps)
|
||||
|
||||
# ──── 查询方法 ─────────────────────────────────────────────
|
||||
def get_component(self, full_name: str) -> Optional[ComponentEntry]:
|
||||
"""按全名查询。
|
||||
|
||||
def get_component(self, full_name: str) -> Optional[RegisteredComponent]:
|
||||
"""按全名查询。"""
|
||||
Args:
|
||||
full_name (str): 组件全名
|
||||
Returns:
|
||||
component (Optional[ComponentEntry]): 组件条目,未找到时为 None
|
||||
"""
|
||||
return self._components.get(full_name)
|
||||
|
||||
def get_components_by_type(self, component_type: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
|
||||
"""按类型查询。"""
|
||||
type_dict = self._by_type.get(component_type, {})
|
||||
def get_components_by_type(
|
||||
self, component_type: str, *, enabled_only: bool = True, session_id: Optional[str] = None
|
||||
) -> List[ComponentEntry]:
|
||||
"""按类型查询组件
|
||||
|
||||
Args:
|
||||
component_type (str): 组件类型(如 `ACTION`、`COMMAND` 等)
|
||||
enabled_only (bool): 是否仅返回启用的组件
|
||||
session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态
|
||||
Returns:
|
||||
components (List[ComponentEntry]): 组件条目列表
|
||||
"""
|
||||
try:
|
||||
comp_type = self._normalize_component_type(component_type)
|
||||
except ValueError:
|
||||
logger.error(f"组件类型 {component_type} 不存在")
|
||||
raise
|
||||
type_dict = self._by_type.get(comp_type, {})
|
||||
if enabled_only:
|
||||
return [c for c in type_dict.values() if c.enabled]
|
||||
return [c for c in type_dict.values() if self.check_component_enabled(c, session_id)]
|
||||
return list(type_dict.values())
|
||||
|
||||
def get_components_by_plugin(self, plugin_id: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
|
||||
"""按插件查询。"""
|
||||
comps = self._by_plugin.get(plugin_id, [])
|
||||
return [c for c in comps if c.enabled] if enabled_only else list(comps)
|
||||
def get_components_by_plugin(
|
||||
self, plugin_id: str, *, enabled_only: bool = True, session_id: Optional[str] = None
|
||||
) -> List[ComponentEntry]:
|
||||
"""按插件查询组件。
|
||||
|
||||
def find_command_by_text(self, text: str) -> Optional[tuple[RegisteredComponent, Dict[str, Any]]]:
|
||||
Args:
|
||||
plugin_id (str): 插件ID
|
||||
enabled_only (bool): 是否仅返回启用的组件
|
||||
session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态
|
||||
Returns:
|
||||
components (List[ComponentEntry]): 组件条目列表
|
||||
"""
|
||||
comps = self._by_plugin.get(plugin_id, [])
|
||||
return [c for c in comps if self.check_component_enabled(c, session_id)] if enabled_only else list(comps)
|
||||
|
||||
def find_command_by_text(
|
||||
self, text: str, session_id: Optional[str] = None
|
||||
) -> Optional[Tuple[ComponentEntry, Dict[str, Any]]]:
|
||||
"""通过文本匹配命令正则,返回 (组件, matched_groups) 元组。
|
||||
|
||||
matched_groups 为正则命名捕获组 dict,别名匹配时为空 dict。
|
||||
Args:
|
||||
text (str): 待匹配文本
|
||||
session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态
|
||||
Returns:
|
||||
result (Optional[tuple[ComponentEntry, Dict[str, Any]]]): 匹配到的组件及正则捕获组,未找到时为 None
|
||||
"""
|
||||
for comp in self._by_type.get("command", {}).values():
|
||||
if not comp.enabled:
|
||||
for comp in self._by_type.get(ComponentTypes.COMMAND, {}).values():
|
||||
if not self.check_component_enabled(comp, session_id):
|
||||
continue
|
||||
if comp._compiled_pattern:
|
||||
m = comp._compiled_pattern.search(text)
|
||||
if m:
|
||||
if not isinstance(comp, CommandEntry):
|
||||
continue
|
||||
if comp.compiled_pattern:
|
||||
if m := comp.compiled_pattern.search(text):
|
||||
return comp, m.groupdict()
|
||||
# 别名匹配
|
||||
aliases = comp.metadata.get("aliases", [])
|
||||
for alias in aliases:
|
||||
for alias in comp.aliases:
|
||||
if text.startswith(alias):
|
||||
return comp, {}
|
||||
return None
|
||||
|
||||
def get_event_handlers(self, event_type: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
|
||||
"""获取特定事件类型的所有 event_handler,按 weight 降序排列。"""
|
||||
handlers = []
|
||||
for comp in self._by_type.get("event_handler", {}).values():
|
||||
if enabled_only and not comp.enabled:
|
||||
def get_event_handlers(
|
||||
self, event_type: str, *, enabled_only: bool = True, session_id: Optional[str] = None
|
||||
) -> List[EventHandlerEntry]:
|
||||
"""查询指定事件类型的事件处理器组件。
|
||||
|
||||
Args:
|
||||
event_type (str): 事件类型
|
||||
enabled_only (bool): 是否仅返回启用的组件
|
||||
session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态
|
||||
Returns:
|
||||
handlers (List[EventHandlerEntry]): 符合条件的 EventHandler 组件列表,按 weight 降序排序
|
||||
"""
|
||||
handlers: List[EventHandlerEntry] = []
|
||||
for comp in self._by_type.get(ComponentTypes.EVENT_HANDLER, {}).values():
|
||||
if enabled_only and not self.check_component_enabled(comp, session_id):
|
||||
continue
|
||||
if comp.metadata.get("event_type") == event_type:
|
||||
if not isinstance(comp, EventHandlerEntry):
|
||||
continue
|
||||
if comp.event_type == event_type:
|
||||
handlers.append(comp)
|
||||
handlers.sort(key=lambda c: c.metadata.get("weight", 0), reverse=True)
|
||||
handlers.sort(key=lambda c: c.weight, reverse=True)
|
||||
return handlers
|
||||
|
||||
def get_workflow_steps(self, stage: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
|
||||
"""获取特定 workflow 阶段的所有步骤,按 priority 降序。"""
|
||||
steps = []
|
||||
for comp in self._by_type.get("workflow_step", {}).values():
|
||||
if enabled_only and not comp.enabled:
|
||||
def get_hook_handlers(
|
||||
self, stage: str, *, enabled_only: bool = True, session_id: Optional[str] = None
|
||||
) -> List[HookHandlerEntry]:
|
||||
"""获取特定 hook 阶段的所有步骤,按 priority 降序。
|
||||
|
||||
Args:
|
||||
stage: hook 名称
|
||||
enabled_only: 是否仅返回启用的组件
|
||||
session_id: 可选的会话ID,若提供则考虑会话禁用状态
|
||||
Returns:
|
||||
handlers (List[HookHandlerEntry]): 符合条件的 HookHandler 组件列表,按 priority 降序排序
|
||||
"""
|
||||
handlers: List[HookHandlerEntry] = []
|
||||
for comp in self._by_type.get(ComponentTypes.HOOK_HANDLER, {}).values():
|
||||
if enabled_only and not self.check_component_enabled(comp, session_id):
|
||||
continue
|
||||
if comp.metadata.get("stage") == stage:
|
||||
steps.append(comp)
|
||||
steps.sort(key=lambda c: c.metadata.get("priority", 0), reverse=True)
|
||||
return steps
|
||||
if not isinstance(comp, HookHandlerEntry):
|
||||
continue
|
||||
if comp.stage == stage:
|
||||
handlers.append(comp)
|
||||
handlers.sort(key=lambda c: c.priority, reverse=True)
|
||||
return handlers
|
||||
|
||||
def get_tools_for_llm(self, *, enabled_only: bool = True) -> List[Dict[str, Any]]:
|
||||
"""获取可供 LLM 使用的工具列表(openai function-calling 格式预览)。"""
|
||||
result: List[Dict[str, Any]] = []
|
||||
for comp in self.get_components_by_type("tool", enabled_only=enabled_only):
|
||||
tool_def: Dict[str, Any] = {
|
||||
"name": comp.full_name,
|
||||
"description": comp.metadata.get("description", ""),
|
||||
}
|
||||
# 从结构化参数或原始参数构建 parameters
|
||||
params = comp.metadata.get("parameters", [])
|
||||
params_raw = comp.metadata.get("parameters_raw", {})
|
||||
if params:
|
||||
tool_def["parameters"] = params
|
||||
elif params_raw:
|
||||
tool_def["parameters"] = params_raw
|
||||
result.append(tool_def)
|
||||
return result
|
||||
def get_message_gateway(
|
||||
self,
|
||||
plugin_id: str,
|
||||
name: str,
|
||||
*,
|
||||
enabled_only: bool = True,
|
||||
session_id: Optional[str] = None,
|
||||
) -> Optional[MessageGatewayEntry]:
|
||||
"""按插件和组件名获取单个消息网关。
|
||||
|
||||
# ──── 统计 ─────────────────────────────────────────────────
|
||||
Args:
|
||||
plugin_id: 插件 ID。
|
||||
name: 网关组件名称。
|
||||
enabled_only: 是否仅返回启用的组件。
|
||||
session_id: 可选的会话 ID。
|
||||
|
||||
def get_stats(self) -> Dict[str, int]:
|
||||
"""获取注册统计。"""
|
||||
stats: Dict[str, int] = {"total": len(self._components)}
|
||||
Returns:
|
||||
Optional[MessageGatewayEntry]: 若存在则返回消息网关条目。
|
||||
"""
|
||||
|
||||
component = self._components.get(f"{plugin_id}.{name}")
|
||||
if not isinstance(component, MessageGatewayEntry):
|
||||
return None
|
||||
if enabled_only and not self.check_component_enabled(component, session_id):
|
||||
return None
|
||||
return component
|
||||
|
||||
def get_message_gateways(
|
||||
self,
|
||||
*,
|
||||
plugin_id: Optional[str] = None,
|
||||
platform: str = "",
|
||||
route_type: str = "",
|
||||
enabled_only: bool = True,
|
||||
session_id: Optional[str] = None,
|
||||
) -> List[MessageGatewayEntry]:
|
||||
"""查询消息网关组件列表。
|
||||
|
||||
Args:
|
||||
plugin_id: 可选的插件 ID 过滤条件。
|
||||
platform: 可选的平台过滤条件。
|
||||
route_type: 可选的路由类型过滤条件。
|
||||
enabled_only: 是否仅返回启用的组件。
|
||||
session_id: 可选的会话 ID。
|
||||
|
||||
Returns:
|
||||
List[MessageGatewayEntry]: 符合条件的消息网关组件列表。
|
||||
"""
|
||||
|
||||
normalized_platform = str(platform or "").strip()
|
||||
normalized_route_type = str(route_type or "").strip().lower()
|
||||
gateways: List[MessageGatewayEntry] = []
|
||||
for comp in self._by_type.get(ComponentTypes.MESSAGE_GATEWAY, {}).values():
|
||||
if not isinstance(comp, MessageGatewayEntry):
|
||||
continue
|
||||
if plugin_id and comp.plugin_id != plugin_id:
|
||||
continue
|
||||
if enabled_only and not self.check_component_enabled(comp, session_id):
|
||||
continue
|
||||
if normalized_platform and comp.platform != normalized_platform:
|
||||
continue
|
||||
if normalized_route_type and comp.route_type != normalized_route_type:
|
||||
continue
|
||||
gateways.append(comp)
|
||||
return gateways
|
||||
|
||||
def get_tools(self, *, enabled_only: bool = True, session_id: Optional[str] = None) -> List[ToolEntry]:
|
||||
"""查询所有工具组件。
|
||||
|
||||
Args:
|
||||
enabled_only (bool): 是否仅返回启用的组件
|
||||
session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态
|
||||
Returns:
|
||||
tools (List[ToolEntry]): 符合条件的 Tool 组件列表
|
||||
"""
|
||||
tools: List[ToolEntry] = []
|
||||
for comp in self._by_type.get(ComponentTypes.TOOL, {}).values():
|
||||
if enabled_only and not self.check_component_enabled(comp, session_id):
|
||||
continue
|
||||
if isinstance(comp, ToolEntry):
|
||||
tools.append(comp)
|
||||
return tools
|
||||
|
||||
# ====== 统计信息 ======
|
||||
def get_stats(self) -> StatusDict:
|
||||
"""获取注册统计。
|
||||
|
||||
Returns:
|
||||
stats (StatusDict): 组件统计信息,包括总数、各类型数量、插件数量等
|
||||
"""
|
||||
stats: StatusDict = {"total": len(self._components)} # type: ignore
|
||||
for comp_type, type_dict in self._by_type.items():
|
||||
stats[comp_type] = len(type_dict)
|
||||
stats[comp_type.value.lower()] = len(type_dict)
|
||||
stats["plugins"] = len(self._by_plugin)
|
||||
return stats
|
||||
|
||||
@@ -4,40 +4,40 @@
|
||||
1. 按事件类型查询已注册的 event_handler(通过 ComponentRegistry)
|
||||
2. 按 weight 排序,依次通过 RPC 调用 Runner 中的处理器
|
||||
3. 支持阻塞(intercept_message)和非阻塞分发
|
||||
4. 事件结果历史记录
|
||||
4. 事件结果历史记录(有上限)
|
||||
"""
|
||||
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING
|
||||
|
||||
import asyncio
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry, RegisteredComponent
|
||||
|
||||
from .message_utils import PluginMessageUtils, MessageDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .supervisor import PluginRunnerSupervisor
|
||||
from .component_registry import ComponentRegistry, EventHandlerEntry
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
|
||||
logger = get_logger("plugin_runtime.host.event_dispatcher")
|
||||
|
||||
# invoke_fn 类型: async (plugin_id, component_name, args) -> response_payload dict
|
||||
InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]]
|
||||
# 每个事件类型的最大历史记录数量,防止内存无限增长
|
||||
_MAX_HISTORY_LENGTH = 100
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventResult:
|
||||
"""单个 EventHandler 的执行结果"""
|
||||
|
||||
__slots__ = ("handler_name", "success", "continue_processing", "modified_message", "custom_result")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handler_name: str,
|
||||
success: bool = True,
|
||||
continue_processing: bool = True,
|
||||
modified_message: Optional[Dict[str, Any]] = None,
|
||||
custom_result: Any = None,
|
||||
):
|
||||
self.handler_name = handler_name
|
||||
self.success = success
|
||||
self.continue_processing = continue_processing
|
||||
self.modified_message = modified_message
|
||||
self.custom_result = custom_result
|
||||
handler_name: str
|
||||
success: bool = field(default=True)
|
||||
continue_processing: bool = field(default=True)
|
||||
modified_message: Optional[MessageDict] = field(default=None)
|
||||
custom_result: Any = field(default=None)
|
||||
|
||||
|
||||
class EventDispatcher:
|
||||
@@ -48,17 +48,20 @@ class EventDispatcher:
|
||||
再通过提供的 invoke_fn 回调 RPC 到 Runner 执行。
|
||||
"""
|
||||
|
||||
def __init__(self, registry: ComponentRegistry) -> None:
|
||||
self._registry: ComponentRegistry = registry
|
||||
def __init__(self, component_registry: "ComponentRegistry") -> None:
|
||||
self._component_registry: "ComponentRegistry" = component_registry
|
||||
self._result_history: Dict[str, List[EventResult]] = {}
|
||||
self._history_enabled: Set[str] = set()
|
||||
# 保持 fire-and-forget task 的强引用,防止被 GC 回收
|
||||
self._background_tasks: Set[asyncio.Task] = set()
|
||||
|
||||
def enable_history(self, event_type: str) -> None:
|
||||
self._history_enabled.add(event_type)
|
||||
self._result_history.setdefault(event_type, [])
|
||||
|
||||
def disable_history(self, event_type: str) -> None:
|
||||
self._history_enabled.discard(event_type)
|
||||
self._result_history.pop(event_type, None)
|
||||
|
||||
def get_history(self, event_type: str) -> List[EventResult]:
|
||||
return self._result_history.get(event_type, [])
|
||||
|
||||
@@ -66,47 +69,58 @@ class EventDispatcher:
|
||||
if event_type in self._result_history:
|
||||
self._result_history[event_type] = []
|
||||
|
||||
async def stop(self):
|
||||
"""停止 EventDispatcher,取消所有未完成的后台任务"""
|
||||
for task in self._background_tasks:
|
||||
task.cancel()
|
||||
await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
||||
self._background_tasks.clear()
|
||||
|
||||
async def dispatch_event(
|
||||
self,
|
||||
event_type: str,
|
||||
invoke_fn: InvokeFn,
|
||||
message: Optional[Dict[str, Any]] = None,
|
||||
supervisor: "PluginRunnerSupervisor",
|
||||
message: Optional["SessionMessage"] = None,
|
||||
extra_args: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[bool, Optional[Dict[str, Any]]]:
|
||||
"""分发事件到所有对应 handler。
|
||||
) -> Tuple[bool, Optional["SessionMessage"]]:
|
||||
"""分发事件到所有对应 handler 的便捷方法。
|
||||
|
||||
内置了通过 PluginSupervisor.invoke_plugin 调用 plugin.emit_event 的逻辑,
|
||||
无需调用方手动构造 invoke_fn 闭包。
|
||||
|
||||
Args:
|
||||
event_type: 事件类型字符串
|
||||
invoke_fn: 异步回调,签名 (plugin_id, component_name, args) -> response_payload dict
|
||||
supervisor: PluginSupervisor 实例,用于调用 invoke_plugin
|
||||
message: MaiMessages 序列化后的 dict(可选)
|
||||
extra_args: 额外参数
|
||||
|
||||
Returns:
|
||||
(should_continue, modified_message_dict)
|
||||
(should_continue, modified_message_dict) (bool, SessionMessage | None): (是否继续后续执行, 可选的修改后的消息)
|
||||
"""
|
||||
handlers = self._registry.get_event_handlers(event_type)
|
||||
if not handlers:
|
||||
handler_entries = self._component_registry.get_event_handlers(event_type)
|
||||
if not handler_entries:
|
||||
return True, None
|
||||
|
||||
should_continue = True
|
||||
modified_message: Optional[Dict[str, Any]] = None
|
||||
intercept_handlers: List[RegisteredComponent] = []
|
||||
async_handlers: List[RegisteredComponent] = []
|
||||
modified_message: Optional[MessageDict] = (
|
||||
PluginMessageUtils._session_message_to_dict(message) if message else None
|
||||
)
|
||||
intercept_handlers: List["EventHandlerEntry"] = []
|
||||
non_blocking_handlers: List["EventHandlerEntry"] = []
|
||||
|
||||
for handler in handlers:
|
||||
if handler.metadata.get("intercept_message", False):
|
||||
intercept_handlers.append(handler)
|
||||
for entry in handler_entries:
|
||||
if entry.intercept_message:
|
||||
intercept_handlers.append(entry)
|
||||
else:
|
||||
async_handlers.append(handler)
|
||||
non_blocking_handlers.append(entry)
|
||||
|
||||
for handler in intercept_handlers:
|
||||
for entry in intercept_handlers:
|
||||
args = {
|
||||
"event_type": event_type,
|
||||
"message": modified_message or message,
|
||||
"message": modified_message,
|
||||
**(extra_args or {}),
|
||||
}
|
||||
|
||||
result = await self._invoke_handler(invoke_fn, handler, args, event_type)
|
||||
result = await self._invoke_handler(supervisor, entry, args, event_type)
|
||||
if result and not result.continue_processing:
|
||||
should_continue = False
|
||||
break
|
||||
@@ -114,47 +128,57 @@ class EventDispatcher:
|
||||
modified_message = result.modified_message
|
||||
|
||||
if should_continue:
|
||||
final_message = modified_message or message
|
||||
for handler in async_handlers:
|
||||
async_message = final_message.copy() if isinstance(final_message, dict) else final_message
|
||||
final_message = modified_message
|
||||
for entry in non_blocking_handlers:
|
||||
async_message = final_message.copy() if final_message else final_message
|
||||
args = {
|
||||
"event_type": event_type,
|
||||
"message": async_message,
|
||||
**(extra_args or {}),
|
||||
}
|
||||
# 非阻塞:保持实例级强引用,防止 task 被 GC 回收
|
||||
task = asyncio.create_task(self._invoke_handler(invoke_fn, handler, args, event_type))
|
||||
task = asyncio.create_task(self._invoke_handler(supervisor, entry, args, event_type))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
return should_continue, modified_message
|
||||
try:
|
||||
modified_message_obj = (
|
||||
PluginMessageUtils._build_session_message_from_dict(modified_message) if modified_message else None # type: ignore
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"构建修改后的 SessionMessage 失败: {e}")
|
||||
modified_message_obj = None
|
||||
return should_continue, modified_message_obj
|
||||
|
||||
async def _invoke_handler(
|
||||
self,
|
||||
invoke_fn: InvokeFn,
|
||||
handler: RegisteredComponent,
|
||||
supervisor: "PluginRunnerSupervisor",
|
||||
handler_entry: "EventHandlerEntry",
|
||||
args: Dict[str, Any],
|
||||
event_type: str,
|
||||
) -> Optional[EventResult]:
|
||||
"""调用单个 handler 并收集结果。"""
|
||||
try:
|
||||
resp = await invoke_fn(handler.plugin_id, handler.name, args)
|
||||
resp_envelope = await supervisor.invoke_plugin(
|
||||
"plugin.emit_event", handler_entry.plugin_id, handler_entry.name, args
|
||||
)
|
||||
resp = resp_envelope.payload
|
||||
result = EventResult(
|
||||
handler_name=handler.full_name,
|
||||
handler_name=handler_entry.full_name,
|
||||
success=resp.get("success", True),
|
||||
continue_processing=resp.get("continue_processing", True),
|
||||
modified_message=resp.get("modified_message"),
|
||||
custom_result=resp.get("custom_result"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"EventHandler {handler.full_name} 执行失败: {e}", exc_info=True)
|
||||
result = EventResult(
|
||||
handler_name=handler.full_name,
|
||||
success=False,
|
||||
continue_processing=True,
|
||||
)
|
||||
logger.error(f"EventHandler {handler_entry.full_name} 执行失败: {e}", exc_info=True)
|
||||
result = EventResult(handler_name=handler_entry.full_name, success=False, continue_processing=True)
|
||||
|
||||
if event_type in self._history_enabled:
|
||||
self._result_history.setdefault(event_type, []).append(result)
|
||||
history_list = self._result_history.setdefault(event_type, [])
|
||||
history_list.append(result)
|
||||
# 自动清理超出限制的旧记录,防止内存无限增长
|
||||
if len(history_list) > _MAX_HISTORY_LENGTH:
|
||||
# 保留最新的 _MAX_HISTORY_LENGTH 条记录
|
||||
self._result_history[event_type] = history_list[-_MAX_HISTORY_LENGTH:]
|
||||
|
||||
return result
|
||||
|
||||
166
src/plugin_runtime/host/hook_dispatcher.py
Normal file
166
src/plugin_runtime/host/hook_dispatcher.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
Hook Dispatch 系统
|
||||
|
||||
插件可以注册自己的Hook,当特定函数被调用时,Hook Dispatch系统会将调用转发给插件的Hook处理函数。
|
||||
每个Hook的参数随Hook点位确定,因此参数是易变的。插件开发者需要根据Hook点位的定义来编写Hook处理函数。
|
||||
在参数/返回值匹配的情况下允许修改参数/返回值。
|
||||
|
||||
HookDispatcher 负责:
|
||||
1. 按 stage 查询已注册的 hook_handler(通过 ComponentRegistry)
|
||||
2. 按 priority 排序,区分 blocking 和非 blocking 模式
|
||||
3. blocking 模式:依次同步调用,支持修改参数/提前终止
|
||||
4. 非 blocking 模式:异步调用,不阻塞主流程
|
||||
5. 支持通过 global_config.plugin_runtime.hook_blocking_timeout_sec 设置超时上限
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .supervisor import PluginRunnerSupervisor
|
||||
from .component_registry import ComponentRegistry, HookHandlerEntry
|
||||
|
||||
logger = get_logger("plugin_runtime.host.hook_dispatcher")
|
||||
|
||||
|
||||
@dataclass
|
||||
class HookResult:
|
||||
"""单个 HookHandler 的执行结果"""
|
||||
|
||||
handler_name: str
|
||||
success: bool = field(default=True)
|
||||
continue_processing: bool = field(default=True)
|
||||
modified_kwargs: Optional[Dict[str, Any]] = field(default=None)
|
||||
custom_result: Any = field(default=None)
|
||||
|
||||
|
||||
class HookDispatcher:
|
||||
"""Host-side Hook 分发器
|
||||
|
||||
由业务层调用 hook_dispatch(),
|
||||
内部通过 ComponentRegistry 查询 handler,
|
||||
再通过提供的 invoke_fn 回调 RPC 到 Runner 执行。
|
||||
"""
|
||||
|
||||
def __init__(self, component_registry: "ComponentRegistry") -> None:
|
||||
"""初始化 HookDispatcher
|
||||
|
||||
Args:
|
||||
component_registry: ComponentRegistry 实例,用于查询已注册的 hook_handler
|
||||
"""
|
||||
self._component_registry: "ComponentRegistry" = component_registry
|
||||
self._background_tasks: Set[asyncio.Task] = set()
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止 HookDispatcher,取消所有未完成的后台任务"""
|
||||
for task in self._background_tasks:
|
||||
task.cancel()
|
||||
await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
||||
self._background_tasks.clear()
|
||||
|
||||
async def hook_dispatch(
|
||||
self,
|
||||
stage: str,
|
||||
supervisor: "PluginRunnerSupervisor",
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""分发 hook 到所有对应 handler 的便捷方法。
|
||||
|
||||
内置了通过 PluginRunnerSupervisor.invoke_plugin 调用 plugin 的逻辑,
|
||||
无需调用方手动构造 invoke_fn 闭包。
|
||||
|
||||
Args:
|
||||
stage: hook 名称
|
||||
supervisor: PluginRunnerSupervisor 实例,用于调用 invoke_plugin
|
||||
**kwargs: 关键字参数,会展开传递给 handler
|
||||
|
||||
Returns:
|
||||
modified_kwargs (Dict[str, Any]): 经过所有 handler 修改后的关键字参数
|
||||
"""
|
||||
handler_entries = self._component_registry.get_hook_handlers(stage)
|
||||
if not handler_entries:
|
||||
return kwargs
|
||||
|
||||
current_kwargs = kwargs.copy()
|
||||
blocking_handlers: List["HookHandlerEntry"] = []
|
||||
non_blocking_handlers: List["HookHandlerEntry"] = []
|
||||
|
||||
# 分离 blocking 和非 blocking handler
|
||||
for entry in handler_entries:
|
||||
if entry.blocking:
|
||||
blocking_handlers.append(entry)
|
||||
else:
|
||||
non_blocking_handlers.append(entry)
|
||||
|
||||
# 处理 blocking handlers(同步调用,支持修改参数/提前终止)
|
||||
timeout = global_config.plugin_runtime.hook_blocking_timeout_sec or 30.0
|
||||
for entry in blocking_handlers:
|
||||
hook_args = {"stage": stage, **current_kwargs}
|
||||
try:
|
||||
# 应用超时控制
|
||||
result = await asyncio.wait_for(
|
||||
self._invoke_handler(supervisor, entry, hook_args),
|
||||
timeout=timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Blocking HookHandler {entry.full_name} 执行超时 (>{timeout}秒),跳过")
|
||||
result = HookResult(handler_name=entry.full_name, success=False, continue_processing=True)
|
||||
|
||||
if result:
|
||||
if result.modified_kwargs is not None:
|
||||
current_kwargs = result.modified_kwargs
|
||||
if not result.continue_processing:
|
||||
logger.info(f"HookHandler {entry.full_name} 终止了后续处理")
|
||||
break
|
||||
|
||||
# 处理 non-blocking handlers(异步调用,不阻塞主流程)
|
||||
for entry in non_blocking_handlers:
|
||||
async_kwargs = current_kwargs.copy()
|
||||
hook_args = {"stage": stage, **async_kwargs}
|
||||
task = asyncio.create_task(
|
||||
asyncio.wait_for(self._invoke_handler(supervisor, entry, hook_args), timeout=timeout)
|
||||
)
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
return current_kwargs
|
||||
|
||||
async def _invoke_handler(
|
||||
self,
|
||||
supervisor: "PluginRunnerSupervisor",
|
||||
handler_entry: "HookHandlerEntry",
|
||||
args: Dict[str, Any],
|
||||
) -> Optional[HookResult]:
|
||||
"""调用单个 handler 并收集结果。
|
||||
|
||||
Args:
|
||||
supervisor: PluginRunnerSupervisor 实例
|
||||
handler_entry: HookHandlerEntry 实例
|
||||
args: 传递给 handler 的参数字典
|
||||
stage: hook 名称
|
||||
|
||||
Returns:
|
||||
Optional[HookResult]: 执行结果,如果执行失败则返回 None
|
||||
"""
|
||||
try:
|
||||
resp_envelope = await supervisor.invoke_plugin(
|
||||
"plugin.invoke_hook", handler_entry.plugin_id, handler_entry.name, args
|
||||
)
|
||||
resp = resp_envelope.payload
|
||||
result = HookResult(
|
||||
handler_name=handler_entry.full_name,
|
||||
success=resp.get("success", True),
|
||||
continue_processing=resp.get("continue_processing", True),
|
||||
modified_kwargs=resp.get("modified_kwargs"),
|
||||
custom_result=resp.get("custom_result"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"HookHandler {handler_entry.full_name} 执行失败:{e}", exc_info=True)
|
||||
result = HookResult(handler_name=handler_entry.full_name, success=False, continue_processing=True)
|
||||
|
||||
return result
|
||||
45
src/plugin_runtime/host/logger_bridge.py
Normal file
45
src/plugin_runtime/host/logger_bridge.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import logging as stdlib_logging
|
||||
from src.plugin_runtime.protocol.errors import ErrorCode
|
||||
from src.plugin_runtime.protocol.envelope import Envelope, LogBatchPayload
|
||||
class RunnerLogBridge:
|
||||
"""将 Runner 进程上报的批量日志重放到主进程的 Logger 中。
|
||||
|
||||
Runner 通过 ``runner.log_batch`` IPC 事件批量到达。
|
||||
每条 LogEntry 被重建为一个真实的 :class:`logging.LogRecord` 并直接
|
||||
调用 ``logging.getLogger(entry.logger_name).handle(record)``,
|
||||
从而接入主进程已配置好的 structlog Handler 链。
|
||||
"""
|
||||
|
||||
async def handle_log_batch(self, envelope: Envelope) -> Envelope:
|
||||
"""IPC 事件处理器:解析批量日志并重放到主进程 Logger。
|
||||
|
||||
Args:
|
||||
envelope: 方法名为 ``runner.log_batch`` 的 IPC 事件信封。
|
||||
|
||||
Returns:
|
||||
空响应信封(事件模式下将被忽略)。
|
||||
"""
|
||||
try:
|
||||
batch = LogBatchPayload.model_validate(envelope.payload)
|
||||
except Exception as exc:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
|
||||
|
||||
for entry in batch.entries:
|
||||
# 重建一个与原始日志尽量相符的 LogRecord
|
||||
record = stdlib_logging.LogRecord(
|
||||
name=entry.logger_name,
|
||||
level=entry.level,
|
||||
pathname="<runner>",
|
||||
lineno=0,
|
||||
msg=entry.message,
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
record.created = entry.timestamp_ms / 1000.0
|
||||
record.msecs = entry.timestamp_ms % 1000
|
||||
if entry.exception_text:
|
||||
record.exc_text = entry.exception_text
|
||||
|
||||
stdlib_logging.getLogger(entry.logger_name).handle(record)
|
||||
|
||||
return envelope.make_response(payload={"accepted": True, "count": len(batch.entries)})
|
||||
112
src/plugin_runtime/host/message_gateway.py
Normal file
112
src/plugin_runtime/host/message_gateway.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""Host 侧消息网关包装器。"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.platform_io import get_platform_io_manager
|
||||
|
||||
from .message_utils import PluginMessageUtils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from .component_registry import ComponentRegistry
|
||||
from .supervisor import PluginRunnerSupervisor
|
||||
|
||||
logger = get_logger("plugin_runtime.host.message_gateway")
|
||||
|
||||
|
||||
class MessageGateway:
|
||||
"""Host 侧消息网关包装器。"""
|
||||
|
||||
def __init__(self, component_registry: "ComponentRegistry") -> None:
|
||||
"""初始化消息网关。
|
||||
|
||||
Args:
|
||||
component_registry: 组件注册表。
|
||||
"""
|
||||
self._component_registry = component_registry
|
||||
|
||||
def build_session_message(self, external_message: Dict[str, Any]) -> "SessionMessage":
|
||||
"""将标准消息字典转换为 ``SessionMessage``。
|
||||
|
||||
Args:
|
||||
external_message: 外部消息的字典格式数据。
|
||||
|
||||
Returns:
|
||||
SessionMessage: 转换后的内部消息对象。
|
||||
|
||||
Raises:
|
||||
ValueError: 消息字典不合法时抛出。
|
||||
"""
|
||||
return PluginMessageUtils._build_session_message_from_dict(external_message)
|
||||
|
||||
def build_message_dict(self, internal_message: "SessionMessage") -> Dict[str, Any]:
|
||||
"""将 ``SessionMessage`` 转换为标准消息字典。
|
||||
|
||||
Args:
|
||||
internal_message: 内部消息对象。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 供消息网关插件消费的标准消息字典。
|
||||
"""
|
||||
return dict(PluginMessageUtils._session_message_to_dict(internal_message))
|
||||
|
||||
async def receive_external_message(self, external_message: Dict[str, Any]) -> None:
|
||||
"""接收外部消息并送入主消息链。
|
||||
|
||||
Args:
|
||||
external_message: 外部消息的字典格式数据。
|
||||
"""
|
||||
try:
|
||||
session_message = self.build_session_message(external_message)
|
||||
except Exception as e:
|
||||
logger.error(f"转换外部消息失败: {e}")
|
||||
return
|
||||
|
||||
from src.chat.message_receive.bot import chat_bot
|
||||
|
||||
await chat_bot.receive_message(session_message)
|
||||
|
||||
async def send_message_to_external(
|
||||
self,
|
||||
internal_message: "SessionMessage",
|
||||
supervisor: "PluginRunnerSupervisor",
|
||||
*,
|
||||
enabled_only: bool = True,
|
||||
save_to_db: bool = True,
|
||||
) -> bool:
|
||||
"""将内部消息通过 Platform IO 发送到外部平台。
|
||||
|
||||
Args:
|
||||
internal_message: 系统内部的 ``SessionMessage`` 对象。
|
||||
supervisor: 当前持有该消息网关的 Supervisor。
|
||||
enabled_only: 兼容旧签名的保留参数,当前未使用。
|
||||
save_to_db: 发送成功后是否写入数据库。
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功。
|
||||
"""
|
||||
del enabled_only
|
||||
del supervisor
|
||||
|
||||
platform_io_manager = get_platform_io_manager()
|
||||
if not platform_io_manager.is_started:
|
||||
logger.warning("Platform IO 尚未启动,无法通过适配器链路发送消息")
|
||||
return False
|
||||
|
||||
route_key = platform_io_manager.build_route_key_from_message(internal_message)
|
||||
delivery_batch = await platform_io_manager.send_message(internal_message, route_key)
|
||||
if not delivery_batch.has_success:
|
||||
logger.warning("通过消息网关链路发送消息失败: 未命中任何成功回执")
|
||||
return False
|
||||
|
||||
first_successful_receipt = delivery_batch.sent_receipts[0]
|
||||
internal_message.message_id = first_successful_receipt.external_message_id or internal_message.message_id
|
||||
if save_to_db:
|
||||
try:
|
||||
from src.common.utils.utils_message import MessageUtils
|
||||
|
||||
MessageUtils.store_message_to_db(internal_message)
|
||||
except Exception as e:
|
||||
logger.error(f"保存消息到数据库失败: {e}")
|
||||
return True
|
||||
487
src/plugin_runtime/host/message_utils.py
Normal file
487
src/plugin_runtime/host/message_utils.py
Normal file
@@ -0,0 +1,487 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, TypedDict
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.data_models.mai_message_data_model import UserInfo, GroupInfo, MessageInfo
|
||||
from src.common.data_models.message_component_data_model import (
|
||||
AtComponent,
|
||||
DictComponent,
|
||||
EmojiComponent,
|
||||
ForwardComponent,
|
||||
ForwardNodeComponent,
|
||||
ImageComponent,
|
||||
MessageSequence,
|
||||
ReplyComponent,
|
||||
StandardMessageComponents,
|
||||
TextComponent,
|
||||
VoiceComponent,
|
||||
)
|
||||
|
||||
logger = get_logger("plugin_runtime.host.message_utils")
|
||||
|
||||
|
||||
class UserInfoDict(TypedDict, total=False):
|
||||
user_id: str
|
||||
user_nickname: str
|
||||
user_cardname: Optional[str]
|
||||
|
||||
|
||||
class GroupInfoDict(TypedDict, total=False):
|
||||
group_id: str
|
||||
group_name: str
|
||||
|
||||
|
||||
class MessageInfoDict(TypedDict, total=False):
|
||||
user_info: UserInfoDict
|
||||
group_info: Optional[GroupInfoDict]
|
||||
additional_config: Dict[str, Any]
|
||||
|
||||
|
||||
class MessageDict(TypedDict, total=False):
|
||||
message_id: str
|
||||
timestamp: str
|
||||
platform: str
|
||||
message_info: MessageInfoDict
|
||||
raw_message: List[Dict[str, Any]]
|
||||
is_mentioned: bool
|
||||
is_at: bool
|
||||
is_emoji: bool
|
||||
is_picture: bool
|
||||
is_command: bool
|
||||
is_notify: bool
|
||||
session_id: str
|
||||
reply_to: Optional[str]
|
||||
processed_plain_text: Optional[str]
|
||||
display_message: Optional[str]
|
||||
|
||||
|
||||
class PluginMessageUtils:
|
||||
@staticmethod
|
||||
def _message_sequence_to_dict(message_sequence: MessageSequence) -> List[Dict[str, Any]]:
|
||||
"""将消息组件序列转换为插件运行时使用的字典结构。
|
||||
|
||||
Args:
|
||||
message_sequence: 待转换的消息组件序列。
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 供插件运行时协议使用的消息段字典列表。
|
||||
"""
|
||||
return [PluginMessageUtils._component_to_dict(component) for component in message_sequence.components]
|
||||
|
||||
@staticmethod
|
||||
def _component_to_dict(component: StandardMessageComponents) -> Dict[str, Any]:
|
||||
"""将单个消息组件转换为插件运行时字典结构。
|
||||
|
||||
Args:
|
||||
component: 待转换的消息组件。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 序列化后的消息组件字典。
|
||||
"""
|
||||
if isinstance(component, TextComponent):
|
||||
return {"type": "text", "data": component.text}
|
||||
|
||||
if isinstance(component, ImageComponent):
|
||||
serialized = {
|
||||
"type": "image",
|
||||
"data": component.content,
|
||||
"hash": component.binary_hash,
|
||||
}
|
||||
if component.binary_data:
|
||||
serialized["binary_data_base64"] = base64.b64encode(component.binary_data).decode("utf-8")
|
||||
return serialized
|
||||
|
||||
if isinstance(component, EmojiComponent):
|
||||
serialized = {
|
||||
"type": "emoji",
|
||||
"data": component.content,
|
||||
"hash": component.binary_hash,
|
||||
}
|
||||
if component.binary_data:
|
||||
serialized["binary_data_base64"] = base64.b64encode(component.binary_data).decode("utf-8")
|
||||
return serialized
|
||||
|
||||
if isinstance(component, VoiceComponent):
|
||||
serialized = {
|
||||
"type": "voice",
|
||||
"data": component.content,
|
||||
"hash": component.binary_hash,
|
||||
}
|
||||
if component.binary_data:
|
||||
serialized["binary_data_base64"] = base64.b64encode(component.binary_data).decode("utf-8")
|
||||
return serialized
|
||||
|
||||
if isinstance(component, AtComponent):
|
||||
return {
|
||||
"type": "at",
|
||||
"data": {
|
||||
"target_user_id": component.target_user_id,
|
||||
"target_user_nickname": component.target_user_nickname,
|
||||
"target_user_cardname": component.target_user_cardname,
|
||||
},
|
||||
}
|
||||
|
||||
if isinstance(component, ReplyComponent):
|
||||
return {
|
||||
"type": "reply",
|
||||
"data": {
|
||||
"target_message_id": component.target_message_id,
|
||||
"target_message_content": component.target_message_content,
|
||||
"target_message_sender_id": component.target_message_sender_id,
|
||||
"target_message_sender_nickname": component.target_message_sender_nickname,
|
||||
"target_message_sender_cardname": component.target_message_sender_cardname,
|
||||
},
|
||||
}
|
||||
|
||||
if isinstance(component, ForwardNodeComponent):
|
||||
return {
|
||||
"type": "forward",
|
||||
"data": [PluginMessageUtils._forward_component_to_dict(item) for item in component.forward_components],
|
||||
}
|
||||
|
||||
return {"type": "dict", "data": component.data}
|
||||
|
||||
@staticmethod
|
||||
def _forward_component_to_dict(component: ForwardComponent) -> Dict[str, Any]:
|
||||
"""将单个转发节点组件转换为字典结构。
|
||||
|
||||
Args:
|
||||
component: 待转换的转发节点组件。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 序列化后的转发节点字典。
|
||||
"""
|
||||
return {
|
||||
"user_id": component.user_id,
|
||||
"user_nickname": component.user_nickname,
|
||||
"user_cardname": component.user_cardname,
|
||||
"message_id": component.message_id,
|
||||
"content": [PluginMessageUtils._component_to_dict(item) for item in component.content],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _message_sequence_from_dict(raw_message_data: List[Dict[str, Any]]) -> MessageSequence:
|
||||
"""从插件运行时字典结构恢复消息组件序列。
|
||||
|
||||
Args:
|
||||
raw_message_data: 插件运行时消息段字典列表。
|
||||
|
||||
Returns:
|
||||
MessageSequence: 恢复后的消息组件序列。
|
||||
"""
|
||||
components = [PluginMessageUtils._component_from_dict(item) for item in raw_message_data]
|
||||
return MessageSequence(components=components)
|
||||
|
||||
@staticmethod
|
||||
def _component_from_dict(item: Dict[str, Any]) -> StandardMessageComponents:
|
||||
"""从插件运行时字典结构恢复单个消息组件。
|
||||
|
||||
Args:
|
||||
item: 单个消息组件的字典表示。
|
||||
|
||||
Returns:
|
||||
StandardMessageComponents: 恢复后的内部消息组件对象。
|
||||
"""
|
||||
item_type = str(item.get("type") or "").strip()
|
||||
if item_type == "text":
|
||||
return TextComponent(text=str(item.get("data") or ""))
|
||||
|
||||
if item_type == "image":
|
||||
return PluginMessageUtils._build_binary_component(ImageComponent, item)
|
||||
|
||||
if item_type == "emoji":
|
||||
return PluginMessageUtils._build_binary_component(EmojiComponent, item)
|
||||
|
||||
if item_type == "voice":
|
||||
return PluginMessageUtils._build_binary_component(VoiceComponent, item)
|
||||
|
||||
if item_type == "at":
|
||||
item_data = item.get("data", {})
|
||||
if not isinstance(item_data, dict):
|
||||
item_data = {}
|
||||
return AtComponent(
|
||||
target_user_id=str(item_data.get("target_user_id") or ""),
|
||||
target_user_nickname=PluginMessageUtils._normalize_optional_string(item_data.get("target_user_nickname")),
|
||||
target_user_cardname=PluginMessageUtils._normalize_optional_string(item_data.get("target_user_cardname")),
|
||||
)
|
||||
|
||||
if item_type == "reply":
|
||||
reply_data = item.get("data")
|
||||
if isinstance(reply_data, dict):
|
||||
return ReplyComponent(
|
||||
target_message_id=str(reply_data.get("target_message_id") or ""),
|
||||
target_message_content=PluginMessageUtils._normalize_optional_string(
|
||||
reply_data.get("target_message_content")
|
||||
),
|
||||
target_message_sender_id=PluginMessageUtils._normalize_optional_string(
|
||||
reply_data.get("target_message_sender_id")
|
||||
),
|
||||
target_message_sender_nickname=PluginMessageUtils._normalize_optional_string(
|
||||
reply_data.get("target_message_sender_nickname")
|
||||
),
|
||||
target_message_sender_cardname=PluginMessageUtils._normalize_optional_string(
|
||||
reply_data.get("target_message_sender_cardname")
|
||||
),
|
||||
)
|
||||
return ReplyComponent(target_message_id=str(reply_data or ""))
|
||||
|
||||
if item_type == "forward":
|
||||
forward_nodes: List[ForwardComponent] = []
|
||||
raw_forward_nodes = item.get("data", [])
|
||||
if isinstance(raw_forward_nodes, list):
|
||||
for node in raw_forward_nodes:
|
||||
if not isinstance(node, dict):
|
||||
continue
|
||||
raw_content = node.get("content", [])
|
||||
node_components: List[StandardMessageComponents] = []
|
||||
if isinstance(raw_content, list):
|
||||
node_components = [
|
||||
PluginMessageUtils._component_from_dict(content)
|
||||
for content in raw_content
|
||||
if isinstance(content, dict)
|
||||
]
|
||||
if not node_components:
|
||||
node_components = [TextComponent(text="[empty forward node]")]
|
||||
forward_nodes.append(
|
||||
ForwardComponent(
|
||||
user_nickname=str(node.get("user_nickname") or "未知用户"),
|
||||
user_id=PluginMessageUtils._normalize_optional_string(node.get("user_id")),
|
||||
user_cardname=PluginMessageUtils._normalize_optional_string(node.get("user_cardname")),
|
||||
message_id=str(node.get("message_id") or ""),
|
||||
content=node_components,
|
||||
)
|
||||
)
|
||||
if not forward_nodes:
|
||||
return DictComponent(data={"type": "forward", "data": item.get("data", [])})
|
||||
return ForwardNodeComponent(forward_components=forward_nodes)
|
||||
|
||||
component_data = item.get("data")
|
||||
if isinstance(component_data, dict):
|
||||
return DictComponent(data=component_data)
|
||||
return DictComponent(data=item)
|
||||
|
||||
@staticmethod
|
||||
def _build_binary_component(component_cls: Any, item: Dict[str, Any]) -> StandardMessageComponents:
|
||||
"""从字典构造带二进制负载的消息组件。
|
||||
|
||||
Args:
|
||||
component_cls: 目标组件类型。
|
||||
item: 消息组件字典。
|
||||
|
||||
Returns:
|
||||
StandardMessageComponents: 构造后的组件对象。
|
||||
"""
|
||||
content = str(item.get("data") or "")
|
||||
binary_hash = str(item.get("hash") or "")
|
||||
raw_binary_base64 = item.get("binary_data_base64")
|
||||
binary_data = b""
|
||||
if isinstance(raw_binary_base64, str) and raw_binary_base64:
|
||||
try:
|
||||
binary_data = base64.b64decode(raw_binary_base64)
|
||||
except Exception:
|
||||
binary_data = b""
|
||||
|
||||
if not binary_hash and binary_data:
|
||||
binary_hash = hashlib.sha256(binary_data).hexdigest()
|
||||
|
||||
return component_cls(binary_hash=binary_hash, content=content, binary_data=binary_data)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_optional_string(value: Any) -> Optional[str]:
|
||||
"""将任意值规范化为可选字符串。
|
||||
|
||||
Args:
|
||||
value: 待规范化的值。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 规范化后的字符串;若值为空则返回 ``None``。
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
normalized_value = str(value)
|
||||
return normalized_value if normalized_value else None
|
||||
|
||||
@staticmethod
|
||||
def _message_info_to_dict(message_info: MessageInfo) -> MessageInfoDict:
|
||||
"""
|
||||
将 MessageInfo 对象转换为字典格式
|
||||
|
||||
Args:
|
||||
message_info: MessageInfo 对象
|
||||
|
||||
Returns:
|
||||
字典格式的消息信息
|
||||
"""
|
||||
user_info_dict = UserInfoDict(
|
||||
user_id=message_info.user_info.user_id,
|
||||
user_nickname=message_info.user_info.user_nickname,
|
||||
user_cardname=message_info.user_info.user_cardname,
|
||||
)
|
||||
|
||||
group_info_dict: Optional[GroupInfoDict] = None
|
||||
if message_info.group_info:
|
||||
group_info_dict = GroupInfoDict(
|
||||
group_id=message_info.group_info.group_id,
|
||||
group_name=message_info.group_info.group_name,
|
||||
)
|
||||
|
||||
return MessageInfoDict(
|
||||
user_info=user_info_dict,
|
||||
group_info=group_info_dict,
|
||||
additional_config=message_info.additional_config,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _session_message_to_dict(session_message: SessionMessage) -> MessageDict:
|
||||
"""
|
||||
将 SessionMessage 对象转换为字典格式(复用 MessageSequence.to_dict 方法)
|
||||
|
||||
Args:
|
||||
session_message: SessionMessage 对象
|
||||
|
||||
Returns:
|
||||
字典格式的消息
|
||||
"""
|
||||
# 转换基本信息
|
||||
message_dict = MessageDict(
|
||||
message_id=session_message.message_id,
|
||||
timestamp=str(session_message.timestamp.timestamp()), # 转换为时间戳字符串
|
||||
platform=session_message.platform,
|
||||
message_info=PluginMessageUtils._message_info_to_dict(session_message.message_info),
|
||||
raw_message=PluginMessageUtils._message_sequence_to_dict(session_message.raw_message),
|
||||
is_mentioned=session_message.is_mentioned,
|
||||
is_at=session_message.is_at,
|
||||
is_emoji=session_message.is_emoji,
|
||||
is_picture=session_message.is_picture,
|
||||
is_command=session_message.is_command,
|
||||
is_notify=session_message.is_notify,
|
||||
session_id=session_message.session_id,
|
||||
)
|
||||
|
||||
# 添加可选字段
|
||||
if session_message.reply_to is not None:
|
||||
message_dict["reply_to"] = session_message.reply_to
|
||||
if session_message.processed_plain_text is not None:
|
||||
message_dict["processed_plain_text"] = session_message.processed_plain_text
|
||||
if session_message.display_message is not None:
|
||||
message_dict["display_message"] = session_message.display_message
|
||||
|
||||
return message_dict
|
||||
|
||||
@staticmethod
|
||||
def _build_message_info_from_dict(message_info_dict: Dict[str, Any]) -> MessageInfo:
|
||||
"""
|
||||
从字典构建 MessageInfo 对象
|
||||
|
||||
Args:
|
||||
message_info_dict: 包含消息信息的字典
|
||||
|
||||
Returns:
|
||||
MessageInfo 对象
|
||||
"""
|
||||
# 构建用户信息
|
||||
user_info_dict = message_info_dict.get("user_info")
|
||||
if not user_info_dict or not isinstance(user_info_dict, dict):
|
||||
raise ValueError("消息字典中 'user_info' 字段无效")
|
||||
user_id = user_info_dict.get("user_id")
|
||||
user_nickname = user_info_dict.get("user_nickname")
|
||||
user_cardname = user_info_dict.get("user_cardname")
|
||||
if not isinstance(user_id, str) or not isinstance(user_nickname, str) or not user_id or not user_nickname:
|
||||
raise ValueError("消息字典中 'user_info' 字段缺少有效的 'user_id' 或 'user_nickname'")
|
||||
user_cardname = str(user_cardname) if user_cardname is not None else None
|
||||
user_info = UserInfo(user_id=user_id, user_nickname=user_nickname, user_cardname=user_cardname)
|
||||
|
||||
# 构建群信息
|
||||
if group_info_dict := message_info_dict.get("group_info"):
|
||||
group_id = group_info_dict.get("group_id")
|
||||
group_name = group_info_dict.get("group_name")
|
||||
if not isinstance(group_id, str) or not isinstance(group_name, str) or not group_id or not group_name:
|
||||
raise ValueError("消息字典中 'group_info' 字段缺少有效的 'group_id' 或 'group_name'")
|
||||
group_info = GroupInfo(group_id=group_id, group_name=group_name)
|
||||
else:
|
||||
group_info = None
|
||||
|
||||
# 获取额外配置
|
||||
additional_config: Dict[str, Any] = message_info_dict.get("additional_config", {})
|
||||
|
||||
return MessageInfo(user_info=user_info, group_info=group_info, additional_config=additional_config)
|
||||
|
||||
@staticmethod
|
||||
def _build_session_message_from_dict(message_dict: Dict[str, Any]) -> SessionMessage:
|
||||
"""
|
||||
从字典构建 SessionMessage 对象(递归处理消息组件)
|
||||
|
||||
Args:
|
||||
message_dict: 包含消息完整信息的字典
|
||||
|
||||
Returns:
|
||||
SessionMessage 对象
|
||||
"""
|
||||
# 提取基本信息
|
||||
message_id = message_dict["message_id"]
|
||||
timestamp_str: str = message_dict.get("timestamp", "")
|
||||
platform = message_dict["platform"]
|
||||
if not isinstance(message_id, str) or not message_id:
|
||||
raise ValueError("消息字典中缺少有效的 'message_id' 字段")
|
||||
if not isinstance(platform, str) or not platform:
|
||||
raise ValueError("消息字典中缺少有效的 'platform' 字段")
|
||||
|
||||
# 解析时间戳
|
||||
try:
|
||||
timestamp_float = float(timestamp_str)
|
||||
timestamp = datetime.fromtimestamp(timestamp_float)
|
||||
except (ValueError, TypeError):
|
||||
timestamp = datetime.now() # 如果解析失败,使用当前时间
|
||||
|
||||
# 创建 SessionMessage 实例
|
||||
session_message = SessionMessage(message_id=message_id, timestamp=timestamp, platform=platform)
|
||||
|
||||
# 构建消息信息
|
||||
session_message.message_info = PluginMessageUtils._build_message_info_from_dict(message_dict["message_info"])
|
||||
|
||||
# 构建原始消息组件序列(复用 MessageSequence.from_dict 方法)
|
||||
raw_message_data = message_dict["raw_message"]
|
||||
if isinstance(raw_message_data, list):
|
||||
session_message.raw_message = PluginMessageUtils._message_sequence_from_dict(raw_message_data)
|
||||
else:
|
||||
raise ValueError("消息字典中 'raw_message' 字段必须是一个列表")
|
||||
|
||||
# 设置其他可选属性
|
||||
session_message.is_mentioned = message_dict.get("is_mentioned", False)
|
||||
if not isinstance(session_message.is_mentioned, bool):
|
||||
session_message.is_mentioned = False
|
||||
session_message.is_at = message_dict.get("is_at", False)
|
||||
if not isinstance(session_message.is_at, bool):
|
||||
session_message.is_at = False
|
||||
session_message.is_emoji = message_dict.get("is_emoji", False)
|
||||
if not isinstance(session_message.is_emoji, bool):
|
||||
session_message.is_emoji = False
|
||||
session_message.is_picture = message_dict.get("is_picture", False)
|
||||
if not isinstance(session_message.is_picture, bool):
|
||||
session_message.is_picture = False
|
||||
session_message.is_command = message_dict.get("is_command", False)
|
||||
if not isinstance(session_message.is_command, bool):
|
||||
session_message.is_command = False
|
||||
session_message.is_notify = message_dict.get("is_notify", False)
|
||||
if not isinstance(session_message.is_notify, bool):
|
||||
session_message.is_notify = False
|
||||
session_message.session_id = message_dict.get("session_id", "")
|
||||
if not isinstance(session_message.session_id, str):
|
||||
session_message.session_id = ""
|
||||
session_message.reply_to = message_dict.get("reply_to")
|
||||
if session_message.reply_to is not None and not isinstance(session_message.reply_to, str):
|
||||
session_message.reply_to = None
|
||||
session_message.processed_plain_text = message_dict.get("processed_plain_text")
|
||||
if session_message.processed_plain_text is not None and not isinstance(
|
||||
session_message.processed_plain_text, str
|
||||
):
|
||||
session_message.processed_plain_text = None
|
||||
session_message.display_message = message_dict.get("display_message")
|
||||
if session_message.display_message is not None and not isinstance(session_message.display_message, str):
|
||||
session_message.display_message = None
|
||||
|
||||
return session_message
|
||||
@@ -1,97 +0,0 @@
|
||||
"""策略引擎
|
||||
|
||||
负责能力授权校验。
|
||||
每个插件在 manifest 中声明能力需求,Host 启动时签发能力令牌。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class CapabilityToken:
|
||||
"""能力令牌"""
|
||||
|
||||
plugin_id: str
|
||||
generation: int
|
||||
capabilities: Set[str] = field(default_factory=set)
|
||||
|
||||
|
||||
class PolicyEngine:
|
||||
"""策略引擎
|
||||
|
||||
管理所有插件的能力令牌,提供授权校验。
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._tokens: Dict[str, Dict[int, CapabilityToken]] = {}
|
||||
|
||||
def register_plugin(
|
||||
self,
|
||||
plugin_id: str,
|
||||
generation: int,
|
||||
capabilities: List[str],
|
||||
) -> CapabilityToken:
|
||||
"""为插件签发能力令牌"""
|
||||
token = CapabilityToken(
|
||||
plugin_id=plugin_id,
|
||||
generation=generation,
|
||||
capabilities=set(capabilities),
|
||||
)
|
||||
self._tokens.setdefault(plugin_id, {})[generation] = token
|
||||
return token
|
||||
|
||||
def revoke_plugin(self, plugin_id: str, generation: Optional[int] = None) -> None:
|
||||
"""撤销插件的能力令牌。"""
|
||||
if generation is None:
|
||||
self._tokens.pop(plugin_id, None)
|
||||
return
|
||||
|
||||
generations = self._tokens.get(plugin_id)
|
||||
if generations is None:
|
||||
return
|
||||
|
||||
generations.pop(generation, None)
|
||||
if not generations:
|
||||
self._tokens.pop(plugin_id, None)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空所有能力令牌。"""
|
||||
self._tokens.clear()
|
||||
|
||||
def check_capability(self, plugin_id: str, capability: str, generation: Optional[int] = None) -> Tuple[bool, str]:
|
||||
"""检查插件是否有权调用某项能力
|
||||
|
||||
Returns:
|
||||
(allowed, reason)
|
||||
"""
|
||||
generations = self._tokens.get(plugin_id)
|
||||
if not generations:
|
||||
return False, f"插件 {plugin_id} 未注册能力令牌"
|
||||
|
||||
if generation is None:
|
||||
token = generations[max(generations)]
|
||||
else:
|
||||
token = generations.get(generation)
|
||||
if token is None:
|
||||
active_generation = max(generations)
|
||||
return False, f"插件 {plugin_id} generation 不匹配: {generation} != {active_generation}"
|
||||
|
||||
if capability not in token.capabilities:
|
||||
return False, f"插件 {plugin_id} 未获授权能力: {capability}"
|
||||
|
||||
if generation is not None and token.generation != generation:
|
||||
return False, f"插件 {plugin_id} generation 不匹配: {generation} != {token.generation}"
|
||||
|
||||
return True, ""
|
||||
|
||||
def get_token(self, plugin_id: str) -> Optional[CapabilityToken]:
|
||||
"""获取插件的能力令牌"""
|
||||
generations = self._tokens.get(plugin_id)
|
||||
if not generations:
|
||||
return None
|
||||
return generations[max(generations)]
|
||||
|
||||
def list_plugins(self) -> List[str]:
|
||||
"""列出所有已注册的插件"""
|
||||
return list(self._tokens.keys())
|
||||
@@ -7,7 +7,7 @@
|
||||
4. 请求-响应关联与超时管理
|
||||
"""
|
||||
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Coroutine
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
@@ -32,7 +32,7 @@ from src.plugin_runtime.transport.base import Connection, TransportServer
|
||||
logger = get_logger("plugin_runtime.host.rpc_server")
|
||||
|
||||
# RPC 方法处理器类型
|
||||
MethodHandler = Callable[[Envelope], Awaitable[Envelope]]
|
||||
MethodHandler = Callable[[Envelope], Coroutine[Any, Any, Envelope]]
|
||||
|
||||
|
||||
class RPCServer:
|
||||
@@ -55,108 +55,39 @@ class RPCServer:
|
||||
|
||||
self._id_gen = RequestIdGenerator()
|
||||
self._connection: Optional[Connection] = None # 当前活跃的 Runner 连接
|
||||
self._runner_id: Optional[str] = None
|
||||
self._runner_generation: int = 0
|
||||
self._staged_connection: Optional[Connection] = None
|
||||
self._staged_runner_id: Optional[str] = None
|
||||
self._staged_runner_generation: int = 0
|
||||
self._staging_takeover: bool = False
|
||||
|
||||
# 方法处理器注册表
|
||||
self._method_handlers: Dict[str, MethodHandler] = {}
|
||||
|
||||
# 等待响应的 pending 请求: request_id -> (Future, target_generation)
|
||||
self._pending_requests: Dict[int, Tuple[asyncio.Future, int]] = {}
|
||||
# 等待响应的 pending 请求: request_id -> Future
|
||||
self._pending_requests: Dict[int, asyncio.Future[Envelope]] = {}
|
||||
|
||||
# 发送队列(背压控制)
|
||||
self._send_queue: Optional[asyncio.Queue[Tuple[Connection, bytes, asyncio.Future[None]]]] = None
|
||||
self._send_worker_task: Optional[asyncio.Task] = None
|
||||
self._send_worker_task: Optional[asyncio.Task[None]] = None
|
||||
|
||||
# 运行状态
|
||||
self._running: bool = False
|
||||
self._tasks: List[asyncio.Task] = []
|
||||
self._tasks: List[asyncio.Task[None]] = []
|
||||
self._last_handshake_rejection_reason: str = ""
|
||||
self._connection_lock: asyncio.Lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def session_token(self) -> str:
|
||||
return self._session_token
|
||||
|
||||
def reset_session_token(self) -> str:
|
||||
"""重新生成会话令牌(热重载时调用,防止旧 Runner 重连)"""
|
||||
self._session_token = secrets.token_hex(32)
|
||||
return self._session_token
|
||||
|
||||
def restore_session_token(self, token: str) -> None:
|
||||
"""恢复指定的会话令牌(热重载回滚时调用)"""
|
||||
self._session_token = token
|
||||
|
||||
@property
|
||||
def runner_generation(self) -> int:
|
||||
return self._runner_generation
|
||||
|
||||
@property
|
||||
def staged_generation(self) -> int:
|
||||
return self._staged_runner_generation
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._connection is not None and not self._connection.is_closed
|
||||
|
||||
def has_generation(self, generation: int) -> bool:
|
||||
return generation == self._runner_generation or (
|
||||
self._staged_connection is not None
|
||||
and not self._staged_connection.is_closed
|
||||
and generation == self._staged_runner_generation
|
||||
)
|
||||
@property
|
||||
def last_handshake_rejection_reason(self) -> str:
|
||||
"""返回最近一次握手被拒绝的原因。"""
|
||||
return self._last_handshake_rejection_reason
|
||||
|
||||
def begin_staged_takeover(self) -> None:
|
||||
"""允许新 Runner 以 staged 方式接入,待 Supervisor 验证后再切换为活跃连接。"""
|
||||
self._staging_takeover = True
|
||||
|
||||
async def commit_staged_takeover(self) -> None:
|
||||
"""提交 staged Runner,原活跃连接在提交后被关闭。"""
|
||||
if self._staged_connection is None or self._staged_connection.is_closed:
|
||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "没有可提交的新 Runner 连接")
|
||||
|
||||
old_connection = self._connection
|
||||
old_generation = self._runner_generation
|
||||
|
||||
self._connection = self._staged_connection
|
||||
self._runner_id = self._staged_runner_id
|
||||
self._runner_generation = self._staged_runner_generation
|
||||
|
||||
self._staged_connection = None
|
||||
self._staged_runner_id = None
|
||||
self._staged_runner_generation = 0
|
||||
self._staging_takeover = False
|
||||
|
||||
if stale_count := self._fail_pending_requests(
|
||||
ErrorCode.E_PLUGIN_CRASHED,
|
||||
"Runner 连接已被新 generation 接管",
|
||||
generation=old_generation,
|
||||
):
|
||||
logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求")
|
||||
|
||||
if old_connection and old_connection is not self._connection and not old_connection.is_closed:
|
||||
await old_connection.close()
|
||||
|
||||
async def rollback_staged_takeover(self) -> None:
|
||||
"""放弃 staged Runner,保留当前活跃连接。"""
|
||||
staged_connection = self._staged_connection
|
||||
staged_generation = self._staged_runner_generation
|
||||
|
||||
self._staged_connection = None
|
||||
self._staged_runner_id = None
|
||||
self._staged_runner_generation = 0
|
||||
self._staging_takeover = False
|
||||
|
||||
self._fail_pending_requests(
|
||||
ErrorCode.E_PLUGIN_CRASHED,
|
||||
"新 Runner 预热失败,已回滚",
|
||||
generation=staged_generation,
|
||||
)
|
||||
|
||||
if staged_connection and not staged_connection.is_closed:
|
||||
await staged_connection.close()
|
||||
def clear_handshake_state(self) -> None:
|
||||
"""清空最近一次握手拒绝状态。"""
|
||||
self._last_handshake_rejection_reason = ""
|
||||
|
||||
def register_method(self, method: str, handler: MethodHandler) -> None:
|
||||
"""注册 RPC 方法处理器"""
|
||||
@@ -165,6 +96,7 @@ class RPCServer:
|
||||
async def start(self) -> None:
|
||||
"""启动 RPC 服务器"""
|
||||
self._running = True
|
||||
self.clear_handshake_state()
|
||||
self._send_queue = asyncio.Queue(maxsize=self._send_queue_size)
|
||||
self._send_worker_task = asyncio.create_task(self._send_loop())
|
||||
await self._transport.start(self._handle_connection)
|
||||
@@ -173,14 +105,9 @@ class RPCServer:
|
||||
async def stop(self) -> None:
|
||||
"""停止 RPC 服务器"""
|
||||
self._running = False
|
||||
|
||||
# 取消所有 pending 请求
|
||||
for future, _generation in self._pending_requests.values():
|
||||
if not future.done():
|
||||
future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
|
||||
self._pending_requests.clear()
|
||||
|
||||
self._fail_queued_sends(ErrorCode.E_TIMEOUT, "服务器关闭")
|
||||
self.clear_handshake_state()
|
||||
self._fail_pending_requests(ErrorCode.E_SHUTTING_DOWN, "服务器正在关闭")
|
||||
self._fail_queued_sends(ErrorCode.E_SHUTTING_DOWN, "服务器正在关闭")
|
||||
|
||||
if self._send_worker_task:
|
||||
self._send_worker_task.cancel()
|
||||
@@ -198,10 +125,6 @@ class RPCServer:
|
||||
await self._connection.close()
|
||||
self._connection = None
|
||||
|
||||
if self._staged_connection:
|
||||
await self._staged_connection.close()
|
||||
self._staged_connection = None
|
||||
|
||||
await self._transport.stop()
|
||||
logger.info("RPC Server 已停止")
|
||||
|
||||
@@ -211,7 +134,6 @@ class RPCServer:
|
||||
plugin_id: str = "",
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
timeout_ms: int = 30000,
|
||||
target_generation: Optional[int] = None,
|
||||
) -> Envelope:
|
||||
"""向 Runner 发送 RPC 请求并等待响应
|
||||
|
||||
@@ -227,18 +149,14 @@ class RPCServer:
|
||||
Raises:
|
||||
RPCError: 调用失败
|
||||
"""
|
||||
generation = target_generation or self._runner_generation
|
||||
conn = self._get_connection_for_generation(generation)
|
||||
if conn is None or conn.is_closed:
|
||||
if not self._connection or self._connection.is_closed:
|
||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
|
||||
|
||||
request_id = self._id_gen.next()
|
||||
request_id = await self._id_gen.next()
|
||||
envelope = Envelope(
|
||||
request_id=request_id,
|
||||
message_type=MessageType.REQUEST,
|
||||
method=method,
|
||||
plugin_id=plugin_id,
|
||||
generation=generation,
|
||||
timeout_ms=timeout_ms,
|
||||
payload=payload or {},
|
||||
)
|
||||
@@ -246,12 +164,12 @@ class RPCServer:
|
||||
# 注册 pending future
|
||||
loop = asyncio.get_running_loop()
|
||||
future: asyncio.Future[Envelope] = loop.create_future()
|
||||
self._pending_requests[request_id] = (future, generation)
|
||||
self._pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
# 发送请求
|
||||
data = self._codec.encode_envelope(envelope)
|
||||
await self._enqueue_send(conn, data)
|
||||
await self._enqueue_send(self._connection, data)
|
||||
|
||||
# 等待响应
|
||||
timeout_sec = timeout_ms / 1000.0
|
||||
@@ -265,150 +183,136 @@ class RPCServer:
|
||||
raise
|
||||
raise RPCError(ErrorCode.E_UNKNOWN, str(e)) from e
|
||||
|
||||
async def send_event(self, method: str, plugin_id: str = "", payload: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""向 Runner 发送单向事件(不等待响应)"""
|
||||
conn = self._connection
|
||||
if conn is None or conn.is_closed:
|
||||
return
|
||||
# ============ 内部方法 ============
|
||||
# ========= 发送循环 =========
|
||||
async def _send_loop(self) -> None:
|
||||
"""后台发送循环:串行消费发送队列,统一执行连接写入。"""
|
||||
if self._send_queue is None:
|
||||
raise RuntimeError("没有消息队列")
|
||||
|
||||
request_id = self._id_gen.next()
|
||||
envelope = Envelope(
|
||||
request_id=request_id,
|
||||
message_type=MessageType.EVENT,
|
||||
method=method,
|
||||
plugin_id=plugin_id,
|
||||
generation=self._runner_generation,
|
||||
payload=payload or {},
|
||||
)
|
||||
data = self._codec.encode_envelope(envelope)
|
||||
await self._enqueue_send(conn, data)
|
||||
while True:
|
||||
try:
|
||||
conn, data, send_future = await self._send_queue.get()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
# ─── 内部方法 ──────────────────────────────────────────────
|
||||
try:
|
||||
if conn.is_closed:
|
||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
|
||||
await conn.send_frame(data)
|
||||
if not send_future.done():
|
||||
send_future.set_result(None)
|
||||
except asyncio.CancelledError:
|
||||
if not send_future.done():
|
||||
send_future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
|
||||
raise
|
||||
except Exception as e:
|
||||
send_error = RPCError.from_exception(e, {ConnectionError: ErrorCode.E_PLUGIN_CRASHED})
|
||||
if not send_future.done():
|
||||
send_future.set_exception(send_error)
|
||||
finally:
|
||||
self._send_queue.task_done()
|
||||
|
||||
# ====== 发送循环方法 ======
|
||||
async def _handle_connection(self, conn: Connection) -> None:
|
||||
"""处理新的 Runner 连接"""
|
||||
logger.info("收到 Runner 连接")
|
||||
previous_connection = self._connection
|
||||
previous_generation = self._runner_generation
|
||||
|
||||
# 第一条消息必须是 runner.hello 握手
|
||||
try:
|
||||
role = await self._handle_handshake(conn)
|
||||
if role is None:
|
||||
await conn.close()
|
||||
return
|
||||
async with self._connection_lock:
|
||||
self.clear_handshake_state()
|
||||
success = await self._handle_handshake(conn)
|
||||
if not success:
|
||||
await conn.close()
|
||||
return
|
||||
logger.info("Runner staged 握手成功")
|
||||
self._connection = conn
|
||||
except Exception as e:
|
||||
logger.error(f"握手失败: {e}")
|
||||
await conn.close()
|
||||
return
|
||||
|
||||
if role == "staged":
|
||||
expected_generation = self._staged_runner_generation
|
||||
logger.info(
|
||||
f"Runner staged 握手成功: runner_id={self._staged_runner_id}, generation={self._staged_runner_generation}"
|
||||
)
|
||||
else:
|
||||
self._connection = conn
|
||||
expected_generation = self._runner_generation
|
||||
logger.info(f"Runner 握手成功: runner_id={self._runner_id}, generation={self._runner_generation}")
|
||||
|
||||
if previous_connection and previous_connection is not conn and not previous_connection.is_closed:
|
||||
logger.info("检测到新 Runner 已接管连接,关闭旧连接")
|
||||
if stale_count := self._fail_pending_requests(
|
||||
ErrorCode.E_PLUGIN_CRASHED,
|
||||
"Runner 连接已被新 generation 接管",
|
||||
generation=previous_generation,
|
||||
):
|
||||
logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求")
|
||||
await previous_connection.close()
|
||||
|
||||
# 启动消息接收循环
|
||||
try:
|
||||
await self._recv_loop(conn, expected_generation=expected_generation)
|
||||
await self._recv_loop(conn)
|
||||
except Exception as e:
|
||||
logger.error(f"连接异常断开: {e}")
|
||||
finally:
|
||||
if self._connection is conn:
|
||||
self._connection = None
|
||||
self._runner_id = None
|
||||
self._fail_pending_requests(
|
||||
ErrorCode.E_PLUGIN_CRASHED,
|
||||
"Runner 连接已断开",
|
||||
generation=expected_generation,
|
||||
)
|
||||
elif self._staged_connection is conn:
|
||||
self._staged_connection = None
|
||||
self._staged_runner_id = None
|
||||
self._staged_runner_generation = 0
|
||||
self._fail_pending_requests(
|
||||
ErrorCode.E_PLUGIN_CRASHED,
|
||||
"Staged Runner 连接已断开",
|
||||
generation=expected_generation,
|
||||
)
|
||||
should_fail_pending_requests = False
|
||||
async with self._connection_lock:
|
||||
if self._connection is conn:
|
||||
self._connection = None
|
||||
should_fail_pending_requests = True
|
||||
if should_fail_pending_requests:
|
||||
self._fail_pending_requests(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已断开")
|
||||
|
||||
async def _handle_handshake(self, conn: Connection) -> Optional[str]:
|
||||
async def _handle_handshake(self, conn: Connection) -> bool:
|
||||
"""处理 runner.hello 握手"""
|
||||
# 接收握手请求
|
||||
data = await asyncio.wait_for(conn.recv_frame(), timeout=10.0)
|
||||
envelope = self._codec.decode_envelope(data)
|
||||
|
||||
if envelope.method != "runner.hello":
|
||||
logger.error(f"期望 runner.hello,收到 {envelope.method}")
|
||||
self._last_handshake_rejection_reason = "首条消息必须为 runner.hello"
|
||||
error_resp = envelope.make_error_response(
|
||||
ErrorCode.E_PROTOCOL_MISMATCH.value,
|
||||
"首条消息必须为 runner.hello",
|
||||
)
|
||||
await conn.send_frame(self._codec.encode_envelope(error_resp))
|
||||
return None
|
||||
return False
|
||||
|
||||
# 解析握手 payload
|
||||
hello = HelloPayload.model_validate(envelope.payload)
|
||||
|
||||
# 校验会话令牌
|
||||
if hello.session_token != self._session_token:
|
||||
logger.error("会话令牌不匹配")
|
||||
resp_payload = HelloResponsePayload(
|
||||
accepted=False,
|
||||
reason="会话令牌无效",
|
||||
)
|
||||
self._last_handshake_rejection_reason = "会话令牌无效"
|
||||
resp_payload = HelloResponsePayload(accepted=False, reason=self._last_handshake_rejection_reason)
|
||||
resp = envelope.make_response(payload=resp_payload.model_dump())
|
||||
await conn.send_frame(self._codec.encode_envelope(resp))
|
||||
return None
|
||||
return False
|
||||
|
||||
# 若已有活跃连接,直接拒绝新的握手,避免后来的连接抢占当前通道。
|
||||
if self.is_connected:
|
||||
logger.warning("拒绝新的 Runner 连接:已有活跃连接")
|
||||
self._last_handshake_rejection_reason = "已有活跃 Runner 连接,拒绝新的握手"
|
||||
resp_payload = HelloResponsePayload(accepted=False, reason=self._last_handshake_rejection_reason)
|
||||
resp = envelope.make_response(payload=resp_payload.model_dump())
|
||||
await conn.send_frame(self._codec.encode_envelope(resp))
|
||||
return False
|
||||
|
||||
# 校验 SDK 版本
|
||||
if not self._check_sdk_version(hello.sdk_version):
|
||||
logger.error(f"SDK 版本不兼容: {hello.sdk_version}")
|
||||
self._last_handshake_rejection_reason = (
|
||||
f"SDK 版本 {hello.sdk_version} 不在支持范围 [{MIN_SDK_VERSION}, {MAX_SDK_VERSION}]"
|
||||
)
|
||||
resp_payload = HelloResponsePayload(
|
||||
accepted=False,
|
||||
reason=f"SDK 版本 {hello.sdk_version} 不在支持范围 [{MIN_SDK_VERSION}, {MAX_SDK_VERSION}]",
|
||||
reason=self._last_handshake_rejection_reason,
|
||||
)
|
||||
resp = envelope.make_response(payload=resp_payload.model_dump())
|
||||
await conn.send_frame(self._codec.encode_envelope(resp))
|
||||
return None
|
||||
return False
|
||||
|
||||
# 握手成功
|
||||
role = "active"
|
||||
assigned_generation = self._runner_generation + 1
|
||||
if self._staging_takeover and self.is_connected:
|
||||
role = "staged"
|
||||
self._staged_connection = conn
|
||||
self._staged_runner_id = hello.runner_id
|
||||
self._staged_runner_generation = assigned_generation
|
||||
else:
|
||||
self._runner_id = hello.runner_id
|
||||
self._runner_generation = assigned_generation
|
||||
|
||||
resp_payload = HelloResponsePayload(
|
||||
accepted=True,
|
||||
host_version=PROTOCOL_VERSION,
|
||||
assigned_generation=assigned_generation,
|
||||
)
|
||||
# 发送响应
|
||||
self.clear_handshake_state()
|
||||
resp_payload = HelloResponsePayload(accepted=True, host_version=PROTOCOL_VERSION)
|
||||
resp = envelope.make_response(payload=resp_payload.model_dump())
|
||||
await conn.send_frame(self._codec.encode_envelope(resp))
|
||||
return True
|
||||
|
||||
return role
|
||||
def _check_sdk_version(self, sdk_version: str) -> bool:
|
||||
"""检查 SDK 版本是否在支持范围内"""
|
||||
try:
|
||||
sdk_parts = _parse_version_tuple(sdk_version)
|
||||
min_parts = _parse_version_tuple(MIN_SDK_VERSION)
|
||||
max_parts = _parse_version_tuple(MAX_SDK_VERSION)
|
||||
return min_parts <= sdk_parts <= max_parts
|
||||
except (ValueError, AttributeError):
|
||||
return False
|
||||
|
||||
async def _recv_loop(self, conn: Connection, expected_generation: int) -> None:
|
||||
# ========= 接收循环 =========
|
||||
async def _recv_loop(self, conn: Connection) -> None:
|
||||
"""消息接收主循环"""
|
||||
while self._running and not conn.is_closed:
|
||||
try:
|
||||
@@ -430,109 +334,40 @@ class RPCServer:
|
||||
if envelope.is_response():
|
||||
self._handle_response(envelope)
|
||||
elif envelope.is_request():
|
||||
if envelope.generation != expected_generation:
|
||||
error_resp = envelope.make_error_response(
|
||||
ErrorCode.E_GENERATION_MISMATCH.value,
|
||||
f"过期 generation: {envelope.generation} != {expected_generation}",
|
||||
)
|
||||
await conn.send_frame(self._codec.encode_envelope(error_resp))
|
||||
continue
|
||||
# 异步处理请求(Runner 发来的能力调用)
|
||||
task = asyncio.create_task(self._handle_request(envelope, conn))
|
||||
self._tasks.append(task)
|
||||
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
|
||||
elif envelope.is_event():
|
||||
if envelope.generation != expected_generation:
|
||||
logger.warning(
|
||||
f"忽略过期 generation 事件 {envelope.method}: {envelope.generation} != {expected_generation}"
|
||||
)
|
||||
continue
|
||||
task = asyncio.create_task(self._handle_event(envelope))
|
||||
elif envelope.is_broadcast():
|
||||
task = asyncio.create_task(self._handle_broadcast(envelope))
|
||||
self._tasks.append(task)
|
||||
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
|
||||
else:
|
||||
logger.warning(f"未知的消息类型: {envelope.message_type}")
|
||||
continue
|
||||
|
||||
# ====== 接收循环内部方法 ======
|
||||
def _handle_response(self, envelope: Envelope) -> None:
|
||||
"""处理来自 Runner 的响应"""
|
||||
pending = self._pending_requests.get(envelope.request_id)
|
||||
if pending is None:
|
||||
pending_future = self._pending_requests.pop(envelope.request_id, None)
|
||||
if pending_future is None:
|
||||
return
|
||||
|
||||
future, expected_generation = pending
|
||||
if envelope.generation != expected_generation:
|
||||
logger.warning(
|
||||
f"忽略过期 generation 响应 {envelope.method}: {envelope.generation} != {expected_generation}"
|
||||
)
|
||||
return
|
||||
|
||||
self._pending_requests.pop(envelope.request_id, None)
|
||||
if not future.done():
|
||||
if not pending_future.done():
|
||||
if envelope.error:
|
||||
future.set_exception(RPCError.from_dict(envelope.error))
|
||||
pending_future.set_exception(RPCError.from_dict(envelope.error))
|
||||
else:
|
||||
future.set_result(envelope)
|
||||
|
||||
async def _enqueue_send(self, conn: Connection, data: bytes) -> None:
|
||||
"""通过发送队列串行发送消息,提供真实背压。"""
|
||||
if conn.is_closed:
|
||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
|
||||
|
||||
if self._send_queue is None:
|
||||
await conn.send_frame(data)
|
||||
return
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
send_future: asyncio.Future[None] = loop.create_future()
|
||||
|
||||
try:
|
||||
self._send_queue.put_nowait((conn, data, send_future))
|
||||
except asyncio.QueueFull:
|
||||
raise RPCError(ErrorCode.E_BACKPRESSURE, "发送队列已满") from None
|
||||
|
||||
await send_future
|
||||
|
||||
async def _send_loop(self) -> None:
|
||||
"""后台发送循环:串行消费发送队列,统一执行连接写入。"""
|
||||
if self._send_queue is None:
|
||||
return
|
||||
|
||||
while True:
|
||||
try:
|
||||
conn, data, send_future = await self._send_queue.get()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
try:
|
||||
if conn.is_closed:
|
||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
|
||||
await conn.send_frame(data)
|
||||
if not send_future.done():
|
||||
send_future.set_result(None)
|
||||
except asyncio.CancelledError:
|
||||
if not send_future.done():
|
||||
send_future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
|
||||
raise
|
||||
except Exception as e:
|
||||
send_error = e if isinstance(e, RPCError) else self._normalize_send_exception(e)
|
||||
if not send_future.done():
|
||||
send_future.set_exception(send_error)
|
||||
finally:
|
||||
self._send_queue.task_done()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_send_exception(error: Exception) -> RPCError:
|
||||
if isinstance(error, ConnectionError):
|
||||
return RPCError(ErrorCode.E_PLUGIN_CRASHED, str(error))
|
||||
return RPCError(ErrorCode.E_UNKNOWN, str(error))
|
||||
pending_future.set_result(envelope)
|
||||
|
||||
async def _handle_request(self, envelope: Envelope, conn: Connection) -> None:
|
||||
"""处理来自 Runner 的请求(通常是能力调用 cap.*)"""
|
||||
handler = self._method_handlers.get(envelope.method)
|
||||
if handler is None:
|
||||
error_resp = envelope.make_error_response(
|
||||
target_method = envelope.method
|
||||
handler = self._method_handlers.get(target_method)
|
||||
if not handler:
|
||||
error_response = envelope.make_error_response(
|
||||
ErrorCode.E_METHOD_NOT_ALLOWED.value,
|
||||
f"未注册的方法: {envelope.method}",
|
||||
)
|
||||
await conn.send_frame(self._codec.encode_envelope(error_resp))
|
||||
await conn.send_frame(self._codec.encode_envelope(error_response))
|
||||
return
|
||||
|
||||
try:
|
||||
@@ -546,59 +381,25 @@ class RPCServer:
|
||||
error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
|
||||
await conn.send_frame(self._codec.encode_envelope(error_resp))
|
||||
|
||||
async def _handle_event(self, envelope: Envelope) -> None:
|
||||
"""处理来自 Runner 的事件"""
|
||||
async def _handle_broadcast(self, envelope: Envelope) -> None:
|
||||
if handler := self._method_handlers.get(envelope.method):
|
||||
try:
|
||||
result = await handler(envelope)
|
||||
# 检查 handler 返回的信封是否包含错误信息
|
||||
if result is not None and isinstance(result, Envelope) and result.error:
|
||||
if result.error:
|
||||
logger.warning(f"事件 {envelope.method} handler 返回错误: {result.error.get('message', '')}")
|
||||
except Exception as e:
|
||||
logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True)
|
||||
|
||||
@staticmethod
|
||||
def _check_sdk_version(sdk_version: str) -> bool:
|
||||
"""检查 SDK 版本是否在支持范围内"""
|
||||
try:
|
||||
sdk_parts = RPCServer._parse_version_tuple(sdk_version)
|
||||
min_parts = RPCServer._parse_version_tuple(MIN_SDK_VERSION)
|
||||
max_parts = RPCServer._parse_version_tuple(MAX_SDK_VERSION)
|
||||
return min_parts <= sdk_parts <= max_parts
|
||||
except (ValueError, AttributeError):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _parse_version_tuple(version: str) -> Tuple[int, int, int]:
|
||||
base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version or "", flags=re.IGNORECASE)[0]
|
||||
base_version = base_version.split("+", 1)[0]
|
||||
parts = [part for part in base_version.split(".") if part != ""]
|
||||
while len(parts) < 3:
|
||||
parts.append("0")
|
||||
return (int(parts[0]), int(parts[1]), int(parts[2]))
|
||||
|
||||
def _get_connection_for_generation(self, generation: int) -> Optional[Connection]:
|
||||
if generation == self._runner_generation:
|
||||
return self._connection
|
||||
if generation == self._staged_runner_generation:
|
||||
return self._staged_connection
|
||||
return None
|
||||
|
||||
def _fail_pending_requests(
|
||||
self,
|
||||
error_code: ErrorCode,
|
||||
message: str,
|
||||
generation: Optional[int] = None,
|
||||
) -> int:
|
||||
stale_count = 0
|
||||
for request_id, (future, request_generation) in list(self._pending_requests.items()):
|
||||
if generation is not None and request_generation != generation:
|
||||
continue
|
||||
def _fail_pending_requests(self, error_code: ErrorCode, message: str) -> int:
|
||||
"""失败所有等待中的请求(如连接断开时)"""
|
||||
aborted_request_count = 0
|
||||
for future in self._pending_requests.values():
|
||||
if not future.done():
|
||||
future.set_exception(RPCError(error_code, message))
|
||||
stale_count += 1
|
||||
self._pending_requests.pop(request_id, None)
|
||||
return stale_count
|
||||
aborted_request_count += 1
|
||||
self._pending_requests.clear()
|
||||
return aborted_request_count
|
||||
|
||||
def _fail_queued_sends(self, error_code: ErrorCode, message: str) -> int:
|
||||
if self._send_queue is None:
|
||||
@@ -617,3 +418,31 @@ class RPCServer:
|
||||
self._send_queue.task_done()
|
||||
|
||||
return failed_count
|
||||
|
||||
async def _enqueue_send(self, conn: Connection, data: bytes) -> None:
|
||||
"""通过发送队列串行发送消息,提供真实背压。"""
|
||||
if conn.is_closed:
|
||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
|
||||
|
||||
if self._send_queue is None:
|
||||
await conn.send_frame(data)
|
||||
return
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
send_future: asyncio.Future[None] = loop.create_future()
|
||||
|
||||
try:
|
||||
self._send_queue.put_nowait((conn, data, send_future))
|
||||
except asyncio.QueueFull:
|
||||
raise RPCError(ErrorCode.E_BACK_PRESSURE, "发送队列已满") from None
|
||||
|
||||
await send_future
|
||||
|
||||
|
||||
def _parse_version_tuple(version: str) -> Tuple[int, int, int]:
|
||||
base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version or "", flags=re.IGNORECASE)[0]
|
||||
base_version = base_version.split("+", 1)[0]
|
||||
parts = [part for part in base_version.split(".") if part != ""]
|
||||
while len(parts) < 3:
|
||||
parts.append("0")
|
||||
return (int(parts[0]), int(parts[1]), int(parts[2]))
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,422 +0,0 @@
|
||||
"""Host-side WorkflowExecutor
|
||||
|
||||
6 阶段线性流转(INGRESS → PRE_PROCESS → PLAN → TOOL_EXECUTE → POST_PROCESS → EGRESS)
|
||||
|
||||
每个阶段执行顺序:
|
||||
1. Host-side pre-filter: 根据 hook filter 条件过滤不相关的 hook
|
||||
2. 按 priority 降序排列
|
||||
3. 串行执行 blocking hook(可修改 message,返回 HookResult)
|
||||
4. 并发执行 non-blocking hook(只读)
|
||||
5. 检查是否有 SKIP_STAGE 或 ABORT
|
||||
6. PLAN 阶段内置 Command 匹配路由
|
||||
|
||||
支持:
|
||||
- HookResult: CONTINUE / SKIP_STAGE / ABORT
|
||||
- ErrorPolicy: ABORT / SKIP / LOG (per-hook)
|
||||
- stage_outputs: 阶段间带命名空间的数据传递
|
||||
- modification_log: 消息修改审计
|
||||
"""
|
||||
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry, RegisteredComponent
|
||||
|
||||
logger = get_logger("plugin_runtime.host.workflow_executor")
|
||||
|
||||
# 阶段顺序
|
||||
STAGE_SEQUENCE: List[str] = [
|
||||
"ingress",
|
||||
"pre_process",
|
||||
"plan",
|
||||
"tool_execute",
|
||||
"post_process",
|
||||
"egress",
|
||||
]
|
||||
|
||||
# HookResult 常量(与 SDK HookResult enum 值对应)
|
||||
HOOK_CONTINUE = "continue"
|
||||
HOOK_SKIP_STAGE = "skip_stage"
|
||||
HOOK_ABORT = "abort"
|
||||
|
||||
|
||||
# blocking hook 全局最大超时(秒):即使 hook 声明 timeout_ms=0 也不会无限等待
|
||||
# 从配置文件读取,允许用户调整
|
||||
def _get_blocking_timeout() -> float:
|
||||
return global_config.plugin_runtime.workflow_blocking_timeout_sec
|
||||
|
||||
|
||||
class ModificationRecord:
|
||||
"""消息修改记录"""
|
||||
|
||||
__slots__ = ("stage", "hook_name", "timestamp", "fields_changed")
|
||||
|
||||
def __init__(self, stage: str, hook_name: str, fields_changed: List[str]) -> None:
|
||||
self.stage = stage
|
||||
self.hook_name = hook_name
|
||||
self.timestamp = time.perf_counter()
|
||||
self.fields_changed = fields_changed
|
||||
|
||||
|
||||
class WorkflowContext:
|
||||
"""Workflow 执行上下文"""
|
||||
|
||||
def __init__(self, trace_id: Optional[str] = None, stream_id: Optional[str] = None) -> None:
|
||||
self.trace_id = trace_id or uuid.uuid4().hex
|
||||
self.stream_id = stream_id
|
||||
self.timings: Dict[str, float] = {}
|
||||
self.errors: List[str] = []
|
||||
# 阶段间数据传递(按 stage 命名空间隔离)
|
||||
self.stage_outputs: Dict[str, Dict[str, Any]] = {}
|
||||
# 消息修改审计日志
|
||||
self.modification_log: List[ModificationRecord] = []
|
||||
# PLAN 阶段命令匹配结果
|
||||
self.matched_command: Optional[str] = None
|
||||
|
||||
def set_stage_output(self, stage: str, key: str, value: Any) -> None:
|
||||
self.stage_outputs.setdefault(stage, {})[key] = value
|
||||
|
||||
def get_stage_output(self, stage: str, key: str, default: Any = None) -> Any:
|
||||
return self.stage_outputs.get(stage, {}).get(key, default)
|
||||
|
||||
|
||||
class WorkflowResult:
|
||||
"""Workflow 执行结果"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
status: str = "completed", # completed / aborted / failed
|
||||
return_message: str = "",
|
||||
stopped_at: str = "",
|
||||
diagnostics: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
self.status = status
|
||||
self.return_message = return_message
|
||||
self.stopped_at = stopped_at
|
||||
self.diagnostics = diagnostics or {}
|
||||
|
||||
|
||||
# invoke_fn 签名
|
||||
InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]]
|
||||
|
||||
|
||||
class WorkflowExecutor:
|
||||
"""Host-side Workflow 执行器
|
||||
|
||||
实现 stage-based pipeline + per-stage hook chain with priority + early return。
|
||||
"""
|
||||
|
||||
def __init__(self, registry: ComponentRegistry) -> None:
|
||||
self._registry = registry
|
||||
self._background_tasks: Set[asyncio.Task] = set()
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
invoke_fn: InvokeFn,
|
||||
message: Optional[Dict[str, Any]] = None,
|
||||
stream_id: Optional[str] = None,
|
||||
context: Optional[WorkflowContext] = None,
|
||||
command_invoke_fn: Optional[InvokeFn] = None,
|
||||
) -> Tuple[WorkflowResult, Optional[Dict[str, Any]], WorkflowContext]:
|
||||
"""执行 workflow pipeline。
|
||||
|
||||
Args:
|
||||
invoke_fn: 用于 workflow_step 的回调
|
||||
command_invoke_fn: 用于 command 的回调(走 plugin.invoke_command),
|
||||
未传则复用 invoke_fn
|
||||
|
||||
Returns:
|
||||
(result, final_message, context)
|
||||
"""
|
||||
ctx = context or WorkflowContext(stream_id=stream_id)
|
||||
current_message = dict(message) if message else None
|
||||
|
||||
for stage in STAGE_SEQUENCE:
|
||||
stage_start = time.perf_counter()
|
||||
|
||||
try:
|
||||
# PLAN 阶段: 先做 Command 路由
|
||||
if stage == "plan" and current_message:
|
||||
cmd_result = await self._route_command(command_invoke_fn or invoke_fn, current_message, ctx)
|
||||
if cmd_result is not None:
|
||||
# 命令匹配成功,跳过 PLAN 阶段的 hook,直接存结果进 stage_outputs
|
||||
ctx.set_stage_output("plan", "command_result", cmd_result)
|
||||
ctx.timings[stage] = time.perf_counter() - stage_start
|
||||
continue
|
||||
|
||||
# 获取该阶段所有 hook(已按 priority 降序排列)
|
||||
all_steps = self._registry.get_workflow_steps(stage)
|
||||
if not all_steps:
|
||||
ctx.timings[stage] = time.perf_counter() - stage_start
|
||||
continue
|
||||
|
||||
# 1. Pre-filter
|
||||
filtered_steps = self._pre_filter(all_steps, current_message)
|
||||
|
||||
# 2. 分离 blocking 和 non-blocking
|
||||
blocking_steps = [s for s in filtered_steps if s.metadata.get("blocking", True)]
|
||||
nonblocking_steps = [s for s in filtered_steps if not s.metadata.get("blocking", True)]
|
||||
|
||||
# 3. 串行执行 blocking hook
|
||||
skip_stage = False
|
||||
for step in blocking_steps:
|
||||
hook_result, modified, step_error = await self._invoke_step(
|
||||
invoke_fn, step, stage, ctx, current_message
|
||||
)
|
||||
|
||||
if step_error:
|
||||
error_policy = step.metadata.get("error_policy", "abort")
|
||||
ctx.errors.append(f"{step.full_name}: {step_error}")
|
||||
|
||||
if error_policy == "abort":
|
||||
ctx.timings[stage] = time.perf_counter() - stage_start
|
||||
return (
|
||||
WorkflowResult(
|
||||
status="failed",
|
||||
return_message=step_error,
|
||||
stopped_at=stage,
|
||||
diagnostics={"step": step.full_name, "trace_id": ctx.trace_id},
|
||||
),
|
||||
current_message,
|
||||
ctx,
|
||||
)
|
||||
elif error_policy == "skip":
|
||||
logger.warning(f"[{ctx.trace_id}] hook {step.full_name} 异常(skip): {step_error}")
|
||||
continue
|
||||
else: # log
|
||||
logger.warning(f"[{ctx.trace_id}] hook {step.full_name} 异常(log): {step_error}")
|
||||
continue
|
||||
|
||||
# 更新消息(仅 blocking hook 有权修改)
|
||||
if modified:
|
||||
changed_fields = (
|
||||
_diff_keys(current_message, modified) if current_message else list(modified.keys())
|
||||
)
|
||||
ctx.modification_log.append(ModificationRecord(stage, step.full_name, changed_fields))
|
||||
current_message = modified
|
||||
|
||||
if hook_result == HOOK_ABORT:
|
||||
ctx.timings[stage] = time.perf_counter() - stage_start
|
||||
return (
|
||||
WorkflowResult(
|
||||
status="aborted",
|
||||
return_message=f"aborted by {step.full_name}",
|
||||
stopped_at=stage,
|
||||
diagnostics={"step": step.full_name, "trace_id": ctx.trace_id},
|
||||
),
|
||||
current_message,
|
||||
ctx,
|
||||
)
|
||||
|
||||
if hook_result == HOOK_SKIP_STAGE:
|
||||
skip_stage = True
|
||||
break
|
||||
|
||||
# 4. 并发执行 non-blocking hook(只读,忽略返回值中的 modified_message)
|
||||
if nonblocking_steps and not skip_stage:
|
||||
for step in nonblocking_steps:
|
||||
self._track_background_task(
|
||||
asyncio.create_task(
|
||||
self._invoke_step_fire_and_forget(invoke_fn, step, stage, ctx, current_message)
|
||||
)
|
||||
)
|
||||
|
||||
ctx.timings[stage] = time.perf_counter() - stage_start
|
||||
|
||||
except Exception as e:
|
||||
ctx.timings[stage] = time.perf_counter() - stage_start
|
||||
ctx.errors.append(f"{stage}: {e}")
|
||||
logger.error(f"[{ctx.trace_id}] 阶段 {stage} 未捕获异常: {e}", exc_info=True)
|
||||
return (
|
||||
WorkflowResult(
|
||||
status="failed",
|
||||
return_message=str(e),
|
||||
stopped_at=stage,
|
||||
diagnostics={"trace_id": ctx.trace_id},
|
||||
),
|
||||
current_message,
|
||||
ctx,
|
||||
)
|
||||
|
||||
return (
|
||||
WorkflowResult(
|
||||
status="completed",
|
||||
return_message="workflow completed",
|
||||
diagnostics={"trace_id": ctx.trace_id},
|
||||
),
|
||||
current_message,
|
||||
ctx,
|
||||
)
|
||||
|
||||
def _track_background_task(self, task: asyncio.Task) -> None:
|
||||
"""保持 non-blocking workflow task 的强引用,直到任务结束。"""
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
# ─── 内部方法 ──────────────────────────────────────────────
|
||||
|
||||
def _pre_filter(
|
||||
self,
|
||||
steps: List[RegisteredComponent],
|
||||
message: Optional[Dict[str, Any]],
|
||||
) -> List[RegisteredComponent]:
|
||||
"""根据 hook 声明的 filter 条件预过滤,避免无意义的 IPC 调用。"""
|
||||
if not message:
|
||||
return steps
|
||||
|
||||
result = []
|
||||
for step in steps:
|
||||
filter_cond = step.metadata.get("filter", {})
|
||||
if not filter_cond:
|
||||
result.append(step)
|
||||
continue
|
||||
if self._match_filter(filter_cond, message):
|
||||
result.append(step)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _match_filter(filter_cond: Dict[str, Any], message: Dict[str, Any]) -> bool:
|
||||
"""简单 key-value 匹配过滤。
|
||||
|
||||
filter 中的每个 key 必须在 message 中存在且值相等,
|
||||
全部匹配才通过。
|
||||
"""
|
||||
for key, expected in filter_cond.items():
|
||||
actual = message.get(key)
|
||||
if (isinstance(expected, list) and actual not in expected) or (
|
||||
not isinstance(expected, list) and actual != expected
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
async def _invoke_step(
|
||||
self,
|
||||
invoke_fn: InvokeFn,
|
||||
step: RegisteredComponent,
|
||||
stage: str,
|
||||
ctx: WorkflowContext,
|
||||
message: Optional[Dict[str, Any]],
|
||||
) -> Tuple[str, Optional[Dict[str, Any]], Optional[str]]:
|
||||
"""调用单个 blocking hook。
|
||||
|
||||
Returns:
|
||||
(hook_result, modified_message, error_string_or_None)
|
||||
"""
|
||||
timeout_ms = step.metadata.get("timeout_ms", 0)
|
||||
# 使用 hook 声明的超时,但不超过全局安全阀
|
||||
timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else _get_blocking_timeout()
|
||||
step_key = f"{stage}:{step.full_name}"
|
||||
step_start = time.perf_counter()
|
||||
|
||||
try:
|
||||
coro = invoke_fn(
|
||||
step.plugin_id,
|
||||
step.name,
|
||||
{
|
||||
"stage": stage,
|
||||
"trace_id": ctx.trace_id,
|
||||
"message": message,
|
||||
"stage_outputs": ctx.stage_outputs,
|
||||
},
|
||||
)
|
||||
resp = await asyncio.wait_for(coro, timeout=timeout_sec)
|
||||
ctx.timings[step_key] = time.perf_counter() - step_start
|
||||
|
||||
hook_result = resp.get("hook_result", HOOK_CONTINUE)
|
||||
modified_message = resp.get("modified_message")
|
||||
# 存 stage output(如果 hook 提供了)
|
||||
stage_out = resp.get("stage_output")
|
||||
if isinstance(stage_out, dict):
|
||||
for k, v in stage_out.items():
|
||||
ctx.set_stage_output(stage, k, v)
|
||||
|
||||
return hook_result, modified_message, None
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
ctx.timings[step_key] = time.perf_counter() - step_start
|
||||
return HOOK_CONTINUE, None, f"timeout after {timeout_ms}ms"
|
||||
|
||||
except Exception as e:
|
||||
ctx.timings[step_key] = time.perf_counter() - step_start
|
||||
return HOOK_CONTINUE, None, str(e)
|
||||
|
||||
async def _invoke_step_fire_and_forget(
|
||||
self,
|
||||
invoke_fn: InvokeFn,
|
||||
step: RegisteredComponent,
|
||||
stage: str,
|
||||
ctx: WorkflowContext,
|
||||
message: Optional[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""Non-blocking hook 调用,只读,忽略结果。"""
|
||||
timeout_ms = step.metadata.get("timeout_ms", 0)
|
||||
# 使用 hook 声明的超时,但无声明时回退到全局安全阀,防止 task 泄漏
|
||||
timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else _get_blocking_timeout()
|
||||
|
||||
try:
|
||||
coro = invoke_fn(
|
||||
step.plugin_id,
|
||||
step.name,
|
||||
{
|
||||
"stage": stage,
|
||||
"trace_id": ctx.trace_id,
|
||||
"message": message,
|
||||
"stage_outputs": ctx.stage_outputs,
|
||||
},
|
||||
)
|
||||
await asyncio.wait_for(coro, timeout=timeout_sec)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"[{ctx.trace_id}] non-blocking hook {step.full_name} 超时 ({timeout_sec}s)")
|
||||
except Exception as e:
|
||||
logger.debug(f"[{ctx.trace_id}] non-blocking hook {step.full_name}: {e}")
|
||||
|
||||
async def _route_command(
|
||||
self,
|
||||
invoke_fn: InvokeFn,
|
||||
message: Dict[str, Any],
|
||||
ctx: WorkflowContext,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""PLAN 阶段内置 Command 路由。
|
||||
|
||||
在 registry 中查找匹配的 command 组件,
|
||||
匹配到则直接路由到对应 command handler,返回执行结果。
|
||||
不匹配则返回 None,让 PLAN 阶段的 hook 继续执行。
|
||||
"""
|
||||
plain_text = message.get("plain_text", "")
|
||||
if not plain_text:
|
||||
return None
|
||||
|
||||
match_result = self._registry.find_command_by_text(plain_text)
|
||||
if match_result is None:
|
||||
return None
|
||||
|
||||
matched, matched_groups = match_result
|
||||
|
||||
ctx.matched_command = matched.full_name
|
||||
logger.info(f"[{ctx.trace_id}] 命令匹配: {matched.full_name}")
|
||||
|
||||
try:
|
||||
return await invoke_fn(
|
||||
matched.plugin_id,
|
||||
matched.name,
|
||||
{
|
||||
"text": plain_text,
|
||||
"message": message,
|
||||
"trace_id": ctx.trace_id,
|
||||
"matched_groups": matched_groups,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[{ctx.trace_id}] 命令 {matched.full_name} 执行失败: {e}", exc_info=True)
|
||||
ctx.errors.append(f"command:{matched.full_name}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _diff_keys(old: Dict[str, Any], new: Dict[str, Any]) -> List[str]:
|
||||
"""返回 new 中与 old 不同的 key 列表。"""
|
||||
return [k for k, v in new.items() if k not in old or old[k] != v]
|
||||
@@ -8,23 +8,27 @@
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, Iterable, List, Optional, Sequence, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, Iterable, List, Optional, Sequence, Set, Tuple
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import tomlkit
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.config.config import config_manager
|
||||
from src.config.file_watcher import FileChange, FileWatcher
|
||||
from src.platform_io import DeliveryBatch, InboundMessageEnvelope, get_platform_io_manager
|
||||
from src.plugin_runtime.capabilities import (
|
||||
RuntimeComponentCapabilityMixin,
|
||||
RuntimeCoreCapabilityMixin,
|
||||
RuntimeDataCapabilityMixin,
|
||||
)
|
||||
from src.plugin_runtime.capabilities.registry import register_capability_impls
|
||||
from src.plugin_runtime.host.message_utils import MessageDict, PluginMessageUtils
|
||||
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||||
|
||||
logger = get_logger("plugin_runtime.integration")
|
||||
@@ -55,6 +59,7 @@ class PluginRuntimeManager(
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化插件运行时管理器。"""
|
||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||||
|
||||
self._builtin_supervisor: Optional[PluginSupervisor] = None
|
||||
@@ -63,6 +68,26 @@ class PluginRuntimeManager(
|
||||
self._plugin_file_watcher: Optional[FileWatcher] = None
|
||||
self._plugin_source_watcher_subscription_id: Optional[str] = None
|
||||
self._plugin_config_watcher_subscriptions: Dict[str, Tuple[Path, str]] = {}
|
||||
self._plugin_path_cache: Dict[str, Path] = {}
|
||||
self._manifest_validator: ManifestValidator = ManifestValidator()
|
||||
self._config_reload_callback: Callable[[Sequence[str]], Awaitable[None]] = self._handle_main_config_reload
|
||||
self._config_reload_callback_registered: bool = False
|
||||
|
||||
async def _dispatch_platform_inbound(self, envelope: InboundMessageEnvelope) -> None:
|
||||
"""接收 Platform IO 审核后的入站消息并送入主消息链。
|
||||
|
||||
Args:
|
||||
envelope: Platform IO 产出的入站封装。
|
||||
"""
|
||||
session_message = envelope.session_message
|
||||
if session_message is None and envelope.payload is not None:
|
||||
session_message = PluginMessageUtils._build_session_message_from_dict(dict(envelope.payload))
|
||||
if session_message is None:
|
||||
raise ValueError("Platform IO 入站封装缺少可用的 SessionMessage 或 payload")
|
||||
|
||||
from src.chat.message_receive.bot import chat_bot
|
||||
|
||||
await chat_bot.receive_message(session_message)
|
||||
|
||||
# ─── 插件目录 ─────────────────────────────────────────────
|
||||
|
||||
@@ -78,6 +103,42 @@ class PluginRuntimeManager(
|
||||
candidate = Path("plugins").resolve()
|
||||
return [candidate] if candidate.is_dir() else []
|
||||
|
||||
@classmethod
|
||||
def _discover_plugin_dependency_map(cls, plugin_dirs: Iterable[Path]) -> Dict[str, List[str]]:
|
||||
"""扫描指定插件目录集合,返回 ``plugin_id -> dependencies`` 映射。"""
|
||||
validator = ManifestValidator()
|
||||
return validator.build_plugin_dependency_map(plugin_dirs)
|
||||
|
||||
@classmethod
|
||||
def _build_group_start_order(
|
||||
cls,
|
||||
builtin_dirs: Sequence[Path],
|
||||
third_party_dirs: Sequence[Path],
|
||||
) -> List[str]:
|
||||
"""根据跨 Supervisor 依赖关系决定 Runner 启动顺序。"""
|
||||
|
||||
builtin_dependencies = cls._discover_plugin_dependency_map(builtin_dirs)
|
||||
third_party_dependencies = cls._discover_plugin_dependency_map(third_party_dirs)
|
||||
builtin_plugin_ids = set(builtin_dependencies)
|
||||
third_party_plugin_ids = set(third_party_dependencies)
|
||||
|
||||
builtin_needs_third_party = any(
|
||||
dependency in third_party_plugin_ids
|
||||
for dependencies in builtin_dependencies.values()
|
||||
for dependency in dependencies
|
||||
)
|
||||
third_party_needs_builtin = any(
|
||||
dependency in builtin_plugin_ids
|
||||
for dependencies in third_party_dependencies.values()
|
||||
for dependency in dependencies
|
||||
)
|
||||
|
||||
if builtin_needs_third_party and third_party_needs_builtin:
|
||||
raise RuntimeError("检测到跨 Supervisor 循环依赖,当前无法安全启动独立 Runner")
|
||||
if builtin_needs_third_party:
|
||||
return ["third_party", "builtin"]
|
||||
return ["builtin", "third_party"]
|
||||
|
||||
# ─── 生命周期 ─────────────────────────────────────────────
|
||||
|
||||
async def start(self) -> None:
|
||||
@@ -86,7 +147,7 @@ class PluginRuntimeManager(
|
||||
logger.warning("PluginRuntimeManager 已在运行中,跳过重复启动")
|
||||
return
|
||||
|
||||
_cfg = global_config.plugin_runtime
|
||||
_cfg = config_manager.get_global_config().plugin_runtime
|
||||
if not _cfg.enabled:
|
||||
logger.info("插件运行时已在配置中禁用,跳过启动")
|
||||
return
|
||||
@@ -108,6 +169,8 @@ class PluginRuntimeManager(
|
||||
logger.info("未找到任何插件目录,跳过插件运行时启动")
|
||||
return
|
||||
|
||||
platform_io_manager = get_platform_io_manager()
|
||||
|
||||
# 从配置读取自定义 IPC socket 路径(留空则自动生成)
|
||||
socket_path_base = _cfg.ipc_socket_path or None
|
||||
|
||||
@@ -132,19 +195,46 @@ class PluginRuntimeManager(
|
||||
|
||||
started_supervisors: List[PluginSupervisor] = []
|
||||
try:
|
||||
if self._builtin_supervisor:
|
||||
await self._builtin_supervisor.start()
|
||||
started_supervisors.append(self._builtin_supervisor)
|
||||
if self._third_party_supervisor:
|
||||
await self._third_party_supervisor.start()
|
||||
started_supervisors.append(self._third_party_supervisor)
|
||||
platform_io_manager.set_inbound_dispatcher(self._dispatch_platform_inbound)
|
||||
await platform_io_manager.ensure_send_pipeline_ready()
|
||||
|
||||
supervisor_groups: Dict[str, Optional[PluginSupervisor]] = {
|
||||
"builtin": self._builtin_supervisor,
|
||||
"third_party": self._third_party_supervisor,
|
||||
}
|
||||
start_order = self._build_group_start_order(builtin_dirs, third_party_dirs)
|
||||
|
||||
for group_name in start_order:
|
||||
supervisor = supervisor_groups.get(group_name)
|
||||
if supervisor is None:
|
||||
continue
|
||||
|
||||
external_plugin_versions = {
|
||||
plugin_id: plugin_version
|
||||
for started_supervisor in started_supervisors
|
||||
for plugin_id, plugin_version in started_supervisor.get_loaded_plugin_versions().items()
|
||||
}
|
||||
supervisor.set_external_available_plugins(external_plugin_versions)
|
||||
await supervisor.start()
|
||||
started_supervisors.append(supervisor)
|
||||
|
||||
await self._start_plugin_file_watcher()
|
||||
config_manager.register_reload_callback(self._config_reload_callback)
|
||||
self._config_reload_callback_registered = True
|
||||
self._started = True
|
||||
logger.info(f"插件运行时已启动 — 内置: {builtin_dirs or '无'}, 第三方: {third_party_dirs or '无'}")
|
||||
except Exception as e:
|
||||
logger.error(f"插件运行时启动失败: {e}", exc_info=True)
|
||||
await self._stop_plugin_file_watcher()
|
||||
if self._config_reload_callback_registered:
|
||||
config_manager.unregister_reload_callback(self._config_reload_callback)
|
||||
self._config_reload_callback_registered = False
|
||||
await asyncio.gather(*(sv.stop() for sv in started_supervisors), return_exceptions=True)
|
||||
platform_io_manager.clear_inbound_dispatcher()
|
||||
try:
|
||||
await platform_io_manager.stop()
|
||||
except Exception as platform_io_exc:
|
||||
logger.warning(f"Platform IO 停止失败: {platform_io_exc}")
|
||||
self._started = False
|
||||
self._builtin_supervisor = None
|
||||
self._third_party_supervisor = None
|
||||
@@ -154,7 +244,11 @@ class PluginRuntimeManager(
|
||||
if not self._started:
|
||||
return
|
||||
|
||||
platform_io_manager = get_platform_io_manager()
|
||||
await self._stop_plugin_file_watcher()
|
||||
if self._config_reload_callback_registered:
|
||||
config_manager.unregister_reload_callback(self._config_reload_callback)
|
||||
self._config_reload_callback_registered = False
|
||||
|
||||
coroutines: List[Coroutine[Any, Any, None]] = []
|
||||
if self._builtin_supervisor:
|
||||
@@ -162,18 +256,32 @@ class PluginRuntimeManager(
|
||||
if self._third_party_supervisor:
|
||||
coroutines.append(self._third_party_supervisor.stop())
|
||||
|
||||
stop_errors: List[str] = []
|
||||
try:
|
||||
await asyncio.gather(*coroutines, return_exceptions=True)
|
||||
logger.info("插件运行时已停止")
|
||||
except Exception as e:
|
||||
logger.error(f"插件运行时停止失败: {e}", exc_info=True)
|
||||
results = await asyncio.gather(*coroutines, return_exceptions=True)
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
stop_errors.append(str(result))
|
||||
|
||||
platform_io_manager.clear_inbound_dispatcher()
|
||||
try:
|
||||
await platform_io_manager.stop()
|
||||
except Exception as exc:
|
||||
stop_errors.append(f"Platform IO: {exc}")
|
||||
|
||||
if stop_errors:
|
||||
logger.error(f"插件运行时停止过程中存在错误: {'; '.join(stop_errors)}")
|
||||
else:
|
||||
logger.info("插件运行时已停止")
|
||||
finally:
|
||||
self._started = False
|
||||
self._builtin_supervisor = None
|
||||
self._third_party_supervisor = None
|
||||
self._plugin_path_cache.clear()
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""返回插件运行时是否处于启动状态。"""
|
||||
return self._started
|
||||
|
||||
@property
|
||||
@@ -181,11 +289,176 @@ class PluginRuntimeManager(
|
||||
"""获取所有活跃的 Supervisor"""
|
||||
return [s for s in (self._builtin_supervisor, self._third_party_supervisor) if s is not None]
|
||||
|
||||
def _build_registered_dependency_map(self) -> Dict[str, Set[str]]:
|
||||
"""根据当前已注册插件构建全局依赖图。"""
|
||||
|
||||
dependency_map: Dict[str, Set[str]] = {}
|
||||
for supervisor in self.supervisors:
|
||||
for plugin_id, registration in getattr(supervisor, "_registered_plugins", {}).items():
|
||||
dependency_map[plugin_id] = {
|
||||
str(dependency or "").strip()
|
||||
for dependency in getattr(registration, "dependencies", [])
|
||||
if str(dependency or "").strip()
|
||||
}
|
||||
return dependency_map
|
||||
|
||||
@staticmethod
|
||||
def _collect_reverse_dependents(
|
||||
plugin_ids: Set[str],
|
||||
dependency_map: Dict[str, Set[str]],
|
||||
) -> Set[str]:
|
||||
"""根据依赖图收集反向依赖闭包。"""
|
||||
|
||||
impacted_plugins: Set[str] = set(plugin_ids)
|
||||
changed = True
|
||||
|
||||
while changed:
|
||||
changed = False
|
||||
for registered_plugin_id, dependencies in dependency_map.items():
|
||||
if registered_plugin_id in impacted_plugins:
|
||||
continue
|
||||
if dependencies & impacted_plugins:
|
||||
impacted_plugins.add(registered_plugin_id)
|
||||
changed = True
|
||||
|
||||
return impacted_plugins
|
||||
|
||||
def _build_registered_supervisor_map(self) -> Dict[str, "PluginSupervisor"]:
|
||||
"""构建当前已注册插件到所属 Supervisor 的映射。"""
|
||||
|
||||
return {
|
||||
plugin_id: supervisor
|
||||
for supervisor in self.supervisors
|
||||
for plugin_id in supervisor.get_loaded_plugin_ids()
|
||||
}
|
||||
|
||||
def _build_external_available_plugins_for_supervisor(self, target_supervisor: "PluginSupervisor") -> Dict[str, str]:
|
||||
"""收集某个 Supervisor 可用的外部插件版本映射。"""
|
||||
|
||||
external_plugin_versions: Dict[str, str] = {}
|
||||
for supervisor in self.supervisors:
|
||||
if supervisor is target_supervisor:
|
||||
continue
|
||||
external_plugin_versions.update(supervisor.get_loaded_plugin_versions())
|
||||
return external_plugin_versions
|
||||
|
||||
def _find_supervisor_by_plugin_directory(self, plugin_id: str) -> Optional["PluginSupervisor"]:
|
||||
"""根据插件目录推断应负责该插件重载的 Supervisor。"""
|
||||
|
||||
for supervisor in self.supervisors:
|
||||
if self._get_plugin_path_for_supervisor(supervisor, plugin_id) is not None:
|
||||
return supervisor
|
||||
return None
|
||||
|
||||
def _warn_skipped_cross_supervisor_reload(
|
||||
self,
|
||||
requested_loaded_plugin_ids: Set[str],
|
||||
dependency_map: Dict[str, Set[str]],
|
||||
supervisor_by_plugin: Dict[str, "PluginSupervisor"],
|
||||
) -> None:
|
||||
"""记录因跨 Supervisor 边界而未参与联动重载的插件。"""
|
||||
|
||||
if not requested_loaded_plugin_ids:
|
||||
return
|
||||
|
||||
handled_plugin_ids: Set[str] = set()
|
||||
for supervisor in self.supervisors:
|
||||
local_requested_plugin_ids = {
|
||||
plugin_id
|
||||
for plugin_id in requested_loaded_plugin_ids
|
||||
if supervisor_by_plugin.get(plugin_id) is supervisor
|
||||
}
|
||||
if not local_requested_plugin_ids:
|
||||
continue
|
||||
|
||||
local_plugin_ids = set(supervisor.get_loaded_plugin_ids())
|
||||
local_dependency_map = {
|
||||
plugin_id: {
|
||||
dependency
|
||||
for dependency in dependency_map.get(plugin_id, set())
|
||||
if dependency in local_plugin_ids
|
||||
}
|
||||
for plugin_id in local_plugin_ids
|
||||
}
|
||||
handled_plugin_ids.update(
|
||||
self._collect_reverse_dependents(local_requested_plugin_ids, local_dependency_map)
|
||||
)
|
||||
|
||||
impacted_plugin_ids = self._collect_reverse_dependents(requested_loaded_plugin_ids, dependency_map)
|
||||
skipped_plugin_ids = sorted(impacted_plugin_ids - handled_plugin_ids)
|
||||
if not skipped_plugin_ids:
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
f"插件 {', '.join(sorted(requested_loaded_plugin_ids))} 存在跨 Supervisor 依赖方未联动重载: "
|
||||
f"{', '.join(skipped_plugin_ids)}。当前仅在单个 Supervisor 内执行联动重载;"
|
||||
"跨 Supervisor API 调用仍然可用。如需联动重载,请将相关插件放在同一个 Supervisor 内。"
|
||||
)
|
||||
|
||||
async def reload_plugins_globally(self, plugin_ids: Sequence[str], reason: str = "manual") -> bool:
|
||||
"""按 Supervisor 分组执行精确重载。
|
||||
|
||||
仅在单个 Supervisor 内执行依赖联动;跨 Supervisor 依赖方仅记录告警,
|
||||
不再自动参与本次热重载。
|
||||
"""
|
||||
|
||||
normalized_plugin_ids = [
|
||||
normalized_plugin_id
|
||||
for plugin_id in plugin_ids
|
||||
if (normalized_plugin_id := str(plugin_id or "").strip())
|
||||
]
|
||||
if not normalized_plugin_ids:
|
||||
return True
|
||||
|
||||
dependency_map = self._build_registered_dependency_map()
|
||||
supervisor_by_plugin = self._build_registered_supervisor_map()
|
||||
supervisor_roots: Dict["PluginSupervisor", List[str]] = {}
|
||||
requested_loaded_plugin_ids: Set[str] = set()
|
||||
missing_plugin_ids: List[str] = []
|
||||
|
||||
for plugin_id in normalized_plugin_ids:
|
||||
supervisor = supervisor_by_plugin.get(plugin_id)
|
||||
if supervisor is not None:
|
||||
requested_loaded_plugin_ids.add(plugin_id)
|
||||
else:
|
||||
supervisor = self._find_supervisor_by_plugin_directory(plugin_id)
|
||||
|
||||
if supervisor is None:
|
||||
missing_plugin_ids.append(plugin_id)
|
||||
continue
|
||||
|
||||
if plugin_id not in supervisor_roots.setdefault(supervisor, []):
|
||||
supervisor_roots[supervisor].append(plugin_id)
|
||||
|
||||
if missing_plugin_ids:
|
||||
logger.warning(f"以下插件未找到可重载的 Supervisor,已跳过: {', '.join(sorted(missing_plugin_ids))}")
|
||||
|
||||
self._warn_skipped_cross_supervisor_reload(
|
||||
requested_loaded_plugin_ids=requested_loaded_plugin_ids,
|
||||
dependency_map=dependency_map,
|
||||
supervisor_by_plugin=supervisor_by_plugin,
|
||||
)
|
||||
|
||||
success = True
|
||||
for supervisor, root_plugin_ids in supervisor_roots.items():
|
||||
if not root_plugin_ids:
|
||||
continue
|
||||
|
||||
reloaded = await supervisor.reload_plugins(
|
||||
plugin_ids=root_plugin_ids,
|
||||
reason=reason,
|
||||
external_available_plugins=self._build_external_available_plugins_for_supervisor(supervisor),
|
||||
)
|
||||
success = success and reloaded
|
||||
|
||||
return success and not missing_plugin_ids
|
||||
|
||||
async def notify_plugin_config_updated(
|
||||
self,
|
||||
plugin_id: str,
|
||||
config_data: Optional[Dict[str, Any]] = None,
|
||||
config_version: str = "",
|
||||
config_scope: str = "self",
|
||||
) -> bool:
|
||||
"""向拥有该插件的 Supervisor 推送配置更新事件。
|
||||
|
||||
@@ -193,6 +466,7 @@ class PluginRuntimeManager(
|
||||
plugin_id: 插件 ID
|
||||
config_data: 可选的配置数据(如果为 None 则由 Supervisor 从磁盘加载)
|
||||
config_version: 可选的配置版本字符串,供 Supervisor 进行版本控制
|
||||
config_scope: 配置变更范围。
|
||||
"""
|
||||
if not self._started:
|
||||
return False
|
||||
@@ -209,23 +483,78 @@ class PluginRuntimeManager(
|
||||
config_payload = (
|
||||
config_data
|
||||
if config_data is not None
|
||||
else self._load_plugin_config_for_supervisor(plugin_id, plugin_dirs=sv._plugin_dirs)
|
||||
else self._load_plugin_config_for_supervisor(sv, plugin_id)
|
||||
)
|
||||
await sv.notify_plugin_config_updated(
|
||||
return await sv.notify_plugin_config_updated(
|
||||
plugin_id=plugin_id,
|
||||
config_data=config_payload,
|
||||
config_version=config_version,
|
||||
config_scope=config_scope,
|
||||
)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _normalize_config_reload_scopes(changed_scopes: Sequence[str]) -> tuple[str, ...]:
|
||||
"""规范化配置热重载范围列表。
|
||||
|
||||
Args:
|
||||
changed_scopes: 原始配置热重载范围列表。
|
||||
|
||||
Returns:
|
||||
tuple[str, ...]: 去重后的有效配置范围元组。
|
||||
"""
|
||||
|
||||
normalized_scopes: list[str] = []
|
||||
for scope in changed_scopes:
|
||||
normalized_scope = str(scope or "").strip().lower()
|
||||
if normalized_scope not in {"bot", "model"}:
|
||||
continue
|
||||
if normalized_scope not in normalized_scopes:
|
||||
normalized_scopes.append(normalized_scope)
|
||||
return tuple(normalized_scopes)
|
||||
|
||||
async def _broadcast_config_reload(self, scope: str, config_data: Dict[str, Any]) -> None:
|
||||
"""向订阅指定范围的插件广播配置热重载。
|
||||
|
||||
Args:
|
||||
scope: 配置变更范围,仅支持 ``bot`` 或 ``model``。
|
||||
config_data: 最新配置数据。
|
||||
"""
|
||||
|
||||
for supervisor in self.supervisors:
|
||||
for plugin_id in supervisor.get_config_reload_subscribers(scope):
|
||||
delivered = await supervisor.notify_plugin_config_updated(
|
||||
plugin_id=plugin_id,
|
||||
config_data=config_data,
|
||||
config_version="",
|
||||
config_scope=scope,
|
||||
)
|
||||
if not delivered:
|
||||
logger.warning(f"向插件 {plugin_id} 广播 {scope} 配置热重载失败")
|
||||
|
||||
async def _handle_main_config_reload(self, changed_scopes: Sequence[str]) -> None:
|
||||
"""处理 bot/model 主配置热重载广播。
|
||||
|
||||
Args:
|
||||
changed_scopes: 本次热重载命中的配置范围列表。
|
||||
"""
|
||||
|
||||
if not self._started:
|
||||
return
|
||||
|
||||
normalized_scopes = self._normalize_config_reload_scopes(changed_scopes)
|
||||
if "bot" in normalized_scopes:
|
||||
await self._broadcast_config_reload("bot", config_manager.get_global_config().model_dump(mode="json"))
|
||||
if "model" in normalized_scopes:
|
||||
await self._broadcast_config_reload("model", config_manager.get_model_config().model_dump(mode="json"))
|
||||
|
||||
# ─── 事件桥接 ──────────────────────────────────────────────
|
||||
|
||||
async def bridge_event(
|
||||
self,
|
||||
event_type_value: str,
|
||||
message_dict: Optional[Dict[str, Any]] = None,
|
||||
message_dict: Optional[MessageDict] = None,
|
||||
extra_args: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[bool, Optional[Dict[str, Any]]]:
|
||||
) -> Tuple[bool, Optional[MessageDict]]:
|
||||
"""将事件分发到所有 Supervisor
|
||||
|
||||
Returns:
|
||||
@@ -235,17 +564,23 @@ class PluginRuntimeManager(
|
||||
return True, None
|
||||
|
||||
new_event_type: str = _EVENT_TYPE_MAP.get(event_type_value, event_type_value)
|
||||
modified: Optional[Dict[str, Any]] = None
|
||||
modified: Optional[MessageDict] = None
|
||||
current_message: Optional["SessionMessage"] = (
|
||||
PluginMessageUtils._build_session_message_from_dict(dict(message_dict))
|
||||
if message_dict is not None
|
||||
else None
|
||||
)
|
||||
|
||||
for sv in self.supervisors:
|
||||
try:
|
||||
cont, mod = await sv.dispatch_event(
|
||||
event_type=new_event_type,
|
||||
message=modified or message_dict,
|
||||
message=current_message,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
if mod is not None:
|
||||
modified = mod
|
||||
current_message = mod
|
||||
modified = PluginMessageUtils._session_message_to_dict(mod)
|
||||
if not cont:
|
||||
return False, modified
|
||||
except Exception as e:
|
||||
@@ -295,6 +630,37 @@ class PluginRuntimeManager(
|
||||
timeout_ms=timeout_ms,
|
||||
)
|
||||
|
||||
async def try_send_message_via_platform_io(
|
||||
self,
|
||||
message: "SessionMessage",
|
||||
) -> Optional[DeliveryBatch]:
|
||||
"""尝试通过 Platform IO 中间层发送消息。
|
||||
|
||||
Args:
|
||||
message: 待发送的内部会话消息。
|
||||
|
||||
Returns:
|
||||
Optional[DeliveryBatch]: 若当前消息命中了至少一条发送路由,则返回
|
||||
实际发送结果;若没有可用路由或 Platform IO 尚未启动,则返回 ``None``。
|
||||
"""
|
||||
if not self._started:
|
||||
return None
|
||||
|
||||
platform_io_manager = get_platform_io_manager()
|
||||
if not platform_io_manager.is_started:
|
||||
return None
|
||||
|
||||
try:
|
||||
route_key = platform_io_manager.build_route_key_from_message(message)
|
||||
except Exception as exc:
|
||||
logger.warning(f"根据消息构造 Platform IO 路由键失败: {exc}")
|
||||
return None
|
||||
|
||||
if not platform_io_manager.resolve_drivers(route_key):
|
||||
return None
|
||||
|
||||
return await platform_io_manager.send_message(message, route_key)
|
||||
|
||||
def _get_supervisors_for_plugin(self, plugin_id: str) -> List["PluginSupervisor"]:
|
||||
"""返回当前持有指定插件的所有 Supervisor。
|
||||
|
||||
@@ -314,30 +680,38 @@ class PluginRuntimeManager(
|
||||
raise RuntimeError(f"插件 {plugin_id} 同时存在于多个 Supervisor 中,无法安全路由")
|
||||
return matches[0] if matches else None
|
||||
|
||||
@staticmethod
|
||||
def _find_duplicate_plugin_ids(plugin_dirs: List[Path]) -> Dict[str, List[Path]]:
|
||||
async def load_plugin_globally(self, plugin_id: str, reason: str = "manual") -> bool:
|
||||
"""加载或重载单个插件,并为其补齐跨 Supervisor 外部依赖。"""
|
||||
|
||||
normalized_plugin_id = str(plugin_id or "").strip()
|
||||
if not normalized_plugin_id:
|
||||
return False
|
||||
|
||||
try:
|
||||
registered_supervisor = self._get_supervisor_for_plugin(normalized_plugin_id)
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
if registered_supervisor is not None:
|
||||
return await self.reload_plugins_globally([normalized_plugin_id], reason=reason)
|
||||
|
||||
supervisor = self._find_supervisor_by_plugin_directory(normalized_plugin_id)
|
||||
if supervisor is None:
|
||||
return False
|
||||
|
||||
return await supervisor.reload_plugins(
|
||||
plugin_ids=[normalized_plugin_id],
|
||||
reason=reason,
|
||||
external_available_plugins=self._build_external_available_plugins_for_supervisor(supervisor),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _find_duplicate_plugin_ids(cls, plugin_dirs: List[Path]) -> Dict[str, List[Path]]:
|
||||
"""扫描插件目录,找出被多个目录重复声明的插件 ID。"""
|
||||
plugin_locations: Dict[str, List[Path]] = {}
|
||||
for base_dir in plugin_dirs:
|
||||
if not base_dir.is_dir():
|
||||
continue
|
||||
for entry in base_dir.iterdir():
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
manifest_path = entry / "_manifest.json"
|
||||
plugin_path = entry / "plugin.py"
|
||||
if not manifest_path.exists() or not plugin_path.exists():
|
||||
continue
|
||||
|
||||
plugin_id = entry.name
|
||||
try:
|
||||
with open(manifest_path, "r", encoding="utf-8") as manifest_file:
|
||||
manifest = json.load(manifest_file)
|
||||
plugin_id = str(manifest.get("name", entry.name)).strip() or entry.name
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
plugin_locations.setdefault(plugin_id, []).append(entry)
|
||||
validator = ManifestValidator()
|
||||
for plugin_path, manifest in validator.iter_plugin_manifests(plugin_dirs):
|
||||
plugin_locations.setdefault(manifest.id, []).append(plugin_path)
|
||||
|
||||
return {
|
||||
plugin_id: sorted(dict.fromkeys(paths), key=lambda p: str(p))
|
||||
@@ -370,6 +744,7 @@ class PluginRuntimeManager(
|
||||
async def _stop_plugin_file_watcher(self) -> None:
|
||||
"""停止插件文件监视器,并清理所有已注册订阅。"""
|
||||
if self._plugin_file_watcher is None:
|
||||
self._plugin_path_cache.clear()
|
||||
return
|
||||
for _plugin_id, (_config_path, subscription_id) in list(self._plugin_config_watcher_subscriptions.items()):
|
||||
self._plugin_file_watcher.unsubscribe(subscription_id)
|
||||
@@ -379,12 +754,79 @@ class PluginRuntimeManager(
|
||||
self._plugin_source_watcher_subscription_id = None
|
||||
await self._plugin_file_watcher.stop()
|
||||
self._plugin_file_watcher = None
|
||||
self._plugin_path_cache.clear()
|
||||
|
||||
def _iter_plugin_dirs(self) -> Iterable[Path]:
|
||||
"""迭代所有 Supervisor 当前管理的插件根目录。"""
|
||||
for supervisor in self.supervisors:
|
||||
yield from getattr(supervisor, "_plugin_dirs", [])
|
||||
|
||||
@staticmethod
|
||||
def _iter_candidate_plugin_paths(plugin_dirs: Iterable[Path]) -> Iterable[Path]:
|
||||
"""迭代所有可能的插件目录路径。
|
||||
|
||||
Args:
|
||||
plugin_dirs: 一个或多个插件根目录。
|
||||
|
||||
Yields:
|
||||
Path: 单个插件目录路径。
|
||||
"""
|
||||
for plugin_dir in plugin_dirs:
|
||||
plugin_root = Path(plugin_dir).resolve()
|
||||
if not plugin_root.is_dir():
|
||||
continue
|
||||
for entry in plugin_root.iterdir():
|
||||
if entry.is_dir():
|
||||
yield entry.resolve()
|
||||
|
||||
def _read_plugin_id_from_plugin_path(self, plugin_path: Path) -> Optional[str]:
|
||||
"""从单个插件目录中读取 manifest 声明的插件 ID。
|
||||
|
||||
Args:
|
||||
plugin_path: 单个插件目录路径。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 解析成功时返回插件 ID,否则返回 ``None``。
|
||||
"""
|
||||
return self._manifest_validator.read_plugin_id_from_plugin_path(plugin_path)
|
||||
|
||||
def _iter_discovered_plugin_paths(self, plugin_dirs: Iterable[Path]) -> Iterable[Tuple[str, Path]]:
|
||||
"""迭代目录中可解析到的插件 ID 与实际目录路径。
|
||||
|
||||
Args:
|
||||
plugin_dirs: 一个或多个插件根目录。
|
||||
|
||||
Yields:
|
||||
Tuple[str, Path]: ``(plugin_id, plugin_path)`` 二元组。
|
||||
"""
|
||||
for plugin_path in self._iter_candidate_plugin_paths(plugin_dirs):
|
||||
if plugin_id := self._read_plugin_id_from_plugin_path(plugin_path):
|
||||
yield plugin_id, plugin_path
|
||||
|
||||
def _get_plugin_path_for_supervisor(self, supervisor: Any, plugin_id: str) -> Optional[Path]:
|
||||
"""为指定 Supervisor 定位某个插件的实际目录。
|
||||
|
||||
Args:
|
||||
supervisor: 目标 Supervisor。
|
||||
plugin_id: 插件 ID。
|
||||
|
||||
Returns:
|
||||
Optional[Path]: 插件目录路径;未找到时返回 ``None``。
|
||||
"""
|
||||
cached_path = self._plugin_path_cache.get(plugin_id)
|
||||
if cached_path is not None:
|
||||
for plugin_dir in getattr(supervisor, "_plugin_dirs", []):
|
||||
if self._plugin_dir_matches(cached_path, Path(plugin_dir)):
|
||||
return cached_path
|
||||
|
||||
for candidate_plugin_id, plugin_path in self._iter_discovered_plugin_paths(getattr(supervisor, "_plugin_dirs", [])):
|
||||
if candidate_plugin_id != plugin_id:
|
||||
continue
|
||||
self._plugin_path_cache[plugin_id] = plugin_path
|
||||
return plugin_path
|
||||
|
||||
return None
|
||||
|
||||
def _refresh_plugin_config_watch_subscriptions(self) -> None:
|
||||
"""按当前已注册插件集合刷新 config.toml 的单插件订阅。
|
||||
|
||||
@@ -394,7 +836,11 @@ class PluginRuntimeManager(
|
||||
if self._plugin_file_watcher is None:
|
||||
return
|
||||
|
||||
desired_config_paths = dict(self._iter_registered_plugin_config_paths())
|
||||
desired_plugin_paths = dict(self._iter_registered_plugin_paths())
|
||||
self._plugin_path_cache = desired_plugin_paths.copy()
|
||||
desired_config_paths = {
|
||||
plugin_id: plugin_path / "config.toml" for plugin_id, plugin_path in desired_plugin_paths.items()
|
||||
}
|
||||
|
||||
for plugin_id, (_old_path, subscription_id) in list(self._plugin_config_watcher_subscriptions.items()):
|
||||
if desired_config_paths.get(plugin_id) == self._plugin_config_watcher_subscriptions[plugin_id][0]:
|
||||
@@ -418,28 +864,35 @@ class PluginRuntimeManager(
|
||||
"""为指定插件生成配置文件变更回调。"""
|
||||
|
||||
async def _callback(changes: Sequence[FileChange]) -> None:
|
||||
"""将 watcher 事件转发到指定插件的配置处理逻辑。
|
||||
|
||||
Args:
|
||||
changes: 当前批次收集到的文件变更列表。
|
||||
"""
|
||||
await self._handle_plugin_config_changes(plugin_id, changes)
|
||||
|
||||
return _callback
|
||||
|
||||
def _iter_registered_plugin_config_paths(self) -> Iterable[Tuple[str, Path]]:
|
||||
"""迭代当前所有已注册插件的 config.toml 路径。"""
|
||||
def _iter_registered_plugin_paths(self) -> Iterable[Tuple[str, Path]]:
|
||||
"""迭代当前所有已注册插件的实际目录路径。"""
|
||||
for supervisor in self.supervisors:
|
||||
for plugin_id in getattr(supervisor, "_registered_plugins", {}).keys():
|
||||
if config_path := self._get_plugin_config_path_for_supervisor(supervisor, plugin_id):
|
||||
yield plugin_id, config_path
|
||||
if plugin_path := self._get_plugin_path_for_supervisor(supervisor, plugin_id):
|
||||
yield plugin_id, plugin_path
|
||||
|
||||
def _get_plugin_config_path_for_supervisor(self, supervisor: Any, plugin_id: str) -> Optional[Path]:
|
||||
"""从指定 Supervisor 的插件目录中定位某个插件的 config.toml。"""
|
||||
for plugin_dir in getattr(supervisor, "_plugin_dirs", []):
|
||||
plugin_dir = Path(plugin_dir)
|
||||
plugin_path = plugin_dir.resolve() / plugin_id
|
||||
if plugin_path.is_dir():
|
||||
return plugin_path / "config.toml"
|
||||
return None
|
||||
plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id)
|
||||
return None if plugin_path is None else plugin_path / "config.toml"
|
||||
|
||||
async def _handle_plugin_config_changes(self, plugin_id: str, changes: Sequence[FileChange]) -> None:
|
||||
"""处理单个插件配置文件变化,并仅向目标插件推送配置更新。"""
|
||||
"""处理单个插件配置文件变化,并定向派发自配置热更新。
|
||||
|
||||
Args:
|
||||
plugin_id: 发生配置变更的插件 ID。
|
||||
changes: 当前批次收集到的配置文件变更列表。
|
||||
|
||||
"""
|
||||
if not self._started or not changes:
|
||||
return
|
||||
|
||||
@@ -453,18 +906,24 @@ class PluginRuntimeManager(
|
||||
return
|
||||
|
||||
try:
|
||||
await supervisor.notify_plugin_config_updated(
|
||||
config_payload = self._load_plugin_config_for_supervisor(supervisor, plugin_id)
|
||||
delivered = await supervisor.notify_plugin_config_updated(
|
||||
plugin_id=plugin_id,
|
||||
config_data=self._load_plugin_config_for_supervisor(plugin_id, getattr(supervisor, "_plugin_dirs", [])),
|
||||
config_data=config_payload,
|
||||
config_version="",
|
||||
config_scope="self",
|
||||
)
|
||||
if not delivered:
|
||||
logger.warning(f"插件 {plugin_id} 配置文件变更后通知失败")
|
||||
except Exception as exc:
|
||||
logger.warning(f"插件 {plugin_id} 配置热更新通知失败: {exc}")
|
||||
logger.warning(f"插件 {plugin_id} 配置文件变更处理失败: {exc}")
|
||||
|
||||
async def _handle_plugin_source_changes(self, changes: Sequence[FileChange]) -> None:
|
||||
"""处理插件源码相关变化。
|
||||
|
||||
这里仅负责源码、清单等会影响插件装载状态的文件;配置文件的变化会由
|
||||
单独的 per-plugin watcher 处理,避免把单插件配置更新放大成全量 reload。
|
||||
单独的 per-plugin watcher 处理,并定向派发给目标插件的
|
||||
``on_config_update()``,避免放大成不必要的跨插件 reload。
|
||||
"""
|
||||
if not self._started or not changes:
|
||||
return
|
||||
@@ -477,7 +936,7 @@ class PluginRuntimeManager(
|
||||
logger.error(f"检测到重复插件 ID,跳过本次插件热重载: {details}")
|
||||
return
|
||||
|
||||
reload_supervisors: List[Any] = []
|
||||
changed_plugin_ids: List[str] = []
|
||||
changed_paths = [change.path.resolve() for change in changes]
|
||||
|
||||
for supervisor in self.supervisors:
|
||||
@@ -485,13 +944,12 @@ class PluginRuntimeManager(
|
||||
plugin_id = self._match_plugin_id_for_supervisor(supervisor, path)
|
||||
if plugin_id is None:
|
||||
continue
|
||||
if (path.name in {"plugin.py", "_manifest.json"} or path.suffix == ".py") and supervisor not in reload_supervisors:
|
||||
reload_supervisors.append(supervisor)
|
||||
if path.name in {"plugin.py", "_manifest.json"} or path.suffix == ".py":
|
||||
if plugin_id not in changed_plugin_ids:
|
||||
changed_plugin_ids.append(plugin_id)
|
||||
|
||||
for supervisor in reload_supervisors:
|
||||
await supervisor.reload_plugins(reason="file_watcher")
|
||||
|
||||
if reload_supervisors:
|
||||
if changed_plugin_ids:
|
||||
await self.reload_plugins_globally(changed_plugin_ids, reason="file_watcher")
|
||||
self._refresh_plugin_config_watch_subscriptions()
|
||||
|
||||
@staticmethod
|
||||
@@ -502,36 +960,47 @@ class PluginRuntimeManager(
|
||||
|
||||
def _match_plugin_id_for_supervisor(self, supervisor: Any, path: Path) -> Optional[str]:
|
||||
"""根据变更路径为指定 Supervisor 推断受影响的插件 ID。"""
|
||||
for plugin_id, _reg in getattr(supervisor, "_registered_plugins", {}).items():
|
||||
for plugin_dir in getattr(supervisor, "_plugin_dirs", []):
|
||||
plugin_dir = Path(plugin_dir)
|
||||
candidate_dir = plugin_dir.resolve() / plugin_id
|
||||
if path == candidate_dir or path.is_relative_to(candidate_dir):
|
||||
return plugin_id
|
||||
resolved_path = path.resolve()
|
||||
|
||||
for plugin_id in getattr(supervisor, "_registered_plugins", {}).keys():
|
||||
plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id)
|
||||
if plugin_path is not None and (resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path)):
|
||||
return plugin_id
|
||||
|
||||
for plugin_id, plugin_path in self._plugin_path_cache.items():
|
||||
if not any(self._plugin_dir_matches(plugin_path, Path(plugin_dir)) for plugin_dir in getattr(supervisor, "_plugin_dirs", [])):
|
||||
continue
|
||||
if resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path):
|
||||
return plugin_id
|
||||
|
||||
for plugin_id, plugin_path in self._iter_discovered_plugin_paths(getattr(supervisor, "_plugin_dirs", [])):
|
||||
if resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path):
|
||||
self._plugin_path_cache[plugin_id] = plugin_path
|
||||
return plugin_id
|
||||
|
||||
for plugin_dir in getattr(supervisor, "_plugin_dirs", []):
|
||||
plugin_dir = Path(plugin_dir)
|
||||
plugin_root = plugin_dir.resolve()
|
||||
if self._plugin_dir_matches(path, plugin_dir) and (relative_parts := path.relative_to(plugin_root).parts):
|
||||
return relative_parts[0]
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _load_plugin_config_for_supervisor(plugin_id: str, plugin_dirs: Iterable[Path]) -> Dict[str, Any]:
|
||||
def _load_plugin_config_for_supervisor(self, supervisor: Any, plugin_id: str) -> Dict[str, Any]:
|
||||
"""从给定插件目录集合中读取目标插件的配置内容。"""
|
||||
for plugin_dir in plugin_dirs:
|
||||
plugin_path = plugin_dir.resolve() / plugin_id
|
||||
if plugin_path.is_dir():
|
||||
config_path = plugin_path / "config.toml"
|
||||
if not config_path.exists():
|
||||
return {}
|
||||
with open(config_path, "r", encoding="utf-8") as handle:
|
||||
return tomlkit.load(handle).unwrap()
|
||||
return {}
|
||||
plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id)
|
||||
if plugin_path is None:
|
||||
return {}
|
||||
|
||||
config_path = plugin_path / "config.toml"
|
||||
if not config_path.exists():
|
||||
return {}
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as handle:
|
||||
return tomlkit.load(handle).unwrap()
|
||||
|
||||
# ─── 能力实现注册 ──────────────────────────────────────────
|
||||
|
||||
def _register_capability_impls(self, supervisor: "PluginSupervisor") -> None:
|
||||
"""向指定 Supervisor 注册主程序能力实现。
|
||||
|
||||
Args:
|
||||
supervisor: 需要注册能力实现的目标 Supervisor。
|
||||
"""
|
||||
register_capability_impls(self, supervisor)
|
||||
|
||||
|
||||
|
||||
@@ -7,52 +7,52 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import logging as stdlib_logging
|
||||
import time
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# ─── 协议常量 ──────────────────────────────────────────────────────
|
||||
|
||||
PROTOCOL_VERSION = "1.0"
|
||||
|
||||
# ====== 协议常量 ======
|
||||
PROTOCOL_VERSION = "1.0.0"
|
||||
# 支持的 SDK 版本范围(Host 在握手时校验)
|
||||
MIN_SDK_VERSION = "1.0.0"
|
||||
MAX_SDK_VERSION = "1.99.99"
|
||||
|
||||
|
||||
# ─── 消息类型 ──────────────────────────────────────────────────────
|
||||
MAX_SDK_VERSION = "2.99.99"
|
||||
|
||||
|
||||
# ====== 消息类型 ======
|
||||
class MessageType(str, Enum):
|
||||
"""RPC 消息类型"""
|
||||
|
||||
REQUEST = "request"
|
||||
RESPONSE = "response"
|
||||
EVENT = "event"
|
||||
BROADCAST = "broadcast"
|
||||
|
||||
|
||||
# ─── 请求 ID 生成器 ───────────────────────────────────────────────
|
||||
class ConfigReloadScope(str, Enum):
|
||||
"""配置热重载范围。"""
|
||||
|
||||
SELF = "self"
|
||||
BOT = "bot"
|
||||
MODEL = "model"
|
||||
|
||||
|
||||
# ====== 请求 ID 生成器 ======
|
||||
class RequestIdGenerator:
|
||||
"""单调递增 int64 请求 ID 生成器(线程安全由调用方保证或使用 asyncio)"""
|
||||
"""单调递增 int64 请求 ID 生成器"""
|
||||
|
||||
def __init__(self, start: int = 1) -> None:
|
||||
self._counter = start
|
||||
|
||||
def next(self) -> int:
|
||||
async def next(self) -> int:
|
||||
current = self._counter
|
||||
self._counter += 1
|
||||
return current
|
||||
|
||||
|
||||
# ─── Envelope 模型 ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
# ====== Envelope 模型 ======
|
||||
class Envelope(BaseModel):
|
||||
"""RPC 统一信封
|
||||
"""RPC 统一消息封装
|
||||
|
||||
所有 Host <-> Runner 消息均封装为此格式。
|
||||
序列化流程:Envelope -> .model_dump() -> MsgPack encode
|
||||
@@ -60,15 +60,23 @@ class Envelope(BaseModel):
|
||||
"""
|
||||
|
||||
protocol_version: str = Field(default=PROTOCOL_VERSION, description="协议版本")
|
||||
"""协议版本"""
|
||||
request_id: int = Field(description="单调递增请求 ID")
|
||||
"""单调递增请求 ID"""
|
||||
message_type: MessageType = Field(description="消息类型")
|
||||
"""消息类型"""
|
||||
method: str = Field(default="", description="RPC 方法名")
|
||||
"""RPC 方法名"""
|
||||
plugin_id: str = Field(default="", description="目标插件 ID")
|
||||
timestamp_ms: int = Field(default_factory=lambda: int(time.time() * 1000), description="发送时间戳(ms)")
|
||||
timeout_ms: int = Field(default=30000, description="相对超时(ms)")
|
||||
generation: int = Field(default=0, description="Runner generation 编号")
|
||||
"""目标插件 ID"""
|
||||
timestamp_ms: int = Field(default_factory=lambda: int(time.time() * 1000), description="发送时间戳 (ms)")
|
||||
"""发送时间戳 (ms)"""
|
||||
timeout_ms: int = Field(default=30000, description="相对超时 (ms)")
|
||||
"""相对超时 (ms)"""
|
||||
payload: Dict[str, Any] = Field(default_factory=dict, description="业务数据")
|
||||
error: Optional[Dict[str, Any]] = Field(default=None, description="错误信息(仅 response)")
|
||||
"""业务数据"""
|
||||
error: Optional[Dict[str, Any]] = Field(default=None, description="错误信息 (仅 response)")
|
||||
"""错误信息 (仅 response)"""
|
||||
|
||||
def is_request(self) -> bool:
|
||||
return self.message_type == MessageType.REQUEST
|
||||
@@ -76,8 +84,8 @@ class Envelope(BaseModel):
|
||||
def is_response(self) -> bool:
|
||||
return self.message_type == MessageType.RESPONSE
|
||||
|
||||
def is_event(self) -> bool:
|
||||
return self.message_type == MessageType.EVENT
|
||||
def is_broadcast(self) -> bool:
|
||||
return self.message_type == MessageType.BROADCAST
|
||||
|
||||
def make_response(
|
||||
self, payload: Optional[Dict[str, Any]] = None, error: Optional[Dict[str, Any]] = None
|
||||
@@ -89,7 +97,6 @@ class Envelope(BaseModel):
|
||||
message_type=MessageType.RESPONSE,
|
||||
method=self.method,
|
||||
plugin_id=self.plugin_id,
|
||||
generation=self.generation,
|
||||
payload=payload or {},
|
||||
error=error,
|
||||
)
|
||||
@@ -105,153 +112,302 @@ class Envelope(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
# ─── 握手消息 ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
# ====== 握手请求与响应 ======
|
||||
class HelloPayload(BaseModel):
|
||||
"""runner.hello 握手请求 payload"""
|
||||
|
||||
runner_id: str = Field(description="Runner 进程唯一标识")
|
||||
"""Runner 进程唯一标识"""
|
||||
sdk_version: str = Field(description="SDK 版本号")
|
||||
"""SDK 版本号"""
|
||||
session_token: str = Field(description="一次性会话令牌")
|
||||
"""一次性会话令牌"""
|
||||
|
||||
|
||||
class HelloResponsePayload(BaseModel):
|
||||
"""runner.hello 握手响应 payload"""
|
||||
|
||||
accepted: bool = Field(description="是否接受连接")
|
||||
"""是否接受连接"""
|
||||
host_version: str = Field(default="", description="Host 版本号")
|
||||
assigned_generation: int = Field(default=0, description="分配的 generation 编号")
|
||||
reason: str = Field(default="", description="拒绝原因(若 accepted=False)")
|
||||
|
||||
|
||||
# ─── 组件注册消息 ──────────────────────────────────────────────────
|
||||
"""Host 版本号"""
|
||||
reason: str = Field(default="", description="拒绝原因 (若 accepted=False)")
|
||||
"""拒绝原因 (若 `accepted`=`False`)"""
|
||||
|
||||
|
||||
# ====== 组件注册消息 ======
|
||||
class ComponentDeclaration(BaseModel):
|
||||
"""单个组件声明"""
|
||||
|
||||
name: str = Field(description="组件名称")
|
||||
component_type: str = Field(description="组件类型: action/command/tool/event_handler")
|
||||
"""组件名称"""
|
||||
component_type: str = Field(
|
||||
description="组件类型:action/command/tool/event_handler/hook_handler/message_gateway"
|
||||
)
|
||||
"""组件类型:`action`/`command`/`tool`/`event_handler`/`hook_handler`/`message_gateway`"""
|
||||
plugin_id: str = Field(description="所属插件 ID")
|
||||
"""所属插件 ID"""
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="组件元数据")
|
||||
"""组件元数据"""
|
||||
|
||||
|
||||
class RegisterComponentsPayload(BaseModel):
|
||||
"""plugin.register_components 请求 payload"""
|
||||
class RegisterPluginPayload(BaseModel):
|
||||
"""插件组件注册请求载荷。
|
||||
|
||||
该模型同时用于 ``plugin.register_components`` 与兼容旧命名的
|
||||
``plugin.register_plugin`` 请求。
|
||||
"""
|
||||
|
||||
plugin_id: str = Field(description="插件 ID")
|
||||
"""插件 ID"""
|
||||
plugin_version: str = Field(default="1.0.0", description="插件版本")
|
||||
"""插件版本"""
|
||||
components: List[ComponentDeclaration] = Field(default_factory=list, description="组件列表")
|
||||
"""组件列表"""
|
||||
capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表")
|
||||
"""所需能力列表"""
|
||||
dependencies: List[str] = Field(default_factory=list, description="插件级依赖插件 ID 列表")
|
||||
"""插件级依赖插件 ID 列表"""
|
||||
config_reload_subscriptions: List[str] = Field(default_factory=list, description="订阅的全局配置热重载范围")
|
||||
"""订阅的全局配置热重载范围"""
|
||||
|
||||
|
||||
class BootstrapPluginPayload(BaseModel):
|
||||
"""plugin.bootstrap 请求 payload"""
|
||||
|
||||
plugin_id: str = Field(description="插件 ID")
|
||||
"""插件 ID"""
|
||||
plugin_version: str = Field(default="1.0.0", description="插件版本")
|
||||
"""插件版本"""
|
||||
capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表")
|
||||
"""所需能力列表"""
|
||||
|
||||
|
||||
# ─── 调用消息 ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
# ====== 插件调用请求和响应 ======
|
||||
class InvokePayload(BaseModel):
|
||||
"""plugin.invoke_* 请求 payload"""
|
||||
"""plugin.invoke.* 请求 payload"""
|
||||
|
||||
component_name: str = Field(description="要调用的组件名称")
|
||||
"""要调用的组件名称"""
|
||||
args: Dict[str, Any] = Field(default_factory=dict, description="调用参数")
|
||||
"""调用参数"""
|
||||
|
||||
|
||||
class InvokeResultPayload(BaseModel):
|
||||
"""plugin.invoke_* 响应 payload"""
|
||||
"""plugin.invoke.* 响应 payload"""
|
||||
|
||||
success: bool = Field(description="是否成功")
|
||||
"""是否成功"""
|
||||
result: Any = Field(default=None, description="返回值")
|
||||
"""返回值"""
|
||||
|
||||
|
||||
# ─── 能力调用消息 ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
# ====== 能力调用消息 ======
|
||||
class CapabilityRequestPayload(BaseModel):
|
||||
"""cap.* 请求 payload(插件 -> Host 能力调用)"""
|
||||
|
||||
capability: str = Field(description="能力名称,如 send.text, db.query")
|
||||
"""能力名称,如 send.text, db.query"""
|
||||
args: Dict[str, Any] = Field(default_factory=dict, description="调用参数")
|
||||
"""调用参数"""
|
||||
|
||||
|
||||
class CapabilityResponsePayload(BaseModel):
|
||||
"""cap.* 响应 payload"""
|
||||
|
||||
success: bool = Field(description="是否成功")
|
||||
"""是否成功"""
|
||||
result: Any = Field(default=None, description="返回值")
|
||||
"""返回值"""
|
||||
|
||||
|
||||
# ─── 健康检查 ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
# ====== 健康检查 ======
|
||||
class HealthPayload(BaseModel):
|
||||
"""plugin.health 响应 payload"""
|
||||
|
||||
healthy: bool = Field(description="是否健康")
|
||||
"""是否健康"""
|
||||
loaded_plugins: List[str] = Field(default_factory=list, description="已加载的插件列表")
|
||||
uptime_ms: int = Field(default=0, description="运行时长(ms)")
|
||||
"""已加载的插件列表"""
|
||||
uptime_ms: int = Field(default=0, description="运行时长 (ms)")
|
||||
"""运行时长 (ms)"""
|
||||
|
||||
|
||||
class RunnerReadyPayload(BaseModel):
|
||||
"""runner.ready 请求 payload"""
|
||||
|
||||
loaded_plugins: List[str] = Field(default_factory=list, description="已完成初始化的插件列表")
|
||||
"""已完成初始化的插件列表"""
|
||||
failed_plugins: List[str] = Field(default_factory=list, description="初始化失败的插件列表")
|
||||
"""初始化失败的插件列表"""
|
||||
|
||||
|
||||
# ─── 配置更新 ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
# Host 侧现已支持配置更新推送:
|
||||
# - 总配置热重载完成后,PluginRuntimeManager 会向已加载插件推送配置更新事件。
|
||||
# - 插件目录下的 config.toml 变化由现有 FileWatcher 监听并转发为 plugin.config_updated。
|
||||
# ====== 配置更新 ======
|
||||
class ConfigUpdatedPayload(BaseModel):
|
||||
"""plugin.config_updated 事件 payload"""
|
||||
|
||||
plugin_id: str = Field(description="插件 ID")
|
||||
"""插件 ID"""
|
||||
config_scope: ConfigReloadScope = Field(description="配置变更范围")
|
||||
"""配置变更范围"""
|
||||
config_version: str = Field(description="新配置版本")
|
||||
"""新配置版本"""
|
||||
config_data: Dict[str, Any] = Field(default_factory=dict, description="配置内容")
|
||||
"""配置内容"""
|
||||
|
||||
|
||||
# ─── 关停 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
# ====== 关停 ======
|
||||
class ShutdownPayload(BaseModel):
|
||||
"""plugin.shutdown / plugin.prepare_shutdown payload"""
|
||||
|
||||
reason: str = Field(default="normal", description="关停原因")
|
||||
drain_timeout_ms: int = Field(default=5000, description="排空超时(ms)")
|
||||
"""关停原因"""
|
||||
drain_timeout_ms: int = Field(default=5000, description="排空超时 (ms)")
|
||||
"""排空超时 (ms)"""
|
||||
|
||||
|
||||
# ─── 日志传输 ──────────────────────────────────────────────────────
|
||||
class UnregisterPluginPayload(BaseModel):
|
||||
"""插件注销请求载荷。"""
|
||||
|
||||
plugin_id: str = Field(description="插件 ID")
|
||||
"""插件 ID"""
|
||||
reason: str = Field(default="manual", description="注销原因")
|
||||
"""注销原因"""
|
||||
|
||||
|
||||
class ReloadPluginPayload(BaseModel):
|
||||
"""插件重载请求载荷。"""
|
||||
|
||||
plugin_id: str = Field(description="目标插件 ID")
|
||||
"""目标插件 ID"""
|
||||
reason: str = Field(default="manual", description="重载原因")
|
||||
"""重载原因"""
|
||||
external_available_plugins: Dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="可视为已满足的外部依赖插件版本映射",
|
||||
)
|
||||
"""可视为已满足的外部依赖插件版本映射"""
|
||||
|
||||
|
||||
class ReloadPluginsPayload(BaseModel):
|
||||
"""批量插件重载请求载荷。"""
|
||||
|
||||
plugin_ids: List[str] = Field(default_factory=list, description="目标插件 ID 列表")
|
||||
"""目标插件 ID 列表"""
|
||||
reason: str = Field(default="manual", description="重载原因")
|
||||
"""重载原因"""
|
||||
external_available_plugins: Dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="可视为已满足的外部依赖插件版本映射",
|
||||
)
|
||||
"""可视为已满足的外部依赖插件版本映射"""
|
||||
|
||||
|
||||
class ReloadPluginResultPayload(BaseModel):
|
||||
"""插件重载结果载荷。"""
|
||||
|
||||
success: bool = Field(description="是否重载成功")
|
||||
"""是否重载成功"""
|
||||
requested_plugin_id: str = Field(description="请求重载的插件 ID")
|
||||
"""请求重载的插件 ID"""
|
||||
reloaded_plugins: List[str] = Field(default_factory=list, description="成功完成重载的插件列表")
|
||||
"""成功完成重载的插件列表"""
|
||||
unloaded_plugins: List[str] = Field(default_factory=list, description="本次已卸载的插件列表")
|
||||
"""本次已卸载的插件列表"""
|
||||
failed_plugins: Dict[str, str] = Field(default_factory=dict, description="重载失败的插件及原因")
|
||||
"""重载失败的插件及原因"""
|
||||
|
||||
|
||||
class ReloadPluginsResultPayload(BaseModel):
|
||||
"""批量插件重载结果载荷。"""
|
||||
|
||||
success: bool = Field(description="是否重载成功")
|
||||
"""是否重载成功"""
|
||||
requested_plugin_ids: List[str] = Field(default_factory=list, description="请求重载的插件 ID 列表")
|
||||
"""请求重载的插件 ID 列表"""
|
||||
reloaded_plugins: List[str] = Field(default_factory=list, description="成功完成重载的插件列表")
|
||||
"""成功完成重载的插件列表"""
|
||||
unloaded_plugins: List[str] = Field(default_factory=list, description="本次已卸载的插件列表")
|
||||
"""本次已卸载的插件列表"""
|
||||
failed_plugins: Dict[str, str] = Field(default_factory=dict, description="重载失败的插件及原因")
|
||||
"""重载失败的插件及原因"""
|
||||
|
||||
|
||||
class MessageGatewayStateUpdatePayload(BaseModel):
|
||||
"""消息网关运行时状态更新载荷。"""
|
||||
|
||||
gateway_name: str = Field(description="消息网关组件名称")
|
||||
"""消息网关组件名称"""
|
||||
ready: bool = Field(description="当前链路是否已经就绪")
|
||||
"""当前链路是否已经就绪"""
|
||||
platform: str = Field(default="", description="当前链路负责的平台名称")
|
||||
"""当前链路负责的平台名称"""
|
||||
account_id: str = Field(default="", description="当前链路对应的账号 ID 或 self_id")
|
||||
"""当前链路对应的账号 ID 或 self_id"""
|
||||
scope: str = Field(default="", description="当前链路对应的可选路由作用域")
|
||||
"""当前链路对应的可选路由作用域"""
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="可选的运行时状态元数据")
|
||||
"""可选的运行时状态元数据"""
|
||||
|
||||
|
||||
class MessageGatewayStateUpdateResultPayload(BaseModel):
|
||||
"""消息网关运行时状态更新结果载荷。"""
|
||||
|
||||
accepted: bool = Field(description="Host 是否接受了本次状态更新")
|
||||
"""Host 是否接受了本次状态更新"""
|
||||
ready: bool = Field(description="Host 记录的当前就绪状态")
|
||||
"""Host 记录的当前就绪状态"""
|
||||
route_key: Dict[str, Any] = Field(default_factory=dict, description="当前生效的路由键")
|
||||
"""当前生效的路由键"""
|
||||
|
||||
|
||||
class RouteMessagePayload(BaseModel):
|
||||
"""消息网关向 Host 路由外部消息的请求载荷。"""
|
||||
|
||||
gateway_name: str = Field(description="接收消息的网关组件名称")
|
||||
"""接收消息的网关组件名称"""
|
||||
message: Dict[str, Any] = Field(description="符合 MessageDict 结构的标准消息字典")
|
||||
"""符合 MessageDict 结构的标准消息字典"""
|
||||
route_metadata: Dict[str, Any] = Field(default_factory=dict, description="可选的路由辅助元数据")
|
||||
"""可选的路由辅助元数据"""
|
||||
external_message_id: str = Field(default="", description="可选的外部平台消息 ID")
|
||||
"""可选的外部平台消息 ID"""
|
||||
dedupe_key: str = Field(default="", description="可选的显式去重键")
|
||||
"""可选的显式去重键"""
|
||||
|
||||
|
||||
class ReceiveExternalMessageResultPayload(BaseModel):
|
||||
"""外部消息注入结果载荷。"""
|
||||
|
||||
accepted: bool = Field(description="Host 是否接受了本次消息注入")
|
||||
"""Host 是否接受了本次消息注入"""
|
||||
route_key: Dict[str, Any] = Field(default_factory=dict, description="本次消息使用的归一路由键")
|
||||
"""本次消息使用的归一路由键"""
|
||||
|
||||
|
||||
RegisterPluginPayload.model_rebuild()
|
||||
|
||||
|
||||
# ====== 日志传输 ======
|
||||
|
||||
|
||||
class LogEntry(BaseModel):
|
||||
"""单条日志记录(Runner → Host 传输格式)"""
|
||||
|
||||
timestamp_ms: int = Field(
|
||||
description="日志时间戳,Unix epoch 毫秒",
|
||||
)
|
||||
level: int = Field(
|
||||
description=("stdlib logging 整数级别: 10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL"),
|
||||
)
|
||||
logger_name: str = Field(
|
||||
description="Logger 名称,如 plugin.my_plugin.submodule",
|
||||
)
|
||||
message: str = Field(
|
||||
description="经 Formatter 格式化后的完整日志消息(含 exc_info 文本)",
|
||||
)
|
||||
timestamp_ms: int = Field(description="日志时间戳,Unix epoch 毫秒")
|
||||
"""日志时间戳,Unix epoch 毫秒"""
|
||||
level: int = Field(description="stdlib logging 整数级别:10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL")
|
||||
"""stdlib logging 整数级别:10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL"""
|
||||
logger_name: str = Field(description="Logger 名称,如 plugin.my_plugin.submodule")
|
||||
"""Logger 名称,如 plugin.my_plugin.submodule"""
|
||||
message: str = Field(description="经 Formatter 格式化后的完整日志消息(含 exc_info 文本)")
|
||||
"""经 Formatter 格式化后的完整日志消息(含 exc_info 文本)"""
|
||||
exception_text: str = Field(
|
||||
default="",
|
||||
description="原始异常摘要(exc_text),供结构化消费;已嵌入 message 中",
|
||||
)
|
||||
"""原始异常摘要(exc_text),供结构化消费;已嵌入 message 中"""
|
||||
log_color_in_hex: Optional[str] = Field(default=None, description="日志颜色的十六进制字符串(如 #RRGGBB)")
|
||||
|
||||
@property
|
||||
def levelname(self) -> str:
|
||||
@@ -262,6 +418,5 @@ class LogEntry(BaseModel):
|
||||
class LogBatchPayload(BaseModel):
|
||||
"""runner.log_batch 事件 payload:Runner 端向 Host 批量推送日志记录"""
|
||||
|
||||
entries: List[LogEntry] = Field(
|
||||
description="本批次日志记录列表,按时间升序排列",
|
||||
)
|
||||
entries: List[LogEntry] = Field(description="本批次日志记录列表,按时间升序排列")
|
||||
"""本批次日志记录列表,按时间升序排列"""
|
||||
|
||||
@@ -18,17 +18,17 @@ class ErrorCode(str, Enum):
|
||||
E_TIMEOUT = "E_TIMEOUT"
|
||||
E_BAD_PAYLOAD = "E_BAD_PAYLOAD"
|
||||
E_PROTOCOL_MISMATCH = "E_PROTOCOL_MISMATCH"
|
||||
E_SHUTTING_DOWN = "E_SHUTTING_DOWN"
|
||||
|
||||
# 权限与策略
|
||||
E_UNAUTHORIZED = "E_UNAUTHORIZED"
|
||||
E_METHOD_NOT_ALLOWED = "E_METHOD_NOT_ALLOWED"
|
||||
E_BACKPRESSURE = "E_BACKPRESSURE"
|
||||
E_BACK_PRESSURE = "E_BACK_PRESSURE"
|
||||
E_HOST_OVERLOADED = "E_HOST_OVERLOADED"
|
||||
|
||||
# 插件生命周期
|
||||
E_PLUGIN_CRASHED = "E_PLUGIN_CRASHED"
|
||||
E_PLUGIN_NOT_FOUND = "E_PLUGIN_NOT_FOUND"
|
||||
E_GENERATION_MISMATCH = "E_GENERATION_MISMATCH"
|
||||
E_RELOAD_IN_PROGRESS = "E_RELOAD_IN_PROGRESS"
|
||||
|
||||
# 能力调用
|
||||
@@ -65,3 +65,13 @@ class RPCError(Exception):
|
||||
message=data.get("message", ""),
|
||||
details=data.get("details", {}),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_exception(cls, exception: Exception, code_mapping: Optional[Dict[type[Exception], ErrorCode]] = None):
|
||||
if isinstance(exception, cls):
|
||||
return exception
|
||||
if code_mapping:
|
||||
for exception_type, code in code_mapping.items():
|
||||
if isinstance(exception, exception_type):
|
||||
return cls(code=code, message=str(exception))
|
||||
return cls(ErrorCode.E_UNKNOWN, str(exception))
|
||||
|
||||
@@ -66,6 +66,12 @@ class RunnerIPCLogHandler(logging.Handler):
|
||||
ALLOWED_LOGGER_PREFIXES: tuple[str, ...] = ("plugin.", "plugin_runtime.", "_maibot_plugin_")
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化 Runner 端日志转发处理器。
|
||||
|
||||
创建有界日志缓冲区,并准备与 RPC 客户端绑定的后台刷新任务。
|
||||
此时不会启动任何异步任务;真正开始转发要等到 :meth:`start`
|
||||
被调用后才会发生。
|
||||
"""
|
||||
super().__init__()
|
||||
# deque(maxlen=N): append/popleft 在 CPython GIL 保护下线程安全
|
||||
self._buffer: collections.deque[LogEntry] = collections.deque(maxlen=self.QUEUE_MAX)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -13,16 +13,16 @@ from typing import Any, Dict, Iterator, List, Optional, Set, Tuple
|
||||
import contextlib
|
||||
import importlib
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
|
||||
from src.plugin_runtime.runner.manifest_validator import ManifestValidator, PluginManifest
|
||||
|
||||
logger = get_logger("plugin_runtime.runner.plugin_loader")
|
||||
|
||||
PluginCandidate = Tuple[Path, Dict[str, Any], Path]
|
||||
PluginCandidate = Tuple[Path, PluginManifest, Path]
|
||||
|
||||
|
||||
class PluginMeta:
|
||||
@@ -32,28 +32,28 @@ class PluginMeta:
|
||||
self,
|
||||
plugin_id: str,
|
||||
plugin_dir: str,
|
||||
module_name: str,
|
||||
plugin_instance: Any,
|
||||
manifest: Dict[str, Any],
|
||||
manifest: PluginManifest,
|
||||
) -> None:
|
||||
"""初始化插件元数据。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID。
|
||||
plugin_dir: 插件目录绝对路径。
|
||||
module_name: 插件入口模块名。
|
||||
plugin_instance: 插件实例对象。
|
||||
manifest: 解析后的强类型 Manifest。
|
||||
"""
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_dir = plugin_dir
|
||||
self.module_name = module_name
|
||||
self.instance = plugin_instance
|
||||
self.manifest = manifest
|
||||
self.version = manifest.get("version", "1.0.0")
|
||||
self.capabilities_required = manifest.get("capabilities", [])
|
||||
self.dependencies: List[str] = self._extract_dependencies(manifest)
|
||||
|
||||
@staticmethod
|
||||
def _extract_dependencies(manifest: Dict[str, Any]) -> List[str]:
|
||||
raw = manifest.get("dependencies", [])
|
||||
result: List[str] = []
|
||||
for dep in raw:
|
||||
if isinstance(dep, str):
|
||||
result.append(dep.strip())
|
||||
elif isinstance(dep, dict):
|
||||
if name := str(dep.get("name", "")).strip():
|
||||
result.append(name)
|
||||
return result
|
||||
self.version = manifest.version
|
||||
self.capabilities_required = list(manifest.capabilities)
|
||||
self.dependencies: List[str] = list(manifest.plugin_dependency_ids)
|
||||
self.component_handlers: Dict[str, str] = {}
|
||||
|
||||
|
||||
class PluginLoader:
|
||||
@@ -66,30 +66,52 @@ class PluginLoader:
|
||||
"""
|
||||
|
||||
def __init__(self, host_version: str = "") -> None:
|
||||
"""初始化插件加载器。
|
||||
|
||||
Args:
|
||||
host_version: Host 版本号,用于 manifest 兼容性校验。
|
||||
"""
|
||||
self._loaded_plugins: Dict[str, PluginMeta] = {}
|
||||
self._failed_plugins: Dict[str, str] = {}
|
||||
self._manifest_validator = ManifestValidator(host_version=host_version)
|
||||
self._compat_hook_installed = False
|
||||
|
||||
def discover_and_load(self, plugin_dirs: List[str]) -> List[PluginMeta]:
|
||||
"""扫描多个目录并加载所有插件(含依赖排序和 manifest 校验)
|
||||
def discover_and_load(
|
||||
self,
|
||||
plugin_dirs: List[str],
|
||||
extra_available: Optional[Dict[str, str]] = None,
|
||||
) -> List[PluginMeta]:
|
||||
"""扫描多个目录并加载所有插件。
|
||||
|
||||
Args:
|
||||
plugin_dirs: 插件目录列表
|
||||
plugin_dirs: 插件目录列表。
|
||||
extra_available: 额外视为已满足的外部依赖插件版本映射。
|
||||
|
||||
Returns:
|
||||
成功加载的插件元数据列表(按依赖顺序)
|
||||
List[PluginMeta]: 成功加载的插件元数据列表,按依赖顺序排列。
|
||||
"""
|
||||
candidates, duplicate_candidates = self._discover_candidates(plugin_dirs)
|
||||
self._record_duplicate_candidates(duplicate_candidates)
|
||||
|
||||
# 第二阶段:依赖解析(拓扑排序)
|
||||
load_order, failed_deps = self._resolve_dependencies(candidates)
|
||||
load_order, failed_deps = self._resolve_dependencies(candidates, extra_available=extra_available)
|
||||
self._record_failed_dependencies(failed_deps)
|
||||
|
||||
# 第三阶段:按依赖顺序加载
|
||||
return self._load_plugins_in_order(load_order, candidates)
|
||||
|
||||
def discover_candidates(self, plugin_dirs: List[str]) -> Tuple[Dict[str, PluginCandidate], Dict[str, List[Path]]]:
|
||||
"""扫描插件目录并返回候选插件。
|
||||
|
||||
Args:
|
||||
plugin_dirs: 需要扫描的插件根目录列表。
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, PluginCandidate], Dict[str, List[Path]]]:
|
||||
候选插件映射和重复插件 ID 冲突映射。
|
||||
"""
|
||||
return self._discover_candidates(plugin_dirs)
|
||||
|
||||
def _discover_candidates(self, plugin_dirs: List[str]) -> Tuple[Dict[str, PluginCandidate], Dict[str, List[Path]]]:
|
||||
"""扫描插件目录并收集候选插件。"""
|
||||
candidates: Dict[str, PluginCandidate] = {}
|
||||
@@ -123,26 +145,17 @@ class PluginLoader:
|
||||
|
||||
def _discover_single_candidate(self, plugin_dir: Path) -> Optional[Tuple[str, PluginCandidate]]:
|
||||
"""发现并校验单个插件目录。"""
|
||||
manifest_path = plugin_dir / "_manifest.json"
|
||||
plugin_path = plugin_dir / "plugin.py"
|
||||
|
||||
if not manifest_path.exists() or not plugin_path.exists():
|
||||
if not plugin_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with manifest_path.open("r", encoding="utf-8") as manifest_file:
|
||||
manifest: Dict[str, Any] = json.load(manifest_file)
|
||||
except Exception as e:
|
||||
self._failed_plugins[plugin_dir.name] = f"manifest 解析失败: {e}"
|
||||
logger.error(f"插件 {plugin_dir.name} manifest 解析失败: {e}")
|
||||
return None
|
||||
|
||||
if not self._manifest_validator.validate(manifest):
|
||||
manifest = self._manifest_validator.load_from_plugin_path(plugin_dir)
|
||||
if manifest is None:
|
||||
errors = "; ".join(self._manifest_validator.errors)
|
||||
self._failed_plugins[plugin_dir.name] = f"manifest 校验失败: {errors}"
|
||||
return None
|
||||
|
||||
plugin_id = str(manifest.get("name", plugin_dir.name)).strip() or plugin_dir.name
|
||||
plugin_id = manifest.id
|
||||
return plugin_id, (plugin_dir, manifest, plugin_path)
|
||||
|
||||
def _record_duplicate_candidates(self, duplicate_candidates: Dict[str, List[Path]]) -> None:
|
||||
@@ -170,7 +183,6 @@ class PluginLoader:
|
||||
plugin_dir, manifest, plugin_path = candidates[plugin_id]
|
||||
try:
|
||||
if meta := self._load_single_plugin(plugin_id, plugin_dir, manifest, plugin_path):
|
||||
self._loaded_plugins[meta.plugin_id] = meta
|
||||
results.append(meta)
|
||||
except Exception as e:
|
||||
self._failed_plugins[plugin_id] = str(e)
|
||||
@@ -182,45 +194,193 @@ class PluginLoader:
|
||||
"""获取已加载的插件"""
|
||||
return self._loaded_plugins.get(plugin_id)
|
||||
|
||||
def set_loaded_plugin(self, meta: PluginMeta) -> None:
|
||||
"""登记一个已经完成初始化的插件。
|
||||
|
||||
Args:
|
||||
meta: 待登记的插件元数据。
|
||||
"""
|
||||
self._loaded_plugins[meta.plugin_id] = meta
|
||||
|
||||
def remove_loaded_plugin(self, plugin_id: str) -> Optional[PluginMeta]:
|
||||
"""移除一个已加载插件的元数据。
|
||||
|
||||
Args:
|
||||
plugin_id: 待移除的插件 ID。
|
||||
|
||||
Returns:
|
||||
Optional[PluginMeta]: 被移除的插件元数据;不存在时返回 ``None``。
|
||||
"""
|
||||
return self._loaded_plugins.pop(plugin_id, None)
|
||||
|
||||
def purge_plugin_modules(self, plugin_id: str, plugin_dir: str) -> List[str]:
|
||||
"""清理指定插件目录下的模块缓存。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID。
|
||||
plugin_dir: 插件目录绝对路径。
|
||||
|
||||
Returns:
|
||||
List[str]: 已从 ``sys.modules`` 中移除的模块名列表。
|
||||
"""
|
||||
removed_modules: List[str] = []
|
||||
plugin_path = Path(plugin_dir).resolve()
|
||||
synthetic_module_name = self._build_safe_module_name(plugin_id)
|
||||
|
||||
for module_name, module in list(sys.modules.items()):
|
||||
if module_name == synthetic_module_name:
|
||||
removed_modules.append(module_name)
|
||||
sys.modules.pop(module_name, None)
|
||||
continue
|
||||
|
||||
module_file = getattr(module, "__file__", None)
|
||||
if module_file is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
module_path = Path(module_file).resolve()
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if module_path.is_relative_to(plugin_path):
|
||||
removed_modules.append(module_name)
|
||||
sys.modules.pop(module_name, None)
|
||||
|
||||
importlib.invalidate_caches()
|
||||
return removed_modules
|
||||
|
||||
@staticmethod
|
||||
def _build_safe_module_name(plugin_id: str) -> str:
|
||||
"""将插件 ID 转换为可用于动态导入的安全模块名。
|
||||
|
||||
Args:
|
||||
plugin_id: 原始插件 ID。
|
||||
|
||||
Returns:
|
||||
str: 仅包含字母、数字和下划线的合成模块名。
|
||||
"""
|
||||
normalized_plugin_id = re.sub(r"[^0-9A-Za-z_]", "_", str(plugin_id or "").strip())
|
||||
if normalized_plugin_id and normalized_plugin_id[0].isdigit():
|
||||
normalized_plugin_id = f"_{normalized_plugin_id}"
|
||||
return f"_maibot_plugin_{normalized_plugin_id or 'plugin'}"
|
||||
|
||||
def list_plugins(self) -> List[str]:
|
||||
"""列出所有已加载的插件 ID"""
|
||||
return list(self._loaded_plugins.keys())
|
||||
|
||||
@property
|
||||
def failed_plugins(self) -> Dict[str, str]:
|
||||
"""返回当前记录的失败插件原因映射。"""
|
||||
return dict(self._failed_plugins)
|
||||
|
||||
@property
|
||||
def manifest_validator(self) -> ManifestValidator:
|
||||
"""返回当前加载器持有的 Manifest 校验器。
|
||||
|
||||
Returns:
|
||||
ManifestValidator: 当前使用的 Manifest 校验器实例。
|
||||
"""
|
||||
return self._manifest_validator
|
||||
|
||||
# ──── 依赖解析 ────────────────────────────────────────────
|
||||
|
||||
def resolve_dependencies(
|
||||
self,
|
||||
candidates: Dict[str, PluginCandidate],
|
||||
extra_available: Optional[Dict[str, str]] = None,
|
||||
) -> Tuple[List[str], Dict[str, str]]:
|
||||
"""解析候选插件的依赖顺序。
|
||||
|
||||
Args:
|
||||
candidates: 待加载的候选插件集合。
|
||||
extra_available: 视为已满足的外部依赖插件版本映射。
|
||||
|
||||
Returns:
|
||||
Tuple[List[str], Dict[str, str]]: 可加载顺序和失败原因映射。
|
||||
"""
|
||||
return self._resolve_dependencies(candidates, extra_available=extra_available)
|
||||
|
||||
def load_candidate(self, plugin_id: str, candidate: PluginCandidate) -> Optional[PluginMeta]:
|
||||
"""加载单个候选插件模块。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID。
|
||||
candidate: 候选插件三元组。
|
||||
|
||||
Returns:
|
||||
Optional[PluginMeta]: 加载成功的插件元数据;失败时返回 ``None``。
|
||||
"""
|
||||
plugin_dir, manifest, plugin_path = candidate
|
||||
return self._load_single_plugin(plugin_id, plugin_dir, manifest, plugin_path)
|
||||
|
||||
def _resolve_dependencies(
|
||||
self,
|
||||
candidates: Dict[str, PluginCandidate],
|
||||
extra_available: Optional[Dict[str, str]] = None,
|
||||
) -> Tuple[List[str], Dict[str, str]]:
|
||||
"""拓扑排序解析加载顺序,返回 (有序列表, 失败项 {id: reason})。"""
|
||||
available = set(candidates.keys())
|
||||
satisfied_dependencies = {
|
||||
str(plugin_id or "").strip(): str(plugin_version or "").strip()
|
||||
for plugin_id, plugin_version in (extra_available or {}).items()
|
||||
if str(plugin_id or "").strip() and str(plugin_version or "").strip()
|
||||
}
|
||||
dep_graph: Dict[str, Set[str]] = {}
|
||||
failed: Dict[str, str] = {}
|
||||
|
||||
for pid, (_, manifest, _) in candidates.items():
|
||||
raw_deps = manifest.get("dependencies", [])
|
||||
resolved: Set[str] = set()
|
||||
missing: List[str] = []
|
||||
for dep in raw_deps:
|
||||
dep_name = dep if isinstance(dep, str) else str(dep.get("name", ""))
|
||||
dep_name = dep_name.strip()
|
||||
if not dep_name or dep_name == pid:
|
||||
missing_or_incompatible: List[str] = []
|
||||
|
||||
for dependency in manifest.plugin_dependencies:
|
||||
dependency_id = dependency.id
|
||||
if dependency_id in available:
|
||||
dependency_manifest = candidates[dependency_id][1]
|
||||
if not self._manifest_validator.is_plugin_dependency_satisfied(
|
||||
dependency,
|
||||
dependency_manifest.version,
|
||||
):
|
||||
missing_or_incompatible.append(
|
||||
f"{dependency_id} (需要 {dependency.version_spec},当前 {dependency_manifest.version})"
|
||||
)
|
||||
continue
|
||||
resolved.add(dependency_id)
|
||||
continue
|
||||
if dep_name in available:
|
||||
resolved.add(dep_name)
|
||||
else:
|
||||
missing.append(dep_name)
|
||||
if missing:
|
||||
failed[pid] = f"缺少依赖: {', '.join(missing)}"
|
||||
|
||||
external_dependency_version = satisfied_dependencies.get(dependency_id)
|
||||
if external_dependency_version is None:
|
||||
missing_or_incompatible.append(f"{dependency_id} (未找到依赖插件)")
|
||||
continue
|
||||
|
||||
if not self._manifest_validator.is_plugin_dependency_satisfied(
|
||||
dependency,
|
||||
external_dependency_version,
|
||||
):
|
||||
missing_or_incompatible.append(
|
||||
f"{dependency_id} (需要 {dependency.version_spec},当前 {external_dependency_version})"
|
||||
)
|
||||
|
||||
if missing_or_incompatible:
|
||||
failed[pid] = f"依赖未满足: {', '.join(missing_or_incompatible)}"
|
||||
dep_graph[pid] = resolved
|
||||
|
||||
# 移除失败项
|
||||
for pid in failed:
|
||||
dep_graph.pop(pid, None)
|
||||
# 迭代传播“依赖自身加载失败”到上游依赖方,避免误报为循环依赖
|
||||
changed = True
|
||||
while changed:
|
||||
changed = False
|
||||
failed_plugin_ids = set(failed)
|
||||
for pid, dependencies in list(dep_graph.items()):
|
||||
if pid in failed:
|
||||
dep_graph.pop(pid, None)
|
||||
continue
|
||||
|
||||
failed_dependencies = sorted(dependency for dependency in dependencies if dependency in failed_plugin_ids)
|
||||
if not failed_dependencies:
|
||||
continue
|
||||
|
||||
failed[pid] = f"依赖未满足: {', '.join(f'{dependency} (依赖插件加载失败)' for dependency in failed_dependencies)}"
|
||||
dep_graph.pop(pid, None)
|
||||
changed = True
|
||||
|
||||
# Kahn 拓扑排序
|
||||
indegree = {pid: len(deps) for pid, deps in dep_graph.items()}
|
||||
@@ -253,7 +413,7 @@ class PluginLoader:
|
||||
self,
|
||||
plugin_id: str,
|
||||
plugin_dir: Path,
|
||||
manifest: Dict[str, Any],
|
||||
manifest: PluginManifest,
|
||||
plugin_path: Path,
|
||||
) -> Optional[PluginMeta]:
|
||||
"""加载单个插件"""
|
||||
@@ -261,8 +421,12 @@ class PluginLoader:
|
||||
self._ensure_compat_hook()
|
||||
|
||||
# 动态导入插件模块
|
||||
module_name = f"_maibot_plugin_{plugin_id}"
|
||||
spec = importlib.util.spec_from_file_location(module_name, str(plugin_path))
|
||||
module_name = self._build_safe_module_name(plugin_id)
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
module_name,
|
||||
str(plugin_path),
|
||||
submodule_search_locations=[str(plugin_dir)],
|
||||
)
|
||||
if spec is None or spec.loader is None:
|
||||
logger.error(f"无法创建模块 spec: {plugin_path}")
|
||||
return None
|
||||
@@ -271,37 +435,73 @@ class PluginLoader:
|
||||
sys.modules[module_name] = module
|
||||
|
||||
plugin_parent_dir = plugin_dir.parent
|
||||
with self._temporary_sys_path_entry(plugin_parent_dir):
|
||||
spec.loader.exec_module(module)
|
||||
try:
|
||||
with self._temporary_sys_path_entry(plugin_parent_dir):
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# 优先使用新版 create_plugin 工厂函数
|
||||
create_plugin = getattr(module, "create_plugin", None)
|
||||
if create_plugin is not None:
|
||||
instance = create_plugin()
|
||||
logger.info(f"插件 {plugin_id} v{manifest.get('version', '?')} 加载成功")
|
||||
return PluginMeta(
|
||||
plugin_id=plugin_id,
|
||||
plugin_dir=str(plugin_dir),
|
||||
plugin_instance=instance,
|
||||
manifest=manifest,
|
||||
)
|
||||
# 优先使用新版 create_plugin 工厂函数
|
||||
create_plugin = getattr(module, "create_plugin", None)
|
||||
if create_plugin is not None:
|
||||
instance = create_plugin()
|
||||
self._validate_sdk_plugin_contract(plugin_id, instance)
|
||||
logger.info(f"插件 {plugin_id} v{manifest.version} 加载成功")
|
||||
return PluginMeta(
|
||||
plugin_id=plugin_id,
|
||||
plugin_dir=str(plugin_dir),
|
||||
module_name=module_name,
|
||||
plugin_instance=instance,
|
||||
manifest=manifest,
|
||||
)
|
||||
|
||||
# 回退:检测旧版 @register_plugin 标记的 BasePlugin 子类
|
||||
instance = self._try_load_legacy_plugin(module, plugin_id)
|
||||
if instance is not None:
|
||||
logger.info(
|
||||
f"插件 {plugin_id} v{manifest.get('version', '?')} 通过旧版兼容层加载成功(请尽快迁移到 maibot_sdk)"
|
||||
)
|
||||
return PluginMeta(
|
||||
plugin_id=plugin_id,
|
||||
plugin_dir=str(plugin_dir),
|
||||
plugin_instance=instance,
|
||||
manifest=manifest,
|
||||
)
|
||||
# 回退:检测旧版 @register_plugin 标记的 BasePlugin 子类
|
||||
instance = self._try_load_legacy_plugin(module, plugin_id)
|
||||
if instance is not None:
|
||||
logger.info(
|
||||
f"插件 {plugin_id} v{manifest.version} 通过旧版兼容层加载成功(请尽快迁移到 maibot_sdk)"
|
||||
)
|
||||
return PluginMeta(
|
||||
plugin_id=plugin_id,
|
||||
plugin_dir=str(plugin_dir),
|
||||
module_name=module_name,
|
||||
plugin_instance=instance,
|
||||
manifest=manifest,
|
||||
)
|
||||
except Exception:
|
||||
sys.modules.pop(module_name, None)
|
||||
raise
|
||||
|
||||
logger.error(f"插件 {plugin_id} 缺少 create_plugin 工厂函数且未检测到旧版 BasePlugin")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _validate_sdk_plugin_contract(plugin_id: str, instance: Any) -> None:
|
||||
"""校验 SDK 插件的基础契约。
|
||||
|
||||
Args:
|
||||
plugin_id: 当前插件 ID。
|
||||
instance: ``create_plugin()`` 返回的插件实例。
|
||||
|
||||
Raises:
|
||||
TypeError: 当插件未覆盖必需生命周期方法或订阅声明不合法时抛出。
|
||||
"""
|
||||
|
||||
try:
|
||||
from maibot_sdk.plugin import MaiBotPlugin
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
if not isinstance(instance, MaiBotPlugin):
|
||||
return
|
||||
|
||||
if type(instance).on_load is MaiBotPlugin.on_load:
|
||||
raise TypeError(f"插件 {plugin_id} 必须实现 on_load()")
|
||||
if type(instance).on_unload is MaiBotPlugin.on_unload:
|
||||
raise TypeError(f"插件 {plugin_id} 必须实现 on_unload()")
|
||||
if type(instance).on_config_update is MaiBotPlugin.on_config_update:
|
||||
raise TypeError(f"插件 {plugin_id} 必须实现 on_config_update()")
|
||||
|
||||
instance.get_config_reload_subscriptions()
|
||||
|
||||
@staticmethod
|
||||
@contextlib.contextmanager
|
||||
def _temporary_sys_path_entry(path: Path) -> Iterator[None]:
|
||||
|
||||
@@ -1,14 +1,6 @@
|
||||
"""Runner 端 RPC Client
|
||||
"""Runner 端 RPC 客户端。"""
|
||||
|
||||
负责:
|
||||
1. 连接 Host RPC Server
|
||||
2. 发送握手(runner.hello)
|
||||
3. 发送组件注册请求
|
||||
4. 接收并分发 Host 的调用请求
|
||||
5. 发送能力调用请求到 Host
|
||||
"""
|
||||
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional, cast
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional, Set, cast
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
@@ -29,12 +21,15 @@ from src.plugin_runtime.transport.factory import create_transport_client
|
||||
|
||||
logger = get_logger("plugin_runtime.runner.rpc_client")
|
||||
|
||||
# RPC 方法处理器类型
|
||||
MethodHandler = Callable[[Envelope], Awaitable[Envelope]]
|
||||
|
||||
|
||||
def _get_sdk_version() -> str:
|
||||
"""从 maibot_sdk 包元数据中读取实际版本号,失败时回退到 1.0.0。"""
|
||||
"""读取 SDK 版本号。
|
||||
|
||||
Returns:
|
||||
str: 已安装的 SDK 版本;读取失败时回退到 ``1.0.0``。
|
||||
"""
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
|
||||
@@ -47,73 +42,78 @@ SDK_VERSION = _get_sdk_version()
|
||||
|
||||
|
||||
class RPCClient:
|
||||
"""Runner 端 RPC 客户端
|
||||
|
||||
管理与 Host 的 IPC 连接,支持双向 RPC 调用。
|
||||
"""
|
||||
"""Runner 端 RPC 客户端。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host_address: str,
|
||||
session_token: str,
|
||||
codec: Optional[Codec] = None,
|
||||
):
|
||||
self._host_address = host_address
|
||||
self._session_token = session_token
|
||||
self._codec = codec or MsgPackCodec()
|
||||
) -> None:
|
||||
"""初始化 RPC 客户端。
|
||||
|
||||
Args:
|
||||
host_address: Host 的 IPC 地址。
|
||||
session_token: 握手用会话令牌。
|
||||
codec: 可选的编解码器实现。
|
||||
"""
|
||||
self._host_address: str = host_address
|
||||
self._session_token: str = session_token
|
||||
self._codec: Codec = codec or MsgPackCodec()
|
||||
|
||||
self._id_gen = RequestIdGenerator()
|
||||
self._connection: Optional[Connection] = None
|
||||
self._runner_id = str(uuid.uuid4())
|
||||
self._generation: int = 0
|
||||
|
||||
# 方法处理器注册表(Host 发来的调用)
|
||||
self._runner_id: str = str(uuid.uuid4())
|
||||
self._method_handlers: Dict[str, MethodHandler] = {}
|
||||
|
||||
# 等待响应的 pending 请求: request_id -> Future
|
||||
self._pending_requests: Dict[int, asyncio.Future] = {}
|
||||
|
||||
# 运行状态
|
||||
self._running = False
|
||||
self._recv_task: Optional[asyncio.Task] = None
|
||||
self._background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
@property
|
||||
def generation(self) -> int:
|
||||
return self._generation
|
||||
self._pending_requests: Dict[int, asyncio.Future[Envelope]] = {}
|
||||
self._running: bool = False
|
||||
self._recv_task: Optional[asyncio.Task[None]] = None
|
||||
self._background_tasks: Set[asyncio.Task[Any]] = set()
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""返回当前连接是否可用。"""
|
||||
return self._connection is not None and not self._connection.is_closed
|
||||
|
||||
def register_method(self, method: str, handler: MethodHandler) -> None:
|
||||
"""注册方法处理器(处理 Host 发来的请求)"""
|
||||
"""注册 Host -> Runner 的 RPC 处理器。
|
||||
|
||||
Args:
|
||||
method: RPC 方法名。
|
||||
handler: 方法处理函数。
|
||||
"""
|
||||
self._method_handlers[method] = handler
|
||||
|
||||
def _require_connection(self) -> Connection:
|
||||
"""返回当前可用连接;若连接不可用则抛出 RPCError。"""
|
||||
"""返回当前可用连接。
|
||||
|
||||
Returns:
|
||||
Connection: 当前连接对象。
|
||||
|
||||
Raises:
|
||||
RPCError: 当前未连接到 Host。
|
||||
"""
|
||||
connection = self._connection
|
||||
if connection is None or connection.is_closed:
|
||||
raise RPCError(ErrorCode.E_UNKNOWN, "未连接到 Host")
|
||||
return cast(Connection, connection)
|
||||
|
||||
async def connect_and_handshake(self) -> bool:
|
||||
"""连接 Host 并完成握手
|
||||
"""连接 Host 并完成握手。
|
||||
|
||||
Returns:
|
||||
是否握手成功
|
||||
bool: 是否握手成功。
|
||||
"""
|
||||
client = create_transport_client(self._host_address)
|
||||
self._connection = await client.connect()
|
||||
connection = self._require_connection()
|
||||
|
||||
# 发送 runner.hello
|
||||
hello = HelloPayload(
|
||||
runner_id=self._runner_id,
|
||||
sdk_version=SDK_VERSION,
|
||||
session_token=self._session_token,
|
||||
)
|
||||
request_id = self._id_gen.next()
|
||||
request_id = await self._id_gen.next()
|
||||
envelope = Envelope(
|
||||
request_id=request_id,
|
||||
message_type=MessageType.REQUEST,
|
||||
@@ -121,33 +121,27 @@ class RPCClient:
|
||||
payload=hello.model_dump(),
|
||||
)
|
||||
|
||||
data = self._codec.encode_envelope(envelope)
|
||||
await connection.send_frame(data)
|
||||
await connection.send_frame(self._codec.encode_envelope(envelope))
|
||||
|
||||
# 接收握手响应
|
||||
resp_data = await asyncio.wait_for(connection.recv_frame(), timeout=10.0)
|
||||
resp = self._codec.decode_envelope(resp_data)
|
||||
response = self._codec.decode_envelope(resp_data)
|
||||
resp_payload = HelloResponsePayload.model_validate(response.payload)
|
||||
|
||||
resp_payload = HelloResponsePayload.model_validate(resp.payload)
|
||||
if not resp_payload.accepted:
|
||||
logger.error(f"握手被拒绝: {resp_payload.reason}")
|
||||
await self._connection.close()
|
||||
self._connection = None
|
||||
await self.disconnect()
|
||||
return False
|
||||
|
||||
self._generation = resp_payload.assigned_generation
|
||||
logger.info(f"握手成功: generation={self._generation}, host_version={resp_payload.host_version}")
|
||||
|
||||
# 启动消息接收循环
|
||||
logger.info(f"握手成功: host_version={resp_payload.host_version}")
|
||||
self._running = True
|
||||
self._recv_task = asyncio.create_task(self._recv_loop())
|
||||
|
||||
self._recv_task = asyncio.create_task(self._recv_loop(), name="RPCClient.recv")
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""断开连接"""
|
||||
"""断开与 Host 的连接并清理状态。"""
|
||||
self._running = False
|
||||
if self._recv_task:
|
||||
|
||||
if self._recv_task is not None:
|
||||
self._recv_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._recv_task
|
||||
@@ -160,13 +154,12 @@ class RPCClient:
|
||||
await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
||||
self._background_tasks.clear()
|
||||
|
||||
# 取消所有 pending 请求
|
||||
for future in self._pending_requests.values():
|
||||
if not future.done():
|
||||
future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "连接关闭"))
|
||||
self._pending_requests.clear()
|
||||
|
||||
if self._connection:
|
||||
if self._connection is not None:
|
||||
await self._connection.close()
|
||||
self._connection = None
|
||||
|
||||
@@ -177,16 +170,27 @@ class RPCClient:
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Envelope:
|
||||
"""向 Host 发送 RPC 请求并等待响应"""
|
||||
connection = self._require_connection()
|
||||
"""向 Host 发送 RPC 请求并等待响应。
|
||||
|
||||
request_id = self._id_gen.next()
|
||||
Args:
|
||||
method: RPC 方法名。
|
||||
plugin_id: 目标插件 ID。
|
||||
payload: 请求载荷。
|
||||
timeout_ms: 超时时间,单位毫秒。
|
||||
|
||||
Returns:
|
||||
Envelope: Host 返回的响应信封。
|
||||
|
||||
Raises:
|
||||
RPCError: 发送失败、超时或连接异常。
|
||||
"""
|
||||
connection = self._require_connection()
|
||||
request_id = await self._id_gen.next()
|
||||
envelope = Envelope(
|
||||
request_id=request_id,
|
||||
message_type=MessageType.REQUEST,
|
||||
method=method,
|
||||
plugin_id=plugin_id,
|
||||
generation=self._generation,
|
||||
timeout_ms=timeout_ms,
|
||||
payload=payload or {},
|
||||
)
|
||||
@@ -196,21 +200,16 @@ class RPCClient:
|
||||
self._pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
data = self._codec.encode_envelope(envelope)
|
||||
await connection.send_frame(data)
|
||||
|
||||
timeout_sec = timeout_ms / 1000.0
|
||||
return await asyncio.wait_for(future, timeout=timeout_sec)
|
||||
await connection.send_frame(self._codec.encode_envelope(envelope))
|
||||
return await asyncio.wait_for(future, timeout=timeout_ms / 1000.0)
|
||||
except asyncio.TimeoutError:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
raise RPCError(ErrorCode.E_TIMEOUT, f"请求 {method} 超时 ({timeout_ms}ms)") from None
|
||||
except Exception as e:
|
||||
except Exception as exc:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
if isinstance(e, RPCError):
|
||||
if isinstance(exc, RPCError):
|
||||
raise
|
||||
raise RPCError(ErrorCode.E_UNKNOWN, str(e)) from e
|
||||
|
||||
# ─── 内部方法 ──────────────────────────────────────────────
|
||||
raise RPCError(ErrorCode.E_UNKNOWN, str(exc)) from exc
|
||||
|
||||
async def send_event(
|
||||
self,
|
||||
@@ -218,33 +217,30 @@ class RPCClient:
|
||||
plugin_id: str = "",
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""向 Host 发送单向事件(fire-and-forget,不等待响应)。
|
||||
"""向 Host 发送单向广播消息。
|
||||
|
||||
Args:
|
||||
method: RPC 方法名,如 "runner.log_batch"。
|
||||
plugin_id: 目标插件 ID(可为空,表示 Runner 级消息)。
|
||||
payload: 事件数据。
|
||||
method: RPC 方法名。
|
||||
plugin_id: 目标插件 ID。
|
||||
payload: 广播载荷。
|
||||
"""
|
||||
if not self.is_connected:
|
||||
return
|
||||
|
||||
connection = self._require_connection()
|
||||
|
||||
request_id = self._id_gen.next()
|
||||
request_id = await self._id_gen.next()
|
||||
envelope = Envelope(
|
||||
request_id=request_id,
|
||||
message_type=MessageType.EVENT,
|
||||
message_type=MessageType.BROADCAST,
|
||||
method=method,
|
||||
plugin_id=plugin_id,
|
||||
generation=self._generation,
|
||||
payload=payload or {},
|
||||
)
|
||||
data = self._codec.encode_envelope(envelope)
|
||||
await connection.send_frame(data)
|
||||
await connection.send_frame(self._codec.encode_envelope(envelope))
|
||||
|
||||
async def _recv_loop(self) -> None:
|
||||
"""消息接收主循环"""
|
||||
while self._running and self._connection and not self._connection.is_closed:
|
||||
"""持续接收 Host 发来的消息并分发。"""
|
||||
while self._running and self._connection is not None and not self._connection.is_closed:
|
||||
try:
|
||||
data = await self._connection.recv_frame()
|
||||
except (asyncio.IncompleteReadError, ConnectionError):
|
||||
@@ -252,39 +248,47 @@ class RPCClient:
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"接收帧失败: {e}")
|
||||
except Exception as exc:
|
||||
logger.error(f"接收帧失败: {exc}")
|
||||
break
|
||||
|
||||
try:
|
||||
envelope = self._codec.decode_envelope(data)
|
||||
except Exception as e:
|
||||
logger.error(f"解码消息失败: {e}")
|
||||
except Exception as exc:
|
||||
logger.error(f"解码消息失败: {exc}")
|
||||
continue
|
||||
|
||||
if envelope.is_response():
|
||||
self._handle_response(envelope)
|
||||
elif envelope.is_request():
|
||||
self._track_background_task(asyncio.create_task(self._handle_request(envelope)))
|
||||
elif envelope.is_event():
|
||||
self._track_background_task(asyncio.create_task(self._handle_event(envelope)))
|
||||
elif envelope.is_broadcast():
|
||||
self._track_background_task(asyncio.create_task(self._handle_broadcast(envelope)))
|
||||
|
||||
def _handle_response(self, envelope: Envelope) -> None:
|
||||
"""处理来自 Host 的响应"""
|
||||
"""处理 Host 返回的响应。
|
||||
|
||||
Args:
|
||||
envelope: 响应信封。
|
||||
"""
|
||||
future = self._pending_requests.pop(envelope.request_id, None)
|
||||
if future and not future.done():
|
||||
if envelope.error:
|
||||
future.set_exception(RPCError.from_dict(envelope.error))
|
||||
else:
|
||||
future.set_result(envelope)
|
||||
if future is None or future.done():
|
||||
return
|
||||
if envelope.error:
|
||||
future.set_exception(RPCError.from_dict(envelope.error))
|
||||
else:
|
||||
future.set_result(envelope)
|
||||
|
||||
async def _handle_request(self, envelope: Envelope) -> None:
|
||||
"""处理来自 Host 的请求(调用插件组件)"""
|
||||
"""处理 Host 发来的请求。
|
||||
|
||||
Args:
|
||||
envelope: 请求信封。
|
||||
"""
|
||||
connection = self._connection
|
||||
if connection is None or connection.is_closed:
|
||||
logger.warning(f"处理请求 {envelope.method} 时连接已关闭,跳过响应")
|
||||
return
|
||||
connection = cast(Connection, connection)
|
||||
|
||||
handler = self._method_handlers.get(envelope.method)
|
||||
if handler is None:
|
||||
@@ -298,23 +302,34 @@ class RPCClient:
|
||||
try:
|
||||
response = await handler(envelope)
|
||||
await connection.send_frame(self._codec.encode_envelope(response))
|
||||
except RPCError as e:
|
||||
error_resp = envelope.make_error_response(e.code.value, e.message, e.details)
|
||||
except RPCError as exc:
|
||||
error_resp = envelope.make_error_response(exc.code.value, exc.message, exc.details)
|
||||
await connection.send_frame(self._codec.encode_envelope(error_resp))
|
||||
except Exception as e:
|
||||
logger.error(f"处理请求 {envelope.method} 异常: {e}", exc_info=True)
|
||||
error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
|
||||
except Exception as exc:
|
||||
logger.error(f"处理请求 {envelope.method} 异常: {exc}", exc_info=True)
|
||||
error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(exc))
|
||||
await connection.send_frame(self._codec.encode_envelope(error_resp))
|
||||
|
||||
async def _handle_event(self, envelope: Envelope) -> None:
|
||||
"""处理来自 Host 的事件"""
|
||||
if handler := self._method_handlers.get(envelope.method):
|
||||
try:
|
||||
await handler(envelope)
|
||||
except Exception as e:
|
||||
logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True)
|
||||
async def _handle_broadcast(self, envelope: Envelope) -> None:
|
||||
"""处理 Host 发来的广播事件。
|
||||
|
||||
def _track_background_task(self, task: asyncio.Task) -> None:
|
||||
"""保持后台任务强引用,直到其完成或被取消。"""
|
||||
Args:
|
||||
envelope: 广播信封。
|
||||
"""
|
||||
handler = self._method_handlers.get(envelope.method)
|
||||
if handler is None:
|
||||
return
|
||||
|
||||
try:
|
||||
await handler(envelope)
|
||||
except Exception as exc:
|
||||
logger.error(f"处理广播 {envelope.method} 异常: {exc}", exc_info=True)
|
||||
|
||||
def _track_background_task(self, task: asyncio.Task[Any]) -> None:
|
||||
"""持有后台任务强引用直到其结束。
|
||||
|
||||
Args:
|
||||
task: 后台任务。
|
||||
"""
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,9 @@
|
||||
"""Windows Named Pipe 传输实现。
|
||||
|
||||
适用于 Windows 平台,使用 asyncio ProactorEventLoop 的 named pipe 支持。
|
||||
|
||||
注意:Named Pipe 是 Windows 特有的 IPC 机制,
|
||||
在 Linux/macOS 平台上不可用。Unix-like 平台请使用 UDS 传输。
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, cast
|
||||
@@ -18,10 +21,12 @@ _DEFAULT_PIPE_PREFIX = "maibot-plugin"
|
||||
|
||||
|
||||
class _NamedPipeServerHandle(Protocol):
|
||||
"""Named Pipe 服务端句柄的协议定义。"""
|
||||
def close(self) -> None: ...
|
||||
|
||||
|
||||
class _NamedPipeEventLoop(Protocol):
|
||||
"""ProactorEventLoop 的协议定义,提供 named pipe 相关方法。"""
|
||||
async def start_serving_pipe(
|
||||
self,
|
||||
protocol_factory: Callable[[], asyncio.BaseProtocol],
|
||||
@@ -40,6 +45,15 @@ class _NamedPipeEventLoop(Protocol):
|
||||
|
||||
|
||||
def _normalize_pipe_address(pipe_name: Optional[str] = None) -> str:
|
||||
"""规范化 Named Pipe 地址。
|
||||
|
||||
Args:
|
||||
pipe_name: 管道名称。如果以 '\\\\.\\pipe\\' 开头则直接使用,
|
||||
否则会自动添加前缀。如果为 None 则生成随机名称。
|
||||
|
||||
Returns:
|
||||
规范化的管道地址(格式:\\\\.\\pipe\\name)
|
||||
"""
|
||||
if pipe_name and pipe_name.startswith(_PIPE_PREFIX):
|
||||
return pipe_name
|
||||
|
||||
@@ -55,12 +69,21 @@ def _normalize_pipe_address(pipe_name: Optional[str] = None) -> str:
|
||||
|
||||
|
||||
class NamedPipeConnection(Connection):
|
||||
"""基于 Windows Named Pipe 的连接。"""
|
||||
"""基于 Windows Named Pipe 的连接。
|
||||
|
||||
封装了底层 StreamReader/StreamWriter,提供分帧读写能力。
|
||||
"""
|
||||
|
||||
pass
|
||||
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
|
||||
super().__init__(reader, writer)
|
||||
|
||||
|
||||
class _NamedPipeServerProtocol(asyncio.StreamReaderProtocol):
|
||||
"""Named Pipe 服务端协议实现。
|
||||
|
||||
处理客户端连接的生命周期,包括连接建立、数据处理和连接关闭。
|
||||
"""
|
||||
|
||||
def __init__(self, handler: ConnectionHandler, loop: asyncio.AbstractEventLoop) -> None:
|
||||
self._reader: asyncio.StreamReader = asyncio.StreamReader()
|
||||
super().__init__(self._reader)
|
||||
@@ -69,39 +92,58 @@ class _NamedPipeServerProtocol(asyncio.StreamReaderProtocol):
|
||||
self._handler_task: Optional[asyncio.Task[None]] = None
|
||||
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
"""连接建立时的回调。"""
|
||||
super().connection_made(transport)
|
||||
writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), self, self._reader, self._loop)
|
||||
connection = NamedPipeConnection(self._reader, writer)
|
||||
self._handler_task = self._loop.create_task(self._run_handler(connection))
|
||||
# 使用 asyncio.create_task 确保任务正确调度
|
||||
self._handler_task = asyncio.create_task(self._run_handler(connection))
|
||||
self._handler_task.add_done_callback(self._on_handler_done)
|
||||
|
||||
async def _run_handler(self, connection: NamedPipeConnection) -> None:
|
||||
"""运行连接处理器。"""
|
||||
try:
|
||||
await self._handler(connection)
|
||||
finally:
|
||||
await connection.close()
|
||||
|
||||
def _on_handler_done(self, task: asyncio.Task[None]) -> None:
|
||||
"""连接处理器完成时的回调。"""
|
||||
if task.cancelled():
|
||||
return
|
||||
if exc := task.exception():
|
||||
self._loop.call_exception_handler(
|
||||
{
|
||||
"message": "Named pipe 连接处理失败",
|
||||
"exception": exc,
|
||||
"protocol": self,
|
||||
}
|
||||
)
|
||||
try:
|
||||
self._loop.call_exception_handler(
|
||||
{
|
||||
"message": "Named pipe 连接处理失败",
|
||||
"exception": exc,
|
||||
"protocol": self,
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
# 如果 loop 已经关闭,忽略异常
|
||||
pass
|
||||
|
||||
|
||||
class NamedPipeTransportServer(TransportServer):
|
||||
"""Windows Named Pipe 传输服务端。"""
|
||||
"""Windows Named Pipe 传输服务端。
|
||||
|
||||
使用 ProactorEventLoop 的 start_serving_pipe 方法监听客户端连接。
|
||||
"""
|
||||
|
||||
def __init__(self, pipe_name: Optional[str] = None) -> None:
|
||||
self._address: str = _normalize_pipe_address(pipe_name)
|
||||
self._servers: List[_NamedPipeServerHandle] = []
|
||||
|
||||
async def start(self, handler: ConnectionHandler) -> None:
|
||||
"""启动 Named Pipe 服务端。
|
||||
|
||||
Args:
|
||||
handler: 新连接到来时的回调函数
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当在非 Windows 平台或事件循环不支持时
|
||||
"""
|
||||
if sys.platform != "win32":
|
||||
raise RuntimeError("Named pipe 仅支持 Windows")
|
||||
|
||||
@@ -116,32 +158,49 @@ class NamedPipeTransportServer(TransportServer):
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止 Named Pipe 服务端并清理资源。"""
|
||||
for server in self._servers:
|
||||
server.close()
|
||||
# 等待所有服务器句柄完全关闭
|
||||
await asyncio.gather(
|
||||
*[asyncio.sleep(0.1) for _ in self._servers],
|
||||
return_exceptions=True
|
||||
)
|
||||
self._servers.clear()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
def get_address(self) -> str:
|
||||
return self._address
|
||||
|
||||
|
||||
class NamedPipeTransportClient(TransportClient):
|
||||
"""Windows Named Pipe 传输客户端。"""
|
||||
"""Windows Named Pipe 传输客户端。
|
||||
|
||||
用于主动连接到 Named Pipe 服务端。
|
||||
"""
|
||||
|
||||
def __init__(self, address: str) -> None:
|
||||
self._address: str = _normalize_pipe_address(address)
|
||||
|
||||
async def connect(self) -> Connection:
|
||||
"""建立到 Named Pipe 服务端的连接。
|
||||
|
||||
Returns:
|
||||
NamedPipeConnection: 连接对象
|
||||
|
||||
Raises:
|
||||
NotImplementedError: 当在非 Windows 平台或事件循环不支持时
|
||||
"""
|
||||
if sys.platform != "win32":
|
||||
raise RuntimeError("Named pipe 仅支持 Windows")
|
||||
raise NotImplementedError("Named pipe 仅支持 Windows")
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
if not hasattr(loop, "create_pipe_connection"):
|
||||
raise RuntimeError("当前事件循环不支持 Windows named pipe")
|
||||
raise NotImplementedError("当前事件循环不支持 Windows named pipe")
|
||||
pipe_loop = cast(_NamedPipeEventLoop, loop)
|
||||
|
||||
reader = asyncio.StreamReader()
|
||||
protocol = asyncio.StreamReaderProtocol(reader)
|
||||
transport, _protocol = await pipe_loop.create_pipe_connection(lambda: protocol, self._address)
|
||||
writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), protocol, reader, loop)
|
||||
# 使用返回的 protocol 创建 StreamWriter
|
||||
writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), _protocol, reader, loop)
|
||||
return NamedPipeConnection(reader, writer)
|
||||
@@ -1,6 +1,9 @@
|
||||
"""Unix Domain Socket 传输实现
|
||||
|
||||
适用于 Linux / macOS 平台。
|
||||
|
||||
注意:UDS (Unix Domain Socket) 是 Unix-like 系统特有的 IPC 机制,
|
||||
在 Windows 平台上不可用。Windows 平台请使用 Named Pipe 传输。
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
@@ -8,20 +11,30 @@ from typing import Optional
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
from .base import Connection, ConnectionHandler, TransportClient, TransportServer
|
||||
|
||||
|
||||
class UDSConnection(Connection):
|
||||
"""基于 UDS 的连接"""
|
||||
"""基于 UDS 的连接
|
||||
|
||||
封装了底层 StreamReader/StreamWriter,提供分帧读写能力。
|
||||
"""
|
||||
|
||||
pass # 直接复用 Connection 基类的分帧读写
|
||||
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
|
||||
super().__init__(reader, writer)
|
||||
|
||||
|
||||
# Unix domain socket 路径的系统限制(sun_path 字段长度)
|
||||
# Linux: 108 字节, macOS: 104 字节
|
||||
_UDS_PATH_MAX = 104
|
||||
# Linux: 108 字节,macOS: 104 字节,其他 Unix: 通常 104 字节
|
||||
if sys.platform == "linux":
|
||||
_UDS_PATH_MAX = 108
|
||||
elif sys.platform == "darwin": # macOS
|
||||
_UDS_PATH_MAX = 104
|
||||
else:
|
||||
_UDS_PATH_MAX = 104 # 保守默认值
|
||||
|
||||
|
||||
class UDSTransportServer(TransportServer):
|
||||
@@ -44,6 +57,18 @@ class UDSTransportServer(TransportServer):
|
||||
self._server: Optional[asyncio.AbstractServer] = None
|
||||
|
||||
async def start(self, handler: ConnectionHandler) -> None:
|
||||
"""启动 UDS 服务端
|
||||
|
||||
Args:
|
||||
handler: 新连接到来时的回调函数
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当在非 Unix 平台(如 Windows)上调用时
|
||||
"""
|
||||
# 平台检查:UDS 仅在 Unix-like 系统上可用
|
||||
if sys.platform == "win32":
|
||||
raise RuntimeError("UDS 不支持 Windows 平台,请使用 Named Pipe")
|
||||
|
||||
# 清理残留 socket 文件
|
||||
if self._socket_path.exists():
|
||||
self._socket_path.unlink()
|
||||
@@ -58,10 +83,16 @@ class UDSTransportServer(TransportServer):
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
self._server = await asyncio.start_unix_server(_on_connect, path=str(self._socket_path))
|
||||
try:
|
||||
self._server = await asyncio.start_unix_server(_on_connect, path=str(self._socket_path))
|
||||
|
||||
# 设置文件权限为仅当前用户可访问
|
||||
self._socket_path.chmod(0o600)
|
||||
# 设置文件权限为仅当前用户可访问
|
||||
self._socket_path.chmod(0o600)
|
||||
except Exception:
|
||||
# 启动失败时清理可能创建的目录和 socket 文件
|
||||
if self._socket_path.exists():
|
||||
self._socket_path.unlink()
|
||||
raise
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self._server:
|
||||
@@ -77,11 +108,26 @@ class UDSTransportServer(TransportServer):
|
||||
|
||||
|
||||
class UDSTransportClient(TransportClient):
|
||||
"""UDS 传输客户端"""
|
||||
"""UDS 传输客户端
|
||||
|
||||
用于主动连接到 UDS 服务端。
|
||||
"""
|
||||
|
||||
def __init__(self, socket_path: Path) -> None:
|
||||
self._socket_path: Path = socket_path
|
||||
|
||||
async def connect(self) -> Connection:
|
||||
"""建立到 UDS 服务端的连接
|
||||
|
||||
Returns:
|
||||
UDSConnection: 连接对象
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当在非 Unix 平台(如 Windows)上调用时
|
||||
"""
|
||||
# 平台检查:UDS 仅在 Unix-like 系统上可用
|
||||
if sys.platform == "win32":
|
||||
raise RuntimeError("UDS 不支持 Windows 平台,请使用 Named Pipe")
|
||||
|
||||
reader, writer = await asyncio.open_unix_connection(str(self._socket_path))
|
||||
return UDSConnection(reader, writer)
|
||||
|
||||
@@ -1,32 +1,28 @@
|
||||
{
|
||||
"manifest_version": 1,
|
||||
"name": "Emoji插件 (Emoji Actions)",
|
||||
"manifest_version": 2,
|
||||
"version": "2.0.0",
|
||||
"description": "可以发送和管理Emoji",
|
||||
"name": "Emoji插件 (Emoji Actions)",
|
||||
"description": "可以发送和管理 Emoji",
|
||||
"author": {
|
||||
"name": "SengokuCola",
|
||||
"url": "https://github.com/MaiM-with-u"
|
||||
},
|
||||
"license": "GPL-v3.0-or-later",
|
||||
"urls": {
|
||||
"repository": "https://github.com/MaiM-with-u/maibot",
|
||||
"homepage": "https://github.com/MaiM-with-u/maibot",
|
||||
"documentation": "https://github.com/MaiM-with-u/maibot",
|
||||
"issues": "https://github.com/MaiM-with-u/maibot/issues"
|
||||
},
|
||||
"host_application": {
|
||||
"min_version": "1.0.0"
|
||||
"min_version": "1.0.0",
|
||||
"max_version": "1.0.0"
|
||||
},
|
||||
"homepage_url": "https://github.com/MaiM-with-u/maibot",
|
||||
"repository_url": "https://github.com/MaiM-with-u/maibot",
|
||||
"keywords": ["emoji", "action", "built-in"],
|
||||
"categories": ["Emoji"],
|
||||
"default_locale": "zh-CN",
|
||||
"plugin_info": {
|
||||
"is_built_in": true,
|
||||
"plugin_type": "action_provider",
|
||||
"components": [
|
||||
{
|
||||
"type": "action",
|
||||
"name": "emoji",
|
||||
"description": "发送表情包辅助表达情绪"
|
||||
}
|
||||
]
|
||||
"sdk": {
|
||||
"min_version": "2.0.0",
|
||||
"max_version": "2.99.99"
|
||||
},
|
||||
"dependencies": [],
|
||||
"capabilities": [
|
||||
"emoji.get_random",
|
||||
"message.get_recent",
|
||||
@@ -34,5 +30,12 @@
|
||||
"llm.generate",
|
||||
"send.emoji",
|
||||
"config.get"
|
||||
]
|
||||
],
|
||||
"i18n": {
|
||||
"default_locale": "zh-CN",
|
||||
"supported_locales": [
|
||||
"zh-CN"
|
||||
]
|
||||
},
|
||||
"id": "builtin.emoji-plugin"
|
||||
}
|
||||
|
||||
@@ -3,11 +3,11 @@
|
||||
根据聊天上下文的情感,使用 LLM 选择并发送合适的表情包。
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
from maibot_sdk import MaiBotPlugin, Action
|
||||
from maibot_sdk import Action, MaiBotPlugin
|
||||
from maibot_sdk.types import ActivationType
|
||||
|
||||
import random
|
||||
|
||||
|
||||
class EmojiPlugin(MaiBotPlugin):
|
||||
"""表情包插件"""
|
||||
@@ -95,10 +95,35 @@ class EmojiPlugin(MaiBotPlugin):
|
||||
return True, f"成功发送表情包:[表情包:{chosen_emotion}]"
|
||||
return False, "发送表情包失败"
|
||||
|
||||
async def on_load(self):
|
||||
async def on_load(self) -> None:
|
||||
"""处理插件加载。"""
|
||||
|
||||
# 从插件配置读取 emoji_chance 来覆盖默认概率
|
||||
await self.ctx.config.get("emoji.emoji_chance")
|
||||
|
||||
async def on_unload(self) -> None:
|
||||
"""处理插件卸载。"""
|
||||
|
||||
async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None:
|
||||
"""处理配置热重载事件。
|
||||
|
||||
Args:
|
||||
scope: 配置变更范围。
|
||||
config_data: 最新配置数据。
|
||||
version: 配置版本号。
|
||||
"""
|
||||
|
||||
del config_data
|
||||
del version
|
||||
if scope == "self":
|
||||
await self.ctx.config.get("emoji.emoji_chance")
|
||||
|
||||
|
||||
def create_plugin() -> EmojiPlugin:
|
||||
"""创建 Emoji 插件实例。
|
||||
|
||||
Returns:
|
||||
EmojiPlugin: 新的 Emoji 插件实例。
|
||||
"""
|
||||
|
||||
def create_plugin():
|
||||
return EmojiPlugin()
|
||||
|
||||
@@ -1,51 +1,46 @@
|
||||
{
|
||||
"manifest_version": 1,
|
||||
"name": "插件和组件管理 (Plugin and Component Management)",
|
||||
"manifest_version": 2,
|
||||
"version": "2.0.0",
|
||||
"description": "通过系统API管理插件和组件的生命周期,包括加载、卸载、启用和禁用等操作。",
|
||||
"name": "插件和组件管理 (Plugin and Component Management)",
|
||||
"description": "通过系统 API 管理插件和组件的生命周期,包括加载、卸载、启用和禁用等操作。",
|
||||
"author": {
|
||||
"name": "MaiBot团队",
|
||||
"url": "https://github.com/MaiM-with-u"
|
||||
},
|
||||
"license": "GPL-v3.0-or-later",
|
||||
"host_application": {
|
||||
"min_version": "1.0.0"
|
||||
"urls": {
|
||||
"repository": "https://github.com/MaiM-with-u/maibot",
|
||||
"homepage": "https://github.com/MaiM-with-u/maibot",
|
||||
"documentation": "https://github.com/MaiM-with-u/maibot",
|
||||
"issues": "https://github.com/MaiM-with-u/maibot/issues"
|
||||
},
|
||||
"homepage_url": "https://github.com/MaiM-with-u/maibot",
|
||||
"repository_url": "https://github.com/MaiM-with-u/maibot",
|
||||
"keywords": [
|
||||
"plugins",
|
||||
"components",
|
||||
"management",
|
||||
"built-in"
|
||||
"host_application": {
|
||||
"min_version": "1.0.0",
|
||||
"max_version": "1.0.0"
|
||||
},
|
||||
"sdk": {
|
||||
"min_version": "2.0.0",
|
||||
"max_version": "2.99.99"
|
||||
},
|
||||
"dependencies": [],
|
||||
"capabilities": [
|
||||
"component.get_all_plugins",
|
||||
"component.list_loaded_plugins",
|
||||
"component.list_registered_plugins",
|
||||
"component.enable",
|
||||
"component.disable",
|
||||
"component.load_plugin",
|
||||
"component.unload_plugin",
|
||||
"component.reload_plugin",
|
||||
"send.text",
|
||||
"config.get"
|
||||
],
|
||||
"categories": [
|
||||
"Core System",
|
||||
"Plugin Management"
|
||||
],
|
||||
"default_locale": "zh-CN",
|
||||
"locales_path": "_locales",
|
||||
"plugin_info": {
|
||||
"is_built_in": true,
|
||||
"plugin_type": "plugin_management",
|
||||
"capabilities": [
|
||||
"component.get_all_plugins",
|
||||
"component.list_loaded_plugins",
|
||||
"component.list_registered_plugins",
|
||||
"component.enable",
|
||||
"component.disable",
|
||||
"component.load_plugin",
|
||||
"component.unload_plugin",
|
||||
"component.reload_plugin",
|
||||
"send.text",
|
||||
"config.get"
|
||||
],
|
||||
"components": [
|
||||
{
|
||||
"type": "command",
|
||||
"name": "management",
|
||||
"description": "管理插件和组件的生命周期,包括加载、卸载、启用和禁用等操作。"
|
||||
}
|
||||
"i18n": {
|
||||
"default_locale": "zh-CN",
|
||||
"locales_path": "_locales",
|
||||
"supported_locales": [
|
||||
"zh-CN"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"id": "builtin.plugin-management"
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
通过 /pm 命令管理插件和组件的生命周期。
|
||||
"""
|
||||
|
||||
from maibot_sdk import MaiBotPlugin, Command
|
||||
from maibot_sdk import Command, MaiBotPlugin
|
||||
|
||||
|
||||
_VALID_COMPONENT_TYPES = ("action", "command", "event_handler")
|
||||
@@ -44,6 +44,12 @@ HELP_COMPONENT = (
|
||||
class PluginManagementPlugin(MaiBotPlugin):
|
||||
"""插件和组件管理插件"""
|
||||
|
||||
async def on_load(self) -> None:
|
||||
"""处理插件加载。"""
|
||||
|
||||
async def on_unload(self) -> None:
|
||||
"""处理插件卸载。"""
|
||||
|
||||
@Command(
|
||||
"management",
|
||||
description="管理插件和组件的生命周期",
|
||||
@@ -268,6 +274,25 @@ class PluginManagementPlugin(MaiBotPlugin):
|
||||
return components
|
||||
return []
|
||||
|
||||
async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None:
|
||||
"""处理配置热重载事件。
|
||||
|
||||
Args:
|
||||
scope: 配置变更范围。
|
||||
config_data: 最新配置数据。
|
||||
version: 配置版本号。
|
||||
"""
|
||||
|
||||
del scope
|
||||
del config_data
|
||||
del version
|
||||
|
||||
|
||||
def create_plugin() -> PluginManagementPlugin:
|
||||
"""创建插件管理插件实例。
|
||||
|
||||
Returns:
|
||||
PluginManagementPlugin: 新的插件管理插件实例。
|
||||
"""
|
||||
|
||||
def create_plugin():
|
||||
return PluginManagementPlugin()
|
||||
|
||||
@@ -1,155 +1,640 @@
|
||||
"""
|
||||
发送服务模块
|
||||
发送服务模块。
|
||||
|
||||
提供发送各种类型消息的核心功能。
|
||||
统一封装内部模块的出站消息发送逻辑:
|
||||
|
||||
1. 内部模块统一调用本模块。
|
||||
2. send service 只负责构造和预处理消息。
|
||||
3. 具体走插件链还是 legacy 旧链,由 Platform IO 内部统一决策。
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
|
||||
from maim_message import BaseMessageInfo, GroupInfo as MaimGroupInfo, MessageBase, Seg, UserInfo as MaimUserInfo
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
|
||||
from src.chat.utils.utils import get_bot_account
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage
|
||||
from src.common.data_models.message_component_data_model import DictComponent, MessageSequence
|
||||
from src.chat.utils.utils import calculate_typing_time, get_bot_account
|
||||
from src.common.data_models.mai_message_data_model import GroupInfo, MaiMessage, MessageInfo, UserInfo
|
||||
from src.common.data_models.message_component_data_model import (
|
||||
AtComponent,
|
||||
DictComponent,
|
||||
EmojiComponent,
|
||||
ForwardNodeComponent,
|
||||
ImageComponent,
|
||||
MessageSequence,
|
||||
ReplyComponent,
|
||||
StandardMessageComponents,
|
||||
TextComponent,
|
||||
VoiceComponent,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_message import MessageUtils
|
||||
from src.config.config import global_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.platform_io import DeliveryBatch, get_platform_io_manager
|
||||
from src.platform_io.route_key_factory import RouteKeyFactory
|
||||
|
||||
logger = get_logger("send_service")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 内部实现函数
|
||||
# =============================================================================
|
||||
def _inherit_platform_io_route_metadata(target_stream: BotChatSession) -> Dict[str, object]:
|
||||
"""从目标会话继承 Platform IO 路由元数据。
|
||||
|
||||
Args:
|
||||
target_stream: 当前消息要发送到的会话对象。
|
||||
|
||||
Returns:
|
||||
Dict[str, object]: 可安全透传到出站消息 ``additional_config`` 中的
|
||||
路由辅助字段。
|
||||
"""
|
||||
inherited_metadata: Dict[str, object] = {}
|
||||
|
||||
context_message = target_stream.context.message if target_stream.context else None
|
||||
if context_message is not None:
|
||||
additional_config = context_message.message_info.additional_config
|
||||
if isinstance(additional_config, dict):
|
||||
for key in (*RouteKeyFactory.ACCOUNT_ID_KEYS, *RouteKeyFactory.SCOPE_KEYS):
|
||||
value = additional_config.get(key)
|
||||
if value is None:
|
||||
continue
|
||||
normalized_value = str(value).strip()
|
||||
if normalized_value:
|
||||
inherited_metadata[key] = value
|
||||
|
||||
# 当目标会话没有可继承的上下文消息时,至少补齐当前平台账号,
|
||||
# 让按 ``platform + account_id`` 绑定的路由仍有机会命中。
|
||||
if not RouteKeyFactory.extract_components(inherited_metadata)[0]:
|
||||
bot_account = get_bot_account(target_stream.platform)
|
||||
if bot_account:
|
||||
inherited_metadata["platform_io_account_id"] = bot_account
|
||||
|
||||
if target_stream.group_id and (normalized_group_id := str(target_stream.group_id).strip()):
|
||||
inherited_metadata["platform_io_target_group_id"] = normalized_group_id
|
||||
|
||||
if target_stream.user_id and (normalized_user_id := str(target_stream.user_id).strip()):
|
||||
inherited_metadata["platform_io_target_user_id"] = normalized_user_id
|
||||
|
||||
return inherited_metadata
|
||||
|
||||
|
||||
def _build_binary_component_from_base64(component_type: str, raw_data: str) -> StandardMessageComponents:
|
||||
"""根据 Base64 数据构造二进制消息组件。
|
||||
|
||||
Args:
|
||||
component_type: 组件类型名称。
|
||||
raw_data: Base64 编码后的二进制数据。
|
||||
|
||||
Returns:
|
||||
StandardMessageComponents: 转换后的内部消息组件。
|
||||
|
||||
Raises:
|
||||
ValueError: 当组件类型不受支持时抛出。
|
||||
"""
|
||||
binary_data = base64.b64decode(raw_data)
|
||||
binary_hash = hashlib.sha256(binary_data).hexdigest()
|
||||
|
||||
if component_type == "image":
|
||||
return ImageComponent(binary_hash=binary_hash, binary_data=binary_data)
|
||||
if component_type == "emoji":
|
||||
return EmojiComponent(binary_hash=binary_hash, binary_data=binary_data)
|
||||
if component_type == "voice":
|
||||
return VoiceComponent(binary_hash=binary_hash, binary_data=binary_data)
|
||||
raise ValueError(f"不支持的二进制组件类型: {component_type}")
|
||||
|
||||
|
||||
def _build_message_sequence_from_custom_message(
|
||||
message_type: str,
|
||||
content: str | Dict[str, Any],
|
||||
) -> MessageSequence:
|
||||
"""根据自定义消息类型构造内部消息组件序列。
|
||||
|
||||
Args:
|
||||
message_type: 自定义消息类型。
|
||||
content: 自定义消息内容。
|
||||
|
||||
Returns:
|
||||
MessageSequence: 转换后的消息组件序列。
|
||||
"""
|
||||
normalized_type = message_type.strip().lower()
|
||||
|
||||
if normalized_type == "text":
|
||||
return MessageSequence(components=[TextComponent(text=str(content))])
|
||||
|
||||
if normalized_type in {"image", "emoji", "voice"}:
|
||||
return MessageSequence(
|
||||
components=[_build_binary_component_from_base64(normalized_type, str(content))]
|
||||
)
|
||||
|
||||
if normalized_type == "at":
|
||||
return MessageSequence(components=[AtComponent(target_user_id=str(content))])
|
||||
|
||||
if normalized_type == "reply":
|
||||
return MessageSequence(components=[ReplyComponent(target_message_id=str(content))])
|
||||
|
||||
if normalized_type == "dict" and isinstance(content, dict):
|
||||
return MessageSequence(components=[DictComponent(data=deepcopy(content))])
|
||||
|
||||
return MessageSequence(
|
||||
components=[
|
||||
DictComponent(
|
||||
data={
|
||||
"type": normalized_type,
|
||||
"data": deepcopy(content),
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _clone_message_sequence(message_sequence: MessageSequence) -> MessageSequence:
|
||||
"""复制消息组件序列,避免原对象被发送流程修改。
|
||||
|
||||
Args:
|
||||
message_sequence: 原始消息组件序列。
|
||||
|
||||
Returns:
|
||||
MessageSequence: 深拷贝后的消息组件序列。
|
||||
"""
|
||||
return deepcopy(message_sequence)
|
||||
|
||||
|
||||
def _detect_outbound_message_flags(message_sequence: MessageSequence) -> Dict[str, bool]:
|
||||
"""根据消息组件序列推断出站消息标记。
|
||||
|
||||
Args:
|
||||
message_sequence: 待发送的消息组件序列。
|
||||
|
||||
Returns:
|
||||
Dict[str, bool]: 包含 ``is_emoji``、``is_picture``、``is_command`` 的标记字典。
|
||||
"""
|
||||
if len(message_sequence.components) != 1:
|
||||
return {
|
||||
"is_emoji": False,
|
||||
"is_picture": False,
|
||||
"is_command": False,
|
||||
}
|
||||
|
||||
component = message_sequence.components[0]
|
||||
is_command = False
|
||||
if isinstance(component, DictComponent) and isinstance(component.data, dict):
|
||||
is_command = str(component.data.get("type") or "").strip().lower() == "command"
|
||||
|
||||
return {
|
||||
"is_emoji": isinstance(component, EmojiComponent),
|
||||
"is_picture": isinstance(component, ImageComponent),
|
||||
"is_command": is_command,
|
||||
}
|
||||
|
||||
|
||||
def _describe_message_sequence(message_sequence: MessageSequence) -> str:
|
||||
"""生成消息组件序列的简短描述文本。
|
||||
|
||||
Args:
|
||||
message_sequence: 待描述的消息组件序列。
|
||||
|
||||
Returns:
|
||||
str: 适用于日志的简短类型描述。
|
||||
"""
|
||||
if len(message_sequence.components) != 1:
|
||||
return "message_sequence"
|
||||
|
||||
component = message_sequence.components[0]
|
||||
if isinstance(component, DictComponent) and isinstance(component.data, dict):
|
||||
custom_type = str(component.data.get("type") or "").strip()
|
||||
return custom_type or "dict"
|
||||
|
||||
if isinstance(component, TextComponent):
|
||||
return component.format_name
|
||||
|
||||
if isinstance(component, ImageComponent):
|
||||
return component.format_name
|
||||
|
||||
if isinstance(component, EmojiComponent):
|
||||
return component.format_name
|
||||
|
||||
if isinstance(component, VoiceComponent):
|
||||
return component.format_name
|
||||
|
||||
if isinstance(component, AtComponent):
|
||||
return component.format_name
|
||||
|
||||
if isinstance(component, ReplyComponent):
|
||||
return component.format_name
|
||||
|
||||
if isinstance(component, ForwardNodeComponent):
|
||||
return component.format_name
|
||||
|
||||
return "unknown"
|
||||
|
||||
|
||||
def _build_processed_plain_text(message: SessionMessage) -> str:
|
||||
"""为出站消息构造轻量纯文本摘要。
|
||||
|
||||
Args:
|
||||
message: 待发送的内部消息对象。
|
||||
|
||||
Returns:
|
||||
str: 适用于日志与打字时长估算的纯文本摘要。
|
||||
"""
|
||||
processed_parts: List[str] = []
|
||||
for component in message.raw_message.components:
|
||||
if isinstance(component, TextComponent):
|
||||
processed_parts.append(component.text)
|
||||
continue
|
||||
|
||||
if isinstance(component, ImageComponent):
|
||||
processed_parts.append(component.content or "[图片]")
|
||||
continue
|
||||
|
||||
if isinstance(component, EmojiComponent):
|
||||
processed_parts.append(component.content or "[表情]")
|
||||
continue
|
||||
|
||||
if isinstance(component, VoiceComponent):
|
||||
processed_parts.append(component.content or "[语音]")
|
||||
continue
|
||||
|
||||
if isinstance(component, AtComponent):
|
||||
at_target = component.target_user_cardname or component.target_user_nickname or component.target_user_id
|
||||
processed_parts.append(f"@{at_target}")
|
||||
continue
|
||||
|
||||
if isinstance(component, ReplyComponent):
|
||||
processed_parts.append(component.target_message_content or "[回复消息]")
|
||||
continue
|
||||
|
||||
if isinstance(component, DictComponent):
|
||||
raw_type = component.data.get("type") if isinstance(component.data, dict) else None
|
||||
if isinstance(raw_type, str) and raw_type.strip():
|
||||
processed_parts.append(f"[{raw_type.strip()}消息]")
|
||||
else:
|
||||
processed_parts.append("[自定义消息]")
|
||||
continue
|
||||
|
||||
return " ".join(part for part in processed_parts if part)
|
||||
|
||||
|
||||
def _build_outbound_session_message(
|
||||
message_sequence: MessageSequence,
|
||||
stream_id: str,
|
||||
display_message: str = "",
|
||||
reply_message: Optional[MaiMessage] = None,
|
||||
selected_expressions: Optional[List[int]] = None,
|
||||
) -> Optional[SessionMessage]:
|
||||
"""根据目标会话构建待发送的内部消息对象。
|
||||
|
||||
Args:
|
||||
message_sequence: 待发送的消息组件序列。
|
||||
stream_id: 目标会话 ID。
|
||||
display_message: 用于界面展示的文本内容。
|
||||
reply_message: 被回复的锚点消息。
|
||||
selected_expressions: 可选的表情候选索引列表。
|
||||
|
||||
Returns:
|
||||
Optional[SessionMessage]: 构建成功时返回内部消息对象;若目标会话或
|
||||
机器人账号不存在,则返回 ``None``。
|
||||
"""
|
||||
target_stream = _chat_manager.get_session_by_session_id(stream_id)
|
||||
if target_stream is None:
|
||||
logger.error(f"[SendService] 未找到聊天流: {stream_id}")
|
||||
return None
|
||||
|
||||
bot_user_id = get_bot_account(target_stream.platform)
|
||||
if not bot_user_id:
|
||||
logger.error(f"[SendService] 平台 {target_stream.platform} 未配置机器人账号,无法发送消息")
|
||||
return None
|
||||
|
||||
current_time = time.time()
|
||||
message_id = f"send_api_{int(current_time * 1000)}"
|
||||
anchor_message = reply_message.deepcopy() if reply_message is not None else None
|
||||
|
||||
group_info: Optional[GroupInfo] = None
|
||||
if target_stream.group_id:
|
||||
group_name = ""
|
||||
if (
|
||||
target_stream.context
|
||||
and target_stream.context.message
|
||||
and target_stream.context.message.message_info.group_info
|
||||
):
|
||||
group_name = target_stream.context.message.message_info.group_info.group_name
|
||||
group_info = GroupInfo(
|
||||
group_id=target_stream.group_id,
|
||||
group_name=group_name,
|
||||
)
|
||||
|
||||
additional_config: Dict[str, object] = _inherit_platform_io_route_metadata(target_stream)
|
||||
if selected_expressions is not None:
|
||||
additional_config["selected_expressions"] = selected_expressions
|
||||
|
||||
outbound_message = SessionMessage(
|
||||
message_id=message_id,
|
||||
timestamp=datetime.fromtimestamp(current_time),
|
||||
platform=target_stream.platform,
|
||||
)
|
||||
outbound_message.message_info = MessageInfo(
|
||||
user_info=UserInfo(
|
||||
user_id=bot_user_id,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
),
|
||||
group_info=group_info,
|
||||
additional_config=additional_config,
|
||||
)
|
||||
outbound_message.raw_message = _clone_message_sequence(message_sequence)
|
||||
outbound_message.session_id = target_stream.session_id
|
||||
outbound_message.display_message = display_message
|
||||
outbound_message.reply_to = anchor_message.message_id if anchor_message is not None else None
|
||||
message_flags = _detect_outbound_message_flags(outbound_message.raw_message)
|
||||
outbound_message.is_emoji = message_flags["is_emoji"]
|
||||
outbound_message.is_picture = message_flags["is_picture"]
|
||||
outbound_message.is_command = message_flags["is_command"]
|
||||
outbound_message.initialized = True
|
||||
return outbound_message
|
||||
|
||||
|
||||
def _ensure_reply_component(message: SessionMessage, reply_message_id: str) -> None:
|
||||
"""为消息补充回复组件。
|
||||
|
||||
Args:
|
||||
message: 待发送的内部消息对象。
|
||||
reply_message_id: 被引用消息的 ID。
|
||||
"""
|
||||
if message.raw_message.components:
|
||||
first_component = message.raw_message.components[0]
|
||||
if isinstance(first_component, ReplyComponent) and first_component.target_message_id == reply_message_id:
|
||||
return
|
||||
|
||||
message.raw_message.components.insert(0, ReplyComponent(target_message_id=reply_message_id))
|
||||
|
||||
|
||||
async def _prepare_message_for_platform_io(
|
||||
message: SessionMessage,
|
||||
*,
|
||||
typing: bool,
|
||||
set_reply: bool,
|
||||
reply_message_id: Optional[str],
|
||||
) -> None:
|
||||
"""为 Platform IO 发送链预处理消息。
|
||||
|
||||
Args:
|
||||
message: 待发送的内部消息对象。
|
||||
typing: 是否模拟打字等待。
|
||||
set_reply: 是否构建引用回复组件。
|
||||
reply_message_id: 被引用消息的 ID。
|
||||
|
||||
Raises:
|
||||
ValueError: 当要求设置引用回复但缺少 ``reply_message_id`` 时抛出。
|
||||
"""
|
||||
if set_reply:
|
||||
if not reply_message_id:
|
||||
raise ValueError("set_reply=True 时必须提供 reply_message_id")
|
||||
_ensure_reply_component(message, reply_message_id)
|
||||
|
||||
message.processed_plain_text = _build_processed_plain_text(message)
|
||||
if typing:
|
||||
typing_time = calculate_typing_time(
|
||||
input_string=message.processed_plain_text or "",
|
||||
is_emoji=message.is_emoji,
|
||||
)
|
||||
await asyncio.sleep(typing_time)
|
||||
|
||||
|
||||
def _store_sent_message(message: SessionMessage) -> None:
|
||||
"""将已成功发送的消息写入数据库。
|
||||
|
||||
Args:
|
||||
message: 已成功发送的内部消息对象。
|
||||
"""
|
||||
MessageUtils.store_message_to_db(message)
|
||||
|
||||
|
||||
def _log_platform_io_failures(delivery_batch: DeliveryBatch) -> None:
|
||||
"""输出 Platform IO 批量发送失败详情。
|
||||
|
||||
Args:
|
||||
delivery_batch: Platform IO 返回的批量回执。
|
||||
"""
|
||||
failed_details = "; ".join(
|
||||
f"driver={receipt.driver_id} status={receipt.status} error={receipt.error}"
|
||||
for receipt in delivery_batch.failed_receipts
|
||||
) or "未命中任何发送路由"
|
||||
logger.warning(
|
||||
"[SendService] Platform IO 发送失败: platform=%s %s",
|
||||
delivery_batch.route_key.platform,
|
||||
failed_details,
|
||||
)
|
||||
|
||||
|
||||
async def _send_via_platform_io(
|
||||
message: SessionMessage,
|
||||
*,
|
||||
typing: bool,
|
||||
set_reply: bool,
|
||||
reply_message_id: Optional[str],
|
||||
storage_message: bool,
|
||||
show_log: bool,
|
||||
) -> bool:
|
||||
"""通过 Platform IO 发送消息。
|
||||
|
||||
Args:
|
||||
message: 待发送的内部消息对象。
|
||||
typing: 是否模拟打字等待。
|
||||
set_reply: 是否设置引用回复。
|
||||
reply_message_id: 被引用消息的 ID。
|
||||
storage_message: 发送成功后是否写入数据库。
|
||||
show_log: 是否输出发送成功日志。
|
||||
|
||||
Returns:
|
||||
bool: 发送成功时返回 ``True``。
|
||||
"""
|
||||
platform_io_manager = get_platform_io_manager()
|
||||
try:
|
||||
await platform_io_manager.ensure_send_pipeline_ready()
|
||||
except Exception as exc:
|
||||
logger.error(f"[SendService] 准备 Platform IO 发送管线失败: {exc}")
|
||||
logger.debug(traceback.format_exc())
|
||||
return False
|
||||
|
||||
try:
|
||||
route_key = platform_io_manager.build_route_key_from_message(message)
|
||||
except Exception as exc:
|
||||
logger.warning(f"[SendService] 根据消息构造 Platform IO 路由键失败: {exc}")
|
||||
return False
|
||||
|
||||
try:
|
||||
await _prepare_message_for_platform_io(
|
||||
message,
|
||||
typing=typing,
|
||||
set_reply=set_reply,
|
||||
reply_message_id=reply_message_id,
|
||||
)
|
||||
delivery_batch = await platform_io_manager.send_message(
|
||||
message,
|
||||
route_key,
|
||||
metadata={"show_log": False},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"[SendService] Platform IO 发送异常: {exc}")
|
||||
logger.debug(traceback.format_exc())
|
||||
return False
|
||||
|
||||
if delivery_batch.has_success:
|
||||
if storage_message:
|
||||
_store_sent_message(message)
|
||||
if show_log:
|
||||
successful_driver_ids = [
|
||||
receipt.driver_id or "unknown"
|
||||
for receipt in delivery_batch.sent_receipts
|
||||
]
|
||||
logger.info(
|
||||
"[SendService] 已通过 Platform IO 将消息发往平台 '%s' (drivers: %s)",
|
||||
route_key.platform,
|
||||
", ".join(successful_driver_ids),
|
||||
)
|
||||
return True
|
||||
|
||||
_log_platform_io_failures(delivery_batch)
|
||||
return False
|
||||
|
||||
|
||||
async def send_session_message(
|
||||
message: SessionMessage,
|
||||
*,
|
||||
typing: bool = False,
|
||||
set_reply: bool = False,
|
||||
reply_message_id: Optional[str] = None,
|
||||
storage_message: bool = True,
|
||||
show_log: bool = True,
|
||||
) -> bool:
|
||||
"""统一发送一条内部消息。
|
||||
|
||||
该方法是内部模块的统一发送入口:
|
||||
|
||||
1. 构造并维护内部消息对象。
|
||||
2. 由 Platform IO 统一决定走插件链还是 legacy 旧链。
|
||||
3. send service 不再自行判断底层发送路径。
|
||||
|
||||
Args:
|
||||
message: 待发送的内部消息对象。
|
||||
typing: 是否模拟打字等待。
|
||||
set_reply: 是否设置引用回复。
|
||||
reply_message_id: 被引用消息的 ID。
|
||||
storage_message: 发送成功后是否写入数据库。
|
||||
show_log: 是否输出发送日志。
|
||||
|
||||
Returns:
|
||||
bool: 发送成功时返回 ``True``,否则返回 ``False``。
|
||||
"""
|
||||
if not message.message_id:
|
||||
logger.error("[SendService] 消息缺少 message_id,无法发送")
|
||||
raise ValueError("消息缺少 message_id,无法发送")
|
||||
|
||||
return await _send_via_platform_io(
|
||||
message,
|
||||
typing=typing,
|
||||
set_reply=set_reply,
|
||||
reply_message_id=reply_message_id,
|
||||
storage_message=storage_message,
|
||||
show_log=show_log,
|
||||
)
|
||||
|
||||
|
||||
async def _send_to_target(
|
||||
message_segment: Seg,
|
||||
message_sequence: MessageSequence,
|
||||
stream_id: str,
|
||||
display_message: str = "",
|
||||
typing: bool = False,
|
||||
set_reply: bool = False,
|
||||
reply_message: Optional["SessionMessage"] = None,
|
||||
reply_message: Optional[MaiMessage] = None,
|
||||
storage_message: bool = True,
|
||||
show_log: bool = True,
|
||||
selected_expressions: Optional[List[int]] = None,
|
||||
) -> bool:
|
||||
"""向指定目标发送消息的内部实现"""
|
||||
"""向指定目标构建并发送消息。
|
||||
|
||||
Args:
|
||||
message_sequence: 待发送的消息组件序列。
|
||||
stream_id: 目标会话 ID。
|
||||
display_message: 用于界面展示的文本内容。
|
||||
typing: 是否显示输入中状态。
|
||||
set_reply: 是否在发送时附带引用回复。
|
||||
reply_message: 被回复的消息对象。
|
||||
storage_message: 是否将发送结果写入消息存储。
|
||||
show_log: 是否输出发送日志。
|
||||
selected_expressions: 可选的表情候选索引列表。
|
||||
|
||||
Returns:
|
||||
bool: 发送成功返回 ``True``,否则返回 ``False``。
|
||||
"""
|
||||
try:
|
||||
if set_reply and not reply_message:
|
||||
if set_reply and reply_message is None:
|
||||
logger.warning("[SendService] 使用引用回复,但未提供回复消息")
|
||||
return False
|
||||
|
||||
if show_log:
|
||||
logger.debug(f"[SendService] 发送{message_segment.type}消息到 {stream_id}")
|
||||
logger.debug(f"[SendService] 发送{_describe_message_sequence(message_sequence)}消息到 {stream_id}")
|
||||
|
||||
target_stream = _chat_manager.get_session_by_session_id(stream_id)
|
||||
if not target_stream:
|
||||
logger.error(f"[SendService] 未找到聊天流: {stream_id}")
|
||||
return False
|
||||
|
||||
message_sender = UniversalMessageSender()
|
||||
|
||||
current_time = time.time()
|
||||
message_id = f"send_api_{int(current_time * 1000)}"
|
||||
|
||||
anchor_message: Optional[MaiMessage] = None
|
||||
if reply_message:
|
||||
anchor_message = reply_message.deepcopy()
|
||||
if anchor_message:
|
||||
logger.debug(
|
||||
f"[SendService] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}"
|
||||
)
|
||||
|
||||
group_info = None
|
||||
if target_stream.group_id:
|
||||
group_name = ""
|
||||
if target_stream.context and target_stream.context.message and target_stream.context.message.message_info.group_info:
|
||||
group_name = target_stream.context.message.message_info.group_info.group_name
|
||||
group_info = MaimGroupInfo(
|
||||
group_id=target_stream.group_id,
|
||||
group_name=group_name,
|
||||
platform=target_stream.platform,
|
||||
)
|
||||
|
||||
additional_config: dict[str, object] = {}
|
||||
if selected_expressions is not None:
|
||||
additional_config["selected_expressions"] = selected_expressions
|
||||
bot_user_id = get_bot_account(target_stream.platform)
|
||||
if not bot_user_id:
|
||||
logger.error(f"[SendService] 平台 {target_stream.platform} 未配置机器人账号,无法发送消息")
|
||||
return False
|
||||
|
||||
maim_message = MessageBase(
|
||||
message_info=BaseMessageInfo(
|
||||
platform=target_stream.platform,
|
||||
message_id=message_id,
|
||||
time=current_time,
|
||||
user_info=MaimUserInfo(
|
||||
user_id=bot_user_id,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
platform=target_stream.platform,
|
||||
),
|
||||
group_info=group_info,
|
||||
additional_config=additional_config,
|
||||
),
|
||||
message_segment=message_segment,
|
||||
outbound_message = _build_outbound_session_message(
|
||||
message_sequence=message_sequence,
|
||||
stream_id=stream_id,
|
||||
display_message=display_message,
|
||||
reply_message=reply_message,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
bot_message = SessionMessage.from_maim_message(maim_message)
|
||||
bot_message.session_id = target_stream.session_id
|
||||
bot_message.display_message = display_message
|
||||
bot_message.reply_to = anchor_message.message_id if anchor_message else None
|
||||
bot_message.is_emoji = message_segment.type == "emoji"
|
||||
bot_message.is_picture = message_segment.type == "image"
|
||||
bot_message.is_command = message_segment.type == "command"
|
||||
if outbound_message is None:
|
||||
return False
|
||||
|
||||
sent_msg = await message_sender.send_message(
|
||||
bot_message,
|
||||
sent = await send_session_message(
|
||||
outbound_message,
|
||||
typing=typing,
|
||||
set_reply=set_reply,
|
||||
reply_message_id=anchor_message.message_id if anchor_message else None,
|
||||
reply_message_id=reply_message.message_id if reply_message is not None else None,
|
||||
storage_message=storage_message,
|
||||
show_log=show_log,
|
||||
)
|
||||
|
||||
if sent_msg:
|
||||
if sent:
|
||||
logger.debug(f"[SendService] 成功发送消息到 {stream_id}")
|
||||
return True
|
||||
else:
|
||||
logger.error("[SendService] 发送消息失败")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SendService] 发送消息时出错: {e}")
|
||||
logger.error("[SendService] 发送消息失败")
|
||||
return False
|
||||
except Exception as exc:
|
||||
logger.error(f"[SendService] 发送消息时出错: {exc}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 公共函数 - 预定义类型的发送函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def text_to_stream(
|
||||
text: str,
|
||||
stream_id: str,
|
||||
typing: bool = False,
|
||||
set_reply: bool = False,
|
||||
reply_message: Optional["SessionMessage"] = None,
|
||||
reply_message: Optional[MaiMessage] = None,
|
||||
storage_message: bool = True,
|
||||
selected_expressions: Optional[List[int]] = None,
|
||||
) -> bool:
|
||||
"""向指定流发送文本消息"""
|
||||
"""向指定流发送文本消息。
|
||||
|
||||
Args:
|
||||
text: 要发送的文本内容。
|
||||
stream_id: 目标会话 ID。
|
||||
typing: 是否显示输入中状态。
|
||||
set_reply: 是否附带引用回复。
|
||||
reply_message: 被回复的消息对象。
|
||||
storage_message: 是否在发送成功后写入数据库。
|
||||
selected_expressions: 可选的表情候选索引列表。
|
||||
|
||||
Returns:
|
||||
bool: 发送成功时返回 ``True``。
|
||||
"""
|
||||
return await _send_to_target(
|
||||
message_segment=Seg(type="text", data=text),
|
||||
message_sequence=MessageSequence(components=[TextComponent(text=text)]),
|
||||
stream_id=stream_id,
|
||||
display_message="",
|
||||
typing=typing,
|
||||
@@ -165,11 +650,22 @@ async def emoji_to_stream(
|
||||
stream_id: str,
|
||||
storage_message: bool = True,
|
||||
set_reply: bool = False,
|
||||
reply_message: Optional["SessionMessage"] = None,
|
||||
reply_message: Optional[MaiMessage] = None,
|
||||
) -> bool:
|
||||
"""向指定流发送表情包"""
|
||||
"""向指定流发送表情消息。
|
||||
|
||||
Args:
|
||||
emoji_base64: 表情图片的 Base64 内容。
|
||||
stream_id: 目标会话 ID。
|
||||
storage_message: 是否在发送成功后写入数据库。
|
||||
set_reply: 是否附带引用回复。
|
||||
reply_message: 被回复的消息对象。
|
||||
|
||||
Returns:
|
||||
bool: 发送成功时返回 ``True``。
|
||||
"""
|
||||
return await _send_to_target(
|
||||
message_segment=Seg(type="emoji", data=emoji_base64),
|
||||
message_sequence=_build_message_sequence_from_custom_message("emoji", emoji_base64),
|
||||
stream_id=stream_id,
|
||||
display_message="",
|
||||
typing=False,
|
||||
@@ -184,11 +680,22 @@ async def image_to_stream(
|
||||
stream_id: str,
|
||||
storage_message: bool = True,
|
||||
set_reply: bool = False,
|
||||
reply_message: Optional["SessionMessage"] = None,
|
||||
reply_message: Optional[MaiMessage] = None,
|
||||
) -> bool:
|
||||
"""向指定流发送图片"""
|
||||
"""向指定流发送图片消息。
|
||||
|
||||
Args:
|
||||
image_base64: 图片的 Base64 内容。
|
||||
stream_id: 目标会话 ID。
|
||||
storage_message: 是否在发送成功后写入数据库。
|
||||
set_reply: 是否附带引用回复。
|
||||
reply_message: 被回复的消息对象。
|
||||
|
||||
Returns:
|
||||
bool: 发送成功时返回 ``True``。
|
||||
"""
|
||||
return await _send_to_target(
|
||||
message_segment=Seg(type="image", data=image_base64),
|
||||
message_sequence=_build_message_sequence_from_custom_message("image", image_base64),
|
||||
stream_id=stream_id,
|
||||
display_message="",
|
||||
typing=False,
|
||||
@@ -200,18 +707,33 @@ async def image_to_stream(
|
||||
|
||||
async def custom_to_stream(
|
||||
message_type: str,
|
||||
content: str | Dict,
|
||||
content: str | Dict[str, Any],
|
||||
stream_id: str,
|
||||
display_message: str = "",
|
||||
typing: bool = False,
|
||||
reply_message: Optional["SessionMessage"] = None,
|
||||
reply_message: Optional[MaiMessage] = None,
|
||||
set_reply: bool = False,
|
||||
storage_message: bool = True,
|
||||
show_log: bool = True,
|
||||
) -> bool:
|
||||
"""向指定流发送自定义类型消息"""
|
||||
"""向指定流发送自定义类型消息。
|
||||
|
||||
Args:
|
||||
message_type: 自定义消息类型。
|
||||
content: 自定义消息内容。
|
||||
stream_id: 目标会话 ID。
|
||||
display_message: 用于展示的文本内容。
|
||||
typing: 是否显示输入中状态。
|
||||
reply_message: 被回复的消息对象。
|
||||
set_reply: 是否附带引用回复。
|
||||
storage_message: 是否在发送成功后写入数据库。
|
||||
show_log: 是否输出发送日志。
|
||||
|
||||
Returns:
|
||||
bool: 发送成功时返回 ``True``。
|
||||
"""
|
||||
return await _send_to_target(
|
||||
message_segment=Seg(type=message_type, data=content), # type: ignore
|
||||
message_sequence=_build_message_sequence_from_custom_message(message_type, content),
|
||||
stream_id=stream_id,
|
||||
display_message=display_message,
|
||||
typing=typing,
|
||||
@@ -227,31 +749,33 @@ async def custom_reply_set_to_stream(
|
||||
stream_id: str,
|
||||
display_message: str = "",
|
||||
typing: bool = False,
|
||||
reply_message: Optional["SessionMessage"] = None,
|
||||
reply_message: Optional[MaiMessage] = None,
|
||||
set_reply: bool = False,
|
||||
storage_message: bool = True,
|
||||
show_log: bool = True,
|
||||
) -> bool:
|
||||
"""向指定流发送消息组件序列。"""
|
||||
flag: bool = True
|
||||
for component in reply_set.components:
|
||||
if isinstance(component, DictComponent):
|
||||
message_seg = Seg(type="dict", data=component.data) # type: ignore
|
||||
else:
|
||||
message_seg = await component.to_seg()
|
||||
status = await _send_to_target(
|
||||
message_segment=message_seg,
|
||||
stream_id=stream_id,
|
||||
display_message=display_message,
|
||||
typing=typing,
|
||||
reply_message=reply_message,
|
||||
set_reply=set_reply,
|
||||
storage_message=storage_message,
|
||||
show_log=show_log,
|
||||
)
|
||||
if not status:
|
||||
flag = False
|
||||
logger.error(f"[SendService] 发送消息组件失败,组件类型:{type(component).__name__}")
|
||||
set_reply = False
|
||||
"""向指定流发送消息组件序列。
|
||||
|
||||
return flag
|
||||
Args:
|
||||
reply_set: 待发送的消息组件序列。
|
||||
stream_id: 目标会话 ID。
|
||||
display_message: 用于展示的文本内容。
|
||||
typing: 是否显示输入中状态。
|
||||
reply_message: 被回复的消息对象。
|
||||
set_reply: 是否附带引用回复。
|
||||
storage_message: 是否在发送成功后写入数据库。
|
||||
show_log: 是否输出发送日志。
|
||||
|
||||
Returns:
|
||||
bool: 发送成功时返回 ``True``。
|
||||
"""
|
||||
return await _send_to_target(
|
||||
message_sequence=reply_set,
|
||||
stream_id=stream_id,
|
||||
display_message=display_message,
|
||||
typing=typing,
|
||||
reply_message=reply_message,
|
||||
set_reply=set_reply,
|
||||
storage_message=storage_message,
|
||||
show_log=show_log,
|
||||
)
|
||||
|
||||
175
src/webui/routers/chat/serializers.py
Normal file
175
src/webui/routers/chat/serializers.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""提供 WebUI 聊天路由使用的消息序列化能力。"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import base64
|
||||
|
||||
from src.common.data_models.message_component_data_model import (
|
||||
AtComponent,
|
||||
DictComponent,
|
||||
EmojiComponent,
|
||||
ForwardComponent,
|
||||
ForwardNodeComponent,
|
||||
ImageComponent,
|
||||
MessageSequence,
|
||||
ReplyComponent,
|
||||
StandardMessageComponents,
|
||||
TextComponent,
|
||||
VoiceComponent,
|
||||
)
|
||||
|
||||
|
||||
def serialize_message_sequence(message_sequence: MessageSequence) -> List[Dict[str, Any]]:
|
||||
"""将内部统一消息组件序列转换为 WebUI 富文本消息段。
|
||||
|
||||
Args:
|
||||
message_sequence: 内部统一消息组件序列。
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 可直接广播给 WebUI 前端的消息段列表。
|
||||
"""
|
||||
serialized_segments: List[Dict[str, Any]] = []
|
||||
for component in message_sequence.components:
|
||||
serialized_segment = serialize_message_component(component)
|
||||
if serialized_segment is not None:
|
||||
serialized_segments.append(serialized_segment)
|
||||
return serialized_segments
|
||||
|
||||
|
||||
def serialize_message_component(component: StandardMessageComponents) -> Optional[Dict[str, Any]]:
|
||||
"""将单个内部消息组件转换为 WebUI 消息段。
|
||||
|
||||
Args:
|
||||
component: 待序列化的内部消息组件。
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 序列化后的 WebUI 消息段;若组件不应展示则返回 ``None``。
|
||||
"""
|
||||
if isinstance(component, TextComponent):
|
||||
return {"type": "text", "data": component.text}
|
||||
|
||||
if isinstance(component, ImageComponent):
|
||||
return _serialize_binary_component(
|
||||
segment_type="image",
|
||||
mime_type="image/png",
|
||||
binary_data=component.binary_data,
|
||||
fallback_text=component.content,
|
||||
)
|
||||
|
||||
if isinstance(component, EmojiComponent):
|
||||
return _serialize_binary_component(
|
||||
segment_type="emoji",
|
||||
mime_type="image/gif",
|
||||
binary_data=component.binary_data,
|
||||
fallback_text=component.content,
|
||||
)
|
||||
|
||||
if isinstance(component, VoiceComponent):
|
||||
return _serialize_binary_component(
|
||||
segment_type="voice",
|
||||
mime_type="audio/wav",
|
||||
binary_data=component.binary_data,
|
||||
fallback_text=component.content,
|
||||
)
|
||||
|
||||
if isinstance(component, AtComponent):
|
||||
return {
|
||||
"type": "at",
|
||||
"data": {
|
||||
"target_user_id": component.target_user_id,
|
||||
"target_user_nickname": component.target_user_nickname,
|
||||
"target_user_cardname": component.target_user_cardname,
|
||||
},
|
||||
}
|
||||
|
||||
if isinstance(component, ReplyComponent):
|
||||
return {
|
||||
"type": "reply",
|
||||
"data": {
|
||||
"target_message_id": component.target_message_id,
|
||||
"target_message_content": component.target_message_content,
|
||||
"target_message_sender_id": component.target_message_sender_id,
|
||||
"target_message_sender_nickname": component.target_message_sender_nickname,
|
||||
"target_message_sender_cardname": component.target_message_sender_cardname,
|
||||
},
|
||||
}
|
||||
|
||||
if isinstance(component, ForwardNodeComponent):
|
||||
return {
|
||||
"type": "forward",
|
||||
"data": [_serialize_forward_component(item) for item in component.forward_components],
|
||||
}
|
||||
|
||||
if isinstance(component, DictComponent):
|
||||
return _serialize_dict_component(component.data)
|
||||
|
||||
return {"type": "unknown", "data": str(component)}
|
||||
|
||||
|
||||
def _serialize_binary_component(
|
||||
segment_type: str,
|
||||
mime_type: str,
|
||||
binary_data: bytes,
|
||||
fallback_text: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""序列化带二进制负载的消息组件。
|
||||
|
||||
Args:
|
||||
segment_type: WebUI 消息段类型。
|
||||
mime_type: 对应的数据 MIME 类型。
|
||||
binary_data: 组件二进制数据。
|
||||
fallback_text: 二进制缺失时可退化展示的文本。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 序列化后的 WebUI 消息段。
|
||||
"""
|
||||
if binary_data:
|
||||
encoded_payload = base64.b64encode(binary_data).decode()
|
||||
return {"type": segment_type, "data": f"data:{mime_type};base64,{encoded_payload}"}
|
||||
|
||||
if fallback_text:
|
||||
return {"type": "text", "data": fallback_text}
|
||||
|
||||
return {"type": "unknown", "original_type": segment_type, "data": ""}
|
||||
|
||||
|
||||
def _serialize_forward_component(component: ForwardComponent) -> Dict[str, Any]:
|
||||
"""序列化单个转发节点。
|
||||
|
||||
Args:
|
||||
component: 待序列化的转发节点组件。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: WebUI 可消费的转发节点字典。
|
||||
"""
|
||||
return {
|
||||
"message_id": component.message_id,
|
||||
"user_id": component.user_id,
|
||||
"user_nickname": component.user_nickname,
|
||||
"user_cardname": component.user_cardname,
|
||||
"content": serialize_message_sequence(MessageSequence(component.content)),
|
||||
}
|
||||
|
||||
|
||||
def _serialize_dict_component(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""最佳努力地序列化非标准字典组件。
|
||||
|
||||
Args:
|
||||
data: 原始字典组件内容。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 序列化后的 WebUI 消息段。
|
||||
"""
|
||||
raw_type = str(data.get("type") or "dict").strip()
|
||||
raw_payload = data.get("data", data)
|
||||
|
||||
if raw_type in {"text", "image", "emoji", "voice", "video", "file", "music", "face"}:
|
||||
return {"type": raw_type, "data": raw_payload}
|
||||
|
||||
if raw_type == "reply":
|
||||
return {"type": "reply", "data": raw_payload}
|
||||
|
||||
if raw_type == "forward" and isinstance(raw_payload, list):
|
||||
return {"type": "forward", "data": raw_payload}
|
||||
|
||||
return {"type": "unknown", "original_type": raw_type, "data": raw_payload}
|
||||
Reference in New Issue
Block a user