重构整个插件系统,尝试恢复可启动性,新增插件系统maibot-plugin-sdk依赖

This commit is contained in:
DrSmoothl
2026-03-07 19:40:51 +08:00
parent 2e3dd44ee9
commit ce8d8dfd0a
90 changed files with 3785 additions and 10061 deletions

6
bot.py
View File

@@ -195,11 +195,11 @@ async def graceful_shutdown(): # sourcery skip: use-named-expression
# except Exception as e: # except Exception as e:
# logger.warning(f"关闭 WebUI 服务器时出错: {e}") # logger.warning(f"关闭 WebUI 服务器时出错: {e}")
from src.plugin_system.core.events_manager import events_manager from src.core.event_bus import event_bus
from src.plugin_system.base.component_types import EventType from src.core.types import EventType
# 触发 ON_STOP 事件 # 触发 ON_STOP 事件
await events_manager.handle_mai_events(event_type=EventType.ON_STOP) await event_bus.emit(event_type=EventType.ON_STOP)
# 停止新版本插件运行时 # 停止新版本插件运行时
from src.plugin_runtime.integration import get_plugin_runtime_manager from src.plugin_runtime.integration import get_plugin_runtime_manager

View File

@@ -9,7 +9,7 @@
}, },
"license": "GPL-v3.0-or-later", "license": "GPL-v3.0-or-later",
"host_application": { "host_application": {
"min_version": "0.10.3" "min_version": "1.0.0"
}, },
"homepage_url": "https://github.com/SengokuCola/BetterFrequency", "homepage_url": "https://github.com/SengokuCola/BetterFrequency",
"repository_url": "https://github.com/SengokuCola/BetterFrequency", "repository_url": "https://github.com/SengokuCola/BetterFrequency",

View File

@@ -1,144 +1,87 @@
from typing import List, Tuple, Type, Optional """发言频率控制插件 — 新 SDK 版本
from src.plugin_system import BasePlugin, register_plugin, BaseCommand, ComponentInfo, ConfigField
from src.plugin_system.apis import send_api, frequency_api 通过 /chat 命令设置和查看聊天频率。
"""
from maibot_sdk import MaiBotPlugin, Command
class SetTalkFrequencyCommand(BaseCommand): class BetterFrequencyPlugin(MaiBotPlugin):
"""设置当前聊天的talk_frequency值""" """聊天频率控制插件"""
command_name = "set_talk_frequency" @Command(
command_description = "设置当前聊天的talk_frequency值/chat talk_frequency <数字> 或 /chat t <数字>" "set_talk_frequency",
command_pattern = r"^/chat\s+(?:talk_frequency|t)\s+(?P<value>[+-]?\d*\.?\d+)$" description="设置当前聊天的talk_frequency值/chat talk_frequency <数字> 或 /chat t <数字>",
pattern=r"^/chat\s+(?:talk_frequency|t)\s+(?P<value>[+-]?\d*\.?\d+)$",
)
async def handle_set_talk_frequency(
self, stream_id: str = "", matched_groups: dict | None = None, **kwargs
):
"""设置当前聊天的 talk_frequency"""
if not matched_groups or "value" not in matched_groups:
return False, "命令格式错误", False
value_str = matched_groups["value"]
if not value_str:
return False, "无法获取数值参数", False
async def execute(self) -> Tuple[bool, Optional[str], bool]:
try: try:
# 获取命令参数 - 使用命名捕获组
if not self.matched_groups or "value" not in self.matched_groups:
return False, "命令格式错误", False
value_str = self.matched_groups["value"]
if not value_str:
return False, "无法获取数值参数", False
value = float(value_str) value = float(value_str)
# 获取聊天流ID
if not self.message.chat_stream or not hasattr(self.message.chat_stream, "stream_id"):
return False, "无法获取聊天流信息", False
chat_id = self.message.chat_stream.stream_id
# 设置talk_frequency
frequency_api.set_talk_frequency_adjust(chat_id, value)
final_value = frequency_api.get_current_talk_value(chat_id)
adjust_value = frequency_api.get_talk_frequency_adjust(chat_id)
base_value = final_value / adjust_value
# 发送反馈消息(不保存到数据库)
await send_api.text_to_stream(
f"已设置当前聊天的talk_frequency调整值为: {value}\n当前talk_value: {final_value:.2f}\n发言频率调整: {adjust_value:.2f}\n基础值: {base_value:.2f}",
chat_id,
storage_message=False,
)
return True, None, False
except ValueError: except ValueError:
error_msg = "数值格式错误,请输入有效的数字" await self.ctx.send.text("数值格式错误,请输入有效的数字", stream_id)
await self.send_text(error_msg, storage_message=False) return False, "数值格式错误", False
return False, error_msg, False
except Exception as e: if not stream_id:
error_msg = f"设置talk_frequency失败: {str(e)}" return False, "无法获取聊天流信息", False
await self.send_text(error_msg, storage_message=False)
return False, error_msg, False # 设置 talk_frequency
await self.ctx.frequency.set_adjust(stream_id, value)
# 获取当前状态
current = await self.ctx.frequency.get_current_talk_value(stream_id)
current_val = current if isinstance(current, (int, float)) else 0
adjust = await self.ctx.frequency.get_adjust(stream_id)
adjust_val = adjust if isinstance(adjust, (int, float)) else 1
base_val = current_val / adjust_val if adjust_val else 0
msg = (
f"已设置当前聊天的talk_frequency调整值为: {value}\n"
f"当前talk_value: {current_val:.2f}\n"
f"发言频率调整: {adjust_val:.2f}\n"
f"基础值: {base_val:.2f}"
)
await self.ctx.send.text(msg, stream_id)
return True, None, False
@Command(
"show_frequency",
description="显示当前聊天的频率控制状态:/chat show 或 /chat s",
pattern=r"^/chat\s+(?:show|s)$",
)
async def handle_show_frequency(self, stream_id: str = "", **kwargs):
"""显示当前频率控制状态"""
if not stream_id:
return False, "无法获取聊天流信息", False
current = await self.ctx.frequency.get_current_talk_value(stream_id)
current_val = current if isinstance(current, (int, float)) else 0
adjust = await self.ctx.frequency.get_adjust(stream_id)
adjust_val = adjust if isinstance(adjust, (int, float)) else 1
base_val = current_val / adjust_val if adjust_val else 0
status_msg = (
"当前聊天频率控制状态\n"
"Talk Value (发言频率):\n\n"
f" • 基础值: {base_val:.2f}\n"
f" • 发言频率调整: {adjust_val:.2f}\n"
f" • 当前值: {current_val:.2f}\n\n"
"使用命令:\n"
" • /chat talk_frequency <数字> 或 /chat t <数字> - 设置发言频率调整\n"
" • /chat show 或 /chat s - 显示当前状态"
)
await self.ctx.send.text(status_msg, stream_id)
return True, None, False
class ShowFrequencyCommand(BaseCommand): def create_plugin():
"""显示当前聊天的频率控制状态""" return BetterFrequencyPlugin()
command_name = "show_frequency"
command_description = "显示当前聊天的频率控制状态:/chat show 或 /chat s"
command_pattern = r"^/chat\s+(?:show|s)$"
async def execute(self) -> Tuple[bool, Optional[str], bool]:
try:
# 获取聊天流ID
if not self.message.chat_stream or not hasattr(self.message.chat_stream, "stream_id"):
return False, "无法获取聊天流信息", False
chat_id = self.message.chat_stream.stream_id
# 获取当前频率控制状态
current_talk_frequency = frequency_api.get_current_talk_value(chat_id)
talk_frequency_adjust = frequency_api.get_talk_frequency_adjust(chat_id)
base_value = current_talk_frequency / talk_frequency_adjust
# 构建显示消息
status_msg = f"""当前聊天频率控制状态
Talk Value (发言频率):
• 基础值: {base_value:.2f}
• 发言频率调整: {talk_frequency_adjust:.2f}
• 当前值: {current_talk_frequency:.2f}
使用命令:
• /chat talk_frequency <数字> 或 /chat t <数字> - 设置发言频率调整
• /chat show 或 /chat s - 显示当前状态"""
# 发送状态消息(不保存到数据库)
await send_api.text_to_stream(status_msg, chat_id, storage_message=False)
return True, None, False
except Exception as e:
error_msg = f"获取频率控制状态失败: {str(e)}"
# 使用内置的send_text方法发送错误消息
await self.send_text(error_msg, storage_message=False)
return False, error_msg, False
# ===== 插件注册 =====
@register_plugin
class BetterFrequencyPlugin(BasePlugin):
"""BetterFrequency插件 - 控制聊天频率的插件"""
# 插件基本信息
plugin_name: str = "better_frequency_plugin"
enable_plugin: bool = True
dependencies: List[str] = []
python_dependencies: List[str] = []
config_file_name: str = "config.toml"
# 配置节描述
config_section_descriptions = {"plugin": "插件基本信息", "frequency": "频率控制配置", "features": "功能开关配置"}
# 配置Schema定义
config_schema: dict = {
"plugin": {
"name": ConfigField(type=str, default="better_frequency_plugin", description="插件名称"),
"version": ConfigField(type=str, default="1.0.2", description="插件版本"),
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
},
"frequency": {
"default_talk_adjust": ConfigField(type=float, default=1.0, description="默认talk_frequency调整值"),
"max_adjust_value": ConfigField(type=float, default=1.0, description="最大调整值"),
"min_adjust_value": ConfigField(type=float, default=0.0, description="最小调整值"),
},
}
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
components = []
# 根据配置决定是否注册命令组件
if self.config.get("features", {}).get("enable_commands", True):
components.extend(
[
(SetTalkFrequencyCommand.get_command_info(), SetTalkFrequencyCommand),
(ShowFrequencyCommand.get_command_info(), ShowFrequencyCommand),
]
)
return components

View File

@@ -1,7 +1,7 @@
{ {
"manifest_version": 1, "manifest_version": 1,
"name": "BetterEmoji", "name": "BetterEmoji",
"version": "1.0.0", "version": "2.0.0",
"description": "更好的表情包管理插件", "description": "更好的表情包管理插件",
"author": { "author": {
"name": "SengokuCola", "name": "SengokuCola",
@@ -9,7 +9,7 @@
}, },
"license": "GPL-v3.0-or-later", "license": "GPL-v3.0-or-later",
"host_application": { "host_application": {
"min_version": "0.10.4" "min_version": "1.0.0"
}, },
"homepage_url": "https://github.com/SengokuCola/BetterEmoji", "homepage_url": "https://github.com/SengokuCola/BetterEmoji",
"repository_url": "https://github.com/SengokuCola/BetterEmoji", "repository_url": "https://github.com/SengokuCola/BetterEmoji",
@@ -19,46 +19,49 @@
"plugin" "plugin"
], ],
"categories": [ "categories": [
"Examples", "Emoji",
"Tutorial" "Management"
], ],
"default_locale": "zh-CN", "default_locale": "zh-CN",
"locales_path": "_locales", "locales_path": "_locales",
"plugin_info": { "plugin_info": {
"is_built_in": false, "is_built_in": false,
"plugin_type": "emoji_manage", "plugin_type": "emoji_manage",
"capabilities": [
"emoji.get_random",
"emoji.get_count",
"emoji.get_info",
"emoji.get_all",
"emoji.register_emoji",
"emoji.delete_emoji",
"send.text",
"send.forward"
],
"components": [ "components": [
{ {
"type": "action", "type": "command",
"name": "hello_greeting", "name": "add_emoji",
"description": "向用户发送问候消息" "description": "添加表情包",
}, "pattern": "/emoji add"
{
"type": "action",
"name": "bye_greeting",
"description": "向用户发送告别消息",
"activation_modes": [
"keyword"
],
"keywords": [
"再见",
"bye",
"88",
"拜拜"
]
}, },
{ {
"type": "command", "type": "command",
"name": "time", "name": "emoji_list",
"description": "查询当前时间", "description": "列表表情包",
"pattern": "/time" "pattern": "/emoji list"
},
{
"type": "command",
"name": "delete_emoji",
"description": "删除表情包",
"pattern": "/emoji delete"
},
{
"type": "command",
"name": "random_emojis",
"description": "发送多张随机表情包",
"pattern": "/random_emojis"
} }
],
"features": [
"问候和告别功能",
"时间查询命令",
"配置文件示例",
"新手教程代码"
] ]
}, },
"id": "SengokuCola.BetterEmoji" "id": "SengokuCola.BetterEmoji"

View File

@@ -1,399 +1,216 @@
from typing import List, Tuple, Type """表情包管理插件 — 新 SDK 版本
from src.plugin_system import (
BasePlugin,
register_plugin,
BaseCommand,
ComponentInfo,
ConfigField,
ReplyContentType,
emoji_api,
)
from maim_message import Seg
from src.common.logger import get_logger
logger = get_logger("emoji_manage_plugin") 通过 /emoji 命令管理表情包的添加、列表和删除。
"""
import base64
import datetime
import hashlib
import re
from maibot_sdk import MaiBotPlugin, Command
class AddEmojiCommand(BaseCommand): class EmojiManagePlugin(MaiBotPlugin):
command_name = "add_emoji" """表情包管理插件"""
command_description = "添加表情包"
command_pattern = r".*/emoji add.*"
async def execute(self) -> Tuple[bool, str, bool]: # ===== 工具方法 =====
# 查找消息中的表情包
# logger.info(f"查找消息中的表情包: {self.message.message_segment}")
emoji_base64_list = self.find_and_return_emoji_in_message(self.message.message_segment) @staticmethod
def _extract_emoji_base64(segments) -> list[str]:
"""从消息 segments 中提取 emoji/image 的 base64 数据。
segments 可以是 dict 列表或 Seg 对象列表(兼容两种格式)。
"""
results: list[str] = []
if not segments:
return results
if isinstance(segments, dict):
seg_type = segments.get("type", "")
if seg_type in ("emoji", "image"):
data = segments.get("data", "")
if data:
results.append(data)
elif seg_type == "seglist":
for child in segments.get("data", []):
results.extend(EmojiManagePlugin._extract_emoji_base64(child))
return results
# 如果有 .type 属性Seg 对象)
if hasattr(segments, "type"):
seg_type = getattr(segments, "type", "")
if seg_type in ("emoji", "image"):
results.append(getattr(segments, "data", ""))
elif seg_type == "seglist":
for child in getattr(segments, "data", []):
results.extend(EmojiManagePlugin._extract_emoji_base64(child))
return results
# 列表
for seg in segments:
results.extend(EmojiManagePlugin._extract_emoji_base64(seg))
return results
# ===== Command 组件 =====
@Command("add_emoji", description="添加表情包", pattern=r".*/emoji add.*")
async def handle_add_emoji(self, stream_id: str = "", message_segments=None, **kwargs):
"""添加表情包"""
emoji_base64_list = self._extract_emoji_base64(message_segments)
if not emoji_base64_list: if not emoji_base64_list:
await self.ctx.send.text("未在消息中找到表情包或图片", stream_id)
return False, "未在消息中找到表情包或图片", False return False, "未在消息中找到表情包或图片", False
# 注册找到的表情包
success_count = 0 success_count = 0
fail_count = 0 fail_count = 0
results = [] results = []
for i, emoji_base64 in enumerate(emoji_base64_list): for i, emoji_b64 in enumerate(emoji_base64_list):
try: result = await self.ctx.emoji.register_emoji(emoji_b64)
# 使用emoji_api注册表情包让API自动生成唯一文件名 if isinstance(result, dict) and result.get("success"):
result = await emoji_api.register_emoji(emoji_base64) success_count += 1
desc = result.get("description", "未知描述")
if result["success"]: emotions = result.get("emotions", [])
success_count += 1 replaced = result.get("replaced", False)
description = result.get("description", "未知描述") msg = f"表情包 {i + 1} 注册成功{'(替换旧表情包)' if replaced else '(新增表情包)'}"
emotions = result.get("emotions", []) if desc:
replaced = result.get("replaced", False) msg += f"\n描述: {desc}"
if emotions:
result_msg = f"表情包 {i + 1} 注册成功{'(替换旧表情包)' if replaced else '(新增表情包)'}" msg += f"\n情感标签: {', '.join(emotions)}"
if description: results.append(msg)
result_msg += f"\n描述: {description}"
if emotions:
result_msg += f"\n情感标签: {', '.join(emotions)}"
results.append(result_msg)
else:
fail_count += 1
error_msg = result.get("message", "注册失败")
results.append(f"表情包 {i + 1} 注册失败: {error_msg}")
except Exception as e:
fail_count += 1
results.append(f"表情包 {i + 1} 注册时发生错误: {str(e)}")
# 构建返回消息
total_count = success_count + fail_count
summary_msg = f"表情包注册完成: 成功 {success_count} 个,失败 {fail_count} 个,共处理 {total_count}"
# 如果有结果详情,添加到返回消息中
details_msg = ""
if results:
details_msg = "\n" + "\n".join(results)
final_msg = summary_msg + details_msg
else:
final_msg = summary_msg
# 使用表达器重写回复
try:
from src.plugin_system.apis import generator_api
# 构建重写数据
rewrite_data = {
"raw_reply": summary_msg,
"reason": f"注册了表情包:{details_msg}\n",
}
# 调用表达器重写
result_status, data = await generator_api.rewrite_reply(
chat_stream=self.message.chat_stream,
reply_data=rewrite_data,
)
if result_status:
# 发送重写后的回复
for reply_seg in data.reply_set.reply_data:
send_data = reply_seg.content
await self.send_text(send_data)
return success_count > 0, final_msg, success_count > 0
else: else:
# 如果重写失败,发送原始消息 fail_count += 1
await self.send_text(final_msg) err = result.get("message", "注册失败") if isinstance(result, dict) else "注册失败"
return success_count > 0, final_msg, success_count > 0 results.append(f"表情包 {i + 1} 注册失败: {err}")
except Exception as e: total = success_count + fail_count
# 如果表达器调用失败,发送原始消息 summary = f"表情包注册完成: 成功 {success_count} 个,失败 {fail_count} 个,共处理 {total}"
logger.error(f"[add_emoji] 表达器重写失败: {e}") if results:
await self.send_text(final_msg) summary += "\n" + "\n".join(results)
return success_count > 0, final_msg, success_count > 0
def find_and_return_emoji_in_message(self, message_segments) -> List[str]: await self.ctx.send.text(summary, stream_id)
emoji_base64_list = [] return success_count > 0, summary, success_count > 0
# 处理单个Seg对象的情况 @Command("emoji_list", description="列表表情包", pattern=r"^/emoji list(\s+\d+)?$")
if isinstance(message_segments, Seg): async def handle_list_emoji(self, stream_id: str = "", raw_message: str = "", **kwargs):
if message_segments.type == "emoji": """列出表情包"""
emoji_base64_list.append(message_segments.data) max_count = 10
elif message_segments.type == "image": match = re.match(r"^/emoji list(?:\s+(\d+))?$", raw_message)
# 假设图片数据是base64编码的
emoji_base64_list.append(message_segments.data)
elif message_segments.type == "seglist":
# 递归处理嵌套的Seg列表
emoji_base64_list.extend(self.find_and_return_emoji_in_message(message_segments.data))
return emoji_base64_list
# 处理Seg列表的情况
for seg in message_segments:
if seg.type == "emoji":
emoji_base64_list.append(seg.data)
elif seg.type == "image":
# 假设图片数据是base64编码的
emoji_base64_list.append(seg.data)
elif seg.type == "seglist":
# 递归处理嵌套的Seg列表
emoji_base64_list.extend(self.find_and_return_emoji_in_message(seg.data))
return emoji_base64_list
class ListEmojiCommand(BaseCommand):
"""列表表情包Command - 响应/emoji list命令"""
command_name = "emoji_list"
command_description = "列表表情包"
# === 命令设置(必须填写)===
command_pattern = r"^/emoji list(\s+\d+)?$" # 匹配 "/emoji list" 或 "/emoji list 数量"
async def execute(self) -> Tuple[bool, str, bool]:
"""执行列表表情包"""
from src.plugin_system.apis import emoji_api
import datetime
# 解析命令参数
import re
match = re.match(r"^/emoji list(?:\s+(\d+))?$", self.message.raw_message)
max_count = 10 # 默认显示10个
if match and match.group(1): if match and match.group(1):
max_count = min(int(match.group(1)), 50) # 最多显示50个 max_count = min(int(match.group(1)), 50)
# 获取当前时间 now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
time_format: str = self.get_config("time.format", "%Y-%m-%d %H:%M:%S") # type: ignore
now = datetime.datetime.now()
time_str = now.strftime(time_format)
# 获取表情包信息 count_result = await self.ctx.emoji.get_count()
emoji_count = emoji_api.get_count() emoji_count = count_result if isinstance(count_result, int) else 0
emoji_info = emoji_api.get_info()
# 构建返回消息 info_result = await self.ctx.emoji.get_info()
message_lines = [ max_emoji = info_result.get("max_count", 0) if isinstance(info_result, dict) else 0
f"📊 表情包统计信息 ({time_str})", available = info_result.get("available_emojis", 0) if isinstance(info_result, dict) else 0
f"• 总数: {emoji_count} / {emoji_info['max_count']}",
f"• 可用: {emoji_info['available_emojis']}", lines = [
f"📊 表情包统计信息 ({now})",
f"• 总数: {emoji_count} / {max_emoji}",
f"• 可用: {available}",
] ]
if emoji_count == 0: if emoji_count == 0:
message_lines.append("\n❌ 暂无表情包") lines.append("\n❌ 暂无表情包")
final_message = "\n".join(message_lines) await self.ctx.send.text("\n".join(lines), stream_id)
await self.send_text(final_message) return True, "\n".join(lines), True
return True, final_message, True
# 获取所有表情包 all_result = await self.ctx.emoji.get_all()
all_emojis = await emoji_api.get_all() all_emojis = all_result if isinstance(all_result, list) else []
if not all_emojis: if not all_emojis:
message_lines.append("\n❌ 无法获取表情包列表") lines.append("\n❌ 无法获取表情包列表")
final_message = "\n".join(message_lines) await self.ctx.send.text("\n".join(lines), stream_id)
await self.send_text(final_message) return False, "\n".join(lines), True
return False, final_message, True
# 显示前N个表情包 display = all_emojis[:max_count]
display_emojis = all_emojis[:max_count] lines.append(f"\n📋 显示前 {len(display)} 个表情包:")
message_lines.append(f"\n📋 显示前 {len(display_emojis)} 个表情包:") for i, emoji in enumerate(display, 1):
if isinstance(emoji, (list, tuple)) and len(emoji) >= 3:
_, desc, emotion = emoji[0], emoji[1], emoji[2]
elif isinstance(emoji, dict):
desc = emoji.get("description", "")
emotion = emoji.get("emotion", "")
else:
desc, emotion = str(emoji), ""
short_desc = desc[:50] + "..." if len(desc) > 50 else desc
lines.append(f"{i}. {short_desc} [{emotion}]")
for i, (_, description, emotion) in enumerate(display_emojis, 1):
# 截断过长的描述
short_desc = description[:50] + "..." if len(description) > 50 else description
message_lines.append(f"{i}. {short_desc} [{emotion}]")
# 如果还有更多表情包,显示总数
if len(all_emojis) > max_count: if len(all_emojis) > max_count:
message_lines.append(f"\n💡 还有 {len(all_emojis) - max_count} 个表情包未显示") lines.append(f"\n💡 还有 {len(all_emojis) - max_count} 个表情包未显示")
final_message = "\n".join(message_lines) final = "\n".join(lines)
await self.ctx.send.text(final, stream_id)
# 直接发送文本消息 return True, final, True
await self.send_text(final_message)
return True, final_message, True
class DeleteEmojiCommand(BaseCommand):
command_name = "delete_emoji"
command_description = "删除表情包"
command_pattern = r".*/emoji delete.*"
async def execute(self) -> Tuple[bool, str, bool]:
# 查找消息中的表情包图片
logger.info(f"查找消息中的表情包用于删除: {self.message.message_segment}")
emoji_base64_list = self.find_and_return_emoji_in_message(self.message.message_segment)
@Command("delete_emoji", description="删除表情包", pattern=r".*/emoji delete.*")
async def handle_delete_emoji(self, stream_id: str = "", message_segments=None, **kwargs):
"""删除表情包"""
emoji_base64_list = self._extract_emoji_base64(message_segments)
if not emoji_base64_list: if not emoji_base64_list:
return False, "未在消息中找到表情包或图片", False await self.ctx.send.text("未在消息中找到表情包或图片", stream_id)
return False, "未找到表情包", False
# 删除找到的表情包
success_count = 0 success_count = 0
fail_count = 0 fail_count = 0
results = [] results = []
for i, emoji_base64 in enumerate(emoji_base64_list): for i, emoji_b64 in enumerate(emoji_base64_list):
try: # 计算哈希
# 计算图片的哈希值来查找对应的表情包 if isinstance(emoji_b64, str):
import base64 clean = emoji_b64.encode("ascii", errors="ignore").decode("ascii")
import hashlib
# 确保base64字符串只包含ASCII字符
if isinstance(emoji_base64, str):
emoji_base64_clean = emoji_base64.encode("ascii", errors="ignore").decode("ascii")
else:
emoji_base64_clean = str(emoji_base64)
# 计算哈希值
image_bytes = base64.b64decode(emoji_base64_clean)
emoji_hash = hashlib.md5(image_bytes).hexdigest()
# 使用emoji_api删除表情包
result = await emoji_api.delete_emoji(emoji_hash)
if result["success"]:
success_count += 1
description = result.get("description", "未知描述")
count_before = result.get("count_before", 0)
count_after = result.get("count_after", 0)
emotions = result.get("emotions", [])
result_msg = f"表情包 {i + 1} 删除成功"
if description:
result_msg += f"\n描述: {description}"
if emotions:
result_msg += f"\n情感标签: {', '.join(emotions)}"
result_msg += f"\n表情包数量: {count_before}{count_after}"
results.append(result_msg)
else:
fail_count += 1
error_msg = result.get("message", "删除失败")
results.append(f"表情包 {i + 1} 删除失败: {error_msg}")
except Exception as e:
fail_count += 1
results.append(f"表情包 {i + 1} 删除时发生错误: {str(e)}")
# 构建返回消息
total_count = success_count + fail_count
summary_msg = f"表情包删除完成: 成功 {success_count} 个,失败 {fail_count} 个,共处理 {total_count}"
# 如果有结果详情,添加到返回消息中
details_msg = ""
if results:
details_msg = "\n" + "\n".join(results)
final_msg = summary_msg + details_msg
else:
final_msg = summary_msg
# 使用表达器重写回复
try:
from src.plugin_system.apis import generator_api
# 构建重写数据
rewrite_data = {
"raw_reply": summary_msg,
"reason": f"删除了表情包:{details_msg}\n",
}
# 调用表达器重写
result_status, data = await generator_api.rewrite_reply(
chat_stream=self.message.chat_stream,
reply_data=rewrite_data,
)
if result_status:
# 发送重写后的回复
for reply_seg in data.reply_set.reply_data:
send_data = reply_seg.content
await self.send_text(send_data)
return success_count > 0, final_msg, success_count > 0
else: else:
# 如果重写失败,发送原始消息 clean = str(emoji_b64)
await self.send_text(final_msg) image_bytes = base64.b64decode(clean)
return success_count > 0, final_msg, success_count > 0 emoji_hash = hashlib.md5(image_bytes).hexdigest() # noqa: S324
except Exception as e: result = await self.ctx.emoji.delete_emoji(emoji_hash)
# 如果表达器调用失败,发送原始消息 if isinstance(result, dict) and result.get("success"):
logger.error(f"[delete_emoji] 表达器重写失败: {e}") success_count += 1
await self.send_text(final_msg) desc = result.get("description", "未知描述")
return success_count > 0, final_msg, success_count > 0 emotions = result.get("emotions", [])
before = result.get("count_before", 0)
after = result.get("count_after", 0)
msg = f"表情包 {i + 1} 删除成功"
if desc:
msg += f"\n描述: {desc}"
if emotions:
msg += f"\n情感标签: {', '.join(emotions)}"
msg += f"\n表情包数量: {before}{after}"
results.append(msg)
else:
fail_count += 1
err = result.get("message", "删除失败") if isinstance(result, dict) else "删除失败"
results.append(f"表情包 {i + 1} 删除失败: {err}")
def find_and_return_emoji_in_message(self, message_segments) -> List[str]: total = success_count + fail_count
emoji_base64_list = [] summary = f"表情包删除完成: 成功 {success_count} 个,失败 {fail_count} 个,共处理 {total}"
if results:
summary += "\n" + "\n".join(results)
# 处理单个Seg对象的情况 await self.ctx.send.text(summary, stream_id)
if isinstance(message_segments, Seg): return success_count > 0, summary, success_count > 0
if message_segments.type == "emoji":
emoji_base64_list.append(message_segments.data)
elif message_segments.type == "image":
# 假设图片数据是base64编码的
emoji_base64_list.append(message_segments.data)
elif message_segments.type == "seglist":
# 递归处理嵌套的Seg列表
emoji_base64_list.extend(self.find_and_return_emoji_in_message(message_segments.data))
return emoji_base64_list
# 处理Seg列表的情况 @Command("random_emojis", description="发送多张随机表情包", pattern=r"^/random_emojis$")
for seg in message_segments: async def handle_random_emojis(self, stream_id: str = "", **kwargs):
if seg.type == "emoji": """发送多张随机表情包"""
emoji_base64_list.append(seg.data) result = await self.ctx.emoji.get_random(5)
elif seg.type == "image": if not result or not result.get("success"):
# 假设图片数据是base64编码的 return False, "未找到表情包", False
emoji_base64_list.append(seg.data) emojis = result.get("emojis", [])
elif seg.type == "seglist":
# 递归处理嵌套的Seg列表
emoji_base64_list.extend(self.find_and_return_emoji_in_message(seg.data))
return emoji_base64_list
class RandomEmojis(BaseCommand):
command_name = "random_emojis"
command_description = "发送多张随机表情包"
command_pattern = r"^/random_emojis$"
async def execute(self):
emojis = await emoji_api.get_random(5)
if not emojis: if not emojis:
return False, "未找到表情包", False return False, "未找到表情包", False
emoji_base64_list = [] messages = [
for emoji in emojis: {"user_id": "0", "nickname": "神秘用户", "segments": [{"type": "image", "content": e.get("base64", "")}]}
emoji_base64_list.append(emoji[0]) for e in emojis
return await self.forward_images(emoji_base64_list)
async def forward_images(self, images: List[str]):
"""
把多张图片用合并转发的方式发给用户
"""
success = await self.send_forward([("0", "神秘用户", [(ReplyContentType.IMAGE, img)]) for img in images])
return (True, "已发送随机表情包", True) if success else (False, "发送随机表情包失败", False)
# ===== 插件注册 =====
@register_plugin
class EmojiManagePlugin(BasePlugin):
"""表情包管理插件 - 管理表情包"""
# 插件基本信息
plugin_name: str = "emoji_manage_plugin" # 内部标识符
enable_plugin: bool = False
dependencies: List[str] = [] # 插件依赖列表
python_dependencies: List[str] = [] # Python包依赖列表
config_file_name: str = "config.toml" # 配置文件名
# 配置节描述
config_section_descriptions = {"plugin": "插件基本信息", "emoji": "表情包功能配置"}
# 配置Schema定义
config_schema: dict = {
"plugin": {
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
"config_version": ConfigField(type=str, default="1.0.1", description="配置文件版本"),
},
}
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
return [
(RandomEmojis.get_command_info(), RandomEmojis),
(AddEmojiCommand.get_command_info(), AddEmojiCommand),
(ListEmojiCommand.get_command_info(), ListEmojiCommand),
(DeleteEmojiCommand.get_command_info(), DeleteEmojiCommand),
] ]
await self.ctx.send.forward(messages, stream_id)
return True, "已发送随机表情包", True
def create_plugin():
return EmojiManagePlugin()

View File

@@ -1,7 +1,7 @@
{ {
"manifest_version": 1, "manifest_version": 1,
"name": "Hello World 示例插件 (Hello World Plugin)", "name": "Hello World 示例插件 (Hello World Plugin)",
"version": "1.0.0", "version": "2.0.0",
"description": "我的第一个MaiCore插件包含问候功能和时间查询等基础示例", "description": "我的第一个MaiCore插件包含问候功能和时间查询等基础示例",
"author": { "author": {
"name": "MaiBot开发团队", "name": "MaiBot开发团队",
@@ -9,7 +9,7 @@
}, },
"license": "GPL-v3.0-or-later", "license": "GPL-v3.0-or-later",
"host_application": { "host_application": {
"min_version": "0.8.0" "min_version": "1.0.0"
}, },
"homepage_url": "https://github.com/MaiM-with-u/maibot", "homepage_url": "https://github.com/MaiM-with-u/maibot",
"repository_url": "https://github.com/MaiM-with-u/maibot", "repository_url": "https://github.com/MaiM-with-u/maibot",
@@ -29,7 +29,19 @@
"plugin_info": { "plugin_info": {
"is_built_in": false, "is_built_in": false,
"plugin_type": "example", "plugin_type": "example",
"capabilities": [
"send.text",
"send.forward",
"send.hybrid",
"emoji.get_random",
"config.get"
],
"components": [ "components": [
{
"type": "tool",
"name": "compare_numbers",
"description": "比较两个数的大小"
},
{ {
"type": "action", "type": "action",
"name": "hello_greeting", "name": "hello_greeting",
@@ -39,28 +51,37 @@
"type": "action", "type": "action",
"name": "bye_greeting", "name": "bye_greeting",
"description": "向用户发送告别消息", "description": "向用户发送告别消息",
"activation_modes": [ "activation_modes": ["keyword"],
"keyword" "keywords": ["再见", "bye", "88", "拜拜"]
],
"keywords": [
"再见",
"bye",
"88",
"拜拜"
]
}, },
{ {
"type": "command", "type": "command",
"name": "time", "name": "time",
"description": "查询当前时间", "description": "查询当前时间",
"pattern": "/time" "pattern": "/time"
},
{
"type": "command",
"name": "random_emojis",
"description": "发送多张随机表情包",
"pattern": "/random_emojis"
},
{
"type": "command",
"name": "test",
"description": "测试命令",
"pattern": "/test"
},
{
"type": "event_handler",
"name": "print_message_handler",
"description": "打印接收到的消息"
},
{
"type": "event_handler",
"name": "forward_messages_handler",
"description": "把接收到的消息转发到指定聊天ID"
} }
],
"features": [
"问候和告别功能",
"时间查询命令",
"配置文件示例",
"新手教程代码"
] ]
}, },
"id": "MaiBot开发团队.maibot" "id": "MaiBot开发团队.maibot"

View File

@@ -1,50 +1,30 @@
"""Hello World 示例插件 — 新 SDK 版本
你的第一个 MaiCore 插件,包含问候功能、时间查询等基础示例。
"""
import datetime
import random import random
from typing import List, Tuple, Type, Any, Optional
from src.plugin_system import (
BasePlugin,
register_plugin,
BaseAction,
BaseCommand,
BaseTool,
ComponentInfo,
ActionActivationType,
ConfigField,
BaseEventHandler,
EventType,
MaiMessages,
ToolParamType,
ReplyContentType,
emoji_api,
)
from src.config.config import global_config
from src.common.logger import get_logger
logger = get_logger("hello_world_plugin") from maibot_sdk import MaiBotPlugin, Action, Command, Tool, EventHandler
from maibot_sdk.types import ActivationType, EventType, ToolParameterInfo, ToolParamType
class CompareNumbersTool(BaseTool): class HelloWorldPlugin(MaiBotPlugin):
"""比较两个数大小的工具""" """Hello World 示例插件"""
name = "compare_numbers" # ===== Tool 组件 =====
description = "使用工具 比较两个数的大小,返回较大的数"
parameters = [
("num1", ToolParamType.FLOAT, "第一个数字", True, None),
("num2", ToolParamType.FLOAT, "第二个数字", True, None),
]
available_for_llm = True
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行比较两个数的大小
Args:
function_args: 工具参数
Returns:
dict: 工具执行结果
"""
num1: int | float = function_args.get("num1") # type: ignore
num2: int | float = function_args.get("num2") # type: ignore
@Tool(
"compare_numbers",
description="使用工具比较两个数的大小,返回较大的数",
parameters=[
ToolParameterInfo(name="num1", param_type=ToolParamType.FLOAT, description="第一个数字", required=True),
ToolParameterInfo(name="num2", param_type=ToolParamType.FLOAT, description="第二个数字", required=True),
],
)
async def handle_compare_numbers(self, num1: float = 0, num2: float = 0, **kwargs):
"""比较两个数的大小"""
try: try:
if num1 > num2: if num1 > num2:
result = f"{num1} 大于 {num2}" result = f"{num1} 大于 {num2}"
@@ -52,270 +32,121 @@ class CompareNumbersTool(BaseTool):
result = f"{num1} 小于 {num2}" result = f"{num1} 小于 {num2}"
else: else:
result = f"{num1} 等于 {num2}" result = f"{num1} 等于 {num2}"
return {"name": "compare_numbers", "content": result}
return {"name": self.name, "content": result}
except Exception as e: except Exception as e:
return {"name": self.name, "content": f"比较数字失败,炸了: {str(e)}"} return {"name": "compare_numbers", "content": f"比较数字失败,炸了: {e}"}
# ===== Action 组件 =====
# ===== Action组件 ===== @Action(
class HelloAction(BaseAction): "hello_greeting",
"""问候Action - 简单的问候动作""" description="向用户发送问候消息",
activation_type=ActivationType.ALWAYS,
# === 基本信息(必须填写)=== action_parameters={"greeting_message": "要发送的问候消息"},
action_name = "hello_greeting" action_require=["需要发送友好问候时使用", "当有人向你问好时使用", "当你遇见没有见过的人时使用"],
action_description = "向用户发送问候消息" associated_types=["text"],
activation_type = ActionActivationType.ALWAYS # 始终激活 )
async def handle_hello(self, stream_id: str = "", greeting_message: str = "", **kwargs):
# === 功能描述(必须填写)=== """问候动作"""
action_parameters = {"greeting_message": "要发送的问候消息"} config_result = await self.ctx.config.get("greeting.message")
action_require = ["需要发送友好问候时使用", "当有人向你问好时使用", "当你遇见没有见过的人时使用"] base_message = config_result if isinstance(config_result, str) else "嗨!很开心见到你!😊"
associated_types = ["text"]
async def execute(self) -> Tuple[bool, str]:
"""执行问候动作 - 这是核心功能"""
# 发送问候消息
greeting_message = self.action_data.get("greeting_message", "")
base_message = self.get_config("greeting.message", "嗨!很开心见到你!😊")
message = base_message + greeting_message message = base_message + greeting_message
await self.send_text(message) await self.ctx.send.text(message, stream_id)
return True, "发送了问候消息" return True, "发送了问候消息"
@Action(
class ByeAction(BaseAction): "bye_greeting",
"""告别Action - 只在用户说再见时激活""" description="向用户发送告别消息",
activation_type=ActivationType.KEYWORD,
action_name = "bye_greeting" activation_keywords=["再见", "bye", "88", "拜拜"],
action_description = "向用户发送告别消息" action_parameters={"bye_message": "发送告别消息"},
action_require=["用户要告别时使用", "当有人要离开时使用", "当有人和你说再见时使用"],
# 使用关键词激活 associated_types=["text"],
activation_type = ActionActivationType.KEYWORD )
async def handle_bye(self, stream_id: str = "", bye_message: str = "", **kwargs):
# 关键词设置 """告别动作"""
activation_keywords = ["再见", "bye", "88", "拜拜"]
keyword_case_sensitive = False
action_parameters = {"bye_message": "要发送的告别消息"}
action_require = [
"用户要告别时使用",
"当有人要离开时使用",
"当有人和你说再见时使用",
]
associated_types = ["text"]
async def execute(self) -> Tuple[bool, str]:
bye_message = self.action_data.get("bye_message", "")
message = f"再见!期待下次聊天!👋{bye_message}" message = f"再见!期待下次聊天!👋{bye_message}"
await self.send_text(message) await self.ctx.send.text(message, stream_id)
return True, "发送了告别消息" return True, "发送了告别消息"
# ===== Command 组件 =====
class TimeCommand(BaseCommand): @Command("time", description="查询当前时间", pattern=r"^/time$")
"""时间查询Command - 响应/time命令""" async def handle_time(self, stream_id: str = "", **kwargs):
"""时间查询命令"""
command_name = "time" config_result = await self.ctx.config.get("time.format")
command_description = "查询当前时间" time_format = config_result if isinstance(config_result, str) else "%Y-%m-%d %H:%M:%S"
# === 命令设置(必须填写)===
command_pattern = r"^/time$" # 精确匹配 "/time" 命令
async def execute(self) -> Tuple[bool, str, bool]:
"""执行时间查询"""
import datetime
# 获取当前时间
time_format: str = self.get_config("time.format", "%Y-%m-%d %H:%M:%S") # type: ignore
now = datetime.datetime.now() now = datetime.datetime.now()
time_str = now.strftime(time_format) time_str = now.strftime(time_format)
await self.ctx.send.text(f"⏰ 当前时间:{time_str}", stream_id)
# 发送时间信息
message = f"⏰ 当前时间:{time_str}"
await self.send_text(message)
return True, f"显示了当前时间: {time_str}", True return True, f"显示了当前时间: {time_str}", True
@Command("random_emojis", description="发送多张随机表情包", pattern=r"^/random_emojis$")
async def handle_random_emojis(self, stream_id: str = "", **kwargs):
"""发送多张随机表情包"""
result = await self.ctx.emoji.get_random(5)
if not result or not result.get("success"):
return False, "未找到表情包", False
emojis = result.get("emojis", [])
if not emojis:
return False, "未找到表情包", False
# 用转发消息发送多张图片
messages = [
{"user_id": "0", "nickname": "神秘用户", "segments": [{"type": "image", "content": e.get("base64", "")}]}
for e in emojis
]
await self.ctx.send.forward(messages, stream_id)
return True, "已发送随机表情包", True
class PrintMessage(BaseEventHandler): @Command("test", description="测试命令", pattern=r"^/test$")
"""打印消息事件处理器 - 处理打印消息事件""" async def handle_test(self, stream_id: str = "", **kwargs):
"""测试命令 — 发送简单测试消息"""
await self.ctx.send.text("测试正常Bot 功能运行中 ✅", stream_id)
return True, "测试完成", True
event_type = EventType.ON_MESSAGE # ===== EventHandler 组件 =====
handler_name = "print_message_handler"
handler_description = "打印接收到的消息"
async def execute(self, message: MaiMessages | None) -> Tuple[bool, bool, str | None, None, None]: @EventHandler("print_message_handler", description="打印接收到的消息", event_type=EventType.ON_MESSAGE)
"""执行打印消息事件处理""" async def handle_print_message(self, message=None, **kwargs):
# 打印接收到的消息 """打印消息事件"""
if self.get_config("print_message.enabled", False): config_result = await self.ctx.config.get("print_message.enabled")
print(f"接收到消息: {message.raw_message if message else '无效消息'}") enabled = config_result if isinstance(config_result, bool) else False
if enabled and message:
raw = message.get("raw_message", "") if isinstance(message, dict) else str(message)
print(f"接收到消息: {raw}")
return True, True, "消息已打印", None, None return True, True, "消息已打印", None, None
@EventHandler("forward_messages_handler", description="把接收到的消息转发到指定聊天ID", event_type=EventType.ON_MESSAGE)
class ForwardMessages(BaseEventHandler): async def handle_forward_messages(self, message=None, stream_id: str = "", **kwargs):
""" """收集消息并定期转发"""
把接收到的消息转发到指定聊天ID
此组件是HYBRID消息和FORWARD消息的使用示例。
每收到10条消息就会以1%的概率使用HYBRID消息转发否则使用FORWARD消息转发。
"""
event_type = EventType.ON_MESSAGE
handler_name = "forward_messages_handler"
handler_description = "把接收到的消息转发到指定聊天ID"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.counter = 0 # 用于计数转发的消息数量
self.messages: List[str] = []
async def execute(self, message: MaiMessages | None) -> Tuple[bool, bool, None, None, None]:
if not message: if not message:
return True, True, None, None, None return True, True, None, None, None
stream_id = message.stream_id or "" plain_text = message.get("plain_text", "") if isinstance(message, dict) else ""
if not plain_text:
return True, True, None, None, None
if message.plain_text: # 使用插件级状态收集消息
self.messages.append(message.plain_text) if not hasattr(self, "_fwd_messages"):
self.counter += 1 self._fwd_messages: list[str] = []
if self.counter % 10 == 0: self._fwd_counter: int = 0
self._fwd_messages.append(plain_text)
self._fwd_counter += 1
if self._fwd_counter % 10 == 0 and stream_id:
if random.random() < 0.01: if random.random() < 0.01:
success = await self.send_hybrid(stream_id, [(ReplyContentType.TEXT, msg) for msg in self.messages]) segments = [{"type": "text", "content": msg} for msg in self._fwd_messages]
await self.ctx.send.hybrid(segments, stream_id)
else: else:
success = await self.send_forward( messages = [
stream_id, {"user_id": "0", "nickname": "转发", "segments": [{"type": "text", "content": msg}]}
[ for msg in self._fwd_messages
( ]
str(global_config.bot.qq_account), await self.ctx.send.forward(messages, stream_id)
str(global_config.bot.nickname), self._fwd_messages = []
[(ReplyContentType.TEXT, msg)],
)
for msg in self.messages
],
)
if not success:
raise ValueError("转发消息失败")
self.messages = []
return True, True, None, None, None return True, True, None, None, None
class RandomEmojis(BaseCommand): def create_plugin():
command_name = "random_emojis" return HelloWorldPlugin()
command_description = "发送多张随机表情包"
command_pattern = r"^/random_emojis$"
async def execute(self):
emojis = await emoji_api.get_random(5)
if not emojis:
return False, "未找到表情包", False
emoji_base64_list = []
for emoji in emojis:
emoji_base64_list.append(emoji[0])
return await self.forward_images(emoji_base64_list)
async def forward_images(self, images: List[str]):
"""
把多张图片用合并转发的方式发给用户
"""
success = await self.send_forward([("0", "神秘用户", [(ReplyContentType.IMAGE, img)]) for img in images])
return (True, "已发送随机表情包", True) if success else (False, "发送随机表情包失败", False)
class TestCommand(BaseCommand):
"""响应/test命令"""
command_name = "test"
command_description = "测试命令"
command_pattern = r"^/test$"
async def execute(self) -> Tuple[bool, Optional[str], int]:
"""执行测试命令"""
try:
from src.plugin_system.apis import generator_api
reply_reason = "这是一条测试消息。"
logger.info(f"测试命令:{reply_reason}")
result_status, data = await generator_api.generate_reply(
chat_stream=self.message.chat_stream,
reply_reason=reply_reason,
enable_chinese_typo=False,
extra_info=f'{reply_reason}用于测试bot的功能是否正常。请你按设定的人设表达一句"测试正常"',
)
if result_status:
# 发送生成的回复
if data and data.reply_set and data.reply_set.reply_data:
for reply_seg in data.reply_set.reply_data:
send_data = reply_seg.content
await self.send_text(send_data, storage_message=True)
logger.info(f"已回复: {send_data}")
return True, "", 1
except Exception as e:
logger.error(f"表达器生成失败:{e}")
return True, "", 1
# ===== 插件注册 =====
@register_plugin
class HelloWorldPlugin(BasePlugin):
"""Hello World插件 - 你的第一个MaiCore插件"""
# 插件基本信息
plugin_name: str = "hello_world_plugin" # 内部标识符
enable_plugin: bool = False
dependencies: List[str] = [] # 插件依赖列表
python_dependencies: List[str] = [] # Python包依赖列表
config_file_name: str = "config.toml" # 配置文件名
# 配置节描述
config_section_descriptions = {"plugin": "插件基本信息", "greeting": "问候功能配置", "time": "时间查询配置"}
# 配置Schema定义
config_schema: dict = {
"plugin": {
"config_version": ConfigField(type=str, default="1.0.0", description="配置文件版本"),
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
},
"greeting": {
"message": ConfigField(
type=list, default=["嗨!很开心见到你!😊", "Ciallo(∠・ω< )⌒★"], description="默认问候消息"
),
"enable_emoji": ConfigField(type=bool, default=True, description="是否启用表情符号"),
},
"time": {"format": ConfigField(type=str, default="%Y-%m-%d %H:%M:%S", description="时间显示格式")},
"print_message": {"enabled": ConfigField(type=bool, default=True, description="是否启用打印")},
}
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
return [
(HelloAction.get_action_info(), HelloAction),
(CompareNumbersTool.get_tool_info(), CompareNumbersTool), # 添加比较数字工具
(ByeAction.get_action_info(), ByeAction), # 添加告别Action
(TimeCommand.get_command_info(), TimeCommand),
(PrintMessage.get_handler_info(), PrintMessage),
(ForwardMessages.get_handler_info(), ForwardMessages),
(RandomEmojis.get_command_info(), RandomEmojis),
(TestCommand.get_command_info(), TestCommand),
]
# @register_plugin
# class HelloWorldEventPlugin(BaseEPlugin):
# """Hello World事件插件 - 处理问候和告别事件"""
# plugin_name = "hello_world_event_plugin"
# enable_plugin = False
# dependencies = []
# python_dependencies = []
# config_file_name = "event_config.toml"
# config_schema = {
# "plugin": {
# "name": ConfigField(type=str, default="hello_world_event_plugin", description="插件名称"),
# "version": ConfigField(type=str, default="1.0.0", description="插件版本"),
# "enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
# },
# }
# def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
# return [(PrintMessage.get_handler_info(), PrintMessage)]

View File

@@ -37,6 +37,7 @@ dependencies = [
"uvicorn>=0.35.0", "uvicorn>=0.35.0",
"msgpack>=1.1.2", "msgpack>=1.1.2",
"watchfiles>=1.1.1", "watchfiles>=1.1.1",
"maibot-plugin-sdk>=1.0.0",
] ]

View File

@@ -6,6 +6,7 @@ fastapi>=0.116.0
google-genai>=1.39.1 google-genai>=1.39.1
jieba>=0.42.1 jieba>=0.42.1
json-repair>=0.47.6 json-repair>=0.47.6
maibot-plugin-sdk>=1.0.0
maim-message>=0.6.2 maim-message>=0.6.2
matplotlib>=3.10.3 matplotlib>=3.10.3
numpy>=2.2.6 numpy>=2.2.6

View File

@@ -19,9 +19,10 @@ from src.chat.heart_flow.hfc_utils import CycleDetail
from src.bw_learner.expression_learner import expression_learner_manager from src.bw_learner.expression_learner import expression_learner_manager
from src.bw_learner.message_recorder import extract_and_distribute_messages from src.bw_learner.message_recorder import extract_and_distribute_messages
from src.person_info.person_info import Person from src.person_info.person_info import Person
from src.plugin_system.base.component_types import EventType, ActionInfo from src.core.types import ActionInfo, EventType
from src.plugin_system.core import events_manager from src.core.event_bus import event_bus
from src.plugin_system.apis import generator_api, send_api, message_api, database_api from src.chat.event_helpers import build_event_message
from src.services import generator_service as generator_api, send_service as send_api, message_service as message_api, database_service as database_api
from src.chat.utils.chat_message_builder import ( from src.chat.utils.chat_message_builder import (
build_readable_messages_with_id, build_readable_messages_with_id,
get_raw_msg_before_timestamp_with_chat, get_raw_msg_before_timestamp_with_chat,
@@ -315,8 +316,9 @@ class BrainChatting:
message_id_list=message_id_list, message_id_list=message_id_list,
prompt_key="brain_planner", prompt_key="brain_planner",
) )
continue_flag, modified_message = await events_manager.handle_mai_events( _event_msg = build_event_message(EventType.ON_PLAN, llm_prompt=prompt_info[0], stream_id=self.chat_stream.stream_id)
EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id continue_flag, modified_message = await event_bus.emit(
EventType.ON_PLAN, _event_msg
) )
if not continue_flag: if not continue_flag:
return False return False

View File

@@ -23,8 +23,8 @@ from src.chat.utils.chat_message_builder import (
from src.chat.utils.utils import get_chat_type_and_target_info from src.chat.utils.utils import get_chat_type_and_target_info
from src.chat.planner_actions.action_manager import ActionManager from src.chat.planner_actions.action_manager import ActionManager
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.plugin_system.base.component_types import ActionInfo, ComponentType, ActionActivationType from src.core.types import ActionActivationType, ActionInfo, ComponentType
from src.plugin_system.core.component_registry import component_registry from src.core.component_registry import component_registry
if TYPE_CHECKING: if TYPE_CHECKING:
from src.common.data_models.info_data_model import TargetPersonInfo from src.common.data_models.info_data_model import TargetPersonInfo

169
src/chat/event_helpers.py Normal file
View File

@@ -0,0 +1,169 @@
"""
事件消息构建工具
将 chat 层的消息对象 (SessionMessage / MessageSending) 转换为
核心事件系统使用的 MaiMessages供调用 event_bus.emit() 前使用。
"""
from typing import List, Optional, TYPE_CHECKING
from src.common.logger import get_logger
from src.core.types import EventType, MaiMessages
if TYPE_CHECKING:
from src.chat.message_receive.message import MessageSending, SessionMessage
from src.common.data_models.llm_data_model import LLMGenerationDataModel
logger = get_logger("event_helpers")
def build_event_message(
event_type: EventType | str,
message: Optional["SessionMessage | MessageSending | MaiMessages"] = None,
llm_prompt: Optional[str] = None,
llm_response: Optional["LLMGenerationDataModel"] = None,
stream_id: Optional[str] = None,
action_usage: Optional[List[str]] = None,
) -> Optional[MaiMessages]:
"""根据事件类型和输入,准备和转换消息对象。
迁移自 events_manager._prepare_message保持相同的行为。
"""
if isinstance(message, MaiMessages):
return message.deepcopy()
if message:
return _transform_event_message(message, llm_prompt, llm_response)
if event_type not in (EventType.ON_START, EventType.ON_STOP):
assert stream_id, "如果没有消息,必须为非启动/关闭事件提供流ID"
if event_type in (EventType.ON_MESSAGE, EventType.ON_PLAN, EventType.POST_LLM, EventType.AFTER_LLM):
return _build_message_from_stream(stream_id, llm_prompt, llm_response)
else:
return _build_message_without_raw(stream_id, llm_prompt, llm_response, action_usage)
return None # ON_START / ON_STOP 没有消息体
def _transform_event_message(
message: "SessionMessage | MessageSending",
llm_prompt: Optional[str] = None,
llm_response: Optional["LLMGenerationDataModel"] = None,
) -> MaiMessages:
"""将 SessionMessage / MessageSending 转换为 MaiMessages。"""
from maim_message import Seg
from src.chat.message_receive.message import MessageSending
transformed = MaiMessages(
llm_prompt=llm_prompt,
llm_response_content=llm_response.content if llm_response else None,
llm_response_reasoning=llm_response.reasoning if llm_response else None,
llm_response_model=llm_response.model if llm_response else None,
llm_response_tool_call=llm_response.tool_calls if llm_response else None,
raw_message=message.processed_plain_text or "",
additional_data={},
)
# 消息段处理
if isinstance(message, MessageSending):
if message.message_segment.type == "seglist":
transformed.message_segments = list(message.message_segment.data) # type: ignore
else:
transformed.message_segments = [message.message_segment]
else:
transformed.message_segments = [Seg(type="text", data=message.processed_plain_text or "")]
# stream_id
transformed.stream_id = message.session_id if hasattr(message, "session_id") else ""
# 处理后文本
transformed.plain_text = message.processed_plain_text
# 基本信息
if isinstance(message, MessageSending):
transformed.message_base_info["platform"] = message.platform
if message.session.group_id:
transformed.is_group_message = True
group_name = ""
if (
message.session.context
and message.session.context.message
and message.session.context.message.message_info.group_info
):
group_name = message.session.context.message.message_info.group_info.group_name
transformed.message_base_info.update(
{
"group_id": message.session.group_id,
"group_name": group_name,
}
)
transformed.message_base_info.update(
{
"user_id": message.bot_user_info.user_id,
"user_cardname": message.bot_user_info.user_cardname,
"user_nickname": message.bot_user_info.user_nickname,
}
)
if not transformed.is_group_message:
transformed.is_private_message = True
elif hasattr(message, "message_info") and message.message_info:
if message.platform:
transformed.message_base_info["platform"] = message.platform
if message.message_info.group_info:
transformed.is_group_message = True
transformed.message_base_info.update(
{
"group_id": message.message_info.group_info.group_id,
"group_name": message.message_info.group_info.group_name,
}
)
if message.message_info.user_info:
if not transformed.is_group_message:
transformed.is_private_message = True
transformed.message_base_info.update(
{
"user_id": message.message_info.user_info.user_id,
"user_cardname": message.message_info.user_info.user_cardname,
"user_nickname": message.message_info.user_info.user_nickname,
}
)
return transformed
def _build_message_from_stream(
stream_id: str,
llm_prompt: Optional[str] = None,
llm_response: Optional["LLMGenerationDataModel"] = None,
) -> MaiMessages:
"""从 stream_id 查找会话消息并转换。"""
from src.chat.message_receive.chat_manager import chat_manager
session = chat_manager.get_session_by_session_id(stream_id)
assert session, f"未找到流ID为 {stream_id} 的会话"
return _transform_event_message(session.context.message, llm_prompt, llm_response)
def _build_message_without_raw(
stream_id: str,
llm_prompt: Optional[str] = None,
llm_response: Optional["LLMGenerationDataModel"] = None,
action_usage: Optional[List[str]] = None,
) -> MaiMessages:
"""没有原始消息对象时,从 stream_id 构建最小 MaiMessages。"""
from src.chat.message_receive.chat_manager import chat_manager
session = chat_manager.get_session_by_session_id(stream_id)
assert session, f"未找到流ID为 {stream_id} 的会话"
return MaiMessages(
stream_id=stream_id,
llm_prompt=llm_prompt,
llm_response_content=llm_response.content if llm_response else None,
llm_response_reasoning=llm_response.reasoning if llm_response else None,
llm_response_model=llm_response.model if llm_response else None,
llm_response_tool_call=llm_response.tool_calls if llm_response else None,
is_group_message=session.is_group_session,
is_private_message=not session.is_group_session,
action_usage=action_usage,
additional_data={"response_is_processed": True},
)

View File

@@ -6,7 +6,7 @@ import time
from src.config.config import global_config from src.config.config import global_config
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.plugin_system.apis import send_api from src.services import send_service as send_api
from src.common.message_repository import count_messages from src.common.message_repository import count_messages

View File

@@ -11,8 +11,9 @@ from src.common.utils.utils_session import SessionUtils
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
from src.chat.brain_chat.PFC.pfc_manager import PFCManager from src.chat.brain_chat.PFC.pfc_manager import PFCManager
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.plugin_system.core import component_registry, events_manager, global_announcement_manager from src.core.announcement_manager import global_announcement_manager
from src.plugin_system.base import BaseCommand, EventType from src.core.component_registry import component_registry
from src.core.types import EventType
from .message import SessionMessage from .message import SessionMessage
from .chat_manager import chat_manager from .chat_manager import chat_manager
@@ -65,10 +66,10 @@ class ChatBot:
try: try:
text = message.processed_plain_text text = message.processed_plain_text
# 使用新的组件注册中心查找命令 # 使用核心组件注册查找命令
command_result = component_registry.find_command_by_text(text) command_result = component_registry.find_command_by_text(text)
if command_result: if command_result:
command_class, matched_groups, command_info = command_result command_executor, matched_groups, command_info = command_result
plugin_name = command_info.plugin_name plugin_name = command_info.plugin_name
command_name = command_info.name command_name = command_info.name
if message.session_id and command_name in global_announcement_manager.get_disabled_chat_commands( if message.session_id and command_name in global_announcement_manager.get_disabled_chat_commands(
@@ -82,20 +83,20 @@ class ChatBot:
# 获取插件配置 # 获取插件配置
plugin_config = component_registry.get_plugin_config(plugin_name) plugin_config = component_registry.get_plugin_config(plugin_name)
# 创建命令实例
command_instance: BaseCommand = command_class(message, plugin_config)
command_instance.set_matched_groups(matched_groups)
try: try:
# 执行命令 # 调用命令执行器
success, response, intercept_message_level = await command_instance.execute() success, response, intercept_message_level = await command_executor(
message=message,
plugin_config=plugin_config,
matched_groups=matched_groups,
)
message.intercept_message_level = intercept_message_level message.intercept_message_level = intercept_message_level
# 记录命令执行结果 # 记录命令执行结果
if success: if success:
logger.info(f"命令执行成功: {command_class.__name__} (拦截等级: {intercept_message_level})") logger.info(f"命令执行成功: {command_name} (拦截等级: {intercept_message_level})")
else: else:
logger.warning(f"命令执行失败: {command_class.__name__} - {response}") logger.warning(f"命令执行失败: {command_name} - {response}")
# 根据命令的拦截设置决定是否继续处理消息 # 根据命令的拦截设置决定是否继续处理消息
return ( return (
@@ -105,14 +106,9 @@ class ChatBot:
) # 找到命令根据intercept_message决定是否继续 ) # 找到命令根据intercept_message决定是否继续
except Exception as e: except Exception as e:
logger.error(f"执行命令时出错: {command_class.__name__} - {e}") logger.error(f"执行命令时出错: {command_name} - {e}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
try:
await command_instance.send_text(f"命令执行出错: {str(e)}")
except Exception as send_error:
logger.error(f"发送错误消息失败: {send_error}")
# 命令出错时,根据命令的拦截设置决定是否继续处理消息 # 命令出错时,根据命令的拦截设置决定是否继续处理消息
return True, str(e), False # 出错时继续处理消息 return True, str(e), False # 出错时继续处理消息

View File

@@ -318,11 +318,13 @@ class UniversalMessageSender:
message.build_reply() message.build_reply()
logger.debug(f"[{chat_id}] 选择回复引用消息: {message.processed_plain_text[:20]}...") logger.debug(f"[{chat_id}] 选择回复引用消息: {message.processed_plain_text[:20]}...")
from src.plugin_system.core.events_manager import events_manager from src.core.event_bus import event_bus
from src.plugin_system.base.component_types import EventType from src.chat.event_helpers import build_event_message
from src.core.types import EventType
continue_flag, modified_message = await events_manager.handle_mai_events( _event_msg = build_event_message(EventType.POST_SEND_PRE_PROCESS, message=message, stream_id=chat_id)
EventType.POST_SEND_PRE_PROCESS, message=message, stream_id=chat_id continue_flag, modified_message = await event_bus.emit(
EventType.POST_SEND_PRE_PROCESS, _event_msg
) )
if not continue_flag: if not continue_flag:
logger.info(f"[{chat_id}] 消息发送被插件取消: {str(message.message_segment)[:100]}...") logger.info(f"[{chat_id}] 消息发送被插件取消: {str(message.message_segment)[:100]}...")
@@ -336,8 +338,9 @@ class UniversalMessageSender:
await message.process() await message.process()
continue_flag, modified_message = await events_manager.handle_mai_events( _event_msg = build_event_message(EventType.POST_SEND, message=message, stream_id=chat_id)
EventType.POST_SEND, message=message, stream_id=chat_id continue_flag, modified_message = await event_bus.emit(
EventType.POST_SEND, _event_msg
) )
if not continue_flag: if not continue_flag:
logger.info(f"[{chat_id}] 消息发送被插件取消: {str(message.message_segment)[:100]}...") logger.info(f"[{chat_id}] 消息发送被插件取消: {str(message.message_segment)[:100]}...")
@@ -360,8 +363,9 @@ class UniversalMessageSender:
if not sent_msg: if not sent_msg:
return False return False
continue_flag, modified_message = await events_manager.handle_mai_events( _event_msg = build_event_message(EventType.AFTER_SEND, message=message, stream_id=chat_id)
EventType.AFTER_SEND, message=message, stream_id=chat_id continue_flag, modified_message = await event_bus.emit(
EventType.AFTER_SEND, _event_msg
) )
if not continue_flag: if not continue_flag:
logger.info(f"[{chat_id}] 消息发送后续处理被插件取消: {str(message.message_segment)[:100]}...") logger.info(f"[{chat_id}] 消息发送后续处理被插件取消: {str(message.message_segment)[:100]}...")

View File

@@ -1,20 +1,34 @@
from typing import Dict, Optional, Type from typing import Dict, Optional, Tuple
from src.chat.message_receive.chat_manager import BotChatSession from src.chat.message_receive.chat_manager import BotChatSession
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.database_data_model import DatabaseMessages
from src.plugin_system.core.component_registry import component_registry from src.core.component_registry import component_registry, ActionExecutor
from src.plugin_system.base.component_types import ComponentType, ActionInfo from src.core.types import ActionInfo, ComponentType
from src.plugin_system.base.base_action import BaseAction
logger = get_logger("action_manager") logger = get_logger("action_manager")
class ActionHandle:
"""Action 执行句柄
不依赖任何插件基类,内部持有 executor (async callable) 和绑定参数。
brain_chat 调用 ``await handle.execute()`` 即可。
"""
def __init__(self, executor: ActionExecutor, **kwargs):
self._executor = executor
self._kwargs = kwargs
async def execute(self) -> Tuple[bool, str]:
return await self._executor(**self._kwargs)
class ActionManager: class ActionManager:
""" """
动作管理器,用于管理各种类型的动作 动作管理器,用于管理各种类型的动作
现在统一使用新插件系统,简化了原有的新旧兼容逻辑 使用核心组件注册表的 executor-based 模式
""" """
def __init__(self): def __init__(self):
@@ -39,9 +53,9 @@ class ActionManager:
log_prefix: str, log_prefix: str,
shutting_down: bool = False, shutting_down: bool = False,
action_message: Optional[DatabaseMessages] = None, action_message: Optional[DatabaseMessages] = None,
) -> Optional[BaseAction]: ) -> Optional[ActionHandle]:
""" """
创建动作处理器实例 创建动作执行句柄
Args: Args:
action_name: 动作名称 action_name: 动作名称
@@ -52,30 +66,26 @@ class ActionManager:
chat_stream: 聊天流 chat_stream: 聊天流
log_prefix: 日志前缀 log_prefix: 日志前缀
shutting_down: 是否正在关闭 shutting_down: 是否正在关闭
action_message: 动作消息记录
Returns: Returns:
Optional[BaseAction]: 创建的动作处理器实例,如果动作名称未注册则返回None Optional[ActionHandle]: 执行句柄,如果动作未注册则返回 None
""" """
try: try:
# 获取组件类 - 明确指定查询Action类型 executor = component_registry.get_action_executor(action_name)
component_class: Type[BaseAction] = component_registry.get_component_class( if not executor:
action_name, ComponentType.ACTION
) # type: ignore
if not component_class:
logger.warning(f"{log_prefix} 未找到Action组件: {action_name}") logger.warning(f"{log_prefix} 未找到Action组件: {action_name}")
return None return None
# 获取组件信息 info = component_registry.get_action_info(action_name)
component_info = component_registry.get_component_info(action_name, ComponentType.ACTION) if not info:
if not component_info:
logger.warning(f"{log_prefix} 未找到Action组件信息: {action_name}") logger.warning(f"{log_prefix} 未找到Action组件信息: {action_name}")
return None return None
# 获取插件配置 plugin_config = component_registry.get_plugin_config(info.plugin_name) or {}
plugin_config = component_registry.get_plugin_config(component_info.plugin_name)
# 创建动作实例 handle = ActionHandle(
instance = component_class( executor,
action_data=action_data, action_data=action_data,
action_reasoning=action_reasoning, action_reasoning=action_reasoning,
cycle_timers=cycle_timers, cycle_timers=cycle_timers,
@@ -87,11 +97,11 @@ class ActionManager:
action_message=action_message, action_message=action_message,
) )
logger.debug(f"创建Action实例成功: {action_name}") logger.debug(f"创建Action执行句柄成功: {action_name}")
return instance return handle
except Exception as e: except Exception as e:
logger.error(f"创建Action实例失败 {action_name}: {e}") logger.error(f"创建Action执行句柄失败 {action_name}: {e}")
import traceback import traceback
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())

View File

@@ -7,8 +7,8 @@ from src.config.config import global_config
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
from src.chat.planner_actions.action_manager import ActionManager from src.chat.planner_actions.action_manager import ActionManager
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
from src.plugin_system.base.component_types import ActionInfo, ActionActivationType from src.core.types import ActionActivationType, ActionInfo
from src.plugin_system.core.global_announcement_manager import global_announcement_manager from src.core.announcement_manager import global_announcement_manager
logger = get_logger("action_manager") logger = get_logger("action_manager")

View File

@@ -23,9 +23,9 @@ from src.chat.utils.chat_message_builder import (
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
from src.chat.planner_actions.action_manager import ActionManager from src.chat.planner_actions.action_manager import ActionManager
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.plugin_system.base.component_types import ActionInfo, ComponentType, ActionActivationType from src.core.types import ActionActivationType, ActionInfo, ComponentType
from src.plugin_system.core.component_registry import component_registry from src.core.component_registry import component_registry
from src.plugin_system.apis.message_api import translate_pid_to_description from src.services.message_service import translate_pid_to_description
from src.person_info.person_info import Person from src.person_info.person_info import Person
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@@ -27,12 +27,12 @@ from src.chat.utils.chat_message_builder import (
replace_user_references, replace_user_references,
) )
from src.bw_learner.expression_selector import expression_selector from src.bw_learner.expression_selector import expression_selector
from src.plugin_system.apis.message_api import translate_pid_to_description from src.services.message_service import translate_pid_to_description
# from src.memory_system.memory_activator import MemoryActivator # from src.memory_system.memory_activator import MemoryActivator
from src.person_info.person_info import Person from src.person_info.person_info import Person
from src.plugin_system.base.component_types import ActionInfo, EventType from src.core.types import ActionInfo, EventType
from src.plugin_system.apis import llm_api from src.services import llm_service as llm_api
from src.chat.logger.plan_reply_logger import PlanReplyLogger from src.chat.logger.plan_reply_logger import PlanReplyLogger
from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt
@@ -56,7 +56,7 @@ class DefaultReplyer:
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id) self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
self.heart_fc_sender = UniversalMessageSender() self.heart_fc_sender = UniversalMessageSender()
from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor不然会循环依赖 from src.chat.tool_executor import ToolExecutor
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.session_id, enable_cache=True, cache_ttl=3) self.tool_executor = ToolExecutor(chat_id=self.chat_stream.session_id, enable_cache=True, cache_ttl=3)
@@ -149,11 +149,13 @@ class DefaultReplyer:
except Exception: except Exception:
logger.exception("记录reply日志失败") logger.exception("记录reply日志失败")
return False, llm_response return False, llm_response
from src.plugin_system.core.events_manager import events_manager from src.core.event_bus import event_bus
from src.chat.event_helpers import build_event_message
if not from_plugin: if not from_plugin:
continue_flag, modified_message = await events_manager.handle_mai_events( _event_msg = build_event_message(EventType.POST_LLM, llm_prompt=prompt, stream_id=stream_id)
EventType.POST_LLM, None, prompt, None, stream_id=stream_id continue_flag, modified_message = await event_bus.emit(
EventType.POST_LLM, _event_msg
) )
if not continue_flag: if not continue_flag:
raise UserWarning("插件于请求前中断了内容生成") raise UserWarning("插件于请求前中断了内容生成")
@@ -217,8 +219,9 @@ class DefaultReplyer:
) )
except Exception: except Exception:
logger.exception("记录reply日志失败") logger.exception("记录reply日志失败")
continue_flag, modified_message = await events_manager.handle_mai_events( _event_msg = build_event_message(EventType.AFTER_LLM, llm_prompt=prompt, llm_response=llm_response, stream_id=stream_id)
EventType.AFTER_LLM, None, prompt, llm_response, stream_id=stream_id continue_flag, modified_message = await event_bus.emit(
EventType.AFTER_LLM, _event_msg
) )
if not from_plugin and not continue_flag: if not from_plugin and not continue_flag:
raise UserWarning("插件于请求后取消了内容生成") raise UserWarning("插件于请求后取消了内容生成")

View File

@@ -28,12 +28,12 @@ from src.chat.utils.chat_message_builder import (
replace_user_references, replace_user_references,
) )
from src.bw_learner.expression_selector import expression_selector from src.bw_learner.expression_selector import expression_selector
from src.plugin_system.apis.message_api import translate_pid_to_description from src.services.message_service import translate_pid_to_description
# from src.memory_system.memory_activator import MemoryActivator # from src.memory_system.memory_activator import MemoryActivator
from src.person_info.person_info import Person, is_person_known from src.person_info.person_info import Person, is_person_known
from src.plugin_system.base.component_types import ActionInfo, EventType from src.core.types import ActionInfo, EventType
from src.plugin_system.apis import llm_api from src.services import llm_service as llm_api
from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt
from src.bw_learner.jargon_explainer import explain_jargon_in_context from src.bw_learner.jargon_explainer import explain_jargon_in_context
@@ -55,7 +55,7 @@ class PrivateReplyer:
self.heart_fc_sender = UniversalMessageSender() self.heart_fc_sender = UniversalMessageSender()
# self.memory_activator = MemoryActivator() # self.memory_activator = MemoryActivator()
from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor不然会循环依赖 from src.chat.tool_executor import ToolExecutor
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.session_id, enable_cache=True, cache_ttl=3) self.tool_executor = ToolExecutor(chat_id=self.chat_stream.session_id, enable_cache=True, cache_ttl=3)
@@ -114,11 +114,13 @@ class PrivateReplyer:
if not prompt: if not prompt:
logger.warning("构建prompt失败跳过回复生成") logger.warning("构建prompt失败跳过回复生成")
return False, llm_response return False, llm_response
from src.plugin_system.core.events_manager import events_manager from src.core.event_bus import event_bus
from src.chat.event_helpers import build_event_message
if not from_plugin: if not from_plugin:
continue_flag, modified_message = await events_manager.handle_mai_events( _event_msg = build_event_message(EventType.POST_LLM, llm_prompt=prompt, stream_id=stream_id)
EventType.POST_LLM, None, prompt, None, stream_id=stream_id continue_flag, modified_message = await event_bus.emit(
EventType.POST_LLM, _event_msg
) )
if not continue_flag: if not continue_flag:
raise UserWarning("插件于请求前中断了内容生成") raise UserWarning("插件于请求前中断了内容生成")
@@ -138,8 +140,9 @@ class PrivateReplyer:
llm_response.reasoning = reasoning_content llm_response.reasoning = reasoning_content
llm_response.model = model_name llm_response.model = model_name
llm_response.tool_calls = tool_call llm_response.tool_calls = tool_call
continue_flag, modified_message = await events_manager.handle_mai_events( _event_msg = build_event_message(EventType.AFTER_LLM, llm_prompt=prompt, llm_response=llm_response, stream_id=stream_id)
EventType.AFTER_LLM, None, prompt, llm_response, stream_id=stream_id continue_flag, modified_message = await event_bus.emit(
EventType.AFTER_LLM, _event_msg
) )
if not from_plugin and not continue_flag: if not from_plugin and not continue_flag:
raise UserWarning("插件于请求后取消了内容生成") raise UserWarning("插件于请求后取消了内容生成")

View File

@@ -1,14 +1,23 @@
"""
工具执行器
独立的工具执行组件可以直接输入聊天消息内容
自动判断并执行相应的工具返回结构化的工具执行结果
src.plugin_system.core.tool_use 迁移使用新的核心组件注册表
"""
import hashlib
import time import time
from typing import List, Dict, Tuple, Optional, Any from typing import Any, Dict, List, Optional, Tuple
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.llm_models.utils_model import LLMRequest
from src.llm_models.payload_content import ToolCall
from src.config.config import global_config, model_config
from src.prompt.prompt_manager import prompt_manager
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.core.announcement_manager import global_announcement_manager
from src.core.component_registry import component_registry
from src.llm_models.payload_content import ToolCall
from src.llm_models.utils_model import LLMRequest
from src.prompt.prompt_manager import prompt_manager
logger = get_logger("tool_use") logger = get_logger("tool_use")
@@ -20,66 +29,38 @@ class ToolExecutor:
""" """
def __init__(self, chat_id: str, enable_cache: bool = True, cache_ttl: int = 3): def __init__(self, chat_id: str, enable_cache: bool = True, cache_ttl: int = 3):
"""初始化工具执行器 from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
Args:
executor_id: 执行器标识符用于日志记录
enable_cache: 是否启用缓存机制
cache_ttl: 缓存生存时间周期数
"""
self.chat_id = chat_id self.chat_id = chat_id
self.chat_stream = _chat_manager.get_session_by_session_id(self.chat_id) self.chat_stream = _chat_manager.get_session_by_session_id(self.chat_id)
self.log_prefix = f"[{_chat_manager.get_session_name(self.chat_id) or self.chat_id}]" self.log_prefix = f"[{_chat_manager.get_session_name(self.chat_id) or self.chat_id}]"
self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor") self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor")
# 缓存配置
self.enable_cache = enable_cache self.enable_cache = enable_cache
self.cache_ttl = cache_ttl self.cache_ttl = cache_ttl
self.tool_cache = {} # 格式: {cache_key: {"result": result, "ttl": ttl, "timestamp": timestamp}} self.tool_cache: Dict[str, dict] = {}
logger.info(f"{self.log_prefix}工具执行器初始化完成,缓存{'启用' if enable_cache else '禁用'}TTL={cache_ttl}") logger.info(f"{self.log_prefix}工具执行器初始化完成,缓存{'启用' if enable_cache else '禁用'}TTL={cache_ttl}")
async def execute_from_chat_message( async def execute_from_chat_message(
self, target_message: str, chat_history: str, sender: str, return_details: bool = False self, target_message: str, chat_history: str, sender: str, return_details: bool = False
) -> Tuple[List[Dict[str, Any]], List[str], str]: ) -> Tuple[List[Dict[str, Any]], List[str], str]:
"""从聊天消息执行工具 """从聊天消息执行工具"""
Args:
target_message: 目标消息内容
chat_history: 聊天历史
sender: 发送者
return_details: 是否返回详细信息(使用的工具列表和提示词)
Returns:
如果return_details为False: Tuple[List[Dict], List[str], str] - (工具执行结果列表, , )
如果return_details为True: Tuple[List[Dict], List[str], str] - (结果列表, 使用的工具, 提示词)
"""
# 首先检查缓存
cache_key = self._generate_cache_key(target_message, chat_history, sender) cache_key = self._generate_cache_key(target_message, chat_history, sender)
if cached_result := self._get_from_cache(cache_key): if cached_result := self._get_from_cache(cache_key):
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行") logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行")
if not return_details: if not return_details:
return cached_result, [], "" return cached_result, [], ""
# 从缓存结果中提取工具名称
used_tools = [result.get("tool_name", "unknown") for result in cached_result] used_tools = [result.get("tool_name", "unknown") for result in cached_result]
return cached_result, used_tools, "" return cached_result, used_tools, ""
# 缓存未命中,执行工具调用
# 获取可用工具
tools = self._get_tool_definitions() tools = self._get_tool_definitions()
# 如果没有可用工具,直接返回空内容
if not tools: if not tools:
logger.debug(f"{self.log_prefix}没有可用工具,直接返回空内容") logger.debug(f"{self.log_prefix}没有可用工具,直接返回空内容")
if return_details: return [], [], ""
return [], [], ""
else:
return [], [], ""
# 构建工具调用提示词
prompt_template = prompt_manager.get_prompt("tool_executor") prompt_template = prompt_manager.get_prompt("tool_executor")
prompt_template.add_context("target_message", target_message) prompt_template.add_context("target_message", target_message)
prompt_template.add_context("chat_history", chat_history) prompt_template.add_context("chat_history", chat_history)
@@ -90,15 +71,12 @@ class ToolExecutor:
logger.debug(f"{self.log_prefix}开始LLM工具调用分析") logger.debug(f"{self.log_prefix}开始LLM工具调用分析")
# 调用LLM进行工具决策
response, (reasoning_content, model_name, tool_calls) = await self.llm_model.generate_response_async( response, (reasoning_content, model_name, tool_calls) = await self.llm_model.generate_response_async(
prompt=prompt, tools=tools, raise_when_empty=False prompt=prompt, tools=tools, raise_when_empty=False
) )
# 执行工具调用
tool_results, used_tools = await self.execute_tool_calls(tool_calls) tool_results, used_tools = await self.execute_tool_calls(tool_calls)
# 缓存结果
if tool_results: if tool_results:
self._set_cache(cache_key, tool_results) self._set_cache(cache_key, tool_results)
@@ -107,42 +85,30 @@ class ToolExecutor:
if return_details: if return_details:
return tool_results, used_tools, prompt return tool_results, used_tools, prompt
else: return tool_results, [], ""
return tool_results, [], ""
def _get_tool_definitions(self) -> List[Dict[str, Any]]: def _get_tool_definitions(self) -> List[Dict[str, Any]]:
all_tools = get_llm_available_tool_definitions() """获取 LLM 可用的工具定义列表"""
all_tools = component_registry.get_llm_available_tools()
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id) user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
return [definition for name, definition in all_tools if name not in user_disabled_tools] return [info.get_llm_definition() for name, info in all_tools.items() if name not in user_disabled_tools]
async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]: async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]:
"""执行工具调用 """执行工具调用列表"""
Args:
tool_calls: LLM返回的工具调用列表
Returns:
Tuple[List[Dict], List[str]]: (工具执行结果列表, 使用的工具名称列表)
"""
tool_results: List[Dict[str, Any]] = [] tool_results: List[Dict[str, Any]] = []
used_tools = [] used_tools: List[str] = []
if not tool_calls: if not tool_calls:
logger.debug(f"{self.log_prefix}无需执行工具") logger.debug(f"{self.log_prefix}无需执行工具")
return [], [] return [], []
# 提取tool_calls中的函数名称
func_names = [call.func_name for call in tool_calls if call.func_name] func_names = [call.func_name for call in tool_calls if call.func_name]
logger.info(f"{self.log_prefix}开始执行工具调用: {func_names}") logger.info(f"{self.log_prefix}开始执行工具调用: {func_names}")
# 执行每个工具调用
for tool_call in tool_calls: for tool_call in tool_calls:
tool_name = tool_call.func_name
try: try:
tool_name = tool_call.func_name
logger.debug(f"{self.log_prefix}执行工具: {tool_name}") logger.debug(f"{self.log_prefix}执行工具: {tool_name}")
# 执行工具
result = await self.execute_tool_call(tool_call) result = await self.execute_tool_call(tool_call)
if result: if result:
@@ -156,7 +122,6 @@ class ToolExecutor:
content = tool_info["content"] content = tool_info["content"]
if not isinstance(content, (str, list, tuple)): if not isinstance(content, (str, list, tuple)):
tool_info["content"] = str(content) tool_info["content"] = str(content)
# 空内容直接跳过(空字符串、全空白字符串、空列表/空元组)
content_check = tool_info["content"] content_check = tool_info["content"]
if (isinstance(content_check, str) and not content_check.strip()) or ( if (isinstance(content_check, str) and not content_check.strip()) or (
isinstance(content_check, (list, tuple)) and len(content_check) == 0 isinstance(content_check, (list, tuple)) and len(content_check) == 0
@@ -166,11 +131,10 @@ class ToolExecutor:
tool_results.append(tool_info) tool_results.append(tool_info)
used_tools.append(tool_name) used_tools.append(tool_name)
preview = content[:200] preview = str(content)[:200]
logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...") logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...")
except Exception as e: except Exception as e:
logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}") logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}")
# 添加错误信息到结果中
error_info = { error_info = {
"type": "tool_error", "type": "tool_error",
"id": f"tool_error_{time.time()}", "id": f"tool_error_{time.time()}",
@@ -182,122 +146,30 @@ class ToolExecutor:
return tool_results, used_tools return tool_results, used_tools
async def execute_tool_call( async def execute_tool_call(self, tool_call: ToolCall) -> Optional[Dict[str, Any]]:
self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None """执行单个工具调用"""
) -> Optional[Dict[str, Any]]: function_name = tool_call.func_name
# sourcery skip: use-assigned-variable function_args = tool_call.args or {}
"""执行单个工具调用 function_args["llm_called"] = True
Args: executor = component_registry.get_tool_executor(function_name)
tool_call: 工具调用对象 if not executor:
logger.warning(f"未知工具名称: {function_name}")
Returns:
Optional[Dict]: 工具调用结果如果失败则返回None
"""
try:
function_name = tool_call.func_name
function_args = tool_call.args or {}
function_args["llm_called"] = True # 标记为LLM调用
# 获取对应工具实例
tool_instance = tool_instance or get_tool_instance(function_name, self.chat_stream)
if not tool_instance:
logger.warning(f"未知工具名称: {function_name}")
return None
# 执行工具
result = await tool_instance.execute(function_args)
if result:
return {
"tool_call_id": tool_call.call_id,
"role": "tool",
"name": function_name,
"type": "function",
"content": result["content"],
}
return None
except Exception as e:
logger.error(f"执行工具调用时发生错误: {str(e)}")
raise e
def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str:
"""生成缓存键
Args:
target_message: 目标消息内容
chat_history: 聊天历史
sender: 发送者
Returns:
str: 缓存键
"""
import hashlib
# 使用消息内容和群聊状态生成唯一缓存键
content = f"{target_message}_{chat_history}_{sender}"
return hashlib.md5(content.encode()).hexdigest()
def _get_from_cache(self, cache_key: str) -> Optional[List[Dict]]:
"""从缓存获取结果
Args:
cache_key: 缓存键
Returns:
Optional[List[Dict]]: 缓存的结果如果不存在或过期则返回None
"""
if not self.enable_cache or cache_key not in self.tool_cache:
return None return None
cache_item = self.tool_cache[cache_key] result = await executor(function_args)
if cache_item["ttl"] <= 0: if result:
# 缓存过期,删除 return {
del self.tool_cache[cache_key] "tool_call_id": tool_call.call_id,
logger.debug(f"{self.log_prefix}缓存过期,删除缓存键: {cache_key}") "role": "tool",
return None "name": function_name,
"type": "function",
# 减少TTL "content": result["content"],
cache_item["ttl"] -= 1 }
logger.debug(f"{self.log_prefix}使用缓存结果剩余TTL: {cache_item['ttl']}") return None
return cache_item["result"]
def _set_cache(self, cache_key: str, result: List[Dict]):
"""设置缓存
Args:
cache_key: 缓存键
result: 要缓存的结果
"""
if not self.enable_cache:
return
self.tool_cache[cache_key] = {"result": result, "ttl": self.cache_ttl, "timestamp": time.time()}
logger.debug(f"{self.log_prefix}设置缓存TTL: {self.cache_ttl}")
def _cleanup_expired_cache(self):
"""清理过期的缓存"""
if not self.enable_cache:
return
expired_keys = []
expired_keys.extend(cache_key for cache_key, cache_item in self.tool_cache.items() if cache_item["ttl"] <= 0)
for key in expired_keys:
del self.tool_cache[key]
if expired_keys:
logger.debug(f"{self.log_prefix}清理了{len(expired_keys)}个过期缓存")
async def execute_specific_tool_simple(self, tool_name: str, tool_args: Dict) -> Optional[Dict]: async def execute_specific_tool_simple(self, tool_name: str, tool_args: Dict) -> Optional[Dict]:
"""直接执行指定工具 """直接执行指定工具"""
Args:
tool_name: 工具名称
tool_args: 工具参数
validate_args: 是否验证参数
Returns:
Optional[Dict]: 工具执行结果失败时返回None
"""
try: try:
tool_call = ToolCall( tool_call = ToolCall(
call_id=f"direct_tool_{time.time()}", call_id=f"direct_tool_{time.time()}",
@@ -306,7 +178,6 @@ class ToolExecutor:
) )
logger.info(f"{self.log_prefix}直接执行工具: {tool_name}") logger.info(f"{self.log_prefix}直接执行工具: {tool_name}")
result = await self.execute_tool_call(tool_call) result = await self.execute_tool_call(tool_call)
if result: if result:
@@ -325,86 +196,55 @@ class ToolExecutor:
return None return None
# === 缓存方法 ===
def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str:
content = f"{target_message}_{chat_history}_{sender}"
return hashlib.md5(content.encode()).hexdigest()
def _get_from_cache(self, cache_key: str) -> Optional[List[Dict]]:
if not self.enable_cache or cache_key not in self.tool_cache:
return None
cache_item = self.tool_cache[cache_key]
if cache_item["ttl"] <= 0:
del self.tool_cache[cache_key]
return None
cache_item["ttl"] -= 1
return cache_item["result"]
def _set_cache(self, cache_key: str, result: List[Dict]):
if not self.enable_cache:
return
self.tool_cache[cache_key] = {"result": result, "ttl": self.cache_ttl, "timestamp": time.time()}
def _cleanup_expired_cache(self):
if not self.enable_cache:
return
expired = [k for k, v in self.tool_cache.items() if v["ttl"] <= 0]
for key in expired:
del self.tool_cache[key]
def clear_cache(self): def clear_cache(self):
"""清空所有缓存"""
if self.enable_cache: if self.enable_cache:
cache_count = len(self.tool_cache)
self.tool_cache.clear() self.tool_cache.clear()
logger.info(f"{self.log_prefix}清空了{cache_count}个缓存项")
def get_cache_status(self) -> Dict: def get_cache_status(self) -> Dict:
"""获取缓存状态信息
Returns:
Dict: 包含缓存统计信息的字典
"""
if not self.enable_cache: if not self.enable_cache:
return {"enabled": False, "cache_count": 0} return {"enabled": False, "cache_count": 0}
# 清理过期缓存
self._cleanup_expired_cache() self._cleanup_expired_cache()
ttl_distribution: Dict[int, int] = {}
total_count = len(self.tool_cache) for item in self.tool_cache.values():
ttl_distribution = {} ttl = item["ttl"]
for cache_item in self.tool_cache.values():
ttl = cache_item["ttl"]
ttl_distribution[ttl] = ttl_distribution.get(ttl, 0) + 1 ttl_distribution[ttl] = ttl_distribution.get(ttl, 0) + 1
return { return {
"enabled": True, "enabled": True,
"cache_count": total_count, "cache_count": len(self.tool_cache),
"cache_ttl": self.cache_ttl, "cache_ttl": self.cache_ttl,
"ttl_distribution": ttl_distribution, "ttl_distribution": ttl_distribution,
} }
def set_cache_config(self, enable_cache: Optional[bool] = None, cache_ttl: int = -1): def set_cache_config(self, enable_cache: Optional[bool] = None, cache_ttl: int = -1):
"""动态修改缓存配置
Args:
enable_cache: 是否启用缓存
cache_ttl: 缓存TTL
"""
if enable_cache is not None: if enable_cache is not None:
self.enable_cache = enable_cache self.enable_cache = enable_cache
logger.info(f"{self.log_prefix}缓存状态修改为: {'启用' if enable_cache else '禁用'}")
if cache_ttl > 0: if cache_ttl > 0:
self.cache_ttl = cache_ttl self.cache_ttl = cache_ttl
logger.info(f"{self.log_prefix}缓存TTL修改为: {cache_ttl}")
"""
ToolExecutor使用示例
# 1. 基础使用 - 从聊天消息执行工具启用缓存默认TTL=3
executor = ToolExecutor(executor_id="my_executor")
results, _, _ = await executor.execute_from_chat_message(
talking_message_str="今天天气怎么样?现在几点了?",
is_group_chat=False
)
# 2. 禁用缓存的执行器
no_cache_executor = ToolExecutor(executor_id="no_cache", enable_cache=False)
# 3. 自定义缓存TTL
long_cache_executor = ToolExecutor(executor_id="long_cache", cache_ttl=10)
# 4. 获取详细信息
results, used_tools, prompt = await executor.execute_from_chat_message(
talking_message_str="帮我查询Python相关知识",
is_group_chat=False,
return_details=True
)
# 5. 直接执行特定工具
result = await executor.execute_specific_tool_simple(
tool_name="get_knowledge",
tool_args={"query": "机器学习"}
)
# 6. 缓存管理
cache_status = executor.get_cache_status() # 查看缓存状态
executor.clear_cache() # 清空缓存
executor.set_cache_config(cache_ttl=5) # 动态修改缓存配置
"""

View File

@@ -4,7 +4,7 @@ from . import BaseDataModel
if TYPE_CHECKING: if TYPE_CHECKING:
from .database_data_model import DatabaseMessages from .database_data_model import DatabaseMessages
from src.plugin_system.base.component_types import ActionInfo from src.core.types import ActionInfo
@dataclass @dataclass

6
src/core/__init__.py Normal file
View File

@@ -0,0 +1,6 @@
"""
MaiBot 核心基础设施
提供与插件系统无关的核心类型定义、配置 schema 等基础设施。
这些类型被整个项目共享,包括内部模块、服务层、旧插件系统和新插件运行时。
"""

View File

@@ -0,0 +1,242 @@
"""
核心组件注册表
面向最终架构的组件管理:
- Action注册 ActionInfo + 执行器(本地 callable 或 IPC 路由)
- Command注册正则模式 + 执行器
- Tool注册工具定义 + 执行器
不依赖任何插件基类,组件执行器是纯 async callable。
"""
import re
from typing import Any, Awaitable, Callable, Dict, List, Optional, Pattern, Tuple, Union
from src.common.logger import get_logger
from src.core.types import (
ActionActivationType,
ActionInfo,
CommandInfo,
ComponentInfo,
ComponentType,
ToolInfo,
)
logger = get_logger("component_registry")
# 执行器类型
ActionExecutor = Callable[..., Awaitable[Any]]
CommandExecutor = Callable[..., Awaitable[Tuple[bool, Optional[str], bool]]]
ToolExecutor = Callable[..., Awaitable[Any]]
class ComponentRegistry:
"""核心组件注册表
管理 action、command、tool 三类组件。
每个组件由「元信息 + 执行器」构成,执行器是 async callable
不需要继承任何基类。
"""
def __init__(self):
# Action 注册
self._actions: Dict[str, ActionInfo] = {}
self._action_executors: Dict[str, ActionExecutor] = {}
self._default_actions: Dict[str, ActionInfo] = {}
# Command 注册
self._commands: Dict[str, CommandInfo] = {}
self._command_executors: Dict[str, CommandExecutor] = {}
self._command_patterns: Dict[Pattern, str] = {}
# Tool 注册
self._tools: Dict[str, ToolInfo] = {}
self._tool_executors: Dict[str, ToolExecutor] = {}
self._llm_available_tools: Dict[str, ToolInfo] = {}
# 插件配置plugin_name -> config dict
self._plugin_configs: Dict[str, dict] = {}
logger.info("核心组件注册表初始化完成")
# ========== Action ==========
def register_action(
self,
info: ActionInfo,
executor: ActionExecutor,
) -> bool:
"""注册 action
Args:
info: action 元信息
executor: 执行器async callable
"""
name = info.name
if name in self._actions:
logger.warning(f"Action {name} 已存在,跳过注册")
return False
self._actions[name] = info
self._action_executors[name] = executor
if info.enabled:
self._default_actions[name] = info
logger.debug(f"注册 Action: {name}")
return True
def get_action_info(self, name: str) -> Optional[ActionInfo]:
return self._actions.get(name)
def get_action_executor(self, name: str) -> Optional[ActionExecutor]:
return self._action_executors.get(name)
def get_default_actions(self) -> Dict[str, ActionInfo]:
return self._default_actions.copy()
def get_all_actions(self) -> Dict[str, ActionInfo]:
return self._actions.copy()
def remove_action(self, name: str) -> bool:
if name not in self._actions:
return False
del self._actions[name]
self._action_executors.pop(name, None)
self._default_actions.pop(name, None)
logger.debug(f"移除 Action: {name}")
return True
# ========== Command ==========
def register_command(
self,
info: CommandInfo,
executor: CommandExecutor,
) -> bool:
"""注册 command"""
name = info.name
if name in self._commands:
logger.warning(f"Command {name} 已存在,跳过注册")
return False
self._commands[name] = info
self._command_executors[name] = executor
if info.enabled and info.command_pattern:
pattern = re.compile(info.command_pattern, re.IGNORECASE | re.DOTALL)
self._command_patterns[pattern] = name
logger.debug(f"注册 Command: {name}")
return True
def find_command_by_text(
self, text: str
) -> Optional[Tuple[CommandExecutor, dict, CommandInfo]]:
"""根据文本查找匹配的命令
Returns:
(executor, matched_groups, command_info) 或 None
"""
candidates = [p for p in self._command_patterns if p.match(text)]
if not candidates:
return None
if len(candidates) > 1:
logger.warning(f"文本 '{text[:50]}' 匹配到多个命令模式,使用第一个")
pattern = candidates[0]
name = self._command_patterns[pattern]
return (
self._command_executors[name],
pattern.match(text).groupdict(), # type: ignore
self._commands[name],
)
def remove_command(self, name: str) -> bool:
if name not in self._commands:
return False
del self._commands[name]
self._command_executors.pop(name, None)
self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != name}
logger.debug(f"移除 Command: {name}")
return True
# ========== Tool ==========
def register_tool(
self,
info: ToolInfo,
executor: ToolExecutor,
) -> bool:
"""注册 tool"""
name = info.name
if name in self._tools:
logger.warning(f"Tool {name} 已存在,跳过注册")
return False
self._tools[name] = info
self._tool_executors[name] = executor
if info.enabled:
self._llm_available_tools[name] = info
logger.debug(f"注册 Tool: {name}")
return True
def get_tool_info(self, name: str) -> Optional[ToolInfo]:
return self._tools.get(name)
def get_tool_executor(self, name: str) -> Optional[ToolExecutor]:
return self._tool_executors.get(name)
def get_llm_available_tools(self) -> Dict[str, ToolInfo]:
return self._llm_available_tools.copy()
def get_all_tools(self) -> Dict[str, ToolInfo]:
return self._tools.copy()
def remove_tool(self, name: str) -> bool:
if name not in self._tools:
return False
del self._tools[name]
self._tool_executors.pop(name, None)
self._llm_available_tools.pop(name, None)
logger.debug(f"移除 Tool: {name}")
return True
# ========== 通用查询 ==========
def get_component_info(self, name: str, component_type: ComponentType) -> Optional[ComponentInfo]:
"""获取组件元信息"""
match component_type:
case ComponentType.ACTION:
return self._actions.get(name)
case ComponentType.COMMAND:
return self._commands.get(name)
case ComponentType.TOOL:
return self._tools.get(name)
case _:
return None
def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
"""获取某类型的所有组件"""
match component_type:
case ComponentType.ACTION:
return dict(self._actions)
case ComponentType.COMMAND:
return dict(self._commands)
case ComponentType.TOOL:
return dict(self._tools)
case _:
return {}
# ========== 插件配置 ==========
def set_plugin_config(self, plugin_name: str, config: dict) -> None:
self._plugin_configs[plugin_name] = config
def get_plugin_config(self, plugin_name: str) -> Optional[dict]:
return self._plugin_configs.get(plugin_name)
# 全局单例
component_registry = ComponentRegistry()

216
src/core/event_bus.py Normal file
View File

@@ -0,0 +1,216 @@
"""
核心事件总线
面向最终架构的事件系统:
- 内部 handler 直接注册 async callable
- IPC 插件通过 plugin_runtime 桥接
- 不依赖任何插件基类
"""
import asyncio
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
from src.common.logger import get_logger
from src.core.types import EventType, MaiMessages
if TYPE_CHECKING:
from src.common.data_models.llm_data_model import LLMGenerationDataModel
logger = get_logger("event_bus")
# Handler 签名:接收 MaiMessages返回 (continue, modified_message)
EventHandler = Callable[[Optional[MaiMessages]], Awaitable[Tuple[bool, Optional[MaiMessages]]]]
class EventBus:
"""核心事件总线
支持两种 handler
- 拦截型intercept=True同步顺序执行可修改消息、可中断流程
- 非拦截型intercept=False异步并发执行fire-and-forget
handler 是纯 async callable不需要继承任何基类。
"""
def __init__(self):
# event_type -> [(handler, name, weight, intercept)]
self._handlers: Dict[EventType | str, List[_HandlerEntry]] = {}
self._running_tasks: Dict[str, List[asyncio.Task]] = {}
# 预注册所有内置事件类型
for event in EventType:
self._handlers[event] = []
def subscribe(
self,
event_type: EventType | str,
handler: EventHandler,
name: str,
weight: int = 0,
intercept: bool = False,
) -> None:
"""注册事件 handler
Args:
event_type: 事件类型
handler: async callable签名 (Optional[MaiMessages]) -> (bool, Optional[MaiMessages])
name: handler 标识名
weight: 权重,越大越先执行
intercept: 是否为拦截型(同步执行,可中断流程)
"""
if event_type not in self._handlers:
self._handlers[event_type] = []
entry = _HandlerEntry(handler=handler, name=name, weight=weight, intercept=intercept)
self._handlers[event_type].append(entry)
self._handlers[event_type].sort(key=lambda e: e.weight, reverse=True)
logger.debug(f"注册事件 handler: {name} -> {event_type} (weight={weight}, intercept={intercept})")
def unsubscribe(self, event_type: EventType | str, name: str) -> bool:
"""取消注册事件 handler"""
handlers = self._handlers.get(event_type, [])
for i, entry in enumerate(handlers):
if entry.name == name:
del handlers[i]
logger.debug(f"取消注册事件 handler: {name} <- {event_type}")
return True
return False
async def emit(
self,
event_type: EventType | str,
message: Optional[MaiMessages] = None,
) -> Tuple[bool, Optional[MaiMessages]]:
"""触发事件
按权重顺序执行所有 handler
- 拦截型 handler 同步执行,可修改消息和中断流程
- 非拦截型 handler 异步 fire-and-forget
Args:
event_type: 事件类型
message: 事件消息(可选)
Returns:
(continue_flag, modified_message)
- continue_flag: False 表示某个拦截型 handler 要求中断
- modified_message: 被拦截型 handler 修改后的消息
"""
handlers = self._handlers.get(event_type, [])
if not handlers:
return True, None
continue_flag = True
current_message = message.deepcopy() if message else None
for entry in handlers:
if entry.intercept:
try:
should_continue, modified = await entry.handler(current_message)
if modified is not None:
current_message = modified
if not should_continue:
continue_flag = False
break
except Exception as e:
logger.error(f"拦截型 handler {entry.name} 执行异常: {e}", exc_info=True)
else:
self._fire_and_forget(entry, event_type, current_message)
# 桥接到 IPC 插件运行时
continue_flag, current_message = await self._bridge_to_ipc_runtime(
event_type, continue_flag, current_message
)
return continue_flag, current_message
async def cancel_handler_tasks(self, handler_name: str) -> None:
"""取消某个 handler 的所有运行中任务"""
tasks = self._running_tasks.pop(handler_name, [])
remaining = [t for t in tasks if not t.done()]
if remaining:
for t in remaining:
t.cancel()
await asyncio.gather(*remaining, return_exceptions=True)
logger.info(f"已取消 handler {handler_name}{len(remaining)} 个任务")
# --- 内部方法 ---
def _fire_and_forget(
self,
entry: "_HandlerEntry",
event_type: EventType | str,
message: Optional[MaiMessages],
) -> None:
"""创建异步任务执行非拦截型 handler"""
try:
task = asyncio.create_task(entry.handler(message))
task.set_name(entry.name)
task.add_done_callback(lambda t: self._task_done_callback(t, entry.name))
self._running_tasks.setdefault(entry.name, []).append(task)
except Exception as e:
logger.error(f"创建 handler 任务 {entry.name} 失败: {e}", exc_info=True)
def _task_done_callback(self, task: asyncio.Task, handler_name: str) -> None:
"""异步任务完成回调"""
try:
if task.cancelled():
return
exc = task.exception()
if exc:
logger.error(f"handler {handler_name} 异步任务异常: {exc}")
except Exception:
pass
finally:
task_list = self._running_tasks.get(handler_name, [])
try:
task_list.remove(task)
except ValueError:
pass
async def _bridge_to_ipc_runtime(
self,
event_type: EventType | str,
continue_flag: bool,
message: Optional[MaiMessages],
) -> Tuple[bool, Optional[MaiMessages]]:
"""将事件桥接到 IPC 插件运行时"""
if not continue_flag:
return continue_flag, message
try:
from src.plugin_runtime.integration import get_plugin_runtime_manager
prm = get_plugin_runtime_manager()
if not prm.is_running:
return continue_flag, message
event_value = event_type.value if isinstance(event_type, EventType) else str(event_type)
message_dict = message.to_dict() if message and hasattr(message, "to_dict") else None
new_continue, _ = await prm.bridge_event(
event_type_value=event_value,
message_dict=message_dict,
)
if not new_continue:
continue_flag = False
except Exception as e:
logger.warning(f"桥接事件到 IPC 运行时失败: {e}")
return continue_flag, message
class _HandlerEntry:
"""内部 handler 条目"""
__slots__ = ("handler", "name", "weight", "intercept")
def __init__(self, handler: EventHandler, name: str, weight: int, intercept: bool):
self.handler = handler
self.name = name
self.weight = weight
self.intercept = intercept
# 全局单例
event_bus = EventBus()

View File

@@ -169,6 +169,14 @@ class ToolInfo(ComponentInfo):
super().__post_init__() super().__post_init__()
self.component_type = ComponentType.TOOL self.component_type = ComponentType.TOOL
def get_llm_definition(self) -> dict:
"""生成 LLM function-calling 所需的工具定义"""
return {
"name": self.name,
"description": self.tool_description,
"parameters": self.tool_parameters,
}
@dataclass @dataclass
class EventHandlerInfo(ComponentInfo): class EventHandlerInfo(ComponentInfo):

View File

@@ -10,7 +10,7 @@ from src.config.config import global_config, model_config
from src.common.database.database_model import ChatHistory from src.common.database.database_model import ChatHistory
from src.prompt.prompt_manager import prompt_manager from src.prompt.prompt_manager import prompt_manager
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
from src.plugin_system.apis import llm_api from src.services import llm_service as llm_api
from src.dream.dream_generator import generate_dream_summary from src.dream.dream_generator import generate_dream_summary
# dream 工具工厂函数 # dream 工具工厂函数

View File

@@ -8,7 +8,7 @@ from src.llm_models.payload_content.message import RoleType, Message
from src.prompt.prompt_manager import prompt_manager from src.prompt.prompt_manager import prompt_manager
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.common.utils.utils_session import SessionUtils from src.common.utils.utils_session import SessionUtils
from src.plugin_system.apis import send_api from src.services import send_service as send_api
logger = get_logger("dream_generator") logger = get_logger("dream_generator")

View File

@@ -2,7 +2,7 @@ from typing import Any, Dict, Optional
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.database_model import ChatHistory from src.common.database.database_model import ChatHistory
from src.plugin_system.apis import database_api from src.services import database_service as database_api
logger = get_logger("dream_agent") logger = get_logger("dream_agent")

View File

@@ -2,7 +2,7 @@ from typing import Any, Dict, Optional
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.database_model import Jargon from src.common.database.database_model import Jargon
from src.plugin_system.apis import database_api from src.services import database_service as database_api
logger = get_logger("dream_agent") logger = get_logger("dream_agent")

View File

@@ -18,10 +18,7 @@ from rich.traceback import install
# from src.api.main import start_api_server # from src.api.main import start_api_server
# 导入新的插件管理器 # 导入插件运行时
from src.plugin_system.core.plugin_manager import plugin_manager
# 导入新版本插件运行时
from src.plugin_runtime.integration import get_plugin_runtime_manager from src.plugin_runtime.integration import get_plugin_runtime_manager
# 导入消息API和traceback模块 # 导入消息API和traceback模块
@@ -31,8 +28,6 @@ from src.bw_learner.expression_auto_check_task import ExpressionAutoCheckTask
from src.prompt.prompt_manager import prompt_manager from src.prompt.prompt_manager import prompt_manager
# 插件系统现在使用统一的插件加载器
install(extra_lines=3) install(extra_lines=3)
logger = get_logger("main") logger = get_logger("main")
@@ -108,10 +103,7 @@ class MainSystem:
# 启动LPMM # 启动LPMM
lpmm_start_up() lpmm_start_up()
# 加载所有actions包括默认的和插件的 # 启动插件运行时(内置插件 + 第三方插件双子进程)
plugin_manager.load_all_plugins()
# 启动新版本插件运行时(与旧系统并行运行)
await get_plugin_runtime_manager().start() await get_plugin_runtime_manager().start()
# 初始化表情管理器 # 初始化表情管理器
@@ -133,12 +125,12 @@ class MainSystem:
prompt_manager.load_prompts() prompt_manager.load_prompts()
# 触发 ON_START 事件 # 触发 ON_START 事件
from src.plugin_system.core.events_manager import events_manager from src.core.event_bus import event_bus
from src.plugin_system.base.component_types import EventType from src.core.types import EventType
await events_manager.handle_mai_events(event_type=EventType.ON_START) await event_bus.emit(event_type=EventType.ON_START)
# 桥接 ON_START 事件到新版本插件运行时 # 分发 ON_START 事件到插件运行时
await get_plugin_runtime_manager().bridge_event("on_start") await get_plugin_runtime_manager().bridge_event("on_start")
# logger.info("已触发 ON_START 事件") # logger.info("已触发 ON_START 事件")
try: try:

View File

@@ -17,7 +17,7 @@ from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.database_data_model import DatabaseMessages
from src.config.config import model_config, global_config from src.config.config import model_config, global_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.plugin_system.apis import message_api from src.services import message_service as message_api
from src.chat.utils.chat_message_builder import build_readable_messages from src.chat.utils.chat_message_builder import build_readable_messages
from src.chat.utils.utils import is_bot_self from src.chat.utils.utils import is_bot_self
from src.person_info.person_info import Person from src.person_info.person_info import Person
@@ -913,7 +913,7 @@ class ChatHistorySummarizer:
"""存储到数据库""" """存储到数据库"""
try: try:
from src.common.database.database_model import ChatHistory from src.common.database.database_model import ChatHistory
from src.plugin_system.apis import database_api from src.services import database_service as database_api
# 准备数据 # 准备数据
data = { data = {

View File

@@ -7,7 +7,7 @@ from typing import List, Dict, Any, Optional, Tuple, Callable
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.prompt.prompt_manager import prompt_manager from src.prompt.prompt_manager import prompt_manager
from src.plugin_system.apis import llm_api from src.services import llm_service as llm_api
from sqlmodel import select, col from sqlmodel import select, col
from src.common.database.database import get_db_session from src.common.database.database import get_db_session
from src.common.database.database_model import ThinkingQuestion from src.common.database.database_model import ThinkingQuestion

File diff suppressed because it is too large Load Diff

View File

@@ -22,8 +22,9 @@ class UDSTransportServer(TransportServer):
def __init__(self, socket_path: str | None = None): def __init__(self, socket_path: str | None = None):
if socket_path is None: if socket_path is None:
# 默认放在临时目录 # 默认放在临时目录,使用 uuid 确保同一进程多实例不碰撞
socket_path = os.path.join(tempfile.gettempdir(), f"maibot-plugin-{os.getpid()}.sock") import uuid
socket_path = os.path.join(tempfile.gettempdir(), f"maibot-plugin-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock")
self._socket_path = socket_path self._socket_path = socket_path
self._server: asyncio.AbstractServer | None = None self._server: asyncio.AbstractServer | None = None

View File

@@ -1,156 +0,0 @@
"""
MaiBot 插件系统
提供统一的插件开发和管理框架
"""
# 导出主要的公共接口
from .base import (
BasePlugin,
BaseAction,
BaseCommand,
BaseTool,
ConfigField,
ConfigSection,
ConfigLayout,
ConfigTab,
ComponentType,
ActionActivationType,
ChatMode,
ComponentInfo,
ActionInfo,
CommandInfo,
PluginInfo,
ToolInfo,
PythonDependency,
BaseEventHandler,
EventHandlerInfo,
EventType,
MaiMessages,
ToolParamType,
CustomEventHandlerResult,
PluginServiceInfo,
ReplyContentType,
ReplyContent,
ForwardNode,
ReplySetModel,
WorkflowContext,
WorkflowMessage,
WorkflowStage,
WorkflowStepInfo,
WorkflowStepResult,
WorkflowErrorCode,
)
# 导入工具模块
from .utils import (
ManifestValidator,
# ManifestGenerator,
# validate_plugin_manifest,
# generate_plugin_manifest,
)
from .apis import (
chat_api,
tool_api,
component_manage_api,
config_api,
database_api,
emoji_api,
generator_api,
llm_api,
message_api,
person_api,
plugin_manage_api,
plugin_service_api,
workflow_api,
send_api,
register_plugin,
get_logger,
)
from src.common.data_models.database_data_model import (
DatabaseMessages,
DatabaseUserInfo,
DatabaseGroupInfo,
DatabaseChatInfo,
)
from src.common.data_models.info_data_model import TargetPersonInfo, ActionPlannerInfo
from src.common.data_models.llm_data_model import LLMGenerationDataModel
__version__ = "2.0.0"
__all__ = [
# API 模块
"chat_api",
"tool_api",
"component_manage_api",
"config_api",
"database_api",
"emoji_api",
"generator_api",
"llm_api",
"message_api",
"person_api",
"plugin_manage_api",
"plugin_service_api",
"workflow_api",
"send_api",
"auto_talk_api",
"register_plugin",
"get_logger",
# 基础类
"BasePlugin",
"BaseAction",
"BaseCommand",
"BaseTool",
"BaseEventHandler",
# 类型定义
"ComponentType",
"ActionActivationType",
"ChatMode",
"ComponentInfo",
"ActionInfo",
"CommandInfo",
"PluginInfo",
"ToolInfo",
"PythonDependency",
"EventHandlerInfo",
"EventType",
"ToolParamType",
# 消息
"ReplyContentType",
"ReplyContent",
"ForwardNode",
"ReplySetModel",
"MaiMessages",
"CustomEventHandlerResult",
"PluginServiceInfo",
"WorkflowContext",
"WorkflowMessage",
"WorkflowStage",
"WorkflowStepInfo",
"WorkflowStepResult",
"WorkflowErrorCode",
# 装饰器
"register_plugin",
"ConfigField",
"ConfigSection",
"ConfigLayout",
"ConfigTab",
# 工具函数
"ManifestValidator",
"get_logger",
# "ManifestGenerator",
# "validate_plugin_manifest",
# "generate_plugin_manifest",
# 数据模型
"DatabaseMessages",
"DatabaseUserInfo",
"DatabaseGroupInfo",
"DatabaseChatInfo",
"TargetPersonInfo",
"ActionPlannerInfo",
"LLMGenerationDataModel",
]

View File

@@ -1,47 +0,0 @@
"""
插件系统API模块
提供了插件开发所需的各种API
"""
# 导入所有API模块
from src.plugin_system.apis import (
chat_api,
component_manage_api,
config_api,
database_api,
emoji_api,
generator_api,
llm_api,
message_api,
person_api,
plugin_manage_api,
plugin_service_api,
send_api,
tool_api,
frequency_api,
workflow_api,
)
from .logging_api import get_logger
from .plugin_register_api import register_plugin
# 导出所有API模块使它们可以通过 apis.xxx 方式访问
__all__ = [
"chat_api",
"component_manage_api",
"config_api",
"database_api",
"emoji_api",
"generator_api",
"llm_api",
"message_api",
"person_api",
"plugin_manage_api",
"plugin_service_api",
"send_api",
"get_logger",
"register_plugin",
"tool_api",
"frequency_api",
"workflow_api",
]

View File

@@ -1,323 +0,0 @@
"""
聊天API模块
专门负责聊天信息的查询和管理采用标准Python包设计模式
使用方式:
from src.plugin_system.apis import chat_api
streams = chat_api.get_all_group_streams()
chat_type = chat_api.get_stream_type(stream)
或者:
from src.plugin_system.apis.chat_api import ChatManager as chat
streams = chat.get_all_group_streams()
"""
from typing import List, Dict, Any, Optional
from enum import Enum
from src.common.logger import get_logger
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
logger = get_logger("chat_api")
class SpecialTypes(Enum):
"""特殊枚举类型"""
ALL_PLATFORMS = "all_platforms"
class ChatManager:
"""聊天管理器 - 专门负责聊天信息的查询和管理"""
@staticmethod
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
# sourcery skip: for-append-to-extend
"""获取所有聊天流
Args:
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns:
List[BotChatSession]: 聊天流列表
Raises:
TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型
"""
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = []
try:
for _, stream in _chat_manager.sessions.items():
if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform:
streams.append(stream)
logger.debug(f"[ChatAPI] 获取到 {len(streams)}{platform} 平台的聊天流")
except Exception as e:
logger.error(f"[ChatAPI] 获取聊天流失败: {e}")
return streams
@staticmethod
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
# sourcery skip: for-append-to-extend
"""获取所有群聊聊天流
Args:
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns:
List[BotChatSession]: 群聊聊天流列表
"""
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = []
try:
for _, stream in _chat_manager.sessions.items():
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.is_group_session:
streams.append(stream)
logger.debug(f"[ChatAPI] 获取到 {len(streams)}{platform} 平台的群聊流")
except Exception as e:
logger.error(f"[ChatAPI] 获取群聊流失败: {e}")
return streams
@staticmethod
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
# sourcery skip: for-append-to-extend
"""获取所有私聊聊天流
Args:
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns:
List[BotChatSession]: 私聊聊天流列表
Raises:
TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型
"""
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = []
try:
for _, stream in _chat_manager.sessions.items():
if (
platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform
) and not stream.is_group_session:
streams.append(stream)
logger.debug(f"[ChatAPI] 获取到 {len(streams)}{platform} 平台的私聊流")
except Exception as e:
logger.error(f"[ChatAPI] 获取私聊流失败: {e}")
return streams
@staticmethod
def get_group_stream_by_group_id(
group_id: str, platform: Optional[str] | SpecialTypes = "qq"
) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast
"""根据群ID获取聊天流
Args:
group_id: 群聊ID
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns:
Optional[BotChatSession]: 聊天流对象如果未找到返回None
Raises:
ValueError: 如果 group_id 为空字符串
TypeError: 如果 group_id 不是字符串类型或 platform 不是字符串或 SpecialTypes
"""
if not isinstance(group_id, str):
raise TypeError("group_id 必须是字符串类型")
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
if not group_id:
raise ValueError("group_id 不能为空")
try:
for _, stream in _chat_manager.sessions.items():
if (
stream.is_group_session
and str(stream.group_id) == str(group_id)
and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform)
):
logger.debug(f"[ChatAPI] 找到群ID {group_id} 的聊天流")
return stream
logger.warning(f"[ChatAPI] 未找到群ID {group_id} 的聊天流")
except Exception as e:
logger.error(f"[ChatAPI] 查找群聊流失败: {e}")
return None
@staticmethod
def get_private_stream_by_user_id(
user_id: str, platform: Optional[str] | SpecialTypes = "qq"
) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast
"""根据用户ID获取私聊流
Args:
user_id: 用户ID
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns:
Optional[BotChatSession]: 聊天流对象如果未找到返回None
Raises:
ValueError: 如果 user_id 为空字符串
TypeError: 如果 user_id 不是字符串类型或 platform 不是字符串或 SpecialTypes
"""
if not isinstance(user_id, str):
raise TypeError("user_id 必须是字符串类型")
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
if not user_id:
raise ValueError("user_id 不能为空")
try:
for _, stream in _chat_manager.sessions.items():
if (
not stream.is_group_session
and str(stream.user_id) == str(user_id)
and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform)
):
logger.debug(f"[ChatAPI] 找到用户ID {user_id} 的私聊流")
return stream
logger.warning(f"[ChatAPI] 未找到用户ID {user_id} 的私聊流")
except Exception as e:
logger.error(f"[ChatAPI] 查找私聊流失败: {e}")
return None
@staticmethod
def get_stream_type(chat_stream: BotChatSession) -> str:
"""获取聊天流类型
Args:
chat_stream: 聊天流对象
Returns:
str: 聊天类型 ("group", "private", "unknown")
Raises:
TypeError: 如果 chat_stream 不是 BotChatSession 类型
ValueError: 如果 chat_stream 为空
"""
if not isinstance(chat_stream, BotChatSession):
raise TypeError("chat_stream 必须是 BotChatSession 类型")
if not chat_stream:
raise ValueError("chat_stream 不能为 None")
return "group" if chat_stream.is_group_session else "private"
@staticmethod
def get_stream_info(chat_stream: BotChatSession) -> Dict[str, Any]:
"""获取聊天流详细信息
Args:
chat_stream: 聊天流对象
Returns:
Dict ({str: Any}): 聊天流信息字典
Raises:
TypeError: 如果 chat_stream 不是 BotChatSession 类型
ValueError: 如果 chat_stream 为空
"""
if not chat_stream:
raise ValueError("chat_stream 不能为 None")
if not isinstance(chat_stream, BotChatSession):
raise TypeError("chat_stream 必须是 BotChatSession 类型")
try:
info: Dict[str, Any] = {
"session_id": chat_stream.session_id,
"platform": chat_stream.platform,
"type": ChatManager.get_stream_type(chat_stream),
}
if chat_stream.is_group_session:
info["group_id"] = chat_stream.group_id
# Try to get group name from context
if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.group_info:
info["group_name"] = chat_stream.context.message.message_info.group_info.group_name or "未知群聊"
else:
info["group_name"] = "未知群聊"
else:
info["user_id"] = chat_stream.user_id
if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.user_info:
info["user_name"] = chat_stream.context.message.message_info.user_info.user_nickname
else:
info["user_name"] = "未知用户"
return info
except Exception as e:
logger.error(f"[ChatAPI] 获取聊天流信息失败: {e}")
return {}
@staticmethod
def get_streams_summary() -> Dict[str, int]:
"""获取聊天流统计摘要
Returns:
Dict[str, int]: 包含各种统计信息的字典
"""
try:
all_streams = ChatManager.get_all_streams(SpecialTypes.ALL_PLATFORMS)
group_streams = ChatManager.get_group_streams(SpecialTypes.ALL_PLATFORMS)
private_streams = ChatManager.get_private_streams(SpecialTypes.ALL_PLATFORMS)
summary = {
"total_streams": len(all_streams),
"group_streams": len(group_streams),
"private_streams": len(private_streams),
"qq_streams": len([s for s in all_streams if s.platform == "qq"]),
}
logger.debug(f"[ChatAPI] 聊天流统计: {summary}")
return summary
except Exception as e:
logger.error(f"[ChatAPI] 获取聊天流统计失败: {e}")
return {
"total_streams": 0,
"group_streams": 0,
"private_streams": 0,
"qq_streams": 0,
}
# =============================================================================
# 模块级别的便捷函数 - 类似 requests.get(), requests.post() 的设计
# =============================================================================
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
"""获取所有聊天流的便捷函数"""
return ChatManager.get_all_streams(platform)
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
"""获取群聊聊天流的便捷函数"""
return ChatManager.get_group_streams(platform)
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
"""获取私聊聊天流的便捷函数"""
return ChatManager.get_private_streams(platform)
def get_stream_by_group_id(group_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[BotChatSession]:
"""根据群ID获取聊天流的便捷函数"""
return ChatManager.get_group_stream_by_group_id(group_id, platform)
def get_stream_by_user_id(user_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[BotChatSession]:
"""根据用户ID获取私聊流的便捷函数"""
return ChatManager.get_private_stream_by_user_id(user_id, platform)
def get_stream_type(chat_stream: BotChatSession) -> str:
"""获取聊天流类型的便捷函数"""
return ChatManager.get_stream_type(chat_stream)
def get_stream_info(chat_stream: BotChatSession) -> Dict[str, Any]:
"""获取聊天流信息的便捷函数"""
return ChatManager.get_stream_info(chat_stream)
def get_streams_summary() -> Dict[str, int]:
"""获取聊天流统计摘要的便捷函数"""
return ChatManager.get_streams_summary()

View File

@@ -1,268 +0,0 @@
from typing import Optional, Union, Dict
from src.plugin_system.base.component_types import (
CommandInfo,
ActionInfo,
EventHandlerInfo,
PluginInfo,
ComponentType,
ToolInfo,
)
# === 插件信息查询 ===
def get_all_plugin_info() -> Dict[str, PluginInfo]:
"""
获取所有插件的信息。
Returns:
dict: 包含所有插件信息的字典,键为插件名称,值为 PluginInfo 对象。
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_all_plugins()
def get_plugin_info(plugin_name: str) -> Optional[PluginInfo]:
"""
获取指定插件的信息。
Args:
plugin_name (str): 插件名称。
Returns:
PluginInfo: 插件信息对象,如果插件不存在则返回 None。
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_plugin_info(plugin_name)
# === 组件查询方法 ===
def get_component_info(
component_name: str, component_type: ComponentType
) -> Optional[Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
"""
获取指定组件的信息。
Args:
component_name (str): 组件名称。
component_type (ComponentType): 组件类型。
Returns:
Union[CommandInfo, ActionInfo, EventHandlerInfo]: 组件信息对象,如果组件不存在则返回 None。
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_component_info(component_name, component_type) # type: ignore
def get_components_info_by_type(
component_type: ComponentType,
) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
"""
获取指定类型的所有组件信息。
Args:
component_type (ComponentType): 组件类型。
Returns:
dict: 包含指定类型组件信息的字典,键为组件名称,值为对应的组件信息对象。
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_components_by_type(component_type) # type: ignore
def get_enabled_components_info_by_type(
component_type: ComponentType,
) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
"""
获取指定类型的所有启用的组件信息。
Args:
component_type (ComponentType): 组件类型。
Returns:
dict: 包含指定类型启用组件信息的字典,键为组件名称,值为对应的组件信息对象。
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_enabled_components_by_type(component_type) # type: ignore
# === Action 查询方法 ===
def get_registered_action_info(action_name: str) -> Optional[ActionInfo]:
"""
获取指定 Action 的注册信息。
Args:
action_name (str): Action 名称。
Returns:
ActionInfo: Action 信息对象,如果 Action 不存在则返回 None。
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_registered_action_info(action_name)
def get_registered_command_info(command_name: str) -> Optional[CommandInfo]:
"""
获取指定 Command 的注册信息。
Args:
command_name (str): Command 名称。
Returns:
CommandInfo: Command 信息对象,如果 Command 不存在则返回 None。
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_registered_command_info(command_name)
def get_registered_tool_info(tool_name: str) -> Optional[ToolInfo]:
"""
获取指定 Tool 的注册信息。
Args:
tool_name (str): Tool 名称。
Returns:
ToolInfo: Tool 信息对象,如果 Tool 不存在则返回 None。
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_registered_tool_info(tool_name)
# === EventHandler 特定查询方法 ===
def get_registered_event_handler_info(
event_handler_name: str,
) -> Optional[EventHandlerInfo]:
"""
获取指定 EventHandler 的注册信息。
Args:
event_handler_name (str): EventHandler 名称。
Returns:
EventHandlerInfo: EventHandler 信息对象,如果 EventHandler 不存在则返回 None。
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_registered_event_handler_info(event_handler_name)
# === 组件管理方法 ===
def globally_enable_component(component_name: str, component_type: ComponentType) -> bool:
"""
全局启用指定组件。
Args:
component_name (str): 组件名称。
component_type (ComponentType): 组件类型。
Returns:
bool: 启用成功返回 True否则返回 False。
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.enable_component(component_name, component_type)
async def globally_disable_component(component_name: str, component_type: ComponentType) -> bool:
"""
全局禁用指定组件。
**此函数是异步的,确保在异步环境中调用。**
Args:
component_name (str): 组件名称。
component_type (ComponentType): 组件类型。
Returns:
bool: 禁用成功返回 True否则返回 False。
"""
from src.plugin_system.core.component_registry import component_registry
return await component_registry.disable_component(component_name, component_type)
def locally_enable_component(component_name: str, component_type: ComponentType, stream_id: str) -> bool:
"""
局部启用指定组件。
Args:
component_name (str): 组件名称。
component_type (ComponentType): 组件类型。
stream_id (str): 消息流 ID。
Returns:
bool: 启用成功返回 True否则返回 False。
"""
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
match component_type:
case ComponentType.ACTION:
return global_announcement_manager.enable_specific_chat_action(stream_id, component_name)
case ComponentType.COMMAND:
return global_announcement_manager.enable_specific_chat_command(stream_id, component_name)
case ComponentType.TOOL:
return global_announcement_manager.enable_specific_chat_tool(stream_id, component_name)
case ComponentType.EVENT_HANDLER:
return global_announcement_manager.enable_specific_chat_event_handler(stream_id, component_name)
case _:
raise ValueError(f"未知 component type: {component_type}")
def locally_disable_component(component_name: str, component_type: ComponentType, stream_id: str) -> bool:
"""
局部禁用指定组件。
Args:
component_name (str): 组件名称。
component_type (ComponentType): 组件类型。
stream_id (str): 消息流 ID。
Returns:
bool: 禁用成功返回 True否则返回 False。
"""
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
match component_type:
case ComponentType.ACTION:
return global_announcement_manager.disable_specific_chat_action(stream_id, component_name)
case ComponentType.COMMAND:
return global_announcement_manager.disable_specific_chat_command(stream_id, component_name)
case ComponentType.TOOL:
return global_announcement_manager.disable_specific_chat_tool(stream_id, component_name)
case ComponentType.EVENT_HANDLER:
return global_announcement_manager.disable_specific_chat_event_handler(stream_id, component_name)
case _:
raise ValueError(f"未知 component type: {component_type}")
def get_locally_disabled_components(stream_id: str, component_type: ComponentType) -> list[str]:
"""
获取指定消息流中禁用的组件列表。
Args:
stream_id (str): 消息流 ID。
component_type (ComponentType): 组件类型。
Returns:
list[str]: 禁用的组件名称列表。
"""
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
match component_type:
case ComponentType.ACTION:
return global_announcement_manager.get_disabled_chat_actions(stream_id)
case ComponentType.COMMAND:
return global_announcement_manager.get_disabled_chat_commands(stream_id)
case ComponentType.TOOL:
return global_announcement_manager.get_disabled_chat_tools(stream_id)
case ComponentType.EVENT_HANDLER:
return global_announcement_manager.get_disabled_chat_event_handlers(stream_id)
case _:
raise ValueError(f"未知 component type: {component_type}")

View File

@@ -1,22 +0,0 @@
from pathlib import Path
from dataclasses import dataclass
@dataclass(frozen=True)
class _SystemConstants:
PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.parent.absolute().resolve()
CONFIG_DIR: Path = PROJECT_ROOT / "config"
BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
PLUGINS_DIR: Path = (PROJECT_ROOT / "plugins").resolve().absolute()
INTERNAL_PLUGINS_DIR: Path = (PROJECT_ROOT / "src" / "plugins").resolve().absolute()
_system_constants = _SystemConstants()
PROJECT_ROOT: Path = _system_constants.PROJECT_ROOT
CONFIG_DIR: Path = _system_constants.CONFIG_DIR
BOT_CONFIG_PATH: Path = _system_constants.BOT_CONFIG_PATH
MODEL_CONFIG_PATH: Path = _system_constants.MODEL_CONFIG_PATH
PLUGINS_DIR: Path = _system_constants.PLUGINS_DIR
INTERNAL_PLUGINS_DIR: Path = _system_constants.INTERNAL_PLUGINS_DIR

View File

@@ -1,700 +0,0 @@
"""
表情API模块
提供表情包相关功能采用标准Python包设计模式
使用方式:
from src.plugin_system.apis import emoji_api
result = await emoji_api.get_by_description("开心")
count = emoji_api.get_count()
"""
import random
import base64
import os
import uuid
from typing import Optional, Tuple, List, Dict, Any
from src.common.logger import get_logger
from src.chat.emoji_system.emoji_manager import emoji_manager, EMOJI_DIR
from src.common.utils.utils_image import ImageUtils
from src.config.config import global_config
logger = get_logger("emoji_api")
# =============================================================================
# 表情包获取API函数
# =============================================================================
async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]:
"""根据描述选择表情包
Args:
description: 表情包的描述文本,例如"开心""难过""愤怒"
Returns:
Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None
Raises:
ValueError: 如果描述为空字符串
TypeError: 如果描述不是字符串类型
"""
if not description:
raise ValueError("描述不能为空")
if not isinstance(description, str):
raise TypeError("描述必须是字符串类型")
try:
logger.debug(f"[EmojiAPI] 根据描述获取表情包: {description}")
emoji_obj = await emoji_manager.get_emoji_for_emotion(description)
if not emoji_obj:
logger.warning(f"[EmojiAPI] 未找到匹配描述 '{description}' 的表情包")
return None
emoji_path = str(emoji_obj.full_path)
emoji_description = emoji_obj.description
matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else ""
emoji_base64 = ImageUtils.image_path_to_base64(emoji_path)
if not emoji_base64:
logger.error(f"[EmojiAPI] 无法将表情包文件转换为base64: {emoji_path}")
return None
logger.debug(f"[EmojiAPI] 成功获取表情包: {emoji_description}, 匹配情感: {matched_emotion}")
return emoji_base64, emoji_description, matched_emotion
except Exception as e:
logger.error(f"[EmojiAPI] 获取表情包失败: {e}")
return None
async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
"""随机获取指定数量的表情包
Args:
count: 要获取的表情包数量默认为1
Returns:
List[Tuple[str, str, str]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表,失败则返回空列表
Raises:
TypeError: 如果count不是整数类型
ValueError: 如果count为负数
"""
if not isinstance(count, int):
raise TypeError("count 必须是整数类型")
if count < 0:
raise ValueError("count 不能为负数")
if count == 0:
logger.warning("[EmojiAPI] count 为0返回空列表")
return []
try:
all_emojis = emoji_manager.emojis
if not all_emojis:
logger.warning("[EmojiAPI] 没有可用的表情包")
return []
# 过滤有效表情包
valid_emojis = [emoji for emoji in all_emojis if not emoji.is_deleted]
if not valid_emojis:
logger.warning("[EmojiAPI] 没有有效的表情包")
return []
if len(valid_emojis) < count:
logger.debug(
f"[EmojiAPI] 有效表情包数量 ({len(valid_emojis)}) 少于请求的数量 ({count}),将返回所有有效表情包"
)
count = len(valid_emojis)
# 随机选择
selected_emojis = random.sample(valid_emojis, count)
results = []
for selected_emoji in selected_emojis:
emoji_base64 = ImageUtils.image_path_to_base64(str(selected_emoji.full_path))
if not emoji_base64:
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {selected_emoji.full_path}")
continue
matched_emotion = random.choice(selected_emoji.emotion) if selected_emoji.emotion else "随机表情"
# 记录使用次数
emoji_manager.update_emoji_usage(selected_emoji)
results.append((emoji_base64, selected_emoji.description, matched_emotion))
if not results and count > 0:
logger.warning("[EmojiAPI] 随机获取表情包失败,没有一个可以成功处理")
return []
logger.debug(f"[EmojiAPI] 成功获取 {len(results)} 个随机表情包")
return results
except Exception as e:
logger.error(f"[EmojiAPI] 获取随机表情包失败: {e}")
return []
async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
"""根据情感标签获取表情包
Args:
emotion: 情感标签,如"happy""sad""angry"
Returns:
Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None
Raises:
ValueError: 如果情感标签为空字符串
TypeError: 如果情感标签不是字符串类型
"""
if not emotion:
raise ValueError("情感标签不能为空")
if not isinstance(emotion, str):
raise TypeError("情感标签必须是字符串类型")
try:
logger.info(f"[EmojiAPI] 根据情感获取表情包: {emotion}")
all_emojis = emoji_manager.emojis
# 筛选匹配情感的表情包
matching_emojis = []
matching_emojis.extend(
emoji_obj
for emoji_obj in all_emojis
if not emoji_obj.is_deleted and emotion.lower() in [e.lower() for e in emoji_obj.emotion]
)
if not matching_emojis:
logger.warning(f"[EmojiAPI] 未找到匹配情感 '{emotion}' 的表情包")
return None
# 随机选择匹配的表情包
selected_emoji = random.choice(matching_emojis)
emoji_base64 = ImageUtils.image_path_to_base64(selected_emoji.full_path)
if not emoji_base64:
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {selected_emoji.full_path}")
return None
# 记录使用次数
emoji_manager.update_emoji_usage(selected_emoji)
logger.info(f"[EmojiAPI] 成功获取情感表情包: {selected_emoji.description}")
return emoji_base64, selected_emoji.description, emotion
except Exception as e:
logger.error(f"[EmojiAPI] 根据情感获取表情包失败: {e}")
return None
# =============================================================================
# 表情包信息查询API函数
# =============================================================================
def get_count() -> int:
"""获取表情包数量
Returns:
int: 当前可用的表情包数量
"""
try:
return len(emoji_manager.emojis)
except Exception as e:
logger.error(f"[EmojiAPI] 获取表情包数量失败: {e}")
return 0
def get_info():
"""获取表情包系统信息
Returns:
dict: 包含表情包数量、最大数量、可用数量信息
"""
try:
return {
"current_count": len(emoji_manager.emojis),
"max_count": global_config.emoji.max_reg_num,
"available_emojis": len([e for e in emoji_manager.emojis if not e.is_deleted]),
}
except Exception as e:
logger.error(f"[EmojiAPI] 获取表情包信息失败: {e}")
return {"current_count": 0, "max_count": 0, "available_emojis": 0}
def get_emotions() -> List[str]:
"""获取所有可用的情感标签
Returns:
list: 所有表情包的情感标签列表(去重)
"""
try:
emotions = set()
for emoji_obj in emoji_manager.emojis:
if not emoji_obj.is_deleted and emoji_obj.emotion:
emotions.update(emoji_obj.emotion)
return sorted(list(emotions))
except Exception as e:
logger.error(f"[EmojiAPI] 获取情感标签失败: {e}")
return []
async def get_all() -> List[Tuple[str, str, str]]:
"""获取所有表情包
Returns:
List[Tuple[str, str, str]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表
"""
try:
all_emojis = emoji_manager.emojis
if not all_emojis:
logger.warning("[EmojiAPI] 没有可用的表情包")
return []
results = []
for emoji_obj in all_emojis:
if emoji_obj.is_deleted:
continue
emoji_base64 = ImageUtils.image_path_to_base64(str(emoji_obj.full_path))
if not emoji_base64:
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {emoji_obj.full_path}")
continue
matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else "随机表情"
results.append((emoji_base64, emoji_obj.description, matched_emotion))
logger.debug(f"[EmojiAPI] 成功获取 {len(results)} 个表情包")
return results
except Exception as e:
logger.error(f"[EmojiAPI] 获取所有表情包失败: {e}")
return []
def get_descriptions() -> List[str]:
"""获取所有表情包描述
Returns:
list: 所有可用表情包的描述列表
"""
try:
descriptions = []
descriptions.extend(
emoji_obj.description
for emoji_obj in emoji_manager.emojis
if not emoji_obj.is_deleted and emoji_obj.description
)
return descriptions
except Exception as e:
logger.error(f"[EmojiAPI] 获取表情包描述失败: {e}")
return []
# =============================================================================
# 表情包注册API函数
# =============================================================================
async def register_emoji(image_base64: str, filename: Optional[str] = None) -> Dict[str, Any]:
"""注册新的表情包
Args:
image_base64: 图片的base64编码
filename: 可选的文件名,如果未提供则自动生成
Returns:
Dict[str, Any]: 注册结果,包含以下字段:
- success: bool, 是否成功注册
- message: str, 结果消息
- description: Optional[str], 表情包描述(成功时)
- emotions: Optional[List[str]], 情感标签列表(成功时)
- replaced: Optional[bool], 是否替换了旧表情包(成功时)
- hash: Optional[str], 表情包哈希值(成功时)
Raises:
ValueError: 如果base64为空或无效
TypeError: 如果参数类型不正确
"""
if not image_base64:
raise ValueError("图片base64编码不能为空")
if not isinstance(image_base64, str):
raise TypeError("image_base64必须是字符串类型")
if filename is not None and not isinstance(filename, str):
raise TypeError("filename必须是字符串类型或None")
try:
logger.info(f"[EmojiAPI] 开始注册表情包,文件名: {filename or '自动生成'}")
# 1. 获取emoji管理器并检查容量
count_before = len(emoji_manager.emojis)
max_count = global_config.emoji.max_reg_num
# 2. 检查是否可以注册(未达到上限或启用替换)
can_register = count_before < max_count or (count_before >= max_count and global_config.emoji.do_replace)
if not can_register:
return {
"success": False,
"message": f"表情包数量已达上限({count_before}/{max_count})且未启用替换功能",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
# 3. 确保emoji目录存在
os.makedirs(EMOJI_DIR, exist_ok=True)
# 4. 生成文件名
if not filename:
# 基于时间戳、微秒和短base64生成唯一文件名
import time
timestamp = int(time.time())
microseconds = int(time.time() * 1000000) % 1000000 # 添加微秒级精度
# 生成12位随机标识符使用base64编码增加随机性
import random
random_bytes = random.getrandbits(72).to_bytes(9, "big") # 72位 = 9字节 = 12位base64
short_id = base64.b64encode(random_bytes).decode("ascii")[:12].rstrip("=")
# 确保base64编码适合文件名替换/和-
short_id = short_id.replace("/", "_").replace("+", "-")
filename = f"emoji_{timestamp}_{microseconds}_{short_id}"
# 确保文件名有扩展名
if not filename.lower().endswith((".jpg", ".jpeg", ".png", ".gif")):
filename = f"{filename}.png" # 默认使用png格式
# 检查文件名是否已存在,如果存在则重新生成短标识符
temp_file_path = os.path.join(EMOJI_DIR, filename)
attempts = 0
max_attempts = 10
while os.path.exists(temp_file_path) and attempts < max_attempts:
# 重新生成短标识符
import random
random_bytes = random.getrandbits(48).to_bytes(6, "big")
short_id = base64.b64encode(random_bytes).decode("ascii")[:8].rstrip("=")
short_id = short_id.replace("/", "_").replace("+", "-")
# 分离文件名和扩展名,重新生成文件名
name_part, ext = os.path.splitext(filename)
# 去掉原来的标识符,添加新的
base_name = name_part.rsplit("_", 1)[0] # 移除最后一个_后的部分
filename = f"{base_name}_{short_id}{ext}"
temp_file_path = os.path.join(EMOJI_DIR, filename)
attempts += 1
# 如果还是冲突使用UUID作为备用方案
if os.path.exists(temp_file_path):
uuid_short = str(uuid.uuid4())[:8]
name_part, ext = os.path.splitext(filename)
base_name = name_part.rsplit("_", 1)[0]
filename = f"{base_name}_{uuid_short}{ext}"
temp_file_path = os.path.join(EMOJI_DIR, filename)
# 如果UUID方案也冲突添加序号
counter = 1
original_filename = filename
while os.path.exists(temp_file_path):
name_part, ext = os.path.splitext(original_filename)
filename = f"{name_part}_{counter}{ext}"
temp_file_path = os.path.join(EMOJI_DIR, filename)
counter += 1
# 防止无限循环最多尝试100次
if counter > 100:
logger.error(f"[EmojiAPI] 无法生成唯一文件名,尝试次数过多: {original_filename}")
return {
"success": False,
"message": "无法生成唯一文件名,请稍后重试",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
# 5. 保存base64图片到emoji目录
try:
# 解码base64并保存图片
if not ImageUtils.base64_to_image(image_base64, temp_file_path):
logger.error(f"[EmojiAPI] 无法保存base64图片到文件: {temp_file_path}")
return {
"success": False,
"message": "无法保存图片文件",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
logger.debug(f"[EmojiAPI] 图片已保存到临时文件: {temp_file_path}")
except Exception as save_error:
logger.error(f"[EmojiAPI] 保存图片文件失败: {save_error}")
return {
"success": False,
"message": f"保存图片文件失败: {str(save_error)}",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
# 6. 调用注册方法
register_success = await emoji_manager.register_emoji_by_filename(filename)
# 7. 清理临时文件(如果注册失败但文件还存在)
if not register_success and os.path.exists(temp_file_path):
try:
os.remove(temp_file_path)
logger.debug(f"[EmojiAPI] 已清理临时文件: {temp_file_path}")
except Exception as cleanup_error:
logger.warning(f"[EmojiAPI] 清理临时文件失败: {cleanup_error}")
# 8. 构建返回结果
if register_success:
count_after = len(emoji_manager.emojis)
replaced = count_after <= count_before # 如果数量没增加,说明是替换
# 尝试获取新注册的表情包信息
new_emoji_info = None
if count_after > count_before or replaced:
# 获取最新的表情包信息
try:
# 通过文件名查找新注册的表情包(注意:文件名在注册后可能已经改变)
for emoji_obj in reversed(emoji_manager.emojis):
if not emoji_obj.is_deleted and (
emoji_obj.file_name == filename
or (hasattr(emoji_obj, "full_path") and filename in str(emoji_obj.full_path))
):
new_emoji_info = emoji_obj
break
except Exception as find_error:
logger.warning(f"[EmojiAPI] 查找新注册表情包信息失败: {find_error}")
description = new_emoji_info.description if new_emoji_info else None
emotions = new_emoji_info.emotion if new_emoji_info else None
emoji_hash = new_emoji_info.emoji_hash if new_emoji_info else None
return {
"success": True,
"message": f"表情包注册成功 {'(替换旧表情包)' if replaced else '(新增表情包)'}",
"description": description,
"emotions": emotions,
"replaced": replaced,
"hash": emoji_hash,
}
else:
return {
"success": False,
"message": "表情包注册失败,可能因为重复、格式不支持或审核未通过",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
except Exception as e:
logger.error(f"[EmojiAPI] 注册表情包时发生异常: {e}")
return {
"success": False,
"message": f"注册过程中发生错误: {str(e)}",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
# =============================================================================
# 表情包删除API函数
# =============================================================================
async def delete_emoji(emoji_hash: str) -> Dict[str, Any]:
"""删除表情包
Args:
emoji_hash: 要删除的表情包的哈希值
Returns:
Dict[str, Any]: 删除结果,包含以下字段:
- success: bool, 是否成功删除
- message: str, 结果消息
- count_before: Optional[int], 删除前的表情包数量
- count_after: Optional[int], 删除后的表情包数量
- description: Optional[str], 被删除的表情包描述(成功时)
- emotions: Optional[List[str]], 被删除的表情包情感标签(成功时)
Raises:
ValueError: 如果哈希值为空
TypeError: 如果哈希值不是字符串类型
"""
if not emoji_hash:
raise ValueError("表情包哈希值不能为空")
if not isinstance(emoji_hash, str):
raise TypeError("emoji_hash必须是字符串类型")
try:
logger.info(f"[EmojiAPI] 开始删除表情包,哈希值: {emoji_hash}")
# 1. 获取emoji管理器和删除前的数量
count_before = len(emoji_manager.emojis)
# 2. 获取被删除表情包的信息(用于返回结果)
deleted_emoji = None
try:
deleted_emoji = emoji_manager.get_emoji_by_hash(emoji_hash) or emoji_manager.get_emoji_by_hash_from_db(
emoji_hash
)
description = deleted_emoji.description if deleted_emoji else None
emotions = deleted_emoji.emotion if deleted_emoji else None
except Exception as info_error:
logger.warning(f"[EmojiAPI] 获取被删除表情包信息失败: {info_error}")
description = None
emotions = None
# 3. 执行删除操作
delete_success = False
if deleted_emoji:
delete_success = emoji_manager.delete_emoji(deleted_emoji)
# 4. 获取删除后的数量
count_after = len(emoji_manager.emojis)
# 5. 构建返回结果
if delete_success:
return {
"success": True,
"message": f"表情包删除成功 (哈希: {emoji_hash[:8]}...)",
"count_before": count_before,
"count_after": count_after,
"description": description,
"emotions": emotions,
}
else:
return {
"success": False,
"message": "表情包删除失败,可能因为哈希值不存在或删除过程出错",
"count_before": count_before,
"count_after": count_after,
"description": None,
"emotions": None,
}
except Exception as e:
logger.error(f"[EmojiAPI] 删除表情包时发生异常: {e}")
return {
"success": False,
"message": f"删除过程中发生错误: {str(e)}",
"count_before": None,
"count_after": None,
"description": None,
"emotions": None,
}
async def delete_emoji_by_description(description: str, exact_match: bool = False) -> Dict[str, Any]:
"""根据描述删除表情包
Args:
description: 表情包描述文本
exact_match: 是否精确匹配描述False则为模糊匹配
Returns:
Dict[str, Any]: 删除结果,包含以下字段:
- success: bool, 是否成功删除
- message: str, 结果消息
- deleted_count: int, 删除的表情包数量
- deleted_hashes: List[str], 被删除的表情包哈希列表
- matched_count: int, 匹配到的表情包数量
Raises:
ValueError: 如果描述为空
TypeError: 如果描述不是字符串类型
"""
if not description:
raise ValueError("描述不能为空")
if not isinstance(description, str):
raise TypeError("description必须是字符串类型")
try:
logger.info(f"[EmojiAPI] 根据描述删除表情包: {description} (精确匹配: {exact_match})")
all_emojis = emoji_manager.emojis
# 筛选匹配的表情包
matching_emojis = []
for emoji_obj in all_emojis:
if emoji_obj.is_deleted:
continue
if exact_match:
if emoji_obj.description == description:
matching_emojis.append(emoji_obj)
else:
if description.lower() in emoji_obj.description.lower():
matching_emojis.append(emoji_obj)
matched_count = len(matching_emojis)
if matched_count == 0:
return {
"success": False,
"message": f"未找到匹配描述 '{description}' 的表情包",
"deleted_count": 0,
"deleted_hashes": [],
"matched_count": 0,
}
# 删除匹配的表情包
deleted_count = 0
deleted_hashes = []
for emoji_obj in matching_emojis:
try:
delete_success = emoji_manager.delete_emoji(emoji_obj)
if delete_success:
deleted_count += 1
deleted_hashes.append(emoji_obj.emoji_hash)
except Exception as delete_error:
logger.error(f"[EmojiAPI] 删除表情包失败 (哈希: {emoji_obj.emoji_hash}): {delete_error}")
# 构建返回结果
if deleted_count > 0:
return {
"success": True,
"message": f"成功删除 {deleted_count} 个表情包 (匹配到 {matched_count} 个)",
"deleted_count": deleted_count,
"deleted_hashes": deleted_hashes,
"matched_count": matched_count,
}
else:
return {
"success": False,
"message": f"匹配到 {matched_count} 个表情包,但删除全部失败",
"deleted_count": 0,
"deleted_hashes": [],
"matched_count": matched_count,
}
except Exception as e:
logger.error(f"[EmojiAPI] 根据描述删除表情包时发生异常: {e}")
return {
"success": False,
"message": f"删除过程中发生错误: {str(e)}",
"deleted_count": 0,
"deleted_hashes": [],
"matched_count": 0,
}

View File

@@ -1,3 +0,0 @@
from src.common.logger import get_logger
__all__ = ["get_logger"]

View File

@@ -1,120 +0,0 @@
from typing import Tuple, List
def list_loaded_plugins() -> List[str]:
"""
列出所有当前加载的插件。
Returns:
List[str]: 当前加载的插件名称列表。
"""
from src.plugin_system.core.plugin_manager import plugin_manager
return plugin_manager.list_loaded_plugins()
def list_registered_plugins() -> List[str]:
"""
列出所有已注册的插件。
Returns:
List[str]: 已注册的插件名称列表。
"""
from src.plugin_system.core.plugin_manager import plugin_manager
return plugin_manager.list_registered_plugins()
def get_plugin_path(plugin_name: str) -> str:
"""
获取指定插件的路径。
Args:
plugin_name (str): 插件名称。
Returns:
str: 插件目录的绝对路径。
Raises:
ValueError: 如果插件不存在。
"""
from src.plugin_system.core.plugin_manager import plugin_manager
if plugin_path := plugin_manager.get_plugin_path(plugin_name):
return plugin_path
else:
raise ValueError(f"插件 '{plugin_name}' 不存在。")
async def remove_plugin(plugin_name: str) -> bool:
"""
卸载指定的插件。
**此函数是异步的,确保在异步环境中调用。**
Args:
plugin_name (str): 要卸载的插件名称。
Returns:
bool: 卸载是否成功。
"""
from src.plugin_system.core.plugin_manager import plugin_manager
return await plugin_manager.remove_registered_plugin(plugin_name)
async def reload_plugin(plugin_name: str) -> bool:
"""
重新加载指定的插件。
**此函数是异步的,确保在异步环境中调用。**
Args:
plugin_name (str): 要重新加载的插件名称。
Returns:
bool: 重新加载是否成功。
"""
from src.plugin_system.core.plugin_manager import plugin_manager
return await plugin_manager.reload_registered_plugin(plugin_name)
def load_plugin(plugin_name: str) -> Tuple[bool, int]:
"""
加载指定的插件。
Args:
plugin_name (str): 要加载的插件名称。
Returns:
Tuple[bool, int]: 加载是否成功,成功或失败个数。
"""
from src.plugin_system.core.plugin_manager import plugin_manager
return plugin_manager.load_registered_plugin_classes(plugin_name)
def add_plugin_directory(plugin_directory: str) -> bool:
"""
添加插件目录。
Args:
plugin_directory (str): 要添加的插件目录路径。
Returns:
bool: 添加是否成功。
"""
from src.plugin_system.core.plugin_manager import plugin_manager
return plugin_manager.add_plugin_directory(plugin_directory)
def rescan_plugin_directory() -> Tuple[int, int]:
"""
重新扫描插件目录,加载新插件。
Returns:
Tuple[int, int]: 成功加载的插件数量和失败的插件数量。
"""
from src.plugin_system.core.plugin_manager import plugin_manager
return plugin_manager.rescan_plugin_directory()

View File

@@ -1,46 +0,0 @@
from pathlib import Path
from src.common.logger import get_logger
logger = get_logger("plugin_manager") # 复用plugin_manager名称
def register_plugin(cls):
from src.plugin_system.core.plugin_manager import plugin_manager
from src.plugin_system.base.base_plugin import BasePlugin
"""插件注册装饰器
用法:
@register_plugin
class MyPlugin(BasePlugin):
plugin_name = "my_plugin"
plugin_description = "我的插件"
...
"""
if not issubclass(cls, BasePlugin):
logger.error(f"{cls.__name__} 不是 BasePlugin 的子类")
return cls
# 只是注册插件类,不立即实例化
# 插件管理器会负责实例化和注册
plugin_name: str = cls.plugin_name # type: ignore
if "." in plugin_name:
logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
raise ValueError(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
splitted_name = cls.__module__.split(".")
root_path = Path(__file__)
# 查找项目根目录
while not (root_path / "pyproject.toml").exists() and root_path.parent != root_path:
root_path = root_path.parent
if not (root_path / "pyproject.toml").exists():
logger.error(f"注册 {plugin_name} 无法找到项目根目录")
return cls
plugin_manager.plugin_classes[plugin_name] = cls
plugin_manager.plugin_paths[plugin_name] = str(Path(root_path, *splitted_name).resolve())
logger.debug(f"插件类已注册: {plugin_name}, 路径: {plugin_manager.plugin_paths[plugin_name]}")
return cls

View File

@@ -1,56 +0,0 @@
from typing import Any, Callable, Dict, Optional
from src.plugin_system.base.service_types import PluginServiceInfo
from src.plugin_system.core.plugin_service_registry import plugin_service_registry
def register_service(service_info: PluginServiceInfo, service_handler: Callable[..., Any]) -> bool:
"""注册插件服务。"""
return plugin_service_registry.register_service(service_info, service_handler)
def get_service(service_name: str, plugin_name: Optional[str] = None) -> Optional[PluginServiceInfo]:
"""获取插件服务元信息。"""
return plugin_service_registry.get_service(service_name, plugin_name)
def get_service_handler(service_name: str, plugin_name: Optional[str] = None) -> Optional[Callable[..., Any]]:
"""获取插件服务处理函数。"""
return plugin_service_registry.get_service_handler(service_name, plugin_name)
def list_services(plugin_name: Optional[str] = None, enabled_only: bool = False) -> Dict[str, PluginServiceInfo]:
"""列出插件服务。"""
return plugin_service_registry.list_services(plugin_name=plugin_name, enabled_only=enabled_only)
def enable_service(service_name: str, plugin_name: Optional[str] = None) -> bool:
"""启用插件服务。"""
return plugin_service_registry.enable_service(service_name, plugin_name)
def disable_service(service_name: str, plugin_name: Optional[str] = None) -> bool:
"""禁用插件服务。"""
return plugin_service_registry.disable_service(service_name, plugin_name)
def unregister_service(service_name: str, plugin_name: Optional[str] = None) -> bool:
"""注销插件服务。"""
return plugin_service_registry.unregister_service(service_name, plugin_name)
async def call_service(
service_name: str,
*args: Any,
plugin_name: Optional[str] = None,
caller_plugin: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""调用插件服务。"""
return await plugin_service_registry.call_service(
service_name,
*args,
plugin_name=plugin_name,
caller_plugin=caller_plugin,
**kwargs,
)

View File

@@ -1,45 +0,0 @@
from typing import Optional, Type, TYPE_CHECKING
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.component_types import ComponentType
from src.common.logger import get_logger
if TYPE_CHECKING:
from src.chat.message_receive.chat_manager import BotChatSession
logger = get_logger("tool_api")
def get_tool_instance(tool_name: str, chat_stream: Optional["BotChatSession"] = None) -> Optional[BaseTool]:
"""获取公开工具实例
Args:
tool_name: 工具名称
chat_stream: 聊天流对象,用于传递聊天上下文信息
Returns:
Optional[BaseTool]: 工具实例如果未找到则返回None
"""
from src.plugin_system.core import component_registry
# 获取插件配置
tool_info = component_registry.get_component_info(tool_name, ComponentType.TOOL)
if tool_info:
plugin_config = component_registry.get_plugin_config(tool_info.plugin_name)
else:
plugin_config = None
tool_class: Type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore
return tool_class(plugin_config, chat_stream) if tool_class else None
def get_llm_available_tool_definitions():
"""获取LLM可用的工具定义列表
Returns:
List[Tuple[str, Dict[str, Any]]]: 工具定义列表,为[("tool_name", 定义)]
"""
from src.plugin_system.core import component_registry
llm_available_tools = component_registry.get_llm_available_tools()
return [(name, tool_class.get_tool_definition()) for name, tool_class in llm_available_tools.items()]

View File

@@ -1,77 +0,0 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from src.plugin_system.base.component_types import EventType, MaiMessages
from src.plugin_system.base.workflow_types import WorkflowContext, WorkflowStage, WorkflowStepInfo, WorkflowStepResult
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.core.events_manager import events_manager
from src.plugin_system.core.workflow_engine import workflow_engine
def register_workflow_step(step_info: WorkflowStepInfo, step_handler: Callable[..., Any]) -> bool:
"""注册workflow step。"""
return component_registry.register_workflow_step(step_info, step_handler)
def get_steps_by_stage(stage: WorkflowStage, enabled_only: bool = False) -> Dict[str, WorkflowStepInfo]:
"""获取指定阶段的workflow steps。"""
return component_registry.get_steps_by_stage(stage, enabled_only=enabled_only)
def get_workflow_step(step_name: str, stage: Optional[WorkflowStage] = None) -> Optional[WorkflowStepInfo]:
"""获取workflow step元信息。"""
return component_registry.get_workflow_step(step_name, stage)
def get_workflow_step_handler(step_name: str, stage: Optional[WorkflowStage] = None) -> Optional[Callable[..., Any]]:
"""获取workflow step处理函数。"""
return component_registry.get_workflow_step_handler(step_name, stage)
def enable_workflow_step(step_name: str, stage: Optional[WorkflowStage] = None) -> bool:
"""启用workflow step。"""
return component_registry.enable_workflow_step(step_name, stage)
def disable_workflow_step(step_name: str, stage: Optional[WorkflowStage] = None) -> bool:
"""禁用workflow step。"""
return component_registry.disable_workflow_step(step_name, stage)
def get_execution_trace(trace_id: str) -> Optional[Dict[str, Any]]:
"""按trace_id获取workflow执行路径。"""
return workflow_engine.get_execution_trace(trace_id)
def clear_execution_trace(trace_id: str) -> bool:
"""清理trace执行路径记录。"""
return workflow_engine.clear_execution_trace(trace_id)
async def execute_workflow_message(
message: Optional[MaiMessages] = None,
stream_id: Optional[str] = None,
action_usage: Optional[List[str]] = None,
context: Optional[WorkflowContext] = None,
) -> Tuple[WorkflowStepResult, Optional[MaiMessages], WorkflowContext]:
"""执行workflow消息流转。"""
return await events_manager.handle_workflow_message(
message=message,
stream_id=stream_id,
action_usage=action_usage,
context=context,
)
async def publish_event(
event_type: Union[EventType, str],
message: Optional[MaiMessages] = None,
stream_id: Optional[str] = None,
action_usage: Optional[List[str]] = None,
) -> Tuple[bool, Optional[MaiMessages]]:
"""发布事件(支持系统事件和自定义字符串事件)。"""
return await events_manager.handle_mai_events(
event_type=event_type,
message=message,
stream_id=stream_id,
action_usage=action_usage,
)

View File

@@ -1,72 +0,0 @@
"""
插件基础类模块
提供插件开发的基础类和类型定义
"""
from .base_plugin import BasePlugin
from .base_action import BaseAction
from .base_tool import BaseTool
from .base_command import BaseCommand
from .base_events_handler import BaseEventHandler
from .service_types import PluginServiceInfo
from .workflow_types import WorkflowContext, WorkflowMessage, WorkflowStage, WorkflowStepInfo, WorkflowStepResult
from .workflow_errors import WorkflowErrorCode
from .component_types import (
ComponentType,
ActionActivationType,
ChatMode,
ComponentInfo,
ActionInfo,
CommandInfo,
ToolInfo,
PluginInfo,
PythonDependency,
EventHandlerInfo,
EventType,
MaiMessages,
ToolParamType,
CustomEventHandlerResult,
ReplyContentType,
ReplyContent,
ForwardNode,
ReplySetModel,
)
from .config_types import ConfigField, ConfigSection, ConfigLayout, ConfigTab
__all__ = [
"BasePlugin",
"BaseAction",
"BaseCommand",
"BaseTool",
"ComponentType",
"ActionActivationType",
"ChatMode",
"ComponentInfo",
"ActionInfo",
"CommandInfo",
"ToolInfo",
"PluginInfo",
"PythonDependency",
"ConfigField",
"ConfigSection",
"ConfigLayout",
"ConfigTab",
"EventHandlerInfo",
"EventType",
"BaseEventHandler",
"MaiMessages",
"ToolParamType",
"CustomEventHandlerResult",
"ReplyContentType",
"ReplyContent",
"ForwardNode",
"ReplySetModel",
"PluginServiceInfo",
"WorkflowContext",
"WorkflowMessage",
"WorkflowStage",
"WorkflowStepInfo",
"WorkflowStepResult",
"WorkflowErrorCode",
]

View File

@@ -1,533 +0,0 @@
import time
import asyncio
from abc import ABC, abstractmethod
from typing import Tuple, Optional, TYPE_CHECKING, Dict, List
from src.common.logger import get_logger
from src.common.data_models.message_data_model import ReplyContentType, ReplyContent, ReplySetModel, ForwardNode
from src.chat.message_receive.chat_manager import BotChatSession
from src.plugin_system.base.component_types import ActionActivationType, ActionInfo, ComponentType
from src.plugin_system.apis import send_api, database_api, message_api
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
logger = get_logger("base_action")
class BaseAction(ABC):
"""Action组件基类
Action是插件的一种组件类型用于处理聊天中的动作逻辑
子类可以通过类属性定义激活条件,这些会在实例化时转换为实例属性:
- focus_activation_type: 专注模式激活类型
- normal_activation_type: 普通模式激活类型
- activation_keywords: 激活关键词列表
- keyword_case_sensitive: 关键词是否区分大小写
- parallel_action: 是否允许并行执行
- random_activation_probability: 随机激活概率
"""
def __init__(
self,
action_data: dict,
action_reasoning: str,
cycle_timers: dict,
thinking_id: str,
chat_stream: BotChatSession,
plugin_config: Optional[dict] = None,
action_message: Optional["DatabaseMessages"] = None,
**kwargs,
):
# sourcery skip: hoist-similar-statement-from-if, merge-else-if-into-elif, move-assign-in-block, swap-if-else-branches, swap-nested-ifs
"""初始化Action组件
Args:
action_data: 动作数据
reasoning: 执行该动作的理由
cycle_timers: 计时器字典
thinking_id: 思考ID
chat_stream: 聊天流对象
log_prefix: 日志前缀
plugin_config: 插件配置字典
action_message: 消息数据
**kwargs: 其他参数
"""
if plugin_config is None:
plugin_config = {}
self.action_data = action_data
self.reasoning = ""
self.cycle_timers = cycle_timers
self.thinking_id = thinking_id
self.action_reasoning = action_reasoning
self.plugin_config = plugin_config or {}
"""对应的插件配置"""
# 设置动作基本信息实例属性
self.action_name: str = getattr(self, "action_name", self.__class__.__name__.lower().replace("action", ""))
"""Action的名字"""
self.action_description: str = getattr(self, "action_description", self.__doc__ or "Action组件")
"""Action的描述"""
self.action_parameters: dict = getattr(self.__class__, "action_parameters", {}).copy()
self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy()
"""NORMAL模式下的激活类型"""
self.activation_type = self.__class__.activation_type
"""激活类型"""
self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0)
"""当激活类型为RANDOM时的概率"""
self.activation_keywords: list[str] = getattr(self.__class__, "activation_keywords", []).copy()
"""激活类型为KEYWORD时的KEYWORDS列表"""
self.keyword_case_sensitive: bool = getattr(self.__class__, "keyword_case_sensitive", False)
self.parallel_action: bool = getattr(self.__class__, "parallel_action", True)
self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy()
# =============================================================================
# 便捷属性 - 直接在初始化时获取常用聊天信息(带类型注解)
# =============================================================================
# 获取聊天流对象
self.chat_stream = chat_stream or kwargs.get("chat_stream")
self.chat_id = self.chat_stream.session_id
self.platform = getattr(self.chat_stream, "platform", None)
# 初始化基础信息(带类型注解)
self.action_message = action_message
self.group_id = None
self.group_name = None
self.user_id = None
self.user_nickname = None
self.is_group = False
self.target_id = None
self.group_id = (
str(self.action_message.chat_info.group_info.group_id) if self.action_message.chat_info.group_info else None
)
self.group_name = (
self.action_message.chat_info.group_info.group_name if self.action_message.chat_info.group_info else None
)
self.user_id = str(self.action_message.user_info.user_id)
self.user_nickname = self.action_message.user_info.user_nickname
if self.group_id:
self.is_group = True
self.target_id = self.group_id
self.log_prefix = f"[{self.group_name}]"
else:
self.is_group = False
self.target_id = self.user_id
self.log_prefix = f"[{self.user_nickname} 的 私聊]"
logger.debug(
f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}"
)
@abstractmethod
async def execute(self) -> Tuple[bool, str]:
"""执行Action的抽象方法子类必须实现
Returns:
Tuple[bool, str]: (是否执行成功, 回复文本)
"""
pass
async def send_text(
self,
content: str,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
typing: bool = False,
storage_message: bool = True,
) -> bool:
"""发送文本消息
Args:
content: 文本内容
set_reply: 是否作为回复发送
reply_message: 回复的消息对象当set_reply为True时必填
typing: 是否计算输入时间
Returns:
bool: 是否发送成功
"""
if not self.chat_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
return await send_api.text_to_stream(
text=content,
stream_id=self.chat_id,
set_reply=set_reply,
reply_message=reply_message,
typing=typing,
storage_message=storage_message,
)
async def send_emoji(
self,
emoji_base64: str,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""发送表情包
Args:
emoji_base64: 表情包的base64编码
set_reply: 是否作为回复发送
reply_message: 回复的消息对象当set_reply为True时必填
Returns:
bool: 是否发送成功
"""
if not self.chat_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
return await send_api.emoji_to_stream(
emoji_base64,
self.chat_id,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
async def send_image(
self,
image_base64: str,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""发送图片
Args:
image_base64: 图片的base64编码
set_reply: 是否作为回复发送
reply_message: 回复的消息对象当set_reply为True时必填
Returns:
bool: 是否发送成功
"""
if not self.chat_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
return await send_api.image_to_stream(
image_base64,
self.chat_id,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
async def send_command(
self,
command_name: str,
args: Optional[dict] = None,
display_message: str = "",
storage_message: bool = True,
) -> bool:
"""发送命令消息
Args:
command_name: 命令名称
args: 命令参数
display_message: 显示消息
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
if not self.chat_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
# 构造命令数据
command_data = {"name": command_name, "args": args or {}}
return await send_api.command_to_stream(
command=command_data,
stream_id=self.chat_id,
storage_message=storage_message,
display_message=display_message,
)
async def send_custom(
self,
message_type: str,
content: str | Dict,
typing: bool = False,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""发送自定义类型消息
Args:
message_type: 消息类型,如"video""file""audio"
content: 消息内容
typing: 是否显示正在输入
set_reply: 是否作为回复发送
reply_message: 回复的消息对象set_reply 为 True时必填
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
if not self.chat_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
return await send_api.custom_to_stream(
message_type=message_type,
content=content,
stream_id=self.chat_id,
typing=typing,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
async def send_hybrid(
self,
message_tuple_list: List[Tuple[ReplyContentType | str, str]],
typing: bool = False,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""
发送混合类型消息
Args:
message_tuple_list: 包含消息类型和内容的元组列表,格式为 [(内容类型, 内容), ...]
typing: 是否计算打字时间
set_reply: 是否作为回复发送
reply_message: 回复的消息对象
"""
if not self.chat_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
reply_set = ReplySetModel()
reply_set.add_hybrid_content_by_raw(message_tuple_list)
return await send_api.custom_reply_set_to_stream(
reply_set=reply_set,
stream_id=self.chat_id,
typing=typing,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
async def send_forward(
self,
messages_list: List[Tuple[str, str, List[Tuple[ReplyContentType | str, str]]] | str],
storage_message: bool = True,
) -> bool:
"""转发消息
Args:
messages_list: 包含消息信息的列表,当传入自行生成的数据时,元素格式为 (sender_id, nickname, 消息体)当传入消息ID时元素格式为 "message_id"
其中消息体的格式为 [(内容类型, 内容), ...]
任意长度的消息都需要使用列表的形式传入
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
if not self.chat_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
reply_set = ReplySetModel()
forward_message_nodes: List[ForwardNode] = []
for message in messages_list:
if isinstance(message, str):
forward_message_node = ForwardNode.construct_as_id_reference(message)
elif isinstance(message, Tuple) and len(message) == 3:
sender_id, nickname, content_list = message
single_node_content_list: List[ReplyContent] = []
for node_content_type, node_content in content_list:
reply_node_content = ReplyContent(content_type=node_content_type, content=node_content)
single_node_content_list.append(reply_node_content)
forward_message_node = ForwardNode.construct_as_created_node(
user_id=sender_id, user_nickname=nickname, content=single_node_content_list
)
else:
logger.warning(f"{self.log_prefix} 转发消息时遇到无效的消息格式: {message}")
continue
forward_message_nodes.append(forward_message_node)
reply_set.add_forward_content(forward_message_nodes)
return await send_api.custom_reply_set_to_stream(
reply_set=reply_set,
stream_id=self.chat_id,
storage_message=storage_message,
set_reply=False,
reply_message=None,
)
async def send_voice(self, audio_base64: str) -> bool:
"""
发送语音消息
Args:
audio_base64: 语音的base64编码
Returns:
bool: 是否发送成功
"""
if not audio_base64:
logger.error(f"{self.log_prefix} 缺少音频内容")
return False
reply_set = ReplySetModel()
reply_set.add_voice_content(audio_base64)
return await send_api.custom_reply_set_to_stream(
reply_set=reply_set,
stream_id=self.chat_id,
storage_message=False,
)
async def store_action_info(
self,
action_build_into_prompt: bool = False,
action_prompt_display: str = "",
action_done: bool = True,
) -> None:
"""存储动作信息到数据库
Args:
action_build_into_prompt: 是否构建到提示中
action_prompt_display: 显示的action提示信息
action_done: action是否完成
"""
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=action_build_into_prompt,
action_prompt_display=action_prompt_display,
action_done=action_done,
thinking_id=self.thinking_id,
action_data=self.action_data,
action_name=self.action_name,
action_reasoning=self.action_reasoning,
)
async def wait_for_new_message(self, timeout: int = 1200) -> Tuple[bool, str]:
"""等待新消息或超时
在loop_start_time之后等待新消息如果没有新消息且没有超时就一直等待。
使用message_api检查self.chat_id对应的聊天中是否有新消息。
Args:
timeout: 超时时间默认1200秒
Returns:
Tuple[bool, str]: (是否收到新消息, 空字符串)
"""
try:
# 获取循环开始时间,如果没有则使用当前时间
loop_start_time = self.action_data.get("loop_start_time", time.time())
logger.info(f"{self.log_prefix} 开始等待新消息... (最长等待: {timeout}秒, 从时间点: {loop_start_time})")
# 确保有有效的chat_id
if not self.chat_id:
logger.error(f"{self.log_prefix} 等待新消息失败: 没有有效的chat_id")
return False, "没有有效的chat_id"
wait_start_time = asyncio.get_event_loop().time()
while True:
# 检查新消息
current_time = time.time()
new_message_count = message_api.count_new_messages(
chat_id=self.chat_id, start_time=loop_start_time, end_time=current_time
)
if new_message_count > 0:
logger.info(f"{self.log_prefix} 检测到{new_message_count}条新消息聊天ID: {self.chat_id}")
return True, ""
# 检查超时
elapsed_time = asyncio.get_event_loop().time() - wait_start_time
if elapsed_time > timeout:
logger.warning(f"{self.log_prefix} 等待新消息超时({timeout}秒)聊天ID: {self.chat_id}")
return False, ""
# 每30秒记录一次等待状态
if int(elapsed_time) % 15 == 0 and int(elapsed_time) > 0:
logger.debug(f"{self.log_prefix} 已等待{int(elapsed_time)}秒,继续等待新消息...")
# 短暂休眠
await asyncio.sleep(0.5)
except asyncio.CancelledError:
logger.info(f"{self.log_prefix} 等待新消息被中断 (CancelledError)")
return False, ""
except Exception as e:
logger.error(f"{self.log_prefix} 等待新消息时发生错误: {e}")
return False, f"等待新消息失败: {str(e)}"
@classmethod
def get_action_info(cls) -> "ActionInfo":
"""从类属性生成ActionInfo
所有信息都从类属性中读取,确保一致性和完整性。
Action类必须定义所有必要的类属性。
Returns:
ActionInfo: 生成的Action信息对象
"""
# 从类属性读取名称,如果没有定义则使用类名自动生成
name = getattr(cls, "action_name", cls.__name__.lower().replace("action", ""))
if "." in name:
logger.error(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代")
raise ValueError(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代")
# 获取focus_activation_type和normal_activation_type
focus_activation_type = getattr(cls, "focus_activation_type", ActionActivationType.ALWAYS)
_normal_activation_type = getattr(cls, "normal_activation_type", ActionActivationType.ALWAYS)
# 处理activation_type如果插件中声明了就用插件的值否则默认使用focus_activation_type
activation_type = getattr(cls, "activation_type", focus_activation_type)
return ActionInfo(
name=name,
component_type=ComponentType.ACTION,
description=getattr(cls, "action_description", "Action动作"),
activation_type=activation_type,
activation_keywords=getattr(cls, "activation_keywords", []).copy(),
keyword_case_sensitive=getattr(cls, "keyword_case_sensitive", False),
parallel_action=getattr(cls, "parallel_action", True),
random_activation_probability=getattr(cls, "random_activation_probability", 0.0),
# 使用正确的字段名
action_parameters=getattr(cls, "action_parameters", {}).copy(),
action_require=getattr(cls, "action_require", []).copy(),
associated_types=getattr(cls, "associated_types", []).copy(),
)
def get_config(self, key: str, default=None):
"""获取插件配置值,使用嵌套键访问
Args:
key: 配置键名,使用嵌套访问如 "section.subsection.key"
default: 默认值
Returns:
Any: 配置值或默认值
"""
if not self.plugin_config:
return default
# 支持嵌套键访问
keys = key.split(".")
current = self.plugin_config
for k in keys:
if isinstance(current, dict) and k in current:
current = current[k]
else:
return default
return current

View File

@@ -1,387 +0,0 @@
from abc import ABC, abstractmethod
from typing import Dict, Tuple, Optional, TYPE_CHECKING, List
from src.common.logger import get_logger
from src.common.data_models.message_data_model import ReplyContentType, ReplyContent, ReplySetModel, ForwardNode
from src.plugin_system.base.component_types import CommandInfo, ComponentType
from src.chat.message_receive.message import SessionMessage
from src.plugin_system.apis import send_api
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
logger = get_logger("base_command")
class BaseCommand(ABC):
"""Command组件基类
Command是插件的一种组件类型用于处理命令请求
子类可以通过类属性定义命令模式:
- command_pattern: 命令匹配的正则表达式
- command_help: 命令帮助信息
- command_examples: 命令使用示例列表
"""
command_name: str = ""
"""Command组件的名称"""
command_description: str = ""
"""Command组件的描述"""
# 默认命令设置
command_pattern: str = r""
"""命令匹配的正则表达式"""
def __init__(self, message: SessionMessage, plugin_config: Optional[dict] = None):
"""初始化Command组件
Args:
message: 接收到的消息对象
plugin_config: 插件配置字典
"""
self.message = message
self.matched_groups: Dict[str, str] = {} # 存储正则表达式匹配的命名组
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
self.log_prefix = "[Command]"
logger.debug(f"{self.log_prefix} Command组件初始化完成")
def set_matched_groups(self, groups: Dict[str, str]) -> None:
"""设置正则表达式匹配的命名组
Args:
groups: 正则表达式匹配的命名组
"""
self.matched_groups = groups
@abstractmethod
async def execute(self) -> Tuple[bool, Optional[str], int]:
"""执行Command的抽象方法子类必须实现
Returns:
Tuple[bool, Optional[str], int]: (是否执行成功, 可选的回复消息, 拦截消息力度0代表不拦截1代表仅不触发回复replyer可见2代表不触发回复replyer不可见)
"""
pass
def get_config(self, key: str, default=None):
"""获取插件配置值,使用嵌套键访问
Args:
key: 配置键名,使用嵌套访问如 "section.subsection.key"
default: 默认值
Returns:
Any: 配置值或默认值
"""
if not self.plugin_config:
return default
# 支持嵌套键访问
keys = key.split(".")
current = self.plugin_config
for k in keys:
if isinstance(current, dict) and k in current:
current = current[k]
else:
return default
return current
async def send_text(
self,
content: str,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""发送回复消息
Args:
content: 回复内容
set_reply: 是否作为回复发送
reply_message: 回复的消息对象当set_reply为True时必填
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
# 获取聊天流信息
session_id = self.message.session_id
if not session_id:
logger.error(f"{self.log_prefix} 缺少session_id")
return False
return await send_api.text_to_stream(
text=content,
stream_id=session_id,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
async def send_image(
self,
image_base64: str,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""发送图片
Args:
image_base64: 图片的base64编码
Returns:
bool: 是否发送成功
"""
session_id = self.message.session_id
if not session_id:
logger.error(f"{self.log_prefix} 缺少session_id")
return False
return await send_api.image_to_stream(
image_base64,
session_id,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
async def send_emoji(
self,
emoji_base64: str,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""发送表情包
Args:
emoji_base64: 表情包的base64编码
set_reply: 是否作为回复发送
reply_message: 回复的消息对象当set_reply为True时必填
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
session_id = self.message.session_id
if not session_id:
logger.error(f"{self.log_prefix} 缺少session_id")
return False
return await send_api.emoji_to_stream(
emoji_base64, session_id, set_reply=set_reply, reply_message=reply_message
)
async def send_command(
self,
command_name: str,
args: Optional[dict] = None,
display_message: str = "",
storage_message: bool = True,
) -> bool:
"""发送命令消息
Args:
command_name: 命令名称
args: 命令参数
display_message: 显示消息
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
try:
# 获取聊天流信息
session_id = self.message.session_id
if not session_id:
logger.error(f"{self.log_prefix} 缺少session_id")
return False
# 构造命令数据
command_data = {"name": command_name, "args": args or {}}
success = await send_api.command_to_stream(
command=command_data,
stream_id=session_id,
storage_message=storage_message,
display_message=display_message,
)
if success:
logger.info(f"{self.log_prefix} 成功发送命令: {command_name}")
else:
logger.error(f"{self.log_prefix} 发送命令失败: {command_name}")
return success
except Exception as e:
logger.error(f"{self.log_prefix} 发送命令时出错: {e}")
return False
async def send_voice(self, voice_base64: str) -> bool:
"""
发送语音消息
Args:
voice_base64: 语音的base64编码
Returns:
bool: 是否发送成功
"""
session_id = self.message.session_id
if not session_id:
logger.error(f"{self.log_prefix} 缺少session_id")
return False
return await send_api.custom_to_stream(
message_type="voice",
content=voice_base64,
stream_id=session_id,
typing=False,
set_reply=False,
reply_message=None,
storage_message=False,
)
async def send_hybrid(
self,
message_tuple_list: List[Tuple[ReplyContentType | str, str]],
typing: bool = False,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""
发送混合类型消息
Args:
message_tuple_list: 包含消息类型和内容的元组列表,格式为 [(内容类型, 内容), ...]
typing: 是否显示正在输入
set_reply: 是否计算打字时间
reply_message: 回复的消息对象
storage_message: 是否存储消息到数据库
"""
session_id = self.message.session_id
if not session_id:
logger.error(f"{self.log_prefix} 缺少session_id")
return False
reply_set = ReplySetModel()
reply_set.add_hybrid_content_by_raw(message_tuple_list)
return await send_api.custom_reply_set_to_stream(
reply_set=reply_set,
stream_id=session_id,
typing=typing,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
async def send_forward(
self,
messages_list: List[Tuple[str, str, List[Tuple[ReplyContentType | str, str]]] | str],
storage_message: bool = True,
) -> bool:
"""转发消息
Args:
messages_list: 包含消息信息的列表,当传入自行生成的数据时,元素格式为 (sender_id, nickname, 消息体)当传入消息ID时元素格式为 "message_id"
其中消息体的格式为 [(内容类型, 内容), ...]
任意长度的消息都需要使用列表的形式传入
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
session_id = self.message.session_id
if not session_id:
logger.error(f"{self.log_prefix} 缺少session_id")
return False
reply_set = ReplySetModel()
forward_message_nodes: List[ForwardNode] = []
for message in messages_list:
if isinstance(message, str):
forward_message_node = ForwardNode.construct_as_id_reference(message)
elif isinstance(message, Tuple) and len(message) == 3:
sender_id, nickname, content_list = message
single_node_content_list: List[ReplyContent] = []
for node_content_type, node_content in content_list:
reply_node_content = ReplyContent(content_type=node_content_type, content=node_content)
single_node_content_list.append(reply_node_content)
forward_message_node = ForwardNode.construct_as_created_node(
user_id=sender_id, user_nickname=nickname, content=single_node_content_list
)
else:
logger.warning(f"{self.log_prefix} 转发消息时遇到无效的消息格式: {message}")
continue
forward_message_nodes.append(forward_message_node)
reply_set.add_forward_content(forward_message_nodes)
return await send_api.custom_reply_set_to_stream(
reply_set=reply_set,
stream_id=session_id,
storage_message=storage_message,
set_reply=False,
reply_message=None,
)
async def send_custom(
self,
message_type: str,
content: str | Dict,
display_message: str = "",
typing: bool = False,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""发送指定类型的回复消息到当前聊天环境
Args:
message_type: 消息类型,如"text""image""emoji""voice"
content: 消息内容
display_message: 显示消息(可选)
typing: 是否显示正在输入
set_reply: 是否作为回复发送
reply_message: 回复的消息对象set_reply 为 True时必填
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
# 获取聊天流信息
session_id = self.message.session_id
if not session_id:
logger.error(f"{self.log_prefix} 缺少session_id")
return False
return await send_api.custom_to_stream(
message_type=message_type,
content=content,
stream_id=session_id,
display_message=display_message,
typing=typing,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
@classmethod
def get_command_info(cls) -> "CommandInfo":
"""从类属性生成CommandInfo
Args:
name: Command名称如果不提供则使用类名
description: Command描述如果不提供则使用类文档字符串
Returns:
CommandInfo: 生成的Command信息对象
"""
if "." in cls.command_name:
logger.error(f"Command名称 '{cls.command_name}' 包含非法字符 '.',请使用下划线替代")
raise ValueError(f"Command名称 '{cls.command_name}' 包含非法字符 '.',请使用下划线替代")
return CommandInfo(
name=cls.command_name,
component_type=ComponentType.COMMAND,
description=cls.command_description,
command_pattern=cls.command_pattern,
)

View File

@@ -1,381 +0,0 @@
from abc import ABC, abstractmethod
from typing import Tuple, Optional, Dict, List, TYPE_CHECKING
from src.common.logger import get_logger
from src.common.data_models.message_data_model import ReplyContentType, ReplySetModel, ReplyContent, ForwardNode
from src.plugin_system.apis import send_api
from .component_types import MaiMessages, EventType, EventHandlerInfo, ComponentType, CustomEventHandlerResult
logger = get_logger("base_event_handler")
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
class BaseEventHandler(ABC):
"""事件处理器基类
所有事件处理器都应该继承这个基类,提供事件处理的基本接口
"""
event_type: EventType | str = EventType.UNKNOWN
"""事件类型,默认为未知"""
handler_name: str = ""
"""处理器名称"""
handler_description: str = ""
"""处理器描述"""
weight: int = 0
"""处理器权重,越大权重越高"""
intercept_message: bool = False
"""是否拦截消息,默认为否"""
def __init__(self):
self.log_prefix = "[EventHandler]"
self.plugin_name = ""
"""对应插件名"""
self.plugin_config: Optional[Dict] = None
"""插件配置字典"""
if self.event_type == EventType.UNKNOWN:
raise NotImplementedError("事件处理器必须指定 event_type")
@abstractmethod
async def execute(
self, message: MaiMessages | None
) -> Tuple[bool, bool, Optional[str], Optional[CustomEventHandlerResult], Optional[MaiMessages]]:
"""执行事件处理的抽象方法,子类必须实现
Args:
message (MaiMessages | None): 事件消息对象当你注册的事件为ON_START和ON_STOP时message为None
Returns:
Tuple[bool, bool, Optional[str], Optional[CustomEventHandlerResult], Optional[MaiMessages]]: (是否执行成功, 是否需要继续处理, 可选的返回消息, 可选的自定义结果,可选的修改后消息)
"""
raise NotImplementedError("子类必须实现 execute 方法")
@classmethod
def get_handler_info(cls) -> "EventHandlerInfo":
"""获取事件处理器的信息"""
# 从类属性读取名称如果没有定义则使用类名自动生成S
name: str = getattr(cls, "handler_name", cls.__name__.lower().replace("handler", ""))
if "." in name:
logger.error(f"事件处理器名称 '{name}' 包含非法字符 '.',请使用下划线替代")
raise ValueError(f"事件处理器名称 '{name}' 包含非法字符 '.',请使用下划线替代")
return EventHandlerInfo(
name=name,
component_type=ComponentType.EVENT_HANDLER,
description=getattr(cls, "handler_description", "events处理器"),
event_type=cls.event_type,
weight=cls.weight,
intercept_message=cls.intercept_message,
)
def set_plugin_config(self, plugin_config: Dict) -> None:
"""设置插件配置
Args:
plugin_config (dict): 插件配置字典
"""
self.plugin_config = plugin_config
def set_plugin_name(self, plugin_name: str) -> None:
"""设置插件名称
Args:
plugin_name (str): 插件名称
"""
self.plugin_name = plugin_name
def get_config(self, key: str, default=None):
"""获取插件配置值,支持嵌套键访问
Args:
key: 配置键名,支持嵌套访问如 "section.subsection.key"
default: 默认值
Returns:
Any: 配置值或默认值
"""
if not self.plugin_config:
return default
# 支持嵌套键访问
keys = key.split(".")
current = self.plugin_config
for k in keys:
if isinstance(current, dict) and k in current:
current = current[k]
else:
return default
return current
async def send_text(
self,
stream_id: str,
text: str,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
typing: bool = False,
storage_message: bool = True,
) -> bool:
"""发送文本消息
Args:
stream_id: 聊天ID
text: 文本内容
set_reply: 是否作为回复发送
reply_message: 回复的消息对象当set_reply为True时必填
typing: 是否计算输入时间
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
if not stream_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
return await send_api.text_to_stream(
text=text,
stream_id=stream_id,
set_reply=set_reply,
reply_message=reply_message,
typing=typing,
storage_message=storage_message,
)
async def send_emoji(
self,
stream_id: str,
emoji_base64: str,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""发送表情消息
Args:
emoji_base64: 表情的Base64编码
stream_id: 聊天ID
set_reply: 是否作为回复发送
reply_message: 回复的消息对象当set_reply为True时必填
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
if not stream_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
return await send_api.emoji_to_stream(
emoji_base64=emoji_base64,
stream_id=stream_id,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
async def send_image(
self,
stream_id: str,
image_base64: str,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""发送图片消息
Args:
image_base64: 图片的Base64编码
stream_id: 聊天ID
set_reply: 是否作为回复发送
reply_message: 回复的消息对象当set_reply为True时必填
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
if not stream_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
return await send_api.image_to_stream(
image_base64=image_base64,
stream_id=stream_id,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
async def send_voice(
self,
stream_id: str,
audio_base64: str,
) -> bool:
"""发送语音消息
Args:
stream_id: 聊天ID
audio_base64: 语音的Base64编码
Returns:
bool: 是否发送成功
"""
if not stream_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
reply_set = ReplySetModel()
reply_set.add_voice_content(audio_base64)
return await send_api.custom_reply_set_to_stream(
reply_set=reply_set,
stream_id=stream_id,
storage_message=False,
)
async def send_command(
self,
stream_id: str,
command_name: str,
command_args: Optional[dict] = None,
display_message: str = "",
storage_message: bool = True,
) -> bool:
"""发送命令消息
Args:
stream_id: 流ID
command_name: 命令名称
command_args: 命令参数字典
display_message: 显示消息
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
if not stream_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
# 构造命令数据
command_data = {"name": command_name, "args": command_args or {}}
return await send_api.command_to_stream(
command=command_data,
stream_id=stream_id,
storage_message=storage_message,
display_message=display_message,
)
async def send_custom(
self,
stream_id: str,
message_type: str,
content: str | Dict,
typing: bool = False,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""发送自定义消息
Args:
stream_id: 聊天ID
message_type: 消息类型
content: 消息内容,可以是字符串或字典
typing: 是否显示正在输入状态
set_reply: 是否作为回复发送
reply_message: 回复的消息对象当set_reply为True时必填
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
if not stream_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
return await send_api.custom_to_stream(
message_type=message_type,
content=content,
stream_id=stream_id,
typing=typing,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
async def send_hybrid(
self,
stream_id: str,
message_tuple_list: List[Tuple[ReplyContentType | str, str]],
typing: bool = False,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""
发送混合类型消息
Args:
stream_id: 流ID
message_tuple_list: 包含消息类型和内容的元组列表,格式为 [(内容类型, 内容), ...]
typing: 是否计算打字时间
set_reply: 是否作为回复发送
reply_message: 回复的消息对象
storage_message: 是否存储消息到数据库
"""
if not stream_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
reply_set = ReplySetModel()
reply_set.add_hybrid_content_by_raw(message_tuple_list)
return await send_api.custom_reply_set_to_stream(
reply_set=reply_set,
stream_id=stream_id,
typing=typing,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
async def send_forward(
self,
stream_id: str,
messages_list: List[Tuple[str, str, List[Tuple[ReplyContentType | str, str]]] | str],
storage_message: bool = True,
) -> bool:
"""转发消息
Args:
stream_id: 聊天ID
messages_list: 包含消息信息的列表,当传入自行生成的数据时,元素格式为 (sender_id, nickname, 消息体)当传入消息ID时元素格式为 "message_id"
其中消息体的格式为 [(内容类型, 内容), ...]
任意长度的消息都需要使用列表的形式传入
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
if not stream_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
reply_set = ReplySetModel()
forward_message_nodes: List[ForwardNode] = []
for message in messages_list:
if isinstance(message, str):
forward_message_node = ForwardNode.construct_as_id_reference(message)
elif isinstance(message, Tuple) and len(message) == 3:
sender_id, nickname, content_list = message
single_node_content_list: List[ReplyContent] = []
for node_content_type, node_content in content_list:
reply_node_content = ReplyContent(content_type=node_content_type, content=node_content)
single_node_content_list.append(reply_node_content)
forward_message_node = ForwardNode.construct_as_created_node(
user_id=sender_id, user_nickname=nickname, content=single_node_content_list
)
else:
logger.warning(f"{self.log_prefix} 转发消息时遇到无效的消息格式: {message}")
continue
forward_message_nodes.append(forward_message_node)
reply_set.add_forward_content(forward_message_nodes)
return await send_api.custom_reply_set_to_stream(
reply_set=reply_set,
stream_id=stream_id,
storage_message=storage_message,
set_reply=False,
reply_message=None,
)

View File

@@ -1,102 +0,0 @@
from abc import abstractmethod
from typing import Any, Callable, List, Type, Tuple, Union
from .plugin_base import PluginBase
from src.common.logger import get_logger
from src.plugin_system.base.component_types import ActionInfo, CommandInfo, EventHandlerInfo, ToolInfo
from src.plugin_system.base.workflow_types import WorkflowStepInfo
from .base_action import BaseAction
from .base_command import BaseCommand
from .base_events_handler import BaseEventHandler
from .base_tool import BaseTool
logger = get_logger("base_plugin")
class BasePlugin(PluginBase):
"""基于Action和Command的插件基类
所有上述类型的插件都应该继承这个基类,一个插件可以包含多种组件:
- Action组件处理聊天中的动作
- Command组件处理命令请求
- 未来可扩展Scheduler、Listener等
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@abstractmethod
def get_plugin_components(
self,
) -> List[
Union[
Tuple[ActionInfo, Type[BaseAction]],
Tuple[CommandInfo, Type[BaseCommand]],
Tuple[EventHandlerInfo, Type[BaseEventHandler]],
Tuple[ToolInfo, Type[BaseTool]],
]
]:
"""获取插件包含的组件列表
子类必须实现此方法,返回组件信息和组件类的列表
Returns:
List[tuple[ComponentInfo, Type]]: [(组件信息, 组件类), ...]
"""
raise NotImplementedError("Subclasses must implement this method")
def get_workflow_steps(self) -> List[Tuple[WorkflowStepInfo, Callable[..., Any]]]:
"""获取插件包含的workflow steps。
默认返回空列表,子类可按需覆盖。
"""
return []
def register_plugin(self) -> bool:
"""注册插件及其所有组件"""
from src.plugin_system.core.component_registry import component_registry
components = self.get_plugin_components()
# 检查依赖
if not self._check_dependencies():
logger.error(f"{self.log_prefix} 依赖检查失败,跳过注册")
return False
# 注册所有组件
registered_components = []
for component_info, component_class in components:
component_info.plugin_name = self.plugin_name
if component_registry.register_component(component_info, component_class):
registered_components.append(component_info)
else:
logger.warning(f"{self.log_prefix} 组件 {component_info.name} 注册失败")
# 更新插件信息中的组件列表
self.plugin_info.components = registered_components
# 注册插件
if component_registry.register_plugin(self.plugin_info):
# 注册workflow steps可选
registered_step_count = 0
for step_info, step_handler in self.get_workflow_steps():
if not step_info.plugin_name:
step_info.plugin_name = self.plugin_name
elif step_info.plugin_name != self.plugin_name:
logger.warning(
f"{self.log_prefix} workflow step {step_info.name} 的plugin_name({step_info.plugin_name})与当前插件不一致,已覆盖为 {self.plugin_name}"
)
step_info.plugin_name = self.plugin_name
if component_registry.register_workflow_step(step_info, step_handler):
registered_step_count += 1
else:
logger.warning(f"{self.log_prefix} workflow step {step_info.full_name} 注册失败")
logger.debug(f"{self.log_prefix} 插件注册成功,包含 {len(registered_components)} 个组件")
if registered_step_count > 0:
logger.debug(f"{self.log_prefix} workflow steps 注册成功,数量: {registered_step_count}")
return True
else:
logger.error(f"{self.log_prefix} 插件注册失败")
return False

View File

@@ -1,137 +0,0 @@
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Tuple, TYPE_CHECKING
from rich.traceback import install
from src.common.logger import get_logger
from src.plugin_system.base.component_types import ComponentType, ToolInfo, ToolParamType
if TYPE_CHECKING:
from src.chat.message_receive.chat_manager import BotChatSession
install(extra_lines=3)
logger = get_logger("base_tool")
class BaseTool(ABC):
"""所有工具的基类"""
name: str = ""
"""工具的名称"""
description: str = ""
"""工具的描述"""
parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = []
"""工具的参数定义,为[("param_name", param_type, "description", required, enum_values)]格式
param_name: 参数名称
param_type: 参数类型
description: 参数描述
required: 是否必填
enum_values: 枚举值列表
例如: [("arg1", ToolParamType.STRING, "参数1描述", True, None), ("arg2", ToolParamType.INTEGER, "参数2描述", False, ["1", "2", "3"])]
"""
available_for_llm: bool = False
"""是否可供LLM使用"""
def __init__(self, plugin_config: Optional[dict] = None, chat_stream: Optional["BotChatSession"] = None):
"""初始化工具基类
Args:
plugin_config: 插件配置字典
chat_stream: 聊天流对象,用于获取聊天上下文信息
"""
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
# =============================================================================
# 便捷属性 - 直接在初始化时获取常用聊天信息与BaseAction保持一致
# =============================================================================
# 获取聊天流对象
self.chat_stream = chat_stream
self.chat_id = self.chat_stream.session_id if self.chat_stream else None
self.platform = getattr(self.chat_stream, "platform", None) if self.chat_stream else None
@classmethod
def get_tool_definition(cls) -> dict[str, Any]:
"""获取工具定义用于LLM工具调用
Returns:
dict: 工具定义字典
"""
if not cls.name or not cls.description or cls.parameters is None:
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
return {"name": cls.name, "description": cls.description, "parameters": cls.parameters}
@classmethod
def get_tool_info(cls) -> ToolInfo:
"""获取工具信息"""
if not cls.name or not cls.description or cls.parameters is None:
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
return ToolInfo(
name=cls.name,
tool_description=cls.description,
enabled=cls.available_for_llm,
tool_parameters=cls.parameters,
component_type=ComponentType.TOOL,
)
@abstractmethod
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行工具函数(供llm调用)
通过该方法maicore会通过llm的tool call来调用工具
传入的是json格式的参数符合parameters定义的格式
Args:
function_args: 工具调用参数
Returns:
dict: 工具执行结果
"""
raise NotImplementedError("子类必须实现execute方法")
async def direct_execute(self, **function_args: dict[str, Any]) -> dict[str, Any]:
"""直接执行工具函数(供插件调用)
通过该方法,插件可以直接调用工具,而不需要传入字典格式的参数
插件可以直接调用此方法,用更加明了的方式传入参数
示例: result = await tool.direct_execute(arg1="参数",arg2="参数2")
工具开发者可以重写此方法以实现与llm调用差异化的执行逻辑
Args:
**function_args: 工具调用参数
Returns:
dict: 工具执行结果
"""
parameter_required = [param[0] for param in self.parameters if param[3]] # 获取所有必填参数名
for param_name in parameter_required:
if param_name not in function_args:
raise ValueError(f"工具类 {self.__class__.__name__} 缺少必要参数: {param_name}")
return await self.execute(function_args)
def get_config(self, key: str, default=None):
"""获取插件配置值,使用嵌套键访问
Args:
key: 配置键名,使用嵌套访问如 "section.subsection.key"
default: 默认值
Returns:
Any: 配置值或默认值
"""
if not self.plugin_config:
return default
# 支持嵌套键访问
keys = key.split(".")
current = self.plugin_config
for k in keys:
if isinstance(current, dict) and k in current:
current = current[k]
else:
return default
return current

View File

@@ -1,761 +0,0 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Any, Union
import os
import inspect
import toml
import json
import shutil
import datetime
import re
from src.common.logger import get_logger
from src.plugin_system.base.component_types import (
PluginInfo,
PythonDependency,
)
from src.plugin_system.base.config_types import (
ConfigField,
ConfigSection,
ConfigLayout,
)
from src.plugin_system.utils.manifest_utils import ManifestValidator, VersionComparator
logger = get_logger("plugin_base")
class PluginBase(ABC):
"""插件总基类
所有衍生插件基类都应该继承自此类,这个类定义了插件的基本结构和行为。
"""
# 插件基本信息(子类必须定义)
@property
@abstractmethod
def plugin_name(self) -> str:
return "" # 插件内部标识符(如 "hello_world_plugin"
@property
@abstractmethod
def enable_plugin(self) -> bool:
return True # 是否启用插件
@property
@abstractmethod
def dependencies(self) -> List[Union[str, Dict[str, Any]]]:
return [] # 依赖的其他插件
@property
@abstractmethod
def python_dependencies(self) -> List[PythonDependency]:
return [] # Python包依赖
@property
@abstractmethod
def config_file_name(self) -> str:
return "" # 配置文件名
# manifest文件相关
manifest_file_name: str = "_manifest.json" # manifest文件名
manifest_data: Dict[str, Any] = {} # manifest数据
# 配置定义
@property
@abstractmethod
def config_schema(self) -> Dict[str, Union[Dict[str, ConfigField], str]]:
return {}
config_section_descriptions: Dict[str, Union[str, ConfigSection]] = {}
# 布局配置(可选,不定义则使用自动布局)
config_layout: ConfigLayout = None
def __init__(self, plugin_dir: str):
"""初始化插件
Args:
plugin_dir: 插件目录路径,由插件管理器传递
"""
self.config: Dict[str, Any] = {} # 插件配置
self.plugin_dir = plugin_dir # 插件目录路径
self.log_prefix = f"[Plugin:{self.plugin_name}]"
# 加载manifest文件
self._load_manifest()
# 验证插件信息
self._validate_plugin_info()
# 加载插件配置
self._load_plugin_config()
# 从manifest获取显示信息
self.display_name = self.get_manifest_info("name", self.plugin_name)
self.plugin_version = self.get_manifest_info("version", "1.0.0")
self.plugin_description = self.get_manifest_info("description", "")
self.plugin_author = self._get_author_name()
# 创建插件信息对象
self.plugin_info = PluginInfo(
name=self.plugin_name,
display_name=self.display_name,
description=self.plugin_description,
version=self.plugin_version,
author=self.plugin_author,
enabled=self.enable_plugin,
is_built_in=False,
config_file=self.config_file_name or "",
dependencies=self._get_dependency_names(),
python_dependencies=self.python_dependencies.copy(),
# manifest相关信息
manifest_data=self.manifest_data.copy(),
license=self.get_manifest_info("license", ""),
homepage_url=self.get_manifest_info("homepage_url", ""),
repository_url=self.get_manifest_info("repository_url", ""),
keywords=self.get_manifest_info("keywords", []).copy() if self.get_manifest_info("keywords") else [],
categories=self.get_manifest_info("categories", []).copy() if self.get_manifest_info("categories") else [],
min_host_version=self.get_manifest_info("host_application.min_version", ""),
max_host_version=self.get_manifest_info("host_application.max_version", ""),
)
logger.debug(f"{self.log_prefix} 插件基类初始化完成")
def _validate_plugin_info(self):
"""验证插件基本信息"""
if not self.plugin_name:
raise ValueError(f"插件类 {self.__class__.__name__} 必须定义 plugin_name")
# 验证manifest中的必需信息
if not self.get_manifest_info("name"):
raise ValueError(f"插件 {self.plugin_name} 的manifest中缺少name字段")
if not self.get_manifest_info("description"):
raise ValueError(f"插件 {self.plugin_name} 的manifest中缺少description字段")
def _load_manifest(self): # sourcery skip: raise-from-previous-error
"""加载manifest文件强制要求"""
if not self.plugin_dir:
raise ValueError(f"{self.log_prefix} 没有插件目录路径无法加载manifest")
manifest_path = os.path.join(self.plugin_dir, self.manifest_file_name)
if not os.path.exists(manifest_path):
error_msg = f"{self.log_prefix} 缺少必需的manifest文件: {manifest_path}"
logger.error(error_msg)
raise FileNotFoundError(error_msg)
try:
with open(manifest_path, "r", encoding="utf-8") as f:
self.manifest_data = json.load(f)
logger.debug(f"{self.log_prefix} 成功加载manifest文件: {manifest_path}")
# 验证manifest格式
self._validate_manifest()
except json.JSONDecodeError as e:
error_msg = f"{self.log_prefix} manifest文件格式错误: {e}"
logger.error(error_msg)
raise ValueError(error_msg) # noqa
except IOError as e:
error_msg = f"{self.log_prefix} 读取manifest文件失败: {e}"
logger.error(error_msg)
raise IOError(error_msg) # noqa
def _get_author_name(self) -> str:
"""从manifest获取作者名称"""
author_info = self.get_manifest_info("author", {})
if isinstance(author_info, dict):
return author_info.get("name", "")
else:
return str(author_info) if author_info else ""
def _validate_manifest(self):
"""验证manifest文件格式使用强化的验证器"""
if not self.manifest_data:
raise ValueError(f"{self.log_prefix} manifest数据为空验证失败")
validator = ManifestValidator()
is_valid = validator.validate_manifest(self.manifest_data)
# 记录验证结果
if validator.validation_errors or validator.validation_warnings:
report = validator.get_validation_report()
logger.info(f"{self.log_prefix} Manifest验证结果:\n{report}")
# 如果有验证错误,抛出异常
if not is_valid:
error_msg = f"{self.log_prefix} Manifest文件验证失败"
if validator.validation_errors:
error_msg += f": {'; '.join(validator.validation_errors)}"
raise ValueError(error_msg)
def get_manifest_info(self, key: str, default: Any = None) -> Any:
"""获取manifest信息
Args:
key: 信息键,支持点分割的嵌套键(如 "author.name"
default: 默认值
Returns:
Any: 对应的值
"""
if not self.manifest_data:
return default
keys = key.split(".")
value = self.manifest_data
for k in keys:
if isinstance(value, dict) and k in value:
value = value[k]
else:
return default
return value
def _format_toml_value(self, value: Any) -> str:
"""将Python值格式化为合法的TOML字符串"""
if isinstance(value, str):
return json.dumps(value, ensure_ascii=False)
if isinstance(value, bool):
return str(value).lower()
if isinstance(value, (int, float)):
return str(value)
if isinstance(value, list):
inner = ", ".join(self._format_toml_value(item) for item in value)
return f"[{inner}]"
if isinstance(value, dict):
items = [f"{k} = {self._format_toml_value(v)}" for k, v in value.items()]
return "{ " + ", ".join(items) + " }"
return json.dumps(value, ensure_ascii=False)
def _generate_and_save_default_config(self, config_file_path: str):
"""根据插件的Schema生成并保存默认配置文件"""
if not self.config_schema:
logger.debug(f"{self.log_prefix} 插件未定义config_schema不生成配置文件")
return
toml_str = f"# {self.plugin_name} - 自动生成的配置文件\n"
plugin_description = self.get_manifest_info("description", "插件配置文件")
toml_str += f"# {plugin_description}\n\n"
# 遍历每个配置节
for section, fields in self.config_schema.items():
# 添加节描述
if section in self.config_section_descriptions:
toml_str += f"# {self.config_section_descriptions[section]}\n"
toml_str += f"[{section}]\n\n"
# 遍历节内的字段
if isinstance(fields, dict):
for field_name, field in fields.items():
if isinstance(field, ConfigField):
# 添加字段描述
toml_str += f"# {field.description}"
if field.required:
toml_str += " (必需)"
toml_str += "\n"
# 如果有示例值,添加示例
if field.example:
toml_str += f"# 示例: {field.example}\n"
# 如果有可选值,添加说明
if field.choices:
choices_str = ", ".join(map(str, field.choices))
toml_str += f"# 可选值: {choices_str}\n"
# 添加字段值
value = field.default
toml_str += f"{field_name} = {self._format_toml_value(value)}\n"
toml_str += "\n"
toml_str += "\n"
try:
with open(config_file_path, "w", encoding="utf-8") as f:
f.write(toml_str)
logger.info(f"{self.log_prefix} 已生成默认配置文件: {config_file_path}")
except IOError as e:
logger.error(f"{self.log_prefix} 保存默认配置文件失败: {e}", exc_info=True)
def _get_expected_config_version(self) -> str:
"""获取插件期望的配置版本号"""
# 从config_schema的plugin.config_version字段获取
if "plugin" in self.config_schema and isinstance(self.config_schema["plugin"], dict):
config_version_field = self.config_schema["plugin"].get("config_version")
if isinstance(config_version_field, ConfigField):
return config_version_field.default
return "1.0.0"
def _get_current_config_version(self, config: Dict[str, Any]) -> str:
"""从配置文件中获取当前版本号"""
if "plugin" in config and "config_version" in config["plugin"]:
return str(config["plugin"]["config_version"])
# 如果没有config_version字段视为最早的版本
return "0.0.0"
def _backup_config_file(self, config_file_path: str) -> str:
"""备份配置文件"""
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = f"{config_file_path}.backup_{timestamp}"
try:
shutil.copy2(config_file_path, backup_path)
logger.info(f"{self.log_prefix} 配置文件已备份到: {backup_path}")
return backup_path
except Exception as e:
logger.error(f"{self.log_prefix} 备份配置文件失败: {e}")
return ""
def _migrate_config_values(self, old_config: Dict[str, Any], new_config: Dict[str, Any]) -> Dict[str, Any]:
"""将旧配置值迁移到新配置结构中
Args:
old_config: 旧配置数据
new_config: 基于新schema生成的默认配置
Returns:
Dict[str, Any]: 迁移后的配置
"""
def migrate_section(
old_section: Dict[str, Any], new_section: Dict[str, Any], section_name: str
) -> Dict[str, Any]:
"""迁移单个配置节"""
result = new_section.copy()
for key, value in old_section.items():
if key in new_section:
# 特殊处理config_version字段总是使用新版本
if section_name == "plugin" and key == "config_version":
# 保持新的版本号,不迁移旧值
logger.debug(
f"{self.log_prefix} 更新配置版本: {section_name}.{key} = {result[key]} (旧值: {value})"
)
continue
# 键存在于新配置中,复制值
if isinstance(value, dict) and isinstance(new_section[key], dict):
# 递归处理嵌套字典
result[key] = migrate_section(value, new_section[key], f"{section_name}.{key}")
else:
result[key] = value
logger.debug(f"{self.log_prefix} 迁移配置: {section_name}.{key} = {value}")
else:
# 键在新配置中不存在,记录警告
logger.warning(f"{self.log_prefix} 配置项 {section_name}.{key} 在新版本中已被移除")
return result
migrated_config = {}
# 迁移每个配置节
for section_name, new_section_data in new_config.items():
if (
section_name in old_config
and isinstance(old_config[section_name], dict)
and isinstance(new_section_data, dict)
):
migrated_config[section_name] = migrate_section(
old_config[section_name], new_section_data, section_name
)
else:
# 新增的节或类型不匹配,使用默认值
migrated_config[section_name] = new_section_data
if section_name in old_config:
logger.warning(f"{self.log_prefix} 配置节 {section_name} 结构已改变,使用默认值")
# 检查旧配置中是否有新配置没有的节
for section_name in old_config:
if section_name not in migrated_config:
logger.warning(f"{self.log_prefix} 配置节 {section_name} 在新版本中已被移除")
return migrated_config
def _generate_config_from_schema(self) -> Dict[str, Any]:
# sourcery skip: dict-comprehension
"""根据schema生成配置数据结构不写入文件"""
if not self.config_schema:
return {}
config_data = {}
# 遍历每个配置节
for section, fields in self.config_schema.items():
if isinstance(fields, dict):
section_data = {}
# 遍历节内的字段
for field_name, field in fields.items():
if isinstance(field, ConfigField):
section_data[field_name] = field.default
config_data[section] = section_data
return config_data
def _save_config_to_file(self, config_data: Dict[str, Any], config_file_path: str):
"""将配置数据保存为TOML文件包含注释"""
if not self.config_schema:
logger.debug(f"{self.log_prefix} 插件未定义config_schema不生成配置文件")
return
toml_str = f"# {self.plugin_name} - 配置文件\n"
plugin_description = self.get_manifest_info("description", "插件配置文件")
toml_str += f"# {plugin_description}\n"
# 获取当前期望的配置版本
expected_version = self._get_expected_config_version()
toml_str += f"# 配置版本: {expected_version}\n\n"
# 遍历每个配置节
for section, fields in self.config_schema.items():
# 添加节描述
if section in self.config_section_descriptions:
toml_str += f"# {self.config_section_descriptions[section]}\n"
toml_str += f"[{section}]\n\n"
# 遍历节内的字段
if isinstance(fields, dict) and section in config_data:
section_data = config_data[section]
for field_name, field in fields.items():
if isinstance(field, ConfigField):
# 添加字段描述
toml_str += f"# {field.description}"
if field.required:
toml_str += " (必需)"
toml_str += "\n"
# 如果有示例值,添加示例
if field.example:
toml_str += f"# 示例: {field.example}\n"
# 如果有可选值,添加说明
if field.choices:
choices_str = ", ".join(map(str, field.choices))
toml_str += f"# 可选值: {choices_str}\n"
# 添加字段值(使用迁移后的值)
value = section_data.get(field_name, field.default)
toml_str += f"{field_name} = {self._format_toml_value(value)}\n"
toml_str += "\n"
toml_str += "\n"
try:
with open(config_file_path, "w", encoding="utf-8") as f:
f.write(toml_str)
logger.info(f"{self.log_prefix} 配置文件已保存: {config_file_path}")
except IOError as e:
logger.error(f"{self.log_prefix} 保存配置文件失败: {e}", exc_info=True)
def _load_plugin_config(self): # sourcery skip: extract-method
"""加载插件配置文件,支持版本检查和自动迁移"""
if not self.config_file_name:
logger.debug(f"{self.log_prefix} 未指定配置文件,跳过加载")
return
# 优先使用传入的插件目录路径
if self.plugin_dir:
plugin_dir = self.plugin_dir
else:
# fallback尝试从类的模块信息获取路径
try:
plugin_module_path = inspect.getfile(self.__class__)
plugin_dir = os.path.dirname(plugin_module_path)
except (TypeError, OSError):
# 最后的fallback从模块的__file__属性获取
module = inspect.getmodule(self.__class__)
if module and hasattr(module, "__file__") and module.__file__:
plugin_dir = os.path.dirname(module.__file__)
else:
logger.warning(f"{self.log_prefix} 无法获取插件目录路径,跳过配置加载")
return
config_file_path = os.path.join(plugin_dir, self.config_file_name)
# 如果配置文件不存在,生成默认配置
if not os.path.exists(config_file_path):
logger.info(f"{self.log_prefix} 配置文件 {config_file_path} 不存在,将生成默认配置。")
self._generate_and_save_default_config(config_file_path)
if not os.path.exists(config_file_path):
logger.warning(f"{self.log_prefix} 配置文件 {config_file_path} 不存在且无法生成。")
return
file_ext = os.path.splitext(self.config_file_name)[1].lower()
if file_ext == ".toml":
# 加载现有配置
with open(config_file_path, "r", encoding="utf-8") as f:
existing_config = toml.load(f) or {}
# 检查配置版本
current_version = self._get_current_config_version(existing_config)
# 如果配置文件没有版本信息,跳过版本检查
if current_version == "0.0.0":
logger.debug(f"{self.log_prefix} 配置文件无版本信息,跳过版本检查")
self.config = existing_config
else:
expected_version = self._get_expected_config_version()
if current_version != expected_version:
logger.info(
f"{self.log_prefix} 检测到配置版本需要更新: 当前=v{current_version}, 期望=v{expected_version}"
)
# 生成新的默认配置结构
new_config_structure = self._generate_config_from_schema()
# 迁移旧配置值到新结构
migrated_config = self._migrate_config_values(existing_config, new_config_structure)
# 保存迁移后的配置
self._save_config_to_file(migrated_config, config_file_path)
logger.info(f"{self.log_prefix} 配置文件已从 v{current_version} 更新到 v{expected_version}")
self.config = migrated_config
else:
logger.debug(f"{self.log_prefix} 配置版本匹配 (v{current_version}),直接加载")
self.config = existing_config
logger.debug(f"{self.log_prefix} 配置已从 {config_file_path} 加载")
# 从配置中更新 enable_plugin
if "plugin" in self.config and "enabled" in self.config["plugin"]:
self.enable_plugin = self.config["plugin"]["enabled"] # type: ignore
logger.debug(f"{self.log_prefix} 从配置更新插件启用状态: {self.enable_plugin}")
else:
logger.warning(f"{self.log_prefix} 不支持的配置文件格式: {file_ext},仅支持 .toml")
self.config = {}
def _check_dependencies(self) -> bool:
"""检查插件依赖"""
from src.plugin_system.core.component_registry import component_registry
if not self.dependencies:
return True
for dependency in self.dependencies:
dep_name, version_spec, min_version, max_version = self._parse_dependency(dependency)
if not dep_name:
logger.warning(f"{self.log_prefix} 跳过无效依赖声明: {dependency}")
continue
dep_plugin_info = component_registry.get_plugin_info(dep_name)
if not dep_plugin_info:
logger.error(f"{self.log_prefix} 缺少依赖插件: {dep_name}")
return False
dep_version = dep_plugin_info.version or "0.0.0"
if version_spec:
is_ok, msg = self._is_version_spec_satisfied(dep_version, version_spec)
if not is_ok:
logger.error(
f"{self.log_prefix} 依赖插件版本不满足: {dep_name} {version_spec}, 当前版本={dep_version} ({msg})"
)
return False
if min_version or max_version:
is_ok, msg = VersionComparator.is_version_in_range(dep_version, min_version, max_version)
if not is_ok:
logger.error(
f"{self.log_prefix} 依赖插件版本不满足: {dep_name} 要求区间[{min_version or '-inf'}, {max_version or '+inf'}], 当前版本={dep_version} ({msg})"
)
return False
return True
def _get_dependency_names(self) -> List[str]:
"""获取依赖插件名称列表(用于插件信息展示和统计)。"""
dependency_names: List[str] = []
for dependency in self.dependencies:
dep_name, _, _, _ = self._parse_dependency(dependency)
if dep_name:
dependency_names.append(dep_name)
return dependency_names
def _parse_dependency(self, dependency: Any) -> tuple[str, str, str, str]:
"""解析依赖声明。
支持格式:
- "plugin_a"
- {"name": "plugin_a", "version": ">=1.2.0,<2.0.0"}
- {"name": "plugin_a", "min_version": "1.2.0", "max_version": "2.0.0"}
"""
if isinstance(dependency, str):
return dependency.strip(), "", "", ""
if isinstance(dependency, dict):
dep_name = str(dependency.get("name", "")).strip()
version_spec = str(dependency.get("version", "")).strip()
min_version = str(dependency.get("min_version", "")).strip()
max_version = str(dependency.get("max_version", "")).strip()
return dep_name, version_spec, min_version, max_version
return "", "", "", ""
def _is_version_spec_satisfied(self, version: str, version_spec: str) -> tuple[bool, str]:
"""检查版本是否满足表达式。
支持:==, >=, <=, >, <,可用逗号分隔多个条件。
示例:">=1.2.0,<2.0.0"
"""
normalized_version = VersionComparator.normalize_version(version)
clauses = [clause.strip() for clause in version_spec.split(",") if clause.strip()]
if not clauses:
return True, ""
operators_pattern = r"^(==|>=|<=|>|<)\s*(.+)$"
for clause in clauses:
if not (match := re.match(operators_pattern, clause)):
return False, f"无效版本约束表达式: {clause}"
operator, target_version = match.group(1), VersionComparator.normalize_version(match.group(2))
compare_result = VersionComparator.compare_versions(normalized_version, target_version)
is_satisfied = False
if operator == "==":
is_satisfied = compare_result == 0
elif operator == ">=":
is_satisfied = compare_result >= 0
elif operator == "<=":
is_satisfied = compare_result <= 0
elif operator == ">":
is_satisfied = compare_result > 0
elif operator == "<":
is_satisfied = compare_result < 0
if not is_satisfied:
return False, f"{normalized_version} 不满足约束 {operator}{target_version}"
return True, ""
def get_config(self, key: str, default: Any = None) -> Any:
"""获取插件配置值,支持嵌套键访问
Args:
key: 配置键名,支持嵌套访问如 "section.subsection.key"
default: 默认值
Returns:
Any: 配置值或默认值
"""
# 支持嵌套键访问
keys = key.split(".")
current = self.config
for k in keys:
if isinstance(current, dict) and k in current:
current = current[k]
else:
return default
return current
def get_webui_config_schema(self) -> Dict[str, Any]:
"""
获取 WebUI 配置 Schema
返回完整的配置 schema包含
- 插件基本信息
- 所有 section 及其字段定义
- 布局配置
用于 WebUI 动态生成配置表单。
Returns:
Dict: 完整的配置 schema
"""
schema = {
"plugin_id": self.plugin_name,
"plugin_info": {
"name": self.display_name,
"version": self.plugin_version,
"description": self.plugin_description,
"author": self.plugin_author,
},
"sections": {},
"layout": None,
}
# 处理 sections
for section_name, fields in self.config_schema.items():
if not isinstance(fields, dict):
continue
section_data = {
"name": section_name,
"title": section_name,
"description": None,
"icon": None,
"collapsed": False,
"order": 0,
"fields": {},
}
# 获取 section 元数据
section_meta = self.config_section_descriptions.get(section_name)
if section_meta:
if isinstance(section_meta, str):
section_data["title"] = section_meta
elif isinstance(section_meta, ConfigSection):
section_data["title"] = section_meta.title
section_data["description"] = section_meta.description
section_data["icon"] = section_meta.icon
section_data["collapsed"] = section_meta.collapsed
section_data["order"] = section_meta.order
elif isinstance(section_meta, dict):
section_data.update(section_meta)
# 处理字段
for field_name, field_def in fields.items():
if isinstance(field_def, ConfigField):
field_data = field_def.to_dict()
field_data["name"] = field_name
section_data["fields"][field_name] = field_data
schema["sections"][section_name] = section_data
# 处理布局
if self.config_layout:
schema["layout"] = self.config_layout.to_dict()
else:
# 自动布局:按 section order 排序
schema["layout"] = {
"type": "auto",
"tabs": [],
}
return schema
def get_current_config_values(self) -> Dict[str, Any]:
"""
获取当前配置值
返回插件当前的配置值(已从配置文件加载)。
Returns:
Dict: 当前配置值
"""
return self.config.copy()
@abstractmethod
def register_plugin(self) -> bool:
"""
注册插件到插件管理器
子类必须实现此方法,返回注册是否成功
Returns:
bool: 是否成功注册插件
"""
raise NotImplementedError("Subclasses must implement this method")

View File

@@ -1,22 +0,0 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List
@dataclass
class PluginServiceInfo:
"""插件服务注册信息"""
name: str
plugin_name: str
version: str = "1.0.0"
description: str = ""
enabled: bool = True
public: bool = False
allowed_callers: List[str] = field(default_factory=list)
params_schema: Dict[str, Any] = field(default_factory=dict)
return_schema: Dict[str, Any] = field(default_factory=dict)
metadata: Dict[str, Any] = field(default_factory=dict)
@property
def full_name(self) -> str:
return f"{self.plugin_name}.{self.name}"

View File

@@ -1,14 +0,0 @@
from enum import Enum
class WorkflowErrorCode(Enum):
"""Workflow统一错误码"""
PLUGIN_NOT_READY = "PLUGIN_NOT_READY"
STEP_TIMEOUT = "STEP_TIMEOUT"
BAD_PAYLOAD = "BAD_PAYLOAD"
DOWNSTREAM_FAILED = "DOWNSTREAM_FAILED"
POLICY_BLOCKED = "POLICY_BLOCKED"
def __str__(self) -> str:
return self.value

View File

@@ -1,68 +0,0 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Literal, Optional
import time
class WorkflowStage(Enum):
"""Workflow阶段定义MVP固定阶段"""
INGRESS = "ingress"
PRE_PROCESS = "pre_process"
PLAN = "plan"
TOOL_EXECUTE = "tool_execute"
POST_PROCESS = "post_process"
EGRESS = "egress"
def __str__(self) -> str:
return self.value
@dataclass
class WorkflowContext:
"""Workflow上下文"""
trace_id: str
stream_id: Optional[str] = None
metadata: Dict[str, Any] = field(default_factory=dict)
timings: Dict[str, float] = field(default_factory=dict)
errors: List[str] = field(default_factory=list)
@dataclass
class WorkflowMessage:
"""Workflow消息包装"""
msg_type: str
payload: Dict[str, Any] = field(default_factory=dict)
headers: Dict[str, Any] = field(default_factory=dict)
mutable_flags: Dict[str, bool] = field(default_factory=dict)
@dataclass
class WorkflowStepResult:
"""Workflow步骤结果"""
status: Literal["continue", "stop", "failed"] = "continue"
return_message: Optional[str] = None
diagnostics: Dict[str, Any] = field(default_factory=dict)
events: List[Dict[str, Any]] = field(default_factory=list)
created_at: float = field(default_factory=time.time)
@dataclass
class WorkflowStepInfo:
"""Workflow步骤元数据"""
name: str
stage: WorkflowStage
plugin_name: str
description: str = ""
enabled: bool = True
priority: int = 0
timeout_ms: int = 0
metadata: Dict[str, Any] = field(default_factory=dict)
@property
def full_name(self) -> str:
return f"{self.plugin_name}.{self.name}"

View File

@@ -1,21 +0,0 @@
"""
插件核心管理模块
提供插件的加载、注册和管理功能
"""
from src.plugin_system.core.plugin_manager import plugin_manager
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.core.events_manager import events_manager
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.plugin_system.core.plugin_service_registry import plugin_service_registry
from src.plugin_system.core.workflow_engine import workflow_engine
__all__ = [
"plugin_manager",
"component_registry",
"events_manager",
"global_announcement_manager",
"plugin_service_registry",
"workflow_engine",
]

View File

@@ -1,758 +0,0 @@
import re
from typing import Callable, Dict, List, Optional, Any, Pattern, Tuple, Union, Type
from src.common.logger import get_logger
from src.plugin_system.base.component_types import (
ComponentInfo,
ActionInfo,
ToolInfo,
CommandInfo,
EventHandlerInfo,
PluginInfo,
ComponentType,
)
from src.plugin_system.base.base_command import BaseCommand
from src.plugin_system.base.base_action import BaseAction
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.base_events_handler import BaseEventHandler
from src.plugin_system.base.workflow_types import WorkflowStage, WorkflowStepInfo
logger = get_logger("component_registry")
class ComponentRegistry:
"""统一的组件注册中心
负责管理所有插件组件的注册、查询和生命周期管理
"""
def __init__(self):
# 命名空间式组件名构成法 f"{component_type}.{component_name}"
self._components: Dict[str, ComponentInfo] = {}
"""组件注册表 命名空间式组件名 -> 组件信息"""
self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType}
"""类型 -> 组件原名称 -> 组件信息"""
self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseTool, BaseEventHandler]]] = {}
"""命名空间式组件名 -> 组件类"""
# 插件注册表
self._plugins: Dict[str, PluginInfo] = {}
"""插件名 -> 插件信息"""
# Action特定注册表
self._action_registry: Dict[str, Type[BaseAction]] = {}
"""Action注册表 action名 -> action类"""
self._default_actions: Dict[str, ActionInfo] = {}
"""默认动作集即启用的Action集用于重置ActionManager状态"""
# Command特定注册表
self._command_registry: Dict[str, Type[BaseCommand]] = {}
"""Command类注册表 command名 -> command类"""
self._command_patterns: Dict[Pattern, str] = {}
"""编译后的正则 -> command名"""
# 工具特定注册表
self._tool_registry: Dict[str, Type[BaseTool]] = {} # 工具名 -> 工具类
self._llm_available_tools: Dict[str, Type[BaseTool]] = {} # llm可用的工具名 -> 工具类
# EventHandler特定注册表
self._event_handler_registry: Dict[str, Type[BaseEventHandler]] = {}
"""event_handler名 -> event_handler类"""
self._enabled_event_handlers: Dict[str, Type[BaseEventHandler]] = {}
"""启用的事件处理器 event_handler名 -> event_handler类"""
# Workflow step注册表
self._workflow_steps: Dict[WorkflowStage, Dict[str, WorkflowStepInfo]] = {stage: {} for stage in WorkflowStage}
self._workflow_step_handlers: Dict[str, Callable[..., Any]] = {}
logger.info("组件注册中心初始化完成")
# == 注册方法 ==
def register_plugin(self, plugin_info: PluginInfo) -> bool:
"""注册插件
Args:
plugin_info: 插件信息
Returns:
bool: 是否注册成功
"""
plugin_name = plugin_info.name
if plugin_name in self._plugins:
logger.warning(f"插件 {plugin_name} 已存在,跳过注册")
return False
self._plugins[plugin_name] = plugin_info
logger.debug(f"已注册插件: {plugin_name} (组件数量: {len(plugin_info.components)})")
return True
def register_component(
self,
component_info: ComponentInfo,
component_class: Type[Union[BaseCommand, BaseAction, BaseEventHandler, BaseTool]],
) -> bool:
"""注册组件
Args:
component_info (ComponentInfo): 组件信息
component_class (Type[Union[BaseCommand, BaseAction, BaseEventHandler]]): 组件类
Returns:
bool: 是否注册成功
"""
component_name = component_info.name
component_type = component_info.component_type
plugin_name = getattr(component_info, "plugin_name", "unknown")
if "." in component_name:
logger.error(f"组件名称 '{component_name}' 包含非法字符 '.',请使用下划线替代")
return False
if "." in plugin_name:
logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
return False
namespaced_name = f"{component_type}.{component_name}"
if namespaced_name in self._components:
existing_info = self._components[namespaced_name]
existing_plugin = getattr(existing_info, "plugin_name", "unknown")
logger.warning(
f"组件名冲突: '{plugin_name}' 插件的 {component_type} 类型组件 '{component_name}' 已被插件 '{existing_plugin}' 注册,跳过此组件注册"
)
return False
self._components[namespaced_name] = component_info # 注册到通用注册表(使用命名空间化的名称)
self._components_by_type[component_type][component_name] = component_info # 类型内部仍使用原名
self._components_classes[namespaced_name] = component_class
# 根据组件类型进行特定注册(使用原始名称)
ret = False
match component_type:
case ComponentType.ACTION:
assert isinstance(component_info, ActionInfo)
assert issubclass(component_class, BaseAction)
ret = self._register_action_component(component_info, component_class)
case ComponentType.COMMAND:
assert isinstance(component_info, CommandInfo)
assert issubclass(component_class, BaseCommand)
ret = self._register_command_component(component_info, component_class)
case ComponentType.TOOL:
assert isinstance(component_info, ToolInfo)
assert issubclass(component_class, BaseTool)
ret = self._register_tool_component(component_info, component_class)
case ComponentType.EVENT_HANDLER:
assert isinstance(component_info, EventHandlerInfo)
assert issubclass(component_class, BaseEventHandler)
ret = self._register_event_handler_component(component_info, component_class)
case _:
logger.warning(f"未知组件类型: {component_type}")
if not ret:
return False
logger.debug(
f"已注册{component_type}组件: '{component_name}' -> '{namespaced_name}' "
f"({component_class.__name__}) [插件: {plugin_name}]"
)
return True
def _register_action_component(self, action_info: ActionInfo, action_class: Type[BaseAction]) -> bool:
"""注册Action组件到Action特定注册表"""
if not (action_name := action_info.name):
logger.error(f"Action组件 {action_class.__name__} 必须指定名称")
return False
if not isinstance(action_info, ActionInfo) or not issubclass(action_class, BaseAction):
logger.error(f"注册失败: {action_name} 不是有效的Action")
return False
self._action_registry[action_name] = action_class
# 如果启用,添加到默认动作集
if action_info.enabled:
self._default_actions[action_name] = action_info
return True
def _register_command_component(self, command_info: CommandInfo, command_class: Type[BaseCommand]) -> bool:
"""注册Command组件到Command特定注册表"""
if not (command_name := command_info.name):
logger.error(f"Command组件 {command_class.__name__} 必须指定名称")
return False
if not isinstance(command_info, CommandInfo) or not issubclass(command_class, BaseCommand):
logger.error(f"注册失败: {command_name} 不是有效的Command")
return False
self._command_registry[command_name] = command_class
# 如果启用了且有匹配模式
if command_info.enabled and command_info.command_pattern:
pattern = re.compile(command_info.command_pattern, re.IGNORECASE | re.DOTALL)
if pattern not in self._command_patterns:
self._command_patterns[pattern] = command_name
else:
logger.warning(
f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令"
)
return True
def _register_tool_component(self, tool_info: ToolInfo, tool_class: Type[BaseTool]) -> bool:
"""注册Tool组件到Tool特定注册表"""
tool_name = tool_info.name
self._tool_registry[tool_name] = tool_class
# 如果是llm可用的且启用的工具,添加到 llm可用工具列表
if tool_info.enabled:
self._llm_available_tools[tool_name] = tool_class
return True
def _register_event_handler_component(
self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]
) -> bool:
if not (handler_name := handler_info.name):
logger.error(f"EventHandler组件 {handler_class.__name__} 必须指定名称")
return False
if not isinstance(handler_info, EventHandlerInfo) or not issubclass(handler_class, BaseEventHandler):
logger.error(f"注册失败: {handler_name} 不是有效的EventHandler")
return False
self._event_handler_registry[handler_name] = handler_class
if not handler_info.enabled:
logger.warning(f"EventHandler组件 {handler_name} 未启用")
return True # 未启用,但是也是注册成功
from .events_manager import events_manager # 延迟导入防止循环导入问题
if events_manager.register_event_subscriber(handler_info, handler_class):
self._enabled_event_handlers[handler_name] = handler_class
return True
else:
logger.error(f"注册事件处理器 {handler_name} 失败")
return False
# === 组件移除相关 ===
async def remove_component(self, component_name: str, component_type: ComponentType, plugin_name: str) -> bool:
target_component_class = self.get_component_class(component_name, component_type)
if not target_component_class:
logger.warning(f"组件 {component_name} 未注册,无法移除")
return False
try:
match component_type:
case ComponentType.ACTION:
self._action_registry.pop(component_name)
self._default_actions.pop(component_name)
case ComponentType.COMMAND:
self._command_registry.pop(component_name)
keys_to_remove = [k for k, v in self._command_patterns.items() if v == component_name]
for key in keys_to_remove:
self._command_patterns.pop(key)
case ComponentType.TOOL:
self._tool_registry.pop(component_name)
self._llm_available_tools.pop(component_name)
case ComponentType.EVENT_HANDLER:
from .events_manager import events_manager # 延迟导入防止循环导入问题
self._event_handler_registry.pop(component_name)
self._enabled_event_handlers.pop(component_name)
await events_manager.unregister_event_subscriber(component_name)
namespaced_name = f"{component_type}.{component_name}"
self._components.pop(namespaced_name)
self._components_by_type[component_type].pop(component_name)
self._components_classes.pop(namespaced_name)
logger.info(f"组件 {component_name} 已移除")
return True
except KeyError as e:
logger.warning(f"移除组件时未找到组件: {component_name}, 发生错误: {e}")
return False
except Exception as e:
logger.error(f"移除组件 {component_name} 时发生错误: {e}")
return False
async def remove_components_by_plugin(self, plugin_name: str) -> int:
"""移除某插件注册的所有组件。"""
targets = [
(component_info.name, component_info.component_type)
for component_info in self._components.values()
if component_info.plugin_name == plugin_name
]
removed_count = 0
for component_name, component_type in targets:
if await self.remove_component(component_name, component_type, plugin_name):
removed_count += 1
if removed_count:
logger.info(f"已移除插件 {plugin_name} 的组件数量: {removed_count}")
return removed_count
def remove_plugin_registry(self, plugin_name: str) -> bool:
"""移除插件注册信息
Args:
plugin_name: 插件名称
Returns:
bool: 是否成功移除
"""
if plugin_name not in self._plugins:
logger.warning(f"插件 {plugin_name} 未注册,无法移除")
return False
del self._plugins[plugin_name]
self.remove_workflow_steps_by_plugin(plugin_name)
logger.info(f"插件 {plugin_name} 已移除")
return True
# === Workflow step 注册与查询 ===
def register_workflow_step(self, step_info: WorkflowStepInfo, step_handler: Callable[..., Any]) -> bool:
"""注册workflow步骤。"""
if not step_info.name or not step_info.plugin_name:
logger.error("workflow step 注册失败: step名称或插件名称为空")
return False
if "." in step_info.name:
logger.error(f"workflow step 名称 '{step_info.name}' 包含非法字符 '.',请使用下划线替代")
return False
if "." in step_info.plugin_name:
logger.error(f"workflow step 所属插件名称 '{step_info.plugin_name}' 包含非法字符 '.',请使用下划线替代")
return False
full_name = step_info.full_name
stage_registry = self._workflow_steps.get(step_info.stage)
if stage_registry is None:
logger.error(f"workflow step 注册失败: 未知阶段 {step_info.stage}")
return False
if full_name in stage_registry:
logger.warning(f"workflow step 已存在,跳过注册: {full_name}")
return False
stage_registry[full_name] = step_info
self._workflow_step_handlers[full_name] = step_handler
logger.debug(f"已注册workflow step: {full_name} @ {step_info.stage}")
return True
def get_steps_by_stage(self, stage: WorkflowStage, enabled_only: bool = False) -> Dict[str, WorkflowStepInfo]:
"""获取某阶段的workflow步骤。"""
steps = self._workflow_steps.get(stage, {})
if enabled_only:
return {name: info for name, info in steps.items() if info.enabled}
return steps.copy()
def get_workflow_step(self, step_name: str, stage: Optional[WorkflowStage] = None) -> Optional[WorkflowStepInfo]:
"""获取workflow step信息。
step_name支持两种
- full_name: plugin_name.step_name
- short_name: step_name若有冲突返回第一个并告警
"""
candidates: List[WorkflowStepInfo] = []
target_stages = [stage] if stage else list(WorkflowStage)
for current_stage in target_stages:
current_steps = self._workflow_steps.get(current_stage, {})
if "." in step_name:
if step_info := current_steps.get(step_name):
return step_info
continue
candidates.extend([step_info for step_info in current_steps.values() if step_info.name == step_name])
if len(candidates) == 1:
return candidates[0]
if len(candidates) > 1:
logger.warning(f"workflow step 名称 '{step_name}' 存在多义,使用第一个匹配: {candidates[0].full_name}")
return candidates[0]
return None
def get_workflow_step_handler(
self, step_name: str, stage: Optional[WorkflowStage] = None
) -> Optional[Callable[..., Any]]:
"""获取workflow step处理函数。"""
if "." in step_name:
return self._workflow_step_handlers.get(step_name)
if step_info := self.get_workflow_step(step_name, stage):
return self._workflow_step_handlers.get(step_info.full_name)
return None
def enable_workflow_step(self, step_name: str, stage: Optional[WorkflowStage] = None) -> bool:
"""启用workflow step。"""
step_info = self.get_workflow_step(step_name, stage)
if not step_info:
logger.warning(f"workflow step 未注册,无法启用: {step_name}")
return False
step_info.enabled = True
logger.info(f"workflow step 已启用: {step_info.full_name}")
return True
def disable_workflow_step(self, step_name: str, stage: Optional[WorkflowStage] = None) -> bool:
"""禁用workflow step。"""
step_info = self.get_workflow_step(step_name, stage)
if not step_info:
logger.warning(f"workflow step 未注册,无法禁用: {step_name}")
return False
step_info.enabled = False
logger.info(f"workflow step 已禁用: {step_info.full_name}")
return True
def remove_workflow_steps_by_plugin(self, plugin_name: str) -> int:
"""移除某插件注册的所有workflow step。"""
removed_count = 0
for stage in WorkflowStage:
stage_registry = self._workflow_steps.get(stage, {})
target_names = [name for name, info in stage_registry.items() if info.plugin_name == plugin_name]
for full_name in target_names:
stage_registry.pop(full_name, None)
self._workflow_step_handlers.pop(full_name, None)
removed_count += 1
if removed_count:
logger.info(f"已移除插件 {plugin_name} 的 workflow step 数量: {removed_count}")
return removed_count
# === 组件全局启用/禁用方法 ===
def enable_component(self, component_name: str, component_type: ComponentType) -> bool:
"""全局的启用某个组件
Parameters:
component_name: 组件名称
component_type: 组件类型
Returns:
bool: 启用成功返回True失败返回False
"""
target_component_class = self.get_component_class(component_name, component_type)
target_component_info = self.get_component_info(component_name, component_type)
if not target_component_class or not target_component_info:
logger.warning(f"组件 {component_name} 未注册,无法启用")
return False
target_component_info.enabled = True
match component_type:
case ComponentType.ACTION:
assert isinstance(target_component_info, ActionInfo)
self._default_actions[component_name] = target_component_info
case ComponentType.COMMAND:
assert isinstance(target_component_info, CommandInfo)
pattern = target_component_info.command_pattern
self._command_patterns[re.compile(pattern)] = component_name
case ComponentType.TOOL:
assert isinstance(target_component_info, ToolInfo)
assert issubclass(target_component_class, BaseTool)
self._llm_available_tools[component_name] = target_component_class
case ComponentType.EVENT_HANDLER:
assert isinstance(target_component_info, EventHandlerInfo)
assert issubclass(target_component_class, BaseEventHandler)
self._enabled_event_handlers[component_name] = target_component_class
from .events_manager import events_manager # 延迟导入防止循环导入问题
events_manager.register_event_subscriber(target_component_info, target_component_class)
namespaced_name = f"{component_type}.{component_name}"
self._components[namespaced_name].enabled = True
self._components_by_type[component_type][component_name].enabled = True
logger.info(f"组件 {component_name} 已启用")
return True
async def disable_component(self, component_name: str, component_type: ComponentType) -> bool:
"""全局的禁用某个组件
Parameters:
component_name: 组件名称
component_type: 组件类型
Returns:
bool: 禁用成功返回True失败返回False
"""
target_component_class = self.get_component_class(component_name, component_type)
target_component_info = self.get_component_info(component_name, component_type)
if not target_component_class or not target_component_info:
logger.warning(f"组件 {component_name} 未注册,无法禁用")
return False
target_component_info.enabled = False
try:
match component_type:
case ComponentType.ACTION:
self._default_actions.pop(component_name)
case ComponentType.COMMAND:
self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != component_name}
case ComponentType.TOOL:
self._llm_available_tools.pop(component_name)
case ComponentType.EVENT_HANDLER:
self._enabled_event_handlers.pop(component_name)
from .events_manager import events_manager # 延迟导入防止循环导入问题
await events_manager.unregister_event_subscriber(component_name)
self._components[component_name].enabled = False
self._components_by_type[component_type][component_name].enabled = False
logger.info(f"组件 {component_name} 已禁用")
return True
except KeyError as e:
logger.warning(f"禁用组件时未找到组件或已禁用: {component_name}, 发生错误: {e}")
return False
except Exception as e:
logger.error(f"禁用组件 {component_name} 时发生错误: {e}")
return False
# === 组件查询方法 ===
def get_component_info(
self, component_name: str, component_type: Optional[ComponentType] = None
) -> Optional[ComponentInfo]:
# sourcery skip: class-extract-method
"""获取组件信息,支持自动命名空间解析
Args:
component_name: 组件名称,可以是原始名称或命名空间化的名称
component_type: 组件类型,如果提供则优先在该类型中查找
Returns:
Optional[ComponentInfo]: 组件信息或None
"""
# 1. 如果已经是命名空间化的名称,直接查找
if "." in component_name:
return self._components.get(component_name)
# 2. 如果指定了组件类型,构造命名空间化的名称查找
if component_type:
namespaced_name = f"{component_type}.{component_name}"
return self._components.get(namespaced_name)
# 3. 如果没有指定类型,尝试在所有命名空间中查找
candidates = []
for namespace_prefix in [types.value for types in ComponentType]:
namespaced_name = f"{namespace_prefix}.{component_name}"
if component_info := self._components.get(namespaced_name):
candidates.append((namespace_prefix, namespaced_name, component_info))
if len(candidates) == 1:
# 只有一个匹配,直接返回
return candidates[0][2]
elif len(candidates) > 1:
# 多个匹配,记录警告并返回第一个
namespaces = [ns for ns, _, _ in candidates]
logger.warning(
f"组件名称 '{component_name}' 在多个命名空间中存在: {namespaces},使用第一个匹配项: {candidates[0][1]}"
)
return candidates[0][2]
# 4. 都没找到
return None
def get_component_class(
self,
component_name: str,
component_type: Optional[ComponentType] = None,
) -> Optional[Union[Type[BaseCommand], Type[BaseAction], Type[BaseEventHandler], Type[BaseTool]]]:
"""获取组件类,支持自动命名空间解析
Args:
component_name: 组件名称,可以是原始名称或命名空间化的名称
component_type: 组件类型,如果提供则优先在该类型中查找
Returns:
Optional[Union[BaseCommand, BaseAction]]: 组件类或None
"""
# 1. 如果已经是命名空间化的名称,直接查找
if "." in component_name:
return self._components_classes.get(component_name)
# 2. 如果指定了组件类型,构造命名空间化的名称查找
if component_type:
namespaced_name = f"{component_type.value}.{component_name}"
return self._components_classes.get(namespaced_name)
# 3. 如果没有指定类型,尝试在所有命名空间中查找
candidates = []
for namespace_prefix in [types.value for types in ComponentType]:
namespaced_name = f"{namespace_prefix}.{component_name}"
if component_class := self._components_classes.get(namespaced_name):
candidates.append((namespace_prefix, namespaced_name, component_class))
if len(candidates) == 1:
# 只有一个匹配,直接返回
_, full_name, cls = candidates[0]
logger.debug(f"自动解析组件: '{component_name}' -> '{full_name}'")
return cls
elif len(candidates) > 1:
# 多个匹配,记录警告并返回第一个
namespaces = [ns for ns, _, _ in candidates]
logger.warning(
f"组件名称 '{component_name}' 在多个命名空间中存在: {namespaces},使用第一个匹配项: {candidates[0][1]}"
)
return candidates[0][2]
# 4. 都没找到
return None
def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
"""获取指定类型的所有组件"""
return self._components_by_type.get(component_type, {}).copy()
def get_enabled_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
"""获取指定类型的所有启用组件"""
components = self.get_components_by_type(component_type)
return {name: info for name, info in components.items() if info.enabled}
# === Action特定查询方法 ===
def get_action_registry(self) -> Dict[str, Type[BaseAction]]:
"""获取Action注册表"""
return self._action_registry.copy()
def get_registered_action_info(self, action_name: str) -> Optional[ActionInfo]:
"""获取Action信息"""
info = self.get_component_info(action_name, ComponentType.ACTION)
return info if isinstance(info, ActionInfo) else None
def get_default_actions(self) -> Dict[str, ActionInfo]:
"""获取默认动作集"""
return self._default_actions.copy()
# === Command特定查询方法 ===
def get_command_registry(self) -> Dict[str, Type[BaseCommand]]:
"""获取Command注册表"""
return self._command_registry.copy()
def get_registered_command_info(self, command_name: str) -> Optional[CommandInfo]:
"""获取Command信息"""
info = self.get_component_info(command_name, ComponentType.COMMAND)
return info if isinstance(info, CommandInfo) else None
def get_command_patterns(self) -> Dict[Pattern, str]:
"""获取Command模式注册表"""
return self._command_patterns.copy()
def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, CommandInfo]]:
# sourcery skip: use-named-expression, use-next
"""根据文本查找匹配的命令
Args:
text: 输入文本
Returns:
Tuple: (命令类, 匹配的命名组, 是否拦截消息, 插件名) 或 None
"""
candidates = [pattern for pattern in self._command_patterns if pattern.match(text)]
if not candidates:
return None
if len(candidates) > 1:
logger.warning(f"文本 '{text}' 匹配到多个命令模式: {candidates},使用第一个匹配")
command_name = self._command_patterns[candidates[0]]
command_info: CommandInfo = self.get_registered_command_info(command_name) # type: ignore
return (
self._command_registry[command_name],
candidates[0].match(text).groupdict(), # type: ignore
command_info,
)
# === Tool 特定查询方法 ===
def get_tool_registry(self) -> Dict[str, Type[BaseTool]]:
"""获取Tool注册表"""
return self._tool_registry.copy()
def get_llm_available_tools(self) -> Dict[str, Type[BaseTool]]:
"""获取LLM可用的Tool列表"""
return self._llm_available_tools.copy()
def get_registered_tool_info(self, tool_name: str) -> Optional[ToolInfo]:
"""获取Tool信息
Args:
tool_name: 工具名称
Returns:
ToolInfo: 工具信息对象,如果工具不存在则返回 None
"""
info = self.get_component_info(tool_name, ComponentType.TOOL)
return info if isinstance(info, ToolInfo) else None
# === EventHandler 特定查询方法 ===
def get_event_handler_registry(self) -> Dict[str, Type[BaseEventHandler]]:
"""获取事件处理器注册表"""
return self._event_handler_registry.copy()
def get_registered_event_handler_info(self, handler_name: str) -> Optional[EventHandlerInfo]:
"""获取事件处理器信息"""
info = self.get_component_info(handler_name, ComponentType.EVENT_HANDLER)
return info if isinstance(info, EventHandlerInfo) else None
def get_enabled_event_handlers(self) -> Dict[str, Type[BaseEventHandler]]:
"""获取启用的事件处理器"""
return self._enabled_event_handlers.copy()
# === 插件查询方法 ===
def get_plugin_info(self, plugin_name: str) -> Optional[PluginInfo]:
"""获取插件信息"""
return self._plugins.get(plugin_name)
def get_all_plugins(self) -> Dict[str, PluginInfo]:
"""获取所有插件"""
return self._plugins.copy()
# def get_enabled_plugins(self) -> Dict[str, PluginInfo]:
# """获取所有启用的插件"""
# return {name: info for name, info in self._plugins.items() if info.enabled}
def get_plugin_components(self, plugin_name: str) -> List[ComponentInfo]:
"""获取插件的所有组件"""
plugin_info = self.get_plugin_info(plugin_name)
return plugin_info.components if plugin_info else []
def get_plugin_config(self, plugin_name: str) -> Optional[dict]:
"""获取插件配置
Args:
plugin_name: 插件名称
Returns:
Optional[dict]: 插件配置字典或None
"""
# 从插件管理器获取插件实例的配置
from src.plugin_system.core.plugin_manager import plugin_manager
plugin_instance = plugin_manager.get_plugin_instance(plugin_name)
return plugin_instance.config if plugin_instance else None
def get_registry_stats(self) -> Dict[str, Any]:
"""获取注册中心统计信息"""
action_components: int = 0
command_components: int = 0
tool_components: int = 0
events_handlers: int = 0
for component in self._components.values():
if component.component_type == ComponentType.ACTION:
action_components += 1
elif component.component_type == ComponentType.COMMAND:
command_components += 1
elif component.component_type == ComponentType.TOOL:
tool_components += 1
elif component.component_type == ComponentType.EVENT_HANDLER:
events_handlers += 1
workflow_step_count = sum(len(steps) for steps in self._workflow_steps.values())
enabled_workflow_step_count = sum(
len([step for step in steps.values() if step.enabled]) for steps in self._workflow_steps.values()
)
return {
"action_components": action_components,
"command_components": command_components,
"tool_components": tool_components,
"event_handlers": events_handlers,
"total_components": len(self._components),
"total_plugins": len(self._plugins),
"components_by_type": {
component_type.value: len(components) for component_type, components in self._components_by_type.items()
},
"enabled_components": len([c for c in self._components.values() if c.enabled]),
"enabled_plugins": len([p for p in self._plugins.values() if p.enabled]),
"workflow_steps": workflow_step_count,
"enabled_workflow_steps": enabled_workflow_step_count,
"workflow_steps_by_stage": {stage.value: len(steps) for stage, steps in self._workflow_steps.items()},
}
component_registry = ComponentRegistry()

View File

@@ -1,512 +0,0 @@
import asyncio
import contextlib
from typing import List, Dict, Optional, Type, Tuple, TYPE_CHECKING
from src.chat.message_receive.message import MessageSending, SessionMessage
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.common.logger import get_logger
from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages, CustomEventHandlerResult
from src.plugin_system.base.base_events_handler import BaseEventHandler
from src.plugin_system.base.workflow_types import WorkflowContext, WorkflowStepResult
from .global_announcement_manager import global_announcement_manager
from .workflow_engine import workflow_engine
if TYPE_CHECKING:
from src.common.data_models.llm_data_model import LLMGenerationDataModel
logger = get_logger("events_manager")
class EventsManager:
def __init__(self):
# 有权重的 events 订阅者注册表
self._events_subscribers: Dict[EventType | str, List[BaseEventHandler]] = {}
self._handler_mapping: Dict[str, Type[BaseEventHandler]] = {} # 事件处理器映射表
self._handler_tasks: Dict[str, List[asyncio.Task]] = {} # 事件处理器正在处理的任务
self._events_result_history: Dict[EventType | str, List[CustomEventHandlerResult]] = {} # 事件的结果历史记录
self._history_enable_map: Dict[EventType | str, bool] = {} # 是否启用历史记录的映射表同时作为events注册表
# 事件注册(同时作为注册样例)
for event in EventType:
self.register_event(event, enable_history_result=False)
def register_event(self, event_type: EventType | str, enable_history_result: bool = False):
if event_type in self._events_subscribers:
raise ValueError(f"事件类型 {event_type} 已存在")
self._events_subscribers[event_type] = []
self._history_enable_map[event_type] = enable_history_result
if enable_history_result:
self._events_result_history[event_type] = []
def register_event_subscriber(self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]) -> bool:
"""注册事件处理器
Args:
handler_info (EventHandlerInfo): 事件处理器信息
handler_class (Type[BaseEventHandler]): 事件处理器类
Returns:
bool: 是否注册成功
"""
if not issubclass(handler_class, BaseEventHandler):
logger.error(f"{handler_class.__name__} 不是 BaseEventHandler 的子类")
return False
handler_name = handler_info.name
if handler_name in self._handler_mapping:
logger.warning(f"事件处理器 {handler_name} 已存在,跳过注册")
return False
if handler_info.event_type not in self._history_enable_map:
if isinstance(handler_info.event_type, str):
self.register_event(handler_info.event_type, enable_history_result=False)
logger.info(f"自动注册自定义事件类型: {handler_info.event_type}")
else:
logger.error(f"事件类型 {handler_info.event_type} 未注册,无法为其注册处理器 {handler_name}")
return False
self._handler_mapping[handler_name] = handler_class
return self._insert_event_handler(handler_class, handler_info)
async def handle_mai_events(
self,
event_type: EventType | str,
message: Optional[SessionMessage | MessageSending | MaiMessages] = None,
llm_prompt: Optional[str] = None,
llm_response: Optional["LLMGenerationDataModel"] = None,
stream_id: Optional[str] = None,
action_usage: Optional[List[str]] = None,
) -> Tuple[bool, Optional[MaiMessages]]:
"""
处理所有事件,根据事件类型分发给订阅的处理器。
"""
from src.plugin_system.core import component_registry
continue_flag = True
# 1. 准备消息
transformed_message = self._prepare_message(
event_type, message, llm_prompt, llm_response, stream_id, action_usage # type: ignore[arg-type]
)
if transformed_message:
transformed_message = transformed_message.deepcopy()
# 2. 获取并遍历处理器
handlers = self._events_subscribers.get(event_type, [])
if not handlers:
return True, None
current_stream_id = transformed_message.stream_id if transformed_message else None
modified_message: Optional[MaiMessages] = None
for handler in handlers:
# 3. 前置检查和配置加载
if (
current_stream_id
and handler.handler_name
in global_announcement_manager.get_disabled_chat_event_handlers(current_stream_id)
):
continue
# 统一加载插件配置
plugin_config = component_registry.get_plugin_config(handler.plugin_name) or {}
handler.set_plugin_config(plugin_config)
# 4. 根据类型分发任务
if (
handler.intercept_message or event_type == EventType.ON_STOP
): # 让ON_STOP的所有事件处理器都发挥作用防止还没执行即被取消
# 阻塞执行,并更新 continue_flag
should_continue, modified_message = await self._dispatch_intercepting_handler_task(
handler, event_type, modified_message or transformed_message
)
continue_flag = continue_flag and should_continue
else:
# 异步执行,不阻塞
self._dispatch_handler_task(handler, event_type, transformed_message)
# 桥接到新版本插件运行时
continue_flag, modified_message = await self._bridge_to_new_runtime(
event_type, continue_flag, modified_message or transformed_message
)
return continue_flag, modified_message
async def handle_workflow_message(
self,
message: Optional[SessionMessage | MessageSending | MaiMessages] = None,
stream_id: Optional[str] = None,
action_usage: Optional[List[str]] = None,
context: Optional[WorkflowContext] = None,
) -> Tuple[WorkflowStepResult, Optional[MaiMessages], WorkflowContext]:
"""执行线性workflow消息流转MVP兼容入口"""
initial_message = self._prepare_message(EventType.ON_MESSAGE_PRE_PROCESS, message=message, stream_id=stream_id)
async def _dispatch(
event_type: EventType | str,
workflow_message: Optional[MaiMessages],
workflow_stream_id: Optional[str],
workflow_action_usage: Optional[List[str]],
) -> Tuple[bool, Optional[MaiMessages]]:
return await self.handle_mai_events(
event_type=event_type,
message=workflow_message,
stream_id=workflow_stream_id,
action_usage=workflow_action_usage,
)
return await workflow_engine.execute_linear(
dispatch_event=_dispatch,
message=initial_message,
stream_id=stream_id,
action_usage=action_usage,
context=context,
)
async def cancel_handler_tasks(self, handler_name: str) -> None:
tasks_to_be_cancelled = self._handler_tasks.get(handler_name, [])
if remaining_tasks := [task for task in tasks_to_be_cancelled if not task.done()]:
for task in remaining_tasks:
task.cancel()
try:
await asyncio.wait_for(asyncio.gather(*remaining_tasks, return_exceptions=True), timeout=5)
logger.info(f"已取消事件处理器 {handler_name} 的所有任务")
except asyncio.TimeoutError:
logger.warning(f"取消事件处理器 {handler_name} 的任务超时,开始强制取消")
except Exception as e:
logger.error(f"取消事件处理器 {handler_name} 的任务时发生异常: {e}")
if handler_name in self._handler_tasks:
del self._handler_tasks[handler_name]
async def unregister_event_subscriber(self, handler_name: str) -> bool:
"""取消注册事件处理器"""
if handler_name not in self._handler_mapping:
logger.warning(f"事件处理器 {handler_name} 不存在,无法取消注册")
return False
await self.cancel_handler_tasks(handler_name)
handler_class = self._handler_mapping.pop(handler_name)
if not self._remove_event_handler_instance(handler_class):
return False
logger.info(f"事件处理器 {handler_name} 已成功取消注册")
return True
async def get_event_result_history(self, event_type: EventType | str) -> List[CustomEventHandlerResult]:
"""获取事件的结果历史记录"""
if event_type == EventType.UNKNOWN:
raise ValueError("未知事件类型")
if event_type not in self._history_enable_map:
raise ValueError(f"事件类型 {event_type} 未注册")
if not self._history_enable_map[event_type]:
raise ValueError(f"事件类型 {event_type} 的历史记录未启用")
return self._events_result_history[event_type]
async def clear_event_result_history(self, event_type: EventType | str) -> None:
"""清空事件的结果历史记录"""
if event_type == EventType.UNKNOWN:
raise ValueError("未知事件类型")
if event_type not in self._history_enable_map:
raise ValueError(f"事件类型 {event_type} 未注册")
if not self._history_enable_map[event_type]:
raise ValueError(f"事件类型 {event_type} 的历史记录未启用")
self._events_result_history[event_type] = []
def _insert_event_handler(self, handler_class: Type[BaseEventHandler], handler_info: EventHandlerInfo) -> bool:
"""插入事件处理器到对应的事件类型列表中并设置其插件配置"""
if handler_class.event_type == EventType.UNKNOWN:
logger.error(f"事件处理器 {handler_class.__name__} 的事件类型未知,无法注册")
return False
if handler_class.event_type not in self._events_subscribers:
self._events_subscribers[handler_class.event_type] = []
handler_instance = handler_class()
handler_instance.set_plugin_name(handler_info.plugin_name or "unknown")
self._events_subscribers[handler_class.event_type].append(handler_instance)
self._events_subscribers[handler_class.event_type].sort(key=lambda x: x.weight, reverse=True)
return True
def _remove_event_handler_instance(self, handler_class: Type[BaseEventHandler]) -> bool:
"""从事件类型列表中移除事件处理器"""
display_handler_name = handler_class.handler_name or handler_class.__name__
if handler_class.event_type == EventType.UNKNOWN:
logger.warning(f"事件处理器 {display_handler_name} 的事件类型未知,不存在于处理器列表中")
return False
handlers = self._events_subscribers[handler_class.event_type]
for i, handler in enumerate(handlers):
if isinstance(handler, handler_class):
del handlers[i]
logger.debug(f"事件处理器 {display_handler_name} 已移除")
return True
logger.warning(f"未找到事件处理器 {display_handler_name},无法移除")
return False
def _transform_event_message(
self,
message: SessionMessage | MessageSending,
llm_prompt: Optional[str] = None,
llm_response: Optional["LLMGenerationDataModel"] = None,
) -> MaiMessages:
"""转换事件消息格式"""
from maim_message import Seg
# 直接赋值部分内容
transformed_message = MaiMessages(
llm_prompt=llm_prompt,
llm_response_content=llm_response.content if llm_response else None,
llm_response_reasoning=llm_response.reasoning if llm_response else None,
llm_response_model=llm_response.model if llm_response else None,
llm_response_tool_call=llm_response.tool_calls if llm_response else None,
raw_message=message.processed_plain_text or "",
additional_data={},
)
# 消息段处理
if isinstance(message, MessageSending):
if message.message_segment.type == "seglist":
transformed_message.message_segments = list(message.message_segment.data) # type: ignore
else:
transformed_message.message_segments = [message.message_segment]
else:
# SessionMessage: 使用 processed_plain_text 构造简单段
transformed_message.message_segments = [Seg(type="text", data=message.processed_plain_text or "")]
# stream_id 处理
transformed_message.stream_id = message.session_id if hasattr(message, "session_id") else ""
# 处理后文本
transformed_message.plain_text = message.processed_plain_text
# 基本信息
if isinstance(message, MessageSending):
transformed_message.message_base_info["platform"] = message.platform
if message.session.group_id:
transformed_message.is_group_message = True
group_name = ""
if message.session.context and message.session.context.message and message.session.context.message.message_info.group_info:
group_name = message.session.context.message.message_info.group_info.group_name
transformed_message.message_base_info.update({
"group_id": message.session.group_id,
"group_name": group_name,
})
transformed_message.message_base_info.update({
"user_id": message.bot_user_info.user_id,
"user_cardname": message.bot_user_info.user_cardname,
"user_nickname": message.bot_user_info.user_nickname,
})
if not transformed_message.is_group_message:
transformed_message.is_private_message = True
elif hasattr(message, "message_info") and message.message_info:
if message.platform:
transformed_message.message_base_info["platform"] = message.platform
if message.message_info.group_info:
transformed_message.is_group_message = True
transformed_message.message_base_info.update({
"group_id": message.message_info.group_info.group_id,
"group_name": message.message_info.group_info.group_name,
})
if message.message_info.user_info:
if not transformed_message.is_group_message:
transformed_message.is_private_message = True
transformed_message.message_base_info.update({
"user_id": message.message_info.user_info.user_id,
"user_cardname": message.message_info.user_info.user_cardname,
"user_nickname": message.message_info.user_info.user_nickname,
})
return transformed_message
def _build_message_from_stream(
self, stream_id: str, llm_prompt: Optional[str] = None, llm_response: Optional["LLMGenerationDataModel"] = None
) -> MaiMessages:
"""从流ID构建消息"""
session = _chat_manager.get_session_by_session_id(stream_id)
assert session, f"未找到流ID为 {stream_id} 的会话"
message = session.context.message
return self._transform_event_message(message, llm_prompt, llm_response)
def _transform_event_without_message(
self,
stream_id: str,
llm_prompt: Optional[str] = None,
llm_response: Optional["LLMGenerationDataModel"] = None,
action_usage: Optional[List[str]] = None,
) -> MaiMessages:
"""没有message对象时进行转换"""
session = _chat_manager.get_session_by_session_id(stream_id)
assert session, f"未找到流ID为 {stream_id} 的会话"
return MaiMessages(
stream_id=stream_id,
llm_prompt=llm_prompt,
llm_response_content=(llm_response.content if llm_response else None),
llm_response_reasoning=(llm_response.reasoning if llm_response else None),
llm_response_model=(llm_response.model if llm_response else None),
llm_response_tool_call=(llm_response.tool_calls if llm_response else None),
is_group_message=session.is_group_session,
is_private_message=not session.is_group_session,
action_usage=action_usage,
additional_data={"response_is_processed": True},
)
async def _bridge_to_new_runtime(
self,
event_type: EventType | str,
continue_flag: bool,
message: Optional[MaiMessages],
) -> Tuple[bool, Optional[MaiMessages]]:
"""将事件桥接到新版本插件运行时
如果旧 handler 已经 abortcontinue_flag=False直接跳过。
"""
if not continue_flag:
return continue_flag, message
try:
from src.plugin_runtime.integration import get_plugin_runtime_manager
prm = get_plugin_runtime_manager()
if not prm.is_running:
return continue_flag, message
event_value = event_type.value if isinstance(event_type, EventType) else str(event_type)
message_dict = message.to_dict() if message and hasattr(message, "to_dict") else None
new_continue, new_msg_dict = await prm.bridge_event(
event_type_value=event_value,
message_dict=message_dict,
)
# 新运行时返回 abort 则合并
if not new_continue:
continue_flag = False
except Exception as e:
logger.warning(f"桥接事件到新运行时失败: {e}")
return continue_flag, message
def _prepare_message(
self,
event_type: EventType | str,
message: Optional[SessionMessage | MessageSending | MaiMessages] = None,
llm_prompt: Optional[str] = None,
llm_response: Optional["LLMGenerationDataModel"] = None,
stream_id: Optional[str] = None,
action_usage: Optional[List[str]] = None,
) -> Optional[MaiMessages]:
"""根据事件类型和输入,准备和转换消息对象。"""
if isinstance(message, MaiMessages):
return message.deepcopy()
if message:
return self._transform_event_message(message, llm_prompt, llm_response)
if event_type not in [EventType.ON_START, EventType.ON_STOP]:
assert stream_id, "如果没有消息,必须为非启动/关闭事件提供流ID"
if event_type in [EventType.ON_MESSAGE, EventType.ON_PLAN, EventType.POST_LLM, EventType.AFTER_LLM]:
return self._build_message_from_stream(stream_id, llm_prompt, llm_response)
else:
return self._transform_event_without_message(stream_id, llm_prompt, llm_response, action_usage)
return None # ON_START, ON_STOP事件没有消息体
def _dispatch_handler_task(
self, handler: BaseEventHandler, event_type: EventType | str, message: Optional[MaiMessages] = None
):
"""分发一个非阻塞(异步)的事件处理任务。"""
if event_type == EventType.UNKNOWN:
raise ValueError("未知事件类型")
try:
task = asyncio.create_task(handler.execute(message))
task_name = f"{handler.plugin_name}-{handler.handler_name}"
task.set_name(task_name)
task.add_done_callback(lambda t: self._task_done_callback(t, event_type))
self._handler_tasks.setdefault(handler.handler_name, []).append(task)
except Exception as e:
logger.error(f"创建事件处理器任务 {handler.handler_name} 时发生异常: {e}", exc_info=True)
async def _dispatch_intercepting_handler_task(
self, handler: BaseEventHandler, event_type: EventType | str, message: Optional[MaiMessages] = None
) -> Tuple[bool, Optional[MaiMessages]]:
"""分发并等待一个阻塞(同步)的事件处理器,返回是否应继续处理。"""
if event_type == EventType.UNKNOWN:
raise ValueError("未知事件类型")
if event_type not in self._history_enable_map:
raise ValueError(f"事件类型 {event_type} 未注册")
try:
result = await handler.execute(message)
expected_fields = ["success", "continue_processing", "return_message", "custom_result", "modified_message"]
if not isinstance(result, tuple) or len(result) != 5:
if isinstance(result, tuple):
annotated = ", ".join(f"{name}={val!r}" for name, val in zip(expected_fields, result, strict=False))
actual_desc = f"{len(result)} 个元素 ({annotated})"
else:
actual_desc = f"非 tuple 类型: {type(result)}"
logger.error(
f"[{self.__class__.__name__}] EventHandler {handler.handler_name} 返回值不符合预期:\n"
f" 模块来源: {handler.__class__.__module__}.{handler.__class__.__name__}\n"
f" 期望: 5 个元素 ({', '.join(expected_fields)})\n"
f" 实际: {actual_desc}"
)
return True, None
success, continue_processing, return_message, custom_result, modified_message = result
if not success:
logger.error(f"EventHandler {handler.handler_name} 执行失败: {return_message}")
else:
logger.debug(f"EventHandler {handler.handler_name} 执行成功: {return_message}")
if self._history_enable_map[event_type] and custom_result:
self._events_result_history[event_type].append(custom_result)
return continue_processing, modified_message
except KeyError:
logger.error(f"事件 {event_type} 注册的历史记录启用情况与实际不符合")
return True, None
except Exception as e:
logger.error(f"EventHandler {handler.handler_name} 发生异常: {e}", exc_info=True)
return True, None # 发生异常时默认不中断其他处理
def _task_done_callback(
self,
task: asyncio.Task[Tuple[bool, bool, str | None, CustomEventHandlerResult | None, MaiMessages | None]],
event_type: EventType | str,
):
"""任务完成回调"""
task_name = task.get_name() or "Unknown Task"
if event_type == EventType.UNKNOWN:
raise ValueError("未知事件类型")
if event_type not in self._history_enable_map:
raise ValueError(f"事件类型 {event_type} 未注册")
try:
success, _, result, custom_result, _ = task.result() # 忽略是否继续的标志和消息的修改,因为消息本身未被拦截
if success:
logger.debug(f"事件处理任务 {task_name} 已成功完成: {result}")
else:
logger.error(f"事件处理任务 {task_name} 执行失败: {result}")
if self._history_enable_map[event_type] and custom_result:
self._events_result_history[event_type].append(custom_result)
except asyncio.CancelledError:
pass
except KeyError:
logger.error(f"事件 {event_type} 注册的历史记录启用情况与实际不符合")
except Exception as e:
logger.error(f"事件处理任务 {task_name} 发生异常: {e}")
finally:
with contextlib.suppress(ValueError, KeyError):
self._handler_tasks[task_name].remove(task)
events_manager = EventsManager()

View File

@@ -1,634 +0,0 @@
from importlib.util import spec_from_file_location, module_from_spec
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Type
from collections import deque
import os
import traceback
from src.common.logger import get_logger
from src.plugin_system.base.plugin_base import PluginBase
from src.plugin_system.base.component_types import ComponentType
from src.plugin_system.utils.manifest_utils import VersionComparator
from .component_registry import component_registry
from .plugin_service_registry import plugin_service_registry
logger = get_logger("plugin_manager")
class PluginManager:
"""
插件管理器类
负责加载,重载和卸载插件,同时管理插件的所有组件
"""
def __init__(self):
self.plugin_directories: List[str] = [] # 插件根目录列表
self.plugin_classes: Dict[str, Type[PluginBase]] = {} # 全局插件类注册表,插件名 -> 插件类
self.plugin_paths: Dict[str, str] = {} # 记录插件名到目录路径的映射,插件名 -> 目录路径
self.loaded_plugins: Dict[str, PluginBase] = {} # 已加载的插件类实例注册表,插件名 -> 插件类实例
self.failed_plugins: Dict[str, str] = {} # 记录加载失败的插件文件及其错误信息,插件名 -> 错误信息
# 确保插件目录存在
self._ensure_plugin_directories()
logger.info("插件管理器初始化完成")
# === 插件目录管理 ===
def add_plugin_directory(self, directory: str) -> bool:
"""添加插件目录"""
if os.path.exists(directory):
if directory not in self.plugin_directories:
self.plugin_directories.append(directory)
logger.debug(f"已添加插件目录: {directory}")
return True
else:
logger.warning(f"插件不可重复加载: {directory}")
else:
logger.warning(f"插件目录不存在: {directory}")
return False
# === 插件加载管理 ===
def load_all_plugins(self) -> Tuple[int, int]:
"""加载所有插件
Returns:
tuple[int, int]: (插件数量, 组件数量)
"""
logger.debug("开始加载所有插件...")
# 第一阶段:加载所有插件模块(注册插件类)
total_loaded_modules = 0
total_failed_modules = 0
for directory in self.plugin_directories:
loaded, failed = self._load_plugin_modules_from_directory(directory)
total_loaded_modules += loaded
total_failed_modules += failed
logger.debug(f"插件模块加载完成 - 成功: {total_loaded_modules}, 失败: {total_failed_modules}")
# 第二阶段前:根据依赖关系决定插件加载顺序
dependency_graph, missing_dependencies = self._build_plugin_dependency_graph()
sorted_plugins, cycle_plugins = self._resolve_plugin_load_order(dependency_graph)
total_registered = 0
total_failed_registration = 0
# 先处理缺失依赖
for plugin_name, missing in missing_dependencies.items():
if not missing:
continue
missing_dep_names = ", ".join(sorted(missing))
self.failed_plugins[plugin_name] = f"缺少依赖插件: {missing_dep_names}"
logger.error(f"❌ 插件加载失败: {plugin_name} - 缺少依赖插件: {missing_dep_names}")
total_failed_registration += 1
# 再处理循环依赖
for plugin_name in sorted(cycle_plugins):
if plugin_name in missing_dependencies and missing_dependencies[plugin_name]:
continue
self.failed_plugins[plugin_name] = "检测到循环依赖"
logger.error(f"❌ 插件加载失败: {plugin_name} - 检测到循环依赖")
total_failed_registration += 1
# 最后按拓扑序加载可加载插件
for plugin_name in sorted_plugins:
if plugin_name in cycle_plugins:
continue
if plugin_name in missing_dependencies and missing_dependencies[plugin_name]:
continue
load_status, count = self.load_registered_plugin_classes(plugin_name)
if load_status:
total_registered += 1
else:
total_failed_registration += count
self._show_stats(total_registered, total_failed_registration)
return total_registered, total_failed_registration
def load_registered_plugin_classes(self, plugin_name: str) -> Tuple[bool, int]:
# sourcery skip: extract-duplicate-method, extract-method
"""
加载已经注册的插件类
"""
plugin_class = self.plugin_classes.get(plugin_name)
if not plugin_class:
logger.error(f"插件 {plugin_name} 的插件类未注册或不存在")
return False, 1
try:
# 使用记录的插件目录路径
plugin_dir = self.plugin_paths.get(plugin_name)
# 如果没有记录,直接返回失败
if not plugin_dir:
return False, 1
plugin_instance = plugin_class(plugin_dir=plugin_dir) # 实例化插件可能因为缺少manifest而失败
if not plugin_instance:
logger.error(f"插件 {plugin_name} 实例化失败")
return False, 1
# 检查插件是否启用
if not plugin_instance.enable_plugin:
logger.info(f"插件 {plugin_name} 已禁用,跳过加载")
return False, 0
# 检查版本兼容性
is_compatible, compatibility_error = self._check_plugin_version_compatibility(
plugin_name, plugin_instance.manifest_data
)
if not is_compatible:
self.failed_plugins[plugin_name] = compatibility_error
logger.error(f"❌ 插件加载失败: {plugin_name} - {compatibility_error}")
return False, 1
if plugin_instance.register_plugin():
self.loaded_plugins[plugin_name] = plugin_instance
self._show_plugin_components(plugin_name)
return True, 1
else:
self.failed_plugins[plugin_name] = "插件注册失败"
logger.error(f"❌ 插件注册失败: {plugin_name}")
return False, 1
except FileNotFoundError as e:
# manifest文件缺失
error_msg = f"缺少manifest文件: {str(e)}"
self.failed_plugins[plugin_name] = error_msg
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
return False, 1
except ValueError as e:
# manifest文件格式错误或验证失败
traceback.print_exc()
error_msg = f"manifest验证失败: {str(e)}"
self.failed_plugins[plugin_name] = error_msg
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
return False, 1
except Exception as e:
# 其他错误
error_msg = f"未知错误: {str(e)}"
self.failed_plugins[plugin_name] = error_msg
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
logger.debug("详细错误信息: ", exc_info=True)
return False, 1
async def remove_registered_plugin(self, plugin_name: str) -> bool:
"""
禁用插件模块
"""
if not plugin_name:
raise ValueError("插件名称不能为空")
if plugin_name not in self.loaded_plugins:
logger.warning(f"插件 {plugin_name} 未加载")
return False
plugin_instance = self.loaded_plugins[plugin_name]
plugin_info = plugin_instance.plugin_info
success = True
for component in plugin_info.components:
success &= await component_registry.remove_component(component.name, component.component_type, plugin_name)
success &= component_registry.remove_plugin_registry(plugin_name)
plugin_service_registry.remove_services_by_plugin(plugin_name)
del self.loaded_plugins[plugin_name]
return success
async def reload_registered_plugin(self, plugin_name: str) -> bool:
"""
重载插件模块
"""
old_instance = self.loaded_plugins.get(plugin_name)
if not old_instance:
logger.warning(f"插件 {plugin_name} 未加载,无法重载")
return False
if not await self.remove_registered_plugin(plugin_name):
return False
if not self.load_registered_plugin_classes(plugin_name)[0]:
logger.error(f"插件 {plugin_name} 重载失败,开始回滚旧实例")
rollback_ok = await self._rollback_failed_reload(plugin_name, old_instance)
if rollback_ok:
logger.info(f"插件 {plugin_name} 已回滚到旧版本实例")
else:
logger.error(f"插件 {plugin_name} 回滚失败,插件当前不可用")
return False
logger.debug(f"插件 {plugin_name} 重载成功")
return True
async def _rollback_failed_reload(self, plugin_name: str, old_instance: PluginBase) -> bool:
"""重载失败后回滚旧实例。"""
try:
await component_registry.remove_components_by_plugin(plugin_name)
component_registry.remove_plugin_registry(plugin_name)
plugin_service_registry.remove_services_by_plugin(plugin_name)
if not old_instance.register_plugin():
logger.error(f"插件 {plugin_name} 回滚失败: 旧实例重新注册失败")
return False
self.loaded_plugins[plugin_name] = old_instance
return True
except Exception as e:
logger.error(f"插件 {plugin_name} 回滚异常: {e}", exc_info=True)
return False
def rescan_plugin_directory(self) -> Tuple[int, int]:
"""
重新扫描插件根目录
"""
total_success = 0
total_fail = 0
for directory in self.plugin_directories:
if os.path.exists(directory):
logger.debug(f"重新扫描插件根目录: {directory}")
success, fail = self._load_plugin_modules_from_directory(directory)
total_success += success
total_fail += fail
else:
logger.warning(f"插件根目录不存在: {directory}")
return total_success, total_fail
def get_plugin_instance(self, plugin_name: str) -> Optional["PluginBase"]:
"""获取插件实例
Args:
plugin_name: 插件名称
Returns:
Optional[BasePlugin]: 插件实例或None
"""
return self.loaded_plugins.get(plugin_name)
# === 查询方法 ===
def list_loaded_plugins(self) -> List[str]:
"""
列出所有当前加载的插件。
Returns:
list: 当前加载的插件名称列表。
"""
return list(self.loaded_plugins.keys())
def list_registered_plugins(self) -> List[str]:
"""
列出所有已注册的插件类。
Returns:
list: 已注册的插件类名称列表。
"""
return list(self.plugin_classes.keys())
def get_plugin_path(self, plugin_name: str) -> Optional[str]:
"""
获取指定插件的路径。
Args:
plugin_name: 插件名称
Returns:
Optional[str]: 插件目录的绝对路径如果插件不存在则返回None。
"""
return self.plugin_paths.get(plugin_name)
# === 私有方法 ===
# == 目录管理 ==
def _ensure_plugin_directories(self) -> None:
"""确保所有插件根目录存在,如果不存在则创建"""
default_directories = ["src/plugins/built_in", "plugins"]
for directory in default_directories:
if not os.path.exists(directory):
os.makedirs(directory, exist_ok=True)
logger.info(f"创建插件根目录: {directory}")
if directory not in self.plugin_directories:
self.plugin_directories.append(directory)
logger.debug(f"已添加插件根目录: {directory}")
else:
logger.warning(f"根目录不可重复加载: {directory}")
# == 插件加载 ==
def _load_plugin_modules_from_directory(self, directory: str) -> tuple[int, int]:
"""从指定目录加载插件模块"""
loaded_count = 0
failed_count = 0
if not os.path.exists(directory):
logger.warning(f"插件根目录不存在: {directory}")
return 0, 1
logger.debug(f"正在扫描插件根目录: {directory}")
# 遍历目录中的所有包
for item in os.listdir(directory):
item_path = os.path.join(directory, item)
if os.path.isdir(item_path) and not item.startswith(".") and not item.startswith("__"):
plugin_file = os.path.join(item_path, "plugin.py")
if os.path.exists(plugin_file):
if self._load_plugin_module_file(plugin_file):
loaded_count += 1
else:
failed_count += 1
return loaded_count, failed_count
def _load_plugin_module_file(self, plugin_file: str) -> bool:
# sourcery skip: extract-method
"""加载单个插件模块文件
Args:
plugin_file: 插件文件路径
plugin_name: 插件名称
plugin_dir: 插件目录路径
"""
# 生成模块名
plugin_path = Path(plugin_file)
module_name = ".".join(plugin_path.parent.parts)
try:
# 动态导入插件模块
spec = spec_from_file_location(module_name, plugin_file)
if spec is None or spec.loader is None:
logger.error(f"无法创建模块规范: {plugin_file}")
return False
module = module_from_spec(spec)
module.__package__ = module_name # 设置模块包名
spec.loader.exec_module(module)
logger.debug(f"插件模块加载成功: {plugin_file}")
return True
except Exception as e:
error_msg = f"加载插件模块 {plugin_file} 失败: {e}"
logger.error(error_msg)
self.failed_plugins[module_name] = error_msg
return False
# == 依赖解析与加载顺序 ==
def _extract_declared_dependencies(self, plugin_name: str, plugin_class: Type[PluginBase]) -> Set[str]:
"""提取插件声明的依赖。
兼容声明格式:
- list[str]
- list[dict]其中dict至少包含name键
"""
dependencies: Set[str] = set()
raw_dependencies = getattr(plugin_class, "dependencies", [])
# 兼容错误声明
if isinstance(raw_dependencies, property):
logger.warning(f"插件 {plugin_name} 的 dependencies 未声明为类属性,将按无依赖处理")
return dependencies
if not isinstance(raw_dependencies, list):
logger.warning(f"插件 {plugin_name} 的 dependencies 不是列表,将按无依赖处理")
return dependencies
for dependency in raw_dependencies:
dependency_name = ""
if isinstance(dependency, str):
dependency_name = dependency.strip()
elif isinstance(dependency, dict):
dependency_name = str(dependency.get("name", "")).strip()
else:
logger.warning(f"插件 {plugin_name} 包含不支持的依赖声明类型: {type(dependency)}")
continue
if not dependency_name:
continue
if dependency_name == plugin_name:
logger.warning(f"插件 {plugin_name} 声明了对自身的依赖,已忽略")
continue
dependencies.add(dependency_name)
return dependencies
def _build_plugin_dependency_graph(self) -> Tuple[Dict[str, Set[str]], Dict[str, Set[str]]]:
"""构建依赖图并返回缺失依赖映射。"""
plugin_names = set(self.plugin_classes.keys())
dependency_graph: Dict[str, Set[str]] = {name: set() for name in plugin_names}
missing_dependencies: Dict[str, Set[str]] = {name: set() for name in plugin_names}
for plugin_name, plugin_class in self.plugin_classes.items():
declared_dependencies = self._extract_declared_dependencies(plugin_name, plugin_class)
for dependency_name in declared_dependencies:
if dependency_name in plugin_names:
dependency_graph[plugin_name].add(dependency_name)
else:
missing_dependencies[plugin_name].add(dependency_name)
return dependency_graph, missing_dependencies
def _resolve_plugin_load_order(self, dependency_graph: Dict[str, Set[str]]) -> Tuple[List[str], Set[str]]:
"""根据依赖图计算加载顺序,并检测循环依赖。"""
indegree: Dict[str, int] = {
plugin_name: len(dependencies) for plugin_name, dependencies in dependency_graph.items()
}
reverse_graph: Dict[str, Set[str]] = {plugin_name: set() for plugin_name in dependency_graph}
for plugin_name, dependencies in dependency_graph.items():
for dependency_name in dependencies:
reverse_graph[dependency_name].add(plugin_name)
zero_indegree_queue = deque(sorted([name for name, degree in indegree.items() if degree == 0]))
load_order: List[str] = []
while zero_indegree_queue:
current_plugin = zero_indegree_queue.popleft()
load_order.append(current_plugin)
for dependent_plugin in sorted(reverse_graph[current_plugin]):
indegree[dependent_plugin] -= 1
if indegree[dependent_plugin] == 0:
zero_indegree_queue.append(dependent_plugin)
cycle_plugins = {name for name, degree in indegree.items() if degree > 0}
if cycle_plugins:
logger.error(f"检测到循环依赖插件: {', '.join(sorted(cycle_plugins))}")
return load_order, cycle_plugins
# == 兼容性检查 ==
def _check_plugin_version_compatibility(self, plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]:
"""检查插件版本兼容性
Args:
plugin_name: 插件名称
manifest_data: manifest数据
Returns:
Tuple[bool, str]: (是否兼容, 错误信息)
"""
if "host_application" not in manifest_data:
return True, "" # 没有版本要求,默认兼容
host_app = manifest_data["host_application"]
if not isinstance(host_app, dict):
return True, ""
min_version = host_app.get("min_version", "")
max_version = host_app.get("max_version", "")
if not min_version and not max_version:
return True, "" # 没有版本要求,默认兼容
try:
current_version = VersionComparator.get_current_host_version()
is_compatible, error_msg = VersionComparator.is_version_in_range(current_version, min_version, max_version)
if not is_compatible:
return False, f"版本不兼容: {error_msg}"
logger.debug(f"插件 {plugin_name} 版本兼容性检查通过")
return True, ""
except Exception as e:
logger.warning(f"插件 {plugin_name} 版本兼容性检查失败: {e}")
return False, f"插件 {plugin_name} 版本兼容性检查失败: {e}" # 检查失败时默认不允许加载
# == 显示统计与插件信息 ==
def _show_stats(self, total_registered: int, total_failed_registration: int):
# sourcery skip: low-code-quality
# 获取组件统计信息
stats = component_registry.get_registry_stats()
action_count = stats.get("action_components", 0)
command_count = stats.get("command_components", 0)
tool_count = stats.get("tool_components", 0)
event_handler_count = stats.get("event_handlers", 0)
total_components = stats.get("total_components", 0)
# 📋 显示插件加载总览
if total_registered > 0:
logger.info("🎉 插件系统加载完成!")
logger.info(
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, EventHandler: {event_handler_count})"
)
# 显示详细的插件列表
logger.info("📋 已加载插件详情:")
for plugin_name in self.loaded_plugins.keys():
if plugin_info := component_registry.get_plugin_info(plugin_name):
# 插件基本信息
version_info = f"v{plugin_info.version}" if plugin_info.version else ""
author_info = f"by {plugin_info.author}" if plugin_info.author else "unknown"
license_info = f"[{plugin_info.license}]" if plugin_info.license else ""
info_parts = [part for part in [version_info, author_info, license_info] if part]
extra_info = f" ({', '.join(info_parts)})" if info_parts else ""
logger.info(f" 📦 {plugin_info.display_name}{extra_info}")
# Manifest信息
if plugin_info.manifest_data:
"""
if plugin_info.keywords:
logger.info(f" 🏷️ 关键词: {', '.join(plugin_info.keywords)}")
if plugin_info.categories:
logger.info(f" 📁 分类: {', '.join(plugin_info.categories)}")
"""
if plugin_info.homepage_url:
logger.info(f" 🌐 主页: {plugin_info.homepage_url}")
# 组件列表
if plugin_info.components:
action_components = [
c for c in plugin_info.components if c.component_type == ComponentType.ACTION
]
command_components = [
c for c in plugin_info.components if c.component_type == ComponentType.COMMAND
]
tool_components = [c for c in plugin_info.components if c.component_type == ComponentType.TOOL]
event_handler_components = [
c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER
]
if action_components:
action_names = [c.name for c in action_components]
logger.info(f" 🎯 Action组件: {', '.join(action_names)}")
if command_components:
command_names = [c.name for c in command_components]
logger.info(f" ⚡ Command组件: {', '.join(command_names)}")
if tool_components:
tool_names = [c.name for c in tool_components]
logger.info(f" 🛠️ Tool组件: {', '.join(tool_names)}")
if event_handler_components:
event_handler_names = [c.name for c in event_handler_components]
logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_names)}")
# 依赖信息
if plugin_info.dependencies:
logger.info(f" 🔗 依赖: {', '.join(plugin_info.dependencies)}")
# 配置文件信息
if plugin_info.config_file:
config_status = "" if self.plugin_paths.get(plugin_name) else ""
logger.info(f" ⚙️ 配置: {plugin_info.config_file} {config_status}")
root_path = Path(__file__)
# 查找项目根目录
while not (root_path / "pyproject.toml").exists() and root_path.parent != root_path:
root_path = root_path.parent
# 显示目录统计
logger.info("📂 加载目录统计:")
for directory in self.plugin_directories:
if os.path.exists(directory):
plugins_in_dir = []
for plugin_name in self.loaded_plugins.keys():
plugin_path = self.plugin_paths.get(plugin_name, "")
if (
Path(plugin_path)
.resolve()
.is_relative_to(Path(os.path.join(str(root_path), directory)).resolve())
):
plugins_in_dir.append(plugin_name)
if plugins_in_dir:
logger.info(f" 📁 {directory}: {len(plugins_in_dir)}个插件 ({', '.join(plugins_in_dir)})")
else:
logger.info(f" 📁 {directory}: 0个插件")
# 失败信息
if total_failed_registration > 0:
logger.info(f"⚠️ 失败统计: {total_failed_registration}个插件加载失败")
for failed_plugin, error in self.failed_plugins.items():
logger.info(f"{failed_plugin}: {error}")
else:
logger.warning("😕 没有成功加载任何插件")
def _show_plugin_components(self, plugin_name: str) -> None:
if plugin_info := component_registry.get_plugin_info(plugin_name):
component_types = {}
for comp in plugin_info.components:
comp_type = comp.component_type.name
component_types[comp_type] = component_types.get(comp_type, 0) + 1
components_str = ", ".join([f"{count}{ctype}" for ctype, count in component_types.items()])
# 显示manifest信息
manifest_info = ""
if plugin_info.license:
manifest_info += f" [{plugin_info.license}]"
if plugin_info.keywords:
manifest_info += f" 关键词: {', '.join(plugin_info.keywords[:3])}" # 只显示前3个关键词
if len(plugin_info.keywords) > 3:
manifest_info += "..."
logger.info(
f"✅ 插件加载成功: {plugin_name} v{plugin_info.version} ({components_str}){manifest_info} - {plugin_info.description}"
)
else:
logger.info(f"✅ 插件加载成功: {plugin_name}")
# 全局插件管理器实例
plugin_manager = PluginManager()

View File

@@ -1,251 +0,0 @@
from typing import Any, Callable, Dict, Optional
import inspect
from src.common.logger import get_logger
from src.plugin_system.base.service_types import PluginServiceInfo
logger = get_logger("plugin_service_registry")
class PluginServiceRegistry:
"""插件服务注册中心"""
def __init__(self):
self._services: Dict[str, PluginServiceInfo] = {}
self._service_handlers: Dict[str, Callable[..., Any]] = {}
logger.info("插件服务注册中心初始化完成")
def register_service(self, service_info: PluginServiceInfo, service_handler: Callable[..., Any]) -> bool:
"""注册插件服务。"""
if not service_info.name or not service_info.plugin_name:
logger.error("插件服务注册失败: service名称或插件名称为空")
return False
if "." in service_info.name:
logger.error(f"插件服务名称 '{service_info.name}' 包含非法字符 '.',请使用下划线替代")
return False
if "." in service_info.plugin_name:
logger.error(f"插件服务所属插件名称 '{service_info.plugin_name}' 包含非法字符 '.',请使用下划线替代")
return False
if invalid_callers := [caller for caller in service_info.allowed_callers if "." in caller]:
logger.error(f"插件服务白名单包含非法调用方名称: {invalid_callers}")
return False
full_name = service_info.full_name
if full_name in self._services:
logger.warning(f"插件服务已存在,拒绝重复注册: {full_name}")
return False
self._services[full_name] = service_info
self._service_handlers[full_name] = service_handler
logger.debug(f"已注册插件服务: {full_name} (version={service_info.version})")
return True
def get_service(self, service_name: str, plugin_name: Optional[str] = None) -> Optional[PluginServiceInfo]:
"""获取插件服务元信息。
service_name支持
- full_name: plugin_name.service_name
- short_name: service_name当唯一时可解析
"""
full_name = self._resolve_full_name(service_name, plugin_name)
return self._services.get(full_name) if full_name else None
def get_service_handler(self, service_name: str, plugin_name: Optional[str] = None) -> Optional[Callable[..., Any]]:
"""获取插件服务处理函数。"""
full_name = self._resolve_full_name(service_name, plugin_name)
return self._service_handlers.get(full_name) if full_name else None
def list_services(
self, plugin_name: Optional[str] = None, enabled_only: bool = False
) -> Dict[str, PluginServiceInfo]:
"""列出插件服务。"""
services = self._services.copy()
if plugin_name:
services = {name: info for name, info in services.items() if info.plugin_name == plugin_name}
if enabled_only:
services = {name: info for name, info in services.items() if info.enabled}
return services
def enable_service(self, service_name: str, plugin_name: Optional[str] = None) -> bool:
"""启用插件服务。"""
if not (service_info := self.get_service(service_name, plugin_name)):
logger.warning(f"插件服务未注册,无法启用: {service_name}")
return False
service_info.enabled = True
logger.info(f"插件服务已启用: {service_info.full_name}")
return True
def disable_service(self, service_name: str, plugin_name: Optional[str] = None) -> bool:
"""禁用插件服务。"""
if not (service_info := self.get_service(service_name, plugin_name)):
logger.warning(f"插件服务未注册,无法禁用: {service_name}")
return False
service_info.enabled = False
logger.info(f"插件服务已禁用: {service_info.full_name}")
return True
def unregister_service(self, service_name: str, plugin_name: Optional[str] = None) -> bool:
"""注销单个插件服务。"""
full_name = self._resolve_full_name(service_name, plugin_name)
if not full_name:
logger.warning(f"插件服务未注册,无法注销: {service_name}")
return False
self._services.pop(full_name, None)
self._service_handlers.pop(full_name, None)
logger.info(f"插件服务已注销: {full_name}")
return True
def remove_services_by_plugin(self, plugin_name: str) -> int:
"""移除某插件的所有注册服务。"""
target_names = [full_name for full_name, info in self._services.items() if info.plugin_name == plugin_name]
for full_name in target_names:
self._services.pop(full_name, None)
self._service_handlers.pop(full_name, None)
removed_count = len(target_names)
if removed_count:
logger.info(f"已移除插件 {plugin_name} 的服务数量: {removed_count}")
return removed_count
async def call_service(
self,
service_name: str,
*args: Any,
plugin_name: Optional[str] = None,
caller_plugin: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""调用插件服务(支持同步/异步handler"""
service_info = self.get_service(service_name, plugin_name)
if not service_info:
target_name = f"{plugin_name}.{service_name}" if plugin_name and "." not in service_name else service_name
raise ValueError(f"插件服务未注册: {target_name}")
if (
"." not in service_name
and plugin_name is None
and caller_plugin
and service_info.plugin_name != caller_plugin
):
raise PermissionError("跨插件服务调用必须使用完整服务名或显式指定plugin_name")
if not self._is_call_authorized(service_info, caller_plugin):
raise PermissionError(
f"调用被拒绝: caller={caller_plugin or 'anonymous'} 无权限访问服务 {service_info.full_name}"
)
if not service_info.enabled:
raise RuntimeError(f"插件服务已禁用: {service_info.full_name}")
handler = self.get_service_handler(service_name, plugin_name)
if not handler:
raise RuntimeError(f"插件服务处理器不存在: {service_info.full_name}")
self._validate_input_contract(service_info, args, kwargs)
result = handler(*args, **kwargs)
resolved_result = await result if inspect.isawaitable(result) else result
self._validate_output_contract(service_info, resolved_result)
return resolved_result
def _is_call_authorized(self, service_info: PluginServiceInfo, caller_plugin: Optional[str]) -> bool:
"""检查服务调用权限。"""
if caller_plugin is None:
return service_info.public
if caller_plugin == service_info.plugin_name:
return True
if service_info.public:
return True
allowed_callers = {caller.strip() for caller in service_info.allowed_callers if caller.strip()}
return "*" in allowed_callers or caller_plugin in allowed_callers
def _validate_input_contract(
self, service_info: PluginServiceInfo, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> None:
"""校验服务入参契约。"""
schema = service_info.params_schema
if not schema:
return
properties = schema.get("properties", {}) if isinstance(schema, dict) else {}
is_invocation_schema = "args" in properties or "kwargs" in properties
if is_invocation_schema:
payload = {"args": list(args), "kwargs": kwargs}
self._validate_by_schema(payload, schema, path="params")
return
if args:
raise ValueError(f"服务 {service_info.full_name} 的入参契约不允许位置参数")
self._validate_by_schema(kwargs, schema, path="params")
def _validate_output_contract(self, service_info: PluginServiceInfo, value: Any) -> None:
"""校验服务返回值契约。"""
if not service_info.return_schema:
return
self._validate_by_schema(value, service_info.return_schema, path="return")
def _validate_by_schema(self, value: Any, schema: Dict[str, Any], path: str) -> None:
"""基于简化JSON-Schema校验数据。"""
expected_type = schema.get("type")
if expected_type:
self._validate_type(value, expected_type, path)
enum_values = schema.get("enum")
if enum_values is not None and value not in enum_values:
raise ValueError(f"{path} 不在枚举范围内: {value}")
if expected_type == "object":
properties = schema.get("properties", {})
required = schema.get("required", [])
for field in required:
if field not in value:
raise ValueError(f"{path}.{field} 为必填字段")
for field, field_value in value.items():
if field in properties:
self._validate_by_schema(field_value, properties[field], f"{path}.{field}")
elif schema.get("additionalProperties", True) is False:
raise ValueError(f"{path}.{field} 不允许额外字段")
if expected_type == "array":
if item_schema := schema.get("items"):
for index, item in enumerate(value):
self._validate_by_schema(item, item_schema, f"{path}[{index}]")
def _validate_type(self, value: Any, expected_type: str, path: str) -> None:
"""校验基础类型。"""
type_checkers: Dict[str, Callable[[Any], bool]] = {
"string": lambda item: isinstance(item, str),
"number": lambda item: isinstance(item, (int, float)) and not isinstance(item, bool),
"integer": lambda item: isinstance(item, int) and not isinstance(item, bool),
"boolean": lambda item: isinstance(item, bool),
"object": lambda item: isinstance(item, dict),
"array": lambda item: isinstance(item, list),
"null": lambda item: item is None,
}
checker = type_checkers.get(expected_type)
if checker and not checker(value):
raise TypeError(f"{path} 类型不匹配,期望 {expected_type},实际 {type(value).__name__}")
def _resolve_full_name(self, service_name: str, plugin_name: Optional[str] = None) -> Optional[str]:
"""解析服务全名。"""
if "." in service_name:
return service_name if service_name in self._services else None
if plugin_name:
full_name = f"{plugin_name}.{service_name}"
return full_name if full_name in self._services else None
candidates = [full_name for full_name, info in self._services.items() if info.name == service_name]
if len(candidates) == 1:
return candidates[0]
if len(candidates) > 1:
logger.warning(f"插件服务名称 '{service_name}' 存在多义请传入plugin_name或使用完整服务名")
return None
return None
plugin_service_registry = PluginServiceRegistry()

View File

@@ -1,13 +0,0 @@
- [x] 自定义事件
- [ ] <del>允许handler随时订阅</del>
- [x] 允许其他组件给handler增加订阅
- [x] 允许其他组件给handler取消订阅
- [ ] <del>允许一个handler订阅多个事件</del>
- [x] event激活时给handler传递参数
- [ ] handler能拿到所有handlers的结果按照处理权重
- [x] 随时注册
- [ ] <del>删除event</del>
- [ ] 必要性?
- [x] 能够更改prompt
- [x] 能够更改llm_response
- [x] 能够更改message

View File

@@ -1,260 +0,0 @@
from typing import Any, Awaitable, Callable, List, Optional, Tuple, Union
import asyncio
import inspect
import time
import uuid
from src.common.logger import get_logger
from src.plugin_system.base.component_types import EventType, MaiMessages
from src.plugin_system.base.workflow_errors import WorkflowErrorCode
from src.plugin_system.base.workflow_types import WorkflowContext, WorkflowStage, WorkflowStepResult
logger = get_logger("workflow_engine")
class WorkflowEngine:
"""线性Workflow执行器MVP"""
STAGE_EVENT_SEQUENCE: List[Tuple[WorkflowStage, Union[EventType, str]]] = [
(WorkflowStage.INGRESS, "workflow.ingress"),
(WorkflowStage.PRE_PROCESS, EventType.ON_MESSAGE_PRE_PROCESS),
(WorkflowStage.PLAN, EventType.ON_PLAN),
(WorkflowStage.TOOL_EXECUTE, "workflow.tool_execute"),
(WorkflowStage.POST_PROCESS, EventType.POST_SEND_PRE_PROCESS),
(WorkflowStage.EGRESS, "workflow.egress"),
]
def __init__(self):
self._execution_history: dict[str, dict[str, Any]] = {}
async def execute_linear(
self,
dispatch_event: Callable[
[Union[EventType, str], Optional[MaiMessages], Optional[str], Optional[List[str]]],
Awaitable[Tuple[bool, Optional[MaiMessages]]],
],
message: Optional[MaiMessages] = None,
stream_id: Optional[str] = None,
action_usage: Optional[List[str]] = None,
context: Optional[WorkflowContext] = None,
) -> Tuple[WorkflowStepResult, Optional[MaiMessages], WorkflowContext]:
"""执行线性workflow。"""
workflow_context = context or WorkflowContext(trace_id=uuid.uuid4().hex, stream_id=stream_id)
current_message = message.deepcopy() if message else None
self._execution_history[workflow_context.trace_id] = {
"trace_id": workflow_context.trace_id,
"stream_id": workflow_context.stream_id,
"stages": [],
"errors": [],
"status": "running",
}
for stage, event_type in self.STAGE_EVENT_SEQUENCE:
stage_key = str(stage)
stage_start = time.perf_counter()
try:
should_continue, modified_message = await dispatch_event(
event_type,
current_message,
workflow_context.stream_id,
action_usage,
)
workflow_context.timings[stage_key] = time.perf_counter() - stage_start
self._execution_history[workflow_context.trace_id]["stages"].append(
{
"stage": stage_key,
"event_type": str(event_type),
"event_continue": should_continue,
"event_cost": workflow_context.timings[stage_key],
}
)
if modified_message:
current_message = modified_message
if not should_continue:
logger.info(f"[trace_id={workflow_context.trace_id}] Workflow在阶段 {stage_key} 被中断")
return (
WorkflowStepResult(
status="stop",
return_message=f"workflow stopped at stage {stage_key}",
diagnostics={
"stage": stage_key,
"trace_id": workflow_context.trace_id,
"error_code": WorkflowErrorCode.POLICY_BLOCKED.value,
},
),
current_message,
workflow_context,
)
step_result = await self._execute_registered_steps(stage, workflow_context, current_message)
if step_result.status in ["stop", "failed"]:
self._execution_history[workflow_context.trace_id]["status"] = step_result.status
self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy()
return step_result, current_message, workflow_context
except Exception as e:
workflow_context.timings[stage_key] = time.perf_counter() - stage_start
workflow_context.errors.append(f"{stage_key}: {e}")
logger.error(
f"[trace_id={workflow_context.trace_id}] Workflow阶段 {stage_key} 执行异常: {e}", exc_info=True
)
self._execution_history[workflow_context.trace_id]["status"] = "failed"
self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy()
return (
WorkflowStepResult(
status="failed",
return_message=str(e),
diagnostics={
"stage": stage_key,
"trace_id": workflow_context.trace_id,
"error_code": WorkflowErrorCode.DOWNSTREAM_FAILED.value,
},
),
current_message,
workflow_context,
)
self._execution_history[workflow_context.trace_id]["status"] = "continue"
self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy()
return (
WorkflowStepResult(
status="continue",
return_message="workflow completed",
diagnostics={"trace_id": workflow_context.trace_id},
),
current_message,
workflow_context,
)
async def _execute_registered_steps(
self,
stage: WorkflowStage,
context: WorkflowContext,
message: Optional[MaiMessages],
) -> WorkflowStepResult:
"""执行指定阶段已注册的workflow步骤。"""
from src.plugin_system.core.component_registry import component_registry
stage_steps = component_registry.get_steps_by_stage(stage, enabled_only=True)
sorted_steps = sorted(stage_steps.values(), key=lambda step_info: step_info.priority, reverse=True)
for step_info in sorted_steps:
handler = component_registry.get_workflow_step_handler(step_info.full_name, stage)
if not handler:
context.errors.append(f"{step_info.full_name}: handler not found")
continue
step_timing_key = f"{stage.value}:{step_info.full_name}"
step_start = time.perf_counter()
timeout_seconds = step_info.timeout_ms / 1000 if step_info.timeout_ms > 0 else None
try:
if inspect.iscoroutinefunction(handler):
coroutine = handler(context, message)
result = await asyncio.wait_for(coroutine, timeout_seconds) if timeout_seconds else await coroutine
else:
if timeout_seconds:
result = await asyncio.wait_for(asyncio.to_thread(handler, context, message), timeout_seconds)
else:
result = handler(context, message)
if inspect.isawaitable(result):
result = await asyncio.wait_for(result, timeout_seconds) if timeout_seconds else await result
context.timings[step_timing_key] = time.perf_counter() - step_start
normalized_result = self._normalize_step_result(result)
if normalized_result.status == "continue":
continue
normalized_result.diagnostics.setdefault("stage", stage.value)
normalized_result.diagnostics.setdefault("step", step_info.full_name)
normalized_result.diagnostics.setdefault("trace_id", context.trace_id)
if normalized_result.status == "failed":
context.errors.append(
f"{step_info.full_name}: {normalized_result.return_message or 'workflow step failed'}"
)
normalized_result.diagnostics.setdefault("error_code", WorkflowErrorCode.DOWNSTREAM_FAILED.value)
return normalized_result
except asyncio.TimeoutError:
context.timings[step_timing_key] = time.perf_counter() - step_start
timeout_message = f"workflow step timeout after {step_info.timeout_ms}ms"
context.errors.append(f"{step_info.full_name}: {timeout_message}")
logger.error(
f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 超时: {timeout_message}"
)
return WorkflowStepResult(
status="failed",
return_message=timeout_message,
diagnostics={
"stage": stage.value,
"step": step_info.full_name,
"trace_id": context.trace_id,
"error_code": WorkflowErrorCode.STEP_TIMEOUT.value,
},
)
except Exception as e:
context.timings[step_timing_key] = time.perf_counter() - step_start
context.errors.append(f"{step_info.full_name}: {e}")
logger.error(
f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 执行异常: {e}", exc_info=True
)
return WorkflowStepResult(
status="failed",
return_message=str(e),
diagnostics={
"stage": stage.value,
"step": step_info.full_name,
"trace_id": context.trace_id,
"error_code": WorkflowErrorCode.DOWNSTREAM_FAILED.value,
},
)
return WorkflowStepResult(status="continue", diagnostics={"stage": stage.value, "trace_id": context.trace_id})
def _normalize_step_result(self, result: Any) -> WorkflowStepResult:
"""归一化workflow step返回值。"""
if isinstance(result, WorkflowStepResult):
return result
if isinstance(result, bool):
if result:
return WorkflowStepResult(status="continue")
return WorkflowStepResult(
status="failed",
diagnostics={"error_code": WorkflowErrorCode.DOWNSTREAM_FAILED.value},
)
if result is None:
return WorkflowStepResult(status="continue")
if isinstance(result, str):
return WorkflowStepResult(status="continue", return_message=result)
if isinstance(result, dict):
status = result.get("status", "continue")
if status not in ["continue", "stop", "failed"]:
status = "failed"
return WorkflowStepResult(
status=status,
return_message=result.get("return_message"),
diagnostics=result.get("diagnostics", {}),
events=result.get("events", []),
)
return WorkflowStepResult(
status="failed",
return_message=f"unsupported step result type: {type(result)}",
diagnostics={"error_code": WorkflowErrorCode.BAD_PAYLOAD.value},
)
def get_execution_trace(self, trace_id: str) -> Optional[dict[str, Any]]:
"""按trace_id获取workflow执行路径。"""
trace = self._execution_history.get(trace_id)
return trace.copy() if trace else None
def clear_execution_trace(self, trace_id: str) -> bool:
"""清理trace执行记录。"""
if trace_id in self._execution_history:
del self._execution_history[trace_id]
return True
return False
workflow_engine = WorkflowEngine()

View File

@@ -1,19 +0,0 @@
"""
插件系统工具模块
提供插件开发和管理的实用工具
"""
from .manifest_utils import (
ManifestValidator,
# ManifestGenerator,
# validate_plugin_manifest,
# generate_plugin_manifest,
)
__all__ = [
"ManifestValidator",
# "ManifestGenerator",
# "validate_plugin_manifest",
# "generate_plugin_manifest",
]

View File

@@ -1,515 +0,0 @@
"""
插件Manifest工具模块
提供manifest文件的验证、生成和管理功能
"""
import re
from typing import Dict, Any, Tuple
from src.common.logger import get_logger
from src.config.config import MMC_VERSION
# if TYPE_CHECKING:
# from src.plugin_system.base.base_plugin import BasePlugin
logger = get_logger("manifest_utils")
class VersionComparator:
"""版本号比较器
支持语义化版本号比较自动处理snapshot版本并支持向前兼容性检查
"""
# 版本兼容性映射表(硬编码)
# 格式: {插件最大支持版本: [实际兼容的版本列表]}
COMPATIBILITY_MAP = {
# 0.8.x 系列向前兼容规则
"0.8.0": ["0.8.1", "0.8.2", "0.8.3", "0.8.4", "0.8.5", "0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"],
"0.8.1": ["0.8.2", "0.8.3", "0.8.4", "0.8.5", "0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"],
"0.8.2": ["0.8.3", "0.8.4", "0.8.5", "0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"],
"0.8.3": ["0.8.4", "0.8.5", "0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"],
"0.8.4": ["0.8.5", "0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"],
"0.8.5": ["0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"],
"0.8.6": ["0.8.7", "0.8.8", "0.8.9", "0.8.10"],
"0.8.7": ["0.8.8", "0.8.9", "0.8.10"],
"0.8.8": ["0.8.9", "0.8.10"],
"0.8.9": ["0.8.10"],
# 可以根据需要添加更多兼容映射
# "0.9.0": ["0.9.1", "0.9.2", "0.9.3"], # 示例0.9.x系列兼容
}
@staticmethod
def normalize_version(version: str) -> str:
"""标准化版本号移除snapshot标识
Args:
version: 原始版本号,如 "0.8.0-snapshot.1"
Returns:
str: 标准化后的版本号,如 "0.8.0"
"""
if not version:
return "0.0.0"
# 移除snapshot部分
normalized = re.sub(r"-snapshot\.\d+", "", version.strip())
# 确保版本号格式正确
if not re.match(r"^\d+(\.\d+){0,2}$", normalized):
# 如果不是有效的版本号格式,返回默认版本
return "0.0.0"
# 尝试补全版本号
parts = normalized.split(".")
while len(parts) < 3:
parts.append("0")
normalized = ".".join(parts[:3])
return normalized
@staticmethod
def parse_version(version: str) -> Tuple[int, int, int]:
"""解析版本号为元组
Args:
version: 版本号字符串
Returns:
Tuple[int, int, int]: (major, minor, patch)
"""
normalized = VersionComparator.normalize_version(version)
try:
parts = normalized.split(".")
return (int(parts[0]), int(parts[1]), int(parts[2]))
except (ValueError, IndexError):
logger.warning(f"无法解析版本号: {version},使用默认版本 0.0.0")
return (0, 0, 0)
@staticmethod
def compare_versions(version1: str, version2: str) -> int:
"""比较两个版本号
Args:
version1: 第一个版本号
version2: 第二个版本号
Returns:
int: -1 if version1 < version2, 0 if equal, 1 if version1 > version2
"""
v1_tuple = VersionComparator.parse_version(version1)
v2_tuple = VersionComparator.parse_version(version2)
if v1_tuple < v2_tuple:
return -1
elif v1_tuple > v2_tuple:
return 1
else:
return 0
@staticmethod
def check_forward_compatibility(current_version: str, max_version: str) -> Tuple[bool, str]:
"""检查向前兼容性(仅使用兼容性映射表)
Args:
current_version: 当前版本
max_version: 插件声明的最大支持版本
Returns:
Tuple[bool, str]: (是否兼容, 兼容信息)
"""
current_normalized = VersionComparator.normalize_version(current_version)
max_normalized = VersionComparator.normalize_version(max_version)
# 检查兼容性映射表
if max_normalized in VersionComparator.COMPATIBILITY_MAP:
compatible_versions = VersionComparator.COMPATIBILITY_MAP[max_normalized]
if current_normalized in compatible_versions:
return True, f"根据兼容性映射表,版本 {current_normalized}{max_normalized} 兼容"
return False, ""
@staticmethod
def is_version_in_range(version: str, min_version: str = "", max_version: str = "") -> Tuple[bool, str]:
"""检查版本是否在指定范围内,支持兼容性检查
Args:
version: 要检查的版本号
min_version: 最小版本号(可选)
max_version: 最大版本号(可选)
Returns:
Tuple[bool, str]: (是否兼容, 错误信息或兼容信息)
"""
if not min_version and not max_version:
return True, ""
version_normalized = VersionComparator.normalize_version(version)
# 检查最小版本
if min_version:
min_normalized = VersionComparator.normalize_version(min_version)
if VersionComparator.compare_versions(version_normalized, min_normalized) < 0:
return False, f"版本 {version_normalized} 低于最小要求版本 {min_normalized}"
# 检查最大版本
if max_version:
max_normalized = VersionComparator.normalize_version(max_version)
comparison = VersionComparator.compare_versions(version_normalized, max_normalized)
if comparison > 0:
# 严格版本检查失败,尝试兼容性检查
is_compatible, compat_msg = VersionComparator.check_forward_compatibility(
version_normalized, max_normalized
)
if not is_compatible:
return False, f"版本 {version_normalized} 高于最大支持版本 {max_normalized},且无兼容性映射"
logger.info(f"版本兼容性检查:{compat_msg}")
return True, compat_msg
return True, ""
@staticmethod
def get_current_host_version() -> str:
"""获取当前主机应用版本
Returns:
str: 当前版本号
"""
return VersionComparator.normalize_version(MMC_VERSION)
@staticmethod
def add_compatibility_mapping(base_version: str, compatible_versions: list) -> None:
"""动态添加兼容性映射
Args:
base_version: 基础版本(插件声明的最大支持版本)
compatible_versions: 兼容的版本列表
"""
base_normalized = VersionComparator.normalize_version(base_version)
VersionComparator.COMPATIBILITY_MAP[base_normalized] = [
VersionComparator.normalize_version(v) for v in compatible_versions
]
logger.info(f"添加兼容性映射:{base_normalized} -> {compatible_versions}")
@staticmethod
def get_compatibility_info() -> Dict[str, list]:
"""获取当前的兼容性映射表
Returns:
Dict[str, list]: 兼容性映射表的副本
"""
return VersionComparator.COMPATIBILITY_MAP.copy()
class ManifestValidator:
"""Manifest文件验证器"""
# 必需字段(必须存在且不能为空)
REQUIRED_FIELDS = ["manifest_version", "name", "version", "description", "author"]
# 可选字段(可以不存在或为空)
OPTIONAL_FIELDS = [
"license",
"host_application",
"homepage_url",
"repository_url",
"keywords",
"categories",
"default_locale",
"locales_path",
"plugin_info",
]
# 建议填写的字段(会给出警告但不会导致验证失败)
RECOMMENDED_FIELDS = ["license", "keywords", "categories"]
SUPPORTED_MANIFEST_VERSIONS = [1]
def __init__(self):
self.validation_errors = []
self.validation_warnings = []
def validate_manifest(self, manifest_data: Dict[str, Any]) -> bool:
"""验证manifest数据
Args:
manifest_data: manifest数据字典
Returns:
bool: 是否验证通过(只有错误会导致验证失败,警告不会)
"""
self.validation_errors.clear()
self.validation_warnings.clear()
# 检查必需字段
for field in self.REQUIRED_FIELDS:
if field not in manifest_data:
self.validation_errors.append(f"缺少必需字段: {field}")
elif not manifest_data[field]:
self.validation_errors.append(f"必需字段不能为空: {field}")
# 检查manifest版本
if "manifest_version" in manifest_data:
version = manifest_data["manifest_version"]
if version not in self.SUPPORTED_MANIFEST_VERSIONS:
self.validation_errors.append(
f"不支持的manifest版本: {version},支持的版本: {self.SUPPORTED_MANIFEST_VERSIONS}"
)
# 检查作者信息格式
if "author" in manifest_data:
author = manifest_data["author"]
if isinstance(author, dict):
if "name" not in author or not author["name"]:
self.validation_errors.append("作者信息缺少name字段或为空")
# url字段是可选的
if "url" in author and author["url"]:
url = author["url"]
if not (url.startswith("http://") or url.startswith("https://")):
self.validation_warnings.append("作者URL建议使用完整的URL格式")
elif isinstance(author, str):
if not author.strip():
self.validation_errors.append("作者信息不能为空")
else:
self.validation_errors.append("作者信息格式错误应为字符串或包含name字段的对象")
# 检查主机应用版本要求(可选)
if "host_application" in manifest_data:
host_app = manifest_data["host_application"]
if isinstance(host_app, dict):
min_version = host_app.get("min_version", "")
max_version = host_app.get("max_version", "")
# 验证版本字段格式
for version_field in ["min_version", "max_version"]:
if version_field in host_app and not host_app[version_field]:
self.validation_warnings.append(f"host_application.{version_field}为空")
# 检查当前主机版本兼容性
if min_version or max_version:
current_version = VersionComparator.get_current_host_version()
is_compatible, error_msg = VersionComparator.is_version_in_range(
current_version, min_version, max_version
)
if not is_compatible:
self.validation_errors.append(f"版本兼容性检查失败: {error_msg} (当前版本: {current_version})")
else:
logger.debug(
f"版本兼容性检查通过: 当前版本 {current_version} 符合要求 [{min_version}, {max_version}]"
)
else:
self.validation_errors.append("host_application格式错误应为对象")
# 检查URL格式可选字段
for url_field in ["homepage_url", "repository_url"]:
if url_field in manifest_data and manifest_data[url_field]:
url: str = manifest_data[url_field]
if not (url.startswith("http://") or url.startswith("https://")):
self.validation_warnings.append(f"{url_field}建议使用完整的URL格式")
# 检查数组字段格式(可选字段)
for list_field in ["keywords", "categories"]:
if list_field in manifest_data:
field_value = manifest_data[list_field]
if field_value is not None and not isinstance(field_value, list):
self.validation_errors.append(f"{list_field}应为数组格式")
elif isinstance(field_value, list):
# 检查数组元素是否为字符串
for i, item in enumerate(field_value):
if not isinstance(item, str):
self.validation_warnings.append(f"{list_field}[{i}]应为字符串")
# 检查建议字段(给出警告)
for field in self.RECOMMENDED_FIELDS:
if field not in manifest_data or not manifest_data[field]:
self.validation_warnings.append(f"建议填写字段: {field}")
# 检查plugin_info结构可选
if "plugin_info" in manifest_data:
plugin_info = manifest_data["plugin_info"]
if isinstance(plugin_info, dict):
# 检查components数组
if "components" in plugin_info:
components = plugin_info["components"]
if not isinstance(components, list):
self.validation_errors.append("plugin_info.components应为数组格式")
else:
for i, component in enumerate(components):
if not isinstance(component, dict):
self.validation_errors.append(f"plugin_info.components[{i}]应为对象")
else:
# 检查组件必需字段
for comp_field in ["type", "name", "description"]:
if comp_field not in component or not component[comp_field]:
self.validation_errors.append(
f"plugin_info.components[{i}]缺少必需字段: {comp_field}"
)
else:
self.validation_errors.append("plugin_info应为对象格式")
return len(self.validation_errors) == 0
def get_validation_report(self) -> str:
"""获取验证报告"""
report = []
if self.validation_errors:
report.append("❌ 验证错误:")
report.extend(f" - {error}" for error in self.validation_errors)
if self.validation_warnings:
report.append("⚠️ 验证警告:")
report.extend(f" - {warning}" for warning in self.validation_warnings)
if not self.validation_errors and not self.validation_warnings:
report.append("✅ Manifest文件验证通过")
return "\n".join(report)
# class ManifestGenerator:
# """Manifest文件生成器"""
# def __init__(self):
# self.template = {
# "manifest_version": 1,
# "name": "",
# "version": "1.0.0",
# "description": "",
# "author": {"name": "", "url": ""},
# "license": "MIT",
# "host_application": {"min_version": "1.0.0", "max_version": "4.0.0"},
# "homepage_url": "",
# "repository_url": "",
# "keywords": [],
# "categories": [],
# "default_locale": "zh-CN",
# "locales_path": "_locales",
# }
# def generate_from_plugin(self, plugin_instance: BasePlugin) -> Dict[str, Any]:
# """从插件实例生成manifest
# Args:
# plugin_instance: BasePlugin实例
# Returns:
# Dict[str, Any]: 生成的manifest数据
# """
# manifest = self.template.copy()
# # 基本信息
# manifest["name"] = plugin_instance.plugin_name
# manifest["version"] = plugin_instance.plugin_version
# manifest["description"] = plugin_instance.plugin_description
# # 作者信息
# if plugin_instance.plugin_author:
# manifest["author"]["name"] = plugin_instance.plugin_author
# # 组件信息
# components = []
# plugin_components = plugin_instance.get_plugin_components()
# for component_info, component_class in plugin_components:
# component_data: Dict[str, Any] = {
# "type": component_info.component_type.value,
# "name": component_info.name,
# "description": component_info.description,
# }
# # 添加激活模式信息对于Action组件
# if hasattr(component_class, "focus_activation_type"):
# activation_modes = []
# if hasattr(component_class, "focus_activation_type"):
# activation_modes.append(component_class.focus_activation_type.value)
# if hasattr(component_class, "normal_activation_type"):
# activation_modes.append(component_class.normal_activation_type.value)
# component_data["activation_modes"] = list(set(activation_modes))
# # 添加关键词信息
# if hasattr(component_class, "activation_keywords"):
# keywords = getattr(component_class, "activation_keywords", [])
# if keywords:
# component_data["keywords"] = keywords
# components.append(component_data)
# manifest["plugin_info"] = {"is_built_in": True, "plugin_type": "general", "components": components}
# return manifest
# def save_manifest(self, manifest_data: Dict[str, Any], plugin_dir: str) -> bool:
# """保存manifest文件
# Args:
# manifest_data: manifest数据
# plugin_dir: 插件目录
# Returns:
# bool: 是否保存成功
# """
# try:
# manifest_path = os.path.join(plugin_dir, "_manifest.json")
# with open(manifest_path, "w", encoding="utf-8") as f:
# json.dump(manifest_data, f, ensure_ascii=False, indent=2)
# logger.info(f"Manifest文件已保存: {manifest_path}")
# return True
# except Exception as e:
# logger.error(f"保存manifest文件失败: {e}")
# return False
# def validate_plugin_manifest(plugin_dir: str) -> bool:
# """验证插件目录中的manifest文件
# Args:
# plugin_dir: 插件目录路径
# Returns:
# bool: 是否验证通过
# """
# manifest_path = os.path.join(plugin_dir, "_manifest.json")
# if not os.path.exists(manifest_path):
# logger.warning(f"未找到manifest文件: {manifest_path}")
# return False
# try:
# with open(manifest_path, "r", encoding="utf-8") as f:
# manifest_data = json.load(f)
# validator = ManifestValidator()
# is_valid = validator.validate_manifest(manifest_data)
# logger.info(f"Manifest验证结果:\n{validator.get_validation_report()}")
# return is_valid
# except Exception as e:
# logger.error(f"读取或验证manifest文件失败: {e}")
# return False
# def generate_plugin_manifest(plugin_instance: BasePlugin, save_to_file: bool = True) -> Optional[Dict[str, Any]]:
# """为插件生成manifest文件
# Args:
# plugin_instance: BasePlugin实例
# save_to_file: 是否保存到文件
# Returns:
# Optional[Dict[str, Any]]: 生成的manifest数据
# """
# try:
# generator = ManifestGenerator()
# manifest_data = generator.generate_from_plugin(plugin_instance)
# if save_to_file and plugin_instance.plugin_dir:
# generator.save_manifest(manifest_data, plugin_instance.plugin_dir)
# return manifest_data
# except Exception as e:
# logger.error(f"生成manifest文件失败: {e}")
# return None

View File

@@ -1,34 +1,38 @@
{ {
"manifest_version": 1, "manifest_version": 1,
"name": "Emoji插件 (Emoji Actions)", "name": "Emoji插件 (Emoji Actions)",
"version": "1.0.0", "version": "2.0.0",
"description": "可以发送和管理Emoji", "description": "可以发送和管理Emoji",
"author": { "author": {
"name": "SengokuCola", "name": "SengokuCola",
"url": "https://github.com/MaiM-with-u" "url": "https://github.com/MaiM-with-u"
}, },
"license": "GPL-v3.0-or-later", "license": "GPL-v3.0-or-later",
"host_application": { "host_application": {
"min_version": "0.10.0" "min_version": "1.0.0"
}, },
"homepage_url": "https://github.com/MaiM-with-u/maibot", "homepage_url": "https://github.com/MaiM-with-u/maibot",
"repository_url": "https://github.com/MaiM-with-u/maibot", "repository_url": "https://github.com/MaiM-with-u/maibot",
"keywords": ["emoji", "action", "built-in"], "keywords": ["emoji", "action", "built-in"],
"categories": ["Emoji"], "categories": ["Emoji"],
"default_locale": "zh-CN", "default_locale": "zh-CN",
"locales_path": "_locales",
"plugin_info": { "plugin_info": {
"is_built_in": true, "is_built_in": true,
"plugin_type": "action_provider", "plugin_type": "action_provider",
"components": [ "components": [
{ {
"type": "action", "type": "action",
"name": "emoji", "name": "emoji",
"description": "发送表情包辅助表达情绪" "description": "发送表情包辅助表达情绪"
} }
] ]
} },
"capabilities": [
"emoji.get_random",
"message.get_recent",
"message.build_readable",
"llm.generate",
"send.emoji",
"config.get"
]
} }

View File

@@ -1,151 +0,0 @@
import random
from typing import Tuple
# 导入新插件系统
from src.plugin_system import BaseAction, ActionActivationType
# 导入依赖的系统组件
from src.common.logger import get_logger
# 导入API模块 - 标准Python包方式
from src.plugin_system.apis import emoji_api, llm_api, message_api
# NoReplyAction已集成到heartFC_chat.py中不再需要导入
from src.config.config import global_config
logger = get_logger("emoji")
class EmojiAction(BaseAction):
"""表情动作 - 发送表情包"""
activation_type = ActionActivationType.RANDOM
random_activation_probability = global_config.emoji.emoji_chance
parallel_action = True
# 动作基本信息
action_name = "emoji"
action_description = "发送表情包辅助表达情绪"
# 动作参数定义
action_parameters = {}
# 动作使用场景
action_require = [
"发送表情包辅助表达情绪",
"表达情绪时可以选择使用",
"不要连续发送,如果你已经发过[表情包],就不要选择此动作",
]
# 关联类型
associated_types = ["emoji"]
async def execute(self) -> Tuple[bool, str]:
# sourcery skip: assign-if-exp, introduce-default-else, swap-if-else-branches, use-named-expression
"""执行表情动作"""
try:
# 1. 获取发送表情的原因
# reason = self.action_data.get("reason", "表达当前情绪")
reason = self.action_reasoning
# 2. 随机获取20个表情包
sampled_emojis = await emoji_api.get_random(30)
if not sampled_emojis:
logger.warning(f"{self.log_prefix} 无法获取随机表情包")
return False, "无法获取随机表情包"
# 3. 准备情感数据
emotion_map = {}
for b64, desc, emo in sampled_emojis:
if emo not in emotion_map:
emotion_map[emo] = []
emotion_map[emo].append((b64, desc))
available_emotions = list(emotion_map.keys())
available_emotions_str = ""
for emotion in available_emotions:
available_emotions_str += f"{emotion}\n"
if not available_emotions:
logger.warning(f"{self.log_prefix} 获取到的表情包均无情感标签, 将随机发送")
emoji_base64, emoji_description, _ = random.choice(sampled_emojis)
else:
# 获取最近的5条消息内容用于判断
recent_messages = message_api.get_recent_messages(chat_id=self.chat_id, limit=5)
messages_text = ""
if recent_messages:
# 使用message_api构建可读的消息字符串
messages_text = message_api.build_readable_messages(
messages=recent_messages,
timestamp_mode="normal_no_YMD",
truncate=False,
show_actions=False,
)
# 4. 构建prompt让LLM选择情感
prompt = f"""你正在进行QQ聊天你需要根据聊天记录选出一个合适的情感标签。
请你根据以下原因和聊天记录进行选择
原因:{reason}
聊天记录:
{messages_text}
这里是可用的情感标签:
{available_emotions_str}
请直接返回最匹配的那个情感标签,不要进行任何解释或添加其他多余的文字。
"""
if global_config.debug.show_prompt:
logger.info(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
else:
logger.debug(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
# 5. 调用LLM
models = llm_api.get_available_models()
chat_model_config = models.get("utils") # 使用字典访问方式
if not chat_model_config:
logger.error(f"{self.log_prefix} 未找到'utils'模型配置无法调用LLM")
return False, "未找到'utils'模型配置"
success, chosen_emotion, _, _ = await llm_api.generate_with_model(
prompt, model_config=chat_model_config, request_type="emoji.select"
)
if not success:
logger.error(f"{self.log_prefix} LLM调用失败: {chosen_emotion}")
return False, f"LLM调用失败: {chosen_emotion}"
chosen_emotion = chosen_emotion.strip().replace('"', "").replace("'", "")
logger.info(f"{self.log_prefix} LLM选择的情感: {chosen_emotion}")
# 6. 根据选择的情感匹配表情包
if chosen_emotion in emotion_map:
emoji_base64, emoji_description = random.choice(emotion_map[chosen_emotion])
logger.info(f"{self.log_prefix} 发送表情包[{chosen_emotion}],原因: {reason}")
else:
logger.warning(
f"{self.log_prefix} LLM选择的情感 '{chosen_emotion}' 不在可用列表中, 将随机选择一个表情包"
)
emoji_base64, emoji_description, _ = random.choice(sampled_emojis)
# 7. 发送表情包
success = await self.send_emoji(emoji_base64)
if success:
# 存储动作信息
await self.store_action_info(
action_build_into_prompt=True,
action_prompt_display=f"你发送了表情包,原因:{reason}",
action_done=True,
)
return True, f"成功发送表情包:[表情包:{chosen_emotion}]"
else:
error_msg = "发送表情包失败"
logger.error(f"{self.log_prefix} {error_msg}")
await self.send_text("执行表情包动作失败")
return False, error_msg
except Exception as e:
logger.error(f"{self.log_prefix} 表情动作执行失败: {e}", exc_info=True)
return False, f"表情发送失败: {str(e)}"

View File

@@ -1,66 +1,116 @@
""" """Emoji 插件 — 新 SDK 版本
核心动作插件
将系统核心动作reply、no_reply、emoji转换为新插件系统格式 根据聊天上下文的情感,使用 LLM 选择并发送合适的表情包。
这是系统的内置插件,提供基础的聊天交互功能
""" """
from typing import List, Tuple, Type import random
# 导入新插件系统 from maibot_sdk import MaiBotPlugin, Action
from src.plugin_system import BasePlugin, register_plugin, ComponentInfo from maibot_sdk.types import ActivationType
from src.plugin_system.base.config_types import ConfigField
# 导入依赖的系统组件
from src.common.logger import get_logger
from src.plugins.built_in.emoji_plugin.emoji import EmojiAction
logger = get_logger("core_actions")
@register_plugin class EmojiPlugin(MaiBotPlugin):
class CoreActionsPlugin(BasePlugin): """表情包插件"""
"""核心动作插件
系统内置插件,提供基础的聊天交互功能: @Action(
- Reply: 回复动作 "emoji",
- NoReply: 不回复动作 description="发送表情包辅助表达情绪",
- Emoji: 表情动作 activation_type=ActivationType.RANDOM,
activation_probability=0.3,
parallel_action=True,
action_require=[
"发送表情包辅助表达情绪",
"表达情绪时可以选择使用",
"不要连续发送,如果你已经发过[表情包],就不要选择此动作",
],
associated_types=["emoji"],
)
async def handle_emoji(self, stream_id: str = "", reasoning: str = "", chat_id: str = "", **kwargs):
"""执行表情动作"""
reason = reasoning or "表达当前情绪"
注意插件基本信息优先从_manifest.json文件中读取 # 1. 随机获取30个表情包
""" result = await self.ctx.emoji.get_random(30)
if not result or not result.get("success"):
return False, "无法获取随机表情包"
# 插件基本信息 sampled_emojis = result.get("emojis", [])
plugin_name: str = "core_actions" # 内部标识符 if not sampled_emojis:
enable_plugin: bool = True return False, "无法获取随机表情包"
dependencies: list[str] = [] # 插件依赖列表
python_dependencies: list[str] = [] # Python包依赖列表
config_file_name: str = "config.toml"
# 配置节描述 # 2. 按情感分组
config_section_descriptions = { emotion_map: dict[str, list] = {}
"plugin": "插件启用配置", for emoji in sampled_emojis:
"components": "核心组件启用配置", emo = emoji.get("emotion", "")
} if emo not in emotion_map:
emotion_map[emo] = []
emotion_map[emo].append(emoji)
# 配置Schema定义 available_emotions = list(emotion_map.keys())
config_schema: dict = {
"plugin": {
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
"config_version": ConfigField(type=str, default="0.6.0", description="配置文件版本"),
},
"components": {
"enable_emoji": ConfigField(type=bool, default=True, description="是否启用发送表情/图片动作"),
},
}
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: if not available_emotions:
"""返回插件包含的组件列表""" # 无情感标签,随机发送
chosen = random.choice(sampled_emojis)
await self.ctx.send.emoji(chosen["base64"], stream_id)
return True, "随机发送了表情包"
# --- 根据配置注册组件 --- # 3. 获取最近消息作为上下文
components = [] messages_text = ""
if self.get_config("components.enable_emoji", True): if chat_id:
components.append((EmojiAction.get_action_info(), EmojiAction)) recent_result = await self.ctx.message.get_recent(chat_id=chat_id, limit=5)
if recent_result and recent_result.get("success"):
readable_result = await self.ctx.call_capability(
"message.build_readable",
chat_id=chat_id,
start_time=0,
end_time=0,
limit=5,
timestamp_mode="normal_no_YMD",
truncate=False,
)
if readable_result and readable_result.get("success"):
messages_text = readable_result.get("text", "")
return components # 4. 构建 prompt 让 LLM 选择情感
available_emotions_str = "\n".join(available_emotions)
prompt = f"""你正在进行QQ聊天你需要根据聊天记录选出一个合适的情感标签。
请你根据以下原因和聊天记录进行选择
原因:{reason}
聊天记录:
{messages_text}
这里是可用的情感标签:
{available_emotions_str}
请直接返回最匹配的那个情感标签,不要进行任何解释或添加其他多余的文字。
"""
# 5. 调用 LLM
llm_result = await self.ctx.llm.generate(prompt=prompt, model_name="utils")
if not llm_result or not llm_result.get("success"):
chosen = random.choice(sampled_emojis)
await self.ctx.send.emoji(chosen["base64"], stream_id)
return True, "LLM调用失败随机发送了表情包"
chosen_emotion = llm_result.get("response", "").strip().replace('"', "").replace("'", "")
# 6. 根据选择的情感匹配表情包
if chosen_emotion in emotion_map:
chosen = random.choice(emotion_map[chosen_emotion])
else:
chosen = random.choice(sampled_emojis)
# 7. 发送
send_result = await self.ctx.send.emoji(chosen["base64"], stream_id)
if send_result and send_result.get("success"):
return True, f"成功发送表情包:[表情包:{chosen_emotion}]"
return False, "发送表情包失败"
async def on_load(self):
# 从插件配置读取 emoji_chance 来覆盖默认概率
config_result = await self.ctx.config.get("emoji.emoji_chance")
if config_result and isinstance(config_result, dict) and config_result.get("success"):
pass # 配置已在宿主端管理
def create_plugin():
return EmojiPlugin()

View File

@@ -0,0 +1,33 @@
{
"manifest_version": 1,
"name": "LPMM 知识库插件 (Knowledge Search)",
"version": "2.0.0",
"description": "从 LPMM 知识库中搜索相关信息,供 LLM 工具调用",
"author": {
"name": "MaiBot团队",
"url": "https://github.com/MaiM-with-u"
},
"license": "GPL-v3.0-or-later",
"host_application": {
"min_version": "1.0.0"
},
"homepage_url": "https://github.com/MaiM-with-u/maibot",
"repository_url": "https://github.com/MaiM-with-u/maibot",
"keywords": ["knowledge", "lpmm", "search", "tool", "built-in"],
"categories": ["Knowledge", "Tools"],
"default_locale": "zh-CN",
"plugin_info": {
"is_built_in": true,
"plugin_type": "tool_provider",
"components": [
{
"type": "tool",
"name": "lpmm_search_knowledge",
"description": "从知识库中搜索相关信息"
}
]
},
"capabilities": [
"knowledge.search"
]
}

View File

@@ -1,62 +0,0 @@
from typing import Dict, Any
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.knowledge import qa_manager
from src.plugin_system import BaseTool, ToolParamType
logger = get_logger("lpmm_get_knowledge_tool")
class SearchKnowledgeFromLPMMTool(BaseTool):
"""从LPMM知识库中搜索相关信息的工具"""
name = "lpmm_search_knowledge"
description = "从知识库中搜索相关信息,如果你需要知识,就使用这个工具"
parameters = [
("query", ToolParamType.STRING, "搜索查询关键词", True, None),
("limit", ToolParamType.INTEGER, "希望返回的相关知识条数默认5", False, None),
]
available_for_llm = global_config.lpmm_knowledge.enable
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
"""执行知识库搜索
Args:
function_args: 工具参数
Returns:
Dict: 工具执行结果
"""
try:
query: str = function_args.get("query") # type: ignore
limit = function_args.get("limit", 5)
try:
limit_value = int(limit)
except (TypeError, ValueError):
limit_value = 5
limit_value = max(1, limit_value)
# threshold = function_args.get("threshold", 0.4)
# 检查LPMM知识库是否启用
if qa_manager is None:
logger.debug("LPMM知识库已禁用跳过知识获取")
return {"type": "info", "id": query, "content": "LPMM知识库已禁用"}
# 调用知识库搜索
knowledge_info = await qa_manager.get_knowledge(query, limit=limit_value)
logger.debug(f"知识库查询结果: {knowledge_info}")
if knowledge_info:
content = f"你知道这些知识: {knowledge_info}"
else:
content = f"你不太了解有关{query}的知识"
return {"type": "lpmm_knowledge", "id": query, "content": content}
except Exception as e:
# 捕获异常并记录错误
logger.error(f"知识库搜索工具执行失败: {str(e)}")
# 在其他异常情况下,确保 id 仍然是 query (如果它被定义了)
query_id = query if "query" in locals() else "unknown_query"
return {"type": "info", "id": query_id, "content": f"lpmm知识库搜索失败炸了: {str(e)}"}

View File

@@ -0,0 +1,39 @@
"""LPMM 知识库搜索插件 — 新 SDK 版本
提供 LLM 可调用的知识库搜索工具。
"""
from maibot_sdk import MaiBotPlugin, Tool
from maibot_sdk.types import ToolParameterInfo, ToolParamType
class KnowledgePlugin(MaiBotPlugin):
"""LPMM 知识库插件"""
@Tool(
"lpmm_search_knowledge",
description="从知识库中搜索相关信息,如果你需要知识,就使用这个工具",
parameters=[
ToolParameterInfo(name="query", param_type=ToolParamType.STRING, description="搜索查询关键词", required=True),
ToolParameterInfo(name="limit", param_type=ToolParamType.INTEGER, description="希望返回的相关知识条数默认5", required=False, default=5),
],
)
async def handle_lpmm_search_knowledge(self, query: str = "", limit: int = 5, **kwargs):
"""执行知识库搜索"""
if not query:
return {"type": "info", "id": "", "content": "未提供搜索关键词"}
try:
limit_value = max(1, int(limit))
except (TypeError, ValueError):
limit_value = 5
result = await self.ctx.call_capability("knowledge.search", query=query, limit=limit_value)
if result and result.get("success"):
content = result.get("content", f"你不太了解有关{query}的知识")
return {"type": "lpmm_knowledge", "id": query, "content": content}
return {"type": "info", "id": query, "content": f"知识库搜索失败: {result}"}
def create_plugin():
return KnowledgePlugin()

View File

@@ -1,7 +1,7 @@
{ {
"manifest_version": 1, "manifest_version": 1,
"name": "插件和组件管理 (Plugin and Component Management)", "name": "插件和组件管理 (Plugin and Component Management)",
"version": "1.0.0", "version": "2.0.0",
"description": "通过系统API管理插件和组件的生命周期包括加载、卸载、启用和禁用等操作。", "description": "通过系统API管理插件和组件的生命周期包括加载、卸载、启用和禁用等操作。",
"author": { "author": {
"name": "MaiBot团队", "name": "MaiBot团队",
@@ -9,7 +9,7 @@
}, },
"license": "GPL-v3.0-or-later", "license": "GPL-v3.0-or-later",
"host_application": { "host_application": {
"min_version": "0.10.1" "min_version": "1.0.0"
}, },
"homepage_url": "https://github.com/MaiM-with-u/maibot", "homepage_url": "https://github.com/MaiM-with-u/maibot",
"repository_url": "https://github.com/MaiM-with-u/maibot", "repository_url": "https://github.com/MaiM-with-u/maibot",
@@ -28,10 +28,22 @@
"plugin_info": { "plugin_info": {
"is_built_in": true, "is_built_in": true,
"plugin_type": "plugin_management", "plugin_type": "plugin_management",
"capabilities": [
"component.get_all_plugins",
"component.list_loaded_plugins",
"component.list_registered_plugins",
"component.enable",
"component.disable",
"component.load_plugin",
"component.unload_plugin",
"component.reload_plugin",
"send.text",
"config.get"
],
"components": [ "components": [
{ {
"type": "command", "type": "command",
"name": "plugin_management", "name": "management",
"description": "管理插件和组件的生命周期,包括加载、卸载、启用和禁用等操作。" "description": "管理插件和组件的生命周期,包括加载、卸载、启用和禁用等操作。"
} }
] ]

View File

@@ -1,454 +1,279 @@
import asyncio """插件和组件管理 — 新 SDK 版本
from typing import List, Tuple, Type 通过 /pm 命令管理插件和组件的生命周期。
from src.plugin_system import ( """
BasePlugin,
BaseCommand, from maibot_sdk import MaiBotPlugin, Command
CommandInfo,
ConfigField,
register_plugin, _VALID_COMPONENT_TYPES = ("action", "command", "event_handler")
plugin_manage_api,
component_manage_api, HELP_ALL = (
ComponentInfo, "管理命令帮助\n"
ComponentType, "/pm help 管理命令提示\n"
send_api, "/pm plugin 插件管理命令\n"
"/pm component 组件管理命令\n"
"使用 /pm plugin help 或 /pm component help 获取具体帮助"
)
HELP_PLUGIN = (
"插件管理命令帮助\n"
"/pm plugin help 插件管理命令提示\n"
"/pm plugin list 列出所有注册的插件\n"
"/pm plugin list_enabled 列出所有加载(启用)的插件\n"
"/pm plugin load <plugin_name> 加载指定插件\n"
"/pm plugin unload <plugin_name> 卸载指定插件\n"
"/pm plugin reload <plugin_name> 重新加载指定插件\n"
)
HELP_COMPONENT = (
"组件管理命令帮助\n"
"/pm component help 组件管理命令提示\n"
"/pm component list 列出所有注册的组件\n"
"/pm component list enabled <可选: type> 列出所有启用的组件\n"
"/pm component list disabled <可选: type> 列出所有禁用的组件\n"
" - <type> 可选项: local代表当前聊天中的global代表全局的\n"
" - <type> 不填时为 global\n"
"/pm component list type <component_type> 列出已经注册的指定类型的组件\n"
"/pm component enable global <component_name> <component_type> 全局启用组件\n"
"/pm component enable local <component_name> <component_type> 本聊天启用组件\n"
"/pm component disable global <component_name> <component_type> 全局禁用组件\n"
"/pm component disable local <component_name> <component_type> 本聊天禁用组件\n"
" - <component_type> 可选项: action, command, event_handler\n"
) )
class ManagementCommand(BaseCommand): class PluginManagementPlugin(MaiBotPlugin):
command_name: str = "management" """插件和组件管理插件"""
description: str = "管理命令"
command_pattern: str = r"(?P<manage_command>^/pm(\s[a-zA-Z0-9_]+)*\s*$)"
async def execute(self) -> Tuple[bool, str, bool]: @Command(
# sourcery skip: merge-duplicate-blocks "management",
if ( description="管理插件和组件的生命周期",
not self.message pattern=r"(?P<manage_command>^/pm(\s[a-zA-Z0-9_]+)*\s*$)",
or not self.message.message_info )
or not self.message.message_info.user_info async def handle_management(
or str(self.message.message_info.user_info.user_id) not in self.get_config("plugin.permission", []) # type: ignore self, stream_id: str = "", user_id: str = "", matched_groups: dict | None = None, **kwargs
): ):
await self._send_message("你没有权限使用插件管理命令") """处理 /pm 命令"""
# 权限检查
permission_result = await self.ctx.config.get("plugin.permission")
permission_list = permission_result if isinstance(permission_result, list) else []
if str(user_id) not in permission_list:
await self.ctx.send.text("你没有权限使用插件管理命令", stream_id)
return False, "没有权限", True return False, "没有权限", True
if not self.message.chat_stream:
await self._send_message("无法获取聊天流信息") if not stream_id:
return False, "无法获取聊天流信息", True return False, "无法获取聊天流信息", True
self.stream_id = self.message.chat_stream.stream_id
if not self.stream_id: raw_command = (matched_groups or {}).get("manage_command", "").strip()
await self._send_message("无法获取聊天流信息") parts = raw_command.split(" ") if raw_command else ["/pm"]
return False, "无法获取聊天流信息", True n = len(parts)
command_list = self.matched_groups["manage_command"].strip().split(" ")
if len(command_list) == 1: # /pm
await self.show_help("all") if n == 1:
await self.ctx.send.text(HELP_ALL, stream_id)
return True, "帮助已发送", True return True, "帮助已发送", True
if len(command_list) == 2:
match command_list[1]:
case "plugin":
await self.show_help("plugin")
case "component":
await self.show_help("component")
case "help":
await self.show_help("all")
case _:
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
if len(command_list) == 3:
if command_list[1] == "plugin":
match command_list[2]:
case "help":
await self.show_help("plugin")
case "list":
await self._list_registered_plugins()
case "list_enabled":
await self._list_loaded_plugins()
case "rescan":
await self._rescan_plugin_dirs()
case _:
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
elif command_list[1] == "component":
if command_list[2] == "list":
await self._list_all_registered_components()
elif command_list[2] == "help":
await self.show_help("component")
else:
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
else:
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
if len(command_list) == 4:
if command_list[1] == "plugin":
match command_list[2]:
case "load":
await self._load_plugin(command_list[3])
case "unload":
await self._unload_plugin(command_list[3])
case "reload":
await self._reload_plugin(command_list[3])
case "add_dir":
await self._add_dir(command_list[3])
case _:
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
elif command_list[1] == "component":
if command_list[2] != "list":
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
if command_list[3] == "enabled":
await self._list_enabled_components()
elif command_list[3] == "disabled":
await self._list_disabled_components()
else:
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
else:
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
if len(command_list) == 5:
if command_list[1] != "component":
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
if command_list[2] != "list":
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
if command_list[3] == "enabled":
await self._list_enabled_components(target_type=command_list[4])
elif command_list[3] == "disabled":
await self._list_disabled_components(target_type=command_list[4])
elif command_list[3] == "type":
await self._list_registered_components_by_type(command_list[4])
else:
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
if len(command_list) == 6:
if command_list[1] != "component":
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
if command_list[2] == "enable":
if command_list[3] == "global":
await self._globally_enable_component(command_list[4], command_list[5])
elif command_list[3] == "local":
await self._locally_enable_component(command_list[4], command_list[5])
else:
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
elif command_list[2] == "disable":
if command_list[3] == "global":
await self._globally_disable_component(command_list[4], command_list[5])
elif command_list[3] == "local":
await self._locally_disable_component(command_list[4], command_list[5])
else:
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
else:
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
return True, "命令执行完成", True # /pm <sub>
if n == 2:
sub = parts[1]
if sub == "plugin":
await self.ctx.send.text(HELP_PLUGIN, stream_id)
elif sub == "component":
await self.ctx.send.text(HELP_COMPONENT, stream_id)
elif sub == "help":
await self.ctx.send.text(HELP_ALL, stream_id)
else:
await self.ctx.send.text("插件管理命令不合法", stream_id)
return False, "命令不合法", True
return True, "帮助已发送", True
async def show_help(self, target: str): # /pm plugin <action> / /pm component <action>
help_msg = "" if n == 3:
match target: if parts[1] == "plugin":
case "all": await self._handle_plugin_3(parts[2], stream_id)
help_msg = ( elif parts[1] == "component":
"管理命令帮助\n" if parts[2] == "list":
"/pm help 管理命令提示\n" await self._list_all_components(stream_id)
"/pm plugin 插件管理命令\n" elif parts[2] == "help":
"/pm component 组件管理命令\n" await self.ctx.send.text(HELP_COMPONENT, stream_id)
"使用 /pm plugin help 或 /pm component help 获取具体帮助" else:
) await self.ctx.send.text("插件管理命令不合法", stream_id)
case "plugin": return False, "命令不合法", True
help_msg = ( else:
"插件管理命令帮助\n" await self.ctx.send.text("插件管理命令不合法", stream_id)
"/pm plugin help 插件管理命令提示\n" return False, "命令不合法", True
"/pm plugin list 列出所有注册的插件\n" return True, "命令执行完成", True
"/pm plugin list_enabled 列出所有加载(启用)的插件\n"
"/pm plugin rescan 重新扫描所有目录\n" if n == 4:
"/pm plugin load <plugin_name> 加载指定插件\n" if parts[1] == "plugin":
"/pm plugin unload <plugin_name> 卸载指定插件\n" await self._handle_plugin_4(parts[2], parts[3], stream_id)
"/pm plugin reload <plugin_name> 重新加载指定插件\n" elif parts[1] == "component":
"/pm plugin add_dir <directory_path> 添加插件目录\n" if parts[2] == "list":
) await self._handle_component_list_4(parts[3], stream_id)
case "component": else:
help_msg = ( await self.ctx.send.text("插件管理命令不合法", stream_id)
"组件管理命令帮助\n" return False, "命令不合法", True
"/pm component help 组件管理命令提示\n" else:
"/pm component list 列出所有注册的组件\n" await self.ctx.send.text("插件管理命令不合法", stream_id)
"/pm component list enabled <可选: type> 列出所有启用的组件\n" return False, "命令不合法", True
"/pm component list disabled <可选: type> 列出所有禁用的组件\n" return True, "命令执行完成", True
" - <type> 可选项: local代表当前聊天中的global代表全局的\n"
" - <type> 不填时为 global\n" if n == 5:
"/pm component list type <component_type> 列出已经注册的指定类型的组件\n" if parts[1] != "component" or parts[2] != "list":
"/pm component enable global <component_name> <component_type> 全局启用组件\n" await self.ctx.send.text("插件管理命令不合法", stream_id)
"/pm component enable local <component_name> <component_type> 本聊天启用组件\n" return False, "命令不合法", True
"/pm component disable global <component_name> <component_type> 全局禁用组件\n" await self._handle_component_list_5(parts[3], parts[4], stream_id)
"/pm component disable local <component_name> <component_type> 本聊天禁用组件\n" return True, "命令执行完成", True
" - <component_type> 可选项: action, command, event_handler\n"
) if n == 6:
if parts[1] != "component":
await self.ctx.send.text("插件管理命令不合法", stream_id)
return False, "命令不合法", True
await self._handle_component_toggle(parts[2], parts[3], parts[4], parts[5], stream_id)
return True, "命令执行完成", True
await self.ctx.send.text("插件管理命令不合法", stream_id)
return False, "命令不合法", True
# ------ plugin 子命令 ------
async def _handle_plugin_3(self, action: str, stream_id: str):
match action:
case "help":
await self.ctx.send.text(HELP_PLUGIN, stream_id)
case "list":
result = await self.ctx.component.list_registered_plugins()
plugins = result if isinstance(result, list) else []
await self.ctx.send.text(f"已注册的插件: {', '.join(plugins) if plugins else ''}", stream_id)
case "list_enabled":
result = await self.ctx.component.list_loaded_plugins()
plugins = result if isinstance(result, list) else []
await self.ctx.send.text(f"已加载的插件: {', '.join(plugins) if plugins else ''}", stream_id)
case _: case _:
await self.ctx.send.text("插件管理命令不合法", stream_id)
async def _handle_plugin_4(self, action: str, name: str, stream_id: str):
match action:
case "load":
result = await self.ctx.component.load_plugin(name)
ok = result.get("success", False) if isinstance(result, dict) else bool(result)
msg = f"插件加载成功: {name}" if ok else f"插件加载失败: {name}"
await self.ctx.send.text(msg, stream_id)
case "unload":
result = await self.ctx.component.unload_plugin(name)
ok = result.get("success", False) if isinstance(result, dict) else bool(result)
msg = f"插件卸载成功: {name}" if ok else f"插件卸载失败: {name}"
await self.ctx.send.text(msg, stream_id)
case "reload":
result = await self.ctx.component.reload_plugin(name)
ok = result.get("success", False) if isinstance(result, dict) else bool(result)
msg = f"插件重新加载成功: {name}" if ok else f"插件重新加载失败: {name}"
await self.ctx.send.text(msg, stream_id)
case _:
await self.ctx.send.text("插件管理命令不合法", stream_id)
# ------ component 子命令 ------
async def _list_all_components(self, stream_id: str):
result = await self.ctx.component.get_all_plugins()
if not result:
await self.ctx.send.text("没有注册的组件", stream_id)
return
components = self._extract_components(result)
if not components:
await self.ctx.send.text("没有注册的组件", stream_id)
return
text = ", ".join(f"{c['name']} ({c['type']})" for c in components)
await self.ctx.send.text(f"已注册的组件: {text}", stream_id)
async def _handle_component_list_4(self, sub: str, stream_id: str):
if sub == "enabled":
await self._list_filtered_components("enabled", "global", stream_id)
elif sub == "disabled":
await self._list_filtered_components("disabled", "global", stream_id)
else:
await self.ctx.send.text("插件管理命令不合法", stream_id)
async def _handle_component_list_5(self, sub: str, arg: str, stream_id: str):
if sub in ("enabled", "disabled"):
await self._list_filtered_components(sub, arg, stream_id)
elif sub == "type":
if arg not in _VALID_COMPONENT_TYPES:
await self.ctx.send.text(f"未知组件类型: {arg}", stream_id)
return return
await self._send_message(help_msg) result = await self.ctx.component.get_all_plugins()
components = [c for c in self._extract_components(result) if c.get("type") == arg]
async def _list_loaded_plugins(self): if not components:
plugins = plugin_manage_api.list_loaded_plugins() await self.ctx.send.text(f"没有注册的 {arg} 组件", stream_id)
await self._send_message(f"已加载的插件: {', '.join(plugins)}") return
text = ", ".join(f"{c['name']} ({c['type']})" for c in components)
async def _list_registered_plugins(self): await self.ctx.send.text(f"注册的 {arg} 组件: {text}", stream_id)
plugins = plugin_manage_api.list_registered_plugins()
await self._send_message(f"已注册的插件: {', '.join(plugins)}")
async def _rescan_plugin_dirs(self):
plugin_manage_api.rescan_plugin_directory()
await self._send_message("插件目录重新扫描执行中")
async def _load_plugin(self, plugin_name: str):
success, count = plugin_manage_api.load_plugin(plugin_name)
if success:
await self._send_message(f"插件加载成功: {plugin_name}")
else: else:
if count == 0: await self.ctx.send.text("插件管理命令不合法", stream_id)
await self._send_message(f"插件{plugin_name}为禁用状态")
await self._send_message(f"插件加载失败: {plugin_name}")
async def _unload_plugin(self, plugin_name: str): async def _list_filtered_components(self, filter_mode: str, scope: str, stream_id: str):
success = await plugin_manage_api.remove_plugin(plugin_name) result = await self.ctx.component.get_all_plugins()
if success: all_components = self._extract_components(result)
await self._send_message(f"插件卸载成功: {plugin_name}") if not all_components:
await self.ctx.send.text("没有注册的组件", stream_id)
return
if filter_mode == "enabled":
filtered = [c for c in all_components if c.get("enabled", False)]
label = "已启用"
else: else:
await self._send_message(f"插件卸载失败: {plugin_name}") filtered = [c for c in all_components if not c.get("enabled", False)]
label = "已禁用"
async def _reload_plugin(self, plugin_name: str): scope_label = "全局" if scope == "global" else "本聊天"
success = await plugin_manage_api.reload_plugin(plugin_name) if not filtered:
if success: await self.ctx.send.text(f"没有满足条件的{label}{scope_label}组件", stream_id)
await self._send_message(f"插件重新加载成功: {plugin_name}") return
text = ", ".join(f"{c['name']} ({c['type']})" for c in filtered)
await self.ctx.send.text(f"满足条件的{label}{scope_label}组件: {text}", stream_id)
async def _handle_component_toggle(
self, action: str, scope: str, comp_name: str, comp_type: str, stream_id: str
):
if action not in ("enable", "disable"):
await self.ctx.send.text("插件管理命令不合法", stream_id)
return
if scope not in ("global", "local"):
await self.ctx.send.text("插件管理命令不合法", stream_id)
return
if comp_type not in _VALID_COMPONENT_TYPES:
await self.ctx.send.text(f"未知组件类型: {comp_type}", stream_id)
return
if action == "enable":
result = await self.ctx.component.enable_component(
comp_name, comp_type, scope=scope, stream_id=stream_id
)
else: else:
await self._send_message(f"插件重新加载失败: {plugin_name}") result = await self.ctx.component.disable_component(
comp_name, comp_type, scope=scope, stream_id=stream_id
)
async def _add_dir(self, dir_path: str): ok = result.get("success", False) if isinstance(result, dict) else bool(result)
await self._send_message(f"正在添加插件目录: {dir_path}") scope_label = "全局" if scope == "global" else "本地"
success = plugin_manage_api.add_plugin_directory(dir_path) action_label = "启用" if action == "enable" else "禁用"
await asyncio.sleep(0.5) # 防止乱序发送 status = "成功" if ok else "失败"
if success: await self.ctx.send.text(f"{scope_label}{action_label}组件{status}: {comp_name}", stream_id)
await self._send_message(f"插件目录添加成功: {dir_path}")
else:
await self._send_message(f"插件目录添加失败: {dir_path}")
def _fetch_all_registered_components(self) -> List[ComponentInfo]: # ------ helpers ------
all_plugin_info = component_manage_api.get_all_plugin_info()
if not all_plugin_info: @staticmethod
def _extract_components(result) -> list[dict]:
"""从 get_all_plugins 结果中提取所有组件列表"""
if not result:
return [] return []
if isinstance(result, dict):
components_info: List[ComponentInfo] = [] components = []
for plugin_info in all_plugin_info.values(): for plugin_info in result.values():
components_info.extend(plugin_info.components) if isinstance(plugin_info, dict):
return components_info components.extend(plugin_info.get("components", []))
return components
def _fetch_locally_disabled_components(self) -> List[str]: return []
locally_disabled_components_actions = component_manage_api.get_locally_disabled_components(
self.message.chat_stream.stream_id, ComponentType.ACTION
)
locally_disabled_components_commands = component_manage_api.get_locally_disabled_components(
self.message.chat_stream.stream_id, ComponentType.COMMAND
)
locally_disabled_components_event_handlers = component_manage_api.get_locally_disabled_components(
self.message.chat_stream.stream_id, ComponentType.EVENT_HANDLER
)
return (
locally_disabled_components_actions
+ locally_disabled_components_commands
+ locally_disabled_components_event_handlers
)
async def _list_all_registered_components(self):
components_info = self._fetch_all_registered_components()
if not components_info:
await self._send_message("没有注册的组件")
return
all_components_str = ", ".join(
f"{component.name} ({component.component_type})" for component in components_info
)
await self._send_message(f"已注册的组件: {all_components_str}")
async def _list_enabled_components(self, target_type: str = "global"):
components_info = self._fetch_all_registered_components()
if not components_info:
await self._send_message("没有注册的组件")
return
if target_type == "global":
enabled_components = [component for component in components_info if component.enabled]
if not enabled_components:
await self._send_message("没有满足条件的已启用全局组件")
return
enabled_components_str = ", ".join(
f"{component.name} ({component.component_type})" for component in enabled_components
)
await self._send_message(f"满足条件的已启用全局组件: {enabled_components_str}")
elif target_type == "local":
locally_disabled_components = self._fetch_locally_disabled_components()
enabled_components = [
component
for component in components_info
if (component.name not in locally_disabled_components and component.enabled)
]
if not enabled_components:
await self._send_message("本聊天没有满足条件的已启用组件")
return
enabled_components_str = ", ".join(
f"{component.name} ({component.component_type})" for component in enabled_components
)
await self._send_message(f"本聊天满足条件的已启用组件: {enabled_components_str}")
async def _list_disabled_components(self, target_type: str = "global"):
components_info = self._fetch_all_registered_components()
if not components_info:
await self._send_message("没有注册的组件")
return
if target_type == "global":
disabled_components = [component for component in components_info if not component.enabled]
if not disabled_components:
await self._send_message("没有满足条件的已禁用全局组件")
return
disabled_components_str = ", ".join(
f"{component.name} ({component.component_type})" for component in disabled_components
)
await self._send_message(f"满足条件的已禁用全局组件: {disabled_components_str}")
elif target_type == "local":
locally_disabled_components = self._fetch_locally_disabled_components()
disabled_components = [
component
for component in components_info
if (component.name in locally_disabled_components or not component.enabled)
]
if not disabled_components:
await self._send_message("本聊天没有满足条件的已禁用组件")
return
disabled_components_str = ", ".join(
f"{component.name} ({component.component_type})" for component in disabled_components
)
await self._send_message(f"本聊天满足条件的已禁用组件: {disabled_components_str}")
async def _list_registered_components_by_type(self, target_type: str):
match target_type:
case "action":
component_type = ComponentType.ACTION
case "command":
component_type = ComponentType.COMMAND
case "event_handler":
component_type = ComponentType.EVENT_HANDLER
case _:
await self._send_message(f"未知组件类型: {target_type}")
return
components_info = component_manage_api.get_components_info_by_type(component_type)
if not components_info:
await self._send_message(f"没有注册的 {target_type} 组件")
return
components_str = ", ".join(
f"{name} ({component.component_type})" for name, component in components_info.items()
)
await self._send_message(f"注册的 {target_type} 组件: {components_str}")
async def _globally_enable_component(self, component_name: str, component_type: str):
match component_type:
case "action":
target_component_type = ComponentType.ACTION
case "command":
target_component_type = ComponentType.COMMAND
case "event_handler":
target_component_type = ComponentType.EVENT_HANDLER
case _:
await self._send_message(f"未知组件类型: {component_type}")
return
if component_manage_api.globally_enable_component(component_name, target_component_type):
await self._send_message(f"全局启用组件成功: {component_name}")
else:
await self._send_message(f"全局启用组件失败: {component_name}")
async def _globally_disable_component(self, component_name: str, component_type: str):
match component_type:
case "action":
target_component_type = ComponentType.ACTION
case "command":
target_component_type = ComponentType.COMMAND
case "event_handler":
target_component_type = ComponentType.EVENT_HANDLER
case _:
await self._send_message(f"未知组件类型: {component_type}")
return
success = await component_manage_api.globally_disable_component(component_name, target_component_type)
if success:
await self._send_message(f"全局禁用组件成功: {component_name}")
else:
await self._send_message(f"全局禁用组件失败: {component_name}")
async def _locally_enable_component(self, component_name: str, component_type: str):
match component_type:
case "action":
target_component_type = ComponentType.ACTION
case "command":
target_component_type = ComponentType.COMMAND
case "event_handler":
target_component_type = ComponentType.EVENT_HANDLER
case _:
await self._send_message(f"未知组件类型: {component_type}")
return
if component_manage_api.locally_enable_component(
component_name,
target_component_type,
self.message.chat_stream.stream_id,
):
await self._send_message(f"本地启用组件成功: {component_name}")
else:
await self._send_message(f"本地启用组件失败: {component_name}")
async def _locally_disable_component(self, component_name: str, component_type: str):
match component_type:
case "action":
target_component_type = ComponentType.ACTION
case "command":
target_component_type = ComponentType.COMMAND
case "event_handler":
target_component_type = ComponentType.EVENT_HANDLER
case _:
await self._send_message(f"未知组件类型: {component_type}")
return
if component_manage_api.locally_disable_component(
component_name,
target_component_type,
self.message.chat_stream.stream_id,
):
await self._send_message(f"本地禁用组件成功: {component_name}")
else:
await self._send_message(f"本地禁用组件失败: {component_name}")
async def _send_message(self, message: str):
await send_api.text_to_stream(message, self.stream_id, typing=False, storage_message=False)
@register_plugin def create_plugin():
class PluginManagementPlugin(BasePlugin): return PluginManagementPlugin()
plugin_name: str = "plugin_management_plugin"
enable_plugin: bool = False
dependencies: list[str] = []
python_dependencies: list[str] = []
config_file_name: str = "config.toml"
config_schema: dict = {
"plugin": {
"enabled": ConfigField(bool, default=False, description="是否启用插件"),
"config_version": ConfigField(type=str, default="1.1.0", description="配置文件版本"),
"permission": ConfigField(
list, default=[], description="有权限使用插件管理命令的用户列表请填写字符串形式的用户ID"
),
},
}
def get_plugin_components(self) -> List[Tuple[CommandInfo, Type[BaseCommand]]]:
components = []
if self.get_config("plugin.enabled", True):
components.append((ManagementCommand.get_command_info(), ManagementCommand))
return components

View File

@@ -1,7 +1,7 @@
{ {
"manifest_version": 1, "manifest_version": 1,
"name": "文本转语音插件 (Text-to-Speech)", "name": "文本转语音插件 (Text-to-Speech)",
"version": "0.1.0", "version": "2.0.0",
"description": "将文本转换为语音进行播放的插件,支持多种语音模式和智能语音输出场景判断。", "description": "将文本转换为语音进行播放的插件,支持多种语音模式和智能语音输出场景判断。",
"author": { "author": {
"name": "MaiBot团队", "name": "MaiBot团队",
@@ -10,7 +10,7 @@
"license": "GPL-v3.0-or-later", "license": "GPL-v3.0-or-later",
"host_application": { "host_application": {
"min_version": "0.8.0" "min_version": "1.0.0"
}, },
"homepage_url": "https://github.com/MaiM-with-u/maibot", "homepage_url": "https://github.com/MaiM-with-u/maibot",
"repository_url": "https://github.com/MaiM-with-u/maibot", "repository_url": "https://github.com/MaiM-with-u/maibot",

View File

@@ -1,145 +1,55 @@
from src.plugin_system.apis.plugin_register_api import register_plugin """TTS 插件 — 新 SDK 版本
from src.plugin_system.base.base_plugin import BasePlugin
from src.plugin_system.base.component_types import ComponentInfo
from src.common.logger import get_logger
from src.plugin_system.base.base_action import BaseAction, ActionActivationType
from src.plugin_system.base.config_types import ConfigField
from typing import Tuple, List, Type
logger = get_logger("tts") 将文本转换为语音进行播放。
"""
import re
from maibot_sdk import MaiBotPlugin, Action
from maibot_sdk.types import ActivationType
class TTSAction(BaseAction): class TTSPlugin(MaiBotPlugin):
"""TTS语音转换动作处理类""" """文本转语音插件"""
# 激活设置
activation_type = ActionActivationType.KEYWORD
activation_keywords = ["语音", "tts", "播报", "读出来", "语音播放", "", "朗读"]
keyword_case_sensitive = False
parallel_action = False
# 动作基本信息
action_name = "tts_action"
action_description = "将文本转换为语音进行播放,适用于需要语音输出的场景"
# 动作参数定义
action_parameters = {
"voice_text": "你想用语音表达的内容,这段内容将会以语音形式发出",
}
# 动作使用场景
action_require = [
"当需要发送语音信息时使用",
"当用户明确要求使用语音功能时使用",
"当表达内容更适合用语音而不是文字传达时使用",
"当用户想听到语音回答而非阅读文本时使用",
]
# 关联类型
associated_types = ["tts_text"]
async def execute(self) -> Tuple[bool, str]:
"""处理TTS文本转语音动作"""
logger.info(f"{self.log_prefix} 执行TTS动作: {self.reasoning}")
# 获取要转换的文本
text = self.action_data.get("voice_text")
@Action(
"tts_action",
description="将文本转换为语音进行播放,适用于需要语音输出的场景",
activation_type=ActivationType.KEYWORD,
activation_keywords=["语音", "tts", "播报", "读出来", "语音播放", "", "朗读"],
parallel_action=False,
action_parameters={"voice_text": "你想用语音表达的内容,这段内容将会以语音形式发出"},
action_require=[
"当需要发送语音信息时使用",
"当用户明确要求使用语音功能时使用",
"当表达内容更适合用语音而不是文字传达时使用",
"当用户想听到语音回答而非阅读文本时使用",
],
associated_types=["tts_text"],
)
async def handle_tts_action(self, stream_id: str = "", action_data: dict = None, reasoning: str = "", **kwargs):
"""处理 TTS 文本转语音动作"""
action_data = action_data or {}
text = action_data.get("voice_text", "")
if not text: if not text:
logger.error(f"{self.log_prefix} 执行TTS动作时未提供文本内容")
return False, "执行TTS动作失败未提供文本内容" return False, "执行TTS动作失败未提供文本内容"
# 确保文本适合TTS使用 # 文本预处理
processed_text = self._process_text_for_tts(text) processed_text = re.sub(r"([!?,.;:。!?,、;:])\1+", r"\1", text)
try:
# 发送TTS消息
await self.send_custom(message_type="tts_text", content=processed_text)
# 记录动作信息
await self.store_action_info(
action_build_into_prompt=True, action_prompt_display="已经发送了语音消息。", action_done=True
)
logger.info(f"{self.log_prefix} TTS动作执行成功文本长度: {len(processed_text)}")
return True, "TTS动作执行成功"
except Exception as e:
logger.error(f"{self.log_prefix} 执行TTS动作时出错: {e}")
return False, f"执行TTS动作时出错: {e}"
def _process_text_for_tts(self, text: str) -> str:
"""
处理文本使其更适合TTS使用
- 移除不必要的特殊字符和表情符号
- 修正标点符号以提高语音质量
- 优化文本结构使语音更流畅
"""
# 这里可以添加文本处理逻辑
# 例如:移除多余的标点、表情符号,优化语句结构等
# 简单示例实现
processed_text = text
# 移除多余的标点符号
import re
processed_text = re.sub(r"([!?,.;:。!?,、;:])\1+", r"\1", processed_text)
# 确保句子结尾有合适的标点
if not any(processed_text.endswith(end) for end in [".", "?", "!", "", "", ""]): if not any(processed_text.endswith(end) for end in [".", "?", "!", "", "", ""]):
processed_text = f"{processed_text}" processed_text = f"{processed_text}"
return processed_text # 发送自定义 tts 消息
result = await self.ctx.call_capability(
"send.custom",
message_type="tts_text",
content=processed_text,
stream_id=stream_id,
)
if result and result.get("success"):
return True, "TTS动作执行成功"
return False, f"TTS动作执行失败: {result}"
@register_plugin def create_plugin():
class TTSPlugin(BasePlugin): return TTSPlugin()
"""TTS插件
- 这是文字转语音插件
- Normal模式下依靠关键词触发
- Focus模式下由LLM判断触发
- 具有一定的文本预处理能力
"""
# 插件基本信息
plugin_name: str = "tts_plugin" # 内部标识符
enable_plugin: bool = True
dependencies: list[str] = [] # 插件依赖列表
python_dependencies: list[str] = [] # Python包依赖列表
config_file_name: str = "config.toml"
# 配置节描述
config_section_descriptions = {
"plugin": "插件基本信息配置",
"components": "组件启用控制",
"logging": "日志记录相关配置",
}
# 配置Schema定义
config_schema: dict = {
"plugin": {
"name": ConfigField(type=str, default="tts_plugin", description="插件名称", required=True),
"version": ConfigField(type=str, default="0.1.0", description="插件版本号"),
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
"description": ConfigField(type=str, default="文字转语音插件", description="插件描述", required=True),
},
"components": {"enable_tts": ConfigField(type=bool, default=True, description="是否启用TTS Action")},
"logging": {
"level": ConfigField(
type=str, default="INFO", description="日志记录级别", choices=["DEBUG", "INFO", "WARNING", "ERROR"]
),
"prefix": ConfigField(type=str, default="[TTS]", description="日志记录前缀"),
},
}
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
"""返回插件包含的组件列表"""
# 从配置获取组件启用状态
enable_tts = self.get_config("components.enable_tts", True)
components = [] # 添加Action组件
if enable_tts:
components.append((TTSAction.get_action_info(), TTSAction))
return components

7
src/services/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
"""
核心服务层
提供与具体插件系统无关的核心业务服务。
内部模块chat、dream、memory 等)应直接使用此层,
而 plugin_system.apis 仅作为面向插件的薄包装。
"""

View File

@@ -0,0 +1,159 @@
"""
聊天服务模块
提供聊天信息查询和管理的核心功能。
"""
from typing import Any, Dict, List, Optional
from enum import Enum
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
from src.common.logger import get_logger
logger = get_logger("chat_service")
class SpecialTypes(Enum):
"""特殊枚举类型"""
ALL_PLATFORMS = "all_platforms"
class ChatManager:
"""聊天管理器 - 负责聊天信息的查询和管理"""
@staticmethod
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
# sourcery skip: for-append-to-extend
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = []
try:
for _, stream in _chat_manager.sessions.items():
if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform:
streams.append(stream)
logger.debug(f"[ChatService] 获取到 {len(streams)}{platform} 平台的聊天流")
except Exception as e:
logger.error(f"[ChatService] 获取聊天流失败: {e}")
return streams
@staticmethod
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
# sourcery skip: for-append-to-extend
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = []
try:
for _, stream in _chat_manager.sessions.items():
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.is_group_session:
streams.append(stream)
logger.debug(f"[ChatService] 获取到 {len(streams)}{platform} 平台的群聊流")
except Exception as e:
logger.error(f"[ChatService] 获取群聊流失败: {e}")
return streams
@staticmethod
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
# sourcery skip: for-append-to-extend
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = []
try:
for _, stream in _chat_manager.sessions.items():
if (
platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform
) and not stream.is_group_session:
streams.append(stream)
logger.debug(f"[ChatService] 获取到 {len(streams)}{platform} 平台的私聊流")
except Exception as e:
logger.error(f"[ChatService] 获取私聊流失败: {e}")
return streams
@staticmethod
def get_group_stream_by_group_id(
group_id: str, platform: Optional[str] | SpecialTypes = "qq"
) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast
if not isinstance(group_id, str):
raise TypeError("group_id 必须是字符串类型")
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
if not group_id:
raise ValueError("group_id 不能为空")
try:
for _, stream in _chat_manager.sessions.items():
if (
stream.is_group_session
and str(stream.group_id) == str(group_id)
and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform)
):
logger.debug(f"[ChatService] 找到群ID {group_id} 的聊天流")
return stream
logger.warning(f"[ChatService] 未找到群ID {group_id} 的聊天流")
except Exception as e:
logger.error(f"[ChatService] 查找群聊流失败: {e}")
return None
@staticmethod
def get_private_stream_by_user_id(
user_id: str, platform: Optional[str] | SpecialTypes = "qq"
) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast
if not isinstance(user_id, str):
raise TypeError("user_id 必须是字符串类型")
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
if not user_id:
raise ValueError("user_id 不能为空")
try:
for _, stream in _chat_manager.sessions.items():
if (
not stream.is_group_session
and str(stream.user_id) == str(user_id)
and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform)
):
logger.debug(f"[ChatService] 找到用户ID {user_id} 的私聊流")
return stream
logger.warning(f"[ChatService] 未找到用户ID {user_id} 的私聊流")
except Exception as e:
logger.error(f"[ChatService] 查找私聊流失败: {e}")
return None
@staticmethod
def get_stream_type(chat_stream: BotChatSession) -> str:
if not isinstance(chat_stream, BotChatSession):
raise TypeError("chat_stream 必须是 BotChatSession 类型")
if not chat_stream:
raise ValueError("chat_stream 不能为 None")
return "group" if chat_stream.is_group_session else "private"
@staticmethod
def get_stream_info(chat_stream: BotChatSession) -> Dict[str, Any]:
if not chat_stream:
raise ValueError("chat_stream 不能为 None")
if not isinstance(chat_stream, BotChatSession):
raise TypeError("chat_stream 必须是 BotChatSession 类型")
try:
info: Dict[str, Any] = {
"session_id": chat_stream.session_id,
"platform": chat_stream.platform,
"type": ChatManager.get_stream_type(chat_stream),
}
if chat_stream.is_group_session:
info["group_id"] = chat_stream.group_id
if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.group_info:
info["group_name"] = chat_stream.context.message.message_info.group_info.group_name or "未知群聊"
else:
info["group_name"] = "未知群聊"
else:
info["user_id"] = chat_stream.user_id
if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.user_info:
info["user_name"] = chat_stream.context.message.message_info.user_info.user_nickname
else:
info["user_name"] = "未知用户"
return info
except Exception as e:
logger.error(f"[ChatService] 获取聊天流信息失败: {e}")
return {}

View File

@@ -1,28 +1,19 @@
"""配置API模块 """配置服务模块
提供配置读取和用户信息获取等功能 提供配置读取的核心功能
使用方式
from src.plugin_system.apis import config_api
value = config_api.get_global_config("section.key")
platform, user_id = await config_api.get_user_id_by_person_name("用户名")
""" """
from typing import Any from typing import Any
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
logger = get_logger("config_api") logger = get_logger("config_service")
# =============================================================================
# 配置访问API函数
# =============================================================================
def get_global_config(key: str, default: Any = None) -> Any: def get_global_config(key: str, default: Any = None) -> Any:
""" """
安全地从全局配置中获取一个值 安全地从全局配置中获取一个值
插件应使用此方法读取全局配置以保证只读和隔离性
Args: Args:
key: 命名空间式配置键名使用嵌套访问 "section.subsection.key"大小写敏感 key: 命名空间式配置键名使用嵌套访问 "section.subsection.key"大小写敏感
@@ -31,7 +22,6 @@ def get_global_config(key: str, default: Any = None) -> Any:
Returns: Returns:
Any: 配置值或默认值 Any: 配置值或默认值
""" """
# 支持嵌套键访问
keys = key.split(".") keys = key.split(".")
current = global_config current = global_config
@@ -43,7 +33,7 @@ def get_global_config(key: str, default: Any = None) -> Any:
raise KeyError(f"配置中不存在子空间或键 '{k}'") raise KeyError(f"配置中不存在子空间或键 '{k}'")
return current return current
except Exception as e: except Exception as e:
logger.warning(f"[ConfigAPI] 获取全局配置 {key} 失败: {e}") logger.warning(f"[ConfigService] 获取全局配置 {key} 失败: {e}")
return default return default
@@ -59,7 +49,6 @@ def get_plugin_config(plugin_config: dict, key: str, default: Any = None) -> Any
Returns: Returns:
Any: 配置值或默认值 Any: 配置值或默认值
""" """
# 支持嵌套键访问
keys = key.split(".") keys = key.split(".")
current = plugin_config current = plugin_config
@@ -73,5 +62,5 @@ def get_plugin_config(plugin_config: dict, key: str, default: Any = None) -> Any
raise KeyError(f"配置中不存在子空间或键 '{k}'") raise KeyError(f"配置中不存在子空间或键 '{k}'")
return current return current
except Exception as e: except Exception as e:
logger.warning(f"[ConfigAPI] 获取插件配置 {key} 失败: {e}") logger.warning(f"[ConfigService] 获取插件配置 {key} 失败: {e}")
return default return default

View File

@@ -1,6 +1,6 @@
"""数据库API模块 """数据库服务模块
提供数据库操作相关功能统一使用 SQLModel/SQLAlchemy 兼容接口 提供数据库操作相关的核心功能
""" """
import json import json
@@ -10,7 +10,7 @@ from typing import Any, Optional
from src.common.logger import get_logger from src.common.logger import get_logger
logger = get_logger("database_api") logger = get_logger("database_service")
def _to_dict(record: Any) -> dict[str, Any]: def _to_dict(record: Any) -> dict[str, Any]:
@@ -73,7 +73,7 @@ async def db_query(
return query.count() return query.count()
except Exception as e: except Exception as e:
logger.error(f"[DatabaseAPI] 数据库操作出错: {e}") logger.error(f"[DatabaseService] 数据库操作出错: {e}")
traceback.print_exc() traceback.print_exc()
if query_type == "get": if query_type == "get":
return None if single_result else [] return None if single_result else []
@@ -93,7 +93,7 @@ async def db_save(model_class, data: dict[str, Any], key_field: Optional[str] =
new_record = model_class.create(**data) new_record = model_class.create(**data)
return _to_dict(new_record) return _to_dict(new_record)
except Exception as e: except Exception as e:
logger.error(f"[DatabaseAPI] 保存数据库记录出错: {e}") logger.error(f"[DatabaseService] 保存数据库记录出错: {e}")
traceback.print_exc() traceback.print_exc()
return None return None
@@ -119,7 +119,7 @@ async def db_get(
return results[0] if results else None return results[0] if results else None
return results return results
except Exception as e: except Exception as e:
logger.error(f"[DatabaseAPI] 获取数据库记录出错: {e}") logger.error(f"[DatabaseService] 获取数据库记录出错: {e}")
traceback.print_exc() traceback.print_exc()
return None if single_result else [] return None if single_result else []
@@ -163,11 +163,11 @@ async def store_action_info(
ActionRecords, data=record_data, key_field="action_id", key_value=record_data["action_id"] ActionRecords, data=record_data, key_field="action_id", key_value=record_data["action_id"]
) )
if saved_record: if saved_record:
logger.debug(f"[DatabaseAPI] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})") logger.debug(f"[DatabaseService] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})")
else: else:
logger.error(f"[DatabaseAPI] 存储动作信息失败: {action_name}") logger.error(f"[DatabaseService] 存储动作信息失败: {action_name}")
return saved_record return saved_record
except Exception as e: except Exception as e:
logger.error(f"[DatabaseAPI] 存储动作信息时发生错误: {e}") logger.error(f"[DatabaseService] 存储动作信息时发生错误: {e}")
traceback.print_exc() traceback.print_exc()
return None return None

View File

@@ -0,0 +1,406 @@
"""
表情服务模块
提供表情包相关的核心功能。
"""
import base64
import os
import random
import uuid
from typing import Any, Dict, List, Optional, Tuple
from src.chat.emoji_system.emoji_manager import emoji_manager, EMOJI_DIR
from src.common.logger import get_logger
from src.common.utils.utils_image import ImageUtils
from src.config.config import global_config
logger = get_logger("emoji_service")
# =============================================================================
# 表情包获取函数
# =============================================================================
async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]:
"""根据描述选择表情包"""
if not description:
raise ValueError("描述不能为空")
if not isinstance(description, str):
raise TypeError("描述必须是字符串类型")
try:
logger.debug(f"[EmojiService] 根据描述获取表情包: {description}")
emoji_obj = await emoji_manager.get_emoji_for_emotion(description)
if not emoji_obj:
logger.warning(f"[EmojiService] 未找到匹配描述 '{description}' 的表情包")
return None
emoji_path = str(emoji_obj.full_path)
emoji_description = emoji_obj.description
matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else ""
emoji_base64 = ImageUtils.image_path_to_base64(emoji_path)
if not emoji_base64:
logger.error(f"[EmojiService] 无法将表情包文件转换为base64: {emoji_path}")
return None
logger.debug(f"[EmojiService] 成功获取表情包: {emoji_description}, 匹配情感: {matched_emotion}")
return emoji_base64, emoji_description, matched_emotion
except Exception as e:
logger.error(f"[EmojiService] 获取表情包失败: {e}")
return None
async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
"""随机获取指定数量的表情包"""
if not isinstance(count, int):
raise TypeError("count 必须是整数类型")
if count < 0:
raise ValueError("count 不能为负数")
if count == 0:
logger.warning("[EmojiService] count 为0返回空列表")
return []
try:
all_emojis = emoji_manager.emojis
if not all_emojis:
logger.warning("[EmojiService] 没有可用的表情包")
return []
valid_emojis = [emoji for emoji in all_emojis if not emoji.is_deleted]
if not valid_emojis:
logger.warning("[EmojiService] 没有有效的表情包")
return []
if len(valid_emojis) < count:
logger.debug(
f"[EmojiService] 有效表情包数量 ({len(valid_emojis)}) 少于请求的数量 ({count}),将返回所有有效表情包"
)
count = len(valid_emojis)
selected_emojis = random.sample(valid_emojis, count)
results = []
for selected_emoji in selected_emojis:
emoji_base64 = ImageUtils.image_path_to_base64(str(selected_emoji.full_path))
if not emoji_base64:
logger.error(f"[EmojiService] 无法转换表情包为base64: {selected_emoji.full_path}")
continue
matched_emotion = random.choice(selected_emoji.emotion) if selected_emoji.emotion else "随机表情"
emoji_manager.update_emoji_usage(selected_emoji)
results.append((emoji_base64, selected_emoji.description, matched_emotion))
if not results and count > 0:
logger.warning("[EmojiService] 随机获取表情包失败,没有一个可以成功处理")
return []
logger.debug(f"[EmojiService] 成功获取 {len(results)} 个随机表情包")
return results
except Exception as e:
logger.error(f"[EmojiService] 获取随机表情包失败: {e}")
return []
async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
"""根据情感标签获取表情包"""
if not emotion:
raise ValueError("情感标签不能为空")
if not isinstance(emotion, str):
raise TypeError("情感标签必须是字符串类型")
try:
logger.info(f"[EmojiService] 根据情感获取表情包: {emotion}")
all_emojis = emoji_manager.emojis
matching_emojis = []
matching_emojis.extend(
emoji_obj
for emoji_obj in all_emojis
if not emoji_obj.is_deleted and emotion.lower() in [e.lower() for e in emoji_obj.emotion]
)
if not matching_emojis:
logger.warning(f"[EmojiService] 未找到匹配情感 '{emotion}' 的表情包")
return None
selected_emoji = random.choice(matching_emojis)
emoji_base64 = ImageUtils.image_path_to_base64(selected_emoji.full_path)
if not emoji_base64:
logger.error(f"[EmojiService] 无法转换表情包为base64: {selected_emoji.full_path}")
return None
emoji_manager.update_emoji_usage(selected_emoji)
logger.info(f"[EmojiService] 成功获取情感表情包: {selected_emoji.description}")
return emoji_base64, selected_emoji.description, emotion
except Exception as e:
logger.error(f"[EmojiService] 根据情感获取表情包失败: {e}")
return None
# =============================================================================
# 表情包信息查询函数
# =============================================================================
def get_count() -> int:
try:
return len(emoji_manager.emojis)
except Exception as e:
logger.error(f"[EmojiService] 获取表情包数量失败: {e}")
return 0
def get_info():
try:
return {
"current_count": len(emoji_manager.emojis),
"max_count": global_config.emoji.max_reg_num,
"available_emojis": len([e for e in emoji_manager.emojis if not e.is_deleted]),
}
except Exception as e:
logger.error(f"[EmojiService] 获取表情包信息失败: {e}")
return {"current_count": 0, "max_count": 0, "available_emojis": 0}
def get_emotions() -> List[str]:
try:
emotions = set()
for emoji_obj in emoji_manager.emojis:
if not emoji_obj.is_deleted and emoji_obj.emotion:
emotions.update(emoji_obj.emotion)
return sorted(list(emotions))
except Exception as e:
logger.error(f"[EmojiService] 获取情感标签失败: {e}")
return []
async def get_all() -> List[Tuple[str, str, str]]:
try:
all_emojis = emoji_manager.emojis
if not all_emojis:
logger.warning("[EmojiService] 没有可用的表情包")
return []
results = []
for emoji_obj in all_emojis:
if emoji_obj.is_deleted:
continue
emoji_base64 = ImageUtils.image_path_to_base64(str(emoji_obj.full_path))
if not emoji_base64:
logger.error(f"[EmojiService] 无法转换表情包为base64: {emoji_obj.full_path}")
continue
matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else "随机表情"
results.append((emoji_base64, emoji_obj.description, matched_emotion))
logger.debug(f"[EmojiService] 成功获取 {len(results)} 个表情包")
return results
except Exception as e:
logger.error(f"[EmojiService] 获取所有表情包失败: {e}")
return []
def get_descriptions() -> List[str]:
try:
descriptions = []
descriptions.extend(
emoji_obj.description
for emoji_obj in emoji_manager.emojis
if not emoji_obj.is_deleted and emoji_obj.description
)
return descriptions
except Exception as e:
logger.error(f"[EmojiService] 获取表情包描述失败: {e}")
return []
# =============================================================================
# 表情包注册函数
# =============================================================================
async def register_emoji(image_base64: str, filename: Optional[str] = None) -> Dict[str, Any]:
"""注册新的表情包"""
if not image_base64:
raise ValueError("图片base64编码不能为空")
if not isinstance(image_base64, str):
raise TypeError("image_base64必须是字符串类型")
if filename is not None and not isinstance(filename, str):
raise TypeError("filename必须是字符串类型或None")
try:
logger.info(f"[EmojiService] 开始注册表情包,文件名: {filename or '自动生成'}")
count_before = len(emoji_manager.emojis)
max_count = global_config.emoji.max_reg_num
can_register = count_before < max_count or (count_before >= max_count and global_config.emoji.do_replace)
if not can_register:
return {
"success": False,
"message": f"表情包数量已达上限({count_before}/{max_count})且未启用替换功能",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
os.makedirs(EMOJI_DIR, exist_ok=True)
if not filename:
import time as _time
timestamp = int(_time.time())
microseconds = int(_time.time() * 1000000) % 1000000
random_bytes = random.getrandbits(72).to_bytes(9, "big")
short_id = base64.b64encode(random_bytes).decode("ascii")[:12].rstrip("=")
short_id = short_id.replace("/", "_").replace("+", "-")
filename = f"emoji_{timestamp}_{microseconds}_{short_id}"
if not filename.lower().endswith((".jpg", ".jpeg", ".png", ".gif")):
filename = f"{filename}.png"
temp_file_path = os.path.join(EMOJI_DIR, filename)
attempts = 0
max_attempts = 10
while os.path.exists(temp_file_path) and attempts < max_attempts:
random_bytes = random.getrandbits(48).to_bytes(6, "big")
short_id = base64.b64encode(random_bytes).decode("ascii")[:8].rstrip("=")
short_id = short_id.replace("/", "_").replace("+", "-")
name_part, ext = os.path.splitext(filename)
base_name = name_part.rsplit("_", 1)[0]
filename = f"{base_name}_{short_id}{ext}"
temp_file_path = os.path.join(EMOJI_DIR, filename)
attempts += 1
if os.path.exists(temp_file_path):
uuid_short = str(uuid.uuid4())[:8]
name_part, ext = os.path.splitext(filename)
base_name = name_part.rsplit("_", 1)[0]
filename = f"{base_name}_{uuid_short}{ext}"
temp_file_path = os.path.join(EMOJI_DIR, filename)
counter = 1
original_filename = filename
while os.path.exists(temp_file_path):
name_part, ext = os.path.splitext(original_filename)
filename = f"{name_part}_{counter}{ext}"
temp_file_path = os.path.join(EMOJI_DIR, filename)
counter += 1
if counter > 100:
logger.error(f"[EmojiService] 无法生成唯一文件名,尝试次数过多: {original_filename}")
return {
"success": False,
"message": "无法生成唯一文件名,请稍后重试",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
try:
if not ImageUtils.base64_to_image(image_base64, temp_file_path):
logger.error(f"[EmojiService] 无法保存base64图片到文件: {temp_file_path}")
return {
"success": False,
"message": "无法保存图片文件",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
logger.debug(f"[EmojiService] 图片已保存到临时文件: {temp_file_path}")
except Exception as save_error:
logger.error(f"[EmojiService] 保存图片文件失败: {save_error}")
return {
"success": False,
"message": f"保存图片文件失败: {str(save_error)}",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
register_success = await emoji_manager.register_emoji_by_filename(filename)
if not register_success and os.path.exists(temp_file_path):
try:
os.remove(temp_file_path)
logger.debug(f"[EmojiService] 已清理临时文件: {temp_file_path}")
except Exception as cleanup_error:
logger.warning(f"[EmojiService] 清理临时文件失败: {cleanup_error}")
if register_success:
count_after = len(emoji_manager.emojis)
replaced = count_after <= count_before
new_emoji_info = None
if count_after > count_before or replaced:
try:
for emoji_obj in reversed(emoji_manager.emojis):
if not emoji_obj.is_deleted and (
emoji_obj.file_name == filename
or (hasattr(emoji_obj, "full_path") and filename in str(emoji_obj.full_path))
):
new_emoji_info = emoji_obj
break
except Exception as find_error:
logger.warning(f"[EmojiService] 查找新注册表情包信息失败: {find_error}")
description = new_emoji_info.description if new_emoji_info else None
emotions = new_emoji_info.emotion if new_emoji_info else None
emoji_hash = new_emoji_info.emoji_hash if new_emoji_info else None
return {
"success": True,
"message": f"表情包注册成功 {'(替换旧表情包)' if replaced else '(新增表情包)'}",
"description": description,
"emotions": emotions,
"replaced": replaced,
"hash": emoji_hash,
}
else:
return {
"success": False,
"message": "表情包注册失败,可能因为重复、格式不支持或审核未通过",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
except Exception as e:
logger.error(f"[EmojiService] 注册表情包时发生异常: {e}")
return {
"success": False,
"message": f"注册过程中发生错误: {str(e)}",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}

View File

@@ -1,9 +1,11 @@
from src.common.logger import get_logger """频率控制服务模块
提供聊天频率控制的核心功能
"""
from src.chat.heart_flow.frequency_control import frequency_control_manager from src.chat.heart_flow.frequency_control import frequency_control_manager
from src.config.config import global_config from src.config.config import global_config
logger = get_logger("frequency_api")
def get_current_talk_value(chat_id: str) -> float: def get_current_talk_value(chat_id: str) -> float:
return frequency_control_manager.get_or_create_frequency_control( return frequency_control_manager.get_or_create_frequency_control(

View File

@@ -1,39 +1,37 @@
""" """
回复器API模块 回复器服务模块
提供回复器相关功能采用标准Python包设计模式 提供回复器相关的核心功能
使用方式
from src.plugin_system.apis import generator_api
replyer = generator_api.get_replyer(chat_stream)
success, reply_set, _ = await generator_api.generate_reply(chat_stream, action_data, reasoning)
""" """
import traceback import traceback
import time import time
from typing import Tuple, Any, Dict, List, Optional, TYPE_CHECKING from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from rich.traceback import install from rich.traceback import install
from src.common.logger import get_logger
from src.common.data_models.message_data_model import ReplySetModel from src.chat.logger.plan_reply_logger import PlanReplyLogger
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.replyer.group_generator import DefaultReplyer from src.chat.replyer.group_generator import DefaultReplyer
from src.chat.replyer.private_generator import PrivateReplyer from src.chat.replyer.private_generator import PrivateReplyer
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.utils.utils import process_llm_response
from src.chat.replyer.replyer_manager import replyer_manager from src.chat.replyer.replyer_manager import replyer_manager
from src.plugin_system.base.component_types import ActionInfo from src.chat.utils.utils import process_llm_response
from src.chat.logger.plan_reply_logger import PlanReplyLogger from src.common.data_models.message_data_model import ReplySetModel
from src.common.logger import get_logger
from src.core.types import ActionInfo
if TYPE_CHECKING: if TYPE_CHECKING:
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.llm_data_model import LLMGenerationDataModel from src.common.data_models.llm_data_model import LLMGenerationDataModel
install(extra_lines=3) install(extra_lines=3)
logger = get_logger("generator_api") logger = get_logger("generator_service")
# ============================================================================= # =============================================================================
# 回复器获取API函数 # 回复器获取函数
# ============================================================================= # =============================================================================
@@ -42,39 +40,24 @@ def get_replyer(
chat_id: Optional[str] = None, chat_id: Optional[str] = None,
request_type: str = "replyer", request_type: str = "replyer",
) -> Optional[DefaultReplyer | PrivateReplyer]: ) -> Optional[DefaultReplyer | PrivateReplyer]:
"""获取回复器对象 """获取回复器对象"""
优先使用chat_stream如果没有则使用chat_id直接查找
使用 ReplyerManager 来管理实例避免重复创建
Args:
chat_stream: 聊天流对象优先
chat_id: 聊天ID实际上就是stream_id
request_type: 请求类型
Returns:
Optional[DefaultReplyer]: 回复器对象如果获取失败则返回None
Raises:
ValueError: chat_stream chat_id 均为空
"""
if not chat_id and not chat_stream: if not chat_id and not chat_stream:
raise ValueError("chat_stream 和 chat_id 不可均为空") raise ValueError("chat_stream 和 chat_id 不可均为空")
try: try:
logger.debug(f"[GeneratorAPI] 正在获取回复器chat_id: {chat_id}, chat_stream: {'' if chat_stream else ''}") logger.debug(f"[GeneratorService] 正在获取回复器chat_id: {chat_id}, chat_stream: {'' if chat_stream else ''}")
return replyer_manager.get_replyer( return replyer_manager.get_replyer(
chat_stream=chat_stream, chat_stream=chat_stream,
chat_id=chat_id, chat_id=chat_id,
request_type=request_type, request_type=request_type,
) )
except Exception as e: except Exception as e:
logger.error(f"[GeneratorAPI] 获取回复器时发生意外错误: {e}", exc_info=True) logger.error(f"[GeneratorService] 获取回复器时发生意外错误: {e}", exc_info=True)
traceback.print_exc() traceback.print_exc()
return None return None
# ============================================================================= # =============================================================================
# 回复生成API函数 # 回复生成函数
# ============================================================================= # =============================================================================
@@ -96,39 +79,15 @@ async def generate_reply(
from_plugin: bool = True, from_plugin: bool = True,
reply_time_point: Optional[float] = None, reply_time_point: Optional[float] = None,
) -> Tuple[bool, Optional["LLMGenerationDataModel"]]: ) -> Tuple[bool, Optional["LLMGenerationDataModel"]]:
"""生成回复 """生成回复"""
Args:
chat_stream: 聊天流对象优先
chat_id: 聊天ID备用
action_data: 动作数据向下兼容包含reply_to和extra_info
reply_message: 回复的消息对象
extra_info: 额外信息用于补充上下文
reply_reason: 回复原因
available_actions: 可用动作
chosen_actions: 已选动作
unknown_words: Planner reply 动作中给出的未知词语列表用于黑话检索
enable_tool: 是否启用工具调用
enable_splitter: 是否启用消息分割器
enable_chinese_typo: 是否启用错字生成器
return_prompt: 是否返回提示词
model_set_with_weight: 模型配置列表每个元素为 (TaskConfig, weight) 元组
request_type: 请求类型可选记录LLM使用
from_plugin: 是否来自插件
reply_time_point: 回复时间点
Returns:
Tuple[bool, List[Tuple[str, Any]], Optional[str]]: (是否成功, 回复集合, 提示词)
"""
try: try:
# 如果 reply_time_point 未传入,设置为当前时间戳
if reply_time_point is None: if reply_time_point is None:
reply_time_point = time.time() reply_time_point = time.time()
# 获取回复器 logger.debug("[GeneratorService] 开始生成回复")
logger.debug("[GeneratorAPI] 开始生成回复")
replyer = get_replyer(chat_stream, chat_id, request_type=request_type) replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
if not replyer: if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器") logger.error("[GeneratorService] 无法获取回复器")
return False, None return False, None
if action_data: if action_data:
@@ -136,11 +95,9 @@ async def generate_reply(
extra_info = action_data.get("extra_info", "") extra_info = action_data.get("extra_info", "")
if not reply_reason: if not reply_reason:
reply_reason = action_data.get("reason", "") reply_reason = action_data.get("reason", "")
# 仅在 reply 场景下使用的未知词语解析Planner JSON 中下发)
if unknown_words is None: if unknown_words is None:
uw = action_data.get("unknown_words") uw = action_data.get("unknown_words")
if isinstance(uw, list): if isinstance(uw, list):
# 只保留非空字符串
cleaned: List[str] = [] cleaned: List[str] = []
for item in uw: for item in uw:
if isinstance(item, str): if isinstance(item, str):
@@ -150,7 +107,6 @@ async def generate_reply(
if cleaned: if cleaned:
unknown_words = cleaned unknown_words = cleaned
# 调用回复器生成回复
success, llm_response = await replyer.generate_reply_with_context( success, llm_response = await replyer.generate_reply_with_context(
extra_info=extra_info, extra_info=extra_info,
available_actions=available_actions, available_actions=available_actions,
@@ -166,7 +122,7 @@ async def generate_reply(
log_reply=False, log_reply=False,
) )
if not success: if not success:
logger.warning("[GeneratorAPI] 回复生成失败") logger.warning("[GeneratorService] 回复生成失败")
return False, None return False, None
reply_set: Optional[ReplySetModel] = None reply_set: Optional[ReplySetModel] = None
if content := llm_response.content: if content := llm_response.content:
@@ -176,9 +132,8 @@ async def generate_reply(
for text in processed_response: for text in processed_response:
reply_set.add_text_content(text) reply_set.add_text_content(text)
llm_response.reply_set = reply_set llm_response.reply_set = reply_set
logger.debug(f"[GeneratorAPI] 回复生成成功,生成了 {len(reply_set) if reply_set else 0} 个回复项") logger.debug(f"[GeneratorService] 回复生成成功,生成了 {len(reply_set) if reply_set else 0} 个回复项")
# 统一在这里记录最终回复日志(包含分割后的 processed_output
try: try:
PlanReplyLogger.log_reply( PlanReplyLogger.log_reply(
chat_id=chat_stream.session_id if chat_stream else (chat_id or ""), chat_id=chat_stream.session_id if chat_stream else (chat_id or ""),
@@ -192,7 +147,7 @@ async def generate_reply(
success=True, success=True,
) )
except Exception: except Exception:
logger.exception("[GeneratorAPI] 记录reply日志失败") logger.exception("[GeneratorService] 记录reply日志失败")
return success, llm_response return success, llm_response
@@ -200,11 +155,11 @@ async def generate_reply(
raise ve raise ve
except UserWarning as uw: except UserWarning as uw:
logger.warning(f"[GeneratorAPI] 中断了生成: {uw}") logger.warning(f"[GeneratorService] 中断了生成: {uw}")
return False, None return False, None
except Exception as e: except Exception as e:
logger.error(f"[GeneratorAPI] 生成回复时出错: {e}") logger.error(f"[GeneratorService] 生成回复时出错: {e}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return False, None return False, None
@@ -220,39 +175,20 @@ async def rewrite_reply(
reply_to: str = "", reply_to: str = "",
request_type: str = "generator_api", request_type: str = "generator_api",
) -> Tuple[bool, Optional["LLMGenerationDataModel"]]: ) -> Tuple[bool, Optional["LLMGenerationDataModel"]]:
"""重写回复 """重写回复"""
Args:
chat_stream: 聊天流对象优先
reply_data: 回复数据字典向下兼容备用当其他参数缺失时从此获取
chat_id: 聊天ID备用
enable_splitter: 是否启用消息分割器
enable_chinese_typo: 是否启用错字生成器
model_set_with_weight: 模型配置列表每个元素为 (TaskConfig, weight) 元组
raw_reply: 原始回复内容
reason: 回复原因
reply_to: 回复对象
return_prompt: 是否返回提示词
Returns:
Tuple[bool, List[Tuple[str, Any]]]: (是否成功, 回复集合)
"""
try: try:
# 获取回复器
replyer = get_replyer(chat_stream, chat_id, request_type=request_type) replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
if not replyer: if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器") logger.error("[GeneratorService] 无法获取回复器")
return False, None return False, None
logger.info("[GeneratorAPI] 开始重写回复") logger.info("[GeneratorService] 开始重写回复")
# 如果参数缺失从reply_data中获取
if reply_data: if reply_data:
raw_reply = raw_reply or reply_data.get("raw_reply", "") raw_reply = raw_reply or reply_data.get("raw_reply", "")
reason = reason or reply_data.get("reason", "") reason = reason or reply_data.get("reason", "")
reply_to = reply_to or reply_data.get("reply_to", "") reply_to = reply_to or reply_data.get("reply_to", "")
# 调用回复器重写回复
success, llm_response = await replyer.rewrite_reply_with_context( success, llm_response = await replyer.rewrite_reply_with_context(
raw_reply=raw_reply, raw_reply=raw_reply,
reason=reason, reason=reason,
@@ -263,9 +199,9 @@ async def rewrite_reply(
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo) reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
llm_response.reply_set = reply_set llm_response.reply_set = reply_set
if success: if success:
logger.info(f"[GeneratorAPI] 重写回复成功,生成了 {len(reply_set) if reply_set else 0} 个回复项") logger.info(f"[GeneratorService] 重写回复成功,生成了 {len(reply_set) if reply_set else 0} 个回复项")
else: else:
logger.warning("[GeneratorAPI] 重写回复失败") logger.warning("[GeneratorService] 重写回复失败")
return success, llm_response return success, llm_response
@@ -273,18 +209,12 @@ async def rewrite_reply(
raise ve raise ve
except Exception as e: except Exception as e:
logger.error(f"[GeneratorAPI] 重写回复时出错: {e}") logger.error(f"[GeneratorService] 重写回复时出错: {e}")
return False, None return False, None
def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> Optional[ReplySetModel]: def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> Optional[ReplySetModel]:
"""将文本处理为更拟人化的文本 """将文本处理为更拟人化的文本"""
Args:
content: 文本内容
enable_splitter: 是否启用消息分割器
enable_chinese_typo: 是否启用错字生成器
"""
if not isinstance(content, str): if not isinstance(content, str):
raise ValueError("content 必须是字符串类型") raise ValueError("content 必须是字符串类型")
try: try:
@@ -297,7 +227,7 @@ def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo:
return reply_set return reply_set
except Exception as e: except Exception as e:
logger.error(f"[GeneratorAPI] 处理人形文本时出错: {e}") logger.error(f"[GeneratorService] 处理人形文本时出错: {e}")
return None return None
@@ -309,18 +239,18 @@ async def generate_response_custom(
) -> Optional[str]: ) -> Optional[str]:
replyer = get_replyer(chat_stream, chat_id, request_type=request_type) replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
if not replyer: if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器") logger.error("[GeneratorService] 无法获取回复器")
return None return None
try: try:
logger.debug("[GeneratorAPI] 开始生成自定义回复") logger.debug("[GeneratorService] 开始生成自定义回复")
response, _, _, _ = await replyer.llm_generate_content(prompt) response, _, _, _ = await replyer.llm_generate_content(prompt)
if response: if response:
logger.debug("[GeneratorAPI] 自定义回复生成成功") logger.debug("[GeneratorService] 自定义回复生成成功")
return response return response
else: else:
logger.warning("[GeneratorAPI] 自定义回复生成失败") logger.warning("[GeneratorService] 自定义回复生成失败")
return None return None
except Exception as e: except Exception as e:
logger.error(f"[GeneratorAPI] 生成自定义回复时出错: {e}") logger.error(f"[GeneratorService] 生成自定义回复时出错: {e}")
return None return None

View File

@@ -1,26 +1,19 @@
"""LLM API模块 """LLM 服务模块
提供了与LLM模型交互的功能 提供 LLM 模型交互的核心功能
使用方式
from src.plugin_system.apis import llm_api
models = llm_api.get_available_models()
success, response, reasoning, model_name = await llm_api.generate_with_model(prompt, model_config)
""" """
from typing import Tuple, Dict, List, Any, Optional, Callable from typing import Any, Callable, Dict, List, Optional, Tuple
from src.common.logger import get_logger from src.common.logger import get_logger
from src.llm_models.payload_content.tool_option import ToolCall
from src.llm_models.payload_content.message import Message
from src.llm_models.model_client.base_client import BaseClient
from src.llm_models.utils_model import LLMRequest
from src.config.config import config_manager from src.config.config import config_manager
from src.config.model_configs import TaskConfig from src.config.model_configs import TaskConfig
from src.llm_models.model_client.base_client import BaseClient
from src.llm_models.payload_content.message import Message
from src.llm_models.payload_content.tool_option import ToolCall
from src.llm_models.utils_model import LLMRequest
logger = get_logger("llm_api") logger = get_logger("llm_service")
# =============================================================================
# LLM模型API函数
# =============================================================================
def get_available_models() -> Dict[str, TaskConfig]: def get_available_models() -> Dict[str, TaskConfig]:
@@ -30,7 +23,6 @@ def get_available_models() -> Dict[str, TaskConfig]:
Dict[str, Any]: 模型配置字典key为模型名称value为模型配置 Dict[str, Any]: 模型配置字典key为模型名称value为模型配置
""" """
try: try:
# 自动获取所有属性并转换为字典形式
models = config_manager.get_model_config().model_task_config models = config_manager.get_model_config().model_task_config
attrs = dir(models) attrs = dir(models)
rets: Dict[str, TaskConfig] = {} rets: Dict[str, TaskConfig] = {}
@@ -41,12 +33,12 @@ def get_available_models() -> Dict[str, TaskConfig]:
if not callable(value) and isinstance(value, TaskConfig): if not callable(value) and isinstance(value, TaskConfig):
rets[attr] = value rets[attr] = value
except Exception as e: except Exception as e:
logger.debug(f"[LLMAPI] 获取属性 {attr} 失败: {e}") logger.debug(f"[LLMService] 获取属性 {attr} 失败: {e}")
continue continue
return rets return rets
except Exception as e: except Exception as e:
logger.error(f"[LLMAPI] 获取可用模型失败: {e}") logger.error(f"[LLMService] 获取可用模型失败: {e}")
return {} return {}
@@ -68,9 +60,7 @@ async def generate_with_model(
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称) Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
""" """
try: try:
# model_name_list = model_config.model_list logger.debug(f"[LLMService] 完整提示词: {prompt}")
# logger.info(f"[LLMAPI] 使用模型集合 {model_name_list} 生成内容")
logger.debug(f"[LLMAPI] 完整提示词: {prompt}")
llm_request = LLMRequest(model_set=model_config, request_type=request_type) llm_request = LLMRequest(model_set=model_config, request_type=request_type)
@@ -81,7 +71,7 @@ async def generate_with_model(
except Exception as e: except Exception as e:
error_msg = f"生成内容时出错: {str(e)}" error_msg = f"生成内容时出错: {str(e)}"
logger.error(f"[LLMAPI] {error_msg}") logger.error(f"[LLMService] {error_msg}")
return False, error_msg, "", "" return False, error_msg, "", ""
@@ -104,7 +94,7 @@ async def generate_with_model_with_tools(
max_tokens: 最大token数 max_tokens: 最大token数
Returns: Returns:
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称) Tuple[bool, str, str, str, List[ToolCall] | None]: (是否成功, 生成的内容, 推理过程, 模型名称, 工具调用列表)
""" """
try: try:
model_name_list = model_config.model_list model_name_list = model_config.model_list
@@ -120,7 +110,7 @@ async def generate_with_model_with_tools(
except Exception as e: except Exception as e:
error_msg = f"生成内容时出错: {str(e)}" error_msg = f"生成内容时出错: {str(e)}"
logger.error(f"[LLMAPI] {error_msg}") logger.error(f"[LLMService] {error_msg}")
return False, error_msg, "", "", None return False, error_msg, "", "", None
@@ -161,5 +151,5 @@ async def generate_with_model_with_tools_by_message_factory(
except Exception as e: except Exception as e:
error_msg = f"生成内容时出错: {str(e)}" error_msg = f"生成内容时出错: {str(e)}"
logger.error(f"[LLMAPI] {error_msg}") logger.error(f"[LLMService] {error_msg}")
return False, error_msg, "", "", None return False, error_msg, "", "", None

View File

@@ -1,62 +1,44 @@
""" """
消息API模块 消息服务模块
提供消息查询和构建成字符串的功能采用标准Python包设计模式 提供消息查询和构建成字符串的核心功能
使用方式
from src.plugin_system.apis import message_api
messages = message_api.get_messages_by_time_in_chat(chat_id, start_time, end_time)
readable_text = message_api.build_readable_messages(messages)
""" """
import time import time
from typing import List, Dict, Any, Tuple, Optional from typing import Any, Dict, List, Optional, Tuple
from sqlmodel import col, select from sqlmodel import col, select
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.database.database import get_db_session
from src.common.database.database_model import Images, ImageType
from src.chat.utils.utils import is_bot_self
from src.chat.utils.chat_message_builder import ( from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp,
get_raw_msg_by_timestamp_with_chat,
get_raw_msg_by_timestamp_with_chat_inclusive,
get_raw_msg_by_timestamp_with_chat_users,
get_raw_msg_by_timestamp_random,
get_raw_msg_by_timestamp_with_users,
get_raw_msg_before_timestamp,
get_raw_msg_before_timestamp_with_chat,
get_raw_msg_before_timestamp_with_users,
num_new_messages_since,
num_new_messages_since_with_users,
build_readable_messages, build_readable_messages,
build_readable_messages_with_list, build_readable_messages_with_list,
get_person_id_list, get_person_id_list,
get_raw_msg_before_timestamp,
get_raw_msg_before_timestamp_with_chat,
get_raw_msg_before_timestamp_with_users,
get_raw_msg_by_timestamp,
get_raw_msg_by_timestamp_random,
get_raw_msg_by_timestamp_with_chat,
get_raw_msg_by_timestamp_with_chat_inclusive,
get_raw_msg_by_timestamp_with_chat_users,
get_raw_msg_by_timestamp_with_users,
num_new_messages_since,
num_new_messages_since_with_users,
) )
from src.chat.utils.utils import is_bot_self
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.database.database import get_db_session
from src.common.database.database_model import Images, ImageType
# ============================================================================= # =============================================================================
# 消息查询API函数 # 消息查询函数
# ============================================================================= # =============================================================================
def get_messages_by_time( def get_messages_by_time(
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
) -> List[DatabaseMessages]: ) -> List[DatabaseMessages]:
"""
获取指定时间范围内的消息
Args:
start_time: 开始时间戳
end_time: 结束时间戳
limit: 限制返回的消息数量0为不限制
limit_mode: 当limit>0时生效'earliest'表示获取最早的记录'latest'表示获取最新的记录
filter_mai: 是否过滤麦麦自身的消息默认为False
Returns:
List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型") raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0: if limit < 0:
@@ -76,23 +58,6 @@ def get_messages_by_time_in_chat(
filter_command: bool = False, filter_command: bool = False,
filter_intercept_message_level: Optional[int] = None, filter_intercept_message_level: Optional[int] = None,
) -> List[DatabaseMessages]: ) -> List[DatabaseMessages]:
"""
获取指定聊天中指定时间范围内的消息
Args:
chat_id: 聊天ID
start_time: 开始时间戳
end_time: 结束时间戳
limit: 限制返回的消息数量0为不限制
limit_mode: 当limit>0时生效'earliest'表示获取最早的记录'latest'表示获取最新的记录
filter_mai: 是否过滤麦麦自身的消息默认为False
filter_command: 是否过滤命令消息默认为False
Returns:
List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型") raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0: if limit < 0:
@@ -101,10 +66,6 @@ def get_messages_by_time_in_chat(
raise ValueError("chat_id 不能为空") raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str): if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型") raise ValueError("chat_id 必须是字符串类型")
# if filter_mai:
# return filter_mai_messages(
# get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command)
# )
return get_raw_msg_by_timestamp_with_chat( return get_raw_msg_by_timestamp_with_chat(
chat_id=chat_id, chat_id=chat_id,
timestamp_start=start_time, timestamp_start=start_time,
@@ -127,23 +88,6 @@ def get_messages_by_time_in_chat_inclusive(
filter_command: bool = False, filter_command: bool = False,
filter_intercept_message_level: Optional[int] = None, filter_intercept_message_level: Optional[int] = None,
) -> List[DatabaseMessages]: ) -> List[DatabaseMessages]:
"""
获取指定聊天中指定时间范围内的消息包含边界
Args:
chat_id: 聊天ID
start_time: 开始时间戳包含
end_time: 结束时间戳包含
limit: 限制返回的消息数量0为不限制
limit_mode: 当limit>0时生效'earliest'表示获取最早的记录'latest'表示获取最新的记录
filter_mai: 是否过滤麦麦自身的消息默认为False
Returns:
List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型") raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0: if limit < 0:
@@ -175,23 +119,6 @@ def get_messages_by_time_in_chat_for_users(
limit: int = 0, limit: int = 0,
limit_mode: str = "latest", limit_mode: str = "latest",
) -> List[DatabaseMessages]: ) -> List[DatabaseMessages]:
"""
获取指定聊天中指定用户在指定时间范围内的消息
Args:
chat_id: 聊天ID
start_time: 开始时间戳
end_time: 结束时间戳
person_ids: 用户ID列表
limit: 限制返回的消息数量0为不限制
limit_mode: 当limit>0时生效'earliest'表示获取最早的记录'latest'表示获取最新的记录
Returns:
List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型") raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0: if limit < 0:
@@ -206,22 +133,6 @@ def get_messages_by_time_in_chat_for_users(
def get_random_chat_messages( def get_random_chat_messages(
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
) -> List[DatabaseMessages]: ) -> List[DatabaseMessages]:
"""
随机选择一个聊天返回该聊天在指定时间范围内的消息
Args:
start_time: 开始时间戳
end_time: 结束时间戳
limit: 限制返回的消息数量0为不限制
limit_mode: 当limit>0时生效'earliest'表示获取最早的记录'latest'表示获取最新的记录
filter_mai: 是否过滤麦麦自身的消息默认为False
Returns:
List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型") raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0: if limit < 0:
@@ -234,22 +145,6 @@ def get_random_chat_messages(
def get_messages_by_time_for_users( def get_messages_by_time_for_users(
start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest" start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
) -> List[DatabaseMessages]: ) -> List[DatabaseMessages]:
"""
获取指定用户在所有聊天中指定时间范围内的消息
Args:
start_time: 开始时间戳
end_time: 结束时间戳
person_ids: 用户ID列表
limit: 限制返回的消息数量0为不限制
limit_mode: 当limit>0时生效'earliest'表示获取最早的记录'latest'表示获取最新的记录
Returns:
List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型") raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0: if limit < 0:
@@ -258,20 +153,6 @@ def get_messages_by_time_for_users(
def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[DatabaseMessages]: def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[DatabaseMessages]:
"""
获取指定时间戳之前的消息
Args:
timestamp: 时间戳
limit: 限制返回的消息数量0为不限制
filter_mai: 是否过滤麦麦自身的消息默认为False
Returns:
List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(timestamp, (int, float)): if not isinstance(timestamp, (int, float)):
raise ValueError("timestamp 必须是数字类型") raise ValueError("timestamp 必须是数字类型")
if limit < 0: if limit < 0:
@@ -288,21 +169,6 @@ def get_messages_before_time_in_chat(
filter_mai: bool = False, filter_mai: bool = False,
filter_intercept_message_level: Optional[int] = None, filter_intercept_message_level: Optional[int] = None,
) -> List[DatabaseMessages]: ) -> List[DatabaseMessages]:
"""
获取指定聊天中指定时间戳之前的消息
Args:
chat_id: 聊天ID
timestamp: 时间戳
limit: 限制返回的消息数量0为不限制
filter_mai: 是否过滤麦麦自身的消息默认为False
Returns:
List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(timestamp, (int, float)): if not isinstance(timestamp, (int, float)):
raise ValueError("timestamp 必须是数字类型") raise ValueError("timestamp 必须是数字类型")
if limit < 0: if limit < 0:
@@ -325,20 +191,6 @@ def get_messages_before_time_in_chat(
def get_messages_before_time_for_users( def get_messages_before_time_for_users(
timestamp: float, person_ids: List[str], limit: int = 0 timestamp: float, person_ids: List[str], limit: int = 0
) -> List[DatabaseMessages]: ) -> List[DatabaseMessages]:
"""
获取指定用户在指定时间戳之前的消息
Args:
timestamp: 时间戳
person_ids: 用户ID列表
limit: 限制返回的消息数量0为不限制
Returns:
List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(timestamp, (int, float)): if not isinstance(timestamp, (int, float)):
raise ValueError("timestamp 必须是数字类型") raise ValueError("timestamp 必须是数字类型")
if limit < 0: if limit < 0:
@@ -349,22 +201,6 @@ def get_messages_before_time_for_users(
def get_recent_messages( def get_recent_messages(
chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False
) -> List[DatabaseMessages]: ) -> List[DatabaseMessages]:
"""
获取指定聊天中最近一段时间的消息
Args:
chat_id: 聊天ID
hours: 最近多少小时默认24小时
limit: 限制返回的消息数量默认100条
limit_mode: 当limit>0时生效'earliest'表示获取最早的记录'latest'表示获取最新的记录
filter_mai: 是否过滤麦麦自身的消息默认为False
Returns:
List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法s
"""
if not isinstance(hours, (int, float)) or hours < 0: if not isinstance(hours, (int, float)) or hours < 0:
raise ValueError("hours 不能是负数") raise ValueError("hours 不能是负数")
if not isinstance(limit, int) or limit < 0: if not isinstance(limit, int) or limit < 0:
@@ -381,25 +217,11 @@ def get_recent_messages(
# ============================================================================= # =============================================================================
# 消息计数API函数 # 消息计数函数
# ============================================================================= # =============================================================================
def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional[float] = None) -> int: def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional[float] = None) -> int:
"""
计算指定聊天中从开始时间到结束时间的新消息数量
Args:
chat_id: 聊天ID
start_time: 开始时间戳
end_time: 结束时间戳如果为None则使用当前时间
Returns:
int: 新消息数量
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(start_time, (int, float)): if not isinstance(start_time, (int, float)):
raise ValueError("start_time 必须是数字类型") raise ValueError("start_time 必须是数字类型")
if not chat_id: if not chat_id:
@@ -410,21 +232,6 @@ def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional
def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int: def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int:
"""
计算指定聊天中指定用户从开始时间到结束时间的新消息数量
Args:
chat_id: 聊天ID
start_time: 开始时间戳
end_time: 结束时间戳
person_ids: 用户ID列表
Returns:
int: 新消息数量
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型") raise ValueError("start_time 和 end_time 必须是数字类型")
if not chat_id: if not chat_id:
@@ -435,7 +242,7 @@ def count_new_messages_for_users(chat_id: str, start_time: float, end_time: floa
# ============================================================================= # =============================================================================
# 消息格式化API函数 # 消息格式化函数
# ============================================================================= # =============================================================================
@@ -447,21 +254,6 @@ def build_readable_messages_to_str(
truncate: bool = False, truncate: bool = False,
show_actions: bool = False, show_actions: bool = False,
) -> str: ) -> str:
"""
将消息列表构建成可读的字符串
Args:
messages: 消息列表
replace_bot_name: 是否将机器人的名称替换为""
merge_messages: 是否合并连续消息
timestamp_mode: 时间戳显示模式'relative''absolute'
read_mark: 已读标记时间戳用于分割已读和未读消息
truncate: 是否截断长消息
show_actions: 是否显示动作记录
Returns:
格式化后的可读字符串
"""
return build_readable_messages(messages, replace_bot_name, timestamp_mode, read_mark, truncate, show_actions) return build_readable_messages(messages, replace_bot_name, timestamp_mode, read_mark, truncate, show_actions)
@@ -471,32 +263,10 @@ async def build_readable_messages_with_details(
timestamp_mode: str = "relative", timestamp_mode: str = "relative",
truncate: bool = False, truncate: bool = False,
) -> Tuple[str, List[Tuple[float, str, str]]]: ) -> Tuple[str, List[Tuple[float, str, str]]]:
"""
将消息列表构建成可读的字符串并返回详细信息
Args:
messages: 消息列表
replace_bot_name: 是否将机器人的名称替换为""
merge_messages: 是否合并连续消息
timestamp_mode: 时间戳显示模式'relative''absolute'
truncate: 是否截断长消息
Returns:
格式化后的可读字符串和详细信息元组列表(时间戳, 昵称, 内容)
"""
return await build_readable_messages_with_list(messages, replace_bot_name, timestamp_mode, truncate) return await build_readable_messages_with_list(messages, replace_bot_name, timestamp_mode, truncate)
async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[str]: async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[str]:
"""
从消息列表中提取不重复的用户ID列表
Args:
messages: 消息列表
Returns:
用户ID列表
"""
return await get_person_id_list(messages) return await get_person_id_list(messages)
@@ -506,14 +276,7 @@ async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[s
def filter_mai_messages(messages: List[DatabaseMessages]) -> List[DatabaseMessages]: def filter_mai_messages(messages: List[DatabaseMessages]) -> List[DatabaseMessages]:
""" """从消息列表中移除麦麦的消息"""
从消息列表中移除麦麦的消息
Args:
messages: 消息列表每个元素是消息字典
Returns:
过滤后的消息列表
"""
# 使用统一的 is_bot_self 函数判断是否是机器人自己(支持多平台,包括 WebUI
return [msg for msg in messages if not is_bot_self(msg.user_info.platform, msg.user_info.user_id)] return [msg for msg in messages if not is_bot_self(msg.user_info.platform, msg.user_info.user_id)]

View File

@@ -1,22 +1,14 @@
"""个人信息API模块 """个人信息服务模块
提供个人信息查询功能用于插件获取用户相关信息 提供个人信息查询的核心功能
使用方式
from src.plugin_system.apis import person_api
person_id = person_api.get_person_id("qq", 123456)
value = await person_api.get_person_value(person_id, "nickname")
""" """
from typing import Any from typing import Any
from src.common.logger import get_logger from src.common.logger import get_logger
from src.person_info.person_info import Person from src.person_info.person_info import Person
logger = get_logger("person_api") logger = get_logger("person_service")
# =============================================================================
# 个人信息API函数
# =============================================================================
def get_person_id(platform: str, user_id: int | str) -> str: def get_person_id(platform: str, user_id: int | str) -> str:
@@ -28,14 +20,11 @@ def get_person_id(platform: str, user_id: int | str) -> str:
Returns: Returns:
str: 唯一的person_idMD5哈希值 str: 唯一的person_idMD5哈希值
示例:
person_id = person_api.get_person_id("qq", 123456)
""" """
try: try:
return Person(platform=platform, user_id=str(user_id)).person_id return Person(platform=platform, user_id=str(user_id)).person_id
except Exception as e: except Exception as e:
logger.error(f"[PersonAPI] 获取person_id失败: platform={platform}, user_id={user_id}, error={e}") logger.error(f"[PersonService] 获取person_id失败: platform={platform}, user_id={user_id}, error={e}")
return "" return ""
@@ -49,17 +38,13 @@ async def get_person_value(person_id: str, field_name: str, default: Any = None)
Returns: Returns:
Any: 字段值或默认值 Any: 字段值或默认值
示例:
nickname = await person_api.get_person_value(person_id, "nickname", "未知用户")
impression = await person_api.get_person_value(person_id, "impression")
""" """
try: try:
person = Person(person_id=person_id) person = Person(person_id=person_id)
value = getattr(person, field_name) value = getattr(person, field_name)
return value if value is not None else default return value if value is not None else default
except Exception as e: except Exception as e:
logger.error(f"[PersonAPI] 获取用户信息失败: person_id={person_id}, field={field_name}, error={e}") logger.error(f"[PersonService] 获取用户信息失败: person_id={person_id}, field={field_name}, error={e}")
return default return default
@@ -71,13 +56,10 @@ def get_person_id_by_name(person_name: str) -> str:
Returns: Returns:
str: person_id如果未找到返回空字符串 str: person_id如果未找到返回空字符串
示例:
person_id = person_api.get_person_id_by_name("张三")
""" """
try: try:
person = Person(person_name=person_name) person = Person(person_name=person_name)
return person.person_id return person.person_id
except Exception as e: except Exception as e:
logger.error(f"[PersonAPI] 根据用户名获取person_id失败: person_name={person_name}, error={e}") logger.error(f"[PersonService] 根据用户名获取person_id失败: person_name={person_name}, error={e}")
return "" return ""

View File

@@ -1,47 +1,32 @@
""" """
发送API模块 发送服务模块
专门负责发送各种类型消息采用标准Python包设计模式 提供发送各种类型消息的核心功能
使用方式
from src.plugin_system.apis import send_api
# 方式1直接使用stream_id推荐
await send_api.text_to_stream("hello", stream_id)
await send_api.emoji_to_stream(emoji_base64, stream_id)
await send_api.custom_to_stream("video", video_data, stream_id)
# 方式2使用群聊/私聊指定函数
await send_api.text_to_group("hello", "123456")
await send_api.text_to_user("hello", "987654")
# 方式3使用通用custom_message函数
await send_api.custom_message("video", video_data, "123456", True)
""" """
import traceback import traceback
import time import time
from typing import Optional, Union, Dict, List, TYPE_CHECKING, Tuple from typing import Optional, Union, Dict, List, TYPE_CHECKING, Tuple
from src.common.logger import get_logger from maim_message import MessageBase, BaseMessageInfo, Seg
from src.common.data_models.message_data_model import ReplyContentType
from src.config.config import global_config
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from maim_message import Seg
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.chat.message_receive.message import MessageSending from src.chat.message_receive.message import MessageSending
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
from src.common.data_models.message_data_model import ReplyContentType
from src.common.logger import get_logger
from src.config.config import global_config
if TYPE_CHECKING: if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.message_data_model import ReplySetModel, ReplyContent, ForwardNode from src.common.data_models.message_data_model import ForwardNode, ReplyContent, ReplySetModel
logger = get_logger("send_api") logger = get_logger("send_service")
# ============================================================================= # =============================================================================
# 内部实现函数(不暴露给外部) # 内部实现函数
# ============================================================================= # =============================================================================
@@ -56,42 +41,25 @@ async def _send_to_target(
show_log: bool = True, show_log: bool = True,
selected_expressions: Optional[List[int]] = None, selected_expressions: Optional[List[int]] = None,
) -> bool: ) -> bool:
"""向指定目标发送消息的内部实现 """向指定目标发送消息的内部实现"""
Args:
message_segment:
stream_id: 目标流ID
display_message: 显示消息
typing: 是否模拟打字等待
reply_to: 回复消息格式为"发送者:消息内容"
storage_message: 是否存储消息到数据库
show_log: 发送是否显示日志
Returns:
bool: 是否发送成功
"""
try: try:
if set_reply and not reply_message: if set_reply and not reply_message:
logger.warning("[SendAPI] 使用引用回复,但未提供回复消息") logger.warning("[SendService] 使用引用回复,但未提供回复消息")
return False return False
if show_log: if show_log:
logger.debug(f"[SendAPI] 发送{message_segment.type}消息到 {stream_id}") logger.debug(f"[SendService] 发送{message_segment.type}消息到 {stream_id}")
# 查找目标聊天流
target_stream = _chat_manager.get_session_by_session_id(stream_id) target_stream = _chat_manager.get_session_by_session_id(stream_id)
if not target_stream: if not target_stream:
logger.error(f"[SendAPI] 未找到聊天流: {stream_id}") logger.error(f"[SendService] 未找到聊天流: {stream_id}")
return False return False
# 创建发送器
message_sender = UniversalMessageSender() message_sender = UniversalMessageSender()
# 生成消息ID
current_time = time.time() current_time = time.time()
message_id = f"send_api_{int(current_time * 1000)}" message_id = f"send_api_{int(current_time * 1000)}"
# 构建机器人用户信息
bot_user_info = UserInfo( bot_user_info = UserInfo(
user_id=global_config.bot.qq_account, user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname, user_nickname=global_config.bot.nickname,
@@ -102,17 +70,15 @@ async def _send_to_target(
if reply_message: if reply_message:
anchor_message = db_message_to_mai_message(reply_message) anchor_message = db_message_to_mai_message(reply_message)
if anchor_message: if anchor_message:
logger.debug(f"[SendAPI] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}") logger.debug(f"[SendService] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}")
reply_to_platform_id = ( reply_to_platform_id = (
f"{anchor_message.platform}:{anchor_message.message_info.user_info.user_id}" f"{anchor_message.platform}:{anchor_message.message_info.user_info.user_id}"
) )
# 构建 sender_info私聊时为接收者信息
sender_info = None sender_info = None
if target_stream.context and target_stream.context.message: if target_stream.context and target_stream.context.message:
sender_info = target_stream.context.message.message_info.user_info sender_info = target_stream.context.message.message_info.user_info
# 构建发送消息对象
bot_message = MessageSending( bot_message = MessageSending(
message_id=message_id, message_id=message_id,
session=target_stream, session=target_stream,
@@ -128,7 +94,6 @@ async def _send_to_target(
selected_expressions=selected_expressions, selected_expressions=selected_expressions,
) )
# 发送消息
sent_msg = await message_sender.send_message( sent_msg = await message_sender.send_message(
bot_message, bot_message,
typing=typing, typing=typing,
@@ -138,28 +103,22 @@ async def _send_to_target(
) )
if sent_msg: if sent_msg:
logger.debug(f"[SendAPI] 成功发送消息到 {stream_id}") logger.debug(f"[SendService] 成功发送消息到 {stream_id}")
return True return True
else: else:
logger.error("[SendAPI] 发送消息失败") logger.error("[SendService] 发送消息失败")
return False return False
except Exception as e: except Exception as e:
logger.error(f"[SendAPI] 发送消息时出错: {e}") logger.error(f"[SendService] 发送消息时出错: {e}")
traceback.print_exc() traceback.print_exc()
return False return False
def db_message_to_mai_message(message_obj: "DatabaseMessages") -> Optional[MaiMessage]: def db_message_to_mai_message(message_obj: "DatabaseMessages") -> Optional[MaiMessage]:
"""将数据库消息重建为 MaiMessage 对象,用于回复引用。 """将数据库消息重建为 MaiMessage 对象,用于回复引用。"""
Args:
message_obj: 插件系统的 DatabaseMessages 数据对象
Returns:
Optional[MaiMessage]: 构建的消息对象如果信息不足则返回 None
"""
from datetime import datetime from datetime import datetime
from src.common.data_models.mai_message_data_model import GroupInfo, MessageInfo from src.common.data_models.mai_message_data_model import GroupInfo, MessageInfo
from src.common.data_models.message_component_data_model import MessageSequence from src.common.data_models.message_component_data_model import MessageSequence
@@ -190,7 +149,7 @@ def db_message_to_mai_message(message_obj: "DatabaseMessages") -> Optional[MaiMe
# ============================================================================= # =============================================================================
# 公共API函数 - 预定义类型的发送函数 # 公共函数 - 预定义类型的发送函数
# ============================================================================= # =============================================================================
@@ -203,18 +162,7 @@ async def text_to_stream(
storage_message: bool = True, storage_message: bool = True,
selected_expressions: Optional[List[int]] = None, selected_expressions: Optional[List[int]] = None,
) -> bool: ) -> bool:
"""向指定流发送文本消息 """向指定流发送文本消息"""
Args:
text: 要发送的文本内容
stream_id: 聊天流ID
typing: 是否显示正在输入
reply_to: 回复消息格式为"发送者:消息内容"
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
return await _send_to_target( return await _send_to_target(
message_segment=Seg(type="text", data=text), message_segment=Seg(type="text", data=text),
stream_id=stream_id, stream_id=stream_id,
@@ -234,16 +182,7 @@ async def emoji_to_stream(
set_reply: bool = False, set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None, reply_message: Optional["DatabaseMessages"] = None,
) -> bool: ) -> bool:
"""向指定流发送表情包 """向指定流发送表情包"""
Args:
emoji_base64: 表情包的base64编码
stream_id: 聊天流ID
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
return await _send_to_target( return await _send_to_target(
message_segment=Seg(type="emoji", data=emoji_base64), message_segment=Seg(type="emoji", data=emoji_base64),
stream_id=stream_id, stream_id=stream_id,
@@ -262,16 +201,7 @@ async def image_to_stream(
set_reply: bool = False, set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None, reply_message: Optional["DatabaseMessages"] = None,
) -> bool: ) -> bool:
"""向指定流发送图片 """向指定流发送图片"""
Args:
image_base64: 图片的base64编码
stream_id: 聊天流ID
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
return await _send_to_target( return await _send_to_target(
message_segment=Seg(type="image", data=image_base64), message_segment=Seg(type="image", data=image_base64),
stream_id=stream_id, stream_id=stream_id,
@@ -289,17 +219,7 @@ async def command_to_stream(
storage_message: bool = True, storage_message: bool = True,
display_message: str = "", display_message: str = "",
) -> bool: ) -> bool:
"""向指定流发送命令 """向指定流发送命令"""
Args:
command: 命令
stream_id: 聊天流ID
storage_message: 是否存储消息到数据库
display_message: 显示消息
Returns:
bool: 是否发送成功
"""
return await _send_to_target( return await _send_to_target(
message_segment=Seg(type="command", data=command), # type: ignore message_segment=Seg(type="command", data=command), # type: ignore
stream_id=stream_id, stream_id=stream_id,
@@ -321,20 +241,7 @@ async def custom_to_stream(
storage_message: bool = True, storage_message: bool = True,
show_log: bool = True, show_log: bool = True,
) -> bool: ) -> bool:
"""向指定流发送自定义类型消息 """向指定流发送自定义类型消息"""
Args:
message_type: 消息类型"text""image""emoji""video""file"
content: 消息内容通常是base64编码或文本
stream_id: 聊天流ID
display_message: 显示消息
typing: 是否显示正在输入
reply_to: 回复消息格式为"发送者:消息内容"
storage_message: 是否存储消息到数据库
show_log: 是否显示日志
Returns:
bool: 是否发送成功
"""
return await _send_to_target( return await _send_to_target(
message_segment=Seg(type=message_type, data=content), # type: ignore message_segment=Seg(type=message_type, data=content), # type: ignore
stream_id=stream_id, stream_id=stream_id,
@@ -350,25 +257,14 @@ async def custom_to_stream(
async def custom_reply_set_to_stream( async def custom_reply_set_to_stream(
reply_set: "ReplySetModel", reply_set: "ReplySetModel",
stream_id: str, stream_id: str,
display_message: str = "", # 基本没用 display_message: str = "",
typing: bool = False, typing: bool = False,
reply_message: Optional["DatabaseMessages"] = None, reply_message: Optional["DatabaseMessages"] = None,
set_reply: bool = False, set_reply: bool = False,
storage_message: bool = True, storage_message: bool = True,
show_log: bool = True, show_log: bool = True,
) -> bool: ) -> bool:
""" """向指定流发送混合型消息集"""
向指定流发送混合型消息集
Args:
reply_set: ReplySetModel 对象包含多个 ReplyContent
stream_id: 聊天流ID
display_message: 显示消息
typing: 是否显示正在输入
reply_to: 回复消息格式为"发送者:消息内容"
storage_message: 是否存储消息到数据库
show_log: 是否显示日志
"""
flag: bool = True flag: bool = True
for reply_content in reply_set.reply_data: for reply_content in reply_set.reply_data:
status: bool = False status: bool = False
@@ -386,20 +282,14 @@ async def custom_reply_set_to_stream(
if not status: if not status:
flag = False flag = False
logger.error( logger.error(
f"[SendAPI] 发送{repr(reply_content.content_type)}消息失败,消息内容:{str(reply_content.content)[:100]}" f"[SendService] 发送{repr(reply_content.content_type)}消息失败,消息内容:{str(reply_content.content)[:100]}"
) )
return flag return flag
def _parse_content_to_seg(reply_content: "ReplyContent") -> Tuple[Seg, bool]: def _parse_content_to_seg(reply_content: "ReplyContent") -> Tuple[Seg, bool]:
""" """把 ReplyContent 转换为 Seg 结构"""
ReplyContent 转换为 Seg 结构 (Forward 中仅递归一次)
Args:
reply_content: ReplyContent 对象
Returns:
Tuple[Seg, bool]: 转换后的 Seg 结构和是否需要typing的标志
"""
content_type = reply_content.content_type content_type = reply_content.content_type
if content_type == ReplyContentType.TEXT: if content_type == ReplyContentType.TEXT:
text_data: str = reply_content.content # type: ignore text_data: str = reply_content.content # type: ignore
@@ -427,7 +317,7 @@ def _parse_content_to_seg(reply_content: "ReplyContent") -> Tuple[Seg, bool]:
elif sub_content_type == ReplyContentType.EMOJI: elif sub_content_type == ReplyContentType.EMOJI:
sub_seg_list.append(Seg(type="emoji", data=sub_content_data)) # type: ignore sub_seg_list.append(Seg(type="emoji", data=sub_content_data)) # type: ignore
else: else:
logger.warning(f"[SendAPI] 混合类型中不支持的子内容类型: {repr(sub_content_type)}") logger.warning(f"[SendService] 混合类型中不支持的子内容类型: {repr(sub_content_type)}")
continue continue
return Seg(type="seglist", data=sub_seg_list), True return Seg(type="seglist", data=sub_seg_list), True
elif content_type == ReplyContentType.FORWARD: elif content_type == ReplyContentType.FORWARD:

View File

@@ -6,7 +6,7 @@ import json
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.toml_utils import save_toml_with_format from src.common.toml_utils import save_toml_with_format
from src.config.config import MMC_VERSION from src.config.config import MMC_VERSION
from src.plugin_system.base.config_types import ConfigField from src.core.config_types import ConfigField
from src.webui.services.git_mirror_service import get_git_mirror_service, set_update_progress_callback from src.webui.services.git_mirror_service import get_git_mirror_service, set_update_progress_callback
from src.webui.core import get_token_manager from src.webui.core import get_token_manager
from src.webui.routers.websocket.plugin_progress import update_progress from src.webui.routers.websocket.plugin_progress import update_progress
@@ -222,15 +222,16 @@ def coerce_types(schema_part: Dict[str, Any], config_part: Dict[str, Any]) -> No
def find_plugin_instance(plugin_id: str) -> Optional[Any]: def find_plugin_instance(plugin_id: str) -> Optional[Any]:
""" """
按 plugin_id 或 plugin_name 查找已加载的插件实例 按 plugin_id 查找已加载的插件信息
局部导入 plugin_manager 以规避循环依赖 新运行时中插件运行在子进程,无法获取实例,返回注册信息
""" """
from src.plugin_system.core.plugin_manager import plugin_manager from src.plugin_runtime.integration import get_plugin_runtime_manager
for loaded_plugin_name in plugin_manager.list_loaded_plugins(): mgr = get_plugin_runtime_manager()
instance = plugin_manager.get_plugin_instance(loaded_plugin_name) for sv in mgr.supervisors:
if instance and (instance.plugin_name == plugin_id or instance.get_manifest_info("id", "") == plugin_id): reg = sv._registered_plugins.get(plugin_id)
return instance if reg is not None:
return reg
return None return None
@@ -1497,26 +1498,10 @@ async def get_plugin_config_schema(
logger.info(f"获取插件配置 Schema: {plugin_id}") logger.info(f"获取插件配置 Schema: {plugin_id}")
try: try:
# 尝试从已加载的插件中获取 # 新运行时中插件运行在子进程,无法直接获取实例的 webui_config_schema
from src.plugin_system.core.plugin_manager import plugin_manager # 尝试从文件系统读取
# 查找插件实例
plugin_instance = None plugin_instance = None
# 遍历所有已加载的插件
for loaded_plugin_name in plugin_manager.list_loaded_plugins():
instance = plugin_manager.get_plugin_instance(loaded_plugin_name)
if instance:
# 匹配 plugin_name 或 manifest 中的 id
if instance.plugin_name == plugin_id:
plugin_instance = instance
break
# 也尝试匹配 manifest 中的 id
manifest_id = instance.get_manifest_info("id", "")
if manifest_id == plugin_id:
plugin_instance = instance
break
if plugin_instance and hasattr(plugin_instance, "get_webui_config_schema"): if plugin_instance and hasattr(plugin_instance, "get_webui_config_schema"):
# 从插件实例获取 schema # 从插件实例获取 schema
schema = plugin_instance.get_webui_config_schema() schema = plugin_instance.get_webui_config_schema()