重构整个插件系统,尝试恢复可启动性,新增插件系统maibot-plugin-sdk依赖
This commit is contained in:
6
bot.py
6
bot.py
@@ -195,11 +195,11 @@ async def graceful_shutdown(): # sourcery skip: use-named-expression
|
|||||||
# except Exception as e:
|
# except Exception as e:
|
||||||
# logger.warning(f"关闭 WebUI 服务器时出错: {e}")
|
# logger.warning(f"关闭 WebUI 服务器时出错: {e}")
|
||||||
|
|
||||||
from src.plugin_system.core.events_manager import events_manager
|
from src.core.event_bus import event_bus
|
||||||
from src.plugin_system.base.component_types import EventType
|
from src.core.types import EventType
|
||||||
|
|
||||||
# 触发 ON_STOP 事件
|
# 触发 ON_STOP 事件
|
||||||
await events_manager.handle_mai_events(event_type=EventType.ON_STOP)
|
await event_bus.emit(event_type=EventType.ON_STOP)
|
||||||
|
|
||||||
# 停止新版本插件运行时
|
# 停止新版本插件运行时
|
||||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||||
|
|||||||
@@ -9,7 +9,7 @@
|
|||||||
},
|
},
|
||||||
"license": "GPL-v3.0-or-later",
|
"license": "GPL-v3.0-or-later",
|
||||||
"host_application": {
|
"host_application": {
|
||||||
"min_version": "0.10.3"
|
"min_version": "1.0.0"
|
||||||
},
|
},
|
||||||
"homepage_url": "https://github.com/SengokuCola/BetterFrequency",
|
"homepage_url": "https://github.com/SengokuCola/BetterFrequency",
|
||||||
"repository_url": "https://github.com/SengokuCola/BetterFrequency",
|
"repository_url": "https://github.com/SengokuCola/BetterFrequency",
|
||||||
|
|||||||
@@ -1,144 +1,87 @@
|
|||||||
from typing import List, Tuple, Type, Optional
|
"""发言频率控制插件 — 新 SDK 版本
|
||||||
from src.plugin_system import BasePlugin, register_plugin, BaseCommand, ComponentInfo, ConfigField
|
|
||||||
from src.plugin_system.apis import send_api, frequency_api
|
通过 /chat 命令设置和查看聊天频率。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from maibot_sdk import MaiBotPlugin, Command
|
||||||
|
|
||||||
|
|
||||||
class SetTalkFrequencyCommand(BaseCommand):
|
class BetterFrequencyPlugin(MaiBotPlugin):
|
||||||
"""设置当前聊天的talk_frequency值"""
|
"""聊天频率控制插件"""
|
||||||
|
|
||||||
command_name = "set_talk_frequency"
|
@Command(
|
||||||
command_description = "设置当前聊天的talk_frequency值:/chat talk_frequency <数字> 或 /chat t <数字>"
|
"set_talk_frequency",
|
||||||
command_pattern = r"^/chat\s+(?:talk_frequency|t)\s+(?P<value>[+-]?\d*\.?\d+)$"
|
description="设置当前聊天的talk_frequency值:/chat talk_frequency <数字> 或 /chat t <数字>",
|
||||||
|
pattern=r"^/chat\s+(?:talk_frequency|t)\s+(?P<value>[+-]?\d*\.?\d+)$",
|
||||||
|
)
|
||||||
|
async def handle_set_talk_frequency(
|
||||||
|
self, stream_id: str = "", matched_groups: dict | None = None, **kwargs
|
||||||
|
):
|
||||||
|
"""设置当前聊天的 talk_frequency"""
|
||||||
|
if not matched_groups or "value" not in matched_groups:
|
||||||
|
return False, "命令格式错误", False
|
||||||
|
|
||||||
|
value_str = matched_groups["value"]
|
||||||
|
if not value_str:
|
||||||
|
return False, "无法获取数值参数", False
|
||||||
|
|
||||||
async def execute(self) -> Tuple[bool, Optional[str], bool]:
|
|
||||||
try:
|
try:
|
||||||
# 获取命令参数 - 使用命名捕获组
|
|
||||||
if not self.matched_groups or "value" not in self.matched_groups:
|
|
||||||
return False, "命令格式错误", False
|
|
||||||
|
|
||||||
value_str = self.matched_groups["value"]
|
|
||||||
if not value_str:
|
|
||||||
return False, "无法获取数值参数", False
|
|
||||||
|
|
||||||
value = float(value_str)
|
value = float(value_str)
|
||||||
|
|
||||||
# 获取聊天流ID
|
|
||||||
if not self.message.chat_stream or not hasattr(self.message.chat_stream, "stream_id"):
|
|
||||||
return False, "无法获取聊天流信息", False
|
|
||||||
|
|
||||||
chat_id = self.message.chat_stream.stream_id
|
|
||||||
|
|
||||||
# 设置talk_frequency
|
|
||||||
frequency_api.set_talk_frequency_adjust(chat_id, value)
|
|
||||||
|
|
||||||
final_value = frequency_api.get_current_talk_value(chat_id)
|
|
||||||
adjust_value = frequency_api.get_talk_frequency_adjust(chat_id)
|
|
||||||
base_value = final_value / adjust_value
|
|
||||||
|
|
||||||
# 发送反馈消息(不保存到数据库)
|
|
||||||
await send_api.text_to_stream(
|
|
||||||
f"已设置当前聊天的talk_frequency调整值为: {value}\n当前talk_value: {final_value:.2f}\n发言频率调整: {adjust_value:.2f}\n基础值: {base_value:.2f}",
|
|
||||||
chat_id,
|
|
||||||
storage_message=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
return True, None, False
|
|
||||||
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
error_msg = "数值格式错误,请输入有效的数字"
|
await self.ctx.send.text("数值格式错误,请输入有效的数字", stream_id)
|
||||||
await self.send_text(error_msg, storage_message=False)
|
return False, "数值格式错误", False
|
||||||
return False, error_msg, False
|
|
||||||
except Exception as e:
|
if not stream_id:
|
||||||
error_msg = f"设置talk_frequency失败: {str(e)}"
|
return False, "无法获取聊天流信息", False
|
||||||
await self.send_text(error_msg, storage_message=False)
|
|
||||||
return False, error_msg, False
|
# 设置 talk_frequency
|
||||||
|
await self.ctx.frequency.set_adjust(stream_id, value)
|
||||||
|
|
||||||
|
# 获取当前状态
|
||||||
|
current = await self.ctx.frequency.get_current_talk_value(stream_id)
|
||||||
|
current_val = current if isinstance(current, (int, float)) else 0
|
||||||
|
adjust = await self.ctx.frequency.get_adjust(stream_id)
|
||||||
|
adjust_val = adjust if isinstance(adjust, (int, float)) else 1
|
||||||
|
base_val = current_val / adjust_val if adjust_val else 0
|
||||||
|
|
||||||
|
msg = (
|
||||||
|
f"已设置当前聊天的talk_frequency调整值为: {value}\n"
|
||||||
|
f"当前talk_value: {current_val:.2f}\n"
|
||||||
|
f"发言频率调整: {adjust_val:.2f}\n"
|
||||||
|
f"基础值: {base_val:.2f}"
|
||||||
|
)
|
||||||
|
await self.ctx.send.text(msg, stream_id)
|
||||||
|
return True, None, False
|
||||||
|
|
||||||
|
@Command(
|
||||||
|
"show_frequency",
|
||||||
|
description="显示当前聊天的频率控制状态:/chat show 或 /chat s",
|
||||||
|
pattern=r"^/chat\s+(?:show|s)$",
|
||||||
|
)
|
||||||
|
async def handle_show_frequency(self, stream_id: str = "", **kwargs):
|
||||||
|
"""显示当前频率控制状态"""
|
||||||
|
if not stream_id:
|
||||||
|
return False, "无法获取聊天流信息", False
|
||||||
|
|
||||||
|
current = await self.ctx.frequency.get_current_talk_value(stream_id)
|
||||||
|
current_val = current if isinstance(current, (int, float)) else 0
|
||||||
|
adjust = await self.ctx.frequency.get_adjust(stream_id)
|
||||||
|
adjust_val = adjust if isinstance(adjust, (int, float)) else 1
|
||||||
|
base_val = current_val / adjust_val if adjust_val else 0
|
||||||
|
|
||||||
|
status_msg = (
|
||||||
|
"当前聊天频率控制状态\n"
|
||||||
|
"Talk Value (发言频率):\n\n"
|
||||||
|
f" • 基础值: {base_val:.2f}\n"
|
||||||
|
f" • 发言频率调整: {adjust_val:.2f}\n"
|
||||||
|
f" • 当前值: {current_val:.2f}\n\n"
|
||||||
|
"使用命令:\n"
|
||||||
|
" • /chat talk_frequency <数字> 或 /chat t <数字> - 设置发言频率调整\n"
|
||||||
|
" • /chat show 或 /chat s - 显示当前状态"
|
||||||
|
)
|
||||||
|
await self.ctx.send.text(status_msg, stream_id)
|
||||||
|
return True, None, False
|
||||||
|
|
||||||
|
|
||||||
class ShowFrequencyCommand(BaseCommand):
|
def create_plugin():
|
||||||
"""显示当前聊天的频率控制状态"""
|
return BetterFrequencyPlugin()
|
||||||
|
|
||||||
command_name = "show_frequency"
|
|
||||||
command_description = "显示当前聊天的频率控制状态:/chat show 或 /chat s"
|
|
||||||
command_pattern = r"^/chat\s+(?:show|s)$"
|
|
||||||
|
|
||||||
async def execute(self) -> Tuple[bool, Optional[str], bool]:
|
|
||||||
try:
|
|
||||||
# 获取聊天流ID
|
|
||||||
if not self.message.chat_stream or not hasattr(self.message.chat_stream, "stream_id"):
|
|
||||||
return False, "无法获取聊天流信息", False
|
|
||||||
|
|
||||||
chat_id = self.message.chat_stream.stream_id
|
|
||||||
|
|
||||||
# 获取当前频率控制状态
|
|
||||||
current_talk_frequency = frequency_api.get_current_talk_value(chat_id)
|
|
||||||
talk_frequency_adjust = frequency_api.get_talk_frequency_adjust(chat_id)
|
|
||||||
base_value = current_talk_frequency / talk_frequency_adjust
|
|
||||||
|
|
||||||
# 构建显示消息
|
|
||||||
status_msg = f"""当前聊天频率控制状态
|
|
||||||
Talk Value (发言频率):
|
|
||||||
|
|
||||||
• 基础值: {base_value:.2f}
|
|
||||||
• 发言频率调整: {talk_frequency_adjust:.2f}
|
|
||||||
• 当前值: {current_talk_frequency:.2f}
|
|
||||||
|
|
||||||
使用命令:
|
|
||||||
• /chat talk_frequency <数字> 或 /chat t <数字> - 设置发言频率调整
|
|
||||||
• /chat show 或 /chat s - 显示当前状态"""
|
|
||||||
|
|
||||||
# 发送状态消息(不保存到数据库)
|
|
||||||
await send_api.text_to_stream(status_msg, chat_id, storage_message=False)
|
|
||||||
|
|
||||||
return True, None, False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"获取频率控制状态失败: {str(e)}"
|
|
||||||
# 使用内置的send_text方法发送错误消息
|
|
||||||
await self.send_text(error_msg, storage_message=False)
|
|
||||||
return False, error_msg, False
|
|
||||||
|
|
||||||
|
|
||||||
# ===== 插件注册 =====
|
|
||||||
|
|
||||||
|
|
||||||
@register_plugin
|
|
||||||
class BetterFrequencyPlugin(BasePlugin):
|
|
||||||
"""BetterFrequency插件 - 控制聊天频率的插件"""
|
|
||||||
|
|
||||||
# 插件基本信息
|
|
||||||
plugin_name: str = "better_frequency_plugin"
|
|
||||||
enable_plugin: bool = True
|
|
||||||
dependencies: List[str] = []
|
|
||||||
python_dependencies: List[str] = []
|
|
||||||
config_file_name: str = "config.toml"
|
|
||||||
|
|
||||||
# 配置节描述
|
|
||||||
config_section_descriptions = {"plugin": "插件基本信息", "frequency": "频率控制配置", "features": "功能开关配置"}
|
|
||||||
|
|
||||||
# 配置Schema定义
|
|
||||||
config_schema: dict = {
|
|
||||||
"plugin": {
|
|
||||||
"name": ConfigField(type=str, default="better_frequency_plugin", description="插件名称"),
|
|
||||||
"version": ConfigField(type=str, default="1.0.2", description="插件版本"),
|
|
||||||
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
|
|
||||||
},
|
|
||||||
"frequency": {
|
|
||||||
"default_talk_adjust": ConfigField(type=float, default=1.0, description="默认talk_frequency调整值"),
|
|
||||||
"max_adjust_value": ConfigField(type=float, default=1.0, description="最大调整值"),
|
|
||||||
"min_adjust_value": ConfigField(type=float, default=0.0, description="最小调整值"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
|
||||||
components = []
|
|
||||||
|
|
||||||
# 根据配置决定是否注册命令组件
|
|
||||||
if self.config.get("features", {}).get("enable_commands", True):
|
|
||||||
components.extend(
|
|
||||||
[
|
|
||||||
(SetTalkFrequencyCommand.get_command_info(), SetTalkFrequencyCommand),
|
|
||||||
(ShowFrequencyCommand.get_command_info(), ShowFrequencyCommand),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
return components
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
"manifest_version": 1,
|
"manifest_version": 1,
|
||||||
"name": "BetterEmoji",
|
"name": "BetterEmoji",
|
||||||
"version": "1.0.0",
|
"version": "2.0.0",
|
||||||
"description": "更好的表情包管理插件",
|
"description": "更好的表情包管理插件",
|
||||||
"author": {
|
"author": {
|
||||||
"name": "SengokuCola",
|
"name": "SengokuCola",
|
||||||
@@ -9,7 +9,7 @@
|
|||||||
},
|
},
|
||||||
"license": "GPL-v3.0-or-later",
|
"license": "GPL-v3.0-or-later",
|
||||||
"host_application": {
|
"host_application": {
|
||||||
"min_version": "0.10.4"
|
"min_version": "1.0.0"
|
||||||
},
|
},
|
||||||
"homepage_url": "https://github.com/SengokuCola/BetterEmoji",
|
"homepage_url": "https://github.com/SengokuCola/BetterEmoji",
|
||||||
"repository_url": "https://github.com/SengokuCola/BetterEmoji",
|
"repository_url": "https://github.com/SengokuCola/BetterEmoji",
|
||||||
@@ -19,46 +19,49 @@
|
|||||||
"plugin"
|
"plugin"
|
||||||
],
|
],
|
||||||
"categories": [
|
"categories": [
|
||||||
"Examples",
|
"Emoji",
|
||||||
"Tutorial"
|
"Management"
|
||||||
],
|
],
|
||||||
"default_locale": "zh-CN",
|
"default_locale": "zh-CN",
|
||||||
"locales_path": "_locales",
|
"locales_path": "_locales",
|
||||||
"plugin_info": {
|
"plugin_info": {
|
||||||
"is_built_in": false,
|
"is_built_in": false,
|
||||||
"plugin_type": "emoji_manage",
|
"plugin_type": "emoji_manage",
|
||||||
|
"capabilities": [
|
||||||
|
"emoji.get_random",
|
||||||
|
"emoji.get_count",
|
||||||
|
"emoji.get_info",
|
||||||
|
"emoji.get_all",
|
||||||
|
"emoji.register_emoji",
|
||||||
|
"emoji.delete_emoji",
|
||||||
|
"send.text",
|
||||||
|
"send.forward"
|
||||||
|
],
|
||||||
"components": [
|
"components": [
|
||||||
{
|
{
|
||||||
"type": "action",
|
"type": "command",
|
||||||
"name": "hello_greeting",
|
"name": "add_emoji",
|
||||||
"description": "向用户发送问候消息"
|
"description": "添加表情包",
|
||||||
},
|
"pattern": "/emoji add"
|
||||||
{
|
|
||||||
"type": "action",
|
|
||||||
"name": "bye_greeting",
|
|
||||||
"description": "向用户发送告别消息",
|
|
||||||
"activation_modes": [
|
|
||||||
"keyword"
|
|
||||||
],
|
|
||||||
"keywords": [
|
|
||||||
"再见",
|
|
||||||
"bye",
|
|
||||||
"88",
|
|
||||||
"拜拜"
|
|
||||||
]
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "command",
|
"type": "command",
|
||||||
"name": "time",
|
"name": "emoji_list",
|
||||||
"description": "查询当前时间",
|
"description": "列表表情包",
|
||||||
"pattern": "/time"
|
"pattern": "/emoji list"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "command",
|
||||||
|
"name": "delete_emoji",
|
||||||
|
"description": "删除表情包",
|
||||||
|
"pattern": "/emoji delete"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "command",
|
||||||
|
"name": "random_emojis",
|
||||||
|
"description": "发送多张随机表情包",
|
||||||
|
"pattern": "/random_emojis"
|
||||||
}
|
}
|
||||||
],
|
|
||||||
"features": [
|
|
||||||
"问候和告别功能",
|
|
||||||
"时间查询命令",
|
|
||||||
"配置文件示例",
|
|
||||||
"新手教程代码"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"id": "SengokuCola.BetterEmoji"
|
"id": "SengokuCola.BetterEmoji"
|
||||||
|
|||||||
@@ -1,399 +1,216 @@
|
|||||||
from typing import List, Tuple, Type
|
"""表情包管理插件 — 新 SDK 版本
|
||||||
from src.plugin_system import (
|
|
||||||
BasePlugin,
|
|
||||||
register_plugin,
|
|
||||||
BaseCommand,
|
|
||||||
ComponentInfo,
|
|
||||||
ConfigField,
|
|
||||||
ReplyContentType,
|
|
||||||
emoji_api,
|
|
||||||
)
|
|
||||||
from maim_message import Seg
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
|
|
||||||
logger = get_logger("emoji_manage_plugin")
|
通过 /emoji 命令管理表情包的添加、列表和删除。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import datetime
|
||||||
|
import hashlib
|
||||||
|
import re
|
||||||
|
|
||||||
|
from maibot_sdk import MaiBotPlugin, Command
|
||||||
|
|
||||||
|
|
||||||
class AddEmojiCommand(BaseCommand):
|
class EmojiManagePlugin(MaiBotPlugin):
|
||||||
command_name = "add_emoji"
|
"""表情包管理插件"""
|
||||||
command_description = "添加表情包"
|
|
||||||
command_pattern = r".*/emoji add.*"
|
|
||||||
|
|
||||||
async def execute(self) -> Tuple[bool, str, bool]:
|
# ===== 工具方法 =====
|
||||||
# 查找消息中的表情包
|
|
||||||
# logger.info(f"查找消息中的表情包: {self.message.message_segment}")
|
|
||||||
|
|
||||||
emoji_base64_list = self.find_and_return_emoji_in_message(self.message.message_segment)
|
@staticmethod
|
||||||
|
def _extract_emoji_base64(segments) -> list[str]:
|
||||||
|
"""从消息 segments 中提取 emoji/image 的 base64 数据。
|
||||||
|
|
||||||
|
segments 可以是 dict 列表或 Seg 对象列表(兼容两种格式)。
|
||||||
|
"""
|
||||||
|
results: list[str] = []
|
||||||
|
if not segments:
|
||||||
|
return results
|
||||||
|
|
||||||
|
if isinstance(segments, dict):
|
||||||
|
seg_type = segments.get("type", "")
|
||||||
|
if seg_type in ("emoji", "image"):
|
||||||
|
data = segments.get("data", "")
|
||||||
|
if data:
|
||||||
|
results.append(data)
|
||||||
|
elif seg_type == "seglist":
|
||||||
|
for child in segments.get("data", []):
|
||||||
|
results.extend(EmojiManagePlugin._extract_emoji_base64(child))
|
||||||
|
return results
|
||||||
|
|
||||||
|
# 如果有 .type 属性(Seg 对象)
|
||||||
|
if hasattr(segments, "type"):
|
||||||
|
seg_type = getattr(segments, "type", "")
|
||||||
|
if seg_type in ("emoji", "image"):
|
||||||
|
results.append(getattr(segments, "data", ""))
|
||||||
|
elif seg_type == "seglist":
|
||||||
|
for child in getattr(segments, "data", []):
|
||||||
|
results.extend(EmojiManagePlugin._extract_emoji_base64(child))
|
||||||
|
return results
|
||||||
|
|
||||||
|
# 列表
|
||||||
|
for seg in segments:
|
||||||
|
results.extend(EmojiManagePlugin._extract_emoji_base64(seg))
|
||||||
|
return results
|
||||||
|
|
||||||
|
# ===== Command 组件 =====
|
||||||
|
|
||||||
|
@Command("add_emoji", description="添加表情包", pattern=r".*/emoji add.*")
|
||||||
|
async def handle_add_emoji(self, stream_id: str = "", message_segments=None, **kwargs):
|
||||||
|
"""添加表情包"""
|
||||||
|
emoji_base64_list = self._extract_emoji_base64(message_segments)
|
||||||
if not emoji_base64_list:
|
if not emoji_base64_list:
|
||||||
|
await self.ctx.send.text("未在消息中找到表情包或图片", stream_id)
|
||||||
return False, "未在消息中找到表情包或图片", False
|
return False, "未在消息中找到表情包或图片", False
|
||||||
|
|
||||||
# 注册找到的表情包
|
|
||||||
success_count = 0
|
success_count = 0
|
||||||
fail_count = 0
|
fail_count = 0
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
for i, emoji_base64 in enumerate(emoji_base64_list):
|
for i, emoji_b64 in enumerate(emoji_base64_list):
|
||||||
try:
|
result = await self.ctx.emoji.register_emoji(emoji_b64)
|
||||||
# 使用emoji_api注册表情包(让API自动生成唯一文件名)
|
if isinstance(result, dict) and result.get("success"):
|
||||||
result = await emoji_api.register_emoji(emoji_base64)
|
success_count += 1
|
||||||
|
desc = result.get("description", "未知描述")
|
||||||
if result["success"]:
|
emotions = result.get("emotions", [])
|
||||||
success_count += 1
|
replaced = result.get("replaced", False)
|
||||||
description = result.get("description", "未知描述")
|
msg = f"表情包 {i + 1} 注册成功{'(替换旧表情包)' if replaced else '(新增表情包)'}"
|
||||||
emotions = result.get("emotions", [])
|
if desc:
|
||||||
replaced = result.get("replaced", False)
|
msg += f"\n描述: {desc}"
|
||||||
|
if emotions:
|
||||||
result_msg = f"表情包 {i + 1} 注册成功{'(替换旧表情包)' if replaced else '(新增表情包)'}"
|
msg += f"\n情感标签: {', '.join(emotions)}"
|
||||||
if description:
|
results.append(msg)
|
||||||
result_msg += f"\n描述: {description}"
|
|
||||||
if emotions:
|
|
||||||
result_msg += f"\n情感标签: {', '.join(emotions)}"
|
|
||||||
|
|
||||||
results.append(result_msg)
|
|
||||||
else:
|
|
||||||
fail_count += 1
|
|
||||||
error_msg = result.get("message", "注册失败")
|
|
||||||
results.append(f"表情包 {i + 1} 注册失败: {error_msg}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
fail_count += 1
|
|
||||||
results.append(f"表情包 {i + 1} 注册时发生错误: {str(e)}")
|
|
||||||
|
|
||||||
# 构建返回消息
|
|
||||||
total_count = success_count + fail_count
|
|
||||||
summary_msg = f"表情包注册完成: 成功 {success_count} 个,失败 {fail_count} 个,共处理 {total_count} 个"
|
|
||||||
|
|
||||||
# 如果有结果详情,添加到返回消息中
|
|
||||||
details_msg = ""
|
|
||||||
if results:
|
|
||||||
details_msg = "\n" + "\n".join(results)
|
|
||||||
final_msg = summary_msg + details_msg
|
|
||||||
else:
|
|
||||||
final_msg = summary_msg
|
|
||||||
|
|
||||||
# 使用表达器重写回复
|
|
||||||
try:
|
|
||||||
from src.plugin_system.apis import generator_api
|
|
||||||
|
|
||||||
# 构建重写数据
|
|
||||||
rewrite_data = {
|
|
||||||
"raw_reply": summary_msg,
|
|
||||||
"reason": f"注册了表情包:{details_msg}\n",
|
|
||||||
}
|
|
||||||
|
|
||||||
# 调用表达器重写
|
|
||||||
result_status, data = await generator_api.rewrite_reply(
|
|
||||||
chat_stream=self.message.chat_stream,
|
|
||||||
reply_data=rewrite_data,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result_status:
|
|
||||||
# 发送重写后的回复
|
|
||||||
for reply_seg in data.reply_set.reply_data:
|
|
||||||
send_data = reply_seg.content
|
|
||||||
await self.send_text(send_data)
|
|
||||||
|
|
||||||
return success_count > 0, final_msg, success_count > 0
|
|
||||||
else:
|
else:
|
||||||
# 如果重写失败,发送原始消息
|
fail_count += 1
|
||||||
await self.send_text(final_msg)
|
err = result.get("message", "注册失败") if isinstance(result, dict) else "注册失败"
|
||||||
return success_count > 0, final_msg, success_count > 0
|
results.append(f"表情包 {i + 1} 注册失败: {err}")
|
||||||
|
|
||||||
except Exception as e:
|
total = success_count + fail_count
|
||||||
# 如果表达器调用失败,发送原始消息
|
summary = f"表情包注册完成: 成功 {success_count} 个,失败 {fail_count} 个,共处理 {total} 个"
|
||||||
logger.error(f"[add_emoji] 表达器重写失败: {e}")
|
if results:
|
||||||
await self.send_text(final_msg)
|
summary += "\n" + "\n".join(results)
|
||||||
return success_count > 0, final_msg, success_count > 0
|
|
||||||
|
|
||||||
def find_and_return_emoji_in_message(self, message_segments) -> List[str]:
|
await self.ctx.send.text(summary, stream_id)
|
||||||
emoji_base64_list = []
|
return success_count > 0, summary, success_count > 0
|
||||||
|
|
||||||
# 处理单个Seg对象的情况
|
@Command("emoji_list", description="列表表情包", pattern=r"^/emoji list(\s+\d+)?$")
|
||||||
if isinstance(message_segments, Seg):
|
async def handle_list_emoji(self, stream_id: str = "", raw_message: str = "", **kwargs):
|
||||||
if message_segments.type == "emoji":
|
"""列出表情包"""
|
||||||
emoji_base64_list.append(message_segments.data)
|
max_count = 10
|
||||||
elif message_segments.type == "image":
|
match = re.match(r"^/emoji list(?:\s+(\d+))?$", raw_message)
|
||||||
# 假设图片数据是base64编码的
|
|
||||||
emoji_base64_list.append(message_segments.data)
|
|
||||||
elif message_segments.type == "seglist":
|
|
||||||
# 递归处理嵌套的Seg列表
|
|
||||||
emoji_base64_list.extend(self.find_and_return_emoji_in_message(message_segments.data))
|
|
||||||
return emoji_base64_list
|
|
||||||
|
|
||||||
# 处理Seg列表的情况
|
|
||||||
for seg in message_segments:
|
|
||||||
if seg.type == "emoji":
|
|
||||||
emoji_base64_list.append(seg.data)
|
|
||||||
elif seg.type == "image":
|
|
||||||
# 假设图片数据是base64编码的
|
|
||||||
emoji_base64_list.append(seg.data)
|
|
||||||
elif seg.type == "seglist":
|
|
||||||
# 递归处理嵌套的Seg列表
|
|
||||||
emoji_base64_list.extend(self.find_and_return_emoji_in_message(seg.data))
|
|
||||||
return emoji_base64_list
|
|
||||||
|
|
||||||
|
|
||||||
class ListEmojiCommand(BaseCommand):
|
|
||||||
"""列表表情包Command - 响应/emoji list命令"""
|
|
||||||
|
|
||||||
command_name = "emoji_list"
|
|
||||||
command_description = "列表表情包"
|
|
||||||
|
|
||||||
# === 命令设置(必须填写)===
|
|
||||||
command_pattern = r"^/emoji list(\s+\d+)?$" # 匹配 "/emoji list" 或 "/emoji list 数量"
|
|
||||||
|
|
||||||
async def execute(self) -> Tuple[bool, str, bool]:
|
|
||||||
"""执行列表表情包"""
|
|
||||||
from src.plugin_system.apis import emoji_api
|
|
||||||
import datetime
|
|
||||||
|
|
||||||
# 解析命令参数
|
|
||||||
import re
|
|
||||||
|
|
||||||
match = re.match(r"^/emoji list(?:\s+(\d+))?$", self.message.raw_message)
|
|
||||||
max_count = 10 # 默认显示10个
|
|
||||||
if match and match.group(1):
|
if match and match.group(1):
|
||||||
max_count = min(int(match.group(1)), 50) # 最多显示50个
|
max_count = min(int(match.group(1)), 50)
|
||||||
|
|
||||||
# 获取当前时间
|
now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
time_format: str = self.get_config("time.format", "%Y-%m-%d %H:%M:%S") # type: ignore
|
|
||||||
now = datetime.datetime.now()
|
|
||||||
time_str = now.strftime(time_format)
|
|
||||||
|
|
||||||
# 获取表情包信息
|
count_result = await self.ctx.emoji.get_count()
|
||||||
emoji_count = emoji_api.get_count()
|
emoji_count = count_result if isinstance(count_result, int) else 0
|
||||||
emoji_info = emoji_api.get_info()
|
|
||||||
|
|
||||||
# 构建返回消息
|
info_result = await self.ctx.emoji.get_info()
|
||||||
message_lines = [
|
max_emoji = info_result.get("max_count", 0) if isinstance(info_result, dict) else 0
|
||||||
f"📊 表情包统计信息 ({time_str})",
|
available = info_result.get("available_emojis", 0) if isinstance(info_result, dict) else 0
|
||||||
f"• 总数: {emoji_count} / {emoji_info['max_count']}",
|
|
||||||
f"• 可用: {emoji_info['available_emojis']}",
|
lines = [
|
||||||
|
f"📊 表情包统计信息 ({now})",
|
||||||
|
f"• 总数: {emoji_count} / {max_emoji}",
|
||||||
|
f"• 可用: {available}",
|
||||||
]
|
]
|
||||||
|
|
||||||
if emoji_count == 0:
|
if emoji_count == 0:
|
||||||
message_lines.append("\n❌ 暂无表情包")
|
lines.append("\n❌ 暂无表情包")
|
||||||
final_message = "\n".join(message_lines)
|
await self.ctx.send.text("\n".join(lines), stream_id)
|
||||||
await self.send_text(final_message)
|
return True, "\n".join(lines), True
|
||||||
return True, final_message, True
|
|
||||||
|
|
||||||
# 获取所有表情包
|
all_result = await self.ctx.emoji.get_all()
|
||||||
all_emojis = await emoji_api.get_all()
|
all_emojis = all_result if isinstance(all_result, list) else []
|
||||||
if not all_emojis:
|
if not all_emojis:
|
||||||
message_lines.append("\n❌ 无法获取表情包列表")
|
lines.append("\n❌ 无法获取表情包列表")
|
||||||
final_message = "\n".join(message_lines)
|
await self.ctx.send.text("\n".join(lines), stream_id)
|
||||||
await self.send_text(final_message)
|
return False, "\n".join(lines), True
|
||||||
return False, final_message, True
|
|
||||||
|
|
||||||
# 显示前N个表情包
|
display = all_emojis[:max_count]
|
||||||
display_emojis = all_emojis[:max_count]
|
lines.append(f"\n📋 显示前 {len(display)} 个表情包:")
|
||||||
message_lines.append(f"\n📋 显示前 {len(display_emojis)} 个表情包:")
|
for i, emoji in enumerate(display, 1):
|
||||||
|
if isinstance(emoji, (list, tuple)) and len(emoji) >= 3:
|
||||||
|
_, desc, emotion = emoji[0], emoji[1], emoji[2]
|
||||||
|
elif isinstance(emoji, dict):
|
||||||
|
desc = emoji.get("description", "")
|
||||||
|
emotion = emoji.get("emotion", "")
|
||||||
|
else:
|
||||||
|
desc, emotion = str(emoji), ""
|
||||||
|
short_desc = desc[:50] + "..." if len(desc) > 50 else desc
|
||||||
|
lines.append(f"{i}. {short_desc} [{emotion}]")
|
||||||
|
|
||||||
for i, (_, description, emotion) in enumerate(display_emojis, 1):
|
|
||||||
# 截断过长的描述
|
|
||||||
short_desc = description[:50] + "..." if len(description) > 50 else description
|
|
||||||
message_lines.append(f"{i}. {short_desc} [{emotion}]")
|
|
||||||
|
|
||||||
# 如果还有更多表情包,显示总数
|
|
||||||
if len(all_emojis) > max_count:
|
if len(all_emojis) > max_count:
|
||||||
message_lines.append(f"\n💡 还有 {len(all_emojis) - max_count} 个表情包未显示")
|
lines.append(f"\n💡 还有 {len(all_emojis) - max_count} 个表情包未显示")
|
||||||
|
|
||||||
final_message = "\n".join(message_lines)
|
final = "\n".join(lines)
|
||||||
|
await self.ctx.send.text(final, stream_id)
|
||||||
# 直接发送文本消息
|
return True, final, True
|
||||||
await self.send_text(final_message)
|
|
||||||
|
|
||||||
return True, final_message, True
|
|
||||||
|
|
||||||
|
|
||||||
class DeleteEmojiCommand(BaseCommand):
|
|
||||||
command_name = "delete_emoji"
|
|
||||||
command_description = "删除表情包"
|
|
||||||
command_pattern = r".*/emoji delete.*"
|
|
||||||
|
|
||||||
async def execute(self) -> Tuple[bool, str, bool]:
|
|
||||||
# 查找消息中的表情包图片
|
|
||||||
logger.info(f"查找消息中的表情包用于删除: {self.message.message_segment}")
|
|
||||||
|
|
||||||
emoji_base64_list = self.find_and_return_emoji_in_message(self.message.message_segment)
|
|
||||||
|
|
||||||
|
@Command("delete_emoji", description="删除表情包", pattern=r".*/emoji delete.*")
|
||||||
|
async def handle_delete_emoji(self, stream_id: str = "", message_segments=None, **kwargs):
|
||||||
|
"""删除表情包"""
|
||||||
|
emoji_base64_list = self._extract_emoji_base64(message_segments)
|
||||||
if not emoji_base64_list:
|
if not emoji_base64_list:
|
||||||
return False, "未在消息中找到表情包或图片", False
|
await self.ctx.send.text("未在消息中找到表情包或图片", stream_id)
|
||||||
|
return False, "未找到表情包", False
|
||||||
|
|
||||||
# 删除找到的表情包
|
|
||||||
success_count = 0
|
success_count = 0
|
||||||
fail_count = 0
|
fail_count = 0
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
for i, emoji_base64 in enumerate(emoji_base64_list):
|
for i, emoji_b64 in enumerate(emoji_base64_list):
|
||||||
try:
|
# 计算哈希
|
||||||
# 计算图片的哈希值来查找对应的表情包
|
if isinstance(emoji_b64, str):
|
||||||
import base64
|
clean = emoji_b64.encode("ascii", errors="ignore").decode("ascii")
|
||||||
import hashlib
|
|
||||||
|
|
||||||
# 确保base64字符串只包含ASCII字符
|
|
||||||
if isinstance(emoji_base64, str):
|
|
||||||
emoji_base64_clean = emoji_base64.encode("ascii", errors="ignore").decode("ascii")
|
|
||||||
else:
|
|
||||||
emoji_base64_clean = str(emoji_base64)
|
|
||||||
|
|
||||||
# 计算哈希值
|
|
||||||
image_bytes = base64.b64decode(emoji_base64_clean)
|
|
||||||
emoji_hash = hashlib.md5(image_bytes).hexdigest()
|
|
||||||
|
|
||||||
# 使用emoji_api删除表情包
|
|
||||||
result = await emoji_api.delete_emoji(emoji_hash)
|
|
||||||
|
|
||||||
if result["success"]:
|
|
||||||
success_count += 1
|
|
||||||
description = result.get("description", "未知描述")
|
|
||||||
count_before = result.get("count_before", 0)
|
|
||||||
count_after = result.get("count_after", 0)
|
|
||||||
emotions = result.get("emotions", [])
|
|
||||||
|
|
||||||
result_msg = f"表情包 {i + 1} 删除成功"
|
|
||||||
if description:
|
|
||||||
result_msg += f"\n描述: {description}"
|
|
||||||
if emotions:
|
|
||||||
result_msg += f"\n情感标签: {', '.join(emotions)}"
|
|
||||||
result_msg += f"\n表情包数量: {count_before} → {count_after}"
|
|
||||||
|
|
||||||
results.append(result_msg)
|
|
||||||
else:
|
|
||||||
fail_count += 1
|
|
||||||
error_msg = result.get("message", "删除失败")
|
|
||||||
results.append(f"表情包 {i + 1} 删除失败: {error_msg}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
fail_count += 1
|
|
||||||
results.append(f"表情包 {i + 1} 删除时发生错误: {str(e)}")
|
|
||||||
|
|
||||||
# 构建返回消息
|
|
||||||
total_count = success_count + fail_count
|
|
||||||
summary_msg = f"表情包删除完成: 成功 {success_count} 个,失败 {fail_count} 个,共处理 {total_count} 个"
|
|
||||||
|
|
||||||
# 如果有结果详情,添加到返回消息中
|
|
||||||
details_msg = ""
|
|
||||||
if results:
|
|
||||||
details_msg = "\n" + "\n".join(results)
|
|
||||||
final_msg = summary_msg + details_msg
|
|
||||||
else:
|
|
||||||
final_msg = summary_msg
|
|
||||||
|
|
||||||
# 使用表达器重写回复
|
|
||||||
try:
|
|
||||||
from src.plugin_system.apis import generator_api
|
|
||||||
|
|
||||||
# 构建重写数据
|
|
||||||
rewrite_data = {
|
|
||||||
"raw_reply": summary_msg,
|
|
||||||
"reason": f"删除了表情包:{details_msg}\n",
|
|
||||||
}
|
|
||||||
|
|
||||||
# 调用表达器重写
|
|
||||||
result_status, data = await generator_api.rewrite_reply(
|
|
||||||
chat_stream=self.message.chat_stream,
|
|
||||||
reply_data=rewrite_data,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result_status:
|
|
||||||
# 发送重写后的回复
|
|
||||||
for reply_seg in data.reply_set.reply_data:
|
|
||||||
send_data = reply_seg.content
|
|
||||||
await self.send_text(send_data)
|
|
||||||
|
|
||||||
return success_count > 0, final_msg, success_count > 0
|
|
||||||
else:
|
else:
|
||||||
# 如果重写失败,发送原始消息
|
clean = str(emoji_b64)
|
||||||
await self.send_text(final_msg)
|
image_bytes = base64.b64decode(clean)
|
||||||
return success_count > 0, final_msg, success_count > 0
|
emoji_hash = hashlib.md5(image_bytes).hexdigest() # noqa: S324
|
||||||
|
|
||||||
except Exception as e:
|
result = await self.ctx.emoji.delete_emoji(emoji_hash)
|
||||||
# 如果表达器调用失败,发送原始消息
|
if isinstance(result, dict) and result.get("success"):
|
||||||
logger.error(f"[delete_emoji] 表达器重写失败: {e}")
|
success_count += 1
|
||||||
await self.send_text(final_msg)
|
desc = result.get("description", "未知描述")
|
||||||
return success_count > 0, final_msg, success_count > 0
|
emotions = result.get("emotions", [])
|
||||||
|
before = result.get("count_before", 0)
|
||||||
|
after = result.get("count_after", 0)
|
||||||
|
msg = f"表情包 {i + 1} 删除成功"
|
||||||
|
if desc:
|
||||||
|
msg += f"\n描述: {desc}"
|
||||||
|
if emotions:
|
||||||
|
msg += f"\n情感标签: {', '.join(emotions)}"
|
||||||
|
msg += f"\n表情包数量: {before} → {after}"
|
||||||
|
results.append(msg)
|
||||||
|
else:
|
||||||
|
fail_count += 1
|
||||||
|
err = result.get("message", "删除失败") if isinstance(result, dict) else "删除失败"
|
||||||
|
results.append(f"表情包 {i + 1} 删除失败: {err}")
|
||||||
|
|
||||||
def find_and_return_emoji_in_message(self, message_segments) -> List[str]:
|
total = success_count + fail_count
|
||||||
emoji_base64_list = []
|
summary = f"表情包删除完成: 成功 {success_count} 个,失败 {fail_count} 个,共处理 {total} 个"
|
||||||
|
if results:
|
||||||
|
summary += "\n" + "\n".join(results)
|
||||||
|
|
||||||
# 处理单个Seg对象的情况
|
await self.ctx.send.text(summary, stream_id)
|
||||||
if isinstance(message_segments, Seg):
|
return success_count > 0, summary, success_count > 0
|
||||||
if message_segments.type == "emoji":
|
|
||||||
emoji_base64_list.append(message_segments.data)
|
|
||||||
elif message_segments.type == "image":
|
|
||||||
# 假设图片数据是base64编码的
|
|
||||||
emoji_base64_list.append(message_segments.data)
|
|
||||||
elif message_segments.type == "seglist":
|
|
||||||
# 递归处理嵌套的Seg列表
|
|
||||||
emoji_base64_list.extend(self.find_and_return_emoji_in_message(message_segments.data))
|
|
||||||
return emoji_base64_list
|
|
||||||
|
|
||||||
# 处理Seg列表的情况
|
@Command("random_emojis", description="发送多张随机表情包", pattern=r"^/random_emojis$")
|
||||||
for seg in message_segments:
|
async def handle_random_emojis(self, stream_id: str = "", **kwargs):
|
||||||
if seg.type == "emoji":
|
"""发送多张随机表情包"""
|
||||||
emoji_base64_list.append(seg.data)
|
result = await self.ctx.emoji.get_random(5)
|
||||||
elif seg.type == "image":
|
if not result or not result.get("success"):
|
||||||
# 假设图片数据是base64编码的
|
return False, "未找到表情包", False
|
||||||
emoji_base64_list.append(seg.data)
|
emojis = result.get("emojis", [])
|
||||||
elif seg.type == "seglist":
|
|
||||||
# 递归处理嵌套的Seg列表
|
|
||||||
emoji_base64_list.extend(self.find_and_return_emoji_in_message(seg.data))
|
|
||||||
return emoji_base64_list
|
|
||||||
|
|
||||||
|
|
||||||
class RandomEmojis(BaseCommand):
|
|
||||||
command_name = "random_emojis"
|
|
||||||
command_description = "发送多张随机表情包"
|
|
||||||
command_pattern = r"^/random_emojis$"
|
|
||||||
|
|
||||||
async def execute(self):
|
|
||||||
emojis = await emoji_api.get_random(5)
|
|
||||||
if not emojis:
|
if not emojis:
|
||||||
return False, "未找到表情包", False
|
return False, "未找到表情包", False
|
||||||
emoji_base64_list = []
|
messages = [
|
||||||
for emoji in emojis:
|
{"user_id": "0", "nickname": "神秘用户", "segments": [{"type": "image", "content": e.get("base64", "")}]}
|
||||||
emoji_base64_list.append(emoji[0])
|
for e in emojis
|
||||||
return await self.forward_images(emoji_base64_list)
|
|
||||||
|
|
||||||
async def forward_images(self, images: List[str]):
|
|
||||||
"""
|
|
||||||
把多张图片用合并转发的方式发给用户
|
|
||||||
"""
|
|
||||||
success = await self.send_forward([("0", "神秘用户", [(ReplyContentType.IMAGE, img)]) for img in images])
|
|
||||||
return (True, "已发送随机表情包", True) if success else (False, "发送随机表情包失败", False)
|
|
||||||
|
|
||||||
|
|
||||||
# ===== 插件注册 =====
|
|
||||||
|
|
||||||
|
|
||||||
@register_plugin
|
|
||||||
class EmojiManagePlugin(BasePlugin):
|
|
||||||
"""表情包管理插件 - 管理表情包"""
|
|
||||||
|
|
||||||
# 插件基本信息
|
|
||||||
plugin_name: str = "emoji_manage_plugin" # 内部标识符
|
|
||||||
enable_plugin: bool = False
|
|
||||||
dependencies: List[str] = [] # 插件依赖列表
|
|
||||||
python_dependencies: List[str] = [] # Python包依赖列表
|
|
||||||
config_file_name: str = "config.toml" # 配置文件名
|
|
||||||
|
|
||||||
# 配置节描述
|
|
||||||
config_section_descriptions = {"plugin": "插件基本信息", "emoji": "表情包功能配置"}
|
|
||||||
|
|
||||||
# 配置Schema定义
|
|
||||||
config_schema: dict = {
|
|
||||||
"plugin": {
|
|
||||||
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
|
|
||||||
"config_version": ConfigField(type=str, default="1.0.1", description="配置文件版本"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
|
||||||
return [
|
|
||||||
(RandomEmojis.get_command_info(), RandomEmojis),
|
|
||||||
(AddEmojiCommand.get_command_info(), AddEmojiCommand),
|
|
||||||
(ListEmojiCommand.get_command_info(), ListEmojiCommand),
|
|
||||||
(DeleteEmojiCommand.get_command_info(), DeleteEmojiCommand),
|
|
||||||
]
|
]
|
||||||
|
await self.ctx.send.forward(messages, stream_id)
|
||||||
|
return True, "已发送随机表情包", True
|
||||||
|
|
||||||
|
|
||||||
|
def create_plugin():
|
||||||
|
return EmojiManagePlugin()
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
"manifest_version": 1,
|
"manifest_version": 1,
|
||||||
"name": "Hello World 示例插件 (Hello World Plugin)",
|
"name": "Hello World 示例插件 (Hello World Plugin)",
|
||||||
"version": "1.0.0",
|
"version": "2.0.0",
|
||||||
"description": "我的第一个MaiCore插件,包含问候功能和时间查询等基础示例",
|
"description": "我的第一个MaiCore插件,包含问候功能和时间查询等基础示例",
|
||||||
"author": {
|
"author": {
|
||||||
"name": "MaiBot开发团队",
|
"name": "MaiBot开发团队",
|
||||||
@@ -9,7 +9,7 @@
|
|||||||
},
|
},
|
||||||
"license": "GPL-v3.0-or-later",
|
"license": "GPL-v3.0-or-later",
|
||||||
"host_application": {
|
"host_application": {
|
||||||
"min_version": "0.8.0"
|
"min_version": "1.0.0"
|
||||||
},
|
},
|
||||||
"homepage_url": "https://github.com/MaiM-with-u/maibot",
|
"homepage_url": "https://github.com/MaiM-with-u/maibot",
|
||||||
"repository_url": "https://github.com/MaiM-with-u/maibot",
|
"repository_url": "https://github.com/MaiM-with-u/maibot",
|
||||||
@@ -29,7 +29,19 @@
|
|||||||
"plugin_info": {
|
"plugin_info": {
|
||||||
"is_built_in": false,
|
"is_built_in": false,
|
||||||
"plugin_type": "example",
|
"plugin_type": "example",
|
||||||
|
"capabilities": [
|
||||||
|
"send.text",
|
||||||
|
"send.forward",
|
||||||
|
"send.hybrid",
|
||||||
|
"emoji.get_random",
|
||||||
|
"config.get"
|
||||||
|
],
|
||||||
"components": [
|
"components": [
|
||||||
|
{
|
||||||
|
"type": "tool",
|
||||||
|
"name": "compare_numbers",
|
||||||
|
"description": "比较两个数的大小"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"type": "action",
|
"type": "action",
|
||||||
"name": "hello_greeting",
|
"name": "hello_greeting",
|
||||||
@@ -39,28 +51,37 @@
|
|||||||
"type": "action",
|
"type": "action",
|
||||||
"name": "bye_greeting",
|
"name": "bye_greeting",
|
||||||
"description": "向用户发送告别消息",
|
"description": "向用户发送告别消息",
|
||||||
"activation_modes": [
|
"activation_modes": ["keyword"],
|
||||||
"keyword"
|
"keywords": ["再见", "bye", "88", "拜拜"]
|
||||||
],
|
|
||||||
"keywords": [
|
|
||||||
"再见",
|
|
||||||
"bye",
|
|
||||||
"88",
|
|
||||||
"拜拜"
|
|
||||||
]
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "command",
|
"type": "command",
|
||||||
"name": "time",
|
"name": "time",
|
||||||
"description": "查询当前时间",
|
"description": "查询当前时间",
|
||||||
"pattern": "/time"
|
"pattern": "/time"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "command",
|
||||||
|
"name": "random_emojis",
|
||||||
|
"description": "发送多张随机表情包",
|
||||||
|
"pattern": "/random_emojis"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "command",
|
||||||
|
"name": "test",
|
||||||
|
"description": "测试命令",
|
||||||
|
"pattern": "/test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "event_handler",
|
||||||
|
"name": "print_message_handler",
|
||||||
|
"description": "打印接收到的消息"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "event_handler",
|
||||||
|
"name": "forward_messages_handler",
|
||||||
|
"description": "把接收到的消息转发到指定聊天ID"
|
||||||
}
|
}
|
||||||
],
|
|
||||||
"features": [
|
|
||||||
"问候和告别功能",
|
|
||||||
"时间查询命令",
|
|
||||||
"配置文件示例",
|
|
||||||
"新手教程代码"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"id": "MaiBot开发团队.maibot"
|
"id": "MaiBot开发团队.maibot"
|
||||||
|
|||||||
@@ -1,50 +1,30 @@
|
|||||||
|
"""Hello World 示例插件 — 新 SDK 版本
|
||||||
|
|
||||||
|
你的第一个 MaiCore 插件,包含问候功能、时间查询等基础示例。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import datetime
|
||||||
import random
|
import random
|
||||||
from typing import List, Tuple, Type, Any, Optional
|
|
||||||
from src.plugin_system import (
|
|
||||||
BasePlugin,
|
|
||||||
register_plugin,
|
|
||||||
BaseAction,
|
|
||||||
BaseCommand,
|
|
||||||
BaseTool,
|
|
||||||
ComponentInfo,
|
|
||||||
ActionActivationType,
|
|
||||||
ConfigField,
|
|
||||||
BaseEventHandler,
|
|
||||||
EventType,
|
|
||||||
MaiMessages,
|
|
||||||
ToolParamType,
|
|
||||||
ReplyContentType,
|
|
||||||
emoji_api,
|
|
||||||
)
|
|
||||||
from src.config.config import global_config
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
|
|
||||||
logger = get_logger("hello_world_plugin")
|
from maibot_sdk import MaiBotPlugin, Action, Command, Tool, EventHandler
|
||||||
|
from maibot_sdk.types import ActivationType, EventType, ToolParameterInfo, ToolParamType
|
||||||
|
|
||||||
|
|
||||||
class CompareNumbersTool(BaseTool):
|
class HelloWorldPlugin(MaiBotPlugin):
|
||||||
"""比较两个数大小的工具"""
|
"""Hello World 示例插件"""
|
||||||
|
|
||||||
name = "compare_numbers"
|
# ===== Tool 组件 =====
|
||||||
description = "使用工具 比较两个数的大小,返回较大的数"
|
|
||||||
parameters = [
|
|
||||||
("num1", ToolParamType.FLOAT, "第一个数字", True, None),
|
|
||||||
("num2", ToolParamType.FLOAT, "第二个数字", True, None),
|
|
||||||
]
|
|
||||||
available_for_llm = True
|
|
||||||
|
|
||||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""执行比较两个数的大小
|
|
||||||
|
|
||||||
Args:
|
|
||||||
function_args: 工具参数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: 工具执行结果
|
|
||||||
"""
|
|
||||||
num1: int | float = function_args.get("num1") # type: ignore
|
|
||||||
num2: int | float = function_args.get("num2") # type: ignore
|
|
||||||
|
|
||||||
|
@Tool(
|
||||||
|
"compare_numbers",
|
||||||
|
description="使用工具比较两个数的大小,返回较大的数",
|
||||||
|
parameters=[
|
||||||
|
ToolParameterInfo(name="num1", param_type=ToolParamType.FLOAT, description="第一个数字", required=True),
|
||||||
|
ToolParameterInfo(name="num2", param_type=ToolParamType.FLOAT, description="第二个数字", required=True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def handle_compare_numbers(self, num1: float = 0, num2: float = 0, **kwargs):
|
||||||
|
"""比较两个数的大小"""
|
||||||
try:
|
try:
|
||||||
if num1 > num2:
|
if num1 > num2:
|
||||||
result = f"{num1} 大于 {num2}"
|
result = f"{num1} 大于 {num2}"
|
||||||
@@ -52,270 +32,121 @@ class CompareNumbersTool(BaseTool):
|
|||||||
result = f"{num1} 小于 {num2}"
|
result = f"{num1} 小于 {num2}"
|
||||||
else:
|
else:
|
||||||
result = f"{num1} 等于 {num2}"
|
result = f"{num1} 等于 {num2}"
|
||||||
|
return {"name": "compare_numbers", "content": result}
|
||||||
return {"name": self.name, "content": result}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"name": self.name, "content": f"比较数字失败,炸了: {str(e)}"}
|
return {"name": "compare_numbers", "content": f"比较数字失败,炸了: {e}"}
|
||||||
|
|
||||||
|
# ===== Action 组件 =====
|
||||||
|
|
||||||
# ===== Action组件 =====
|
@Action(
|
||||||
class HelloAction(BaseAction):
|
"hello_greeting",
|
||||||
"""问候Action - 简单的问候动作"""
|
description="向用户发送问候消息",
|
||||||
|
activation_type=ActivationType.ALWAYS,
|
||||||
# === 基本信息(必须填写)===
|
action_parameters={"greeting_message": "要发送的问候消息"},
|
||||||
action_name = "hello_greeting"
|
action_require=["需要发送友好问候时使用", "当有人向你问好时使用", "当你遇见没有见过的人时使用"],
|
||||||
action_description = "向用户发送问候消息"
|
associated_types=["text"],
|
||||||
activation_type = ActionActivationType.ALWAYS # 始终激活
|
)
|
||||||
|
async def handle_hello(self, stream_id: str = "", greeting_message: str = "", **kwargs):
|
||||||
# === 功能描述(必须填写)===
|
"""问候动作"""
|
||||||
action_parameters = {"greeting_message": "要发送的问候消息"}
|
config_result = await self.ctx.config.get("greeting.message")
|
||||||
action_require = ["需要发送友好问候时使用", "当有人向你问好时使用", "当你遇见没有见过的人时使用"]
|
base_message = config_result if isinstance(config_result, str) else "嗨!很开心见到你!😊"
|
||||||
associated_types = ["text"]
|
|
||||||
|
|
||||||
async def execute(self) -> Tuple[bool, str]:
|
|
||||||
"""执行问候动作 - 这是核心功能"""
|
|
||||||
# 发送问候消息
|
|
||||||
greeting_message = self.action_data.get("greeting_message", "")
|
|
||||||
base_message = self.get_config("greeting.message", "嗨!很开心见到你!😊")
|
|
||||||
message = base_message + greeting_message
|
message = base_message + greeting_message
|
||||||
await self.send_text(message)
|
await self.ctx.send.text(message, stream_id)
|
||||||
|
|
||||||
return True, "发送了问候消息"
|
return True, "发送了问候消息"
|
||||||
|
|
||||||
|
@Action(
|
||||||
class ByeAction(BaseAction):
|
"bye_greeting",
|
||||||
"""告别Action - 只在用户说再见时激活"""
|
description="向用户发送告别消息",
|
||||||
|
activation_type=ActivationType.KEYWORD,
|
||||||
action_name = "bye_greeting"
|
activation_keywords=["再见", "bye", "88", "拜拜"],
|
||||||
action_description = "向用户发送告别消息"
|
action_parameters={"bye_message": "要发送的告别消息"},
|
||||||
|
action_require=["用户要告别时使用", "当有人要离开时使用", "当有人和你说再见时使用"],
|
||||||
# 使用关键词激活
|
associated_types=["text"],
|
||||||
activation_type = ActionActivationType.KEYWORD
|
)
|
||||||
|
async def handle_bye(self, stream_id: str = "", bye_message: str = "", **kwargs):
|
||||||
# 关键词设置
|
"""告别动作"""
|
||||||
activation_keywords = ["再见", "bye", "88", "拜拜"]
|
|
||||||
keyword_case_sensitive = False
|
|
||||||
|
|
||||||
action_parameters = {"bye_message": "要发送的告别消息"}
|
|
||||||
action_require = [
|
|
||||||
"用户要告别时使用",
|
|
||||||
"当有人要离开时使用",
|
|
||||||
"当有人和你说再见时使用",
|
|
||||||
]
|
|
||||||
associated_types = ["text"]
|
|
||||||
|
|
||||||
async def execute(self) -> Tuple[bool, str]:
|
|
||||||
bye_message = self.action_data.get("bye_message", "")
|
|
||||||
|
|
||||||
message = f"再见!期待下次聊天!👋{bye_message}"
|
message = f"再见!期待下次聊天!👋{bye_message}"
|
||||||
await self.send_text(message)
|
await self.ctx.send.text(message, stream_id)
|
||||||
return True, "发送了告别消息"
|
return True, "发送了告别消息"
|
||||||
|
|
||||||
|
# ===== Command 组件 =====
|
||||||
|
|
||||||
class TimeCommand(BaseCommand):
|
@Command("time", description="查询当前时间", pattern=r"^/time$")
|
||||||
"""时间查询Command - 响应/time命令"""
|
async def handle_time(self, stream_id: str = "", **kwargs):
|
||||||
|
"""时间查询命令"""
|
||||||
command_name = "time"
|
config_result = await self.ctx.config.get("time.format")
|
||||||
command_description = "查询当前时间"
|
time_format = config_result if isinstance(config_result, str) else "%Y-%m-%d %H:%M:%S"
|
||||||
|
|
||||||
# === 命令设置(必须填写)===
|
|
||||||
command_pattern = r"^/time$" # 精确匹配 "/time" 命令
|
|
||||||
|
|
||||||
async def execute(self) -> Tuple[bool, str, bool]:
|
|
||||||
"""执行时间查询"""
|
|
||||||
import datetime
|
|
||||||
|
|
||||||
# 获取当前时间
|
|
||||||
time_format: str = self.get_config("time.format", "%Y-%m-%d %H:%M:%S") # type: ignore
|
|
||||||
now = datetime.datetime.now()
|
now = datetime.datetime.now()
|
||||||
time_str = now.strftime(time_format)
|
time_str = now.strftime(time_format)
|
||||||
|
await self.ctx.send.text(f"⏰ 当前时间:{time_str}", stream_id)
|
||||||
# 发送时间信息
|
|
||||||
message = f"⏰ 当前时间:{time_str}"
|
|
||||||
await self.send_text(message)
|
|
||||||
|
|
||||||
return True, f"显示了当前时间: {time_str}", True
|
return True, f"显示了当前时间: {time_str}", True
|
||||||
|
|
||||||
|
@Command("random_emojis", description="发送多张随机表情包", pattern=r"^/random_emojis$")
|
||||||
|
async def handle_random_emojis(self, stream_id: str = "", **kwargs):
|
||||||
|
"""发送多张随机表情包"""
|
||||||
|
result = await self.ctx.emoji.get_random(5)
|
||||||
|
if not result or not result.get("success"):
|
||||||
|
return False, "未找到表情包", False
|
||||||
|
emojis = result.get("emojis", [])
|
||||||
|
if not emojis:
|
||||||
|
return False, "未找到表情包", False
|
||||||
|
# 用转发消息发送多张图片
|
||||||
|
messages = [
|
||||||
|
{"user_id": "0", "nickname": "神秘用户", "segments": [{"type": "image", "content": e.get("base64", "")}]}
|
||||||
|
for e in emojis
|
||||||
|
]
|
||||||
|
await self.ctx.send.forward(messages, stream_id)
|
||||||
|
return True, "已发送随机表情包", True
|
||||||
|
|
||||||
class PrintMessage(BaseEventHandler):
|
@Command("test", description="测试命令", pattern=r"^/test$")
|
||||||
"""打印消息事件处理器 - 处理打印消息事件"""
|
async def handle_test(self, stream_id: str = "", **kwargs):
|
||||||
|
"""测试命令 — 发送简单测试消息"""
|
||||||
|
await self.ctx.send.text("测试正常!Bot 功能运行中 ✅", stream_id)
|
||||||
|
return True, "测试完成", True
|
||||||
|
|
||||||
event_type = EventType.ON_MESSAGE
|
# ===== EventHandler 组件 =====
|
||||||
handler_name = "print_message_handler"
|
|
||||||
handler_description = "打印接收到的消息"
|
|
||||||
|
|
||||||
async def execute(self, message: MaiMessages | None) -> Tuple[bool, bool, str | None, None, None]:
|
@EventHandler("print_message_handler", description="打印接收到的消息", event_type=EventType.ON_MESSAGE)
|
||||||
"""执行打印消息事件处理"""
|
async def handle_print_message(self, message=None, **kwargs):
|
||||||
# 打印接收到的消息
|
"""打印消息事件"""
|
||||||
if self.get_config("print_message.enabled", False):
|
config_result = await self.ctx.config.get("print_message.enabled")
|
||||||
print(f"接收到消息: {message.raw_message if message else '无效消息'}")
|
enabled = config_result if isinstance(config_result, bool) else False
|
||||||
|
if enabled and message:
|
||||||
|
raw = message.get("raw_message", "") if isinstance(message, dict) else str(message)
|
||||||
|
print(f"接收到消息: {raw}")
|
||||||
return True, True, "消息已打印", None, None
|
return True, True, "消息已打印", None, None
|
||||||
|
|
||||||
|
@EventHandler("forward_messages_handler", description="把接收到的消息转发到指定聊天ID", event_type=EventType.ON_MESSAGE)
|
||||||
class ForwardMessages(BaseEventHandler):
|
async def handle_forward_messages(self, message=None, stream_id: str = "", **kwargs):
|
||||||
"""
|
"""收集消息并定期转发"""
|
||||||
把接收到的消息转发到指定聊天ID
|
|
||||||
|
|
||||||
此组件是HYBRID消息和FORWARD消息的使用示例。
|
|
||||||
每收到10条消息,就会以1%的概率使用HYBRID消息转发,否则使用FORWARD消息转发。
|
|
||||||
"""
|
|
||||||
|
|
||||||
event_type = EventType.ON_MESSAGE
|
|
||||||
handler_name = "forward_messages_handler"
|
|
||||||
handler_description = "把接收到的消息转发到指定聊天ID"
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.counter = 0 # 用于计数转发的消息数量
|
|
||||||
self.messages: List[str] = []
|
|
||||||
|
|
||||||
async def execute(self, message: MaiMessages | None) -> Tuple[bool, bool, None, None, None]:
|
|
||||||
if not message:
|
if not message:
|
||||||
return True, True, None, None, None
|
return True, True, None, None, None
|
||||||
stream_id = message.stream_id or ""
|
plain_text = message.get("plain_text", "") if isinstance(message, dict) else ""
|
||||||
|
if not plain_text:
|
||||||
|
return True, True, None, None, None
|
||||||
|
|
||||||
if message.plain_text:
|
# 使用插件级状态收集消息
|
||||||
self.messages.append(message.plain_text)
|
if not hasattr(self, "_fwd_messages"):
|
||||||
self.counter += 1
|
self._fwd_messages: list[str] = []
|
||||||
if self.counter % 10 == 0:
|
self._fwd_counter: int = 0
|
||||||
|
|
||||||
|
self._fwd_messages.append(plain_text)
|
||||||
|
self._fwd_counter += 1
|
||||||
|
|
||||||
|
if self._fwd_counter % 10 == 0 and stream_id:
|
||||||
if random.random() < 0.01:
|
if random.random() < 0.01:
|
||||||
success = await self.send_hybrid(stream_id, [(ReplyContentType.TEXT, msg) for msg in self.messages])
|
segments = [{"type": "text", "content": msg} for msg in self._fwd_messages]
|
||||||
|
await self.ctx.send.hybrid(segments, stream_id)
|
||||||
else:
|
else:
|
||||||
success = await self.send_forward(
|
messages = [
|
||||||
stream_id,
|
{"user_id": "0", "nickname": "转发", "segments": [{"type": "text", "content": msg}]}
|
||||||
[
|
for msg in self._fwd_messages
|
||||||
(
|
]
|
||||||
str(global_config.bot.qq_account),
|
await self.ctx.send.forward(messages, stream_id)
|
||||||
str(global_config.bot.nickname),
|
self._fwd_messages = []
|
||||||
[(ReplyContentType.TEXT, msg)],
|
|
||||||
)
|
|
||||||
for msg in self.messages
|
|
||||||
],
|
|
||||||
)
|
|
||||||
if not success:
|
|
||||||
raise ValueError("转发消息失败")
|
|
||||||
self.messages = []
|
|
||||||
return True, True, None, None, None
|
return True, True, None, None, None
|
||||||
|
|
||||||
|
|
||||||
class RandomEmojis(BaseCommand):
|
def create_plugin():
|
||||||
command_name = "random_emojis"
|
return HelloWorldPlugin()
|
||||||
command_description = "发送多张随机表情包"
|
|
||||||
command_pattern = r"^/random_emojis$"
|
|
||||||
|
|
||||||
async def execute(self):
|
|
||||||
emojis = await emoji_api.get_random(5)
|
|
||||||
if not emojis:
|
|
||||||
return False, "未找到表情包", False
|
|
||||||
emoji_base64_list = []
|
|
||||||
for emoji in emojis:
|
|
||||||
emoji_base64_list.append(emoji[0])
|
|
||||||
return await self.forward_images(emoji_base64_list)
|
|
||||||
|
|
||||||
async def forward_images(self, images: List[str]):
|
|
||||||
"""
|
|
||||||
把多张图片用合并转发的方式发给用户
|
|
||||||
"""
|
|
||||||
success = await self.send_forward([("0", "神秘用户", [(ReplyContentType.IMAGE, img)]) for img in images])
|
|
||||||
return (True, "已发送随机表情包", True) if success else (False, "发送随机表情包失败", False)
|
|
||||||
|
|
||||||
|
|
||||||
class TestCommand(BaseCommand):
|
|
||||||
"""响应/test命令"""
|
|
||||||
|
|
||||||
command_name = "test"
|
|
||||||
command_description = "测试命令"
|
|
||||||
command_pattern = r"^/test$"
|
|
||||||
|
|
||||||
async def execute(self) -> Tuple[bool, Optional[str], int]:
|
|
||||||
"""执行测试命令"""
|
|
||||||
try:
|
|
||||||
from src.plugin_system.apis import generator_api
|
|
||||||
|
|
||||||
reply_reason = "这是一条测试消息。"
|
|
||||||
logger.info(f"测试命令:{reply_reason}")
|
|
||||||
result_status, data = await generator_api.generate_reply(
|
|
||||||
chat_stream=self.message.chat_stream,
|
|
||||||
reply_reason=reply_reason,
|
|
||||||
enable_chinese_typo=False,
|
|
||||||
extra_info=f'{reply_reason}用于测试bot的功能是否正常。请你按设定的人设表达一句"测试正常"',
|
|
||||||
)
|
|
||||||
if result_status:
|
|
||||||
# 发送生成的回复
|
|
||||||
if data and data.reply_set and data.reply_set.reply_data:
|
|
||||||
for reply_seg in data.reply_set.reply_data:
|
|
||||||
send_data = reply_seg.content
|
|
||||||
await self.send_text(send_data, storage_message=True)
|
|
||||||
logger.info(f"已回复: {send_data}")
|
|
||||||
return True, "", 1
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"表达器生成失败:{e}")
|
|
||||||
return True, "", 1
|
|
||||||
|
|
||||||
|
|
||||||
# ===== 插件注册 =====
|
|
||||||
|
|
||||||
|
|
||||||
@register_plugin
|
|
||||||
class HelloWorldPlugin(BasePlugin):
|
|
||||||
"""Hello World插件 - 你的第一个MaiCore插件"""
|
|
||||||
|
|
||||||
# 插件基本信息
|
|
||||||
plugin_name: str = "hello_world_plugin" # 内部标识符
|
|
||||||
enable_plugin: bool = False
|
|
||||||
dependencies: List[str] = [] # 插件依赖列表
|
|
||||||
python_dependencies: List[str] = [] # Python包依赖列表
|
|
||||||
config_file_name: str = "config.toml" # 配置文件名
|
|
||||||
|
|
||||||
# 配置节描述
|
|
||||||
config_section_descriptions = {"plugin": "插件基本信息", "greeting": "问候功能配置", "time": "时间查询配置"}
|
|
||||||
|
|
||||||
# 配置Schema定义
|
|
||||||
config_schema: dict = {
|
|
||||||
"plugin": {
|
|
||||||
"config_version": ConfigField(type=str, default="1.0.0", description="配置文件版本"),
|
|
||||||
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
|
|
||||||
},
|
|
||||||
"greeting": {
|
|
||||||
"message": ConfigField(
|
|
||||||
type=list, default=["嗨!很开心见到你!😊", "Ciallo~(∠・ω< )⌒★"], description="默认问候消息"
|
|
||||||
),
|
|
||||||
"enable_emoji": ConfigField(type=bool, default=True, description="是否启用表情符号"),
|
|
||||||
},
|
|
||||||
"time": {"format": ConfigField(type=str, default="%Y-%m-%d %H:%M:%S", description="时间显示格式")},
|
|
||||||
"print_message": {"enabled": ConfigField(type=bool, default=True, description="是否启用打印")},
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
|
||||||
return [
|
|
||||||
(HelloAction.get_action_info(), HelloAction),
|
|
||||||
(CompareNumbersTool.get_tool_info(), CompareNumbersTool), # 添加比较数字工具
|
|
||||||
(ByeAction.get_action_info(), ByeAction), # 添加告别Action
|
|
||||||
(TimeCommand.get_command_info(), TimeCommand),
|
|
||||||
(PrintMessage.get_handler_info(), PrintMessage),
|
|
||||||
(ForwardMessages.get_handler_info(), ForwardMessages),
|
|
||||||
(RandomEmojis.get_command_info(), RandomEmojis),
|
|
||||||
(TestCommand.get_command_info(), TestCommand),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# @register_plugin
|
|
||||||
# class HelloWorldEventPlugin(BaseEPlugin):
|
|
||||||
# """Hello World事件插件 - 处理问候和告别事件"""
|
|
||||||
|
|
||||||
# plugin_name = "hello_world_event_plugin"
|
|
||||||
# enable_plugin = False
|
|
||||||
# dependencies = []
|
|
||||||
# python_dependencies = []
|
|
||||||
# config_file_name = "event_config.toml"
|
|
||||||
|
|
||||||
# config_schema = {
|
|
||||||
# "plugin": {
|
|
||||||
# "name": ConfigField(type=str, default="hello_world_event_plugin", description="插件名称"),
|
|
||||||
# "version": ConfigField(type=str, default="1.0.0", description="插件版本"),
|
|
||||||
# "enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
|
|
||||||
# },
|
|
||||||
# }
|
|
||||||
|
|
||||||
# def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
|
||||||
# return [(PrintMessage.get_handler_info(), PrintMessage)]
|
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ dependencies = [
|
|||||||
"uvicorn>=0.35.0",
|
"uvicorn>=0.35.0",
|
||||||
"msgpack>=1.1.2",
|
"msgpack>=1.1.2",
|
||||||
"watchfiles>=1.1.1",
|
"watchfiles>=1.1.1",
|
||||||
|
"maibot-plugin-sdk>=1.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ fastapi>=0.116.0
|
|||||||
google-genai>=1.39.1
|
google-genai>=1.39.1
|
||||||
jieba>=0.42.1
|
jieba>=0.42.1
|
||||||
json-repair>=0.47.6
|
json-repair>=0.47.6
|
||||||
|
maibot-plugin-sdk>=1.0.0
|
||||||
maim-message>=0.6.2
|
maim-message>=0.6.2
|
||||||
matplotlib>=3.10.3
|
matplotlib>=3.10.3
|
||||||
numpy>=2.2.6
|
numpy>=2.2.6
|
||||||
|
|||||||
@@ -19,9 +19,10 @@ from src.chat.heart_flow.hfc_utils import CycleDetail
|
|||||||
from src.bw_learner.expression_learner import expression_learner_manager
|
from src.bw_learner.expression_learner import expression_learner_manager
|
||||||
from src.bw_learner.message_recorder import extract_and_distribute_messages
|
from src.bw_learner.message_recorder import extract_and_distribute_messages
|
||||||
from src.person_info.person_info import Person
|
from src.person_info.person_info import Person
|
||||||
from src.plugin_system.base.component_types import EventType, ActionInfo
|
from src.core.types import ActionInfo, EventType
|
||||||
from src.plugin_system.core import events_manager
|
from src.core.event_bus import event_bus
|
||||||
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
|
from src.chat.event_helpers import build_event_message
|
||||||
|
from src.services import generator_service as generator_api, send_service as send_api, message_service as message_api, database_service as database_api
|
||||||
from src.chat.utils.chat_message_builder import (
|
from src.chat.utils.chat_message_builder import (
|
||||||
build_readable_messages_with_id,
|
build_readable_messages_with_id,
|
||||||
get_raw_msg_before_timestamp_with_chat,
|
get_raw_msg_before_timestamp_with_chat,
|
||||||
@@ -315,8 +316,9 @@ class BrainChatting:
|
|||||||
message_id_list=message_id_list,
|
message_id_list=message_id_list,
|
||||||
prompt_key="brain_planner",
|
prompt_key="brain_planner",
|
||||||
)
|
)
|
||||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
_event_msg = build_event_message(EventType.ON_PLAN, llm_prompt=prompt_info[0], stream_id=self.chat_stream.stream_id)
|
||||||
EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id
|
continue_flag, modified_message = await event_bus.emit(
|
||||||
|
EventType.ON_PLAN, _event_msg
|
||||||
)
|
)
|
||||||
if not continue_flag:
|
if not continue_flag:
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -23,8 +23,8 @@ from src.chat.utils.chat_message_builder import (
|
|||||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||||
from src.chat.planner_actions.action_manager import ActionManager
|
from src.chat.planner_actions.action_manager import ActionManager
|
||||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||||
from src.plugin_system.base.component_types import ActionInfo, ComponentType, ActionActivationType
|
from src.core.types import ActionActivationType, ActionInfo, ComponentType
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
from src.core.component_registry import component_registry
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from src.common.data_models.info_data_model import TargetPersonInfo
|
from src.common.data_models.info_data_model import TargetPersonInfo
|
||||||
|
|||||||
169
src/chat/event_helpers.py
Normal file
169
src/chat/event_helpers.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
"""
|
||||||
|
事件消息构建工具
|
||||||
|
|
||||||
|
将 chat 层的消息对象 (SessionMessage / MessageSending) 转换为
|
||||||
|
核心事件系统使用的 MaiMessages,供调用 event_bus.emit() 前使用。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.core.types import EventType, MaiMessages
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.chat.message_receive.message import MessageSending, SessionMessage
|
||||||
|
from src.common.data_models.llm_data_model import LLMGenerationDataModel
|
||||||
|
|
||||||
|
logger = get_logger("event_helpers")
|
||||||
|
|
||||||
|
|
||||||
|
def build_event_message(
|
||||||
|
event_type: EventType | str,
|
||||||
|
message: Optional["SessionMessage | MessageSending | MaiMessages"] = None,
|
||||||
|
llm_prompt: Optional[str] = None,
|
||||||
|
llm_response: Optional["LLMGenerationDataModel"] = None,
|
||||||
|
stream_id: Optional[str] = None,
|
||||||
|
action_usage: Optional[List[str]] = None,
|
||||||
|
) -> Optional[MaiMessages]:
|
||||||
|
"""根据事件类型和输入,准备和转换消息对象。
|
||||||
|
|
||||||
|
迁移自 events_manager._prepare_message,保持相同的行为。
|
||||||
|
"""
|
||||||
|
if isinstance(message, MaiMessages):
|
||||||
|
return message.deepcopy()
|
||||||
|
|
||||||
|
if message:
|
||||||
|
return _transform_event_message(message, llm_prompt, llm_response)
|
||||||
|
|
||||||
|
if event_type not in (EventType.ON_START, EventType.ON_STOP):
|
||||||
|
assert stream_id, "如果没有消息,必须为非启动/关闭事件提供流ID"
|
||||||
|
if event_type in (EventType.ON_MESSAGE, EventType.ON_PLAN, EventType.POST_LLM, EventType.AFTER_LLM):
|
||||||
|
return _build_message_from_stream(stream_id, llm_prompt, llm_response)
|
||||||
|
else:
|
||||||
|
return _build_message_without_raw(stream_id, llm_prompt, llm_response, action_usage)
|
||||||
|
|
||||||
|
return None # ON_START / ON_STOP 没有消息体
|
||||||
|
|
||||||
|
|
||||||
|
def _transform_event_message(
|
||||||
|
message: "SessionMessage | MessageSending",
|
||||||
|
llm_prompt: Optional[str] = None,
|
||||||
|
llm_response: Optional["LLMGenerationDataModel"] = None,
|
||||||
|
) -> MaiMessages:
|
||||||
|
"""将 SessionMessage / MessageSending 转换为 MaiMessages。"""
|
||||||
|
from maim_message import Seg
|
||||||
|
from src.chat.message_receive.message import MessageSending
|
||||||
|
|
||||||
|
transformed = MaiMessages(
|
||||||
|
llm_prompt=llm_prompt,
|
||||||
|
llm_response_content=llm_response.content if llm_response else None,
|
||||||
|
llm_response_reasoning=llm_response.reasoning if llm_response else None,
|
||||||
|
llm_response_model=llm_response.model if llm_response else None,
|
||||||
|
llm_response_tool_call=llm_response.tool_calls if llm_response else None,
|
||||||
|
raw_message=message.processed_plain_text or "",
|
||||||
|
additional_data={},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 消息段处理
|
||||||
|
if isinstance(message, MessageSending):
|
||||||
|
if message.message_segment.type == "seglist":
|
||||||
|
transformed.message_segments = list(message.message_segment.data) # type: ignore
|
||||||
|
else:
|
||||||
|
transformed.message_segments = [message.message_segment]
|
||||||
|
else:
|
||||||
|
transformed.message_segments = [Seg(type="text", data=message.processed_plain_text or "")]
|
||||||
|
|
||||||
|
# stream_id
|
||||||
|
transformed.stream_id = message.session_id if hasattr(message, "session_id") else ""
|
||||||
|
|
||||||
|
# 处理后文本
|
||||||
|
transformed.plain_text = message.processed_plain_text
|
||||||
|
|
||||||
|
# 基本信息
|
||||||
|
if isinstance(message, MessageSending):
|
||||||
|
transformed.message_base_info["platform"] = message.platform
|
||||||
|
if message.session.group_id:
|
||||||
|
transformed.is_group_message = True
|
||||||
|
group_name = ""
|
||||||
|
if (
|
||||||
|
message.session.context
|
||||||
|
and message.session.context.message
|
||||||
|
and message.session.context.message.message_info.group_info
|
||||||
|
):
|
||||||
|
group_name = message.session.context.message.message_info.group_info.group_name
|
||||||
|
transformed.message_base_info.update(
|
||||||
|
{
|
||||||
|
"group_id": message.session.group_id,
|
||||||
|
"group_name": group_name,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
transformed.message_base_info.update(
|
||||||
|
{
|
||||||
|
"user_id": message.bot_user_info.user_id,
|
||||||
|
"user_cardname": message.bot_user_info.user_cardname,
|
||||||
|
"user_nickname": message.bot_user_info.user_nickname,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if not transformed.is_group_message:
|
||||||
|
transformed.is_private_message = True
|
||||||
|
elif hasattr(message, "message_info") and message.message_info:
|
||||||
|
if message.platform:
|
||||||
|
transformed.message_base_info["platform"] = message.platform
|
||||||
|
if message.message_info.group_info:
|
||||||
|
transformed.is_group_message = True
|
||||||
|
transformed.message_base_info.update(
|
||||||
|
{
|
||||||
|
"group_id": message.message_info.group_info.group_id,
|
||||||
|
"group_name": message.message_info.group_info.group_name,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if message.message_info.user_info:
|
||||||
|
if not transformed.is_group_message:
|
||||||
|
transformed.is_private_message = True
|
||||||
|
transformed.message_base_info.update(
|
||||||
|
{
|
||||||
|
"user_id": message.message_info.user_info.user_id,
|
||||||
|
"user_cardname": message.message_info.user_info.user_cardname,
|
||||||
|
"user_nickname": message.message_info.user_info.user_nickname,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return transformed
|
||||||
|
|
||||||
|
|
||||||
|
def _build_message_from_stream(
|
||||||
|
stream_id: str,
|
||||||
|
llm_prompt: Optional[str] = None,
|
||||||
|
llm_response: Optional["LLMGenerationDataModel"] = None,
|
||||||
|
) -> MaiMessages:
|
||||||
|
"""从 stream_id 查找会话消息并转换。"""
|
||||||
|
from src.chat.message_receive.chat_manager import chat_manager
|
||||||
|
|
||||||
|
session = chat_manager.get_session_by_session_id(stream_id)
|
||||||
|
assert session, f"未找到流ID为 {stream_id} 的会话"
|
||||||
|
return _transform_event_message(session.context.message, llm_prompt, llm_response)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_message_without_raw(
|
||||||
|
stream_id: str,
|
||||||
|
llm_prompt: Optional[str] = None,
|
||||||
|
llm_response: Optional["LLMGenerationDataModel"] = None,
|
||||||
|
action_usage: Optional[List[str]] = None,
|
||||||
|
) -> MaiMessages:
|
||||||
|
"""没有原始消息对象时,从 stream_id 构建最小 MaiMessages。"""
|
||||||
|
from src.chat.message_receive.chat_manager import chat_manager
|
||||||
|
|
||||||
|
session = chat_manager.get_session_by_session_id(stream_id)
|
||||||
|
assert session, f"未找到流ID为 {stream_id} 的会话"
|
||||||
|
return MaiMessages(
|
||||||
|
stream_id=stream_id,
|
||||||
|
llm_prompt=llm_prompt,
|
||||||
|
llm_response_content=llm_response.content if llm_response else None,
|
||||||
|
llm_response_reasoning=llm_response.reasoning if llm_response else None,
|
||||||
|
llm_response_model=llm_response.model if llm_response else None,
|
||||||
|
llm_response_tool_call=llm_response.tool_calls if llm_response else None,
|
||||||
|
is_group_message=session.is_group_session,
|
||||||
|
is_private_message=not session.is_group_session,
|
||||||
|
action_usage=action_usage,
|
||||||
|
additional_data={"response_is_processed": True},
|
||||||
|
)
|
||||||
@@ -6,7 +6,7 @@ import time
|
|||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||||
from src.plugin_system.apis import send_api
|
from src.services import send_service as send_api
|
||||||
|
|
||||||
from src.common.message_repository import count_messages
|
from src.common.message_repository import count_messages
|
||||||
|
|
||||||
|
|||||||
@@ -11,8 +11,9 @@ from src.common.utils.utils_session import SessionUtils
|
|||||||
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
||||||
from src.chat.brain_chat.PFC.pfc_manager import PFCManager
|
from src.chat.brain_chat.PFC.pfc_manager import PFCManager
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.plugin_system.core import component_registry, events_manager, global_announcement_manager
|
from src.core.announcement_manager import global_announcement_manager
|
||||||
from src.plugin_system.base import BaseCommand, EventType
|
from src.core.component_registry import component_registry
|
||||||
|
from src.core.types import EventType
|
||||||
|
|
||||||
from .message import SessionMessage
|
from .message import SessionMessage
|
||||||
from .chat_manager import chat_manager
|
from .chat_manager import chat_manager
|
||||||
@@ -65,10 +66,10 @@ class ChatBot:
|
|||||||
try:
|
try:
|
||||||
text = message.processed_plain_text
|
text = message.processed_plain_text
|
||||||
|
|
||||||
# 使用新的组件注册中心查找命令
|
# 使用核心组件注册表查找命令
|
||||||
command_result = component_registry.find_command_by_text(text)
|
command_result = component_registry.find_command_by_text(text)
|
||||||
if command_result:
|
if command_result:
|
||||||
command_class, matched_groups, command_info = command_result
|
command_executor, matched_groups, command_info = command_result
|
||||||
plugin_name = command_info.plugin_name
|
plugin_name = command_info.plugin_name
|
||||||
command_name = command_info.name
|
command_name = command_info.name
|
||||||
if message.session_id and command_name in global_announcement_manager.get_disabled_chat_commands(
|
if message.session_id and command_name in global_announcement_manager.get_disabled_chat_commands(
|
||||||
@@ -82,20 +83,20 @@ class ChatBot:
|
|||||||
# 获取插件配置
|
# 获取插件配置
|
||||||
plugin_config = component_registry.get_plugin_config(plugin_name)
|
plugin_config = component_registry.get_plugin_config(plugin_name)
|
||||||
|
|
||||||
# 创建命令实例
|
|
||||||
command_instance: BaseCommand = command_class(message, plugin_config)
|
|
||||||
command_instance.set_matched_groups(matched_groups)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 执行命令
|
# 调用命令执行器
|
||||||
success, response, intercept_message_level = await command_instance.execute()
|
success, response, intercept_message_level = await command_executor(
|
||||||
|
message=message,
|
||||||
|
plugin_config=plugin_config,
|
||||||
|
matched_groups=matched_groups,
|
||||||
|
)
|
||||||
message.intercept_message_level = intercept_message_level
|
message.intercept_message_level = intercept_message_level
|
||||||
|
|
||||||
# 记录命令执行结果
|
# 记录命令执行结果
|
||||||
if success:
|
if success:
|
||||||
logger.info(f"命令执行成功: {command_class.__name__} (拦截等级: {intercept_message_level})")
|
logger.info(f"命令执行成功: {command_name} (拦截等级: {intercept_message_level})")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"命令执行失败: {command_class.__name__} - {response}")
|
logger.warning(f"命令执行失败: {command_name} - {response}")
|
||||||
|
|
||||||
# 根据命令的拦截设置决定是否继续处理消息
|
# 根据命令的拦截设置决定是否继续处理消息
|
||||||
return (
|
return (
|
||||||
@@ -105,14 +106,9 @@ class ChatBot:
|
|||||||
) # 找到命令,根据intercept_message决定是否继续
|
) # 找到命令,根据intercept_message决定是否继续
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"执行命令时出错: {command_class.__name__} - {e}")
|
logger.error(f"执行命令时出错: {command_name} - {e}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
try:
|
|
||||||
await command_instance.send_text(f"命令执行出错: {str(e)}")
|
|
||||||
except Exception as send_error:
|
|
||||||
logger.error(f"发送错误消息失败: {send_error}")
|
|
||||||
|
|
||||||
# 命令出错时,根据命令的拦截设置决定是否继续处理消息
|
# 命令出错时,根据命令的拦截设置决定是否继续处理消息
|
||||||
return True, str(e), False # 出错时继续处理消息
|
return True, str(e), False # 出错时继续处理消息
|
||||||
|
|
||||||
|
|||||||
@@ -318,11 +318,13 @@ class UniversalMessageSender:
|
|||||||
message.build_reply()
|
message.build_reply()
|
||||||
logger.debug(f"[{chat_id}] 选择回复引用消息: {message.processed_plain_text[:20]}...")
|
logger.debug(f"[{chat_id}] 选择回复引用消息: {message.processed_plain_text[:20]}...")
|
||||||
|
|
||||||
from src.plugin_system.core.events_manager import events_manager
|
from src.core.event_bus import event_bus
|
||||||
from src.plugin_system.base.component_types import EventType
|
from src.chat.event_helpers import build_event_message
|
||||||
|
from src.core.types import EventType
|
||||||
|
|
||||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
_event_msg = build_event_message(EventType.POST_SEND_PRE_PROCESS, message=message, stream_id=chat_id)
|
||||||
EventType.POST_SEND_PRE_PROCESS, message=message, stream_id=chat_id
|
continue_flag, modified_message = await event_bus.emit(
|
||||||
|
EventType.POST_SEND_PRE_PROCESS, _event_msg
|
||||||
)
|
)
|
||||||
if not continue_flag:
|
if not continue_flag:
|
||||||
logger.info(f"[{chat_id}] 消息发送被插件取消: {str(message.message_segment)[:100]}...")
|
logger.info(f"[{chat_id}] 消息发送被插件取消: {str(message.message_segment)[:100]}...")
|
||||||
@@ -336,8 +338,9 @@ class UniversalMessageSender:
|
|||||||
|
|
||||||
await message.process()
|
await message.process()
|
||||||
|
|
||||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
_event_msg = build_event_message(EventType.POST_SEND, message=message, stream_id=chat_id)
|
||||||
EventType.POST_SEND, message=message, stream_id=chat_id
|
continue_flag, modified_message = await event_bus.emit(
|
||||||
|
EventType.POST_SEND, _event_msg
|
||||||
)
|
)
|
||||||
if not continue_flag:
|
if not continue_flag:
|
||||||
logger.info(f"[{chat_id}] 消息发送被插件取消: {str(message.message_segment)[:100]}...")
|
logger.info(f"[{chat_id}] 消息发送被插件取消: {str(message.message_segment)[:100]}...")
|
||||||
@@ -360,8 +363,9 @@ class UniversalMessageSender:
|
|||||||
if not sent_msg:
|
if not sent_msg:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
_event_msg = build_event_message(EventType.AFTER_SEND, message=message, stream_id=chat_id)
|
||||||
EventType.AFTER_SEND, message=message, stream_id=chat_id
|
continue_flag, modified_message = await event_bus.emit(
|
||||||
|
EventType.AFTER_SEND, _event_msg
|
||||||
)
|
)
|
||||||
if not continue_flag:
|
if not continue_flag:
|
||||||
logger.info(f"[{chat_id}] 消息发送后续处理被插件取消: {str(message.message_segment)[:100]}...")
|
logger.info(f"[{chat_id}] 消息发送后续处理被插件取消: {str(message.message_segment)[:100]}...")
|
||||||
|
|||||||
@@ -1,20 +1,34 @@
|
|||||||
from typing import Dict, Optional, Type
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
from src.chat.message_receive.chat_manager import BotChatSession
|
from src.chat.message_receive.chat_manager import BotChatSession
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
from src.core.component_registry import component_registry, ActionExecutor
|
||||||
from src.plugin_system.base.component_types import ComponentType, ActionInfo
|
from src.core.types import ActionInfo, ComponentType
|
||||||
from src.plugin_system.base.base_action import BaseAction
|
|
||||||
|
|
||||||
logger = get_logger("action_manager")
|
logger = get_logger("action_manager")
|
||||||
|
|
||||||
|
|
||||||
|
class ActionHandle:
|
||||||
|
"""Action 执行句柄
|
||||||
|
|
||||||
|
不依赖任何插件基类,内部持有 executor (async callable) 和绑定参数。
|
||||||
|
brain_chat 调用 ``await handle.execute()`` 即可。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, executor: ActionExecutor, **kwargs):
|
||||||
|
self._executor = executor
|
||||||
|
self._kwargs = kwargs
|
||||||
|
|
||||||
|
async def execute(self) -> Tuple[bool, str]:
|
||||||
|
return await self._executor(**self._kwargs)
|
||||||
|
|
||||||
|
|
||||||
class ActionManager:
|
class ActionManager:
|
||||||
"""
|
"""
|
||||||
动作管理器,用于管理各种类型的动作
|
动作管理器,用于管理各种类型的动作
|
||||||
|
|
||||||
现在统一使用新插件系统,简化了原有的新旧兼容逻辑。
|
使用核心组件注册表的 executor-based 模式。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -39,9 +53,9 @@ class ActionManager:
|
|||||||
log_prefix: str,
|
log_prefix: str,
|
||||||
shutting_down: bool = False,
|
shutting_down: bool = False,
|
||||||
action_message: Optional[DatabaseMessages] = None,
|
action_message: Optional[DatabaseMessages] = None,
|
||||||
) -> Optional[BaseAction]:
|
) -> Optional[ActionHandle]:
|
||||||
"""
|
"""
|
||||||
创建动作处理器实例
|
创建动作执行句柄
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
action_name: 动作名称
|
action_name: 动作名称
|
||||||
@@ -52,30 +66,26 @@ class ActionManager:
|
|||||||
chat_stream: 聊天流
|
chat_stream: 聊天流
|
||||||
log_prefix: 日志前缀
|
log_prefix: 日志前缀
|
||||||
shutting_down: 是否正在关闭
|
shutting_down: 是否正在关闭
|
||||||
|
action_message: 动作消息记录
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[BaseAction]: 创建的动作处理器实例,如果动作名称未注册则返回None
|
Optional[ActionHandle]: 执行句柄,如果动作未注册则返回 None
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 获取组件类 - 明确指定查询Action类型
|
executor = component_registry.get_action_executor(action_name)
|
||||||
component_class: Type[BaseAction] = component_registry.get_component_class(
|
if not executor:
|
||||||
action_name, ComponentType.ACTION
|
|
||||||
) # type: ignore
|
|
||||||
if not component_class:
|
|
||||||
logger.warning(f"{log_prefix} 未找到Action组件: {action_name}")
|
logger.warning(f"{log_prefix} 未找到Action组件: {action_name}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 获取组件信息
|
info = component_registry.get_action_info(action_name)
|
||||||
component_info = component_registry.get_component_info(action_name, ComponentType.ACTION)
|
if not info:
|
||||||
if not component_info:
|
|
||||||
logger.warning(f"{log_prefix} 未找到Action组件信息: {action_name}")
|
logger.warning(f"{log_prefix} 未找到Action组件信息: {action_name}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 获取插件配置
|
plugin_config = component_registry.get_plugin_config(info.plugin_name) or {}
|
||||||
plugin_config = component_registry.get_plugin_config(component_info.plugin_name)
|
|
||||||
|
|
||||||
# 创建动作实例
|
handle = ActionHandle(
|
||||||
instance = component_class(
|
executor,
|
||||||
action_data=action_data,
|
action_data=action_data,
|
||||||
action_reasoning=action_reasoning,
|
action_reasoning=action_reasoning,
|
||||||
cycle_timers=cycle_timers,
|
cycle_timers=cycle_timers,
|
||||||
@@ -87,11 +97,11 @@ class ActionManager:
|
|||||||
action_message=action_message,
|
action_message=action_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"创建Action实例成功: {action_name}")
|
logger.debug(f"创建Action执行句柄成功: {action_name}")
|
||||||
return instance
|
return handle
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"创建Action实例失败 {action_name}: {e}")
|
logger.error(f"创建Action执行句柄失败 {action_name}: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
|
|||||||
@@ -7,8 +7,8 @@ from src.config.config import global_config
|
|||||||
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
|
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
|
||||||
from src.chat.planner_actions.action_manager import ActionManager
|
from src.chat.planner_actions.action_manager import ActionManager
|
||||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
|
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
|
||||||
from src.plugin_system.base.component_types import ActionInfo, ActionActivationType
|
from src.core.types import ActionActivationType, ActionInfo
|
||||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
from src.core.announcement_manager import global_announcement_manager
|
||||||
|
|
||||||
logger = get_logger("action_manager")
|
logger = get_logger("action_manager")
|
||||||
|
|
||||||
|
|||||||
@@ -23,9 +23,9 @@ from src.chat.utils.chat_message_builder import (
|
|||||||
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
|
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
|
||||||
from src.chat.planner_actions.action_manager import ActionManager
|
from src.chat.planner_actions.action_manager import ActionManager
|
||||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||||
from src.plugin_system.base.component_types import ActionInfo, ComponentType, ActionActivationType
|
from src.core.types import ActionActivationType, ActionInfo, ComponentType
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
from src.core.component_registry import component_registry
|
||||||
from src.plugin_system.apis.message_api import translate_pid_to_description
|
from src.services.message_service import translate_pid_to_description
|
||||||
from src.person_info.person_info import Person
|
from src.person_info.person_info import Person
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|||||||
@@ -27,12 +27,12 @@ from src.chat.utils.chat_message_builder import (
|
|||||||
replace_user_references,
|
replace_user_references,
|
||||||
)
|
)
|
||||||
from src.bw_learner.expression_selector import expression_selector
|
from src.bw_learner.expression_selector import expression_selector
|
||||||
from src.plugin_system.apis.message_api import translate_pid_to_description
|
from src.services.message_service import translate_pid_to_description
|
||||||
|
|
||||||
# from src.memory_system.memory_activator import MemoryActivator
|
# from src.memory_system.memory_activator import MemoryActivator
|
||||||
from src.person_info.person_info import Person
|
from src.person_info.person_info import Person
|
||||||
from src.plugin_system.base.component_types import ActionInfo, EventType
|
from src.core.types import ActionInfo, EventType
|
||||||
from src.plugin_system.apis import llm_api
|
from src.services import llm_service as llm_api
|
||||||
|
|
||||||
from src.chat.logger.plan_reply_logger import PlanReplyLogger
|
from src.chat.logger.plan_reply_logger import PlanReplyLogger
|
||||||
from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt
|
from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt
|
||||||
@@ -56,7 +56,7 @@ class DefaultReplyer:
|
|||||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
|
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
|
||||||
self.heart_fc_sender = UniversalMessageSender()
|
self.heart_fc_sender = UniversalMessageSender()
|
||||||
|
|
||||||
from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor,不然会循环依赖
|
from src.chat.tool_executor import ToolExecutor
|
||||||
|
|
||||||
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.session_id, enable_cache=True, cache_ttl=3)
|
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.session_id, enable_cache=True, cache_ttl=3)
|
||||||
|
|
||||||
@@ -149,11 +149,13 @@ class DefaultReplyer:
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("记录reply日志失败")
|
logger.exception("记录reply日志失败")
|
||||||
return False, llm_response
|
return False, llm_response
|
||||||
from src.plugin_system.core.events_manager import events_manager
|
from src.core.event_bus import event_bus
|
||||||
|
from src.chat.event_helpers import build_event_message
|
||||||
|
|
||||||
if not from_plugin:
|
if not from_plugin:
|
||||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
_event_msg = build_event_message(EventType.POST_LLM, llm_prompt=prompt, stream_id=stream_id)
|
||||||
EventType.POST_LLM, None, prompt, None, stream_id=stream_id
|
continue_flag, modified_message = await event_bus.emit(
|
||||||
|
EventType.POST_LLM, _event_msg
|
||||||
)
|
)
|
||||||
if not continue_flag:
|
if not continue_flag:
|
||||||
raise UserWarning("插件于请求前中断了内容生成")
|
raise UserWarning("插件于请求前中断了内容生成")
|
||||||
@@ -217,8 +219,9 @@ class DefaultReplyer:
|
|||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("记录reply日志失败")
|
logger.exception("记录reply日志失败")
|
||||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
_event_msg = build_event_message(EventType.AFTER_LLM, llm_prompt=prompt, llm_response=llm_response, stream_id=stream_id)
|
||||||
EventType.AFTER_LLM, None, prompt, llm_response, stream_id=stream_id
|
continue_flag, modified_message = await event_bus.emit(
|
||||||
|
EventType.AFTER_LLM, _event_msg
|
||||||
)
|
)
|
||||||
if not from_plugin and not continue_flag:
|
if not from_plugin and not continue_flag:
|
||||||
raise UserWarning("插件于请求后取消了内容生成")
|
raise UserWarning("插件于请求后取消了内容生成")
|
||||||
|
|||||||
@@ -28,12 +28,12 @@ from src.chat.utils.chat_message_builder import (
|
|||||||
replace_user_references,
|
replace_user_references,
|
||||||
)
|
)
|
||||||
from src.bw_learner.expression_selector import expression_selector
|
from src.bw_learner.expression_selector import expression_selector
|
||||||
from src.plugin_system.apis.message_api import translate_pid_to_description
|
from src.services.message_service import translate_pid_to_description
|
||||||
|
|
||||||
# from src.memory_system.memory_activator import MemoryActivator
|
# from src.memory_system.memory_activator import MemoryActivator
|
||||||
from src.person_info.person_info import Person, is_person_known
|
from src.person_info.person_info import Person, is_person_known
|
||||||
from src.plugin_system.base.component_types import ActionInfo, EventType
|
from src.core.types import ActionInfo, EventType
|
||||||
from src.plugin_system.apis import llm_api
|
from src.services import llm_service as llm_api
|
||||||
from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt
|
from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt
|
||||||
from src.bw_learner.jargon_explainer import explain_jargon_in_context
|
from src.bw_learner.jargon_explainer import explain_jargon_in_context
|
||||||
|
|
||||||
@@ -55,7 +55,7 @@ class PrivateReplyer:
|
|||||||
self.heart_fc_sender = UniversalMessageSender()
|
self.heart_fc_sender = UniversalMessageSender()
|
||||||
# self.memory_activator = MemoryActivator()
|
# self.memory_activator = MemoryActivator()
|
||||||
|
|
||||||
from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor,不然会循环依赖
|
from src.chat.tool_executor import ToolExecutor
|
||||||
|
|
||||||
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.session_id, enable_cache=True, cache_ttl=3)
|
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.session_id, enable_cache=True, cache_ttl=3)
|
||||||
|
|
||||||
@@ -114,11 +114,13 @@ class PrivateReplyer:
|
|||||||
if not prompt:
|
if not prompt:
|
||||||
logger.warning("构建prompt失败,跳过回复生成")
|
logger.warning("构建prompt失败,跳过回复生成")
|
||||||
return False, llm_response
|
return False, llm_response
|
||||||
from src.plugin_system.core.events_manager import events_manager
|
from src.core.event_bus import event_bus
|
||||||
|
from src.chat.event_helpers import build_event_message
|
||||||
|
|
||||||
if not from_plugin:
|
if not from_plugin:
|
||||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
_event_msg = build_event_message(EventType.POST_LLM, llm_prompt=prompt, stream_id=stream_id)
|
||||||
EventType.POST_LLM, None, prompt, None, stream_id=stream_id
|
continue_flag, modified_message = await event_bus.emit(
|
||||||
|
EventType.POST_LLM, _event_msg
|
||||||
)
|
)
|
||||||
if not continue_flag:
|
if not continue_flag:
|
||||||
raise UserWarning("插件于请求前中断了内容生成")
|
raise UserWarning("插件于请求前中断了内容生成")
|
||||||
@@ -138,8 +140,9 @@ class PrivateReplyer:
|
|||||||
llm_response.reasoning = reasoning_content
|
llm_response.reasoning = reasoning_content
|
||||||
llm_response.model = model_name
|
llm_response.model = model_name
|
||||||
llm_response.tool_calls = tool_call
|
llm_response.tool_calls = tool_call
|
||||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
_event_msg = build_event_message(EventType.AFTER_LLM, llm_prompt=prompt, llm_response=llm_response, stream_id=stream_id)
|
||||||
EventType.AFTER_LLM, None, prompt, llm_response, stream_id=stream_id
|
continue_flag, modified_message = await event_bus.emit(
|
||||||
|
EventType.AFTER_LLM, _event_msg
|
||||||
)
|
)
|
||||||
if not from_plugin and not continue_flag:
|
if not from_plugin and not continue_flag:
|
||||||
raise UserWarning("插件于请求后取消了内容生成")
|
raise UserWarning("插件于请求后取消了内容生成")
|
||||||
|
|||||||
@@ -1,14 +1,23 @@
|
|||||||
|
"""
|
||||||
|
工具执行器
|
||||||
|
|
||||||
|
独立的工具执行组件,可以直接输入聊天消息内容,
|
||||||
|
自动判断并执行相应的工具,返回结构化的工具执行结果。
|
||||||
|
|
||||||
|
从 src.plugin_system.core.tool_use 迁移,使用新的核心组件注册表。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
import time
|
import time
|
||||||
from typing import List, Dict, Tuple, Optional, Any
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
|
|
||||||
from src.plugin_system.base.base_tool import BaseTool
|
|
||||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
|
||||||
from src.llm_models.payload_content import ToolCall
|
|
||||||
from src.config.config import global_config, model_config
|
|
||||||
from src.prompt.prompt_manager import prompt_manager
|
|
||||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import global_config, model_config
|
||||||
|
from src.core.announcement_manager import global_announcement_manager
|
||||||
|
from src.core.component_registry import component_registry
|
||||||
|
from src.llm_models.payload_content import ToolCall
|
||||||
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
from src.prompt.prompt_manager import prompt_manager
|
||||||
|
|
||||||
logger = get_logger("tool_use")
|
logger = get_logger("tool_use")
|
||||||
|
|
||||||
@@ -20,66 +29,38 @@ class ToolExecutor:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, chat_id: str, enable_cache: bool = True, cache_ttl: int = 3):
|
def __init__(self, chat_id: str, enable_cache: bool = True, cache_ttl: int = 3):
|
||||||
"""初始化工具执行器
|
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||||
|
|
||||||
Args:
|
|
||||||
executor_id: 执行器标识符,用于日志记录
|
|
||||||
enable_cache: 是否启用缓存机制
|
|
||||||
cache_ttl: 缓存生存时间(周期数)
|
|
||||||
"""
|
|
||||||
self.chat_id = chat_id
|
self.chat_id = chat_id
|
||||||
self.chat_stream = _chat_manager.get_session_by_session_id(self.chat_id)
|
self.chat_stream = _chat_manager.get_session_by_session_id(self.chat_id)
|
||||||
self.log_prefix = f"[{_chat_manager.get_session_name(self.chat_id) or self.chat_id}]"
|
self.log_prefix = f"[{_chat_manager.get_session_name(self.chat_id) or self.chat_id}]"
|
||||||
|
|
||||||
self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor")
|
self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor")
|
||||||
|
|
||||||
# 缓存配置
|
|
||||||
self.enable_cache = enable_cache
|
self.enable_cache = enable_cache
|
||||||
self.cache_ttl = cache_ttl
|
self.cache_ttl = cache_ttl
|
||||||
self.tool_cache = {} # 格式: {cache_key: {"result": result, "ttl": ttl, "timestamp": timestamp}}
|
self.tool_cache: Dict[str, dict] = {}
|
||||||
|
|
||||||
logger.info(f"{self.log_prefix}工具执行器初始化完成,缓存{'启用' if enable_cache else '禁用'},TTL={cache_ttl}")
|
logger.info(f"{self.log_prefix}工具执行器初始化完成,缓存{'启用' if enable_cache else '禁用'},TTL={cache_ttl}")
|
||||||
|
|
||||||
async def execute_from_chat_message(
|
async def execute_from_chat_message(
|
||||||
self, target_message: str, chat_history: str, sender: str, return_details: bool = False
|
self, target_message: str, chat_history: str, sender: str, return_details: bool = False
|
||||||
) -> Tuple[List[Dict[str, Any]], List[str], str]:
|
) -> Tuple[List[Dict[str, Any]], List[str], str]:
|
||||||
"""从聊天消息执行工具
|
"""从聊天消息执行工具"""
|
||||||
|
|
||||||
Args:
|
|
||||||
target_message: 目标消息内容
|
|
||||||
chat_history: 聊天历史
|
|
||||||
sender: 发送者
|
|
||||||
return_details: 是否返回详细信息(使用的工具列表和提示词)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
如果return_details为False: Tuple[List[Dict], List[str], str] - (工具执行结果列表, 空, 空)
|
|
||||||
如果return_details为True: Tuple[List[Dict], List[str], str] - (结果列表, 使用的工具, 提示词)
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 首先检查缓存
|
|
||||||
cache_key = self._generate_cache_key(target_message, chat_history, sender)
|
cache_key = self._generate_cache_key(target_message, chat_history, sender)
|
||||||
if cached_result := self._get_from_cache(cache_key):
|
if cached_result := self._get_from_cache(cache_key):
|
||||||
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行")
|
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行")
|
||||||
if not return_details:
|
if not return_details:
|
||||||
return cached_result, [], ""
|
return cached_result, [], ""
|
||||||
|
|
||||||
# 从缓存结果中提取工具名称
|
|
||||||
used_tools = [result.get("tool_name", "unknown") for result in cached_result]
|
used_tools = [result.get("tool_name", "unknown") for result in cached_result]
|
||||||
return cached_result, used_tools, ""
|
return cached_result, used_tools, ""
|
||||||
|
|
||||||
# 缓存未命中,执行工具调用
|
|
||||||
# 获取可用工具
|
|
||||||
tools = self._get_tool_definitions()
|
tools = self._get_tool_definitions()
|
||||||
|
|
||||||
# 如果没有可用工具,直接返回空内容
|
|
||||||
if not tools:
|
if not tools:
|
||||||
logger.debug(f"{self.log_prefix}没有可用工具,直接返回空内容")
|
logger.debug(f"{self.log_prefix}没有可用工具,直接返回空内容")
|
||||||
if return_details:
|
return [], [], ""
|
||||||
return [], [], ""
|
|
||||||
else:
|
|
||||||
return [], [], ""
|
|
||||||
|
|
||||||
# 构建工具调用提示词
|
|
||||||
prompt_template = prompt_manager.get_prompt("tool_executor")
|
prompt_template = prompt_manager.get_prompt("tool_executor")
|
||||||
prompt_template.add_context("target_message", target_message)
|
prompt_template.add_context("target_message", target_message)
|
||||||
prompt_template.add_context("chat_history", chat_history)
|
prompt_template.add_context("chat_history", chat_history)
|
||||||
@@ -90,15 +71,12 @@ class ToolExecutor:
|
|||||||
|
|
||||||
logger.debug(f"{self.log_prefix}开始LLM工具调用分析")
|
logger.debug(f"{self.log_prefix}开始LLM工具调用分析")
|
||||||
|
|
||||||
# 调用LLM进行工具决策
|
|
||||||
response, (reasoning_content, model_name, tool_calls) = await self.llm_model.generate_response_async(
|
response, (reasoning_content, model_name, tool_calls) = await self.llm_model.generate_response_async(
|
||||||
prompt=prompt, tools=tools, raise_when_empty=False
|
prompt=prompt, tools=tools, raise_when_empty=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# 执行工具调用
|
|
||||||
tool_results, used_tools = await self.execute_tool_calls(tool_calls)
|
tool_results, used_tools = await self.execute_tool_calls(tool_calls)
|
||||||
|
|
||||||
# 缓存结果
|
|
||||||
if tool_results:
|
if tool_results:
|
||||||
self._set_cache(cache_key, tool_results)
|
self._set_cache(cache_key, tool_results)
|
||||||
|
|
||||||
@@ -107,42 +85,30 @@ class ToolExecutor:
|
|||||||
|
|
||||||
if return_details:
|
if return_details:
|
||||||
return tool_results, used_tools, prompt
|
return tool_results, used_tools, prompt
|
||||||
else:
|
return tool_results, [], ""
|
||||||
return tool_results, [], ""
|
|
||||||
|
|
||||||
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
|
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
|
||||||
all_tools = get_llm_available_tool_definitions()
|
"""获取 LLM 可用的工具定义列表"""
|
||||||
|
all_tools = component_registry.get_llm_available_tools()
|
||||||
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
|
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
|
||||||
return [definition for name, definition in all_tools if name not in user_disabled_tools]
|
return [info.get_llm_definition() for name, info in all_tools.items() if name not in user_disabled_tools]
|
||||||
|
|
||||||
async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]:
|
async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]:
|
||||||
"""执行工具调用
|
"""执行工具调用列表"""
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_calls: LLM返回的工具调用列表
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[List[Dict], List[str]]: (工具执行结果列表, 使用的工具名称列表)
|
|
||||||
"""
|
|
||||||
tool_results: List[Dict[str, Any]] = []
|
tool_results: List[Dict[str, Any]] = []
|
||||||
used_tools = []
|
used_tools: List[str] = []
|
||||||
|
|
||||||
if not tool_calls:
|
if not tool_calls:
|
||||||
logger.debug(f"{self.log_prefix}无需执行工具")
|
logger.debug(f"{self.log_prefix}无需执行工具")
|
||||||
return [], []
|
return [], []
|
||||||
|
|
||||||
# 提取tool_calls中的函数名称
|
|
||||||
func_names = [call.func_name for call in tool_calls if call.func_name]
|
func_names = [call.func_name for call in tool_calls if call.func_name]
|
||||||
|
|
||||||
logger.info(f"{self.log_prefix}开始执行工具调用: {func_names}")
|
logger.info(f"{self.log_prefix}开始执行工具调用: {func_names}")
|
||||||
|
|
||||||
# 执行每个工具调用
|
|
||||||
for tool_call in tool_calls:
|
for tool_call in tool_calls:
|
||||||
|
tool_name = tool_call.func_name
|
||||||
try:
|
try:
|
||||||
tool_name = tool_call.func_name
|
|
||||||
logger.debug(f"{self.log_prefix}执行工具: {tool_name}")
|
logger.debug(f"{self.log_prefix}执行工具: {tool_name}")
|
||||||
|
|
||||||
# 执行工具
|
|
||||||
result = await self.execute_tool_call(tool_call)
|
result = await self.execute_tool_call(tool_call)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
@@ -156,7 +122,6 @@ class ToolExecutor:
|
|||||||
content = tool_info["content"]
|
content = tool_info["content"]
|
||||||
if not isinstance(content, (str, list, tuple)):
|
if not isinstance(content, (str, list, tuple)):
|
||||||
tool_info["content"] = str(content)
|
tool_info["content"] = str(content)
|
||||||
# 空内容直接跳过(空字符串、全空白字符串、空列表/空元组)
|
|
||||||
content_check = tool_info["content"]
|
content_check = tool_info["content"]
|
||||||
if (isinstance(content_check, str) and not content_check.strip()) or (
|
if (isinstance(content_check, str) and not content_check.strip()) or (
|
||||||
isinstance(content_check, (list, tuple)) and len(content_check) == 0
|
isinstance(content_check, (list, tuple)) and len(content_check) == 0
|
||||||
@@ -166,11 +131,10 @@ class ToolExecutor:
|
|||||||
|
|
||||||
tool_results.append(tool_info)
|
tool_results.append(tool_info)
|
||||||
used_tools.append(tool_name)
|
used_tools.append(tool_name)
|
||||||
preview = content[:200]
|
preview = str(content)[:200]
|
||||||
logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...")
|
logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}")
|
logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}")
|
||||||
# 添加错误信息到结果中
|
|
||||||
error_info = {
|
error_info = {
|
||||||
"type": "tool_error",
|
"type": "tool_error",
|
||||||
"id": f"tool_error_{time.time()}",
|
"id": f"tool_error_{time.time()}",
|
||||||
@@ -182,122 +146,30 @@ class ToolExecutor:
|
|||||||
|
|
||||||
return tool_results, used_tools
|
return tool_results, used_tools
|
||||||
|
|
||||||
async def execute_tool_call(
|
async def execute_tool_call(self, tool_call: ToolCall) -> Optional[Dict[str, Any]]:
|
||||||
self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None
|
"""执行单个工具调用"""
|
||||||
) -> Optional[Dict[str, Any]]:
|
function_name = tool_call.func_name
|
||||||
# sourcery skip: use-assigned-variable
|
function_args = tool_call.args or {}
|
||||||
"""执行单个工具调用
|
function_args["llm_called"] = True
|
||||||
|
|
||||||
Args:
|
executor = component_registry.get_tool_executor(function_name)
|
||||||
tool_call: 工具调用对象
|
if not executor:
|
||||||
|
logger.warning(f"未知工具名称: {function_name}")
|
||||||
Returns:
|
|
||||||
Optional[Dict]: 工具调用结果,如果失败则返回None
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
function_name = tool_call.func_name
|
|
||||||
function_args = tool_call.args or {}
|
|
||||||
function_args["llm_called"] = True # 标记为LLM调用
|
|
||||||
|
|
||||||
# 获取对应工具实例
|
|
||||||
tool_instance = tool_instance or get_tool_instance(function_name, self.chat_stream)
|
|
||||||
if not tool_instance:
|
|
||||||
logger.warning(f"未知工具名称: {function_name}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 执行工具
|
|
||||||
result = await tool_instance.execute(function_args)
|
|
||||||
if result:
|
|
||||||
return {
|
|
||||||
"tool_call_id": tool_call.call_id,
|
|
||||||
"role": "tool",
|
|
||||||
"name": function_name,
|
|
||||||
"type": "function",
|
|
||||||
"content": result["content"],
|
|
||||||
}
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"执行工具调用时发生错误: {str(e)}")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str:
|
|
||||||
"""生成缓存键
|
|
||||||
|
|
||||||
Args:
|
|
||||||
target_message: 目标消息内容
|
|
||||||
chat_history: 聊天历史
|
|
||||||
sender: 发送者
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 缓存键
|
|
||||||
"""
|
|
||||||
import hashlib
|
|
||||||
|
|
||||||
# 使用消息内容和群聊状态生成唯一缓存键
|
|
||||||
content = f"{target_message}_{chat_history}_{sender}"
|
|
||||||
return hashlib.md5(content.encode()).hexdigest()
|
|
||||||
|
|
||||||
def _get_from_cache(self, cache_key: str) -> Optional[List[Dict]]:
|
|
||||||
"""从缓存获取结果
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cache_key: 缓存键
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[List[Dict]]: 缓存的结果,如果不存在或过期则返回None
|
|
||||||
"""
|
|
||||||
if not self.enable_cache or cache_key not in self.tool_cache:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
cache_item = self.tool_cache[cache_key]
|
result = await executor(function_args)
|
||||||
if cache_item["ttl"] <= 0:
|
if result:
|
||||||
# 缓存过期,删除
|
return {
|
||||||
del self.tool_cache[cache_key]
|
"tool_call_id": tool_call.call_id,
|
||||||
logger.debug(f"{self.log_prefix}缓存过期,删除缓存键: {cache_key}")
|
"role": "tool",
|
||||||
return None
|
"name": function_name,
|
||||||
|
"type": "function",
|
||||||
# 减少TTL
|
"content": result["content"],
|
||||||
cache_item["ttl"] -= 1
|
}
|
||||||
logger.debug(f"{self.log_prefix}使用缓存结果,剩余TTL: {cache_item['ttl']}")
|
return None
|
||||||
return cache_item["result"]
|
|
||||||
|
|
||||||
def _set_cache(self, cache_key: str, result: List[Dict]):
|
|
||||||
"""设置缓存
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cache_key: 缓存键
|
|
||||||
result: 要缓存的结果
|
|
||||||
"""
|
|
||||||
if not self.enable_cache:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.tool_cache[cache_key] = {"result": result, "ttl": self.cache_ttl, "timestamp": time.time()}
|
|
||||||
logger.debug(f"{self.log_prefix}设置缓存,TTL: {self.cache_ttl}")
|
|
||||||
|
|
||||||
def _cleanup_expired_cache(self):
|
|
||||||
"""清理过期的缓存"""
|
|
||||||
if not self.enable_cache:
|
|
||||||
return
|
|
||||||
|
|
||||||
expired_keys = []
|
|
||||||
expired_keys.extend(cache_key for cache_key, cache_item in self.tool_cache.items() if cache_item["ttl"] <= 0)
|
|
||||||
for key in expired_keys:
|
|
||||||
del self.tool_cache[key]
|
|
||||||
|
|
||||||
if expired_keys:
|
|
||||||
logger.debug(f"{self.log_prefix}清理了{len(expired_keys)}个过期缓存")
|
|
||||||
|
|
||||||
async def execute_specific_tool_simple(self, tool_name: str, tool_args: Dict) -> Optional[Dict]:
|
async def execute_specific_tool_simple(self, tool_name: str, tool_args: Dict) -> Optional[Dict]:
|
||||||
"""直接执行指定工具
|
"""直接执行指定工具"""
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_name: 工具名称
|
|
||||||
tool_args: 工具参数
|
|
||||||
validate_args: 是否验证参数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[Dict]: 工具执行结果,失败时返回None
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
tool_call = ToolCall(
|
tool_call = ToolCall(
|
||||||
call_id=f"direct_tool_{time.time()}",
|
call_id=f"direct_tool_{time.time()}",
|
||||||
@@ -306,7 +178,6 @@ class ToolExecutor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"{self.log_prefix}直接执行工具: {tool_name}")
|
logger.info(f"{self.log_prefix}直接执行工具: {tool_name}")
|
||||||
|
|
||||||
result = await self.execute_tool_call(tool_call)
|
result = await self.execute_tool_call(tool_call)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
@@ -325,86 +196,55 @@ class ToolExecutor:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# === 缓存方法 ===
|
||||||
|
|
||||||
|
def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str:
|
||||||
|
content = f"{target_message}_{chat_history}_{sender}"
|
||||||
|
return hashlib.md5(content.encode()).hexdigest()
|
||||||
|
|
||||||
|
def _get_from_cache(self, cache_key: str) -> Optional[List[Dict]]:
|
||||||
|
if not self.enable_cache or cache_key not in self.tool_cache:
|
||||||
|
return None
|
||||||
|
cache_item = self.tool_cache[cache_key]
|
||||||
|
if cache_item["ttl"] <= 0:
|
||||||
|
del self.tool_cache[cache_key]
|
||||||
|
return None
|
||||||
|
cache_item["ttl"] -= 1
|
||||||
|
return cache_item["result"]
|
||||||
|
|
||||||
|
def _set_cache(self, cache_key: str, result: List[Dict]):
|
||||||
|
if not self.enable_cache:
|
||||||
|
return
|
||||||
|
self.tool_cache[cache_key] = {"result": result, "ttl": self.cache_ttl, "timestamp": time.time()}
|
||||||
|
|
||||||
|
def _cleanup_expired_cache(self):
|
||||||
|
if not self.enable_cache:
|
||||||
|
return
|
||||||
|
expired = [k for k, v in self.tool_cache.items() if v["ttl"] <= 0]
|
||||||
|
for key in expired:
|
||||||
|
del self.tool_cache[key]
|
||||||
|
|
||||||
def clear_cache(self):
|
def clear_cache(self):
|
||||||
"""清空所有缓存"""
|
|
||||||
if self.enable_cache:
|
if self.enable_cache:
|
||||||
cache_count = len(self.tool_cache)
|
|
||||||
self.tool_cache.clear()
|
self.tool_cache.clear()
|
||||||
logger.info(f"{self.log_prefix}清空了{cache_count}个缓存项")
|
|
||||||
|
|
||||||
def get_cache_status(self) -> Dict:
|
def get_cache_status(self) -> Dict:
|
||||||
"""获取缓存状态信息
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: 包含缓存统计信息的字典
|
|
||||||
"""
|
|
||||||
if not self.enable_cache:
|
if not self.enable_cache:
|
||||||
return {"enabled": False, "cache_count": 0}
|
return {"enabled": False, "cache_count": 0}
|
||||||
|
|
||||||
# 清理过期缓存
|
|
||||||
self._cleanup_expired_cache()
|
self._cleanup_expired_cache()
|
||||||
|
ttl_distribution: Dict[int, int] = {}
|
||||||
total_count = len(self.tool_cache)
|
for item in self.tool_cache.values():
|
||||||
ttl_distribution = {}
|
ttl = item["ttl"]
|
||||||
|
|
||||||
for cache_item in self.tool_cache.values():
|
|
||||||
ttl = cache_item["ttl"]
|
|
||||||
ttl_distribution[ttl] = ttl_distribution.get(ttl, 0) + 1
|
ttl_distribution[ttl] = ttl_distribution.get(ttl, 0) + 1
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
"cache_count": total_count,
|
"cache_count": len(self.tool_cache),
|
||||||
"cache_ttl": self.cache_ttl,
|
"cache_ttl": self.cache_ttl,
|
||||||
"ttl_distribution": ttl_distribution,
|
"ttl_distribution": ttl_distribution,
|
||||||
}
|
}
|
||||||
|
|
||||||
def set_cache_config(self, enable_cache: Optional[bool] = None, cache_ttl: int = -1):
|
def set_cache_config(self, enable_cache: Optional[bool] = None, cache_ttl: int = -1):
|
||||||
"""动态修改缓存配置
|
|
||||||
|
|
||||||
Args:
|
|
||||||
enable_cache: 是否启用缓存
|
|
||||||
cache_ttl: 缓存TTL
|
|
||||||
"""
|
|
||||||
if enable_cache is not None:
|
if enable_cache is not None:
|
||||||
self.enable_cache = enable_cache
|
self.enable_cache = enable_cache
|
||||||
logger.info(f"{self.log_prefix}缓存状态修改为: {'启用' if enable_cache else '禁用'}")
|
|
||||||
|
|
||||||
if cache_ttl > 0:
|
if cache_ttl > 0:
|
||||||
self.cache_ttl = cache_ttl
|
self.cache_ttl = cache_ttl
|
||||||
logger.info(f"{self.log_prefix}缓存TTL修改为: {cache_ttl}")
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
ToolExecutor使用示例:
|
|
||||||
|
|
||||||
# 1. 基础使用 - 从聊天消息执行工具(启用缓存,默认TTL=3)
|
|
||||||
executor = ToolExecutor(executor_id="my_executor")
|
|
||||||
results, _, _ = await executor.execute_from_chat_message(
|
|
||||||
talking_message_str="今天天气怎么样?现在几点了?",
|
|
||||||
is_group_chat=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. 禁用缓存的执行器
|
|
||||||
no_cache_executor = ToolExecutor(executor_id="no_cache", enable_cache=False)
|
|
||||||
|
|
||||||
# 3. 自定义缓存TTL
|
|
||||||
long_cache_executor = ToolExecutor(executor_id="long_cache", cache_ttl=10)
|
|
||||||
|
|
||||||
# 4. 获取详细信息
|
|
||||||
results, used_tools, prompt = await executor.execute_from_chat_message(
|
|
||||||
talking_message_str="帮我查询Python相关知识",
|
|
||||||
is_group_chat=False,
|
|
||||||
return_details=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 5. 直接执行特定工具
|
|
||||||
result = await executor.execute_specific_tool_simple(
|
|
||||||
tool_name="get_knowledge",
|
|
||||||
tool_args={"query": "机器学习"}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 6. 缓存管理
|
|
||||||
cache_status = executor.get_cache_status() # 查看缓存状态
|
|
||||||
executor.clear_cache() # 清空缓存
|
|
||||||
executor.set_cache_config(cache_ttl=5) # 动态修改缓存配置
|
|
||||||
"""
|
|
||||||
@@ -4,7 +4,7 @@ from . import BaseDataModel
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .database_data_model import DatabaseMessages
|
from .database_data_model import DatabaseMessages
|
||||||
from src.plugin_system.base.component_types import ActionInfo
|
from src.core.types import ActionInfo
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
6
src/core/__init__.py
Normal file
6
src/core/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
"""
|
||||||
|
MaiBot 核心基础设施
|
||||||
|
|
||||||
|
提供与插件系统无关的核心类型定义、配置 schema 等基础设施。
|
||||||
|
这些类型被整个项目共享,包括内部模块、服务层、旧插件系统和新插件运行时。
|
||||||
|
"""
|
||||||
242
src/core/component_registry.py
Normal file
242
src/core/component_registry.py
Normal file
@@ -0,0 +1,242 @@
|
|||||||
|
"""
|
||||||
|
核心组件注册表
|
||||||
|
|
||||||
|
面向最终架构的组件管理:
|
||||||
|
- Action:注册 ActionInfo + 执行器(本地 callable 或 IPC 路由)
|
||||||
|
- Command:注册正则模式 + 执行器
|
||||||
|
- Tool:注册工具定义 + 执行器
|
||||||
|
|
||||||
|
不依赖任何插件基类,组件执行器是纯 async callable。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Any, Awaitable, Callable, Dict, List, Optional, Pattern, Tuple, Union
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.core.types import (
|
||||||
|
ActionActivationType,
|
||||||
|
ActionInfo,
|
||||||
|
CommandInfo,
|
||||||
|
ComponentInfo,
|
||||||
|
ComponentType,
|
||||||
|
ToolInfo,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = get_logger("component_registry")
|
||||||
|
|
||||||
|
# 执行器类型
|
||||||
|
ActionExecutor = Callable[..., Awaitable[Any]]
|
||||||
|
CommandExecutor = Callable[..., Awaitable[Tuple[bool, Optional[str], bool]]]
|
||||||
|
ToolExecutor = Callable[..., Awaitable[Any]]
|
||||||
|
|
||||||
|
|
||||||
|
class ComponentRegistry:
|
||||||
|
"""核心组件注册表
|
||||||
|
|
||||||
|
管理 action、command、tool 三类组件。
|
||||||
|
每个组件由「元信息 + 执行器」构成,执行器是 async callable,
|
||||||
|
不需要继承任何基类。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Action 注册
|
||||||
|
self._actions: Dict[str, ActionInfo] = {}
|
||||||
|
self._action_executors: Dict[str, ActionExecutor] = {}
|
||||||
|
self._default_actions: Dict[str, ActionInfo] = {}
|
||||||
|
|
||||||
|
# Command 注册
|
||||||
|
self._commands: Dict[str, CommandInfo] = {}
|
||||||
|
self._command_executors: Dict[str, CommandExecutor] = {}
|
||||||
|
self._command_patterns: Dict[Pattern, str] = {}
|
||||||
|
|
||||||
|
# Tool 注册
|
||||||
|
self._tools: Dict[str, ToolInfo] = {}
|
||||||
|
self._tool_executors: Dict[str, ToolExecutor] = {}
|
||||||
|
self._llm_available_tools: Dict[str, ToolInfo] = {}
|
||||||
|
|
||||||
|
# 插件配置(plugin_name -> config dict)
|
||||||
|
self._plugin_configs: Dict[str, dict] = {}
|
||||||
|
|
||||||
|
logger.info("核心组件注册表初始化完成")
|
||||||
|
|
||||||
|
# ========== Action ==========
|
||||||
|
|
||||||
|
def register_action(
|
||||||
|
self,
|
||||||
|
info: ActionInfo,
|
||||||
|
executor: ActionExecutor,
|
||||||
|
) -> bool:
|
||||||
|
"""注册 action
|
||||||
|
|
||||||
|
Args:
|
||||||
|
info: action 元信息
|
||||||
|
executor: 执行器,async callable
|
||||||
|
"""
|
||||||
|
name = info.name
|
||||||
|
if name in self._actions:
|
||||||
|
logger.warning(f"Action {name} 已存在,跳过注册")
|
||||||
|
return False
|
||||||
|
|
||||||
|
self._actions[name] = info
|
||||||
|
self._action_executors[name] = executor
|
||||||
|
|
||||||
|
if info.enabled:
|
||||||
|
self._default_actions[name] = info
|
||||||
|
|
||||||
|
logger.debug(f"注册 Action: {name}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_action_info(self, name: str) -> Optional[ActionInfo]:
|
||||||
|
return self._actions.get(name)
|
||||||
|
|
||||||
|
def get_action_executor(self, name: str) -> Optional[ActionExecutor]:
|
||||||
|
return self._action_executors.get(name)
|
||||||
|
|
||||||
|
def get_default_actions(self) -> Dict[str, ActionInfo]:
|
||||||
|
return self._default_actions.copy()
|
||||||
|
|
||||||
|
def get_all_actions(self) -> Dict[str, ActionInfo]:
|
||||||
|
return self._actions.copy()
|
||||||
|
|
||||||
|
def remove_action(self, name: str) -> bool:
|
||||||
|
if name not in self._actions:
|
||||||
|
return False
|
||||||
|
del self._actions[name]
|
||||||
|
self._action_executors.pop(name, None)
|
||||||
|
self._default_actions.pop(name, None)
|
||||||
|
logger.debug(f"移除 Action: {name}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# ========== Command ==========
|
||||||
|
|
||||||
|
def register_command(
|
||||||
|
self,
|
||||||
|
info: CommandInfo,
|
||||||
|
executor: CommandExecutor,
|
||||||
|
) -> bool:
|
||||||
|
"""注册 command"""
|
||||||
|
name = info.name
|
||||||
|
if name in self._commands:
|
||||||
|
logger.warning(f"Command {name} 已存在,跳过注册")
|
||||||
|
return False
|
||||||
|
|
||||||
|
self._commands[name] = info
|
||||||
|
self._command_executors[name] = executor
|
||||||
|
|
||||||
|
if info.enabled and info.command_pattern:
|
||||||
|
pattern = re.compile(info.command_pattern, re.IGNORECASE | re.DOTALL)
|
||||||
|
self._command_patterns[pattern] = name
|
||||||
|
|
||||||
|
logger.debug(f"注册 Command: {name}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def find_command_by_text(
|
||||||
|
self, text: str
|
||||||
|
) -> Optional[Tuple[CommandExecutor, dict, CommandInfo]]:
|
||||||
|
"""根据文本查找匹配的命令
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(executor, matched_groups, command_info) 或 None
|
||||||
|
"""
|
||||||
|
candidates = [p for p in self._command_patterns if p.match(text)]
|
||||||
|
if not candidates:
|
||||||
|
return None
|
||||||
|
if len(candidates) > 1:
|
||||||
|
logger.warning(f"文本 '{text[:50]}' 匹配到多个命令模式,使用第一个")
|
||||||
|
pattern = candidates[0]
|
||||||
|
name = self._command_patterns[pattern]
|
||||||
|
return (
|
||||||
|
self._command_executors[name],
|
||||||
|
pattern.match(text).groupdict(), # type: ignore
|
||||||
|
self._commands[name],
|
||||||
|
)
|
||||||
|
|
||||||
|
def remove_command(self, name: str) -> bool:
|
||||||
|
if name not in self._commands:
|
||||||
|
return False
|
||||||
|
del self._commands[name]
|
||||||
|
self._command_executors.pop(name, None)
|
||||||
|
self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != name}
|
||||||
|
logger.debug(f"移除 Command: {name}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# ========== Tool ==========
|
||||||
|
|
||||||
|
def register_tool(
|
||||||
|
self,
|
||||||
|
info: ToolInfo,
|
||||||
|
executor: ToolExecutor,
|
||||||
|
) -> bool:
|
||||||
|
"""注册 tool"""
|
||||||
|
name = info.name
|
||||||
|
if name in self._tools:
|
||||||
|
logger.warning(f"Tool {name} 已存在,跳过注册")
|
||||||
|
return False
|
||||||
|
|
||||||
|
self._tools[name] = info
|
||||||
|
self._tool_executors[name] = executor
|
||||||
|
|
||||||
|
if info.enabled:
|
||||||
|
self._llm_available_tools[name] = info
|
||||||
|
|
||||||
|
logger.debug(f"注册 Tool: {name}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_tool_info(self, name: str) -> Optional[ToolInfo]:
|
||||||
|
return self._tools.get(name)
|
||||||
|
|
||||||
|
def get_tool_executor(self, name: str) -> Optional[ToolExecutor]:
|
||||||
|
return self._tool_executors.get(name)
|
||||||
|
|
||||||
|
def get_llm_available_tools(self) -> Dict[str, ToolInfo]:
|
||||||
|
return self._llm_available_tools.copy()
|
||||||
|
|
||||||
|
def get_all_tools(self) -> Dict[str, ToolInfo]:
|
||||||
|
return self._tools.copy()
|
||||||
|
|
||||||
|
def remove_tool(self, name: str) -> bool:
|
||||||
|
if name not in self._tools:
|
||||||
|
return False
|
||||||
|
del self._tools[name]
|
||||||
|
self._tool_executors.pop(name, None)
|
||||||
|
self._llm_available_tools.pop(name, None)
|
||||||
|
logger.debug(f"移除 Tool: {name}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# ========== 通用查询 ==========
|
||||||
|
|
||||||
|
def get_component_info(self, name: str, component_type: ComponentType) -> Optional[ComponentInfo]:
|
||||||
|
"""获取组件元信息"""
|
||||||
|
match component_type:
|
||||||
|
case ComponentType.ACTION:
|
||||||
|
return self._actions.get(name)
|
||||||
|
case ComponentType.COMMAND:
|
||||||
|
return self._commands.get(name)
|
||||||
|
case ComponentType.TOOL:
|
||||||
|
return self._tools.get(name)
|
||||||
|
case _:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
|
||||||
|
"""获取某类型的所有组件"""
|
||||||
|
match component_type:
|
||||||
|
case ComponentType.ACTION:
|
||||||
|
return dict(self._actions)
|
||||||
|
case ComponentType.COMMAND:
|
||||||
|
return dict(self._commands)
|
||||||
|
case ComponentType.TOOL:
|
||||||
|
return dict(self._tools)
|
||||||
|
case _:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# ========== 插件配置 ==========
|
||||||
|
|
||||||
|
def set_plugin_config(self, plugin_name: str, config: dict) -> None:
|
||||||
|
self._plugin_configs[plugin_name] = config
|
||||||
|
|
||||||
|
def get_plugin_config(self, plugin_name: str) -> Optional[dict]:
|
||||||
|
return self._plugin_configs.get(plugin_name)
|
||||||
|
|
||||||
|
|
||||||
|
# 全局单例
|
||||||
|
component_registry = ComponentRegistry()
|
||||||
216
src/core/event_bus.py
Normal file
216
src/core/event_bus.py
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
"""
|
||||||
|
核心事件总线
|
||||||
|
|
||||||
|
面向最终架构的事件系统:
|
||||||
|
- 内部 handler 直接注册 async callable
|
||||||
|
- IPC 插件通过 plugin_runtime 桥接
|
||||||
|
- 不依赖任何插件基类
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.core.types import EventType, MaiMessages
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.common.data_models.llm_data_model import LLMGenerationDataModel
|
||||||
|
|
||||||
|
logger = get_logger("event_bus")
|
||||||
|
|
||||||
|
# Handler 签名:接收 MaiMessages,返回 (continue, modified_message)
|
||||||
|
EventHandler = Callable[[Optional[MaiMessages]], Awaitable[Tuple[bool, Optional[MaiMessages]]]]
|
||||||
|
|
||||||
|
|
||||||
|
class EventBus:
|
||||||
|
"""核心事件总线
|
||||||
|
|
||||||
|
支持两种 handler:
|
||||||
|
- 拦截型(intercept=True):同步顺序执行,可修改消息、可中断流程
|
||||||
|
- 非拦截型(intercept=False):异步并发执行,fire-and-forget
|
||||||
|
|
||||||
|
handler 是纯 async callable,不需要继承任何基类。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# event_type -> [(handler, name, weight, intercept)]
|
||||||
|
self._handlers: Dict[EventType | str, List[_HandlerEntry]] = {}
|
||||||
|
self._running_tasks: Dict[str, List[asyncio.Task]] = {}
|
||||||
|
|
||||||
|
# 预注册所有内置事件类型
|
||||||
|
for event in EventType:
|
||||||
|
self._handlers[event] = []
|
||||||
|
|
||||||
|
def subscribe(
|
||||||
|
self,
|
||||||
|
event_type: EventType | str,
|
||||||
|
handler: EventHandler,
|
||||||
|
name: str,
|
||||||
|
weight: int = 0,
|
||||||
|
intercept: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""注册事件 handler
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_type: 事件类型
|
||||||
|
handler: async callable,签名 (Optional[MaiMessages]) -> (bool, Optional[MaiMessages])
|
||||||
|
name: handler 标识名
|
||||||
|
weight: 权重,越大越先执行
|
||||||
|
intercept: 是否为拦截型(同步执行,可中断流程)
|
||||||
|
"""
|
||||||
|
if event_type not in self._handlers:
|
||||||
|
self._handlers[event_type] = []
|
||||||
|
|
||||||
|
entry = _HandlerEntry(handler=handler, name=name, weight=weight, intercept=intercept)
|
||||||
|
self._handlers[event_type].append(entry)
|
||||||
|
self._handlers[event_type].sort(key=lambda e: e.weight, reverse=True)
|
||||||
|
logger.debug(f"注册事件 handler: {name} -> {event_type} (weight={weight}, intercept={intercept})")
|
||||||
|
|
||||||
|
def unsubscribe(self, event_type: EventType | str, name: str) -> bool:
|
||||||
|
"""取消注册事件 handler"""
|
||||||
|
handlers = self._handlers.get(event_type, [])
|
||||||
|
for i, entry in enumerate(handlers):
|
||||||
|
if entry.name == name:
|
||||||
|
del handlers[i]
|
||||||
|
logger.debug(f"取消注册事件 handler: {name} <- {event_type}")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def emit(
|
||||||
|
self,
|
||||||
|
event_type: EventType | str,
|
||||||
|
message: Optional[MaiMessages] = None,
|
||||||
|
) -> Tuple[bool, Optional[MaiMessages]]:
|
||||||
|
"""触发事件
|
||||||
|
|
||||||
|
按权重顺序执行所有 handler:
|
||||||
|
- 拦截型 handler 同步执行,可修改消息和中断流程
|
||||||
|
- 非拦截型 handler 异步 fire-and-forget
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_type: 事件类型
|
||||||
|
message: 事件消息(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(continue_flag, modified_message)
|
||||||
|
- continue_flag: False 表示某个拦截型 handler 要求中断
|
||||||
|
- modified_message: 被拦截型 handler 修改后的消息
|
||||||
|
"""
|
||||||
|
handlers = self._handlers.get(event_type, [])
|
||||||
|
if not handlers:
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
continue_flag = True
|
||||||
|
current_message = message.deepcopy() if message else None
|
||||||
|
|
||||||
|
for entry in handlers:
|
||||||
|
if entry.intercept:
|
||||||
|
try:
|
||||||
|
should_continue, modified = await entry.handler(current_message)
|
||||||
|
if modified is not None:
|
||||||
|
current_message = modified
|
||||||
|
if not should_continue:
|
||||||
|
continue_flag = False
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"拦截型 handler {entry.name} 执行异常: {e}", exc_info=True)
|
||||||
|
else:
|
||||||
|
self._fire_and_forget(entry, event_type, current_message)
|
||||||
|
|
||||||
|
# 桥接到 IPC 插件运行时
|
||||||
|
continue_flag, current_message = await self._bridge_to_ipc_runtime(
|
||||||
|
event_type, continue_flag, current_message
|
||||||
|
)
|
||||||
|
|
||||||
|
return continue_flag, current_message
|
||||||
|
|
||||||
|
async def cancel_handler_tasks(self, handler_name: str) -> None:
|
||||||
|
"""取消某个 handler 的所有运行中任务"""
|
||||||
|
tasks = self._running_tasks.pop(handler_name, [])
|
||||||
|
remaining = [t for t in tasks if not t.done()]
|
||||||
|
if remaining:
|
||||||
|
for t in remaining:
|
||||||
|
t.cancel()
|
||||||
|
await asyncio.gather(*remaining, return_exceptions=True)
|
||||||
|
logger.info(f"已取消 handler {handler_name} 的 {len(remaining)} 个任务")
|
||||||
|
|
||||||
|
# --- 内部方法 ---
|
||||||
|
|
||||||
|
def _fire_and_forget(
|
||||||
|
self,
|
||||||
|
entry: "_HandlerEntry",
|
||||||
|
event_type: EventType | str,
|
||||||
|
message: Optional[MaiMessages],
|
||||||
|
) -> None:
|
||||||
|
"""创建异步任务执行非拦截型 handler"""
|
||||||
|
try:
|
||||||
|
task = asyncio.create_task(entry.handler(message))
|
||||||
|
task.set_name(entry.name)
|
||||||
|
task.add_done_callback(lambda t: self._task_done_callback(t, entry.name))
|
||||||
|
self._running_tasks.setdefault(entry.name, []).append(task)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建 handler 任务 {entry.name} 失败: {e}", exc_info=True)
|
||||||
|
|
||||||
|
def _task_done_callback(self, task: asyncio.Task, handler_name: str) -> None:
|
||||||
|
"""异步任务完成回调"""
|
||||||
|
try:
|
||||||
|
if task.cancelled():
|
||||||
|
return
|
||||||
|
exc = task.exception()
|
||||||
|
if exc:
|
||||||
|
logger.error(f"handler {handler_name} 异步任务异常: {exc}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
task_list = self._running_tasks.get(handler_name, [])
|
||||||
|
try:
|
||||||
|
task_list.remove(task)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _bridge_to_ipc_runtime(
|
||||||
|
self,
|
||||||
|
event_type: EventType | str,
|
||||||
|
continue_flag: bool,
|
||||||
|
message: Optional[MaiMessages],
|
||||||
|
) -> Tuple[bool, Optional[MaiMessages]]:
|
||||||
|
"""将事件桥接到 IPC 插件运行时"""
|
||||||
|
if not continue_flag:
|
||||||
|
return continue_flag, message
|
||||||
|
|
||||||
|
try:
|
||||||
|
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||||
|
|
||||||
|
prm = get_plugin_runtime_manager()
|
||||||
|
if not prm.is_running:
|
||||||
|
return continue_flag, message
|
||||||
|
|
||||||
|
event_value = event_type.value if isinstance(event_type, EventType) else str(event_type)
|
||||||
|
message_dict = message.to_dict() if message and hasattr(message, "to_dict") else None
|
||||||
|
|
||||||
|
new_continue, _ = await prm.bridge_event(
|
||||||
|
event_type_value=event_value,
|
||||||
|
message_dict=message_dict,
|
||||||
|
)
|
||||||
|
if not new_continue:
|
||||||
|
continue_flag = False
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"桥接事件到 IPC 运行时失败: {e}")
|
||||||
|
|
||||||
|
return continue_flag, message
|
||||||
|
|
||||||
|
|
||||||
|
class _HandlerEntry:
|
||||||
|
"""内部 handler 条目"""
|
||||||
|
|
||||||
|
__slots__ = ("handler", "name", "weight", "intercept")
|
||||||
|
|
||||||
|
def __init__(self, handler: EventHandler, name: str, weight: int, intercept: bool):
|
||||||
|
self.handler = handler
|
||||||
|
self.name = name
|
||||||
|
self.weight = weight
|
||||||
|
self.intercept = intercept
|
||||||
|
|
||||||
|
|
||||||
|
# 全局单例
|
||||||
|
event_bus = EventBus()
|
||||||
@@ -169,6 +169,14 @@ class ToolInfo(ComponentInfo):
|
|||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
self.component_type = ComponentType.TOOL
|
self.component_type = ComponentType.TOOL
|
||||||
|
|
||||||
|
def get_llm_definition(self) -> dict:
|
||||||
|
"""生成 LLM function-calling 所需的工具定义"""
|
||||||
|
return {
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.tool_description,
|
||||||
|
"parameters": self.tool_parameters,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EventHandlerInfo(ComponentInfo):
|
class EventHandlerInfo(ComponentInfo):
|
||||||
@@ -10,7 +10,7 @@ from src.config.config import global_config, model_config
|
|||||||
from src.common.database.database_model import ChatHistory
|
from src.common.database.database_model import ChatHistory
|
||||||
from src.prompt.prompt_manager import prompt_manager
|
from src.prompt.prompt_manager import prompt_manager
|
||||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
|
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
|
||||||
from src.plugin_system.apis import llm_api
|
from src.services import llm_service as llm_api
|
||||||
from src.dream.dream_generator import generate_dream_summary
|
from src.dream.dream_generator import generate_dream_summary
|
||||||
|
|
||||||
# dream 工具工厂函数
|
# dream 工具工厂函数
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from src.llm_models.payload_content.message import RoleType, Message
|
|||||||
from src.prompt.prompt_manager import prompt_manager
|
from src.prompt.prompt_manager import prompt_manager
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.common.utils.utils_session import SessionUtils
|
from src.common.utils.utils_session import SessionUtils
|
||||||
from src.plugin_system.apis import send_api
|
from src.services import send_service as send_api
|
||||||
|
|
||||||
logger = get_logger("dream_generator")
|
logger = get_logger("dream_generator")
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from typing import Any, Dict, Optional
|
|||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.database.database_model import ChatHistory
|
from src.common.database.database_model import ChatHistory
|
||||||
from src.plugin_system.apis import database_api
|
from src.services import database_service as database_api
|
||||||
|
|
||||||
logger = get_logger("dream_agent")
|
logger = get_logger("dream_agent")
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from typing import Any, Dict, Optional
|
|||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.database.database_model import Jargon
|
from src.common.database.database_model import Jargon
|
||||||
from src.plugin_system.apis import database_api
|
from src.services import database_service as database_api
|
||||||
|
|
||||||
logger = get_logger("dream_agent")
|
logger = get_logger("dream_agent")
|
||||||
|
|
||||||
|
|||||||
20
src/main.py
20
src/main.py
@@ -18,10 +18,7 @@ from rich.traceback import install
|
|||||||
|
|
||||||
# from src.api.main import start_api_server
|
# from src.api.main import start_api_server
|
||||||
|
|
||||||
# 导入新的插件管理器
|
# 导入插件运行时
|
||||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
|
||||||
|
|
||||||
# 导入新版本插件运行时
|
|
||||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||||
|
|
||||||
# 导入消息API和traceback模块
|
# 导入消息API和traceback模块
|
||||||
@@ -31,8 +28,6 @@ from src.bw_learner.expression_auto_check_task import ExpressionAutoCheckTask
|
|||||||
|
|
||||||
from src.prompt.prompt_manager import prompt_manager
|
from src.prompt.prompt_manager import prompt_manager
|
||||||
|
|
||||||
# 插件系统现在使用统一的插件加载器
|
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
logger = get_logger("main")
|
logger = get_logger("main")
|
||||||
@@ -108,10 +103,7 @@ class MainSystem:
|
|||||||
# 启动LPMM
|
# 启动LPMM
|
||||||
lpmm_start_up()
|
lpmm_start_up()
|
||||||
|
|
||||||
# 加载所有actions,包括默认的和插件的
|
# 启动插件运行时(内置插件 + 第三方插件双子进程)
|
||||||
plugin_manager.load_all_plugins()
|
|
||||||
|
|
||||||
# 启动新版本插件运行时(与旧系统并行运行)
|
|
||||||
await get_plugin_runtime_manager().start()
|
await get_plugin_runtime_manager().start()
|
||||||
|
|
||||||
# 初始化表情管理器
|
# 初始化表情管理器
|
||||||
@@ -133,12 +125,12 @@ class MainSystem:
|
|||||||
prompt_manager.load_prompts()
|
prompt_manager.load_prompts()
|
||||||
|
|
||||||
# 触发 ON_START 事件
|
# 触发 ON_START 事件
|
||||||
from src.plugin_system.core.events_manager import events_manager
|
from src.core.event_bus import event_bus
|
||||||
from src.plugin_system.base.component_types import EventType
|
from src.core.types import EventType
|
||||||
|
|
||||||
await events_manager.handle_mai_events(event_type=EventType.ON_START)
|
await event_bus.emit(event_type=EventType.ON_START)
|
||||||
|
|
||||||
# 桥接 ON_START 事件到新版本插件运行时
|
# 分发 ON_START 事件到插件运行时
|
||||||
await get_plugin_runtime_manager().bridge_event("on_start")
|
await get_plugin_runtime_manager().bridge_event("on_start")
|
||||||
# logger.info("已触发 ON_START 事件")
|
# logger.info("已触发 ON_START 事件")
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from src.common.logger import get_logger
|
|||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.config.config import model_config, global_config
|
from src.config.config import model_config, global_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.plugin_system.apis import message_api
|
from src.services import message_service as message_api
|
||||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||||
from src.chat.utils.utils import is_bot_self
|
from src.chat.utils.utils import is_bot_self
|
||||||
from src.person_info.person_info import Person
|
from src.person_info.person_info import Person
|
||||||
@@ -913,7 +913,7 @@ class ChatHistorySummarizer:
|
|||||||
"""存储到数据库"""
|
"""存储到数据库"""
|
||||||
try:
|
try:
|
||||||
from src.common.database.database_model import ChatHistory
|
from src.common.database.database_model import ChatHistory
|
||||||
from src.plugin_system.apis import database_api
|
from src.services import database_service as database_api
|
||||||
|
|
||||||
# 准备数据
|
# 准备数据
|
||||||
data = {
|
data = {
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from typing import List, Dict, Any, Optional, Tuple, Callable
|
|||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.prompt.prompt_manager import prompt_manager
|
from src.prompt.prompt_manager import prompt_manager
|
||||||
from src.plugin_system.apis import llm_api
|
from src.services import llm_service as llm_api
|
||||||
from sqlmodel import select, col
|
from sqlmodel import select, col
|
||||||
from src.common.database.database import get_db_session
|
from src.common.database.database import get_db_session
|
||||||
from src.common.database.database_model import ThinkingQuestion
|
from src.common.database.database_model import ThinkingQuestion
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -22,8 +22,9 @@ class UDSTransportServer(TransportServer):
|
|||||||
|
|
||||||
def __init__(self, socket_path: str | None = None):
|
def __init__(self, socket_path: str | None = None):
|
||||||
if socket_path is None:
|
if socket_path is None:
|
||||||
# 默认放在临时目录
|
# 默认放在临时目录,使用 uuid 确保同一进程多实例不碰撞
|
||||||
socket_path = os.path.join(tempfile.gettempdir(), f"maibot-plugin-{os.getpid()}.sock")
|
import uuid
|
||||||
|
socket_path = os.path.join(tempfile.gettempdir(), f"maibot-plugin-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock")
|
||||||
self._socket_path = socket_path
|
self._socket_path = socket_path
|
||||||
self._server: asyncio.AbstractServer | None = None
|
self._server: asyncio.AbstractServer | None = None
|
||||||
|
|
||||||
|
|||||||
@@ -1,156 +0,0 @@
|
|||||||
"""
|
|
||||||
MaiBot 插件系统
|
|
||||||
|
|
||||||
提供统一的插件开发和管理框架
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 导出主要的公共接口
|
|
||||||
from .base import (
|
|
||||||
BasePlugin,
|
|
||||||
BaseAction,
|
|
||||||
BaseCommand,
|
|
||||||
BaseTool,
|
|
||||||
ConfigField,
|
|
||||||
ConfigSection,
|
|
||||||
ConfigLayout,
|
|
||||||
ConfigTab,
|
|
||||||
ComponentType,
|
|
||||||
ActionActivationType,
|
|
||||||
ChatMode,
|
|
||||||
ComponentInfo,
|
|
||||||
ActionInfo,
|
|
||||||
CommandInfo,
|
|
||||||
PluginInfo,
|
|
||||||
ToolInfo,
|
|
||||||
PythonDependency,
|
|
||||||
BaseEventHandler,
|
|
||||||
EventHandlerInfo,
|
|
||||||
EventType,
|
|
||||||
MaiMessages,
|
|
||||||
ToolParamType,
|
|
||||||
CustomEventHandlerResult,
|
|
||||||
PluginServiceInfo,
|
|
||||||
ReplyContentType,
|
|
||||||
ReplyContent,
|
|
||||||
ForwardNode,
|
|
||||||
ReplySetModel,
|
|
||||||
WorkflowContext,
|
|
||||||
WorkflowMessage,
|
|
||||||
WorkflowStage,
|
|
||||||
WorkflowStepInfo,
|
|
||||||
WorkflowStepResult,
|
|
||||||
WorkflowErrorCode,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 导入工具模块
|
|
||||||
from .utils import (
|
|
||||||
ManifestValidator,
|
|
||||||
# ManifestGenerator,
|
|
||||||
# validate_plugin_manifest,
|
|
||||||
# generate_plugin_manifest,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .apis import (
|
|
||||||
chat_api,
|
|
||||||
tool_api,
|
|
||||||
component_manage_api,
|
|
||||||
config_api,
|
|
||||||
database_api,
|
|
||||||
emoji_api,
|
|
||||||
generator_api,
|
|
||||||
llm_api,
|
|
||||||
message_api,
|
|
||||||
person_api,
|
|
||||||
plugin_manage_api,
|
|
||||||
plugin_service_api,
|
|
||||||
workflow_api,
|
|
||||||
send_api,
|
|
||||||
register_plugin,
|
|
||||||
get_logger,
|
|
||||||
)
|
|
||||||
|
|
||||||
from src.common.data_models.database_data_model import (
|
|
||||||
DatabaseMessages,
|
|
||||||
DatabaseUserInfo,
|
|
||||||
DatabaseGroupInfo,
|
|
||||||
DatabaseChatInfo,
|
|
||||||
)
|
|
||||||
from src.common.data_models.info_data_model import TargetPersonInfo, ActionPlannerInfo
|
|
||||||
from src.common.data_models.llm_data_model import LLMGenerationDataModel
|
|
||||||
|
|
||||||
|
|
||||||
__version__ = "2.0.0"
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
# API 模块
|
|
||||||
"chat_api",
|
|
||||||
"tool_api",
|
|
||||||
"component_manage_api",
|
|
||||||
"config_api",
|
|
||||||
"database_api",
|
|
||||||
"emoji_api",
|
|
||||||
"generator_api",
|
|
||||||
"llm_api",
|
|
||||||
"message_api",
|
|
||||||
"person_api",
|
|
||||||
"plugin_manage_api",
|
|
||||||
"plugin_service_api",
|
|
||||||
"workflow_api",
|
|
||||||
"send_api",
|
|
||||||
"auto_talk_api",
|
|
||||||
"register_plugin",
|
|
||||||
"get_logger",
|
|
||||||
# 基础类
|
|
||||||
"BasePlugin",
|
|
||||||
"BaseAction",
|
|
||||||
"BaseCommand",
|
|
||||||
"BaseTool",
|
|
||||||
"BaseEventHandler",
|
|
||||||
# 类型定义
|
|
||||||
"ComponentType",
|
|
||||||
"ActionActivationType",
|
|
||||||
"ChatMode",
|
|
||||||
"ComponentInfo",
|
|
||||||
"ActionInfo",
|
|
||||||
"CommandInfo",
|
|
||||||
"PluginInfo",
|
|
||||||
"ToolInfo",
|
|
||||||
"PythonDependency",
|
|
||||||
"EventHandlerInfo",
|
|
||||||
"EventType",
|
|
||||||
"ToolParamType",
|
|
||||||
# 消息
|
|
||||||
"ReplyContentType",
|
|
||||||
"ReplyContent",
|
|
||||||
"ForwardNode",
|
|
||||||
"ReplySetModel",
|
|
||||||
"MaiMessages",
|
|
||||||
"CustomEventHandlerResult",
|
|
||||||
"PluginServiceInfo",
|
|
||||||
"WorkflowContext",
|
|
||||||
"WorkflowMessage",
|
|
||||||
"WorkflowStage",
|
|
||||||
"WorkflowStepInfo",
|
|
||||||
"WorkflowStepResult",
|
|
||||||
"WorkflowErrorCode",
|
|
||||||
# 装饰器
|
|
||||||
"register_plugin",
|
|
||||||
"ConfigField",
|
|
||||||
"ConfigSection",
|
|
||||||
"ConfigLayout",
|
|
||||||
"ConfigTab",
|
|
||||||
# 工具函数
|
|
||||||
"ManifestValidator",
|
|
||||||
"get_logger",
|
|
||||||
# "ManifestGenerator",
|
|
||||||
# "validate_plugin_manifest",
|
|
||||||
# "generate_plugin_manifest",
|
|
||||||
# 数据模型
|
|
||||||
"DatabaseMessages",
|
|
||||||
"DatabaseUserInfo",
|
|
||||||
"DatabaseGroupInfo",
|
|
||||||
"DatabaseChatInfo",
|
|
||||||
"TargetPersonInfo",
|
|
||||||
"ActionPlannerInfo",
|
|
||||||
"LLMGenerationDataModel",
|
|
||||||
]
|
|
||||||
@@ -1,47 +0,0 @@
|
|||||||
"""
|
|
||||||
插件系统API模块
|
|
||||||
|
|
||||||
提供了插件开发所需的各种API
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 导入所有API模块
|
|
||||||
from src.plugin_system.apis import (
|
|
||||||
chat_api,
|
|
||||||
component_manage_api,
|
|
||||||
config_api,
|
|
||||||
database_api,
|
|
||||||
emoji_api,
|
|
||||||
generator_api,
|
|
||||||
llm_api,
|
|
||||||
message_api,
|
|
||||||
person_api,
|
|
||||||
plugin_manage_api,
|
|
||||||
plugin_service_api,
|
|
||||||
send_api,
|
|
||||||
tool_api,
|
|
||||||
frequency_api,
|
|
||||||
workflow_api,
|
|
||||||
)
|
|
||||||
from .logging_api import get_logger
|
|
||||||
from .plugin_register_api import register_plugin
|
|
||||||
|
|
||||||
# 导出所有API模块,使它们可以通过 apis.xxx 方式访问
|
|
||||||
__all__ = [
|
|
||||||
"chat_api",
|
|
||||||
"component_manage_api",
|
|
||||||
"config_api",
|
|
||||||
"database_api",
|
|
||||||
"emoji_api",
|
|
||||||
"generator_api",
|
|
||||||
"llm_api",
|
|
||||||
"message_api",
|
|
||||||
"person_api",
|
|
||||||
"plugin_manage_api",
|
|
||||||
"plugin_service_api",
|
|
||||||
"send_api",
|
|
||||||
"get_logger",
|
|
||||||
"register_plugin",
|
|
||||||
"tool_api",
|
|
||||||
"frequency_api",
|
|
||||||
"workflow_api",
|
|
||||||
]
|
|
||||||
@@ -1,323 +0,0 @@
|
|||||||
"""
|
|
||||||
聊天API模块
|
|
||||||
|
|
||||||
专门负责聊天信息的查询和管理,采用标准Python包设计模式
|
|
||||||
使用方式:
|
|
||||||
from src.plugin_system.apis import chat_api
|
|
||||||
streams = chat_api.get_all_group_streams()
|
|
||||||
chat_type = chat_api.get_stream_type(stream)
|
|
||||||
|
|
||||||
或者:
|
|
||||||
from src.plugin_system.apis.chat_api import ChatManager as chat
|
|
||||||
streams = chat.get_all_group_streams()
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import List, Dict, Any, Optional
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
|
|
||||||
|
|
||||||
logger = get_logger("chat_api")
|
|
||||||
|
|
||||||
|
|
||||||
class SpecialTypes(Enum):
|
|
||||||
"""特殊枚举类型"""
|
|
||||||
|
|
||||||
ALL_PLATFORMS = "all_platforms"
|
|
||||||
|
|
||||||
|
|
||||||
class ChatManager:
|
|
||||||
"""聊天管理器 - 专门负责聊天信息的查询和管理"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
|
|
||||||
# sourcery skip: for-append-to-extend
|
|
||||||
"""获取所有聊天流
|
|
||||||
|
|
||||||
Args:
|
|
||||||
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[BotChatSession]: 聊天流列表
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型
|
|
||||||
"""
|
|
||||||
if not isinstance(platform, (str, SpecialTypes)):
|
|
||||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
|
||||||
streams = []
|
|
||||||
try:
|
|
||||||
for _, stream in _chat_manager.sessions.items():
|
|
||||||
if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform:
|
|
||||||
streams.append(stream)
|
|
||||||
logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的聊天流")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[ChatAPI] 获取聊天流失败: {e}")
|
|
||||||
return streams
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
|
|
||||||
# sourcery skip: for-append-to-extend
|
|
||||||
"""获取所有群聊聊天流
|
|
||||||
|
|
||||||
Args:
|
|
||||||
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[BotChatSession]: 群聊聊天流列表
|
|
||||||
"""
|
|
||||||
if not isinstance(platform, (str, SpecialTypes)):
|
|
||||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
|
||||||
streams = []
|
|
||||||
try:
|
|
||||||
for _, stream in _chat_manager.sessions.items():
|
|
||||||
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.is_group_session:
|
|
||||||
streams.append(stream)
|
|
||||||
logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的群聊流")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[ChatAPI] 获取群聊流失败: {e}")
|
|
||||||
return streams
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
|
|
||||||
# sourcery skip: for-append-to-extend
|
|
||||||
"""获取所有私聊聊天流
|
|
||||||
|
|
||||||
Args:
|
|
||||||
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[BotChatSession]: 私聊聊天流列表
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型
|
|
||||||
"""
|
|
||||||
if not isinstance(platform, (str, SpecialTypes)):
|
|
||||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
|
||||||
streams = []
|
|
||||||
try:
|
|
||||||
for _, stream in _chat_manager.sessions.items():
|
|
||||||
if (
|
|
||||||
platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform
|
|
||||||
) and not stream.is_group_session:
|
|
||||||
streams.append(stream)
|
|
||||||
logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的私聊流")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[ChatAPI] 获取私聊流失败: {e}")
|
|
||||||
return streams
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_group_stream_by_group_id(
|
|
||||||
group_id: str, platform: Optional[str] | SpecialTypes = "qq"
|
|
||||||
) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast
|
|
||||||
"""根据群ID获取聊天流
|
|
||||||
|
|
||||||
Args:
|
|
||||||
group_id: 群聊ID
|
|
||||||
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[BotChatSession]: 聊天流对象,如果未找到返回None
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: 如果 group_id 为空字符串
|
|
||||||
TypeError: 如果 group_id 不是字符串类型或 platform 不是字符串或 SpecialTypes
|
|
||||||
"""
|
|
||||||
if not isinstance(group_id, str):
|
|
||||||
raise TypeError("group_id 必须是字符串类型")
|
|
||||||
if not isinstance(platform, (str, SpecialTypes)):
|
|
||||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
|
||||||
if not group_id:
|
|
||||||
raise ValueError("group_id 不能为空")
|
|
||||||
try:
|
|
||||||
for _, stream in _chat_manager.sessions.items():
|
|
||||||
if (
|
|
||||||
stream.is_group_session
|
|
||||||
and str(stream.group_id) == str(group_id)
|
|
||||||
and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform)
|
|
||||||
):
|
|
||||||
logger.debug(f"[ChatAPI] 找到群ID {group_id} 的聊天流")
|
|
||||||
return stream
|
|
||||||
logger.warning(f"[ChatAPI] 未找到群ID {group_id} 的聊天流")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[ChatAPI] 查找群聊流失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_private_stream_by_user_id(
|
|
||||||
user_id: str, platform: Optional[str] | SpecialTypes = "qq"
|
|
||||||
) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast
|
|
||||||
"""根据用户ID获取私聊流
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: 用户ID
|
|
||||||
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[BotChatSession]: 聊天流对象,如果未找到返回None
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: 如果 user_id 为空字符串
|
|
||||||
TypeError: 如果 user_id 不是字符串类型或 platform 不是字符串或 SpecialTypes
|
|
||||||
"""
|
|
||||||
if not isinstance(user_id, str):
|
|
||||||
raise TypeError("user_id 必须是字符串类型")
|
|
||||||
if not isinstance(platform, (str, SpecialTypes)):
|
|
||||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
|
||||||
if not user_id:
|
|
||||||
raise ValueError("user_id 不能为空")
|
|
||||||
try:
|
|
||||||
for _, stream in _chat_manager.sessions.items():
|
|
||||||
if (
|
|
||||||
not stream.is_group_session
|
|
||||||
and str(stream.user_id) == str(user_id)
|
|
||||||
and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform)
|
|
||||||
):
|
|
||||||
logger.debug(f"[ChatAPI] 找到用户ID {user_id} 的私聊流")
|
|
||||||
return stream
|
|
||||||
logger.warning(f"[ChatAPI] 未找到用户ID {user_id} 的私聊流")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[ChatAPI] 查找私聊流失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_stream_type(chat_stream: BotChatSession) -> str:
|
|
||||||
"""获取聊天流类型
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_stream: 聊天流对象
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 聊天类型 ("group", "private", "unknown")
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: 如果 chat_stream 不是 BotChatSession 类型
|
|
||||||
ValueError: 如果 chat_stream 为空
|
|
||||||
"""
|
|
||||||
if not isinstance(chat_stream, BotChatSession):
|
|
||||||
raise TypeError("chat_stream 必须是 BotChatSession 类型")
|
|
||||||
if not chat_stream:
|
|
||||||
raise ValueError("chat_stream 不能为 None")
|
|
||||||
|
|
||||||
return "group" if chat_stream.is_group_session else "private"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_stream_info(chat_stream: BotChatSession) -> Dict[str, Any]:
|
|
||||||
"""获取聊天流详细信息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_stream: 聊天流对象
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict ({str: Any}): 聊天流信息字典
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: 如果 chat_stream 不是 BotChatSession 类型
|
|
||||||
ValueError: 如果 chat_stream 为空
|
|
||||||
"""
|
|
||||||
if not chat_stream:
|
|
||||||
raise ValueError("chat_stream 不能为 None")
|
|
||||||
if not isinstance(chat_stream, BotChatSession):
|
|
||||||
raise TypeError("chat_stream 必须是 BotChatSession 类型")
|
|
||||||
|
|
||||||
try:
|
|
||||||
info: Dict[str, Any] = {
|
|
||||||
"session_id": chat_stream.session_id,
|
|
||||||
"platform": chat_stream.platform,
|
|
||||||
"type": ChatManager.get_stream_type(chat_stream),
|
|
||||||
}
|
|
||||||
|
|
||||||
if chat_stream.is_group_session:
|
|
||||||
info["group_id"] = chat_stream.group_id
|
|
||||||
# Try to get group name from context
|
|
||||||
if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.group_info:
|
|
||||||
info["group_name"] = chat_stream.context.message.message_info.group_info.group_name or "未知群聊"
|
|
||||||
else:
|
|
||||||
info["group_name"] = "未知群聊"
|
|
||||||
else:
|
|
||||||
info["user_id"] = chat_stream.user_id
|
|
||||||
if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.user_info:
|
|
||||||
info["user_name"] = chat_stream.context.message.message_info.user_info.user_nickname
|
|
||||||
else:
|
|
||||||
info["user_name"] = "未知用户"
|
|
||||||
|
|
||||||
return info
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[ChatAPI] 获取聊天流信息失败: {e}")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_streams_summary() -> Dict[str, int]:
|
|
||||||
"""获取聊天流统计摘要
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict[str, int]: 包含各种统计信息的字典
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
all_streams = ChatManager.get_all_streams(SpecialTypes.ALL_PLATFORMS)
|
|
||||||
group_streams = ChatManager.get_group_streams(SpecialTypes.ALL_PLATFORMS)
|
|
||||||
private_streams = ChatManager.get_private_streams(SpecialTypes.ALL_PLATFORMS)
|
|
||||||
|
|
||||||
summary = {
|
|
||||||
"total_streams": len(all_streams),
|
|
||||||
"group_streams": len(group_streams),
|
|
||||||
"private_streams": len(private_streams),
|
|
||||||
"qq_streams": len([s for s in all_streams if s.platform == "qq"]),
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.debug(f"[ChatAPI] 聊天流统计: {summary}")
|
|
||||||
return summary
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[ChatAPI] 获取聊天流统计失败: {e}")
|
|
||||||
return {
|
|
||||||
"total_streams": 0,
|
|
||||||
"group_streams": 0,
|
|
||||||
"private_streams": 0,
|
|
||||||
"qq_streams": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# 模块级别的便捷函数 - 类似 requests.get(), requests.post() 的设计
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
|
|
||||||
"""获取所有聊天流的便捷函数"""
|
|
||||||
return ChatManager.get_all_streams(platform)
|
|
||||||
|
|
||||||
|
|
||||||
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
|
|
||||||
"""获取群聊聊天流的便捷函数"""
|
|
||||||
return ChatManager.get_group_streams(platform)
|
|
||||||
|
|
||||||
|
|
||||||
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
|
|
||||||
"""获取私聊聊天流的便捷函数"""
|
|
||||||
return ChatManager.get_private_streams(platform)
|
|
||||||
|
|
||||||
|
|
||||||
def get_stream_by_group_id(group_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[BotChatSession]:
|
|
||||||
"""根据群ID获取聊天流的便捷函数"""
|
|
||||||
return ChatManager.get_group_stream_by_group_id(group_id, platform)
|
|
||||||
|
|
||||||
|
|
||||||
def get_stream_by_user_id(user_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[BotChatSession]:
|
|
||||||
"""根据用户ID获取私聊流的便捷函数"""
|
|
||||||
return ChatManager.get_private_stream_by_user_id(user_id, platform)
|
|
||||||
|
|
||||||
|
|
||||||
def get_stream_type(chat_stream: BotChatSession) -> str:
|
|
||||||
"""获取聊天流类型的便捷函数"""
|
|
||||||
return ChatManager.get_stream_type(chat_stream)
|
|
||||||
|
|
||||||
|
|
||||||
def get_stream_info(chat_stream: BotChatSession) -> Dict[str, Any]:
|
|
||||||
"""获取聊天流信息的便捷函数"""
|
|
||||||
return ChatManager.get_stream_info(chat_stream)
|
|
||||||
|
|
||||||
|
|
||||||
def get_streams_summary() -> Dict[str, int]:
|
|
||||||
"""获取聊天流统计摘要的便捷函数"""
|
|
||||||
return ChatManager.get_streams_summary()
|
|
||||||
@@ -1,268 +0,0 @@
|
|||||||
from typing import Optional, Union, Dict
|
|
||||||
from src.plugin_system.base.component_types import (
|
|
||||||
CommandInfo,
|
|
||||||
ActionInfo,
|
|
||||||
EventHandlerInfo,
|
|
||||||
PluginInfo,
|
|
||||||
ComponentType,
|
|
||||||
ToolInfo,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# === 插件信息查询 ===
|
|
||||||
def get_all_plugin_info() -> Dict[str, PluginInfo]:
|
|
||||||
"""
|
|
||||||
获取所有插件的信息。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: 包含所有插件信息的字典,键为插件名称,值为 PluginInfo 对象。
|
|
||||||
"""
|
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
|
||||||
|
|
||||||
return component_registry.get_all_plugins()
|
|
||||||
|
|
||||||
|
|
||||||
def get_plugin_info(plugin_name: str) -> Optional[PluginInfo]:
|
|
||||||
"""
|
|
||||||
获取指定插件的信息。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin_name (str): 插件名称。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
PluginInfo: 插件信息对象,如果插件不存在则返回 None。
|
|
||||||
"""
|
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
|
||||||
|
|
||||||
return component_registry.get_plugin_info(plugin_name)
|
|
||||||
|
|
||||||
|
|
||||||
# === 组件查询方法 ===
|
|
||||||
def get_component_info(
|
|
||||||
component_name: str, component_type: ComponentType
|
|
||||||
) -> Optional[Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
|
|
||||||
"""
|
|
||||||
获取指定组件的信息。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
component_name (str): 组件名称。
|
|
||||||
component_type (ComponentType): 组件类型。
|
|
||||||
Returns:
|
|
||||||
Union[CommandInfo, ActionInfo, EventHandlerInfo]: 组件信息对象,如果组件不存在则返回 None。
|
|
||||||
"""
|
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
|
||||||
|
|
||||||
return component_registry.get_component_info(component_name, component_type) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
def get_components_info_by_type(
|
|
||||||
component_type: ComponentType,
|
|
||||||
) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
|
|
||||||
"""
|
|
||||||
获取指定类型的所有组件信息。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
component_type (ComponentType): 组件类型。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: 包含指定类型组件信息的字典,键为组件名称,值为对应的组件信息对象。
|
|
||||||
"""
|
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
|
||||||
|
|
||||||
return component_registry.get_components_by_type(component_type) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
def get_enabled_components_info_by_type(
|
|
||||||
component_type: ComponentType,
|
|
||||||
) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
|
|
||||||
"""
|
|
||||||
获取指定类型的所有启用的组件信息。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
component_type (ComponentType): 组件类型。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: 包含指定类型启用组件信息的字典,键为组件名称,值为对应的组件信息对象。
|
|
||||||
"""
|
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
|
||||||
|
|
||||||
return component_registry.get_enabled_components_by_type(component_type) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
# === Action 查询方法 ===
|
|
||||||
def get_registered_action_info(action_name: str) -> Optional[ActionInfo]:
|
|
||||||
"""
|
|
||||||
获取指定 Action 的注册信息。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action_name (str): Action 名称。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ActionInfo: Action 信息对象,如果 Action 不存在则返回 None。
|
|
||||||
"""
|
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
|
||||||
|
|
||||||
return component_registry.get_registered_action_info(action_name)
|
|
||||||
|
|
||||||
|
|
||||||
def get_registered_command_info(command_name: str) -> Optional[CommandInfo]:
|
|
||||||
"""
|
|
||||||
获取指定 Command 的注册信息。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
command_name (str): Command 名称。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
CommandInfo: Command 信息对象,如果 Command 不存在则返回 None。
|
|
||||||
"""
|
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
|
||||||
|
|
||||||
return component_registry.get_registered_command_info(command_name)
|
|
||||||
|
|
||||||
|
|
||||||
def get_registered_tool_info(tool_name: str) -> Optional[ToolInfo]:
|
|
||||||
"""
|
|
||||||
获取指定 Tool 的注册信息。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_name (str): Tool 名称。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ToolInfo: Tool 信息对象,如果 Tool 不存在则返回 None。
|
|
||||||
"""
|
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
|
||||||
|
|
||||||
return component_registry.get_registered_tool_info(tool_name)
|
|
||||||
|
|
||||||
|
|
||||||
# === EventHandler 特定查询方法 ===
|
|
||||||
def get_registered_event_handler_info(
|
|
||||||
event_handler_name: str,
|
|
||||||
) -> Optional[EventHandlerInfo]:
|
|
||||||
"""
|
|
||||||
获取指定 EventHandler 的注册信息。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event_handler_name (str): EventHandler 名称。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
EventHandlerInfo: EventHandler 信息对象,如果 EventHandler 不存在则返回 None。
|
|
||||||
"""
|
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
|
||||||
|
|
||||||
return component_registry.get_registered_event_handler_info(event_handler_name)
|
|
||||||
|
|
||||||
|
|
||||||
# === 组件管理方法 ===
|
|
||||||
def globally_enable_component(component_name: str, component_type: ComponentType) -> bool:
|
|
||||||
"""
|
|
||||||
全局启用指定组件。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
component_name (str): 组件名称。
|
|
||||||
component_type (ComponentType): 组件类型。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 启用成功返回 True,否则返回 False。
|
|
||||||
"""
|
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
|
||||||
|
|
||||||
return component_registry.enable_component(component_name, component_type)
|
|
||||||
|
|
||||||
|
|
||||||
async def globally_disable_component(component_name: str, component_type: ComponentType) -> bool:
|
|
||||||
"""
|
|
||||||
全局禁用指定组件。
|
|
||||||
|
|
||||||
**此函数是异步的,确保在异步环境中调用。**
|
|
||||||
|
|
||||||
Args:
|
|
||||||
component_name (str): 组件名称。
|
|
||||||
component_type (ComponentType): 组件类型。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 禁用成功返回 True,否则返回 False。
|
|
||||||
"""
|
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
|
||||||
|
|
||||||
return await component_registry.disable_component(component_name, component_type)
|
|
||||||
|
|
||||||
|
|
||||||
def locally_enable_component(component_name: str, component_type: ComponentType, stream_id: str) -> bool:
|
|
||||||
"""
|
|
||||||
局部启用指定组件。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
component_name (str): 组件名称。
|
|
||||||
component_type (ComponentType): 组件类型。
|
|
||||||
stream_id (str): 消息流 ID。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 启用成功返回 True,否则返回 False。
|
|
||||||
"""
|
|
||||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
|
||||||
|
|
||||||
match component_type:
|
|
||||||
case ComponentType.ACTION:
|
|
||||||
return global_announcement_manager.enable_specific_chat_action(stream_id, component_name)
|
|
||||||
case ComponentType.COMMAND:
|
|
||||||
return global_announcement_manager.enable_specific_chat_command(stream_id, component_name)
|
|
||||||
case ComponentType.TOOL:
|
|
||||||
return global_announcement_manager.enable_specific_chat_tool(stream_id, component_name)
|
|
||||||
case ComponentType.EVENT_HANDLER:
|
|
||||||
return global_announcement_manager.enable_specific_chat_event_handler(stream_id, component_name)
|
|
||||||
case _:
|
|
||||||
raise ValueError(f"未知 component type: {component_type}")
|
|
||||||
|
|
||||||
|
|
||||||
def locally_disable_component(component_name: str, component_type: ComponentType, stream_id: str) -> bool:
|
|
||||||
"""
|
|
||||||
局部禁用指定组件。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
component_name (str): 组件名称。
|
|
||||||
component_type (ComponentType): 组件类型。
|
|
||||||
stream_id (str): 消息流 ID。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 禁用成功返回 True,否则返回 False。
|
|
||||||
"""
|
|
||||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
|
||||||
|
|
||||||
match component_type:
|
|
||||||
case ComponentType.ACTION:
|
|
||||||
return global_announcement_manager.disable_specific_chat_action(stream_id, component_name)
|
|
||||||
case ComponentType.COMMAND:
|
|
||||||
return global_announcement_manager.disable_specific_chat_command(stream_id, component_name)
|
|
||||||
case ComponentType.TOOL:
|
|
||||||
return global_announcement_manager.disable_specific_chat_tool(stream_id, component_name)
|
|
||||||
case ComponentType.EVENT_HANDLER:
|
|
||||||
return global_announcement_manager.disable_specific_chat_event_handler(stream_id, component_name)
|
|
||||||
case _:
|
|
||||||
raise ValueError(f"未知 component type: {component_type}")
|
|
||||||
|
|
||||||
|
|
||||||
def get_locally_disabled_components(stream_id: str, component_type: ComponentType) -> list[str]:
|
|
||||||
"""
|
|
||||||
获取指定消息流中禁用的组件列表。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
stream_id (str): 消息流 ID。
|
|
||||||
component_type (ComponentType): 组件类型。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[str]: 禁用的组件名称列表。
|
|
||||||
"""
|
|
||||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
|
||||||
|
|
||||||
match component_type:
|
|
||||||
case ComponentType.ACTION:
|
|
||||||
return global_announcement_manager.get_disabled_chat_actions(stream_id)
|
|
||||||
case ComponentType.COMMAND:
|
|
||||||
return global_announcement_manager.get_disabled_chat_commands(stream_id)
|
|
||||||
case ComponentType.TOOL:
|
|
||||||
return global_announcement_manager.get_disabled_chat_tools(stream_id)
|
|
||||||
case ComponentType.EVENT_HANDLER:
|
|
||||||
return global_announcement_manager.get_disabled_chat_event_handlers(stream_id)
|
|
||||||
case _:
|
|
||||||
raise ValueError(f"未知 component type: {component_type}")
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class _SystemConstants:
|
|
||||||
PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.parent.absolute().resolve()
|
|
||||||
CONFIG_DIR: Path = PROJECT_ROOT / "config"
|
|
||||||
BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
|
|
||||||
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
|
|
||||||
PLUGINS_DIR: Path = (PROJECT_ROOT / "plugins").resolve().absolute()
|
|
||||||
INTERNAL_PLUGINS_DIR: Path = (PROJECT_ROOT / "src" / "plugins").resolve().absolute()
|
|
||||||
|
|
||||||
|
|
||||||
_system_constants = _SystemConstants()
|
|
||||||
|
|
||||||
PROJECT_ROOT: Path = _system_constants.PROJECT_ROOT
|
|
||||||
CONFIG_DIR: Path = _system_constants.CONFIG_DIR
|
|
||||||
BOT_CONFIG_PATH: Path = _system_constants.BOT_CONFIG_PATH
|
|
||||||
MODEL_CONFIG_PATH: Path = _system_constants.MODEL_CONFIG_PATH
|
|
||||||
PLUGINS_DIR: Path = _system_constants.PLUGINS_DIR
|
|
||||||
INTERNAL_PLUGINS_DIR: Path = _system_constants.INTERNAL_PLUGINS_DIR
|
|
||||||
@@ -1,700 +0,0 @@
|
|||||||
"""
|
|
||||||
表情API模块
|
|
||||||
|
|
||||||
提供表情包相关功能,采用标准Python包设计模式
|
|
||||||
使用方式:
|
|
||||||
from src.plugin_system.apis import emoji_api
|
|
||||||
result = await emoji_api.get_by_description("开心")
|
|
||||||
count = emoji_api.get_count()
|
|
||||||
"""
|
|
||||||
|
|
||||||
import random
|
|
||||||
import base64
|
|
||||||
import os
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from typing import Optional, Tuple, List, Dict, Any
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.chat.emoji_system.emoji_manager import emoji_manager, EMOJI_DIR
|
|
||||||
from src.common.utils.utils_image import ImageUtils
|
|
||||||
from src.config.config import global_config
|
|
||||||
|
|
||||||
logger = get_logger("emoji_api")
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# 表情包获取API函数
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]:
|
|
||||||
"""根据描述选择表情包
|
|
||||||
|
|
||||||
Args:
|
|
||||||
description: 表情包的描述文本,例如"开心"、"难过"、"愤怒"等
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: 如果描述为空字符串
|
|
||||||
TypeError: 如果描述不是字符串类型
|
|
||||||
"""
|
|
||||||
if not description:
|
|
||||||
raise ValueError("描述不能为空")
|
|
||||||
if not isinstance(description, str):
|
|
||||||
raise TypeError("描述必须是字符串类型")
|
|
||||||
try:
|
|
||||||
logger.debug(f"[EmojiAPI] 根据描述获取表情包: {description}")
|
|
||||||
|
|
||||||
emoji_obj = await emoji_manager.get_emoji_for_emotion(description)
|
|
||||||
|
|
||||||
if not emoji_obj:
|
|
||||||
logger.warning(f"[EmojiAPI] 未找到匹配描述 '{description}' 的表情包")
|
|
||||||
return None
|
|
||||||
|
|
||||||
emoji_path = str(emoji_obj.full_path)
|
|
||||||
emoji_description = emoji_obj.description
|
|
||||||
matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else ""
|
|
||||||
emoji_base64 = ImageUtils.image_path_to_base64(emoji_path)
|
|
||||||
|
|
||||||
if not emoji_base64:
|
|
||||||
logger.error(f"[EmojiAPI] 无法将表情包文件转换为base64: {emoji_path}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
logger.debug(f"[EmojiAPI] 成功获取表情包: {emoji_description}, 匹配情感: {matched_emotion}")
|
|
||||||
return emoji_base64, emoji_description, matched_emotion
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[EmojiAPI] 获取表情包失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
|
|
||||||
"""随机获取指定数量的表情包
|
|
||||||
|
|
||||||
Args:
|
|
||||||
count: 要获取的表情包数量,默认为1
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Tuple[str, str, str]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表,失败则返回空列表
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: 如果count不是整数类型
|
|
||||||
ValueError: 如果count为负数
|
|
||||||
"""
|
|
||||||
if not isinstance(count, int):
|
|
||||||
raise TypeError("count 必须是整数类型")
|
|
||||||
if count < 0:
|
|
||||||
raise ValueError("count 不能为负数")
|
|
||||||
if count == 0:
|
|
||||||
logger.warning("[EmojiAPI] count 为0,返回空列表")
|
|
||||||
return []
|
|
||||||
|
|
||||||
try:
|
|
||||||
all_emojis = emoji_manager.emojis
|
|
||||||
|
|
||||||
if not all_emojis:
|
|
||||||
logger.warning("[EmojiAPI] 没有可用的表情包")
|
|
||||||
return []
|
|
||||||
|
|
||||||
# 过滤有效表情包
|
|
||||||
valid_emojis = [emoji for emoji in all_emojis if not emoji.is_deleted]
|
|
||||||
if not valid_emojis:
|
|
||||||
logger.warning("[EmojiAPI] 没有有效的表情包")
|
|
||||||
return []
|
|
||||||
|
|
||||||
if len(valid_emojis) < count:
|
|
||||||
logger.debug(
|
|
||||||
f"[EmojiAPI] 有效表情包数量 ({len(valid_emojis)}) 少于请求的数量 ({count}),将返回所有有效表情包"
|
|
||||||
)
|
|
||||||
count = len(valid_emojis)
|
|
||||||
|
|
||||||
# 随机选择
|
|
||||||
selected_emojis = random.sample(valid_emojis, count)
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for selected_emoji in selected_emojis:
|
|
||||||
emoji_base64 = ImageUtils.image_path_to_base64(str(selected_emoji.full_path))
|
|
||||||
|
|
||||||
if not emoji_base64:
|
|
||||||
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {selected_emoji.full_path}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
matched_emotion = random.choice(selected_emoji.emotion) if selected_emoji.emotion else "随机表情"
|
|
||||||
|
|
||||||
# 记录使用次数
|
|
||||||
emoji_manager.update_emoji_usage(selected_emoji)
|
|
||||||
results.append((emoji_base64, selected_emoji.description, matched_emotion))
|
|
||||||
|
|
||||||
if not results and count > 0:
|
|
||||||
logger.warning("[EmojiAPI] 随机获取表情包失败,没有一个可以成功处理")
|
|
||||||
return []
|
|
||||||
|
|
||||||
logger.debug(f"[EmojiAPI] 成功获取 {len(results)} 个随机表情包")
|
|
||||||
return results
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[EmojiAPI] 获取随机表情包失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
|
|
||||||
"""根据情感标签获取表情包
|
|
||||||
|
|
||||||
Args:
|
|
||||||
emotion: 情感标签,如"happy"、"sad"、"angry"等
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: 如果情感标签为空字符串
|
|
||||||
TypeError: 如果情感标签不是字符串类型
|
|
||||||
"""
|
|
||||||
if not emotion:
|
|
||||||
raise ValueError("情感标签不能为空")
|
|
||||||
if not isinstance(emotion, str):
|
|
||||||
raise TypeError("情感标签必须是字符串类型")
|
|
||||||
try:
|
|
||||||
logger.info(f"[EmojiAPI] 根据情感获取表情包: {emotion}")
|
|
||||||
|
|
||||||
all_emojis = emoji_manager.emojis
|
|
||||||
|
|
||||||
# 筛选匹配情感的表情包
|
|
||||||
matching_emojis = []
|
|
||||||
matching_emojis.extend(
|
|
||||||
emoji_obj
|
|
||||||
for emoji_obj in all_emojis
|
|
||||||
if not emoji_obj.is_deleted and emotion.lower() in [e.lower() for e in emoji_obj.emotion]
|
|
||||||
)
|
|
||||||
if not matching_emojis:
|
|
||||||
logger.warning(f"[EmojiAPI] 未找到匹配情感 '{emotion}' 的表情包")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 随机选择匹配的表情包
|
|
||||||
selected_emoji = random.choice(matching_emojis)
|
|
||||||
emoji_base64 = ImageUtils.image_path_to_base64(selected_emoji.full_path)
|
|
||||||
|
|
||||||
if not emoji_base64:
|
|
||||||
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {selected_emoji.full_path}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 记录使用次数
|
|
||||||
emoji_manager.update_emoji_usage(selected_emoji)
|
|
||||||
|
|
||||||
logger.info(f"[EmojiAPI] 成功获取情感表情包: {selected_emoji.description}")
|
|
||||||
return emoji_base64, selected_emoji.description, emotion
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[EmojiAPI] 根据情感获取表情包失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# 表情包信息查询API函数
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
def get_count() -> int:
|
|
||||||
"""获取表情包数量
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: 当前可用的表情包数量
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return len(emoji_manager.emojis)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[EmojiAPI] 获取表情包数量失败: {e}")
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
def get_info():
|
|
||||||
"""获取表情包系统信息
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: 包含表情包数量、最大数量、可用数量信息
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return {
|
|
||||||
"current_count": len(emoji_manager.emojis),
|
|
||||||
"max_count": global_config.emoji.max_reg_num,
|
|
||||||
"available_emojis": len([e for e in emoji_manager.emojis if not e.is_deleted]),
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[EmojiAPI] 获取表情包信息失败: {e}")
|
|
||||||
return {"current_count": 0, "max_count": 0, "available_emojis": 0}
|
|
||||||
|
|
||||||
|
|
||||||
def get_emotions() -> List[str]:
|
|
||||||
"""获取所有可用的情感标签
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: 所有表情包的情感标签列表(去重)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
emotions = set()
|
|
||||||
|
|
||||||
for emoji_obj in emoji_manager.emojis:
|
|
||||||
if not emoji_obj.is_deleted and emoji_obj.emotion:
|
|
||||||
emotions.update(emoji_obj.emotion)
|
|
||||||
|
|
||||||
return sorted(list(emotions))
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[EmojiAPI] 获取情感标签失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
async def get_all() -> List[Tuple[str, str, str]]:
|
|
||||||
"""获取所有表情包
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Tuple[str, str, str]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
all_emojis = emoji_manager.emojis
|
|
||||||
|
|
||||||
if not all_emojis:
|
|
||||||
logger.warning("[EmojiAPI] 没有可用的表情包")
|
|
||||||
return []
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for emoji_obj in all_emojis:
|
|
||||||
if emoji_obj.is_deleted:
|
|
||||||
continue
|
|
||||||
|
|
||||||
emoji_base64 = ImageUtils.image_path_to_base64(str(emoji_obj.full_path))
|
|
||||||
|
|
||||||
if not emoji_base64:
|
|
||||||
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {emoji_obj.full_path}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else "随机表情"
|
|
||||||
results.append((emoji_base64, emoji_obj.description, matched_emotion))
|
|
||||||
|
|
||||||
logger.debug(f"[EmojiAPI] 成功获取 {len(results)} 个表情包")
|
|
||||||
return results
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[EmojiAPI] 获取所有表情包失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def get_descriptions() -> List[str]:
|
|
||||||
"""获取所有表情包描述
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: 所有可用表情包的描述列表
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
descriptions = []
|
|
||||||
|
|
||||||
descriptions.extend(
|
|
||||||
emoji_obj.description
|
|
||||||
for emoji_obj in emoji_manager.emojis
|
|
||||||
if not emoji_obj.is_deleted and emoji_obj.description
|
|
||||||
)
|
|
||||||
return descriptions
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[EmojiAPI] 获取表情包描述失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# 表情包注册API函数
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
async def register_emoji(image_base64: str, filename: Optional[str] = None) -> Dict[str, Any]:
|
|
||||||
"""注册新的表情包
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_base64: 图片的base64编码
|
|
||||||
filename: 可选的文件名,如果未提供则自动生成
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict[str, Any]: 注册结果,包含以下字段:
|
|
||||||
- success: bool, 是否成功注册
|
|
||||||
- message: str, 结果消息
|
|
||||||
- description: Optional[str], 表情包描述(成功时)
|
|
||||||
- emotions: Optional[List[str]], 情感标签列表(成功时)
|
|
||||||
- replaced: Optional[bool], 是否替换了旧表情包(成功时)
|
|
||||||
- hash: Optional[str], 表情包哈希值(成功时)
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: 如果base64为空或无效
|
|
||||||
TypeError: 如果参数类型不正确
|
|
||||||
"""
|
|
||||||
if not image_base64:
|
|
||||||
raise ValueError("图片base64编码不能为空")
|
|
||||||
if not isinstance(image_base64, str):
|
|
||||||
raise TypeError("image_base64必须是字符串类型")
|
|
||||||
if filename is not None and not isinstance(filename, str):
|
|
||||||
raise TypeError("filename必须是字符串类型或None")
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.info(f"[EmojiAPI] 开始注册表情包,文件名: {filename or '自动生成'}")
|
|
||||||
|
|
||||||
# 1. 获取emoji管理器并检查容量
|
|
||||||
count_before = len(emoji_manager.emojis)
|
|
||||||
max_count = global_config.emoji.max_reg_num
|
|
||||||
|
|
||||||
# 2. 检查是否可以注册(未达到上限或启用替换)
|
|
||||||
can_register = count_before < max_count or (count_before >= max_count and global_config.emoji.do_replace)
|
|
||||||
|
|
||||||
if not can_register:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"message": f"表情包数量已达上限({count_before}/{max_count})且未启用替换功能",
|
|
||||||
"description": None,
|
|
||||||
"emotions": None,
|
|
||||||
"replaced": None,
|
|
||||||
"hash": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 3. 确保emoji目录存在
|
|
||||||
os.makedirs(EMOJI_DIR, exist_ok=True)
|
|
||||||
|
|
||||||
# 4. 生成文件名
|
|
||||||
if not filename:
|
|
||||||
# 基于时间戳、微秒和短base64生成唯一文件名
|
|
||||||
import time
|
|
||||||
|
|
||||||
timestamp = int(time.time())
|
|
||||||
microseconds = int(time.time() * 1000000) % 1000000 # 添加微秒级精度
|
|
||||||
|
|
||||||
# 生成12位随机标识符,使用base64编码(增加随机性)
|
|
||||||
import random
|
|
||||||
|
|
||||||
random_bytes = random.getrandbits(72).to_bytes(9, "big") # 72位 = 9字节 = 12位base64
|
|
||||||
short_id = base64.b64encode(random_bytes).decode("ascii")[:12].rstrip("=")
|
|
||||||
# 确保base64编码适合文件名(替换/和-)
|
|
||||||
short_id = short_id.replace("/", "_").replace("+", "-")
|
|
||||||
filename = f"emoji_{timestamp}_{microseconds}_{short_id}"
|
|
||||||
|
|
||||||
# 确保文件名有扩展名
|
|
||||||
if not filename.lower().endswith((".jpg", ".jpeg", ".png", ".gif")):
|
|
||||||
filename = f"{filename}.png" # 默认使用png格式
|
|
||||||
|
|
||||||
# 检查文件名是否已存在,如果存在则重新生成短标识符
|
|
||||||
temp_file_path = os.path.join(EMOJI_DIR, filename)
|
|
||||||
attempts = 0
|
|
||||||
max_attempts = 10
|
|
||||||
while os.path.exists(temp_file_path) and attempts < max_attempts:
|
|
||||||
# 重新生成短标识符
|
|
||||||
import random
|
|
||||||
|
|
||||||
random_bytes = random.getrandbits(48).to_bytes(6, "big")
|
|
||||||
short_id = base64.b64encode(random_bytes).decode("ascii")[:8].rstrip("=")
|
|
||||||
short_id = short_id.replace("/", "_").replace("+", "-")
|
|
||||||
|
|
||||||
# 分离文件名和扩展名,重新生成文件名
|
|
||||||
name_part, ext = os.path.splitext(filename)
|
|
||||||
# 去掉原来的标识符,添加新的
|
|
||||||
base_name = name_part.rsplit("_", 1)[0] # 移除最后一个_后的部分
|
|
||||||
filename = f"{base_name}_{short_id}{ext}"
|
|
||||||
temp_file_path = os.path.join(EMOJI_DIR, filename)
|
|
||||||
attempts += 1
|
|
||||||
|
|
||||||
# 如果还是冲突,使用UUID作为备用方案
|
|
||||||
if os.path.exists(temp_file_path):
|
|
||||||
uuid_short = str(uuid.uuid4())[:8]
|
|
||||||
name_part, ext = os.path.splitext(filename)
|
|
||||||
base_name = name_part.rsplit("_", 1)[0]
|
|
||||||
filename = f"{base_name}_{uuid_short}{ext}"
|
|
||||||
temp_file_path = os.path.join(EMOJI_DIR, filename)
|
|
||||||
|
|
||||||
# 如果UUID方案也冲突,添加序号
|
|
||||||
counter = 1
|
|
||||||
original_filename = filename
|
|
||||||
while os.path.exists(temp_file_path):
|
|
||||||
name_part, ext = os.path.splitext(original_filename)
|
|
||||||
filename = f"{name_part}_{counter}{ext}"
|
|
||||||
temp_file_path = os.path.join(EMOJI_DIR, filename)
|
|
||||||
counter += 1
|
|
||||||
|
|
||||||
# 防止无限循环,最多尝试100次
|
|
||||||
if counter > 100:
|
|
||||||
logger.error(f"[EmojiAPI] 无法生成唯一文件名,尝试次数过多: {original_filename}")
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"message": "无法生成唯一文件名,请稍后重试",
|
|
||||||
"description": None,
|
|
||||||
"emotions": None,
|
|
||||||
"replaced": None,
|
|
||||||
"hash": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 5. 保存base64图片到emoji目录
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 解码base64并保存图片
|
|
||||||
if not ImageUtils.base64_to_image(image_base64, temp_file_path):
|
|
||||||
logger.error(f"[EmojiAPI] 无法保存base64图片到文件: {temp_file_path}")
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"message": "无法保存图片文件",
|
|
||||||
"description": None,
|
|
||||||
"emotions": None,
|
|
||||||
"replaced": None,
|
|
||||||
"hash": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.debug(f"[EmojiAPI] 图片已保存到临时文件: {temp_file_path}")
|
|
||||||
|
|
||||||
except Exception as save_error:
|
|
||||||
logger.error(f"[EmojiAPI] 保存图片文件失败: {save_error}")
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"message": f"保存图片文件失败: {str(save_error)}",
|
|
||||||
"description": None,
|
|
||||||
"emotions": None,
|
|
||||||
"replaced": None,
|
|
||||||
"hash": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 6. 调用注册方法
|
|
||||||
register_success = await emoji_manager.register_emoji_by_filename(filename)
|
|
||||||
|
|
||||||
# 7. 清理临时文件(如果注册失败但文件还存在)
|
|
||||||
if not register_success and os.path.exists(temp_file_path):
|
|
||||||
try:
|
|
||||||
os.remove(temp_file_path)
|
|
||||||
logger.debug(f"[EmojiAPI] 已清理临时文件: {temp_file_path}")
|
|
||||||
except Exception as cleanup_error:
|
|
||||||
logger.warning(f"[EmojiAPI] 清理临时文件失败: {cleanup_error}")
|
|
||||||
|
|
||||||
# 8. 构建返回结果
|
|
||||||
if register_success:
|
|
||||||
count_after = len(emoji_manager.emojis)
|
|
||||||
replaced = count_after <= count_before # 如果数量没增加,说明是替换
|
|
||||||
|
|
||||||
# 尝试获取新注册的表情包信息
|
|
||||||
new_emoji_info = None
|
|
||||||
if count_after > count_before or replaced:
|
|
||||||
# 获取最新的表情包信息
|
|
||||||
try:
|
|
||||||
# 通过文件名查找新注册的表情包(注意:文件名在注册后可能已经改变)
|
|
||||||
for emoji_obj in reversed(emoji_manager.emojis):
|
|
||||||
if not emoji_obj.is_deleted and (
|
|
||||||
emoji_obj.file_name == filename
|
|
||||||
or (hasattr(emoji_obj, "full_path") and filename in str(emoji_obj.full_path))
|
|
||||||
):
|
|
||||||
new_emoji_info = emoji_obj
|
|
||||||
break
|
|
||||||
except Exception as find_error:
|
|
||||||
logger.warning(f"[EmojiAPI] 查找新注册表情包信息失败: {find_error}")
|
|
||||||
|
|
||||||
description = new_emoji_info.description if new_emoji_info else None
|
|
||||||
emotions = new_emoji_info.emotion if new_emoji_info else None
|
|
||||||
emoji_hash = new_emoji_info.emoji_hash if new_emoji_info else None
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"message": f"表情包注册成功 {'(替换旧表情包)' if replaced else '(新增表情包)'}",
|
|
||||||
"description": description,
|
|
||||||
"emotions": emotions,
|
|
||||||
"replaced": replaced,
|
|
||||||
"hash": emoji_hash,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"message": "表情包注册失败,可能因为重复、格式不支持或审核未通过",
|
|
||||||
"description": None,
|
|
||||||
"emotions": None,
|
|
||||||
"replaced": None,
|
|
||||||
"hash": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[EmojiAPI] 注册表情包时发生异常: {e}")
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"message": f"注册过程中发生错误: {str(e)}",
|
|
||||||
"description": None,
|
|
||||||
"emotions": None,
|
|
||||||
"replaced": None,
|
|
||||||
"hash": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# 表情包删除API函数
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
async def delete_emoji(emoji_hash: str) -> Dict[str, Any]:
|
|
||||||
"""删除表情包
|
|
||||||
|
|
||||||
Args:
|
|
||||||
emoji_hash: 要删除的表情包的哈希值
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict[str, Any]: 删除结果,包含以下字段:
|
|
||||||
- success: bool, 是否成功删除
|
|
||||||
- message: str, 结果消息
|
|
||||||
- count_before: Optional[int], 删除前的表情包数量
|
|
||||||
- count_after: Optional[int], 删除后的表情包数量
|
|
||||||
- description: Optional[str], 被删除的表情包描述(成功时)
|
|
||||||
- emotions: Optional[List[str]], 被删除的表情包情感标签(成功时)
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: 如果哈希值为空
|
|
||||||
TypeError: 如果哈希值不是字符串类型
|
|
||||||
"""
|
|
||||||
if not emoji_hash:
|
|
||||||
raise ValueError("表情包哈希值不能为空")
|
|
||||||
if not isinstance(emoji_hash, str):
|
|
||||||
raise TypeError("emoji_hash必须是字符串类型")
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.info(f"[EmojiAPI] 开始删除表情包,哈希值: {emoji_hash}")
|
|
||||||
|
|
||||||
# 1. 获取emoji管理器和删除前的数量
|
|
||||||
count_before = len(emoji_manager.emojis)
|
|
||||||
|
|
||||||
# 2. 获取被删除表情包的信息(用于返回结果)
|
|
||||||
deleted_emoji = None
|
|
||||||
try:
|
|
||||||
deleted_emoji = emoji_manager.get_emoji_by_hash(emoji_hash) or emoji_manager.get_emoji_by_hash_from_db(
|
|
||||||
emoji_hash
|
|
||||||
)
|
|
||||||
description = deleted_emoji.description if deleted_emoji else None
|
|
||||||
emotions = deleted_emoji.emotion if deleted_emoji else None
|
|
||||||
except Exception as info_error:
|
|
||||||
logger.warning(f"[EmojiAPI] 获取被删除表情包信息失败: {info_error}")
|
|
||||||
description = None
|
|
||||||
emotions = None
|
|
||||||
|
|
||||||
# 3. 执行删除操作
|
|
||||||
delete_success = False
|
|
||||||
if deleted_emoji:
|
|
||||||
delete_success = emoji_manager.delete_emoji(deleted_emoji)
|
|
||||||
|
|
||||||
# 4. 获取删除后的数量
|
|
||||||
count_after = len(emoji_manager.emojis)
|
|
||||||
|
|
||||||
# 5. 构建返回结果
|
|
||||||
if delete_success:
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"message": f"表情包删除成功 (哈希: {emoji_hash[:8]}...)",
|
|
||||||
"count_before": count_before,
|
|
||||||
"count_after": count_after,
|
|
||||||
"description": description,
|
|
||||||
"emotions": emotions,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"message": "表情包删除失败,可能因为哈希值不存在或删除过程出错",
|
|
||||||
"count_before": count_before,
|
|
||||||
"count_after": count_after,
|
|
||||||
"description": None,
|
|
||||||
"emotions": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[EmojiAPI] 删除表情包时发生异常: {e}")
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"message": f"删除过程中发生错误: {str(e)}",
|
|
||||||
"count_before": None,
|
|
||||||
"count_after": None,
|
|
||||||
"description": None,
|
|
||||||
"emotions": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def delete_emoji_by_description(description: str, exact_match: bool = False) -> Dict[str, Any]:
|
|
||||||
"""根据描述删除表情包
|
|
||||||
|
|
||||||
Args:
|
|
||||||
description: 表情包描述文本
|
|
||||||
exact_match: 是否精确匹配描述,False则为模糊匹配
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict[str, Any]: 删除结果,包含以下字段:
|
|
||||||
- success: bool, 是否成功删除
|
|
||||||
- message: str, 结果消息
|
|
||||||
- deleted_count: int, 删除的表情包数量
|
|
||||||
- deleted_hashes: List[str], 被删除的表情包哈希列表
|
|
||||||
- matched_count: int, 匹配到的表情包数量
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: 如果描述为空
|
|
||||||
TypeError: 如果描述不是字符串类型
|
|
||||||
"""
|
|
||||||
if not description:
|
|
||||||
raise ValueError("描述不能为空")
|
|
||||||
if not isinstance(description, str):
|
|
||||||
raise TypeError("description必须是字符串类型")
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.info(f"[EmojiAPI] 根据描述删除表情包: {description} (精确匹配: {exact_match})")
|
|
||||||
|
|
||||||
all_emojis = emoji_manager.emojis
|
|
||||||
|
|
||||||
# 筛选匹配的表情包
|
|
||||||
matching_emojis = []
|
|
||||||
for emoji_obj in all_emojis:
|
|
||||||
if emoji_obj.is_deleted:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if exact_match:
|
|
||||||
if emoji_obj.description == description:
|
|
||||||
matching_emojis.append(emoji_obj)
|
|
||||||
else:
|
|
||||||
if description.lower() in emoji_obj.description.lower():
|
|
||||||
matching_emojis.append(emoji_obj)
|
|
||||||
|
|
||||||
matched_count = len(matching_emojis)
|
|
||||||
if matched_count == 0:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"message": f"未找到匹配描述 '{description}' 的表情包",
|
|
||||||
"deleted_count": 0,
|
|
||||||
"deleted_hashes": [],
|
|
||||||
"matched_count": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 删除匹配的表情包
|
|
||||||
deleted_count = 0
|
|
||||||
deleted_hashes = []
|
|
||||||
for emoji_obj in matching_emojis:
|
|
||||||
try:
|
|
||||||
delete_success = emoji_manager.delete_emoji(emoji_obj)
|
|
||||||
if delete_success:
|
|
||||||
deleted_count += 1
|
|
||||||
deleted_hashes.append(emoji_obj.emoji_hash)
|
|
||||||
except Exception as delete_error:
|
|
||||||
logger.error(f"[EmojiAPI] 删除表情包失败 (哈希: {emoji_obj.emoji_hash}): {delete_error}")
|
|
||||||
|
|
||||||
# 构建返回结果
|
|
||||||
if deleted_count > 0:
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"message": f"成功删除 {deleted_count} 个表情包 (匹配到 {matched_count} 个)",
|
|
||||||
"deleted_count": deleted_count,
|
|
||||||
"deleted_hashes": deleted_hashes,
|
|
||||||
"matched_count": matched_count,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"message": f"匹配到 {matched_count} 个表情包,但删除全部失败",
|
|
||||||
"deleted_count": 0,
|
|
||||||
"deleted_hashes": [],
|
|
||||||
"matched_count": matched_count,
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[EmojiAPI] 根据描述删除表情包时发生异常: {e}")
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"message": f"删除过程中发生错误: {str(e)}",
|
|
||||||
"deleted_count": 0,
|
|
||||||
"deleted_hashes": [],
|
|
||||||
"matched_count": 0,
|
|
||||||
}
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
from src.common.logger import get_logger
|
|
||||||
|
|
||||||
__all__ = ["get_logger"]
|
|
||||||
@@ -1,120 +0,0 @@
|
|||||||
from typing import Tuple, List
|
|
||||||
|
|
||||||
|
|
||||||
def list_loaded_plugins() -> List[str]:
|
|
||||||
"""
|
|
||||||
列出所有当前加载的插件。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: 当前加载的插件名称列表。
|
|
||||||
"""
|
|
||||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
|
||||||
|
|
||||||
return plugin_manager.list_loaded_plugins()
|
|
||||||
|
|
||||||
|
|
||||||
def list_registered_plugins() -> List[str]:
|
|
||||||
"""
|
|
||||||
列出所有已注册的插件。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: 已注册的插件名称列表。
|
|
||||||
"""
|
|
||||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
|
||||||
|
|
||||||
return plugin_manager.list_registered_plugins()
|
|
||||||
|
|
||||||
|
|
||||||
def get_plugin_path(plugin_name: str) -> str:
|
|
||||||
"""
|
|
||||||
获取指定插件的路径。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin_name (str): 插件名称。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 插件目录的绝对路径。
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: 如果插件不存在。
|
|
||||||
"""
|
|
||||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
|
||||||
|
|
||||||
if plugin_path := plugin_manager.get_plugin_path(plugin_name):
|
|
||||||
return plugin_path
|
|
||||||
else:
|
|
||||||
raise ValueError(f"插件 '{plugin_name}' 不存在。")
|
|
||||||
|
|
||||||
|
|
||||||
async def remove_plugin(plugin_name: str) -> bool:
|
|
||||||
"""
|
|
||||||
卸载指定的插件。
|
|
||||||
|
|
||||||
**此函数是异步的,确保在异步环境中调用。**
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin_name (str): 要卸载的插件名称。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 卸载是否成功。
|
|
||||||
"""
|
|
||||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
|
||||||
|
|
||||||
return await plugin_manager.remove_registered_plugin(plugin_name)
|
|
||||||
|
|
||||||
|
|
||||||
async def reload_plugin(plugin_name: str) -> bool:
|
|
||||||
"""
|
|
||||||
重新加载指定的插件。
|
|
||||||
|
|
||||||
**此函数是异步的,确保在异步环境中调用。**
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin_name (str): 要重新加载的插件名称。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 重新加载是否成功。
|
|
||||||
"""
|
|
||||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
|
||||||
|
|
||||||
return await plugin_manager.reload_registered_plugin(plugin_name)
|
|
||||||
|
|
||||||
|
|
||||||
def load_plugin(plugin_name: str) -> Tuple[bool, int]:
|
|
||||||
"""
|
|
||||||
加载指定的插件。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin_name (str): 要加载的插件名称。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[bool, int]: 加载是否成功,成功或失败个数。
|
|
||||||
"""
|
|
||||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
|
||||||
|
|
||||||
return plugin_manager.load_registered_plugin_classes(plugin_name)
|
|
||||||
|
|
||||||
|
|
||||||
def add_plugin_directory(plugin_directory: str) -> bool:
|
|
||||||
"""
|
|
||||||
添加插件目录。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin_directory (str): 要添加的插件目录路径。
|
|
||||||
Returns:
|
|
||||||
bool: 添加是否成功。
|
|
||||||
"""
|
|
||||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
|
||||||
|
|
||||||
return plugin_manager.add_plugin_directory(plugin_directory)
|
|
||||||
|
|
||||||
|
|
||||||
def rescan_plugin_directory() -> Tuple[int, int]:
|
|
||||||
"""
|
|
||||||
重新扫描插件目录,加载新插件。
|
|
||||||
Returns:
|
|
||||||
Tuple[int, int]: 成功加载的插件数量和失败的插件数量。
|
|
||||||
"""
|
|
||||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
|
||||||
|
|
||||||
return plugin_manager.rescan_plugin_directory()
|
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
|
|
||||||
logger = get_logger("plugin_manager") # 复用plugin_manager名称
|
|
||||||
|
|
||||||
|
|
||||||
def register_plugin(cls):
|
|
||||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
|
||||||
from src.plugin_system.base.base_plugin import BasePlugin
|
|
||||||
|
|
||||||
"""插件注册装饰器
|
|
||||||
|
|
||||||
用法:
|
|
||||||
@register_plugin
|
|
||||||
class MyPlugin(BasePlugin):
|
|
||||||
plugin_name = "my_plugin"
|
|
||||||
plugin_description = "我的插件"
|
|
||||||
...
|
|
||||||
"""
|
|
||||||
if not issubclass(cls, BasePlugin):
|
|
||||||
logger.error(f"类 {cls.__name__} 不是 BasePlugin 的子类")
|
|
||||||
return cls
|
|
||||||
|
|
||||||
# 只是注册插件类,不立即实例化
|
|
||||||
# 插件管理器会负责实例化和注册
|
|
||||||
plugin_name: str = cls.plugin_name # type: ignore
|
|
||||||
if "." in plugin_name:
|
|
||||||
logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
|
|
||||||
raise ValueError(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
|
|
||||||
splitted_name = cls.__module__.split(".")
|
|
||||||
root_path = Path(__file__)
|
|
||||||
|
|
||||||
# 查找项目根目录
|
|
||||||
while not (root_path / "pyproject.toml").exists() and root_path.parent != root_path:
|
|
||||||
root_path = root_path.parent
|
|
||||||
|
|
||||||
if not (root_path / "pyproject.toml").exists():
|
|
||||||
logger.error(f"注册 {plugin_name} 无法找到项目根目录")
|
|
||||||
return cls
|
|
||||||
|
|
||||||
plugin_manager.plugin_classes[plugin_name] = cls
|
|
||||||
plugin_manager.plugin_paths[plugin_name] = str(Path(root_path, *splitted_name).resolve())
|
|
||||||
logger.debug(f"插件类已注册: {plugin_name}, 路径: {plugin_manager.plugin_paths[plugin_name]}")
|
|
||||||
|
|
||||||
return cls
|
|
||||||
@@ -1,56 +0,0 @@
|
|||||||
from typing import Any, Callable, Dict, Optional
|
|
||||||
|
|
||||||
from src.plugin_system.base.service_types import PluginServiceInfo
|
|
||||||
from src.plugin_system.core.plugin_service_registry import plugin_service_registry
|
|
||||||
|
|
||||||
|
|
||||||
def register_service(service_info: PluginServiceInfo, service_handler: Callable[..., Any]) -> bool:
|
|
||||||
"""注册插件服务。"""
|
|
||||||
return plugin_service_registry.register_service(service_info, service_handler)
|
|
||||||
|
|
||||||
|
|
||||||
def get_service(service_name: str, plugin_name: Optional[str] = None) -> Optional[PluginServiceInfo]:
|
|
||||||
"""获取插件服务元信息。"""
|
|
||||||
return plugin_service_registry.get_service(service_name, plugin_name)
|
|
||||||
|
|
||||||
|
|
||||||
def get_service_handler(service_name: str, plugin_name: Optional[str] = None) -> Optional[Callable[..., Any]]:
|
|
||||||
"""获取插件服务处理函数。"""
|
|
||||||
return plugin_service_registry.get_service_handler(service_name, plugin_name)
|
|
||||||
|
|
||||||
|
|
||||||
def list_services(plugin_name: Optional[str] = None, enabled_only: bool = False) -> Dict[str, PluginServiceInfo]:
|
|
||||||
"""列出插件服务。"""
|
|
||||||
return plugin_service_registry.list_services(plugin_name=plugin_name, enabled_only=enabled_only)
|
|
||||||
|
|
||||||
|
|
||||||
def enable_service(service_name: str, plugin_name: Optional[str] = None) -> bool:
|
|
||||||
"""启用插件服务。"""
|
|
||||||
return plugin_service_registry.enable_service(service_name, plugin_name)
|
|
||||||
|
|
||||||
|
|
||||||
def disable_service(service_name: str, plugin_name: Optional[str] = None) -> bool:
|
|
||||||
"""禁用插件服务。"""
|
|
||||||
return plugin_service_registry.disable_service(service_name, plugin_name)
|
|
||||||
|
|
||||||
|
|
||||||
def unregister_service(service_name: str, plugin_name: Optional[str] = None) -> bool:
|
|
||||||
"""注销插件服务。"""
|
|
||||||
return plugin_service_registry.unregister_service(service_name, plugin_name)
|
|
||||||
|
|
||||||
|
|
||||||
async def call_service(
|
|
||||||
service_name: str,
|
|
||||||
*args: Any,
|
|
||||||
plugin_name: Optional[str] = None,
|
|
||||||
caller_plugin: Optional[str] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Any:
|
|
||||||
"""调用插件服务。"""
|
|
||||||
return await plugin_service_registry.call_service(
|
|
||||||
service_name,
|
|
||||||
*args,
|
|
||||||
plugin_name=plugin_name,
|
|
||||||
caller_plugin=caller_plugin,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
from typing import Optional, Type, TYPE_CHECKING
|
|
||||||
from src.plugin_system.base.base_tool import BaseTool
|
|
||||||
from src.plugin_system.base.component_types import ComponentType
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from src.chat.message_receive.chat_manager import BotChatSession
|
|
||||||
|
|
||||||
logger = get_logger("tool_api")
|
|
||||||
|
|
||||||
|
|
||||||
def get_tool_instance(tool_name: str, chat_stream: Optional["BotChatSession"] = None) -> Optional[BaseTool]:
|
|
||||||
"""获取公开工具实例
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_name: 工具名称
|
|
||||||
chat_stream: 聊天流对象,用于传递聊天上下文信息
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[BaseTool]: 工具实例,如果未找到则返回None
|
|
||||||
"""
|
|
||||||
from src.plugin_system.core import component_registry
|
|
||||||
|
|
||||||
# 获取插件配置
|
|
||||||
tool_info = component_registry.get_component_info(tool_name, ComponentType.TOOL)
|
|
||||||
if tool_info:
|
|
||||||
plugin_config = component_registry.get_plugin_config(tool_info.plugin_name)
|
|
||||||
else:
|
|
||||||
plugin_config = None
|
|
||||||
|
|
||||||
tool_class: Type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore
|
|
||||||
return tool_class(plugin_config, chat_stream) if tool_class else None
|
|
||||||
|
|
||||||
|
|
||||||
def get_llm_available_tool_definitions():
|
|
||||||
"""获取LLM可用的工具定义列表
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Tuple[str, Dict[str, Any]]]: 工具定义列表,为[("tool_name", 定义)]
|
|
||||||
"""
|
|
||||||
from src.plugin_system.core import component_registry
|
|
||||||
|
|
||||||
llm_available_tools = component_registry.get_llm_available_tools()
|
|
||||||
return [(name, tool_class.get_tool_definition()) for name, tool_class in llm_available_tools.items()]
|
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
from src.plugin_system.base.component_types import EventType, MaiMessages
|
|
||||||
from src.plugin_system.base.workflow_types import WorkflowContext, WorkflowStage, WorkflowStepInfo, WorkflowStepResult
|
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
|
||||||
from src.plugin_system.core.events_manager import events_manager
|
|
||||||
from src.plugin_system.core.workflow_engine import workflow_engine
|
|
||||||
|
|
||||||
|
|
||||||
def register_workflow_step(step_info: WorkflowStepInfo, step_handler: Callable[..., Any]) -> bool:
|
|
||||||
"""注册workflow step。"""
|
|
||||||
return component_registry.register_workflow_step(step_info, step_handler)
|
|
||||||
|
|
||||||
|
|
||||||
def get_steps_by_stage(stage: WorkflowStage, enabled_only: bool = False) -> Dict[str, WorkflowStepInfo]:
|
|
||||||
"""获取指定阶段的workflow steps。"""
|
|
||||||
return component_registry.get_steps_by_stage(stage, enabled_only=enabled_only)
|
|
||||||
|
|
||||||
|
|
||||||
def get_workflow_step(step_name: str, stage: Optional[WorkflowStage] = None) -> Optional[WorkflowStepInfo]:
|
|
||||||
"""获取workflow step元信息。"""
|
|
||||||
return component_registry.get_workflow_step(step_name, stage)
|
|
||||||
|
|
||||||
|
|
||||||
def get_workflow_step_handler(step_name: str, stage: Optional[WorkflowStage] = None) -> Optional[Callable[..., Any]]:
|
|
||||||
"""获取workflow step处理函数。"""
|
|
||||||
return component_registry.get_workflow_step_handler(step_name, stage)
|
|
||||||
|
|
||||||
|
|
||||||
def enable_workflow_step(step_name: str, stage: Optional[WorkflowStage] = None) -> bool:
|
|
||||||
"""启用workflow step。"""
|
|
||||||
return component_registry.enable_workflow_step(step_name, stage)
|
|
||||||
|
|
||||||
|
|
||||||
def disable_workflow_step(step_name: str, stage: Optional[WorkflowStage] = None) -> bool:
|
|
||||||
"""禁用workflow step。"""
|
|
||||||
return component_registry.disable_workflow_step(step_name, stage)
|
|
||||||
|
|
||||||
|
|
||||||
def get_execution_trace(trace_id: str) -> Optional[Dict[str, Any]]:
|
|
||||||
"""按trace_id获取workflow执行路径。"""
|
|
||||||
return workflow_engine.get_execution_trace(trace_id)
|
|
||||||
|
|
||||||
|
|
||||||
def clear_execution_trace(trace_id: str) -> bool:
|
|
||||||
"""清理trace执行路径记录。"""
|
|
||||||
return workflow_engine.clear_execution_trace(trace_id)
|
|
||||||
|
|
||||||
|
|
||||||
async def execute_workflow_message(
|
|
||||||
message: Optional[MaiMessages] = None,
|
|
||||||
stream_id: Optional[str] = None,
|
|
||||||
action_usage: Optional[List[str]] = None,
|
|
||||||
context: Optional[WorkflowContext] = None,
|
|
||||||
) -> Tuple[WorkflowStepResult, Optional[MaiMessages], WorkflowContext]:
|
|
||||||
"""执行workflow消息流转。"""
|
|
||||||
return await events_manager.handle_workflow_message(
|
|
||||||
message=message,
|
|
||||||
stream_id=stream_id,
|
|
||||||
action_usage=action_usage,
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def publish_event(
|
|
||||||
event_type: Union[EventType, str],
|
|
||||||
message: Optional[MaiMessages] = None,
|
|
||||||
stream_id: Optional[str] = None,
|
|
||||||
action_usage: Optional[List[str]] = None,
|
|
||||||
) -> Tuple[bool, Optional[MaiMessages]]:
|
|
||||||
"""发布事件(支持系统事件和自定义字符串事件)。"""
|
|
||||||
return await events_manager.handle_mai_events(
|
|
||||||
event_type=event_type,
|
|
||||||
message=message,
|
|
||||||
stream_id=stream_id,
|
|
||||||
action_usage=action_usage,
|
|
||||||
)
|
|
||||||
@@ -1,72 +0,0 @@
|
|||||||
"""
|
|
||||||
插件基础类模块
|
|
||||||
|
|
||||||
提供插件开发的基础类和类型定义
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .base_plugin import BasePlugin
|
|
||||||
from .base_action import BaseAction
|
|
||||||
from .base_tool import BaseTool
|
|
||||||
from .base_command import BaseCommand
|
|
||||||
from .base_events_handler import BaseEventHandler
|
|
||||||
from .service_types import PluginServiceInfo
|
|
||||||
from .workflow_types import WorkflowContext, WorkflowMessage, WorkflowStage, WorkflowStepInfo, WorkflowStepResult
|
|
||||||
from .workflow_errors import WorkflowErrorCode
|
|
||||||
from .component_types import (
|
|
||||||
ComponentType,
|
|
||||||
ActionActivationType,
|
|
||||||
ChatMode,
|
|
||||||
ComponentInfo,
|
|
||||||
ActionInfo,
|
|
||||||
CommandInfo,
|
|
||||||
ToolInfo,
|
|
||||||
PluginInfo,
|
|
||||||
PythonDependency,
|
|
||||||
EventHandlerInfo,
|
|
||||||
EventType,
|
|
||||||
MaiMessages,
|
|
||||||
ToolParamType,
|
|
||||||
CustomEventHandlerResult,
|
|
||||||
ReplyContentType,
|
|
||||||
ReplyContent,
|
|
||||||
ForwardNode,
|
|
||||||
ReplySetModel,
|
|
||||||
)
|
|
||||||
from .config_types import ConfigField, ConfigSection, ConfigLayout, ConfigTab
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"BasePlugin",
|
|
||||||
"BaseAction",
|
|
||||||
"BaseCommand",
|
|
||||||
"BaseTool",
|
|
||||||
"ComponentType",
|
|
||||||
"ActionActivationType",
|
|
||||||
"ChatMode",
|
|
||||||
"ComponentInfo",
|
|
||||||
"ActionInfo",
|
|
||||||
"CommandInfo",
|
|
||||||
"ToolInfo",
|
|
||||||
"PluginInfo",
|
|
||||||
"PythonDependency",
|
|
||||||
"ConfigField",
|
|
||||||
"ConfigSection",
|
|
||||||
"ConfigLayout",
|
|
||||||
"ConfigTab",
|
|
||||||
"EventHandlerInfo",
|
|
||||||
"EventType",
|
|
||||||
"BaseEventHandler",
|
|
||||||
"MaiMessages",
|
|
||||||
"ToolParamType",
|
|
||||||
"CustomEventHandlerResult",
|
|
||||||
"ReplyContentType",
|
|
||||||
"ReplyContent",
|
|
||||||
"ForwardNode",
|
|
||||||
"ReplySetModel",
|
|
||||||
"PluginServiceInfo",
|
|
||||||
"WorkflowContext",
|
|
||||||
"WorkflowMessage",
|
|
||||||
"WorkflowStage",
|
|
||||||
"WorkflowStepInfo",
|
|
||||||
"WorkflowStepResult",
|
|
||||||
"WorkflowErrorCode",
|
|
||||||
]
|
|
||||||
@@ -1,533 +0,0 @@
|
|||||||
import time
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Tuple, Optional, TYPE_CHECKING, Dict, List
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.common.data_models.message_data_model import ReplyContentType, ReplyContent, ReplySetModel, ForwardNode
|
|
||||||
from src.chat.message_receive.chat_manager import BotChatSession
|
|
||||||
from src.plugin_system.base.component_types import ActionActivationType, ActionInfo, ComponentType
|
|
||||||
from src.plugin_system.apis import send_api, database_api, message_api
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
|
||||||
|
|
||||||
logger = get_logger("base_action")
|
|
||||||
|
|
||||||
|
|
||||||
class BaseAction(ABC):
|
|
||||||
"""Action组件基类
|
|
||||||
|
|
||||||
Action是插件的一种组件类型,用于处理聊天中的动作逻辑
|
|
||||||
|
|
||||||
子类可以通过类属性定义激活条件,这些会在实例化时转换为实例属性:
|
|
||||||
- focus_activation_type: 专注模式激活类型
|
|
||||||
- normal_activation_type: 普通模式激活类型
|
|
||||||
- activation_keywords: 激活关键词列表
|
|
||||||
- keyword_case_sensitive: 关键词是否区分大小写
|
|
||||||
- parallel_action: 是否允许并行执行
|
|
||||||
- random_activation_probability: 随机激活概率
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
action_data: dict,
|
|
||||||
action_reasoning: str,
|
|
||||||
cycle_timers: dict,
|
|
||||||
thinking_id: str,
|
|
||||||
chat_stream: BotChatSession,
|
|
||||||
plugin_config: Optional[dict] = None,
|
|
||||||
action_message: Optional["DatabaseMessages"] = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
# sourcery skip: hoist-similar-statement-from-if, merge-else-if-into-elif, move-assign-in-block, swap-if-else-branches, swap-nested-ifs
|
|
||||||
"""初始化Action组件
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action_data: 动作数据
|
|
||||||
reasoning: 执行该动作的理由
|
|
||||||
cycle_timers: 计时器字典
|
|
||||||
thinking_id: 思考ID
|
|
||||||
chat_stream: 聊天流对象
|
|
||||||
log_prefix: 日志前缀
|
|
||||||
plugin_config: 插件配置字典
|
|
||||||
action_message: 消息数据
|
|
||||||
**kwargs: 其他参数
|
|
||||||
"""
|
|
||||||
if plugin_config is None:
|
|
||||||
plugin_config = {}
|
|
||||||
self.action_data = action_data
|
|
||||||
self.reasoning = ""
|
|
||||||
self.cycle_timers = cycle_timers
|
|
||||||
self.thinking_id = thinking_id
|
|
||||||
|
|
||||||
self.action_reasoning = action_reasoning
|
|
||||||
|
|
||||||
self.plugin_config = plugin_config or {}
|
|
||||||
"""对应的插件配置"""
|
|
||||||
|
|
||||||
# 设置动作基本信息实例属性
|
|
||||||
self.action_name: str = getattr(self, "action_name", self.__class__.__name__.lower().replace("action", ""))
|
|
||||||
"""Action的名字"""
|
|
||||||
self.action_description: str = getattr(self, "action_description", self.__doc__ or "Action组件")
|
|
||||||
"""Action的描述"""
|
|
||||||
self.action_parameters: dict = getattr(self.__class__, "action_parameters", {}).copy()
|
|
||||||
self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy()
|
|
||||||
|
|
||||||
"""NORMAL模式下的激活类型"""
|
|
||||||
self.activation_type = self.__class__.activation_type
|
|
||||||
"""激活类型"""
|
|
||||||
self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0)
|
|
||||||
"""当激活类型为RANDOM时的概率"""
|
|
||||||
self.activation_keywords: list[str] = getattr(self.__class__, "activation_keywords", []).copy()
|
|
||||||
"""激活类型为KEYWORD时的KEYWORDS列表"""
|
|
||||||
self.keyword_case_sensitive: bool = getattr(self.__class__, "keyword_case_sensitive", False)
|
|
||||||
self.parallel_action: bool = getattr(self.__class__, "parallel_action", True)
|
|
||||||
self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy()
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# 便捷属性 - 直接在初始化时获取常用聊天信息(带类型注解)
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
# 获取聊天流对象
|
|
||||||
self.chat_stream = chat_stream or kwargs.get("chat_stream")
|
|
||||||
self.chat_id = self.chat_stream.session_id
|
|
||||||
self.platform = getattr(self.chat_stream, "platform", None)
|
|
||||||
|
|
||||||
# 初始化基础信息(带类型注解)
|
|
||||||
self.action_message = action_message
|
|
||||||
|
|
||||||
self.group_id = None
|
|
||||||
self.group_name = None
|
|
||||||
self.user_id = None
|
|
||||||
self.user_nickname = None
|
|
||||||
self.is_group = False
|
|
||||||
self.target_id = None
|
|
||||||
|
|
||||||
self.group_id = (
|
|
||||||
str(self.action_message.chat_info.group_info.group_id) if self.action_message.chat_info.group_info else None
|
|
||||||
)
|
|
||||||
self.group_name = (
|
|
||||||
self.action_message.chat_info.group_info.group_name if self.action_message.chat_info.group_info else None
|
|
||||||
)
|
|
||||||
|
|
||||||
self.user_id = str(self.action_message.user_info.user_id)
|
|
||||||
self.user_nickname = self.action_message.user_info.user_nickname
|
|
||||||
|
|
||||||
if self.group_id:
|
|
||||||
self.is_group = True
|
|
||||||
self.target_id = self.group_id
|
|
||||||
self.log_prefix = f"[{self.group_name}]"
|
|
||||||
else:
|
|
||||||
self.is_group = False
|
|
||||||
self.target_id = self.user_id
|
|
||||||
self.log_prefix = f"[{self.user_nickname} 的 私聊]"
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def execute(self) -> Tuple[bool, str]:
|
|
||||||
"""执行Action的抽象方法,子类必须实现
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[bool, str]: (是否执行成功, 回复文本)
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def send_text(
|
|
||||||
self,
|
|
||||||
content: str,
|
|
||||||
set_reply: bool = False,
|
|
||||||
reply_message: Optional["DatabaseMessages"] = None,
|
|
||||||
typing: bool = False,
|
|
||||||
storage_message: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""发送文本消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: 文本内容
|
|
||||||
set_reply: 是否作为回复发送
|
|
||||||
reply_message: 回复的消息对象(当set_reply为True时必填)
|
|
||||||
typing: 是否计算输入时间
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否发送成功
|
|
||||||
"""
|
|
||||||
if not self.chat_id:
|
|
||||||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return await send_api.text_to_stream(
|
|
||||||
text=content,
|
|
||||||
stream_id=self.chat_id,
|
|
||||||
set_reply=set_reply,
|
|
||||||
reply_message=reply_message,
|
|
||||||
typing=typing,
|
|
||||||
storage_message=storage_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_emoji(
|
|
||||||
self,
|
|
||||||
emoji_base64: str,
|
|
||||||
set_reply: bool = False,
|
|
||||||
reply_message: Optional["DatabaseMessages"] = None,
|
|
||||||
storage_message: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""发送表情包
|
|
||||||
|
|
||||||
Args:
|
|
||||||
emoji_base64: 表情包的base64编码
|
|
||||||
set_reply: 是否作为回复发送
|
|
||||||
reply_message: 回复的消息对象(当set_reply为True时必填)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否发送成功
|
|
||||||
"""
|
|
||||||
if not self.chat_id:
|
|
||||||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return await send_api.emoji_to_stream(
|
|
||||||
emoji_base64,
|
|
||||||
self.chat_id,
|
|
||||||
set_reply=set_reply,
|
|
||||||
reply_message=reply_message,
|
|
||||||
storage_message=storage_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_image(
|
|
||||||
self,
|
|
||||||
image_base64: str,
|
|
||||||
set_reply: bool = False,
|
|
||||||
reply_message: Optional["DatabaseMessages"] = None,
|
|
||||||
storage_message: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""发送图片
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_base64: 图片的base64编码
|
|
||||||
set_reply: 是否作为回复发送
|
|
||||||
reply_message: 回复的消息对象(当set_reply为True时必填)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否发送成功
|
|
||||||
"""
|
|
||||||
if not self.chat_id:
|
|
||||||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return await send_api.image_to_stream(
|
|
||||||
image_base64,
|
|
||||||
self.chat_id,
|
|
||||||
set_reply=set_reply,
|
|
||||||
reply_message=reply_message,
|
|
||||||
storage_message=storage_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_command(
|
|
||||||
self,
|
|
||||||
command_name: str,
|
|
||||||
args: Optional[dict] = None,
|
|
||||||
display_message: str = "",
|
|
||||||
storage_message: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""发送命令消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
command_name: 命令名称
|
|
||||||
args: 命令参数
|
|
||||||
display_message: 显示消息
|
|
||||||
storage_message: 是否存储消息到数据库
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否发送成功
|
|
||||||
"""
|
|
||||||
if not self.chat_id:
|
|
||||||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 构造命令数据
|
|
||||||
command_data = {"name": command_name, "args": args or {}}
|
|
||||||
|
|
||||||
return await send_api.command_to_stream(
|
|
||||||
command=command_data,
|
|
||||||
stream_id=self.chat_id,
|
|
||||||
storage_message=storage_message,
|
|
||||||
display_message=display_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_custom(
|
|
||||||
self,
|
|
||||||
message_type: str,
|
|
||||||
content: str | Dict,
|
|
||||||
typing: bool = False,
|
|
||||||
set_reply: bool = False,
|
|
||||||
reply_message: Optional["DatabaseMessages"] = None,
|
|
||||||
storage_message: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""发送自定义类型消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message_type: 消息类型,如"video"、"file"、"audio"等
|
|
||||||
content: 消息内容
|
|
||||||
typing: 是否显示正在输入
|
|
||||||
set_reply: 是否作为回复发送
|
|
||||||
reply_message: 回复的消息对象(set_reply 为 True时必填)
|
|
||||||
storage_message: 是否存储消息到数据库
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否发送成功
|
|
||||||
"""
|
|
||||||
if not self.chat_id:
|
|
||||||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return await send_api.custom_to_stream(
|
|
||||||
message_type=message_type,
|
|
||||||
content=content,
|
|
||||||
stream_id=self.chat_id,
|
|
||||||
typing=typing,
|
|
||||||
set_reply=set_reply,
|
|
||||||
reply_message=reply_message,
|
|
||||||
storage_message=storage_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_hybrid(
|
|
||||||
self,
|
|
||||||
message_tuple_list: List[Tuple[ReplyContentType | str, str]],
|
|
||||||
typing: bool = False,
|
|
||||||
set_reply: bool = False,
|
|
||||||
reply_message: Optional["DatabaseMessages"] = None,
|
|
||||||
storage_message: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
发送混合类型消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message_tuple_list: 包含消息类型和内容的元组列表,格式为 [(内容类型, 内容), ...]
|
|
||||||
typing: 是否计算打字时间
|
|
||||||
set_reply: 是否作为回复发送
|
|
||||||
reply_message: 回复的消息对象
|
|
||||||
"""
|
|
||||||
if not self.chat_id:
|
|
||||||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
|
||||||
return False
|
|
||||||
reply_set = ReplySetModel()
|
|
||||||
reply_set.add_hybrid_content_by_raw(message_tuple_list)
|
|
||||||
return await send_api.custom_reply_set_to_stream(
|
|
||||||
reply_set=reply_set,
|
|
||||||
stream_id=self.chat_id,
|
|
||||||
typing=typing,
|
|
||||||
set_reply=set_reply,
|
|
||||||
reply_message=reply_message,
|
|
||||||
storage_message=storage_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_forward(
|
|
||||||
self,
|
|
||||||
messages_list: List[Tuple[str, str, List[Tuple[ReplyContentType | str, str]]] | str],
|
|
||||||
storage_message: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""转发消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages_list: 包含消息信息的列表,当传入自行生成的数据时,元素格式为 (sender_id, nickname, 消息体);当传入消息ID时,元素格式为 "message_id"
|
|
||||||
其中消息体的格式为 [(内容类型, 内容), ...]
|
|
||||||
任意长度的消息都需要使用列表的形式传入
|
|
||||||
storage_message: 是否存储消息到数据库
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否发送成功
|
|
||||||
"""
|
|
||||||
if not self.chat_id:
|
|
||||||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
|
||||||
return False
|
|
||||||
reply_set = ReplySetModel()
|
|
||||||
forward_message_nodes: List[ForwardNode] = []
|
|
||||||
for message in messages_list:
|
|
||||||
if isinstance(message, str):
|
|
||||||
forward_message_node = ForwardNode.construct_as_id_reference(message)
|
|
||||||
elif isinstance(message, Tuple) and len(message) == 3:
|
|
||||||
sender_id, nickname, content_list = message
|
|
||||||
single_node_content_list: List[ReplyContent] = []
|
|
||||||
for node_content_type, node_content in content_list:
|
|
||||||
reply_node_content = ReplyContent(content_type=node_content_type, content=node_content)
|
|
||||||
single_node_content_list.append(reply_node_content)
|
|
||||||
forward_message_node = ForwardNode.construct_as_created_node(
|
|
||||||
user_id=sender_id, user_nickname=nickname, content=single_node_content_list
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(f"{self.log_prefix} 转发消息时遇到无效的消息格式: {message}")
|
|
||||||
continue
|
|
||||||
forward_message_nodes.append(forward_message_node)
|
|
||||||
reply_set.add_forward_content(forward_message_nodes)
|
|
||||||
return await send_api.custom_reply_set_to_stream(
|
|
||||||
reply_set=reply_set,
|
|
||||||
stream_id=self.chat_id,
|
|
||||||
storage_message=storage_message,
|
|
||||||
set_reply=False,
|
|
||||||
reply_message=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_voice(self, audio_base64: str) -> bool:
|
|
||||||
"""
|
|
||||||
发送语音消息
|
|
||||||
Args:
|
|
||||||
audio_base64: 语音的base64编码
|
|
||||||
Returns:
|
|
||||||
bool: 是否发送成功
|
|
||||||
"""
|
|
||||||
if not audio_base64:
|
|
||||||
logger.error(f"{self.log_prefix} 缺少音频内容")
|
|
||||||
return False
|
|
||||||
reply_set = ReplySetModel()
|
|
||||||
reply_set.add_voice_content(audio_base64)
|
|
||||||
return await send_api.custom_reply_set_to_stream(
|
|
||||||
reply_set=reply_set,
|
|
||||||
stream_id=self.chat_id,
|
|
||||||
storage_message=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def store_action_info(
|
|
||||||
self,
|
|
||||||
action_build_into_prompt: bool = False,
|
|
||||||
action_prompt_display: str = "",
|
|
||||||
action_done: bool = True,
|
|
||||||
) -> None:
|
|
||||||
"""存储动作信息到数据库
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action_build_into_prompt: 是否构建到提示中
|
|
||||||
action_prompt_display: 显示的action提示信息
|
|
||||||
action_done: action是否完成
|
|
||||||
"""
|
|
||||||
await database_api.store_action_info(
|
|
||||||
chat_stream=self.chat_stream,
|
|
||||||
action_build_into_prompt=action_build_into_prompt,
|
|
||||||
action_prompt_display=action_prompt_display,
|
|
||||||
action_done=action_done,
|
|
||||||
thinking_id=self.thinking_id,
|
|
||||||
action_data=self.action_data,
|
|
||||||
action_name=self.action_name,
|
|
||||||
action_reasoning=self.action_reasoning,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def wait_for_new_message(self, timeout: int = 1200) -> Tuple[bool, str]:
|
|
||||||
"""等待新消息或超时
|
|
||||||
|
|
||||||
在loop_start_time之后等待新消息,如果没有新消息且没有超时,就一直等待。
|
|
||||||
使用message_api检查self.chat_id对应的聊天中是否有新消息。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
timeout: 超时时间(秒),默认1200秒
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[bool, str]: (是否收到新消息, 空字符串)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 获取循环开始时间,如果没有则使用当前时间
|
|
||||||
loop_start_time = self.action_data.get("loop_start_time", time.time())
|
|
||||||
logger.info(f"{self.log_prefix} 开始等待新消息... (最长等待: {timeout}秒, 从时间点: {loop_start_time})")
|
|
||||||
|
|
||||||
# 确保有有效的chat_id
|
|
||||||
if not self.chat_id:
|
|
||||||
logger.error(f"{self.log_prefix} 等待新消息失败: 没有有效的chat_id")
|
|
||||||
return False, "没有有效的chat_id"
|
|
||||||
|
|
||||||
wait_start_time = asyncio.get_event_loop().time()
|
|
||||||
while True:
|
|
||||||
# 检查新消息
|
|
||||||
current_time = time.time()
|
|
||||||
new_message_count = message_api.count_new_messages(
|
|
||||||
chat_id=self.chat_id, start_time=loop_start_time, end_time=current_time
|
|
||||||
)
|
|
||||||
|
|
||||||
if new_message_count > 0:
|
|
||||||
logger.info(f"{self.log_prefix} 检测到{new_message_count}条新消息,聊天ID: {self.chat_id}")
|
|
||||||
return True, ""
|
|
||||||
|
|
||||||
# 检查超时
|
|
||||||
elapsed_time = asyncio.get_event_loop().time() - wait_start_time
|
|
||||||
if elapsed_time > timeout:
|
|
||||||
logger.warning(f"{self.log_prefix} 等待新消息超时({timeout}秒),聊天ID: {self.chat_id}")
|
|
||||||
return False, ""
|
|
||||||
|
|
||||||
# 每30秒记录一次等待状态
|
|
||||||
if int(elapsed_time) % 15 == 0 and int(elapsed_time) > 0:
|
|
||||||
logger.debug(f"{self.log_prefix} 已等待{int(elapsed_time)}秒,继续等待新消息...")
|
|
||||||
|
|
||||||
# 短暂休眠
|
|
||||||
await asyncio.sleep(0.5)
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.info(f"{self.log_prefix} 等待新消息被中断 (CancelledError)")
|
|
||||||
return False, ""
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"{self.log_prefix} 等待新消息时发生错误: {e}")
|
|
||||||
return False, f"等待新消息失败: {str(e)}"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_action_info(cls) -> "ActionInfo":
|
|
||||||
"""从类属性生成ActionInfo
|
|
||||||
|
|
||||||
所有信息都从类属性中读取,确保一致性和完整性。
|
|
||||||
Action类必须定义所有必要的类属性。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ActionInfo: 生成的Action信息对象
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 从类属性读取名称,如果没有定义则使用类名自动生成
|
|
||||||
name = getattr(cls, "action_name", cls.__name__.lower().replace("action", ""))
|
|
||||||
if "." in name:
|
|
||||||
logger.error(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代")
|
|
||||||
raise ValueError(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代")
|
|
||||||
# 获取focus_activation_type和normal_activation_type
|
|
||||||
focus_activation_type = getattr(cls, "focus_activation_type", ActionActivationType.ALWAYS)
|
|
||||||
_normal_activation_type = getattr(cls, "normal_activation_type", ActionActivationType.ALWAYS)
|
|
||||||
|
|
||||||
# 处理activation_type:如果插件中声明了就用插件的值,否则默认使用focus_activation_type
|
|
||||||
activation_type = getattr(cls, "activation_type", focus_activation_type)
|
|
||||||
|
|
||||||
return ActionInfo(
|
|
||||||
name=name,
|
|
||||||
component_type=ComponentType.ACTION,
|
|
||||||
description=getattr(cls, "action_description", "Action动作"),
|
|
||||||
activation_type=activation_type,
|
|
||||||
activation_keywords=getattr(cls, "activation_keywords", []).copy(),
|
|
||||||
keyword_case_sensitive=getattr(cls, "keyword_case_sensitive", False),
|
|
||||||
parallel_action=getattr(cls, "parallel_action", True),
|
|
||||||
random_activation_probability=getattr(cls, "random_activation_probability", 0.0),
|
|
||||||
# 使用正确的字段名
|
|
||||||
action_parameters=getattr(cls, "action_parameters", {}).copy(),
|
|
||||||
action_require=getattr(cls, "action_require", []).copy(),
|
|
||||||
associated_types=getattr(cls, "associated_types", []).copy(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_config(self, key: str, default=None):
|
|
||||||
"""获取插件配置值,使用嵌套键访问
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key: 配置键名,使用嵌套访问如 "section.subsection.key"
|
|
||||||
default: 默认值
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Any: 配置值或默认值
|
|
||||||
"""
|
|
||||||
if not self.plugin_config:
|
|
||||||
return default
|
|
||||||
|
|
||||||
# 支持嵌套键访问
|
|
||||||
keys = key.split(".")
|
|
||||||
current = self.plugin_config
|
|
||||||
|
|
||||||
for k in keys:
|
|
||||||
if isinstance(current, dict) and k in current:
|
|
||||||
current = current[k]
|
|
||||||
else:
|
|
||||||
return default
|
|
||||||
|
|
||||||
return current
|
|
||||||
@@ -1,387 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Dict, Tuple, Optional, TYPE_CHECKING, List
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.common.data_models.message_data_model import ReplyContentType, ReplyContent, ReplySetModel, ForwardNode
|
|
||||||
from src.plugin_system.base.component_types import CommandInfo, ComponentType
|
|
||||||
from src.chat.message_receive.message import SessionMessage
|
|
||||||
from src.plugin_system.apis import send_api
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
|
||||||
|
|
||||||
logger = get_logger("base_command")
|
|
||||||
|
|
||||||
|
|
||||||
class BaseCommand(ABC):
|
|
||||||
"""Command组件基类
|
|
||||||
|
|
||||||
Command是插件的一种组件类型,用于处理命令请求
|
|
||||||
|
|
||||||
子类可以通过类属性定义命令模式:
|
|
||||||
- command_pattern: 命令匹配的正则表达式
|
|
||||||
- command_help: 命令帮助信息
|
|
||||||
- command_examples: 命令使用示例列表
|
|
||||||
"""
|
|
||||||
|
|
||||||
command_name: str = ""
|
|
||||||
"""Command组件的名称"""
|
|
||||||
command_description: str = ""
|
|
||||||
"""Command组件的描述"""
|
|
||||||
# 默认命令设置
|
|
||||||
command_pattern: str = r""
|
|
||||||
"""命令匹配的正则表达式"""
|
|
||||||
|
|
||||||
def __init__(self, message: SessionMessage, plugin_config: Optional[dict] = None):
|
|
||||||
"""初始化Command组件
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: 接收到的消息对象
|
|
||||||
plugin_config: 插件配置字典
|
|
||||||
"""
|
|
||||||
self.message = message
|
|
||||||
self.matched_groups: Dict[str, str] = {} # 存储正则表达式匹配的命名组
|
|
||||||
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
|
|
||||||
|
|
||||||
self.log_prefix = "[Command]"
|
|
||||||
|
|
||||||
logger.debug(f"{self.log_prefix} Command组件初始化完成")
|
|
||||||
|
|
||||||
def set_matched_groups(self, groups: Dict[str, str]) -> None:
|
|
||||||
"""设置正则表达式匹配的命名组
|
|
||||||
|
|
||||||
Args:
|
|
||||||
groups: 正则表达式匹配的命名组
|
|
||||||
"""
|
|
||||||
self.matched_groups = groups
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def execute(self) -> Tuple[bool, Optional[str], int]:
|
|
||||||
"""执行Command的抽象方法,子类必须实现
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[bool, Optional[str], int]: (是否执行成功, 可选的回复消息, 拦截消息力度,0代表不拦截,1代表仅不触发回复,replyer可见,2代表不触发回复,replyer不可见)
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_config(self, key: str, default=None):
|
|
||||||
"""获取插件配置值,使用嵌套键访问
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key: 配置键名,使用嵌套访问如 "section.subsection.key"
|
|
||||||
default: 默认值
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Any: 配置值或默认值
|
|
||||||
"""
|
|
||||||
if not self.plugin_config:
|
|
||||||
return default
|
|
||||||
|
|
||||||
# 支持嵌套键访问
|
|
||||||
keys = key.split(".")
|
|
||||||
current = self.plugin_config
|
|
||||||
|
|
||||||
for k in keys:
|
|
||||||
if isinstance(current, dict) and k in current:
|
|
||||||
current = current[k]
|
|
||||||
else:
|
|
||||||
return default
|
|
||||||
|
|
||||||
return current
|
|
||||||
|
|
||||||
async def send_text(
|
|
||||||
self,
|
|
||||||
content: str,
|
|
||||||
set_reply: bool = False,
|
|
||||||
reply_message: Optional["DatabaseMessages"] = None,
|
|
||||||
storage_message: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""发送回复消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: 回复内容
|
|
||||||
set_reply: 是否作为回复发送
|
|
||||||
reply_message: 回复的消息对象(当set_reply为True时必填)
|
|
||||||
storage_message: 是否存储消息到数据库
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否发送成功
|
|
||||||
"""
|
|
||||||
# 获取聊天流信息
|
|
||||||
session_id = self.message.session_id
|
|
||||||
if not session_id:
|
|
||||||
logger.error(f"{self.log_prefix} 缺少session_id")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return await send_api.text_to_stream(
|
|
||||||
text=content,
|
|
||||||
stream_id=session_id,
|
|
||||||
set_reply=set_reply,
|
|
||||||
reply_message=reply_message,
|
|
||||||
storage_message=storage_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_image(
|
|
||||||
self,
|
|
||||||
image_base64: str,
|
|
||||||
set_reply: bool = False,
|
|
||||||
reply_message: Optional["DatabaseMessages"] = None,
|
|
||||||
storage_message: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""发送图片
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_base64: 图片的base64编码
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否发送成功
|
|
||||||
"""
|
|
||||||
session_id = self.message.session_id
|
|
||||||
if not session_id:
|
|
||||||
logger.error(f"{self.log_prefix} 缺少session_id")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return await send_api.image_to_stream(
|
|
||||||
image_base64,
|
|
||||||
session_id,
|
|
||||||
set_reply=set_reply,
|
|
||||||
reply_message=reply_message,
|
|
||||||
storage_message=storage_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_emoji(
|
|
||||||
self,
|
|
||||||
emoji_base64: str,
|
|
||||||
set_reply: bool = False,
|
|
||||||
reply_message: Optional["DatabaseMessages"] = None,
|
|
||||||
storage_message: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""发送表情包
|
|
||||||
|
|
||||||
Args:
|
|
||||||
emoji_base64: 表情包的base64编码
|
|
||||||
set_reply: 是否作为回复发送
|
|
||||||
reply_message: 回复的消息对象(当set_reply为True时必填)
|
|
||||||
storage_message: 是否存储消息到数据库
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否发送成功
|
|
||||||
"""
|
|
||||||
session_id = self.message.session_id
|
|
||||||
if not session_id:
|
|
||||||
logger.error(f"{self.log_prefix} 缺少session_id")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return await send_api.emoji_to_stream(
|
|
||||||
emoji_base64, session_id, set_reply=set_reply, reply_message=reply_message
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_command(
|
|
||||||
self,
|
|
||||||
command_name: str,
|
|
||||||
args: Optional[dict] = None,
|
|
||||||
display_message: str = "",
|
|
||||||
storage_message: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""发送命令消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
command_name: 命令名称
|
|
||||||
args: 命令参数
|
|
||||||
display_message: 显示消息
|
|
||||||
storage_message: 是否存储消息到数据库
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否发送成功
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 获取聊天流信息
|
|
||||||
session_id = self.message.session_id
|
|
||||||
if not session_id:
|
|
||||||
logger.error(f"{self.log_prefix} 缺少session_id")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 构造命令数据
|
|
||||||
command_data = {"name": command_name, "args": args or {}}
|
|
||||||
|
|
||||||
success = await send_api.command_to_stream(
|
|
||||||
command=command_data,
|
|
||||||
stream_id=session_id,
|
|
||||||
storage_message=storage_message,
|
|
||||||
display_message=display_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
if success:
|
|
||||||
logger.info(f"{self.log_prefix} 成功发送命令: {command_name}")
|
|
||||||
else:
|
|
||||||
logger.error(f"{self.log_prefix} 发送命令失败: {command_name}")
|
|
||||||
|
|
||||||
return success
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"{self.log_prefix} 发送命令时出错: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def send_voice(self, voice_base64: str) -> bool:
|
|
||||||
"""
|
|
||||||
发送语音消息
|
|
||||||
Args:
|
|
||||||
voice_base64: 语音的base64编码
|
|
||||||
Returns:
|
|
||||||
bool: 是否发送成功
|
|
||||||
"""
|
|
||||||
session_id = self.message.session_id
|
|
||||||
if not session_id:
|
|
||||||
logger.error(f"{self.log_prefix} 缺少session_id")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return await send_api.custom_to_stream(
|
|
||||||
message_type="voice",
|
|
||||||
content=voice_base64,
|
|
||||||
stream_id=session_id,
|
|
||||||
typing=False,
|
|
||||||
set_reply=False,
|
|
||||||
reply_message=None,
|
|
||||||
storage_message=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_hybrid(
|
|
||||||
self,
|
|
||||||
message_tuple_list: List[Tuple[ReplyContentType | str, str]],
|
|
||||||
typing: bool = False,
|
|
||||||
set_reply: bool = False,
|
|
||||||
reply_message: Optional["DatabaseMessages"] = None,
|
|
||||||
storage_message: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
发送混合类型消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message_tuple_list: 包含消息类型和内容的元组列表,格式为 [(内容类型, 内容), ...]
|
|
||||||
typing: 是否显示正在输入
|
|
||||||
set_reply: 是否计算打字时间
|
|
||||||
reply_message: 回复的消息对象
|
|
||||||
storage_message: 是否存储消息到数据库
|
|
||||||
"""
|
|
||||||
session_id = self.message.session_id
|
|
||||||
if not session_id:
|
|
||||||
logger.error(f"{self.log_prefix} 缺少session_id")
|
|
||||||
return False
|
|
||||||
reply_set = ReplySetModel()
|
|
||||||
reply_set.add_hybrid_content_by_raw(message_tuple_list)
|
|
||||||
return await send_api.custom_reply_set_to_stream(
|
|
||||||
reply_set=reply_set,
|
|
||||||
stream_id=session_id,
|
|
||||||
typing=typing,
|
|
||||||
set_reply=set_reply,
|
|
||||||
reply_message=reply_message,
|
|
||||||
storage_message=storage_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_forward(
|
|
||||||
self,
|
|
||||||
messages_list: List[Tuple[str, str, List[Tuple[ReplyContentType | str, str]]] | str],
|
|
||||||
storage_message: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""转发消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages_list: 包含消息信息的列表,当传入自行生成的数据时,元素格式为 (sender_id, nickname, 消息体);当传入消息ID时,元素格式为 "message_id"
|
|
||||||
其中消息体的格式为 [(内容类型, 内容), ...]
|
|
||||||
任意长度的消息都需要使用列表的形式传入
|
|
||||||
storage_message: 是否存储消息到数据库
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否发送成功
|
|
||||||
"""
|
|
||||||
session_id = self.message.session_id
|
|
||||||
if not session_id:
|
|
||||||
logger.error(f"{self.log_prefix} 缺少session_id")
|
|
||||||
return False
|
|
||||||
reply_set = ReplySetModel()
|
|
||||||
forward_message_nodes: List[ForwardNode] = []
|
|
||||||
for message in messages_list:
|
|
||||||
if isinstance(message, str):
|
|
||||||
forward_message_node = ForwardNode.construct_as_id_reference(message)
|
|
||||||
elif isinstance(message, Tuple) and len(message) == 3:
|
|
||||||
sender_id, nickname, content_list = message
|
|
||||||
single_node_content_list: List[ReplyContent] = []
|
|
||||||
for node_content_type, node_content in content_list:
|
|
||||||
reply_node_content = ReplyContent(content_type=node_content_type, content=node_content)
|
|
||||||
single_node_content_list.append(reply_node_content)
|
|
||||||
forward_message_node = ForwardNode.construct_as_created_node(
|
|
||||||
user_id=sender_id, user_nickname=nickname, content=single_node_content_list
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(f"{self.log_prefix} 转发消息时遇到无效的消息格式: {message}")
|
|
||||||
continue
|
|
||||||
forward_message_nodes.append(forward_message_node)
|
|
||||||
reply_set.add_forward_content(forward_message_nodes)
|
|
||||||
return await send_api.custom_reply_set_to_stream(
|
|
||||||
reply_set=reply_set,
|
|
||||||
stream_id=session_id,
|
|
||||||
storage_message=storage_message,
|
|
||||||
set_reply=False,
|
|
||||||
reply_message=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_custom(
|
|
||||||
self,
|
|
||||||
message_type: str,
|
|
||||||
content: str | Dict,
|
|
||||||
display_message: str = "",
|
|
||||||
typing: bool = False,
|
|
||||||
set_reply: bool = False,
|
|
||||||
reply_message: Optional["DatabaseMessages"] = None,
|
|
||||||
storage_message: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""发送指定类型的回复消息到当前聊天环境
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message_type: 消息类型,如"text"、"image"、"emoji"、"voice"等
|
|
||||||
content: 消息内容
|
|
||||||
display_message: 显示消息(可选)
|
|
||||||
typing: 是否显示正在输入
|
|
||||||
set_reply: 是否作为回复发送
|
|
||||||
reply_message: 回复的消息对象(set_reply 为 True时必填)
|
|
||||||
storage_message: 是否存储消息到数据库
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否发送成功
|
|
||||||
"""
|
|
||||||
# 获取聊天流信息
|
|
||||||
session_id = self.message.session_id
|
|
||||||
if not session_id:
|
|
||||||
logger.error(f"{self.log_prefix} 缺少session_id")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return await send_api.custom_to_stream(
|
|
||||||
message_type=message_type,
|
|
||||||
content=content,
|
|
||||||
stream_id=session_id,
|
|
||||||
display_message=display_message,
|
|
||||||
typing=typing,
|
|
||||||
set_reply=set_reply,
|
|
||||||
reply_message=reply_message,
|
|
||||||
storage_message=storage_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_command_info(cls) -> "CommandInfo":
|
|
||||||
"""从类属性生成CommandInfo
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Command名称,如果不提供则使用类名
|
|
||||||
description: Command描述,如果不提供则使用类文档字符串
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
CommandInfo: 生成的Command信息对象
|
|
||||||
"""
|
|
||||||
if "." in cls.command_name:
|
|
||||||
logger.error(f"Command名称 '{cls.command_name}' 包含非法字符 '.',请使用下划线替代")
|
|
||||||
raise ValueError(f"Command名称 '{cls.command_name}' 包含非法字符 '.',请使用下划线替代")
|
|
||||||
return CommandInfo(
|
|
||||||
name=cls.command_name,
|
|
||||||
component_type=ComponentType.COMMAND,
|
|
||||||
description=cls.command_description,
|
|
||||||
command_pattern=cls.command_pattern,
|
|
||||||
)
|
|
||||||
@@ -1,381 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Tuple, Optional, Dict, List, TYPE_CHECKING
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.common.data_models.message_data_model import ReplyContentType, ReplySetModel, ReplyContent, ForwardNode
|
|
||||||
from src.plugin_system.apis import send_api
|
|
||||||
from .component_types import MaiMessages, EventType, EventHandlerInfo, ComponentType, CustomEventHandlerResult
|
|
||||||
|
|
||||||
logger = get_logger("base_event_handler")
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
|
||||||
|
|
||||||
|
|
||||||
class BaseEventHandler(ABC):
|
|
||||||
"""事件处理器基类
|
|
||||||
|
|
||||||
所有事件处理器都应该继承这个基类,提供事件处理的基本接口
|
|
||||||
"""
|
|
||||||
|
|
||||||
event_type: EventType | str = EventType.UNKNOWN
|
|
||||||
"""事件类型,默认为未知"""
|
|
||||||
handler_name: str = ""
|
|
||||||
"""处理器名称"""
|
|
||||||
handler_description: str = ""
|
|
||||||
"""处理器描述"""
|
|
||||||
weight: int = 0
|
|
||||||
"""处理器权重,越大权重越高"""
|
|
||||||
intercept_message: bool = False
|
|
||||||
"""是否拦截消息,默认为否"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.log_prefix = "[EventHandler]"
|
|
||||||
self.plugin_name = ""
|
|
||||||
"""对应插件名"""
|
|
||||||
self.plugin_config: Optional[Dict] = None
|
|
||||||
"""插件配置字典"""
|
|
||||||
if self.event_type == EventType.UNKNOWN:
|
|
||||||
raise NotImplementedError("事件处理器必须指定 event_type")
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def execute(
|
|
||||||
self, message: MaiMessages | None
|
|
||||||
) -> Tuple[bool, bool, Optional[str], Optional[CustomEventHandlerResult], Optional[MaiMessages]]:
|
|
||||||
"""执行事件处理的抽象方法,子类必须实现
|
|
||||||
Args:
|
|
||||||
message (MaiMessages | None): 事件消息对象,当你注册的事件为ON_START和ON_STOP时message为None
|
|
||||||
Returns:
|
|
||||||
Tuple[bool, bool, Optional[str], Optional[CustomEventHandlerResult], Optional[MaiMessages]]: (是否执行成功, 是否需要继续处理, 可选的返回消息, 可选的自定义结果,可选的修改后消息)
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("子类必须实现 execute 方法")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_handler_info(cls) -> "EventHandlerInfo":
|
|
||||||
"""获取事件处理器的信息"""
|
|
||||||
# 从类属性读取名称,如果没有定义则使用类名自动生成S
|
|
||||||
name: str = getattr(cls, "handler_name", cls.__name__.lower().replace("handler", ""))
|
|
||||||
if "." in name:
|
|
||||||
logger.error(f"事件处理器名称 '{name}' 包含非法字符 '.',请使用下划线替代")
|
|
||||||
raise ValueError(f"事件处理器名称 '{name}' 包含非法字符 '.',请使用下划线替代")
|
|
||||||
return EventHandlerInfo(
|
|
||||||
name=name,
|
|
||||||
component_type=ComponentType.EVENT_HANDLER,
|
|
||||||
description=getattr(cls, "handler_description", "events处理器"),
|
|
||||||
event_type=cls.event_type,
|
|
||||||
weight=cls.weight,
|
|
||||||
intercept_message=cls.intercept_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
def set_plugin_config(self, plugin_config: Dict) -> None:
|
|
||||||
"""设置插件配置
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin_config (dict): 插件配置字典
|
|
||||||
"""
|
|
||||||
self.plugin_config = plugin_config
|
|
||||||
|
|
||||||
def set_plugin_name(self, plugin_name: str) -> None:
|
|
||||||
"""设置插件名称
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin_name (str): 插件名称
|
|
||||||
"""
|
|
||||||
self.plugin_name = plugin_name
|
|
||||||
|
|
||||||
def get_config(self, key: str, default=None):
|
|
||||||
"""获取插件配置值,支持嵌套键访问
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key: 配置键名,支持嵌套访问如 "section.subsection.key"
|
|
||||||
default: 默认值
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Any: 配置值或默认值
|
|
||||||
"""
|
|
||||||
if not self.plugin_config:
|
|
||||||
return default
|
|
||||||
|
|
||||||
# 支持嵌套键访问
|
|
||||||
keys = key.split(".")
|
|
||||||
current = self.plugin_config
|
|
||||||
|
|
||||||
for k in keys:
|
|
||||||
if isinstance(current, dict) and k in current:
|
|
||||||
current = current[k]
|
|
||||||
else:
|
|
||||||
return default
|
|
||||||
|
|
||||||
return current
|
|
||||||
|
|
||||||
async def send_text(
|
|
||||||
self,
|
|
||||||
stream_id: str,
|
|
||||||
text: str,
|
|
||||||
set_reply: bool = False,
|
|
||||||
reply_message: Optional["DatabaseMessages"] = None,
|
|
||||||
typing: bool = False,
|
|
||||||
storage_message: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""发送文本消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
stream_id: 聊天ID
|
|
||||||
text: 文本内容
|
|
||||||
set_reply: 是否作为回复发送
|
|
||||||
reply_message: 回复的消息对象(当set_reply为True时必填)
|
|
||||||
typing: 是否计算输入时间
|
|
||||||
storage_message: 是否存储消息到数据库
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否发送成功
|
|
||||||
"""
|
|
||||||
if not stream_id:
|
|
||||||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
|
||||||
return False
|
|
||||||
return await send_api.text_to_stream(
|
|
||||||
text=text,
|
|
||||||
stream_id=stream_id,
|
|
||||||
set_reply=set_reply,
|
|
||||||
reply_message=reply_message,
|
|
||||||
typing=typing,
|
|
||||||
storage_message=storage_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_emoji(
|
|
||||||
self,
|
|
||||||
stream_id: str,
|
|
||||||
emoji_base64: str,
|
|
||||||
set_reply: bool = False,
|
|
||||||
reply_message: Optional["DatabaseMessages"] = None,
|
|
||||||
storage_message: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""发送表情消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
emoji_base64: 表情的Base64编码
|
|
||||||
stream_id: 聊天ID
|
|
||||||
set_reply: 是否作为回复发送
|
|
||||||
reply_message: 回复的消息对象(当set_reply为True时必填)
|
|
||||||
storage_message: 是否存储消息到数据库
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否发送成功
|
|
||||||
"""
|
|
||||||
if not stream_id:
|
|
||||||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
|
||||||
return False
|
|
||||||
return await send_api.emoji_to_stream(
|
|
||||||
emoji_base64=emoji_base64,
|
|
||||||
stream_id=stream_id,
|
|
||||||
set_reply=set_reply,
|
|
||||||
reply_message=reply_message,
|
|
||||||
storage_message=storage_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_image(
|
|
||||||
self,
|
|
||||||
stream_id: str,
|
|
||||||
image_base64: str,
|
|
||||||
set_reply: bool = False,
|
|
||||||
reply_message: Optional["DatabaseMessages"] = None,
|
|
||||||
storage_message: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""发送图片消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_base64: 图片的Base64编码
|
|
||||||
stream_id: 聊天ID
|
|
||||||
set_reply: 是否作为回复发送
|
|
||||||
reply_message: 回复的消息对象(当set_reply为True时必填)
|
|
||||||
storage_message: 是否存储消息到数据库
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否发送成功
|
|
||||||
"""
|
|
||||||
if not stream_id:
|
|
||||||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
|
||||||
return False
|
|
||||||
return await send_api.image_to_stream(
|
|
||||||
image_base64=image_base64,
|
|
||||||
stream_id=stream_id,
|
|
||||||
set_reply=set_reply,
|
|
||||||
reply_message=reply_message,
|
|
||||||
storage_message=storage_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_voice(
|
|
||||||
self,
|
|
||||||
stream_id: str,
|
|
||||||
audio_base64: str,
|
|
||||||
) -> bool:
|
|
||||||
"""发送语音消息
|
|
||||||
Args:
|
|
||||||
stream_id: 聊天ID
|
|
||||||
audio_base64: 语音的Base64编码
|
|
||||||
Returns:
|
|
||||||
bool: 是否发送成功
|
|
||||||
"""
|
|
||||||
if not stream_id:
|
|
||||||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
|
||||||
return False
|
|
||||||
reply_set = ReplySetModel()
|
|
||||||
reply_set.add_voice_content(audio_base64)
|
|
||||||
return await send_api.custom_reply_set_to_stream(
|
|
||||||
reply_set=reply_set,
|
|
||||||
stream_id=stream_id,
|
|
||||||
storage_message=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_command(
|
|
||||||
self,
|
|
||||||
stream_id: str,
|
|
||||||
command_name: str,
|
|
||||||
command_args: Optional[dict] = None,
|
|
||||||
display_message: str = "",
|
|
||||||
storage_message: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""发送命令消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
stream_id: 流ID
|
|
||||||
command_name: 命令名称
|
|
||||||
command_args: 命令参数字典
|
|
||||||
display_message: 显示消息
|
|
||||||
storage_message: 是否存储消息到数据库
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否发送成功
|
|
||||||
"""
|
|
||||||
if not stream_id:
|
|
||||||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 构造命令数据
|
|
||||||
command_data = {"name": command_name, "args": command_args or {}}
|
|
||||||
|
|
||||||
return await send_api.command_to_stream(
|
|
||||||
command=command_data,
|
|
||||||
stream_id=stream_id,
|
|
||||||
storage_message=storage_message,
|
|
||||||
display_message=display_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_custom(
|
|
||||||
self,
|
|
||||||
stream_id: str,
|
|
||||||
message_type: str,
|
|
||||||
content: str | Dict,
|
|
||||||
typing: bool = False,
|
|
||||||
set_reply: bool = False,
|
|
||||||
reply_message: Optional["DatabaseMessages"] = None,
|
|
||||||
storage_message: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""发送自定义消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
stream_id: 聊天ID
|
|
||||||
message_type: 消息类型
|
|
||||||
content: 消息内容,可以是字符串或字典
|
|
||||||
typing: 是否显示正在输入状态
|
|
||||||
set_reply: 是否作为回复发送
|
|
||||||
reply_message: 回复的消息对象(当set_reply为True时必填)
|
|
||||||
storage_message: 是否存储消息到数据库
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否发送成功
|
|
||||||
"""
|
|
||||||
if not stream_id:
|
|
||||||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
|
||||||
return False
|
|
||||||
return await send_api.custom_to_stream(
|
|
||||||
message_type=message_type,
|
|
||||||
content=content,
|
|
||||||
stream_id=stream_id,
|
|
||||||
typing=typing,
|
|
||||||
set_reply=set_reply,
|
|
||||||
reply_message=reply_message,
|
|
||||||
storage_message=storage_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_hybrid(
|
|
||||||
self,
|
|
||||||
stream_id: str,
|
|
||||||
message_tuple_list: List[Tuple[ReplyContentType | str, str]],
|
|
||||||
typing: bool = False,
|
|
||||||
set_reply: bool = False,
|
|
||||||
reply_message: Optional["DatabaseMessages"] = None,
|
|
||||||
storage_message: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
发送混合类型消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
stream_id: 流ID
|
|
||||||
message_tuple_list: 包含消息类型和内容的元组列表,格式为 [(内容类型, 内容), ...]
|
|
||||||
typing: 是否计算打字时间
|
|
||||||
set_reply: 是否作为回复发送
|
|
||||||
reply_message: 回复的消息对象
|
|
||||||
storage_message: 是否存储消息到数据库
|
|
||||||
"""
|
|
||||||
if not stream_id:
|
|
||||||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
|
||||||
return False
|
|
||||||
reply_set = ReplySetModel()
|
|
||||||
reply_set.add_hybrid_content_by_raw(message_tuple_list)
|
|
||||||
return await send_api.custom_reply_set_to_stream(
|
|
||||||
reply_set=reply_set,
|
|
||||||
stream_id=stream_id,
|
|
||||||
typing=typing,
|
|
||||||
set_reply=set_reply,
|
|
||||||
reply_message=reply_message,
|
|
||||||
storage_message=storage_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_forward(
|
|
||||||
self,
|
|
||||||
stream_id: str,
|
|
||||||
messages_list: List[Tuple[str, str, List[Tuple[ReplyContentType | str, str]]] | str],
|
|
||||||
storage_message: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""转发消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
stream_id: 聊天ID
|
|
||||||
messages_list: 包含消息信息的列表,当传入自行生成的数据时,元素格式为 (sender_id, nickname, 消息体);当传入消息ID时,元素格式为 "message_id"
|
|
||||||
其中消息体的格式为 [(内容类型, 内容), ...]
|
|
||||||
任意长度的消息都需要使用列表的形式传入
|
|
||||||
storage_message: 是否存储消息到数据库
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否发送成功
|
|
||||||
"""
|
|
||||||
if not stream_id:
|
|
||||||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
|
||||||
return False
|
|
||||||
reply_set = ReplySetModel()
|
|
||||||
forward_message_nodes: List[ForwardNode] = []
|
|
||||||
for message in messages_list:
|
|
||||||
if isinstance(message, str):
|
|
||||||
forward_message_node = ForwardNode.construct_as_id_reference(message)
|
|
||||||
elif isinstance(message, Tuple) and len(message) == 3:
|
|
||||||
sender_id, nickname, content_list = message
|
|
||||||
single_node_content_list: List[ReplyContent] = []
|
|
||||||
for node_content_type, node_content in content_list:
|
|
||||||
reply_node_content = ReplyContent(content_type=node_content_type, content=node_content)
|
|
||||||
single_node_content_list.append(reply_node_content)
|
|
||||||
forward_message_node = ForwardNode.construct_as_created_node(
|
|
||||||
user_id=sender_id, user_nickname=nickname, content=single_node_content_list
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(f"{self.log_prefix} 转发消息时遇到无效的消息格式: {message}")
|
|
||||||
continue
|
|
||||||
forward_message_nodes.append(forward_message_node)
|
|
||||||
reply_set.add_forward_content(forward_message_nodes)
|
|
||||||
return await send_api.custom_reply_set_to_stream(
|
|
||||||
reply_set=reply_set,
|
|
||||||
stream_id=stream_id,
|
|
||||||
storage_message=storage_message,
|
|
||||||
set_reply=False,
|
|
||||||
reply_message=None,
|
|
||||||
)
|
|
||||||
@@ -1,102 +0,0 @@
|
|||||||
from abc import abstractmethod
|
|
||||||
from typing import Any, Callable, List, Type, Tuple, Union
|
|
||||||
from .plugin_base import PluginBase
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.plugin_system.base.component_types import ActionInfo, CommandInfo, EventHandlerInfo, ToolInfo
|
|
||||||
from src.plugin_system.base.workflow_types import WorkflowStepInfo
|
|
||||||
from .base_action import BaseAction
|
|
||||||
from .base_command import BaseCommand
|
|
||||||
from .base_events_handler import BaseEventHandler
|
|
||||||
from .base_tool import BaseTool
|
|
||||||
|
|
||||||
logger = get_logger("base_plugin")
|
|
||||||
|
|
||||||
|
|
||||||
class BasePlugin(PluginBase):
|
|
||||||
"""基于Action和Command的插件基类
|
|
||||||
|
|
||||||
所有上述类型的插件都应该继承这个基类,一个插件可以包含多种组件:
|
|
||||||
- Action组件:处理聊天中的动作
|
|
||||||
- Command组件:处理命令请求
|
|
||||||
- 未来可扩展:Scheduler、Listener等
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_plugin_components(
|
|
||||||
self,
|
|
||||||
) -> List[
|
|
||||||
Union[
|
|
||||||
Tuple[ActionInfo, Type[BaseAction]],
|
|
||||||
Tuple[CommandInfo, Type[BaseCommand]],
|
|
||||||
Tuple[EventHandlerInfo, Type[BaseEventHandler]],
|
|
||||||
Tuple[ToolInfo, Type[BaseTool]],
|
|
||||||
]
|
|
||||||
]:
|
|
||||||
"""获取插件包含的组件列表
|
|
||||||
|
|
||||||
子类必须实现此方法,返回组件信息和组件类的列表
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[tuple[ComponentInfo, Type]]: [(组件信息, 组件类), ...]
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("Subclasses must implement this method")
|
|
||||||
|
|
||||||
def get_workflow_steps(self) -> List[Tuple[WorkflowStepInfo, Callable[..., Any]]]:
|
|
||||||
"""获取插件包含的workflow steps。
|
|
||||||
|
|
||||||
默认返回空列表,子类可按需覆盖。
|
|
||||||
"""
|
|
||||||
return []
|
|
||||||
|
|
||||||
def register_plugin(self) -> bool:
|
|
||||||
"""注册插件及其所有组件"""
|
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
|
||||||
|
|
||||||
components = self.get_plugin_components()
|
|
||||||
|
|
||||||
# 检查依赖
|
|
||||||
if not self._check_dependencies():
|
|
||||||
logger.error(f"{self.log_prefix} 依赖检查失败,跳过注册")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 注册所有组件
|
|
||||||
registered_components = []
|
|
||||||
for component_info, component_class in components:
|
|
||||||
component_info.plugin_name = self.plugin_name
|
|
||||||
if component_registry.register_component(component_info, component_class):
|
|
||||||
registered_components.append(component_info)
|
|
||||||
else:
|
|
||||||
logger.warning(f"{self.log_prefix} 组件 {component_info.name} 注册失败")
|
|
||||||
|
|
||||||
# 更新插件信息中的组件列表
|
|
||||||
self.plugin_info.components = registered_components
|
|
||||||
|
|
||||||
# 注册插件
|
|
||||||
if component_registry.register_plugin(self.plugin_info):
|
|
||||||
# 注册workflow steps(可选)
|
|
||||||
registered_step_count = 0
|
|
||||||
for step_info, step_handler in self.get_workflow_steps():
|
|
||||||
if not step_info.plugin_name:
|
|
||||||
step_info.plugin_name = self.plugin_name
|
|
||||||
elif step_info.plugin_name != self.plugin_name:
|
|
||||||
logger.warning(
|
|
||||||
f"{self.log_prefix} workflow step {step_info.name} 的plugin_name({step_info.plugin_name})与当前插件不一致,已覆盖为 {self.plugin_name}"
|
|
||||||
)
|
|
||||||
step_info.plugin_name = self.plugin_name
|
|
||||||
|
|
||||||
if component_registry.register_workflow_step(step_info, step_handler):
|
|
||||||
registered_step_count += 1
|
|
||||||
else:
|
|
||||||
logger.warning(f"{self.log_prefix} workflow step {step_info.full_name} 注册失败")
|
|
||||||
|
|
||||||
logger.debug(f"{self.log_prefix} 插件注册成功,包含 {len(registered_components)} 个组件")
|
|
||||||
if registered_step_count > 0:
|
|
||||||
logger.debug(f"{self.log_prefix} workflow steps 注册成功,数量: {registered_step_count}")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
logger.error(f"{self.log_prefix} 插件注册失败")
|
|
||||||
return False
|
|
||||||
@@ -1,137 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any, List, Optional, Tuple, TYPE_CHECKING
|
|
||||||
from rich.traceback import install
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.plugin_system.base.component_types import ComponentType, ToolInfo, ToolParamType
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from src.chat.message_receive.chat_manager import BotChatSession
|
|
||||||
|
|
||||||
install(extra_lines=3)
|
|
||||||
|
|
||||||
logger = get_logger("base_tool")
|
|
||||||
|
|
||||||
|
|
||||||
class BaseTool(ABC):
|
|
||||||
"""所有工具的基类"""
|
|
||||||
|
|
||||||
name: str = ""
|
|
||||||
"""工具的名称"""
|
|
||||||
description: str = ""
|
|
||||||
"""工具的描述"""
|
|
||||||
parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = []
|
|
||||||
"""工具的参数定义,为[("param_name", param_type, "description", required, enum_values)]格式
|
|
||||||
param_name: 参数名称
|
|
||||||
param_type: 参数类型
|
|
||||||
description: 参数描述
|
|
||||||
required: 是否必填
|
|
||||||
enum_values: 枚举值列表
|
|
||||||
例如: [("arg1", ToolParamType.STRING, "参数1描述", True, None), ("arg2", ToolParamType.INTEGER, "参数2描述", False, ["1", "2", "3"])]
|
|
||||||
"""
|
|
||||||
available_for_llm: bool = False
|
|
||||||
"""是否可供LLM使用"""
|
|
||||||
|
|
||||||
def __init__(self, plugin_config: Optional[dict] = None, chat_stream: Optional["BotChatSession"] = None):
|
|
||||||
"""初始化工具基类
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin_config: 插件配置字典
|
|
||||||
chat_stream: 聊天流对象,用于获取聊天上下文信息
|
|
||||||
"""
|
|
||||||
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# 便捷属性 - 直接在初始化时获取常用聊天信息(与BaseAction保持一致)
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
# 获取聊天流对象
|
|
||||||
self.chat_stream = chat_stream
|
|
||||||
self.chat_id = self.chat_stream.session_id if self.chat_stream else None
|
|
||||||
self.platform = getattr(self.chat_stream, "platform", None) if self.chat_stream else None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_tool_definition(cls) -> dict[str, Any]:
|
|
||||||
"""获取工具定义,用于LLM工具调用
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: 工具定义字典
|
|
||||||
"""
|
|
||||||
if not cls.name or not cls.description or cls.parameters is None:
|
|
||||||
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
|
|
||||||
|
|
||||||
return {"name": cls.name, "description": cls.description, "parameters": cls.parameters}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_tool_info(cls) -> ToolInfo:
|
|
||||||
"""获取工具信息"""
|
|
||||||
if not cls.name or not cls.description or cls.parameters is None:
|
|
||||||
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
|
|
||||||
|
|
||||||
return ToolInfo(
|
|
||||||
name=cls.name,
|
|
||||||
tool_description=cls.description,
|
|
||||||
enabled=cls.available_for_llm,
|
|
||||||
tool_parameters=cls.parameters,
|
|
||||||
component_type=ComponentType.TOOL,
|
|
||||||
)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""执行工具函数(供llm调用)
|
|
||||||
通过该方法,maicore会通过llm的tool call来调用工具
|
|
||||||
传入的是json格式的参数,符合parameters定义的格式
|
|
||||||
|
|
||||||
Args:
|
|
||||||
function_args: 工具调用参数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: 工具执行结果
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("子类必须实现execute方法")
|
|
||||||
|
|
||||||
async def direct_execute(self, **function_args: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""直接执行工具函数(供插件调用)
|
|
||||||
通过该方法,插件可以直接调用工具,而不需要传入字典格式的参数
|
|
||||||
插件可以直接调用此方法,用更加明了的方式传入参数
|
|
||||||
示例: result = await tool.direct_execute(arg1="参数",arg2="参数2")
|
|
||||||
|
|
||||||
工具开发者可以重写此方法以实现与llm调用差异化的执行逻辑
|
|
||||||
|
|
||||||
Args:
|
|
||||||
**function_args: 工具调用参数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: 工具执行结果
|
|
||||||
"""
|
|
||||||
parameter_required = [param[0] for param in self.parameters if param[3]] # 获取所有必填参数名
|
|
||||||
for param_name in parameter_required:
|
|
||||||
if param_name not in function_args:
|
|
||||||
raise ValueError(f"工具类 {self.__class__.__name__} 缺少必要参数: {param_name}")
|
|
||||||
|
|
||||||
return await self.execute(function_args)
|
|
||||||
|
|
||||||
def get_config(self, key: str, default=None):
|
|
||||||
"""获取插件配置值,使用嵌套键访问
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key: 配置键名,使用嵌套访问如 "section.subsection.key"
|
|
||||||
default: 默认值
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Any: 配置值或默认值
|
|
||||||
"""
|
|
||||||
if not self.plugin_config:
|
|
||||||
return default
|
|
||||||
|
|
||||||
# 支持嵌套键访问
|
|
||||||
keys = key.split(".")
|
|
||||||
current = self.plugin_config
|
|
||||||
|
|
||||||
for k in keys:
|
|
||||||
if isinstance(current, dict) and k in current:
|
|
||||||
current = current[k]
|
|
||||||
else:
|
|
||||||
return default
|
|
||||||
|
|
||||||
return current
|
|
||||||
@@ -1,761 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Dict, List, Any, Union
|
|
||||||
import os
|
|
||||||
import inspect
|
|
||||||
import toml
|
|
||||||
import json
|
|
||||||
import shutil
|
|
||||||
import datetime
|
|
||||||
import re
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.plugin_system.base.component_types import (
|
|
||||||
PluginInfo,
|
|
||||||
PythonDependency,
|
|
||||||
)
|
|
||||||
from src.plugin_system.base.config_types import (
|
|
||||||
ConfigField,
|
|
||||||
ConfigSection,
|
|
||||||
ConfigLayout,
|
|
||||||
)
|
|
||||||
from src.plugin_system.utils.manifest_utils import ManifestValidator, VersionComparator
|
|
||||||
|
|
||||||
logger = get_logger("plugin_base")
|
|
||||||
|
|
||||||
|
|
||||||
class PluginBase(ABC):
|
|
||||||
"""插件总基类
|
|
||||||
|
|
||||||
所有衍生插件基类都应该继承自此类,这个类定义了插件的基本结构和行为。
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 插件基本信息(子类必须定义)
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def plugin_name(self) -> str:
|
|
||||||
return "" # 插件内部标识符(如 "hello_world_plugin")
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def enable_plugin(self) -> bool:
|
|
||||||
return True # 是否启用插件
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def dependencies(self) -> List[Union[str, Dict[str, Any]]]:
|
|
||||||
return [] # 依赖的其他插件
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def python_dependencies(self) -> List[PythonDependency]:
|
|
||||||
return [] # Python包依赖
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def config_file_name(self) -> str:
|
|
||||||
return "" # 配置文件名
|
|
||||||
|
|
||||||
# manifest文件相关
|
|
||||||
manifest_file_name: str = "_manifest.json" # manifest文件名
|
|
||||||
manifest_data: Dict[str, Any] = {} # manifest数据
|
|
||||||
|
|
||||||
# 配置定义
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def config_schema(self) -> Dict[str, Union[Dict[str, ConfigField], str]]:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
config_section_descriptions: Dict[str, Union[str, ConfigSection]] = {}
|
|
||||||
|
|
||||||
# 布局配置(可选,不定义则使用自动布局)
|
|
||||||
config_layout: ConfigLayout = None
|
|
||||||
|
|
||||||
def __init__(self, plugin_dir: str):
|
|
||||||
"""初始化插件
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin_dir: 插件目录路径,由插件管理器传递
|
|
||||||
"""
|
|
||||||
self.config: Dict[str, Any] = {} # 插件配置
|
|
||||||
self.plugin_dir = plugin_dir # 插件目录路径
|
|
||||||
self.log_prefix = f"[Plugin:{self.plugin_name}]"
|
|
||||||
|
|
||||||
# 加载manifest文件
|
|
||||||
self._load_manifest()
|
|
||||||
|
|
||||||
# 验证插件信息
|
|
||||||
self._validate_plugin_info()
|
|
||||||
|
|
||||||
# 加载插件配置
|
|
||||||
self._load_plugin_config()
|
|
||||||
|
|
||||||
# 从manifest获取显示信息
|
|
||||||
self.display_name = self.get_manifest_info("name", self.plugin_name)
|
|
||||||
self.plugin_version = self.get_manifest_info("version", "1.0.0")
|
|
||||||
self.plugin_description = self.get_manifest_info("description", "")
|
|
||||||
self.plugin_author = self._get_author_name()
|
|
||||||
|
|
||||||
# 创建插件信息对象
|
|
||||||
self.plugin_info = PluginInfo(
|
|
||||||
name=self.plugin_name,
|
|
||||||
display_name=self.display_name,
|
|
||||||
description=self.plugin_description,
|
|
||||||
version=self.plugin_version,
|
|
||||||
author=self.plugin_author,
|
|
||||||
enabled=self.enable_plugin,
|
|
||||||
is_built_in=False,
|
|
||||||
config_file=self.config_file_name or "",
|
|
||||||
dependencies=self._get_dependency_names(),
|
|
||||||
python_dependencies=self.python_dependencies.copy(),
|
|
||||||
# manifest相关信息
|
|
||||||
manifest_data=self.manifest_data.copy(),
|
|
||||||
license=self.get_manifest_info("license", ""),
|
|
||||||
homepage_url=self.get_manifest_info("homepage_url", ""),
|
|
||||||
repository_url=self.get_manifest_info("repository_url", ""),
|
|
||||||
keywords=self.get_manifest_info("keywords", []).copy() if self.get_manifest_info("keywords") else [],
|
|
||||||
categories=self.get_manifest_info("categories", []).copy() if self.get_manifest_info("categories") else [],
|
|
||||||
min_host_version=self.get_manifest_info("host_application.min_version", ""),
|
|
||||||
max_host_version=self.get_manifest_info("host_application.max_version", ""),
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(f"{self.log_prefix} 插件基类初始化完成")
|
|
||||||
|
|
||||||
def _validate_plugin_info(self):
|
|
||||||
"""验证插件基本信息"""
|
|
||||||
if not self.plugin_name:
|
|
||||||
raise ValueError(f"插件类 {self.__class__.__name__} 必须定义 plugin_name")
|
|
||||||
|
|
||||||
# 验证manifest中的必需信息
|
|
||||||
if not self.get_manifest_info("name"):
|
|
||||||
raise ValueError(f"插件 {self.plugin_name} 的manifest中缺少name字段")
|
|
||||||
if not self.get_manifest_info("description"):
|
|
||||||
raise ValueError(f"插件 {self.plugin_name} 的manifest中缺少description字段")
|
|
||||||
|
|
||||||
def _load_manifest(self): # sourcery skip: raise-from-previous-error
|
|
||||||
"""加载manifest文件(强制要求)"""
|
|
||||||
if not self.plugin_dir:
|
|
||||||
raise ValueError(f"{self.log_prefix} 没有插件目录路径,无法加载manifest")
|
|
||||||
|
|
||||||
manifest_path = os.path.join(self.plugin_dir, self.manifest_file_name)
|
|
||||||
|
|
||||||
if not os.path.exists(manifest_path):
|
|
||||||
error_msg = f"{self.log_prefix} 缺少必需的manifest文件: {manifest_path}"
|
|
||||||
logger.error(error_msg)
|
|
||||||
raise FileNotFoundError(error_msg)
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
|
||||||
self.manifest_data = json.load(f)
|
|
||||||
|
|
||||||
logger.debug(f"{self.log_prefix} 成功加载manifest文件: {manifest_path}")
|
|
||||||
|
|
||||||
# 验证manifest格式
|
|
||||||
self._validate_manifest()
|
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
error_msg = f"{self.log_prefix} manifest文件格式错误: {e}"
|
|
||||||
logger.error(error_msg)
|
|
||||||
raise ValueError(error_msg) # noqa
|
|
||||||
except IOError as e:
|
|
||||||
error_msg = f"{self.log_prefix} 读取manifest文件失败: {e}"
|
|
||||||
logger.error(error_msg)
|
|
||||||
raise IOError(error_msg) # noqa
|
|
||||||
|
|
||||||
def _get_author_name(self) -> str:
|
|
||||||
"""从manifest获取作者名称"""
|
|
||||||
author_info = self.get_manifest_info("author", {})
|
|
||||||
if isinstance(author_info, dict):
|
|
||||||
return author_info.get("name", "")
|
|
||||||
else:
|
|
||||||
return str(author_info) if author_info else ""
|
|
||||||
|
|
||||||
def _validate_manifest(self):
|
|
||||||
"""验证manifest文件格式(使用强化的验证器)"""
|
|
||||||
if not self.manifest_data:
|
|
||||||
raise ValueError(f"{self.log_prefix} manifest数据为空,验证失败")
|
|
||||||
|
|
||||||
validator = ManifestValidator()
|
|
||||||
is_valid = validator.validate_manifest(self.manifest_data)
|
|
||||||
|
|
||||||
# 记录验证结果
|
|
||||||
if validator.validation_errors or validator.validation_warnings:
|
|
||||||
report = validator.get_validation_report()
|
|
||||||
logger.info(f"{self.log_prefix} Manifest验证结果:\n{report}")
|
|
||||||
|
|
||||||
# 如果有验证错误,抛出异常
|
|
||||||
if not is_valid:
|
|
||||||
error_msg = f"{self.log_prefix} Manifest文件验证失败"
|
|
||||||
if validator.validation_errors:
|
|
||||||
error_msg += f": {'; '.join(validator.validation_errors)}"
|
|
||||||
raise ValueError(error_msg)
|
|
||||||
|
|
||||||
def get_manifest_info(self, key: str, default: Any = None) -> Any:
|
|
||||||
"""获取manifest信息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key: 信息键,支持点分割的嵌套键(如 "author.name")
|
|
||||||
default: 默认值
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Any: 对应的值
|
|
||||||
"""
|
|
||||||
if not self.manifest_data:
|
|
||||||
return default
|
|
||||||
|
|
||||||
keys = key.split(".")
|
|
||||||
value = self.manifest_data
|
|
||||||
|
|
||||||
for k in keys:
|
|
||||||
if isinstance(value, dict) and k in value:
|
|
||||||
value = value[k]
|
|
||||||
else:
|
|
||||||
return default
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
def _format_toml_value(self, value: Any) -> str:
|
|
||||||
"""将Python值格式化为合法的TOML字符串"""
|
|
||||||
if isinstance(value, str):
|
|
||||||
return json.dumps(value, ensure_ascii=False)
|
|
||||||
if isinstance(value, bool):
|
|
||||||
return str(value).lower()
|
|
||||||
if isinstance(value, (int, float)):
|
|
||||||
return str(value)
|
|
||||||
if isinstance(value, list):
|
|
||||||
inner = ", ".join(self._format_toml_value(item) for item in value)
|
|
||||||
return f"[{inner}]"
|
|
||||||
if isinstance(value, dict):
|
|
||||||
items = [f"{k} = {self._format_toml_value(v)}" for k, v in value.items()]
|
|
||||||
return "{ " + ", ".join(items) + " }"
|
|
||||||
return json.dumps(value, ensure_ascii=False)
|
|
||||||
|
|
||||||
def _generate_and_save_default_config(self, config_file_path: str):
|
|
||||||
"""根据插件的Schema生成并保存默认配置文件"""
|
|
||||||
if not self.config_schema:
|
|
||||||
logger.debug(f"{self.log_prefix} 插件未定义config_schema,不生成配置文件")
|
|
||||||
return
|
|
||||||
|
|
||||||
toml_str = f"# {self.plugin_name} - 自动生成的配置文件\n"
|
|
||||||
plugin_description = self.get_manifest_info("description", "插件配置文件")
|
|
||||||
toml_str += f"# {plugin_description}\n\n"
|
|
||||||
|
|
||||||
# 遍历每个配置节
|
|
||||||
for section, fields in self.config_schema.items():
|
|
||||||
# 添加节描述
|
|
||||||
if section in self.config_section_descriptions:
|
|
||||||
toml_str += f"# {self.config_section_descriptions[section]}\n"
|
|
||||||
|
|
||||||
toml_str += f"[{section}]\n\n"
|
|
||||||
|
|
||||||
# 遍历节内的字段
|
|
||||||
if isinstance(fields, dict):
|
|
||||||
for field_name, field in fields.items():
|
|
||||||
if isinstance(field, ConfigField):
|
|
||||||
# 添加字段描述
|
|
||||||
toml_str += f"# {field.description}"
|
|
||||||
if field.required:
|
|
||||||
toml_str += " (必需)"
|
|
||||||
toml_str += "\n"
|
|
||||||
|
|
||||||
# 如果有示例值,添加示例
|
|
||||||
if field.example:
|
|
||||||
toml_str += f"# 示例: {field.example}\n"
|
|
||||||
|
|
||||||
# 如果有可选值,添加说明
|
|
||||||
if field.choices:
|
|
||||||
choices_str = ", ".join(map(str, field.choices))
|
|
||||||
toml_str += f"# 可选值: {choices_str}\n"
|
|
||||||
|
|
||||||
# 添加字段值
|
|
||||||
value = field.default
|
|
||||||
toml_str += f"{field_name} = {self._format_toml_value(value)}\n"
|
|
||||||
|
|
||||||
toml_str += "\n"
|
|
||||||
toml_str += "\n"
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(config_file_path, "w", encoding="utf-8") as f:
|
|
||||||
f.write(toml_str)
|
|
||||||
logger.info(f"{self.log_prefix} 已生成默认配置文件: {config_file_path}")
|
|
||||||
except IOError as e:
|
|
||||||
logger.error(f"{self.log_prefix} 保存默认配置文件失败: {e}", exc_info=True)
|
|
||||||
|
|
||||||
def _get_expected_config_version(self) -> str:
|
|
||||||
"""获取插件期望的配置版本号"""
|
|
||||||
# 从config_schema的plugin.config_version字段获取
|
|
||||||
if "plugin" in self.config_schema and isinstance(self.config_schema["plugin"], dict):
|
|
||||||
config_version_field = self.config_schema["plugin"].get("config_version")
|
|
||||||
if isinstance(config_version_field, ConfigField):
|
|
||||||
return config_version_field.default
|
|
||||||
return "1.0.0"
|
|
||||||
|
|
||||||
def _get_current_config_version(self, config: Dict[str, Any]) -> str:
|
|
||||||
"""从配置文件中获取当前版本号"""
|
|
||||||
if "plugin" in config and "config_version" in config["plugin"]:
|
|
||||||
return str(config["plugin"]["config_version"])
|
|
||||||
# 如果没有config_version字段,视为最早的版本
|
|
||||||
return "0.0.0"
|
|
||||||
|
|
||||||
def _backup_config_file(self, config_file_path: str) -> str:
|
|
||||||
"""备份配置文件"""
|
|
||||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
backup_path = f"{config_file_path}.backup_{timestamp}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
shutil.copy2(config_file_path, backup_path)
|
|
||||||
logger.info(f"{self.log_prefix} 配置文件已备份到: {backup_path}")
|
|
||||||
return backup_path
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"{self.log_prefix} 备份配置文件失败: {e}")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
def _migrate_config_values(self, old_config: Dict[str, Any], new_config: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""将旧配置值迁移到新配置结构中
|
|
||||||
|
|
||||||
Args:
|
|
||||||
old_config: 旧配置数据
|
|
||||||
new_config: 基于新schema生成的默认配置
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict[str, Any]: 迁移后的配置
|
|
||||||
"""
|
|
||||||
|
|
||||||
def migrate_section(
|
|
||||||
old_section: Dict[str, Any], new_section: Dict[str, Any], section_name: str
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""迁移单个配置节"""
|
|
||||||
result = new_section.copy()
|
|
||||||
|
|
||||||
for key, value in old_section.items():
|
|
||||||
if key in new_section:
|
|
||||||
# 特殊处理:config_version字段总是使用新版本
|
|
||||||
if section_name == "plugin" and key == "config_version":
|
|
||||||
# 保持新的版本号,不迁移旧值
|
|
||||||
logger.debug(
|
|
||||||
f"{self.log_prefix} 更新配置版本: {section_name}.{key} = {result[key]} (旧值: {value})"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 键存在于新配置中,复制值
|
|
||||||
if isinstance(value, dict) and isinstance(new_section[key], dict):
|
|
||||||
# 递归处理嵌套字典
|
|
||||||
result[key] = migrate_section(value, new_section[key], f"{section_name}.{key}")
|
|
||||||
else:
|
|
||||||
result[key] = value
|
|
||||||
logger.debug(f"{self.log_prefix} 迁移配置: {section_name}.{key} = {value}")
|
|
||||||
else:
|
|
||||||
# 键在新配置中不存在,记录警告
|
|
||||||
logger.warning(f"{self.log_prefix} 配置项 {section_name}.{key} 在新版本中已被移除")
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
migrated_config = {}
|
|
||||||
|
|
||||||
# 迁移每个配置节
|
|
||||||
for section_name, new_section_data in new_config.items():
|
|
||||||
if (
|
|
||||||
section_name in old_config
|
|
||||||
and isinstance(old_config[section_name], dict)
|
|
||||||
and isinstance(new_section_data, dict)
|
|
||||||
):
|
|
||||||
migrated_config[section_name] = migrate_section(
|
|
||||||
old_config[section_name], new_section_data, section_name
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 新增的节或类型不匹配,使用默认值
|
|
||||||
migrated_config[section_name] = new_section_data
|
|
||||||
if section_name in old_config:
|
|
||||||
logger.warning(f"{self.log_prefix} 配置节 {section_name} 结构已改变,使用默认值")
|
|
||||||
|
|
||||||
# 检查旧配置中是否有新配置没有的节
|
|
||||||
for section_name in old_config:
|
|
||||||
if section_name not in migrated_config:
|
|
||||||
logger.warning(f"{self.log_prefix} 配置节 {section_name} 在新版本中已被移除")
|
|
||||||
|
|
||||||
return migrated_config
|
|
||||||
|
|
||||||
def _generate_config_from_schema(self) -> Dict[str, Any]:
|
|
||||||
# sourcery skip: dict-comprehension
|
|
||||||
"""根据schema生成配置数据结构(不写入文件)"""
|
|
||||||
if not self.config_schema:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
config_data = {}
|
|
||||||
|
|
||||||
# 遍历每个配置节
|
|
||||||
for section, fields in self.config_schema.items():
|
|
||||||
if isinstance(fields, dict):
|
|
||||||
section_data = {}
|
|
||||||
|
|
||||||
# 遍历节内的字段
|
|
||||||
for field_name, field in fields.items():
|
|
||||||
if isinstance(field, ConfigField):
|
|
||||||
section_data[field_name] = field.default
|
|
||||||
|
|
||||||
config_data[section] = section_data
|
|
||||||
|
|
||||||
return config_data
|
|
||||||
|
|
||||||
def _save_config_to_file(self, config_data: Dict[str, Any], config_file_path: str):
|
|
||||||
"""将配置数据保存为TOML文件(包含注释)"""
|
|
||||||
if not self.config_schema:
|
|
||||||
logger.debug(f"{self.log_prefix} 插件未定义config_schema,不生成配置文件")
|
|
||||||
return
|
|
||||||
|
|
||||||
toml_str = f"# {self.plugin_name} - 配置文件\n"
|
|
||||||
plugin_description = self.get_manifest_info("description", "插件配置文件")
|
|
||||||
toml_str += f"# {plugin_description}\n"
|
|
||||||
|
|
||||||
# 获取当前期望的配置版本
|
|
||||||
expected_version = self._get_expected_config_version()
|
|
||||||
toml_str += f"# 配置版本: {expected_version}\n\n"
|
|
||||||
|
|
||||||
# 遍历每个配置节
|
|
||||||
for section, fields in self.config_schema.items():
|
|
||||||
# 添加节描述
|
|
||||||
if section in self.config_section_descriptions:
|
|
||||||
toml_str += f"# {self.config_section_descriptions[section]}\n"
|
|
||||||
|
|
||||||
toml_str += f"[{section}]\n\n"
|
|
||||||
|
|
||||||
# 遍历节内的字段
|
|
||||||
if isinstance(fields, dict) and section in config_data:
|
|
||||||
section_data = config_data[section]
|
|
||||||
|
|
||||||
for field_name, field in fields.items():
|
|
||||||
if isinstance(field, ConfigField):
|
|
||||||
# 添加字段描述
|
|
||||||
toml_str += f"# {field.description}"
|
|
||||||
if field.required:
|
|
||||||
toml_str += " (必需)"
|
|
||||||
toml_str += "\n"
|
|
||||||
|
|
||||||
# 如果有示例值,添加示例
|
|
||||||
if field.example:
|
|
||||||
toml_str += f"# 示例: {field.example}\n"
|
|
||||||
|
|
||||||
# 如果有可选值,添加说明
|
|
||||||
if field.choices:
|
|
||||||
choices_str = ", ".join(map(str, field.choices))
|
|
||||||
toml_str += f"# 可选值: {choices_str}\n"
|
|
||||||
|
|
||||||
# 添加字段值(使用迁移后的值)
|
|
||||||
value = section_data.get(field_name, field.default)
|
|
||||||
toml_str += f"{field_name} = {self._format_toml_value(value)}\n"
|
|
||||||
|
|
||||||
toml_str += "\n"
|
|
||||||
toml_str += "\n"
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(config_file_path, "w", encoding="utf-8") as f:
|
|
||||||
f.write(toml_str)
|
|
||||||
logger.info(f"{self.log_prefix} 配置文件已保存: {config_file_path}")
|
|
||||||
except IOError as e:
|
|
||||||
logger.error(f"{self.log_prefix} 保存配置文件失败: {e}", exc_info=True)
|
|
||||||
|
|
||||||
def _load_plugin_config(self): # sourcery skip: extract-method
|
|
||||||
"""加载插件配置文件,支持版本检查和自动迁移"""
|
|
||||||
if not self.config_file_name:
|
|
||||||
logger.debug(f"{self.log_prefix} 未指定配置文件,跳过加载")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 优先使用传入的插件目录路径
|
|
||||||
if self.plugin_dir:
|
|
||||||
plugin_dir = self.plugin_dir
|
|
||||||
else:
|
|
||||||
# fallback:尝试从类的模块信息获取路径
|
|
||||||
try:
|
|
||||||
plugin_module_path = inspect.getfile(self.__class__)
|
|
||||||
plugin_dir = os.path.dirname(plugin_module_path)
|
|
||||||
except (TypeError, OSError):
|
|
||||||
# 最后的fallback:从模块的__file__属性获取
|
|
||||||
module = inspect.getmodule(self.__class__)
|
|
||||||
if module and hasattr(module, "__file__") and module.__file__:
|
|
||||||
plugin_dir = os.path.dirname(module.__file__)
|
|
||||||
else:
|
|
||||||
logger.warning(f"{self.log_prefix} 无法获取插件目录路径,跳过配置加载")
|
|
||||||
return
|
|
||||||
|
|
||||||
config_file_path = os.path.join(plugin_dir, self.config_file_name)
|
|
||||||
|
|
||||||
# 如果配置文件不存在,生成默认配置
|
|
||||||
if not os.path.exists(config_file_path):
|
|
||||||
logger.info(f"{self.log_prefix} 配置文件 {config_file_path} 不存在,将生成默认配置。")
|
|
||||||
self._generate_and_save_default_config(config_file_path)
|
|
||||||
|
|
||||||
if not os.path.exists(config_file_path):
|
|
||||||
logger.warning(f"{self.log_prefix} 配置文件 {config_file_path} 不存在且无法生成。")
|
|
||||||
return
|
|
||||||
|
|
||||||
file_ext = os.path.splitext(self.config_file_name)[1].lower()
|
|
||||||
|
|
||||||
if file_ext == ".toml":
|
|
||||||
# 加载现有配置
|
|
||||||
with open(config_file_path, "r", encoding="utf-8") as f:
|
|
||||||
existing_config = toml.load(f) or {}
|
|
||||||
|
|
||||||
# 检查配置版本
|
|
||||||
current_version = self._get_current_config_version(existing_config)
|
|
||||||
|
|
||||||
# 如果配置文件没有版本信息,跳过版本检查
|
|
||||||
if current_version == "0.0.0":
|
|
||||||
logger.debug(f"{self.log_prefix} 配置文件无版本信息,跳过版本检查")
|
|
||||||
self.config = existing_config
|
|
||||||
else:
|
|
||||||
expected_version = self._get_expected_config_version()
|
|
||||||
|
|
||||||
if current_version != expected_version:
|
|
||||||
logger.info(
|
|
||||||
f"{self.log_prefix} 检测到配置版本需要更新: 当前=v{current_version}, 期望=v{expected_version}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 生成新的默认配置结构
|
|
||||||
new_config_structure = self._generate_config_from_schema()
|
|
||||||
|
|
||||||
# 迁移旧配置值到新结构
|
|
||||||
migrated_config = self._migrate_config_values(existing_config, new_config_structure)
|
|
||||||
|
|
||||||
# 保存迁移后的配置
|
|
||||||
self._save_config_to_file(migrated_config, config_file_path)
|
|
||||||
|
|
||||||
logger.info(f"{self.log_prefix} 配置文件已从 v{current_version} 更新到 v{expected_version}")
|
|
||||||
|
|
||||||
self.config = migrated_config
|
|
||||||
else:
|
|
||||||
logger.debug(f"{self.log_prefix} 配置版本匹配 (v{current_version}),直接加载")
|
|
||||||
self.config = existing_config
|
|
||||||
|
|
||||||
logger.debug(f"{self.log_prefix} 配置已从 {config_file_path} 加载")
|
|
||||||
|
|
||||||
# 从配置中更新 enable_plugin
|
|
||||||
if "plugin" in self.config and "enabled" in self.config["plugin"]:
|
|
||||||
self.enable_plugin = self.config["plugin"]["enabled"] # type: ignore
|
|
||||||
logger.debug(f"{self.log_prefix} 从配置更新插件启用状态: {self.enable_plugin}")
|
|
||||||
else:
|
|
||||||
logger.warning(f"{self.log_prefix} 不支持的配置文件格式: {file_ext},仅支持 .toml")
|
|
||||||
self.config = {}
|
|
||||||
|
|
||||||
def _check_dependencies(self) -> bool:
|
|
||||||
"""检查插件依赖"""
|
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
|
||||||
|
|
||||||
if not self.dependencies:
|
|
||||||
return True
|
|
||||||
|
|
||||||
for dependency in self.dependencies:
|
|
||||||
dep_name, version_spec, min_version, max_version = self._parse_dependency(dependency)
|
|
||||||
if not dep_name:
|
|
||||||
logger.warning(f"{self.log_prefix} 跳过无效依赖声明: {dependency}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
dep_plugin_info = component_registry.get_plugin_info(dep_name)
|
|
||||||
if not dep_plugin_info:
|
|
||||||
logger.error(f"{self.log_prefix} 缺少依赖插件: {dep_name}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
dep_version = dep_plugin_info.version or "0.0.0"
|
|
||||||
|
|
||||||
if version_spec:
|
|
||||||
is_ok, msg = self._is_version_spec_satisfied(dep_version, version_spec)
|
|
||||||
if not is_ok:
|
|
||||||
logger.error(
|
|
||||||
f"{self.log_prefix} 依赖插件版本不满足: {dep_name} {version_spec}, 当前版本={dep_version} ({msg})"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
if min_version or max_version:
|
|
||||||
is_ok, msg = VersionComparator.is_version_in_range(dep_version, min_version, max_version)
|
|
||||||
if not is_ok:
|
|
||||||
logger.error(
|
|
||||||
f"{self.log_prefix} 依赖插件版本不满足: {dep_name} 要求区间[{min_version or '-inf'}, {max_version or '+inf'}], 当前版本={dep_version} ({msg})"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _get_dependency_names(self) -> List[str]:
|
|
||||||
"""获取依赖插件名称列表(用于插件信息展示和统计)。"""
|
|
||||||
dependency_names: List[str] = []
|
|
||||||
for dependency in self.dependencies:
|
|
||||||
dep_name, _, _, _ = self._parse_dependency(dependency)
|
|
||||||
if dep_name:
|
|
||||||
dependency_names.append(dep_name)
|
|
||||||
return dependency_names
|
|
||||||
|
|
||||||
def _parse_dependency(self, dependency: Any) -> tuple[str, str, str, str]:
|
|
||||||
"""解析依赖声明。
|
|
||||||
|
|
||||||
支持格式:
|
|
||||||
- "plugin_a"
|
|
||||||
- {"name": "plugin_a", "version": ">=1.2.0,<2.0.0"}
|
|
||||||
- {"name": "plugin_a", "min_version": "1.2.0", "max_version": "2.0.0"}
|
|
||||||
"""
|
|
||||||
if isinstance(dependency, str):
|
|
||||||
return dependency.strip(), "", "", ""
|
|
||||||
|
|
||||||
if isinstance(dependency, dict):
|
|
||||||
dep_name = str(dependency.get("name", "")).strip()
|
|
||||||
version_spec = str(dependency.get("version", "")).strip()
|
|
||||||
min_version = str(dependency.get("min_version", "")).strip()
|
|
||||||
max_version = str(dependency.get("max_version", "")).strip()
|
|
||||||
return dep_name, version_spec, min_version, max_version
|
|
||||||
|
|
||||||
return "", "", "", ""
|
|
||||||
|
|
||||||
def _is_version_spec_satisfied(self, version: str, version_spec: str) -> tuple[bool, str]:
|
|
||||||
"""检查版本是否满足表达式。
|
|
||||||
|
|
||||||
支持:==, >=, <=, >, <,可用逗号分隔多个条件。
|
|
||||||
示例:">=1.2.0,<2.0.0"
|
|
||||||
"""
|
|
||||||
normalized_version = VersionComparator.normalize_version(version)
|
|
||||||
clauses = [clause.strip() for clause in version_spec.split(",") if clause.strip()]
|
|
||||||
if not clauses:
|
|
||||||
return True, ""
|
|
||||||
|
|
||||||
operators_pattern = r"^(==|>=|<=|>|<)\s*(.+)$"
|
|
||||||
|
|
||||||
for clause in clauses:
|
|
||||||
if not (match := re.match(operators_pattern, clause)):
|
|
||||||
return False, f"无效版本约束表达式: {clause}"
|
|
||||||
|
|
||||||
operator, target_version = match.group(1), VersionComparator.normalize_version(match.group(2))
|
|
||||||
compare_result = VersionComparator.compare_versions(normalized_version, target_version)
|
|
||||||
|
|
||||||
is_satisfied = False
|
|
||||||
if operator == "==":
|
|
||||||
is_satisfied = compare_result == 0
|
|
||||||
elif operator == ">=":
|
|
||||||
is_satisfied = compare_result >= 0
|
|
||||||
elif operator == "<=":
|
|
||||||
is_satisfied = compare_result <= 0
|
|
||||||
elif operator == ">":
|
|
||||||
is_satisfied = compare_result > 0
|
|
||||||
elif operator == "<":
|
|
||||||
is_satisfied = compare_result < 0
|
|
||||||
|
|
||||||
if not is_satisfied:
|
|
||||||
return False, f"{normalized_version} 不满足约束 {operator}{target_version}"
|
|
||||||
|
|
||||||
return True, ""
|
|
||||||
|
|
||||||
def get_config(self, key: str, default: Any = None) -> Any:
|
|
||||||
"""获取插件配置值,支持嵌套键访问
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key: 配置键名,支持嵌套访问如 "section.subsection.key"
|
|
||||||
default: 默认值
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Any: 配置值或默认值
|
|
||||||
"""
|
|
||||||
# 支持嵌套键访问
|
|
||||||
keys = key.split(".")
|
|
||||||
current = self.config
|
|
||||||
|
|
||||||
for k in keys:
|
|
||||||
if isinstance(current, dict) and k in current:
|
|
||||||
current = current[k]
|
|
||||||
else:
|
|
||||||
return default
|
|
||||||
|
|
||||||
return current
|
|
||||||
|
|
||||||
def get_webui_config_schema(self) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
获取 WebUI 配置 Schema
|
|
||||||
|
|
||||||
返回完整的配置 schema,包含:
|
|
||||||
- 插件基本信息
|
|
||||||
- 所有 section 及其字段定义
|
|
||||||
- 布局配置
|
|
||||||
|
|
||||||
用于 WebUI 动态生成配置表单。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: 完整的配置 schema
|
|
||||||
"""
|
|
||||||
schema = {
|
|
||||||
"plugin_id": self.plugin_name,
|
|
||||||
"plugin_info": {
|
|
||||||
"name": self.display_name,
|
|
||||||
"version": self.plugin_version,
|
|
||||||
"description": self.plugin_description,
|
|
||||||
"author": self.plugin_author,
|
|
||||||
},
|
|
||||||
"sections": {},
|
|
||||||
"layout": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 处理 sections
|
|
||||||
for section_name, fields in self.config_schema.items():
|
|
||||||
if not isinstance(fields, dict):
|
|
||||||
continue
|
|
||||||
|
|
||||||
section_data = {
|
|
||||||
"name": section_name,
|
|
||||||
"title": section_name,
|
|
||||||
"description": None,
|
|
||||||
"icon": None,
|
|
||||||
"collapsed": False,
|
|
||||||
"order": 0,
|
|
||||||
"fields": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
# 获取 section 元数据
|
|
||||||
section_meta = self.config_section_descriptions.get(section_name)
|
|
||||||
if section_meta:
|
|
||||||
if isinstance(section_meta, str):
|
|
||||||
section_data["title"] = section_meta
|
|
||||||
elif isinstance(section_meta, ConfigSection):
|
|
||||||
section_data["title"] = section_meta.title
|
|
||||||
section_data["description"] = section_meta.description
|
|
||||||
section_data["icon"] = section_meta.icon
|
|
||||||
section_data["collapsed"] = section_meta.collapsed
|
|
||||||
section_data["order"] = section_meta.order
|
|
||||||
elif isinstance(section_meta, dict):
|
|
||||||
section_data.update(section_meta)
|
|
||||||
|
|
||||||
# 处理字段
|
|
||||||
for field_name, field_def in fields.items():
|
|
||||||
if isinstance(field_def, ConfigField):
|
|
||||||
field_data = field_def.to_dict()
|
|
||||||
field_data["name"] = field_name
|
|
||||||
section_data["fields"][field_name] = field_data
|
|
||||||
|
|
||||||
schema["sections"][section_name] = section_data
|
|
||||||
|
|
||||||
# 处理布局
|
|
||||||
if self.config_layout:
|
|
||||||
schema["layout"] = self.config_layout.to_dict()
|
|
||||||
else:
|
|
||||||
# 自动布局:按 section order 排序
|
|
||||||
schema["layout"] = {
|
|
||||||
"type": "auto",
|
|
||||||
"tabs": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
return schema
|
|
||||||
|
|
||||||
def get_current_config_values(self) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
获取当前配置值
|
|
||||||
|
|
||||||
返回插件当前的配置值(已从配置文件加载)。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: 当前配置值
|
|
||||||
"""
|
|
||||||
return self.config.copy()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def register_plugin(self) -> bool:
|
|
||||||
"""
|
|
||||||
注册插件到插件管理器
|
|
||||||
|
|
||||||
子类必须实现此方法,返回注册是否成功
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否成功注册插件
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("Subclasses must implement this method")
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any, Dict, List
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PluginServiceInfo:
|
|
||||||
"""插件服务注册信息"""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
plugin_name: str
|
|
||||||
version: str = "1.0.0"
|
|
||||||
description: str = ""
|
|
||||||
enabled: bool = True
|
|
||||||
public: bool = False
|
|
||||||
allowed_callers: List[str] = field(default_factory=list)
|
|
||||||
params_schema: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
return_schema: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def full_name(self) -> str:
|
|
||||||
return f"{self.plugin_name}.{self.name}"
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
from enum import Enum
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowErrorCode(Enum):
|
|
||||||
"""Workflow统一错误码"""
|
|
||||||
|
|
||||||
PLUGIN_NOT_READY = "PLUGIN_NOT_READY"
|
|
||||||
STEP_TIMEOUT = "STEP_TIMEOUT"
|
|
||||||
BAD_PAYLOAD = "BAD_PAYLOAD"
|
|
||||||
DOWNSTREAM_FAILED = "DOWNSTREAM_FAILED"
|
|
||||||
POLICY_BLOCKED = "POLICY_BLOCKED"
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return self.value
|
|
||||||
@@ -1,68 +0,0 @@
|
|||||||
from dataclasses import dataclass, field
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Dict, List, Literal, Optional
|
|
||||||
import time
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowStage(Enum):
|
|
||||||
"""Workflow阶段定义(MVP固定阶段)"""
|
|
||||||
|
|
||||||
INGRESS = "ingress"
|
|
||||||
PRE_PROCESS = "pre_process"
|
|
||||||
PLAN = "plan"
|
|
||||||
TOOL_EXECUTE = "tool_execute"
|
|
||||||
POST_PROCESS = "post_process"
|
|
||||||
EGRESS = "egress"
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return self.value
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class WorkflowContext:
|
|
||||||
"""Workflow上下文"""
|
|
||||||
|
|
||||||
trace_id: str
|
|
||||||
stream_id: Optional[str] = None
|
|
||||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
timings: Dict[str, float] = field(default_factory=dict)
|
|
||||||
errors: List[str] = field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class WorkflowMessage:
|
|
||||||
"""Workflow消息包装"""
|
|
||||||
|
|
||||||
msg_type: str
|
|
||||||
payload: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
headers: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
mutable_flags: Dict[str, bool] = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class WorkflowStepResult:
|
|
||||||
"""Workflow步骤结果"""
|
|
||||||
|
|
||||||
status: Literal["continue", "stop", "failed"] = "continue"
|
|
||||||
return_message: Optional[str] = None
|
|
||||||
diagnostics: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
events: List[Dict[str, Any]] = field(default_factory=list)
|
|
||||||
created_at: float = field(default_factory=time.time)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class WorkflowStepInfo:
|
|
||||||
"""Workflow步骤元数据"""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
stage: WorkflowStage
|
|
||||||
plugin_name: str
|
|
||||||
description: str = ""
|
|
||||||
enabled: bool = True
|
|
||||||
priority: int = 0
|
|
||||||
timeout_ms: int = 0
|
|
||||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def full_name(self) -> str:
|
|
||||||
return f"{self.plugin_name}.{self.name}"
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
"""
|
|
||||||
插件核心管理模块
|
|
||||||
|
|
||||||
提供插件的加载、注册和管理功能
|
|
||||||
"""
|
|
||||||
|
|
||||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
|
||||||
from src.plugin_system.core.events_manager import events_manager
|
|
||||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
|
||||||
from src.plugin_system.core.plugin_service_registry import plugin_service_registry
|
|
||||||
from src.plugin_system.core.workflow_engine import workflow_engine
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"plugin_manager",
|
|
||||||
"component_registry",
|
|
||||||
"events_manager",
|
|
||||||
"global_announcement_manager",
|
|
||||||
"plugin_service_registry",
|
|
||||||
"workflow_engine",
|
|
||||||
]
|
|
||||||
@@ -1,758 +0,0 @@
|
|||||||
import re
|
|
||||||
|
|
||||||
from typing import Callable, Dict, List, Optional, Any, Pattern, Tuple, Union, Type
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.plugin_system.base.component_types import (
|
|
||||||
ComponentInfo,
|
|
||||||
ActionInfo,
|
|
||||||
ToolInfo,
|
|
||||||
CommandInfo,
|
|
||||||
EventHandlerInfo,
|
|
||||||
PluginInfo,
|
|
||||||
ComponentType,
|
|
||||||
)
|
|
||||||
from src.plugin_system.base.base_command import BaseCommand
|
|
||||||
from src.plugin_system.base.base_action import BaseAction
|
|
||||||
from src.plugin_system.base.base_tool import BaseTool
|
|
||||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
|
||||||
from src.plugin_system.base.workflow_types import WorkflowStage, WorkflowStepInfo
|
|
||||||
|
|
||||||
logger = get_logger("component_registry")
|
|
||||||
|
|
||||||
|
|
||||||
class ComponentRegistry:
|
|
||||||
"""统一的组件注册中心
|
|
||||||
|
|
||||||
负责管理所有插件组件的注册、查询和生命周期管理
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
# 命名空间式组件名构成法 f"{component_type}.{component_name}"
|
|
||||||
self._components: Dict[str, ComponentInfo] = {}
|
|
||||||
"""组件注册表 命名空间式组件名 -> 组件信息"""
|
|
||||||
self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType}
|
|
||||||
"""类型 -> 组件原名称 -> 组件信息"""
|
|
||||||
self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseTool, BaseEventHandler]]] = {}
|
|
||||||
"""命名空间式组件名 -> 组件类"""
|
|
||||||
|
|
||||||
# 插件注册表
|
|
||||||
self._plugins: Dict[str, PluginInfo] = {}
|
|
||||||
"""插件名 -> 插件信息"""
|
|
||||||
|
|
||||||
# Action特定注册表
|
|
||||||
self._action_registry: Dict[str, Type[BaseAction]] = {}
|
|
||||||
"""Action注册表 action名 -> action类"""
|
|
||||||
self._default_actions: Dict[str, ActionInfo] = {}
|
|
||||||
"""默认动作集,即启用的Action集,用于重置ActionManager状态"""
|
|
||||||
|
|
||||||
# Command特定注册表
|
|
||||||
self._command_registry: Dict[str, Type[BaseCommand]] = {}
|
|
||||||
"""Command类注册表 command名 -> command类"""
|
|
||||||
self._command_patterns: Dict[Pattern, str] = {}
|
|
||||||
"""编译后的正则 -> command名"""
|
|
||||||
|
|
||||||
# 工具特定注册表
|
|
||||||
self._tool_registry: Dict[str, Type[BaseTool]] = {} # 工具名 -> 工具类
|
|
||||||
self._llm_available_tools: Dict[str, Type[BaseTool]] = {} # llm可用的工具名 -> 工具类
|
|
||||||
|
|
||||||
# EventHandler特定注册表
|
|
||||||
self._event_handler_registry: Dict[str, Type[BaseEventHandler]] = {}
|
|
||||||
"""event_handler名 -> event_handler类"""
|
|
||||||
self._enabled_event_handlers: Dict[str, Type[BaseEventHandler]] = {}
|
|
||||||
"""启用的事件处理器 event_handler名 -> event_handler类"""
|
|
||||||
|
|
||||||
# Workflow step注册表
|
|
||||||
self._workflow_steps: Dict[WorkflowStage, Dict[str, WorkflowStepInfo]] = {stage: {} for stage in WorkflowStage}
|
|
||||||
self._workflow_step_handlers: Dict[str, Callable[..., Any]] = {}
|
|
||||||
|
|
||||||
logger.info("组件注册中心初始化完成")
|
|
||||||
|
|
||||||
# == 注册方法 ==
|
|
||||||
|
|
||||||
def register_plugin(self, plugin_info: PluginInfo) -> bool:
|
|
||||||
"""注册插件
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin_info: 插件信息
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否注册成功
|
|
||||||
"""
|
|
||||||
plugin_name = plugin_info.name
|
|
||||||
|
|
||||||
if plugin_name in self._plugins:
|
|
||||||
logger.warning(f"插件 {plugin_name} 已存在,跳过注册")
|
|
||||||
return False
|
|
||||||
|
|
||||||
self._plugins[plugin_name] = plugin_info
|
|
||||||
logger.debug(f"已注册插件: {plugin_name} (组件数量: {len(plugin_info.components)})")
|
|
||||||
return True
|
|
||||||
|
|
||||||
def register_component(
|
|
||||||
self,
|
|
||||||
component_info: ComponentInfo,
|
|
||||||
component_class: Type[Union[BaseCommand, BaseAction, BaseEventHandler, BaseTool]],
|
|
||||||
) -> bool:
|
|
||||||
"""注册组件
|
|
||||||
|
|
||||||
Args:
|
|
||||||
component_info (ComponentInfo): 组件信息
|
|
||||||
component_class (Type[Union[BaseCommand, BaseAction, BaseEventHandler]]): 组件类
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否注册成功
|
|
||||||
"""
|
|
||||||
component_name = component_info.name
|
|
||||||
component_type = component_info.component_type
|
|
||||||
plugin_name = getattr(component_info, "plugin_name", "unknown")
|
|
||||||
if "." in component_name:
|
|
||||||
logger.error(f"组件名称 '{component_name}' 包含非法字符 '.',请使用下划线替代")
|
|
||||||
return False
|
|
||||||
if "." in plugin_name:
|
|
||||||
logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
|
|
||||||
return False
|
|
||||||
|
|
||||||
namespaced_name = f"{component_type}.{component_name}"
|
|
||||||
|
|
||||||
if namespaced_name in self._components:
|
|
||||||
existing_info = self._components[namespaced_name]
|
|
||||||
existing_plugin = getattr(existing_info, "plugin_name", "unknown")
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
f"组件名冲突: '{plugin_name}' 插件的 {component_type} 类型组件 '{component_name}' 已被插件 '{existing_plugin}' 注册,跳过此组件注册"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
self._components[namespaced_name] = component_info # 注册到通用注册表(使用命名空间化的名称)
|
|
||||||
self._components_by_type[component_type][component_name] = component_info # 类型内部仍使用原名
|
|
||||||
self._components_classes[namespaced_name] = component_class
|
|
||||||
|
|
||||||
# 根据组件类型进行特定注册(使用原始名称)
|
|
||||||
ret = False
|
|
||||||
match component_type:
|
|
||||||
case ComponentType.ACTION:
|
|
||||||
assert isinstance(component_info, ActionInfo)
|
|
||||||
assert issubclass(component_class, BaseAction)
|
|
||||||
ret = self._register_action_component(component_info, component_class)
|
|
||||||
case ComponentType.COMMAND:
|
|
||||||
assert isinstance(component_info, CommandInfo)
|
|
||||||
assert issubclass(component_class, BaseCommand)
|
|
||||||
ret = self._register_command_component(component_info, component_class)
|
|
||||||
case ComponentType.TOOL:
|
|
||||||
assert isinstance(component_info, ToolInfo)
|
|
||||||
assert issubclass(component_class, BaseTool)
|
|
||||||
ret = self._register_tool_component(component_info, component_class)
|
|
||||||
case ComponentType.EVENT_HANDLER:
|
|
||||||
assert isinstance(component_info, EventHandlerInfo)
|
|
||||||
assert issubclass(component_class, BaseEventHandler)
|
|
||||||
ret = self._register_event_handler_component(component_info, component_class)
|
|
||||||
case _:
|
|
||||||
logger.warning(f"未知组件类型: {component_type}")
|
|
||||||
|
|
||||||
if not ret:
|
|
||||||
return False
|
|
||||||
logger.debug(
|
|
||||||
f"已注册{component_type}组件: '{component_name}' -> '{namespaced_name}' "
|
|
||||||
f"({component_class.__name__}) [插件: {plugin_name}]"
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _register_action_component(self, action_info: ActionInfo, action_class: Type[BaseAction]) -> bool:
|
|
||||||
"""注册Action组件到Action特定注册表"""
|
|
||||||
if not (action_name := action_info.name):
|
|
||||||
logger.error(f"Action组件 {action_class.__name__} 必须指定名称")
|
|
||||||
return False
|
|
||||||
if not isinstance(action_info, ActionInfo) or not issubclass(action_class, BaseAction):
|
|
||||||
logger.error(f"注册失败: {action_name} 不是有效的Action")
|
|
||||||
return False
|
|
||||||
|
|
||||||
self._action_registry[action_name] = action_class
|
|
||||||
|
|
||||||
# 如果启用,添加到默认动作集
|
|
||||||
if action_info.enabled:
|
|
||||||
self._default_actions[action_name] = action_info
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _register_command_component(self, command_info: CommandInfo, command_class: Type[BaseCommand]) -> bool:
|
|
||||||
"""注册Command组件到Command特定注册表"""
|
|
||||||
if not (command_name := command_info.name):
|
|
||||||
logger.error(f"Command组件 {command_class.__name__} 必须指定名称")
|
|
||||||
return False
|
|
||||||
if not isinstance(command_info, CommandInfo) or not issubclass(command_class, BaseCommand):
|
|
||||||
logger.error(f"注册失败: {command_name} 不是有效的Command")
|
|
||||||
return False
|
|
||||||
|
|
||||||
self._command_registry[command_name] = command_class
|
|
||||||
|
|
||||||
# 如果启用了且有匹配模式
|
|
||||||
if command_info.enabled and command_info.command_pattern:
|
|
||||||
pattern = re.compile(command_info.command_pattern, re.IGNORECASE | re.DOTALL)
|
|
||||||
if pattern not in self._command_patterns:
|
|
||||||
self._command_patterns[pattern] = command_name
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令"
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _register_tool_component(self, tool_info: ToolInfo, tool_class: Type[BaseTool]) -> bool:
|
|
||||||
"""注册Tool组件到Tool特定注册表"""
|
|
||||||
tool_name = tool_info.name
|
|
||||||
|
|
||||||
self._tool_registry[tool_name] = tool_class
|
|
||||||
|
|
||||||
# 如果是llm可用的且启用的工具,添加到 llm可用工具列表
|
|
||||||
if tool_info.enabled:
|
|
||||||
self._llm_available_tools[tool_name] = tool_class
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _register_event_handler_component(
|
|
||||||
self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]
|
|
||||||
) -> bool:
|
|
||||||
if not (handler_name := handler_info.name):
|
|
||||||
logger.error(f"EventHandler组件 {handler_class.__name__} 必须指定名称")
|
|
||||||
return False
|
|
||||||
if not isinstance(handler_info, EventHandlerInfo) or not issubclass(handler_class, BaseEventHandler):
|
|
||||||
logger.error(f"注册失败: {handler_name} 不是有效的EventHandler")
|
|
||||||
return False
|
|
||||||
|
|
||||||
self._event_handler_registry[handler_name] = handler_class
|
|
||||||
|
|
||||||
if not handler_info.enabled:
|
|
||||||
logger.warning(f"EventHandler组件 {handler_name} 未启用")
|
|
||||||
return True # 未启用,但是也是注册成功
|
|
||||||
|
|
||||||
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
|
||||||
|
|
||||||
if events_manager.register_event_subscriber(handler_info, handler_class):
|
|
||||||
self._enabled_event_handlers[handler_name] = handler_class
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
logger.error(f"注册事件处理器 {handler_name} 失败")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# === 组件移除相关 ===
|
|
||||||
|
|
||||||
async def remove_component(self, component_name: str, component_type: ComponentType, plugin_name: str) -> bool:
|
|
||||||
target_component_class = self.get_component_class(component_name, component_type)
|
|
||||||
if not target_component_class:
|
|
||||||
logger.warning(f"组件 {component_name} 未注册,无法移除")
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
match component_type:
|
|
||||||
case ComponentType.ACTION:
|
|
||||||
self._action_registry.pop(component_name)
|
|
||||||
self._default_actions.pop(component_name)
|
|
||||||
case ComponentType.COMMAND:
|
|
||||||
self._command_registry.pop(component_name)
|
|
||||||
keys_to_remove = [k for k, v in self._command_patterns.items() if v == component_name]
|
|
||||||
for key in keys_to_remove:
|
|
||||||
self._command_patterns.pop(key)
|
|
||||||
case ComponentType.TOOL:
|
|
||||||
self._tool_registry.pop(component_name)
|
|
||||||
self._llm_available_tools.pop(component_name)
|
|
||||||
case ComponentType.EVENT_HANDLER:
|
|
||||||
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
|
||||||
|
|
||||||
self._event_handler_registry.pop(component_name)
|
|
||||||
self._enabled_event_handlers.pop(component_name)
|
|
||||||
await events_manager.unregister_event_subscriber(component_name)
|
|
||||||
namespaced_name = f"{component_type}.{component_name}"
|
|
||||||
self._components.pop(namespaced_name)
|
|
||||||
self._components_by_type[component_type].pop(component_name)
|
|
||||||
self._components_classes.pop(namespaced_name)
|
|
||||||
logger.info(f"组件 {component_name} 已移除")
|
|
||||||
return True
|
|
||||||
except KeyError as e:
|
|
||||||
logger.warning(f"移除组件时未找到组件: {component_name}, 发生错误: {e}")
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"移除组件 {component_name} 时发生错误: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def remove_components_by_plugin(self, plugin_name: str) -> int:
|
|
||||||
"""移除某插件注册的所有组件。"""
|
|
||||||
targets = [
|
|
||||||
(component_info.name, component_info.component_type)
|
|
||||||
for component_info in self._components.values()
|
|
||||||
if component_info.plugin_name == plugin_name
|
|
||||||
]
|
|
||||||
|
|
||||||
removed_count = 0
|
|
||||||
for component_name, component_type in targets:
|
|
||||||
if await self.remove_component(component_name, component_type, plugin_name):
|
|
||||||
removed_count += 1
|
|
||||||
|
|
||||||
if removed_count:
|
|
||||||
logger.info(f"已移除插件 {plugin_name} 的组件数量: {removed_count}")
|
|
||||||
return removed_count
|
|
||||||
|
|
||||||
def remove_plugin_registry(self, plugin_name: str) -> bool:
|
|
||||||
"""移除插件注册信息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin_name: 插件名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否成功移除
|
|
||||||
"""
|
|
||||||
if plugin_name not in self._plugins:
|
|
||||||
logger.warning(f"插件 {plugin_name} 未注册,无法移除")
|
|
||||||
return False
|
|
||||||
del self._plugins[plugin_name]
|
|
||||||
self.remove_workflow_steps_by_plugin(plugin_name)
|
|
||||||
logger.info(f"插件 {plugin_name} 已移除")
|
|
||||||
return True
|
|
||||||
|
|
||||||
# === Workflow step 注册与查询 ===
|
|
||||||
|
|
||||||
def register_workflow_step(self, step_info: WorkflowStepInfo, step_handler: Callable[..., Any]) -> bool:
|
|
||||||
"""注册workflow步骤。"""
|
|
||||||
if not step_info.name or not step_info.plugin_name:
|
|
||||||
logger.error("workflow step 注册失败: step名称或插件名称为空")
|
|
||||||
return False
|
|
||||||
if "." in step_info.name:
|
|
||||||
logger.error(f"workflow step 名称 '{step_info.name}' 包含非法字符 '.',请使用下划线替代")
|
|
||||||
return False
|
|
||||||
if "." in step_info.plugin_name:
|
|
||||||
logger.error(f"workflow step 所属插件名称 '{step_info.plugin_name}' 包含非法字符 '.',请使用下划线替代")
|
|
||||||
return False
|
|
||||||
|
|
||||||
full_name = step_info.full_name
|
|
||||||
stage_registry = self._workflow_steps.get(step_info.stage)
|
|
||||||
if stage_registry is None:
|
|
||||||
logger.error(f"workflow step 注册失败: 未知阶段 {step_info.stage}")
|
|
||||||
return False
|
|
||||||
if full_name in stage_registry:
|
|
||||||
logger.warning(f"workflow step 已存在,跳过注册: {full_name}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
stage_registry[full_name] = step_info
|
|
||||||
self._workflow_step_handlers[full_name] = step_handler
|
|
||||||
logger.debug(f"已注册workflow step: {full_name} @ {step_info.stage}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_steps_by_stage(self, stage: WorkflowStage, enabled_only: bool = False) -> Dict[str, WorkflowStepInfo]:
|
|
||||||
"""获取某阶段的workflow步骤。"""
|
|
||||||
steps = self._workflow_steps.get(stage, {})
|
|
||||||
if enabled_only:
|
|
||||||
return {name: info for name, info in steps.items() if info.enabled}
|
|
||||||
return steps.copy()
|
|
||||||
|
|
||||||
def get_workflow_step(self, step_name: str, stage: Optional[WorkflowStage] = None) -> Optional[WorkflowStepInfo]:
|
|
||||||
"""获取workflow step信息。
|
|
||||||
|
|
||||||
step_name支持两种:
|
|
||||||
- full_name: plugin_name.step_name
|
|
||||||
- short_name: step_name(若有冲突返回第一个并告警)
|
|
||||||
"""
|
|
||||||
candidates: List[WorkflowStepInfo] = []
|
|
||||||
|
|
||||||
target_stages = [stage] if stage else list(WorkflowStage)
|
|
||||||
for current_stage in target_stages:
|
|
||||||
current_steps = self._workflow_steps.get(current_stage, {})
|
|
||||||
if "." in step_name:
|
|
||||||
if step_info := current_steps.get(step_name):
|
|
||||||
return step_info
|
|
||||||
continue
|
|
||||||
candidates.extend([step_info for step_info in current_steps.values() if step_info.name == step_name])
|
|
||||||
|
|
||||||
if len(candidates) == 1:
|
|
||||||
return candidates[0]
|
|
||||||
if len(candidates) > 1:
|
|
||||||
logger.warning(f"workflow step 名称 '{step_name}' 存在多义,使用第一个匹配: {candidates[0].full_name}")
|
|
||||||
return candidates[0]
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_workflow_step_handler(
|
|
||||||
self, step_name: str, stage: Optional[WorkflowStage] = None
|
|
||||||
) -> Optional[Callable[..., Any]]:
|
|
||||||
"""获取workflow step处理函数。"""
|
|
||||||
if "." in step_name:
|
|
||||||
return self._workflow_step_handlers.get(step_name)
|
|
||||||
|
|
||||||
if step_info := self.get_workflow_step(step_name, stage):
|
|
||||||
return self._workflow_step_handlers.get(step_info.full_name)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def enable_workflow_step(self, step_name: str, stage: Optional[WorkflowStage] = None) -> bool:
|
|
||||||
"""启用workflow step。"""
|
|
||||||
step_info = self.get_workflow_step(step_name, stage)
|
|
||||||
if not step_info:
|
|
||||||
logger.warning(f"workflow step 未注册,无法启用: {step_name}")
|
|
||||||
return False
|
|
||||||
step_info.enabled = True
|
|
||||||
logger.info(f"workflow step 已启用: {step_info.full_name}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
def disable_workflow_step(self, step_name: str, stage: Optional[WorkflowStage] = None) -> bool:
|
|
||||||
"""禁用workflow step。"""
|
|
||||||
step_info = self.get_workflow_step(step_name, stage)
|
|
||||||
if not step_info:
|
|
||||||
logger.warning(f"workflow step 未注册,无法禁用: {step_name}")
|
|
||||||
return False
|
|
||||||
step_info.enabled = False
|
|
||||||
logger.info(f"workflow step 已禁用: {step_info.full_name}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
def remove_workflow_steps_by_plugin(self, plugin_name: str) -> int:
|
|
||||||
"""移除某插件注册的所有workflow step。"""
|
|
||||||
removed_count = 0
|
|
||||||
for stage in WorkflowStage:
|
|
||||||
stage_registry = self._workflow_steps.get(stage, {})
|
|
||||||
target_names = [name for name, info in stage_registry.items() if info.plugin_name == plugin_name]
|
|
||||||
for full_name in target_names:
|
|
||||||
stage_registry.pop(full_name, None)
|
|
||||||
self._workflow_step_handlers.pop(full_name, None)
|
|
||||||
removed_count += 1
|
|
||||||
|
|
||||||
if removed_count:
|
|
||||||
logger.info(f"已移除插件 {plugin_name} 的 workflow step 数量: {removed_count}")
|
|
||||||
return removed_count
|
|
||||||
|
|
||||||
# === 组件全局启用/禁用方法 ===
|
|
||||||
|
|
||||||
def enable_component(self, component_name: str, component_type: ComponentType) -> bool:
|
|
||||||
"""全局的启用某个组件
|
|
||||||
Parameters:
|
|
||||||
component_name: 组件名称
|
|
||||||
component_type: 组件类型
|
|
||||||
Returns:
|
|
||||||
bool: 启用成功返回True,失败返回False
|
|
||||||
"""
|
|
||||||
target_component_class = self.get_component_class(component_name, component_type)
|
|
||||||
target_component_info = self.get_component_info(component_name, component_type)
|
|
||||||
if not target_component_class or not target_component_info:
|
|
||||||
logger.warning(f"组件 {component_name} 未注册,无法启用")
|
|
||||||
return False
|
|
||||||
target_component_info.enabled = True
|
|
||||||
match component_type:
|
|
||||||
case ComponentType.ACTION:
|
|
||||||
assert isinstance(target_component_info, ActionInfo)
|
|
||||||
self._default_actions[component_name] = target_component_info
|
|
||||||
case ComponentType.COMMAND:
|
|
||||||
assert isinstance(target_component_info, CommandInfo)
|
|
||||||
pattern = target_component_info.command_pattern
|
|
||||||
self._command_patterns[re.compile(pattern)] = component_name
|
|
||||||
case ComponentType.TOOL:
|
|
||||||
assert isinstance(target_component_info, ToolInfo)
|
|
||||||
assert issubclass(target_component_class, BaseTool)
|
|
||||||
self._llm_available_tools[component_name] = target_component_class
|
|
||||||
case ComponentType.EVENT_HANDLER:
|
|
||||||
assert isinstance(target_component_info, EventHandlerInfo)
|
|
||||||
assert issubclass(target_component_class, BaseEventHandler)
|
|
||||||
self._enabled_event_handlers[component_name] = target_component_class
|
|
||||||
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
|
||||||
|
|
||||||
events_manager.register_event_subscriber(target_component_info, target_component_class)
|
|
||||||
namespaced_name = f"{component_type}.{component_name}"
|
|
||||||
self._components[namespaced_name].enabled = True
|
|
||||||
self._components_by_type[component_type][component_name].enabled = True
|
|
||||||
logger.info(f"组件 {component_name} 已启用")
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def disable_component(self, component_name: str, component_type: ComponentType) -> bool:
|
|
||||||
"""全局的禁用某个组件
|
|
||||||
Parameters:
|
|
||||||
component_name: 组件名称
|
|
||||||
component_type: 组件类型
|
|
||||||
Returns:
|
|
||||||
bool: 禁用成功返回True,失败返回False
|
|
||||||
"""
|
|
||||||
target_component_class = self.get_component_class(component_name, component_type)
|
|
||||||
target_component_info = self.get_component_info(component_name, component_type)
|
|
||||||
if not target_component_class or not target_component_info:
|
|
||||||
logger.warning(f"组件 {component_name} 未注册,无法禁用")
|
|
||||||
return False
|
|
||||||
target_component_info.enabled = False
|
|
||||||
try:
|
|
||||||
match component_type:
|
|
||||||
case ComponentType.ACTION:
|
|
||||||
self._default_actions.pop(component_name)
|
|
||||||
case ComponentType.COMMAND:
|
|
||||||
self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != component_name}
|
|
||||||
case ComponentType.TOOL:
|
|
||||||
self._llm_available_tools.pop(component_name)
|
|
||||||
case ComponentType.EVENT_HANDLER:
|
|
||||||
self._enabled_event_handlers.pop(component_name)
|
|
||||||
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
|
||||||
|
|
||||||
await events_manager.unregister_event_subscriber(component_name)
|
|
||||||
self._components[component_name].enabled = False
|
|
||||||
self._components_by_type[component_type][component_name].enabled = False
|
|
||||||
logger.info(f"组件 {component_name} 已禁用")
|
|
||||||
return True
|
|
||||||
except KeyError as e:
|
|
||||||
logger.warning(f"禁用组件时未找到组件或已禁用: {component_name}, 发生错误: {e}")
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"禁用组件 {component_name} 时发生错误: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# === 组件查询方法 ===
|
|
||||||
def get_component_info(
|
|
||||||
self, component_name: str, component_type: Optional[ComponentType] = None
|
|
||||||
) -> Optional[ComponentInfo]:
|
|
||||||
# sourcery skip: class-extract-method
|
|
||||||
"""获取组件信息,支持自动命名空间解析
|
|
||||||
|
|
||||||
Args:
|
|
||||||
component_name: 组件名称,可以是原始名称或命名空间化的名称
|
|
||||||
component_type: 组件类型,如果提供则优先在该类型中查找
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[ComponentInfo]: 组件信息或None
|
|
||||||
"""
|
|
||||||
# 1. 如果已经是命名空间化的名称,直接查找
|
|
||||||
if "." in component_name:
|
|
||||||
return self._components.get(component_name)
|
|
||||||
|
|
||||||
# 2. 如果指定了组件类型,构造命名空间化的名称查找
|
|
||||||
if component_type:
|
|
||||||
namespaced_name = f"{component_type}.{component_name}"
|
|
||||||
return self._components.get(namespaced_name)
|
|
||||||
|
|
||||||
# 3. 如果没有指定类型,尝试在所有命名空间中查找
|
|
||||||
candidates = []
|
|
||||||
for namespace_prefix in [types.value for types in ComponentType]:
|
|
||||||
namespaced_name = f"{namespace_prefix}.{component_name}"
|
|
||||||
if component_info := self._components.get(namespaced_name):
|
|
||||||
candidates.append((namespace_prefix, namespaced_name, component_info))
|
|
||||||
|
|
||||||
if len(candidates) == 1:
|
|
||||||
# 只有一个匹配,直接返回
|
|
||||||
return candidates[0][2]
|
|
||||||
elif len(candidates) > 1:
|
|
||||||
# 多个匹配,记录警告并返回第一个
|
|
||||||
namespaces = [ns for ns, _, _ in candidates]
|
|
||||||
logger.warning(
|
|
||||||
f"组件名称 '{component_name}' 在多个命名空间中存在: {namespaces},使用第一个匹配项: {candidates[0][1]}"
|
|
||||||
)
|
|
||||||
return candidates[0][2]
|
|
||||||
|
|
||||||
# 4. 都没找到
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_component_class(
|
|
||||||
self,
|
|
||||||
component_name: str,
|
|
||||||
component_type: Optional[ComponentType] = None,
|
|
||||||
) -> Optional[Union[Type[BaseCommand], Type[BaseAction], Type[BaseEventHandler], Type[BaseTool]]]:
|
|
||||||
"""获取组件类,支持自动命名空间解析
|
|
||||||
|
|
||||||
Args:
|
|
||||||
component_name: 组件名称,可以是原始名称或命名空间化的名称
|
|
||||||
component_type: 组件类型,如果提供则优先在该类型中查找
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[Union[BaseCommand, BaseAction]]: 组件类或None
|
|
||||||
"""
|
|
||||||
# 1. 如果已经是命名空间化的名称,直接查找
|
|
||||||
if "." in component_name:
|
|
||||||
return self._components_classes.get(component_name)
|
|
||||||
|
|
||||||
# 2. 如果指定了组件类型,构造命名空间化的名称查找
|
|
||||||
if component_type:
|
|
||||||
namespaced_name = f"{component_type.value}.{component_name}"
|
|
||||||
return self._components_classes.get(namespaced_name)
|
|
||||||
|
|
||||||
# 3. 如果没有指定类型,尝试在所有命名空间中查找
|
|
||||||
candidates = []
|
|
||||||
for namespace_prefix in [types.value for types in ComponentType]:
|
|
||||||
namespaced_name = f"{namespace_prefix}.{component_name}"
|
|
||||||
if component_class := self._components_classes.get(namespaced_name):
|
|
||||||
candidates.append((namespace_prefix, namespaced_name, component_class))
|
|
||||||
|
|
||||||
if len(candidates) == 1:
|
|
||||||
# 只有一个匹配,直接返回
|
|
||||||
_, full_name, cls = candidates[0]
|
|
||||||
logger.debug(f"自动解析组件: '{component_name}' -> '{full_name}'")
|
|
||||||
return cls
|
|
||||||
elif len(candidates) > 1:
|
|
||||||
# 多个匹配,记录警告并返回第一个
|
|
||||||
namespaces = [ns for ns, _, _ in candidates]
|
|
||||||
logger.warning(
|
|
||||||
f"组件名称 '{component_name}' 在多个命名空间中存在: {namespaces},使用第一个匹配项: {candidates[0][1]}"
|
|
||||||
)
|
|
||||||
return candidates[0][2]
|
|
||||||
|
|
||||||
# 4. 都没找到
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
|
|
||||||
"""获取指定类型的所有组件"""
|
|
||||||
return self._components_by_type.get(component_type, {}).copy()
|
|
||||||
|
|
||||||
def get_enabled_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
|
|
||||||
"""获取指定类型的所有启用组件"""
|
|
||||||
components = self.get_components_by_type(component_type)
|
|
||||||
return {name: info for name, info in components.items() if info.enabled}
|
|
||||||
|
|
||||||
# === Action特定查询方法 ===
|
|
||||||
|
|
||||||
def get_action_registry(self) -> Dict[str, Type[BaseAction]]:
|
|
||||||
"""获取Action注册表"""
|
|
||||||
return self._action_registry.copy()
|
|
||||||
|
|
||||||
def get_registered_action_info(self, action_name: str) -> Optional[ActionInfo]:
|
|
||||||
"""获取Action信息"""
|
|
||||||
info = self.get_component_info(action_name, ComponentType.ACTION)
|
|
||||||
return info if isinstance(info, ActionInfo) else None
|
|
||||||
|
|
||||||
def get_default_actions(self) -> Dict[str, ActionInfo]:
|
|
||||||
"""获取默认动作集"""
|
|
||||||
return self._default_actions.copy()
|
|
||||||
|
|
||||||
# === Command特定查询方法 ===
|
|
||||||
|
|
||||||
def get_command_registry(self) -> Dict[str, Type[BaseCommand]]:
|
|
||||||
"""获取Command注册表"""
|
|
||||||
return self._command_registry.copy()
|
|
||||||
|
|
||||||
def get_registered_command_info(self, command_name: str) -> Optional[CommandInfo]:
|
|
||||||
"""获取Command信息"""
|
|
||||||
info = self.get_component_info(command_name, ComponentType.COMMAND)
|
|
||||||
return info if isinstance(info, CommandInfo) else None
|
|
||||||
|
|
||||||
def get_command_patterns(self) -> Dict[Pattern, str]:
|
|
||||||
"""获取Command模式注册表"""
|
|
||||||
return self._command_patterns.copy()
|
|
||||||
|
|
||||||
def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, CommandInfo]]:
|
|
||||||
# sourcery skip: use-named-expression, use-next
|
|
||||||
"""根据文本查找匹配的命令
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: 输入文本
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple: (命令类, 匹配的命名组, 是否拦截消息, 插件名) 或 None
|
|
||||||
"""
|
|
||||||
|
|
||||||
candidates = [pattern for pattern in self._command_patterns if pattern.match(text)]
|
|
||||||
if not candidates:
|
|
||||||
return None
|
|
||||||
if len(candidates) > 1:
|
|
||||||
logger.warning(f"文本 '{text}' 匹配到多个命令模式: {candidates},使用第一个匹配")
|
|
||||||
command_name = self._command_patterns[candidates[0]]
|
|
||||||
command_info: CommandInfo = self.get_registered_command_info(command_name) # type: ignore
|
|
||||||
return (
|
|
||||||
self._command_registry[command_name],
|
|
||||||
candidates[0].match(text).groupdict(), # type: ignore
|
|
||||||
command_info,
|
|
||||||
)
|
|
||||||
|
|
||||||
# === Tool 特定查询方法 ===
|
|
||||||
def get_tool_registry(self) -> Dict[str, Type[BaseTool]]:
|
|
||||||
"""获取Tool注册表"""
|
|
||||||
return self._tool_registry.copy()
|
|
||||||
|
|
||||||
def get_llm_available_tools(self) -> Dict[str, Type[BaseTool]]:
|
|
||||||
"""获取LLM可用的Tool列表"""
|
|
||||||
return self._llm_available_tools.copy()
|
|
||||||
|
|
||||||
def get_registered_tool_info(self, tool_name: str) -> Optional[ToolInfo]:
|
|
||||||
"""获取Tool信息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_name: 工具名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ToolInfo: 工具信息对象,如果工具不存在则返回 None
|
|
||||||
"""
|
|
||||||
info = self.get_component_info(tool_name, ComponentType.TOOL)
|
|
||||||
return info if isinstance(info, ToolInfo) else None
|
|
||||||
|
|
||||||
# === EventHandler 特定查询方法 ===
|
|
||||||
|
|
||||||
def get_event_handler_registry(self) -> Dict[str, Type[BaseEventHandler]]:
|
|
||||||
"""获取事件处理器注册表"""
|
|
||||||
return self._event_handler_registry.copy()
|
|
||||||
|
|
||||||
def get_registered_event_handler_info(self, handler_name: str) -> Optional[EventHandlerInfo]:
|
|
||||||
"""获取事件处理器信息"""
|
|
||||||
info = self.get_component_info(handler_name, ComponentType.EVENT_HANDLER)
|
|
||||||
return info if isinstance(info, EventHandlerInfo) else None
|
|
||||||
|
|
||||||
def get_enabled_event_handlers(self) -> Dict[str, Type[BaseEventHandler]]:
|
|
||||||
"""获取启用的事件处理器"""
|
|
||||||
return self._enabled_event_handlers.copy()
|
|
||||||
|
|
||||||
# === 插件查询方法 ===
|
|
||||||
|
|
||||||
def get_plugin_info(self, plugin_name: str) -> Optional[PluginInfo]:
|
|
||||||
"""获取插件信息"""
|
|
||||||
return self._plugins.get(plugin_name)
|
|
||||||
|
|
||||||
def get_all_plugins(self) -> Dict[str, PluginInfo]:
|
|
||||||
"""获取所有插件"""
|
|
||||||
return self._plugins.copy()
|
|
||||||
|
|
||||||
# def get_enabled_plugins(self) -> Dict[str, PluginInfo]:
|
|
||||||
# """获取所有启用的插件"""
|
|
||||||
# return {name: info for name, info in self._plugins.items() if info.enabled}
|
|
||||||
|
|
||||||
def get_plugin_components(self, plugin_name: str) -> List[ComponentInfo]:
|
|
||||||
"""获取插件的所有组件"""
|
|
||||||
plugin_info = self.get_plugin_info(plugin_name)
|
|
||||||
return plugin_info.components if plugin_info else []
|
|
||||||
|
|
||||||
def get_plugin_config(self, plugin_name: str) -> Optional[dict]:
|
|
||||||
"""获取插件配置
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin_name: 插件名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[dict]: 插件配置字典或None
|
|
||||||
"""
|
|
||||||
# 从插件管理器获取插件实例的配置
|
|
||||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
|
||||||
|
|
||||||
plugin_instance = plugin_manager.get_plugin_instance(plugin_name)
|
|
||||||
return plugin_instance.config if plugin_instance else None
|
|
||||||
|
|
||||||
def get_registry_stats(self) -> Dict[str, Any]:
|
|
||||||
"""获取注册中心统计信息"""
|
|
||||||
action_components: int = 0
|
|
||||||
command_components: int = 0
|
|
||||||
tool_components: int = 0
|
|
||||||
events_handlers: int = 0
|
|
||||||
for component in self._components.values():
|
|
||||||
if component.component_type == ComponentType.ACTION:
|
|
||||||
action_components += 1
|
|
||||||
elif component.component_type == ComponentType.COMMAND:
|
|
||||||
command_components += 1
|
|
||||||
elif component.component_type == ComponentType.TOOL:
|
|
||||||
tool_components += 1
|
|
||||||
elif component.component_type == ComponentType.EVENT_HANDLER:
|
|
||||||
events_handlers += 1
|
|
||||||
|
|
||||||
workflow_step_count = sum(len(steps) for steps in self._workflow_steps.values())
|
|
||||||
enabled_workflow_step_count = sum(
|
|
||||||
len([step for step in steps.values() if step.enabled]) for steps in self._workflow_steps.values()
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"action_components": action_components,
|
|
||||||
"command_components": command_components,
|
|
||||||
"tool_components": tool_components,
|
|
||||||
"event_handlers": events_handlers,
|
|
||||||
"total_components": len(self._components),
|
|
||||||
"total_plugins": len(self._plugins),
|
|
||||||
"components_by_type": {
|
|
||||||
component_type.value: len(components) for component_type, components in self._components_by_type.items()
|
|
||||||
},
|
|
||||||
"enabled_components": len([c for c in self._components.values() if c.enabled]),
|
|
||||||
"enabled_plugins": len([p for p in self._plugins.values() if p.enabled]),
|
|
||||||
"workflow_steps": workflow_step_count,
|
|
||||||
"enabled_workflow_steps": enabled_workflow_step_count,
|
|
||||||
"workflow_steps_by_stage": {stage.value: len(steps) for stage, steps in self._workflow_steps.items()},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
component_registry = ComponentRegistry()
|
|
||||||
@@ -1,512 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import contextlib
|
|
||||||
from typing import List, Dict, Optional, Type, Tuple, TYPE_CHECKING
|
|
||||||
|
|
||||||
from src.chat.message_receive.message import MessageSending, SessionMessage
|
|
||||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages, CustomEventHandlerResult
|
|
||||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
|
||||||
from src.plugin_system.base.workflow_types import WorkflowContext, WorkflowStepResult
|
|
||||||
from .global_announcement_manager import global_announcement_manager
|
|
||||||
from .workflow_engine import workflow_engine
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from src.common.data_models.llm_data_model import LLMGenerationDataModel
|
|
||||||
|
|
||||||
logger = get_logger("events_manager")
|
|
||||||
|
|
||||||
|
|
||||||
class EventsManager:
|
|
||||||
def __init__(self):
|
|
||||||
# 有权重的 events 订阅者注册表
|
|
||||||
self._events_subscribers: Dict[EventType | str, List[BaseEventHandler]] = {}
|
|
||||||
self._handler_mapping: Dict[str, Type[BaseEventHandler]] = {} # 事件处理器映射表
|
|
||||||
self._handler_tasks: Dict[str, List[asyncio.Task]] = {} # 事件处理器正在处理的任务
|
|
||||||
self._events_result_history: Dict[EventType | str, List[CustomEventHandlerResult]] = {} # 事件的结果历史记录
|
|
||||||
self._history_enable_map: Dict[EventType | str, bool] = {} # 是否启用历史记录的映射表,同时作为events注册表
|
|
||||||
|
|
||||||
# 事件注册(同时作为注册样例)
|
|
||||||
for event in EventType:
|
|
||||||
self.register_event(event, enable_history_result=False)
|
|
||||||
|
|
||||||
def register_event(self, event_type: EventType | str, enable_history_result: bool = False):
|
|
||||||
if event_type in self._events_subscribers:
|
|
||||||
raise ValueError(f"事件类型 {event_type} 已存在")
|
|
||||||
self._events_subscribers[event_type] = []
|
|
||||||
self._history_enable_map[event_type] = enable_history_result
|
|
||||||
if enable_history_result:
|
|
||||||
self._events_result_history[event_type] = []
|
|
||||||
|
|
||||||
def register_event_subscriber(self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]) -> bool:
|
|
||||||
"""注册事件处理器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
handler_info (EventHandlerInfo): 事件处理器信息
|
|
||||||
handler_class (Type[BaseEventHandler]): 事件处理器类
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否注册成功
|
|
||||||
"""
|
|
||||||
if not issubclass(handler_class, BaseEventHandler):
|
|
||||||
logger.error(f"类 {handler_class.__name__} 不是 BaseEventHandler 的子类")
|
|
||||||
return False
|
|
||||||
|
|
||||||
handler_name = handler_info.name
|
|
||||||
|
|
||||||
if handler_name in self._handler_mapping:
|
|
||||||
logger.warning(f"事件处理器 {handler_name} 已存在,跳过注册")
|
|
||||||
return False
|
|
||||||
|
|
||||||
if handler_info.event_type not in self._history_enable_map:
|
|
||||||
if isinstance(handler_info.event_type, str):
|
|
||||||
self.register_event(handler_info.event_type, enable_history_result=False)
|
|
||||||
logger.info(f"自动注册自定义事件类型: {handler_info.event_type}")
|
|
||||||
else:
|
|
||||||
logger.error(f"事件类型 {handler_info.event_type} 未注册,无法为其注册处理器 {handler_name}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
self._handler_mapping[handler_name] = handler_class
|
|
||||||
return self._insert_event_handler(handler_class, handler_info)
|
|
||||||
|
|
||||||
async def handle_mai_events(
|
|
||||||
self,
|
|
||||||
event_type: EventType | str,
|
|
||||||
message: Optional[SessionMessage | MessageSending | MaiMessages] = None,
|
|
||||||
llm_prompt: Optional[str] = None,
|
|
||||||
llm_response: Optional["LLMGenerationDataModel"] = None,
|
|
||||||
stream_id: Optional[str] = None,
|
|
||||||
action_usage: Optional[List[str]] = None,
|
|
||||||
) -> Tuple[bool, Optional[MaiMessages]]:
|
|
||||||
"""
|
|
||||||
处理所有事件,根据事件类型分发给订阅的处理器。
|
|
||||||
"""
|
|
||||||
from src.plugin_system.core import component_registry
|
|
||||||
|
|
||||||
continue_flag = True
|
|
||||||
|
|
||||||
# 1. 准备消息
|
|
||||||
transformed_message = self._prepare_message(
|
|
||||||
event_type, message, llm_prompt, llm_response, stream_id, action_usage # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
if transformed_message:
|
|
||||||
transformed_message = transformed_message.deepcopy()
|
|
||||||
|
|
||||||
# 2. 获取并遍历处理器
|
|
||||||
handlers = self._events_subscribers.get(event_type, [])
|
|
||||||
if not handlers:
|
|
||||||
return True, None
|
|
||||||
|
|
||||||
current_stream_id = transformed_message.stream_id if transformed_message else None
|
|
||||||
modified_message: Optional[MaiMessages] = None
|
|
||||||
for handler in handlers:
|
|
||||||
# 3. 前置检查和配置加载
|
|
||||||
if (
|
|
||||||
current_stream_id
|
|
||||||
and handler.handler_name
|
|
||||||
in global_announcement_manager.get_disabled_chat_event_handlers(current_stream_id)
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 统一加载插件配置
|
|
||||||
plugin_config = component_registry.get_plugin_config(handler.plugin_name) or {}
|
|
||||||
handler.set_plugin_config(plugin_config)
|
|
||||||
|
|
||||||
# 4. 根据类型分发任务
|
|
||||||
if (
|
|
||||||
handler.intercept_message or event_type == EventType.ON_STOP
|
|
||||||
): # 让ON_STOP的所有事件处理器都发挥作用,防止还没执行即被取消
|
|
||||||
# 阻塞执行,并更新 continue_flag
|
|
||||||
should_continue, modified_message = await self._dispatch_intercepting_handler_task(
|
|
||||||
handler, event_type, modified_message or transformed_message
|
|
||||||
)
|
|
||||||
continue_flag = continue_flag and should_continue
|
|
||||||
else:
|
|
||||||
# 异步执行,不阻塞
|
|
||||||
self._dispatch_handler_task(handler, event_type, transformed_message)
|
|
||||||
|
|
||||||
# 桥接到新版本插件运行时
|
|
||||||
continue_flag, modified_message = await self._bridge_to_new_runtime(
|
|
||||||
event_type, continue_flag, modified_message or transformed_message
|
|
||||||
)
|
|
||||||
|
|
||||||
return continue_flag, modified_message
|
|
||||||
|
|
||||||
async def handle_workflow_message(
|
|
||||||
self,
|
|
||||||
message: Optional[SessionMessage | MessageSending | MaiMessages] = None,
|
|
||||||
stream_id: Optional[str] = None,
|
|
||||||
action_usage: Optional[List[str]] = None,
|
|
||||||
context: Optional[WorkflowContext] = None,
|
|
||||||
) -> Tuple[WorkflowStepResult, Optional[MaiMessages], WorkflowContext]:
|
|
||||||
"""执行线性workflow消息流转(MVP兼容入口)。"""
|
|
||||||
initial_message = self._prepare_message(EventType.ON_MESSAGE_PRE_PROCESS, message=message, stream_id=stream_id)
|
|
||||||
|
|
||||||
async def _dispatch(
|
|
||||||
event_type: EventType | str,
|
|
||||||
workflow_message: Optional[MaiMessages],
|
|
||||||
workflow_stream_id: Optional[str],
|
|
||||||
workflow_action_usage: Optional[List[str]],
|
|
||||||
) -> Tuple[bool, Optional[MaiMessages]]:
|
|
||||||
return await self.handle_mai_events(
|
|
||||||
event_type=event_type,
|
|
||||||
message=workflow_message,
|
|
||||||
stream_id=workflow_stream_id,
|
|
||||||
action_usage=workflow_action_usage,
|
|
||||||
)
|
|
||||||
|
|
||||||
return await workflow_engine.execute_linear(
|
|
||||||
dispatch_event=_dispatch,
|
|
||||||
message=initial_message,
|
|
||||||
stream_id=stream_id,
|
|
||||||
action_usage=action_usage,
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def cancel_handler_tasks(self, handler_name: str) -> None:
|
|
||||||
tasks_to_be_cancelled = self._handler_tasks.get(handler_name, [])
|
|
||||||
if remaining_tasks := [task for task in tasks_to_be_cancelled if not task.done()]:
|
|
||||||
for task in remaining_tasks:
|
|
||||||
task.cancel()
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(asyncio.gather(*remaining_tasks, return_exceptions=True), timeout=5)
|
|
||||||
logger.info(f"已取消事件处理器 {handler_name} 的所有任务")
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.warning(f"取消事件处理器 {handler_name} 的任务超时,开始强制取消")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"取消事件处理器 {handler_name} 的任务时发生异常: {e}")
|
|
||||||
if handler_name in self._handler_tasks:
|
|
||||||
del self._handler_tasks[handler_name]
|
|
||||||
|
|
||||||
async def unregister_event_subscriber(self, handler_name: str) -> bool:
|
|
||||||
"""取消注册事件处理器"""
|
|
||||||
if handler_name not in self._handler_mapping:
|
|
||||||
logger.warning(f"事件处理器 {handler_name} 不存在,无法取消注册")
|
|
||||||
return False
|
|
||||||
|
|
||||||
await self.cancel_handler_tasks(handler_name)
|
|
||||||
|
|
||||||
handler_class = self._handler_mapping.pop(handler_name)
|
|
||||||
if not self._remove_event_handler_instance(handler_class):
|
|
||||||
return False
|
|
||||||
|
|
||||||
logger.info(f"事件处理器 {handler_name} 已成功取消注册")
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def get_event_result_history(self, event_type: EventType | str) -> List[CustomEventHandlerResult]:
|
|
||||||
"""获取事件的结果历史记录"""
|
|
||||||
if event_type == EventType.UNKNOWN:
|
|
||||||
raise ValueError("未知事件类型")
|
|
||||||
if event_type not in self._history_enable_map:
|
|
||||||
raise ValueError(f"事件类型 {event_type} 未注册")
|
|
||||||
if not self._history_enable_map[event_type]:
|
|
||||||
raise ValueError(f"事件类型 {event_type} 的历史记录未启用")
|
|
||||||
|
|
||||||
return self._events_result_history[event_type]
|
|
||||||
|
|
||||||
async def clear_event_result_history(self, event_type: EventType | str) -> None:
|
|
||||||
"""清空事件的结果历史记录"""
|
|
||||||
if event_type == EventType.UNKNOWN:
|
|
||||||
raise ValueError("未知事件类型")
|
|
||||||
if event_type not in self._history_enable_map:
|
|
||||||
raise ValueError(f"事件类型 {event_type} 未注册")
|
|
||||||
if not self._history_enable_map[event_type]:
|
|
||||||
raise ValueError(f"事件类型 {event_type} 的历史记录未启用")
|
|
||||||
|
|
||||||
self._events_result_history[event_type] = []
|
|
||||||
|
|
||||||
def _insert_event_handler(self, handler_class: Type[BaseEventHandler], handler_info: EventHandlerInfo) -> bool:
|
|
||||||
"""插入事件处理器到对应的事件类型列表中并设置其插件配置"""
|
|
||||||
if handler_class.event_type == EventType.UNKNOWN:
|
|
||||||
logger.error(f"事件处理器 {handler_class.__name__} 的事件类型未知,无法注册")
|
|
||||||
return False
|
|
||||||
if handler_class.event_type not in self._events_subscribers:
|
|
||||||
self._events_subscribers[handler_class.event_type] = []
|
|
||||||
handler_instance = handler_class()
|
|
||||||
handler_instance.set_plugin_name(handler_info.plugin_name or "unknown")
|
|
||||||
self._events_subscribers[handler_class.event_type].append(handler_instance)
|
|
||||||
self._events_subscribers[handler_class.event_type].sort(key=lambda x: x.weight, reverse=True)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _remove_event_handler_instance(self, handler_class: Type[BaseEventHandler]) -> bool:
|
|
||||||
"""从事件类型列表中移除事件处理器"""
|
|
||||||
display_handler_name = handler_class.handler_name or handler_class.__name__
|
|
||||||
if handler_class.event_type == EventType.UNKNOWN:
|
|
||||||
logger.warning(f"事件处理器 {display_handler_name} 的事件类型未知,不存在于处理器列表中")
|
|
||||||
return False
|
|
||||||
|
|
||||||
handlers = self._events_subscribers[handler_class.event_type]
|
|
||||||
for i, handler in enumerate(handlers):
|
|
||||||
if isinstance(handler, handler_class):
|
|
||||||
del handlers[i]
|
|
||||||
logger.debug(f"事件处理器 {display_handler_name} 已移除")
|
|
||||||
return True
|
|
||||||
|
|
||||||
logger.warning(f"未找到事件处理器 {display_handler_name},无法移除")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _transform_event_message(
|
|
||||||
self,
|
|
||||||
message: SessionMessage | MessageSending,
|
|
||||||
llm_prompt: Optional[str] = None,
|
|
||||||
llm_response: Optional["LLMGenerationDataModel"] = None,
|
|
||||||
) -> MaiMessages:
|
|
||||||
"""转换事件消息格式"""
|
|
||||||
from maim_message import Seg
|
|
||||||
|
|
||||||
# 直接赋值部分内容
|
|
||||||
transformed_message = MaiMessages(
|
|
||||||
llm_prompt=llm_prompt,
|
|
||||||
llm_response_content=llm_response.content if llm_response else None,
|
|
||||||
llm_response_reasoning=llm_response.reasoning if llm_response else None,
|
|
||||||
llm_response_model=llm_response.model if llm_response else None,
|
|
||||||
llm_response_tool_call=llm_response.tool_calls if llm_response else None,
|
|
||||||
raw_message=message.processed_plain_text or "",
|
|
||||||
additional_data={},
|
|
||||||
)
|
|
||||||
|
|
||||||
# 消息段处理
|
|
||||||
if isinstance(message, MessageSending):
|
|
||||||
if message.message_segment.type == "seglist":
|
|
||||||
transformed_message.message_segments = list(message.message_segment.data) # type: ignore
|
|
||||||
else:
|
|
||||||
transformed_message.message_segments = [message.message_segment]
|
|
||||||
else:
|
|
||||||
# SessionMessage: 使用 processed_plain_text 构造简单段
|
|
||||||
transformed_message.message_segments = [Seg(type="text", data=message.processed_plain_text or "")]
|
|
||||||
|
|
||||||
# stream_id 处理
|
|
||||||
transformed_message.stream_id = message.session_id if hasattr(message, "session_id") else ""
|
|
||||||
|
|
||||||
# 处理后文本
|
|
||||||
transformed_message.plain_text = message.processed_plain_text
|
|
||||||
|
|
||||||
# 基本信息
|
|
||||||
if isinstance(message, MessageSending):
|
|
||||||
transformed_message.message_base_info["platform"] = message.platform
|
|
||||||
if message.session.group_id:
|
|
||||||
transformed_message.is_group_message = True
|
|
||||||
group_name = ""
|
|
||||||
if message.session.context and message.session.context.message and message.session.context.message.message_info.group_info:
|
|
||||||
group_name = message.session.context.message.message_info.group_info.group_name
|
|
||||||
transformed_message.message_base_info.update({
|
|
||||||
"group_id": message.session.group_id,
|
|
||||||
"group_name": group_name,
|
|
||||||
})
|
|
||||||
transformed_message.message_base_info.update({
|
|
||||||
"user_id": message.bot_user_info.user_id,
|
|
||||||
"user_cardname": message.bot_user_info.user_cardname,
|
|
||||||
"user_nickname": message.bot_user_info.user_nickname,
|
|
||||||
})
|
|
||||||
if not transformed_message.is_group_message:
|
|
||||||
transformed_message.is_private_message = True
|
|
||||||
elif hasattr(message, "message_info") and message.message_info:
|
|
||||||
if message.platform:
|
|
||||||
transformed_message.message_base_info["platform"] = message.platform
|
|
||||||
if message.message_info.group_info:
|
|
||||||
transformed_message.is_group_message = True
|
|
||||||
transformed_message.message_base_info.update({
|
|
||||||
"group_id": message.message_info.group_info.group_id,
|
|
||||||
"group_name": message.message_info.group_info.group_name,
|
|
||||||
})
|
|
||||||
if message.message_info.user_info:
|
|
||||||
if not transformed_message.is_group_message:
|
|
||||||
transformed_message.is_private_message = True
|
|
||||||
transformed_message.message_base_info.update({
|
|
||||||
"user_id": message.message_info.user_info.user_id,
|
|
||||||
"user_cardname": message.message_info.user_info.user_cardname,
|
|
||||||
"user_nickname": message.message_info.user_info.user_nickname,
|
|
||||||
})
|
|
||||||
|
|
||||||
return transformed_message
|
|
||||||
|
|
||||||
def _build_message_from_stream(
|
|
||||||
self, stream_id: str, llm_prompt: Optional[str] = None, llm_response: Optional["LLMGenerationDataModel"] = None
|
|
||||||
) -> MaiMessages:
|
|
||||||
"""从流ID构建消息"""
|
|
||||||
session = _chat_manager.get_session_by_session_id(stream_id)
|
|
||||||
assert session, f"未找到流ID为 {stream_id} 的会话"
|
|
||||||
message = session.context.message
|
|
||||||
return self._transform_event_message(message, llm_prompt, llm_response)
|
|
||||||
|
|
||||||
def _transform_event_without_message(
|
|
||||||
self,
|
|
||||||
stream_id: str,
|
|
||||||
llm_prompt: Optional[str] = None,
|
|
||||||
llm_response: Optional["LLMGenerationDataModel"] = None,
|
|
||||||
action_usage: Optional[List[str]] = None,
|
|
||||||
) -> MaiMessages:
|
|
||||||
"""没有message对象时进行转换"""
|
|
||||||
session = _chat_manager.get_session_by_session_id(stream_id)
|
|
||||||
assert session, f"未找到流ID为 {stream_id} 的会话"
|
|
||||||
return MaiMessages(
|
|
||||||
stream_id=stream_id,
|
|
||||||
llm_prompt=llm_prompt,
|
|
||||||
llm_response_content=(llm_response.content if llm_response else None),
|
|
||||||
llm_response_reasoning=(llm_response.reasoning if llm_response else None),
|
|
||||||
llm_response_model=(llm_response.model if llm_response else None),
|
|
||||||
llm_response_tool_call=(llm_response.tool_calls if llm_response else None),
|
|
||||||
is_group_message=session.is_group_session,
|
|
||||||
is_private_message=not session.is_group_session,
|
|
||||||
action_usage=action_usage,
|
|
||||||
additional_data={"response_is_processed": True},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _bridge_to_new_runtime(
|
|
||||||
self,
|
|
||||||
event_type: EventType | str,
|
|
||||||
continue_flag: bool,
|
|
||||||
message: Optional[MaiMessages],
|
|
||||||
) -> Tuple[bool, Optional[MaiMessages]]:
|
|
||||||
"""将事件桥接到新版本插件运行时
|
|
||||||
|
|
||||||
如果旧 handler 已经 abort(continue_flag=False),直接跳过。
|
|
||||||
"""
|
|
||||||
if not continue_flag:
|
|
||||||
return continue_flag, message
|
|
||||||
|
|
||||||
try:
|
|
||||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
|
||||||
|
|
||||||
prm = get_plugin_runtime_manager()
|
|
||||||
if not prm.is_running:
|
|
||||||
return continue_flag, message
|
|
||||||
|
|
||||||
event_value = event_type.value if isinstance(event_type, EventType) else str(event_type)
|
|
||||||
message_dict = message.to_dict() if message and hasattr(message, "to_dict") else None
|
|
||||||
|
|
||||||
new_continue, new_msg_dict = await prm.bridge_event(
|
|
||||||
event_type_value=event_value,
|
|
||||||
message_dict=message_dict,
|
|
||||||
)
|
|
||||||
# 新运行时返回 abort 则合并
|
|
||||||
if not new_continue:
|
|
||||||
continue_flag = False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"桥接事件到新运行时失败: {e}")
|
|
||||||
|
|
||||||
return continue_flag, message
|
|
||||||
|
|
||||||
def _prepare_message(
|
|
||||||
self,
|
|
||||||
event_type: EventType | str,
|
|
||||||
message: Optional[SessionMessage | MessageSending | MaiMessages] = None,
|
|
||||||
llm_prompt: Optional[str] = None,
|
|
||||||
llm_response: Optional["LLMGenerationDataModel"] = None,
|
|
||||||
stream_id: Optional[str] = None,
|
|
||||||
action_usage: Optional[List[str]] = None,
|
|
||||||
) -> Optional[MaiMessages]:
|
|
||||||
"""根据事件类型和输入,准备和转换消息对象。"""
|
|
||||||
if isinstance(message, MaiMessages):
|
|
||||||
return message.deepcopy()
|
|
||||||
|
|
||||||
if message:
|
|
||||||
return self._transform_event_message(message, llm_prompt, llm_response)
|
|
||||||
|
|
||||||
if event_type not in [EventType.ON_START, EventType.ON_STOP]:
|
|
||||||
assert stream_id, "如果没有消息,必须为非启动/关闭事件提供流ID"
|
|
||||||
if event_type in [EventType.ON_MESSAGE, EventType.ON_PLAN, EventType.POST_LLM, EventType.AFTER_LLM]:
|
|
||||||
return self._build_message_from_stream(stream_id, llm_prompt, llm_response)
|
|
||||||
else:
|
|
||||||
return self._transform_event_without_message(stream_id, llm_prompt, llm_response, action_usage)
|
|
||||||
|
|
||||||
return None # ON_START, ON_STOP事件没有消息体
|
|
||||||
|
|
||||||
def _dispatch_handler_task(
|
|
||||||
self, handler: BaseEventHandler, event_type: EventType | str, message: Optional[MaiMessages] = None
|
|
||||||
):
|
|
||||||
"""分发一个非阻塞(异步)的事件处理任务。"""
|
|
||||||
if event_type == EventType.UNKNOWN:
|
|
||||||
raise ValueError("未知事件类型")
|
|
||||||
try:
|
|
||||||
task = asyncio.create_task(handler.execute(message))
|
|
||||||
|
|
||||||
task_name = f"{handler.plugin_name}-{handler.handler_name}"
|
|
||||||
task.set_name(task_name)
|
|
||||||
task.add_done_callback(lambda t: self._task_done_callback(t, event_type))
|
|
||||||
|
|
||||||
self._handler_tasks.setdefault(handler.handler_name, []).append(task)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"创建事件处理器任务 {handler.handler_name} 时发生异常: {e}", exc_info=True)
|
|
||||||
|
|
||||||
async def _dispatch_intercepting_handler_task(
|
|
||||||
self, handler: BaseEventHandler, event_type: EventType | str, message: Optional[MaiMessages] = None
|
|
||||||
) -> Tuple[bool, Optional[MaiMessages]]:
|
|
||||||
"""分发并等待一个阻塞(同步)的事件处理器,返回是否应继续处理。"""
|
|
||||||
if event_type == EventType.UNKNOWN:
|
|
||||||
raise ValueError("未知事件类型")
|
|
||||||
if event_type not in self._history_enable_map:
|
|
||||||
raise ValueError(f"事件类型 {event_type} 未注册")
|
|
||||||
try:
|
|
||||||
result = await handler.execute(message)
|
|
||||||
|
|
||||||
expected_fields = ["success", "continue_processing", "return_message", "custom_result", "modified_message"]
|
|
||||||
|
|
||||||
if not isinstance(result, tuple) or len(result) != 5:
|
|
||||||
if isinstance(result, tuple):
|
|
||||||
annotated = ", ".join(f"{name}={val!r}" for name, val in zip(expected_fields, result, strict=False))
|
|
||||||
actual_desc = f"{len(result)} 个元素 ({annotated})"
|
|
||||||
else:
|
|
||||||
actual_desc = f"非 tuple 类型: {type(result)}"
|
|
||||||
|
|
||||||
logger.error(
|
|
||||||
f"[{self.__class__.__name__}] EventHandler {handler.handler_name} 返回值不符合预期:\n"
|
|
||||||
f" 模块来源: {handler.__class__.__module__}.{handler.__class__.__name__}\n"
|
|
||||||
f" 期望: 5 个元素 ({', '.join(expected_fields)})\n"
|
|
||||||
f" 实际: {actual_desc}"
|
|
||||||
)
|
|
||||||
return True, None
|
|
||||||
|
|
||||||
success, continue_processing, return_message, custom_result, modified_message = result
|
|
||||||
|
|
||||||
if not success:
|
|
||||||
logger.error(f"EventHandler {handler.handler_name} 执行失败: {return_message}")
|
|
||||||
else:
|
|
||||||
logger.debug(f"EventHandler {handler.handler_name} 执行成功: {return_message}")
|
|
||||||
|
|
||||||
if self._history_enable_map[event_type] and custom_result:
|
|
||||||
self._events_result_history[event_type].append(custom_result)
|
|
||||||
|
|
||||||
return continue_processing, modified_message
|
|
||||||
|
|
||||||
except KeyError:
|
|
||||||
logger.error(f"事件 {event_type} 注册的历史记录启用情况与实际不符合")
|
|
||||||
return True, None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"EventHandler {handler.handler_name} 发生异常: {e}", exc_info=True)
|
|
||||||
return True, None # 发生异常时默认不中断其他处理
|
|
||||||
|
|
||||||
def _task_done_callback(
|
|
||||||
self,
|
|
||||||
task: asyncio.Task[Tuple[bool, bool, str | None, CustomEventHandlerResult | None, MaiMessages | None]],
|
|
||||||
event_type: EventType | str,
|
|
||||||
):
|
|
||||||
"""任务完成回调"""
|
|
||||||
task_name = task.get_name() or "Unknown Task"
|
|
||||||
if event_type == EventType.UNKNOWN:
|
|
||||||
raise ValueError("未知事件类型")
|
|
||||||
if event_type not in self._history_enable_map:
|
|
||||||
raise ValueError(f"事件类型 {event_type} 未注册")
|
|
||||||
try:
|
|
||||||
success, _, result, custom_result, _ = task.result() # 忽略是否继续的标志和消息的修改,因为消息本身未被拦截
|
|
||||||
if success:
|
|
||||||
logger.debug(f"事件处理任务 {task_name} 已成功完成: {result}")
|
|
||||||
else:
|
|
||||||
logger.error(f"事件处理任务 {task_name} 执行失败: {result}")
|
|
||||||
|
|
||||||
if self._history_enable_map[event_type] and custom_result:
|
|
||||||
self._events_result_history[event_type].append(custom_result)
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
except KeyError:
|
|
||||||
logger.error(f"事件 {event_type} 注册的历史记录启用情况与实际不符合")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"事件处理任务 {task_name} 发生异常: {e}")
|
|
||||||
finally:
|
|
||||||
with contextlib.suppress(ValueError, KeyError):
|
|
||||||
self._handler_tasks[task_name].remove(task)
|
|
||||||
|
|
||||||
|
|
||||||
events_manager = EventsManager()
|
|
||||||
@@ -1,634 +0,0 @@
|
|||||||
from importlib.util import spec_from_file_location, module_from_spec
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List, Optional, Set, Tuple, Type
|
|
||||||
from collections import deque
|
|
||||||
import os
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.plugin_system.base.plugin_base import PluginBase
|
|
||||||
from src.plugin_system.base.component_types import ComponentType
|
|
||||||
from src.plugin_system.utils.manifest_utils import VersionComparator
|
|
||||||
from .component_registry import component_registry
|
|
||||||
from .plugin_service_registry import plugin_service_registry
|
|
||||||
|
|
||||||
logger = get_logger("plugin_manager")
|
|
||||||
|
|
||||||
|
|
||||||
class PluginManager:
|
|
||||||
"""
|
|
||||||
插件管理器类
|
|
||||||
|
|
||||||
负责加载,重载和卸载插件,同时管理插件的所有组件
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.plugin_directories: List[str] = [] # 插件根目录列表
|
|
||||||
self.plugin_classes: Dict[str, Type[PluginBase]] = {} # 全局插件类注册表,插件名 -> 插件类
|
|
||||||
self.plugin_paths: Dict[str, str] = {} # 记录插件名到目录路径的映射,插件名 -> 目录路径
|
|
||||||
|
|
||||||
self.loaded_plugins: Dict[str, PluginBase] = {} # 已加载的插件类实例注册表,插件名 -> 插件类实例
|
|
||||||
self.failed_plugins: Dict[str, str] = {} # 记录加载失败的插件文件及其错误信息,插件名 -> 错误信息
|
|
||||||
|
|
||||||
# 确保插件目录存在
|
|
||||||
self._ensure_plugin_directories()
|
|
||||||
logger.info("插件管理器初始化完成")
|
|
||||||
|
|
||||||
# === 插件目录管理 ===
|
|
||||||
|
|
||||||
def add_plugin_directory(self, directory: str) -> bool:
|
|
||||||
"""添加插件目录"""
|
|
||||||
if os.path.exists(directory):
|
|
||||||
if directory not in self.plugin_directories:
|
|
||||||
self.plugin_directories.append(directory)
|
|
||||||
logger.debug(f"已添加插件目录: {directory}")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
logger.warning(f"插件不可重复加载: {directory}")
|
|
||||||
else:
|
|
||||||
logger.warning(f"插件目录不存在: {directory}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# === 插件加载管理 ===
|
|
||||||
|
|
||||||
def load_all_plugins(self) -> Tuple[int, int]:
|
|
||||||
"""加载所有插件
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[int, int]: (插件数量, 组件数量)
|
|
||||||
"""
|
|
||||||
logger.debug("开始加载所有插件...")
|
|
||||||
|
|
||||||
# 第一阶段:加载所有插件模块(注册插件类)
|
|
||||||
total_loaded_modules = 0
|
|
||||||
total_failed_modules = 0
|
|
||||||
|
|
||||||
for directory in self.plugin_directories:
|
|
||||||
loaded, failed = self._load_plugin_modules_from_directory(directory)
|
|
||||||
total_loaded_modules += loaded
|
|
||||||
total_failed_modules += failed
|
|
||||||
|
|
||||||
logger.debug(f"插件模块加载完成 - 成功: {total_loaded_modules}, 失败: {total_failed_modules}")
|
|
||||||
|
|
||||||
# 第二阶段前:根据依赖关系决定插件加载顺序
|
|
||||||
dependency_graph, missing_dependencies = self._build_plugin_dependency_graph()
|
|
||||||
sorted_plugins, cycle_plugins = self._resolve_plugin_load_order(dependency_graph)
|
|
||||||
|
|
||||||
total_registered = 0
|
|
||||||
total_failed_registration = 0
|
|
||||||
|
|
||||||
# 先处理缺失依赖
|
|
||||||
for plugin_name, missing in missing_dependencies.items():
|
|
||||||
if not missing:
|
|
||||||
continue
|
|
||||||
missing_dep_names = ", ".join(sorted(missing))
|
|
||||||
self.failed_plugins[plugin_name] = f"缺少依赖插件: {missing_dep_names}"
|
|
||||||
logger.error(f"❌ 插件加载失败: {plugin_name} - 缺少依赖插件: {missing_dep_names}")
|
|
||||||
total_failed_registration += 1
|
|
||||||
|
|
||||||
# 再处理循环依赖
|
|
||||||
for plugin_name in sorted(cycle_plugins):
|
|
||||||
if plugin_name in missing_dependencies and missing_dependencies[plugin_name]:
|
|
||||||
continue
|
|
||||||
self.failed_plugins[plugin_name] = "检测到循环依赖"
|
|
||||||
logger.error(f"❌ 插件加载失败: {plugin_name} - 检测到循环依赖")
|
|
||||||
total_failed_registration += 1
|
|
||||||
|
|
||||||
# 最后按拓扑序加载可加载插件
|
|
||||||
for plugin_name in sorted_plugins:
|
|
||||||
if plugin_name in cycle_plugins:
|
|
||||||
continue
|
|
||||||
if plugin_name in missing_dependencies and missing_dependencies[plugin_name]:
|
|
||||||
continue
|
|
||||||
load_status, count = self.load_registered_plugin_classes(plugin_name)
|
|
||||||
if load_status:
|
|
||||||
total_registered += 1
|
|
||||||
else:
|
|
||||||
total_failed_registration += count
|
|
||||||
|
|
||||||
self._show_stats(total_registered, total_failed_registration)
|
|
||||||
|
|
||||||
return total_registered, total_failed_registration
|
|
||||||
|
|
||||||
def load_registered_plugin_classes(self, plugin_name: str) -> Tuple[bool, int]:
|
|
||||||
# sourcery skip: extract-duplicate-method, extract-method
|
|
||||||
"""
|
|
||||||
加载已经注册的插件类
|
|
||||||
"""
|
|
||||||
plugin_class = self.plugin_classes.get(plugin_name)
|
|
||||||
if not plugin_class:
|
|
||||||
logger.error(f"插件 {plugin_name} 的插件类未注册或不存在")
|
|
||||||
return False, 1
|
|
||||||
try:
|
|
||||||
# 使用记录的插件目录路径
|
|
||||||
plugin_dir = self.plugin_paths.get(plugin_name)
|
|
||||||
|
|
||||||
# 如果没有记录,直接返回失败
|
|
||||||
if not plugin_dir:
|
|
||||||
return False, 1
|
|
||||||
|
|
||||||
plugin_instance = plugin_class(plugin_dir=plugin_dir) # 实例化插件(可能因为缺少manifest而失败)
|
|
||||||
if not plugin_instance:
|
|
||||||
logger.error(f"插件 {plugin_name} 实例化失败")
|
|
||||||
return False, 1
|
|
||||||
# 检查插件是否启用
|
|
||||||
if not plugin_instance.enable_plugin:
|
|
||||||
logger.info(f"插件 {plugin_name} 已禁用,跳过加载")
|
|
||||||
return False, 0
|
|
||||||
|
|
||||||
# 检查版本兼容性
|
|
||||||
is_compatible, compatibility_error = self._check_plugin_version_compatibility(
|
|
||||||
plugin_name, plugin_instance.manifest_data
|
|
||||||
)
|
|
||||||
if not is_compatible:
|
|
||||||
self.failed_plugins[plugin_name] = compatibility_error
|
|
||||||
logger.error(f"❌ 插件加载失败: {plugin_name} - {compatibility_error}")
|
|
||||||
return False, 1
|
|
||||||
if plugin_instance.register_plugin():
|
|
||||||
self.loaded_plugins[plugin_name] = plugin_instance
|
|
||||||
self._show_plugin_components(plugin_name)
|
|
||||||
return True, 1
|
|
||||||
else:
|
|
||||||
self.failed_plugins[plugin_name] = "插件注册失败"
|
|
||||||
logger.error(f"❌ 插件注册失败: {plugin_name}")
|
|
||||||
return False, 1
|
|
||||||
|
|
||||||
except FileNotFoundError as e:
|
|
||||||
# manifest文件缺失
|
|
||||||
error_msg = f"缺少manifest文件: {str(e)}"
|
|
||||||
self.failed_plugins[plugin_name] = error_msg
|
|
||||||
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
|
|
||||||
return False, 1
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
# manifest文件格式错误或验证失败
|
|
||||||
traceback.print_exc()
|
|
||||||
error_msg = f"manifest验证失败: {str(e)}"
|
|
||||||
self.failed_plugins[plugin_name] = error_msg
|
|
||||||
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
|
|
||||||
return False, 1
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# 其他错误
|
|
||||||
error_msg = f"未知错误: {str(e)}"
|
|
||||||
self.failed_plugins[plugin_name] = error_msg
|
|
||||||
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
|
|
||||||
logger.debug("详细错误信息: ", exc_info=True)
|
|
||||||
return False, 1
|
|
||||||
|
|
||||||
async def remove_registered_plugin(self, plugin_name: str) -> bool:
|
|
||||||
"""
|
|
||||||
禁用插件模块
|
|
||||||
"""
|
|
||||||
if not plugin_name:
|
|
||||||
raise ValueError("插件名称不能为空")
|
|
||||||
if plugin_name not in self.loaded_plugins:
|
|
||||||
logger.warning(f"插件 {plugin_name} 未加载")
|
|
||||||
return False
|
|
||||||
plugin_instance = self.loaded_plugins[plugin_name]
|
|
||||||
plugin_info = plugin_instance.plugin_info
|
|
||||||
success = True
|
|
||||||
for component in plugin_info.components:
|
|
||||||
success &= await component_registry.remove_component(component.name, component.component_type, plugin_name)
|
|
||||||
success &= component_registry.remove_plugin_registry(plugin_name)
|
|
||||||
plugin_service_registry.remove_services_by_plugin(plugin_name)
|
|
||||||
del self.loaded_plugins[plugin_name]
|
|
||||||
return success
|
|
||||||
|
|
||||||
async def reload_registered_plugin(self, plugin_name: str) -> bool:
|
|
||||||
"""
|
|
||||||
重载插件模块
|
|
||||||
"""
|
|
||||||
old_instance = self.loaded_plugins.get(plugin_name)
|
|
||||||
if not old_instance:
|
|
||||||
logger.warning(f"插件 {plugin_name} 未加载,无法重载")
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not await self.remove_registered_plugin(plugin_name):
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not self.load_registered_plugin_classes(plugin_name)[0]:
|
|
||||||
logger.error(f"插件 {plugin_name} 重载失败,开始回滚旧实例")
|
|
||||||
rollback_ok = await self._rollback_failed_reload(plugin_name, old_instance)
|
|
||||||
if rollback_ok:
|
|
||||||
logger.info(f"插件 {plugin_name} 已回滚到旧版本实例")
|
|
||||||
else:
|
|
||||||
logger.error(f"插件 {plugin_name} 回滚失败,插件当前不可用")
|
|
||||||
return False
|
|
||||||
|
|
||||||
logger.debug(f"插件 {plugin_name} 重载成功")
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _rollback_failed_reload(self, plugin_name: str, old_instance: PluginBase) -> bool:
|
|
||||||
"""重载失败后回滚旧实例。"""
|
|
||||||
try:
|
|
||||||
await component_registry.remove_components_by_plugin(plugin_name)
|
|
||||||
component_registry.remove_plugin_registry(plugin_name)
|
|
||||||
plugin_service_registry.remove_services_by_plugin(plugin_name)
|
|
||||||
|
|
||||||
if not old_instance.register_plugin():
|
|
||||||
logger.error(f"插件 {plugin_name} 回滚失败: 旧实例重新注册失败")
|
|
||||||
return False
|
|
||||||
|
|
||||||
self.loaded_plugins[plugin_name] = old_instance
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"插件 {plugin_name} 回滚异常: {e}", exc_info=True)
|
|
||||||
return False
|
|
||||||
|
|
||||||
def rescan_plugin_directory(self) -> Tuple[int, int]:
|
|
||||||
"""
|
|
||||||
重新扫描插件根目录
|
|
||||||
"""
|
|
||||||
total_success = 0
|
|
||||||
total_fail = 0
|
|
||||||
for directory in self.plugin_directories:
|
|
||||||
if os.path.exists(directory):
|
|
||||||
logger.debug(f"重新扫描插件根目录: {directory}")
|
|
||||||
success, fail = self._load_plugin_modules_from_directory(directory)
|
|
||||||
total_success += success
|
|
||||||
total_fail += fail
|
|
||||||
else:
|
|
||||||
logger.warning(f"插件根目录不存在: {directory}")
|
|
||||||
return total_success, total_fail
|
|
||||||
|
|
||||||
def get_plugin_instance(self, plugin_name: str) -> Optional["PluginBase"]:
|
|
||||||
"""获取插件实例
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin_name: 插件名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[BasePlugin]: 插件实例或None
|
|
||||||
"""
|
|
||||||
return self.loaded_plugins.get(plugin_name)
|
|
||||||
|
|
||||||
# === 查询方法 ===
|
|
||||||
def list_loaded_plugins(self) -> List[str]:
|
|
||||||
"""
|
|
||||||
列出所有当前加载的插件。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: 当前加载的插件名称列表。
|
|
||||||
"""
|
|
||||||
return list(self.loaded_plugins.keys())
|
|
||||||
|
|
||||||
def list_registered_plugins(self) -> List[str]:
|
|
||||||
"""
|
|
||||||
列出所有已注册的插件类。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: 已注册的插件类名称列表。
|
|
||||||
"""
|
|
||||||
return list(self.plugin_classes.keys())
|
|
||||||
|
|
||||||
def get_plugin_path(self, plugin_name: str) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
获取指定插件的路径。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin_name: 插件名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[str]: 插件目录的绝对路径,如果插件不存在则返回None。
|
|
||||||
"""
|
|
||||||
return self.plugin_paths.get(plugin_name)
|
|
||||||
|
|
||||||
# === 私有方法 ===
|
|
||||||
# == 目录管理 ==
|
|
||||||
def _ensure_plugin_directories(self) -> None:
|
|
||||||
"""确保所有插件根目录存在,如果不存在则创建"""
|
|
||||||
default_directories = ["src/plugins/built_in", "plugins"]
|
|
||||||
|
|
||||||
for directory in default_directories:
|
|
||||||
if not os.path.exists(directory):
|
|
||||||
os.makedirs(directory, exist_ok=True)
|
|
||||||
logger.info(f"创建插件根目录: {directory}")
|
|
||||||
if directory not in self.plugin_directories:
|
|
||||||
self.plugin_directories.append(directory)
|
|
||||||
logger.debug(f"已添加插件根目录: {directory}")
|
|
||||||
else:
|
|
||||||
logger.warning(f"根目录不可重复加载: {directory}")
|
|
||||||
|
|
||||||
# == 插件加载 ==
|
|
||||||
|
|
||||||
def _load_plugin_modules_from_directory(self, directory: str) -> tuple[int, int]:
|
|
||||||
"""从指定目录加载插件模块"""
|
|
||||||
loaded_count = 0
|
|
||||||
failed_count = 0
|
|
||||||
|
|
||||||
if not os.path.exists(directory):
|
|
||||||
logger.warning(f"插件根目录不存在: {directory}")
|
|
||||||
return 0, 1
|
|
||||||
|
|
||||||
logger.debug(f"正在扫描插件根目录: {directory}")
|
|
||||||
|
|
||||||
# 遍历目录中的所有包
|
|
||||||
for item in os.listdir(directory):
|
|
||||||
item_path = os.path.join(directory, item)
|
|
||||||
|
|
||||||
if os.path.isdir(item_path) and not item.startswith(".") and not item.startswith("__"):
|
|
||||||
plugin_file = os.path.join(item_path, "plugin.py")
|
|
||||||
if os.path.exists(plugin_file):
|
|
||||||
if self._load_plugin_module_file(plugin_file):
|
|
||||||
loaded_count += 1
|
|
||||||
else:
|
|
||||||
failed_count += 1
|
|
||||||
|
|
||||||
return loaded_count, failed_count
|
|
||||||
|
|
||||||
def _load_plugin_module_file(self, plugin_file: str) -> bool:
|
|
||||||
# sourcery skip: extract-method
|
|
||||||
"""加载单个插件模块文件
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin_file: 插件文件路径
|
|
||||||
plugin_name: 插件名称
|
|
||||||
plugin_dir: 插件目录路径
|
|
||||||
"""
|
|
||||||
# 生成模块名
|
|
||||||
plugin_path = Path(plugin_file)
|
|
||||||
module_name = ".".join(plugin_path.parent.parts)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 动态导入插件模块
|
|
||||||
spec = spec_from_file_location(module_name, plugin_file)
|
|
||||||
if spec is None or spec.loader is None:
|
|
||||||
logger.error(f"无法创建模块规范: {plugin_file}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
module = module_from_spec(spec)
|
|
||||||
module.__package__ = module_name # 设置模块包名
|
|
||||||
spec.loader.exec_module(module)
|
|
||||||
|
|
||||||
logger.debug(f"插件模块加载成功: {plugin_file}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"加载插件模块 {plugin_file} 失败: {e}"
|
|
||||||
logger.error(error_msg)
|
|
||||||
self.failed_plugins[module_name] = error_msg
|
|
||||||
return False
|
|
||||||
|
|
||||||
# == 依赖解析与加载顺序 ==
|
|
||||||
|
|
||||||
def _extract_declared_dependencies(self, plugin_name: str, plugin_class: Type[PluginBase]) -> Set[str]:
|
|
||||||
"""提取插件声明的依赖。
|
|
||||||
|
|
||||||
兼容声明格式:
|
|
||||||
- list[str]
|
|
||||||
- list[dict],其中dict至少包含name键
|
|
||||||
"""
|
|
||||||
dependencies: Set[str] = set()
|
|
||||||
raw_dependencies = getattr(plugin_class, "dependencies", [])
|
|
||||||
|
|
||||||
# 兼容错误声明
|
|
||||||
if isinstance(raw_dependencies, property):
|
|
||||||
logger.warning(f"插件 {plugin_name} 的 dependencies 未声明为类属性,将按无依赖处理")
|
|
||||||
return dependencies
|
|
||||||
if not isinstance(raw_dependencies, list):
|
|
||||||
logger.warning(f"插件 {plugin_name} 的 dependencies 不是列表,将按无依赖处理")
|
|
||||||
return dependencies
|
|
||||||
|
|
||||||
for dependency in raw_dependencies:
|
|
||||||
dependency_name = ""
|
|
||||||
if isinstance(dependency, str):
|
|
||||||
dependency_name = dependency.strip()
|
|
||||||
elif isinstance(dependency, dict):
|
|
||||||
dependency_name = str(dependency.get("name", "")).strip()
|
|
||||||
else:
|
|
||||||
logger.warning(f"插件 {plugin_name} 包含不支持的依赖声明类型: {type(dependency)}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not dependency_name:
|
|
||||||
continue
|
|
||||||
if dependency_name == plugin_name:
|
|
||||||
logger.warning(f"插件 {plugin_name} 声明了对自身的依赖,已忽略")
|
|
||||||
continue
|
|
||||||
|
|
||||||
dependencies.add(dependency_name)
|
|
||||||
|
|
||||||
return dependencies
|
|
||||||
|
|
||||||
def _build_plugin_dependency_graph(self) -> Tuple[Dict[str, Set[str]], Dict[str, Set[str]]]:
|
|
||||||
"""构建依赖图并返回缺失依赖映射。"""
|
|
||||||
plugin_names = set(self.plugin_classes.keys())
|
|
||||||
dependency_graph: Dict[str, Set[str]] = {name: set() for name in plugin_names}
|
|
||||||
missing_dependencies: Dict[str, Set[str]] = {name: set() for name in plugin_names}
|
|
||||||
|
|
||||||
for plugin_name, plugin_class in self.plugin_classes.items():
|
|
||||||
declared_dependencies = self._extract_declared_dependencies(plugin_name, plugin_class)
|
|
||||||
for dependency_name in declared_dependencies:
|
|
||||||
if dependency_name in plugin_names:
|
|
||||||
dependency_graph[plugin_name].add(dependency_name)
|
|
||||||
else:
|
|
||||||
missing_dependencies[plugin_name].add(dependency_name)
|
|
||||||
|
|
||||||
return dependency_graph, missing_dependencies
|
|
||||||
|
|
||||||
def _resolve_plugin_load_order(self, dependency_graph: Dict[str, Set[str]]) -> Tuple[List[str], Set[str]]:
|
|
||||||
"""根据依赖图计算加载顺序,并检测循环依赖。"""
|
|
||||||
indegree: Dict[str, int] = {
|
|
||||||
plugin_name: len(dependencies) for plugin_name, dependencies in dependency_graph.items()
|
|
||||||
}
|
|
||||||
reverse_graph: Dict[str, Set[str]] = {plugin_name: set() for plugin_name in dependency_graph}
|
|
||||||
|
|
||||||
for plugin_name, dependencies in dependency_graph.items():
|
|
||||||
for dependency_name in dependencies:
|
|
||||||
reverse_graph[dependency_name].add(plugin_name)
|
|
||||||
|
|
||||||
zero_indegree_queue = deque(sorted([name for name, degree in indegree.items() if degree == 0]))
|
|
||||||
load_order: List[str] = []
|
|
||||||
|
|
||||||
while zero_indegree_queue:
|
|
||||||
current_plugin = zero_indegree_queue.popleft()
|
|
||||||
load_order.append(current_plugin)
|
|
||||||
|
|
||||||
for dependent_plugin in sorted(reverse_graph[current_plugin]):
|
|
||||||
indegree[dependent_plugin] -= 1
|
|
||||||
if indegree[dependent_plugin] == 0:
|
|
||||||
zero_indegree_queue.append(dependent_plugin)
|
|
||||||
|
|
||||||
cycle_plugins = {name for name, degree in indegree.items() if degree > 0}
|
|
||||||
if cycle_plugins:
|
|
||||||
logger.error(f"检测到循环依赖插件: {', '.join(sorted(cycle_plugins))}")
|
|
||||||
|
|
||||||
return load_order, cycle_plugins
|
|
||||||
|
|
||||||
# == 兼容性检查 ==
|
|
||||||
|
|
||||||
def _check_plugin_version_compatibility(self, plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]:
|
|
||||||
"""检查插件版本兼容性
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin_name: 插件名称
|
|
||||||
manifest_data: manifest数据
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[bool, str]: (是否兼容, 错误信息)
|
|
||||||
"""
|
|
||||||
if "host_application" not in manifest_data:
|
|
||||||
return True, "" # 没有版本要求,默认兼容
|
|
||||||
|
|
||||||
host_app = manifest_data["host_application"]
|
|
||||||
if not isinstance(host_app, dict):
|
|
||||||
return True, ""
|
|
||||||
|
|
||||||
min_version = host_app.get("min_version", "")
|
|
||||||
max_version = host_app.get("max_version", "")
|
|
||||||
|
|
||||||
if not min_version and not max_version:
|
|
||||||
return True, "" # 没有版本要求,默认兼容
|
|
||||||
|
|
||||||
try:
|
|
||||||
current_version = VersionComparator.get_current_host_version()
|
|
||||||
is_compatible, error_msg = VersionComparator.is_version_in_range(current_version, min_version, max_version)
|
|
||||||
if not is_compatible:
|
|
||||||
return False, f"版本不兼容: {error_msg}"
|
|
||||||
logger.debug(f"插件 {plugin_name} 版本兼容性检查通过")
|
|
||||||
return True, ""
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"插件 {plugin_name} 版本兼容性检查失败: {e}")
|
|
||||||
return False, f"插件 {plugin_name} 版本兼容性检查失败: {e}" # 检查失败时默认不允许加载
|
|
||||||
|
|
||||||
# == 显示统计与插件信息 ==
|
|
||||||
|
|
||||||
def _show_stats(self, total_registered: int, total_failed_registration: int):
|
|
||||||
# sourcery skip: low-code-quality
|
|
||||||
# 获取组件统计信息
|
|
||||||
stats = component_registry.get_registry_stats()
|
|
||||||
action_count = stats.get("action_components", 0)
|
|
||||||
command_count = stats.get("command_components", 0)
|
|
||||||
tool_count = stats.get("tool_components", 0)
|
|
||||||
event_handler_count = stats.get("event_handlers", 0)
|
|
||||||
total_components = stats.get("total_components", 0)
|
|
||||||
|
|
||||||
# 📋 显示插件加载总览
|
|
||||||
if total_registered > 0:
|
|
||||||
logger.info("🎉 插件系统加载完成!")
|
|
||||||
logger.info(
|
|
||||||
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, EventHandler: {event_handler_count})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 显示详细的插件列表
|
|
||||||
logger.info("📋 已加载插件详情:")
|
|
||||||
for plugin_name in self.loaded_plugins.keys():
|
|
||||||
if plugin_info := component_registry.get_plugin_info(plugin_name):
|
|
||||||
# 插件基本信息
|
|
||||||
version_info = f"v{plugin_info.version}" if plugin_info.version else ""
|
|
||||||
author_info = f"by {plugin_info.author}" if plugin_info.author else "unknown"
|
|
||||||
license_info = f"[{plugin_info.license}]" if plugin_info.license else ""
|
|
||||||
info_parts = [part for part in [version_info, author_info, license_info] if part]
|
|
||||||
extra_info = f" ({', '.join(info_parts)})" if info_parts else ""
|
|
||||||
|
|
||||||
logger.info(f" 📦 {plugin_info.display_name}{extra_info}")
|
|
||||||
|
|
||||||
# Manifest信息
|
|
||||||
if plugin_info.manifest_data:
|
|
||||||
"""
|
|
||||||
if plugin_info.keywords:
|
|
||||||
logger.info(f" 🏷️ 关键词: {', '.join(plugin_info.keywords)}")
|
|
||||||
if plugin_info.categories:
|
|
||||||
logger.info(f" 📁 分类: {', '.join(plugin_info.categories)}")
|
|
||||||
"""
|
|
||||||
if plugin_info.homepage_url:
|
|
||||||
logger.info(f" 🌐 主页: {plugin_info.homepage_url}")
|
|
||||||
|
|
||||||
# 组件列表
|
|
||||||
if plugin_info.components:
|
|
||||||
action_components = [
|
|
||||||
c for c in plugin_info.components if c.component_type == ComponentType.ACTION
|
|
||||||
]
|
|
||||||
command_components = [
|
|
||||||
c for c in plugin_info.components if c.component_type == ComponentType.COMMAND
|
|
||||||
]
|
|
||||||
tool_components = [c for c in plugin_info.components if c.component_type == ComponentType.TOOL]
|
|
||||||
event_handler_components = [
|
|
||||||
c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER
|
|
||||||
]
|
|
||||||
|
|
||||||
if action_components:
|
|
||||||
action_names = [c.name for c in action_components]
|
|
||||||
logger.info(f" 🎯 Action组件: {', '.join(action_names)}")
|
|
||||||
|
|
||||||
if command_components:
|
|
||||||
command_names = [c.name for c in command_components]
|
|
||||||
logger.info(f" ⚡ Command组件: {', '.join(command_names)}")
|
|
||||||
if tool_components:
|
|
||||||
tool_names = [c.name for c in tool_components]
|
|
||||||
logger.info(f" 🛠️ Tool组件: {', '.join(tool_names)}")
|
|
||||||
if event_handler_components:
|
|
||||||
event_handler_names = [c.name for c in event_handler_components]
|
|
||||||
logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_names)}")
|
|
||||||
|
|
||||||
# 依赖信息
|
|
||||||
if plugin_info.dependencies:
|
|
||||||
logger.info(f" 🔗 依赖: {', '.join(plugin_info.dependencies)}")
|
|
||||||
|
|
||||||
# 配置文件信息
|
|
||||||
if plugin_info.config_file:
|
|
||||||
config_status = "✅" if self.plugin_paths.get(plugin_name) else "❌"
|
|
||||||
logger.info(f" ⚙️ 配置: {plugin_info.config_file} {config_status}")
|
|
||||||
|
|
||||||
root_path = Path(__file__)
|
|
||||||
|
|
||||||
# 查找项目根目录
|
|
||||||
while not (root_path / "pyproject.toml").exists() and root_path.parent != root_path:
|
|
||||||
root_path = root_path.parent
|
|
||||||
|
|
||||||
# 显示目录统计
|
|
||||||
logger.info("📂 加载目录统计:")
|
|
||||||
for directory in self.plugin_directories:
|
|
||||||
if os.path.exists(directory):
|
|
||||||
plugins_in_dir = []
|
|
||||||
for plugin_name in self.loaded_plugins.keys():
|
|
||||||
plugin_path = self.plugin_paths.get(plugin_name, "")
|
|
||||||
if (
|
|
||||||
Path(plugin_path)
|
|
||||||
.resolve()
|
|
||||||
.is_relative_to(Path(os.path.join(str(root_path), directory)).resolve())
|
|
||||||
):
|
|
||||||
plugins_in_dir.append(plugin_name)
|
|
||||||
|
|
||||||
if plugins_in_dir:
|
|
||||||
logger.info(f" 📁 {directory}: {len(plugins_in_dir)}个插件 ({', '.join(plugins_in_dir)})")
|
|
||||||
else:
|
|
||||||
logger.info(f" 📁 {directory}: 0个插件")
|
|
||||||
|
|
||||||
# 失败信息
|
|
||||||
if total_failed_registration > 0:
|
|
||||||
logger.info(f"⚠️ 失败统计: {total_failed_registration}个插件加载失败")
|
|
||||||
for failed_plugin, error in self.failed_plugins.items():
|
|
||||||
logger.info(f" ❌ {failed_plugin}: {error}")
|
|
||||||
else:
|
|
||||||
logger.warning("😕 没有成功加载任何插件")
|
|
||||||
|
|
||||||
def _show_plugin_components(self, plugin_name: str) -> None:
|
|
||||||
if plugin_info := component_registry.get_plugin_info(plugin_name):
|
|
||||||
component_types = {}
|
|
||||||
for comp in plugin_info.components:
|
|
||||||
comp_type = comp.component_type.name
|
|
||||||
component_types[comp_type] = component_types.get(comp_type, 0) + 1
|
|
||||||
|
|
||||||
components_str = ", ".join([f"{count}个{ctype}" for ctype, count in component_types.items()])
|
|
||||||
|
|
||||||
# 显示manifest信息
|
|
||||||
manifest_info = ""
|
|
||||||
if plugin_info.license:
|
|
||||||
manifest_info += f" [{plugin_info.license}]"
|
|
||||||
if plugin_info.keywords:
|
|
||||||
manifest_info += f" 关键词: {', '.join(plugin_info.keywords[:3])}" # 只显示前3个关键词
|
|
||||||
if len(plugin_info.keywords) > 3:
|
|
||||||
manifest_info += "..."
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"✅ 插件加载成功: {plugin_name} v{plugin_info.version} ({components_str}){manifest_info} - {plugin_info.description}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info(f"✅ 插件加载成功: {plugin_name}")
|
|
||||||
|
|
||||||
|
|
||||||
# 全局插件管理器实例
|
|
||||||
plugin_manager = PluginManager()
|
|
||||||
@@ -1,251 +0,0 @@
|
|||||||
from typing import Any, Callable, Dict, Optional
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.plugin_system.base.service_types import PluginServiceInfo
|
|
||||||
|
|
||||||
logger = get_logger("plugin_service_registry")
|
|
||||||
|
|
||||||
|
|
||||||
class PluginServiceRegistry:
|
|
||||||
"""插件服务注册中心"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._services: Dict[str, PluginServiceInfo] = {}
|
|
||||||
self._service_handlers: Dict[str, Callable[..., Any]] = {}
|
|
||||||
logger.info("插件服务注册中心初始化完成")
|
|
||||||
|
|
||||||
def register_service(self, service_info: PluginServiceInfo, service_handler: Callable[..., Any]) -> bool:
|
|
||||||
"""注册插件服务。"""
|
|
||||||
if not service_info.name or not service_info.plugin_name:
|
|
||||||
logger.error("插件服务注册失败: service名称或插件名称为空")
|
|
||||||
return False
|
|
||||||
if "." in service_info.name:
|
|
||||||
logger.error(f"插件服务名称 '{service_info.name}' 包含非法字符 '.',请使用下划线替代")
|
|
||||||
return False
|
|
||||||
if "." in service_info.plugin_name:
|
|
||||||
logger.error(f"插件服务所属插件名称 '{service_info.plugin_name}' 包含非法字符 '.',请使用下划线替代")
|
|
||||||
return False
|
|
||||||
if invalid_callers := [caller for caller in service_info.allowed_callers if "." in caller]:
|
|
||||||
logger.error(f"插件服务白名单包含非法调用方名称: {invalid_callers}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
full_name = service_info.full_name
|
|
||||||
if full_name in self._services:
|
|
||||||
logger.warning(f"插件服务已存在,拒绝重复注册: {full_name}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
self._services[full_name] = service_info
|
|
||||||
self._service_handlers[full_name] = service_handler
|
|
||||||
logger.debug(f"已注册插件服务: {full_name} (version={service_info.version})")
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_service(self, service_name: str, plugin_name: Optional[str] = None) -> Optional[PluginServiceInfo]:
|
|
||||||
"""获取插件服务元信息。
|
|
||||||
|
|
||||||
service_name支持:
|
|
||||||
- full_name: plugin_name.service_name
|
|
||||||
- short_name: service_name(当唯一时可解析)
|
|
||||||
"""
|
|
||||||
full_name = self._resolve_full_name(service_name, plugin_name)
|
|
||||||
return self._services.get(full_name) if full_name else None
|
|
||||||
|
|
||||||
def get_service_handler(self, service_name: str, plugin_name: Optional[str] = None) -> Optional[Callable[..., Any]]:
|
|
||||||
"""获取插件服务处理函数。"""
|
|
||||||
full_name = self._resolve_full_name(service_name, plugin_name)
|
|
||||||
return self._service_handlers.get(full_name) if full_name else None
|
|
||||||
|
|
||||||
def list_services(
|
|
||||||
self, plugin_name: Optional[str] = None, enabled_only: bool = False
|
|
||||||
) -> Dict[str, PluginServiceInfo]:
|
|
||||||
"""列出插件服务。"""
|
|
||||||
services = self._services.copy()
|
|
||||||
if plugin_name:
|
|
||||||
services = {name: info for name, info in services.items() if info.plugin_name == plugin_name}
|
|
||||||
if enabled_only:
|
|
||||||
services = {name: info for name, info in services.items() if info.enabled}
|
|
||||||
return services
|
|
||||||
|
|
||||||
def enable_service(self, service_name: str, plugin_name: Optional[str] = None) -> bool:
|
|
||||||
"""启用插件服务。"""
|
|
||||||
if not (service_info := self.get_service(service_name, plugin_name)):
|
|
||||||
logger.warning(f"插件服务未注册,无法启用: {service_name}")
|
|
||||||
return False
|
|
||||||
service_info.enabled = True
|
|
||||||
logger.info(f"插件服务已启用: {service_info.full_name}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
def disable_service(self, service_name: str, plugin_name: Optional[str] = None) -> bool:
|
|
||||||
"""禁用插件服务。"""
|
|
||||||
if not (service_info := self.get_service(service_name, plugin_name)):
|
|
||||||
logger.warning(f"插件服务未注册,无法禁用: {service_name}")
|
|
||||||
return False
|
|
||||||
service_info.enabled = False
|
|
||||||
logger.info(f"插件服务已禁用: {service_info.full_name}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
def unregister_service(self, service_name: str, plugin_name: Optional[str] = None) -> bool:
|
|
||||||
"""注销单个插件服务。"""
|
|
||||||
full_name = self._resolve_full_name(service_name, plugin_name)
|
|
||||||
if not full_name:
|
|
||||||
logger.warning(f"插件服务未注册,无法注销: {service_name}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
self._services.pop(full_name, None)
|
|
||||||
self._service_handlers.pop(full_name, None)
|
|
||||||
logger.info(f"插件服务已注销: {full_name}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
def remove_services_by_plugin(self, plugin_name: str) -> int:
|
|
||||||
"""移除某插件的所有注册服务。"""
|
|
||||||
target_names = [full_name for full_name, info in self._services.items() if info.plugin_name == plugin_name]
|
|
||||||
for full_name in target_names:
|
|
||||||
self._services.pop(full_name, None)
|
|
||||||
self._service_handlers.pop(full_name, None)
|
|
||||||
|
|
||||||
removed_count = len(target_names)
|
|
||||||
if removed_count:
|
|
||||||
logger.info(f"已移除插件 {plugin_name} 的服务数量: {removed_count}")
|
|
||||||
return removed_count
|
|
||||||
|
|
||||||
async def call_service(
|
|
||||||
self,
|
|
||||||
service_name: str,
|
|
||||||
*args: Any,
|
|
||||||
plugin_name: Optional[str] = None,
|
|
||||||
caller_plugin: Optional[str] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Any:
|
|
||||||
"""调用插件服务(支持同步/异步handler)。"""
|
|
||||||
service_info = self.get_service(service_name, plugin_name)
|
|
||||||
if not service_info:
|
|
||||||
target_name = f"{plugin_name}.{service_name}" if plugin_name and "." not in service_name else service_name
|
|
||||||
raise ValueError(f"插件服务未注册: {target_name}")
|
|
||||||
|
|
||||||
if (
|
|
||||||
"." not in service_name
|
|
||||||
and plugin_name is None
|
|
||||||
and caller_plugin
|
|
||||||
and service_info.plugin_name != caller_plugin
|
|
||||||
):
|
|
||||||
raise PermissionError("跨插件服务调用必须使用完整服务名或显式指定plugin_name")
|
|
||||||
|
|
||||||
if not self._is_call_authorized(service_info, caller_plugin):
|
|
||||||
raise PermissionError(
|
|
||||||
f"调用被拒绝: caller={caller_plugin or 'anonymous'} 无权限访问服务 {service_info.full_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not service_info.enabled:
|
|
||||||
raise RuntimeError(f"插件服务已禁用: {service_info.full_name}")
|
|
||||||
|
|
||||||
handler = self.get_service_handler(service_name, plugin_name)
|
|
||||||
if not handler:
|
|
||||||
raise RuntimeError(f"插件服务处理器不存在: {service_info.full_name}")
|
|
||||||
|
|
||||||
self._validate_input_contract(service_info, args, kwargs)
|
|
||||||
|
|
||||||
result = handler(*args, **kwargs)
|
|
||||||
resolved_result = await result if inspect.isawaitable(result) else result
|
|
||||||
self._validate_output_contract(service_info, resolved_result)
|
|
||||||
return resolved_result
|
|
||||||
|
|
||||||
def _is_call_authorized(self, service_info: PluginServiceInfo, caller_plugin: Optional[str]) -> bool:
|
|
||||||
"""检查服务调用权限。"""
|
|
||||||
if caller_plugin is None:
|
|
||||||
return service_info.public
|
|
||||||
if caller_plugin == service_info.plugin_name:
|
|
||||||
return True
|
|
||||||
if service_info.public:
|
|
||||||
return True
|
|
||||||
allowed_callers = {caller.strip() for caller in service_info.allowed_callers if caller.strip()}
|
|
||||||
return "*" in allowed_callers or caller_plugin in allowed_callers
|
|
||||||
|
|
||||||
def _validate_input_contract(
|
|
||||||
self, service_info: PluginServiceInfo, args: tuple[Any, ...], kwargs: dict[str, Any]
|
|
||||||
) -> None:
|
|
||||||
"""校验服务入参契约。"""
|
|
||||||
schema = service_info.params_schema
|
|
||||||
if not schema:
|
|
||||||
return
|
|
||||||
|
|
||||||
properties = schema.get("properties", {}) if isinstance(schema, dict) else {}
|
|
||||||
is_invocation_schema = "args" in properties or "kwargs" in properties
|
|
||||||
|
|
||||||
if is_invocation_schema:
|
|
||||||
payload = {"args": list(args), "kwargs": kwargs}
|
|
||||||
self._validate_by_schema(payload, schema, path="params")
|
|
||||||
return
|
|
||||||
|
|
||||||
if args:
|
|
||||||
raise ValueError(f"服务 {service_info.full_name} 的入参契约不允许位置参数")
|
|
||||||
self._validate_by_schema(kwargs, schema, path="params")
|
|
||||||
|
|
||||||
def _validate_output_contract(self, service_info: PluginServiceInfo, value: Any) -> None:
|
|
||||||
"""校验服务返回值契约。"""
|
|
||||||
if not service_info.return_schema:
|
|
||||||
return
|
|
||||||
self._validate_by_schema(value, service_info.return_schema, path="return")
|
|
||||||
|
|
||||||
def _validate_by_schema(self, value: Any, schema: Dict[str, Any], path: str) -> None:
|
|
||||||
"""基于简化JSON-Schema校验数据。"""
|
|
||||||
expected_type = schema.get("type")
|
|
||||||
if expected_type:
|
|
||||||
self._validate_type(value, expected_type, path)
|
|
||||||
|
|
||||||
enum_values = schema.get("enum")
|
|
||||||
if enum_values is not None and value not in enum_values:
|
|
||||||
raise ValueError(f"{path} 不在枚举范围内: {value}")
|
|
||||||
|
|
||||||
if expected_type == "object":
|
|
||||||
properties = schema.get("properties", {})
|
|
||||||
required = schema.get("required", [])
|
|
||||||
|
|
||||||
for field in required:
|
|
||||||
if field not in value:
|
|
||||||
raise ValueError(f"{path}.{field} 为必填字段")
|
|
||||||
|
|
||||||
for field, field_value in value.items():
|
|
||||||
if field in properties:
|
|
||||||
self._validate_by_schema(field_value, properties[field], f"{path}.{field}")
|
|
||||||
elif schema.get("additionalProperties", True) is False:
|
|
||||||
raise ValueError(f"{path}.{field} 不允许额外字段")
|
|
||||||
|
|
||||||
if expected_type == "array":
|
|
||||||
if item_schema := schema.get("items"):
|
|
||||||
for index, item in enumerate(value):
|
|
||||||
self._validate_by_schema(item, item_schema, f"{path}[{index}]")
|
|
||||||
|
|
||||||
def _validate_type(self, value: Any, expected_type: str, path: str) -> None:
|
|
||||||
"""校验基础类型。"""
|
|
||||||
type_checkers: Dict[str, Callable[[Any], bool]] = {
|
|
||||||
"string": lambda item: isinstance(item, str),
|
|
||||||
"number": lambda item: isinstance(item, (int, float)) and not isinstance(item, bool),
|
|
||||||
"integer": lambda item: isinstance(item, int) and not isinstance(item, bool),
|
|
||||||
"boolean": lambda item: isinstance(item, bool),
|
|
||||||
"object": lambda item: isinstance(item, dict),
|
|
||||||
"array": lambda item: isinstance(item, list),
|
|
||||||
"null": lambda item: item is None,
|
|
||||||
}
|
|
||||||
checker = type_checkers.get(expected_type)
|
|
||||||
if checker and not checker(value):
|
|
||||||
raise TypeError(f"{path} 类型不匹配,期望 {expected_type},实际 {type(value).__name__}")
|
|
||||||
|
|
||||||
def _resolve_full_name(self, service_name: str, plugin_name: Optional[str] = None) -> Optional[str]:
|
|
||||||
"""解析服务全名。"""
|
|
||||||
if "." in service_name:
|
|
||||||
return service_name if service_name in self._services else None
|
|
||||||
|
|
||||||
if plugin_name:
|
|
||||||
full_name = f"{plugin_name}.{service_name}"
|
|
||||||
return full_name if full_name in self._services else None
|
|
||||||
|
|
||||||
candidates = [full_name for full_name, info in self._services.items() if info.name == service_name]
|
|
||||||
if len(candidates) == 1:
|
|
||||||
return candidates[0]
|
|
||||||
if len(candidates) > 1:
|
|
||||||
logger.warning(f"插件服务名称 '{service_name}' 存在多义,请传入plugin_name或使用完整服务名")
|
|
||||||
return None
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
plugin_service_registry = PluginServiceRegistry()
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
- [x] 自定义事件
|
|
||||||
- [ ] <del>允许handler随时订阅</del>
|
|
||||||
- [x] 允许其他组件给handler增加订阅
|
|
||||||
- [x] 允许其他组件给handler取消订阅
|
|
||||||
- [ ] <del>允许一个handler订阅多个事件</del>
|
|
||||||
- [x] event激活时给handler传递参数
|
|
||||||
- [ ] handler能拿到所有handlers的结果(按照处理权重)
|
|
||||||
- [x] 随时注册
|
|
||||||
- [ ] <del>删除event</del>
|
|
||||||
- [ ] 必要性?
|
|
||||||
- [x] 能够更改prompt
|
|
||||||
- [x] 能够更改llm_response
|
|
||||||
- [x] 能够更改message
|
|
||||||
@@ -1,260 +0,0 @@
|
|||||||
from typing import Any, Awaitable, Callable, List, Optional, Tuple, Union
|
|
||||||
import asyncio
|
|
||||||
import inspect
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.plugin_system.base.component_types import EventType, MaiMessages
|
|
||||||
from src.plugin_system.base.workflow_errors import WorkflowErrorCode
|
|
||||||
from src.plugin_system.base.workflow_types import WorkflowContext, WorkflowStage, WorkflowStepResult
|
|
||||||
|
|
||||||
logger = get_logger("workflow_engine")
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowEngine:
|
|
||||||
"""线性Workflow执行器(MVP)"""
|
|
||||||
|
|
||||||
STAGE_EVENT_SEQUENCE: List[Tuple[WorkflowStage, Union[EventType, str]]] = [
|
|
||||||
(WorkflowStage.INGRESS, "workflow.ingress"),
|
|
||||||
(WorkflowStage.PRE_PROCESS, EventType.ON_MESSAGE_PRE_PROCESS),
|
|
||||||
(WorkflowStage.PLAN, EventType.ON_PLAN),
|
|
||||||
(WorkflowStage.TOOL_EXECUTE, "workflow.tool_execute"),
|
|
||||||
(WorkflowStage.POST_PROCESS, EventType.POST_SEND_PRE_PROCESS),
|
|
||||||
(WorkflowStage.EGRESS, "workflow.egress"),
|
|
||||||
]
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._execution_history: dict[str, dict[str, Any]] = {}
|
|
||||||
|
|
||||||
async def execute_linear(
|
|
||||||
self,
|
|
||||||
dispatch_event: Callable[
|
|
||||||
[Union[EventType, str], Optional[MaiMessages], Optional[str], Optional[List[str]]],
|
|
||||||
Awaitable[Tuple[bool, Optional[MaiMessages]]],
|
|
||||||
],
|
|
||||||
message: Optional[MaiMessages] = None,
|
|
||||||
stream_id: Optional[str] = None,
|
|
||||||
action_usage: Optional[List[str]] = None,
|
|
||||||
context: Optional[WorkflowContext] = None,
|
|
||||||
) -> Tuple[WorkflowStepResult, Optional[MaiMessages], WorkflowContext]:
|
|
||||||
"""执行线性workflow。"""
|
|
||||||
workflow_context = context or WorkflowContext(trace_id=uuid.uuid4().hex, stream_id=stream_id)
|
|
||||||
current_message = message.deepcopy() if message else None
|
|
||||||
self._execution_history[workflow_context.trace_id] = {
|
|
||||||
"trace_id": workflow_context.trace_id,
|
|
||||||
"stream_id": workflow_context.stream_id,
|
|
||||||
"stages": [],
|
|
||||||
"errors": [],
|
|
||||||
"status": "running",
|
|
||||||
}
|
|
||||||
|
|
||||||
for stage, event_type in self.STAGE_EVENT_SEQUENCE:
|
|
||||||
stage_key = str(stage)
|
|
||||||
stage_start = time.perf_counter()
|
|
||||||
try:
|
|
||||||
should_continue, modified_message = await dispatch_event(
|
|
||||||
event_type,
|
|
||||||
current_message,
|
|
||||||
workflow_context.stream_id,
|
|
||||||
action_usage,
|
|
||||||
)
|
|
||||||
workflow_context.timings[stage_key] = time.perf_counter() - stage_start
|
|
||||||
self._execution_history[workflow_context.trace_id]["stages"].append(
|
|
||||||
{
|
|
||||||
"stage": stage_key,
|
|
||||||
"event_type": str(event_type),
|
|
||||||
"event_continue": should_continue,
|
|
||||||
"event_cost": workflow_context.timings[stage_key],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if modified_message:
|
|
||||||
current_message = modified_message
|
|
||||||
|
|
||||||
if not should_continue:
|
|
||||||
logger.info(f"[trace_id={workflow_context.trace_id}] Workflow在阶段 {stage_key} 被中断")
|
|
||||||
return (
|
|
||||||
WorkflowStepResult(
|
|
||||||
status="stop",
|
|
||||||
return_message=f"workflow stopped at stage {stage_key}",
|
|
||||||
diagnostics={
|
|
||||||
"stage": stage_key,
|
|
||||||
"trace_id": workflow_context.trace_id,
|
|
||||||
"error_code": WorkflowErrorCode.POLICY_BLOCKED.value,
|
|
||||||
},
|
|
||||||
),
|
|
||||||
current_message,
|
|
||||||
workflow_context,
|
|
||||||
)
|
|
||||||
|
|
||||||
step_result = await self._execute_registered_steps(stage, workflow_context, current_message)
|
|
||||||
if step_result.status in ["stop", "failed"]:
|
|
||||||
self._execution_history[workflow_context.trace_id]["status"] = step_result.status
|
|
||||||
self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy()
|
|
||||||
return step_result, current_message, workflow_context
|
|
||||||
except Exception as e:
|
|
||||||
workflow_context.timings[stage_key] = time.perf_counter() - stage_start
|
|
||||||
workflow_context.errors.append(f"{stage_key}: {e}")
|
|
||||||
logger.error(
|
|
||||||
f"[trace_id={workflow_context.trace_id}] Workflow阶段 {stage_key} 执行异常: {e}", exc_info=True
|
|
||||||
)
|
|
||||||
self._execution_history[workflow_context.trace_id]["status"] = "failed"
|
|
||||||
self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy()
|
|
||||||
return (
|
|
||||||
WorkflowStepResult(
|
|
||||||
status="failed",
|
|
||||||
return_message=str(e),
|
|
||||||
diagnostics={
|
|
||||||
"stage": stage_key,
|
|
||||||
"trace_id": workflow_context.trace_id,
|
|
||||||
"error_code": WorkflowErrorCode.DOWNSTREAM_FAILED.value,
|
|
||||||
},
|
|
||||||
),
|
|
||||||
current_message,
|
|
||||||
workflow_context,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._execution_history[workflow_context.trace_id]["status"] = "continue"
|
|
||||||
self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy()
|
|
||||||
return (
|
|
||||||
WorkflowStepResult(
|
|
||||||
status="continue",
|
|
||||||
return_message="workflow completed",
|
|
||||||
diagnostics={"trace_id": workflow_context.trace_id},
|
|
||||||
),
|
|
||||||
current_message,
|
|
||||||
workflow_context,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _execute_registered_steps(
|
|
||||||
self,
|
|
||||||
stage: WorkflowStage,
|
|
||||||
context: WorkflowContext,
|
|
||||||
message: Optional[MaiMessages],
|
|
||||||
) -> WorkflowStepResult:
|
|
||||||
"""执行指定阶段已注册的workflow步骤。"""
|
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
|
||||||
|
|
||||||
stage_steps = component_registry.get_steps_by_stage(stage, enabled_only=True)
|
|
||||||
sorted_steps = sorted(stage_steps.values(), key=lambda step_info: step_info.priority, reverse=True)
|
|
||||||
|
|
||||||
for step_info in sorted_steps:
|
|
||||||
handler = component_registry.get_workflow_step_handler(step_info.full_name, stage)
|
|
||||||
if not handler:
|
|
||||||
context.errors.append(f"{step_info.full_name}: handler not found")
|
|
||||||
continue
|
|
||||||
|
|
||||||
step_timing_key = f"{stage.value}:{step_info.full_name}"
|
|
||||||
step_start = time.perf_counter()
|
|
||||||
timeout_seconds = step_info.timeout_ms / 1000 if step_info.timeout_ms > 0 else None
|
|
||||||
|
|
||||||
try:
|
|
||||||
if inspect.iscoroutinefunction(handler):
|
|
||||||
coroutine = handler(context, message)
|
|
||||||
result = await asyncio.wait_for(coroutine, timeout_seconds) if timeout_seconds else await coroutine
|
|
||||||
else:
|
|
||||||
if timeout_seconds:
|
|
||||||
result = await asyncio.wait_for(asyncio.to_thread(handler, context, message), timeout_seconds)
|
|
||||||
else:
|
|
||||||
result = handler(context, message)
|
|
||||||
if inspect.isawaitable(result):
|
|
||||||
result = await asyncio.wait_for(result, timeout_seconds) if timeout_seconds else await result
|
|
||||||
context.timings[step_timing_key] = time.perf_counter() - step_start
|
|
||||||
|
|
||||||
normalized_result = self._normalize_step_result(result)
|
|
||||||
if normalized_result.status == "continue":
|
|
||||||
continue
|
|
||||||
|
|
||||||
normalized_result.diagnostics.setdefault("stage", stage.value)
|
|
||||||
normalized_result.diagnostics.setdefault("step", step_info.full_name)
|
|
||||||
normalized_result.diagnostics.setdefault("trace_id", context.trace_id)
|
|
||||||
if normalized_result.status == "failed":
|
|
||||||
context.errors.append(
|
|
||||||
f"{step_info.full_name}: {normalized_result.return_message or 'workflow step failed'}"
|
|
||||||
)
|
|
||||||
normalized_result.diagnostics.setdefault("error_code", WorkflowErrorCode.DOWNSTREAM_FAILED.value)
|
|
||||||
return normalized_result
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
context.timings[step_timing_key] = time.perf_counter() - step_start
|
|
||||||
timeout_message = f"workflow step timeout after {step_info.timeout_ms}ms"
|
|
||||||
context.errors.append(f"{step_info.full_name}: {timeout_message}")
|
|
||||||
logger.error(
|
|
||||||
f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 超时: {timeout_message}"
|
|
||||||
)
|
|
||||||
return WorkflowStepResult(
|
|
||||||
status="failed",
|
|
||||||
return_message=timeout_message,
|
|
||||||
diagnostics={
|
|
||||||
"stage": stage.value,
|
|
||||||
"step": step_info.full_name,
|
|
||||||
"trace_id": context.trace_id,
|
|
||||||
"error_code": WorkflowErrorCode.STEP_TIMEOUT.value,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
context.timings[step_timing_key] = time.perf_counter() - step_start
|
|
||||||
context.errors.append(f"{step_info.full_name}: {e}")
|
|
||||||
logger.error(
|
|
||||||
f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 执行异常: {e}", exc_info=True
|
|
||||||
)
|
|
||||||
return WorkflowStepResult(
|
|
||||||
status="failed",
|
|
||||||
return_message=str(e),
|
|
||||||
diagnostics={
|
|
||||||
"stage": stage.value,
|
|
||||||
"step": step_info.full_name,
|
|
||||||
"trace_id": context.trace_id,
|
|
||||||
"error_code": WorkflowErrorCode.DOWNSTREAM_FAILED.value,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
return WorkflowStepResult(status="continue", diagnostics={"stage": stage.value, "trace_id": context.trace_id})
|
|
||||||
|
|
||||||
def _normalize_step_result(self, result: Any) -> WorkflowStepResult:
|
|
||||||
"""归一化workflow step返回值。"""
|
|
||||||
if isinstance(result, WorkflowStepResult):
|
|
||||||
return result
|
|
||||||
if isinstance(result, bool):
|
|
||||||
if result:
|
|
||||||
return WorkflowStepResult(status="continue")
|
|
||||||
return WorkflowStepResult(
|
|
||||||
status="failed",
|
|
||||||
diagnostics={"error_code": WorkflowErrorCode.DOWNSTREAM_FAILED.value},
|
|
||||||
)
|
|
||||||
if result is None:
|
|
||||||
return WorkflowStepResult(status="continue")
|
|
||||||
if isinstance(result, str):
|
|
||||||
return WorkflowStepResult(status="continue", return_message=result)
|
|
||||||
if isinstance(result, dict):
|
|
||||||
status = result.get("status", "continue")
|
|
||||||
if status not in ["continue", "stop", "failed"]:
|
|
||||||
status = "failed"
|
|
||||||
return WorkflowStepResult(
|
|
||||||
status=status,
|
|
||||||
return_message=result.get("return_message"),
|
|
||||||
diagnostics=result.get("diagnostics", {}),
|
|
||||||
events=result.get("events", []),
|
|
||||||
)
|
|
||||||
return WorkflowStepResult(
|
|
||||||
status="failed",
|
|
||||||
return_message=f"unsupported step result type: {type(result)}",
|
|
||||||
diagnostics={"error_code": WorkflowErrorCode.BAD_PAYLOAD.value},
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_execution_trace(self, trace_id: str) -> Optional[dict[str, Any]]:
|
|
||||||
"""按trace_id获取workflow执行路径。"""
|
|
||||||
trace = self._execution_history.get(trace_id)
|
|
||||||
return trace.copy() if trace else None
|
|
||||||
|
|
||||||
def clear_execution_trace(self, trace_id: str) -> bool:
|
|
||||||
"""清理trace执行记录。"""
|
|
||||||
if trace_id in self._execution_history:
|
|
||||||
del self._execution_history[trace_id]
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
workflow_engine = WorkflowEngine()
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
"""
|
|
||||||
插件系统工具模块
|
|
||||||
|
|
||||||
提供插件开发和管理的实用工具
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .manifest_utils import (
|
|
||||||
ManifestValidator,
|
|
||||||
# ManifestGenerator,
|
|
||||||
# validate_plugin_manifest,
|
|
||||||
# generate_plugin_manifest,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"ManifestValidator",
|
|
||||||
# "ManifestGenerator",
|
|
||||||
# "validate_plugin_manifest",
|
|
||||||
# "generate_plugin_manifest",
|
|
||||||
]
|
|
||||||
@@ -1,515 +0,0 @@
|
|||||||
"""
|
|
||||||
插件Manifest工具模块
|
|
||||||
|
|
||||||
提供manifest文件的验证、生成和管理功能
|
|
||||||
"""
|
|
||||||
|
|
||||||
import re
|
|
||||||
from typing import Dict, Any, Tuple
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.config.config import MMC_VERSION
|
|
||||||
|
|
||||||
# if TYPE_CHECKING:
|
|
||||||
# from src.plugin_system.base.base_plugin import BasePlugin
|
|
||||||
|
|
||||||
logger = get_logger("manifest_utils")
|
|
||||||
|
|
||||||
|
|
||||||
class VersionComparator:
|
|
||||||
"""版本号比较器
|
|
||||||
|
|
||||||
支持语义化版本号比较,自动处理snapshot版本,并支持向前兼容性检查
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 版本兼容性映射表(硬编码)
|
|
||||||
# 格式: {插件最大支持版本: [实际兼容的版本列表]}
|
|
||||||
COMPATIBILITY_MAP = {
|
|
||||||
# 0.8.x 系列向前兼容规则
|
|
||||||
"0.8.0": ["0.8.1", "0.8.2", "0.8.3", "0.8.4", "0.8.5", "0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"],
|
|
||||||
"0.8.1": ["0.8.2", "0.8.3", "0.8.4", "0.8.5", "0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"],
|
|
||||||
"0.8.2": ["0.8.3", "0.8.4", "0.8.5", "0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"],
|
|
||||||
"0.8.3": ["0.8.4", "0.8.5", "0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"],
|
|
||||||
"0.8.4": ["0.8.5", "0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"],
|
|
||||||
"0.8.5": ["0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"],
|
|
||||||
"0.8.6": ["0.8.7", "0.8.8", "0.8.9", "0.8.10"],
|
|
||||||
"0.8.7": ["0.8.8", "0.8.9", "0.8.10"],
|
|
||||||
"0.8.8": ["0.8.9", "0.8.10"],
|
|
||||||
"0.8.9": ["0.8.10"],
|
|
||||||
# 可以根据需要添加更多兼容映射
|
|
||||||
# "0.9.0": ["0.9.1", "0.9.2", "0.9.3"], # 示例:0.9.x系列兼容
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def normalize_version(version: str) -> str:
|
|
||||||
"""标准化版本号,移除snapshot标识
|
|
||||||
|
|
||||||
Args:
|
|
||||||
version: 原始版本号,如 "0.8.0-snapshot.1"
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 标准化后的版本号,如 "0.8.0"
|
|
||||||
"""
|
|
||||||
if not version:
|
|
||||||
return "0.0.0"
|
|
||||||
|
|
||||||
# 移除snapshot部分
|
|
||||||
normalized = re.sub(r"-snapshot\.\d+", "", version.strip())
|
|
||||||
|
|
||||||
# 确保版本号格式正确
|
|
||||||
if not re.match(r"^\d+(\.\d+){0,2}$", normalized):
|
|
||||||
# 如果不是有效的版本号格式,返回默认版本
|
|
||||||
return "0.0.0"
|
|
||||||
|
|
||||||
# 尝试补全版本号
|
|
||||||
parts = normalized.split(".")
|
|
||||||
while len(parts) < 3:
|
|
||||||
parts.append("0")
|
|
||||||
normalized = ".".join(parts[:3])
|
|
||||||
|
|
||||||
return normalized
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def parse_version(version: str) -> Tuple[int, int, int]:
|
|
||||||
"""解析版本号为元组
|
|
||||||
|
|
||||||
Args:
|
|
||||||
version: 版本号字符串
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[int, int, int]: (major, minor, patch)
|
|
||||||
"""
|
|
||||||
normalized = VersionComparator.normalize_version(version)
|
|
||||||
try:
|
|
||||||
parts = normalized.split(".")
|
|
||||||
return (int(parts[0]), int(parts[1]), int(parts[2]))
|
|
||||||
except (ValueError, IndexError):
|
|
||||||
logger.warning(f"无法解析版本号: {version},使用默认版本 0.0.0")
|
|
||||||
return (0, 0, 0)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def compare_versions(version1: str, version2: str) -> int:
|
|
||||||
"""比较两个版本号
|
|
||||||
|
|
||||||
Args:
|
|
||||||
version1: 第一个版本号
|
|
||||||
version2: 第二个版本号
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: -1 if version1 < version2, 0 if equal, 1 if version1 > version2
|
|
||||||
"""
|
|
||||||
v1_tuple = VersionComparator.parse_version(version1)
|
|
||||||
v2_tuple = VersionComparator.parse_version(version2)
|
|
||||||
|
|
||||||
if v1_tuple < v2_tuple:
|
|
||||||
return -1
|
|
||||||
elif v1_tuple > v2_tuple:
|
|
||||||
return 1
|
|
||||||
else:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def check_forward_compatibility(current_version: str, max_version: str) -> Tuple[bool, str]:
|
|
||||||
"""检查向前兼容性(仅使用兼容性映射表)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
current_version: 当前版本
|
|
||||||
max_version: 插件声明的最大支持版本
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[bool, str]: (是否兼容, 兼容信息)
|
|
||||||
"""
|
|
||||||
current_normalized = VersionComparator.normalize_version(current_version)
|
|
||||||
max_normalized = VersionComparator.normalize_version(max_version)
|
|
||||||
|
|
||||||
# 检查兼容性映射表
|
|
||||||
if max_normalized in VersionComparator.COMPATIBILITY_MAP:
|
|
||||||
compatible_versions = VersionComparator.COMPATIBILITY_MAP[max_normalized]
|
|
||||||
if current_normalized in compatible_versions:
|
|
||||||
return True, f"根据兼容性映射表,版本 {current_normalized} 与 {max_normalized} 兼容"
|
|
||||||
|
|
||||||
return False, ""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def is_version_in_range(version: str, min_version: str = "", max_version: str = "") -> Tuple[bool, str]:
|
|
||||||
"""检查版本是否在指定范围内,支持兼容性检查
|
|
||||||
|
|
||||||
Args:
|
|
||||||
version: 要检查的版本号
|
|
||||||
min_version: 最小版本号(可选)
|
|
||||||
max_version: 最大版本号(可选)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[bool, str]: (是否兼容, 错误信息或兼容信息)
|
|
||||||
"""
|
|
||||||
if not min_version and not max_version:
|
|
||||||
return True, ""
|
|
||||||
|
|
||||||
version_normalized = VersionComparator.normalize_version(version)
|
|
||||||
|
|
||||||
# 检查最小版本
|
|
||||||
if min_version:
|
|
||||||
min_normalized = VersionComparator.normalize_version(min_version)
|
|
||||||
if VersionComparator.compare_versions(version_normalized, min_normalized) < 0:
|
|
||||||
return False, f"版本 {version_normalized} 低于最小要求版本 {min_normalized}"
|
|
||||||
|
|
||||||
# 检查最大版本
|
|
||||||
if max_version:
|
|
||||||
max_normalized = VersionComparator.normalize_version(max_version)
|
|
||||||
comparison = VersionComparator.compare_versions(version_normalized, max_normalized)
|
|
||||||
|
|
||||||
if comparison > 0:
|
|
||||||
# 严格版本检查失败,尝试兼容性检查
|
|
||||||
is_compatible, compat_msg = VersionComparator.check_forward_compatibility(
|
|
||||||
version_normalized, max_normalized
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_compatible:
|
|
||||||
return False, f"版本 {version_normalized} 高于最大支持版本 {max_normalized},且无兼容性映射"
|
|
||||||
|
|
||||||
logger.info(f"版本兼容性检查:{compat_msg}")
|
|
||||||
return True, compat_msg
|
|
||||||
return True, ""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_current_host_version() -> str:
|
|
||||||
"""获取当前主机应用版本
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 当前版本号
|
|
||||||
"""
|
|
||||||
return VersionComparator.normalize_version(MMC_VERSION)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def add_compatibility_mapping(base_version: str, compatible_versions: list) -> None:
|
|
||||||
"""动态添加兼容性映射
|
|
||||||
|
|
||||||
Args:
|
|
||||||
base_version: 基础版本(插件声明的最大支持版本)
|
|
||||||
compatible_versions: 兼容的版本列表
|
|
||||||
"""
|
|
||||||
base_normalized = VersionComparator.normalize_version(base_version)
|
|
||||||
VersionComparator.COMPATIBILITY_MAP[base_normalized] = [
|
|
||||||
VersionComparator.normalize_version(v) for v in compatible_versions
|
|
||||||
]
|
|
||||||
logger.info(f"添加兼容性映射:{base_normalized} -> {compatible_versions}")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_compatibility_info() -> Dict[str, list]:
|
|
||||||
"""获取当前的兼容性映射表
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict[str, list]: 兼容性映射表的副本
|
|
||||||
"""
|
|
||||||
return VersionComparator.COMPATIBILITY_MAP.copy()
|
|
||||||
|
|
||||||
|
|
||||||
class ManifestValidator:
|
|
||||||
"""Manifest文件验证器"""
|
|
||||||
|
|
||||||
# 必需字段(必须存在且不能为空)
|
|
||||||
REQUIRED_FIELDS = ["manifest_version", "name", "version", "description", "author"]
|
|
||||||
|
|
||||||
# 可选字段(可以不存在或为空)
|
|
||||||
OPTIONAL_FIELDS = [
|
|
||||||
"license",
|
|
||||||
"host_application",
|
|
||||||
"homepage_url",
|
|
||||||
"repository_url",
|
|
||||||
"keywords",
|
|
||||||
"categories",
|
|
||||||
"default_locale",
|
|
||||||
"locales_path",
|
|
||||||
"plugin_info",
|
|
||||||
]
|
|
||||||
|
|
||||||
# 建议填写的字段(会给出警告但不会导致验证失败)
|
|
||||||
RECOMMENDED_FIELDS = ["license", "keywords", "categories"]
|
|
||||||
|
|
||||||
SUPPORTED_MANIFEST_VERSIONS = [1]
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.validation_errors = []
|
|
||||||
self.validation_warnings = []
|
|
||||||
|
|
||||||
def validate_manifest(self, manifest_data: Dict[str, Any]) -> bool:
|
|
||||||
"""验证manifest数据
|
|
||||||
|
|
||||||
Args:
|
|
||||||
manifest_data: manifest数据字典
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否验证通过(只有错误会导致验证失败,警告不会)
|
|
||||||
"""
|
|
||||||
self.validation_errors.clear()
|
|
||||||
self.validation_warnings.clear()
|
|
||||||
|
|
||||||
# 检查必需字段
|
|
||||||
for field in self.REQUIRED_FIELDS:
|
|
||||||
if field not in manifest_data:
|
|
||||||
self.validation_errors.append(f"缺少必需字段: {field}")
|
|
||||||
elif not manifest_data[field]:
|
|
||||||
self.validation_errors.append(f"必需字段不能为空: {field}")
|
|
||||||
|
|
||||||
# 检查manifest版本
|
|
||||||
if "manifest_version" in manifest_data:
|
|
||||||
version = manifest_data["manifest_version"]
|
|
||||||
if version not in self.SUPPORTED_MANIFEST_VERSIONS:
|
|
||||||
self.validation_errors.append(
|
|
||||||
f"不支持的manifest版本: {version},支持的版本: {self.SUPPORTED_MANIFEST_VERSIONS}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 检查作者信息格式
|
|
||||||
if "author" in manifest_data:
|
|
||||||
author = manifest_data["author"]
|
|
||||||
if isinstance(author, dict):
|
|
||||||
if "name" not in author or not author["name"]:
|
|
||||||
self.validation_errors.append("作者信息缺少name字段或为空")
|
|
||||||
# url字段是可选的
|
|
||||||
if "url" in author and author["url"]:
|
|
||||||
url = author["url"]
|
|
||||||
if not (url.startswith("http://") or url.startswith("https://")):
|
|
||||||
self.validation_warnings.append("作者URL建议使用完整的URL格式")
|
|
||||||
elif isinstance(author, str):
|
|
||||||
if not author.strip():
|
|
||||||
self.validation_errors.append("作者信息不能为空")
|
|
||||||
else:
|
|
||||||
self.validation_errors.append("作者信息格式错误,应为字符串或包含name字段的对象")
|
|
||||||
# 检查主机应用版本要求(可选)
|
|
||||||
if "host_application" in manifest_data:
|
|
||||||
host_app = manifest_data["host_application"]
|
|
||||||
if isinstance(host_app, dict):
|
|
||||||
min_version = host_app.get("min_version", "")
|
|
||||||
max_version = host_app.get("max_version", "")
|
|
||||||
|
|
||||||
# 验证版本字段格式
|
|
||||||
for version_field in ["min_version", "max_version"]:
|
|
||||||
if version_field in host_app and not host_app[version_field]:
|
|
||||||
self.validation_warnings.append(f"host_application.{version_field}为空")
|
|
||||||
|
|
||||||
# 检查当前主机版本兼容性
|
|
||||||
if min_version or max_version:
|
|
||||||
current_version = VersionComparator.get_current_host_version()
|
|
||||||
is_compatible, error_msg = VersionComparator.is_version_in_range(
|
|
||||||
current_version, min_version, max_version
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_compatible:
|
|
||||||
self.validation_errors.append(f"版本兼容性检查失败: {error_msg} (当前版本: {current_version})")
|
|
||||||
else:
|
|
||||||
logger.debug(
|
|
||||||
f"版本兼容性检查通过: 当前版本 {current_version} 符合要求 [{min_version}, {max_version}]"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.validation_errors.append("host_application格式错误,应为对象")
|
|
||||||
|
|
||||||
# 检查URL格式(可选字段)
|
|
||||||
for url_field in ["homepage_url", "repository_url"]:
|
|
||||||
if url_field in manifest_data and manifest_data[url_field]:
|
|
||||||
url: str = manifest_data[url_field]
|
|
||||||
if not (url.startswith("http://") or url.startswith("https://")):
|
|
||||||
self.validation_warnings.append(f"{url_field}建议使用完整的URL格式")
|
|
||||||
|
|
||||||
# 检查数组字段格式(可选字段)
|
|
||||||
for list_field in ["keywords", "categories"]:
|
|
||||||
if list_field in manifest_data:
|
|
||||||
field_value = manifest_data[list_field]
|
|
||||||
if field_value is not None and not isinstance(field_value, list):
|
|
||||||
self.validation_errors.append(f"{list_field}应为数组格式")
|
|
||||||
elif isinstance(field_value, list):
|
|
||||||
# 检查数组元素是否为字符串
|
|
||||||
for i, item in enumerate(field_value):
|
|
||||||
if not isinstance(item, str):
|
|
||||||
self.validation_warnings.append(f"{list_field}[{i}]应为字符串")
|
|
||||||
|
|
||||||
# 检查建议字段(给出警告)
|
|
||||||
for field in self.RECOMMENDED_FIELDS:
|
|
||||||
if field not in manifest_data or not manifest_data[field]:
|
|
||||||
self.validation_warnings.append(f"建议填写字段: {field}")
|
|
||||||
|
|
||||||
# 检查plugin_info结构(可选)
|
|
||||||
if "plugin_info" in manifest_data:
|
|
||||||
plugin_info = manifest_data["plugin_info"]
|
|
||||||
if isinstance(plugin_info, dict):
|
|
||||||
# 检查components数组
|
|
||||||
if "components" in plugin_info:
|
|
||||||
components = plugin_info["components"]
|
|
||||||
if not isinstance(components, list):
|
|
||||||
self.validation_errors.append("plugin_info.components应为数组格式")
|
|
||||||
else:
|
|
||||||
for i, component in enumerate(components):
|
|
||||||
if not isinstance(component, dict):
|
|
||||||
self.validation_errors.append(f"plugin_info.components[{i}]应为对象")
|
|
||||||
else:
|
|
||||||
# 检查组件必需字段
|
|
||||||
for comp_field in ["type", "name", "description"]:
|
|
||||||
if comp_field not in component or not component[comp_field]:
|
|
||||||
self.validation_errors.append(
|
|
||||||
f"plugin_info.components[{i}]缺少必需字段: {comp_field}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.validation_errors.append("plugin_info应为对象格式")
|
|
||||||
|
|
||||||
return len(self.validation_errors) == 0
|
|
||||||
|
|
||||||
def get_validation_report(self) -> str:
|
|
||||||
"""获取验证报告"""
|
|
||||||
report = []
|
|
||||||
|
|
||||||
if self.validation_errors:
|
|
||||||
report.append("❌ 验证错误:")
|
|
||||||
report.extend(f" - {error}" for error in self.validation_errors)
|
|
||||||
if self.validation_warnings:
|
|
||||||
report.append("⚠️ 验证警告:")
|
|
||||||
report.extend(f" - {warning}" for warning in self.validation_warnings)
|
|
||||||
if not self.validation_errors and not self.validation_warnings:
|
|
||||||
report.append("✅ Manifest文件验证通过")
|
|
||||||
|
|
||||||
return "\n".join(report)
|
|
||||||
|
|
||||||
|
|
||||||
# class ManifestGenerator:
|
|
||||||
# """Manifest文件生成器"""
|
|
||||||
|
|
||||||
# def __init__(self):
|
|
||||||
# self.template = {
|
|
||||||
# "manifest_version": 1,
|
|
||||||
# "name": "",
|
|
||||||
# "version": "1.0.0",
|
|
||||||
# "description": "",
|
|
||||||
# "author": {"name": "", "url": ""},
|
|
||||||
# "license": "MIT",
|
|
||||||
# "host_application": {"min_version": "1.0.0", "max_version": "4.0.0"},
|
|
||||||
# "homepage_url": "",
|
|
||||||
# "repository_url": "",
|
|
||||||
# "keywords": [],
|
|
||||||
# "categories": [],
|
|
||||||
# "default_locale": "zh-CN",
|
|
||||||
# "locales_path": "_locales",
|
|
||||||
# }
|
|
||||||
|
|
||||||
# def generate_from_plugin(self, plugin_instance: BasePlugin) -> Dict[str, Any]:
|
|
||||||
# """从插件实例生成manifest
|
|
||||||
|
|
||||||
# Args:
|
|
||||||
# plugin_instance: BasePlugin实例
|
|
||||||
|
|
||||||
# Returns:
|
|
||||||
# Dict[str, Any]: 生成的manifest数据
|
|
||||||
# """
|
|
||||||
# manifest = self.template.copy()
|
|
||||||
|
|
||||||
# # 基本信息
|
|
||||||
# manifest["name"] = plugin_instance.plugin_name
|
|
||||||
# manifest["version"] = plugin_instance.plugin_version
|
|
||||||
# manifest["description"] = plugin_instance.plugin_description
|
|
||||||
|
|
||||||
# # 作者信息
|
|
||||||
# if plugin_instance.plugin_author:
|
|
||||||
# manifest["author"]["name"] = plugin_instance.plugin_author
|
|
||||||
|
|
||||||
# # 组件信息
|
|
||||||
# components = []
|
|
||||||
# plugin_components = plugin_instance.get_plugin_components()
|
|
||||||
|
|
||||||
# for component_info, component_class in plugin_components:
|
|
||||||
# component_data: Dict[str, Any] = {
|
|
||||||
# "type": component_info.component_type.value,
|
|
||||||
# "name": component_info.name,
|
|
||||||
# "description": component_info.description,
|
|
||||||
# }
|
|
||||||
|
|
||||||
# # 添加激活模式信息(对于Action组件)
|
|
||||||
# if hasattr(component_class, "focus_activation_type"):
|
|
||||||
# activation_modes = []
|
|
||||||
# if hasattr(component_class, "focus_activation_type"):
|
|
||||||
# activation_modes.append(component_class.focus_activation_type.value)
|
|
||||||
# if hasattr(component_class, "normal_activation_type"):
|
|
||||||
# activation_modes.append(component_class.normal_activation_type.value)
|
|
||||||
# component_data["activation_modes"] = list(set(activation_modes))
|
|
||||||
|
|
||||||
# # 添加关键词信息
|
|
||||||
# if hasattr(component_class, "activation_keywords"):
|
|
||||||
# keywords = getattr(component_class, "activation_keywords", [])
|
|
||||||
# if keywords:
|
|
||||||
# component_data["keywords"] = keywords
|
|
||||||
|
|
||||||
# components.append(component_data)
|
|
||||||
|
|
||||||
# manifest["plugin_info"] = {"is_built_in": True, "plugin_type": "general", "components": components}
|
|
||||||
|
|
||||||
# return manifest
|
|
||||||
|
|
||||||
# def save_manifest(self, manifest_data: Dict[str, Any], plugin_dir: str) -> bool:
|
|
||||||
# """保存manifest文件
|
|
||||||
|
|
||||||
# Args:
|
|
||||||
# manifest_data: manifest数据
|
|
||||||
# plugin_dir: 插件目录
|
|
||||||
|
|
||||||
# Returns:
|
|
||||||
# bool: 是否保存成功
|
|
||||||
# """
|
|
||||||
# try:
|
|
||||||
# manifest_path = os.path.join(plugin_dir, "_manifest.json")
|
|
||||||
# with open(manifest_path, "w", encoding="utf-8") as f:
|
|
||||||
# json.dump(manifest_data, f, ensure_ascii=False, indent=2)
|
|
||||||
# logger.info(f"Manifest文件已保存: {manifest_path}")
|
|
||||||
# return True
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(f"保存manifest文件失败: {e}")
|
|
||||||
# return False
|
|
||||||
|
|
||||||
|
|
||||||
# def validate_plugin_manifest(plugin_dir: str) -> bool:
|
|
||||||
# """验证插件目录中的manifest文件
|
|
||||||
|
|
||||||
# Args:
|
|
||||||
# plugin_dir: 插件目录路径
|
|
||||||
|
|
||||||
# Returns:
|
|
||||||
# bool: 是否验证通过
|
|
||||||
# """
|
|
||||||
# manifest_path = os.path.join(plugin_dir, "_manifest.json")
|
|
||||||
|
|
||||||
# if not os.path.exists(manifest_path):
|
|
||||||
# logger.warning(f"未找到manifest文件: {manifest_path}")
|
|
||||||
# return False
|
|
||||||
|
|
||||||
# try:
|
|
||||||
# with open(manifest_path, "r", encoding="utf-8") as f:
|
|
||||||
# manifest_data = json.load(f)
|
|
||||||
|
|
||||||
# validator = ManifestValidator()
|
|
||||||
# is_valid = validator.validate_manifest(manifest_data)
|
|
||||||
|
|
||||||
# logger.info(f"Manifest验证结果:\n{validator.get_validation_report()}")
|
|
||||||
|
|
||||||
# return is_valid
|
|
||||||
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(f"读取或验证manifest文件失败: {e}")
|
|
||||||
# return False
|
|
||||||
|
|
||||||
|
|
||||||
# def generate_plugin_manifest(plugin_instance: BasePlugin, save_to_file: bool = True) -> Optional[Dict[str, Any]]:
|
|
||||||
# """为插件生成manifest文件
|
|
||||||
|
|
||||||
# Args:
|
|
||||||
# plugin_instance: BasePlugin实例
|
|
||||||
# save_to_file: 是否保存到文件
|
|
||||||
|
|
||||||
# Returns:
|
|
||||||
# Optional[Dict[str, Any]]: 生成的manifest数据
|
|
||||||
# """
|
|
||||||
# try:
|
|
||||||
# generator = ManifestGenerator()
|
|
||||||
# manifest_data = generator.generate_from_plugin(plugin_instance)
|
|
||||||
|
|
||||||
# if save_to_file and plugin_instance.plugin_dir:
|
|
||||||
# generator.save_manifest(manifest_data, plugin_instance.plugin_dir)
|
|
||||||
|
|
||||||
# return manifest_data
|
|
||||||
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(f"生成manifest文件失败: {e}")
|
|
||||||
# return None
|
|
||||||
@@ -1,34 +1,38 @@
|
|||||||
{
|
{
|
||||||
"manifest_version": 1,
|
"manifest_version": 1,
|
||||||
"name": "Emoji插件 (Emoji Actions)",
|
"name": "Emoji插件 (Emoji Actions)",
|
||||||
"version": "1.0.0",
|
"version": "2.0.0",
|
||||||
"description": "可以发送和管理Emoji",
|
"description": "可以发送和管理Emoji",
|
||||||
"author": {
|
"author": {
|
||||||
"name": "SengokuCola",
|
"name": "SengokuCola",
|
||||||
"url": "https://github.com/MaiM-with-u"
|
"url": "https://github.com/MaiM-with-u"
|
||||||
},
|
},
|
||||||
"license": "GPL-v3.0-or-later",
|
"license": "GPL-v3.0-or-later",
|
||||||
|
|
||||||
"host_application": {
|
"host_application": {
|
||||||
"min_version": "0.10.0"
|
"min_version": "1.0.0"
|
||||||
},
|
},
|
||||||
"homepage_url": "https://github.com/MaiM-with-u/maibot",
|
"homepage_url": "https://github.com/MaiM-with-u/maibot",
|
||||||
"repository_url": "https://github.com/MaiM-with-u/maibot",
|
"repository_url": "https://github.com/MaiM-with-u/maibot",
|
||||||
"keywords": ["emoji", "action", "built-in"],
|
"keywords": ["emoji", "action", "built-in"],
|
||||||
"categories": ["Emoji"],
|
"categories": ["Emoji"],
|
||||||
|
|
||||||
"default_locale": "zh-CN",
|
"default_locale": "zh-CN",
|
||||||
"locales_path": "_locales",
|
|
||||||
|
|
||||||
"plugin_info": {
|
"plugin_info": {
|
||||||
"is_built_in": true,
|
"is_built_in": true,
|
||||||
"plugin_type": "action_provider",
|
"plugin_type": "action_provider",
|
||||||
"components": [
|
"components": [
|
||||||
{
|
{
|
||||||
"type": "action",
|
"type": "action",
|
||||||
"name": "emoji",
|
"name": "emoji",
|
||||||
"description": "发送表情包辅助表达情绪"
|
"description": "发送表情包辅助表达情绪"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
},
|
||||||
|
"capabilities": [
|
||||||
|
"emoji.get_random",
|
||||||
|
"message.get_recent",
|
||||||
|
"message.build_readable",
|
||||||
|
"llm.generate",
|
||||||
|
"send.emoji",
|
||||||
|
"config.get"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,151 +0,0 @@
|
|||||||
import random
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
# 导入新插件系统
|
|
||||||
from src.plugin_system import BaseAction, ActionActivationType
|
|
||||||
|
|
||||||
# 导入依赖的系统组件
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
|
|
||||||
# 导入API模块 - 标准Python包方式
|
|
||||||
from src.plugin_system.apis import emoji_api, llm_api, message_api
|
|
||||||
|
|
||||||
# NoReplyAction已集成到heartFC_chat.py中,不再需要导入
|
|
||||||
from src.config.config import global_config
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("emoji")
|
|
||||||
|
|
||||||
|
|
||||||
class EmojiAction(BaseAction):
|
|
||||||
"""表情动作 - 发送表情包"""
|
|
||||||
|
|
||||||
activation_type = ActionActivationType.RANDOM
|
|
||||||
random_activation_probability = global_config.emoji.emoji_chance
|
|
||||||
parallel_action = True
|
|
||||||
|
|
||||||
# 动作基本信息
|
|
||||||
action_name = "emoji"
|
|
||||||
action_description = "发送表情包辅助表达情绪"
|
|
||||||
|
|
||||||
# 动作参数定义
|
|
||||||
action_parameters = {}
|
|
||||||
|
|
||||||
# 动作使用场景
|
|
||||||
action_require = [
|
|
||||||
"发送表情包辅助表达情绪",
|
|
||||||
"表达情绪时可以选择使用",
|
|
||||||
"不要连续发送,如果你已经发过[表情包],就不要选择此动作",
|
|
||||||
]
|
|
||||||
|
|
||||||
# 关联类型
|
|
||||||
associated_types = ["emoji"]
|
|
||||||
|
|
||||||
async def execute(self) -> Tuple[bool, str]:
|
|
||||||
# sourcery skip: assign-if-exp, introduce-default-else, swap-if-else-branches, use-named-expression
|
|
||||||
"""执行表情动作"""
|
|
||||||
try:
|
|
||||||
# 1. 获取发送表情的原因
|
|
||||||
# reason = self.action_data.get("reason", "表达当前情绪")
|
|
||||||
reason = self.action_reasoning
|
|
||||||
|
|
||||||
# 2. 随机获取20个表情包
|
|
||||||
sampled_emojis = await emoji_api.get_random(30)
|
|
||||||
if not sampled_emojis:
|
|
||||||
logger.warning(f"{self.log_prefix} 无法获取随机表情包")
|
|
||||||
return False, "无法获取随机表情包"
|
|
||||||
|
|
||||||
# 3. 准备情感数据
|
|
||||||
emotion_map = {}
|
|
||||||
for b64, desc, emo in sampled_emojis:
|
|
||||||
if emo not in emotion_map:
|
|
||||||
emotion_map[emo] = []
|
|
||||||
emotion_map[emo].append((b64, desc))
|
|
||||||
|
|
||||||
available_emotions = list(emotion_map.keys())
|
|
||||||
available_emotions_str = ""
|
|
||||||
for emotion in available_emotions:
|
|
||||||
available_emotions_str += f"{emotion}\n"
|
|
||||||
|
|
||||||
if not available_emotions:
|
|
||||||
logger.warning(f"{self.log_prefix} 获取到的表情包均无情感标签, 将随机发送")
|
|
||||||
emoji_base64, emoji_description, _ = random.choice(sampled_emojis)
|
|
||||||
else:
|
|
||||||
# 获取最近的5条消息内容用于判断
|
|
||||||
recent_messages = message_api.get_recent_messages(chat_id=self.chat_id, limit=5)
|
|
||||||
messages_text = ""
|
|
||||||
if recent_messages:
|
|
||||||
# 使用message_api构建可读的消息字符串
|
|
||||||
messages_text = message_api.build_readable_messages(
|
|
||||||
messages=recent_messages,
|
|
||||||
timestamp_mode="normal_no_YMD",
|
|
||||||
truncate=False,
|
|
||||||
show_actions=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. 构建prompt让LLM选择情感
|
|
||||||
prompt = f"""你正在进行QQ聊天,你需要根据聊天记录,选出一个合适的情感标签。
|
|
||||||
请你根据以下原因和聊天记录进行选择
|
|
||||||
原因:{reason}
|
|
||||||
聊天记录:
|
|
||||||
{messages_text}
|
|
||||||
|
|
||||||
这里是可用的情感标签:
|
|
||||||
{available_emotions_str}
|
|
||||||
请直接返回最匹配的那个情感标签,不要进行任何解释或添加其他多余的文字。
|
|
||||||
"""
|
|
||||||
|
|
||||||
if global_config.debug.show_prompt:
|
|
||||||
logger.info(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
|
|
||||||
else:
|
|
||||||
logger.debug(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
|
|
||||||
|
|
||||||
# 5. 调用LLM
|
|
||||||
models = llm_api.get_available_models()
|
|
||||||
chat_model_config = models.get("utils") # 使用字典访问方式
|
|
||||||
if not chat_model_config:
|
|
||||||
logger.error(f"{self.log_prefix} 未找到'utils'模型配置,无法调用LLM")
|
|
||||||
return False, "未找到'utils'模型配置"
|
|
||||||
|
|
||||||
success, chosen_emotion, _, _ = await llm_api.generate_with_model(
|
|
||||||
prompt, model_config=chat_model_config, request_type="emoji.select"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not success:
|
|
||||||
logger.error(f"{self.log_prefix} LLM调用失败: {chosen_emotion}")
|
|
||||||
return False, f"LLM调用失败: {chosen_emotion}"
|
|
||||||
|
|
||||||
chosen_emotion = chosen_emotion.strip().replace('"', "").replace("'", "")
|
|
||||||
logger.info(f"{self.log_prefix} LLM选择的情感: {chosen_emotion}")
|
|
||||||
|
|
||||||
# 6. 根据选择的情感匹配表情包
|
|
||||||
if chosen_emotion in emotion_map:
|
|
||||||
emoji_base64, emoji_description = random.choice(emotion_map[chosen_emotion])
|
|
||||||
logger.info(f"{self.log_prefix} 发送表情包[{chosen_emotion}],原因: {reason}")
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"{self.log_prefix} LLM选择的情感 '{chosen_emotion}' 不在可用列表中, 将随机选择一个表情包"
|
|
||||||
)
|
|
||||||
emoji_base64, emoji_description, _ = random.choice(sampled_emojis)
|
|
||||||
|
|
||||||
# 7. 发送表情包
|
|
||||||
success = await self.send_emoji(emoji_base64)
|
|
||||||
|
|
||||||
if success:
|
|
||||||
# 存储动作信息
|
|
||||||
await self.store_action_info(
|
|
||||||
action_build_into_prompt=True,
|
|
||||||
action_prompt_display=f"你发送了表情包,原因:{reason}",
|
|
||||||
action_done=True,
|
|
||||||
)
|
|
||||||
return True, f"成功发送表情包:[表情包:{chosen_emotion}]"
|
|
||||||
else:
|
|
||||||
error_msg = "发送表情包失败"
|
|
||||||
logger.error(f"{self.log_prefix} {error_msg}")
|
|
||||||
|
|
||||||
await self.send_text("执行表情包动作失败")
|
|
||||||
return False, error_msg
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"{self.log_prefix} 表情动作执行失败: {e}", exc_info=True)
|
|
||||||
return False, f"表情发送失败: {str(e)}"
|
|
||||||
@@ -1,66 +1,116 @@
|
|||||||
"""
|
"""Emoji 插件 — 新 SDK 版本
|
||||||
核心动作插件
|
|
||||||
|
|
||||||
将系统核心动作(reply、no_reply、emoji)转换为新插件系统格式
|
根据聊天上下文的情感,使用 LLM 选择并发送合适的表情包。
|
||||||
这是系统的内置插件,提供基础的聊天交互功能
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Tuple, Type
|
import random
|
||||||
|
|
||||||
# 导入新插件系统
|
from maibot_sdk import MaiBotPlugin, Action
|
||||||
from src.plugin_system import BasePlugin, register_plugin, ComponentInfo
|
from maibot_sdk.types import ActivationType
|
||||||
from src.plugin_system.base.config_types import ConfigField
|
|
||||||
|
|
||||||
# 导入依赖的系统组件
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
|
|
||||||
from src.plugins.built_in.emoji_plugin.emoji import EmojiAction
|
|
||||||
|
|
||||||
logger = get_logger("core_actions")
|
|
||||||
|
|
||||||
|
|
||||||
@register_plugin
|
class EmojiPlugin(MaiBotPlugin):
|
||||||
class CoreActionsPlugin(BasePlugin):
|
"""表情包插件"""
|
||||||
"""核心动作插件
|
|
||||||
|
|
||||||
系统内置插件,提供基础的聊天交互功能:
|
@Action(
|
||||||
- Reply: 回复动作
|
"emoji",
|
||||||
- NoReply: 不回复动作
|
description="发送表情包辅助表达情绪",
|
||||||
- Emoji: 表情动作
|
activation_type=ActivationType.RANDOM,
|
||||||
|
activation_probability=0.3,
|
||||||
|
parallel_action=True,
|
||||||
|
action_require=[
|
||||||
|
"发送表情包辅助表达情绪",
|
||||||
|
"表达情绪时可以选择使用",
|
||||||
|
"不要连续发送,如果你已经发过[表情包],就不要选择此动作",
|
||||||
|
],
|
||||||
|
associated_types=["emoji"],
|
||||||
|
)
|
||||||
|
async def handle_emoji(self, stream_id: str = "", reasoning: str = "", chat_id: str = "", **kwargs):
|
||||||
|
"""执行表情动作"""
|
||||||
|
reason = reasoning or "表达当前情绪"
|
||||||
|
|
||||||
注意:插件基本信息优先从_manifest.json文件中读取
|
# 1. 随机获取30个表情包
|
||||||
"""
|
result = await self.ctx.emoji.get_random(30)
|
||||||
|
if not result or not result.get("success"):
|
||||||
|
return False, "无法获取随机表情包"
|
||||||
|
|
||||||
# 插件基本信息
|
sampled_emojis = result.get("emojis", [])
|
||||||
plugin_name: str = "core_actions" # 内部标识符
|
if not sampled_emojis:
|
||||||
enable_plugin: bool = True
|
return False, "无法获取随机表情包"
|
||||||
dependencies: list[str] = [] # 插件依赖列表
|
|
||||||
python_dependencies: list[str] = [] # Python包依赖列表
|
|
||||||
config_file_name: str = "config.toml"
|
|
||||||
|
|
||||||
# 配置节描述
|
# 2. 按情感分组
|
||||||
config_section_descriptions = {
|
emotion_map: dict[str, list] = {}
|
||||||
"plugin": "插件启用配置",
|
for emoji in sampled_emojis:
|
||||||
"components": "核心组件启用配置",
|
emo = emoji.get("emotion", "")
|
||||||
}
|
if emo not in emotion_map:
|
||||||
|
emotion_map[emo] = []
|
||||||
|
emotion_map[emo].append(emoji)
|
||||||
|
|
||||||
# 配置Schema定义
|
available_emotions = list(emotion_map.keys())
|
||||||
config_schema: dict = {
|
|
||||||
"plugin": {
|
|
||||||
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
|
|
||||||
"config_version": ConfigField(type=str, default="0.6.0", description="配置文件版本"),
|
|
||||||
},
|
|
||||||
"components": {
|
|
||||||
"enable_emoji": ConfigField(type=bool, default=True, description="是否启用发送表情/图片动作"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
if not available_emotions:
|
||||||
"""返回插件包含的组件列表"""
|
# 无情感标签,随机发送
|
||||||
|
chosen = random.choice(sampled_emojis)
|
||||||
|
await self.ctx.send.emoji(chosen["base64"], stream_id)
|
||||||
|
return True, "随机发送了表情包"
|
||||||
|
|
||||||
# --- 根据配置注册组件 ---
|
# 3. 获取最近消息作为上下文
|
||||||
components = []
|
messages_text = ""
|
||||||
if self.get_config("components.enable_emoji", True):
|
if chat_id:
|
||||||
components.append((EmojiAction.get_action_info(), EmojiAction))
|
recent_result = await self.ctx.message.get_recent(chat_id=chat_id, limit=5)
|
||||||
|
if recent_result and recent_result.get("success"):
|
||||||
|
readable_result = await self.ctx.call_capability(
|
||||||
|
"message.build_readable",
|
||||||
|
chat_id=chat_id,
|
||||||
|
start_time=0,
|
||||||
|
end_time=0,
|
||||||
|
limit=5,
|
||||||
|
timestamp_mode="normal_no_YMD",
|
||||||
|
truncate=False,
|
||||||
|
)
|
||||||
|
if readable_result and readable_result.get("success"):
|
||||||
|
messages_text = readable_result.get("text", "")
|
||||||
|
|
||||||
return components
|
# 4. 构建 prompt 让 LLM 选择情感
|
||||||
|
available_emotions_str = "\n".join(available_emotions)
|
||||||
|
prompt = f"""你正在进行QQ聊天,你需要根据聊天记录,选出一个合适的情感标签。
|
||||||
|
请你根据以下原因和聊天记录进行选择
|
||||||
|
原因:{reason}
|
||||||
|
聊天记录:
|
||||||
|
{messages_text}
|
||||||
|
|
||||||
|
这里是可用的情感标签:
|
||||||
|
{available_emotions_str}
|
||||||
|
请直接返回最匹配的那个情感标签,不要进行任何解释或添加其他多余的文字。
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 5. 调用 LLM
|
||||||
|
llm_result = await self.ctx.llm.generate(prompt=prompt, model_name="utils")
|
||||||
|
if not llm_result or not llm_result.get("success"):
|
||||||
|
chosen = random.choice(sampled_emojis)
|
||||||
|
await self.ctx.send.emoji(chosen["base64"], stream_id)
|
||||||
|
return True, "LLM调用失败,随机发送了表情包"
|
||||||
|
|
||||||
|
chosen_emotion = llm_result.get("response", "").strip().replace('"', "").replace("'", "")
|
||||||
|
|
||||||
|
# 6. 根据选择的情感匹配表情包
|
||||||
|
if chosen_emotion in emotion_map:
|
||||||
|
chosen = random.choice(emotion_map[chosen_emotion])
|
||||||
|
else:
|
||||||
|
chosen = random.choice(sampled_emojis)
|
||||||
|
|
||||||
|
# 7. 发送
|
||||||
|
send_result = await self.ctx.send.emoji(chosen["base64"], stream_id)
|
||||||
|
if send_result and send_result.get("success"):
|
||||||
|
return True, f"成功发送表情包:[表情包:{chosen_emotion}]"
|
||||||
|
return False, "发送表情包失败"
|
||||||
|
|
||||||
|
async def on_load(self):
|
||||||
|
# 从插件配置读取 emoji_chance 来覆盖默认概率
|
||||||
|
config_result = await self.ctx.config.get("emoji.emoji_chance")
|
||||||
|
if config_result and isinstance(config_result, dict) and config_result.get("success"):
|
||||||
|
pass # 配置已在宿主端管理
|
||||||
|
|
||||||
|
|
||||||
|
def create_plugin():
|
||||||
|
return EmojiPlugin()
|
||||||
|
|||||||
33
src/plugins/built_in/knowledge/_manifest.json
Normal file
33
src/plugins/built_in/knowledge/_manifest.json
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
{
|
||||||
|
"manifest_version": 1,
|
||||||
|
"name": "LPMM 知识库插件 (Knowledge Search)",
|
||||||
|
"version": "2.0.0",
|
||||||
|
"description": "从 LPMM 知识库中搜索相关信息,供 LLM 工具调用",
|
||||||
|
"author": {
|
||||||
|
"name": "MaiBot团队",
|
||||||
|
"url": "https://github.com/MaiM-with-u"
|
||||||
|
},
|
||||||
|
"license": "GPL-v3.0-or-later",
|
||||||
|
"host_application": {
|
||||||
|
"min_version": "1.0.0"
|
||||||
|
},
|
||||||
|
"homepage_url": "https://github.com/MaiM-with-u/maibot",
|
||||||
|
"repository_url": "https://github.com/MaiM-with-u/maibot",
|
||||||
|
"keywords": ["knowledge", "lpmm", "search", "tool", "built-in"],
|
||||||
|
"categories": ["Knowledge", "Tools"],
|
||||||
|
"default_locale": "zh-CN",
|
||||||
|
"plugin_info": {
|
||||||
|
"is_built_in": true,
|
||||||
|
"plugin_type": "tool_provider",
|
||||||
|
"components": [
|
||||||
|
{
|
||||||
|
"type": "tool",
|
||||||
|
"name": "lpmm_search_knowledge",
|
||||||
|
"description": "从知识库中搜索相关信息"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"capabilities": [
|
||||||
|
"knowledge.search"
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
from typing import Dict, Any
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.config.config import global_config
|
|
||||||
from src.chat.knowledge import qa_manager
|
|
||||||
from src.plugin_system import BaseTool, ToolParamType
|
|
||||||
|
|
||||||
logger = get_logger("lpmm_get_knowledge_tool")
|
|
||||||
|
|
||||||
|
|
||||||
class SearchKnowledgeFromLPMMTool(BaseTool):
|
|
||||||
"""从LPMM知识库中搜索相关信息的工具"""
|
|
||||||
|
|
||||||
name = "lpmm_search_knowledge"
|
|
||||||
description = "从知识库中搜索相关信息,如果你需要知识,就使用这个工具"
|
|
||||||
parameters = [
|
|
||||||
("query", ToolParamType.STRING, "搜索查询关键词", True, None),
|
|
||||||
("limit", ToolParamType.INTEGER, "希望返回的相关知识条数,默认5", False, None),
|
|
||||||
]
|
|
||||||
available_for_llm = global_config.lpmm_knowledge.enable
|
|
||||||
|
|
||||||
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""执行知识库搜索
|
|
||||||
|
|
||||||
Args:
|
|
||||||
function_args: 工具参数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: 工具执行结果
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
query: str = function_args.get("query") # type: ignore
|
|
||||||
limit = function_args.get("limit", 5)
|
|
||||||
try:
|
|
||||||
limit_value = int(limit)
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
limit_value = 5
|
|
||||||
limit_value = max(1, limit_value)
|
|
||||||
# threshold = function_args.get("threshold", 0.4)
|
|
||||||
|
|
||||||
# 检查LPMM知识库是否启用
|
|
||||||
if qa_manager is None:
|
|
||||||
logger.debug("LPMM知识库已禁用,跳过知识获取")
|
|
||||||
return {"type": "info", "id": query, "content": "LPMM知识库已禁用"}
|
|
||||||
|
|
||||||
# 调用知识库搜索
|
|
||||||
|
|
||||||
knowledge_info = await qa_manager.get_knowledge(query, limit=limit_value)
|
|
||||||
|
|
||||||
logger.debug(f"知识库查询结果: {knowledge_info}")
|
|
||||||
|
|
||||||
if knowledge_info:
|
|
||||||
content = f"你知道这些知识: {knowledge_info}"
|
|
||||||
else:
|
|
||||||
content = f"你不太了解有关{query}的知识"
|
|
||||||
return {"type": "lpmm_knowledge", "id": query, "content": content}
|
|
||||||
except Exception as e:
|
|
||||||
# 捕获异常并记录错误
|
|
||||||
logger.error(f"知识库搜索工具执行失败: {str(e)}")
|
|
||||||
# 在其他异常情况下,确保 id 仍然是 query (如果它被定义了)
|
|
||||||
query_id = query if "query" in locals() else "unknown_query"
|
|
||||||
return {"type": "info", "id": query_id, "content": f"lpmm知识库搜索失败,炸了: {str(e)}"}
|
|
||||||
39
src/plugins/built_in/knowledge/plugin.py
Normal file
39
src/plugins/built_in/knowledge/plugin.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
"""LPMM 知识库搜索插件 — 新 SDK 版本
|
||||||
|
|
||||||
|
提供 LLM 可调用的知识库搜索工具。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from maibot_sdk import MaiBotPlugin, Tool
|
||||||
|
from maibot_sdk.types import ToolParameterInfo, ToolParamType
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgePlugin(MaiBotPlugin):
|
||||||
|
"""LPMM 知识库插件"""
|
||||||
|
|
||||||
|
@Tool(
|
||||||
|
"lpmm_search_knowledge",
|
||||||
|
description="从知识库中搜索相关信息,如果你需要知识,就使用这个工具",
|
||||||
|
parameters=[
|
||||||
|
ToolParameterInfo(name="query", param_type=ToolParamType.STRING, description="搜索查询关键词", required=True),
|
||||||
|
ToolParameterInfo(name="limit", param_type=ToolParamType.INTEGER, description="希望返回的相关知识条数,默认5", required=False, default=5),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def handle_lpmm_search_knowledge(self, query: str = "", limit: int = 5, **kwargs):
|
||||||
|
"""执行知识库搜索"""
|
||||||
|
if not query:
|
||||||
|
return {"type": "info", "id": "", "content": "未提供搜索关键词"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
limit_value = max(1, int(limit))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
limit_value = 5
|
||||||
|
|
||||||
|
result = await self.ctx.call_capability("knowledge.search", query=query, limit=limit_value)
|
||||||
|
if result and result.get("success"):
|
||||||
|
content = result.get("content", f"你不太了解有关{query}的知识")
|
||||||
|
return {"type": "lpmm_knowledge", "id": query, "content": content}
|
||||||
|
return {"type": "info", "id": query, "content": f"知识库搜索失败: {result}"}
|
||||||
|
|
||||||
|
|
||||||
|
def create_plugin():
|
||||||
|
return KnowledgePlugin()
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
"manifest_version": 1,
|
"manifest_version": 1,
|
||||||
"name": "插件和组件管理 (Plugin and Component Management)",
|
"name": "插件和组件管理 (Plugin and Component Management)",
|
||||||
"version": "1.0.0",
|
"version": "2.0.0",
|
||||||
"description": "通过系统API管理插件和组件的生命周期,包括加载、卸载、启用和禁用等操作。",
|
"description": "通过系统API管理插件和组件的生命周期,包括加载、卸载、启用和禁用等操作。",
|
||||||
"author": {
|
"author": {
|
||||||
"name": "MaiBot团队",
|
"name": "MaiBot团队",
|
||||||
@@ -9,7 +9,7 @@
|
|||||||
},
|
},
|
||||||
"license": "GPL-v3.0-or-later",
|
"license": "GPL-v3.0-or-later",
|
||||||
"host_application": {
|
"host_application": {
|
||||||
"min_version": "0.10.1"
|
"min_version": "1.0.0"
|
||||||
},
|
},
|
||||||
"homepage_url": "https://github.com/MaiM-with-u/maibot",
|
"homepage_url": "https://github.com/MaiM-with-u/maibot",
|
||||||
"repository_url": "https://github.com/MaiM-with-u/maibot",
|
"repository_url": "https://github.com/MaiM-with-u/maibot",
|
||||||
@@ -28,10 +28,22 @@
|
|||||||
"plugin_info": {
|
"plugin_info": {
|
||||||
"is_built_in": true,
|
"is_built_in": true,
|
||||||
"plugin_type": "plugin_management",
|
"plugin_type": "plugin_management",
|
||||||
|
"capabilities": [
|
||||||
|
"component.get_all_plugins",
|
||||||
|
"component.list_loaded_plugins",
|
||||||
|
"component.list_registered_plugins",
|
||||||
|
"component.enable",
|
||||||
|
"component.disable",
|
||||||
|
"component.load_plugin",
|
||||||
|
"component.unload_plugin",
|
||||||
|
"component.reload_plugin",
|
||||||
|
"send.text",
|
||||||
|
"config.get"
|
||||||
|
],
|
||||||
"components": [
|
"components": [
|
||||||
{
|
{
|
||||||
"type": "command",
|
"type": "command",
|
||||||
"name": "plugin_management",
|
"name": "management",
|
||||||
"description": "管理插件和组件的生命周期,包括加载、卸载、启用和禁用等操作。"
|
"description": "管理插件和组件的生命周期,包括加载、卸载、启用和禁用等操作。"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,454 +1,279 @@
|
|||||||
import asyncio
|
"""插件和组件管理 — 新 SDK 版本
|
||||||
|
|
||||||
from typing import List, Tuple, Type
|
通过 /pm 命令管理插件和组件的生命周期。
|
||||||
from src.plugin_system import (
|
"""
|
||||||
BasePlugin,
|
|
||||||
BaseCommand,
|
from maibot_sdk import MaiBotPlugin, Command
|
||||||
CommandInfo,
|
|
||||||
ConfigField,
|
|
||||||
register_plugin,
|
_VALID_COMPONENT_TYPES = ("action", "command", "event_handler")
|
||||||
plugin_manage_api,
|
|
||||||
component_manage_api,
|
HELP_ALL = (
|
||||||
ComponentInfo,
|
"管理命令帮助\n"
|
||||||
ComponentType,
|
"/pm help 管理命令提示\n"
|
||||||
send_api,
|
"/pm plugin 插件管理命令\n"
|
||||||
|
"/pm component 组件管理命令\n"
|
||||||
|
"使用 /pm plugin help 或 /pm component help 获取具体帮助"
|
||||||
|
)
|
||||||
|
HELP_PLUGIN = (
|
||||||
|
"插件管理命令帮助\n"
|
||||||
|
"/pm plugin help 插件管理命令提示\n"
|
||||||
|
"/pm plugin list 列出所有注册的插件\n"
|
||||||
|
"/pm plugin list_enabled 列出所有加载(启用)的插件\n"
|
||||||
|
"/pm plugin load <plugin_name> 加载指定插件\n"
|
||||||
|
"/pm plugin unload <plugin_name> 卸载指定插件\n"
|
||||||
|
"/pm plugin reload <plugin_name> 重新加载指定插件\n"
|
||||||
|
)
|
||||||
|
HELP_COMPONENT = (
|
||||||
|
"组件管理命令帮助\n"
|
||||||
|
"/pm component help 组件管理命令提示\n"
|
||||||
|
"/pm component list 列出所有注册的组件\n"
|
||||||
|
"/pm component list enabled <可选: type> 列出所有启用的组件\n"
|
||||||
|
"/pm component list disabled <可选: type> 列出所有禁用的组件\n"
|
||||||
|
" - <type> 可选项: local,代表当前聊天中的;global,代表全局的\n"
|
||||||
|
" - <type> 不填时为 global\n"
|
||||||
|
"/pm component list type <component_type> 列出已经注册的指定类型的组件\n"
|
||||||
|
"/pm component enable global <component_name> <component_type> 全局启用组件\n"
|
||||||
|
"/pm component enable local <component_name> <component_type> 本聊天启用组件\n"
|
||||||
|
"/pm component disable global <component_name> <component_type> 全局禁用组件\n"
|
||||||
|
"/pm component disable local <component_name> <component_type> 本聊天禁用组件\n"
|
||||||
|
" - <component_type> 可选项: action, command, event_handler\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ManagementCommand(BaseCommand):
|
class PluginManagementPlugin(MaiBotPlugin):
|
||||||
command_name: str = "management"
|
"""插件和组件管理插件"""
|
||||||
description: str = "管理命令"
|
|
||||||
command_pattern: str = r"(?P<manage_command>^/pm(\s[a-zA-Z0-9_]+)*\s*$)"
|
|
||||||
|
|
||||||
async def execute(self) -> Tuple[bool, str, bool]:
|
@Command(
|
||||||
# sourcery skip: merge-duplicate-blocks
|
"management",
|
||||||
if (
|
description="管理插件和组件的生命周期",
|
||||||
not self.message
|
pattern=r"(?P<manage_command>^/pm(\s[a-zA-Z0-9_]+)*\s*$)",
|
||||||
or not self.message.message_info
|
)
|
||||||
or not self.message.message_info.user_info
|
async def handle_management(
|
||||||
or str(self.message.message_info.user_info.user_id) not in self.get_config("plugin.permission", []) # type: ignore
|
self, stream_id: str = "", user_id: str = "", matched_groups: dict | None = None, **kwargs
|
||||||
):
|
):
|
||||||
await self._send_message("你没有权限使用插件管理命令")
|
"""处理 /pm 命令"""
|
||||||
|
# 权限检查
|
||||||
|
permission_result = await self.ctx.config.get("plugin.permission")
|
||||||
|
permission_list = permission_result if isinstance(permission_result, list) else []
|
||||||
|
if str(user_id) not in permission_list:
|
||||||
|
await self.ctx.send.text("你没有权限使用插件管理命令", stream_id)
|
||||||
return False, "没有权限", True
|
return False, "没有权限", True
|
||||||
if not self.message.chat_stream:
|
|
||||||
await self._send_message("无法获取聊天流信息")
|
if not stream_id:
|
||||||
return False, "无法获取聊天流信息", True
|
return False, "无法获取聊天流信息", True
|
||||||
self.stream_id = self.message.chat_stream.stream_id
|
|
||||||
if not self.stream_id:
|
raw_command = (matched_groups or {}).get("manage_command", "").strip()
|
||||||
await self._send_message("无法获取聊天流信息")
|
parts = raw_command.split(" ") if raw_command else ["/pm"]
|
||||||
return False, "无法获取聊天流信息", True
|
n = len(parts)
|
||||||
command_list = self.matched_groups["manage_command"].strip().split(" ")
|
|
||||||
if len(command_list) == 1:
|
# /pm
|
||||||
await self.show_help("all")
|
if n == 1:
|
||||||
|
await self.ctx.send.text(HELP_ALL, stream_id)
|
||||||
return True, "帮助已发送", True
|
return True, "帮助已发送", True
|
||||||
if len(command_list) == 2:
|
|
||||||
match command_list[1]:
|
|
||||||
case "plugin":
|
|
||||||
await self.show_help("plugin")
|
|
||||||
case "component":
|
|
||||||
await self.show_help("component")
|
|
||||||
case "help":
|
|
||||||
await self.show_help("all")
|
|
||||||
case _:
|
|
||||||
await self._send_message("插件管理命令不合法")
|
|
||||||
return False, "命令不合法", True
|
|
||||||
if len(command_list) == 3:
|
|
||||||
if command_list[1] == "plugin":
|
|
||||||
match command_list[2]:
|
|
||||||
case "help":
|
|
||||||
await self.show_help("plugin")
|
|
||||||
case "list":
|
|
||||||
await self._list_registered_plugins()
|
|
||||||
case "list_enabled":
|
|
||||||
await self._list_loaded_plugins()
|
|
||||||
case "rescan":
|
|
||||||
await self._rescan_plugin_dirs()
|
|
||||||
case _:
|
|
||||||
await self._send_message("插件管理命令不合法")
|
|
||||||
return False, "命令不合法", True
|
|
||||||
elif command_list[1] == "component":
|
|
||||||
if command_list[2] == "list":
|
|
||||||
await self._list_all_registered_components()
|
|
||||||
elif command_list[2] == "help":
|
|
||||||
await self.show_help("component")
|
|
||||||
else:
|
|
||||||
await self._send_message("插件管理命令不合法")
|
|
||||||
return False, "命令不合法", True
|
|
||||||
else:
|
|
||||||
await self._send_message("插件管理命令不合法")
|
|
||||||
return False, "命令不合法", True
|
|
||||||
if len(command_list) == 4:
|
|
||||||
if command_list[1] == "plugin":
|
|
||||||
match command_list[2]:
|
|
||||||
case "load":
|
|
||||||
await self._load_plugin(command_list[3])
|
|
||||||
case "unload":
|
|
||||||
await self._unload_plugin(command_list[3])
|
|
||||||
case "reload":
|
|
||||||
await self._reload_plugin(command_list[3])
|
|
||||||
case "add_dir":
|
|
||||||
await self._add_dir(command_list[3])
|
|
||||||
case _:
|
|
||||||
await self._send_message("插件管理命令不合法")
|
|
||||||
return False, "命令不合法", True
|
|
||||||
elif command_list[1] == "component":
|
|
||||||
if command_list[2] != "list":
|
|
||||||
await self._send_message("插件管理命令不合法")
|
|
||||||
return False, "命令不合法", True
|
|
||||||
if command_list[3] == "enabled":
|
|
||||||
await self._list_enabled_components()
|
|
||||||
elif command_list[3] == "disabled":
|
|
||||||
await self._list_disabled_components()
|
|
||||||
else:
|
|
||||||
await self._send_message("插件管理命令不合法")
|
|
||||||
return False, "命令不合法", True
|
|
||||||
else:
|
|
||||||
await self._send_message("插件管理命令不合法")
|
|
||||||
return False, "命令不合法", True
|
|
||||||
if len(command_list) == 5:
|
|
||||||
if command_list[1] != "component":
|
|
||||||
await self._send_message("插件管理命令不合法")
|
|
||||||
return False, "命令不合法", True
|
|
||||||
if command_list[2] != "list":
|
|
||||||
await self._send_message("插件管理命令不合法")
|
|
||||||
return False, "命令不合法", True
|
|
||||||
if command_list[3] == "enabled":
|
|
||||||
await self._list_enabled_components(target_type=command_list[4])
|
|
||||||
elif command_list[3] == "disabled":
|
|
||||||
await self._list_disabled_components(target_type=command_list[4])
|
|
||||||
elif command_list[3] == "type":
|
|
||||||
await self._list_registered_components_by_type(command_list[4])
|
|
||||||
else:
|
|
||||||
await self._send_message("插件管理命令不合法")
|
|
||||||
return False, "命令不合法", True
|
|
||||||
if len(command_list) == 6:
|
|
||||||
if command_list[1] != "component":
|
|
||||||
await self._send_message("插件管理命令不合法")
|
|
||||||
return False, "命令不合法", True
|
|
||||||
if command_list[2] == "enable":
|
|
||||||
if command_list[3] == "global":
|
|
||||||
await self._globally_enable_component(command_list[4], command_list[5])
|
|
||||||
elif command_list[3] == "local":
|
|
||||||
await self._locally_enable_component(command_list[4], command_list[5])
|
|
||||||
else:
|
|
||||||
await self._send_message("插件管理命令不合法")
|
|
||||||
return False, "命令不合法", True
|
|
||||||
elif command_list[2] == "disable":
|
|
||||||
if command_list[3] == "global":
|
|
||||||
await self._globally_disable_component(command_list[4], command_list[5])
|
|
||||||
elif command_list[3] == "local":
|
|
||||||
await self._locally_disable_component(command_list[4], command_list[5])
|
|
||||||
else:
|
|
||||||
await self._send_message("插件管理命令不合法")
|
|
||||||
return False, "命令不合法", True
|
|
||||||
else:
|
|
||||||
await self._send_message("插件管理命令不合法")
|
|
||||||
return False, "命令不合法", True
|
|
||||||
|
|
||||||
return True, "命令执行完成", True
|
# /pm <sub>
|
||||||
|
if n == 2:
|
||||||
|
sub = parts[1]
|
||||||
|
if sub == "plugin":
|
||||||
|
await self.ctx.send.text(HELP_PLUGIN, stream_id)
|
||||||
|
elif sub == "component":
|
||||||
|
await self.ctx.send.text(HELP_COMPONENT, stream_id)
|
||||||
|
elif sub == "help":
|
||||||
|
await self.ctx.send.text(HELP_ALL, stream_id)
|
||||||
|
else:
|
||||||
|
await self.ctx.send.text("插件管理命令不合法", stream_id)
|
||||||
|
return False, "命令不合法", True
|
||||||
|
return True, "帮助已发送", True
|
||||||
|
|
||||||
async def show_help(self, target: str):
|
# /pm plugin <action> / /pm component <action>
|
||||||
help_msg = ""
|
if n == 3:
|
||||||
match target:
|
if parts[1] == "plugin":
|
||||||
case "all":
|
await self._handle_plugin_3(parts[2], stream_id)
|
||||||
help_msg = (
|
elif parts[1] == "component":
|
||||||
"管理命令帮助\n"
|
if parts[2] == "list":
|
||||||
"/pm help 管理命令提示\n"
|
await self._list_all_components(stream_id)
|
||||||
"/pm plugin 插件管理命令\n"
|
elif parts[2] == "help":
|
||||||
"/pm component 组件管理命令\n"
|
await self.ctx.send.text(HELP_COMPONENT, stream_id)
|
||||||
"使用 /pm plugin help 或 /pm component help 获取具体帮助"
|
else:
|
||||||
)
|
await self.ctx.send.text("插件管理命令不合法", stream_id)
|
||||||
case "plugin":
|
return False, "命令不合法", True
|
||||||
help_msg = (
|
else:
|
||||||
"插件管理命令帮助\n"
|
await self.ctx.send.text("插件管理命令不合法", stream_id)
|
||||||
"/pm plugin help 插件管理命令提示\n"
|
return False, "命令不合法", True
|
||||||
"/pm plugin list 列出所有注册的插件\n"
|
return True, "命令执行完成", True
|
||||||
"/pm plugin list_enabled 列出所有加载(启用)的插件\n"
|
|
||||||
"/pm plugin rescan 重新扫描所有目录\n"
|
if n == 4:
|
||||||
"/pm plugin load <plugin_name> 加载指定插件\n"
|
if parts[1] == "plugin":
|
||||||
"/pm plugin unload <plugin_name> 卸载指定插件\n"
|
await self._handle_plugin_4(parts[2], parts[3], stream_id)
|
||||||
"/pm plugin reload <plugin_name> 重新加载指定插件\n"
|
elif parts[1] == "component":
|
||||||
"/pm plugin add_dir <directory_path> 添加插件目录\n"
|
if parts[2] == "list":
|
||||||
)
|
await self._handle_component_list_4(parts[3], stream_id)
|
||||||
case "component":
|
else:
|
||||||
help_msg = (
|
await self.ctx.send.text("插件管理命令不合法", stream_id)
|
||||||
"组件管理命令帮助\n"
|
return False, "命令不合法", True
|
||||||
"/pm component help 组件管理命令提示\n"
|
else:
|
||||||
"/pm component list 列出所有注册的组件\n"
|
await self.ctx.send.text("插件管理命令不合法", stream_id)
|
||||||
"/pm component list enabled <可选: type> 列出所有启用的组件\n"
|
return False, "命令不合法", True
|
||||||
"/pm component list disabled <可选: type> 列出所有禁用的组件\n"
|
return True, "命令执行完成", True
|
||||||
" - <type> 可选项: local,代表当前聊天中的;global,代表全局的\n"
|
|
||||||
" - <type> 不填时为 global\n"
|
if n == 5:
|
||||||
"/pm component list type <component_type> 列出已经注册的指定类型的组件\n"
|
if parts[1] != "component" or parts[2] != "list":
|
||||||
"/pm component enable global <component_name> <component_type> 全局启用组件\n"
|
await self.ctx.send.text("插件管理命令不合法", stream_id)
|
||||||
"/pm component enable local <component_name> <component_type> 本聊天启用组件\n"
|
return False, "命令不合法", True
|
||||||
"/pm component disable global <component_name> <component_type> 全局禁用组件\n"
|
await self._handle_component_list_5(parts[3], parts[4], stream_id)
|
||||||
"/pm component disable local <component_name> <component_type> 本聊天禁用组件\n"
|
return True, "命令执行完成", True
|
||||||
" - <component_type> 可选项: action, command, event_handler\n"
|
|
||||||
)
|
if n == 6:
|
||||||
|
if parts[1] != "component":
|
||||||
|
await self.ctx.send.text("插件管理命令不合法", stream_id)
|
||||||
|
return False, "命令不合法", True
|
||||||
|
await self._handle_component_toggle(parts[2], parts[3], parts[4], parts[5], stream_id)
|
||||||
|
return True, "命令执行完成", True
|
||||||
|
|
||||||
|
await self.ctx.send.text("插件管理命令不合法", stream_id)
|
||||||
|
return False, "命令不合法", True
|
||||||
|
|
||||||
|
# ------ plugin 子命令 ------
|
||||||
|
|
||||||
|
async def _handle_plugin_3(self, action: str, stream_id: str):
|
||||||
|
match action:
|
||||||
|
case "help":
|
||||||
|
await self.ctx.send.text(HELP_PLUGIN, stream_id)
|
||||||
|
case "list":
|
||||||
|
result = await self.ctx.component.list_registered_plugins()
|
||||||
|
plugins = result if isinstance(result, list) else []
|
||||||
|
await self.ctx.send.text(f"已注册的插件: {', '.join(plugins) if plugins else '无'}", stream_id)
|
||||||
|
case "list_enabled":
|
||||||
|
result = await self.ctx.component.list_loaded_plugins()
|
||||||
|
plugins = result if isinstance(result, list) else []
|
||||||
|
await self.ctx.send.text(f"已加载的插件: {', '.join(plugins) if plugins else '无'}", stream_id)
|
||||||
case _:
|
case _:
|
||||||
|
await self.ctx.send.text("插件管理命令不合法", stream_id)
|
||||||
|
|
||||||
|
async def _handle_plugin_4(self, action: str, name: str, stream_id: str):
|
||||||
|
match action:
|
||||||
|
case "load":
|
||||||
|
result = await self.ctx.component.load_plugin(name)
|
||||||
|
ok = result.get("success", False) if isinstance(result, dict) else bool(result)
|
||||||
|
msg = f"插件加载成功: {name}" if ok else f"插件加载失败: {name}"
|
||||||
|
await self.ctx.send.text(msg, stream_id)
|
||||||
|
case "unload":
|
||||||
|
result = await self.ctx.component.unload_plugin(name)
|
||||||
|
ok = result.get("success", False) if isinstance(result, dict) else bool(result)
|
||||||
|
msg = f"插件卸载成功: {name}" if ok else f"插件卸载失败: {name}"
|
||||||
|
await self.ctx.send.text(msg, stream_id)
|
||||||
|
case "reload":
|
||||||
|
result = await self.ctx.component.reload_plugin(name)
|
||||||
|
ok = result.get("success", False) if isinstance(result, dict) else bool(result)
|
||||||
|
msg = f"插件重新加载成功: {name}" if ok else f"插件重新加载失败: {name}"
|
||||||
|
await self.ctx.send.text(msg, stream_id)
|
||||||
|
case _:
|
||||||
|
await self.ctx.send.text("插件管理命令不合法", stream_id)
|
||||||
|
|
||||||
|
# ------ component 子命令 ------
|
||||||
|
|
||||||
|
async def _list_all_components(self, stream_id: str):
|
||||||
|
result = await self.ctx.component.get_all_plugins()
|
||||||
|
if not result:
|
||||||
|
await self.ctx.send.text("没有注册的组件", stream_id)
|
||||||
|
return
|
||||||
|
components = self._extract_components(result)
|
||||||
|
if not components:
|
||||||
|
await self.ctx.send.text("没有注册的组件", stream_id)
|
||||||
|
return
|
||||||
|
text = ", ".join(f"{c['name']} ({c['type']})" for c in components)
|
||||||
|
await self.ctx.send.text(f"已注册的组件: {text}", stream_id)
|
||||||
|
|
||||||
|
async def _handle_component_list_4(self, sub: str, stream_id: str):
|
||||||
|
if sub == "enabled":
|
||||||
|
await self._list_filtered_components("enabled", "global", stream_id)
|
||||||
|
elif sub == "disabled":
|
||||||
|
await self._list_filtered_components("disabled", "global", stream_id)
|
||||||
|
else:
|
||||||
|
await self.ctx.send.text("插件管理命令不合法", stream_id)
|
||||||
|
|
||||||
|
async def _handle_component_list_5(self, sub: str, arg: str, stream_id: str):
|
||||||
|
if sub in ("enabled", "disabled"):
|
||||||
|
await self._list_filtered_components(sub, arg, stream_id)
|
||||||
|
elif sub == "type":
|
||||||
|
if arg not in _VALID_COMPONENT_TYPES:
|
||||||
|
await self.ctx.send.text(f"未知组件类型: {arg}", stream_id)
|
||||||
return
|
return
|
||||||
await self._send_message(help_msg)
|
result = await self.ctx.component.get_all_plugins()
|
||||||
|
components = [c for c in self._extract_components(result) if c.get("type") == arg]
|
||||||
async def _list_loaded_plugins(self):
|
if not components:
|
||||||
plugins = plugin_manage_api.list_loaded_plugins()
|
await self.ctx.send.text(f"没有注册的 {arg} 组件", stream_id)
|
||||||
await self._send_message(f"已加载的插件: {', '.join(plugins)}")
|
return
|
||||||
|
text = ", ".join(f"{c['name']} ({c['type']})" for c in components)
|
||||||
async def _list_registered_plugins(self):
|
await self.ctx.send.text(f"注册的 {arg} 组件: {text}", stream_id)
|
||||||
plugins = plugin_manage_api.list_registered_plugins()
|
|
||||||
await self._send_message(f"已注册的插件: {', '.join(plugins)}")
|
|
||||||
|
|
||||||
async def _rescan_plugin_dirs(self):
|
|
||||||
plugin_manage_api.rescan_plugin_directory()
|
|
||||||
await self._send_message("插件目录重新扫描执行中")
|
|
||||||
|
|
||||||
async def _load_plugin(self, plugin_name: str):
|
|
||||||
success, count = plugin_manage_api.load_plugin(plugin_name)
|
|
||||||
if success:
|
|
||||||
await self._send_message(f"插件加载成功: {plugin_name}")
|
|
||||||
else:
|
else:
|
||||||
if count == 0:
|
await self.ctx.send.text("插件管理命令不合法", stream_id)
|
||||||
await self._send_message(f"插件{plugin_name}为禁用状态")
|
|
||||||
await self._send_message(f"插件加载失败: {plugin_name}")
|
|
||||||
|
|
||||||
async def _unload_plugin(self, plugin_name: str):
|
async def _list_filtered_components(self, filter_mode: str, scope: str, stream_id: str):
|
||||||
success = await plugin_manage_api.remove_plugin(plugin_name)
|
result = await self.ctx.component.get_all_plugins()
|
||||||
if success:
|
all_components = self._extract_components(result)
|
||||||
await self._send_message(f"插件卸载成功: {plugin_name}")
|
if not all_components:
|
||||||
|
await self.ctx.send.text("没有注册的组件", stream_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
if filter_mode == "enabled":
|
||||||
|
filtered = [c for c in all_components if c.get("enabled", False)]
|
||||||
|
label = "已启用"
|
||||||
else:
|
else:
|
||||||
await self._send_message(f"插件卸载失败: {plugin_name}")
|
filtered = [c for c in all_components if not c.get("enabled", False)]
|
||||||
|
label = "已禁用"
|
||||||
|
|
||||||
async def _reload_plugin(self, plugin_name: str):
|
scope_label = "全局" if scope == "global" else "本聊天"
|
||||||
success = await plugin_manage_api.reload_plugin(plugin_name)
|
if not filtered:
|
||||||
if success:
|
await self.ctx.send.text(f"没有满足条件的{label}{scope_label}组件", stream_id)
|
||||||
await self._send_message(f"插件重新加载成功: {plugin_name}")
|
return
|
||||||
|
text = ", ".join(f"{c['name']} ({c['type']})" for c in filtered)
|
||||||
|
await self.ctx.send.text(f"满足条件的{label}{scope_label}组件: {text}", stream_id)
|
||||||
|
|
||||||
|
async def _handle_component_toggle(
|
||||||
|
self, action: str, scope: str, comp_name: str, comp_type: str, stream_id: str
|
||||||
|
):
|
||||||
|
if action not in ("enable", "disable"):
|
||||||
|
await self.ctx.send.text("插件管理命令不合法", stream_id)
|
||||||
|
return
|
||||||
|
if scope not in ("global", "local"):
|
||||||
|
await self.ctx.send.text("插件管理命令不合法", stream_id)
|
||||||
|
return
|
||||||
|
if comp_type not in _VALID_COMPONENT_TYPES:
|
||||||
|
await self.ctx.send.text(f"未知组件类型: {comp_type}", stream_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
if action == "enable":
|
||||||
|
result = await self.ctx.component.enable_component(
|
||||||
|
comp_name, comp_type, scope=scope, stream_id=stream_id
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
await self._send_message(f"插件重新加载失败: {plugin_name}")
|
result = await self.ctx.component.disable_component(
|
||||||
|
comp_name, comp_type, scope=scope, stream_id=stream_id
|
||||||
|
)
|
||||||
|
|
||||||
async def _add_dir(self, dir_path: str):
|
ok = result.get("success", False) if isinstance(result, dict) else bool(result)
|
||||||
await self._send_message(f"正在添加插件目录: {dir_path}")
|
scope_label = "全局" if scope == "global" else "本地"
|
||||||
success = plugin_manage_api.add_plugin_directory(dir_path)
|
action_label = "启用" if action == "enable" else "禁用"
|
||||||
await asyncio.sleep(0.5) # 防止乱序发送
|
status = "成功" if ok else "失败"
|
||||||
if success:
|
await self.ctx.send.text(f"{scope_label}{action_label}组件{status}: {comp_name}", stream_id)
|
||||||
await self._send_message(f"插件目录添加成功: {dir_path}")
|
|
||||||
else:
|
|
||||||
await self._send_message(f"插件目录添加失败: {dir_path}")
|
|
||||||
|
|
||||||
def _fetch_all_registered_components(self) -> List[ComponentInfo]:
|
# ------ helpers ------
|
||||||
all_plugin_info = component_manage_api.get_all_plugin_info()
|
|
||||||
if not all_plugin_info:
|
@staticmethod
|
||||||
|
def _extract_components(result) -> list[dict]:
|
||||||
|
"""从 get_all_plugins 结果中提取所有组件列表"""
|
||||||
|
if not result:
|
||||||
return []
|
return []
|
||||||
|
if isinstance(result, dict):
|
||||||
components_info: List[ComponentInfo] = []
|
components = []
|
||||||
for plugin_info in all_plugin_info.values():
|
for plugin_info in result.values():
|
||||||
components_info.extend(plugin_info.components)
|
if isinstance(plugin_info, dict):
|
||||||
return components_info
|
components.extend(plugin_info.get("components", []))
|
||||||
|
return components
|
||||||
def _fetch_locally_disabled_components(self) -> List[str]:
|
return []
|
||||||
locally_disabled_components_actions = component_manage_api.get_locally_disabled_components(
|
|
||||||
self.message.chat_stream.stream_id, ComponentType.ACTION
|
|
||||||
)
|
|
||||||
locally_disabled_components_commands = component_manage_api.get_locally_disabled_components(
|
|
||||||
self.message.chat_stream.stream_id, ComponentType.COMMAND
|
|
||||||
)
|
|
||||||
locally_disabled_components_event_handlers = component_manage_api.get_locally_disabled_components(
|
|
||||||
self.message.chat_stream.stream_id, ComponentType.EVENT_HANDLER
|
|
||||||
)
|
|
||||||
return (
|
|
||||||
locally_disabled_components_actions
|
|
||||||
+ locally_disabled_components_commands
|
|
||||||
+ locally_disabled_components_event_handlers
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _list_all_registered_components(self):
|
|
||||||
components_info = self._fetch_all_registered_components()
|
|
||||||
if not components_info:
|
|
||||||
await self._send_message("没有注册的组件")
|
|
||||||
return
|
|
||||||
|
|
||||||
all_components_str = ", ".join(
|
|
||||||
f"{component.name} ({component.component_type})" for component in components_info
|
|
||||||
)
|
|
||||||
await self._send_message(f"已注册的组件: {all_components_str}")
|
|
||||||
|
|
||||||
async def _list_enabled_components(self, target_type: str = "global"):
|
|
||||||
components_info = self._fetch_all_registered_components()
|
|
||||||
if not components_info:
|
|
||||||
await self._send_message("没有注册的组件")
|
|
||||||
return
|
|
||||||
|
|
||||||
if target_type == "global":
|
|
||||||
enabled_components = [component for component in components_info if component.enabled]
|
|
||||||
if not enabled_components:
|
|
||||||
await self._send_message("没有满足条件的已启用全局组件")
|
|
||||||
return
|
|
||||||
enabled_components_str = ", ".join(
|
|
||||||
f"{component.name} ({component.component_type})" for component in enabled_components
|
|
||||||
)
|
|
||||||
await self._send_message(f"满足条件的已启用全局组件: {enabled_components_str}")
|
|
||||||
elif target_type == "local":
|
|
||||||
locally_disabled_components = self._fetch_locally_disabled_components()
|
|
||||||
enabled_components = [
|
|
||||||
component
|
|
||||||
for component in components_info
|
|
||||||
if (component.name not in locally_disabled_components and component.enabled)
|
|
||||||
]
|
|
||||||
if not enabled_components:
|
|
||||||
await self._send_message("本聊天没有满足条件的已启用组件")
|
|
||||||
return
|
|
||||||
enabled_components_str = ", ".join(
|
|
||||||
f"{component.name} ({component.component_type})" for component in enabled_components
|
|
||||||
)
|
|
||||||
await self._send_message(f"本聊天满足条件的已启用组件: {enabled_components_str}")
|
|
||||||
|
|
||||||
async def _list_disabled_components(self, target_type: str = "global"):
|
|
||||||
components_info = self._fetch_all_registered_components()
|
|
||||||
if not components_info:
|
|
||||||
await self._send_message("没有注册的组件")
|
|
||||||
return
|
|
||||||
|
|
||||||
if target_type == "global":
|
|
||||||
disabled_components = [component for component in components_info if not component.enabled]
|
|
||||||
if not disabled_components:
|
|
||||||
await self._send_message("没有满足条件的已禁用全局组件")
|
|
||||||
return
|
|
||||||
disabled_components_str = ", ".join(
|
|
||||||
f"{component.name} ({component.component_type})" for component in disabled_components
|
|
||||||
)
|
|
||||||
await self._send_message(f"满足条件的已禁用全局组件: {disabled_components_str}")
|
|
||||||
elif target_type == "local":
|
|
||||||
locally_disabled_components = self._fetch_locally_disabled_components()
|
|
||||||
disabled_components = [
|
|
||||||
component
|
|
||||||
for component in components_info
|
|
||||||
if (component.name in locally_disabled_components or not component.enabled)
|
|
||||||
]
|
|
||||||
if not disabled_components:
|
|
||||||
await self._send_message("本聊天没有满足条件的已禁用组件")
|
|
||||||
return
|
|
||||||
disabled_components_str = ", ".join(
|
|
||||||
f"{component.name} ({component.component_type})" for component in disabled_components
|
|
||||||
)
|
|
||||||
await self._send_message(f"本聊天满足条件的已禁用组件: {disabled_components_str}")
|
|
||||||
|
|
||||||
async def _list_registered_components_by_type(self, target_type: str):
|
|
||||||
match target_type:
|
|
||||||
case "action":
|
|
||||||
component_type = ComponentType.ACTION
|
|
||||||
case "command":
|
|
||||||
component_type = ComponentType.COMMAND
|
|
||||||
case "event_handler":
|
|
||||||
component_type = ComponentType.EVENT_HANDLER
|
|
||||||
case _:
|
|
||||||
await self._send_message(f"未知组件类型: {target_type}")
|
|
||||||
return
|
|
||||||
|
|
||||||
components_info = component_manage_api.get_components_info_by_type(component_type)
|
|
||||||
if not components_info:
|
|
||||||
await self._send_message(f"没有注册的 {target_type} 组件")
|
|
||||||
return
|
|
||||||
|
|
||||||
components_str = ", ".join(
|
|
||||||
f"{name} ({component.component_type})" for name, component in components_info.items()
|
|
||||||
)
|
|
||||||
await self._send_message(f"注册的 {target_type} 组件: {components_str}")
|
|
||||||
|
|
||||||
async def _globally_enable_component(self, component_name: str, component_type: str):
|
|
||||||
match component_type:
|
|
||||||
case "action":
|
|
||||||
target_component_type = ComponentType.ACTION
|
|
||||||
case "command":
|
|
||||||
target_component_type = ComponentType.COMMAND
|
|
||||||
case "event_handler":
|
|
||||||
target_component_type = ComponentType.EVENT_HANDLER
|
|
||||||
case _:
|
|
||||||
await self._send_message(f"未知组件类型: {component_type}")
|
|
||||||
return
|
|
||||||
if component_manage_api.globally_enable_component(component_name, target_component_type):
|
|
||||||
await self._send_message(f"全局启用组件成功: {component_name}")
|
|
||||||
else:
|
|
||||||
await self._send_message(f"全局启用组件失败: {component_name}")
|
|
||||||
|
|
||||||
async def _globally_disable_component(self, component_name: str, component_type: str):
|
|
||||||
match component_type:
|
|
||||||
case "action":
|
|
||||||
target_component_type = ComponentType.ACTION
|
|
||||||
case "command":
|
|
||||||
target_component_type = ComponentType.COMMAND
|
|
||||||
case "event_handler":
|
|
||||||
target_component_type = ComponentType.EVENT_HANDLER
|
|
||||||
case _:
|
|
||||||
await self._send_message(f"未知组件类型: {component_type}")
|
|
||||||
return
|
|
||||||
success = await component_manage_api.globally_disable_component(component_name, target_component_type)
|
|
||||||
if success:
|
|
||||||
await self._send_message(f"全局禁用组件成功: {component_name}")
|
|
||||||
else:
|
|
||||||
await self._send_message(f"全局禁用组件失败: {component_name}")
|
|
||||||
|
|
||||||
async def _locally_enable_component(self, component_name: str, component_type: str):
|
|
||||||
match component_type:
|
|
||||||
case "action":
|
|
||||||
target_component_type = ComponentType.ACTION
|
|
||||||
case "command":
|
|
||||||
target_component_type = ComponentType.COMMAND
|
|
||||||
case "event_handler":
|
|
||||||
target_component_type = ComponentType.EVENT_HANDLER
|
|
||||||
case _:
|
|
||||||
await self._send_message(f"未知组件类型: {component_type}")
|
|
||||||
return
|
|
||||||
if component_manage_api.locally_enable_component(
|
|
||||||
component_name,
|
|
||||||
target_component_type,
|
|
||||||
self.message.chat_stream.stream_id,
|
|
||||||
):
|
|
||||||
await self._send_message(f"本地启用组件成功: {component_name}")
|
|
||||||
else:
|
|
||||||
await self._send_message(f"本地启用组件失败: {component_name}")
|
|
||||||
|
|
||||||
async def _locally_disable_component(self, component_name: str, component_type: str):
|
|
||||||
match component_type:
|
|
||||||
case "action":
|
|
||||||
target_component_type = ComponentType.ACTION
|
|
||||||
case "command":
|
|
||||||
target_component_type = ComponentType.COMMAND
|
|
||||||
case "event_handler":
|
|
||||||
target_component_type = ComponentType.EVENT_HANDLER
|
|
||||||
case _:
|
|
||||||
await self._send_message(f"未知组件类型: {component_type}")
|
|
||||||
return
|
|
||||||
if component_manage_api.locally_disable_component(
|
|
||||||
component_name,
|
|
||||||
target_component_type,
|
|
||||||
self.message.chat_stream.stream_id,
|
|
||||||
):
|
|
||||||
await self._send_message(f"本地禁用组件成功: {component_name}")
|
|
||||||
else:
|
|
||||||
await self._send_message(f"本地禁用组件失败: {component_name}")
|
|
||||||
|
|
||||||
async def _send_message(self, message: str):
|
|
||||||
await send_api.text_to_stream(message, self.stream_id, typing=False, storage_message=False)
|
|
||||||
|
|
||||||
|
|
||||||
@register_plugin
|
def create_plugin():
|
||||||
class PluginManagementPlugin(BasePlugin):
|
return PluginManagementPlugin()
|
||||||
plugin_name: str = "plugin_management_plugin"
|
|
||||||
enable_plugin: bool = False
|
|
||||||
dependencies: list[str] = []
|
|
||||||
python_dependencies: list[str] = []
|
|
||||||
config_file_name: str = "config.toml"
|
|
||||||
config_schema: dict = {
|
|
||||||
"plugin": {
|
|
||||||
"enabled": ConfigField(bool, default=False, description="是否启用插件"),
|
|
||||||
"config_version": ConfigField(type=str, default="1.1.0", description="配置文件版本"),
|
|
||||||
"permission": ConfigField(
|
|
||||||
list, default=[], description="有权限使用插件管理命令的用户列表,请填写字符串形式的用户ID"
|
|
||||||
),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_plugin_components(self) -> List[Tuple[CommandInfo, Type[BaseCommand]]]:
|
|
||||||
components = []
|
|
||||||
if self.get_config("plugin.enabled", True):
|
|
||||||
components.append((ManagementCommand.get_command_info(), ManagementCommand))
|
|
||||||
return components
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
"manifest_version": 1,
|
"manifest_version": 1,
|
||||||
"name": "文本转语音插件 (Text-to-Speech)",
|
"name": "文本转语音插件 (Text-to-Speech)",
|
||||||
"version": "0.1.0",
|
"version": "2.0.0",
|
||||||
"description": "将文本转换为语音进行播放的插件,支持多种语音模式和智能语音输出场景判断。",
|
"description": "将文本转换为语音进行播放的插件,支持多种语音模式和智能语音输出场景判断。",
|
||||||
"author": {
|
"author": {
|
||||||
"name": "MaiBot团队",
|
"name": "MaiBot团队",
|
||||||
@@ -10,7 +10,7 @@
|
|||||||
"license": "GPL-v3.0-or-later",
|
"license": "GPL-v3.0-or-later",
|
||||||
|
|
||||||
"host_application": {
|
"host_application": {
|
||||||
"min_version": "0.8.0"
|
"min_version": "1.0.0"
|
||||||
},
|
},
|
||||||
"homepage_url": "https://github.com/MaiM-with-u/maibot",
|
"homepage_url": "https://github.com/MaiM-with-u/maibot",
|
||||||
"repository_url": "https://github.com/MaiM-with-u/maibot",
|
"repository_url": "https://github.com/MaiM-with-u/maibot",
|
||||||
|
|||||||
@@ -1,145 +1,55 @@
|
|||||||
from src.plugin_system.apis.plugin_register_api import register_plugin
|
"""TTS 插件 — 新 SDK 版本
|
||||||
from src.plugin_system.base.base_plugin import BasePlugin
|
|
||||||
from src.plugin_system.base.component_types import ComponentInfo
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.plugin_system.base.base_action import BaseAction, ActionActivationType
|
|
||||||
from src.plugin_system.base.config_types import ConfigField
|
|
||||||
from typing import Tuple, List, Type
|
|
||||||
|
|
||||||
logger = get_logger("tts")
|
将文本转换为语音进行播放。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
from maibot_sdk import MaiBotPlugin, Action
|
||||||
|
from maibot_sdk.types import ActivationType
|
||||||
|
|
||||||
|
|
||||||
class TTSAction(BaseAction):
|
class TTSPlugin(MaiBotPlugin):
|
||||||
"""TTS语音转换动作处理类"""
|
"""文本转语音插件"""
|
||||||
|
|
||||||
# 激活设置
|
|
||||||
activation_type = ActionActivationType.KEYWORD
|
|
||||||
activation_keywords = ["语音", "tts", "播报", "读出来", "语音播放", "听", "朗读"]
|
|
||||||
keyword_case_sensitive = False
|
|
||||||
parallel_action = False
|
|
||||||
|
|
||||||
# 动作基本信息
|
|
||||||
action_name = "tts_action"
|
|
||||||
action_description = "将文本转换为语音进行播放,适用于需要语音输出的场景"
|
|
||||||
|
|
||||||
# 动作参数定义
|
|
||||||
action_parameters = {
|
|
||||||
"voice_text": "你想用语音表达的内容,这段内容将会以语音形式发出",
|
|
||||||
}
|
|
||||||
|
|
||||||
# 动作使用场景
|
|
||||||
action_require = [
|
|
||||||
"当需要发送语音信息时使用",
|
|
||||||
"当用户明确要求使用语音功能时使用",
|
|
||||||
"当表达内容更适合用语音而不是文字传达时使用",
|
|
||||||
"当用户想听到语音回答而非阅读文本时使用",
|
|
||||||
]
|
|
||||||
|
|
||||||
# 关联类型
|
|
||||||
associated_types = ["tts_text"]
|
|
||||||
|
|
||||||
async def execute(self) -> Tuple[bool, str]:
|
|
||||||
"""处理TTS文本转语音动作"""
|
|
||||||
logger.info(f"{self.log_prefix} 执行TTS动作: {self.reasoning}")
|
|
||||||
|
|
||||||
# 获取要转换的文本
|
|
||||||
text = self.action_data.get("voice_text")
|
|
||||||
|
|
||||||
|
@Action(
|
||||||
|
"tts_action",
|
||||||
|
description="将文本转换为语音进行播放,适用于需要语音输出的场景",
|
||||||
|
activation_type=ActivationType.KEYWORD,
|
||||||
|
activation_keywords=["语音", "tts", "播报", "读出来", "语音播放", "听", "朗读"],
|
||||||
|
parallel_action=False,
|
||||||
|
action_parameters={"voice_text": "你想用语音表达的内容,这段内容将会以语音形式发出"},
|
||||||
|
action_require=[
|
||||||
|
"当需要发送语音信息时使用",
|
||||||
|
"当用户明确要求使用语音功能时使用",
|
||||||
|
"当表达内容更适合用语音而不是文字传达时使用",
|
||||||
|
"当用户想听到语音回答而非阅读文本时使用",
|
||||||
|
],
|
||||||
|
associated_types=["tts_text"],
|
||||||
|
)
|
||||||
|
async def handle_tts_action(self, stream_id: str = "", action_data: dict = None, reasoning: str = "", **kwargs):
|
||||||
|
"""处理 TTS 文本转语音动作"""
|
||||||
|
action_data = action_data or {}
|
||||||
|
text = action_data.get("voice_text", "")
|
||||||
if not text:
|
if not text:
|
||||||
logger.error(f"{self.log_prefix} 执行TTS动作时未提供文本内容")
|
|
||||||
return False, "执行TTS动作失败:未提供文本内容"
|
return False, "执行TTS动作失败:未提供文本内容"
|
||||||
|
|
||||||
# 确保文本适合TTS使用
|
# 文本预处理
|
||||||
processed_text = self._process_text_for_tts(text)
|
processed_text = re.sub(r"([!?,.;:。!?,、;:])\1+", r"\1", text)
|
||||||
|
|
||||||
try:
|
|
||||||
# 发送TTS消息
|
|
||||||
await self.send_custom(message_type="tts_text", content=processed_text)
|
|
||||||
|
|
||||||
# 记录动作信息
|
|
||||||
await self.store_action_info(
|
|
||||||
action_build_into_prompt=True, action_prompt_display="已经发送了语音消息。", action_done=True
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"{self.log_prefix} TTS动作执行成功,文本长度: {len(processed_text)}")
|
|
||||||
return True, "TTS动作执行成功"
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"{self.log_prefix} 执行TTS动作时出错: {e}")
|
|
||||||
return False, f"执行TTS动作时出错: {e}"
|
|
||||||
|
|
||||||
def _process_text_for_tts(self, text: str) -> str:
|
|
||||||
"""
|
|
||||||
处理文本使其更适合TTS使用
|
|
||||||
- 移除不必要的特殊字符和表情符号
|
|
||||||
- 修正标点符号以提高语音质量
|
|
||||||
- 优化文本结构使语音更流畅
|
|
||||||
"""
|
|
||||||
# 这里可以添加文本处理逻辑
|
|
||||||
# 例如:移除多余的标点、表情符号,优化语句结构等
|
|
||||||
|
|
||||||
# 简单示例实现
|
|
||||||
processed_text = text
|
|
||||||
|
|
||||||
# 移除多余的标点符号
|
|
||||||
import re
|
|
||||||
|
|
||||||
processed_text = re.sub(r"([!?,.;:。!?,、;:])\1+", r"\1", processed_text)
|
|
||||||
|
|
||||||
# 确保句子结尾有合适的标点
|
|
||||||
if not any(processed_text.endswith(end) for end in [".", "?", "!", "。", "!", "?"]):
|
if not any(processed_text.endswith(end) for end in [".", "?", "!", "。", "!", "?"]):
|
||||||
processed_text = f"{processed_text}。"
|
processed_text = f"{processed_text}。"
|
||||||
|
|
||||||
return processed_text
|
# 发送自定义 tts 消息
|
||||||
|
result = await self.ctx.call_capability(
|
||||||
|
"send.custom",
|
||||||
|
message_type="tts_text",
|
||||||
|
content=processed_text,
|
||||||
|
stream_id=stream_id,
|
||||||
|
)
|
||||||
|
if result and result.get("success"):
|
||||||
|
return True, "TTS动作执行成功"
|
||||||
|
return False, f"TTS动作执行失败: {result}"
|
||||||
|
|
||||||
|
|
||||||
@register_plugin
|
def create_plugin():
|
||||||
class TTSPlugin(BasePlugin):
|
return TTSPlugin()
|
||||||
"""TTS插件
|
|
||||||
- 这是文字转语音插件
|
|
||||||
- Normal模式下依靠关键词触发
|
|
||||||
- Focus模式下由LLM判断触发
|
|
||||||
- 具有一定的文本预处理能力
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 插件基本信息
|
|
||||||
plugin_name: str = "tts_plugin" # 内部标识符
|
|
||||||
enable_plugin: bool = True
|
|
||||||
dependencies: list[str] = [] # 插件依赖列表
|
|
||||||
python_dependencies: list[str] = [] # Python包依赖列表
|
|
||||||
config_file_name: str = "config.toml"
|
|
||||||
|
|
||||||
# 配置节描述
|
|
||||||
config_section_descriptions = {
|
|
||||||
"plugin": "插件基本信息配置",
|
|
||||||
"components": "组件启用控制",
|
|
||||||
"logging": "日志记录相关配置",
|
|
||||||
}
|
|
||||||
|
|
||||||
# 配置Schema定义
|
|
||||||
config_schema: dict = {
|
|
||||||
"plugin": {
|
|
||||||
"name": ConfigField(type=str, default="tts_plugin", description="插件名称", required=True),
|
|
||||||
"version": ConfigField(type=str, default="0.1.0", description="插件版本号"),
|
|
||||||
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
|
|
||||||
"description": ConfigField(type=str, default="文字转语音插件", description="插件描述", required=True),
|
|
||||||
},
|
|
||||||
"components": {"enable_tts": ConfigField(type=bool, default=True, description="是否启用TTS Action")},
|
|
||||||
"logging": {
|
|
||||||
"level": ConfigField(
|
|
||||||
type=str, default="INFO", description="日志记录级别", choices=["DEBUG", "INFO", "WARNING", "ERROR"]
|
|
||||||
),
|
|
||||||
"prefix": ConfigField(type=str, default="[TTS]", description="日志记录前缀"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
|
||||||
"""返回插件包含的组件列表"""
|
|
||||||
|
|
||||||
# 从配置获取组件启用状态
|
|
||||||
enable_tts = self.get_config("components.enable_tts", True)
|
|
||||||
components = [] # 添加Action组件
|
|
||||||
if enable_tts:
|
|
||||||
components.append((TTSAction.get_action_info(), TTSAction))
|
|
||||||
|
|
||||||
return components
|
|
||||||
|
|||||||
7
src/services/__init__.py
Normal file
7
src/services/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
"""
|
||||||
|
核心服务层
|
||||||
|
|
||||||
|
提供与具体插件系统无关的核心业务服务。
|
||||||
|
内部模块(chat、dream、memory 等)应直接使用此层,
|
||||||
|
而 plugin_system.apis 仅作为面向插件的薄包装。
|
||||||
|
"""
|
||||||
159
src/services/chat_service.py
Normal file
159
src/services/chat_service.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
"""
|
||||||
|
聊天服务模块
|
||||||
|
|
||||||
|
提供聊天信息查询和管理的核心功能。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("chat_service")
|
||||||
|
|
||||||
|
|
||||||
|
class SpecialTypes(Enum):
|
||||||
|
"""特殊枚举类型"""
|
||||||
|
|
||||||
|
ALL_PLATFORMS = "all_platforms"
|
||||||
|
|
||||||
|
|
||||||
|
class ChatManager:
|
||||||
|
"""聊天管理器 - 负责聊天信息的查询和管理"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
|
||||||
|
# sourcery skip: for-append-to-extend
|
||||||
|
if not isinstance(platform, (str, SpecialTypes)):
|
||||||
|
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||||
|
streams = []
|
||||||
|
try:
|
||||||
|
for _, stream in _chat_manager.sessions.items():
|
||||||
|
if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform:
|
||||||
|
streams.append(stream)
|
||||||
|
logger.debug(f"[ChatService] 获取到 {len(streams)} 个 {platform} 平台的聊天流")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[ChatService] 获取聊天流失败: {e}")
|
||||||
|
return streams
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
|
||||||
|
# sourcery skip: for-append-to-extend
|
||||||
|
if not isinstance(platform, (str, SpecialTypes)):
|
||||||
|
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||||
|
streams = []
|
||||||
|
try:
|
||||||
|
for _, stream in _chat_manager.sessions.items():
|
||||||
|
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.is_group_session:
|
||||||
|
streams.append(stream)
|
||||||
|
logger.debug(f"[ChatService] 获取到 {len(streams)} 个 {platform} 平台的群聊流")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[ChatService] 获取群聊流失败: {e}")
|
||||||
|
return streams
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
|
||||||
|
# sourcery skip: for-append-to-extend
|
||||||
|
if not isinstance(platform, (str, SpecialTypes)):
|
||||||
|
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||||
|
streams = []
|
||||||
|
try:
|
||||||
|
for _, stream in _chat_manager.sessions.items():
|
||||||
|
if (
|
||||||
|
platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform
|
||||||
|
) and not stream.is_group_session:
|
||||||
|
streams.append(stream)
|
||||||
|
logger.debug(f"[ChatService] 获取到 {len(streams)} 个 {platform} 平台的私聊流")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[ChatService] 获取私聊流失败: {e}")
|
||||||
|
return streams
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_group_stream_by_group_id(
|
||||||
|
group_id: str, platform: Optional[str] | SpecialTypes = "qq"
|
||||||
|
) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast
|
||||||
|
if not isinstance(group_id, str):
|
||||||
|
raise TypeError("group_id 必须是字符串类型")
|
||||||
|
if not isinstance(platform, (str, SpecialTypes)):
|
||||||
|
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||||
|
if not group_id:
|
||||||
|
raise ValueError("group_id 不能为空")
|
||||||
|
try:
|
||||||
|
for _, stream in _chat_manager.sessions.items():
|
||||||
|
if (
|
||||||
|
stream.is_group_session
|
||||||
|
and str(stream.group_id) == str(group_id)
|
||||||
|
and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform)
|
||||||
|
):
|
||||||
|
logger.debug(f"[ChatService] 找到群ID {group_id} 的聊天流")
|
||||||
|
return stream
|
||||||
|
logger.warning(f"[ChatService] 未找到群ID {group_id} 的聊天流")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[ChatService] 查找群聊流失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_private_stream_by_user_id(
|
||||||
|
user_id: str, platform: Optional[str] | SpecialTypes = "qq"
|
||||||
|
) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast
|
||||||
|
if not isinstance(user_id, str):
|
||||||
|
raise TypeError("user_id 必须是字符串类型")
|
||||||
|
if not isinstance(platform, (str, SpecialTypes)):
|
||||||
|
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||||
|
if not user_id:
|
||||||
|
raise ValueError("user_id 不能为空")
|
||||||
|
try:
|
||||||
|
for _, stream in _chat_manager.sessions.items():
|
||||||
|
if (
|
||||||
|
not stream.is_group_session
|
||||||
|
and str(stream.user_id) == str(user_id)
|
||||||
|
and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform)
|
||||||
|
):
|
||||||
|
logger.debug(f"[ChatService] 找到用户ID {user_id} 的私聊流")
|
||||||
|
return stream
|
||||||
|
logger.warning(f"[ChatService] 未找到用户ID {user_id} 的私聊流")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[ChatService] 查找私聊流失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_stream_type(chat_stream: BotChatSession) -> str:
|
||||||
|
if not isinstance(chat_stream, BotChatSession):
|
||||||
|
raise TypeError("chat_stream 必须是 BotChatSession 类型")
|
||||||
|
if not chat_stream:
|
||||||
|
raise ValueError("chat_stream 不能为 None")
|
||||||
|
|
||||||
|
return "group" if chat_stream.is_group_session else "private"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_stream_info(chat_stream: BotChatSession) -> Dict[str, Any]:
|
||||||
|
if not chat_stream:
|
||||||
|
raise ValueError("chat_stream 不能为 None")
|
||||||
|
if not isinstance(chat_stream, BotChatSession):
|
||||||
|
raise TypeError("chat_stream 必须是 BotChatSession 类型")
|
||||||
|
|
||||||
|
try:
|
||||||
|
info: Dict[str, Any] = {
|
||||||
|
"session_id": chat_stream.session_id,
|
||||||
|
"platform": chat_stream.platform,
|
||||||
|
"type": ChatManager.get_stream_type(chat_stream),
|
||||||
|
}
|
||||||
|
|
||||||
|
if chat_stream.is_group_session:
|
||||||
|
info["group_id"] = chat_stream.group_id
|
||||||
|
if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.group_info:
|
||||||
|
info["group_name"] = chat_stream.context.message.message_info.group_info.group_name or "未知群聊"
|
||||||
|
else:
|
||||||
|
info["group_name"] = "未知群聊"
|
||||||
|
else:
|
||||||
|
info["user_id"] = chat_stream.user_id
|
||||||
|
if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.user_info:
|
||||||
|
info["user_name"] = chat_stream.context.message.message_info.user_info.user_nickname
|
||||||
|
else:
|
||||||
|
info["user_name"] = "未知用户"
|
||||||
|
|
||||||
|
return info
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[ChatService] 获取聊天流信息失败: {e}")
|
||||||
|
return {}
|
||||||
@@ -1,28 +1,19 @@
|
|||||||
"""配置API模块
|
"""配置服务模块
|
||||||
|
|
||||||
提供了配置读取和用户信息获取等功能
|
提供配置读取的核心功能。
|
||||||
使用方式:
|
|
||||||
from src.plugin_system.apis import config_api
|
|
||||||
value = config_api.get_global_config("section.key")
|
|
||||||
platform, user_id = await config_api.get_user_id_by_person_name("用户名")
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
logger = get_logger("config_api")
|
logger = get_logger("config_service")
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# 配置访问API函数
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
def get_global_config(key: str, default: Any = None) -> Any:
|
def get_global_config(key: str, default: Any = None) -> Any:
|
||||||
"""
|
"""
|
||||||
安全地从全局配置中获取一个值。
|
安全地从全局配置中获取一个值。
|
||||||
插件应使用此方法读取全局配置,以保证只读和隔离性。
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key: 命名空间式配置键名,使用嵌套访问,如 "section.subsection.key",大小写敏感
|
key: 命名空间式配置键名,使用嵌套访问,如 "section.subsection.key",大小写敏感
|
||||||
@@ -31,7 +22,6 @@ def get_global_config(key: str, default: Any = None) -> Any:
|
|||||||
Returns:
|
Returns:
|
||||||
Any: 配置值或默认值
|
Any: 配置值或默认值
|
||||||
"""
|
"""
|
||||||
# 支持嵌套键访问
|
|
||||||
keys = key.split(".")
|
keys = key.split(".")
|
||||||
current = global_config
|
current = global_config
|
||||||
|
|
||||||
@@ -43,7 +33,7 @@ def get_global_config(key: str, default: Any = None) -> Any:
|
|||||||
raise KeyError(f"配置中不存在子空间或键 '{k}'")
|
raise KeyError(f"配置中不存在子空间或键 '{k}'")
|
||||||
return current
|
return current
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"[ConfigAPI] 获取全局配置 {key} 失败: {e}")
|
logger.warning(f"[ConfigService] 获取全局配置 {key} 失败: {e}")
|
||||||
return default
|
return default
|
||||||
|
|
||||||
|
|
||||||
@@ -59,7 +49,6 @@ def get_plugin_config(plugin_config: dict, key: str, default: Any = None) -> Any
|
|||||||
Returns:
|
Returns:
|
||||||
Any: 配置值或默认值
|
Any: 配置值或默认值
|
||||||
"""
|
"""
|
||||||
# 支持嵌套键访问
|
|
||||||
keys = key.split(".")
|
keys = key.split(".")
|
||||||
current = plugin_config
|
current = plugin_config
|
||||||
|
|
||||||
@@ -73,5 +62,5 @@ def get_plugin_config(plugin_config: dict, key: str, default: Any = None) -> Any
|
|||||||
raise KeyError(f"配置中不存在子空间或键 '{k}'")
|
raise KeyError(f"配置中不存在子空间或键 '{k}'")
|
||||||
return current
|
return current
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"[ConfigAPI] 获取插件配置 {key} 失败: {e}")
|
logger.warning(f"[ConfigService] 获取插件配置 {key} 失败: {e}")
|
||||||
return default
|
return default
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
"""数据库API模块
|
"""数据库服务模块
|
||||||
|
|
||||||
提供数据库操作相关功能,统一使用 SQLModel/SQLAlchemy 兼容接口。
|
提供数据库操作相关的核心功能。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
@@ -10,7 +10,7 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger("database_api")
|
logger = get_logger("database_service")
|
||||||
|
|
||||||
|
|
||||||
def _to_dict(record: Any) -> dict[str, Any]:
|
def _to_dict(record: Any) -> dict[str, Any]:
|
||||||
@@ -73,7 +73,7 @@ async def db_query(
|
|||||||
|
|
||||||
return query.count()
|
return query.count()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[DatabaseAPI] 数据库操作出错: {e}")
|
logger.error(f"[DatabaseService] 数据库操作出错: {e}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
if query_type == "get":
|
if query_type == "get":
|
||||||
return None if single_result else []
|
return None if single_result else []
|
||||||
@@ -93,7 +93,7 @@ async def db_save(model_class, data: dict[str, Any], key_field: Optional[str] =
|
|||||||
new_record = model_class.create(**data)
|
new_record = model_class.create(**data)
|
||||||
return _to_dict(new_record)
|
return _to_dict(new_record)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[DatabaseAPI] 保存数据库记录出错: {e}")
|
logger.error(f"[DatabaseService] 保存数据库记录出错: {e}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -119,7 +119,7 @@ async def db_get(
|
|||||||
return results[0] if results else None
|
return results[0] if results else None
|
||||||
return results
|
return results
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[DatabaseAPI] 获取数据库记录出错: {e}")
|
logger.error(f"[DatabaseService] 获取数据库记录出错: {e}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return None if single_result else []
|
return None if single_result else []
|
||||||
|
|
||||||
@@ -163,11 +163,11 @@ async def store_action_info(
|
|||||||
ActionRecords, data=record_data, key_field="action_id", key_value=record_data["action_id"]
|
ActionRecords, data=record_data, key_field="action_id", key_value=record_data["action_id"]
|
||||||
)
|
)
|
||||||
if saved_record:
|
if saved_record:
|
||||||
logger.debug(f"[DatabaseAPI] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})")
|
logger.debug(f"[DatabaseService] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})")
|
||||||
else:
|
else:
|
||||||
logger.error(f"[DatabaseAPI] 存储动作信息失败: {action_name}")
|
logger.error(f"[DatabaseService] 存储动作信息失败: {action_name}")
|
||||||
return saved_record
|
return saved_record
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[DatabaseAPI] 存储动作信息时发生错误: {e}")
|
logger.error(f"[DatabaseService] 存储动作信息时发生错误: {e}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return None
|
return None
|
||||||
406
src/services/emoji_service.py
Normal file
406
src/services/emoji_service.py
Normal file
@@ -0,0 +1,406 @@
|
|||||||
|
"""
|
||||||
|
表情服务模块
|
||||||
|
|
||||||
|
提供表情包相关的核心功能。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from src.chat.emoji_system.emoji_manager import emoji_manager, EMOJI_DIR
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.common.utils.utils_image import ImageUtils
|
||||||
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
logger = get_logger("emoji_service")
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# 表情包获取函数
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]:
|
||||||
|
"""根据描述选择表情包"""
|
||||||
|
if not description:
|
||||||
|
raise ValueError("描述不能为空")
|
||||||
|
if not isinstance(description, str):
|
||||||
|
raise TypeError("描述必须是字符串类型")
|
||||||
|
try:
|
||||||
|
logger.debug(f"[EmojiService] 根据描述获取表情包: {description}")
|
||||||
|
|
||||||
|
emoji_obj = await emoji_manager.get_emoji_for_emotion(description)
|
||||||
|
|
||||||
|
if not emoji_obj:
|
||||||
|
logger.warning(f"[EmojiService] 未找到匹配描述 '{description}' 的表情包")
|
||||||
|
return None
|
||||||
|
|
||||||
|
emoji_path = str(emoji_obj.full_path)
|
||||||
|
emoji_description = emoji_obj.description
|
||||||
|
matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else ""
|
||||||
|
emoji_base64 = ImageUtils.image_path_to_base64(emoji_path)
|
||||||
|
|
||||||
|
if not emoji_base64:
|
||||||
|
logger.error(f"[EmojiService] 无法将表情包文件转换为base64: {emoji_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.debug(f"[EmojiService] 成功获取表情包: {emoji_description}, 匹配情感: {matched_emotion}")
|
||||||
|
return emoji_base64, emoji_description, matched_emotion
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[EmojiService] 获取表情包失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
|
||||||
|
"""随机获取指定数量的表情包"""
|
||||||
|
if not isinstance(count, int):
|
||||||
|
raise TypeError("count 必须是整数类型")
|
||||||
|
if count < 0:
|
||||||
|
raise ValueError("count 不能为负数")
|
||||||
|
if count == 0:
|
||||||
|
logger.warning("[EmojiService] count 为0,返回空列表")
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
all_emojis = emoji_manager.emojis
|
||||||
|
|
||||||
|
if not all_emojis:
|
||||||
|
logger.warning("[EmojiService] 没有可用的表情包")
|
||||||
|
return []
|
||||||
|
|
||||||
|
valid_emojis = [emoji for emoji in all_emojis if not emoji.is_deleted]
|
||||||
|
if not valid_emojis:
|
||||||
|
logger.warning("[EmojiService] 没有有效的表情包")
|
||||||
|
return []
|
||||||
|
|
||||||
|
if len(valid_emojis) < count:
|
||||||
|
logger.debug(
|
||||||
|
f"[EmojiService] 有效表情包数量 ({len(valid_emojis)}) 少于请求的数量 ({count}),将返回所有有效表情包"
|
||||||
|
)
|
||||||
|
count = len(valid_emojis)
|
||||||
|
|
||||||
|
selected_emojis = random.sample(valid_emojis, count)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for selected_emoji in selected_emojis:
|
||||||
|
emoji_base64 = ImageUtils.image_path_to_base64(str(selected_emoji.full_path))
|
||||||
|
|
||||||
|
if not emoji_base64:
|
||||||
|
logger.error(f"[EmojiService] 无法转换表情包为base64: {selected_emoji.full_path}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
matched_emotion = random.choice(selected_emoji.emotion) if selected_emoji.emotion else "随机表情"
|
||||||
|
|
||||||
|
emoji_manager.update_emoji_usage(selected_emoji)
|
||||||
|
results.append((emoji_base64, selected_emoji.description, matched_emotion))
|
||||||
|
|
||||||
|
if not results and count > 0:
|
||||||
|
logger.warning("[EmojiService] 随机获取表情包失败,没有一个可以成功处理")
|
||||||
|
return []
|
||||||
|
|
||||||
|
logger.debug(f"[EmojiService] 成功获取 {len(results)} 个随机表情包")
|
||||||
|
return results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[EmojiService] 获取随机表情包失败: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
|
||||||
|
"""根据情感标签获取表情包"""
|
||||||
|
if not emotion:
|
||||||
|
raise ValueError("情感标签不能为空")
|
||||||
|
if not isinstance(emotion, str):
|
||||||
|
raise TypeError("情感标签必须是字符串类型")
|
||||||
|
try:
|
||||||
|
logger.info(f"[EmojiService] 根据情感获取表情包: {emotion}")
|
||||||
|
|
||||||
|
all_emojis = emoji_manager.emojis
|
||||||
|
|
||||||
|
matching_emojis = []
|
||||||
|
matching_emojis.extend(
|
||||||
|
emoji_obj
|
||||||
|
for emoji_obj in all_emojis
|
||||||
|
if not emoji_obj.is_deleted and emotion.lower() in [e.lower() for e in emoji_obj.emotion]
|
||||||
|
)
|
||||||
|
if not matching_emojis:
|
||||||
|
logger.warning(f"[EmojiService] 未找到匹配情感 '{emotion}' 的表情包")
|
||||||
|
return None
|
||||||
|
|
||||||
|
selected_emoji = random.choice(matching_emojis)
|
||||||
|
emoji_base64 = ImageUtils.image_path_to_base64(selected_emoji.full_path)
|
||||||
|
|
||||||
|
if not emoji_base64:
|
||||||
|
logger.error(f"[EmojiService] 无法转换表情包为base64: {selected_emoji.full_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
emoji_manager.update_emoji_usage(selected_emoji)
|
||||||
|
|
||||||
|
logger.info(f"[EmojiService] 成功获取情感表情包: {selected_emoji.description}")
|
||||||
|
return emoji_base64, selected_emoji.description, emotion
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[EmojiService] 根据情感获取表情包失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# 表情包信息查询函数
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def get_count() -> int:
|
||||||
|
try:
|
||||||
|
return len(emoji_manager.emojis)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[EmojiService] 获取表情包数量失败: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def get_info():
|
||||||
|
try:
|
||||||
|
return {
|
||||||
|
"current_count": len(emoji_manager.emojis),
|
||||||
|
"max_count": global_config.emoji.max_reg_num,
|
||||||
|
"available_emojis": len([e for e in emoji_manager.emojis if not e.is_deleted]),
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[EmojiService] 获取表情包信息失败: {e}")
|
||||||
|
return {"current_count": 0, "max_count": 0, "available_emojis": 0}
|
||||||
|
|
||||||
|
|
||||||
|
def get_emotions() -> List[str]:
|
||||||
|
try:
|
||||||
|
emotions = set()
|
||||||
|
|
||||||
|
for emoji_obj in emoji_manager.emojis:
|
||||||
|
if not emoji_obj.is_deleted and emoji_obj.emotion:
|
||||||
|
emotions.update(emoji_obj.emotion)
|
||||||
|
|
||||||
|
return sorted(list(emotions))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[EmojiService] 获取情感标签失败: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
async def get_all() -> List[Tuple[str, str, str]]:
|
||||||
|
try:
|
||||||
|
all_emojis = emoji_manager.emojis
|
||||||
|
|
||||||
|
if not all_emojis:
|
||||||
|
logger.warning("[EmojiService] 没有可用的表情包")
|
||||||
|
return []
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for emoji_obj in all_emojis:
|
||||||
|
if emoji_obj.is_deleted:
|
||||||
|
continue
|
||||||
|
|
||||||
|
emoji_base64 = ImageUtils.image_path_to_base64(str(emoji_obj.full_path))
|
||||||
|
|
||||||
|
if not emoji_base64:
|
||||||
|
logger.error(f"[EmojiService] 无法转换表情包为base64: {emoji_obj.full_path}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else "随机表情"
|
||||||
|
results.append((emoji_base64, emoji_obj.description, matched_emotion))
|
||||||
|
|
||||||
|
logger.debug(f"[EmojiService] 成功获取 {len(results)} 个表情包")
|
||||||
|
return results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[EmojiService] 获取所有表情包失败: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def get_descriptions() -> List[str]:
|
||||||
|
try:
|
||||||
|
descriptions = []
|
||||||
|
|
||||||
|
descriptions.extend(
|
||||||
|
emoji_obj.description
|
||||||
|
for emoji_obj in emoji_manager.emojis
|
||||||
|
if not emoji_obj.is_deleted and emoji_obj.description
|
||||||
|
)
|
||||||
|
return descriptions
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[EmojiService] 获取表情包描述失败: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# 表情包注册函数
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
async def register_emoji(image_base64: str, filename: Optional[str] = None) -> Dict[str, Any]:
|
||||||
|
"""注册新的表情包"""
|
||||||
|
if not image_base64:
|
||||||
|
raise ValueError("图片base64编码不能为空")
|
||||||
|
if not isinstance(image_base64, str):
|
||||||
|
raise TypeError("image_base64必须是字符串类型")
|
||||||
|
if filename is not None and not isinstance(filename, str):
|
||||||
|
raise TypeError("filename必须是字符串类型或None")
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"[EmojiService] 开始注册表情包,文件名: {filename or '自动生成'}")
|
||||||
|
|
||||||
|
count_before = len(emoji_manager.emojis)
|
||||||
|
max_count = global_config.emoji.max_reg_num
|
||||||
|
|
||||||
|
can_register = count_before < max_count or (count_before >= max_count and global_config.emoji.do_replace)
|
||||||
|
|
||||||
|
if not can_register:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": f"表情包数量已达上限({count_before}/{max_count})且未启用替换功能",
|
||||||
|
"description": None,
|
||||||
|
"emotions": None,
|
||||||
|
"replaced": None,
|
||||||
|
"hash": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
os.makedirs(EMOJI_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
if not filename:
|
||||||
|
import time as _time
|
||||||
|
|
||||||
|
timestamp = int(_time.time())
|
||||||
|
microseconds = int(_time.time() * 1000000) % 1000000
|
||||||
|
|
||||||
|
random_bytes = random.getrandbits(72).to_bytes(9, "big")
|
||||||
|
short_id = base64.b64encode(random_bytes).decode("ascii")[:12].rstrip("=")
|
||||||
|
short_id = short_id.replace("/", "_").replace("+", "-")
|
||||||
|
filename = f"emoji_{timestamp}_{microseconds}_{short_id}"
|
||||||
|
|
||||||
|
if not filename.lower().endswith((".jpg", ".jpeg", ".png", ".gif")):
|
||||||
|
filename = f"{filename}.png"
|
||||||
|
|
||||||
|
temp_file_path = os.path.join(EMOJI_DIR, filename)
|
||||||
|
attempts = 0
|
||||||
|
max_attempts = 10
|
||||||
|
while os.path.exists(temp_file_path) and attempts < max_attempts:
|
||||||
|
random_bytes = random.getrandbits(48).to_bytes(6, "big")
|
||||||
|
short_id = base64.b64encode(random_bytes).decode("ascii")[:8].rstrip("=")
|
||||||
|
short_id = short_id.replace("/", "_").replace("+", "-")
|
||||||
|
|
||||||
|
name_part, ext = os.path.splitext(filename)
|
||||||
|
base_name = name_part.rsplit("_", 1)[0]
|
||||||
|
filename = f"{base_name}_{short_id}{ext}"
|
||||||
|
temp_file_path = os.path.join(EMOJI_DIR, filename)
|
||||||
|
attempts += 1
|
||||||
|
|
||||||
|
if os.path.exists(temp_file_path):
|
||||||
|
uuid_short = str(uuid.uuid4())[:8]
|
||||||
|
name_part, ext = os.path.splitext(filename)
|
||||||
|
base_name = name_part.rsplit("_", 1)[0]
|
||||||
|
filename = f"{base_name}_{uuid_short}{ext}"
|
||||||
|
temp_file_path = os.path.join(EMOJI_DIR, filename)
|
||||||
|
|
||||||
|
counter = 1
|
||||||
|
original_filename = filename
|
||||||
|
while os.path.exists(temp_file_path):
|
||||||
|
name_part, ext = os.path.splitext(original_filename)
|
||||||
|
filename = f"{name_part}_{counter}{ext}"
|
||||||
|
temp_file_path = os.path.join(EMOJI_DIR, filename)
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
if counter > 100:
|
||||||
|
logger.error(f"[EmojiService] 无法生成唯一文件名,尝试次数过多: {original_filename}")
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": "无法生成唯一文件名,请稍后重试",
|
||||||
|
"description": None,
|
||||||
|
"emotions": None,
|
||||||
|
"replaced": None,
|
||||||
|
"hash": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not ImageUtils.base64_to_image(image_base64, temp_file_path):
|
||||||
|
logger.error(f"[EmojiService] 无法保存base64图片到文件: {temp_file_path}")
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": "无法保存图片文件",
|
||||||
|
"description": None,
|
||||||
|
"emotions": None,
|
||||||
|
"replaced": None,
|
||||||
|
"hash": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug(f"[EmojiService] 图片已保存到临时文件: {temp_file_path}")
|
||||||
|
|
||||||
|
except Exception as save_error:
|
||||||
|
logger.error(f"[EmojiService] 保存图片文件失败: {save_error}")
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": f"保存图片文件失败: {str(save_error)}",
|
||||||
|
"description": None,
|
||||||
|
"emotions": None,
|
||||||
|
"replaced": None,
|
||||||
|
"hash": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
register_success = await emoji_manager.register_emoji_by_filename(filename)
|
||||||
|
|
||||||
|
if not register_success and os.path.exists(temp_file_path):
|
||||||
|
try:
|
||||||
|
os.remove(temp_file_path)
|
||||||
|
logger.debug(f"[EmojiService] 已清理临时文件: {temp_file_path}")
|
||||||
|
except Exception as cleanup_error:
|
||||||
|
logger.warning(f"[EmojiService] 清理临时文件失败: {cleanup_error}")
|
||||||
|
|
||||||
|
if register_success:
|
||||||
|
count_after = len(emoji_manager.emojis)
|
||||||
|
replaced = count_after <= count_before
|
||||||
|
|
||||||
|
new_emoji_info = None
|
||||||
|
if count_after > count_before or replaced:
|
||||||
|
try:
|
||||||
|
for emoji_obj in reversed(emoji_manager.emojis):
|
||||||
|
if not emoji_obj.is_deleted and (
|
||||||
|
emoji_obj.file_name == filename
|
||||||
|
or (hasattr(emoji_obj, "full_path") and filename in str(emoji_obj.full_path))
|
||||||
|
):
|
||||||
|
new_emoji_info = emoji_obj
|
||||||
|
break
|
||||||
|
except Exception as find_error:
|
||||||
|
logger.warning(f"[EmojiService] 查找新注册表情包信息失败: {find_error}")
|
||||||
|
|
||||||
|
description = new_emoji_info.description if new_emoji_info else None
|
||||||
|
emotions = new_emoji_info.emotion if new_emoji_info else None
|
||||||
|
emoji_hash = new_emoji_info.emoji_hash if new_emoji_info else None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": f"表情包注册成功 {'(替换旧表情包)' if replaced else '(新增表情包)'}",
|
||||||
|
"description": description,
|
||||||
|
"emotions": emotions,
|
||||||
|
"replaced": replaced,
|
||||||
|
"hash": emoji_hash,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": "表情包注册失败,可能因为重复、格式不支持或审核未通过",
|
||||||
|
"description": None,
|
||||||
|
"emotions": None,
|
||||||
|
"replaced": None,
|
||||||
|
"hash": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[EmojiService] 注册表情包时发生异常: {e}")
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": f"注册过程中发生错误: {str(e)}",
|
||||||
|
"description": None,
|
||||||
|
"emotions": None,
|
||||||
|
"replaced": None,
|
||||||
|
"hash": None,
|
||||||
|
}
|
||||||
@@ -1,9 +1,11 @@
|
|||||||
from src.common.logger import get_logger
|
"""频率控制服务模块
|
||||||
|
|
||||||
|
提供聊天频率控制的核心功能。
|
||||||
|
"""
|
||||||
|
|
||||||
from src.chat.heart_flow.frequency_control import frequency_control_manager
|
from src.chat.heart_flow.frequency_control import frequency_control_manager
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
logger = get_logger("frequency_api")
|
|
||||||
|
|
||||||
|
|
||||||
def get_current_talk_value(chat_id: str) -> float:
|
def get_current_talk_value(chat_id: str) -> float:
|
||||||
return frequency_control_manager.get_or_create_frequency_control(
|
return frequency_control_manager.get_or_create_frequency_control(
|
||||||
@@ -1,39 +1,37 @@
|
|||||||
"""
|
"""
|
||||||
回复器API模块
|
回复器服务模块
|
||||||
|
|
||||||
提供回复器相关功能,采用标准Python包设计模式
|
提供回复器相关的核心功能。
|
||||||
使用方式:
|
|
||||||
from src.plugin_system.apis import generator_api
|
|
||||||
replyer = generator_api.get_replyer(chat_stream)
|
|
||||||
success, reply_set, _ = await generator_api.generate_reply(chat_stream, action_data, reasoning)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import traceback
|
import traceback
|
||||||
import time
|
import time
|
||||||
from typing import Tuple, Any, Dict, List, Optional, TYPE_CHECKING
|
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||||
|
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.common.data_models.message_data_model import ReplySetModel
|
from src.chat.logger.plan_reply_logger import PlanReplyLogger
|
||||||
|
from src.chat.message_receive.chat_manager import BotChatSession
|
||||||
from src.chat.replyer.group_generator import DefaultReplyer
|
from src.chat.replyer.group_generator import DefaultReplyer
|
||||||
from src.chat.replyer.private_generator import PrivateReplyer
|
from src.chat.replyer.private_generator import PrivateReplyer
|
||||||
from src.chat.message_receive.chat_manager import BotChatSession
|
|
||||||
from src.chat.utils.utils import process_llm_response
|
|
||||||
from src.chat.replyer.replyer_manager import replyer_manager
|
from src.chat.replyer.replyer_manager import replyer_manager
|
||||||
from src.plugin_system.base.component_types import ActionInfo
|
from src.chat.utils.utils import process_llm_response
|
||||||
from src.chat.logger.plan_reply_logger import PlanReplyLogger
|
from src.common.data_models.message_data_model import ReplySetModel
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.core.types import ActionInfo
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||||
from src.common.data_models.llm_data_model import LLMGenerationDataModel
|
from src.common.data_models.llm_data_model import LLMGenerationDataModel
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
logger = get_logger("generator_api")
|
logger = get_logger("generator_service")
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# 回复器获取API函数
|
# 回复器获取函数
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
@@ -42,39 +40,24 @@ def get_replyer(
|
|||||||
chat_id: Optional[str] = None,
|
chat_id: Optional[str] = None,
|
||||||
request_type: str = "replyer",
|
request_type: str = "replyer",
|
||||||
) -> Optional[DefaultReplyer | PrivateReplyer]:
|
) -> Optional[DefaultReplyer | PrivateReplyer]:
|
||||||
"""获取回复器对象
|
"""获取回复器对象"""
|
||||||
|
|
||||||
优先使用chat_stream,如果没有则使用chat_id直接查找。
|
|
||||||
使用 ReplyerManager 来管理实例,避免重复创建。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_stream: 聊天流对象(优先)
|
|
||||||
chat_id: 聊天ID(实际上就是stream_id)
|
|
||||||
request_type: 请求类型
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[DefaultReplyer]: 回复器对象,如果获取失败则返回None
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: chat_stream 和 chat_id 均为空
|
|
||||||
"""
|
|
||||||
if not chat_id and not chat_stream:
|
if not chat_id and not chat_stream:
|
||||||
raise ValueError("chat_stream 和 chat_id 不可均为空")
|
raise ValueError("chat_stream 和 chat_id 不可均为空")
|
||||||
try:
|
try:
|
||||||
logger.debug(f"[GeneratorAPI] 正在获取回复器,chat_id: {chat_id}, chat_stream: {'有' if chat_stream else '无'}")
|
logger.debug(f"[GeneratorService] 正在获取回复器,chat_id: {chat_id}, chat_stream: {'有' if chat_stream else '无'}")
|
||||||
return replyer_manager.get_replyer(
|
return replyer_manager.get_replyer(
|
||||||
chat_stream=chat_stream,
|
chat_stream=chat_stream,
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
request_type=request_type,
|
request_type=request_type,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[GeneratorAPI] 获取回复器时发生意外错误: {e}", exc_info=True)
|
logger.error(f"[GeneratorService] 获取回复器时发生意外错误: {e}", exc_info=True)
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# 回复生成API函数
|
# 回复生成函数
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
@@ -96,39 +79,15 @@ async def generate_reply(
|
|||||||
from_plugin: bool = True,
|
from_plugin: bool = True,
|
||||||
reply_time_point: Optional[float] = None,
|
reply_time_point: Optional[float] = None,
|
||||||
) -> Tuple[bool, Optional["LLMGenerationDataModel"]]:
|
) -> Tuple[bool, Optional["LLMGenerationDataModel"]]:
|
||||||
"""生成回复
|
"""生成回复"""
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_stream: 聊天流对象(优先)
|
|
||||||
chat_id: 聊天ID(备用)
|
|
||||||
action_data: 动作数据(向下兼容,包含reply_to和extra_info)
|
|
||||||
reply_message: 回复的消息对象
|
|
||||||
extra_info: 额外信息,用于补充上下文
|
|
||||||
reply_reason: 回复原因
|
|
||||||
available_actions: 可用动作
|
|
||||||
chosen_actions: 已选动作
|
|
||||||
unknown_words: Planner 在 reply 动作中给出的未知词语列表,用于黑话检索
|
|
||||||
enable_tool: 是否启用工具调用
|
|
||||||
enable_splitter: 是否启用消息分割器
|
|
||||||
enable_chinese_typo: 是否启用错字生成器
|
|
||||||
return_prompt: 是否返回提示词
|
|
||||||
model_set_with_weight: 模型配置列表,每个元素为 (TaskConfig, weight) 元组
|
|
||||||
request_type: 请求类型(可选,记录LLM使用)
|
|
||||||
from_plugin: 是否来自插件
|
|
||||||
reply_time_point: 回复时间点
|
|
||||||
Returns:
|
|
||||||
Tuple[bool, List[Tuple[str, Any]], Optional[str]]: (是否成功, 回复集合, 提示词)
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# 如果 reply_time_point 未传入,设置为当前时间戳
|
|
||||||
if reply_time_point is None:
|
if reply_time_point is None:
|
||||||
reply_time_point = time.time()
|
reply_time_point = time.time()
|
||||||
|
|
||||||
# 获取回复器
|
logger.debug("[GeneratorService] 开始生成回复")
|
||||||
logger.debug("[GeneratorAPI] 开始生成回复")
|
|
||||||
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
|
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
|
||||||
if not replyer:
|
if not replyer:
|
||||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
logger.error("[GeneratorService] 无法获取回复器")
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
if action_data:
|
if action_data:
|
||||||
@@ -136,11 +95,9 @@ async def generate_reply(
|
|||||||
extra_info = action_data.get("extra_info", "")
|
extra_info = action_data.get("extra_info", "")
|
||||||
if not reply_reason:
|
if not reply_reason:
|
||||||
reply_reason = action_data.get("reason", "")
|
reply_reason = action_data.get("reason", "")
|
||||||
# 仅在 reply 场景下使用的未知词语解析(Planner JSON 中下发)
|
|
||||||
if unknown_words is None:
|
if unknown_words is None:
|
||||||
uw = action_data.get("unknown_words")
|
uw = action_data.get("unknown_words")
|
||||||
if isinstance(uw, list):
|
if isinstance(uw, list):
|
||||||
# 只保留非空字符串
|
|
||||||
cleaned: List[str] = []
|
cleaned: List[str] = []
|
||||||
for item in uw:
|
for item in uw:
|
||||||
if isinstance(item, str):
|
if isinstance(item, str):
|
||||||
@@ -150,7 +107,6 @@ async def generate_reply(
|
|||||||
if cleaned:
|
if cleaned:
|
||||||
unknown_words = cleaned
|
unknown_words = cleaned
|
||||||
|
|
||||||
# 调用回复器生成回复
|
|
||||||
success, llm_response = await replyer.generate_reply_with_context(
|
success, llm_response = await replyer.generate_reply_with_context(
|
||||||
extra_info=extra_info,
|
extra_info=extra_info,
|
||||||
available_actions=available_actions,
|
available_actions=available_actions,
|
||||||
@@ -166,7 +122,7 @@ async def generate_reply(
|
|||||||
log_reply=False,
|
log_reply=False,
|
||||||
)
|
)
|
||||||
if not success:
|
if not success:
|
||||||
logger.warning("[GeneratorAPI] 回复生成失败")
|
logger.warning("[GeneratorService] 回复生成失败")
|
||||||
return False, None
|
return False, None
|
||||||
reply_set: Optional[ReplySetModel] = None
|
reply_set: Optional[ReplySetModel] = None
|
||||||
if content := llm_response.content:
|
if content := llm_response.content:
|
||||||
@@ -176,9 +132,8 @@ async def generate_reply(
|
|||||||
for text in processed_response:
|
for text in processed_response:
|
||||||
reply_set.add_text_content(text)
|
reply_set.add_text_content(text)
|
||||||
llm_response.reply_set = reply_set
|
llm_response.reply_set = reply_set
|
||||||
logger.debug(f"[GeneratorAPI] 回复生成成功,生成了 {len(reply_set) if reply_set else 0} 个回复项")
|
logger.debug(f"[GeneratorService] 回复生成成功,生成了 {len(reply_set) if reply_set else 0} 个回复项")
|
||||||
|
|
||||||
# 统一在这里记录最终回复日志(包含分割后的 processed_output)
|
|
||||||
try:
|
try:
|
||||||
PlanReplyLogger.log_reply(
|
PlanReplyLogger.log_reply(
|
||||||
chat_id=chat_stream.session_id if chat_stream else (chat_id or ""),
|
chat_id=chat_stream.session_id if chat_stream else (chat_id or ""),
|
||||||
@@ -192,7 +147,7 @@ async def generate_reply(
|
|||||||
success=True,
|
success=True,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("[GeneratorAPI] 记录reply日志失败")
|
logger.exception("[GeneratorService] 记录reply日志失败")
|
||||||
|
|
||||||
return success, llm_response
|
return success, llm_response
|
||||||
|
|
||||||
@@ -200,11 +155,11 @@ async def generate_reply(
|
|||||||
raise ve
|
raise ve
|
||||||
|
|
||||||
except UserWarning as uw:
|
except UserWarning as uw:
|
||||||
logger.warning(f"[GeneratorAPI] 中断了生成: {uw}")
|
logger.warning(f"[GeneratorService] 中断了生成: {uw}")
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[GeneratorAPI] 生成回复时出错: {e}")
|
logger.error(f"[GeneratorService] 生成回复时出错: {e}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
@@ -220,39 +175,20 @@ async def rewrite_reply(
|
|||||||
reply_to: str = "",
|
reply_to: str = "",
|
||||||
request_type: str = "generator_api",
|
request_type: str = "generator_api",
|
||||||
) -> Tuple[bool, Optional["LLMGenerationDataModel"]]:
|
) -> Tuple[bool, Optional["LLMGenerationDataModel"]]:
|
||||||
"""重写回复
|
"""重写回复"""
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_stream: 聊天流对象(优先)
|
|
||||||
reply_data: 回复数据字典(向下兼容备用,当其他参数缺失时从此获取)
|
|
||||||
chat_id: 聊天ID(备用)
|
|
||||||
enable_splitter: 是否启用消息分割器
|
|
||||||
enable_chinese_typo: 是否启用错字生成器
|
|
||||||
model_set_with_weight: 模型配置列表,每个元素为 (TaskConfig, weight) 元组
|
|
||||||
raw_reply: 原始回复内容
|
|
||||||
reason: 回复原因
|
|
||||||
reply_to: 回复对象
|
|
||||||
return_prompt: 是否返回提示词
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[bool, List[Tuple[str, Any]]]: (是否成功, 回复集合)
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# 获取回复器
|
|
||||||
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
|
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
|
||||||
if not replyer:
|
if not replyer:
|
||||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
logger.error("[GeneratorService] 无法获取回复器")
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
logger.info("[GeneratorAPI] 开始重写回复")
|
logger.info("[GeneratorService] 开始重写回复")
|
||||||
|
|
||||||
# 如果参数缺失,从reply_data中获取
|
|
||||||
if reply_data:
|
if reply_data:
|
||||||
raw_reply = raw_reply or reply_data.get("raw_reply", "")
|
raw_reply = raw_reply or reply_data.get("raw_reply", "")
|
||||||
reason = reason or reply_data.get("reason", "")
|
reason = reason or reply_data.get("reason", "")
|
||||||
reply_to = reply_to or reply_data.get("reply_to", "")
|
reply_to = reply_to or reply_data.get("reply_to", "")
|
||||||
|
|
||||||
# 调用回复器重写回复
|
|
||||||
success, llm_response = await replyer.rewrite_reply_with_context(
|
success, llm_response = await replyer.rewrite_reply_with_context(
|
||||||
raw_reply=raw_reply,
|
raw_reply=raw_reply,
|
||||||
reason=reason,
|
reason=reason,
|
||||||
@@ -263,9 +199,9 @@ async def rewrite_reply(
|
|||||||
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
|
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
|
||||||
llm_response.reply_set = reply_set
|
llm_response.reply_set = reply_set
|
||||||
if success:
|
if success:
|
||||||
logger.info(f"[GeneratorAPI] 重写回复成功,生成了 {len(reply_set) if reply_set else 0} 个回复项")
|
logger.info(f"[GeneratorService] 重写回复成功,生成了 {len(reply_set) if reply_set else 0} 个回复项")
|
||||||
else:
|
else:
|
||||||
logger.warning("[GeneratorAPI] 重写回复失败")
|
logger.warning("[GeneratorService] 重写回复失败")
|
||||||
|
|
||||||
return success, llm_response
|
return success, llm_response
|
||||||
|
|
||||||
@@ -273,18 +209,12 @@ async def rewrite_reply(
|
|||||||
raise ve
|
raise ve
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[GeneratorAPI] 重写回复时出错: {e}")
|
logger.error(f"[GeneratorService] 重写回复时出错: {e}")
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
|
|
||||||
def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> Optional[ReplySetModel]:
|
def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> Optional[ReplySetModel]:
|
||||||
"""将文本处理为更拟人化的文本
|
"""将文本处理为更拟人化的文本"""
|
||||||
|
|
||||||
Args:
|
|
||||||
content: 文本内容
|
|
||||||
enable_splitter: 是否启用消息分割器
|
|
||||||
enable_chinese_typo: 是否启用错字生成器
|
|
||||||
"""
|
|
||||||
if not isinstance(content, str):
|
if not isinstance(content, str):
|
||||||
raise ValueError("content 必须是字符串类型")
|
raise ValueError("content 必须是字符串类型")
|
||||||
try:
|
try:
|
||||||
@@ -297,7 +227,7 @@ def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo:
|
|||||||
return reply_set
|
return reply_set
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[GeneratorAPI] 处理人形文本时出错: {e}")
|
logger.error(f"[GeneratorService] 处理人形文本时出错: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@@ -309,18 +239,18 @@ async def generate_response_custom(
|
|||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
|
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
|
||||||
if not replyer:
|
if not replyer:
|
||||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
logger.error("[GeneratorService] 无法获取回复器")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.debug("[GeneratorAPI] 开始生成自定义回复")
|
logger.debug("[GeneratorService] 开始生成自定义回复")
|
||||||
response, _, _, _ = await replyer.llm_generate_content(prompt)
|
response, _, _, _ = await replyer.llm_generate_content(prompt)
|
||||||
if response:
|
if response:
|
||||||
logger.debug("[GeneratorAPI] 自定义回复生成成功")
|
logger.debug("[GeneratorService] 自定义回复生成成功")
|
||||||
return response
|
return response
|
||||||
else:
|
else:
|
||||||
logger.warning("[GeneratorAPI] 自定义回复生成失败")
|
logger.warning("[GeneratorService] 自定义回复生成失败")
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[GeneratorAPI] 生成自定义回复时出错: {e}")
|
logger.error(f"[GeneratorService] 生成自定义回复时出错: {e}")
|
||||||
return None
|
return None
|
||||||
@@ -1,26 +1,19 @@
|
|||||||
"""LLM API模块
|
"""LLM 服务模块
|
||||||
|
|
||||||
提供了与LLM模型交互的功能
|
提供与 LLM 模型交互的核心功能。
|
||||||
使用方式:
|
|
||||||
from src.plugin_system.apis import llm_api
|
|
||||||
models = llm_api.get_available_models()
|
|
||||||
success, response, reasoning, model_name = await llm_api.generate_with_model(prompt, model_config)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Tuple, Dict, List, Any, Optional, Callable
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.llm_models.payload_content.tool_option import ToolCall
|
|
||||||
from src.llm_models.payload_content.message import Message
|
|
||||||
from src.llm_models.model_client.base_client import BaseClient
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
|
||||||
from src.config.config import config_manager
|
from src.config.config import config_manager
|
||||||
from src.config.model_configs import TaskConfig
|
from src.config.model_configs import TaskConfig
|
||||||
|
from src.llm_models.model_client.base_client import BaseClient
|
||||||
|
from src.llm_models.payload_content.message import Message
|
||||||
|
from src.llm_models.payload_content.tool_option import ToolCall
|
||||||
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
|
||||||
logger = get_logger("llm_api")
|
logger = get_logger("llm_service")
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# LLM模型API函数
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
def get_available_models() -> Dict[str, TaskConfig]:
|
def get_available_models() -> Dict[str, TaskConfig]:
|
||||||
@@ -30,7 +23,6 @@ def get_available_models() -> Dict[str, TaskConfig]:
|
|||||||
Dict[str, Any]: 模型配置字典,key为模型名称,value为模型配置
|
Dict[str, Any]: 模型配置字典,key为模型名称,value为模型配置
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 自动获取所有属性并转换为字典形式
|
|
||||||
models = config_manager.get_model_config().model_task_config
|
models = config_manager.get_model_config().model_task_config
|
||||||
attrs = dir(models)
|
attrs = dir(models)
|
||||||
rets: Dict[str, TaskConfig] = {}
|
rets: Dict[str, TaskConfig] = {}
|
||||||
@@ -41,12 +33,12 @@ def get_available_models() -> Dict[str, TaskConfig]:
|
|||||||
if not callable(value) and isinstance(value, TaskConfig):
|
if not callable(value) and isinstance(value, TaskConfig):
|
||||||
rets[attr] = value
|
rets[attr] = value
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"[LLMAPI] 获取属性 {attr} 失败: {e}")
|
logger.debug(f"[LLMService] 获取属性 {attr} 失败: {e}")
|
||||||
continue
|
continue
|
||||||
return rets
|
return rets
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[LLMAPI] 获取可用模型失败: {e}")
|
logger.error(f"[LLMService] 获取可用模型失败: {e}")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
@@ -68,9 +60,7 @@ async def generate_with_model(
|
|||||||
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
|
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# model_name_list = model_config.model_list
|
logger.debug(f"[LLMService] 完整提示词: {prompt}")
|
||||||
# logger.info(f"[LLMAPI] 使用模型集合 {model_name_list} 生成内容")
|
|
||||||
logger.debug(f"[LLMAPI] 完整提示词: {prompt}")
|
|
||||||
|
|
||||||
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
|
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
|
||||||
|
|
||||||
@@ -81,7 +71,7 @@ async def generate_with_model(
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"生成内容时出错: {str(e)}"
|
error_msg = f"生成内容时出错: {str(e)}"
|
||||||
logger.error(f"[LLMAPI] {error_msg}")
|
logger.error(f"[LLMService] {error_msg}")
|
||||||
return False, error_msg, "", ""
|
return False, error_msg, "", ""
|
||||||
|
|
||||||
|
|
||||||
@@ -104,7 +94,7 @@ async def generate_with_model_with_tools(
|
|||||||
max_tokens: 最大token数
|
max_tokens: 最大token数
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
|
Tuple[bool, str, str, str, List[ToolCall] | None]: (是否成功, 生成的内容, 推理过程, 模型名称, 工具调用列表)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
model_name_list = model_config.model_list
|
model_name_list = model_config.model_list
|
||||||
@@ -120,7 +110,7 @@ async def generate_with_model_with_tools(
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"生成内容时出错: {str(e)}"
|
error_msg = f"生成内容时出错: {str(e)}"
|
||||||
logger.error(f"[LLMAPI] {error_msg}")
|
logger.error(f"[LLMService] {error_msg}")
|
||||||
return False, error_msg, "", "", None
|
return False, error_msg, "", "", None
|
||||||
|
|
||||||
|
|
||||||
@@ -161,5 +151,5 @@ async def generate_with_model_with_tools_by_message_factory(
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"生成内容时出错: {str(e)}"
|
error_msg = f"生成内容时出错: {str(e)}"
|
||||||
logger.error(f"[LLMAPI] {error_msg}")
|
logger.error(f"[LLMService] {error_msg}")
|
||||||
return False, error_msg, "", "", None
|
return False, error_msg, "", "", None
|
||||||
@@ -1,62 +1,44 @@
|
|||||||
"""
|
"""
|
||||||
消息API模块
|
消息服务模块
|
||||||
|
|
||||||
提供消息查询和构建成字符串的功能,采用标准Python包设计模式
|
提供消息查询和构建成字符串的核心功能。
|
||||||
使用方式:
|
|
||||||
from src.plugin_system.apis import message_api
|
|
||||||
messages = message_api.get_messages_by_time_in_chat(chat_id, start_time, end_time)
|
|
||||||
readable_text = message_api.build_readable_messages(messages)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from typing import List, Dict, Any, Tuple, Optional
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
|
||||||
from src.common.database.database import get_db_session
|
|
||||||
from src.common.database.database_model import Images, ImageType
|
|
||||||
from src.chat.utils.utils import is_bot_self
|
|
||||||
from src.chat.utils.chat_message_builder import (
|
from src.chat.utils.chat_message_builder import (
|
||||||
get_raw_msg_by_timestamp,
|
|
||||||
get_raw_msg_by_timestamp_with_chat,
|
|
||||||
get_raw_msg_by_timestamp_with_chat_inclusive,
|
|
||||||
get_raw_msg_by_timestamp_with_chat_users,
|
|
||||||
get_raw_msg_by_timestamp_random,
|
|
||||||
get_raw_msg_by_timestamp_with_users,
|
|
||||||
get_raw_msg_before_timestamp,
|
|
||||||
get_raw_msg_before_timestamp_with_chat,
|
|
||||||
get_raw_msg_before_timestamp_with_users,
|
|
||||||
num_new_messages_since,
|
|
||||||
num_new_messages_since_with_users,
|
|
||||||
build_readable_messages,
|
build_readable_messages,
|
||||||
build_readable_messages_with_list,
|
build_readable_messages_with_list,
|
||||||
get_person_id_list,
|
get_person_id_list,
|
||||||
|
get_raw_msg_before_timestamp,
|
||||||
|
get_raw_msg_before_timestamp_with_chat,
|
||||||
|
get_raw_msg_before_timestamp_with_users,
|
||||||
|
get_raw_msg_by_timestamp,
|
||||||
|
get_raw_msg_by_timestamp_random,
|
||||||
|
get_raw_msg_by_timestamp_with_chat,
|
||||||
|
get_raw_msg_by_timestamp_with_chat_inclusive,
|
||||||
|
get_raw_msg_by_timestamp_with_chat_users,
|
||||||
|
get_raw_msg_by_timestamp_with_users,
|
||||||
|
num_new_messages_since,
|
||||||
|
num_new_messages_since_with_users,
|
||||||
)
|
)
|
||||||
|
from src.chat.utils.utils import is_bot_self
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
from src.common.database.database import get_db_session
|
||||||
|
from src.common.database.database_model import Images, ImageType
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# 消息查询API函数
|
# 消息查询函数
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
def get_messages_by_time(
|
def get_messages_by_time(
|
||||||
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
|
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
|
||||||
) -> List[DatabaseMessages]:
|
) -> List[DatabaseMessages]:
|
||||||
"""
|
|
||||||
获取指定时间范围内的消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
start_time: 开始时间戳
|
|
||||||
end_time: 结束时间戳
|
|
||||||
limit: 限制返回的消息数量,0为不限制
|
|
||||||
limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录
|
|
||||||
filter_mai: 是否过滤麦麦自身的消息,默认为False
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Dict[str, Any]]: 消息列表
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: 如果参数不合法
|
|
||||||
"""
|
|
||||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||||
if limit < 0:
|
if limit < 0:
|
||||||
@@ -76,23 +58,6 @@ def get_messages_by_time_in_chat(
|
|||||||
filter_command: bool = False,
|
filter_command: bool = False,
|
||||||
filter_intercept_message_level: Optional[int] = None,
|
filter_intercept_message_level: Optional[int] = None,
|
||||||
) -> List[DatabaseMessages]:
|
) -> List[DatabaseMessages]:
|
||||||
"""
|
|
||||||
获取指定聊天中指定时间范围内的消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: 聊天ID
|
|
||||||
start_time: 开始时间戳
|
|
||||||
end_time: 结束时间戳
|
|
||||||
limit: 限制返回的消息数量,0为不限制
|
|
||||||
limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录
|
|
||||||
filter_mai: 是否过滤麦麦自身的消息,默认为False
|
|
||||||
filter_command: 是否过滤命令消息,默认为False
|
|
||||||
Returns:
|
|
||||||
List[Dict[str, Any]]: 消息列表
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: 如果参数不合法
|
|
||||||
"""
|
|
||||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||||
if limit < 0:
|
if limit < 0:
|
||||||
@@ -101,10 +66,6 @@ def get_messages_by_time_in_chat(
|
|||||||
raise ValueError("chat_id 不能为空")
|
raise ValueError("chat_id 不能为空")
|
||||||
if not isinstance(chat_id, str):
|
if not isinstance(chat_id, str):
|
||||||
raise ValueError("chat_id 必须是字符串类型")
|
raise ValueError("chat_id 必须是字符串类型")
|
||||||
# if filter_mai:
|
|
||||||
# return filter_mai_messages(
|
|
||||||
# get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command)
|
|
||||||
# )
|
|
||||||
return get_raw_msg_by_timestamp_with_chat(
|
return get_raw_msg_by_timestamp_with_chat(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
timestamp_start=start_time,
|
timestamp_start=start_time,
|
||||||
@@ -127,23 +88,6 @@ def get_messages_by_time_in_chat_inclusive(
|
|||||||
filter_command: bool = False,
|
filter_command: bool = False,
|
||||||
filter_intercept_message_level: Optional[int] = None,
|
filter_intercept_message_level: Optional[int] = None,
|
||||||
) -> List[DatabaseMessages]:
|
) -> List[DatabaseMessages]:
|
||||||
"""
|
|
||||||
获取指定聊天中指定时间范围内的消息(包含边界)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: 聊天ID
|
|
||||||
start_time: 开始时间戳(包含)
|
|
||||||
end_time: 结束时间戳(包含)
|
|
||||||
limit: 限制返回的消息数量,0为不限制
|
|
||||||
limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录
|
|
||||||
filter_mai: 是否过滤麦麦自身的消息,默认为False
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Dict[str, Any]]: 消息列表
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: 如果参数不合法
|
|
||||||
"""
|
|
||||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||||
if limit < 0:
|
if limit < 0:
|
||||||
@@ -175,23 +119,6 @@ def get_messages_by_time_in_chat_for_users(
|
|||||||
limit: int = 0,
|
limit: int = 0,
|
||||||
limit_mode: str = "latest",
|
limit_mode: str = "latest",
|
||||||
) -> List[DatabaseMessages]:
|
) -> List[DatabaseMessages]:
|
||||||
"""
|
|
||||||
获取指定聊天中指定用户在指定时间范围内的消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: 聊天ID
|
|
||||||
start_time: 开始时间戳
|
|
||||||
end_time: 结束时间戳
|
|
||||||
person_ids: 用户ID列表
|
|
||||||
limit: 限制返回的消息数量,0为不限制
|
|
||||||
limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Dict[str, Any]]: 消息列表
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: 如果参数不合法
|
|
||||||
"""
|
|
||||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||||
if limit < 0:
|
if limit < 0:
|
||||||
@@ -206,22 +133,6 @@ def get_messages_by_time_in_chat_for_users(
|
|||||||
def get_random_chat_messages(
|
def get_random_chat_messages(
|
||||||
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
|
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
|
||||||
) -> List[DatabaseMessages]:
|
) -> List[DatabaseMessages]:
|
||||||
"""
|
|
||||||
随机选择一个聊天,返回该聊天在指定时间范围内的消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
start_time: 开始时间戳
|
|
||||||
end_time: 结束时间戳
|
|
||||||
limit: 限制返回的消息数量,0为不限制
|
|
||||||
limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录
|
|
||||||
filter_mai: 是否过滤麦麦自身的消息,默认为False
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Dict[str, Any]]: 消息列表
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: 如果参数不合法
|
|
||||||
"""
|
|
||||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||||
if limit < 0:
|
if limit < 0:
|
||||||
@@ -234,22 +145,6 @@ def get_random_chat_messages(
|
|||||||
def get_messages_by_time_for_users(
|
def get_messages_by_time_for_users(
|
||||||
start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
|
start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
|
||||||
) -> List[DatabaseMessages]:
|
) -> List[DatabaseMessages]:
|
||||||
"""
|
|
||||||
获取指定用户在所有聊天中指定时间范围内的消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
start_time: 开始时间戳
|
|
||||||
end_time: 结束时间戳
|
|
||||||
person_ids: 用户ID列表
|
|
||||||
limit: 限制返回的消息数量,0为不限制
|
|
||||||
limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Dict[str, Any]]: 消息列表
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: 如果参数不合法
|
|
||||||
"""
|
|
||||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||||
if limit < 0:
|
if limit < 0:
|
||||||
@@ -258,20 +153,6 @@ def get_messages_by_time_for_users(
|
|||||||
|
|
||||||
|
|
||||||
def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[DatabaseMessages]:
|
def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[DatabaseMessages]:
|
||||||
"""
|
|
||||||
获取指定时间戳之前的消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
timestamp: 时间戳
|
|
||||||
limit: 限制返回的消息数量,0为不限制
|
|
||||||
filter_mai: 是否过滤麦麦自身的消息,默认为False
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Dict[str, Any]]: 消息列表
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: 如果参数不合法
|
|
||||||
"""
|
|
||||||
if not isinstance(timestamp, (int, float)):
|
if not isinstance(timestamp, (int, float)):
|
||||||
raise ValueError("timestamp 必须是数字类型")
|
raise ValueError("timestamp 必须是数字类型")
|
||||||
if limit < 0:
|
if limit < 0:
|
||||||
@@ -288,21 +169,6 @@ def get_messages_before_time_in_chat(
|
|||||||
filter_mai: bool = False,
|
filter_mai: bool = False,
|
||||||
filter_intercept_message_level: Optional[int] = None,
|
filter_intercept_message_level: Optional[int] = None,
|
||||||
) -> List[DatabaseMessages]:
|
) -> List[DatabaseMessages]:
|
||||||
"""
|
|
||||||
获取指定聊天中指定时间戳之前的消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: 聊天ID
|
|
||||||
timestamp: 时间戳
|
|
||||||
limit: 限制返回的消息数量,0为不限制
|
|
||||||
filter_mai: 是否过滤麦麦自身的消息,默认为False
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Dict[str, Any]]: 消息列表
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: 如果参数不合法
|
|
||||||
"""
|
|
||||||
if not isinstance(timestamp, (int, float)):
|
if not isinstance(timestamp, (int, float)):
|
||||||
raise ValueError("timestamp 必须是数字类型")
|
raise ValueError("timestamp 必须是数字类型")
|
||||||
if limit < 0:
|
if limit < 0:
|
||||||
@@ -325,20 +191,6 @@ def get_messages_before_time_in_chat(
|
|||||||
def get_messages_before_time_for_users(
|
def get_messages_before_time_for_users(
|
||||||
timestamp: float, person_ids: List[str], limit: int = 0
|
timestamp: float, person_ids: List[str], limit: int = 0
|
||||||
) -> List[DatabaseMessages]:
|
) -> List[DatabaseMessages]:
|
||||||
"""
|
|
||||||
获取指定用户在指定时间戳之前的消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
timestamp: 时间戳
|
|
||||||
person_ids: 用户ID列表
|
|
||||||
limit: 限制返回的消息数量,0为不限制
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Dict[str, Any]]: 消息列表
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: 如果参数不合法
|
|
||||||
"""
|
|
||||||
if not isinstance(timestamp, (int, float)):
|
if not isinstance(timestamp, (int, float)):
|
||||||
raise ValueError("timestamp 必须是数字类型")
|
raise ValueError("timestamp 必须是数字类型")
|
||||||
if limit < 0:
|
if limit < 0:
|
||||||
@@ -349,22 +201,6 @@ def get_messages_before_time_for_users(
|
|||||||
def get_recent_messages(
|
def get_recent_messages(
|
||||||
chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False
|
chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False
|
||||||
) -> List[DatabaseMessages]:
|
) -> List[DatabaseMessages]:
|
||||||
"""
|
|
||||||
获取指定聊天中最近一段时间的消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: 聊天ID
|
|
||||||
hours: 最近多少小时,默认24小时
|
|
||||||
limit: 限制返回的消息数量,默认100条
|
|
||||||
limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录
|
|
||||||
filter_mai: 是否过滤麦麦自身的消息,默认为False
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Dict[str, Any]]: 消息列表
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: 如果参数不合法s
|
|
||||||
"""
|
|
||||||
if not isinstance(hours, (int, float)) or hours < 0:
|
if not isinstance(hours, (int, float)) or hours < 0:
|
||||||
raise ValueError("hours 不能是负数")
|
raise ValueError("hours 不能是负数")
|
||||||
if not isinstance(limit, int) or limit < 0:
|
if not isinstance(limit, int) or limit < 0:
|
||||||
@@ -381,25 +217,11 @@ def get_recent_messages(
|
|||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# 消息计数API函数
|
# 消息计数函数
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional[float] = None) -> int:
|
def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional[float] = None) -> int:
|
||||||
"""
|
|
||||||
计算指定聊天中从开始时间到结束时间的新消息数量
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: 聊天ID
|
|
||||||
start_time: 开始时间戳
|
|
||||||
end_time: 结束时间戳,如果为None则使用当前时间
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: 新消息数量
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: 如果参数不合法
|
|
||||||
"""
|
|
||||||
if not isinstance(start_time, (int, float)):
|
if not isinstance(start_time, (int, float)):
|
||||||
raise ValueError("start_time 必须是数字类型")
|
raise ValueError("start_time 必须是数字类型")
|
||||||
if not chat_id:
|
if not chat_id:
|
||||||
@@ -410,21 +232,6 @@ def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional
|
|||||||
|
|
||||||
|
|
||||||
def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int:
|
def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int:
|
||||||
"""
|
|
||||||
计算指定聊天中指定用户从开始时间到结束时间的新消息数量
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: 聊天ID
|
|
||||||
start_time: 开始时间戳
|
|
||||||
end_time: 结束时间戳
|
|
||||||
person_ids: 用户ID列表
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: 新消息数量
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: 如果参数不合法
|
|
||||||
"""
|
|
||||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||||
if not chat_id:
|
if not chat_id:
|
||||||
@@ -435,7 +242,7 @@ def count_new_messages_for_users(chat_id: str, start_time: float, end_time: floa
|
|||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# 消息格式化API函数
|
# 消息格式化函数
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
@@ -447,21 +254,6 @@ def build_readable_messages_to_str(
|
|||||||
truncate: bool = False,
|
truncate: bool = False,
|
||||||
show_actions: bool = False,
|
show_actions: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
|
||||||
将消息列表构建成可读的字符串
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: 消息列表
|
|
||||||
replace_bot_name: 是否将机器人的名称替换为"你"
|
|
||||||
merge_messages: 是否合并连续消息
|
|
||||||
timestamp_mode: 时间戳显示模式,'relative'或'absolute'
|
|
||||||
read_mark: 已读标记时间戳,用于分割已读和未读消息
|
|
||||||
truncate: 是否截断长消息
|
|
||||||
show_actions: 是否显示动作记录
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
格式化后的可读字符串
|
|
||||||
"""
|
|
||||||
return build_readable_messages(messages, replace_bot_name, timestamp_mode, read_mark, truncate, show_actions)
|
return build_readable_messages(messages, replace_bot_name, timestamp_mode, read_mark, truncate, show_actions)
|
||||||
|
|
||||||
|
|
||||||
@@ -471,32 +263,10 @@ async def build_readable_messages_with_details(
|
|||||||
timestamp_mode: str = "relative",
|
timestamp_mode: str = "relative",
|
||||||
truncate: bool = False,
|
truncate: bool = False,
|
||||||
) -> Tuple[str, List[Tuple[float, str, str]]]:
|
) -> Tuple[str, List[Tuple[float, str, str]]]:
|
||||||
"""
|
|
||||||
将消息列表构建成可读的字符串,并返回详细信息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: 消息列表
|
|
||||||
replace_bot_name: 是否将机器人的名称替换为"你"
|
|
||||||
merge_messages: 是否合并连续消息
|
|
||||||
timestamp_mode: 时间戳显示模式,'relative'或'absolute'
|
|
||||||
truncate: 是否截断长消息
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
格式化后的可读字符串和详细信息元组列表(时间戳, 昵称, 内容)
|
|
||||||
"""
|
|
||||||
return await build_readable_messages_with_list(messages, replace_bot_name, timestamp_mode, truncate)
|
return await build_readable_messages_with_list(messages, replace_bot_name, timestamp_mode, truncate)
|
||||||
|
|
||||||
|
|
||||||
async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[str]:
|
async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[str]:
|
||||||
"""
|
|
||||||
从消息列表中提取不重复的用户ID列表
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: 消息列表
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
用户ID列表
|
|
||||||
"""
|
|
||||||
return await get_person_id_list(messages)
|
return await get_person_id_list(messages)
|
||||||
|
|
||||||
|
|
||||||
@@ -506,14 +276,7 @@ async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[s
|
|||||||
|
|
||||||
|
|
||||||
def filter_mai_messages(messages: List[DatabaseMessages]) -> List[DatabaseMessages]:
|
def filter_mai_messages(messages: List[DatabaseMessages]) -> List[DatabaseMessages]:
|
||||||
"""
|
"""从消息列表中移除麦麦的消息"""
|
||||||
从消息列表中移除麦麦的消息
|
|
||||||
Args:
|
|
||||||
messages: 消息列表,每个元素是消息字典
|
|
||||||
Returns:
|
|
||||||
过滤后的消息列表
|
|
||||||
"""
|
|
||||||
# 使用统一的 is_bot_self 函数判断是否是机器人自己(支持多平台,包括 WebUI)
|
|
||||||
return [msg for msg in messages if not is_bot_self(msg.user_info.platform, msg.user_info.user_id)]
|
return [msg for msg in messages if not is_bot_self(msg.user_info.platform, msg.user_info.user_id)]
|
||||||
|
|
||||||
|
|
||||||
@@ -1,22 +1,14 @@
|
|||||||
"""个人信息API模块
|
"""个人信息服务模块
|
||||||
|
|
||||||
提供个人信息查询功能,用于插件获取用户相关信息
|
提供个人信息查询的核心功能。
|
||||||
使用方式:
|
|
||||||
from src.plugin_system.apis import person_api
|
|
||||||
person_id = person_api.get_person_id("qq", 123456)
|
|
||||||
value = await person_api.get_person_value(person_id, "nickname")
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.person_info.person_info import Person
|
from src.person_info.person_info import Person
|
||||||
|
|
||||||
logger = get_logger("person_api")
|
logger = get_logger("person_service")
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# 个人信息API函数
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
def get_person_id(platform: str, user_id: int | str) -> str:
|
def get_person_id(platform: str, user_id: int | str) -> str:
|
||||||
@@ -28,14 +20,11 @@ def get_person_id(platform: str, user_id: int | str) -> str:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 唯一的person_id(MD5哈希值)
|
str: 唯一的person_id(MD5哈希值)
|
||||||
|
|
||||||
示例:
|
|
||||||
person_id = person_api.get_person_id("qq", 123456)
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return Person(platform=platform, user_id=str(user_id)).person_id
|
return Person(platform=platform, user_id=str(user_id)).person_id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[PersonAPI] 获取person_id失败: platform={platform}, user_id={user_id}, error={e}")
|
logger.error(f"[PersonService] 获取person_id失败: platform={platform}, user_id={user_id}, error={e}")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
@@ -49,17 +38,13 @@ async def get_person_value(person_id: str, field_name: str, default: Any = None)
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Any: 字段值或默认值
|
Any: 字段值或默认值
|
||||||
|
|
||||||
示例:
|
|
||||||
nickname = await person_api.get_person_value(person_id, "nickname", "未知用户")
|
|
||||||
impression = await person_api.get_person_value(person_id, "impression")
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
person = Person(person_id=person_id)
|
person = Person(person_id=person_id)
|
||||||
value = getattr(person, field_name)
|
value = getattr(person, field_name)
|
||||||
return value if value is not None else default
|
return value if value is not None else default
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[PersonAPI] 获取用户信息失败: person_id={person_id}, field={field_name}, error={e}")
|
logger.error(f"[PersonService] 获取用户信息失败: person_id={person_id}, field={field_name}, error={e}")
|
||||||
return default
|
return default
|
||||||
|
|
||||||
|
|
||||||
@@ -71,13 +56,10 @@ def get_person_id_by_name(person_name: str) -> str:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: person_id,如果未找到返回空字符串
|
str: person_id,如果未找到返回空字符串
|
||||||
|
|
||||||
示例:
|
|
||||||
person_id = person_api.get_person_id_by_name("张三")
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
person = Person(person_name=person_name)
|
person = Person(person_name=person_name)
|
||||||
return person.person_id
|
return person.person_id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[PersonAPI] 根据用户名获取person_id失败: person_name={person_name}, error={e}")
|
logger.error(f"[PersonService] 根据用户名获取person_id失败: person_name={person_name}, error={e}")
|
||||||
return ""
|
return ""
|
||||||
@@ -1,47 +1,32 @@
|
|||||||
"""
|
"""
|
||||||
发送API模块
|
发送服务模块
|
||||||
|
|
||||||
专门负责发送各种类型的消息,采用标准Python包设计模式
|
提供发送各种类型消息的核心功能。
|
||||||
|
|
||||||
使用方式:
|
|
||||||
from src.plugin_system.apis import send_api
|
|
||||||
|
|
||||||
# 方式1:直接使用stream_id(推荐)
|
|
||||||
await send_api.text_to_stream("hello", stream_id)
|
|
||||||
await send_api.emoji_to_stream(emoji_base64, stream_id)
|
|
||||||
await send_api.custom_to_stream("video", video_data, stream_id)
|
|
||||||
|
|
||||||
# 方式2:使用群聊/私聊指定函数
|
|
||||||
await send_api.text_to_group("hello", "123456")
|
|
||||||
await send_api.text_to_user("hello", "987654")
|
|
||||||
|
|
||||||
# 方式3:使用通用custom_message函数
|
|
||||||
await send_api.custom_message("video", video_data, "123456", True)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import traceback
|
import traceback
|
||||||
import time
|
import time
|
||||||
from typing import Optional, Union, Dict, List, TYPE_CHECKING, Tuple
|
from typing import Optional, Union, Dict, List, TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from maim_message import MessageBase, BaseMessageInfo, Seg
|
||||||
from src.common.data_models.message_data_model import ReplyContentType
|
|
||||||
from src.config.config import global_config
|
|
||||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
|
||||||
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
|
|
||||||
from maim_message import Seg
|
|
||||||
|
|
||||||
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
|
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||||
from src.chat.message_receive.message import MessageSending
|
from src.chat.message_receive.message import MessageSending
|
||||||
|
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
|
||||||
|
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
|
||||||
|
from src.common.data_models.message_data_model import ReplyContentType
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import global_config
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.common.data_models.message_data_model import ReplySetModel, ReplyContent, ForwardNode
|
from src.common.data_models.message_data_model import ForwardNode, ReplyContent, ReplySetModel
|
||||||
|
|
||||||
logger = get_logger("send_api")
|
logger = get_logger("send_service")
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# 内部实现函数(不暴露给外部)
|
# 内部实现函数
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
@@ -56,42 +41,25 @@ async def _send_to_target(
|
|||||||
show_log: bool = True,
|
show_log: bool = True,
|
||||||
selected_expressions: Optional[List[int]] = None,
|
selected_expressions: Optional[List[int]] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""向指定目标发送消息的内部实现
|
"""向指定目标发送消息的内部实现"""
|
||||||
|
|
||||||
Args:
|
|
||||||
message_segment:
|
|
||||||
stream_id: 目标流ID
|
|
||||||
display_message: 显示消息
|
|
||||||
typing: 是否模拟打字等待。
|
|
||||||
reply_to: 回复消息,格式为"发送者:消息内容"
|
|
||||||
storage_message: 是否存储消息到数据库
|
|
||||||
show_log: 发送是否显示日志
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否发送成功
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
if set_reply and not reply_message:
|
if set_reply and not reply_message:
|
||||||
logger.warning("[SendAPI] 使用引用回复,但未提供回复消息")
|
logger.warning("[SendService] 使用引用回复,但未提供回复消息")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if show_log:
|
if show_log:
|
||||||
logger.debug(f"[SendAPI] 发送{message_segment.type}消息到 {stream_id}")
|
logger.debug(f"[SendService] 发送{message_segment.type}消息到 {stream_id}")
|
||||||
|
|
||||||
# 查找目标聊天流
|
|
||||||
target_stream = _chat_manager.get_session_by_session_id(stream_id)
|
target_stream = _chat_manager.get_session_by_session_id(stream_id)
|
||||||
if not target_stream:
|
if not target_stream:
|
||||||
logger.error(f"[SendAPI] 未找到聊天流: {stream_id}")
|
logger.error(f"[SendService] 未找到聊天流: {stream_id}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 创建发送器
|
|
||||||
message_sender = UniversalMessageSender()
|
message_sender = UniversalMessageSender()
|
||||||
|
|
||||||
# 生成消息ID
|
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
message_id = f"send_api_{int(current_time * 1000)}"
|
message_id = f"send_api_{int(current_time * 1000)}"
|
||||||
|
|
||||||
# 构建机器人用户信息
|
|
||||||
bot_user_info = UserInfo(
|
bot_user_info = UserInfo(
|
||||||
user_id=global_config.bot.qq_account,
|
user_id=global_config.bot.qq_account,
|
||||||
user_nickname=global_config.bot.nickname,
|
user_nickname=global_config.bot.nickname,
|
||||||
@@ -102,17 +70,15 @@ async def _send_to_target(
|
|||||||
if reply_message:
|
if reply_message:
|
||||||
anchor_message = db_message_to_mai_message(reply_message)
|
anchor_message = db_message_to_mai_message(reply_message)
|
||||||
if anchor_message:
|
if anchor_message:
|
||||||
logger.debug(f"[SendAPI] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}")
|
logger.debug(f"[SendService] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}")
|
||||||
reply_to_platform_id = (
|
reply_to_platform_id = (
|
||||||
f"{anchor_message.platform}:{anchor_message.message_info.user_info.user_id}"
|
f"{anchor_message.platform}:{anchor_message.message_info.user_info.user_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 构建 sender_info(私聊时为接收者信息)
|
|
||||||
sender_info = None
|
sender_info = None
|
||||||
if target_stream.context and target_stream.context.message:
|
if target_stream.context and target_stream.context.message:
|
||||||
sender_info = target_stream.context.message.message_info.user_info
|
sender_info = target_stream.context.message.message_info.user_info
|
||||||
|
|
||||||
# 构建发送消息对象
|
|
||||||
bot_message = MessageSending(
|
bot_message = MessageSending(
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
session=target_stream,
|
session=target_stream,
|
||||||
@@ -128,7 +94,6 @@ async def _send_to_target(
|
|||||||
selected_expressions=selected_expressions,
|
selected_expressions=selected_expressions,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 发送消息
|
|
||||||
sent_msg = await message_sender.send_message(
|
sent_msg = await message_sender.send_message(
|
||||||
bot_message,
|
bot_message,
|
||||||
typing=typing,
|
typing=typing,
|
||||||
@@ -138,28 +103,22 @@ async def _send_to_target(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if sent_msg:
|
if sent_msg:
|
||||||
logger.debug(f"[SendAPI] 成功发送消息到 {stream_id}")
|
logger.debug(f"[SendService] 成功发送消息到 {stream_id}")
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
logger.error("[SendAPI] 发送消息失败")
|
logger.error("[SendService] 发送消息失败")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[SendAPI] 发送消息时出错: {e}")
|
logger.error(f"[SendService] 发送消息时出错: {e}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def db_message_to_mai_message(message_obj: "DatabaseMessages") -> Optional[MaiMessage]:
|
def db_message_to_mai_message(message_obj: "DatabaseMessages") -> Optional[MaiMessage]:
|
||||||
"""将数据库消息重建为 MaiMessage 对象,用于回复引用。
|
"""将数据库消息重建为 MaiMessage 对象,用于回复引用。"""
|
||||||
|
|
||||||
Args:
|
|
||||||
message_obj: 插件系统的 DatabaseMessages 数据对象
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[MaiMessage]: 构建的消息对象,如果信息不足则返回 None
|
|
||||||
"""
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from src.common.data_models.mai_message_data_model import GroupInfo, MessageInfo
|
from src.common.data_models.mai_message_data_model import GroupInfo, MessageInfo
|
||||||
from src.common.data_models.message_component_data_model import MessageSequence
|
from src.common.data_models.message_component_data_model import MessageSequence
|
||||||
|
|
||||||
@@ -190,7 +149,7 @@ def db_message_to_mai_message(message_obj: "DatabaseMessages") -> Optional[MaiMe
|
|||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# 公共API函数 - 预定义类型的发送函数
|
# 公共函数 - 预定义类型的发送函数
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
@@ -203,18 +162,7 @@ async def text_to_stream(
|
|||||||
storage_message: bool = True,
|
storage_message: bool = True,
|
||||||
selected_expressions: Optional[List[int]] = None,
|
selected_expressions: Optional[List[int]] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""向指定流发送文本消息
|
"""向指定流发送文本消息"""
|
||||||
|
|
||||||
Args:
|
|
||||||
text: 要发送的文本内容
|
|
||||||
stream_id: 聊天流ID
|
|
||||||
typing: 是否显示正在输入
|
|
||||||
reply_to: 回复消息,格式为"发送者:消息内容"
|
|
||||||
storage_message: 是否存储消息到数据库
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否发送成功
|
|
||||||
"""
|
|
||||||
return await _send_to_target(
|
return await _send_to_target(
|
||||||
message_segment=Seg(type="text", data=text),
|
message_segment=Seg(type="text", data=text),
|
||||||
stream_id=stream_id,
|
stream_id=stream_id,
|
||||||
@@ -234,16 +182,7 @@ async def emoji_to_stream(
|
|||||||
set_reply: bool = False,
|
set_reply: bool = False,
|
||||||
reply_message: Optional["DatabaseMessages"] = None,
|
reply_message: Optional["DatabaseMessages"] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""向指定流发送表情包
|
"""向指定流发送表情包"""
|
||||||
|
|
||||||
Args:
|
|
||||||
emoji_base64: 表情包的base64编码
|
|
||||||
stream_id: 聊天流ID
|
|
||||||
storage_message: 是否存储消息到数据库
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否发送成功
|
|
||||||
"""
|
|
||||||
return await _send_to_target(
|
return await _send_to_target(
|
||||||
message_segment=Seg(type="emoji", data=emoji_base64),
|
message_segment=Seg(type="emoji", data=emoji_base64),
|
||||||
stream_id=stream_id,
|
stream_id=stream_id,
|
||||||
@@ -262,16 +201,7 @@ async def image_to_stream(
|
|||||||
set_reply: bool = False,
|
set_reply: bool = False,
|
||||||
reply_message: Optional["DatabaseMessages"] = None,
|
reply_message: Optional["DatabaseMessages"] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""向指定流发送图片
|
"""向指定流发送图片"""
|
||||||
|
|
||||||
Args:
|
|
||||||
image_base64: 图片的base64编码
|
|
||||||
stream_id: 聊天流ID
|
|
||||||
storage_message: 是否存储消息到数据库
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否发送成功
|
|
||||||
"""
|
|
||||||
return await _send_to_target(
|
return await _send_to_target(
|
||||||
message_segment=Seg(type="image", data=image_base64),
|
message_segment=Seg(type="image", data=image_base64),
|
||||||
stream_id=stream_id,
|
stream_id=stream_id,
|
||||||
@@ -289,17 +219,7 @@ async def command_to_stream(
|
|||||||
storage_message: bool = True,
|
storage_message: bool = True,
|
||||||
display_message: str = "",
|
display_message: str = "",
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""向指定流发送命令
|
"""向指定流发送命令"""
|
||||||
|
|
||||||
Args:
|
|
||||||
command: 命令
|
|
||||||
stream_id: 聊天流ID
|
|
||||||
storage_message: 是否存储消息到数据库
|
|
||||||
display_message: 显示消息
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否发送成功
|
|
||||||
"""
|
|
||||||
return await _send_to_target(
|
return await _send_to_target(
|
||||||
message_segment=Seg(type="command", data=command), # type: ignore
|
message_segment=Seg(type="command", data=command), # type: ignore
|
||||||
stream_id=stream_id,
|
stream_id=stream_id,
|
||||||
@@ -321,20 +241,7 @@ async def custom_to_stream(
|
|||||||
storage_message: bool = True,
|
storage_message: bool = True,
|
||||||
show_log: bool = True,
|
show_log: bool = True,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""向指定流发送自定义类型消息
|
"""向指定流发送自定义类型消息"""
|
||||||
|
|
||||||
Args:
|
|
||||||
message_type: 消息类型,如"text"、"image"、"emoji"、"video"、"file"等
|
|
||||||
content: 消息内容(通常是base64编码或文本)
|
|
||||||
stream_id: 聊天流ID
|
|
||||||
display_message: 显示消息
|
|
||||||
typing: 是否显示正在输入
|
|
||||||
reply_to: 回复消息,格式为"发送者:消息内容"
|
|
||||||
storage_message: 是否存储消息到数据库
|
|
||||||
show_log: 是否显示日志
|
|
||||||
Returns:
|
|
||||||
bool: 是否发送成功
|
|
||||||
"""
|
|
||||||
return await _send_to_target(
|
return await _send_to_target(
|
||||||
message_segment=Seg(type=message_type, data=content), # type: ignore
|
message_segment=Seg(type=message_type, data=content), # type: ignore
|
||||||
stream_id=stream_id,
|
stream_id=stream_id,
|
||||||
@@ -350,25 +257,14 @@ async def custom_to_stream(
|
|||||||
async def custom_reply_set_to_stream(
|
async def custom_reply_set_to_stream(
|
||||||
reply_set: "ReplySetModel",
|
reply_set: "ReplySetModel",
|
||||||
stream_id: str,
|
stream_id: str,
|
||||||
display_message: str = "", # 基本没用
|
display_message: str = "",
|
||||||
typing: bool = False,
|
typing: bool = False,
|
||||||
reply_message: Optional["DatabaseMessages"] = None,
|
reply_message: Optional["DatabaseMessages"] = None,
|
||||||
set_reply: bool = False,
|
set_reply: bool = False,
|
||||||
storage_message: bool = True,
|
storage_message: bool = True,
|
||||||
show_log: bool = True,
|
show_log: bool = True,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""向指定流发送混合型消息集"""
|
||||||
向指定流发送混合型消息集
|
|
||||||
|
|
||||||
Args:
|
|
||||||
reply_set: ReplySetModel 对象,包含多个 ReplyContent
|
|
||||||
stream_id: 聊天流ID
|
|
||||||
display_message: 显示消息
|
|
||||||
typing: 是否显示正在输入
|
|
||||||
reply_to: 回复消息,格式为"发送者:消息内容"
|
|
||||||
storage_message: 是否存储消息到数据库
|
|
||||||
show_log: 是否显示日志
|
|
||||||
"""
|
|
||||||
flag: bool = True
|
flag: bool = True
|
||||||
for reply_content in reply_set.reply_data:
|
for reply_content in reply_set.reply_data:
|
||||||
status: bool = False
|
status: bool = False
|
||||||
@@ -386,20 +282,14 @@ async def custom_reply_set_to_stream(
|
|||||||
if not status:
|
if not status:
|
||||||
flag = False
|
flag = False
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[SendAPI] 发送{repr(reply_content.content_type)}消息失败,消息内容:{str(reply_content.content)[:100]}"
|
f"[SendService] 发送{repr(reply_content.content_type)}消息失败,消息内容:{str(reply_content.content)[:100]}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return flag
|
return flag
|
||||||
|
|
||||||
|
|
||||||
def _parse_content_to_seg(reply_content: "ReplyContent") -> Tuple[Seg, bool]:
|
def _parse_content_to_seg(reply_content: "ReplyContent") -> Tuple[Seg, bool]:
|
||||||
"""
|
"""把 ReplyContent 转换为 Seg 结构"""
|
||||||
把 ReplyContent 转换为 Seg 结构 (Forward 中仅递归一次)
|
|
||||||
Args:
|
|
||||||
reply_content: ReplyContent 对象
|
|
||||||
Returns:
|
|
||||||
Tuple[Seg, bool]: 转换后的 Seg 结构和是否需要typing的标志
|
|
||||||
"""
|
|
||||||
content_type = reply_content.content_type
|
content_type = reply_content.content_type
|
||||||
if content_type == ReplyContentType.TEXT:
|
if content_type == ReplyContentType.TEXT:
|
||||||
text_data: str = reply_content.content # type: ignore
|
text_data: str = reply_content.content # type: ignore
|
||||||
@@ -427,7 +317,7 @@ def _parse_content_to_seg(reply_content: "ReplyContent") -> Tuple[Seg, bool]:
|
|||||||
elif sub_content_type == ReplyContentType.EMOJI:
|
elif sub_content_type == ReplyContentType.EMOJI:
|
||||||
sub_seg_list.append(Seg(type="emoji", data=sub_content_data)) # type: ignore
|
sub_seg_list.append(Seg(type="emoji", data=sub_content_data)) # type: ignore
|
||||||
else:
|
else:
|
||||||
logger.warning(f"[SendAPI] 混合类型中不支持的子内容类型: {repr(sub_content_type)}")
|
logger.warning(f"[SendService] 混合类型中不支持的子内容类型: {repr(sub_content_type)}")
|
||||||
continue
|
continue
|
||||||
return Seg(type="seglist", data=sub_seg_list), True
|
return Seg(type="seglist", data=sub_seg_list), True
|
||||||
elif content_type == ReplyContentType.FORWARD:
|
elif content_type == ReplyContentType.FORWARD:
|
||||||
@@ -6,7 +6,7 @@ import json
|
|||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.toml_utils import save_toml_with_format
|
from src.common.toml_utils import save_toml_with_format
|
||||||
from src.config.config import MMC_VERSION
|
from src.config.config import MMC_VERSION
|
||||||
from src.plugin_system.base.config_types import ConfigField
|
from src.core.config_types import ConfigField
|
||||||
from src.webui.services.git_mirror_service import get_git_mirror_service, set_update_progress_callback
|
from src.webui.services.git_mirror_service import get_git_mirror_service, set_update_progress_callback
|
||||||
from src.webui.core import get_token_manager
|
from src.webui.core import get_token_manager
|
||||||
from src.webui.routers.websocket.plugin_progress import update_progress
|
from src.webui.routers.websocket.plugin_progress import update_progress
|
||||||
@@ -222,15 +222,16 @@ def coerce_types(schema_part: Dict[str, Any], config_part: Dict[str, Any]) -> No
|
|||||||
|
|
||||||
def find_plugin_instance(plugin_id: str) -> Optional[Any]:
|
def find_plugin_instance(plugin_id: str) -> Optional[Any]:
|
||||||
"""
|
"""
|
||||||
按 plugin_id 或 plugin_name 查找已加载的插件实例。
|
按 plugin_id 查找已加载的插件信息。
|
||||||
局部导入 plugin_manager 以规避循环依赖。
|
新运行时中插件运行在子进程,无法获取实例,返回注册信息。
|
||||||
"""
|
"""
|
||||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||||
|
|
||||||
for loaded_plugin_name in plugin_manager.list_loaded_plugins():
|
mgr = get_plugin_runtime_manager()
|
||||||
instance = plugin_manager.get_plugin_instance(loaded_plugin_name)
|
for sv in mgr.supervisors:
|
||||||
if instance and (instance.plugin_name == plugin_id or instance.get_manifest_info("id", "") == plugin_id):
|
reg = sv._registered_plugins.get(plugin_id)
|
||||||
return instance
|
if reg is not None:
|
||||||
|
return reg
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@@ -1497,26 +1498,10 @@ async def get_plugin_config_schema(
|
|||||||
logger.info(f"获取插件配置 Schema: {plugin_id}")
|
logger.info(f"获取插件配置 Schema: {plugin_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 尝试从已加载的插件中获取
|
# 新运行时中插件运行在子进程,无法直接获取实例的 webui_config_schema
|
||||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
# 尝试从文件系统读取
|
||||||
|
|
||||||
# 查找插件实例
|
|
||||||
plugin_instance = None
|
plugin_instance = None
|
||||||
|
|
||||||
# 遍历所有已加载的插件
|
|
||||||
for loaded_plugin_name in plugin_manager.list_loaded_plugins():
|
|
||||||
instance = plugin_manager.get_plugin_instance(loaded_plugin_name)
|
|
||||||
if instance:
|
|
||||||
# 匹配 plugin_name 或 manifest 中的 id
|
|
||||||
if instance.plugin_name == plugin_id:
|
|
||||||
plugin_instance = instance
|
|
||||||
break
|
|
||||||
# 也尝试匹配 manifest 中的 id
|
|
||||||
manifest_id = instance.get_manifest_info("id", "")
|
|
||||||
if manifest_id == plugin_id:
|
|
||||||
plugin_instance = instance
|
|
||||||
break
|
|
||||||
|
|
||||||
if plugin_instance and hasattr(plugin_instance, "get_webui_config_schema"):
|
if plugin_instance and hasattr(plugin_instance, "get_webui_config_schema"):
|
||||||
# 从插件实例获取 schema
|
# 从插件实例获取 schema
|
||||||
schema = plugin_instance.get_webui_config_schema()
|
schema = plugin_instance.get_webui_config_schema()
|
||||||
|
|||||||
Reference in New Issue
Block a user