重构整个插件系统,尝试恢复可启动性,新增插件系统maibot-plugin-sdk依赖

This commit is contained in:
DrSmoothl
2026-03-07 19:40:51 +08:00
parent 2e3dd44ee9
commit ce8d8dfd0a
90 changed files with 3785 additions and 10061 deletions

6
bot.py
View File

@@ -195,11 +195,11 @@ async def graceful_shutdown(): # sourcery skip: use-named-expression
# except Exception as e:
# logger.warning(f"关闭 WebUI 服务器时出错: {e}")
from src.plugin_system.core.events_manager import events_manager
from src.plugin_system.base.component_types import EventType
from src.core.event_bus import event_bus
from src.core.types import EventType
# 触发 ON_STOP 事件
await events_manager.handle_mai_events(event_type=EventType.ON_STOP)
await event_bus.emit(event_type=EventType.ON_STOP)
# 停止新版本插件运行时
from src.plugin_runtime.integration import get_plugin_runtime_manager

View File

@@ -9,7 +9,7 @@
},
"license": "GPL-v3.0-or-later",
"host_application": {
"min_version": "0.10.3"
"min_version": "1.0.0"
},
"homepage_url": "https://github.com/SengokuCola/BetterFrequency",
"repository_url": "https://github.com/SengokuCola/BetterFrequency",

View File

@@ -1,144 +1,87 @@
from typing import List, Tuple, Type, Optional
from src.plugin_system import BasePlugin, register_plugin, BaseCommand, ComponentInfo, ConfigField
from src.plugin_system.apis import send_api, frequency_api
"""发言频率控制插件 — 新 SDK 版本
通过 /chat 命令设置和查看聊天频率。
"""
from maibot_sdk import MaiBotPlugin, Command
class SetTalkFrequencyCommand(BaseCommand):
"""设置当前聊天的talk_frequency值"""
class BetterFrequencyPlugin(MaiBotPlugin):
"""聊天频率控制插件"""
command_name = "set_talk_frequency"
command_description = "设置当前聊天的talk_frequency值/chat talk_frequency <数字> 或 /chat t <数字>"
command_pattern = r"^/chat\s+(?:talk_frequency|t)\s+(?P<value>[+-]?\d*\.?\d+)$"
async def execute(self) -> Tuple[bool, Optional[str], bool]:
try:
# 获取命令参数 - 使用命名捕获组
if not self.matched_groups or "value" not in self.matched_groups:
@Command(
"set_talk_frequency",
description="设置当前聊天的talk_frequency值/chat talk_frequency <数字> 或 /chat t <数字>",
pattern=r"^/chat\s+(?:talk_frequency|t)\s+(?P<value>[+-]?\d*\.?\d+)$",
)
async def handle_set_talk_frequency(
self, stream_id: str = "", matched_groups: dict | None = None, **kwargs
):
"""设置当前聊天的 talk_frequency"""
if not matched_groups or "value" not in matched_groups:
return False, "命令格式错误", False
value_str = self.matched_groups["value"]
value_str = matched_groups["value"]
if not value_str:
return False, "无法获取数值参数", False
try:
value = float(value_str)
except ValueError:
await self.ctx.send.text("数值格式错误,请输入有效的数字", stream_id)
return False, "数值格式错误", False
# 获取聊天流ID
if not self.message.chat_stream or not hasattr(self.message.chat_stream, "stream_id"):
if not stream_id:
return False, "无法获取聊天流信息", False
chat_id = self.message.chat_stream.stream_id
# 设置 talk_frequency
frequency_api.set_talk_frequency_adjust(chat_id, value)
await self.ctx.frequency.set_adjust(stream_id, value)
final_value = frequency_api.get_current_talk_value(chat_id)
adjust_value = frequency_api.get_talk_frequency_adjust(chat_id)
base_value = final_value / adjust_value
# 获取当前状态
current = await self.ctx.frequency.get_current_talk_value(stream_id)
current_val = current if isinstance(current, (int, float)) else 0
adjust = await self.ctx.frequency.get_adjust(stream_id)
adjust_val = adjust if isinstance(adjust, (int, float)) else 1
base_val = current_val / adjust_val if adjust_val else 0
# 发送反馈消息(不保存到数据库)
await send_api.text_to_stream(
f"已设置当前聊天的talk_frequency调整值为: {value}\n当前talk_value: {final_value:.2f}\n发言频率调整: {adjust_value:.2f}\n基础值: {base_value:.2f}",
chat_id,
storage_message=False,
msg = (
f"已设置当前聊天的talk_frequency调整值为: {value}\n"
f"当前talk_value: {current_val:.2f}\n"
f"发言频率调整: {adjust_val:.2f}\n"
f"基础值: {base_val:.2f}"
)
await self.ctx.send.text(msg, stream_id)
return True, None, False
except ValueError:
error_msg = "数值格式错误,请输入有效的数字"
await self.send_text(error_msg, storage_message=False)
return False, error_msg, False
except Exception as e:
error_msg = f"设置talk_frequency失败: {str(e)}"
await self.send_text(error_msg, storage_message=False)
return False, error_msg, False
class ShowFrequencyCommand(BaseCommand):
"""显示当前聊天的频率控制状态"""
command_name = "show_frequency"
command_description = "显示当前聊天的频率控制状态:/chat show 或 /chat s"
command_pattern = r"^/chat\s+(?:show|s)$"
async def execute(self) -> Tuple[bool, Optional[str], bool]:
try:
# 获取聊天流ID
if not self.message.chat_stream or not hasattr(self.message.chat_stream, "stream_id"):
@Command(
"show_frequency",
description="显示当前聊天的频率控制状态:/chat show 或 /chat s",
pattern=r"^/chat\s+(?:show|s)$",
)
async def handle_show_frequency(self, stream_id: str = "", **kwargs):
"""显示当前频率控制状态"""
if not stream_id:
return False, "无法获取聊天流信息", False
chat_id = self.message.chat_stream.stream_id
# 获取当前频率控制状态
current_talk_frequency = frequency_api.get_current_talk_value(chat_id)
talk_frequency_adjust = frequency_api.get_talk_frequency_adjust(chat_id)
base_value = current_talk_frequency / talk_frequency_adjust
# 构建显示消息
status_msg = f"""当前聊天频率控制状态
Talk Value (发言频率):
• 基础值: {base_value:.2f}
• 发言频率调整: {talk_frequency_adjust:.2f}
• 当前值: {current_talk_frequency:.2f}
使用命令:
• /chat talk_frequency <数字> 或 /chat t <数字> - 设置发言频率调整
• /chat show 或 /chat s - 显示当前状态"""
# 发送状态消息(不保存到数据库)
await send_api.text_to_stream(status_msg, chat_id, storage_message=False)
current = await self.ctx.frequency.get_current_talk_value(stream_id)
current_val = current if isinstance(current, (int, float)) else 0
adjust = await self.ctx.frequency.get_adjust(stream_id)
adjust_val = adjust if isinstance(adjust, (int, float)) else 1
base_val = current_val / adjust_val if adjust_val else 0
status_msg = (
"当前聊天频率控制状态\n"
"Talk Value (发言频率):\n\n"
f" • 基础值: {base_val:.2f}\n"
f" • 发言频率调整: {adjust_val:.2f}\n"
f" • 当前值: {current_val:.2f}\n\n"
"使用命令:\n"
" • /chat talk_frequency <数字> 或 /chat t <数字> - 设置发言频率调整\n"
" • /chat show 或 /chat s - 显示当前状态"
)
await self.ctx.send.text(status_msg, stream_id)
return True, None, False
except Exception as e:
error_msg = f"获取频率控制状态失败: {str(e)}"
# 使用内置的send_text方法发送错误消息
await self.send_text(error_msg, storage_message=False)
return False, error_msg, False
# ===== 插件注册 =====
@register_plugin
class BetterFrequencyPlugin(BasePlugin):
"""BetterFrequency插件 - 控制聊天频率的插件"""
# 插件基本信息
plugin_name: str = "better_frequency_plugin"
enable_plugin: bool = True
dependencies: List[str] = []
python_dependencies: List[str] = []
config_file_name: str = "config.toml"
# 配置节描述
config_section_descriptions = {"plugin": "插件基本信息", "frequency": "频率控制配置", "features": "功能开关配置"}
# 配置Schema定义
config_schema: dict = {
"plugin": {
"name": ConfigField(type=str, default="better_frequency_plugin", description="插件名称"),
"version": ConfigField(type=str, default="1.0.2", description="插件版本"),
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
},
"frequency": {
"default_talk_adjust": ConfigField(type=float, default=1.0, description="默认talk_frequency调整值"),
"max_adjust_value": ConfigField(type=float, default=1.0, description="最大调整值"),
"min_adjust_value": ConfigField(type=float, default=0.0, description="最小调整值"),
},
}
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
components = []
# 根据配置决定是否注册命令组件
if self.config.get("features", {}).get("enable_commands", True):
components.extend(
[
(SetTalkFrequencyCommand.get_command_info(), SetTalkFrequencyCommand),
(ShowFrequencyCommand.get_command_info(), ShowFrequencyCommand),
]
)
return components
def create_plugin():
return BetterFrequencyPlugin()

View File

@@ -1,7 +1,7 @@
{
"manifest_version": 1,
"name": "BetterEmoji",
"version": "1.0.0",
"version": "2.0.0",
"description": "更好的表情包管理插件",
"author": {
"name": "SengokuCola",
@@ -9,7 +9,7 @@
},
"license": "GPL-v3.0-or-later",
"host_application": {
"min_version": "0.10.4"
"min_version": "1.0.0"
},
"homepage_url": "https://github.com/SengokuCola/BetterEmoji",
"repository_url": "https://github.com/SengokuCola/BetterEmoji",
@@ -19,46 +19,49 @@
"plugin"
],
"categories": [
"Examples",
"Tutorial"
"Emoji",
"Management"
],
"default_locale": "zh-CN",
"locales_path": "_locales",
"plugin_info": {
"is_built_in": false,
"plugin_type": "emoji_manage",
"capabilities": [
"emoji.get_random",
"emoji.get_count",
"emoji.get_info",
"emoji.get_all",
"emoji.register_emoji",
"emoji.delete_emoji",
"send.text",
"send.forward"
],
"components": [
{
"type": "action",
"name": "hello_greeting",
"description": "向用户发送问候消息"
},
{
"type": "action",
"name": "bye_greeting",
"description": "向用户发送告别消息",
"activation_modes": [
"keyword"
],
"keywords": [
"再见",
"bye",
"88",
"拜拜"
]
"type": "command",
"name": "add_emoji",
"description": "添加表情包",
"pattern": "/emoji add"
},
{
"type": "command",
"name": "time",
"description": "查询当前时间",
"pattern": "/time"
"name": "emoji_list",
"description": "列表表情包",
"pattern": "/emoji list"
},
{
"type": "command",
"name": "delete_emoji",
"description": "删除表情包",
"pattern": "/emoji delete"
},
{
"type": "command",
"name": "random_emojis",
"description": "发送多张随机表情包",
"pattern": "/random_emojis"
}
],
"features": [
"问候和告别功能",
"时间查询命令",
"配置文件示例",
"新手教程代码"
]
},
"id": "SengokuCola.BetterEmoji"

View File

@@ -1,399 +1,216 @@
from typing import List, Tuple, Type
from src.plugin_system import (
BasePlugin,
register_plugin,
BaseCommand,
ComponentInfo,
ConfigField,
ReplyContentType,
emoji_api,
)
from maim_message import Seg
from src.common.logger import get_logger
"""表情包管理插件 — 新 SDK 版本
logger = get_logger("emoji_manage_plugin")
通过 /emoji 命令管理表情包的添加、列表和删除。
"""
import base64
import datetime
import hashlib
import re
from maibot_sdk import MaiBotPlugin, Command
class AddEmojiCommand(BaseCommand):
command_name = "add_emoji"
command_description = "添加表情包"
command_pattern = r".*/emoji add.*"
class EmojiManagePlugin(MaiBotPlugin):
"""表情包管理插件"""
async def execute(self) -> Tuple[bool, str, bool]:
# 查找消息中的表情包
# logger.info(f"查找消息中的表情包: {self.message.message_segment}")
# ===== 工具方法 =====
emoji_base64_list = self.find_and_return_emoji_in_message(self.message.message_segment)
@staticmethod
def _extract_emoji_base64(segments) -> list[str]:
"""从消息 segments 中提取 emoji/image 的 base64 数据。
segments 可以是 dict 列表或 Seg 对象列表(兼容两种格式)。
"""
results: list[str] = []
if not segments:
return results
if isinstance(segments, dict):
seg_type = segments.get("type", "")
if seg_type in ("emoji", "image"):
data = segments.get("data", "")
if data:
results.append(data)
elif seg_type == "seglist":
for child in segments.get("data", []):
results.extend(EmojiManagePlugin._extract_emoji_base64(child))
return results
# 如果有 .type 属性Seg 对象)
if hasattr(segments, "type"):
seg_type = getattr(segments, "type", "")
if seg_type in ("emoji", "image"):
results.append(getattr(segments, "data", ""))
elif seg_type == "seglist":
for child in getattr(segments, "data", []):
results.extend(EmojiManagePlugin._extract_emoji_base64(child))
return results
# 列表
for seg in segments:
results.extend(EmojiManagePlugin._extract_emoji_base64(seg))
return results
# ===== Command 组件 =====
@Command("add_emoji", description="添加表情包", pattern=r".*/emoji add.*")
async def handle_add_emoji(self, stream_id: str = "", message_segments=None, **kwargs):
"""添加表情包"""
emoji_base64_list = self._extract_emoji_base64(message_segments)
if not emoji_base64_list:
await self.ctx.send.text("未在消息中找到表情包或图片", stream_id)
return False, "未在消息中找到表情包或图片", False
# 注册找到的表情包
success_count = 0
fail_count = 0
results = []
for i, emoji_base64 in enumerate(emoji_base64_list):
try:
# 使用emoji_api注册表情包让API自动生成唯一文件名
result = await emoji_api.register_emoji(emoji_base64)
if result["success"]:
for i, emoji_b64 in enumerate(emoji_base64_list):
result = await self.ctx.emoji.register_emoji(emoji_b64)
if isinstance(result, dict) and result.get("success"):
success_count += 1
description = result.get("description", "未知描述")
desc = result.get("description", "未知描述")
emotions = result.get("emotions", [])
replaced = result.get("replaced", False)
result_msg = f"表情包 {i + 1} 注册成功{'(替换旧表情包)' if replaced else '(新增表情包)'}"
if description:
result_msg += f"\n描述: {description}"
msg = f"表情包 {i + 1} 注册成功{'(替换旧表情包)' if replaced else '(新增表情包)'}"
if desc:
msg += f"\n描述: {desc}"
if emotions:
result_msg += f"\n情感标签: {', '.join(emotions)}"
results.append(result_msg)
msg += f"\n情感标签: {', '.join(emotions)}"
results.append(msg)
else:
fail_count += 1
error_msg = result.get("message", "注册失败")
results.append(f"表情包 {i + 1} 注册失败: {error_msg}")
err = result.get("message", "注册失败") if isinstance(result, dict) else "注册失败"
results.append(f"表情包 {i + 1} 注册失败: {err}")
except Exception as e:
fail_count += 1
results.append(f"表情包 {i + 1} 注册时发生错误: {str(e)}")
# 构建返回消息
total_count = success_count + fail_count
summary_msg = f"表情包注册完成: 成功 {success_count} 个,失败 {fail_count} 个,共处理 {total_count}"
# 如果有结果详情,添加到返回消息中
details_msg = ""
total = success_count + fail_count
summary = f"表情包注册完成: 成功 {success_count} 个,失败 {fail_count} 个,共处理 {total}"
if results:
details_msg = "\n" + "\n".join(results)
final_msg = summary_msg + details_msg
else:
final_msg = summary_msg
summary += "\n" + "\n".join(results)
# 使用表达器重写回复
try:
from src.plugin_system.apis import generator_api
await self.ctx.send.text(summary, stream_id)
return success_count > 0, summary, success_count > 0
# 构建重写数据
rewrite_data = {
"raw_reply": summary_msg,
"reason": f"注册了表情包:{details_msg}\n",
}
# 调用表达器重写
result_status, data = await generator_api.rewrite_reply(
chat_stream=self.message.chat_stream,
reply_data=rewrite_data,
)
if result_status:
# 发送重写后的回复
for reply_seg in data.reply_set.reply_data:
send_data = reply_seg.content
await self.send_text(send_data)
return success_count > 0, final_msg, success_count > 0
else:
# 如果重写失败,发送原始消息
await self.send_text(final_msg)
return success_count > 0, final_msg, success_count > 0
except Exception as e:
# 如果表达器调用失败,发送原始消息
logger.error(f"[add_emoji] 表达器重写失败: {e}")
await self.send_text(final_msg)
return success_count > 0, final_msg, success_count > 0
def find_and_return_emoji_in_message(self, message_segments) -> List[str]:
emoji_base64_list = []
# 处理单个Seg对象的情况
if isinstance(message_segments, Seg):
if message_segments.type == "emoji":
emoji_base64_list.append(message_segments.data)
elif message_segments.type == "image":
# 假设图片数据是base64编码的
emoji_base64_list.append(message_segments.data)
elif message_segments.type == "seglist":
# 递归处理嵌套的Seg列表
emoji_base64_list.extend(self.find_and_return_emoji_in_message(message_segments.data))
return emoji_base64_list
# 处理Seg列表的情况
for seg in message_segments:
if seg.type == "emoji":
emoji_base64_list.append(seg.data)
elif seg.type == "image":
# 假设图片数据是base64编码的
emoji_base64_list.append(seg.data)
elif seg.type == "seglist":
# 递归处理嵌套的Seg列表
emoji_base64_list.extend(self.find_and_return_emoji_in_message(seg.data))
return emoji_base64_list
class ListEmojiCommand(BaseCommand):
"""列表表情包Command - 响应/emoji list命令"""
command_name = "emoji_list"
command_description = "列表表情包"
# === 命令设置(必须填写)===
command_pattern = r"^/emoji list(\s+\d+)?$" # 匹配 "/emoji list" 或 "/emoji list 数量"
async def execute(self) -> Tuple[bool, str, bool]:
"""执行列表表情包"""
from src.plugin_system.apis import emoji_api
import datetime
# 解析命令参数
import re
match = re.match(r"^/emoji list(?:\s+(\d+))?$", self.message.raw_message)
max_count = 10 # 默认显示10个
@Command("emoji_list", description="列表表情包", pattern=r"^/emoji list(\s+\d+)?$")
async def handle_list_emoji(self, stream_id: str = "", raw_message: str = "", **kwargs):
"""列出表情包"""
max_count = 10
match = re.match(r"^/emoji list(?:\s+(\d+))?$", raw_message)
if match and match.group(1):
max_count = min(int(match.group(1)), 50) # 最多显示50个
max_count = min(int(match.group(1)), 50)
# 获取当前时间
time_format: str = self.get_config("time.format", "%Y-%m-%d %H:%M:%S") # type: ignore
now = datetime.datetime.now()
time_str = now.strftime(time_format)
now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# 获取表情包信息
emoji_count = emoji_api.get_count()
emoji_info = emoji_api.get_info()
count_result = await self.ctx.emoji.get_count()
emoji_count = count_result if isinstance(count_result, int) else 0
# 构建返回消息
message_lines = [
f"📊 表情包统计信息 ({time_str})",
f"• 总数: {emoji_count} / {emoji_info['max_count']}",
f"• 可用: {emoji_info['available_emojis']}",
info_result = await self.ctx.emoji.get_info()
max_emoji = info_result.get("max_count", 0) if isinstance(info_result, dict) else 0
available = info_result.get("available_emojis", 0) if isinstance(info_result, dict) else 0
lines = [
f"📊 表情包统计信息 ({now})",
f"• 总数: {emoji_count} / {max_emoji}",
f"• 可用: {available}",
]
if emoji_count == 0:
message_lines.append("\n❌ 暂无表情包")
final_message = "\n".join(message_lines)
await self.send_text(final_message)
return True, final_message, True
lines.append("\n❌ 暂无表情包")
await self.ctx.send.text("\n".join(lines), stream_id)
return True, "\n".join(lines), True
# 获取所有表情包
all_emojis = await emoji_api.get_all()
all_result = await self.ctx.emoji.get_all()
all_emojis = all_result if isinstance(all_result, list) else []
if not all_emojis:
message_lines.append("\n❌ 无法获取表情包列表")
final_message = "\n".join(message_lines)
await self.send_text(final_message)
return False, final_message, True
lines.append("\n❌ 无法获取表情包列表")
await self.ctx.send.text("\n".join(lines), stream_id)
return False, "\n".join(lines), True
# 显示前N个表情包
display_emojis = all_emojis[:max_count]
message_lines.append(f"\n📋 显示前 {len(display_emojis)} 个表情包:")
display = all_emojis[:max_count]
lines.append(f"\n📋 显示前 {len(display)} 个表情包:")
for i, emoji in enumerate(display, 1):
if isinstance(emoji, (list, tuple)) and len(emoji) >= 3:
_, desc, emotion = emoji[0], emoji[1], emoji[2]
elif isinstance(emoji, dict):
desc = emoji.get("description", "")
emotion = emoji.get("emotion", "")
else:
desc, emotion = str(emoji), ""
short_desc = desc[:50] + "..." if len(desc) > 50 else desc
lines.append(f"{i}. {short_desc} [{emotion}]")
for i, (_, description, emotion) in enumerate(display_emojis, 1):
# 截断过长的描述
short_desc = description[:50] + "..." if len(description) > 50 else description
message_lines.append(f"{i}. {short_desc} [{emotion}]")
# 如果还有更多表情包,显示总数
if len(all_emojis) > max_count:
message_lines.append(f"\n💡 还有 {len(all_emojis) - max_count} 个表情包未显示")
lines.append(f"\n💡 还有 {len(all_emojis) - max_count} 个表情包未显示")
final_message = "\n".join(message_lines)
# 直接发送文本消息
await self.send_text(final_message)
return True, final_message, True
class DeleteEmojiCommand(BaseCommand):
command_name = "delete_emoji"
command_description = "删除表情包"
command_pattern = r".*/emoji delete.*"
async def execute(self) -> Tuple[bool, str, bool]:
# 查找消息中的表情包图片
logger.info(f"查找消息中的表情包用于删除: {self.message.message_segment}")
emoji_base64_list = self.find_and_return_emoji_in_message(self.message.message_segment)
final = "\n".join(lines)
await self.ctx.send.text(final, stream_id)
return True, final, True
@Command("delete_emoji", description="删除表情包", pattern=r".*/emoji delete.*")
async def handle_delete_emoji(self, stream_id: str = "", message_segments=None, **kwargs):
"""删除表情包"""
emoji_base64_list = self._extract_emoji_base64(message_segments)
if not emoji_base64_list:
return False, "未在消息中找到表情包或图片", False
await self.ctx.send.text("未在消息中找到表情包或图片", stream_id)
return False, "未找到表情包", False
# 删除找到的表情包
success_count = 0
fail_count = 0
results = []
for i, emoji_base64 in enumerate(emoji_base64_list):
try:
# 计算图片的哈希值来查找对应的表情包
import base64
import hashlib
# 确保base64字符串只包含ASCII字符
if isinstance(emoji_base64, str):
emoji_base64_clean = emoji_base64.encode("ascii", errors="ignore").decode("ascii")
for i, emoji_b64 in enumerate(emoji_base64_list):
# 计算哈希
if isinstance(emoji_b64, str):
clean = emoji_b64.encode("ascii", errors="ignore").decode("ascii")
else:
emoji_base64_clean = str(emoji_base64)
clean = str(emoji_b64)
image_bytes = base64.b64decode(clean)
emoji_hash = hashlib.md5(image_bytes).hexdigest() # noqa: S324
# 计算哈希值
image_bytes = base64.b64decode(emoji_base64_clean)
emoji_hash = hashlib.md5(image_bytes).hexdigest()
# 使用emoji_api删除表情包
result = await emoji_api.delete_emoji(emoji_hash)
if result["success"]:
result = await self.ctx.emoji.delete_emoji(emoji_hash)
if isinstance(result, dict) and result.get("success"):
success_count += 1
description = result.get("description", "未知描述")
count_before = result.get("count_before", 0)
count_after = result.get("count_after", 0)
desc = result.get("description", "未知描述")
emotions = result.get("emotions", [])
result_msg = f"表情包 {i + 1} 删除成功"
if description:
result_msg += f"\n描述: {description}"
before = result.get("count_before", 0)
after = result.get("count_after", 0)
msg = f"表情包 {i + 1} 删除成功"
if desc:
msg += f"\n描述: {desc}"
if emotions:
result_msg += f"\n情感标签: {', '.join(emotions)}"
result_msg += f"\n表情包数量: {count_before}{count_after}"
results.append(result_msg)
msg += f"\n情感标签: {', '.join(emotions)}"
msg += f"\n表情包数量: {before}{after}"
results.append(msg)
else:
fail_count += 1
error_msg = result.get("message", "删除失败")
results.append(f"表情包 {i + 1} 删除失败: {error_msg}")
err = result.get("message", "删除失败") if isinstance(result, dict) else "删除失败"
results.append(f"表情包 {i + 1} 删除失败: {err}")
except Exception as e:
fail_count += 1
results.append(f"表情包 {i + 1} 删除时发生错误: {str(e)}")
# 构建返回消息
total_count = success_count + fail_count
summary_msg = f"表情包删除完成: 成功 {success_count} 个,失败 {fail_count} 个,共处理 {total_count}"
# 如果有结果详情,添加到返回消息中
details_msg = ""
total = success_count + fail_count
summary = f"表情包删除完成: 成功 {success_count} 个,失败 {fail_count} 个,共处理 {total}"
if results:
details_msg = "\n" + "\n".join(results)
final_msg = summary_msg + details_msg
else:
final_msg = summary_msg
summary += "\n" + "\n".join(results)
# 使用表达器重写回复
try:
from src.plugin_system.apis import generator_api
await self.ctx.send.text(summary, stream_id)
return success_count > 0, summary, success_count > 0
# 构建重写数据
rewrite_data = {
"raw_reply": summary_msg,
"reason": f"删除了表情包:{details_msg}\n",
}
# 调用表达器重写
result_status, data = await generator_api.rewrite_reply(
chat_stream=self.message.chat_stream,
reply_data=rewrite_data,
)
if result_status:
# 发送重写后的回复
for reply_seg in data.reply_set.reply_data:
send_data = reply_seg.content
await self.send_text(send_data)
return success_count > 0, final_msg, success_count > 0
else:
# 如果重写失败,发送原始消息
await self.send_text(final_msg)
return success_count > 0, final_msg, success_count > 0
except Exception as e:
# 如果表达器调用失败,发送原始消息
logger.error(f"[delete_emoji] 表达器重写失败: {e}")
await self.send_text(final_msg)
return success_count > 0, final_msg, success_count > 0
def find_and_return_emoji_in_message(self, message_segments) -> List[str]:
emoji_base64_list = []
# 处理单个Seg对象的情况
if isinstance(message_segments, Seg):
if message_segments.type == "emoji":
emoji_base64_list.append(message_segments.data)
elif message_segments.type == "image":
# 假设图片数据是base64编码的
emoji_base64_list.append(message_segments.data)
elif message_segments.type == "seglist":
# 递归处理嵌套的Seg列表
emoji_base64_list.extend(self.find_and_return_emoji_in_message(message_segments.data))
return emoji_base64_list
# 处理Seg列表的情况
for seg in message_segments:
if seg.type == "emoji":
emoji_base64_list.append(seg.data)
elif seg.type == "image":
# 假设图片数据是base64编码的
emoji_base64_list.append(seg.data)
elif seg.type == "seglist":
# 递归处理嵌套的Seg列表
emoji_base64_list.extend(self.find_and_return_emoji_in_message(seg.data))
return emoji_base64_list
class RandomEmojis(BaseCommand):
command_name = "random_emojis"
command_description = "发送多张随机表情包"
command_pattern = r"^/random_emojis$"
async def execute(self):
emojis = await emoji_api.get_random(5)
@Command("random_emojis", description="发送多张随机表情包", pattern=r"^/random_emojis$")
async def handle_random_emojis(self, stream_id: str = "", **kwargs):
"""发送多张随机表情包"""
result = await self.ctx.emoji.get_random(5)
if not result or not result.get("success"):
return False, "未找到表情包", False
emojis = result.get("emojis", [])
if not emojis:
return False, "未找到表情包", False
emoji_base64_list = []
for emoji in emojis:
emoji_base64_list.append(emoji[0])
return await self.forward_images(emoji_base64_list)
async def forward_images(self, images: List[str]):
"""
把多张图片用合并转发的方式发给用户
"""
success = await self.send_forward([("0", "神秘用户", [(ReplyContentType.IMAGE, img)]) for img in images])
return (True, "已发送随机表情包", True) if success else (False, "发送随机表情包失败", False)
# ===== 插件注册 =====
@register_plugin
class EmojiManagePlugin(BasePlugin):
"""表情包管理插件 - 管理表情包"""
# 插件基本信息
plugin_name: str = "emoji_manage_plugin" # 内部标识符
enable_plugin: bool = False
dependencies: List[str] = [] # 插件依赖列表
python_dependencies: List[str] = [] # Python包依赖列表
config_file_name: str = "config.toml" # 配置文件名
# 配置节描述
config_section_descriptions = {"plugin": "插件基本信息", "emoji": "表情包功能配置"}
# 配置Schema定义
config_schema: dict = {
"plugin": {
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
"config_version": ConfigField(type=str, default="1.0.1", description="配置文件版本"),
},
}
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
return [
(RandomEmojis.get_command_info(), RandomEmojis),
(AddEmojiCommand.get_command_info(), AddEmojiCommand),
(ListEmojiCommand.get_command_info(), ListEmojiCommand),
(DeleteEmojiCommand.get_command_info(), DeleteEmojiCommand),
messages = [
{"user_id": "0", "nickname": "神秘用户", "segments": [{"type": "image", "content": e.get("base64", "")}]}
for e in emojis
]
await self.ctx.send.forward(messages, stream_id)
return True, "已发送随机表情包", True
def create_plugin():
return EmojiManagePlugin()

View File

@@ -1,7 +1,7 @@
{
"manifest_version": 1,
"name": "Hello World 示例插件 (Hello World Plugin)",
"version": "1.0.0",
"version": "2.0.0",
"description": "我的第一个MaiCore插件包含问候功能和时间查询等基础示例",
"author": {
"name": "MaiBot开发团队",
@@ -9,7 +9,7 @@
},
"license": "GPL-v3.0-or-later",
"host_application": {
"min_version": "0.8.0"
"min_version": "1.0.0"
},
"homepage_url": "https://github.com/MaiM-with-u/maibot",
"repository_url": "https://github.com/MaiM-with-u/maibot",
@@ -29,7 +29,19 @@
"plugin_info": {
"is_built_in": false,
"plugin_type": "example",
"capabilities": [
"send.text",
"send.forward",
"send.hybrid",
"emoji.get_random",
"config.get"
],
"components": [
{
"type": "tool",
"name": "compare_numbers",
"description": "比较两个数的大小"
},
{
"type": "action",
"name": "hello_greeting",
@@ -39,28 +51,37 @@
"type": "action",
"name": "bye_greeting",
"description": "向用户发送告别消息",
"activation_modes": [
"keyword"
],
"keywords": [
"再见",
"bye",
"88",
"拜拜"
]
"activation_modes": ["keyword"],
"keywords": ["再见", "bye", "88", "拜拜"]
},
{
"type": "command",
"name": "time",
"description": "查询当前时间",
"pattern": "/time"
},
{
"type": "command",
"name": "random_emojis",
"description": "发送多张随机表情包",
"pattern": "/random_emojis"
},
{
"type": "command",
"name": "test",
"description": "测试命令",
"pattern": "/test"
},
{
"type": "event_handler",
"name": "print_message_handler",
"description": "打印接收到的消息"
},
{
"type": "event_handler",
"name": "forward_messages_handler",
"description": "把接收到的消息转发到指定聊天ID"
}
],
"features": [
"问候和告别功能",
"时间查询命令",
"配置文件示例",
"新手教程代码"
]
},
"id": "MaiBot开发团队.maibot"

View File

@@ -1,50 +1,30 @@
import random
from typing import List, Tuple, Type, Any, Optional
from src.plugin_system import (
BasePlugin,
register_plugin,
BaseAction,
BaseCommand,
BaseTool,
ComponentInfo,
ActionActivationType,
ConfigField,
BaseEventHandler,
EventType,
MaiMessages,
ToolParamType,
ReplyContentType,
emoji_api,
)
from src.config.config import global_config
from src.common.logger import get_logger
"""Hello World 示例插件 — 新 SDK 版本
logger = get_logger("hello_world_plugin")
class CompareNumbersTool(BaseTool):
"""比较两个数大小的工具"""
name = "compare_numbers"
description = "使用工具 比较两个数的大小,返回较大的数"
parameters = [
("num1", ToolParamType.FLOAT, "第一个数字", True, None),
("num2", ToolParamType.FLOAT, "第二个数字", True, None),
]
available_for_llm = True
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行比较两个数的大小
Args:
function_args: 工具参数
Returns:
dict: 工具执行结果
你的第一个 MaiCore 插件,包含问候功能、时间查询等基础示例。
"""
num1: int | float = function_args.get("num1") # type: ignore
num2: int | float = function_args.get("num2") # type: ignore
import datetime
import random
from maibot_sdk import MaiBotPlugin, Action, Command, Tool, EventHandler
from maibot_sdk.types import ActivationType, EventType, ToolParameterInfo, ToolParamType
class HelloWorldPlugin(MaiBotPlugin):
"""Hello World 示例插件"""
# ===== Tool 组件 =====
@Tool(
"compare_numbers",
description="使用工具比较两个数的大小,返回较大的数",
parameters=[
ToolParameterInfo(name="num1", param_type=ToolParamType.FLOAT, description="第一个数字", required=True),
ToolParameterInfo(name="num2", param_type=ToolParamType.FLOAT, description="第二个数字", required=True),
],
)
async def handle_compare_numbers(self, num1: float = 0, num2: float = 0, **kwargs):
"""比较两个数的大小"""
try:
if num1 > num2:
result = f"{num1} 大于 {num2}"
@@ -52,270 +32,121 @@ class CompareNumbersTool(BaseTool):
result = f"{num1} 小于 {num2}"
else:
result = f"{num1} 等于 {num2}"
return {"name": self.name, "content": result}
return {"name": "compare_numbers", "content": result}
except Exception as e:
return {"name": self.name, "content": f"比较数字失败,炸了: {str(e)}"}
return {"name": "compare_numbers", "content": f"比较数字失败,炸了: {e}"}
# ===== Action 组件 =====
class HelloAction(BaseAction):
"""问候Action - 简单的问候动作"""
# === 基本信息(必须填写)===
action_name = "hello_greeting"
action_description = "向用户发送问候消息"
activation_type = ActionActivationType.ALWAYS # 始终激活
# === 功能描述(必须填写)===
action_parameters = {"greeting_message": "要发送的问候消息"}
action_require = ["需要发送友好问候时使用", "当有人向你问好时使用", "当你遇见没有见过的人时使用"]
associated_types = ["text"]
async def execute(self) -> Tuple[bool, str]:
"""执行问候动作 - 这是核心功能"""
# 发送问候消息
greeting_message = self.action_data.get("greeting_message", "")
base_message = self.get_config("greeting.message", "嗨!很开心见到你!😊")
@Action(
"hello_greeting",
description="向用户发送问候消息",
activation_type=ActivationType.ALWAYS,
action_parameters={"greeting_message": "要发送的问候消息"},
action_require=["需要发送友好问候时使用", "当有人向你问好时使用", "当你遇见没有见过的人时使用"],
associated_types=["text"],
)
async def handle_hello(self, stream_id: str = "", greeting_message: str = "", **kwargs):
"""问候动作"""
config_result = await self.ctx.config.get("greeting.message")
base_message = config_result if isinstance(config_result, str) else "嗨!很开心见到你!😊"
message = base_message + greeting_message
await self.send_text(message)
await self.ctx.send.text(message, stream_id)
return True, "发送了问候消息"
class ByeAction(BaseAction):
"""告别Action - 只在用户说再见时激活"""
action_name = "bye_greeting"
action_description = "向用户发送告别消息"
# 使用关键词激活
activation_type = ActionActivationType.KEYWORD
# 关键词设置
activation_keywords = ["再见", "bye", "88", "拜拜"]
keyword_case_sensitive = False
action_parameters = {"bye_message": "要发送的告别消息"}
action_require = [
"用户要告别时使用",
"当有人要离开时使用",
"当有人和你说再见时使用",
]
associated_types = ["text"]
async def execute(self) -> Tuple[bool, str]:
bye_message = self.action_data.get("bye_message", "")
@Action(
"bye_greeting",
description="向用户发送告别消息",
activation_type=ActivationType.KEYWORD,
activation_keywords=["再见", "bye", "88", "拜拜"],
action_parameters={"bye_message": "发送告别消息"},
action_require=["用户要告别时使用", "当有人要离开时使用", "当有人和你说再见时使用"],
associated_types=["text"],
)
async def handle_bye(self, stream_id: str = "", bye_message: str = "", **kwargs):
"""告别动作"""
message = f"再见!期待下次聊天!👋{bye_message}"
await self.send_text(message)
await self.ctx.send.text(message, stream_id)
return True, "发送了告别消息"
# ===== Command 组件 =====
class TimeCommand(BaseCommand):
"""时间查询Command - 响应/time命令"""
command_name = "time"
command_description = "查询当前时间"
# === 命令设置(必须填写)===
command_pattern = r"^/time$" # 精确匹配 "/time" 命令
async def execute(self) -> Tuple[bool, str, bool]:
"""执行时间查询"""
import datetime
# 获取当前时间
time_format: str = self.get_config("time.format", "%Y-%m-%d %H:%M:%S") # type: ignore
@Command("time", description="查询当前时间", pattern=r"^/time$")
async def handle_time(self, stream_id: str = "", **kwargs):
"""时间查询命令"""
config_result = await self.ctx.config.get("time.format")
time_format = config_result if isinstance(config_result, str) else "%Y-%m-%d %H:%M:%S"
now = datetime.datetime.now()
time_str = now.strftime(time_format)
# 发送时间信息
message = f"⏰ 当前时间:{time_str}"
await self.send_text(message)
await self.ctx.send.text(f"⏰ 当前时间:{time_str}", stream_id)
return True, f"显示了当前时间: {time_str}", True
class PrintMessage(BaseEventHandler):
"""打印消息事件处理器 - 处理打印消息事件"""
event_type = EventType.ON_MESSAGE
handler_name = "print_message_handler"
handler_description = "打印接收到的消息"
async def execute(self, message: MaiMessages | None) -> Tuple[bool, bool, str | None, None, None]:
"""执行打印消息事件处理"""
# 打印接收到的消息
if self.get_config("print_message.enabled", False):
print(f"接收到消息: {message.raw_message if message else '无效消息'}")
return True, True, "消息已打印", None, None
class ForwardMessages(BaseEventHandler):
"""
把接收到的消息转发到指定聊天ID
此组件是HYBRID消息和FORWARD消息的使用示例。
每收到10条消息就会以1%的概率使用HYBRID消息转发否则使用FORWARD消息转发。
"""
event_type = EventType.ON_MESSAGE
handler_name = "forward_messages_handler"
handler_description = "把接收到的消息转发到指定聊天ID"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.counter = 0 # 用于计数转发的消息数量
self.messages: List[str] = []
async def execute(self, message: MaiMessages | None) -> Tuple[bool, bool, None, None, None]:
if not message:
return True, True, None, None, None
stream_id = message.stream_id or ""
if message.plain_text:
self.messages.append(message.plain_text)
self.counter += 1
if self.counter % 10 == 0:
if random.random() < 0.01:
success = await self.send_hybrid(stream_id, [(ReplyContentType.TEXT, msg) for msg in self.messages])
else:
success = await self.send_forward(
stream_id,
[
(
str(global_config.bot.qq_account),
str(global_config.bot.nickname),
[(ReplyContentType.TEXT, msg)],
)
for msg in self.messages
],
)
if not success:
raise ValueError("转发消息失败")
self.messages = []
return True, True, None, None, None
class RandomEmojis(BaseCommand):
command_name = "random_emojis"
command_description = "发送多张随机表情包"
command_pattern = r"^/random_emojis$"
async def execute(self):
emojis = await emoji_api.get_random(5)
@Command("random_emojis", description="发送多张随机表情包", pattern=r"^/random_emojis$")
async def handle_random_emojis(self, stream_id: str = "", **kwargs):
"""发送多张随机表情包"""
result = await self.ctx.emoji.get_random(5)
if not result or not result.get("success"):
return False, "未找到表情包", False
emojis = result.get("emojis", [])
if not emojis:
return False, "未找到表情包", False
emoji_base64_list = []
for emoji in emojis:
emoji_base64_list.append(emoji[0])
return await self.forward_images(emoji_base64_list)
async def forward_images(self, images: List[str]):
"""
把多张图片用合并转发的方式发给用户
"""
success = await self.send_forward([("0", "神秘用户", [(ReplyContentType.IMAGE, img)]) for img in images])
return (True, "已发送随机表情包", True) if success else (False, "发送随机表情包失败", False)
class TestCommand(BaseCommand):
"""响应/test命令"""
command_name = "test"
command_description = "测试命令"
command_pattern = r"^/test$"
async def execute(self) -> Tuple[bool, Optional[str], int]:
"""执行测试命令"""
try:
from src.plugin_system.apis import generator_api
reply_reason = "这是一条测试消息。"
logger.info(f"测试命令:{reply_reason}")
result_status, data = await generator_api.generate_reply(
chat_stream=self.message.chat_stream,
reply_reason=reply_reason,
enable_chinese_typo=False,
extra_info=f'{reply_reason}用于测试bot的功能是否正常。请你按设定的人设表达一句"测试正常"',
)
if result_status:
# 发送生成的回复
if data and data.reply_set and data.reply_set.reply_data:
for reply_seg in data.reply_set.reply_data:
send_data = reply_seg.content
await self.send_text(send_data, storage_message=True)
logger.info(f"已回复: {send_data}")
return True, "", 1
except Exception as e:
logger.error(f"表达器生成失败:{e}")
return True, "", 1
# ===== 插件注册 =====
@register_plugin
class HelloWorldPlugin(BasePlugin):
"""Hello World插件 - 你的第一个MaiCore插件"""
# 插件基本信息
plugin_name: str = "hello_world_plugin" # 内部标识符
enable_plugin: bool = False
dependencies: List[str] = [] # 插件依赖列表
python_dependencies: List[str] = [] # Python包依赖列表
config_file_name: str = "config.toml" # 配置文件名
# 配置节描述
config_section_descriptions = {"plugin": "插件基本信息", "greeting": "问候功能配置", "time": "时间查询配置"}
# 配置Schema定义
config_schema: dict = {
"plugin": {
"config_version": ConfigField(type=str, default="1.0.0", description="配置文件版本"),
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
},
"greeting": {
"message": ConfigField(
type=list, default=["嗨!很开心见到你!😊", "Ciallo(∠・ω< )⌒★"], description="默认问候消息"
),
"enable_emoji": ConfigField(type=bool, default=True, description="是否启用表情符号"),
},
"time": {"format": ConfigField(type=str, default="%Y-%m-%d %H:%M:%S", description="时间显示格式")},
"print_message": {"enabled": ConfigField(type=bool, default=True, description="是否启用打印")},
}
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
return [
(HelloAction.get_action_info(), HelloAction),
(CompareNumbersTool.get_tool_info(), CompareNumbersTool), # 添加比较数字工具
(ByeAction.get_action_info(), ByeAction), # 添加告别Action
(TimeCommand.get_command_info(), TimeCommand),
(PrintMessage.get_handler_info(), PrintMessage),
(ForwardMessages.get_handler_info(), ForwardMessages),
(RandomEmojis.get_command_info(), RandomEmojis),
(TestCommand.get_command_info(), TestCommand),
# 用转发消息发送多张图片
messages = [
{"user_id": "0", "nickname": "神秘用户", "segments": [{"type": "image", "content": e.get("base64", "")}]}
for e in emojis
]
await self.ctx.send.forward(messages, stream_id)
return True, "已发送随机表情包", True
@Command("test", description="测试命令", pattern=r"^/test$")
async def handle_test(self, stream_id: str = "", **kwargs):
"""测试命令 — 发送简单测试消息"""
await self.ctx.send.text("测试正常Bot 功能运行中 ✅", stream_id)
return True, "测试完成", True
# ===== EventHandler 组件 =====
@EventHandler("print_message_handler", description="打印接收到的消息", event_type=EventType.ON_MESSAGE)
async def handle_print_message(self, message=None, **kwargs):
"""打印消息事件"""
config_result = await self.ctx.config.get("print_message.enabled")
enabled = config_result if isinstance(config_result, bool) else False
if enabled and message:
raw = message.get("raw_message", "") if isinstance(message, dict) else str(message)
print(f"接收到消息: {raw}")
return True, True, "消息已打印", None, None
@EventHandler("forward_messages_handler", description="把接收到的消息转发到指定聊天ID", event_type=EventType.ON_MESSAGE)
async def handle_forward_messages(self, message=None, stream_id: str = "", **kwargs):
"""收集消息并定期转发"""
if not message:
return True, True, None, None, None
plain_text = message.get("plain_text", "") if isinstance(message, dict) else ""
if not plain_text:
return True, True, None, None, None
# 使用插件级状态收集消息
if not hasattr(self, "_fwd_messages"):
self._fwd_messages: list[str] = []
self._fwd_counter: int = 0
self._fwd_messages.append(plain_text)
self._fwd_counter += 1
if self._fwd_counter % 10 == 0 and stream_id:
if random.random() < 0.01:
segments = [{"type": "text", "content": msg} for msg in self._fwd_messages]
await self.ctx.send.hybrid(segments, stream_id)
else:
messages = [
{"user_id": "0", "nickname": "转发", "segments": [{"type": "text", "content": msg}]}
for msg in self._fwd_messages
]
await self.ctx.send.forward(messages, stream_id)
self._fwd_messages = []
return True, True, None, None, None
# @register_plugin
# class HelloWorldEventPlugin(BaseEPlugin):
# """Hello World事件插件 - 处理问候和告别事件"""
# plugin_name = "hello_world_event_plugin"
# enable_plugin = False
# dependencies = []
# python_dependencies = []
# config_file_name = "event_config.toml"
# config_schema = {
# "plugin": {
# "name": ConfigField(type=str, default="hello_world_event_plugin", description="插件名称"),
# "version": ConfigField(type=str, default="1.0.0", description="插件版本"),
# "enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
# },
# }
# def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
# return [(PrintMessage.get_handler_info(), PrintMessage)]
def create_plugin():
return HelloWorldPlugin()

View File

@@ -37,6 +37,7 @@ dependencies = [
"uvicorn>=0.35.0",
"msgpack>=1.1.2",
"watchfiles>=1.1.1",
"maibot-plugin-sdk>=1.0.0",
]

View File

@@ -6,6 +6,7 @@ fastapi>=0.116.0
google-genai>=1.39.1
jieba>=0.42.1
json-repair>=0.47.6
maibot-plugin-sdk>=1.0.0
maim-message>=0.6.2
matplotlib>=3.10.3
numpy>=2.2.6

View File

@@ -19,9 +19,10 @@ from src.chat.heart_flow.hfc_utils import CycleDetail
from src.bw_learner.expression_learner import expression_learner_manager
from src.bw_learner.message_recorder import extract_and_distribute_messages
from src.person_info.person_info import Person
from src.plugin_system.base.component_types import EventType, ActionInfo
from src.plugin_system.core import events_manager
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
from src.core.types import ActionInfo, EventType
from src.core.event_bus import event_bus
from src.chat.event_helpers import build_event_message
from src.services import generator_service as generator_api, send_service as send_api, message_service as message_api, database_service as database_api
from src.chat.utils.chat_message_builder import (
build_readable_messages_with_id,
get_raw_msg_before_timestamp_with_chat,
@@ -315,8 +316,9 @@ class BrainChatting:
message_id_list=message_id_list,
prompt_key="brain_planner",
)
continue_flag, modified_message = await events_manager.handle_mai_events(
EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id
_event_msg = build_event_message(EventType.ON_PLAN, llm_prompt=prompt_info[0], stream_id=self.chat_stream.stream_id)
continue_flag, modified_message = await event_bus.emit(
EventType.ON_PLAN, _event_msg
)
if not continue_flag:
return False

View File

@@ -23,8 +23,8 @@ from src.chat.utils.chat_message_builder import (
from src.chat.utils.utils import get_chat_type_and_target_info
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.plugin_system.base.component_types import ActionInfo, ComponentType, ActionActivationType
from src.plugin_system.core.component_registry import component_registry
from src.core.types import ActionActivationType, ActionInfo, ComponentType
from src.core.component_registry import component_registry
if TYPE_CHECKING:
from src.common.data_models.info_data_model import TargetPersonInfo

169
src/chat/event_helpers.py Normal file
View File

@@ -0,0 +1,169 @@
"""
事件消息构建工具
将 chat 层的消息对象 (SessionMessage / MessageSending) 转换为
核心事件系统使用的 MaiMessages供调用 event_bus.emit() 前使用。
"""
from typing import List, Optional, TYPE_CHECKING
from src.common.logger import get_logger
from src.core.types import EventType, MaiMessages
if TYPE_CHECKING:
from src.chat.message_receive.message import MessageSending, SessionMessage
from src.common.data_models.llm_data_model import LLMGenerationDataModel
logger = get_logger("event_helpers")
def build_event_message(
event_type: EventType | str,
message: Optional["SessionMessage | MessageSending | MaiMessages"] = None,
llm_prompt: Optional[str] = None,
llm_response: Optional["LLMGenerationDataModel"] = None,
stream_id: Optional[str] = None,
action_usage: Optional[List[str]] = None,
) -> Optional[MaiMessages]:
"""根据事件类型和输入,准备和转换消息对象。
迁移自 events_manager._prepare_message保持相同的行为。
"""
if isinstance(message, MaiMessages):
return message.deepcopy()
if message:
return _transform_event_message(message, llm_prompt, llm_response)
if event_type not in (EventType.ON_START, EventType.ON_STOP):
assert stream_id, "如果没有消息,必须为非启动/关闭事件提供流ID"
if event_type in (EventType.ON_MESSAGE, EventType.ON_PLAN, EventType.POST_LLM, EventType.AFTER_LLM):
return _build_message_from_stream(stream_id, llm_prompt, llm_response)
else:
return _build_message_without_raw(stream_id, llm_prompt, llm_response, action_usage)
return None # ON_START / ON_STOP 没有消息体
def _transform_event_message(
message: "SessionMessage | MessageSending",
llm_prompt: Optional[str] = None,
llm_response: Optional["LLMGenerationDataModel"] = None,
) -> MaiMessages:
"""将 SessionMessage / MessageSending 转换为 MaiMessages。"""
from maim_message import Seg
from src.chat.message_receive.message import MessageSending
transformed = MaiMessages(
llm_prompt=llm_prompt,
llm_response_content=llm_response.content if llm_response else None,
llm_response_reasoning=llm_response.reasoning if llm_response else None,
llm_response_model=llm_response.model if llm_response else None,
llm_response_tool_call=llm_response.tool_calls if llm_response else None,
raw_message=message.processed_plain_text or "",
additional_data={},
)
# 消息段处理
if isinstance(message, MessageSending):
if message.message_segment.type == "seglist":
transformed.message_segments = list(message.message_segment.data) # type: ignore
else:
transformed.message_segments = [message.message_segment]
else:
transformed.message_segments = [Seg(type="text", data=message.processed_plain_text or "")]
# stream_id
transformed.stream_id = message.session_id if hasattr(message, "session_id") else ""
# 处理后文本
transformed.plain_text = message.processed_plain_text
# 基本信息
if isinstance(message, MessageSending):
transformed.message_base_info["platform"] = message.platform
if message.session.group_id:
transformed.is_group_message = True
group_name = ""
if (
message.session.context
and message.session.context.message
and message.session.context.message.message_info.group_info
):
group_name = message.session.context.message.message_info.group_info.group_name
transformed.message_base_info.update(
{
"group_id": message.session.group_id,
"group_name": group_name,
}
)
transformed.message_base_info.update(
{
"user_id": message.bot_user_info.user_id,
"user_cardname": message.bot_user_info.user_cardname,
"user_nickname": message.bot_user_info.user_nickname,
}
)
if not transformed.is_group_message:
transformed.is_private_message = True
elif hasattr(message, "message_info") and message.message_info:
if message.platform:
transformed.message_base_info["platform"] = message.platform
if message.message_info.group_info:
transformed.is_group_message = True
transformed.message_base_info.update(
{
"group_id": message.message_info.group_info.group_id,
"group_name": message.message_info.group_info.group_name,
}
)
if message.message_info.user_info:
if not transformed.is_group_message:
transformed.is_private_message = True
transformed.message_base_info.update(
{
"user_id": message.message_info.user_info.user_id,
"user_cardname": message.message_info.user_info.user_cardname,
"user_nickname": message.message_info.user_info.user_nickname,
}
)
return transformed
def _build_message_from_stream(
stream_id: str,
llm_prompt: Optional[str] = None,
llm_response: Optional["LLMGenerationDataModel"] = None,
) -> MaiMessages:
"""从 stream_id 查找会话消息并转换。"""
from src.chat.message_receive.chat_manager import chat_manager
session = chat_manager.get_session_by_session_id(stream_id)
assert session, f"未找到流ID为 {stream_id} 的会话"
return _transform_event_message(session.context.message, llm_prompt, llm_response)
def _build_message_without_raw(
stream_id: str,
llm_prompt: Optional[str] = None,
llm_response: Optional["LLMGenerationDataModel"] = None,
action_usage: Optional[List[str]] = None,
) -> MaiMessages:
"""没有原始消息对象时,从 stream_id 构建最小 MaiMessages。"""
from src.chat.message_receive.chat_manager import chat_manager
session = chat_manager.get_session_by_session_id(stream_id)
assert session, f"未找到流ID为 {stream_id} 的会话"
return MaiMessages(
stream_id=stream_id,
llm_prompt=llm_prompt,
llm_response_content=llm_response.content if llm_response else None,
llm_response_reasoning=llm_response.reasoning if llm_response else None,
llm_response_model=llm_response.model if llm_response else None,
llm_response_tool_call=llm_response.tool_calls if llm_response else None,
is_group_message=session.is_group_session,
is_private_message=not session.is_group_session,
action_usage=action_usage,
additional_data={"response_is_processed": True},
)

View File

@@ -6,7 +6,7 @@ import time
from src.config.config import global_config
from src.common.logger import get_logger
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.plugin_system.apis import send_api
from src.services import send_service as send_api
from src.common.message_repository import count_messages

View File

@@ -11,8 +11,9 @@ from src.common.utils.utils_session import SessionUtils
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
from src.chat.brain_chat.PFC.pfc_manager import PFCManager
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.plugin_system.core import component_registry, events_manager, global_announcement_manager
from src.plugin_system.base import BaseCommand, EventType
from src.core.announcement_manager import global_announcement_manager
from src.core.component_registry import component_registry
from src.core.types import EventType
from .message import SessionMessage
from .chat_manager import chat_manager
@@ -65,10 +66,10 @@ class ChatBot:
try:
text = message.processed_plain_text
# 使用新的组件注册中心查找命令
# 使用核心组件注册查找命令
command_result = component_registry.find_command_by_text(text)
if command_result:
command_class, matched_groups, command_info = command_result
command_executor, matched_groups, command_info = command_result
plugin_name = command_info.plugin_name
command_name = command_info.name
if message.session_id and command_name in global_announcement_manager.get_disabled_chat_commands(
@@ -82,20 +83,20 @@ class ChatBot:
# 获取插件配置
plugin_config = component_registry.get_plugin_config(plugin_name)
# 创建命令实例
command_instance: BaseCommand = command_class(message, plugin_config)
command_instance.set_matched_groups(matched_groups)
try:
# 执行命令
success, response, intercept_message_level = await command_instance.execute()
# 调用命令执行器
success, response, intercept_message_level = await command_executor(
message=message,
plugin_config=plugin_config,
matched_groups=matched_groups,
)
message.intercept_message_level = intercept_message_level
# 记录命令执行结果
if success:
logger.info(f"命令执行成功: {command_class.__name__} (拦截等级: {intercept_message_level})")
logger.info(f"命令执行成功: {command_name} (拦截等级: {intercept_message_level})")
else:
logger.warning(f"命令执行失败: {command_class.__name__} - {response}")
logger.warning(f"命令执行失败: {command_name} - {response}")
# 根据命令的拦截设置决定是否继续处理消息
return (
@@ -105,14 +106,9 @@ class ChatBot:
) # 找到命令根据intercept_message决定是否继续
except Exception as e:
logger.error(f"执行命令时出错: {command_class.__name__} - {e}")
logger.error(f"执行命令时出错: {command_name} - {e}")
logger.error(traceback.format_exc())
try:
await command_instance.send_text(f"命令执行出错: {str(e)}")
except Exception as send_error:
logger.error(f"发送错误消息失败: {send_error}")
# 命令出错时,根据命令的拦截设置决定是否继续处理消息
return True, str(e), False # 出错时继续处理消息

View File

@@ -318,11 +318,13 @@ class UniversalMessageSender:
message.build_reply()
logger.debug(f"[{chat_id}] 选择回复引用消息: {message.processed_plain_text[:20]}...")
from src.plugin_system.core.events_manager import events_manager
from src.plugin_system.base.component_types import EventType
from src.core.event_bus import event_bus
from src.chat.event_helpers import build_event_message
from src.core.types import EventType
continue_flag, modified_message = await events_manager.handle_mai_events(
EventType.POST_SEND_PRE_PROCESS, message=message, stream_id=chat_id
_event_msg = build_event_message(EventType.POST_SEND_PRE_PROCESS, message=message, stream_id=chat_id)
continue_flag, modified_message = await event_bus.emit(
EventType.POST_SEND_PRE_PROCESS, _event_msg
)
if not continue_flag:
logger.info(f"[{chat_id}] 消息发送被插件取消: {str(message.message_segment)[:100]}...")
@@ -336,8 +338,9 @@ class UniversalMessageSender:
await message.process()
continue_flag, modified_message = await events_manager.handle_mai_events(
EventType.POST_SEND, message=message, stream_id=chat_id
_event_msg = build_event_message(EventType.POST_SEND, message=message, stream_id=chat_id)
continue_flag, modified_message = await event_bus.emit(
EventType.POST_SEND, _event_msg
)
if not continue_flag:
logger.info(f"[{chat_id}] 消息发送被插件取消: {str(message.message_segment)[:100]}...")
@@ -360,8 +363,9 @@ class UniversalMessageSender:
if not sent_msg:
return False
continue_flag, modified_message = await events_manager.handle_mai_events(
EventType.AFTER_SEND, message=message, stream_id=chat_id
_event_msg = build_event_message(EventType.AFTER_SEND, message=message, stream_id=chat_id)
continue_flag, modified_message = await event_bus.emit(
EventType.AFTER_SEND, _event_msg
)
if not continue_flag:
logger.info(f"[{chat_id}] 消息发送后续处理被插件取消: {str(message.message_segment)[:100]}...")

View File

@@ -1,20 +1,34 @@
from typing import Dict, Optional, Type
from typing import Dict, Optional, Tuple
from src.chat.message_receive.chat_manager import BotChatSession
from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.base.component_types import ComponentType, ActionInfo
from src.plugin_system.base.base_action import BaseAction
from src.core.component_registry import component_registry, ActionExecutor
from src.core.types import ActionInfo, ComponentType
logger = get_logger("action_manager")
class ActionHandle:
"""Action 执行句柄
不依赖任何插件基类,内部持有 executor (async callable) 和绑定参数。
brain_chat 调用 ``await handle.execute()`` 即可。
"""
def __init__(self, executor: ActionExecutor, **kwargs):
self._executor = executor
self._kwargs = kwargs
async def execute(self) -> Tuple[bool, str]:
return await self._executor(**self._kwargs)
class ActionManager:
"""
动作管理器,用于管理各种类型的动作
现在统一使用新插件系统,简化了原有的新旧兼容逻辑
使用核心组件注册表的 executor-based 模式
"""
def __init__(self):
@@ -39,9 +53,9 @@ class ActionManager:
log_prefix: str,
shutting_down: bool = False,
action_message: Optional[DatabaseMessages] = None,
) -> Optional[BaseAction]:
) -> Optional[ActionHandle]:
"""
创建动作处理器实例
创建动作执行句柄
Args:
action_name: 动作名称
@@ -52,30 +66,26 @@ class ActionManager:
chat_stream: 聊天流
log_prefix: 日志前缀
shutting_down: 是否正在关闭
action_message: 动作消息记录
Returns:
Optional[BaseAction]: 创建的动作处理器实例,如果动作名称未注册则返回None
Optional[ActionHandle]: 执行句柄,如果动作未注册则返回 None
"""
try:
# 获取组件类 - 明确指定查询Action类型
component_class: Type[BaseAction] = component_registry.get_component_class(
action_name, ComponentType.ACTION
) # type: ignore
if not component_class:
executor = component_registry.get_action_executor(action_name)
if not executor:
logger.warning(f"{log_prefix} 未找到Action组件: {action_name}")
return None
# 获取组件信息
component_info = component_registry.get_component_info(action_name, ComponentType.ACTION)
if not component_info:
info = component_registry.get_action_info(action_name)
if not info:
logger.warning(f"{log_prefix} 未找到Action组件信息: {action_name}")
return None
# 获取插件配置
plugin_config = component_registry.get_plugin_config(component_info.plugin_name)
plugin_config = component_registry.get_plugin_config(info.plugin_name) or {}
# 创建动作实例
instance = component_class(
handle = ActionHandle(
executor,
action_data=action_data,
action_reasoning=action_reasoning,
cycle_timers=cycle_timers,
@@ -87,11 +97,11 @@ class ActionManager:
action_message=action_message,
)
logger.debug(f"创建Action实例成功: {action_name}")
return instance
logger.debug(f"创建Action执行句柄成功: {action_name}")
return handle
except Exception as e:
logger.error(f"创建Action实例失败 {action_name}: {e}")
logger.error(f"创建Action执行句柄失败 {action_name}: {e}")
import traceback
logger.error(traceback.format_exc())

View File

@@ -7,8 +7,8 @@ from src.config.config import global_config
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
from src.plugin_system.base.component_types import ActionInfo, ActionActivationType
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.core.types import ActionActivationType, ActionInfo
from src.core.announcement_manager import global_announcement_manager
logger = get_logger("action_manager")

View File

@@ -23,9 +23,9 @@ from src.chat.utils.chat_message_builder import (
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.plugin_system.base.component_types import ActionInfo, ComponentType, ActionActivationType
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.apis.message_api import translate_pid_to_description
from src.core.types import ActionActivationType, ActionInfo, ComponentType
from src.core.component_registry import component_registry
from src.services.message_service import translate_pid_to_description
from src.person_info.person_info import Person
if TYPE_CHECKING:

View File

@@ -27,12 +27,12 @@ from src.chat.utils.chat_message_builder import (
replace_user_references,
)
from src.bw_learner.expression_selector import expression_selector
from src.plugin_system.apis.message_api import translate_pid_to_description
from src.services.message_service import translate_pid_to_description
# from src.memory_system.memory_activator import MemoryActivator
from src.person_info.person_info import Person
from src.plugin_system.base.component_types import ActionInfo, EventType
from src.plugin_system.apis import llm_api
from src.core.types import ActionInfo, EventType
from src.services import llm_service as llm_api
from src.chat.logger.plan_reply_logger import PlanReplyLogger
from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt
@@ -56,7 +56,7 @@ class DefaultReplyer:
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
self.heart_fc_sender = UniversalMessageSender()
from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor不然会循环依赖
from src.chat.tool_executor import ToolExecutor
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.session_id, enable_cache=True, cache_ttl=3)
@@ -149,11 +149,13 @@ class DefaultReplyer:
except Exception:
logger.exception("记录reply日志失败")
return False, llm_response
from src.plugin_system.core.events_manager import events_manager
from src.core.event_bus import event_bus
from src.chat.event_helpers import build_event_message
if not from_plugin:
continue_flag, modified_message = await events_manager.handle_mai_events(
EventType.POST_LLM, None, prompt, None, stream_id=stream_id
_event_msg = build_event_message(EventType.POST_LLM, llm_prompt=prompt, stream_id=stream_id)
continue_flag, modified_message = await event_bus.emit(
EventType.POST_LLM, _event_msg
)
if not continue_flag:
raise UserWarning("插件于请求前中断了内容生成")
@@ -217,8 +219,9 @@ class DefaultReplyer:
)
except Exception:
logger.exception("记录reply日志失败")
continue_flag, modified_message = await events_manager.handle_mai_events(
EventType.AFTER_LLM, None, prompt, llm_response, stream_id=stream_id
_event_msg = build_event_message(EventType.AFTER_LLM, llm_prompt=prompt, llm_response=llm_response, stream_id=stream_id)
continue_flag, modified_message = await event_bus.emit(
EventType.AFTER_LLM, _event_msg
)
if not from_plugin and not continue_flag:
raise UserWarning("插件于请求后取消了内容生成")

View File

@@ -28,12 +28,12 @@ from src.chat.utils.chat_message_builder import (
replace_user_references,
)
from src.bw_learner.expression_selector import expression_selector
from src.plugin_system.apis.message_api import translate_pid_to_description
from src.services.message_service import translate_pid_to_description
# from src.memory_system.memory_activator import MemoryActivator
from src.person_info.person_info import Person, is_person_known
from src.plugin_system.base.component_types import ActionInfo, EventType
from src.plugin_system.apis import llm_api
from src.core.types import ActionInfo, EventType
from src.services import llm_service as llm_api
from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt
from src.bw_learner.jargon_explainer import explain_jargon_in_context
@@ -55,7 +55,7 @@ class PrivateReplyer:
self.heart_fc_sender = UniversalMessageSender()
# self.memory_activator = MemoryActivator()
from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor不然会循环依赖
from src.chat.tool_executor import ToolExecutor
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.session_id, enable_cache=True, cache_ttl=3)
@@ -114,11 +114,13 @@ class PrivateReplyer:
if not prompt:
logger.warning("构建prompt失败跳过回复生成")
return False, llm_response
from src.plugin_system.core.events_manager import events_manager
from src.core.event_bus import event_bus
from src.chat.event_helpers import build_event_message
if not from_plugin:
continue_flag, modified_message = await events_manager.handle_mai_events(
EventType.POST_LLM, None, prompt, None, stream_id=stream_id
_event_msg = build_event_message(EventType.POST_LLM, llm_prompt=prompt, stream_id=stream_id)
continue_flag, modified_message = await event_bus.emit(
EventType.POST_LLM, _event_msg
)
if not continue_flag:
raise UserWarning("插件于请求前中断了内容生成")
@@ -138,8 +140,9 @@ class PrivateReplyer:
llm_response.reasoning = reasoning_content
llm_response.model = model_name
llm_response.tool_calls = tool_call
continue_flag, modified_message = await events_manager.handle_mai_events(
EventType.AFTER_LLM, None, prompt, llm_response, stream_id=stream_id
_event_msg = build_event_message(EventType.AFTER_LLM, llm_prompt=prompt, llm_response=llm_response, stream_id=stream_id)
continue_flag, modified_message = await event_bus.emit(
EventType.AFTER_LLM, _event_msg
)
if not from_plugin and not continue_flag:
raise UserWarning("插件于请求后取消了内容生成")

View File

@@ -1,14 +1,23 @@
"""
工具执行器
独立的工具执行组件可以直接输入聊天消息内容
自动判断并执行相应的工具返回结构化的工具执行结果
src.plugin_system.core.tool_use 迁移使用新的核心组件注册表
"""
import hashlib
import time
from typing import List, Dict, Tuple, Optional, Any
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.llm_models.utils_model import LLMRequest
from src.llm_models.payload_content import ToolCall
from src.config.config import global_config, model_config
from src.prompt.prompt_manager import prompt_manager
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from typing import Any, Dict, List, Optional, Tuple
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.core.announcement_manager import global_announcement_manager
from src.core.component_registry import component_registry
from src.llm_models.payload_content import ToolCall
from src.llm_models.utils_model import LLMRequest
from src.prompt.prompt_manager import prompt_manager
logger = get_logger("tool_use")
@@ -20,66 +29,38 @@ class ToolExecutor:
"""
def __init__(self, chat_id: str, enable_cache: bool = True, cache_ttl: int = 3):
"""初始化工具执行器
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
Args:
executor_id: 执行器标识符用于日志记录
enable_cache: 是否启用缓存机制
cache_ttl: 缓存生存时间周期数
"""
self.chat_id = chat_id
self.chat_stream = _chat_manager.get_session_by_session_id(self.chat_id)
self.log_prefix = f"[{_chat_manager.get_session_name(self.chat_id) or self.chat_id}]"
self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor")
# 缓存配置
self.enable_cache = enable_cache
self.cache_ttl = cache_ttl
self.tool_cache = {} # 格式: {cache_key: {"result": result, "ttl": ttl, "timestamp": timestamp}}
self.tool_cache: Dict[str, dict] = {}
logger.info(f"{self.log_prefix}工具执行器初始化完成,缓存{'启用' if enable_cache else '禁用'}TTL={cache_ttl}")
async def execute_from_chat_message(
self, target_message: str, chat_history: str, sender: str, return_details: bool = False
) -> Tuple[List[Dict[str, Any]], List[str], str]:
"""从聊天消息执行工具
"""从聊天消息执行工具"""
Args:
target_message: 目标消息内容
chat_history: 聊天历史
sender: 发送者
return_details: 是否返回详细信息(使用的工具列表和提示词)
Returns:
如果return_details为False: Tuple[List[Dict], List[str], str] - (工具执行结果列表, , )
如果return_details为True: Tuple[List[Dict], List[str], str] - (结果列表, 使用的工具, 提示词)
"""
# 首先检查缓存
cache_key = self._generate_cache_key(target_message, chat_history, sender)
if cached_result := self._get_from_cache(cache_key):
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行")
if not return_details:
return cached_result, [], ""
# 从缓存结果中提取工具名称
used_tools = [result.get("tool_name", "unknown") for result in cached_result]
return cached_result, used_tools, ""
# 缓存未命中,执行工具调用
# 获取可用工具
tools = self._get_tool_definitions()
# 如果没有可用工具,直接返回空内容
if not tools:
logger.debug(f"{self.log_prefix}没有可用工具,直接返回空内容")
if return_details:
return [], [], ""
else:
return [], [], ""
# 构建工具调用提示词
prompt_template = prompt_manager.get_prompt("tool_executor")
prompt_template.add_context("target_message", target_message)
prompt_template.add_context("chat_history", chat_history)
@@ -90,15 +71,12 @@ class ToolExecutor:
logger.debug(f"{self.log_prefix}开始LLM工具调用分析")
# 调用LLM进行工具决策
response, (reasoning_content, model_name, tool_calls) = await self.llm_model.generate_response_async(
prompt=prompt, tools=tools, raise_when_empty=False
)
# 执行工具调用
tool_results, used_tools = await self.execute_tool_calls(tool_calls)
# 缓存结果
if tool_results:
self._set_cache(cache_key, tool_results)
@@ -107,42 +85,30 @@ class ToolExecutor:
if return_details:
return tool_results, used_tools, prompt
else:
return tool_results, [], ""
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
all_tools = get_llm_available_tool_definitions()
"""获取 LLM 可用的工具定义列表"""
all_tools = component_registry.get_llm_available_tools()
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
return [definition for name, definition in all_tools if name not in user_disabled_tools]
return [info.get_llm_definition() for name, info in all_tools.items() if name not in user_disabled_tools]
async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]:
"""执行工具调用
Args:
tool_calls: LLM返回的工具调用列表
Returns:
Tuple[List[Dict], List[str]]: (工具执行结果列表, 使用的工具名称列表)
"""
"""执行工具调用列表"""
tool_results: List[Dict[str, Any]] = []
used_tools = []
used_tools: List[str] = []
if not tool_calls:
logger.debug(f"{self.log_prefix}无需执行工具")
return [], []
# 提取tool_calls中的函数名称
func_names = [call.func_name for call in tool_calls if call.func_name]
logger.info(f"{self.log_prefix}开始执行工具调用: {func_names}")
# 执行每个工具调用
for tool_call in tool_calls:
try:
tool_name = tool_call.func_name
try:
logger.debug(f"{self.log_prefix}执行工具: {tool_name}")
# 执行工具
result = await self.execute_tool_call(tool_call)
if result:
@@ -156,7 +122,6 @@ class ToolExecutor:
content = tool_info["content"]
if not isinstance(content, (str, list, tuple)):
tool_info["content"] = str(content)
# 空内容直接跳过(空字符串、全空白字符串、空列表/空元组)
content_check = tool_info["content"]
if (isinstance(content_check, str) and not content_check.strip()) or (
isinstance(content_check, (list, tuple)) and len(content_check) == 0
@@ -166,11 +131,10 @@ class ToolExecutor:
tool_results.append(tool_info)
used_tools.append(tool_name)
preview = content[:200]
preview = str(content)[:200]
logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...")
except Exception as e:
logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}")
# 添加错误信息到结果中
error_info = {
"type": "tool_error",
"id": f"tool_error_{time.time()}",
@@ -182,31 +146,18 @@ class ToolExecutor:
return tool_results, used_tools
async def execute_tool_call(
self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None
) -> Optional[Dict[str, Any]]:
# sourcery skip: use-assigned-variable
"""执行单个工具调用
Args:
tool_call: 工具调用对象
Returns:
Optional[Dict]: 工具调用结果如果失败则返回None
"""
try:
async def execute_tool_call(self, tool_call: ToolCall) -> Optional[Dict[str, Any]]:
"""执行单个工具调用"""
function_name = tool_call.func_name
function_args = tool_call.args or {}
function_args["llm_called"] = True # 标记为LLM调用
function_args["llm_called"] = True
# 获取对应工具实例
tool_instance = tool_instance or get_tool_instance(function_name, self.chat_stream)
if not tool_instance:
executor = component_registry.get_tool_executor(function_name)
if not executor:
logger.warning(f"未知工具名称: {function_name}")
return None
# 执行工具
result = await tool_instance.execute(function_args)
result = await executor(function_args)
if result:
return {
"tool_call_id": tool_call.call_id,
@@ -216,88 +167,9 @@ class ToolExecutor:
"content": result["content"],
}
return None
except Exception as e:
logger.error(f"执行工具调用时发生错误: {str(e)}")
raise e
def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str:
"""生成缓存键
Args:
target_message: 目标消息内容
chat_history: 聊天历史
sender: 发送者
Returns:
str: 缓存键
"""
import hashlib
# 使用消息内容和群聊状态生成唯一缓存键
content = f"{target_message}_{chat_history}_{sender}"
return hashlib.md5(content.encode()).hexdigest()
def _get_from_cache(self, cache_key: str) -> Optional[List[Dict]]:
"""从缓存获取结果
Args:
cache_key: 缓存键
Returns:
Optional[List[Dict]]: 缓存的结果如果不存在或过期则返回None
"""
if not self.enable_cache or cache_key not in self.tool_cache:
return None
cache_item = self.tool_cache[cache_key]
if cache_item["ttl"] <= 0:
# 缓存过期,删除
del self.tool_cache[cache_key]
logger.debug(f"{self.log_prefix}缓存过期,删除缓存键: {cache_key}")
return None
# 减少TTL
cache_item["ttl"] -= 1
logger.debug(f"{self.log_prefix}使用缓存结果剩余TTL: {cache_item['ttl']}")
return cache_item["result"]
def _set_cache(self, cache_key: str, result: List[Dict]):
"""设置缓存
Args:
cache_key: 缓存键
result: 要缓存的结果
"""
if not self.enable_cache:
return
self.tool_cache[cache_key] = {"result": result, "ttl": self.cache_ttl, "timestamp": time.time()}
logger.debug(f"{self.log_prefix}设置缓存TTL: {self.cache_ttl}")
def _cleanup_expired_cache(self):
"""清理过期的缓存"""
if not self.enable_cache:
return
expired_keys = []
expired_keys.extend(cache_key for cache_key, cache_item in self.tool_cache.items() if cache_item["ttl"] <= 0)
for key in expired_keys:
del self.tool_cache[key]
if expired_keys:
logger.debug(f"{self.log_prefix}清理了{len(expired_keys)}个过期缓存")
async def execute_specific_tool_simple(self, tool_name: str, tool_args: Dict) -> Optional[Dict]:
"""直接执行指定工具
Args:
tool_name: 工具名称
tool_args: 工具参数
validate_args: 是否验证参数
Returns:
Optional[Dict]: 工具执行结果失败时返回None
"""
"""直接执行指定工具"""
try:
tool_call = ToolCall(
call_id=f"direct_tool_{time.time()}",
@@ -306,7 +178,6 @@ class ToolExecutor:
)
logger.info(f"{self.log_prefix}直接执行工具: {tool_name}")
result = await self.execute_tool_call(tool_call)
if result:
@@ -325,86 +196,55 @@ class ToolExecutor:
return None
# === 缓存方法 ===
def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str:
content = f"{target_message}_{chat_history}_{sender}"
return hashlib.md5(content.encode()).hexdigest()
def _get_from_cache(self, cache_key: str) -> Optional[List[Dict]]:
if not self.enable_cache or cache_key not in self.tool_cache:
return None
cache_item = self.tool_cache[cache_key]
if cache_item["ttl"] <= 0:
del self.tool_cache[cache_key]
return None
cache_item["ttl"] -= 1
return cache_item["result"]
def _set_cache(self, cache_key: str, result: List[Dict]):
if not self.enable_cache:
return
self.tool_cache[cache_key] = {"result": result, "ttl": self.cache_ttl, "timestamp": time.time()}
def _cleanup_expired_cache(self):
if not self.enable_cache:
return
expired = [k for k, v in self.tool_cache.items() if v["ttl"] <= 0]
for key in expired:
del self.tool_cache[key]
def clear_cache(self):
"""清空所有缓存"""
if self.enable_cache:
cache_count = len(self.tool_cache)
self.tool_cache.clear()
logger.info(f"{self.log_prefix}清空了{cache_count}个缓存项")
def get_cache_status(self) -> Dict:
"""获取缓存状态信息
Returns:
Dict: 包含缓存统计信息的字典
"""
if not self.enable_cache:
return {"enabled": False, "cache_count": 0}
# 清理过期缓存
self._cleanup_expired_cache()
total_count = len(self.tool_cache)
ttl_distribution = {}
for cache_item in self.tool_cache.values():
ttl = cache_item["ttl"]
ttl_distribution: Dict[int, int] = {}
for item in self.tool_cache.values():
ttl = item["ttl"]
ttl_distribution[ttl] = ttl_distribution.get(ttl, 0) + 1
return {
"enabled": True,
"cache_count": total_count,
"cache_count": len(self.tool_cache),
"cache_ttl": self.cache_ttl,
"ttl_distribution": ttl_distribution,
}
def set_cache_config(self, enable_cache: Optional[bool] = None, cache_ttl: int = -1):
"""动态修改缓存配置
Args:
enable_cache: 是否启用缓存
cache_ttl: 缓存TTL
"""
if enable_cache is not None:
self.enable_cache = enable_cache
logger.info(f"{self.log_prefix}缓存状态修改为: {'启用' if enable_cache else '禁用'}")
if cache_ttl > 0:
self.cache_ttl = cache_ttl
logger.info(f"{self.log_prefix}缓存TTL修改为: {cache_ttl}")
"""
ToolExecutor使用示例
# 1. 基础使用 - 从聊天消息执行工具启用缓存默认TTL=3
executor = ToolExecutor(executor_id="my_executor")
results, _, _ = await executor.execute_from_chat_message(
talking_message_str="今天天气怎么样?现在几点了?",
is_group_chat=False
)
# 2. 禁用缓存的执行器
no_cache_executor = ToolExecutor(executor_id="no_cache", enable_cache=False)
# 3. 自定义缓存TTL
long_cache_executor = ToolExecutor(executor_id="long_cache", cache_ttl=10)
# 4. 获取详细信息
results, used_tools, prompt = await executor.execute_from_chat_message(
talking_message_str="帮我查询Python相关知识",
is_group_chat=False,
return_details=True
)
# 5. 直接执行特定工具
result = await executor.execute_specific_tool_simple(
tool_name="get_knowledge",
tool_args={"query": "机器学习"}
)
# 6. 缓存管理
cache_status = executor.get_cache_status() # 查看缓存状态
executor.clear_cache() # 清空缓存
executor.set_cache_config(cache_ttl=5) # 动态修改缓存配置
"""

View File

@@ -4,7 +4,7 @@ from . import BaseDataModel
if TYPE_CHECKING:
from .database_data_model import DatabaseMessages
from src.plugin_system.base.component_types import ActionInfo
from src.core.types import ActionInfo
@dataclass

6
src/core/__init__.py Normal file
View File

@@ -0,0 +1,6 @@
"""
MaiBot 核心基础设施
提供与插件系统无关的核心类型定义、配置 schema 等基础设施。
这些类型被整个项目共享,包括内部模块、服务层、旧插件系统和新插件运行时。
"""

View File

@@ -0,0 +1,242 @@
"""
核心组件注册表
面向最终架构的组件管理:
- Action注册 ActionInfo + 执行器(本地 callable 或 IPC 路由)
- Command注册正则模式 + 执行器
- Tool注册工具定义 + 执行器
不依赖任何插件基类,组件执行器是纯 async callable。
"""
import re
from typing import Any, Awaitable, Callable, Dict, List, Optional, Pattern, Tuple, Union
from src.common.logger import get_logger
from src.core.types import (
ActionActivationType,
ActionInfo,
CommandInfo,
ComponentInfo,
ComponentType,
ToolInfo,
)
logger = get_logger("component_registry")
# 执行器类型
ActionExecutor = Callable[..., Awaitable[Any]]
CommandExecutor = Callable[..., Awaitable[Tuple[bool, Optional[str], bool]]]
ToolExecutor = Callable[..., Awaitable[Any]]
class ComponentRegistry:
"""核心组件注册表
管理 action、command、tool 三类组件。
每个组件由「元信息 + 执行器」构成,执行器是 async callable
不需要继承任何基类。
"""
def __init__(self):
# Action 注册
self._actions: Dict[str, ActionInfo] = {}
self._action_executors: Dict[str, ActionExecutor] = {}
self._default_actions: Dict[str, ActionInfo] = {}
# Command 注册
self._commands: Dict[str, CommandInfo] = {}
self._command_executors: Dict[str, CommandExecutor] = {}
self._command_patterns: Dict[Pattern, str] = {}
# Tool 注册
self._tools: Dict[str, ToolInfo] = {}
self._tool_executors: Dict[str, ToolExecutor] = {}
self._llm_available_tools: Dict[str, ToolInfo] = {}
# 插件配置plugin_name -> config dict
self._plugin_configs: Dict[str, dict] = {}
logger.info("核心组件注册表初始化完成")
# ========== Action ==========
def register_action(
self,
info: ActionInfo,
executor: ActionExecutor,
) -> bool:
"""注册 action
Args:
info: action 元信息
executor: 执行器async callable
"""
name = info.name
if name in self._actions:
logger.warning(f"Action {name} 已存在,跳过注册")
return False
self._actions[name] = info
self._action_executors[name] = executor
if info.enabled:
self._default_actions[name] = info
logger.debug(f"注册 Action: {name}")
return True
def get_action_info(self, name: str) -> Optional[ActionInfo]:
return self._actions.get(name)
def get_action_executor(self, name: str) -> Optional[ActionExecutor]:
return self._action_executors.get(name)
def get_default_actions(self) -> Dict[str, ActionInfo]:
return self._default_actions.copy()
def get_all_actions(self) -> Dict[str, ActionInfo]:
return self._actions.copy()
def remove_action(self, name: str) -> bool:
if name not in self._actions:
return False
del self._actions[name]
self._action_executors.pop(name, None)
self._default_actions.pop(name, None)
logger.debug(f"移除 Action: {name}")
return True
# ========== Command ==========
def register_command(
self,
info: CommandInfo,
executor: CommandExecutor,
) -> bool:
"""注册 command"""
name = info.name
if name in self._commands:
logger.warning(f"Command {name} 已存在,跳过注册")
return False
self._commands[name] = info
self._command_executors[name] = executor
if info.enabled and info.command_pattern:
pattern = re.compile(info.command_pattern, re.IGNORECASE | re.DOTALL)
self._command_patterns[pattern] = name
logger.debug(f"注册 Command: {name}")
return True
def find_command_by_text(
self, text: str
) -> Optional[Tuple[CommandExecutor, dict, CommandInfo]]:
"""根据文本查找匹配的命令
Returns:
(executor, matched_groups, command_info) 或 None
"""
candidates = [p for p in self._command_patterns if p.match(text)]
if not candidates:
return None
if len(candidates) > 1:
logger.warning(f"文本 '{text[:50]}' 匹配到多个命令模式,使用第一个")
pattern = candidates[0]
name = self._command_patterns[pattern]
return (
self._command_executors[name],
pattern.match(text).groupdict(), # type: ignore
self._commands[name],
)
def remove_command(self, name: str) -> bool:
if name not in self._commands:
return False
del self._commands[name]
self._command_executors.pop(name, None)
self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != name}
logger.debug(f"移除 Command: {name}")
return True
# ========== Tool ==========
def register_tool(
self,
info: ToolInfo,
executor: ToolExecutor,
) -> bool:
"""注册 tool"""
name = info.name
if name in self._tools:
logger.warning(f"Tool {name} 已存在,跳过注册")
return False
self._tools[name] = info
self._tool_executors[name] = executor
if info.enabled:
self._llm_available_tools[name] = info
logger.debug(f"注册 Tool: {name}")
return True
def get_tool_info(self, name: str) -> Optional[ToolInfo]:
return self._tools.get(name)
def get_tool_executor(self, name: str) -> Optional[ToolExecutor]:
return self._tool_executors.get(name)
def get_llm_available_tools(self) -> Dict[str, ToolInfo]:
return self._llm_available_tools.copy()
def get_all_tools(self) -> Dict[str, ToolInfo]:
return self._tools.copy()
def remove_tool(self, name: str) -> bool:
if name not in self._tools:
return False
del self._tools[name]
self._tool_executors.pop(name, None)
self._llm_available_tools.pop(name, None)
logger.debug(f"移除 Tool: {name}")
return True
# ========== 通用查询 ==========
def get_component_info(self, name: str, component_type: ComponentType) -> Optional[ComponentInfo]:
"""获取组件元信息"""
match component_type:
case ComponentType.ACTION:
return self._actions.get(name)
case ComponentType.COMMAND:
return self._commands.get(name)
case ComponentType.TOOL:
return self._tools.get(name)
case _:
return None
def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
"""获取某类型的所有组件"""
match component_type:
case ComponentType.ACTION:
return dict(self._actions)
case ComponentType.COMMAND:
return dict(self._commands)
case ComponentType.TOOL:
return dict(self._tools)
case _:
return {}
# ========== 插件配置 ==========
def set_plugin_config(self, plugin_name: str, config: dict) -> None:
self._plugin_configs[plugin_name] = config
def get_plugin_config(self, plugin_name: str) -> Optional[dict]:
return self._plugin_configs.get(plugin_name)
# 全局单例
component_registry = ComponentRegistry()

216
src/core/event_bus.py Normal file
View File

@@ -0,0 +1,216 @@
"""
核心事件总线
面向最终架构的事件系统:
- 内部 handler 直接注册 async callable
- IPC 插件通过 plugin_runtime 桥接
- 不依赖任何插件基类
"""
import asyncio
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
from src.common.logger import get_logger
from src.core.types import EventType, MaiMessages
if TYPE_CHECKING:
from src.common.data_models.llm_data_model import LLMGenerationDataModel
logger = get_logger("event_bus")
# Handler 签名:接收 MaiMessages返回 (continue, modified_message)
EventHandler = Callable[[Optional[MaiMessages]], Awaitable[Tuple[bool, Optional[MaiMessages]]]]
class EventBus:
"""核心事件总线
支持两种 handler
- 拦截型intercept=True同步顺序执行可修改消息、可中断流程
- 非拦截型intercept=False异步并发执行fire-and-forget
handler 是纯 async callable不需要继承任何基类。
"""
def __init__(self):
# event_type -> [(handler, name, weight, intercept)]
self._handlers: Dict[EventType | str, List[_HandlerEntry]] = {}
self._running_tasks: Dict[str, List[asyncio.Task]] = {}
# 预注册所有内置事件类型
for event in EventType:
self._handlers[event] = []
def subscribe(
self,
event_type: EventType | str,
handler: EventHandler,
name: str,
weight: int = 0,
intercept: bool = False,
) -> None:
"""注册事件 handler
Args:
event_type: 事件类型
handler: async callable签名 (Optional[MaiMessages]) -> (bool, Optional[MaiMessages])
name: handler 标识名
weight: 权重,越大越先执行
intercept: 是否为拦截型(同步执行,可中断流程)
"""
if event_type not in self._handlers:
self._handlers[event_type] = []
entry = _HandlerEntry(handler=handler, name=name, weight=weight, intercept=intercept)
self._handlers[event_type].append(entry)
self._handlers[event_type].sort(key=lambda e: e.weight, reverse=True)
logger.debug(f"注册事件 handler: {name} -> {event_type} (weight={weight}, intercept={intercept})")
def unsubscribe(self, event_type: EventType | str, name: str) -> bool:
"""取消注册事件 handler"""
handlers = self._handlers.get(event_type, [])
for i, entry in enumerate(handlers):
if entry.name == name:
del handlers[i]
logger.debug(f"取消注册事件 handler: {name} <- {event_type}")
return True
return False
async def emit(
self,
event_type: EventType | str,
message: Optional[MaiMessages] = None,
) -> Tuple[bool, Optional[MaiMessages]]:
"""触发事件
按权重顺序执行所有 handler
- 拦截型 handler 同步执行,可修改消息和中断流程
- 非拦截型 handler 异步 fire-and-forget
Args:
event_type: 事件类型
message: 事件消息(可选)
Returns:
(continue_flag, modified_message)
- continue_flag: False 表示某个拦截型 handler 要求中断
- modified_message: 被拦截型 handler 修改后的消息
"""
handlers = self._handlers.get(event_type, [])
if not handlers:
return True, None
continue_flag = True
current_message = message.deepcopy() if message else None
for entry in handlers:
if entry.intercept:
try:
should_continue, modified = await entry.handler(current_message)
if modified is not None:
current_message = modified
if not should_continue:
continue_flag = False
break
except Exception as e:
logger.error(f"拦截型 handler {entry.name} 执行异常: {e}", exc_info=True)
else:
self._fire_and_forget(entry, event_type, current_message)
# 桥接到 IPC 插件运行时
continue_flag, current_message = await self._bridge_to_ipc_runtime(
event_type, continue_flag, current_message
)
return continue_flag, current_message
async def cancel_handler_tasks(self, handler_name: str) -> None:
"""取消某个 handler 的所有运行中任务"""
tasks = self._running_tasks.pop(handler_name, [])
remaining = [t for t in tasks if not t.done()]
if remaining:
for t in remaining:
t.cancel()
await asyncio.gather(*remaining, return_exceptions=True)
logger.info(f"已取消 handler {handler_name}{len(remaining)} 个任务")
# --- 内部方法 ---
def _fire_and_forget(
self,
entry: "_HandlerEntry",
event_type: EventType | str,
message: Optional[MaiMessages],
) -> None:
"""创建异步任务执行非拦截型 handler"""
try:
task = asyncio.create_task(entry.handler(message))
task.set_name(entry.name)
task.add_done_callback(lambda t: self._task_done_callback(t, entry.name))
self._running_tasks.setdefault(entry.name, []).append(task)
except Exception as e:
logger.error(f"创建 handler 任务 {entry.name} 失败: {e}", exc_info=True)
def _task_done_callback(self, task: asyncio.Task, handler_name: str) -> None:
"""异步任务完成回调"""
try:
if task.cancelled():
return
exc = task.exception()
if exc:
logger.error(f"handler {handler_name} 异步任务异常: {exc}")
except Exception:
pass
finally:
task_list = self._running_tasks.get(handler_name, [])
try:
task_list.remove(task)
except ValueError:
pass
async def _bridge_to_ipc_runtime(
self,
event_type: EventType | str,
continue_flag: bool,
message: Optional[MaiMessages],
) -> Tuple[bool, Optional[MaiMessages]]:
"""将事件桥接到 IPC 插件运行时"""
if not continue_flag:
return continue_flag, message
try:
from src.plugin_runtime.integration import get_plugin_runtime_manager
prm = get_plugin_runtime_manager()
if not prm.is_running:
return continue_flag, message
event_value = event_type.value if isinstance(event_type, EventType) else str(event_type)
message_dict = message.to_dict() if message and hasattr(message, "to_dict") else None
new_continue, _ = await prm.bridge_event(
event_type_value=event_value,
message_dict=message_dict,
)
if not new_continue:
continue_flag = False
except Exception as e:
logger.warning(f"桥接事件到 IPC 运行时失败: {e}")
return continue_flag, message
class _HandlerEntry:
"""内部 handler 条目"""
__slots__ = ("handler", "name", "weight", "intercept")
def __init__(self, handler: EventHandler, name: str, weight: int, intercept: bool):
self.handler = handler
self.name = name
self.weight = weight
self.intercept = intercept
# 全局单例
event_bus = EventBus()

View File

@@ -169,6 +169,14 @@ class ToolInfo(ComponentInfo):
super().__post_init__()
self.component_type = ComponentType.TOOL
def get_llm_definition(self) -> dict:
"""生成 LLM function-calling 所需的工具定义"""
return {
"name": self.name,
"description": self.tool_description,
"parameters": self.tool_parameters,
}
@dataclass
class EventHandlerInfo(ComponentInfo):

View File

@@ -10,7 +10,7 @@ from src.config.config import global_config, model_config
from src.common.database.database_model import ChatHistory
from src.prompt.prompt_manager import prompt_manager
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
from src.plugin_system.apis import llm_api
from src.services import llm_service as llm_api
from src.dream.dream_generator import generate_dream_summary
# dream 工具工厂函数

View File

@@ -8,7 +8,7 @@ from src.llm_models.payload_content.message import RoleType, Message
from src.prompt.prompt_manager import prompt_manager
from src.llm_models.utils_model import LLMRequest
from src.common.utils.utils_session import SessionUtils
from src.plugin_system.apis import send_api
from src.services import send_service as send_api
logger = get_logger("dream_generator")

View File

@@ -2,7 +2,7 @@ from typing import Any, Dict, Optional
from src.common.logger import get_logger
from src.common.database.database_model import ChatHistory
from src.plugin_system.apis import database_api
from src.services import database_service as database_api
logger = get_logger("dream_agent")

View File

@@ -2,7 +2,7 @@ from typing import Any, Dict, Optional
from src.common.logger import get_logger
from src.common.database.database_model import Jargon
from src.plugin_system.apis import database_api
from src.services import database_service as database_api
logger = get_logger("dream_agent")

View File

@@ -18,10 +18,7 @@ from rich.traceback import install
# from src.api.main import start_api_server
# 导入新的插件管理器
from src.plugin_system.core.plugin_manager import plugin_manager
# 导入新版本插件运行时
# 导入插件运行时
from src.plugin_runtime.integration import get_plugin_runtime_manager
# 导入消息API和traceback模块
@@ -31,8 +28,6 @@ from src.bw_learner.expression_auto_check_task import ExpressionAutoCheckTask
from src.prompt.prompt_manager import prompt_manager
# 插件系统现在使用统一的插件加载器
install(extra_lines=3)
logger = get_logger("main")
@@ -108,10 +103,7 @@ class MainSystem:
# 启动LPMM
lpmm_start_up()
# 加载所有actions包括默认的和插件的
plugin_manager.load_all_plugins()
# 启动新版本插件运行时(与旧系统并行运行)
# 启动插件运行时(内置插件 + 第三方插件双子进程)
await get_plugin_runtime_manager().start()
# 初始化表情管理器
@@ -133,12 +125,12 @@ class MainSystem:
prompt_manager.load_prompts()
# 触发 ON_START 事件
from src.plugin_system.core.events_manager import events_manager
from src.plugin_system.base.component_types import EventType
from src.core.event_bus import event_bus
from src.core.types import EventType
await events_manager.handle_mai_events(event_type=EventType.ON_START)
await event_bus.emit(event_type=EventType.ON_START)
# 桥接 ON_START 事件到新版本插件运行时
# 分发 ON_START 事件到插件运行时
await get_plugin_runtime_manager().bridge_event("on_start")
# logger.info("已触发 ON_START 事件")
try:

View File

@@ -17,7 +17,7 @@ from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.config.config import model_config, global_config
from src.llm_models.utils_model import LLMRequest
from src.plugin_system.apis import message_api
from src.services import message_service as message_api
from src.chat.utils.chat_message_builder import build_readable_messages
from src.chat.utils.utils import is_bot_self
from src.person_info.person_info import Person
@@ -913,7 +913,7 @@ class ChatHistorySummarizer:
"""存储到数据库"""
try:
from src.common.database.database_model import ChatHistory
from src.plugin_system.apis import database_api
from src.services import database_service as database_api
# 准备数据
data = {

View File

@@ -7,7 +7,7 @@ from typing import List, Dict, Any, Optional, Tuple, Callable
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.prompt.prompt_manager import prompt_manager
from src.plugin_system.apis import llm_api
from src.services import llm_service as llm_api
from sqlmodel import select, col
from src.common.database.database import get_db_session
from src.common.database.database_model import ThinkingQuestion

File diff suppressed because it is too large Load Diff

View File

@@ -22,8 +22,9 @@ class UDSTransportServer(TransportServer):
def __init__(self, socket_path: str | None = None):
if socket_path is None:
# 默认放在临时目录
socket_path = os.path.join(tempfile.gettempdir(), f"maibot-plugin-{os.getpid()}.sock")
# 默认放在临时目录,使用 uuid 确保同一进程多实例不碰撞
import uuid
socket_path = os.path.join(tempfile.gettempdir(), f"maibot-plugin-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock")
self._socket_path = socket_path
self._server: asyncio.AbstractServer | None = None

View File

@@ -1,156 +0,0 @@
"""
MaiBot 插件系统
提供统一的插件开发和管理框架
"""
# 导出主要的公共接口
from .base import (
BasePlugin,
BaseAction,
BaseCommand,
BaseTool,
ConfigField,
ConfigSection,
ConfigLayout,
ConfigTab,
ComponentType,
ActionActivationType,
ChatMode,
ComponentInfo,
ActionInfo,
CommandInfo,
PluginInfo,
ToolInfo,
PythonDependency,
BaseEventHandler,
EventHandlerInfo,
EventType,
MaiMessages,
ToolParamType,
CustomEventHandlerResult,
PluginServiceInfo,
ReplyContentType,
ReplyContent,
ForwardNode,
ReplySetModel,
WorkflowContext,
WorkflowMessage,
WorkflowStage,
WorkflowStepInfo,
WorkflowStepResult,
WorkflowErrorCode,
)
# 导入工具模块
from .utils import (
ManifestValidator,
# ManifestGenerator,
# validate_plugin_manifest,
# generate_plugin_manifest,
)
from .apis import (
chat_api,
tool_api,
component_manage_api,
config_api,
database_api,
emoji_api,
generator_api,
llm_api,
message_api,
person_api,
plugin_manage_api,
plugin_service_api,
workflow_api,
send_api,
register_plugin,
get_logger,
)
from src.common.data_models.database_data_model import (
DatabaseMessages,
DatabaseUserInfo,
DatabaseGroupInfo,
DatabaseChatInfo,
)
from src.common.data_models.info_data_model import TargetPersonInfo, ActionPlannerInfo
from src.common.data_models.llm_data_model import LLMGenerationDataModel
__version__ = "2.0.0"
__all__ = [
# API 模块
"chat_api",
"tool_api",
"component_manage_api",
"config_api",
"database_api",
"emoji_api",
"generator_api",
"llm_api",
"message_api",
"person_api",
"plugin_manage_api",
"plugin_service_api",
"workflow_api",
"send_api",
"auto_talk_api",
"register_plugin",
"get_logger",
# 基础类
"BasePlugin",
"BaseAction",
"BaseCommand",
"BaseTool",
"BaseEventHandler",
# 类型定义
"ComponentType",
"ActionActivationType",
"ChatMode",
"ComponentInfo",
"ActionInfo",
"CommandInfo",
"PluginInfo",
"ToolInfo",
"PythonDependency",
"EventHandlerInfo",
"EventType",
"ToolParamType",
# 消息
"ReplyContentType",
"ReplyContent",
"ForwardNode",
"ReplySetModel",
"MaiMessages",
"CustomEventHandlerResult",
"PluginServiceInfo",
"WorkflowContext",
"WorkflowMessage",
"WorkflowStage",
"WorkflowStepInfo",
"WorkflowStepResult",
"WorkflowErrorCode",
# 装饰器
"register_plugin",
"ConfigField",
"ConfigSection",
"ConfigLayout",
"ConfigTab",
# 工具函数
"ManifestValidator",
"get_logger",
# "ManifestGenerator",
# "validate_plugin_manifest",
# "generate_plugin_manifest",
# 数据模型
"DatabaseMessages",
"DatabaseUserInfo",
"DatabaseGroupInfo",
"DatabaseChatInfo",
"TargetPersonInfo",
"ActionPlannerInfo",
"LLMGenerationDataModel",
]

View File

@@ -1,47 +0,0 @@
"""
插件系统API模块
提供了插件开发所需的各种API
"""
# 导入所有API模块
from src.plugin_system.apis import (
chat_api,
component_manage_api,
config_api,
database_api,
emoji_api,
generator_api,
llm_api,
message_api,
person_api,
plugin_manage_api,
plugin_service_api,
send_api,
tool_api,
frequency_api,
workflow_api,
)
from .logging_api import get_logger
from .plugin_register_api import register_plugin
# 导出所有API模块使它们可以通过 apis.xxx 方式访问
__all__ = [
"chat_api",
"component_manage_api",
"config_api",
"database_api",
"emoji_api",
"generator_api",
"llm_api",
"message_api",
"person_api",
"plugin_manage_api",
"plugin_service_api",
"send_api",
"get_logger",
"register_plugin",
"tool_api",
"frequency_api",
"workflow_api",
]

View File

@@ -1,323 +0,0 @@
"""
聊天API模块
专门负责聊天信息的查询和管理采用标准Python包设计模式
使用方式:
from src.plugin_system.apis import chat_api
streams = chat_api.get_all_group_streams()
chat_type = chat_api.get_stream_type(stream)
或者:
from src.plugin_system.apis.chat_api import ChatManager as chat
streams = chat.get_all_group_streams()
"""
from typing import List, Dict, Any, Optional
from enum import Enum
from src.common.logger import get_logger
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
logger = get_logger("chat_api")
class SpecialTypes(Enum):
"""特殊枚举类型"""
ALL_PLATFORMS = "all_platforms"
class ChatManager:
"""聊天管理器 - 专门负责聊天信息的查询和管理"""
@staticmethod
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
# sourcery skip: for-append-to-extend
"""获取所有聊天流
Args:
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns:
List[BotChatSession]: 聊天流列表
Raises:
TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型
"""
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = []
try:
for _, stream in _chat_manager.sessions.items():
if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform:
streams.append(stream)
logger.debug(f"[ChatAPI] 获取到 {len(streams)}{platform} 平台的聊天流")
except Exception as e:
logger.error(f"[ChatAPI] 获取聊天流失败: {e}")
return streams
@staticmethod
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
# sourcery skip: for-append-to-extend
"""获取所有群聊聊天流
Args:
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns:
List[BotChatSession]: 群聊聊天流列表
"""
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = []
try:
for _, stream in _chat_manager.sessions.items():
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.is_group_session:
streams.append(stream)
logger.debug(f"[ChatAPI] 获取到 {len(streams)}{platform} 平台的群聊流")
except Exception as e:
logger.error(f"[ChatAPI] 获取群聊流失败: {e}")
return streams
@staticmethod
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
# sourcery skip: for-append-to-extend
"""获取所有私聊聊天流
Args:
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns:
List[BotChatSession]: 私聊聊天流列表
Raises:
TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型
"""
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = []
try:
for _, stream in _chat_manager.sessions.items():
if (
platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform
) and not stream.is_group_session:
streams.append(stream)
logger.debug(f"[ChatAPI] 获取到 {len(streams)}{platform} 平台的私聊流")
except Exception as e:
logger.error(f"[ChatAPI] 获取私聊流失败: {e}")
return streams
@staticmethod
def get_group_stream_by_group_id(
group_id: str, platform: Optional[str] | SpecialTypes = "qq"
) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast
"""根据群ID获取聊天流
Args:
group_id: 群聊ID
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns:
Optional[BotChatSession]: 聊天流对象如果未找到返回None
Raises:
ValueError: 如果 group_id 为空字符串
TypeError: 如果 group_id 不是字符串类型或 platform 不是字符串或 SpecialTypes
"""
if not isinstance(group_id, str):
raise TypeError("group_id 必须是字符串类型")
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
if not group_id:
raise ValueError("group_id 不能为空")
try:
for _, stream in _chat_manager.sessions.items():
if (
stream.is_group_session
and str(stream.group_id) == str(group_id)
and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform)
):
logger.debug(f"[ChatAPI] 找到群ID {group_id} 的聊天流")
return stream
logger.warning(f"[ChatAPI] 未找到群ID {group_id} 的聊天流")
except Exception as e:
logger.error(f"[ChatAPI] 查找群聊流失败: {e}")
return None
@staticmethod
def get_private_stream_by_user_id(
user_id: str, platform: Optional[str] | SpecialTypes = "qq"
) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast
"""根据用户ID获取私聊流
Args:
user_id: 用户ID
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns:
Optional[BotChatSession]: 聊天流对象如果未找到返回None
Raises:
ValueError: 如果 user_id 为空字符串
TypeError: 如果 user_id 不是字符串类型或 platform 不是字符串或 SpecialTypes
"""
if not isinstance(user_id, str):
raise TypeError("user_id 必须是字符串类型")
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
if not user_id:
raise ValueError("user_id 不能为空")
try:
for _, stream in _chat_manager.sessions.items():
if (
not stream.is_group_session
and str(stream.user_id) == str(user_id)
and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform)
):
logger.debug(f"[ChatAPI] 找到用户ID {user_id} 的私聊流")
return stream
logger.warning(f"[ChatAPI] 未找到用户ID {user_id} 的私聊流")
except Exception as e:
logger.error(f"[ChatAPI] 查找私聊流失败: {e}")
return None
@staticmethod
def get_stream_type(chat_stream: BotChatSession) -> str:
"""获取聊天流类型
Args:
chat_stream: 聊天流对象
Returns:
str: 聊天类型 ("group", "private", "unknown")
Raises:
TypeError: 如果 chat_stream 不是 BotChatSession 类型
ValueError: 如果 chat_stream 为空
"""
if not isinstance(chat_stream, BotChatSession):
raise TypeError("chat_stream 必须是 BotChatSession 类型")
if not chat_stream:
raise ValueError("chat_stream 不能为 None")
return "group" if chat_stream.is_group_session else "private"
@staticmethod
def get_stream_info(chat_stream: BotChatSession) -> Dict[str, Any]:
"""获取聊天流详细信息
Args:
chat_stream: 聊天流对象
Returns:
Dict ({str: Any}): 聊天流信息字典
Raises:
TypeError: 如果 chat_stream 不是 BotChatSession 类型
ValueError: 如果 chat_stream 为空
"""
if not chat_stream:
raise ValueError("chat_stream 不能为 None")
if not isinstance(chat_stream, BotChatSession):
raise TypeError("chat_stream 必须是 BotChatSession 类型")
try:
info: Dict[str, Any] = {
"session_id": chat_stream.session_id,
"platform": chat_stream.platform,
"type": ChatManager.get_stream_type(chat_stream),
}
if chat_stream.is_group_session:
info["group_id"] = chat_stream.group_id
# Try to get group name from context
if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.group_info:
info["group_name"] = chat_stream.context.message.message_info.group_info.group_name or "未知群聊"
else:
info["group_name"] = "未知群聊"
else:
info["user_id"] = chat_stream.user_id
if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.user_info:
info["user_name"] = chat_stream.context.message.message_info.user_info.user_nickname
else:
info["user_name"] = "未知用户"
return info
except Exception as e:
logger.error(f"[ChatAPI] 获取聊天流信息失败: {e}")
return {}
@staticmethod
def get_streams_summary() -> Dict[str, int]:
"""获取聊天流统计摘要
Returns:
Dict[str, int]: 包含各种统计信息的字典
"""
try:
all_streams = ChatManager.get_all_streams(SpecialTypes.ALL_PLATFORMS)
group_streams = ChatManager.get_group_streams(SpecialTypes.ALL_PLATFORMS)
private_streams = ChatManager.get_private_streams(SpecialTypes.ALL_PLATFORMS)
summary = {
"total_streams": len(all_streams),
"group_streams": len(group_streams),
"private_streams": len(private_streams),
"qq_streams": len([s for s in all_streams if s.platform == "qq"]),
}
logger.debug(f"[ChatAPI] 聊天流统计: {summary}")
return summary
except Exception as e:
logger.error(f"[ChatAPI] 获取聊天流统计失败: {e}")
return {
"total_streams": 0,
"group_streams": 0,
"private_streams": 0,
"qq_streams": 0,
}
# =============================================================================
# 模块级别的便捷函数 - 类似 requests.get(), requests.post() 的设计
# =============================================================================
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
"""获取所有聊天流的便捷函数"""
return ChatManager.get_all_streams(platform)
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
"""获取群聊聊天流的便捷函数"""
return ChatManager.get_group_streams(platform)
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
"""获取私聊聊天流的便捷函数"""
return ChatManager.get_private_streams(platform)
def get_stream_by_group_id(group_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[BotChatSession]:
"""根据群ID获取聊天流的便捷函数"""
return ChatManager.get_group_stream_by_group_id(group_id, platform)
def get_stream_by_user_id(user_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[BotChatSession]:
"""根据用户ID获取私聊流的便捷函数"""
return ChatManager.get_private_stream_by_user_id(user_id, platform)
def get_stream_type(chat_stream: BotChatSession) -> str:
"""获取聊天流类型的便捷函数"""
return ChatManager.get_stream_type(chat_stream)
def get_stream_info(chat_stream: BotChatSession) -> Dict[str, Any]:
"""获取聊天流信息的便捷函数"""
return ChatManager.get_stream_info(chat_stream)
def get_streams_summary() -> Dict[str, int]:
"""获取聊天流统计摘要的便捷函数"""
return ChatManager.get_streams_summary()

View File

@@ -1,268 +0,0 @@
from typing import Optional, Union, Dict
from src.plugin_system.base.component_types import (
CommandInfo,
ActionInfo,
EventHandlerInfo,
PluginInfo,
ComponentType,
ToolInfo,
)
# === 插件信息查询 ===
def get_all_plugin_info() -> Dict[str, PluginInfo]:
"""
获取所有插件的信息。
Returns:
dict: 包含所有插件信息的字典,键为插件名称,值为 PluginInfo 对象。
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_all_plugins()
def get_plugin_info(plugin_name: str) -> Optional[PluginInfo]:
"""
获取指定插件的信息。
Args:
plugin_name (str): 插件名称。
Returns:
PluginInfo: 插件信息对象,如果插件不存在则返回 None。
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_plugin_info(plugin_name)
# === 组件查询方法 ===
def get_component_info(
component_name: str, component_type: ComponentType
) -> Optional[Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
"""
获取指定组件的信息。
Args:
component_name (str): 组件名称。
component_type (ComponentType): 组件类型。
Returns:
Union[CommandInfo, ActionInfo, EventHandlerInfo]: 组件信息对象,如果组件不存在则返回 None。
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_component_info(component_name, component_type) # type: ignore
def get_components_info_by_type(
component_type: ComponentType,
) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
"""
获取指定类型的所有组件信息。
Args:
component_type (ComponentType): 组件类型。
Returns:
dict: 包含指定类型组件信息的字典,键为组件名称,值为对应的组件信息对象。
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_components_by_type(component_type) # type: ignore
def get_enabled_components_info_by_type(
component_type: ComponentType,
) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
"""
获取指定类型的所有启用的组件信息。
Args:
component_type (ComponentType): 组件类型。
Returns:
dict: 包含指定类型启用组件信息的字典,键为组件名称,值为对应的组件信息对象。
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_enabled_components_by_type(component_type) # type: ignore
# === Action 查询方法 ===
def get_registered_action_info(action_name: str) -> Optional[ActionInfo]:
"""
获取指定 Action 的注册信息。
Args:
action_name (str): Action 名称。
Returns:
ActionInfo: Action 信息对象,如果 Action 不存在则返回 None。
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_registered_action_info(action_name)
def get_registered_command_info(command_name: str) -> Optional[CommandInfo]:
"""
获取指定 Command 的注册信息。
Args:
command_name (str): Command 名称。
Returns:
CommandInfo: Command 信息对象,如果 Command 不存在则返回 None。
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_registered_command_info(command_name)
def get_registered_tool_info(tool_name: str) -> Optional[ToolInfo]:
"""
获取指定 Tool 的注册信息。
Args:
tool_name (str): Tool 名称。
Returns:
ToolInfo: Tool 信息对象,如果 Tool 不存在则返回 None。
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_registered_tool_info(tool_name)
# === EventHandler 特定查询方法 ===
def get_registered_event_handler_info(
event_handler_name: str,
) -> Optional[EventHandlerInfo]:
"""
获取指定 EventHandler 的注册信息。
Args:
event_handler_name (str): EventHandler 名称。
Returns:
EventHandlerInfo: EventHandler 信息对象,如果 EventHandler 不存在则返回 None。
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_registered_event_handler_info(event_handler_name)
# === 组件管理方法 ===
def globally_enable_component(component_name: str, component_type: ComponentType) -> bool:
"""
全局启用指定组件。
Args:
component_name (str): 组件名称。
component_type (ComponentType): 组件类型。
Returns:
bool: 启用成功返回 True否则返回 False。
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.enable_component(component_name, component_type)
async def globally_disable_component(component_name: str, component_type: ComponentType) -> bool:
"""
全局禁用指定组件。
**此函数是异步的,确保在异步环境中调用。**
Args:
component_name (str): 组件名称。
component_type (ComponentType): 组件类型。
Returns:
bool: 禁用成功返回 True否则返回 False。
"""
from src.plugin_system.core.component_registry import component_registry
return await component_registry.disable_component(component_name, component_type)
def locally_enable_component(component_name: str, component_type: ComponentType, stream_id: str) -> bool:
"""
局部启用指定组件。
Args:
component_name (str): 组件名称。
component_type (ComponentType): 组件类型。
stream_id (str): 消息流 ID。
Returns:
bool: 启用成功返回 True否则返回 False。
"""
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
match component_type:
case ComponentType.ACTION:
return global_announcement_manager.enable_specific_chat_action(stream_id, component_name)
case ComponentType.COMMAND:
return global_announcement_manager.enable_specific_chat_command(stream_id, component_name)
case ComponentType.TOOL:
return global_announcement_manager.enable_specific_chat_tool(stream_id, component_name)
case ComponentType.EVENT_HANDLER:
return global_announcement_manager.enable_specific_chat_event_handler(stream_id, component_name)
case _:
raise ValueError(f"未知 component type: {component_type}")
def locally_disable_component(component_name: str, component_type: ComponentType, stream_id: str) -> bool:
"""
局部禁用指定组件。
Args:
component_name (str): 组件名称。
component_type (ComponentType): 组件类型。
stream_id (str): 消息流 ID。
Returns:
bool: 禁用成功返回 True否则返回 False。
"""
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
match component_type:
case ComponentType.ACTION:
return global_announcement_manager.disable_specific_chat_action(stream_id, component_name)
case ComponentType.COMMAND:
return global_announcement_manager.disable_specific_chat_command(stream_id, component_name)
case ComponentType.TOOL:
return global_announcement_manager.disable_specific_chat_tool(stream_id, component_name)
case ComponentType.EVENT_HANDLER:
return global_announcement_manager.disable_specific_chat_event_handler(stream_id, component_name)
case _:
raise ValueError(f"未知 component type: {component_type}")
def get_locally_disabled_components(stream_id: str, component_type: ComponentType) -> list[str]:
"""
获取指定消息流中禁用的组件列表。
Args:
stream_id (str): 消息流 ID。
component_type (ComponentType): 组件类型。
Returns:
list[str]: 禁用的组件名称列表。
"""
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
match component_type:
case ComponentType.ACTION:
return global_announcement_manager.get_disabled_chat_actions(stream_id)
case ComponentType.COMMAND:
return global_announcement_manager.get_disabled_chat_commands(stream_id)
case ComponentType.TOOL:
return global_announcement_manager.get_disabled_chat_tools(stream_id)
case ComponentType.EVENT_HANDLER:
return global_announcement_manager.get_disabled_chat_event_handlers(stream_id)
case _:
raise ValueError(f"未知 component type: {component_type}")

View File

@@ -1,22 +0,0 @@
from pathlib import Path
from dataclasses import dataclass
@dataclass(frozen=True)
class _SystemConstants:
PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.parent.absolute().resolve()
CONFIG_DIR: Path = PROJECT_ROOT / "config"
BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
PLUGINS_DIR: Path = (PROJECT_ROOT / "plugins").resolve().absolute()
INTERNAL_PLUGINS_DIR: Path = (PROJECT_ROOT / "src" / "plugins").resolve().absolute()
_system_constants = _SystemConstants()
PROJECT_ROOT: Path = _system_constants.PROJECT_ROOT
CONFIG_DIR: Path = _system_constants.CONFIG_DIR
BOT_CONFIG_PATH: Path = _system_constants.BOT_CONFIG_PATH
MODEL_CONFIG_PATH: Path = _system_constants.MODEL_CONFIG_PATH
PLUGINS_DIR: Path = _system_constants.PLUGINS_DIR
INTERNAL_PLUGINS_DIR: Path = _system_constants.INTERNAL_PLUGINS_DIR

View File

@@ -1,700 +0,0 @@
"""
表情API模块
提供表情包相关功能采用标准Python包设计模式
使用方式:
from src.plugin_system.apis import emoji_api
result = await emoji_api.get_by_description("开心")
count = emoji_api.get_count()
"""
import random
import base64
import os
import uuid
from typing import Optional, Tuple, List, Dict, Any
from src.common.logger import get_logger
from src.chat.emoji_system.emoji_manager import emoji_manager, EMOJI_DIR
from src.common.utils.utils_image import ImageUtils
from src.config.config import global_config
logger = get_logger("emoji_api")
# =============================================================================
# 表情包获取API函数
# =============================================================================
async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]:
"""根据描述选择表情包
Args:
description: 表情包的描述文本,例如"开心""难过""愤怒"
Returns:
Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None
Raises:
ValueError: 如果描述为空字符串
TypeError: 如果描述不是字符串类型
"""
if not description:
raise ValueError("描述不能为空")
if not isinstance(description, str):
raise TypeError("描述必须是字符串类型")
try:
logger.debug(f"[EmojiAPI] 根据描述获取表情包: {description}")
emoji_obj = await emoji_manager.get_emoji_for_emotion(description)
if not emoji_obj:
logger.warning(f"[EmojiAPI] 未找到匹配描述 '{description}' 的表情包")
return None
emoji_path = str(emoji_obj.full_path)
emoji_description = emoji_obj.description
matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else ""
emoji_base64 = ImageUtils.image_path_to_base64(emoji_path)
if not emoji_base64:
logger.error(f"[EmojiAPI] 无法将表情包文件转换为base64: {emoji_path}")
return None
logger.debug(f"[EmojiAPI] 成功获取表情包: {emoji_description}, 匹配情感: {matched_emotion}")
return emoji_base64, emoji_description, matched_emotion
except Exception as e:
logger.error(f"[EmojiAPI] 获取表情包失败: {e}")
return None
async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
"""随机获取指定数量的表情包
Args:
count: 要获取的表情包数量默认为1
Returns:
List[Tuple[str, str, str]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表,失败则返回空列表
Raises:
TypeError: 如果count不是整数类型
ValueError: 如果count为负数
"""
if not isinstance(count, int):
raise TypeError("count 必须是整数类型")
if count < 0:
raise ValueError("count 不能为负数")
if count == 0:
logger.warning("[EmojiAPI] count 为0返回空列表")
return []
try:
all_emojis = emoji_manager.emojis
if not all_emojis:
logger.warning("[EmojiAPI] 没有可用的表情包")
return []
# 过滤有效表情包
valid_emojis = [emoji for emoji in all_emojis if not emoji.is_deleted]
if not valid_emojis:
logger.warning("[EmojiAPI] 没有有效的表情包")
return []
if len(valid_emojis) < count:
logger.debug(
f"[EmojiAPI] 有效表情包数量 ({len(valid_emojis)}) 少于请求的数量 ({count}),将返回所有有效表情包"
)
count = len(valid_emojis)
# 随机选择
selected_emojis = random.sample(valid_emojis, count)
results = []
for selected_emoji in selected_emojis:
emoji_base64 = ImageUtils.image_path_to_base64(str(selected_emoji.full_path))
if not emoji_base64:
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {selected_emoji.full_path}")
continue
matched_emotion = random.choice(selected_emoji.emotion) if selected_emoji.emotion else "随机表情"
# 记录使用次数
emoji_manager.update_emoji_usage(selected_emoji)
results.append((emoji_base64, selected_emoji.description, matched_emotion))
if not results and count > 0:
logger.warning("[EmojiAPI] 随机获取表情包失败,没有一个可以成功处理")
return []
logger.debug(f"[EmojiAPI] 成功获取 {len(results)} 个随机表情包")
return results
except Exception as e:
logger.error(f"[EmojiAPI] 获取随机表情包失败: {e}")
return []
async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
"""根据情感标签获取表情包
Args:
emotion: 情感标签,如"happy""sad""angry"
Returns:
Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None
Raises:
ValueError: 如果情感标签为空字符串
TypeError: 如果情感标签不是字符串类型
"""
if not emotion:
raise ValueError("情感标签不能为空")
if not isinstance(emotion, str):
raise TypeError("情感标签必须是字符串类型")
try:
logger.info(f"[EmojiAPI] 根据情感获取表情包: {emotion}")
all_emojis = emoji_manager.emojis
# 筛选匹配情感的表情包
matching_emojis = []
matching_emojis.extend(
emoji_obj
for emoji_obj in all_emojis
if not emoji_obj.is_deleted and emotion.lower() in [e.lower() for e in emoji_obj.emotion]
)
if not matching_emojis:
logger.warning(f"[EmojiAPI] 未找到匹配情感 '{emotion}' 的表情包")
return None
# 随机选择匹配的表情包
selected_emoji = random.choice(matching_emojis)
emoji_base64 = ImageUtils.image_path_to_base64(selected_emoji.full_path)
if not emoji_base64:
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {selected_emoji.full_path}")
return None
# 记录使用次数
emoji_manager.update_emoji_usage(selected_emoji)
logger.info(f"[EmojiAPI] 成功获取情感表情包: {selected_emoji.description}")
return emoji_base64, selected_emoji.description, emotion
except Exception as e:
logger.error(f"[EmojiAPI] 根据情感获取表情包失败: {e}")
return None
# =============================================================================
# 表情包信息查询API函数
# =============================================================================
def get_count() -> int:
"""获取表情包数量
Returns:
int: 当前可用的表情包数量
"""
try:
return len(emoji_manager.emojis)
except Exception as e:
logger.error(f"[EmojiAPI] 获取表情包数量失败: {e}")
return 0
def get_info():
"""获取表情包系统信息
Returns:
dict: 包含表情包数量、最大数量、可用数量信息
"""
try:
return {
"current_count": len(emoji_manager.emojis),
"max_count": global_config.emoji.max_reg_num,
"available_emojis": len([e for e in emoji_manager.emojis if not e.is_deleted]),
}
except Exception as e:
logger.error(f"[EmojiAPI] 获取表情包信息失败: {e}")
return {"current_count": 0, "max_count": 0, "available_emojis": 0}
def get_emotions() -> List[str]:
"""获取所有可用的情感标签
Returns:
list: 所有表情包的情感标签列表(去重)
"""
try:
emotions = set()
for emoji_obj in emoji_manager.emojis:
if not emoji_obj.is_deleted and emoji_obj.emotion:
emotions.update(emoji_obj.emotion)
return sorted(list(emotions))
except Exception as e:
logger.error(f"[EmojiAPI] 获取情感标签失败: {e}")
return []
async def get_all() -> List[Tuple[str, str, str]]:
"""获取所有表情包
Returns:
List[Tuple[str, str, str]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表
"""
try:
all_emojis = emoji_manager.emojis
if not all_emojis:
logger.warning("[EmojiAPI] 没有可用的表情包")
return []
results = []
for emoji_obj in all_emojis:
if emoji_obj.is_deleted:
continue
emoji_base64 = ImageUtils.image_path_to_base64(str(emoji_obj.full_path))
if not emoji_base64:
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {emoji_obj.full_path}")
continue
matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else "随机表情"
results.append((emoji_base64, emoji_obj.description, matched_emotion))
logger.debug(f"[EmojiAPI] 成功获取 {len(results)} 个表情包")
return results
except Exception as e:
logger.error(f"[EmojiAPI] 获取所有表情包失败: {e}")
return []
def get_descriptions() -> List[str]:
"""获取所有表情包描述
Returns:
list: 所有可用表情包的描述列表
"""
try:
descriptions = []
descriptions.extend(
emoji_obj.description
for emoji_obj in emoji_manager.emojis
if not emoji_obj.is_deleted and emoji_obj.description
)
return descriptions
except Exception as e:
logger.error(f"[EmojiAPI] 获取表情包描述失败: {e}")
return []
# =============================================================================
# 表情包注册API函数
# =============================================================================
async def register_emoji(image_base64: str, filename: Optional[str] = None) -> Dict[str, Any]:
"""注册新的表情包
Args:
image_base64: 图片的base64编码
filename: 可选的文件名,如果未提供则自动生成
Returns:
Dict[str, Any]: 注册结果,包含以下字段:
- success: bool, 是否成功注册
- message: str, 结果消息
- description: Optional[str], 表情包描述(成功时)
- emotions: Optional[List[str]], 情感标签列表(成功时)
- replaced: Optional[bool], 是否替换了旧表情包(成功时)
- hash: Optional[str], 表情包哈希值(成功时)
Raises:
ValueError: 如果base64为空或无效
TypeError: 如果参数类型不正确
"""
if not image_base64:
raise ValueError("图片base64编码不能为空")
if not isinstance(image_base64, str):
raise TypeError("image_base64必须是字符串类型")
if filename is not None and not isinstance(filename, str):
raise TypeError("filename必须是字符串类型或None")
try:
logger.info(f"[EmojiAPI] 开始注册表情包,文件名: {filename or '自动生成'}")
# 1. 获取emoji管理器并检查容量
count_before = len(emoji_manager.emojis)
max_count = global_config.emoji.max_reg_num
# 2. 检查是否可以注册(未达到上限或启用替换)
can_register = count_before < max_count or (count_before >= max_count and global_config.emoji.do_replace)
if not can_register:
return {
"success": False,
"message": f"表情包数量已达上限({count_before}/{max_count})且未启用替换功能",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
# 3. 确保emoji目录存在
os.makedirs(EMOJI_DIR, exist_ok=True)
# 4. 生成文件名
if not filename:
# 基于时间戳、微秒和短base64生成唯一文件名
import time
timestamp = int(time.time())
microseconds = int(time.time() * 1000000) % 1000000 # 添加微秒级精度
# 生成12位随机标识符使用base64编码增加随机性
import random
random_bytes = random.getrandbits(72).to_bytes(9, "big") # 72位 = 9字节 = 12位base64
short_id = base64.b64encode(random_bytes).decode("ascii")[:12].rstrip("=")
# 确保base64编码适合文件名替换/和-
short_id = short_id.replace("/", "_").replace("+", "-")
filename = f"emoji_{timestamp}_{microseconds}_{short_id}"
# 确保文件名有扩展名
if not filename.lower().endswith((".jpg", ".jpeg", ".png", ".gif")):
filename = f"{filename}.png" # 默认使用png格式
# 检查文件名是否已存在,如果存在则重新生成短标识符
temp_file_path = os.path.join(EMOJI_DIR, filename)
attempts = 0
max_attempts = 10
while os.path.exists(temp_file_path) and attempts < max_attempts:
# 重新生成短标识符
import random
random_bytes = random.getrandbits(48).to_bytes(6, "big")
short_id = base64.b64encode(random_bytes).decode("ascii")[:8].rstrip("=")
short_id = short_id.replace("/", "_").replace("+", "-")
# 分离文件名和扩展名,重新生成文件名
name_part, ext = os.path.splitext(filename)
# 去掉原来的标识符,添加新的
base_name = name_part.rsplit("_", 1)[0] # 移除最后一个_后的部分
filename = f"{base_name}_{short_id}{ext}"
temp_file_path = os.path.join(EMOJI_DIR, filename)
attempts += 1
# 如果还是冲突使用UUID作为备用方案
if os.path.exists(temp_file_path):
uuid_short = str(uuid.uuid4())[:8]
name_part, ext = os.path.splitext(filename)
base_name = name_part.rsplit("_", 1)[0]
filename = f"{base_name}_{uuid_short}{ext}"
temp_file_path = os.path.join(EMOJI_DIR, filename)
# 如果UUID方案也冲突添加序号
counter = 1
original_filename = filename
while os.path.exists(temp_file_path):
name_part, ext = os.path.splitext(original_filename)
filename = f"{name_part}_{counter}{ext}"
temp_file_path = os.path.join(EMOJI_DIR, filename)
counter += 1
# 防止无限循环最多尝试100次
if counter > 100:
logger.error(f"[EmojiAPI] 无法生成唯一文件名,尝试次数过多: {original_filename}")
return {
"success": False,
"message": "无法生成唯一文件名,请稍后重试",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
# 5. 保存base64图片到emoji目录
try:
# 解码base64并保存图片
if not ImageUtils.base64_to_image(image_base64, temp_file_path):
logger.error(f"[EmojiAPI] 无法保存base64图片到文件: {temp_file_path}")
return {
"success": False,
"message": "无法保存图片文件",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
logger.debug(f"[EmojiAPI] 图片已保存到临时文件: {temp_file_path}")
except Exception as save_error:
logger.error(f"[EmojiAPI] 保存图片文件失败: {save_error}")
return {
"success": False,
"message": f"保存图片文件失败: {str(save_error)}",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
# 6. 调用注册方法
register_success = await emoji_manager.register_emoji_by_filename(filename)
# 7. 清理临时文件(如果注册失败但文件还存在)
if not register_success and os.path.exists(temp_file_path):
try:
os.remove(temp_file_path)
logger.debug(f"[EmojiAPI] 已清理临时文件: {temp_file_path}")
except Exception as cleanup_error:
logger.warning(f"[EmojiAPI] 清理临时文件失败: {cleanup_error}")
# 8. 构建返回结果
if register_success:
count_after = len(emoji_manager.emojis)
replaced = count_after <= count_before # 如果数量没增加,说明是替换
# 尝试获取新注册的表情包信息
new_emoji_info = None
if count_after > count_before or replaced:
# 获取最新的表情包信息
try:
# 通过文件名查找新注册的表情包(注意:文件名在注册后可能已经改变)
for emoji_obj in reversed(emoji_manager.emojis):
if not emoji_obj.is_deleted and (
emoji_obj.file_name == filename
or (hasattr(emoji_obj, "full_path") and filename in str(emoji_obj.full_path))
):
new_emoji_info = emoji_obj
break
except Exception as find_error:
logger.warning(f"[EmojiAPI] 查找新注册表情包信息失败: {find_error}")
description = new_emoji_info.description if new_emoji_info else None
emotions = new_emoji_info.emotion if new_emoji_info else None
emoji_hash = new_emoji_info.emoji_hash if new_emoji_info else None
return {
"success": True,
"message": f"表情包注册成功 {'(替换旧表情包)' if replaced else '(新增表情包)'}",
"description": description,
"emotions": emotions,
"replaced": replaced,
"hash": emoji_hash,
}
else:
return {
"success": False,
"message": "表情包注册失败,可能因为重复、格式不支持或审核未通过",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
except Exception as e:
logger.error(f"[EmojiAPI] 注册表情包时发生异常: {e}")
return {
"success": False,
"message": f"注册过程中发生错误: {str(e)}",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
# =============================================================================
# 表情包删除API函数
# =============================================================================
async def delete_emoji(emoji_hash: str) -> Dict[str, Any]:
"""删除表情包
Args:
emoji_hash: 要删除的表情包的哈希值
Returns:
Dict[str, Any]: 删除结果,包含以下字段:
- success: bool, 是否成功删除
- message: str, 结果消息
- count_before: Optional[int], 删除前的表情包数量
- count_after: Optional[int], 删除后的表情包数量
- description: Optional[str], 被删除的表情包描述(成功时)
- emotions: Optional[List[str]], 被删除的表情包情感标签(成功时)
Raises:
ValueError: 如果哈希值为空
TypeError: 如果哈希值不是字符串类型
"""
if not emoji_hash:
raise ValueError("表情包哈希值不能为空")
if not isinstance(emoji_hash, str):
raise TypeError("emoji_hash必须是字符串类型")
try:
logger.info(f"[EmojiAPI] 开始删除表情包,哈希值: {emoji_hash}")
# 1. 获取emoji管理器和删除前的数量
count_before = len(emoji_manager.emojis)
# 2. 获取被删除表情包的信息(用于返回结果)
deleted_emoji = None
try:
deleted_emoji = emoji_manager.get_emoji_by_hash(emoji_hash) or emoji_manager.get_emoji_by_hash_from_db(
emoji_hash
)
description = deleted_emoji.description if deleted_emoji else None
emotions = deleted_emoji.emotion if deleted_emoji else None
except Exception as info_error:
logger.warning(f"[EmojiAPI] 获取被删除表情包信息失败: {info_error}")
description = None
emotions = None
# 3. 执行删除操作
delete_success = False
if deleted_emoji:
delete_success = emoji_manager.delete_emoji(deleted_emoji)
# 4. 获取删除后的数量
count_after = len(emoji_manager.emojis)
# 5. 构建返回结果
if delete_success:
return {
"success": True,
"message": f"表情包删除成功 (哈希: {emoji_hash[:8]}...)",
"count_before": count_before,
"count_after": count_after,
"description": description,
"emotions": emotions,
}
else:
return {
"success": False,
"message": "表情包删除失败,可能因为哈希值不存在或删除过程出错",
"count_before": count_before,
"count_after": count_after,
"description": None,
"emotions": None,
}
except Exception as e:
logger.error(f"[EmojiAPI] 删除表情包时发生异常: {e}")
return {
"success": False,
"message": f"删除过程中发生错误: {str(e)}",
"count_before": None,
"count_after": None,
"description": None,
"emotions": None,
}
async def delete_emoji_by_description(description: str, exact_match: bool = False) -> Dict[str, Any]:
"""根据描述删除表情包
Args:
description: 表情包描述文本
exact_match: 是否精确匹配描述False则为模糊匹配
Returns:
Dict[str, Any]: 删除结果,包含以下字段:
- success: bool, 是否成功删除
- message: str, 结果消息
- deleted_count: int, 删除的表情包数量
- deleted_hashes: List[str], 被删除的表情包哈希列表
- matched_count: int, 匹配到的表情包数量
Raises:
ValueError: 如果描述为空
TypeError: 如果描述不是字符串类型
"""
if not description:
raise ValueError("描述不能为空")
if not isinstance(description, str):
raise TypeError("description必须是字符串类型")
try:
logger.info(f"[EmojiAPI] 根据描述删除表情包: {description} (精确匹配: {exact_match})")
all_emojis = emoji_manager.emojis
# 筛选匹配的表情包
matching_emojis = []
for emoji_obj in all_emojis:
if emoji_obj.is_deleted:
continue
if exact_match:
if emoji_obj.description == description:
matching_emojis.append(emoji_obj)
else:
if description.lower() in emoji_obj.description.lower():
matching_emojis.append(emoji_obj)
matched_count = len(matching_emojis)
if matched_count == 0:
return {
"success": False,
"message": f"未找到匹配描述 '{description}' 的表情包",
"deleted_count": 0,
"deleted_hashes": [],
"matched_count": 0,
}
# 删除匹配的表情包
deleted_count = 0
deleted_hashes = []
for emoji_obj in matching_emojis:
try:
delete_success = emoji_manager.delete_emoji(emoji_obj)
if delete_success:
deleted_count += 1
deleted_hashes.append(emoji_obj.emoji_hash)
except Exception as delete_error:
logger.error(f"[EmojiAPI] 删除表情包失败 (哈希: {emoji_obj.emoji_hash}): {delete_error}")
# 构建返回结果
if deleted_count > 0:
return {
"success": True,
"message": f"成功删除 {deleted_count} 个表情包 (匹配到 {matched_count} 个)",
"deleted_count": deleted_count,
"deleted_hashes": deleted_hashes,
"matched_count": matched_count,
}
else:
return {
"success": False,
"message": f"匹配到 {matched_count} 个表情包,但删除全部失败",
"deleted_count": 0,
"deleted_hashes": [],
"matched_count": matched_count,
}
except Exception as e:
logger.error(f"[EmojiAPI] 根据描述删除表情包时发生异常: {e}")
return {
"success": False,
"message": f"删除过程中发生错误: {str(e)}",
"deleted_count": 0,
"deleted_hashes": [],
"matched_count": 0,
}

View File

@@ -1,3 +0,0 @@
from src.common.logger import get_logger
__all__ = ["get_logger"]

View File

@@ -1,120 +0,0 @@
from typing import Tuple, List
def list_loaded_plugins() -> List[str]:
"""
列出所有当前加载的插件。
Returns:
List[str]: 当前加载的插件名称列表。
"""
from src.plugin_system.core.plugin_manager import plugin_manager
return plugin_manager.list_loaded_plugins()
def list_registered_plugins() -> List[str]:
"""
列出所有已注册的插件。
Returns:
List[str]: 已注册的插件名称列表。
"""
from src.plugin_system.core.plugin_manager import plugin_manager
return plugin_manager.list_registered_plugins()
def get_plugin_path(plugin_name: str) -> str:
"""
获取指定插件的路径。
Args:
plugin_name (str): 插件名称。
Returns:
str: 插件目录的绝对路径。
Raises:
ValueError: 如果插件不存在。
"""
from src.plugin_system.core.plugin_manager import plugin_manager
if plugin_path := plugin_manager.get_plugin_path(plugin_name):
return plugin_path
else:
raise ValueError(f"插件 '{plugin_name}' 不存在。")
async def remove_plugin(plugin_name: str) -> bool:
"""
卸载指定的插件。
**此函数是异步的,确保在异步环境中调用。**
Args:
plugin_name (str): 要卸载的插件名称。
Returns:
bool: 卸载是否成功。
"""
from src.plugin_system.core.plugin_manager import plugin_manager
return await plugin_manager.remove_registered_plugin(plugin_name)
async def reload_plugin(plugin_name: str) -> bool:
"""
重新加载指定的插件。
**此函数是异步的,确保在异步环境中调用。**
Args:
plugin_name (str): 要重新加载的插件名称。
Returns:
bool: 重新加载是否成功。
"""
from src.plugin_system.core.plugin_manager import plugin_manager
return await plugin_manager.reload_registered_plugin(plugin_name)
def load_plugin(plugin_name: str) -> Tuple[bool, int]:
"""
加载指定的插件。
Args:
plugin_name (str): 要加载的插件名称。
Returns:
Tuple[bool, int]: 加载是否成功,成功或失败个数。
"""
from src.plugin_system.core.plugin_manager import plugin_manager
return plugin_manager.load_registered_plugin_classes(plugin_name)
def add_plugin_directory(plugin_directory: str) -> bool:
"""
添加插件目录。
Args:
plugin_directory (str): 要添加的插件目录路径。
Returns:
bool: 添加是否成功。
"""
from src.plugin_system.core.plugin_manager import plugin_manager
return plugin_manager.add_plugin_directory(plugin_directory)
def rescan_plugin_directory() -> Tuple[int, int]:
"""
重新扫描插件目录,加载新插件。
Returns:
Tuple[int, int]: 成功加载的插件数量和失败的插件数量。
"""
from src.plugin_system.core.plugin_manager import plugin_manager
return plugin_manager.rescan_plugin_directory()

View File

@@ -1,46 +0,0 @@
from pathlib import Path
from src.common.logger import get_logger
logger = get_logger("plugin_manager") # 复用plugin_manager名称
def register_plugin(cls):
from src.plugin_system.core.plugin_manager import plugin_manager
from src.plugin_system.base.base_plugin import BasePlugin
"""插件注册装饰器
用法:
@register_plugin
class MyPlugin(BasePlugin):
plugin_name = "my_plugin"
plugin_description = "我的插件"
...
"""
if not issubclass(cls, BasePlugin):
logger.error(f"{cls.__name__} 不是 BasePlugin 的子类")
return cls
# 只是注册插件类,不立即实例化
# 插件管理器会负责实例化和注册
plugin_name: str = cls.plugin_name # type: ignore
if "." in plugin_name:
logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
raise ValueError(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
splitted_name = cls.__module__.split(".")
root_path = Path(__file__)
# 查找项目根目录
while not (root_path / "pyproject.toml").exists() and root_path.parent != root_path:
root_path = root_path.parent
if not (root_path / "pyproject.toml").exists():
logger.error(f"注册 {plugin_name} 无法找到项目根目录")
return cls
plugin_manager.plugin_classes[plugin_name] = cls
plugin_manager.plugin_paths[plugin_name] = str(Path(root_path, *splitted_name).resolve())
logger.debug(f"插件类已注册: {plugin_name}, 路径: {plugin_manager.plugin_paths[plugin_name]}")
return cls

View File

@@ -1,56 +0,0 @@
from typing import Any, Callable, Dict, Optional
from src.plugin_system.base.service_types import PluginServiceInfo
from src.plugin_system.core.plugin_service_registry import plugin_service_registry
def register_service(service_info: PluginServiceInfo, service_handler: Callable[..., Any]) -> bool:
"""注册插件服务。"""
return plugin_service_registry.register_service(service_info, service_handler)
def get_service(service_name: str, plugin_name: Optional[str] = None) -> Optional[PluginServiceInfo]:
"""获取插件服务元信息。"""
return plugin_service_registry.get_service(service_name, plugin_name)
def get_service_handler(service_name: str, plugin_name: Optional[str] = None) -> Optional[Callable[..., Any]]:
"""获取插件服务处理函数。"""
return plugin_service_registry.get_service_handler(service_name, plugin_name)
def list_services(plugin_name: Optional[str] = None, enabled_only: bool = False) -> Dict[str, PluginServiceInfo]:
"""列出插件服务。"""
return plugin_service_registry.list_services(plugin_name=plugin_name, enabled_only=enabled_only)
def enable_service(service_name: str, plugin_name: Optional[str] = None) -> bool:
"""启用插件服务。"""
return plugin_service_registry.enable_service(service_name, plugin_name)
def disable_service(service_name: str, plugin_name: Optional[str] = None) -> bool:
"""禁用插件服务。"""
return plugin_service_registry.disable_service(service_name, plugin_name)
def unregister_service(service_name: str, plugin_name: Optional[str] = None) -> bool:
"""注销插件服务。"""
return plugin_service_registry.unregister_service(service_name, plugin_name)
async def call_service(
service_name: str,
*args: Any,
plugin_name: Optional[str] = None,
caller_plugin: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""调用插件服务。"""
return await plugin_service_registry.call_service(
service_name,
*args,
plugin_name=plugin_name,
caller_plugin=caller_plugin,
**kwargs,
)

View File

@@ -1,45 +0,0 @@
from typing import Optional, Type, TYPE_CHECKING
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.component_types import ComponentType
from src.common.logger import get_logger
if TYPE_CHECKING:
from src.chat.message_receive.chat_manager import BotChatSession
logger = get_logger("tool_api")
def get_tool_instance(tool_name: str, chat_stream: Optional["BotChatSession"] = None) -> Optional[BaseTool]:
"""获取公开工具实例
Args:
tool_name: 工具名称
chat_stream: 聊天流对象,用于传递聊天上下文信息
Returns:
Optional[BaseTool]: 工具实例如果未找到则返回None
"""
from src.plugin_system.core import component_registry
# 获取插件配置
tool_info = component_registry.get_component_info(tool_name, ComponentType.TOOL)
if tool_info:
plugin_config = component_registry.get_plugin_config(tool_info.plugin_name)
else:
plugin_config = None
tool_class: Type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore
return tool_class(plugin_config, chat_stream) if tool_class else None
def get_llm_available_tool_definitions():
"""获取LLM可用的工具定义列表
Returns:
List[Tuple[str, Dict[str, Any]]]: 工具定义列表,为[("tool_name", 定义)]
"""
from src.plugin_system.core import component_registry
llm_available_tools = component_registry.get_llm_available_tools()
return [(name, tool_class.get_tool_definition()) for name, tool_class in llm_available_tools.items()]

View File

@@ -1,77 +0,0 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from src.plugin_system.base.component_types import EventType, MaiMessages
from src.plugin_system.base.workflow_types import WorkflowContext, WorkflowStage, WorkflowStepInfo, WorkflowStepResult
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.core.events_manager import events_manager
from src.plugin_system.core.workflow_engine import workflow_engine
def register_workflow_step(step_info: WorkflowStepInfo, step_handler: Callable[..., Any]) -> bool:
"""注册workflow step。"""
return component_registry.register_workflow_step(step_info, step_handler)
def get_steps_by_stage(stage: WorkflowStage, enabled_only: bool = False) -> Dict[str, WorkflowStepInfo]:
"""获取指定阶段的workflow steps。"""
return component_registry.get_steps_by_stage(stage, enabled_only=enabled_only)
def get_workflow_step(step_name: str, stage: Optional[WorkflowStage] = None) -> Optional[WorkflowStepInfo]:
"""获取workflow step元信息。"""
return component_registry.get_workflow_step(step_name, stage)
def get_workflow_step_handler(step_name: str, stage: Optional[WorkflowStage] = None) -> Optional[Callable[..., Any]]:
"""获取workflow step处理函数。"""
return component_registry.get_workflow_step_handler(step_name, stage)
def enable_workflow_step(step_name: str, stage: Optional[WorkflowStage] = None) -> bool:
"""启用workflow step。"""
return component_registry.enable_workflow_step(step_name, stage)
def disable_workflow_step(step_name: str, stage: Optional[WorkflowStage] = None) -> bool:
"""禁用workflow step。"""
return component_registry.disable_workflow_step(step_name, stage)
def get_execution_trace(trace_id: str) -> Optional[Dict[str, Any]]:
"""按trace_id获取workflow执行路径。"""
return workflow_engine.get_execution_trace(trace_id)
def clear_execution_trace(trace_id: str) -> bool:
"""清理trace执行路径记录。"""
return workflow_engine.clear_execution_trace(trace_id)
async def execute_workflow_message(
message: Optional[MaiMessages] = None,
stream_id: Optional[str] = None,
action_usage: Optional[List[str]] = None,
context: Optional[WorkflowContext] = None,
) -> Tuple[WorkflowStepResult, Optional[MaiMessages], WorkflowContext]:
"""执行workflow消息流转。"""
return await events_manager.handle_workflow_message(
message=message,
stream_id=stream_id,
action_usage=action_usage,
context=context,
)
async def publish_event(
event_type: Union[EventType, str],
message: Optional[MaiMessages] = None,
stream_id: Optional[str] = None,
action_usage: Optional[List[str]] = None,
) -> Tuple[bool, Optional[MaiMessages]]:
"""发布事件(支持系统事件和自定义字符串事件)。"""
return await events_manager.handle_mai_events(
event_type=event_type,
message=message,
stream_id=stream_id,
action_usage=action_usage,
)

View File

@@ -1,72 +0,0 @@
"""
插件基础类模块
提供插件开发的基础类和类型定义
"""
from .base_plugin import BasePlugin
from .base_action import BaseAction
from .base_tool import BaseTool
from .base_command import BaseCommand
from .base_events_handler import BaseEventHandler
from .service_types import PluginServiceInfo
from .workflow_types import WorkflowContext, WorkflowMessage, WorkflowStage, WorkflowStepInfo, WorkflowStepResult
from .workflow_errors import WorkflowErrorCode
from .component_types import (
ComponentType,
ActionActivationType,
ChatMode,
ComponentInfo,
ActionInfo,
CommandInfo,
ToolInfo,
PluginInfo,
PythonDependency,
EventHandlerInfo,
EventType,
MaiMessages,
ToolParamType,
CustomEventHandlerResult,
ReplyContentType,
ReplyContent,
ForwardNode,
ReplySetModel,
)
from .config_types import ConfigField, ConfigSection, ConfigLayout, ConfigTab
__all__ = [
"BasePlugin",
"BaseAction",
"BaseCommand",
"BaseTool",
"ComponentType",
"ActionActivationType",
"ChatMode",
"ComponentInfo",
"ActionInfo",
"CommandInfo",
"ToolInfo",
"PluginInfo",
"PythonDependency",
"ConfigField",
"ConfigSection",
"ConfigLayout",
"ConfigTab",
"EventHandlerInfo",
"EventType",
"BaseEventHandler",
"MaiMessages",
"ToolParamType",
"CustomEventHandlerResult",
"ReplyContentType",
"ReplyContent",
"ForwardNode",
"ReplySetModel",
"PluginServiceInfo",
"WorkflowContext",
"WorkflowMessage",
"WorkflowStage",
"WorkflowStepInfo",
"WorkflowStepResult",
"WorkflowErrorCode",
]

View File

@@ -1,533 +0,0 @@
import time
import asyncio
from abc import ABC, abstractmethod
from typing import Tuple, Optional, TYPE_CHECKING, Dict, List
from src.common.logger import get_logger
from src.common.data_models.message_data_model import ReplyContentType, ReplyContent, ReplySetModel, ForwardNode
from src.chat.message_receive.chat_manager import BotChatSession
from src.plugin_system.base.component_types import ActionActivationType, ActionInfo, ComponentType
from src.plugin_system.apis import send_api, database_api, message_api
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
logger = get_logger("base_action")
class BaseAction(ABC):
"""Action组件基类
Action是插件的一种组件类型用于处理聊天中的动作逻辑
子类可以通过类属性定义激活条件,这些会在实例化时转换为实例属性:
- focus_activation_type: 专注模式激活类型
- normal_activation_type: 普通模式激活类型
- activation_keywords: 激活关键词列表
- keyword_case_sensitive: 关键词是否区分大小写
- parallel_action: 是否允许并行执行
- random_activation_probability: 随机激活概率
"""
def __init__(
self,
action_data: dict,
action_reasoning: str,
cycle_timers: dict,
thinking_id: str,
chat_stream: BotChatSession,
plugin_config: Optional[dict] = None,
action_message: Optional["DatabaseMessages"] = None,
**kwargs,
):
# sourcery skip: hoist-similar-statement-from-if, merge-else-if-into-elif, move-assign-in-block, swap-if-else-branches, swap-nested-ifs
"""初始化Action组件
Args:
action_data: 动作数据
reasoning: 执行该动作的理由
cycle_timers: 计时器字典
thinking_id: 思考ID
chat_stream: 聊天流对象
log_prefix: 日志前缀
plugin_config: 插件配置字典
action_message: 消息数据
**kwargs: 其他参数
"""
if plugin_config is None:
plugin_config = {}
self.action_data = action_data
self.reasoning = ""
self.cycle_timers = cycle_timers
self.thinking_id = thinking_id
self.action_reasoning = action_reasoning
self.plugin_config = plugin_config or {}
"""对应的插件配置"""
# 设置动作基本信息实例属性
self.action_name: str = getattr(self, "action_name", self.__class__.__name__.lower().replace("action", ""))
"""Action的名字"""
self.action_description: str = getattr(self, "action_description", self.__doc__ or "Action组件")
"""Action的描述"""
self.action_parameters: dict = getattr(self.__class__, "action_parameters", {}).copy()
self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy()
"""NORMAL模式下的激活类型"""
self.activation_type = self.__class__.activation_type
"""激活类型"""
self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0)
"""当激活类型为RANDOM时的概率"""
self.activation_keywords: list[str] = getattr(self.__class__, "activation_keywords", []).copy()
"""激活类型为KEYWORD时的KEYWORDS列表"""
self.keyword_case_sensitive: bool = getattr(self.__class__, "keyword_case_sensitive", False)
self.parallel_action: bool = getattr(self.__class__, "parallel_action", True)
self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy()
# =============================================================================
# 便捷属性 - 直接在初始化时获取常用聊天信息(带类型注解)
# =============================================================================
# 获取聊天流对象
self.chat_stream = chat_stream or kwargs.get("chat_stream")
self.chat_id = self.chat_stream.session_id
self.platform = getattr(self.chat_stream, "platform", None)
# 初始化基础信息(带类型注解)
self.action_message = action_message
self.group_id = None
self.group_name = None
self.user_id = None
self.user_nickname = None
self.is_group = False
self.target_id = None
self.group_id = (
str(self.action_message.chat_info.group_info.group_id) if self.action_message.chat_info.group_info else None
)
self.group_name = (
self.action_message.chat_info.group_info.group_name if self.action_message.chat_info.group_info else None
)
self.user_id = str(self.action_message.user_info.user_id)
self.user_nickname = self.action_message.user_info.user_nickname
if self.group_id:
self.is_group = True
self.target_id = self.group_id
self.log_prefix = f"[{self.group_name}]"
else:
self.is_group = False
self.target_id = self.user_id
self.log_prefix = f"[{self.user_nickname} 的 私聊]"
logger.debug(
f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}"
)
@abstractmethod
async def execute(self) -> Tuple[bool, str]:
"""执行Action的抽象方法子类必须实现
Returns:
Tuple[bool, str]: (是否执行成功, 回复文本)
"""
pass
async def send_text(
self,
content: str,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
typing: bool = False,
storage_message: bool = True,
) -> bool:
"""发送文本消息
Args:
content: 文本内容
set_reply: 是否作为回复发送
reply_message: 回复的消息对象当set_reply为True时必填
typing: 是否计算输入时间
Returns:
bool: 是否发送成功
"""
if not self.chat_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
return await send_api.text_to_stream(
text=content,
stream_id=self.chat_id,
set_reply=set_reply,
reply_message=reply_message,
typing=typing,
storage_message=storage_message,
)
async def send_emoji(
self,
emoji_base64: str,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""发送表情包
Args:
emoji_base64: 表情包的base64编码
set_reply: 是否作为回复发送
reply_message: 回复的消息对象当set_reply为True时必填
Returns:
bool: 是否发送成功
"""
if not self.chat_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
return await send_api.emoji_to_stream(
emoji_base64,
self.chat_id,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
async def send_image(
self,
image_base64: str,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""发送图片
Args:
image_base64: 图片的base64编码
set_reply: 是否作为回复发送
reply_message: 回复的消息对象当set_reply为True时必填
Returns:
bool: 是否发送成功
"""
if not self.chat_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
return await send_api.image_to_stream(
image_base64,
self.chat_id,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
async def send_command(
self,
command_name: str,
args: Optional[dict] = None,
display_message: str = "",
storage_message: bool = True,
) -> bool:
"""发送命令消息
Args:
command_name: 命令名称
args: 命令参数
display_message: 显示消息
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
if not self.chat_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
# 构造命令数据
command_data = {"name": command_name, "args": args or {}}
return await send_api.command_to_stream(
command=command_data,
stream_id=self.chat_id,
storage_message=storage_message,
display_message=display_message,
)
async def send_custom(
self,
message_type: str,
content: str | Dict,
typing: bool = False,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""发送自定义类型消息
Args:
message_type: 消息类型,如"video""file""audio"
content: 消息内容
typing: 是否显示正在输入
set_reply: 是否作为回复发送
reply_message: 回复的消息对象set_reply 为 True时必填
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
if not self.chat_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
return await send_api.custom_to_stream(
message_type=message_type,
content=content,
stream_id=self.chat_id,
typing=typing,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
async def send_hybrid(
self,
message_tuple_list: List[Tuple[ReplyContentType | str, str]],
typing: bool = False,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""
发送混合类型消息
Args:
message_tuple_list: 包含消息类型和内容的元组列表,格式为 [(内容类型, 内容), ...]
typing: 是否计算打字时间
set_reply: 是否作为回复发送
reply_message: 回复的消息对象
"""
if not self.chat_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
reply_set = ReplySetModel()
reply_set.add_hybrid_content_by_raw(message_tuple_list)
return await send_api.custom_reply_set_to_stream(
reply_set=reply_set,
stream_id=self.chat_id,
typing=typing,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
async def send_forward(
self,
messages_list: List[Tuple[str, str, List[Tuple[ReplyContentType | str, str]]] | str],
storage_message: bool = True,
) -> bool:
"""转发消息
Args:
messages_list: 包含消息信息的列表,当传入自行生成的数据时,元素格式为 (sender_id, nickname, 消息体)当传入消息ID时元素格式为 "message_id"
其中消息体的格式为 [(内容类型, 内容), ...]
任意长度的消息都需要使用列表的形式传入
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
if not self.chat_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
reply_set = ReplySetModel()
forward_message_nodes: List[ForwardNode] = []
for message in messages_list:
if isinstance(message, str):
forward_message_node = ForwardNode.construct_as_id_reference(message)
elif isinstance(message, Tuple) and len(message) == 3:
sender_id, nickname, content_list = message
single_node_content_list: List[ReplyContent] = []
for node_content_type, node_content in content_list:
reply_node_content = ReplyContent(content_type=node_content_type, content=node_content)
single_node_content_list.append(reply_node_content)
forward_message_node = ForwardNode.construct_as_created_node(
user_id=sender_id, user_nickname=nickname, content=single_node_content_list
)
else:
logger.warning(f"{self.log_prefix} 转发消息时遇到无效的消息格式: {message}")
continue
forward_message_nodes.append(forward_message_node)
reply_set.add_forward_content(forward_message_nodes)
return await send_api.custom_reply_set_to_stream(
reply_set=reply_set,
stream_id=self.chat_id,
storage_message=storage_message,
set_reply=False,
reply_message=None,
)
async def send_voice(self, audio_base64: str) -> bool:
"""
发送语音消息
Args:
audio_base64: 语音的base64编码
Returns:
bool: 是否发送成功
"""
if not audio_base64:
logger.error(f"{self.log_prefix} 缺少音频内容")
return False
reply_set = ReplySetModel()
reply_set.add_voice_content(audio_base64)
return await send_api.custom_reply_set_to_stream(
reply_set=reply_set,
stream_id=self.chat_id,
storage_message=False,
)
async def store_action_info(
self,
action_build_into_prompt: bool = False,
action_prompt_display: str = "",
action_done: bool = True,
) -> None:
"""存储动作信息到数据库
Args:
action_build_into_prompt: 是否构建到提示中
action_prompt_display: 显示的action提示信息
action_done: action是否完成
"""
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=action_build_into_prompt,
action_prompt_display=action_prompt_display,
action_done=action_done,
thinking_id=self.thinking_id,
action_data=self.action_data,
action_name=self.action_name,
action_reasoning=self.action_reasoning,
)
async def wait_for_new_message(self, timeout: int = 1200) -> Tuple[bool, str]:
"""等待新消息或超时
在loop_start_time之后等待新消息如果没有新消息且没有超时就一直等待。
使用message_api检查self.chat_id对应的聊天中是否有新消息。
Args:
timeout: 超时时间默认1200秒
Returns:
Tuple[bool, str]: (是否收到新消息, 空字符串)
"""
try:
# 获取循环开始时间,如果没有则使用当前时间
loop_start_time = self.action_data.get("loop_start_time", time.time())
logger.info(f"{self.log_prefix} 开始等待新消息... (最长等待: {timeout}秒, 从时间点: {loop_start_time})")
# 确保有有效的chat_id
if not self.chat_id:
logger.error(f"{self.log_prefix} 等待新消息失败: 没有有效的chat_id")
return False, "没有有效的chat_id"
wait_start_time = asyncio.get_event_loop().time()
while True:
# 检查新消息
current_time = time.time()
new_message_count = message_api.count_new_messages(
chat_id=self.chat_id, start_time=loop_start_time, end_time=current_time
)
if new_message_count > 0:
logger.info(f"{self.log_prefix} 检测到{new_message_count}条新消息聊天ID: {self.chat_id}")
return True, ""
# 检查超时
elapsed_time = asyncio.get_event_loop().time() - wait_start_time
if elapsed_time > timeout:
logger.warning(f"{self.log_prefix} 等待新消息超时({timeout}秒)聊天ID: {self.chat_id}")
return False, ""
# 每30秒记录一次等待状态
if int(elapsed_time) % 15 == 0 and int(elapsed_time) > 0:
logger.debug(f"{self.log_prefix} 已等待{int(elapsed_time)}秒,继续等待新消息...")
# 短暂休眠
await asyncio.sleep(0.5)
except asyncio.CancelledError:
logger.info(f"{self.log_prefix} 等待新消息被中断 (CancelledError)")
return False, ""
except Exception as e:
logger.error(f"{self.log_prefix} 等待新消息时发生错误: {e}")
return False, f"等待新消息失败: {str(e)}"
@classmethod
def get_action_info(cls) -> "ActionInfo":
"""从类属性生成ActionInfo
所有信息都从类属性中读取,确保一致性和完整性。
Action类必须定义所有必要的类属性。
Returns:
ActionInfo: 生成的Action信息对象
"""
# 从类属性读取名称,如果没有定义则使用类名自动生成
name = getattr(cls, "action_name", cls.__name__.lower().replace("action", ""))
if "." in name:
logger.error(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代")
raise ValueError(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代")
# 获取focus_activation_type和normal_activation_type
focus_activation_type = getattr(cls, "focus_activation_type", ActionActivationType.ALWAYS)
_normal_activation_type = getattr(cls, "normal_activation_type", ActionActivationType.ALWAYS)
# 处理activation_type如果插件中声明了就用插件的值否则默认使用focus_activation_type
activation_type = getattr(cls, "activation_type", focus_activation_type)
return ActionInfo(
name=name,
component_type=ComponentType.ACTION,
description=getattr(cls, "action_description", "Action动作"),
activation_type=activation_type,
activation_keywords=getattr(cls, "activation_keywords", []).copy(),
keyword_case_sensitive=getattr(cls, "keyword_case_sensitive", False),
parallel_action=getattr(cls, "parallel_action", True),
random_activation_probability=getattr(cls, "random_activation_probability", 0.0),
# 使用正确的字段名
action_parameters=getattr(cls, "action_parameters", {}).copy(),
action_require=getattr(cls, "action_require", []).copy(),
associated_types=getattr(cls, "associated_types", []).copy(),
)
def get_config(self, key: str, default=None):
"""获取插件配置值,使用嵌套键访问
Args:
key: 配置键名,使用嵌套访问如 "section.subsection.key"
default: 默认值
Returns:
Any: 配置值或默认值
"""
if not self.plugin_config:
return default
# 支持嵌套键访问
keys = key.split(".")
current = self.plugin_config
for k in keys:
if isinstance(current, dict) and k in current:
current = current[k]
else:
return default
return current

View File

@@ -1,387 +0,0 @@
from abc import ABC, abstractmethod
from typing import Dict, Tuple, Optional, TYPE_CHECKING, List
from src.common.logger import get_logger
from src.common.data_models.message_data_model import ReplyContentType, ReplyContent, ReplySetModel, ForwardNode
from src.plugin_system.base.component_types import CommandInfo, ComponentType
from src.chat.message_receive.message import SessionMessage
from src.plugin_system.apis import send_api
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
logger = get_logger("base_command")
class BaseCommand(ABC):
"""Command组件基类
Command是插件的一种组件类型用于处理命令请求
子类可以通过类属性定义命令模式:
- command_pattern: 命令匹配的正则表达式
- command_help: 命令帮助信息
- command_examples: 命令使用示例列表
"""
command_name: str = ""
"""Command组件的名称"""
command_description: str = ""
"""Command组件的描述"""
# 默认命令设置
command_pattern: str = r""
"""命令匹配的正则表达式"""
def __init__(self, message: SessionMessage, plugin_config: Optional[dict] = None):
"""初始化Command组件
Args:
message: 接收到的消息对象
plugin_config: 插件配置字典
"""
self.message = message
self.matched_groups: Dict[str, str] = {} # 存储正则表达式匹配的命名组
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
self.log_prefix = "[Command]"
logger.debug(f"{self.log_prefix} Command组件初始化完成")
def set_matched_groups(self, groups: Dict[str, str]) -> None:
"""设置正则表达式匹配的命名组
Args:
groups: 正则表达式匹配的命名组
"""
self.matched_groups = groups
@abstractmethod
async def execute(self) -> Tuple[bool, Optional[str], int]:
"""执行Command的抽象方法子类必须实现
Returns:
Tuple[bool, Optional[str], int]: (是否执行成功, 可选的回复消息, 拦截消息力度0代表不拦截1代表仅不触发回复replyer可见2代表不触发回复replyer不可见)
"""
pass
def get_config(self, key: str, default=None):
"""获取插件配置值,使用嵌套键访问
Args:
key: 配置键名,使用嵌套访问如 "section.subsection.key"
default: 默认值
Returns:
Any: 配置值或默认值
"""
if not self.plugin_config:
return default
# 支持嵌套键访问
keys = key.split(".")
current = self.plugin_config
for k in keys:
if isinstance(current, dict) and k in current:
current = current[k]
else:
return default
return current
async def send_text(
self,
content: str,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""发送回复消息
Args:
content: 回复内容
set_reply: 是否作为回复发送
reply_message: 回复的消息对象当set_reply为True时必填
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
# 获取聊天流信息
session_id = self.message.session_id
if not session_id:
logger.error(f"{self.log_prefix} 缺少session_id")
return False
return await send_api.text_to_stream(
text=content,
stream_id=session_id,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
async def send_image(
self,
image_base64: str,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""发送图片
Args:
image_base64: 图片的base64编码
Returns:
bool: 是否发送成功
"""
session_id = self.message.session_id
if not session_id:
logger.error(f"{self.log_prefix} 缺少session_id")
return False
return await send_api.image_to_stream(
image_base64,
session_id,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
async def send_emoji(
self,
emoji_base64: str,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""发送表情包
Args:
emoji_base64: 表情包的base64编码
set_reply: 是否作为回复发送
reply_message: 回复的消息对象当set_reply为True时必填
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
session_id = self.message.session_id
if not session_id:
logger.error(f"{self.log_prefix} 缺少session_id")
return False
return await send_api.emoji_to_stream(
emoji_base64, session_id, set_reply=set_reply, reply_message=reply_message
)
async def send_command(
self,
command_name: str,
args: Optional[dict] = None,
display_message: str = "",
storage_message: bool = True,
) -> bool:
"""发送命令消息
Args:
command_name: 命令名称
args: 命令参数
display_message: 显示消息
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
try:
# 获取聊天流信息
session_id = self.message.session_id
if not session_id:
logger.error(f"{self.log_prefix} 缺少session_id")
return False
# 构造命令数据
command_data = {"name": command_name, "args": args or {}}
success = await send_api.command_to_stream(
command=command_data,
stream_id=session_id,
storage_message=storage_message,
display_message=display_message,
)
if success:
logger.info(f"{self.log_prefix} 成功发送命令: {command_name}")
else:
logger.error(f"{self.log_prefix} 发送命令失败: {command_name}")
return success
except Exception as e:
logger.error(f"{self.log_prefix} 发送命令时出错: {e}")
return False
async def send_voice(self, voice_base64: str) -> bool:
"""
发送语音消息
Args:
voice_base64: 语音的base64编码
Returns:
bool: 是否发送成功
"""
session_id = self.message.session_id
if not session_id:
logger.error(f"{self.log_prefix} 缺少session_id")
return False
return await send_api.custom_to_stream(
message_type="voice",
content=voice_base64,
stream_id=session_id,
typing=False,
set_reply=False,
reply_message=None,
storage_message=False,
)
async def send_hybrid(
self,
message_tuple_list: List[Tuple[ReplyContentType | str, str]],
typing: bool = False,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""
发送混合类型消息
Args:
message_tuple_list: 包含消息类型和内容的元组列表,格式为 [(内容类型, 内容), ...]
typing: 是否显示正在输入
set_reply: 是否计算打字时间
reply_message: 回复的消息对象
storage_message: 是否存储消息到数据库
"""
session_id = self.message.session_id
if not session_id:
logger.error(f"{self.log_prefix} 缺少session_id")
return False
reply_set = ReplySetModel()
reply_set.add_hybrid_content_by_raw(message_tuple_list)
return await send_api.custom_reply_set_to_stream(
reply_set=reply_set,
stream_id=session_id,
typing=typing,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
async def send_forward(
self,
messages_list: List[Tuple[str, str, List[Tuple[ReplyContentType | str, str]]] | str],
storage_message: bool = True,
) -> bool:
"""转发消息
Args:
messages_list: 包含消息信息的列表,当传入自行生成的数据时,元素格式为 (sender_id, nickname, 消息体)当传入消息ID时元素格式为 "message_id"
其中消息体的格式为 [(内容类型, 内容), ...]
任意长度的消息都需要使用列表的形式传入
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
session_id = self.message.session_id
if not session_id:
logger.error(f"{self.log_prefix} 缺少session_id")
return False
reply_set = ReplySetModel()
forward_message_nodes: List[ForwardNode] = []
for message in messages_list:
if isinstance(message, str):
forward_message_node = ForwardNode.construct_as_id_reference(message)
elif isinstance(message, Tuple) and len(message) == 3:
sender_id, nickname, content_list = message
single_node_content_list: List[ReplyContent] = []
for node_content_type, node_content in content_list:
reply_node_content = ReplyContent(content_type=node_content_type, content=node_content)
single_node_content_list.append(reply_node_content)
forward_message_node = ForwardNode.construct_as_created_node(
user_id=sender_id, user_nickname=nickname, content=single_node_content_list
)
else:
logger.warning(f"{self.log_prefix} 转发消息时遇到无效的消息格式: {message}")
continue
forward_message_nodes.append(forward_message_node)
reply_set.add_forward_content(forward_message_nodes)
return await send_api.custom_reply_set_to_stream(
reply_set=reply_set,
stream_id=session_id,
storage_message=storage_message,
set_reply=False,
reply_message=None,
)
async def send_custom(
self,
message_type: str,
content: str | Dict,
display_message: str = "",
typing: bool = False,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""发送指定类型的回复消息到当前聊天环境
Args:
message_type: 消息类型,如"text""image""emoji""voice"
content: 消息内容
display_message: 显示消息(可选)
typing: 是否显示正在输入
set_reply: 是否作为回复发送
reply_message: 回复的消息对象set_reply 为 True时必填
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
# 获取聊天流信息
session_id = self.message.session_id
if not session_id:
logger.error(f"{self.log_prefix} 缺少session_id")
return False
return await send_api.custom_to_stream(
message_type=message_type,
content=content,
stream_id=session_id,
display_message=display_message,
typing=typing,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
@classmethod
def get_command_info(cls) -> "CommandInfo":
"""从类属性生成CommandInfo
Args:
name: Command名称如果不提供则使用类名
description: Command描述如果不提供则使用类文档字符串
Returns:
CommandInfo: 生成的Command信息对象
"""
if "." in cls.command_name:
logger.error(f"Command名称 '{cls.command_name}' 包含非法字符 '.',请使用下划线替代")
raise ValueError(f"Command名称 '{cls.command_name}' 包含非法字符 '.',请使用下划线替代")
return CommandInfo(
name=cls.command_name,
component_type=ComponentType.COMMAND,
description=cls.command_description,
command_pattern=cls.command_pattern,
)

View File

@@ -1,381 +0,0 @@
from abc import ABC, abstractmethod
from typing import Tuple, Optional, Dict, List, TYPE_CHECKING
from src.common.logger import get_logger
from src.common.data_models.message_data_model import ReplyContentType, ReplySetModel, ReplyContent, ForwardNode
from src.plugin_system.apis import send_api
from .component_types import MaiMessages, EventType, EventHandlerInfo, ComponentType, CustomEventHandlerResult
logger = get_logger("base_event_handler")
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
class BaseEventHandler(ABC):
"""事件处理器基类
所有事件处理器都应该继承这个基类,提供事件处理的基本接口
"""
event_type: EventType | str = EventType.UNKNOWN
"""事件类型,默认为未知"""
handler_name: str = ""
"""处理器名称"""
handler_description: str = ""
"""处理器描述"""
weight: int = 0
"""处理器权重,越大权重越高"""
intercept_message: bool = False
"""是否拦截消息,默认为否"""
def __init__(self):
self.log_prefix = "[EventHandler]"
self.plugin_name = ""
"""对应插件名"""
self.plugin_config: Optional[Dict] = None
"""插件配置字典"""
if self.event_type == EventType.UNKNOWN:
raise NotImplementedError("事件处理器必须指定 event_type")
@abstractmethod
async def execute(
self, message: MaiMessages | None
) -> Tuple[bool, bool, Optional[str], Optional[CustomEventHandlerResult], Optional[MaiMessages]]:
"""执行事件处理的抽象方法,子类必须实现
Args:
message (MaiMessages | None): 事件消息对象当你注册的事件为ON_START和ON_STOP时message为None
Returns:
Tuple[bool, bool, Optional[str], Optional[CustomEventHandlerResult], Optional[MaiMessages]]: (是否执行成功, 是否需要继续处理, 可选的返回消息, 可选的自定义结果,可选的修改后消息)
"""
raise NotImplementedError("子类必须实现 execute 方法")
@classmethod
def get_handler_info(cls) -> "EventHandlerInfo":
"""获取事件处理器的信息"""
# 从类属性读取名称如果没有定义则使用类名自动生成S
name: str = getattr(cls, "handler_name", cls.__name__.lower().replace("handler", ""))
if "." in name:
logger.error(f"事件处理器名称 '{name}' 包含非法字符 '.',请使用下划线替代")
raise ValueError(f"事件处理器名称 '{name}' 包含非法字符 '.',请使用下划线替代")
return EventHandlerInfo(
name=name,
component_type=ComponentType.EVENT_HANDLER,
description=getattr(cls, "handler_description", "events处理器"),
event_type=cls.event_type,
weight=cls.weight,
intercept_message=cls.intercept_message,
)
def set_plugin_config(self, plugin_config: Dict) -> None:
"""设置插件配置
Args:
plugin_config (dict): 插件配置字典
"""
self.plugin_config = plugin_config
def set_plugin_name(self, plugin_name: str) -> None:
"""设置插件名称
Args:
plugin_name (str): 插件名称
"""
self.plugin_name = plugin_name
def get_config(self, key: str, default=None):
"""获取插件配置值,支持嵌套键访问
Args:
key: 配置键名,支持嵌套访问如 "section.subsection.key"
default: 默认值
Returns:
Any: 配置值或默认值
"""
if not self.plugin_config:
return default
# 支持嵌套键访问
keys = key.split(".")
current = self.plugin_config
for k in keys:
if isinstance(current, dict) and k in current:
current = current[k]
else:
return default
return current
async def send_text(
self,
stream_id: str,
text: str,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
typing: bool = False,
storage_message: bool = True,
) -> bool:
"""发送文本消息
Args:
stream_id: 聊天ID
text: 文本内容
set_reply: 是否作为回复发送
reply_message: 回复的消息对象当set_reply为True时必填
typing: 是否计算输入时间
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
if not stream_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
return await send_api.text_to_stream(
text=text,
stream_id=stream_id,
set_reply=set_reply,
reply_message=reply_message,
typing=typing,
storage_message=storage_message,
)
async def send_emoji(
self,
stream_id: str,
emoji_base64: str,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""发送表情消息
Args:
emoji_base64: 表情的Base64编码
stream_id: 聊天ID
set_reply: 是否作为回复发送
reply_message: 回复的消息对象当set_reply为True时必填
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
if not stream_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
return await send_api.emoji_to_stream(
emoji_base64=emoji_base64,
stream_id=stream_id,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
async def send_image(
self,
stream_id: str,
image_base64: str,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""发送图片消息
Args:
image_base64: 图片的Base64编码
stream_id: 聊天ID
set_reply: 是否作为回复发送
reply_message: 回复的消息对象当set_reply为True时必填
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
if not stream_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
return await send_api.image_to_stream(
image_base64=image_base64,
stream_id=stream_id,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
async def send_voice(
self,
stream_id: str,
audio_base64: str,
) -> bool:
"""发送语音消息
Args:
stream_id: 聊天ID
audio_base64: 语音的Base64编码
Returns:
bool: 是否发送成功
"""
if not stream_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
reply_set = ReplySetModel()
reply_set.add_voice_content(audio_base64)
return await send_api.custom_reply_set_to_stream(
reply_set=reply_set,
stream_id=stream_id,
storage_message=False,
)
async def send_command(
self,
stream_id: str,
command_name: str,
command_args: Optional[dict] = None,
display_message: str = "",
storage_message: bool = True,
) -> bool:
"""发送命令消息
Args:
stream_id: 流ID
command_name: 命令名称
command_args: 命令参数字典
display_message: 显示消息
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
if not stream_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
# 构造命令数据
command_data = {"name": command_name, "args": command_args or {}}
return await send_api.command_to_stream(
command=command_data,
stream_id=stream_id,
storage_message=storage_message,
display_message=display_message,
)
async def send_custom(
self,
stream_id: str,
message_type: str,
content: str | Dict,
typing: bool = False,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""发送自定义消息
Args:
stream_id: 聊天ID
message_type: 消息类型
content: 消息内容,可以是字符串或字典
typing: 是否显示正在输入状态
set_reply: 是否作为回复发送
reply_message: 回复的消息对象当set_reply为True时必填
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
if not stream_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
return await send_api.custom_to_stream(
message_type=message_type,
content=content,
stream_id=stream_id,
typing=typing,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
async def send_hybrid(
self,
stream_id: str,
message_tuple_list: List[Tuple[ReplyContentType | str, str]],
typing: bool = False,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""
发送混合类型消息
Args:
stream_id: 流ID
message_tuple_list: 包含消息类型和内容的元组列表,格式为 [(内容类型, 内容), ...]
typing: 是否计算打字时间
set_reply: 是否作为回复发送
reply_message: 回复的消息对象
storage_message: 是否存储消息到数据库
"""
if not stream_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
reply_set = ReplySetModel()
reply_set.add_hybrid_content_by_raw(message_tuple_list)
return await send_api.custom_reply_set_to_stream(
reply_set=reply_set,
stream_id=stream_id,
typing=typing,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
async def send_forward(
self,
stream_id: str,
messages_list: List[Tuple[str, str, List[Tuple[ReplyContentType | str, str]]] | str],
storage_message: bool = True,
) -> bool:
"""转发消息
Args:
stream_id: 聊天ID
messages_list: 包含消息信息的列表,当传入自行生成的数据时,元素格式为 (sender_id, nickname, 消息体)当传入消息ID时元素格式为 "message_id"
其中消息体的格式为 [(内容类型, 内容), ...]
任意长度的消息都需要使用列表的形式传入
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
if not stream_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
reply_set = ReplySetModel()
forward_message_nodes: List[ForwardNode] = []
for message in messages_list:
if isinstance(message, str):
forward_message_node = ForwardNode.construct_as_id_reference(message)
elif isinstance(message, Tuple) and len(message) == 3:
sender_id, nickname, content_list = message
single_node_content_list: List[ReplyContent] = []
for node_content_type, node_content in content_list:
reply_node_content = ReplyContent(content_type=node_content_type, content=node_content)
single_node_content_list.append(reply_node_content)
forward_message_node = ForwardNode.construct_as_created_node(
user_id=sender_id, user_nickname=nickname, content=single_node_content_list
)
else:
logger.warning(f"{self.log_prefix} 转发消息时遇到无效的消息格式: {message}")
continue
forward_message_nodes.append(forward_message_node)
reply_set.add_forward_content(forward_message_nodes)
return await send_api.custom_reply_set_to_stream(
reply_set=reply_set,
stream_id=stream_id,
storage_message=storage_message,
set_reply=False,
reply_message=None,
)

View File

@@ -1,102 +0,0 @@
from abc import abstractmethod
from typing import Any, Callable, List, Type, Tuple, Union
from .plugin_base import PluginBase
from src.common.logger import get_logger
from src.plugin_system.base.component_types import ActionInfo, CommandInfo, EventHandlerInfo, ToolInfo
from src.plugin_system.base.workflow_types import WorkflowStepInfo
from .base_action import BaseAction
from .base_command import BaseCommand
from .base_events_handler import BaseEventHandler
from .base_tool import BaseTool
logger = get_logger("base_plugin")
class BasePlugin(PluginBase):
"""基于Action和Command的插件基类
所有上述类型的插件都应该继承这个基类,一个插件可以包含多种组件:
- Action组件处理聊天中的动作
- Command组件处理命令请求
- 未来可扩展Scheduler、Listener等
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@abstractmethod
def get_plugin_components(
self,
) -> List[
Union[
Tuple[ActionInfo, Type[BaseAction]],
Tuple[CommandInfo, Type[BaseCommand]],
Tuple[EventHandlerInfo, Type[BaseEventHandler]],
Tuple[ToolInfo, Type[BaseTool]],
]
]:
"""获取插件包含的组件列表
子类必须实现此方法,返回组件信息和组件类的列表
Returns:
List[tuple[ComponentInfo, Type]]: [(组件信息, 组件类), ...]
"""
raise NotImplementedError("Subclasses must implement this method")
def get_workflow_steps(self) -> List[Tuple[WorkflowStepInfo, Callable[..., Any]]]:
"""获取插件包含的workflow steps。
默认返回空列表,子类可按需覆盖。
"""
return []
def register_plugin(self) -> bool:
"""注册插件及其所有组件"""
from src.plugin_system.core.component_registry import component_registry
components = self.get_plugin_components()
# 检查依赖
if not self._check_dependencies():
logger.error(f"{self.log_prefix} 依赖检查失败,跳过注册")
return False
# 注册所有组件
registered_components = []
for component_info, component_class in components:
component_info.plugin_name = self.plugin_name
if component_registry.register_component(component_info, component_class):
registered_components.append(component_info)
else:
logger.warning(f"{self.log_prefix} 组件 {component_info.name} 注册失败")
# 更新插件信息中的组件列表
self.plugin_info.components = registered_components
# 注册插件
if component_registry.register_plugin(self.plugin_info):
# 注册workflow steps可选
registered_step_count = 0
for step_info, step_handler in self.get_workflow_steps():
if not step_info.plugin_name:
step_info.plugin_name = self.plugin_name
elif step_info.plugin_name != self.plugin_name:
logger.warning(
f"{self.log_prefix} workflow step {step_info.name} 的plugin_name({step_info.plugin_name})与当前插件不一致,已覆盖为 {self.plugin_name}"
)
step_info.plugin_name = self.plugin_name
if component_registry.register_workflow_step(step_info, step_handler):
registered_step_count += 1
else:
logger.warning(f"{self.log_prefix} workflow step {step_info.full_name} 注册失败")
logger.debug(f"{self.log_prefix} 插件注册成功,包含 {len(registered_components)} 个组件")
if registered_step_count > 0:
logger.debug(f"{self.log_prefix} workflow steps 注册成功,数量: {registered_step_count}")
return True
else:
logger.error(f"{self.log_prefix} 插件注册失败")
return False

View File

@@ -1,137 +0,0 @@
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Tuple, TYPE_CHECKING
from rich.traceback import install
from src.common.logger import get_logger
from src.plugin_system.base.component_types import ComponentType, ToolInfo, ToolParamType
if TYPE_CHECKING:
from src.chat.message_receive.chat_manager import BotChatSession
install(extra_lines=3)
logger = get_logger("base_tool")
class BaseTool(ABC):
"""所有工具的基类"""
name: str = ""
"""工具的名称"""
description: str = ""
"""工具的描述"""
parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = []
"""工具的参数定义,为[("param_name", param_type, "description", required, enum_values)]格式
param_name: 参数名称
param_type: 参数类型
description: 参数描述
required: 是否必填
enum_values: 枚举值列表
例如: [("arg1", ToolParamType.STRING, "参数1描述", True, None), ("arg2", ToolParamType.INTEGER, "参数2描述", False, ["1", "2", "3"])]
"""
available_for_llm: bool = False
"""是否可供LLM使用"""
def __init__(self, plugin_config: Optional[dict] = None, chat_stream: Optional["BotChatSession"] = None):
"""初始化工具基类
Args:
plugin_config: 插件配置字典
chat_stream: 聊天流对象,用于获取聊天上下文信息
"""
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
# =============================================================================
# 便捷属性 - 直接在初始化时获取常用聊天信息与BaseAction保持一致
# =============================================================================
# 获取聊天流对象
self.chat_stream = chat_stream
self.chat_id = self.chat_stream.session_id if self.chat_stream else None
self.platform = getattr(self.chat_stream, "platform", None) if self.chat_stream else None
@classmethod
def get_tool_definition(cls) -> dict[str, Any]:
"""获取工具定义用于LLM工具调用
Returns:
dict: 工具定义字典
"""
if not cls.name or not cls.description or cls.parameters is None:
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
return {"name": cls.name, "description": cls.description, "parameters": cls.parameters}
@classmethod
def get_tool_info(cls) -> ToolInfo:
"""获取工具信息"""
if not cls.name or not cls.description or cls.parameters is None:
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
return ToolInfo(
name=cls.name,
tool_description=cls.description,
enabled=cls.available_for_llm,
tool_parameters=cls.parameters,
component_type=ComponentType.TOOL,
)
@abstractmethod
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行工具函数(供llm调用)
通过该方法maicore会通过llm的tool call来调用工具
传入的是json格式的参数符合parameters定义的格式
Args:
function_args: 工具调用参数
Returns:
dict: 工具执行结果
"""
raise NotImplementedError("子类必须实现execute方法")
async def direct_execute(self, **function_args: dict[str, Any]) -> dict[str, Any]:
"""直接执行工具函数(供插件调用)
通过该方法,插件可以直接调用工具,而不需要传入字典格式的参数
插件可以直接调用此方法,用更加明了的方式传入参数
示例: result = await tool.direct_execute(arg1="参数",arg2="参数2")
工具开发者可以重写此方法以实现与llm调用差异化的执行逻辑
Args:
**function_args: 工具调用参数
Returns:
dict: 工具执行结果
"""
parameter_required = [param[0] for param in self.parameters if param[3]] # 获取所有必填参数名
for param_name in parameter_required:
if param_name not in function_args:
raise ValueError(f"工具类 {self.__class__.__name__} 缺少必要参数: {param_name}")
return await self.execute(function_args)
def get_config(self, key: str, default=None):
"""获取插件配置值,使用嵌套键访问
Args:
key: 配置键名,使用嵌套访问如 "section.subsection.key"
default: 默认值
Returns:
Any: 配置值或默认值
"""
if not self.plugin_config:
return default
# 支持嵌套键访问
keys = key.split(".")
current = self.plugin_config
for k in keys:
if isinstance(current, dict) and k in current:
current = current[k]
else:
return default
return current

View File

@@ -1,761 +0,0 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Any, Union
import os
import inspect
import toml
import json
import shutil
import datetime
import re
from src.common.logger import get_logger
from src.plugin_system.base.component_types import (
PluginInfo,
PythonDependency,
)
from src.plugin_system.base.config_types import (
ConfigField,
ConfigSection,
ConfigLayout,
)
from src.plugin_system.utils.manifest_utils import ManifestValidator, VersionComparator
logger = get_logger("plugin_base")
class PluginBase(ABC):
"""插件总基类
所有衍生插件基类都应该继承自此类,这个类定义了插件的基本结构和行为。
"""
# 插件基本信息(子类必须定义)
@property
@abstractmethod
def plugin_name(self) -> str:
return "" # 插件内部标识符(如 "hello_world_plugin"
@property
@abstractmethod
def enable_plugin(self) -> bool:
return True # 是否启用插件
@property
@abstractmethod
def dependencies(self) -> List[Union[str, Dict[str, Any]]]:
return [] # 依赖的其他插件
@property
@abstractmethod
def python_dependencies(self) -> List[PythonDependency]:
return [] # Python包依赖
@property
@abstractmethod
def config_file_name(self) -> str:
return "" # 配置文件名
# manifest文件相关
manifest_file_name: str = "_manifest.json" # manifest文件名
manifest_data: Dict[str, Any] = {} # manifest数据
# 配置定义
@property
@abstractmethod
def config_schema(self) -> Dict[str, Union[Dict[str, ConfigField], str]]:
return {}
config_section_descriptions: Dict[str, Union[str, ConfigSection]] = {}
# 布局配置(可选,不定义则使用自动布局)
config_layout: ConfigLayout = None
def __init__(self, plugin_dir: str):
"""初始化插件
Args:
plugin_dir: 插件目录路径,由插件管理器传递
"""
self.config: Dict[str, Any] = {} # 插件配置
self.plugin_dir = plugin_dir # 插件目录路径
self.log_prefix = f"[Plugin:{self.plugin_name}]"
# 加载manifest文件
self._load_manifest()
# 验证插件信息
self._validate_plugin_info()
# 加载插件配置
self._load_plugin_config()
# 从manifest获取显示信息
self.display_name = self.get_manifest_info("name", self.plugin_name)
self.plugin_version = self.get_manifest_info("version", "1.0.0")
self.plugin_description = self.get_manifest_info("description", "")
self.plugin_author = self._get_author_name()
# 创建插件信息对象
self.plugin_info = PluginInfo(
name=self.plugin_name,
display_name=self.display_name,
description=self.plugin_description,
version=self.plugin_version,
author=self.plugin_author,
enabled=self.enable_plugin,
is_built_in=False,
config_file=self.config_file_name or "",
dependencies=self._get_dependency_names(),
python_dependencies=self.python_dependencies.copy(),
# manifest相关信息
manifest_data=self.manifest_data.copy(),
license=self.get_manifest_info("license", ""),
homepage_url=self.get_manifest_info("homepage_url", ""),
repository_url=self.get_manifest_info("repository_url", ""),
keywords=self.get_manifest_info("keywords", []).copy() if self.get_manifest_info("keywords") else [],
categories=self.get_manifest_info("categories", []).copy() if self.get_manifest_info("categories") else [],
min_host_version=self.get_manifest_info("host_application.min_version", ""),
max_host_version=self.get_manifest_info("host_application.max_version", ""),
)
logger.debug(f"{self.log_prefix} 插件基类初始化完成")
def _validate_plugin_info(self):
"""验证插件基本信息"""
if not self.plugin_name:
raise ValueError(f"插件类 {self.__class__.__name__} 必须定义 plugin_name")
# 验证manifest中的必需信息
if not self.get_manifest_info("name"):
raise ValueError(f"插件 {self.plugin_name} 的manifest中缺少name字段")
if not self.get_manifest_info("description"):
raise ValueError(f"插件 {self.plugin_name} 的manifest中缺少description字段")
def _load_manifest(self): # sourcery skip: raise-from-previous-error
"""加载manifest文件强制要求"""
if not self.plugin_dir:
raise ValueError(f"{self.log_prefix} 没有插件目录路径无法加载manifest")
manifest_path = os.path.join(self.plugin_dir, self.manifest_file_name)
if not os.path.exists(manifest_path):
error_msg = f"{self.log_prefix} 缺少必需的manifest文件: {manifest_path}"
logger.error(error_msg)
raise FileNotFoundError(error_msg)
try:
with open(manifest_path, "r", encoding="utf-8") as f:
self.manifest_data = json.load(f)
logger.debug(f"{self.log_prefix} 成功加载manifest文件: {manifest_path}")
# 验证manifest格式
self._validate_manifest()
except json.JSONDecodeError as e:
error_msg = f"{self.log_prefix} manifest文件格式错误: {e}"
logger.error(error_msg)
raise ValueError(error_msg) # noqa
except IOError as e:
error_msg = f"{self.log_prefix} 读取manifest文件失败: {e}"
logger.error(error_msg)
raise IOError(error_msg) # noqa
def _get_author_name(self) -> str:
"""从manifest获取作者名称"""
author_info = self.get_manifest_info("author", {})
if isinstance(author_info, dict):
return author_info.get("name", "")
else:
return str(author_info) if author_info else ""
def _validate_manifest(self):
"""验证manifest文件格式使用强化的验证器"""
if not self.manifest_data:
raise ValueError(f"{self.log_prefix} manifest数据为空验证失败")
validator = ManifestValidator()
is_valid = validator.validate_manifest(self.manifest_data)
# 记录验证结果
if validator.validation_errors or validator.validation_warnings:
report = validator.get_validation_report()
logger.info(f"{self.log_prefix} Manifest验证结果:\n{report}")
# 如果有验证错误,抛出异常
if not is_valid:
error_msg = f"{self.log_prefix} Manifest文件验证失败"
if validator.validation_errors:
error_msg += f": {'; '.join(validator.validation_errors)}"
raise ValueError(error_msg)
def get_manifest_info(self, key: str, default: Any = None) -> Any:
"""获取manifest信息
Args:
key: 信息键,支持点分割的嵌套键(如 "author.name"
default: 默认值
Returns:
Any: 对应的值
"""
if not self.manifest_data:
return default
keys = key.split(".")
value = self.manifest_data
for k in keys:
if isinstance(value, dict) and k in value:
value = value[k]
else:
return default
return value
def _format_toml_value(self, value: Any) -> str:
"""将Python值格式化为合法的TOML字符串"""
if isinstance(value, str):
return json.dumps(value, ensure_ascii=False)
if isinstance(value, bool):
return str(value).lower()
if isinstance(value, (int, float)):
return str(value)
if isinstance(value, list):
inner = ", ".join(self._format_toml_value(item) for item in value)
return f"[{inner}]"
if isinstance(value, dict):
items = [f"{k} = {self._format_toml_value(v)}" for k, v in value.items()]
return "{ " + ", ".join(items) + " }"
return json.dumps(value, ensure_ascii=False)
def _generate_and_save_default_config(self, config_file_path: str):
"""根据插件的Schema生成并保存默认配置文件"""
if not self.config_schema:
logger.debug(f"{self.log_prefix} 插件未定义config_schema不生成配置文件")
return
toml_str = f"# {self.plugin_name} - 自动生成的配置文件\n"
plugin_description = self.get_manifest_info("description", "插件配置文件")
toml_str += f"# {plugin_description}\n\n"
# 遍历每个配置节
for section, fields in self.config_schema.items():
# 添加节描述
if section in self.config_section_descriptions:
toml_str += f"# {self.config_section_descriptions[section]}\n"
toml_str += f"[{section}]\n\n"
# 遍历节内的字段
if isinstance(fields, dict):
for field_name, field in fields.items():
if isinstance(field, ConfigField):
# 添加字段描述
toml_str += f"# {field.description}"
if field.required:
toml_str += " (必需)"
toml_str += "\n"
# 如果有示例值,添加示例
if field.example:
toml_str += f"# 示例: {field.example}\n"
# 如果有可选值,添加说明
if field.choices:
choices_str = ", ".join(map(str, field.choices))
toml_str += f"# 可选值: {choices_str}\n"
# 添加字段值
value = field.default
toml_str += f"{field_name} = {self._format_toml_value(value)}\n"
toml_str += "\n"
toml_str += "\n"
try:
with open(config_file_path, "w", encoding="utf-8") as f:
f.write(toml_str)
logger.info(f"{self.log_prefix} 已生成默认配置文件: {config_file_path}")
except IOError as e:
logger.error(f"{self.log_prefix} 保存默认配置文件失败: {e}", exc_info=True)
def _get_expected_config_version(self) -> str:
"""获取插件期望的配置版本号"""
# 从config_schema的plugin.config_version字段获取
if "plugin" in self.config_schema and isinstance(self.config_schema["plugin"], dict):
config_version_field = self.config_schema["plugin"].get("config_version")
if isinstance(config_version_field, ConfigField):
return config_version_field.default
return "1.0.0"
def _get_current_config_version(self, config: Dict[str, Any]) -> str:
"""从配置文件中获取当前版本号"""
if "plugin" in config and "config_version" in config["plugin"]:
return str(config["plugin"]["config_version"])
# 如果没有config_version字段视为最早的版本
return "0.0.0"
def _backup_config_file(self, config_file_path: str) -> str:
"""备份配置文件"""
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = f"{config_file_path}.backup_{timestamp}"
try:
shutil.copy2(config_file_path, backup_path)
logger.info(f"{self.log_prefix} 配置文件已备份到: {backup_path}")
return backup_path
except Exception as e:
logger.error(f"{self.log_prefix} 备份配置文件失败: {e}")
return ""
def _migrate_config_values(self, old_config: Dict[str, Any], new_config: Dict[str, Any]) -> Dict[str, Any]:
"""将旧配置值迁移到新配置结构中
Args:
old_config: 旧配置数据
new_config: 基于新schema生成的默认配置
Returns:
Dict[str, Any]: 迁移后的配置
"""
def migrate_section(
old_section: Dict[str, Any], new_section: Dict[str, Any], section_name: str
) -> Dict[str, Any]:
"""迁移单个配置节"""
result = new_section.copy()
for key, value in old_section.items():
if key in new_section:
# 特殊处理config_version字段总是使用新版本
if section_name == "plugin" and key == "config_version":
# 保持新的版本号,不迁移旧值
logger.debug(
f"{self.log_prefix} 更新配置版本: {section_name}.{key} = {result[key]} (旧值: {value})"
)
continue
# 键存在于新配置中,复制值
if isinstance(value, dict) and isinstance(new_section[key], dict):
# 递归处理嵌套字典
result[key] = migrate_section(value, new_section[key], f"{section_name}.{key}")
else:
result[key] = value
logger.debug(f"{self.log_prefix} 迁移配置: {section_name}.{key} = {value}")
else:
# 键在新配置中不存在,记录警告
logger.warning(f"{self.log_prefix} 配置项 {section_name}.{key} 在新版本中已被移除")
return result
migrated_config = {}
# 迁移每个配置节
for section_name, new_section_data in new_config.items():
if (
section_name in old_config
and isinstance(old_config[section_name], dict)
and isinstance(new_section_data, dict)
):
migrated_config[section_name] = migrate_section(
old_config[section_name], new_section_data, section_name
)
else:
# 新增的节或类型不匹配,使用默认值
migrated_config[section_name] = new_section_data
if section_name in old_config:
logger.warning(f"{self.log_prefix} 配置节 {section_name} 结构已改变,使用默认值")
# 检查旧配置中是否有新配置没有的节
for section_name in old_config:
if section_name not in migrated_config:
logger.warning(f"{self.log_prefix} 配置节 {section_name} 在新版本中已被移除")
return migrated_config
def _generate_config_from_schema(self) -> Dict[str, Any]:
# sourcery skip: dict-comprehension
"""根据schema生成配置数据结构不写入文件"""
if not self.config_schema:
return {}
config_data = {}
# 遍历每个配置节
for section, fields in self.config_schema.items():
if isinstance(fields, dict):
section_data = {}
# 遍历节内的字段
for field_name, field in fields.items():
if isinstance(field, ConfigField):
section_data[field_name] = field.default
config_data[section] = section_data
return config_data
def _save_config_to_file(self, config_data: Dict[str, Any], config_file_path: str):
"""将配置数据保存为TOML文件包含注释"""
if not self.config_schema:
logger.debug(f"{self.log_prefix} 插件未定义config_schema不生成配置文件")
return
toml_str = f"# {self.plugin_name} - 配置文件\n"
plugin_description = self.get_manifest_info("description", "插件配置文件")
toml_str += f"# {plugin_description}\n"
# 获取当前期望的配置版本
expected_version = self._get_expected_config_version()
toml_str += f"# 配置版本: {expected_version}\n\n"
# 遍历每个配置节
for section, fields in self.config_schema.items():
# 添加节描述
if section in self.config_section_descriptions:
toml_str += f"# {self.config_section_descriptions[section]}\n"
toml_str += f"[{section}]\n\n"
# 遍历节内的字段
if isinstance(fields, dict) and section in config_data:
section_data = config_data[section]
for field_name, field in fields.items():
if isinstance(field, ConfigField):
# 添加字段描述
toml_str += f"# {field.description}"
if field.required:
toml_str += " (必需)"
toml_str += "\n"
# 如果有示例值,添加示例
if field.example:
toml_str += f"# 示例: {field.example}\n"
# 如果有可选值,添加说明
if field.choices:
choices_str = ", ".join(map(str, field.choices))
toml_str += f"# 可选值: {choices_str}\n"
# 添加字段值(使用迁移后的值)
value = section_data.get(field_name, field.default)
toml_str += f"{field_name} = {self._format_toml_value(value)}\n"
toml_str += "\n"
toml_str += "\n"
try:
with open(config_file_path, "w", encoding="utf-8") as f:
f.write(toml_str)
logger.info(f"{self.log_prefix} 配置文件已保存: {config_file_path}")
except IOError as e:
logger.error(f"{self.log_prefix} 保存配置文件失败: {e}", exc_info=True)
def _load_plugin_config(self): # sourcery skip: extract-method
"""加载插件配置文件,支持版本检查和自动迁移"""
if not self.config_file_name:
logger.debug(f"{self.log_prefix} 未指定配置文件,跳过加载")
return
# 优先使用传入的插件目录路径
if self.plugin_dir:
plugin_dir = self.plugin_dir
else:
# fallback尝试从类的模块信息获取路径
try:
plugin_module_path = inspect.getfile(self.__class__)
plugin_dir = os.path.dirname(plugin_module_path)
except (TypeError, OSError):
# 最后的fallback从模块的__file__属性获取
module = inspect.getmodule(self.__class__)
if module and hasattr(module, "__file__") and module.__file__:
plugin_dir = os.path.dirname(module.__file__)
else:
logger.warning(f"{self.log_prefix} 无法获取插件目录路径,跳过配置加载")
return
config_file_path = os.path.join(plugin_dir, self.config_file_name)
# 如果配置文件不存在,生成默认配置
if not os.path.exists(config_file_path):
logger.info(f"{self.log_prefix} 配置文件 {config_file_path} 不存在,将生成默认配置。")
self._generate_and_save_default_config(config_file_path)
if not os.path.exists(config_file_path):
logger.warning(f"{self.log_prefix} 配置文件 {config_file_path} 不存在且无法生成。")
return
file_ext = os.path.splitext(self.config_file_name)[1].lower()
if file_ext == ".toml":
# 加载现有配置
with open(config_file_path, "r", encoding="utf-8") as f:
existing_config = toml.load(f) or {}
# 检查配置版本
current_version = self._get_current_config_version(existing_config)
# 如果配置文件没有版本信息,跳过版本检查
if current_version == "0.0.0":
logger.debug(f"{self.log_prefix} 配置文件无版本信息,跳过版本检查")
self.config = existing_config
else:
expected_version = self._get_expected_config_version()
if current_version != expected_version:
logger.info(
f"{self.log_prefix} 检测到配置版本需要更新: 当前=v{current_version}, 期望=v{expected_version}"
)
# 生成新的默认配置结构
new_config_structure = self._generate_config_from_schema()
# 迁移旧配置值到新结构
migrated_config = self._migrate_config_values(existing_config, new_config_structure)
# 保存迁移后的配置
self._save_config_to_file(migrated_config, config_file_path)
logger.info(f"{self.log_prefix} 配置文件已从 v{current_version} 更新到 v{expected_version}")
self.config = migrated_config
else:
logger.debug(f"{self.log_prefix} 配置版本匹配 (v{current_version}),直接加载")
self.config = existing_config
logger.debug(f"{self.log_prefix} 配置已从 {config_file_path} 加载")
# 从配置中更新 enable_plugin
if "plugin" in self.config and "enabled" in self.config["plugin"]:
self.enable_plugin = self.config["plugin"]["enabled"] # type: ignore
logger.debug(f"{self.log_prefix} 从配置更新插件启用状态: {self.enable_plugin}")
else:
logger.warning(f"{self.log_prefix} 不支持的配置文件格式: {file_ext},仅支持 .toml")
self.config = {}
def _check_dependencies(self) -> bool:
"""检查插件依赖"""
from src.plugin_system.core.component_registry import component_registry
if not self.dependencies:
return True
for dependency in self.dependencies:
dep_name, version_spec, min_version, max_version = self._parse_dependency(dependency)
if not dep_name:
logger.warning(f"{self.log_prefix} 跳过无效依赖声明: {dependency}")
continue
dep_plugin_info = component_registry.get_plugin_info(dep_name)
if not dep_plugin_info:
logger.error(f"{self.log_prefix} 缺少依赖插件: {dep_name}")
return False
dep_version = dep_plugin_info.version or "0.0.0"
if version_spec:
is_ok, msg = self._is_version_spec_satisfied(dep_version, version_spec)
if not is_ok:
logger.error(
f"{self.log_prefix} 依赖插件版本不满足: {dep_name} {version_spec}, 当前版本={dep_version} ({msg})"
)
return False
if min_version or max_version:
is_ok, msg = VersionComparator.is_version_in_range(dep_version, min_version, max_version)
if not is_ok:
logger.error(
f"{self.log_prefix} 依赖插件版本不满足: {dep_name} 要求区间[{min_version or '-inf'}, {max_version or '+inf'}], 当前版本={dep_version} ({msg})"
)
return False
return True
def _get_dependency_names(self) -> List[str]:
"""获取依赖插件名称列表(用于插件信息展示和统计)。"""
dependency_names: List[str] = []
for dependency in self.dependencies:
dep_name, _, _, _ = self._parse_dependency(dependency)
if dep_name:
dependency_names.append(dep_name)
return dependency_names
def _parse_dependency(self, dependency: Any) -> tuple[str, str, str, str]:
"""解析依赖声明。
支持格式:
- "plugin_a"
- {"name": "plugin_a", "version": ">=1.2.0,<2.0.0"}
- {"name": "plugin_a", "min_version": "1.2.0", "max_version": "2.0.0"}
"""
if isinstance(dependency, str):
return dependency.strip(), "", "", ""
if isinstance(dependency, dict):
dep_name = str(dependency.get("name", "")).strip()
version_spec = str(dependency.get("version", "")).strip()
min_version = str(dependency.get("min_version", "")).strip()
max_version = str(dependency.get("max_version", "")).strip()
return dep_name, version_spec, min_version, max_version
return "", "", "", ""
def _is_version_spec_satisfied(self, version: str, version_spec: str) -> tuple[bool, str]:
"""检查版本是否满足表达式。
支持:==, >=, <=, >, <,可用逗号分隔多个条件。
示例:">=1.2.0,<2.0.0"
"""
normalized_version = VersionComparator.normalize_version(version)
clauses = [clause.strip() for clause in version_spec.split(",") if clause.strip()]
if not clauses:
return True, ""
operators_pattern = r"^(==|>=|<=|>|<)\s*(.+)$"
for clause in clauses:
if not (match := re.match(operators_pattern, clause)):
return False, f"无效版本约束表达式: {clause}"
operator, target_version = match.group(1), VersionComparator.normalize_version(match.group(2))
compare_result = VersionComparator.compare_versions(normalized_version, target_version)
is_satisfied = False
if operator == "==":
is_satisfied = compare_result == 0
elif operator == ">=":
is_satisfied = compare_result >= 0
elif operator == "<=":
is_satisfied = compare_result <= 0
elif operator == ">":
is_satisfied = compare_result > 0
elif operator == "<":
is_satisfied = compare_result < 0
if not is_satisfied:
return False, f"{normalized_version} 不满足约束 {operator}{target_version}"
return True, ""
def get_config(self, key: str, default: Any = None) -> Any:
"""获取插件配置值,支持嵌套键访问
Args:
key: 配置键名,支持嵌套访问如 "section.subsection.key"
default: 默认值
Returns:
Any: 配置值或默认值
"""
# 支持嵌套键访问
keys = key.split(".")
current = self.config
for k in keys:
if isinstance(current, dict) and k in current:
current = current[k]
else:
return default
return current
def get_webui_config_schema(self) -> Dict[str, Any]:
"""
获取 WebUI 配置 Schema
返回完整的配置 schema包含
- 插件基本信息
- 所有 section 及其字段定义
- 布局配置
用于 WebUI 动态生成配置表单。
Returns:
Dict: 完整的配置 schema
"""
schema = {
"plugin_id": self.plugin_name,
"plugin_info": {
"name": self.display_name,
"version": self.plugin_version,
"description": self.plugin_description,
"author": self.plugin_author,
},
"sections": {},
"layout": None,
}
# 处理 sections
for section_name, fields in self.config_schema.items():
if not isinstance(fields, dict):
continue
section_data = {
"name": section_name,
"title": section_name,
"description": None,
"icon": None,
"collapsed": False,
"order": 0,
"fields": {},
}
# 获取 section 元数据
section_meta = self.config_section_descriptions.get(section_name)
if section_meta:
if isinstance(section_meta, str):
section_data["title"] = section_meta
elif isinstance(section_meta, ConfigSection):
section_data["title"] = section_meta.title
section_data["description"] = section_meta.description
section_data["icon"] = section_meta.icon
section_data["collapsed"] = section_meta.collapsed
section_data["order"] = section_meta.order
elif isinstance(section_meta, dict):
section_data.update(section_meta)
# 处理字段
for field_name, field_def in fields.items():
if isinstance(field_def, ConfigField):
field_data = field_def.to_dict()
field_data["name"] = field_name
section_data["fields"][field_name] = field_data
schema["sections"][section_name] = section_data
# 处理布局
if self.config_layout:
schema["layout"] = self.config_layout.to_dict()
else:
# 自动布局:按 section order 排序
schema["layout"] = {
"type": "auto",
"tabs": [],
}
return schema
def get_current_config_values(self) -> Dict[str, Any]:
"""
获取当前配置值
返回插件当前的配置值(已从配置文件加载)。
Returns:
Dict: 当前配置值
"""
return self.config.copy()
@abstractmethod
def register_plugin(self) -> bool:
"""
注册插件到插件管理器
子类必须实现此方法,返回注册是否成功
Returns:
bool: 是否成功注册插件
"""
raise NotImplementedError("Subclasses must implement this method")

View File

@@ -1,22 +0,0 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List
@dataclass
class PluginServiceInfo:
"""插件服务注册信息"""
name: str
plugin_name: str
version: str = "1.0.0"
description: str = ""
enabled: bool = True
public: bool = False
allowed_callers: List[str] = field(default_factory=list)
params_schema: Dict[str, Any] = field(default_factory=dict)
return_schema: Dict[str, Any] = field(default_factory=dict)
metadata: Dict[str, Any] = field(default_factory=dict)
@property
def full_name(self) -> str:
return f"{self.plugin_name}.{self.name}"

View File

@@ -1,14 +0,0 @@
from enum import Enum
class WorkflowErrorCode(Enum):
"""Workflow统一错误码"""
PLUGIN_NOT_READY = "PLUGIN_NOT_READY"
STEP_TIMEOUT = "STEP_TIMEOUT"
BAD_PAYLOAD = "BAD_PAYLOAD"
DOWNSTREAM_FAILED = "DOWNSTREAM_FAILED"
POLICY_BLOCKED = "POLICY_BLOCKED"
def __str__(self) -> str:
return self.value

View File

@@ -1,68 +0,0 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Literal, Optional
import time
class WorkflowStage(Enum):
"""Workflow阶段定义MVP固定阶段"""
INGRESS = "ingress"
PRE_PROCESS = "pre_process"
PLAN = "plan"
TOOL_EXECUTE = "tool_execute"
POST_PROCESS = "post_process"
EGRESS = "egress"
def __str__(self) -> str:
return self.value
@dataclass
class WorkflowContext:
"""Workflow上下文"""
trace_id: str
stream_id: Optional[str] = None
metadata: Dict[str, Any] = field(default_factory=dict)
timings: Dict[str, float] = field(default_factory=dict)
errors: List[str] = field(default_factory=list)
@dataclass
class WorkflowMessage:
"""Workflow消息包装"""
msg_type: str
payload: Dict[str, Any] = field(default_factory=dict)
headers: Dict[str, Any] = field(default_factory=dict)
mutable_flags: Dict[str, bool] = field(default_factory=dict)
@dataclass
class WorkflowStepResult:
"""Workflow步骤结果"""
status: Literal["continue", "stop", "failed"] = "continue"
return_message: Optional[str] = None
diagnostics: Dict[str, Any] = field(default_factory=dict)
events: List[Dict[str, Any]] = field(default_factory=list)
created_at: float = field(default_factory=time.time)
@dataclass
class WorkflowStepInfo:
"""Workflow步骤元数据"""
name: str
stage: WorkflowStage
plugin_name: str
description: str = ""
enabled: bool = True
priority: int = 0
timeout_ms: int = 0
metadata: Dict[str, Any] = field(default_factory=dict)
@property
def full_name(self) -> str:
return f"{self.plugin_name}.{self.name}"

View File

@@ -1,21 +0,0 @@
"""
插件核心管理模块
提供插件的加载、注册和管理功能
"""
from src.plugin_system.core.plugin_manager import plugin_manager
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.core.events_manager import events_manager
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.plugin_system.core.plugin_service_registry import plugin_service_registry
from src.plugin_system.core.workflow_engine import workflow_engine
__all__ = [
"plugin_manager",
"component_registry",
"events_manager",
"global_announcement_manager",
"plugin_service_registry",
"workflow_engine",
]

View File

@@ -1,758 +0,0 @@
import re
from typing import Callable, Dict, List, Optional, Any, Pattern, Tuple, Union, Type
from src.common.logger import get_logger
from src.plugin_system.base.component_types import (
ComponentInfo,
ActionInfo,
ToolInfo,
CommandInfo,
EventHandlerInfo,
PluginInfo,
ComponentType,
)
from src.plugin_system.base.base_command import BaseCommand
from src.plugin_system.base.base_action import BaseAction
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.base_events_handler import BaseEventHandler
from src.plugin_system.base.workflow_types import WorkflowStage, WorkflowStepInfo
logger = get_logger("component_registry")
class ComponentRegistry:
"""统一的组件注册中心
负责管理所有插件组件的注册、查询和生命周期管理
"""
def __init__(self):
# 命名空间式组件名构成法 f"{component_type}.{component_name}"
self._components: Dict[str, ComponentInfo] = {}
"""组件注册表 命名空间式组件名 -> 组件信息"""
self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType}
"""类型 -> 组件原名称 -> 组件信息"""
self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseTool, BaseEventHandler]]] = {}
"""命名空间式组件名 -> 组件类"""
# 插件注册表
self._plugins: Dict[str, PluginInfo] = {}
"""插件名 -> 插件信息"""
# Action特定注册表
self._action_registry: Dict[str, Type[BaseAction]] = {}
"""Action注册表 action名 -> action类"""
self._default_actions: Dict[str, ActionInfo] = {}
"""默认动作集即启用的Action集用于重置ActionManager状态"""
# Command特定注册表
self._command_registry: Dict[str, Type[BaseCommand]] = {}
"""Command类注册表 command名 -> command类"""
self._command_patterns: Dict[Pattern, str] = {}
"""编译后的正则 -> command名"""
# 工具特定注册表
self._tool_registry: Dict[str, Type[BaseTool]] = {} # 工具名 -> 工具类
self._llm_available_tools: Dict[str, Type[BaseTool]] = {} # llm可用的工具名 -> 工具类
# EventHandler特定注册表
self._event_handler_registry: Dict[str, Type[BaseEventHandler]] = {}
"""event_handler名 -> event_handler类"""
self._enabled_event_handlers: Dict[str, Type[BaseEventHandler]] = {}
"""启用的事件处理器 event_handler名 -> event_handler类"""
# Workflow step注册表
self._workflow_steps: Dict[WorkflowStage, Dict[str, WorkflowStepInfo]] = {stage: {} for stage in WorkflowStage}
self._workflow_step_handlers: Dict[str, Callable[..., Any]] = {}
logger.info("组件注册中心初始化完成")
# == 注册方法 ==
def register_plugin(self, plugin_info: PluginInfo) -> bool:
"""注册插件
Args:
plugin_info: 插件信息
Returns:
bool: 是否注册成功
"""
plugin_name = plugin_info.name
if plugin_name in self._plugins:
logger.warning(f"插件 {plugin_name} 已存在,跳过注册")
return False
self._plugins[plugin_name] = plugin_info
logger.debug(f"已注册插件: {plugin_name} (组件数量: {len(plugin_info.components)})")
return True
def register_component(
self,
component_info: ComponentInfo,
component_class: Type[Union[BaseCommand, BaseAction, BaseEventHandler, BaseTool]],
) -> bool:
"""注册组件
Args:
component_info (ComponentInfo): 组件信息
component_class (Type[Union[BaseCommand, BaseAction, BaseEventHandler]]): 组件类
Returns:
bool: 是否注册成功
"""
component_name = component_info.name
component_type = component_info.component_type
plugin_name = getattr(component_info, "plugin_name", "unknown")
if "." in component_name:
logger.error(f"组件名称 '{component_name}' 包含非法字符 '.',请使用下划线替代")
return False
if "." in plugin_name:
logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
return False
namespaced_name = f"{component_type}.{component_name}"
if namespaced_name in self._components:
existing_info = self._components[namespaced_name]
existing_plugin = getattr(existing_info, "plugin_name", "unknown")
logger.warning(
f"组件名冲突: '{plugin_name}' 插件的 {component_type} 类型组件 '{component_name}' 已被插件 '{existing_plugin}' 注册,跳过此组件注册"
)
return False
self._components[namespaced_name] = component_info # 注册到通用注册表(使用命名空间化的名称)
self._components_by_type[component_type][component_name] = component_info # 类型内部仍使用原名
self._components_classes[namespaced_name] = component_class
# 根据组件类型进行特定注册(使用原始名称)
ret = False
match component_type:
case ComponentType.ACTION:
assert isinstance(component_info, ActionInfo)
assert issubclass(component_class, BaseAction)
ret = self._register_action_component(component_info, component_class)
case ComponentType.COMMAND:
assert isinstance(component_info, CommandInfo)
assert issubclass(component_class, BaseCommand)
ret = self._register_command_component(component_info, component_class)
case ComponentType.TOOL:
assert isinstance(component_info, ToolInfo)
assert issubclass(component_class, BaseTool)
ret = self._register_tool_component(component_info, component_class)
case ComponentType.EVENT_HANDLER:
assert isinstance(component_info, EventHandlerInfo)
assert issubclass(component_class, BaseEventHandler)
ret = self._register_event_handler_component(component_info, component_class)
case _:
logger.warning(f"未知组件类型: {component_type}")
if not ret:
return False
logger.debug(
f"已注册{component_type}组件: '{component_name}' -> '{namespaced_name}' "
f"({component_class.__name__}) [插件: {plugin_name}]"
)
return True
def _register_action_component(self, action_info: ActionInfo, action_class: Type[BaseAction]) -> bool:
"""注册Action组件到Action特定注册表"""
if not (action_name := action_info.name):
logger.error(f"Action组件 {action_class.__name__} 必须指定名称")
return False
if not isinstance(action_info, ActionInfo) or not issubclass(action_class, BaseAction):
logger.error(f"注册失败: {action_name} 不是有效的Action")
return False
self._action_registry[action_name] = action_class
# 如果启用,添加到默认动作集
if action_info.enabled:
self._default_actions[action_name] = action_info
return True
def _register_command_component(self, command_info: CommandInfo, command_class: Type[BaseCommand]) -> bool:
"""注册Command组件到Command特定注册表"""
if not (command_name := command_info.name):
logger.error(f"Command组件 {command_class.__name__} 必须指定名称")
return False
if not isinstance(command_info, CommandInfo) or not issubclass(command_class, BaseCommand):
logger.error(f"注册失败: {command_name} 不是有效的Command")
return False
self._command_registry[command_name] = command_class
# 如果启用了且有匹配模式
if command_info.enabled and command_info.command_pattern:
pattern = re.compile(command_info.command_pattern, re.IGNORECASE | re.DOTALL)
if pattern not in self._command_patterns:
self._command_patterns[pattern] = command_name
else:
logger.warning(
f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令"
)
return True
def _register_tool_component(self, tool_info: ToolInfo, tool_class: Type[BaseTool]) -> bool:
"""注册Tool组件到Tool特定注册表"""
tool_name = tool_info.name
self._tool_registry[tool_name] = tool_class
# 如果是llm可用的且启用的工具,添加到 llm可用工具列表
if tool_info.enabled:
self._llm_available_tools[tool_name] = tool_class
return True
def _register_event_handler_component(
self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]
) -> bool:
if not (handler_name := handler_info.name):
logger.error(f"EventHandler组件 {handler_class.__name__} 必须指定名称")
return False
if not isinstance(handler_info, EventHandlerInfo) or not issubclass(handler_class, BaseEventHandler):
logger.error(f"注册失败: {handler_name} 不是有效的EventHandler")
return False
self._event_handler_registry[handler_name] = handler_class
if not handler_info.enabled:
logger.warning(f"EventHandler组件 {handler_name} 未启用")
return True # 未启用,但是也是注册成功
from .events_manager import events_manager # 延迟导入防止循环导入问题
if events_manager.register_event_subscriber(handler_info, handler_class):
self._enabled_event_handlers[handler_name] = handler_class
return True
else:
logger.error(f"注册事件处理器 {handler_name} 失败")
return False
# === 组件移除相关 ===
async def remove_component(self, component_name: str, component_type: ComponentType, plugin_name: str) -> bool:
target_component_class = self.get_component_class(component_name, component_type)
if not target_component_class:
logger.warning(f"组件 {component_name} 未注册,无法移除")
return False
try:
match component_type:
case ComponentType.ACTION:
self._action_registry.pop(component_name)
self._default_actions.pop(component_name)
case ComponentType.COMMAND:
self._command_registry.pop(component_name)
keys_to_remove = [k for k, v in self._command_patterns.items() if v == component_name]
for key in keys_to_remove:
self._command_patterns.pop(key)
case ComponentType.TOOL:
self._tool_registry.pop(component_name)
self._llm_available_tools.pop(component_name)
case ComponentType.EVENT_HANDLER:
from .events_manager import events_manager # 延迟导入防止循环导入问题
self._event_handler_registry.pop(component_name)
self._enabled_event_handlers.pop(component_name)
await events_manager.unregister_event_subscriber(component_name)
namespaced_name = f"{component_type}.{component_name}"
self._components.pop(namespaced_name)
self._components_by_type[component_type].pop(component_name)
self._components_classes.pop(namespaced_name)
logger.info(f"组件 {component_name} 已移除")
return True
except KeyError as e:
logger.warning(f"移除组件时未找到组件: {component_name}, 发生错误: {e}")
return False
except Exception as e:
logger.error(f"移除组件 {component_name} 时发生错误: {e}")
return False
async def remove_components_by_plugin(self, plugin_name: str) -> int:
"""移除某插件注册的所有组件。"""
targets = [
(component_info.name, component_info.component_type)
for component_info in self._components.values()
if component_info.plugin_name == plugin_name
]
removed_count = 0
for component_name, component_type in targets:
if await self.remove_component(component_name, component_type, plugin_name):
removed_count += 1
if removed_count:
logger.info(f"已移除插件 {plugin_name} 的组件数量: {removed_count}")
return removed_count
def remove_plugin_registry(self, plugin_name: str) -> bool:
"""移除插件注册信息
Args:
plugin_name: 插件名称
Returns:
bool: 是否成功移除
"""
if plugin_name not in self._plugins:
logger.warning(f"插件 {plugin_name} 未注册,无法移除")
return False
del self._plugins[plugin_name]
self.remove_workflow_steps_by_plugin(plugin_name)
logger.info(f"插件 {plugin_name} 已移除")
return True
# === Workflow step 注册与查询 ===
def register_workflow_step(self, step_info: WorkflowStepInfo, step_handler: Callable[..., Any]) -> bool:
"""注册workflow步骤。"""
if not step_info.name or not step_info.plugin_name:
logger.error("workflow step 注册失败: step名称或插件名称为空")
return False
if "." in step_info.name:
logger.error(f"workflow step 名称 '{step_info.name}' 包含非法字符 '.',请使用下划线替代")
return False
if "." in step_info.plugin_name:
logger.error(f"workflow step 所属插件名称 '{step_info.plugin_name}' 包含非法字符 '.',请使用下划线替代")
return False
full_name = step_info.full_name
stage_registry = self._workflow_steps.get(step_info.stage)
if stage_registry is None:
logger.error(f"workflow step 注册失败: 未知阶段 {step_info.stage}")
return False
if full_name in stage_registry:
logger.warning(f"workflow step 已存在,跳过注册: {full_name}")
return False
stage_registry[full_name] = step_info
self._workflow_step_handlers[full_name] = step_handler
logger.debug(f"已注册workflow step: {full_name} @ {step_info.stage}")
return True
def get_steps_by_stage(self, stage: WorkflowStage, enabled_only: bool = False) -> Dict[str, WorkflowStepInfo]:
"""获取某阶段的workflow步骤。"""
steps = self._workflow_steps.get(stage, {})
if enabled_only:
return {name: info for name, info in steps.items() if info.enabled}
return steps.copy()
def get_workflow_step(self, step_name: str, stage: Optional[WorkflowStage] = None) -> Optional[WorkflowStepInfo]:
"""获取workflow step信息。
step_name支持两种
- full_name: plugin_name.step_name
- short_name: step_name若有冲突返回第一个并告警
"""
candidates: List[WorkflowStepInfo] = []
target_stages = [stage] if stage else list(WorkflowStage)
for current_stage in target_stages:
current_steps = self._workflow_steps.get(current_stage, {})
if "." in step_name:
if step_info := current_steps.get(step_name):
return step_info
continue
candidates.extend([step_info for step_info in current_steps.values() if step_info.name == step_name])
if len(candidates) == 1:
return candidates[0]
if len(candidates) > 1:
logger.warning(f"workflow step 名称 '{step_name}' 存在多义,使用第一个匹配: {candidates[0].full_name}")
return candidates[0]
return None
def get_workflow_step_handler(
self, step_name: str, stage: Optional[WorkflowStage] = None
) -> Optional[Callable[..., Any]]:
"""获取workflow step处理函数。"""
if "." in step_name:
return self._workflow_step_handlers.get(step_name)
if step_info := self.get_workflow_step(step_name, stage):
return self._workflow_step_handlers.get(step_info.full_name)
return None
def enable_workflow_step(self, step_name: str, stage: Optional[WorkflowStage] = None) -> bool:
"""启用workflow step。"""
step_info = self.get_workflow_step(step_name, stage)
if not step_info:
logger.warning(f"workflow step 未注册,无法启用: {step_name}")
return False
step_info.enabled = True
logger.info(f"workflow step 已启用: {step_info.full_name}")
return True
def disable_workflow_step(self, step_name: str, stage: Optional[WorkflowStage] = None) -> bool:
"""禁用workflow step。"""
step_info = self.get_workflow_step(step_name, stage)
if not step_info:
logger.warning(f"workflow step 未注册,无法禁用: {step_name}")
return False
step_info.enabled = False
logger.info(f"workflow step 已禁用: {step_info.full_name}")
return True
def remove_workflow_steps_by_plugin(self, plugin_name: str) -> int:
"""移除某插件注册的所有workflow step。"""
removed_count = 0
for stage in WorkflowStage:
stage_registry = self._workflow_steps.get(stage, {})
target_names = [name for name, info in stage_registry.items() if info.plugin_name == plugin_name]
for full_name in target_names:
stage_registry.pop(full_name, None)
self._workflow_step_handlers.pop(full_name, None)
removed_count += 1
if removed_count:
logger.info(f"已移除插件 {plugin_name} 的 workflow step 数量: {removed_count}")
return removed_count
# === 组件全局启用/禁用方法 ===
def enable_component(self, component_name: str, component_type: ComponentType) -> bool:
"""全局的启用某个组件
Parameters:
component_name: 组件名称
component_type: 组件类型
Returns:
bool: 启用成功返回True失败返回False
"""
target_component_class = self.get_component_class(component_name, component_type)
target_component_info = self.get_component_info(component_name, component_type)
if not target_component_class or not target_component_info:
logger.warning(f"组件 {component_name} 未注册,无法启用")
return False
target_component_info.enabled = True
match component_type:
case ComponentType.ACTION:
assert isinstance(target_component_info, ActionInfo)
self._default_actions[component_name] = target_component_info
case ComponentType.COMMAND:
assert isinstance(target_component_info, CommandInfo)
pattern = target_component_info.command_pattern
self._command_patterns[re.compile(pattern)] = component_name
case ComponentType.TOOL:
assert isinstance(target_component_info, ToolInfo)
assert issubclass(target_component_class, BaseTool)
self._llm_available_tools[component_name] = target_component_class
case ComponentType.EVENT_HANDLER:
assert isinstance(target_component_info, EventHandlerInfo)
assert issubclass(target_component_class, BaseEventHandler)
self._enabled_event_handlers[component_name] = target_component_class
from .events_manager import events_manager # 延迟导入防止循环导入问题
events_manager.register_event_subscriber(target_component_info, target_component_class)
namespaced_name = f"{component_type}.{component_name}"
self._components[namespaced_name].enabled = True
self._components_by_type[component_type][component_name].enabled = True
logger.info(f"组件 {component_name} 已启用")
return True
async def disable_component(self, component_name: str, component_type: ComponentType) -> bool:
"""全局的禁用某个组件
Parameters:
component_name: 组件名称
component_type: 组件类型
Returns:
bool: 禁用成功返回True失败返回False
"""
target_component_class = self.get_component_class(component_name, component_type)
target_component_info = self.get_component_info(component_name, component_type)
if not target_component_class or not target_component_info:
logger.warning(f"组件 {component_name} 未注册,无法禁用")
return False
target_component_info.enabled = False
try:
match component_type:
case ComponentType.ACTION:
self._default_actions.pop(component_name)
case ComponentType.COMMAND:
self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != component_name}
case ComponentType.TOOL:
self._llm_available_tools.pop(component_name)
case ComponentType.EVENT_HANDLER:
self._enabled_event_handlers.pop(component_name)
from .events_manager import events_manager # 延迟导入防止循环导入问题
await events_manager.unregister_event_subscriber(component_name)
self._components[component_name].enabled = False
self._components_by_type[component_type][component_name].enabled = False
logger.info(f"组件 {component_name} 已禁用")
return True
except KeyError as e:
logger.warning(f"禁用组件时未找到组件或已禁用: {component_name}, 发生错误: {e}")
return False
except Exception as e:
logger.error(f"禁用组件 {component_name} 时发生错误: {e}")
return False
# === 组件查询方法 ===
def get_component_info(
self, component_name: str, component_type: Optional[ComponentType] = None
) -> Optional[ComponentInfo]:
# sourcery skip: class-extract-method
"""获取组件信息,支持自动命名空间解析
Args:
component_name: 组件名称,可以是原始名称或命名空间化的名称
component_type: 组件类型,如果提供则优先在该类型中查找
Returns:
Optional[ComponentInfo]: 组件信息或None
"""
# 1. 如果已经是命名空间化的名称,直接查找
if "." in component_name:
return self._components.get(component_name)
# 2. 如果指定了组件类型,构造命名空间化的名称查找
if component_type:
namespaced_name = f"{component_type}.{component_name}"
return self._components.get(namespaced_name)
# 3. 如果没有指定类型,尝试在所有命名空间中查找
candidates = []
for namespace_prefix in [types.value for types in ComponentType]:
namespaced_name = f"{namespace_prefix}.{component_name}"
if component_info := self._components.get(namespaced_name):
candidates.append((namespace_prefix, namespaced_name, component_info))
if len(candidates) == 1:
# 只有一个匹配,直接返回
return candidates[0][2]
elif len(candidates) > 1:
# 多个匹配,记录警告并返回第一个
namespaces = [ns for ns, _, _ in candidates]
logger.warning(
f"组件名称 '{component_name}' 在多个命名空间中存在: {namespaces},使用第一个匹配项: {candidates[0][1]}"
)
return candidates[0][2]
# 4. 都没找到
return None
def get_component_class(
self,
component_name: str,
component_type: Optional[ComponentType] = None,
) -> Optional[Union[Type[BaseCommand], Type[BaseAction], Type[BaseEventHandler], Type[BaseTool]]]:
"""获取组件类,支持自动命名空间解析
Args:
component_name: 组件名称,可以是原始名称或命名空间化的名称
component_type: 组件类型,如果提供则优先在该类型中查找
Returns:
Optional[Union[BaseCommand, BaseAction]]: 组件类或None
"""
# 1. 如果已经是命名空间化的名称,直接查找
if "." in component_name:
return self._components_classes.get(component_name)
# 2. 如果指定了组件类型,构造命名空间化的名称查找
if component_type:
namespaced_name = f"{component_type.value}.{component_name}"
return self._components_classes.get(namespaced_name)
# 3. 如果没有指定类型,尝试在所有命名空间中查找
candidates = []
for namespace_prefix in [types.value for types in ComponentType]:
namespaced_name = f"{namespace_prefix}.{component_name}"
if component_class := self._components_classes.get(namespaced_name):
candidates.append((namespace_prefix, namespaced_name, component_class))
if len(candidates) == 1:
# 只有一个匹配,直接返回
_, full_name, cls = candidates[0]
logger.debug(f"自动解析组件: '{component_name}' -> '{full_name}'")
return cls
elif len(candidates) > 1:
# 多个匹配,记录警告并返回第一个
namespaces = [ns for ns, _, _ in candidates]
logger.warning(
f"组件名称 '{component_name}' 在多个命名空间中存在: {namespaces},使用第一个匹配项: {candidates[0][1]}"
)
return candidates[0][2]
# 4. 都没找到
return None
def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
"""获取指定类型的所有组件"""
return self._components_by_type.get(component_type, {}).copy()
def get_enabled_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
"""获取指定类型的所有启用组件"""
components = self.get_components_by_type(component_type)
return {name: info for name, info in components.items() if info.enabled}
# === Action特定查询方法 ===
def get_action_registry(self) -> Dict[str, Type[BaseAction]]:
"""获取Action注册表"""
return self._action_registry.copy()
def get_registered_action_info(self, action_name: str) -> Optional[ActionInfo]:
"""获取Action信息"""
info = self.get_component_info(action_name, ComponentType.ACTION)
return info if isinstance(info, ActionInfo) else None
def get_default_actions(self) -> Dict[str, ActionInfo]:
"""获取默认动作集"""
return self._default_actions.copy()
# === Command特定查询方法 ===
def get_command_registry(self) -> Dict[str, Type[BaseCommand]]:
"""获取Command注册表"""
return self._command_registry.copy()
def get_registered_command_info(self, command_name: str) -> Optional[CommandInfo]:
"""获取Command信息"""
info = self.get_component_info(command_name, ComponentType.COMMAND)
return info if isinstance(info, CommandInfo) else None
def get_command_patterns(self) -> Dict[Pattern, str]:
"""获取Command模式注册表"""
return self._command_patterns.copy()
def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, CommandInfo]]:
# sourcery skip: use-named-expression, use-next
"""根据文本查找匹配的命令
Args:
text: 输入文本
Returns:
Tuple: (命令类, 匹配的命名组, 是否拦截消息, 插件名) 或 None
"""
candidates = [pattern for pattern in self._command_patterns if pattern.match(text)]
if not candidates:
return None
if len(candidates) > 1:
logger.warning(f"文本 '{text}' 匹配到多个命令模式: {candidates},使用第一个匹配")
command_name = self._command_patterns[candidates[0]]
command_info: CommandInfo = self.get_registered_command_info(command_name) # type: ignore
return (
self._command_registry[command_name],
candidates[0].match(text).groupdict(), # type: ignore
command_info,
)
# === Tool 特定查询方法 ===
def get_tool_registry(self) -> Dict[str, Type[BaseTool]]:
"""获取Tool注册表"""
return self._tool_registry.copy()
def get_llm_available_tools(self) -> Dict[str, Type[BaseTool]]:
"""获取LLM可用的Tool列表"""
return self._llm_available_tools.copy()
def get_registered_tool_info(self, tool_name: str) -> Optional[ToolInfo]:
"""获取Tool信息
Args:
tool_name: 工具名称
Returns:
ToolInfo: 工具信息对象,如果工具不存在则返回 None
"""
info = self.get_component_info(tool_name, ComponentType.TOOL)
return info if isinstance(info, ToolInfo) else None
# === EventHandler 特定查询方法 ===
def get_event_handler_registry(self) -> Dict[str, Type[BaseEventHandler]]:
"""获取事件处理器注册表"""
return self._event_handler_registry.copy()
def get_registered_event_handler_info(self, handler_name: str) -> Optional[EventHandlerInfo]:
"""获取事件处理器信息"""
info = self.get_component_info(handler_name, ComponentType.EVENT_HANDLER)
return info if isinstance(info, EventHandlerInfo) else None
def get_enabled_event_handlers(self) -> Dict[str, Type[BaseEventHandler]]:
"""获取启用的事件处理器"""
return self._enabled_event_handlers.copy()
# === 插件查询方法 ===
def get_plugin_info(self, plugin_name: str) -> Optional[PluginInfo]:
"""获取插件信息"""
return self._plugins.get(plugin_name)
def get_all_plugins(self) -> Dict[str, PluginInfo]:
"""获取所有插件"""
return self._plugins.copy()
# def get_enabled_plugins(self) -> Dict[str, PluginInfo]:
# """获取所有启用的插件"""
# return {name: info for name, info in self._plugins.items() if info.enabled}
def get_plugin_components(self, plugin_name: str) -> List[ComponentInfo]:
"""获取插件的所有组件"""
plugin_info = self.get_plugin_info(plugin_name)
return plugin_info.components if plugin_info else []
def get_plugin_config(self, plugin_name: str) -> Optional[dict]:
"""获取插件配置
Args:
plugin_name: 插件名称
Returns:
Optional[dict]: 插件配置字典或None
"""
# 从插件管理器获取插件实例的配置
from src.plugin_system.core.plugin_manager import plugin_manager
plugin_instance = plugin_manager.get_plugin_instance(plugin_name)
return plugin_instance.config if plugin_instance else None
def get_registry_stats(self) -> Dict[str, Any]:
"""获取注册中心统计信息"""
action_components: int = 0
command_components: int = 0
tool_components: int = 0
events_handlers: int = 0
for component in self._components.values():
if component.component_type == ComponentType.ACTION:
action_components += 1
elif component.component_type == ComponentType.COMMAND:
command_components += 1
elif component.component_type == ComponentType.TOOL:
tool_components += 1
elif component.component_type == ComponentType.EVENT_HANDLER:
events_handlers += 1
workflow_step_count = sum(len(steps) for steps in self._workflow_steps.values())
enabled_workflow_step_count = sum(
len([step for step in steps.values() if step.enabled]) for steps in self._workflow_steps.values()
)
return {
"action_components": action_components,
"command_components": command_components,
"tool_components": tool_components,
"event_handlers": events_handlers,
"total_components": len(self._components),
"total_plugins": len(self._plugins),
"components_by_type": {
component_type.value: len(components) for component_type, components in self._components_by_type.items()
},
"enabled_components": len([c for c in self._components.values() if c.enabled]),
"enabled_plugins": len([p for p in self._plugins.values() if p.enabled]),
"workflow_steps": workflow_step_count,
"enabled_workflow_steps": enabled_workflow_step_count,
"workflow_steps_by_stage": {stage.value: len(steps) for stage, steps in self._workflow_steps.items()},
}
component_registry = ComponentRegistry()

View File

@@ -1,512 +0,0 @@
import asyncio
import contextlib
from typing import List, Dict, Optional, Type, Tuple, TYPE_CHECKING
from src.chat.message_receive.message import MessageSending, SessionMessage
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.common.logger import get_logger
from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages, CustomEventHandlerResult
from src.plugin_system.base.base_events_handler import BaseEventHandler
from src.plugin_system.base.workflow_types import WorkflowContext, WorkflowStepResult
from .global_announcement_manager import global_announcement_manager
from .workflow_engine import workflow_engine
if TYPE_CHECKING:
from src.common.data_models.llm_data_model import LLMGenerationDataModel
logger = get_logger("events_manager")
class EventsManager:
def __init__(self):
# 有权重的 events 订阅者注册表
self._events_subscribers: Dict[EventType | str, List[BaseEventHandler]] = {}
self._handler_mapping: Dict[str, Type[BaseEventHandler]] = {} # 事件处理器映射表
self._handler_tasks: Dict[str, List[asyncio.Task]] = {} # 事件处理器正在处理的任务
self._events_result_history: Dict[EventType | str, List[CustomEventHandlerResult]] = {} # 事件的结果历史记录
self._history_enable_map: Dict[EventType | str, bool] = {} # 是否启用历史记录的映射表同时作为events注册表
# 事件注册(同时作为注册样例)
for event in EventType:
self.register_event(event, enable_history_result=False)
def register_event(self, event_type: EventType | str, enable_history_result: bool = False):
if event_type in self._events_subscribers:
raise ValueError(f"事件类型 {event_type} 已存在")
self._events_subscribers[event_type] = []
self._history_enable_map[event_type] = enable_history_result
if enable_history_result:
self._events_result_history[event_type] = []
def register_event_subscriber(self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]) -> bool:
"""注册事件处理器
Args:
handler_info (EventHandlerInfo): 事件处理器信息
handler_class (Type[BaseEventHandler]): 事件处理器类
Returns:
bool: 是否注册成功
"""
if not issubclass(handler_class, BaseEventHandler):
logger.error(f"{handler_class.__name__} 不是 BaseEventHandler 的子类")
return False
handler_name = handler_info.name
if handler_name in self._handler_mapping:
logger.warning(f"事件处理器 {handler_name} 已存在,跳过注册")
return False
if handler_info.event_type not in self._history_enable_map:
if isinstance(handler_info.event_type, str):
self.register_event(handler_info.event_type, enable_history_result=False)
logger.info(f"自动注册自定义事件类型: {handler_info.event_type}")
else:
logger.error(f"事件类型 {handler_info.event_type} 未注册,无法为其注册处理器 {handler_name}")
return False
self._handler_mapping[handler_name] = handler_class
return self._insert_event_handler(handler_class, handler_info)
async def handle_mai_events(
self,
event_type: EventType | str,
message: Optional[SessionMessage | MessageSending | MaiMessages] = None,
llm_prompt: Optional[str] = None,
llm_response: Optional["LLMGenerationDataModel"] = None,
stream_id: Optional[str] = None,
action_usage: Optional[List[str]] = None,
) -> Tuple[bool, Optional[MaiMessages]]:
"""
处理所有事件,根据事件类型分发给订阅的处理器。
"""
from src.plugin_system.core import component_registry
continue_flag = True
# 1. 准备消息
transformed_message = self._prepare_message(
event_type, message, llm_prompt, llm_response, stream_id, action_usage # type: ignore[arg-type]
)
if transformed_message:
transformed_message = transformed_message.deepcopy()
# 2. 获取并遍历处理器
handlers = self._events_subscribers.get(event_type, [])
if not handlers:
return True, None
current_stream_id = transformed_message.stream_id if transformed_message else None
modified_message: Optional[MaiMessages] = None
for handler in handlers:
# 3. 前置检查和配置加载
if (
current_stream_id
and handler.handler_name
in global_announcement_manager.get_disabled_chat_event_handlers(current_stream_id)
):
continue
# 统一加载插件配置
plugin_config = component_registry.get_plugin_config(handler.plugin_name) or {}
handler.set_plugin_config(plugin_config)
# 4. 根据类型分发任务
if (
handler.intercept_message or event_type == EventType.ON_STOP
): # 让ON_STOP的所有事件处理器都发挥作用防止还没执行即被取消
# 阻塞执行,并更新 continue_flag
should_continue, modified_message = await self._dispatch_intercepting_handler_task(
handler, event_type, modified_message or transformed_message
)
continue_flag = continue_flag and should_continue
else:
# 异步执行,不阻塞
self._dispatch_handler_task(handler, event_type, transformed_message)
# 桥接到新版本插件运行时
continue_flag, modified_message = await self._bridge_to_new_runtime(
event_type, continue_flag, modified_message or transformed_message
)
return continue_flag, modified_message
async def handle_workflow_message(
self,
message: Optional[SessionMessage | MessageSending | MaiMessages] = None,
stream_id: Optional[str] = None,
action_usage: Optional[List[str]] = None,
context: Optional[WorkflowContext] = None,
) -> Tuple[WorkflowStepResult, Optional[MaiMessages], WorkflowContext]:
"""执行线性workflow消息流转MVP兼容入口"""
initial_message = self._prepare_message(EventType.ON_MESSAGE_PRE_PROCESS, message=message, stream_id=stream_id)
async def _dispatch(
event_type: EventType | str,
workflow_message: Optional[MaiMessages],
workflow_stream_id: Optional[str],
workflow_action_usage: Optional[List[str]],
) -> Tuple[bool, Optional[MaiMessages]]:
return await self.handle_mai_events(
event_type=event_type,
message=workflow_message,
stream_id=workflow_stream_id,
action_usage=workflow_action_usage,
)
return await workflow_engine.execute_linear(
dispatch_event=_dispatch,
message=initial_message,
stream_id=stream_id,
action_usage=action_usage,
context=context,
)
async def cancel_handler_tasks(self, handler_name: str) -> None:
tasks_to_be_cancelled = self._handler_tasks.get(handler_name, [])
if remaining_tasks := [task for task in tasks_to_be_cancelled if not task.done()]:
for task in remaining_tasks:
task.cancel()
try:
await asyncio.wait_for(asyncio.gather(*remaining_tasks, return_exceptions=True), timeout=5)
logger.info(f"已取消事件处理器 {handler_name} 的所有任务")
except asyncio.TimeoutError:
logger.warning(f"取消事件处理器 {handler_name} 的任务超时,开始强制取消")
except Exception as e:
logger.error(f"取消事件处理器 {handler_name} 的任务时发生异常: {e}")
if handler_name in self._handler_tasks:
del self._handler_tasks[handler_name]
async def unregister_event_subscriber(self, handler_name: str) -> bool:
"""取消注册事件处理器"""
if handler_name not in self._handler_mapping:
logger.warning(f"事件处理器 {handler_name} 不存在,无法取消注册")
return False
await self.cancel_handler_tasks(handler_name)
handler_class = self._handler_mapping.pop(handler_name)
if not self._remove_event_handler_instance(handler_class):
return False
logger.info(f"事件处理器 {handler_name} 已成功取消注册")
return True
async def get_event_result_history(self, event_type: EventType | str) -> List[CustomEventHandlerResult]:
"""获取事件的结果历史记录"""
if event_type == EventType.UNKNOWN:
raise ValueError("未知事件类型")
if event_type not in self._history_enable_map:
raise ValueError(f"事件类型 {event_type} 未注册")
if not self._history_enable_map[event_type]:
raise ValueError(f"事件类型 {event_type} 的历史记录未启用")
return self._events_result_history[event_type]
async def clear_event_result_history(self, event_type: EventType | str) -> None:
"""清空事件的结果历史记录"""
if event_type == EventType.UNKNOWN:
raise ValueError("未知事件类型")
if event_type not in self._history_enable_map:
raise ValueError(f"事件类型 {event_type} 未注册")
if not self._history_enable_map[event_type]:
raise ValueError(f"事件类型 {event_type} 的历史记录未启用")
self._events_result_history[event_type] = []
def _insert_event_handler(self, handler_class: Type[BaseEventHandler], handler_info: EventHandlerInfo) -> bool:
"""插入事件处理器到对应的事件类型列表中并设置其插件配置"""
if handler_class.event_type == EventType.UNKNOWN:
logger.error(f"事件处理器 {handler_class.__name__} 的事件类型未知,无法注册")
return False
if handler_class.event_type not in self._events_subscribers:
self._events_subscribers[handler_class.event_type] = []
handler_instance = handler_class()
handler_instance.set_plugin_name(handler_info.plugin_name or "unknown")
self._events_subscribers[handler_class.event_type].append(handler_instance)
self._events_subscribers[handler_class.event_type].sort(key=lambda x: x.weight, reverse=True)
return True
def _remove_event_handler_instance(self, handler_class: Type[BaseEventHandler]) -> bool:
"""从事件类型列表中移除事件处理器"""
display_handler_name = handler_class.handler_name or handler_class.__name__
if handler_class.event_type == EventType.UNKNOWN:
logger.warning(f"事件处理器 {display_handler_name} 的事件类型未知,不存在于处理器列表中")
return False
handlers = self._events_subscribers[handler_class.event_type]
for i, handler in enumerate(handlers):
if isinstance(handler, handler_class):
del handlers[i]
logger.debug(f"事件处理器 {display_handler_name} 已移除")
return True
logger.warning(f"未找到事件处理器 {display_handler_name},无法移除")
return False
def _transform_event_message(
self,
message: SessionMessage | MessageSending,
llm_prompt: Optional[str] = None,
llm_response: Optional["LLMGenerationDataModel"] = None,
) -> MaiMessages:
"""转换事件消息格式"""
from maim_message import Seg
# 直接赋值部分内容
transformed_message = MaiMessages(
llm_prompt=llm_prompt,
llm_response_content=llm_response.content if llm_response else None,
llm_response_reasoning=llm_response.reasoning if llm_response else None,
llm_response_model=llm_response.model if llm_response else None,
llm_response_tool_call=llm_response.tool_calls if llm_response else None,
raw_message=message.processed_plain_text or "",
additional_data={},
)
# 消息段处理
if isinstance(message, MessageSending):
if message.message_segment.type == "seglist":
transformed_message.message_segments = list(message.message_segment.data) # type: ignore
else:
transformed_message.message_segments = [message.message_segment]
else:
# SessionMessage: 使用 processed_plain_text 构造简单段
transformed_message.message_segments = [Seg(type="text", data=message.processed_plain_text or "")]
# stream_id 处理
transformed_message.stream_id = message.session_id if hasattr(message, "session_id") else ""
# 处理后文本
transformed_message.plain_text = message.processed_plain_text
# 基本信息
if isinstance(message, MessageSending):
transformed_message.message_base_info["platform"] = message.platform
if message.session.group_id:
transformed_message.is_group_message = True
group_name = ""
if message.session.context and message.session.context.message and message.session.context.message.message_info.group_info:
group_name = message.session.context.message.message_info.group_info.group_name
transformed_message.message_base_info.update({
"group_id": message.session.group_id,
"group_name": group_name,
})
transformed_message.message_base_info.update({
"user_id": message.bot_user_info.user_id,
"user_cardname": message.bot_user_info.user_cardname,
"user_nickname": message.bot_user_info.user_nickname,
})
if not transformed_message.is_group_message:
transformed_message.is_private_message = True
elif hasattr(message, "message_info") and message.message_info:
if message.platform:
transformed_message.message_base_info["platform"] = message.platform
if message.message_info.group_info:
transformed_message.is_group_message = True
transformed_message.message_base_info.update({
"group_id": message.message_info.group_info.group_id,
"group_name": message.message_info.group_info.group_name,
})
if message.message_info.user_info:
if not transformed_message.is_group_message:
transformed_message.is_private_message = True
transformed_message.message_base_info.update({
"user_id": message.message_info.user_info.user_id,
"user_cardname": message.message_info.user_info.user_cardname,
"user_nickname": message.message_info.user_info.user_nickname,
})
return transformed_message
def _build_message_from_stream(
self, stream_id: str, llm_prompt: Optional[str] = None, llm_response: Optional["LLMGenerationDataModel"] = None
) -> MaiMessages:
"""从流ID构建消息"""
session = _chat_manager.get_session_by_session_id(stream_id)
assert session, f"未找到流ID为 {stream_id} 的会话"
message = session.context.message
return self._transform_event_message(message, llm_prompt, llm_response)
def _transform_event_without_message(
self,
stream_id: str,
llm_prompt: Optional[str] = None,
llm_response: Optional["LLMGenerationDataModel"] = None,
action_usage: Optional[List[str]] = None,
) -> MaiMessages:
"""没有message对象时进行转换"""
session = _chat_manager.get_session_by_session_id(stream_id)
assert session, f"未找到流ID为 {stream_id} 的会话"
return MaiMessages(
stream_id=stream_id,
llm_prompt=llm_prompt,
llm_response_content=(llm_response.content if llm_response else None),
llm_response_reasoning=(llm_response.reasoning if llm_response else None),
llm_response_model=(llm_response.model if llm_response else None),
llm_response_tool_call=(llm_response.tool_calls if llm_response else None),
is_group_message=session.is_group_session,
is_private_message=not session.is_group_session,
action_usage=action_usage,
additional_data={"response_is_processed": True},
)
async def _bridge_to_new_runtime(
self,
event_type: EventType | str,
continue_flag: bool,
message: Optional[MaiMessages],
) -> Tuple[bool, Optional[MaiMessages]]:
"""将事件桥接到新版本插件运行时
如果旧 handler 已经 abortcontinue_flag=False直接跳过。
"""
if not continue_flag:
return continue_flag, message
try:
from src.plugin_runtime.integration import get_plugin_runtime_manager
prm = get_plugin_runtime_manager()
if not prm.is_running:
return continue_flag, message
event_value = event_type.value if isinstance(event_type, EventType) else str(event_type)
message_dict = message.to_dict() if message and hasattr(message, "to_dict") else None
new_continue, new_msg_dict = await prm.bridge_event(
event_type_value=event_value,
message_dict=message_dict,
)
# 新运行时返回 abort 则合并
if not new_continue:
continue_flag = False
except Exception as e:
logger.warning(f"桥接事件到新运行时失败: {e}")
return continue_flag, message
def _prepare_message(
self,
event_type: EventType | str,
message: Optional[SessionMessage | MessageSending | MaiMessages] = None,
llm_prompt: Optional[str] = None,
llm_response: Optional["LLMGenerationDataModel"] = None,
stream_id: Optional[str] = None,
action_usage: Optional[List[str]] = None,
) -> Optional[MaiMessages]:
"""根据事件类型和输入,准备和转换消息对象。"""
if isinstance(message, MaiMessages):
return message.deepcopy()
if message:
return self._transform_event_message(message, llm_prompt, llm_response)
if event_type not in [EventType.ON_START, EventType.ON_STOP]:
assert stream_id, "如果没有消息,必须为非启动/关闭事件提供流ID"
if event_type in [EventType.ON_MESSAGE, EventType.ON_PLAN, EventType.POST_LLM, EventType.AFTER_LLM]:
return self._build_message_from_stream(stream_id, llm_prompt, llm_response)
else:
return self._transform_event_without_message(stream_id, llm_prompt, llm_response, action_usage)
return None # ON_START, ON_STOP事件没有消息体
def _dispatch_handler_task(
self, handler: BaseEventHandler, event_type: EventType | str, message: Optional[MaiMessages] = None
):
"""分发一个非阻塞(异步)的事件处理任务。"""
if event_type == EventType.UNKNOWN:
raise ValueError("未知事件类型")
try:
task = asyncio.create_task(handler.execute(message))
task_name = f"{handler.plugin_name}-{handler.handler_name}"
task.set_name(task_name)
task.add_done_callback(lambda t: self._task_done_callback(t, event_type))
self._handler_tasks.setdefault(handler.handler_name, []).append(task)
except Exception as e:
logger.error(f"创建事件处理器任务 {handler.handler_name} 时发生异常: {e}", exc_info=True)
async def _dispatch_intercepting_handler_task(
self, handler: BaseEventHandler, event_type: EventType | str, message: Optional[MaiMessages] = None
) -> Tuple[bool, Optional[MaiMessages]]:
"""分发并等待一个阻塞(同步)的事件处理器,返回是否应继续处理。"""
if event_type == EventType.UNKNOWN:
raise ValueError("未知事件类型")
if event_type not in self._history_enable_map:
raise ValueError(f"事件类型 {event_type} 未注册")
try:
result = await handler.execute(message)
expected_fields = ["success", "continue_processing", "return_message", "custom_result", "modified_message"]
if not isinstance(result, tuple) or len(result) != 5:
if isinstance(result, tuple):
annotated = ", ".join(f"{name}={val!r}" for name, val in zip(expected_fields, result, strict=False))
actual_desc = f"{len(result)} 个元素 ({annotated})"
else:
actual_desc = f"非 tuple 类型: {type(result)}"
logger.error(
f"[{self.__class__.__name__}] EventHandler {handler.handler_name} 返回值不符合预期:\n"
f" 模块来源: {handler.__class__.__module__}.{handler.__class__.__name__}\n"
f" 期望: 5 个元素 ({', '.join(expected_fields)})\n"
f" 实际: {actual_desc}"
)
return True, None
success, continue_processing, return_message, custom_result, modified_message = result
if not success:
logger.error(f"EventHandler {handler.handler_name} 执行失败: {return_message}")
else:
logger.debug(f"EventHandler {handler.handler_name} 执行成功: {return_message}")
if self._history_enable_map[event_type] and custom_result:
self._events_result_history[event_type].append(custom_result)
return continue_processing, modified_message
except KeyError:
logger.error(f"事件 {event_type} 注册的历史记录启用情况与实际不符合")
return True, None
except Exception as e:
logger.error(f"EventHandler {handler.handler_name} 发生异常: {e}", exc_info=True)
return True, None # 发生异常时默认不中断其他处理
def _task_done_callback(
self,
task: asyncio.Task[Tuple[bool, bool, str | None, CustomEventHandlerResult | None, MaiMessages | None]],
event_type: EventType | str,
):
"""任务完成回调"""
task_name = task.get_name() or "Unknown Task"
if event_type == EventType.UNKNOWN:
raise ValueError("未知事件类型")
if event_type not in self._history_enable_map:
raise ValueError(f"事件类型 {event_type} 未注册")
try:
success, _, result, custom_result, _ = task.result() # 忽略是否继续的标志和消息的修改,因为消息本身未被拦截
if success:
logger.debug(f"事件处理任务 {task_name} 已成功完成: {result}")
else:
logger.error(f"事件处理任务 {task_name} 执行失败: {result}")
if self._history_enable_map[event_type] and custom_result:
self._events_result_history[event_type].append(custom_result)
except asyncio.CancelledError:
pass
except KeyError:
logger.error(f"事件 {event_type} 注册的历史记录启用情况与实际不符合")
except Exception as e:
logger.error(f"事件处理任务 {task_name} 发生异常: {e}")
finally:
with contextlib.suppress(ValueError, KeyError):
self._handler_tasks[task_name].remove(task)
events_manager = EventsManager()

View File

@@ -1,634 +0,0 @@
from importlib.util import spec_from_file_location, module_from_spec
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Type
from collections import deque
import os
import traceback
from src.common.logger import get_logger
from src.plugin_system.base.plugin_base import PluginBase
from src.plugin_system.base.component_types import ComponentType
from src.plugin_system.utils.manifest_utils import VersionComparator
from .component_registry import component_registry
from .plugin_service_registry import plugin_service_registry
logger = get_logger("plugin_manager")
class PluginManager:
"""
插件管理器类
负责加载,重载和卸载插件,同时管理插件的所有组件
"""
def __init__(self):
self.plugin_directories: List[str] = [] # 插件根目录列表
self.plugin_classes: Dict[str, Type[PluginBase]] = {} # 全局插件类注册表,插件名 -> 插件类
self.plugin_paths: Dict[str, str] = {} # 记录插件名到目录路径的映射,插件名 -> 目录路径
self.loaded_plugins: Dict[str, PluginBase] = {} # 已加载的插件类实例注册表,插件名 -> 插件类实例
self.failed_plugins: Dict[str, str] = {} # 记录加载失败的插件文件及其错误信息,插件名 -> 错误信息
# 确保插件目录存在
self._ensure_plugin_directories()
logger.info("插件管理器初始化完成")
# === 插件目录管理 ===
def add_plugin_directory(self, directory: str) -> bool:
"""添加插件目录"""
if os.path.exists(directory):
if directory not in self.plugin_directories:
self.plugin_directories.append(directory)
logger.debug(f"已添加插件目录: {directory}")
return True
else:
logger.warning(f"插件不可重复加载: {directory}")
else:
logger.warning(f"插件目录不存在: {directory}")
return False
# === 插件加载管理 ===
def load_all_plugins(self) -> Tuple[int, int]:
"""加载所有插件
Returns:
tuple[int, int]: (插件数量, 组件数量)
"""
logger.debug("开始加载所有插件...")
# 第一阶段:加载所有插件模块(注册插件类)
total_loaded_modules = 0
total_failed_modules = 0
for directory in self.plugin_directories:
loaded, failed = self._load_plugin_modules_from_directory(directory)
total_loaded_modules += loaded
total_failed_modules += failed
logger.debug(f"插件模块加载完成 - 成功: {total_loaded_modules}, 失败: {total_failed_modules}")
# 第二阶段前:根据依赖关系决定插件加载顺序
dependency_graph, missing_dependencies = self._build_plugin_dependency_graph()
sorted_plugins, cycle_plugins = self._resolve_plugin_load_order(dependency_graph)
total_registered = 0
total_failed_registration = 0
# 先处理缺失依赖
for plugin_name, missing in missing_dependencies.items():
if not missing:
continue
missing_dep_names = ", ".join(sorted(missing))
self.failed_plugins[plugin_name] = f"缺少依赖插件: {missing_dep_names}"
logger.error(f"❌ 插件加载失败: {plugin_name} - 缺少依赖插件: {missing_dep_names}")
total_failed_registration += 1
# 再处理循环依赖
for plugin_name in sorted(cycle_plugins):
if plugin_name in missing_dependencies and missing_dependencies[plugin_name]:
continue
self.failed_plugins[plugin_name] = "检测到循环依赖"
logger.error(f"❌ 插件加载失败: {plugin_name} - 检测到循环依赖")
total_failed_registration += 1
# 最后按拓扑序加载可加载插件
for plugin_name in sorted_plugins:
if plugin_name in cycle_plugins:
continue
if plugin_name in missing_dependencies and missing_dependencies[plugin_name]:
continue
load_status, count = self.load_registered_plugin_classes(plugin_name)
if load_status:
total_registered += 1
else:
total_failed_registration += count
self._show_stats(total_registered, total_failed_registration)
return total_registered, total_failed_registration
def load_registered_plugin_classes(self, plugin_name: str) -> Tuple[bool, int]:
# sourcery skip: extract-duplicate-method, extract-method
"""
加载已经注册的插件类
"""
plugin_class = self.plugin_classes.get(plugin_name)
if not plugin_class:
logger.error(f"插件 {plugin_name} 的插件类未注册或不存在")
return False, 1
try:
# 使用记录的插件目录路径
plugin_dir = self.plugin_paths.get(plugin_name)
# 如果没有记录,直接返回失败
if not plugin_dir:
return False, 1
plugin_instance = plugin_class(plugin_dir=plugin_dir) # 实例化插件可能因为缺少manifest而失败
if not plugin_instance:
logger.error(f"插件 {plugin_name} 实例化失败")
return False, 1
# 检查插件是否启用
if not plugin_instance.enable_plugin:
logger.info(f"插件 {plugin_name} 已禁用,跳过加载")
return False, 0
# 检查版本兼容性
is_compatible, compatibility_error = self._check_plugin_version_compatibility(
plugin_name, plugin_instance.manifest_data
)
if not is_compatible:
self.failed_plugins[plugin_name] = compatibility_error
logger.error(f"❌ 插件加载失败: {plugin_name} - {compatibility_error}")
return False, 1
if plugin_instance.register_plugin():
self.loaded_plugins[plugin_name] = plugin_instance
self._show_plugin_components(plugin_name)
return True, 1
else:
self.failed_plugins[plugin_name] = "插件注册失败"
logger.error(f"❌ 插件注册失败: {plugin_name}")
return False, 1
except FileNotFoundError as e:
# manifest文件缺失
error_msg = f"缺少manifest文件: {str(e)}"
self.failed_plugins[plugin_name] = error_msg
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
return False, 1
except ValueError as e:
# manifest文件格式错误或验证失败
traceback.print_exc()
error_msg = f"manifest验证失败: {str(e)}"
self.failed_plugins[plugin_name] = error_msg
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
return False, 1
except Exception as e:
# 其他错误
error_msg = f"未知错误: {str(e)}"
self.failed_plugins[plugin_name] = error_msg
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
logger.debug("详细错误信息: ", exc_info=True)
return False, 1
async def remove_registered_plugin(self, plugin_name: str) -> bool:
"""
禁用插件模块
"""
if not plugin_name:
raise ValueError("插件名称不能为空")
if plugin_name not in self.loaded_plugins:
logger.warning(f"插件 {plugin_name} 未加载")
return False
plugin_instance = self.loaded_plugins[plugin_name]
plugin_info = plugin_instance.plugin_info
success = True
for component in plugin_info.components:
success &= await component_registry.remove_component(component.name, component.component_type, plugin_name)
success &= component_registry.remove_plugin_registry(plugin_name)
plugin_service_registry.remove_services_by_plugin(plugin_name)
del self.loaded_plugins[plugin_name]
return success
async def reload_registered_plugin(self, plugin_name: str) -> bool:
"""
重载插件模块
"""
old_instance = self.loaded_plugins.get(plugin_name)
if not old_instance:
logger.warning(f"插件 {plugin_name} 未加载,无法重载")
return False
if not await self.remove_registered_plugin(plugin_name):
return False
if not self.load_registered_plugin_classes(plugin_name)[0]:
logger.error(f"插件 {plugin_name} 重载失败,开始回滚旧实例")
rollback_ok = await self._rollback_failed_reload(plugin_name, old_instance)
if rollback_ok:
logger.info(f"插件 {plugin_name} 已回滚到旧版本实例")
else:
logger.error(f"插件 {plugin_name} 回滚失败,插件当前不可用")
return False
logger.debug(f"插件 {plugin_name} 重载成功")
return True
async def _rollback_failed_reload(self, plugin_name: str, old_instance: PluginBase) -> bool:
"""重载失败后回滚旧实例。"""
try:
await component_registry.remove_components_by_plugin(plugin_name)
component_registry.remove_plugin_registry(plugin_name)
plugin_service_registry.remove_services_by_plugin(plugin_name)
if not old_instance.register_plugin():
logger.error(f"插件 {plugin_name} 回滚失败: 旧实例重新注册失败")
return False
self.loaded_plugins[plugin_name] = old_instance
return True
except Exception as e:
logger.error(f"插件 {plugin_name} 回滚异常: {e}", exc_info=True)
return False
def rescan_plugin_directory(self) -> Tuple[int, int]:
"""
重新扫描插件根目录
"""
total_success = 0
total_fail = 0
for directory in self.plugin_directories:
if os.path.exists(directory):
logger.debug(f"重新扫描插件根目录: {directory}")
success, fail = self._load_plugin_modules_from_directory(directory)
total_success += success
total_fail += fail
else:
logger.warning(f"插件根目录不存在: {directory}")
return total_success, total_fail
def get_plugin_instance(self, plugin_name: str) -> Optional["PluginBase"]:
"""获取插件实例
Args:
plugin_name: 插件名称
Returns:
Optional[BasePlugin]: 插件实例或None
"""
return self.loaded_plugins.get(plugin_name)
# === 查询方法 ===
def list_loaded_plugins(self) -> List[str]:
"""
列出所有当前加载的插件。
Returns:
list: 当前加载的插件名称列表。
"""
return list(self.loaded_plugins.keys())
def list_registered_plugins(self) -> List[str]:
"""
列出所有已注册的插件类。
Returns:
list: 已注册的插件类名称列表。
"""
return list(self.plugin_classes.keys())
def get_plugin_path(self, plugin_name: str) -> Optional[str]:
"""
获取指定插件的路径。
Args:
plugin_name: 插件名称
Returns:
Optional[str]: 插件目录的绝对路径如果插件不存在则返回None。
"""
return self.plugin_paths.get(plugin_name)
# === 私有方法 ===
# == 目录管理 ==
def _ensure_plugin_directories(self) -> None:
"""确保所有插件根目录存在,如果不存在则创建"""
default_directories = ["src/plugins/built_in", "plugins"]
for directory in default_directories:
if not os.path.exists(directory):
os.makedirs(directory, exist_ok=True)
logger.info(f"创建插件根目录: {directory}")
if directory not in self.plugin_directories:
self.plugin_directories.append(directory)
logger.debug(f"已添加插件根目录: {directory}")
else:
logger.warning(f"根目录不可重复加载: {directory}")
# == 插件加载 ==
def _load_plugin_modules_from_directory(self, directory: str) -> tuple[int, int]:
"""从指定目录加载插件模块"""
loaded_count = 0
failed_count = 0
if not os.path.exists(directory):
logger.warning(f"插件根目录不存在: {directory}")
return 0, 1
logger.debug(f"正在扫描插件根目录: {directory}")
# 遍历目录中的所有包
for item in os.listdir(directory):
item_path = os.path.join(directory, item)
if os.path.isdir(item_path) and not item.startswith(".") and not item.startswith("__"):
plugin_file = os.path.join(item_path, "plugin.py")
if os.path.exists(plugin_file):
if self._load_plugin_module_file(plugin_file):
loaded_count += 1
else:
failed_count += 1
return loaded_count, failed_count
def _load_plugin_module_file(self, plugin_file: str) -> bool:
# sourcery skip: extract-method
"""加载单个插件模块文件
Args:
plugin_file: 插件文件路径
plugin_name: 插件名称
plugin_dir: 插件目录路径
"""
# 生成模块名
plugin_path = Path(plugin_file)
module_name = ".".join(plugin_path.parent.parts)
try:
# 动态导入插件模块
spec = spec_from_file_location(module_name, plugin_file)
if spec is None or spec.loader is None:
logger.error(f"无法创建模块规范: {plugin_file}")
return False
module = module_from_spec(spec)
module.__package__ = module_name # 设置模块包名
spec.loader.exec_module(module)
logger.debug(f"插件模块加载成功: {plugin_file}")
return True
except Exception as e:
error_msg = f"加载插件模块 {plugin_file} 失败: {e}"
logger.error(error_msg)
self.failed_plugins[module_name] = error_msg
return False
# == 依赖解析与加载顺序 ==
def _extract_declared_dependencies(self, plugin_name: str, plugin_class: Type[PluginBase]) -> Set[str]:
"""提取插件声明的依赖。
兼容声明格式:
- list[str]
- list[dict]其中dict至少包含name键
"""
dependencies: Set[str] = set()
raw_dependencies = getattr(plugin_class, "dependencies", [])
# 兼容错误声明
if isinstance(raw_dependencies, property):
logger.warning(f"插件 {plugin_name} 的 dependencies 未声明为类属性,将按无依赖处理")
return dependencies
if not isinstance(raw_dependencies, list):
logger.warning(f"插件 {plugin_name} 的 dependencies 不是列表,将按无依赖处理")
return dependencies
for dependency in raw_dependencies:
dependency_name = ""
if isinstance(dependency, str):
dependency_name = dependency.strip()
elif isinstance(dependency, dict):
dependency_name = str(dependency.get("name", "")).strip()
else:
logger.warning(f"插件 {plugin_name} 包含不支持的依赖声明类型: {type(dependency)}")
continue
if not dependency_name:
continue
if dependency_name == plugin_name:
logger.warning(f"插件 {plugin_name} 声明了对自身的依赖,已忽略")
continue
dependencies.add(dependency_name)
return dependencies
def _build_plugin_dependency_graph(self) -> Tuple[Dict[str, Set[str]], Dict[str, Set[str]]]:
"""构建依赖图并返回缺失依赖映射。"""
plugin_names = set(self.plugin_classes.keys())
dependency_graph: Dict[str, Set[str]] = {name: set() for name in plugin_names}
missing_dependencies: Dict[str, Set[str]] = {name: set() for name in plugin_names}
for plugin_name, plugin_class in self.plugin_classes.items():
declared_dependencies = self._extract_declared_dependencies(plugin_name, plugin_class)
for dependency_name in declared_dependencies:
if dependency_name in plugin_names:
dependency_graph[plugin_name].add(dependency_name)
else:
missing_dependencies[plugin_name].add(dependency_name)
return dependency_graph, missing_dependencies
def _resolve_plugin_load_order(self, dependency_graph: Dict[str, Set[str]]) -> Tuple[List[str], Set[str]]:
"""根据依赖图计算加载顺序,并检测循环依赖。"""
indegree: Dict[str, int] = {
plugin_name: len(dependencies) for plugin_name, dependencies in dependency_graph.items()
}
reverse_graph: Dict[str, Set[str]] = {plugin_name: set() for plugin_name in dependency_graph}
for plugin_name, dependencies in dependency_graph.items():
for dependency_name in dependencies:
reverse_graph[dependency_name].add(plugin_name)
zero_indegree_queue = deque(sorted([name for name, degree in indegree.items() if degree == 0]))
load_order: List[str] = []
while zero_indegree_queue:
current_plugin = zero_indegree_queue.popleft()
load_order.append(current_plugin)
for dependent_plugin in sorted(reverse_graph[current_plugin]):
indegree[dependent_plugin] -= 1
if indegree[dependent_plugin] == 0:
zero_indegree_queue.append(dependent_plugin)
cycle_plugins = {name for name, degree in indegree.items() if degree > 0}
if cycle_plugins:
logger.error(f"检测到循环依赖插件: {', '.join(sorted(cycle_plugins))}")
return load_order, cycle_plugins
# == 兼容性检查 ==
def _check_plugin_version_compatibility(self, plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]:
"""检查插件版本兼容性
Args:
plugin_name: 插件名称
manifest_data: manifest数据
Returns:
Tuple[bool, str]: (是否兼容, 错误信息)
"""
if "host_application" not in manifest_data:
return True, "" # 没有版本要求,默认兼容
host_app = manifest_data["host_application"]
if not isinstance(host_app, dict):
return True, ""
min_version = host_app.get("min_version", "")
max_version = host_app.get("max_version", "")
if not min_version and not max_version:
return True, "" # 没有版本要求,默认兼容
try:
current_version = VersionComparator.get_current_host_version()
is_compatible, error_msg = VersionComparator.is_version_in_range(current_version, min_version, max_version)
if not is_compatible:
return False, f"版本不兼容: {error_msg}"
logger.debug(f"插件 {plugin_name} 版本兼容性检查通过")
return True, ""
except Exception as e:
logger.warning(f"插件 {plugin_name} 版本兼容性检查失败: {e}")
return False, f"插件 {plugin_name} 版本兼容性检查失败: {e}" # 检查失败时默认不允许加载
# == 显示统计与插件信息 ==
def _show_stats(self, total_registered: int, total_failed_registration: int):
# sourcery skip: low-code-quality
# 获取组件统计信息
stats = component_registry.get_registry_stats()
action_count = stats.get("action_components", 0)
command_count = stats.get("command_components", 0)
tool_count = stats.get("tool_components", 0)
event_handler_count = stats.get("event_handlers", 0)
total_components = stats.get("total_components", 0)
# 📋 显示插件加载总览
if total_registered > 0:
logger.info("🎉 插件系统加载完成!")
logger.info(
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, EventHandler: {event_handler_count})"
)
# 显示详细的插件列表
logger.info("📋 已加载插件详情:")
for plugin_name in self.loaded_plugins.keys():
if plugin_info := component_registry.get_plugin_info(plugin_name):
# 插件基本信息
version_info = f"v{plugin_info.version}" if plugin_info.version else ""
author_info = f"by {plugin_info.author}" if plugin_info.author else "unknown"
license_info = f"[{plugin_info.license}]" if plugin_info.license else ""
info_parts = [part for part in [version_info, author_info, license_info] if part]
extra_info = f" ({', '.join(info_parts)})" if info_parts else ""
logger.info(f" 📦 {plugin_info.display_name}{extra_info}")
# Manifest信息
if plugin_info.manifest_data:
"""
if plugin_info.keywords:
logger.info(f" 🏷️ 关键词: {', '.join(plugin_info.keywords)}")
if plugin_info.categories:
logger.info(f" 📁 分类: {', '.join(plugin_info.categories)}")
"""
if plugin_info.homepage_url:
logger.info(f" 🌐 主页: {plugin_info.homepage_url}")
# 组件列表
if plugin_info.components:
action_components = [
c for c in plugin_info.components if c.component_type == ComponentType.ACTION
]
command_components = [
c for c in plugin_info.components if c.component_type == ComponentType.COMMAND
]
tool_components = [c for c in plugin_info.components if c.component_type == ComponentType.TOOL]
event_handler_components = [
c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER
]
if action_components:
action_names = [c.name for c in action_components]
logger.info(f" 🎯 Action组件: {', '.join(action_names)}")
if command_components:
command_names = [c.name for c in command_components]
logger.info(f" ⚡ Command组件: {', '.join(command_names)}")
if tool_components:
tool_names = [c.name for c in tool_components]
logger.info(f" 🛠️ Tool组件: {', '.join(tool_names)}")
if event_handler_components:
event_handler_names = [c.name for c in event_handler_components]
logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_names)}")
# 依赖信息
if plugin_info.dependencies:
logger.info(f" 🔗 依赖: {', '.join(plugin_info.dependencies)}")
# 配置文件信息
if plugin_info.config_file:
config_status = "" if self.plugin_paths.get(plugin_name) else ""
logger.info(f" ⚙️ 配置: {plugin_info.config_file} {config_status}")
root_path = Path(__file__)
# 查找项目根目录
while not (root_path / "pyproject.toml").exists() and root_path.parent != root_path:
root_path = root_path.parent
# 显示目录统计
logger.info("📂 加载目录统计:")
for directory in self.plugin_directories:
if os.path.exists(directory):
plugins_in_dir = []
for plugin_name in self.loaded_plugins.keys():
plugin_path = self.plugin_paths.get(plugin_name, "")
if (
Path(plugin_path)
.resolve()
.is_relative_to(Path(os.path.join(str(root_path), directory)).resolve())
):
plugins_in_dir.append(plugin_name)
if plugins_in_dir:
logger.info(f" 📁 {directory}: {len(plugins_in_dir)}个插件 ({', '.join(plugins_in_dir)})")
else:
logger.info(f" 📁 {directory}: 0个插件")
# 失败信息
if total_failed_registration > 0:
logger.info(f"⚠️ 失败统计: {total_failed_registration}个插件加载失败")
for failed_plugin, error in self.failed_plugins.items():
logger.info(f"{failed_plugin}: {error}")
else:
logger.warning("😕 没有成功加载任何插件")
def _show_plugin_components(self, plugin_name: str) -> None:
if plugin_info := component_registry.get_plugin_info(plugin_name):
component_types = {}
for comp in plugin_info.components:
comp_type = comp.component_type.name
component_types[comp_type] = component_types.get(comp_type, 0) + 1
components_str = ", ".join([f"{count}{ctype}" for ctype, count in component_types.items()])
# 显示manifest信息
manifest_info = ""
if plugin_info.license:
manifest_info += f" [{plugin_info.license}]"
if plugin_info.keywords:
manifest_info += f" 关键词: {', '.join(plugin_info.keywords[:3])}" # 只显示前3个关键词
if len(plugin_info.keywords) > 3:
manifest_info += "..."
logger.info(
f"✅ 插件加载成功: {plugin_name} v{plugin_info.version} ({components_str}){manifest_info} - {plugin_info.description}"
)
else:
logger.info(f"✅ 插件加载成功: {plugin_name}")
# 全局插件管理器实例
plugin_manager = PluginManager()

View File

@@ -1,251 +0,0 @@
from typing import Any, Callable, Dict, Optional
import inspect
from src.common.logger import get_logger
from src.plugin_system.base.service_types import PluginServiceInfo
logger = get_logger("plugin_service_registry")
class PluginServiceRegistry:
"""插件服务注册中心"""
def __init__(self):
self._services: Dict[str, PluginServiceInfo] = {}
self._service_handlers: Dict[str, Callable[..., Any]] = {}
logger.info("插件服务注册中心初始化完成")
def register_service(self, service_info: PluginServiceInfo, service_handler: Callable[..., Any]) -> bool:
"""注册插件服务。"""
if not service_info.name or not service_info.plugin_name:
logger.error("插件服务注册失败: service名称或插件名称为空")
return False
if "." in service_info.name:
logger.error(f"插件服务名称 '{service_info.name}' 包含非法字符 '.',请使用下划线替代")
return False
if "." in service_info.plugin_name:
logger.error(f"插件服务所属插件名称 '{service_info.plugin_name}' 包含非法字符 '.',请使用下划线替代")
return False
if invalid_callers := [caller for caller in service_info.allowed_callers if "." in caller]:
logger.error(f"插件服务白名单包含非法调用方名称: {invalid_callers}")
return False
full_name = service_info.full_name
if full_name in self._services:
logger.warning(f"插件服务已存在,拒绝重复注册: {full_name}")
return False
self._services[full_name] = service_info
self._service_handlers[full_name] = service_handler
logger.debug(f"已注册插件服务: {full_name} (version={service_info.version})")
return True
def get_service(self, service_name: str, plugin_name: Optional[str] = None) -> Optional[PluginServiceInfo]:
"""获取插件服务元信息。
service_name支持
- full_name: plugin_name.service_name
- short_name: service_name当唯一时可解析
"""
full_name = self._resolve_full_name(service_name, plugin_name)
return self._services.get(full_name) if full_name else None
def get_service_handler(self, service_name: str, plugin_name: Optional[str] = None) -> Optional[Callable[..., Any]]:
"""获取插件服务处理函数。"""
full_name = self._resolve_full_name(service_name, plugin_name)
return self._service_handlers.get(full_name) if full_name else None
def list_services(
self, plugin_name: Optional[str] = None, enabled_only: bool = False
) -> Dict[str, PluginServiceInfo]:
"""列出插件服务。"""
services = self._services.copy()
if plugin_name:
services = {name: info for name, info in services.items() if info.plugin_name == plugin_name}
if enabled_only:
services = {name: info for name, info in services.items() if info.enabled}
return services
def enable_service(self, service_name: str, plugin_name: Optional[str] = None) -> bool:
"""启用插件服务。"""
if not (service_info := self.get_service(service_name, plugin_name)):
logger.warning(f"插件服务未注册,无法启用: {service_name}")
return False
service_info.enabled = True
logger.info(f"插件服务已启用: {service_info.full_name}")
return True
def disable_service(self, service_name: str, plugin_name: Optional[str] = None) -> bool:
"""禁用插件服务。"""
if not (service_info := self.get_service(service_name, plugin_name)):
logger.warning(f"插件服务未注册,无法禁用: {service_name}")
return False
service_info.enabled = False
logger.info(f"插件服务已禁用: {service_info.full_name}")
return True
def unregister_service(self, service_name: str, plugin_name: Optional[str] = None) -> bool:
"""注销单个插件服务。"""
full_name = self._resolve_full_name(service_name, plugin_name)
if not full_name:
logger.warning(f"插件服务未注册,无法注销: {service_name}")
return False
self._services.pop(full_name, None)
self._service_handlers.pop(full_name, None)
logger.info(f"插件服务已注销: {full_name}")
return True
def remove_services_by_plugin(self, plugin_name: str) -> int:
"""移除某插件的所有注册服务。"""
target_names = [full_name for full_name, info in self._services.items() if info.plugin_name == plugin_name]
for full_name in target_names:
self._services.pop(full_name, None)
self._service_handlers.pop(full_name, None)
removed_count = len(target_names)
if removed_count:
logger.info(f"已移除插件 {plugin_name} 的服务数量: {removed_count}")
return removed_count
async def call_service(
self,
service_name: str,
*args: Any,
plugin_name: Optional[str] = None,
caller_plugin: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""调用插件服务(支持同步/异步handler"""
service_info = self.get_service(service_name, plugin_name)
if not service_info:
target_name = f"{plugin_name}.{service_name}" if plugin_name and "." not in service_name else service_name
raise ValueError(f"插件服务未注册: {target_name}")
if (
"." not in service_name
and plugin_name is None
and caller_plugin
and service_info.plugin_name != caller_plugin
):
raise PermissionError("跨插件服务调用必须使用完整服务名或显式指定plugin_name")
if not self._is_call_authorized(service_info, caller_plugin):
raise PermissionError(
f"调用被拒绝: caller={caller_plugin or 'anonymous'} 无权限访问服务 {service_info.full_name}"
)
if not service_info.enabled:
raise RuntimeError(f"插件服务已禁用: {service_info.full_name}")
handler = self.get_service_handler(service_name, plugin_name)
if not handler:
raise RuntimeError(f"插件服务处理器不存在: {service_info.full_name}")
self._validate_input_contract(service_info, args, kwargs)
result = handler(*args, **kwargs)
resolved_result = await result if inspect.isawaitable(result) else result
self._validate_output_contract(service_info, resolved_result)
return resolved_result
def _is_call_authorized(self, service_info: PluginServiceInfo, caller_plugin: Optional[str]) -> bool:
"""检查服务调用权限。"""
if caller_plugin is None:
return service_info.public
if caller_plugin == service_info.plugin_name:
return True
if service_info.public:
return True
allowed_callers = {caller.strip() for caller in service_info.allowed_callers if caller.strip()}
return "*" in allowed_callers or caller_plugin in allowed_callers
def _validate_input_contract(
self, service_info: PluginServiceInfo, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> None:
"""校验服务入参契约。"""
schema = service_info.params_schema
if not schema:
return
properties = schema.get("properties", {}) if isinstance(schema, dict) else {}
is_invocation_schema = "args" in properties or "kwargs" in properties
if is_invocation_schema:
payload = {"args": list(args), "kwargs": kwargs}
self._validate_by_schema(payload, schema, path="params")
return
if args:
raise ValueError(f"服务 {service_info.full_name} 的入参契约不允许位置参数")
self._validate_by_schema(kwargs, schema, path="params")
def _validate_output_contract(self, service_info: PluginServiceInfo, value: Any) -> None:
"""校验服务返回值契约。"""
if not service_info.return_schema:
return
self._validate_by_schema(value, service_info.return_schema, path="return")
def _validate_by_schema(self, value: Any, schema: Dict[str, Any], path: str) -> None:
"""基于简化JSON-Schema校验数据。"""
expected_type = schema.get("type")
if expected_type:
self._validate_type(value, expected_type, path)
enum_values = schema.get("enum")
if enum_values is not None and value not in enum_values:
raise ValueError(f"{path} 不在枚举范围内: {value}")
if expected_type == "object":
properties = schema.get("properties", {})
required = schema.get("required", [])
for field in required:
if field not in value:
raise ValueError(f"{path}.{field} 为必填字段")
for field, field_value in value.items():
if field in properties:
self._validate_by_schema(field_value, properties[field], f"{path}.{field}")
elif schema.get("additionalProperties", True) is False:
raise ValueError(f"{path}.{field} 不允许额外字段")
if expected_type == "array":
if item_schema := schema.get("items"):
for index, item in enumerate(value):
self._validate_by_schema(item, item_schema, f"{path}[{index}]")
def _validate_type(self, value: Any, expected_type: str, path: str) -> None:
"""校验基础类型。"""
type_checkers: Dict[str, Callable[[Any], bool]] = {
"string": lambda item: isinstance(item, str),
"number": lambda item: isinstance(item, (int, float)) and not isinstance(item, bool),
"integer": lambda item: isinstance(item, int) and not isinstance(item, bool),
"boolean": lambda item: isinstance(item, bool),
"object": lambda item: isinstance(item, dict),
"array": lambda item: isinstance(item, list),
"null": lambda item: item is None,
}
checker = type_checkers.get(expected_type)
if checker and not checker(value):
raise TypeError(f"{path} 类型不匹配,期望 {expected_type},实际 {type(value).__name__}")
def _resolve_full_name(self, service_name: str, plugin_name: Optional[str] = None) -> Optional[str]:
"""解析服务全名。"""
if "." in service_name:
return service_name if service_name in self._services else None
if plugin_name:
full_name = f"{plugin_name}.{service_name}"
return full_name if full_name in self._services else None
candidates = [full_name for full_name, info in self._services.items() if info.name == service_name]
if len(candidates) == 1:
return candidates[0]
if len(candidates) > 1:
logger.warning(f"插件服务名称 '{service_name}' 存在多义请传入plugin_name或使用完整服务名")
return None
return None
plugin_service_registry = PluginServiceRegistry()

View File

@@ -1,13 +0,0 @@
- [x] 自定义事件
- [ ] <del>允许handler随时订阅</del>
- [x] 允许其他组件给handler增加订阅
- [x] 允许其他组件给handler取消订阅
- [ ] <del>允许一个handler订阅多个事件</del>
- [x] event激活时给handler传递参数
- [ ] handler能拿到所有handlers的结果按照处理权重
- [x] 随时注册
- [ ] <del>删除event</del>
- [ ] 必要性?
- [x] 能够更改prompt
- [x] 能够更改llm_response
- [x] 能够更改message

View File

@@ -1,260 +0,0 @@
from typing import Any, Awaitable, Callable, List, Optional, Tuple, Union
import asyncio
import inspect
import time
import uuid
from src.common.logger import get_logger
from src.plugin_system.base.component_types import EventType, MaiMessages
from src.plugin_system.base.workflow_errors import WorkflowErrorCode
from src.plugin_system.base.workflow_types import WorkflowContext, WorkflowStage, WorkflowStepResult
logger = get_logger("workflow_engine")
class WorkflowEngine:
"""线性Workflow执行器MVP"""
STAGE_EVENT_SEQUENCE: List[Tuple[WorkflowStage, Union[EventType, str]]] = [
(WorkflowStage.INGRESS, "workflow.ingress"),
(WorkflowStage.PRE_PROCESS, EventType.ON_MESSAGE_PRE_PROCESS),
(WorkflowStage.PLAN, EventType.ON_PLAN),
(WorkflowStage.TOOL_EXECUTE, "workflow.tool_execute"),
(WorkflowStage.POST_PROCESS, EventType.POST_SEND_PRE_PROCESS),
(WorkflowStage.EGRESS, "workflow.egress"),
]
def __init__(self):
self._execution_history: dict[str, dict[str, Any]] = {}
async def execute_linear(
self,
dispatch_event: Callable[
[Union[EventType, str], Optional[MaiMessages], Optional[str], Optional[List[str]]],
Awaitable[Tuple[bool, Optional[MaiMessages]]],
],
message: Optional[MaiMessages] = None,
stream_id: Optional[str] = None,
action_usage: Optional[List[str]] = None,
context: Optional[WorkflowContext] = None,
) -> Tuple[WorkflowStepResult, Optional[MaiMessages], WorkflowContext]:
"""执行线性workflow。"""
workflow_context = context or WorkflowContext(trace_id=uuid.uuid4().hex, stream_id=stream_id)
current_message = message.deepcopy() if message else None
self._execution_history[workflow_context.trace_id] = {
"trace_id": workflow_context.trace_id,
"stream_id": workflow_context.stream_id,
"stages": [],
"errors": [],
"status": "running",
}
for stage, event_type in self.STAGE_EVENT_SEQUENCE:
stage_key = str(stage)
stage_start = time.perf_counter()
try:
should_continue, modified_message = await dispatch_event(
event_type,
current_message,
workflow_context.stream_id,
action_usage,
)
workflow_context.timings[stage_key] = time.perf_counter() - stage_start
self._execution_history[workflow_context.trace_id]["stages"].append(
{
"stage": stage_key,
"event_type": str(event_type),
"event_continue": should_continue,
"event_cost": workflow_context.timings[stage_key],
}
)
if modified_message:
current_message = modified_message
if not should_continue:
logger.info(f"[trace_id={workflow_context.trace_id}] Workflow在阶段 {stage_key} 被中断")
return (
WorkflowStepResult(
status="stop",
return_message=f"workflow stopped at stage {stage_key}",
diagnostics={
"stage": stage_key,
"trace_id": workflow_context.trace_id,
"error_code": WorkflowErrorCode.POLICY_BLOCKED.value,
},
),
current_message,
workflow_context,
)
step_result = await self._execute_registered_steps(stage, workflow_context, current_message)
if step_result.status in ["stop", "failed"]:
self._execution_history[workflow_context.trace_id]["status"] = step_result.status
self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy()
return step_result, current_message, workflow_context
except Exception as e:
workflow_context.timings[stage_key] = time.perf_counter() - stage_start
workflow_context.errors.append(f"{stage_key}: {e}")
logger.error(
f"[trace_id={workflow_context.trace_id}] Workflow阶段 {stage_key} 执行异常: {e}", exc_info=True
)
self._execution_history[workflow_context.trace_id]["status"] = "failed"
self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy()
return (
WorkflowStepResult(
status="failed",
return_message=str(e),
diagnostics={
"stage": stage_key,
"trace_id": workflow_context.trace_id,
"error_code": WorkflowErrorCode.DOWNSTREAM_FAILED.value,
},
),
current_message,
workflow_context,
)
self._execution_history[workflow_context.trace_id]["status"] = "continue"
self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy()
return (
WorkflowStepResult(
status="continue",
return_message="workflow completed",
diagnostics={"trace_id": workflow_context.trace_id},
),
current_message,
workflow_context,
)
async def _execute_registered_steps(
self,
stage: WorkflowStage,
context: WorkflowContext,
message: Optional[MaiMessages],
) -> WorkflowStepResult:
"""执行指定阶段已注册的workflow步骤。"""
from src.plugin_system.core.component_registry import component_registry
stage_steps = component_registry.get_steps_by_stage(stage, enabled_only=True)
sorted_steps = sorted(stage_steps.values(), key=lambda step_info: step_info.priority, reverse=True)
for step_info in sorted_steps:
handler = component_registry.get_workflow_step_handler(step_info.full_name, stage)
if not handler:
context.errors.append(f"{step_info.full_name}: handler not found")
continue
step_timing_key = f"{stage.value}:{step_info.full_name}"
step_start = time.perf_counter()
timeout_seconds = step_info.timeout_ms / 1000 if step_info.timeout_ms > 0 else None
try:
if inspect.iscoroutinefunction(handler):
coroutine = handler(context, message)
result = await asyncio.wait_for(coroutine, timeout_seconds) if timeout_seconds else await coroutine
else:
if timeout_seconds:
result = await asyncio.wait_for(asyncio.to_thread(handler, context, message), timeout_seconds)
else:
result = handler(context, message)
if inspect.isawaitable(result):
result = await asyncio.wait_for(result, timeout_seconds) if timeout_seconds else await result
context.timings[step_timing_key] = time.perf_counter() - step_start
normalized_result = self._normalize_step_result(result)
if normalized_result.status == "continue":
continue
normalized_result.diagnostics.setdefault("stage", stage.value)
normalized_result.diagnostics.setdefault("step", step_info.full_name)
normalized_result.diagnostics.setdefault("trace_id", context.trace_id)
if normalized_result.status == "failed":
context.errors.append(
f"{step_info.full_name}: {normalized_result.return_message or 'workflow step failed'}"
)
normalized_result.diagnostics.setdefault("error_code", WorkflowErrorCode.DOWNSTREAM_FAILED.value)
return normalized_result
except asyncio.TimeoutError:
context.timings[step_timing_key] = time.perf_counter() - step_start
timeout_message = f"workflow step timeout after {step_info.timeout_ms}ms"
context.errors.append(f"{step_info.full_name}: {timeout_message}")
logger.error(
f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 超时: {timeout_message}"
)
return WorkflowStepResult(
status="failed",
return_message=timeout_message,
diagnostics={
"stage": stage.value,
"step": step_info.full_name,
"trace_id": context.trace_id,
"error_code": WorkflowErrorCode.STEP_TIMEOUT.value,
},
)
except Exception as e:
context.timings[step_timing_key] = time.perf_counter() - step_start
context.errors.append(f"{step_info.full_name}: {e}")
logger.error(
f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 执行异常: {e}", exc_info=True
)
return WorkflowStepResult(
status="failed",
return_message=str(e),
diagnostics={
"stage": stage.value,
"step": step_info.full_name,
"trace_id": context.trace_id,
"error_code": WorkflowErrorCode.DOWNSTREAM_FAILED.value,
},
)
return WorkflowStepResult(status="continue", diagnostics={"stage": stage.value, "trace_id": context.trace_id})
def _normalize_step_result(self, result: Any) -> WorkflowStepResult:
"""归一化workflow step返回值。"""
if isinstance(result, WorkflowStepResult):
return result
if isinstance(result, bool):
if result:
return WorkflowStepResult(status="continue")
return WorkflowStepResult(
status="failed",
diagnostics={"error_code": WorkflowErrorCode.DOWNSTREAM_FAILED.value},
)
if result is None:
return WorkflowStepResult(status="continue")
if isinstance(result, str):
return WorkflowStepResult(status="continue", return_message=result)
if isinstance(result, dict):
status = result.get("status", "continue")
if status not in ["continue", "stop", "failed"]:
status = "failed"
return WorkflowStepResult(
status=status,
return_message=result.get("return_message"),
diagnostics=result.get("diagnostics", {}),
events=result.get("events", []),
)
return WorkflowStepResult(
status="failed",
return_message=f"unsupported step result type: {type(result)}",
diagnostics={"error_code": WorkflowErrorCode.BAD_PAYLOAD.value},
)
def get_execution_trace(self, trace_id: str) -> Optional[dict[str, Any]]:
"""按trace_id获取workflow执行路径。"""
trace = self._execution_history.get(trace_id)
return trace.copy() if trace else None
def clear_execution_trace(self, trace_id: str) -> bool:
"""清理trace执行记录。"""
if trace_id in self._execution_history:
del self._execution_history[trace_id]
return True
return False
workflow_engine = WorkflowEngine()

View File

@@ -1,19 +0,0 @@
"""
插件系统工具模块
提供插件开发和管理的实用工具
"""
from .manifest_utils import (
ManifestValidator,
# ManifestGenerator,
# validate_plugin_manifest,
# generate_plugin_manifest,
)
__all__ = [
"ManifestValidator",
# "ManifestGenerator",
# "validate_plugin_manifest",
# "generate_plugin_manifest",
]

View File

@@ -1,515 +0,0 @@
"""
插件Manifest工具模块
提供manifest文件的验证、生成和管理功能
"""
import re
from typing import Dict, Any, Tuple
from src.common.logger import get_logger
from src.config.config import MMC_VERSION
# if TYPE_CHECKING:
# from src.plugin_system.base.base_plugin import BasePlugin
logger = get_logger("manifest_utils")
class VersionComparator:
"""版本号比较器
支持语义化版本号比较自动处理snapshot版本并支持向前兼容性检查
"""
# 版本兼容性映射表(硬编码)
# 格式: {插件最大支持版本: [实际兼容的版本列表]}
COMPATIBILITY_MAP = {
# 0.8.x 系列向前兼容规则
"0.8.0": ["0.8.1", "0.8.2", "0.8.3", "0.8.4", "0.8.5", "0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"],
"0.8.1": ["0.8.2", "0.8.3", "0.8.4", "0.8.5", "0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"],
"0.8.2": ["0.8.3", "0.8.4", "0.8.5", "0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"],
"0.8.3": ["0.8.4", "0.8.5", "0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"],
"0.8.4": ["0.8.5", "0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"],
"0.8.5": ["0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"],
"0.8.6": ["0.8.7", "0.8.8", "0.8.9", "0.8.10"],
"0.8.7": ["0.8.8", "0.8.9", "0.8.10"],
"0.8.8": ["0.8.9", "0.8.10"],
"0.8.9": ["0.8.10"],
# 可以根据需要添加更多兼容映射
# "0.9.0": ["0.9.1", "0.9.2", "0.9.3"], # 示例0.9.x系列兼容
}
@staticmethod
def normalize_version(version: str) -> str:
"""标准化版本号移除snapshot标识
Args:
version: 原始版本号,如 "0.8.0-snapshot.1"
Returns:
str: 标准化后的版本号,如 "0.8.0"
"""
if not version:
return "0.0.0"
# 移除snapshot部分
normalized = re.sub(r"-snapshot\.\d+", "", version.strip())
# 确保版本号格式正确
if not re.match(r"^\d+(\.\d+){0,2}$", normalized):
# 如果不是有效的版本号格式,返回默认版本
return "0.0.0"
# 尝试补全版本号
parts = normalized.split(".")
while len(parts) < 3:
parts.append("0")
normalized = ".".join(parts[:3])
return normalized
@staticmethod
def parse_version(version: str) -> Tuple[int, int, int]:
"""解析版本号为元组
Args:
version: 版本号字符串
Returns:
Tuple[int, int, int]: (major, minor, patch)
"""
normalized = VersionComparator.normalize_version(version)
try:
parts = normalized.split(".")
return (int(parts[0]), int(parts[1]), int(parts[2]))
except (ValueError, IndexError):
logger.warning(f"无法解析版本号: {version},使用默认版本 0.0.0")
return (0, 0, 0)
@staticmethod
def compare_versions(version1: str, version2: str) -> int:
"""比较两个版本号
Args:
version1: 第一个版本号
version2: 第二个版本号
Returns:
int: -1 if version1 < version2, 0 if equal, 1 if version1 > version2
"""
v1_tuple = VersionComparator.parse_version(version1)
v2_tuple = VersionComparator.parse_version(version2)
if v1_tuple < v2_tuple:
return -1
elif v1_tuple > v2_tuple:
return 1
else:
return 0
@staticmethod
def check_forward_compatibility(current_version: str, max_version: str) -> Tuple[bool, str]:
"""检查向前兼容性(仅使用兼容性映射表)
Args:
current_version: 当前版本
max_version: 插件声明的最大支持版本
Returns:
Tuple[bool, str]: (是否兼容, 兼容信息)
"""
current_normalized = VersionComparator.normalize_version(current_version)
max_normalized = VersionComparator.normalize_version(max_version)
# 检查兼容性映射表
if max_normalized in VersionComparator.COMPATIBILITY_MAP:
compatible_versions = VersionComparator.COMPATIBILITY_MAP[max_normalized]
if current_normalized in compatible_versions:
return True, f"根据兼容性映射表,版本 {current_normalized}{max_normalized} 兼容"
return False, ""
@staticmethod
def is_version_in_range(version: str, min_version: str = "", max_version: str = "") -> Tuple[bool, str]:
"""检查版本是否在指定范围内,支持兼容性检查
Args:
version: 要检查的版本号
min_version: 最小版本号(可选)
max_version: 最大版本号(可选)
Returns:
Tuple[bool, str]: (是否兼容, 错误信息或兼容信息)
"""
if not min_version and not max_version:
return True, ""
version_normalized = VersionComparator.normalize_version(version)
# 检查最小版本
if min_version:
min_normalized = VersionComparator.normalize_version(min_version)
if VersionComparator.compare_versions(version_normalized, min_normalized) < 0:
return False, f"版本 {version_normalized} 低于最小要求版本 {min_normalized}"
# 检查最大版本
if max_version:
max_normalized = VersionComparator.normalize_version(max_version)
comparison = VersionComparator.compare_versions(version_normalized, max_normalized)
if comparison > 0:
# 严格版本检查失败,尝试兼容性检查
is_compatible, compat_msg = VersionComparator.check_forward_compatibility(
version_normalized, max_normalized
)
if not is_compatible:
return False, f"版本 {version_normalized} 高于最大支持版本 {max_normalized},且无兼容性映射"
logger.info(f"版本兼容性检查:{compat_msg}")
return True, compat_msg
return True, ""
@staticmethod
def get_current_host_version() -> str:
"""获取当前主机应用版本
Returns:
str: 当前版本号
"""
return VersionComparator.normalize_version(MMC_VERSION)
@staticmethod
def add_compatibility_mapping(base_version: str, compatible_versions: list) -> None:
"""动态添加兼容性映射
Args:
base_version: 基础版本(插件声明的最大支持版本)
compatible_versions: 兼容的版本列表
"""
base_normalized = VersionComparator.normalize_version(base_version)
VersionComparator.COMPATIBILITY_MAP[base_normalized] = [
VersionComparator.normalize_version(v) for v in compatible_versions
]
logger.info(f"添加兼容性映射:{base_normalized} -> {compatible_versions}")
@staticmethod
def get_compatibility_info() -> Dict[str, list]:
"""获取当前的兼容性映射表
Returns:
Dict[str, list]: 兼容性映射表的副本
"""
return VersionComparator.COMPATIBILITY_MAP.copy()
class ManifestValidator:
"""Manifest文件验证器"""
# 必需字段(必须存在且不能为空)
REQUIRED_FIELDS = ["manifest_version", "name", "version", "description", "author"]
# 可选字段(可以不存在或为空)
OPTIONAL_FIELDS = [
"license",
"host_application",
"homepage_url",
"repository_url",
"keywords",
"categories",
"default_locale",
"locales_path",
"plugin_info",
]
# 建议填写的字段(会给出警告但不会导致验证失败)
RECOMMENDED_FIELDS = ["license", "keywords", "categories"]
SUPPORTED_MANIFEST_VERSIONS = [1]
def __init__(self):
self.validation_errors = []
self.validation_warnings = []
def validate_manifest(self, manifest_data: Dict[str, Any]) -> bool:
"""验证manifest数据
Args:
manifest_data: manifest数据字典
Returns:
bool: 是否验证通过(只有错误会导致验证失败,警告不会)
"""
self.validation_errors.clear()
self.validation_warnings.clear()
# 检查必需字段
for field in self.REQUIRED_FIELDS:
if field not in manifest_data:
self.validation_errors.append(f"缺少必需字段: {field}")
elif not manifest_data[field]:
self.validation_errors.append(f"必需字段不能为空: {field}")
# 检查manifest版本
if "manifest_version" in manifest_data:
version = manifest_data["manifest_version"]
if version not in self.SUPPORTED_MANIFEST_VERSIONS:
self.validation_errors.append(
f"不支持的manifest版本: {version},支持的版本: {self.SUPPORTED_MANIFEST_VERSIONS}"
)
# 检查作者信息格式
if "author" in manifest_data:
author = manifest_data["author"]
if isinstance(author, dict):
if "name" not in author or not author["name"]:
self.validation_errors.append("作者信息缺少name字段或为空")
# url字段是可选的
if "url" in author and author["url"]:
url = author["url"]
if not (url.startswith("http://") or url.startswith("https://")):
self.validation_warnings.append("作者URL建议使用完整的URL格式")
elif isinstance(author, str):
if not author.strip():
self.validation_errors.append("作者信息不能为空")
else:
self.validation_errors.append("作者信息格式错误应为字符串或包含name字段的对象")
# 检查主机应用版本要求(可选)
if "host_application" in manifest_data:
host_app = manifest_data["host_application"]
if isinstance(host_app, dict):
min_version = host_app.get("min_version", "")
max_version = host_app.get("max_version", "")
# 验证版本字段格式
for version_field in ["min_version", "max_version"]:
if version_field in host_app and not host_app[version_field]:
self.validation_warnings.append(f"host_application.{version_field}为空")
# 检查当前主机版本兼容性
if min_version or max_version:
current_version = VersionComparator.get_current_host_version()
is_compatible, error_msg = VersionComparator.is_version_in_range(
current_version, min_version, max_version
)
if not is_compatible:
self.validation_errors.append(f"版本兼容性检查失败: {error_msg} (当前版本: {current_version})")
else:
logger.debug(
f"版本兼容性检查通过: 当前版本 {current_version} 符合要求 [{min_version}, {max_version}]"
)
else:
self.validation_errors.append("host_application格式错误应为对象")
# 检查URL格式可选字段
for url_field in ["homepage_url", "repository_url"]:
if url_field in manifest_data and manifest_data[url_field]:
url: str = manifest_data[url_field]
if not (url.startswith("http://") or url.startswith("https://")):
self.validation_warnings.append(f"{url_field}建议使用完整的URL格式")
# 检查数组字段格式(可选字段)
for list_field in ["keywords", "categories"]:
if list_field in manifest_data:
field_value = manifest_data[list_field]
if field_value is not None and not isinstance(field_value, list):
self.validation_errors.append(f"{list_field}应为数组格式")
elif isinstance(field_value, list):
# 检查数组元素是否为字符串
for i, item in enumerate(field_value):
if not isinstance(item, str):
self.validation_warnings.append(f"{list_field}[{i}]应为字符串")
# 检查建议字段(给出警告)
for field in self.RECOMMENDED_FIELDS:
if field not in manifest_data or not manifest_data[field]:
self.validation_warnings.append(f"建议填写字段: {field}")
# 检查plugin_info结构可选
if "plugin_info" in manifest_data:
plugin_info = manifest_data["plugin_info"]
if isinstance(plugin_info, dict):
# 检查components数组
if "components" in plugin_info:
components = plugin_info["components"]
if not isinstance(components, list):
self.validation_errors.append("plugin_info.components应为数组格式")
else:
for i, component in enumerate(components):
if not isinstance(component, dict):
self.validation_errors.append(f"plugin_info.components[{i}]应为对象")
else:
# 检查组件必需字段
for comp_field in ["type", "name", "description"]:
if comp_field not in component or not component[comp_field]:
self.validation_errors.append(
f"plugin_info.components[{i}]缺少必需字段: {comp_field}"
)
else:
self.validation_errors.append("plugin_info应为对象格式")
return len(self.validation_errors) == 0
def get_validation_report(self) -> str:
"""获取验证报告"""
report = []
if self.validation_errors:
report.append("❌ 验证错误:")
report.extend(f" - {error}" for error in self.validation_errors)
if self.validation_warnings:
report.append("⚠️ 验证警告:")
report.extend(f" - {warning}" for warning in self.validation_warnings)
if not self.validation_errors and not self.validation_warnings:
report.append("✅ Manifest文件验证通过")
return "\n".join(report)
# class ManifestGenerator:
# """Manifest文件生成器"""
# def __init__(self):
# self.template = {
# "manifest_version": 1,
# "name": "",
# "version": "1.0.0",
# "description": "",
# "author": {"name": "", "url": ""},
# "license": "MIT",
# "host_application": {"min_version": "1.0.0", "max_version": "4.0.0"},
# "homepage_url": "",
# "repository_url": "",
# "keywords": [],
# "categories": [],
# "default_locale": "zh-CN",
# "locales_path": "_locales",
# }
# def generate_from_plugin(self, plugin_instance: BasePlugin) -> Dict[str, Any]:
# """从插件实例生成manifest
# Args:
# plugin_instance: BasePlugin实例
# Returns:
# Dict[str, Any]: 生成的manifest数据
# """
# manifest = self.template.copy()
# # 基本信息
# manifest["name"] = plugin_instance.plugin_name
# manifest["version"] = plugin_instance.plugin_version
# manifest["description"] = plugin_instance.plugin_description
# # 作者信息
# if plugin_instance.plugin_author:
# manifest["author"]["name"] = plugin_instance.plugin_author
# # 组件信息
# components = []
# plugin_components = plugin_instance.get_plugin_components()
# for component_info, component_class in plugin_components:
# component_data: Dict[str, Any] = {
# "type": component_info.component_type.value,
# "name": component_info.name,
# "description": component_info.description,
# }
# # 添加激活模式信息对于Action组件
# if hasattr(component_class, "focus_activation_type"):
# activation_modes = []
# if hasattr(component_class, "focus_activation_type"):
# activation_modes.append(component_class.focus_activation_type.value)
# if hasattr(component_class, "normal_activation_type"):
# activation_modes.append(component_class.normal_activation_type.value)
# component_data["activation_modes"] = list(set(activation_modes))
# # 添加关键词信息
# if hasattr(component_class, "activation_keywords"):
# keywords = getattr(component_class, "activation_keywords", [])
# if keywords:
# component_data["keywords"] = keywords
# components.append(component_data)
# manifest["plugin_info"] = {"is_built_in": True, "plugin_type": "general", "components": components}
# return manifest
# def save_manifest(self, manifest_data: Dict[str, Any], plugin_dir: str) -> bool:
# """保存manifest文件
# Args:
# manifest_data: manifest数据
# plugin_dir: 插件目录
# Returns:
# bool: 是否保存成功
# """
# try:
# manifest_path = os.path.join(plugin_dir, "_manifest.json")
# with open(manifest_path, "w", encoding="utf-8") as f:
# json.dump(manifest_data, f, ensure_ascii=False, indent=2)
# logger.info(f"Manifest文件已保存: {manifest_path}")
# return True
# except Exception as e:
# logger.error(f"保存manifest文件失败: {e}")
# return False
# def validate_plugin_manifest(plugin_dir: str) -> bool:
# """验证插件目录中的manifest文件
# Args:
# plugin_dir: 插件目录路径
# Returns:
# bool: 是否验证通过
# """
# manifest_path = os.path.join(plugin_dir, "_manifest.json")
# if not os.path.exists(manifest_path):
# logger.warning(f"未找到manifest文件: {manifest_path}")
# return False
# try:
# with open(manifest_path, "r", encoding="utf-8") as f:
# manifest_data = json.load(f)
# validator = ManifestValidator()
# is_valid = validator.validate_manifest(manifest_data)
# logger.info(f"Manifest验证结果:\n{validator.get_validation_report()}")
# return is_valid
# except Exception as e:
# logger.error(f"读取或验证manifest文件失败: {e}")
# return False
# def generate_plugin_manifest(plugin_instance: BasePlugin, save_to_file: bool = True) -> Optional[Dict[str, Any]]:
# """为插件生成manifest文件
# Args:
# plugin_instance: BasePlugin实例
# save_to_file: 是否保存到文件
# Returns:
# Optional[Dict[str, Any]]: 生成的manifest数据
# """
# try:
# generator = ManifestGenerator()
# manifest_data = generator.generate_from_plugin(plugin_instance)
# if save_to_file and plugin_instance.plugin_dir:
# generator.save_manifest(manifest_data, plugin_instance.plugin_dir)
# return manifest_data
# except Exception as e:
# logger.error(f"生成manifest文件失败: {e}")
# return None

View File

@@ -1,25 +1,21 @@
{
"manifest_version": 1,
"name": "Emoji插件 (Emoji Actions)",
"version": "1.0.0",
"version": "2.0.0",
"description": "可以发送和管理Emoji",
"author": {
"name": "SengokuCola",
"url": "https://github.com/MaiM-with-u"
},
"license": "GPL-v3.0-or-later",
"host_application": {
"min_version": "0.10.0"
"min_version": "1.0.0"
},
"homepage_url": "https://github.com/MaiM-with-u/maibot",
"repository_url": "https://github.com/MaiM-with-u/maibot",
"keywords": ["emoji", "action", "built-in"],
"categories": ["Emoji"],
"default_locale": "zh-CN",
"locales_path": "_locales",
"plugin_info": {
"is_built_in": true,
"plugin_type": "action_provider",
@@ -30,5 +26,13 @@
"description": "发送表情包辅助表达情绪"
}
]
}
},
"capabilities": [
"emoji.get_random",
"message.get_recent",
"message.build_readable",
"llm.generate",
"send.emoji",
"config.get"
]
}

View File

@@ -1,151 +0,0 @@
import random
from typing import Tuple
# 导入新插件系统
from src.plugin_system import BaseAction, ActionActivationType
# 导入依赖的系统组件
from src.common.logger import get_logger
# 导入API模块 - 标准Python包方式
from src.plugin_system.apis import emoji_api, llm_api, message_api
# NoReplyAction已集成到heartFC_chat.py中不再需要导入
from src.config.config import global_config
logger = get_logger("emoji")
class EmojiAction(BaseAction):
"""表情动作 - 发送表情包"""
activation_type = ActionActivationType.RANDOM
random_activation_probability = global_config.emoji.emoji_chance
parallel_action = True
# 动作基本信息
action_name = "emoji"
action_description = "发送表情包辅助表达情绪"
# 动作参数定义
action_parameters = {}
# 动作使用场景
action_require = [
"发送表情包辅助表达情绪",
"表达情绪时可以选择使用",
"不要连续发送,如果你已经发过[表情包],就不要选择此动作",
]
# 关联类型
associated_types = ["emoji"]
async def execute(self) -> Tuple[bool, str]:
# sourcery skip: assign-if-exp, introduce-default-else, swap-if-else-branches, use-named-expression
"""执行表情动作"""
try:
# 1. 获取发送表情的原因
# reason = self.action_data.get("reason", "表达当前情绪")
reason = self.action_reasoning
# 2. 随机获取20个表情包
sampled_emojis = await emoji_api.get_random(30)
if not sampled_emojis:
logger.warning(f"{self.log_prefix} 无法获取随机表情包")
return False, "无法获取随机表情包"
# 3. 准备情感数据
emotion_map = {}
for b64, desc, emo in sampled_emojis:
if emo not in emotion_map:
emotion_map[emo] = []
emotion_map[emo].append((b64, desc))
available_emotions = list(emotion_map.keys())
available_emotions_str = ""
for emotion in available_emotions:
available_emotions_str += f"{emotion}\n"
if not available_emotions:
logger.warning(f"{self.log_prefix} 获取到的表情包均无情感标签, 将随机发送")
emoji_base64, emoji_description, _ = random.choice(sampled_emojis)
else:
# 获取最近的5条消息内容用于判断
recent_messages = message_api.get_recent_messages(chat_id=self.chat_id, limit=5)
messages_text = ""
if recent_messages:
# 使用message_api构建可读的消息字符串
messages_text = message_api.build_readable_messages(
messages=recent_messages,
timestamp_mode="normal_no_YMD",
truncate=False,
show_actions=False,
)
# 4. 构建prompt让LLM选择情感
prompt = f"""你正在进行QQ聊天你需要根据聊天记录选出一个合适的情感标签。
请你根据以下原因和聊天记录进行选择
原因:{reason}
聊天记录:
{messages_text}
这里是可用的情感标签:
{available_emotions_str}
请直接返回最匹配的那个情感标签,不要进行任何解释或添加其他多余的文字。
"""
if global_config.debug.show_prompt:
logger.info(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
else:
logger.debug(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
# 5. 调用LLM
models = llm_api.get_available_models()
chat_model_config = models.get("utils") # 使用字典访问方式
if not chat_model_config:
logger.error(f"{self.log_prefix} 未找到'utils'模型配置无法调用LLM")
return False, "未找到'utils'模型配置"
success, chosen_emotion, _, _ = await llm_api.generate_with_model(
prompt, model_config=chat_model_config, request_type="emoji.select"
)
if not success:
logger.error(f"{self.log_prefix} LLM调用失败: {chosen_emotion}")
return False, f"LLM调用失败: {chosen_emotion}"
chosen_emotion = chosen_emotion.strip().replace('"', "").replace("'", "")
logger.info(f"{self.log_prefix} LLM选择的情感: {chosen_emotion}")
# 6. 根据选择的情感匹配表情包
if chosen_emotion in emotion_map:
emoji_base64, emoji_description = random.choice(emotion_map[chosen_emotion])
logger.info(f"{self.log_prefix} 发送表情包[{chosen_emotion}],原因: {reason}")
else:
logger.warning(
f"{self.log_prefix} LLM选择的情感 '{chosen_emotion}' 不在可用列表中, 将随机选择一个表情包"
)
emoji_base64, emoji_description, _ = random.choice(sampled_emojis)
# 7. 发送表情包
success = await self.send_emoji(emoji_base64)
if success:
# 存储动作信息
await self.store_action_info(
action_build_into_prompt=True,
action_prompt_display=f"你发送了表情包,原因:{reason}",
action_done=True,
)
return True, f"成功发送表情包:[表情包:{chosen_emotion}]"
else:
error_msg = "发送表情包失败"
logger.error(f"{self.log_prefix} {error_msg}")
await self.send_text("执行表情包动作失败")
return False, error_msg
except Exception as e:
logger.error(f"{self.log_prefix} 表情动作执行失败: {e}", exc_info=True)
return False, f"表情发送失败: {str(e)}"

View File

@@ -1,66 +1,116 @@
"""
核心动作插件
"""Emoji 插件 — 新 SDK 版本
将系统核心动作reply、no_reply、emoji转换为新插件系统格式
这是系统的内置插件,提供基础的聊天交互功能
根据聊天上下文的情感,使用 LLM 选择并发送合适的表情包。
"""
from typing import List, Tuple, Type
import random
# 导入新插件系统
from src.plugin_system import BasePlugin, register_plugin, ComponentInfo
from src.plugin_system.base.config_types import ConfigField
# 导入依赖的系统组件
from src.common.logger import get_logger
from src.plugins.built_in.emoji_plugin.emoji import EmojiAction
logger = get_logger("core_actions")
from maibot_sdk import MaiBotPlugin, Action
from maibot_sdk.types import ActivationType
@register_plugin
class CoreActionsPlugin(BasePlugin):
"""核心动作插件
class EmojiPlugin(MaiBotPlugin):
"""表情包插件"""
系统内置插件,提供基础的聊天交互功能:
- Reply: 回复动作
- NoReply: 不回复动作
- Emoji: 表情动作
@Action(
"emoji",
description="发送表情包辅助表达情绪",
activation_type=ActivationType.RANDOM,
activation_probability=0.3,
parallel_action=True,
action_require=[
"发送表情包辅助表达情绪",
"表达情绪时可以选择使用",
"不要连续发送,如果你已经发过[表情包],就不要选择此动作",
],
associated_types=["emoji"],
)
async def handle_emoji(self, stream_id: str = "", reasoning: str = "", chat_id: str = "", **kwargs):
"""执行表情动作"""
reason = reasoning or "表达当前情绪"
注意插件基本信息优先从_manifest.json文件中读取
# 1. 随机获取30个表情包
result = await self.ctx.emoji.get_random(30)
if not result or not result.get("success"):
return False, "无法获取随机表情包"
sampled_emojis = result.get("emojis", [])
if not sampled_emojis:
return False, "无法获取随机表情包"
# 2. 按情感分组
emotion_map: dict[str, list] = {}
for emoji in sampled_emojis:
emo = emoji.get("emotion", "")
if emo not in emotion_map:
emotion_map[emo] = []
emotion_map[emo].append(emoji)
available_emotions = list(emotion_map.keys())
if not available_emotions:
# 无情感标签,随机发送
chosen = random.choice(sampled_emojis)
await self.ctx.send.emoji(chosen["base64"], stream_id)
return True, "随机发送了表情包"
# 3. 获取最近消息作为上下文
messages_text = ""
if chat_id:
recent_result = await self.ctx.message.get_recent(chat_id=chat_id, limit=5)
if recent_result and recent_result.get("success"):
readable_result = await self.ctx.call_capability(
"message.build_readable",
chat_id=chat_id,
start_time=0,
end_time=0,
limit=5,
timestamp_mode="normal_no_YMD",
truncate=False,
)
if readable_result and readable_result.get("success"):
messages_text = readable_result.get("text", "")
# 4. 构建 prompt 让 LLM 选择情感
available_emotions_str = "\n".join(available_emotions)
prompt = f"""你正在进行QQ聊天你需要根据聊天记录选出一个合适的情感标签。
请你根据以下原因和聊天记录进行选择
原因:{reason}
聊天记录:
{messages_text}
这里是可用的情感标签:
{available_emotions_str}
请直接返回最匹配的那个情感标签,不要进行任何解释或添加其他多余的文字。
"""
# 插件基本信息
plugin_name: str = "core_actions" # 内部标识符
enable_plugin: bool = True
dependencies: list[str] = [] # 插件依赖列表
python_dependencies: list[str] = [] # Python包依赖列表
config_file_name: str = "config.toml"
# 5. 调用 LLM
llm_result = await self.ctx.llm.generate(prompt=prompt, model_name="utils")
if not llm_result or not llm_result.get("success"):
chosen = random.choice(sampled_emojis)
await self.ctx.send.emoji(chosen["base64"], stream_id)
return True, "LLM调用失败随机发送了表情包"
# 配置节描述
config_section_descriptions = {
"plugin": "插件启用配置",
"components": "核心组件启用配置",
}
chosen_emotion = llm_result.get("response", "").strip().replace('"', "").replace("'", "")
# 配置Schema定义
config_schema: dict = {
"plugin": {
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
"config_version": ConfigField(type=str, default="0.6.0", description="配置文件版本"),
},
"components": {
"enable_emoji": ConfigField(type=bool, default=True, description="是否启用发送表情/图片动作"),
},
}
# 6. 根据选择的情感匹配表情包
if chosen_emotion in emotion_map:
chosen = random.choice(emotion_map[chosen_emotion])
else:
chosen = random.choice(sampled_emojis)
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
"""返回插件包含的组件列表"""
# 7. 发送
send_result = await self.ctx.send.emoji(chosen["base64"], stream_id)
if send_result and send_result.get("success"):
return True, f"成功发送表情包:[表情包:{chosen_emotion}]"
return False, "发送表情包失败"
# --- 根据配置注册组件 ---
components = []
if self.get_config("components.enable_emoji", True):
components.append((EmojiAction.get_action_info(), EmojiAction))
async def on_load(self):
# 从插件配置读取 emoji_chance 来覆盖默认概率
config_result = await self.ctx.config.get("emoji.emoji_chance")
if config_result and isinstance(config_result, dict) and config_result.get("success"):
pass # 配置已在宿主端管理
return components
def create_plugin():
return EmojiPlugin()

View File

@@ -0,0 +1,33 @@
{
"manifest_version": 1,
"name": "LPMM 知识库插件 (Knowledge Search)",
"version": "2.0.0",
"description": "从 LPMM 知识库中搜索相关信息,供 LLM 工具调用",
"author": {
"name": "MaiBot团队",
"url": "https://github.com/MaiM-with-u"
},
"license": "GPL-v3.0-or-later",
"host_application": {
"min_version": "1.0.0"
},
"homepage_url": "https://github.com/MaiM-with-u/maibot",
"repository_url": "https://github.com/MaiM-with-u/maibot",
"keywords": ["knowledge", "lpmm", "search", "tool", "built-in"],
"categories": ["Knowledge", "Tools"],
"default_locale": "zh-CN",
"plugin_info": {
"is_built_in": true,
"plugin_type": "tool_provider",
"components": [
{
"type": "tool",
"name": "lpmm_search_knowledge",
"description": "从知识库中搜索相关信息"
}
]
},
"capabilities": [
"knowledge.search"
]
}

View File

@@ -1,62 +0,0 @@
from typing import Dict, Any
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.knowledge import qa_manager
from src.plugin_system import BaseTool, ToolParamType
logger = get_logger("lpmm_get_knowledge_tool")
class SearchKnowledgeFromLPMMTool(BaseTool):
"""从LPMM知识库中搜索相关信息的工具"""
name = "lpmm_search_knowledge"
description = "从知识库中搜索相关信息,如果你需要知识,就使用这个工具"
parameters = [
("query", ToolParamType.STRING, "搜索查询关键词", True, None),
("limit", ToolParamType.INTEGER, "希望返回的相关知识条数默认5", False, None),
]
available_for_llm = global_config.lpmm_knowledge.enable
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
"""执行知识库搜索
Args:
function_args: 工具参数
Returns:
Dict: 工具执行结果
"""
try:
query: str = function_args.get("query") # type: ignore
limit = function_args.get("limit", 5)
try:
limit_value = int(limit)
except (TypeError, ValueError):
limit_value = 5
limit_value = max(1, limit_value)
# threshold = function_args.get("threshold", 0.4)
# 检查LPMM知识库是否启用
if qa_manager is None:
logger.debug("LPMM知识库已禁用跳过知识获取")
return {"type": "info", "id": query, "content": "LPMM知识库已禁用"}
# 调用知识库搜索
knowledge_info = await qa_manager.get_knowledge(query, limit=limit_value)
logger.debug(f"知识库查询结果: {knowledge_info}")
if knowledge_info:
content = f"你知道这些知识: {knowledge_info}"
else:
content = f"你不太了解有关{query}的知识"
return {"type": "lpmm_knowledge", "id": query, "content": content}
except Exception as e:
# 捕获异常并记录错误
logger.error(f"知识库搜索工具执行失败: {str(e)}")
# 在其他异常情况下,确保 id 仍然是 query (如果它被定义了)
query_id = query if "query" in locals() else "unknown_query"
return {"type": "info", "id": query_id, "content": f"lpmm知识库搜索失败炸了: {str(e)}"}

View File

@@ -0,0 +1,39 @@
"""LPMM 知识库搜索插件 — 新 SDK 版本
提供 LLM 可调用的知识库搜索工具。
"""
from maibot_sdk import MaiBotPlugin, Tool
from maibot_sdk.types import ToolParameterInfo, ToolParamType
class KnowledgePlugin(MaiBotPlugin):
"""LPMM 知识库插件"""
@Tool(
"lpmm_search_knowledge",
description="从知识库中搜索相关信息,如果你需要知识,就使用这个工具",
parameters=[
ToolParameterInfo(name="query", param_type=ToolParamType.STRING, description="搜索查询关键词", required=True),
ToolParameterInfo(name="limit", param_type=ToolParamType.INTEGER, description="希望返回的相关知识条数默认5", required=False, default=5),
],
)
async def handle_lpmm_search_knowledge(self, query: str = "", limit: int = 5, **kwargs):
"""执行知识库搜索"""
if not query:
return {"type": "info", "id": "", "content": "未提供搜索关键词"}
try:
limit_value = max(1, int(limit))
except (TypeError, ValueError):
limit_value = 5
result = await self.ctx.call_capability("knowledge.search", query=query, limit=limit_value)
if result and result.get("success"):
content = result.get("content", f"你不太了解有关{query}的知识")
return {"type": "lpmm_knowledge", "id": query, "content": content}
return {"type": "info", "id": query, "content": f"知识库搜索失败: {result}"}
def create_plugin():
return KnowledgePlugin()

View File

@@ -1,7 +1,7 @@
{
"manifest_version": 1,
"name": "插件和组件管理 (Plugin and Component Management)",
"version": "1.0.0",
"version": "2.0.0",
"description": "通过系统API管理插件和组件的生命周期包括加载、卸载、启用和禁用等操作。",
"author": {
"name": "MaiBot团队",
@@ -9,7 +9,7 @@
},
"license": "GPL-v3.0-or-later",
"host_application": {
"min_version": "0.10.1"
"min_version": "1.0.0"
},
"homepage_url": "https://github.com/MaiM-with-u/maibot",
"repository_url": "https://github.com/MaiM-with-u/maibot",
@@ -28,10 +28,22 @@
"plugin_info": {
"is_built_in": true,
"plugin_type": "plugin_management",
"capabilities": [
"component.get_all_plugins",
"component.list_loaded_plugins",
"component.list_registered_plugins",
"component.enable",
"component.disable",
"component.load_plugin",
"component.unload_plugin",
"component.reload_plugin",
"send.text",
"config.get"
],
"components": [
{
"type": "command",
"name": "plugin_management",
"name": "management",
"description": "管理插件和组件的生命周期,包括加载、卸载、启用和禁用等操作。"
}
]

View File

@@ -1,177 +1,30 @@
import asyncio
"""插件和组件管理 — 新 SDK 版本
from typing import List, Tuple, Type
from src.plugin_system import (
BasePlugin,
BaseCommand,
CommandInfo,
ConfigField,
register_plugin,
plugin_manage_api,
component_manage_api,
ComponentInfo,
ComponentType,
send_api,
)
通过 /pm 命令管理插件和组件的生命周期。
"""
from maibot_sdk import MaiBotPlugin, Command
class ManagementCommand(BaseCommand):
command_name: str = "management"
description: str = "管理命令"
command_pattern: str = r"(?P<manage_command>^/pm(\s[a-zA-Z0-9_]+)*\s*$)"
_VALID_COMPONENT_TYPES = ("action", "command", "event_handler")
async def execute(self) -> Tuple[bool, str, bool]:
# sourcery skip: merge-duplicate-blocks
if (
not self.message
or not self.message.message_info
or not self.message.message_info.user_info
or str(self.message.message_info.user_info.user_id) not in self.get_config("plugin.permission", []) # type: ignore
):
await self._send_message("你没有权限使用插件管理命令")
return False, "没有权限", True
if not self.message.chat_stream:
await self._send_message("无法获取聊天流信息")
return False, "无法获取聊天流信息", True
self.stream_id = self.message.chat_stream.stream_id
if not self.stream_id:
await self._send_message("无法获取聊天流信息")
return False, "无法获取聊天流信息", True
command_list = self.matched_groups["manage_command"].strip().split(" ")
if len(command_list) == 1:
await self.show_help("all")
return True, "帮助已发送", True
if len(command_list) == 2:
match command_list[1]:
case "plugin":
await self.show_help("plugin")
case "component":
await self.show_help("component")
case "help":
await self.show_help("all")
case _:
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
if len(command_list) == 3:
if command_list[1] == "plugin":
match command_list[2]:
case "help":
await self.show_help("plugin")
case "list":
await self._list_registered_plugins()
case "list_enabled":
await self._list_loaded_plugins()
case "rescan":
await self._rescan_plugin_dirs()
case _:
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
elif command_list[1] == "component":
if command_list[2] == "list":
await self._list_all_registered_components()
elif command_list[2] == "help":
await self.show_help("component")
else:
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
else:
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
if len(command_list) == 4:
if command_list[1] == "plugin":
match command_list[2]:
case "load":
await self._load_plugin(command_list[3])
case "unload":
await self._unload_plugin(command_list[3])
case "reload":
await self._reload_plugin(command_list[3])
case "add_dir":
await self._add_dir(command_list[3])
case _:
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
elif command_list[1] == "component":
if command_list[2] != "list":
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
if command_list[3] == "enabled":
await self._list_enabled_components()
elif command_list[3] == "disabled":
await self._list_disabled_components()
else:
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
else:
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
if len(command_list) == 5:
if command_list[1] != "component":
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
if command_list[2] != "list":
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
if command_list[3] == "enabled":
await self._list_enabled_components(target_type=command_list[4])
elif command_list[3] == "disabled":
await self._list_disabled_components(target_type=command_list[4])
elif command_list[3] == "type":
await self._list_registered_components_by_type(command_list[4])
else:
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
if len(command_list) == 6:
if command_list[1] != "component":
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
if command_list[2] == "enable":
if command_list[3] == "global":
await self._globally_enable_component(command_list[4], command_list[5])
elif command_list[3] == "local":
await self._locally_enable_component(command_list[4], command_list[5])
else:
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
elif command_list[2] == "disable":
if command_list[3] == "global":
await self._globally_disable_component(command_list[4], command_list[5])
elif command_list[3] == "local":
await self._locally_disable_component(command_list[4], command_list[5])
else:
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
else:
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
return True, "命令执行完成", True
async def show_help(self, target: str):
help_msg = ""
match target:
case "all":
help_msg = (
HELP_ALL = (
"管理命令帮助\n"
"/pm help 管理命令提示\n"
"/pm plugin 插件管理命令\n"
"/pm component 组件管理命令\n"
"使用 /pm plugin help 或 /pm component help 获取具体帮助"
)
case "plugin":
help_msg = (
HELP_PLUGIN = (
"插件管理命令帮助\n"
"/pm plugin help 插件管理命令提示\n"
"/pm plugin list 列出所有注册的插件\n"
"/pm plugin list_enabled 列出所有加载(启用)的插件\n"
"/pm plugin rescan 重新扫描所有目录\n"
"/pm plugin load <plugin_name> 加载指定插件\n"
"/pm plugin unload <plugin_name> 卸载指定插件\n"
"/pm plugin reload <plugin_name> 重新加载指定插件\n"
"/pm plugin add_dir <directory_path> 添加插件目录\n"
)
case "component":
help_msg = (
HELP_COMPONENT = (
"组件管理命令帮助\n"
"/pm component help 组件管理命令提示\n"
"/pm component list 列出所有注册的组件\n"
@@ -186,269 +39,241 @@ class ManagementCommand(BaseCommand):
"/pm component disable local <component_name> <component_type> 本聊天禁用组件\n"
" - <component_type> 可选项: action, command, event_handler\n"
)
class PluginManagementPlugin(MaiBotPlugin):
"""插件和组件管理插件"""
@Command(
"management",
description="管理插件和组件的生命周期",
pattern=r"(?P<manage_command>^/pm(\s[a-zA-Z0-9_]+)*\s*$)",
)
async def handle_management(
self, stream_id: str = "", user_id: str = "", matched_groups: dict | None = None, **kwargs
):
"""处理 /pm 命令"""
# 权限检查
permission_result = await self.ctx.config.get("plugin.permission")
permission_list = permission_result if isinstance(permission_result, list) else []
if str(user_id) not in permission_list:
await self.ctx.send.text("你没有权限使用插件管理命令", stream_id)
return False, "没有权限", True
if not stream_id:
return False, "无法获取聊天流信息", True
raw_command = (matched_groups or {}).get("manage_command", "").strip()
parts = raw_command.split(" ") if raw_command else ["/pm"]
n = len(parts)
# /pm
if n == 1:
await self.ctx.send.text(HELP_ALL, stream_id)
return True, "帮助已发送", True
# /pm <sub>
if n == 2:
sub = parts[1]
if sub == "plugin":
await self.ctx.send.text(HELP_PLUGIN, stream_id)
elif sub == "component":
await self.ctx.send.text(HELP_COMPONENT, stream_id)
elif sub == "help":
await self.ctx.send.text(HELP_ALL, stream_id)
else:
await self.ctx.send.text("插件管理命令不合法", stream_id)
return False, "命令不合法", True
return True, "帮助已发送", True
# /pm plugin <action> / /pm component <action>
if n == 3:
if parts[1] == "plugin":
await self._handle_plugin_3(parts[2], stream_id)
elif parts[1] == "component":
if parts[2] == "list":
await self._list_all_components(stream_id)
elif parts[2] == "help":
await self.ctx.send.text(HELP_COMPONENT, stream_id)
else:
await self.ctx.send.text("插件管理命令不合法", stream_id)
return False, "命令不合法", True
else:
await self.ctx.send.text("插件管理命令不合法", stream_id)
return False, "命令不合法", True
return True, "命令执行完成", True
if n == 4:
if parts[1] == "plugin":
await self._handle_plugin_4(parts[2], parts[3], stream_id)
elif parts[1] == "component":
if parts[2] == "list":
await self._handle_component_list_4(parts[3], stream_id)
else:
await self.ctx.send.text("插件管理命令不合法", stream_id)
return False, "命令不合法", True
else:
await self.ctx.send.text("插件管理命令不合法", stream_id)
return False, "命令不合法", True
return True, "命令执行完成", True
if n == 5:
if parts[1] != "component" or parts[2] != "list":
await self.ctx.send.text("插件管理命令不合法", stream_id)
return False, "命令不合法", True
await self._handle_component_list_5(parts[3], parts[4], stream_id)
return True, "命令执行完成", True
if n == 6:
if parts[1] != "component":
await self.ctx.send.text("插件管理命令不合法", stream_id)
return False, "命令不合法", True
await self._handle_component_toggle(parts[2], parts[3], parts[4], parts[5], stream_id)
return True, "命令执行完成", True
await self.ctx.send.text("插件管理命令不合法", stream_id)
return False, "命令不合法", True
# ------ plugin 子命令 ------
async def _handle_plugin_3(self, action: str, stream_id: str):
match action:
case "help":
await self.ctx.send.text(HELP_PLUGIN, stream_id)
case "list":
result = await self.ctx.component.list_registered_plugins()
plugins = result if isinstance(result, list) else []
await self.ctx.send.text(f"已注册的插件: {', '.join(plugins) if plugins else ''}", stream_id)
case "list_enabled":
result = await self.ctx.component.list_loaded_plugins()
plugins = result if isinstance(result, list) else []
await self.ctx.send.text(f"已加载的插件: {', '.join(plugins) if plugins else ''}", stream_id)
case _:
await self.ctx.send.text("插件管理命令不合法", stream_id)
async def _handle_plugin_4(self, action: str, name: str, stream_id: str):
match action:
case "load":
result = await self.ctx.component.load_plugin(name)
ok = result.get("success", False) if isinstance(result, dict) else bool(result)
msg = f"插件加载成功: {name}" if ok else f"插件加载失败: {name}"
await self.ctx.send.text(msg, stream_id)
case "unload":
result = await self.ctx.component.unload_plugin(name)
ok = result.get("success", False) if isinstance(result, dict) else bool(result)
msg = f"插件卸载成功: {name}" if ok else f"插件卸载失败: {name}"
await self.ctx.send.text(msg, stream_id)
case "reload":
result = await self.ctx.component.reload_plugin(name)
ok = result.get("success", False) if isinstance(result, dict) else bool(result)
msg = f"插件重新加载成功: {name}" if ok else f"插件重新加载失败: {name}"
await self.ctx.send.text(msg, stream_id)
case _:
await self.ctx.send.text("插件管理命令不合法", stream_id)
# ------ component 子命令 ------
async def _list_all_components(self, stream_id: str):
result = await self.ctx.component.get_all_plugins()
if not result:
await self.ctx.send.text("没有注册的组件", stream_id)
return
await self._send_message(help_msg)
components = self._extract_components(result)
if not components:
await self.ctx.send.text("没有注册的组件", stream_id)
return
text = ", ".join(f"{c['name']} ({c['type']})" for c in components)
await self.ctx.send.text(f"已注册的组件: {text}", stream_id)
async def _list_loaded_plugins(self):
plugins = plugin_manage_api.list_loaded_plugins()
await self._send_message(f"已加载的插件: {', '.join(plugins)}")
async def _list_registered_plugins(self):
plugins = plugin_manage_api.list_registered_plugins()
await self._send_message(f"已注册的插件: {', '.join(plugins)}")
async def _rescan_plugin_dirs(self):
plugin_manage_api.rescan_plugin_directory()
await self._send_message("插件目录重新扫描执行中")
async def _load_plugin(self, plugin_name: str):
success, count = plugin_manage_api.load_plugin(plugin_name)
if success:
await self._send_message(f"插件加载成功: {plugin_name}")
async def _handle_component_list_4(self, sub: str, stream_id: str):
if sub == "enabled":
await self._list_filtered_components("enabled", "global", stream_id)
elif sub == "disabled":
await self._list_filtered_components("disabled", "global", stream_id)
else:
if count == 0:
await self._send_message(f"插件{plugin_name}为禁用状态")
await self._send_message(f"插件加载失败: {plugin_name}")
await self.ctx.send.text("插件管理命令不合法", stream_id)
async def _unload_plugin(self, plugin_name: str):
success = await plugin_manage_api.remove_plugin(plugin_name)
if success:
await self._send_message(f"插件卸载成功: {plugin_name}")
async def _handle_component_list_5(self, sub: str, arg: str, stream_id: str):
if sub in ("enabled", "disabled"):
await self._list_filtered_components(sub, arg, stream_id)
elif sub == "type":
if arg not in _VALID_COMPONENT_TYPES:
await self.ctx.send.text(f"未知组件类型: {arg}", stream_id)
return
result = await self.ctx.component.get_all_plugins()
components = [c for c in self._extract_components(result) if c.get("type") == arg]
if not components:
await self.ctx.send.text(f"没有注册的 {arg} 组件", stream_id)
return
text = ", ".join(f"{c['name']} ({c['type']})" for c in components)
await self.ctx.send.text(f"注册的 {arg} 组件: {text}", stream_id)
else:
await self._send_message(f"插件卸载失败: {plugin_name}")
await self.ctx.send.text("插件管理命令不合法", stream_id)
async def _reload_plugin(self, plugin_name: str):
success = await plugin_manage_api.reload_plugin(plugin_name)
if success:
await self._send_message(f"插件重新加载成功: {plugin_name}")
async def _list_filtered_components(self, filter_mode: str, scope: str, stream_id: str):
result = await self.ctx.component.get_all_plugins()
all_components = self._extract_components(result)
if not all_components:
await self.ctx.send.text("没有注册的组件", stream_id)
return
if filter_mode == "enabled":
filtered = [c for c in all_components if c.get("enabled", False)]
label = "已启用"
else:
await self._send_message(f"插件重新加载失败: {plugin_name}")
filtered = [c for c in all_components if not c.get("enabled", False)]
label = "已禁用"
async def _add_dir(self, dir_path: str):
await self._send_message(f"正在添加插件目录: {dir_path}")
success = plugin_manage_api.add_plugin_directory(dir_path)
await asyncio.sleep(0.5) # 防止乱序发送
if success:
await self._send_message(f"插件目录添加成功: {dir_path}")
scope_label = "全局" if scope == "global" else "本聊天"
if not filtered:
await self.ctx.send.text(f"没有满足条件的{label}{scope_label}组件", stream_id)
return
text = ", ".join(f"{c['name']} ({c['type']})" for c in filtered)
await self.ctx.send.text(f"满足条件的{label}{scope_label}组件: {text}", stream_id)
async def _handle_component_toggle(
self, action: str, scope: str, comp_name: str, comp_type: str, stream_id: str
):
if action not in ("enable", "disable"):
await self.ctx.send.text("插件管理命令不合法", stream_id)
return
if scope not in ("global", "local"):
await self.ctx.send.text("插件管理命令不合法", stream_id)
return
if comp_type not in _VALID_COMPONENT_TYPES:
await self.ctx.send.text(f"未知组件类型: {comp_type}", stream_id)
return
if action == "enable":
result = await self.ctx.component.enable_component(
comp_name, comp_type, scope=scope, stream_id=stream_id
)
else:
await self._send_message(f"插件目录添加失败: {dir_path}")
result = await self.ctx.component.disable_component(
comp_name, comp_type, scope=scope, stream_id=stream_id
)
def _fetch_all_registered_components(self) -> List[ComponentInfo]:
all_plugin_info = component_manage_api.get_all_plugin_info()
if not all_plugin_info:
ok = result.get("success", False) if isinstance(result, dict) else bool(result)
scope_label = "全局" if scope == "global" else "本地"
action_label = "启用" if action == "enable" else "禁用"
status = "成功" if ok else "失败"
await self.ctx.send.text(f"{scope_label}{action_label}组件{status}: {comp_name}", stream_id)
# ------ helpers ------
@staticmethod
def _extract_components(result) -> list[dict]:
"""从 get_all_plugins 结果中提取所有组件列表"""
if not result:
return []
if isinstance(result, dict):
components = []
for plugin_info in result.values():
if isinstance(plugin_info, dict):
components.extend(plugin_info.get("components", []))
return components
return []
components_info: List[ComponentInfo] = []
for plugin_info in all_plugin_info.values():
components_info.extend(plugin_info.components)
return components_info
def _fetch_locally_disabled_components(self) -> List[str]:
locally_disabled_components_actions = component_manage_api.get_locally_disabled_components(
self.message.chat_stream.stream_id, ComponentType.ACTION
)
locally_disabled_components_commands = component_manage_api.get_locally_disabled_components(
self.message.chat_stream.stream_id, ComponentType.COMMAND
)
locally_disabled_components_event_handlers = component_manage_api.get_locally_disabled_components(
self.message.chat_stream.stream_id, ComponentType.EVENT_HANDLER
)
return (
locally_disabled_components_actions
+ locally_disabled_components_commands
+ locally_disabled_components_event_handlers
)
async def _list_all_registered_components(self):
components_info = self._fetch_all_registered_components()
if not components_info:
await self._send_message("没有注册的组件")
return
all_components_str = ", ".join(
f"{component.name} ({component.component_type})" for component in components_info
)
await self._send_message(f"已注册的组件: {all_components_str}")
async def _list_enabled_components(self, target_type: str = "global"):
components_info = self._fetch_all_registered_components()
if not components_info:
await self._send_message("没有注册的组件")
return
if target_type == "global":
enabled_components = [component for component in components_info if component.enabled]
if not enabled_components:
await self._send_message("没有满足条件的已启用全局组件")
return
enabled_components_str = ", ".join(
f"{component.name} ({component.component_type})" for component in enabled_components
)
await self._send_message(f"满足条件的已启用全局组件: {enabled_components_str}")
elif target_type == "local":
locally_disabled_components = self._fetch_locally_disabled_components()
enabled_components = [
component
for component in components_info
if (component.name not in locally_disabled_components and component.enabled)
]
if not enabled_components:
await self._send_message("本聊天没有满足条件的已启用组件")
return
enabled_components_str = ", ".join(
f"{component.name} ({component.component_type})" for component in enabled_components
)
await self._send_message(f"本聊天满足条件的已启用组件: {enabled_components_str}")
async def _list_disabled_components(self, target_type: str = "global"):
components_info = self._fetch_all_registered_components()
if not components_info:
await self._send_message("没有注册的组件")
return
if target_type == "global":
disabled_components = [component for component in components_info if not component.enabled]
if not disabled_components:
await self._send_message("没有满足条件的已禁用全局组件")
return
disabled_components_str = ", ".join(
f"{component.name} ({component.component_type})" for component in disabled_components
)
await self._send_message(f"满足条件的已禁用全局组件: {disabled_components_str}")
elif target_type == "local":
locally_disabled_components = self._fetch_locally_disabled_components()
disabled_components = [
component
for component in components_info
if (component.name in locally_disabled_components or not component.enabled)
]
if not disabled_components:
await self._send_message("本聊天没有满足条件的已禁用组件")
return
disabled_components_str = ", ".join(
f"{component.name} ({component.component_type})" for component in disabled_components
)
await self._send_message(f"本聊天满足条件的已禁用组件: {disabled_components_str}")
async def _list_registered_components_by_type(self, target_type: str):
match target_type:
case "action":
component_type = ComponentType.ACTION
case "command":
component_type = ComponentType.COMMAND
case "event_handler":
component_type = ComponentType.EVENT_HANDLER
case _:
await self._send_message(f"未知组件类型: {target_type}")
return
components_info = component_manage_api.get_components_info_by_type(component_type)
if not components_info:
await self._send_message(f"没有注册的 {target_type} 组件")
return
components_str = ", ".join(
f"{name} ({component.component_type})" for name, component in components_info.items()
)
await self._send_message(f"注册的 {target_type} 组件: {components_str}")
async def _globally_enable_component(self, component_name: str, component_type: str):
match component_type:
case "action":
target_component_type = ComponentType.ACTION
case "command":
target_component_type = ComponentType.COMMAND
case "event_handler":
target_component_type = ComponentType.EVENT_HANDLER
case _:
await self._send_message(f"未知组件类型: {component_type}")
return
if component_manage_api.globally_enable_component(component_name, target_component_type):
await self._send_message(f"全局启用组件成功: {component_name}")
else:
await self._send_message(f"全局启用组件失败: {component_name}")
async def _globally_disable_component(self, component_name: str, component_type: str):
match component_type:
case "action":
target_component_type = ComponentType.ACTION
case "command":
target_component_type = ComponentType.COMMAND
case "event_handler":
target_component_type = ComponentType.EVENT_HANDLER
case _:
await self._send_message(f"未知组件类型: {component_type}")
return
success = await component_manage_api.globally_disable_component(component_name, target_component_type)
if success:
await self._send_message(f"全局禁用组件成功: {component_name}")
else:
await self._send_message(f"全局禁用组件失败: {component_name}")
async def _locally_enable_component(self, component_name: str, component_type: str):
match component_type:
case "action":
target_component_type = ComponentType.ACTION
case "command":
target_component_type = ComponentType.COMMAND
case "event_handler":
target_component_type = ComponentType.EVENT_HANDLER
case _:
await self._send_message(f"未知组件类型: {component_type}")
return
if component_manage_api.locally_enable_component(
component_name,
target_component_type,
self.message.chat_stream.stream_id,
):
await self._send_message(f"本地启用组件成功: {component_name}")
else:
await self._send_message(f"本地启用组件失败: {component_name}")
async def _locally_disable_component(self, component_name: str, component_type: str):
match component_type:
case "action":
target_component_type = ComponentType.ACTION
case "command":
target_component_type = ComponentType.COMMAND
case "event_handler":
target_component_type = ComponentType.EVENT_HANDLER
case _:
await self._send_message(f"未知组件类型: {component_type}")
return
if component_manage_api.locally_disable_component(
component_name,
target_component_type,
self.message.chat_stream.stream_id,
):
await self._send_message(f"本地禁用组件成功: {component_name}")
else:
await self._send_message(f"本地禁用组件失败: {component_name}")
async def _send_message(self, message: str):
await send_api.text_to_stream(message, self.stream_id, typing=False, storage_message=False)
@register_plugin
class PluginManagementPlugin(BasePlugin):
plugin_name: str = "plugin_management_plugin"
enable_plugin: bool = False
dependencies: list[str] = []
python_dependencies: list[str] = []
config_file_name: str = "config.toml"
config_schema: dict = {
"plugin": {
"enabled": ConfigField(bool, default=False, description="是否启用插件"),
"config_version": ConfigField(type=str, default="1.1.0", description="配置文件版本"),
"permission": ConfigField(
list, default=[], description="有权限使用插件管理命令的用户列表请填写字符串形式的用户ID"
),
},
}
def get_plugin_components(self) -> List[Tuple[CommandInfo, Type[BaseCommand]]]:
components = []
if self.get_config("plugin.enabled", True):
components.append((ManagementCommand.get_command_info(), ManagementCommand))
return components
def create_plugin():
return PluginManagementPlugin()

View File

@@ -1,7 +1,7 @@
{
"manifest_version": 1,
"name": "文本转语音插件 (Text-to-Speech)",
"version": "0.1.0",
"version": "2.0.0",
"description": "将文本转换为语音进行播放的插件,支持多种语音模式和智能语音输出场景判断。",
"author": {
"name": "MaiBot团队",
@@ -10,7 +10,7 @@
"license": "GPL-v3.0-or-later",
"host_application": {
"min_version": "0.8.0"
"min_version": "1.0.0"
},
"homepage_url": "https://github.com/MaiM-with-u/maibot",
"repository_url": "https://github.com/MaiM-with-u/maibot",

View File

@@ -1,145 +1,55 @@
from src.plugin_system.apis.plugin_register_api import register_plugin
from src.plugin_system.base.base_plugin import BasePlugin
from src.plugin_system.base.component_types import ComponentInfo
from src.common.logger import get_logger
from src.plugin_system.base.base_action import BaseAction, ActionActivationType
from src.plugin_system.base.config_types import ConfigField
from typing import Tuple, List, Type
"""TTS 插件 — 新 SDK 版本
logger = get_logger("tts")
将文本转换为语音进行播放。
"""
import re
from maibot_sdk import MaiBotPlugin, Action
from maibot_sdk.types import ActivationType
class TTSAction(BaseAction):
"""TTS语音转换动作处理类"""
class TTSPlugin(MaiBotPlugin):
"""文本转语音插件"""
# 激活设置
activation_type = ActionActivationType.KEYWORD
activation_keywords = ["语音", "tts", "播报", "读出来", "语音播放", "", "朗读"]
keyword_case_sensitive = False
parallel_action = False
# 动作基本信息
action_name = "tts_action"
action_description = "将文本转换为语音进行播放,适用于需要语音输出的场景"
# 动作参数定义
action_parameters = {
"voice_text": "你想用语音表达的内容,这段内容将会以语音形式发出",
}
# 动作使用场景
@Action(
"tts_action",
description="将文本转换为语音进行播放,适用于需要语音输出的场景",
activation_type=ActivationType.KEYWORD,
activation_keywords=["语音", "tts", "播报", "读出来", "语音播放", "", "朗读"],
parallel_action=False,
action_parameters={"voice_text": "你想用语音表达的内容,这段内容将会以语音形式发出"},
action_require=[
"当需要发送语音信息时使用",
"当用户明确要求使用语音功能时使用",
"当表达内容更适合用语音而不是文字传达时使用",
"当用户想听到语音回答而非阅读文本时使用",
]
# 关联类型
associated_types = ["tts_text"]
async def execute(self) -> Tuple[bool, str]:
],
associated_types=["tts_text"],
)
async def handle_tts_action(self, stream_id: str = "", action_data: dict = None, reasoning: str = "", **kwargs):
"""处理 TTS 文本转语音动作"""
logger.info(f"{self.log_prefix} 执行TTS动作: {self.reasoning}")
# 获取要转换的文本
text = self.action_data.get("voice_text")
action_data = action_data or {}
text = action_data.get("voice_text", "")
if not text:
logger.error(f"{self.log_prefix} 执行TTS动作时未提供文本内容")
return False, "执行TTS动作失败未提供文本内容"
# 确保文本适合TTS使用
processed_text = self._process_text_for_tts(text)
try:
# 发送TTS消息
await self.send_custom(message_type="tts_text", content=processed_text)
# 记录动作信息
await self.store_action_info(
action_build_into_prompt=True, action_prompt_display="已经发送了语音消息。", action_done=True
)
logger.info(f"{self.log_prefix} TTS动作执行成功文本长度: {len(processed_text)}")
return True, "TTS动作执行成功"
except Exception as e:
logger.error(f"{self.log_prefix} 执行TTS动作时出错: {e}")
return False, f"执行TTS动作时出错: {e}"
def _process_text_for_tts(self, text: str) -> str:
"""
处理文本使其更适合TTS使用
- 移除不必要的特殊字符和表情符号
- 修正标点符号以提高语音质量
- 优化文本结构使语音更流畅
"""
# 这里可以添加文本处理逻辑
# 例如:移除多余的标点、表情符号,优化语句结构等
# 简单示例实现
processed_text = text
# 移除多余的标点符号
import re
processed_text = re.sub(r"([!?,.;:。!?,、;:])\1+", r"\1", processed_text)
# 确保句子结尾有合适的标点
# 文本预处理
processed_text = re.sub(r"([!?,.;:。!?,、;:])\1+", r"\1", text)
if not any(processed_text.endswith(end) for end in [".", "?", "!", "", "", ""]):
processed_text = f"{processed_text}"
return processed_text
# 发送自定义 tts 消息
result = await self.ctx.call_capability(
"send.custom",
message_type="tts_text",
content=processed_text,
stream_id=stream_id,
)
if result and result.get("success"):
return True, "TTS动作执行成功"
return False, f"TTS动作执行失败: {result}"
@register_plugin
class TTSPlugin(BasePlugin):
"""TTS插件
- 这是文字转语音插件
- Normal模式下依靠关键词触发
- Focus模式下由LLM判断触发
- 具有一定的文本预处理能力
"""
# 插件基本信息
plugin_name: str = "tts_plugin" # 内部标识符
enable_plugin: bool = True
dependencies: list[str] = [] # 插件依赖列表
python_dependencies: list[str] = [] # Python包依赖列表
config_file_name: str = "config.toml"
# 配置节描述
config_section_descriptions = {
"plugin": "插件基本信息配置",
"components": "组件启用控制",
"logging": "日志记录相关配置",
}
# 配置Schema定义
config_schema: dict = {
"plugin": {
"name": ConfigField(type=str, default="tts_plugin", description="插件名称", required=True),
"version": ConfigField(type=str, default="0.1.0", description="插件版本号"),
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
"description": ConfigField(type=str, default="文字转语音插件", description="插件描述", required=True),
},
"components": {"enable_tts": ConfigField(type=bool, default=True, description="是否启用TTS Action")},
"logging": {
"level": ConfigField(
type=str, default="INFO", description="日志记录级别", choices=["DEBUG", "INFO", "WARNING", "ERROR"]
),
"prefix": ConfigField(type=str, default="[TTS]", description="日志记录前缀"),
},
}
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
"""返回插件包含的组件列表"""
# 从配置获取组件启用状态
enable_tts = self.get_config("components.enable_tts", True)
components = [] # 添加Action组件
if enable_tts:
components.append((TTSAction.get_action_info(), TTSAction))
return components
def create_plugin():
return TTSPlugin()

7
src/services/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
"""
核心服务层
提供与具体插件系统无关的核心业务服务。
内部模块chat、dream、memory 等)应直接使用此层,
而 plugin_system.apis 仅作为面向插件的薄包装。
"""

View File

@@ -0,0 +1,159 @@
"""
聊天服务模块
提供聊天信息查询和管理的核心功能。
"""
from typing import Any, Dict, List, Optional
from enum import Enum
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
from src.common.logger import get_logger
logger = get_logger("chat_service")
class SpecialTypes(Enum):
"""特殊枚举类型"""
ALL_PLATFORMS = "all_platforms"
class ChatManager:
"""聊天管理器 - 负责聊天信息的查询和管理"""
@staticmethod
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
# sourcery skip: for-append-to-extend
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = []
try:
for _, stream in _chat_manager.sessions.items():
if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform:
streams.append(stream)
logger.debug(f"[ChatService] 获取到 {len(streams)}{platform} 平台的聊天流")
except Exception as e:
logger.error(f"[ChatService] 获取聊天流失败: {e}")
return streams
@staticmethod
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
# sourcery skip: for-append-to-extend
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = []
try:
for _, stream in _chat_manager.sessions.items():
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.is_group_session:
streams.append(stream)
logger.debug(f"[ChatService] 获取到 {len(streams)}{platform} 平台的群聊流")
except Exception as e:
logger.error(f"[ChatService] 获取群聊流失败: {e}")
return streams
@staticmethod
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
# sourcery skip: for-append-to-extend
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = []
try:
for _, stream in _chat_manager.sessions.items():
if (
platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform
) and not stream.is_group_session:
streams.append(stream)
logger.debug(f"[ChatService] 获取到 {len(streams)}{platform} 平台的私聊流")
except Exception as e:
logger.error(f"[ChatService] 获取私聊流失败: {e}")
return streams
@staticmethod
def get_group_stream_by_group_id(
group_id: str, platform: Optional[str] | SpecialTypes = "qq"
) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast
if not isinstance(group_id, str):
raise TypeError("group_id 必须是字符串类型")
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
if not group_id:
raise ValueError("group_id 不能为空")
try:
for _, stream in _chat_manager.sessions.items():
if (
stream.is_group_session
and str(stream.group_id) == str(group_id)
and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform)
):
logger.debug(f"[ChatService] 找到群ID {group_id} 的聊天流")
return stream
logger.warning(f"[ChatService] 未找到群ID {group_id} 的聊天流")
except Exception as e:
logger.error(f"[ChatService] 查找群聊流失败: {e}")
return None
@staticmethod
def get_private_stream_by_user_id(
user_id: str, platform: Optional[str] | SpecialTypes = "qq"
) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast
if not isinstance(user_id, str):
raise TypeError("user_id 必须是字符串类型")
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
if not user_id:
raise ValueError("user_id 不能为空")
try:
for _, stream in _chat_manager.sessions.items():
if (
not stream.is_group_session
and str(stream.user_id) == str(user_id)
and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform)
):
logger.debug(f"[ChatService] 找到用户ID {user_id} 的私聊流")
return stream
logger.warning(f"[ChatService] 未找到用户ID {user_id} 的私聊流")
except Exception as e:
logger.error(f"[ChatService] 查找私聊流失败: {e}")
return None
@staticmethod
def get_stream_type(chat_stream: BotChatSession) -> str:
if not isinstance(chat_stream, BotChatSession):
raise TypeError("chat_stream 必须是 BotChatSession 类型")
if not chat_stream:
raise ValueError("chat_stream 不能为 None")
return "group" if chat_stream.is_group_session else "private"
@staticmethod
def get_stream_info(chat_stream: BotChatSession) -> Dict[str, Any]:
if not chat_stream:
raise ValueError("chat_stream 不能为 None")
if not isinstance(chat_stream, BotChatSession):
raise TypeError("chat_stream 必须是 BotChatSession 类型")
try:
info: Dict[str, Any] = {
"session_id": chat_stream.session_id,
"platform": chat_stream.platform,
"type": ChatManager.get_stream_type(chat_stream),
}
if chat_stream.is_group_session:
info["group_id"] = chat_stream.group_id
if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.group_info:
info["group_name"] = chat_stream.context.message.message_info.group_info.group_name or "未知群聊"
else:
info["group_name"] = "未知群聊"
else:
info["user_id"] = chat_stream.user_id
if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.user_info:
info["user_name"] = chat_stream.context.message.message_info.user_info.user_nickname
else:
info["user_name"] = "未知用户"
return info
except Exception as e:
logger.error(f"[ChatService] 获取聊天流信息失败: {e}")
return {}

View File

@@ -1,28 +1,19 @@
"""配置API模块
"""配置服务模块
提供配置读取和用户信息获取等功能
使用方式
from src.plugin_system.apis import config_api
value = config_api.get_global_config("section.key")
platform, user_id = await config_api.get_user_id_by_person_name("用户名")
提供配置读取的核心功能
"""
from typing import Any
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger("config_api")
# =============================================================================
# 配置访问API函数
# =============================================================================
logger = get_logger("config_service")
def get_global_config(key: str, default: Any = None) -> Any:
"""
安全地从全局配置中获取一个值
插件应使用此方法读取全局配置以保证只读和隔离性
Args:
key: 命名空间式配置键名使用嵌套访问 "section.subsection.key"大小写敏感
@@ -31,7 +22,6 @@ def get_global_config(key: str, default: Any = None) -> Any:
Returns:
Any: 配置值或默认值
"""
# 支持嵌套键访问
keys = key.split(".")
current = global_config
@@ -43,7 +33,7 @@ def get_global_config(key: str, default: Any = None) -> Any:
raise KeyError(f"配置中不存在子空间或键 '{k}'")
return current
except Exception as e:
logger.warning(f"[ConfigAPI] 获取全局配置 {key} 失败: {e}")
logger.warning(f"[ConfigService] 获取全局配置 {key} 失败: {e}")
return default
@@ -59,7 +49,6 @@ def get_plugin_config(plugin_config: dict, key: str, default: Any = None) -> Any
Returns:
Any: 配置值或默认值
"""
# 支持嵌套键访问
keys = key.split(".")
current = plugin_config
@@ -73,5 +62,5 @@ def get_plugin_config(plugin_config: dict, key: str, default: Any = None) -> Any
raise KeyError(f"配置中不存在子空间或键 '{k}'")
return current
except Exception as e:
logger.warning(f"[ConfigAPI] 获取插件配置 {key} 失败: {e}")
logger.warning(f"[ConfigService] 获取插件配置 {key} 失败: {e}")
return default

View File

@@ -1,6 +1,6 @@
"""数据库API模块
"""数据库服务模块
提供数据库操作相关功能统一使用 SQLModel/SQLAlchemy 兼容接口
提供数据库操作相关的核心功能
"""
import json
@@ -10,7 +10,7 @@ from typing import Any, Optional
from src.common.logger import get_logger
logger = get_logger("database_api")
logger = get_logger("database_service")
def _to_dict(record: Any) -> dict[str, Any]:
@@ -73,7 +73,7 @@ async def db_query(
return query.count()
except Exception as e:
logger.error(f"[DatabaseAPI] 数据库操作出错: {e}")
logger.error(f"[DatabaseService] 数据库操作出错: {e}")
traceback.print_exc()
if query_type == "get":
return None if single_result else []
@@ -93,7 +93,7 @@ async def db_save(model_class, data: dict[str, Any], key_field: Optional[str] =
new_record = model_class.create(**data)
return _to_dict(new_record)
except Exception as e:
logger.error(f"[DatabaseAPI] 保存数据库记录出错: {e}")
logger.error(f"[DatabaseService] 保存数据库记录出错: {e}")
traceback.print_exc()
return None
@@ -119,7 +119,7 @@ async def db_get(
return results[0] if results else None
return results
except Exception as e:
logger.error(f"[DatabaseAPI] 获取数据库记录出错: {e}")
logger.error(f"[DatabaseService] 获取数据库记录出错: {e}")
traceback.print_exc()
return None if single_result else []
@@ -163,11 +163,11 @@ async def store_action_info(
ActionRecords, data=record_data, key_field="action_id", key_value=record_data["action_id"]
)
if saved_record:
logger.debug(f"[DatabaseAPI] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})")
logger.debug(f"[DatabaseService] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})")
else:
logger.error(f"[DatabaseAPI] 存储动作信息失败: {action_name}")
logger.error(f"[DatabaseService] 存储动作信息失败: {action_name}")
return saved_record
except Exception as e:
logger.error(f"[DatabaseAPI] 存储动作信息时发生错误: {e}")
logger.error(f"[DatabaseService] 存储动作信息时发生错误: {e}")
traceback.print_exc()
return None

View File

@@ -0,0 +1,406 @@
"""
表情服务模块
提供表情包相关的核心功能。
"""
import base64
import os
import random
import uuid
from typing import Any, Dict, List, Optional, Tuple
from src.chat.emoji_system.emoji_manager import emoji_manager, EMOJI_DIR
from src.common.logger import get_logger
from src.common.utils.utils_image import ImageUtils
from src.config.config import global_config
logger = get_logger("emoji_service")
# =============================================================================
# 表情包获取函数
# =============================================================================
async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]:
"""根据描述选择表情包"""
if not description:
raise ValueError("描述不能为空")
if not isinstance(description, str):
raise TypeError("描述必须是字符串类型")
try:
logger.debug(f"[EmojiService] 根据描述获取表情包: {description}")
emoji_obj = await emoji_manager.get_emoji_for_emotion(description)
if not emoji_obj:
logger.warning(f"[EmojiService] 未找到匹配描述 '{description}' 的表情包")
return None
emoji_path = str(emoji_obj.full_path)
emoji_description = emoji_obj.description
matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else ""
emoji_base64 = ImageUtils.image_path_to_base64(emoji_path)
if not emoji_base64:
logger.error(f"[EmojiService] 无法将表情包文件转换为base64: {emoji_path}")
return None
logger.debug(f"[EmojiService] 成功获取表情包: {emoji_description}, 匹配情感: {matched_emotion}")
return emoji_base64, emoji_description, matched_emotion
except Exception as e:
logger.error(f"[EmojiService] 获取表情包失败: {e}")
return None
async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
"""随机获取指定数量的表情包"""
if not isinstance(count, int):
raise TypeError("count 必须是整数类型")
if count < 0:
raise ValueError("count 不能为负数")
if count == 0:
logger.warning("[EmojiService] count 为0返回空列表")
return []
try:
all_emojis = emoji_manager.emojis
if not all_emojis:
logger.warning("[EmojiService] 没有可用的表情包")
return []
valid_emojis = [emoji for emoji in all_emojis if not emoji.is_deleted]
if not valid_emojis:
logger.warning("[EmojiService] 没有有效的表情包")
return []
if len(valid_emojis) < count:
logger.debug(
f"[EmojiService] 有效表情包数量 ({len(valid_emojis)}) 少于请求的数量 ({count}),将返回所有有效表情包"
)
count = len(valid_emojis)
selected_emojis = random.sample(valid_emojis, count)
results = []
for selected_emoji in selected_emojis:
emoji_base64 = ImageUtils.image_path_to_base64(str(selected_emoji.full_path))
if not emoji_base64:
logger.error(f"[EmojiService] 无法转换表情包为base64: {selected_emoji.full_path}")
continue
matched_emotion = random.choice(selected_emoji.emotion) if selected_emoji.emotion else "随机表情"
emoji_manager.update_emoji_usage(selected_emoji)
results.append((emoji_base64, selected_emoji.description, matched_emotion))
if not results and count > 0:
logger.warning("[EmojiService] 随机获取表情包失败,没有一个可以成功处理")
return []
logger.debug(f"[EmojiService] 成功获取 {len(results)} 个随机表情包")
return results
except Exception as e:
logger.error(f"[EmojiService] 获取随机表情包失败: {e}")
return []
async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
"""根据情感标签获取表情包"""
if not emotion:
raise ValueError("情感标签不能为空")
if not isinstance(emotion, str):
raise TypeError("情感标签必须是字符串类型")
try:
logger.info(f"[EmojiService] 根据情感获取表情包: {emotion}")
all_emojis = emoji_manager.emojis
matching_emojis = []
matching_emojis.extend(
emoji_obj
for emoji_obj in all_emojis
if not emoji_obj.is_deleted and emotion.lower() in [e.lower() for e in emoji_obj.emotion]
)
if not matching_emojis:
logger.warning(f"[EmojiService] 未找到匹配情感 '{emotion}' 的表情包")
return None
selected_emoji = random.choice(matching_emojis)
emoji_base64 = ImageUtils.image_path_to_base64(selected_emoji.full_path)
if not emoji_base64:
logger.error(f"[EmojiService] 无法转换表情包为base64: {selected_emoji.full_path}")
return None
emoji_manager.update_emoji_usage(selected_emoji)
logger.info(f"[EmojiService] 成功获取情感表情包: {selected_emoji.description}")
return emoji_base64, selected_emoji.description, emotion
except Exception as e:
logger.error(f"[EmojiService] 根据情感获取表情包失败: {e}")
return None
# =============================================================================
# 表情包信息查询函数
# =============================================================================
def get_count() -> int:
try:
return len(emoji_manager.emojis)
except Exception as e:
logger.error(f"[EmojiService] 获取表情包数量失败: {e}")
return 0
def get_info():
try:
return {
"current_count": len(emoji_manager.emojis),
"max_count": global_config.emoji.max_reg_num,
"available_emojis": len([e for e in emoji_manager.emojis if not e.is_deleted]),
}
except Exception as e:
logger.error(f"[EmojiService] 获取表情包信息失败: {e}")
return {"current_count": 0, "max_count": 0, "available_emojis": 0}
def get_emotions() -> List[str]:
try:
emotions = set()
for emoji_obj in emoji_manager.emojis:
if not emoji_obj.is_deleted and emoji_obj.emotion:
emotions.update(emoji_obj.emotion)
return sorted(list(emotions))
except Exception as e:
logger.error(f"[EmojiService] 获取情感标签失败: {e}")
return []
async def get_all() -> List[Tuple[str, str, str]]:
try:
all_emojis = emoji_manager.emojis
if not all_emojis:
logger.warning("[EmojiService] 没有可用的表情包")
return []
results = []
for emoji_obj in all_emojis:
if emoji_obj.is_deleted:
continue
emoji_base64 = ImageUtils.image_path_to_base64(str(emoji_obj.full_path))
if not emoji_base64:
logger.error(f"[EmojiService] 无法转换表情包为base64: {emoji_obj.full_path}")
continue
matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else "随机表情"
results.append((emoji_base64, emoji_obj.description, matched_emotion))
logger.debug(f"[EmojiService] 成功获取 {len(results)} 个表情包")
return results
except Exception as e:
logger.error(f"[EmojiService] 获取所有表情包失败: {e}")
return []
def get_descriptions() -> List[str]:
try:
descriptions = []
descriptions.extend(
emoji_obj.description
for emoji_obj in emoji_manager.emojis
if not emoji_obj.is_deleted and emoji_obj.description
)
return descriptions
except Exception as e:
logger.error(f"[EmojiService] 获取表情包描述失败: {e}")
return []
# =============================================================================
# 表情包注册函数
# =============================================================================
async def register_emoji(image_base64: str, filename: Optional[str] = None) -> Dict[str, Any]:
"""注册新的表情包"""
if not image_base64:
raise ValueError("图片base64编码不能为空")
if not isinstance(image_base64, str):
raise TypeError("image_base64必须是字符串类型")
if filename is not None and not isinstance(filename, str):
raise TypeError("filename必须是字符串类型或None")
try:
logger.info(f"[EmojiService] 开始注册表情包,文件名: {filename or '自动生成'}")
count_before = len(emoji_manager.emojis)
max_count = global_config.emoji.max_reg_num
can_register = count_before < max_count or (count_before >= max_count and global_config.emoji.do_replace)
if not can_register:
return {
"success": False,
"message": f"表情包数量已达上限({count_before}/{max_count})且未启用替换功能",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
os.makedirs(EMOJI_DIR, exist_ok=True)
if not filename:
import time as _time
timestamp = int(_time.time())
microseconds = int(_time.time() * 1000000) % 1000000
random_bytes = random.getrandbits(72).to_bytes(9, "big")
short_id = base64.b64encode(random_bytes).decode("ascii")[:12].rstrip("=")
short_id = short_id.replace("/", "_").replace("+", "-")
filename = f"emoji_{timestamp}_{microseconds}_{short_id}"
if not filename.lower().endswith((".jpg", ".jpeg", ".png", ".gif")):
filename = f"{filename}.png"
temp_file_path = os.path.join(EMOJI_DIR, filename)
attempts = 0
max_attempts = 10
while os.path.exists(temp_file_path) and attempts < max_attempts:
random_bytes = random.getrandbits(48).to_bytes(6, "big")
short_id = base64.b64encode(random_bytes).decode("ascii")[:8].rstrip("=")
short_id = short_id.replace("/", "_").replace("+", "-")
name_part, ext = os.path.splitext(filename)
base_name = name_part.rsplit("_", 1)[0]
filename = f"{base_name}_{short_id}{ext}"
temp_file_path = os.path.join(EMOJI_DIR, filename)
attempts += 1
if os.path.exists(temp_file_path):
uuid_short = str(uuid.uuid4())[:8]
name_part, ext = os.path.splitext(filename)
base_name = name_part.rsplit("_", 1)[0]
filename = f"{base_name}_{uuid_short}{ext}"
temp_file_path = os.path.join(EMOJI_DIR, filename)
counter = 1
original_filename = filename
while os.path.exists(temp_file_path):
name_part, ext = os.path.splitext(original_filename)
filename = f"{name_part}_{counter}{ext}"
temp_file_path = os.path.join(EMOJI_DIR, filename)
counter += 1
if counter > 100:
logger.error(f"[EmojiService] 无法生成唯一文件名,尝试次数过多: {original_filename}")
return {
"success": False,
"message": "无法生成唯一文件名,请稍后重试",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
try:
if not ImageUtils.base64_to_image(image_base64, temp_file_path):
logger.error(f"[EmojiService] 无法保存base64图片到文件: {temp_file_path}")
return {
"success": False,
"message": "无法保存图片文件",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
logger.debug(f"[EmojiService] 图片已保存到临时文件: {temp_file_path}")
except Exception as save_error:
logger.error(f"[EmojiService] 保存图片文件失败: {save_error}")
return {
"success": False,
"message": f"保存图片文件失败: {str(save_error)}",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
register_success = await emoji_manager.register_emoji_by_filename(filename)
if not register_success and os.path.exists(temp_file_path):
try:
os.remove(temp_file_path)
logger.debug(f"[EmojiService] 已清理临时文件: {temp_file_path}")
except Exception as cleanup_error:
logger.warning(f"[EmojiService] 清理临时文件失败: {cleanup_error}")
if register_success:
count_after = len(emoji_manager.emojis)
replaced = count_after <= count_before
new_emoji_info = None
if count_after > count_before or replaced:
try:
for emoji_obj in reversed(emoji_manager.emojis):
if not emoji_obj.is_deleted and (
emoji_obj.file_name == filename
or (hasattr(emoji_obj, "full_path") and filename in str(emoji_obj.full_path))
):
new_emoji_info = emoji_obj
break
except Exception as find_error:
logger.warning(f"[EmojiService] 查找新注册表情包信息失败: {find_error}")
description = new_emoji_info.description if new_emoji_info else None
emotions = new_emoji_info.emotion if new_emoji_info else None
emoji_hash = new_emoji_info.emoji_hash if new_emoji_info else None
return {
"success": True,
"message": f"表情包注册成功 {'(替换旧表情包)' if replaced else '(新增表情包)'}",
"description": description,
"emotions": emotions,
"replaced": replaced,
"hash": emoji_hash,
}
else:
return {
"success": False,
"message": "表情包注册失败,可能因为重复、格式不支持或审核未通过",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
except Exception as e:
logger.error(f"[EmojiService] 注册表情包时发生异常: {e}")
return {
"success": False,
"message": f"注册过程中发生错误: {str(e)}",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}

View File

@@ -1,9 +1,11 @@
from src.common.logger import get_logger
"""频率控制服务模块
提供聊天频率控制的核心功能
"""
from src.chat.heart_flow.frequency_control import frequency_control_manager
from src.config.config import global_config
logger = get_logger("frequency_api")
def get_current_talk_value(chat_id: str) -> float:
return frequency_control_manager.get_or_create_frequency_control(

View File

@@ -1,39 +1,37 @@
"""
回复器API模块
回复器服务模块
提供回复器相关功能采用标准Python包设计模式
使用方式
from src.plugin_system.apis import generator_api
replyer = generator_api.get_replyer(chat_stream)
success, reply_set, _ = await generator_api.generate_reply(chat_stream, action_data, reasoning)
提供回复器相关的核心功能
"""
import traceback
import time
from typing import Tuple, Any, Dict, List, Optional, TYPE_CHECKING
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from rich.traceback import install
from src.common.logger import get_logger
from src.common.data_models.message_data_model import ReplySetModel
from src.chat.logger.plan_reply_logger import PlanReplyLogger
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.replyer.group_generator import DefaultReplyer
from src.chat.replyer.private_generator import PrivateReplyer
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.utils.utils import process_llm_response
from src.chat.replyer.replyer_manager import replyer_manager
from src.plugin_system.base.component_types import ActionInfo
from src.chat.logger.plan_reply_logger import PlanReplyLogger
from src.chat.utils.utils import process_llm_response
from src.common.data_models.message_data_model import ReplySetModel
from src.common.logger import get_logger
from src.core.types import ActionInfo
if TYPE_CHECKING:
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.llm_data_model import LLMGenerationDataModel
install(extra_lines=3)
logger = get_logger("generator_api")
logger = get_logger("generator_service")
# =============================================================================
# 回复器获取API函数
# 回复器获取函数
# =============================================================================
@@ -42,39 +40,24 @@ def get_replyer(
chat_id: Optional[str] = None,
request_type: str = "replyer",
) -> Optional[DefaultReplyer | PrivateReplyer]:
"""获取回复器对象
优先使用chat_stream如果没有则使用chat_id直接查找
使用 ReplyerManager 来管理实例避免重复创建
Args:
chat_stream: 聊天流对象优先
chat_id: 聊天ID实际上就是stream_id
request_type: 请求类型
Returns:
Optional[DefaultReplyer]: 回复器对象如果获取失败则返回None
Raises:
ValueError: chat_stream chat_id 均为空
"""
"""获取回复器对象"""
if not chat_id and not chat_stream:
raise ValueError("chat_stream 和 chat_id 不可均为空")
try:
logger.debug(f"[GeneratorAPI] 正在获取回复器chat_id: {chat_id}, chat_stream: {'' if chat_stream else ''}")
logger.debug(f"[GeneratorService] 正在获取回复器chat_id: {chat_id}, chat_stream: {'' if chat_stream else ''}")
return replyer_manager.get_replyer(
chat_stream=chat_stream,
chat_id=chat_id,
request_type=request_type,
)
except Exception as e:
logger.error(f"[GeneratorAPI] 获取回复器时发生意外错误: {e}", exc_info=True)
logger.error(f"[GeneratorService] 获取回复器时发生意外错误: {e}", exc_info=True)
traceback.print_exc()
return None
# =============================================================================
# 回复生成API函数
# 回复生成函数
# =============================================================================
@@ -96,39 +79,15 @@ async def generate_reply(
from_plugin: bool = True,
reply_time_point: Optional[float] = None,
) -> Tuple[bool, Optional["LLMGenerationDataModel"]]:
"""生成回复
Args:
chat_stream: 聊天流对象优先
chat_id: 聊天ID备用
action_data: 动作数据向下兼容包含reply_to和extra_info
reply_message: 回复的消息对象
extra_info: 额外信息用于补充上下文
reply_reason: 回复原因
available_actions: 可用动作
chosen_actions: 已选动作
unknown_words: Planner reply 动作中给出的未知词语列表用于黑话检索
enable_tool: 是否启用工具调用
enable_splitter: 是否启用消息分割器
enable_chinese_typo: 是否启用错字生成器
return_prompt: 是否返回提示词
model_set_with_weight: 模型配置列表每个元素为 (TaskConfig, weight) 元组
request_type: 请求类型可选记录LLM使用
from_plugin: 是否来自插件
reply_time_point: 回复时间点
Returns:
Tuple[bool, List[Tuple[str, Any]], Optional[str]]: (是否成功, 回复集合, 提示词)
"""
"""生成回复"""
try:
# 如果 reply_time_point 未传入,设置为当前时间戳
if reply_time_point is None:
reply_time_point = time.time()
# 获取回复器
logger.debug("[GeneratorAPI] 开始生成回复")
logger.debug("[GeneratorService] 开始生成回复")
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")
logger.error("[GeneratorService] 无法获取回复器")
return False, None
if action_data:
@@ -136,11 +95,9 @@ async def generate_reply(
extra_info = action_data.get("extra_info", "")
if not reply_reason:
reply_reason = action_data.get("reason", "")
# 仅在 reply 场景下使用的未知词语解析Planner JSON 中下发)
if unknown_words is None:
uw = action_data.get("unknown_words")
if isinstance(uw, list):
# 只保留非空字符串
cleaned: List[str] = []
for item in uw:
if isinstance(item, str):
@@ -150,7 +107,6 @@ async def generate_reply(
if cleaned:
unknown_words = cleaned
# 调用回复器生成回复
success, llm_response = await replyer.generate_reply_with_context(
extra_info=extra_info,
available_actions=available_actions,
@@ -166,7 +122,7 @@ async def generate_reply(
log_reply=False,
)
if not success:
logger.warning("[GeneratorAPI] 回复生成失败")
logger.warning("[GeneratorService] 回复生成失败")
return False, None
reply_set: Optional[ReplySetModel] = None
if content := llm_response.content:
@@ -176,9 +132,8 @@ async def generate_reply(
for text in processed_response:
reply_set.add_text_content(text)
llm_response.reply_set = reply_set
logger.debug(f"[GeneratorAPI] 回复生成成功,生成了 {len(reply_set) if reply_set else 0} 个回复项")
logger.debug(f"[GeneratorService] 回复生成成功,生成了 {len(reply_set) if reply_set else 0} 个回复项")
# 统一在这里记录最终回复日志(包含分割后的 processed_output
try:
PlanReplyLogger.log_reply(
chat_id=chat_stream.session_id if chat_stream else (chat_id or ""),
@@ -192,7 +147,7 @@ async def generate_reply(
success=True,
)
except Exception:
logger.exception("[GeneratorAPI] 记录reply日志失败")
logger.exception("[GeneratorService] 记录reply日志失败")
return success, llm_response
@@ -200,11 +155,11 @@ async def generate_reply(
raise ve
except UserWarning as uw:
logger.warning(f"[GeneratorAPI] 中断了生成: {uw}")
logger.warning(f"[GeneratorService] 中断了生成: {uw}")
return False, None
except Exception as e:
logger.error(f"[GeneratorAPI] 生成回复时出错: {e}")
logger.error(f"[GeneratorService] 生成回复时出错: {e}")
logger.error(traceback.format_exc())
return False, None
@@ -220,39 +175,20 @@ async def rewrite_reply(
reply_to: str = "",
request_type: str = "generator_api",
) -> Tuple[bool, Optional["LLMGenerationDataModel"]]:
"""重写回复
Args:
chat_stream: 聊天流对象优先
reply_data: 回复数据字典向下兼容备用当其他参数缺失时从此获取
chat_id: 聊天ID备用
enable_splitter: 是否启用消息分割器
enable_chinese_typo: 是否启用错字生成器
model_set_with_weight: 模型配置列表每个元素为 (TaskConfig, weight) 元组
raw_reply: 原始回复内容
reason: 回复原因
reply_to: 回复对象
return_prompt: 是否返回提示词
Returns:
Tuple[bool, List[Tuple[str, Any]]]: (是否成功, 回复集合)
"""
"""重写回复"""
try:
# 获取回复器
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")
logger.error("[GeneratorService] 无法获取回复器")
return False, None
logger.info("[GeneratorAPI] 开始重写回复")
logger.info("[GeneratorService] 开始重写回复")
# 如果参数缺失从reply_data中获取
if reply_data:
raw_reply = raw_reply or reply_data.get("raw_reply", "")
reason = reason or reply_data.get("reason", "")
reply_to = reply_to or reply_data.get("reply_to", "")
# 调用回复器重写回复
success, llm_response = await replyer.rewrite_reply_with_context(
raw_reply=raw_reply,
reason=reason,
@@ -263,9 +199,9 @@ async def rewrite_reply(
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
llm_response.reply_set = reply_set
if success:
logger.info(f"[GeneratorAPI] 重写回复成功,生成了 {len(reply_set) if reply_set else 0} 个回复项")
logger.info(f"[GeneratorService] 重写回复成功,生成了 {len(reply_set) if reply_set else 0} 个回复项")
else:
logger.warning("[GeneratorAPI] 重写回复失败")
logger.warning("[GeneratorService] 重写回复失败")
return success, llm_response
@@ -273,18 +209,12 @@ async def rewrite_reply(
raise ve
except Exception as e:
logger.error(f"[GeneratorAPI] 重写回复时出错: {e}")
logger.error(f"[GeneratorService] 重写回复时出错: {e}")
return False, None
def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> Optional[ReplySetModel]:
"""将文本处理为更拟人化的文本
Args:
content: 文本内容
enable_splitter: 是否启用消息分割器
enable_chinese_typo: 是否启用错字生成器
"""
"""将文本处理为更拟人化的文本"""
if not isinstance(content, str):
raise ValueError("content 必须是字符串类型")
try:
@@ -297,7 +227,7 @@ def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo:
return reply_set
except Exception as e:
logger.error(f"[GeneratorAPI] 处理人形文本时出错: {e}")
logger.error(f"[GeneratorService] 处理人形文本时出错: {e}")
return None
@@ -309,18 +239,18 @@ async def generate_response_custom(
) -> Optional[str]:
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")
logger.error("[GeneratorService] 无法获取回复器")
return None
try:
logger.debug("[GeneratorAPI] 开始生成自定义回复")
logger.debug("[GeneratorService] 开始生成自定义回复")
response, _, _, _ = await replyer.llm_generate_content(prompt)
if response:
logger.debug("[GeneratorAPI] 自定义回复生成成功")
logger.debug("[GeneratorService] 自定义回复生成成功")
return response
else:
logger.warning("[GeneratorAPI] 自定义回复生成失败")
logger.warning("[GeneratorService] 自定义回复生成失败")
return None
except Exception as e:
logger.error(f"[GeneratorAPI] 生成自定义回复时出错: {e}")
logger.error(f"[GeneratorService] 生成自定义回复时出错: {e}")
return None

View File

@@ -1,26 +1,19 @@
"""LLM API模块
"""LLM 服务模块
提供了与LLM模型交互的功能
使用方式
from src.plugin_system.apis import llm_api
models = llm_api.get_available_models()
success, response, reasoning, model_name = await llm_api.generate_with_model(prompt, model_config)
提供 LLM 模型交互的核心功能
"""
from typing import Tuple, Dict, List, Any, Optional, Callable
from typing import Any, Callable, Dict, List, Optional, Tuple
from src.common.logger import get_logger
from src.llm_models.payload_content.tool_option import ToolCall
from src.llm_models.payload_content.message import Message
from src.llm_models.model_client.base_client import BaseClient
from src.llm_models.utils_model import LLMRequest
from src.config.config import config_manager
from src.config.model_configs import TaskConfig
from src.llm_models.model_client.base_client import BaseClient
from src.llm_models.payload_content.message import Message
from src.llm_models.payload_content.tool_option import ToolCall
from src.llm_models.utils_model import LLMRequest
logger = get_logger("llm_api")
# =============================================================================
# LLM模型API函数
# =============================================================================
logger = get_logger("llm_service")
def get_available_models() -> Dict[str, TaskConfig]:
@@ -30,7 +23,6 @@ def get_available_models() -> Dict[str, TaskConfig]:
Dict[str, Any]: 模型配置字典key为模型名称value为模型配置
"""
try:
# 自动获取所有属性并转换为字典形式
models = config_manager.get_model_config().model_task_config
attrs = dir(models)
rets: Dict[str, TaskConfig] = {}
@@ -41,12 +33,12 @@ def get_available_models() -> Dict[str, TaskConfig]:
if not callable(value) and isinstance(value, TaskConfig):
rets[attr] = value
except Exception as e:
logger.debug(f"[LLMAPI] 获取属性 {attr} 失败: {e}")
logger.debug(f"[LLMService] 获取属性 {attr} 失败: {e}")
continue
return rets
except Exception as e:
logger.error(f"[LLMAPI] 获取可用模型失败: {e}")
logger.error(f"[LLMService] 获取可用模型失败: {e}")
return {}
@@ -68,9 +60,7 @@ async def generate_with_model(
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
"""
try:
# model_name_list = model_config.model_list
# logger.info(f"[LLMAPI] 使用模型集合 {model_name_list} 生成内容")
logger.debug(f"[LLMAPI] 完整提示词: {prompt}")
logger.debug(f"[LLMService] 完整提示词: {prompt}")
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
@@ -81,7 +71,7 @@ async def generate_with_model(
except Exception as e:
error_msg = f"生成内容时出错: {str(e)}"
logger.error(f"[LLMAPI] {error_msg}")
logger.error(f"[LLMService] {error_msg}")
return False, error_msg, "", ""
@@ -104,7 +94,7 @@ async def generate_with_model_with_tools(
max_tokens: 最大token数
Returns:
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
Tuple[bool, str, str, str, List[ToolCall] | None]: (是否成功, 生成的内容, 推理过程, 模型名称, 工具调用列表)
"""
try:
model_name_list = model_config.model_list
@@ -120,7 +110,7 @@ async def generate_with_model_with_tools(
except Exception as e:
error_msg = f"生成内容时出错: {str(e)}"
logger.error(f"[LLMAPI] {error_msg}")
logger.error(f"[LLMService] {error_msg}")
return False, error_msg, "", "", None
@@ -161,5 +151,5 @@ async def generate_with_model_with_tools_by_message_factory(
except Exception as e:
error_msg = f"生成内容时出错: {str(e)}"
logger.error(f"[LLMAPI] {error_msg}")
logger.error(f"[LLMService] {error_msg}")
return False, error_msg, "", "", None

View File

@@ -1,62 +1,44 @@
"""
消息API模块
消息服务模块
提供消息查询和构建成字符串的功能采用标准Python包设计模式
使用方式
from src.plugin_system.apis import message_api
messages = message_api.get_messages_by_time_in_chat(chat_id, start_time, end_time)
readable_text = message_api.build_readable_messages(messages)
提供消息查询和构建成字符串的核心功能
"""
import time
from typing import List, Dict, Any, Tuple, Optional
from typing import Any, Dict, List, Optional, Tuple
from sqlmodel import col, select
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.database.database import get_db_session
from src.common.database.database_model import Images, ImageType
from src.chat.utils.utils import is_bot_self
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp,
get_raw_msg_by_timestamp_with_chat,
get_raw_msg_by_timestamp_with_chat_inclusive,
get_raw_msg_by_timestamp_with_chat_users,
get_raw_msg_by_timestamp_random,
get_raw_msg_by_timestamp_with_users,
get_raw_msg_before_timestamp,
get_raw_msg_before_timestamp_with_chat,
get_raw_msg_before_timestamp_with_users,
num_new_messages_since,
num_new_messages_since_with_users,
build_readable_messages,
build_readable_messages_with_list,
get_person_id_list,
get_raw_msg_before_timestamp,
get_raw_msg_before_timestamp_with_chat,
get_raw_msg_before_timestamp_with_users,
get_raw_msg_by_timestamp,
get_raw_msg_by_timestamp_random,
get_raw_msg_by_timestamp_with_chat,
get_raw_msg_by_timestamp_with_chat_inclusive,
get_raw_msg_by_timestamp_with_chat_users,
get_raw_msg_by_timestamp_with_users,
num_new_messages_since,
num_new_messages_since_with_users,
)
from src.chat.utils.utils import is_bot_self
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.database.database import get_db_session
from src.common.database.database_model import Images, ImageType
# =============================================================================
# 消息查询API函数
# 消息查询函数
# =============================================================================
def get_messages_by_time(
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
) -> List[DatabaseMessages]:
"""
获取指定时间范围内的消息
Args:
start_time: 开始时间戳
end_time: 结束时间戳
limit: 限制返回的消息数量0为不限制
limit_mode: 当limit>0时生效'earliest'表示获取最早的记录'latest'表示获取最新的记录
filter_mai: 是否过滤麦麦自身的消息默认为False
Returns:
List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
@@ -76,23 +58,6 @@ def get_messages_by_time_in_chat(
filter_command: bool = False,
filter_intercept_message_level: Optional[int] = None,
) -> List[DatabaseMessages]:
"""
获取指定聊天中指定时间范围内的消息
Args:
chat_id: 聊天ID
start_time: 开始时间戳
end_time: 结束时间戳
limit: 限制返回的消息数量0为不限制
limit_mode: 当limit>0时生效'earliest'表示获取最早的记录'latest'表示获取最新的记录
filter_mai: 是否过滤麦麦自身的消息默认为False
filter_command: 是否过滤命令消息默认为False
Returns:
List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
@@ -101,10 +66,6 @@ def get_messages_by_time_in_chat(
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
# if filter_mai:
# return filter_mai_messages(
# get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command)
# )
return get_raw_msg_by_timestamp_with_chat(
chat_id=chat_id,
timestamp_start=start_time,
@@ -127,23 +88,6 @@ def get_messages_by_time_in_chat_inclusive(
filter_command: bool = False,
filter_intercept_message_level: Optional[int] = None,
) -> List[DatabaseMessages]:
"""
获取指定聊天中指定时间范围内的消息包含边界
Args:
chat_id: 聊天ID
start_time: 开始时间戳包含
end_time: 结束时间戳包含
limit: 限制返回的消息数量0为不限制
limit_mode: 当limit>0时生效'earliest'表示获取最早的记录'latest'表示获取最新的记录
filter_mai: 是否过滤麦麦自身的消息默认为False
Returns:
List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
@@ -175,23 +119,6 @@ def get_messages_by_time_in_chat_for_users(
limit: int = 0,
limit_mode: str = "latest",
) -> List[DatabaseMessages]:
"""
获取指定聊天中指定用户在指定时间范围内的消息
Args:
chat_id: 聊天ID
start_time: 开始时间戳
end_time: 结束时间戳
person_ids: 用户ID列表
limit: 限制返回的消息数量0为不限制
limit_mode: 当limit>0时生效'earliest'表示获取最早的记录'latest'表示获取最新的记录
Returns:
List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
@@ -206,22 +133,6 @@ def get_messages_by_time_in_chat_for_users(
def get_random_chat_messages(
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
) -> List[DatabaseMessages]:
"""
随机选择一个聊天返回该聊天在指定时间范围内的消息
Args:
start_time: 开始时间戳
end_time: 结束时间戳
limit: 限制返回的消息数量0为不限制
limit_mode: 当limit>0时生效'earliest'表示获取最早的记录'latest'表示获取最新的记录
filter_mai: 是否过滤麦麦自身的消息默认为False
Returns:
List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
@@ -234,22 +145,6 @@ def get_random_chat_messages(
def get_messages_by_time_for_users(
start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
) -> List[DatabaseMessages]:
"""
获取指定用户在所有聊天中指定时间范围内的消息
Args:
start_time: 开始时间戳
end_time: 结束时间戳
person_ids: 用户ID列表
limit: 限制返回的消息数量0为不限制
limit_mode: 当limit>0时生效'earliest'表示获取最早的记录'latest'表示获取最新的记录
Returns:
List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
@@ -258,20 +153,6 @@ def get_messages_by_time_for_users(
def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[DatabaseMessages]:
"""
获取指定时间戳之前的消息
Args:
timestamp: 时间戳
limit: 限制返回的消息数量0为不限制
filter_mai: 是否过滤麦麦自身的消息默认为False
Returns:
List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(timestamp, (int, float)):
raise ValueError("timestamp 必须是数字类型")
if limit < 0:
@@ -288,21 +169,6 @@ def get_messages_before_time_in_chat(
filter_mai: bool = False,
filter_intercept_message_level: Optional[int] = None,
) -> List[DatabaseMessages]:
"""
获取指定聊天中指定时间戳之前的消息
Args:
chat_id: 聊天ID
timestamp: 时间戳
limit: 限制返回的消息数量0为不限制
filter_mai: 是否过滤麦麦自身的消息默认为False
Returns:
List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(timestamp, (int, float)):
raise ValueError("timestamp 必须是数字类型")
if limit < 0:
@@ -325,20 +191,6 @@ def get_messages_before_time_in_chat(
def get_messages_before_time_for_users(
timestamp: float, person_ids: List[str], limit: int = 0
) -> List[DatabaseMessages]:
"""
获取指定用户在指定时间戳之前的消息
Args:
timestamp: 时间戳
person_ids: 用户ID列表
limit: 限制返回的消息数量0为不限制
Returns:
List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(timestamp, (int, float)):
raise ValueError("timestamp 必须是数字类型")
if limit < 0:
@@ -349,22 +201,6 @@ def get_messages_before_time_for_users(
def get_recent_messages(
chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False
) -> List[DatabaseMessages]:
"""
获取指定聊天中最近一段时间的消息
Args:
chat_id: 聊天ID
hours: 最近多少小时默认24小时
limit: 限制返回的消息数量默认100条
limit_mode: 当limit>0时生效'earliest'表示获取最早的记录'latest'表示获取最新的记录
filter_mai: 是否过滤麦麦自身的消息默认为False
Returns:
List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法s
"""
if not isinstance(hours, (int, float)) or hours < 0:
raise ValueError("hours 不能是负数")
if not isinstance(limit, int) or limit < 0:
@@ -381,25 +217,11 @@ def get_recent_messages(
# =============================================================================
# 消息计数API函数
# 消息计数函数
# =============================================================================
def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional[float] = None) -> int:
"""
计算指定聊天中从开始时间到结束时间的新消息数量
Args:
chat_id: 聊天ID
start_time: 开始时间戳
end_time: 结束时间戳如果为None则使用当前时间
Returns:
int: 新消息数量
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(start_time, (int, float)):
raise ValueError("start_time 必须是数字类型")
if not chat_id:
@@ -410,21 +232,6 @@ def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional
def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int:
"""
计算指定聊天中指定用户从开始时间到结束时间的新消息数量
Args:
chat_id: 聊天ID
start_time: 开始时间戳
end_time: 结束时间戳
person_ids: 用户ID列表
Returns:
int: 新消息数量
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if not chat_id:
@@ -435,7 +242,7 @@ def count_new_messages_for_users(chat_id: str, start_time: float, end_time: floa
# =============================================================================
# 消息格式化API函数
# 消息格式化函数
# =============================================================================
@@ -447,21 +254,6 @@ def build_readable_messages_to_str(
truncate: bool = False,
show_actions: bool = False,
) -> str:
"""
将消息列表构建成可读的字符串
Args:
messages: 消息列表
replace_bot_name: 是否将机器人的名称替换为""
merge_messages: 是否合并连续消息
timestamp_mode: 时间戳显示模式'relative''absolute'
read_mark: 已读标记时间戳用于分割已读和未读消息
truncate: 是否截断长消息
show_actions: 是否显示动作记录
Returns:
格式化后的可读字符串
"""
return build_readable_messages(messages, replace_bot_name, timestamp_mode, read_mark, truncate, show_actions)
@@ -471,32 +263,10 @@ async def build_readable_messages_with_details(
timestamp_mode: str = "relative",
truncate: bool = False,
) -> Tuple[str, List[Tuple[float, str, str]]]:
"""
将消息列表构建成可读的字符串并返回详细信息
Args:
messages: 消息列表
replace_bot_name: 是否将机器人的名称替换为""
merge_messages: 是否合并连续消息
timestamp_mode: 时间戳显示模式'relative''absolute'
truncate: 是否截断长消息
Returns:
格式化后的可读字符串和详细信息元组列表(时间戳, 昵称, 内容)
"""
return await build_readable_messages_with_list(messages, replace_bot_name, timestamp_mode, truncate)
async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[str]:
"""
从消息列表中提取不重复的用户ID列表
Args:
messages: 消息列表
Returns:
用户ID列表
"""
return await get_person_id_list(messages)
@@ -506,14 +276,7 @@ async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[s
def filter_mai_messages(messages: List[DatabaseMessages]) -> List[DatabaseMessages]:
"""
从消息列表中移除麦麦的消息
Args:
messages: 消息列表每个元素是消息字典
Returns:
过滤后的消息列表
"""
# 使用统一的 is_bot_self 函数判断是否是机器人自己(支持多平台,包括 WebUI
"""从消息列表中移除麦麦的消息"""
return [msg for msg in messages if not is_bot_self(msg.user_info.platform, msg.user_info.user_id)]

View File

@@ -1,22 +1,14 @@
"""个人信息API模块
"""个人信息服务模块
提供个人信息查询功能用于插件获取用户相关信息
使用方式
from src.plugin_system.apis import person_api
person_id = person_api.get_person_id("qq", 123456)
value = await person_api.get_person_value(person_id, "nickname")
提供个人信息查询的核心功能
"""
from typing import Any
from src.common.logger import get_logger
from src.person_info.person_info import Person
logger = get_logger("person_api")
# =============================================================================
# 个人信息API函数
# =============================================================================
logger = get_logger("person_service")
def get_person_id(platform: str, user_id: int | str) -> str:
@@ -28,14 +20,11 @@ def get_person_id(platform: str, user_id: int | str) -> str:
Returns:
str: 唯一的person_idMD5哈希值
示例:
person_id = person_api.get_person_id("qq", 123456)
"""
try:
return Person(platform=platform, user_id=str(user_id)).person_id
except Exception as e:
logger.error(f"[PersonAPI] 获取person_id失败: platform={platform}, user_id={user_id}, error={e}")
logger.error(f"[PersonService] 获取person_id失败: platform={platform}, user_id={user_id}, error={e}")
return ""
@@ -49,17 +38,13 @@ async def get_person_value(person_id: str, field_name: str, default: Any = None)
Returns:
Any: 字段值或默认值
示例:
nickname = await person_api.get_person_value(person_id, "nickname", "未知用户")
impression = await person_api.get_person_value(person_id, "impression")
"""
try:
person = Person(person_id=person_id)
value = getattr(person, field_name)
return value if value is not None else default
except Exception as e:
logger.error(f"[PersonAPI] 获取用户信息失败: person_id={person_id}, field={field_name}, error={e}")
logger.error(f"[PersonService] 获取用户信息失败: person_id={person_id}, field={field_name}, error={e}")
return default
@@ -71,13 +56,10 @@ def get_person_id_by_name(person_name: str) -> str:
Returns:
str: person_id如果未找到返回空字符串
示例:
person_id = person_api.get_person_id_by_name("张三")
"""
try:
person = Person(person_name=person_name)
return person.person_id
except Exception as e:
logger.error(f"[PersonAPI] 根据用户名获取person_id失败: person_name={person_name}, error={e}")
logger.error(f"[PersonService] 根据用户名获取person_id失败: person_name={person_name}, error={e}")
return ""

View File

@@ -1,47 +1,32 @@
"""
发送API模块
发送服务模块
专门负责发送各种类型消息采用标准Python包设计模式
使用方式
from src.plugin_system.apis import send_api
# 方式1直接使用stream_id推荐
await send_api.text_to_stream("hello", stream_id)
await send_api.emoji_to_stream(emoji_base64, stream_id)
await send_api.custom_to_stream("video", video_data, stream_id)
# 方式2使用群聊/私聊指定函数
await send_api.text_to_group("hello", "123456")
await send_api.text_to_user("hello", "987654")
# 方式3使用通用custom_message函数
await send_api.custom_message("video", video_data, "123456", True)
提供发送各种类型消息的核心功能
"""
import traceback
import time
from typing import Optional, Union, Dict, List, TYPE_CHECKING, Tuple
from src.common.logger import get_logger
from src.common.data_models.message_data_model import ReplyContentType
from src.config.config import global_config
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from maim_message import Seg
from maim_message import MessageBase, BaseMessageInfo, Seg
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.chat.message_receive.message import MessageSending
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
from src.common.data_models.message_data_model import ReplyContentType
from src.common.logger import get_logger
from src.config.config import global_config
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.message_data_model import ReplySetModel, ReplyContent, ForwardNode
from src.common.data_models.message_data_model import ForwardNode, ReplyContent, ReplySetModel
logger = get_logger("send_api")
logger = get_logger("send_service")
# =============================================================================
# 内部实现函数(不暴露给外部)
# 内部实现函数
# =============================================================================
@@ -56,42 +41,25 @@ async def _send_to_target(
show_log: bool = True,
selected_expressions: Optional[List[int]] = None,
) -> bool:
"""向指定目标发送消息的内部实现
Args:
message_segment:
stream_id: 目标流ID
display_message: 显示消息
typing: 是否模拟打字等待
reply_to: 回复消息格式为"发送者:消息内容"
storage_message: 是否存储消息到数据库
show_log: 发送是否显示日志
Returns:
bool: 是否发送成功
"""
"""向指定目标发送消息的内部实现"""
try:
if set_reply and not reply_message:
logger.warning("[SendAPI] 使用引用回复,但未提供回复消息")
logger.warning("[SendService] 使用引用回复,但未提供回复消息")
return False
if show_log:
logger.debug(f"[SendAPI] 发送{message_segment.type}消息到 {stream_id}")
logger.debug(f"[SendService] 发送{message_segment.type}消息到 {stream_id}")
# 查找目标聊天流
target_stream = _chat_manager.get_session_by_session_id(stream_id)
if not target_stream:
logger.error(f"[SendAPI] 未找到聊天流: {stream_id}")
logger.error(f"[SendService] 未找到聊天流: {stream_id}")
return False
# 创建发送器
message_sender = UniversalMessageSender()
# 生成消息ID
current_time = time.time()
message_id = f"send_api_{int(current_time * 1000)}"
# 构建机器人用户信息
bot_user_info = UserInfo(
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
@@ -102,17 +70,15 @@ async def _send_to_target(
if reply_message:
anchor_message = db_message_to_mai_message(reply_message)
if anchor_message:
logger.debug(f"[SendAPI] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}")
logger.debug(f"[SendService] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}")
reply_to_platform_id = (
f"{anchor_message.platform}:{anchor_message.message_info.user_info.user_id}"
)
# 构建 sender_info私聊时为接收者信息
sender_info = None
if target_stream.context and target_stream.context.message:
sender_info = target_stream.context.message.message_info.user_info
# 构建发送消息对象
bot_message = MessageSending(
message_id=message_id,
session=target_stream,
@@ -128,7 +94,6 @@ async def _send_to_target(
selected_expressions=selected_expressions,
)
# 发送消息
sent_msg = await message_sender.send_message(
bot_message,
typing=typing,
@@ -138,28 +103,22 @@ async def _send_to_target(
)
if sent_msg:
logger.debug(f"[SendAPI] 成功发送消息到 {stream_id}")
logger.debug(f"[SendService] 成功发送消息到 {stream_id}")
return True
else:
logger.error("[SendAPI] 发送消息失败")
logger.error("[SendService] 发送消息失败")
return False
except Exception as e:
logger.error(f"[SendAPI] 发送消息时出错: {e}")
logger.error(f"[SendService] 发送消息时出错: {e}")
traceback.print_exc()
return False
def db_message_to_mai_message(message_obj: "DatabaseMessages") -> Optional[MaiMessage]:
"""将数据库消息重建为 MaiMessage 对象,用于回复引用。
Args:
message_obj: 插件系统的 DatabaseMessages 数据对象
Returns:
Optional[MaiMessage]: 构建的消息对象如果信息不足则返回 None
"""
"""将数据库消息重建为 MaiMessage 对象,用于回复引用。"""
from datetime import datetime
from src.common.data_models.mai_message_data_model import GroupInfo, MessageInfo
from src.common.data_models.message_component_data_model import MessageSequence
@@ -190,7 +149,7 @@ def db_message_to_mai_message(message_obj: "DatabaseMessages") -> Optional[MaiMe
# =============================================================================
# 公共API函数 - 预定义类型的发送函数
# 公共函数 - 预定义类型的发送函数
# =============================================================================
@@ -203,18 +162,7 @@ async def text_to_stream(
storage_message: bool = True,
selected_expressions: Optional[List[int]] = None,
) -> bool:
"""向指定流发送文本消息
Args:
text: 要发送的文本内容
stream_id: 聊天流ID
typing: 是否显示正在输入
reply_to: 回复消息格式为"发送者:消息内容"
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
"""向指定流发送文本消息"""
return await _send_to_target(
message_segment=Seg(type="text", data=text),
stream_id=stream_id,
@@ -234,16 +182,7 @@ async def emoji_to_stream(
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
) -> bool:
"""向指定流发送表情包
Args:
emoji_base64: 表情包的base64编码
stream_id: 聊天流ID
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
"""向指定流发送表情包"""
return await _send_to_target(
message_segment=Seg(type="emoji", data=emoji_base64),
stream_id=stream_id,
@@ -262,16 +201,7 @@ async def image_to_stream(
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
) -> bool:
"""向指定流发送图片
Args:
image_base64: 图片的base64编码
stream_id: 聊天流ID
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
"""向指定流发送图片"""
return await _send_to_target(
message_segment=Seg(type="image", data=image_base64),
stream_id=stream_id,
@@ -289,17 +219,7 @@ async def command_to_stream(
storage_message: bool = True,
display_message: str = "",
) -> bool:
"""向指定流发送命令
Args:
command: 命令
stream_id: 聊天流ID
storage_message: 是否存储消息到数据库
display_message: 显示消息
Returns:
bool: 是否发送成功
"""
"""向指定流发送命令"""
return await _send_to_target(
message_segment=Seg(type="command", data=command), # type: ignore
stream_id=stream_id,
@@ -321,20 +241,7 @@ async def custom_to_stream(
storage_message: bool = True,
show_log: bool = True,
) -> bool:
"""向指定流发送自定义类型消息
Args:
message_type: 消息类型"text""image""emoji""video""file"
content: 消息内容通常是base64编码或文本
stream_id: 聊天流ID
display_message: 显示消息
typing: 是否显示正在输入
reply_to: 回复消息格式为"发送者:消息内容"
storage_message: 是否存储消息到数据库
show_log: 是否显示日志
Returns:
bool: 是否发送成功
"""
"""向指定流发送自定义类型消息"""
return await _send_to_target(
message_segment=Seg(type=message_type, data=content), # type: ignore
stream_id=stream_id,
@@ -350,25 +257,14 @@ async def custom_to_stream(
async def custom_reply_set_to_stream(
reply_set: "ReplySetModel",
stream_id: str,
display_message: str = "", # 基本没用
display_message: str = "",
typing: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
set_reply: bool = False,
storage_message: bool = True,
show_log: bool = True,
) -> bool:
"""
向指定流发送混合型消息集
Args:
reply_set: ReplySetModel 对象包含多个 ReplyContent
stream_id: 聊天流ID
display_message: 显示消息
typing: 是否显示正在输入
reply_to: 回复消息格式为"发送者:消息内容"
storage_message: 是否存储消息到数据库
show_log: 是否显示日志
"""
"""向指定流发送混合型消息集"""
flag: bool = True
for reply_content in reply_set.reply_data:
status: bool = False
@@ -386,20 +282,14 @@ async def custom_reply_set_to_stream(
if not status:
flag = False
logger.error(
f"[SendAPI] 发送{repr(reply_content.content_type)}消息失败,消息内容:{str(reply_content.content)[:100]}"
f"[SendService] 发送{repr(reply_content.content_type)}消息失败,消息内容:{str(reply_content.content)[:100]}"
)
return flag
def _parse_content_to_seg(reply_content: "ReplyContent") -> Tuple[Seg, bool]:
"""
ReplyContent 转换为 Seg 结构 (Forward 中仅递归一次)
Args:
reply_content: ReplyContent 对象
Returns:
Tuple[Seg, bool]: 转换后的 Seg 结构和是否需要typing的标志
"""
"""把 ReplyContent 转换为 Seg 结构"""
content_type = reply_content.content_type
if content_type == ReplyContentType.TEXT:
text_data: str = reply_content.content # type: ignore
@@ -427,7 +317,7 @@ def _parse_content_to_seg(reply_content: "ReplyContent") -> Tuple[Seg, bool]:
elif sub_content_type == ReplyContentType.EMOJI:
sub_seg_list.append(Seg(type="emoji", data=sub_content_data)) # type: ignore
else:
logger.warning(f"[SendAPI] 混合类型中不支持的子内容类型: {repr(sub_content_type)}")
logger.warning(f"[SendService] 混合类型中不支持的子内容类型: {repr(sub_content_type)}")
continue
return Seg(type="seglist", data=sub_seg_list), True
elif content_type == ReplyContentType.FORWARD:

View File

@@ -6,7 +6,7 @@ import json
from src.common.logger import get_logger
from src.common.toml_utils import save_toml_with_format
from src.config.config import MMC_VERSION
from src.plugin_system.base.config_types import ConfigField
from src.core.config_types import ConfigField
from src.webui.services.git_mirror_service import get_git_mirror_service, set_update_progress_callback
from src.webui.core import get_token_manager
from src.webui.routers.websocket.plugin_progress import update_progress
@@ -222,15 +222,16 @@ def coerce_types(schema_part: Dict[str, Any], config_part: Dict[str, Any]) -> No
def find_plugin_instance(plugin_id: str) -> Optional[Any]:
"""
按 plugin_id 或 plugin_name 查找已加载的插件实例
局部导入 plugin_manager 以规避循环依赖
按 plugin_id 查找已加载的插件信息
新运行时中插件运行在子进程,无法获取实例,返回注册信息
"""
from src.plugin_system.core.plugin_manager import plugin_manager
from src.plugin_runtime.integration import get_plugin_runtime_manager
for loaded_plugin_name in plugin_manager.list_loaded_plugins():
instance = plugin_manager.get_plugin_instance(loaded_plugin_name)
if instance and (instance.plugin_name == plugin_id or instance.get_manifest_info("id", "") == plugin_id):
return instance
mgr = get_plugin_runtime_manager()
for sv in mgr.supervisors:
reg = sv._registered_plugins.get(plugin_id)
if reg is not None:
return reg
return None
@@ -1497,26 +1498,10 @@ async def get_plugin_config_schema(
logger.info(f"获取插件配置 Schema: {plugin_id}")
try:
# 尝试从已加载的插件中获取
from src.plugin_system.core.plugin_manager import plugin_manager
# 查找插件实例
# 新运行时中插件运行在子进程,无法直接获取实例的 webui_config_schema
# 尝试从文件系统读取
plugin_instance = None
# 遍历所有已加载的插件
for loaded_plugin_name in plugin_manager.list_loaded_plugins():
instance = plugin_manager.get_plugin_instance(loaded_plugin_name)
if instance:
# 匹配 plugin_name 或 manifest 中的 id
if instance.plugin_name == plugin_id:
plugin_instance = instance
break
# 也尝试匹配 manifest 中的 id
manifest_id = instance.get_manifest_info("id", "")
if manifest_id == plugin_id:
plugin_instance = instance
break
if plugin_instance and hasattr(plugin_instance, "get_webui_config_schema"):
# 从插件实例获取 schema
schema = plugin_instance.get_webui_config_schema()