diff --git a/bot.py b/bot.py index e182c1eb..4414f551 100644 --- a/bot.py +++ b/bot.py @@ -195,11 +195,11 @@ async def graceful_shutdown(): # sourcery skip: use-named-expression # except Exception as e: # logger.warning(f"关闭 WebUI 服务器时出错: {e}") - from src.plugin_system.core.events_manager import events_manager - from src.plugin_system.base.component_types import EventType + from src.core.event_bus import event_bus + from src.core.types import EventType # 触发 ON_STOP 事件 - await events_manager.handle_mai_events(event_type=EventType.ON_STOP) + await event_bus.emit(event_type=EventType.ON_STOP) # 停止新版本插件运行时 from src.plugin_runtime.integration import get_plugin_runtime_manager diff --git a/plugins/ChatFrequency/_manifest.json b/plugins/ChatFrequency/_manifest.json index 669de6b8..241242ed 100644 --- a/plugins/ChatFrequency/_manifest.json +++ b/plugins/ChatFrequency/_manifest.json @@ -9,7 +9,7 @@ }, "license": "GPL-v3.0-or-later", "host_application": { - "min_version": "0.10.3" + "min_version": "1.0.0" }, "homepage_url": "https://github.com/SengokuCola/BetterFrequency", "repository_url": "https://github.com/SengokuCola/BetterFrequency", diff --git a/plugins/ChatFrequency/plugin.py b/plugins/ChatFrequency/plugin.py index e2231c86..97fae04f 100644 --- a/plugins/ChatFrequency/plugin.py +++ b/plugins/ChatFrequency/plugin.py @@ -1,144 +1,87 @@ -from typing import List, Tuple, Type, Optional -from src.plugin_system import BasePlugin, register_plugin, BaseCommand, ComponentInfo, ConfigField -from src.plugin_system.apis import send_api, frequency_api +"""发言频率控制插件 — 新 SDK 版本 + +通过 /chat 命令设置和查看聊天频率。 +""" + +from maibot_sdk import MaiBotPlugin, Command -class SetTalkFrequencyCommand(BaseCommand): - """设置当前聊天的talk_frequency值""" +class BetterFrequencyPlugin(MaiBotPlugin): + """聊天频率控制插件""" - command_name = "set_talk_frequency" - command_description = "设置当前聊天的talk_frequency值:/chat talk_frequency <数字> 或 /chat t <数字>" - command_pattern = r"^/chat\s+(?:talk_frequency|t)\s+(?P[+-]?\d*\.?\d+)$" + @Command( + "set_talk_frequency", + description="设置当前聊天的talk_frequency值:/chat talk_frequency <数字> 或 /chat t <数字>", + pattern=r"^/chat\s+(?:talk_frequency|t)\s+(?P[+-]?\d*\.?\d+)$", + ) + async def handle_set_talk_frequency( + self, stream_id: str = "", matched_groups: dict | None = None, **kwargs + ): + """设置当前聊天的 talk_frequency""" + if not matched_groups or "value" not in matched_groups: + return False, "命令格式错误", False + + value_str = matched_groups["value"] + if not value_str: + return False, "无法获取数值参数", False - async def execute(self) -> Tuple[bool, Optional[str], bool]: try: - # 获取命令参数 - 使用命名捕获组 - if not self.matched_groups or "value" not in self.matched_groups: - return False, "命令格式错误", False - - value_str = self.matched_groups["value"] - if not value_str: - return False, "无法获取数值参数", False - value = float(value_str) - - # 获取聊天流ID - if not self.message.chat_stream or not hasattr(self.message.chat_stream, "stream_id"): - return False, "无法获取聊天流信息", False - - chat_id = self.message.chat_stream.stream_id - - # 设置talk_frequency - frequency_api.set_talk_frequency_adjust(chat_id, value) - - final_value = frequency_api.get_current_talk_value(chat_id) - adjust_value = frequency_api.get_talk_frequency_adjust(chat_id) - base_value = final_value / adjust_value - - # 发送反馈消息(不保存到数据库) - await send_api.text_to_stream( - f"已设置当前聊天的talk_frequency调整值为: {value}\n当前talk_value: {final_value:.2f}\n发言频率调整: {adjust_value:.2f}\n基础值: {base_value:.2f}", - chat_id, - storage_message=False, - ) - - return True, None, False - except ValueError: - error_msg = "数值格式错误,请输入有效的数字" - await self.send_text(error_msg, storage_message=False) - return False, error_msg, False - except Exception as e: - error_msg = f"设置talk_frequency失败: {str(e)}" - await self.send_text(error_msg, storage_message=False) - return False, error_msg, False + await self.ctx.send.text("数值格式错误,请输入有效的数字", stream_id) + return False, "数值格式错误", False + + if not stream_id: + return False, "无法获取聊天流信息", False + + # 设置 talk_frequency + await self.ctx.frequency.set_adjust(stream_id, value) + + # 获取当前状态 + current = await self.ctx.frequency.get_current_talk_value(stream_id) + current_val = current if isinstance(current, (int, float)) else 0 + adjust = await self.ctx.frequency.get_adjust(stream_id) + adjust_val = adjust if isinstance(adjust, (int, float)) else 1 + base_val = current_val / adjust_val if adjust_val else 0 + + msg = ( + f"已设置当前聊天的talk_frequency调整值为: {value}\n" + f"当前talk_value: {current_val:.2f}\n" + f"发言频率调整: {adjust_val:.2f}\n" + f"基础值: {base_val:.2f}" + ) + await self.ctx.send.text(msg, stream_id) + return True, None, False + + @Command( + "show_frequency", + description="显示当前聊天的频率控制状态:/chat show 或 /chat s", + pattern=r"^/chat\s+(?:show|s)$", + ) + async def handle_show_frequency(self, stream_id: str = "", **kwargs): + """显示当前频率控制状态""" + if not stream_id: + return False, "无法获取聊天流信息", False + + current = await self.ctx.frequency.get_current_talk_value(stream_id) + current_val = current if isinstance(current, (int, float)) else 0 + adjust = await self.ctx.frequency.get_adjust(stream_id) + adjust_val = adjust if isinstance(adjust, (int, float)) else 1 + base_val = current_val / adjust_val if adjust_val else 0 + + status_msg = ( + "当前聊天频率控制状态\n" + "Talk Value (发言频率):\n\n" + f" • 基础值: {base_val:.2f}\n" + f" • 发言频率调整: {adjust_val:.2f}\n" + f" • 当前值: {current_val:.2f}\n\n" + "使用命令:\n" + " • /chat talk_frequency <数字> 或 /chat t <数字> - 设置发言频率调整\n" + " • /chat show 或 /chat s - 显示当前状态" + ) + await self.ctx.send.text(status_msg, stream_id) + return True, None, False -class ShowFrequencyCommand(BaseCommand): - """显示当前聊天的频率控制状态""" - - command_name = "show_frequency" - command_description = "显示当前聊天的频率控制状态:/chat show 或 /chat s" - command_pattern = r"^/chat\s+(?:show|s)$" - - async def execute(self) -> Tuple[bool, Optional[str], bool]: - try: - # 获取聊天流ID - if not self.message.chat_stream or not hasattr(self.message.chat_stream, "stream_id"): - return False, "无法获取聊天流信息", False - - chat_id = self.message.chat_stream.stream_id - - # 获取当前频率控制状态 - current_talk_frequency = frequency_api.get_current_talk_value(chat_id) - talk_frequency_adjust = frequency_api.get_talk_frequency_adjust(chat_id) - base_value = current_talk_frequency / talk_frequency_adjust - - # 构建显示消息 - status_msg = f"""当前聊天频率控制状态 -Talk Value (发言频率): - - • 基础值: {base_value:.2f} - • 发言频率调整: {talk_frequency_adjust:.2f} - • 当前值: {current_talk_frequency:.2f} - -使用命令: - • /chat talk_frequency <数字> 或 /chat t <数字> - 设置发言频率调整 - • /chat show 或 /chat s - 显示当前状态""" - - # 发送状态消息(不保存到数据库) - await send_api.text_to_stream(status_msg, chat_id, storage_message=False) - - return True, None, False - - except Exception as e: - error_msg = f"获取频率控制状态失败: {str(e)}" - # 使用内置的send_text方法发送错误消息 - await self.send_text(error_msg, storage_message=False) - return False, error_msg, False - - -# ===== 插件注册 ===== - - -@register_plugin -class BetterFrequencyPlugin(BasePlugin): - """BetterFrequency插件 - 控制聊天频率的插件""" - - # 插件基本信息 - plugin_name: str = "better_frequency_plugin" - enable_plugin: bool = True - dependencies: List[str] = [] - python_dependencies: List[str] = [] - config_file_name: str = "config.toml" - - # 配置节描述 - config_section_descriptions = {"plugin": "插件基本信息", "frequency": "频率控制配置", "features": "功能开关配置"} - - # 配置Schema定义 - config_schema: dict = { - "plugin": { - "name": ConfigField(type=str, default="better_frequency_plugin", description="插件名称"), - "version": ConfigField(type=str, default="1.0.2", description="插件版本"), - "enabled": ConfigField(type=bool, default=True, description="是否启用插件"), - }, - "frequency": { - "default_talk_adjust": ConfigField(type=float, default=1.0, description="默认talk_frequency调整值"), - "max_adjust_value": ConfigField(type=float, default=1.0, description="最大调整值"), - "min_adjust_value": ConfigField(type=float, default=0.0, description="最小调整值"), - }, - } - - def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: - components = [] - - # 根据配置决定是否注册命令组件 - if self.config.get("features", {}).get("enable_commands", True): - components.extend( - [ - (SetTalkFrequencyCommand.get_command_info(), SetTalkFrequencyCommand), - (ShowFrequencyCommand.get_command_info(), ShowFrequencyCommand), - ] - ) - - return components +def create_plugin(): + return BetterFrequencyPlugin() diff --git a/plugins/emoji_manage_plugin/_manifest.json b/plugins/emoji_manage_plugin/_manifest.json index 68f5c679..3af69023 100644 --- a/plugins/emoji_manage_plugin/_manifest.json +++ b/plugins/emoji_manage_plugin/_manifest.json @@ -1,7 +1,7 @@ { "manifest_version": 1, "name": "BetterEmoji", - "version": "1.0.0", + "version": "2.0.0", "description": "更好的表情包管理插件", "author": { "name": "SengokuCola", @@ -9,7 +9,7 @@ }, "license": "GPL-v3.0-or-later", "host_application": { - "min_version": "0.10.4" + "min_version": "1.0.0" }, "homepage_url": "https://github.com/SengokuCola/BetterEmoji", "repository_url": "https://github.com/SengokuCola/BetterEmoji", @@ -19,46 +19,49 @@ "plugin" ], "categories": [ - "Examples", - "Tutorial" + "Emoji", + "Management" ], "default_locale": "zh-CN", "locales_path": "_locales", "plugin_info": { "is_built_in": false, "plugin_type": "emoji_manage", + "capabilities": [ + "emoji.get_random", + "emoji.get_count", + "emoji.get_info", + "emoji.get_all", + "emoji.register_emoji", + "emoji.delete_emoji", + "send.text", + "send.forward" + ], "components": [ { - "type": "action", - "name": "hello_greeting", - "description": "向用户发送问候消息" - }, - { - "type": "action", - "name": "bye_greeting", - "description": "向用户发送告别消息", - "activation_modes": [ - "keyword" - ], - "keywords": [ - "再见", - "bye", - "88", - "拜拜" - ] + "type": "command", + "name": "add_emoji", + "description": "添加表情包", + "pattern": "/emoji add" }, { "type": "command", - "name": "time", - "description": "查询当前时间", - "pattern": "/time" + "name": "emoji_list", + "description": "列表表情包", + "pattern": "/emoji list" + }, + { + "type": "command", + "name": "delete_emoji", + "description": "删除表情包", + "pattern": "/emoji delete" + }, + { + "type": "command", + "name": "random_emojis", + "description": "发送多张随机表情包", + "pattern": "/random_emojis" } - ], - "features": [ - "问候和告别功能", - "时间查询命令", - "配置文件示例", - "新手教程代码" ] }, "id": "SengokuCola.BetterEmoji" diff --git a/plugins/emoji_manage_plugin/plugin.py b/plugins/emoji_manage_plugin/plugin.py index 453afbc2..89c3a3cb 100644 --- a/plugins/emoji_manage_plugin/plugin.py +++ b/plugins/emoji_manage_plugin/plugin.py @@ -1,399 +1,216 @@ -from typing import List, Tuple, Type -from src.plugin_system import ( - BasePlugin, - register_plugin, - BaseCommand, - ComponentInfo, - ConfigField, - ReplyContentType, - emoji_api, -) -from maim_message import Seg -from src.common.logger import get_logger +"""表情包管理插件 — 新 SDK 版本 -logger = get_logger("emoji_manage_plugin") +通过 /emoji 命令管理表情包的添加、列表和删除。 +""" + +import base64 +import datetime +import hashlib +import re + +from maibot_sdk import MaiBotPlugin, Command -class AddEmojiCommand(BaseCommand): - command_name = "add_emoji" - command_description = "添加表情包" - command_pattern = r".*/emoji add.*" +class EmojiManagePlugin(MaiBotPlugin): + """表情包管理插件""" - async def execute(self) -> Tuple[bool, str, bool]: - # 查找消息中的表情包 - # logger.info(f"查找消息中的表情包: {self.message.message_segment}") + # ===== 工具方法 ===== - emoji_base64_list = self.find_and_return_emoji_in_message(self.message.message_segment) + @staticmethod + def _extract_emoji_base64(segments) -> list[str]: + """从消息 segments 中提取 emoji/image 的 base64 数据。 + segments 可以是 dict 列表或 Seg 对象列表(兼容两种格式)。 + """ + results: list[str] = [] + if not segments: + return results + + if isinstance(segments, dict): + seg_type = segments.get("type", "") + if seg_type in ("emoji", "image"): + data = segments.get("data", "") + if data: + results.append(data) + elif seg_type == "seglist": + for child in segments.get("data", []): + results.extend(EmojiManagePlugin._extract_emoji_base64(child)) + return results + + # 如果有 .type 属性(Seg 对象) + if hasattr(segments, "type"): + seg_type = getattr(segments, "type", "") + if seg_type in ("emoji", "image"): + results.append(getattr(segments, "data", "")) + elif seg_type == "seglist": + for child in getattr(segments, "data", []): + results.extend(EmojiManagePlugin._extract_emoji_base64(child)) + return results + + # 列表 + for seg in segments: + results.extend(EmojiManagePlugin._extract_emoji_base64(seg)) + return results + + # ===== Command 组件 ===== + + @Command("add_emoji", description="添加表情包", pattern=r".*/emoji add.*") + async def handle_add_emoji(self, stream_id: str = "", message_segments=None, **kwargs): + """添加表情包""" + emoji_base64_list = self._extract_emoji_base64(message_segments) if not emoji_base64_list: + await self.ctx.send.text("未在消息中找到表情包或图片", stream_id) return False, "未在消息中找到表情包或图片", False - # 注册找到的表情包 success_count = 0 fail_count = 0 results = [] - for i, emoji_base64 in enumerate(emoji_base64_list): - try: - # 使用emoji_api注册表情包(让API自动生成唯一文件名) - result = await emoji_api.register_emoji(emoji_base64) - - if result["success"]: - success_count += 1 - description = result.get("description", "未知描述") - emotions = result.get("emotions", []) - replaced = result.get("replaced", False) - - result_msg = f"表情包 {i + 1} 注册成功{'(替换旧表情包)' if replaced else '(新增表情包)'}" - if description: - result_msg += f"\n描述: {description}" - if emotions: - result_msg += f"\n情感标签: {', '.join(emotions)}" - - results.append(result_msg) - else: - fail_count += 1 - error_msg = result.get("message", "注册失败") - results.append(f"表情包 {i + 1} 注册失败: {error_msg}") - - except Exception as e: - fail_count += 1 - results.append(f"表情包 {i + 1} 注册时发生错误: {str(e)}") - - # 构建返回消息 - total_count = success_count + fail_count - summary_msg = f"表情包注册完成: 成功 {success_count} 个,失败 {fail_count} 个,共处理 {total_count} 个" - - # 如果有结果详情,添加到返回消息中 - details_msg = "" - if results: - details_msg = "\n" + "\n".join(results) - final_msg = summary_msg + details_msg - else: - final_msg = summary_msg - - # 使用表达器重写回复 - try: - from src.plugin_system.apis import generator_api - - # 构建重写数据 - rewrite_data = { - "raw_reply": summary_msg, - "reason": f"注册了表情包:{details_msg}\n", - } - - # 调用表达器重写 - result_status, data = await generator_api.rewrite_reply( - chat_stream=self.message.chat_stream, - reply_data=rewrite_data, - ) - - if result_status: - # 发送重写后的回复 - for reply_seg in data.reply_set.reply_data: - send_data = reply_seg.content - await self.send_text(send_data) - - return success_count > 0, final_msg, success_count > 0 + for i, emoji_b64 in enumerate(emoji_base64_list): + result = await self.ctx.emoji.register_emoji(emoji_b64) + if isinstance(result, dict) and result.get("success"): + success_count += 1 + desc = result.get("description", "未知描述") + emotions = result.get("emotions", []) + replaced = result.get("replaced", False) + msg = f"表情包 {i + 1} 注册成功{'(替换旧表情包)' if replaced else '(新增表情包)'}" + if desc: + msg += f"\n描述: {desc}" + if emotions: + msg += f"\n情感标签: {', '.join(emotions)}" + results.append(msg) else: - # 如果重写失败,发送原始消息 - await self.send_text(final_msg) - return success_count > 0, final_msg, success_count > 0 + fail_count += 1 + err = result.get("message", "注册失败") if isinstance(result, dict) else "注册失败" + results.append(f"表情包 {i + 1} 注册失败: {err}") - except Exception as e: - # 如果表达器调用失败,发送原始消息 - logger.error(f"[add_emoji] 表达器重写失败: {e}") - await self.send_text(final_msg) - return success_count > 0, final_msg, success_count > 0 + total = success_count + fail_count + summary = f"表情包注册完成: 成功 {success_count} 个,失败 {fail_count} 个,共处理 {total} 个" + if results: + summary += "\n" + "\n".join(results) - def find_and_return_emoji_in_message(self, message_segments) -> List[str]: - emoji_base64_list = [] + await self.ctx.send.text(summary, stream_id) + return success_count > 0, summary, success_count > 0 - # 处理单个Seg对象的情况 - if isinstance(message_segments, Seg): - if message_segments.type == "emoji": - emoji_base64_list.append(message_segments.data) - elif message_segments.type == "image": - # 假设图片数据是base64编码的 - emoji_base64_list.append(message_segments.data) - elif message_segments.type == "seglist": - # 递归处理嵌套的Seg列表 - emoji_base64_list.extend(self.find_and_return_emoji_in_message(message_segments.data)) - return emoji_base64_list - - # 处理Seg列表的情况 - for seg in message_segments: - if seg.type == "emoji": - emoji_base64_list.append(seg.data) - elif seg.type == "image": - # 假设图片数据是base64编码的 - emoji_base64_list.append(seg.data) - elif seg.type == "seglist": - # 递归处理嵌套的Seg列表 - emoji_base64_list.extend(self.find_and_return_emoji_in_message(seg.data)) - return emoji_base64_list - - -class ListEmojiCommand(BaseCommand): - """列表表情包Command - 响应/emoji list命令""" - - command_name = "emoji_list" - command_description = "列表表情包" - - # === 命令设置(必须填写)=== - command_pattern = r"^/emoji list(\s+\d+)?$" # 匹配 "/emoji list" 或 "/emoji list 数量" - - async def execute(self) -> Tuple[bool, str, bool]: - """执行列表表情包""" - from src.plugin_system.apis import emoji_api - import datetime - - # 解析命令参数 - import re - - match = re.match(r"^/emoji list(?:\s+(\d+))?$", self.message.raw_message) - max_count = 10 # 默认显示10个 + @Command("emoji_list", description="列表表情包", pattern=r"^/emoji list(\s+\d+)?$") + async def handle_list_emoji(self, stream_id: str = "", raw_message: str = "", **kwargs): + """列出表情包""" + max_count = 10 + match = re.match(r"^/emoji list(?:\s+(\d+))?$", raw_message) if match and match.group(1): - max_count = min(int(match.group(1)), 50) # 最多显示50个 + max_count = min(int(match.group(1)), 50) - # 获取当前时间 - time_format: str = self.get_config("time.format", "%Y-%m-%d %H:%M:%S") # type: ignore - now = datetime.datetime.now() - time_str = now.strftime(time_format) + now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") - # 获取表情包信息 - emoji_count = emoji_api.get_count() - emoji_info = emoji_api.get_info() + count_result = await self.ctx.emoji.get_count() + emoji_count = count_result if isinstance(count_result, int) else 0 - # 构建返回消息 - message_lines = [ - f"📊 表情包统计信息 ({time_str})", - f"• 总数: {emoji_count} / {emoji_info['max_count']}", - f"• 可用: {emoji_info['available_emojis']}", + info_result = await self.ctx.emoji.get_info() + max_emoji = info_result.get("max_count", 0) if isinstance(info_result, dict) else 0 + available = info_result.get("available_emojis", 0) if isinstance(info_result, dict) else 0 + + lines = [ + f"📊 表情包统计信息 ({now})", + f"• 总数: {emoji_count} / {max_emoji}", + f"• 可用: {available}", ] if emoji_count == 0: - message_lines.append("\n❌ 暂无表情包") - final_message = "\n".join(message_lines) - await self.send_text(final_message) - return True, final_message, True + lines.append("\n❌ 暂无表情包") + await self.ctx.send.text("\n".join(lines), stream_id) + return True, "\n".join(lines), True - # 获取所有表情包 - all_emojis = await emoji_api.get_all() + all_result = await self.ctx.emoji.get_all() + all_emojis = all_result if isinstance(all_result, list) else [] if not all_emojis: - message_lines.append("\n❌ 无法获取表情包列表") - final_message = "\n".join(message_lines) - await self.send_text(final_message) - return False, final_message, True + lines.append("\n❌ 无法获取表情包列表") + await self.ctx.send.text("\n".join(lines), stream_id) + return False, "\n".join(lines), True - # 显示前N个表情包 - display_emojis = all_emojis[:max_count] - message_lines.append(f"\n📋 显示前 {len(display_emojis)} 个表情包:") + display = all_emojis[:max_count] + lines.append(f"\n📋 显示前 {len(display)} 个表情包:") + for i, emoji in enumerate(display, 1): + if isinstance(emoji, (list, tuple)) and len(emoji) >= 3: + _, desc, emotion = emoji[0], emoji[1], emoji[2] + elif isinstance(emoji, dict): + desc = emoji.get("description", "") + emotion = emoji.get("emotion", "") + else: + desc, emotion = str(emoji), "" + short_desc = desc[:50] + "..." if len(desc) > 50 else desc + lines.append(f"{i}. {short_desc} [{emotion}]") - for i, (_, description, emotion) in enumerate(display_emojis, 1): - # 截断过长的描述 - short_desc = description[:50] + "..." if len(description) > 50 else description - message_lines.append(f"{i}. {short_desc} [{emotion}]") - - # 如果还有更多表情包,显示总数 if len(all_emojis) > max_count: - message_lines.append(f"\n💡 还有 {len(all_emojis) - max_count} 个表情包未显示") + lines.append(f"\n💡 还有 {len(all_emojis) - max_count} 个表情包未显示") - final_message = "\n".join(message_lines) - - # 直接发送文本消息 - await self.send_text(final_message) - - return True, final_message, True - - -class DeleteEmojiCommand(BaseCommand): - command_name = "delete_emoji" - command_description = "删除表情包" - command_pattern = r".*/emoji delete.*" - - async def execute(self) -> Tuple[bool, str, bool]: - # 查找消息中的表情包图片 - logger.info(f"查找消息中的表情包用于删除: {self.message.message_segment}") - - emoji_base64_list = self.find_and_return_emoji_in_message(self.message.message_segment) + final = "\n".join(lines) + await self.ctx.send.text(final, stream_id) + return True, final, True + @Command("delete_emoji", description="删除表情包", pattern=r".*/emoji delete.*") + async def handle_delete_emoji(self, stream_id: str = "", message_segments=None, **kwargs): + """删除表情包""" + emoji_base64_list = self._extract_emoji_base64(message_segments) if not emoji_base64_list: - return False, "未在消息中找到表情包或图片", False + await self.ctx.send.text("未在消息中找到表情包或图片", stream_id) + return False, "未找到表情包", False - # 删除找到的表情包 success_count = 0 fail_count = 0 results = [] - for i, emoji_base64 in enumerate(emoji_base64_list): - try: - # 计算图片的哈希值来查找对应的表情包 - import base64 - import hashlib - - # 确保base64字符串只包含ASCII字符 - if isinstance(emoji_base64, str): - emoji_base64_clean = emoji_base64.encode("ascii", errors="ignore").decode("ascii") - else: - emoji_base64_clean = str(emoji_base64) - - # 计算哈希值 - image_bytes = base64.b64decode(emoji_base64_clean) - emoji_hash = hashlib.md5(image_bytes).hexdigest() - - # 使用emoji_api删除表情包 - result = await emoji_api.delete_emoji(emoji_hash) - - if result["success"]: - success_count += 1 - description = result.get("description", "未知描述") - count_before = result.get("count_before", 0) - count_after = result.get("count_after", 0) - emotions = result.get("emotions", []) - - result_msg = f"表情包 {i + 1} 删除成功" - if description: - result_msg += f"\n描述: {description}" - if emotions: - result_msg += f"\n情感标签: {', '.join(emotions)}" - result_msg += f"\n表情包数量: {count_before} → {count_after}" - - results.append(result_msg) - else: - fail_count += 1 - error_msg = result.get("message", "删除失败") - results.append(f"表情包 {i + 1} 删除失败: {error_msg}") - - except Exception as e: - fail_count += 1 - results.append(f"表情包 {i + 1} 删除时发生错误: {str(e)}") - - # 构建返回消息 - total_count = success_count + fail_count - summary_msg = f"表情包删除完成: 成功 {success_count} 个,失败 {fail_count} 个,共处理 {total_count} 个" - - # 如果有结果详情,添加到返回消息中 - details_msg = "" - if results: - details_msg = "\n" + "\n".join(results) - final_msg = summary_msg + details_msg - else: - final_msg = summary_msg - - # 使用表达器重写回复 - try: - from src.plugin_system.apis import generator_api - - # 构建重写数据 - rewrite_data = { - "raw_reply": summary_msg, - "reason": f"删除了表情包:{details_msg}\n", - } - - # 调用表达器重写 - result_status, data = await generator_api.rewrite_reply( - chat_stream=self.message.chat_stream, - reply_data=rewrite_data, - ) - - if result_status: - # 发送重写后的回复 - for reply_seg in data.reply_set.reply_data: - send_data = reply_seg.content - await self.send_text(send_data) - - return success_count > 0, final_msg, success_count > 0 + for i, emoji_b64 in enumerate(emoji_base64_list): + # 计算哈希 + if isinstance(emoji_b64, str): + clean = emoji_b64.encode("ascii", errors="ignore").decode("ascii") else: - # 如果重写失败,发送原始消息 - await self.send_text(final_msg) - return success_count > 0, final_msg, success_count > 0 + clean = str(emoji_b64) + image_bytes = base64.b64decode(clean) + emoji_hash = hashlib.md5(image_bytes).hexdigest() # noqa: S324 - except Exception as e: - # 如果表达器调用失败,发送原始消息 - logger.error(f"[delete_emoji] 表达器重写失败: {e}") - await self.send_text(final_msg) - return success_count > 0, final_msg, success_count > 0 + result = await self.ctx.emoji.delete_emoji(emoji_hash) + if isinstance(result, dict) and result.get("success"): + success_count += 1 + desc = result.get("description", "未知描述") + emotions = result.get("emotions", []) + before = result.get("count_before", 0) + after = result.get("count_after", 0) + msg = f"表情包 {i + 1} 删除成功" + if desc: + msg += f"\n描述: {desc}" + if emotions: + msg += f"\n情感标签: {', '.join(emotions)}" + msg += f"\n表情包数量: {before} → {after}" + results.append(msg) + else: + fail_count += 1 + err = result.get("message", "删除失败") if isinstance(result, dict) else "删除失败" + results.append(f"表情包 {i + 1} 删除失败: {err}") - def find_and_return_emoji_in_message(self, message_segments) -> List[str]: - emoji_base64_list = [] + total = success_count + fail_count + summary = f"表情包删除完成: 成功 {success_count} 个,失败 {fail_count} 个,共处理 {total} 个" + if results: + summary += "\n" + "\n".join(results) - # 处理单个Seg对象的情况 - if isinstance(message_segments, Seg): - if message_segments.type == "emoji": - emoji_base64_list.append(message_segments.data) - elif message_segments.type == "image": - # 假设图片数据是base64编码的 - emoji_base64_list.append(message_segments.data) - elif message_segments.type == "seglist": - # 递归处理嵌套的Seg列表 - emoji_base64_list.extend(self.find_and_return_emoji_in_message(message_segments.data)) - return emoji_base64_list + await self.ctx.send.text(summary, stream_id) + return success_count > 0, summary, success_count > 0 - # 处理Seg列表的情况 - for seg in message_segments: - if seg.type == "emoji": - emoji_base64_list.append(seg.data) - elif seg.type == "image": - # 假设图片数据是base64编码的 - emoji_base64_list.append(seg.data) - elif seg.type == "seglist": - # 递归处理嵌套的Seg列表 - emoji_base64_list.extend(self.find_and_return_emoji_in_message(seg.data)) - return emoji_base64_list - - -class RandomEmojis(BaseCommand): - command_name = "random_emojis" - command_description = "发送多张随机表情包" - command_pattern = r"^/random_emojis$" - - async def execute(self): - emojis = await emoji_api.get_random(5) + @Command("random_emojis", description="发送多张随机表情包", pattern=r"^/random_emojis$") + async def handle_random_emojis(self, stream_id: str = "", **kwargs): + """发送多张随机表情包""" + result = await self.ctx.emoji.get_random(5) + if not result or not result.get("success"): + return False, "未找到表情包", False + emojis = result.get("emojis", []) if not emojis: return False, "未找到表情包", False - emoji_base64_list = [] - for emoji in emojis: - emoji_base64_list.append(emoji[0]) - return await self.forward_images(emoji_base64_list) - - async def forward_images(self, images: List[str]): - """ - 把多张图片用合并转发的方式发给用户 - """ - success = await self.send_forward([("0", "神秘用户", [(ReplyContentType.IMAGE, img)]) for img in images]) - return (True, "已发送随机表情包", True) if success else (False, "发送随机表情包失败", False) - - -# ===== 插件注册 ===== - - -@register_plugin -class EmojiManagePlugin(BasePlugin): - """表情包管理插件 - 管理表情包""" - - # 插件基本信息 - plugin_name: str = "emoji_manage_plugin" # 内部标识符 - enable_plugin: bool = False - dependencies: List[str] = [] # 插件依赖列表 - python_dependencies: List[str] = [] # Python包依赖列表 - config_file_name: str = "config.toml" # 配置文件名 - - # 配置节描述 - config_section_descriptions = {"plugin": "插件基本信息", "emoji": "表情包功能配置"} - - # 配置Schema定义 - config_schema: dict = { - "plugin": { - "enabled": ConfigField(type=bool, default=True, description="是否启用插件"), - "config_version": ConfigField(type=str, default="1.0.1", description="配置文件版本"), - }, - } - - def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: - return [ - (RandomEmojis.get_command_info(), RandomEmojis), - (AddEmojiCommand.get_command_info(), AddEmojiCommand), - (ListEmojiCommand.get_command_info(), ListEmojiCommand), - (DeleteEmojiCommand.get_command_info(), DeleteEmojiCommand), + messages = [ + {"user_id": "0", "nickname": "神秘用户", "segments": [{"type": "image", "content": e.get("base64", "")}]} + for e in emojis ] + await self.ctx.send.forward(messages, stream_id) + return True, "已发送随机表情包", True + + +def create_plugin(): + return EmojiManagePlugin() diff --git a/plugins/hello_world_plugin/_manifest.json b/plugins/hello_world_plugin/_manifest.json index 25234a00..dc9fc474 100644 --- a/plugins/hello_world_plugin/_manifest.json +++ b/plugins/hello_world_plugin/_manifest.json @@ -1,7 +1,7 @@ { "manifest_version": 1, "name": "Hello World 示例插件 (Hello World Plugin)", - "version": "1.0.0", + "version": "2.0.0", "description": "我的第一个MaiCore插件,包含问候功能和时间查询等基础示例", "author": { "name": "MaiBot开发团队", @@ -9,7 +9,7 @@ }, "license": "GPL-v3.0-or-later", "host_application": { - "min_version": "0.8.0" + "min_version": "1.0.0" }, "homepage_url": "https://github.com/MaiM-with-u/maibot", "repository_url": "https://github.com/MaiM-with-u/maibot", @@ -29,7 +29,19 @@ "plugin_info": { "is_built_in": false, "plugin_type": "example", + "capabilities": [ + "send.text", + "send.forward", + "send.hybrid", + "emoji.get_random", + "config.get" + ], "components": [ + { + "type": "tool", + "name": "compare_numbers", + "description": "比较两个数的大小" + }, { "type": "action", "name": "hello_greeting", @@ -39,28 +51,37 @@ "type": "action", "name": "bye_greeting", "description": "向用户发送告别消息", - "activation_modes": [ - "keyword" - ], - "keywords": [ - "再见", - "bye", - "88", - "拜拜" - ] + "activation_modes": ["keyword"], + "keywords": ["再见", "bye", "88", "拜拜"] }, { "type": "command", "name": "time", "description": "查询当前时间", "pattern": "/time" + }, + { + "type": "command", + "name": "random_emojis", + "description": "发送多张随机表情包", + "pattern": "/random_emojis" + }, + { + "type": "command", + "name": "test", + "description": "测试命令", + "pattern": "/test" + }, + { + "type": "event_handler", + "name": "print_message_handler", + "description": "打印接收到的消息" + }, + { + "type": "event_handler", + "name": "forward_messages_handler", + "description": "把接收到的消息转发到指定聊天ID" } - ], - "features": [ - "问候和告别功能", - "时间查询命令", - "配置文件示例", - "新手教程代码" ] }, "id": "MaiBot开发团队.maibot" diff --git a/plugins/hello_world_plugin/plugin.py b/plugins/hello_world_plugin/plugin.py index 76567817..4e8460c0 100644 --- a/plugins/hello_world_plugin/plugin.py +++ b/plugins/hello_world_plugin/plugin.py @@ -1,50 +1,30 @@ +"""Hello World 示例插件 — 新 SDK 版本 + +你的第一个 MaiCore 插件,包含问候功能、时间查询等基础示例。 +""" + +import datetime import random -from typing import List, Tuple, Type, Any, Optional -from src.plugin_system import ( - BasePlugin, - register_plugin, - BaseAction, - BaseCommand, - BaseTool, - ComponentInfo, - ActionActivationType, - ConfigField, - BaseEventHandler, - EventType, - MaiMessages, - ToolParamType, - ReplyContentType, - emoji_api, -) -from src.config.config import global_config -from src.common.logger import get_logger -logger = get_logger("hello_world_plugin") +from maibot_sdk import MaiBotPlugin, Action, Command, Tool, EventHandler +from maibot_sdk.types import ActivationType, EventType, ToolParameterInfo, ToolParamType -class CompareNumbersTool(BaseTool): - """比较两个数大小的工具""" +class HelloWorldPlugin(MaiBotPlugin): + """Hello World 示例插件""" - name = "compare_numbers" - description = "使用工具 比较两个数的大小,返回较大的数" - parameters = [ - ("num1", ToolParamType.FLOAT, "第一个数字", True, None), - ("num2", ToolParamType.FLOAT, "第二个数字", True, None), - ] - available_for_llm = True - - async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: - """执行比较两个数的大小 - - Args: - function_args: 工具参数 - - Returns: - dict: 工具执行结果 - """ - num1: int | float = function_args.get("num1") # type: ignore - num2: int | float = function_args.get("num2") # type: ignore + # ===== Tool 组件 ===== + @Tool( + "compare_numbers", + description="使用工具比较两个数的大小,返回较大的数", + parameters=[ + ToolParameterInfo(name="num1", param_type=ToolParamType.FLOAT, description="第一个数字", required=True), + ToolParameterInfo(name="num2", param_type=ToolParamType.FLOAT, description="第二个数字", required=True), + ], + ) + async def handle_compare_numbers(self, num1: float = 0, num2: float = 0, **kwargs): + """比较两个数的大小""" try: if num1 > num2: result = f"{num1} 大于 {num2}" @@ -52,270 +32,121 @@ class CompareNumbersTool(BaseTool): result = f"{num1} 小于 {num2}" else: result = f"{num1} 等于 {num2}" - - return {"name": self.name, "content": result} + return {"name": "compare_numbers", "content": result} except Exception as e: - return {"name": self.name, "content": f"比较数字失败,炸了: {str(e)}"} + return {"name": "compare_numbers", "content": f"比较数字失败,炸了: {e}"} + # ===== Action 组件 ===== -# ===== Action组件 ===== -class HelloAction(BaseAction): - """问候Action - 简单的问候动作""" - - # === 基本信息(必须填写)=== - action_name = "hello_greeting" - action_description = "向用户发送问候消息" - activation_type = ActionActivationType.ALWAYS # 始终激活 - - # === 功能描述(必须填写)=== - action_parameters = {"greeting_message": "要发送的问候消息"} - action_require = ["需要发送友好问候时使用", "当有人向你问好时使用", "当你遇见没有见过的人时使用"] - associated_types = ["text"] - - async def execute(self) -> Tuple[bool, str]: - """执行问候动作 - 这是核心功能""" - # 发送问候消息 - greeting_message = self.action_data.get("greeting_message", "") - base_message = self.get_config("greeting.message", "嗨!很开心见到你!😊") + @Action( + "hello_greeting", + description="向用户发送问候消息", + activation_type=ActivationType.ALWAYS, + action_parameters={"greeting_message": "要发送的问候消息"}, + action_require=["需要发送友好问候时使用", "当有人向你问好时使用", "当你遇见没有见过的人时使用"], + associated_types=["text"], + ) + async def handle_hello(self, stream_id: str = "", greeting_message: str = "", **kwargs): + """问候动作""" + config_result = await self.ctx.config.get("greeting.message") + base_message = config_result if isinstance(config_result, str) else "嗨!很开心见到你!😊" message = base_message + greeting_message - await self.send_text(message) - + await self.ctx.send.text(message, stream_id) return True, "发送了问候消息" - -class ByeAction(BaseAction): - """告别Action - 只在用户说再见时激活""" - - action_name = "bye_greeting" - action_description = "向用户发送告别消息" - - # 使用关键词激活 - activation_type = ActionActivationType.KEYWORD - - # 关键词设置 - activation_keywords = ["再见", "bye", "88", "拜拜"] - keyword_case_sensitive = False - - action_parameters = {"bye_message": "要发送的告别消息"} - action_require = [ - "用户要告别时使用", - "当有人要离开时使用", - "当有人和你说再见时使用", - ] - associated_types = ["text"] - - async def execute(self) -> Tuple[bool, str]: - bye_message = self.action_data.get("bye_message", "") - + @Action( + "bye_greeting", + description="向用户发送告别消息", + activation_type=ActivationType.KEYWORD, + activation_keywords=["再见", "bye", "88", "拜拜"], + action_parameters={"bye_message": "要发送的告别消息"}, + action_require=["用户要告别时使用", "当有人要离开时使用", "当有人和你说再见时使用"], + associated_types=["text"], + ) + async def handle_bye(self, stream_id: str = "", bye_message: str = "", **kwargs): + """告别动作""" message = f"再见!期待下次聊天!👋{bye_message}" - await self.send_text(message) + await self.ctx.send.text(message, stream_id) return True, "发送了告别消息" + # ===== Command 组件 ===== -class TimeCommand(BaseCommand): - """时间查询Command - 响应/time命令""" - - command_name = "time" - command_description = "查询当前时间" - - # === 命令设置(必须填写)=== - command_pattern = r"^/time$" # 精确匹配 "/time" 命令 - - async def execute(self) -> Tuple[bool, str, bool]: - """执行时间查询""" - import datetime - - # 获取当前时间 - time_format: str = self.get_config("time.format", "%Y-%m-%d %H:%M:%S") # type: ignore + @Command("time", description="查询当前时间", pattern=r"^/time$") + async def handle_time(self, stream_id: str = "", **kwargs): + """时间查询命令""" + config_result = await self.ctx.config.get("time.format") + time_format = config_result if isinstance(config_result, str) else "%Y-%m-%d %H:%M:%S" now = datetime.datetime.now() time_str = now.strftime(time_format) - - # 发送时间信息 - message = f"⏰ 当前时间:{time_str}" - await self.send_text(message) - + await self.ctx.send.text(f"⏰ 当前时间:{time_str}", stream_id) return True, f"显示了当前时间: {time_str}", True + @Command("random_emojis", description="发送多张随机表情包", pattern=r"^/random_emojis$") + async def handle_random_emojis(self, stream_id: str = "", **kwargs): + """发送多张随机表情包""" + result = await self.ctx.emoji.get_random(5) + if not result or not result.get("success"): + return False, "未找到表情包", False + emojis = result.get("emojis", []) + if not emojis: + return False, "未找到表情包", False + # 用转发消息发送多张图片 + messages = [ + {"user_id": "0", "nickname": "神秘用户", "segments": [{"type": "image", "content": e.get("base64", "")}]} + for e in emojis + ] + await self.ctx.send.forward(messages, stream_id) + return True, "已发送随机表情包", True -class PrintMessage(BaseEventHandler): - """打印消息事件处理器 - 处理打印消息事件""" + @Command("test", description="测试命令", pattern=r"^/test$") + async def handle_test(self, stream_id: str = "", **kwargs): + """测试命令 — 发送简单测试消息""" + await self.ctx.send.text("测试正常!Bot 功能运行中 ✅", stream_id) + return True, "测试完成", True - event_type = EventType.ON_MESSAGE - handler_name = "print_message_handler" - handler_description = "打印接收到的消息" + # ===== EventHandler 组件 ===== - async def execute(self, message: MaiMessages | None) -> Tuple[bool, bool, str | None, None, None]: - """执行打印消息事件处理""" - # 打印接收到的消息 - if self.get_config("print_message.enabled", False): - print(f"接收到消息: {message.raw_message if message else '无效消息'}") + @EventHandler("print_message_handler", description="打印接收到的消息", event_type=EventType.ON_MESSAGE) + async def handle_print_message(self, message=None, **kwargs): + """打印消息事件""" + config_result = await self.ctx.config.get("print_message.enabled") + enabled = config_result if isinstance(config_result, bool) else False + if enabled and message: + raw = message.get("raw_message", "") if isinstance(message, dict) else str(message) + print(f"接收到消息: {raw}") return True, True, "消息已打印", None, None - -class ForwardMessages(BaseEventHandler): - """ - 把接收到的消息转发到指定聊天ID - - 此组件是HYBRID消息和FORWARD消息的使用示例。 - 每收到10条消息,就会以1%的概率使用HYBRID消息转发,否则使用FORWARD消息转发。 - """ - - event_type = EventType.ON_MESSAGE - handler_name = "forward_messages_handler" - handler_description = "把接收到的消息转发到指定聊天ID" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.counter = 0 # 用于计数转发的消息数量 - self.messages: List[str] = [] - - async def execute(self, message: MaiMessages | None) -> Tuple[bool, bool, None, None, None]: + @EventHandler("forward_messages_handler", description="把接收到的消息转发到指定聊天ID", event_type=EventType.ON_MESSAGE) + async def handle_forward_messages(self, message=None, stream_id: str = "", **kwargs): + """收集消息并定期转发""" if not message: return True, True, None, None, None - stream_id = message.stream_id or "" + plain_text = message.get("plain_text", "") if isinstance(message, dict) else "" + if not plain_text: + return True, True, None, None, None - if message.plain_text: - self.messages.append(message.plain_text) - self.counter += 1 - if self.counter % 10 == 0: + # 使用插件级状态收集消息 + if not hasattr(self, "_fwd_messages"): + self._fwd_messages: list[str] = [] + self._fwd_counter: int = 0 + + self._fwd_messages.append(plain_text) + self._fwd_counter += 1 + + if self._fwd_counter % 10 == 0 and stream_id: if random.random() < 0.01: - success = await self.send_hybrid(stream_id, [(ReplyContentType.TEXT, msg) for msg in self.messages]) + segments = [{"type": "text", "content": msg} for msg in self._fwd_messages] + await self.ctx.send.hybrid(segments, stream_id) else: - success = await self.send_forward( - stream_id, - [ - ( - str(global_config.bot.qq_account), - str(global_config.bot.nickname), - [(ReplyContentType.TEXT, msg)], - ) - for msg in self.messages - ], - ) - if not success: - raise ValueError("转发消息失败") - self.messages = [] + messages = [ + {"user_id": "0", "nickname": "转发", "segments": [{"type": "text", "content": msg}]} + for msg in self._fwd_messages + ] + await self.ctx.send.forward(messages, stream_id) + self._fwd_messages = [] + return True, True, None, None, None -class RandomEmojis(BaseCommand): - command_name = "random_emojis" - command_description = "发送多张随机表情包" - command_pattern = r"^/random_emojis$" - - async def execute(self): - emojis = await emoji_api.get_random(5) - if not emojis: - return False, "未找到表情包", False - emoji_base64_list = [] - for emoji in emojis: - emoji_base64_list.append(emoji[0]) - return await self.forward_images(emoji_base64_list) - - async def forward_images(self, images: List[str]): - """ - 把多张图片用合并转发的方式发给用户 - """ - success = await self.send_forward([("0", "神秘用户", [(ReplyContentType.IMAGE, img)]) for img in images]) - return (True, "已发送随机表情包", True) if success else (False, "发送随机表情包失败", False) - - -class TestCommand(BaseCommand): - """响应/test命令""" - - command_name = "test" - command_description = "测试命令" - command_pattern = r"^/test$" - - async def execute(self) -> Tuple[bool, Optional[str], int]: - """执行测试命令""" - try: - from src.plugin_system.apis import generator_api - - reply_reason = "这是一条测试消息。" - logger.info(f"测试命令:{reply_reason}") - result_status, data = await generator_api.generate_reply( - chat_stream=self.message.chat_stream, - reply_reason=reply_reason, - enable_chinese_typo=False, - extra_info=f'{reply_reason}用于测试bot的功能是否正常。请你按设定的人设表达一句"测试正常"', - ) - if result_status: - # 发送生成的回复 - if data and data.reply_set and data.reply_set.reply_data: - for reply_seg in data.reply_set.reply_data: - send_data = reply_seg.content - await self.send_text(send_data, storage_message=True) - logger.info(f"已回复: {send_data}") - return True, "", 1 - except Exception as e: - logger.error(f"表达器生成失败:{e}") - return True, "", 1 - - -# ===== 插件注册 ===== - - -@register_plugin -class HelloWorldPlugin(BasePlugin): - """Hello World插件 - 你的第一个MaiCore插件""" - - # 插件基本信息 - plugin_name: str = "hello_world_plugin" # 内部标识符 - enable_plugin: bool = False - dependencies: List[str] = [] # 插件依赖列表 - python_dependencies: List[str] = [] # Python包依赖列表 - config_file_name: str = "config.toml" # 配置文件名 - - # 配置节描述 - config_section_descriptions = {"plugin": "插件基本信息", "greeting": "问候功能配置", "time": "时间查询配置"} - - # 配置Schema定义 - config_schema: dict = { - "plugin": { - "config_version": ConfigField(type=str, default="1.0.0", description="配置文件版本"), - "enabled": ConfigField(type=bool, default=False, description="是否启用插件"), - }, - "greeting": { - "message": ConfigField( - type=list, default=["嗨!很开心见到你!😊", "Ciallo~(∠・ω< )⌒★"], description="默认问候消息" - ), - "enable_emoji": ConfigField(type=bool, default=True, description="是否启用表情符号"), - }, - "time": {"format": ConfigField(type=str, default="%Y-%m-%d %H:%M:%S", description="时间显示格式")}, - "print_message": {"enabled": ConfigField(type=bool, default=True, description="是否启用打印")}, - } - - def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: - return [ - (HelloAction.get_action_info(), HelloAction), - (CompareNumbersTool.get_tool_info(), CompareNumbersTool), # 添加比较数字工具 - (ByeAction.get_action_info(), ByeAction), # 添加告别Action - (TimeCommand.get_command_info(), TimeCommand), - (PrintMessage.get_handler_info(), PrintMessage), - (ForwardMessages.get_handler_info(), ForwardMessages), - (RandomEmojis.get_command_info(), RandomEmojis), - (TestCommand.get_command_info(), TestCommand), - ] - - -# @register_plugin -# class HelloWorldEventPlugin(BaseEPlugin): -# """Hello World事件插件 - 处理问候和告别事件""" - -# plugin_name = "hello_world_event_plugin" -# enable_plugin = False -# dependencies = [] -# python_dependencies = [] -# config_file_name = "event_config.toml" - -# config_schema = { -# "plugin": { -# "name": ConfigField(type=str, default="hello_world_event_plugin", description="插件名称"), -# "version": ConfigField(type=str, default="1.0.0", description="插件版本"), -# "enabled": ConfigField(type=bool, default=True, description="是否启用插件"), -# }, -# } - -# def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: -# return [(PrintMessage.get_handler_info(), PrintMessage)] +def create_plugin(): + return HelloWorldPlugin() diff --git a/pyproject.toml b/pyproject.toml index b84324fd..cc6cb884 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "uvicorn>=0.35.0", "msgpack>=1.1.2", "watchfiles>=1.1.1", + "maibot-plugin-sdk>=1.0.0", ] diff --git a/requirements.txt b/requirements.txt index 2ddfe645..b65420de 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ fastapi>=0.116.0 google-genai>=1.39.1 jieba>=0.42.1 json-repair>=0.47.6 +maibot-plugin-sdk>=1.0.0 maim-message>=0.6.2 matplotlib>=3.10.3 numpy>=2.2.6 diff --git a/src/chat/brain_chat/brain_chat.py b/src/chat/brain_chat/brain_chat.py index 85ccc238..befa1675 100644 --- a/src/chat/brain_chat/brain_chat.py +++ b/src/chat/brain_chat/brain_chat.py @@ -19,9 +19,10 @@ from src.chat.heart_flow.hfc_utils import CycleDetail from src.bw_learner.expression_learner import expression_learner_manager from src.bw_learner.message_recorder import extract_and_distribute_messages from src.person_info.person_info import Person -from src.plugin_system.base.component_types import EventType, ActionInfo -from src.plugin_system.core import events_manager -from src.plugin_system.apis import generator_api, send_api, message_api, database_api +from src.core.types import ActionInfo, EventType +from src.core.event_bus import event_bus +from src.chat.event_helpers import build_event_message +from src.services import generator_service as generator_api, send_service as send_api, message_service as message_api, database_service as database_api from src.chat.utils.chat_message_builder import ( build_readable_messages_with_id, get_raw_msg_before_timestamp_with_chat, @@ -315,8 +316,9 @@ class BrainChatting: message_id_list=message_id_list, prompt_key="brain_planner", ) - continue_flag, modified_message = await events_manager.handle_mai_events( - EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id + _event_msg = build_event_message(EventType.ON_PLAN, llm_prompt=prompt_info[0], stream_id=self.chat_stream.stream_id) + continue_flag, modified_message = await event_bus.emit( + EventType.ON_PLAN, _event_msg ) if not continue_flag: return False diff --git a/src/chat/brain_chat/brain_planner.py b/src/chat/brain_chat/brain_planner.py index 44eaf0bc..281830b0 100644 --- a/src/chat/brain_chat/brain_planner.py +++ b/src/chat/brain_chat/brain_planner.py @@ -23,8 +23,8 @@ from src.chat.utils.chat_message_builder import ( from src.chat.utils.utils import get_chat_type_and_target_info from src.chat.planner_actions.action_manager import ActionManager from src.chat.message_receive.chat_manager import chat_manager as _chat_manager -from src.plugin_system.base.component_types import ActionInfo, ComponentType, ActionActivationType -from src.plugin_system.core.component_registry import component_registry +from src.core.types import ActionActivationType, ActionInfo, ComponentType +from src.core.component_registry import component_registry if TYPE_CHECKING: from src.common.data_models.info_data_model import TargetPersonInfo diff --git a/src/chat/event_helpers.py b/src/chat/event_helpers.py new file mode 100644 index 00000000..818c4f88 --- /dev/null +++ b/src/chat/event_helpers.py @@ -0,0 +1,169 @@ +""" +事件消息构建工具 + +将 chat 层的消息对象 (SessionMessage / MessageSending) 转换为 +核心事件系统使用的 MaiMessages,供调用 event_bus.emit() 前使用。 +""" + +from typing import List, Optional, TYPE_CHECKING + +from src.common.logger import get_logger +from src.core.types import EventType, MaiMessages + +if TYPE_CHECKING: + from src.chat.message_receive.message import MessageSending, SessionMessage + from src.common.data_models.llm_data_model import LLMGenerationDataModel + +logger = get_logger("event_helpers") + + +def build_event_message( + event_type: EventType | str, + message: Optional["SessionMessage | MessageSending | MaiMessages"] = None, + llm_prompt: Optional[str] = None, + llm_response: Optional["LLMGenerationDataModel"] = None, + stream_id: Optional[str] = None, + action_usage: Optional[List[str]] = None, +) -> Optional[MaiMessages]: + """根据事件类型和输入,准备和转换消息对象。 + + 迁移自 events_manager._prepare_message,保持相同的行为。 + """ + if isinstance(message, MaiMessages): + return message.deepcopy() + + if message: + return _transform_event_message(message, llm_prompt, llm_response) + + if event_type not in (EventType.ON_START, EventType.ON_STOP): + assert stream_id, "如果没有消息,必须为非启动/关闭事件提供流ID" + if event_type in (EventType.ON_MESSAGE, EventType.ON_PLAN, EventType.POST_LLM, EventType.AFTER_LLM): + return _build_message_from_stream(stream_id, llm_prompt, llm_response) + else: + return _build_message_without_raw(stream_id, llm_prompt, llm_response, action_usage) + + return None # ON_START / ON_STOP 没有消息体 + + +def _transform_event_message( + message: "SessionMessage | MessageSending", + llm_prompt: Optional[str] = None, + llm_response: Optional["LLMGenerationDataModel"] = None, +) -> MaiMessages: + """将 SessionMessage / MessageSending 转换为 MaiMessages。""" + from maim_message import Seg + from src.chat.message_receive.message import MessageSending + + transformed = MaiMessages( + llm_prompt=llm_prompt, + llm_response_content=llm_response.content if llm_response else None, + llm_response_reasoning=llm_response.reasoning if llm_response else None, + llm_response_model=llm_response.model if llm_response else None, + llm_response_tool_call=llm_response.tool_calls if llm_response else None, + raw_message=message.processed_plain_text or "", + additional_data={}, + ) + + # 消息段处理 + if isinstance(message, MessageSending): + if message.message_segment.type == "seglist": + transformed.message_segments = list(message.message_segment.data) # type: ignore + else: + transformed.message_segments = [message.message_segment] + else: + transformed.message_segments = [Seg(type="text", data=message.processed_plain_text or "")] + + # stream_id + transformed.stream_id = message.session_id if hasattr(message, "session_id") else "" + + # 处理后文本 + transformed.plain_text = message.processed_plain_text + + # 基本信息 + if isinstance(message, MessageSending): + transformed.message_base_info["platform"] = message.platform + if message.session.group_id: + transformed.is_group_message = True + group_name = "" + if ( + message.session.context + and message.session.context.message + and message.session.context.message.message_info.group_info + ): + group_name = message.session.context.message.message_info.group_info.group_name + transformed.message_base_info.update( + { + "group_id": message.session.group_id, + "group_name": group_name, + } + ) + transformed.message_base_info.update( + { + "user_id": message.bot_user_info.user_id, + "user_cardname": message.bot_user_info.user_cardname, + "user_nickname": message.bot_user_info.user_nickname, + } + ) + if not transformed.is_group_message: + transformed.is_private_message = True + elif hasattr(message, "message_info") and message.message_info: + if message.platform: + transformed.message_base_info["platform"] = message.platform + if message.message_info.group_info: + transformed.is_group_message = True + transformed.message_base_info.update( + { + "group_id": message.message_info.group_info.group_id, + "group_name": message.message_info.group_info.group_name, + } + ) + if message.message_info.user_info: + if not transformed.is_group_message: + transformed.is_private_message = True + transformed.message_base_info.update( + { + "user_id": message.message_info.user_info.user_id, + "user_cardname": message.message_info.user_info.user_cardname, + "user_nickname": message.message_info.user_info.user_nickname, + } + ) + + return transformed + + +def _build_message_from_stream( + stream_id: str, + llm_prompt: Optional[str] = None, + llm_response: Optional["LLMGenerationDataModel"] = None, +) -> MaiMessages: + """从 stream_id 查找会话消息并转换。""" + from src.chat.message_receive.chat_manager import chat_manager + + session = chat_manager.get_session_by_session_id(stream_id) + assert session, f"未找到流ID为 {stream_id} 的会话" + return _transform_event_message(session.context.message, llm_prompt, llm_response) + + +def _build_message_without_raw( + stream_id: str, + llm_prompt: Optional[str] = None, + llm_response: Optional["LLMGenerationDataModel"] = None, + action_usage: Optional[List[str]] = None, +) -> MaiMessages: + """没有原始消息对象时,从 stream_id 构建最小 MaiMessages。""" + from src.chat.message_receive.chat_manager import chat_manager + + session = chat_manager.get_session_by_session_id(stream_id) + assert session, f"未找到流ID为 {stream_id} 的会话" + return MaiMessages( + stream_id=stream_id, + llm_prompt=llm_prompt, + llm_response_content=llm_response.content if llm_response else None, + llm_response_reasoning=llm_response.reasoning if llm_response else None, + llm_response_model=llm_response.model if llm_response else None, + llm_response_tool_call=llm_response.tool_calls if llm_response else None, + is_group_message=session.is_group_session, + is_private_message=not session.is_group_session, + action_usage=action_usage, + additional_data={"response_is_processed": True}, + ) diff --git a/src/chat/heart_flow/hfc_utils.py b/src/chat/heart_flow/hfc_utils.py index 3ad36ac8..36820d1c 100644 --- a/src/chat/heart_flow/hfc_utils.py +++ b/src/chat/heart_flow/hfc_utils.py @@ -6,7 +6,7 @@ import time from src.config.config import global_config from src.common.logger import get_logger from src.chat.message_receive.chat_manager import chat_manager as _chat_manager -from src.plugin_system.apis import send_api +from src.services import send_service as send_api from src.common.message_repository import count_messages diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index de1dc7dd..eec9b3c0 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -11,8 +11,9 @@ from src.common.utils.utils_session import SessionUtils from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver from src.chat.brain_chat.PFC.pfc_manager import PFCManager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager -from src.plugin_system.core import component_registry, events_manager, global_announcement_manager -from src.plugin_system.base import BaseCommand, EventType +from src.core.announcement_manager import global_announcement_manager +from src.core.component_registry import component_registry +from src.core.types import EventType from .message import SessionMessage from .chat_manager import chat_manager @@ -65,10 +66,10 @@ class ChatBot: try: text = message.processed_plain_text - # 使用新的组件注册中心查找命令 + # 使用核心组件注册表查找命令 command_result = component_registry.find_command_by_text(text) if command_result: - command_class, matched_groups, command_info = command_result + command_executor, matched_groups, command_info = command_result plugin_name = command_info.plugin_name command_name = command_info.name if message.session_id and command_name in global_announcement_manager.get_disabled_chat_commands( @@ -82,20 +83,20 @@ class ChatBot: # 获取插件配置 plugin_config = component_registry.get_plugin_config(plugin_name) - # 创建命令实例 - command_instance: BaseCommand = command_class(message, plugin_config) - command_instance.set_matched_groups(matched_groups) - try: - # 执行命令 - success, response, intercept_message_level = await command_instance.execute() + # 调用命令执行器 + success, response, intercept_message_level = await command_executor( + message=message, + plugin_config=plugin_config, + matched_groups=matched_groups, + ) message.intercept_message_level = intercept_message_level # 记录命令执行结果 if success: - logger.info(f"命令执行成功: {command_class.__name__} (拦截等级: {intercept_message_level})") + logger.info(f"命令执行成功: {command_name} (拦截等级: {intercept_message_level})") else: - logger.warning(f"命令执行失败: {command_class.__name__} - {response}") + logger.warning(f"命令执行失败: {command_name} - {response}") # 根据命令的拦截设置决定是否继续处理消息 return ( @@ -105,14 +106,9 @@ class ChatBot: ) # 找到命令,根据intercept_message决定是否继续 except Exception as e: - logger.error(f"执行命令时出错: {command_class.__name__} - {e}") + logger.error(f"执行命令时出错: {command_name} - {e}") logger.error(traceback.format_exc()) - try: - await command_instance.send_text(f"命令执行出错: {str(e)}") - except Exception as send_error: - logger.error(f"发送错误消息失败: {send_error}") - # 命令出错时,根据命令的拦截设置决定是否继续处理消息 return True, str(e), False # 出错时继续处理消息 diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py index eb04109c..de980d9c 100644 --- a/src/chat/message_receive/uni_message_sender.py +++ b/src/chat/message_receive/uni_message_sender.py @@ -318,11 +318,13 @@ class UniversalMessageSender: message.build_reply() logger.debug(f"[{chat_id}] 选择回复引用消息: {message.processed_plain_text[:20]}...") - from src.plugin_system.core.events_manager import events_manager - from src.plugin_system.base.component_types import EventType + from src.core.event_bus import event_bus + from src.chat.event_helpers import build_event_message + from src.core.types import EventType - continue_flag, modified_message = await events_manager.handle_mai_events( - EventType.POST_SEND_PRE_PROCESS, message=message, stream_id=chat_id + _event_msg = build_event_message(EventType.POST_SEND_PRE_PROCESS, message=message, stream_id=chat_id) + continue_flag, modified_message = await event_bus.emit( + EventType.POST_SEND_PRE_PROCESS, _event_msg ) if not continue_flag: logger.info(f"[{chat_id}] 消息发送被插件取消: {str(message.message_segment)[:100]}...") @@ -336,8 +338,9 @@ class UniversalMessageSender: await message.process() - continue_flag, modified_message = await events_manager.handle_mai_events( - EventType.POST_SEND, message=message, stream_id=chat_id + _event_msg = build_event_message(EventType.POST_SEND, message=message, stream_id=chat_id) + continue_flag, modified_message = await event_bus.emit( + EventType.POST_SEND, _event_msg ) if not continue_flag: logger.info(f"[{chat_id}] 消息发送被插件取消: {str(message.message_segment)[:100]}...") @@ -360,8 +363,9 @@ class UniversalMessageSender: if not sent_msg: return False - continue_flag, modified_message = await events_manager.handle_mai_events( - EventType.AFTER_SEND, message=message, stream_id=chat_id + _event_msg = build_event_message(EventType.AFTER_SEND, message=message, stream_id=chat_id) + continue_flag, modified_message = await event_bus.emit( + EventType.AFTER_SEND, _event_msg ) if not continue_flag: logger.info(f"[{chat_id}] 消息发送后续处理被插件取消: {str(message.message_segment)[:100]}...") diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index 2e11474a..05673778 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -1,20 +1,34 @@ -from typing import Dict, Optional, Type +from typing import Dict, Optional, Tuple from src.chat.message_receive.chat_manager import BotChatSession from src.common.logger import get_logger from src.common.data_models.database_data_model import DatabaseMessages -from src.plugin_system.core.component_registry import component_registry -from src.plugin_system.base.component_types import ComponentType, ActionInfo -from src.plugin_system.base.base_action import BaseAction +from src.core.component_registry import component_registry, ActionExecutor +from src.core.types import ActionInfo, ComponentType logger = get_logger("action_manager") +class ActionHandle: + """Action 执行句柄 + + 不依赖任何插件基类,内部持有 executor (async callable) 和绑定参数。 + brain_chat 调用 ``await handle.execute()`` 即可。 + """ + + def __init__(self, executor: ActionExecutor, **kwargs): + self._executor = executor + self._kwargs = kwargs + + async def execute(self) -> Tuple[bool, str]: + return await self._executor(**self._kwargs) + + class ActionManager: """ 动作管理器,用于管理各种类型的动作 - 现在统一使用新插件系统,简化了原有的新旧兼容逻辑。 + 使用核心组件注册表的 executor-based 模式。 """ def __init__(self): @@ -39,9 +53,9 @@ class ActionManager: log_prefix: str, shutting_down: bool = False, action_message: Optional[DatabaseMessages] = None, - ) -> Optional[BaseAction]: + ) -> Optional[ActionHandle]: """ - 创建动作处理器实例 + 创建动作执行句柄 Args: action_name: 动作名称 @@ -52,30 +66,26 @@ class ActionManager: chat_stream: 聊天流 log_prefix: 日志前缀 shutting_down: 是否正在关闭 + action_message: 动作消息记录 Returns: - Optional[BaseAction]: 创建的动作处理器实例,如果动作名称未注册则返回None + Optional[ActionHandle]: 执行句柄,如果动作未注册则返回 None """ try: - # 获取组件类 - 明确指定查询Action类型 - component_class: Type[BaseAction] = component_registry.get_component_class( - action_name, ComponentType.ACTION - ) # type: ignore - if not component_class: + executor = component_registry.get_action_executor(action_name) + if not executor: logger.warning(f"{log_prefix} 未找到Action组件: {action_name}") return None - # 获取组件信息 - component_info = component_registry.get_component_info(action_name, ComponentType.ACTION) - if not component_info: + info = component_registry.get_action_info(action_name) + if not info: logger.warning(f"{log_prefix} 未找到Action组件信息: {action_name}") return None - # 获取插件配置 - plugin_config = component_registry.get_plugin_config(component_info.plugin_name) + plugin_config = component_registry.get_plugin_config(info.plugin_name) or {} - # 创建动作实例 - instance = component_class( + handle = ActionHandle( + executor, action_data=action_data, action_reasoning=action_reasoning, cycle_timers=cycle_timers, @@ -87,11 +97,11 @@ class ActionManager: action_message=action_message, ) - logger.debug(f"创建Action实例成功: {action_name}") - return instance + logger.debug(f"创建Action执行句柄成功: {action_name}") + return handle except Exception as e: - logger.error(f"创建Action实例失败 {action_name}: {e}") + logger.error(f"创建Action执行句柄失败 {action_name}: {e}") import traceback logger.error(traceback.format_exc()) diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index b7ea858a..a34f42d4 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -7,8 +7,8 @@ from src.config.config import global_config from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager from src.chat.planner_actions.action_manager import ActionManager from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages -from src.plugin_system.base.component_types import ActionInfo, ActionActivationType -from src.plugin_system.core.global_announcement_manager import global_announcement_manager +from src.core.types import ActionActivationType, ActionInfo +from src.core.announcement_manager import global_announcement_manager logger = get_logger("action_manager") diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index c9c7f2f8..e7b013e6 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -23,9 +23,9 @@ from src.chat.utils.chat_message_builder import ( from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self from src.chat.planner_actions.action_manager import ActionManager from src.chat.message_receive.chat_manager import chat_manager as _chat_manager -from src.plugin_system.base.component_types import ActionInfo, ComponentType, ActionActivationType -from src.plugin_system.core.component_registry import component_registry -from src.plugin_system.apis.message_api import translate_pid_to_description +from src.core.types import ActionActivationType, ActionInfo, ComponentType +from src.core.component_registry import component_registry +from src.services.message_service import translate_pid_to_description from src.person_info.person_info import Person if TYPE_CHECKING: diff --git a/src/chat/replyer/group_generator.py b/src/chat/replyer/group_generator.py index f49ffda3..1889f144 100644 --- a/src/chat/replyer/group_generator.py +++ b/src/chat/replyer/group_generator.py @@ -27,12 +27,12 @@ from src.chat.utils.chat_message_builder import ( replace_user_references, ) from src.bw_learner.expression_selector import expression_selector -from src.plugin_system.apis.message_api import translate_pid_to_description +from src.services.message_service import translate_pid_to_description # from src.memory_system.memory_activator import MemoryActivator from src.person_info.person_info import Person -from src.plugin_system.base.component_types import ActionInfo, EventType -from src.plugin_system.apis import llm_api +from src.core.types import ActionInfo, EventType +from src.services import llm_service as llm_api from src.chat.logger.plan_reply_logger import PlanReplyLogger from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt @@ -56,7 +56,7 @@ class DefaultReplyer: self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id) self.heart_fc_sender = UniversalMessageSender() - from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor,不然会循环依赖 + from src.chat.tool_executor import ToolExecutor self.tool_executor = ToolExecutor(chat_id=self.chat_stream.session_id, enable_cache=True, cache_ttl=3) @@ -149,11 +149,13 @@ class DefaultReplyer: except Exception: logger.exception("记录reply日志失败") return False, llm_response - from src.plugin_system.core.events_manager import events_manager + from src.core.event_bus import event_bus + from src.chat.event_helpers import build_event_message if not from_plugin: - continue_flag, modified_message = await events_manager.handle_mai_events( - EventType.POST_LLM, None, prompt, None, stream_id=stream_id + _event_msg = build_event_message(EventType.POST_LLM, llm_prompt=prompt, stream_id=stream_id) + continue_flag, modified_message = await event_bus.emit( + EventType.POST_LLM, _event_msg ) if not continue_flag: raise UserWarning("插件于请求前中断了内容生成") @@ -217,8 +219,9 @@ class DefaultReplyer: ) except Exception: logger.exception("记录reply日志失败") - continue_flag, modified_message = await events_manager.handle_mai_events( - EventType.AFTER_LLM, None, prompt, llm_response, stream_id=stream_id + _event_msg = build_event_message(EventType.AFTER_LLM, llm_prompt=prompt, llm_response=llm_response, stream_id=stream_id) + continue_flag, modified_message = await event_bus.emit( + EventType.AFTER_LLM, _event_msg ) if not from_plugin and not continue_flag: raise UserWarning("插件于请求后取消了内容生成") diff --git a/src/chat/replyer/private_generator.py b/src/chat/replyer/private_generator.py index c79b0c54..ee6f98bc 100644 --- a/src/chat/replyer/private_generator.py +++ b/src/chat/replyer/private_generator.py @@ -28,12 +28,12 @@ from src.chat.utils.chat_message_builder import ( replace_user_references, ) from src.bw_learner.expression_selector import expression_selector -from src.plugin_system.apis.message_api import translate_pid_to_description +from src.services.message_service import translate_pid_to_description # from src.memory_system.memory_activator import MemoryActivator from src.person_info.person_info import Person, is_person_known -from src.plugin_system.base.component_types import ActionInfo, EventType -from src.plugin_system.apis import llm_api +from src.core.types import ActionInfo, EventType +from src.services import llm_service as llm_api from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt from src.bw_learner.jargon_explainer import explain_jargon_in_context @@ -55,7 +55,7 @@ class PrivateReplyer: self.heart_fc_sender = UniversalMessageSender() # self.memory_activator = MemoryActivator() - from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor,不然会循环依赖 + from src.chat.tool_executor import ToolExecutor self.tool_executor = ToolExecutor(chat_id=self.chat_stream.session_id, enable_cache=True, cache_ttl=3) @@ -114,11 +114,13 @@ class PrivateReplyer: if not prompt: logger.warning("构建prompt失败,跳过回复生成") return False, llm_response - from src.plugin_system.core.events_manager import events_manager + from src.core.event_bus import event_bus + from src.chat.event_helpers import build_event_message if not from_plugin: - continue_flag, modified_message = await events_manager.handle_mai_events( - EventType.POST_LLM, None, prompt, None, stream_id=stream_id + _event_msg = build_event_message(EventType.POST_LLM, llm_prompt=prompt, stream_id=stream_id) + continue_flag, modified_message = await event_bus.emit( + EventType.POST_LLM, _event_msg ) if not continue_flag: raise UserWarning("插件于请求前中断了内容生成") @@ -138,8 +140,9 @@ class PrivateReplyer: llm_response.reasoning = reasoning_content llm_response.model = model_name llm_response.tool_calls = tool_call - continue_flag, modified_message = await events_manager.handle_mai_events( - EventType.AFTER_LLM, None, prompt, llm_response, stream_id=stream_id + _event_msg = build_event_message(EventType.AFTER_LLM, llm_prompt=prompt, llm_response=llm_response, stream_id=stream_id) + continue_flag, modified_message = await event_bus.emit( + EventType.AFTER_LLM, _event_msg ) if not from_plugin and not continue_flag: raise UserWarning("插件于请求后取消了内容生成") diff --git a/src/plugin_system/core/tool_use.py b/src/chat/tool_executor.py similarity index 53% rename from src/plugin_system/core/tool_use.py rename to src/chat/tool_executor.py index 0cdbb472..d449f7a1 100644 --- a/src/plugin_system/core/tool_use.py +++ b/src/chat/tool_executor.py @@ -1,14 +1,23 @@ +""" +工具执行器 + +独立的工具执行组件,可以直接输入聊天消息内容, +自动判断并执行相应的工具,返回结构化的工具执行结果。 + +从 src.plugin_system.core.tool_use 迁移,使用新的核心组件注册表。 +""" + +import hashlib import time -from typing import List, Dict, Tuple, Optional, Any -from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance -from src.plugin_system.base.base_tool import BaseTool -from src.plugin_system.core.global_announcement_manager import global_announcement_manager -from src.llm_models.utils_model import LLMRequest -from src.llm_models.payload_content import ToolCall -from src.config.config import global_config, model_config -from src.prompt.prompt_manager import prompt_manager -from src.chat.message_receive.chat_manager import chat_manager as _chat_manager +from typing import Any, Dict, List, Optional, Tuple + from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.core.announcement_manager import global_announcement_manager +from src.core.component_registry import component_registry +from src.llm_models.payload_content import ToolCall +from src.llm_models.utils_model import LLMRequest +from src.prompt.prompt_manager import prompt_manager logger = get_logger("tool_use") @@ -20,66 +29,38 @@ class ToolExecutor: """ def __init__(self, chat_id: str, enable_cache: bool = True, cache_ttl: int = 3): - """初始化工具执行器 + from src.chat.message_receive.chat_manager import chat_manager as _chat_manager - Args: - executor_id: 执行器标识符,用于日志记录 - enable_cache: 是否启用缓存机制 - cache_ttl: 缓存生存时间(周期数) - """ self.chat_id = chat_id self.chat_stream = _chat_manager.get_session_by_session_id(self.chat_id) self.log_prefix = f"[{_chat_manager.get_session_name(self.chat_id) or self.chat_id}]" self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor") - # 缓存配置 self.enable_cache = enable_cache self.cache_ttl = cache_ttl - self.tool_cache = {} # 格式: {cache_key: {"result": result, "ttl": ttl, "timestamp": timestamp}} + self.tool_cache: Dict[str, dict] = {} logger.info(f"{self.log_prefix}工具执行器初始化完成,缓存{'启用' if enable_cache else '禁用'},TTL={cache_ttl}") async def execute_from_chat_message( self, target_message: str, chat_history: str, sender: str, return_details: bool = False ) -> Tuple[List[Dict[str, Any]], List[str], str]: - """从聊天消息执行工具 + """从聊天消息执行工具""" - Args: - target_message: 目标消息内容 - chat_history: 聊天历史 - sender: 发送者 - return_details: 是否返回详细信息(使用的工具列表和提示词) - - Returns: - 如果return_details为False: Tuple[List[Dict], List[str], str] - (工具执行结果列表, 空, 空) - 如果return_details为True: Tuple[List[Dict], List[str], str] - (结果列表, 使用的工具, 提示词) - """ - - # 首先检查缓存 cache_key = self._generate_cache_key(target_message, chat_history, sender) if cached_result := self._get_from_cache(cache_key): logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行") if not return_details: return cached_result, [], "" - - # 从缓存结果中提取工具名称 used_tools = [result.get("tool_name", "unknown") for result in cached_result] return cached_result, used_tools, "" - # 缓存未命中,执行工具调用 - # 获取可用工具 tools = self._get_tool_definitions() - - # 如果没有可用工具,直接返回空内容 if not tools: logger.debug(f"{self.log_prefix}没有可用工具,直接返回空内容") - if return_details: - return [], [], "" - else: - return [], [], "" + return [], [], "" - # 构建工具调用提示词 prompt_template = prompt_manager.get_prompt("tool_executor") prompt_template.add_context("target_message", target_message) prompt_template.add_context("chat_history", chat_history) @@ -90,15 +71,12 @@ class ToolExecutor: logger.debug(f"{self.log_prefix}开始LLM工具调用分析") - # 调用LLM进行工具决策 response, (reasoning_content, model_name, tool_calls) = await self.llm_model.generate_response_async( prompt=prompt, tools=tools, raise_when_empty=False ) - # 执行工具调用 tool_results, used_tools = await self.execute_tool_calls(tool_calls) - # 缓存结果 if tool_results: self._set_cache(cache_key, tool_results) @@ -107,42 +85,30 @@ class ToolExecutor: if return_details: return tool_results, used_tools, prompt - else: - return tool_results, [], "" + return tool_results, [], "" def _get_tool_definitions(self) -> List[Dict[str, Any]]: - all_tools = get_llm_available_tool_definitions() + """获取 LLM 可用的工具定义列表""" + all_tools = component_registry.get_llm_available_tools() user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id) - return [definition for name, definition in all_tools if name not in user_disabled_tools] + return [info.get_llm_definition() for name, info in all_tools.items() if name not in user_disabled_tools] async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]: - """执行工具调用 - - Args: - tool_calls: LLM返回的工具调用列表 - - Returns: - Tuple[List[Dict], List[str]]: (工具执行结果列表, 使用的工具名称列表) - """ + """执行工具调用列表""" tool_results: List[Dict[str, Any]] = [] - used_tools = [] + used_tools: List[str] = [] if not tool_calls: logger.debug(f"{self.log_prefix}无需执行工具") return [], [] - # 提取tool_calls中的函数名称 func_names = [call.func_name for call in tool_calls if call.func_name] - logger.info(f"{self.log_prefix}开始执行工具调用: {func_names}") - # 执行每个工具调用 for tool_call in tool_calls: + tool_name = tool_call.func_name try: - tool_name = tool_call.func_name logger.debug(f"{self.log_prefix}执行工具: {tool_name}") - - # 执行工具 result = await self.execute_tool_call(tool_call) if result: @@ -156,7 +122,6 @@ class ToolExecutor: content = tool_info["content"] if not isinstance(content, (str, list, tuple)): tool_info["content"] = str(content) - # 空内容直接跳过(空字符串、全空白字符串、空列表/空元组) content_check = tool_info["content"] if (isinstance(content_check, str) and not content_check.strip()) or ( isinstance(content_check, (list, tuple)) and len(content_check) == 0 @@ -166,11 +131,10 @@ class ToolExecutor: tool_results.append(tool_info) used_tools.append(tool_name) - preview = content[:200] + preview = str(content)[:200] logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...") except Exception as e: logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}") - # 添加错误信息到结果中 error_info = { "type": "tool_error", "id": f"tool_error_{time.time()}", @@ -182,122 +146,30 @@ class ToolExecutor: return tool_results, used_tools - async def execute_tool_call( - self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None - ) -> Optional[Dict[str, Any]]: - # sourcery skip: use-assigned-variable - """执行单个工具调用 + async def execute_tool_call(self, tool_call: ToolCall) -> Optional[Dict[str, Any]]: + """执行单个工具调用""" + function_name = tool_call.func_name + function_args = tool_call.args or {} + function_args["llm_called"] = True - Args: - tool_call: 工具调用对象 - - Returns: - Optional[Dict]: 工具调用结果,如果失败则返回None - """ - try: - function_name = tool_call.func_name - function_args = tool_call.args or {} - function_args["llm_called"] = True # 标记为LLM调用 - - # 获取对应工具实例 - tool_instance = tool_instance or get_tool_instance(function_name, self.chat_stream) - if not tool_instance: - logger.warning(f"未知工具名称: {function_name}") - return None - - # 执行工具 - result = await tool_instance.execute(function_args) - if result: - return { - "tool_call_id": tool_call.call_id, - "role": "tool", - "name": function_name, - "type": "function", - "content": result["content"], - } - return None - except Exception as e: - logger.error(f"执行工具调用时发生错误: {str(e)}") - raise e - - def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str: - """生成缓存键 - - Args: - target_message: 目标消息内容 - chat_history: 聊天历史 - sender: 发送者 - - Returns: - str: 缓存键 - """ - import hashlib - - # 使用消息内容和群聊状态生成唯一缓存键 - content = f"{target_message}_{chat_history}_{sender}" - return hashlib.md5(content.encode()).hexdigest() - - def _get_from_cache(self, cache_key: str) -> Optional[List[Dict]]: - """从缓存获取结果 - - Args: - cache_key: 缓存键 - - Returns: - Optional[List[Dict]]: 缓存的结果,如果不存在或过期则返回None - """ - if not self.enable_cache or cache_key not in self.tool_cache: + executor = component_registry.get_tool_executor(function_name) + if not executor: + logger.warning(f"未知工具名称: {function_name}") return None - cache_item = self.tool_cache[cache_key] - if cache_item["ttl"] <= 0: - # 缓存过期,删除 - del self.tool_cache[cache_key] - logger.debug(f"{self.log_prefix}缓存过期,删除缓存键: {cache_key}") - return None - - # 减少TTL - cache_item["ttl"] -= 1 - logger.debug(f"{self.log_prefix}使用缓存结果,剩余TTL: {cache_item['ttl']}") - return cache_item["result"] - - def _set_cache(self, cache_key: str, result: List[Dict]): - """设置缓存 - - Args: - cache_key: 缓存键 - result: 要缓存的结果 - """ - if not self.enable_cache: - return - - self.tool_cache[cache_key] = {"result": result, "ttl": self.cache_ttl, "timestamp": time.time()} - logger.debug(f"{self.log_prefix}设置缓存,TTL: {self.cache_ttl}") - - def _cleanup_expired_cache(self): - """清理过期的缓存""" - if not self.enable_cache: - return - - expired_keys = [] - expired_keys.extend(cache_key for cache_key, cache_item in self.tool_cache.items() if cache_item["ttl"] <= 0) - for key in expired_keys: - del self.tool_cache[key] - - if expired_keys: - logger.debug(f"{self.log_prefix}清理了{len(expired_keys)}个过期缓存") + result = await executor(function_args) + if result: + return { + "tool_call_id": tool_call.call_id, + "role": "tool", + "name": function_name, + "type": "function", + "content": result["content"], + } + return None async def execute_specific_tool_simple(self, tool_name: str, tool_args: Dict) -> Optional[Dict]: - """直接执行指定工具 - - Args: - tool_name: 工具名称 - tool_args: 工具参数 - validate_args: 是否验证参数 - - Returns: - Optional[Dict]: 工具执行结果,失败时返回None - """ + """直接执行指定工具""" try: tool_call = ToolCall( call_id=f"direct_tool_{time.time()}", @@ -306,7 +178,6 @@ class ToolExecutor: ) logger.info(f"{self.log_prefix}直接执行工具: {tool_name}") - result = await self.execute_tool_call(tool_call) if result: @@ -325,86 +196,55 @@ class ToolExecutor: return None + # === 缓存方法 === + + def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str: + content = f"{target_message}_{chat_history}_{sender}" + return hashlib.md5(content.encode()).hexdigest() + + def _get_from_cache(self, cache_key: str) -> Optional[List[Dict]]: + if not self.enable_cache or cache_key not in self.tool_cache: + return None + cache_item = self.tool_cache[cache_key] + if cache_item["ttl"] <= 0: + del self.tool_cache[cache_key] + return None + cache_item["ttl"] -= 1 + return cache_item["result"] + + def _set_cache(self, cache_key: str, result: List[Dict]): + if not self.enable_cache: + return + self.tool_cache[cache_key] = {"result": result, "ttl": self.cache_ttl, "timestamp": time.time()} + + def _cleanup_expired_cache(self): + if not self.enable_cache: + return + expired = [k for k, v in self.tool_cache.items() if v["ttl"] <= 0] + for key in expired: + del self.tool_cache[key] + def clear_cache(self): - """清空所有缓存""" if self.enable_cache: - cache_count = len(self.tool_cache) self.tool_cache.clear() - logger.info(f"{self.log_prefix}清空了{cache_count}个缓存项") def get_cache_status(self) -> Dict: - """获取缓存状态信息 - - Returns: - Dict: 包含缓存统计信息的字典 - """ if not self.enable_cache: return {"enabled": False, "cache_count": 0} - - # 清理过期缓存 self._cleanup_expired_cache() - - total_count = len(self.tool_cache) - ttl_distribution = {} - - for cache_item in self.tool_cache.values(): - ttl = cache_item["ttl"] + ttl_distribution: Dict[int, int] = {} + for item in self.tool_cache.values(): + ttl = item["ttl"] ttl_distribution[ttl] = ttl_distribution.get(ttl, 0) + 1 - return { "enabled": True, - "cache_count": total_count, + "cache_count": len(self.tool_cache), "cache_ttl": self.cache_ttl, "ttl_distribution": ttl_distribution, } def set_cache_config(self, enable_cache: Optional[bool] = None, cache_ttl: int = -1): - """动态修改缓存配置 - - Args: - enable_cache: 是否启用缓存 - cache_ttl: 缓存TTL - """ if enable_cache is not None: self.enable_cache = enable_cache - logger.info(f"{self.log_prefix}缓存状态修改为: {'启用' if enable_cache else '禁用'}") - if cache_ttl > 0: self.cache_ttl = cache_ttl - logger.info(f"{self.log_prefix}缓存TTL修改为: {cache_ttl}") - - -""" -ToolExecutor使用示例: - -# 1. 基础使用 - 从聊天消息执行工具(启用缓存,默认TTL=3) -executor = ToolExecutor(executor_id="my_executor") -results, _, _ = await executor.execute_from_chat_message( - talking_message_str="今天天气怎么样?现在几点了?", - is_group_chat=False -) - -# 2. 禁用缓存的执行器 -no_cache_executor = ToolExecutor(executor_id="no_cache", enable_cache=False) - -# 3. 自定义缓存TTL -long_cache_executor = ToolExecutor(executor_id="long_cache", cache_ttl=10) - -# 4. 获取详细信息 -results, used_tools, prompt = await executor.execute_from_chat_message( - talking_message_str="帮我查询Python相关知识", - is_group_chat=False, - return_details=True -) - -# 5. 直接执行特定工具 -result = await executor.execute_specific_tool_simple( - tool_name="get_knowledge", - tool_args={"query": "机器学习"} -) - -# 6. 缓存管理 -cache_status = executor.get_cache_status() # 查看缓存状态 -executor.clear_cache() # 清空缓存 -executor.set_cache_config(cache_ttl=5) # 动态修改缓存配置 -""" diff --git a/src/common/data_models/info_data_model.py b/src/common/data_models/info_data_model.py index 24cbc640..13e53eb3 100644 --- a/src/common/data_models/info_data_model.py +++ b/src/common/data_models/info_data_model.py @@ -4,7 +4,7 @@ from . import BaseDataModel if TYPE_CHECKING: from .database_data_model import DatabaseMessages - from src.plugin_system.base.component_types import ActionInfo + from src.core.types import ActionInfo @dataclass diff --git a/src/core/__init__.py b/src/core/__init__.py new file mode 100644 index 00000000..82910380 --- /dev/null +++ b/src/core/__init__.py @@ -0,0 +1,6 @@ +""" +MaiBot 核心基础设施 + +提供与插件系统无关的核心类型定义、配置 schema 等基础设施。 +这些类型被整个项目共享,包括内部模块、服务层、旧插件系统和新插件运行时。 +""" diff --git a/src/plugin_system/core/global_announcement_manager.py b/src/core/announcement_manager.py similarity index 100% rename from src/plugin_system/core/global_announcement_manager.py rename to src/core/announcement_manager.py diff --git a/src/core/component_registry.py b/src/core/component_registry.py new file mode 100644 index 00000000..b9c7af73 --- /dev/null +++ b/src/core/component_registry.py @@ -0,0 +1,242 @@ +""" +核心组件注册表 + +面向最终架构的组件管理: +- Action:注册 ActionInfo + 执行器(本地 callable 或 IPC 路由) +- Command:注册正则模式 + 执行器 +- Tool:注册工具定义 + 执行器 + +不依赖任何插件基类,组件执行器是纯 async callable。 +""" + +import re +from typing import Any, Awaitable, Callable, Dict, List, Optional, Pattern, Tuple, Union + +from src.common.logger import get_logger +from src.core.types import ( + ActionActivationType, + ActionInfo, + CommandInfo, + ComponentInfo, + ComponentType, + ToolInfo, +) + +logger = get_logger("component_registry") + +# 执行器类型 +ActionExecutor = Callable[..., Awaitable[Any]] +CommandExecutor = Callable[..., Awaitable[Tuple[bool, Optional[str], bool]]] +ToolExecutor = Callable[..., Awaitable[Any]] + + +class ComponentRegistry: + """核心组件注册表 + + 管理 action、command、tool 三类组件。 + 每个组件由「元信息 + 执行器」构成,执行器是 async callable, + 不需要继承任何基类。 + """ + + def __init__(self): + # Action 注册 + self._actions: Dict[str, ActionInfo] = {} + self._action_executors: Dict[str, ActionExecutor] = {} + self._default_actions: Dict[str, ActionInfo] = {} + + # Command 注册 + self._commands: Dict[str, CommandInfo] = {} + self._command_executors: Dict[str, CommandExecutor] = {} + self._command_patterns: Dict[Pattern, str] = {} + + # Tool 注册 + self._tools: Dict[str, ToolInfo] = {} + self._tool_executors: Dict[str, ToolExecutor] = {} + self._llm_available_tools: Dict[str, ToolInfo] = {} + + # 插件配置(plugin_name -> config dict) + self._plugin_configs: Dict[str, dict] = {} + + logger.info("核心组件注册表初始化完成") + + # ========== Action ========== + + def register_action( + self, + info: ActionInfo, + executor: ActionExecutor, + ) -> bool: + """注册 action + + Args: + info: action 元信息 + executor: 执行器,async callable + """ + name = info.name + if name in self._actions: + logger.warning(f"Action {name} 已存在,跳过注册") + return False + + self._actions[name] = info + self._action_executors[name] = executor + + if info.enabled: + self._default_actions[name] = info + + logger.debug(f"注册 Action: {name}") + return True + + def get_action_info(self, name: str) -> Optional[ActionInfo]: + return self._actions.get(name) + + def get_action_executor(self, name: str) -> Optional[ActionExecutor]: + return self._action_executors.get(name) + + def get_default_actions(self) -> Dict[str, ActionInfo]: + return self._default_actions.copy() + + def get_all_actions(self) -> Dict[str, ActionInfo]: + return self._actions.copy() + + def remove_action(self, name: str) -> bool: + if name not in self._actions: + return False + del self._actions[name] + self._action_executors.pop(name, None) + self._default_actions.pop(name, None) + logger.debug(f"移除 Action: {name}") + return True + + # ========== Command ========== + + def register_command( + self, + info: CommandInfo, + executor: CommandExecutor, + ) -> bool: + """注册 command""" + name = info.name + if name in self._commands: + logger.warning(f"Command {name} 已存在,跳过注册") + return False + + self._commands[name] = info + self._command_executors[name] = executor + + if info.enabled and info.command_pattern: + pattern = re.compile(info.command_pattern, re.IGNORECASE | re.DOTALL) + self._command_patterns[pattern] = name + + logger.debug(f"注册 Command: {name}") + return True + + def find_command_by_text( + self, text: str + ) -> Optional[Tuple[CommandExecutor, dict, CommandInfo]]: + """根据文本查找匹配的命令 + + Returns: + (executor, matched_groups, command_info) 或 None + """ + candidates = [p for p in self._command_patterns if p.match(text)] + if not candidates: + return None + if len(candidates) > 1: + logger.warning(f"文本 '{text[:50]}' 匹配到多个命令模式,使用第一个") + pattern = candidates[0] + name = self._command_patterns[pattern] + return ( + self._command_executors[name], + pattern.match(text).groupdict(), # type: ignore + self._commands[name], + ) + + def remove_command(self, name: str) -> bool: + if name not in self._commands: + return False + del self._commands[name] + self._command_executors.pop(name, None) + self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != name} + logger.debug(f"移除 Command: {name}") + return True + + # ========== Tool ========== + + def register_tool( + self, + info: ToolInfo, + executor: ToolExecutor, + ) -> bool: + """注册 tool""" + name = info.name + if name in self._tools: + logger.warning(f"Tool {name} 已存在,跳过注册") + return False + + self._tools[name] = info + self._tool_executors[name] = executor + + if info.enabled: + self._llm_available_tools[name] = info + + logger.debug(f"注册 Tool: {name}") + return True + + def get_tool_info(self, name: str) -> Optional[ToolInfo]: + return self._tools.get(name) + + def get_tool_executor(self, name: str) -> Optional[ToolExecutor]: + return self._tool_executors.get(name) + + def get_llm_available_tools(self) -> Dict[str, ToolInfo]: + return self._llm_available_tools.copy() + + def get_all_tools(self) -> Dict[str, ToolInfo]: + return self._tools.copy() + + def remove_tool(self, name: str) -> bool: + if name not in self._tools: + return False + del self._tools[name] + self._tool_executors.pop(name, None) + self._llm_available_tools.pop(name, None) + logger.debug(f"移除 Tool: {name}") + return True + + # ========== 通用查询 ========== + + def get_component_info(self, name: str, component_type: ComponentType) -> Optional[ComponentInfo]: + """获取组件元信息""" + match component_type: + case ComponentType.ACTION: + return self._actions.get(name) + case ComponentType.COMMAND: + return self._commands.get(name) + case ComponentType.TOOL: + return self._tools.get(name) + case _: + return None + + def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]: + """获取某类型的所有组件""" + match component_type: + case ComponentType.ACTION: + return dict(self._actions) + case ComponentType.COMMAND: + return dict(self._commands) + case ComponentType.TOOL: + return dict(self._tools) + case _: + return {} + + # ========== 插件配置 ========== + + def set_plugin_config(self, plugin_name: str, config: dict) -> None: + self._plugin_configs[plugin_name] = config + + def get_plugin_config(self, plugin_name: str) -> Optional[dict]: + return self._plugin_configs.get(plugin_name) + + +# 全局单例 +component_registry = ComponentRegistry() diff --git a/src/plugin_system/base/config_types.py b/src/core/config_types.py similarity index 100% rename from src/plugin_system/base/config_types.py rename to src/core/config_types.py diff --git a/src/core/event_bus.py b/src/core/event_bus.py new file mode 100644 index 00000000..f4dbed46 --- /dev/null +++ b/src/core/event_bus.py @@ -0,0 +1,216 @@ +""" +核心事件总线 + +面向最终架构的事件系统: +- 内部 handler 直接注册 async callable +- IPC 插件通过 plugin_runtime 桥接 +- 不依赖任何插件基类 +""" + +import asyncio +from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING + +from src.common.logger import get_logger +from src.core.types import EventType, MaiMessages + +if TYPE_CHECKING: + from src.common.data_models.llm_data_model import LLMGenerationDataModel + +logger = get_logger("event_bus") + +# Handler 签名:接收 MaiMessages,返回 (continue, modified_message) +EventHandler = Callable[[Optional[MaiMessages]], Awaitable[Tuple[bool, Optional[MaiMessages]]]] + + +class EventBus: + """核心事件总线 + + 支持两种 handler: + - 拦截型(intercept=True):同步顺序执行,可修改消息、可中断流程 + - 非拦截型(intercept=False):异步并发执行,fire-and-forget + + handler 是纯 async callable,不需要继承任何基类。 + """ + + def __init__(self): + # event_type -> [(handler, name, weight, intercept)] + self._handlers: Dict[EventType | str, List[_HandlerEntry]] = {} + self._running_tasks: Dict[str, List[asyncio.Task]] = {} + + # 预注册所有内置事件类型 + for event in EventType: + self._handlers[event] = [] + + def subscribe( + self, + event_type: EventType | str, + handler: EventHandler, + name: str, + weight: int = 0, + intercept: bool = False, + ) -> None: + """注册事件 handler + + Args: + event_type: 事件类型 + handler: async callable,签名 (Optional[MaiMessages]) -> (bool, Optional[MaiMessages]) + name: handler 标识名 + weight: 权重,越大越先执行 + intercept: 是否为拦截型(同步执行,可中断流程) + """ + if event_type not in self._handlers: + self._handlers[event_type] = [] + + entry = _HandlerEntry(handler=handler, name=name, weight=weight, intercept=intercept) + self._handlers[event_type].append(entry) + self._handlers[event_type].sort(key=lambda e: e.weight, reverse=True) + logger.debug(f"注册事件 handler: {name} -> {event_type} (weight={weight}, intercept={intercept})") + + def unsubscribe(self, event_type: EventType | str, name: str) -> bool: + """取消注册事件 handler""" + handlers = self._handlers.get(event_type, []) + for i, entry in enumerate(handlers): + if entry.name == name: + del handlers[i] + logger.debug(f"取消注册事件 handler: {name} <- {event_type}") + return True + return False + + async def emit( + self, + event_type: EventType | str, + message: Optional[MaiMessages] = None, + ) -> Tuple[bool, Optional[MaiMessages]]: + """触发事件 + + 按权重顺序执行所有 handler: + - 拦截型 handler 同步执行,可修改消息和中断流程 + - 非拦截型 handler 异步 fire-and-forget + + Args: + event_type: 事件类型 + message: 事件消息(可选) + + Returns: + (continue_flag, modified_message) + - continue_flag: False 表示某个拦截型 handler 要求中断 + - modified_message: 被拦截型 handler 修改后的消息 + """ + handlers = self._handlers.get(event_type, []) + if not handlers: + return True, None + + continue_flag = True + current_message = message.deepcopy() if message else None + + for entry in handlers: + if entry.intercept: + try: + should_continue, modified = await entry.handler(current_message) + if modified is not None: + current_message = modified + if not should_continue: + continue_flag = False + break + except Exception as e: + logger.error(f"拦截型 handler {entry.name} 执行异常: {e}", exc_info=True) + else: + self._fire_and_forget(entry, event_type, current_message) + + # 桥接到 IPC 插件运行时 + continue_flag, current_message = await self._bridge_to_ipc_runtime( + event_type, continue_flag, current_message + ) + + return continue_flag, current_message + + async def cancel_handler_tasks(self, handler_name: str) -> None: + """取消某个 handler 的所有运行中任务""" + tasks = self._running_tasks.pop(handler_name, []) + remaining = [t for t in tasks if not t.done()] + if remaining: + for t in remaining: + t.cancel() + await asyncio.gather(*remaining, return_exceptions=True) + logger.info(f"已取消 handler {handler_name} 的 {len(remaining)} 个任务") + + # --- 内部方法 --- + + def _fire_and_forget( + self, + entry: "_HandlerEntry", + event_type: EventType | str, + message: Optional[MaiMessages], + ) -> None: + """创建异步任务执行非拦截型 handler""" + try: + task = asyncio.create_task(entry.handler(message)) + task.set_name(entry.name) + task.add_done_callback(lambda t: self._task_done_callback(t, entry.name)) + self._running_tasks.setdefault(entry.name, []).append(task) + except Exception as e: + logger.error(f"创建 handler 任务 {entry.name} 失败: {e}", exc_info=True) + + def _task_done_callback(self, task: asyncio.Task, handler_name: str) -> None: + """异步任务完成回调""" + try: + if task.cancelled(): + return + exc = task.exception() + if exc: + logger.error(f"handler {handler_name} 异步任务异常: {exc}") + except Exception: + pass + finally: + task_list = self._running_tasks.get(handler_name, []) + try: + task_list.remove(task) + except ValueError: + pass + + async def _bridge_to_ipc_runtime( + self, + event_type: EventType | str, + continue_flag: bool, + message: Optional[MaiMessages], + ) -> Tuple[bool, Optional[MaiMessages]]: + """将事件桥接到 IPC 插件运行时""" + if not continue_flag: + return continue_flag, message + + try: + from src.plugin_runtime.integration import get_plugin_runtime_manager + + prm = get_plugin_runtime_manager() + if not prm.is_running: + return continue_flag, message + + event_value = event_type.value if isinstance(event_type, EventType) else str(event_type) + message_dict = message.to_dict() if message and hasattr(message, "to_dict") else None + + new_continue, _ = await prm.bridge_event( + event_type_value=event_value, + message_dict=message_dict, + ) + if not new_continue: + continue_flag = False + except Exception as e: + logger.warning(f"桥接事件到 IPC 运行时失败: {e}") + + return continue_flag, message + + +class _HandlerEntry: + """内部 handler 条目""" + + __slots__ = ("handler", "name", "weight", "intercept") + + def __init__(self, handler: EventHandler, name: str, weight: int, intercept: bool): + self.handler = handler + self.name = name + self.weight = weight + self.intercept = intercept + + +# 全局单例 +event_bus = EventBus() diff --git a/src/plugin_system/base/component_types.py b/src/core/types.py similarity index 98% rename from src/plugin_system/base/component_types.py rename to src/core/types.py index d9d6a06f..d7975576 100644 --- a/src/plugin_system/base/component_types.py +++ b/src/core/types.py @@ -169,6 +169,14 @@ class ToolInfo(ComponentInfo): super().__post_init__() self.component_type = ComponentType.TOOL + def get_llm_definition(self) -> dict: + """生成 LLM function-calling 所需的工具定义""" + return { + "name": self.name, + "description": self.tool_description, + "parameters": self.tool_parameters, + } + @dataclass class EventHandlerInfo(ComponentInfo): diff --git a/src/dream/dream_agent.py b/src/dream/dream_agent.py index 79acceeb..97b4b5d9 100644 --- a/src/dream/dream_agent.py +++ b/src/dream/dream_agent.py @@ -10,7 +10,7 @@ from src.config.config import global_config, model_config from src.common.database.database_model import ChatHistory from src.prompt.prompt_manager import prompt_manager from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message -from src.plugin_system.apis import llm_api +from src.services import llm_service as llm_api from src.dream.dream_generator import generate_dream_summary # dream 工具工厂函数 diff --git a/src/dream/dream_generator.py b/src/dream/dream_generator.py index 4934bee1..ff709a5f 100644 --- a/src/dream/dream_generator.py +++ b/src/dream/dream_generator.py @@ -8,7 +8,7 @@ from src.llm_models.payload_content.message import RoleType, Message from src.prompt.prompt_manager import prompt_manager from src.llm_models.utils_model import LLMRequest from src.common.utils.utils_session import SessionUtils -from src.plugin_system.apis import send_api +from src.services import send_service as send_api logger = get_logger("dream_generator") diff --git a/src/dream/tools/update_chat_history_tool.py b/src/dream/tools/update_chat_history_tool.py index acb98d45..2853dddd 100644 --- a/src/dream/tools/update_chat_history_tool.py +++ b/src/dream/tools/update_chat_history_tool.py @@ -2,7 +2,7 @@ from typing import Any, Dict, Optional from src.common.logger import get_logger from src.common.database.database_model import ChatHistory -from src.plugin_system.apis import database_api +from src.services import database_service as database_api logger = get_logger("dream_agent") diff --git a/src/dream/tools/update_jargon_tool.py b/src/dream/tools/update_jargon_tool.py index 1d559cf6..7ef17cb6 100644 --- a/src/dream/tools/update_jargon_tool.py +++ b/src/dream/tools/update_jargon_tool.py @@ -2,7 +2,7 @@ from typing import Any, Dict, Optional from src.common.logger import get_logger from src.common.database.database_model import Jargon -from src.plugin_system.apis import database_api +from src.services import database_service as database_api logger = get_logger("dream_agent") diff --git a/src/main.py b/src/main.py index c3634037..8b43c9f8 100644 --- a/src/main.py +++ b/src/main.py @@ -18,10 +18,7 @@ from rich.traceback import install # from src.api.main import start_api_server -# 导入新的插件管理器 -from src.plugin_system.core.plugin_manager import plugin_manager - -# 导入新版本插件运行时 +# 导入插件运行时 from src.plugin_runtime.integration import get_plugin_runtime_manager # 导入消息API和traceback模块 @@ -31,8 +28,6 @@ from src.bw_learner.expression_auto_check_task import ExpressionAutoCheckTask from src.prompt.prompt_manager import prompt_manager -# 插件系统现在使用统一的插件加载器 - install(extra_lines=3) logger = get_logger("main") @@ -108,10 +103,7 @@ class MainSystem: # 启动LPMM lpmm_start_up() - # 加载所有actions,包括默认的和插件的 - plugin_manager.load_all_plugins() - - # 启动新版本插件运行时(与旧系统并行运行) + # 启动插件运行时(内置插件 + 第三方插件双子进程) await get_plugin_runtime_manager().start() # 初始化表情管理器 @@ -133,12 +125,12 @@ class MainSystem: prompt_manager.load_prompts() # 触发 ON_START 事件 - from src.plugin_system.core.events_manager import events_manager - from src.plugin_system.base.component_types import EventType + from src.core.event_bus import event_bus + from src.core.types import EventType - await events_manager.handle_mai_events(event_type=EventType.ON_START) + await event_bus.emit(event_type=EventType.ON_START) - # 桥接 ON_START 事件到新版本插件运行时 + # 分发 ON_START 事件到插件运行时 await get_plugin_runtime_manager().bridge_event("on_start") # logger.info("已触发 ON_START 事件") try: diff --git a/src/memory_system/chat_history_summarizer.py b/src/memory_system/chat_history_summarizer.py index f8ff7adc..7dabdd98 100644 --- a/src/memory_system/chat_history_summarizer.py +++ b/src/memory_system/chat_history_summarizer.py @@ -17,7 +17,7 @@ from src.common.logger import get_logger from src.common.data_models.database_data_model import DatabaseMessages from src.config.config import model_config, global_config from src.llm_models.utils_model import LLMRequest -from src.plugin_system.apis import message_api +from src.services import message_service as message_api from src.chat.utils.chat_message_builder import build_readable_messages from src.chat.utils.utils import is_bot_self from src.person_info.person_info import Person @@ -913,7 +913,7 @@ class ChatHistorySummarizer: """存储到数据库""" try: from src.common.database.database_model import ChatHistory - from src.plugin_system.apis import database_api + from src.services import database_service as database_api # 准备数据 data = { diff --git a/src/memory_system/memory_retrieval.py b/src/memory_system/memory_retrieval.py index e0d4fcda..a14a9c67 100644 --- a/src/memory_system/memory_retrieval.py +++ b/src/memory_system/memory_retrieval.py @@ -7,7 +7,7 @@ from typing import List, Dict, Any, Optional, Tuple, Callable from src.common.logger import get_logger from src.config.config import global_config, model_config from src.prompt.prompt_manager import prompt_manager -from src.plugin_system.apis import llm_api +from src.services import llm_service as llm_api from sqlmodel import select, col from src.common.database.database import get_db_session from src.common.database.database_model import ThinkingQuestion diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py index 111f72d5..7c6a5be4 100644 --- a/src/plugin_runtime/integration.py +++ b/src/plugin_runtime/integration.py @@ -1,21 +1,17 @@ -"""新版本插件运行时与主程序的集成层 +"""插件运行时与主程序的集成层 提供 PluginRuntimeManager 单例,负责: -1. 管理 PluginSupervisor 的生命周期(启动 / 停止) -2. 将旧系统的 EventType 桥接到新运行时的 event dispatch -3. 将新运行时注册的 command 合并到旧系统的命令查找流程 -4. 提供统一的能力实现注册接口,使新插件可以调用主程序功能 - -在过渡期内,新旧插件系统共存: -- 旧插件继续通过 plugin_manager / component_registry 加载和执行 -- 新插件通过 PluginSupervisor + Runner 子进程加载和执行 -- 事件和命令在两套系统间桥接 +1. 管理双 PluginSupervisor 的生命周期(内置插件 / 第三方插件各一个子进程) +2. 将 EventType 桥接到运行时的 event dispatch +3. 在运行时的 ComponentRegistry 中查找命令 +4. 提供统一的能力实现注册接口,使插件可以调用主程序功能 """ from __future__ import annotations from typing import Any +import asyncio import os from src.common.logger import get_logger @@ -38,81 +34,111 @@ _EVENT_TYPE_MAP: dict[str, str] = { class PluginRuntimeManager: - """新版本插件运行时管理器(单例) + """插件运行时管理器(单例) - 作为主程序与 PluginSupervisor 之间的桥梁。 + 内置插件与第三方插件分别运行在各自的 Supervisor / Runner 子进程中。 """ def __init__(self) -> None: from src.plugin_runtime.host.supervisor import PluginSupervisor - self._supervisor: PluginSupervisor | None = None + self._builtin_supervisor: PluginSupervisor | None = None + self._thirdparty_supervisor: PluginSupervisor | None = None self._started: bool = False - def _get_plugin_dirs(self) -> list[str]: - """获取新版本插件目录列表 + # ─── 插件目录 ───────────────────────────────────────────── - 新版本插件放在 plugins/ 目录中,与旧版本共存。 - 只有包含 _manifest.json 的插件目录会被新 Runner 加载。 - """ - dirs: list[str] = [] - for candidate in ("plugins",): - abs_path: str = os.path.abspath(candidate) - if os.path.isdir(abs_path): - dirs.append(abs_path) - return dirs + @staticmethod + def _get_builtin_plugin_dirs() -> list[str]: + """内置插件目录: src/plugins/built_in/""" + candidate = os.path.abspath(os.path.join("src", "plugins", "built_in")) + return [candidate] if os.path.isdir(candidate) else [] + + @staticmethod + def _get_thirdparty_plugin_dirs() -> list[str]: + """第三方插件目录: plugins/""" + candidate = os.path.abspath("plugins") + return [candidate] if os.path.isdir(candidate) else [] + + # ─── 生命周期 ───────────────────────────────────────────── async def start(self) -> None: - """启动新版本插件运行时 - - 应在 plugin_manager.load_all_plugins() 之后调用。 - """ + """启动双子进程插件运行时""" if self._started: logger.warning("PluginRuntimeManager 已在运行中,跳过重复启动") return from src.plugin_runtime.host.supervisor import PluginSupervisor - plugin_dirs: list[str] = self._get_plugin_dirs() - if not plugin_dirs: - logger.info("未找到插件目录,跳过新版本插件运行时启动") + builtin_dirs = self._get_builtin_plugin_dirs() + thirdparty_dirs = self._get_thirdparty_plugin_dirs() + + if not builtin_dirs and not thirdparty_dirs: + logger.info("未找到任何插件目录,跳过插件运行时启动") return - self._supervisor = PluginSupervisor(plugin_dirs=plugin_dirs) + # 创建两个 Supervisor,各自拥有独立的 socket / Runner 子进程 + if builtin_dirs: + self._builtin_supervisor = PluginSupervisor( + plugin_dirs=builtin_dirs, + socket_path=None, # 自动生成 + ) + self._register_capability_impls(self._builtin_supervisor) - # 注册主程序提供的能力实现 - self._register_capability_impls() + if thirdparty_dirs: + self._thirdparty_supervisor = PluginSupervisor( + plugin_dirs=thirdparty_dirs, + socket_path=None, + ) + self._register_capability_impls(self._thirdparty_supervisor) + + # 并行启动 + coros = [] + if self._builtin_supervisor: + coros.append(self._builtin_supervisor.start()) + if self._thirdparty_supervisor: + coros.append(self._thirdparty_supervisor.start()) try: - await self._supervisor.start() + await asyncio.gather(*coros) self._started = True - logger.info(f"新版本插件运行时已启动,监控目录: {plugin_dirs}") + logger.info( + f"插件运行时已启动 — 内置: {builtin_dirs or '无'}, 第三方: {thirdparty_dirs or '无'}" + ) except Exception as e: - logger.error(f"新版本插件运行时启动失败: {e}", exc_info=True) - self._supervisor = None + logger.error(f"插件运行时启动失败: {e}", exc_info=True) + self._builtin_supervisor = None + self._thirdparty_supervisor = None async def stop(self) -> None: - """停止新版本插件运行时""" - if not self._started or self._supervisor is None: + """停止所有插件运行时""" + if not self._started: return + coros = [] + if self._builtin_supervisor: + coros.append(self._builtin_supervisor.stop()) + if self._thirdparty_supervisor: + coros.append(self._thirdparty_supervisor.stop()) + try: - await self._supervisor.stop() - logger.info("新版本插件运行时已停止") + await asyncio.gather(*coros, return_exceptions=True) + logger.info("插件运行时已停止") except Exception as e: - logger.error(f"新版本插件运行时停止失败: {e}", exc_info=True) + logger.error(f"插件运行时停止失败: {e}", exc_info=True) finally: self._started = False - self._supervisor = None + self._builtin_supervisor = None + self._thirdparty_supervisor = None @property def is_running(self) -> bool: return self._started @property - def supervisor(self) -> Any: - """获取底层 Supervisor(供高级用途)""" - return self._supervisor + def supervisors(self) -> list[Any]: + """获取所有活跃的 Supervisor""" + return [s for s in (self._builtin_supervisor, self._thirdparty_supervisor) if s is not None] # ─── 事件桥接 ────────────────────────────────────────────── @@ -122,124 +148,1273 @@ class PluginRuntimeManager: message_dict: dict[str, Any] | None = None, extra_args: dict[str, Any] | None = None, ) -> tuple[bool, dict[str, Any] | None]: - """将旧系统事件转发到新版本插件运行时 - - Args: - event_type_value: 旧 EventType 的 .value(如 "on_message") - message_dict: 序列化后的消息字典(MaiMessages 转 dict) - extra_args: 额外参数 + """将事件分发到所有 Supervisor Returns: (continue_flag, modified_message_dict) """ - if not self._started or self._supervisor is None: + if not self._started: return True, None new_event_type: str = _EVENT_TYPE_MAP.get(event_type_value, event_type_value) + modified: dict[str, Any] | None = None - try: - return await self._supervisor.dispatch_event( - event_type=new_event_type, - message=message_dict, - extra_args=extra_args, - ) - except Exception as e: - logger.error(f"桥接事件 {new_event_type} 到新运行时失败: {e}", exc_info=True) - return True, None + for sv in self.supervisors: + try: + cont, mod = await sv.dispatch_event( + event_type=new_event_type, + message=message_dict, + extra_args=extra_args, + ) + if mod is not None: + modified = mod + if not cont: + return False, modified + except Exception as e: + logger.error(f"事件 {new_event_type} 分发失败: {e}", exc_info=True) - # ─── 命令桥接 ────────────────────────────────────────────── + return True, modified + + # ─── 命令查找 ────────────────────────────────────────────── def find_command_by_text(self, text: str) -> dict[str, Any] | None: - """在新版本插件运行时的 ComponentRegistry 中查找命令 - - Returns: - 匹配结果字典 {"component": RegisteredComponent, "match": re.Match} - 或 None - """ - if not self._started or self._supervisor is None: + """在所有 Supervisor 的 ComponentRegistry 中查找命令""" + if not self._started: return None - return self._supervisor.component_registry.find_command_by_text(text) + for sv in self.supervisors: + result = sv.component_registry.find_command_by_text(text) + if result is not None: + return result + return None # ─── 能力实现注册 ────────────────────────────────────────── - def _register_capability_impls(self) -> None: - """注册主程序提供的能力实现 + def _register_capability_impls(self, supervisor: Any) -> None: + """向指定 Supervisor 注册主程序提供的能力实现""" + cap_service = supervisor.capability_service - 新版本插件通过 cap.request RPC 请求能力调用, - Host 端的 CapabilityService 需要真正的能力实现来处理这些请求。 - 这里注册主程序中可用的功能接口。 - """ - if self._supervisor is None: - return - - cap_service = self._supervisor.capability_service - - # 注册 send.* 能力 + # ── send.* ───────────────────────────────────────── cap_service.register_capability("send.text", self._cap_send_text) cap_service.register_capability("send.emoji", self._cap_send_emoji) cap_service.register_capability("send.image", self._cap_send_image) + cap_service.register_capability("send.command", self._cap_send_command) + cap_service.register_capability("send.custom", self._cap_send_custom) - # 注册 llm.* 能力 + # ── llm.* ───────────────────────────────────────── cap_service.register_capability("llm.generate", self._cap_llm_generate) + cap_service.register_capability("llm.generate_with_tools", self._cap_llm_generate_with_tools) + cap_service.register_capability("llm.get_available_models", self._cap_llm_get_available_models) - # 注册 config.* 能力 + # ── config.* ────────────────────────────────────── cap_service.register_capability("config.get", self._cap_config_get) + cap_service.register_capability("config.get_plugin", self._cap_config_get_plugin) - logger.debug("已注册主程序能力实现") + # ── database.* ──────────────────────────────────── + cap_service.register_capability("database.query", self._cap_database_query) + cap_service.register_capability("database.save", self._cap_database_save) + cap_service.register_capability("database.get", self._cap_database_get) - # ─── 能力实现 ────────────────────────────────────────────── + # ── chat.* ──────────────────────────────────────── + cap_service.register_capability("chat.get_all_streams", self._cap_chat_get_all_streams) + cap_service.register_capability("chat.get_group_streams", self._cap_chat_get_group_streams) + cap_service.register_capability("chat.get_private_streams", self._cap_chat_get_private_streams) + cap_service.register_capability("chat.get_stream_by_group_id", self._cap_chat_get_stream_by_group_id) + cap_service.register_capability("chat.get_stream_by_user_id", self._cap_chat_get_stream_by_user_id) + + # ── message.* ───────────────────────────────────── + cap_service.register_capability("message.get_by_time", self._cap_message_get_by_time) + cap_service.register_capability("message.get_by_time_in_chat", self._cap_message_get_by_time_in_chat) + cap_service.register_capability("message.get_recent", self._cap_message_get_recent) + cap_service.register_capability("message.count_new", self._cap_message_count_new) + cap_service.register_capability("message.build_readable", self._cap_message_build_readable) + + # ── person.* ────────────────────────────────────── + cap_service.register_capability("person.get_id", self._cap_person_get_id) + cap_service.register_capability("person.get_value", self._cap_person_get_value) + cap_service.register_capability("person.get_id_by_name", self._cap_person_get_id_by_name) + + # ── emoji.* ─────────────────────────────────────── + cap_service.register_capability("emoji.get_by_description", self._cap_emoji_get_by_description) + cap_service.register_capability("emoji.get_random", self._cap_emoji_get_random) + cap_service.register_capability("emoji.get_count", self._cap_emoji_get_count) + cap_service.register_capability("emoji.get_emotions", self._cap_emoji_get_emotions) + cap_service.register_capability("emoji.get_all", self._cap_emoji_get_all) + cap_service.register_capability("emoji.get_info", self._cap_emoji_get_info) + cap_service.register_capability("emoji.register", self._cap_emoji_register) + cap_service.register_capability("emoji.delete", self._cap_emoji_delete) + + # ── frequency.* ─────────────────────────────────── + cap_service.register_capability("frequency.get_current_talk_value", self._cap_frequency_get_current_talk_value) + cap_service.register_capability("frequency.set_adjust", self._cap_frequency_set_adjust) + cap_service.register_capability("frequency.get_adjust", self._cap_frequency_get_adjust) + + # ── tool.* ──────────────────────────────────────── + cap_service.register_capability("tool.get_definitions", self._cap_tool_get_definitions) + + # ── component.* ─────────────────────────────────── + cap_service.register_capability("component.get_all_plugins", self._cap_component_get_all_plugins) + cap_service.register_capability("component.get_plugin_info", self._cap_component_get_plugin_info) + cap_service.register_capability("component.list_loaded_plugins", self._cap_component_list_loaded_plugins) + cap_service.register_capability("component.list_registered_plugins", self._cap_component_list_registered_plugins) + cap_service.register_capability("component.enable", self._cap_component_enable) + cap_service.register_capability("component.disable", self._cap_component_disable) + cap_service.register_capability("component.load_plugin", self._cap_component_load_plugin) + cap_service.register_capability("component.unload_plugin", self._cap_component_unload_plugin) + cap_service.register_capability("component.reload_plugin", self._cap_component_reload_plugin) + + # ── knowledge.* ─────────────────────────────────── + cap_service.register_capability("knowledge.search", self._cap_knowledge_search) + + # ── logging.* ───────────────────────────────────── + cap_service.register_capability("logging.log", self._cap_logging_log) + + logger.debug("已注册全部主程序能力实现") + + # ═════════════════════════════════════════════════════════ + # send.* 能力实现 + # ═════════════════════════════════════════════════════════ @staticmethod async def _cap_send_text(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: - """发送文本消息能力实现 + """发送文本消息 - 注意: chat_stream 模块已被移除,send.text 能力暂不可用, - 待新的消息发送接口稳定后再接入。 + args: text, stream_id, typing?, set_reply?, storage_message? """ - return {"success": False, "error": "send.text 尚未接入(chat_stream 已移除)"} + from src.services import send_service as send_api + + text: str = args.get("text", "") + stream_id: str = args.get("stream_id", "") + if not text or not stream_id: + return {"success": False, "error": "缺少必要参数 text 或 stream_id"} + + try: + result = await send_api.text_to_stream( + text=text, + stream_id=stream_id, + typing=args.get("typing", False), + set_reply=args.get("set_reply", False), + storage_message=args.get("storage_message", True), + ) + return {"success": result} + except Exception as e: + logger.error(f"[cap.send.text] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} @staticmethod async def _cap_send_emoji(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: - """发送表情能力实现""" - return {"success": False, "error": "send.emoji 尚未实现"} + """发送表情 + + args: emoji_base64, stream_id, storage_message? + """ + from src.services import send_service as send_api + + emoji_base64: str = args.get("emoji_base64", "") + stream_id: str = args.get("stream_id", "") + if not emoji_base64 or not stream_id: + return {"success": False, "error": "缺少必要参数 emoji_base64 或 stream_id"} + + try: + result = await send_api.emoji_to_stream( + emoji_base64=emoji_base64, + stream_id=stream_id, + storage_message=args.get("storage_message", True), + ) + return {"success": result} + except Exception as e: + logger.error(f"[cap.send.emoji] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} @staticmethod async def _cap_send_image(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: - """发送图片能力实现""" - return {"success": False, "error": "send.image 尚未实现"} + """发送图片 + + args: image_base64, stream_id, storage_message? + """ + from src.services import send_service as send_api + + image_base64: str = args.get("image_base64", "") + stream_id: str = args.get("stream_id", "") + if not image_base64 or not stream_id: + return {"success": False, "error": "缺少必要参数 image_base64 或 stream_id"} + + try: + result = await send_api.image_to_stream( + image_base64=image_base64, + stream_id=stream_id, + storage_message=args.get("storage_message", True), + ) + return {"success": result} + except Exception as e: + logger.error(f"[cap.send.image] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + @staticmethod + async def _cap_send_command(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """发送命令 + + args: command, stream_id, storage_message?, display_message? + """ + from src.services import send_service as send_api + + command = args.get("command", "") + stream_id: str = args.get("stream_id", "") + if not command or not stream_id: + return {"success": False, "error": "缺少必要参数 command 或 stream_id"} + + try: + result = await send_api.command_to_stream( + command=command, + stream_id=stream_id, + storage_message=args.get("storage_message", True), + display_message=args.get("display_message", ""), + ) + return {"success": result} + except Exception as e: + logger.error(f"[cap.send.command] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + @staticmethod + async def _cap_send_custom(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """发送自定义类型消息 + + args: message_type, content, stream_id, display_message?, typing?, storage_message? + """ + from src.services import send_service as send_api + + message_type: str = args.get("message_type", "") + content = args.get("content", "") + stream_id: str = args.get("stream_id", "") + if not message_type or not stream_id: + return {"success": False, "error": "缺少必要参数 message_type 或 stream_id"} + + try: + result = await send_api.custom_to_stream( + message_type=message_type, + content=content, + stream_id=stream_id, + display_message=args.get("display_message", ""), + typing=args.get("typing", False), + storage_message=args.get("storage_message", True), + ) + return {"success": result} + except Exception as e: + logger.error(f"[cap.send.custom] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + # ═════════════════════════════════════════════════════════ + # llm.* 能力实现 + # ═════════════════════════════════════════════════════════ @staticmethod async def _cap_llm_generate(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: - """LLM 生成能力实现""" - return {"success": False, "error": "llm.generate 尚未完全接入,请使用旧系统的 LLM API"} + """LLM 生成 + + args: prompt, model_name?, temperature?, max_tokens? + """ + from src.services import llm_service as llm_api + + prompt: str = args.get("prompt", "") + if not prompt: + return {"success": False, "error": "缺少必要参数 prompt"} + + model_name: str = args.get("model_name", "") + temperature = args.get("temperature") + max_tokens = args.get("max_tokens") + + try: + models = llm_api.get_available_models() + if model_name and model_name in models: + model_config = models[model_name] + else: + # 选取第一个可用模型配置 + if not models: + return {"success": False, "error": "没有可用的模型配置"} + model_config = next(iter(models.values())) + + success, response, reasoning, used_model = await llm_api.generate_with_model( + prompt=prompt, + model_config=model_config, + request_type=f"plugin.{plugin_id}", + temperature=temperature, + max_tokens=max_tokens, + ) + return { + "success": success, + "response": response, + "reasoning": reasoning, + "model_name": used_model, + } + except Exception as e: + logger.error(f"[cap.llm.generate] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + @staticmethod + async def _cap_llm_generate_with_tools(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """LLM 带工具生成 + + args: prompt, model_name?, tool_options?, temperature?, max_tokens? + """ + from src.services import llm_service as llm_api + + prompt: str = args.get("prompt", "") + if not prompt: + return {"success": False, "error": "缺少必要参数 prompt"} + + model_name: str = args.get("model_name", "") + tool_options = args.get("tool_options") + temperature = args.get("temperature") + max_tokens = args.get("max_tokens") + + try: + models = llm_api.get_available_models() + if model_name and model_name in models: + model_config = models[model_name] + else: + if not models: + return {"success": False, "error": "没有可用的模型配置"} + model_config = next(iter(models.values())) + + success, response, reasoning, used_model, tool_calls = await llm_api.generate_with_model_with_tools( + prompt=prompt, + model_config=model_config, + tool_options=tool_options, + request_type=f"plugin.{plugin_id}", + temperature=temperature, + max_tokens=max_tokens, + ) + # 将 ToolCall 对象序列化为 dict + serialized_tool_calls = None + if tool_calls: + serialized_tool_calls = [ + {"id": tc.id, "function": {"name": tc.function.name, "arguments": tc.function.arguments}} + for tc in tool_calls + if hasattr(tc, "function") + ] + return { + "success": success, + "response": response, + "reasoning": reasoning, + "model_name": used_model, + "tool_calls": serialized_tool_calls, + } + except Exception as e: + logger.error(f"[cap.llm.generate_with_tools] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + @staticmethod + async def _cap_llm_get_available_models(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """获取可用模型列表""" + from src.services import llm_service as llm_api + + try: + models = llm_api.get_available_models() + return {"success": True, "models": list(models.keys())} + except Exception as e: + logger.error(f"[cap.llm.get_available_models] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + # ═════════════════════════════════════════════════════════ + # config.* 能力实现 + # ═════════════════════════════════════════════════════════ @staticmethod async def _cap_config_get(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: - """配置读取能力实现""" - from src.plugin_system.core import component_registry as old_registry + """读取全局配置 + + args: key, default? + """ + from src.services import config_service as config_api + + key: str = args.get("key", "") + default = args.get("default") + + if not key: + return {"success": False, "value": None, "error": "缺少必要参数 key"} + + try: + value = config_api.get_global_config(key, default) + return {"success": True, "value": value} + except Exception as e: + return {"success": False, "value": None, "error": str(e)} + + @staticmethod + async def _cap_config_get_plugin(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """读取插件配置 + + args: key, default?, plugin_name? + """ + from src.core.component_registry import component_registry as core_registry plugin_name: str = args.get("plugin_name", plugin_id) key: str = args.get("key", "") + default = args.get("default") try: - config = old_registry.get_plugin_config(plugin_name) + config = core_registry.get_plugin_config(plugin_name) if config is None: - return {"success": False, "value": None, "error": f"未找到插件 {plugin_name} 的配置"} + return {"success": False, "value": default, "error": f"未找到插件 {plugin_name} 的配置"} if key: - parts = key.split(".") - value: Any = config - for part in parts: - if isinstance(value, dict): - value = value.get(part) - else: - return {"success": False, "value": None, "error": f"配置路径无效: {key}"} + from src.services import config_service as config_api + + value = config_api.get_plugin_config(config, key, default) return {"success": True, "value": value} return {"success": True, "value": config} except Exception as e: - return {"success": False, "value": None, "error": str(e)} + return {"success": False, "value": default, "error": str(e)} + + # ═════════════════════════════════════════════════════════ + # database.* 能力实现 + # ═════════════════════════════════════════════════════════ + + @staticmethod + async def _cap_database_query(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """数据库查询 + + args: model_name, query_type?, filters?, limit?, order_by?, data?, single_result? + model_name 应为 src.common.database.database_model 中的类名。 + """ + from src.services import database_service as database_api + + model_name: str = args.get("model_name", "") + if not model_name: + return {"success": False, "error": "缺少必要参数 model_name"} + + try: + import src.common.database.database_model as db_models + + model_class = getattr(db_models, model_name, None) + if model_class is None: + return {"success": False, "error": f"未找到数据模型: {model_name}"} + + result = await database_api.db_query( + model_class=model_class, + data=args.get("data"), + query_type=args.get("query_type", "get"), + filters=args.get("filters"), + limit=args.get("limit"), + order_by=args.get("order_by"), + single_result=args.get("single_result", False), + ) + return {"success": True, "result": result} + except Exception as e: + logger.error(f"[cap.database.query] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + @staticmethod + async def _cap_database_save(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """数据库保存 + + args: model_name, data, key_field?, key_value? + """ + from src.services import database_service as database_api + + model_name: str = args.get("model_name", "") + data: dict[str, Any] | None = args.get("data") + if not model_name or not data: + return {"success": False, "error": "缺少必要参数 model_name 或 data"} + + try: + import src.common.database.database_model as db_models + + model_class = getattr(db_models, model_name, None) + if model_class is None: + return {"success": False, "error": f"未找到数据模型: {model_name}"} + + result = await database_api.db_save( + model_class=model_class, + data=data, + key_field=args.get("key_field"), + key_value=args.get("key_value"), + ) + return {"success": True, "result": result} + except Exception as e: + logger.error(f"[cap.database.save] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + @staticmethod + async def _cap_database_get(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """数据库简单查询 + + args: model_name, filters?, limit?, order_by?, single_result? + """ + from src.services import database_service as database_api + + model_name: str = args.get("model_name", "") + if not model_name: + return {"success": False, "error": "缺少必要参数 model_name"} + + try: + import src.common.database.database_model as db_models + + model_class = getattr(db_models, model_name, None) + if model_class is None: + return {"success": False, "error": f"未找到数据模型: {model_name}"} + + result = await database_api.db_get( + model_class=model_class, + filters=args.get("filters"), + limit=args.get("limit"), + order_by=args.get("order_by"), + single_result=args.get("single_result", False), + ) + return {"success": True, "result": result} + except Exception as e: + logger.error(f"[cap.database.get] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + # ═════════════════════════════════════════════════════════ + # chat.* 能力实现 + # ═════════════════════════════════════════════════════════ + + @staticmethod + def _serialize_stream(stream: Any) -> dict[str, Any]: + """将 BotChatSession 序列化为可通过 RPC 传输的字典""" + return { + "session_id": getattr(stream, "session_id", ""), + "platform": getattr(stream, "platform", ""), + "user_id": getattr(stream, "user_id", ""), + "group_id": getattr(stream, "group_id", ""), + "is_group_session": getattr(stream, "is_group_session", False), + } + + @staticmethod + async def _cap_chat_get_all_streams(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """获取所有聊天流 + + args: platform? + """ + from src.services import chat_service as chat_api + + platform: str = args.get("platform", "qq") + try: + streams = chat_api.ChatManager.get_all_streams(platform=platform) + return { + "success": True, + "streams": [PluginRuntimeManager._serialize_stream(s) for s in streams], + } + except Exception as e: + logger.error(f"[cap.chat.get_all_streams] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + @staticmethod + async def _cap_chat_get_group_streams(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """获取所有群聊流 + + args: platform? + """ + from src.services import chat_service as chat_api + + platform: str = args.get("platform", "qq") + try: + streams = chat_api.ChatManager.get_group_streams(platform=platform) + return { + "success": True, + "streams": [PluginRuntimeManager._serialize_stream(s) for s in streams], + } + except Exception as e: + logger.error(f"[cap.chat.get_group_streams] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + @staticmethod + async def _cap_chat_get_private_streams(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """获取所有私聊流 + + args: platform? + """ + from src.services import chat_service as chat_api + + platform: str = args.get("platform", "qq") + try: + streams = chat_api.ChatManager.get_private_streams(platform=platform) + return { + "success": True, + "streams": [PluginRuntimeManager._serialize_stream(s) for s in streams], + } + except Exception as e: + logger.error(f"[cap.chat.get_private_streams] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + @staticmethod + async def _cap_chat_get_stream_by_group_id(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """按群 ID 查找聊天流 + + args: group_id, platform? + """ + from src.services import chat_service as chat_api + + group_id: str = args.get("group_id", "") + if not group_id: + return {"success": False, "error": "缺少必要参数 group_id"} + + platform: str = args.get("platform", "qq") + try: + stream = chat_api.ChatManager.get_group_stream_by_group_id(group_id=group_id, platform=platform) + if stream is None: + return {"success": True, "stream": None} + return {"success": True, "stream": PluginRuntimeManager._serialize_stream(stream)} + except Exception as e: + logger.error(f"[cap.chat.get_stream_by_group_id] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + @staticmethod + async def _cap_chat_get_stream_by_user_id(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """按用户 ID 查找私聊流 + + args: user_id, platform? + """ + from src.services import chat_service as chat_api + + user_id: str = args.get("user_id", "") + if not user_id: + return {"success": False, "error": "缺少必要参数 user_id"} + + platform: str = args.get("platform", "qq") + try: + stream = chat_api.ChatManager.get_private_stream_by_user_id(user_id=user_id, platform=platform) + if stream is None: + return {"success": True, "stream": None} + return {"success": True, "stream": PluginRuntimeManager._serialize_stream(stream)} + except Exception as e: + logger.error(f"[cap.chat.get_stream_by_user_id] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + # ═════════════════════════════════════════════════════════ + # message.* 能力实现 + # ═════════════════════════════════════════════════════════ + + @staticmethod + def _serialize_messages(messages: list) -> list[dict[str, Any]]: + """将 DatabaseMessages 列表序列化为 dict 列表""" + result: list[dict[str, Any]] = [] + for msg in messages: + if hasattr(msg, "model_dump"): + result.append(msg.model_dump()) + elif hasattr(msg, "__dict__"): + result.append(dict(msg.__dict__)) + else: + result.append(str(msg)) + return result + + @staticmethod + async def _cap_message_get_by_time(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """按时间范围查询消息 + + args: start_time, end_time, limit?, filter_mai? + """ + from src.services import message_service as message_api + + start_time = args.get("start_time", 0.0) + end_time = args.get("end_time", 0.0) + + try: + messages = message_api.get_messages_by_time( + start_time=float(start_time), + end_time=float(end_time), + limit=args.get("limit", 0), + limit_mode=args.get("limit_mode", "latest"), + filter_mai=args.get("filter_mai", False), + ) + return {"success": True, "messages": PluginRuntimeManager._serialize_messages(messages)} + except Exception as e: + logger.error(f"[cap.message.get_by_time] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + @staticmethod + async def _cap_message_get_by_time_in_chat(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """按时间范围查询指定聊天消息 + + args: chat_id, start_time, end_time, limit?, filter_mai?, filter_command? + """ + from src.services import message_service as message_api + + chat_id: str = args.get("chat_id", "") + if not chat_id: + return {"success": False, "error": "缺少必要参数 chat_id"} + + try: + messages = message_api.get_messages_by_time_in_chat( + chat_id=chat_id, + start_time=float(args.get("start_time", 0.0)), + end_time=float(args.get("end_time", 0.0)), + limit=args.get("limit", 0), + limit_mode=args.get("limit_mode", "latest"), + filter_mai=args.get("filter_mai", False), + filter_command=args.get("filter_command", False), + ) + return {"success": True, "messages": PluginRuntimeManager._serialize_messages(messages)} + except Exception as e: + logger.error(f"[cap.message.get_by_time_in_chat] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + @staticmethod + async def _cap_message_get_recent(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """获取最近的消息 + + args: chat_id, hours?, limit?, filter_mai? + """ + from src.services import message_service as message_api + + chat_id: str = args.get("chat_id", "") + if not chat_id: + return {"success": False, "error": "缺少必要参数 chat_id"} + + try: + messages = message_api.get_recent_messages( + chat_id=chat_id, + hours=float(args.get("hours", 24.0)), + limit=args.get("limit", 100), + limit_mode=args.get("limit_mode", "latest"), + filter_mai=args.get("filter_mai", False), + ) + return {"success": True, "messages": PluginRuntimeManager._serialize_messages(messages)} + except Exception as e: + logger.error(f"[cap.message.get_recent] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + @staticmethod + async def _cap_message_count_new(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """统计新消息数量 + + args: chat_id, start_time?, end_time? + """ + from src.services import message_service as message_api + + chat_id: str = args.get("chat_id", "") + if not chat_id: + return {"success": False, "error": "缺少必要参数 chat_id"} + + try: + count = message_api.count_new_messages( + chat_id=chat_id, + start_time=float(args.get("start_time", 0.0)), + end_time=args.get("end_time"), + ) + return {"success": True, "count": count} + except Exception as e: + logger.error(f"[cap.message.count_new] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + @staticmethod + async def _cap_message_build_readable(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """将消息列表构建成可读字符串 + + args: chat_id, start_time, end_time, limit?, replace_bot_name?, timestamp_mode? + """ + from src.services import message_service as message_api + + chat_id: str = args.get("chat_id", "") + if not chat_id: + return {"success": False, "error": "缺少必要参数 chat_id"} + + try: + messages = message_api.get_messages_by_time_in_chat( + chat_id=chat_id, + start_time=float(args.get("start_time", 0.0)), + end_time=float(args.get("end_time", 0.0)), + limit=args.get("limit", 0), + ) + readable = message_api.build_readable_messages_to_str( + messages=messages, + replace_bot_name=args.get("replace_bot_name", True), + timestamp_mode=args.get("timestamp_mode", "relative"), + truncate=args.get("truncate", False), + ) + return {"success": True, "text": readable} + except Exception as e: + logger.error(f"[cap.message.build_readable] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + # ═════════════════════════════════════════════════════════ + # person.* 能力实现 + # ═════════════════════════════════════════════════════════ + + @staticmethod + async def _cap_person_get_id(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """获取 person_id + + args: platform, user_id + """ + from src.services import person_service as person_api + + platform: str = args.get("platform", "") + user_id = args.get("user_id", "") + if not platform or not user_id: + return {"success": False, "error": "缺少必要参数 platform 或 user_id"} + + try: + pid = person_api.get_person_id(platform=platform, user_id=user_id) + return {"success": True, "person_id": pid} + except Exception as e: + logger.error(f"[cap.person.get_id] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + @staticmethod + async def _cap_person_get_value(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """获取用户字段值 + + args: person_id, field_name, default? + """ + from src.services import person_service as person_api + + person_id: str = args.get("person_id", "") + field_name: str = args.get("field_name", "") + if not person_id or not field_name: + return {"success": False, "error": "缺少必要参数 person_id 或 field_name"} + + try: + value = await person_api.get_person_value( + person_id=person_id, + field_name=field_name, + default=args.get("default"), + ) + return {"success": True, "value": value} + except Exception as e: + logger.error(f"[cap.person.get_value] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + @staticmethod + async def _cap_person_get_id_by_name(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """根据用户名获取 person_id + + args: person_name + """ + from src.services import person_service as person_api + + person_name: str = args.get("person_name", "") + if not person_name: + return {"success": False, "error": "缺少必要参数 person_name"} + + try: + pid = person_api.get_person_id_by_name(person_name=person_name) + return {"success": True, "person_id": pid} + except Exception as e: + logger.error(f"[cap.person.get_id_by_name] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + # ═════════════════════════════════════════════════════════ + # emoji.* 能力实现 + # ═════════════════════════════════════════════════════════ + + @staticmethod + async def _cap_emoji_get_by_description(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """根据描述获取表情包 + + args: description + """ + from src.services import emoji_service as emoji_api + + description: str = args.get("description", "") + if not description: + return {"success": False, "error": "缺少必要参数 description"} + + try: + result = await emoji_api.get_by_description(description=description) + if result is None: + return {"success": True, "emoji": None} + emoji_base64, emoji_desc, matched_emotion = result + return { + "success": True, + "emoji": { + "base64": emoji_base64, + "description": emoji_desc, + "emotion": matched_emotion, + }, + } + except Exception as e: + logger.error(f"[cap.emoji.get_by_description] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + @staticmethod + async def _cap_emoji_get_random(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """随机获取表情包 + + args: count? + """ + from src.services import emoji_service as emoji_api + + count: int = args.get("count", 1) + try: + results = await emoji_api.get_random(count=count) + emojis = [ + {"base64": b64, "description": desc, "emotion": emo} for b64, desc, emo in results + ] + return {"success": True, "emojis": emojis} + except Exception as e: + logger.error(f"[cap.emoji.get_random] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + @staticmethod + async def _cap_emoji_get_count(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """获取表情包数量""" + from src.services import emoji_service as emoji_api + + try: + return {"success": True, "count": emoji_api.get_count()} + except Exception as e: + logger.error(f"[cap.emoji.get_count] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + @staticmethod + async def _cap_emoji_get_emotions(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """获取所有情绪标签""" + from src.services import emoji_service as emoji_api + + try: + return {"success": True, "emotions": emoji_api.get_emotions()} + except Exception as e: + logger.error(f"[cap.emoji.get_emotions] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + @staticmethod + async def _cap_emoji_get_all(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """获取所有表情包""" + from src.services import emoji_service as emoji_api + + try: + results = await emoji_api.get_all() + emojis = [ + {"base64": b64, "description": desc, "emotion": emo} for b64, desc, emo in results + ] if results else [] + return {"success": True, "emojis": emojis} + except Exception as e: + logger.error(f"[cap.emoji.get_all] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + @staticmethod + async def _cap_emoji_get_info(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """获取表情包统计信息""" + from src.services import emoji_service as emoji_api + + try: + return {"success": True, "info": emoji_api.get_info()} + except Exception as e: + logger.error(f"[cap.emoji.get_info] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + @staticmethod + async def _cap_emoji_register(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """注册表情包 + + args: emoji_base64 + """ + from src.services import emoji_service as emoji_api + + emoji_base64: str = args.get("emoji_base64", "") + if not emoji_base64: + return {"success": False, "error": "缺少必要参数 emoji_base64"} + + try: + result = await emoji_api.register_emoji(emoji_base64) + return result + except Exception as e: + logger.error(f"[cap.emoji.register] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + @staticmethod + async def _cap_emoji_delete(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """删除表情包 + + args: emoji_hash + """ + from src.services import emoji_service as emoji_api + + emoji_hash: str = args.get("emoji_hash", "") + if not emoji_hash: + return {"success": False, "error": "缺少必要参数 emoji_hash"} + + try: + result = await emoji_api.delete_emoji(emoji_hash) + return result + except Exception as e: + logger.error(f"[cap.emoji.delete] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + # ═════════════════════════════════════════════════════════ + # frequency.* 能力实现 + # ═════════════════════════════════════════════════════════ + + @staticmethod + async def _cap_frequency_get_current_talk_value(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """获取当前说话频率值 + + args: chat_id + """ + from src.services import frequency_service as frequency_api + + chat_id: str = args.get("chat_id", "") + if not chat_id: + return {"success": False, "error": "缺少必要参数 chat_id"} + + try: + value = frequency_api.get_current_talk_value(chat_id) + return {"success": True, "value": value} + except Exception as e: + logger.error(f"[cap.frequency.get_current_talk_value] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + @staticmethod + async def _cap_frequency_set_adjust(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """设置说话频率调整值 + + args: chat_id, value + """ + from src.services import frequency_service as frequency_api + + chat_id: str = args.get("chat_id", "") + value = args.get("value") + if not chat_id or value is None: + return {"success": False, "error": "缺少必要参数 chat_id 或 value"} + + try: + frequency_api.set_talk_frequency_adjust(chat_id, float(value)) + return {"success": True} + except Exception as e: + logger.error(f"[cap.frequency.set_adjust] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + @staticmethod + async def _cap_frequency_get_adjust(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """获取说话频率调整值 + + args: chat_id + """ + from src.services import frequency_service as frequency_api + + chat_id: str = args.get("chat_id", "") + if not chat_id: + return {"success": False, "error": "缺少必要参数 chat_id"} + + try: + value = frequency_api.get_talk_frequency_adjust(chat_id) + return {"success": True, "value": value} + except Exception as e: + logger.error(f"[cap.frequency.get_adjust] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + # ═════════════════════════════════════════════════════════ + # tool.* 能力实现 + # ═════════════════════════════════════════════════════════ + + @staticmethod + async def _cap_tool_get_definitions(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """获取 LLM 可用的工具定义列表""" + from src.core.component_registry import component_registry as core_registry + + try: + tools = core_registry.get_llm_available_tools() + return { + "success": True, + "tools": [{"name": name, "definition": info.get_llm_definition()} for name, info in tools.items()], + } + except Exception as e: + logger.error(f"[cap.tool.get_definitions] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + # ═════════════════════════════════════════════════════════ + # component.* 能力实现 + # ═════════════════════════════════════════════════════════ + + @staticmethod + async def _cap_component_get_all_plugins(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """获取所有插件信息(汇总所有 Supervisor 的注册信息)""" + mgr = get_plugin_runtime_manager() + result: dict[str, Any] = {} + for sv in mgr.supervisors: + for pid, reg in sv._registered_plugins.items(): + result[pid] = { + "name": pid, + "version": reg.plugin_version, + "description": "", + "author": "", + "enabled": True, + } + return {"success": True, "plugins": result} + + @staticmethod + async def _cap_component_get_plugin_info(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """获取指定插件信息 + + args: plugin_name + """ + plugin_name: str = args.get("plugin_name", plugin_id) + mgr = get_plugin_runtime_manager() + for sv in mgr.supervisors: + reg = sv._registered_plugins.get(plugin_name) + if reg is not None: + return { + "success": True, + "plugin": { + "name": plugin_name, + "version": reg.plugin_version, + "description": "", + "author": "", + "enabled": True, + }, + } + return {"success": False, "error": f"未找到插件: {plugin_name}"} + + @staticmethod + async def _cap_component_list_loaded_plugins(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """列出已加载的插件""" + mgr = get_plugin_runtime_manager() + plugins: list[str] = [] + for sv in mgr.supervisors: + plugins.extend(sv._registered_plugins.keys()) + return {"success": True, "plugins": plugins} + + @staticmethod + async def _cap_component_list_registered_plugins(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """列出已注册的插件(同 list_loaded)""" + mgr = get_plugin_runtime_manager() + plugins: list[str] = [] + for sv in mgr.supervisors: + plugins.extend(sv._registered_plugins.keys()) + return {"success": True, "plugins": plugins} + + @staticmethod + async def _cap_component_enable(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """启用组件 + + args: name, component_type + """ + name: str = args.get("name", "") + component_type: str = args.get("component_type", "") + if not name or not component_type: + return {"success": False, "error": "缺少必要参数 name 或 component_type"} + + mgr = get_plugin_runtime_manager() + for sv in mgr.supervisors: + comp = sv.component_registry.get_component(name) + if comp is not None: + comp.enabled = True + return {"success": True} + return {"success": False, "error": f"未找到组件: {name}"} + + @staticmethod + async def _cap_component_disable(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """禁用组件 + + args: name, component_type + """ + name: str = args.get("name", "") + component_type: str = args.get("component_type", "") + if not name or not component_type: + return {"success": False, "error": "缺少必要参数 name 或 component_type"} + + mgr = get_plugin_runtime_manager() + for sv in mgr.supervisors: + comp = sv.component_registry.get_component(name) + if comp is not None: + comp.enabled = False + return {"success": True} + return {"success": False, "error": f"未找到组件: {name}"} + + @staticmethod + async def _cap_component_load_plugin(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """加载插件(在新运行时中通过热重载实现) + + args: plugin_name + """ + plugin_name: str = args.get("plugin_name", "") + if not plugin_name: + return {"success": False, "error": "缺少必要参数 plugin_name"} + + mgr = get_plugin_runtime_manager() + for sv in mgr.supervisors: + try: + await sv.reload_plugins(reason=f"load {plugin_name}") + return {"success": True, "count": 1} + except Exception as e: + logger.error(f"[cap.component.load_plugin] 热重载失败: {e}") + return {"success": False, "error": f"无法加载插件: {plugin_name}"} + + @staticmethod + async def _cap_component_unload_plugin(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """卸载插件(在新运行时中不支持单独卸载) + + args: plugin_name + """ + return {"success": False, "error": "新运行时不支持单独卸载插件,请使用 reload"} + + @staticmethod + async def _cap_component_reload_plugin(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """重新加载插件(触发对应 Supervisor 的热重载) + + args: plugin_name + """ + plugin_name: str = args.get("plugin_name", "") + if not plugin_name: + return {"success": False, "error": "缺少必要参数 plugin_name"} + + mgr = get_plugin_runtime_manager() + for sv in mgr.supervisors: + if plugin_name in sv._registered_plugins: + try: + await sv.reload_plugins(reason=f"reload {plugin_name}") + return {"success": True} + except Exception as e: + logger.error(f"[cap.component.reload_plugin] 热重载失败: {e}") + return {"success": False, "error": str(e)} + return {"success": False, "error": f"未找到插件: {plugin_name}"} + + # ═════════════════════════════════════════════════════════ + # knowledge.* 能力实现 + # ═════════════════════════════════════════════════════════ + + @staticmethod + async def _cap_knowledge_search(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """从 LPMM 知识库搜索知识 + + args: query, limit? + """ + query: str = args.get("query", "") + if not query: + return {"success": False, "error": "缺少必要参数 query"} + + limit = args.get("limit", 5) + try: + limit_value = max(1, int(limit)) + except (TypeError, ValueError): + limit_value = 5 + + try: + from src.chat.knowledge import qa_manager + + if qa_manager is None: + return {"success": True, "content": "LPMM知识库已禁用"} + + knowledge_info = await qa_manager.get_knowledge(query, limit=limit_value) + if knowledge_info: + content = f"你知道这些知识: {knowledge_info}" + else: + content = f"你不太了解有关{query}的知识" + return {"success": True, "content": content} + except Exception as e: + logger.error(f"[cap.knowledge.search] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + # ═════════════════════════════════════════════════════════ + # logging.* 能力实现 + # ═════════════════════════════════════════════════════════ + + @staticmethod + async def _cap_logging_log(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: + """插件日志记录 + + args: level?, message + """ + level: str = args.get("level", "info").lower() + message: str = args.get("message", "") + if not message: + return {"success": False, "error": "缺少必要参数 message"} + + plugin_logger = get_logger(f"plugin.{plugin_id}") + log_fn = getattr(plugin_logger, level, plugin_logger.info) + log_fn(message) + return {"success": True} # ─── 单例 ────────────────────────────────────────────────── diff --git a/src/plugin_runtime/transport/uds.py b/src/plugin_runtime/transport/uds.py index 461f3197..36129968 100644 --- a/src/plugin_runtime/transport/uds.py +++ b/src/plugin_runtime/transport/uds.py @@ -22,8 +22,9 @@ class UDSTransportServer(TransportServer): def __init__(self, socket_path: str | None = None): if socket_path is None: - # 默认放在临时目录 - socket_path = os.path.join(tempfile.gettempdir(), f"maibot-plugin-{os.getpid()}.sock") + # 默认放在临时目录,使用 uuid 确保同一进程多实例不碰撞 + import uuid + socket_path = os.path.join(tempfile.gettempdir(), f"maibot-plugin-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock") self._socket_path = socket_path self._server: asyncio.AbstractServer | None = None diff --git a/src/plugin_system/__init__.py b/src/plugin_system/__init__.py deleted file mode 100644 index 2c33726e..00000000 --- a/src/plugin_system/__init__.py +++ /dev/null @@ -1,156 +0,0 @@ -""" -MaiBot 插件系统 - -提供统一的插件开发和管理框架 -""" - -# 导出主要的公共接口 -from .base import ( - BasePlugin, - BaseAction, - BaseCommand, - BaseTool, - ConfigField, - ConfigSection, - ConfigLayout, - ConfigTab, - ComponentType, - ActionActivationType, - ChatMode, - ComponentInfo, - ActionInfo, - CommandInfo, - PluginInfo, - ToolInfo, - PythonDependency, - BaseEventHandler, - EventHandlerInfo, - EventType, - MaiMessages, - ToolParamType, - CustomEventHandlerResult, - PluginServiceInfo, - ReplyContentType, - ReplyContent, - ForwardNode, - ReplySetModel, - WorkflowContext, - WorkflowMessage, - WorkflowStage, - WorkflowStepInfo, - WorkflowStepResult, - WorkflowErrorCode, -) - -# 导入工具模块 -from .utils import ( - ManifestValidator, - # ManifestGenerator, - # validate_plugin_manifest, - # generate_plugin_manifest, -) - -from .apis import ( - chat_api, - tool_api, - component_manage_api, - config_api, - database_api, - emoji_api, - generator_api, - llm_api, - message_api, - person_api, - plugin_manage_api, - plugin_service_api, - workflow_api, - send_api, - register_plugin, - get_logger, -) - -from src.common.data_models.database_data_model import ( - DatabaseMessages, - DatabaseUserInfo, - DatabaseGroupInfo, - DatabaseChatInfo, -) -from src.common.data_models.info_data_model import TargetPersonInfo, ActionPlannerInfo -from src.common.data_models.llm_data_model import LLMGenerationDataModel - - -__version__ = "2.0.0" - -__all__ = [ - # API 模块 - "chat_api", - "tool_api", - "component_manage_api", - "config_api", - "database_api", - "emoji_api", - "generator_api", - "llm_api", - "message_api", - "person_api", - "plugin_manage_api", - "plugin_service_api", - "workflow_api", - "send_api", - "auto_talk_api", - "register_plugin", - "get_logger", - # 基础类 - "BasePlugin", - "BaseAction", - "BaseCommand", - "BaseTool", - "BaseEventHandler", - # 类型定义 - "ComponentType", - "ActionActivationType", - "ChatMode", - "ComponentInfo", - "ActionInfo", - "CommandInfo", - "PluginInfo", - "ToolInfo", - "PythonDependency", - "EventHandlerInfo", - "EventType", - "ToolParamType", - # 消息 - "ReplyContentType", - "ReplyContent", - "ForwardNode", - "ReplySetModel", - "MaiMessages", - "CustomEventHandlerResult", - "PluginServiceInfo", - "WorkflowContext", - "WorkflowMessage", - "WorkflowStage", - "WorkflowStepInfo", - "WorkflowStepResult", - "WorkflowErrorCode", - # 装饰器 - "register_plugin", - "ConfigField", - "ConfigSection", - "ConfigLayout", - "ConfigTab", - # 工具函数 - "ManifestValidator", - "get_logger", - # "ManifestGenerator", - # "validate_plugin_manifest", - # "generate_plugin_manifest", - # 数据模型 - "DatabaseMessages", - "DatabaseUserInfo", - "DatabaseGroupInfo", - "DatabaseChatInfo", - "TargetPersonInfo", - "ActionPlannerInfo", - "LLMGenerationDataModel", -] diff --git a/src/plugin_system/apis/__init__.py b/src/plugin_system/apis/__init__.py deleted file mode 100644 index 75e64af1..00000000 --- a/src/plugin_system/apis/__init__.py +++ /dev/null @@ -1,47 +0,0 @@ -""" -插件系统API模块 - -提供了插件开发所需的各种API -""" - -# 导入所有API模块 -from src.plugin_system.apis import ( - chat_api, - component_manage_api, - config_api, - database_api, - emoji_api, - generator_api, - llm_api, - message_api, - person_api, - plugin_manage_api, - plugin_service_api, - send_api, - tool_api, - frequency_api, - workflow_api, -) -from .logging_api import get_logger -from .plugin_register_api import register_plugin - -# 导出所有API模块,使它们可以通过 apis.xxx 方式访问 -__all__ = [ - "chat_api", - "component_manage_api", - "config_api", - "database_api", - "emoji_api", - "generator_api", - "llm_api", - "message_api", - "person_api", - "plugin_manage_api", - "plugin_service_api", - "send_api", - "get_logger", - "register_plugin", - "tool_api", - "frequency_api", - "workflow_api", -] diff --git a/src/plugin_system/apis/chat_api.py b/src/plugin_system/apis/chat_api.py deleted file mode 100644 index faab54b0..00000000 --- a/src/plugin_system/apis/chat_api.py +++ /dev/null @@ -1,323 +0,0 @@ -""" -聊天API模块 - -专门负责聊天信息的查询和管理,采用标准Python包设计模式 -使用方式: - from src.plugin_system.apis import chat_api - streams = chat_api.get_all_group_streams() - chat_type = chat_api.get_stream_type(stream) - -或者: - from src.plugin_system.apis.chat_api import ChatManager as chat - streams = chat.get_all_group_streams() -""" - -from typing import List, Dict, Any, Optional -from enum import Enum - -from src.common.logger import get_logger -from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager - -logger = get_logger("chat_api") - - -class SpecialTypes(Enum): - """特殊枚举类型""" - - ALL_PLATFORMS = "all_platforms" - - -class ChatManager: - """聊天管理器 - 专门负责聊天信息的查询和管理""" - - @staticmethod - def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]: - # sourcery skip: for-append-to-extend - """获取所有聊天流 - - Args: - platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流 - - Returns: - List[BotChatSession]: 聊天流列表 - - Raises: - TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型 - """ - if not isinstance(platform, (str, SpecialTypes)): - raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") - streams = [] - try: - for _, stream in _chat_manager.sessions.items(): - if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform: - streams.append(stream) - logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的聊天流") - except Exception as e: - logger.error(f"[ChatAPI] 获取聊天流失败: {e}") - return streams - - @staticmethod - def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]: - # sourcery skip: for-append-to-extend - """获取所有群聊聊天流 - - Args: - platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流 - - Returns: - List[BotChatSession]: 群聊聊天流列表 - """ - if not isinstance(platform, (str, SpecialTypes)): - raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") - streams = [] - try: - for _, stream in _chat_manager.sessions.items(): - if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.is_group_session: - streams.append(stream) - logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的群聊流") - except Exception as e: - logger.error(f"[ChatAPI] 获取群聊流失败: {e}") - return streams - - @staticmethod - def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]: - # sourcery skip: for-append-to-extend - """获取所有私聊聊天流 - - Args: - platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流 - - Returns: - List[BotChatSession]: 私聊聊天流列表 - - Raises: - TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型 - """ - if not isinstance(platform, (str, SpecialTypes)): - raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") - streams = [] - try: - for _, stream in _chat_manager.sessions.items(): - if ( - platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform - ) and not stream.is_group_session: - streams.append(stream) - logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的私聊流") - except Exception as e: - logger.error(f"[ChatAPI] 获取私聊流失败: {e}") - return streams - - @staticmethod - def get_group_stream_by_group_id( - group_id: str, platform: Optional[str] | SpecialTypes = "qq" - ) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast - """根据群ID获取聊天流 - - Args: - group_id: 群聊ID - platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流 - - Returns: - Optional[BotChatSession]: 聊天流对象,如果未找到返回None - - Raises: - ValueError: 如果 group_id 为空字符串 - TypeError: 如果 group_id 不是字符串类型或 platform 不是字符串或 SpecialTypes - """ - if not isinstance(group_id, str): - raise TypeError("group_id 必须是字符串类型") - if not isinstance(platform, (str, SpecialTypes)): - raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") - if not group_id: - raise ValueError("group_id 不能为空") - try: - for _, stream in _chat_manager.sessions.items(): - if ( - stream.is_group_session - and str(stream.group_id) == str(group_id) - and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) - ): - logger.debug(f"[ChatAPI] 找到群ID {group_id} 的聊天流") - return stream - logger.warning(f"[ChatAPI] 未找到群ID {group_id} 的聊天流") - except Exception as e: - logger.error(f"[ChatAPI] 查找群聊流失败: {e}") - return None - - @staticmethod - def get_private_stream_by_user_id( - user_id: str, platform: Optional[str] | SpecialTypes = "qq" - ) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast - """根据用户ID获取私聊流 - - Args: - user_id: 用户ID - platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流 - - Returns: - Optional[BotChatSession]: 聊天流对象,如果未找到返回None - - Raises: - ValueError: 如果 user_id 为空字符串 - TypeError: 如果 user_id 不是字符串类型或 platform 不是字符串或 SpecialTypes - """ - if not isinstance(user_id, str): - raise TypeError("user_id 必须是字符串类型") - if not isinstance(platform, (str, SpecialTypes)): - raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") - if not user_id: - raise ValueError("user_id 不能为空") - try: - for _, stream in _chat_manager.sessions.items(): - if ( - not stream.is_group_session - and str(stream.user_id) == str(user_id) - and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) - ): - logger.debug(f"[ChatAPI] 找到用户ID {user_id} 的私聊流") - return stream - logger.warning(f"[ChatAPI] 未找到用户ID {user_id} 的私聊流") - except Exception as e: - logger.error(f"[ChatAPI] 查找私聊流失败: {e}") - return None - - @staticmethod - def get_stream_type(chat_stream: BotChatSession) -> str: - """获取聊天流类型 - - Args: - chat_stream: 聊天流对象 - - Returns: - str: 聊天类型 ("group", "private", "unknown") - - Raises: - TypeError: 如果 chat_stream 不是 BotChatSession 类型 - ValueError: 如果 chat_stream 为空 - """ - if not isinstance(chat_stream, BotChatSession): - raise TypeError("chat_stream 必须是 BotChatSession 类型") - if not chat_stream: - raise ValueError("chat_stream 不能为 None") - - return "group" if chat_stream.is_group_session else "private" - - @staticmethod - def get_stream_info(chat_stream: BotChatSession) -> Dict[str, Any]: - """获取聊天流详细信息 - - Args: - chat_stream: 聊天流对象 - - Returns: - Dict ({str: Any}): 聊天流信息字典 - - Raises: - TypeError: 如果 chat_stream 不是 BotChatSession 类型 - ValueError: 如果 chat_stream 为空 - """ - if not chat_stream: - raise ValueError("chat_stream 不能为 None") - if not isinstance(chat_stream, BotChatSession): - raise TypeError("chat_stream 必须是 BotChatSession 类型") - - try: - info: Dict[str, Any] = { - "session_id": chat_stream.session_id, - "platform": chat_stream.platform, - "type": ChatManager.get_stream_type(chat_stream), - } - - if chat_stream.is_group_session: - info["group_id"] = chat_stream.group_id - # Try to get group name from context - if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.group_info: - info["group_name"] = chat_stream.context.message.message_info.group_info.group_name or "未知群聊" - else: - info["group_name"] = "未知群聊" - else: - info["user_id"] = chat_stream.user_id - if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.user_info: - info["user_name"] = chat_stream.context.message.message_info.user_info.user_nickname - else: - info["user_name"] = "未知用户" - - return info - except Exception as e: - logger.error(f"[ChatAPI] 获取聊天流信息失败: {e}") - return {} - - @staticmethod - def get_streams_summary() -> Dict[str, int]: - """获取聊天流统计摘要 - - Returns: - Dict[str, int]: 包含各种统计信息的字典 - """ - try: - all_streams = ChatManager.get_all_streams(SpecialTypes.ALL_PLATFORMS) - group_streams = ChatManager.get_group_streams(SpecialTypes.ALL_PLATFORMS) - private_streams = ChatManager.get_private_streams(SpecialTypes.ALL_PLATFORMS) - - summary = { - "total_streams": len(all_streams), - "group_streams": len(group_streams), - "private_streams": len(private_streams), - "qq_streams": len([s for s in all_streams if s.platform == "qq"]), - } - - logger.debug(f"[ChatAPI] 聊天流统计: {summary}") - return summary - except Exception as e: - logger.error(f"[ChatAPI] 获取聊天流统计失败: {e}") - return { - "total_streams": 0, - "group_streams": 0, - "private_streams": 0, - "qq_streams": 0, - } - - -# ============================================================================= -# 模块级别的便捷函数 - 类似 requests.get(), requests.post() 的设计 -# ============================================================================= - - -def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]: - """获取所有聊天流的便捷函数""" - return ChatManager.get_all_streams(platform) - - -def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]: - """获取群聊聊天流的便捷函数""" - return ChatManager.get_group_streams(platform) - - -def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]: - """获取私聊聊天流的便捷函数""" - return ChatManager.get_private_streams(platform) - - -def get_stream_by_group_id(group_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[BotChatSession]: - """根据群ID获取聊天流的便捷函数""" - return ChatManager.get_group_stream_by_group_id(group_id, platform) - - -def get_stream_by_user_id(user_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[BotChatSession]: - """根据用户ID获取私聊流的便捷函数""" - return ChatManager.get_private_stream_by_user_id(user_id, platform) - - -def get_stream_type(chat_stream: BotChatSession) -> str: - """获取聊天流类型的便捷函数""" - return ChatManager.get_stream_type(chat_stream) - - -def get_stream_info(chat_stream: BotChatSession) -> Dict[str, Any]: - """获取聊天流信息的便捷函数""" - return ChatManager.get_stream_info(chat_stream) - - -def get_streams_summary() -> Dict[str, int]: - """获取聊天流统计摘要的便捷函数""" - return ChatManager.get_streams_summary() diff --git a/src/plugin_system/apis/component_manage_api.py b/src/plugin_system/apis/component_manage_api.py deleted file mode 100644 index 1ffa0833..00000000 --- a/src/plugin_system/apis/component_manage_api.py +++ /dev/null @@ -1,268 +0,0 @@ -from typing import Optional, Union, Dict -from src.plugin_system.base.component_types import ( - CommandInfo, - ActionInfo, - EventHandlerInfo, - PluginInfo, - ComponentType, - ToolInfo, -) - - -# === 插件信息查询 === -def get_all_plugin_info() -> Dict[str, PluginInfo]: - """ - 获取所有插件的信息。 - - Returns: - dict: 包含所有插件信息的字典,键为插件名称,值为 PluginInfo 对象。 - """ - from src.plugin_system.core.component_registry import component_registry - - return component_registry.get_all_plugins() - - -def get_plugin_info(plugin_name: str) -> Optional[PluginInfo]: - """ - 获取指定插件的信息。 - - Args: - plugin_name (str): 插件名称。 - - Returns: - PluginInfo: 插件信息对象,如果插件不存在则返回 None。 - """ - from src.plugin_system.core.component_registry import component_registry - - return component_registry.get_plugin_info(plugin_name) - - -# === 组件查询方法 === -def get_component_info( - component_name: str, component_type: ComponentType -) -> Optional[Union[CommandInfo, ActionInfo, EventHandlerInfo]]: - """ - 获取指定组件的信息。 - - Args: - component_name (str): 组件名称。 - component_type (ComponentType): 组件类型。 - Returns: - Union[CommandInfo, ActionInfo, EventHandlerInfo]: 组件信息对象,如果组件不存在则返回 None。 - """ - from src.plugin_system.core.component_registry import component_registry - - return component_registry.get_component_info(component_name, component_type) # type: ignore - - -def get_components_info_by_type( - component_type: ComponentType, -) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]: - """ - 获取指定类型的所有组件信息。 - - Args: - component_type (ComponentType): 组件类型。 - - Returns: - dict: 包含指定类型组件信息的字典,键为组件名称,值为对应的组件信息对象。 - """ - from src.plugin_system.core.component_registry import component_registry - - return component_registry.get_components_by_type(component_type) # type: ignore - - -def get_enabled_components_info_by_type( - component_type: ComponentType, -) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]: - """ - 获取指定类型的所有启用的组件信息。 - - Args: - component_type (ComponentType): 组件类型。 - - Returns: - dict: 包含指定类型启用组件信息的字典,键为组件名称,值为对应的组件信息对象。 - """ - from src.plugin_system.core.component_registry import component_registry - - return component_registry.get_enabled_components_by_type(component_type) # type: ignore - - -# === Action 查询方法 === -def get_registered_action_info(action_name: str) -> Optional[ActionInfo]: - """ - 获取指定 Action 的注册信息。 - - Args: - action_name (str): Action 名称。 - - Returns: - ActionInfo: Action 信息对象,如果 Action 不存在则返回 None。 - """ - from src.plugin_system.core.component_registry import component_registry - - return component_registry.get_registered_action_info(action_name) - - -def get_registered_command_info(command_name: str) -> Optional[CommandInfo]: - """ - 获取指定 Command 的注册信息。 - - Args: - command_name (str): Command 名称。 - - Returns: - CommandInfo: Command 信息对象,如果 Command 不存在则返回 None。 - """ - from src.plugin_system.core.component_registry import component_registry - - return component_registry.get_registered_command_info(command_name) - - -def get_registered_tool_info(tool_name: str) -> Optional[ToolInfo]: - """ - 获取指定 Tool 的注册信息。 - - Args: - tool_name (str): Tool 名称。 - - Returns: - ToolInfo: Tool 信息对象,如果 Tool 不存在则返回 None。 - """ - from src.plugin_system.core.component_registry import component_registry - - return component_registry.get_registered_tool_info(tool_name) - - -# === EventHandler 特定查询方法 === -def get_registered_event_handler_info( - event_handler_name: str, -) -> Optional[EventHandlerInfo]: - """ - 获取指定 EventHandler 的注册信息。 - - Args: - event_handler_name (str): EventHandler 名称。 - - Returns: - EventHandlerInfo: EventHandler 信息对象,如果 EventHandler 不存在则返回 None。 - """ - from src.plugin_system.core.component_registry import component_registry - - return component_registry.get_registered_event_handler_info(event_handler_name) - - -# === 组件管理方法 === -def globally_enable_component(component_name: str, component_type: ComponentType) -> bool: - """ - 全局启用指定组件。 - - Args: - component_name (str): 组件名称。 - component_type (ComponentType): 组件类型。 - - Returns: - bool: 启用成功返回 True,否则返回 False。 - """ - from src.plugin_system.core.component_registry import component_registry - - return component_registry.enable_component(component_name, component_type) - - -async def globally_disable_component(component_name: str, component_type: ComponentType) -> bool: - """ - 全局禁用指定组件。 - - **此函数是异步的,确保在异步环境中调用。** - - Args: - component_name (str): 组件名称。 - component_type (ComponentType): 组件类型。 - - Returns: - bool: 禁用成功返回 True,否则返回 False。 - """ - from src.plugin_system.core.component_registry import component_registry - - return await component_registry.disable_component(component_name, component_type) - - -def locally_enable_component(component_name: str, component_type: ComponentType, stream_id: str) -> bool: - """ - 局部启用指定组件。 - - Args: - component_name (str): 组件名称。 - component_type (ComponentType): 组件类型。 - stream_id (str): 消息流 ID。 - - Returns: - bool: 启用成功返回 True,否则返回 False。 - """ - from src.plugin_system.core.global_announcement_manager import global_announcement_manager - - match component_type: - case ComponentType.ACTION: - return global_announcement_manager.enable_specific_chat_action(stream_id, component_name) - case ComponentType.COMMAND: - return global_announcement_manager.enable_specific_chat_command(stream_id, component_name) - case ComponentType.TOOL: - return global_announcement_manager.enable_specific_chat_tool(stream_id, component_name) - case ComponentType.EVENT_HANDLER: - return global_announcement_manager.enable_specific_chat_event_handler(stream_id, component_name) - case _: - raise ValueError(f"未知 component type: {component_type}") - - -def locally_disable_component(component_name: str, component_type: ComponentType, stream_id: str) -> bool: - """ - 局部禁用指定组件。 - - Args: - component_name (str): 组件名称。 - component_type (ComponentType): 组件类型。 - stream_id (str): 消息流 ID。 - - Returns: - bool: 禁用成功返回 True,否则返回 False。 - """ - from src.plugin_system.core.global_announcement_manager import global_announcement_manager - - match component_type: - case ComponentType.ACTION: - return global_announcement_manager.disable_specific_chat_action(stream_id, component_name) - case ComponentType.COMMAND: - return global_announcement_manager.disable_specific_chat_command(stream_id, component_name) - case ComponentType.TOOL: - return global_announcement_manager.disable_specific_chat_tool(stream_id, component_name) - case ComponentType.EVENT_HANDLER: - return global_announcement_manager.disable_specific_chat_event_handler(stream_id, component_name) - case _: - raise ValueError(f"未知 component type: {component_type}") - - -def get_locally_disabled_components(stream_id: str, component_type: ComponentType) -> list[str]: - """ - 获取指定消息流中禁用的组件列表。 - - Args: - stream_id (str): 消息流 ID。 - component_type (ComponentType): 组件类型。 - - Returns: - list[str]: 禁用的组件名称列表。 - """ - from src.plugin_system.core.global_announcement_manager import global_announcement_manager - - match component_type: - case ComponentType.ACTION: - return global_announcement_manager.get_disabled_chat_actions(stream_id) - case ComponentType.COMMAND: - return global_announcement_manager.get_disabled_chat_commands(stream_id) - case ComponentType.TOOL: - return global_announcement_manager.get_disabled_chat_tools(stream_id) - case ComponentType.EVENT_HANDLER: - return global_announcement_manager.get_disabled_chat_event_handlers(stream_id) - case _: - raise ValueError(f"未知 component type: {component_type}") diff --git a/src/plugin_system/apis/constants.py b/src/plugin_system/apis/constants.py deleted file mode 100644 index 88d74dca..00000000 --- a/src/plugin_system/apis/constants.py +++ /dev/null @@ -1,22 +0,0 @@ -from pathlib import Path -from dataclasses import dataclass - - -@dataclass(frozen=True) -class _SystemConstants: - PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.parent.absolute().resolve() - CONFIG_DIR: Path = PROJECT_ROOT / "config" - BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute() - MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute() - PLUGINS_DIR: Path = (PROJECT_ROOT / "plugins").resolve().absolute() - INTERNAL_PLUGINS_DIR: Path = (PROJECT_ROOT / "src" / "plugins").resolve().absolute() - - -_system_constants = _SystemConstants() - -PROJECT_ROOT: Path = _system_constants.PROJECT_ROOT -CONFIG_DIR: Path = _system_constants.CONFIG_DIR -BOT_CONFIG_PATH: Path = _system_constants.BOT_CONFIG_PATH -MODEL_CONFIG_PATH: Path = _system_constants.MODEL_CONFIG_PATH -PLUGINS_DIR: Path = _system_constants.PLUGINS_DIR -INTERNAL_PLUGINS_DIR: Path = _system_constants.INTERNAL_PLUGINS_DIR diff --git a/src/plugin_system/apis/emoji_api.py b/src/plugin_system/apis/emoji_api.py deleted file mode 100644 index cbc6dc50..00000000 --- a/src/plugin_system/apis/emoji_api.py +++ /dev/null @@ -1,700 +0,0 @@ -""" -表情API模块 - -提供表情包相关功能,采用标准Python包设计模式 -使用方式: - from src.plugin_system.apis import emoji_api - result = await emoji_api.get_by_description("开心") - count = emoji_api.get_count() -""" - -import random -import base64 -import os -import uuid - -from typing import Optional, Tuple, List, Dict, Any -from src.common.logger import get_logger -from src.chat.emoji_system.emoji_manager import emoji_manager, EMOJI_DIR -from src.common.utils.utils_image import ImageUtils -from src.config.config import global_config - -logger = get_logger("emoji_api") - - -# ============================================================================= -# 表情包获取API函数 -# ============================================================================= - - -async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]: - """根据描述选择表情包 - - Args: - description: 表情包的描述文本,例如"开心"、"难过"、"愤怒"等 - - Returns: - Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None - - Raises: - ValueError: 如果描述为空字符串 - TypeError: 如果描述不是字符串类型 - """ - if not description: - raise ValueError("描述不能为空") - if not isinstance(description, str): - raise TypeError("描述必须是字符串类型") - try: - logger.debug(f"[EmojiAPI] 根据描述获取表情包: {description}") - - emoji_obj = await emoji_manager.get_emoji_for_emotion(description) - - if not emoji_obj: - logger.warning(f"[EmojiAPI] 未找到匹配描述 '{description}' 的表情包") - return None - - emoji_path = str(emoji_obj.full_path) - emoji_description = emoji_obj.description - matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else "" - emoji_base64 = ImageUtils.image_path_to_base64(emoji_path) - - if not emoji_base64: - logger.error(f"[EmojiAPI] 无法将表情包文件转换为base64: {emoji_path}") - return None - - logger.debug(f"[EmojiAPI] 成功获取表情包: {emoji_description}, 匹配情感: {matched_emotion}") - return emoji_base64, emoji_description, matched_emotion - - except Exception as e: - logger.error(f"[EmojiAPI] 获取表情包失败: {e}") - return None - - -async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]: - """随机获取指定数量的表情包 - - Args: - count: 要获取的表情包数量,默认为1 - - Returns: - List[Tuple[str, str, str]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表,失败则返回空列表 - - Raises: - TypeError: 如果count不是整数类型 - ValueError: 如果count为负数 - """ - if not isinstance(count, int): - raise TypeError("count 必须是整数类型") - if count < 0: - raise ValueError("count 不能为负数") - if count == 0: - logger.warning("[EmojiAPI] count 为0,返回空列表") - return [] - - try: - all_emojis = emoji_manager.emojis - - if not all_emojis: - logger.warning("[EmojiAPI] 没有可用的表情包") - return [] - - # 过滤有效表情包 - valid_emojis = [emoji for emoji in all_emojis if not emoji.is_deleted] - if not valid_emojis: - logger.warning("[EmojiAPI] 没有有效的表情包") - return [] - - if len(valid_emojis) < count: - logger.debug( - f"[EmojiAPI] 有效表情包数量 ({len(valid_emojis)}) 少于请求的数量 ({count}),将返回所有有效表情包" - ) - count = len(valid_emojis) - - # 随机选择 - selected_emojis = random.sample(valid_emojis, count) - - results = [] - for selected_emoji in selected_emojis: - emoji_base64 = ImageUtils.image_path_to_base64(str(selected_emoji.full_path)) - - if not emoji_base64: - logger.error(f"[EmojiAPI] 无法转换表情包为base64: {selected_emoji.full_path}") - continue - - matched_emotion = random.choice(selected_emoji.emotion) if selected_emoji.emotion else "随机表情" - - # 记录使用次数 - emoji_manager.update_emoji_usage(selected_emoji) - results.append((emoji_base64, selected_emoji.description, matched_emotion)) - - if not results and count > 0: - logger.warning("[EmojiAPI] 随机获取表情包失败,没有一个可以成功处理") - return [] - - logger.debug(f"[EmojiAPI] 成功获取 {len(results)} 个随机表情包") - return results - - except Exception as e: - logger.error(f"[EmojiAPI] 获取随机表情包失败: {e}") - return [] - - -async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]: - """根据情感标签获取表情包 - - Args: - emotion: 情感标签,如"happy"、"sad"、"angry"等 - - Returns: - Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None - - Raises: - ValueError: 如果情感标签为空字符串 - TypeError: 如果情感标签不是字符串类型 - """ - if not emotion: - raise ValueError("情感标签不能为空") - if not isinstance(emotion, str): - raise TypeError("情感标签必须是字符串类型") - try: - logger.info(f"[EmojiAPI] 根据情感获取表情包: {emotion}") - - all_emojis = emoji_manager.emojis - - # 筛选匹配情感的表情包 - matching_emojis = [] - matching_emojis.extend( - emoji_obj - for emoji_obj in all_emojis - if not emoji_obj.is_deleted and emotion.lower() in [e.lower() for e in emoji_obj.emotion] - ) - if not matching_emojis: - logger.warning(f"[EmojiAPI] 未找到匹配情感 '{emotion}' 的表情包") - return None - - # 随机选择匹配的表情包 - selected_emoji = random.choice(matching_emojis) - emoji_base64 = ImageUtils.image_path_to_base64(selected_emoji.full_path) - - if not emoji_base64: - logger.error(f"[EmojiAPI] 无法转换表情包为base64: {selected_emoji.full_path}") - return None - - # 记录使用次数 - emoji_manager.update_emoji_usage(selected_emoji) - - logger.info(f"[EmojiAPI] 成功获取情感表情包: {selected_emoji.description}") - return emoji_base64, selected_emoji.description, emotion - - except Exception as e: - logger.error(f"[EmojiAPI] 根据情感获取表情包失败: {e}") - return None - - -# ============================================================================= -# 表情包信息查询API函数 -# ============================================================================= - - -def get_count() -> int: - """获取表情包数量 - - Returns: - int: 当前可用的表情包数量 - """ - try: - return len(emoji_manager.emojis) - except Exception as e: - logger.error(f"[EmojiAPI] 获取表情包数量失败: {e}") - return 0 - - -def get_info(): - """获取表情包系统信息 - - Returns: - dict: 包含表情包数量、最大数量、可用数量信息 - """ - try: - return { - "current_count": len(emoji_manager.emojis), - "max_count": global_config.emoji.max_reg_num, - "available_emojis": len([e for e in emoji_manager.emojis if not e.is_deleted]), - } - except Exception as e: - logger.error(f"[EmojiAPI] 获取表情包信息失败: {e}") - return {"current_count": 0, "max_count": 0, "available_emojis": 0} - - -def get_emotions() -> List[str]: - """获取所有可用的情感标签 - - Returns: - list: 所有表情包的情感标签列表(去重) - """ - try: - emotions = set() - - for emoji_obj in emoji_manager.emojis: - if not emoji_obj.is_deleted and emoji_obj.emotion: - emotions.update(emoji_obj.emotion) - - return sorted(list(emotions)) - except Exception as e: - logger.error(f"[EmojiAPI] 获取情感标签失败: {e}") - return [] - - -async def get_all() -> List[Tuple[str, str, str]]: - """获取所有表情包 - - Returns: - List[Tuple[str, str, str]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表 - """ - try: - all_emojis = emoji_manager.emojis - - if not all_emojis: - logger.warning("[EmojiAPI] 没有可用的表情包") - return [] - - results = [] - for emoji_obj in all_emojis: - if emoji_obj.is_deleted: - continue - - emoji_base64 = ImageUtils.image_path_to_base64(str(emoji_obj.full_path)) - - if not emoji_base64: - logger.error(f"[EmojiAPI] 无法转换表情包为base64: {emoji_obj.full_path}") - continue - - matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else "随机表情" - results.append((emoji_base64, emoji_obj.description, matched_emotion)) - - logger.debug(f"[EmojiAPI] 成功获取 {len(results)} 个表情包") - return results - - except Exception as e: - logger.error(f"[EmojiAPI] 获取所有表情包失败: {e}") - return [] - - -def get_descriptions() -> List[str]: - """获取所有表情包描述 - - Returns: - list: 所有可用表情包的描述列表 - """ - try: - descriptions = [] - - descriptions.extend( - emoji_obj.description - for emoji_obj in emoji_manager.emojis - if not emoji_obj.is_deleted and emoji_obj.description - ) - return descriptions - except Exception as e: - logger.error(f"[EmojiAPI] 获取表情包描述失败: {e}") - return [] - - -# ============================================================================= -# 表情包注册API函数 -# ============================================================================= - - -async def register_emoji(image_base64: str, filename: Optional[str] = None) -> Dict[str, Any]: - """注册新的表情包 - - Args: - image_base64: 图片的base64编码 - filename: 可选的文件名,如果未提供则自动生成 - - Returns: - Dict[str, Any]: 注册结果,包含以下字段: - - success: bool, 是否成功注册 - - message: str, 结果消息 - - description: Optional[str], 表情包描述(成功时) - - emotions: Optional[List[str]], 情感标签列表(成功时) - - replaced: Optional[bool], 是否替换了旧表情包(成功时) - - hash: Optional[str], 表情包哈希值(成功时) - - Raises: - ValueError: 如果base64为空或无效 - TypeError: 如果参数类型不正确 - """ - if not image_base64: - raise ValueError("图片base64编码不能为空") - if not isinstance(image_base64, str): - raise TypeError("image_base64必须是字符串类型") - if filename is not None and not isinstance(filename, str): - raise TypeError("filename必须是字符串类型或None") - - try: - logger.info(f"[EmojiAPI] 开始注册表情包,文件名: {filename or '自动生成'}") - - # 1. 获取emoji管理器并检查容量 - count_before = len(emoji_manager.emojis) - max_count = global_config.emoji.max_reg_num - - # 2. 检查是否可以注册(未达到上限或启用替换) - can_register = count_before < max_count or (count_before >= max_count and global_config.emoji.do_replace) - - if not can_register: - return { - "success": False, - "message": f"表情包数量已达上限({count_before}/{max_count})且未启用替换功能", - "description": None, - "emotions": None, - "replaced": None, - "hash": None, - } - - # 3. 确保emoji目录存在 - os.makedirs(EMOJI_DIR, exist_ok=True) - - # 4. 生成文件名 - if not filename: - # 基于时间戳、微秒和短base64生成唯一文件名 - import time - - timestamp = int(time.time()) - microseconds = int(time.time() * 1000000) % 1000000 # 添加微秒级精度 - - # 生成12位随机标识符,使用base64编码(增加随机性) - import random - - random_bytes = random.getrandbits(72).to_bytes(9, "big") # 72位 = 9字节 = 12位base64 - short_id = base64.b64encode(random_bytes).decode("ascii")[:12].rstrip("=") - # 确保base64编码适合文件名(替换/和-) - short_id = short_id.replace("/", "_").replace("+", "-") - filename = f"emoji_{timestamp}_{microseconds}_{short_id}" - - # 确保文件名有扩展名 - if not filename.lower().endswith((".jpg", ".jpeg", ".png", ".gif")): - filename = f"{filename}.png" # 默认使用png格式 - - # 检查文件名是否已存在,如果存在则重新生成短标识符 - temp_file_path = os.path.join(EMOJI_DIR, filename) - attempts = 0 - max_attempts = 10 - while os.path.exists(temp_file_path) and attempts < max_attempts: - # 重新生成短标识符 - import random - - random_bytes = random.getrandbits(48).to_bytes(6, "big") - short_id = base64.b64encode(random_bytes).decode("ascii")[:8].rstrip("=") - short_id = short_id.replace("/", "_").replace("+", "-") - - # 分离文件名和扩展名,重新生成文件名 - name_part, ext = os.path.splitext(filename) - # 去掉原来的标识符,添加新的 - base_name = name_part.rsplit("_", 1)[0] # 移除最后一个_后的部分 - filename = f"{base_name}_{short_id}{ext}" - temp_file_path = os.path.join(EMOJI_DIR, filename) - attempts += 1 - - # 如果还是冲突,使用UUID作为备用方案 - if os.path.exists(temp_file_path): - uuid_short = str(uuid.uuid4())[:8] - name_part, ext = os.path.splitext(filename) - base_name = name_part.rsplit("_", 1)[0] - filename = f"{base_name}_{uuid_short}{ext}" - temp_file_path = os.path.join(EMOJI_DIR, filename) - - # 如果UUID方案也冲突,添加序号 - counter = 1 - original_filename = filename - while os.path.exists(temp_file_path): - name_part, ext = os.path.splitext(original_filename) - filename = f"{name_part}_{counter}{ext}" - temp_file_path = os.path.join(EMOJI_DIR, filename) - counter += 1 - - # 防止无限循环,最多尝试100次 - if counter > 100: - logger.error(f"[EmojiAPI] 无法生成唯一文件名,尝试次数过多: {original_filename}") - return { - "success": False, - "message": "无法生成唯一文件名,请稍后重试", - "description": None, - "emotions": None, - "replaced": None, - "hash": None, - } - - # 5. 保存base64图片到emoji目录 - - try: - # 解码base64并保存图片 - if not ImageUtils.base64_to_image(image_base64, temp_file_path): - logger.error(f"[EmojiAPI] 无法保存base64图片到文件: {temp_file_path}") - return { - "success": False, - "message": "无法保存图片文件", - "description": None, - "emotions": None, - "replaced": None, - "hash": None, - } - - logger.debug(f"[EmojiAPI] 图片已保存到临时文件: {temp_file_path}") - - except Exception as save_error: - logger.error(f"[EmojiAPI] 保存图片文件失败: {save_error}") - return { - "success": False, - "message": f"保存图片文件失败: {str(save_error)}", - "description": None, - "emotions": None, - "replaced": None, - "hash": None, - } - - # 6. 调用注册方法 - register_success = await emoji_manager.register_emoji_by_filename(filename) - - # 7. 清理临时文件(如果注册失败但文件还存在) - if not register_success and os.path.exists(temp_file_path): - try: - os.remove(temp_file_path) - logger.debug(f"[EmojiAPI] 已清理临时文件: {temp_file_path}") - except Exception as cleanup_error: - logger.warning(f"[EmojiAPI] 清理临时文件失败: {cleanup_error}") - - # 8. 构建返回结果 - if register_success: - count_after = len(emoji_manager.emojis) - replaced = count_after <= count_before # 如果数量没增加,说明是替换 - - # 尝试获取新注册的表情包信息 - new_emoji_info = None - if count_after > count_before or replaced: - # 获取最新的表情包信息 - try: - # 通过文件名查找新注册的表情包(注意:文件名在注册后可能已经改变) - for emoji_obj in reversed(emoji_manager.emojis): - if not emoji_obj.is_deleted and ( - emoji_obj.file_name == filename - or (hasattr(emoji_obj, "full_path") and filename in str(emoji_obj.full_path)) - ): - new_emoji_info = emoji_obj - break - except Exception as find_error: - logger.warning(f"[EmojiAPI] 查找新注册表情包信息失败: {find_error}") - - description = new_emoji_info.description if new_emoji_info else None - emotions = new_emoji_info.emotion if new_emoji_info else None - emoji_hash = new_emoji_info.emoji_hash if new_emoji_info else None - - return { - "success": True, - "message": f"表情包注册成功 {'(替换旧表情包)' if replaced else '(新增表情包)'}", - "description": description, - "emotions": emotions, - "replaced": replaced, - "hash": emoji_hash, - } - else: - return { - "success": False, - "message": "表情包注册失败,可能因为重复、格式不支持或审核未通过", - "description": None, - "emotions": None, - "replaced": None, - "hash": None, - } - - except Exception as e: - logger.error(f"[EmojiAPI] 注册表情包时发生异常: {e}") - return { - "success": False, - "message": f"注册过程中发生错误: {str(e)}", - "description": None, - "emotions": None, - "replaced": None, - "hash": None, - } - - -# ============================================================================= -# 表情包删除API函数 -# ============================================================================= - - -async def delete_emoji(emoji_hash: str) -> Dict[str, Any]: - """删除表情包 - - Args: - emoji_hash: 要删除的表情包的哈希值 - - Returns: - Dict[str, Any]: 删除结果,包含以下字段: - - success: bool, 是否成功删除 - - message: str, 结果消息 - - count_before: Optional[int], 删除前的表情包数量 - - count_after: Optional[int], 删除后的表情包数量 - - description: Optional[str], 被删除的表情包描述(成功时) - - emotions: Optional[List[str]], 被删除的表情包情感标签(成功时) - - Raises: - ValueError: 如果哈希值为空 - TypeError: 如果哈希值不是字符串类型 - """ - if not emoji_hash: - raise ValueError("表情包哈希值不能为空") - if not isinstance(emoji_hash, str): - raise TypeError("emoji_hash必须是字符串类型") - - try: - logger.info(f"[EmojiAPI] 开始删除表情包,哈希值: {emoji_hash}") - - # 1. 获取emoji管理器和删除前的数量 - count_before = len(emoji_manager.emojis) - - # 2. 获取被删除表情包的信息(用于返回结果) - deleted_emoji = None - try: - deleted_emoji = emoji_manager.get_emoji_by_hash(emoji_hash) or emoji_manager.get_emoji_by_hash_from_db( - emoji_hash - ) - description = deleted_emoji.description if deleted_emoji else None - emotions = deleted_emoji.emotion if deleted_emoji else None - except Exception as info_error: - logger.warning(f"[EmojiAPI] 获取被删除表情包信息失败: {info_error}") - description = None - emotions = None - - # 3. 执行删除操作 - delete_success = False - if deleted_emoji: - delete_success = emoji_manager.delete_emoji(deleted_emoji) - - # 4. 获取删除后的数量 - count_after = len(emoji_manager.emojis) - - # 5. 构建返回结果 - if delete_success: - return { - "success": True, - "message": f"表情包删除成功 (哈希: {emoji_hash[:8]}...)", - "count_before": count_before, - "count_after": count_after, - "description": description, - "emotions": emotions, - } - else: - return { - "success": False, - "message": "表情包删除失败,可能因为哈希值不存在或删除过程出错", - "count_before": count_before, - "count_after": count_after, - "description": None, - "emotions": None, - } - - except Exception as e: - logger.error(f"[EmojiAPI] 删除表情包时发生异常: {e}") - return { - "success": False, - "message": f"删除过程中发生错误: {str(e)}", - "count_before": None, - "count_after": None, - "description": None, - "emotions": None, - } - - -async def delete_emoji_by_description(description: str, exact_match: bool = False) -> Dict[str, Any]: - """根据描述删除表情包 - - Args: - description: 表情包描述文本 - exact_match: 是否精确匹配描述,False则为模糊匹配 - - Returns: - Dict[str, Any]: 删除结果,包含以下字段: - - success: bool, 是否成功删除 - - message: str, 结果消息 - - deleted_count: int, 删除的表情包数量 - - deleted_hashes: List[str], 被删除的表情包哈希列表 - - matched_count: int, 匹配到的表情包数量 - - Raises: - ValueError: 如果描述为空 - TypeError: 如果描述不是字符串类型 - """ - if not description: - raise ValueError("描述不能为空") - if not isinstance(description, str): - raise TypeError("description必须是字符串类型") - - try: - logger.info(f"[EmojiAPI] 根据描述删除表情包: {description} (精确匹配: {exact_match})") - - all_emojis = emoji_manager.emojis - - # 筛选匹配的表情包 - matching_emojis = [] - for emoji_obj in all_emojis: - if emoji_obj.is_deleted: - continue - - if exact_match: - if emoji_obj.description == description: - matching_emojis.append(emoji_obj) - else: - if description.lower() in emoji_obj.description.lower(): - matching_emojis.append(emoji_obj) - - matched_count = len(matching_emojis) - if matched_count == 0: - return { - "success": False, - "message": f"未找到匹配描述 '{description}' 的表情包", - "deleted_count": 0, - "deleted_hashes": [], - "matched_count": 0, - } - - # 删除匹配的表情包 - deleted_count = 0 - deleted_hashes = [] - for emoji_obj in matching_emojis: - try: - delete_success = emoji_manager.delete_emoji(emoji_obj) - if delete_success: - deleted_count += 1 - deleted_hashes.append(emoji_obj.emoji_hash) - except Exception as delete_error: - logger.error(f"[EmojiAPI] 删除表情包失败 (哈希: {emoji_obj.emoji_hash}): {delete_error}") - - # 构建返回结果 - if deleted_count > 0: - return { - "success": True, - "message": f"成功删除 {deleted_count} 个表情包 (匹配到 {matched_count} 个)", - "deleted_count": deleted_count, - "deleted_hashes": deleted_hashes, - "matched_count": matched_count, - } - else: - return { - "success": False, - "message": f"匹配到 {matched_count} 个表情包,但删除全部失败", - "deleted_count": 0, - "deleted_hashes": [], - "matched_count": matched_count, - } - - except Exception as e: - logger.error(f"[EmojiAPI] 根据描述删除表情包时发生异常: {e}") - return { - "success": False, - "message": f"删除过程中发生错误: {str(e)}", - "deleted_count": 0, - "deleted_hashes": [], - "matched_count": 0, - } diff --git a/src/plugin_system/apis/logging_api.py b/src/plugin_system/apis/logging_api.py deleted file mode 100644 index 7aeec413..00000000 --- a/src/plugin_system/apis/logging_api.py +++ /dev/null @@ -1,3 +0,0 @@ -from src.common.logger import get_logger - -__all__ = ["get_logger"] diff --git a/src/plugin_system/apis/plugin_manage_api.py b/src/plugin_system/apis/plugin_manage_api.py deleted file mode 100644 index d428eb28..00000000 --- a/src/plugin_system/apis/plugin_manage_api.py +++ /dev/null @@ -1,120 +0,0 @@ -from typing import Tuple, List - - -def list_loaded_plugins() -> List[str]: - """ - 列出所有当前加载的插件。 - - Returns: - List[str]: 当前加载的插件名称列表。 - """ - from src.plugin_system.core.plugin_manager import plugin_manager - - return plugin_manager.list_loaded_plugins() - - -def list_registered_plugins() -> List[str]: - """ - 列出所有已注册的插件。 - - Returns: - List[str]: 已注册的插件名称列表。 - """ - from src.plugin_system.core.plugin_manager import plugin_manager - - return plugin_manager.list_registered_plugins() - - -def get_plugin_path(plugin_name: str) -> str: - """ - 获取指定插件的路径。 - - Args: - plugin_name (str): 插件名称。 - - Returns: - str: 插件目录的绝对路径。 - - Raises: - ValueError: 如果插件不存在。 - """ - from src.plugin_system.core.plugin_manager import plugin_manager - - if plugin_path := plugin_manager.get_plugin_path(plugin_name): - return plugin_path - else: - raise ValueError(f"插件 '{plugin_name}' 不存在。") - - -async def remove_plugin(plugin_name: str) -> bool: - """ - 卸载指定的插件。 - - **此函数是异步的,确保在异步环境中调用。** - - Args: - plugin_name (str): 要卸载的插件名称。 - - Returns: - bool: 卸载是否成功。 - """ - from src.plugin_system.core.plugin_manager import plugin_manager - - return await plugin_manager.remove_registered_plugin(plugin_name) - - -async def reload_plugin(plugin_name: str) -> bool: - """ - 重新加载指定的插件。 - - **此函数是异步的,确保在异步环境中调用。** - - Args: - plugin_name (str): 要重新加载的插件名称。 - - Returns: - bool: 重新加载是否成功。 - """ - from src.plugin_system.core.plugin_manager import plugin_manager - - return await plugin_manager.reload_registered_plugin(plugin_name) - - -def load_plugin(plugin_name: str) -> Tuple[bool, int]: - """ - 加载指定的插件。 - - Args: - plugin_name (str): 要加载的插件名称。 - - Returns: - Tuple[bool, int]: 加载是否成功,成功或失败个数。 - """ - from src.plugin_system.core.plugin_manager import plugin_manager - - return plugin_manager.load_registered_plugin_classes(plugin_name) - - -def add_plugin_directory(plugin_directory: str) -> bool: - """ - 添加插件目录。 - - Args: - plugin_directory (str): 要添加的插件目录路径。 - Returns: - bool: 添加是否成功。 - """ - from src.plugin_system.core.plugin_manager import plugin_manager - - return plugin_manager.add_plugin_directory(plugin_directory) - - -def rescan_plugin_directory() -> Tuple[int, int]: - """ - 重新扫描插件目录,加载新插件。 - Returns: - Tuple[int, int]: 成功加载的插件数量和失败的插件数量。 - """ - from src.plugin_system.core.plugin_manager import plugin_manager - - return plugin_manager.rescan_plugin_directory() diff --git a/src/plugin_system/apis/plugin_register_api.py b/src/plugin_system/apis/plugin_register_api.py deleted file mode 100644 index 2e14b0c8..00000000 --- a/src/plugin_system/apis/plugin_register_api.py +++ /dev/null @@ -1,46 +0,0 @@ -from pathlib import Path - -from src.common.logger import get_logger - -logger = get_logger("plugin_manager") # 复用plugin_manager名称 - - -def register_plugin(cls): - from src.plugin_system.core.plugin_manager import plugin_manager - from src.plugin_system.base.base_plugin import BasePlugin - - """插件注册装饰器 - - 用法: - @register_plugin - class MyPlugin(BasePlugin): - plugin_name = "my_plugin" - plugin_description = "我的插件" - ... - """ - if not issubclass(cls, BasePlugin): - logger.error(f"类 {cls.__name__} 不是 BasePlugin 的子类") - return cls - - # 只是注册插件类,不立即实例化 - # 插件管理器会负责实例化和注册 - plugin_name: str = cls.plugin_name # type: ignore - if "." in plugin_name: - logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代") - raise ValueError(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代") - splitted_name = cls.__module__.split(".") - root_path = Path(__file__) - - # 查找项目根目录 - while not (root_path / "pyproject.toml").exists() and root_path.parent != root_path: - root_path = root_path.parent - - if not (root_path / "pyproject.toml").exists(): - logger.error(f"注册 {plugin_name} 无法找到项目根目录") - return cls - - plugin_manager.plugin_classes[plugin_name] = cls - plugin_manager.plugin_paths[plugin_name] = str(Path(root_path, *splitted_name).resolve()) - logger.debug(f"插件类已注册: {plugin_name}, 路径: {plugin_manager.plugin_paths[plugin_name]}") - - return cls diff --git a/src/plugin_system/apis/plugin_service_api.py b/src/plugin_system/apis/plugin_service_api.py deleted file mode 100644 index 4c783f04..00000000 --- a/src/plugin_system/apis/plugin_service_api.py +++ /dev/null @@ -1,56 +0,0 @@ -from typing import Any, Callable, Dict, Optional - -from src.plugin_system.base.service_types import PluginServiceInfo -from src.plugin_system.core.plugin_service_registry import plugin_service_registry - - -def register_service(service_info: PluginServiceInfo, service_handler: Callable[..., Any]) -> bool: - """注册插件服务。""" - return plugin_service_registry.register_service(service_info, service_handler) - - -def get_service(service_name: str, plugin_name: Optional[str] = None) -> Optional[PluginServiceInfo]: - """获取插件服务元信息。""" - return plugin_service_registry.get_service(service_name, plugin_name) - - -def get_service_handler(service_name: str, plugin_name: Optional[str] = None) -> Optional[Callable[..., Any]]: - """获取插件服务处理函数。""" - return plugin_service_registry.get_service_handler(service_name, plugin_name) - - -def list_services(plugin_name: Optional[str] = None, enabled_only: bool = False) -> Dict[str, PluginServiceInfo]: - """列出插件服务。""" - return plugin_service_registry.list_services(plugin_name=plugin_name, enabled_only=enabled_only) - - -def enable_service(service_name: str, plugin_name: Optional[str] = None) -> bool: - """启用插件服务。""" - return plugin_service_registry.enable_service(service_name, plugin_name) - - -def disable_service(service_name: str, plugin_name: Optional[str] = None) -> bool: - """禁用插件服务。""" - return plugin_service_registry.disable_service(service_name, plugin_name) - - -def unregister_service(service_name: str, plugin_name: Optional[str] = None) -> bool: - """注销插件服务。""" - return plugin_service_registry.unregister_service(service_name, plugin_name) - - -async def call_service( - service_name: str, - *args: Any, - plugin_name: Optional[str] = None, - caller_plugin: Optional[str] = None, - **kwargs: Any, -) -> Any: - """调用插件服务。""" - return await plugin_service_registry.call_service( - service_name, - *args, - plugin_name=plugin_name, - caller_plugin=caller_plugin, - **kwargs, - ) diff --git a/src/plugin_system/apis/tool_api.py b/src/plugin_system/apis/tool_api.py deleted file mode 100644 index 00464ea7..00000000 --- a/src/plugin_system/apis/tool_api.py +++ /dev/null @@ -1,45 +0,0 @@ -from typing import Optional, Type, TYPE_CHECKING -from src.plugin_system.base.base_tool import BaseTool -from src.plugin_system.base.component_types import ComponentType - -from src.common.logger import get_logger - -if TYPE_CHECKING: - from src.chat.message_receive.chat_manager import BotChatSession - -logger = get_logger("tool_api") - - -def get_tool_instance(tool_name: str, chat_stream: Optional["BotChatSession"] = None) -> Optional[BaseTool]: - """获取公开工具实例 - - Args: - tool_name: 工具名称 - chat_stream: 聊天流对象,用于传递聊天上下文信息 - - Returns: - Optional[BaseTool]: 工具实例,如果未找到则返回None - """ - from src.plugin_system.core import component_registry - - # 获取插件配置 - tool_info = component_registry.get_component_info(tool_name, ComponentType.TOOL) - if tool_info: - plugin_config = component_registry.get_plugin_config(tool_info.plugin_name) - else: - plugin_config = None - - tool_class: Type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore - return tool_class(plugin_config, chat_stream) if tool_class else None - - -def get_llm_available_tool_definitions(): - """获取LLM可用的工具定义列表 - - Returns: - List[Tuple[str, Dict[str, Any]]]: 工具定义列表,为[("tool_name", 定义)] - """ - from src.plugin_system.core import component_registry - - llm_available_tools = component_registry.get_llm_available_tools() - return [(name, tool_class.get_tool_definition()) for name, tool_class in llm_available_tools.items()] diff --git a/src/plugin_system/apis/workflow_api.py b/src/plugin_system/apis/workflow_api.py deleted file mode 100644 index c7d45e88..00000000 --- a/src/plugin_system/apis/workflow_api.py +++ /dev/null @@ -1,77 +0,0 @@ -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -from src.plugin_system.base.component_types import EventType, MaiMessages -from src.plugin_system.base.workflow_types import WorkflowContext, WorkflowStage, WorkflowStepInfo, WorkflowStepResult -from src.plugin_system.core.component_registry import component_registry -from src.plugin_system.core.events_manager import events_manager -from src.plugin_system.core.workflow_engine import workflow_engine - - -def register_workflow_step(step_info: WorkflowStepInfo, step_handler: Callable[..., Any]) -> bool: - """注册workflow step。""" - return component_registry.register_workflow_step(step_info, step_handler) - - -def get_steps_by_stage(stage: WorkflowStage, enabled_only: bool = False) -> Dict[str, WorkflowStepInfo]: - """获取指定阶段的workflow steps。""" - return component_registry.get_steps_by_stage(stage, enabled_only=enabled_only) - - -def get_workflow_step(step_name: str, stage: Optional[WorkflowStage] = None) -> Optional[WorkflowStepInfo]: - """获取workflow step元信息。""" - return component_registry.get_workflow_step(step_name, stage) - - -def get_workflow_step_handler(step_name: str, stage: Optional[WorkflowStage] = None) -> Optional[Callable[..., Any]]: - """获取workflow step处理函数。""" - return component_registry.get_workflow_step_handler(step_name, stage) - - -def enable_workflow_step(step_name: str, stage: Optional[WorkflowStage] = None) -> bool: - """启用workflow step。""" - return component_registry.enable_workflow_step(step_name, stage) - - -def disable_workflow_step(step_name: str, stage: Optional[WorkflowStage] = None) -> bool: - """禁用workflow step。""" - return component_registry.disable_workflow_step(step_name, stage) - - -def get_execution_trace(trace_id: str) -> Optional[Dict[str, Any]]: - """按trace_id获取workflow执行路径。""" - return workflow_engine.get_execution_trace(trace_id) - - -def clear_execution_trace(trace_id: str) -> bool: - """清理trace执行路径记录。""" - return workflow_engine.clear_execution_trace(trace_id) - - -async def execute_workflow_message( - message: Optional[MaiMessages] = None, - stream_id: Optional[str] = None, - action_usage: Optional[List[str]] = None, - context: Optional[WorkflowContext] = None, -) -> Tuple[WorkflowStepResult, Optional[MaiMessages], WorkflowContext]: - """执行workflow消息流转。""" - return await events_manager.handle_workflow_message( - message=message, - stream_id=stream_id, - action_usage=action_usage, - context=context, - ) - - -async def publish_event( - event_type: Union[EventType, str], - message: Optional[MaiMessages] = None, - stream_id: Optional[str] = None, - action_usage: Optional[List[str]] = None, -) -> Tuple[bool, Optional[MaiMessages]]: - """发布事件(支持系统事件和自定义字符串事件)。""" - return await events_manager.handle_mai_events( - event_type=event_type, - message=message, - stream_id=stream_id, - action_usage=action_usage, - ) diff --git a/src/plugin_system/base/__init__.py b/src/plugin_system/base/__init__.py deleted file mode 100644 index 7326fcec..00000000 --- a/src/plugin_system/base/__init__.py +++ /dev/null @@ -1,72 +0,0 @@ -""" -插件基础类模块 - -提供插件开发的基础类和类型定义 -""" - -from .base_plugin import BasePlugin -from .base_action import BaseAction -from .base_tool import BaseTool -from .base_command import BaseCommand -from .base_events_handler import BaseEventHandler -from .service_types import PluginServiceInfo -from .workflow_types import WorkflowContext, WorkflowMessage, WorkflowStage, WorkflowStepInfo, WorkflowStepResult -from .workflow_errors import WorkflowErrorCode -from .component_types import ( - ComponentType, - ActionActivationType, - ChatMode, - ComponentInfo, - ActionInfo, - CommandInfo, - ToolInfo, - PluginInfo, - PythonDependency, - EventHandlerInfo, - EventType, - MaiMessages, - ToolParamType, - CustomEventHandlerResult, - ReplyContentType, - ReplyContent, - ForwardNode, - ReplySetModel, -) -from .config_types import ConfigField, ConfigSection, ConfigLayout, ConfigTab - -__all__ = [ - "BasePlugin", - "BaseAction", - "BaseCommand", - "BaseTool", - "ComponentType", - "ActionActivationType", - "ChatMode", - "ComponentInfo", - "ActionInfo", - "CommandInfo", - "ToolInfo", - "PluginInfo", - "PythonDependency", - "ConfigField", - "ConfigSection", - "ConfigLayout", - "ConfigTab", - "EventHandlerInfo", - "EventType", - "BaseEventHandler", - "MaiMessages", - "ToolParamType", - "CustomEventHandlerResult", - "ReplyContentType", - "ReplyContent", - "ForwardNode", - "ReplySetModel", - "PluginServiceInfo", - "WorkflowContext", - "WorkflowMessage", - "WorkflowStage", - "WorkflowStepInfo", - "WorkflowStepResult", - "WorkflowErrorCode", -] diff --git a/src/plugin_system/base/base_action.py b/src/plugin_system/base/base_action.py deleted file mode 100644 index 1f2af8a3..00000000 --- a/src/plugin_system/base/base_action.py +++ /dev/null @@ -1,533 +0,0 @@ -import time -import asyncio - -from abc import ABC, abstractmethod -from typing import Tuple, Optional, TYPE_CHECKING, Dict, List - -from src.common.logger import get_logger -from src.common.data_models.message_data_model import ReplyContentType, ReplyContent, ReplySetModel, ForwardNode -from src.chat.message_receive.chat_manager import BotChatSession -from src.plugin_system.base.component_types import ActionActivationType, ActionInfo, ComponentType -from src.plugin_system.apis import send_api, database_api, message_api - -if TYPE_CHECKING: - from src.common.data_models.database_data_model import DatabaseMessages - -logger = get_logger("base_action") - - -class BaseAction(ABC): - """Action组件基类 - - Action是插件的一种组件类型,用于处理聊天中的动作逻辑 - - 子类可以通过类属性定义激活条件,这些会在实例化时转换为实例属性: - - focus_activation_type: 专注模式激活类型 - - normal_activation_type: 普通模式激活类型 - - activation_keywords: 激活关键词列表 - - keyword_case_sensitive: 关键词是否区分大小写 - - parallel_action: 是否允许并行执行 - - random_activation_probability: 随机激活概率 - """ - - def __init__( - self, - action_data: dict, - action_reasoning: str, - cycle_timers: dict, - thinking_id: str, - chat_stream: BotChatSession, - plugin_config: Optional[dict] = None, - action_message: Optional["DatabaseMessages"] = None, - **kwargs, - ): - # sourcery skip: hoist-similar-statement-from-if, merge-else-if-into-elif, move-assign-in-block, swap-if-else-branches, swap-nested-ifs - """初始化Action组件 - - Args: - action_data: 动作数据 - reasoning: 执行该动作的理由 - cycle_timers: 计时器字典 - thinking_id: 思考ID - chat_stream: 聊天流对象 - log_prefix: 日志前缀 - plugin_config: 插件配置字典 - action_message: 消息数据 - **kwargs: 其他参数 - """ - if plugin_config is None: - plugin_config = {} - self.action_data = action_data - self.reasoning = "" - self.cycle_timers = cycle_timers - self.thinking_id = thinking_id - - self.action_reasoning = action_reasoning - - self.plugin_config = plugin_config or {} - """对应的插件配置""" - - # 设置动作基本信息实例属性 - self.action_name: str = getattr(self, "action_name", self.__class__.__name__.lower().replace("action", "")) - """Action的名字""" - self.action_description: str = getattr(self, "action_description", self.__doc__ or "Action组件") - """Action的描述""" - self.action_parameters: dict = getattr(self.__class__, "action_parameters", {}).copy() - self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy() - - """NORMAL模式下的激活类型""" - self.activation_type = self.__class__.activation_type - """激活类型""" - self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0) - """当激活类型为RANDOM时的概率""" - self.activation_keywords: list[str] = getattr(self.__class__, "activation_keywords", []).copy() - """激活类型为KEYWORD时的KEYWORDS列表""" - self.keyword_case_sensitive: bool = getattr(self.__class__, "keyword_case_sensitive", False) - self.parallel_action: bool = getattr(self.__class__, "parallel_action", True) - self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy() - - # ============================================================================= - # 便捷属性 - 直接在初始化时获取常用聊天信息(带类型注解) - # ============================================================================= - - # 获取聊天流对象 - self.chat_stream = chat_stream or kwargs.get("chat_stream") - self.chat_id = self.chat_stream.session_id - self.platform = getattr(self.chat_stream, "platform", None) - - # 初始化基础信息(带类型注解) - self.action_message = action_message - - self.group_id = None - self.group_name = None - self.user_id = None - self.user_nickname = None - self.is_group = False - self.target_id = None - - self.group_id = ( - str(self.action_message.chat_info.group_info.group_id) if self.action_message.chat_info.group_info else None - ) - self.group_name = ( - self.action_message.chat_info.group_info.group_name if self.action_message.chat_info.group_info else None - ) - - self.user_id = str(self.action_message.user_info.user_id) - self.user_nickname = self.action_message.user_info.user_nickname - - if self.group_id: - self.is_group = True - self.target_id = self.group_id - self.log_prefix = f"[{self.group_name}]" - else: - self.is_group = False - self.target_id = self.user_id - self.log_prefix = f"[{self.user_nickname} 的 私聊]" - - logger.debug( - f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}" - ) - - @abstractmethod - async def execute(self) -> Tuple[bool, str]: - """执行Action的抽象方法,子类必须实现 - - Returns: - Tuple[bool, str]: (是否执行成功, 回复文本) - """ - pass - - async def send_text( - self, - content: str, - set_reply: bool = False, - reply_message: Optional["DatabaseMessages"] = None, - typing: bool = False, - storage_message: bool = True, - ) -> bool: - """发送文本消息 - - Args: - content: 文本内容 - set_reply: 是否作为回复发送 - reply_message: 回复的消息对象(当set_reply为True时必填) - typing: 是否计算输入时间 - - Returns: - bool: 是否发送成功 - """ - if not self.chat_id: - logger.error(f"{self.log_prefix} 缺少聊天ID") - return False - - return await send_api.text_to_stream( - text=content, - stream_id=self.chat_id, - set_reply=set_reply, - reply_message=reply_message, - typing=typing, - storage_message=storage_message, - ) - - async def send_emoji( - self, - emoji_base64: str, - set_reply: bool = False, - reply_message: Optional["DatabaseMessages"] = None, - storage_message: bool = True, - ) -> bool: - """发送表情包 - - Args: - emoji_base64: 表情包的base64编码 - set_reply: 是否作为回复发送 - reply_message: 回复的消息对象(当set_reply为True时必填) - - Returns: - bool: 是否发送成功 - """ - if not self.chat_id: - logger.error(f"{self.log_prefix} 缺少聊天ID") - return False - - return await send_api.emoji_to_stream( - emoji_base64, - self.chat_id, - set_reply=set_reply, - reply_message=reply_message, - storage_message=storage_message, - ) - - async def send_image( - self, - image_base64: str, - set_reply: bool = False, - reply_message: Optional["DatabaseMessages"] = None, - storage_message: bool = True, - ) -> bool: - """发送图片 - - Args: - image_base64: 图片的base64编码 - set_reply: 是否作为回复发送 - reply_message: 回复的消息对象(当set_reply为True时必填) - - Returns: - bool: 是否发送成功 - """ - if not self.chat_id: - logger.error(f"{self.log_prefix} 缺少聊天ID") - return False - - return await send_api.image_to_stream( - image_base64, - self.chat_id, - set_reply=set_reply, - reply_message=reply_message, - storage_message=storage_message, - ) - - async def send_command( - self, - command_name: str, - args: Optional[dict] = None, - display_message: str = "", - storage_message: bool = True, - ) -> bool: - """发送命令消息 - - Args: - command_name: 命令名称 - args: 命令参数 - display_message: 显示消息 - storage_message: 是否存储消息到数据库 - - Returns: - bool: 是否发送成功 - """ - if not self.chat_id: - logger.error(f"{self.log_prefix} 缺少聊天ID") - return False - - # 构造命令数据 - command_data = {"name": command_name, "args": args or {}} - - return await send_api.command_to_stream( - command=command_data, - stream_id=self.chat_id, - storage_message=storage_message, - display_message=display_message, - ) - - async def send_custom( - self, - message_type: str, - content: str | Dict, - typing: bool = False, - set_reply: bool = False, - reply_message: Optional["DatabaseMessages"] = None, - storage_message: bool = True, - ) -> bool: - """发送自定义类型消息 - - Args: - message_type: 消息类型,如"video"、"file"、"audio"等 - content: 消息内容 - typing: 是否显示正在输入 - set_reply: 是否作为回复发送 - reply_message: 回复的消息对象(set_reply 为 True时必填) - storage_message: 是否存储消息到数据库 - - Returns: - bool: 是否发送成功 - """ - if not self.chat_id: - logger.error(f"{self.log_prefix} 缺少聊天ID") - return False - - return await send_api.custom_to_stream( - message_type=message_type, - content=content, - stream_id=self.chat_id, - typing=typing, - set_reply=set_reply, - reply_message=reply_message, - storage_message=storage_message, - ) - - async def send_hybrid( - self, - message_tuple_list: List[Tuple[ReplyContentType | str, str]], - typing: bool = False, - set_reply: bool = False, - reply_message: Optional["DatabaseMessages"] = None, - storage_message: bool = True, - ) -> bool: - """ - 发送混合类型消息 - - Args: - message_tuple_list: 包含消息类型和内容的元组列表,格式为 [(内容类型, 内容), ...] - typing: 是否计算打字时间 - set_reply: 是否作为回复发送 - reply_message: 回复的消息对象 - """ - if not self.chat_id: - logger.error(f"{self.log_prefix} 缺少聊天ID") - return False - reply_set = ReplySetModel() - reply_set.add_hybrid_content_by_raw(message_tuple_list) - return await send_api.custom_reply_set_to_stream( - reply_set=reply_set, - stream_id=self.chat_id, - typing=typing, - set_reply=set_reply, - reply_message=reply_message, - storage_message=storage_message, - ) - - async def send_forward( - self, - messages_list: List[Tuple[str, str, List[Tuple[ReplyContentType | str, str]]] | str], - storage_message: bool = True, - ) -> bool: - """转发消息 - - Args: - messages_list: 包含消息信息的列表,当传入自行生成的数据时,元素格式为 (sender_id, nickname, 消息体);当传入消息ID时,元素格式为 "message_id" - 其中消息体的格式为 [(内容类型, 内容), ...] - 任意长度的消息都需要使用列表的形式传入 - storage_message: 是否存储消息到数据库 - - Returns: - bool: 是否发送成功 - """ - if not self.chat_id: - logger.error(f"{self.log_prefix} 缺少聊天ID") - return False - reply_set = ReplySetModel() - forward_message_nodes: List[ForwardNode] = [] - for message in messages_list: - if isinstance(message, str): - forward_message_node = ForwardNode.construct_as_id_reference(message) - elif isinstance(message, Tuple) and len(message) == 3: - sender_id, nickname, content_list = message - single_node_content_list: List[ReplyContent] = [] - for node_content_type, node_content in content_list: - reply_node_content = ReplyContent(content_type=node_content_type, content=node_content) - single_node_content_list.append(reply_node_content) - forward_message_node = ForwardNode.construct_as_created_node( - user_id=sender_id, user_nickname=nickname, content=single_node_content_list - ) - else: - logger.warning(f"{self.log_prefix} 转发消息时遇到无效的消息格式: {message}") - continue - forward_message_nodes.append(forward_message_node) - reply_set.add_forward_content(forward_message_nodes) - return await send_api.custom_reply_set_to_stream( - reply_set=reply_set, - stream_id=self.chat_id, - storage_message=storage_message, - set_reply=False, - reply_message=None, - ) - - async def send_voice(self, audio_base64: str) -> bool: - """ - 发送语音消息 - Args: - audio_base64: 语音的base64编码 - Returns: - bool: 是否发送成功 - """ - if not audio_base64: - logger.error(f"{self.log_prefix} 缺少音频内容") - return False - reply_set = ReplySetModel() - reply_set.add_voice_content(audio_base64) - return await send_api.custom_reply_set_to_stream( - reply_set=reply_set, - stream_id=self.chat_id, - storage_message=False, - ) - - async def store_action_info( - self, - action_build_into_prompt: bool = False, - action_prompt_display: str = "", - action_done: bool = True, - ) -> None: - """存储动作信息到数据库 - - Args: - action_build_into_prompt: 是否构建到提示中 - action_prompt_display: 显示的action提示信息 - action_done: action是否完成 - """ - await database_api.store_action_info( - chat_stream=self.chat_stream, - action_build_into_prompt=action_build_into_prompt, - action_prompt_display=action_prompt_display, - action_done=action_done, - thinking_id=self.thinking_id, - action_data=self.action_data, - action_name=self.action_name, - action_reasoning=self.action_reasoning, - ) - - async def wait_for_new_message(self, timeout: int = 1200) -> Tuple[bool, str]: - """等待新消息或超时 - - 在loop_start_time之后等待新消息,如果没有新消息且没有超时,就一直等待。 - 使用message_api检查self.chat_id对应的聊天中是否有新消息。 - - Args: - timeout: 超时时间(秒),默认1200秒 - - Returns: - Tuple[bool, str]: (是否收到新消息, 空字符串) - """ - try: - # 获取循环开始时间,如果没有则使用当前时间 - loop_start_time = self.action_data.get("loop_start_time", time.time()) - logger.info(f"{self.log_prefix} 开始等待新消息... (最长等待: {timeout}秒, 从时间点: {loop_start_time})") - - # 确保有有效的chat_id - if not self.chat_id: - logger.error(f"{self.log_prefix} 等待新消息失败: 没有有效的chat_id") - return False, "没有有效的chat_id" - - wait_start_time = asyncio.get_event_loop().time() - while True: - # 检查新消息 - current_time = time.time() - new_message_count = message_api.count_new_messages( - chat_id=self.chat_id, start_time=loop_start_time, end_time=current_time - ) - - if new_message_count > 0: - logger.info(f"{self.log_prefix} 检测到{new_message_count}条新消息,聊天ID: {self.chat_id}") - return True, "" - - # 检查超时 - elapsed_time = asyncio.get_event_loop().time() - wait_start_time - if elapsed_time > timeout: - logger.warning(f"{self.log_prefix} 等待新消息超时({timeout}秒),聊天ID: {self.chat_id}") - return False, "" - - # 每30秒记录一次等待状态 - if int(elapsed_time) % 15 == 0 and int(elapsed_time) > 0: - logger.debug(f"{self.log_prefix} 已等待{int(elapsed_time)}秒,继续等待新消息...") - - # 短暂休眠 - await asyncio.sleep(0.5) - - except asyncio.CancelledError: - logger.info(f"{self.log_prefix} 等待新消息被中断 (CancelledError)") - return False, "" - except Exception as e: - logger.error(f"{self.log_prefix} 等待新消息时发生错误: {e}") - return False, f"等待新消息失败: {str(e)}" - - @classmethod - def get_action_info(cls) -> "ActionInfo": - """从类属性生成ActionInfo - - 所有信息都从类属性中读取,确保一致性和完整性。 - Action类必须定义所有必要的类属性。 - - Returns: - ActionInfo: 生成的Action信息对象 - """ - - # 从类属性读取名称,如果没有定义则使用类名自动生成 - name = getattr(cls, "action_name", cls.__name__.lower().replace("action", "")) - if "." in name: - logger.error(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代") - raise ValueError(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代") - # 获取focus_activation_type和normal_activation_type - focus_activation_type = getattr(cls, "focus_activation_type", ActionActivationType.ALWAYS) - _normal_activation_type = getattr(cls, "normal_activation_type", ActionActivationType.ALWAYS) - - # 处理activation_type:如果插件中声明了就用插件的值,否则默认使用focus_activation_type - activation_type = getattr(cls, "activation_type", focus_activation_type) - - return ActionInfo( - name=name, - component_type=ComponentType.ACTION, - description=getattr(cls, "action_description", "Action动作"), - activation_type=activation_type, - activation_keywords=getattr(cls, "activation_keywords", []).copy(), - keyword_case_sensitive=getattr(cls, "keyword_case_sensitive", False), - parallel_action=getattr(cls, "parallel_action", True), - random_activation_probability=getattr(cls, "random_activation_probability", 0.0), - # 使用正确的字段名 - action_parameters=getattr(cls, "action_parameters", {}).copy(), - action_require=getattr(cls, "action_require", []).copy(), - associated_types=getattr(cls, "associated_types", []).copy(), - ) - - def get_config(self, key: str, default=None): - """获取插件配置值,使用嵌套键访问 - - Args: - key: 配置键名,使用嵌套访问如 "section.subsection.key" - default: 默认值 - - Returns: - Any: 配置值或默认值 - """ - if not self.plugin_config: - return default - - # 支持嵌套键访问 - keys = key.split(".") - current = self.plugin_config - - for k in keys: - if isinstance(current, dict) and k in current: - current = current[k] - else: - return default - - return current diff --git a/src/plugin_system/base/base_command.py b/src/plugin_system/base/base_command.py deleted file mode 100644 index ffa1e46b..00000000 --- a/src/plugin_system/base/base_command.py +++ /dev/null @@ -1,387 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Dict, Tuple, Optional, TYPE_CHECKING, List -from src.common.logger import get_logger -from src.common.data_models.message_data_model import ReplyContentType, ReplyContent, ReplySetModel, ForwardNode -from src.plugin_system.base.component_types import CommandInfo, ComponentType -from src.chat.message_receive.message import SessionMessage -from src.plugin_system.apis import send_api - -if TYPE_CHECKING: - from src.common.data_models.database_data_model import DatabaseMessages - -logger = get_logger("base_command") - - -class BaseCommand(ABC): - """Command组件基类 - - Command是插件的一种组件类型,用于处理命令请求 - - 子类可以通过类属性定义命令模式: - - command_pattern: 命令匹配的正则表达式 - - command_help: 命令帮助信息 - - command_examples: 命令使用示例列表 - """ - - command_name: str = "" - """Command组件的名称""" - command_description: str = "" - """Command组件的描述""" - # 默认命令设置 - command_pattern: str = r"" - """命令匹配的正则表达式""" - - def __init__(self, message: SessionMessage, plugin_config: Optional[dict] = None): - """初始化Command组件 - - Args: - message: 接收到的消息对象 - plugin_config: 插件配置字典 - """ - self.message = message - self.matched_groups: Dict[str, str] = {} # 存储正则表达式匹配的命名组 - self.plugin_config = plugin_config or {} # 直接存储插件配置字典 - - self.log_prefix = "[Command]" - - logger.debug(f"{self.log_prefix} Command组件初始化完成") - - def set_matched_groups(self, groups: Dict[str, str]) -> None: - """设置正则表达式匹配的命名组 - - Args: - groups: 正则表达式匹配的命名组 - """ - self.matched_groups = groups - - @abstractmethod - async def execute(self) -> Tuple[bool, Optional[str], int]: - """执行Command的抽象方法,子类必须实现 - - Returns: - Tuple[bool, Optional[str], int]: (是否执行成功, 可选的回复消息, 拦截消息力度,0代表不拦截,1代表仅不触发回复,replyer可见,2代表不触发回复,replyer不可见) - """ - pass - - def get_config(self, key: str, default=None): - """获取插件配置值,使用嵌套键访问 - - Args: - key: 配置键名,使用嵌套访问如 "section.subsection.key" - default: 默认值 - - Returns: - Any: 配置值或默认值 - """ - if not self.plugin_config: - return default - - # 支持嵌套键访问 - keys = key.split(".") - current = self.plugin_config - - for k in keys: - if isinstance(current, dict) and k in current: - current = current[k] - else: - return default - - return current - - async def send_text( - self, - content: str, - set_reply: bool = False, - reply_message: Optional["DatabaseMessages"] = None, - storage_message: bool = True, - ) -> bool: - """发送回复消息 - - Args: - content: 回复内容 - set_reply: 是否作为回复发送 - reply_message: 回复的消息对象(当set_reply为True时必填) - storage_message: 是否存储消息到数据库 - - Returns: - bool: 是否发送成功 - """ - # 获取聊天流信息 - session_id = self.message.session_id - if not session_id: - logger.error(f"{self.log_prefix} 缺少session_id") - return False - - return await send_api.text_to_stream( - text=content, - stream_id=session_id, - set_reply=set_reply, - reply_message=reply_message, - storage_message=storage_message, - ) - - async def send_image( - self, - image_base64: str, - set_reply: bool = False, - reply_message: Optional["DatabaseMessages"] = None, - storage_message: bool = True, - ) -> bool: - """发送图片 - - Args: - image_base64: 图片的base64编码 - - Returns: - bool: 是否发送成功 - """ - session_id = self.message.session_id - if not session_id: - logger.error(f"{self.log_prefix} 缺少session_id") - return False - - return await send_api.image_to_stream( - image_base64, - session_id, - set_reply=set_reply, - reply_message=reply_message, - storage_message=storage_message, - ) - - async def send_emoji( - self, - emoji_base64: str, - set_reply: bool = False, - reply_message: Optional["DatabaseMessages"] = None, - storage_message: bool = True, - ) -> bool: - """发送表情包 - - Args: - emoji_base64: 表情包的base64编码 - set_reply: 是否作为回复发送 - reply_message: 回复的消息对象(当set_reply为True时必填) - storage_message: 是否存储消息到数据库 - - Returns: - bool: 是否发送成功 - """ - session_id = self.message.session_id - if not session_id: - logger.error(f"{self.log_prefix} 缺少session_id") - return False - - return await send_api.emoji_to_stream( - emoji_base64, session_id, set_reply=set_reply, reply_message=reply_message - ) - - async def send_command( - self, - command_name: str, - args: Optional[dict] = None, - display_message: str = "", - storage_message: bool = True, - ) -> bool: - """发送命令消息 - - Args: - command_name: 命令名称 - args: 命令参数 - display_message: 显示消息 - storage_message: 是否存储消息到数据库 - - Returns: - bool: 是否发送成功 - """ - try: - # 获取聊天流信息 - session_id = self.message.session_id - if not session_id: - logger.error(f"{self.log_prefix} 缺少session_id") - return False - - # 构造命令数据 - command_data = {"name": command_name, "args": args or {}} - - success = await send_api.command_to_stream( - command=command_data, - stream_id=session_id, - storage_message=storage_message, - display_message=display_message, - ) - - if success: - logger.info(f"{self.log_prefix} 成功发送命令: {command_name}") - else: - logger.error(f"{self.log_prefix} 发送命令失败: {command_name}") - - return success - - except Exception as e: - logger.error(f"{self.log_prefix} 发送命令时出错: {e}") - return False - - async def send_voice(self, voice_base64: str) -> bool: - """ - 发送语音消息 - Args: - voice_base64: 语音的base64编码 - Returns: - bool: 是否发送成功 - """ - session_id = self.message.session_id - if not session_id: - logger.error(f"{self.log_prefix} 缺少session_id") - return False - - return await send_api.custom_to_stream( - message_type="voice", - content=voice_base64, - stream_id=session_id, - typing=False, - set_reply=False, - reply_message=None, - storage_message=False, - ) - - async def send_hybrid( - self, - message_tuple_list: List[Tuple[ReplyContentType | str, str]], - typing: bool = False, - set_reply: bool = False, - reply_message: Optional["DatabaseMessages"] = None, - storage_message: bool = True, - ) -> bool: - """ - 发送混合类型消息 - - Args: - message_tuple_list: 包含消息类型和内容的元组列表,格式为 [(内容类型, 内容), ...] - typing: 是否显示正在输入 - set_reply: 是否计算打字时间 - reply_message: 回复的消息对象 - storage_message: 是否存储消息到数据库 - """ - session_id = self.message.session_id - if not session_id: - logger.error(f"{self.log_prefix} 缺少session_id") - return False - reply_set = ReplySetModel() - reply_set.add_hybrid_content_by_raw(message_tuple_list) - return await send_api.custom_reply_set_to_stream( - reply_set=reply_set, - stream_id=session_id, - typing=typing, - set_reply=set_reply, - reply_message=reply_message, - storage_message=storage_message, - ) - - async def send_forward( - self, - messages_list: List[Tuple[str, str, List[Tuple[ReplyContentType | str, str]]] | str], - storage_message: bool = True, - ) -> bool: - """转发消息 - - Args: - messages_list: 包含消息信息的列表,当传入自行生成的数据时,元素格式为 (sender_id, nickname, 消息体);当传入消息ID时,元素格式为 "message_id" - 其中消息体的格式为 [(内容类型, 内容), ...] - 任意长度的消息都需要使用列表的形式传入 - storage_message: 是否存储消息到数据库 - - Returns: - bool: 是否发送成功 - """ - session_id = self.message.session_id - if not session_id: - logger.error(f"{self.log_prefix} 缺少session_id") - return False - reply_set = ReplySetModel() - forward_message_nodes: List[ForwardNode] = [] - for message in messages_list: - if isinstance(message, str): - forward_message_node = ForwardNode.construct_as_id_reference(message) - elif isinstance(message, Tuple) and len(message) == 3: - sender_id, nickname, content_list = message - single_node_content_list: List[ReplyContent] = [] - for node_content_type, node_content in content_list: - reply_node_content = ReplyContent(content_type=node_content_type, content=node_content) - single_node_content_list.append(reply_node_content) - forward_message_node = ForwardNode.construct_as_created_node( - user_id=sender_id, user_nickname=nickname, content=single_node_content_list - ) - else: - logger.warning(f"{self.log_prefix} 转发消息时遇到无效的消息格式: {message}") - continue - forward_message_nodes.append(forward_message_node) - reply_set.add_forward_content(forward_message_nodes) - return await send_api.custom_reply_set_to_stream( - reply_set=reply_set, - stream_id=session_id, - storage_message=storage_message, - set_reply=False, - reply_message=None, - ) - - async def send_custom( - self, - message_type: str, - content: str | Dict, - display_message: str = "", - typing: bool = False, - set_reply: bool = False, - reply_message: Optional["DatabaseMessages"] = None, - storage_message: bool = True, - ) -> bool: - """发送指定类型的回复消息到当前聊天环境 - - Args: - message_type: 消息类型,如"text"、"image"、"emoji"、"voice"等 - content: 消息内容 - display_message: 显示消息(可选) - typing: 是否显示正在输入 - set_reply: 是否作为回复发送 - reply_message: 回复的消息对象(set_reply 为 True时必填) - storage_message: 是否存储消息到数据库 - - Returns: - bool: 是否发送成功 - """ - # 获取聊天流信息 - session_id = self.message.session_id - if not session_id: - logger.error(f"{self.log_prefix} 缺少session_id") - return False - - return await send_api.custom_to_stream( - message_type=message_type, - content=content, - stream_id=session_id, - display_message=display_message, - typing=typing, - set_reply=set_reply, - reply_message=reply_message, - storage_message=storage_message, - ) - - @classmethod - def get_command_info(cls) -> "CommandInfo": - """从类属性生成CommandInfo - - Args: - name: Command名称,如果不提供则使用类名 - description: Command描述,如果不提供则使用类文档字符串 - - Returns: - CommandInfo: 生成的Command信息对象 - """ - if "." in cls.command_name: - logger.error(f"Command名称 '{cls.command_name}' 包含非法字符 '.',请使用下划线替代") - raise ValueError(f"Command名称 '{cls.command_name}' 包含非法字符 '.',请使用下划线替代") - return CommandInfo( - name=cls.command_name, - component_type=ComponentType.COMMAND, - description=cls.command_description, - command_pattern=cls.command_pattern, - ) diff --git a/src/plugin_system/base/base_events_handler.py b/src/plugin_system/base/base_events_handler.py deleted file mode 100644 index d31af6f4..00000000 --- a/src/plugin_system/base/base_events_handler.py +++ /dev/null @@ -1,381 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Tuple, Optional, Dict, List, TYPE_CHECKING - -from src.common.logger import get_logger -from src.common.data_models.message_data_model import ReplyContentType, ReplySetModel, ReplyContent, ForwardNode -from src.plugin_system.apis import send_api -from .component_types import MaiMessages, EventType, EventHandlerInfo, ComponentType, CustomEventHandlerResult - -logger = get_logger("base_event_handler") - -if TYPE_CHECKING: - from src.common.data_models.database_data_model import DatabaseMessages - - -class BaseEventHandler(ABC): - """事件处理器基类 - - 所有事件处理器都应该继承这个基类,提供事件处理的基本接口 - """ - - event_type: EventType | str = EventType.UNKNOWN - """事件类型,默认为未知""" - handler_name: str = "" - """处理器名称""" - handler_description: str = "" - """处理器描述""" - weight: int = 0 - """处理器权重,越大权重越高""" - intercept_message: bool = False - """是否拦截消息,默认为否""" - - def __init__(self): - self.log_prefix = "[EventHandler]" - self.plugin_name = "" - """对应插件名""" - self.plugin_config: Optional[Dict] = None - """插件配置字典""" - if self.event_type == EventType.UNKNOWN: - raise NotImplementedError("事件处理器必须指定 event_type") - - @abstractmethod - async def execute( - self, message: MaiMessages | None - ) -> Tuple[bool, bool, Optional[str], Optional[CustomEventHandlerResult], Optional[MaiMessages]]: - """执行事件处理的抽象方法,子类必须实现 - Args: - message (MaiMessages | None): 事件消息对象,当你注册的事件为ON_START和ON_STOP时message为None - Returns: - Tuple[bool, bool, Optional[str], Optional[CustomEventHandlerResult], Optional[MaiMessages]]: (是否执行成功, 是否需要继续处理, 可选的返回消息, 可选的自定义结果,可选的修改后消息) - """ - raise NotImplementedError("子类必须实现 execute 方法") - - @classmethod - def get_handler_info(cls) -> "EventHandlerInfo": - """获取事件处理器的信息""" - # 从类属性读取名称,如果没有定义则使用类名自动生成S - name: str = getattr(cls, "handler_name", cls.__name__.lower().replace("handler", "")) - if "." in name: - logger.error(f"事件处理器名称 '{name}' 包含非法字符 '.',请使用下划线替代") - raise ValueError(f"事件处理器名称 '{name}' 包含非法字符 '.',请使用下划线替代") - return EventHandlerInfo( - name=name, - component_type=ComponentType.EVENT_HANDLER, - description=getattr(cls, "handler_description", "events处理器"), - event_type=cls.event_type, - weight=cls.weight, - intercept_message=cls.intercept_message, - ) - - def set_plugin_config(self, plugin_config: Dict) -> None: - """设置插件配置 - - Args: - plugin_config (dict): 插件配置字典 - """ - self.plugin_config = plugin_config - - def set_plugin_name(self, plugin_name: str) -> None: - """设置插件名称 - - Args: - plugin_name (str): 插件名称 - """ - self.plugin_name = plugin_name - - def get_config(self, key: str, default=None): - """获取插件配置值,支持嵌套键访问 - - Args: - key: 配置键名,支持嵌套访问如 "section.subsection.key" - default: 默认值 - - Returns: - Any: 配置值或默认值 - """ - if not self.plugin_config: - return default - - # 支持嵌套键访问 - keys = key.split(".") - current = self.plugin_config - - for k in keys: - if isinstance(current, dict) and k in current: - current = current[k] - else: - return default - - return current - - async def send_text( - self, - stream_id: str, - text: str, - set_reply: bool = False, - reply_message: Optional["DatabaseMessages"] = None, - typing: bool = False, - storage_message: bool = True, - ) -> bool: - """发送文本消息 - - Args: - stream_id: 聊天ID - text: 文本内容 - set_reply: 是否作为回复发送 - reply_message: 回复的消息对象(当set_reply为True时必填) - typing: 是否计算输入时间 - storage_message: 是否存储消息到数据库 - - Returns: - bool: 是否发送成功 - """ - if not stream_id: - logger.error(f"{self.log_prefix} 缺少聊天ID") - return False - return await send_api.text_to_stream( - text=text, - stream_id=stream_id, - set_reply=set_reply, - reply_message=reply_message, - typing=typing, - storage_message=storage_message, - ) - - async def send_emoji( - self, - stream_id: str, - emoji_base64: str, - set_reply: bool = False, - reply_message: Optional["DatabaseMessages"] = None, - storage_message: bool = True, - ) -> bool: - """发送表情消息 - - Args: - emoji_base64: 表情的Base64编码 - stream_id: 聊天ID - set_reply: 是否作为回复发送 - reply_message: 回复的消息对象(当set_reply为True时必填) - storage_message: 是否存储消息到数据库 - - Returns: - bool: 是否发送成功 - """ - if not stream_id: - logger.error(f"{self.log_prefix} 缺少聊天ID") - return False - return await send_api.emoji_to_stream( - emoji_base64=emoji_base64, - stream_id=stream_id, - set_reply=set_reply, - reply_message=reply_message, - storage_message=storage_message, - ) - - async def send_image( - self, - stream_id: str, - image_base64: str, - set_reply: bool = False, - reply_message: Optional["DatabaseMessages"] = None, - storage_message: bool = True, - ) -> bool: - """发送图片消息 - - Args: - image_base64: 图片的Base64编码 - stream_id: 聊天ID - set_reply: 是否作为回复发送 - reply_message: 回复的消息对象(当set_reply为True时必填) - storage_message: 是否存储消息到数据库 - - Returns: - bool: 是否发送成功 - """ - if not stream_id: - logger.error(f"{self.log_prefix} 缺少聊天ID") - return False - return await send_api.image_to_stream( - image_base64=image_base64, - stream_id=stream_id, - set_reply=set_reply, - reply_message=reply_message, - storage_message=storage_message, - ) - - async def send_voice( - self, - stream_id: str, - audio_base64: str, - ) -> bool: - """发送语音消息 - Args: - stream_id: 聊天ID - audio_base64: 语音的Base64编码 - Returns: - bool: 是否发送成功 - """ - if not stream_id: - logger.error(f"{self.log_prefix} 缺少聊天ID") - return False - reply_set = ReplySetModel() - reply_set.add_voice_content(audio_base64) - return await send_api.custom_reply_set_to_stream( - reply_set=reply_set, - stream_id=stream_id, - storage_message=False, - ) - - async def send_command( - self, - stream_id: str, - command_name: str, - command_args: Optional[dict] = None, - display_message: str = "", - storage_message: bool = True, - ) -> bool: - """发送命令消息 - - Args: - stream_id: 流ID - command_name: 命令名称 - command_args: 命令参数字典 - display_message: 显示消息 - storage_message: 是否存储消息到数据库 - - Returns: - bool: 是否发送成功 - """ - if not stream_id: - logger.error(f"{self.log_prefix} 缺少聊天ID") - return False - - # 构造命令数据 - command_data = {"name": command_name, "args": command_args or {}} - - return await send_api.command_to_stream( - command=command_data, - stream_id=stream_id, - storage_message=storage_message, - display_message=display_message, - ) - - async def send_custom( - self, - stream_id: str, - message_type: str, - content: str | Dict, - typing: bool = False, - set_reply: bool = False, - reply_message: Optional["DatabaseMessages"] = None, - storage_message: bool = True, - ) -> bool: - """发送自定义消息 - - Args: - stream_id: 聊天ID - message_type: 消息类型 - content: 消息内容,可以是字符串或字典 - typing: 是否显示正在输入状态 - set_reply: 是否作为回复发送 - reply_message: 回复的消息对象(当set_reply为True时必填) - storage_message: 是否存储消息到数据库 - - Returns: - bool: 是否发送成功 - """ - if not stream_id: - logger.error(f"{self.log_prefix} 缺少聊天ID") - return False - return await send_api.custom_to_stream( - message_type=message_type, - content=content, - stream_id=stream_id, - typing=typing, - set_reply=set_reply, - reply_message=reply_message, - storage_message=storage_message, - ) - - async def send_hybrid( - self, - stream_id: str, - message_tuple_list: List[Tuple[ReplyContentType | str, str]], - typing: bool = False, - set_reply: bool = False, - reply_message: Optional["DatabaseMessages"] = None, - storage_message: bool = True, - ) -> bool: - """ - 发送混合类型消息 - - Args: - stream_id: 流ID - message_tuple_list: 包含消息类型和内容的元组列表,格式为 [(内容类型, 内容), ...] - typing: 是否计算打字时间 - set_reply: 是否作为回复发送 - reply_message: 回复的消息对象 - storage_message: 是否存储消息到数据库 - """ - if not stream_id: - logger.error(f"{self.log_prefix} 缺少聊天ID") - return False - reply_set = ReplySetModel() - reply_set.add_hybrid_content_by_raw(message_tuple_list) - return await send_api.custom_reply_set_to_stream( - reply_set=reply_set, - stream_id=stream_id, - typing=typing, - set_reply=set_reply, - reply_message=reply_message, - storage_message=storage_message, - ) - - async def send_forward( - self, - stream_id: str, - messages_list: List[Tuple[str, str, List[Tuple[ReplyContentType | str, str]]] | str], - storage_message: bool = True, - ) -> bool: - """转发消息 - - Args: - stream_id: 聊天ID - messages_list: 包含消息信息的列表,当传入自行生成的数据时,元素格式为 (sender_id, nickname, 消息体);当传入消息ID时,元素格式为 "message_id" - 其中消息体的格式为 [(内容类型, 内容), ...] - 任意长度的消息都需要使用列表的形式传入 - storage_message: 是否存储消息到数据库 - - Returns: - bool: 是否发送成功 - """ - if not stream_id: - logger.error(f"{self.log_prefix} 缺少聊天ID") - return False - reply_set = ReplySetModel() - forward_message_nodes: List[ForwardNode] = [] - for message in messages_list: - if isinstance(message, str): - forward_message_node = ForwardNode.construct_as_id_reference(message) - elif isinstance(message, Tuple) and len(message) == 3: - sender_id, nickname, content_list = message - single_node_content_list: List[ReplyContent] = [] - for node_content_type, node_content in content_list: - reply_node_content = ReplyContent(content_type=node_content_type, content=node_content) - single_node_content_list.append(reply_node_content) - forward_message_node = ForwardNode.construct_as_created_node( - user_id=sender_id, user_nickname=nickname, content=single_node_content_list - ) - else: - logger.warning(f"{self.log_prefix} 转发消息时遇到无效的消息格式: {message}") - continue - forward_message_nodes.append(forward_message_node) - reply_set.add_forward_content(forward_message_nodes) - return await send_api.custom_reply_set_to_stream( - reply_set=reply_set, - stream_id=stream_id, - storage_message=storage_message, - set_reply=False, - reply_message=None, - ) diff --git a/src/plugin_system/base/base_plugin.py b/src/plugin_system/base/base_plugin.py deleted file mode 100644 index 6b278ef4..00000000 --- a/src/plugin_system/base/base_plugin.py +++ /dev/null @@ -1,102 +0,0 @@ -from abc import abstractmethod -from typing import Any, Callable, List, Type, Tuple, Union -from .plugin_base import PluginBase - -from src.common.logger import get_logger -from src.plugin_system.base.component_types import ActionInfo, CommandInfo, EventHandlerInfo, ToolInfo -from src.plugin_system.base.workflow_types import WorkflowStepInfo -from .base_action import BaseAction -from .base_command import BaseCommand -from .base_events_handler import BaseEventHandler -from .base_tool import BaseTool - -logger = get_logger("base_plugin") - - -class BasePlugin(PluginBase): - """基于Action和Command的插件基类 - - 所有上述类型的插件都应该继承这个基类,一个插件可以包含多种组件: - - Action组件:处理聊天中的动作 - - Command组件:处理命令请求 - - 未来可扩展:Scheduler、Listener等 - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - @abstractmethod - def get_plugin_components( - self, - ) -> List[ - Union[ - Tuple[ActionInfo, Type[BaseAction]], - Tuple[CommandInfo, Type[BaseCommand]], - Tuple[EventHandlerInfo, Type[BaseEventHandler]], - Tuple[ToolInfo, Type[BaseTool]], - ] - ]: - """获取插件包含的组件列表 - - 子类必须实现此方法,返回组件信息和组件类的列表 - - Returns: - List[tuple[ComponentInfo, Type]]: [(组件信息, 组件类), ...] - """ - raise NotImplementedError("Subclasses must implement this method") - - def get_workflow_steps(self) -> List[Tuple[WorkflowStepInfo, Callable[..., Any]]]: - """获取插件包含的workflow steps。 - - 默认返回空列表,子类可按需覆盖。 - """ - return [] - - def register_plugin(self) -> bool: - """注册插件及其所有组件""" - from src.plugin_system.core.component_registry import component_registry - - components = self.get_plugin_components() - - # 检查依赖 - if not self._check_dependencies(): - logger.error(f"{self.log_prefix} 依赖检查失败,跳过注册") - return False - - # 注册所有组件 - registered_components = [] - for component_info, component_class in components: - component_info.plugin_name = self.plugin_name - if component_registry.register_component(component_info, component_class): - registered_components.append(component_info) - else: - logger.warning(f"{self.log_prefix} 组件 {component_info.name} 注册失败") - - # 更新插件信息中的组件列表 - self.plugin_info.components = registered_components - - # 注册插件 - if component_registry.register_plugin(self.plugin_info): - # 注册workflow steps(可选) - registered_step_count = 0 - for step_info, step_handler in self.get_workflow_steps(): - if not step_info.plugin_name: - step_info.plugin_name = self.plugin_name - elif step_info.plugin_name != self.plugin_name: - logger.warning( - f"{self.log_prefix} workflow step {step_info.name} 的plugin_name({step_info.plugin_name})与当前插件不一致,已覆盖为 {self.plugin_name}" - ) - step_info.plugin_name = self.plugin_name - - if component_registry.register_workflow_step(step_info, step_handler): - registered_step_count += 1 - else: - logger.warning(f"{self.log_prefix} workflow step {step_info.full_name} 注册失败") - - logger.debug(f"{self.log_prefix} 插件注册成功,包含 {len(registered_components)} 个组件") - if registered_step_count > 0: - logger.debug(f"{self.log_prefix} workflow steps 注册成功,数量: {registered_step_count}") - return True - else: - logger.error(f"{self.log_prefix} 插件注册失败") - return False diff --git a/src/plugin_system/base/base_tool.py b/src/plugin_system/base/base_tool.py deleted file mode 100644 index 3938027e..00000000 --- a/src/plugin_system/base/base_tool.py +++ /dev/null @@ -1,137 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, List, Optional, Tuple, TYPE_CHECKING -from rich.traceback import install - -from src.common.logger import get_logger -from src.plugin_system.base.component_types import ComponentType, ToolInfo, ToolParamType - -if TYPE_CHECKING: - from src.chat.message_receive.chat_manager import BotChatSession - -install(extra_lines=3) - -logger = get_logger("base_tool") - - -class BaseTool(ABC): - """所有工具的基类""" - - name: str = "" - """工具的名称""" - description: str = "" - """工具的描述""" - parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = [] - """工具的参数定义,为[("param_name", param_type, "description", required, enum_values)]格式 - param_name: 参数名称 - param_type: 参数类型 - description: 参数描述 - required: 是否必填 - enum_values: 枚举值列表 - 例如: [("arg1", ToolParamType.STRING, "参数1描述", True, None), ("arg2", ToolParamType.INTEGER, "参数2描述", False, ["1", "2", "3"])] - """ - available_for_llm: bool = False - """是否可供LLM使用""" - - def __init__(self, plugin_config: Optional[dict] = None, chat_stream: Optional["BotChatSession"] = None): - """初始化工具基类 - - Args: - plugin_config: 插件配置字典 - chat_stream: 聊天流对象,用于获取聊天上下文信息 - """ - self.plugin_config = plugin_config or {} # 直接存储插件配置字典 - - # ============================================================================= - # 便捷属性 - 直接在初始化时获取常用聊天信息(与BaseAction保持一致) - # ============================================================================= - - # 获取聊天流对象 - self.chat_stream = chat_stream - self.chat_id = self.chat_stream.session_id if self.chat_stream else None - self.platform = getattr(self.chat_stream, "platform", None) if self.chat_stream else None - - @classmethod - def get_tool_definition(cls) -> dict[str, Any]: - """获取工具定义,用于LLM工具调用 - - Returns: - dict: 工具定义字典 - """ - if not cls.name or not cls.description or cls.parameters is None: - raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性") - - return {"name": cls.name, "description": cls.description, "parameters": cls.parameters} - - @classmethod - def get_tool_info(cls) -> ToolInfo: - """获取工具信息""" - if not cls.name or not cls.description or cls.parameters is None: - raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性") - - return ToolInfo( - name=cls.name, - tool_description=cls.description, - enabled=cls.available_for_llm, - tool_parameters=cls.parameters, - component_type=ComponentType.TOOL, - ) - - @abstractmethod - async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: - """执行工具函数(供llm调用) - 通过该方法,maicore会通过llm的tool call来调用工具 - 传入的是json格式的参数,符合parameters定义的格式 - - Args: - function_args: 工具调用参数 - - Returns: - dict: 工具执行结果 - """ - raise NotImplementedError("子类必须实现execute方法") - - async def direct_execute(self, **function_args: dict[str, Any]) -> dict[str, Any]: - """直接执行工具函数(供插件调用) - 通过该方法,插件可以直接调用工具,而不需要传入字典格式的参数 - 插件可以直接调用此方法,用更加明了的方式传入参数 - 示例: result = await tool.direct_execute(arg1="参数",arg2="参数2") - - 工具开发者可以重写此方法以实现与llm调用差异化的执行逻辑 - - Args: - **function_args: 工具调用参数 - - Returns: - dict: 工具执行结果 - """ - parameter_required = [param[0] for param in self.parameters if param[3]] # 获取所有必填参数名 - for param_name in parameter_required: - if param_name not in function_args: - raise ValueError(f"工具类 {self.__class__.__name__} 缺少必要参数: {param_name}") - - return await self.execute(function_args) - - def get_config(self, key: str, default=None): - """获取插件配置值,使用嵌套键访问 - - Args: - key: 配置键名,使用嵌套访问如 "section.subsection.key" - default: 默认值 - - Returns: - Any: 配置值或默认值 - """ - if not self.plugin_config: - return default - - # 支持嵌套键访问 - keys = key.split(".") - current = self.plugin_config - - for k in keys: - if isinstance(current, dict) and k in current: - current = current[k] - else: - return default - - return current diff --git a/src/plugin_system/base/plugin_base.py b/src/plugin_system/base/plugin_base.py deleted file mode 100644 index 7076af4d..00000000 --- a/src/plugin_system/base/plugin_base.py +++ /dev/null @@ -1,761 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Dict, List, Any, Union -import os -import inspect -import toml -import json -import shutil -import datetime -import re - -from src.common.logger import get_logger -from src.plugin_system.base.component_types import ( - PluginInfo, - PythonDependency, -) -from src.plugin_system.base.config_types import ( - ConfigField, - ConfigSection, - ConfigLayout, -) -from src.plugin_system.utils.manifest_utils import ManifestValidator, VersionComparator - -logger = get_logger("plugin_base") - - -class PluginBase(ABC): - """插件总基类 - - 所有衍生插件基类都应该继承自此类,这个类定义了插件的基本结构和行为。 - """ - - # 插件基本信息(子类必须定义) - @property - @abstractmethod - def plugin_name(self) -> str: - return "" # 插件内部标识符(如 "hello_world_plugin") - - @property - @abstractmethod - def enable_plugin(self) -> bool: - return True # 是否启用插件 - - @property - @abstractmethod - def dependencies(self) -> List[Union[str, Dict[str, Any]]]: - return [] # 依赖的其他插件 - - @property - @abstractmethod - def python_dependencies(self) -> List[PythonDependency]: - return [] # Python包依赖 - - @property - @abstractmethod - def config_file_name(self) -> str: - return "" # 配置文件名 - - # manifest文件相关 - manifest_file_name: str = "_manifest.json" # manifest文件名 - manifest_data: Dict[str, Any] = {} # manifest数据 - - # 配置定义 - @property - @abstractmethod - def config_schema(self) -> Dict[str, Union[Dict[str, ConfigField], str]]: - return {} - - config_section_descriptions: Dict[str, Union[str, ConfigSection]] = {} - - # 布局配置(可选,不定义则使用自动布局) - config_layout: ConfigLayout = None - - def __init__(self, plugin_dir: str): - """初始化插件 - - Args: - plugin_dir: 插件目录路径,由插件管理器传递 - """ - self.config: Dict[str, Any] = {} # 插件配置 - self.plugin_dir = plugin_dir # 插件目录路径 - self.log_prefix = f"[Plugin:{self.plugin_name}]" - - # 加载manifest文件 - self._load_manifest() - - # 验证插件信息 - self._validate_plugin_info() - - # 加载插件配置 - self._load_plugin_config() - - # 从manifest获取显示信息 - self.display_name = self.get_manifest_info("name", self.plugin_name) - self.plugin_version = self.get_manifest_info("version", "1.0.0") - self.plugin_description = self.get_manifest_info("description", "") - self.plugin_author = self._get_author_name() - - # 创建插件信息对象 - self.plugin_info = PluginInfo( - name=self.plugin_name, - display_name=self.display_name, - description=self.plugin_description, - version=self.plugin_version, - author=self.plugin_author, - enabled=self.enable_plugin, - is_built_in=False, - config_file=self.config_file_name or "", - dependencies=self._get_dependency_names(), - python_dependencies=self.python_dependencies.copy(), - # manifest相关信息 - manifest_data=self.manifest_data.copy(), - license=self.get_manifest_info("license", ""), - homepage_url=self.get_manifest_info("homepage_url", ""), - repository_url=self.get_manifest_info("repository_url", ""), - keywords=self.get_manifest_info("keywords", []).copy() if self.get_manifest_info("keywords") else [], - categories=self.get_manifest_info("categories", []).copy() if self.get_manifest_info("categories") else [], - min_host_version=self.get_manifest_info("host_application.min_version", ""), - max_host_version=self.get_manifest_info("host_application.max_version", ""), - ) - - logger.debug(f"{self.log_prefix} 插件基类初始化完成") - - def _validate_plugin_info(self): - """验证插件基本信息""" - if not self.plugin_name: - raise ValueError(f"插件类 {self.__class__.__name__} 必须定义 plugin_name") - - # 验证manifest中的必需信息 - if not self.get_manifest_info("name"): - raise ValueError(f"插件 {self.plugin_name} 的manifest中缺少name字段") - if not self.get_manifest_info("description"): - raise ValueError(f"插件 {self.plugin_name} 的manifest中缺少description字段") - - def _load_manifest(self): # sourcery skip: raise-from-previous-error - """加载manifest文件(强制要求)""" - if not self.plugin_dir: - raise ValueError(f"{self.log_prefix} 没有插件目录路径,无法加载manifest") - - manifest_path = os.path.join(self.plugin_dir, self.manifest_file_name) - - if not os.path.exists(manifest_path): - error_msg = f"{self.log_prefix} 缺少必需的manifest文件: {manifest_path}" - logger.error(error_msg) - raise FileNotFoundError(error_msg) - - try: - with open(manifest_path, "r", encoding="utf-8") as f: - self.manifest_data = json.load(f) - - logger.debug(f"{self.log_prefix} 成功加载manifest文件: {manifest_path}") - - # 验证manifest格式 - self._validate_manifest() - - except json.JSONDecodeError as e: - error_msg = f"{self.log_prefix} manifest文件格式错误: {e}" - logger.error(error_msg) - raise ValueError(error_msg) # noqa - except IOError as e: - error_msg = f"{self.log_prefix} 读取manifest文件失败: {e}" - logger.error(error_msg) - raise IOError(error_msg) # noqa - - def _get_author_name(self) -> str: - """从manifest获取作者名称""" - author_info = self.get_manifest_info("author", {}) - if isinstance(author_info, dict): - return author_info.get("name", "") - else: - return str(author_info) if author_info else "" - - def _validate_manifest(self): - """验证manifest文件格式(使用强化的验证器)""" - if not self.manifest_data: - raise ValueError(f"{self.log_prefix} manifest数据为空,验证失败") - - validator = ManifestValidator() - is_valid = validator.validate_manifest(self.manifest_data) - - # 记录验证结果 - if validator.validation_errors or validator.validation_warnings: - report = validator.get_validation_report() - logger.info(f"{self.log_prefix} Manifest验证结果:\n{report}") - - # 如果有验证错误,抛出异常 - if not is_valid: - error_msg = f"{self.log_prefix} Manifest文件验证失败" - if validator.validation_errors: - error_msg += f": {'; '.join(validator.validation_errors)}" - raise ValueError(error_msg) - - def get_manifest_info(self, key: str, default: Any = None) -> Any: - """获取manifest信息 - - Args: - key: 信息键,支持点分割的嵌套键(如 "author.name") - default: 默认值 - - Returns: - Any: 对应的值 - """ - if not self.manifest_data: - return default - - keys = key.split(".") - value = self.manifest_data - - for k in keys: - if isinstance(value, dict) and k in value: - value = value[k] - else: - return default - - return value - - def _format_toml_value(self, value: Any) -> str: - """将Python值格式化为合法的TOML字符串""" - if isinstance(value, str): - return json.dumps(value, ensure_ascii=False) - if isinstance(value, bool): - return str(value).lower() - if isinstance(value, (int, float)): - return str(value) - if isinstance(value, list): - inner = ", ".join(self._format_toml_value(item) for item in value) - return f"[{inner}]" - if isinstance(value, dict): - items = [f"{k} = {self._format_toml_value(v)}" for k, v in value.items()] - return "{ " + ", ".join(items) + " }" - return json.dumps(value, ensure_ascii=False) - - def _generate_and_save_default_config(self, config_file_path: str): - """根据插件的Schema生成并保存默认配置文件""" - if not self.config_schema: - logger.debug(f"{self.log_prefix} 插件未定义config_schema,不生成配置文件") - return - - toml_str = f"# {self.plugin_name} - 自动生成的配置文件\n" - plugin_description = self.get_manifest_info("description", "插件配置文件") - toml_str += f"# {plugin_description}\n\n" - - # 遍历每个配置节 - for section, fields in self.config_schema.items(): - # 添加节描述 - if section in self.config_section_descriptions: - toml_str += f"# {self.config_section_descriptions[section]}\n" - - toml_str += f"[{section}]\n\n" - - # 遍历节内的字段 - if isinstance(fields, dict): - for field_name, field in fields.items(): - if isinstance(field, ConfigField): - # 添加字段描述 - toml_str += f"# {field.description}" - if field.required: - toml_str += " (必需)" - toml_str += "\n" - - # 如果有示例值,添加示例 - if field.example: - toml_str += f"# 示例: {field.example}\n" - - # 如果有可选值,添加说明 - if field.choices: - choices_str = ", ".join(map(str, field.choices)) - toml_str += f"# 可选值: {choices_str}\n" - - # 添加字段值 - value = field.default - toml_str += f"{field_name} = {self._format_toml_value(value)}\n" - - toml_str += "\n" - toml_str += "\n" - - try: - with open(config_file_path, "w", encoding="utf-8") as f: - f.write(toml_str) - logger.info(f"{self.log_prefix} 已生成默认配置文件: {config_file_path}") - except IOError as e: - logger.error(f"{self.log_prefix} 保存默认配置文件失败: {e}", exc_info=True) - - def _get_expected_config_version(self) -> str: - """获取插件期望的配置版本号""" - # 从config_schema的plugin.config_version字段获取 - if "plugin" in self.config_schema and isinstance(self.config_schema["plugin"], dict): - config_version_field = self.config_schema["plugin"].get("config_version") - if isinstance(config_version_field, ConfigField): - return config_version_field.default - return "1.0.0" - - def _get_current_config_version(self, config: Dict[str, Any]) -> str: - """从配置文件中获取当前版本号""" - if "plugin" in config and "config_version" in config["plugin"]: - return str(config["plugin"]["config_version"]) - # 如果没有config_version字段,视为最早的版本 - return "0.0.0" - - def _backup_config_file(self, config_file_path: str) -> str: - """备份配置文件""" - timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - backup_path = f"{config_file_path}.backup_{timestamp}" - - try: - shutil.copy2(config_file_path, backup_path) - logger.info(f"{self.log_prefix} 配置文件已备份到: {backup_path}") - return backup_path - except Exception as e: - logger.error(f"{self.log_prefix} 备份配置文件失败: {e}") - return "" - - def _migrate_config_values(self, old_config: Dict[str, Any], new_config: Dict[str, Any]) -> Dict[str, Any]: - """将旧配置值迁移到新配置结构中 - - Args: - old_config: 旧配置数据 - new_config: 基于新schema生成的默认配置 - - Returns: - Dict[str, Any]: 迁移后的配置 - """ - - def migrate_section( - old_section: Dict[str, Any], new_section: Dict[str, Any], section_name: str - ) -> Dict[str, Any]: - """迁移单个配置节""" - result = new_section.copy() - - for key, value in old_section.items(): - if key in new_section: - # 特殊处理:config_version字段总是使用新版本 - if section_name == "plugin" and key == "config_version": - # 保持新的版本号,不迁移旧值 - logger.debug( - f"{self.log_prefix} 更新配置版本: {section_name}.{key} = {result[key]} (旧值: {value})" - ) - continue - - # 键存在于新配置中,复制值 - if isinstance(value, dict) and isinstance(new_section[key], dict): - # 递归处理嵌套字典 - result[key] = migrate_section(value, new_section[key], f"{section_name}.{key}") - else: - result[key] = value - logger.debug(f"{self.log_prefix} 迁移配置: {section_name}.{key} = {value}") - else: - # 键在新配置中不存在,记录警告 - logger.warning(f"{self.log_prefix} 配置项 {section_name}.{key} 在新版本中已被移除") - - return result - - migrated_config = {} - - # 迁移每个配置节 - for section_name, new_section_data in new_config.items(): - if ( - section_name in old_config - and isinstance(old_config[section_name], dict) - and isinstance(new_section_data, dict) - ): - migrated_config[section_name] = migrate_section( - old_config[section_name], new_section_data, section_name - ) - else: - # 新增的节或类型不匹配,使用默认值 - migrated_config[section_name] = new_section_data - if section_name in old_config: - logger.warning(f"{self.log_prefix} 配置节 {section_name} 结构已改变,使用默认值") - - # 检查旧配置中是否有新配置没有的节 - for section_name in old_config: - if section_name not in migrated_config: - logger.warning(f"{self.log_prefix} 配置节 {section_name} 在新版本中已被移除") - - return migrated_config - - def _generate_config_from_schema(self) -> Dict[str, Any]: - # sourcery skip: dict-comprehension - """根据schema生成配置数据结构(不写入文件)""" - if not self.config_schema: - return {} - - config_data = {} - - # 遍历每个配置节 - for section, fields in self.config_schema.items(): - if isinstance(fields, dict): - section_data = {} - - # 遍历节内的字段 - for field_name, field in fields.items(): - if isinstance(field, ConfigField): - section_data[field_name] = field.default - - config_data[section] = section_data - - return config_data - - def _save_config_to_file(self, config_data: Dict[str, Any], config_file_path: str): - """将配置数据保存为TOML文件(包含注释)""" - if not self.config_schema: - logger.debug(f"{self.log_prefix} 插件未定义config_schema,不生成配置文件") - return - - toml_str = f"# {self.plugin_name} - 配置文件\n" - plugin_description = self.get_manifest_info("description", "插件配置文件") - toml_str += f"# {plugin_description}\n" - - # 获取当前期望的配置版本 - expected_version = self._get_expected_config_version() - toml_str += f"# 配置版本: {expected_version}\n\n" - - # 遍历每个配置节 - for section, fields in self.config_schema.items(): - # 添加节描述 - if section in self.config_section_descriptions: - toml_str += f"# {self.config_section_descriptions[section]}\n" - - toml_str += f"[{section}]\n\n" - - # 遍历节内的字段 - if isinstance(fields, dict) and section in config_data: - section_data = config_data[section] - - for field_name, field in fields.items(): - if isinstance(field, ConfigField): - # 添加字段描述 - toml_str += f"# {field.description}" - if field.required: - toml_str += " (必需)" - toml_str += "\n" - - # 如果有示例值,添加示例 - if field.example: - toml_str += f"# 示例: {field.example}\n" - - # 如果有可选值,添加说明 - if field.choices: - choices_str = ", ".join(map(str, field.choices)) - toml_str += f"# 可选值: {choices_str}\n" - - # 添加字段值(使用迁移后的值) - value = section_data.get(field_name, field.default) - toml_str += f"{field_name} = {self._format_toml_value(value)}\n" - - toml_str += "\n" - toml_str += "\n" - - try: - with open(config_file_path, "w", encoding="utf-8") as f: - f.write(toml_str) - logger.info(f"{self.log_prefix} 配置文件已保存: {config_file_path}") - except IOError as e: - logger.error(f"{self.log_prefix} 保存配置文件失败: {e}", exc_info=True) - - def _load_plugin_config(self): # sourcery skip: extract-method - """加载插件配置文件,支持版本检查和自动迁移""" - if not self.config_file_name: - logger.debug(f"{self.log_prefix} 未指定配置文件,跳过加载") - return - - # 优先使用传入的插件目录路径 - if self.plugin_dir: - plugin_dir = self.plugin_dir - else: - # fallback:尝试从类的模块信息获取路径 - try: - plugin_module_path = inspect.getfile(self.__class__) - plugin_dir = os.path.dirname(plugin_module_path) - except (TypeError, OSError): - # 最后的fallback:从模块的__file__属性获取 - module = inspect.getmodule(self.__class__) - if module and hasattr(module, "__file__") and module.__file__: - plugin_dir = os.path.dirname(module.__file__) - else: - logger.warning(f"{self.log_prefix} 无法获取插件目录路径,跳过配置加载") - return - - config_file_path = os.path.join(plugin_dir, self.config_file_name) - - # 如果配置文件不存在,生成默认配置 - if not os.path.exists(config_file_path): - logger.info(f"{self.log_prefix} 配置文件 {config_file_path} 不存在,将生成默认配置。") - self._generate_and_save_default_config(config_file_path) - - if not os.path.exists(config_file_path): - logger.warning(f"{self.log_prefix} 配置文件 {config_file_path} 不存在且无法生成。") - return - - file_ext = os.path.splitext(self.config_file_name)[1].lower() - - if file_ext == ".toml": - # 加载现有配置 - with open(config_file_path, "r", encoding="utf-8") as f: - existing_config = toml.load(f) or {} - - # 检查配置版本 - current_version = self._get_current_config_version(existing_config) - - # 如果配置文件没有版本信息,跳过版本检查 - if current_version == "0.0.0": - logger.debug(f"{self.log_prefix} 配置文件无版本信息,跳过版本检查") - self.config = existing_config - else: - expected_version = self._get_expected_config_version() - - if current_version != expected_version: - logger.info( - f"{self.log_prefix} 检测到配置版本需要更新: 当前=v{current_version}, 期望=v{expected_version}" - ) - - # 生成新的默认配置结构 - new_config_structure = self._generate_config_from_schema() - - # 迁移旧配置值到新结构 - migrated_config = self._migrate_config_values(existing_config, new_config_structure) - - # 保存迁移后的配置 - self._save_config_to_file(migrated_config, config_file_path) - - logger.info(f"{self.log_prefix} 配置文件已从 v{current_version} 更新到 v{expected_version}") - - self.config = migrated_config - else: - logger.debug(f"{self.log_prefix} 配置版本匹配 (v{current_version}),直接加载") - self.config = existing_config - - logger.debug(f"{self.log_prefix} 配置已从 {config_file_path} 加载") - - # 从配置中更新 enable_plugin - if "plugin" in self.config and "enabled" in self.config["plugin"]: - self.enable_plugin = self.config["plugin"]["enabled"] # type: ignore - logger.debug(f"{self.log_prefix} 从配置更新插件启用状态: {self.enable_plugin}") - else: - logger.warning(f"{self.log_prefix} 不支持的配置文件格式: {file_ext},仅支持 .toml") - self.config = {} - - def _check_dependencies(self) -> bool: - """检查插件依赖""" - from src.plugin_system.core.component_registry import component_registry - - if not self.dependencies: - return True - - for dependency in self.dependencies: - dep_name, version_spec, min_version, max_version = self._parse_dependency(dependency) - if not dep_name: - logger.warning(f"{self.log_prefix} 跳过无效依赖声明: {dependency}") - continue - - dep_plugin_info = component_registry.get_plugin_info(dep_name) - if not dep_plugin_info: - logger.error(f"{self.log_prefix} 缺少依赖插件: {dep_name}") - return False - - dep_version = dep_plugin_info.version or "0.0.0" - - if version_spec: - is_ok, msg = self._is_version_spec_satisfied(dep_version, version_spec) - if not is_ok: - logger.error( - f"{self.log_prefix} 依赖插件版本不满足: {dep_name} {version_spec}, 当前版本={dep_version} ({msg})" - ) - return False - - if min_version or max_version: - is_ok, msg = VersionComparator.is_version_in_range(dep_version, min_version, max_version) - if not is_ok: - logger.error( - f"{self.log_prefix} 依赖插件版本不满足: {dep_name} 要求区间[{min_version or '-inf'}, {max_version or '+inf'}], 当前版本={dep_version} ({msg})" - ) - return False - - return True - - def _get_dependency_names(self) -> List[str]: - """获取依赖插件名称列表(用于插件信息展示和统计)。""" - dependency_names: List[str] = [] - for dependency in self.dependencies: - dep_name, _, _, _ = self._parse_dependency(dependency) - if dep_name: - dependency_names.append(dep_name) - return dependency_names - - def _parse_dependency(self, dependency: Any) -> tuple[str, str, str, str]: - """解析依赖声明。 - - 支持格式: - - "plugin_a" - - {"name": "plugin_a", "version": ">=1.2.0,<2.0.0"} - - {"name": "plugin_a", "min_version": "1.2.0", "max_version": "2.0.0"} - """ - if isinstance(dependency, str): - return dependency.strip(), "", "", "" - - if isinstance(dependency, dict): - dep_name = str(dependency.get("name", "")).strip() - version_spec = str(dependency.get("version", "")).strip() - min_version = str(dependency.get("min_version", "")).strip() - max_version = str(dependency.get("max_version", "")).strip() - return dep_name, version_spec, min_version, max_version - - return "", "", "", "" - - def _is_version_spec_satisfied(self, version: str, version_spec: str) -> tuple[bool, str]: - """检查版本是否满足表达式。 - - 支持:==, >=, <=, >, <,可用逗号分隔多个条件。 - 示例:">=1.2.0,<2.0.0" - """ - normalized_version = VersionComparator.normalize_version(version) - clauses = [clause.strip() for clause in version_spec.split(",") if clause.strip()] - if not clauses: - return True, "" - - operators_pattern = r"^(==|>=|<=|>|<)\s*(.+)$" - - for clause in clauses: - if not (match := re.match(operators_pattern, clause)): - return False, f"无效版本约束表达式: {clause}" - - operator, target_version = match.group(1), VersionComparator.normalize_version(match.group(2)) - compare_result = VersionComparator.compare_versions(normalized_version, target_version) - - is_satisfied = False - if operator == "==": - is_satisfied = compare_result == 0 - elif operator == ">=": - is_satisfied = compare_result >= 0 - elif operator == "<=": - is_satisfied = compare_result <= 0 - elif operator == ">": - is_satisfied = compare_result > 0 - elif operator == "<": - is_satisfied = compare_result < 0 - - if not is_satisfied: - return False, f"{normalized_version} 不满足约束 {operator}{target_version}" - - return True, "" - - def get_config(self, key: str, default: Any = None) -> Any: - """获取插件配置值,支持嵌套键访问 - - Args: - key: 配置键名,支持嵌套访问如 "section.subsection.key" - default: 默认值 - - Returns: - Any: 配置值或默认值 - """ - # 支持嵌套键访问 - keys = key.split(".") - current = self.config - - for k in keys: - if isinstance(current, dict) and k in current: - current = current[k] - else: - return default - - return current - - def get_webui_config_schema(self) -> Dict[str, Any]: - """ - 获取 WebUI 配置 Schema - - 返回完整的配置 schema,包含: - - 插件基本信息 - - 所有 section 及其字段定义 - - 布局配置 - - 用于 WebUI 动态生成配置表单。 - - Returns: - Dict: 完整的配置 schema - """ - schema = { - "plugin_id": self.plugin_name, - "plugin_info": { - "name": self.display_name, - "version": self.plugin_version, - "description": self.plugin_description, - "author": self.plugin_author, - }, - "sections": {}, - "layout": None, - } - - # 处理 sections - for section_name, fields in self.config_schema.items(): - if not isinstance(fields, dict): - continue - - section_data = { - "name": section_name, - "title": section_name, - "description": None, - "icon": None, - "collapsed": False, - "order": 0, - "fields": {}, - } - - # 获取 section 元数据 - section_meta = self.config_section_descriptions.get(section_name) - if section_meta: - if isinstance(section_meta, str): - section_data["title"] = section_meta - elif isinstance(section_meta, ConfigSection): - section_data["title"] = section_meta.title - section_data["description"] = section_meta.description - section_data["icon"] = section_meta.icon - section_data["collapsed"] = section_meta.collapsed - section_data["order"] = section_meta.order - elif isinstance(section_meta, dict): - section_data.update(section_meta) - - # 处理字段 - for field_name, field_def in fields.items(): - if isinstance(field_def, ConfigField): - field_data = field_def.to_dict() - field_data["name"] = field_name - section_data["fields"][field_name] = field_data - - schema["sections"][section_name] = section_data - - # 处理布局 - if self.config_layout: - schema["layout"] = self.config_layout.to_dict() - else: - # 自动布局:按 section order 排序 - schema["layout"] = { - "type": "auto", - "tabs": [], - } - - return schema - - def get_current_config_values(self) -> Dict[str, Any]: - """ - 获取当前配置值 - - 返回插件当前的配置值(已从配置文件加载)。 - - Returns: - Dict: 当前配置值 - """ - return self.config.copy() - - @abstractmethod - def register_plugin(self) -> bool: - """ - 注册插件到插件管理器 - - 子类必须实现此方法,返回注册是否成功 - - Returns: - bool: 是否成功注册插件 - """ - raise NotImplementedError("Subclasses must implement this method") diff --git a/src/plugin_system/base/service_types.py b/src/plugin_system/base/service_types.py deleted file mode 100644 index 0a789311..00000000 --- a/src/plugin_system/base/service_types.py +++ /dev/null @@ -1,22 +0,0 @@ -from dataclasses import dataclass, field -from typing import Any, Dict, List - - -@dataclass -class PluginServiceInfo: - """插件服务注册信息""" - - name: str - plugin_name: str - version: str = "1.0.0" - description: str = "" - enabled: bool = True - public: bool = False - allowed_callers: List[str] = field(default_factory=list) - params_schema: Dict[str, Any] = field(default_factory=dict) - return_schema: Dict[str, Any] = field(default_factory=dict) - metadata: Dict[str, Any] = field(default_factory=dict) - - @property - def full_name(self) -> str: - return f"{self.plugin_name}.{self.name}" diff --git a/src/plugin_system/base/workflow_errors.py b/src/plugin_system/base/workflow_errors.py deleted file mode 100644 index 0f4f74d0..00000000 --- a/src/plugin_system/base/workflow_errors.py +++ /dev/null @@ -1,14 +0,0 @@ -from enum import Enum - - -class WorkflowErrorCode(Enum): - """Workflow统一错误码""" - - PLUGIN_NOT_READY = "PLUGIN_NOT_READY" - STEP_TIMEOUT = "STEP_TIMEOUT" - BAD_PAYLOAD = "BAD_PAYLOAD" - DOWNSTREAM_FAILED = "DOWNSTREAM_FAILED" - POLICY_BLOCKED = "POLICY_BLOCKED" - - def __str__(self) -> str: - return self.value diff --git a/src/plugin_system/base/workflow_types.py b/src/plugin_system/base/workflow_types.py deleted file mode 100644 index bedcec8d..00000000 --- a/src/plugin_system/base/workflow_types.py +++ /dev/null @@ -1,68 +0,0 @@ -from dataclasses import dataclass, field -from enum import Enum -from typing import Any, Dict, List, Literal, Optional -import time - - -class WorkflowStage(Enum): - """Workflow阶段定义(MVP固定阶段)""" - - INGRESS = "ingress" - PRE_PROCESS = "pre_process" - PLAN = "plan" - TOOL_EXECUTE = "tool_execute" - POST_PROCESS = "post_process" - EGRESS = "egress" - - def __str__(self) -> str: - return self.value - - -@dataclass -class WorkflowContext: - """Workflow上下文""" - - trace_id: str - stream_id: Optional[str] = None - metadata: Dict[str, Any] = field(default_factory=dict) - timings: Dict[str, float] = field(default_factory=dict) - errors: List[str] = field(default_factory=list) - - -@dataclass -class WorkflowMessage: - """Workflow消息包装""" - - msg_type: str - payload: Dict[str, Any] = field(default_factory=dict) - headers: Dict[str, Any] = field(default_factory=dict) - mutable_flags: Dict[str, bool] = field(default_factory=dict) - - -@dataclass -class WorkflowStepResult: - """Workflow步骤结果""" - - status: Literal["continue", "stop", "failed"] = "continue" - return_message: Optional[str] = None - diagnostics: Dict[str, Any] = field(default_factory=dict) - events: List[Dict[str, Any]] = field(default_factory=list) - created_at: float = field(default_factory=time.time) - - -@dataclass -class WorkflowStepInfo: - """Workflow步骤元数据""" - - name: str - stage: WorkflowStage - plugin_name: str - description: str = "" - enabled: bool = True - priority: int = 0 - timeout_ms: int = 0 - metadata: Dict[str, Any] = field(default_factory=dict) - - @property - def full_name(self) -> str: - return f"{self.plugin_name}.{self.name}" diff --git a/src/plugin_system/core/__init__.py b/src/plugin_system/core/__init__.py deleted file mode 100644 index 01d97646..00000000 --- a/src/plugin_system/core/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -""" -插件核心管理模块 - -提供插件的加载、注册和管理功能 -""" - -from src.plugin_system.core.plugin_manager import plugin_manager -from src.plugin_system.core.component_registry import component_registry -from src.plugin_system.core.events_manager import events_manager -from src.plugin_system.core.global_announcement_manager import global_announcement_manager -from src.plugin_system.core.plugin_service_registry import plugin_service_registry -from src.plugin_system.core.workflow_engine import workflow_engine - -__all__ = [ - "plugin_manager", - "component_registry", - "events_manager", - "global_announcement_manager", - "plugin_service_registry", - "workflow_engine", -] diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py deleted file mode 100644 index 7320f09a..00000000 --- a/src/plugin_system/core/component_registry.py +++ /dev/null @@ -1,758 +0,0 @@ -import re - -from typing import Callable, Dict, List, Optional, Any, Pattern, Tuple, Union, Type - -from src.common.logger import get_logger -from src.plugin_system.base.component_types import ( - ComponentInfo, - ActionInfo, - ToolInfo, - CommandInfo, - EventHandlerInfo, - PluginInfo, - ComponentType, -) -from src.plugin_system.base.base_command import BaseCommand -from src.plugin_system.base.base_action import BaseAction -from src.plugin_system.base.base_tool import BaseTool -from src.plugin_system.base.base_events_handler import BaseEventHandler -from src.plugin_system.base.workflow_types import WorkflowStage, WorkflowStepInfo - -logger = get_logger("component_registry") - - -class ComponentRegistry: - """统一的组件注册中心 - - 负责管理所有插件组件的注册、查询和生命周期管理 - """ - - def __init__(self): - # 命名空间式组件名构成法 f"{component_type}.{component_name}" - self._components: Dict[str, ComponentInfo] = {} - """组件注册表 命名空间式组件名 -> 组件信息""" - self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType} - """类型 -> 组件原名称 -> 组件信息""" - self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseTool, BaseEventHandler]]] = {} - """命名空间式组件名 -> 组件类""" - - # 插件注册表 - self._plugins: Dict[str, PluginInfo] = {} - """插件名 -> 插件信息""" - - # Action特定注册表 - self._action_registry: Dict[str, Type[BaseAction]] = {} - """Action注册表 action名 -> action类""" - self._default_actions: Dict[str, ActionInfo] = {} - """默认动作集,即启用的Action集,用于重置ActionManager状态""" - - # Command特定注册表 - self._command_registry: Dict[str, Type[BaseCommand]] = {} - """Command类注册表 command名 -> command类""" - self._command_patterns: Dict[Pattern, str] = {} - """编译后的正则 -> command名""" - - # 工具特定注册表 - self._tool_registry: Dict[str, Type[BaseTool]] = {} # 工具名 -> 工具类 - self._llm_available_tools: Dict[str, Type[BaseTool]] = {} # llm可用的工具名 -> 工具类 - - # EventHandler特定注册表 - self._event_handler_registry: Dict[str, Type[BaseEventHandler]] = {} - """event_handler名 -> event_handler类""" - self._enabled_event_handlers: Dict[str, Type[BaseEventHandler]] = {} - """启用的事件处理器 event_handler名 -> event_handler类""" - - # Workflow step注册表 - self._workflow_steps: Dict[WorkflowStage, Dict[str, WorkflowStepInfo]] = {stage: {} for stage in WorkflowStage} - self._workflow_step_handlers: Dict[str, Callable[..., Any]] = {} - - logger.info("组件注册中心初始化完成") - - # == 注册方法 == - - def register_plugin(self, plugin_info: PluginInfo) -> bool: - """注册插件 - - Args: - plugin_info: 插件信息 - - Returns: - bool: 是否注册成功 - """ - plugin_name = plugin_info.name - - if plugin_name in self._plugins: - logger.warning(f"插件 {plugin_name} 已存在,跳过注册") - return False - - self._plugins[plugin_name] = plugin_info - logger.debug(f"已注册插件: {plugin_name} (组件数量: {len(plugin_info.components)})") - return True - - def register_component( - self, - component_info: ComponentInfo, - component_class: Type[Union[BaseCommand, BaseAction, BaseEventHandler, BaseTool]], - ) -> bool: - """注册组件 - - Args: - component_info (ComponentInfo): 组件信息 - component_class (Type[Union[BaseCommand, BaseAction, BaseEventHandler]]): 组件类 - - Returns: - bool: 是否注册成功 - """ - component_name = component_info.name - component_type = component_info.component_type - plugin_name = getattr(component_info, "plugin_name", "unknown") - if "." in component_name: - logger.error(f"组件名称 '{component_name}' 包含非法字符 '.',请使用下划线替代") - return False - if "." in plugin_name: - logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代") - return False - - namespaced_name = f"{component_type}.{component_name}" - - if namespaced_name in self._components: - existing_info = self._components[namespaced_name] - existing_plugin = getattr(existing_info, "plugin_name", "unknown") - - logger.warning( - f"组件名冲突: '{plugin_name}' 插件的 {component_type} 类型组件 '{component_name}' 已被插件 '{existing_plugin}' 注册,跳过此组件注册" - ) - return False - - self._components[namespaced_name] = component_info # 注册到通用注册表(使用命名空间化的名称) - self._components_by_type[component_type][component_name] = component_info # 类型内部仍使用原名 - self._components_classes[namespaced_name] = component_class - - # 根据组件类型进行特定注册(使用原始名称) - ret = False - match component_type: - case ComponentType.ACTION: - assert isinstance(component_info, ActionInfo) - assert issubclass(component_class, BaseAction) - ret = self._register_action_component(component_info, component_class) - case ComponentType.COMMAND: - assert isinstance(component_info, CommandInfo) - assert issubclass(component_class, BaseCommand) - ret = self._register_command_component(component_info, component_class) - case ComponentType.TOOL: - assert isinstance(component_info, ToolInfo) - assert issubclass(component_class, BaseTool) - ret = self._register_tool_component(component_info, component_class) - case ComponentType.EVENT_HANDLER: - assert isinstance(component_info, EventHandlerInfo) - assert issubclass(component_class, BaseEventHandler) - ret = self._register_event_handler_component(component_info, component_class) - case _: - logger.warning(f"未知组件类型: {component_type}") - - if not ret: - return False - logger.debug( - f"已注册{component_type}组件: '{component_name}' -> '{namespaced_name}' " - f"({component_class.__name__}) [插件: {plugin_name}]" - ) - return True - - def _register_action_component(self, action_info: ActionInfo, action_class: Type[BaseAction]) -> bool: - """注册Action组件到Action特定注册表""" - if not (action_name := action_info.name): - logger.error(f"Action组件 {action_class.__name__} 必须指定名称") - return False - if not isinstance(action_info, ActionInfo) or not issubclass(action_class, BaseAction): - logger.error(f"注册失败: {action_name} 不是有效的Action") - return False - - self._action_registry[action_name] = action_class - - # 如果启用,添加到默认动作集 - if action_info.enabled: - self._default_actions[action_name] = action_info - - return True - - def _register_command_component(self, command_info: CommandInfo, command_class: Type[BaseCommand]) -> bool: - """注册Command组件到Command特定注册表""" - if not (command_name := command_info.name): - logger.error(f"Command组件 {command_class.__name__} 必须指定名称") - return False - if not isinstance(command_info, CommandInfo) or not issubclass(command_class, BaseCommand): - logger.error(f"注册失败: {command_name} 不是有效的Command") - return False - - self._command_registry[command_name] = command_class - - # 如果启用了且有匹配模式 - if command_info.enabled and command_info.command_pattern: - pattern = re.compile(command_info.command_pattern, re.IGNORECASE | re.DOTALL) - if pattern not in self._command_patterns: - self._command_patterns[pattern] = command_name - else: - logger.warning( - f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令" - ) - - return True - - def _register_tool_component(self, tool_info: ToolInfo, tool_class: Type[BaseTool]) -> bool: - """注册Tool组件到Tool特定注册表""" - tool_name = tool_info.name - - self._tool_registry[tool_name] = tool_class - - # 如果是llm可用的且启用的工具,添加到 llm可用工具列表 - if tool_info.enabled: - self._llm_available_tools[tool_name] = tool_class - - return True - - def _register_event_handler_component( - self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler] - ) -> bool: - if not (handler_name := handler_info.name): - logger.error(f"EventHandler组件 {handler_class.__name__} 必须指定名称") - return False - if not isinstance(handler_info, EventHandlerInfo) or not issubclass(handler_class, BaseEventHandler): - logger.error(f"注册失败: {handler_name} 不是有效的EventHandler") - return False - - self._event_handler_registry[handler_name] = handler_class - - if not handler_info.enabled: - logger.warning(f"EventHandler组件 {handler_name} 未启用") - return True # 未启用,但是也是注册成功 - - from .events_manager import events_manager # 延迟导入防止循环导入问题 - - if events_manager.register_event_subscriber(handler_info, handler_class): - self._enabled_event_handlers[handler_name] = handler_class - return True - else: - logger.error(f"注册事件处理器 {handler_name} 失败") - return False - - # === 组件移除相关 === - - async def remove_component(self, component_name: str, component_type: ComponentType, plugin_name: str) -> bool: - target_component_class = self.get_component_class(component_name, component_type) - if not target_component_class: - logger.warning(f"组件 {component_name} 未注册,无法移除") - return False - try: - match component_type: - case ComponentType.ACTION: - self._action_registry.pop(component_name) - self._default_actions.pop(component_name) - case ComponentType.COMMAND: - self._command_registry.pop(component_name) - keys_to_remove = [k for k, v in self._command_patterns.items() if v == component_name] - for key in keys_to_remove: - self._command_patterns.pop(key) - case ComponentType.TOOL: - self._tool_registry.pop(component_name) - self._llm_available_tools.pop(component_name) - case ComponentType.EVENT_HANDLER: - from .events_manager import events_manager # 延迟导入防止循环导入问题 - - self._event_handler_registry.pop(component_name) - self._enabled_event_handlers.pop(component_name) - await events_manager.unregister_event_subscriber(component_name) - namespaced_name = f"{component_type}.{component_name}" - self._components.pop(namespaced_name) - self._components_by_type[component_type].pop(component_name) - self._components_classes.pop(namespaced_name) - logger.info(f"组件 {component_name} 已移除") - return True - except KeyError as e: - logger.warning(f"移除组件时未找到组件: {component_name}, 发生错误: {e}") - return False - except Exception as e: - logger.error(f"移除组件 {component_name} 时发生错误: {e}") - return False - - async def remove_components_by_plugin(self, plugin_name: str) -> int: - """移除某插件注册的所有组件。""" - targets = [ - (component_info.name, component_info.component_type) - for component_info in self._components.values() - if component_info.plugin_name == plugin_name - ] - - removed_count = 0 - for component_name, component_type in targets: - if await self.remove_component(component_name, component_type, plugin_name): - removed_count += 1 - - if removed_count: - logger.info(f"已移除插件 {plugin_name} 的组件数量: {removed_count}") - return removed_count - - def remove_plugin_registry(self, plugin_name: str) -> bool: - """移除插件注册信息 - - Args: - plugin_name: 插件名称 - - Returns: - bool: 是否成功移除 - """ - if plugin_name not in self._plugins: - logger.warning(f"插件 {plugin_name} 未注册,无法移除") - return False - del self._plugins[plugin_name] - self.remove_workflow_steps_by_plugin(plugin_name) - logger.info(f"插件 {plugin_name} 已移除") - return True - - # === Workflow step 注册与查询 === - - def register_workflow_step(self, step_info: WorkflowStepInfo, step_handler: Callable[..., Any]) -> bool: - """注册workflow步骤。""" - if not step_info.name or not step_info.plugin_name: - logger.error("workflow step 注册失败: step名称或插件名称为空") - return False - if "." in step_info.name: - logger.error(f"workflow step 名称 '{step_info.name}' 包含非法字符 '.',请使用下划线替代") - return False - if "." in step_info.plugin_name: - logger.error(f"workflow step 所属插件名称 '{step_info.plugin_name}' 包含非法字符 '.',请使用下划线替代") - return False - - full_name = step_info.full_name - stage_registry = self._workflow_steps.get(step_info.stage) - if stage_registry is None: - logger.error(f"workflow step 注册失败: 未知阶段 {step_info.stage}") - return False - if full_name in stage_registry: - logger.warning(f"workflow step 已存在,跳过注册: {full_name}") - return False - - stage_registry[full_name] = step_info - self._workflow_step_handlers[full_name] = step_handler - logger.debug(f"已注册workflow step: {full_name} @ {step_info.stage}") - return True - - def get_steps_by_stage(self, stage: WorkflowStage, enabled_only: bool = False) -> Dict[str, WorkflowStepInfo]: - """获取某阶段的workflow步骤。""" - steps = self._workflow_steps.get(stage, {}) - if enabled_only: - return {name: info for name, info in steps.items() if info.enabled} - return steps.copy() - - def get_workflow_step(self, step_name: str, stage: Optional[WorkflowStage] = None) -> Optional[WorkflowStepInfo]: - """获取workflow step信息。 - - step_name支持两种: - - full_name: plugin_name.step_name - - short_name: step_name(若有冲突返回第一个并告警) - """ - candidates: List[WorkflowStepInfo] = [] - - target_stages = [stage] if stage else list(WorkflowStage) - for current_stage in target_stages: - current_steps = self._workflow_steps.get(current_stage, {}) - if "." in step_name: - if step_info := current_steps.get(step_name): - return step_info - continue - candidates.extend([step_info for step_info in current_steps.values() if step_info.name == step_name]) - - if len(candidates) == 1: - return candidates[0] - if len(candidates) > 1: - logger.warning(f"workflow step 名称 '{step_name}' 存在多义,使用第一个匹配: {candidates[0].full_name}") - return candidates[0] - return None - - def get_workflow_step_handler( - self, step_name: str, stage: Optional[WorkflowStage] = None - ) -> Optional[Callable[..., Any]]: - """获取workflow step处理函数。""" - if "." in step_name: - return self._workflow_step_handlers.get(step_name) - - if step_info := self.get_workflow_step(step_name, stage): - return self._workflow_step_handlers.get(step_info.full_name) - return None - - def enable_workflow_step(self, step_name: str, stage: Optional[WorkflowStage] = None) -> bool: - """启用workflow step。""" - step_info = self.get_workflow_step(step_name, stage) - if not step_info: - logger.warning(f"workflow step 未注册,无法启用: {step_name}") - return False - step_info.enabled = True - logger.info(f"workflow step 已启用: {step_info.full_name}") - return True - - def disable_workflow_step(self, step_name: str, stage: Optional[WorkflowStage] = None) -> bool: - """禁用workflow step。""" - step_info = self.get_workflow_step(step_name, stage) - if not step_info: - logger.warning(f"workflow step 未注册,无法禁用: {step_name}") - return False - step_info.enabled = False - logger.info(f"workflow step 已禁用: {step_info.full_name}") - return True - - def remove_workflow_steps_by_plugin(self, plugin_name: str) -> int: - """移除某插件注册的所有workflow step。""" - removed_count = 0 - for stage in WorkflowStage: - stage_registry = self._workflow_steps.get(stage, {}) - target_names = [name for name, info in stage_registry.items() if info.plugin_name == plugin_name] - for full_name in target_names: - stage_registry.pop(full_name, None) - self._workflow_step_handlers.pop(full_name, None) - removed_count += 1 - - if removed_count: - logger.info(f"已移除插件 {plugin_name} 的 workflow step 数量: {removed_count}") - return removed_count - - # === 组件全局启用/禁用方法 === - - def enable_component(self, component_name: str, component_type: ComponentType) -> bool: - """全局的启用某个组件 - Parameters: - component_name: 组件名称 - component_type: 组件类型 - Returns: - bool: 启用成功返回True,失败返回False - """ - target_component_class = self.get_component_class(component_name, component_type) - target_component_info = self.get_component_info(component_name, component_type) - if not target_component_class or not target_component_info: - logger.warning(f"组件 {component_name} 未注册,无法启用") - return False - target_component_info.enabled = True - match component_type: - case ComponentType.ACTION: - assert isinstance(target_component_info, ActionInfo) - self._default_actions[component_name] = target_component_info - case ComponentType.COMMAND: - assert isinstance(target_component_info, CommandInfo) - pattern = target_component_info.command_pattern - self._command_patterns[re.compile(pattern)] = component_name - case ComponentType.TOOL: - assert isinstance(target_component_info, ToolInfo) - assert issubclass(target_component_class, BaseTool) - self._llm_available_tools[component_name] = target_component_class - case ComponentType.EVENT_HANDLER: - assert isinstance(target_component_info, EventHandlerInfo) - assert issubclass(target_component_class, BaseEventHandler) - self._enabled_event_handlers[component_name] = target_component_class - from .events_manager import events_manager # 延迟导入防止循环导入问题 - - events_manager.register_event_subscriber(target_component_info, target_component_class) - namespaced_name = f"{component_type}.{component_name}" - self._components[namespaced_name].enabled = True - self._components_by_type[component_type][component_name].enabled = True - logger.info(f"组件 {component_name} 已启用") - return True - - async def disable_component(self, component_name: str, component_type: ComponentType) -> bool: - """全局的禁用某个组件 - Parameters: - component_name: 组件名称 - component_type: 组件类型 - Returns: - bool: 禁用成功返回True,失败返回False - """ - target_component_class = self.get_component_class(component_name, component_type) - target_component_info = self.get_component_info(component_name, component_type) - if not target_component_class or not target_component_info: - logger.warning(f"组件 {component_name} 未注册,无法禁用") - return False - target_component_info.enabled = False - try: - match component_type: - case ComponentType.ACTION: - self._default_actions.pop(component_name) - case ComponentType.COMMAND: - self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != component_name} - case ComponentType.TOOL: - self._llm_available_tools.pop(component_name) - case ComponentType.EVENT_HANDLER: - self._enabled_event_handlers.pop(component_name) - from .events_manager import events_manager # 延迟导入防止循环导入问题 - - await events_manager.unregister_event_subscriber(component_name) - self._components[component_name].enabled = False - self._components_by_type[component_type][component_name].enabled = False - logger.info(f"组件 {component_name} 已禁用") - return True - except KeyError as e: - logger.warning(f"禁用组件时未找到组件或已禁用: {component_name}, 发生错误: {e}") - return False - except Exception as e: - logger.error(f"禁用组件 {component_name} 时发生错误: {e}") - return False - - # === 组件查询方法 === - def get_component_info( - self, component_name: str, component_type: Optional[ComponentType] = None - ) -> Optional[ComponentInfo]: - # sourcery skip: class-extract-method - """获取组件信息,支持自动命名空间解析 - - Args: - component_name: 组件名称,可以是原始名称或命名空间化的名称 - component_type: 组件类型,如果提供则优先在该类型中查找 - - Returns: - Optional[ComponentInfo]: 组件信息或None - """ - # 1. 如果已经是命名空间化的名称,直接查找 - if "." in component_name: - return self._components.get(component_name) - - # 2. 如果指定了组件类型,构造命名空间化的名称查找 - if component_type: - namespaced_name = f"{component_type}.{component_name}" - return self._components.get(namespaced_name) - - # 3. 如果没有指定类型,尝试在所有命名空间中查找 - candidates = [] - for namespace_prefix in [types.value for types in ComponentType]: - namespaced_name = f"{namespace_prefix}.{component_name}" - if component_info := self._components.get(namespaced_name): - candidates.append((namespace_prefix, namespaced_name, component_info)) - - if len(candidates) == 1: - # 只有一个匹配,直接返回 - return candidates[0][2] - elif len(candidates) > 1: - # 多个匹配,记录警告并返回第一个 - namespaces = [ns for ns, _, _ in candidates] - logger.warning( - f"组件名称 '{component_name}' 在多个命名空间中存在: {namespaces},使用第一个匹配项: {candidates[0][1]}" - ) - return candidates[0][2] - - # 4. 都没找到 - return None - - def get_component_class( - self, - component_name: str, - component_type: Optional[ComponentType] = None, - ) -> Optional[Union[Type[BaseCommand], Type[BaseAction], Type[BaseEventHandler], Type[BaseTool]]]: - """获取组件类,支持自动命名空间解析 - - Args: - component_name: 组件名称,可以是原始名称或命名空间化的名称 - component_type: 组件类型,如果提供则优先在该类型中查找 - - Returns: - Optional[Union[BaseCommand, BaseAction]]: 组件类或None - """ - # 1. 如果已经是命名空间化的名称,直接查找 - if "." in component_name: - return self._components_classes.get(component_name) - - # 2. 如果指定了组件类型,构造命名空间化的名称查找 - if component_type: - namespaced_name = f"{component_type.value}.{component_name}" - return self._components_classes.get(namespaced_name) - - # 3. 如果没有指定类型,尝试在所有命名空间中查找 - candidates = [] - for namespace_prefix in [types.value for types in ComponentType]: - namespaced_name = f"{namespace_prefix}.{component_name}" - if component_class := self._components_classes.get(namespaced_name): - candidates.append((namespace_prefix, namespaced_name, component_class)) - - if len(candidates) == 1: - # 只有一个匹配,直接返回 - _, full_name, cls = candidates[0] - logger.debug(f"自动解析组件: '{component_name}' -> '{full_name}'") - return cls - elif len(candidates) > 1: - # 多个匹配,记录警告并返回第一个 - namespaces = [ns for ns, _, _ in candidates] - logger.warning( - f"组件名称 '{component_name}' 在多个命名空间中存在: {namespaces},使用第一个匹配项: {candidates[0][1]}" - ) - return candidates[0][2] - - # 4. 都没找到 - return None - - def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]: - """获取指定类型的所有组件""" - return self._components_by_type.get(component_type, {}).copy() - - def get_enabled_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]: - """获取指定类型的所有启用组件""" - components = self.get_components_by_type(component_type) - return {name: info for name, info in components.items() if info.enabled} - - # === Action特定查询方法 === - - def get_action_registry(self) -> Dict[str, Type[BaseAction]]: - """获取Action注册表""" - return self._action_registry.copy() - - def get_registered_action_info(self, action_name: str) -> Optional[ActionInfo]: - """获取Action信息""" - info = self.get_component_info(action_name, ComponentType.ACTION) - return info if isinstance(info, ActionInfo) else None - - def get_default_actions(self) -> Dict[str, ActionInfo]: - """获取默认动作集""" - return self._default_actions.copy() - - # === Command特定查询方法 === - - def get_command_registry(self) -> Dict[str, Type[BaseCommand]]: - """获取Command注册表""" - return self._command_registry.copy() - - def get_registered_command_info(self, command_name: str) -> Optional[CommandInfo]: - """获取Command信息""" - info = self.get_component_info(command_name, ComponentType.COMMAND) - return info if isinstance(info, CommandInfo) else None - - def get_command_patterns(self) -> Dict[Pattern, str]: - """获取Command模式注册表""" - return self._command_patterns.copy() - - def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, CommandInfo]]: - # sourcery skip: use-named-expression, use-next - """根据文本查找匹配的命令 - - Args: - text: 输入文本 - - Returns: - Tuple: (命令类, 匹配的命名组, 是否拦截消息, 插件名) 或 None - """ - - candidates = [pattern for pattern in self._command_patterns if pattern.match(text)] - if not candidates: - return None - if len(candidates) > 1: - logger.warning(f"文本 '{text}' 匹配到多个命令模式: {candidates},使用第一个匹配") - command_name = self._command_patterns[candidates[0]] - command_info: CommandInfo = self.get_registered_command_info(command_name) # type: ignore - return ( - self._command_registry[command_name], - candidates[0].match(text).groupdict(), # type: ignore - command_info, - ) - - # === Tool 特定查询方法 === - def get_tool_registry(self) -> Dict[str, Type[BaseTool]]: - """获取Tool注册表""" - return self._tool_registry.copy() - - def get_llm_available_tools(self) -> Dict[str, Type[BaseTool]]: - """获取LLM可用的Tool列表""" - return self._llm_available_tools.copy() - - def get_registered_tool_info(self, tool_name: str) -> Optional[ToolInfo]: - """获取Tool信息 - - Args: - tool_name: 工具名称 - - Returns: - ToolInfo: 工具信息对象,如果工具不存在则返回 None - """ - info = self.get_component_info(tool_name, ComponentType.TOOL) - return info if isinstance(info, ToolInfo) else None - - # === EventHandler 特定查询方法 === - - def get_event_handler_registry(self) -> Dict[str, Type[BaseEventHandler]]: - """获取事件处理器注册表""" - return self._event_handler_registry.copy() - - def get_registered_event_handler_info(self, handler_name: str) -> Optional[EventHandlerInfo]: - """获取事件处理器信息""" - info = self.get_component_info(handler_name, ComponentType.EVENT_HANDLER) - return info if isinstance(info, EventHandlerInfo) else None - - def get_enabled_event_handlers(self) -> Dict[str, Type[BaseEventHandler]]: - """获取启用的事件处理器""" - return self._enabled_event_handlers.copy() - - # === 插件查询方法 === - - def get_plugin_info(self, plugin_name: str) -> Optional[PluginInfo]: - """获取插件信息""" - return self._plugins.get(plugin_name) - - def get_all_plugins(self) -> Dict[str, PluginInfo]: - """获取所有插件""" - return self._plugins.copy() - - # def get_enabled_plugins(self) -> Dict[str, PluginInfo]: - # """获取所有启用的插件""" - # return {name: info for name, info in self._plugins.items() if info.enabled} - - def get_plugin_components(self, plugin_name: str) -> List[ComponentInfo]: - """获取插件的所有组件""" - plugin_info = self.get_plugin_info(plugin_name) - return plugin_info.components if plugin_info else [] - - def get_plugin_config(self, plugin_name: str) -> Optional[dict]: - """获取插件配置 - - Args: - plugin_name: 插件名称 - - Returns: - Optional[dict]: 插件配置字典或None - """ - # 从插件管理器获取插件实例的配置 - from src.plugin_system.core.plugin_manager import plugin_manager - - plugin_instance = plugin_manager.get_plugin_instance(plugin_name) - return plugin_instance.config if plugin_instance else None - - def get_registry_stats(self) -> Dict[str, Any]: - """获取注册中心统计信息""" - action_components: int = 0 - command_components: int = 0 - tool_components: int = 0 - events_handlers: int = 0 - for component in self._components.values(): - if component.component_type == ComponentType.ACTION: - action_components += 1 - elif component.component_type == ComponentType.COMMAND: - command_components += 1 - elif component.component_type == ComponentType.TOOL: - tool_components += 1 - elif component.component_type == ComponentType.EVENT_HANDLER: - events_handlers += 1 - - workflow_step_count = sum(len(steps) for steps in self._workflow_steps.values()) - enabled_workflow_step_count = sum( - len([step for step in steps.values() if step.enabled]) for steps in self._workflow_steps.values() - ) - - return { - "action_components": action_components, - "command_components": command_components, - "tool_components": tool_components, - "event_handlers": events_handlers, - "total_components": len(self._components), - "total_plugins": len(self._plugins), - "components_by_type": { - component_type.value: len(components) for component_type, components in self._components_by_type.items() - }, - "enabled_components": len([c for c in self._components.values() if c.enabled]), - "enabled_plugins": len([p for p in self._plugins.values() if p.enabled]), - "workflow_steps": workflow_step_count, - "enabled_workflow_steps": enabled_workflow_step_count, - "workflow_steps_by_stage": {stage.value: len(steps) for stage, steps in self._workflow_steps.items()}, - } - - -component_registry = ComponentRegistry() diff --git a/src/plugin_system/core/events_manager.py b/src/plugin_system/core/events_manager.py deleted file mode 100644 index fa0a0637..00000000 --- a/src/plugin_system/core/events_manager.py +++ /dev/null @@ -1,512 +0,0 @@ -import asyncio -import contextlib -from typing import List, Dict, Optional, Type, Tuple, TYPE_CHECKING - -from src.chat.message_receive.message import MessageSending, SessionMessage -from src.chat.message_receive.chat_manager import chat_manager as _chat_manager -from src.common.logger import get_logger -from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages, CustomEventHandlerResult -from src.plugin_system.base.base_events_handler import BaseEventHandler -from src.plugin_system.base.workflow_types import WorkflowContext, WorkflowStepResult -from .global_announcement_manager import global_announcement_manager -from .workflow_engine import workflow_engine - -if TYPE_CHECKING: - from src.common.data_models.llm_data_model import LLMGenerationDataModel - -logger = get_logger("events_manager") - - -class EventsManager: - def __init__(self): - # 有权重的 events 订阅者注册表 - self._events_subscribers: Dict[EventType | str, List[BaseEventHandler]] = {} - self._handler_mapping: Dict[str, Type[BaseEventHandler]] = {} # 事件处理器映射表 - self._handler_tasks: Dict[str, List[asyncio.Task]] = {} # 事件处理器正在处理的任务 - self._events_result_history: Dict[EventType | str, List[CustomEventHandlerResult]] = {} # 事件的结果历史记录 - self._history_enable_map: Dict[EventType | str, bool] = {} # 是否启用历史记录的映射表,同时作为events注册表 - - # 事件注册(同时作为注册样例) - for event in EventType: - self.register_event(event, enable_history_result=False) - - def register_event(self, event_type: EventType | str, enable_history_result: bool = False): - if event_type in self._events_subscribers: - raise ValueError(f"事件类型 {event_type} 已存在") - self._events_subscribers[event_type] = [] - self._history_enable_map[event_type] = enable_history_result - if enable_history_result: - self._events_result_history[event_type] = [] - - def register_event_subscriber(self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]) -> bool: - """注册事件处理器 - - Args: - handler_info (EventHandlerInfo): 事件处理器信息 - handler_class (Type[BaseEventHandler]): 事件处理器类 - - Returns: - bool: 是否注册成功 - """ - if not issubclass(handler_class, BaseEventHandler): - logger.error(f"类 {handler_class.__name__} 不是 BaseEventHandler 的子类") - return False - - handler_name = handler_info.name - - if handler_name in self._handler_mapping: - logger.warning(f"事件处理器 {handler_name} 已存在,跳过注册") - return False - - if handler_info.event_type not in self._history_enable_map: - if isinstance(handler_info.event_type, str): - self.register_event(handler_info.event_type, enable_history_result=False) - logger.info(f"自动注册自定义事件类型: {handler_info.event_type}") - else: - logger.error(f"事件类型 {handler_info.event_type} 未注册,无法为其注册处理器 {handler_name}") - return False - - self._handler_mapping[handler_name] = handler_class - return self._insert_event_handler(handler_class, handler_info) - - async def handle_mai_events( - self, - event_type: EventType | str, - message: Optional[SessionMessage | MessageSending | MaiMessages] = None, - llm_prompt: Optional[str] = None, - llm_response: Optional["LLMGenerationDataModel"] = None, - stream_id: Optional[str] = None, - action_usage: Optional[List[str]] = None, - ) -> Tuple[bool, Optional[MaiMessages]]: - """ - 处理所有事件,根据事件类型分发给订阅的处理器。 - """ - from src.plugin_system.core import component_registry - - continue_flag = True - - # 1. 准备消息 - transformed_message = self._prepare_message( - event_type, message, llm_prompt, llm_response, stream_id, action_usage # type: ignore[arg-type] - ) - if transformed_message: - transformed_message = transformed_message.deepcopy() - - # 2. 获取并遍历处理器 - handlers = self._events_subscribers.get(event_type, []) - if not handlers: - return True, None - - current_stream_id = transformed_message.stream_id if transformed_message else None - modified_message: Optional[MaiMessages] = None - for handler in handlers: - # 3. 前置检查和配置加载 - if ( - current_stream_id - and handler.handler_name - in global_announcement_manager.get_disabled_chat_event_handlers(current_stream_id) - ): - continue - - # 统一加载插件配置 - plugin_config = component_registry.get_plugin_config(handler.plugin_name) or {} - handler.set_plugin_config(plugin_config) - - # 4. 根据类型分发任务 - if ( - handler.intercept_message or event_type == EventType.ON_STOP - ): # 让ON_STOP的所有事件处理器都发挥作用,防止还没执行即被取消 - # 阻塞执行,并更新 continue_flag - should_continue, modified_message = await self._dispatch_intercepting_handler_task( - handler, event_type, modified_message or transformed_message - ) - continue_flag = continue_flag and should_continue - else: - # 异步执行,不阻塞 - self._dispatch_handler_task(handler, event_type, transformed_message) - - # 桥接到新版本插件运行时 - continue_flag, modified_message = await self._bridge_to_new_runtime( - event_type, continue_flag, modified_message or transformed_message - ) - - return continue_flag, modified_message - - async def handle_workflow_message( - self, - message: Optional[SessionMessage | MessageSending | MaiMessages] = None, - stream_id: Optional[str] = None, - action_usage: Optional[List[str]] = None, - context: Optional[WorkflowContext] = None, - ) -> Tuple[WorkflowStepResult, Optional[MaiMessages], WorkflowContext]: - """执行线性workflow消息流转(MVP兼容入口)。""" - initial_message = self._prepare_message(EventType.ON_MESSAGE_PRE_PROCESS, message=message, stream_id=stream_id) - - async def _dispatch( - event_type: EventType | str, - workflow_message: Optional[MaiMessages], - workflow_stream_id: Optional[str], - workflow_action_usage: Optional[List[str]], - ) -> Tuple[bool, Optional[MaiMessages]]: - return await self.handle_mai_events( - event_type=event_type, - message=workflow_message, - stream_id=workflow_stream_id, - action_usage=workflow_action_usage, - ) - - return await workflow_engine.execute_linear( - dispatch_event=_dispatch, - message=initial_message, - stream_id=stream_id, - action_usage=action_usage, - context=context, - ) - - async def cancel_handler_tasks(self, handler_name: str) -> None: - tasks_to_be_cancelled = self._handler_tasks.get(handler_name, []) - if remaining_tasks := [task for task in tasks_to_be_cancelled if not task.done()]: - for task in remaining_tasks: - task.cancel() - try: - await asyncio.wait_for(asyncio.gather(*remaining_tasks, return_exceptions=True), timeout=5) - logger.info(f"已取消事件处理器 {handler_name} 的所有任务") - except asyncio.TimeoutError: - logger.warning(f"取消事件处理器 {handler_name} 的任务超时,开始强制取消") - except Exception as e: - logger.error(f"取消事件处理器 {handler_name} 的任务时发生异常: {e}") - if handler_name in self._handler_tasks: - del self._handler_tasks[handler_name] - - async def unregister_event_subscriber(self, handler_name: str) -> bool: - """取消注册事件处理器""" - if handler_name not in self._handler_mapping: - logger.warning(f"事件处理器 {handler_name} 不存在,无法取消注册") - return False - - await self.cancel_handler_tasks(handler_name) - - handler_class = self._handler_mapping.pop(handler_name) - if not self._remove_event_handler_instance(handler_class): - return False - - logger.info(f"事件处理器 {handler_name} 已成功取消注册") - return True - - async def get_event_result_history(self, event_type: EventType | str) -> List[CustomEventHandlerResult]: - """获取事件的结果历史记录""" - if event_type == EventType.UNKNOWN: - raise ValueError("未知事件类型") - if event_type not in self._history_enable_map: - raise ValueError(f"事件类型 {event_type} 未注册") - if not self._history_enable_map[event_type]: - raise ValueError(f"事件类型 {event_type} 的历史记录未启用") - - return self._events_result_history[event_type] - - async def clear_event_result_history(self, event_type: EventType | str) -> None: - """清空事件的结果历史记录""" - if event_type == EventType.UNKNOWN: - raise ValueError("未知事件类型") - if event_type not in self._history_enable_map: - raise ValueError(f"事件类型 {event_type} 未注册") - if not self._history_enable_map[event_type]: - raise ValueError(f"事件类型 {event_type} 的历史记录未启用") - - self._events_result_history[event_type] = [] - - def _insert_event_handler(self, handler_class: Type[BaseEventHandler], handler_info: EventHandlerInfo) -> bool: - """插入事件处理器到对应的事件类型列表中并设置其插件配置""" - if handler_class.event_type == EventType.UNKNOWN: - logger.error(f"事件处理器 {handler_class.__name__} 的事件类型未知,无法注册") - return False - if handler_class.event_type not in self._events_subscribers: - self._events_subscribers[handler_class.event_type] = [] - handler_instance = handler_class() - handler_instance.set_plugin_name(handler_info.plugin_name or "unknown") - self._events_subscribers[handler_class.event_type].append(handler_instance) - self._events_subscribers[handler_class.event_type].sort(key=lambda x: x.weight, reverse=True) - - return True - - def _remove_event_handler_instance(self, handler_class: Type[BaseEventHandler]) -> bool: - """从事件类型列表中移除事件处理器""" - display_handler_name = handler_class.handler_name or handler_class.__name__ - if handler_class.event_type == EventType.UNKNOWN: - logger.warning(f"事件处理器 {display_handler_name} 的事件类型未知,不存在于处理器列表中") - return False - - handlers = self._events_subscribers[handler_class.event_type] - for i, handler in enumerate(handlers): - if isinstance(handler, handler_class): - del handlers[i] - logger.debug(f"事件处理器 {display_handler_name} 已移除") - return True - - logger.warning(f"未找到事件处理器 {display_handler_name},无法移除") - return False - - def _transform_event_message( - self, - message: SessionMessage | MessageSending, - llm_prompt: Optional[str] = None, - llm_response: Optional["LLMGenerationDataModel"] = None, - ) -> MaiMessages: - """转换事件消息格式""" - from maim_message import Seg - - # 直接赋值部分内容 - transformed_message = MaiMessages( - llm_prompt=llm_prompt, - llm_response_content=llm_response.content if llm_response else None, - llm_response_reasoning=llm_response.reasoning if llm_response else None, - llm_response_model=llm_response.model if llm_response else None, - llm_response_tool_call=llm_response.tool_calls if llm_response else None, - raw_message=message.processed_plain_text or "", - additional_data={}, - ) - - # 消息段处理 - if isinstance(message, MessageSending): - if message.message_segment.type == "seglist": - transformed_message.message_segments = list(message.message_segment.data) # type: ignore - else: - transformed_message.message_segments = [message.message_segment] - else: - # SessionMessage: 使用 processed_plain_text 构造简单段 - transformed_message.message_segments = [Seg(type="text", data=message.processed_plain_text or "")] - - # stream_id 处理 - transformed_message.stream_id = message.session_id if hasattr(message, "session_id") else "" - - # 处理后文本 - transformed_message.plain_text = message.processed_plain_text - - # 基本信息 - if isinstance(message, MessageSending): - transformed_message.message_base_info["platform"] = message.platform - if message.session.group_id: - transformed_message.is_group_message = True - group_name = "" - if message.session.context and message.session.context.message and message.session.context.message.message_info.group_info: - group_name = message.session.context.message.message_info.group_info.group_name - transformed_message.message_base_info.update({ - "group_id": message.session.group_id, - "group_name": group_name, - }) - transformed_message.message_base_info.update({ - "user_id": message.bot_user_info.user_id, - "user_cardname": message.bot_user_info.user_cardname, - "user_nickname": message.bot_user_info.user_nickname, - }) - if not transformed_message.is_group_message: - transformed_message.is_private_message = True - elif hasattr(message, "message_info") and message.message_info: - if message.platform: - transformed_message.message_base_info["platform"] = message.platform - if message.message_info.group_info: - transformed_message.is_group_message = True - transformed_message.message_base_info.update({ - "group_id": message.message_info.group_info.group_id, - "group_name": message.message_info.group_info.group_name, - }) - if message.message_info.user_info: - if not transformed_message.is_group_message: - transformed_message.is_private_message = True - transformed_message.message_base_info.update({ - "user_id": message.message_info.user_info.user_id, - "user_cardname": message.message_info.user_info.user_cardname, - "user_nickname": message.message_info.user_info.user_nickname, - }) - - return transformed_message - - def _build_message_from_stream( - self, stream_id: str, llm_prompt: Optional[str] = None, llm_response: Optional["LLMGenerationDataModel"] = None - ) -> MaiMessages: - """从流ID构建消息""" - session = _chat_manager.get_session_by_session_id(stream_id) - assert session, f"未找到流ID为 {stream_id} 的会话" - message = session.context.message - return self._transform_event_message(message, llm_prompt, llm_response) - - def _transform_event_without_message( - self, - stream_id: str, - llm_prompt: Optional[str] = None, - llm_response: Optional["LLMGenerationDataModel"] = None, - action_usage: Optional[List[str]] = None, - ) -> MaiMessages: - """没有message对象时进行转换""" - session = _chat_manager.get_session_by_session_id(stream_id) - assert session, f"未找到流ID为 {stream_id} 的会话" - return MaiMessages( - stream_id=stream_id, - llm_prompt=llm_prompt, - llm_response_content=(llm_response.content if llm_response else None), - llm_response_reasoning=(llm_response.reasoning if llm_response else None), - llm_response_model=(llm_response.model if llm_response else None), - llm_response_tool_call=(llm_response.tool_calls if llm_response else None), - is_group_message=session.is_group_session, - is_private_message=not session.is_group_session, - action_usage=action_usage, - additional_data={"response_is_processed": True}, - ) - - async def _bridge_to_new_runtime( - self, - event_type: EventType | str, - continue_flag: bool, - message: Optional[MaiMessages], - ) -> Tuple[bool, Optional[MaiMessages]]: - """将事件桥接到新版本插件运行时 - - 如果旧 handler 已经 abort(continue_flag=False),直接跳过。 - """ - if not continue_flag: - return continue_flag, message - - try: - from src.plugin_runtime.integration import get_plugin_runtime_manager - - prm = get_plugin_runtime_manager() - if not prm.is_running: - return continue_flag, message - - event_value = event_type.value if isinstance(event_type, EventType) else str(event_type) - message_dict = message.to_dict() if message and hasattr(message, "to_dict") else None - - new_continue, new_msg_dict = await prm.bridge_event( - event_type_value=event_value, - message_dict=message_dict, - ) - # 新运行时返回 abort 则合并 - if not new_continue: - continue_flag = False - - except Exception as e: - logger.warning(f"桥接事件到新运行时失败: {e}") - - return continue_flag, message - - def _prepare_message( - self, - event_type: EventType | str, - message: Optional[SessionMessage | MessageSending | MaiMessages] = None, - llm_prompt: Optional[str] = None, - llm_response: Optional["LLMGenerationDataModel"] = None, - stream_id: Optional[str] = None, - action_usage: Optional[List[str]] = None, - ) -> Optional[MaiMessages]: - """根据事件类型和输入,准备和转换消息对象。""" - if isinstance(message, MaiMessages): - return message.deepcopy() - - if message: - return self._transform_event_message(message, llm_prompt, llm_response) - - if event_type not in [EventType.ON_START, EventType.ON_STOP]: - assert stream_id, "如果没有消息,必须为非启动/关闭事件提供流ID" - if event_type in [EventType.ON_MESSAGE, EventType.ON_PLAN, EventType.POST_LLM, EventType.AFTER_LLM]: - return self._build_message_from_stream(stream_id, llm_prompt, llm_response) - else: - return self._transform_event_without_message(stream_id, llm_prompt, llm_response, action_usage) - - return None # ON_START, ON_STOP事件没有消息体 - - def _dispatch_handler_task( - self, handler: BaseEventHandler, event_type: EventType | str, message: Optional[MaiMessages] = None - ): - """分发一个非阻塞(异步)的事件处理任务。""" - if event_type == EventType.UNKNOWN: - raise ValueError("未知事件类型") - try: - task = asyncio.create_task(handler.execute(message)) - - task_name = f"{handler.plugin_name}-{handler.handler_name}" - task.set_name(task_name) - task.add_done_callback(lambda t: self._task_done_callback(t, event_type)) - - self._handler_tasks.setdefault(handler.handler_name, []).append(task) - except Exception as e: - logger.error(f"创建事件处理器任务 {handler.handler_name} 时发生异常: {e}", exc_info=True) - - async def _dispatch_intercepting_handler_task( - self, handler: BaseEventHandler, event_type: EventType | str, message: Optional[MaiMessages] = None - ) -> Tuple[bool, Optional[MaiMessages]]: - """分发并等待一个阻塞(同步)的事件处理器,返回是否应继续处理。""" - if event_type == EventType.UNKNOWN: - raise ValueError("未知事件类型") - if event_type not in self._history_enable_map: - raise ValueError(f"事件类型 {event_type} 未注册") - try: - result = await handler.execute(message) - - expected_fields = ["success", "continue_processing", "return_message", "custom_result", "modified_message"] - - if not isinstance(result, tuple) or len(result) != 5: - if isinstance(result, tuple): - annotated = ", ".join(f"{name}={val!r}" for name, val in zip(expected_fields, result, strict=False)) - actual_desc = f"{len(result)} 个元素 ({annotated})" - else: - actual_desc = f"非 tuple 类型: {type(result)}" - - logger.error( - f"[{self.__class__.__name__}] EventHandler {handler.handler_name} 返回值不符合预期:\n" - f" 模块来源: {handler.__class__.__module__}.{handler.__class__.__name__}\n" - f" 期望: 5 个元素 ({', '.join(expected_fields)})\n" - f" 实际: {actual_desc}" - ) - return True, None - - success, continue_processing, return_message, custom_result, modified_message = result - - if not success: - logger.error(f"EventHandler {handler.handler_name} 执行失败: {return_message}") - else: - logger.debug(f"EventHandler {handler.handler_name} 执行成功: {return_message}") - - if self._history_enable_map[event_type] and custom_result: - self._events_result_history[event_type].append(custom_result) - - return continue_processing, modified_message - - except KeyError: - logger.error(f"事件 {event_type} 注册的历史记录启用情况与实际不符合") - return True, None - except Exception as e: - logger.error(f"EventHandler {handler.handler_name} 发生异常: {e}", exc_info=True) - return True, None # 发生异常时默认不中断其他处理 - - def _task_done_callback( - self, - task: asyncio.Task[Tuple[bool, bool, str | None, CustomEventHandlerResult | None, MaiMessages | None]], - event_type: EventType | str, - ): - """任务完成回调""" - task_name = task.get_name() or "Unknown Task" - if event_type == EventType.UNKNOWN: - raise ValueError("未知事件类型") - if event_type not in self._history_enable_map: - raise ValueError(f"事件类型 {event_type} 未注册") - try: - success, _, result, custom_result, _ = task.result() # 忽略是否继续的标志和消息的修改,因为消息本身未被拦截 - if success: - logger.debug(f"事件处理任务 {task_name} 已成功完成: {result}") - else: - logger.error(f"事件处理任务 {task_name} 执行失败: {result}") - - if self._history_enable_map[event_type] and custom_result: - self._events_result_history[event_type].append(custom_result) - except asyncio.CancelledError: - pass - except KeyError: - logger.error(f"事件 {event_type} 注册的历史记录启用情况与实际不符合") - except Exception as e: - logger.error(f"事件处理任务 {task_name} 发生异常: {e}") - finally: - with contextlib.suppress(ValueError, KeyError): - self._handler_tasks[task_name].remove(task) - - -events_manager = EventsManager() diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py deleted file mode 100644 index 62803ee6..00000000 --- a/src/plugin_system/core/plugin_manager.py +++ /dev/null @@ -1,634 +0,0 @@ -from importlib.util import spec_from_file_location, module_from_spec -from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Tuple, Type -from collections import deque -import os -import traceback - - -from src.common.logger import get_logger -from src.plugin_system.base.plugin_base import PluginBase -from src.plugin_system.base.component_types import ComponentType -from src.plugin_system.utils.manifest_utils import VersionComparator -from .component_registry import component_registry -from .plugin_service_registry import plugin_service_registry - -logger = get_logger("plugin_manager") - - -class PluginManager: - """ - 插件管理器类 - - 负责加载,重载和卸载插件,同时管理插件的所有组件 - """ - - def __init__(self): - self.plugin_directories: List[str] = [] # 插件根目录列表 - self.plugin_classes: Dict[str, Type[PluginBase]] = {} # 全局插件类注册表,插件名 -> 插件类 - self.plugin_paths: Dict[str, str] = {} # 记录插件名到目录路径的映射,插件名 -> 目录路径 - - self.loaded_plugins: Dict[str, PluginBase] = {} # 已加载的插件类实例注册表,插件名 -> 插件类实例 - self.failed_plugins: Dict[str, str] = {} # 记录加载失败的插件文件及其错误信息,插件名 -> 错误信息 - - # 确保插件目录存在 - self._ensure_plugin_directories() - logger.info("插件管理器初始化完成") - - # === 插件目录管理 === - - def add_plugin_directory(self, directory: str) -> bool: - """添加插件目录""" - if os.path.exists(directory): - if directory not in self.plugin_directories: - self.plugin_directories.append(directory) - logger.debug(f"已添加插件目录: {directory}") - return True - else: - logger.warning(f"插件不可重复加载: {directory}") - else: - logger.warning(f"插件目录不存在: {directory}") - return False - - # === 插件加载管理 === - - def load_all_plugins(self) -> Tuple[int, int]: - """加载所有插件 - - Returns: - tuple[int, int]: (插件数量, 组件数量) - """ - logger.debug("开始加载所有插件...") - - # 第一阶段:加载所有插件模块(注册插件类) - total_loaded_modules = 0 - total_failed_modules = 0 - - for directory in self.plugin_directories: - loaded, failed = self._load_plugin_modules_from_directory(directory) - total_loaded_modules += loaded - total_failed_modules += failed - - logger.debug(f"插件模块加载完成 - 成功: {total_loaded_modules}, 失败: {total_failed_modules}") - - # 第二阶段前:根据依赖关系决定插件加载顺序 - dependency_graph, missing_dependencies = self._build_plugin_dependency_graph() - sorted_plugins, cycle_plugins = self._resolve_plugin_load_order(dependency_graph) - - total_registered = 0 - total_failed_registration = 0 - - # 先处理缺失依赖 - for plugin_name, missing in missing_dependencies.items(): - if not missing: - continue - missing_dep_names = ", ".join(sorted(missing)) - self.failed_plugins[plugin_name] = f"缺少依赖插件: {missing_dep_names}" - logger.error(f"❌ 插件加载失败: {plugin_name} - 缺少依赖插件: {missing_dep_names}") - total_failed_registration += 1 - - # 再处理循环依赖 - for plugin_name in sorted(cycle_plugins): - if plugin_name in missing_dependencies and missing_dependencies[plugin_name]: - continue - self.failed_plugins[plugin_name] = "检测到循环依赖" - logger.error(f"❌ 插件加载失败: {plugin_name} - 检测到循环依赖") - total_failed_registration += 1 - - # 最后按拓扑序加载可加载插件 - for plugin_name in sorted_plugins: - if plugin_name in cycle_plugins: - continue - if plugin_name in missing_dependencies and missing_dependencies[plugin_name]: - continue - load_status, count = self.load_registered_plugin_classes(plugin_name) - if load_status: - total_registered += 1 - else: - total_failed_registration += count - - self._show_stats(total_registered, total_failed_registration) - - return total_registered, total_failed_registration - - def load_registered_plugin_classes(self, plugin_name: str) -> Tuple[bool, int]: - # sourcery skip: extract-duplicate-method, extract-method - """ - 加载已经注册的插件类 - """ - plugin_class = self.plugin_classes.get(plugin_name) - if not plugin_class: - logger.error(f"插件 {plugin_name} 的插件类未注册或不存在") - return False, 1 - try: - # 使用记录的插件目录路径 - plugin_dir = self.plugin_paths.get(plugin_name) - - # 如果没有记录,直接返回失败 - if not plugin_dir: - return False, 1 - - plugin_instance = plugin_class(plugin_dir=plugin_dir) # 实例化插件(可能因为缺少manifest而失败) - if not plugin_instance: - logger.error(f"插件 {plugin_name} 实例化失败") - return False, 1 - # 检查插件是否启用 - if not plugin_instance.enable_plugin: - logger.info(f"插件 {plugin_name} 已禁用,跳过加载") - return False, 0 - - # 检查版本兼容性 - is_compatible, compatibility_error = self._check_plugin_version_compatibility( - plugin_name, plugin_instance.manifest_data - ) - if not is_compatible: - self.failed_plugins[plugin_name] = compatibility_error - logger.error(f"❌ 插件加载失败: {plugin_name} - {compatibility_error}") - return False, 1 - if plugin_instance.register_plugin(): - self.loaded_plugins[plugin_name] = plugin_instance - self._show_plugin_components(plugin_name) - return True, 1 - else: - self.failed_plugins[plugin_name] = "插件注册失败" - logger.error(f"❌ 插件注册失败: {plugin_name}") - return False, 1 - - except FileNotFoundError as e: - # manifest文件缺失 - error_msg = f"缺少manifest文件: {str(e)}" - self.failed_plugins[plugin_name] = error_msg - logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}") - return False, 1 - - except ValueError as e: - # manifest文件格式错误或验证失败 - traceback.print_exc() - error_msg = f"manifest验证失败: {str(e)}" - self.failed_plugins[plugin_name] = error_msg - logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}") - return False, 1 - - except Exception as e: - # 其他错误 - error_msg = f"未知错误: {str(e)}" - self.failed_plugins[plugin_name] = error_msg - logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}") - logger.debug("详细错误信息: ", exc_info=True) - return False, 1 - - async def remove_registered_plugin(self, plugin_name: str) -> bool: - """ - 禁用插件模块 - """ - if not plugin_name: - raise ValueError("插件名称不能为空") - if plugin_name not in self.loaded_plugins: - logger.warning(f"插件 {plugin_name} 未加载") - return False - plugin_instance = self.loaded_plugins[plugin_name] - plugin_info = plugin_instance.plugin_info - success = True - for component in plugin_info.components: - success &= await component_registry.remove_component(component.name, component.component_type, plugin_name) - success &= component_registry.remove_plugin_registry(plugin_name) - plugin_service_registry.remove_services_by_plugin(plugin_name) - del self.loaded_plugins[plugin_name] - return success - - async def reload_registered_plugin(self, plugin_name: str) -> bool: - """ - 重载插件模块 - """ - old_instance = self.loaded_plugins.get(plugin_name) - if not old_instance: - logger.warning(f"插件 {plugin_name} 未加载,无法重载") - return False - - if not await self.remove_registered_plugin(plugin_name): - return False - - if not self.load_registered_plugin_classes(plugin_name)[0]: - logger.error(f"插件 {plugin_name} 重载失败,开始回滚旧实例") - rollback_ok = await self._rollback_failed_reload(plugin_name, old_instance) - if rollback_ok: - logger.info(f"插件 {plugin_name} 已回滚到旧版本实例") - else: - logger.error(f"插件 {plugin_name} 回滚失败,插件当前不可用") - return False - - logger.debug(f"插件 {plugin_name} 重载成功") - return True - - async def _rollback_failed_reload(self, plugin_name: str, old_instance: PluginBase) -> bool: - """重载失败后回滚旧实例。""" - try: - await component_registry.remove_components_by_plugin(plugin_name) - component_registry.remove_plugin_registry(plugin_name) - plugin_service_registry.remove_services_by_plugin(plugin_name) - - if not old_instance.register_plugin(): - logger.error(f"插件 {plugin_name} 回滚失败: 旧实例重新注册失败") - return False - - self.loaded_plugins[plugin_name] = old_instance - return True - except Exception as e: - logger.error(f"插件 {plugin_name} 回滚异常: {e}", exc_info=True) - return False - - def rescan_plugin_directory(self) -> Tuple[int, int]: - """ - 重新扫描插件根目录 - """ - total_success = 0 - total_fail = 0 - for directory in self.plugin_directories: - if os.path.exists(directory): - logger.debug(f"重新扫描插件根目录: {directory}") - success, fail = self._load_plugin_modules_from_directory(directory) - total_success += success - total_fail += fail - else: - logger.warning(f"插件根目录不存在: {directory}") - return total_success, total_fail - - def get_plugin_instance(self, plugin_name: str) -> Optional["PluginBase"]: - """获取插件实例 - - Args: - plugin_name: 插件名称 - - Returns: - Optional[BasePlugin]: 插件实例或None - """ - return self.loaded_plugins.get(plugin_name) - - # === 查询方法 === - def list_loaded_plugins(self) -> List[str]: - """ - 列出所有当前加载的插件。 - - Returns: - list: 当前加载的插件名称列表。 - """ - return list(self.loaded_plugins.keys()) - - def list_registered_plugins(self) -> List[str]: - """ - 列出所有已注册的插件类。 - - Returns: - list: 已注册的插件类名称列表。 - """ - return list(self.plugin_classes.keys()) - - def get_plugin_path(self, plugin_name: str) -> Optional[str]: - """ - 获取指定插件的路径。 - - Args: - plugin_name: 插件名称 - - Returns: - Optional[str]: 插件目录的绝对路径,如果插件不存在则返回None。 - """ - return self.plugin_paths.get(plugin_name) - - # === 私有方法 === - # == 目录管理 == - def _ensure_plugin_directories(self) -> None: - """确保所有插件根目录存在,如果不存在则创建""" - default_directories = ["src/plugins/built_in", "plugins"] - - for directory in default_directories: - if not os.path.exists(directory): - os.makedirs(directory, exist_ok=True) - logger.info(f"创建插件根目录: {directory}") - if directory not in self.plugin_directories: - self.plugin_directories.append(directory) - logger.debug(f"已添加插件根目录: {directory}") - else: - logger.warning(f"根目录不可重复加载: {directory}") - - # == 插件加载 == - - def _load_plugin_modules_from_directory(self, directory: str) -> tuple[int, int]: - """从指定目录加载插件模块""" - loaded_count = 0 - failed_count = 0 - - if not os.path.exists(directory): - logger.warning(f"插件根目录不存在: {directory}") - return 0, 1 - - logger.debug(f"正在扫描插件根目录: {directory}") - - # 遍历目录中的所有包 - for item in os.listdir(directory): - item_path = os.path.join(directory, item) - - if os.path.isdir(item_path) and not item.startswith(".") and not item.startswith("__"): - plugin_file = os.path.join(item_path, "plugin.py") - if os.path.exists(plugin_file): - if self._load_plugin_module_file(plugin_file): - loaded_count += 1 - else: - failed_count += 1 - - return loaded_count, failed_count - - def _load_plugin_module_file(self, plugin_file: str) -> bool: - # sourcery skip: extract-method - """加载单个插件模块文件 - - Args: - plugin_file: 插件文件路径 - plugin_name: 插件名称 - plugin_dir: 插件目录路径 - """ - # 生成模块名 - plugin_path = Path(plugin_file) - module_name = ".".join(plugin_path.parent.parts) - - try: - # 动态导入插件模块 - spec = spec_from_file_location(module_name, plugin_file) - if spec is None or spec.loader is None: - logger.error(f"无法创建模块规范: {plugin_file}") - return False - - module = module_from_spec(spec) - module.__package__ = module_name # 设置模块包名 - spec.loader.exec_module(module) - - logger.debug(f"插件模块加载成功: {plugin_file}") - return True - - except Exception as e: - error_msg = f"加载插件模块 {plugin_file} 失败: {e}" - logger.error(error_msg) - self.failed_plugins[module_name] = error_msg - return False - - # == 依赖解析与加载顺序 == - - def _extract_declared_dependencies(self, plugin_name: str, plugin_class: Type[PluginBase]) -> Set[str]: - """提取插件声明的依赖。 - - 兼容声明格式: - - list[str] - - list[dict],其中dict至少包含name键 - """ - dependencies: Set[str] = set() - raw_dependencies = getattr(plugin_class, "dependencies", []) - - # 兼容错误声明 - if isinstance(raw_dependencies, property): - logger.warning(f"插件 {plugin_name} 的 dependencies 未声明为类属性,将按无依赖处理") - return dependencies - if not isinstance(raw_dependencies, list): - logger.warning(f"插件 {plugin_name} 的 dependencies 不是列表,将按无依赖处理") - return dependencies - - for dependency in raw_dependencies: - dependency_name = "" - if isinstance(dependency, str): - dependency_name = dependency.strip() - elif isinstance(dependency, dict): - dependency_name = str(dependency.get("name", "")).strip() - else: - logger.warning(f"插件 {plugin_name} 包含不支持的依赖声明类型: {type(dependency)}") - continue - - if not dependency_name: - continue - if dependency_name == plugin_name: - logger.warning(f"插件 {plugin_name} 声明了对自身的依赖,已忽略") - continue - - dependencies.add(dependency_name) - - return dependencies - - def _build_plugin_dependency_graph(self) -> Tuple[Dict[str, Set[str]], Dict[str, Set[str]]]: - """构建依赖图并返回缺失依赖映射。""" - plugin_names = set(self.plugin_classes.keys()) - dependency_graph: Dict[str, Set[str]] = {name: set() for name in plugin_names} - missing_dependencies: Dict[str, Set[str]] = {name: set() for name in plugin_names} - - for plugin_name, plugin_class in self.plugin_classes.items(): - declared_dependencies = self._extract_declared_dependencies(plugin_name, plugin_class) - for dependency_name in declared_dependencies: - if dependency_name in plugin_names: - dependency_graph[plugin_name].add(dependency_name) - else: - missing_dependencies[plugin_name].add(dependency_name) - - return dependency_graph, missing_dependencies - - def _resolve_plugin_load_order(self, dependency_graph: Dict[str, Set[str]]) -> Tuple[List[str], Set[str]]: - """根据依赖图计算加载顺序,并检测循环依赖。""" - indegree: Dict[str, int] = { - plugin_name: len(dependencies) for plugin_name, dependencies in dependency_graph.items() - } - reverse_graph: Dict[str, Set[str]] = {plugin_name: set() for plugin_name in dependency_graph} - - for plugin_name, dependencies in dependency_graph.items(): - for dependency_name in dependencies: - reverse_graph[dependency_name].add(plugin_name) - - zero_indegree_queue = deque(sorted([name for name, degree in indegree.items() if degree == 0])) - load_order: List[str] = [] - - while zero_indegree_queue: - current_plugin = zero_indegree_queue.popleft() - load_order.append(current_plugin) - - for dependent_plugin in sorted(reverse_graph[current_plugin]): - indegree[dependent_plugin] -= 1 - if indegree[dependent_plugin] == 0: - zero_indegree_queue.append(dependent_plugin) - - cycle_plugins = {name for name, degree in indegree.items() if degree > 0} - if cycle_plugins: - logger.error(f"检测到循环依赖插件: {', '.join(sorted(cycle_plugins))}") - - return load_order, cycle_plugins - - # == 兼容性检查 == - - def _check_plugin_version_compatibility(self, plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]: - """检查插件版本兼容性 - - Args: - plugin_name: 插件名称 - manifest_data: manifest数据 - - Returns: - Tuple[bool, str]: (是否兼容, 错误信息) - """ - if "host_application" not in manifest_data: - return True, "" # 没有版本要求,默认兼容 - - host_app = manifest_data["host_application"] - if not isinstance(host_app, dict): - return True, "" - - min_version = host_app.get("min_version", "") - max_version = host_app.get("max_version", "") - - if not min_version and not max_version: - return True, "" # 没有版本要求,默认兼容 - - try: - current_version = VersionComparator.get_current_host_version() - is_compatible, error_msg = VersionComparator.is_version_in_range(current_version, min_version, max_version) - if not is_compatible: - return False, f"版本不兼容: {error_msg}" - logger.debug(f"插件 {plugin_name} 版本兼容性检查通过") - return True, "" - - except Exception as e: - logger.warning(f"插件 {plugin_name} 版本兼容性检查失败: {e}") - return False, f"插件 {plugin_name} 版本兼容性检查失败: {e}" # 检查失败时默认不允许加载 - - # == 显示统计与插件信息 == - - def _show_stats(self, total_registered: int, total_failed_registration: int): - # sourcery skip: low-code-quality - # 获取组件统计信息 - stats = component_registry.get_registry_stats() - action_count = stats.get("action_components", 0) - command_count = stats.get("command_components", 0) - tool_count = stats.get("tool_components", 0) - event_handler_count = stats.get("event_handlers", 0) - total_components = stats.get("total_components", 0) - - # 📋 显示插件加载总览 - if total_registered > 0: - logger.info("🎉 插件系统加载完成!") - logger.info( - f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, EventHandler: {event_handler_count})" - ) - - # 显示详细的插件列表 - logger.info("📋 已加载插件详情:") - for plugin_name in self.loaded_plugins.keys(): - if plugin_info := component_registry.get_plugin_info(plugin_name): - # 插件基本信息 - version_info = f"v{plugin_info.version}" if plugin_info.version else "" - author_info = f"by {plugin_info.author}" if plugin_info.author else "unknown" - license_info = f"[{plugin_info.license}]" if plugin_info.license else "" - info_parts = [part for part in [version_info, author_info, license_info] if part] - extra_info = f" ({', '.join(info_parts)})" if info_parts else "" - - logger.info(f" 📦 {plugin_info.display_name}{extra_info}") - - # Manifest信息 - if plugin_info.manifest_data: - """ - if plugin_info.keywords: - logger.info(f" 🏷️ 关键词: {', '.join(plugin_info.keywords)}") - if plugin_info.categories: - logger.info(f" 📁 分类: {', '.join(plugin_info.categories)}") - """ - if plugin_info.homepage_url: - logger.info(f" 🌐 主页: {plugin_info.homepage_url}") - - # 组件列表 - if plugin_info.components: - action_components = [ - c for c in plugin_info.components if c.component_type == ComponentType.ACTION - ] - command_components = [ - c for c in plugin_info.components if c.component_type == ComponentType.COMMAND - ] - tool_components = [c for c in plugin_info.components if c.component_type == ComponentType.TOOL] - event_handler_components = [ - c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER - ] - - if action_components: - action_names = [c.name for c in action_components] - logger.info(f" 🎯 Action组件: {', '.join(action_names)}") - - if command_components: - command_names = [c.name for c in command_components] - logger.info(f" ⚡ Command组件: {', '.join(command_names)}") - if tool_components: - tool_names = [c.name for c in tool_components] - logger.info(f" 🛠️ Tool组件: {', '.join(tool_names)}") - if event_handler_components: - event_handler_names = [c.name for c in event_handler_components] - logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_names)}") - - # 依赖信息 - if plugin_info.dependencies: - logger.info(f" 🔗 依赖: {', '.join(plugin_info.dependencies)}") - - # 配置文件信息 - if plugin_info.config_file: - config_status = "✅" if self.plugin_paths.get(plugin_name) else "❌" - logger.info(f" ⚙️ 配置: {plugin_info.config_file} {config_status}") - - root_path = Path(__file__) - - # 查找项目根目录 - while not (root_path / "pyproject.toml").exists() and root_path.parent != root_path: - root_path = root_path.parent - - # 显示目录统计 - logger.info("📂 加载目录统计:") - for directory in self.plugin_directories: - if os.path.exists(directory): - plugins_in_dir = [] - for plugin_name in self.loaded_plugins.keys(): - plugin_path = self.plugin_paths.get(plugin_name, "") - if ( - Path(plugin_path) - .resolve() - .is_relative_to(Path(os.path.join(str(root_path), directory)).resolve()) - ): - plugins_in_dir.append(plugin_name) - - if plugins_in_dir: - logger.info(f" 📁 {directory}: {len(plugins_in_dir)}个插件 ({', '.join(plugins_in_dir)})") - else: - logger.info(f" 📁 {directory}: 0个插件") - - # 失败信息 - if total_failed_registration > 0: - logger.info(f"⚠️ 失败统计: {total_failed_registration}个插件加载失败") - for failed_plugin, error in self.failed_plugins.items(): - logger.info(f" ❌ {failed_plugin}: {error}") - else: - logger.warning("😕 没有成功加载任何插件") - - def _show_plugin_components(self, plugin_name: str) -> None: - if plugin_info := component_registry.get_plugin_info(plugin_name): - component_types = {} - for comp in plugin_info.components: - comp_type = comp.component_type.name - component_types[comp_type] = component_types.get(comp_type, 0) + 1 - - components_str = ", ".join([f"{count}个{ctype}" for ctype, count in component_types.items()]) - - # 显示manifest信息 - manifest_info = "" - if plugin_info.license: - manifest_info += f" [{plugin_info.license}]" - if plugin_info.keywords: - manifest_info += f" 关键词: {', '.join(plugin_info.keywords[:3])}" # 只显示前3个关键词 - if len(plugin_info.keywords) > 3: - manifest_info += "..." - - logger.info( - f"✅ 插件加载成功: {plugin_name} v{plugin_info.version} ({components_str}){manifest_info} - {plugin_info.description}" - ) - else: - logger.info(f"✅ 插件加载成功: {plugin_name}") - - -# 全局插件管理器实例 -plugin_manager = PluginManager() diff --git a/src/plugin_system/core/plugin_service_registry.py b/src/plugin_system/core/plugin_service_registry.py deleted file mode 100644 index d94e5b3a..00000000 --- a/src/plugin_system/core/plugin_service_registry.py +++ /dev/null @@ -1,251 +0,0 @@ -from typing import Any, Callable, Dict, Optional -import inspect - -from src.common.logger import get_logger -from src.plugin_system.base.service_types import PluginServiceInfo - -logger = get_logger("plugin_service_registry") - - -class PluginServiceRegistry: - """插件服务注册中心""" - - def __init__(self): - self._services: Dict[str, PluginServiceInfo] = {} - self._service_handlers: Dict[str, Callable[..., Any]] = {} - logger.info("插件服务注册中心初始化完成") - - def register_service(self, service_info: PluginServiceInfo, service_handler: Callable[..., Any]) -> bool: - """注册插件服务。""" - if not service_info.name or not service_info.plugin_name: - logger.error("插件服务注册失败: service名称或插件名称为空") - return False - if "." in service_info.name: - logger.error(f"插件服务名称 '{service_info.name}' 包含非法字符 '.',请使用下划线替代") - return False - if "." in service_info.plugin_name: - logger.error(f"插件服务所属插件名称 '{service_info.plugin_name}' 包含非法字符 '.',请使用下划线替代") - return False - if invalid_callers := [caller for caller in service_info.allowed_callers if "." in caller]: - logger.error(f"插件服务白名单包含非法调用方名称: {invalid_callers}") - return False - - full_name = service_info.full_name - if full_name in self._services: - logger.warning(f"插件服务已存在,拒绝重复注册: {full_name}") - return False - - self._services[full_name] = service_info - self._service_handlers[full_name] = service_handler - logger.debug(f"已注册插件服务: {full_name} (version={service_info.version})") - return True - - def get_service(self, service_name: str, plugin_name: Optional[str] = None) -> Optional[PluginServiceInfo]: - """获取插件服务元信息。 - - service_name支持: - - full_name: plugin_name.service_name - - short_name: service_name(当唯一时可解析) - """ - full_name = self._resolve_full_name(service_name, plugin_name) - return self._services.get(full_name) if full_name else None - - def get_service_handler(self, service_name: str, plugin_name: Optional[str] = None) -> Optional[Callable[..., Any]]: - """获取插件服务处理函数。""" - full_name = self._resolve_full_name(service_name, plugin_name) - return self._service_handlers.get(full_name) if full_name else None - - def list_services( - self, plugin_name: Optional[str] = None, enabled_only: bool = False - ) -> Dict[str, PluginServiceInfo]: - """列出插件服务。""" - services = self._services.copy() - if plugin_name: - services = {name: info for name, info in services.items() if info.plugin_name == plugin_name} - if enabled_only: - services = {name: info for name, info in services.items() if info.enabled} - return services - - def enable_service(self, service_name: str, plugin_name: Optional[str] = None) -> bool: - """启用插件服务。""" - if not (service_info := self.get_service(service_name, plugin_name)): - logger.warning(f"插件服务未注册,无法启用: {service_name}") - return False - service_info.enabled = True - logger.info(f"插件服务已启用: {service_info.full_name}") - return True - - def disable_service(self, service_name: str, plugin_name: Optional[str] = None) -> bool: - """禁用插件服务。""" - if not (service_info := self.get_service(service_name, plugin_name)): - logger.warning(f"插件服务未注册,无法禁用: {service_name}") - return False - service_info.enabled = False - logger.info(f"插件服务已禁用: {service_info.full_name}") - return True - - def unregister_service(self, service_name: str, plugin_name: Optional[str] = None) -> bool: - """注销单个插件服务。""" - full_name = self._resolve_full_name(service_name, plugin_name) - if not full_name: - logger.warning(f"插件服务未注册,无法注销: {service_name}") - return False - - self._services.pop(full_name, None) - self._service_handlers.pop(full_name, None) - logger.info(f"插件服务已注销: {full_name}") - return True - - def remove_services_by_plugin(self, plugin_name: str) -> int: - """移除某插件的所有注册服务。""" - target_names = [full_name for full_name, info in self._services.items() if info.plugin_name == plugin_name] - for full_name in target_names: - self._services.pop(full_name, None) - self._service_handlers.pop(full_name, None) - - removed_count = len(target_names) - if removed_count: - logger.info(f"已移除插件 {plugin_name} 的服务数量: {removed_count}") - return removed_count - - async def call_service( - self, - service_name: str, - *args: Any, - plugin_name: Optional[str] = None, - caller_plugin: Optional[str] = None, - **kwargs: Any, - ) -> Any: - """调用插件服务(支持同步/异步handler)。""" - service_info = self.get_service(service_name, plugin_name) - if not service_info: - target_name = f"{plugin_name}.{service_name}" if plugin_name and "." not in service_name else service_name - raise ValueError(f"插件服务未注册: {target_name}") - - if ( - "." not in service_name - and plugin_name is None - and caller_plugin - and service_info.plugin_name != caller_plugin - ): - raise PermissionError("跨插件服务调用必须使用完整服务名或显式指定plugin_name") - - if not self._is_call_authorized(service_info, caller_plugin): - raise PermissionError( - f"调用被拒绝: caller={caller_plugin or 'anonymous'} 无权限访问服务 {service_info.full_name}" - ) - - if not service_info.enabled: - raise RuntimeError(f"插件服务已禁用: {service_info.full_name}") - - handler = self.get_service_handler(service_name, plugin_name) - if not handler: - raise RuntimeError(f"插件服务处理器不存在: {service_info.full_name}") - - self._validate_input_contract(service_info, args, kwargs) - - result = handler(*args, **kwargs) - resolved_result = await result if inspect.isawaitable(result) else result - self._validate_output_contract(service_info, resolved_result) - return resolved_result - - def _is_call_authorized(self, service_info: PluginServiceInfo, caller_plugin: Optional[str]) -> bool: - """检查服务调用权限。""" - if caller_plugin is None: - return service_info.public - if caller_plugin == service_info.plugin_name: - return True - if service_info.public: - return True - allowed_callers = {caller.strip() for caller in service_info.allowed_callers if caller.strip()} - return "*" in allowed_callers or caller_plugin in allowed_callers - - def _validate_input_contract( - self, service_info: PluginServiceInfo, args: tuple[Any, ...], kwargs: dict[str, Any] - ) -> None: - """校验服务入参契约。""" - schema = service_info.params_schema - if not schema: - return - - properties = schema.get("properties", {}) if isinstance(schema, dict) else {} - is_invocation_schema = "args" in properties or "kwargs" in properties - - if is_invocation_schema: - payload = {"args": list(args), "kwargs": kwargs} - self._validate_by_schema(payload, schema, path="params") - return - - if args: - raise ValueError(f"服务 {service_info.full_name} 的入参契约不允许位置参数") - self._validate_by_schema(kwargs, schema, path="params") - - def _validate_output_contract(self, service_info: PluginServiceInfo, value: Any) -> None: - """校验服务返回值契约。""" - if not service_info.return_schema: - return - self._validate_by_schema(value, service_info.return_schema, path="return") - - def _validate_by_schema(self, value: Any, schema: Dict[str, Any], path: str) -> None: - """基于简化JSON-Schema校验数据。""" - expected_type = schema.get("type") - if expected_type: - self._validate_type(value, expected_type, path) - - enum_values = schema.get("enum") - if enum_values is not None and value not in enum_values: - raise ValueError(f"{path} 不在枚举范围内: {value}") - - if expected_type == "object": - properties = schema.get("properties", {}) - required = schema.get("required", []) - - for field in required: - if field not in value: - raise ValueError(f"{path}.{field} 为必填字段") - - for field, field_value in value.items(): - if field in properties: - self._validate_by_schema(field_value, properties[field], f"{path}.{field}") - elif schema.get("additionalProperties", True) is False: - raise ValueError(f"{path}.{field} 不允许额外字段") - - if expected_type == "array": - if item_schema := schema.get("items"): - for index, item in enumerate(value): - self._validate_by_schema(item, item_schema, f"{path}[{index}]") - - def _validate_type(self, value: Any, expected_type: str, path: str) -> None: - """校验基础类型。""" - type_checkers: Dict[str, Callable[[Any], bool]] = { - "string": lambda item: isinstance(item, str), - "number": lambda item: isinstance(item, (int, float)) and not isinstance(item, bool), - "integer": lambda item: isinstance(item, int) and not isinstance(item, bool), - "boolean": lambda item: isinstance(item, bool), - "object": lambda item: isinstance(item, dict), - "array": lambda item: isinstance(item, list), - "null": lambda item: item is None, - } - checker = type_checkers.get(expected_type) - if checker and not checker(value): - raise TypeError(f"{path} 类型不匹配,期望 {expected_type},实际 {type(value).__name__}") - - def _resolve_full_name(self, service_name: str, plugin_name: Optional[str] = None) -> Optional[str]: - """解析服务全名。""" - if "." in service_name: - return service_name if service_name in self._services else None - - if plugin_name: - full_name = f"{plugin_name}.{service_name}" - return full_name if full_name in self._services else None - - candidates = [full_name for full_name, info in self._services.items() if info.name == service_name] - if len(candidates) == 1: - return candidates[0] - if len(candidates) > 1: - logger.warning(f"插件服务名称 '{service_name}' 存在多义,请传入plugin_name或使用完整服务名") - return None - return None - - -plugin_service_registry = PluginServiceRegistry() diff --git a/src/plugin_system/core/to_do_event.md b/src/plugin_system/core/to_do_event.md deleted file mode 100644 index dd7b9fab..00000000 --- a/src/plugin_system/core/to_do_event.md +++ /dev/null @@ -1,13 +0,0 @@ -- [x] 自定义事件 -- [ ] 允许handler随时订阅 -- [x] 允许其他组件给handler增加订阅 -- [x] 允许其他组件给handler取消订阅 -- [ ] 允许一个handler订阅多个事件 -- [x] event激活时给handler传递参数 -- [ ] handler能拿到所有handlers的结果(按照处理权重) -- [x] 随时注册 -- [ ] 删除event - - [ ] 必要性? -- [x] 能够更改prompt -- [x] 能够更改llm_response -- [x] 能够更改message \ No newline at end of file diff --git a/src/plugin_system/core/workflow_engine.py b/src/plugin_system/core/workflow_engine.py deleted file mode 100644 index 1a3fbcf7..00000000 --- a/src/plugin_system/core/workflow_engine.py +++ /dev/null @@ -1,260 +0,0 @@ -from typing import Any, Awaitable, Callable, List, Optional, Tuple, Union -import asyncio -import inspect -import time -import uuid - -from src.common.logger import get_logger -from src.plugin_system.base.component_types import EventType, MaiMessages -from src.plugin_system.base.workflow_errors import WorkflowErrorCode -from src.plugin_system.base.workflow_types import WorkflowContext, WorkflowStage, WorkflowStepResult - -logger = get_logger("workflow_engine") - - -class WorkflowEngine: - """线性Workflow执行器(MVP)""" - - STAGE_EVENT_SEQUENCE: List[Tuple[WorkflowStage, Union[EventType, str]]] = [ - (WorkflowStage.INGRESS, "workflow.ingress"), - (WorkflowStage.PRE_PROCESS, EventType.ON_MESSAGE_PRE_PROCESS), - (WorkflowStage.PLAN, EventType.ON_PLAN), - (WorkflowStage.TOOL_EXECUTE, "workflow.tool_execute"), - (WorkflowStage.POST_PROCESS, EventType.POST_SEND_PRE_PROCESS), - (WorkflowStage.EGRESS, "workflow.egress"), - ] - - def __init__(self): - self._execution_history: dict[str, dict[str, Any]] = {} - - async def execute_linear( - self, - dispatch_event: Callable[ - [Union[EventType, str], Optional[MaiMessages], Optional[str], Optional[List[str]]], - Awaitable[Tuple[bool, Optional[MaiMessages]]], - ], - message: Optional[MaiMessages] = None, - stream_id: Optional[str] = None, - action_usage: Optional[List[str]] = None, - context: Optional[WorkflowContext] = None, - ) -> Tuple[WorkflowStepResult, Optional[MaiMessages], WorkflowContext]: - """执行线性workflow。""" - workflow_context = context or WorkflowContext(trace_id=uuid.uuid4().hex, stream_id=stream_id) - current_message = message.deepcopy() if message else None - self._execution_history[workflow_context.trace_id] = { - "trace_id": workflow_context.trace_id, - "stream_id": workflow_context.stream_id, - "stages": [], - "errors": [], - "status": "running", - } - - for stage, event_type in self.STAGE_EVENT_SEQUENCE: - stage_key = str(stage) - stage_start = time.perf_counter() - try: - should_continue, modified_message = await dispatch_event( - event_type, - current_message, - workflow_context.stream_id, - action_usage, - ) - workflow_context.timings[stage_key] = time.perf_counter() - stage_start - self._execution_history[workflow_context.trace_id]["stages"].append( - { - "stage": stage_key, - "event_type": str(event_type), - "event_continue": should_continue, - "event_cost": workflow_context.timings[stage_key], - } - ) - - if modified_message: - current_message = modified_message - - if not should_continue: - logger.info(f"[trace_id={workflow_context.trace_id}] Workflow在阶段 {stage_key} 被中断") - return ( - WorkflowStepResult( - status="stop", - return_message=f"workflow stopped at stage {stage_key}", - diagnostics={ - "stage": stage_key, - "trace_id": workflow_context.trace_id, - "error_code": WorkflowErrorCode.POLICY_BLOCKED.value, - }, - ), - current_message, - workflow_context, - ) - - step_result = await self._execute_registered_steps(stage, workflow_context, current_message) - if step_result.status in ["stop", "failed"]: - self._execution_history[workflow_context.trace_id]["status"] = step_result.status - self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy() - return step_result, current_message, workflow_context - except Exception as e: - workflow_context.timings[stage_key] = time.perf_counter() - stage_start - workflow_context.errors.append(f"{stage_key}: {e}") - logger.error( - f"[trace_id={workflow_context.trace_id}] Workflow阶段 {stage_key} 执行异常: {e}", exc_info=True - ) - self._execution_history[workflow_context.trace_id]["status"] = "failed" - self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy() - return ( - WorkflowStepResult( - status="failed", - return_message=str(e), - diagnostics={ - "stage": stage_key, - "trace_id": workflow_context.trace_id, - "error_code": WorkflowErrorCode.DOWNSTREAM_FAILED.value, - }, - ), - current_message, - workflow_context, - ) - - self._execution_history[workflow_context.trace_id]["status"] = "continue" - self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy() - return ( - WorkflowStepResult( - status="continue", - return_message="workflow completed", - diagnostics={"trace_id": workflow_context.trace_id}, - ), - current_message, - workflow_context, - ) - - async def _execute_registered_steps( - self, - stage: WorkflowStage, - context: WorkflowContext, - message: Optional[MaiMessages], - ) -> WorkflowStepResult: - """执行指定阶段已注册的workflow步骤。""" - from src.plugin_system.core.component_registry import component_registry - - stage_steps = component_registry.get_steps_by_stage(stage, enabled_only=True) - sorted_steps = sorted(stage_steps.values(), key=lambda step_info: step_info.priority, reverse=True) - - for step_info in sorted_steps: - handler = component_registry.get_workflow_step_handler(step_info.full_name, stage) - if not handler: - context.errors.append(f"{step_info.full_name}: handler not found") - continue - - step_timing_key = f"{stage.value}:{step_info.full_name}" - step_start = time.perf_counter() - timeout_seconds = step_info.timeout_ms / 1000 if step_info.timeout_ms > 0 else None - - try: - if inspect.iscoroutinefunction(handler): - coroutine = handler(context, message) - result = await asyncio.wait_for(coroutine, timeout_seconds) if timeout_seconds else await coroutine - else: - if timeout_seconds: - result = await asyncio.wait_for(asyncio.to_thread(handler, context, message), timeout_seconds) - else: - result = handler(context, message) - if inspect.isawaitable(result): - result = await asyncio.wait_for(result, timeout_seconds) if timeout_seconds else await result - context.timings[step_timing_key] = time.perf_counter() - step_start - - normalized_result = self._normalize_step_result(result) - if normalized_result.status == "continue": - continue - - normalized_result.diagnostics.setdefault("stage", stage.value) - normalized_result.diagnostics.setdefault("step", step_info.full_name) - normalized_result.diagnostics.setdefault("trace_id", context.trace_id) - if normalized_result.status == "failed": - context.errors.append( - f"{step_info.full_name}: {normalized_result.return_message or 'workflow step failed'}" - ) - normalized_result.diagnostics.setdefault("error_code", WorkflowErrorCode.DOWNSTREAM_FAILED.value) - return normalized_result - - except asyncio.TimeoutError: - context.timings[step_timing_key] = time.perf_counter() - step_start - timeout_message = f"workflow step timeout after {step_info.timeout_ms}ms" - context.errors.append(f"{step_info.full_name}: {timeout_message}") - logger.error( - f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 超时: {timeout_message}" - ) - return WorkflowStepResult( - status="failed", - return_message=timeout_message, - diagnostics={ - "stage": stage.value, - "step": step_info.full_name, - "trace_id": context.trace_id, - "error_code": WorkflowErrorCode.STEP_TIMEOUT.value, - }, - ) - - except Exception as e: - context.timings[step_timing_key] = time.perf_counter() - step_start - context.errors.append(f"{step_info.full_name}: {e}") - logger.error( - f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 执行异常: {e}", exc_info=True - ) - return WorkflowStepResult( - status="failed", - return_message=str(e), - diagnostics={ - "stage": stage.value, - "step": step_info.full_name, - "trace_id": context.trace_id, - "error_code": WorkflowErrorCode.DOWNSTREAM_FAILED.value, - }, - ) - - return WorkflowStepResult(status="continue", diagnostics={"stage": stage.value, "trace_id": context.trace_id}) - - def _normalize_step_result(self, result: Any) -> WorkflowStepResult: - """归一化workflow step返回值。""" - if isinstance(result, WorkflowStepResult): - return result - if isinstance(result, bool): - if result: - return WorkflowStepResult(status="continue") - return WorkflowStepResult( - status="failed", - diagnostics={"error_code": WorkflowErrorCode.DOWNSTREAM_FAILED.value}, - ) - if result is None: - return WorkflowStepResult(status="continue") - if isinstance(result, str): - return WorkflowStepResult(status="continue", return_message=result) - if isinstance(result, dict): - status = result.get("status", "continue") - if status not in ["continue", "stop", "failed"]: - status = "failed" - return WorkflowStepResult( - status=status, - return_message=result.get("return_message"), - diagnostics=result.get("diagnostics", {}), - events=result.get("events", []), - ) - return WorkflowStepResult( - status="failed", - return_message=f"unsupported step result type: {type(result)}", - diagnostics={"error_code": WorkflowErrorCode.BAD_PAYLOAD.value}, - ) - - def get_execution_trace(self, trace_id: str) -> Optional[dict[str, Any]]: - """按trace_id获取workflow执行路径。""" - trace = self._execution_history.get(trace_id) - return trace.copy() if trace else None - - def clear_execution_trace(self, trace_id: str) -> bool: - """清理trace执行记录。""" - if trace_id in self._execution_history: - del self._execution_history[trace_id] - return True - return False - - -workflow_engine = WorkflowEngine() diff --git a/src/plugin_system/utils/__init__.py b/src/plugin_system/utils/__init__.py deleted file mode 100644 index bf49e3fa..00000000 --- a/src/plugin_system/utils/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -""" -插件系统工具模块 - -提供插件开发和管理的实用工具 -""" - -from .manifest_utils import ( - ManifestValidator, - # ManifestGenerator, - # validate_plugin_manifest, - # generate_plugin_manifest, -) - -__all__ = [ - "ManifestValidator", - # "ManifestGenerator", - # "validate_plugin_manifest", - # "generate_plugin_manifest", -] diff --git a/src/plugin_system/utils/manifest_utils.py b/src/plugin_system/utils/manifest_utils.py deleted file mode 100644 index d070b733..00000000 --- a/src/plugin_system/utils/manifest_utils.py +++ /dev/null @@ -1,515 +0,0 @@ -""" -插件Manifest工具模块 - -提供manifest文件的验证、生成和管理功能 -""" - -import re -from typing import Dict, Any, Tuple -from src.common.logger import get_logger -from src.config.config import MMC_VERSION - -# if TYPE_CHECKING: -# from src.plugin_system.base.base_plugin import BasePlugin - -logger = get_logger("manifest_utils") - - -class VersionComparator: - """版本号比较器 - - 支持语义化版本号比较,自动处理snapshot版本,并支持向前兼容性检查 - """ - - # 版本兼容性映射表(硬编码) - # 格式: {插件最大支持版本: [实际兼容的版本列表]} - COMPATIBILITY_MAP = { - # 0.8.x 系列向前兼容规则 - "0.8.0": ["0.8.1", "0.8.2", "0.8.3", "0.8.4", "0.8.5", "0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"], - "0.8.1": ["0.8.2", "0.8.3", "0.8.4", "0.8.5", "0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"], - "0.8.2": ["0.8.3", "0.8.4", "0.8.5", "0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"], - "0.8.3": ["0.8.4", "0.8.5", "0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"], - "0.8.4": ["0.8.5", "0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"], - "0.8.5": ["0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"], - "0.8.6": ["0.8.7", "0.8.8", "0.8.9", "0.8.10"], - "0.8.7": ["0.8.8", "0.8.9", "0.8.10"], - "0.8.8": ["0.8.9", "0.8.10"], - "0.8.9": ["0.8.10"], - # 可以根据需要添加更多兼容映射 - # "0.9.0": ["0.9.1", "0.9.2", "0.9.3"], # 示例:0.9.x系列兼容 - } - - @staticmethod - def normalize_version(version: str) -> str: - """标准化版本号,移除snapshot标识 - - Args: - version: 原始版本号,如 "0.8.0-snapshot.1" - - Returns: - str: 标准化后的版本号,如 "0.8.0" - """ - if not version: - return "0.0.0" - - # 移除snapshot部分 - normalized = re.sub(r"-snapshot\.\d+", "", version.strip()) - - # 确保版本号格式正确 - if not re.match(r"^\d+(\.\d+){0,2}$", normalized): - # 如果不是有效的版本号格式,返回默认版本 - return "0.0.0" - - # 尝试补全版本号 - parts = normalized.split(".") - while len(parts) < 3: - parts.append("0") - normalized = ".".join(parts[:3]) - - return normalized - - @staticmethod - def parse_version(version: str) -> Tuple[int, int, int]: - """解析版本号为元组 - - Args: - version: 版本号字符串 - - Returns: - Tuple[int, int, int]: (major, minor, patch) - """ - normalized = VersionComparator.normalize_version(version) - try: - parts = normalized.split(".") - return (int(parts[0]), int(parts[1]), int(parts[2])) - except (ValueError, IndexError): - logger.warning(f"无法解析版本号: {version},使用默认版本 0.0.0") - return (0, 0, 0) - - @staticmethod - def compare_versions(version1: str, version2: str) -> int: - """比较两个版本号 - - Args: - version1: 第一个版本号 - version2: 第二个版本号 - - Returns: - int: -1 if version1 < version2, 0 if equal, 1 if version1 > version2 - """ - v1_tuple = VersionComparator.parse_version(version1) - v2_tuple = VersionComparator.parse_version(version2) - - if v1_tuple < v2_tuple: - return -1 - elif v1_tuple > v2_tuple: - return 1 - else: - return 0 - - @staticmethod - def check_forward_compatibility(current_version: str, max_version: str) -> Tuple[bool, str]: - """检查向前兼容性(仅使用兼容性映射表) - - Args: - current_version: 当前版本 - max_version: 插件声明的最大支持版本 - - Returns: - Tuple[bool, str]: (是否兼容, 兼容信息) - """ - current_normalized = VersionComparator.normalize_version(current_version) - max_normalized = VersionComparator.normalize_version(max_version) - - # 检查兼容性映射表 - if max_normalized in VersionComparator.COMPATIBILITY_MAP: - compatible_versions = VersionComparator.COMPATIBILITY_MAP[max_normalized] - if current_normalized in compatible_versions: - return True, f"根据兼容性映射表,版本 {current_normalized} 与 {max_normalized} 兼容" - - return False, "" - - @staticmethod - def is_version_in_range(version: str, min_version: str = "", max_version: str = "") -> Tuple[bool, str]: - """检查版本是否在指定范围内,支持兼容性检查 - - Args: - version: 要检查的版本号 - min_version: 最小版本号(可选) - max_version: 最大版本号(可选) - - Returns: - Tuple[bool, str]: (是否兼容, 错误信息或兼容信息) - """ - if not min_version and not max_version: - return True, "" - - version_normalized = VersionComparator.normalize_version(version) - - # 检查最小版本 - if min_version: - min_normalized = VersionComparator.normalize_version(min_version) - if VersionComparator.compare_versions(version_normalized, min_normalized) < 0: - return False, f"版本 {version_normalized} 低于最小要求版本 {min_normalized}" - - # 检查最大版本 - if max_version: - max_normalized = VersionComparator.normalize_version(max_version) - comparison = VersionComparator.compare_versions(version_normalized, max_normalized) - - if comparison > 0: - # 严格版本检查失败,尝试兼容性检查 - is_compatible, compat_msg = VersionComparator.check_forward_compatibility( - version_normalized, max_normalized - ) - - if not is_compatible: - return False, f"版本 {version_normalized} 高于最大支持版本 {max_normalized},且无兼容性映射" - - logger.info(f"版本兼容性检查:{compat_msg}") - return True, compat_msg - return True, "" - - @staticmethod - def get_current_host_version() -> str: - """获取当前主机应用版本 - - Returns: - str: 当前版本号 - """ - return VersionComparator.normalize_version(MMC_VERSION) - - @staticmethod - def add_compatibility_mapping(base_version: str, compatible_versions: list) -> None: - """动态添加兼容性映射 - - Args: - base_version: 基础版本(插件声明的最大支持版本) - compatible_versions: 兼容的版本列表 - """ - base_normalized = VersionComparator.normalize_version(base_version) - VersionComparator.COMPATIBILITY_MAP[base_normalized] = [ - VersionComparator.normalize_version(v) for v in compatible_versions - ] - logger.info(f"添加兼容性映射:{base_normalized} -> {compatible_versions}") - - @staticmethod - def get_compatibility_info() -> Dict[str, list]: - """获取当前的兼容性映射表 - - Returns: - Dict[str, list]: 兼容性映射表的副本 - """ - return VersionComparator.COMPATIBILITY_MAP.copy() - - -class ManifestValidator: - """Manifest文件验证器""" - - # 必需字段(必须存在且不能为空) - REQUIRED_FIELDS = ["manifest_version", "name", "version", "description", "author"] - - # 可选字段(可以不存在或为空) - OPTIONAL_FIELDS = [ - "license", - "host_application", - "homepage_url", - "repository_url", - "keywords", - "categories", - "default_locale", - "locales_path", - "plugin_info", - ] - - # 建议填写的字段(会给出警告但不会导致验证失败) - RECOMMENDED_FIELDS = ["license", "keywords", "categories"] - - SUPPORTED_MANIFEST_VERSIONS = [1] - - def __init__(self): - self.validation_errors = [] - self.validation_warnings = [] - - def validate_manifest(self, manifest_data: Dict[str, Any]) -> bool: - """验证manifest数据 - - Args: - manifest_data: manifest数据字典 - - Returns: - bool: 是否验证通过(只有错误会导致验证失败,警告不会) - """ - self.validation_errors.clear() - self.validation_warnings.clear() - - # 检查必需字段 - for field in self.REQUIRED_FIELDS: - if field not in manifest_data: - self.validation_errors.append(f"缺少必需字段: {field}") - elif not manifest_data[field]: - self.validation_errors.append(f"必需字段不能为空: {field}") - - # 检查manifest版本 - if "manifest_version" in manifest_data: - version = manifest_data["manifest_version"] - if version not in self.SUPPORTED_MANIFEST_VERSIONS: - self.validation_errors.append( - f"不支持的manifest版本: {version},支持的版本: {self.SUPPORTED_MANIFEST_VERSIONS}" - ) - - # 检查作者信息格式 - if "author" in manifest_data: - author = manifest_data["author"] - if isinstance(author, dict): - if "name" not in author or not author["name"]: - self.validation_errors.append("作者信息缺少name字段或为空") - # url字段是可选的 - if "url" in author and author["url"]: - url = author["url"] - if not (url.startswith("http://") or url.startswith("https://")): - self.validation_warnings.append("作者URL建议使用完整的URL格式") - elif isinstance(author, str): - if not author.strip(): - self.validation_errors.append("作者信息不能为空") - else: - self.validation_errors.append("作者信息格式错误,应为字符串或包含name字段的对象") - # 检查主机应用版本要求(可选) - if "host_application" in manifest_data: - host_app = manifest_data["host_application"] - if isinstance(host_app, dict): - min_version = host_app.get("min_version", "") - max_version = host_app.get("max_version", "") - - # 验证版本字段格式 - for version_field in ["min_version", "max_version"]: - if version_field in host_app and not host_app[version_field]: - self.validation_warnings.append(f"host_application.{version_field}为空") - - # 检查当前主机版本兼容性 - if min_version or max_version: - current_version = VersionComparator.get_current_host_version() - is_compatible, error_msg = VersionComparator.is_version_in_range( - current_version, min_version, max_version - ) - - if not is_compatible: - self.validation_errors.append(f"版本兼容性检查失败: {error_msg} (当前版本: {current_version})") - else: - logger.debug( - f"版本兼容性检查通过: 当前版本 {current_version} 符合要求 [{min_version}, {max_version}]" - ) - else: - self.validation_errors.append("host_application格式错误,应为对象") - - # 检查URL格式(可选字段) - for url_field in ["homepage_url", "repository_url"]: - if url_field in manifest_data and manifest_data[url_field]: - url: str = manifest_data[url_field] - if not (url.startswith("http://") or url.startswith("https://")): - self.validation_warnings.append(f"{url_field}建议使用完整的URL格式") - - # 检查数组字段格式(可选字段) - for list_field in ["keywords", "categories"]: - if list_field in manifest_data: - field_value = manifest_data[list_field] - if field_value is not None and not isinstance(field_value, list): - self.validation_errors.append(f"{list_field}应为数组格式") - elif isinstance(field_value, list): - # 检查数组元素是否为字符串 - for i, item in enumerate(field_value): - if not isinstance(item, str): - self.validation_warnings.append(f"{list_field}[{i}]应为字符串") - - # 检查建议字段(给出警告) - for field in self.RECOMMENDED_FIELDS: - if field not in manifest_data or not manifest_data[field]: - self.validation_warnings.append(f"建议填写字段: {field}") - - # 检查plugin_info结构(可选) - if "plugin_info" in manifest_data: - plugin_info = manifest_data["plugin_info"] - if isinstance(plugin_info, dict): - # 检查components数组 - if "components" in plugin_info: - components = plugin_info["components"] - if not isinstance(components, list): - self.validation_errors.append("plugin_info.components应为数组格式") - else: - for i, component in enumerate(components): - if not isinstance(component, dict): - self.validation_errors.append(f"plugin_info.components[{i}]应为对象") - else: - # 检查组件必需字段 - for comp_field in ["type", "name", "description"]: - if comp_field not in component or not component[comp_field]: - self.validation_errors.append( - f"plugin_info.components[{i}]缺少必需字段: {comp_field}" - ) - else: - self.validation_errors.append("plugin_info应为对象格式") - - return len(self.validation_errors) == 0 - - def get_validation_report(self) -> str: - """获取验证报告""" - report = [] - - if self.validation_errors: - report.append("❌ 验证错误:") - report.extend(f" - {error}" for error in self.validation_errors) - if self.validation_warnings: - report.append("⚠️ 验证警告:") - report.extend(f" - {warning}" for warning in self.validation_warnings) - if not self.validation_errors and not self.validation_warnings: - report.append("✅ Manifest文件验证通过") - - return "\n".join(report) - - -# class ManifestGenerator: -# """Manifest文件生成器""" - -# def __init__(self): -# self.template = { -# "manifest_version": 1, -# "name": "", -# "version": "1.0.0", -# "description": "", -# "author": {"name": "", "url": ""}, -# "license": "MIT", -# "host_application": {"min_version": "1.0.0", "max_version": "4.0.0"}, -# "homepage_url": "", -# "repository_url": "", -# "keywords": [], -# "categories": [], -# "default_locale": "zh-CN", -# "locales_path": "_locales", -# } - -# def generate_from_plugin(self, plugin_instance: BasePlugin) -> Dict[str, Any]: -# """从插件实例生成manifest - -# Args: -# plugin_instance: BasePlugin实例 - -# Returns: -# Dict[str, Any]: 生成的manifest数据 -# """ -# manifest = self.template.copy() - -# # 基本信息 -# manifest["name"] = plugin_instance.plugin_name -# manifest["version"] = plugin_instance.plugin_version -# manifest["description"] = plugin_instance.plugin_description - -# # 作者信息 -# if plugin_instance.plugin_author: -# manifest["author"]["name"] = plugin_instance.plugin_author - -# # 组件信息 -# components = [] -# plugin_components = plugin_instance.get_plugin_components() - -# for component_info, component_class in plugin_components: -# component_data: Dict[str, Any] = { -# "type": component_info.component_type.value, -# "name": component_info.name, -# "description": component_info.description, -# } - -# # 添加激活模式信息(对于Action组件) -# if hasattr(component_class, "focus_activation_type"): -# activation_modes = [] -# if hasattr(component_class, "focus_activation_type"): -# activation_modes.append(component_class.focus_activation_type.value) -# if hasattr(component_class, "normal_activation_type"): -# activation_modes.append(component_class.normal_activation_type.value) -# component_data["activation_modes"] = list(set(activation_modes)) - -# # 添加关键词信息 -# if hasattr(component_class, "activation_keywords"): -# keywords = getattr(component_class, "activation_keywords", []) -# if keywords: -# component_data["keywords"] = keywords - -# components.append(component_data) - -# manifest["plugin_info"] = {"is_built_in": True, "plugin_type": "general", "components": components} - -# return manifest - -# def save_manifest(self, manifest_data: Dict[str, Any], plugin_dir: str) -> bool: -# """保存manifest文件 - -# Args: -# manifest_data: manifest数据 -# plugin_dir: 插件目录 - -# Returns: -# bool: 是否保存成功 -# """ -# try: -# manifest_path = os.path.join(plugin_dir, "_manifest.json") -# with open(manifest_path, "w", encoding="utf-8") as f: -# json.dump(manifest_data, f, ensure_ascii=False, indent=2) -# logger.info(f"Manifest文件已保存: {manifest_path}") -# return True -# except Exception as e: -# logger.error(f"保存manifest文件失败: {e}") -# return False - - -# def validate_plugin_manifest(plugin_dir: str) -> bool: -# """验证插件目录中的manifest文件 - -# Args: -# plugin_dir: 插件目录路径 - -# Returns: -# bool: 是否验证通过 -# """ -# manifest_path = os.path.join(plugin_dir, "_manifest.json") - -# if not os.path.exists(manifest_path): -# logger.warning(f"未找到manifest文件: {manifest_path}") -# return False - -# try: -# with open(manifest_path, "r", encoding="utf-8") as f: -# manifest_data = json.load(f) - -# validator = ManifestValidator() -# is_valid = validator.validate_manifest(manifest_data) - -# logger.info(f"Manifest验证结果:\n{validator.get_validation_report()}") - -# return is_valid - -# except Exception as e: -# logger.error(f"读取或验证manifest文件失败: {e}") -# return False - - -# def generate_plugin_manifest(plugin_instance: BasePlugin, save_to_file: bool = True) -> Optional[Dict[str, Any]]: -# """为插件生成manifest文件 - -# Args: -# plugin_instance: BasePlugin实例 -# save_to_file: 是否保存到文件 - -# Returns: -# Optional[Dict[str, Any]]: 生成的manifest数据 -# """ -# try: -# generator = ManifestGenerator() -# manifest_data = generator.generate_from_plugin(plugin_instance) - -# if save_to_file and plugin_instance.plugin_dir: -# generator.save_manifest(manifest_data, plugin_instance.plugin_dir) - -# return manifest_data - -# except Exception as e: -# logger.error(f"生成manifest文件失败: {e}") -# return None diff --git a/src/plugins/built_in/emoji_plugin/_manifest.json b/src/plugins/built_in/emoji_plugin/_manifest.json index 33fce7cb..d4d262e7 100644 --- a/src/plugins/built_in/emoji_plugin/_manifest.json +++ b/src/plugins/built_in/emoji_plugin/_manifest.json @@ -1,34 +1,38 @@ { "manifest_version": 1, "name": "Emoji插件 (Emoji Actions)", - "version": "1.0.0", + "version": "2.0.0", "description": "可以发送和管理Emoji", "author": { "name": "SengokuCola", "url": "https://github.com/MaiM-with-u" }, "license": "GPL-v3.0-or-later", - "host_application": { - "min_version": "0.10.0" + "min_version": "1.0.0" }, "homepage_url": "https://github.com/MaiM-with-u/maibot", "repository_url": "https://github.com/MaiM-with-u/maibot", "keywords": ["emoji", "action", "built-in"], "categories": ["Emoji"], - "default_locale": "zh-CN", - "locales_path": "_locales", - "plugin_info": { "is_built_in": true, "plugin_type": "action_provider", "components": [ { "type": "action", - "name": "emoji", + "name": "emoji", "description": "发送表情包辅助表达情绪" } ] - } + }, + "capabilities": [ + "emoji.get_random", + "message.get_recent", + "message.build_readable", + "llm.generate", + "send.emoji", + "config.get" + ] } diff --git a/src/plugins/built_in/emoji_plugin/emoji.py b/src/plugins/built_in/emoji_plugin/emoji.py deleted file mode 100644 index 8599620a..00000000 --- a/src/plugins/built_in/emoji_plugin/emoji.py +++ /dev/null @@ -1,151 +0,0 @@ -import random -from typing import Tuple - -# 导入新插件系统 -from src.plugin_system import BaseAction, ActionActivationType - -# 导入依赖的系统组件 -from src.common.logger import get_logger - -# 导入API模块 - 标准Python包方式 -from src.plugin_system.apis import emoji_api, llm_api, message_api - -# NoReplyAction已集成到heartFC_chat.py中,不再需要导入 -from src.config.config import global_config - - -logger = get_logger("emoji") - - -class EmojiAction(BaseAction): - """表情动作 - 发送表情包""" - - activation_type = ActionActivationType.RANDOM - random_activation_probability = global_config.emoji.emoji_chance - parallel_action = True - - # 动作基本信息 - action_name = "emoji" - action_description = "发送表情包辅助表达情绪" - - # 动作参数定义 - action_parameters = {} - - # 动作使用场景 - action_require = [ - "发送表情包辅助表达情绪", - "表达情绪时可以选择使用", - "不要连续发送,如果你已经发过[表情包],就不要选择此动作", - ] - - # 关联类型 - associated_types = ["emoji"] - - async def execute(self) -> Tuple[bool, str]: - # sourcery skip: assign-if-exp, introduce-default-else, swap-if-else-branches, use-named-expression - """执行表情动作""" - try: - # 1. 获取发送表情的原因 - # reason = self.action_data.get("reason", "表达当前情绪") - reason = self.action_reasoning - - # 2. 随机获取20个表情包 - sampled_emojis = await emoji_api.get_random(30) - if not sampled_emojis: - logger.warning(f"{self.log_prefix} 无法获取随机表情包") - return False, "无法获取随机表情包" - - # 3. 准备情感数据 - emotion_map = {} - for b64, desc, emo in sampled_emojis: - if emo not in emotion_map: - emotion_map[emo] = [] - emotion_map[emo].append((b64, desc)) - - available_emotions = list(emotion_map.keys()) - available_emotions_str = "" - for emotion in available_emotions: - available_emotions_str += f"{emotion}\n" - - if not available_emotions: - logger.warning(f"{self.log_prefix} 获取到的表情包均无情感标签, 将随机发送") - emoji_base64, emoji_description, _ = random.choice(sampled_emojis) - else: - # 获取最近的5条消息内容用于判断 - recent_messages = message_api.get_recent_messages(chat_id=self.chat_id, limit=5) - messages_text = "" - if recent_messages: - # 使用message_api构建可读的消息字符串 - messages_text = message_api.build_readable_messages( - messages=recent_messages, - timestamp_mode="normal_no_YMD", - truncate=False, - show_actions=False, - ) - - # 4. 构建prompt让LLM选择情感 - prompt = f"""你正在进行QQ聊天,你需要根据聊天记录,选出一个合适的情感标签。 -请你根据以下原因和聊天记录进行选择 -原因:{reason} -聊天记录: -{messages_text} - -这里是可用的情感标签: -{available_emotions_str} -请直接返回最匹配的那个情感标签,不要进行任何解释或添加其他多余的文字。 - """ - - if global_config.debug.show_prompt: - logger.info(f"{self.log_prefix} 生成的LLM Prompt: {prompt}") - else: - logger.debug(f"{self.log_prefix} 生成的LLM Prompt: {prompt}") - - # 5. 调用LLM - models = llm_api.get_available_models() - chat_model_config = models.get("utils") # 使用字典访问方式 - if not chat_model_config: - logger.error(f"{self.log_prefix} 未找到'utils'模型配置,无法调用LLM") - return False, "未找到'utils'模型配置" - - success, chosen_emotion, _, _ = await llm_api.generate_with_model( - prompt, model_config=chat_model_config, request_type="emoji.select" - ) - - if not success: - logger.error(f"{self.log_prefix} LLM调用失败: {chosen_emotion}") - return False, f"LLM调用失败: {chosen_emotion}" - - chosen_emotion = chosen_emotion.strip().replace('"', "").replace("'", "") - logger.info(f"{self.log_prefix} LLM选择的情感: {chosen_emotion}") - - # 6. 根据选择的情感匹配表情包 - if chosen_emotion in emotion_map: - emoji_base64, emoji_description = random.choice(emotion_map[chosen_emotion]) - logger.info(f"{self.log_prefix} 发送表情包[{chosen_emotion}],原因: {reason}") - else: - logger.warning( - f"{self.log_prefix} LLM选择的情感 '{chosen_emotion}' 不在可用列表中, 将随机选择一个表情包" - ) - emoji_base64, emoji_description, _ = random.choice(sampled_emojis) - - # 7. 发送表情包 - success = await self.send_emoji(emoji_base64) - - if success: - # 存储动作信息 - await self.store_action_info( - action_build_into_prompt=True, - action_prompt_display=f"你发送了表情包,原因:{reason}", - action_done=True, - ) - return True, f"成功发送表情包:[表情包:{chosen_emotion}]" - else: - error_msg = "发送表情包失败" - logger.error(f"{self.log_prefix} {error_msg}") - - await self.send_text("执行表情包动作失败") - return False, error_msg - - except Exception as e: - logger.error(f"{self.log_prefix} 表情动作执行失败: {e}", exc_info=True) - return False, f"表情发送失败: {str(e)}" diff --git a/src/plugins/built_in/emoji_plugin/plugin.py b/src/plugins/built_in/emoji_plugin/plugin.py index 98b1e01a..5b5c7b93 100644 --- a/src/plugins/built_in/emoji_plugin/plugin.py +++ b/src/plugins/built_in/emoji_plugin/plugin.py @@ -1,66 +1,116 @@ -""" -核心动作插件 +"""Emoji 插件 — 新 SDK 版本 -将系统核心动作(reply、no_reply、emoji)转换为新插件系统格式 -这是系统的内置插件,提供基础的聊天交互功能 +根据聊天上下文的情感,使用 LLM 选择并发送合适的表情包。 """ -from typing import List, Tuple, Type +import random -# 导入新插件系统 -from src.plugin_system import BasePlugin, register_plugin, ComponentInfo -from src.plugin_system.base.config_types import ConfigField - -# 导入依赖的系统组件 -from src.common.logger import get_logger - -from src.plugins.built_in.emoji_plugin.emoji import EmojiAction - -logger = get_logger("core_actions") +from maibot_sdk import MaiBotPlugin, Action +from maibot_sdk.types import ActivationType -@register_plugin -class CoreActionsPlugin(BasePlugin): - """核心动作插件 +class EmojiPlugin(MaiBotPlugin): + """表情包插件""" - 系统内置插件,提供基础的聊天交互功能: - - Reply: 回复动作 - - NoReply: 不回复动作 - - Emoji: 表情动作 + @Action( + "emoji", + description="发送表情包辅助表达情绪", + activation_type=ActivationType.RANDOM, + activation_probability=0.3, + parallel_action=True, + action_require=[ + "发送表情包辅助表达情绪", + "表达情绪时可以选择使用", + "不要连续发送,如果你已经发过[表情包],就不要选择此动作", + ], + associated_types=["emoji"], + ) + async def handle_emoji(self, stream_id: str = "", reasoning: str = "", chat_id: str = "", **kwargs): + """执行表情动作""" + reason = reasoning or "表达当前情绪" - 注意:插件基本信息优先从_manifest.json文件中读取 - """ + # 1. 随机获取30个表情包 + result = await self.ctx.emoji.get_random(30) + if not result or not result.get("success"): + return False, "无法获取随机表情包" - # 插件基本信息 - plugin_name: str = "core_actions" # 内部标识符 - enable_plugin: bool = True - dependencies: list[str] = [] # 插件依赖列表 - python_dependencies: list[str] = [] # Python包依赖列表 - config_file_name: str = "config.toml" + sampled_emojis = result.get("emojis", []) + if not sampled_emojis: + return False, "无法获取随机表情包" - # 配置节描述 - config_section_descriptions = { - "plugin": "插件启用配置", - "components": "核心组件启用配置", - } + # 2. 按情感分组 + emotion_map: dict[str, list] = {} + for emoji in sampled_emojis: + emo = emoji.get("emotion", "") + if emo not in emotion_map: + emotion_map[emo] = [] + emotion_map[emo].append(emoji) - # 配置Schema定义 - config_schema: dict = { - "plugin": { - "enabled": ConfigField(type=bool, default=True, description="是否启用插件"), - "config_version": ConfigField(type=str, default="0.6.0", description="配置文件版本"), - }, - "components": { - "enable_emoji": ConfigField(type=bool, default=True, description="是否启用发送表情/图片动作"), - }, - } + available_emotions = list(emotion_map.keys()) - def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: - """返回插件包含的组件列表""" + if not available_emotions: + # 无情感标签,随机发送 + chosen = random.choice(sampled_emojis) + await self.ctx.send.emoji(chosen["base64"], stream_id) + return True, "随机发送了表情包" - # --- 根据配置注册组件 --- - components = [] - if self.get_config("components.enable_emoji", True): - components.append((EmojiAction.get_action_info(), EmojiAction)) + # 3. 获取最近消息作为上下文 + messages_text = "" + if chat_id: + recent_result = await self.ctx.message.get_recent(chat_id=chat_id, limit=5) + if recent_result and recent_result.get("success"): + readable_result = await self.ctx.call_capability( + "message.build_readable", + chat_id=chat_id, + start_time=0, + end_time=0, + limit=5, + timestamp_mode="normal_no_YMD", + truncate=False, + ) + if readable_result and readable_result.get("success"): + messages_text = readable_result.get("text", "") - return components + # 4. 构建 prompt 让 LLM 选择情感 + available_emotions_str = "\n".join(available_emotions) + prompt = f"""你正在进行QQ聊天,你需要根据聊天记录,选出一个合适的情感标签。 +请你根据以下原因和聊天记录进行选择 +原因:{reason} +聊天记录: +{messages_text} + +这里是可用的情感标签: +{available_emotions_str} +请直接返回最匹配的那个情感标签,不要进行任何解释或添加其他多余的文字。 +""" + + # 5. 调用 LLM + llm_result = await self.ctx.llm.generate(prompt=prompt, model_name="utils") + if not llm_result or not llm_result.get("success"): + chosen = random.choice(sampled_emojis) + await self.ctx.send.emoji(chosen["base64"], stream_id) + return True, "LLM调用失败,随机发送了表情包" + + chosen_emotion = llm_result.get("response", "").strip().replace('"', "").replace("'", "") + + # 6. 根据选择的情感匹配表情包 + if chosen_emotion in emotion_map: + chosen = random.choice(emotion_map[chosen_emotion]) + else: + chosen = random.choice(sampled_emojis) + + # 7. 发送 + send_result = await self.ctx.send.emoji(chosen["base64"], stream_id) + if send_result and send_result.get("success"): + return True, f"成功发送表情包:[表情包:{chosen_emotion}]" + return False, "发送表情包失败" + + async def on_load(self): + # 从插件配置读取 emoji_chance 来覆盖默认概率 + config_result = await self.ctx.config.get("emoji.emoji_chance") + if config_result and isinstance(config_result, dict) and config_result.get("success"): + pass # 配置已在宿主端管理 + + +def create_plugin(): + return EmojiPlugin() diff --git a/src/plugins/built_in/knowledge/_manifest.json b/src/plugins/built_in/knowledge/_manifest.json new file mode 100644 index 00000000..06295135 --- /dev/null +++ b/src/plugins/built_in/knowledge/_manifest.json @@ -0,0 +1,33 @@ +{ + "manifest_version": 1, + "name": "LPMM 知识库插件 (Knowledge Search)", + "version": "2.0.0", + "description": "从 LPMM 知识库中搜索相关信息,供 LLM 工具调用", + "author": { + "name": "MaiBot团队", + "url": "https://github.com/MaiM-with-u" + }, + "license": "GPL-v3.0-or-later", + "host_application": { + "min_version": "1.0.0" + }, + "homepage_url": "https://github.com/MaiM-with-u/maibot", + "repository_url": "https://github.com/MaiM-with-u/maibot", + "keywords": ["knowledge", "lpmm", "search", "tool", "built-in"], + "categories": ["Knowledge", "Tools"], + "default_locale": "zh-CN", + "plugin_info": { + "is_built_in": true, + "plugin_type": "tool_provider", + "components": [ + { + "type": "tool", + "name": "lpmm_search_knowledge", + "description": "从知识库中搜索相关信息" + } + ] + }, + "capabilities": [ + "knowledge.search" + ] +} diff --git a/src/plugins/built_in/knowledge/lpmm_get_knowledge.py b/src/plugins/built_in/knowledge/lpmm_get_knowledge.py deleted file mode 100644 index bb627e5e..00000000 --- a/src/plugins/built_in/knowledge/lpmm_get_knowledge.py +++ /dev/null @@ -1,62 +0,0 @@ -from typing import Dict, Any - -from src.common.logger import get_logger -from src.config.config import global_config -from src.chat.knowledge import qa_manager -from src.plugin_system import BaseTool, ToolParamType - -logger = get_logger("lpmm_get_knowledge_tool") - - -class SearchKnowledgeFromLPMMTool(BaseTool): - """从LPMM知识库中搜索相关信息的工具""" - - name = "lpmm_search_knowledge" - description = "从知识库中搜索相关信息,如果你需要知识,就使用这个工具" - parameters = [ - ("query", ToolParamType.STRING, "搜索查询关键词", True, None), - ("limit", ToolParamType.INTEGER, "希望返回的相关知识条数,默认5", False, None), - ] - available_for_llm = global_config.lpmm_knowledge.enable - - async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: - """执行知识库搜索 - - Args: - function_args: 工具参数 - - Returns: - Dict: 工具执行结果 - """ - try: - query: str = function_args.get("query") # type: ignore - limit = function_args.get("limit", 5) - try: - limit_value = int(limit) - except (TypeError, ValueError): - limit_value = 5 - limit_value = max(1, limit_value) - # threshold = function_args.get("threshold", 0.4) - - # 检查LPMM知识库是否启用 - if qa_manager is None: - logger.debug("LPMM知识库已禁用,跳过知识获取") - return {"type": "info", "id": query, "content": "LPMM知识库已禁用"} - - # 调用知识库搜索 - - knowledge_info = await qa_manager.get_knowledge(query, limit=limit_value) - - logger.debug(f"知识库查询结果: {knowledge_info}") - - if knowledge_info: - content = f"你知道这些知识: {knowledge_info}" - else: - content = f"你不太了解有关{query}的知识" - return {"type": "lpmm_knowledge", "id": query, "content": content} - except Exception as e: - # 捕获异常并记录错误 - logger.error(f"知识库搜索工具执行失败: {str(e)}") - # 在其他异常情况下,确保 id 仍然是 query (如果它被定义了) - query_id = query if "query" in locals() else "unknown_query" - return {"type": "info", "id": query_id, "content": f"lpmm知识库搜索失败,炸了: {str(e)}"} diff --git a/src/plugins/built_in/knowledge/plugin.py b/src/plugins/built_in/knowledge/plugin.py new file mode 100644 index 00000000..92185521 --- /dev/null +++ b/src/plugins/built_in/knowledge/plugin.py @@ -0,0 +1,39 @@ +"""LPMM 知识库搜索插件 — 新 SDK 版本 + +提供 LLM 可调用的知识库搜索工具。 +""" + +from maibot_sdk import MaiBotPlugin, Tool +from maibot_sdk.types import ToolParameterInfo, ToolParamType + + +class KnowledgePlugin(MaiBotPlugin): + """LPMM 知识库插件""" + + @Tool( + "lpmm_search_knowledge", + description="从知识库中搜索相关信息,如果你需要知识,就使用这个工具", + parameters=[ + ToolParameterInfo(name="query", param_type=ToolParamType.STRING, description="搜索查询关键词", required=True), + ToolParameterInfo(name="limit", param_type=ToolParamType.INTEGER, description="希望返回的相关知识条数,默认5", required=False, default=5), + ], + ) + async def handle_lpmm_search_knowledge(self, query: str = "", limit: int = 5, **kwargs): + """执行知识库搜索""" + if not query: + return {"type": "info", "id": "", "content": "未提供搜索关键词"} + + try: + limit_value = max(1, int(limit)) + except (TypeError, ValueError): + limit_value = 5 + + result = await self.ctx.call_capability("knowledge.search", query=query, limit=limit_value) + if result and result.get("success"): + content = result.get("content", f"你不太了解有关{query}的知识") + return {"type": "lpmm_knowledge", "id": query, "content": content} + return {"type": "info", "id": query, "content": f"知识库搜索失败: {result}"} + + +def create_plugin(): + return KnowledgePlugin() diff --git a/src/plugins/built_in/plugin_management/_manifest.json b/src/plugins/built_in/plugin_management/_manifest.json index a0175d77..a5b52835 100644 --- a/src/plugins/built_in/plugin_management/_manifest.json +++ b/src/plugins/built_in/plugin_management/_manifest.json @@ -1,7 +1,7 @@ { "manifest_version": 1, "name": "插件和组件管理 (Plugin and Component Management)", - "version": "1.0.0", + "version": "2.0.0", "description": "通过系统API管理插件和组件的生命周期,包括加载、卸载、启用和禁用等操作。", "author": { "name": "MaiBot团队", @@ -9,7 +9,7 @@ }, "license": "GPL-v3.0-or-later", "host_application": { - "min_version": "0.10.1" + "min_version": "1.0.0" }, "homepage_url": "https://github.com/MaiM-with-u/maibot", "repository_url": "https://github.com/MaiM-with-u/maibot", @@ -28,10 +28,22 @@ "plugin_info": { "is_built_in": true, "plugin_type": "plugin_management", + "capabilities": [ + "component.get_all_plugins", + "component.list_loaded_plugins", + "component.list_registered_plugins", + "component.enable", + "component.disable", + "component.load_plugin", + "component.unload_plugin", + "component.reload_plugin", + "send.text", + "config.get" + ], "components": [ { "type": "command", - "name": "plugin_management", + "name": "management", "description": "管理插件和组件的生命周期,包括加载、卸载、启用和禁用等操作。" } ] diff --git a/src/plugins/built_in/plugin_management/plugin.py b/src/plugins/built_in/plugin_management/plugin.py index ba60f451..81c77faa 100644 --- a/src/plugins/built_in/plugin_management/plugin.py +++ b/src/plugins/built_in/plugin_management/plugin.py @@ -1,454 +1,279 @@ -import asyncio +"""插件和组件管理 — 新 SDK 版本 -from typing import List, Tuple, Type -from src.plugin_system import ( - BasePlugin, - BaseCommand, - CommandInfo, - ConfigField, - register_plugin, - plugin_manage_api, - component_manage_api, - ComponentInfo, - ComponentType, - send_api, +通过 /pm 命令管理插件和组件的生命周期。 +""" + +from maibot_sdk import MaiBotPlugin, Command + + +_VALID_COMPONENT_TYPES = ("action", "command", "event_handler") + +HELP_ALL = ( + "管理命令帮助\n" + "/pm help 管理命令提示\n" + "/pm plugin 插件管理命令\n" + "/pm component 组件管理命令\n" + "使用 /pm plugin help 或 /pm component help 获取具体帮助" +) +HELP_PLUGIN = ( + "插件管理命令帮助\n" + "/pm plugin help 插件管理命令提示\n" + "/pm plugin list 列出所有注册的插件\n" + "/pm plugin list_enabled 列出所有加载(启用)的插件\n" + "/pm plugin load 加载指定插件\n" + "/pm plugin unload 卸载指定插件\n" + "/pm plugin reload 重新加载指定插件\n" +) +HELP_COMPONENT = ( + "组件管理命令帮助\n" + "/pm component help 组件管理命令提示\n" + "/pm component list 列出所有注册的组件\n" + "/pm component list enabled <可选: type> 列出所有启用的组件\n" + "/pm component list disabled <可选: type> 列出所有禁用的组件\n" + " - 可选项: local,代表当前聊天中的;global,代表全局的\n" + " - 不填时为 global\n" + "/pm component list type 列出已经注册的指定类型的组件\n" + "/pm component enable global 全局启用组件\n" + "/pm component enable local 本聊天启用组件\n" + "/pm component disable global 全局禁用组件\n" + "/pm component disable local 本聊天禁用组件\n" + " - 可选项: action, command, event_handler\n" ) -class ManagementCommand(BaseCommand): - command_name: str = "management" - description: str = "管理命令" - command_pattern: str = r"(?P^/pm(\s[a-zA-Z0-9_]+)*\s*$)" +class PluginManagementPlugin(MaiBotPlugin): + """插件和组件管理插件""" - async def execute(self) -> Tuple[bool, str, bool]: - # sourcery skip: merge-duplicate-blocks - if ( - not self.message - or not self.message.message_info - or not self.message.message_info.user_info - or str(self.message.message_info.user_info.user_id) not in self.get_config("plugin.permission", []) # type: ignore - ): - await self._send_message("你没有权限使用插件管理命令") + @Command( + "management", + description="管理插件和组件的生命周期", + pattern=r"(?P^/pm(\s[a-zA-Z0-9_]+)*\s*$)", + ) + async def handle_management( + self, stream_id: str = "", user_id: str = "", matched_groups: dict | None = None, **kwargs + ): + """处理 /pm 命令""" + # 权限检查 + permission_result = await self.ctx.config.get("plugin.permission") + permission_list = permission_result if isinstance(permission_result, list) else [] + if str(user_id) not in permission_list: + await self.ctx.send.text("你没有权限使用插件管理命令", stream_id) return False, "没有权限", True - if not self.message.chat_stream: - await self._send_message("无法获取聊天流信息") + + if not stream_id: return False, "无法获取聊天流信息", True - self.stream_id = self.message.chat_stream.stream_id - if not self.stream_id: - await self._send_message("无法获取聊天流信息") - return False, "无法获取聊天流信息", True - command_list = self.matched_groups["manage_command"].strip().split(" ") - if len(command_list) == 1: - await self.show_help("all") + + raw_command = (matched_groups or {}).get("manage_command", "").strip() + parts = raw_command.split(" ") if raw_command else ["/pm"] + n = len(parts) + + # /pm + if n == 1: + await self.ctx.send.text(HELP_ALL, stream_id) return True, "帮助已发送", True - if len(command_list) == 2: - match command_list[1]: - case "plugin": - await self.show_help("plugin") - case "component": - await self.show_help("component") - case "help": - await self.show_help("all") - case _: - await self._send_message("插件管理命令不合法") - return False, "命令不合法", True - if len(command_list) == 3: - if command_list[1] == "plugin": - match command_list[2]: - case "help": - await self.show_help("plugin") - case "list": - await self._list_registered_plugins() - case "list_enabled": - await self._list_loaded_plugins() - case "rescan": - await self._rescan_plugin_dirs() - case _: - await self._send_message("插件管理命令不合法") - return False, "命令不合法", True - elif command_list[1] == "component": - if command_list[2] == "list": - await self._list_all_registered_components() - elif command_list[2] == "help": - await self.show_help("component") - else: - await self._send_message("插件管理命令不合法") - return False, "命令不合法", True - else: - await self._send_message("插件管理命令不合法") - return False, "命令不合法", True - if len(command_list) == 4: - if command_list[1] == "plugin": - match command_list[2]: - case "load": - await self._load_plugin(command_list[3]) - case "unload": - await self._unload_plugin(command_list[3]) - case "reload": - await self._reload_plugin(command_list[3]) - case "add_dir": - await self._add_dir(command_list[3]) - case _: - await self._send_message("插件管理命令不合法") - return False, "命令不合法", True - elif command_list[1] == "component": - if command_list[2] != "list": - await self._send_message("插件管理命令不合法") - return False, "命令不合法", True - if command_list[3] == "enabled": - await self._list_enabled_components() - elif command_list[3] == "disabled": - await self._list_disabled_components() - else: - await self._send_message("插件管理命令不合法") - return False, "命令不合法", True - else: - await self._send_message("插件管理命令不合法") - return False, "命令不合法", True - if len(command_list) == 5: - if command_list[1] != "component": - await self._send_message("插件管理命令不合法") - return False, "命令不合法", True - if command_list[2] != "list": - await self._send_message("插件管理命令不合法") - return False, "命令不合法", True - if command_list[3] == "enabled": - await self._list_enabled_components(target_type=command_list[4]) - elif command_list[3] == "disabled": - await self._list_disabled_components(target_type=command_list[4]) - elif command_list[3] == "type": - await self._list_registered_components_by_type(command_list[4]) - else: - await self._send_message("插件管理命令不合法") - return False, "命令不合法", True - if len(command_list) == 6: - if command_list[1] != "component": - await self._send_message("插件管理命令不合法") - return False, "命令不合法", True - if command_list[2] == "enable": - if command_list[3] == "global": - await self._globally_enable_component(command_list[4], command_list[5]) - elif command_list[3] == "local": - await self._locally_enable_component(command_list[4], command_list[5]) - else: - await self._send_message("插件管理命令不合法") - return False, "命令不合法", True - elif command_list[2] == "disable": - if command_list[3] == "global": - await self._globally_disable_component(command_list[4], command_list[5]) - elif command_list[3] == "local": - await self._locally_disable_component(command_list[4], command_list[5]) - else: - await self._send_message("插件管理命令不合法") - return False, "命令不合法", True - else: - await self._send_message("插件管理命令不合法") - return False, "命令不合法", True - return True, "命令执行完成", True + # /pm + if n == 2: + sub = parts[1] + if sub == "plugin": + await self.ctx.send.text(HELP_PLUGIN, stream_id) + elif sub == "component": + await self.ctx.send.text(HELP_COMPONENT, stream_id) + elif sub == "help": + await self.ctx.send.text(HELP_ALL, stream_id) + else: + await self.ctx.send.text("插件管理命令不合法", stream_id) + return False, "命令不合法", True + return True, "帮助已发送", True - async def show_help(self, target: str): - help_msg = "" - match target: - case "all": - help_msg = ( - "管理命令帮助\n" - "/pm help 管理命令提示\n" - "/pm plugin 插件管理命令\n" - "/pm component 组件管理命令\n" - "使用 /pm plugin help 或 /pm component help 获取具体帮助" - ) - case "plugin": - help_msg = ( - "插件管理命令帮助\n" - "/pm plugin help 插件管理命令提示\n" - "/pm plugin list 列出所有注册的插件\n" - "/pm plugin list_enabled 列出所有加载(启用)的插件\n" - "/pm plugin rescan 重新扫描所有目录\n" - "/pm plugin load 加载指定插件\n" - "/pm plugin unload 卸载指定插件\n" - "/pm plugin reload 重新加载指定插件\n" - "/pm plugin add_dir 添加插件目录\n" - ) - case "component": - help_msg = ( - "组件管理命令帮助\n" - "/pm component help 组件管理命令提示\n" - "/pm component list 列出所有注册的组件\n" - "/pm component list enabled <可选: type> 列出所有启用的组件\n" - "/pm component list disabled <可选: type> 列出所有禁用的组件\n" - " - 可选项: local,代表当前聊天中的;global,代表全局的\n" - " - 不填时为 global\n" - "/pm component list type 列出已经注册的指定类型的组件\n" - "/pm component enable global 全局启用组件\n" - "/pm component enable local 本聊天启用组件\n" - "/pm component disable global 全局禁用组件\n" - "/pm component disable local 本聊天禁用组件\n" - " - 可选项: action, command, event_handler\n" - ) + # /pm plugin / /pm component + if n == 3: + if parts[1] == "plugin": + await self._handle_plugin_3(parts[2], stream_id) + elif parts[1] == "component": + if parts[2] == "list": + await self._list_all_components(stream_id) + elif parts[2] == "help": + await self.ctx.send.text(HELP_COMPONENT, stream_id) + else: + await self.ctx.send.text("插件管理命令不合法", stream_id) + return False, "命令不合法", True + else: + await self.ctx.send.text("插件管理命令不合法", stream_id) + return False, "命令不合法", True + return True, "命令执行完成", True + + if n == 4: + if parts[1] == "plugin": + await self._handle_plugin_4(parts[2], parts[3], stream_id) + elif parts[1] == "component": + if parts[2] == "list": + await self._handle_component_list_4(parts[3], stream_id) + else: + await self.ctx.send.text("插件管理命令不合法", stream_id) + return False, "命令不合法", True + else: + await self.ctx.send.text("插件管理命令不合法", stream_id) + return False, "命令不合法", True + return True, "命令执行完成", True + + if n == 5: + if parts[1] != "component" or parts[2] != "list": + await self.ctx.send.text("插件管理命令不合法", stream_id) + return False, "命令不合法", True + await self._handle_component_list_5(parts[3], parts[4], stream_id) + return True, "命令执行完成", True + + if n == 6: + if parts[1] != "component": + await self.ctx.send.text("插件管理命令不合法", stream_id) + return False, "命令不合法", True + await self._handle_component_toggle(parts[2], parts[3], parts[4], parts[5], stream_id) + return True, "命令执行完成", True + + await self.ctx.send.text("插件管理命令不合法", stream_id) + return False, "命令不合法", True + + # ------ plugin 子命令 ------ + + async def _handle_plugin_3(self, action: str, stream_id: str): + match action: + case "help": + await self.ctx.send.text(HELP_PLUGIN, stream_id) + case "list": + result = await self.ctx.component.list_registered_plugins() + plugins = result if isinstance(result, list) else [] + await self.ctx.send.text(f"已注册的插件: {', '.join(plugins) if plugins else '无'}", stream_id) + case "list_enabled": + result = await self.ctx.component.list_loaded_plugins() + plugins = result if isinstance(result, list) else [] + await self.ctx.send.text(f"已加载的插件: {', '.join(plugins) if plugins else '无'}", stream_id) case _: + await self.ctx.send.text("插件管理命令不合法", stream_id) + + async def _handle_plugin_4(self, action: str, name: str, stream_id: str): + match action: + case "load": + result = await self.ctx.component.load_plugin(name) + ok = result.get("success", False) if isinstance(result, dict) else bool(result) + msg = f"插件加载成功: {name}" if ok else f"插件加载失败: {name}" + await self.ctx.send.text(msg, stream_id) + case "unload": + result = await self.ctx.component.unload_plugin(name) + ok = result.get("success", False) if isinstance(result, dict) else bool(result) + msg = f"插件卸载成功: {name}" if ok else f"插件卸载失败: {name}" + await self.ctx.send.text(msg, stream_id) + case "reload": + result = await self.ctx.component.reload_plugin(name) + ok = result.get("success", False) if isinstance(result, dict) else bool(result) + msg = f"插件重新加载成功: {name}" if ok else f"插件重新加载失败: {name}" + await self.ctx.send.text(msg, stream_id) + case _: + await self.ctx.send.text("插件管理命令不合法", stream_id) + + # ------ component 子命令 ------ + + async def _list_all_components(self, stream_id: str): + result = await self.ctx.component.get_all_plugins() + if not result: + await self.ctx.send.text("没有注册的组件", stream_id) + return + components = self._extract_components(result) + if not components: + await self.ctx.send.text("没有注册的组件", stream_id) + return + text = ", ".join(f"{c['name']} ({c['type']})" for c in components) + await self.ctx.send.text(f"已注册的组件: {text}", stream_id) + + async def _handle_component_list_4(self, sub: str, stream_id: str): + if sub == "enabled": + await self._list_filtered_components("enabled", "global", stream_id) + elif sub == "disabled": + await self._list_filtered_components("disabled", "global", stream_id) + else: + await self.ctx.send.text("插件管理命令不合法", stream_id) + + async def _handle_component_list_5(self, sub: str, arg: str, stream_id: str): + if sub in ("enabled", "disabled"): + await self._list_filtered_components(sub, arg, stream_id) + elif sub == "type": + if arg not in _VALID_COMPONENT_TYPES: + await self.ctx.send.text(f"未知组件类型: {arg}", stream_id) return - await self._send_message(help_msg) - - async def _list_loaded_plugins(self): - plugins = plugin_manage_api.list_loaded_plugins() - await self._send_message(f"已加载的插件: {', '.join(plugins)}") - - async def _list_registered_plugins(self): - plugins = plugin_manage_api.list_registered_plugins() - await self._send_message(f"已注册的插件: {', '.join(plugins)}") - - async def _rescan_plugin_dirs(self): - plugin_manage_api.rescan_plugin_directory() - await self._send_message("插件目录重新扫描执行中") - - async def _load_plugin(self, plugin_name: str): - success, count = plugin_manage_api.load_plugin(plugin_name) - if success: - await self._send_message(f"插件加载成功: {plugin_name}") + result = await self.ctx.component.get_all_plugins() + components = [c for c in self._extract_components(result) if c.get("type") == arg] + if not components: + await self.ctx.send.text(f"没有注册的 {arg} 组件", stream_id) + return + text = ", ".join(f"{c['name']} ({c['type']})" for c in components) + await self.ctx.send.text(f"注册的 {arg} 组件: {text}", stream_id) else: - if count == 0: - await self._send_message(f"插件{plugin_name}为禁用状态") - await self._send_message(f"插件加载失败: {plugin_name}") + await self.ctx.send.text("插件管理命令不合法", stream_id) - async def _unload_plugin(self, plugin_name: str): - success = await plugin_manage_api.remove_plugin(plugin_name) - if success: - await self._send_message(f"插件卸载成功: {plugin_name}") + async def _list_filtered_components(self, filter_mode: str, scope: str, stream_id: str): + result = await self.ctx.component.get_all_plugins() + all_components = self._extract_components(result) + if not all_components: + await self.ctx.send.text("没有注册的组件", stream_id) + return + + if filter_mode == "enabled": + filtered = [c for c in all_components if c.get("enabled", False)] + label = "已启用" else: - await self._send_message(f"插件卸载失败: {plugin_name}") + filtered = [c for c in all_components if not c.get("enabled", False)] + label = "已禁用" - async def _reload_plugin(self, plugin_name: str): - success = await plugin_manage_api.reload_plugin(plugin_name) - if success: - await self._send_message(f"插件重新加载成功: {plugin_name}") + scope_label = "全局" if scope == "global" else "本聊天" + if not filtered: + await self.ctx.send.text(f"没有满足条件的{label}{scope_label}组件", stream_id) + return + text = ", ".join(f"{c['name']} ({c['type']})" for c in filtered) + await self.ctx.send.text(f"满足条件的{label}{scope_label}组件: {text}", stream_id) + + async def _handle_component_toggle( + self, action: str, scope: str, comp_name: str, comp_type: str, stream_id: str + ): + if action not in ("enable", "disable"): + await self.ctx.send.text("插件管理命令不合法", stream_id) + return + if scope not in ("global", "local"): + await self.ctx.send.text("插件管理命令不合法", stream_id) + return + if comp_type not in _VALID_COMPONENT_TYPES: + await self.ctx.send.text(f"未知组件类型: {comp_type}", stream_id) + return + + if action == "enable": + result = await self.ctx.component.enable_component( + comp_name, comp_type, scope=scope, stream_id=stream_id + ) else: - await self._send_message(f"插件重新加载失败: {plugin_name}") + result = await self.ctx.component.disable_component( + comp_name, comp_type, scope=scope, stream_id=stream_id + ) - async def _add_dir(self, dir_path: str): - await self._send_message(f"正在添加插件目录: {dir_path}") - success = plugin_manage_api.add_plugin_directory(dir_path) - await asyncio.sleep(0.5) # 防止乱序发送 - if success: - await self._send_message(f"插件目录添加成功: {dir_path}") - else: - await self._send_message(f"插件目录添加失败: {dir_path}") + ok = result.get("success", False) if isinstance(result, dict) else bool(result) + scope_label = "全局" if scope == "global" else "本地" + action_label = "启用" if action == "enable" else "禁用" + status = "成功" if ok else "失败" + await self.ctx.send.text(f"{scope_label}{action_label}组件{status}: {comp_name}", stream_id) - def _fetch_all_registered_components(self) -> List[ComponentInfo]: - all_plugin_info = component_manage_api.get_all_plugin_info() - if not all_plugin_info: + # ------ helpers ------ + + @staticmethod + def _extract_components(result) -> list[dict]: + """从 get_all_plugins 结果中提取所有组件列表""" + if not result: return [] - - components_info: List[ComponentInfo] = [] - for plugin_info in all_plugin_info.values(): - components_info.extend(plugin_info.components) - return components_info - - def _fetch_locally_disabled_components(self) -> List[str]: - locally_disabled_components_actions = component_manage_api.get_locally_disabled_components( - self.message.chat_stream.stream_id, ComponentType.ACTION - ) - locally_disabled_components_commands = component_manage_api.get_locally_disabled_components( - self.message.chat_stream.stream_id, ComponentType.COMMAND - ) - locally_disabled_components_event_handlers = component_manage_api.get_locally_disabled_components( - self.message.chat_stream.stream_id, ComponentType.EVENT_HANDLER - ) - return ( - locally_disabled_components_actions - + locally_disabled_components_commands - + locally_disabled_components_event_handlers - ) - - async def _list_all_registered_components(self): - components_info = self._fetch_all_registered_components() - if not components_info: - await self._send_message("没有注册的组件") - return - - all_components_str = ", ".join( - f"{component.name} ({component.component_type})" for component in components_info - ) - await self._send_message(f"已注册的组件: {all_components_str}") - - async def _list_enabled_components(self, target_type: str = "global"): - components_info = self._fetch_all_registered_components() - if not components_info: - await self._send_message("没有注册的组件") - return - - if target_type == "global": - enabled_components = [component for component in components_info if component.enabled] - if not enabled_components: - await self._send_message("没有满足条件的已启用全局组件") - return - enabled_components_str = ", ".join( - f"{component.name} ({component.component_type})" for component in enabled_components - ) - await self._send_message(f"满足条件的已启用全局组件: {enabled_components_str}") - elif target_type == "local": - locally_disabled_components = self._fetch_locally_disabled_components() - enabled_components = [ - component - for component in components_info - if (component.name not in locally_disabled_components and component.enabled) - ] - if not enabled_components: - await self._send_message("本聊天没有满足条件的已启用组件") - return - enabled_components_str = ", ".join( - f"{component.name} ({component.component_type})" for component in enabled_components - ) - await self._send_message(f"本聊天满足条件的已启用组件: {enabled_components_str}") - - async def _list_disabled_components(self, target_type: str = "global"): - components_info = self._fetch_all_registered_components() - if not components_info: - await self._send_message("没有注册的组件") - return - - if target_type == "global": - disabled_components = [component for component in components_info if not component.enabled] - if not disabled_components: - await self._send_message("没有满足条件的已禁用全局组件") - return - disabled_components_str = ", ".join( - f"{component.name} ({component.component_type})" for component in disabled_components - ) - await self._send_message(f"满足条件的已禁用全局组件: {disabled_components_str}") - elif target_type == "local": - locally_disabled_components = self._fetch_locally_disabled_components() - disabled_components = [ - component - for component in components_info - if (component.name in locally_disabled_components or not component.enabled) - ] - if not disabled_components: - await self._send_message("本聊天没有满足条件的已禁用组件") - return - disabled_components_str = ", ".join( - f"{component.name} ({component.component_type})" for component in disabled_components - ) - await self._send_message(f"本聊天满足条件的已禁用组件: {disabled_components_str}") - - async def _list_registered_components_by_type(self, target_type: str): - match target_type: - case "action": - component_type = ComponentType.ACTION - case "command": - component_type = ComponentType.COMMAND - case "event_handler": - component_type = ComponentType.EVENT_HANDLER - case _: - await self._send_message(f"未知组件类型: {target_type}") - return - - components_info = component_manage_api.get_components_info_by_type(component_type) - if not components_info: - await self._send_message(f"没有注册的 {target_type} 组件") - return - - components_str = ", ".join( - f"{name} ({component.component_type})" for name, component in components_info.items() - ) - await self._send_message(f"注册的 {target_type} 组件: {components_str}") - - async def _globally_enable_component(self, component_name: str, component_type: str): - match component_type: - case "action": - target_component_type = ComponentType.ACTION - case "command": - target_component_type = ComponentType.COMMAND - case "event_handler": - target_component_type = ComponentType.EVENT_HANDLER - case _: - await self._send_message(f"未知组件类型: {component_type}") - return - if component_manage_api.globally_enable_component(component_name, target_component_type): - await self._send_message(f"全局启用组件成功: {component_name}") - else: - await self._send_message(f"全局启用组件失败: {component_name}") - - async def _globally_disable_component(self, component_name: str, component_type: str): - match component_type: - case "action": - target_component_type = ComponentType.ACTION - case "command": - target_component_type = ComponentType.COMMAND - case "event_handler": - target_component_type = ComponentType.EVENT_HANDLER - case _: - await self._send_message(f"未知组件类型: {component_type}") - return - success = await component_manage_api.globally_disable_component(component_name, target_component_type) - if success: - await self._send_message(f"全局禁用组件成功: {component_name}") - else: - await self._send_message(f"全局禁用组件失败: {component_name}") - - async def _locally_enable_component(self, component_name: str, component_type: str): - match component_type: - case "action": - target_component_type = ComponentType.ACTION - case "command": - target_component_type = ComponentType.COMMAND - case "event_handler": - target_component_type = ComponentType.EVENT_HANDLER - case _: - await self._send_message(f"未知组件类型: {component_type}") - return - if component_manage_api.locally_enable_component( - component_name, - target_component_type, - self.message.chat_stream.stream_id, - ): - await self._send_message(f"本地启用组件成功: {component_name}") - else: - await self._send_message(f"本地启用组件失败: {component_name}") - - async def _locally_disable_component(self, component_name: str, component_type: str): - match component_type: - case "action": - target_component_type = ComponentType.ACTION - case "command": - target_component_type = ComponentType.COMMAND - case "event_handler": - target_component_type = ComponentType.EVENT_HANDLER - case _: - await self._send_message(f"未知组件类型: {component_type}") - return - if component_manage_api.locally_disable_component( - component_name, - target_component_type, - self.message.chat_stream.stream_id, - ): - await self._send_message(f"本地禁用组件成功: {component_name}") - else: - await self._send_message(f"本地禁用组件失败: {component_name}") - - async def _send_message(self, message: str): - await send_api.text_to_stream(message, self.stream_id, typing=False, storage_message=False) + if isinstance(result, dict): + components = [] + for plugin_info in result.values(): + if isinstance(plugin_info, dict): + components.extend(plugin_info.get("components", [])) + return components + return [] -@register_plugin -class PluginManagementPlugin(BasePlugin): - plugin_name: str = "plugin_management_plugin" - enable_plugin: bool = False - dependencies: list[str] = [] - python_dependencies: list[str] = [] - config_file_name: str = "config.toml" - config_schema: dict = { - "plugin": { - "enabled": ConfigField(bool, default=False, description="是否启用插件"), - "config_version": ConfigField(type=str, default="1.1.0", description="配置文件版本"), - "permission": ConfigField( - list, default=[], description="有权限使用插件管理命令的用户列表,请填写字符串形式的用户ID" - ), - }, - } - - def get_plugin_components(self) -> List[Tuple[CommandInfo, Type[BaseCommand]]]: - components = [] - if self.get_config("plugin.enabled", True): - components.append((ManagementCommand.get_command_info(), ManagementCommand)) - return components +def create_plugin(): + return PluginManagementPlugin() diff --git a/src/plugins/built_in/tts_plugin/_manifest.json b/src/plugins/built_in/tts_plugin/_manifest.json index 01e83907..8b43b9e7 100644 --- a/src/plugins/built_in/tts_plugin/_manifest.json +++ b/src/plugins/built_in/tts_plugin/_manifest.json @@ -1,7 +1,7 @@ { "manifest_version": 1, "name": "文本转语音插件 (Text-to-Speech)", - "version": "0.1.0", + "version": "2.0.0", "description": "将文本转换为语音进行播放的插件,支持多种语音模式和智能语音输出场景判断。", "author": { "name": "MaiBot团队", @@ -10,7 +10,7 @@ "license": "GPL-v3.0-or-later", "host_application": { - "min_version": "0.8.0" + "min_version": "1.0.0" }, "homepage_url": "https://github.com/MaiM-with-u/maibot", "repository_url": "https://github.com/MaiM-with-u/maibot", diff --git a/src/plugins/built_in/tts_plugin/plugin.py b/src/plugins/built_in/tts_plugin/plugin.py index 62dfc91a..8b22ebb5 100644 --- a/src/plugins/built_in/tts_plugin/plugin.py +++ b/src/plugins/built_in/tts_plugin/plugin.py @@ -1,145 +1,55 @@ -from src.plugin_system.apis.plugin_register_api import register_plugin -from src.plugin_system.base.base_plugin import BasePlugin -from src.plugin_system.base.component_types import ComponentInfo -from src.common.logger import get_logger -from src.plugin_system.base.base_action import BaseAction, ActionActivationType -from src.plugin_system.base.config_types import ConfigField -from typing import Tuple, List, Type +"""TTS 插件 — 新 SDK 版本 -logger = get_logger("tts") +将文本转换为语音进行播放。 +""" + +import re + +from maibot_sdk import MaiBotPlugin, Action +from maibot_sdk.types import ActivationType -class TTSAction(BaseAction): - """TTS语音转换动作处理类""" - - # 激活设置 - activation_type = ActionActivationType.KEYWORD - activation_keywords = ["语音", "tts", "播报", "读出来", "语音播放", "听", "朗读"] - keyword_case_sensitive = False - parallel_action = False - - # 动作基本信息 - action_name = "tts_action" - action_description = "将文本转换为语音进行播放,适用于需要语音输出的场景" - - # 动作参数定义 - action_parameters = { - "voice_text": "你想用语音表达的内容,这段内容将会以语音形式发出", - } - - # 动作使用场景 - action_require = [ - "当需要发送语音信息时使用", - "当用户明确要求使用语音功能时使用", - "当表达内容更适合用语音而不是文字传达时使用", - "当用户想听到语音回答而非阅读文本时使用", - ] - - # 关联类型 - associated_types = ["tts_text"] - - async def execute(self) -> Tuple[bool, str]: - """处理TTS文本转语音动作""" - logger.info(f"{self.log_prefix} 执行TTS动作: {self.reasoning}") - - # 获取要转换的文本 - text = self.action_data.get("voice_text") +class TTSPlugin(MaiBotPlugin): + """文本转语音插件""" + @Action( + "tts_action", + description="将文本转换为语音进行播放,适用于需要语音输出的场景", + activation_type=ActivationType.KEYWORD, + activation_keywords=["语音", "tts", "播报", "读出来", "语音播放", "听", "朗读"], + parallel_action=False, + action_parameters={"voice_text": "你想用语音表达的内容,这段内容将会以语音形式发出"}, + action_require=[ + "当需要发送语音信息时使用", + "当用户明确要求使用语音功能时使用", + "当表达内容更适合用语音而不是文字传达时使用", + "当用户想听到语音回答而非阅读文本时使用", + ], + associated_types=["tts_text"], + ) + async def handle_tts_action(self, stream_id: str = "", action_data: dict = None, reasoning: str = "", **kwargs): + """处理 TTS 文本转语音动作""" + action_data = action_data or {} + text = action_data.get("voice_text", "") if not text: - logger.error(f"{self.log_prefix} 执行TTS动作时未提供文本内容") return False, "执行TTS动作失败:未提供文本内容" - # 确保文本适合TTS使用 - processed_text = self._process_text_for_tts(text) - - try: - # 发送TTS消息 - await self.send_custom(message_type="tts_text", content=processed_text) - - # 记录动作信息 - await self.store_action_info( - action_build_into_prompt=True, action_prompt_display="已经发送了语音消息。", action_done=True - ) - - logger.info(f"{self.log_prefix} TTS动作执行成功,文本长度: {len(processed_text)}") - return True, "TTS动作执行成功" - - except Exception as e: - logger.error(f"{self.log_prefix} 执行TTS动作时出错: {e}") - return False, f"执行TTS动作时出错: {e}" - - def _process_text_for_tts(self, text: str) -> str: - """ - 处理文本使其更适合TTS使用 - - 移除不必要的特殊字符和表情符号 - - 修正标点符号以提高语音质量 - - 优化文本结构使语音更流畅 - """ - # 这里可以添加文本处理逻辑 - # 例如:移除多余的标点、表情符号,优化语句结构等 - - # 简单示例实现 - processed_text = text - - # 移除多余的标点符号 - import re - - processed_text = re.sub(r"([!?,.;:。!?,、;:])\1+", r"\1", processed_text) - - # 确保句子结尾有合适的标点 + # 文本预处理 + processed_text = re.sub(r"([!?,.;:。!?,、;:])\1+", r"\1", text) if not any(processed_text.endswith(end) for end in [".", "?", "!", "。", "!", "?"]): processed_text = f"{processed_text}。" - return processed_text + # 发送自定义 tts 消息 + result = await self.ctx.call_capability( + "send.custom", + message_type="tts_text", + content=processed_text, + stream_id=stream_id, + ) + if result and result.get("success"): + return True, "TTS动作执行成功" + return False, f"TTS动作执行失败: {result}" -@register_plugin -class TTSPlugin(BasePlugin): - """TTS插件 - - 这是文字转语音插件 - - Normal模式下依靠关键词触发 - - Focus模式下由LLM判断触发 - - 具有一定的文本预处理能力 - """ - - # 插件基本信息 - plugin_name: str = "tts_plugin" # 内部标识符 - enable_plugin: bool = True - dependencies: list[str] = [] # 插件依赖列表 - python_dependencies: list[str] = [] # Python包依赖列表 - config_file_name: str = "config.toml" - - # 配置节描述 - config_section_descriptions = { - "plugin": "插件基本信息配置", - "components": "组件启用控制", - "logging": "日志记录相关配置", - } - - # 配置Schema定义 - config_schema: dict = { - "plugin": { - "name": ConfigField(type=str, default="tts_plugin", description="插件名称", required=True), - "version": ConfigField(type=str, default="0.1.0", description="插件版本号"), - "enabled": ConfigField(type=bool, default=True, description="是否启用插件"), - "description": ConfigField(type=str, default="文字转语音插件", description="插件描述", required=True), - }, - "components": {"enable_tts": ConfigField(type=bool, default=True, description="是否启用TTS Action")}, - "logging": { - "level": ConfigField( - type=str, default="INFO", description="日志记录级别", choices=["DEBUG", "INFO", "WARNING", "ERROR"] - ), - "prefix": ConfigField(type=str, default="[TTS]", description="日志记录前缀"), - }, - } - - def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: - """返回插件包含的组件列表""" - - # 从配置获取组件启用状态 - enable_tts = self.get_config("components.enable_tts", True) - components = [] # 添加Action组件 - if enable_tts: - components.append((TTSAction.get_action_info(), TTSAction)) - - return components +def create_plugin(): + return TTSPlugin() diff --git a/src/services/__init__.py b/src/services/__init__.py new file mode 100644 index 00000000..74bd007b --- /dev/null +++ b/src/services/__init__.py @@ -0,0 +1,7 @@ +""" +核心服务层 + +提供与具体插件系统无关的核心业务服务。 +内部模块(chat、dream、memory 等)应直接使用此层, +而 plugin_system.apis 仅作为面向插件的薄包装。 +""" diff --git a/src/services/chat_service.py b/src/services/chat_service.py new file mode 100644 index 00000000..ea6e6541 --- /dev/null +++ b/src/services/chat_service.py @@ -0,0 +1,159 @@ +""" +聊天服务模块 + +提供聊天信息查询和管理的核心功能。 +""" + +from typing import Any, Dict, List, Optional +from enum import Enum + +from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager +from src.common.logger import get_logger + +logger = get_logger("chat_service") + + +class SpecialTypes(Enum): + """特殊枚举类型""" + + ALL_PLATFORMS = "all_platforms" + + +class ChatManager: + """聊天管理器 - 负责聊天信息的查询和管理""" + + @staticmethod + def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]: + # sourcery skip: for-append-to-extend + if not isinstance(platform, (str, SpecialTypes)): + raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") + streams = [] + try: + for _, stream in _chat_manager.sessions.items(): + if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform: + streams.append(stream) + logger.debug(f"[ChatService] 获取到 {len(streams)} 个 {platform} 平台的聊天流") + except Exception as e: + logger.error(f"[ChatService] 获取聊天流失败: {e}") + return streams + + @staticmethod + def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]: + # sourcery skip: for-append-to-extend + if not isinstance(platform, (str, SpecialTypes)): + raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") + streams = [] + try: + for _, stream in _chat_manager.sessions.items(): + if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.is_group_session: + streams.append(stream) + logger.debug(f"[ChatService] 获取到 {len(streams)} 个 {platform} 平台的群聊流") + except Exception as e: + logger.error(f"[ChatService] 获取群聊流失败: {e}") + return streams + + @staticmethod + def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]: + # sourcery skip: for-append-to-extend + if not isinstance(platform, (str, SpecialTypes)): + raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") + streams = [] + try: + for _, stream in _chat_manager.sessions.items(): + if ( + platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform + ) and not stream.is_group_session: + streams.append(stream) + logger.debug(f"[ChatService] 获取到 {len(streams)} 个 {platform} 平台的私聊流") + except Exception as e: + logger.error(f"[ChatService] 获取私聊流失败: {e}") + return streams + + @staticmethod + def get_group_stream_by_group_id( + group_id: str, platform: Optional[str] | SpecialTypes = "qq" + ) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast + if not isinstance(group_id, str): + raise TypeError("group_id 必须是字符串类型") + if not isinstance(platform, (str, SpecialTypes)): + raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") + if not group_id: + raise ValueError("group_id 不能为空") + try: + for _, stream in _chat_manager.sessions.items(): + if ( + stream.is_group_session + and str(stream.group_id) == str(group_id) + and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) + ): + logger.debug(f"[ChatService] 找到群ID {group_id} 的聊天流") + return stream + logger.warning(f"[ChatService] 未找到群ID {group_id} 的聊天流") + except Exception as e: + logger.error(f"[ChatService] 查找群聊流失败: {e}") + return None + + @staticmethod + def get_private_stream_by_user_id( + user_id: str, platform: Optional[str] | SpecialTypes = "qq" + ) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast + if not isinstance(user_id, str): + raise TypeError("user_id 必须是字符串类型") + if not isinstance(platform, (str, SpecialTypes)): + raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") + if not user_id: + raise ValueError("user_id 不能为空") + try: + for _, stream in _chat_manager.sessions.items(): + if ( + not stream.is_group_session + and str(stream.user_id) == str(user_id) + and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) + ): + logger.debug(f"[ChatService] 找到用户ID {user_id} 的私聊流") + return stream + logger.warning(f"[ChatService] 未找到用户ID {user_id} 的私聊流") + except Exception as e: + logger.error(f"[ChatService] 查找私聊流失败: {e}") + return None + + @staticmethod + def get_stream_type(chat_stream: BotChatSession) -> str: + if not isinstance(chat_stream, BotChatSession): + raise TypeError("chat_stream 必须是 BotChatSession 类型") + if not chat_stream: + raise ValueError("chat_stream 不能为 None") + + return "group" if chat_stream.is_group_session else "private" + + @staticmethod + def get_stream_info(chat_stream: BotChatSession) -> Dict[str, Any]: + if not chat_stream: + raise ValueError("chat_stream 不能为 None") + if not isinstance(chat_stream, BotChatSession): + raise TypeError("chat_stream 必须是 BotChatSession 类型") + + try: + info: Dict[str, Any] = { + "session_id": chat_stream.session_id, + "platform": chat_stream.platform, + "type": ChatManager.get_stream_type(chat_stream), + } + + if chat_stream.is_group_session: + info["group_id"] = chat_stream.group_id + if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.group_info: + info["group_name"] = chat_stream.context.message.message_info.group_info.group_name or "未知群聊" + else: + info["group_name"] = "未知群聊" + else: + info["user_id"] = chat_stream.user_id + if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.user_info: + info["user_name"] = chat_stream.context.message.message_info.user_info.user_nickname + else: + info["user_name"] = "未知用户" + + return info + except Exception as e: + logger.error(f"[ChatService] 获取聊天流信息失败: {e}") + return {} diff --git a/src/plugin_system/apis/config_api.py b/src/services/config_service.py similarity index 67% rename from src/plugin_system/apis/config_api.py rename to src/services/config_service.py index 05556414..5cca332e 100644 --- a/src/plugin_system/apis/config_api.py +++ b/src/services/config_service.py @@ -1,28 +1,19 @@ -"""配置API模块 +"""配置服务模块 -提供了配置读取和用户信息获取等功能 -使用方式: - from src.plugin_system.apis import config_api - value = config_api.get_global_config("section.key") - platform, user_id = await config_api.get_user_id_by_person_name("用户名") +提供配置读取的核心功能。 """ from typing import Any + from src.common.logger import get_logger from src.config.config import global_config -logger = get_logger("config_api") - - -# ============================================================================= -# 配置访问API函数 -# ============================================================================= +logger = get_logger("config_service") def get_global_config(key: str, default: Any = None) -> Any: """ 安全地从全局配置中获取一个值。 - 插件应使用此方法读取全局配置,以保证只读和隔离性。 Args: key: 命名空间式配置键名,使用嵌套访问,如 "section.subsection.key",大小写敏感 @@ -31,7 +22,6 @@ def get_global_config(key: str, default: Any = None) -> Any: Returns: Any: 配置值或默认值 """ - # 支持嵌套键访问 keys = key.split(".") current = global_config @@ -43,7 +33,7 @@ def get_global_config(key: str, default: Any = None) -> Any: raise KeyError(f"配置中不存在子空间或键 '{k}'") return current except Exception as e: - logger.warning(f"[ConfigAPI] 获取全局配置 {key} 失败: {e}") + logger.warning(f"[ConfigService] 获取全局配置 {key} 失败: {e}") return default @@ -59,7 +49,6 @@ def get_plugin_config(plugin_config: dict, key: str, default: Any = None) -> Any Returns: Any: 配置值或默认值 """ - # 支持嵌套键访问 keys = key.split(".") current = plugin_config @@ -73,5 +62,5 @@ def get_plugin_config(plugin_config: dict, key: str, default: Any = None) -> Any raise KeyError(f"配置中不存在子空间或键 '{k}'") return current except Exception as e: - logger.warning(f"[ConfigAPI] 获取插件配置 {key} 失败: {e}") + logger.warning(f"[ConfigService] 获取插件配置 {key} 失败: {e}") return default diff --git a/src/plugin_system/apis/database_api.py b/src/services/database_service.py similarity index 89% rename from src/plugin_system/apis/database_api.py rename to src/services/database_service.py index 9af4e078..31705745 100644 --- a/src/plugin_system/apis/database_api.py +++ b/src/services/database_service.py @@ -1,6 +1,6 @@ -"""数据库API模块 +"""数据库服务模块 -提供数据库操作相关功能,统一使用 SQLModel/SQLAlchemy 兼容接口。 +提供数据库操作相关的核心功能。 """ import json @@ -10,7 +10,7 @@ from typing import Any, Optional from src.common.logger import get_logger -logger = get_logger("database_api") +logger = get_logger("database_service") def _to_dict(record: Any) -> dict[str, Any]: @@ -73,7 +73,7 @@ async def db_query( return query.count() except Exception as e: - logger.error(f"[DatabaseAPI] 数据库操作出错: {e}") + logger.error(f"[DatabaseService] 数据库操作出错: {e}") traceback.print_exc() if query_type == "get": return None if single_result else [] @@ -93,7 +93,7 @@ async def db_save(model_class, data: dict[str, Any], key_field: Optional[str] = new_record = model_class.create(**data) return _to_dict(new_record) except Exception as e: - logger.error(f"[DatabaseAPI] 保存数据库记录出错: {e}") + logger.error(f"[DatabaseService] 保存数据库记录出错: {e}") traceback.print_exc() return None @@ -119,7 +119,7 @@ async def db_get( return results[0] if results else None return results except Exception as e: - logger.error(f"[DatabaseAPI] 获取数据库记录出错: {e}") + logger.error(f"[DatabaseService] 获取数据库记录出错: {e}") traceback.print_exc() return None if single_result else [] @@ -163,11 +163,11 @@ async def store_action_info( ActionRecords, data=record_data, key_field="action_id", key_value=record_data["action_id"] ) if saved_record: - logger.debug(f"[DatabaseAPI] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})") + logger.debug(f"[DatabaseService] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})") else: - logger.error(f"[DatabaseAPI] 存储动作信息失败: {action_name}") + logger.error(f"[DatabaseService] 存储动作信息失败: {action_name}") return saved_record except Exception as e: - logger.error(f"[DatabaseAPI] 存储动作信息时发生错误: {e}") + logger.error(f"[DatabaseService] 存储动作信息时发生错误: {e}") traceback.print_exc() return None diff --git a/src/services/emoji_service.py b/src/services/emoji_service.py new file mode 100644 index 00000000..f6d14348 --- /dev/null +++ b/src/services/emoji_service.py @@ -0,0 +1,406 @@ +""" +表情服务模块 + +提供表情包相关的核心功能。 +""" + +import base64 +import os +import random +import uuid + +from typing import Any, Dict, List, Optional, Tuple + +from src.chat.emoji_system.emoji_manager import emoji_manager, EMOJI_DIR +from src.common.logger import get_logger +from src.common.utils.utils_image import ImageUtils +from src.config.config import global_config + +logger = get_logger("emoji_service") + + +# ============================================================================= +# 表情包获取函数 +# ============================================================================= + + +async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]: + """根据描述选择表情包""" + if not description: + raise ValueError("描述不能为空") + if not isinstance(description, str): + raise TypeError("描述必须是字符串类型") + try: + logger.debug(f"[EmojiService] 根据描述获取表情包: {description}") + + emoji_obj = await emoji_manager.get_emoji_for_emotion(description) + + if not emoji_obj: + logger.warning(f"[EmojiService] 未找到匹配描述 '{description}' 的表情包") + return None + + emoji_path = str(emoji_obj.full_path) + emoji_description = emoji_obj.description + matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else "" + emoji_base64 = ImageUtils.image_path_to_base64(emoji_path) + + if not emoji_base64: + logger.error(f"[EmojiService] 无法将表情包文件转换为base64: {emoji_path}") + return None + + logger.debug(f"[EmojiService] 成功获取表情包: {emoji_description}, 匹配情感: {matched_emotion}") + return emoji_base64, emoji_description, matched_emotion + + except Exception as e: + logger.error(f"[EmojiService] 获取表情包失败: {e}") + return None + + +async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]: + """随机获取指定数量的表情包""" + if not isinstance(count, int): + raise TypeError("count 必须是整数类型") + if count < 0: + raise ValueError("count 不能为负数") + if count == 0: + logger.warning("[EmojiService] count 为0,返回空列表") + return [] + + try: + all_emojis = emoji_manager.emojis + + if not all_emojis: + logger.warning("[EmojiService] 没有可用的表情包") + return [] + + valid_emojis = [emoji for emoji in all_emojis if not emoji.is_deleted] + if not valid_emojis: + logger.warning("[EmojiService] 没有有效的表情包") + return [] + + if len(valid_emojis) < count: + logger.debug( + f"[EmojiService] 有效表情包数量 ({len(valid_emojis)}) 少于请求的数量 ({count}),将返回所有有效表情包" + ) + count = len(valid_emojis) + + selected_emojis = random.sample(valid_emojis, count) + + results = [] + for selected_emoji in selected_emojis: + emoji_base64 = ImageUtils.image_path_to_base64(str(selected_emoji.full_path)) + + if not emoji_base64: + logger.error(f"[EmojiService] 无法转换表情包为base64: {selected_emoji.full_path}") + continue + + matched_emotion = random.choice(selected_emoji.emotion) if selected_emoji.emotion else "随机表情" + + emoji_manager.update_emoji_usage(selected_emoji) + results.append((emoji_base64, selected_emoji.description, matched_emotion)) + + if not results and count > 0: + logger.warning("[EmojiService] 随机获取表情包失败,没有一个可以成功处理") + return [] + + logger.debug(f"[EmojiService] 成功获取 {len(results)} 个随机表情包") + return results + + except Exception as e: + logger.error(f"[EmojiService] 获取随机表情包失败: {e}") + return [] + + +async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]: + """根据情感标签获取表情包""" + if not emotion: + raise ValueError("情感标签不能为空") + if not isinstance(emotion, str): + raise TypeError("情感标签必须是字符串类型") + try: + logger.info(f"[EmojiService] 根据情感获取表情包: {emotion}") + + all_emojis = emoji_manager.emojis + + matching_emojis = [] + matching_emojis.extend( + emoji_obj + for emoji_obj in all_emojis + if not emoji_obj.is_deleted and emotion.lower() in [e.lower() for e in emoji_obj.emotion] + ) + if not matching_emojis: + logger.warning(f"[EmojiService] 未找到匹配情感 '{emotion}' 的表情包") + return None + + selected_emoji = random.choice(matching_emojis) + emoji_base64 = ImageUtils.image_path_to_base64(selected_emoji.full_path) + + if not emoji_base64: + logger.error(f"[EmojiService] 无法转换表情包为base64: {selected_emoji.full_path}") + return None + + emoji_manager.update_emoji_usage(selected_emoji) + + logger.info(f"[EmojiService] 成功获取情感表情包: {selected_emoji.description}") + return emoji_base64, selected_emoji.description, emotion + + except Exception as e: + logger.error(f"[EmojiService] 根据情感获取表情包失败: {e}") + return None + + +# ============================================================================= +# 表情包信息查询函数 +# ============================================================================= + + +def get_count() -> int: + try: + return len(emoji_manager.emojis) + except Exception as e: + logger.error(f"[EmojiService] 获取表情包数量失败: {e}") + return 0 + + +def get_info(): + try: + return { + "current_count": len(emoji_manager.emojis), + "max_count": global_config.emoji.max_reg_num, + "available_emojis": len([e for e in emoji_manager.emojis if not e.is_deleted]), + } + except Exception as e: + logger.error(f"[EmojiService] 获取表情包信息失败: {e}") + return {"current_count": 0, "max_count": 0, "available_emojis": 0} + + +def get_emotions() -> List[str]: + try: + emotions = set() + + for emoji_obj in emoji_manager.emojis: + if not emoji_obj.is_deleted and emoji_obj.emotion: + emotions.update(emoji_obj.emotion) + + return sorted(list(emotions)) + except Exception as e: + logger.error(f"[EmojiService] 获取情感标签失败: {e}") + return [] + + +async def get_all() -> List[Tuple[str, str, str]]: + try: + all_emojis = emoji_manager.emojis + + if not all_emojis: + logger.warning("[EmojiService] 没有可用的表情包") + return [] + + results = [] + for emoji_obj in all_emojis: + if emoji_obj.is_deleted: + continue + + emoji_base64 = ImageUtils.image_path_to_base64(str(emoji_obj.full_path)) + + if not emoji_base64: + logger.error(f"[EmojiService] 无法转换表情包为base64: {emoji_obj.full_path}") + continue + + matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else "随机表情" + results.append((emoji_base64, emoji_obj.description, matched_emotion)) + + logger.debug(f"[EmojiService] 成功获取 {len(results)} 个表情包") + return results + + except Exception as e: + logger.error(f"[EmojiService] 获取所有表情包失败: {e}") + return [] + + +def get_descriptions() -> List[str]: + try: + descriptions = [] + + descriptions.extend( + emoji_obj.description + for emoji_obj in emoji_manager.emojis + if not emoji_obj.is_deleted and emoji_obj.description + ) + return descriptions + except Exception as e: + logger.error(f"[EmojiService] 获取表情包描述失败: {e}") + return [] + + +# ============================================================================= +# 表情包注册函数 +# ============================================================================= + + +async def register_emoji(image_base64: str, filename: Optional[str] = None) -> Dict[str, Any]: + """注册新的表情包""" + if not image_base64: + raise ValueError("图片base64编码不能为空") + if not isinstance(image_base64, str): + raise TypeError("image_base64必须是字符串类型") + if filename is not None and not isinstance(filename, str): + raise TypeError("filename必须是字符串类型或None") + + try: + logger.info(f"[EmojiService] 开始注册表情包,文件名: {filename or '自动生成'}") + + count_before = len(emoji_manager.emojis) + max_count = global_config.emoji.max_reg_num + + can_register = count_before < max_count or (count_before >= max_count and global_config.emoji.do_replace) + + if not can_register: + return { + "success": False, + "message": f"表情包数量已达上限({count_before}/{max_count})且未启用替换功能", + "description": None, + "emotions": None, + "replaced": None, + "hash": None, + } + + os.makedirs(EMOJI_DIR, exist_ok=True) + + if not filename: + import time as _time + + timestamp = int(_time.time()) + microseconds = int(_time.time() * 1000000) % 1000000 + + random_bytes = random.getrandbits(72).to_bytes(9, "big") + short_id = base64.b64encode(random_bytes).decode("ascii")[:12].rstrip("=") + short_id = short_id.replace("/", "_").replace("+", "-") + filename = f"emoji_{timestamp}_{microseconds}_{short_id}" + + if not filename.lower().endswith((".jpg", ".jpeg", ".png", ".gif")): + filename = f"{filename}.png" + + temp_file_path = os.path.join(EMOJI_DIR, filename) + attempts = 0 + max_attempts = 10 + while os.path.exists(temp_file_path) and attempts < max_attempts: + random_bytes = random.getrandbits(48).to_bytes(6, "big") + short_id = base64.b64encode(random_bytes).decode("ascii")[:8].rstrip("=") + short_id = short_id.replace("/", "_").replace("+", "-") + + name_part, ext = os.path.splitext(filename) + base_name = name_part.rsplit("_", 1)[0] + filename = f"{base_name}_{short_id}{ext}" + temp_file_path = os.path.join(EMOJI_DIR, filename) + attempts += 1 + + if os.path.exists(temp_file_path): + uuid_short = str(uuid.uuid4())[:8] + name_part, ext = os.path.splitext(filename) + base_name = name_part.rsplit("_", 1)[0] + filename = f"{base_name}_{uuid_short}{ext}" + temp_file_path = os.path.join(EMOJI_DIR, filename) + + counter = 1 + original_filename = filename + while os.path.exists(temp_file_path): + name_part, ext = os.path.splitext(original_filename) + filename = f"{name_part}_{counter}{ext}" + temp_file_path = os.path.join(EMOJI_DIR, filename) + counter += 1 + + if counter > 100: + logger.error(f"[EmojiService] 无法生成唯一文件名,尝试次数过多: {original_filename}") + return { + "success": False, + "message": "无法生成唯一文件名,请稍后重试", + "description": None, + "emotions": None, + "replaced": None, + "hash": None, + } + + try: + if not ImageUtils.base64_to_image(image_base64, temp_file_path): + logger.error(f"[EmojiService] 无法保存base64图片到文件: {temp_file_path}") + return { + "success": False, + "message": "无法保存图片文件", + "description": None, + "emotions": None, + "replaced": None, + "hash": None, + } + + logger.debug(f"[EmojiService] 图片已保存到临时文件: {temp_file_path}") + + except Exception as save_error: + logger.error(f"[EmojiService] 保存图片文件失败: {save_error}") + return { + "success": False, + "message": f"保存图片文件失败: {str(save_error)}", + "description": None, + "emotions": None, + "replaced": None, + "hash": None, + } + + register_success = await emoji_manager.register_emoji_by_filename(filename) + + if not register_success and os.path.exists(temp_file_path): + try: + os.remove(temp_file_path) + logger.debug(f"[EmojiService] 已清理临时文件: {temp_file_path}") + except Exception as cleanup_error: + logger.warning(f"[EmojiService] 清理临时文件失败: {cleanup_error}") + + if register_success: + count_after = len(emoji_manager.emojis) + replaced = count_after <= count_before + + new_emoji_info = None + if count_after > count_before or replaced: + try: + for emoji_obj in reversed(emoji_manager.emojis): + if not emoji_obj.is_deleted and ( + emoji_obj.file_name == filename + or (hasattr(emoji_obj, "full_path") and filename in str(emoji_obj.full_path)) + ): + new_emoji_info = emoji_obj + break + except Exception as find_error: + logger.warning(f"[EmojiService] 查找新注册表情包信息失败: {find_error}") + + description = new_emoji_info.description if new_emoji_info else None + emotions = new_emoji_info.emotion if new_emoji_info else None + emoji_hash = new_emoji_info.emoji_hash if new_emoji_info else None + + return { + "success": True, + "message": f"表情包注册成功 {'(替换旧表情包)' if replaced else '(新增表情包)'}", + "description": description, + "emotions": emotions, + "replaced": replaced, + "hash": emoji_hash, + } + else: + return { + "success": False, + "message": "表情包注册失败,可能因为重复、格式不支持或审核未通过", + "description": None, + "emotions": None, + "replaced": None, + "hash": None, + } + + except Exception as e: + logger.error(f"[EmojiService] 注册表情包时发生异常: {e}") + return { + "success": False, + "message": f"注册过程中发生错误: {str(e)}", + "description": None, + "emotions": None, + "replaced": None, + "hash": None, + } diff --git a/src/plugin_system/apis/frequency_api.py b/src/services/frequency_service.py similarity index 90% rename from src/plugin_system/apis/frequency_api.py rename to src/services/frequency_service.py index 9cde1d90..eceb6b95 100644 --- a/src/plugin_system/apis/frequency_api.py +++ b/src/services/frequency_service.py @@ -1,9 +1,11 @@ -from src.common.logger import get_logger +"""频率控制服务模块 + +提供聊天频率控制的核心功能。 +""" + from src.chat.heart_flow.frequency_control import frequency_control_manager from src.config.config import global_config -logger = get_logger("frequency_api") - def get_current_talk_value(chat_id: str) -> float: return frequency_control_manager.get_or_create_frequency_control( diff --git a/src/plugin_system/apis/generator_api.py b/src/services/generator_service.py similarity index 60% rename from src/plugin_system/apis/generator_api.py rename to src/services/generator_service.py index 3217817c..8b5c5152 100644 --- a/src/plugin_system/apis/generator_api.py +++ b/src/services/generator_service.py @@ -1,39 +1,37 @@ """ -回复器API模块 +回复器服务模块 -提供回复器相关功能,采用标准Python包设计模式 -使用方式: - from src.plugin_system.apis import generator_api - replyer = generator_api.get_replyer(chat_stream) - success, reply_set, _ = await generator_api.generate_reply(chat_stream, action_data, reasoning) +提供回复器相关的核心功能。 """ import traceback import time -from typing import Tuple, Any, Dict, List, Optional, TYPE_CHECKING +from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING + from rich.traceback import install -from src.common.logger import get_logger -from src.common.data_models.message_data_model import ReplySetModel + +from src.chat.logger.plan_reply_logger import PlanReplyLogger +from src.chat.message_receive.chat_manager import BotChatSession from src.chat.replyer.group_generator import DefaultReplyer from src.chat.replyer.private_generator import PrivateReplyer -from src.chat.message_receive.chat_manager import BotChatSession -from src.chat.utils.utils import process_llm_response from src.chat.replyer.replyer_manager import replyer_manager -from src.plugin_system.base.component_types import ActionInfo -from src.chat.logger.plan_reply_logger import PlanReplyLogger +from src.chat.utils.utils import process_llm_response +from src.common.data_models.message_data_model import ReplySetModel +from src.common.logger import get_logger +from src.core.types import ActionInfo if TYPE_CHECKING: - from src.common.data_models.info_data_model import ActionPlannerInfo from src.common.data_models.database_data_model import DatabaseMessages + from src.common.data_models.info_data_model import ActionPlannerInfo from src.common.data_models.llm_data_model import LLMGenerationDataModel install(extra_lines=3) -logger = get_logger("generator_api") +logger = get_logger("generator_service") # ============================================================================= -# 回复器获取API函数 +# 回复器获取函数 # ============================================================================= @@ -42,39 +40,24 @@ def get_replyer( chat_id: Optional[str] = None, request_type: str = "replyer", ) -> Optional[DefaultReplyer | PrivateReplyer]: - """获取回复器对象 - - 优先使用chat_stream,如果没有则使用chat_id直接查找。 - 使用 ReplyerManager 来管理实例,避免重复创建。 - - Args: - chat_stream: 聊天流对象(优先) - chat_id: 聊天ID(实际上就是stream_id) - request_type: 请求类型 - - Returns: - Optional[DefaultReplyer]: 回复器对象,如果获取失败则返回None - - Raises: - ValueError: chat_stream 和 chat_id 均为空 - """ + """获取回复器对象""" if not chat_id and not chat_stream: raise ValueError("chat_stream 和 chat_id 不可均为空") try: - logger.debug(f"[GeneratorAPI] 正在获取回复器,chat_id: {chat_id}, chat_stream: {'有' if chat_stream else '无'}") + logger.debug(f"[GeneratorService] 正在获取回复器,chat_id: {chat_id}, chat_stream: {'有' if chat_stream else '无'}") return replyer_manager.get_replyer( chat_stream=chat_stream, chat_id=chat_id, request_type=request_type, ) except Exception as e: - logger.error(f"[GeneratorAPI] 获取回复器时发生意外错误: {e}", exc_info=True) + logger.error(f"[GeneratorService] 获取回复器时发生意外错误: {e}", exc_info=True) traceback.print_exc() return None # ============================================================================= -# 回复生成API函数 +# 回复生成函数 # ============================================================================= @@ -96,39 +79,15 @@ async def generate_reply( from_plugin: bool = True, reply_time_point: Optional[float] = None, ) -> Tuple[bool, Optional["LLMGenerationDataModel"]]: - """生成回复 - - Args: - chat_stream: 聊天流对象(优先) - chat_id: 聊天ID(备用) - action_data: 动作数据(向下兼容,包含reply_to和extra_info) - reply_message: 回复的消息对象 - extra_info: 额外信息,用于补充上下文 - reply_reason: 回复原因 - available_actions: 可用动作 - chosen_actions: 已选动作 - unknown_words: Planner 在 reply 动作中给出的未知词语列表,用于黑话检索 - enable_tool: 是否启用工具调用 - enable_splitter: 是否启用消息分割器 - enable_chinese_typo: 是否启用错字生成器 - return_prompt: 是否返回提示词 - model_set_with_weight: 模型配置列表,每个元素为 (TaskConfig, weight) 元组 - request_type: 请求类型(可选,记录LLM使用) - from_plugin: 是否来自插件 - reply_time_point: 回复时间点 - Returns: - Tuple[bool, List[Tuple[str, Any]], Optional[str]]: (是否成功, 回复集合, 提示词) - """ + """生成回复""" try: - # 如果 reply_time_point 未传入,设置为当前时间戳 if reply_time_point is None: reply_time_point = time.time() - # 获取回复器 - logger.debug("[GeneratorAPI] 开始生成回复") + logger.debug("[GeneratorService] 开始生成回复") replyer = get_replyer(chat_stream, chat_id, request_type=request_type) if not replyer: - logger.error("[GeneratorAPI] 无法获取回复器") + logger.error("[GeneratorService] 无法获取回复器") return False, None if action_data: @@ -136,11 +95,9 @@ async def generate_reply( extra_info = action_data.get("extra_info", "") if not reply_reason: reply_reason = action_data.get("reason", "") - # 仅在 reply 场景下使用的未知词语解析(Planner JSON 中下发) if unknown_words is None: uw = action_data.get("unknown_words") if isinstance(uw, list): - # 只保留非空字符串 cleaned: List[str] = [] for item in uw: if isinstance(item, str): @@ -150,7 +107,6 @@ async def generate_reply( if cleaned: unknown_words = cleaned - # 调用回复器生成回复 success, llm_response = await replyer.generate_reply_with_context( extra_info=extra_info, available_actions=available_actions, @@ -166,7 +122,7 @@ async def generate_reply( log_reply=False, ) if not success: - logger.warning("[GeneratorAPI] 回复生成失败") + logger.warning("[GeneratorService] 回复生成失败") return False, None reply_set: Optional[ReplySetModel] = None if content := llm_response.content: @@ -176,9 +132,8 @@ async def generate_reply( for text in processed_response: reply_set.add_text_content(text) llm_response.reply_set = reply_set - logger.debug(f"[GeneratorAPI] 回复生成成功,生成了 {len(reply_set) if reply_set else 0} 个回复项") + logger.debug(f"[GeneratorService] 回复生成成功,生成了 {len(reply_set) if reply_set else 0} 个回复项") - # 统一在这里记录最终回复日志(包含分割后的 processed_output) try: PlanReplyLogger.log_reply( chat_id=chat_stream.session_id if chat_stream else (chat_id or ""), @@ -192,7 +147,7 @@ async def generate_reply( success=True, ) except Exception: - logger.exception("[GeneratorAPI] 记录reply日志失败") + logger.exception("[GeneratorService] 记录reply日志失败") return success, llm_response @@ -200,11 +155,11 @@ async def generate_reply( raise ve except UserWarning as uw: - logger.warning(f"[GeneratorAPI] 中断了生成: {uw}") + logger.warning(f"[GeneratorService] 中断了生成: {uw}") return False, None except Exception as e: - logger.error(f"[GeneratorAPI] 生成回复时出错: {e}") + logger.error(f"[GeneratorService] 生成回复时出错: {e}") logger.error(traceback.format_exc()) return False, None @@ -220,39 +175,20 @@ async def rewrite_reply( reply_to: str = "", request_type: str = "generator_api", ) -> Tuple[bool, Optional["LLMGenerationDataModel"]]: - """重写回复 - - Args: - chat_stream: 聊天流对象(优先) - reply_data: 回复数据字典(向下兼容备用,当其他参数缺失时从此获取) - chat_id: 聊天ID(备用) - enable_splitter: 是否启用消息分割器 - enable_chinese_typo: 是否启用错字生成器 - model_set_with_weight: 模型配置列表,每个元素为 (TaskConfig, weight) 元组 - raw_reply: 原始回复内容 - reason: 回复原因 - reply_to: 回复对象 - return_prompt: 是否返回提示词 - - Returns: - Tuple[bool, List[Tuple[str, Any]]]: (是否成功, 回复集合) - """ + """重写回复""" try: - # 获取回复器 replyer = get_replyer(chat_stream, chat_id, request_type=request_type) if not replyer: - logger.error("[GeneratorAPI] 无法获取回复器") + logger.error("[GeneratorService] 无法获取回复器") return False, None - logger.info("[GeneratorAPI] 开始重写回复") + logger.info("[GeneratorService] 开始重写回复") - # 如果参数缺失,从reply_data中获取 if reply_data: raw_reply = raw_reply or reply_data.get("raw_reply", "") reason = reason or reply_data.get("reason", "") reply_to = reply_to or reply_data.get("reply_to", "") - # 调用回复器重写回复 success, llm_response = await replyer.rewrite_reply_with_context( raw_reply=raw_reply, reason=reason, @@ -263,9 +199,9 @@ async def rewrite_reply( reply_set = process_human_text(content, enable_splitter, enable_chinese_typo) llm_response.reply_set = reply_set if success: - logger.info(f"[GeneratorAPI] 重写回复成功,生成了 {len(reply_set) if reply_set else 0} 个回复项") + logger.info(f"[GeneratorService] 重写回复成功,生成了 {len(reply_set) if reply_set else 0} 个回复项") else: - logger.warning("[GeneratorAPI] 重写回复失败") + logger.warning("[GeneratorService] 重写回复失败") return success, llm_response @@ -273,18 +209,12 @@ async def rewrite_reply( raise ve except Exception as e: - logger.error(f"[GeneratorAPI] 重写回复时出错: {e}") + logger.error(f"[GeneratorService] 重写回复时出错: {e}") return False, None def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> Optional[ReplySetModel]: - """将文本处理为更拟人化的文本 - - Args: - content: 文本内容 - enable_splitter: 是否启用消息分割器 - enable_chinese_typo: 是否启用错字生成器 - """ + """将文本处理为更拟人化的文本""" if not isinstance(content, str): raise ValueError("content 必须是字符串类型") try: @@ -297,7 +227,7 @@ def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: return reply_set except Exception as e: - logger.error(f"[GeneratorAPI] 处理人形文本时出错: {e}") + logger.error(f"[GeneratorService] 处理人形文本时出错: {e}") return None @@ -309,18 +239,18 @@ async def generate_response_custom( ) -> Optional[str]: replyer = get_replyer(chat_stream, chat_id, request_type=request_type) if not replyer: - logger.error("[GeneratorAPI] 无法获取回复器") + logger.error("[GeneratorService] 无法获取回复器") return None try: - logger.debug("[GeneratorAPI] 开始生成自定义回复") + logger.debug("[GeneratorService] 开始生成自定义回复") response, _, _, _ = await replyer.llm_generate_content(prompt) if response: - logger.debug("[GeneratorAPI] 自定义回复生成成功") + logger.debug("[GeneratorService] 自定义回复生成成功") return response else: - logger.warning("[GeneratorAPI] 自定义回复生成失败") + logger.warning("[GeneratorService] 自定义回复生成失败") return None except Exception as e: - logger.error(f"[GeneratorAPI] 生成自定义回复时出错: {e}") + logger.error(f"[GeneratorService] 生成自定义回复时出错: {e}") return None diff --git a/src/plugin_system/apis/llm_api.py b/src/services/llm_service.py similarity index 81% rename from src/plugin_system/apis/llm_api.py rename to src/services/llm_service.py index f35b1102..b267e67c 100644 --- a/src/plugin_system/apis/llm_api.py +++ b/src/services/llm_service.py @@ -1,26 +1,19 @@ -"""LLM API模块 +"""LLM 服务模块 -提供了与LLM模型交互的功能 -使用方式: - from src.plugin_system.apis import llm_api - models = llm_api.get_available_models() - success, response, reasoning, model_name = await llm_api.generate_with_model(prompt, model_config) +提供与 LLM 模型交互的核心功能。 """ -from typing import Tuple, Dict, List, Any, Optional, Callable +from typing import Any, Callable, Dict, List, Optional, Tuple + from src.common.logger import get_logger -from src.llm_models.payload_content.tool_option import ToolCall -from src.llm_models.payload_content.message import Message -from src.llm_models.model_client.base_client import BaseClient -from src.llm_models.utils_model import LLMRequest from src.config.config import config_manager from src.config.model_configs import TaskConfig +from src.llm_models.model_client.base_client import BaseClient +from src.llm_models.payload_content.message import Message +from src.llm_models.payload_content.tool_option import ToolCall +from src.llm_models.utils_model import LLMRequest -logger = get_logger("llm_api") - -# ============================================================================= -# LLM模型API函数 -# ============================================================================= +logger = get_logger("llm_service") def get_available_models() -> Dict[str, TaskConfig]: @@ -30,7 +23,6 @@ def get_available_models() -> Dict[str, TaskConfig]: Dict[str, Any]: 模型配置字典,key为模型名称,value为模型配置 """ try: - # 自动获取所有属性并转换为字典形式 models = config_manager.get_model_config().model_task_config attrs = dir(models) rets: Dict[str, TaskConfig] = {} @@ -41,12 +33,12 @@ def get_available_models() -> Dict[str, TaskConfig]: if not callable(value) and isinstance(value, TaskConfig): rets[attr] = value except Exception as e: - logger.debug(f"[LLMAPI] 获取属性 {attr} 失败: {e}") + logger.debug(f"[LLMService] 获取属性 {attr} 失败: {e}") continue return rets except Exception as e: - logger.error(f"[LLMAPI] 获取可用模型失败: {e}") + logger.error(f"[LLMService] 获取可用模型失败: {e}") return {} @@ -68,9 +60,7 @@ async def generate_with_model( Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称) """ try: - # model_name_list = model_config.model_list - # logger.info(f"[LLMAPI] 使用模型集合 {model_name_list} 生成内容") - logger.debug(f"[LLMAPI] 完整提示词: {prompt}") + logger.debug(f"[LLMService] 完整提示词: {prompt}") llm_request = LLMRequest(model_set=model_config, request_type=request_type) @@ -81,7 +71,7 @@ async def generate_with_model( except Exception as e: error_msg = f"生成内容时出错: {str(e)}" - logger.error(f"[LLMAPI] {error_msg}") + logger.error(f"[LLMService] {error_msg}") return False, error_msg, "", "" @@ -104,7 +94,7 @@ async def generate_with_model_with_tools( max_tokens: 最大token数 Returns: - Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称) + Tuple[bool, str, str, str, List[ToolCall] | None]: (是否成功, 生成的内容, 推理过程, 模型名称, 工具调用列表) """ try: model_name_list = model_config.model_list @@ -120,7 +110,7 @@ async def generate_with_model_with_tools( except Exception as e: error_msg = f"生成内容时出错: {str(e)}" - logger.error(f"[LLMAPI] {error_msg}") + logger.error(f"[LLMService] {error_msg}") return False, error_msg, "", "", None @@ -161,5 +151,5 @@ async def generate_with_model_with_tools_by_message_factory( except Exception as e: error_msg = f"生成内容时出错: {str(e)}" - logger.error(f"[LLMAPI] {error_msg}") + logger.error(f"[LLMService] {error_msg}") return False, error_msg, "", "", None diff --git a/src/plugin_system/apis/message_api.py b/src/services/message_service.py similarity index 60% rename from src/plugin_system/apis/message_api.py rename to src/services/message_service.py index b7d1d2cf..7b175dfe 100644 --- a/src/plugin_system/apis/message_api.py +++ b/src/services/message_service.py @@ -1,62 +1,44 @@ """ -消息API模块 +消息服务模块 -提供消息查询和构建成字符串的功能,采用标准Python包设计模式 -使用方式: - from src.plugin_system.apis import message_api - messages = message_api.get_messages_by_time_in_chat(chat_id, start_time, end_time) - readable_text = message_api.build_readable_messages(messages) +提供消息查询和构建成字符串的核心功能。 """ import time -from typing import List, Dict, Any, Tuple, Optional +from typing import Any, Dict, List, Optional, Tuple + from sqlmodel import col, select -from src.common.data_models.database_data_model import DatabaseMessages -from src.common.database.database import get_db_session -from src.common.database.database_model import Images, ImageType -from src.chat.utils.utils import is_bot_self + from src.chat.utils.chat_message_builder import ( - get_raw_msg_by_timestamp, - get_raw_msg_by_timestamp_with_chat, - get_raw_msg_by_timestamp_with_chat_inclusive, - get_raw_msg_by_timestamp_with_chat_users, - get_raw_msg_by_timestamp_random, - get_raw_msg_by_timestamp_with_users, - get_raw_msg_before_timestamp, - get_raw_msg_before_timestamp_with_chat, - get_raw_msg_before_timestamp_with_users, - num_new_messages_since, - num_new_messages_since_with_users, build_readable_messages, build_readable_messages_with_list, get_person_id_list, + get_raw_msg_before_timestamp, + get_raw_msg_before_timestamp_with_chat, + get_raw_msg_before_timestamp_with_users, + get_raw_msg_by_timestamp, + get_raw_msg_by_timestamp_random, + get_raw_msg_by_timestamp_with_chat, + get_raw_msg_by_timestamp_with_chat_inclusive, + get_raw_msg_by_timestamp_with_chat_users, + get_raw_msg_by_timestamp_with_users, + num_new_messages_since, + num_new_messages_since_with_users, ) +from src.chat.utils.utils import is_bot_self +from src.common.data_models.database_data_model import DatabaseMessages +from src.common.database.database import get_db_session +from src.common.database.database_model import Images, ImageType # ============================================================================= -# 消息查询API函数 +# 消息查询函数 # ============================================================================= def get_messages_by_time( start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False ) -> List[DatabaseMessages]: - """ - 获取指定时间范围内的消息 - - Args: - start_time: 开始时间戳 - end_time: 结束时间戳 - limit: 限制返回的消息数量,0为不限制 - limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录 - filter_mai: 是否过滤麦麦自身的消息,默认为False - - Returns: - List[Dict[str, Any]]: 消息列表 - - Raises: - ValueError: 如果参数不合法 - """ if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): raise ValueError("start_time 和 end_time 必须是数字类型") if limit < 0: @@ -76,23 +58,6 @@ def get_messages_by_time_in_chat( filter_command: bool = False, filter_intercept_message_level: Optional[int] = None, ) -> List[DatabaseMessages]: - """ - 获取指定聊天中指定时间范围内的消息 - - Args: - chat_id: 聊天ID - start_time: 开始时间戳 - end_time: 结束时间戳 - limit: 限制返回的消息数量,0为不限制 - limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录 - filter_mai: 是否过滤麦麦自身的消息,默认为False - filter_command: 是否过滤命令消息,默认为False - Returns: - List[Dict[str, Any]]: 消息列表 - - Raises: - ValueError: 如果参数不合法 - """ if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): raise ValueError("start_time 和 end_time 必须是数字类型") if limit < 0: @@ -101,10 +66,6 @@ def get_messages_by_time_in_chat( raise ValueError("chat_id 不能为空") if not isinstance(chat_id, str): raise ValueError("chat_id 必须是字符串类型") - # if filter_mai: - # return filter_mai_messages( - # get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command) - # ) return get_raw_msg_by_timestamp_with_chat( chat_id=chat_id, timestamp_start=start_time, @@ -127,23 +88,6 @@ def get_messages_by_time_in_chat_inclusive( filter_command: bool = False, filter_intercept_message_level: Optional[int] = None, ) -> List[DatabaseMessages]: - """ - 获取指定聊天中指定时间范围内的消息(包含边界) - - Args: - chat_id: 聊天ID - start_time: 开始时间戳(包含) - end_time: 结束时间戳(包含) - limit: 限制返回的消息数量,0为不限制 - limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录 - filter_mai: 是否过滤麦麦自身的消息,默认为False - - Returns: - List[Dict[str, Any]]: 消息列表 - - Raises: - ValueError: 如果参数不合法 - """ if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): raise ValueError("start_time 和 end_time 必须是数字类型") if limit < 0: @@ -175,23 +119,6 @@ def get_messages_by_time_in_chat_for_users( limit: int = 0, limit_mode: str = "latest", ) -> List[DatabaseMessages]: - """ - 获取指定聊天中指定用户在指定时间范围内的消息 - - Args: - chat_id: 聊天ID - start_time: 开始时间戳 - end_time: 结束时间戳 - person_ids: 用户ID列表 - limit: 限制返回的消息数量,0为不限制 - limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录 - - Returns: - List[Dict[str, Any]]: 消息列表 - - Raises: - ValueError: 如果参数不合法 - """ if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): raise ValueError("start_time 和 end_time 必须是数字类型") if limit < 0: @@ -206,22 +133,6 @@ def get_messages_by_time_in_chat_for_users( def get_random_chat_messages( start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False ) -> List[DatabaseMessages]: - """ - 随机选择一个聊天,返回该聊天在指定时间范围内的消息 - - Args: - start_time: 开始时间戳 - end_time: 结束时间戳 - limit: 限制返回的消息数量,0为不限制 - limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录 - filter_mai: 是否过滤麦麦自身的消息,默认为False - - Returns: - List[Dict[str, Any]]: 消息列表 - - Raises: - ValueError: 如果参数不合法 - """ if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): raise ValueError("start_time 和 end_time 必须是数字类型") if limit < 0: @@ -234,22 +145,6 @@ def get_random_chat_messages( def get_messages_by_time_for_users( start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest" ) -> List[DatabaseMessages]: - """ - 获取指定用户在所有聊天中指定时间范围内的消息 - - Args: - start_time: 开始时间戳 - end_time: 结束时间戳 - person_ids: 用户ID列表 - limit: 限制返回的消息数量,0为不限制 - limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录 - - Returns: - List[Dict[str, Any]]: 消息列表 - - Raises: - ValueError: 如果参数不合法 - """ if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): raise ValueError("start_time 和 end_time 必须是数字类型") if limit < 0: @@ -258,20 +153,6 @@ def get_messages_by_time_for_users( def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[DatabaseMessages]: - """ - 获取指定时间戳之前的消息 - - Args: - timestamp: 时间戳 - limit: 限制返回的消息数量,0为不限制 - filter_mai: 是否过滤麦麦自身的消息,默认为False - - Returns: - List[Dict[str, Any]]: 消息列表 - - Raises: - ValueError: 如果参数不合法 - """ if not isinstance(timestamp, (int, float)): raise ValueError("timestamp 必须是数字类型") if limit < 0: @@ -288,21 +169,6 @@ def get_messages_before_time_in_chat( filter_mai: bool = False, filter_intercept_message_level: Optional[int] = None, ) -> List[DatabaseMessages]: - """ - 获取指定聊天中指定时间戳之前的消息 - - Args: - chat_id: 聊天ID - timestamp: 时间戳 - limit: 限制返回的消息数量,0为不限制 - filter_mai: 是否过滤麦麦自身的消息,默认为False - - Returns: - List[Dict[str, Any]]: 消息列表 - - Raises: - ValueError: 如果参数不合法 - """ if not isinstance(timestamp, (int, float)): raise ValueError("timestamp 必须是数字类型") if limit < 0: @@ -325,20 +191,6 @@ def get_messages_before_time_in_chat( def get_messages_before_time_for_users( timestamp: float, person_ids: List[str], limit: int = 0 ) -> List[DatabaseMessages]: - """ - 获取指定用户在指定时间戳之前的消息 - - Args: - timestamp: 时间戳 - person_ids: 用户ID列表 - limit: 限制返回的消息数量,0为不限制 - - Returns: - List[Dict[str, Any]]: 消息列表 - - Raises: - ValueError: 如果参数不合法 - """ if not isinstance(timestamp, (int, float)): raise ValueError("timestamp 必须是数字类型") if limit < 0: @@ -349,22 +201,6 @@ def get_messages_before_time_for_users( def get_recent_messages( chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False ) -> List[DatabaseMessages]: - """ - 获取指定聊天中最近一段时间的消息 - - Args: - chat_id: 聊天ID - hours: 最近多少小时,默认24小时 - limit: 限制返回的消息数量,默认100条 - limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录 - filter_mai: 是否过滤麦麦自身的消息,默认为False - - Returns: - List[Dict[str, Any]]: 消息列表 - - Raises: - ValueError: 如果参数不合法s - """ if not isinstance(hours, (int, float)) or hours < 0: raise ValueError("hours 不能是负数") if not isinstance(limit, int) or limit < 0: @@ -381,25 +217,11 @@ def get_recent_messages( # ============================================================================= -# 消息计数API函数 +# 消息计数函数 # ============================================================================= def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional[float] = None) -> int: - """ - 计算指定聊天中从开始时间到结束时间的新消息数量 - - Args: - chat_id: 聊天ID - start_time: 开始时间戳 - end_time: 结束时间戳,如果为None则使用当前时间 - - Returns: - int: 新消息数量 - - Raises: - ValueError: 如果参数不合法 - """ if not isinstance(start_time, (int, float)): raise ValueError("start_time 必须是数字类型") if not chat_id: @@ -410,21 +232,6 @@ def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int: - """ - 计算指定聊天中指定用户从开始时间到结束时间的新消息数量 - - Args: - chat_id: 聊天ID - start_time: 开始时间戳 - end_time: 结束时间戳 - person_ids: 用户ID列表 - - Returns: - int: 新消息数量 - - Raises: - ValueError: 如果参数不合法 - """ if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): raise ValueError("start_time 和 end_time 必须是数字类型") if not chat_id: @@ -435,7 +242,7 @@ def count_new_messages_for_users(chat_id: str, start_time: float, end_time: floa # ============================================================================= -# 消息格式化API函数 +# 消息格式化函数 # ============================================================================= @@ -447,21 +254,6 @@ def build_readable_messages_to_str( truncate: bool = False, show_actions: bool = False, ) -> str: - """ - 将消息列表构建成可读的字符串 - - Args: - messages: 消息列表 - replace_bot_name: 是否将机器人的名称替换为"你" - merge_messages: 是否合并连续消息 - timestamp_mode: 时间戳显示模式,'relative'或'absolute' - read_mark: 已读标记时间戳,用于分割已读和未读消息 - truncate: 是否截断长消息 - show_actions: 是否显示动作记录 - - Returns: - 格式化后的可读字符串 - """ return build_readable_messages(messages, replace_bot_name, timestamp_mode, read_mark, truncate, show_actions) @@ -471,32 +263,10 @@ async def build_readable_messages_with_details( timestamp_mode: str = "relative", truncate: bool = False, ) -> Tuple[str, List[Tuple[float, str, str]]]: - """ - 将消息列表构建成可读的字符串,并返回详细信息 - - Args: - messages: 消息列表 - replace_bot_name: 是否将机器人的名称替换为"你" - merge_messages: 是否合并连续消息 - timestamp_mode: 时间戳显示模式,'relative'或'absolute' - truncate: 是否截断长消息 - - Returns: - 格式化后的可读字符串和详细信息元组列表(时间戳, 昵称, 内容) - """ return await build_readable_messages_with_list(messages, replace_bot_name, timestamp_mode, truncate) async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[str]: - """ - 从消息列表中提取不重复的用户ID列表 - - Args: - messages: 消息列表 - - Returns: - 用户ID列表 - """ return await get_person_id_list(messages) @@ -506,14 +276,7 @@ async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[s def filter_mai_messages(messages: List[DatabaseMessages]) -> List[DatabaseMessages]: - """ - 从消息列表中移除麦麦的消息 - Args: - messages: 消息列表,每个元素是消息字典 - Returns: - 过滤后的消息列表 - """ - # 使用统一的 is_bot_self 函数判断是否是机器人自己(支持多平台,包括 WebUI) + """从消息列表中移除麦麦的消息""" return [msg for msg in messages if not is_bot_self(msg.user_info.platform, msg.user_info.user_id)] diff --git a/src/plugin_system/apis/person_api.py b/src/services/person_service.py similarity index 55% rename from src/plugin_system/apis/person_api.py rename to src/services/person_service.py index ed904003..74c02feb 100644 --- a/src/plugin_system/apis/person_api.py +++ b/src/services/person_service.py @@ -1,22 +1,14 @@ -"""个人信息API模块 +"""个人信息服务模块 -提供个人信息查询功能,用于插件获取用户相关信息 -使用方式: - from src.plugin_system.apis import person_api - person_id = person_api.get_person_id("qq", 123456) - value = await person_api.get_person_value(person_id, "nickname") +提供个人信息查询的核心功能。 """ from typing import Any + from src.common.logger import get_logger from src.person_info.person_info import Person -logger = get_logger("person_api") - - -# ============================================================================= -# 个人信息API函数 -# ============================================================================= +logger = get_logger("person_service") def get_person_id(platform: str, user_id: int | str) -> str: @@ -28,14 +20,11 @@ def get_person_id(platform: str, user_id: int | str) -> str: Returns: str: 唯一的person_id(MD5哈希值) - - 示例: - person_id = person_api.get_person_id("qq", 123456) """ try: return Person(platform=platform, user_id=str(user_id)).person_id except Exception as e: - logger.error(f"[PersonAPI] 获取person_id失败: platform={platform}, user_id={user_id}, error={e}") + logger.error(f"[PersonService] 获取person_id失败: platform={platform}, user_id={user_id}, error={e}") return "" @@ -49,17 +38,13 @@ async def get_person_value(person_id: str, field_name: str, default: Any = None) Returns: Any: 字段值或默认值 - - 示例: - nickname = await person_api.get_person_value(person_id, "nickname", "未知用户") - impression = await person_api.get_person_value(person_id, "impression") """ try: person = Person(person_id=person_id) value = getattr(person, field_name) return value if value is not None else default except Exception as e: - logger.error(f"[PersonAPI] 获取用户信息失败: person_id={person_id}, field={field_name}, error={e}") + logger.error(f"[PersonService] 获取用户信息失败: person_id={person_id}, field={field_name}, error={e}") return default @@ -71,13 +56,10 @@ def get_person_id_by_name(person_name: str) -> str: Returns: str: person_id,如果未找到返回空字符串 - - 示例: - person_id = person_api.get_person_id_by_name("张三") """ try: person = Person(person_name=person_name) return person.person_id except Exception as e: - logger.error(f"[PersonAPI] 根据用户名获取person_id失败: person_name={person_name}, error={e}") + logger.error(f"[PersonService] 根据用户名获取person_id失败: person_name={person_name}, error={e}") return "" diff --git a/src/plugin_system/apis/send_api.py b/src/services/send_service.py similarity index 70% rename from src/plugin_system/apis/send_api.py rename to src/services/send_service.py index b18f8378..0f2eac85 100644 --- a/src/plugin_system/apis/send_api.py +++ b/src/services/send_service.py @@ -1,47 +1,32 @@ """ -发送API模块 +发送服务模块 -专门负责发送各种类型的消息,采用标准Python包设计模式 - -使用方式: - from src.plugin_system.apis import send_api - - # 方式1:直接使用stream_id(推荐) - await send_api.text_to_stream("hello", stream_id) - await send_api.emoji_to_stream(emoji_base64, stream_id) - await send_api.custom_to_stream("video", video_data, stream_id) - - # 方式2:使用群聊/私聊指定函数 - await send_api.text_to_group("hello", "123456") - await send_api.text_to_user("hello", "987654") - - # 方式3:使用通用custom_message函数 - await send_api.custom_message("video", video_data, "123456", True) +提供发送各种类型消息的核心功能。 """ import traceback import time from typing import Optional, Union, Dict, List, TYPE_CHECKING, Tuple -from src.common.logger import get_logger -from src.common.data_models.message_data_model import ReplyContentType -from src.config.config import global_config -from src.chat.message_receive.chat_manager import chat_manager as _chat_manager -from src.chat.message_receive.uni_message_sender import UniversalMessageSender -from maim_message import Seg +from maim_message import MessageBase, BaseMessageInfo, Seg -from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo +from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.chat.message_receive.message import MessageSending +from src.chat.message_receive.uni_message_sender import UniversalMessageSender +from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo +from src.common.data_models.message_data_model import ReplyContentType +from src.common.logger import get_logger +from src.config.config import global_config if TYPE_CHECKING: from src.common.data_models.database_data_model import DatabaseMessages - from src.common.data_models.message_data_model import ReplySetModel, ReplyContent, ForwardNode + from src.common.data_models.message_data_model import ForwardNode, ReplyContent, ReplySetModel -logger = get_logger("send_api") +logger = get_logger("send_service") # ============================================================================= -# 内部实现函数(不暴露给外部) +# 内部实现函数 # ============================================================================= @@ -56,42 +41,25 @@ async def _send_to_target( show_log: bool = True, selected_expressions: Optional[List[int]] = None, ) -> bool: - """向指定目标发送消息的内部实现 - - Args: - message_segment: - stream_id: 目标流ID - display_message: 显示消息 - typing: 是否模拟打字等待。 - reply_to: 回复消息,格式为"发送者:消息内容" - storage_message: 是否存储消息到数据库 - show_log: 发送是否显示日志 - - Returns: - bool: 是否发送成功 - """ + """向指定目标发送消息的内部实现""" try: if set_reply and not reply_message: - logger.warning("[SendAPI] 使用引用回复,但未提供回复消息") + logger.warning("[SendService] 使用引用回复,但未提供回复消息") return False if show_log: - logger.debug(f"[SendAPI] 发送{message_segment.type}消息到 {stream_id}") + logger.debug(f"[SendService] 发送{message_segment.type}消息到 {stream_id}") - # 查找目标聊天流 target_stream = _chat_manager.get_session_by_session_id(stream_id) if not target_stream: - logger.error(f"[SendAPI] 未找到聊天流: {stream_id}") + logger.error(f"[SendService] 未找到聊天流: {stream_id}") return False - # 创建发送器 message_sender = UniversalMessageSender() - # 生成消息ID current_time = time.time() message_id = f"send_api_{int(current_time * 1000)}" - # 构建机器人用户信息 bot_user_info = UserInfo( user_id=global_config.bot.qq_account, user_nickname=global_config.bot.nickname, @@ -102,17 +70,15 @@ async def _send_to_target( if reply_message: anchor_message = db_message_to_mai_message(reply_message) if anchor_message: - logger.debug(f"[SendAPI] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}") + logger.debug(f"[SendService] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}") reply_to_platform_id = ( f"{anchor_message.platform}:{anchor_message.message_info.user_info.user_id}" ) - # 构建 sender_info(私聊时为接收者信息) sender_info = None if target_stream.context and target_stream.context.message: sender_info = target_stream.context.message.message_info.user_info - # 构建发送消息对象 bot_message = MessageSending( message_id=message_id, session=target_stream, @@ -128,7 +94,6 @@ async def _send_to_target( selected_expressions=selected_expressions, ) - # 发送消息 sent_msg = await message_sender.send_message( bot_message, typing=typing, @@ -138,28 +103,22 @@ async def _send_to_target( ) if sent_msg: - logger.debug(f"[SendAPI] 成功发送消息到 {stream_id}") + logger.debug(f"[SendService] 成功发送消息到 {stream_id}") return True else: - logger.error("[SendAPI] 发送消息失败") + logger.error("[SendService] 发送消息失败") return False except Exception as e: - logger.error(f"[SendAPI] 发送消息时出错: {e}") + logger.error(f"[SendService] 发送消息时出错: {e}") traceback.print_exc() return False def db_message_to_mai_message(message_obj: "DatabaseMessages") -> Optional[MaiMessage]: - """将数据库消息重建为 MaiMessage 对象,用于回复引用。 - - Args: - message_obj: 插件系统的 DatabaseMessages 数据对象 - - Returns: - Optional[MaiMessage]: 构建的消息对象,如果信息不足则返回 None - """ + """将数据库消息重建为 MaiMessage 对象,用于回复引用。""" from datetime import datetime + from src.common.data_models.mai_message_data_model import GroupInfo, MessageInfo from src.common.data_models.message_component_data_model import MessageSequence @@ -190,7 +149,7 @@ def db_message_to_mai_message(message_obj: "DatabaseMessages") -> Optional[MaiMe # ============================================================================= -# 公共API函数 - 预定义类型的发送函数 +# 公共函数 - 预定义类型的发送函数 # ============================================================================= @@ -203,18 +162,7 @@ async def text_to_stream( storage_message: bool = True, selected_expressions: Optional[List[int]] = None, ) -> bool: - """向指定流发送文本消息 - - Args: - text: 要发送的文本内容 - stream_id: 聊天流ID - typing: 是否显示正在输入 - reply_to: 回复消息,格式为"发送者:消息内容" - storage_message: 是否存储消息到数据库 - - Returns: - bool: 是否发送成功 - """ + """向指定流发送文本消息""" return await _send_to_target( message_segment=Seg(type="text", data=text), stream_id=stream_id, @@ -234,16 +182,7 @@ async def emoji_to_stream( set_reply: bool = False, reply_message: Optional["DatabaseMessages"] = None, ) -> bool: - """向指定流发送表情包 - - Args: - emoji_base64: 表情包的base64编码 - stream_id: 聊天流ID - storage_message: 是否存储消息到数据库 - - Returns: - bool: 是否发送成功 - """ + """向指定流发送表情包""" return await _send_to_target( message_segment=Seg(type="emoji", data=emoji_base64), stream_id=stream_id, @@ -262,16 +201,7 @@ async def image_to_stream( set_reply: bool = False, reply_message: Optional["DatabaseMessages"] = None, ) -> bool: - """向指定流发送图片 - - Args: - image_base64: 图片的base64编码 - stream_id: 聊天流ID - storage_message: 是否存储消息到数据库 - - Returns: - bool: 是否发送成功 - """ + """向指定流发送图片""" return await _send_to_target( message_segment=Seg(type="image", data=image_base64), stream_id=stream_id, @@ -289,17 +219,7 @@ async def command_to_stream( storage_message: bool = True, display_message: str = "", ) -> bool: - """向指定流发送命令 - - Args: - command: 命令 - stream_id: 聊天流ID - storage_message: 是否存储消息到数据库 - display_message: 显示消息 - - Returns: - bool: 是否发送成功 - """ + """向指定流发送命令""" return await _send_to_target( message_segment=Seg(type="command", data=command), # type: ignore stream_id=stream_id, @@ -321,20 +241,7 @@ async def custom_to_stream( storage_message: bool = True, show_log: bool = True, ) -> bool: - """向指定流发送自定义类型消息 - - Args: - message_type: 消息类型,如"text"、"image"、"emoji"、"video"、"file"等 - content: 消息内容(通常是base64编码或文本) - stream_id: 聊天流ID - display_message: 显示消息 - typing: 是否显示正在输入 - reply_to: 回复消息,格式为"发送者:消息内容" - storage_message: 是否存储消息到数据库 - show_log: 是否显示日志 - Returns: - bool: 是否发送成功 - """ + """向指定流发送自定义类型消息""" return await _send_to_target( message_segment=Seg(type=message_type, data=content), # type: ignore stream_id=stream_id, @@ -350,25 +257,14 @@ async def custom_to_stream( async def custom_reply_set_to_stream( reply_set: "ReplySetModel", stream_id: str, - display_message: str = "", # 基本没用 + display_message: str = "", typing: bool = False, reply_message: Optional["DatabaseMessages"] = None, set_reply: bool = False, storage_message: bool = True, show_log: bool = True, ) -> bool: - """ - 向指定流发送混合型消息集 - - Args: - reply_set: ReplySetModel 对象,包含多个 ReplyContent - stream_id: 聊天流ID - display_message: 显示消息 - typing: 是否显示正在输入 - reply_to: 回复消息,格式为"发送者:消息内容" - storage_message: 是否存储消息到数据库 - show_log: 是否显示日志 - """ + """向指定流发送混合型消息集""" flag: bool = True for reply_content in reply_set.reply_data: status: bool = False @@ -386,20 +282,14 @@ async def custom_reply_set_to_stream( if not status: flag = False logger.error( - f"[SendAPI] 发送{repr(reply_content.content_type)}消息失败,消息内容:{str(reply_content.content)[:100]}" + f"[SendService] 发送{repr(reply_content.content_type)}消息失败,消息内容:{str(reply_content.content)[:100]}" ) return flag def _parse_content_to_seg(reply_content: "ReplyContent") -> Tuple[Seg, bool]: - """ - 把 ReplyContent 转换为 Seg 结构 (Forward 中仅递归一次) - Args: - reply_content: ReplyContent 对象 - Returns: - Tuple[Seg, bool]: 转换后的 Seg 结构和是否需要typing的标志 - """ + """把 ReplyContent 转换为 Seg 结构""" content_type = reply_content.content_type if content_type == ReplyContentType.TEXT: text_data: str = reply_content.content # type: ignore @@ -427,7 +317,7 @@ def _parse_content_to_seg(reply_content: "ReplyContent") -> Tuple[Seg, bool]: elif sub_content_type == ReplyContentType.EMOJI: sub_seg_list.append(Seg(type="emoji", data=sub_content_data)) # type: ignore else: - logger.warning(f"[SendAPI] 混合类型中不支持的子内容类型: {repr(sub_content_type)}") + logger.warning(f"[SendService] 混合类型中不支持的子内容类型: {repr(sub_content_type)}") continue return Seg(type="seglist", data=sub_seg_list), True elif content_type == ReplyContentType.FORWARD: diff --git a/src/webui/routers/plugin.py b/src/webui/routers/plugin.py index ab6ca479..32102235 100644 --- a/src/webui/routers/plugin.py +++ b/src/webui/routers/plugin.py @@ -6,7 +6,7 @@ import json from src.common.logger import get_logger from src.common.toml_utils import save_toml_with_format from src.config.config import MMC_VERSION -from src.plugin_system.base.config_types import ConfigField +from src.core.config_types import ConfigField from src.webui.services.git_mirror_service import get_git_mirror_service, set_update_progress_callback from src.webui.core import get_token_manager from src.webui.routers.websocket.plugin_progress import update_progress @@ -222,15 +222,16 @@ def coerce_types(schema_part: Dict[str, Any], config_part: Dict[str, Any]) -> No def find_plugin_instance(plugin_id: str) -> Optional[Any]: """ - 按 plugin_id 或 plugin_name 查找已加载的插件实例。 - 局部导入 plugin_manager 以规避循环依赖。 + 按 plugin_id 查找已加载的插件信息。 + 新运行时中插件运行在子进程,无法获取实例,返回注册信息。 """ - from src.plugin_system.core.plugin_manager import plugin_manager + from src.plugin_runtime.integration import get_plugin_runtime_manager - for loaded_plugin_name in plugin_manager.list_loaded_plugins(): - instance = plugin_manager.get_plugin_instance(loaded_plugin_name) - if instance and (instance.plugin_name == plugin_id or instance.get_manifest_info("id", "") == plugin_id): - return instance + mgr = get_plugin_runtime_manager() + for sv in mgr.supervisors: + reg = sv._registered_plugins.get(plugin_id) + if reg is not None: + return reg return None @@ -1497,26 +1498,10 @@ async def get_plugin_config_schema( logger.info(f"获取插件配置 Schema: {plugin_id}") try: - # 尝试从已加载的插件中获取 - from src.plugin_system.core.plugin_manager import plugin_manager - - # 查找插件实例 + # 新运行时中插件运行在子进程,无法直接获取实例的 webui_config_schema + # 尝试从文件系统读取 plugin_instance = None - # 遍历所有已加载的插件 - for loaded_plugin_name in plugin_manager.list_loaded_plugins(): - instance = plugin_manager.get_plugin_instance(loaded_plugin_name) - if instance: - # 匹配 plugin_name 或 manifest 中的 id - if instance.plugin_name == plugin_id: - plugin_instance = instance - break - # 也尝试匹配 manifest 中的 id - manifest_id = instance.get_manifest_info("id", "") - if manifest_id == plugin_id: - plugin_instance = instance - break - if plugin_instance and hasattr(plugin_instance, "get_webui_config_schema"): # 从插件实例获取 schema schema = plugin_instance.get_webui_config_schema()