diff --git a/AGENTS.md b/AGENTS.md
index b4caaaf1..b3456610 100644
--- a/AGENTS.md
+++ b/AGENTS.md
@@ -17,6 +17,7 @@
1. 尽量保持良好的注释
2. 如果原来的代码中有注释,则重构的时候,除非这部分代码被删除,否则相同功能的代码应该保留注释(可以对注释进行修改以保持准确性,但不应该删除注释)。
3. 如果原来的代码中没有注释,则重构的时候,如果某个功能块的代码较长或者逻辑较为复杂,则应该添加注释来解释这部分代码的功能和逻辑。
+4. 对于类,方法以及模块的注释,首选使用的注释格式为 Google DocStr 格式,但保证语言为简体中文
## 类型注解规范
1. 重构代码时,如果原来的代码中有类型注解,则相同功能的代码应该保留类型注解(可以对类型注解进行修改以保持准确性,但不应该删除类型注解)。
2. 重构代码时,如果原来的代码中没有类型注解,则重构的时候,如果某个函数的功能较为复杂或者参数较多,则应该添加类型注解来提高代码的可读性和可维护性。(对于简单的变量,可以不添加类型注解)
@@ -35,3 +36,7 @@
# 运行/调试/构建/测试/依赖
优先使用uv
依赖项以 pyproject.toml 为准
+
+# 语言规范
+
+项目的首选语言为简体中文,无论是注释语言,日志展示语言,还是 WebUI 展示语言都应该首要以简体中文为首要实现目标
diff --git a/README.md b/README.md
index 7c41c8cf..3f3851b7 100644
--- a/README.md
+++ b/README.md
@@ -1,21 +1,21 @@
-
简体中文 |
English
+
简体中文 |
English
diff --git a/docs/minimal-cross-platform-plan.md b/docs/minimal-cross-platform-plan.md
index d0b6707b..2f0a86bd 100644
--- a/docs/minimal-cross-platform-plan.md
+++ b/docs/minimal-cross-platform-plan.md
@@ -41,7 +41,7 @@ This plan is based on the checked-in code, not on assumptions from previous draf
| `src/person_info/person_info.py:247` | `_is_bot_self(self, platform, user_id)` | Duplicate logic with same QQ fallback |
Wrong-order call sites (8 total):
-- `src/bw_learner/expression_learner.py` x3 (lines 158, 241, 301)
+- `src/learners/expression_learner.py` x3 (lines 158, 241, 301)
- `src/common/utils/utils_message.py` x4 (lines 370, 440, 476, 515)
- `src/webui/routers/chat/support.py` x1 (line 65)
@@ -122,7 +122,7 @@ Make `src/chat/utils/utils.py::is_bot_self(platform, user_id)` the only real imp
- `src/common/utils/system_utils.py`
- `src/chat/utils/utils.py`
- `src/person_info/person_info.py`
-- `src/bw_learner/expression_learner.py`
+- `src/learners/expression_learner.py`
- `src/common/utils/utils_message.py`
- `src/webui/routers/chat/support.py`
- tests
@@ -468,7 +468,7 @@ When stopping, name: the exact file(s), the blocking mismatch, why it is outside
| Phase | Allowed files |
|-------|---------------|
-| Phase 0 | `src/common/utils/system_utils.py`, `src/chat/utils/utils.py`, `src/person_info/person_info.py`, `src/bw_learner/expression_learner.py`, `src/common/utils/utils_message.py`, `src/webui/routers/chat/support.py`, tests (including `pytests/utils_test/message_utils_test.py`) |
+| Phase 0 | `src/common/utils/system_utils.py`, `src/chat/utils/utils.py`, `src/person_info/person_info.py`, `src/learners/expression_learner.py`, `src/common/utils/utils_message.py`, `src/webui/routers/chat/support.py`, tests (including `pytests/utils_test/message_utils_test.py`) |
| Phase 1 | `src/chat/utils/utils.py`, `src/chat/planner_actions/planner.py`, `src/chat/utils/statistic.py`, `src/common/message_repository.py`, `src/webui/routers/chat/support.py`, `src/services/send_service.py`, `src/chat/replyer/group_generator.py`, `src/chat/replyer/private_generator.py`, `src/chat/brain_chat/PFC/message_sender.py`, `src/person_info/person_info.py`, tests |
### INVALID OUTPUT EXAMPLES
diff --git a/plugins/ChatFrequency/_manifest.json b/plugins/ChatFrequency/_manifest.json
index 241242ed..56417665 100644
--- a/plugins/ChatFrequency/_manifest.json
+++ b/plugins/ChatFrequency/_manifest.json
@@ -1,58 +1,40 @@
{
- "manifest_version": 1,
- "name": "发言频率控制插件|BetterFrequency Plugin",
+ "manifest_version": 2,
"version": "2.0.0",
- "description": "控制聊天频率,支持设置focus_value和talk_frequency调整值,提供命令",
+ "name": "发言频率控制插件|BetterFrequency Plugin",
+ "description": "控制聊天频率,支持设置 focus_value 和 talk_frequency 调整值,并提供命令入口。",
"author": {
"name": "SengokuCola",
"url": "https://github.com/MaiM-with-u"
},
"license": "GPL-v3.0-or-later",
- "host_application": {
- "min_version": "1.0.0"
+ "urls": {
+ "repository": "https://github.com/SengokuCola/BetterFrequency",
+ "homepage": "https://github.com/SengokuCola/BetterFrequency",
+ "documentation": "https://github.com/SengokuCola/BetterFrequency",
+ "issues": "https://github.com/SengokuCola/BetterFrequency/issues"
},
- "homepage_url": "https://github.com/SengokuCola/BetterFrequency",
- "repository_url": "https://github.com/SengokuCola/BetterFrequency",
- "keywords": [
- "frequency",
- "control",
- "talk_frequency",
- "plugin",
- "shortcut"
+ "host_application": {
+ "min_version": "1.0.0",
+ "max_version": "1.0.0"
+ },
+ "sdk": {
+ "min_version": "2.0.0",
+ "max_version": "2.99.99"
+ },
+ "dependencies": [],
+ "capabilities": [
+ "send.text",
+ "frequency.set_adjust",
+ "frequency.get_current_talk_value",
+ "frequency.get_adjust"
],
- "categories": [
- "Chat",
- "Frequency",
- "Control"
- ],
- "default_locale": "zh-CN",
- "locales_path": "_locales",
- "plugin_info": {
- "is_built_in": false,
- "plugin_type": "frequency",
- "components": [
- {
- "type": "command",
- "name": "set_talk_frequency",
- "description": "设置当前聊天的talk_frequency调整值",
- "pattern": "/chat talk_frequency <数字> 或 /chat t <数字>"
- },
- {
- "type": "command",
- "name": "show_frequency",
- "description": "显示当前聊天的频率控制状态",
- "pattern": "/chat show 或 /chat s"
- }
- ],
- "features": [
- "设置talk_frequency调整值",
- "调整当前聊天的发言频率",
- "显示当前频率控制状态",
- "实时频率控制调整",
- "命令执行反馈(不保存消息)",
- "支持完整命令和简化命令",
- "快速操作支持"
+ "i18n": {
+ "default_locale": "zh-CN",
+ "locales_path": "_locales",
+ "supported_locales": [
+ "zh-CN"
]
},
- "id": "SengokuCola.BetterFrequency"
-}
\ No newline at end of file
+ "id": "sengokucola.betterfrequency"
+}
diff --git a/plugins/ChatFrequency/plugin.py b/plugins/ChatFrequency/plugin.py
index b3f69384..0e9f5a0c 100644
--- a/plugins/ChatFrequency/plugin.py
+++ b/plugins/ChatFrequency/plugin.py
@@ -3,12 +3,18 @@
通过 /chat 命令设置和查看聊天频率。
"""
-from maibot_sdk import MaiBotPlugin, Command
+from maibot_sdk import Command, MaiBotPlugin
class BetterFrequencyPlugin(MaiBotPlugin):
"""聊天频率控制插件"""
+ async def on_load(self) -> None:
+ """处理插件加载。"""
+
+ async def on_unload(self) -> None:
+ """处理插件卸载。"""
+
@Command(
"set_talk_frequency",
description="设置当前聊天的talk_frequency值:/chat talk_frequency <数字> 或 /chat t <数字>",
@@ -80,6 +86,25 @@ class BetterFrequencyPlugin(MaiBotPlugin):
await self.ctx.send.text(status_msg, stream_id)
return True, None, False
+ async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None:
+ """处理配置热重载事件。
+
+ Args:
+ scope: 配置变更范围。
+ config_data: 最新配置数据。
+ version: 配置版本号。
+ """
+
+ del scope
+ del config_data
+ del version
+
+
+def create_plugin() -> BetterFrequencyPlugin:
+ """创建聊天频率插件实例。
+
+ Returns:
+ BetterFrequencyPlugin: 新的聊天频率插件实例。
+ """
-def create_plugin():
return BetterFrequencyPlugin()
diff --git a/plugins/MaiBot_MCPBridgePlugin/_manifest.json b/plugins/MaiBot_MCPBridgePlugin/_manifest.json
index 85225a43..d2e08ab4 100644
--- a/plugins/MaiBot_MCPBridgePlugin/_manifest.json
+++ b/plugins/MaiBot_MCPBridgePlugin/_manifest.json
@@ -1,67 +1,42 @@
{
- "manifest_version": 1,
- "name": "MCP桥接插件",
+ "manifest_version": 2,
"version": "2.0.0",
- "description": "将 MCP (Model Context Protocol) 服务器的工具桥接到 MaiBot,使麦麦能够调用外部 MCP 工具",
+ "name": "MCP桥接插件",
+ "description": "将 MCP (Model Context Protocol) 服务器的工具桥接到 MaiBot,使麦麦能够调用外部 MCP 工具。",
"author": {
"name": "CharTyr",
"url": "https://github.com/CharTyr"
},
"license": "AGPL-3.0",
- "host_application": {
- "min_version": "0.11.6"
+ "urls": {
+ "repository": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin",
+ "homepage": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin",
+ "documentation": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin",
+ "issues": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin/issues"
},
- "homepage_url": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin",
- "repository_url": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin",
- "keywords": [
- "mcp",
- "bridge",
- "tool",
- "integration",
- "resources",
- "prompts",
- "post-process",
- "cache",
- "trace",
- "permissions",
- "import",
- "export",
- "claude-desktop",
- "workflow",
- "react",
- "agent"
+ "host_application": {
+ "min_version": "0.11.6",
+ "max_version": "1.0.0"
+ },
+ "sdk": {
+ "min_version": "2.0.0",
+ "max_version": "2.99.99"
+ },
+ "dependencies": [
+ {
+ "type": "python_package",
+ "name": "mcp",
+ "version_spec": ">=0.0.0"
+ }
],
- "categories": [
- "工具扩展",
- "外部集成"
+ "capabilities": [
+ "send.text"
],
- "default_locale": "zh-CN",
- "plugin_info": {
- "is_built_in": false,
- "components": [],
- "features": [
- "支持多个 MCP 服务器",
- "自动发现并注册 MCP 工具",
- "支持 stdio、SSE、HTTP、Streamable HTTP 四种传输方式",
- "工具参数自动转换",
- "心跳检测与自动重连",
- "调用统计(次数、成功率、耗时)",
- "WebUI 配置支持",
- "Resources 支持(实验性)",
- "Prompts 支持(实验性)",
- "结果后处理(LLM 摘要提炼)",
- "工具禁用管理",
- "调用链路追踪",
- "工具调用缓存(LRU)",
- "工具权限控制(群/用户级别)",
- "配置导入导出(Claude Desktop mcpServers)",
- "断路器模式(故障快速失败)",
- "状态实时刷新",
- "Workflow 硬流程(顺序执行多个工具)",
- "Workflow 快速添加(表单式配置)",
- "ReAct 软流程(LLM 自主多轮调用)",
- "双轨制架构(软流程 + 硬流程)"
+ "i18n": {
+ "default_locale": "zh-CN",
+ "supported_locales": [
+ "zh-CN"
]
},
- "id": "MaiBot Community.MCPBridgePlugin"
+ "id": "chartyr.mcpbridge-plugin"
}
diff --git a/plugins/emoji_manage_plugin/_manifest.json b/plugins/emoji_manage_plugin/_manifest.json
index 3af69023..998cb7da 100644
--- a/plugins/emoji_manage_plugin/_manifest.json
+++ b/plugins/emoji_manage_plugin/_manifest.json
@@ -1,68 +1,44 @@
{
- "manifest_version": 1,
- "name": "BetterEmoji",
+ "manifest_version": 2,
"version": "2.0.0",
+ "name": "BetterEmoji",
"description": "更好的表情包管理插件",
"author": {
"name": "SengokuCola",
"url": "https://github.com/SengokuCola"
},
"license": "GPL-v3.0-or-later",
- "host_application": {
- "min_version": "1.0.0"
+ "urls": {
+ "repository": "https://github.com/SengokuCola/BetterEmoji",
+ "homepage": "https://github.com/SengokuCola/BetterEmoji",
+ "documentation": "https://github.com/SengokuCola/BetterEmoji",
+ "issues": "https://github.com/SengokuCola/BetterEmoji/issues"
},
- "homepage_url": "https://github.com/SengokuCola/BetterEmoji",
- "repository_url": "https://github.com/SengokuCola/BetterEmoji",
- "keywords": [
- "emoji",
- "manage",
- "plugin"
+ "host_application": {
+ "min_version": "1.0.0",
+ "max_version": "1.0.0"
+ },
+ "sdk": {
+ "min_version": "2.0.0",
+ "max_version": "2.99.99"
+ },
+ "dependencies": [],
+ "capabilities": [
+ "emoji.get_random",
+ "emoji.get_count",
+ "emoji.get_info",
+ "emoji.get_all",
+ "emoji.register_emoji",
+ "emoji.delete_emoji",
+ "send.text",
+ "send.forward"
],
- "categories": [
- "Emoji",
- "Management"
- ],
- "default_locale": "zh-CN",
- "locales_path": "_locales",
- "plugin_info": {
- "is_built_in": false,
- "plugin_type": "emoji_manage",
- "capabilities": [
- "emoji.get_random",
- "emoji.get_count",
- "emoji.get_info",
- "emoji.get_all",
- "emoji.register_emoji",
- "emoji.delete_emoji",
- "send.text",
- "send.forward"
- ],
- "components": [
- {
- "type": "command",
- "name": "add_emoji",
- "description": "添加表情包",
- "pattern": "/emoji add"
- },
- {
- "type": "command",
- "name": "emoji_list",
- "description": "列表表情包",
- "pattern": "/emoji list"
- },
- {
- "type": "command",
- "name": "delete_emoji",
- "description": "删除表情包",
- "pattern": "/emoji delete"
- },
- {
- "type": "command",
- "name": "random_emojis",
- "description": "发送多张随机表情包",
- "pattern": "/random_emojis"
- }
+ "i18n": {
+ "default_locale": "zh-CN",
+ "locales_path": "_locales",
+ "supported_locales": [
+ "zh-CN"
]
},
- "id": "SengokuCola.BetterEmoji"
-}
\ No newline at end of file
+ "id": "sengokucola.betteremoji"
+}
diff --git a/plugins/emoji_manage_plugin/plugin.py b/plugins/emoji_manage_plugin/plugin.py
index f3c5f677..9362c828 100644
--- a/plugins/emoji_manage_plugin/plugin.py
+++ b/plugins/emoji_manage_plugin/plugin.py
@@ -3,17 +3,23 @@
通过 /emoji 命令管理表情包的添加、列表和删除。
"""
+from maibot_sdk import Command, MaiBotPlugin
+
import base64
import datetime
import hashlib
import re
-from maibot_sdk import MaiBotPlugin, Command
-
class EmojiManagePlugin(MaiBotPlugin):
"""表情包管理插件"""
+ async def on_load(self) -> None:
+ """处理插件加载。"""
+
+ async def on_unload(self) -> None:
+ """处理插件卸载。"""
+
# ===== 工具方法 =====
@staticmethod
@@ -208,6 +214,25 @@ class EmojiManagePlugin(MaiBotPlugin):
await self.ctx.send.forward(messages, stream_id)
return True, "已发送随机表情包", True
+ async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None:
+ """处理配置热重载事件。
+
+ Args:
+ scope: 配置变更范围。
+ config_data: 最新配置数据。
+ version: 配置版本号。
+ """
+
+ del scope
+ del config_data
+ del version
+
+
+def create_plugin() -> EmojiManagePlugin:
+ """创建表情包管理插件实例。
+
+ Returns:
+ EmojiManagePlugin: 新的表情包管理插件实例。
+ """
-def create_plugin():
return EmojiManagePlugin()
diff --git a/plugins/hello_world_plugin/_manifest.json b/plugins/hello_world_plugin/_manifest.json
index dc9fc474..e2bc694d 100644
--- a/plugins/hello_world_plugin/_manifest.json
+++ b/plugins/hello_world_plugin/_manifest.json
@@ -1,88 +1,41 @@
{
- "manifest_version": 1,
- "name": "Hello World 示例插件 (Hello World Plugin)",
+ "manifest_version": 2,
"version": "2.0.0",
- "description": "我的第一个MaiCore插件,包含问候功能和时间查询等基础示例",
+ "name": "Hello World 示例插件 (Hello World Plugin)",
+ "description": "我的第一个 MaiCore 插件,包含问候功能和时间查询等基础示例",
"author": {
"name": "MaiBot开发团队",
"url": "https://github.com/MaiM-with-u"
},
"license": "GPL-v3.0-or-later",
- "host_application": {
- "min_version": "1.0.0"
+ "urls": {
+ "repository": "https://github.com/MaiM-with-u/maibot",
+ "homepage": "https://github.com/MaiM-with-u/maibot",
+ "documentation": "https://github.com/MaiM-with-u/maibot",
+ "issues": "https://github.com/MaiM-with-u/maibot/issues"
},
- "homepage_url": "https://github.com/MaiM-with-u/maibot",
- "repository_url": "https://github.com/MaiM-with-u/maibot",
- "keywords": [
- "demo",
- "example",
- "hello",
- "greeting",
- "tutorial"
+ "host_application": {
+ "min_version": "1.0.0",
+ "max_version": "1.0.0"
+ },
+ "sdk": {
+ "min_version": "2.0.0",
+ "max_version": "2.99.99"
+ },
+ "dependencies": [],
+ "capabilities": [
+ "send.text",
+ "send.forward",
+ "send.hybrid",
+ "emoji.get_random",
+ "config.get"
],
- "categories": [
- "Examples",
- "Tutorial"
- ],
- "default_locale": "zh-CN",
- "locales_path": "_locales",
- "plugin_info": {
- "is_built_in": false,
- "plugin_type": "example",
- "capabilities": [
- "send.text",
- "send.forward",
- "send.hybrid",
- "emoji.get_random",
- "config.get"
- ],
- "components": [
- {
- "type": "tool",
- "name": "compare_numbers",
- "description": "比较两个数的大小"
- },
- {
- "type": "action",
- "name": "hello_greeting",
- "description": "向用户发送问候消息"
- },
- {
- "type": "action",
- "name": "bye_greeting",
- "description": "向用户发送告别消息",
- "activation_modes": ["keyword"],
- "keywords": ["再见", "bye", "88", "拜拜"]
- },
- {
- "type": "command",
- "name": "time",
- "description": "查询当前时间",
- "pattern": "/time"
- },
- {
- "type": "command",
- "name": "random_emojis",
- "description": "发送多张随机表情包",
- "pattern": "/random_emojis"
- },
- {
- "type": "command",
- "name": "test",
- "description": "测试命令",
- "pattern": "/test"
- },
- {
- "type": "event_handler",
- "name": "print_message_handler",
- "description": "打印接收到的消息"
- },
- {
- "type": "event_handler",
- "name": "forward_messages_handler",
- "description": "把接收到的消息转发到指定聊天ID"
- }
+ "i18n": {
+ "default_locale": "zh-CN",
+ "locales_path": "_locales",
+ "supported_locales": [
+ "zh-CN"
]
},
- "id": "MaiBot开发团队.maibot"
-}
\ No newline at end of file
+ "id": "maibot-team.hello-world-plugin"
+}
diff --git a/plugins/hello_world_plugin/plugin.py b/plugins/hello_world_plugin/plugin.py
index fbba9d10..4d1f37af 100644
--- a/plugins/hello_world_plugin/plugin.py
+++ b/plugins/hello_world_plugin/plugin.py
@@ -3,16 +3,22 @@
你的第一个 MaiCore 插件,包含问候功能、时间查询等基础示例。
"""
+from maibot_sdk import Action, Command, EventHandler, MaiBotPlugin, Tool
+from maibot_sdk.types import ActivationType, EventType, ToolParameterInfo, ToolParamType
+
import datetime
import random
-from maibot_sdk import MaiBotPlugin, Action, Command, Tool, EventHandler
-from maibot_sdk.types import ActivationType, EventType, ToolParameterInfo, ToolParamType
-
class HelloWorldPlugin(MaiBotPlugin):
"""Hello World 示例插件"""
+ async def on_load(self) -> None:
+ """处理插件加载。"""
+
+ async def on_unload(self) -> None:
+ """处理插件卸载。"""
+
# ===== Tool 组件 =====
@Tool(
@@ -146,6 +152,25 @@ class HelloWorldPlugin(MaiBotPlugin):
return True, True, None, None, None
+ async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None:
+ """处理配置热重载事件。
+
+ Args:
+ scope: 配置变更范围。
+ config_data: 最新配置数据。
+ version: 配置版本号。
+ """
+
+ del scope
+ del config_data
+ del version
+
+
+def create_plugin() -> HelloWorldPlugin:
+ """创建 Hello World 示例插件实例。
+
+ Returns:
+ HelloWorldPlugin: 新的示例插件实例。
+ """
-def create_plugin():
return HelloWorldPlugin()
diff --git a/pyproject.toml b/pyproject.toml
index 70aa42cf..90135e04 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -19,7 +19,7 @@ dependencies = [
"jieba>=0.42.1",
"json-repair>=0.47.6",
"maim-message>=0.6.2",
- "maibot-plugin-sdk>=1.2.3,<2.0.0",
+ "maibot-plugin-sdk>=2.0.0",
"msgpack>=1.1.2",
"numpy>=2.2.6",
"openai>=1.95.0",
@@ -55,6 +55,8 @@ dev = [
[tool.uv]
index-url = "https://pypi.tuna.tsinghua.edu.cn/simple"
+[tool.uv.sources]
+maibot-plugin-sdk = { path = "packages/maibot-plugin-sdk", editable = true }
[tool.ruff]
diff --git a/pytests/common_test/test_expression_auto_check_task.py b/pytests/common_test/test_expression_auto_check_task.py
new file mode 100644
index 00000000..da8c59e1
--- /dev/null
+++ b/pytests/common_test/test_expression_auto_check_task.py
@@ -0,0 +1,89 @@
+"""测试表达方式自动检查任务的数据库读取行为。"""
+
+from contextlib import contextmanager
+from typing import Generator
+
+import pytest
+from sqlalchemy.pool import StaticPool
+from sqlmodel import Session, SQLModel, create_engine
+
+from src.bw_learner.expression_auto_check_task import ExpressionAutoCheckTask
+from src.common.database.database_model import Expression
+
+
+@pytest.fixture(name="expression_auto_check_engine")
+def expression_auto_check_engine_fixture() -> Generator:
+ """创建用于表达方式自动检查任务测试的内存数据库引擎。
+
+ Yields:
+ Generator: 供测试使用的 SQLite 内存引擎。
+ """
+
+ engine = create_engine(
+ "sqlite://",
+ connect_args={"check_same_thread": False},
+ poolclass=StaticPool,
+ )
+ SQLModel.metadata.create_all(engine)
+ yield engine
+
+
+@pytest.mark.asyncio
+async def test_select_expressions_uses_read_only_session(
+ monkeypatch: pytest.MonkeyPatch,
+ expression_auto_check_engine,
+) -> None:
+ """选择表达方式时应使用只读会话,并在离开会话后安全读取 ORM 字段。"""
+
+ import src.bw_learner.expression_auto_check_task as expression_auto_check_task_module
+
+ with Session(expression_auto_check_engine) as session:
+ session.add(
+ Expression(
+ situation="表达情绪高涨或生理反应",
+ style="发送💦表情符号",
+ content_list='["表达情绪高涨或生理反应"]',
+ count=1,
+ session_id="session-a",
+ checked=False,
+ rejected=False,
+ )
+ )
+ session.commit()
+
+ auto_commit_calls: list[bool] = []
+
+ @contextmanager
+ def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
+ """构造带自动提交语义的测试会话工厂。
+
+ Args:
+ auto_commit: 退出上下文时是否自动提交。
+
+ Yields:
+ Generator[Session, None, None]: SQLModel 会话对象。
+ """
+
+ auto_commit_calls.append(auto_commit)
+ session = Session(expression_auto_check_engine)
+ try:
+ yield session
+ if auto_commit:
+ session.commit()
+ except Exception:
+ session.rollback()
+ raise
+ finally:
+ session.close()
+
+ monkeypatch.setattr(expression_auto_check_task_module, "get_db_session", fake_get_db_session)
+ monkeypatch.setattr(expression_auto_check_task_module.random, "sample", lambda entries, _count: list(entries))
+
+ task = ExpressionAutoCheckTask()
+ expressions = await task._select_expressions(1)
+
+ assert auto_commit_calls == [False]
+ assert len(expressions) == 1
+ assert expressions[0].id is not None
+ assert expressions[0].situation == "表达情绪高涨或生理反应"
+ assert expressions[0].style == "发送💦表情符号"
diff --git a/pytests/common_test/test_expression_learner.py b/pytests/common_test/test_expression_learner.py
new file mode 100644
index 00000000..951aa424
--- /dev/null
+++ b/pytests/common_test/test_expression_learner.py
@@ -0,0 +1,81 @@
+"""测试表达方式学习器的数据库读取行为。"""
+
+from contextlib import contextmanager
+from typing import Generator
+
+import pytest
+from sqlalchemy.pool import StaticPool
+from sqlmodel import Session, SQLModel, create_engine
+
+from src.bw_learner.expression_learner import ExpressionLearner
+from src.common.database.database_model import Expression
+
+
+@pytest.fixture(name="expression_learner_engine")
+def expression_learner_engine_fixture() -> Generator:
+ """创建用于表达方式学习器测试的内存数据库引擎。
+
+ Yields:
+ Generator: 供测试使用的 SQLite 内存引擎。
+ """
+ engine = create_engine(
+ "sqlite://",
+ connect_args={"check_same_thread": False},
+ poolclass=StaticPool,
+ )
+ SQLModel.metadata.create_all(engine)
+ yield engine
+
+
+def test_find_similar_expression_uses_read_only_session_and_history_content(
+ monkeypatch: pytest.MonkeyPatch,
+ expression_learner_engine,
+) -> None:
+ """查找相似表达方式时,应能在离开会话后安全使用结果,并比较历史情景内容。"""
+ import src.bw_learner.expression_learner as expression_learner_module
+
+ with Session(expression_learner_engine) as session:
+ session.add(
+ Expression(
+ situation="发送汗滴表情",
+ style="发送💦表情符号",
+ content_list='["表达情绪高涨或生理反应"]',
+ count=1,
+ session_id="session-a",
+ checked=False,
+ rejected=False,
+ )
+ )
+ session.commit()
+
+ @contextmanager
+ def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
+ """构造带自动提交语义的测试会话工厂。
+
+ Args:
+ auto_commit: 退出上下文时是否自动提交。
+
+ Yields:
+ Generator[Session, None, None]: SQLModel 会话对象。
+ """
+ session = Session(expression_learner_engine)
+ try:
+ yield session
+ if auto_commit:
+ session.commit()
+ except Exception:
+ session.rollback()
+ raise
+ finally:
+ session.close()
+
+ monkeypatch.setattr(expression_learner_module, "get_db_session", fake_get_db_session)
+
+ learner = ExpressionLearner(session_id="session-a")
+ result = learner._find_similar_expression("表达情绪高涨或生理反应")
+
+ assert result is not None
+ expression, similarity = result
+ assert expression.item_id is not None
+ assert expression.style == "发送💦表情符号"
+ assert similarity == pytest.approx(1.0)
diff --git a/pytests/common_test/test_expression_schema.py b/pytests/common_test/test_expression_schema.py
new file mode 100644
index 00000000..31fcd98f
--- /dev/null
+++ b/pytests/common_test/test_expression_schema.py
@@ -0,0 +1,78 @@
+"""测试表达方式表结构和基础插入行为。"""
+
+from typing import Generator
+
+import pytest
+from sqlalchemy.pool import StaticPool
+from sqlmodel import Session, SQLModel, create_engine
+
+from src.common.database.database_model import Expression
+
+
+@pytest.fixture(name="expression_engine")
+def expression_engine_fixture() -> Generator:
+ """创建仅用于表达方式表测试的内存数据库引擎。
+
+ Yields:
+ Generator: 供测试使用的 SQLite 内存引擎。
+ """
+ engine = create_engine(
+ "sqlite://",
+ connect_args={"check_same_thread": False},
+ poolclass=StaticPool,
+ )
+ SQLModel.metadata.create_all(engine)
+ yield engine
+
+
+def test_expression_insert_assigns_auto_increment_id(expression_engine) -> None:
+ """表达方式表在新库中应能自动分配自增主键。"""
+ with Session(expression_engine) as session:
+ expression = Expression(
+ situation="表达情绪高涨或生理反应",
+ style="发送💦表情符号",
+ content_list='["表达情绪高涨或生理反应"]',
+ count=1,
+ session_id="session-a",
+ checked=False,
+ rejected=False,
+ )
+ session.add(expression)
+ session.commit()
+ session.refresh(expression)
+
+ assert expression.id is not None
+ assert expression.id > 0
+
+
+def test_expression_insert_allows_same_situation_style(expression_engine) -> None:
+ """相同情景和风格的表达方式记录不应再被错误绑定到复合主键。"""
+ with Session(expression_engine) as session:
+ first_expression = Expression(
+ situation="对重复行为的默契响应",
+ style="持续性跟发相同内容",
+ content_list='["对重复行为的默契响应"]',
+ count=1,
+ session_id="session-a",
+ checked=False,
+ rejected=False,
+ )
+ second_expression = Expression(
+ situation="对重复行为的默契响应",
+ style="持续性跟发相同内容",
+ content_list='["对重复行为的默契响应-变体"]',
+ count=2,
+ session_id="session-b",
+ checked=False,
+ rejected=False,
+ )
+
+ session.add(first_expression)
+ session.add(second_expression)
+ session.commit()
+ session.refresh(first_expression)
+ session.refresh(second_expression)
+
+ assert first_expression.id is not None
+ assert second_expression.id is not None
+ assert first_expression.id != second_expression.id
diff --git a/pytests/common_test/test_jargon_miner.py b/pytests/common_test/test_jargon_miner.py
new file mode 100644
index 00000000..bf81e4d2
--- /dev/null
+++ b/pytests/common_test/test_jargon_miner.py
@@ -0,0 +1,90 @@
+"""测试黑话学习器的数据库读取行为。"""
+
+from contextlib import contextmanager
+from typing import Generator
+
+import pytest
+from sqlalchemy.pool import StaticPool
+from sqlmodel import Session, SQLModel, create_engine, select
+
+from src.bw_learner.jargon_miner import JargonMiner
+from src.common.database.database_model import Jargon
+
+
+@pytest.fixture(name="jargon_miner_engine")
+def jargon_miner_engine_fixture() -> Generator:
+ """创建用于黑话学习器测试的内存数据库引擎。
+
+ Yields:
+ Generator: 供测试使用的 SQLite 内存引擎。
+ """
+ engine = create_engine(
+ "sqlite://",
+ connect_args={"check_same_thread": False},
+ poolclass=StaticPool,
+ )
+ SQLModel.metadata.create_all(engine)
+ yield engine
+
+
+@pytest.mark.asyncio
+async def test_process_extracted_entries_updates_existing_jargon_without_detached_session(
+ monkeypatch: pytest.MonkeyPatch,
+ jargon_miner_engine,
+) -> None:
+ """更新已有黑话时,不应因会话关闭导致 ORM 实例失效。"""
+ import src.bw_learner.jargon_miner as jargon_miner_module
+
+ with Session(jargon_miner_engine) as session:
+ session.add(
+ Jargon(
+ content="VF8V4L",
+ raw_content='["[1] first"]',
+ meaning="",
+ session_id_dict='{"session-a": 1}',
+ count=0,
+ is_jargon=True,
+ is_complete=False,
+ is_global=False,
+ last_inference_count=0,
+ )
+ )
+ session.commit()
+
+ @contextmanager
+ def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
+ """构造带自动提交语义的测试会话工厂。
+
+ Args:
+ auto_commit: 退出上下文时是否自动提交。
+
+ Yields:
+ Generator[Session, None, None]: SQLModel 会话对象。
+ """
+ session = Session(jargon_miner_engine)
+ try:
+ yield session
+ if auto_commit:
+ session.commit()
+ except Exception:
+ session.rollback()
+ raise
+ finally:
+ session.close()
+
+ monkeypatch.setattr(jargon_miner_module, "get_db_session", fake_get_db_session)
+
+ jargon_miner = JargonMiner(session_id="session-a", session_name="测试群")
+ await jargon_miner.process_extracted_entries(
+ [{"content": "VF8V4L", "raw_content": {"[2] second"}}],
+ )
+
+ with Session(jargon_miner_engine) as session:
+ db_jargon = session.exec(select(Jargon).where(Jargon.content == "VF8V4L")).one()
+
+ assert db_jargon.count == 1
+ assert db_jargon.session_id_dict == '{"session-a": 2}'
+ assert sorted(db_jargon.raw_content and __import__("json").loads(db_jargon.raw_content)) == [
+ "[1] first",
+ "[2] second",
+ ]
diff --git a/pytests/common_test/test_jargon_schema.py b/pytests/common_test/test_jargon_schema.py
new file mode 100644
index 00000000..909392ab
--- /dev/null
+++ b/pytests/common_test/test_jargon_schema.py
@@ -0,0 +1,84 @@
+"""测试黑话表结构和基础插入行为。"""
+
+from typing import Generator
+
+import pytest
+from sqlalchemy.pool import StaticPool
+from sqlmodel import Session, SQLModel, create_engine
+
+from src.common.database.database_model import Jargon
+
+
+@pytest.fixture(name="jargon_engine")
+def jargon_engine_fixture() -> Generator:
+ """创建仅用于黑话表测试的内存数据库引擎。
+
+ Yields:
+ Generator: 供测试使用的 SQLite 内存引擎。
+ """
+ engine = create_engine(
+ "sqlite://",
+ connect_args={"check_same_thread": False},
+ poolclass=StaticPool,
+ )
+ SQLModel.metadata.create_all(engine)
+ yield engine
+
+
+def test_jargon_insert_assigns_auto_increment_id(jargon_engine) -> None:
+ """黑话表在新库中应能自动分配自增主键。"""
+ with Session(jargon_engine) as session:
+ jargon = Jargon(
+ content="VF8V4L",
+ raw_content='["[1] test"]',
+ meaning="",
+ session_id_dict='{"session-a": 1}',
+ count=1,
+ is_jargon=True,
+ is_complete=False,
+ is_global=True,
+ last_inference_count=0,
+ )
+ session.add(jargon)
+ session.commit()
+ session.refresh(jargon)
+
+ assert jargon.id is not None
+ assert jargon.id > 0
+
+
+def test_jargon_insert_allows_same_content_with_different_rows(jargon_engine) -> None:
+ """黑话内容不应再被错误地绑成复合主键的一部分。"""
+ with Session(jargon_engine) as session:
+ first_jargon = Jargon(
+ content="表情1",
+ raw_content='["[1] first"]',
+ meaning="",
+ session_id_dict='{"session-a": 1}',
+ count=1,
+ is_jargon=True,
+ is_complete=False,
+ is_global=False,
+ last_inference_count=0,
+ )
+ second_jargon = Jargon(
+ content="表情1",
+ raw_content='["[1] second"]',
+ meaning="",
+ session_id_dict='{"session-b": 1}',
+ count=1,
+ is_jargon=True,
+ is_complete=False,
+ is_global=False,
+ last_inference_count=0,
+ )
+
+ session.add(first_jargon)
+ session.add(second_jargon)
+ session.commit()
+ session.refresh(first_jargon)
+ session.refresh(second_jargon)
+
+ assert first_jargon.id is not None
+ assert second_jargon.id is not None
+ assert first_jargon.id != second_jargon.id
diff --git a/pytests/common_test/test_person_info_group_cardname.py b/pytests/common_test/test_person_info_group_cardname.py
new file mode 100644
index 00000000..62a63f43
--- /dev/null
+++ b/pytests/common_test/test_person_info_group_cardname.py
@@ -0,0 +1,355 @@
+"""人物信息群名片字段兼容测试。"""
+
+from __future__ import annotations
+
+from importlib.util import module_from_spec, spec_from_file_location
+from pathlib import Path
+from types import ModuleType, SimpleNamespace
+from typing import Any
+
+import json
+import sys
+
+import pytest
+
+from src.common.data_models.person_info_data_model import dump_group_cardname_records, parse_group_cardname_json
+
+
+class _DummyLogger:
+ """模拟日志记录器。"""
+
+ def debug(self, message: str) -> None:
+ """记录调试日志。
+
+ Args:
+ message: 日志内容。
+ """
+ del message
+
+ def info(self, message: str) -> None:
+ """记录信息日志。
+
+ Args:
+ message: 日志内容。
+ """
+ del message
+
+ def warning(self, message: str) -> None:
+ """记录警告日志。
+
+ Args:
+ message: 日志内容。
+ """
+ del message
+
+ def error(self, message: str) -> None:
+ """记录错误日志。
+
+ Args:
+ message: 日志内容。
+ """
+ del message
+
+
+class _DummyStatement:
+ """模拟 SQL 查询语句对象。"""
+
+ def where(self, condition: Any) -> "_DummyStatement":
+ """附加过滤条件。
+
+ Args:
+ condition: 过滤条件。
+
+ Returns:
+ _DummyStatement: 当前语句对象。
+ """
+ del condition
+ return self
+
+ def limit(self, value: int) -> "_DummyStatement":
+ """限制返回条数。
+
+ Args:
+ value: 条数限制。
+
+ Returns:
+ _DummyStatement: 当前语句对象。
+ """
+ del value
+ return self
+
+
+class _DummyColumn:
+ """模拟 SQLModel 列对象。"""
+
+ def is_not(self, value: Any) -> "_DummyColumn":
+ """模拟 `IS NOT` 条件构造。
+
+ Args:
+ value: 比较值。
+
+ Returns:
+ _DummyColumn: 当前列对象。
+ """
+ del value
+ return self
+
+ def __eq__(self, other: Any) -> "_DummyColumn":
+ """模拟等值条件构造。
+
+ Args:
+ other: 比较值。
+
+ Returns:
+ _DummyColumn: 当前列对象。
+ """
+ del other
+ return self
+
+
+class _DummyResult:
+ """模拟数据库查询结果。"""
+
+ def __init__(self, record: Any) -> None:
+ """初始化查询结果。
+
+ Args:
+ record: 待返回的首条记录。
+ """
+ self._record = record
+
+ def first(self) -> Any:
+ """返回第一条记录。
+
+ Returns:
+ Any: 首条记录。
+ """
+ return self._record
+
+ def all(self) -> list[Any]:
+ """返回全部结果。
+
+ Returns:
+ list[Any]: 结果列表。
+ """
+ if self._record is None:
+ return []
+ return self._record if isinstance(self._record, list) else [self._record]
+
+
+class _DummySession:
+ """模拟数据库 Session。"""
+
+ def __init__(self, record: Any) -> None:
+ """初始化 Session。
+
+ Args:
+ record: `first()` 应返回的记录。
+ """
+ self.record = record
+ self.added_records: list[Any] = []
+
+ def __enter__(self) -> "_DummySession":
+ """进入上下文管理器。
+
+ Returns:
+ _DummySession: 当前 Session。
+ """
+ return self
+
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
+ """退出上下文管理器。
+
+ Args:
+ exc_type: 异常类型。
+ exc_val: 异常值。
+ exc_tb: 异常回溯。
+ """
+ del exc_type
+ del exc_val
+ del exc_tb
+
+ def exec(self, statement: Any) -> _DummyResult:
+ """执行查询。
+
+ Args:
+ statement: 查询语句。
+
+ Returns:
+ _DummyResult: 模拟结果对象。
+ """
+ del statement
+ return _DummyResult(self.record)
+
+ def add(self, record: Any) -> None:
+ """记录被添加的对象。
+
+ Args:
+ record: 被写入 Session 的对象。
+ """
+ self.added_records.append(record)
+
+
+class _DummyPersonInfoRecord:
+ """模拟 `PersonInfo` ORM 模型。"""
+
+ person_id = "person_id"
+ person_name = "person_name"
+
+ def __init__(self, **kwargs: Any) -> None:
+ """使用关键字参数初始化记录对象。
+
+ Args:
+ **kwargs: 字段值。
+ """
+ for key, value in kwargs.items():
+ setattr(self, key, value)
+
+
+def _load_person_module(monkeypatch: pytest.MonkeyPatch, session: _DummySession) -> ModuleType:
+ """加载带依赖桩的 `person_info` 模块。
+
+ Args:
+ monkeypatch: Pytest monkeypatch 工具。
+ session: 提供给模块使用的假数据库 Session。
+
+ Returns:
+ ModuleType: 加载后的模块对象。
+ """
+ logger_module = ModuleType("src.common.logger")
+ logger_module.get_logger = lambda name: _DummyLogger()
+ monkeypatch.setitem(sys.modules, "src.common.logger", logger_module)
+
+ database_module = ModuleType("src.common.database.database")
+ database_module.get_db_session = lambda: session
+ monkeypatch.setitem(sys.modules, "src.common.database.database", database_module)
+
+ database_model_module = ModuleType("src.common.database.database_model")
+ database_model_module.PersonInfo = _DummyPersonInfoRecord
+ monkeypatch.setitem(sys.modules, "src.common.database.database_model", database_model_module)
+
+ llm_module = ModuleType("src.llm_models.utils_model")
+
+ class _DummyLLMRequest:
+ """模拟 LLMRequest。"""
+
+ def __init__(self, model_set: Any, request_type: str) -> None:
+ """初始化假请求对象。
+
+ Args:
+ model_set: 模型配置。
+ request_type: 请求类型。
+ """
+ del model_set
+ del request_type
+
+ llm_module.LLMRequest = _DummyLLMRequest
+ monkeypatch.setitem(sys.modules, "src.llm_models.utils_model", llm_module)
+
+ config_module = ModuleType("src.config.config")
+ config_module.global_config = SimpleNamespace(bot=SimpleNamespace(nickname="MaiBot"))
+ config_module.model_config = SimpleNamespace(model_task_config=SimpleNamespace(tool_use="tool_use", utils="utils"))
+ monkeypatch.setitem(sys.modules, "src.config.config", config_module)
+
+ chat_manager_module = ModuleType("src.chat.message_receive.chat_manager")
+ chat_manager_module.chat_manager = SimpleNamespace()
+ monkeypatch.setitem(sys.modules, "src.chat.message_receive.chat_manager", chat_manager_module)
+
+ module_path = Path(__file__).resolve().parents[2] / "src" / "person_info" / "person_info.py"
+ spec = spec_from_file_location("person_info_group_cardname_test_module", module_path)
+ assert spec is not None and spec.loader is not None
+
+ module = module_from_spec(spec)
+ monkeypatch.setitem(sys.modules, spec.name, module)
+ spec.loader.exec_module(module)
+
+ monkeypatch.setattr(module, "select", lambda *args: _DummyStatement())
+ monkeypatch.setattr(module, "col", lambda field: _DummyColumn())
+ return module
+
+
+def test_parse_group_cardname_json_uses_canonical_key() -> None:
+ """群名片 JSON 解析应只使用 `group_cardname` 键名。"""
+ parsed = parse_group_cardname_json(
+ json.dumps(
+ [
+ {"group_id": "1001", "group_cardname": "现行字段"},
+ ],
+ ensure_ascii=False,
+ )
+ )
+
+ assert parsed is not None
+ assert [(item.group_id, item.group_cardname) for item in parsed] == [
+ ("1001", "现行字段"),
+ ]
+
+
+def test_dump_group_cardname_records_uses_canonical_key() -> None:
+ """群名片序列化应输出 `group_cardname` 键名。"""
+ dumped = dump_group_cardname_records(
+ [
+ {"group_id": "1001", "group_cardname": "群昵称"},
+ ]
+ )
+
+ assert json.loads(dumped) == [{"group_id": "1001", "group_cardname": "群昵称"}]
+
+
+def test_person_sync_to_database_uses_group_cardname_field(monkeypatch: pytest.MonkeyPatch) -> None:
+ """同步人物信息时应写入数据库模型的 `group_cardname` 字段。"""
+ record = _DummyPersonInfoRecord()
+ session = _DummySession(record)
+ module = _load_person_module(monkeypatch, session)
+
+ person = module.Person.__new__(module.Person)
+ person.is_known = True
+ person.person_id = "person-1"
+ person.platform = "qq"
+ person.user_id = "10001"
+ person.nickname = "看番的龙"
+ person.person_name = "看番的龙"
+ person.name_reason = "测试"
+ person.know_times = 1
+ person.know_since = 1700000000.0
+ person.last_know = 1700000100.0
+ person.memory_points = ["喜好:番剧:0.8"]
+ person.group_cardname_list = [{"group_id": "20001", "group_cardname": "白泽大人"}]
+
+ person.sync_to_database()
+
+ assert record.group_cardname == '[{"group_id": "20001", "group_cardname": "白泽大人"}]'
+ assert not hasattr(record, "group_nickname")
+
+
+def test_person_load_from_database_normalizes_group_cardname_payload(monkeypatch: pytest.MonkeyPatch) -> None:
+ """从数据库加载人物信息时应读取标准 `group_cardname` 结构。"""
+ record = _DummyPersonInfoRecord(
+ user_id="10001",
+ platform="qq",
+ is_known=True,
+ user_nickname="看番的龙",
+ person_name="看番的龙",
+ name_reason=None,
+ know_counts=2,
+ memory_points='["喜好:番剧:0.8"]',
+ group_cardname=json.dumps(
+ [
+ {"group_id": "20001", "group_cardname": "白泽大人"},
+ ],
+ ensure_ascii=False,
+ ),
+ )
+ session = _DummySession(record)
+ module = _load_person_module(monkeypatch, session)
+
+ person = module.Person.__new__(module.Person)
+ person.person_id = "person-1"
+ person.memory_points = []
+ person.group_cardname_list = []
+
+ person.load_from_database()
+
+ assert person.group_cardname_list == [
+ {"group_id": "20001", "group_cardname": "白泽大人"},
+ ]
diff --git a/pytests/test_message_gateway_runtime.py b/pytests/test_message_gateway_runtime.py
new file mode 100644
index 00000000..9650bc10
--- /dev/null
+++ b/pytests/test_message_gateway_runtime.py
@@ -0,0 +1,170 @@
+"""消息网关运行时状态同步测试。"""
+
+from typing import Any, Dict
+
+import pytest
+
+from src.platform_io.manager import PlatformIOManager
+from src.platform_io.types import RouteKey
+from src.plugin_runtime.host.supervisor import PluginSupervisor
+from src.plugin_runtime.protocol.envelope import Envelope, MessageType
+
+
+def _make_request(method: str, plugin_id: str, payload: Dict[str, Any]) -> Envelope:
+ """构造一个 RPC 请求信封。
+
+ Args:
+ method: RPC 方法名。
+ plugin_id: 目标插件 ID。
+ payload: 请求载荷。
+
+ Returns:
+ Envelope: 标准 RPC 请求信封。
+ """
+
+ return Envelope(
+ request_id=1,
+ message_type=MessageType.REQUEST,
+ method=method,
+ plugin_id=plugin_id,
+ payload=payload,
+ )
+
+
+@pytest.mark.asyncio
+async def test_message_gateway_runtime_state_binds_send_and_receive_routes(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """消息网关就绪后应同时绑定发送表和接收表。"""
+
+ import src.plugin_runtime.host.supervisor as supervisor_module
+
+ platform_io_manager = PlatformIOManager()
+ monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager)
+
+ supervisor = PluginSupervisor(plugin_dirs=[])
+ register_response = await supervisor._handle_register_plugin(
+ _make_request(
+ "plugin.register_components",
+ "napcat_plugin",
+ {
+ "plugin_id": "napcat_plugin",
+ "plugin_version": "1.0.0",
+ "components": [
+ {
+ "name": "napcat_gateway",
+ "component_type": "MESSAGE_GATEWAY",
+ "plugin_id": "napcat_plugin",
+ "metadata": {
+ "route_type": "duplex",
+ "platform": "qq",
+ "protocol": "napcat",
+ },
+ }
+ ],
+ "capabilities_required": [],
+ },
+ )
+ )
+
+ assert register_response.error is None
+ response = await supervisor._handle_update_message_gateway_state(
+ _make_request(
+ "host.update_message_gateway_state",
+ "napcat_plugin",
+ {
+ "gateway_name": "napcat_gateway",
+ "ready": True,
+ "platform": "qq",
+ "account_id": "10001",
+ "scope": "primary",
+ "metadata": {},
+ },
+ )
+ )
+
+ assert response.error is None
+ assert response.payload["accepted"] is True
+
+ send_bindings = platform_io_manager.send_route_table.resolve_bindings(
+ RouteKey(platform="qq", account_id="10001", scope="primary")
+ )
+ receive_bindings = platform_io_manager.receive_route_table.resolve_bindings(
+ RouteKey(platform="qq", account_id="10001", scope="primary")
+ )
+
+ assert [binding.driver_id for binding in send_bindings] == ["gateway:napcat_plugin:napcat_gateway"]
+ assert [binding.driver_id for binding in receive_bindings] == ["gateway:napcat_plugin:napcat_gateway"]
+
+
+@pytest.mark.asyncio
+async def test_message_gateway_runtime_state_unbinds_routes_when_not_ready(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """消息网关断开后应撤销发送表和接收表中的绑定。"""
+
+ import src.plugin_runtime.host.supervisor as supervisor_module
+
+ platform_io_manager = PlatformIOManager()
+ monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager)
+
+ supervisor = PluginSupervisor(plugin_dirs=[])
+ await supervisor._handle_register_plugin(
+ _make_request(
+ "plugin.register_components",
+ "napcat_plugin",
+ {
+ "plugin_id": "napcat_plugin",
+ "plugin_version": "1.0.0",
+ "components": [
+ {
+ "name": "napcat_gateway",
+ "component_type": "MESSAGE_GATEWAY",
+ "plugin_id": "napcat_plugin",
+ "metadata": {
+ "route_type": "duplex",
+ "platform": "qq",
+ "protocol": "napcat",
+ },
+ }
+ ],
+ "capabilities_required": [],
+ },
+ )
+ )
+
+ await supervisor._handle_update_message_gateway_state(
+ _make_request(
+ "host.update_message_gateway_state",
+ "napcat_plugin",
+ {
+ "gateway_name": "napcat_gateway",
+ "ready": True,
+ "platform": "qq",
+ "account_id": "10001",
+ "scope": "primary",
+ "metadata": {},
+ },
+ )
+ )
+ response = await supervisor._handle_update_message_gateway_state(
+ _make_request(
+ "host.update_message_gateway_state",
+ "napcat_plugin",
+ {
+ "gateway_name": "napcat_gateway",
+ "ready": False,
+ "platform": "qq",
+ "account_id": "",
+ "scope": "",
+ "metadata": {},
+ },
+ )
+ )
+
+ assert response.error is None
+ assert response.payload["accepted"] is True
+ assert platform_io_manager.send_route_table.resolve_bindings(RouteKey(platform="qq", account_id="10001")) == []
+ assert (
+ platform_io_manager.receive_route_table.resolve_bindings(RouteKey(platform="qq", account_id="10001")) == []
+ )
diff --git a/pytests/test_napcat_adapter_sdk.py b/pytests/test_napcat_adapter_sdk.py
new file mode 100644
index 00000000..c6b1fdbd
--- /dev/null
+++ b/pytests/test_napcat_adapter_sdk.py
@@ -0,0 +1,132 @@
+"""NapCat 插件与新 SDK 对接测试。"""
+
+from pathlib import Path
+from typing import Any, Dict, List
+
+import importlib
+import logging
+import sys
+
+import pytest
+
+PROJECT_ROOT = Path(__file__).resolve().parents[1]
+PLUGINS_ROOT = PROJECT_ROOT / "plugins"
+SDK_ROOT = PROJECT_ROOT / "packages" / "maibot-plugin-sdk"
+
+for import_path in (str(PLUGINS_ROOT), str(SDK_ROOT)):
+ if import_path not in sys.path:
+ sys.path.insert(0, import_path)
+
+
+class _FakeGatewayCapability:
+ """用于捕获消息网关状态上报的测试替身。"""
+
+ def __init__(self) -> None:
+ """初始化测试替身。"""
+
+ self.calls: List[Dict[str, Any]] = []
+
+ async def update_state(
+ self,
+ gateway_name: str,
+ *,
+ ready: bool,
+ platform: str = "",
+ account_id: str = "",
+ scope: str = "",
+ metadata: Dict[str, Any] | None = None,
+ ) -> bool:
+ """记录一次状态上报请求。
+
+ Args:
+ gateway_name: 网关组件名称。
+ ready: 当前是否就绪。
+ platform: 平台名称。
+ account_id: 账号 ID。
+ scope: 路由作用域。
+ metadata: 附加元数据。
+
+ Returns:
+ bool: 始终返回 ``True``,模拟 Host 接受状态更新。
+ """
+
+ self.calls.append(
+ {
+ "gateway_name": gateway_name,
+ "ready": ready,
+ "platform": platform,
+ "account_id": account_id,
+ "scope": scope,
+ "metadata": metadata or {},
+ }
+ )
+ return True
+
+
+def _load_napcat_sdk_symbols() -> tuple[Any, Any, Any, Any]:
+ """动态加载 NapCat 插件测试所需的符号。
+
+ Returns:
+ tuple[Any, Any, Any, Any]:
+ 依次返回网关名常量、配置类、插件类和运行时状态管理器类。
+ """
+
+ constants_module = importlib.import_module("napcat_adapter.constants")
+ config_module = importlib.import_module("napcat_adapter.config")
+ plugin_module = importlib.import_module("napcat_adapter.plugin")
+ runtime_state_module = importlib.import_module("napcat_adapter.runtime_state")
+ return (
+ constants_module.NAPCAT_GATEWAY_NAME,
+ config_module.NapCatServerConfig,
+ plugin_module.NapCatAdapterPlugin,
+ runtime_state_module.NapCatRuntimeStateManager,
+ )
+
+
+def test_napcat_plugin_collects_duplex_message_gateway() -> None:
+ """NapCat 插件应声明新的双工消息网关组件。"""
+
+ napcat_gateway_name, _napcat_server_config, napcat_plugin_cls, _runtime_state_cls = _load_napcat_sdk_symbols()
+ plugin = napcat_plugin_cls()
+ components = plugin.get_components()
+ gateway_components = [
+ component
+ for component in components
+ if component.get("type") == "MESSAGE_GATEWAY"
+ ]
+
+ assert len(gateway_components) == 1
+ gateway_component = gateway_components[0]
+ assert gateway_component["name"] == napcat_gateway_name
+ assert gateway_component["metadata"]["route_type"] == "duplex"
+ assert gateway_component["metadata"]["platform"] == "qq"
+ assert gateway_component["metadata"]["protocol"] == "napcat"
+
+
+@pytest.mark.asyncio
+async def test_runtime_state_reports_via_gateway_capability() -> None:
+ """NapCat 运行时状态应通过新的消息网关能力上报。"""
+
+ napcat_gateway_name, napcat_server_config_cls, _napcat_plugin_cls, runtime_state_cls = _load_napcat_sdk_symbols()
+ gateway_capability = _FakeGatewayCapability()
+ runtime_state_manager = runtime_state_cls(
+ gateway_capability=gateway_capability,
+ logger=logging.getLogger("test.napcat_adapter"),
+ gateway_name=napcat_gateway_name,
+ )
+
+ connected = await runtime_state_manager.report_connected(
+ "10001",
+ napcat_server_config_cls(connection_id="primary"),
+ )
+ await runtime_state_manager.report_disconnected()
+
+ assert connected is True
+ assert gateway_capability.calls[0]["gateway_name"] == napcat_gateway_name
+ assert gateway_capability.calls[0]["ready"] is True
+ assert gateway_capability.calls[0]["platform"] == "qq"
+ assert gateway_capability.calls[0]["account_id"] == "10001"
+ assert gateway_capability.calls[0]["scope"] == "primary"
+ assert gateway_capability.calls[1]["gateway_name"] == napcat_gateway_name
+ assert gateway_capability.calls[1]["ready"] is False
+ assert gateway_capability.calls[1]["platform"] == "qq"
diff --git a/pytests/test_platform_io_dedupe.py b/pytests/test_platform_io_dedupe.py
new file mode 100644
index 00000000..d6bdd1dd
--- /dev/null
+++ b/pytests/test_platform_io_dedupe.py
@@ -0,0 +1,209 @@
+"""Platform IO 入站去重策略测试。"""
+
+from types import SimpleNamespace
+from typing import Any, Dict, List, Optional
+
+import pytest
+
+from src.platform_io.drivers.base import PlatformIODriver
+from src.platform_io.manager import PlatformIOManager
+from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey
+
+
+def _build_envelope(
+ *,
+ dedupe_key: str | None = None,
+ external_message_id: str | None = None,
+ session_message_id: str | None = None,
+ payload: Optional[Dict[str, Any]] = None,
+) -> InboundMessageEnvelope:
+ """构造测试用入站信封。
+
+ Args:
+ dedupe_key: 显式去重键。
+ external_message_id: 平台侧消息 ID。
+ session_message_id: 规范化消息对象上的消息 ID。
+ payload: 原始载荷。
+
+ Returns:
+ InboundMessageEnvelope: 测试用入站消息信封。
+ """
+ session_message = None
+ if session_message_id is not None:
+ session_message = SimpleNamespace(message_id=session_message_id)
+
+ return InboundMessageEnvelope(
+ route_key=RouteKey(platform="qq", account_id="10001", scope="main"),
+ driver_id="plugin.napcat",
+ driver_kind=DriverKind.PLUGIN,
+ dedupe_key=dedupe_key,
+ external_message_id=external_message_id,
+ session_message=session_message,
+ payload=payload,
+ )
+
+
+class _StubPlatformIODriver(PlatformIODriver):
+ """测试用 Platform IO 驱动。"""
+
+ async def send_message(
+ self,
+ message: Any,
+ route_key: RouteKey,
+ metadata: Optional[Dict[str, Any]] = None,
+ ) -> DeliveryReceipt:
+ """返回一个固定的成功回执。
+
+ Args:
+ message: 待发送的消息对象。
+ route_key: 本次发送使用的路由键。
+ metadata: 额外发送元数据。
+
+ Returns:
+ DeliveryReceipt: 固定的成功回执。
+ """
+ return DeliveryReceipt(
+ internal_message_id=str(getattr(message, "message_id", "stub-message-id")),
+ route_key=route_key,
+ status=DeliveryStatus.SENT,
+ driver_id=self.driver_id,
+ driver_kind=self.descriptor.kind,
+ )
+
+
+def _build_manager() -> PlatformIOManager:
+ """构造带有最小接收路由的 Broker 管理器。
+
+ Returns:
+ PlatformIOManager: 已注册测试驱动并绑定接收路由的 Broker。
+ """
+ manager = PlatformIOManager()
+ driver = _StubPlatformIODriver(
+ DriverDescriptor(
+ driver_id="plugin.napcat",
+ kind=DriverKind.PLUGIN,
+ platform="qq",
+ account_id="10001",
+ scope="main",
+ )
+ )
+ manager.register_driver(driver)
+ manager.bind_receive_route(
+ RouteBinding(
+ route_key=RouteKey(platform="qq", account_id="10001", scope="main"),
+ driver_id=driver.driver_id,
+ driver_kind=driver.descriptor.kind,
+ )
+ )
+ return manager
+
+
+class TestPlatformIODedupe:
+ """Platform IO 去重测试。"""
+
+ @pytest.mark.asyncio
+ async def test_accept_inbound_dedupes_by_external_message_id(self) -> None:
+ """相同平台消息 ID 的重复入站应被抑制。"""
+ manager = _build_manager()
+ accepted_envelopes: List[InboundMessageEnvelope] = []
+
+ async def dispatcher(envelope: InboundMessageEnvelope) -> None:
+ """记录被成功接收的入站消息。
+
+ Args:
+ envelope: 被 Broker 接受的入站消息。
+ """
+ accepted_envelopes.append(envelope)
+
+ manager.set_inbound_dispatcher(dispatcher)
+
+ first_envelope = _build_envelope(
+ external_message_id="msg-1",
+ payload={"message": "hello"},
+ )
+ second_envelope = _build_envelope(
+ external_message_id="msg-1",
+ payload={"message": "hello"},
+ )
+
+ assert await manager.accept_inbound(first_envelope) is True
+ assert await manager.accept_inbound(second_envelope) is False
+ assert len(accepted_envelopes) == 1
+
+ @pytest.mark.asyncio
+ async def test_accept_inbound_without_stable_identity_does_not_guess_duplicate(self) -> None:
+ """缺少稳定身份时,不应仅凭 payload 内容猜测重复消息。"""
+ manager = _build_manager()
+ accepted_envelopes: List[InboundMessageEnvelope] = []
+
+ async def dispatcher(envelope: InboundMessageEnvelope) -> None:
+ """记录被成功接收的入站消息。
+
+ Args:
+ envelope: 被 Broker 接受的入站消息。
+ """
+ accepted_envelopes.append(envelope)
+
+ manager.set_inbound_dispatcher(dispatcher)
+
+ first_envelope = _build_envelope(payload={"message": "same-payload"})
+ second_envelope = _build_envelope(payload={"message": "same-payload"})
+
+ assert await manager.accept_inbound(first_envelope) is True
+ assert await manager.accept_inbound(second_envelope) is True
+ assert len(accepted_envelopes) == 2
+
+ def test_build_inbound_dedupe_key_prefers_explicit_identity(self) -> None:
+ """去重键应只来自显式或稳定的技术身份。"""
+ explicit_envelope = _build_envelope(dedupe_key="dedupe-1", external_message_id="msg-1")
+ session_message_envelope = _build_envelope(session_message_id="session-1")
+ payload_only_envelope = _build_envelope(payload={"message": "hello"})
+
+ assert PlatformIOManager._build_inbound_dedupe_key(explicit_envelope) == "plugin.napcat:dedupe-1"
+ assert PlatformIOManager._build_inbound_dedupe_key(session_message_envelope) == "plugin.napcat:session-1"
+ assert PlatformIOManager._build_inbound_dedupe_key(payload_only_envelope) is None
+
+ @pytest.mark.asyncio
+ async def test_send_message_fans_out_to_all_matching_routes(self) -> None:
+ """同一路由命中多条发送链路时应全部发送。"""
+
+ manager = PlatformIOManager()
+ first_driver = _StubPlatformIODriver(
+ DriverDescriptor(
+ driver_id="plugin.gateway_a",
+ kind=DriverKind.PLUGIN,
+ platform="qq",
+ )
+ )
+ second_driver = _StubPlatformIODriver(
+ DriverDescriptor(
+ driver_id="plugin.gateway_b",
+ kind=DriverKind.PLUGIN,
+ platform="qq",
+ )
+ )
+ manager.register_driver(first_driver)
+ manager.register_driver(second_driver)
+ manager.bind_send_route(
+ RouteBinding(
+ route_key=RouteKey(platform="qq"),
+ driver_id=first_driver.driver_id,
+ driver_kind=first_driver.descriptor.kind,
+ )
+ )
+ manager.bind_send_route(
+ RouteBinding(
+ route_key=RouteKey(platform="qq"),
+ driver_id=second_driver.driver_id,
+ driver_kind=second_driver.descriptor.kind,
+ )
+ )
+
+ message = SimpleNamespace(message_id="internal-msg-1")
+ result = await manager.send_message(message, RouteKey(platform="qq"))
+
+ assert result.has_success is True
+ assert [receipt.driver_id for receipt in result.sent_receipts] == [
+ "plugin.gateway_a",
+ "plugin.gateway_b",
+ ]
diff --git a/pytests/test_platform_io_legacy_driver.py b/pytests/test_platform_io_legacy_driver.py
new file mode 100644
index 00000000..76f14d8f
--- /dev/null
+++ b/pytests/test_platform_io_legacy_driver.py
@@ -0,0 +1,178 @@
+"""Platform IO legacy driver 回归测试。"""
+
+from typing import Any, Dict, Optional
+
+import pytest
+
+from src.chat.utils import utils as chat_utils
+from src.chat.message_receive import uni_message_sender
+from src.platform_io.drivers.base import PlatformIODriver
+from src.platform_io.drivers.legacy_driver import LegacyPlatformDriver
+from src.platform_io.manager import PlatformIOManager
+from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, RouteBinding, RouteKey
+
+
+class _PluginDriver(PlatformIODriver):
+ """测试用插件发送驱动。"""
+
+ def __init__(self, driver_id: str, platform: str) -> None:
+ """初始化测试驱动。
+
+ Args:
+ driver_id: 驱动 ID。
+ platform: 负责的平台名称。
+ """
+ super().__init__(
+ DriverDescriptor(
+ driver_id=driver_id,
+ kind=DriverKind.PLUGIN,
+ platform=platform,
+ plugin_id="test.plugin",
+ )
+ )
+
+ async def send_message(
+ self,
+ message: Any,
+ route_key: RouteKey,
+ metadata: Optional[Dict[str, Any]] = None,
+ ) -> DeliveryReceipt:
+ """返回一个固定成功回执。
+
+ Args:
+ message: 待发送消息。
+ route_key: 当前路由键。
+ metadata: 发送元数据。
+
+ Returns:
+ DeliveryReceipt: 固定成功回执。
+ """
+ del metadata
+ return DeliveryReceipt(
+ internal_message_id=str(message.message_id),
+ route_key=route_key,
+ status=DeliveryStatus.SENT,
+ driver_id=self.driver_id,
+ driver_kind=self.descriptor.kind,
+ )
+
+
+@pytest.mark.asyncio
+async def test_platform_io_uses_legacy_driver_when_no_explicit_send_route(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """没有显式发送路由时,应由 Platform IO 回退到 legacy driver。"""
+ manager = PlatformIOManager()
+ monkeypatch.setattr(chat_utils, "get_all_bot_accounts", lambda: {"qq": "bot-qq"})
+
+ try:
+ await manager.ensure_send_pipeline_ready()
+
+ fallback_drivers = manager.resolve_drivers(RouteKey(platform="qq"))
+ assert [driver.driver_id for driver in fallback_drivers] == ["legacy.send.qq"]
+
+ plugin_driver = _PluginDriver(driver_id="plugin.qq.sender", platform="qq")
+ await manager.add_driver(plugin_driver)
+ manager.bind_send_route(
+ RouteBinding(
+ route_key=RouteKey(platform="qq"),
+ driver_id=plugin_driver.driver_id,
+ driver_kind=plugin_driver.descriptor.kind,
+ )
+ )
+
+ explicit_drivers = manager.resolve_drivers(RouteKey(platform="qq"))
+ assert [driver.driver_id for driver in explicit_drivers] == ["plugin.qq.sender", "legacy.send.qq"]
+ finally:
+ await manager.stop()
+
+
+@pytest.mark.asyncio
+async def test_platform_io_broadcasts_to_plugin_and_legacy_driver(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """同一路由命中插件驱动与 legacy driver 时,应同时广播发送。"""
+
+ manager = PlatformIOManager()
+ legacy_calls: list[dict[str, Any]] = []
+ monkeypatch.setattr(chat_utils, "get_all_bot_accounts", lambda: {"qq": "bot-qq"})
+
+ async def _fake_send_prepared_message_to_platform(message: Any, show_log: bool = True) -> bool:
+ """记录 legacy driver 调用。"""
+
+ legacy_calls.append({"message": message, "show_log": show_log})
+ return True
+
+ monkeypatch.setattr(
+ uni_message_sender,
+ "send_prepared_message_to_platform",
+ _fake_send_prepared_message_to_platform,
+ )
+
+ try:
+ await manager.ensure_send_pipeline_ready()
+
+ plugin_driver = _PluginDriver(driver_id="plugin.qq.sender", platform="qq")
+ await manager.add_driver(plugin_driver)
+ manager.bind_send_route(
+ RouteBinding(
+ route_key=RouteKey(platform="qq"),
+ driver_id=plugin_driver.driver_id,
+ driver_kind=plugin_driver.descriptor.kind,
+ )
+ )
+
+ message = type("FakeMessage", (), {"message_id": "message-1"})()
+ batch = await manager.send_message(
+ message=message,
+ route_key=RouteKey(platform="qq"),
+ metadata={"show_log": False},
+ )
+
+ assert sorted(receipt.driver_id for receipt in batch.sent_receipts) == [
+ "legacy.send.qq",
+ "plugin.qq.sender",
+ ]
+ assert batch.failed_receipts == []
+ assert len(legacy_calls) == 1
+ assert legacy_calls[0]["message"] is message
+ assert legacy_calls[0]["show_log"] is False
+ finally:
+ await manager.stop()
+
+
+@pytest.mark.asyncio
+async def test_legacy_platform_driver_uses_prepared_universal_sender(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """legacy driver 应复用已预处理消息的旧链发送函数。"""
+ calls: list[dict[str, Any]] = []
+
+ async def _fake_send_prepared_message_to_platform(message: Any, show_log: bool = True) -> bool:
+ """记录 legacy driver 调用。"""
+ calls.append({"message": message, "show_log": show_log})
+ return True
+
+ monkeypatch.setattr(
+ uni_message_sender,
+ "send_prepared_message_to_platform",
+ _fake_send_prepared_message_to_platform,
+ )
+
+ driver = LegacyPlatformDriver(
+ driver_id="legacy.send.qq",
+ platform="qq",
+ account_id="bot-qq",
+ )
+ message = type("FakeMessage", (), {"message_id": "message-1"})()
+ receipt = await driver.send_message(
+ message=message,
+ route_key=RouteKey(platform="qq"),
+ metadata={"show_log": False},
+ )
+
+ assert len(calls) == 1
+ assert calls[0]["message"] is message
+ assert calls[0]["show_log"] is False
+ assert receipt.status == DeliveryStatus.SENT
+ assert receipt.driver_id == "legacy.send.qq"
diff --git a/pytests/test_plugin_message_utils_runtime.py b/pytests/test_plugin_message_utils_runtime.py
new file mode 100644
index 00000000..cb4b5341
--- /dev/null
+++ b/pytests/test_plugin_message_utils_runtime.py
@@ -0,0 +1,87 @@
+from datetime import datetime
+from pathlib import Path
+
+import sys
+
+from src.chat.message_receive.message import SessionMessage
+from src.common.data_models.mai_message_data_model import GroupInfo, MessageInfo, UserInfo
+from src.common.data_models.message_component_data_model import (
+ ForwardComponent,
+ ForwardNodeComponent,
+ ImageComponent,
+ MessageSequence,
+ ReplyComponent,
+ TextComponent,
+ VoiceComponent,
+)
+from src.plugin_runtime.host.message_utils import PluginMessageUtils
+
+
+PROJECT_ROOT = Path(__file__).resolve().parents[1]
+if str(PROJECT_ROOT) not in sys.path:
+ sys.path.insert(0, str(PROJECT_ROOT))
+
+
+def test_plugin_message_utils_preserves_binary_components_and_reply_metadata() -> None:
+ message = SessionMessage(message_id="msg-1", timestamp=datetime.now(), platform="qq")
+ message.message_info = MessageInfo(
+ user_info=UserInfo(user_id="10001", user_nickname="tester"),
+ group_info=GroupInfo(group_id="20001", group_name="group"),
+ additional_config={"self_id": "999"},
+ )
+ message.session_id = "qq:20001:10001"
+ message.processed_plain_text = "binary payload"
+ message.display_message = "binary payload"
+ message.raw_message = MessageSequence(
+ components=[
+ TextComponent("hello"),
+ ImageComponent(binary_hash="", binary_data=b"image-bytes", content=""),
+ VoiceComponent(binary_hash="", binary_data=b"voice-bytes", content=""),
+ ReplyComponent(
+ target_message_id="origin-1",
+ target_message_content="origin text",
+ target_message_sender_id="42",
+ target_message_sender_nickname="alice",
+ target_message_sender_cardname="Alice",
+ ),
+ ForwardNodeComponent(
+ forward_components=[
+ ForwardComponent(
+ user_nickname="bob",
+ user_id="43",
+ user_cardname="Bob",
+ message_id="forward-1",
+ content=[
+ TextComponent("node-text"),
+ ImageComponent(binary_hash="", binary_data=b"node-image", content=""),
+ ],
+ )
+ ]
+ ),
+ ]
+ )
+
+ message_dict = PluginMessageUtils._session_message_to_dict(message)
+ rebuilt_message = PluginMessageUtils._build_session_message_from_dict(dict(message_dict))
+
+ image_component = rebuilt_message.raw_message.components[1]
+ voice_component = rebuilt_message.raw_message.components[2]
+ reply_component = rebuilt_message.raw_message.components[3]
+ forward_component = rebuilt_message.raw_message.components[4]
+
+ assert isinstance(image_component, ImageComponent)
+ assert image_component.binary_data == b"image-bytes"
+
+ assert isinstance(voice_component, VoiceComponent)
+ assert voice_component.binary_data == b"voice-bytes"
+
+ assert isinstance(reply_component, ReplyComponent)
+ assert reply_component.target_message_id == "origin-1"
+ assert reply_component.target_message_content == "origin text"
+ assert reply_component.target_message_sender_id == "42"
+ assert reply_component.target_message_sender_nickname == "alice"
+ assert reply_component.target_message_sender_cardname == "Alice"
+
+ assert isinstance(forward_component, ForwardNodeComponent)
+ assert isinstance(forward_component.forward_components[0].content[1], ImageComponent)
+ assert forward_component.forward_components[0].content[1].binary_data == b"node-image"
diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py
index 2c703161..e3247f05 100644
--- a/pytests/test_plugin_runtime.py
+++ b/pytests/test_plugin_runtime.py
@@ -3,6 +3,7 @@
验证协议层、传输层、RPC 通信链路的正确性。
"""
+from pathlib import Path
from types import SimpleNamespace
import asyncio
@@ -18,6 +19,104 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "packages", "maibot-plugin-sdk"))
+def build_test_manifest(
+ plugin_id: str,
+ *,
+ version: str = "1.0.0",
+ name: str = "测试插件",
+ description: str = "测试插件描述",
+ dependencies: list[dict[str, str]] | None = None,
+ capabilities: list[str] | None = None,
+ host_min_version: str = "0.12.0",
+ host_max_version: str = "1.0.0",
+ sdk_min_version: str = "2.0.0",
+ sdk_max_version: str = "2.99.99",
+) -> dict[str, object]:
+ """构造一个合法的 Manifest v2 测试样例。
+
+ Args:
+ plugin_id: 插件 ID。
+ version: 插件版本。
+ name: 展示名称。
+ description: 插件描述。
+ dependencies: 依赖声明列表。
+ capabilities: 能力声明列表。
+ host_min_version: Host 最低支持版本。
+ host_max_version: Host 最高支持版本。
+ sdk_min_version: SDK 最低支持版本。
+ sdk_max_version: SDK 最高支持版本。
+
+ Returns:
+ dict[str, object]: 可直接序列化为 ``_manifest.json`` 的字典。
+ """
+ return {
+ "manifest_version": 2,
+ "version": version,
+ "name": name,
+ "description": description,
+ "author": {
+ "name": "tester",
+ "url": "https://example.com/tester",
+ },
+ "license": "MIT",
+ "urls": {
+ "repository": f"https://example.com/{plugin_id}",
+ },
+ "host_application": {
+ "min_version": host_min_version,
+ "max_version": host_max_version,
+ },
+ "sdk": {
+ "min_version": sdk_min_version,
+ "max_version": sdk_max_version,
+ },
+ "dependencies": dependencies or [],
+ "capabilities": capabilities or [],
+ "i18n": {
+ "default_locale": "zh-CN",
+ "supported_locales": ["zh-CN"],
+ },
+ "id": plugin_id,
+ }
+
+
+def build_test_manifest_model(
+ plugin_id: str,
+ *,
+ version: str = "1.0.0",
+ dependencies: list[dict[str, str]] | None = None,
+ capabilities: list[str] | None = None,
+ host_version: str = "1.0.0",
+ sdk_version: str = "2.0.1",
+) -> object:
+ """构造一个已经通过校验的强类型 Manifest 测试对象。
+
+ Args:
+ plugin_id: 插件 ID。
+ version: 插件版本。
+ dependencies: 依赖声明列表。
+ capabilities: 能力声明列表。
+ host_version: 当前测试使用的 Host 版本。
+ sdk_version: 当前测试使用的 SDK 版本。
+
+ Returns:
+ object: ``PluginManifest`` 实例。
+ """
+ from src.plugin_runtime.runner.manifest_validator import ManifestValidator
+
+ validator = ManifestValidator(host_version=host_version, sdk_version=sdk_version)
+ manifest = validator.parse_manifest(
+ build_test_manifest(
+ plugin_id,
+ version=version,
+ dependencies=dependencies,
+ capabilities=capabilities,
+ )
+ )
+ assert manifest is not None
+ return manifest
+
+
# ─── 协议层测试 ───────────────────────────────────────────
@@ -441,8 +540,8 @@ class TestSDK:
def set_plugin_config(self, config):
self.configs.append(config)
- async def on_config_update(self, config, version):
- self.updates.append((config, version, list(self.configs)))
+ async def on_config_update(self, scope, config, version):
+ self.updates.append((scope, config, version, list(self.configs)))
runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[])
plugin = DummyPlugin()
@@ -453,14 +552,60 @@ class TestSDK:
message_type=MessageType.REQUEST,
method="plugin.config_updated",
plugin_id="demo_plugin",
- payload={"config_data": {"enabled": True}, "config_version": "v2"},
+ payload={
+ "plugin_id": "demo_plugin",
+ "config_scope": "self",
+ "config_data": {"enabled": True},
+ "config_version": "v2",
+ },
)
response = await runner._handle_config_updated(envelope)
assert response.payload["acknowledged"] is True
assert plugin.configs == [{"enabled": True}]
- assert plugin.updates == [({"enabled": True}, "v2", [{"enabled": True}])]
+ assert plugin.updates == [("self", {"enabled": True}, "v2", [{"enabled": True}])]
+
+ @pytest.mark.asyncio
+ async def test_runner_global_config_update_does_not_override_plugin_config(self):
+ """bot/model 广播不应覆盖插件自身配置缓存。"""
+ from src.plugin_runtime.protocol.envelope import Envelope, MessageType
+ from src.plugin_runtime.runner.runner_main import PluginRunner
+
+ class DummyPlugin:
+ def __init__(self):
+ self.configs = []
+ self.updates = []
+
+ def set_plugin_config(self, config):
+ self.configs.append(config)
+
+ async def on_config_update(self, scope, config, version):
+ self.updates.append((scope, config, version, list(self.configs)))
+
+ runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[])
+ plugin = DummyPlugin()
+ runner._loader._loaded_plugins["demo_plugin"] = SimpleNamespace(instance=plugin)
+ plugin.set_plugin_config({"plugin_enabled": True})
+
+ envelope = Envelope(
+ request_id=1,
+ message_type=MessageType.REQUEST,
+ method="plugin.config_updated",
+ plugin_id="demo_plugin",
+ payload={
+ "plugin_id": "demo_plugin",
+ "config_scope": "model",
+ "config_data": {"models": []},
+ "config_version": "",
+ },
+ )
+
+ response = await runner._handle_config_updated(envelope)
+
+ assert response.payload["acknowledged"] is True
+ assert plugin.configs == [{"plugin_enabled": True}]
+ assert plugin.updates == [("model", {"models": []}, "", [{"plugin_enabled": True}])]
@pytest.mark.asyncio
async def test_runner_bootstraps_capabilities_before_on_load(self, monkeypatch):
@@ -486,10 +631,10 @@ class TestSDK:
"timeout_ms": timeout_ms,
}
)
- if method == "cap.request":
+ if method == "cap.call":
bootstrap_methods = [call["method"] for call in self.calls[:-1]]
assert "plugin.bootstrap" in bootstrap_methods
- return SimpleNamespace(error=None, payload={"result": {"success": True}})
+ return SimpleNamespace(error=None, payload={"success": True})
return SimpleNamespace(error=None, payload={"accepted": True})
async def disconnect(self):
@@ -529,7 +674,102 @@ class TestSDK:
await runner.run()
methods = [call["method"] for call in runner._rpc_client.calls]
- assert methods == ["plugin.bootstrap", "cap.request", "plugin.register_components", "runner.ready"]
+ assert methods == ["plugin.bootstrap", "plugin.register_components", "cap.call", "runner.ready"]
+
+ @pytest.mark.asyncio
+ async def test_runner_batch_reload_merges_overlapping_reverse_dependents(self, monkeypatch):
+ """批量重载应只对重叠依赖闭包执行一次 unload/load。"""
+ from src.plugin_runtime.runner.runner_main import PluginRunner
+
+ runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[])
+ plugin_a_id = "test.plugin-a"
+ plugin_b_id = "test.plugin-b"
+ plugin_c_id = "test.plugin-c"
+
+ def build_meta(plugin_id: str, dependencies: list[str]) -> SimpleNamespace:
+ return SimpleNamespace(
+ plugin_id=plugin_id,
+ dependencies=dependencies,
+ plugin_dir=f"/tmp/{plugin_id}",
+ version="1.0.0",
+ instance=SimpleNamespace(),
+ )
+
+ loaded_metas = {
+ plugin_a_id: build_meta(plugin_a_id, []),
+ plugin_b_id: build_meta(plugin_b_id, [plugin_a_id]),
+ plugin_c_id: build_meta(plugin_c_id, [plugin_b_id]),
+ }
+ reloaded_metas = {
+ plugin_id: build_meta(plugin_id, list(meta.dependencies))
+ for plugin_id, meta in loaded_metas.items()
+ }
+ candidates = {
+ plugin_a_id: (
+ "dir_plugin_a",
+ build_test_manifest_model(plugin_a_id),
+ "plugin_a/plugin.py",
+ ),
+ plugin_b_id: (
+ "dir_plugin_b",
+ build_test_manifest_model(
+ plugin_b_id,
+ dependencies=[{"type": "plugin", "id": plugin_a_id, "version_spec": ">=1.0.0,<2.0.0"}],
+ ),
+ "plugin_b/plugin.py",
+ ),
+ plugin_c_id: (
+ "dir_plugin_c",
+ build_test_manifest_model(
+ plugin_c_id,
+ dependencies=[{"type": "plugin", "id": plugin_b_id, "version_spec": ">=1.0.0,<2.0.0"}],
+ ),
+ "plugin_c/plugin.py",
+ ),
+ }
+ unloaded_plugins: list[str] = []
+ activated_plugins: list[str] = []
+
+ monkeypatch.setattr(runner._loader, "discover_candidates", lambda plugin_dirs: (candidates, {}))
+ monkeypatch.setattr(runner._loader, "list_plugins", lambda: sorted(loaded_metas.keys()))
+ monkeypatch.setattr(runner._loader, "get_plugin", lambda plugin_id: loaded_metas.get(plugin_id))
+ monkeypatch.setattr(
+ runner._loader,
+ "remove_loaded_plugin",
+ lambda plugin_id: loaded_metas.pop(plugin_id, None),
+ )
+ monkeypatch.setattr(runner._loader, "purge_plugin_modules", lambda plugin_id, plugin_dir: [])
+ monkeypatch.setattr(
+ runner._loader,
+ "resolve_dependencies",
+ lambda reload_candidates, extra_available=None: (sorted(reload_candidates.keys()), {}),
+ )
+ monkeypatch.setattr(
+ runner._loader,
+ "load_candidate",
+ lambda plugin_id, candidate: reloaded_metas[plugin_id],
+ )
+
+ async def fake_unload_plugin(meta, reason, purge_modules=False):
+ del reason, purge_modules
+ unloaded_plugins.append(meta.plugin_id)
+ loaded_metas.pop(meta.plugin_id, None)
+
+ async def fake_activate_plugin(meta):
+ activated_plugins.append(meta.plugin_id)
+ loaded_metas[meta.plugin_id] = meta
+ return True
+
+ monkeypatch.setattr(runner, "_unload_plugin", fake_unload_plugin)
+ monkeypatch.setattr(runner, "_activate_plugin", fake_activate_plugin)
+
+ result = await runner._reload_plugins_by_ids([plugin_a_id, plugin_b_id], reason="manual")
+
+ assert result.success is True
+ assert result.requested_plugin_ids == [plugin_a_id, plugin_b_id]
+ assert unloaded_plugins == [plugin_c_id, plugin_b_id, plugin_a_id]
+ assert activated_plugins == [plugin_a_id, plugin_b_id, plugin_c_id]
+ assert result.reloaded_plugins == [plugin_a_id, plugin_b_id, plugin_c_id]
class TestPluginSdkUsage:
@@ -712,65 +952,77 @@ class TestManifestValidator:
def test_valid_manifest(self):
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
- validator = ManifestValidator()
- manifest = {
- "manifest_version": 1,
- "name": "test_plugin",
- "version": "1.0.0",
- "description": "测试插件",
- "author": "test",
- }
+ validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1")
+ manifest = build_test_manifest("test.valid-plugin", capabilities=["send.text"])
assert validator.validate(manifest) is True
assert len(validator.errors) == 0
+ assert validator.warnings == []
def test_missing_required_fields(self):
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
- validator = ManifestValidator()
- manifest = {"manifest_version": 1}
+ validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1")
+ manifest = {"manifest_version": 2}
assert validator.validate(manifest) is False
- assert len(validator.errors) >= 4 # name, version, description, author
+ assert len(validator.errors) >= 6
+ assert any("缺少必需字段" in error for error in validator.errors)
def test_unsupported_manifest_version(self):
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
- validator = ManifestValidator()
- manifest = {
- "manifest_version": 999,
- "name": "test",
- "version": "1.0",
- "description": "d",
- "author": "a",
- }
+ validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1")
+ manifest = build_test_manifest("test.invalid-version")
+ manifest["manifest_version"] = 999
assert validator.validate(manifest) is False
assert any("manifest_version" in e for e in validator.errors)
def test_host_version_compatibility(self):
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
- validator = ManifestValidator(host_version="0.8.5")
- manifest = {
- "name": "test",
- "version": "1.0",
- "description": "d",
- "author": "a",
- "host_application": {"min_version": "0.9.0"},
- }
+ validator = ManifestValidator(host_version="0.8.5", sdk_version="2.0.1")
+ manifest = build_test_manifest(
+ "test.host-check",
+ host_min_version="0.9.0",
+ host_max_version="1.0.0",
+ )
assert validator.validate(manifest) is False
assert any("Host 版本不兼容" in e for e in validator.errors)
- def test_recommended_fields_warning(self):
+ def test_sdk_version_compatibility(self):
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
- validator = ManifestValidator()
- manifest = {
- "name": "test",
- "version": "1.0",
- "description": "d",
- "author": "a",
- }
- validator.validate(manifest)
- assert len(validator.warnings) >= 3 # license, keywords, categories
+ validator = ManifestValidator(host_version="1.0.0", sdk_version="1.9.9")
+ manifest = build_test_manifest("test.sdk-check")
+ assert validator.validate(manifest) is False
+ assert any("SDK 版本不兼容" in e for e in validator.errors)
+
+ def test_extra_fields_are_rejected(self):
+ from src.plugin_runtime.runner.manifest_validator import ManifestValidator
+
+ validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1")
+ manifest = build_test_manifest("test.extra-field")
+ manifest["unexpected"] = True
+
+ assert validator.validate(manifest) is False
+ assert any("存在未声明字段" in error for error in validator.errors)
+
+ def test_python_package_conflict_rejects_manifest(self):
+ from src.plugin_runtime.runner.manifest_validator import ManifestValidator
+
+ validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1")
+ manifest = build_test_manifest(
+ "test.numpy-conflict",
+ dependencies=[
+ {
+ "type": "python_package",
+ "name": "numpy",
+ "version_spec": ">=999.0.0",
+ }
+ ],
+ )
+
+ assert validator.validate(manifest) is False
+ assert any("Python 包依赖冲突" in error for error in validator.errors)
class TestVersionComparator:
@@ -812,59 +1064,83 @@ class TestDependencyResolution:
loader = PluginLoader()
candidates = {
- "core": ("dir_core", {"name": "core", "version": "1.0", "description": "d", "author": "a"}, "plugin.py"),
- "auth": (
- "dir_auth",
- {"name": "auth", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core"]},
+ "test.core": (
+ "dir_core",
+ build_test_manifest_model("test.core"),
"plugin.py",
),
- "api": (
+ "test.auth": (
+ "dir_auth",
+ build_test_manifest_model(
+ "test.auth",
+ dependencies=[
+ {"type": "plugin", "id": "test.core", "version_spec": ">=1.0.0,<2.0.0"},
+ ],
+ ),
+ "plugin.py",
+ ),
+ "test.api": (
"dir_api",
- {"name": "api", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core", "auth"]},
+ build_test_manifest_model(
+ "test.api",
+ dependencies=[
+ {"type": "plugin", "id": "test.core", "version_spec": ">=1.0.0,<2.0.0"},
+ {"type": "plugin", "id": "test.auth", "version_spec": ">=1.0.0,<2.0.0"},
+ ],
+ ),
"plugin.py",
),
}
order, failed = loader._resolve_dependencies(candidates)
assert len(failed) == 0
- assert order.index("core") < order.index("auth")
- assert order.index("auth") < order.index("api")
+ assert order.index("test.core") < order.index("test.auth")
+ assert order.index("test.auth") < order.index("test.api")
def test_missing_dependency(self):
from src.plugin_runtime.runner.plugin_loader import PluginLoader
loader = PluginLoader()
candidates = {
- "plugin_a": (
+ "test.plugin-a": (
"dir_a",
- {
- "name": "plugin_a",
- "version": "1.0",
- "description": "d",
- "author": "a",
- "dependencies": ["nonexistent"],
- },
+ build_test_manifest_model(
+ "test.plugin-a",
+ dependencies=[
+ {"type": "plugin", "id": "test.nonexistent", "version_spec": ">=1.0.0,<2.0.0"},
+ ],
+ ),
"plugin.py",
),
}
order, failed = loader._resolve_dependencies(candidates)
- assert "plugin_a" in failed
- assert "缺少依赖" in failed["plugin_a"]
+ assert "test.plugin-a" in failed
+ assert "依赖未满足" in failed["test.plugin-a"]
def test_circular_dependency(self):
from src.plugin_runtime.runner.plugin_loader import PluginLoader
loader = PluginLoader()
candidates = {
- "a": (
+ "test.a": (
"dir_a",
- {"name": "a", "version": "1.0", "description": "d", "author": "x", "dependencies": ["b"]},
+ build_test_manifest_model(
+ "test.a",
+ dependencies=[
+ {"type": "plugin", "id": "test.b", "version_spec": ">=1.0.0,<2.0.0"},
+ ],
+ ),
"p.py",
),
- "b": (
+ "test.b": (
"dir_b",
- {"name": "b", "version": "1.0", "description": "d", "author": "x", "dependencies": ["a"]},
+ build_test_manifest_model(
+ "test.b",
+ dependencies=[
+ {"type": "plugin", "id": "test.a", "version_spec": ">=1.0.0,<2.0.0"},
+ ],
+ ),
"p.py",
),
}
@@ -882,12 +1158,11 @@ class TestDependencyResolution:
(plugin_dir / "_manifest.json").write_text(
json.dumps(
- {
- "name": "grok_search_plugin",
- "version": "1.0.0",
- "description": "demo",
- "author": "tester",
- }
+ build_test_manifest(
+ "test.grok-search-plugin",
+ name="grok_search_plugin",
+ description="demo",
+ )
),
encoding="utf-8",
)
@@ -907,14 +1182,130 @@ class TestDependencyResolution:
loader = PluginLoader()
loaded = loader.discover_and_load([str(plugin_root)])
- assert [meta.plugin_id for meta in loaded] == ["grok_search_plugin"]
+ assert [meta.plugin_id for meta in loaded] == ["test.grok-search-plugin"]
assert loader.failed_plugins == {}
assert loaded[0].instance.answer() == 42
+ def test_loader_requires_sdk_plugin_to_override_on_config_update(self, tmp_path):
+ from src.plugin_runtime.runner.plugin_loader import PluginLoader
+
+ plugin_root = tmp_path / "plugins"
+ plugin_root.mkdir()
+ plugin_dir = plugin_root / "demo_plugin"
+ plugin_dir.mkdir()
+
+ (plugin_dir / "_manifest.json").write_text(
+ json.dumps(
+ build_test_manifest(
+ "test.demo-plugin",
+ name="demo_plugin",
+ description="demo",
+ )
+ ),
+ encoding="utf-8",
+ )
+ (plugin_dir / "plugin.py").write_text(
+ "from maibot_sdk import MaiBotPlugin\n\n"
+ "class DemoPlugin(MaiBotPlugin):\n"
+ " async def on_load(self):\n"
+ " pass\n\n"
+ " async def on_unload(self):\n"
+ " pass\n\n"
+ "def create_plugin():\n"
+ " return DemoPlugin()\n",
+ encoding="utf-8",
+ )
+
+ loader = PluginLoader()
+ loaded = loader.discover_and_load([str(plugin_root)])
+
+ assert loaded == []
+ assert "test.demo-plugin" in loader.failed_plugins
+ assert "on_config_update" in loader.failed_plugins["test.demo-plugin"]
+
+ def test_loader_requires_sdk_plugin_to_override_on_load(self, tmp_path):
+ from src.plugin_runtime.runner.plugin_loader import PluginLoader
+
+ plugin_root = tmp_path / "plugins"
+ plugin_root.mkdir()
+ plugin_dir = plugin_root / "demo_plugin"
+ plugin_dir.mkdir()
+
+ (plugin_dir / "_manifest.json").write_text(
+ json.dumps(
+ build_test_manifest(
+ "test.demo-plugin",
+ name="demo_plugin",
+ description="demo",
+ )
+ ),
+ encoding="utf-8",
+ )
+ (plugin_dir / "plugin.py").write_text(
+ "from maibot_sdk import MaiBotPlugin\n\n"
+ "class DemoPlugin(MaiBotPlugin):\n"
+ " async def on_unload(self):\n"
+ " pass\n\n"
+ " async def on_config_update(self, scope, config_data, version):\n"
+ " pass\n\n"
+ "def create_plugin():\n"
+ " return DemoPlugin()\n",
+ encoding="utf-8",
+ )
+
+ loader = PluginLoader()
+ loaded = loader.discover_and_load([str(plugin_root)])
+
+ assert loaded == []
+ assert "test.demo-plugin" in loader.failed_plugins
+ assert "on_load" in loader.failed_plugins["test.demo-plugin"]
+
+ def test_loader_requires_sdk_plugin_to_override_on_unload(self, tmp_path):
+ from src.plugin_runtime.runner.plugin_loader import PluginLoader
+
+ plugin_root = tmp_path / "plugins"
+ plugin_root.mkdir()
+ plugin_dir = plugin_root / "demo_plugin"
+ plugin_dir.mkdir()
+
+ (plugin_dir / "_manifest.json").write_text(
+ json.dumps(
+ build_test_manifest(
+ "test.demo-plugin",
+ name="demo_plugin",
+ description="demo",
+ )
+ ),
+ encoding="utf-8",
+ )
+ (plugin_dir / "plugin.py").write_text(
+ "from maibot_sdk import MaiBotPlugin\n\n"
+ "class DemoPlugin(MaiBotPlugin):\n"
+ " async def on_load(self):\n"
+ " pass\n\n"
+ " async def on_config_update(self, scope, config_data, version):\n"
+ " pass\n\n"
+ "def create_plugin():\n"
+ " return DemoPlugin()\n",
+ encoding="utf-8",
+ )
+
+ loader = PluginLoader()
+ loaded = loader.discover_and_load([str(plugin_root)])
+
+ assert loaded == []
+ assert "test.demo-plugin" in loader.failed_plugins
+ assert "on_unload" in loader.failed_plugins["test.demo-plugin"]
+
def test_isolate_sys_path_preserves_plugin_dirs(self):
+ import builtins
+ import importlib
+
from src.plugin_runtime.runner import runner_main
plugin_root = os.path.normpath("/tmp/maibot-plugin-root")
+ original_import = builtins.__import__
+ original_import_module = importlib.import_module
original_path = list(sys.path)
original_meta_path = list(sys.meta_path)
@@ -926,9 +1317,155 @@ class TestDependencyResolution:
assert plugin_root in sys.path
finally:
+ builtins.__import__ = original_import
+ importlib.import_module = original_import_module
sys.path[:] = original_path
sys.meta_path[:] = original_meta_path
+ def test_isolate_sys_path_blocks_disallowed_src_imports(self):
+ import builtins
+ import importlib
+
+ from src.plugin_runtime.runner import runner_main
+
+ original_import = builtins.__import__
+ original_import_module = importlib.import_module
+ original_path = list(sys.path)
+ original_meta_path = list(sys.meta_path)
+ sys.modules.pop("src.forbidden_demo", None)
+
+ try:
+ runner_main._isolate_sys_path([])
+ plugin_globals = {
+ "__name__": "_maibot_plugin_demo",
+ "__package__": "_maibot_plugin_demo",
+ "importlib": importlib,
+ }
+
+ with pytest.raises(ImportError, match="不允许导入主程序模块"):
+ exec('importlib.import_module("src.forbidden_demo")', plugin_globals)
+ finally:
+ builtins.__import__ = original_import
+ importlib.import_module = original_import_module
+ sys.path[:] = original_path
+ sys.meta_path[:] = original_meta_path
+ sys.modules.pop("src.forbidden_demo", None)
+
+ def test_isolate_sys_path_blocks_preloaded_runtime_modules(self):
+ import builtins
+ import importlib
+
+ from src.plugin_runtime.runner import runner_main
+
+ original_import = builtins.__import__
+ original_import_module = importlib.import_module
+ original_path = list(sys.path)
+ original_meta_path = list(sys.meta_path)
+
+ try:
+ runner_main._isolate_sys_path([])
+ plugin_globals = {
+ "__name__": "_maibot_plugin_demo",
+ "__package__": "_maibot_plugin_demo",
+ "importlib": importlib,
+ }
+
+ with pytest.raises(ImportError, match="rpc_client"):
+ exec('importlib.import_module("src.plugin_runtime.runner.rpc_client")', plugin_globals)
+ finally:
+ builtins.__import__ = original_import
+ importlib.import_module = original_import_module
+ sys.path[:] = original_path
+ sys.meta_path[:] = original_meta_path
+
+ def test_isolate_sys_path_keeps_legacy_logger_import_available(self):
+ import builtins
+ import importlib
+
+ from src.plugin_runtime.runner import runner_main
+
+ original_import = builtins.__import__
+ original_import_module = importlib.import_module
+ original_path = list(sys.path)
+ original_meta_path = list(sys.meta_path)
+
+ try:
+ runner_main._isolate_sys_path([])
+ plugin_globals = {
+ "__name__": "_maibot_plugin_demo",
+ "__package__": "_maibot_plugin_demo",
+ "importlib": importlib,
+ }
+
+ exec('logger_module = importlib.import_module("src.common.logger")', plugin_globals)
+ logger_module = plugin_globals["logger_module"]
+ assert callable(logger_module.get_logger)
+ finally:
+ builtins.__import__ = original_import
+ importlib.import_module = original_import_module
+ sys.path[:] = original_path
+ sys.meta_path[:] = original_meta_path
+
+ def test_isolate_sys_path_keeps_runtime_imports_working(self):
+ import builtins
+ import importlib
+
+ from src.plugin_runtime.runner import runner_main
+
+ original_import = builtins.__import__
+ original_import_module = importlib.import_module
+ original_path = list(sys.path)
+ original_meta_path = list(sys.meta_path)
+
+ try:
+ runner_main._isolate_sys_path([])
+
+ uds_module = importlib.import_module("src.plugin_runtime.transport.uds")
+ assert hasattr(uds_module, "UDSTransportClient")
+ finally:
+ builtins.__import__ = original_import
+ importlib.import_module = original_import_module
+ sys.path[:] = original_path
+ sys.meta_path[:] = original_meta_path
+
+ @pytest.mark.asyncio
+ async def test_async_main_removes_sensitive_runtime_env_vars(self, monkeypatch):
+ from src.plugin_runtime.runner import runner_main
+
+ captured = {}
+
+ class FakeRunner:
+ def __init__(
+ self,
+ host_address: str,
+ session_token: str,
+ plugin_dirs: list[str],
+ external_available_plugins: dict[str, str] | None = None,
+ ) -> None:
+ captured["host_address"] = host_address
+ captured["session_token"] = session_token
+ captured["plugin_dirs"] = plugin_dirs
+ captured["external_available_plugins"] = external_available_plugins or {}
+
+ async def run(self) -> None:
+ assert os.environ.get(runner_main.ENV_IPC_ADDRESS) is None
+ assert os.environ.get(runner_main.ENV_SESSION_TOKEN) is None
+
+ monkeypatch.setenv(runner_main.ENV_IPC_ADDRESS, "tcp://127.0.0.1:9999")
+ monkeypatch.setenv(runner_main.ENV_SESSION_TOKEN, "secret-token")
+ monkeypatch.setenv(runner_main.ENV_PLUGIN_DIRS, "/tmp/plugins")
+ monkeypatch.setenv(runner_main.ENV_EXTERNAL_PLUGIN_IDS, '{"demo.plugin":"1.0.0"}')
+ monkeypatch.setattr(runner_main, "_install_shutdown_signal_handlers", lambda callback: None)
+ monkeypatch.setattr(runner_main, "_isolate_sys_path", lambda plugin_dirs: None)
+ monkeypatch.setattr(runner_main, "PluginRunner", FakeRunner)
+
+ await runner_main._async_main()
+
+ assert captured["host_address"] == "tcp://127.0.0.1:9999"
+ assert captured["session_token"] == "secret-token"
+ assert captured["plugin_dirs"] == ["/tmp/plugins"]
+ assert captured["external_available_plugins"] == {"demo.plugin": "1.0.0"}
+
# ─── Host-side ComponentRegistry 测试 ──────────────────────
@@ -973,6 +1510,30 @@ class TestComponentRegistry:
assert stats["command"] == 1
assert stats["tool"] == 1
+ def test_register_command_with_invalid_regex_only_warns(self, monkeypatch):
+ from src.plugin_runtime.host.component_registry import ComponentRegistry
+
+ reg = ComponentRegistry()
+ warnings: list[str] = []
+ monkeypatch.setattr(
+ "src.plugin_runtime.host.component_registry.logger.warning",
+ lambda message: warnings.append(str(message)),
+ )
+
+ success = reg.register_component(
+ "broken",
+ "command",
+ "plugin_a",
+ {
+ "command_pattern": "[",
+ },
+ )
+
+ assert success is True
+ assert reg.get_component("plugin_a.broken") is not None
+ assert warnings
+ assert "plugin_a.broken" in warnings[0]
+
def test_query_by_type(self):
from src.plugin_runtime.host.component_registry import ComponentRegistry
@@ -1664,6 +2225,67 @@ class TestWorkflowExecutor:
class TestRPCServer:
"""RPC Server 代际保护测试"""
+ @pytest.mark.asyncio
+ async def test_reject_second_active_runner_connection(self):
+ from src.plugin_runtime.host.rpc_server import RPCServer
+ from src.plugin_runtime.protocol.codec import MsgPackCodec
+ from src.plugin_runtime.protocol.envelope import Envelope, HelloPayload, HelloResponsePayload, MessageType
+
+ class DummyTransport:
+ async def start(self, handler):
+ return None
+
+ async def stop(self):
+ return None
+
+ def get_address(self):
+ return "dummy"
+
+ class FakeConnection:
+ def __init__(self, incoming_frames: list[bytes]):
+ self._incoming_frames = list(incoming_frames)
+ self.sent_frames: list[bytes] = []
+ self.is_closed = False
+
+ async def recv_frame(self):
+ return self._incoming_frames.pop(0)
+
+ async def send_frame(self, data):
+ self.sent_frames.append(data)
+
+ async def close(self):
+ self.is_closed = True
+
+ codec = MsgPackCodec()
+ server = RPCServer(transport=DummyTransport(), session_token="session-token")
+ active_conn = SimpleNamespace(is_closed=False)
+ server._connection = active_conn
+
+ hello = HelloPayload(
+ runner_id="runner-b",
+ sdk_version="1.0.0",
+ session_token="session-token",
+ )
+ envelope = Envelope(
+ request_id=1,
+ message_type=MessageType.REQUEST,
+ method="runner.hello",
+ payload=hello.model_dump(),
+ )
+ incoming_conn = FakeConnection([codec.encode_envelope(envelope)])
+
+ await server._handle_connection(incoming_conn)
+
+ assert incoming_conn.is_closed is True
+ assert server._connection is active_conn
+ assert server.last_handshake_rejection_reason == "已有活跃 Runner 连接,拒绝新的握手"
+ assert len(incoming_conn.sent_frames) == 1
+
+ response = codec.decode_envelope(incoming_conn.sent_frames[0])
+ response_payload = HelloResponsePayload.model_validate(response.payload)
+ assert response_payload.accepted is False
+ assert response_payload.reason == "已有活跃 Runner 连接,拒绝新的握手"
+
def test_ignore_stale_generation_response(self):
from src.plugin_runtime.host.rpc_server import RPCServer
from src.plugin_runtime.protocol.envelope import Envelope, MessageType
@@ -2012,6 +2634,39 @@ class TestSupervisor:
assert supervisor.component_registry.get_component("plugin_a.handler") is not None
assert supervisor.component_registry.get_component("plugin_a.obsolete") is None
+ @pytest.mark.asyncio
+ async def test_reload_plugins_uses_batch_rpc_for_multiple_roots(self):
+ from src.plugin_runtime.host.supervisor import PluginSupervisor
+ from src.plugin_runtime.protocol.envelope import ReloadPluginsResultPayload
+
+ supervisor = PluginSupervisor(plugin_dirs=[])
+ sent_requests: list[tuple[str, dict[str, object], int]] = []
+
+ class FakeRPCServer:
+ async def send_request(self, method, payload, timeout_ms=5000, **kwargs):
+ del kwargs
+ sent_requests.append((method, payload, timeout_ms))
+ return SimpleNamespace(
+ payload=ReloadPluginsResultPayload(
+ success=True,
+ requested_plugin_ids=["plugin_a", "plugin_b"],
+ reloaded_plugins=["plugin_a", "plugin_b", "plugin_c"],
+ unloaded_plugins=["plugin_c", "plugin_b", "plugin_a"],
+ ).model_dump()
+ )
+
+ supervisor._rpc_server = FakeRPCServer()
+
+ reloaded = await supervisor.reload_plugins(["plugin_a", "plugin_b", "plugin_a"], reason="manual")
+
+ assert reloaded is True
+ assert len(sent_requests) == 1
+ method, payload, timeout_ms = sent_requests[0]
+ assert method == "plugin.reload_batch"
+ assert payload["plugin_ids"] == ["plugin_a", "plugin_b"]
+ assert payload["reason"] == "manual"
+ assert timeout_ms >= 10000
+
@pytest.mark.asyncio
async def test_reload_rolls_back_when_runner_ready_not_received(self, monkeypatch):
from src.plugin_runtime.host.supervisor import PluginSupervisor
@@ -2152,8 +2807,11 @@ class TestIntegration:
self.supervisors = [FakeSupervisor("plugin_a"), FakeSupervisor("plugin_b")]
monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager())
+ manager = integration_module.PluginRuntimeManager()
+ manager._builtin_supervisor = FakeSupervisor("plugin_a")
+ manager._third_party_supervisor = FakeSupervisor("plugin_b")
- result = await integration_module.PluginRuntimeManager._cap_component_enable(
+ result = await manager._cap_component_enable(
"plugin_a",
"component.enable",
{"name": "shared", "component_type": "tool", "scope": "global", "stream_id": ""},
@@ -2182,8 +2840,10 @@ class TestIntegration:
self.supervisors = [FakeSupervisor()]
monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager())
+ manager = integration_module.PluginRuntimeManager()
+ manager._builtin_supervisor = FakeSupervisor()
- result = await integration_module.PluginRuntimeManager._cap_component_disable(
+ result = await manager._cap_component_disable(
"plugin_a",
"component.disable",
{"name": "plugin_a.handler", "component_type": "tool", "scope": "stream", "stream_id": "s1"},
@@ -2197,6 +2857,8 @@ class TestIntegration:
from src.plugin_runtime import integration as integration_module
instances = []
+ builtin_dir = Path("builtin")
+ thirdparty_dir = Path("thirdparty")
class FakeCapabilityService:
def register_capability(self, name, impl):
@@ -2204,11 +2866,21 @@ class TestIntegration:
class FakeSupervisor:
def __init__(self, plugin_dirs=None, socket_path=None):
- self.plugin_dirs = plugin_dirs or []
+ self._plugin_dirs = plugin_dirs or []
self.capability_service = FakeCapabilityService()
+ self.external_plugin_versions = {}
self.stopped = False
instances.append(self)
+ def set_external_available_plugins(self, plugin_versions):
+ self.external_plugin_versions = dict(plugin_versions)
+
+ def get_loaded_plugin_ids(self):
+ return []
+
+ def get_loaded_plugin_versions(self):
+ return {}
+
async def start(self):
if len(instances) == 2 and self is instances[1]:
raise RuntimeError("boom")
@@ -2217,10 +2889,10 @@ class TestIntegration:
self.stopped = True
monkeypatch.setattr(
- integration_module.PluginRuntimeManager, "_get_builtin_plugin_dirs", staticmethod(lambda: ["builtin"])
+ integration_module.PluginRuntimeManager, "_get_builtin_plugin_dirs", staticmethod(lambda: [builtin_dir])
)
monkeypatch.setattr(
- integration_module.PluginRuntimeManager, "_get_thirdparty_plugin_dirs", staticmethod(lambda: ["thirdparty"])
+ integration_module.PluginRuntimeManager, "_get_third_party_plugin_dirs", staticmethod(lambda: [thirdparty_dir])
)
import src.plugin_runtime.host.supervisor as supervisor_module
@@ -2238,6 +2910,7 @@ class TestIntegration:
async def test_handle_plugin_source_changes_only_reload_matching_supervisor(self, monkeypatch, tmp_path):
from src.config.file_watcher import FileChange
from src.plugin_runtime import integration as integration_module
+ import json
builtin_root = tmp_path / "src" / "plugins" / "built_in"
thirdparty_root = tmp_path / "plugins"
@@ -2247,6 +2920,10 @@ class TestIntegration:
beta_dir.mkdir(parents=True)
(alpha_dir / "config.toml").write_text("enabled = true\n", encoding="utf-8")
(beta_dir / "config.toml").write_text("enabled = false\n", encoding="utf-8")
+ (alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
+ (beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
+ (alpha_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.alpha")), encoding="utf-8")
+ (beta_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.beta")), encoding="utf-8")
monkeypatch.chdir(tmp_path)
@@ -2257,8 +2934,14 @@ class TestIntegration:
self.reload_reasons = []
self.config_updates = []
- async def reload_plugins(self, reason="manual"):
- self.reload_reasons.append(reason)
+ def get_loaded_plugin_ids(self):
+ return sorted(self._registered_plugins.keys())
+
+ def get_loaded_plugin_versions(self):
+ return {plugin_id: "1.0.0" for plugin_id in self._registered_plugins}
+
+ async def reload_plugins(self, plugin_ids=None, reason="manual", external_available_plugins=None):
+ self.reload_reasons.append((plugin_ids, reason, external_available_plugins or {}))
async def notify_plugin_config_updated(self, plugin_id, config_data, config_version=""):
self.config_updates.append((plugin_id, config_data, config_version))
@@ -2266,8 +2949,8 @@ class TestIntegration:
manager = integration_module.PluginRuntimeManager()
manager._started = True
- manager._builtin_supervisor = FakeSupervisor([builtin_root], {"alpha": object()})
- manager._third_party_supervisor = FakeSupervisor([thirdparty_root], {"beta": object()})
+ manager._builtin_supervisor = FakeSupervisor([builtin_root], {"test.alpha": object()})
+ manager._third_party_supervisor = FakeSupervisor([thirdparty_root], {"test.beta": object()})
changes = [
FileChange(change_type=1, path=beta_dir / "plugin.py"),
@@ -2283,15 +2966,71 @@ class TestIntegration:
await manager._handle_plugin_source_changes(changes)
assert manager._builtin_supervisor.reload_reasons == []
- assert manager._third_party_supervisor.reload_reasons == ["file_watcher"]
+ assert manager._third_party_supervisor.reload_reasons == [
+ (["test.beta"], "file_watcher", {"test.alpha": "1.0.0"})
+ ]
assert manager._builtin_supervisor.config_updates == []
assert manager._third_party_supervisor.config_updates == []
assert refresh_calls == [True]
+ @pytest.mark.asyncio
+ async def test_reload_plugins_globally_warns_and_skips_cross_supervisor_dependents(self, monkeypatch):
+ from src.plugin_runtime import integration as integration_module
+
+ class FakeRegistration:
+ def __init__(self, dependencies):
+ self.dependencies = dependencies
+
+ class FakeSupervisor:
+ def __init__(self, registrations):
+ self._registered_plugins = registrations
+ self.reload_calls = []
+
+ def get_loaded_plugin_ids(self):
+ return sorted(self._registered_plugins.keys())
+
+ def get_loaded_plugin_versions(self):
+ return {plugin_id: "1.0.0" for plugin_id in self._registered_plugins}
+
+ async def reload_plugins(self, plugin_ids=None, reason="manual", external_available_plugins=None):
+ self.reload_calls.append((plugin_ids, reason, dict(sorted((external_available_plugins or {}).items()))))
+ return True
+
+ builtin_supervisor = FakeSupervisor({"test.alpha": FakeRegistration([])})
+ third_party_supervisor = FakeSupervisor(
+ {
+ "test.beta": FakeRegistration(["test.alpha"]),
+ "test.gamma": FakeRegistration(["test.beta"]),
+ }
+ )
+
+ manager = integration_module.PluginRuntimeManager()
+ manager._builtin_supervisor = builtin_supervisor
+ manager._third_party_supervisor = third_party_supervisor
+ warning_messages = []
+
+ monkeypatch.setattr(
+ integration_module.logger,
+ "warning",
+ lambda message: warning_messages.append(message),
+ )
+
+ reloaded = await manager.reload_plugins_globally(["test.alpha"], reason="manual")
+
+ assert reloaded is True
+ assert builtin_supervisor.reload_calls == [
+ (["test.alpha"], "manual", {"test.beta": "1.0.0", "test.gamma": "1.0.0"})
+ ]
+ assert third_party_supervisor.reload_calls == []
+ assert len(warning_messages) == 1
+ assert "test.beta, test.gamma" in warning_messages[0]
+ assert "跨 Supervisor API 调用仍然可用" in warning_messages[0]
+
@pytest.mark.asyncio
async def test_handle_plugin_config_changes_only_notify_target_plugin(self, monkeypatch, tmp_path):
from src.plugin_runtime import integration as integration_module
from src.config.file_watcher import FileChange
+ import json
builtin_root = tmp_path / "src" / "plugins" / "built_in"
thirdparty_root = tmp_path / "plugins"
@@ -2301,6 +3040,10 @@ class TestIntegration:
beta_dir.mkdir(parents=True)
(alpha_dir / "config.toml").write_text("enabled = true\n", encoding="utf-8")
(beta_dir / "config.toml").write_text("enabled = false\n", encoding="utf-8")
+ (alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
+ (beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
+ (alpha_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.alpha")), encoding="utf-8")
+ (beta_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.beta")), encoding="utf-8")
monkeypatch.chdir(tmp_path)
@@ -2310,25 +3053,97 @@ class TestIntegration:
self._registered_plugins = {plugin_id: object() for plugin_id in plugins}
self.config_updates = []
- async def notify_plugin_config_updated(self, plugin_id, config_data, config_version=""):
- self.config_updates.append((plugin_id, config_data, config_version))
+ async def notify_plugin_config_updated(
+ self,
+ plugin_id,
+ config_data,
+ config_version="",
+ config_scope="self",
+ ):
+ self.config_updates.append((plugin_id, config_data, config_version, config_scope))
return True
manager = integration_module.PluginRuntimeManager()
manager._started = True
- manager._builtin_supervisor = FakeSupervisor([builtin_root], ["alpha"])
- manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["beta"])
+ manager._builtin_supervisor = FakeSupervisor([builtin_root], ["test.alpha"])
+ manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["test.beta"])
await manager._handle_plugin_config_changes(
- "alpha",
+ "test.alpha",
[FileChange(change_type=1, path=alpha_dir / "config.toml")],
)
- assert manager._builtin_supervisor.config_updates == [("alpha", {"enabled": True}, "")]
+ assert manager._builtin_supervisor.config_updates == [("test.alpha", {"enabled": True}, "", "self")]
assert manager._third_party_supervisor.config_updates == []
+ @pytest.mark.asyncio
+ async def test_handle_main_config_reload_only_notifies_subscribers(self, monkeypatch):
+ from src.plugin_runtime import integration as integration_module
+
+ class FakeRegistration:
+ def __init__(self, subscriptions):
+ self.config_reload_subscriptions = subscriptions
+
+ class FakeSupervisor:
+ def __init__(self, registrations):
+ self._registered_plugins = registrations
+ self.config_updates = []
+
+ def get_config_reload_subscribers(self, scope):
+ matched_plugins = []
+ for plugin_id, registration in self._registered_plugins.items():
+ if scope in registration.config_reload_subscriptions:
+ matched_plugins.append(plugin_id)
+ return matched_plugins
+
+ async def notify_plugin_config_updated(
+ self,
+ plugin_id,
+ config_data,
+ config_version="",
+ config_scope="self",
+ ):
+ self.config_updates.append((plugin_id, config_data, config_version, config_scope))
+ return True
+
+ fake_global = SimpleNamespace(plugin_runtime=SimpleNamespace(enabled=True))
+ monkeypatch.setattr(
+ integration_module.config_manager,
+ "get_global_config",
+ lambda: SimpleNamespace(model_dump=lambda: {"bot": {"name": "MaiBot"}}, plugin_runtime=fake_global.plugin_runtime),
+ )
+ monkeypatch.setattr(
+ integration_module.config_manager,
+ "get_model_config",
+ lambda: SimpleNamespace(model_dump=lambda: {"models": [{"name": "demo"}]}),
+ )
+
+ manager = integration_module.PluginRuntimeManager()
+ manager._started = True
+ manager._builtin_supervisor = FakeSupervisor(
+ {
+ "test.alpha": FakeRegistration(["bot"]),
+ "test.beta": FakeRegistration([]),
+ }
+ )
+ manager._third_party_supervisor = FakeSupervisor(
+ {
+ "test.gamma": FakeRegistration(["model"]),
+ }
+ )
+
+ await manager._handle_main_config_reload(["bot", "model"])
+
+ assert manager._builtin_supervisor.config_updates == [
+ ("test.alpha", {"bot": {"name": "MaiBot"}}, "", "bot")
+ ]
+ assert manager._third_party_supervisor.config_updates == [
+ ("test.gamma", {"models": [{"name": "demo"}]}, "", "model")
+ ]
+
def test_refresh_plugin_config_watch_subscriptions_registers_per_plugin(self, tmp_path):
from src.plugin_runtime import integration as integration_module
+ import json
builtin_root = tmp_path / "src" / "plugins" / "built_in"
thirdparty_root = tmp_path / "plugins"
@@ -2336,6 +3151,10 @@ class TestIntegration:
beta_dir = thirdparty_root / "beta"
alpha_dir.mkdir(parents=True)
beta_dir.mkdir(parents=True)
+ (alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
+ (beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
+ (alpha_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.alpha")), encoding="utf-8")
+ (beta_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.beta")), encoding="utf-8")
class FakeWatcher:
def __init__(self):
@@ -2358,12 +3177,12 @@ class TestIntegration:
manager = integration_module.PluginRuntimeManager()
manager._plugin_file_watcher = FakeWatcher()
- manager._builtin_supervisor = FakeSupervisor([builtin_root], ["alpha"])
- manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["beta"])
+ manager._builtin_supervisor = FakeSupervisor([builtin_root], ["test.alpha"])
+ manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["test.beta"])
manager._refresh_plugin_config_watch_subscriptions()
- assert set(manager._plugin_config_watcher_subscriptions.keys()) == {"alpha", "beta"}
+ assert set(manager._plugin_config_watcher_subscriptions.keys()) == {"test.alpha", "test.beta"}
assert {
subscription["paths"][0] for subscription in manager._plugin_file_watcher.subscriptions
} == {alpha_dir / "config.toml", beta_dir / "config.toml"}
@@ -2372,55 +3191,30 @@ class TestIntegration:
async def test_component_reload_plugin_returns_failure_when_reload_rolls_back(self, monkeypatch):
from src.plugin_runtime import integration as integration_module
- class FakeSupervisor:
- def __init__(self):
- self._registered_plugins = {"alpha": object()}
+ manager = integration_module.PluginRuntimeManager()
+ monkeypatch.setattr(manager, "reload_plugins_globally", lambda plugin_ids, reason="manual": asyncio.sleep(0, False))
- async def reload_plugins(self, reason="manual"):
- return False
-
- class FakeManager:
- def __init__(self):
- self.supervisors = [FakeSupervisor()]
-
- monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager())
-
- result = await integration_module.PluginRuntimeManager._cap_component_reload_plugin(
+ result = await manager._cap_component_reload_plugin(
"plugin_a",
"component.reload_plugin",
{"plugin_name": "alpha"},
)
assert result["success"] is False
- assert "已回滚" in result["error"]
+ assert result["error"] == "插件 alpha 热重载失败"
@pytest.mark.asyncio
async def test_component_load_plugin_returns_failure_when_reload_rolls_back(self, monkeypatch, tmp_path):
from src.plugin_runtime import integration as integration_module
- plugin_root = tmp_path / "plugins"
- plugin_root.mkdir()
- (plugin_root / "alpha").mkdir()
+ manager = integration_module.PluginRuntimeManager()
+ monkeypatch.setattr(manager, "load_plugin_globally", lambda plugin_id, reason="manual": asyncio.sleep(0, False))
- class FakeSupervisor:
- def __init__(self):
- self._registered_plugins = {}
- self._plugin_dirs = [str(plugin_root)]
-
- async def reload_plugins(self, reason="manual"):
- return False
-
- class FakeManager:
- def __init__(self):
- self.supervisors = [FakeSupervisor()]
-
- monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager())
-
- result = await integration_module.PluginRuntimeManager._cap_component_load_plugin(
+ result = await manager._cap_component_load_plugin(
"plugin_a",
"component.load_plugin",
{"plugin_name": "alpha"},
)
assert result["success"] is False
- assert "已回滚" in result["error"]
+ assert result["error"] == "插件 alpha 热重载失败"
diff --git a/pytests/test_plugin_runtime_action_bridge.py b/pytests/test_plugin_runtime_action_bridge.py
new file mode 100644
index 00000000..e13dfaf3
--- /dev/null
+++ b/pytests/test_plugin_runtime_action_bridge.py
@@ -0,0 +1,284 @@
+"""核心组件查询层与插件运行时聚合测试。"""
+
+from types import SimpleNamespace
+from typing import Any
+
+import pytest
+
+import src.plugin_runtime.integration as integration_module
+
+from src.core.types import ActionInfo, ToolInfo
+from src.plugin_runtime.component_query import component_query_service
+from src.plugin_runtime.host.supervisor import PluginSupervisor
+
+
+class _FakeRuntimeManager:
+ """测试用插件运行时管理器。"""
+
+ def __init__(self, supervisor: PluginSupervisor, plugin_id: str, plugin_config: dict[str, Any]) -> None:
+ """初始化测试用运行时管理器。
+
+ Args:
+ supervisor: 持有测试组件的监督器。
+ plugin_id: 目标插件 ID。
+ plugin_config: 需要返回的插件配置。
+ """
+
+ self.supervisors = [supervisor]
+ self._plugin_id = plugin_id
+ self._plugin_config = plugin_config
+
+ def _get_supervisor_for_plugin(self, plugin_id: str) -> PluginSupervisor | None:
+ """按插件 ID 返回对应监督器。
+
+ Args:
+ plugin_id: 目标插件 ID。
+
+ Returns:
+ PluginSupervisor | None: 命中时返回监督器。
+ """
+
+ return self.supervisors[0] if plugin_id == self._plugin_id else None
+
+ def _load_plugin_config_for_supervisor(self, supervisor: Any, plugin_id: str) -> dict[str, Any]:
+ """返回测试配置。
+
+ Args:
+ supervisor: 监督器实例。
+ plugin_id: 目标插件 ID。
+
+ Returns:
+ dict[str, Any]: 测试配置内容。
+ """
+
+ del supervisor
+ if plugin_id != self._plugin_id:
+ return {}
+ return dict(self._plugin_config)
+
+
+def _install_runtime_manager(
+ monkeypatch: pytest.MonkeyPatch,
+ supervisor: PluginSupervisor,
+ plugin_id: str,
+ plugin_config: dict[str, Any] | None = None,
+) -> None:
+ """为测试安装假的运行时管理器。
+
+ Args:
+ monkeypatch: pytest monkeypatch 对象。
+ supervisor: 持有测试组件的监督器。
+ plugin_id: 测试插件 ID。
+ plugin_config: 可选的测试配置内容。
+ """
+
+ fake_manager = _FakeRuntimeManager(supervisor, plugin_id, plugin_config or {"enabled": True})
+ monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: fake_manager)
+
+
+@pytest.mark.asyncio
+async def test_core_component_registry_reads_runtime_action_and_executor(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """核心查询层应直接读取运行时 Action,并返回 RPC 执行闭包。"""
+
+ plugin_id = "runtime_action_bridge_plugin"
+ action_name = "runtime_action_bridge_test"
+ supervisor = PluginSupervisor(plugin_dirs=[])
+ captured: dict[str, Any] = {}
+
+ supervisor.component_registry.register_component(
+ name=action_name,
+ component_type="ACTION",
+ plugin_id=plugin_id,
+ metadata={
+ "description": "发送一个测试回复",
+ "enabled": True,
+ "activation_type": "keyword",
+ "activation_probability": 0.25,
+ "activation_keywords": ["测试", "hello"],
+ "action_parameters": {"target": "目标对象"},
+ "action_require": ["需要发送回复时使用"],
+ "associated_types": ["text"],
+ "parallel_action": True,
+ },
+ )
+ _install_runtime_manager(monkeypatch, supervisor, plugin_id, {"enabled": True, "mode": "test"})
+
+ async def fake_invoke_plugin(
+ method: str,
+ plugin_id: str,
+ component_name: str,
+ args: dict[str, Any] | None = None,
+ timeout_ms: int = 30000,
+ ) -> Any:
+ """模拟动作 RPC 调用。"""
+
+ captured["method"] = method
+ captured["plugin_id"] = plugin_id
+ captured["component_name"] = component_name
+ captured["args"] = args or {}
+ captured["timeout_ms"] = timeout_ms
+ return SimpleNamespace(payload={"success": True, "result": (True, "runtime action executed")})
+
+ monkeypatch.setattr(supervisor, "invoke_plugin", fake_invoke_plugin)
+
+ action_info = component_query_service.get_action_info(action_name)
+ assert isinstance(action_info, ActionInfo)
+ assert action_info.plugin_name == plugin_id
+ assert action_info.description == "发送一个测试回复"
+ assert action_info.activation_keywords == ["测试", "hello"]
+ assert action_info.random_activation_probability == 0.25
+ assert action_info.parallel_action is True
+ assert action_name in component_query_service.get_default_actions()
+ assert component_query_service.get_plugin_config(plugin_id) == {"enabled": True, "mode": "test"}
+
+ executor = component_query_service.get_action_executor(action_name)
+ assert executor is not None
+
+ success, reason = await executor(
+ action_data={"target": "MaiBot"},
+ action_reasoning="当前适合使用这个动作",
+ cycle_timers={"planner": 0.1},
+ thinking_id="tid-1",
+ chat_stream=SimpleNamespace(session_id="stream-1"),
+ log_prefix="[test]",
+ shutting_down=False,
+ plugin_config={"enabled": True},
+ )
+
+ assert success is True
+ assert reason == "runtime action executed"
+ assert captured["method"] == "plugin.invoke_action"
+ assert captured["plugin_id"] == plugin_id
+ assert captured["component_name"] == action_name
+ assert captured["args"]["stream_id"] == "stream-1"
+ assert captured["args"]["chat_id"] == "stream-1"
+ assert captured["args"]["reasoning"] == "当前适合使用这个动作"
+ assert captured["args"]["target"] == "MaiBot"
+ assert captured["args"]["action_data"] == {"target": "MaiBot"}
+
+
+@pytest.mark.asyncio
+async def test_core_component_registry_reads_runtime_command_and_executor(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """核心查询层应直接使用运行时命令匹配与执行闭包。"""
+
+ plugin_id = "runtime_command_bridge_plugin"
+ command_name = "runtime_command_bridge_test"
+ supervisor = PluginSupervisor(plugin_dirs=[])
+ captured: dict[str, Any] = {}
+
+ supervisor.component_registry.register_component(
+ name=command_name,
+ component_type="COMMAND",
+ plugin_id=plugin_id,
+ metadata={
+ "description": "测试命令",
+ "enabled": True,
+ "command_pattern": r"^/test(?:\s+.+)?$",
+ "aliases": ["/hello"],
+ "intercept_message_level": 1,
+ },
+ )
+ _install_runtime_manager(monkeypatch, supervisor, plugin_id, {"mode": "command"})
+
+ async def fake_invoke_plugin(
+ method: str,
+ plugin_id: str,
+ component_name: str,
+ args: dict[str, Any] | None = None,
+ timeout_ms: int = 30000,
+ ) -> Any:
+ """模拟命令 RPC 调用。"""
+
+ captured["method"] = method
+ captured["plugin_id"] = plugin_id
+ captured["component_name"] = component_name
+ captured["args"] = args or {}
+ captured["timeout_ms"] = timeout_ms
+ return SimpleNamespace(payload={"success": True, "result": (True, "command ok", True)})
+
+ monkeypatch.setattr(supervisor, "invoke_plugin", fake_invoke_plugin)
+
+ matched = component_query_service.find_command_by_text("/test hello")
+ assert matched is not None
+ command_executor, matched_groups, command_info = matched
+
+ assert matched_groups == {}
+ assert command_info.plugin_name == plugin_id
+ assert command_info.command_pattern == r"^/test(?:\s+.+)?$"
+
+ success, response_text, intercept = await command_executor(
+ message=SimpleNamespace(processed_plain_text="/test hello", session_id="stream-2"),
+ plugin_config={"mode": "command"},
+ matched_groups=matched_groups,
+ )
+
+ assert success is True
+ assert response_text == "command ok"
+ assert intercept is True
+ assert captured["method"] == "plugin.invoke_command"
+ assert captured["plugin_id"] == plugin_id
+ assert captured["component_name"] == command_name
+ assert captured["args"]["text"] == "/test hello"
+ assert captured["args"]["stream_id"] == "stream-2"
+ assert captured["args"]["plugin_config"] == {"mode": "command"}
+
+
+@pytest.mark.asyncio
+async def test_core_component_registry_reads_runtime_tools_and_executor(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """核心查询层应直接读取运行时 Tool,并返回 RPC 执行闭包。"""
+
+ plugin_id = "runtime_tool_bridge_plugin"
+ tool_name = "runtime_tool_bridge_test"
+ supervisor = PluginSupervisor(plugin_dirs=[])
+
+ supervisor.component_registry.register_component(
+ name=tool_name,
+ component_type="TOOL",
+ plugin_id=plugin_id,
+ metadata={
+ "description": "测试工具",
+ "enabled": True,
+ "parameters": [
+ {
+ "name": "query",
+ "param_type": "string",
+ "description": "查询词",
+ "required": True,
+ }
+ ],
+ },
+ )
+ _install_runtime_manager(monkeypatch, supervisor, plugin_id)
+
+ async def fake_invoke_plugin(
+ method: str,
+ plugin_id: str,
+ component_name: str,
+ args: dict[str, Any] | None = None,
+ timeout_ms: int = 30000,
+ ) -> Any:
+ """模拟工具 RPC 调用。"""
+
+ del timeout_ms
+ assert method == "plugin.invoke_tool"
+ assert plugin_id == "runtime_tool_bridge_plugin"
+ assert component_name == "runtime_tool_bridge_test"
+ assert args == {"query": "MaiBot"}
+ return SimpleNamespace(payload={"success": True, "result": {"content": "tool ok"}})
+
+ monkeypatch.setattr(supervisor, "invoke_plugin", fake_invoke_plugin)
+
+ tool_info = component_query_service.get_tool_info(tool_name)
+ assert isinstance(tool_info, ToolInfo)
+ assert tool_info.tool_description == "测试工具"
+ assert tool_name in component_query_service.get_llm_available_tools()
+
+ executor = component_query_service.get_tool_executor(tool_name)
+ assert executor is not None
+ assert await executor({"query": "MaiBot"}) == {"content": "tool ok"}
diff --git a/pytests/test_plugin_runtime_api.py b/pytests/test_plugin_runtime_api.py
new file mode 100644
index 00000000..58a8e6ba
--- /dev/null
+++ b/pytests/test_plugin_runtime_api.py
@@ -0,0 +1,524 @@
+"""插件 API 注册与调用测试。"""
+
+from types import SimpleNamespace
+from typing import Any, Dict, List
+
+import pytest
+
+from src.plugin_runtime.integration import PluginRuntimeManager
+from src.plugin_runtime.host.supervisor import PluginSupervisor
+from src.plugin_runtime.protocol.envelope import (
+ ComponentDeclaration,
+ Envelope,
+ MessageType,
+ RegisterPluginPayload,
+ UnregisterPluginPayload,
+)
+
+
+def _build_manager(*supervisors: PluginSupervisor) -> PluginRuntimeManager:
+ """构造一个最小可用的插件运行时管理器。
+
+ Args:
+ *supervisors: 需要挂载的监督器列表。
+
+ Returns:
+ PluginRuntimeManager: 已注入监督器的运行时管理器。
+ """
+
+ manager = PluginRuntimeManager()
+ if supervisors:
+ manager._builtin_supervisor = supervisors[0]
+ if len(supervisors) > 1:
+ manager._third_party_supervisor = supervisors[1]
+ return manager
+
+
+async def _register_plugin(
+ supervisor: PluginSupervisor,
+ plugin_id: str,
+ components: List[Dict[str, Any]],
+) -> Envelope:
+ """通过 Supervisor 注册测试插件。
+
+ Args:
+ supervisor: 目标监督器。
+ plugin_id: 测试插件 ID。
+ components: 组件声明列表。
+
+ Returns:
+ Envelope: 注册响应信封。
+ """
+
+ payload = RegisterPluginPayload(
+ plugin_id=plugin_id,
+ plugin_version="1.0.0",
+ components=[
+ ComponentDeclaration(
+ name=str(component.get("name", "") or ""),
+ component_type=str(component.get("component_type", "") or ""),
+ plugin_id=plugin_id,
+ metadata=component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {},
+ )
+ for component in components
+ ],
+ )
+ return await supervisor._handle_register_plugin(
+ Envelope(
+ request_id=1,
+ message_type=MessageType.REQUEST,
+ method="plugin.register_components",
+ plugin_id=plugin_id,
+ payload=payload.model_dump(),
+ )
+ )
+
+
+async def _unregister_plugin(supervisor: PluginSupervisor, plugin_id: str) -> Envelope:
+ """通过 Supervisor 注销测试插件。
+
+ Args:
+ supervisor: 目标监督器。
+ plugin_id: 测试插件 ID。
+
+ Returns:
+ Envelope: 注销响应信封。
+ """
+
+ payload = UnregisterPluginPayload(plugin_id=plugin_id, reason="test")
+ return await supervisor._handle_unregister_plugin(
+ Envelope(
+ request_id=2,
+ message_type=MessageType.REQUEST,
+ method="plugin.unregister",
+ plugin_id=plugin_id,
+ payload=payload.model_dump(),
+ )
+ )
+
+
+@pytest.mark.asyncio
+async def test_register_plugin_syncs_dedicated_api_registry() -> None:
+ """插件注册时应将 API 同步到独立注册表,而不是通用组件表。"""
+
+ supervisor = PluginSupervisor(plugin_dirs=[])
+ response = await _register_plugin(
+ supervisor,
+ "provider",
+ [
+ {
+ "name": "render_html",
+ "component_type": "API",
+ "metadata": {
+ "description": "渲染 HTML",
+ "version": "1",
+ "public": True,
+ },
+ }
+ ],
+ )
+
+ assert response.payload["accepted"] is True
+ assert response.payload["registered_components"] == 0
+ assert response.payload["registered_apis"] == 1
+ assert supervisor.api_registry.get_api("provider", "render_html") is not None
+ assert supervisor.component_registry.get_component("provider.render_html") is None
+
+ unregister_response = await _unregister_plugin(supervisor, "provider")
+ assert unregister_response.payload["removed_apis"] == 1
+ assert supervisor.api_registry.get_api("provider", "render_html") is None
+
+
+@pytest.mark.asyncio
+async def test_api_call_allows_public_api_between_plugins(monkeypatch: pytest.MonkeyPatch) -> None:
+ """公开 API 应允许其他插件通过 Host 转发调用。"""
+
+ provider_supervisor = PluginSupervisor(plugin_dirs=[])
+ consumer_supervisor = PluginSupervisor(plugin_dirs=[])
+ await _register_plugin(
+ provider_supervisor,
+ "provider",
+ [
+ {
+ "name": "render_html",
+ "component_type": "API",
+ "metadata": {
+ "description": "渲染 HTML",
+ "version": "1",
+ "public": True,
+ },
+ }
+ ],
+ )
+ await _register_plugin(consumer_supervisor, "consumer", [])
+
+ captured: Dict[str, Any] = {}
+
+ async def fake_invoke_api(
+ plugin_id: str,
+ component_name: str,
+ args: Dict[str, Any] | None = None,
+ timeout_ms: int = 30000,
+ ) -> Any:
+ """模拟 API RPC 调用。"""
+
+ captured["plugin_id"] = plugin_id
+ captured["component_name"] = component_name
+ captured["args"] = args or {}
+ captured["timeout_ms"] = timeout_ms
+ return SimpleNamespace(error=None, payload={"success": True, "result": {"image": "ok"}})
+
+ monkeypatch.setattr(provider_supervisor, "invoke_api", fake_invoke_api)
+
+ manager = _build_manager(provider_supervisor, consumer_supervisor)
+ result = await manager._cap_api_call(
+ "consumer",
+ "api.call",
+ {
+ "api_name": "provider.render_html",
+ "version": "1",
+ "args": {"html": "
Hello
"},
+ },
+ )
+
+ assert result == {"success": True, "result": {"image": "ok"}}
+ assert captured["plugin_id"] == "provider"
+ assert captured["component_name"] == "render_html"
+ assert captured["args"] == {"html": "
Hello
"}
+
+
+@pytest.mark.asyncio
+async def test_api_call_rejects_private_api_between_plugins() -> None:
+ """未公开的 API 默认不允许跨插件调用。"""
+
+ provider_supervisor = PluginSupervisor(plugin_dirs=[])
+ consumer_supervisor = PluginSupervisor(plugin_dirs=[])
+ await _register_plugin(
+ provider_supervisor,
+ "provider",
+ [
+ {
+ "name": "secret_api",
+ "component_type": "API",
+ "metadata": {
+ "description": "私有 API",
+ "version": "1",
+ "public": False,
+ },
+ }
+ ],
+ )
+ await _register_plugin(consumer_supervisor, "consumer", [])
+
+ manager = _build_manager(provider_supervisor, consumer_supervisor)
+ result = await manager._cap_api_call(
+ "consumer",
+ "api.call",
+ {
+ "api_name": "provider.secret_api",
+ "args": {},
+ },
+ )
+
+ assert result["success"] is False
+ assert "未公开" in str(result["error"])
+
+
+@pytest.mark.asyncio
+async def test_api_list_and_component_toggle_use_dedicated_registry() -> None:
+ """API 列表与组件启停应直接作用于独立 API 注册表。"""
+
+ provider_supervisor = PluginSupervisor(plugin_dirs=[])
+ consumer_supervisor = PluginSupervisor(plugin_dirs=[])
+ await _register_plugin(
+ provider_supervisor,
+ "provider",
+ [
+ {
+ "name": "public_api",
+ "component_type": "API",
+ "metadata": {"version": "1", "public": True},
+ },
+ {
+ "name": "private_api",
+ "component_type": "API",
+ "metadata": {"version": "1", "public": False},
+ },
+ ],
+ )
+ await _register_plugin(
+ consumer_supervisor,
+ "consumer",
+ [
+ {
+ "name": "self_private_api",
+ "component_type": "API",
+ "metadata": {"version": "1", "public": False},
+ }
+ ],
+ )
+
+ manager = _build_manager(provider_supervisor, consumer_supervisor)
+ list_result = await manager._cap_api_list("consumer", "api.list", {})
+
+ assert list_result["success"] is True
+ api_names = {(item["plugin_id"], item["name"]) for item in list_result["apis"]}
+ assert ("provider", "public_api") in api_names
+ assert ("provider", "private_api") not in api_names
+ assert ("consumer", "self_private_api") in api_names
+
+ disable_result = await manager._cap_component_disable(
+ "consumer",
+ "component.disable",
+ {
+ "name": "provider.public_api",
+ "component_type": "API",
+ "scope": "global",
+ "stream_id": "",
+ },
+ )
+ assert disable_result["success"] is True
+ assert provider_supervisor.api_registry.get_api("provider", "public_api", enabled_only=True) is None
+
+ enable_result = await manager._cap_component_enable(
+ "consumer",
+ "component.enable",
+ {
+ "name": "provider.public_api",
+ "component_type": "API",
+ "scope": "global",
+ "stream_id": "",
+ },
+ )
+ assert enable_result["success"] is True
+ assert provider_supervisor.api_registry.get_api("provider", "public_api", enabled_only=True) is not None
+
+
+@pytest.mark.asyncio
+async def test_api_registry_supports_multiple_versions_with_distinct_handlers(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """同名 API 不同版本应可并存,并按版本路由到不同处理器。"""
+
+ provider_supervisor = PluginSupervisor(plugin_dirs=[])
+ consumer_supervisor = PluginSupervisor(plugin_dirs=[])
+ await _register_plugin(
+ provider_supervisor,
+ "provider",
+ [
+ {
+ "name": "render_html",
+ "component_type": "API",
+ "metadata": {
+ "description": "渲染 HTML v1",
+ "version": "1",
+ "public": True,
+ "handler_name": "handle_render_html_v1",
+ },
+ },
+ {
+ "name": "render_html",
+ "component_type": "API",
+ "metadata": {
+ "description": "渲染 HTML v2",
+ "version": "2",
+ "public": True,
+ "handler_name": "handle_render_html_v2",
+ },
+ },
+ ],
+ )
+ await _register_plugin(consumer_supervisor, "consumer", [])
+
+ captured: Dict[str, Any] = {}
+
+ async def fake_invoke_api(
+ plugin_id: str,
+ component_name: str,
+ args: Dict[str, Any] | None = None,
+ timeout_ms: int = 30000,
+ ) -> Any:
+ """模拟多版本 API 调用。"""
+
+ captured["plugin_id"] = plugin_id
+ captured["component_name"] = component_name
+ captured["args"] = args or {}
+ captured["timeout_ms"] = timeout_ms
+ return SimpleNamespace(error=None, payload={"success": True, "result": {"image": "ok"}})
+
+ monkeypatch.setattr(provider_supervisor, "invoke_api", fake_invoke_api)
+ manager = _build_manager(provider_supervisor, consumer_supervisor)
+
+ ambiguous_result = await manager._cap_api_call(
+ "consumer",
+ "api.call",
+ {
+ "api_name": "provider.render_html",
+ "args": {"html": "
Hello
"},
+ },
+ )
+ assert ambiguous_result["success"] is False
+ assert "多个版本" in str(ambiguous_result["error"])
+
+ disable_ambiguous_result = await manager._cap_component_disable(
+ "consumer",
+ "component.disable",
+ {
+ "name": "provider.render_html",
+ "component_type": "API",
+ "scope": "global",
+ "stream_id": "",
+ },
+ )
+ assert disable_ambiguous_result["success"] is False
+ assert "多个版本" in str(disable_ambiguous_result["error"])
+
+ disable_v1_result = await manager._cap_component_disable(
+ "consumer",
+ "component.disable",
+ {
+ "name": "provider.render_html",
+ "component_type": "API",
+ "scope": "global",
+ "stream_id": "",
+ "version": "1",
+ },
+ )
+ assert disable_v1_result["success"] is True
+ assert provider_supervisor.api_registry.get_api("provider", "render_html", version="1", enabled_only=True) is None
+ assert provider_supervisor.api_registry.get_api("provider", "render_html", version="2", enabled_only=True) is not None
+
+ result = await manager._cap_api_call(
+ "consumer",
+ "api.call",
+ {
+ "api_name": "provider.render_html",
+ "version": "2",
+ "args": {"html": "
Hello
"},
+ },
+ )
+
+ assert result == {"success": True, "result": {"image": "ok"}}
+ assert captured["plugin_id"] == "provider"
+ assert captured["component_name"] == "handle_render_html_v2"
+ assert captured["args"] == {"html": "
Hello
"}
+
+
+@pytest.mark.asyncio
+async def test_api_replace_dynamic_can_offline_removed_entries(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """动态 API 替换后,被移除的 API 应返回明确下线错误。"""
+
+ supervisor = PluginSupervisor(plugin_dirs=[])
+ await _register_plugin(supervisor, "provider", [])
+ manager = _build_manager(supervisor)
+
+ captured: Dict[str, Any] = {}
+
+ async def fake_invoke_api(
+ plugin_id: str,
+ component_name: str,
+ args: Dict[str, Any] | None = None,
+ timeout_ms: int = 30000,
+ ) -> Any:
+ """模拟动态 API 调用。"""
+
+ captured["plugin_id"] = plugin_id
+ captured["component_name"] = component_name
+ captured["args"] = args or {}
+ captured["timeout_ms"] = timeout_ms
+ return SimpleNamespace(error=None, payload={"success": True, "result": {"ok": True}})
+
+ monkeypatch.setattr(supervisor, "invoke_api", fake_invoke_api)
+
+ replace_result = await manager._cap_api_replace_dynamic(
+ "provider",
+ "api.replace_dynamic",
+ {
+ "apis": [
+ {
+ "name": "mcp.search",
+ "type": "API",
+ "metadata": {
+ "version": "1",
+ "public": True,
+ "handler_name": "dynamic_search",
+ },
+ },
+ {
+ "name": "mcp.read",
+ "type": "API",
+ "metadata": {
+ "version": "1",
+ "public": True,
+ "handler_name": "dynamic_read",
+ },
+ },
+ ],
+ "offline_reason": "MCP 服务器已关闭",
+ },
+ )
+
+ assert replace_result["success"] is True
+ assert replace_result["count"] == 2
+ list_result = await manager._cap_api_list("provider", "api.list", {"plugin_id": "provider"})
+ assert {(item["name"], item["version"]) for item in list_result["apis"]} == {
+ ("mcp.read", "1"),
+ ("mcp.search", "1"),
+ }
+
+ call_result = await manager._cap_api_call(
+ "provider",
+ "api.call",
+ {
+ "api_name": "provider.mcp.search",
+ "version": "1",
+ "args": {"query": "hello"},
+ },
+ )
+ assert call_result == {"success": True, "result": {"ok": True}}
+ assert captured["component_name"] == "dynamic_search"
+ assert captured["args"]["query"] == "hello"
+ assert captured["args"]["__maibot_api_name__"] == "mcp.search"
+ assert captured["args"]["__maibot_api_version__"] == "1"
+
+ second_replace_result = await manager._cap_api_replace_dynamic(
+ "provider",
+ "api.replace_dynamic",
+ {
+ "apis": [
+ {
+ "name": "mcp.read",
+ "type": "API",
+ "metadata": {
+ "version": "1",
+ "public": True,
+ "handler_name": "dynamic_read",
+ },
+ }
+ ],
+ "offline_reason": "MCP 服务器已关闭",
+ },
+ )
+
+ assert second_replace_result["success"] is True
+ assert second_replace_result["count"] == 1
+ assert second_replace_result["offlined"] == 1
+
+ offlined_call_result = await manager._cap_api_call(
+ "provider",
+ "api.call",
+ {
+ "api_name": "provider.mcp.search",
+ "version": "1",
+ "args": {},
+ },
+ )
+ assert offlined_call_result["success"] is False
+ assert "MCP 服务器已关闭" in str(offlined_call_result["error"])
+
+ list_after_replace = await manager._cap_api_list("provider", "api.list", {"plugin_id": "provider"})
+ assert {(item["name"], item["version"]) for item in list_after_replace["apis"]} == {
+ ("mcp.read", "1"),
+ }
diff --git a/pytests/test_send_service.py b/pytests/test_send_service.py
new file mode 100644
index 00000000..16aad080
--- /dev/null
+++ b/pytests/test_send_service.py
@@ -0,0 +1,154 @@
+"""发送服务回归测试。"""
+
+from types import SimpleNamespace
+from typing import Any, Dict, List
+
+import pytest
+
+from src.chat.message_receive.chat_manager import BotChatSession
+from src.services import send_service
+
+
+class _FakePlatformIOManager:
+ """用于测试的 Platform IO 管理器假对象。"""
+
+ def __init__(self, delivery_batch: Any) -> None:
+ """初始化假 Platform IO 管理器。
+
+ Args:
+ delivery_batch: 发送时返回的批量回执。
+ """
+ self._delivery_batch = delivery_batch
+ self.ensure_calls = 0
+ self.sent_messages: List[Dict[str, Any]] = []
+
+ async def ensure_send_pipeline_ready(self) -> None:
+ """记录发送管线准备调用次数。"""
+ self.ensure_calls += 1
+
+ def build_route_key_from_message(self, message: Any) -> Any:
+ """根据消息构造假的路由键。
+
+ Args:
+ message: 待发送的内部消息对象。
+
+ Returns:
+ Any: 简化后的路由键对象。
+ """
+ del message
+ return SimpleNamespace(platform="qq")
+
+ async def send_message(self, message: Any, route_key: Any, metadata: Dict[str, Any]) -> Any:
+ """记录发送请求并返回预设回执。
+
+ Args:
+ message: 待发送的内部消息对象。
+ route_key: 本次发送使用的路由键。
+ metadata: 发送元数据。
+
+ Returns:
+ Any: 预设的批量发送回执。
+ """
+ self.sent_messages.append(
+ {
+ "message": message,
+ "route_key": route_key,
+ "metadata": metadata,
+ }
+ )
+ return self._delivery_batch
+
+
+def _build_target_stream() -> BotChatSession:
+ """构造一个最小可用的目标会话对象。
+
+ Returns:
+ BotChatSession: 测试用会话对象。
+ """
+ return BotChatSession(
+ session_id="test-session",
+ platform="qq",
+ user_id="target-user",
+ group_id=None,
+ )
+
+
+def test_inherit_platform_io_route_metadata_falls_back_to_bot_account(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """没有上下文消息时,也应回填当前平台账号用于账号级路由命中。"""
+
+ monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq" if platform == "qq" else "")
+
+ metadata = send_service._inherit_platform_io_route_metadata(_build_target_stream())
+
+ assert metadata["platform_io_account_id"] == "bot-qq"
+ assert metadata["platform_io_target_user_id"] == "target-user"
+
+
+@pytest.mark.asyncio
+async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.MonkeyPatch) -> None:
+ """send service 应将发送职责统一交给 Platform IO。"""
+ fake_manager = _FakePlatformIOManager(
+ delivery_batch=SimpleNamespace(
+ has_success=True,
+ sent_receipts=[SimpleNamespace(driver_id="plugin.qq.sender")],
+ failed_receipts=[],
+ route_key=SimpleNamespace(platform="qq"),
+ )
+ )
+ stored_messages: List[Any] = []
+
+ monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager)
+ monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
+ monkeypatch.setattr(
+ send_service._chat_manager,
+ "get_session_by_session_id",
+ lambda stream_id: _build_target_stream() if stream_id == "test-session" else None,
+ )
+ monkeypatch.setattr(
+ send_service.MessageUtils,
+ "store_message_to_db",
+ lambda message: stored_messages.append(message),
+ )
+
+ result = await send_service.text_to_stream(text="你好", stream_id="test-session")
+
+ assert result is True
+ assert fake_manager.ensure_calls == 1
+ assert len(fake_manager.sent_messages) == 1
+ assert fake_manager.sent_messages[0]["metadata"] == {"show_log": False}
+ assert len(stored_messages) == 1
+
+
+@pytest.mark.asyncio
+async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch: pytest.MonkeyPatch) -> None:
+ """Platform IO 批量发送全部失败时,应直接向上返回失败。"""
+ fake_manager = _FakePlatformIOManager(
+ delivery_batch=SimpleNamespace(
+ has_success=False,
+ sent_receipts=[],
+ failed_receipts=[
+ SimpleNamespace(
+ driver_id="plugin.qq.sender",
+ status="failed",
+ error="network error",
+ )
+ ],
+ route_key=SimpleNamespace(platform="qq"),
+ )
+ )
+
+ monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager)
+ monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
+ monkeypatch.setattr(
+ send_service._chat_manager,
+ "get_session_by_session_id",
+ lambda stream_id: _build_target_stream() if stream_id == "test-session" else None,
+ )
+
+ result = await send_service.text_to_stream(text="发送失败", stream_id="test-session")
+
+ assert result is False
+ assert fake_manager.ensure_calls == 1
+ assert len(fake_manager.sent_messages) == 1
diff --git a/pytests/utils_test/statistic_test.py b/pytests/utils_test/statistic_test.py
new file mode 100644
index 00000000..d3d8c18a
--- /dev/null
+++ b/pytests/utils_test/statistic_test.py
@@ -0,0 +1,115 @@
+"""统计模块数据库会话行为测试。"""
+
+from __future__ import annotations
+
+from contextlib import contextmanager
+from datetime import datetime, timedelta
+from types import ModuleType
+from typing import Any, Callable, Iterator
+
+import sys
+
+import pytest
+
+from src.chat.utils import statistic
+
+
+class _DummyResult:
+ """模拟 SQLModel 查询结果对象。"""
+
+ def all(self) -> list[Any]:
+ """返回空结果集。
+
+ Returns:
+ list[Any]: 空列表。
+ """
+ return []
+
+
+class _DummySession:
+ """模拟数据库 Session。"""
+
+ def exec(self, statement: Any) -> _DummyResult:
+ """执行查询语句并返回空结果。
+
+ Args:
+ statement: 待执行的查询语句。
+
+ Returns:
+ _DummyResult: 空结果对象。
+ """
+ del statement
+ return _DummyResult()
+
+
+def _build_fake_get_db_session(calls: list[bool]) -> Callable[[bool], Iterator[_DummySession]]:
+ """构造一个记录 auto_commit 参数的假会话工厂。
+
+ Args:
+ calls: 用于记录每次调用 auto_commit 参数的列表。
+
+ Returns:
+ Callable[[bool], Iterator[_DummySession]]: 可替换 `get_db_session` 的上下文管理器工厂。
+ """
+
+ @contextmanager
+ def _fake_get_db_session(auto_commit: bool = True) -> Iterator[_DummySession]:
+ """记录会话参数并返回假 Session。
+
+ Args:
+ auto_commit: 是否启用自动提交。
+
+ Yields:
+ Iterator[_DummySession]: 假 Session 对象。
+ """
+ calls.append(auto_commit)
+ yield _DummySession()
+
+ return _fake_get_db_session
+
+
+def _build_statistic_task() -> statistic.StatisticOutputTask:
+ """构造一个最小可用的统计任务实例。
+
+ Returns:
+ statistic.StatisticOutputTask: 跳过 `__init__` 的测试实例。
+ """
+ task = statistic.StatisticOutputTask.__new__(statistic.StatisticOutputTask)
+ task.name_mapping = {}
+ return task
+
+
+def _is_bot_self(platform: str, user_id: str) -> bool:
+ """返回固定的非机器人身份判断结果。
+
+ Args:
+ platform: 平台名称。
+ user_id: 用户 ID。
+
+ Returns:
+ bool: 始终返回 ``False``。
+ """
+ del platform
+ del user_id
+ return False
+
+
+def test_statistic_read_queries_disable_auto_commit(monkeypatch: pytest.MonkeyPatch) -> None:
+ """统计模块的纯读查询应关闭自动提交,避免 Session 退出后对象被 expire。"""
+ calls: list[bool] = []
+ now = datetime.now()
+ task = _build_statistic_task()
+
+ monkeypatch.setattr(statistic, "get_db_session", _build_fake_get_db_session(calls))
+
+ utils_module = ModuleType("src.chat.utils.utils")
+ utils_module.is_bot_self = _is_bot_self
+ monkeypatch.setitem(sys.modules, "src.chat.utils.utils", utils_module)
+
+ statistic.StatisticOutputTask._fetch_online_time_since(now)
+ statistic.StatisticOutputTask._fetch_model_usage_since(now)
+ task._collect_message_count_for_period([("last_hour", now - timedelta(hours=1))])
+ task._collect_interval_data(now, hours=1, interval_minutes=60)
+ task._collect_metrics_interval_data(now, hours=1, interval_hours=1)
+
+ assert calls == [False] * 9
diff --git a/pytests/utils_test/test_session_utils.py b/pytests/utils_test/test_session_utils.py
new file mode 100644
index 00000000..c44e2eba
--- /dev/null
+++ b/pytests/utils_test/test_session_utils.py
@@ -0,0 +1,42 @@
+from types import SimpleNamespace
+
+from src.chat.message_receive.chat_manager import ChatManager
+from src.common.utils.utils_session import SessionUtils
+
+
+def test_calculate_session_id_distinguishes_account_and_scope() -> None:
+ base_session_id = SessionUtils.calculate_session_id("qq", user_id="42")
+ same_base_session_id = SessionUtils.calculate_session_id("qq", user_id="42")
+ account_scoped_session_id = SessionUtils.calculate_session_id("qq", user_id="42", account_id="123")
+ route_scoped_session_id = SessionUtils.calculate_session_id("qq", user_id="42", account_id="123", scope="main")
+
+ assert base_session_id == same_base_session_id
+ assert account_scoped_session_id != base_session_id
+ assert route_scoped_session_id != account_scoped_session_id
+
+
+def test_chat_manager_register_message_uses_route_metadata() -> None:
+ chat_manager = ChatManager()
+ message = SimpleNamespace(
+ platform="qq",
+ session_id="",
+ message_info=SimpleNamespace(
+ user_info=SimpleNamespace(user_id="42"),
+ group_info=SimpleNamespace(group_id="1000"),
+ additional_config={
+ "platform_io_account_id": "123",
+ "platform_io_scope": "main",
+ },
+ ),
+ )
+
+ chat_manager.register_message(message)
+
+ assert message.session_id == SessionUtils.calculate_session_id(
+ "qq",
+ user_id="42",
+ group_id="1000",
+ account_id="123",
+ scope="main",
+ )
+ assert chat_manager.last_messages[message.session_id] is message
diff --git a/src/chat/brain_chat/PFC/message_sender.py b/src/chat/brain_chat/PFC/message_sender.py
index ec5fb5ba..b9da905c 100644
--- a/src/chat/brain_chat/PFC/message_sender.py
+++ b/src/chat/brain_chat/PFC/message_sender.py
@@ -1,27 +1,28 @@
-import time
+"""PFC 侧消息发送封装。"""
+
from typing import Optional
-from maim_message import Seg
from rich.traceback import install
-from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
from src.chat.message_receive.chat_manager import BotChatSession
-from src.chat.message_receive.message import MessageSending
-from src.chat.message_receive.uni_message_sender import UniversalMessageSender
-from src.chat.utils.utils import get_bot_account
+from src.common.data_models.mai_message_data_model import MaiMessage
from src.common.logger import get_logger
-from src.config.config import global_config
+from src.services import send_service as send_api
install(extra_lines=3)
-
logger = get_logger("message_sender")
class DirectMessageSender:
- """直接消息发送器"""
+ """直接消息发送器。"""
- def __init__(self, private_name: str):
+ def __init__(self, private_name: str) -> None:
+ """初始化直接消息发送器。
+
+ Args:
+ private_name: 当前私聊实例的名称。
+ """
self.private_name = private_name
async def send_message(
@@ -30,58 +31,31 @@ class DirectMessageSender:
content: str,
reply_to_message: Optional[MaiMessage] = None,
) -> None:
- """发送消息到聊天流
+ """发送文本消息到聊天流。
Args:
- chat_stream: 聊天会话
- content: 消息内容
- reply_to_message: 要回复的消息(可选)
+ chat_stream: 目标聊天会话。
+ content: 待发送的文本内容。
+ reply_to_message: 可选的引用回复锚点消息。
+
+ Raises:
+ RuntimeError: 当消息发送失败时抛出。
"""
try:
- # 创建消息内容
- segments = Seg(type="seglist", data=[Seg(type="text", data=content)])
-
- # 获取麦麦的信息
- bot_user_id = get_bot_account(chat_stream.platform)
- if not bot_user_id:
- logger.error(f"[私聊][{self.private_name}]平台 {chat_stream.platform} 未配置机器人账号,无法发送消息")
- raise RuntimeError(f"平台 {chat_stream.platform} 未配置机器人账号")
- bot_user_info = UserInfo(
- user_id=bot_user_id,
- user_nickname=global_config.bot.nickname,
+ sent = await send_api.text_to_stream(
+ text=content,
+ stream_id=chat_stream.session_id,
+ set_reply=reply_to_message is not None,
+ reply_message=reply_to_message,
+ storage_message=True,
)
- # 用当前时间作为message_id,和之前那套sender一样
- message_id = f"dm{round(time.time(), 2)}"
-
- # 构建发送者信息(私聊时为接收者)
- sender_info = None
- if reply_to_message and reply_to_message.message_info and reply_to_message.message_info.user_info:
- sender_info = reply_to_message.message_info.user_info
-
- # 构建消息对象
- message = MessageSending(
- message_id=message_id,
- session=chat_stream,
- bot_user_info=bot_user_info,
- sender_info=sender_info,
- message_segment=segments,
- reply=reply_to_message,
- is_head=True,
- is_emoji=False,
- thinking_start_time=time.time(),
- )
-
- # 发送消息
- message_sender = UniversalMessageSender()
- sent = await message_sender.send_message(message, typing=False, set_reply=False, storage_message=True)
-
if sent:
logger.info(f"[私聊][{self.private_name}]PFC消息已发送: {content}")
- else:
- logger.error(f"[私聊][{self.private_name}]PFC消息发送失败")
- raise RuntimeError("消息发送失败")
+ return
- except Exception as e:
- logger.error(f"[私聊][{self.private_name}]PFC消息发送失败: {str(e)}")
+ logger.error(f"[私聊][{self.private_name}]PFC消息发送失败")
+ raise RuntimeError("消息发送失败")
+ except Exception as exc:
+ logger.error(f"[私聊][{self.private_name}]PFC消息发送失败: {exc}")
raise
diff --git a/src/chat/brain_chat/brain_chat.py b/src/chat/brain_chat/brain_chat.py
index 2b4863ac..1e9e648a 100644
--- a/src/chat/brain_chat/brain_chat.py
+++ b/src/chat/brain_chat/brain_chat.py
@@ -8,8 +8,8 @@ from rich.traceback import install
from src.config.config import global_config
from src.common.logger import get_logger
from src.common.utils.utils_config import ExpressionConfigUtils
-from src.bw_learner.expression_learner import ExpressionLearner
-from src.bw_learner.jargon_miner import JargonMiner
+from src.learners.expression_learner import ExpressionLearner
+from src.learners.jargon_miner import JargonMiner
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.chat.message_receive.message import SessionMessage
diff --git a/src/chat/brain_chat/brain_planner.py b/src/chat/brain_chat/brain_planner.py
index 12b103a0..709be8ee 100644
--- a/src/chat/brain_chat/brain_planner.py
+++ b/src/chat/brain_chat/brain_planner.py
@@ -1,30 +1,32 @@
+from datetime import datetime
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
+
import json
-import time
-import traceback
import random
import re
-from typing import Dict, Optional, Tuple, List, TYPE_CHECKING
-from rich.traceback import install
-from datetime import datetime
-from json_repair import repair_json
+import time
+import traceback
+
+from json_repair import repair_json
+from rich.traceback import install
-from src.llm_models.utils_model import LLMRequest
-from src.config.config import global_config, model_config
-from src.common.logger import get_logger
from src.chat.logger.plan_reply_logger import PlanReplyLogger
+from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
+from src.chat.planner_actions.action_manager import ActionManager
+from src.chat.utils.utils import get_chat_type_and_target_info
from src.common.data_models.info_data_model import ActionPlannerInfo
+from src.common.logger import get_logger
from src.common.utils.utils_action import ActionUtils
+from src.config.config import global_config, model_config
+from src.core.types import ActionActivationType, ActionInfo, ComponentType
+from src.llm_models.utils_model import LLMRequest
+from src.plugin_runtime.component_query import component_query_service
from src.prompt.prompt_manager import prompt_manager
from src.services.message_service import (
build_readable_messages_with_id,
get_actions_by_timestamp_with_chat,
get_messages_before_time_in_chat,
)
-from src.chat.utils.utils import get_chat_type_and_target_info
-from src.chat.planner_actions.action_manager import ActionManager
-from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
-from src.core.types import ActionActivationType, ActionInfo, ComponentType
-from src.core.component_registry import component_registry
if TYPE_CHECKING:
from src.common.data_models.info_data_model import TargetPersonInfo
@@ -320,7 +322,7 @@ class BrainPlanner:
current_available_actions_dict = self.action_manager.get_using_actions()
# 获取完整的动作信息
- all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore
+ all_registered_actions: Dict[str, ActionInfo] = component_query_service.get_components_by_type( # type: ignore
ComponentType.ACTION
)
current_available_actions = {}
diff --git a/src/chat/heart_flow/heartFC_chat - 副本.py b/src/chat/heart_flow/heartFC_chat - 副本.py
deleted file mode 100644
index c805597d..00000000
--- a/src/chat/heart_flow/heartFC_chat - 副本.py
+++ /dev/null
@@ -1,734 +0,0 @@
-import asyncio
-import time
-import traceback
-import random
-from typing import List, Optional, Dict, Any, Tuple, TYPE_CHECKING
-from rich.traceback import install
-
-from src.config.config import global_config
-from src.common.logger import get_logger
-from src.common.data_models.info_data_model import ActionPlannerInfo
-from src.common.data_models.message_data_model import ReplyContentType
-from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
-from src.chat.utils.prompt_builder import global_prompt_manager
-from src.chat.utils.timer_calculator import Timer
-from src.chat.planner_actions.planner import ActionPlanner
-from src.chat.planner_actions.action_modifier import ActionModifier
-from src.chat.planner_actions.action_manager import ActionManager
-from src.chat.heart_flow.hfc_utils import CycleDetail
-from src.bw_learner.expression_learner import expression_learner_manager
-from src.chat.heart_flow.frequency_control import frequency_control_manager
-from src.bw_learner.message_recorder import extract_and_distribute_messages
-from src.person_info.person_info import Person
-from src.plugin_system.base.component_types import EventType, ActionInfo
-from src.plugin_system.core import events_manager
-from src.plugin_system.apis import generator_api, send_api, message_api, database_api
-from src.chat.utils.chat_message_builder import (
- build_readable_messages_with_id,
- get_raw_msg_before_timestamp_with_chat,
-)
-from src.chat.utils.utils import record_replyer_action_temp
-from src.memory_system.chat_history_summarizer import ChatHistorySummarizer
-
-if TYPE_CHECKING:
- from src.common.data_models.database_data_model import DatabaseMessages
- from src.common.data_models.message_data_model import ReplySetModel
-
-
-ERROR_LOOP_INFO = {
- "loop_plan_info": {
- "action_result": {
- "action_type": "error",
- "action_data": {},
- "reasoning": "循环处理失败",
- },
- },
- "loop_action_info": {
- "action_taken": False,
- "reply_text": "",
- "command": "",
- "taken_time": time.time(),
- },
-}
-
-
-install(extra_lines=3)
-
-# 注释:原来的动作修改超时常量已移除,因为改为顺序执行
-
-logger = get_logger("hfc") # Logger Name Changed
-
-
-class HeartFChatting:
- """
- 管理一个连续的Focus Chat循环
- 用于在特定聊天流中生成回复。
- 其生命周期现在由其关联的 SubHeartflow 的 FOCUSED 状态控制。
- """
-
- def __init__(self, chat_id: str):
- """
- HeartFChatting 初始化函数
-
- 参数:
- chat_id: 聊天流唯一标识符(如stream_id)
- on_stop_focus_chat: 当收到stop_focus_chat命令时调用的回调函数
- performance_version: 性能记录版本号,用于区分不同启动版本
- """
- # 基础属性
- self.stream_id: str = chat_id # 聊天流ID
- self.chat_stream: ChatStream = get_chat_manager().get_stream(self.stream_id) # type: ignore
- if not self.chat_stream:
- raise ValueError(f"无法找到聊天流: {self.stream_id}")
- self.log_prefix = f"[{get_chat_manager().get_stream_name(self.stream_id) or self.stream_id}]"
-
- self.expression_learner = expression_learner_manager.get_expression_learner(self.stream_id)
-
- self.action_manager = ActionManager()
- self.action_planner = ActionPlanner(chat_id=self.stream_id, action_manager=self.action_manager)
- self.action_modifier = ActionModifier(action_manager=self.action_manager, chat_id=self.stream_id)
-
- # 循环控制内部状态
- self.running: bool = False
- self._loop_task: Optional[asyncio.Task] = None # 主循环任务
-
- # 添加循环信息管理相关的属性
- self.history_loop: List[CycleDetail] = []
- self._cycle_counter = 0
- self._current_cycle_detail: CycleDetail = None # type: ignore
-
- self.last_read_time = time.time() - 2
-
- self.is_mute = False
-
- self.last_active_time = time.time() # 记录上一次非noreply时间
-
- self.question_probability_multiplier = 1
- self.questioned = False
-
- # 跟踪连续 no_reply 次数,用于动态调整阈值
- self.consecutive_no_reply_count = 0
-
- # 聊天内容概括器
- self.chat_history_summarizer = ChatHistorySummarizer(chat_id=self.stream_id)
-
- async def start(self):
- """检查是否需要启动主循环,如果未激活则启动。"""
-
- # 如果循环已经激活,直接返回
- if self.running:
- logger.debug(f"{self.log_prefix} HeartFChatting 已激活,无需重复启动")
- return
-
- try:
- # 标记为活动状态,防止重复启动
- self.running = True
-
- self._loop_task = asyncio.create_task(self._main_chat_loop())
- self._loop_task.add_done_callback(self._handle_loop_completion)
-
- # 启动聊天内容概括器的后台定期检查循环
- await self.chat_history_summarizer.start()
-
- logger.info(f"{self.log_prefix} HeartFChatting 启动完成")
-
- except Exception as e:
- # 启动失败时重置状态
- self.running = False
- self._loop_task = None
- logger.error(f"{self.log_prefix} HeartFChatting 启动失败: {e}")
- raise
-
- def _handle_loop_completion(self, task: asyncio.Task):
- """当 _hfc_loop 任务完成时执行的回调。"""
- try:
- if exception := task.exception():
- logger.error(f"{self.log_prefix} HeartFChatting: 脱离了聊天(异常): {exception}")
- logger.error(traceback.format_exc()) # Log full traceback for exceptions
- else:
- logger.info(f"{self.log_prefix} HeartFChatting: 脱离了聊天 (外部停止)")
- except asyncio.CancelledError:
- logger.info(f"{self.log_prefix} HeartFChatting: 结束了聊天")
-
- def start_cycle(self) -> Tuple[Dict[str, float], str]:
- self._cycle_counter += 1
- self._current_cycle_detail = CycleDetail(self._cycle_counter)
- self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
- cycle_timers = {}
- return cycle_timers, self._current_cycle_detail.thinking_id
-
- def end_cycle(self, loop_info, cycle_timers):
- self._current_cycle_detail.set_loop_info(loop_info)
- self.history_loop.append(self._current_cycle_detail)
- self._current_cycle_detail.timers = cycle_timers
- self._current_cycle_detail.end_time = time.time()
-
- def print_cycle_info(self, cycle_timers):
- # 记录循环信息和计时器结果
- timer_strings = []
- for name, elapsed in cycle_timers.items():
- if elapsed < 0.1:
- # 不显示小于0.1秒的计时器
- continue
- formatted_time = f"{elapsed:.2f}秒"
- timer_strings.append(f"{name}: {formatted_time}")
-
- logger.info(
- f"{self.log_prefix} 第{self._current_cycle_detail.cycle_id}次思考,"
- f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒;" # type: ignore
- + (f"详情: {'; '.join(timer_strings)}" if timer_strings else "")
- )
-
- async def _loopbody(self):
- recent_messages_list = message_api.get_messages_by_time_in_chat(
- chat_id=self.stream_id,
- start_time=self.last_read_time,
- end_time=time.time(),
- limit=20,
- limit_mode="latest",
- filter_mai=True,
- filter_command=False,
- filter_intercept_message_level=0,
- )
-
- # 根据连续 no_reply 次数动态调整阈值
- # 3次 no_reply 时,阈值调高到 1.5(50%概率为1,50%概率为2)
- # 5次 no_reply 时,提高到 2(大于等于两条消息的阈值)
- if self.consecutive_no_reply_count >= 5:
- threshold = 2
- elif self.consecutive_no_reply_count >= 3:
- # 1.5 的含义:50%概率为1,50%概率为2
- threshold = 2 if random.random() < 0.5 else 1
- else:
- threshold = 1
-
- if len(recent_messages_list) >= threshold:
- # for message in recent_messages_list:
- # print(message.processed_plain_text)
-
- self.last_read_time = time.time()
-
- # !此处使at或者提及必定回复
- mentioned_message = None
- for message in recent_messages_list:
- if (message.is_mentioned or message.is_at) and global_config.chat.mentioned_bot_reply:
- mentioned_message = message
-
- # logger.info(f"{self.log_prefix} 当前talk_value: {global_config.chat.get_talk_value(self.stream_id)}")
-
- # *控制频率用
- if mentioned_message:
- await self._observe(recent_messages_list=recent_messages_list, force_reply_message=mentioned_message)
- elif (
- random.random()
- < global_config.chat.get_talk_value(self.stream_id)
- * frequency_control_manager.get_or_create_frequency_control(self.stream_id).get_talk_frequency_adjust()
- ):
- await self._observe(recent_messages_list=recent_messages_list)
- else:
- # 没有提到,继续保持沉默,等待5秒防止频繁触发
- await asyncio.sleep(10)
- return True
- else:
- await asyncio.sleep(0.2)
- return True
- return True
-
- async def _send_and_store_reply(
- self,
- response_set: "ReplySetModel",
- action_message: "DatabaseMessages",
- cycle_timers: Dict[str, float],
- thinking_id,
- actions,
- selected_expressions: Optional[List[int]] = None,
- quote_message: Optional[bool] = None,
- ) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
- with Timer("回复发送", cycle_timers):
- reply_text = await self._send_response(
- reply_set=response_set,
- message_data=action_message,
- selected_expressions=selected_expressions,
- quote_message=quote_message,
- )
-
- # 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
- platform = action_message.chat_info.platform
- if platform is None:
- platform = getattr(self.chat_stream, "platform", "unknown")
-
- person = Person(platform=platform, user_id=action_message.user_info.user_id)
- person_name = person.person_name
- action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
-
- await database_api.store_action_info(
- chat_stream=self.chat_stream,
- action_build_into_prompt=False,
- action_prompt_display=action_prompt_display,
- action_done=True,
- thinking_id=thinking_id,
- action_data={"reply_text": reply_text},
- action_name="reply",
- )
-
- # 构建循环信息
- loop_info: Dict[str, Any] = {
- "loop_plan_info": {
- "action_result": actions,
- },
- "loop_action_info": {
- "action_taken": True,
- "reply_text": reply_text,
- "command": "",
- "taken_time": time.time(),
- },
- }
-
- return loop_info, reply_text, cycle_timers
-
- async def _observe(
- self, # interest_value: float = 0.0,
- recent_messages_list: Optional[List["DatabaseMessages"]] = None,
- force_reply_message: Optional["DatabaseMessages"] = None,
- ) -> bool: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
- if recent_messages_list is None:
- recent_messages_list = []
- _reply_text = "" # 初始化reply_text变量,避免UnboundLocalError
-
- start_time = time.time()
- async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
- # 通过 MessageRecorder 统一提取消息并分发给 expression_learner 和 jargon_miner
- # 在 replyer 执行时触发,统一管理时间窗口,避免重复获取消息
- asyncio.create_task(extract_and_distribute_messages(self.stream_id))
-
- # 添加curious检测任务 - 检测聊天记录中的矛盾、冲突或需要提问的内容
- # asyncio.create_task(check_and_make_question(self.stream_id))
- # 添加聊天内容概括任务 - 累积、打包和压缩聊天记录
- # 注意:后台循环已在start()中启动,这里作为额外触发点,在有思考时立即处理
- # asyncio.create_task(self.chat_history_summarizer.process())
-
- cycle_timers, thinking_id = self.start_cycle()
- logger.info(
- f"{self.log_prefix} 开始第{self._cycle_counter}次思考(频率: {global_config.chat.get_talk_value(self.stream_id)})"
- )
-
- # 第一步:动作检查
- available_actions: Dict[str, ActionInfo] = {}
- try:
- await self.action_modifier.modify_actions()
- available_actions = self.action_manager.get_using_actions()
- except Exception as e:
- logger.error(f"{self.log_prefix} 动作修改失败: {e}")
-
- # 执行planner
- is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
-
- message_list_before_now = get_raw_msg_before_timestamp_with_chat(
- chat_id=self.stream_id,
- timestamp=time.time(),
- limit=int(global_config.chat.max_context_size * 0.6),
- filter_intercept_message_level=1,
- )
- chat_content_block, message_id_list = build_readable_messages_with_id(
- messages=message_list_before_now,
- timestamp_mode="normal_no_YMD",
- read_mark=self.action_planner.last_obs_time_mark,
- truncate=True,
- show_actions=True,
- )
-
- prompt_info = await self.action_planner.build_planner_prompt(
- is_group_chat=is_group_chat,
- chat_target_info=chat_target_info,
- current_available_actions=available_actions,
- chat_content_block=chat_content_block,
- message_id_list=message_id_list,
- )
- continue_flag, modified_message = await events_manager.handle_mai_events(
- EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id
- )
- if not continue_flag:
- return False
- if modified_message and modified_message._modify_flags.modify_llm_prompt:
- prompt_info = (modified_message.llm_prompt, prompt_info[1])
-
- with Timer("规划器", cycle_timers):
- action_to_use_info = await self.action_planner.plan(
- loop_start_time=self.last_read_time,
- available_actions=available_actions,
- force_reply_message=force_reply_message,
- )
-
- logger.info(
- f"{self.log_prefix} 决定执行{len(action_to_use_info)}个动作: {' '.join([a.action_type for a in action_to_use_info])}"
- )
-
- # 3. 并行执行所有动作
- action_tasks = [
- asyncio.create_task(
- self._execute_action(action, action_to_use_info, thinking_id, available_actions, cycle_timers)
- )
- for action in action_to_use_info
- ]
-
- # 并行执行所有任务
- results = await asyncio.gather(*action_tasks, return_exceptions=True)
-
- # 处理执行结果
- reply_loop_info = None
- reply_text_from_reply = ""
- action_success = False
- action_reply_text = ""
-
- excute_result_str = ""
- for result in results:
- excute_result_str += f"{result['action_type']} 执行结果:{result['result']}\n"
-
- if isinstance(result, BaseException):
- logger.error(f"{self.log_prefix} 动作执行异常: {result}")
- continue
-
- if result["action_type"] != "reply":
- action_success = result["success"]
- action_reply_text = result["result"]
- elif result["action_type"] == "reply":
- if result["success"]:
- reply_loop_info = result["loop_info"]
- reply_text_from_reply = result["result"]
- else:
- logger.warning(f"{self.log_prefix} 回复动作执行失败")
-
- self.action_planner.add_plan_excute_log(result=excute_result_str)
-
- # 构建最终的循环信息
- if reply_loop_info:
- # 如果有回复信息,使用回复的loop_info作为基础
- loop_info = reply_loop_info
- # 更新动作执行信息
- loop_info["loop_action_info"].update(
- {
- "action_taken": action_success,
- "taken_time": time.time(),
- }
- )
- _reply_text = reply_text_from_reply
- else:
- # 没有回复信息,构建纯动作的loop_info
- loop_info = {
- "loop_plan_info": {
- "action_result": action_to_use_info,
- },
- "loop_action_info": {
- "action_taken": action_success,
- "reply_text": action_reply_text,
- "taken_time": time.time(),
- },
- }
- _reply_text = action_reply_text
-
- self.end_cycle(loop_info, cycle_timers)
- self.print_cycle_info(cycle_timers)
-
- end_time = time.time()
- if end_time - start_time < global_config.chat.planner_smooth:
- wait_time = global_config.chat.planner_smooth - (end_time - start_time)
- await asyncio.sleep(wait_time)
- else:
- await asyncio.sleep(0.1)
- return True
-
- async def _main_chat_loop(self):
- """主循环,持续进行计划并可能回复消息,直到被外部取消。"""
- try:
- while self.running:
- # 主循环
- success = await self._loopbody()
- await asyncio.sleep(0.1)
- if not success:
- break
- except asyncio.CancelledError:
- # 设置了关闭标志位后被取消是正常流程
- logger.info(f"{self.log_prefix} 麦麦已关闭聊天")
- except Exception:
- logger.error(f"{self.log_prefix} 麦麦聊天意外错误,将于3s后尝试重新启动")
- print(traceback.format_exc())
- await asyncio.sleep(3)
- self._loop_task = asyncio.create_task(self._main_chat_loop())
- logger.error(f"{self.log_prefix} 结束了当前聊天循环")
-
- async def _handle_action(
- self,
- action: str,
- action_reasoning: str,
- action_data: dict,
- cycle_timers: Dict[str, float],
- thinking_id: str,
- action_message: Optional["DatabaseMessages"] = None,
- ) -> tuple[bool, str, str]:
- """
- 处理规划动作,使用动作工厂创建相应的动作处理器
-
- 参数:
- action: 动作类型
- action_reasoning: 决策理由
- action_data: 动作数据,包含不同动作需要的参数
- cycle_timers: 计时器字典
- thinking_id: 思考ID
- action_message: 消息数据
- 返回:
- tuple[bool, str, str]: (是否执行了动作, 思考消息ID, 命令)
- """
- try:
- # 使用工厂创建动作处理器实例
- try:
- action_handler = self.action_manager.create_action(
- action_name=action,
- action_data=action_data,
- cycle_timers=cycle_timers,
- thinking_id=thinking_id,
- chat_stream=self.chat_stream,
- log_prefix=self.log_prefix,
- action_reasoning=action_reasoning,
- action_message=action_message,
- )
- except Exception as e:
- logger.error(f"{self.log_prefix} 创建动作处理器时出错: {e}")
- traceback.print_exc()
- return False, ""
-
- # 处理动作并获取结果(固定记录一次动作信息)
- result = await action_handler.execute()
- success, action_text = result
-
- return success, action_text
-
- except Exception as e:
- logger.error(f"{self.log_prefix} 处理{action}时出错: {e}")
- traceback.print_exc()
- return False, ""
-
- async def _send_response(
- self,
- reply_set: "ReplySetModel",
- message_data: "DatabaseMessages",
- selected_expressions: Optional[List[int]] = None,
- quote_message: Optional[bool] = None,
- ) -> str:
- # 根据 llm_quote 配置决定是否使用 quote_message 参数
- if global_config.chat.llm_quote:
- # 如果配置为 true,使用 llm_quote 参数决定是否引用回复
- if quote_message is None:
- logger.warning(f"{self.log_prefix} quote_message 参数为空,不引用")
- need_reply = False
- else:
- need_reply = quote_message
- if need_reply:
- logger.info(f"{self.log_prefix} LLM 决定使用引用回复")
- else:
- # 如果配置为 false,使用原来的模式
- new_message_count = message_api.count_new_messages(
- chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time()
- )
- need_reply = new_message_count >= random.randint(2, 3) or time.time() - self.last_read_time > 90
- if need_reply:
- logger.info(f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复,或者上次回复时间超过90秒")
-
- reply_text = ""
- first_replied = False
- for reply_content in reply_set.reply_data:
- if reply_content.content_type != ReplyContentType.TEXT:
- continue
- data: str = reply_content.content # type: ignore
- if not first_replied:
- await send_api.text_to_stream(
- text=data,
- stream_id=self.chat_stream.stream_id,
- reply_message=message_data,
- set_reply=need_reply,
- typing=False,
- selected_expressions=selected_expressions,
- )
- first_replied = True
- else:
- await send_api.text_to_stream(
- text=data,
- stream_id=self.chat_stream.stream_id,
- reply_message=message_data,
- set_reply=False,
- typing=True,
- selected_expressions=selected_expressions,
- )
- reply_text += data
-
- return reply_text
-
- async def _execute_action(
- self,
- action_planner_info: ActionPlannerInfo,
- chosen_action_plan_infos: List[ActionPlannerInfo],
- thinking_id: str,
- available_actions: Dict[str, ActionInfo],
- cycle_timers: Dict[str, float],
- ):
- """执行单个动作的通用函数"""
- try:
- with Timer(f"动作{action_planner_info.action_type}", cycle_timers):
- # 直接当场执行no_reply逻辑
- if action_planner_info.action_type == "no_reply":
- # 直接处理no_reply逻辑,不再通过动作系统
- reason = action_planner_info.reasoning or "选择不回复"
- # logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
-
- # 增加连续 no_reply 计数
- self.consecutive_no_reply_count += 1
-
- await database_api.store_action_info(
- chat_stream=self.chat_stream,
- action_build_into_prompt=False,
- action_prompt_display=reason,
- action_done=True,
- thinking_id=thinking_id,
- action_data={},
- action_name="no_reply",
- action_reasoning=reason,
- )
-
- return {"action_type": "no_reply", "success": True, "result": "选择不回复", "command": ""}
-
- elif action_planner_info.action_type == "reply":
- # 直接当场执行reply逻辑
- self.questioned = False
- # 刷新主动发言状态
- # 重置连续 no_reply 计数
- self.consecutive_no_reply_count = 0
-
- reason = action_planner_info.reasoning or ""
- # 根据 think_mode 配置决定 think_level 的值
- think_mode = global_config.chat.think_mode
- if think_mode == "default":
- think_level = 0
- elif think_mode == "deep":
- think_level = 1
- elif think_mode == "dynamic":
- # dynamic 模式:从 planner 返回的 action_data 中获取
- think_level = action_planner_info.action_data.get("think_level", 1)
- else:
- # 默认使用 default 模式
- think_level = 0
- # 使用 action_reasoning(planner 的整体思考理由)作为 reply_reason
- planner_reasoning = action_planner_info.action_reasoning or reason
-
- record_replyer_action_temp(
- chat_id=self.stream_id,
- reason=reason,
- think_level=think_level,
- )
-
- await database_api.store_action_info(
- chat_stream=self.chat_stream,
- action_build_into_prompt=False,
- action_prompt_display=reason,
- action_done=True,
- thinking_id=thinking_id,
- action_data={},
- action_name="reply",
- action_reasoning=reason,
- )
-
- # 从 Planner 的 action_data 中提取未知词语列表(仅在 reply 时使用)
- unknown_words = None
- quote_message = None
- if isinstance(action_planner_info.action_data, dict):
- uw = action_planner_info.action_data.get("unknown_words")
- if isinstance(uw, list):
- cleaned_uw: List[str] = []
- for item in uw:
- if isinstance(item, str):
- s = item.strip()
- if s:
- cleaned_uw.append(s)
- if cleaned_uw:
- unknown_words = cleaned_uw
-
- # 从 Planner 的 action_data 中提取 quote_message 参数
- qm = action_planner_info.action_data.get("quote")
- if qm is not None:
- # 支持多种格式:true/false, "true"/"false", 1/0
- if isinstance(qm, bool):
- quote_message = qm
- elif isinstance(qm, str):
- quote_message = qm.lower() in ("true", "1", "yes")
- elif isinstance(qm, (int, float)):
- quote_message = bool(qm)
-
- logger.info(f"{self.log_prefix} {qm}引用回复设置: {quote_message}")
-
- success, llm_response = await generator_api.generate_reply(
- chat_stream=self.chat_stream,
- reply_message=action_planner_info.action_message,
- available_actions=available_actions,
- chosen_actions=chosen_action_plan_infos,
- reply_reason=planner_reasoning,
- unknown_words=unknown_words,
- enable_tool=global_config.tool.enable_tool,
- request_type="replyer",
- from_plugin=False,
- reply_time_point=action_planner_info.action_data.get("loop_start_time", time.time()),
- think_level=think_level,
- )
-
- if not success or not llm_response or not llm_response.reply_set:
- if action_planner_info.action_message:
- logger.info(f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败")
- else:
- logger.info("回复生成失败")
- return {"action_type": "reply", "success": False, "result": "回复生成失败", "loop_info": None}
-
- response_set = llm_response.reply_set
- selected_expressions = llm_response.selected_expressions
- loop_info, reply_text, _ = await self._send_and_store_reply(
- response_set=response_set,
- action_message=action_planner_info.action_message, # type: ignore
- cycle_timers=cycle_timers,
- thinking_id=thinking_id,
- actions=chosen_action_plan_infos,
- selected_expressions=selected_expressions,
- quote_message=quote_message,
- )
- self.last_active_time = time.time()
- return {
- "action_type": "reply",
- "success": True,
- "result": f"你使用reply动作,对' {action_planner_info.action_message.processed_plain_text} '这句话进行了回复,回复内容为: '{reply_text}'",
- "loop_info": loop_info,
- }
-
- else:
- # 执行普通动作
- with Timer("动作执行", cycle_timers):
- success, result = await self._handle_action(
- action=action_planner_info.action_type,
- action_reasoning=action_planner_info.action_reasoning or "",
- action_data=action_planner_info.action_data or {},
- cycle_timers=cycle_timers,
- thinking_id=thinking_id,
- action_message=action_planner_info.action_message,
- )
-
- self.last_active_time = time.time()
- return {
- "action_type": action_planner_info.action_type,
- "success": success,
- "result": result,
- }
-
- except Exception as e:
- logger.error(f"{self.log_prefix} 执行动作时出错: {e}")
- logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}")
- return {
- "action_type": action_planner_info.action_type,
- "success": False,
- "result": "",
- "loop_info": None,
- "error": str(e),
- }
diff --git a/src/chat/heart_flow/heartFC_chat.py b/src/chat/heart_flow/heartFC_chat.py
index af0beb4e..74d94773 100644
--- a/src/chat/heart_flow/heartFC_chat.py
+++ b/src/chat/heart_flow/heartFC_chat.py
@@ -1,377 +1,231 @@
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from rich.traceback import install
+from typing import List, Optional, TYPE_CHECKING
import asyncio
import random
import time
import traceback
-from rich.traceback import install
-
-from src.bw_learner.expression_learner import ExpressionLearner
-from src.bw_learner.jargon_miner import JargonMiner
-from src.chat.event_helpers import build_event_message
-from src.chat.logger.plan_reply_logger import PlanReplyLogger
-from src.chat.message_receive.chat_manager import BotChatSession
-from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
-from src.chat.planner_actions.action_manager import ActionManager
-from src.chat.planner_actions.action_modifier import ActionModifier
-from src.chat.planner_actions.planner import ActionPlanner
-from src.chat.utils.prompt_builder import global_prompt_manager
-from src.chat.utils.timer_calculator import Timer
-from src.chat.utils.utils import record_replyer_action_temp
-from src.common.data_models.info_data_model import ActionPlannerInfo
-from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
+from src.chat.message_receive.chat_manager import chat_manager
from src.common.logger import get_logger
from src.common.utils.utils_config import ChatConfigUtils, ExpressionConfigUtils
from src.config.config import global_config
from src.config.file_watcher import FileChange
-from src.core.event_bus import event_bus
-from src.core.types import ActionInfo, EventType
-from src.person_info.person_info import Person
-from src.services import (
- database_service as database_api,
- generator_service as generator_api,
- message_service as message_api,
- send_service as send_api,
-)
-from src.services.message_service import build_readable_messages_with_id, get_messages_before_time_in_chat
+from src.learners.expression_learner import ExpressionLearner
+from src.learners.jargon_miner import JargonMiner
from .heartFC_utils import CycleDetail
if TYPE_CHECKING:
from src.chat.message_receive.message import SessionMessage
-
install(extra_lines=5)
logger = get_logger("heartFC_chat")
class HeartFChatting:
- """管理一个持续运行的 Focus Chat 会话。"""
+ """
+ 管理一个连续的Focus Chat聊天会话
+ 用于在特定的聊天会话里面生成回复
+ """
def __init__(self, session_id: str):
- self.session_id = session_id
- self.chat_stream: BotChatSession = _chat_manager.get_session_by_session_id(self.session_id) # type: ignore[assignment]
- if not self.chat_stream:
- raise ValueError(f"无法找到聊天会话 {self.session_id}")
+ """
+ 初始化 HeartFChatting 实例
- session_name = _chat_manager.get_session_name(session_id) or session_id
+ Args:
+ session_id: 聊天会话ID
+ """
+ # 基础属性
+ self.session_id = session_id
+ session_name = chat_manager.get_session_name(session_id) or session_id
self.log_prefix = f"[{session_name}]"
self.session_name = session_name
- self.action_manager = ActionManager()
- self.action_planner = ActionPlanner(chat_id=self.session_id, action_manager=self.action_manager)
- self.action_modifier = ActionModifier(action_manager=self.action_manager, chat_id=self.session_id)
-
+ # 系统运行状态
self._running: bool = False
self._loop_task: Optional[asyncio.Task] = None
+ self._cycle_counter: int = 0
+ self._hfc_lock: asyncio.Lock = asyncio.Lock() # 用于保护 _hfc_func 的并发访问
+ # 聊天频率相关
+ self._consecutive_no_reply_count = 0 # 跟踪连续 no_reply 次数,用于动态调整阈值
+ self._talk_frequency_adjust: float = 1.0 # 发言频率修正值,默认为1.0,可以根据需要调整
+
+ # HFC内消息缓存
+ self.message_cache: List[SessionMessage] = []
+
+ # Asyncio Event 用于控制循环的开始和结束
self._cycle_event = asyncio.Event()
- self._hfc_lock = asyncio.Lock()
-
- self._cycle_counter = 0
- self._current_cycle_detail: Optional[CycleDetail] = None
- self.history_loop: List[CycleDetail] = []
-
- self.last_read_time = time.time() - 2
- self.last_active_time = time.time()
- self._talk_frequency_adjust = 1.0
- self._consecutive_no_reply_count = 0
-
- self.message_cache: List["SessionMessage"] = []
-
- self._min_messages_for_extraction = 30
- self._min_extraction_interval = 60
- self._last_extraction_time = 0.0
+ # 表达方式相关内容
+ self._min_messages_for_extraction = 30 # 最少提取消息数
+ self._min_extraction_interval = 60 # 最小提取时间间隔,单位为秒
+ self._last_extraction_time: float = 0.0 # 上次提取的时间戳
expr_use, jargon_learn, expr_learn = ExpressionConfigUtils.get_expression_config_for_chat(session_id)
- self._enable_expression_use = expr_use
- self._enable_expression_learning = expr_learn
- self._enable_jargon_learning = jargon_learn
- self._expression_learner = ExpressionLearner(session_id)
- self._jargon_miner = JargonMiner(session_id, session_name=session_name)
+ self._enable_expression_use = expr_use # 允许使用表达方式,但不一定启用学习
+ self._enable_expression_learning = expr_learn # 允许学习表达方式
+ self._enable_jargon_learning = jargon_learn # 允许学习黑话
+ # 表达学习器
+ self._expression_learner: ExpressionLearner = ExpressionLearner(session_id)
+ # 黑话挖掘器
+ self._jargon_miner: JargonMiner = JargonMiner(session_id, session_name=session_name)
+
+ # TODO: ChatSummarizer 聊天总结器重构
+
+ # ====== 公开方法 ======
async def start(self):
+ """启动 HeartFChatting 的主循环"""
+ # 先检查是否已经启动运行
if self._running:
- logger.debug(f"{self.log_prefix} HeartFChatting 已在运行中")
+ logger.debug(f"{self.log_prefix} 已经在运行中,无需重复启动")
return
try:
self._running = True
- self._cycle_event.clear()
+ self._cycle_event.clear() # 确保事件初始状态为未设置
+
self._loop_task = asyncio.create_task(self.main_loop())
self._loop_task.add_done_callback(self._handle_loop_completion)
+
logger.info(f"{self.log_prefix} HeartFChatting 启动完成")
- except Exception as exc:
- logger.error(f"{self.log_prefix} HeartFChatting 启动失败: {exc}", exc_info=True)
- self._running = False
- self._cycle_event.set()
- self._loop_task = None
+ except Exception as e:
+ logger.error(f"{self.log_prefix} 启动 HeartFChatting 失败: {e}", exc_info=True)
+ self._running = False # 确保状态正确
+ self._cycle_event.set() # 确保事件被设置,避免死锁
+ self._loop_task = None # 确保任务引用被清理
raise
async def stop(self):
+ """停止 HeartFChatting 的主循环"""
if not self._running:
- logger.debug(f"{self.log_prefix} HeartFChatting 已停止")
+ logger.debug(f"{self.log_prefix} HeartFChatting 已经停止,无需重复停止")
return
self._running = False
- self._cycle_event.set()
+ self._cycle_event.set() # 触发事件,通知循环结束
if self._loop_task:
- self._loop_task.cancel()
+ self._loop_task.cancel() # 取消主循环任务
try:
- await self._loop_task
+ await self._loop_task # 等待任务完成
except asyncio.CancelledError:
- logger.info(f"{self.log_prefix} HeartFChatting 主循环已取消")
- except Exception as exc:
- logger.error(f"{self.log_prefix} 停止 HeartFChatting 时发生错误: {exc}", exc_info=True)
+ logger.info(f"{self.log_prefix} HeartFChatting 主循环已成功取消")
+ except Exception as e:
+ logger.error(f"{self.log_prefix} 停止 HeartFChatting 时发生错误: {e}", exc_info=True)
finally:
- self._loop_task = None
+ self._loop_task = None # 确保任务引用被清理
logger.info(f"{self.log_prefix} HeartFChatting 已停止")
def adjust_talk_frequency(self, new_value: float):
+ """调整发言频率的调整值
+
+ Args:
+ new_value: 新的修正值,必须为非负数。值越大,修正发言频率越高;值越小,修正发言频率越低。
+ """
self._talk_frequency_adjust = max(0.0, new_value)
async def register_message(self, message: "SessionMessage"):
+ """注册一条消息到 HeartFChatting 的缓存中,并检测其是否产生提及,决定是否唤醒聊天
+
+ Args:
+ message: 待注册的消息对象
+ """
self.message_cache.append(message)
-
+ # 先检查at必回复
if global_config.chat.inevitable_at_reply and message.is_at:
- self.last_read_time = time.time()
- async with self._hfc_lock:
- await self._judge_and_response(mentioned_message=message, recent_messages_list=[message])
- return
-
+ async with self._hfc_lock: # 确保与主循环逻辑的互斥访问
+ await self._judge_and_response(message)
+ return # 直接返回,避免同一条消息被主循环再次处理
+ # 再检查提及必回复
if global_config.chat.mentioned_bot_reply and message.is_mentioned:
- self.last_read_time = time.time()
- async with self._hfc_lock:
- await self._judge_and_response(mentioned_message=message, recent_messages_list=[message])
+ # 直接获取锁,确保一定一定触发回复逻辑,不受当前是否正在执行主循环的影响
+ async with self._hfc_lock: # 确保与主循环逻辑的互斥访问
+ await self._judge_and_response(message)
return
async def main_loop(self):
try:
while self._running and not self._cycle_event.is_set():
if not self._hfc_lock.locked():
- async with self._hfc_lock:
+ async with self._hfc_lock: # 确保主循环逻辑的互斥访问
await self._hfc_func()
- await asyncio.sleep(0.1)
+ await asyncio.sleep(5)
except asyncio.CancelledError:
- logger.info(f"{self.log_prefix} HeartFChatting: 主循环被取消")
- except Exception as exc:
- logger.error(f"{self.log_prefix} HeartFChatting: 主循环异常: {exc}", exc_info=True)
- await self.stop()
+ logger.info(f"{self.log_prefix} HeartFChatting: 主循环被取消,正在关闭")
+ except Exception as e:
+ logger.error(f"{self.log_prefix} 麦麦聊天意外错误: {e},将于3s后尝试重新启动")
+ await self.stop() # 确保状态正确
await asyncio.sleep(3)
- await self.start()
+ await self.start() # 尝试重新启动
async def _config_callback(self, file_change: Optional[FileChange] = None):
- del file_change
- expr_use, jargon_learn, expr_learn = ExpressionConfigUtils.get_expression_config_for_chat(self.session_id)
- self._enable_expression_use = expr_use
- self._enable_expression_learning = expr_learn
- self._enable_jargon_learning = jargon_learn
+ """配置文件变更回调函数"""
+ # TODO: 根据配置文件变动重新计算相关参数:
+ """
+ 需要计算的参数:
+ self._enable_expression_use = expr_use # 允许使用表达方式,但不一定启用学习
+ self._enable_expression_learning = expr_learn # 允许学习表达方式
+ self._enable_jargon_learning = jargon_learn # 允许学习黑话
+ """
- async def _hfc_func(self):
- recent_messages_list = message_api.get_messages_by_time_in_chat(
- chat_id=self.session_id,
- start_time=self.last_read_time,
- end_time=time.time(),
- limit=20,
- limit_mode="latest",
- filter_mai=True,
- filter_command=False,
- filter_intercept_message_level=1,
- )
+ # ====== 心流聊天核心逻辑 ======
+ async def _hfc_func(self, mentioned_message: Optional["SessionMessage"] = None):
+ """心流聊天的主循环逻辑"""
+ if self._consecutive_no_reply_count >= 5:
+ threshold = 2
+ elif self._consecutive_no_reply_count >= 3:
+ threshold = 2 if random.random() < 0.5 else 1
+ else:
+ threshold = 1
- if len(recent_messages_list) < 1:
+ if len(self.message_cache) < threshold:
await asyncio.sleep(0.2)
return True
- self.last_read_time = time.time()
-
- mentioned_message: Optional["SessionMessage"] = None
- for message in recent_messages_list:
- if global_config.chat.inevitable_at_reply and message.is_at:
- mentioned_message = message
- elif global_config.chat.mentioned_bot_reply and message.is_mentioned:
- mentioned_message = message
-
- talk_value = ChatConfigUtils.get_talk_value(self.session_id) * self._talk_frequency_adjust
- if mentioned_message:
- await self._judge_and_response(mentioned_message=mentioned_message, recent_messages_list=recent_messages_list)
- elif random.random() < talk_value:
- await self._judge_and_response(recent_messages_list=recent_messages_list)
+ talk_value_threshold = (
+ random.random() * ChatConfigUtils.get_talk_value(self.session_id) * self._talk_frequency_adjust
+ )
+ if mentioned_message and global_config.chat.mentioned_bot_reply:
+ await self._judge_and_response(mentioned_message)
+ elif random.random() < talk_value_threshold:
+ await self._judge_and_response()
return True
- async def _judge_and_response(
- self,
- mentioned_message: Optional["SessionMessage"] = None,
- recent_messages_list: Optional[List["SessionMessage"]] = None,
- ):
- recent_messages = list(recent_messages_list or self.message_cache[-20:])
- if recent_messages:
- asyncio.create_task(self._trigger_expression_learning(recent_messages))
-
- cycle_timers, thinking_id = self._start_cycle()
+ async def _judge_and_response(self, mentioned_message: Optional["SessionMessage"] = None):
+ """判定和生成回复"""
+ asyncio.create_task(self._trigger_expression_learning(self.message_cache))
+ # TODO: 完成反思器之后的逻辑
+ start_time = time.time()
+ current_cycle_detail = self._start_cycle()
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
- try:
- async with global_prompt_manager.async_message_scope(self._get_template_name()):
- available_actions: Dict[str, ActionInfo] = {}
- try:
- await self.action_modifier.modify_actions()
- available_actions = self.action_manager.get_using_actions()
- except Exception as exc:
- logger.error(f"{self.log_prefix} 动作修改失败: {exc}", exc_info=True)
+ # TODO: 动作检查逻辑
+ # TODO: Planner逻辑
+ # TODO: 动作执行逻辑
- is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
- message_list_before_now = get_messages_before_time_in_chat(
- chat_id=self.session_id,
- timestamp=time.time(),
- limit=int(global_config.chat.max_context_size * 0.6),
- filter_intercept_message_level=1,
- )
- chat_content_block, message_id_list = build_readable_messages_with_id(
- messages=message_list_before_now,
- timestamp_mode="normal_no_YMD",
- read_mark=self.action_planner.last_obs_time_mark,
- truncate=True,
- show_actions=True,
- )
-
- prompt, filtered_actions = await self._build_planner_prompt_with_event(
- available_actions=available_actions,
- is_group_chat=is_group_chat,
- chat_target_info=chat_target_info,
- chat_content_block=chat_content_block,
- message_id_list=message_id_list,
- )
- if prompt is None:
- return False
-
- with Timer("规划器", cycle_timers):
- reasoning, action_to_use_info, llm_raw_output, llm_reasoning, llm_duration_ms = (
- await self.action_planner._execute_main_planner(
- prompt=prompt,
- message_id_list=message_id_list,
- filtered_actions=filtered_actions,
- available_actions=available_actions,
- loop_start_time=self.last_read_time,
- )
- )
-
- action_to_use_info = self._ensure_force_reply_action(
- actions=action_to_use_info,
- force_reply_message=mentioned_message,
- available_actions=available_actions,
- )
- self.action_planner.add_plan_log(reasoning, action_to_use_info)
- self.action_planner.last_obs_time_mark = time.time()
- self._log_plan(
- prompt=prompt,
- reasoning=reasoning,
- llm_raw_output=llm_raw_output,
- llm_reasoning=llm_reasoning,
- llm_duration_ms=llm_duration_ms,
- actions=action_to_use_info,
- )
-
- logger.info(
- f"{self.log_prefix} 决定执行{len(action_to_use_info)}个动作: {' '.join([a.action_type for a in action_to_use_info])}"
- )
-
- action_tasks = [
- asyncio.create_task(
- self._execute_action(
- action,
- action_to_use_info,
- thinking_id,
- available_actions,
- cycle_timers,
- )
- )
- for action in action_to_use_info
- ]
- results = await asyncio.gather(*action_tasks, return_exceptions=True)
-
- reply_loop_info = None
- reply_text_from_reply = ""
- action_success = False
- action_reply_text = ""
- execute_result_str = ""
-
- for result in results:
- if isinstance(result, BaseException):
- logger.error(f"{self.log_prefix} 动作执行异常: {result}", exc_info=True)
- continue
-
- execute_result_str += f"{result['action_type']} 执行结果:{result['result']}\n"
- if result["action_type"] == "reply":
- if result["success"]:
- reply_loop_info = result["loop_info"]
- reply_text_from_reply = result["result"]
- else:
- logger.warning(f"{self.log_prefix} reply 动作执行失败")
- else:
- action_success = result["success"]
- action_reply_text = result["result"]
-
- self.action_planner.add_plan_excute_log(result=execute_result_str)
-
- if reply_loop_info:
- loop_info = reply_loop_info
- loop_info["loop_action_info"].update(
- {
- "action_taken": action_success,
- "taken_time": time.time(),
- }
- )
- else:
- loop_info = {
- "loop_plan_info": {
- "action_result": action_to_use_info,
- },
- "loop_action_info": {
- "action_taken": action_success,
- "reply_text": action_reply_text,
- "taken_time": time.time(),
- },
- }
- reply_text_from_reply = action_reply_text
-
- current_cycle_detail = self._end_cycle(self._current_cycle_detail, loop_info)
- logger.debug(f"{self.log_prefix} 本轮最终输出: {reply_text_from_reply}")
- return current_cycle_detail is not None
- except Exception as exc:
- logger.error(f"{self.log_prefix} 判定与回复流程失败: {exc}", exc_info=True)
- if self._current_cycle_detail:
- self._end_cycle(
- self._current_cycle_detail,
- {
- "loop_plan_info": {"action_result": []},
- "loop_action_info": {
- "action_taken": False,
- "reply_text": "",
- "taken_time": time.time(),
- "error": str(exc),
- },
- },
- )
- return False
+ cycle_detail = self._end_cycle(current_cycle_detail)
+ if wait_time := global_config.chat.planner_smooth - (time.time() - start_time) > 0:
+ await asyncio.sleep(wait_time)
+ else:
+ await asyncio.sleep(0.1) # 最小等待时间,避免过快循环
+ return True
def _handle_loop_completion(self, task: asyncio.Task):
+ """当 _hfc_func 任务完成时执行的回调。"""
try:
if exception := task.exception():
- logger.error(f"{self.log_prefix} HeartFChatting: 主循环异常退出: {exception}")
- logger.error(traceback.format_exc())
+ logger.error(f"{self.log_prefix} HeartFChatting: 脱离了聊天(异常): {exception}")
+ logger.error(traceback.format_exc()) # Log full traceback for exceptions
else:
- logger.info(f"{self.log_prefix} HeartFChatting: 主循环已退出")
+ logger.info(f"{self.log_prefix} HeartFChatting: 脱离了聊天 (外部停止)")
except asyncio.CancelledError:
- logger.info(f"{self.log_prefix} HeartFChatting: 聊天已结束")
+ logger.info(f"{self.log_prefix} HeartFChatting: 结束了聊天")
+ # ====== 学习器触发逻辑 ======
async def _trigger_expression_learning(self, messages: List["SessionMessage"]):
- if not messages:
- return
-
self._expression_learner.add_messages(messages)
if time.time() - self._last_extraction_time < self._min_extraction_interval:
return
@@ -379,14 +233,12 @@ class HeartFChatting:
return
if not self._enable_expression_learning:
return
-
extraction_end_time = time.time()
logger.info(
f"聊天流 {self.session_name} 提取到 {len(messages)} 条消息,"
f"时间窗口: {self._last_extraction_time:.2f} - {extraction_end_time:.2f}"
)
self._last_extraction_time = extraction_end_time
-
try:
jargon_miner = self._jargon_miner if self._enable_jargon_learning else None
learnt_style = await self._expression_learner.learn(jargon_miner)
@@ -394,398 +246,43 @@ class HeartFChatting:
logger.info(f"{self.log_prefix} 表达学习完成")
else:
logger.debug(f"{self.log_prefix} 表达学习未获得有效结果")
- except Exception as exc:
- logger.error(f"{self.log_prefix} 表达学习失败: {exc}", exc_info=True)
+ except Exception as e:
+ logger.error(f"{self.log_prefix} 表达学习失败: {e}", exc_info=True)
- def _start_cycle(self) -> Tuple[Dict[str, float], str]:
+ # ====== 记录循环执行信息相关逻辑 ======
+ def _start_cycle(self) -> CycleDetail:
self._cycle_counter += 1
- self._current_cycle_detail = CycleDetail(cycle_id=self._cycle_counter)
- self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
- return self._current_cycle_detail.time_records, self._current_cycle_detail.thinking_id
+ current_cycle_detail = CycleDetail(cycle_id=self._cycle_counter)
+ current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
+ return current_cycle_detail
- def _end_cycle(self, cycle_detail: Optional[CycleDetail], loop_info: Optional[Dict[str, Any]] = None):
- if cycle_detail is None:
- return None
-
- cycle_detail.loop_plan_info = (loop_info or {}).get("loop_plan_info")
- cycle_detail.loop_action_info = (loop_info or {}).get("loop_action_info")
+ def _end_cycle(self, cycle_detail: CycleDetail, only_long_execution: bool = True):
cycle_detail.end_time = time.time()
- self.history_loop.append(cycle_detail)
-
- timer_strings = [
+ timer_strings: List[str] = [
f"{name}: {duration:.2f}s"
for name, duration in cycle_detail.time_records.items()
- if duration >= 0.1
+ if not only_long_execution or duration >= 0.1
]
logger.info(
- f"{self.log_prefix} 第{cycle_detail.cycle_id} 个心流循环完成,"
- f"耗时: {cycle_detail.end_time - cycle_detail.start_time:.2f}s;"
+ f"{self.log_prefix} 第 {cycle_detail.cycle_id} 个心流循环完成"
+ f"耗时: {cycle_detail.end_time - cycle_detail.start_time:.2f}秒\n"
f"详细计时: {', '.join(timer_strings) if timer_strings else '无'}"
)
+
return cycle_detail
- async def _execute_action(
- self,
- action_planner_info: ActionPlannerInfo,
- chosen_action_plan_infos: List[ActionPlannerInfo],
- thinking_id: str,
- available_actions: Dict[str, ActionInfo],
- cycle_timers: Dict[str, float],
- ):
- try:
- with Timer(f"动作{action_planner_info.action_type}", cycle_timers):
- if action_planner_info.action_type == "no_reply":
- reason = action_planner_info.reasoning or "选择不回复"
- self._consecutive_no_reply_count += 1
- await database_api.store_action_info(
- chat_stream=self.chat_stream,
- display_prompt=reason,
- thinking_id=thinking_id,
- action_data={},
- action_name="no_reply",
- action_reasoning=reason,
- )
- return {
- "action_type": "no_reply",
- "success": True,
- "result": "选择不回复",
- "loop_info": None,
- }
+ # ====== Action相关逻辑 ======
+ async def _execute_action(self, *args, **kwargs):
+ """原ExecuteAction"""
+ raise NotImplementedError("执行动作的逻辑尚未实现") # TODO: 实现动作执行的逻辑,替换掉*args, **kwargs*占位符
- if action_planner_info.action_type == "reply":
- self._consecutive_no_reply_count = 0
- reason = action_planner_info.reasoning or ""
- think_level = self._get_think_level(action_planner_info)
- planner_reasoning = action_planner_info.action_reasoning or reason
+ async def _execute_other_actions(self, *args, **kwargs):
+ """原HandleAction"""
+ raise NotImplementedError(
+ "执行其他动作的逻辑尚未实现"
+ ) # TODO: 实现其他动作执行的逻辑, 替换掉*args, **kwargs*占位符
- record_replyer_action_temp(
- chat_id=self.session_id,
- reason=reason,
- think_level=think_level,
- )
- await database_api.store_action_info(
- chat_stream=self.chat_stream,
- display_prompt=reason,
- thinking_id=thinking_id,
- action_data={},
- action_name="reply",
- action_reasoning=reason,
- )
-
- unknown_words, quote_message = self._extract_reply_metadata(action_planner_info)
- success, llm_response = await generator_api.generate_reply(
- chat_stream=self.chat_stream,
- reply_message=action_planner_info.action_message,
- available_actions=available_actions,
- chosen_actions=chosen_action_plan_infos,
- reply_reason=planner_reasoning,
- unknown_words=unknown_words,
- enable_tool=global_config.tool.enable_tool,
- request_type="replyer",
- from_plugin=False,
- reply_time_point=action_planner_info.action_data.get("loop_start_time", time.time())
- if action_planner_info.action_data
- else time.time(),
- think_level=think_level,
- )
- if not success or not llm_response or not llm_response.reply_set:
- if action_planner_info.action_message:
- logger.info(
- f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败"
- )
- else:
- logger.info(f"{self.log_prefix} 回复生成失败")
- return {
- "action_type": "reply",
- "success": False,
- "result": "回复生成失败",
- "loop_info": None,
- }
-
- loop_info, reply_text, _ = await self._send_and_store_reply(
- response_set=llm_response.reply_set,
- action_message=action_planner_info.action_message, # type: ignore[arg-type]
- cycle_timers=cycle_timers,
- thinking_id=thinking_id,
- actions=chosen_action_plan_infos,
- selected_expressions=llm_response.selected_expressions,
- quote_message=quote_message,
- )
- self.last_active_time = time.time()
- return {
- "action_type": "reply",
- "success": True,
- "result": reply_text,
- "loop_info": loop_info,
- }
-
- with Timer("动作执行", cycle_timers):
- success, result = await self._handle_action(
- action=action_planner_info.action_type,
- action_reasoning=action_planner_info.action_reasoning or "",
- action_data=action_planner_info.action_data or {},
- cycle_timers=cycle_timers,
- thinking_id=thinking_id,
- action_message=action_planner_info.action_message,
- )
- if success:
- self.last_active_time = time.time()
- return {
- "action_type": action_planner_info.action_type,
- "success": success,
- "result": result,
- "loop_info": None,
- }
- except Exception as exc:
- logger.error(f"{self.log_prefix} 执行动作时出错: {exc}", exc_info=True)
- return {
- "action_type": action_planner_info.action_type,
- "success": False,
- "result": "",
- "loop_info": None,
- "error": str(exc),
- }
-
- async def _handle_action(
- self,
- action: str,
- action_reasoning: str,
- action_data: dict,
- cycle_timers: Dict[str, float],
- thinking_id: str,
- action_message: Optional["SessionMessage"] = None,
- ) -> Tuple[bool, str]:
- try:
- action_handler = self.action_manager.create_action(
- action_name=action,
- action_data=action_data,
- action_reasoning=action_reasoning,
- cycle_timers=cycle_timers,
- thinking_id=thinking_id,
- chat_stream=self.chat_stream,
- log_prefix=self.log_prefix,
- action_message=action_message,
- )
- if not action_handler:
- logger.warning(f"{self.log_prefix} 未能创建动作处理器: {action}")
- return False, ""
-
- success, action_text = await action_handler.execute()
- return success, action_text
- except Exception as exc:
- logger.error(f"{self.log_prefix} 处理动作 {action} 时出错: {exc}", exc_info=True)
- return False, ""
-
- async def _send_and_store_reply(
- self,
- response_set: MessageSequence,
- action_message: "SessionMessage",
- cycle_timers: Dict[str, float],
- thinking_id: str,
- actions: List[ActionPlannerInfo],
- selected_expressions: Optional[List[int]] = None,
- quote_message: Optional[bool] = None,
- ) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
- with Timer("回复发送", cycle_timers):
- reply_text = await self._send_response(
- reply_set=response_set,
- message_data=action_message,
- selected_expressions=selected_expressions,
- quote_message=quote_message,
- )
-
- platform = action_message.platform or getattr(self.chat_stream, "platform", "unknown")
- person = Person(platform=platform, user_id=action_message.message_info.user_info.user_id)
- action_prompt_display = f"你对{person.person_name}进行了回复:{reply_text}"
- await database_api.store_action_info(
- chat_stream=self.chat_stream,
- display_prompt=action_prompt_display,
- thinking_id=thinking_id,
- action_data={"reply_text": reply_text},
- action_name="reply",
- )
-
- loop_info: Dict[str, Any] = {
- "loop_plan_info": {
- "action_result": actions,
- },
- "loop_action_info": {
- "action_taken": True,
- "reply_text": reply_text,
- "command": "",
- "taken_time": time.time(),
- },
- }
- return loop_info, reply_text, cycle_timers
-
- async def _send_response(
- self,
- reply_set: MessageSequence,
- message_data: "SessionMessage",
- selected_expressions: Optional[List[int]] = None,
- quote_message: Optional[bool] = None,
- ) -> str:
- if global_config.chat.llm_quote:
- need_reply = bool(quote_message)
- else:
- new_message_count = message_api.count_new_messages(
- chat_id=self.session_id,
- start_time=self.last_read_time,
- end_time=time.time(),
- )
- need_reply = new_message_count >= random.randint(2, 3) or time.time() - self.last_read_time > 90
-
- reply_text = ""
- first_replied = False
- for component in reply_set.components:
- if not isinstance(component, TextComponent):
- continue
- data = component.text
- if not first_replied:
- await send_api.text_to_stream(
- text=data,
- stream_id=self.session_id,
- reply_message=message_data,
- set_reply=need_reply,
- typing=False,
- selected_expressions=selected_expressions,
- )
- first_replied = True
- else:
- await send_api.text_to_stream(
- text=data,
- stream_id=self.session_id,
- reply_message=message_data,
- set_reply=False,
- typing=True,
- selected_expressions=selected_expressions,
- )
- reply_text += data
- return reply_text
-
- async def _build_planner_prompt_with_event(
- self,
- available_actions: Dict[str, ActionInfo],
- is_group_chat: bool,
- chat_target_info: Any,
- chat_content_block: str,
- message_id_list: List[Tuple[str, "SessionMessage"]],
- ) -> Tuple[Optional[str], Dict[str, ActionInfo]]:
- filtered_actions = self.action_planner._filter_actions_by_activation_type(available_actions, chat_content_block)
- prompt, _ = await self.action_planner.build_planner_prompt(
- is_group_chat=is_group_chat,
- chat_target_info=chat_target_info,
- current_available_actions=filtered_actions,
- chat_content_block=chat_content_block,
- message_id_list=message_id_list,
- )
- event_message = build_event_message(EventType.ON_PLAN, llm_prompt=prompt, stream_id=self.session_id)
- continue_flag, modified_message = await event_bus.emit(EventType.ON_PLAN, event_message)
- if not continue_flag:
- logger.info(f"{self.log_prefix} ON_PLAN 事件中止了本轮 HFC")
- return None, filtered_actions
- if modified_message and modified_message._modify_flags.modify_llm_prompt and modified_message.llm_prompt:
- prompt = modified_message.llm_prompt
- return prompt, filtered_actions
-
- def _ensure_force_reply_action(
- self,
- actions: List[ActionPlannerInfo],
- force_reply_message: Optional["SessionMessage"],
- available_actions: Dict[str, ActionInfo],
- ) -> List[ActionPlannerInfo]:
- if not force_reply_message:
- return actions
-
- has_reply_to_force_message = any(
- action.action_type == "reply"
- and action.action_message
- and action.action_message.message_id == force_reply_message.message_id
- for action in actions
- )
- if has_reply_to_force_message:
- return actions
-
- actions = [action for action in actions if action.action_type != "no_reply"]
- actions.insert(
- 0,
- ActionPlannerInfo(
- action_type="reply",
- reasoning="用户提及了我,必须回复该消息",
- action_data={"loop_start_time": self.last_read_time},
- action_message=force_reply_message,
- available_actions=available_actions,
- action_reasoning=None,
- ),
- )
- logger.info(f"{self.log_prefix} 检测到强制回复消息,已补充 reply 动作")
- return actions
-
- def _log_plan(
- self,
- prompt: str,
- reasoning: str,
- llm_raw_output: Optional[str],
- llm_reasoning: Optional[str],
- llm_duration_ms: Optional[float],
- actions: List[ActionPlannerInfo],
- ) -> None:
- try:
- PlanReplyLogger.log_plan(
- chat_id=self.session_id,
- prompt=prompt,
- reasoning=reasoning,
- raw_output=llm_raw_output,
- raw_reasoning=llm_reasoning,
- actions=actions,
- timing={
- "llm_duration_ms": round(llm_duration_ms, 2) if llm_duration_ms is not None else None,
- "loop_start_time": self.last_read_time,
- },
- extra=None,
- )
- except Exception:
- logger.exception(f"{self.log_prefix} 记录 plan 日志失败")
-
- def _extract_reply_metadata(
- self,
- action_planner_info: ActionPlannerInfo,
- ) -> Tuple[Optional[List[str]], Optional[bool]]:
- unknown_words: Optional[List[str]] = None
- quote_message: Optional[bool] = None
- action_data = action_planner_info.action_data or {}
-
- raw_unknown_words = action_data.get("unknown_words")
- if isinstance(raw_unknown_words, list):
- cleaned_unknown_words = []
- for item in raw_unknown_words:
- if isinstance(item, str) and (cleaned_item := item.strip()):
- cleaned_unknown_words.append(cleaned_item)
- if cleaned_unknown_words:
- unknown_words = cleaned_unknown_words
-
- raw_quote = action_data.get("quote")
- if isinstance(raw_quote, bool):
- quote_message = raw_quote
- elif isinstance(raw_quote, str):
- quote_message = raw_quote.lower() in {"true", "1", "yes"}
- elif isinstance(raw_quote, (int, float)):
- quote_message = bool(raw_quote)
-
- return unknown_words, quote_message
-
- def _get_think_level(self, action_planner_info: ActionPlannerInfo) -> int:
- think_mode = global_config.chat.think_mode
- if think_mode == "default":
- return 0
- if think_mode == "deep":
- return 1
- if think_mode == "dynamic":
- action_data = action_planner_info.action_data or {}
- return int(action_data.get("think_level", 1))
- return 0
-
- def _get_template_name(self) -> Optional[str]:
- if self.chat_stream.context:
- return self.chat_stream.context.template_name
- return None
+ # ====== 响应发送相关方法 ======
+ async def _send_response(self, *args, **kwargs):
+ raise NotImplementedError("发送回复的逻辑尚未实现") # TODO: 实现发送回复的逻辑,替换掉*args, **kwargs*占位符
+ # 传入的消息至少应该是个MessageSequence实例,最好是SessionMessage实例,随后可直接转化为MessageSending实例
diff --git a/src/chat/heart_flow/heartflow.py b/src/chat/heart_flow/heartflow.py
deleted file mode 100644
index febff2d5..00000000
--- a/src/chat/heart_flow/heartflow.py
+++ /dev/null
@@ -1,42 +0,0 @@
-import traceback
-from typing import Any, Optional, Dict
-
-from src.chat.message_receive.chat_stream import get_chat_manager
-from src.common.logger import get_logger
-from src.chat.heart_flow.heartFC_chat import HeartFChatting
-from src.chat.brain_chat.brain_chat import BrainChatting
-from src.chat.message_receive.chat_stream import ChatStream
-
-logger = get_logger("heartflow")
-
-
-class Heartflow:
- """主心流协调器,负责初始化并协调聊天"""
-
- def __init__(self):
- self.heartflow_chat_list: Dict[Any, HeartFChatting | BrainChatting] = {}
-
- async def get_or_create_heartflow_chat(self, chat_id: Any) -> Optional[HeartFChatting | BrainChatting]:
- """获取或创建一个新的HeartFChatting实例"""
- try:
- if chat_id in self.heartflow_chat_list:
- if chat := self.heartflow_chat_list.get(chat_id):
- return chat
- else:
- chat_stream: ChatStream | None = get_chat_manager().get_stream(chat_id)
- if not chat_stream:
- raise ValueError(f"未找到 chat_id={chat_id} 的聊天流")
- if chat_stream.group_info:
- new_chat = HeartFChatting(chat_id=chat_id)
- else:
- new_chat = BrainChatting(chat_id=chat_id)
- await new_chat.start()
- self.heartflow_chat_list[chat_id] = new_chat
- return new_chat
- except Exception as e:
- logger.error(f"创建心流聊天 {chat_id} 失败: {e}", exc_info=True)
- traceback.print_exc()
- return None
-
-
-heartflow = Heartflow()
diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py
index df7d28fc..1fc4ef53 100644
--- a/src/chat/message_receive/bot.py
+++ b/src/chat/message_receive/bot.py
@@ -1,19 +1,20 @@
from contextlib import suppress
-import traceback
-import os
-
-from maim_message import MessageBase
from typing import Any, Dict, Optional
+import os
+import traceback
+from maim_message import MessageBase
+
+from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
from src.common.logger import get_logger
from src.common.utils.utils_message import MessageUtils
from src.common.utils.utils_session import SessionUtils
-from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
+from src.platform_io.route_key_factory import RouteKeyFactory
# from src.chat.brain_chat.PFC.pfc_manager import PFCManager
from src.core.announcement_manager import global_announcement_manager
-from src.core.component_registry import component_registry
+from src.plugin_runtime.component_query import component_query_service
from .message import SessionMessage
from .chat_manager import chat_manager
@@ -58,16 +59,22 @@ class ChatBot:
logger.error(f"创建PFC聊天失败: {e}")
logger.error(traceback.format_exc())
- async def _process_commands(self, message: SessionMessage):
- # sourcery skip: use-named-expression
- """使用新插件系统处理命令"""
+ async def _process_commands(self, message: SessionMessage) -> tuple[bool, Optional[str], bool]:
+ """使用统一组件注册表处理命令。
+
+ Args:
+ message: 当前待处理的会话消息。
+
+ Returns:
+ tuple[bool, Optional[str], bool]: ``(是否命中命令, 命令响应文本, 是否继续后续处理)``。
+ """
if not message.processed_plain_text:
return False, None, True # 没有文本内容,继续处理消息
try:
text = message.processed_plain_text
- # 使用核心组件注册表查找命令
- command_result = component_registry.find_command_by_text(text)
+ # 使用插件运行时统一查询服务查找命令
+ command_result = component_query_service.find_command_by_text(text)
if command_result:
command_executor, matched_groups, command_info = command_result
plugin_name = command_info.plugin_name
@@ -81,7 +88,7 @@ class ChatBot:
message.is_command = True
# 获取插件配置
- plugin_config = component_registry.get_plugin_config(plugin_name)
+ plugin_config = component_query_service.get_plugin_config(plugin_name)
try:
# 调用命令执行器
@@ -112,88 +119,32 @@ class ChatBot:
# 命令出错时,根据命令的拦截设置决定是否继续处理消息
return True, str(e), False # 出错时继续处理消息
- # 没有找到旧系统命令,尝试新版本插件运行时
- new_cmd_result = await self._process_new_runtime_command(message)
- return new_cmd_result if new_cmd_result is not None else (False, None, True)
+ return False, None, True
except Exception as e:
logger.error(f"处理命令时出错: {e}")
return False, None, True # 出错时继续处理消息
- async def _process_new_runtime_command(self, message: SessionMessage):
- """尝试在新版本插件运行时中查找并执行命令
-
- Returns:
- (found, response, continue_processing) 三元组,
- 或 None 表示新运行时中也未找到匹配命令。
- """
- from src.plugin_runtime.integration import get_plugin_runtime_manager
-
- prm = get_plugin_runtime_manager()
- if not prm.is_running:
- return None
-
- matched = prm.find_command_by_text(message.processed_plain_text)
- if matched is None:
- return None
-
- command_name = matched["name"]
- if message.session_id and command_name in global_announcement_manager.get_disabled_chat_commands(
- message.session_id
- ):
- logger.info(f"[新运行时] 用户禁用的命令,跳过处理: {matched['full_name']}")
- return False, None, True
-
- message.is_command = True
- logger.info(f"[新运行时] 匹配命令: {matched['full_name']}")
-
- try:
- resp = await prm.invoke_plugin(
- method="plugin.invoke_command",
- plugin_id=matched["plugin_id"],
- component_name=matched["name"],
- args={
- "text": message.processed_plain_text,
- "stream_id": message.session_id or "",
- "matched_groups": matched.get("matched_groups") or {},
- },
- timeout_ms=30000,
- )
-
- payload = resp.payload
- success = payload.get("success", False)
- cmd_result = payload.get("result")
-
- # 拦截位优先从命令返回值中获取(支持运行时动态决定),
- # 回退到组件 metadata 中的静态声明
- if isinstance(cmd_result, (list, tuple)) and len(cmd_result) >= 3:
- # 命令返回 (found, response_text, intercept_bool) 三元组
- response_text = cmd_result[1] if cmd_result[1] is not None else ""
- intercept = bool(cmd_result[2])
- else:
- response_text = cmd_result if cmd_result is not None else ""
- intercept = bool(matched["metadata"].get("intercept_message_level", 0))
-
- self._mark_command_message(message, int(intercept))
-
- if success:
- logger.info(f"[新运行时] 命令执行成功: {matched['full_name']}")
- else:
- logger.warning(f"[新运行时] 命令执行失败: {matched['full_name']} - {response_text}")
-
- return True, response_text, not intercept
-
- except Exception as e:
- logger.error(f"[新运行时] 执行命令 {matched['full_name']} 异常: {e}", exc_info=True)
- return True, str(e), True
-
@staticmethod
def _mark_command_message(message: SessionMessage, intercept_message_level: int) -> None:
+ """标记消息已经被命令链消费。
+
+ Args:
+ message: 待标记的会话消息。
+ intercept_message_level: 命令设置的拦截级别。
+ """
+
message.is_command = True
message.message_info.additional_config["intercept_message_level"] = intercept_message_level
@staticmethod
def _store_intercepted_command_message(message: SessionMessage) -> None:
+ """将被命令链拦截的消息写入数据库。
+
+ Args:
+ message: 已完成命令处理的会话消息。
+ """
+
MessageUtils.store_message_to_db(message)
async def _handle_command_processing_result(
@@ -310,13 +261,28 @@ class ChatBot:
# logger.debug(str(message_data))
maim_raw_message = MessageBase.from_dict(message_data)
message = SessionMessage.from_maim_message(maim_raw_message)
+ await self.receive_message(message)
+
+ except Exception as e:
+ logger.error(f"预处理消息失败: {e}")
+ traceback.print_exc()
+
+ async def receive_message(self, message: SessionMessage):
+ try:
group_info = message.message_info.group_info
user_info = message.message_info.user_info
+ account_id = None
+ scope = None
+ additional_config = message.message_info.additional_config
+ if isinstance(additional_config, dict):
+ account_id, scope = RouteKeyFactory.extract_components(additional_config)
session_id = SessionUtils.calculate_session_id(
message.platform,
user_id=message.message_info.user_info.user_id,
group_id=group_info.group_id if group_info else None,
+ account_id=account_id,
+ scope=scope,
)
message.session_id = session_id # 正确初始化session_id
@@ -359,24 +325,24 @@ class ChatBot:
platform = message.platform
user_id = user_info.user_id
group_id = group_info.group_id if group_info else None
- _ = await chat_manager.get_or_create_session(platform, user_id, group_id) # 确保会话存在
- try:
- from src.services.memory_flow_service import memory_automation_service
-
- await memory_automation_service.on_incoming_message(message)
- except Exception as exc:
- logger.warning(f"[长期记忆自动总结] 注册会话总结器失败: {exc}")
+ _ = await chat_manager.get_or_create_session(
+ platform,
+ user_id,
+ group_id,
+ account_id=account_id,
+ scope=scope,
+ ) # 确保会话存在
# message.update_chat_stream(chat)
# 命令处理 - 使用新插件系统检查并处理命令
# 注意:命令返回的 response 当前只用于日志记录和流程判断,
# 不会在这里自动作为回复消息发送回会话。
- is_command, cmd_result, continue_process = await self._process_commands(message)
+ # is_command, cmd_result, continue_process = await self._process_commands(message)
- # 如果是命令且不需要继续处理,则直接返回
- if is_command and await self._handle_command_processing_result(message, cmd_result, continue_process):
- return
+ # # 如果是命令且不需要继续处理,则直接返回
+ # if is_command and await self._handle_command_processing_result(message, cmd_result, continue_process):
+ # return
# continue_flag, modified_message = await events_manager.handle_mai_events(EventType.ON_MESSAGE, message)
# if not continue_flag:
diff --git a/src/chat/message_receive/chat_manager.py b/src/chat/message_receive/chat_manager.py
index b11d233c..48d89956 100644
--- a/src/chat/message_receive/chat_manager.py
+++ b/src/chat/message_receive/chat_manager.py
@@ -1,15 +1,16 @@
+import asyncio
from datetime import datetime
+from typing import TYPE_CHECKING, Dict, List, Optional
+
from rich.traceback import install
from sqlmodel import select
-from typing import Optional, TYPE_CHECKING, List, Dict
-import asyncio
-
-from src.common.logger import get_logger
from src.common.data_models.chat_session_data_model import MaiChatSession
-from src.common.database.database_model import ChatSession
from src.common.database.database import get_db_session
+from src.common.database.database_model import ChatSession
+from src.common.logger import get_logger
from src.common.utils.utils_session import SessionUtils
+from src.platform_io.route_key_factory import RouteKeyFactory
if TYPE_CHECKING:
from .message import SessionMessage
@@ -82,7 +83,12 @@ class ChatManager:
logger.error(f"初始化聊天管理器出现错误: {e}")
async def get_or_create_session(
- self, platform: str, user_id: str, group_id: Optional[str] = None
+ self,
+ platform: str,
+ user_id: str,
+ group_id: Optional[str] = None,
+ account_id: Optional[str] = None,
+ scope: Optional[str] = None,
) -> BotChatSession:
"""获取会话,如果不存在则创建一个新会话;一个封装方法。
@@ -90,12 +96,20 @@ class ChatManager:
platform: 平台
user_id: 用户ID
group_id: 群ID(如果是群聊)
+ account_id: 平台账号 ID
+ scope: 路由作用域
Returns:
return (BotChatSession) 会话对象
Raises:
Exception: 获取或创建会话时发生错误
"""
- session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
+ session_id = SessionUtils.calculate_session_id(
+ platform,
+ user_id=user_id,
+ group_id=group_id,
+ account_id=account_id,
+ scope=scope,
+ )
if session := self.get_session_by_session_id(session_id):
session.update_active_time()
return session
@@ -131,7 +145,18 @@ class ChatManager:
raise ValueError("消息缺少平台信息")
user_id = message.message_info.user_info.user_id
group_id = message.message_info.group_info.group_id if message.message_info.group_info else None
- session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
+ account_id = None
+ scope = None
+ additional_config = message.message_info.additional_config
+ if isinstance(additional_config, dict):
+ account_id, scope = RouteKeyFactory.extract_components(additional_config)
+ session_id = SessionUtils.calculate_session_id(
+ platform,
+ user_id=user_id,
+ group_id=group_id,
+ account_id=account_id,
+ scope=scope,
+ )
message.session_id = session_id # 确保消息的session_id正确设置
self.last_messages[session_id] = message
@@ -188,7 +213,12 @@ class ChatManager:
return None
def get_session_by_info(
- self, platform: str, user_id: Optional[str] = None, group_id: Optional[str] = None
+ self,
+ platform: str,
+ user_id: Optional[str] = None,
+ group_id: Optional[str] = None,
+ account_id: Optional[str] = None,
+ scope: Optional[str] = None,
) -> Optional[BotChatSession]:
"""根据平台、用户ID和群ID获取对应的会话
@@ -196,10 +226,18 @@ class ChatManager:
platform: 平台
user_id: 用户ID
group_id: 群ID(如果是群聊)
+ account_id: 平台账号 ID
+ scope: 路由作用域
Returns:
return (Optional[BotChatSession]): 会话对象,如果不存在则返回None
"""
- session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
+ session_id = SessionUtils.calculate_session_id(
+ platform,
+ user_id=user_id,
+ group_id=group_id,
+ account_id=account_id,
+ scope=scope,
+ )
return self.get_session_by_session_id(session_id)
def get_session_by_session_id(self, session_id: str) -> Optional[BotChatSession]:
diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py
index 369c0c51..246a8350 100644
--- a/src/chat/message_receive/uni_message_sender.py
+++ b/src/chat/message_receive/uni_message_sender.py
@@ -1,31 +1,37 @@
-from rich.traceback import install
-from typing import Optional
+from typing import Any, Optional, Tuple
import asyncio
+import traceback
+from rich.traceback import install
-from src.common.message_server.api import get_global_api
-from src.common.logger import get_logger
-from src.common.database.database import get_db_session
from src.chat.message_receive.message import SessionMessage
+from src.chat.utils.utils import calculate_typing_time, truncate_message
from src.common.data_models.message_component_data_model import ReplyComponent
-from src.chat.utils.utils import truncate_message
-from src.chat.utils.utils import calculate_typing_time
+from src.common.database.database import get_db_session
+from src.common.logger import get_logger
+from src.common.message_server.api import get_global_api
+from src.webui.routers.chat.serializers import serialize_message_sequence
install(extra_lines=3)
logger = get_logger("sender")
# WebUI 聊天室的消息广播器(延迟导入避免循环依赖)
-_webui_chat_broadcaster = None
+_webui_chat_broadcaster: Optional[Tuple[Any, Optional[str]]] = None
# 虚拟群 ID 前缀(与 chat_routes.py 保持一致)
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
# TODO: 重构完成后完成webui相关
-def get_webui_chat_broadcaster():
- """获取 WebUI 聊天室广播器"""
+def get_webui_chat_broadcaster() -> Tuple[Any, Optional[str]]:
+ """获取 WebUI 聊天室广播器。
+
+ Returns:
+ Tuple[Any, Optional[str]]: ``(chat_manager, platform_name)`` 二元组;
+ 若 WebUI 相关模块不可用,则元素会退化为 ``None``。
+ """
global _webui_chat_broadcaster
if _webui_chat_broadcaster is None:
try:
@@ -38,102 +44,35 @@ def get_webui_chat_broadcaster():
def is_webui_virtual_group(group_id: str) -> bool:
- """检查是否是 WebUI 虚拟群"""
- return group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX)
-
-
-def parse_message_segments(segment) -> list:
- """解析消息段,转换为 WebUI 可用的格式
-
- 参考 NapCat 适配器的消息解析逻辑
+ """检查是否是 WebUI 虚拟群。
Args:
- segment: Seg 消息段对象
+ group_id: 待判断的群 ID。
Returns:
- list: 消息段列表,每个元素为 {"type": "...", "data": ...}
+ bool: 若群 ID 属于 WebUI 虚拟群则返回 ``True``。
"""
-
- result = []
-
- if segment is None:
- return result
-
- if segment.type == "seglist":
- # 处理消息段列表
- if segment.data:
- for seg in segment.data:
- result.extend(parse_message_segments(seg))
- elif segment.type == "text":
- # 文本消息
- if segment.data:
- result.append({"type": "text", "data": segment.data})
- elif segment.type == "image":
- # 图片消息(base64)
- if segment.data:
- result.append({"type": "image", "data": f"data:image/png;base64,{segment.data}"})
- elif segment.type == "emoji":
- # 表情包消息(base64)
- if segment.data:
- result.append({"type": "emoji", "data": f"data:image/gif;base64,{segment.data}"})
- elif segment.type == "imageurl":
- # 图片链接消息
- if segment.data:
- result.append({"type": "image", "data": segment.data})
- elif segment.type == "face":
- # 原生表情
- result.append({"type": "face", "data": segment.data})
- elif segment.type == "voice":
- # 语音消息(base64)
- if segment.data:
- result.append({"type": "voice", "data": f"data:audio/wav;base64,{segment.data}"})
- elif segment.type == "voiceurl":
- # 语音链接
- if segment.data:
- result.append({"type": "voice", "data": segment.data})
- elif segment.type == "video":
- # 视频消息(base64)
- if segment.data:
- result.append({"type": "video", "data": f"data:video/mp4;base64,{segment.data}"})
- elif segment.type == "videourl":
- # 视频链接
- if segment.data:
- result.append({"type": "video", "data": segment.data})
- elif segment.type == "music":
- # 音乐消息
- result.append({"type": "music", "data": segment.data})
- elif segment.type == "file":
- # 文件消息
- result.append({"type": "file", "data": segment.data})
- elif segment.type == "reply":
- # 回复消息
- result.append({"type": "reply", "data": segment.data})
- elif segment.type == "forward":
- # 转发消息
- forward_items = []
- if segment.data:
- for item in segment.data:
- forward_items.append(
- {
- "content": parse_message_segments(item.get("message_segment", {}))
- if isinstance(item, dict)
- else []
- }
- )
- result.append({"type": "forward", "data": forward_items})
- else:
- # 未知类型,尝试作为文本处理
- if segment.data:
- result.append({"type": "unknown", "original_type": segment.type, "data": str(segment.data)})
-
- return result
+ return bool(group_id) and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX)
-async def _send_message(message: MessageSending, show_log=True) -> bool:
- """合并后的消息发送函数,包含WS发送和日志记录"""
+async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
+ """执行统一的消息发送流程。
+
+ 发送顺序为:
+ 1. WebUI 特殊链路
+ 2. 旧版 ``maim_message`` / API Server 链路
+
+ Args:
+ message: 待发送的内部会话消息。
+ show_log: 是否输出发送成功日志。
+
+ Returns:
+ bool: 是否最终发送成功。
+ """
message_preview = truncate_message(message.processed_plain_text, max_length=200)
platform = message.platform
- group_id = message.session.group_id
+ group_info = message.message_info.group_info
+ group_id = group_info.group_id if group_info is not None else ""
try:
# 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息
@@ -146,7 +85,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
from src.config.config import global_config
# 解析消息段,获取富文本内容
- message_segments = parse_message_segments(message.message_segment)
+ message_segments = serialize_message_sequence(message.raw_message)
# 判断消息类型
# 如果只有一个文本段,使用简单的 text 类型
@@ -185,7 +124,15 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
return True
# Fallback 逻辑: 尝试通过 API Server 发送
- async def send_with_new_api(legacy_exception=None):
+ async def send_with_new_api(legacy_exception: Optional[Exception] = None) -> bool:
+ """通过 API Server 回退链路发送消息。
+
+ Args:
+ legacy_exception: 旧发送链已经抛出的异常;若回退也失败,则重新抛出。
+
+ Returns:
+ bool: 回退链路是否发送成功。
+ """
try:
from src.config.config import global_config
@@ -286,10 +233,24 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
raise e # 重新抛出其他异常
-class UniversalMessageSender:
- """管理消息的注册、即时处理、发送和存储,并跟踪思考状态。"""
+async def send_prepared_message_to_platform(message: SessionMessage, show_log: bool = True) -> bool:
+ """发送一条已完成预处理的消息到底层平台。
- def __init__(self):
+ Args:
+ message: 已经完成回复组件注入、文本处理等预处理的消息对象。
+ show_log: 是否输出发送成功日志。
+
+ Returns:
+ bool: 发送成功时返回 ``True``。
+ """
+ return await _send_message(message, show_log=show_log)
+
+
+class UniversalMessageSender:
+ """旧链与 WebUI 的底层发送器。"""
+
+ def __init__(self) -> None:
+ """初始化统一消息发送器。"""
pass
async def send_message(
@@ -300,18 +261,19 @@ class UniversalMessageSender:
reply_message_id: Optional[str] = None,
storage_message: bool = True,
show_log: bool = True,
- ):
- """
- 处理、发送并存储一条消息。
+ ) -> bool:
+ """通过旧链或 WebUI 发送并存储一条消息。
- 参数:
- message: MessageSession 对象,待发送的消息。
+ Args:
+ message: 待发送的内部消息对象。
typing: 是否模拟打字等待。
- set_reply: 是否构建回复引用消息。
+ set_reply: 是否构建引用回复消息。
+ reply_message_id: 被引用消息的 ID。
+ storage_message: 是否在发送成功后写入数据库。
+ show_log: 是否输出发送日志。
-
- 用法:
- - typing=True 时,发送前会有打字等待。
+ Returns:
+ bool: 发送成功时返回 ``True``。
"""
if not message.message_id:
logger.error("消息缺少 message_id,无法发送")
@@ -364,7 +326,7 @@ class UniversalMessageSender:
)
await asyncio.sleep(typing_time)
- sent_msg = await _send_message(message, show_log=show_log)
+ sent_msg = await send_prepared_message_to_platform(message, show_log=show_log)
if not sent_msg:
return False
diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py
index 167cdcab..8133ac18 100644
--- a/src/chat/planner_actions/action_manager.py
+++ b/src/chat/planner_actions/action_manager.py
@@ -3,8 +3,8 @@ from typing import Dict, Optional, Tuple
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.message import SessionMessage
from src.common.logger import get_logger
-from src.core.component_registry import component_registry, ActionExecutor
from src.core.types import ActionInfo
+from src.plugin_runtime.component_query import ActionExecutor, component_query_service
logger = get_logger("action_manager")
@@ -28,7 +28,7 @@ class ActionManager:
"""
动作管理器,用于管理各种类型的动作
- 使用核心组件注册表的 executor-based 模式。
+ 使用插件运行时统一查询服务的 executor-based 模式。
"""
def __init__(self):
@@ -38,7 +38,7 @@ class ActionManager:
self._using_actions: Dict[str, ActionInfo] = {}
# 初始化时将默认动作加载到使用中的动作
- self._using_actions = component_registry.get_default_actions()
+ self._using_actions = component_query_service.get_default_actions()
# === 执行Action方法 ===
@@ -72,17 +72,17 @@ class ActionManager:
Optional[ActionHandle]: 执行句柄,如果动作未注册则返回 None
"""
try:
- executor = component_registry.get_action_executor(action_name)
+ executor = component_query_service.get_action_executor(action_name)
if not executor:
logger.warning(f"{log_prefix} 未找到Action组件: {action_name}")
return None
- info = component_registry.get_action_info(action_name)
+ info = component_query_service.get_action_info(action_name)
if not info:
logger.warning(f"{log_prefix} 未找到Action组件信息: {action_name}")
return None
- plugin_config = component_registry.get_plugin_config(info.plugin_name) or {}
+ plugin_config = component_query_service.get_plugin_config(info.plugin_name) or {}
handle = ActionHandle(
executor,
@@ -133,5 +133,5 @@ class ActionManager:
def restore_actions(self) -> None:
"""恢复到默认动作集"""
actions_to_restore = list(self._using_actions.keys())
- self._using_actions = component_registry.get_default_actions()
+ self._using_actions = component_query_service.get_default_actions()
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py
index 5184abcb..b21efa6b 100644
--- a/src/chat/planner_actions/planner.py
+++ b/src/chat/planner_actions/planner.py
@@ -1,33 +1,36 @@
+from collections import OrderedDict
+from datetime import datetime
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+
+import contextlib
import json
-import time
-import traceback
import random
import re
-import contextlib
-from typing import Dict, Optional, Tuple, List, TYPE_CHECKING, Union
-from collections import OrderedDict
-from rich.traceback import install
-from datetime import datetime
+import time
+import traceback
+
from json_repair import repair_json
-from src.llm_models.utils_model import LLMRequest
-from src.config.config import global_config, model_config
-from src.common.logger import get_logger
+from rich.traceback import install
+
from src.chat.logger.plan_reply_logger import PlanReplyLogger
+from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
+from src.chat.message_receive.message import SessionMessage
+from src.chat.planner_actions.action_manager import ActionManager
+from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
from src.common.data_models.info_data_model import ActionPlannerInfo
+from src.common.logger import get_logger
+from src.config.config import global_config, model_config
+from src.core.types import ActionActivationType, ActionInfo, ComponentType
+from src.llm_models.utils_model import LLMRequest
+from src.person_info.person_info import Person
+from src.plugin_runtime.component_query import component_query_service
from src.prompt.prompt_manager import prompt_manager
from src.services.message_service import (
build_readable_messages_with_id,
- replace_user_references,
get_messages_before_time_in_chat,
+ replace_user_references,
translate_pid_to_description,
)
-from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
-from src.chat.planner_actions.action_manager import ActionManager
-from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
-from src.chat.message_receive.message import SessionMessage
-from src.core.types import ActionActivationType, ActionInfo, ComponentType
-from src.core.component_registry import component_registry
-from src.person_info.person_info import Person
if TYPE_CHECKING:
from src.common.data_models.info_data_model import TargetPersonInfo
@@ -634,7 +637,7 @@ class ActionPlanner:
current_available_actions_dict = self.action_manager.get_using_actions()
# 获取完整的动作信息
- all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore
+ all_registered_actions: Dict[str, ActionInfo] = component_query_service.get_components_by_type( # type: ignore
ComponentType.ACTION
)
current_available_actions = {}
diff --git a/src/chat/replyer/group_generator.py b/src/chat/replyer/group_generator.py
index 003009b8..4ffa14a7 100644
--- a/src/chat/replyer/group_generator.py
+++ b/src/chat/replyer/group_generator.py
@@ -1,6 +1,7 @@
import traceback
import time
import asyncio
+import importlib
import random
import re
@@ -16,7 +17,6 @@ from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUser
from src.common.data_models.mai_message_data_model import MaiMessage
from src.chat.message_receive.message import SessionMessage
from src.chat.message_receive.chat_manager import BotChatSession
-from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
from src.chat.utils.utils import get_bot_account, get_chat_type_and_target_info, is_bot_self
from src.prompt.prompt_manager import prompt_manager
@@ -26,7 +26,7 @@ from src.services.message_service import (
replace_user_references,
translate_pid_to_description,
)
-from src.bw_learner.expression_selector import expression_selector
+from src.learners.expression_selector import expression_selector
# from src.memory_system.memory_activator import MemoryActivator
from src.person_info.person_info import Person
@@ -35,8 +35,7 @@ from src.services import llm_service as llm_api
from src.chat.logger.plan_reply_logger import PlanReplyLogger
from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt
-from src.memory_system.retrieval_tools import get_tool_registry
-from src.bw_learner.jargon_explainer_old import explain_jargon_in_context, retrieve_concepts_with_jargon
+from src.learners.jargon_explainer_old import explain_jargon_in_context, retrieve_concepts_with_jargon
from src.chat.utils.common_utils import TempMethodsExpression
init_memory_retrieval_sys()
@@ -51,10 +50,15 @@ class DefaultReplyer:
chat_stream: BotChatSession,
request_type: str = "replyer",
):
+ """初始化群聊回复器。
+
+ Args:
+ chat_stream: 当前绑定的聊天会话。
+ request_type: LLM 请求类型标识。
+ """
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
self.chat_stream = chat_stream
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
- self.heart_fc_sender = UniversalMessageSender()
from src.chat.tool_executor import ToolExecutor
@@ -1129,7 +1133,10 @@ class DefaultReplyer:
user_id=bot_user_id,
user_nickname=global_config.bot.nickname,
),
- additional_config={},
+ additional_config={
+ "platform_io_target_group_id": self.chat_stream.group_id,
+ "platform_io_target_user_id": self.chat_stream.user_id,
+ },
),
message_segment=message_segment,
)
@@ -1164,14 +1171,29 @@ class DefaultReplyer:
async def get_prompt_info(self, message: str, sender: str, target: str):
related_info = ""
start_time = time.time()
- search_knowledge_tool = get_tool_registry().get_tool("search_long_term_memory")
- if search_knowledge_tool is None:
- logger.debug("长期记忆检索工具未注册,跳过获取知识内容")
+ try:
+ knowledge_module = importlib.import_module("src.plugins.built_in.knowledge.lpmm_get_knowledge")
+ except ImportError:
+ logger.debug("LPMM知识库工具模块不存在,跳过获取知识库内容")
return ""
- logger.debug(f"获取长期记忆内容,元消息:{message[:30]}...,消息长度: {len(message)}")
+ search_knowledge_tool = getattr(knowledge_module, "SearchKnowledgeFromLPMMTool", None)
+ if search_knowledge_tool is None:
+ logger.debug("LPMM知识库工具未提供 SearchKnowledgeFromLPMMTool,跳过获取知识库内容")
+ return ""
+
+ logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
+ # 从LPMM知识库获取知识
try:
- template_prompt = prompt_manager.get_prompt("memory_get_knowledge")
+ # 检查LPMM知识库是否启用
+ if not global_config.lpmm_knowledge.enable:
+ logger.debug("LPMM知识库未启用,跳过获取知识库内容")
+ return ""
+
+ if global_config.lpmm_knowledge.lpmm_mode == "agent":
+ return ""
+
+ template_prompt = prompt_manager.get_prompt("lpmm_get_knowledge")
template_prompt.add_context("bot_name", global_config.bot.nickname)
template_prompt.add_context("time_now", lambda _: time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
template_prompt.add_context("chat_history", message)
@@ -1187,31 +1209,24 @@ class DefaultReplyer:
# logger.info(f"工具调用提示词: {prompt}")
# logger.info(f"工具调用: {tool_calls}")
- if not tool_calls:
- logger.debug("模型认为不需要使用长期记忆")
+ if tool_calls:
+ result = await self.tool_executor.execute_tool_call(tool_calls[0])
+ end_time = time.time()
+ if not result or not result.get("content"):
+ logger.debug("从LPMM知识库获取知识失败,返回空知识...")
+ return ""
+ found_knowledge_from_lpmm = result.get("content", "")
+ logger.info(
+ f"从LPMM知识库获取知识,相关信息:{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}"
+ )
+ related_info += found_knowledge_from_lpmm
+ logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}秒")
+ logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
+
+ return f"你有以下这些**知识**:\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n"
+ else:
+ logger.debug("模型认为不需要使用LPMM知识库")
return ""
-
- related_chunks: List[str] = []
- for tool_call in tool_calls:
- if tool_call.func_name != "search_long_term_memory":
- continue
- tool_args = dict(tool_call.args or {})
- tool_args.setdefault("chat_id", self.chat_stream.session_id)
- result_text = await search_knowledge_tool.execute(**tool_args)
- if result_text and "未找到" not in result_text:
- related_chunks.append(result_text)
-
- if not related_chunks:
- logger.debug("长期记忆未返回有效信息")
- return ""
-
- related_info = "\n".join(related_chunks)
- end_time = time.time()
- logger.info(f"从长期记忆获取知识,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
- logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}秒")
- logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
-
- return f"你有以下这些**知识**:\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n"
except Exception as e:
logger.error(f"获取知识库内容时发生异常: {str(e)}")
return ""
diff --git a/src/chat/replyer/private_generator.py b/src/chat/replyer/private_generator.py
index 3b70bb2c..c125a42f 100644
--- a/src/chat/replyer/private_generator.py
+++ b/src/chat/replyer/private_generator.py
@@ -16,7 +16,6 @@ from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUser
from src.common.data_models.mai_message_data_model import MaiMessage
from src.chat.message_receive.message import SessionMessage
from src.chat.message_receive.chat_manager import BotChatSession
-from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from src.chat.utils.timer_calculator import Timer
from src.chat.utils.utils import get_bot_account, get_chat_type_and_target_info, is_bot_self
from src.prompt.prompt_manager import prompt_manager
@@ -27,13 +26,13 @@ from src.services.message_service import (
replace_user_references,
translate_pid_to_description,
)
-from src.bw_learner.expression_selector import expression_selector
+from src.learners.expression_selector import expression_selector
# from src.memory_system.memory_activator import MemoryActivator
from src.person_info.person_info import Person, is_person_known
from src.core.types import ActionInfo, EventType
from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt
-from src.bw_learner.jargon_explainer_old import explain_jargon_in_context
+from src.learners.jargon_explainer_old import explain_jargon_in_context
init_memory_retrieval_sys()
@@ -47,10 +46,15 @@ class PrivateReplyer:
chat_stream: BotChatSession,
request_type: str = "replyer",
):
+ """初始化私聊回复器。
+
+ Args:
+ chat_stream: 当前绑定的聊天会话。
+ request_type: LLM 请求类型标识。
+ """
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
self.chat_stream = chat_stream
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
- self.heart_fc_sender = UniversalMessageSender()
# self.memory_activator = MemoryActivator()
from src.chat.tool_executor import ToolExecutor
@@ -970,7 +974,9 @@ class PrivateReplyer:
user_nickname=global_config.bot.nickname,
),
group_info=None,
- additional_config={},
+ additional_config={
+ "platform_io_target_user_id": self.chat_stream.user_id,
+ },
),
message_segment=message_segment,
)
diff --git a/src/chat/tool_executor.py b/src/chat/tool_executor.py
index d449f7a1..aa99fce8 100644
--- a/src/chat/tool_executor.py
+++ b/src/chat/tool_executor.py
@@ -1,22 +1,20 @@
-"""
-工具执行器
+"""工具执行器。
独立的工具执行组件,可以直接输入聊天消息内容,
自动判断并执行相应的工具,返回结构化的工具执行结果。
-
-从 src.plugin_system.core.tool_use 迁移,使用新的核心组件注册表。
"""
+from typing import Any, Dict, List, Optional, Tuple
+
import hashlib
import time
-from typing import Any, Dict, List, Optional, Tuple
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.core.announcement_manager import global_announcement_manager
-from src.core.component_registry import component_registry
from src.llm_models.payload_content import ToolCall
from src.llm_models.utils_model import LLMRequest
+from src.plugin_runtime.component_query import component_query_service
from src.prompt.prompt_manager import prompt_manager
logger = get_logger("tool_use")
@@ -89,7 +87,7 @@ class ToolExecutor:
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
"""获取 LLM 可用的工具定义列表"""
- all_tools = component_registry.get_llm_available_tools()
+ all_tools = component_query_service.get_llm_available_tools()
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
return [info.get_llm_definition() for name, info in all_tools.items() if name not in user_disabled_tools]
@@ -152,7 +150,7 @@ class ToolExecutor:
function_args = tool_call.args or {}
function_args["llm_called"] = True
- executor = component_registry.get_tool_executor(function_name)
+ executor = component_query_service.get_tool_executor(function_name)
if not executor:
logger.warning(f"未知工具名称: {function_name}")
return None
diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py
index ede10a41..51e5e643 100644
--- a/src/chat/utils/statistic.py
+++ b/src/chat/utils/statistic.py
@@ -436,14 +436,14 @@ class StatisticOutputTask(AsyncTask):
@staticmethod
def _fetch_online_time_since(query_start_time: datetime) -> list[tuple[datetime, datetime]]:
- with get_db_session() as session:
+ with get_db_session(auto_commit=False) as session:
statement = select(OnlineTime).where(col(OnlineTime.end_timestamp) >= query_start_time)
records = session.exec(statement).all()
return [(record.start_timestamp, record.end_timestamp) for record in records]
@staticmethod
def _fetch_model_usage_since(query_start_time: datetime) -> list[dict[str, object]]:
- with get_db_session() as session:
+ with get_db_session(auto_commit=False) as session:
statement = select(ModelUsage).where(col(ModelUsage.timestamp) >= query_start_time)
records = session.exec(statement).all()
return [
@@ -664,7 +664,7 @@ class StatisticOutputTask(AsyncTask):
}
query_start_timestamp = collect_period[-1][1]
- with get_db_session() as session:
+ with get_db_session(auto_commit=False) as session:
statement = select(Messages).where(col(Messages.timestamp) >= query_start_timestamp)
messages = session.exec(statement).all()
for message in messages:
@@ -713,7 +713,7 @@ class StatisticOutputTask(AsyncTask):
# 使用 ActionRecords 中的 reply 动作次数作为回复数基准
try:
action_query_start_timestamp = collect_period[-1][1]
- with get_db_session() as session:
+ with get_db_session(auto_commit=False) as session:
statement = select(ActionRecord).where(col(ActionRecord.timestamp) >= action_query_start_timestamp)
actions = session.exec(statement).all()
for action in actions:
@@ -1750,7 +1750,7 @@ class StatisticOutputTask(AsyncTask):
# 查询消息记录
query_start_timestamp = start_time.timestamp()
- with get_db_session() as session:
+ with get_db_session(auto_commit=False) as session:
statement = select(Messages).where(col(Messages.timestamp) >= start_time)
messages = session.exec(statement).all()
for message in messages:
@@ -2131,7 +2131,7 @@ class StatisticOutputTask(AsyncTask):
# 查询消息记录
query_start_timestamp = start_time.timestamp()
- with get_db_session() as session:
+ with get_db_session(auto_commit=False) as session:
statement = select(Messages).where(col(Messages.timestamp) >= start_time)
messages = session.exec(statement).all()
for message in messages:
diff --git a/src/common/data_models/person_info_data_model.py b/src/common/data_models/person_info_data_model.py
index 4cbb62d8..1b239356 100644
--- a/src/common/data_models/person_info_data_model.py
+++ b/src/common/data_models/person_info_data_model.py
@@ -1,6 +1,6 @@
-from dataclasses import dataclass
+from dataclasses import asdict, dataclass
from datetime import datetime
-from typing import Optional, List
+from typing import Any, List, Mapping, Optional, Sequence
import json
@@ -15,6 +15,76 @@ class GroupCardnameInfo:
group_cardname: str
+def _normalize_group_cardname_item(raw_item: Mapping[str, Any]) -> Optional[GroupCardnameInfo]:
+ """将单条群名片数据规范化为统一结构。
+
+ Args:
+ raw_item: 原始群名片字典,必须包含 `group_id` 和 `group_cardname`。
+
+ Returns:
+ Optional[GroupCardnameInfo]: 规范化后的群名片信息;若数据不完整则返回 ``None``。
+ """
+ group_id = str(raw_item.get("group_id") or "").strip()
+ group_cardname = str(raw_item.get("group_cardname") or "").strip()
+ if not group_id or not group_cardname:
+ return None
+ return GroupCardnameInfo(group_id=group_id, group_cardname=group_cardname)
+
+
+def parse_group_cardname_json(group_cardname_json: Optional[str]) -> Optional[List[GroupCardnameInfo]]:
+ """解析数据库中的群名片 JSON 字段。
+
+ Args:
+ group_cardname_json: 数据库存储的群名片 JSON 字符串。
+
+ Returns:
+ Optional[List[GroupCardnameInfo]]: 解析并规范化后的群名片列表;若字段为空或无有效项则返回 ``None``。
+
+ Raises:
+ json.JSONDecodeError: 当 JSON 文本格式非法时抛出。
+ TypeError: 当输入值类型不符合 `json.loads()` 要求时抛出。
+ """
+ if not group_cardname_json:
+ return None
+
+ raw_items = json.loads(group_cardname_json)
+ if not isinstance(raw_items, list):
+ return None
+
+ normalized_items: List[GroupCardnameInfo] = []
+ for raw_item in raw_items:
+ if not isinstance(raw_item, Mapping):
+ continue
+ if normalized_item := _normalize_group_cardname_item(raw_item):
+ normalized_items.append(normalized_item)
+
+ return normalized_items or None
+
+
+def dump_group_cardname_records(
+ group_cardname_records: Optional[Sequence[GroupCardnameInfo | Mapping[str, Any]]],
+) -> str:
+ """将群名片列表序列化为数据库使用的标准 JSON 字符串。
+
+ Args:
+ group_cardname_records: 待序列化的群名片列表,支持 `GroupCardnameInfo`
+ 对象和包含 `group_id` / `group_cardname` 的字典。
+
+ Returns:
+ str: 统一使用 `group_cardname` 键名的 JSON 字符串。
+ """
+ normalized_items: List[GroupCardnameInfo] = []
+ for raw_item in group_cardname_records or []:
+ if isinstance(raw_item, GroupCardnameInfo):
+ normalized_items.append(raw_item)
+ continue
+ if isinstance(raw_item, Mapping):
+ if normalized_item := _normalize_group_cardname_item(raw_item):
+ normalized_items.append(normalized_item)
+
+ return json.dumps([asdict(item) for item in normalized_items], ensure_ascii=False)
+
+
class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
def __init__(
self,
@@ -58,9 +128,16 @@ class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
"""最后一次被认识的时间"""
@classmethod
- def from_db_instance(cls, db_record: "PersonInfo"):
- nickname_json = json.loads(db_record.group_cardname) if db_record.group_cardname else None
- group_cardname_list = [GroupCardnameInfo(**item) for item in nickname_json] if nickname_json else None
+ def from_db_instance(cls, db_record: "PersonInfo") -> "MaiPersonInfo":
+ """从数据库记录构造人物信息数据模型。
+
+ Args:
+ db_record: 数据库中的人物信息记录。
+
+ Returns:
+ MaiPersonInfo: 转换后的数据模型对象。
+ """
+ group_cardname_list = parse_group_cardname_json(db_record.group_cardname)
memory_points = json.loads(db_record.memory_points) if db_record.memory_points else None
return cls(
is_known=db_record.is_known,
@@ -78,9 +155,12 @@ class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
)
def to_db_instance(self) -> "PersonInfo":
- group_cardname = (
- json.dumps([gc.__dict__ for gc in self.group_cardname_list]) if self.group_cardname_list else None
- )
+ """将当前数据模型转换为数据库记录对象。
+
+ Returns:
+ PersonInfo: 可直接写入数据库的模型实例。
+ """
+ group_cardname = dump_group_cardname_records(self.group_cardname_list)
return PersonInfo(
is_known=self.is_known,
person_id=self.person_id,
diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py
index a0993a77..5b274c43 100644
--- a/src/common/database/database_model.py
+++ b/src/common/database/database_model.py
@@ -1,8 +1,9 @@
-from typing import Optional
-from sqlalchemy import Column, Float, Enum as SQLEnum, DateTime
-from sqlmodel import SQLModel, Field, LargeBinary
-from enum import Enum
from datetime import datetime
+from enum import Enum
+from typing import Optional
+
+from sqlalchemy import Column, DateTime, Enum as SQLEnum, Float
+from sqlmodel import Field, LargeBinary, SQLModel
class ModelUser(str, Enum):
@@ -172,8 +173,8 @@ class Expression(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
- situation: str = Field(index=True, max_length=255, primary_key=True) # 情景
- style: str = Field(index=True, max_length=255, primary_key=True) # 风格
+ situation: str = Field(index=True, max_length=255) # 情景
+ style: str = Field(index=True, max_length=255) # 风格
# context: str # 上下文
# up_content: str
@@ -200,7 +201,7 @@ class Jargon(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
- content: str = Field(index=True, max_length=255, primary_key=True) # 黑话内容
+ content: str = Field(index=True, max_length=255) # 黑话内容
raw_content: Optional[str] = Field(default=None, nullable=True) # 原始内容,未处理的黑话内容,为List[str]
meaning: str # 黑话含义
diff --git a/src/common/logger_color_and_mapping.py b/src/common/logger_color_and_mapping.py
index 1aabafbc..863c9c1e 100644
--- a/src/common/logger_color_and_mapping.py
+++ b/src/common/logger_color_and_mapping.py
@@ -1,9 +1,8 @@
# 定义模块颜色映射
-from typing import Optional, Tuple, Dict
-
import itertools
import os
import sys
+from typing import Dict, Optional, Tuple
MODULE_COLORS: Dict[str, Tuple[str, Optional[str], bool]] = {
@@ -54,15 +53,19 @@ MODULE_COLORS: Dict[str, Tuple[str, Optional[str], bool]] = {
"component_registry": ("#ffaf00", None, False),
"plugin_runtime.integration": ("#d75f00", None, False),
"plugin_runtime.host.supervisor": ("#ff5f00", None, False),
+ "plugin_runtime.host.runner_manager": ("#ff5f00", None, False),
"plugin_runtime.host.rpc_server": ("#ff8700", None, False),
"plugin_runtime.host.component_registry": ("#ffaf00", None, False),
"plugin_runtime.host.capability_service": ("#ffd700", None, False),
"plugin_runtime.host.event_dispatcher": ("#87d700", None, False),
- "plugin_runtime.host.workflow_executor": ("#5fd7af", None, False),
+ "plugin_runtime.host.hook_dispatcher": ("#5fd7af", None, False),
+ "plugin_runtime.host.message_gateway": ("#5fd7d7", None, False),
+ "plugin_runtime.host.message_utils": ("#5faf87", None, False),
"plugin_runtime.runner.main": ("#d787ff", None, False),
"plugin_runtime.runner.rpc_client": ("#8787ff", None, False),
"plugin_runtime.runner.manifest_validator": ("#5fafff", None, False),
"plugin_runtime.runner.plugin_loader": ("#00afaf", None, False),
+ "plugin.maibot-team.napcat-adapter": ("#00af87", None, False),
"webui": ("#5f87ff", None, False),
"webui.app": ("#5f87d7", None, False),
"webui.api": ("#5fafff", None, False),
@@ -157,15 +160,20 @@ MODULE_ALIASES = {
"chat_history_summarizer": "聊天概括器",
"plugin_runtime.integration": "IPC插件系统",
"plugin_runtime.host.supervisor": "插件监督器",
+ "plugin_runtime.host.runner_manager": "插件监督器",
"plugin_runtime.host.rpc_server": "插件RPC服务",
"plugin_runtime.host.component_registry": "插件组件注册",
"plugin_runtime.host.capability_service": "插件能力服务",
"plugin_runtime.host.event_dispatcher": "插件事件分发",
+ "plugin_runtime.host.hook_dispatcher": "插件Hook分发",
+ "plugin_runtime.host.message_gateway": "插件消息网关",
+ "plugin_runtime.host.message_utils": "插件消息工具",
"plugin_runtime.host.workflow_executor": "插件工作流",
"plugin_runtime.runner.main": "插件运行器",
"plugin_runtime.runner.rpc_client": "插件RPC客户端",
"plugin_runtime.runner.manifest_validator": "插件清单校验",
"plugin_runtime.runner.plugin_loader": "插件加载器",
+ "plugin.maibot-team.napcat-adapter": "NapCat内置适配器",
"webui": "WebUI",
"webui.app": "WebUI应用",
"webui.api": "WebUI接口",
diff --git a/src/common/message_server/server.py b/src/common/message_server/server.py
index 77a931e5..e75da4e7 100644
--- a/src/common/message_server/server.py
+++ b/src/common/message_server/server.py
@@ -21,7 +21,7 @@ class Server:
self._server: Optional[UvicornServer] = None
self.set_address(host, port)
- def register_router(self, router: APIRouter, prefix: str = ""):
+ def register_router(self, router: APIRouter, prefix: str = ""):
"""注册路由
APIRouter 用于对相关的路由端点进行分组和模块化管理:
diff --git a/src/common/utils/utils_session.py b/src/common/utils/utils_session.py
index a383f5a2..1b6d8f72 100644
--- a/src/common/utils/utils_session.py
+++ b/src/common/utils/utils_session.py
@@ -5,13 +5,22 @@ import hashlib
class SessionUtils:
@staticmethod
- def calculate_session_id(platform: str, *, user_id: Optional[str] = None, group_id: Optional[str] = None) -> str:
+ def calculate_session_id(
+ platform: str,
+ *,
+ user_id: Optional[str] = None,
+ group_id: Optional[str] = None,
+ account_id: Optional[str] = None,
+ scope: Optional[str] = None,
+ ) -> str:
"""计算session_id
Args:
platform: 平台名称
user_id: 用户ID(如果是私聊)
group_id: 群ID(如果是群聊)
+ account_id: 当前平台账号 ID,可选
+ scope: 当前路由作用域,可选
Returns:
str: 计算得到的会话ID
Raises:
@@ -19,8 +28,15 @@ class SessionUtils:
"""
if not user_id and not group_id:
raise ValueError("UserID 或 GroupID 必须提供其一")
+
+ route_components = []
+ if account_id:
+ route_components.append(f"account:{account_id}")
+ if scope:
+ route_components.append(f"scope:{scope}")
+
if group_id:
- components = [platform, group_id]
+ components = [platform, *route_components, group_id]
else:
- components = [platform, user_id, "private"]
+ components = [platform, *route_components, user_id, "private"]
return hashlib.md5("_".join(components).encode()).hexdigest()
diff --git a/src/config/config.py b/src/config/config.py
index ff5941bf..bee81efb 100644
--- a/src/config/config.py
+++ b/src/config/config.py
@@ -4,6 +4,7 @@ from typing import Any, Callable, Mapping, Sequence, TypeVar
import asyncio
import copy
+import inspect
import sys
import tomlkit
@@ -61,6 +62,7 @@ MODEL_CONFIG_VERSION: str = "1.12.0"
logger = get_logger("config")
T = TypeVar("T", bound="ConfigBase")
+ConfigReloadCallback = Callable[[Sequence[str]], object] | Callable[[], object]
class Config(ConfigBase):
@@ -190,7 +192,7 @@ class ConfigManager:
self.global_config: Config | None = None
self.model_config: ModelConfig | None = None
self._reload_lock: asyncio.Lock = asyncio.Lock()
- self._reload_callbacks: list[Callable[[], object]] = []
+ self._reload_callbacks: list[ConfigReloadCallback] = []
self._file_watcher: FileWatcher | None = None
self._file_watcher_subscription_id: str | None = None
self._hot_reload_min_interval_s: float = 1.0
@@ -226,16 +228,125 @@ class ConfigManager:
raise RuntimeError(t("config.model_not_initialized"))
return self.model_config
- def register_reload_callback(self, callback: Callable[[], object]) -> None:
+ def register_reload_callback(self, callback: ConfigReloadCallback) -> None:
+ """注册配置热重载回调。
+
+ Args:
+ callback: 配置热重载回调。允许无参回调,也允许接收
+ ``Sequence[str]`` 类型的变更范围列表。
+ """
+
self._reload_callbacks.append(callback)
- def unregister_reload_callback(self, callback: Callable[[], object]) -> None:
+ def unregister_reload_callback(self, callback: ConfigReloadCallback) -> None:
+ """注销配置热重载回调。
+
+ Args:
+ callback: 先前注册过的回调对象。
+ """
+
try:
self._reload_callbacks.remove(callback)
except ValueError:
return
- async def reload_config(self) -> bool:
+ @staticmethod
+ def _normalize_changed_scopes(changed_scopes: Sequence[str] | None) -> tuple[str, ...]:
+ """规范化配置变更范围列表。
+
+ Args:
+ changed_scopes: 原始配置变更范围。
+
+ Returns:
+ tuple[str, ...]: 去重后的配置变更范围元组。
+ """
+
+ if not changed_scopes:
+ return ("bot", "model")
+
+ normalized_scopes: list[str] = []
+ for scope in changed_scopes:
+ normalized_scope = str(scope or "").strip().lower()
+ if normalized_scope not in {"bot", "model"}:
+ continue
+ if normalized_scope not in normalized_scopes:
+ normalized_scopes.append(normalized_scope)
+ return tuple(normalized_scopes)
+
+ @staticmethod
+ def _resolve_changed_scopes(changes: Sequence[FileChange]) -> tuple[str, ...]:
+ """根据文件变更列表推断配置变更范围。
+
+ Args:
+ changes: 文件监听器返回的变更列表。
+
+ Returns:
+ tuple[str, ...]: 命中的配置变更范围元组。
+ """
+
+ changed_scopes: list[str] = []
+ for change in changes:
+ file_name = change.path.name
+ if file_name == "bot_config.toml" and "bot" not in changed_scopes:
+ changed_scopes.append("bot")
+ if file_name == "model_config.toml" and "model" not in changed_scopes:
+ changed_scopes.append("model")
+ return tuple(changed_scopes)
+
+ @staticmethod
+ def _callback_accepts_scopes(callback: ConfigReloadCallback) -> bool:
+ """判断回调是否接收配置变更范围参数。
+
+ Args:
+ callback: 待检测的回调对象。
+
+ Returns:
+ bool: 若回调可接收一个位置参数或可变位置参数,则返回 ``True``。
+ """
+
+ try:
+ parameters = inspect.signature(callback).parameters.values()
+ except (TypeError, ValueError):
+ return False
+
+ positional_params = {
+ inspect.Parameter.POSITIONAL_ONLY,
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ }
+ for parameter in parameters:
+ if parameter.kind == inspect.Parameter.VAR_POSITIONAL:
+ return True
+ if parameter.kind in positional_params:
+ return True
+ return False
+
+ async def _invoke_reload_callback(
+ self,
+ callback: ConfigReloadCallback,
+ changed_scopes: Sequence[str],
+ ) -> None:
+ """执行单个配置热重载回调。
+
+ Args:
+ callback: 要执行的回调对象。
+ changed_scopes: 本次热重载命中的配置范围。
+ """
+
+ result = callback(changed_scopes) if self._callback_accepts_scopes(callback) else callback()
+ if asyncio.iscoroutine(result):
+ await result
+
+ async def reload_config(self, changed_scopes: Sequence[str] | None = None) -> bool:
+ """重新加载主配置和模型配置。
+
+ Args:
+ changed_scopes: 本次触发热重载的配置范围。
+
+ Returns:
+ bool: 是否重载成功。
+ """
+
+ normalized_scopes = self._normalize_changed_scopes(changed_scopes)
async with self._reload_lock:
try:
global_config_new, global_updated = load_config_from_file(
@@ -265,9 +376,7 @@ class ConfigManager:
for callback in list(self._reload_callbacks):
try:
- result = callback()
- if asyncio.iscoroutine(result):
- await result
+ await self._invoke_reload_callback(callback, normalized_scopes)
except Exception as exc:
logger.warning(t("config.reload_callback_failed", error=exc))
return True
@@ -312,6 +421,12 @@ class ConfigManager:
self._file_watcher = None
async def _handle_file_changes(self, changes: Sequence[FileChange]) -> None:
+ """处理主配置与模型配置文件变更。
+
+ Args:
+ changes: 当前批次收集到的文件变更列表。
+ """
+
if not changes:
return
now_monotonic = asyncio.get_running_loop().time()
@@ -321,7 +436,11 @@ class ConfigManager:
self._last_hot_reload_monotonic = now_monotonic
logger.info(t("config.file_change_detected"))
try:
- await asyncio.wait_for(self.reload_config(), timeout=self._hot_reload_timeout_s)
+ changed_scopes = self._resolve_changed_scopes(changes)
+ await asyncio.wait_for(
+ self.reload_config(changed_scopes=changed_scopes),
+ timeout=self._hot_reload_timeout_s,
+ )
except asyncio.TimeoutError:
logger.error(t("config.reload_timeout", timeout_seconds=self._hot_reload_timeout_s))
diff --git a/src/config/official_configs.py b/src/config/official_configs.py
index 20c2c2c8..fde3f800 100644
--- a/src/config/official_configs.py
+++ b/src/config/official_configs.py
@@ -1633,24 +1633,6 @@ class PluginRuntimeConfig(ConfigBase):
)
"""启用插件系统"""
- builtin_plugin_dir: str = Field(
- default="src/plugins/built_in",
- json_schema_extra={
- "x-widget": "input",
- "x-icon": "folder",
- },
- )
- """内置插件目录(相对于项目根目录)"""
-
- thirdparty_plugin_dir: str = Field(
- default="plugins",
- json_schema_extra={
- "x-widget": "input",
- "x-icon": "folder-open",
- },
- )
- """第三方插件目录(相对于项目根目录)"""
-
health_check_interval_sec: float = Field(
default=30.0,
json_schema_extra={
@@ -1678,14 +1660,14 @@ class PluginRuntimeConfig(ConfigBase):
)
"""等待 Runner 子进程启动并注册的超时时间(秒)"""
- workflow_blocking_timeout_sec: float = Field(
- default=120.0,
+ hook_blocking_timeout_sec: float = Field(
+ default=30,
json_schema_extra={
"x-widget": "number",
"x-icon": "timer",
},
)
- """Workflow 阻塞步骤的全局超时上限(秒)"""
+ """Hook 阻塞步骤的全局超时上限(秒)"""
ipc_socket_path: str = Field(
default="",
@@ -1694,4 +1676,7 @@ class PluginRuntimeConfig(ConfigBase):
"x-icon": "link",
},
)
- """_wrap_\n 自定义 IPC Socket 路径(仅 Linux/macOS 生效)\n 留空则自动生成临时路径"""
+ """
+ 自定义 IPC Socket 路径(仅 Linux/macOS 生效)
+ 留空则自动生成临时路径
+ """
diff --git a/src/core/component_registry.py b/src/core/component_registry.py
deleted file mode 100644
index bb58682a..00000000
--- a/src/core/component_registry.py
+++ /dev/null
@@ -1,239 +0,0 @@
-"""
-核心组件注册表
-
-面向最终架构的组件管理:
-- Action:注册 ActionInfo + 执行器(本地 callable 或 IPC 路由)
-- Command:注册正则模式 + 执行器
-- Tool:注册工具定义 + 执行器
-
-不依赖任何插件基类,组件执行器是纯 async callable。
-"""
-
-import re
-from typing import Any, Awaitable, Callable, Dict, Optional, Pattern, Tuple
-
-from src.common.logger import get_logger
-from src.core.types import (
- ActionInfo,
- CommandInfo,
- ComponentInfo,
- ComponentType,
- ToolInfo,
-)
-
-logger = get_logger("component_registry")
-
-# 执行器类型
-ActionExecutor = Callable[..., Awaitable[Any]]
-CommandExecutor = Callable[..., Awaitable[Tuple[bool, Optional[str], bool]]]
-ToolExecutor = Callable[..., Awaitable[Any]]
-
-
-class ComponentRegistry:
- """核心组件注册表
-
- 管理 action、command、tool 三类组件。
- 每个组件由「元信息 + 执行器」构成,执行器是 async callable,
- 不需要继承任何基类。
- """
-
- def __init__(self):
- # Action 注册
- self._actions: Dict[str, ActionInfo] = {}
- self._action_executors: Dict[str, ActionExecutor] = {}
- self._default_actions: Dict[str, ActionInfo] = {}
-
- # Command 注册
- self._commands: Dict[str, CommandInfo] = {}
- self._command_executors: Dict[str, CommandExecutor] = {}
- self._command_patterns: Dict[Pattern, str] = {}
-
- # Tool 注册
- self._tools: Dict[str, ToolInfo] = {}
- self._tool_executors: Dict[str, ToolExecutor] = {}
- self._llm_available_tools: Dict[str, ToolInfo] = {}
-
- # 插件配置(plugin_name -> config dict)
- self._plugin_configs: Dict[str, dict] = {}
-
- logger.info("核心组件注册表初始化完成")
-
- # ========== Action ==========
-
- def register_action(
- self,
- info: ActionInfo,
- executor: ActionExecutor,
- ) -> bool:
- """注册 action
-
- Args:
- info: action 元信息
- executor: 执行器,async callable
- """
- name = info.name
- if name in self._actions:
- logger.warning(f"Action {name} 已存在,跳过注册")
- return False
-
- self._actions[name] = info
- self._action_executors[name] = executor
-
- if info.enabled:
- self._default_actions[name] = info
-
- logger.debug(f"注册 Action: {name}")
- return True
-
- def get_action_info(self, name: str) -> Optional[ActionInfo]:
- return self._actions.get(name)
-
- def get_action_executor(self, name: str) -> Optional[ActionExecutor]:
- return self._action_executors.get(name)
-
- def get_default_actions(self) -> Dict[str, ActionInfo]:
- return self._default_actions.copy()
-
- def get_all_actions(self) -> Dict[str, ActionInfo]:
- return self._actions.copy()
-
- def remove_action(self, name: str) -> bool:
- if name not in self._actions:
- return False
- del self._actions[name]
- self._action_executors.pop(name, None)
- self._default_actions.pop(name, None)
- logger.debug(f"移除 Action: {name}")
- return True
-
- # ========== Command ==========
-
- def register_command(
- self,
- info: CommandInfo,
- executor: CommandExecutor,
- ) -> bool:
- """注册 command"""
- name = info.name
- if name in self._commands:
- logger.warning(f"Command {name} 已存在,跳过注册")
- return False
-
- self._commands[name] = info
- self._command_executors[name] = executor
-
- if info.enabled and info.command_pattern:
- pattern = re.compile(info.command_pattern, re.IGNORECASE | re.DOTALL)
- self._command_patterns[pattern] = name
-
- logger.debug(f"注册 Command: {name}")
- return True
-
- def find_command_by_text(self, text: str) -> Optional[Tuple[CommandExecutor, dict, CommandInfo]]:
- """根据文本查找匹配的命令
-
- Returns:
- (executor, matched_groups, command_info) 或 None
- """
- candidates = [p for p in self._command_patterns if p.match(text)]
- if not candidates:
- return None
- if len(candidates) > 1:
- logger.warning(f"文本 '{text[:50]}' 匹配到多个命令模式,使用第一个")
- pattern = candidates[0]
- name = self._command_patterns[pattern]
- return (
- self._command_executors[name],
- pattern.match(text).groupdict(), # type: ignore
- self._commands[name],
- )
-
- def remove_command(self, name: str) -> bool:
- if name not in self._commands:
- return False
- del self._commands[name]
- self._command_executors.pop(name, None)
- self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != name}
- logger.debug(f"移除 Command: {name}")
- return True
-
- # ========== Tool ==========
-
- def register_tool(
- self,
- info: ToolInfo,
- executor: ToolExecutor,
- ) -> bool:
- """注册 tool"""
- name = info.name
- if name in self._tools:
- logger.warning(f"Tool {name} 已存在,跳过注册")
- return False
-
- self._tools[name] = info
- self._tool_executors[name] = executor
-
- if info.enabled:
- self._llm_available_tools[name] = info
-
- logger.debug(f"注册 Tool: {name}")
- return True
-
- def get_tool_info(self, name: str) -> Optional[ToolInfo]:
- return self._tools.get(name)
-
- def get_tool_executor(self, name: str) -> Optional[ToolExecutor]:
- return self._tool_executors.get(name)
-
- def get_llm_available_tools(self) -> Dict[str, ToolInfo]:
- return self._llm_available_tools.copy()
-
- def get_all_tools(self) -> Dict[str, ToolInfo]:
- return self._tools.copy()
-
- def remove_tool(self, name: str) -> bool:
- if name not in self._tools:
- return False
- del self._tools[name]
- self._tool_executors.pop(name, None)
- self._llm_available_tools.pop(name, None)
- logger.debug(f"移除 Tool: {name}")
- return True
-
- # ========== 通用查询 ==========
-
- def get_component_info(self, name: str, component_type: ComponentType) -> Optional[ComponentInfo]:
- """获取组件元信息"""
- match component_type:
- case ComponentType.ACTION:
- return self._actions.get(name)
- case ComponentType.COMMAND:
- return self._commands.get(name)
- case ComponentType.TOOL:
- return self._tools.get(name)
- case _:
- return None
-
- def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
- """获取某类型的所有组件"""
- match component_type:
- case ComponentType.ACTION:
- return dict(self._actions)
- case ComponentType.COMMAND:
- return dict(self._commands)
- case ComponentType.TOOL:
- return dict(self._tools)
- case _:
- return {}
-
- # ========== 插件配置 ==========
-
- def set_plugin_config(self, plugin_name: str, config: dict) -> None:
- self._plugin_configs[plugin_name] = config
-
- def get_plugin_config(self, plugin_name: str) -> Optional[dict]:
- return self._plugin_configs.get(plugin_name)
-
-
-# 全局单例
-component_registry = ComponentRegistry()
diff --git a/src/bw_learner/expression_auto_check_task.py b/src/learners/expression_auto_check_task.py
similarity index 96%
rename from src/bw_learner/expression_auto_check_task.py
rename to src/learners/expression_auto_check_task.py
index d90eb4da..e5af1057 100644
--- a/src/bw_learner/expression_auto_check_task.py
+++ b/src/learners/expression_auto_check_task.py
@@ -3,19 +3,19 @@
功能:
1. 定期随机选取指定数量的表达方式
-2. 使用LLM进行评估
+2. 使用 LLM 进行评估
3. 通过评估的:rejected=0, checked=1
4. 未通过评估的:rejected=1, checked=1
"""
-from typing import List
import asyncio
import json
import random
+from typing import List
from sqlmodel import select
-from src.bw_learner.expression_review_store import get_review_state, set_review_state
+from src.learners.expression_review_store import get_review_state, set_review_state
from src.common.database.database import get_db_session
from src.common.database.database_model import Expression
from src.common.logger import get_logger
@@ -146,7 +146,8 @@ class ExpressionAutoCheckTask(AsyncTask):
选中的表达方式列表
"""
try:
- with get_db_session() as session:
+ # 这里只做查询,避免退出上下文时自动提交导致 ORM 实例过期。
+ with get_db_session(auto_commit=False) as session:
statement = select(Expression)
all_expressions = session.exec(statement).all()
diff --git a/src/bw_learner/expression_learner.py b/src/learners/expression_learner.py
similarity index 90%
rename from src/bw_learner/expression_learner.py
rename to src/learners/expression_learner.py
index 43e4ee7d..b82ae1fa 100644
--- a/src/bw_learner/expression_learner.py
+++ b/src/learners/expression_learner.py
@@ -329,7 +329,13 @@ class ExpressionLearner:
return filtered_expressions
# ====== DB 操作相关 ======
- async def _upsert_expression_to_db(self, situation: str, style: str):
+ async def _upsert_expression_to_db(self, situation: str, style: str) -> None:
+ """将表达方式写入数据库,存在时更新,不存在时新增。
+
+ Args:
+ situation: 表达方式对应的使用情景。
+ style: 表达方式风格。
+ """
expr, similarity = self._find_similar_expression(situation) or (None, 0)
if expr:
# 根据相似度决定是否使用 LLM 总结
@@ -340,7 +346,13 @@ class ExpressionLearner:
# 没有找到匹配的记录,创建新记录
self._create_expression(situation, style)
- def _create_expression(self, situation: str, style: str):
+ def _create_expression(self, situation: str, style: str) -> None:
+ """创建新的表达方式记录。
+
+ Args:
+ situation: 表达方式对应的使用情景。
+ style: 表达方式风格。
+ """
content_list = [situation]
try:
with get_db_session() as db:
@@ -353,6 +365,7 @@ class ExpressionLearner:
last_active_time=datetime.now(),
)
db.add(new_expr)
+ db.flush()
except Exception as e:
logger.error(f"创建表达方式失败: {e}")
@@ -448,25 +461,43 @@ class ExpressionLearner:
def _find_similar_expression(
self, situation: str, similarity_threshold: float = 0.75
) -> Optional[Tuple[MaiExpression, float]]:
- """在数据库中查找相似的表达方式"""
+ """在数据库中查找相似的表达方式。
+
+ Args:
+ situation: 当前待匹配的情景描述。
+ similarity_threshold: 认定为相似表达方式的最低相似度阈值。
+
+ Returns:
+ Optional[Tuple[MaiExpression, float]]: 若找到最相似的表达方式,则返回
+ ``(表达方式对象, 相似度)``;否则返回 ``None``。
+ """
try:
- with get_db_session() as session:
+ with get_db_session(auto_commit=False) as session:
statement = select(Expression).filter_by(session_id=self.session_id)
expressions = session.exec(statement).all()
- best_match: Optional[Expression] = None
- best_similarity = 0.0
+ best_match: Optional[MaiExpression] = None
+ best_similarity = 0.0
+
+ for db_expression in expressions:
+ expression = MaiExpression.from_db_instance(db_expression)
+ candidate_situations = [expression.situation, *expression.content]
+ for candidate_situation in candidate_situations:
+ normalized_candidate_situation = candidate_situation.strip()
+ if not normalized_candidate_situation:
+ continue
+ similarity = difflib.SequenceMatcher(
+ None,
+ situation,
+ normalized_candidate_situation,
+ ).ratio()
+ if similarity > similarity_threshold and similarity > best_similarity:
+ best_similarity = similarity
+ best_match = expression
- for expr in expressions:
- content_list = json.loads(expr.content_list)
- for situation in content_list:
- similarity = difflib.SequenceMatcher(None, situation, expr.situation).ratio()
- if similarity > similarity_threshold and similarity > best_similarity:
- best_similarity = similarity
- best_match = expr
if best_match:
- logger.debug(f"找到相似表达方式情景 [ID: {best_match.id}],相似度: {best_similarity:.2f}")
- return MaiExpression.from_db_instance(best_match), best_similarity
+ logger.debug(f"找到相似表达方式情景 [ID: {best_match.item_id}],相似度: {best_similarity:.2f}")
+ return best_match, best_similarity
except Exception as e:
logger.error(f"查找相似表达方式失败: {e}")
diff --git a/src/bw_learner/expression_review_store.py b/src/learners/expression_review_store.py
similarity index 100%
rename from src/bw_learner/expression_review_store.py
rename to src/learners/expression_review_store.py
diff --git a/src/bw_learner/expression_selector.py b/src/learners/expression_selector.py
similarity index 99%
rename from src/bw_learner/expression_selector.py
rename to src/learners/expression_selector.py
index c6cfe469..c96e84cf 100644
--- a/src/bw_learner/expression_selector.py
+++ b/src/learners/expression_selector.py
@@ -9,7 +9,7 @@ from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.prompt.prompt_manager import prompt_manager
-from src.bw_learner.learner_utils_old import weighted_sample
+from src.learners.learner_utils_old import weighted_sample
from src.chat.utils.common_utils import TempMethodsExpression
logger = get_logger("expression_selector")
diff --git a/src/bw_learner/expression_utils.py b/src/learners/expression_utils.py
similarity index 100%
rename from src/bw_learner/expression_utils.py
rename to src/learners/expression_utils.py
diff --git a/src/bw_learner/jargon_explainer.py b/src/learners/jargon_explainer.py
similarity index 100%
rename from src/bw_learner/jargon_explainer.py
rename to src/learners/jargon_explainer.py
diff --git a/src/bw_learner/jargon_explainer_old.py b/src/learners/jargon_explainer_old.py
similarity index 99%
rename from src/bw_learner/jargon_explainer_old.py
rename to src/learners/jargon_explainer_old.py
index 94031b4a..0cfafa82 100644
--- a/src/bw_learner/jargon_explainer_old.py
+++ b/src/learners/jargon_explainer_old.py
@@ -7,8 +7,8 @@ from src.common.database.database_model import Jargon
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.prompt.prompt_manager import prompt_manager
-from src.bw_learner.jargon_explainer import search_jargon
-from src.bw_learner.learner_utils_old import (
+from src.learners.jargon_miner_old import search_jargon
+from src.learners.learner_utils_old import (
is_bot_message,
contains_bot_self_name,
parse_chat_id_list,
diff --git a/src/bw_learner/jargon_miner.py b/src/learners/jargon_miner.py
similarity index 93%
rename from src/bw_learner/jargon_miner.py
rename to src/learners/jargon_miner.py
index 2fbf8a2e..32926894 100644
--- a/src/bw_learner/jargon_miner.py
+++ b/src/learners/jargon_miner.py
@@ -1,17 +1,18 @@
from collections import OrderedDict
-from json_repair import repair_json
-from sqlmodel import select
-from typing import List, Optional, Dict, Callable, TypedDict, Set
+from typing import Callable, Dict, List, Optional, Set, TypedDict
import asyncio
import json
import random
-from src.common.logger import get_logger
+from json_repair import repair_json
+from sqlmodel import select
+
+from src.common.data_models.jargon_data_model import MaiJargon
from src.common.database.database import get_db_session
from src.common.database.database_model import Jargon
-from src.common.data_models.jargon_data_model import MaiJargon
-from src.config.config import model_config, global_config
+from src.common.logger import get_logger
+from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.prompt.prompt_manager import prompt_manager
@@ -198,7 +199,7 @@ class JargonMiner:
async def process_extracted_entries(
self, entries: List[JargonEntry], person_name_filter: Optional[Callable[[str], bool]] = None
- ):
+ ) -> None:
"""
处理已提取的黑话条目(从 expression_learner 路由过来的)
@@ -229,7 +230,7 @@ class JargonMiner:
content = entry["content"]
raw_content_set = entry["raw_content"]
try:
- with get_db_session() as session:
+ with get_db_session(auto_commit=False) as session:
jargon_items = session.exec(select(Jargon).filter_by(content=content)).all()
except Exception as e:
logger.error(f"查询黑话 '{content}' 失败: {e}")
@@ -273,11 +274,12 @@ class JargonMiner:
try:
with get_db_session() as session:
session.add(new_jargon)
+ session.flush()
+ saved += 1
+ self._add_to_cache(content)
except Exception as e:
logger.error(f"保存新黑话 '{content}' 失败: {e}")
continue
- finally:
- self._add_to_cache(content)
# 固定输出提取的jargon结果,格式化为可读形式(只要有提取结果就输出)
if uniq_entries:
# 收集所有提取的jargon内容
@@ -304,7 +306,13 @@ class JargonMiner:
removed_content, _ = self.cache.popitem(last=False)
logger.debug(f"缓存已满,移除最旧的黑话: {removed_content}")
- def _update_jargon(self, db_jargon: Jargon, raw_content_set: Set[str]):
+ def _update_jargon(self, db_jargon: Jargon, raw_content_set: Set[str]) -> None:
+ """更新已有黑话记录并写回数据库。
+
+ Args:
+ db_jargon: 已命中的黑话 ORM 对象。
+ raw_content_set: 本次新增的原始上下文集合。
+ """
db_jargon.count += 1
existing_raw_content: List[str] = []
if db_jargon.raw_content:
@@ -326,7 +334,17 @@ class JargonMiner:
try:
with get_db_session() as session:
- session.add(db_jargon)
+ if db_jargon.id is None:
+ raise ValueError("黑话记录缺少 id,无法更新数据库")
+ statement = select(Jargon).filter_by(id=db_jargon.id).limit(1)
+ if persisted_jargon := session.exec(statement).first():
+ persisted_jargon.count = db_jargon.count
+ persisted_jargon.raw_content = db_jargon.raw_content
+ persisted_jargon.session_id_dict = db_jargon.session_id_dict
+ persisted_jargon.is_global = db_jargon.is_global
+ session.add(persisted_jargon)
+ else:
+ logger.warning(f"黑话 ID {db_jargon.id} 在数据库中未找到,无法更新")
except Exception as e:
logger.error(f"更新黑话 '{db_jargon.content}' 失败: {e}")
diff --git a/src/bw_learner/learner_utils.py b/src/learners/learner_utils.py
similarity index 100%
rename from src/bw_learner/learner_utils.py
rename to src/learners/learner_utils.py
diff --git a/src/bw_learner/learner_utils_old.py b/src/learners/learner_utils_old.py
similarity index 100%
rename from src/bw_learner/learner_utils_old.py
rename to src/learners/learner_utils_old.py
diff --git a/src/main.py b/src/main.py
index 1bfa91b0..6e568df5 100644
--- a/src/main.py
+++ b/src/main.py
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
import asyncio
import time
-from src.bw_learner.expression_auto_check_task import ExpressionAutoCheckTask
+from src.learners.expression_auto_check_task import ExpressionAutoCheckTask
from src.chat.emoji_system.emoji_manager import emoji_manager
from src.chat.message_receive.bot import chat_bot
from src.chat.message_receive.chat_manager import chat_manager
diff --git a/src/memory_system/memory_retrieval.py b/src/memory_system/memory_retrieval.py
index 49e5ca02..2eadd05a 100644
--- a/src/memory_system/memory_retrieval.py
+++ b/src/memory_system/memory_retrieval.py
@@ -14,7 +14,7 @@ from src.common.database.database_model import ThinkingQuestion
from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
-from src.bw_learner.jargon_explainer_old import retrieve_concepts_with_jargon
+from src.learners.jargon_explainer_old import retrieve_concepts_with_jargon
logger = get_logger("memory_retrieval")
diff --git a/src/memory_system/retrieval_tools/query_words.py b/src/memory_system/retrieval_tools/query_words.py
index 66fb3c46..ee28b934 100644
--- a/src/memory_system/retrieval_tools/query_words.py
+++ b/src/memory_system/retrieval_tools/query_words.py
@@ -4,7 +4,7 @@
"""
from src.common.logger import get_logger
-from src.bw_learner.jargon_explainer_old import retrieve_concepts_with_jargon
+from src.learners.jargon_explainer_old import retrieve_concepts_with_jargon
from .tool_registry import register_memory_retrieval_tool
logger = get_logger("memory_retrieval_tools")
diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py
index 960de4aa..15ef0049 100644
--- a/src/person_info/person_info.py
+++ b/src/person_info/person_info.py
@@ -1,24 +1,24 @@
-import hashlib
+from datetime import datetime
+from typing import Dict, Optional, Union
+
import asyncio
+import hashlib
import json
-import time
-import random
import math
+import random
+import time
from json_repair import repair_json
-from typing import Union, Optional, Dict, List
-from datetime import datetime
-from sqlalchemy import or_
from sqlmodel import col, select
-from src.common.logger import get_logger
+from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
+from src.common.data_models.person_info_data_model import dump_group_cardname_records, parse_group_cardname_json
from src.common.database.database import get_db_session
from src.common.database.database_model import PersonInfo
-from src.llm_models.utils_model import LLMRequest
+from src.common.logger import get_logger
from src.config.config import global_config, model_config
-from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
-from src.services.memory_service import memory_service
+from src.llm_models.utils_model import LLMRequest
logger = get_logger("person_info")
@@ -28,6 +28,32 @@ relation_selection_model = LLMRequest(
)
+def _to_group_cardname_records(group_cardname_json: Optional[str]) -> list[dict[str, str]]:
+ """将数据库中的群名片 JSON 转换为 `Person` 内部使用的结构。
+
+ Args:
+ group_cardname_json: 数据库存储的群名片 JSON 字符串。
+
+ Returns:
+ list[dict[str, str]]: 统一使用 `group_cardname` 键名的群名片列表。
+
+ Raises:
+ json.JSONDecodeError: 当 JSON 文本格式非法时抛出。
+ TypeError: 当输入值类型不符合 `json.loads()` 要求时抛出。
+ """
+ group_cardname_list = parse_group_cardname_json(group_cardname_json)
+ if not group_cardname_list:
+ return []
+
+ return [
+ {
+ "group_id": group_cardname.group_id,
+ "group_cardname": group_cardname.group_cardname,
+ }
+ for group_cardname in group_cardname_list
+ ]
+
+
def get_person_id(platform: str, user_id: Union[int, str]) -> str:
"""获取唯一id"""
if "-" in platform:
@@ -39,60 +65,16 @@ def get_person_id(platform: str, user_id: Union[int, str]) -> str:
def get_person_id_by_person_name(person_name: str) -> str:
"""根据用户名获取用户ID"""
- clean_name = str(person_name or "").strip()
- if not clean_name:
- return ""
try:
with get_db_session() as session:
- statement = (
- select(PersonInfo)
- .where(
- or_(
- col(PersonInfo.person_name) == clean_name,
- col(PersonInfo.user_nickname) == clean_name,
- )
- )
- .limit(1)
- )
- record = session.exec(statement).first()
- if record and record.person_id:
- return record.person_id
-
- statement = (
- select(PersonInfo)
- .where(PersonInfo.group_cardname.contains(clean_name))
- .limit(1)
- )
+ statement = select(PersonInfo).where(col(PersonInfo.person_name) == person_name).limit(1)
record = session.exec(statement).first()
return record.person_id if record else ""
except Exception as e:
- logger.error(f"根据用户名 {clean_name} 获取用户ID时出错: {e}")
+ logger.error(f"根据用户名 {person_name} 获取用户ID时出错: {e}")
return ""
-def resolve_person_id_for_memory(
- *,
- person_name: str = "",
- platform: str = "",
- user_id: Optional[Union[int, str]] = None,
-) -> str:
- """统一人物记忆链路中的 person_id 解析。
-
- 优先使用已知的人物名称/别名,其次退回到平台 + user_id 的稳定 ID。
- """
- name_token = str(person_name or "").strip()
- if name_token:
- resolved = get_person_id_by_person_name(name_token)
- if resolved:
- return resolved
-
- platform_token = str(platform or "").strip()
- user_token = str(user_id or "").strip()
- if platform_token and user_token:
- return get_person_id(platform_token, user_token)
- return ""
-
-
def is_person_known(
person_id: Optional[str] = None,
user_id: Optional[str] = None,
@@ -277,7 +259,7 @@ class Person:
person.know_since = time.time()
person.last_know = time.time()
person.memory_points = []
- person.group_nick_name = [] # 初始化群昵称列表
+ person.group_cardname_list = [] # 初始化群名片列表
# 如果是群聊,添加群昵称
if group_id and group_nick_name:
@@ -315,7 +297,7 @@ class Person:
self.platform = platform
self.nickname = global_config.bot.nickname
self.person_name = global_config.bot.nickname
- self.group_nick_name: list[dict[str, str]] = []
+ self.group_cardname_list: list[dict[str, str]] = []
return
self.user_id = ""
@@ -354,7 +336,7 @@ class Person:
self.know_since = None
self.last_know: Optional[float] = None
self.memory_points = []
- self.group_nick_name: list[dict[str, str]] = [] # 群昵称列表,存储 {"group_id": str, "group_nick_name": str}
+ self.group_cardname_list: list[dict[str, str]] = [] # 群名片列表,存储 {"group_id": str, "group_cardname": str}
# 从数据库加载数据
self.load_from_database()
@@ -454,16 +436,16 @@ class Person:
return
# 检查是否已存在该群号的记录
- for item in self.group_nick_name:
+ for item in self.group_cardname_list:
if item.get("group_id") == group_id:
# 更新现有记录
- item["group_nick_name"] = group_nick_name
+ item["group_cardname"] = group_nick_name
self.sync_to_database()
logger.debug(f"更新用户 {self.person_id} 在群 {group_id} 的群昵称为 {group_nick_name}")
return
# 添加新记录
- self.group_nick_name.append({"group_id": group_id, "group_nick_name": group_nick_name})
+ self.group_cardname_list.append({"group_id": group_id, "group_cardname": group_nick_name})
self.sync_to_database()
logger.debug(f"添加用户 {self.person_id} 在群 {group_id} 的群昵称 {group_nick_name}")
@@ -498,20 +480,15 @@ class Person:
else:
self.memory_points = []
- # 处理group_nick_name字段(JSON格式的列表)
+ # 处理 group_cardname 字段(JSON 格式的列表)
if record.group_cardname:
try:
- loaded_group_nick_names = json.loads(record.group_cardname)
- # 确保是列表格式
- if isinstance(loaded_group_nick_names, list):
- self.group_nick_name = loaded_group_nick_names
- else:
- self.group_nick_name = []
+ self.group_cardname_list = _to_group_cardname_records(record.group_cardname)
except (json.JSONDecodeError, TypeError):
logger.warning(f"解析用户 {self.person_id} 的group_cardname字段失败,使用默认值")
- self.group_nick_name = []
+ self.group_cardname_list = []
else:
- self.group_nick_name = []
+ self.group_cardname_list = []
logger.debug(f"已从数据库加载用户 {self.person_id} 的信息")
else:
@@ -532,11 +509,7 @@ class Person:
if self.memory_points
else json.dumps([], ensure_ascii=False)
)
- group_nickname_value = (
- json.dumps(self.group_nick_name, ensure_ascii=False)
- if self.group_nick_name
- else json.dumps([], ensure_ascii=False)
- )
+ group_cardname_value = dump_group_cardname_records(self.group_cardname_list)
first_known_time = datetime.fromtimestamp(self.know_since) if self.know_since else None
last_known_time = datetime.fromtimestamp(self.last_know) if self.last_know else None
@@ -556,7 +529,7 @@ class Person:
record.first_known_time = first_known_time
record.last_known_time = last_known_time
record.memory_points = memory_points_value
- record.group_nickname = group_nickname_value
+ record.group_cardname = group_cardname_value
session.add(record)
logger.debug(f"已同步用户 {self.person_id} 的信息到数据库")
else:
@@ -572,7 +545,7 @@ class Person:
first_known_time=first_known_time,
last_known_time=last_known_time,
memory_points=memory_points_value,
- group_nickname=group_nickname_value,
+ group_cardname=group_cardname_value,
)
session.add(record)
logger.debug(f"已创建用户 {self.person_id} 的信息到数据库")
@@ -583,79 +556,79 @@ class Person:
async def build_relationship(self, chat_content: str = "", info_type=""):
if not self.is_known:
return ""
+ # 构建points文本
+
nickname_str = ""
if self.person_name != self.nickname:
nickname_str = f"(ta在{self.platform}上的昵称是{self.nickname})"
- async def _select_traits(query_text: str, traits: List[str], limit: int = 3) -> List[str]:
- clean_traits = [trait.strip() for trait in traits if isinstance(trait, str) and trait.strip()]
- if not clean_traits:
- return []
- if not query_text:
- return clean_traits[:limit]
+ relation_info = ""
- numbered_traits = "\n".join(f"{index}. {trait}" for index, trait in enumerate(clean_traits, start=1))
- prompt = f"""当前关注内容:
-{query_text}
+ points_text = ""
+ category_list = self.get_all_category()
-候选人物信息:
-{numbered_traits}
+ if chat_content:
+ prompt = f"""当前聊天内容:
+{chat_content}
-请从候选人物信息中选择与当前关注内容最相关的编号,并用<>包裹输出,不要输出其他内容。
-例如:
-<1><3>
-如果都不相关,请输出
"""
+分类列表:
+{category_list}
+**要求**:请你根据当前聊天内容,从以下分类中选择一个与聊天内容相关的分类,并用<>包裹输出,不要输出其他内容,不要输出引号或[],严格用<>包裹:
+例如:
+<分类1><分类2><分类3>......
+如果没有相关的分类,请输出"""
- try:
- response, _ = await relation_selection_model.generate_response_async(prompt)
- selected_traits: List[str] = []
- for raw_index in extract_categories_from_response(response):
- if raw_index == "none":
- return []
- try:
- trait_index = int(raw_index) - 1
- except ValueError:
- continue
- if 0 <= trait_index < len(clean_traits):
- trait = clean_traits[trait_index]
- if trait not in selected_traits:
- selected_traits.append(trait)
- if selected_traits:
- return selected_traits[:limit]
- except Exception as e:
- logger.debug(f"筛选人物画像信息失败,使用默认画像摘要: {e}")
+ response, _ = await relation_selection_model.generate_response_async(prompt)
+ # print(prompt)
+ # print(response)
+ category_list = extract_categories_from_response(response)
+ if "none" not in category_list:
+ for category in category_list:
+ random_memory = self.get_random_memory_by_category(category, 2)
+ if random_memory:
+ random_memory_str = "\n".join(
+ [get_memory_content_from_memory(memory) for memory in random_memory]
+ )
+ points_text = f"有关 {category} 的内容:{random_memory_str}"
+ break
+ elif info_type:
+ prompt = f"""你需要获取用户{self.person_name}的 **{info_type}** 信息。
- return clean_traits[:limit]
-
- profile = await memory_service.get_person_profile(self.person_id, limit=8)
- relation_parts: List[str] = []
- if profile.summary.strip():
- relation_parts.append(profile.summary.strip())
-
- query_text = str(chat_content or info_type or "").strip()
- selected_traits = await _select_traits(query_text, profile.traits, limit=3)
- if not selected_traits and not query_text:
- selected_traits = [trait for trait in profile.traits if trait][:2]
-
- for trait in selected_traits:
- clean_trait = str(trait).strip()
- if clean_trait and clean_trait not in relation_parts:
- relation_parts.append(clean_trait)
-
- for evidence in profile.evidence:
- content = str(evidence.get("content", "") or "").strip()
- if content and content not in relation_parts:
- relation_parts.append(content)
- if len(relation_parts) >= 4:
- break
+现有信息类别列表:
+{category_list}
+**要求**:请你根据**{info_type}**,从以下分类中选择一个与**{info_type}**相关的分类,并用<>包裹输出,不要输出其他内容,不要输出引号或[],严格用<>包裹:
+例如:
+<分类1><分类2><分类3>......
+如果没有相关的分类,请输出"""
+ response, _ = await relation_selection_model.generate_response_async(prompt)
+ # print(prompt)
+ # print(response)
+ category_list = extract_categories_from_response(response)
+ if "none" not in category_list:
+ for category in category_list:
+ random_memory = self.get_random_memory_by_category(category, 3)
+ if random_memory:
+ random_memory_str = "\n".join(
+ [get_memory_content_from_memory(memory) for memory in random_memory]
+ )
+ points_text = f"有关 {category} 的内容:{random_memory_str}"
+ break
+ else:
+ for category in category_list:
+ random_memory = self.get_random_memory_by_category(category, 1)[0]
+ if random_memory:
+ points_text = f"有关 {category} 的内容:{get_memory_content_from_memory(random_memory)}"
+ break
points_info = ""
- if relation_parts:
- points_info = f"你还记得有关{self.person_name}的内容:{';'.join(relation_parts[:3])}"
+ if points_text:
+ points_info = f"你还记得有关{self.person_name}的内容:{points_text}"
if not (nickname_str or points_info):
return ""
- return f"{self.person_name}:{nickname_str}{points_info}"
+ relation_info = f"{self.person_name}:{nickname_str}{points_info}"
+
+ return relation_info
class PersonInfoManager:
@@ -822,7 +795,7 @@ person_info_manager = PersonInfoManager()
async def store_person_memory_from_answer(person_name: str, memory_content: str, chat_id: str) -> None:
- """将人物事实写入统一长期记忆
+ """将人物信息存入person_info的memory_points
Args:
person_name: 人物名称
@@ -830,11 +803,6 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
chat_id: 聊天ID
"""
try:
- content = str(memory_content or "").strip()
- if not content:
- logger.debug("人物记忆内容为空,跳过写入")
- return
-
# 从 chat_id 获取 session
session = _chat_manager.get_session_by_session_id(chat_id)
if not session:
@@ -845,14 +813,16 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
# 尝试从person_name查找person_id
# 首先尝试通过person_name查找
- person_id = resolve_person_id_for_memory(
- person_name=person_name,
- platform=platform,
- user_id=session.user_id,
- )
+ person_id = get_person_id_by_person_name(person_name)
+
if not person_id:
- logger.warning(f"无法确定person_id for person_name: {person_name}, chat_id: {chat_id}")
- return
+ # 如果通过person_name找不到,尝试从 session 获取 user_id
+ if platform and session.user_id:
+ user_id = session.user_id
+ person_id = get_person_id(platform, user_id)
+ else:
+ logger.warning(f"无法确定person_id for person_name: {person_name}, chat_id: {chat_id}")
+ return
# 创建或获取Person对象
person = Person(person_id=person_id)
@@ -861,34 +831,39 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
logger.warning(f"用户 {person_name} (person_id: {person_id}) 尚未认识,无法存储记忆")
return
- memory_hash = hashlib.sha256(f"{person_id}\n{content}".encode("utf-8")).hexdigest()[:16]
- result = await memory_service.ingest_text(
- external_id=f"person_fact:{person_id}:{memory_hash}",
- source_type="person_fact",
- text=content,
- chat_id=chat_id,
- person_ids=[person_id],
- participants=[person.person_name or person_name],
- timestamp=time.time(),
- tags=["person_fact"],
- metadata={
- "person_id": person_id,
- "person_name": person.person_name or person_name,
- "platform": platform,
- "source": "person_info.store_person_memory_from_answer",
- },
- respect_filter=True,
- user_id=str(session.user_id or "").strip(),
- group_id=str(session.group_id or "").strip(),
- )
+ # 确定记忆分类(可以根据memory_content判断,这里使用通用分类)
+ category = "其他" # 默认分类,可以根据需要调整
- if result.success:
- if result.detail == "chat_filtered":
- logger.debug(f"人物长期记忆被聊天过滤策略跳过: {person_name} (person_id: {person_id})")
- else:
- logger.info(f"成功写入人物长期记忆: {person_name} (person_id: {person_id})")
+ # 记忆点格式:category:content:weight
+ weight = "1.0" # 默认权重
+ memory_point = f"{category}:{memory_content}:{weight}"
+
+ # 添加到memory_points
+ if not person.memory_points:
+ person.memory_points = []
+
+ # 检查是否已存在相似的记忆点(避免重复)
+ is_duplicate = False
+ for existing_point in person.memory_points:
+ if existing_point and isinstance(existing_point, str):
+ parts = existing_point.split(":", 2)
+ if len(parts) >= 2:
+ existing_content = parts[1].strip()
+ # 简单相似度检查(如果内容相同或非常相似,则跳过)
+ if (
+ existing_content == memory_content
+ or memory_content in existing_content
+ or existing_content in memory_content
+ ):
+ is_duplicate = True
+ break
+
+ if not is_duplicate:
+ person.memory_points.append(memory_point)
+ person.sync_to_database()
+ logger.info(f"成功添加记忆点到 {person_name} (person_id: {person_id}): {memory_point}")
else:
- logger.warning(f"写入人物长期记忆失败: {person_name} (person_id: {person_id}) | {result.detail}")
+ logger.debug(f"记忆点已存在,跳过: {memory_point}")
except Exception as e:
logger.error(f"存储人物记忆失败: {e}")
diff --git a/src/platform_io/__init__.py b/src/platform_io/__init__.py
new file mode 100644
index 00000000..c91535d1
--- /dev/null
+++ b/src/platform_io/__init__.py
@@ -0,0 +1,34 @@
+"""导出 Platform IO 层的公开入口。
+
+当前仍处于地基阶段,调用方应优先从这里导入共享类型和全局管理器,
+而不是直接依赖更底层的私有子模块。
+"""
+
+from .manager import PlatformIOManager, get_platform_io_manager
+from .route_key_factory import RouteKeyFactory
+from .routing import RouteTable
+from .types import (
+ DeliveryBatch,
+ DeliveryReceipt,
+ DeliveryStatus,
+ DriverDescriptor,
+ DriverKind,
+ InboundMessageEnvelope,
+ RouteBinding,
+ RouteKey,
+)
+
+__all__ = [
+ "DeliveryBatch",
+ "DeliveryReceipt",
+ "DeliveryStatus",
+ "DriverDescriptor",
+ "DriverKind",
+ "InboundMessageEnvelope",
+ "PlatformIOManager",
+ "RouteKeyFactory",
+ "RouteBinding",
+ "RouteKey",
+ "RouteTable",
+ "get_platform_io_manager",
+]
diff --git a/src/platform_io/dedupe.py b/src/platform_io/dedupe.py
new file mode 100644
index 00000000..4c5c55a2
--- /dev/null
+++ b/src/platform_io/dedupe.py
@@ -0,0 +1,133 @@
+"""提供 Platform IO 的轻量入站消息去重能力。
+
+当前实现基于 ``dict + heapq``:
+- ``dict`` 保存去重键到过期时间的映射
+- ``heapq`` 维护按过期时间排序的小顶堆
+
+这样就不需要在每次检查时全表扫描,而是通过懒清理逐步弹出已经过期
+或已经失效的堆节点。
+"""
+
+from typing import Dict, List, Tuple
+
+import heapq
+import time
+
+
+class MessageDeduplicator:
+ """使用基于 TTL 的内存缓存进行入站消息去重。
+
+ 主要用于解决同一条外部消息被重复送入 Core 的问题,例如双路径并存、
+ 适配器重试、重连或重复回调等场景。Broker 可以借助这个组件在进入
+ Core 前先拦住重复投递,避免重复处理、重复回复和重复入库。
+
+ 当前实现使用 ``dict + heapq`` 维护过期时间:
+ - ``dict`` 负责 ``O(1)`` 级别的去重键查找
+ - ``heapq`` 负责按过期时间顺序做懒清理
+
+ 这比“每次调用都全表扫描过期项”的实现更适合高吞吐消息场景。
+
+ Notes:
+ 复杂度说明如下,设 ``n`` 为当前缓存中的有效去重键数量:
+
+ - 单次 ``mark_seen()`` 在常见路径下的时间复杂度接近 ``O(log n)``
+ - 从长期摊还角度看,``mark_seen()`` 的时间复杂度也接近 ``O(log n)``
+ - 如果某次调用恰好触发一批过期键的集中清理,则该次调用的最坏时间复杂度
+ 可达到 ``O(k log n)``,其中 ``k`` 为本次被弹出或清理的键数量
+ - 空间复杂度为 ``O(n)``
+ """
+
+ def __init__(self, ttl_seconds: float = 300.0, max_entries: int = 10000) -> None:
+ """初始化去重器。
+
+ Args:
+ ttl_seconds: 每个去重键在缓存中的保留时长,单位为秒。
+ max_entries: 缓存允许保留的最大有效键数量,超出后会触发
+ 机会性淘汰。
+
+ Raises:
+ ValueError: 当 ``ttl_seconds`` 或 ``max_entries`` 非正数时抛出。
+ """
+ if ttl_seconds <= 0:
+ raise ValueError("ttl_seconds 必须大于 0")
+ if max_entries <= 0:
+ raise ValueError("max_entries 必须大于 0")
+
+ self._ttl_seconds = ttl_seconds
+ self._max_entries = max_entries
+ self._expire_heap: List[Tuple[float, str]] = []
+ self._seen: Dict[str, float] = {}
+
+ def mark_seen(self, dedupe_key: str) -> bool:
+ """标记一条去重键已经出现过。
+
+ Args:
+ dedupe_key: 能稳定标识一条外部入站消息的去重键。
+
+ Returns:
+ bool: 若该键在当前 TTL 窗口内首次出现则返回 ``True``,
+ 否则返回 ``False``。
+
+ Notes:
+ 方法会先基于小顶堆做一次懒清理,再判断当前键是否仍在有效期内。
+ 如果缓存已达到上限,则会优先淘汰“最早过期的仍然有效的键”。
+
+ 复杂度方面,常见路径下该方法接近 ``O(log n)``;如果恰好需要
+ 集中清理一批过期键,则单次调用最坏可达到 ``O(k log n)``。
+ """
+ now = time.monotonic()
+ self._purge_expired(now)
+
+ expires_at = self._seen.get(dedupe_key)
+ if expires_at is not None and expires_at > now:
+ return False
+
+ if len(self._seen) >= self._max_entries:
+ self._evict_earliest_live()
+
+ expires_at = now + self._ttl_seconds
+ self._seen[dedupe_key] = expires_at
+ heapq.heappush(self._expire_heap, (expires_at, dedupe_key))
+ return True
+
+ def clear(self) -> None:
+ """清空全部去重缓存。"""
+ self._expire_heap.clear()
+ self._seen.clear()
+
+ def _purge_expired(self, now: float) -> None:
+ """从缓存中清理已经过期的去重键。
+
+ Args:
+ now: 当前单调时钟时间戳。
+
+ Notes:
+ 堆中可能存在旧版本节点。例如同一个 ``dedupe_key`` 被重新写入后,
+ 旧的过期时间节点仍会留在堆里。这里会通过和 ``dict`` 中当前值比对,
+ 跳过这类失效节点。
+ """
+ while self._expire_heap and self._expire_heap[0][0] <= now:
+ expires_at, dedupe_key = heapq.heappop(self._expire_heap)
+ current_expires_at = self._seen.get(dedupe_key)
+ if current_expires_at is None:
+ continue
+ if current_expires_at != expires_at:
+ continue
+ self._seen.pop(dedupe_key, None)
+
+ def _evict_earliest_live(self) -> None:
+ """当缓存达到容量上限时,淘汰一条最早过期的有效键。
+
+ Notes:
+ 堆顶可能是已经过期或已失效的旧节点,因此这里同样需要循环弹出,
+ 直到找到一条当前仍然在 ``dict`` 中生效的键。
+ """
+ while self._expire_heap:
+ expires_at, dedupe_key = heapq.heappop(self._expire_heap)
+ current_expires_at = self._seen.get(dedupe_key)
+ if current_expires_at is None:
+ continue
+ if current_expires_at != expires_at:
+ continue
+ self._seen.pop(dedupe_key, None)
+ return
diff --git a/src/platform_io/drivers/__init__.py b/src/platform_io/drivers/__init__.py
new file mode 100644
index 00000000..b12120cf
--- /dev/null
+++ b/src/platform_io/drivers/__init__.py
@@ -0,0 +1,11 @@
+"""导出 Platform IO 层的公开驱动类型。"""
+
+from .base import PlatformIODriver
+from .legacy_driver import LegacyPlatformDriver
+from .plugin_driver import PluginPlatformDriver
+
+__all__ = [
+ "LegacyPlatformDriver",
+ "PlatformIODriver",
+ "PluginPlatformDriver",
+]
diff --git a/src/platform_io/drivers/base.py b/src/platform_io/drivers/base.py
new file mode 100644
index 00000000..c6173d8c
--- /dev/null
+++ b/src/platform_io/drivers/base.py
@@ -0,0 +1,104 @@
+"""定义 Platform IO 传输驱动的基础抽象协议。"""
+
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional
+
+from src.platform_io.types import DeliveryReceipt, DriverDescriptor, InboundMessageEnvelope, RouteKey
+
+if TYPE_CHECKING:
+ from src.chat.message_receive.message import SessionMessage
+
+InboundHandler = Callable[[InboundMessageEnvelope], Awaitable[bool]]
+
+
+class PlatformIODriver(ABC):
+ """定义所有 Platform IO 驱动都必须实现的最小契约。
+
+ 当前实现故意保持接口很小,让中间层可以先落地,再逐步把 legacy
+ 与 plugin 路径的真实收发能力迁入这套协议之下。
+ """
+
+ def __init__(self, descriptor: DriverDescriptor) -> None:
+ """使用驱动描述对象初始化驱动。
+
+ Args:
+ descriptor: 注册到 Broker 中的静态驱动元数据。
+ """
+ self._descriptor = descriptor
+ self._inbound_handler: Optional[InboundHandler] = None
+
+ @property
+ def descriptor(self) -> DriverDescriptor:
+ """返回当前驱动的描述对象。
+
+ Returns:
+ DriverDescriptor: 当前驱动实例对应的描述对象。
+ """
+ return self._descriptor
+
+ @property
+ def driver_id(self) -> str:
+ """返回驱动标识。
+
+ Returns:
+ str: 当前驱动的唯一 ID。
+ """
+ return self._descriptor.driver_id
+
+ def set_inbound_handler(self, handler: InboundHandler) -> None:
+ """注册入站消息交回 Broker 的回调函数。
+
+ Args:
+ handler: 将规范化入站封装继续转发给 Broker 的异步回调。
+ """
+ self._inbound_handler = handler
+
+ def clear_inbound_handler(self) -> None:
+ """清除当前注册的入站回调函数。"""
+ self._inbound_handler = None
+
+ async def emit_inbound(self, envelope: InboundMessageEnvelope) -> bool:
+ """将一条入站封装转交给 Broker 回调。
+
+ Args:
+ envelope: 由驱动产出的规范化入站封装。
+
+ Returns:
+ bool: 若 Broker 接受该入站消息则返回 ``True``,否则返回 ``False``。
+ """
+
+ if self._inbound_handler is None:
+ return False
+ return await self._inbound_handler(envelope)
+
+ async def start(self) -> None:
+ """启动驱动生命周期。
+
+ 子类后续若需要初始化逻辑,可以覆盖这个钩子。
+ """
+ return None
+
+ async def stop(self) -> None:
+ """停止驱动生命周期。
+
+ 子类后续若需要清理逻辑,可以覆盖这个钩子。
+ """
+ return None
+
+ @abstractmethod
+ async def send_message(
+ self,
+ message: "SessionMessage",
+ route_key: RouteKey,
+ metadata: Optional[Dict[str, Any]] = None,
+ ) -> DeliveryReceipt:
+ """通过具体驱动发送一条消息。
+
+ Args:
+ message: 要投递的内部会话消息。
+ route_key: Broker 为本次投递选中的路由键。
+ metadata: 本次出站投递可选的 Broker 侧元数据。
+
+ Returns:
+ DeliveryReceipt: 规范化后的投递结果。
+ """
diff --git a/src/platform_io/drivers/legacy_driver.py b/src/platform_io/drivers/legacy_driver.py
new file mode 100644
index 00000000..ef90c772
--- /dev/null
+++ b/src/platform_io/drivers/legacy_driver.py
@@ -0,0 +1,92 @@
+"""提供 Platform IO 的 legacy 传输驱动实现。"""
+
+from typing import TYPE_CHECKING, Any, Dict, Optional
+
+from src.platform_io.drivers.base import PlatformIODriver
+from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, RouteKey
+
+if TYPE_CHECKING:
+ from src.chat.message_receive.message import SessionMessage
+
+
+class LegacyPlatformDriver(PlatformIODriver):
+ """面向 ``UniversalMessageSender`` 旧链的 Platform IO 驱动。"""
+
+ def __init__(
+ self,
+ driver_id: str,
+ platform: str,
+ account_id: Optional[str] = None,
+ scope: Optional[str] = None,
+ metadata: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ """初始化一个 legacy 驱动描述对象。
+
+ Args:
+ driver_id: Broker 内的唯一驱动 ID。
+ platform: 该 legacy 适配器链路负责的平台。
+ account_id: 可选的账号 ID。
+ scope: 可选的额外路由作用域。
+ metadata: 可选的额外驱动元数据。
+ """
+ descriptor = DriverDescriptor(
+ driver_id=driver_id,
+ kind=DriverKind.LEGACY,
+ platform=platform,
+ account_id=account_id,
+ scope=scope,
+ metadata=metadata or {},
+ )
+ super().__init__(descriptor)
+
+ async def send_message(
+ self,
+ message: "SessionMessage",
+ route_key: RouteKey,
+ metadata: Optional[Dict[str, Any]] = None,
+ ) -> DeliveryReceipt:
+ """通过旧链发送一条已经过预处理的消息。
+
+ Args:
+ message: 要投递的内部会话消息。
+ route_key: Broker 为本次投递选择的路由键。
+ metadata: 本次出站投递可选的 Broker 侧元数据。
+
+ Returns:
+ DeliveryReceipt: 规范化后的发送回执。
+ """
+ from src.chat.message_receive.uni_message_sender import send_prepared_message_to_platform
+
+ show_log = False
+ if isinstance(metadata, dict):
+ show_log = bool(metadata.get("show_log", False))
+
+ try:
+ sent = await send_prepared_message_to_platform(message, show_log=show_log)
+ except Exception as exc:
+ return DeliveryReceipt(
+ internal_message_id=message.message_id,
+ route_key=route_key,
+ status=DeliveryStatus.FAILED,
+ driver_id=self.driver_id,
+ driver_kind=self.descriptor.kind,
+ error=str(exc),
+ )
+
+ if not sent:
+ return DeliveryReceipt(
+ internal_message_id=message.message_id,
+ route_key=route_key,
+ status=DeliveryStatus.FAILED,
+ driver_id=self.driver_id,
+ driver_kind=self.descriptor.kind,
+ error="旧链发送失败",
+ )
+
+ return DeliveryReceipt(
+ internal_message_id=message.message_id,
+ route_key=route_key,
+ status=DeliveryStatus.SENT,
+ driver_id=self.driver_id,
+ driver_kind=self.descriptor.kind,
+ )
diff --git a/src/platform_io/drivers/plugin_driver.py b/src/platform_io/drivers/plugin_driver.py
new file mode 100644
index 00000000..c03204ad
--- /dev/null
+++ b/src/platform_io/drivers/plugin_driver.py
@@ -0,0 +1,211 @@
+"""提供 Platform IO 的插件消息网关驱动实现。"""
+
+from typing import TYPE_CHECKING, Any, Dict, Optional, Protocol
+
+from src.platform_io.drivers.base import PlatformIODriver
+from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, RouteKey
+
+if TYPE_CHECKING:
+ from src.chat.message_receive.message import SessionMessage
+
+
+class _GatewaySupervisorProtocol(Protocol):
+ """消息网关驱动依赖的 Supervisor 最小协议。"""
+
+ async def invoke_message_gateway(
+ self,
+ plugin_id: str,
+ component_name: str,
+ args: Optional[Dict[str, Any]] = None,
+ timeout_ms: int = 30000,
+ ) -> Any:
+ """调用插件声明的消息网关方法。"""
+
+
+class PluginPlatformDriver(PlatformIODriver):
+ """面向插件消息网关链路的 Platform IO 驱动。"""
+
+ def __init__(
+ self,
+ driver_id: str,
+ platform: str,
+ supervisor: _GatewaySupervisorProtocol,
+ component_name: str,
+ *,
+ supports_send: bool,
+ account_id: Optional[str] = None,
+ scope: Optional[str] = None,
+ plugin_id: Optional[str] = None,
+ metadata: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ """初始化一个插件消息网关驱动。
+
+ Args:
+ driver_id: Broker 内的唯一驱动 ID。
+ platform: 该消息网关负责的平台名称。
+ supervisor: 持有该插件的 Supervisor。
+ component_name: 出站时要调用的网关组件名称。
+ supports_send: 当前驱动是否具备出站能力。
+ account_id: 可选的账号 ID 或 self ID。
+ scope: 可选的额外路由作用域。
+ plugin_id: 拥有该实现的插件 ID。
+ metadata: 可选的额外驱动元数据。
+ """
+
+ descriptor = DriverDescriptor(
+ driver_id=driver_id,
+ kind=DriverKind.PLUGIN,
+ platform=platform,
+ account_id=account_id,
+ scope=scope,
+ plugin_id=plugin_id,
+ metadata=metadata or {},
+ )
+ super().__init__(descriptor)
+ self._supervisor = supervisor
+ self._component_name = component_name
+ self._supports_send = supports_send
+
+ async def send_message(
+ self,
+ message: "SessionMessage",
+ route_key: RouteKey,
+ metadata: Optional[Dict[str, Any]] = None,
+ ) -> DeliveryReceipt:
+ """通过插件消息网关发送消息。
+
+ Args:
+ message: 要投递的内部会话消息。
+ route_key: Broker 为本次投递选择的路由键。
+ metadata: 可选的发送元数据。
+
+ Returns:
+ DeliveryReceipt: 规范化后的发送回执。
+ """
+
+ if not self._supports_send:
+ return DeliveryReceipt(
+ internal_message_id=message.message_id,
+ route_key=route_key,
+ status=DeliveryStatus.FAILED,
+ driver_id=self.driver_id,
+ driver_kind=self.descriptor.kind,
+ error="当前消息网关仅支持接收,不支持发送",
+ )
+
+ from src.plugin_runtime.host.message_utils import PluginMessageUtils
+
+ plugin_id = self.descriptor.plugin_id or ""
+ if not plugin_id:
+ return DeliveryReceipt(
+ internal_message_id=message.message_id,
+ route_key=route_key,
+ status=DeliveryStatus.FAILED,
+ driver_id=self.driver_id,
+ driver_kind=self.descriptor.kind,
+ error="插件消息网关驱动缺少 plugin_id",
+ )
+
+ try:
+ message_dict = PluginMessageUtils._session_message_to_dict(message)
+ response = await self._supervisor.invoke_message_gateway(
+ plugin_id=plugin_id,
+ component_name=self._component_name,
+ args={
+ "message": message_dict,
+ "route": {
+ "platform": route_key.platform,
+ "account_id": route_key.account_id,
+ "scope": route_key.scope,
+ },
+ "metadata": metadata or {},
+ },
+ timeout_ms=30000,
+ )
+ except Exception as exc:
+ return DeliveryReceipt(
+ internal_message_id=message.message_id,
+ route_key=route_key,
+ status=DeliveryStatus.FAILED,
+ driver_id=self.driver_id,
+ driver_kind=self.descriptor.kind,
+ error=str(exc),
+ )
+
+ return self._build_receipt(message.message_id, route_key, response)
+
+ def _build_receipt(self, internal_message_id: str, route_key: RouteKey, response: Any) -> DeliveryReceipt:
+ """将网关调用响应归一化为出站回执。
+
+ Args:
+ internal_message_id: 内部消息 ID。
+ route_key: 本次投递的路由键。
+ response: Supervisor 返回的 RPC 响应对象。
+
+ Returns:
+ DeliveryReceipt: 标准化后的出站回执。
+ """
+
+ if getattr(response, "error", None):
+ error = response.error.get("message", "消息网关发送失败")
+ return DeliveryReceipt(
+ internal_message_id=internal_message_id,
+ route_key=route_key,
+ status=DeliveryStatus.FAILED,
+ driver_id=self.driver_id,
+ driver_kind=self.descriptor.kind,
+ error=error,
+ )
+
+ payload = getattr(response, "payload", {})
+ invoke_success = bool(payload.get("success", False)) if isinstance(payload, dict) else False
+ if not invoke_success:
+ return DeliveryReceipt(
+ internal_message_id=internal_message_id,
+ route_key=route_key,
+ status=DeliveryStatus.FAILED,
+ driver_id=self.driver_id,
+ driver_kind=self.descriptor.kind,
+ error=str(payload.get("result", "消息网关发送失败")) if isinstance(payload, dict) else "消息网关发送失败",
+ )
+
+ result = payload.get("result") if isinstance(payload, dict) else None
+ if isinstance(result, dict):
+ if result.get("success") is False:
+ return DeliveryReceipt(
+ internal_message_id=internal_message_id,
+ route_key=route_key,
+ status=DeliveryStatus.FAILED,
+ driver_id=self.driver_id,
+ driver_kind=self.descriptor.kind,
+ error=str(result.get("error", "消息网关发送失败")),
+ metadata=result.get("metadata", {}) if isinstance(result.get("metadata"), dict) else {},
+ )
+ external_message_id = str(result.get("external_message_id") or result.get("message_id") or "") or None
+ return DeliveryReceipt(
+ internal_message_id=internal_message_id,
+ route_key=route_key,
+ status=DeliveryStatus.SENT,
+ driver_id=self.driver_id,
+ driver_kind=self.descriptor.kind,
+ external_message_id=external_message_id,
+ metadata=result.get("metadata", {}) if isinstance(result.get("metadata"), dict) else {},
+ )
+
+ if isinstance(result, str) and result.strip():
+ return DeliveryReceipt(
+ internal_message_id=internal_message_id,
+ route_key=route_key,
+ status=DeliveryStatus.SENT,
+ driver_id=self.driver_id,
+ driver_kind=self.descriptor.kind,
+ external_message_id=result.strip(),
+ )
+
+ return DeliveryReceipt(
+ internal_message_id=internal_message_id,
+ route_key=route_key,
+ status=DeliveryStatus.SENT,
+ driver_id=self.driver_id,
+ driver_kind=self.descriptor.kind,
+ )
diff --git a/src/platform_io/manager.py b/src/platform_io/manager.py
new file mode 100644
index 00000000..dee553a6
--- /dev/null
+++ b/src/platform_io/manager.py
@@ -0,0 +1,611 @@
+"""提供 Platform IO 层的中心 Broker 管理器。"""
+
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional
+
+from src.common.logger import get_logger
+from src.platform_io.drivers.base import PlatformIODriver
+
+from .dedupe import MessageDeduplicator
+from .outbound_tracker import OutboundTracker
+from .route_key_factory import RouteKeyFactory
+from .registry import DriverRegistry
+from .routing import RouteTable
+from .types import DeliveryBatch, DeliveryReceipt, DeliveryStatus, InboundMessageEnvelope, RouteBinding, RouteKey
+
+if TYPE_CHECKING:
+ from src.chat.message_receive.message import SessionMessage
+
+logger = get_logger("platform_io.manager")
+
+InboundDispatcher = Callable[[InboundMessageEnvelope], Awaitable[None]]
+
+
+class PlatformIOManager:
+ """统一协调平台消息 IO 的路由、去重与状态跟踪。
+
+ 与旧实现不同,这个管理器不再负责“多条链路谁该接管平台”的裁决,
+ 只维护发送表和接收表两张轻量路由表:
+
+ - 发送时:解析所有命中的发送绑定并全部投递。
+ - 接收时:只校验当前驱动是否已登记为可接收链路,然后全部放行给上层。
+ - 去重时:仅对单条链路做技术性重放抑制,不做跨链路语义去重。
+ """
+
+ def __init__(self) -> None:
+ """初始化 Broker 管理器及其内存状态。"""
+ self._driver_registry = DriverRegistry()
+ self._send_route_table = RouteTable()
+ self._receive_route_table = RouteTable()
+ self._legacy_send_drivers: Dict[str, PlatformIODriver] = {}
+ self._deduplicator = MessageDeduplicator()
+ self._outbound_tracker = OutboundTracker()
+ self._inbound_dispatcher: Optional[InboundDispatcher] = None
+ self._started = False
+
+ @property
+ def is_started(self) -> bool:
+ """返回 Broker 当前是否已进入运行态。
+
+ Returns:
+ bool: 若 Broker 已启动则返回 ``True``。
+ """
+ return self._started
+
+ async def start(self) -> None:
+ """启动 Broker,并依次启动当前已注册的全部驱动。
+
+ Raises:
+ Exception: 当某个驱动启动失败时,异常会继续上抛;已成功启动的驱动
+ 会被自动回滚停止。
+ """
+ if self._started:
+ return
+
+ started_drivers: List[PlatformIODriver] = []
+ try:
+ for driver in self._driver_registry.list():
+ await driver.start()
+ started_drivers.append(driver)
+ except Exception:
+ for driver in reversed(started_drivers):
+ try:
+ await driver.stop()
+ except Exception:
+ logger.exception(f"回滚驱动停止失败: driver_id={driver.driver_id}")
+ raise
+
+ self._started = True
+
+ async def ensure_send_pipeline_ready(self) -> None:
+ """确保出站发送管线已准备就绪。
+
+ 该方法会先同步 legacy fallback driver,再在需要时启动 Broker。
+ send service 应只调用这一层准备入口,而不是自行判断旧链或插件链。
+ """
+ await self._sync_legacy_send_drivers()
+ if not self._started:
+ await self.start()
+
+ async def stop(self) -> None:
+ """停止 Broker,并按逆序停止全部已注册驱动。
+
+ 停止完成后,会同步清空仅对当前运行周期有效的去重缓存和出站跟踪状态,
+ 避免下一次启动时继续沿用上一个运行周期的瞬时内存数据。
+
+ Raises:
+ RuntimeError: 当一个或多个驱动停止失败时抛出汇总异常。
+ """
+ if not self._started:
+ return
+
+ stop_errors: List[str] = []
+ for driver in reversed(self._driver_registry.list()):
+ try:
+ await driver.stop()
+ except Exception as exc:
+ stop_errors.append(f"{driver.driver_id}: {exc}")
+ logger.exception(f"驱动停止失败: driver_id={driver.driver_id}")
+
+ self._started = False
+ self._deduplicator.clear()
+ self._outbound_tracker.clear()
+ if stop_errors:
+ raise RuntimeError(f"部分驱动停止失败: {'; '.join(stop_errors)}")
+
+ async def add_driver(self, driver: PlatformIODriver) -> None:
+ """向运行中的 Broker 注册并启动一个驱动。
+
+ 如果 Broker 尚未启动,则该方法等价于 ``register_driver()``。
+
+ Args:
+ driver: 要添加的驱动实例。
+
+ Raises:
+ Exception: 当驱动启动失败时,注册会自动回滚,异常继续上抛。
+ """
+ self._register_driver_internal(driver)
+ if not self._started:
+ return
+
+ try:
+ await driver.start()
+ except Exception:
+ self._unregister_driver_internal(driver.driver_id)
+ raise
+
+ async def remove_driver(self, driver_id: str) -> Optional[PlatformIODriver]:
+ """从运行中的 Broker 停止并移除一个驱动。
+
+ 如果 Broker 尚未启动,则该方法等价于 ``unregister_driver()``。
+
+ Args:
+ driver_id: 要移除的驱动 ID。
+
+ Returns:
+ Optional[PlatformIODriver]: 若驱动存在,则返回被移除的驱动实例。
+
+ Raises:
+ Exception: 当 Broker 运行中且驱动停止失败时,异常会继续上抛。
+ """
+ if not self._started:
+ return self.unregister_driver(driver_id)
+
+ driver = self._driver_registry.get(driver_id)
+ if driver is None:
+ return None
+
+ await driver.stop()
+ return self._unregister_driver_internal(driver_id)
+
+ @property
+ def driver_registry(self) -> DriverRegistry:
+ """返回管理器持有的驱动注册表。
+
+ Returns:
+ DriverRegistry: 用于保存全部已注册驱动的注册表。
+ """
+ return self._driver_registry
+
+ @property
+ def send_route_table(self) -> RouteTable:
+ """返回发送路由表。"""
+
+ return self._send_route_table
+
+ @property
+ def receive_route_table(self) -> RouteTable:
+ """返回接收路由表。"""
+
+ return self._receive_route_table
+
+ @property
+ def deduplicator(self) -> MessageDeduplicator:
+ """返回管理器持有的入站去重器。
+
+ Returns:
+ MessageDeduplicator: 用于抑制重复入站的去重器。
+ """
+ return self._deduplicator
+
+ @property
+ def outbound_tracker(self) -> OutboundTracker:
+ """返回管理器持有的出站跟踪器。
+
+ Returns:
+ OutboundTracker: 用于记录出站 pending 状态与回执的跟踪器。
+ """
+ return self._outbound_tracker
+
+ def set_inbound_dispatcher(self, dispatcher: InboundDispatcher) -> None:
+ """设置统一的入站分发回调。
+
+ Args:
+ dispatcher: 接收已通过 Broker 审核的入站封装,并继续送入
+ Core 下一处理阶段的异步回调。
+ """
+
+ self._inbound_dispatcher = dispatcher
+
+ def clear_inbound_dispatcher(self) -> None:
+ """清除当前的入站分发回调。"""
+ self._inbound_dispatcher = None
+
+ @property
+ def has_inbound_dispatcher(self) -> bool:
+ """返回当前是否已经配置入站分发回调。
+
+ Returns:
+ bool: 若已经配置入站分发回调则返回 ``True``。
+ """
+ return self._inbound_dispatcher is not None
+
+ def register_driver(self, driver: PlatformIODriver) -> None:
+ """注册驱动,并把它的入站回调挂到 Broker。
+
+ Args:
+ driver: 要注册的驱动实例。
+
+ Raises:
+ RuntimeError: 当 Broker 已经处于运行态时抛出。此时应改用
+ ``add_driver()`` 以保证驱动生命周期和注册状态一致。
+ """
+ if self._started:
+ raise RuntimeError("Broker 运行中不允许直接 register_driver,请改用 add_driver()")
+
+ self._register_driver_internal(driver)
+
+ def _register_driver_internal(self, driver: PlatformIODriver) -> None:
+ """执行不带运行态限制的内部驱动注册。
+
+ Args:
+ driver: 要注册的驱动实例。
+ """
+ driver.set_inbound_handler(self.accept_inbound)
+ self._driver_registry.register(driver)
+
+ def unregister_driver(self, driver_id: str) -> Optional[PlatformIODriver]:
+ """从 Broker 注销一个驱动。
+
+ Args:
+ driver_id: 要移除的驱动 ID。
+
+ Returns:
+ Optional[PlatformIODriver]: 若驱动存在,则返回被移除的驱动实例。
+
+ Raises:
+ RuntimeError: 当 Broker 已经处于运行态时抛出。此时应改用
+ ``remove_driver()``,避免驱动停止与路由解绑脱节。
+ """
+ if self._started:
+ raise RuntimeError("Broker 运行中不允许直接 unregister_driver,请改用 remove_driver()")
+
+ return self._unregister_driver_internal(driver_id)
+
+ def _unregister_driver_internal(self, driver_id: str) -> Optional[PlatformIODriver]:
+ """执行不带运行态限制的内部驱动注销。
+
+ Args:
+ driver_id: 要移除的驱动 ID。
+
+ Returns:
+ Optional[PlatformIODriver]: 若驱动存在,则返回被移除的驱动实例。
+ """
+ removed_driver = self._driver_registry.unregister(driver_id)
+ if removed_driver is None:
+ return None
+
+ removed_driver.clear_inbound_handler()
+ self._send_route_table.remove_bindings_by_driver(driver_id)
+ self._receive_route_table.remove_bindings_by_driver(driver_id)
+ self._legacy_send_drivers = {
+ platform: driver
+ for platform, driver in self._legacy_send_drivers.items()
+ if driver.driver_id != driver_id
+ }
+ return removed_driver
+
+ async def _sync_legacy_send_drivers(self) -> None:
+ """根据当前配置同步 legacy fallback driver。"""
+ from src.chat.utils.utils import get_all_bot_accounts
+ from src.platform_io.drivers.legacy_driver import LegacyPlatformDriver
+
+ desired_accounts = get_all_bot_accounts()
+ desired_platforms = set(desired_accounts.keys())
+ current_platforms = set(self._legacy_send_drivers.keys())
+
+ for platform in sorted(current_platforms - desired_platforms):
+ await self._remove_legacy_send_driver(platform)
+
+ for platform, account_id in desired_accounts.items():
+ existing_driver = self._legacy_send_drivers.get(platform)
+ if existing_driver is not None and existing_driver.descriptor.account_id == account_id:
+ continue
+
+ if existing_driver is not None:
+ await self._remove_legacy_send_driver(platform)
+
+ driver = LegacyPlatformDriver(
+ driver_id=f"legacy.send.{platform}",
+ platform=platform,
+ account_id=account_id,
+ )
+ if self._started:
+ await self.add_driver(driver)
+ else:
+ self.register_driver(driver)
+ self._legacy_send_drivers[platform] = driver
+
+ async def _remove_legacy_send_driver(self, platform: str) -> None:
+ """移除指定平台的 legacy fallback driver。
+
+ Args:
+ platform: 要移除的目标平台。
+ """
+ driver = self._legacy_send_drivers.get(platform)
+ if driver is None:
+ return
+
+ if self._started:
+ await self.remove_driver(driver.driver_id)
+ else:
+ self.unregister_driver(driver.driver_id)
+ self._legacy_send_drivers.pop(platform, None)
+
+ def bind_send_route(self, binding: RouteBinding) -> None:
+ """为某个路由键绑定发送驱动。
+
+ Args:
+ binding: 要保存的路由绑定。
+
+ Raises:
+ ValueError: 当绑定引用了不存在的驱动,或者绑定与驱动描述不一致时抛出。
+ """
+ driver = self._driver_registry.get(binding.driver_id)
+ if driver is None:
+ raise ValueError(f"驱动 {binding.driver_id} 未注册,无法绑定路由")
+
+ self._validate_binding_against_driver(binding, driver)
+ self._send_route_table.bind(binding)
+
+ def bind_receive_route(self, binding: RouteBinding) -> None:
+ """为某个路由键绑定接收驱动。
+
+ Args:
+ binding: 要保存的路由绑定。
+
+ Raises:
+ ValueError: 当绑定引用了不存在的驱动,或者绑定与驱动描述不一致时抛出。
+ """
+ driver = self._driver_registry.get(binding.driver_id)
+ if driver is None:
+ raise ValueError(f"驱动 {binding.driver_id} 未注册,无法绑定路由")
+
+ self._validate_binding_against_driver(binding, driver)
+ self._receive_route_table.bind(binding)
+
+ def unbind_send_route(self, route_key: RouteKey, driver_id: Optional[str] = None) -> None:
+ """移除发送路由绑定。
+
+ Args:
+ route_key: 要移除绑定的路由键。
+ driver_id: 可选的特定驱动 ID。
+ """
+
+ self._send_route_table.unbind(route_key, driver_id)
+
+ def unbind_receive_route(self, route_key: RouteKey, driver_id: Optional[str] = None) -> None:
+ """移除接收路由绑定。
+
+ Args:
+ route_key: 要移除绑定的路由键。
+ driver_id: 可选的特定驱动 ID。
+ """
+
+ self._receive_route_table.unbind(route_key, driver_id)
+
+ def resolve_drivers(self, route_key: RouteKey) -> List[PlatformIODriver]:
+ """解析某个路由键当前命中的全部发送驱动。
+
+ Args:
+ route_key: 要解析的路由键。
+
+ Returns:
+ List[PlatformIODriver]: 当前命中的全部发送驱动。
+ """
+
+ drivers: List[PlatformIODriver] = []
+ seen_driver_ids: set[str] = set()
+ for binding in self._send_route_table.resolve_bindings(route_key):
+ driver = self._driver_registry.get(binding.driver_id)
+ if driver is not None and driver.driver_id not in seen_driver_ids:
+ drivers.append(driver)
+ seen_driver_ids.add(driver.driver_id)
+
+ fallback_driver = self._legacy_send_drivers.get(route_key.platform)
+ if fallback_driver is not None:
+ descriptor = fallback_driver.descriptor
+ account_matches = descriptor.account_id is None or route_key.account_id in (None, descriptor.account_id)
+ scope_matches = descriptor.scope is None or route_key.scope in (None, descriptor.scope)
+ if account_matches and scope_matches and fallback_driver.driver_id not in seen_driver_ids:
+ drivers.append(fallback_driver)
+
+ return drivers
+
+ @staticmethod
+ def build_route_key_from_message(message: "SessionMessage") -> RouteKey:
+ """根据 ``SessionMessage`` 构造路由键。
+
+ Args:
+ message: 内部会话消息对象。
+
+ Returns:
+ RouteKey: 由消息内容提取出的规范化路由键。
+ """
+ return RouteKeyFactory.from_session_message(message)
+
+ @staticmethod
+ def build_route_key_from_message_dict(message_dict: Dict[str, Any]) -> RouteKey:
+ """根据消息字典构造路由键。
+
+ Args:
+ message_dict: Host 与插件之间传输的消息字典。
+
+ Returns:
+ RouteKey: 由消息字典提取出的规范化路由键。
+ """
+ return RouteKeyFactory.from_message_dict(message_dict)
+
+ async def accept_inbound(self, envelope: InboundMessageEnvelope) -> bool:
+ """处理一条由驱动上报的入站封装。
+
+ Args:
+ envelope: 由传输驱动产出的入站封装。
+
+ Returns:
+ bool: 若消息被接受并继续转发给入站分发器,则返回 ``True``,
+ 否则返回 ``False``。
+ """
+
+ if not self._receive_route_table.has_binding_for_driver(envelope.route_key, envelope.driver_id):
+ logger.info(
+ f"忽略未登记到接收路由表的入站消息: route={envelope.route_key} "
+ f"driver={envelope.driver_id}"
+ )
+ return False
+
+ if self._inbound_dispatcher is None:
+ logger.debug("PlatformIOManager 尚未配置 inbound dispatcher,暂不继续分发")
+ return False
+
+ dedupe_key = self._build_inbound_dedupe_key(envelope)
+ if dedupe_key is not None:
+ if not self._deduplicator.mark_seen(dedupe_key):
+ logger.info(f"忽略重复入站消息: dedupe_key={dedupe_key}")
+ return False
+
+ await self._inbound_dispatcher(envelope)
+ return True
+
+ async def send_message(
+ self,
+ message: "SessionMessage",
+ route_key: RouteKey,
+ metadata: Optional[Dict[str, Any]] = None,
+ ) -> DeliveryBatch:
+ """通过 Broker 选中的全部发送驱动广播一条消息。
+
+ Args:
+ message: 要投递的内部会话消息。
+ route_key: 本次出站投递选择的路由键。
+ metadata: 可选的额外 Broker 侧元数据。
+
+ Returns:
+ DeliveryBatch: 规范化后的批量出站回执。
+ """
+ drivers = self.resolve_drivers(route_key)
+ if not drivers:
+ return DeliveryBatch(internal_message_id=message.message_id, route_key=route_key)
+
+ receipts: List[DeliveryReceipt] = []
+ for driver in drivers:
+ try:
+ self._outbound_tracker.begin_tracking(
+ internal_message_id=message.message_id,
+ route_key=route_key,
+ driver_id=driver.driver_id,
+ metadata=metadata,
+ )
+ except ValueError as exc:
+ receipts.append(
+ DeliveryReceipt(
+ internal_message_id=message.message_id,
+ route_key=route_key,
+ status=DeliveryStatus.FAILED,
+ driver_id=driver.driver_id,
+ driver_kind=driver.descriptor.kind,
+ error=str(exc),
+ )
+ )
+ continue
+
+ try:
+ receipt = await driver.send_message(message=message, route_key=route_key, metadata=metadata)
+ except Exception as exc:
+ receipt = DeliveryReceipt(
+ internal_message_id=message.message_id,
+ route_key=route_key,
+ status=DeliveryStatus.FAILED,
+ driver_id=driver.driver_id,
+ driver_kind=driver.descriptor.kind,
+ error=str(exc),
+ )
+
+ self._outbound_tracker.finish_tracking(receipt)
+ receipts.append(receipt)
+
+ return DeliveryBatch(
+ internal_message_id=message.message_id,
+ route_key=route_key,
+ receipts=receipts,
+ )
+
+ @staticmethod
+ def _build_inbound_dedupe_key(envelope: InboundMessageEnvelope) -> Optional[str]:
+ """构造用于入站抑制的去重键。
+
+ Args:
+ envelope: 当前正在处理的入站封装。
+
+ Returns:
+ Optional[str]: 若可以构造稳定去重键则返回该键,否则返回 ``None``。
+
+ Notes:
+ 这里仅接受上游显式提供的稳定消息身份,例如 ``dedupe_key``、
+ 平台侧 ``external_message_id`` 或已经完成规范化的
+ ``session_message.message_id``。Broker 不再根据 ``payload`` 内容
+ 猜测语义去重键,避免把“短时间内两条内容刚好完全相同”的合法消息
+ 误判为重复入站。
+ """
+ raw_dedupe_key = envelope.dedupe_key or envelope.external_message_id
+ if raw_dedupe_key is None and envelope.session_message is not None:
+ raw_dedupe_key = envelope.session_message.message_id
+ if raw_dedupe_key is None:
+ return None
+
+ normalized_dedupe_key = str(raw_dedupe_key).strip()
+ if not normalized_dedupe_key:
+ return None
+
+ return f"{envelope.driver_id}:{normalized_dedupe_key}"
+
+ @staticmethod
+ def _validate_binding_against_driver(binding: RouteBinding, driver: PlatformIODriver) -> None:
+ """校验路由绑定与驱动描述是否一致。
+
+ Args:
+ binding: 待校验的路由绑定。
+ driver: 被绑定的驱动实例。
+
+ Raises:
+ ValueError: 当绑定类型、平台或更细粒度路由维度与驱动描述冲突时抛出。
+ """
+ descriptor = driver.descriptor
+ if binding.driver_kind != descriptor.kind:
+ raise ValueError(
+ f"路由绑定的 driver_kind={binding.driver_kind} 与驱动 {driver.driver_id} 的类型 "
+ f"{descriptor.kind} 不一致"
+ )
+
+ if binding.route_key.platform != descriptor.platform:
+ raise ValueError(
+ f"路由绑定的平台 {binding.route_key.platform} 与驱动 {driver.driver_id} 的平台 "
+ f"{descriptor.platform} 不一致"
+ )
+
+ if descriptor.account_id is not None and binding.route_key.account_id not in (None, descriptor.account_id):
+ raise ValueError(
+ f"路由绑定的 account_id={binding.route_key.account_id} 与驱动 {driver.driver_id} 的 "
+ f"account_id={descriptor.account_id} 冲突"
+ )
+
+ if descriptor.scope is not None and binding.route_key.scope not in (None, descriptor.scope):
+ raise ValueError(
+ f"路由绑定的 scope={binding.route_key.scope} 与驱动 {driver.driver_id} 的 "
+ f"scope={descriptor.scope} 冲突"
+ )
+
+
+_platform_io_manager: Optional[PlatformIOManager] = None
+
+
+def get_platform_io_manager() -> PlatformIOManager:
+ """返回全局 ``PlatformIOManager`` 单例。
+
+ Returns:
+ PlatformIOManager: 进程级共享的 Broker 管理器实例。
+ """
+
+ global _platform_io_manager
+ if _platform_io_manager is None:
+ _platform_io_manager = PlatformIOManager()
+ return _platform_io_manager
diff --git a/src/platform_io/outbound_tracker.py b/src/platform_io/outbound_tracker.py
new file mode 100644
index 00000000..3725691f
--- /dev/null
+++ b/src/platform_io/outbound_tracker.py
@@ -0,0 +1,286 @@
+"""跟踪 Platform IO 层的出站投递状态。
+
+当前实现基于两组 ``dict + heapq``:
+- ``_pending`` 和 ``_pending_expire_heap`` 负责管理待完成的出站记录
+- ``_receipts_by_external_id`` 和 ``_receipt_expire_heap`` 负责管理已完成回执索引
+
+这样就不需要在每次读写时全表扫描过期项,而是通过懒清理逐步弹出已经过期
+或已经失效的堆节点。
+"""
+
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional, Tuple
+
+import heapq
+import time
+
+from .types import DeliveryReceipt, RouteKey
+
+
+@dataclass(slots=True)
+class PendingOutboundRecord:
+ """表示一条仍在等待完成的出站投递记录。
+
+ Attributes:
+ internal_message_id: 正在跟踪的内部 ``SessionMessage.message_id``。
+ route_key: 该出站投递开始时使用的路由键。
+ driver_id: 负责这次出站投递的驱动 ID。
+ created_at: 开始跟踪时记录的单调时钟时间戳。
+ expires_at: 该待完成记录预计过期的单调时钟时间戳。
+ metadata: 与待完成记录一同保留的额外 Broker 侧元数据。
+ """
+
+ internal_message_id: str
+ route_key: RouteKey
+ driver_id: str
+ created_at: float = field(default_factory=time.monotonic)
+ expires_at: float = 0.0
+ metadata: Dict[str, Any] = field(default_factory=dict)
+
+
+@dataclass(slots=True)
+class StoredDeliveryReceipt:
+ """表示一条已完成并暂存的出站回执。
+
+ Attributes:
+ receipt: 规范化后的出站投递回执。
+ stored_at: 回执被写入索引时记录的单调时钟时间戳。
+ expires_at: 该回执索引预计过期的单调时钟时间戳。
+ """
+
+ receipt: DeliveryReceipt
+ stored_at: float = field(default_factory=time.monotonic)
+ expires_at: float = 0.0
+
+
+class OutboundTracker:
+ """统一跟踪出站消息的 pending 状态与最终回执。
+
+ 主要用于解决出站消息在发送过程中“状态散落在不同路径里”的问题:
+ - 发送开始后,需要在最终回执返回前保留一份 pending 状态
+ - 平台返回 ``external_message_id`` 后,需要保留一段时间的回执索引
+
+ 当前实现使用 ``dict + heapq`` 做 TTL 管理:
+ - ``dict`` 提供 ``O(1)`` 级别的主键查询
+ - ``heapq`` 提供按过期时间排序的懒清理能力
+
+ 这比“每次 begin/finish/get 都全表扫描”的实现更适合高吞吐出站场景。
+
+ Notes:
+ 复杂度说明如下,设 ``p`` 为当前有效 pending 数量,``r`` 为当前有效回执数量:
+
+ - ``begin_tracking()``、``finish_tracking()`` 的常见路径时间复杂度接近
+ ``O(log p)`` 或 ``O(log r)``
+ - ``get_pending()``、``get_receipt_by_external_id()`` 的查询本身是 ``O(1)``
+ ,连同懒清理一起看,长期摊还复杂度接近 ``O(log n)``
+ - 如果某次调用恰好触发一批过期节点的集中清理,则该次调用的最坏时间复杂度
+ 可达到 ``O(k log n)``,其中 ``k`` 为本次被弹出的节点数量
+ - 空间复杂度为 ``O(p + r)``
+ """
+
+ def __init__(self, ttl_seconds: float = 1800.0) -> None:
+ """初始化出站跟踪器。
+
+ Args:
+ ttl_seconds: 待完成记录与按外部消息 ID 建立的回执索引保留时长,
+ 单位为秒。
+
+ Raises:
+ ValueError: 当 ``ttl_seconds`` 非正数时抛出。
+ """
+ if ttl_seconds <= 0:
+ raise ValueError("ttl_seconds 必须大于 0")
+
+ self._ttl_seconds = ttl_seconds
+ self._pending: Dict[Tuple[str, str], PendingOutboundRecord] = {}
+ self._pending_expire_heap: List[Tuple[float, str, str]] = []
+ self._receipts_by_external_id: Dict[str, StoredDeliveryReceipt] = {}
+ self._receipt_expire_heap: List[Tuple[float, str]] = []
+
+ @staticmethod
+ def _build_pending_key(internal_message_id: str, driver_id: str) -> Tuple[str, str]:
+ """构造单条出站跟踪记录的唯一键。
+
+ Args:
+ internal_message_id: 内部消息 ID。
+ driver_id: 负责当前投递的驱动 ID。
+
+ Returns:
+ Tuple[str, str]: ``(internal_message_id, driver_id)`` 组合键。
+ """
+ return internal_message_id, driver_id
+
+ def begin_tracking(
+ self,
+ internal_message_id: str,
+ route_key: RouteKey,
+ driver_id: str,
+ metadata: Optional[Dict[str, Any]] = None,
+ ) -> PendingOutboundRecord:
+ """开始跟踪一次出站投递。
+
+ Args:
+ internal_message_id: 正在投递的内部消息 ID。
+ route_key: 这次出站投递选择的路由键。
+ driver_id: 负责本次投递的驱动 ID。
+ metadata: 可选的额外元数据,会一并保存在待完成记录中。
+
+ Returns:
+ PendingOutboundRecord: 新创建的待完成记录。
+
+ Raises:
+ ValueError: 当同一个 ``internal_message_id`` 与 ``driver_id`` 组合已经存在
+ 未完成记录时抛出。
+ """
+ now = time.monotonic()
+ self._cleanup_expired(now)
+ pending_key = self._build_pending_key(internal_message_id, driver_id)
+
+ if pending_key in self._pending:
+ raise ValueError(f"消息 {internal_message_id} 在驱动 {driver_id} 上已存在未完成的出站跟踪记录")
+
+ expires_at = now + self._ttl_seconds
+ record = PendingOutboundRecord(
+ internal_message_id=internal_message_id,
+ route_key=route_key,
+ driver_id=driver_id,
+ created_at=now,
+ expires_at=expires_at,
+ metadata=metadata or {},
+ )
+ self._pending[pending_key] = record
+ heapq.heappush(self._pending_expire_heap, (expires_at, internal_message_id, driver_id))
+ return record
+
+ def finish_tracking(self, receipt: DeliveryReceipt) -> Optional[PendingOutboundRecord]:
+ """使用最终回执结束一条出站跟踪。
+
+ Args:
+ receipt: 规范化后的最终投递回执。
+
+ Returns:
+ Optional[PendingOutboundRecord]: 若此前存在待完成记录,则返回该记录。
+ """
+ now = time.monotonic()
+ self._cleanup_expired(now)
+
+ pending_record: Optional[PendingOutboundRecord] = None
+ if receipt.driver_id:
+ pending_key = self._build_pending_key(receipt.internal_message_id, receipt.driver_id)
+ pending_record = self._pending.pop(pending_key, None)
+ else:
+ matched_records = [
+ key
+ for key, record in self._pending.items()
+ if record.internal_message_id == receipt.internal_message_id
+ ]
+ if len(matched_records) == 1:
+ pending_record = self._pending.pop(matched_records[0], None)
+
+ if receipt.external_message_id:
+ expires_at = now + self._ttl_seconds
+ self._receipts_by_external_id[receipt.external_message_id] = StoredDeliveryReceipt(
+ receipt=receipt,
+ stored_at=now,
+ expires_at=expires_at,
+ )
+ heapq.heappush(self._receipt_expire_heap, (expires_at, receipt.external_message_id))
+ return pending_record
+
+ def get_pending(
+ self,
+ internal_message_id: str,
+ driver_id: Optional[str] = None,
+ ) -> Optional[PendingOutboundRecord]:
+ """根据内部消息 ID 查询待完成记录。
+
+ Args:
+ internal_message_id: 要查询的内部消息 ID。
+ driver_id: 可选的驱动 ID;提供后仅返回该驱动上的待完成记录。
+
+ Returns:
+ Optional[PendingOutboundRecord]: 若记录仍存在,则返回对应待完成记录。
+ """
+ self._cleanup_expired(time.monotonic())
+
+ if driver_id:
+ return self._pending.get(self._build_pending_key(internal_message_id, driver_id))
+
+ matched_records = [
+ record
+ for record in self._pending.values()
+ if record.internal_message_id == internal_message_id
+ ]
+ if len(matched_records) == 1:
+ return matched_records[0]
+ return None
+
+ def get_receipt_by_external_id(self, external_message_id: str) -> Optional[DeliveryReceipt]:
+ """根据外部平台消息 ID 查询已完成回执。
+
+ Args:
+ external_message_id: 要查询的平台侧消息 ID。
+
+ Returns:
+ Optional[DeliveryReceipt]: 若存在对应回执,则返回该回执。
+ """
+ self._cleanup_expired(time.monotonic())
+ stored_receipt = self._receipts_by_external_id.get(external_message_id)
+ return stored_receipt.receipt if stored_receipt else None
+
+ def clear(self) -> None:
+ """清空全部待完成记录与已保存回执。"""
+ self._pending.clear()
+ self._pending_expire_heap.clear()
+ self._receipts_by_external_id.clear()
+ self._receipt_expire_heap.clear()
+
+ def _cleanup_expired(self, now: float) -> None:
+ """清理内存中已经过期的待完成记录与已保存回执。
+
+ Args:
+ now: 当前单调时钟时间戳。
+ """
+ self._cleanup_expired_pending(now)
+ self._cleanup_expired_receipts(now)
+
+ def _cleanup_expired_pending(self, now: float) -> None:
+ """清理已经过期的待完成记录。
+
+ Args:
+ now: 当前单调时钟时间戳。
+
+ Notes:
+ 堆中可能存在已经失效的旧节点。例如某条记录提前 ``finish`` 后,
+ 它原本的过期节点仍可能留在堆里。这里会通过和 ``dict`` 中当前记录的
+ ``expires_at`` 对比,跳过这类旧节点。
+ """
+ while self._pending_expire_heap and self._pending_expire_heap[0][0] <= now:
+ expires_at, internal_message_id, driver_id = heapq.heappop(self._pending_expire_heap)
+ pending_key = self._build_pending_key(internal_message_id, driver_id)
+ current_record = self._pending.get(pending_key)
+ if current_record is None:
+ continue
+ if current_record.expires_at != expires_at:
+ continue
+ self._pending.pop(pending_key, None)
+
+ def _cleanup_expired_receipts(self, now: float) -> None:
+ """清理已经过期的回执索引。
+
+ Args:
+ now: 当前单调时钟时间戳。
+
+ Notes:
+ 同一个 ``external_message_id`` 在极端情况下可能被重复写入索引,
+ 因此这里同样需要通过 ``expires_at`` 和当前 ``dict`` 中的值比对,
+ 跳过已经失效的旧堆节点。
+ """
+ while self._receipt_expire_heap and self._receipt_expire_heap[0][0] <= now:
+ expires_at, external_message_id = heapq.heappop(self._receipt_expire_heap)
+ current_receipt = self._receipts_by_external_id.get(external_message_id)
+ if current_receipt is None:
+ continue
+ if current_receipt.expires_at != expires_at:
+ continue
+ self._receipts_by_external_id.pop(external_message_id, None)
diff --git a/src/platform_io/registry.py b/src/platform_io/registry.py
new file mode 100644
index 00000000..9ad8ea8a
--- /dev/null
+++ b/src/platform_io/registry.py
@@ -0,0 +1,70 @@
+"""提供 Platform IO 的驱动注册与查询能力。"""
+
+from typing import Dict, List, Optional
+
+from src.platform_io.drivers.base import PlatformIODriver
+from src.platform_io.types import DriverKind
+
+
+class DriverRegistry:
+ """集中保存已注册的 Platform IO 驱动,并提供基础查询接口。"""
+
+ def __init__(self) -> None:
+ """初始化一个空的驱动注册表。"""
+ self._drivers: Dict[str, PlatformIODriver] = {}
+
+ def register(self, driver: PlatformIODriver) -> None:
+ """注册一个驱动实例。
+
+ Args:
+ driver: 要注册的驱动实例。
+
+ Raises:
+ ValueError: 当驱动 ID 已经存在时抛出。
+ """
+ if driver.driver_id in self._drivers:
+ raise ValueError(f"驱动 {driver.driver_id} 已注册")
+ self._drivers[driver.driver_id] = driver
+
+ def unregister(self, driver_id: str) -> Optional[PlatformIODriver]:
+ """按驱动 ID 注销一个驱动。
+
+ Args:
+ driver_id: 要移除的驱动 ID。
+
+ Returns:
+ Optional[PlatformIODriver]: 若驱动存在,则返回被移除的驱动实例。
+ """
+ return self._drivers.pop(driver_id, None)
+
+ def get(self, driver_id: str) -> Optional[PlatformIODriver]:
+ """按驱动 ID 获取驱动实例。
+
+ Args:
+ driver_id: 要查询的驱动 ID。
+
+ Returns:
+ Optional[PlatformIODriver]: 若存在匹配驱动,则返回该驱动实例。
+ """
+ return self._drivers.get(driver_id)
+
+ def list(self, *, kind: Optional[DriverKind] = None, platform: Optional[str] = None) -> List[PlatformIODriver]:
+ """列出已注册驱动,并支持可选过滤。
+
+ Args:
+ kind: 可选的驱动类型过滤条件。
+ platform: 可选的平台名称过滤条件。
+
+ Returns:
+ List[PlatformIODriver]: 符合过滤条件的驱动列表。
+ """
+ drivers = list(self._drivers.values())
+ if kind is not None:
+ drivers = [driver for driver in drivers if driver.descriptor.kind == kind]
+ if platform is not None:
+ drivers = [driver for driver in drivers if driver.descriptor.platform == platform]
+ return drivers
+
+ def clear(self) -> None:
+ """清空全部已注册驱动。"""
+ self._drivers.clear()
diff --git a/src/platform_io/route_key_factory.py b/src/platform_io/route_key_factory.py
new file mode 100644
index 00000000..05bac6e8
--- /dev/null
+++ b/src/platform_io/route_key_factory.py
@@ -0,0 +1,150 @@
+"""提供 Platform IO 路由键的统一提取与构造能力。
+
+这层的目标不是直接接入具体消息链,而是先把“未来接线时用什么字段构造
+RouteKey”约定下来,避免 legacy 和 plugin 两条链路各自发明一套隐式规则。
+"""
+
+from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
+
+from .types import RouteKey
+
+if TYPE_CHECKING:
+ from src.chat.message_receive.message import SessionMessage
+
+
+class RouteKeyFactory:
+ """统一构造 ``RouteKey`` 的工厂。
+
+ 当前约定会优先从消息字典顶层、``message_info``、``additional_config`` 或传入 metadata 中提取
+ 以下字段:
+
+ - account_id: ``platform_io_account_id`` / ``account_id`` / ``self_id`` / ``bot_account``
+ - scope: ``platform_io_scope`` / ``route_scope`` / ``adapter_scope`` / ``connection_id``
+
+ 这样即使上游主链暂时还没有正式的 ``self_id`` 字段,中间层也能先统一
+ 约定提取口径,等具体消息链接入时直接复用。
+ """
+
+ ACCOUNT_ID_KEYS = (
+ "platform_io_account_id",
+ "account_id",
+ "self_id",
+ "bot_account",
+ )
+ SCOPE_KEYS = (
+ "platform_io_scope",
+ "route_scope",
+ "adapter_scope",
+ "connection_id",
+ )
+
+ @classmethod
+ def from_platform(
+ cls,
+ platform: str,
+ *,
+ account_id: Optional[str] = None,
+ scope: Optional[str] = None,
+ metadata: Optional[Dict[str, Any]] = None,
+ ) -> RouteKey:
+ """根据平台名和可选 metadata 构造 ``RouteKey``。
+
+ Args:
+ platform: 平台名称。
+ account_id: 显式传入的账号 ID;若为空,则尝试从 metadata 提取。
+ scope: 显式传入的路由作用域;若为空,则尝试从 metadata 提取。
+ metadata: 可选的元数据字典。
+
+ Returns:
+ RouteKey: 构造出的规范化路由键。
+ """
+ extracted_account_id, extracted_scope = cls.extract_components(metadata)
+ return RouteKey(
+ platform=platform,
+ account_id=account_id or extracted_account_id,
+ scope=scope or extracted_scope,
+ )
+
+ @classmethod
+ def from_message_dict(cls, message_dict: Dict[str, Any]) -> RouteKey:
+ """从消息字典中提取 ``RouteKey``。
+
+ Args:
+ message_dict: Host 与插件之间传输的消息字典。
+
+ Returns:
+ RouteKey: 构造出的规范化路由键。
+
+ Raises:
+ ValueError: 当消息字典缺少有效 ``platform`` 字段时抛出。
+ """
+ platform = str(message_dict.get("platform") or "").strip()
+ if not platform:
+ raise ValueError("消息字典缺少有效的 platform 字段,无法构造 RouteKey")
+
+ message_info = message_dict.get("message_info", {})
+ additional_config = {}
+ if isinstance(message_info, dict):
+ raw_additional_config = message_info.get("additional_config", {})
+ if isinstance(raw_additional_config, dict):
+ additional_config = raw_additional_config
+
+ explicit_account_id, explicit_scope = cls.extract_components(message_dict)
+ message_info_account_id, message_info_scope = cls.extract_components(message_info)
+ metadata_account_id, metadata_scope = cls.extract_components(additional_config)
+ return RouteKey(
+ platform=platform,
+ account_id=explicit_account_id or message_info_account_id or metadata_account_id,
+ scope=explicit_scope or message_info_scope or metadata_scope,
+ )
+
+ @classmethod
+ def from_session_message(cls, message: "SessionMessage") -> RouteKey:
+ """从 ``SessionMessage`` 中提取 ``RouteKey``。
+
+ Args:
+ message: 内部会话消息对象。
+
+ Returns:
+ RouteKey: 构造出的规范化路由键。
+ """
+ additional_config = message.message_info.additional_config or {}
+ metadata = additional_config if isinstance(additional_config, dict) else {}
+ return cls.from_platform(message.platform, metadata=metadata)
+
+ @classmethod
+ def extract_components(cls, mapping: Optional[Dict[str, Any]]) -> Tuple[Optional[str], Optional[str]]:
+ """从任意字典中提取 ``account_id`` 与 ``scope``。
+
+ Args:
+ mapping: 待提取的字典;若为空或不是字典,则返回空结果。
+
+ Returns:
+ Tuple[Optional[str], Optional[str]]: ``(account_id, scope)``。
+ """
+ if not mapping or not isinstance(mapping, dict):
+ return None, None
+
+ account_id = cls._pick_string(mapping, cls.ACCOUNT_ID_KEYS)
+ scope = cls._pick_string(mapping, cls.SCOPE_KEYS)
+ return account_id, scope
+
+ @staticmethod
+ def _pick_string(mapping: Dict[str, Any], keys: Tuple[str, ...]) -> Optional[str]:
+ """按优先级从字典里挑选第一个有效字符串。
+
+ Args:
+ mapping: 待查询的字典。
+ keys: 按优先级排列的候选键名。
+
+ Returns:
+ Optional[str]: 第一个规范化后非空的字符串值;若不存在则返回 ``None``。
+ """
+ for key in keys:
+ value = mapping.get(key)
+ if value is None:
+ continue
+ normalized = str(value).strip()
+ if normalized:
+ return normalized
+ return None
diff --git a/src/platform_io/routing.py b/src/platform_io/routing.py
new file mode 100644
index 00000000..2a9b41ef
--- /dev/null
+++ b/src/platform_io/routing.py
@@ -0,0 +1,141 @@
+"""提供 Platform IO 的轻量路由绑定表。"""
+
+from typing import Dict, List, Optional
+
+from .types import RouteBinding, RouteKey
+
+
+class RouteTable:
+ """维护单张路由绑定表。
+
+ 该实现不负责裁决“唯一 owner”,只负责保存绑定,并按
+ ``RouteKey.resolution_order()`` 解析出候选绑定列表。
+ """
+
+ def __init__(self) -> None:
+ """初始化空路由绑定表。"""
+
+ self._bindings: Dict[RouteKey, Dict[str, RouteBinding]] = {}
+
+ def bind(self, binding: RouteBinding) -> None:
+ """注册或更新一条路由绑定。
+
+ Args:
+ binding: 要保存的路由绑定。
+ """
+
+ self._bindings.setdefault(binding.route_key, {})[binding.driver_id] = binding
+
+ def unbind(self, route_key: RouteKey, driver_id: Optional[str] = None) -> List[RouteBinding]:
+ """移除指定路由键上的绑定。
+
+ Args:
+ route_key: 要移除绑定的路由键。
+ driver_id: 可选的驱动 ID;为空时移除该路由键下全部绑定。
+
+ Returns:
+ List[RouteBinding]: 被移除的绑定列表。
+ """
+
+ binding_map = self._bindings.get(route_key)
+ if not binding_map:
+ return []
+
+ if driver_id is None:
+ removed = list(binding_map.values())
+ self._bindings.pop(route_key, None)
+ return self._sort_bindings(removed)
+
+ removed_binding = binding_map.pop(driver_id, None)
+ if not binding_map:
+ self._bindings.pop(route_key, None)
+ return [removed_binding] if removed_binding is not None else []
+
+ def remove_bindings_by_driver(self, driver_id: str) -> List[RouteBinding]:
+ """移除某个驱动在整张表上的全部绑定。
+
+ Args:
+ driver_id: 要移除绑定的驱动 ID。
+
+ Returns:
+ List[RouteBinding]: 被移除的绑定列表。
+ """
+
+ removed_bindings: List[RouteBinding] = []
+ empty_route_keys: List[RouteKey] = []
+ for route_key, binding_map in self._bindings.items():
+ removed_binding = binding_map.pop(driver_id, None)
+ if removed_binding is not None:
+ removed_bindings.append(removed_binding)
+ if not binding_map:
+ empty_route_keys.append(route_key)
+
+ for route_key in empty_route_keys:
+ self._bindings.pop(route_key, None)
+
+ return self._sort_bindings(removed_bindings)
+
+ def list_bindings(self, route_key: Optional[RouteKey] = None) -> List[RouteBinding]:
+ """列出当前路由表中的绑定。
+
+ Args:
+ route_key: 可选的路由键过滤条件。
+
+ Returns:
+ List[RouteBinding]: 当前绑定列表。
+ """
+
+ if route_key is None:
+ bindings: List[RouteBinding] = []
+ for binding_map in self._bindings.values():
+ bindings.extend(binding_map.values())
+ return self._sort_bindings(bindings)
+
+ binding_map = self._bindings.get(route_key, {})
+ return self._sort_bindings(list(binding_map.values()))
+
+ def resolve_bindings(self, route_key: RouteKey) -> List[RouteBinding]:
+ """按从具体到宽泛的顺序解析路由候选绑定。
+
+ Args:
+ route_key: 待解析的路由键。
+
+ Returns:
+ List[RouteBinding]: 去重后的候选绑定列表。
+ """
+
+ resolved_bindings: List[RouteBinding] = []
+ seen_driver_ids: set[str] = set()
+ for candidate_key in route_key.resolution_order():
+ for binding in self.list_bindings(candidate_key):
+ if binding.driver_id in seen_driver_ids:
+ continue
+ seen_driver_ids.add(binding.driver_id)
+ resolved_bindings.append(binding)
+ return resolved_bindings
+
+ def has_binding_for_driver(self, route_key: RouteKey, driver_id: str) -> bool:
+ """判断指定驱动是否在当前路由键解析结果中。
+
+ Args:
+ route_key: 待解析的路由键。
+ driver_id: 目标驱动 ID。
+
+ Returns:
+ bool: 若驱动存在于解析结果中则返回 ``True``。
+ """
+
+ return any(binding.driver_id == driver_id for binding in self.resolve_bindings(route_key))
+
+ @staticmethod
+ def _sort_bindings(bindings: List[RouteBinding]) -> List[RouteBinding]:
+ """按优先级降序排列绑定列表。
+
+ Args:
+ bindings: 待排序的绑定列表。
+
+ Returns:
+ List[RouteBinding]: 排序后的绑定列表。
+ """
+
+ return sorted(bindings, key=lambda item: item.priority, reverse=True)
diff --git a/src/platform_io/types.py b/src/platform_io/types.py
new file mode 100644
index 00000000..200eca51
--- /dev/null
+++ b/src/platform_io/types.py
@@ -0,0 +1,264 @@
+"""定义 Platform IO 中间层共享的核心类型。
+
+本模块放置路由、驱动、入站与出站等规范化数据结构,供 Broker
+层在 legacy 适配器链路和 plugin 适配器链路之间复用。
+"""
+
+from dataclasses import dataclass, field
+from enum import Enum
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+
+if TYPE_CHECKING:
+ from src.chat.message_receive.message import SessionMessage
+
+
+class DriverKind(str, Enum):
+ """底层收发驱动类型枚举。"""
+
+ LEGACY = "legacy"
+ PLUGIN = "plugin"
+
+
+class DeliveryStatus(str, Enum):
+ """统一出站回执状态枚举。"""
+
+ PENDING = "pending"
+ SENT = "sent"
+ FAILED = "failed"
+ DROPPED = "dropped"
+
+
+@dataclass(frozen=True, slots=True)
+class RouteKey:
+ """用于 Platform IO 路由决策的唯一键。
+
+ 路由解析会按照“从最具体到最宽泛”的顺序进行回退,这样同一平台
+ 后续就能自然支持按账号、自定义 scope 等更细粒度的归属控制。
+
+ Attributes:
+ platform: 平台名称,例如 ``qq``。
+ account_id: 机器人账号 ID 或 self ID,用于区分同平台多身份。
+ scope: 额外路由作用域,预留给未来的连接实例、租户或子通道等维度。
+ """
+
+ platform: str
+ account_id: Optional[str] = None
+ scope: Optional[str] = None
+
+ def __post_init__(self) -> None:
+ """规范化并校验路由键字段。
+
+ Raises:
+ ValueError: 当 ``platform`` 规范化后为空时抛出。
+ """
+ platform = str(self.platform).strip()
+ account_id = str(self.account_id).strip() if self.account_id is not None else None
+ scope = str(self.scope).strip() if self.scope is not None else None
+
+ if not platform:
+ raise ValueError("RouteKey.platform 不能为空")
+
+ object.__setattr__(self, "platform", platform)
+ object.__setattr__(self, "account_id", account_id or None)
+ object.__setattr__(self, "scope", scope or None)
+
+ def resolution_order(self) -> List["RouteKey"]:
+ """返回从最具体到最宽泛的路由匹配顺序。
+
+ Returns:
+ List[RouteKey]: 按回退优先级排序的候选路由键列表。
+ """
+
+ keys: List[RouteKey] = [self]
+
+ if self.account_id is not None and self.scope is not None:
+ keys.append(RouteKey(platform=self.platform, account_id=self.account_id, scope=None))
+ keys.append(RouteKey(platform=self.platform, account_id=None, scope=self.scope))
+ elif self.account_id is not None:
+ keys.append(RouteKey(platform=self.platform, account_id=None, scope=None))
+ elif self.scope is not None:
+ keys.append(RouteKey(platform=self.platform, account_id=None, scope=None))
+
+ default_key = RouteKey(platform=self.platform, account_id=None, scope=None)
+ if default_key not in keys:
+ keys.append(default_key)
+
+ return keys
+
+ def to_dedupe_scope(self) -> str:
+ """生成跨驱动共享的去重作用域字符串。
+
+ Returns:
+ str: 用于入站消息去重的稳定文本作用域键。
+ """
+
+ account_id = self.account_id or "*"
+ scope = self.scope or "*"
+ return f"{self.platform}:{account_id}:{scope}"
+
+
+@dataclass(frozen=True, slots=True)
+class DriverDescriptor:
+ """描述一个已注册的 Platform IO 驱动。
+
+ Attributes:
+ driver_id: Broker 层内全局唯一的驱动标识。
+ kind: 驱动实现类型,例如 legacy 或 plugin。
+ platform: 驱动负责的平台名称。
+ account_id: 可选的账号 ID 或 self ID。
+ scope: 可选的额外路由作用域。
+ plugin_id: 当驱动来自插件适配器时,对应的插件 ID。
+ metadata: 预留给路由策略或观测能力的额外驱动元数据。
+ """
+
+ driver_id: str
+ kind: DriverKind
+ platform: str
+ account_id: Optional[str] = None
+ scope: Optional[str] = None
+ plugin_id: Optional[str] = None
+ metadata: Dict[str, Any] = field(default_factory=dict)
+
+ def __post_init__(self) -> None:
+ """规范化并校验驱动描述字段。
+
+ Raises:
+ ValueError: 当 ``driver_id`` 或 ``platform`` 规范化后为空时抛出。
+ """
+ driver_id = str(self.driver_id).strip()
+ platform = str(self.platform).strip()
+ plugin_id = str(self.plugin_id).strip() if self.plugin_id is not None else None
+
+ if not driver_id:
+ raise ValueError("DriverDescriptor.driver_id 不能为空")
+ if not platform:
+ raise ValueError("DriverDescriptor.platform 不能为空")
+
+ object.__setattr__(self, "driver_id", driver_id)
+ object.__setattr__(self, "platform", platform)
+ object.__setattr__(self, "plugin_id", plugin_id or None)
+
+ @property
+ def route_key(self) -> RouteKey:
+ """构造该驱动默认代表的路由键。
+
+ Returns:
+ RouteKey: 当前驱动描述对应的规范化路由键。
+ """
+ return RouteKey(platform=self.platform, account_id=self.account_id, scope=self.scope)
+
+
+@dataclass(frozen=True, slots=True)
+class RouteBinding:
+ """表示一条从路由键到驱动的绑定关系。
+
+ Attributes:
+ route_key: 该绑定覆盖的路由键。
+ driver_id: 拥有该路由的驱动 ID。
+ driver_kind: 绑定驱动的类型。
+ priority: 当同一路由键存在多条绑定时使用的相对优先级。
+ metadata: 预留给未来路由策略的额外绑定元数据。
+ """
+
+ route_key: RouteKey
+ driver_id: str
+ driver_kind: DriverKind
+ priority: int = 0
+ metadata: Dict[str, Any] = field(default_factory=dict)
+
+ def __post_init__(self) -> None:
+ """规范化并校验绑定字段。
+
+ Raises:
+ ValueError: 当 ``driver_id`` 规范化后为空时抛出。
+ """
+ driver_id = str(self.driver_id).strip()
+ if not driver_id:
+ raise ValueError("RouteBinding.driver_id 不能为空")
+ object.__setattr__(self, "driver_id", driver_id)
+
+
+@dataclass(slots=True)
+class InboundMessageEnvelope:
+ """封装一次由驱动产出的规范化入站消息。
+
+ Attributes:
+ route_key: 该入站消息解析出的路由键。
+ driver_id: 产出该消息的驱动 ID。
+ driver_kind: 产出该消息的驱动类型。
+ external_message_id: 可选的平台侧消息 ID,用于去重。
+ dedupe_key: 可选的显式去重键。当外部消息没有稳定 ``message_id`` 时,
+ 可由上游驱动提供稳定的技术性幂等键。若这里为空,中间层仅会继续
+ 回退到 ``external_message_id`` 或 ``session_message.message_id``,
+ 不会再根据 ``payload`` 内容猜测语义去重键。
+ session_message: 可选的、已经完成规范化的 ``SessionMessage`` 对象。
+ payload: 可选的原始字典载荷,供延迟转换或调试使用。
+ metadata: 额外入站元数据,例如连接信息或追踪上下文。
+ """
+
+ route_key: RouteKey
+ driver_id: str
+ driver_kind: DriverKind
+ external_message_id: Optional[str] = None
+ dedupe_key: Optional[str] = None
+ session_message: Optional["SessionMessage"] = None
+ payload: Optional[Dict[str, Any]] = None
+ metadata: Dict[str, Any] = field(default_factory=dict)
+
+
+@dataclass(slots=True)
+class DeliveryReceipt:
+ """表示一次出站投递尝试的统一结果。
+
+ Attributes:
+ internal_message_id: Broker 跟踪的内部 ``SessionMessage.message_id``。
+ route_key: 本次投递使用的路由键。
+ status: 规范化后的投递状态。
+ driver_id: 实际处理该投递的驱动 ID,可为空。
+ driver_kind: 实际处理该投递的驱动类型,可为空。
+ external_message_id: 驱动或适配器返回的平台侧消息 ID,可为空。
+ error: 投递失败时的错误信息,可为空。
+ metadata: 预留给回执、时间戳或平台特有信息的额外元数据。
+ """
+
+ internal_message_id: str
+ route_key: RouteKey
+ status: DeliveryStatus
+ driver_id: Optional[str] = None
+ driver_kind: Optional[DriverKind] = None
+ external_message_id: Optional[str] = None
+ error: Optional[str] = None
+ metadata: Dict[str, Any] = field(default_factory=dict)
+
+
+@dataclass(slots=True)
+class DeliveryBatch:
+ """表示一次广播式出站投递的批量结果。
+
+ Attributes:
+ internal_message_id: 内部消息 ID。
+ route_key: 本次投递使用的路由键。
+ receipts: 各条路由的独立投递回执列表。
+ """
+
+ internal_message_id: str
+ route_key: RouteKey
+ receipts: List[DeliveryReceipt] = field(default_factory=list)
+
+ @property
+ def sent_receipts(self) -> List[DeliveryReceipt]:
+ """返回全部发送成功的回执。"""
+
+ return [receipt for receipt in self.receipts if receipt.status == DeliveryStatus.SENT]
+
+ @property
+ def failed_receipts(self) -> List[DeliveryReceipt]:
+ """返回全部发送失败的回执。"""
+
+ return [receipt for receipt in self.receipts if receipt.status != DeliveryStatus.SENT]
+
+ @property
+ def has_success(self) -> bool:
+ """返回当前批量投递是否至少命中一条成功回执。"""
+
+ return bool(self.sent_receipts)
diff --git a/src/plugin_runtime/__init__.py b/src/plugin_runtime/__init__.py
index a881d399..7f2d789f 100644
--- a/src/plugin_runtime/__init__.py
+++ b/src/plugin_runtime/__init__.py
@@ -16,3 +16,9 @@ ENV_PLUGIN_DIRS = "MAIBOT_PLUGIN_DIRS"
ENV_HOST_VERSION = "MAIBOT_HOST_VERSION"
"""Runner 读取的 Host 应用版本号,用于 manifest 兼容性校验"""
+
+ENV_EXTERNAL_PLUGIN_IDS = "MAIBOT_EXTERNAL_PLUGIN_IDS"
+"""Runner 启动时可视为已满足的外部插件依赖版本映射(JSON 对象)"""
+
+ENV_GLOBAL_CONFIG_SNAPSHOT = "MAIBOT_GLOBAL_CONFIG_SNAPSHOT"
+"""Runner 启动时注入的全局配置快照(JSON 对象)"""
diff --git a/src/plugin_runtime/capabilities/components.py b/src/plugin_runtime/capabilities/components.py
index aa7ceb46..33b54c64 100644
--- a/src/plugin_runtime/capabilities/components.py
+++ b/src/plugin_runtime/capabilities/components.py
@@ -1,12 +1,13 @@
from pathlib import Path
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Protocol
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Protocol, Sequence
from src.common.logger import get_logger
logger = get_logger("plugin_runtime.integration")
if TYPE_CHECKING:
- from src.plugin_runtime.host.component_registry import RegisteredComponent
+ from src.plugin_runtime.host.api_registry import APIEntry
+ from src.plugin_runtime.host.component_registry import ComponentEntry
from src.plugin_runtime.host.supervisor import PluginSupervisor
@@ -14,18 +15,311 @@ class _RuntimeComponentManagerProtocol(Protocol):
@property
def supervisors(self) -> List["PluginSupervisor"]: ...
+ def _normalize_component_type(self, component_type: str) -> str: ...
+
+ def _is_api_component_type(self, component_type: str) -> bool: ...
+
+ def _serialize_api_entry(self, entry: "APIEntry") -> Dict[str, Any]: ...
+
+ def _serialize_api_component_entry(self, entry: "APIEntry") -> Dict[str, Any]: ...
+
+ def _is_api_visible_to_plugin(self, entry: "APIEntry", caller_plugin_id: str) -> bool: ...
+
+ def _normalize_api_reference(self, api_name: str, version: str = "") -> tuple[str, str]: ...
+
+ def _build_api_unavailable_error(self, entry: "APIEntry") -> str: ...
+
def _get_supervisor_for_plugin(self, plugin_id: str) -> Optional["PluginSupervisor"]: ...
+ def _resolve_api_target(
+ self,
+ caller_plugin_id: str,
+ api_name: str,
+ version: str = "",
+ ) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]: ...
+
+ def _resolve_api_toggle_target(
+ self,
+ name: str,
+ version: str = "",
+ ) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]: ...
+
def _resolve_component_toggle_target(
self, name: str, component_type: str
- ) -> tuple[Optional["RegisteredComponent"], Optional[str]]: ...
+ ) -> tuple[Optional["ComponentEntry"], Optional[str]]: ...
def _find_duplicate_plugin_ids(self, plugin_dirs: List[Path]) -> Dict[str, List[Path]]: ...
def _iter_plugin_dirs(self) -> Iterable[Path]: ...
+ async def load_plugin_globally(self, plugin_id: str, reason: str = "manual") -> bool: ...
+
+ async def reload_plugins_globally(self, plugin_ids: Sequence[str], reason: str = "manual") -> bool: ...
+
class RuntimeComponentCapabilityMixin:
+ @staticmethod
+ def _normalize_component_type(component_type: str) -> str:
+ """规范化组件类型名称。
+
+ Args:
+ component_type: 原始组件类型。
+
+ Returns:
+ str: 统一转为大写后的组件类型名。
+ """
+
+ return str(component_type or "").strip().upper()
+
+ @classmethod
+ def _is_api_component_type(cls, component_type: str) -> bool:
+ """判断组件类型是否为 API。
+
+ Args:
+ component_type: 原始组件类型。
+
+ Returns:
+ bool: 是否为 API 组件类型。
+ """
+
+ return cls._normalize_component_type(component_type) == "API"
+
+ @staticmethod
+ def _serialize_api_entry(entry: "APIEntry") -> Dict[str, Any]:
+ """将 API 组件条目序列化为能力返回值。
+
+ Args:
+ entry: API 组件条目。
+
+ Returns:
+ Dict[str, Any]: 适合通过能力层返回给插件的 API 元信息。
+ """
+
+ return {
+ "name": entry.name,
+ "full_name": entry.full_name,
+ "plugin_id": entry.plugin_id,
+ "description": entry.description,
+ "version": entry.version,
+ "public": entry.public,
+ "enabled": entry.enabled,
+ "dynamic": entry.dynamic,
+ "offline_reason": entry.offline_reason,
+ "metadata": dict(entry.metadata),
+ }
+
+ @classmethod
+ def _serialize_api_component_entry(cls, entry: "APIEntry") -> Dict[str, Any]:
+ """将 API 条目序列化为通用组件视图。
+
+ Args:
+ entry: API 组件条目。
+
+ Returns:
+ Dict[str, Any]: 适合 ``component.get_all_plugins`` 返回的组件结构。
+ """
+
+ serialized_entry = cls._serialize_api_entry(entry)
+ return {
+ "name": serialized_entry["name"],
+ "full_name": serialized_entry["full_name"],
+ "type": "API",
+ "enabled": serialized_entry["enabled"],
+ "metadata": serialized_entry["metadata"],
+ }
+
+ @staticmethod
+ def _is_api_visible_to_plugin(entry: "APIEntry", caller_plugin_id: str) -> bool:
+ """判断某个 API 是否对调用方可见。
+
+ Args:
+ entry: 目标 API 组件条目。
+ caller_plugin_id: 调用方插件 ID。
+
+ Returns:
+ bool: 是否允许当前插件可见并调用。
+ """
+
+ return entry.plugin_id == caller_plugin_id or entry.public
+
+ @staticmethod
+ def _normalize_api_reference(api_name: str, version: str = "") -> tuple[str, str]:
+ """规范化 API 名称与版本参数。
+
+ 支持在 ``api_name`` 中直接携带 ``@version`` 后缀。
+ """
+
+ normalized_api_name = str(api_name or "").strip()
+ normalized_version = str(version or "").strip()
+ if normalized_api_name and not normalized_version and "@" in normalized_api_name:
+ candidate_name, candidate_version = normalized_api_name.rsplit("@", 1)
+ candidate_name = candidate_name.strip()
+ candidate_version = candidate_version.strip()
+ if candidate_name and candidate_version:
+ normalized_api_name = candidate_name
+ normalized_version = candidate_version
+ return normalized_api_name, normalized_version
+
+ @staticmethod
+ def _build_api_unavailable_error(entry: "APIEntry") -> str:
+ """构造 API 当前不可用时的错误信息。"""
+
+ if entry.offline_reason:
+ return entry.offline_reason
+ return f"API {entry.registry_key} 当前不可用"
+
+ def _resolve_api_target(
+ self: _RuntimeComponentManagerProtocol,
+ caller_plugin_id: str,
+ api_name: str,
+ version: str = "",
+ ) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]:
+ """解析 API 名称到唯一可调用的目标组件。
+
+ Args:
+ caller_plugin_id: 调用方插件 ID。
+ api_name: API 名称,支持 ``plugin_id.api_name`` 或唯一短名。
+ version: 可选的 API 版本。
+
+ Returns:
+ tuple[Optional[PluginSupervisor], Optional[APIEntry], Optional[str]]:
+ 解析成功时返回 ``(监督器, API 条目, None)``,失败时返回错误信息。
+ """
+
+ normalized_api_name, normalized_version = self._normalize_api_reference(api_name, version)
+ if not normalized_api_name:
+ return None, None, "缺少必要参数 api_name"
+
+ if "." in normalized_api_name:
+ target_plugin_id, target_api_name = normalized_api_name.rsplit(".", 1)
+ try:
+ supervisor = self._get_supervisor_for_plugin(target_plugin_id)
+ except RuntimeError as exc:
+ return None, None, str(exc)
+
+ if supervisor is None:
+ return None, None, f"未找到 API 提供方插件: {target_plugin_id}"
+
+ entries = supervisor.api_registry.get_apis(
+ plugin_id=target_plugin_id,
+ name=target_api_name,
+ version=normalized_version,
+ enabled_only=False,
+ )
+ visible_enabled_entries = [
+ entry
+ for entry in entries
+ if self._is_api_visible_to_plugin(entry, caller_plugin_id) and entry.enabled
+ ]
+ visible_disabled_entries = [
+ entry
+ for entry in entries
+ if self._is_api_visible_to_plugin(entry, caller_plugin_id) and not entry.enabled
+ ]
+ if len(visible_enabled_entries) == 1:
+ return supervisor, visible_enabled_entries[0], None
+ if len(visible_enabled_entries) > 1:
+ return None, None, f"API {normalized_api_name} 存在多个版本,请显式指定 version"
+ if visible_disabled_entries:
+ if len(visible_disabled_entries) == 1:
+ return None, None, self._build_api_unavailable_error(visible_disabled_entries[0])
+ return None, None, f"API {normalized_api_name} 存在多个已下线版本,请显式指定 version"
+ if any(not self._is_api_visible_to_plugin(entry, caller_plugin_id) for entry in entries):
+ return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用"
+ if normalized_version:
+ return None, None, f"未找到版本为 {normalized_version} 的 API: {normalized_api_name}"
+ return None, None, f"未找到 API: {normalized_api_name}"
+
+ visible_enabled_matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
+ visible_disabled_matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
+ hidden_match_exists = False
+ for supervisor in self.supervisors:
+ for entry in supervisor.api_registry.get_apis(
+ name=normalized_api_name,
+ version=normalized_version,
+ enabled_only=False,
+ ):
+ if self._is_api_visible_to_plugin(entry, caller_plugin_id):
+ if entry.enabled:
+ visible_enabled_matches.append((supervisor, entry))
+ else:
+ visible_disabled_matches.append((supervisor, entry))
+ else:
+ hidden_match_exists = True
+
+ if len(visible_enabled_matches) == 1:
+ return visible_enabled_matches[0][0], visible_enabled_matches[0][1], None
+ if len(visible_enabled_matches) > 1:
+ return None, None, f"API 名称不唯一: {normalized_api_name},请使用 plugin_id.api_name 或显式指定 version"
+ if visible_disabled_matches:
+ if len(visible_disabled_matches) == 1:
+ return None, None, self._build_api_unavailable_error(visible_disabled_matches[0][1])
+ return None, None, f"API {normalized_api_name} 存在多个已下线版本,请使用 plugin_id.api_name@version"
+ if hidden_match_exists:
+ return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用"
+ if normalized_version:
+ return None, None, f"未找到版本为 {normalized_version} 的 API: {normalized_api_name}"
+ return None, None, f"未找到 API: {normalized_api_name}"
+
+ def _resolve_api_toggle_target(
+ self: _RuntimeComponentManagerProtocol,
+ name: str,
+ version: str = "",
+ ) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]:
+ """解析需要启用或禁用的 API 组件。
+
+ Args:
+ name: API 名称,支持 ``plugin_id.api_name`` 或唯一短名。
+ version: 可选的 API 版本。
+
+ Returns:
+ tuple[Optional[PluginSupervisor], Optional[APIEntry], Optional[str]]:
+ 解析成功时返回 ``(监督器, API 条目, None)``,失败时返回错误信息。
+ """
+
+ normalized_name, normalized_version = self._normalize_api_reference(name, version)
+ if not normalized_name:
+ return None, None, "缺少必要参数 name"
+
+ if "." in normalized_name:
+ plugin_id, api_name = normalized_name.rsplit(".", 1)
+ try:
+ supervisor = self._get_supervisor_for_plugin(plugin_id)
+ except RuntimeError as exc:
+ return None, None, str(exc)
+
+ if supervisor is None:
+ return None, None, f"未找到 API 提供方插件: {plugin_id}"
+
+ entries = supervisor.api_registry.get_apis(
+ plugin_id=plugin_id,
+ name=api_name,
+ version=normalized_version,
+ enabled_only=False,
+ )
+ if len(entries) == 1:
+ return supervisor, entries[0], None
+ if entries:
+ return None, None, f"API {normalized_name} 存在多个版本,请显式指定 version"
+ return None, None, f"未找到 API: {normalized_name}"
+
+ matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
+ for supervisor in self.supervisors:
+ matches.extend(
+ (supervisor, entry)
+ for entry in supervisor.api_registry.get_apis(
+ name=normalized_name,
+ version=normalized_version,
+ enabled_only=False,
+ )
+ )
+
+ if len(matches) == 1:
+ return matches[0][0], matches[0][1], None
+ if len(matches) > 1:
+ return None, None, f"API 名称不唯一: {normalized_name},请使用 plugin_id.api_name 或显式指定 version"
+ return None, None, f"未找到 API: {normalized_name}"
+
async def _cap_component_get_all_plugins(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
@@ -46,6 +340,10 @@ class RuntimeComponentCapabilityMixin:
}
for component in comps
]
+ components_list.extend(
+ self._serialize_api_component_entry(entry)
+ for entry in sv.api_registry.get_apis(plugin_id=pid, enabled_only=False)
+ )
result[pid] = {
"name": pid,
"version": reg.plugin_version,
@@ -96,30 +394,35 @@ class RuntimeComponentCapabilityMixin:
def _resolve_component_toggle_target(
self: _RuntimeComponentManagerProtocol, name: str, component_type: str
- ) -> tuple[Optional["RegisteredComponent"], Optional[str]]:
- short_name_matches: List["RegisteredComponent"] = []
+ ) -> tuple[Optional["ComponentEntry"], Optional[str]]:
+ normalized_component_type = self._normalize_component_type(component_type)
+ short_name_matches: List["ComponentEntry"] = []
for sv in self.supervisors:
comp = sv.component_registry.get_component(name)
- if comp is not None and comp.component_type == component_type:
+ if comp is not None and comp.component_type == normalized_component_type:
return comp, None
short_name_matches.extend(
candidate
- for candidate in sv.component_registry.get_components_by_type(component_type, enabled_only=False)
+ for candidate in sv.component_registry.get_components_by_type(
+ normalized_component_type,
+ enabled_only=False,
+ )
if candidate.name == name
)
if len(short_name_matches) == 1:
return short_name_matches[0], None
if len(short_name_matches) > 1:
- return None, f"组件名不唯一: {name} ({component_type}),请使用完整名 plugin_id.component_name"
- return None, f"未找到组件: {name} ({component_type})"
+ return None, f"组件名不唯一: {name} ({normalized_component_type}),请使用完整名 plugin_id.component_name"
+ return None, f"未找到组件: {name} ({normalized_component_type})"
async def _cap_component_enable(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
name: str = args.get("name", "")
component_type: str = args.get("component_type", "")
+ version: str = args.get("version", "")
scope: str = args.get("scope", "global")
stream_id: str = args.get("stream_id", "")
if not name or not component_type:
@@ -127,6 +430,13 @@ class RuntimeComponentCapabilityMixin:
if scope != "global" or stream_id:
return {"success": False, "error": "当前仅支持全局组件启用,不支持 scope/stream_id 定位"}
+ if self._is_api_component_type(component_type):
+ supervisor, api_entry, error = self._resolve_api_toggle_target(name, version)
+ if supervisor is None or api_entry is None:
+ return {"success": False, "error": error or f"未找到 API: {name}"}
+ supervisor.api_registry.toggle_api_status(api_entry.registry_key, True)
+ return {"success": True}
+
comp, error = self._resolve_component_toggle_target(name, component_type)
if comp is None:
return {"success": False, "error": error or f"未找到组件: {name} ({component_type})"}
@@ -139,6 +449,7 @@ class RuntimeComponentCapabilityMixin:
) -> Any:
name: str = args.get("name", "")
component_type: str = args.get("component_type", "")
+ version: str = args.get("version", "")
scope: str = args.get("scope", "global")
stream_id: str = args.get("stream_id", "")
if not name or not component_type:
@@ -146,6 +457,13 @@ class RuntimeComponentCapabilityMixin:
if scope != "global" or stream_id:
return {"success": False, "error": "当前仅支持全局组件禁用,不支持 scope/stream_id 定位"}
+ if self._is_api_component_type(component_type):
+ supervisor, api_entry, error = self._resolve_api_toggle_target(name, version)
+ if supervisor is None or api_entry is None:
+ return {"success": False, "error": error or f"未找到 API: {name}"}
+ supervisor.api_registry.toggle_api_status(api_entry.registry_key, False)
+ return {"success": True}
+
comp, error = self._resolve_component_toggle_target(name, component_type)
if comp is None:
return {"success": False, "error": error or f"未找到组件: {name} ({component_type})"}
@@ -168,33 +486,14 @@ class RuntimeComponentCapabilityMixin:
return {"success": False, "error": f"检测到重复插件 ID,拒绝热重载: {details}"}
try:
- registered_supervisor = self._get_supervisor_for_plugin(plugin_name)
- except RuntimeError as exc:
- return {"success": False, "error": str(exc)}
+ loaded = await self.load_plugin_globally(plugin_name, reason=f"load {plugin_name}")
+ except Exception as e:
+ logger.error(f"[cap.component.load_plugin] 热重载失败: {e}")
+ return {"success": False, "error": str(e)}
- if registered_supervisor is not None:
- try:
- reloaded = await registered_supervisor.reload_plugins(reason=f"load {plugin_name}")
- if reloaded:
- return {"success": True, "count": 1}
- return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
- except Exception as e:
- logger.error(f"[cap.component.load_plugin] 热重载失败: {e}")
- return {"success": False, "error": str(e)}
-
- for sv in self.supervisors:
- for pdir in sv._plugin_dirs:
- if (pdir / plugin_name).is_dir():
- try:
- reloaded = await sv.reload_plugins(reason=f"load {plugin_name}")
- if reloaded:
- return {"success": True, "count": 1}
- return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
- except Exception as e:
- logger.error(f"[cap.component.load_plugin] 热重载失败: {e}")
- return {"success": False, "error": str(e)}
-
- return {"success": False, "error": f"未找到插件: {plugin_name}"}
+ if loaded:
+ return {"success": True, "count": 1}
+ return {"success": False, "error": f"插件 {plugin_name} 热重载失败"}
async def _cap_component_unload_plugin(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
@@ -216,17 +515,204 @@ class RuntimeComponentCapabilityMixin:
return {"success": False, "error": f"检测到重复插件 ID,拒绝热重载: {details}"}
try:
- sv = self._get_supervisor_for_plugin(plugin_name)
+ reloaded = await self.reload_plugins_globally([plugin_name], reason=f"reload {plugin_name}")
+ except Exception as e:
+ logger.error(f"[cap.component.reload_plugin] 热重载失败: {e}")
+ return {"success": False, "error": str(e)}
+
+ if reloaded:
+ return {"success": True}
+ return {"success": False, "error": f"插件 {plugin_name} 热重载失败"}
+
+ async def _cap_api_call(
+ self: _RuntimeComponentManagerProtocol,
+ plugin_id: str,
+ capability: str,
+ args: Dict[str, Any],
+ ) -> Any:
+ """调用其他插件公开的 API。
+
+ Args:
+ plugin_id: 当前调用方插件 ID。
+ capability: 能力名称。
+ args: 能力参数。
+
+ Returns:
+ Any: API 调用结果。
+ """
+
+ del capability
+ api_name = str(args.get("api_name", "") or "").strip()
+ version = str(args.get("version", "") or "").strip()
+ api_args = args.get("args", {})
+ if not isinstance(api_args, dict):
+ return {"success": False, "error": "参数 args 必须为字典"}
+
+ supervisor, entry, error = self._resolve_api_target(plugin_id, api_name, version)
+ if supervisor is None or entry is None:
+ return {"success": False, "error": error or "API 解析失败"}
+
+ invoke_args = dict(api_args)
+ if entry.dynamic:
+ invoke_args.setdefault("__maibot_api_name__", entry.name)
+ invoke_args.setdefault("__maibot_api_full_name__", entry.full_name)
+ invoke_args.setdefault("__maibot_api_version__", entry.version)
+
+ try:
+ response = await supervisor.invoke_api(
+ plugin_id=entry.plugin_id,
+ component_name=entry.handler_name,
+ args=invoke_args,
+ timeout_ms=30000,
+ )
+ except Exception as exc:
+ logger.error(f"[cap.api.call] 调用 API {entry.full_name} 失败: {exc}", exc_info=True)
+ return {"success": False, "error": str(exc)}
+
+ if response.error:
+ return {"success": False, "error": response.error.get("message", "API 调用失败")}
+
+ payload = response.payload if isinstance(response.payload, dict) else {}
+ if not bool(payload.get("success", False)):
+ result = payload.get("result")
+ return {"success": False, "error": "" if result is None else str(result)}
+ return {"success": True, "result": payload.get("result")}
+
+ async def _cap_api_get(
+ self: _RuntimeComponentManagerProtocol,
+ plugin_id: str,
+ capability: str,
+ args: Dict[str, Any],
+ ) -> Any:
+ """获取当前插件可见的单个 API 元信息。
+
+ Args:
+ plugin_id: 当前调用方插件 ID。
+ capability: 能力名称。
+ args: 能力参数。
+
+ Returns:
+ Any: API 元信息或 ``None``。
+ """
+
+ del capability
+ api_name = str(args.get("api_name", "") or "").strip()
+ version = str(args.get("version", "") or "").strip()
+ if not api_name:
+ return {"success": False, "error": "缺少必要参数 api_name"}
+
+ supervisor, entry, _error = self._resolve_api_target(plugin_id, api_name, version)
+ if supervisor is None or entry is None:
+ return {"success": True, "api": None}
+ return {"success": True, "api": self._serialize_api_entry(entry)}
+
+ async def _cap_api_list(
+ self: _RuntimeComponentManagerProtocol,
+ plugin_id: str,
+ capability: str,
+ args: Dict[str, Any],
+ ) -> Any:
+ """列出当前插件可见的 API 列表。
+
+ Args:
+ plugin_id: 当前调用方插件 ID。
+ capability: 能力名称。
+ args: 能力参数。
+
+ Returns:
+ Any: API 元信息列表。
+ """
+
+ del capability
+ target_plugin_id = str(args.get("plugin_id", "") or "").strip()
+ api_name, version = self._normalize_api_reference(
+ str(args.get("api_name", args.get("name", "")) or ""),
+ str(args.get("version", "") or ""),
+ )
+ apis: List[Dict[str, Any]] = []
+ for supervisor in self.supervisors:
+ apis.extend(
+ self._serialize_api_entry(entry)
+ for entry in supervisor.api_registry.get_apis(
+ plugin_id=target_plugin_id or None,
+ name=api_name,
+ version=version,
+ enabled_only=True,
+ )
+ if self._is_api_visible_to_plugin(entry, plugin_id)
+ )
+
+ apis.sort(key=lambda item: (str(item["plugin_id"]), str(item["name"]), str(item["version"])))
+ return {"success": True, "apis": apis}
+
+ async def _cap_api_replace_dynamic(
+ self: _RuntimeComponentManagerProtocol,
+ plugin_id: str,
+ capability: str,
+ args: Dict[str, Any],
+ ) -> Any:
+ """替换插件自行维护的动态 API 列表。"""
+
+ del capability
+ raw_apis = args.get("apis", [])
+ offline_reason = str(args.get("offline_reason", "") or "").strip() or "动态 API 已下线"
+ if not isinstance(raw_apis, list):
+ return {"success": False, "error": "参数 apis 必须为列表"}
+
+ try:
+ supervisor = self._get_supervisor_for_plugin(plugin_id)
except RuntimeError as exc:
return {"success": False, "error": str(exc)}
- if sv is not None:
- try:
- reloaded = await sv.reload_plugins(reason=f"reload {plugin_name}")
- if reloaded:
- return {"success": True}
- return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
- except Exception as e:
- logger.error(f"[cap.component.reload_plugin] 热重载失败: {e}")
- return {"success": False, "error": str(e)}
- return {"success": False, "error": f"未找到插件: {plugin_name}"}
+ if supervisor is None:
+ return {"success": False, "error": f"未找到插件: {plugin_id}"}
+
+ normalized_components: List[Dict[str, Any]] = []
+ seen_registry_keys: set[str] = set()
+ for index, raw_api in enumerate(raw_apis):
+ if not isinstance(raw_api, dict):
+ return {"success": False, "error": f"apis[{index}] 必须为字典"}
+
+ api_name = str(raw_api.get("name", "") or "").strip()
+ component_type = str(raw_api.get("component_type", raw_api.get("type", "API")) or "").strip()
+ if not api_name:
+ return {"success": False, "error": f"apis[{index}] 缺少 name"}
+ if not self._is_api_component_type(component_type):
+ return {"success": False, "error": f"apis[{index}] 不是 API 组件"}
+
+ metadata = raw_api.get("metadata", {}) if isinstance(raw_api.get("metadata"), dict) else {}
+ normalized_metadata = dict(metadata)
+ normalized_metadata["dynamic"] = True
+ version = str(normalized_metadata.get("version", "1") or "1").strip() or "1"
+ registry_key = supervisor.api_registry.build_registry_key(plugin_id, api_name, version)
+ if registry_key in seen_registry_keys:
+ return {"success": False, "error": f"动态 API 重复声明: {registry_key}"}
+ seen_registry_keys.add(registry_key)
+
+ existing_entry = supervisor.api_registry.get_api(
+ plugin_id,
+ api_name,
+ version=version,
+ enabled_only=False,
+ )
+ if existing_entry is not None and not existing_entry.dynamic:
+ return {"success": False, "error": f"动态 API 不能覆盖静态 API: {registry_key}"}
+
+ normalized_components.append(
+ {
+ "name": api_name,
+ "component_type": "API",
+ "metadata": normalized_metadata,
+ }
+ )
+
+ registered_count, offlined_count = supervisor.api_registry.replace_plugin_dynamic_apis(
+ plugin_id,
+ normalized_components,
+ offline_reason=offline_reason,
+ )
+ return {
+ "success": True,
+ "count": registered_count,
+ "offlined": offlined_count,
+ }
diff --git a/src/plugin_runtime/capabilities/core.py b/src/plugin_runtime/capabilities/core.py
index def5f03d..9bb1755b 100644
--- a/src/plugin_runtime/capabilities/core.py
+++ b/src/plugin_runtime/capabilities/core.py
@@ -238,14 +238,14 @@ class RuntimeCoreCapabilityMixin:
return {"success": False, "value": None, "error": str(e)}
async def _cap_config_get_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
- from src.core.component_registry import component_registry as core_registry
+ from src.plugin_runtime.component_query import component_query_service
plugin_name: str = args.get("plugin_name", plugin_id)
key: str = args.get("key", "")
default = args.get("default")
try:
- config = core_registry.get_plugin_config(plugin_name)
+ config = component_query_service.get_plugin_config(plugin_name)
if config is None:
return {"success": False, "value": default, "error": f"未找到插件 {plugin_name} 的配置"}
@@ -258,11 +258,11 @@ class RuntimeCoreCapabilityMixin:
return {"success": False, "value": default, "error": str(e)}
async def _cap_config_get_all(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
- from src.core.component_registry import component_registry as core_registry
+ from src.plugin_runtime.component_query import component_query_service
plugin_name: str = args.get("plugin_name", plugin_id)
try:
- config = core_registry.get_plugin_config(plugin_name)
+ config = component_query_service.get_plugin_config(plugin_name)
if config is None:
return {"success": True, "value": {}}
return {"success": True, "value": config}
diff --git a/src/plugin_runtime/capabilities/data.py b/src/plugin_runtime/capabilities/data.py
index c8139c16..32843d09 100644
--- a/src/plugin_runtime/capabilities/data.py
+++ b/src/plugin_runtime/capabilities/data.py
@@ -648,10 +648,10 @@ class RuntimeDataCapabilityMixin:
return {"success": False, "error": str(e)}
async def _cap_tool_get_definitions(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
- from src.core.component_registry import component_registry as core_registry
+ from src.plugin_runtime.component_query import component_query_service
try:
- tools = core_registry.get_llm_available_tools()
+ tools = component_query_service.get_llm_available_tools()
return {
"success": True,
"tools": [{"name": name, "definition": info.get_llm_definition()} for name, info in tools.items()],
diff --git a/src/plugin_runtime/capabilities/registry.py b/src/plugin_runtime/capabilities/registry.py
index abce97dc..7f87604d 100644
--- a/src/plugin_runtime/capabilities/registry.py
+++ b/src/plugin_runtime/capabilities/registry.py
@@ -1,6 +1,7 @@
from typing import TYPE_CHECKING
from src.common.logger import get_logger
+from src.plugin_runtime.host.capability_service import CapabilityImpl
from src.plugin_runtime.host.supervisor import PluginSupervisor
if TYPE_CHECKING:
@@ -13,66 +14,80 @@ def register_capability_impls(manager: "PluginRuntimeManager", supervisor: Plugi
"""向指定 Supervisor 注册主程序提供的能力实现。"""
cap_service = supervisor.capability_service
- cap_service.register_capability("send.text", manager._cap_send_text)
- cap_service.register_capability("send.emoji", manager._cap_send_emoji)
- cap_service.register_capability("send.image", manager._cap_send_image)
- cap_service.register_capability("send.command", manager._cap_send_command)
- cap_service.register_capability("send.custom", manager._cap_send_custom)
+ def _register(name: str, impl: CapabilityImpl) -> None:
+ """注册单个能力实现。
- cap_service.register_capability("llm.generate", manager._cap_llm_generate)
- cap_service.register_capability("llm.generate_with_tools", manager._cap_llm_generate_with_tools)
- cap_service.register_capability("llm.get_available_models", manager._cap_llm_get_available_models)
+ Args:
+ name: 能力名称。
+ impl: 能力实现函数。
+ """
+ cap_service.register_capability(name, impl)
- cap_service.register_capability("config.get", manager._cap_config_get)
- cap_service.register_capability("config.get_plugin", manager._cap_config_get_plugin)
- cap_service.register_capability("config.get_all", manager._cap_config_get_all)
+ _register("send.text", manager._cap_send_text)
+ _register("send.emoji", manager._cap_send_emoji)
+ _register("send.image", manager._cap_send_image)
+ _register("send.command", manager._cap_send_command)
+ _register("send.custom", manager._cap_send_custom)
- cap_service.register_capability("database.query", manager._cap_database_query)
- cap_service.register_capability("database.save", manager._cap_database_save)
- cap_service.register_capability("database.get", manager._cap_database_get)
- cap_service.register_capability("database.delete", manager._cap_database_delete)
- cap_service.register_capability("database.count", manager._cap_database_count)
+ _register("llm.generate", manager._cap_llm_generate)
+ _register("llm.generate_with_tools", manager._cap_llm_generate_with_tools)
+ _register("llm.get_available_models", manager._cap_llm_get_available_models)
- cap_service.register_capability("chat.get_all_streams", manager._cap_chat_get_all_streams)
- cap_service.register_capability("chat.get_group_streams", manager._cap_chat_get_group_streams)
- cap_service.register_capability("chat.get_private_streams", manager._cap_chat_get_private_streams)
- cap_service.register_capability("chat.get_stream_by_group_id", manager._cap_chat_get_stream_by_group_id)
- cap_service.register_capability("chat.get_stream_by_user_id", manager._cap_chat_get_stream_by_user_id)
+ _register("config.get", manager._cap_config_get)
+ _register("config.get_plugin", manager._cap_config_get_plugin)
+ _register("config.get_all", manager._cap_config_get_all)
- cap_service.register_capability("message.get_by_time", manager._cap_message_get_by_time)
- cap_service.register_capability("message.get_by_time_in_chat", manager._cap_message_get_by_time_in_chat)
- cap_service.register_capability("message.get_recent", manager._cap_message_get_recent)
- cap_service.register_capability("message.count_new", manager._cap_message_count_new)
- cap_service.register_capability("message.build_readable", manager._cap_message_build_readable)
+ _register("database.query", manager._cap_database_query)
+ _register("database.save", manager._cap_database_save)
+ _register("database.get", manager._cap_database_get)
+ _register("database.delete", manager._cap_database_delete)
+ _register("database.count", manager._cap_database_count)
- cap_service.register_capability("person.get_id", manager._cap_person_get_id)
- cap_service.register_capability("person.get_value", manager._cap_person_get_value)
- cap_service.register_capability("person.get_id_by_name", manager._cap_person_get_id_by_name)
+ _register("chat.get_all_streams", manager._cap_chat_get_all_streams)
+ _register("chat.get_group_streams", manager._cap_chat_get_group_streams)
+ _register("chat.get_private_streams", manager._cap_chat_get_private_streams)
+ _register("chat.get_stream_by_group_id", manager._cap_chat_get_stream_by_group_id)
+ _register("chat.get_stream_by_user_id", manager._cap_chat_get_stream_by_user_id)
- cap_service.register_capability("emoji.get_by_description", manager._cap_emoji_get_by_description)
- cap_service.register_capability("emoji.get_random", manager._cap_emoji_get_random)
- cap_service.register_capability("emoji.get_count", manager._cap_emoji_get_count)
- cap_service.register_capability("emoji.get_emotions", manager._cap_emoji_get_emotions)
- cap_service.register_capability("emoji.get_all", manager._cap_emoji_get_all)
- cap_service.register_capability("emoji.get_info", manager._cap_emoji_get_info)
- cap_service.register_capability("emoji.register", manager._cap_emoji_register)
- cap_service.register_capability("emoji.delete", manager._cap_emoji_delete)
+ _register("message.get_by_time", manager._cap_message_get_by_time)
+ _register("message.get_by_time_in_chat", manager._cap_message_get_by_time_in_chat)
+ _register("message.get_recent", manager._cap_message_get_recent)
+ _register("message.count_new", manager._cap_message_count_new)
+ _register("message.build_readable", manager._cap_message_build_readable)
- cap_service.register_capability("frequency.get_current_talk_value", manager._cap_frequency_get_current_talk_value)
- cap_service.register_capability("frequency.set_adjust", manager._cap_frequency_set_adjust)
- cap_service.register_capability("frequency.get_adjust", manager._cap_frequency_get_adjust)
+ _register("person.get_id", manager._cap_person_get_id)
+ _register("person.get_value", manager._cap_person_get_value)
+ _register("person.get_id_by_name", manager._cap_person_get_id_by_name)
- cap_service.register_capability("tool.get_definitions", manager._cap_tool_get_definitions)
+ _register("emoji.get_by_description", manager._cap_emoji_get_by_description)
+ _register("emoji.get_random", manager._cap_emoji_get_random)
+ _register("emoji.get_count", manager._cap_emoji_get_count)
+ _register("emoji.get_emotions", manager._cap_emoji_get_emotions)
+ _register("emoji.get_all", manager._cap_emoji_get_all)
+ _register("emoji.get_info", manager._cap_emoji_get_info)
+ _register("emoji.register", manager._cap_emoji_register)
+ _register("emoji.delete", manager._cap_emoji_delete)
- cap_service.register_capability("component.get_all_plugins", manager._cap_component_get_all_plugins)
- cap_service.register_capability("component.get_plugin_info", manager._cap_component_get_plugin_info)
- cap_service.register_capability("component.list_loaded_plugins", manager._cap_component_list_loaded_plugins)
- cap_service.register_capability("component.list_registered_plugins", manager._cap_component_list_registered_plugins)
- cap_service.register_capability("component.enable", manager._cap_component_enable)
- cap_service.register_capability("component.disable", manager._cap_component_disable)
- cap_service.register_capability("component.load_plugin", manager._cap_component_load_plugin)
- cap_service.register_capability("component.unload_plugin", manager._cap_component_unload_plugin)
- cap_service.register_capability("component.reload_plugin", manager._cap_component_reload_plugin)
+ _register("frequency.get_current_talk_value", manager._cap_frequency_get_current_talk_value)
+ _register("frequency.set_adjust", manager._cap_frequency_set_adjust)
+ _register("frequency.get_adjust", manager._cap_frequency_get_adjust)
- cap_service.register_capability("knowledge.search", manager._cap_knowledge_search)
+ _register("tool.get_definitions", manager._cap_tool_get_definitions)
+
+ _register("api.call", manager._cap_api_call)
+ _register("api.get", manager._cap_api_get)
+ _register("api.list", manager._cap_api_list)
+ _register("api.replace_dynamic", manager._cap_api_replace_dynamic)
+
+ _register("component.get_all_plugins", manager._cap_component_get_all_plugins)
+ _register("component.get_plugin_info", manager._cap_component_get_plugin_info)
+ _register("component.list_loaded_plugins", manager._cap_component_list_loaded_plugins)
+ _register("component.list_registered_plugins", manager._cap_component_list_registered_plugins)
+ _register("component.enable", manager._cap_component_enable)
+ _register("component.disable", manager._cap_component_disable)
+ _register("component.load_plugin", manager._cap_component_load_plugin)
+ _register("component.unload_plugin", manager._cap_component_unload_plugin)
+ _register("component.reload_plugin", manager._cap_component_reload_plugin)
+
+ _register("knowledge.search", manager._cap_knowledge_search)
logger.debug("已注册全部主程序能力实现")
diff --git a/src/plugin_runtime/component_query.py b/src/plugin_runtime/component_query.py
new file mode 100644
index 00000000..7d23d202
--- /dev/null
+++ b/src/plugin_runtime/component_query.py
@@ -0,0 +1,709 @@
+"""插件运行时统一组件查询服务。
+
+该模块统一从插件运行时的 Host ComponentRegistry 中聚合只读视图,
+供 HFC/PFC、Planner、ToolExecutor 和运行时能力层查询与调用。
+"""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tuple
+
+from src.common.logger import get_logger
+from src.core.types import ActionActivationType, ActionInfo, CommandInfo, ComponentInfo, ComponentType, ToolInfo
+from src.llm_models.payload_content.tool_option import ToolParamType
+
+if TYPE_CHECKING:
+ from src.plugin_runtime.host.component_registry import ActionEntry, CommandEntry, ComponentEntry, ToolEntry
+ from src.plugin_runtime.host.supervisor import PluginSupervisor
+ from src.plugin_runtime.integration import PluginRuntimeManager
+
+logger = get_logger("plugin_runtime.component_query")
+
+ActionExecutor = Callable[..., Awaitable[Any]]
+CommandExecutor = Callable[..., Awaitable[Tuple[bool, Optional[str], bool]]]
+ToolExecutor = Callable[..., Awaitable[Any]]
+
+_HOST_COMPONENT_TYPE_MAP: Dict[ComponentType, str] = {
+ ComponentType.ACTION: "ACTION",
+ ComponentType.COMMAND: "COMMAND",
+ ComponentType.TOOL: "TOOL",
+}
+_TOOL_PARAM_TYPE_MAP: Dict[str, ToolParamType] = {
+ "string": ToolParamType.STRING,
+ "integer": ToolParamType.INTEGER,
+ "float": ToolParamType.FLOAT,
+ "boolean": ToolParamType.BOOLEAN,
+ "bool": ToolParamType.BOOLEAN,
+}
+
+
+class ComponentQueryService:
+ """插件运行时统一组件查询服务。
+
+ 该对象不维护独立状态,只读取插件系统中的注册结果。
+ 所有注册、删除、配置写入等写操作都被显式禁用。
+ """
+
+ @staticmethod
+ def _get_runtime_manager() -> "PluginRuntimeManager":
+ """获取插件运行时管理器单例。
+
+ Returns:
+ PluginRuntimeManager: 当前全局插件运行时管理器。
+ """
+
+ from src.plugin_runtime.integration import get_plugin_runtime_manager
+
+ return get_plugin_runtime_manager()
+
+ def _iter_supervisors(self) -> list["PluginSupervisor"]:
+ """获取当前所有活跃的插件运行时监督器。
+
+ Returns:
+ list[PluginSupervisor]: 当前运行中的监督器列表。
+ """
+
+ runtime_manager = self._get_runtime_manager()
+ return list(runtime_manager.supervisors)
+
+ def _iter_component_entries(
+ self,
+ component_type: ComponentType,
+ *,
+ enabled_only: bool = True,
+ ) -> list[tuple["PluginSupervisor", "ComponentEntry"]]:
+ """遍历指定类型的全部组件条目。
+
+ Args:
+ component_type: 目标组件类型。
+ enabled_only: 是否仅返回启用状态的组件。
+
+ Returns:
+ list[tuple[PluginSupervisor, ComponentEntry]]: ``(监督器, 组件条目)`` 列表。
+ """
+
+ host_component_type = _HOST_COMPONENT_TYPE_MAP.get(component_type)
+ if host_component_type is None:
+ return []
+
+ collected_entries: list[tuple["PluginSupervisor", "ComponentEntry"]] = []
+ for supervisor in self._iter_supervisors():
+ for component in supervisor.component_registry.get_components_by_type(
+ host_component_type,
+ enabled_only=enabled_only,
+ ):
+ collected_entries.append((supervisor, component))
+ return collected_entries
+
+ @staticmethod
+ def _coerce_action_activation_type(raw_value: Any) -> ActionActivationType:
+ """规范化动作激活类型。
+
+ Args:
+ raw_value: 原始激活类型值。
+
+ Returns:
+ ActionActivationType: 规范化后的激活类型枚举。
+ """
+
+ normalized_value = str(raw_value or "").strip().lower()
+ if normalized_value == ActionActivationType.NEVER.value:
+ return ActionActivationType.NEVER
+ if normalized_value == ActionActivationType.RANDOM.value:
+ return ActionActivationType.RANDOM
+ if normalized_value == ActionActivationType.KEYWORD.value:
+ return ActionActivationType.KEYWORD
+ return ActionActivationType.ALWAYS
+
+ @staticmethod
+ def _coerce_float(value: Any, default: float = 0.0) -> float:
+ """将任意值安全转换为浮点数。
+
+ Args:
+ value: 待转换的输入值。
+ default: 转换失败时返回的默认值。
+
+ Returns:
+ float: 转换后的浮点结果。
+ """
+
+ try:
+ return float(value)
+ except (TypeError, ValueError):
+ return default
+
+ @staticmethod
+ def _build_action_info(entry: "ActionEntry") -> ActionInfo:
+ """将运行时 Action 条目转换为核心动作信息。
+
+ Args:
+ entry: 插件运行时中的 Action 条目。
+
+ Returns:
+ ActionInfo: 供核心 Planner 使用的动作信息。
+ """
+
+ metadata = dict(entry.metadata)
+ raw_action_parameters = metadata.get("action_parameters")
+ action_parameters = (
+ {
+ str(param_name): str(param_description)
+ for param_name, param_description in raw_action_parameters.items()
+ }
+ if isinstance(raw_action_parameters, dict)
+ else {}
+ )
+ action_require = [
+ str(item)
+ for item in (metadata.get("action_require") or [])
+ if item is not None and str(item).strip()
+ ]
+ associated_types = [
+ str(item)
+ for item in (metadata.get("associated_types") or [])
+ if item is not None and str(item).strip()
+ ]
+ activation_keywords = [
+ str(item)
+ for item in (metadata.get("activation_keywords") or [])
+ if item is not None and str(item).strip()
+ ]
+
+ return ActionInfo(
+ name=entry.name,
+ component_type=ComponentType.ACTION,
+ description=str(metadata.get("description", "") or ""),
+ enabled=bool(entry.enabled),
+ plugin_name=entry.plugin_id,
+ metadata=metadata,
+ action_parameters=action_parameters,
+ action_require=action_require,
+ associated_types=associated_types,
+ activation_type=ComponentQueryService._coerce_action_activation_type(metadata.get("activation_type")),
+ random_activation_probability=ComponentQueryService._coerce_float(
+ metadata.get("activation_probability"),
+ 0.0,
+ ),
+ activation_keywords=activation_keywords,
+ parallel_action=bool(metadata.get("parallel_action", False)),
+ )
+
+ @staticmethod
+ def _build_command_info(entry: "CommandEntry") -> CommandInfo:
+ """将运行时 Command 条目转换为核心命令信息。
+
+ Args:
+ entry: 插件运行时中的 Command 条目。
+
+ Returns:
+ CommandInfo: 供核心命令链使用的命令信息。
+ """
+
+ metadata = dict(entry.metadata)
+ return CommandInfo(
+ name=entry.name,
+ component_type=ComponentType.COMMAND,
+ description=str(metadata.get("description", "") or ""),
+ enabled=bool(entry.enabled),
+ plugin_name=entry.plugin_id,
+ metadata=metadata,
+ command_pattern=str(metadata.get("command_pattern", "") or ""),
+ )
+
+ @staticmethod
+ def _coerce_tool_param_type(raw_value: Any) -> ToolParamType:
+ """规范化工具参数类型。
+
+ Args:
+ raw_value: 原始工具参数类型值。
+
+ Returns:
+ ToolParamType: 规范化后的工具参数类型。
+ """
+
+ normalized_value = str(raw_value or "").strip().lower()
+ return _TOOL_PARAM_TYPE_MAP.get(normalized_value, ToolParamType.STRING)
+
+ @staticmethod
+ def _build_tool_parameters(entry: "ToolEntry") -> list[tuple[str, ToolParamType, str, bool, list[str] | None]]:
+ """将运行时工具参数元数据转换为核心 ToolInfo 参数列表。
+
+ Args:
+ entry: 插件运行时中的 Tool 条目。
+
+ Returns:
+ list[tuple[str, ToolParamType, str, bool, list[str] | None]]: 转换后的参数列表。
+ """
+
+ structured_parameters = entry.parameters if isinstance(entry.parameters, list) else []
+ if not structured_parameters and isinstance(entry.parameters_raw, dict):
+ structured_parameters = [
+ {"name": key, **value}
+ for key, value in entry.parameters_raw.items()
+ if isinstance(value, dict)
+ ]
+
+ normalized_parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = []
+ for parameter in structured_parameters:
+ if not isinstance(parameter, dict):
+ continue
+
+ parameter_name = str(parameter.get("name", "") or "").strip()
+ if not parameter_name:
+ continue
+
+ enum_values = parameter.get("enum")
+ normalized_enum_values = (
+ [str(item) for item in enum_values if item is not None]
+ if isinstance(enum_values, list)
+ else None
+ )
+ normalized_parameters.append(
+ (
+ parameter_name,
+ ComponentQueryService._coerce_tool_param_type(parameter.get("param_type") or parameter.get("type")),
+ str(parameter.get("description", "") or ""),
+ bool(parameter.get("required", True)),
+ normalized_enum_values,
+ )
+ )
+ return normalized_parameters
+
+ @staticmethod
+ def _build_tool_info(entry: "ToolEntry") -> ToolInfo:
+ """将运行时 Tool 条目转换为核心工具信息。
+
+ Args:
+ entry: 插件运行时中的 Tool 条目。
+
+ Returns:
+ ToolInfo: 供 ToolExecutor 与能力层使用的工具信息。
+ """
+
+ return ToolInfo(
+ name=entry.name,
+ component_type=ComponentType.TOOL,
+ description=entry.description,
+ enabled=bool(entry.enabled),
+ plugin_name=entry.plugin_id,
+ metadata=dict(entry.metadata),
+ tool_parameters=ComponentQueryService._build_tool_parameters(entry),
+ tool_description=entry.description,
+ )
+
+ @staticmethod
+ def _log_duplicate_component(component_type: ComponentType, component_name: str) -> None:
+ """记录重复组件名称冲突。
+
+ Args:
+ component_type: 组件类型。
+ component_name: 发生冲突的组件名称。
+ """
+
+ logger.warning(f"检测到重复{component_type.value}名称 {component_name},将只保留首个匹配项")
+
+ def _get_unique_component_entry(
+ self,
+ component_type: ComponentType,
+ name: str,
+ ) -> Optional[tuple["PluginSupervisor", "ComponentEntry"]]:
+ """按组件短名解析唯一条目。
+
+ Args:
+ component_type: 目标组件类型。
+ name: 组件短名。
+
+ Returns:
+ Optional[tuple[PluginSupervisor, ComponentEntry]]: 唯一命中的组件条目。
+ """
+
+ matched_entries = [
+ (supervisor, entry)
+ for supervisor, entry in self._iter_component_entries(component_type)
+ if entry.name == name
+ ]
+ if not matched_entries:
+ return None
+ if len(matched_entries) > 1:
+ self._log_duplicate_component(component_type, name)
+ return matched_entries[0]
+
+ def _collect_unique_component_infos(
+ self,
+ component_type: ComponentType,
+ ) -> Dict[str, ComponentInfo]:
+ """收集某类组件的唯一信息视图。
+
+ Args:
+ component_type: 目标组件类型。
+
+ Returns:
+ Dict[str, ComponentInfo]: 组件名到核心组件信息的映射。
+ """
+
+ collected_components: Dict[str, ComponentInfo] = {}
+ for _supervisor, entry in self._iter_component_entries(component_type):
+ if entry.name in collected_components:
+ self._log_duplicate_component(component_type, entry.name)
+ continue
+
+ if component_type == ComponentType.ACTION:
+ collected_components[entry.name] = self._build_action_info(entry) # type: ignore[arg-type]
+ elif component_type == ComponentType.COMMAND:
+ collected_components[entry.name] = self._build_command_info(entry) # type: ignore[arg-type]
+ elif component_type == ComponentType.TOOL:
+ collected_components[entry.name] = self._build_tool_info(entry) # type: ignore[arg-type]
+ return collected_components
+
+ @staticmethod
+ def _extract_stream_id_from_action_kwargs(kwargs: Dict[str, Any]) -> str:
+ """从旧 ActionManager 参数中提取聊天流 ID。
+
+ Args:
+ kwargs: 旧动作执行器收到的关键字参数。
+
+ Returns:
+ str: 提取出的 ``stream_id``。
+ """
+
+ chat_stream = kwargs.get("chat_stream")
+ if chat_stream is not None:
+ try:
+ return str(chat_stream.session_id)
+ except AttributeError:
+ pass
+
+ return str(kwargs.get("stream_id", "") or "")
+
+ @staticmethod
+ def _build_action_executor(supervisor: "PluginSupervisor", plugin_id: str, component_name: str) -> ActionExecutor:
+ """构造动作执行 RPC 闭包。
+
+ Args:
+ supervisor: 负责该组件的监督器。
+ plugin_id: 插件 ID。
+ component_name: 组件名称。
+
+ Returns:
+ ActionExecutor: 兼容旧 Planner 的异步执行器。
+ """
+
+ async def _executor(**kwargs: Any) -> tuple[bool, str]:
+ """将核心动作调用桥接到插件运行时。
+
+ Args:
+ **kwargs: 旧 ActionManager 传入的上下文参数。
+
+ Returns:
+ tuple[bool, str]: ``(是否成功, 结果说明)``。
+ """
+
+ invoke_args: Dict[str, Any] = {}
+ action_data = kwargs.get("action_data")
+ if isinstance(action_data, dict):
+ invoke_args.update(action_data)
+
+ stream_id = ComponentQueryService._extract_stream_id_from_action_kwargs(kwargs)
+ invoke_args["action_data"] = action_data if isinstance(action_data, dict) else {}
+ invoke_args["stream_id"] = stream_id
+ invoke_args["chat_id"] = stream_id
+ invoke_args["reasoning"] = str(kwargs.get("action_reasoning", "") or "")
+
+ if (thinking_id := kwargs.get("thinking_id")) is not None:
+ invoke_args["thinking_id"] = str(thinking_id)
+ if isinstance(kwargs.get("cycle_timers"), dict):
+ invoke_args["cycle_timers"] = kwargs["cycle_timers"]
+ if isinstance(kwargs.get("plugin_config"), dict):
+ invoke_args["plugin_config"] = kwargs["plugin_config"]
+ if isinstance(kwargs.get("log_prefix"), str):
+ invoke_args["log_prefix"] = kwargs["log_prefix"]
+ if isinstance(kwargs.get("shutting_down"), bool):
+ invoke_args["shutting_down"] = kwargs["shutting_down"]
+
+ try:
+ response = await supervisor.invoke_plugin(
+ method="plugin.invoke_action",
+ plugin_id=plugin_id,
+ component_name=component_name,
+ args=invoke_args,
+ timeout_ms=30000,
+ )
+ except Exception as exc:
+ logger.error(f"运行时 Action {plugin_id}.{component_name} 执行失败: {exc}", exc_info=True)
+ return False, str(exc)
+
+ payload = response.payload if isinstance(response.payload, dict) else {}
+ success = bool(payload.get("success", False))
+ result = payload.get("result")
+ if isinstance(result, (list, tuple)):
+ if len(result) >= 2:
+ return bool(result[0]), "" if result[1] is None else str(result[1])
+ if len(result) == 1:
+ return bool(result[0]), ""
+ if success:
+ return True, "" if result is None else str(result)
+ return False, "" if result is None else str(result)
+
+ return _executor
+
+ @staticmethod
+ def _build_command_executor(
+ supervisor: "PluginSupervisor",
+ plugin_id: str,
+ component_name: str,
+ metadata: Dict[str, Any],
+ ) -> CommandExecutor:
+ """构造命令执行 RPC 闭包。
+
+ Args:
+ supervisor: 负责该组件的监督器。
+ plugin_id: 插件 ID。
+ component_name: 组件名称。
+ metadata: 命令组件元数据。
+
+ Returns:
+ CommandExecutor: 兼容旧消息命令链的执行器。
+ """
+
+ async def _executor(**kwargs: Any) -> tuple[bool, Optional[str], bool]:
+ """将核心命令调用桥接到插件运行时。
+
+ Args:
+ **kwargs: 命令执行上下文参数。
+
+ Returns:
+ tuple[bool, Optional[str], bool]: ``(是否成功, 返回文本, 是否拦截后续消息)``。
+ """
+
+ message = kwargs.get("message")
+ matched_groups = kwargs.get("matched_groups")
+ plugin_config = kwargs.get("plugin_config")
+ invoke_args: Dict[str, Any] = {
+ "text": str(getattr(message, "processed_plain_text", "") or ""),
+ "stream_id": str(getattr(message, "session_id", "") or ""),
+ "matched_groups": matched_groups if isinstance(matched_groups, dict) else {},
+ }
+ if isinstance(plugin_config, dict):
+ invoke_args["plugin_config"] = plugin_config
+
+ try:
+ response = await supervisor.invoke_plugin(
+ method="plugin.invoke_command",
+ plugin_id=plugin_id,
+ component_name=component_name,
+ args=invoke_args,
+ timeout_ms=30000,
+ )
+ except Exception as exc:
+ logger.error(f"运行时 Command {plugin_id}.{component_name} 执行失败: {exc}", exc_info=True)
+ return False, str(exc), True
+
+ payload = response.payload if isinstance(response.payload, dict) else {}
+ success = bool(payload.get("success", False))
+ result = payload.get("result")
+ intercept = bool(metadata.get("intercept_message_level", 0))
+ response_text: Optional[str]
+
+ if isinstance(result, (list, tuple)) and len(result) >= 3:
+ response_text = None if result[1] is None else str(result[1])
+ intercept = bool(result[2])
+ else:
+ response_text = None if result is None else str(result)
+
+ return success, response_text, intercept
+
+ return _executor
+
+ @staticmethod
+ def _build_tool_executor(supervisor: "PluginSupervisor", plugin_id: str, component_name: str) -> ToolExecutor:
+ """构造工具执行 RPC 闭包。
+
+ Args:
+ supervisor: 负责该组件的监督器。
+ plugin_id: 插件 ID。
+ component_name: 组件名称。
+
+ Returns:
+ ToolExecutor: 兼容旧 ToolExecutor 的异步执行器。
+ """
+
+ async def _executor(function_args: Dict[str, Any]) -> Any:
+ """将核心工具调用桥接到插件运行时。
+
+ Args:
+ function_args: 工具调用参数。
+
+ Returns:
+ Any: 插件工具返回结果;若结果不是字典,则会包装为 ``{"content": ...}``。
+ """
+
+ try:
+ response = await supervisor.invoke_plugin(
+ method="plugin.invoke_tool",
+ plugin_id=plugin_id,
+ component_name=component_name,
+ args=function_args,
+ timeout_ms=30000,
+ )
+ except Exception as exc:
+ logger.error(f"运行时 Tool {plugin_id}.{component_name} 执行失败: {exc}", exc_info=True)
+ return {"content": f"工具 {component_name} 执行失败: {exc}"}
+
+ payload = response.payload if isinstance(response.payload, dict) else {}
+ result = payload.get("result")
+ if isinstance(result, dict):
+ return result
+ return {"content": "" if result is None else str(result)}
+
+ return _executor
+
+ def get_action_info(self, name: str) -> Optional[ActionInfo]:
+ """获取指定动作的信息。
+
+ Args:
+ name: 动作名称。
+
+ Returns:
+ Optional[ActionInfo]: 匹配到的动作信息。
+ """
+
+ matched_entry = self._get_unique_component_entry(ComponentType.ACTION, name)
+ if matched_entry is None:
+ return None
+ _supervisor, entry = matched_entry
+ return self._build_action_info(entry) # type: ignore[arg-type]
+
+ def get_action_executor(self, name: str) -> Optional[ActionExecutor]:
+ """获取指定动作的执行器。
+
+ Args:
+ name: 动作名称。
+
+ Returns:
+ Optional[ActionExecutor]: 运行时 RPC 执行闭包。
+ """
+
+ matched_entry = self._get_unique_component_entry(ComponentType.ACTION, name)
+ if matched_entry is None:
+ return None
+ supervisor, entry = matched_entry
+ return self._build_action_executor(supervisor, entry.plugin_id, entry.name)
+
+ def get_default_actions(self) -> Dict[str, ActionInfo]:
+ """获取当前默认启用的动作集合。
+
+ Returns:
+ Dict[str, ActionInfo]: 动作名到动作信息的映射。
+ """
+
+ action_infos = self._collect_unique_component_infos(ComponentType.ACTION)
+ return {name: info for name, info in action_infos.items() if isinstance(info, ActionInfo) and info.enabled}
+
+ def find_command_by_text(self, text: str) -> Optional[Tuple[CommandExecutor, dict, CommandInfo]]:
+ """根据文本查找匹配的命令。
+
+ Args:
+ text: 待匹配的文本内容。
+
+ Returns:
+ Optional[Tuple[CommandExecutor, dict, CommandInfo]]: 匹配结果。
+ """
+
+ for supervisor in self._iter_supervisors():
+ match_result = supervisor.component_registry.find_command_by_text(text)
+ if match_result is None:
+ continue
+
+ entry, matched_groups = match_result
+ command_info = self._build_command_info(entry) # type: ignore[arg-type]
+ command_executor = self._build_command_executor(
+ supervisor,
+ entry.plugin_id,
+ entry.name,
+ dict(entry.metadata),
+ )
+ return command_executor, matched_groups, command_info
+ return None
+
+ def get_tool_info(self, name: str) -> Optional[ToolInfo]:
+ """获取指定工具的信息。
+
+ Args:
+ name: 工具名称。
+
+ Returns:
+ Optional[ToolInfo]: 匹配到的工具信息。
+ """
+
+ matched_entry = self._get_unique_component_entry(ComponentType.TOOL, name)
+ if matched_entry is None:
+ return None
+ _supervisor, entry = matched_entry
+ return self._build_tool_info(entry) # type: ignore[arg-type]
+
+ def get_tool_executor(self, name: str) -> Optional[ToolExecutor]:
+ """获取指定工具的执行器。
+
+ Args:
+ name: 工具名称。
+
+ Returns:
+ Optional[ToolExecutor]: 运行时 RPC 执行闭包。
+ """
+
+ matched_entry = self._get_unique_component_entry(ComponentType.TOOL, name)
+ if matched_entry is None:
+ return None
+ supervisor, entry = matched_entry
+ return self._build_tool_executor(supervisor, entry.plugin_id, entry.name)
+
+ def get_llm_available_tools(self) -> Dict[str, ToolInfo]:
+ """获取当前可供 LLM 选择的工具集合。
+
+ Returns:
+ Dict[str, ToolInfo]: 工具名到工具信息的映射。
+ """
+
+ tool_infos = self._collect_unique_component_infos(ComponentType.TOOL)
+ return {name: info for name, info in tool_infos.items() if isinstance(info, ToolInfo) and info.enabled}
+
+ def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
+ """获取某类组件的全部信息。
+
+ Args:
+ component_type: 组件类型。
+
+ Returns:
+ Dict[str, ComponentInfo]: 组件名到组件信息的映射。
+ """
+
+ return self._collect_unique_component_infos(component_type)
+
+ def get_plugin_config(self, plugin_name: str) -> Optional[dict]:
+ """读取指定插件的配置文件内容。
+
+ Args:
+ plugin_name: 插件名称。
+
+ Returns:
+ Optional[dict]: 读取成功时返回配置字典;未找到时返回 ``None``。
+ """
+
+ runtime_manager = self._get_runtime_manager()
+ try:
+ supervisor = runtime_manager._get_supervisor_for_plugin(plugin_name)
+ except RuntimeError as exc:
+ logger.error(f"读取插件配置失败: {exc}")
+ return None
+
+ if supervisor is None:
+ return None
+
+ try:
+ return runtime_manager._load_plugin_config_for_supervisor(supervisor, plugin_name)
+ except Exception as exc:
+ logger.error(f"读取插件 {plugin_name} 配置失败: {exc}", exc_info=True)
+ return None
+
+
+component_query_service = ComponentQueryService()
diff --git a/src/plugin_runtime/host/api_registry.py b/src/plugin_runtime/host/api_registry.py
new file mode 100644
index 00000000..1cbc05f6
--- /dev/null
+++ b/src/plugin_runtime/host/api_registry.py
@@ -0,0 +1,349 @@
+"""Host 侧插件 API 动态注册表。"""
+
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional, Set, Tuple
+
+from src.common.logger import get_logger
+
+logger = get_logger("plugin_runtime.host.api_registry")
+
+
+@dataclass(slots=True)
+class APIEntry:
+ """API 组件条目。"""
+
+ name: str
+ plugin_id: str
+ description: str = ""
+ version: str = "1"
+ public: bool = False
+ metadata: Dict[str, Any] = field(default_factory=dict)
+ enabled: bool = True
+ handler_name: str = ""
+ dynamic: bool = False
+ offline_reason: str = ""
+ disabled_session: Set[str] = field(default_factory=set)
+ full_name: str = field(init=False)
+ registry_key: str = field(init=False)
+
+ def __post_init__(self) -> None:
+ """规范化 API 条目字段。"""
+
+ self.name = str(self.name or "").strip()
+ self.plugin_id = str(self.plugin_id or "").strip()
+ self.description = str(self.description or "").strip()
+ self.version = str(self.version or "1").strip() or "1"
+ self.handler_name = str(self.handler_name or self.name).strip() or self.name
+ self.offline_reason = str(self.offline_reason or "").strip()
+ self.full_name = f"{self.plugin_id}.{self.name}"
+ self.registry_key = APIRegistry.build_registry_key(self.plugin_id, self.name, self.version)
+
+ @classmethod
+ def from_metadata(cls, name: str, plugin_id: str, metadata: Dict[str, Any]) -> "APIEntry":
+ """根据 Runner 上报的元数据构造 API 条目。"""
+
+ safe_metadata = dict(metadata)
+ return cls(
+ name=name,
+ plugin_id=plugin_id,
+ description=str(safe_metadata.get("description", "") or ""),
+ version=str(safe_metadata.get("version", "1") or "1"),
+ public=bool(safe_metadata.get("public", False)),
+ metadata=safe_metadata,
+ enabled=bool(safe_metadata.get("enabled", True)),
+ handler_name=str(safe_metadata.get("handler_name", name) or name),
+ dynamic=bool(safe_metadata.get("dynamic", False)),
+ offline_reason=str(safe_metadata.get("offline_reason", "") or ""),
+ )
+
+
+class APIRegistry:
+ """Host 侧插件 API 动态注册表。
+
+ 该注册表不直接面向 Runner,而是复用插件组件注册/卸载事件,
+ 维护面向 API 调用场景的专用索引。
+ """
+
+ def __init__(self) -> None:
+ """初始化 API 注册表。"""
+
+ self._apis: Dict[str, APIEntry] = {}
+ self._by_full_name: Dict[str, List[APIEntry]] = {}
+ self._by_plugin: Dict[str, List[APIEntry]] = {}
+ self._by_name: Dict[str, List[APIEntry]] = {}
+
+ def clear(self) -> None:
+ """清空全部 API 注册状态。"""
+
+ self._apis.clear()
+ self._by_full_name.clear()
+ self._by_plugin.clear()
+ self._by_name.clear()
+
+ @staticmethod
+ def _is_api_component(component_type: Any) -> bool:
+ """判断组件声明是否属于 API。"""
+
+ return str(component_type or "").strip().upper() == "API"
+
+ @staticmethod
+ def _normalize_query_version(version: Any) -> str:
+ """规范化查询使用的版本字符串。"""
+
+ return str(version or "").strip()
+
+ @classmethod
+ def _split_reference(cls, reference: str, version: Any = "") -> Tuple[str, str]:
+ """解析可能带 ``@version`` 后缀的 API 引用。"""
+
+ normalized_reference = str(reference or "").strip()
+ normalized_version = cls._normalize_query_version(version)
+ if normalized_reference and not normalized_version and "@" in normalized_reference:
+ candidate_reference, candidate_version = normalized_reference.rsplit("@", 1)
+ candidate_reference = candidate_reference.strip()
+ candidate_version = candidate_version.strip()
+ if candidate_reference and candidate_version:
+ normalized_reference = candidate_reference
+ normalized_version = candidate_version
+ return normalized_reference, normalized_version
+
+ @staticmethod
+ def build_registry_key(plugin_id: str, name: str, version: str) -> str:
+ """构造 API 注册表唯一键。"""
+
+ normalized_full_name = f"{str(plugin_id or '').strip()}.{str(name or '').strip()}"
+ normalized_version = str(version or "1").strip() or "1"
+ return f"{normalized_full_name}@{normalized_version}"
+
+ @staticmethod
+ def check_api_enabled(entry: APIEntry, session_id: Optional[str] = None) -> bool:
+ """判断 API 条目当前是否处于启用状态。"""
+
+ if session_id and session_id in entry.disabled_session:
+ return False
+ return entry.enabled
+
+ def register_api(self, name: str, plugin_id: str, metadata: Dict[str, Any]) -> bool:
+ """注册单个 API 条目。"""
+
+ normalized_name = str(name or "").strip()
+ if not normalized_name:
+ logger.warning(f"插件 {plugin_id} 存在空 API 名称声明,已忽略")
+ return False
+
+ entry = APIEntry.from_metadata(name=normalized_name, plugin_id=plugin_id, metadata=metadata)
+ existing_entry = self._apis.get(entry.registry_key)
+ if existing_entry is not None:
+ logger.warning(f"API {entry.registry_key} 已存在,覆盖旧条目")
+ self._remove_entry(existing_entry)
+
+ self._apis[entry.registry_key] = entry
+ self._by_full_name.setdefault(entry.full_name, []).append(entry)
+ self._by_plugin.setdefault(plugin_id, []).append(entry)
+ self._by_name.setdefault(entry.name, []).append(entry)
+ return True
+
+ def register_plugin_apis(self, plugin_id: str, components: List[Dict[str, Any]]) -> int:
+ """批量注册某个插件声明的全部 API。"""
+
+ count = 0
+ for component in components:
+ if not self._is_api_component(component.get("component_type")):
+ continue
+ if self.register_api(
+ name=str(component.get("name", "") or ""),
+ plugin_id=plugin_id,
+ metadata=component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {},
+ ):
+ count += 1
+ return count
+
+ def replace_plugin_dynamic_apis(
+ self,
+ plugin_id: str,
+ components: List[Dict[str, Any]],
+ *,
+ offline_reason: str = "动态 API 已下线",
+ ) -> Tuple[int, int]:
+ """替换指定插件当前声明的动态 API 集合。"""
+
+ normalized_offline_reason = str(offline_reason or "").strip() or "动态 API 已下线"
+ desired_registry_keys: Set[str] = set()
+ registered_count = 0
+
+ for component in components:
+ if not self._is_api_component(component.get("component_type")):
+ continue
+ metadata = component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {}
+ dynamic_metadata = dict(metadata)
+ dynamic_metadata["dynamic"] = True
+ dynamic_metadata.pop("offline_reason", None)
+
+ entry = APIEntry.from_metadata(
+ name=str(component.get("name", "") or ""),
+ plugin_id=plugin_id,
+ metadata=dynamic_metadata,
+ )
+ desired_registry_keys.add(entry.registry_key)
+ if self.register_api(entry.name, plugin_id, dynamic_metadata):
+ registered_count += 1
+
+ offlined_count = 0
+ for entry in list(self._by_plugin.get(plugin_id, [])):
+ if not entry.dynamic or entry.registry_key in desired_registry_keys:
+ continue
+ entry.enabled = False
+ entry.offline_reason = normalized_offline_reason
+ entry.metadata["offline_reason"] = normalized_offline_reason
+ offlined_count += 1
+
+ return registered_count, offlined_count
+
+ def _remove_entry(self, entry: APIEntry) -> None:
+ """从全部索引中移除单个 API 条目。"""
+
+ self._apis.pop(entry.registry_key, None)
+
+ full_name_entries = self._by_full_name.get(entry.full_name)
+ if full_name_entries is not None:
+ self._by_full_name[entry.full_name] = [
+ candidate for candidate in full_name_entries if candidate is not entry
+ ]
+ if not self._by_full_name[entry.full_name]:
+ self._by_full_name.pop(entry.full_name, None)
+
+ plugin_entries = self._by_plugin.get(entry.plugin_id)
+ if plugin_entries is not None:
+ self._by_plugin[entry.plugin_id] = [candidate for candidate in plugin_entries if candidate is not entry]
+ if not self._by_plugin[entry.plugin_id]:
+ self._by_plugin.pop(entry.plugin_id, None)
+
+ name_entries = self._by_name.get(entry.name)
+ if name_entries is not None:
+ self._by_name[entry.name] = [candidate for candidate in name_entries if candidate is not entry]
+ if not self._by_name[entry.name]:
+ self._by_name.pop(entry.name, None)
+
+ def remove_apis_by_plugin(self, plugin_id: str) -> int:
+ """移除某个插件的全部 API。"""
+
+ entries = list(self._by_plugin.get(plugin_id, []))
+ for entry in entries:
+ self._remove_entry(entry)
+ return len(entries)
+
+ def get_api_by_full_name(
+ self,
+ full_name: str,
+ *,
+ version: str = "",
+ enabled_only: bool = True,
+ session_id: Optional[str] = None,
+ ) -> Optional[APIEntry]:
+ """按完整名查询单个 API。"""
+
+ normalized_full_name, normalized_version = self._split_reference(full_name, version)
+ if not normalized_full_name:
+ return None
+
+ if normalized_version:
+ entry = self._apis.get(f"{normalized_full_name}@{normalized_version}")
+ if entry is None:
+ return None
+ if enabled_only and not self.check_api_enabled(entry, session_id):
+ return None
+ return entry
+
+ candidates = list(self._by_full_name.get(normalized_full_name, []))
+ filtered_entries = [
+ entry
+ for entry in candidates
+ if not enabled_only or self.check_api_enabled(entry, session_id)
+ ]
+ if len(filtered_entries) != 1:
+ return None
+ return filtered_entries[0]
+
+ def get_api(
+ self,
+ plugin_id: str,
+ name: str,
+ *,
+ version: str = "",
+ enabled_only: bool = True,
+ session_id: Optional[str] = None,
+ ) -> Optional[APIEntry]:
+ """按插件 ID、短名与版本查询单个 API。"""
+
+ return self.get_api_by_full_name(
+ f"{plugin_id}.{name}",
+ version=version,
+ enabled_only=enabled_only,
+ session_id=session_id,
+ )
+
+ def get_apis(
+ self,
+ *,
+ plugin_id: Optional[str] = None,
+ name: str = "",
+ version: str = "",
+ enabled_only: bool = True,
+ session_id: Optional[str] = None,
+ ) -> List[APIEntry]:
+ """查询 API 列表。"""
+
+ normalized_name = str(name or "").strip()
+ normalized_version = self._normalize_query_version(version)
+
+ if plugin_id:
+ candidates = list(self._by_plugin.get(plugin_id, []))
+ elif normalized_name:
+ candidates = list(self._by_name.get(normalized_name, []))
+ else:
+ candidates = list(self._apis.values())
+
+ filtered_entries: List[APIEntry] = []
+ for entry in candidates:
+ if plugin_id and entry.plugin_id != plugin_id:
+ continue
+ if normalized_name and entry.name != normalized_name:
+ continue
+ if normalized_version and entry.version != normalized_version:
+ continue
+ if enabled_only and not self.check_api_enabled(entry, session_id):
+ continue
+ filtered_entries.append(entry)
+
+ filtered_entries.sort(key=lambda entry: (entry.plugin_id, entry.name, entry.version))
+ return filtered_entries
+
+ def toggle_api_status(
+ self,
+ full_name: str,
+ enabled: bool,
+ *,
+ version: str = "",
+ session_id: Optional[str] = None,
+ ) -> bool:
+ """设置指定 API 的启用状态。"""
+
+ entry = self.get_api_by_full_name(
+ full_name,
+ version=version,
+ enabled_only=False,
+ session_id=session_id,
+ )
+ if entry is None:
+ return False
+ if session_id:
+ if enabled:
+ entry.disabled_session.discard(session_id)
+ else:
+ entry.disabled_session.add(session_id)
+ else:
+ entry.enabled = enabled
+ if enabled:
+ entry.offline_reason = ""
+ entry.metadata.pop("offline_reason", None)
+ return True
diff --git a/src/plugin_runtime/host/authorization.py b/src/plugin_runtime/host/authorization.py
new file mode 100644
index 00000000..70593768
--- /dev/null
+++ b/src/plugin_runtime/host/authorization.py
@@ -0,0 +1,67 @@
+"""授权管理器
+
+负责管理插件的能力授权以及校验
+每个插件在 manifest 中声明能力需求,Host 启动时签发能力令牌。
+"""
+
+from dataclasses import dataclass, field
+from typing import Dict, List, Optional, Set, Tuple
+
+_ALWAYS_ALLOWED_CAPABILITIES = frozenset({"api.replace_dynamic"})
+
+
+@dataclass
+class CapabilityPermissionToken:
+ """能力令牌"""
+
+ plugin_id: str
+ capabilities: Set[str] = field(default_factory=set)
+
+
+class AuthorizationManager:
+ """授权管理器
+
+ 管理所有插件的能力令牌,提供授权校验。
+ """
+
+ def __init__(self) -> None:
+ self._permission_tokens: Dict[str, CapabilityPermissionToken] = {}
+
+ def register_plugin(self, plugin_id: str, capabilities: List[str]) -> CapabilityPermissionToken:
+ """为插件签发能力令牌"""
+ token = CapabilityPermissionToken(plugin_id=plugin_id, capabilities=set(capabilities))
+ self._permission_tokens[plugin_id] = token
+ return token
+
+ def revoke_permission_token(self, plugin_id: str):
+ """移除插件的能力令牌。"""
+ self._permission_tokens.pop(plugin_id, None)
+
+ def clear(self) -> None:
+ """清空所有能力令牌。"""
+ self._permission_tokens.clear()
+
+ def check_capability(self, plugin_id: str, capability: str) -> Tuple[bool, str]:
+ # sourcery skip: assign-if-exp, reintroduce-else, swap-if-else-branches, use-named-expression
+ """检查插件是否有权调用某项能力
+
+ Returns:
+ return (bool, str): (是否有此能力, 原因)
+ """
+ if capability in _ALWAYS_ALLOWED_CAPABILITIES:
+ return True, ""
+
+ token = self._permission_tokens.get(plugin_id)
+ if not token:
+ return False, f"插件 {plugin_id} 未注册能力令牌"
+ if capability not in token.capabilities:
+ return False, f"插件 {plugin_id} 未获授权能力: {capability}"
+ return True, ""
+
+ def get_token(self, plugin_id: str) -> Optional[CapabilityPermissionToken]:
+ """获取插件的能力令牌"""
+ return self._permission_tokens.get(plugin_id)
+
+ def list_plugins(self) -> List[str]:
+ """列出所有已注册的插件"""
+ return list(self._permission_tokens.keys())
diff --git a/src/plugin_runtime/host/capability_service.py b/src/plugin_runtime/host/capability_service.py
index 6685ff60..0ff31fe1 100644
--- a/src/plugin_runtime/host/capability_service.py
+++ b/src/plugin_runtime/host/capability_service.py
@@ -4,21 +4,19 @@ Host 端实现的能力服务,处理来自插件的 cap.* 请求。
每个能力方法被注册到 RPC Server,接收 Runner 转发的请求并执行实际操作。
"""
-from typing import Any, Awaitable, Callable, Dict, List
+from typing import Any, Callable, Dict, List, Coroutine, TYPE_CHECKING
from src.common.logger import get_logger
-from src.plugin_runtime.host.policy_engine import PolicyEngine
-from src.plugin_runtime.protocol.envelope import (
- CapabilityRequestPayload,
- CapabilityResponsePayload,
- Envelope,
-)
+from src.plugin_runtime.protocol.envelope import CapabilityRequestPayload, CapabilityResponsePayload, Envelope
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
+if TYPE_CHECKING:
+ from src.plugin_runtime.host.authorization import AuthorizationManager
+
logger = get_logger("plugin_runtime.host.capability_service")
# 能力实现函数类型: (plugin_id, capability, args) -> result
-CapabilityImpl = Callable[[str, str, Dict[str, Any]], Awaitable[Any]]
+CapabilityImpl = Callable[[str, str, Dict[str, Any]], Coroutine[Any, Any, Any]]
class CapabilityService:
@@ -31,8 +29,13 @@ class CapabilityService:
4. 执行实际操作并返回结果
"""
- def __init__(self, policy_engine: PolicyEngine) -> None:
- self._policy = policy_engine
+ def __init__(self, authorization: "AuthorizationManager") -> None:
+ """初始化能力服务。
+
+ Args:
+ authorization: 能力授权管理器。
+ """
+ self._authorization = authorization
# capability_name -> implementation
self._implementations: Dict[str, CapabilityImpl] = {}
@@ -56,46 +59,32 @@ class CapabilityService:
try:
req = CapabilityRequestPayload.model_validate(envelope.payload)
- except Exception as e:
- return envelope.make_error_response(
- ErrorCode.E_BAD_PAYLOAD.value,
- f"能力调用 payload 格式错误: {e}",
- )
+ except Exception as exc:
+ return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, f"能力调用 payload 非法: {exc}")
capability = req.capability
+ args = req.args
# 1. 权限校验
- allowed, reason = self._policy.check_capability(plugin_id, capability, envelope.generation)
+ allowed, reason = self._authorization.check_capability(plugin_id, capability)
if not allowed:
- error_code = (
- ErrorCode.E_GENERATION_MISMATCH if "generation 不匹配" in reason else ErrorCode.E_CAPABILITY_DENIED
- )
- return envelope.make_error_response(
- error_code.value,
- reason,
- )
+ return envelope.make_error_response(ErrorCode.E_CAPABILITY_DENIED.value, reason)
# 2. 查找实现
impl = self._implementations.get(capability)
if impl is None:
- return envelope.make_error_response(
- ErrorCode.E_METHOD_NOT_ALLOWED.value,
- f"未注册的能力: {capability}",
- )
+ return envelope.make_error_response(ErrorCode.E_METHOD_NOT_ALLOWED.value, f"未注册的能力: {capability}")
# 3. 执行
try:
- result = await impl(plugin_id, capability, req.args)
+ result = await impl(plugin_id, capability, args)
resp_payload = CapabilityResponsePayload(success=True, result=result)
return envelope.make_response(payload=resp_payload.model_dump())
except RPCError as e:
return envelope.make_error_response(e.code.value, e.message, e.details)
except Exception as e:
logger.error(f"能力 {capability} 执行异常: {e}", exc_info=True)
- return envelope.make_error_response(
- ErrorCode.E_CAPABILITY_FAILED.value,
- str(e),
- )
+ return envelope.make_error_response(ErrorCode.E_CAPABILITY_FAILED.value, str(e))
def list_capabilities(self) -> List[str]:
"""列出所有已注册的能力"""
diff --git a/src/plugin_runtime/host/component_registry.py b/src/plugin_runtime/host/component_registry.py
index 220a19c0..97fdca30 100644
--- a/src/plugin_runtime/host/component_registry.py
+++ b/src/plugin_runtime/host/component_registry.py
@@ -1,7 +1,7 @@
"""Host-side ComponentRegistry
对齐旧系统 component_registry.py 的核心能力:
-- 按类型注册组件(action / command / tool / event_handler / workflow_step)
+- 按类型注册组件(action / command / tool / event_handler / workflow_handler / message_gateway)
- 命名空间 (plugin_id.component_name)
- 命令正则匹配
- 组件启用/禁用
@@ -9,8 +9,10 @@
- 注册统计
"""
-from typing import Any, Dict, List, Optional
+from enum import Enum
+from typing import Any, Dict, List, Optional, Set, TypedDict, Tuple
+import contextlib
import re
from src.common.logger import get_logger
@@ -18,8 +20,28 @@ from src.common.logger import get_logger
logger = get_logger("plugin_runtime.host.component_registry")
-class RegisteredComponent:
- """已注册的组件条目"""
+class ComponentTypes(str, Enum):
+ ACTION = "ACTION"
+ COMMAND = "COMMAND"
+ TOOL = "TOOL"
+ EVENT_HANDLER = "EVENT_HANDLER"
+ HOOK_HANDLER = "HOOK_HANDLER"
+ MESSAGE_GATEWAY = "MESSAGE_GATEWAY"
+
+
+class StatusDict(TypedDict):
+ total: int
+ action: int
+ command: int
+ tool: int
+ event_handler: int
+ hook_handler: int
+ message_gateway: int
+ plugins: int
+
+
+class ComponentEntry:
+ """组件条目"""
__slots__ = (
"name",
@@ -28,31 +50,120 @@ class RegisteredComponent:
"plugin_id",
"metadata",
"enabled",
- "_compiled_pattern",
+ "compiled_pattern",
+ "disabled_session",
)
- def __init__(
- self,
- name: str,
- component_type: str,
- plugin_id: str,
- metadata: Dict[str, Any],
- ) -> None:
- self.name = name
- self.full_name = f"{plugin_id}.{name}"
- self.component_type = component_type
- self.plugin_id = plugin_id
- self.metadata = metadata
- self.enabled = metadata.get("enabled", True)
+ def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
+ self.name: str = name
+ self.full_name: str = f"{plugin_id}.{name}"
+ self.component_type: ComponentTypes = ComponentTypes(component_type)
+ self.plugin_id: str = plugin_id
+ self.metadata: Dict[str, Any] = metadata
+ self.enabled: bool = metadata.get("enabled", True)
+ self.disabled_session: Set[str] = set()
- # 预编译命令正则(仅 command 类型)
- self._compiled_pattern: Optional[re.Pattern] = None
- if component_type == "command":
- if pattern := metadata.get("command_pattern", ""):
- try:
- self._compiled_pattern = re.compile(pattern)
- except re.error as e:
- logger.warning(f"命令 {self.full_name} 正则编译失败: {e}")
+
+class ActionEntry(ComponentEntry):
+ """Action 组件条目"""
+
+ def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
+ super().__init__(name, component_type, plugin_id, metadata)
+
+
+class CommandEntry(ComponentEntry):
+ """Command 组件条目"""
+
+ def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
+ super().__init__(name, component_type, plugin_id, metadata)
+ self.aliases: List[str] = metadata.get("aliases", [])
+ self.compiled_pattern: Optional[re.Pattern] = None
+ if pattern := metadata.get("command_pattern", ""):
+ try:
+ self.compiled_pattern = re.compile(pattern)
+ except (re.error, TypeError) as e:
+ logger.warning(f"命令 {self.full_name} 正则编译失败: {e}")
+
+
+class ToolEntry(ComponentEntry):
+ """Tool 组件条目"""
+
+ def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
+ self.description: str = metadata.get("description", "")
+ self.parameters: List[Dict[str, Any]] = metadata.get("parameters", [])
+ self.parameters_raw: List[Dict[str, Any]] = metadata.get("parameters_raw", [])
+ super().__init__(name, component_type, plugin_id, metadata)
+
+
+class EventHandlerEntry(ComponentEntry):
+ """EventHandler 组件条目"""
+
+ def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
+ self.event_type: str = metadata.get("event_type", "")
+ self.weight: int = metadata.get("weight", 0)
+ self.intercept_message: bool = metadata.get("intercept_message", False)
+ super().__init__(name, component_type, plugin_id, metadata)
+
+
+class HookHandlerEntry(ComponentEntry):
+ """WorkflowHandler 组件条目"""
+
+ def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
+ self.stage: str = metadata.get("stage", "")
+ self.priority: int = metadata.get("priority", 0)
+ self.blocking: bool = metadata.get("blocking", False)
+ super().__init__(name, component_type, plugin_id, metadata)
+
+
+class MessageGatewayEntry(ComponentEntry):
+ """MessageGateway 组件条目"""
+
+ def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
+ self.route_type: str = self._normalize_route_type(metadata.get("route_type", ""))
+ self.platform: str = str(metadata.get("platform", "") or "").strip()
+ self.protocol: str = str(metadata.get("protocol", "") or "").strip()
+ self.account_id: str = str(metadata.get("account_id", "") or "").strip()
+ self.scope: str = str(metadata.get("scope", "") or "").strip()
+ super().__init__(name, component_type, plugin_id, metadata)
+
+ @staticmethod
+ def _normalize_route_type(raw_value: Any) -> str:
+ """规范化消息网关路由类型。
+
+ Args:
+ raw_value: 原始路由类型值。
+
+ Returns:
+ str: 规范化后的路由类型。
+
+ Raises:
+ ValueError: 当路由类型不受支持时抛出。
+ """
+
+ normalized_value = str(raw_value or "").strip().lower()
+ route_type_aliases = {
+ "send": "send",
+ "receive": "receive",
+ "recv": "receive",
+ "recive": "receive",
+ "duplex": "duplex",
+ }
+ route_type = route_type_aliases.get(normalized_value)
+ if route_type is None:
+ raise ValueError(f"MessageGateway 路由类型不合法: {raw_value}")
+ return route_type
+
+ @property
+ def supports_send(self) -> bool:
+ """返回当前网关是否支持出站。"""
+
+ return self.route_type in {"send", "duplex"}
+
+ @property
+ def supports_receive(self) -> bool:
+ """返回当前网关是否支持入站。"""
+
+ return self.route_type in {"receive", "duplex"}
class ComponentRegistry:
@@ -64,19 +175,32 @@ class ComponentRegistry:
def __init__(self) -> None:
# 全量索引
- self._components: Dict[str, RegisteredComponent] = {} # full_name -> comp
+ self._components: Dict[str, ComponentEntry] = {} # full_name -> comp
# 按类型索引
- self._by_type: Dict[str, Dict[str, RegisteredComponent]] = {
- "action": {},
- "command": {},
- "tool": {},
- "event_handler": {},
- "workflow_step": {},
- }
+ self._by_type: Dict[ComponentTypes, Dict[str, ComponentEntry]] = {
+ comp_type: {} for comp_type in ComponentTypes
+ } # component_type -> (full_name -> comp)
# 按插件索引
- self._by_plugin: Dict[str, List[RegisteredComponent]] = {}
+ self._by_plugin: Dict[str, List[ComponentEntry]] = {}
+
+ @staticmethod
+ def _normalize_component_type(component_type: str) -> ComponentTypes:
+ """规范化组件类型输入。
+
+ Args:
+ component_type: 原始组件类型字符串。
+
+ Returns:
+ ComponentTypes: 规范化后的组件类型枚举。
+
+ Raises:
+ ValueError: 当组件类型不受支持时抛出。
+ """
+
+ normalized_value = str(component_type or "").strip().upper()
+ return ComponentTypes(normalized_value)
def clear(self) -> None:
"""清空全部组件注册状态。"""
@@ -85,47 +209,64 @@ class ComponentRegistry:
type_dict.clear()
self._by_plugin.clear()
- # ──── 注册 / 注销 ─────────────────────────────────────────
+ # ====== 注册 / 注销 ======
+ def register_component(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> bool:
+ """注册单个组件
+
+ Args:
+ name: 组件名称(不含插件id前缀)
+ component_type: 组件类型(如 `ACTION`、`COMMAND` 等)
+ plugin_id: 插件id
+ metadata: 组件元数据
+ Returns:
+ success (bool): 是否成功注册(失败原因通常是组件类型无效)
+ """
+ try:
+ normalized_type = self._normalize_component_type(component_type)
+ if normalized_type == ComponentTypes.ACTION:
+ comp = ActionEntry(name, normalized_type.value, plugin_id, metadata)
+ elif normalized_type == ComponentTypes.COMMAND:
+ comp = CommandEntry(name, normalized_type.value, plugin_id, metadata)
+ elif normalized_type == ComponentTypes.TOOL:
+ comp = ToolEntry(name, normalized_type.value, plugin_id, metadata)
+ elif normalized_type == ComponentTypes.EVENT_HANDLER:
+ comp = EventHandlerEntry(name, normalized_type.value, plugin_id, metadata)
+ elif normalized_type == ComponentTypes.HOOK_HANDLER:
+ comp = HookHandlerEntry(name, normalized_type.value, plugin_id, metadata)
+ elif normalized_type == ComponentTypes.MESSAGE_GATEWAY:
+ comp = MessageGatewayEntry(name, normalized_type.value, plugin_id, metadata)
+ else:
+ raise ValueError(f"组件类型 {component_type} 不存在")
+ except ValueError:
+ logger.error(f"组件类型 {component_type} 不存在")
+ return False
- def register_component(
- self,
- name: str,
- component_type: str,
- plugin_id: str,
- metadata: Dict[str, Any],
- ) -> bool:
- """注册单个组件。"""
- comp = RegisteredComponent(name, component_type, plugin_id, metadata)
if comp.full_name in self._components:
logger.warning(f"组件 {comp.full_name} 已存在,覆盖")
old_comp = self._components[comp.full_name]
# 从 _by_plugin 列表中移除旧条目,防止幽灵组件堆积
old_list = self._by_plugin.get(old_comp.plugin_id)
if old_list is not None:
- try:
+ with contextlib.suppress(ValueError):
old_list.remove(old_comp)
- except ValueError:
- pass
# 从旧类型索引中移除,防止类型变更时幽灵残留
if old_type_dict := self._by_type.get(old_comp.component_type):
old_type_dict.pop(comp.full_name, None)
self._components[comp.full_name] = comp
-
- if component_type not in self._by_type:
- self._by_type[component_type] = {}
- self._by_type[component_type][comp.full_name] = comp
-
+ self._by_type[comp.component_type][comp.full_name] = comp
self._by_plugin.setdefault(plugin_id, []).append(comp)
return True
- def register_plugin_components(
- self,
- plugin_id: str,
- components: List[Dict[str, Any]],
- ) -> int:
- """批量注册一个插件的所有组件,返回成功注册数。"""
+ def register_plugin_components(self, plugin_id: str, components: List[Dict[str, Any]]) -> int:
+ """批量注册一个插件的所有组件,返回成功注册数。
+ Args:
+ plugin_id (str): 插件id
+ components (List[Dict[str, Any]]): 组件字典列表,每个组件包含 name, component_type, metadata 等字段
+ Returns:
+ count (int): 成功注册的组件数量
+ """
count = 0
for comp_data in components:
ok = self.register_component(
@@ -139,7 +280,13 @@ class ComponentRegistry:
return count
def remove_components_by_plugin(self, plugin_id: str) -> int:
- """移除某个插件的所有组件,返回移除数量。"""
+ """移除某个插件的所有组件,返回移除数量。
+
+ Args:
+ plugin_id (str): 插件id
+ Returns:
+ count (int): 移除的组件数量
+ """
comps = self._by_plugin.pop(plugin_id, [])
for comp in comps:
self._components.pop(comp.full_name, None)
@@ -147,106 +294,280 @@ class ComponentRegistry:
type_dict.pop(comp.full_name, None)
return len(comps)
- # ──── 启用 / 禁用 ─────────────────────────────────────────
+ # ====== 启用 / 禁用 ======
+ def check_component_enabled(self, component: ComponentEntry, session_id: Optional[str] = None):
+ if session_id and session_id in component.disabled_session:
+ return False
+ return component.enabled
- def set_component_enabled(self, full_name: str, enabled: bool) -> bool:
- """启用或禁用指定组件。"""
+ def toggle_component_status(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool:
+ """启用或禁用指定组件。
+
+ Args:
+ full_name (str): 组件全名
+ enabled (bool): 使能情况
+ session_id (Optional[str]): 可选的会话ID,仅对该会话禁用(如果提供)
+ Returns:
+ success (bool): 是否成功设置(失败原因通常是组件不存在)
+ """
comp = self._components.get(full_name)
if comp is None:
return False
- comp.enabled = enabled
+ if session_id:
+ if enabled:
+ comp.disabled_session.discard(session_id)
+ else:
+ comp.disabled_session.add(session_id)
+ else:
+ comp.enabled = enabled
return True
- def set_plugin_enabled(self, plugin_id: str, enabled: bool) -> int:
- """批量启用或禁用某插件的所有组件。"""
+ def set_component_enabled(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool:
+ """设置指定组件的启用状态。
+
+ Args:
+ full_name: 组件全名。
+ enabled: 目标启用状态。
+ session_id: 可选的会话 ID,仅对该会话生效。
+
+ Returns:
+ bool: 是否设置成功。
+ """
+
+ return self.toggle_component_status(full_name, enabled, session_id=session_id)
+
+ def toggle_plugin_status(self, plugin_id: str, enabled: bool, session_id: Optional[str] = None) -> int:
+ """批量启用或禁用某插件的所有组件。
+
+ Args:
+ plugin_id (str): 插件id
+ enabled (bool): 使能情况
+ session_id (Optional[str]): 可选的会话ID,仅对该会话禁用(如果提供)
+ Returns:
+ count (int): 成功设置的组件数量(失败原因通常是插件不存在)
+ """
comps = self._by_plugin.get(plugin_id, [])
for comp in comps:
- comp.enabled = enabled
+ if session_id:
+ if enabled:
+ comp.disabled_session.discard(session_id)
+ else:
+ comp.disabled_session.add(session_id)
+ else:
+ comp.enabled = enabled
return len(comps)
- # ──── 查询方法 ─────────────────────────────────────────────
+ def get_component(self, full_name: str) -> Optional[ComponentEntry]:
+ """按全名查询。
- def get_component(self, full_name: str) -> Optional[RegisteredComponent]:
- """按全名查询。"""
+ Args:
+ full_name (str): 组件全名
+ Returns:
+ component (Optional[ComponentEntry]): 组件条目,未找到时为 None
+ """
return self._components.get(full_name)
- def get_components_by_type(self, component_type: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
- """按类型查询。"""
- type_dict = self._by_type.get(component_type, {})
+ def get_components_by_type(
+ self, component_type: str, *, enabled_only: bool = True, session_id: Optional[str] = None
+ ) -> List[ComponentEntry]:
+ """按类型查询组件
+
+ Args:
+ component_type (str): 组件类型(如 `ACTION`、`COMMAND` 等)
+ enabled_only (bool): 是否仅返回启用的组件
+ session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态
+ Returns:
+ components (List[ComponentEntry]): 组件条目列表
+ """
+ try:
+ comp_type = self._normalize_component_type(component_type)
+ except ValueError:
+ logger.error(f"组件类型 {component_type} 不存在")
+ raise
+ type_dict = self._by_type.get(comp_type, {})
if enabled_only:
- return [c for c in type_dict.values() if c.enabled]
+ return [c for c in type_dict.values() if self.check_component_enabled(c, session_id)]
return list(type_dict.values())
- def get_components_by_plugin(self, plugin_id: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
- """按插件查询。"""
- comps = self._by_plugin.get(plugin_id, [])
- return [c for c in comps if c.enabled] if enabled_only else list(comps)
+ def get_components_by_plugin(
+ self, plugin_id: str, *, enabled_only: bool = True, session_id: Optional[str] = None
+ ) -> List[ComponentEntry]:
+ """按插件查询组件。
- def find_command_by_text(self, text: str) -> Optional[tuple[RegisteredComponent, Dict[str, Any]]]:
+ Args:
+ plugin_id (str): 插件ID
+ enabled_only (bool): 是否仅返回启用的组件
+ session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态
+ Returns:
+ components (List[ComponentEntry]): 组件条目列表
+ """
+ comps = self._by_plugin.get(plugin_id, [])
+ return [c for c in comps if self.check_component_enabled(c, session_id)] if enabled_only else list(comps)
+
+ def find_command_by_text(
+ self, text: str, session_id: Optional[str] = None
+ ) -> Optional[Tuple[ComponentEntry, Dict[str, Any]]]:
"""通过文本匹配命令正则,返回 (组件, matched_groups) 元组。
matched_groups 为正则命名捕获组 dict,别名匹配时为空 dict。
+ Args:
+ text (str): 待匹配文本
+ session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态
+ Returns:
+ result (Optional[tuple[ComponentEntry, Dict[str, Any]]]): 匹配到的组件及正则捕获组,未找到时为 None
"""
- for comp in self._by_type.get("command", {}).values():
- if not comp.enabled:
+ for comp in self._by_type.get(ComponentTypes.COMMAND, {}).values():
+ if not self.check_component_enabled(comp, session_id):
continue
- if comp._compiled_pattern:
- m = comp._compiled_pattern.search(text)
- if m:
+ if not isinstance(comp, CommandEntry):
+ continue
+ if comp.compiled_pattern:
+ if m := comp.compiled_pattern.search(text):
return comp, m.groupdict()
# 别名匹配
- aliases = comp.metadata.get("aliases", [])
- for alias in aliases:
+ for alias in comp.aliases:
if text.startswith(alias):
return comp, {}
return None
- def get_event_handlers(self, event_type: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
- """获取特定事件类型的所有 event_handler,按 weight 降序排列。"""
- handlers = []
- for comp in self._by_type.get("event_handler", {}).values():
- if enabled_only and not comp.enabled:
+ def get_event_handlers(
+ self, event_type: str, *, enabled_only: bool = True, session_id: Optional[str] = None
+ ) -> List[EventHandlerEntry]:
+ """查询指定事件类型的事件处理器组件。
+
+ Args:
+ event_type (str): 事件类型
+ enabled_only (bool): 是否仅返回启用的组件
+ session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态
+ Returns:
+ handlers (List[EventHandlerEntry]): 符合条件的 EventHandler 组件列表,按 weight 降序排序
+ """
+ handlers: List[EventHandlerEntry] = []
+ for comp in self._by_type.get(ComponentTypes.EVENT_HANDLER, {}).values():
+ if enabled_only and not self.check_component_enabled(comp, session_id):
continue
- if comp.metadata.get("event_type") == event_type:
+ if not isinstance(comp, EventHandlerEntry):
+ continue
+ if comp.event_type == event_type:
handlers.append(comp)
- handlers.sort(key=lambda c: c.metadata.get("weight", 0), reverse=True)
+ handlers.sort(key=lambda c: c.weight, reverse=True)
return handlers
- def get_workflow_steps(self, stage: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
- """获取特定 workflow 阶段的所有步骤,按 priority 降序。"""
- steps = []
- for comp in self._by_type.get("workflow_step", {}).values():
- if enabled_only and not comp.enabled:
+ def get_hook_handlers(
+ self, stage: str, *, enabled_only: bool = True, session_id: Optional[str] = None
+ ) -> List[HookHandlerEntry]:
+ """获取特定 hook 阶段的所有步骤,按 priority 降序。
+
+ Args:
+ stage: hook 名称
+ enabled_only: 是否仅返回启用的组件
+ session_id: 可选的会话ID,若提供则考虑会话禁用状态
+ Returns:
+ handlers (List[HookHandlerEntry]): 符合条件的 HookHandler 组件列表,按 priority 降序排序
+ """
+ handlers: List[HookHandlerEntry] = []
+ for comp in self._by_type.get(ComponentTypes.HOOK_HANDLER, {}).values():
+ if enabled_only and not self.check_component_enabled(comp, session_id):
continue
- if comp.metadata.get("stage") == stage:
- steps.append(comp)
- steps.sort(key=lambda c: c.metadata.get("priority", 0), reverse=True)
- return steps
+ if not isinstance(comp, HookHandlerEntry):
+ continue
+ if comp.stage == stage:
+ handlers.append(comp)
+ handlers.sort(key=lambda c: c.priority, reverse=True)
+ return handlers
- def get_tools_for_llm(self, *, enabled_only: bool = True) -> List[Dict[str, Any]]:
- """获取可供 LLM 使用的工具列表(openai function-calling 格式预览)。"""
- result: List[Dict[str, Any]] = []
- for comp in self.get_components_by_type("tool", enabled_only=enabled_only):
- tool_def: Dict[str, Any] = {
- "name": comp.full_name,
- "description": comp.metadata.get("description", ""),
- }
- # 从结构化参数或原始参数构建 parameters
- params = comp.metadata.get("parameters", [])
- params_raw = comp.metadata.get("parameters_raw", {})
- if params:
- tool_def["parameters"] = params
- elif params_raw:
- tool_def["parameters"] = params_raw
- result.append(tool_def)
- return result
+ def get_message_gateway(
+ self,
+ plugin_id: str,
+ name: str,
+ *,
+ enabled_only: bool = True,
+ session_id: Optional[str] = None,
+ ) -> Optional[MessageGatewayEntry]:
+ """按插件和组件名获取单个消息网关。
- # ──── 统计 ─────────────────────────────────────────────────
+ Args:
+ plugin_id: 插件 ID。
+ name: 网关组件名称。
+ enabled_only: 是否仅返回启用的组件。
+ session_id: 可选的会话 ID。
- def get_stats(self) -> Dict[str, int]:
- """获取注册统计。"""
- stats: Dict[str, int] = {"total": len(self._components)}
+ Returns:
+ Optional[MessageGatewayEntry]: 若存在则返回消息网关条目。
+ """
+
+ component = self._components.get(f"{plugin_id}.{name}")
+ if not isinstance(component, MessageGatewayEntry):
+ return None
+ if enabled_only and not self.check_component_enabled(component, session_id):
+ return None
+ return component
+
+ def get_message_gateways(
+ self,
+ *,
+ plugin_id: Optional[str] = None,
+ platform: str = "",
+ route_type: str = "",
+ enabled_only: bool = True,
+ session_id: Optional[str] = None,
+ ) -> List[MessageGatewayEntry]:
+ """查询消息网关组件列表。
+
+ Args:
+ plugin_id: 可选的插件 ID 过滤条件。
+ platform: 可选的平台过滤条件。
+ route_type: 可选的路由类型过滤条件。
+ enabled_only: 是否仅返回启用的组件。
+ session_id: 可选的会话 ID。
+
+ Returns:
+ List[MessageGatewayEntry]: 符合条件的消息网关组件列表。
+ """
+
+ normalized_platform = str(platform or "").strip()
+ normalized_route_type = str(route_type or "").strip().lower()
+ gateways: List[MessageGatewayEntry] = []
+ for comp in self._by_type.get(ComponentTypes.MESSAGE_GATEWAY, {}).values():
+ if not isinstance(comp, MessageGatewayEntry):
+ continue
+ if plugin_id and comp.plugin_id != plugin_id:
+ continue
+ if enabled_only and not self.check_component_enabled(comp, session_id):
+ continue
+ if normalized_platform and comp.platform != normalized_platform:
+ continue
+ if normalized_route_type and comp.route_type != normalized_route_type:
+ continue
+ gateways.append(comp)
+ return gateways
+
+ def get_tools(self, *, enabled_only: bool = True, session_id: Optional[str] = None) -> List[ToolEntry]:
+ """查询所有工具组件。
+
+ Args:
+ enabled_only (bool): 是否仅返回启用的组件
+ session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态
+ Returns:
+ tools (List[ToolEntry]): 符合条件的 Tool 组件列表
+ """
+ tools: List[ToolEntry] = []
+ for comp in self._by_type.get(ComponentTypes.TOOL, {}).values():
+ if enabled_only and not self.check_component_enabled(comp, session_id):
+ continue
+ if isinstance(comp, ToolEntry):
+ tools.append(comp)
+ return tools
+
+ # ====== 统计信息 ======
+ def get_stats(self) -> StatusDict:
+ """获取注册统计。
+
+ Returns:
+ stats (StatusDict): 组件统计信息,包括总数、各类型数量、插件数量等
+ """
+ stats: StatusDict = {"total": len(self._components)} # type: ignore
for comp_type, type_dict in self._by_type.items():
- stats[comp_type] = len(type_dict)
+ stats[comp_type.value.lower()] = len(type_dict)
stats["plugins"] = len(self._by_plugin)
return stats
diff --git a/src/plugin_runtime/host/event_dispatcher.py b/src/plugin_runtime/host/event_dispatcher.py
index 720e93d7..d252b6ee 100644
--- a/src/plugin_runtime/host/event_dispatcher.py
+++ b/src/plugin_runtime/host/event_dispatcher.py
@@ -4,40 +4,40 @@
1. 按事件类型查询已注册的 event_handler(通过 ComponentRegistry)
2. 按 weight 排序,依次通过 RPC 调用 Runner 中的处理器
3. 支持阻塞(intercept_message)和非阻塞分发
-4. 事件结果历史记录
+4. 事件结果历史记录(有上限)
"""
-from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple
+from dataclasses import dataclass, field
+from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING
import asyncio
from src.common.logger import get_logger
-from src.plugin_runtime.host.component_registry import ComponentRegistry, RegisteredComponent
+
+from .message_utils import PluginMessageUtils, MessageDict
+
+if TYPE_CHECKING:
+ from .supervisor import PluginRunnerSupervisor
+ from .component_registry import ComponentRegistry, EventHandlerEntry
+ from src.chat.message_receive.message import SessionMessage
logger = get_logger("plugin_runtime.host.event_dispatcher")
# invoke_fn 类型: async (plugin_id, component_name, args) -> response_payload dict
InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]]
+# 每个事件类型的最大历史记录数量,防止内存无限增长
+_MAX_HISTORY_LENGTH = 100
+@dataclass
class EventResult:
"""单个 EventHandler 的执行结果"""
- __slots__ = ("handler_name", "success", "continue_processing", "modified_message", "custom_result")
-
- def __init__(
- self,
- handler_name: str,
- success: bool = True,
- continue_processing: bool = True,
- modified_message: Optional[Dict[str, Any]] = None,
- custom_result: Any = None,
- ):
- self.handler_name = handler_name
- self.success = success
- self.continue_processing = continue_processing
- self.modified_message = modified_message
- self.custom_result = custom_result
+ handler_name: str
+ success: bool = field(default=True)
+ continue_processing: bool = field(default=True)
+ modified_message: Optional[MessageDict] = field(default=None)
+ custom_result: Any = field(default=None)
class EventDispatcher:
@@ -48,17 +48,20 @@ class EventDispatcher:
再通过提供的 invoke_fn 回调 RPC 到 Runner 执行。
"""
- def __init__(self, registry: ComponentRegistry) -> None:
- self._registry: ComponentRegistry = registry
+ def __init__(self, component_registry: "ComponentRegistry") -> None:
+ self._component_registry: "ComponentRegistry" = component_registry
self._result_history: Dict[str, List[EventResult]] = {}
self._history_enabled: Set[str] = set()
- # 保持 fire-and-forget task 的强引用,防止被 GC 回收
self._background_tasks: Set[asyncio.Task] = set()
def enable_history(self, event_type: str) -> None:
self._history_enabled.add(event_type)
self._result_history.setdefault(event_type, [])
+ def disable_history(self, event_type: str) -> None:
+ self._history_enabled.discard(event_type)
+ self._result_history.pop(event_type, None)
+
def get_history(self, event_type: str) -> List[EventResult]:
return self._result_history.get(event_type, [])
@@ -66,47 +69,58 @@ class EventDispatcher:
if event_type in self._result_history:
self._result_history[event_type] = []
+ async def stop(self):
+ """停止 EventDispatcher,取消所有未完成的后台任务"""
+ for task in self._background_tasks:
+ task.cancel()
+ await asyncio.gather(*self._background_tasks, return_exceptions=True)
+ self._background_tasks.clear()
+
async def dispatch_event(
self,
event_type: str,
- invoke_fn: InvokeFn,
- message: Optional[Dict[str, Any]] = None,
+ supervisor: "PluginRunnerSupervisor",
+ message: Optional["SessionMessage"] = None,
extra_args: Optional[Dict[str, Any]] = None,
- ) -> Tuple[bool, Optional[Dict[str, Any]]]:
- """分发事件到所有对应 handler。
+ ) -> Tuple[bool, Optional["SessionMessage"]]:
+ """分发事件到所有对应 handler 的便捷方法。
+
+ 内置了通过 PluginSupervisor.invoke_plugin 调用 plugin.emit_event 的逻辑,
+ 无需调用方手动构造 invoke_fn 闭包。
Args:
event_type: 事件类型字符串
- invoke_fn: 异步回调,签名 (plugin_id, component_name, args) -> response_payload dict
+ supervisor: PluginSupervisor 实例,用于调用 invoke_plugin
message: MaiMessages 序列化后的 dict(可选)
extra_args: 额外参数
Returns:
- (should_continue, modified_message_dict)
+ (should_continue, modified_message_dict) (bool, SessionMessage | None): (是否继续后续执行, 可选的修改后的消息)
"""
- handlers = self._registry.get_event_handlers(event_type)
- if not handlers:
+ handler_entries = self._component_registry.get_event_handlers(event_type)
+ if not handler_entries:
return True, None
should_continue = True
- modified_message: Optional[Dict[str, Any]] = None
- intercept_handlers: List[RegisteredComponent] = []
- async_handlers: List[RegisteredComponent] = []
+ modified_message: Optional[MessageDict] = (
+ PluginMessageUtils._session_message_to_dict(message) if message else None
+ )
+ intercept_handlers: List["EventHandlerEntry"] = []
+ non_blocking_handlers: List["EventHandlerEntry"] = []
- for handler in handlers:
- if handler.metadata.get("intercept_message", False):
- intercept_handlers.append(handler)
+ for entry in handler_entries:
+ if entry.intercept_message:
+ intercept_handlers.append(entry)
else:
- async_handlers.append(handler)
+ non_blocking_handlers.append(entry)
- for handler in intercept_handlers:
+ for entry in intercept_handlers:
args = {
"event_type": event_type,
- "message": modified_message or message,
+ "message": modified_message,
**(extra_args or {}),
}
-
- result = await self._invoke_handler(invoke_fn, handler, args, event_type)
+ result = await self._invoke_handler(supervisor, entry, args, event_type)
if result and not result.continue_processing:
should_continue = False
break
@@ -114,47 +128,57 @@ class EventDispatcher:
modified_message = result.modified_message
if should_continue:
- final_message = modified_message or message
- for handler in async_handlers:
- async_message = final_message.copy() if isinstance(final_message, dict) else final_message
+ final_message = modified_message
+ for entry in non_blocking_handlers:
+ async_message = final_message.copy() if final_message else final_message
args = {
"event_type": event_type,
"message": async_message,
**(extra_args or {}),
}
# 非阻塞:保持实例级强引用,防止 task 被 GC 回收
- task = asyncio.create_task(self._invoke_handler(invoke_fn, handler, args, event_type))
+ task = asyncio.create_task(self._invoke_handler(supervisor, entry, args, event_type))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
-
- return should_continue, modified_message
+ try:
+ modified_message_obj = (
+ PluginMessageUtils._build_session_message_from_dict(modified_message) if modified_message else None # type: ignore
+ )
+ except Exception as e:
+ logger.error(f"构建修改后的 SessionMessage 失败: {e}")
+ modified_message_obj = None
+ return should_continue, modified_message_obj
async def _invoke_handler(
self,
- invoke_fn: InvokeFn,
- handler: RegisteredComponent,
+ supervisor: "PluginRunnerSupervisor",
+ handler_entry: "EventHandlerEntry",
args: Dict[str, Any],
event_type: str,
) -> Optional[EventResult]:
"""调用单个 handler 并收集结果。"""
try:
- resp = await invoke_fn(handler.plugin_id, handler.name, args)
+ resp_envelope = await supervisor.invoke_plugin(
+ "plugin.emit_event", handler_entry.plugin_id, handler_entry.name, args
+ )
+ resp = resp_envelope.payload
result = EventResult(
- handler_name=handler.full_name,
+ handler_name=handler_entry.full_name,
success=resp.get("success", True),
continue_processing=resp.get("continue_processing", True),
modified_message=resp.get("modified_message"),
custom_result=resp.get("custom_result"),
)
except Exception as e:
- logger.error(f"EventHandler {handler.full_name} 执行失败: {e}", exc_info=True)
- result = EventResult(
- handler_name=handler.full_name,
- success=False,
- continue_processing=True,
- )
+ logger.error(f"EventHandler {handler_entry.full_name} 执行失败: {e}", exc_info=True)
+ result = EventResult(handler_name=handler_entry.full_name, success=False, continue_processing=True)
if event_type in self._history_enabled:
- self._result_history.setdefault(event_type, []).append(result)
+ history_list = self._result_history.setdefault(event_type, [])
+ history_list.append(result)
+ # 自动清理超出限制的旧记录,防止内存无限增长
+ if len(history_list) > _MAX_HISTORY_LENGTH:
+ # 保留最新的 _MAX_HISTORY_LENGTH 条记录
+ self._result_history[event_type] = history_list[-_MAX_HISTORY_LENGTH:]
return result
diff --git a/src/plugin_runtime/host/hook_dispatcher.py b/src/plugin_runtime/host/hook_dispatcher.py
new file mode 100644
index 00000000..d5e88448
--- /dev/null
+++ b/src/plugin_runtime/host/hook_dispatcher.py
@@ -0,0 +1,166 @@
+"""
+Hook Dispatch 系统
+
+插件可以注册自己的Hook,当特定函数被调用时,Hook Dispatch系统会将调用转发给插件的Hook处理函数。
+每个Hook的参数随Hook点位确定,因此参数是易变的。插件开发者需要根据Hook点位的定义来编写Hook处理函数。
+在参数/返回值匹配的情况下允许修改参数/返回值。
+
+HookDispatcher 负责:
+1. 按 stage 查询已注册的 hook_handler(通过 ComponentRegistry)
+2. 按 priority 排序,区分 blocking 和非 blocking 模式
+3. blocking 模式:依次同步调用,支持修改参数/提前终止
+4. 非 blocking 模式:异步调用,不阻塞主流程
+5. 支持通过 global_config.plugin_runtime.hook_blocking_timeout_sec 设置超时上限
+"""
+
+import asyncio
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING
+
+from src.common.logger import get_logger
+from src.config.config import global_config
+
+
+if TYPE_CHECKING:
+ from .supervisor import PluginRunnerSupervisor
+ from .component_registry import ComponentRegistry, HookHandlerEntry
+
+logger = get_logger("plugin_runtime.host.hook_dispatcher")
+
+
+@dataclass
+class HookResult:
+ """单个 HookHandler 的执行结果"""
+
+ handler_name: str
+ success: bool = field(default=True)
+ continue_processing: bool = field(default=True)
+ modified_kwargs: Optional[Dict[str, Any]] = field(default=None)
+ custom_result: Any = field(default=None)
+
+
+class HookDispatcher:
+ """Host-side Hook 分发器
+
+ 由业务层调用 hook_dispatch(),
+ 内部通过 ComponentRegistry 查询 handler,
+ 再通过提供的 invoke_fn 回调 RPC 到 Runner 执行。
+ """
+
+ def __init__(self, component_registry: "ComponentRegistry") -> None:
+ """初始化 HookDispatcher
+
+ Args:
+ component_registry: ComponentRegistry 实例,用于查询已注册的 hook_handler
+ """
+ self._component_registry: "ComponentRegistry" = component_registry
+ self._background_tasks: Set[asyncio.Task] = set()
+
+ async def stop(self) -> None:
+ """停止 HookDispatcher,取消所有未完成的后台任务"""
+ for task in self._background_tasks:
+ task.cancel()
+ await asyncio.gather(*self._background_tasks, return_exceptions=True)
+ self._background_tasks.clear()
+
+ async def hook_dispatch(
+ self,
+ stage: str,
+ supervisor: "PluginRunnerSupervisor",
+ **kwargs: Any,
+ ) -> Dict[str, Any]:
+ """分发 hook 到所有对应 handler 的便捷方法。
+
+ 内置了通过 PluginRunnerSupervisor.invoke_plugin 调用 plugin 的逻辑,
+ 无需调用方手动构造 invoke_fn 闭包。
+
+ Args:
+ stage: hook 名称
+ supervisor: PluginRunnerSupervisor 实例,用于调用 invoke_plugin
+ **kwargs: 关键字参数,会展开传递给 handler
+
+ Returns:
+ modified_kwargs (Dict[str, Any]): 经过所有 handler 修改后的关键字参数
+ """
+ handler_entries = self._component_registry.get_hook_handlers(stage)
+ if not handler_entries:
+ return kwargs
+
+ current_kwargs = kwargs.copy()
+ blocking_handlers: List["HookHandlerEntry"] = []
+ non_blocking_handlers: List["HookHandlerEntry"] = []
+
+ # 分离 blocking 和非 blocking handler
+ for entry in handler_entries:
+ if entry.blocking:
+ blocking_handlers.append(entry)
+ else:
+ non_blocking_handlers.append(entry)
+
+ # 处理 blocking handlers(同步调用,支持修改参数/提前终止)
+ timeout = global_config.plugin_runtime.hook_blocking_timeout_sec or 30.0
+ for entry in blocking_handlers:
+ hook_args = {"stage": stage, **current_kwargs}
+ try:
+ # 应用超时控制
+ result = await asyncio.wait_for(
+ self._invoke_handler(supervisor, entry, hook_args),
+ timeout=timeout,
+ )
+ except asyncio.TimeoutError:
+ logger.error(f"Blocking HookHandler {entry.full_name} 执行超时 (>{timeout}秒),跳过")
+ result = HookResult(handler_name=entry.full_name, success=False, continue_processing=True)
+
+ if result:
+ if result.modified_kwargs is not None:
+ current_kwargs = result.modified_kwargs
+ if not result.continue_processing:
+ logger.info(f"HookHandler {entry.full_name} 终止了后续处理")
+ break
+
+ # 处理 non-blocking handlers(异步调用,不阻塞主流程)
+ for entry in non_blocking_handlers:
+ async_kwargs = current_kwargs.copy()
+ hook_args = {"stage": stage, **async_kwargs}
+ task = asyncio.create_task(
+ asyncio.wait_for(self._invoke_handler(supervisor, entry, hook_args), timeout=timeout)
+ )
+ self._background_tasks.add(task)
+ task.add_done_callback(self._background_tasks.discard)
+
+ return current_kwargs
+
+ async def _invoke_handler(
+ self,
+ supervisor: "PluginRunnerSupervisor",
+ handler_entry: "HookHandlerEntry",
+ args: Dict[str, Any],
+ ) -> Optional[HookResult]:
+ """调用单个 handler 并收集结果。
+
+ Args:
+ supervisor: PluginRunnerSupervisor 实例
+ handler_entry: HookHandlerEntry 实例
+ args: 传递给 handler 的参数字典
+ stage: hook 名称
+
+ Returns:
+ Optional[HookResult]: 执行结果,如果执行失败则返回 None
+ """
+ try:
+ resp_envelope = await supervisor.invoke_plugin(
+ "plugin.invoke_hook", handler_entry.plugin_id, handler_entry.name, args
+ )
+ resp = resp_envelope.payload
+ result = HookResult(
+ handler_name=handler_entry.full_name,
+ success=resp.get("success", True),
+ continue_processing=resp.get("continue_processing", True),
+ modified_kwargs=resp.get("modified_kwargs"),
+ custom_result=resp.get("custom_result"),
+ )
+ except Exception as e:
+ logger.error(f"HookHandler {handler_entry.full_name} 执行失败:{e}", exc_info=True)
+ result = HookResult(handler_name=handler_entry.full_name, success=False, continue_processing=True)
+
+ return result
diff --git a/src/plugin_runtime/host/logger_bridge.py b/src/plugin_runtime/host/logger_bridge.py
new file mode 100644
index 00000000..f2213dfe
--- /dev/null
+++ b/src/plugin_runtime/host/logger_bridge.py
@@ -0,0 +1,45 @@
+import logging as stdlib_logging
+from src.plugin_runtime.protocol.errors import ErrorCode
+from src.plugin_runtime.protocol.envelope import Envelope, LogBatchPayload
+class RunnerLogBridge:
+ """将 Runner 进程上报的批量日志重放到主进程的 Logger 中。
+
+ Runner 通过 ``runner.log_batch`` IPC 事件批量到达。
+ 每条 LogEntry 被重建为一个真实的 :class:`logging.LogRecord` 并直接
+ 调用 ``logging.getLogger(entry.logger_name).handle(record)``,
+ 从而接入主进程已配置好的 structlog Handler 链。
+ """
+
+ async def handle_log_batch(self, envelope: Envelope) -> Envelope:
+ """IPC 事件处理器:解析批量日志并重放到主进程 Logger。
+
+ Args:
+ envelope: 方法名为 ``runner.log_batch`` 的 IPC 事件信封。
+
+ Returns:
+ 空响应信封(事件模式下将被忽略)。
+ """
+ try:
+ batch = LogBatchPayload.model_validate(envelope.payload)
+ except Exception as exc:
+ return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
+
+ for entry in batch.entries:
+ # 重建一个与原始日志尽量相符的 LogRecord
+ record = stdlib_logging.LogRecord(
+ name=entry.logger_name,
+ level=entry.level,
+ pathname="",
+ lineno=0,
+ msg=entry.message,
+ args=(),
+ exc_info=None,
+ )
+ record.created = entry.timestamp_ms / 1000.0
+ record.msecs = entry.timestamp_ms % 1000
+ if entry.exception_text:
+ record.exc_text = entry.exception_text
+
+ stdlib_logging.getLogger(entry.logger_name).handle(record)
+
+ return envelope.make_response(payload={"accepted": True, "count": len(batch.entries)})
\ No newline at end of file
diff --git a/src/plugin_runtime/host/message_gateway.py b/src/plugin_runtime/host/message_gateway.py
new file mode 100644
index 00000000..90f94493
--- /dev/null
+++ b/src/plugin_runtime/host/message_gateway.py
@@ -0,0 +1,112 @@
+"""Host 侧消息网关包装器。"""
+
+from typing import TYPE_CHECKING, Any, Dict
+
+from src.common.logger import get_logger
+from src.platform_io import get_platform_io_manager
+
+from .message_utils import PluginMessageUtils
+
+if TYPE_CHECKING:
+ from src.chat.message_receive.message import SessionMessage
+ from .component_registry import ComponentRegistry
+ from .supervisor import PluginRunnerSupervisor
+
+logger = get_logger("plugin_runtime.host.message_gateway")
+
+
+class MessageGateway:
+ """Host 侧消息网关包装器。"""
+
+ def __init__(self, component_registry: "ComponentRegistry") -> None:
+ """初始化消息网关。
+
+ Args:
+ component_registry: 组件注册表。
+ """
+ self._component_registry = component_registry
+
+ def build_session_message(self, external_message: Dict[str, Any]) -> "SessionMessage":
+ """将标准消息字典转换为 ``SessionMessage``。
+
+ Args:
+ external_message: 外部消息的字典格式数据。
+
+ Returns:
+ SessionMessage: 转换后的内部消息对象。
+
+ Raises:
+ ValueError: 消息字典不合法时抛出。
+ """
+ return PluginMessageUtils._build_session_message_from_dict(external_message)
+
+ def build_message_dict(self, internal_message: "SessionMessage") -> Dict[str, Any]:
+ """将 ``SessionMessage`` 转换为标准消息字典。
+
+ Args:
+ internal_message: 内部消息对象。
+
+ Returns:
+ Dict[str, Any]: 供消息网关插件消费的标准消息字典。
+ """
+ return dict(PluginMessageUtils._session_message_to_dict(internal_message))
+
+ async def receive_external_message(self, external_message: Dict[str, Any]) -> None:
+ """接收外部消息并送入主消息链。
+
+ Args:
+ external_message: 外部消息的字典格式数据。
+ """
+ try:
+ session_message = self.build_session_message(external_message)
+ except Exception as e:
+ logger.error(f"转换外部消息失败: {e}")
+ return
+
+ from src.chat.message_receive.bot import chat_bot
+
+ await chat_bot.receive_message(session_message)
+
+ async def send_message_to_external(
+ self,
+ internal_message: "SessionMessage",
+ supervisor: "PluginRunnerSupervisor",
+ *,
+ enabled_only: bool = True,
+ save_to_db: bool = True,
+ ) -> bool:
+ """将内部消息通过 Platform IO 发送到外部平台。
+
+ Args:
+ internal_message: 系统内部的 ``SessionMessage`` 对象。
+ supervisor: 当前持有该消息网关的 Supervisor。
+ enabled_only: 兼容旧签名的保留参数,当前未使用。
+ save_to_db: 发送成功后是否写入数据库。
+
+ Returns:
+ bool: 是否发送成功。
+ """
+ del enabled_only
+ del supervisor
+
+ platform_io_manager = get_platform_io_manager()
+ if not platform_io_manager.is_started:
+ logger.warning("Platform IO 尚未启动,无法通过适配器链路发送消息")
+ return False
+
+ route_key = platform_io_manager.build_route_key_from_message(internal_message)
+ delivery_batch = await platform_io_manager.send_message(internal_message, route_key)
+ if not delivery_batch.has_success:
+ logger.warning("通过消息网关链路发送消息失败: 未命中任何成功回执")
+ return False
+
+ first_successful_receipt = delivery_batch.sent_receipts[0]
+ internal_message.message_id = first_successful_receipt.external_message_id or internal_message.message_id
+ if save_to_db:
+ try:
+ from src.common.utils.utils_message import MessageUtils
+
+ MessageUtils.store_message_to_db(internal_message)
+ except Exception as e:
+ logger.error(f"保存消息到数据库失败: {e}")
+ return True
diff --git a/src/plugin_runtime/host/message_utils.py b/src/plugin_runtime/host/message_utils.py
new file mode 100644
index 00000000..2f6aa01b
--- /dev/null
+++ b/src/plugin_runtime/host/message_utils.py
@@ -0,0 +1,487 @@
+from datetime import datetime
+from typing import Any, Dict, List, Optional, TypedDict
+
+import base64
+import hashlib
+
+from src.common.logger import get_logger
+from src.chat.message_receive.message import SessionMessage
+from src.common.data_models.mai_message_data_model import UserInfo, GroupInfo, MessageInfo
+from src.common.data_models.message_component_data_model import (
+ AtComponent,
+ DictComponent,
+ EmojiComponent,
+ ForwardComponent,
+ ForwardNodeComponent,
+ ImageComponent,
+ MessageSequence,
+ ReplyComponent,
+ StandardMessageComponents,
+ TextComponent,
+ VoiceComponent,
+)
+
+logger = get_logger("plugin_runtime.host.message_utils")
+
+
+class UserInfoDict(TypedDict, total=False):
+ user_id: str
+ user_nickname: str
+ user_cardname: Optional[str]
+
+
+class GroupInfoDict(TypedDict, total=False):
+ group_id: str
+ group_name: str
+
+
+class MessageInfoDict(TypedDict, total=False):
+ user_info: UserInfoDict
+ group_info: Optional[GroupInfoDict]
+ additional_config: Dict[str, Any]
+
+
+class MessageDict(TypedDict, total=False):
+ message_id: str
+ timestamp: str
+ platform: str
+ message_info: MessageInfoDict
+ raw_message: List[Dict[str, Any]]
+ is_mentioned: bool
+ is_at: bool
+ is_emoji: bool
+ is_picture: bool
+ is_command: bool
+ is_notify: bool
+ session_id: str
+ reply_to: Optional[str]
+ processed_plain_text: Optional[str]
+ display_message: Optional[str]
+
+
+class PluginMessageUtils:
+ @staticmethod
+ def _message_sequence_to_dict(message_sequence: MessageSequence) -> List[Dict[str, Any]]:
+ """将消息组件序列转换为插件运行时使用的字典结构。
+
+ Args:
+ message_sequence: 待转换的消息组件序列。
+
+ Returns:
+ List[Dict[str, Any]]: 供插件运行时协议使用的消息段字典列表。
+ """
+ return [PluginMessageUtils._component_to_dict(component) for component in message_sequence.components]
+
+ @staticmethod
+ def _component_to_dict(component: StandardMessageComponents) -> Dict[str, Any]:
+ """将单个消息组件转换为插件运行时字典结构。
+
+ Args:
+ component: 待转换的消息组件。
+
+ Returns:
+ Dict[str, Any]: 序列化后的消息组件字典。
+ """
+ if isinstance(component, TextComponent):
+ return {"type": "text", "data": component.text}
+
+ if isinstance(component, ImageComponent):
+ serialized = {
+ "type": "image",
+ "data": component.content,
+ "hash": component.binary_hash,
+ }
+ if component.binary_data:
+ serialized["binary_data_base64"] = base64.b64encode(component.binary_data).decode("utf-8")
+ return serialized
+
+ if isinstance(component, EmojiComponent):
+ serialized = {
+ "type": "emoji",
+ "data": component.content,
+ "hash": component.binary_hash,
+ }
+ if component.binary_data:
+ serialized["binary_data_base64"] = base64.b64encode(component.binary_data).decode("utf-8")
+ return serialized
+
+ if isinstance(component, VoiceComponent):
+ serialized = {
+ "type": "voice",
+ "data": component.content,
+ "hash": component.binary_hash,
+ }
+ if component.binary_data:
+ serialized["binary_data_base64"] = base64.b64encode(component.binary_data).decode("utf-8")
+ return serialized
+
+ if isinstance(component, AtComponent):
+ return {
+ "type": "at",
+ "data": {
+ "target_user_id": component.target_user_id,
+ "target_user_nickname": component.target_user_nickname,
+ "target_user_cardname": component.target_user_cardname,
+ },
+ }
+
+ if isinstance(component, ReplyComponent):
+ return {
+ "type": "reply",
+ "data": {
+ "target_message_id": component.target_message_id,
+ "target_message_content": component.target_message_content,
+ "target_message_sender_id": component.target_message_sender_id,
+ "target_message_sender_nickname": component.target_message_sender_nickname,
+ "target_message_sender_cardname": component.target_message_sender_cardname,
+ },
+ }
+
+ if isinstance(component, ForwardNodeComponent):
+ return {
+ "type": "forward",
+ "data": [PluginMessageUtils._forward_component_to_dict(item) for item in component.forward_components],
+ }
+
+ return {"type": "dict", "data": component.data}
+
+ @staticmethod
+ def _forward_component_to_dict(component: ForwardComponent) -> Dict[str, Any]:
+ """将单个转发节点组件转换为字典结构。
+
+ Args:
+ component: 待转换的转发节点组件。
+
+ Returns:
+ Dict[str, Any]: 序列化后的转发节点字典。
+ """
+ return {
+ "user_id": component.user_id,
+ "user_nickname": component.user_nickname,
+ "user_cardname": component.user_cardname,
+ "message_id": component.message_id,
+ "content": [PluginMessageUtils._component_to_dict(item) for item in component.content],
+ }
+
+ @staticmethod
+ def _message_sequence_from_dict(raw_message_data: List[Dict[str, Any]]) -> MessageSequence:
+ """从插件运行时字典结构恢复消息组件序列。
+
+ Args:
+ raw_message_data: 插件运行时消息段字典列表。
+
+ Returns:
+ MessageSequence: 恢复后的消息组件序列。
+ """
+ components = [PluginMessageUtils._component_from_dict(item) for item in raw_message_data]
+ return MessageSequence(components=components)
+
+ @staticmethod
+ def _component_from_dict(item: Dict[str, Any]) -> StandardMessageComponents:
+ """从插件运行时字典结构恢复单个消息组件。
+
+ Args:
+ item: 单个消息组件的字典表示。
+
+ Returns:
+ StandardMessageComponents: 恢复后的内部消息组件对象。
+ """
+ item_type = str(item.get("type") or "").strip()
+ if item_type == "text":
+ return TextComponent(text=str(item.get("data") or ""))
+
+ if item_type == "image":
+ return PluginMessageUtils._build_binary_component(ImageComponent, item)
+
+ if item_type == "emoji":
+ return PluginMessageUtils._build_binary_component(EmojiComponent, item)
+
+ if item_type == "voice":
+ return PluginMessageUtils._build_binary_component(VoiceComponent, item)
+
+ if item_type == "at":
+ item_data = item.get("data", {})
+ if not isinstance(item_data, dict):
+ item_data = {}
+ return AtComponent(
+ target_user_id=str(item_data.get("target_user_id") or ""),
+ target_user_nickname=PluginMessageUtils._normalize_optional_string(item_data.get("target_user_nickname")),
+ target_user_cardname=PluginMessageUtils._normalize_optional_string(item_data.get("target_user_cardname")),
+ )
+
+ if item_type == "reply":
+ reply_data = item.get("data")
+ if isinstance(reply_data, dict):
+ return ReplyComponent(
+ target_message_id=str(reply_data.get("target_message_id") or ""),
+ target_message_content=PluginMessageUtils._normalize_optional_string(
+ reply_data.get("target_message_content")
+ ),
+ target_message_sender_id=PluginMessageUtils._normalize_optional_string(
+ reply_data.get("target_message_sender_id")
+ ),
+ target_message_sender_nickname=PluginMessageUtils._normalize_optional_string(
+ reply_data.get("target_message_sender_nickname")
+ ),
+ target_message_sender_cardname=PluginMessageUtils._normalize_optional_string(
+ reply_data.get("target_message_sender_cardname")
+ ),
+ )
+ return ReplyComponent(target_message_id=str(reply_data or ""))
+
+ if item_type == "forward":
+ forward_nodes: List[ForwardComponent] = []
+ raw_forward_nodes = item.get("data", [])
+ if isinstance(raw_forward_nodes, list):
+ for node in raw_forward_nodes:
+ if not isinstance(node, dict):
+ continue
+ raw_content = node.get("content", [])
+ node_components: List[StandardMessageComponents] = []
+ if isinstance(raw_content, list):
+ node_components = [
+ PluginMessageUtils._component_from_dict(content)
+ for content in raw_content
+ if isinstance(content, dict)
+ ]
+ if not node_components:
+ node_components = [TextComponent(text="[empty forward node]")]
+ forward_nodes.append(
+ ForwardComponent(
+ user_nickname=str(node.get("user_nickname") or "未知用户"),
+ user_id=PluginMessageUtils._normalize_optional_string(node.get("user_id")),
+ user_cardname=PluginMessageUtils._normalize_optional_string(node.get("user_cardname")),
+ message_id=str(node.get("message_id") or ""),
+ content=node_components,
+ )
+ )
+ if not forward_nodes:
+ return DictComponent(data={"type": "forward", "data": item.get("data", [])})
+ return ForwardNodeComponent(forward_components=forward_nodes)
+
+ component_data = item.get("data")
+ if isinstance(component_data, dict):
+ return DictComponent(data=component_data)
+ return DictComponent(data=item)
+
+ @staticmethod
+ def _build_binary_component(component_cls: Any, item: Dict[str, Any]) -> StandardMessageComponents:
+ """从字典构造带二进制负载的消息组件。
+
+ Args:
+ component_cls: 目标组件类型。
+ item: 消息组件字典。
+
+ Returns:
+ StandardMessageComponents: 构造后的组件对象。
+ """
+ content = str(item.get("data") or "")
+ binary_hash = str(item.get("hash") or "")
+ raw_binary_base64 = item.get("binary_data_base64")
+ binary_data = b""
+ if isinstance(raw_binary_base64, str) and raw_binary_base64:
+ try:
+ binary_data = base64.b64decode(raw_binary_base64)
+ except Exception:
+ binary_data = b""
+
+ if not binary_hash and binary_data:
+ binary_hash = hashlib.sha256(binary_data).hexdigest()
+
+ return component_cls(binary_hash=binary_hash, content=content, binary_data=binary_data)
+
+ @staticmethod
+ def _normalize_optional_string(value: Any) -> Optional[str]:
+ """将任意值规范化为可选字符串。
+
+ Args:
+ value: 待规范化的值。
+
+ Returns:
+ Optional[str]: 规范化后的字符串;若值为空则返回 ``None``。
+ """
+ if value is None:
+ return None
+ normalized_value = str(value)
+ return normalized_value if normalized_value else None
+
+ @staticmethod
+ def _message_info_to_dict(message_info: MessageInfo) -> MessageInfoDict:
+ """
+ 将 MessageInfo 对象转换为字典格式
+
+ Args:
+ message_info: MessageInfo 对象
+
+ Returns:
+ 字典格式的消息信息
+ """
+ user_info_dict = UserInfoDict(
+ user_id=message_info.user_info.user_id,
+ user_nickname=message_info.user_info.user_nickname,
+ user_cardname=message_info.user_info.user_cardname,
+ )
+
+ group_info_dict: Optional[GroupInfoDict] = None
+ if message_info.group_info:
+ group_info_dict = GroupInfoDict(
+ group_id=message_info.group_info.group_id,
+ group_name=message_info.group_info.group_name,
+ )
+
+ return MessageInfoDict(
+ user_info=user_info_dict,
+ group_info=group_info_dict,
+ additional_config=message_info.additional_config,
+ )
+
+ @staticmethod
+ def _session_message_to_dict(session_message: SessionMessage) -> MessageDict:
+ """
+ 将 SessionMessage 对象转换为字典格式(复用 MessageSequence.to_dict 方法)
+
+ Args:
+ session_message: SessionMessage 对象
+
+ Returns:
+ 字典格式的消息
+ """
+ # 转换基本信息
+ message_dict = MessageDict(
+ message_id=session_message.message_id,
+ timestamp=str(session_message.timestamp.timestamp()), # 转换为时间戳字符串
+ platform=session_message.platform,
+ message_info=PluginMessageUtils._message_info_to_dict(session_message.message_info),
+ raw_message=PluginMessageUtils._message_sequence_to_dict(session_message.raw_message),
+ is_mentioned=session_message.is_mentioned,
+ is_at=session_message.is_at,
+ is_emoji=session_message.is_emoji,
+ is_picture=session_message.is_picture,
+ is_command=session_message.is_command,
+ is_notify=session_message.is_notify,
+ session_id=session_message.session_id,
+ )
+
+ # 添加可选字段
+ if session_message.reply_to is not None:
+ message_dict["reply_to"] = session_message.reply_to
+ if session_message.processed_plain_text is not None:
+ message_dict["processed_plain_text"] = session_message.processed_plain_text
+ if session_message.display_message is not None:
+ message_dict["display_message"] = session_message.display_message
+
+ return message_dict
+
+ @staticmethod
+ def _build_message_info_from_dict(message_info_dict: Dict[str, Any]) -> MessageInfo:
+ """
+ 从字典构建 MessageInfo 对象
+
+ Args:
+ message_info_dict: 包含消息信息的字典
+
+ Returns:
+ MessageInfo 对象
+ """
+ # 构建用户信息
+ user_info_dict = message_info_dict.get("user_info")
+ if not user_info_dict or not isinstance(user_info_dict, dict):
+ raise ValueError("消息字典中 'user_info' 字段无效")
+ user_id = user_info_dict.get("user_id")
+ user_nickname = user_info_dict.get("user_nickname")
+ user_cardname = user_info_dict.get("user_cardname")
+ if not isinstance(user_id, str) or not isinstance(user_nickname, str) or not user_id or not user_nickname:
+ raise ValueError("消息字典中 'user_info' 字段缺少有效的 'user_id' 或 'user_nickname'")
+ user_cardname = str(user_cardname) if user_cardname is not None else None
+ user_info = UserInfo(user_id=user_id, user_nickname=user_nickname, user_cardname=user_cardname)
+
+ # 构建群信息
+ if group_info_dict := message_info_dict.get("group_info"):
+ group_id = group_info_dict.get("group_id")
+ group_name = group_info_dict.get("group_name")
+ if not isinstance(group_id, str) or not isinstance(group_name, str) or not group_id or not group_name:
+ raise ValueError("消息字典中 'group_info' 字段缺少有效的 'group_id' 或 'group_name'")
+ group_info = GroupInfo(group_id=group_id, group_name=group_name)
+ else:
+ group_info = None
+
+ # 获取额外配置
+ additional_config: Dict[str, Any] = message_info_dict.get("additional_config", {})
+
+ return MessageInfo(user_info=user_info, group_info=group_info, additional_config=additional_config)
+
+ @staticmethod
+ def _build_session_message_from_dict(message_dict: Dict[str, Any]) -> SessionMessage:
+ """
+ 从字典构建 SessionMessage 对象(递归处理消息组件)
+
+ Args:
+ message_dict: 包含消息完整信息的字典
+
+ Returns:
+ SessionMessage 对象
+ """
+ # 提取基本信息
+ message_id = message_dict["message_id"]
+ timestamp_str: str = message_dict.get("timestamp", "")
+ platform = message_dict["platform"]
+ if not isinstance(message_id, str) or not message_id:
+ raise ValueError("消息字典中缺少有效的 'message_id' 字段")
+ if not isinstance(platform, str) or not platform:
+ raise ValueError("消息字典中缺少有效的 'platform' 字段")
+
+ # 解析时间戳
+ try:
+ timestamp_float = float(timestamp_str)
+ timestamp = datetime.fromtimestamp(timestamp_float)
+ except (ValueError, TypeError):
+ timestamp = datetime.now() # 如果解析失败,使用当前时间
+
+ # 创建 SessionMessage 实例
+ session_message = SessionMessage(message_id=message_id, timestamp=timestamp, platform=platform)
+
+ # 构建消息信息
+ session_message.message_info = PluginMessageUtils._build_message_info_from_dict(message_dict["message_info"])
+
+ # 构建原始消息组件序列(复用 MessageSequence.from_dict 方法)
+ raw_message_data = message_dict["raw_message"]
+ if isinstance(raw_message_data, list):
+ session_message.raw_message = PluginMessageUtils._message_sequence_from_dict(raw_message_data)
+ else:
+ raise ValueError("消息字典中 'raw_message' 字段必须是一个列表")
+
+ # 设置其他可选属性
+ session_message.is_mentioned = message_dict.get("is_mentioned", False)
+ if not isinstance(session_message.is_mentioned, bool):
+ session_message.is_mentioned = False
+ session_message.is_at = message_dict.get("is_at", False)
+ if not isinstance(session_message.is_at, bool):
+ session_message.is_at = False
+ session_message.is_emoji = message_dict.get("is_emoji", False)
+ if not isinstance(session_message.is_emoji, bool):
+ session_message.is_emoji = False
+ session_message.is_picture = message_dict.get("is_picture", False)
+ if not isinstance(session_message.is_picture, bool):
+ session_message.is_picture = False
+ session_message.is_command = message_dict.get("is_command", False)
+ if not isinstance(session_message.is_command, bool):
+ session_message.is_command = False
+ session_message.is_notify = message_dict.get("is_notify", False)
+ if not isinstance(session_message.is_notify, bool):
+ session_message.is_notify = False
+ session_message.session_id = message_dict.get("session_id", "")
+ if not isinstance(session_message.session_id, str):
+ session_message.session_id = ""
+ session_message.reply_to = message_dict.get("reply_to")
+ if session_message.reply_to is not None and not isinstance(session_message.reply_to, str):
+ session_message.reply_to = None
+ session_message.processed_plain_text = message_dict.get("processed_plain_text")
+ if session_message.processed_plain_text is not None and not isinstance(
+ session_message.processed_plain_text, str
+ ):
+ session_message.processed_plain_text = None
+ session_message.display_message = message_dict.get("display_message")
+ if session_message.display_message is not None and not isinstance(session_message.display_message, str):
+ session_message.display_message = None
+
+ return session_message
diff --git a/src/plugin_runtime/host/policy_engine.py b/src/plugin_runtime/host/policy_engine.py
deleted file mode 100644
index 61b32480..00000000
--- a/src/plugin_runtime/host/policy_engine.py
+++ /dev/null
@@ -1,97 +0,0 @@
-"""策略引擎
-
-负责能力授权校验。
-每个插件在 manifest 中声明能力需求,Host 启动时签发能力令牌。
-"""
-
-from dataclasses import dataclass, field
-from typing import Dict, List, Optional, Set, Tuple
-
-
-@dataclass
-class CapabilityToken:
- """能力令牌"""
-
- plugin_id: str
- generation: int
- capabilities: Set[str] = field(default_factory=set)
-
-
-class PolicyEngine:
- """策略引擎
-
- 管理所有插件的能力令牌,提供授权校验。
- """
-
- def __init__(self) -> None:
- self._tokens: Dict[str, Dict[int, CapabilityToken]] = {}
-
- def register_plugin(
- self,
- plugin_id: str,
- generation: int,
- capabilities: List[str],
- ) -> CapabilityToken:
- """为插件签发能力令牌"""
- token = CapabilityToken(
- plugin_id=plugin_id,
- generation=generation,
- capabilities=set(capabilities),
- )
- self._tokens.setdefault(plugin_id, {})[generation] = token
- return token
-
- def revoke_plugin(self, plugin_id: str, generation: Optional[int] = None) -> None:
- """撤销插件的能力令牌。"""
- if generation is None:
- self._tokens.pop(plugin_id, None)
- return
-
- generations = self._tokens.get(plugin_id)
- if generations is None:
- return
-
- generations.pop(generation, None)
- if not generations:
- self._tokens.pop(plugin_id, None)
-
- def clear(self) -> None:
- """清空所有能力令牌。"""
- self._tokens.clear()
-
- def check_capability(self, plugin_id: str, capability: str, generation: Optional[int] = None) -> Tuple[bool, str]:
- """检查插件是否有权调用某项能力
-
- Returns:
- (allowed, reason)
- """
- generations = self._tokens.get(plugin_id)
- if not generations:
- return False, f"插件 {plugin_id} 未注册能力令牌"
-
- if generation is None:
- token = generations[max(generations)]
- else:
- token = generations.get(generation)
- if token is None:
- active_generation = max(generations)
- return False, f"插件 {plugin_id} generation 不匹配: {generation} != {active_generation}"
-
- if capability not in token.capabilities:
- return False, f"插件 {plugin_id} 未获授权能力: {capability}"
-
- if generation is not None and token.generation != generation:
- return False, f"插件 {plugin_id} generation 不匹配: {generation} != {token.generation}"
-
- return True, ""
-
- def get_token(self, plugin_id: str) -> Optional[CapabilityToken]:
- """获取插件的能力令牌"""
- generations = self._tokens.get(plugin_id)
- if not generations:
- return None
- return generations[max(generations)]
-
- def list_plugins(self) -> List[str]:
- """列出所有已注册的插件"""
- return list(self._tokens.keys())
diff --git a/src/plugin_runtime/host/rpc_server.py b/src/plugin_runtime/host/rpc_server.py
index 79fe0d9a..eb6768c2 100644
--- a/src/plugin_runtime/host/rpc_server.py
+++ b/src/plugin_runtime/host/rpc_server.py
@@ -7,7 +7,7 @@
4. 请求-响应关联与超时管理
"""
-from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple, Coroutine
import asyncio
import contextlib
@@ -32,7 +32,7 @@ from src.plugin_runtime.transport.base import Connection, TransportServer
logger = get_logger("plugin_runtime.host.rpc_server")
# RPC 方法处理器类型
-MethodHandler = Callable[[Envelope], Awaitable[Envelope]]
+MethodHandler = Callable[[Envelope], Coroutine[Any, Any, Envelope]]
class RPCServer:
@@ -55,108 +55,39 @@ class RPCServer:
self._id_gen = RequestIdGenerator()
self._connection: Optional[Connection] = None # 当前活跃的 Runner 连接
- self._runner_id: Optional[str] = None
- self._runner_generation: int = 0
- self._staged_connection: Optional[Connection] = None
- self._staged_runner_id: Optional[str] = None
- self._staged_runner_generation: int = 0
- self._staging_takeover: bool = False
# 方法处理器注册表
self._method_handlers: Dict[str, MethodHandler] = {}
- # 等待响应的 pending 请求: request_id -> (Future, target_generation)
- self._pending_requests: Dict[int, Tuple[asyncio.Future, int]] = {}
+ # 等待响应的 pending 请求: request_id -> Future
+ self._pending_requests: Dict[int, asyncio.Future[Envelope]] = {}
# 发送队列(背压控制)
self._send_queue: Optional[asyncio.Queue[Tuple[Connection, bytes, asyncio.Future[None]]]] = None
- self._send_worker_task: Optional[asyncio.Task] = None
+ self._send_worker_task: Optional[asyncio.Task[None]] = None
# 运行状态
self._running: bool = False
- self._tasks: List[asyncio.Task] = []
+ self._tasks: List[asyncio.Task[None]] = []
+ self._last_handshake_rejection_reason: str = ""
+ self._connection_lock: asyncio.Lock = asyncio.Lock()
@property
def session_token(self) -> str:
return self._session_token
- def reset_session_token(self) -> str:
- """重新生成会话令牌(热重载时调用,防止旧 Runner 重连)"""
- self._session_token = secrets.token_hex(32)
- return self._session_token
-
- def restore_session_token(self, token: str) -> None:
- """恢复指定的会话令牌(热重载回滚时调用)"""
- self._session_token = token
-
- @property
- def runner_generation(self) -> int:
- return self._runner_generation
-
- @property
- def staged_generation(self) -> int:
- return self._staged_runner_generation
-
@property
def is_connected(self) -> bool:
return self._connection is not None and not self._connection.is_closed
- def has_generation(self, generation: int) -> bool:
- return generation == self._runner_generation or (
- self._staged_connection is not None
- and not self._staged_connection.is_closed
- and generation == self._staged_runner_generation
- )
+ @property
+ def last_handshake_rejection_reason(self) -> str:
+ """返回最近一次握手被拒绝的原因。"""
+ return self._last_handshake_rejection_reason
- def begin_staged_takeover(self) -> None:
- """允许新 Runner 以 staged 方式接入,待 Supervisor 验证后再切换为活跃连接。"""
- self._staging_takeover = True
-
- async def commit_staged_takeover(self) -> None:
- """提交 staged Runner,原活跃连接在提交后被关闭。"""
- if self._staged_connection is None or self._staged_connection.is_closed:
- raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "没有可提交的新 Runner 连接")
-
- old_connection = self._connection
- old_generation = self._runner_generation
-
- self._connection = self._staged_connection
- self._runner_id = self._staged_runner_id
- self._runner_generation = self._staged_runner_generation
-
- self._staged_connection = None
- self._staged_runner_id = None
- self._staged_runner_generation = 0
- self._staging_takeover = False
-
- if stale_count := self._fail_pending_requests(
- ErrorCode.E_PLUGIN_CRASHED,
- "Runner 连接已被新 generation 接管",
- generation=old_generation,
- ):
- logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求")
-
- if old_connection and old_connection is not self._connection and not old_connection.is_closed:
- await old_connection.close()
-
- async def rollback_staged_takeover(self) -> None:
- """放弃 staged Runner,保留当前活跃连接。"""
- staged_connection = self._staged_connection
- staged_generation = self._staged_runner_generation
-
- self._staged_connection = None
- self._staged_runner_id = None
- self._staged_runner_generation = 0
- self._staging_takeover = False
-
- self._fail_pending_requests(
- ErrorCode.E_PLUGIN_CRASHED,
- "新 Runner 预热失败,已回滚",
- generation=staged_generation,
- )
-
- if staged_connection and not staged_connection.is_closed:
- await staged_connection.close()
+ def clear_handshake_state(self) -> None:
+ """清空最近一次握手拒绝状态。"""
+ self._last_handshake_rejection_reason = ""
def register_method(self, method: str, handler: MethodHandler) -> None:
"""注册 RPC 方法处理器"""
@@ -165,6 +96,7 @@ class RPCServer:
async def start(self) -> None:
"""启动 RPC 服务器"""
self._running = True
+ self.clear_handshake_state()
self._send_queue = asyncio.Queue(maxsize=self._send_queue_size)
self._send_worker_task = asyncio.create_task(self._send_loop())
await self._transport.start(self._handle_connection)
@@ -173,14 +105,9 @@ class RPCServer:
async def stop(self) -> None:
"""停止 RPC 服务器"""
self._running = False
-
- # 取消所有 pending 请求
- for future, _generation in self._pending_requests.values():
- if not future.done():
- future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
- self._pending_requests.clear()
-
- self._fail_queued_sends(ErrorCode.E_TIMEOUT, "服务器关闭")
+ self.clear_handshake_state()
+ self._fail_pending_requests(ErrorCode.E_SHUTTING_DOWN, "服务器正在关闭")
+ self._fail_queued_sends(ErrorCode.E_SHUTTING_DOWN, "服务器正在关闭")
if self._send_worker_task:
self._send_worker_task.cancel()
@@ -198,10 +125,6 @@ class RPCServer:
await self._connection.close()
self._connection = None
- if self._staged_connection:
- await self._staged_connection.close()
- self._staged_connection = None
-
await self._transport.stop()
logger.info("RPC Server 已停止")
@@ -211,7 +134,6 @@ class RPCServer:
plugin_id: str = "",
payload: Optional[Dict[str, Any]] = None,
timeout_ms: int = 30000,
- target_generation: Optional[int] = None,
) -> Envelope:
"""向 Runner 发送 RPC 请求并等待响应
@@ -227,18 +149,14 @@ class RPCServer:
Raises:
RPCError: 调用失败
"""
- generation = target_generation or self._runner_generation
- conn = self._get_connection_for_generation(generation)
- if conn is None or conn.is_closed:
+ if not self._connection or self._connection.is_closed:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
-
- request_id = self._id_gen.next()
+ request_id = await self._id_gen.next()
envelope = Envelope(
request_id=request_id,
message_type=MessageType.REQUEST,
method=method,
plugin_id=plugin_id,
- generation=generation,
timeout_ms=timeout_ms,
payload=payload or {},
)
@@ -246,12 +164,12 @@ class RPCServer:
# 注册 pending future
loop = asyncio.get_running_loop()
future: asyncio.Future[Envelope] = loop.create_future()
- self._pending_requests[request_id] = (future, generation)
+ self._pending_requests[request_id] = future
try:
# 发送请求
data = self._codec.encode_envelope(envelope)
- await self._enqueue_send(conn, data)
+ await self._enqueue_send(self._connection, data)
# 等待响应
timeout_sec = timeout_ms / 1000.0
@@ -265,150 +183,136 @@ class RPCServer:
raise
raise RPCError(ErrorCode.E_UNKNOWN, str(e)) from e
- async def send_event(self, method: str, plugin_id: str = "", payload: Optional[Dict[str, Any]] = None) -> None:
- """向 Runner 发送单向事件(不等待响应)"""
- conn = self._connection
- if conn is None or conn.is_closed:
- return
+ # ============ 内部方法 ============
+ # ========= 发送循环 =========
+ async def _send_loop(self) -> None:
+ """后台发送循环:串行消费发送队列,统一执行连接写入。"""
+ if self._send_queue is None:
+ raise RuntimeError("没有消息队列")
- request_id = self._id_gen.next()
- envelope = Envelope(
- request_id=request_id,
- message_type=MessageType.EVENT,
- method=method,
- plugin_id=plugin_id,
- generation=self._runner_generation,
- payload=payload or {},
- )
- data = self._codec.encode_envelope(envelope)
- await self._enqueue_send(conn, data)
+ while True:
+ try:
+ conn, data, send_future = await self._send_queue.get()
+ except asyncio.CancelledError:
+ break
- # ─── 内部方法 ──────────────────────────────────────────────
+ try:
+ if conn.is_closed:
+ raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
+ await conn.send_frame(data)
+ if not send_future.done():
+ send_future.set_result(None)
+ except asyncio.CancelledError:
+ if not send_future.done():
+ send_future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
+ raise
+ except Exception as e:
+ send_error = RPCError.from_exception(e, {ConnectionError: ErrorCode.E_PLUGIN_CRASHED})
+ if not send_future.done():
+ send_future.set_exception(send_error)
+ finally:
+ self._send_queue.task_done()
+ # ====== 发送循环方法 ======
async def _handle_connection(self, conn: Connection) -> None:
"""处理新的 Runner 连接"""
logger.info("收到 Runner 连接")
- previous_connection = self._connection
- previous_generation = self._runner_generation
-
- # 第一条消息必须是 runner.hello 握手
try:
- role = await self._handle_handshake(conn)
- if role is None:
- await conn.close()
- return
+ async with self._connection_lock:
+ self.clear_handshake_state()
+ success = await self._handle_handshake(conn)
+ if not success:
+ await conn.close()
+ return
+ logger.info("Runner staged 握手成功")
+ self._connection = conn
except Exception as e:
logger.error(f"握手失败: {e}")
await conn.close()
return
- if role == "staged":
- expected_generation = self._staged_runner_generation
- logger.info(
- f"Runner staged 握手成功: runner_id={self._staged_runner_id}, generation={self._staged_runner_generation}"
- )
- else:
- self._connection = conn
- expected_generation = self._runner_generation
- logger.info(f"Runner 握手成功: runner_id={self._runner_id}, generation={self._runner_generation}")
-
- if previous_connection and previous_connection is not conn and not previous_connection.is_closed:
- logger.info("检测到新 Runner 已接管连接,关闭旧连接")
- if stale_count := self._fail_pending_requests(
- ErrorCode.E_PLUGIN_CRASHED,
- "Runner 连接已被新 generation 接管",
- generation=previous_generation,
- ):
- logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求")
- await previous_connection.close()
-
# 启动消息接收循环
try:
- await self._recv_loop(conn, expected_generation=expected_generation)
+ await self._recv_loop(conn)
except Exception as e:
logger.error(f"连接异常断开: {e}")
finally:
- if self._connection is conn:
- self._connection = None
- self._runner_id = None
- self._fail_pending_requests(
- ErrorCode.E_PLUGIN_CRASHED,
- "Runner 连接已断开",
- generation=expected_generation,
- )
- elif self._staged_connection is conn:
- self._staged_connection = None
- self._staged_runner_id = None
- self._staged_runner_generation = 0
- self._fail_pending_requests(
- ErrorCode.E_PLUGIN_CRASHED,
- "Staged Runner 连接已断开",
- generation=expected_generation,
- )
+ should_fail_pending_requests = False
+ async with self._connection_lock:
+ if self._connection is conn:
+ self._connection = None
+ should_fail_pending_requests = True
+ if should_fail_pending_requests:
+ self._fail_pending_requests(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已断开")
- async def _handle_handshake(self, conn: Connection) -> Optional[str]:
+ async def _handle_handshake(self, conn: Connection) -> bool:
"""处理 runner.hello 握手"""
# 接收握手请求
data = await asyncio.wait_for(conn.recv_frame(), timeout=10.0)
envelope = self._codec.decode_envelope(data)
-
if envelope.method != "runner.hello":
logger.error(f"期望 runner.hello,收到 {envelope.method}")
+ self._last_handshake_rejection_reason = "首条消息必须为 runner.hello"
error_resp = envelope.make_error_response(
ErrorCode.E_PROTOCOL_MISMATCH.value,
"首条消息必须为 runner.hello",
)
await conn.send_frame(self._codec.encode_envelope(error_resp))
- return None
+ return False
# 解析握手 payload
hello = HelloPayload.model_validate(envelope.payload)
-
# 校验会话令牌
if hello.session_token != self._session_token:
logger.error("会话令牌不匹配")
- resp_payload = HelloResponsePayload(
- accepted=False,
- reason="会话令牌无效",
- )
+ self._last_handshake_rejection_reason = "会话令牌无效"
+ resp_payload = HelloResponsePayload(accepted=False, reason=self._last_handshake_rejection_reason)
resp = envelope.make_response(payload=resp_payload.model_dump())
await conn.send_frame(self._codec.encode_envelope(resp))
- return None
+ return False
+
+ # 若已有活跃连接,直接拒绝新的握手,避免后来的连接抢占当前通道。
+ if self.is_connected:
+ logger.warning("拒绝新的 Runner 连接:已有活跃连接")
+ self._last_handshake_rejection_reason = "已有活跃 Runner 连接,拒绝新的握手"
+ resp_payload = HelloResponsePayload(accepted=False, reason=self._last_handshake_rejection_reason)
+ resp = envelope.make_response(payload=resp_payload.model_dump())
+ await conn.send_frame(self._codec.encode_envelope(resp))
+ return False
# 校验 SDK 版本
if not self._check_sdk_version(hello.sdk_version):
logger.error(f"SDK 版本不兼容: {hello.sdk_version}")
+ self._last_handshake_rejection_reason = (
+ f"SDK 版本 {hello.sdk_version} 不在支持范围 [{MIN_SDK_VERSION}, {MAX_SDK_VERSION}]"
+ )
resp_payload = HelloResponsePayload(
accepted=False,
- reason=f"SDK 版本 {hello.sdk_version} 不在支持范围 [{MIN_SDK_VERSION}, {MAX_SDK_VERSION}]",
+ reason=self._last_handshake_rejection_reason,
)
resp = envelope.make_response(payload=resp_payload.model_dump())
await conn.send_frame(self._codec.encode_envelope(resp))
- return None
+ return False
- # 握手成功
- role = "active"
- assigned_generation = self._runner_generation + 1
- if self._staging_takeover and self.is_connected:
- role = "staged"
- self._staged_connection = conn
- self._staged_runner_id = hello.runner_id
- self._staged_runner_generation = assigned_generation
- else:
- self._runner_id = hello.runner_id
- self._runner_generation = assigned_generation
-
- resp_payload = HelloResponsePayload(
- accepted=True,
- host_version=PROTOCOL_VERSION,
- assigned_generation=assigned_generation,
- )
+ # 发送响应
+ self.clear_handshake_state()
+ resp_payload = HelloResponsePayload(accepted=True, host_version=PROTOCOL_VERSION)
resp = envelope.make_response(payload=resp_payload.model_dump())
await conn.send_frame(self._codec.encode_envelope(resp))
+ return True
- return role
+ def _check_sdk_version(self, sdk_version: str) -> bool:
+ """检查 SDK 版本是否在支持范围内"""
+ try:
+ sdk_parts = _parse_version_tuple(sdk_version)
+ min_parts = _parse_version_tuple(MIN_SDK_VERSION)
+ max_parts = _parse_version_tuple(MAX_SDK_VERSION)
+ return min_parts <= sdk_parts <= max_parts
+ except (ValueError, AttributeError):
+ return False
- async def _recv_loop(self, conn: Connection, expected_generation: int) -> None:
+ # ========= 接收循环 =========
+ async def _recv_loop(self, conn: Connection) -> None:
"""消息接收主循环"""
while self._running and not conn.is_closed:
try:
@@ -430,109 +334,40 @@ class RPCServer:
if envelope.is_response():
self._handle_response(envelope)
elif envelope.is_request():
- if envelope.generation != expected_generation:
- error_resp = envelope.make_error_response(
- ErrorCode.E_GENERATION_MISMATCH.value,
- f"过期 generation: {envelope.generation} != {expected_generation}",
- )
- await conn.send_frame(self._codec.encode_envelope(error_resp))
- continue
# 异步处理请求(Runner 发来的能力调用)
task = asyncio.create_task(self._handle_request(envelope, conn))
self._tasks.append(task)
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
- elif envelope.is_event():
- if envelope.generation != expected_generation:
- logger.warning(
- f"忽略过期 generation 事件 {envelope.method}: {envelope.generation} != {expected_generation}"
- )
- continue
- task = asyncio.create_task(self._handle_event(envelope))
+ elif envelope.is_broadcast():
+ task = asyncio.create_task(self._handle_broadcast(envelope))
self._tasks.append(task)
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
+ else:
+ logger.warning(f"未知的消息类型: {envelope.message_type}")
+ continue
+ # ====== 接收循环内部方法 ======
def _handle_response(self, envelope: Envelope) -> None:
"""处理来自 Runner 的响应"""
- pending = self._pending_requests.get(envelope.request_id)
- if pending is None:
+ pending_future = self._pending_requests.pop(envelope.request_id, None)
+ if pending_future is None:
return
-
- future, expected_generation = pending
- if envelope.generation != expected_generation:
- logger.warning(
- f"忽略过期 generation 响应 {envelope.method}: {envelope.generation} != {expected_generation}"
- )
- return
-
- self._pending_requests.pop(envelope.request_id, None)
- if not future.done():
+ if not pending_future.done():
if envelope.error:
- future.set_exception(RPCError.from_dict(envelope.error))
+ pending_future.set_exception(RPCError.from_dict(envelope.error))
else:
- future.set_result(envelope)
-
- async def _enqueue_send(self, conn: Connection, data: bytes) -> None:
- """通过发送队列串行发送消息,提供真实背压。"""
- if conn.is_closed:
- raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
-
- if self._send_queue is None:
- await conn.send_frame(data)
- return
-
- loop = asyncio.get_running_loop()
- send_future: asyncio.Future[None] = loop.create_future()
-
- try:
- self._send_queue.put_nowait((conn, data, send_future))
- except asyncio.QueueFull:
- raise RPCError(ErrorCode.E_BACKPRESSURE, "发送队列已满") from None
-
- await send_future
-
- async def _send_loop(self) -> None:
- """后台发送循环:串行消费发送队列,统一执行连接写入。"""
- if self._send_queue is None:
- return
-
- while True:
- try:
- conn, data, send_future = await self._send_queue.get()
- except asyncio.CancelledError:
- break
-
- try:
- if conn.is_closed:
- raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
- await conn.send_frame(data)
- if not send_future.done():
- send_future.set_result(None)
- except asyncio.CancelledError:
- if not send_future.done():
- send_future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
- raise
- except Exception as e:
- send_error = e if isinstance(e, RPCError) else self._normalize_send_exception(e)
- if not send_future.done():
- send_future.set_exception(send_error)
- finally:
- self._send_queue.task_done()
-
- @staticmethod
- def _normalize_send_exception(error: Exception) -> RPCError:
- if isinstance(error, ConnectionError):
- return RPCError(ErrorCode.E_PLUGIN_CRASHED, str(error))
- return RPCError(ErrorCode.E_UNKNOWN, str(error))
+ pending_future.set_result(envelope)
async def _handle_request(self, envelope: Envelope, conn: Connection) -> None:
"""处理来自 Runner 的请求(通常是能力调用 cap.*)"""
- handler = self._method_handlers.get(envelope.method)
- if handler is None:
- error_resp = envelope.make_error_response(
+ target_method = envelope.method
+ handler = self._method_handlers.get(target_method)
+ if not handler:
+ error_response = envelope.make_error_response(
ErrorCode.E_METHOD_NOT_ALLOWED.value,
f"未注册的方法: {envelope.method}",
)
- await conn.send_frame(self._codec.encode_envelope(error_resp))
+ await conn.send_frame(self._codec.encode_envelope(error_response))
return
try:
@@ -546,59 +381,25 @@ class RPCServer:
error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
await conn.send_frame(self._codec.encode_envelope(error_resp))
- async def _handle_event(self, envelope: Envelope) -> None:
- """处理来自 Runner 的事件"""
+ async def _handle_broadcast(self, envelope: Envelope) -> None:
if handler := self._method_handlers.get(envelope.method):
try:
result = await handler(envelope)
# 检查 handler 返回的信封是否包含错误信息
- if result is not None and isinstance(result, Envelope) and result.error:
+ if result.error:
logger.warning(f"事件 {envelope.method} handler 返回错误: {result.error.get('message', '')}")
except Exception as e:
logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True)
- @staticmethod
- def _check_sdk_version(sdk_version: str) -> bool:
- """检查 SDK 版本是否在支持范围内"""
- try:
- sdk_parts = RPCServer._parse_version_tuple(sdk_version)
- min_parts = RPCServer._parse_version_tuple(MIN_SDK_VERSION)
- max_parts = RPCServer._parse_version_tuple(MAX_SDK_VERSION)
- return min_parts <= sdk_parts <= max_parts
- except (ValueError, AttributeError):
- return False
-
- @staticmethod
- def _parse_version_tuple(version: str) -> Tuple[int, int, int]:
- base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version or "", flags=re.IGNORECASE)[0]
- base_version = base_version.split("+", 1)[0]
- parts = [part for part in base_version.split(".") if part != ""]
- while len(parts) < 3:
- parts.append("0")
- return (int(parts[0]), int(parts[1]), int(parts[2]))
-
- def _get_connection_for_generation(self, generation: int) -> Optional[Connection]:
- if generation == self._runner_generation:
- return self._connection
- if generation == self._staged_runner_generation:
- return self._staged_connection
- return None
-
- def _fail_pending_requests(
- self,
- error_code: ErrorCode,
- message: str,
- generation: Optional[int] = None,
- ) -> int:
- stale_count = 0
- for request_id, (future, request_generation) in list(self._pending_requests.items()):
- if generation is not None and request_generation != generation:
- continue
+ def _fail_pending_requests(self, error_code: ErrorCode, message: str) -> int:
+ """失败所有等待中的请求(如连接断开时)"""
+ aborted_request_count = 0
+ for future in self._pending_requests.values():
if not future.done():
future.set_exception(RPCError(error_code, message))
- stale_count += 1
- self._pending_requests.pop(request_id, None)
- return stale_count
+ aborted_request_count += 1
+ self._pending_requests.clear()
+ return aborted_request_count
def _fail_queued_sends(self, error_code: ErrorCode, message: str) -> int:
if self._send_queue is None:
@@ -617,3 +418,31 @@ class RPCServer:
self._send_queue.task_done()
return failed_count
+
+ async def _enqueue_send(self, conn: Connection, data: bytes) -> None:
+ """通过发送队列串行发送消息,提供真实背压。"""
+ if conn.is_closed:
+ raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
+
+ if self._send_queue is None:
+ await conn.send_frame(data)
+ return
+
+ loop = asyncio.get_running_loop()
+ send_future: asyncio.Future[None] = loop.create_future()
+
+ try:
+ self._send_queue.put_nowait((conn, data, send_future))
+ except asyncio.QueueFull:
+ raise RPCError(ErrorCode.E_BACK_PRESSURE, "发送队列已满") from None
+
+ await send_future
+
+
+def _parse_version_tuple(version: str) -> Tuple[int, int, int]:
+ base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version or "", flags=re.IGNORECASE)[0]
+ base_version = base_version.split("+", 1)[0]
+ parts = [part for part in base_version.split(".") if part != ""]
+ while len(parts) < 3:
+ parts.append("0")
+ return (int(parts[0]), int(parts[1]), int(parts[2]))
diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py
index bfa00cbf..08638d16 100644
--- a/src/plugin_runtime/host/supervisor.py
+++ b/src/plugin_runtime/host/supervisor.py
@@ -1,97 +1,80 @@
-"""Supervisor - 插件生命周期管理
-
-负责:
-1. 拉起 Runner 子进程
-2. 健康检查 + 崩溃自动重启
-3. 代码热重载(generation 切换)
-4. 优雅关停
-"""
-
-from typing import Any, Dict, List, Optional, Tuple
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import asyncio
import contextlib
-import logging as stdlib_logging
+import json
import os
import sys
-from pathlib import Path
from src.common.logger import get_logger
-from src.config.config import MMC_VERSION, global_config
-from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN
-from src.plugin_runtime.host.capability_service import CapabilityService
-from src.plugin_runtime.host.component_registry import ComponentRegistry
-from src.plugin_runtime.host.event_dispatcher import EventDispatcher
-from src.plugin_runtime.host.policy_engine import PolicyEngine
-from src.plugin_runtime.host.rpc_server import RPCServer
-from src.plugin_runtime.host.workflow_executor import WorkflowExecutor, WorkflowContext, WorkflowResult
+from src.config.config import config_manager, global_config
+from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, get_platform_io_manager
+from src.platform_io.drivers import PluginPlatformDriver
+from src.platform_io.route_key_factory import RouteKeyFactory
+from src.plugin_runtime import (
+ ENV_EXTERNAL_PLUGIN_IDS,
+ ENV_GLOBAL_CONFIG_SNAPSHOT,
+ ENV_HOST_VERSION,
+ ENV_IPC_ADDRESS,
+ ENV_PLUGIN_DIRS,
+ ENV_SESSION_TOKEN,
+)
from src.plugin_runtime.protocol.envelope import (
BootstrapPluginPayload,
+ ConfigReloadScope,
ConfigUpdatedPayload,
Envelope,
HealthPayload,
- LogBatchPayload,
- RegisterComponentsPayload,
+ MessageGatewayStateUpdatePayload,
+ MessageGatewayStateUpdateResultPayload,
+ PROTOCOL_VERSION,
+ ReceiveExternalMessageResultPayload,
+ RegisterPluginPayload,
+ ReloadPluginResultPayload,
+ ReloadPluginsPayload,
+ ReloadPluginsResultPayload,
+ RouteMessagePayload,
RunnerReadyPayload,
ShutdownPayload,
+ UnregisterPluginPayload,
)
+from src.plugin_runtime.protocol.codec import MsgPackCodec
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
from src.plugin_runtime.transport.factory import create_transport_server
-logger = get_logger("plugin_runtime.host.supervisor")
+from .authorization import AuthorizationManager
+from .api_registry import APIRegistry
+from .capability_service import CapabilityService
+from .component_registry import ComponentRegistry
+from .event_dispatcher import EventDispatcher
+from .hook_dispatcher import HookDispatcher
+from .logger_bridge import RunnerLogBridge
+from .message_gateway import MessageGateway
+from .rpc_server import RPCServer
+
+if TYPE_CHECKING:
+ from src.chat.message_receive.message import SessionMessage
+
+logger = get_logger("plugin_runtime.host.runner_manager")
+
+@dataclass(slots=True)
+class _MessageGatewayRuntimeState:
+ """保存消息网关当前的运行时连接状态。"""
+
+ ready: bool = False
+ platform: Optional[str] = None
+ account_id: Optional[str] = None
+ scope: Optional[str] = None
+ metadata: Dict[str, Any] = field(default_factory=dict)
-# ─── 日志桥 ──────────────────────────────────────────────────────
+class PluginRunnerSupervisor:
+ """插件 Runner 监督器。
-
-class RunnerLogBridge:
- """将 Runner 进程上报的批量日志重放到主进程的 Logger 中。
-
- Runner 通过 ``runner.log_batch`` IPC 事件批量到达。
- 每条 LogEntry 被重建为一个真实的 :class:`logging.LogRecord` 并直接
- 调用 ``logging.getLogger(entry.logger_name).handle(record)``,
- 从而接入主进程已配置好的 structlog Handler 链。
- """
-
- async def handle_log_batch(self, envelope: Envelope) -> Envelope:
- """IPC 事件处理器:解析批量日志并重放到主进程 Logger。
-
- Args:
- envelope: 方法名为 ``runner.log_batch`` 的 IPC 事件信封。
-
- Returns:
- 空响应信封(事件模式下将被忽略)。
- """
- try:
- batch = LogBatchPayload.model_validate(envelope.payload)
- except Exception as exc:
- return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
-
- for entry in batch.entries:
- # 重建一个与原始日志尽量相符的 LogRecord
- record = stdlib_logging.LogRecord(
- name=entry.logger_name,
- level=entry.level,
- pathname="",
- lineno=0,
- msg=entry.message,
- args=(),
- exc_info=None,
- )
- record.created = entry.timestamp_ms / 1000.0
- record.msecs = entry.timestamp_ms % 1000
- if entry.exception_text:
- record.exc_text = entry.exception_text
-
- stdlib_logging.getLogger(entry.logger_name).handle(record)
-
- return envelope.make_response(payload={"accepted": True, "count": len(batch.entries)})
-
-
-class PluginSupervisor:
- """插件 Supervisor
-
- Host 端的核心管理器,负责整个插件 Runner 进程的生命周期。
+ 负责 Host 侧与单个 Runner 子进程之间的生命周期、内部 RPC、
+ 健康检查和插件级重载协调。
"""
def __init__(
@@ -101,196 +84,253 @@ class PluginSupervisor:
health_check_interval_sec: Optional[float] = None,
max_restart_attempts: Optional[int] = None,
runner_spawn_timeout_sec: Optional[float] = None,
- ):
- _cfg = global_config.plugin_runtime
- self._plugin_dirs = plugin_dirs or []
- self._health_interval = (
- health_check_interval_sec if health_check_interval_sec is not None else _cfg.health_check_interval_sec
- )
- self._runner_spawn_timeout = (
- runner_spawn_timeout_sec if runner_spawn_timeout_sec is not None else _cfg.runner_spawn_timeout_sec
- )
+ ) -> None:
+ """初始化 Supervisor。
+
+ Args:
+ plugin_dirs: 由当前 Runner 负责加载的插件目录列表。
+ socket_path: 自定义 IPC 地址;留空时由传输层自动生成。
+ health_check_interval_sec: 健康检查间隔,单位秒。
+ max_restart_attempts: 自动重启 Runner 的最大次数。
+ runner_spawn_timeout_sec: 等待 Runner 建连并就绪的超时时间,单位秒。
+ """
+ runtime_config = global_config.plugin_runtime
+ self._plugin_dirs: List[Path] = plugin_dirs or []
+ self._health_interval: float = health_check_interval_sec or runtime_config.health_check_interval_sec or 30.0
+ self._runner_spawn_timeout: float = (
+ runner_spawn_timeout_sec or runtime_config.runner_spawn_timeout_sec or 30.0
+ )
+ self._max_restart_attempts: int = max_restart_attempts or runtime_config.max_restart_attempts or 3
- # 基础设施
self._transport = create_transport_server(socket_path=socket_path)
- self._policy = PolicyEngine()
- self._capability_service = CapabilityService(self._policy)
+ self._authorization = AuthorizationManager()
+ self._capability_service = CapabilityService(self._authorization)
+ self._api_registry = APIRegistry()
self._component_registry = ComponentRegistry()
self._event_dispatcher = EventDispatcher(self._component_registry)
- self._workflow_executor = WorkflowExecutor(self._component_registry)
-
- # 编解码
- from src.plugin_runtime.protocol.codec import MsgPackCodec
+ self._hook_dispatcher = HookDispatcher(self._component_registry)
+ self._message_gateway = MessageGateway(self._component_registry)
+ self._log_bridge = RunnerLogBridge()
codec = MsgPackCodec()
+ self._rpc_server = RPCServer(transport=self._transport, codec=codec)
- self._rpc_server = RPCServer(
- transport=self._transport,
- codec=codec,
- )
-
- # Runner 子进程
self._runner_process: Optional[asyncio.subprocess.Process] = None
- self._runner_generation: int = 0
- self._max_restart_attempts: int = (
- max_restart_attempts if max_restart_attempts is not None else _cfg.max_restart_attempts
- )
+ self._registered_plugins: Dict[str, RegisterPluginPayload] = {}
+ self._message_gateway_states: Dict[str, Dict[str, _MessageGatewayRuntimeState]] = {}
+ self._external_available_plugins: Dict[str, str] = {}
+ self._runner_ready_events: asyncio.Event = asyncio.Event()
+ self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload()
+ self._health_task: Optional[asyncio.Task[None]] = None
+ self._stderr_drain_task: Optional[asyncio.Task[None]] = None
self._restart_count: int = 0
+ self._running: bool = False
- # 已注册的插件组件信息
- self._registered_plugins: Dict[str, RegisterComponentsPayload] = {}
- self._staged_registered_plugins: Dict[str, RegisterComponentsPayload] = {}
- self._runner_ready_events: Dict[int, asyncio.Event] = {}
- self._runner_ready_payloads: Dict[int, RunnerReadyPayload] = {}
-
- # 后台任务
- self._health_task: Optional[asyncio.Task] = None
- # Runner stderr 流排空任务(仅保留 stderr,用于 IPC 建立前的启动日志倒空、致命错误输出等场景)
- self._stderr_drain_task: Optional[asyncio.Task] = None
- self._running = False
-
- # Runner 日志桥(将 Runner 上报的批量日志重放到主进程 Logger)
- self._log_bridge: RunnerLogBridge = RunnerLogBridge()
-
- # 注册内部 RPC 方法
self._register_internal_methods()
@property
- def policy_engine(self) -> PolicyEngine:
- return self._policy
+ def authorization_manager(self) -> AuthorizationManager:
+ """返回授权管理器。"""
+ return self._authorization
@property
def capability_service(self) -> CapabilityService:
+ """返回能力服务。"""
return self._capability_service
+ @property
+ def api_registry(self) -> APIRegistry:
+ """返回 API 专用注册表。"""
+ return self._api_registry
+
@property
def component_registry(self) -> ComponentRegistry:
+ """返回组件注册表。"""
return self._component_registry
@property
def event_dispatcher(self) -> EventDispatcher:
+ """返回事件分发器。"""
return self._event_dispatcher
@property
- def workflow_executor(self) -> WorkflowExecutor:
- return self._workflow_executor
+ def hook_dispatcher(self) -> HookDispatcher:
+ """返回 Hook 分发器。"""
+ return self._hook_dispatcher
+
+ @property
+ def message_gateway(self) -> MessageGateway:
+ """返回消息网关。"""
+ return self._message_gateway
@property
def rpc_server(self) -> RPCServer:
+ """返回底层 RPC 服务端。"""
return self._rpc_server
+ def set_external_available_plugins(self, plugin_versions: Dict[str, str]) -> None:
+ """设置当前 Runner 启动/重载时可视为已满足的外部依赖版本映射。
+
+ Args:
+ plugin_versions: 外部插件版本映射,键为插件 ID,值为插件版本。
+ """
+ self._external_available_plugins = {
+ str(plugin_id or "").strip(): str(plugin_version or "").strip()
+ for plugin_id, plugin_version in plugin_versions.items()
+ if str(plugin_id or "").strip() and str(plugin_version or "").strip()
+ }
+
+ def get_loaded_plugin_ids(self) -> List[str]:
+ """返回当前 Supervisor 已注册的插件 ID 列表。"""
+
+ return sorted(self._registered_plugins.keys())
+
+ def get_loaded_plugin_versions(self) -> Dict[str, str]:
+ """返回当前 Supervisor 已注册插件的版本映射。
+
+ Returns:
+ Dict[str, str]: 已注册插件版本映射,键为插件 ID,值为插件版本。
+ """
+ return {
+ plugin_id: registration.plugin_version
+ for plugin_id, registration in self._registered_plugins.items()
+ }
+
+ @staticmethod
+ def _normalize_reload_plugin_ids(plugin_ids: Optional[List[str] | str]) -> List[str]:
+ """规范化批量重载入参。
+
+ Args:
+ plugin_ids: 原始插件 ID 列表或单个插件 ID。
+
+ Returns:
+ List[str]: 去重且去空白后的插件 ID 列表。
+ """
+
+ raw_plugin_ids: List[str]
+ if plugin_ids is None:
+ raw_plugin_ids = []
+ elif isinstance(plugin_ids, str):
+ raw_plugin_ids = [plugin_ids]
+ else:
+ raw_plugin_ids = list(plugin_ids)
+
+ normalized_plugin_ids: List[str] = []
+ seen_plugin_ids: set[str] = set()
+ for plugin_id in raw_plugin_ids:
+ normalized_plugin_id = str(plugin_id or "").strip()
+ if not normalized_plugin_id or normalized_plugin_id in seen_plugin_ids:
+ continue
+ seen_plugin_ids.add(normalized_plugin_id)
+ normalized_plugin_ids.append(normalized_plugin_id)
+ return normalized_plugin_ids
+
async def dispatch_event(
self,
event_type: str,
- message: Optional[Dict[str, Any]] = None,
+ message: Optional["SessionMessage"] = None,
extra_args: Optional[Dict[str, Any]] = None,
- ) -> Tuple[bool, Optional[Dict[str, Any]]]:
- """分发事件到所有对应 handler 的快捷方法。"""
+ ) -> Tuple[bool, Optional["SessionMessage"]]:
+ """分发事件到已注册的事件处理器。
- async def _invoke(plugin_id: str, component_name: str, args: Dict[str, Any]) -> Dict[str, Any]:
- resp = await self.invoke_plugin(
- method="plugin.emit_event",
- plugin_id=plugin_id,
- component_name=component_name,
- args=args,
- )
- return resp.payload
+ Args:
+ event_type: 事件类型。
+ message: 可选的消息对象。
+ extra_args: 附加参数。
- return await self._event_dispatcher.dispatch_event(
- event_type=event_type,
- invoke_fn=_invoke,
- message=message,
- extra_args=extra_args,
- )
+ Returns:
+ Tuple[bool, Optional[SessionMessage]]: 是否继续处理,以及插件可能修改后的消息。
+ """
+ return await self._event_dispatcher.dispatch_event(event_type, self, message, extra_args)
- async def execute_workflow(
+ async def dispatch_hook(self, stage: str, **kwargs: Any) -> Dict[str, Any]:
+ """分发 Hook 到已注册的 Hook 处理器。
+
+ Args:
+ stage: Hook 阶段名称。
+ **kwargs: 传递给 Hook 的关键字参数。
+
+ Returns:
+ Dict[str, Any]: 经 Hook 修改后的参数字典。
+ """
+ return await self._hook_dispatcher.hook_dispatch(stage, self, **kwargs)
+
+ async def send_message_to_external(
self,
- message: Optional[Dict[str, Any]] = None,
- stream_id: Optional[str] = None,
- context: Optional[WorkflowContext] = None,
- ) -> Tuple[WorkflowResult, Optional[Dict[str, Any]], WorkflowContext]:
- """执行 Workflow Pipeline 的快捷方法。"""
+ internal_message: "SessionMessage",
+ *,
+ enabled_only: bool = True,
+ save_to_db: bool = True,
+ ) -> bool:
+ """通过插件消息网关发送外部消息。
- async def _invoke(plugin_id: str, component_name: str, args: Dict[str, Any]) -> Dict[str, Any]:
- resp = await self.invoke_plugin(
- method="plugin.invoke_workflow_step",
- plugin_id=plugin_id,
- component_name=component_name,
- args=args,
- )
- payload = resp.payload
- if payload.get("success"):
- result = payload.get("result")
- return result if isinstance(result, dict) else {}
- raise RuntimeError(payload.get("result", "workflow step invoke failed"))
+ Args:
+ internal_message: 系统内部消息对象。
+ enabled_only: 是否仅使用启用的网关组件。
+ save_to_db: 发送成功后是否写入数据库。
- async def _command_invoke(plugin_id: str, component_name: str, args: Dict[str, Any]) -> Dict[str, Any]:
- """命令走 plugin.invoke_command,保留原始返回值结构。"""
- resp = await self.invoke_plugin(
- method="plugin.invoke_command",
- plugin_id=plugin_id,
- component_name=component_name,
- args=args,
- )
- return resp.payload
-
- return await self._workflow_executor.execute(
- invoke_fn=_invoke,
- message=message,
- stream_id=stream_id,
- context=context,
- command_invoke_fn=_command_invoke,
+ Returns:
+ bool: 是否发送成功。
+ """
+ return await self._message_gateway.send_message_to_external(
+ internal_message,
+ self,
+ enabled_only=enabled_only,
+ save_to_db=save_to_db,
)
async def start(self) -> None:
- """启动 Supervisor
+ """启动 Supervisor。"""
+ if self._running:
+ logger.warning("PluginRunnerSupervisor 已在运行,跳过重复启动")
+ return
- 1. 启动 RPC Server
- 2. 拉起 Runner 子进程
- 3. 启动健康检查
- """
self._running = True
+ self._restart_count = 0
+ self._clear_runner_state()
- # 启动 RPC Server
- await self._rpc_server.start()
-
- # 计算预期 generation(与 reload_plugins 保持一致)
- expected_generation = self._rpc_server.runner_generation + 1
-
- # 拉起 Runner 进程
- await self._spawn_runner()
-
- # 等待 Runner 完成连接和初始化,避免 start() 返回时 Runner 尚未就绪
try:
- await self._wait_for_runner_generation(expected_generation, timeout_sec=self._runner_spawn_timeout)
- await self._wait_for_runner_ready(expected_generation, timeout_sec=self._runner_spawn_timeout)
- except TimeoutError:
- if not self._rpc_server.is_connected:
- logger.warning(f"Runner 未在 {self._runner_spawn_timeout}s 内完成连接,后续操作可能失败")
- else:
- logger.warning(f"Runner 未在 {self._runner_spawn_timeout}s 内完成初始化,后续操作可能失败")
+ await self._rpc_server.start()
+ await self._spawn_runner()
- # 启动健康检查
- self._health_task = asyncio.create_task(self._health_check_loop())
+ try:
+ await self._wait_for_runner_connection(timeout_sec=self._runner_spawn_timeout)
+ await self._wait_for_runner_ready(timeout_sec=self._runner_spawn_timeout)
+ except TimeoutError:
+ if not self._rpc_server.is_connected:
+ logger.warning("Runner 未在限定时间内完成连接,后续操作可能失败")
+ else:
+ logger.warning("Runner 未在限定时间内完成初始化,后续操作可能失败")
+ except Exception:
+ await self._shutdown_runner(reason="startup_failed")
+ await self._rpc_server.stop()
+ self._clear_runner_state()
+ self._running = False
+ raise
- logger.info("PluginSupervisor 已启动")
+ self._health_task = asyncio.create_task(self._health_check_loop(), name="PluginRunnerSupervisor.health")
+ logger.info("PluginRunnerSupervisor 已启动")
async def stop(self) -> None:
- """停止 Supervisor"""
+ """停止 Supervisor。"""
+ if not self._running:
+ return
+
self._running = False
- # 停止健康检查
- if self._health_task:
+ if self._health_task is not None:
self._health_task.cancel()
+ with contextlib.suppress(asyncio.CancelledError):
+ await self._health_task
self._health_task = None
- # 优雅关停 Runner
- await self._shutdown_runner()
-
- # 停止 RPC Server
+ await self._event_dispatcher.stop()
+ await self._hook_dispatcher.stop()
+ await self._shutdown_runner(reason="host_stop")
await self._rpc_server.stop()
+ self._clear_runner_state()
- logger.info("PluginSupervisor 已停止")
+ logger.info("PluginRunnerSupervisor 已停止")
async def invoke_plugin(
self,
@@ -300,444 +340,1068 @@ class PluginSupervisor:
args: Optional[Dict[str, Any]] = None,
timeout_ms: int = 30000,
) -> Envelope:
- """调用插件组件
+ """调用 Runner 内的插件组件。
- 由主进程业务逻辑调用,通过 RPC 转发给 Runner。
+ Args:
+ method: RPC 方法名。
+ plugin_id: 目标插件 ID。
+ component_name: 组件名。
+ args: 调用参数。
+ timeout_ms: RPC 超时时间,单位毫秒。
+
+ Returns:
+ Envelope: RPC 响应信封。
"""
return await self._rpc_server.send_request(
- method=method,
+ method,
+ plugin_id,
+ {"component_name": component_name, "args": args or {}},
+ timeout_ms,
+ )
+
+ async def invoke_message_gateway(
+ self,
+ plugin_id: str,
+ component_name: str,
+ args: Optional[Dict[str, Any]] = None,
+ timeout_ms: int = 30000,
+ ) -> Envelope:
+ """调用插件声明的消息网关方法。
+
+ Args:
+ plugin_id: 目标插件 ID。
+ component_name: 消息网关组件名称。
+ args: 传递给网关方法的关键字参数。
+ timeout_ms: RPC 超时时间,单位毫秒。
+
+ Returns:
+ Envelope: Runner 返回的响应信封。
+ """
+
+ return await self.invoke_plugin(
+ method="plugin.invoke_message_gateway",
plugin_id=plugin_id,
- payload={
- "component_name": component_name,
- "args": args or {},
- },
+ component_name=component_name,
+ args=args,
timeout_ms=timeout_ms,
)
- async def reload_plugins(self, reason: str = "manual") -> bool:
- """热重载所有插件(进程级 generation 切换)
+ async def invoke_api(
+ self,
+ plugin_id: str,
+ component_name: str,
+ args: Optional[Dict[str, Any]] = None,
+ timeout_ms: int = 30000,
+ ) -> Envelope:
+ """调用插件声明的 API 方法。
- 1. 拉起新 Runner
- 2. 等待新 Runner 完成注册和健康检查
- 3. 关停旧 Runner
+ Args:
+ plugin_id: 目标插件 ID。
+ component_name: API 组件名称。
+ args: 传递给 API 方法的关键字参数。
+ timeout_ms: RPC 超时时间,单位毫秒。
+
+ Returns:
+ Envelope: Runner 返回的响应信封。
"""
- logger.info(f"开始热重载插件,原因: {reason}")
- # 保存旧进程引用和旧 session token(回滚时需要恢复)
- old_process = self._runner_process
- old_registered_plugins = dict(self._registered_plugins)
- old_session_token = self._rpc_server.session_token
- expected_generation = self._rpc_server.runner_generation + 1
+ return await self.invoke_plugin(
+ method="plugin.invoke_api",
+ plugin_id=plugin_id,
+ component_name=component_name,
+ args=args,
+ timeout_ms=timeout_ms,
+ )
- # 允许新 Runner 以 staged 方式接入,验证通过后再切换活跃连接
- self._rpc_server.begin_staged_takeover()
- self._staged_registered_plugins.clear()
+ async def reload_plugin(
+ self,
+ plugin_id: str,
+ reason: str = "manual",
+ external_available_plugins: Optional[Dict[str, str]] = None,
+ ) -> bool:
+ """按插件 ID 触发精确重载。
- # 重新生成 session token,防止被终止的旧 Runner 重连
- self._rpc_server.reset_session_token()
+ Args:
+ plugin_id: 目标插件 ID。
+ reason: 重载原因。
+ external_available_plugins: 视为已满足的外部依赖插件版本映射。
- # 注意:不在此处调用 _clear_runtime_state()。
- # 旧组件在新 Runner 完成注册前继续提供服务,避免热重载窗口期内
- # dispatch_event / execute_workflow 找不到任何组件导致消息静默丢失。
- # ComponentRegistry.register_component 对同名组件是覆盖式写入,安全。
-
- # 拉起新 Runner
+ Returns:
+ bool: 是否重载成功。
+ """
try:
- await self._spawn_runner()
- await self._wait_for_runner_generation(
- expected_generation,
- timeout_sec=self._runner_spawn_timeout,
- allow_staged=True,
+ response = await self._rpc_server.send_request(
+ "plugin.reload",
+ plugin_id=plugin_id,
+ payload={
+ "plugin_id": plugin_id,
+ "reason": reason,
+ "external_available_plugins": external_available_plugins or self._external_available_plugins,
+ },
+ timeout_ms=max(int(self._runner_spawn_timeout * 1000), 10000),
)
- await self._wait_for_runner_ready(expected_generation, timeout_sec=self._runner_spawn_timeout)
- resp = await self._rpc_server.send_request(
- "plugin.health",
- timeout_ms=5000,
- target_generation=expected_generation,
- )
- health = HealthPayload.model_validate(resp.payload)
- if not health.healthy:
- raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "新 Runner 健康检查失败")
- await self._rpc_server.commit_staged_takeover()
- except Exception as e:
- logger.error(f"新 Runner 健康检查失败: {e},回滚")
- await self._terminate_process(self._runner_process, old_process)
- await self._rpc_server.rollback_staged_takeover()
- self._runner_process = old_process
- self._rpc_server.restore_session_token(old_session_token)
- self._staged_registered_plugins.clear()
- self._registered_plugins = dict(old_registered_plugins)
- self._rebuild_runtime_state()
+ except Exception as exc:
+ logger.error(f"插件 {plugin_id} 重载请求失败: {exc}")
return False
- self._runner_generation = self._rpc_server.runner_generation
- self._registered_plugins = dict(self._staged_registered_plugins)
- self._staged_registered_plugins.clear()
- self._rebuild_runtime_state()
+ result = ReloadPluginResultPayload.model_validate(response.payload)
+ if not result.success:
+ logger.warning(f"插件 {plugin_id} 重载失败: {result.failed_plugins}")
+ return result.success
- # 关停旧 Runner
- if old_process and old_process.returncode is None:
- try:
- old_process.terminate()
- await asyncio.wait_for(old_process.wait(), timeout=10.0)
- except asyncio.TimeoutError:
- old_process.kill()
+ async def reload_plugins(
+ self,
+ plugin_ids: Optional[List[str] | str] = None,
+ reason: str = "manual",
+ external_available_plugins: Optional[Dict[str, str]] = None,
+ ) -> bool:
+ """批量重载插件。
- logger.info("热重载完成")
- return True
+ Args:
+ plugin_ids: 目标插件 ID 列表;为空时重载当前已注册的全部插件。
+ reason: 重载原因。
+ external_available_plugins: 视为已满足的外部依赖插件版本映射。
+
+ Returns:
+ bool: 是否全部重载成功。
+ """
+ ordered_plugin_ids = self._normalize_reload_plugin_ids(plugin_ids)
+ if not ordered_plugin_ids:
+ ordered_plugin_ids = list(self._registered_plugins.keys())
+ if not ordered_plugin_ids:
+ return True
+
+ if len(ordered_plugin_ids) == 1:
+ return await self.reload_plugin(
+ plugin_id=ordered_plugin_ids[0],
+ reason=reason,
+ external_available_plugins=external_available_plugins,
+ )
+
+ try:
+ response = await self._rpc_server.send_request(
+ "plugin.reload_batch",
+ payload=ReloadPluginsPayload(
+ plugin_ids=ordered_plugin_ids,
+ reason=reason,
+ external_available_plugins=external_available_plugins or self._external_available_plugins,
+ ).model_dump(),
+ timeout_ms=max(int(self._runner_spawn_timeout * 1000), 10000),
+ )
+ except Exception as exc:
+ logger.error(f"插件批量重载请求失败: {exc}")
+ return False
+
+ result = ReloadPluginsResultPayload.model_validate(response.payload)
+ if not result.success:
+ logger.warning(f"插件批量重载失败: {result.failed_plugins}")
+ return result.success
async def notify_plugin_config_updated(
self,
plugin_id: str,
- config_data: Dict[str, Any],
+ config_data: Optional[Dict[str, Any]] = None,
config_version: str = "",
+ config_scope: str | ConfigReloadScope = "self",
) -> bool:
- """通知指定插件其配置已更新。"""
- if plugin_id not in self._registered_plugins:
+ """向 Runner 推送插件配置更新。
+
+ Args:
+ plugin_id: 目标插件 ID。
+ config_data: 配置内容。
+ config_version: 配置版本号。
+ config_scope: 配置变更范围。
+
+ Returns:
+ bool: 请求是否成功送达并被 Runner 接受。
+ """
+ try:
+ normalized_scope = ConfigReloadScope(config_scope)
+ except ValueError:
+ logger.warning(f"插件 {plugin_id} 配置更新通知失败: 非法的 config_scope={config_scope}")
return False
payload = ConfigUpdatedPayload(
plugin_id=plugin_id,
+ config_scope=normalized_scope,
config_version=config_version,
- config_data=config_data,
+ config_data=config_data or {},
)
- await self._rpc_server.send_request(
- "plugin.config_updated",
- plugin_id=plugin_id,
- payload=payload.model_dump(),
- timeout_ms=5000,
- )
- return True
+ try:
+ response = await self._rpc_server.send_request(
+ "plugin.config_updated",
+ plugin_id=plugin_id,
+ payload=payload.model_dump(),
+ timeout_ms=10000,
+ )
+ except Exception as exc:
+ logger.warning(f"插件 {plugin_id} 配置更新通知失败: {exc}")
+ return False
- # ─── 内部方法 ──────────────────────────────────────────────
+ return bool(response.payload.get("acknowledged", False))
+
+ def get_config_reload_subscribers(self, scope: str) -> List[str]:
+ """返回订阅指定全局配置广播的插件列表。
+
+ Args:
+ scope: 配置变更范围,仅支持 ``bot`` 或 ``model``。
+
+ Returns:
+ List[str]: 已声明订阅该范围的插件 ID 列表。
+ """
+
+ return [
+ plugin_id
+ for plugin_id, registration in self._registered_plugins.items()
+ if scope in registration.config_reload_subscriptions
+ ]
+
+ async def _wait_for_runner_connection(self, timeout_sec: float) -> None:
+ """等待 Runner 建立 RPC 连接。
+
+ Args:
+ timeout_sec: 超时时间,单位秒。
+
+ Raises:
+ TimeoutError: 在超时时间内 Runner 未完成连接。
+ """
+
+ async def wait_for_connection() -> None:
+ """轮询等待 RPC 连接建立。"""
+ while True:
+ if self._rpc_server.is_connected:
+ return
+
+ if not self._running:
+ raise RuntimeError("Supervisor 已停止,等待 Runner 连接已取消")
+
+ if failure_reason := self._get_runner_startup_failure_reason():
+ raise RuntimeError(f"等待 Runner 连接失败: {failure_reason}")
+
+ await asyncio.sleep(0.1)
+
+ try:
+ await asyncio.wait_for(wait_for_connection(), timeout=timeout_sec)
+ logger.info("Runner 已连接到 RPC Server")
+ except asyncio.TimeoutError as exc:
+ raise TimeoutError(f"等待 Runner 连接超时({timeout_sec}s)") from exc
+
+ async def _wait_for_runner_ready(self, timeout_sec: float = 30.0) -> RunnerReadyPayload:
+ """等待 Runner 完成启动初始化。
+
+ Args:
+ timeout_sec: 超时时间,单位秒。
+
+ Returns:
+ RunnerReadyPayload: Runner 上报的就绪信息。
+
+ Raises:
+ TimeoutError: 在超时时间内 Runner 未完成初始化。
+ """
+ async def wait_for_ready() -> RunnerReadyPayload:
+ """轮询等待 Runner 上报就绪。"""
+ while True:
+ if self._runner_ready_events.is_set():
+ return self._runner_ready_payloads
+
+ if not self._running:
+ raise RuntimeError("Supervisor 已停止,等待 Runner 就绪已取消")
+
+ if failure_reason := self._get_runner_startup_failure_reason():
+ raise RuntimeError(f"等待 Runner 就绪失败: {failure_reason}")
+
+ if not self._rpc_server.is_connected:
+ raise RuntimeError("等待 Runner 就绪失败: Runner RPC 连接已断开")
+
+ await asyncio.sleep(0.1)
+
+ try:
+ payload = await asyncio.wait_for(wait_for_ready(), timeout=timeout_sec)
+ logger.info("Runner 已完成初始化并上报就绪")
+ return payload
+ except asyncio.TimeoutError as exc:
+ raise TimeoutError(f"等待 Runner 就绪超时({timeout_sec}s)") from exc
def _register_internal_methods(self) -> None:
- """注册 Host 端的 RPC 方法处理器"""
- # Runner -> Host 的能力调用统一走 capability_service
- self._rpc_server.register_method("cap.request", self._capability_service.handle_capability_request)
+ """注册 Host 侧内部 RPC 方法。"""
+ self._rpc_server.register_method("cap.call", self._capability_service.handle_capability_request)
+ self._rpc_server.register_method("host.route_message", self._handle_route_message)
+ self._rpc_server.register_method("host.update_message_gateway_state", self._handle_update_message_gateway_state)
self._rpc_server.register_method("plugin.bootstrap", self._handle_bootstrap_plugin)
- # 插件注册
- self._rpc_server.register_method("plugin.register_components", self._handle_register_components)
- self._rpc_server.register_method("runner.ready", self._handle_runner_ready)
- # Runner 日志批量上报
+ self._rpc_server.register_method("plugin.register_components", self._handle_register_plugin)
+ self._rpc_server.register_method("plugin.register_plugin", self._handle_register_plugin)
+ self._rpc_server.register_method("plugin.unregister", self._handle_unregister_plugin)
self._rpc_server.register_method("runner.log_batch", self._log_bridge.handle_log_batch)
+ self._rpc_server.register_method("runner.ready", self._handle_runner_ready)
async def _handle_bootstrap_plugin(self, envelope: Envelope) -> Envelope:
- """处理插件 bootstrap 请求,仅同步能力令牌。"""
+ """处理插件 bootstrap 请求。
+
+ Args:
+ envelope: RPC 请求信封。
+
+ Returns:
+ Envelope: RPC 响应信封。
+ """
try:
- bootstrap = BootstrapPluginPayload.model_validate(envelope.payload)
- except Exception as e:
- return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e))
+ payload = BootstrapPluginPayload.model_validate(envelope.payload)
+ except Exception as exc:
+ return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
- active_generation = self._rpc_server.runner_generation
- staged_generation = self._rpc_server.staged_generation
- if envelope.generation not in {active_generation, staged_generation}:
- return envelope.make_error_response(
- ErrorCode.E_GENERATION_MISMATCH.value,
- f"插件 bootstrap generation 过期: {envelope.generation} 不在已知代际中",
- )
-
- if bootstrap.capabilities_required:
- self._policy.register_plugin(
- plugin_id=bootstrap.plugin_id,
- generation=envelope.generation,
- capabilities=bootstrap.capabilities_required,
- )
+ if payload.capabilities_required:
+ self._authorization.register_plugin(payload.plugin_id, payload.capabilities_required)
else:
- self._policy.revoke_plugin(bootstrap.plugin_id, generation=envelope.generation)
+ self._authorization.revoke_permission_token(payload.plugin_id)
- return envelope.make_response(payload={"accepted": True})
+ return envelope.make_response(payload={"accepted": True, "plugin_id": payload.plugin_id})
- async def _handle_register_components(self, envelope: Envelope) -> Envelope:
- """处理插件组件注册请求"""
+ async def _handle_register_plugin(self, envelope: Envelope) -> Envelope:
+ """处理插件组件注册请求。
+
+ Args:
+ envelope: RPC 请求信封。
+
+ Returns:
+ Envelope: RPC 响应信封。
+ """
try:
- reg = RegisterComponentsPayload.model_validate(envelope.payload)
- except Exception as e:
- return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e))
+ payload = RegisterPluginPayload.model_validate(envelope.payload)
+ except Exception as exc:
+ return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
- active_generation = self._rpc_server.runner_generation
- staged_generation = self._rpc_server.staged_generation
- if envelope.generation not in {active_generation, staged_generation}:
+ component_declarations = [component.model_dump() for component in payload.components]
+ runtime_components, api_components = self._split_component_declarations(component_declarations)
+ self._component_registry.remove_components_by_plugin(payload.plugin_id)
+ self._api_registry.remove_apis_by_plugin(payload.plugin_id)
+ await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id)
+
+ registered_count = self._component_registry.register_plugin_components(
+ payload.plugin_id,
+ runtime_components,
+ )
+ registered_api_count = self._api_registry.register_plugin_apis(payload.plugin_id, api_components)
+ self._registered_plugins[payload.plugin_id] = payload
+ self._message_gateway_states[payload.plugin_id] = {}
+
+ return envelope.make_response(
+ payload={
+ "accepted": True,
+ "plugin_id": payload.plugin_id,
+ "registered_components": registered_count,
+ "registered_apis": registered_api_count,
+ "message_gateways": len(
+ self._component_registry.get_message_gateways(plugin_id=payload.plugin_id, enabled_only=False)
+ ),
+ }
+ )
+
+ async def _handle_unregister_plugin(self, envelope: Envelope) -> Envelope:
+ """处理插件注销请求。
+
+ Args:
+ envelope: RPC 请求信封。
+
+ Returns:
+ Envelope: RPC 响应信封。
+ """
+ try:
+ payload = UnregisterPluginPayload.model_validate(envelope.payload)
+ except Exception as exc:
+ return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
+
+ removed_components = self._component_registry.remove_components_by_plugin(payload.plugin_id)
+ removed_apis = self._api_registry.remove_apis_by_plugin(payload.plugin_id)
+ self._authorization.revoke_permission_token(payload.plugin_id)
+ removed_registration = self._registered_plugins.pop(payload.plugin_id, None) is not None
+ await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id)
+ self._message_gateway_states.pop(payload.plugin_id, None)
+
+ return envelope.make_response(
+ payload={
+ "accepted": True,
+ "plugin_id": payload.plugin_id,
+ "reason": payload.reason,
+ "removed_components": removed_components,
+ "removed_apis": removed_apis,
+ "removed_registration": removed_registration,
+ }
+ )
+
+ @staticmethod
+ def _is_api_component(component: Dict[str, Any]) -> bool:
+ """判断组件声明是否属于 API。
+
+ Args:
+ component: 原始组件声明字典。
+
+ Returns:
+ bool: 是否为 API 组件。
+ """
+
+ return str(component.get("component_type", "") or "").strip().upper() == "API"
+
+ def _split_component_declarations(
+ self,
+ components: List[Dict[str, Any]],
+ ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
+ """拆分通用组件声明和 API 声明。
+
+ Args:
+ components: Runner 上报的原始组件声明列表。
+
+ Returns:
+ Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
+ 第一个列表为需要进入通用组件表的声明,
+ 第二个列表为需要进入 API 专用表的声明。
+ """
+
+ runtime_components: List[Dict[str, Any]] = []
+ api_components: List[Dict[str, Any]] = []
+ for component in components:
+ if self._is_api_component(component):
+ api_components.append(component)
+ else:
+ runtime_components.append(component)
+ return runtime_components, api_components
+
+ @staticmethod
+ def _build_message_gateway_driver_id(plugin_id: str, gateway_name: str) -> str:
+ """构造消息网关驱动 ID。
+
+ Args:
+ plugin_id: 插件 ID。
+ gateway_name: 网关组件名称。
+
+ Returns:
+ str: 对应 Platform IO 中的驱动 ID。
+ """
+
+ return f"gateway:{plugin_id}:{gateway_name}"
+
+ @staticmethod
+ def _normalize_runtime_route_value(value: str) -> Optional[str]:
+ """规范化运行时路由字段。
+
+ Args:
+ value: 待规范化的原始字符串。
+
+ Returns:
+ Optional[str]: 规范化后非空则返回字符串,否则返回 ``None``。
+ """
+
+ normalized_value = str(value or "").strip()
+ return normalized_value or None
+
+ def _resolve_message_gateway_entry(
+ self,
+ plugin_id: str,
+ gateway_name: str,
+ ) -> Optional[Any]:
+ """解析指定插件的消息网关组件。
+
+ Args:
+ plugin_id: 插件 ID。
+ gateway_name: 网关组件名称;为空时按兼容规则推断。
+
+ Returns:
+ Optional[Any]: 匹配到的消息网关组件条目。
+ """
+
+ if gateway_name:
+ return self._component_registry.get_message_gateway(
+ plugin_id=plugin_id,
+ name=gateway_name,
+ enabled_only=False,
+ )
+
+ gateways = self._component_registry.get_message_gateways(plugin_id=plugin_id, enabled_only=False)
+ return gateways[0] if len(gateways) == 1 else None
+
+ async def _register_message_gateway_driver(
+ self,
+ plugin_id: str,
+ gateway_entry: Any,
+ route_key: RouteKey,
+ ) -> None:
+ """为消息网关注册驱动并绑定发送/接收路由。
+
+ Args:
+ plugin_id: 插件 ID。
+ gateway_entry: 消息网关组件条目。
+ route_key: 当前链路对应的路由键。
+ """
+
+ await self._unregister_message_gateway_driver(plugin_id, gateway_entry.name)
+
+ platform_io_manager = get_platform_io_manager()
+ driver = PluginPlatformDriver(
+ driver_id=self._build_message_gateway_driver_id(plugin_id, gateway_entry.name),
+ platform=route_key.platform,
+ account_id=route_key.account_id,
+ scope=route_key.scope,
+ plugin_id=plugin_id,
+ component_name=gateway_entry.name,
+ supports_send=bool(gateway_entry.supports_send),
+ supervisor=self,
+ metadata={
+ "protocol": gateway_entry.protocol,
+ "route_type": gateway_entry.route_type,
+ **gateway_entry.metadata,
+ },
+ )
+
+ try:
+ if platform_io_manager.is_started:
+ await platform_io_manager.add_driver(driver)
+ else:
+ platform_io_manager.register_driver(driver)
+ except Exception:
+ with contextlib.suppress(Exception):
+ if platform_io_manager.is_started:
+ await platform_io_manager.remove_driver(driver.driver_id)
+ else:
+ platform_io_manager.unregister_driver(driver.driver_id)
+ raise
+
+ binding_metadata = {
+ "plugin_id": plugin_id,
+ "gateway_name": gateway_entry.name,
+ "protocol": gateway_entry.protocol,
+ "route_type": gateway_entry.route_type,
+ **gateway_entry.metadata,
+ }
+ binding = RouteBinding(
+ route_key=route_key,
+ driver_id=driver.driver_id,
+ driver_kind=DriverKind.PLUGIN,
+ metadata=binding_metadata,
+ )
+ if gateway_entry.supports_send:
+ platform_io_manager.bind_send_route(binding)
+ if gateway_entry.supports_receive:
+ platform_io_manager.bind_receive_route(binding)
+
+ async def _unregister_message_gateway_driver(self, plugin_id: str, gateway_name: str) -> None:
+ """从 Platform IO 注销单个消息网关驱动。
+
+ Args:
+ plugin_id: 插件 ID。
+ gateway_name: 网关组件名称。
+ """
+
+ platform_io_manager = get_platform_io_manager()
+ driver_id = self._build_message_gateway_driver_id(plugin_id, gateway_name)
+ platform_io_manager.send_route_table.remove_bindings_by_driver(driver_id)
+ platform_io_manager.receive_route_table.remove_bindings_by_driver(driver_id)
+
+ with contextlib.suppress(Exception):
+ if platform_io_manager.is_started:
+ await platform_io_manager.remove_driver(driver_id)
+ else:
+ platform_io_manager.unregister_driver(driver_id)
+
+ async def _unregister_all_message_gateway_drivers_for_plugin(self, plugin_id: str) -> None:
+ """注销指定插件的全部消息网关驱动。
+
+ Args:
+ plugin_id: 插件 ID。
+ """
+
+ gateway_names = list(self._message_gateway_states.get(plugin_id, {}).keys())
+ for gateway_name in gateway_names:
+ await self._unregister_message_gateway_driver(plugin_id, gateway_name)
+
+ def _build_message_gateway_route_key(
+ self,
+ gateway_entry: Any,
+ payload: MessageGatewayStateUpdatePayload,
+ ) -> RouteKey:
+ """根据消息网关运行时状态构造路由键。
+
+ Args:
+ gateway_entry: 消息网关组件条目。
+ payload: 网关上报的运行时状态。
+
+ Returns:
+ RouteKey: 当前链路对应的路由键。
+
+ Raises:
+ ValueError: 当平台信息缺失时抛出。
+ """
+
+ if not (platform := str(payload.platform or gateway_entry.platform or "").strip()):
+ raise ValueError(f"消息网关 {gateway_entry.full_name} 未提供有效的平台名称")
+
+ return RouteKey(
+ platform=platform,
+ account_id=self._normalize_runtime_route_value(payload.account_id) or gateway_entry.account_id or None,
+ scope=self._normalize_runtime_route_value(payload.scope) or gateway_entry.scope or None,
+ )
+
+ def _apply_message_gateway_state(
+ self,
+ plugin_id: str,
+ gateway_entry: Any,
+ payload: MessageGatewayStateUpdatePayload,
+ ) -> Tuple[_MessageGatewayRuntimeState, Dict[str, Any]]:
+ """应用消息网关运行时状态,并同步 Platform IO 路由。
+
+ Args:
+ plugin_id: 插件 ID。
+ gateway_entry: 消息网关组件条目。
+ payload: 网关上报的运行时状态。
+
+ Returns:
+ Tuple[_MessageGatewayRuntimeState, Dict[str, Any]]: 更新后的状态与路由键字典。
+ """
+
+ plugin_states = self._message_gateway_states.setdefault(plugin_id, {})
+ if not payload.ready:
+ runtime_state = _MessageGatewayRuntimeState(
+ ready=False,
+ platform=self._normalize_runtime_route_value(payload.platform) or gateway_entry.platform or None,
+ account_id=self._normalize_runtime_route_value(payload.account_id) or gateway_entry.account_id or None,
+ scope=self._normalize_runtime_route_value(payload.scope) or gateway_entry.scope or None,
+ metadata=dict(payload.metadata),
+ )
+ plugin_states[gateway_entry.name] = runtime_state
+ return runtime_state, {}
+
+ route_key = self._build_message_gateway_route_key(gateway_entry, payload)
+ runtime_state = _MessageGatewayRuntimeState(
+ ready=True,
+ platform=route_key.platform,
+ account_id=route_key.account_id,
+ scope=route_key.scope,
+ metadata=dict(payload.metadata),
+ )
+ plugin_states[gateway_entry.name] = runtime_state
+ return runtime_state, {
+ "platform": route_key.platform,
+ "account_id": route_key.account_id,
+ "scope": route_key.scope,
+ }
+
+ @staticmethod
+ def _attach_inbound_route_metadata(
+ session_message: "SessionMessage",
+ route_key: RouteKey,
+ route_metadata: Dict[str, Any],
+ ) -> None:
+ """将入站路由信息写回消息的 ``additional_config``。
+
+ Args:
+ session_message: 已构造好的内部消息对象。
+ route_key: Host 为该消息解析出的标准路由键。
+ route_metadata: 插件通过 RPC 补充的原始路由辅助元数据。
+ """
+
+ additional_config = session_message.message_info.additional_config
+ if not isinstance(additional_config, dict):
+ additional_config = {}
+ session_message.message_info.additional_config = additional_config
+
+ for key, value in route_metadata.items():
+ if value is None:
+ continue
+ normalized_value = str(value).strip()
+ if normalized_value:
+ additional_config[key] = value
+
+ if route_key.account_id:
+ additional_config.setdefault("platform_io_account_id", route_key.account_id)
+ if route_key.scope:
+ additional_config.setdefault("platform_io_scope", route_key.scope)
+
+ def _build_inbound_route_key(
+ self,
+ gateway_entry: Any,
+ runtime_state: _MessageGatewayRuntimeState,
+ message: Dict[str, Any],
+ route_metadata: Dict[str, Any],
+ ) -> RouteKey:
+ """为入站消息构造归一路由键。
+
+ Args:
+ gateway_entry: 接收消息的网关组件条目。
+ runtime_state: 当前网关的运行时状态。
+ message: 标准消息字典。
+ route_metadata: 插件补充的路由辅助元数据。
+
+ Returns:
+ RouteKey: 供 Platform IO 使用的规范化路由键。
+ """
+
+ platform = str(
+ message.get("platform")
+ or route_metadata.get("platform")
+ or runtime_state.platform
+ or gateway_entry.platform
+ or ""
+ ).strip()
+ if not platform:
+ raise ValueError(f"消息网关 {gateway_entry.full_name} 的入站消息缺少平台信息")
+
+ try:
+ route_key = RouteKeyFactory.from_message_dict(message)
+ except Exception:
+ route_key = RouteKey(platform=platform)
+
+ route_account_id, route_scope = RouteKeyFactory.extract_components(route_metadata)
+ account_id = route_key.account_id or route_account_id or runtime_state.account_id or gateway_entry.account_id or None
+ scope = route_key.scope or route_scope or runtime_state.scope or gateway_entry.scope or None
+ return RouteKey(
+ platform=platform,
+ account_id=account_id,
+ scope=scope,
+ )
+
+ async def _handle_update_message_gateway_state(self, envelope: Envelope) -> Envelope:
+ """处理消息网关上报的运行时状态更新。
+
+ Args:
+ envelope: RPC 请求信封。
+
+ Returns:
+ Envelope: 状态更新处理结果。
+ """
+
+ try:
+ payload = MessageGatewayStateUpdatePayload.model_validate(envelope.payload)
+ except Exception as exc:
+ return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
+
+ gateway_entry = self._resolve_message_gateway_entry(envelope.plugin_id, payload.gateway_name)
+ if gateway_entry is None:
return envelope.make_error_response(
- ErrorCode.E_GENERATION_MISMATCH.value,
- f"组件注册 generation 过期: {envelope.generation} 不在已知代际中",
+ ErrorCode.E_METHOD_NOT_ALLOWED.value,
+ f"插件 {envelope.plugin_id} 未声明消息网关 {payload.gateway_name or ''}",
)
- if envelope.generation == staged_generation and staged_generation != 0:
- self._staged_registered_plugins[reg.plugin_id] = reg
- logger.info(
- f"插件 {reg.plugin_id} v{reg.plugin_version} staged 注册成功,"
- f"组件数: {len(reg.components)}, 能力需求: {reg.capabilities_required}"
+ try:
+ if payload.ready:
+ route_key = self._build_message_gateway_route_key(gateway_entry, payload)
+ await self._register_message_gateway_driver(envelope.plugin_id, gateway_entry, route_key)
+ else:
+ await self._unregister_message_gateway_driver(envelope.plugin_id, gateway_entry.name)
+ runtime_state, route_key_dict = self._apply_message_gateway_state(
+ plugin_id=envelope.plugin_id,
+ gateway_entry=gateway_entry,
+ payload=payload,
)
- return envelope.make_response(payload={"accepted": True, "staged": True})
+ except Exception as exc:
+ return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
- self._registered_plugins[reg.plugin_id] = reg
-
- # 在策略引擎中注册插件
- self._policy.register_plugin(
- plugin_id=reg.plugin_id,
- generation=envelope.generation,
- capabilities=reg.capabilities_required or [],
+ response = MessageGatewayStateUpdateResultPayload(
+ accepted=True,
+ ready=runtime_state.ready,
+ route_key=route_key_dict,
)
+ return envelope.make_response(payload=response.model_dump())
- # 同 generation 下重新注册时,以本次声明为准,避免残留幽灵组件
- self._component_registry.remove_components_by_plugin(reg.plugin_id)
- self._component_registry.register_plugin_components(
- plugin_id=reg.plugin_id,
- components=[c.model_dump() for c in reg.components],
+ async def _handle_route_message(self, envelope: Envelope) -> Envelope:
+ """处理消息网关上报的外部入站消息。
+
+ Args:
+ envelope: RPC 请求信封。
+
+ Returns:
+ Envelope: 注入结果响应。
+ """
+
+ try:
+ payload = RouteMessagePayload.model_validate(envelope.payload)
+ except Exception as exc:
+ return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
+
+ gateway_entry = self._resolve_message_gateway_entry(envelope.plugin_id, payload.gateway_name)
+ if gateway_entry is None or not bool(gateway_entry.supports_receive):
+ return envelope.make_error_response(
+ ErrorCode.E_METHOD_NOT_ALLOWED.value,
+ f"插件 {envelope.plugin_id} 未声明可接收的消息网关 {payload.gateway_name}",
+ )
+
+ runtime_state = self._message_gateway_states.get(envelope.plugin_id, {}).get(
+ gateway_entry.name,
+ _MessageGatewayRuntimeState(),
)
+ if not runtime_state.ready:
+ return envelope.make_error_response(
+ ErrorCode.E_METHOD_NOT_ALLOWED.value,
+ f"消息网关 {gateway_entry.full_name} 尚未就绪,不能注入外部消息",
+ )
- stats = self._component_registry.get_stats()
- logger.info(
- f"插件 {reg.plugin_id} v{reg.plugin_version} 注册成功,"
- f"组件数: {len(reg.components)}, 能力需求: {reg.capabilities_required},"
- f"注册表总计: {stats}"
+ try:
+ route_key = self._build_inbound_route_key(
+ gateway_entry=gateway_entry,
+ runtime_state=runtime_state,
+ message=payload.message,
+ route_metadata=payload.route_metadata,
+ )
+ session_message = self._message_gateway.build_session_message(payload.message)
+ self._attach_inbound_route_metadata(session_message, route_key, payload.route_metadata)
+ except Exception as exc:
+ return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
+
+ platform_io_manager = get_platform_io_manager()
+ accepted = await platform_io_manager.accept_inbound(
+ InboundMessageEnvelope(
+ route_key=route_key,
+ driver_id=self._build_message_gateway_driver_id(envelope.plugin_id, gateway_entry.name),
+ driver_kind=DriverKind.PLUGIN,
+ external_message_id=payload.external_message_id or str(payload.message.get("message_id") or "") or None,
+ dedupe_key=payload.dedupe_key or None,
+ session_message=session_message,
+ payload=payload.message,
+ metadata={
+ "plugin_id": envelope.plugin_id,
+ "gateway_name": gateway_entry.name,
+ "protocol": gateway_entry.protocol,
+ **payload.route_metadata,
+ },
+ )
)
-
- return envelope.make_response(payload={"accepted": True})
+ response = ReceiveExternalMessageResultPayload(
+ accepted=accepted,
+ route_key={
+ "platform": route_key.platform,
+ "account_id": route_key.account_id,
+ "scope": route_key.scope,
+ },
+ )
+ return envelope.make_response(payload=response.model_dump())
async def _handle_runner_ready(self, envelope: Envelope) -> Envelope:
- """处理 Runner 初始化完成信号。"""
- try:
- ready = RunnerReadyPayload.model_validate(envelope.payload)
- except Exception as e:
- return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e))
+ """处理 Runner 就绪通知。
- event = self._runner_ready_events.setdefault(envelope.generation, asyncio.Event())
- self._runner_ready_payloads[envelope.generation] = ready
- event.set()
- logger.info(
- f"Runner generation={envelope.generation} 已就绪,成功插件数: {len(ready.loaded_plugins)},"
- f"失败插件数: {len(ready.failed_plugins)}"
- )
+ Args:
+ envelope: RPC 请求信封。
+
+ Returns:
+ Envelope: RPC 响应信封。
+ """
+ try:
+ payload = RunnerReadyPayload.model_validate(envelope.payload)
+ except Exception as exc:
+ return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
+
+ self._runner_ready_payloads = payload
+ self._runner_ready_events.set()
return envelope.make_response(payload={"accepted": True})
+ def _build_runner_environment(self) -> Dict[str, str]:
+ """构建拉起 Runner 所需的环境变量。
+
+ Returns:
+ Dict[str, str]: 传递给 Runner 进程的环境变量映射。
+ """
+ global_config_snapshot = config_manager.get_global_config().model_dump(mode="json")
+ global_config_snapshot["model"] = config_manager.get_model_config().model_dump(mode="json")
+ return {
+ ENV_EXTERNAL_PLUGIN_IDS: json.dumps(self._external_available_plugins, ensure_ascii=False),
+ ENV_GLOBAL_CONFIG_SNAPSHOT: json.dumps(global_config_snapshot, ensure_ascii=False),
+ ENV_HOST_VERSION: PROTOCOL_VERSION,
+ ENV_IPC_ADDRESS: self._transport.get_address(),
+ ENV_PLUGIN_DIRS: os.pathsep.join(str(path) for path in self._plugin_dirs),
+ ENV_SESSION_TOKEN: self._rpc_server.session_token,
+ }
+
async def _spawn_runner(self) -> None:
- """拉起 Runner 子进程"""
- runner_module = "src.plugin_runtime.runner.runner_main"
- address = self._transport.get_address()
- token = self._rpc_server.session_token
+ """拉起 Runner 子进程。"""
+ if self._runner_process is not None and self._runner_process.returncode is None:
+ logger.warning("Runner 已在运行,跳过重复拉起")
+ return
+
+ self._clear_runner_state()
env = os.environ.copy()
- env[ENV_IPC_ADDRESS] = address
- env[ENV_SESSION_TOKEN] = token
- env[ENV_PLUGIN_DIRS] = os.pathsep.join(str(p) for p in self._plugin_dirs)
- env[ENV_HOST_VERSION] = MMC_VERSION
+ env.update(self._build_runner_environment())
self._runner_process = await asyncio.create_subprocess_exec(
sys.executable,
"-m",
- runner_module,
+ "src.plugin_runtime.runner.runner_main",
env=env,
- # stdout 不捕获:Runner 的日志均通过 IPC 传㛹(RunnerIPCLogHandler)
- stdout=None,
- # stderr 捕获为 PIPE,仅用于 IPC 建立前的进程级致命错误输出
+ stdout=asyncio.subprocess.DEVNULL,
stderr=asyncio.subprocess.PIPE,
)
- self._attach_stderr_drain(self._runner_process)
- self._runner_generation = self._rpc_server.runner_generation
- logger.info(f"Runner 子进程已启动: pid={self._runner_process.pid}, generation={self._runner_generation}")
-
- async def _shutdown_runner(self) -> None:
- """优雅关停 Runner"""
- if not self._runner_process or self._runner_process.returncode is not None:
- return
-
- # 发送 prepare_shutdown
- try:
- if self._rpc_server.is_connected:
- shutdown_payload = ShutdownPayload(reason="host_shutdown", drain_timeout_ms=5000)
- await self._rpc_server.send_request(
- "plugin.prepare_shutdown",
- payload=shutdown_payload.model_dump(),
- timeout_ms=5000,
- )
- await self._rpc_server.send_request(
- "plugin.shutdown",
- payload=shutdown_payload.model_dump(),
- timeout_ms=5000,
- )
- except Exception as e:
- logger.warning(f"发送关停命令失败: {e}")
-
- # 等待进程退出
- try:
- await asyncio.wait_for(self._runner_process.wait(), timeout=10.0)
- except asyncio.TimeoutError:
- logger.warning("Runner 未在超时内退出,强制终止")
- self._runner_process.kill()
- await self._runner_process.wait()
-
- await self._cleanup_stderr_drain()
-
- async def _health_check_loop(self) -> None:
- """周期性健康检查 + 崩溃自动重启"""
- while self._running:
- await asyncio.sleep(self._health_interval)
-
- # 检查 Runner 进程是否意外退出
- if self._runner_process and self._runner_process.returncode is not None:
- exit_code = self._runner_process.returncode
- logger.warning(f"Runner 进程已退出 (exit_code={exit_code})")
-
- if self._restart_count < self._max_restart_attempts:
- self._restart_count += 1
- logger.info(f"尝试重启 Runner ({self._restart_count}/{self._max_restart_attempts})")
- # 清理旧的组件注册
- for plugin_id in list(self._registered_plugins.keys()):
- self._component_registry.remove_components_by_plugin(plugin_id)
- self._policy.revoke_plugin(plugin_id)
- self._registered_plugins.clear()
-
- try:
- self._clear_runtime_state()
- # 重新生成 session token,防止旧 Runner 僵尸进程用旧 token 重连
- self._rpc_server.reset_session_token()
- await self._spawn_runner()
- except Exception as e:
- logger.error(f"Runner 重启失败: {e}", exc_info=True)
- else:
- logger.error(f"Runner 连续崩溃 {self._max_restart_attempts} 次,停止重启")
- continue
-
- if not self._rpc_server.is_connected:
- logger.warning("Runner 未连接,跳过健康检查")
- continue
-
- try:
- resp = await self._rpc_server.send_request("plugin.health", timeout_ms=5000)
- health = HealthPayload.model_validate(resp.payload)
- if not health.healthy:
- logger.warning(f"Runner 健康检查异常: {health}")
- else:
- # 健康检查成功,重置重启计数
- self._restart_count = 0
- except RPCError as e:
- logger.error(f"健康检查失败: {e}")
- except asyncio.CancelledError:
- break
- except Exception as e:
- logger.error(f"健康检查异常: {e}")
-
- async def _wait_for_runner_generation(
- self,
- expected_generation: int,
- timeout_sec: float,
- allow_staged: bool = False,
- ) -> None:
- """等待指定代际的 Runner 完成连接。"""
- deadline = asyncio.get_running_loop().time() + timeout_sec
- while asyncio.get_running_loop().time() < deadline:
- if allow_staged and self._rpc_server.has_generation(expected_generation):
- return
- if self._rpc_server.is_connected and self._rpc_server.runner_generation >= expected_generation:
- self._runner_generation = self._rpc_server.runner_generation
- return
- await asyncio.sleep(0.1)
- raise TimeoutError(f"等待 Runner generation {expected_generation} 超时")
-
- async def _wait_for_runner_ready(self, expected_generation: int, timeout_sec: float) -> RunnerReadyPayload:
- """等待指定代际的 Runner 完成初始化。"""
- event = self._runner_ready_events.setdefault(expected_generation, asyncio.Event())
- await asyncio.wait_for(event.wait(), timeout=timeout_sec)
- return self._runner_ready_payloads.get(expected_generation, RunnerReadyPayload())
-
- def _clear_runtime_state(self) -> None:
- """清空当前插件注册态。"""
- self._component_registry.clear()
- self._policy.clear()
- self._registered_plugins.clear()
- self._staged_registered_plugins.clear()
-
- def _rebuild_runtime_state(self) -> None:
- """根据已记录的插件注册信息重建运行时状态。"""
- self._component_registry.clear()
- self._policy.clear()
- for reg in self._registered_plugins.values():
- self._policy.register_plugin(
- plugin_id=reg.plugin_id,
- generation=self._rpc_server.runner_generation,
- capabilities=reg.capabilities_required or [],
- )
- self._component_registry.register_plugin_components(
- plugin_id=reg.plugin_id,
- components=[c.model_dump() for c in reg.components],
+ if self._runner_process.stderr is not None:
+ self._stderr_drain_task = asyncio.create_task(
+ self._drain_runner_stderr(self._runner_process.stderr),
+ name="PluginRunnerSupervisor.stderr",
)
- def _attach_stderr_drain(self, process: asyncio.subprocess.Process) -> None:
- """为 Runner stderr 创建排空任务,捕获 IPC 建立前的进程级错误输出。
+ logger.info(f"Runner 已拉起,pid={self._runner_process.pid}")
- stderr 中的内容通常是:
- - Runner 启动早期(握手完成之前)的日志
- - 进程级致命错误(ImportError、SyntaxError等)
- - 异常进程退出前的最后输出
-
- 握手成功后,插件的所有日志均经由 RunnerIPCLogHandler 通过 IPC 传输。
- """
- if process.stderr is None:
- return
- task = asyncio.create_task(
- self._drain_runner_stderr(process.stderr, process.pid),
- name=f"runner_stderr_drain:{process.pid}",
- )
- self._stderr_drain_task = task
- task.add_done_callback(
- lambda done_task: None if self._stderr_drain_task is not done_task else self._clear_stderr_drain_task()
- )
-
- def _clear_stderr_drain_task(self) -> None:
- self._stderr_drain_task = None
-
- async def _drain_runner_stderr(
- self,
- stream: asyncio.StreamReader,
- pid: int,
- ) -> None:
- """持续读取 Runner stderr 并转发到 Host Logger,防止 PIPE 锡死子进程。
+ async def _drain_runner_stderr(self, stream: asyncio.StreamReader) -> None:
+ """持续排空 Runner 的 stderr。
Args:
- stream: Runner 子进程的 stderr 流。
- pid: 子进程 PID,仅用于日志上下文。
+ stream: Runner 的 stderr 流。
"""
try:
while True:
line = await stream.readline()
if not line:
- break
- if message := line.decode(errors="replace").rstrip():
- # 将 stderr 输出以 WARNING 级展示:
- # 如果 Runner 正常运行,此流应当无输出;
- # 有输出说明进程级错误发生,需要出现在主进程日志中
- logger.warning(f"[runner:{pid}:stderr] {message}")
+ return
+ if message := line.decode("utf-8", errors="replace").rstrip():
+ logger.warning(f"[runner-stderr] {message}")
except asyncio.CancelledError:
raise
except Exception as exc:
- logger.debug(f"读取 Runner stderr 失败 (pid={pid}): {exc}")
+ logger.warning(f"排空 Runner stderr 失败: {exc}")
- async def _cleanup_stderr_drain(self) -> None:
- """等待并取消 stderr 排空任务。"""
- if self._stderr_drain_task is None:
- return
- task = self._stderr_drain_task
- self._stderr_drain_task = None
- if not task.done():
- task.cancel()
- with contextlib.suppress(Exception):
- await asyncio.gather(task, return_exceptions=True)
+ async def _shutdown_runner(self, reason: str = "normal") -> None:
+ """优雅关闭 Runner 子进程。
- @staticmethod
- async def _terminate_process(
- process: Optional[asyncio.subprocess.Process],
- keep_process: Optional[asyncio.subprocess.Process] = None,
- ) -> None:
- """终止指定进程,但跳过需要保留的旧进程引用。"""
- if process is None or process is keep_process or process.returncode is not None:
+ Args:
+ reason: 关停原因。
+ """
+ process = self._runner_process
+ if process is None:
return
- process.terminate()
+ payload = ShutdownPayload(reason=reason)
+
+ if process.returncode is None and self._rpc_server.is_connected:
+ with contextlib.suppress(Exception):
+ await self._rpc_server.send_request(
+ "plugin.prepare_shutdown",
+ payload=payload.model_dump(),
+ timeout_ms=payload.drain_timeout_ms,
+ )
+ with contextlib.suppress(Exception):
+ await self._rpc_server.send_request(
+ "plugin.shutdown",
+ payload=payload.model_dump(),
+ timeout_ms=payload.drain_timeout_ms,
+ )
+
+ if process.returncode is None:
+ try:
+ await asyncio.wait_for(process.wait(), timeout=max(payload.drain_timeout_ms / 1000.0, 1.0))
+ except asyncio.TimeoutError:
+ logger.warning("Runner 优雅退出超时,尝试 terminate")
+ process.terminate()
+ try:
+ await asyncio.wait_for(process.wait(), timeout=5.0)
+ except asyncio.TimeoutError:
+ logger.warning("Runner terminate 超时,尝试 kill")
+ process.kill()
+ with contextlib.suppress(Exception):
+ await asyncio.wait_for(process.wait(), timeout=5.0)
+
+ self._runner_process = None
+
+ if self._stderr_drain_task is not None:
+ self._stderr_drain_task.cancel()
+ with contextlib.suppress(asyncio.CancelledError):
+ await self._stderr_drain_task
+ self._stderr_drain_task = None
+
+ for plugin_id in list(self._message_gateway_states.keys()):
+ await self._unregister_all_message_gateway_drivers_for_plugin(plugin_id)
+ self._clear_runner_state()
+
+ async def _health_check_loop(self) -> None:
+ """周期性检查 Runner 健康状态,并在必要时重启。"""
+ timeout_ms = max(int(self._health_interval * 1000), 1000)
+
+ while self._running:
+ try:
+ await asyncio.sleep(self._health_interval)
+ except asyncio.CancelledError:
+ return
+
+ if not self._running:
+ return
+
+ process = self._runner_process
+ if process is None or process.returncode is not None:
+ reason = "runner_process_exited" if process is not None else "runner_process_missing"
+ restarted = await self._restart_runner(reason=reason)
+ if not restarted:
+ return
+ continue
+
+ try:
+ response = await self._rpc_server.send_request("plugin.health", timeout_ms=timeout_ms)
+ health = HealthPayload.model_validate(response.payload)
+ if not health.healthy:
+ restarted = await self._restart_runner(reason="health_check_unhealthy")
+ if not restarted:
+ return
+ except asyncio.CancelledError:
+ return
+ except (RPCError, Exception) as exc:
+ logger.warning(f"Runner 健康检查失败: {exc}")
+ restarted = await self._restart_runner(reason="health_check_failed")
+ if not restarted:
+ return
+
+ async def _restart_runner(self, reason: str) -> bool:
+ """在 Runner 异常时执行整进程级重启。
+
+ Args:
+ reason: 触发重启的原因。
+
+ Returns:
+ bool: 是否重启成功。
+ """
+ if not self._running:
+ return False
+
+ if self._restart_count >= self._max_restart_attempts:
+ logger.error(f"Runner 自动重启次数已达上限,停止重启。reason={reason}")
+ return False
+
+ self._restart_count += 1
+ logger.warning(f"准备重启 Runner,第 {self._restart_count} 次,reason={reason}")
+
+ await self._shutdown_runner(reason=reason)
+
try:
- await asyncio.wait_for(process.wait(), timeout=10.0)
- except asyncio.TimeoutError:
- process.kill()
- await process.wait()
+ await self._spawn_runner()
+ await self._wait_for_runner_connection(timeout_sec=self._runner_spawn_timeout)
+ await self._wait_for_runner_ready(timeout_sec=self._runner_spawn_timeout)
+ except Exception as exc:
+ await self._shutdown_runner(reason="restart_failed")
+ logger.error(f"Runner 重启失败: {exc}", exc_info=True)
+ return False
+
+ self._restart_count = 0
+ logger.info("Runner 已成功重启")
+ return True
+
+ def _clear_runner_state(self) -> None:
+ """清理当前 Runner 对应的 Host 侧注册状态。"""
+ self._authorization.clear()
+ self._api_registry.clear()
+ self._component_registry.clear()
+ self._registered_plugins.clear()
+ self._message_gateway_states.clear()
+ self._runner_ready_events = asyncio.Event()
+ self._runner_ready_payloads = RunnerReadyPayload()
+ self._rpc_server.clear_handshake_state()
+
+ def _get_runner_startup_failure_reason(self) -> Optional[str]:
+ """获取 Runner 在启动阶段已经暴露出的失败原因。
+
+ Returns:
+ Optional[str]: 若已检测到失败则返回失败原因,否则返回 ``None``。
+ """
+ if handshake_reason := self._rpc_server.last_handshake_rejection_reason:
+ return f"握手被拒绝: {handshake_reason}"
+
+ process = self._runner_process
+ if process is None:
+ return "Runner 进程不存在"
+
+ if process.returncode is not None:
+ return f"Runner 进程已退出,退出码 {process.returncode}"
+
+ return None
+
+
+PluginSupervisor = PluginRunnerSupervisor
diff --git a/src/plugin_runtime/host/workflow_executor.py b/src/plugin_runtime/host/workflow_executor.py
deleted file mode 100644
index 3037e9dd..00000000
--- a/src/plugin_runtime/host/workflow_executor.py
+++ /dev/null
@@ -1,422 +0,0 @@
-"""Host-side WorkflowExecutor
-
-6 阶段线性流转(INGRESS → PRE_PROCESS → PLAN → TOOL_EXECUTE → POST_PROCESS → EGRESS)
-
-每个阶段执行顺序:
-1. Host-side pre-filter: 根据 hook filter 条件过滤不相关的 hook
-2. 按 priority 降序排列
-3. 串行执行 blocking hook(可修改 message,返回 HookResult)
-4. 并发执行 non-blocking hook(只读)
-5. 检查是否有 SKIP_STAGE 或 ABORT
-6. PLAN 阶段内置 Command 匹配路由
-
-支持:
-- HookResult: CONTINUE / SKIP_STAGE / ABORT
-- ErrorPolicy: ABORT / SKIP / LOG (per-hook)
-- stage_outputs: 阶段间带命名空间的数据传递
-- modification_log: 消息修改审计
-"""
-
-from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple
-
-import asyncio
-import time
-import uuid
-
-from src.common.logger import get_logger
-from src.config.config import global_config
-from src.plugin_runtime.host.component_registry import ComponentRegistry, RegisteredComponent
-
-logger = get_logger("plugin_runtime.host.workflow_executor")
-
-# 阶段顺序
-STAGE_SEQUENCE: List[str] = [
- "ingress",
- "pre_process",
- "plan",
- "tool_execute",
- "post_process",
- "egress",
-]
-
-# HookResult 常量(与 SDK HookResult enum 值对应)
-HOOK_CONTINUE = "continue"
-HOOK_SKIP_STAGE = "skip_stage"
-HOOK_ABORT = "abort"
-
-
-# blocking hook 全局最大超时(秒):即使 hook 声明 timeout_ms=0 也不会无限等待
-# 从配置文件读取,允许用户调整
-def _get_blocking_timeout() -> float:
- return global_config.plugin_runtime.workflow_blocking_timeout_sec
-
-
-class ModificationRecord:
- """消息修改记录"""
-
- __slots__ = ("stage", "hook_name", "timestamp", "fields_changed")
-
- def __init__(self, stage: str, hook_name: str, fields_changed: List[str]) -> None:
- self.stage = stage
- self.hook_name = hook_name
- self.timestamp = time.perf_counter()
- self.fields_changed = fields_changed
-
-
-class WorkflowContext:
- """Workflow 执行上下文"""
-
- def __init__(self, trace_id: Optional[str] = None, stream_id: Optional[str] = None) -> None:
- self.trace_id = trace_id or uuid.uuid4().hex
- self.stream_id = stream_id
- self.timings: Dict[str, float] = {}
- self.errors: List[str] = []
- # 阶段间数据传递(按 stage 命名空间隔离)
- self.stage_outputs: Dict[str, Dict[str, Any]] = {}
- # 消息修改审计日志
- self.modification_log: List[ModificationRecord] = []
- # PLAN 阶段命令匹配结果
- self.matched_command: Optional[str] = None
-
- def set_stage_output(self, stage: str, key: str, value: Any) -> None:
- self.stage_outputs.setdefault(stage, {})[key] = value
-
- def get_stage_output(self, stage: str, key: str, default: Any = None) -> Any:
- return self.stage_outputs.get(stage, {}).get(key, default)
-
-
-class WorkflowResult:
- """Workflow 执行结果"""
-
- def __init__(
- self,
- status: str = "completed", # completed / aborted / failed
- return_message: str = "",
- stopped_at: str = "",
- diagnostics: Optional[Dict[str, Any]] = None,
- ) -> None:
- self.status = status
- self.return_message = return_message
- self.stopped_at = stopped_at
- self.diagnostics = diagnostics or {}
-
-
-# invoke_fn 签名
-InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]]
-
-
-class WorkflowExecutor:
- """Host-side Workflow 执行器
-
- 实现 stage-based pipeline + per-stage hook chain with priority + early return。
- """
-
- def __init__(self, registry: ComponentRegistry) -> None:
- self._registry = registry
- self._background_tasks: Set[asyncio.Task] = set()
-
- async def execute(
- self,
- invoke_fn: InvokeFn,
- message: Optional[Dict[str, Any]] = None,
- stream_id: Optional[str] = None,
- context: Optional[WorkflowContext] = None,
- command_invoke_fn: Optional[InvokeFn] = None,
- ) -> Tuple[WorkflowResult, Optional[Dict[str, Any]], WorkflowContext]:
- """执行 workflow pipeline。
-
- Args:
- invoke_fn: 用于 workflow_step 的回调
- command_invoke_fn: 用于 command 的回调(走 plugin.invoke_command),
- 未传则复用 invoke_fn
-
- Returns:
- (result, final_message, context)
- """
- ctx = context or WorkflowContext(stream_id=stream_id)
- current_message = dict(message) if message else None
-
- for stage in STAGE_SEQUENCE:
- stage_start = time.perf_counter()
-
- try:
- # PLAN 阶段: 先做 Command 路由
- if stage == "plan" and current_message:
- cmd_result = await self._route_command(command_invoke_fn or invoke_fn, current_message, ctx)
- if cmd_result is not None:
- # 命令匹配成功,跳过 PLAN 阶段的 hook,直接存结果进 stage_outputs
- ctx.set_stage_output("plan", "command_result", cmd_result)
- ctx.timings[stage] = time.perf_counter() - stage_start
- continue
-
- # 获取该阶段所有 hook(已按 priority 降序排列)
- all_steps = self._registry.get_workflow_steps(stage)
- if not all_steps:
- ctx.timings[stage] = time.perf_counter() - stage_start
- continue
-
- # 1. Pre-filter
- filtered_steps = self._pre_filter(all_steps, current_message)
-
- # 2. 分离 blocking 和 non-blocking
- blocking_steps = [s for s in filtered_steps if s.metadata.get("blocking", True)]
- nonblocking_steps = [s for s in filtered_steps if not s.metadata.get("blocking", True)]
-
- # 3. 串行执行 blocking hook
- skip_stage = False
- for step in blocking_steps:
- hook_result, modified, step_error = await self._invoke_step(
- invoke_fn, step, stage, ctx, current_message
- )
-
- if step_error:
- error_policy = step.metadata.get("error_policy", "abort")
- ctx.errors.append(f"{step.full_name}: {step_error}")
-
- if error_policy == "abort":
- ctx.timings[stage] = time.perf_counter() - stage_start
- return (
- WorkflowResult(
- status="failed",
- return_message=step_error,
- stopped_at=stage,
- diagnostics={"step": step.full_name, "trace_id": ctx.trace_id},
- ),
- current_message,
- ctx,
- )
- elif error_policy == "skip":
- logger.warning(f"[{ctx.trace_id}] hook {step.full_name} 异常(skip): {step_error}")
- continue
- else: # log
- logger.warning(f"[{ctx.trace_id}] hook {step.full_name} 异常(log): {step_error}")
- continue
-
- # 更新消息(仅 blocking hook 有权修改)
- if modified:
- changed_fields = (
- _diff_keys(current_message, modified) if current_message else list(modified.keys())
- )
- ctx.modification_log.append(ModificationRecord(stage, step.full_name, changed_fields))
- current_message = modified
-
- if hook_result == HOOK_ABORT:
- ctx.timings[stage] = time.perf_counter() - stage_start
- return (
- WorkflowResult(
- status="aborted",
- return_message=f"aborted by {step.full_name}",
- stopped_at=stage,
- diagnostics={"step": step.full_name, "trace_id": ctx.trace_id},
- ),
- current_message,
- ctx,
- )
-
- if hook_result == HOOK_SKIP_STAGE:
- skip_stage = True
- break
-
- # 4. 并发执行 non-blocking hook(只读,忽略返回值中的 modified_message)
- if nonblocking_steps and not skip_stage:
- for step in nonblocking_steps:
- self._track_background_task(
- asyncio.create_task(
- self._invoke_step_fire_and_forget(invoke_fn, step, stage, ctx, current_message)
- )
- )
-
- ctx.timings[stage] = time.perf_counter() - stage_start
-
- except Exception as e:
- ctx.timings[stage] = time.perf_counter() - stage_start
- ctx.errors.append(f"{stage}: {e}")
- logger.error(f"[{ctx.trace_id}] 阶段 {stage} 未捕获异常: {e}", exc_info=True)
- return (
- WorkflowResult(
- status="failed",
- return_message=str(e),
- stopped_at=stage,
- diagnostics={"trace_id": ctx.trace_id},
- ),
- current_message,
- ctx,
- )
-
- return (
- WorkflowResult(
- status="completed",
- return_message="workflow completed",
- diagnostics={"trace_id": ctx.trace_id},
- ),
- current_message,
- ctx,
- )
-
- def _track_background_task(self, task: asyncio.Task) -> None:
- """保持 non-blocking workflow task 的强引用,直到任务结束。"""
- self._background_tasks.add(task)
- task.add_done_callback(self._background_tasks.discard)
-
- # ─── 内部方法 ──────────────────────────────────────────────
-
- def _pre_filter(
- self,
- steps: List[RegisteredComponent],
- message: Optional[Dict[str, Any]],
- ) -> List[RegisteredComponent]:
- """根据 hook 声明的 filter 条件预过滤,避免无意义的 IPC 调用。"""
- if not message:
- return steps
-
- result = []
- for step in steps:
- filter_cond = step.metadata.get("filter", {})
- if not filter_cond:
- result.append(step)
- continue
- if self._match_filter(filter_cond, message):
- result.append(step)
- return result
-
- @staticmethod
- def _match_filter(filter_cond: Dict[str, Any], message: Dict[str, Any]) -> bool:
- """简单 key-value 匹配过滤。
-
- filter 中的每个 key 必须在 message 中存在且值相等,
- 全部匹配才通过。
- """
- for key, expected in filter_cond.items():
- actual = message.get(key)
- if (isinstance(expected, list) and actual not in expected) or (
- not isinstance(expected, list) and actual != expected
- ):
- return False
- return True
-
- async def _invoke_step(
- self,
- invoke_fn: InvokeFn,
- step: RegisteredComponent,
- stage: str,
- ctx: WorkflowContext,
- message: Optional[Dict[str, Any]],
- ) -> Tuple[str, Optional[Dict[str, Any]], Optional[str]]:
- """调用单个 blocking hook。
-
- Returns:
- (hook_result, modified_message, error_string_or_None)
- """
- timeout_ms = step.metadata.get("timeout_ms", 0)
- # 使用 hook 声明的超时,但不超过全局安全阀
- timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else _get_blocking_timeout()
- step_key = f"{stage}:{step.full_name}"
- step_start = time.perf_counter()
-
- try:
- coro = invoke_fn(
- step.plugin_id,
- step.name,
- {
- "stage": stage,
- "trace_id": ctx.trace_id,
- "message": message,
- "stage_outputs": ctx.stage_outputs,
- },
- )
- resp = await asyncio.wait_for(coro, timeout=timeout_sec)
- ctx.timings[step_key] = time.perf_counter() - step_start
-
- hook_result = resp.get("hook_result", HOOK_CONTINUE)
- modified_message = resp.get("modified_message")
- # 存 stage output(如果 hook 提供了)
- stage_out = resp.get("stage_output")
- if isinstance(stage_out, dict):
- for k, v in stage_out.items():
- ctx.set_stage_output(stage, k, v)
-
- return hook_result, modified_message, None
-
- except asyncio.TimeoutError:
- ctx.timings[step_key] = time.perf_counter() - step_start
- return HOOK_CONTINUE, None, f"timeout after {timeout_ms}ms"
-
- except Exception as e:
- ctx.timings[step_key] = time.perf_counter() - step_start
- return HOOK_CONTINUE, None, str(e)
-
- async def _invoke_step_fire_and_forget(
- self,
- invoke_fn: InvokeFn,
- step: RegisteredComponent,
- stage: str,
- ctx: WorkflowContext,
- message: Optional[Dict[str, Any]],
- ) -> None:
- """Non-blocking hook 调用,只读,忽略结果。"""
- timeout_ms = step.metadata.get("timeout_ms", 0)
- # 使用 hook 声明的超时,但无声明时回退到全局安全阀,防止 task 泄漏
- timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else _get_blocking_timeout()
-
- try:
- coro = invoke_fn(
- step.plugin_id,
- step.name,
- {
- "stage": stage,
- "trace_id": ctx.trace_id,
- "message": message,
- "stage_outputs": ctx.stage_outputs,
- },
- )
- await asyncio.wait_for(coro, timeout=timeout_sec)
- except asyncio.TimeoutError:
- logger.warning(f"[{ctx.trace_id}] non-blocking hook {step.full_name} 超时 ({timeout_sec}s)")
- except Exception as e:
- logger.debug(f"[{ctx.trace_id}] non-blocking hook {step.full_name}: {e}")
-
- async def _route_command(
- self,
- invoke_fn: InvokeFn,
- message: Dict[str, Any],
- ctx: WorkflowContext,
- ) -> Optional[Dict[str, Any]]:
- """PLAN 阶段内置 Command 路由。
-
- 在 registry 中查找匹配的 command 组件,
- 匹配到则直接路由到对应 command handler,返回执行结果。
- 不匹配则返回 None,让 PLAN 阶段的 hook 继续执行。
- """
- plain_text = message.get("plain_text", "")
- if not plain_text:
- return None
-
- match_result = self._registry.find_command_by_text(plain_text)
- if match_result is None:
- return None
-
- matched, matched_groups = match_result
-
- ctx.matched_command = matched.full_name
- logger.info(f"[{ctx.trace_id}] 命令匹配: {matched.full_name}")
-
- try:
- return await invoke_fn(
- matched.plugin_id,
- matched.name,
- {
- "text": plain_text,
- "message": message,
- "trace_id": ctx.trace_id,
- "matched_groups": matched_groups,
- },
- )
- except Exception as e:
- logger.error(f"[{ctx.trace_id}] 命令 {matched.full_name} 执行失败: {e}", exc_info=True)
- ctx.errors.append(f"command:{matched.full_name}: {e}")
- return None
-
-
-def _diff_keys(old: Dict[str, Any], new: Dict[str, Any]) -> List[str]:
- """返回 new 中与 old 不同的 key 列表。"""
- return [k for k, v in new.items() if k not in old or old[k] != v]
diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py
index 04c8e324..c34f5ef5 100644
--- a/src/plugin_runtime/integration.py
+++ b/src/plugin_runtime/integration.py
@@ -8,23 +8,27 @@
"""
from pathlib import Path
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, Iterable, List, Optional, Sequence, Tuple
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, Iterable, List, Optional, Sequence, Set, Tuple
import asyncio
-import json
+
import tomlkit
from src.common.logger import get_logger
-from src.config.config import global_config
+from src.config.config import config_manager
from src.config.file_watcher import FileChange, FileWatcher
+from src.platform_io import DeliveryBatch, InboundMessageEnvelope, get_platform_io_manager
from src.plugin_runtime.capabilities import (
RuntimeComponentCapabilityMixin,
RuntimeCoreCapabilityMixin,
RuntimeDataCapabilityMixin,
)
from src.plugin_runtime.capabilities.registry import register_capability_impls
+from src.plugin_runtime.host.message_utils import MessageDict, PluginMessageUtils
+from src.plugin_runtime.runner.manifest_validator import ManifestValidator
if TYPE_CHECKING:
+ from src.chat.message_receive.message import SessionMessage
from src.plugin_runtime.host.supervisor import PluginSupervisor
logger = get_logger("plugin_runtime.integration")
@@ -55,6 +59,7 @@ class PluginRuntimeManager(
"""
def __init__(self) -> None:
+ """初始化插件运行时管理器。"""
from src.plugin_runtime.host.supervisor import PluginSupervisor
self._builtin_supervisor: Optional[PluginSupervisor] = None
@@ -63,6 +68,26 @@ class PluginRuntimeManager(
self._plugin_file_watcher: Optional[FileWatcher] = None
self._plugin_source_watcher_subscription_id: Optional[str] = None
self._plugin_config_watcher_subscriptions: Dict[str, Tuple[Path, str]] = {}
+ self._plugin_path_cache: Dict[str, Path] = {}
+ self._manifest_validator: ManifestValidator = ManifestValidator()
+ self._config_reload_callback: Callable[[Sequence[str]], Awaitable[None]] = self._handle_main_config_reload
+ self._config_reload_callback_registered: bool = False
+
+ async def _dispatch_platform_inbound(self, envelope: InboundMessageEnvelope) -> None:
+ """接收 Platform IO 审核后的入站消息并送入主消息链。
+
+ Args:
+ envelope: Platform IO 产出的入站封装。
+ """
+ session_message = envelope.session_message
+ if session_message is None and envelope.payload is not None:
+ session_message = PluginMessageUtils._build_session_message_from_dict(dict(envelope.payload))
+ if session_message is None:
+ raise ValueError("Platform IO 入站封装缺少可用的 SessionMessage 或 payload")
+
+ from src.chat.message_receive.bot import chat_bot
+
+ await chat_bot.receive_message(session_message)
# ─── 插件目录 ─────────────────────────────────────────────
@@ -78,6 +103,42 @@ class PluginRuntimeManager(
candidate = Path("plugins").resolve()
return [candidate] if candidate.is_dir() else []
+ @classmethod
+ def _discover_plugin_dependency_map(cls, plugin_dirs: Iterable[Path]) -> Dict[str, List[str]]:
+ """扫描指定插件目录集合,返回 ``plugin_id -> dependencies`` 映射。"""
+ validator = ManifestValidator()
+ return validator.build_plugin_dependency_map(plugin_dirs)
+
+ @classmethod
+ def _build_group_start_order(
+ cls,
+ builtin_dirs: Sequence[Path],
+ third_party_dirs: Sequence[Path],
+ ) -> List[str]:
+ """根据跨 Supervisor 依赖关系决定 Runner 启动顺序。"""
+
+ builtin_dependencies = cls._discover_plugin_dependency_map(builtin_dirs)
+ third_party_dependencies = cls._discover_plugin_dependency_map(third_party_dirs)
+ builtin_plugin_ids = set(builtin_dependencies)
+ third_party_plugin_ids = set(third_party_dependencies)
+
+ builtin_needs_third_party = any(
+ dependency in third_party_plugin_ids
+ for dependencies in builtin_dependencies.values()
+ for dependency in dependencies
+ )
+ third_party_needs_builtin = any(
+ dependency in builtin_plugin_ids
+ for dependencies in third_party_dependencies.values()
+ for dependency in dependencies
+ )
+
+ if builtin_needs_third_party and third_party_needs_builtin:
+ raise RuntimeError("检测到跨 Supervisor 循环依赖,当前无法安全启动独立 Runner")
+ if builtin_needs_third_party:
+ return ["third_party", "builtin"]
+ return ["builtin", "third_party"]
+
# ─── 生命周期 ─────────────────────────────────────────────
async def start(self) -> None:
@@ -86,7 +147,7 @@ class PluginRuntimeManager(
logger.warning("PluginRuntimeManager 已在运行中,跳过重复启动")
return
- _cfg = global_config.plugin_runtime
+ _cfg = config_manager.get_global_config().plugin_runtime
if not _cfg.enabled:
logger.info("插件运行时已在配置中禁用,跳过启动")
return
@@ -108,6 +169,8 @@ class PluginRuntimeManager(
logger.info("未找到任何插件目录,跳过插件运行时启动")
return
+ platform_io_manager = get_platform_io_manager()
+
# 从配置读取自定义 IPC socket 路径(留空则自动生成)
socket_path_base = _cfg.ipc_socket_path or None
@@ -132,19 +195,46 @@ class PluginRuntimeManager(
started_supervisors: List[PluginSupervisor] = []
try:
- if self._builtin_supervisor:
- await self._builtin_supervisor.start()
- started_supervisors.append(self._builtin_supervisor)
- if self._third_party_supervisor:
- await self._third_party_supervisor.start()
- started_supervisors.append(self._third_party_supervisor)
+ platform_io_manager.set_inbound_dispatcher(self._dispatch_platform_inbound)
+ await platform_io_manager.ensure_send_pipeline_ready()
+
+ supervisor_groups: Dict[str, Optional[PluginSupervisor]] = {
+ "builtin": self._builtin_supervisor,
+ "third_party": self._third_party_supervisor,
+ }
+ start_order = self._build_group_start_order(builtin_dirs, third_party_dirs)
+
+ for group_name in start_order:
+ supervisor = supervisor_groups.get(group_name)
+ if supervisor is None:
+ continue
+
+ external_plugin_versions = {
+ plugin_id: plugin_version
+ for started_supervisor in started_supervisors
+ for plugin_id, plugin_version in started_supervisor.get_loaded_plugin_versions().items()
+ }
+ supervisor.set_external_available_plugins(external_plugin_versions)
+ await supervisor.start()
+ started_supervisors.append(supervisor)
+
await self._start_plugin_file_watcher()
+ config_manager.register_reload_callback(self._config_reload_callback)
+ self._config_reload_callback_registered = True
self._started = True
logger.info(f"插件运行时已启动 — 内置: {builtin_dirs or '无'}, 第三方: {third_party_dirs or '无'}")
except Exception as e:
logger.error(f"插件运行时启动失败: {e}", exc_info=True)
await self._stop_plugin_file_watcher()
+ if self._config_reload_callback_registered:
+ config_manager.unregister_reload_callback(self._config_reload_callback)
+ self._config_reload_callback_registered = False
await asyncio.gather(*(sv.stop() for sv in started_supervisors), return_exceptions=True)
+ platform_io_manager.clear_inbound_dispatcher()
+ try:
+ await platform_io_manager.stop()
+ except Exception as platform_io_exc:
+ logger.warning(f"Platform IO 停止失败: {platform_io_exc}")
self._started = False
self._builtin_supervisor = None
self._third_party_supervisor = None
@@ -154,7 +244,11 @@ class PluginRuntimeManager(
if not self._started:
return
+ platform_io_manager = get_platform_io_manager()
await self._stop_plugin_file_watcher()
+ if self._config_reload_callback_registered:
+ config_manager.unregister_reload_callback(self._config_reload_callback)
+ self._config_reload_callback_registered = False
coroutines: List[Coroutine[Any, Any, None]] = []
if self._builtin_supervisor:
@@ -162,18 +256,32 @@ class PluginRuntimeManager(
if self._third_party_supervisor:
coroutines.append(self._third_party_supervisor.stop())
+ stop_errors: List[str] = []
try:
- await asyncio.gather(*coroutines, return_exceptions=True)
- logger.info("插件运行时已停止")
- except Exception as e:
- logger.error(f"插件运行时停止失败: {e}", exc_info=True)
+ results = await asyncio.gather(*coroutines, return_exceptions=True)
+ for result in results:
+ if isinstance(result, Exception):
+ stop_errors.append(str(result))
+
+ platform_io_manager.clear_inbound_dispatcher()
+ try:
+ await platform_io_manager.stop()
+ except Exception as exc:
+ stop_errors.append(f"Platform IO: {exc}")
+
+ if stop_errors:
+ logger.error(f"插件运行时停止过程中存在错误: {'; '.join(stop_errors)}")
+ else:
+ logger.info("插件运行时已停止")
finally:
self._started = False
self._builtin_supervisor = None
self._third_party_supervisor = None
+ self._plugin_path_cache.clear()
@property
def is_running(self) -> bool:
+ """返回插件运行时是否处于启动状态。"""
return self._started
@property
@@ -181,11 +289,176 @@ class PluginRuntimeManager(
"""获取所有活跃的 Supervisor"""
return [s for s in (self._builtin_supervisor, self._third_party_supervisor) if s is not None]
+ def _build_registered_dependency_map(self) -> Dict[str, Set[str]]:
+ """根据当前已注册插件构建全局依赖图。"""
+
+ dependency_map: Dict[str, Set[str]] = {}
+ for supervisor in self.supervisors:
+ for plugin_id, registration in getattr(supervisor, "_registered_plugins", {}).items():
+ dependency_map[plugin_id] = {
+ str(dependency or "").strip()
+ for dependency in getattr(registration, "dependencies", [])
+ if str(dependency or "").strip()
+ }
+ return dependency_map
+
+ @staticmethod
+ def _collect_reverse_dependents(
+ plugin_ids: Set[str],
+ dependency_map: Dict[str, Set[str]],
+ ) -> Set[str]:
+ """根据依赖图收集反向依赖闭包。"""
+
+ impacted_plugins: Set[str] = set(plugin_ids)
+ changed = True
+
+ while changed:
+ changed = False
+ for registered_plugin_id, dependencies in dependency_map.items():
+ if registered_plugin_id in impacted_plugins:
+ continue
+ if dependencies & impacted_plugins:
+ impacted_plugins.add(registered_plugin_id)
+ changed = True
+
+ return impacted_plugins
+
+ def _build_registered_supervisor_map(self) -> Dict[str, "PluginSupervisor"]:
+ """构建当前已注册插件到所属 Supervisor 的映射。"""
+
+ return {
+ plugin_id: supervisor
+ for supervisor in self.supervisors
+ for plugin_id in supervisor.get_loaded_plugin_ids()
+ }
+
+ def _build_external_available_plugins_for_supervisor(self, target_supervisor: "PluginSupervisor") -> Dict[str, str]:
+ """收集某个 Supervisor 可用的外部插件版本映射。"""
+
+ external_plugin_versions: Dict[str, str] = {}
+ for supervisor in self.supervisors:
+ if supervisor is target_supervisor:
+ continue
+ external_plugin_versions.update(supervisor.get_loaded_plugin_versions())
+ return external_plugin_versions
+
+ def _find_supervisor_by_plugin_directory(self, plugin_id: str) -> Optional["PluginSupervisor"]:
+ """根据插件目录推断应负责该插件重载的 Supervisor。"""
+
+ for supervisor in self.supervisors:
+ if self._get_plugin_path_for_supervisor(supervisor, plugin_id) is not None:
+ return supervisor
+ return None
+
+ def _warn_skipped_cross_supervisor_reload(
+ self,
+ requested_loaded_plugin_ids: Set[str],
+ dependency_map: Dict[str, Set[str]],
+ supervisor_by_plugin: Dict[str, "PluginSupervisor"],
+ ) -> None:
+ """记录因跨 Supervisor 边界而未参与联动重载的插件。"""
+
+ if not requested_loaded_plugin_ids:
+ return
+
+ handled_plugin_ids: Set[str] = set()
+ for supervisor in self.supervisors:
+ local_requested_plugin_ids = {
+ plugin_id
+ for plugin_id in requested_loaded_plugin_ids
+ if supervisor_by_plugin.get(plugin_id) is supervisor
+ }
+ if not local_requested_plugin_ids:
+ continue
+
+ local_plugin_ids = set(supervisor.get_loaded_plugin_ids())
+ local_dependency_map = {
+ plugin_id: {
+ dependency
+ for dependency in dependency_map.get(plugin_id, set())
+ if dependency in local_plugin_ids
+ }
+ for plugin_id in local_plugin_ids
+ }
+ handled_plugin_ids.update(
+ self._collect_reverse_dependents(local_requested_plugin_ids, local_dependency_map)
+ )
+
+ impacted_plugin_ids = self._collect_reverse_dependents(requested_loaded_plugin_ids, dependency_map)
+ skipped_plugin_ids = sorted(impacted_plugin_ids - handled_plugin_ids)
+ if not skipped_plugin_ids:
+ return
+
+ logger.warning(
+ f"插件 {', '.join(sorted(requested_loaded_plugin_ids))} 存在跨 Supervisor 依赖方未联动重载: "
+ f"{', '.join(skipped_plugin_ids)}。当前仅在单个 Supervisor 内执行联动重载;"
+ "跨 Supervisor API 调用仍然可用。如需联动重载,请将相关插件放在同一个 Supervisor 内。"
+ )
+
+ async def reload_plugins_globally(self, plugin_ids: Sequence[str], reason: str = "manual") -> bool:
+ """按 Supervisor 分组执行精确重载。
+
+ 仅在单个 Supervisor 内执行依赖联动;跨 Supervisor 依赖方仅记录告警,
+ 不再自动参与本次热重载。
+ """
+
+ normalized_plugin_ids = [
+ normalized_plugin_id
+ for plugin_id in plugin_ids
+ if (normalized_plugin_id := str(plugin_id or "").strip())
+ ]
+ if not normalized_plugin_ids:
+ return True
+
+ dependency_map = self._build_registered_dependency_map()
+ supervisor_by_plugin = self._build_registered_supervisor_map()
+ supervisor_roots: Dict["PluginSupervisor", List[str]] = {}
+ requested_loaded_plugin_ids: Set[str] = set()
+ missing_plugin_ids: List[str] = []
+
+ for plugin_id in normalized_plugin_ids:
+ supervisor = supervisor_by_plugin.get(plugin_id)
+ if supervisor is not None:
+ requested_loaded_plugin_ids.add(plugin_id)
+ else:
+ supervisor = self._find_supervisor_by_plugin_directory(plugin_id)
+
+ if supervisor is None:
+ missing_plugin_ids.append(plugin_id)
+ continue
+
+ if plugin_id not in supervisor_roots.setdefault(supervisor, []):
+ supervisor_roots[supervisor].append(plugin_id)
+
+ if missing_plugin_ids:
+ logger.warning(f"以下插件未找到可重载的 Supervisor,已跳过: {', '.join(sorted(missing_plugin_ids))}")
+
+ self._warn_skipped_cross_supervisor_reload(
+ requested_loaded_plugin_ids=requested_loaded_plugin_ids,
+ dependency_map=dependency_map,
+ supervisor_by_plugin=supervisor_by_plugin,
+ )
+
+ success = True
+ for supervisor, root_plugin_ids in supervisor_roots.items():
+ if not root_plugin_ids:
+ continue
+
+ reloaded = await supervisor.reload_plugins(
+ plugin_ids=root_plugin_ids,
+ reason=reason,
+ external_available_plugins=self._build_external_available_plugins_for_supervisor(supervisor),
+ )
+ success = success and reloaded
+
+ return success and not missing_plugin_ids
+
async def notify_plugin_config_updated(
self,
plugin_id: str,
config_data: Optional[Dict[str, Any]] = None,
config_version: str = "",
+ config_scope: str = "self",
) -> bool:
"""向拥有该插件的 Supervisor 推送配置更新事件。
@@ -193,6 +466,7 @@ class PluginRuntimeManager(
plugin_id: 插件 ID
config_data: 可选的配置数据(如果为 None 则由 Supervisor 从磁盘加载)
config_version: 可选的配置版本字符串,供 Supervisor 进行版本控制
+ config_scope: 配置变更范围。
"""
if not self._started:
return False
@@ -209,23 +483,78 @@ class PluginRuntimeManager(
config_payload = (
config_data
if config_data is not None
- else self._load_plugin_config_for_supervisor(plugin_id, plugin_dirs=sv._plugin_dirs)
+ else self._load_plugin_config_for_supervisor(sv, plugin_id)
)
- await sv.notify_plugin_config_updated(
+ return await sv.notify_plugin_config_updated(
plugin_id=plugin_id,
config_data=config_payload,
config_version=config_version,
+ config_scope=config_scope,
)
- return True
+
+ @staticmethod
+ def _normalize_config_reload_scopes(changed_scopes: Sequence[str]) -> tuple[str, ...]:
+ """规范化配置热重载范围列表。
+
+ Args:
+ changed_scopes: 原始配置热重载范围列表。
+
+ Returns:
+ tuple[str, ...]: 去重后的有效配置范围元组。
+ """
+
+ normalized_scopes: list[str] = []
+ for scope in changed_scopes:
+ normalized_scope = str(scope or "").strip().lower()
+ if normalized_scope not in {"bot", "model"}:
+ continue
+ if normalized_scope not in normalized_scopes:
+ normalized_scopes.append(normalized_scope)
+ return tuple(normalized_scopes)
+
+ async def _broadcast_config_reload(self, scope: str, config_data: Dict[str, Any]) -> None:
+ """向订阅指定范围的插件广播配置热重载。
+
+ Args:
+ scope: 配置变更范围,仅支持 ``bot`` 或 ``model``。
+ config_data: 最新配置数据。
+ """
+
+ for supervisor in self.supervisors:
+ for plugin_id in supervisor.get_config_reload_subscribers(scope):
+ delivered = await supervisor.notify_plugin_config_updated(
+ plugin_id=plugin_id,
+ config_data=config_data,
+ config_version="",
+ config_scope=scope,
+ )
+ if not delivered:
+ logger.warning(f"向插件 {plugin_id} 广播 {scope} 配置热重载失败")
+
+ async def _handle_main_config_reload(self, changed_scopes: Sequence[str]) -> None:
+ """处理 bot/model 主配置热重载广播。
+
+ Args:
+ changed_scopes: 本次热重载命中的配置范围列表。
+ """
+
+ if not self._started:
+ return
+
+ normalized_scopes = self._normalize_config_reload_scopes(changed_scopes)
+ if "bot" in normalized_scopes:
+ await self._broadcast_config_reload("bot", config_manager.get_global_config().model_dump(mode="json"))
+ if "model" in normalized_scopes:
+ await self._broadcast_config_reload("model", config_manager.get_model_config().model_dump(mode="json"))
# ─── 事件桥接 ──────────────────────────────────────────────
async def bridge_event(
self,
event_type_value: str,
- message_dict: Optional[Dict[str, Any]] = None,
+ message_dict: Optional[MessageDict] = None,
extra_args: Optional[Dict[str, Any]] = None,
- ) -> Tuple[bool, Optional[Dict[str, Any]]]:
+ ) -> Tuple[bool, Optional[MessageDict]]:
"""将事件分发到所有 Supervisor
Returns:
@@ -235,17 +564,23 @@ class PluginRuntimeManager(
return True, None
new_event_type: str = _EVENT_TYPE_MAP.get(event_type_value, event_type_value)
- modified: Optional[Dict[str, Any]] = None
+ modified: Optional[MessageDict] = None
+ current_message: Optional["SessionMessage"] = (
+ PluginMessageUtils._build_session_message_from_dict(dict(message_dict))
+ if message_dict is not None
+ else None
+ )
for sv in self.supervisors:
try:
cont, mod = await sv.dispatch_event(
event_type=new_event_type,
- message=modified or message_dict,
+ message=current_message,
extra_args=extra_args,
)
if mod is not None:
- modified = mod
+ current_message = mod
+ modified = PluginMessageUtils._session_message_to_dict(mod)
if not cont:
return False, modified
except Exception as e:
@@ -295,6 +630,37 @@ class PluginRuntimeManager(
timeout_ms=timeout_ms,
)
+ async def try_send_message_via_platform_io(
+ self,
+ message: "SessionMessage",
+ ) -> Optional[DeliveryBatch]:
+ """尝试通过 Platform IO 中间层发送消息。
+
+ Args:
+ message: 待发送的内部会话消息。
+
+ Returns:
+ Optional[DeliveryBatch]: 若当前消息命中了至少一条发送路由,则返回
+ 实际发送结果;若没有可用路由或 Platform IO 尚未启动,则返回 ``None``。
+ """
+ if not self._started:
+ return None
+
+ platform_io_manager = get_platform_io_manager()
+ if not platform_io_manager.is_started:
+ return None
+
+ try:
+ route_key = platform_io_manager.build_route_key_from_message(message)
+ except Exception as exc:
+ logger.warning(f"根据消息构造 Platform IO 路由键失败: {exc}")
+ return None
+
+ if not platform_io_manager.resolve_drivers(route_key):
+ return None
+
+ return await platform_io_manager.send_message(message, route_key)
+
def _get_supervisors_for_plugin(self, plugin_id: str) -> List["PluginSupervisor"]:
"""返回当前持有指定插件的所有 Supervisor。
@@ -314,30 +680,38 @@ class PluginRuntimeManager(
raise RuntimeError(f"插件 {plugin_id} 同时存在于多个 Supervisor 中,无法安全路由")
return matches[0] if matches else None
- @staticmethod
- def _find_duplicate_plugin_ids(plugin_dirs: List[Path]) -> Dict[str, List[Path]]:
+ async def load_plugin_globally(self, plugin_id: str, reason: str = "manual") -> bool:
+ """加载或重载单个插件,并为其补齐跨 Supervisor 外部依赖。"""
+
+ normalized_plugin_id = str(plugin_id or "").strip()
+ if not normalized_plugin_id:
+ return False
+
+ try:
+ registered_supervisor = self._get_supervisor_for_plugin(normalized_plugin_id)
+ except RuntimeError:
+ return False
+
+ if registered_supervisor is not None:
+ return await self.reload_plugins_globally([normalized_plugin_id], reason=reason)
+
+ supervisor = self._find_supervisor_by_plugin_directory(normalized_plugin_id)
+ if supervisor is None:
+ return False
+
+ return await supervisor.reload_plugins(
+ plugin_ids=[normalized_plugin_id],
+ reason=reason,
+ external_available_plugins=self._build_external_available_plugins_for_supervisor(supervisor),
+ )
+
+ @classmethod
+ def _find_duplicate_plugin_ids(cls, plugin_dirs: List[Path]) -> Dict[str, List[Path]]:
"""扫描插件目录,找出被多个目录重复声明的插件 ID。"""
plugin_locations: Dict[str, List[Path]] = {}
- for base_dir in plugin_dirs:
- if not base_dir.is_dir():
- continue
- for entry in base_dir.iterdir():
- if not entry.is_dir():
- continue
- manifest_path = entry / "_manifest.json"
- plugin_path = entry / "plugin.py"
- if not manifest_path.exists() or not plugin_path.exists():
- continue
-
- plugin_id = entry.name
- try:
- with open(manifest_path, "r", encoding="utf-8") as manifest_file:
- manifest = json.load(manifest_file)
- plugin_id = str(manifest.get("name", entry.name)).strip() or entry.name
- except Exception:
- continue
-
- plugin_locations.setdefault(plugin_id, []).append(entry)
+ validator = ManifestValidator()
+ for plugin_path, manifest in validator.iter_plugin_manifests(plugin_dirs):
+ plugin_locations.setdefault(manifest.id, []).append(plugin_path)
return {
plugin_id: sorted(dict.fromkeys(paths), key=lambda p: str(p))
@@ -370,6 +744,7 @@ class PluginRuntimeManager(
async def _stop_plugin_file_watcher(self) -> None:
"""停止插件文件监视器,并清理所有已注册订阅。"""
if self._plugin_file_watcher is None:
+ self._plugin_path_cache.clear()
return
for _plugin_id, (_config_path, subscription_id) in list(self._plugin_config_watcher_subscriptions.items()):
self._plugin_file_watcher.unsubscribe(subscription_id)
@@ -379,12 +754,79 @@ class PluginRuntimeManager(
self._plugin_source_watcher_subscription_id = None
await self._plugin_file_watcher.stop()
self._plugin_file_watcher = None
+ self._plugin_path_cache.clear()
def _iter_plugin_dirs(self) -> Iterable[Path]:
"""迭代所有 Supervisor 当前管理的插件根目录。"""
for supervisor in self.supervisors:
yield from getattr(supervisor, "_plugin_dirs", [])
+ @staticmethod
+ def _iter_candidate_plugin_paths(plugin_dirs: Iterable[Path]) -> Iterable[Path]:
+ """迭代所有可能的插件目录路径。
+
+ Args:
+ plugin_dirs: 一个或多个插件根目录。
+
+ Yields:
+ Path: 单个插件目录路径。
+ """
+ for plugin_dir in plugin_dirs:
+ plugin_root = Path(plugin_dir).resolve()
+ if not plugin_root.is_dir():
+ continue
+ for entry in plugin_root.iterdir():
+ if entry.is_dir():
+ yield entry.resolve()
+
+ def _read_plugin_id_from_plugin_path(self, plugin_path: Path) -> Optional[str]:
+ """从单个插件目录中读取 manifest 声明的插件 ID。
+
+ Args:
+ plugin_path: 单个插件目录路径。
+
+ Returns:
+ Optional[str]: 解析成功时返回插件 ID,否则返回 ``None``。
+ """
+ return self._manifest_validator.read_plugin_id_from_plugin_path(plugin_path)
+
+ def _iter_discovered_plugin_paths(self, plugin_dirs: Iterable[Path]) -> Iterable[Tuple[str, Path]]:
+ """迭代目录中可解析到的插件 ID 与实际目录路径。
+
+ Args:
+ plugin_dirs: 一个或多个插件根目录。
+
+ Yields:
+ Tuple[str, Path]: ``(plugin_id, plugin_path)`` 二元组。
+ """
+ for plugin_path in self._iter_candidate_plugin_paths(plugin_dirs):
+ if plugin_id := self._read_plugin_id_from_plugin_path(plugin_path):
+ yield plugin_id, plugin_path
+
+ def _get_plugin_path_for_supervisor(self, supervisor: Any, plugin_id: str) -> Optional[Path]:
+ """为指定 Supervisor 定位某个插件的实际目录。
+
+ Args:
+ supervisor: 目标 Supervisor。
+ plugin_id: 插件 ID。
+
+ Returns:
+ Optional[Path]: 插件目录路径;未找到时返回 ``None``。
+ """
+ cached_path = self._plugin_path_cache.get(plugin_id)
+ if cached_path is not None:
+ for plugin_dir in getattr(supervisor, "_plugin_dirs", []):
+ if self._plugin_dir_matches(cached_path, Path(plugin_dir)):
+ return cached_path
+
+ for candidate_plugin_id, plugin_path in self._iter_discovered_plugin_paths(getattr(supervisor, "_plugin_dirs", [])):
+ if candidate_plugin_id != plugin_id:
+ continue
+ self._plugin_path_cache[plugin_id] = plugin_path
+ return plugin_path
+
+ return None
+
def _refresh_plugin_config_watch_subscriptions(self) -> None:
"""按当前已注册插件集合刷新 config.toml 的单插件订阅。
@@ -394,7 +836,11 @@ class PluginRuntimeManager(
if self._plugin_file_watcher is None:
return
- desired_config_paths = dict(self._iter_registered_plugin_config_paths())
+ desired_plugin_paths = dict(self._iter_registered_plugin_paths())
+ self._plugin_path_cache = desired_plugin_paths.copy()
+ desired_config_paths = {
+ plugin_id: plugin_path / "config.toml" for plugin_id, plugin_path in desired_plugin_paths.items()
+ }
for plugin_id, (_old_path, subscription_id) in list(self._plugin_config_watcher_subscriptions.items()):
if desired_config_paths.get(plugin_id) == self._plugin_config_watcher_subscriptions[plugin_id][0]:
@@ -418,28 +864,35 @@ class PluginRuntimeManager(
"""为指定插件生成配置文件变更回调。"""
async def _callback(changes: Sequence[FileChange]) -> None:
+ """将 watcher 事件转发到指定插件的配置处理逻辑。
+
+ Args:
+ changes: 当前批次收集到的文件变更列表。
+ """
await self._handle_plugin_config_changes(plugin_id, changes)
return _callback
- def _iter_registered_plugin_config_paths(self) -> Iterable[Tuple[str, Path]]:
- """迭代当前所有已注册插件的 config.toml 路径。"""
+ def _iter_registered_plugin_paths(self) -> Iterable[Tuple[str, Path]]:
+ """迭代当前所有已注册插件的实际目录路径。"""
for supervisor in self.supervisors:
for plugin_id in getattr(supervisor, "_registered_plugins", {}).keys():
- if config_path := self._get_plugin_config_path_for_supervisor(supervisor, plugin_id):
- yield plugin_id, config_path
+ if plugin_path := self._get_plugin_path_for_supervisor(supervisor, plugin_id):
+ yield plugin_id, plugin_path
def _get_plugin_config_path_for_supervisor(self, supervisor: Any, plugin_id: str) -> Optional[Path]:
"""从指定 Supervisor 的插件目录中定位某个插件的 config.toml。"""
- for plugin_dir in getattr(supervisor, "_plugin_dirs", []):
- plugin_dir = Path(plugin_dir)
- plugin_path = plugin_dir.resolve() / plugin_id
- if plugin_path.is_dir():
- return plugin_path / "config.toml"
- return None
+ plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id)
+ return None if plugin_path is None else plugin_path / "config.toml"
async def _handle_plugin_config_changes(self, plugin_id: str, changes: Sequence[FileChange]) -> None:
- """处理单个插件配置文件变化,并仅向目标插件推送配置更新。"""
+ """处理单个插件配置文件变化,并定向派发自配置热更新。
+
+ Args:
+ plugin_id: 发生配置变更的插件 ID。
+ changes: 当前批次收集到的配置文件变更列表。
+
+ """
if not self._started or not changes:
return
@@ -453,18 +906,24 @@ class PluginRuntimeManager(
return
try:
- await supervisor.notify_plugin_config_updated(
+ config_payload = self._load_plugin_config_for_supervisor(supervisor, plugin_id)
+ delivered = await supervisor.notify_plugin_config_updated(
plugin_id=plugin_id,
- config_data=self._load_plugin_config_for_supervisor(plugin_id, getattr(supervisor, "_plugin_dirs", [])),
+ config_data=config_payload,
+ config_version="",
+ config_scope="self",
)
+ if not delivered:
+ logger.warning(f"插件 {plugin_id} 配置文件变更后通知失败")
except Exception as exc:
- logger.warning(f"插件 {plugin_id} 配置热更新通知失败: {exc}")
+ logger.warning(f"插件 {plugin_id} 配置文件变更处理失败: {exc}")
async def _handle_plugin_source_changes(self, changes: Sequence[FileChange]) -> None:
"""处理插件源码相关变化。
这里仅负责源码、清单等会影响插件装载状态的文件;配置文件的变化会由
- 单独的 per-plugin watcher 处理,避免把单插件配置更新放大成全量 reload。
+ 单独的 per-plugin watcher 处理,并定向派发给目标插件的
+ ``on_config_update()``,避免放大成不必要的跨插件 reload。
"""
if not self._started or not changes:
return
@@ -477,7 +936,7 @@ class PluginRuntimeManager(
logger.error(f"检测到重复插件 ID,跳过本次插件热重载: {details}")
return
- reload_supervisors: List[Any] = []
+ changed_plugin_ids: List[str] = []
changed_paths = [change.path.resolve() for change in changes]
for supervisor in self.supervisors:
@@ -485,13 +944,12 @@ class PluginRuntimeManager(
plugin_id = self._match_plugin_id_for_supervisor(supervisor, path)
if plugin_id is None:
continue
- if (path.name in {"plugin.py", "_manifest.json"} or path.suffix == ".py") and supervisor not in reload_supervisors:
- reload_supervisors.append(supervisor)
+ if path.name in {"plugin.py", "_manifest.json"} or path.suffix == ".py":
+ if plugin_id not in changed_plugin_ids:
+ changed_plugin_ids.append(plugin_id)
- for supervisor in reload_supervisors:
- await supervisor.reload_plugins(reason="file_watcher")
-
- if reload_supervisors:
+ if changed_plugin_ids:
+ await self.reload_plugins_globally(changed_plugin_ids, reason="file_watcher")
self._refresh_plugin_config_watch_subscriptions()
@staticmethod
@@ -502,36 +960,47 @@ class PluginRuntimeManager(
def _match_plugin_id_for_supervisor(self, supervisor: Any, path: Path) -> Optional[str]:
"""根据变更路径为指定 Supervisor 推断受影响的插件 ID。"""
- for plugin_id, _reg in getattr(supervisor, "_registered_plugins", {}).items():
- for plugin_dir in getattr(supervisor, "_plugin_dirs", []):
- plugin_dir = Path(plugin_dir)
- candidate_dir = plugin_dir.resolve() / plugin_id
- if path == candidate_dir or path.is_relative_to(candidate_dir):
- return plugin_id
+ resolved_path = path.resolve()
+
+ for plugin_id in getattr(supervisor, "_registered_plugins", {}).keys():
+ plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id)
+ if plugin_path is not None and (resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path)):
+ return plugin_id
+
+ for plugin_id, plugin_path in self._plugin_path_cache.items():
+ if not any(self._plugin_dir_matches(plugin_path, Path(plugin_dir)) for plugin_dir in getattr(supervisor, "_plugin_dirs", [])):
+ continue
+ if resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path):
+ return plugin_id
+
+ for plugin_id, plugin_path in self._iter_discovered_plugin_paths(getattr(supervisor, "_plugin_dirs", [])):
+ if resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path):
+ self._plugin_path_cache[plugin_id] = plugin_path
+ return plugin_id
- for plugin_dir in getattr(supervisor, "_plugin_dirs", []):
- plugin_dir = Path(plugin_dir)
- plugin_root = plugin_dir.resolve()
- if self._plugin_dir_matches(path, plugin_dir) and (relative_parts := path.relative_to(plugin_root).parts):
- return relative_parts[0]
return None
- @staticmethod
- def _load_plugin_config_for_supervisor(plugin_id: str, plugin_dirs: Iterable[Path]) -> Dict[str, Any]:
+ def _load_plugin_config_for_supervisor(self, supervisor: Any, plugin_id: str) -> Dict[str, Any]:
"""从给定插件目录集合中读取目标插件的配置内容。"""
- for plugin_dir in plugin_dirs:
- plugin_path = plugin_dir.resolve() / plugin_id
- if plugin_path.is_dir():
- config_path = plugin_path / "config.toml"
- if not config_path.exists():
- return {}
- with open(config_path, "r", encoding="utf-8") as handle:
- return tomlkit.load(handle).unwrap()
- return {}
+ plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id)
+ if plugin_path is None:
+ return {}
+
+ config_path = plugin_path / "config.toml"
+ if not config_path.exists():
+ return {}
+
+ with open(config_path, "r", encoding="utf-8") as handle:
+ return tomlkit.load(handle).unwrap()
# ─── 能力实现注册 ──────────────────────────────────────────
def _register_capability_impls(self, supervisor: "PluginSupervisor") -> None:
+ """向指定 Supervisor 注册主程序能力实现。
+
+ Args:
+ supervisor: 需要注册能力实现的目标 Supervisor。
+ """
register_capability_impls(self, supervisor)
diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py
index bcfb2758..e738d019 100644
--- a/src/plugin_runtime/protocol/envelope.py
+++ b/src/plugin_runtime/protocol/envelope.py
@@ -7,52 +7,52 @@
from enum import Enum
from typing import Any, Dict, List, Optional
-from pydantic import BaseModel, Field
-
import logging as stdlib_logging
import time
+from pydantic import BaseModel, Field
-# ─── 协议常量 ──────────────────────────────────────────────────────
-
-PROTOCOL_VERSION = "1.0"
+# ====== 协议常量 ======
+PROTOCOL_VERSION = "1.0.0"
# 支持的 SDK 版本范围(Host 在握手时校验)
MIN_SDK_VERSION = "1.0.0"
-MAX_SDK_VERSION = "1.99.99"
-
-
-# ─── 消息类型 ──────────────────────────────────────────────────────
+MAX_SDK_VERSION = "2.99.99"
+# ====== 消息类型 ======
class MessageType(str, Enum):
"""RPC 消息类型"""
REQUEST = "request"
RESPONSE = "response"
- EVENT = "event"
+ BROADCAST = "broadcast"
-# ─── 请求 ID 生成器 ───────────────────────────────────────────────
+class ConfigReloadScope(str, Enum):
+ """配置热重载范围。"""
+
+ SELF = "self"
+ BOT = "bot"
+ MODEL = "model"
+# ====== 请求 ID 生成器 ======
class RequestIdGenerator:
- """单调递增 int64 请求 ID 生成器(线程安全由调用方保证或使用 asyncio)"""
+ """单调递增 int64 请求 ID 生成器"""
def __init__(self, start: int = 1) -> None:
self._counter = start
- def next(self) -> int:
+ async def next(self) -> int:
current = self._counter
self._counter += 1
return current
-# ─── Envelope 模型 ─────────────────────────────────────────────────
-
-
+# ====== Envelope 模型 ======
class Envelope(BaseModel):
- """RPC 统一信封
+ """RPC 统一消息封装
所有 Host <-> Runner 消息均封装为此格式。
序列化流程:Envelope -> .model_dump() -> MsgPack encode
@@ -60,15 +60,23 @@ class Envelope(BaseModel):
"""
protocol_version: str = Field(default=PROTOCOL_VERSION, description="协议版本")
+ """协议版本"""
request_id: int = Field(description="单调递增请求 ID")
+ """单调递增请求 ID"""
message_type: MessageType = Field(description="消息类型")
+ """消息类型"""
method: str = Field(default="", description="RPC 方法名")
+ """RPC 方法名"""
plugin_id: str = Field(default="", description="目标插件 ID")
- timestamp_ms: int = Field(default_factory=lambda: int(time.time() * 1000), description="发送时间戳(ms)")
- timeout_ms: int = Field(default=30000, description="相对超时(ms)")
- generation: int = Field(default=0, description="Runner generation 编号")
+ """目标插件 ID"""
+ timestamp_ms: int = Field(default_factory=lambda: int(time.time() * 1000), description="发送时间戳 (ms)")
+ """发送时间戳 (ms)"""
+ timeout_ms: int = Field(default=30000, description="相对超时 (ms)")
+ """相对超时 (ms)"""
payload: Dict[str, Any] = Field(default_factory=dict, description="业务数据")
- error: Optional[Dict[str, Any]] = Field(default=None, description="错误信息(仅 response)")
+ """业务数据"""
+ error: Optional[Dict[str, Any]] = Field(default=None, description="错误信息 (仅 response)")
+ """错误信息 (仅 response)"""
def is_request(self) -> bool:
return self.message_type == MessageType.REQUEST
@@ -76,8 +84,8 @@ class Envelope(BaseModel):
def is_response(self) -> bool:
return self.message_type == MessageType.RESPONSE
- def is_event(self) -> bool:
- return self.message_type == MessageType.EVENT
+ def is_broadcast(self) -> bool:
+ return self.message_type == MessageType.BROADCAST
def make_response(
self, payload: Optional[Dict[str, Any]] = None, error: Optional[Dict[str, Any]] = None
@@ -89,7 +97,6 @@ class Envelope(BaseModel):
message_type=MessageType.RESPONSE,
method=self.method,
plugin_id=self.plugin_id,
- generation=self.generation,
payload=payload or {},
error=error,
)
@@ -105,153 +112,302 @@ class Envelope(BaseModel):
)
-# ─── 握手消息 ──────────────────────────────────────────────────────
-
-
+# ====== 握手请求与响应 ======
class HelloPayload(BaseModel):
"""runner.hello 握手请求 payload"""
runner_id: str = Field(description="Runner 进程唯一标识")
+ """Runner 进程唯一标识"""
sdk_version: str = Field(description="SDK 版本号")
+ """SDK 版本号"""
session_token: str = Field(description="一次性会话令牌")
+ """一次性会话令牌"""
class HelloResponsePayload(BaseModel):
"""runner.hello 握手响应 payload"""
accepted: bool = Field(description="是否接受连接")
+ """是否接受连接"""
host_version: str = Field(default="", description="Host 版本号")
- assigned_generation: int = Field(default=0, description="分配的 generation 编号")
- reason: str = Field(default="", description="拒绝原因(若 accepted=False)")
-
-
-# ─── 组件注册消息 ──────────────────────────────────────────────────
+ """Host 版本号"""
+ reason: str = Field(default="", description="拒绝原因 (若 accepted=False)")
+ """拒绝原因 (若 `accepted`=`False`)"""
+# ====== 组件注册消息 ======
class ComponentDeclaration(BaseModel):
"""单个组件声明"""
name: str = Field(description="组件名称")
- component_type: str = Field(description="组件类型: action/command/tool/event_handler")
+ """组件名称"""
+ component_type: str = Field(
+ description="组件类型:action/command/tool/event_handler/hook_handler/message_gateway"
+ )
+ """组件类型:`action`/`command`/`tool`/`event_handler`/`hook_handler`/`message_gateway`"""
plugin_id: str = Field(description="所属插件 ID")
+ """所属插件 ID"""
metadata: Dict[str, Any] = Field(default_factory=dict, description="组件元数据")
+ """组件元数据"""
-class RegisterComponentsPayload(BaseModel):
- """plugin.register_components 请求 payload"""
+class RegisterPluginPayload(BaseModel):
+ """插件组件注册请求载荷。
+
+ 该模型同时用于 ``plugin.register_components`` 与兼容旧命名的
+ ``plugin.register_plugin`` 请求。
+ """
plugin_id: str = Field(description="插件 ID")
+ """插件 ID"""
plugin_version: str = Field(default="1.0.0", description="插件版本")
+ """插件版本"""
components: List[ComponentDeclaration] = Field(default_factory=list, description="组件列表")
+ """组件列表"""
capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表")
+ """所需能力列表"""
+ dependencies: List[str] = Field(default_factory=list, description="插件级依赖插件 ID 列表")
+ """插件级依赖插件 ID 列表"""
+ config_reload_subscriptions: List[str] = Field(default_factory=list, description="订阅的全局配置热重载范围")
+ """订阅的全局配置热重载范围"""
class BootstrapPluginPayload(BaseModel):
"""plugin.bootstrap 请求 payload"""
plugin_id: str = Field(description="插件 ID")
+ """插件 ID"""
plugin_version: str = Field(default="1.0.0", description="插件版本")
+ """插件版本"""
capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表")
+ """所需能力列表"""
-# ─── 调用消息 ──────────────────────────────────────────────────────
-
-
+# ====== 插件调用请求和响应 ======
class InvokePayload(BaseModel):
- """plugin.invoke_* 请求 payload"""
+ """plugin.invoke.* 请求 payload"""
component_name: str = Field(description="要调用的组件名称")
+ """要调用的组件名称"""
args: Dict[str, Any] = Field(default_factory=dict, description="调用参数")
+ """调用参数"""
class InvokeResultPayload(BaseModel):
- """plugin.invoke_* 响应 payload"""
+ """plugin.invoke.* 响应 payload"""
success: bool = Field(description="是否成功")
+ """是否成功"""
result: Any = Field(default=None, description="返回值")
+ """返回值"""
-# ─── 能力调用消息 ──────────────────────────────────────────────────
-
-
+# ====== 能力调用消息 ======
class CapabilityRequestPayload(BaseModel):
"""cap.* 请求 payload(插件 -> Host 能力调用)"""
capability: str = Field(description="能力名称,如 send.text, db.query")
+ """能力名称,如 send.text, db.query"""
args: Dict[str, Any] = Field(default_factory=dict, description="调用参数")
+ """调用参数"""
class CapabilityResponsePayload(BaseModel):
"""cap.* 响应 payload"""
success: bool = Field(description="是否成功")
+ """是否成功"""
result: Any = Field(default=None, description="返回值")
+ """返回值"""
-# ─── 健康检查 ──────────────────────────────────────────────────────
-
-
+# ====== 健康检查 ======
class HealthPayload(BaseModel):
"""plugin.health 响应 payload"""
healthy: bool = Field(description="是否健康")
+ """是否健康"""
loaded_plugins: List[str] = Field(default_factory=list, description="已加载的插件列表")
- uptime_ms: int = Field(default=0, description="运行时长(ms)")
+ """已加载的插件列表"""
+ uptime_ms: int = Field(default=0, description="运行时长 (ms)")
+ """运行时长 (ms)"""
class RunnerReadyPayload(BaseModel):
"""runner.ready 请求 payload"""
loaded_plugins: List[str] = Field(default_factory=list, description="已完成初始化的插件列表")
+ """已完成初始化的插件列表"""
failed_plugins: List[str] = Field(default_factory=list, description="初始化失败的插件列表")
+ """初始化失败的插件列表"""
-# ─── 配置更新 ──────────────────────────────────────────────────────
-
-
-# Host 侧现已支持配置更新推送:
-# - 总配置热重载完成后,PluginRuntimeManager 会向已加载插件推送配置更新事件。
-# - 插件目录下的 config.toml 变化由现有 FileWatcher 监听并转发为 plugin.config_updated。
+# ====== 配置更新 ======
class ConfigUpdatedPayload(BaseModel):
"""plugin.config_updated 事件 payload"""
plugin_id: str = Field(description="插件 ID")
+ """插件 ID"""
+ config_scope: ConfigReloadScope = Field(description="配置变更范围")
+ """配置变更范围"""
config_version: str = Field(description="新配置版本")
+ """新配置版本"""
config_data: Dict[str, Any] = Field(default_factory=dict, description="配置内容")
+ """配置内容"""
-# ─── 关停 ──────────────────────────────────────────────────────────
-
-
+# ====== 关停 ======
class ShutdownPayload(BaseModel):
"""plugin.shutdown / plugin.prepare_shutdown payload"""
reason: str = Field(default="normal", description="关停原因")
- drain_timeout_ms: int = Field(default=5000, description="排空超时(ms)")
+ """关停原因"""
+ drain_timeout_ms: int = Field(default=5000, description="排空超时 (ms)")
+ """排空超时 (ms)"""
-# ─── 日志传输 ──────────────────────────────────────────────────────
+class UnregisterPluginPayload(BaseModel):
+ """插件注销请求载荷。"""
+
+ plugin_id: str = Field(description="插件 ID")
+ """插件 ID"""
+ reason: str = Field(default="manual", description="注销原因")
+ """注销原因"""
+
+
+class ReloadPluginPayload(BaseModel):
+ """插件重载请求载荷。"""
+
+ plugin_id: str = Field(description="目标插件 ID")
+ """目标插件 ID"""
+ reason: str = Field(default="manual", description="重载原因")
+ """重载原因"""
+ external_available_plugins: Dict[str, str] = Field(
+ default_factory=dict,
+ description="可视为已满足的外部依赖插件版本映射",
+ )
+ """可视为已满足的外部依赖插件版本映射"""
+
+
+class ReloadPluginsPayload(BaseModel):
+ """批量插件重载请求载荷。"""
+
+ plugin_ids: List[str] = Field(default_factory=list, description="目标插件 ID 列表")
+ """目标插件 ID 列表"""
+ reason: str = Field(default="manual", description="重载原因")
+ """重载原因"""
+ external_available_plugins: Dict[str, str] = Field(
+ default_factory=dict,
+ description="可视为已满足的外部依赖插件版本映射",
+ )
+ """可视为已满足的外部依赖插件版本映射"""
+
+
+class ReloadPluginResultPayload(BaseModel):
+ """插件重载结果载荷。"""
+
+ success: bool = Field(description="是否重载成功")
+ """是否重载成功"""
+ requested_plugin_id: str = Field(description="请求重载的插件 ID")
+ """请求重载的插件 ID"""
+ reloaded_plugins: List[str] = Field(default_factory=list, description="成功完成重载的插件列表")
+ """成功完成重载的插件列表"""
+ unloaded_plugins: List[str] = Field(default_factory=list, description="本次已卸载的插件列表")
+ """本次已卸载的插件列表"""
+ failed_plugins: Dict[str, str] = Field(default_factory=dict, description="重载失败的插件及原因")
+ """重载失败的插件及原因"""
+
+
+class ReloadPluginsResultPayload(BaseModel):
+ """批量插件重载结果载荷。"""
+
+ success: bool = Field(description="是否重载成功")
+ """是否重载成功"""
+ requested_plugin_ids: List[str] = Field(default_factory=list, description="请求重载的插件 ID 列表")
+ """请求重载的插件 ID 列表"""
+ reloaded_plugins: List[str] = Field(default_factory=list, description="成功完成重载的插件列表")
+ """成功完成重载的插件列表"""
+ unloaded_plugins: List[str] = Field(default_factory=list, description="本次已卸载的插件列表")
+ """本次已卸载的插件列表"""
+ failed_plugins: Dict[str, str] = Field(default_factory=dict, description="重载失败的插件及原因")
+ """重载失败的插件及原因"""
+
+
+class MessageGatewayStateUpdatePayload(BaseModel):
+ """消息网关运行时状态更新载荷。"""
+
+ gateway_name: str = Field(description="消息网关组件名称")
+ """消息网关组件名称"""
+ ready: bool = Field(description="当前链路是否已经就绪")
+ """当前链路是否已经就绪"""
+ platform: str = Field(default="", description="当前链路负责的平台名称")
+ """当前链路负责的平台名称"""
+ account_id: str = Field(default="", description="当前链路对应的账号 ID 或 self_id")
+ """当前链路对应的账号 ID 或 self_id"""
+ scope: str = Field(default="", description="当前链路对应的可选路由作用域")
+ """当前链路对应的可选路由作用域"""
+ metadata: Dict[str, Any] = Field(default_factory=dict, description="可选的运行时状态元数据")
+ """可选的运行时状态元数据"""
+
+
+class MessageGatewayStateUpdateResultPayload(BaseModel):
+ """消息网关运行时状态更新结果载荷。"""
+
+ accepted: bool = Field(description="Host 是否接受了本次状态更新")
+ """Host 是否接受了本次状态更新"""
+ ready: bool = Field(description="Host 记录的当前就绪状态")
+ """Host 记录的当前就绪状态"""
+ route_key: Dict[str, Any] = Field(default_factory=dict, description="当前生效的路由键")
+ """当前生效的路由键"""
+
+
+class RouteMessagePayload(BaseModel):
+ """消息网关向 Host 路由外部消息的请求载荷。"""
+
+ gateway_name: str = Field(description="接收消息的网关组件名称")
+ """接收消息的网关组件名称"""
+ message: Dict[str, Any] = Field(description="符合 MessageDict 结构的标准消息字典")
+ """符合 MessageDict 结构的标准消息字典"""
+ route_metadata: Dict[str, Any] = Field(default_factory=dict, description="可选的路由辅助元数据")
+ """可选的路由辅助元数据"""
+ external_message_id: str = Field(default="", description="可选的外部平台消息 ID")
+ """可选的外部平台消息 ID"""
+ dedupe_key: str = Field(default="", description="可选的显式去重键")
+ """可选的显式去重键"""
+
+
+class ReceiveExternalMessageResultPayload(BaseModel):
+ """外部消息注入结果载荷。"""
+
+ accepted: bool = Field(description="Host 是否接受了本次消息注入")
+ """Host 是否接受了本次消息注入"""
+ route_key: Dict[str, Any] = Field(default_factory=dict, description="本次消息使用的归一路由键")
+ """本次消息使用的归一路由键"""
+
+
+RegisterPluginPayload.model_rebuild()
+
+
+# ====== 日志传输 ======
class LogEntry(BaseModel):
"""单条日志记录(Runner → Host 传输格式)"""
- timestamp_ms: int = Field(
- description="日志时间戳,Unix epoch 毫秒",
- )
- level: int = Field(
- description=("stdlib logging 整数级别: 10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL"),
- )
- logger_name: str = Field(
- description="Logger 名称,如 plugin.my_plugin.submodule",
- )
- message: str = Field(
- description="经 Formatter 格式化后的完整日志消息(含 exc_info 文本)",
- )
+ timestamp_ms: int = Field(description="日志时间戳,Unix epoch 毫秒")
+ """日志时间戳,Unix epoch 毫秒"""
+ level: int = Field(description="stdlib logging 整数级别:10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL")
+ """stdlib logging 整数级别:10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL"""
+ logger_name: str = Field(description="Logger 名称,如 plugin.my_plugin.submodule")
+ """Logger 名称,如 plugin.my_plugin.submodule"""
+ message: str = Field(description="经 Formatter 格式化后的完整日志消息(含 exc_info 文本)")
+ """经 Formatter 格式化后的完整日志消息(含 exc_info 文本)"""
exception_text: str = Field(
default="",
description="原始异常摘要(exc_text),供结构化消费;已嵌入 message 中",
)
+ """原始异常摘要(exc_text),供结构化消费;已嵌入 message 中"""
+ log_color_in_hex: Optional[str] = Field(default=None, description="日志颜色的十六进制字符串(如 #RRGGBB)")
@property
def levelname(self) -> str:
@@ -262,6 +418,5 @@ class LogEntry(BaseModel):
class LogBatchPayload(BaseModel):
"""runner.log_batch 事件 payload:Runner 端向 Host 批量推送日志记录"""
- entries: List[LogEntry] = Field(
- description="本批次日志记录列表,按时间升序排列",
- )
+ entries: List[LogEntry] = Field(description="本批次日志记录列表,按时间升序排列")
+ """本批次日志记录列表,按时间升序排列"""
diff --git a/src/plugin_runtime/protocol/errors.py b/src/plugin_runtime/protocol/errors.py
index dcae6b8f..d2b9228b 100644
--- a/src/plugin_runtime/protocol/errors.py
+++ b/src/plugin_runtime/protocol/errors.py
@@ -18,17 +18,17 @@ class ErrorCode(str, Enum):
E_TIMEOUT = "E_TIMEOUT"
E_BAD_PAYLOAD = "E_BAD_PAYLOAD"
E_PROTOCOL_MISMATCH = "E_PROTOCOL_MISMATCH"
+ E_SHUTTING_DOWN = "E_SHUTTING_DOWN"
# 权限与策略
E_UNAUTHORIZED = "E_UNAUTHORIZED"
E_METHOD_NOT_ALLOWED = "E_METHOD_NOT_ALLOWED"
- E_BACKPRESSURE = "E_BACKPRESSURE"
+ E_BACK_PRESSURE = "E_BACK_PRESSURE"
E_HOST_OVERLOADED = "E_HOST_OVERLOADED"
# 插件生命周期
E_PLUGIN_CRASHED = "E_PLUGIN_CRASHED"
E_PLUGIN_NOT_FOUND = "E_PLUGIN_NOT_FOUND"
- E_GENERATION_MISMATCH = "E_GENERATION_MISMATCH"
E_RELOAD_IN_PROGRESS = "E_RELOAD_IN_PROGRESS"
# 能力调用
@@ -65,3 +65,13 @@ class RPCError(Exception):
message=data.get("message", ""),
details=data.get("details", {}),
)
+
+ @classmethod
+ def from_exception(cls, exception: Exception, code_mapping: Optional[Dict[type[Exception], ErrorCode]] = None):
+ if isinstance(exception, cls):
+ return exception
+ if code_mapping:
+ for exception_type, code in code_mapping.items():
+ if isinstance(exception, exception_type):
+ return cls(code=code, message=str(exception))
+ return cls(ErrorCode.E_UNKNOWN, str(exception))
diff --git a/src/plugin_runtime/runner/log_handler.py b/src/plugin_runtime/runner/log_handler.py
index b5a0a328..6f42940f 100644
--- a/src/plugin_runtime/runner/log_handler.py
+++ b/src/plugin_runtime/runner/log_handler.py
@@ -66,6 +66,12 @@ class RunnerIPCLogHandler(logging.Handler):
ALLOWED_LOGGER_PREFIXES: tuple[str, ...] = ("plugin.", "plugin_runtime.", "_maibot_plugin_")
def __init__(self) -> None:
+ """初始化 Runner 端日志转发处理器。
+
+ 创建有界日志缓冲区,并准备与 RPC 客户端绑定的后台刷新任务。
+ 此时不会启动任何异步任务;真正开始转发要等到 :meth:`start`
+ 被调用后才会发生。
+ """
super().__init__()
# deque(maxlen=N): append/popleft 在 CPython GIL 保护下线程安全
self._buffer: collections.deque[LogEntry] = collections.deque(maxlen=self.QUEUE_MAX)
diff --git a/src/plugin_runtime/runner/manifest_validator.py b/src/plugin_runtime/runner/manifest_validator.py
index b6990850..33c2b1e5 100644
--- a/src/plugin_runtime/runner/manifest_validator.py
+++ b/src/plugin_runtime/runner/manifest_validator.py
@@ -1,28 +1,55 @@
-"""Manifest 校验与版本兼容性
+"""Manifest 校验与解析。
-从旧系统的 ManifestValidator / VersionComparator 对齐移植,
-适配新 plugin_runtime 的 _manifest.json 格式。
+集中负责插件 ``_manifest.json`` 的读取、结构校验、运行时兼容性判断,
+以及插件依赖/Python 包依赖的解析逻辑。
"""
-from typing import Any, Dict, List, Tuple
+from functools import lru_cache
+from importlib import metadata as importlib_metadata
+from pathlib import Path
+from typing import Annotated, Any, Dict, Iterable, List, Literal, Optional, Tuple, Union
+import json
import re
+import tomllib
+
+from packaging.requirements import InvalidRequirement, Requirement
+from packaging.specifiers import InvalidSpecifier, SpecifierSet
+from packaging.utils import canonicalize_name
+from packaging.version import InvalidVersion, Version
+from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator, model_validator
from src.common.logger import get_logger
logger = get_logger("plugin_runtime.runner.manifest_validator")
+_SEMVER_PATTERN = re.compile(r"^\d+\.\d+\.\d+$")
+_PLUGIN_ID_PATTERN = re.compile(r"^[a-z0-9]+(?:[.-][a-z0-9]+)+$")
+_PACKAGE_NAME_PATTERN = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*$")
+_HTTP_URL_PATTERN = re.compile(r"^https?://.+$")
+
class VersionComparator:
- """语义化版本号比较器"""
+ """语义化版本号比较器。"""
@staticmethod
def normalize_version(version: str) -> str:
+ """将版本号规范化为三段式语义版本字符串。
+
+ Args:
+ version: 原始版本号字符串。
+
+ Returns:
+ str: 规范化后的 ``major.minor.patch`` 形式版本号。
+ 当输入为空或格式非法时返回 ``0.0.0``。
+ """
if not version:
return "0.0.0"
- normalized = re.sub(r"-snapshot\.\d+", "", version.strip())
+
+ normalized = re.sub(r"-snapshot\.\d+", "", str(version).strip())
if not re.match(r"^\d+(\.\d+){0,2}$", normalized):
return "0.0.0"
+
parts = normalized.split(".")
while len(parts) < 3:
parts.append("0")
@@ -30,6 +57,15 @@ class VersionComparator:
@staticmethod
def parse_version(version: str) -> Tuple[int, int, int]:
+ """将版本字符串解析为可比较的整数元组。
+
+ Args:
+ version: 原始版本号字符串。
+
+ Returns:
+ Tuple[int, int, int]: 三段式版本号对应的整数元组。
+ 当解析失败时返回 ``(0, 0, 0)``。
+ """
normalized = VersionComparator.normalize_version(version)
try:
parts = normalized.split(".")
@@ -39,98 +75,1072 @@ class VersionComparator:
@staticmethod
def compare(v1: str, v2: str) -> int:
+ """比较两个版本号的大小关系。
+
+ Args:
+ v1: 第一个版本号。
+ v2: 第二个版本号。
+
+ Returns:
+ int: ``-1`` 表示 ``v1 < v2``,``1`` 表示 ``v1 > v2``,
+ ``0`` 表示两者相等。
+ """
t1 = VersionComparator.parse_version(v1)
t2 = VersionComparator.parse_version(v2)
if t1 < t2:
return -1
- elif t1 > t2:
+ if t1 > t2:
return 1
return 0
@staticmethod
def is_in_range(version: str, min_version: str = "", max_version: str = "") -> Tuple[bool, str]:
+ """判断版本号是否落在给定闭区间内。
+
+ Args:
+ version: 待检查的版本号。
+ min_version: 允许的最小版本号,留空表示不限制下界。
+ max_version: 允许的最大版本号,留空表示不限制上界。
+
+ Returns:
+ Tuple[bool, str]: 第一项表示是否满足要求,第二项为失败原因;
+ 当校验通过时第二项为空字符串。
+ """
if not min_version and not max_version:
return True, ""
- vn = VersionComparator.normalize_version(version)
+
+ normalized_version = VersionComparator.normalize_version(version)
if min_version:
- mn = VersionComparator.normalize_version(min_version)
- if VersionComparator.compare(vn, mn) < 0:
- return False, f"版本 {vn} 低于最小要求 {mn}"
+ normalized_min_version = VersionComparator.normalize_version(min_version)
+ if VersionComparator.compare(normalized_version, normalized_min_version) < 0:
+ return False, f"版本 {normalized_version} 低于最小要求 {normalized_min_version}"
if max_version:
- mx = VersionComparator.normalize_version(max_version)
- if VersionComparator.compare(vn, mx) > 0:
- return False, f"版本 {vn} 高于最大支持 {mx}"
+ normalized_max_version = VersionComparator.normalize_version(max_version)
+ if VersionComparator.compare(normalized_version, normalized_max_version) > 0:
+ return False, f"版本 {normalized_version} 高于最大支持 {normalized_max_version}"
return True, ""
+ @staticmethod
+ def is_valid_semver(version: str) -> bool:
+ """判断字符串是否为严格三段式语义版本号。
+
+ Args:
+ version: 待检查的版本号字符串。
+
+ Returns:
+ bool: 是否满足 ``X.Y.Z`` 格式。
+ """
+ return bool(_SEMVER_PATTERN.fullmatch(str(version or "").strip()))
+
+
+class _StrictManifestModel(BaseModel):
+ """Manifest 解析使用的严格基类模型。"""
+
+ model_config = ConfigDict(extra="forbid", frozen=True, str_strip_whitespace=True)
+
+
+class ManifestAuthor(_StrictManifestModel):
+ """插件作者信息。"""
+
+ name: str = Field(description="作者名称")
+ url: str = Field(description="作者主页地址")
+
+ @field_validator("name")
+ @classmethod
+ def _validate_name(cls, value: str) -> str:
+ """校验作者名称。
+
+ Args:
+ value: 原始作者名称。
+
+ Returns:
+ str: 规范化后的作者名称。
+
+ Raises:
+ ValueError: 当字段为空时抛出。
+ """
+ if not value:
+ raise ValueError("不能为空")
+ return value
+
+ @field_validator("url")
+ @classmethod
+ def _validate_url(cls, value: str) -> str:
+ """校验作者主页地址。
+
+ Args:
+ value: 原始主页地址。
+
+ Returns:
+ str: 规范化后的主页地址。
+
+ Raises:
+ ValueError: 当字段为空或不是 HTTP/HTTPS URL 时抛出。
+ """
+ if not value:
+ raise ValueError("不能为空")
+ if not _HTTP_URL_PATTERN.fullmatch(value):
+ raise ValueError("必须为 http:// 或 https:// 开头的 URL")
+ return value
+
+
+class ManifestUrls(_StrictManifestModel):
+ """插件相关链接集合。"""
+
+ repository: str = Field(description="插件仓库地址")
+ homepage: Optional[str] = Field(default=None, description="插件主页地址")
+ documentation: Optional[str] = Field(default=None, description="插件文档地址")
+ issues: Optional[str] = Field(default=None, description="插件问题反馈地址")
+
+ @field_validator("repository")
+ @classmethod
+ def _validate_repository(cls, value: str) -> str:
+ """校验仓库地址。
+
+ Args:
+ value: 原始仓库地址。
+
+ Returns:
+ str: 规范化后的仓库地址。
+
+ Raises:
+ ValueError: 当字段为空或不是 HTTP/HTTPS URL 时抛出。
+ """
+ if not value:
+ raise ValueError("不能为空")
+ if not _HTTP_URL_PATTERN.fullmatch(value):
+ raise ValueError("必须为 http:// 或 https:// 开头的 URL")
+ return value
+
+ @field_validator("homepage", "documentation", "issues")
+ @classmethod
+ def _validate_optional_url(cls, value: Optional[str]) -> Optional[str]:
+ """校验可选链接字段。
+
+ Args:
+ value: 原始链接值。
+
+ Returns:
+ Optional[str]: 合法的链接值。
+
+ Raises:
+ ValueError: 当提供的值不是 HTTP/HTTPS URL 时抛出。
+ """
+ if value is None:
+ return None
+ if not value:
+ raise ValueError("不能为空字符串")
+ if not _HTTP_URL_PATTERN.fullmatch(value):
+ raise ValueError("必须为 http:// 或 https:// 开头的 URL")
+ return value
+
+
+class ManifestVersionRange(_StrictManifestModel):
+ """版本闭区间声明。"""
+
+ min_version: str = Field(description="最小版本,闭区间")
+ max_version: str = Field(description="最大版本,闭区间")
+
+ @field_validator("min_version", "max_version")
+ @classmethod
+ def _validate_version(cls, value: str) -> str:
+ """校验版本号格式。
+
+ Args:
+ value: 原始版本号。
+
+ Returns:
+ str: 合法的版本号。
+
+ Raises:
+ ValueError: 当版本号不是严格三段式语义版本时抛出。
+ """
+ if not VersionComparator.is_valid_semver(value):
+ raise ValueError("必须为严格三段式版本号,例如 1.0.0")
+ return value
+
+ @model_validator(mode="after")
+ def _validate_range(self) -> "ManifestVersionRange":
+ """校验版本区间上下界关系。
+
+ Returns:
+ ManifestVersionRange: 当前对象本身。
+
+ Raises:
+ ValueError: 当最小版本大于最大版本时抛出。
+ """
+ if VersionComparator.compare(self.min_version, self.max_version) > 0:
+ raise ValueError("min_version 不能大于 max_version")
+ return self
+
+
+class ManifestI18n(_StrictManifestModel):
+ """国际化配置。"""
+
+ default_locale: str = Field(description="默认语言")
+ locales_path: Optional[str] = Field(default=None, description="语言资源目录")
+ supported_locales: List[str] = Field(default_factory=list, description="支持的语言列表")
+
+ @field_validator("default_locale")
+ @classmethod
+ def _validate_default_locale(cls, value: str) -> str:
+ """校验默认语言。
+
+ Args:
+ value: 原始默认语言。
+
+ Returns:
+ str: 规范化后的默认语言。
+
+ Raises:
+ ValueError: 当字段为空时抛出。
+ """
+ if not value:
+ raise ValueError("不能为空")
+ return value
+
+ @field_validator("locales_path")
+ @classmethod
+ def _validate_locales_path(cls, value: Optional[str]) -> Optional[str]:
+ """校验语言资源目录。
+
+ Args:
+ value: 原始语言资源目录。
+
+ Returns:
+ Optional[str]: 合法的目录值。
+
+ Raises:
+ ValueError: 当值为空字符串时抛出。
+ """
+ if value is None:
+ return None
+ if not value:
+ raise ValueError("不能为空字符串")
+ return value
+
+ @field_validator("supported_locales")
+ @classmethod
+ def _validate_supported_locales(cls, value: List[str]) -> List[str]:
+ """校验支持语言列表。
+
+ Args:
+ value: 原始语言列表。
+
+ Returns:
+ List[str]: 去重后的语言列表。
+
+ Raises:
+ ValueError: 当列表项为空时抛出。
+ """
+ normalized_locales: List[str] = []
+ for locale in value:
+ normalized_locale = str(locale or "").strip()
+ if not normalized_locale:
+ raise ValueError("语言列表中存在空值")
+ if normalized_locale not in normalized_locales:
+ normalized_locales.append(normalized_locale)
+ return normalized_locales
+
+ @model_validator(mode="after")
+ def _validate_default_locale_membership(self) -> "ManifestI18n":
+ """校验默认语言是否位于支持列表中。
+
+ Returns:
+ ManifestI18n: 当前对象本身。
+
+ Raises:
+ ValueError: 当 ``supported_locales`` 非空但未包含 ``default_locale`` 时抛出。
+ """
+ if self.supported_locales and self.default_locale not in self.supported_locales:
+ raise ValueError("default_locale 必须包含在 supported_locales 中")
+ return self
+
+
+class PluginDependencyDefinition(_StrictManifestModel):
+ """插件级依赖声明。"""
+
+ type: Literal["plugin"] = Field(description="依赖类型")
+ id: str = Field(description="依赖插件 ID")
+ version_spec: str = Field(description="版本约束表达式")
+
+ @field_validator("id")
+ @classmethod
+ def _validate_id(cls, value: str) -> str:
+ """校验依赖插件 ID。
+
+ Args:
+ value: 原始依赖插件 ID。
+
+ Returns:
+ str: 合法的依赖插件 ID。
+
+ Raises:
+ ValueError: 当 ID 不符合规则时抛出。
+ """
+ if not _PLUGIN_ID_PATTERN.fullmatch(value):
+ raise ValueError("必须使用小写字母/数字,并以点号或横线分隔,例如 github.author.plugin")
+ return value
+
+ @field_validator("version_spec")
+ @classmethod
+ def _validate_version_spec(cls, value: str) -> str:
+ """校验插件依赖版本约束。
+
+ Args:
+ value: 原始版本约束表达式。
+
+ Returns:
+ str: 合法的版本约束表达式。
+
+ Raises:
+ ValueError: 当表达式无效时抛出。
+ """
+ if not value:
+ raise ValueError("不能为空")
+ try:
+ SpecifierSet(value)
+ except InvalidSpecifier as exc:
+ raise ValueError(f"无效的版本约束: {exc}") from exc
+ return value
+
+
+class PythonPackageDependencyDefinition(_StrictManifestModel):
+ """Python 包依赖声明。"""
+
+ type: Literal["python_package"] = Field(description="依赖类型")
+ name: str = Field(description="Python 包名")
+ version_spec: str = Field(description="版本约束表达式")
+
+ @field_validator("name")
+ @classmethod
+ def _validate_name(cls, value: str) -> str:
+ """校验 Python 包名。
+
+ Args:
+ value: 原始包名。
+
+ Returns:
+ str: 合法的包名。
+
+ Raises:
+ ValueError: 当包名不合法时抛出。
+ """
+ if not _PACKAGE_NAME_PATTERN.fullmatch(value):
+ raise ValueError("包名只能包含字母、数字、点号、下划线和横线")
+ return value
+
+ @field_validator("version_spec")
+ @classmethod
+ def _validate_version_spec(cls, value: str) -> str:
+ """校验 Python 包版本约束。
+
+ Args:
+ value: 原始版本约束表达式。
+
+ Returns:
+ str: 合法的版本约束表达式。
+
+ Raises:
+ ValueError: 当表达式无效时抛出。
+ """
+ if not value:
+ raise ValueError("不能为空")
+ try:
+ Requirement(f"placeholder{value}")
+ except InvalidRequirement as exc:
+ raise ValueError(f"无效的版本约束: {exc}") from exc
+ return value
+
+
+ManifestDependencyDefinition = Annotated[
+ Union[PluginDependencyDefinition, PythonPackageDependencyDefinition],
+ Field(discriminator="type"),
+]
+
+
+class PluginManifest(_StrictManifestModel):
+ """插件 Manifest v2 强类型模型。"""
+
+ manifest_version: Literal[2] = Field(description="Manifest 协议版本")
+ version: str = Field(description="插件版本")
+ name: str = Field(description="插件展示名称")
+ description: str = Field(description="插件描述")
+ author: ManifestAuthor = Field(description="插件作者信息")
+ license: str = Field(description="插件协议")
+ urls: ManifestUrls = Field(description="插件相关链接")
+ host_application: ManifestVersionRange = Field(description="Host 兼容区间")
+ sdk: ManifestVersionRange = Field(description="SDK 兼容区间")
+ dependencies: List[ManifestDependencyDefinition] = Field(default_factory=list, description="依赖声明")
+ capabilities: List[str] = Field(description="插件声明的能力请求")
+ i18n: ManifestI18n = Field(description="国际化配置")
+ id: str = Field(description="稳定插件 ID")
+
+ @field_validator("version")
+ @classmethod
+ def _validate_version(cls, value: str) -> str:
+ """校验插件版本号格式。
+
+ Args:
+ value: 原始插件版本号。
+
+ Returns:
+ str: 合法的插件版本号。
+
+ Raises:
+ ValueError: 当版本号不是严格三段式语义版本时抛出。
+ """
+ if not VersionComparator.is_valid_semver(value):
+ raise ValueError("必须为严格三段式版本号,例如 1.0.0")
+ return value
+
+ @field_validator("name", "description", "license", "id")
+ @classmethod
+ def _validate_required_string(cls, value: str, info: Any) -> str:
+ """校验必填字符串字段。
+
+ Args:
+ value: 原始字段值。
+ info: Pydantic 字段上下文。
+
+ Returns:
+ str: 合法的字段值。
+
+ Raises:
+ ValueError: 当字段为空或格式不合法时抛出。
+ """
+ if not value:
+ raise ValueError("不能为空")
+ if info.field_name == "id" and not _PLUGIN_ID_PATTERN.fullmatch(value):
+ raise ValueError("必须使用小写字母/数字,并以点号或横线分隔,例如 github.author.plugin")
+ return value
+
+ @field_validator("capabilities")
+ @classmethod
+ def _validate_capabilities(cls, value: List[str]) -> List[str]:
+ """校验能力声明列表。
+
+ Args:
+ value: 原始能力声明列表。
+
+ Returns:
+ List[str]: 去重后的能力列表。
+
+ Raises:
+ ValueError: 当列表为空项或能力名为空时抛出。
+ """
+ normalized_capabilities: List[str] = []
+ for capability in value:
+ normalized_capability = str(capability or "").strip()
+ if not normalized_capability:
+ raise ValueError("capabilities 中存在空能力名")
+ if normalized_capability not in normalized_capabilities:
+ normalized_capabilities.append(normalized_capability)
+ return normalized_capabilities
+
+ @model_validator(mode="after")
+ def _validate_dependencies(self) -> "PluginManifest":
+ """校验依赖声明集合。
+
+ Returns:
+ PluginManifest: 当前对象本身。
+
+ Raises:
+ ValueError: 当依赖项重复或插件依赖自身时抛出。
+ """
+ plugin_dependency_ids: set[str] = set()
+ python_package_names: set[str] = set()
+
+ for dependency in self.dependencies:
+ if isinstance(dependency, PluginDependencyDefinition):
+ if dependency.id == self.id:
+ raise ValueError("dependencies 中的插件依赖不能依赖自身")
+ if dependency.id in plugin_dependency_ids:
+ raise ValueError(f"存在重复的插件依赖声明: {dependency.id}")
+ plugin_dependency_ids.add(dependency.id)
+ continue
+
+ normalized_package_name = canonicalize_name(dependency.name)
+ if normalized_package_name in python_package_names:
+ raise ValueError(f"存在重复的 Python 包依赖声明: {dependency.name}")
+ python_package_names.add(normalized_package_name)
+
+ return self
+
+ @property
+ def plugin_dependencies(self) -> List[PluginDependencyDefinition]:
+ """返回插件级依赖列表。
+
+ Returns:
+ List[PluginDependencyDefinition]: 所有 ``type=plugin`` 的依赖项。
+ """
+ return [dependency for dependency in self.dependencies if isinstance(dependency, PluginDependencyDefinition)]
+
+ @property
+ def python_package_dependencies(self) -> List[PythonPackageDependencyDefinition]:
+ """返回 Python 包依赖列表。
+
+ Returns:
+ List[PythonPackageDependencyDefinition]: 所有 ``type=python_package`` 的依赖项。
+ """
+ return [
+ dependency
+ for dependency in self.dependencies
+ if isinstance(dependency, PythonPackageDependencyDefinition)
+ ]
+
+ @property
+ def plugin_dependency_ids(self) -> List[str]:
+ """返回插件级依赖的插件 ID 列表。
+
+ Returns:
+ List[str]: 所有插件级依赖的插件 ID。
+ """
+ return [dependency.id for dependency in self.plugin_dependencies]
+
class ManifestValidator:
- """_manifest.json 校验器"""
+ """严格的插件 Manifest v2 校验器。"""
- REQUIRED_FIELDS = ["name", "version", "description", "author"]
- RECOMMENDED_FIELDS = ["license", "keywords", "categories"]
- SUPPORTED_MANIFEST_VERSIONS = [1, 2]
+ SUPPORTED_MANIFEST_VERSIONS = [2]
- def __init__(self, host_version: str = "") -> None:
- self._host_version = host_version
+ def __init__(
+ self,
+ host_version: str = "",
+ sdk_version: str = "",
+ project_root: Optional[Path] = None,
+ ) -> None:
+ """初始化 Manifest 校验器。
+
+ Args:
+ host_version: 当前 Host 版本号;留空时自动从主程序 ``pyproject.toml`` 读取。
+ sdk_version: 当前 SDK 版本号;留空时自动从运行环境中探测。
+ project_root: 项目根目录;留空时自动推断。
+ """
+ self._project_root: Path = project_root or self._resolve_project_root()
+ self._host_version: str = host_version or self._detect_default_host_version(self._project_root)
+ self._sdk_version: str = sdk_version or self._detect_default_sdk_version(self._project_root)
self.errors: List[str] = []
self.warnings: List[str] = []
def validate(self, manifest: Dict[str, Any]) -> bool:
- """校验 manifest 数据,返回是否通过(errors 为空即通过)。"""
+ """校验 manifest 数据,返回是否通过。
+
+ Args:
+ manifest: 待校验的 Manifest 原始字典。
+
+ Returns:
+ bool: 校验是否通过。
+ """
+ return self.parse_manifest(manifest) is not None
+
+ def parse_manifest(self, manifest: Dict[str, Any]) -> Optional[PluginManifest]:
+ """解析并校验 manifest 字典。
+
+ Args:
+ manifest: 待解析的 Manifest 原始字典。
+
+ Returns:
+ Optional[PluginManifest]: 解析成功时返回强类型 Manifest;失败时返回 ``None``。
+ """
self.errors.clear()
self.warnings.clear()
- self._check_required_fields(manifest)
- self._check_manifest_version(manifest)
- self._check_author(manifest)
- self._check_host_compatibility(manifest)
- self._check_recommended(manifest)
+ try:
+ parsed_manifest = PluginManifest.model_validate(manifest)
+ except ValidationError as exc:
+ self.errors.extend(self._format_validation_errors(exc))
+ self._log_errors()
+ return None
+ self._validate_runtime_compatibility(parsed_manifest)
if self.errors:
- for e in self.errors:
- logger.error(f"Manifest 校验失败: {e}")
- if self.warnings:
- for w in self.warnings:
- logger.warning(f"Manifest 警告: {w}")
+ self._log_errors()
+ return None
- return len(self.errors) == 0
+ return parsed_manifest
- def _check_required_fields(self, manifest: Dict[str, Any]) -> None:
- for field in self.REQUIRED_FIELDS:
- if field not in manifest:
- self.errors.append(f"缺少必需字段: {field}")
- elif not manifest[field]:
- self.errors.append(f"必需字段不能为空: {field}")
+ def load_from_plugin_path(self, plugin_path: Path, require_entrypoint: bool = True) -> Optional[PluginManifest]:
+ """从插件目录读取并解析 manifest。
- def _check_manifest_version(self, manifest: Dict[str, Any]) -> None:
- mv = manifest.get("manifest_version")
- if mv is not None and mv not in self.SUPPORTED_MANIFEST_VERSIONS:
- self.errors.append(f"不支持的 manifest_version: {mv},支持: {self.SUPPORTED_MANIFEST_VERSIONS}")
+ Args:
+ plugin_path: 单个插件目录路径。
+ require_entrypoint: 是否要求目录内存在 ``plugin.py`` 入口文件。
- def _check_author(self, manifest: Dict[str, Any]) -> None:
- author = manifest.get("author")
- if author is None:
- return
- if isinstance(author, dict):
- if "name" not in author or not author["name"]:
- self.errors.append("author 对象缺少 name 字段")
- elif isinstance(author, str):
- if not author.strip():
- self.errors.append("author 不能为空")
- else:
- self.errors.append("author 应为字符串或 {name, url} 对象")
+ Returns:
+ Optional[PluginManifest]: 解析成功时返回强类型 Manifest;失败时返回 ``None``。
+ """
+ self.errors.clear()
+ self.warnings.clear()
- def _check_host_compatibility(self, manifest: Dict[str, Any]) -> None:
- host_app = manifest.get("host_application")
- if not isinstance(host_app, dict) or not self._host_version:
- return
- min_v = host_app.get("min_version", "")
- max_v = host_app.get("max_version", "")
- ok, msg = VersionComparator.is_in_range(self._host_version, min_v, max_v)
- if not ok:
- self.errors.append(f"Host 版本不兼容: {msg} (当前 Host: {self._host_version})")
+ manifest_path = plugin_path / "_manifest.json"
+ entrypoint_path = plugin_path / "plugin.py"
- def _check_recommended(self, manifest: Dict[str, Any]) -> None:
- for field in self.RECOMMENDED_FIELDS:
- if field not in manifest or not manifest[field]:
- self.warnings.append(f"建议填写字段: {field}")
+ if not manifest_path.is_file():
+ self.errors.append("缺少 _manifest.json")
+ return None
+ if require_entrypoint and not entrypoint_path.is_file():
+ self.errors.append("缺少 plugin.py")
+ return None
+
+ try:
+ with manifest_path.open("r", encoding="utf-8") as manifest_file:
+ manifest_data = json.load(manifest_file)
+ except Exception as exc:
+ self.errors.append(f"manifest 解析失败: {exc}")
+ self._log_errors()
+ return None
+
+ if not isinstance(manifest_data, dict):
+ self.errors.append("manifest 顶层必须为 JSON 对象")
+ self._log_errors()
+ return None
+
+ return self.parse_manifest(manifest_data)
+
+ def iter_plugin_manifests(
+ self,
+ plugin_dirs: Iterable[Path],
+ require_entrypoint: bool = True,
+ ) -> Iterable[Tuple[Path, PluginManifest]]:
+ """扫描插件根目录并迭代所有可成功解析的 Manifest。
+
+ Args:
+ plugin_dirs: 一个或多个插件根目录。
+ require_entrypoint: 是否要求每个插件目录内存在 ``plugin.py``。
+
+ Yields:
+ Tuple[Path, PluginManifest]: ``(插件目录路径, 解析结果)`` 二元组。
+ """
+ for plugin_root in plugin_dirs:
+ normalized_root = Path(plugin_root).resolve()
+ if not normalized_root.is_dir():
+ continue
+
+ for candidate_path in sorted(entry.resolve() for entry in normalized_root.iterdir() if entry.is_dir()):
+ parsed_manifest = self.load_from_plugin_path(candidate_path, require_entrypoint=require_entrypoint)
+ if parsed_manifest is None:
+ continue
+ yield candidate_path, parsed_manifest
+
+ def build_plugin_dependency_map(
+ self,
+ plugin_dirs: Iterable[Path],
+ require_entrypoint: bool = True,
+ ) -> Dict[str, List[str]]:
+ """扫描目录并构建 ``plugin_id -> 依赖插件 ID 列表`` 映射。
+
+ Args:
+ plugin_dirs: 一个或多个插件根目录。
+ require_entrypoint: 是否要求每个插件目录内存在 ``plugin.py``。
+
+ Returns:
+ Dict[str, List[str]]: 所有成功解析到的插件依赖映射。
+ """
+ dependency_map: Dict[str, List[str]] = {}
+ for _plugin_path, manifest in self.iter_plugin_manifests(plugin_dirs, require_entrypoint=require_entrypoint):
+ dependency_map[manifest.id] = manifest.plugin_dependency_ids
+ return dependency_map
+
+ def read_plugin_id_from_plugin_path(self, plugin_path: Path, require_entrypoint: bool = True) -> Optional[str]:
+ """从单个插件目录中读取规范化后的插件 ID。
+
+ Args:
+ plugin_path: 单个插件目录路径。
+ require_entrypoint: 是否要求目录内存在 ``plugin.py``。
+
+ Returns:
+ Optional[str]: 解析成功时返回插件 ID,否则返回 ``None``。
+ """
+ manifest = self.load_from_plugin_path(plugin_path, require_entrypoint=require_entrypoint)
+ if manifest is None:
+ return None
+ return manifest.id
+
+ def get_unsatisfied_plugin_dependencies(
+ self,
+ manifest: PluginManifest,
+ available_plugin_versions: Dict[str, str],
+ ) -> List[str]:
+ """返回当前 Manifest 尚未满足的插件依赖项。
+
+ Args:
+ manifest: 目标插件的强类型 Manifest。
+ available_plugin_versions: 当前可用插件版本映射,键为插件 ID,值为插件版本。
+
+ Returns:
+ List[str]: 未满足依赖的错误描述列表。
+ """
+ unsatisfied_dependencies: List[str] = []
+ for dependency in manifest.plugin_dependencies:
+ dependency_version = available_plugin_versions.get(dependency.id)
+ if not dependency_version:
+ unsatisfied_dependencies.append(f"{dependency.id} (未找到依赖插件)")
+ continue
+
+ if not self._version_matches_specifier(dependency_version, dependency.version_spec):
+ unsatisfied_dependencies.append(
+ f"{dependency.id} (需要 {dependency.version_spec},当前 {dependency_version})"
+ )
+
+ return unsatisfied_dependencies
+
+ def is_plugin_dependency_satisfied(
+ self,
+ dependency: PluginDependencyDefinition,
+ plugin_version: str,
+ ) -> bool:
+ """判断单个插件依赖是否被指定版本满足。
+
+ Args:
+ dependency: 插件级依赖声明。
+ plugin_version: 当前可用的插件版本号。
+
+ Returns:
+ bool: 是否满足版本约束。
+ """
+ return self._version_matches_specifier(plugin_version, dependency.version_spec)
+
+ def _validate_runtime_compatibility(self, manifest: PluginManifest) -> None:
+ """校验运行时版本兼容性与 Python 包依赖。
+
+ Args:
+ manifest: 已通过结构校验的强类型 Manifest。
+ """
+ host_ok, host_message = VersionComparator.is_in_range(
+ self._host_version,
+ manifest.host_application.min_version,
+ manifest.host_application.max_version,
+ )
+ if not host_ok:
+ self.errors.append(f"Host 版本不兼容: {host_message} (当前 Host: {self._host_version})")
+
+ sdk_ok, sdk_message = VersionComparator.is_in_range(
+ self._sdk_version,
+ manifest.sdk.min_version,
+ manifest.sdk.max_version,
+ )
+ if not sdk_ok:
+ self.errors.append(f"SDK 版本不兼容: {sdk_message} (当前 SDK: {self._sdk_version})")
+
+ self._validate_python_package_dependencies(manifest)
+
+ def _validate_python_package_dependencies(self, manifest: PluginManifest) -> None:
+ """校验 Python 包依赖与主程序运行环境是否冲突。
+
+ Args:
+ manifest: 已通过结构校验的强类型 Manifest。
+ """
+ host_requirements = self._load_host_dependency_requirements(self._project_root)
+
+ for dependency in manifest.python_package_dependencies:
+ normalized_package_name = canonicalize_name(dependency.name)
+ package_specifier = self._build_specifier_set(dependency.version_spec)
+ if package_specifier is None:
+ self.errors.append(
+ f"Python 包依赖 {dependency.name} 的版本约束无效: {dependency.version_spec}"
+ )
+ continue
+
+ installed_version = self._get_installed_package_version(dependency.name)
+ host_requirement = host_requirements.get(normalized_package_name)
+
+ if installed_version is not None and not self._version_matches_specifier(
+ installed_version,
+ dependency.version_spec,
+ ):
+ self.errors.append(
+ f"Python 包依赖冲突: {dependency.name} 需要 {dependency.version_spec},"
+ f"当前运行环境为 {installed_version}"
+ )
+ continue
+
+ if host_requirement is None:
+ continue
+
+ if not self._requirements_may_overlap(host_requirement.specifier, package_specifier):
+ host_specifier = str(host_requirement.specifier or "")
+ self.errors.append(
+ f"Python 包依赖冲突: {dependency.name} 需要 {dependency.version_spec},"
+ f"主程序依赖约束为 {host_specifier or '任意版本'}"
+ )
+
+ def _log_errors(self) -> None:
+ """输出当前累计的 Manifest 校验错误。"""
+ for error_message in self.errors:
+ logger.error(f"Manifest 校验失败: {error_message}")
+
+ @classmethod
+ def _resolve_project_root(cls) -> Path:
+ """推断当前项目根目录。
+
+ Returns:
+ Path: 项目根目录路径。
+ """
+ return Path(__file__).resolve().parents[3]
+
+ @classmethod
+ @lru_cache(maxsize=None)
+ def _detect_default_host_version(cls, project_root: Path) -> str:
+ """从主程序 ``pyproject.toml`` 探测 Host 版本号。
+
+ Args:
+ project_root: 项目根目录。
+
+ Returns:
+ str: 探测到的 Host 版本号;失败时返回空字符串。
+ """
+ pyproject_path = project_root / "pyproject.toml"
+ try:
+ with pyproject_path.open("rb") as pyproject_file:
+ pyproject_data = tomllib.load(pyproject_file)
+ except Exception:
+ return ""
+
+ project_data = pyproject_data.get("project", {})
+ if not isinstance(project_data, dict):
+ return ""
+
+ raw_version = str(project_data.get("version", "") or "").strip()
+ if VersionComparator.is_valid_semver(raw_version):
+ return raw_version
+ return ""
+
+ @classmethod
+ @lru_cache(maxsize=None)
+ def _detect_default_sdk_version(cls, project_root: Path) -> str:
+ """探测当前运行环境中的 SDK 版本号。
+
+ Args:
+ project_root: 项目根目录。
+
+ Returns:
+ str: 探测到的 SDK 版本号;失败时返回空字符串。
+ """
+ try:
+ raw_version = importlib_metadata.version("maibot-plugin-sdk")
+ if VersionComparator.is_valid_semver(raw_version):
+ return raw_version
+ except importlib_metadata.PackageNotFoundError:
+ pass
+
+ sdk_pyproject_path = project_root / "packages" / "maibot-plugin-sdk" / "pyproject.toml"
+ try:
+ with sdk_pyproject_path.open("rb") as pyproject_file:
+ pyproject_data = tomllib.load(pyproject_file)
+ except Exception:
+ return ""
+
+ project_data = pyproject_data.get("project", {})
+ if not isinstance(project_data, dict):
+ return ""
+
+ raw_version = str(project_data.get("version", "") or "").strip()
+ if VersionComparator.is_valid_semver(raw_version):
+ return raw_version
+ return ""
+
+ @classmethod
+ @lru_cache(maxsize=None)
+ def _load_host_dependency_requirements(cls, project_root: Path) -> Dict[str, Requirement]:
+ """加载主程序 ``pyproject.toml`` 中声明的依赖约束。
+
+ Args:
+ project_root: 项目根目录。
+
+ Returns:
+ Dict[str, Requirement]: 以规范化包名为键的 Requirement 映射。
+ """
+ pyproject_path = project_root / "pyproject.toml"
+ try:
+ with pyproject_path.open("rb") as pyproject_file:
+ pyproject_data = tomllib.load(pyproject_file)
+ except Exception:
+ return {}
+
+ project_data = pyproject_data.get("project", {})
+ if not isinstance(project_data, dict):
+ return {}
+
+ raw_dependencies = project_data.get("dependencies", [])
+ if not isinstance(raw_dependencies, list):
+ return {}
+
+ requirements: Dict[str, Requirement] = {}
+ for raw_dependency in raw_dependencies:
+ dependency_text = str(raw_dependency or "").strip()
+ if not dependency_text:
+ continue
+
+ try:
+ requirement = Requirement(dependency_text)
+ except InvalidRequirement:
+ continue
+
+ requirements[canonicalize_name(requirement.name)] = requirement
+
+ return requirements
+
+ @staticmethod
+ def _get_installed_package_version(package_name: str) -> Optional[str]:
+ """获取当前运行环境中指定 Python 包的安装版本。
+
+ Args:
+ package_name: 待查询的包名。
+
+ Returns:
+ Optional[str]: 已安装版本号;未安装时返回 ``None``。
+ """
+ try:
+ return importlib_metadata.version(package_name)
+ except importlib_metadata.PackageNotFoundError:
+ return None
+
+ @staticmethod
+ def _build_specifier_set(version_spec: str) -> Optional[SpecifierSet]:
+ """构造版本约束对象。
+
+ Args:
+ version_spec: 版本约束字符串。
+
+ Returns:
+ Optional[SpecifierSet]: 构造成功时返回约束对象,否则返回 ``None``。
+ """
+ try:
+ return SpecifierSet(version_spec)
+ except InvalidSpecifier:
+ return None
+
+ @staticmethod
+ def _version_matches_specifier(version: str, version_spec: str) -> bool:
+ """判断版本是否满足给定约束。
+
+ Args:
+ version: 待判断的版本号。
+ version_spec: 版本约束表达式。
+
+ Returns:
+ bool: 是否满足约束。
+ """
+ try:
+ normalized_version = Version(version)
+ specifier_set = SpecifierSet(version_spec)
+ except (InvalidVersion, InvalidSpecifier):
+ return False
+ return specifier_set.contains(normalized_version, prereleases=True)
+
+ @classmethod
+ def _requirements_may_overlap(cls, left: SpecifierSet, right: SpecifierSet) -> bool:
+ """粗略判断两个版本约束是否存在交集。
+
+ Args:
+ left: 左侧版本约束。
+ right: 右侧版本约束。
+
+ Returns:
+ bool: 若可能存在交集则返回 ``True``,否则返回 ``False``。
+ """
+ candidate_versions = cls._build_candidate_versions(left, right)
+ for candidate_version in candidate_versions:
+ if left.contains(candidate_version, prereleases=True) and right.contains(candidate_version, prereleases=True):
+ return True
+ return False
+
+ @classmethod
+ def _build_candidate_versions(cls, left: SpecifierSet, right: SpecifierSet) -> List[Version]:
+ """为两个版本约束构造一组用于交集探测的候选版本。
+
+ Args:
+ left: 左侧版本约束。
+ right: 右侧版本约束。
+
+ Returns:
+ List[Version]: 去重后的候选版本列表。
+ """
+ candidate_versions: List[Version] = [Version("0.0.0")]
+ for specifier in tuple(left) + tuple(right):
+ for candidate_version in cls._expand_candidate_versions(specifier.version):
+ if candidate_version not in candidate_versions:
+ candidate_versions.append(candidate_version)
+ return candidate_versions
+
+ @staticmethod
+ def _expand_candidate_versions(raw_version: str) -> List[Version]:
+ """根据边界版本扩展出一组邻近候选版本。
+
+ Args:
+ raw_version: 约束中出现的边界版本字符串。
+
+ Returns:
+ List[Version]: 可用于交集探测的候选版本列表。
+ """
+ normalized_text = raw_version.replace("*", "0")
+ try:
+ boundary_version = Version(normalized_text)
+ except InvalidVersion:
+ return []
+
+ release_parts = list(boundary_version.release[:3])
+ while len(release_parts) < 3:
+ release_parts.append(0)
+ major, minor, patch = release_parts[:3]
+
+ candidates = {
+ Version(f"{major}.{minor}.{patch}"),
+ Version(f"{major}.{minor}.{patch + 1}"),
+ }
+ if patch > 0:
+ candidates.add(Version(f"{major}.{minor}.{patch - 1}"))
+ elif minor > 0:
+ candidates.add(Version(f"{major}.{minor - 1}.999"))
+ elif major > 0:
+ candidates.add(Version(f"{major - 1}.999.999"))
+
+ return sorted(candidates)
+
+ @classmethod
+ def _format_validation_errors(cls, exc: ValidationError) -> List[str]:
+ """将 Pydantic 校验错误转换为中文错误列表。
+
+ Args:
+ exc: Pydantic 抛出的校验异常。
+
+ Returns:
+ List[str]: 中文错误描述列表。
+ """
+ error_messages: List[str] = []
+ for error in exc.errors():
+ location = cls._format_error_location(error.get("loc", ()))
+ error_type = str(error.get("type", ""))
+ error_input = error.get("input")
+ error_context = error.get("ctx", {}) or {}
+
+ if error_type == "missing":
+ error_messages.append(f"缺少必需字段: {location}")
+ elif error_type == "extra_forbidden":
+ error_messages.append(f"存在未声明字段: {location}")
+ elif error_type == "literal_error":
+ expected_values = error_context.get("expected")
+ error_messages.append(f"字段 {location} 的值不合法,必须为 {expected_values}")
+ elif error_type == "model_type":
+ error_messages.append(f"字段 {location} 必须为对象")
+ elif error_type.endswith("_type"):
+ error_messages.append(f"字段 {location} 的类型不正确")
+ elif error_type == "value_error":
+ error_messages.append(f"字段 {location} 校验失败: {error_context.get('error')}")
+ else:
+ error_messages.append(f"字段 {location} 校验失败: {error.get('msg', error_input)}")
+
+ return error_messages
+
+ @staticmethod
+ def _format_error_location(location: Tuple[Any, ...]) -> str:
+ """格式化校验错误字段路径。
+
+ Args:
+ location: Pydantic 提供的字段路径元组。
+
+ Returns:
+ str: 点号连接后的字段路径。
+ """
+ return ".".join(str(item) for item in location) if location else ""
diff --git a/src/plugin_runtime/runner/plugin_loader.py b/src/plugin_runtime/runner/plugin_loader.py
index 11ba45e7..6e85714b 100644
--- a/src/plugin_runtime/runner/plugin_loader.py
+++ b/src/plugin_runtime/runner/plugin_loader.py
@@ -13,16 +13,16 @@ from typing import Any, Dict, Iterator, List, Optional, Set, Tuple
import contextlib
import importlib
import importlib.util
-import json
import os
+import re
import sys
from src.common.logger import get_logger
-from src.plugin_runtime.runner.manifest_validator import ManifestValidator
+from src.plugin_runtime.runner.manifest_validator import ManifestValidator, PluginManifest
logger = get_logger("plugin_runtime.runner.plugin_loader")
-PluginCandidate = Tuple[Path, Dict[str, Any], Path]
+PluginCandidate = Tuple[Path, PluginManifest, Path]
class PluginMeta:
@@ -32,28 +32,28 @@ class PluginMeta:
self,
plugin_id: str,
plugin_dir: str,
+ module_name: str,
plugin_instance: Any,
- manifest: Dict[str, Any],
+ manifest: PluginManifest,
) -> None:
+ """初始化插件元数据。
+
+ Args:
+ plugin_id: 插件 ID。
+ plugin_dir: 插件目录绝对路径。
+ module_name: 插件入口模块名。
+ plugin_instance: 插件实例对象。
+ manifest: 解析后的强类型 Manifest。
+ """
self.plugin_id = plugin_id
self.plugin_dir = plugin_dir
+ self.module_name = module_name
self.instance = plugin_instance
self.manifest = manifest
- self.version = manifest.get("version", "1.0.0")
- self.capabilities_required = manifest.get("capabilities", [])
- self.dependencies: List[str] = self._extract_dependencies(manifest)
-
- @staticmethod
- def _extract_dependencies(manifest: Dict[str, Any]) -> List[str]:
- raw = manifest.get("dependencies", [])
- result: List[str] = []
- for dep in raw:
- if isinstance(dep, str):
- result.append(dep.strip())
- elif isinstance(dep, dict):
- if name := str(dep.get("name", "")).strip():
- result.append(name)
- return result
+ self.version = manifest.version
+ self.capabilities_required = list(manifest.capabilities)
+ self.dependencies: List[str] = list(manifest.plugin_dependency_ids)
+ self.component_handlers: Dict[str, str] = {}
class PluginLoader:
@@ -66,30 +66,52 @@ class PluginLoader:
"""
def __init__(self, host_version: str = "") -> None:
+ """初始化插件加载器。
+
+ Args:
+ host_version: Host 版本号,用于 manifest 兼容性校验。
+ """
self._loaded_plugins: Dict[str, PluginMeta] = {}
self._failed_plugins: Dict[str, str] = {}
self._manifest_validator = ManifestValidator(host_version=host_version)
self._compat_hook_installed = False
- def discover_and_load(self, plugin_dirs: List[str]) -> List[PluginMeta]:
- """扫描多个目录并加载所有插件(含依赖排序和 manifest 校验)
+ def discover_and_load(
+ self,
+ plugin_dirs: List[str],
+ extra_available: Optional[Dict[str, str]] = None,
+ ) -> List[PluginMeta]:
+ """扫描多个目录并加载所有插件。
Args:
- plugin_dirs: 插件目录列表
+ plugin_dirs: 插件目录列表。
+ extra_available: 额外视为已满足的外部依赖插件版本映射。
Returns:
- 成功加载的插件元数据列表(按依赖顺序)
+ List[PluginMeta]: 成功加载的插件元数据列表,按依赖顺序排列。
"""
candidates, duplicate_candidates = self._discover_candidates(plugin_dirs)
self._record_duplicate_candidates(duplicate_candidates)
# 第二阶段:依赖解析(拓扑排序)
- load_order, failed_deps = self._resolve_dependencies(candidates)
+ load_order, failed_deps = self._resolve_dependencies(candidates, extra_available=extra_available)
self._record_failed_dependencies(failed_deps)
# 第三阶段:按依赖顺序加载
return self._load_plugins_in_order(load_order, candidates)
+ def discover_candidates(self, plugin_dirs: List[str]) -> Tuple[Dict[str, PluginCandidate], Dict[str, List[Path]]]:
+ """扫描插件目录并返回候选插件。
+
+ Args:
+ plugin_dirs: 需要扫描的插件根目录列表。
+
+ Returns:
+ Tuple[Dict[str, PluginCandidate], Dict[str, List[Path]]]:
+ 候选插件映射和重复插件 ID 冲突映射。
+ """
+ return self._discover_candidates(plugin_dirs)
+
def _discover_candidates(self, plugin_dirs: List[str]) -> Tuple[Dict[str, PluginCandidate], Dict[str, List[Path]]]:
"""扫描插件目录并收集候选插件。"""
candidates: Dict[str, PluginCandidate] = {}
@@ -123,26 +145,17 @@ class PluginLoader:
def _discover_single_candidate(self, plugin_dir: Path) -> Optional[Tuple[str, PluginCandidate]]:
"""发现并校验单个插件目录。"""
- manifest_path = plugin_dir / "_manifest.json"
plugin_path = plugin_dir / "plugin.py"
-
- if not manifest_path.exists() or not plugin_path.exists():
+ if not plugin_path.exists():
return None
- try:
- with manifest_path.open("r", encoding="utf-8") as manifest_file:
- manifest: Dict[str, Any] = json.load(manifest_file)
- except Exception as e:
- self._failed_plugins[plugin_dir.name] = f"manifest 解析失败: {e}"
- logger.error(f"插件 {plugin_dir.name} manifest 解析失败: {e}")
- return None
-
- if not self._manifest_validator.validate(manifest):
+ manifest = self._manifest_validator.load_from_plugin_path(plugin_dir)
+ if manifest is None:
errors = "; ".join(self._manifest_validator.errors)
self._failed_plugins[plugin_dir.name] = f"manifest 校验失败: {errors}"
return None
- plugin_id = str(manifest.get("name", plugin_dir.name)).strip() or plugin_dir.name
+ plugin_id = manifest.id
return plugin_id, (plugin_dir, manifest, plugin_path)
def _record_duplicate_candidates(self, duplicate_candidates: Dict[str, List[Path]]) -> None:
@@ -170,7 +183,6 @@ class PluginLoader:
plugin_dir, manifest, plugin_path = candidates[plugin_id]
try:
if meta := self._load_single_plugin(plugin_id, plugin_dir, manifest, plugin_path):
- self._loaded_plugins[meta.plugin_id] = meta
results.append(meta)
except Exception as e:
self._failed_plugins[plugin_id] = str(e)
@@ -182,45 +194,193 @@ class PluginLoader:
"""获取已加载的插件"""
return self._loaded_plugins.get(plugin_id)
+ def set_loaded_plugin(self, meta: PluginMeta) -> None:
+ """登记一个已经完成初始化的插件。
+
+ Args:
+ meta: 待登记的插件元数据。
+ """
+ self._loaded_plugins[meta.plugin_id] = meta
+
+ def remove_loaded_plugin(self, plugin_id: str) -> Optional[PluginMeta]:
+ """移除一个已加载插件的元数据。
+
+ Args:
+ plugin_id: 待移除的插件 ID。
+
+ Returns:
+ Optional[PluginMeta]: 被移除的插件元数据;不存在时返回 ``None``。
+ """
+ return self._loaded_plugins.pop(plugin_id, None)
+
+ def purge_plugin_modules(self, plugin_id: str, plugin_dir: str) -> List[str]:
+ """清理指定插件目录下的模块缓存。
+
+ Args:
+ plugin_id: 插件 ID。
+ plugin_dir: 插件目录绝对路径。
+
+ Returns:
+ List[str]: 已从 ``sys.modules`` 中移除的模块名列表。
+ """
+ removed_modules: List[str] = []
+ plugin_path = Path(plugin_dir).resolve()
+ synthetic_module_name = self._build_safe_module_name(plugin_id)
+
+ for module_name, module in list(sys.modules.items()):
+ if module_name == synthetic_module_name:
+ removed_modules.append(module_name)
+ sys.modules.pop(module_name, None)
+ continue
+
+ module_file = getattr(module, "__file__", None)
+ if module_file is None:
+ continue
+
+ try:
+ module_path = Path(module_file).resolve()
+ except Exception:
+ continue
+
+ if module_path.is_relative_to(plugin_path):
+ removed_modules.append(module_name)
+ sys.modules.pop(module_name, None)
+
+ importlib.invalidate_caches()
+ return removed_modules
+
+ @staticmethod
+ def _build_safe_module_name(plugin_id: str) -> str:
+ """将插件 ID 转换为可用于动态导入的安全模块名。
+
+ Args:
+ plugin_id: 原始插件 ID。
+
+ Returns:
+ str: 仅包含字母、数字和下划线的合成模块名。
+ """
+ normalized_plugin_id = re.sub(r"[^0-9A-Za-z_]", "_", str(plugin_id or "").strip())
+ if normalized_plugin_id and normalized_plugin_id[0].isdigit():
+ normalized_plugin_id = f"_{normalized_plugin_id}"
+ return f"_maibot_plugin_{normalized_plugin_id or 'plugin'}"
+
def list_plugins(self) -> List[str]:
"""列出所有已加载的插件 ID"""
return list(self._loaded_plugins.keys())
@property
def failed_plugins(self) -> Dict[str, str]:
+ """返回当前记录的失败插件原因映射。"""
return dict(self._failed_plugins)
+ @property
+ def manifest_validator(self) -> ManifestValidator:
+ """返回当前加载器持有的 Manifest 校验器。
+
+ Returns:
+ ManifestValidator: 当前使用的 Manifest 校验器实例。
+ """
+ return self._manifest_validator
+
# ──── 依赖解析 ────────────────────────────────────────────
+ def resolve_dependencies(
+ self,
+ candidates: Dict[str, PluginCandidate],
+ extra_available: Optional[Dict[str, str]] = None,
+ ) -> Tuple[List[str], Dict[str, str]]:
+ """解析候选插件的依赖顺序。
+
+ Args:
+ candidates: 待加载的候选插件集合。
+ extra_available: 视为已满足的外部依赖插件版本映射。
+
+ Returns:
+ Tuple[List[str], Dict[str, str]]: 可加载顺序和失败原因映射。
+ """
+ return self._resolve_dependencies(candidates, extra_available=extra_available)
+
+ def load_candidate(self, plugin_id: str, candidate: PluginCandidate) -> Optional[PluginMeta]:
+ """加载单个候选插件模块。
+
+ Args:
+ plugin_id: 插件 ID。
+ candidate: 候选插件三元组。
+
+ Returns:
+ Optional[PluginMeta]: 加载成功的插件元数据;失败时返回 ``None``。
+ """
+ plugin_dir, manifest, plugin_path = candidate
+ return self._load_single_plugin(plugin_id, plugin_dir, manifest, plugin_path)
+
def _resolve_dependencies(
self,
candidates: Dict[str, PluginCandidate],
+ extra_available: Optional[Dict[str, str]] = None,
) -> Tuple[List[str], Dict[str, str]]:
"""拓扑排序解析加载顺序,返回 (有序列表, 失败项 {id: reason})。"""
available = set(candidates.keys())
+ satisfied_dependencies = {
+ str(plugin_id or "").strip(): str(plugin_version or "").strip()
+ for plugin_id, plugin_version in (extra_available or {}).items()
+ if str(plugin_id or "").strip() and str(plugin_version or "").strip()
+ }
dep_graph: Dict[str, Set[str]] = {}
failed: Dict[str, str] = {}
for pid, (_, manifest, _) in candidates.items():
- raw_deps = manifest.get("dependencies", [])
resolved: Set[str] = set()
- missing: List[str] = []
- for dep in raw_deps:
- dep_name = dep if isinstance(dep, str) else str(dep.get("name", ""))
- dep_name = dep_name.strip()
- if not dep_name or dep_name == pid:
+ missing_or_incompatible: List[str] = []
+
+ for dependency in manifest.plugin_dependencies:
+ dependency_id = dependency.id
+ if dependency_id in available:
+ dependency_manifest = candidates[dependency_id][1]
+ if not self._manifest_validator.is_plugin_dependency_satisfied(
+ dependency,
+ dependency_manifest.version,
+ ):
+ missing_or_incompatible.append(
+ f"{dependency_id} (需要 {dependency.version_spec},当前 {dependency_manifest.version})"
+ )
+ continue
+ resolved.add(dependency_id)
continue
- if dep_name in available:
- resolved.add(dep_name)
- else:
- missing.append(dep_name)
- if missing:
- failed[pid] = f"缺少依赖: {', '.join(missing)}"
+
+ external_dependency_version = satisfied_dependencies.get(dependency_id)
+ if external_dependency_version is None:
+ missing_or_incompatible.append(f"{dependency_id} (未找到依赖插件)")
+ continue
+
+ if not self._manifest_validator.is_plugin_dependency_satisfied(
+ dependency,
+ external_dependency_version,
+ ):
+ missing_or_incompatible.append(
+ f"{dependency_id} (需要 {dependency.version_spec},当前 {external_dependency_version})"
+ )
+
+ if missing_or_incompatible:
+ failed[pid] = f"依赖未满足: {', '.join(missing_or_incompatible)}"
dep_graph[pid] = resolved
- # 移除失败项
- for pid in failed:
- dep_graph.pop(pid, None)
+ # 迭代传播“依赖自身加载失败”到上游依赖方,避免误报为循环依赖
+ changed = True
+ while changed:
+ changed = False
+ failed_plugin_ids = set(failed)
+ for pid, dependencies in list(dep_graph.items()):
+ if pid in failed:
+ dep_graph.pop(pid, None)
+ continue
+
+ failed_dependencies = sorted(dependency for dependency in dependencies if dependency in failed_plugin_ids)
+ if not failed_dependencies:
+ continue
+
+ failed[pid] = f"依赖未满足: {', '.join(f'{dependency} (依赖插件加载失败)' for dependency in failed_dependencies)}"
+ dep_graph.pop(pid, None)
+ changed = True
# Kahn 拓扑排序
indegree = {pid: len(deps) for pid, deps in dep_graph.items()}
@@ -253,7 +413,7 @@ class PluginLoader:
self,
plugin_id: str,
plugin_dir: Path,
- manifest: Dict[str, Any],
+ manifest: PluginManifest,
plugin_path: Path,
) -> Optional[PluginMeta]:
"""加载单个插件"""
@@ -261,8 +421,12 @@ class PluginLoader:
self._ensure_compat_hook()
# 动态导入插件模块
- module_name = f"_maibot_plugin_{plugin_id}"
- spec = importlib.util.spec_from_file_location(module_name, str(plugin_path))
+ module_name = self._build_safe_module_name(plugin_id)
+ spec = importlib.util.spec_from_file_location(
+ module_name,
+ str(plugin_path),
+ submodule_search_locations=[str(plugin_dir)],
+ )
if spec is None or spec.loader is None:
logger.error(f"无法创建模块 spec: {plugin_path}")
return None
@@ -271,37 +435,73 @@ class PluginLoader:
sys.modules[module_name] = module
plugin_parent_dir = plugin_dir.parent
- with self._temporary_sys_path_entry(plugin_parent_dir):
- spec.loader.exec_module(module)
+ try:
+ with self._temporary_sys_path_entry(plugin_parent_dir):
+ spec.loader.exec_module(module)
- # 优先使用新版 create_plugin 工厂函数
- create_plugin = getattr(module, "create_plugin", None)
- if create_plugin is not None:
- instance = create_plugin()
- logger.info(f"插件 {plugin_id} v{manifest.get('version', '?')} 加载成功")
- return PluginMeta(
- plugin_id=plugin_id,
- plugin_dir=str(plugin_dir),
- plugin_instance=instance,
- manifest=manifest,
- )
+ # 优先使用新版 create_plugin 工厂函数
+ create_plugin = getattr(module, "create_plugin", None)
+ if create_plugin is not None:
+ instance = create_plugin()
+ self._validate_sdk_plugin_contract(plugin_id, instance)
+ logger.info(f"插件 {plugin_id} v{manifest.version} 加载成功")
+ return PluginMeta(
+ plugin_id=plugin_id,
+ plugin_dir=str(plugin_dir),
+ module_name=module_name,
+ plugin_instance=instance,
+ manifest=manifest,
+ )
- # 回退:检测旧版 @register_plugin 标记的 BasePlugin 子类
- instance = self._try_load_legacy_plugin(module, plugin_id)
- if instance is not None:
- logger.info(
- f"插件 {plugin_id} v{manifest.get('version', '?')} 通过旧版兼容层加载成功(请尽快迁移到 maibot_sdk)"
- )
- return PluginMeta(
- plugin_id=plugin_id,
- plugin_dir=str(plugin_dir),
- plugin_instance=instance,
- manifest=manifest,
- )
+ # 回退:检测旧版 @register_plugin 标记的 BasePlugin 子类
+ instance = self._try_load_legacy_plugin(module, plugin_id)
+ if instance is not None:
+ logger.info(
+ f"插件 {plugin_id} v{manifest.version} 通过旧版兼容层加载成功(请尽快迁移到 maibot_sdk)"
+ )
+ return PluginMeta(
+ plugin_id=plugin_id,
+ plugin_dir=str(plugin_dir),
+ module_name=module_name,
+ plugin_instance=instance,
+ manifest=manifest,
+ )
+ except Exception:
+ sys.modules.pop(module_name, None)
+ raise
logger.error(f"插件 {plugin_id} 缺少 create_plugin 工厂函数且未检测到旧版 BasePlugin")
return None
+ @staticmethod
+ def _validate_sdk_plugin_contract(plugin_id: str, instance: Any) -> None:
+ """校验 SDK 插件的基础契约。
+
+ Args:
+ plugin_id: 当前插件 ID。
+ instance: ``create_plugin()`` 返回的插件实例。
+
+ Raises:
+ TypeError: 当插件未覆盖必需生命周期方法或订阅声明不合法时抛出。
+ """
+
+ try:
+ from maibot_sdk.plugin import MaiBotPlugin
+ except ImportError:
+ return
+
+ if not isinstance(instance, MaiBotPlugin):
+ return
+
+ if type(instance).on_load is MaiBotPlugin.on_load:
+ raise TypeError(f"插件 {plugin_id} 必须实现 on_load()")
+ if type(instance).on_unload is MaiBotPlugin.on_unload:
+ raise TypeError(f"插件 {plugin_id} 必须实现 on_unload()")
+ if type(instance).on_config_update is MaiBotPlugin.on_config_update:
+ raise TypeError(f"插件 {plugin_id} 必须实现 on_config_update()")
+
+ instance.get_config_reload_subscriptions()
+
@staticmethod
@contextlib.contextmanager
def _temporary_sys_path_entry(path: Path) -> Iterator[None]:
diff --git a/src/plugin_runtime/runner/rpc_client.py b/src/plugin_runtime/runner/rpc_client.py
index 6a1d59d5..dc917cc8 100644
--- a/src/plugin_runtime/runner/rpc_client.py
+++ b/src/plugin_runtime/runner/rpc_client.py
@@ -1,14 +1,6 @@
-"""Runner 端 RPC Client
+"""Runner 端 RPC 客户端。"""
-负责:
-1. 连接 Host RPC Server
-2. 发送握手(runner.hello)
-3. 发送组件注册请求
-4. 接收并分发 Host 的调用请求
-5. 发送能力调用请求到 Host
-"""
-
-from typing import Any, Awaitable, Callable, Dict, Optional, cast
+from typing import Any, Awaitable, Callable, Dict, Optional, Set, cast
import asyncio
import contextlib
@@ -29,12 +21,15 @@ from src.plugin_runtime.transport.factory import create_transport_client
logger = get_logger("plugin_runtime.runner.rpc_client")
-# RPC 方法处理器类型
MethodHandler = Callable[[Envelope], Awaitable[Envelope]]
def _get_sdk_version() -> str:
- """从 maibot_sdk 包元数据中读取实际版本号,失败时回退到 1.0.0。"""
+ """读取 SDK 版本号。
+
+ Returns:
+ str: 已安装的 SDK 版本;读取失败时回退到 ``1.0.0``。
+ """
try:
from importlib.metadata import version
@@ -47,73 +42,78 @@ SDK_VERSION = _get_sdk_version()
class RPCClient:
- """Runner 端 RPC 客户端
-
- 管理与 Host 的 IPC 连接,支持双向 RPC 调用。
- """
+ """Runner 端 RPC 客户端。"""
def __init__(
self,
host_address: str,
session_token: str,
codec: Optional[Codec] = None,
- ):
- self._host_address = host_address
- self._session_token = session_token
- self._codec = codec or MsgPackCodec()
+ ) -> None:
+ """初始化 RPC 客户端。
+
+ Args:
+ host_address: Host 的 IPC 地址。
+ session_token: 握手用会话令牌。
+ codec: 可选的编解码器实现。
+ """
+ self._host_address: str = host_address
+ self._session_token: str = session_token
+ self._codec: Codec = codec or MsgPackCodec()
self._id_gen = RequestIdGenerator()
self._connection: Optional[Connection] = None
- self._runner_id = str(uuid.uuid4())
- self._generation: int = 0
-
- # 方法处理器注册表(Host 发来的调用)
+ self._runner_id: str = str(uuid.uuid4())
self._method_handlers: Dict[str, MethodHandler] = {}
-
- # 等待响应的 pending 请求: request_id -> Future
- self._pending_requests: Dict[int, asyncio.Future] = {}
-
- # 运行状态
- self._running = False
- self._recv_task: Optional[asyncio.Task] = None
- self._background_tasks: set[asyncio.Task] = set()
-
- @property
- def generation(self) -> int:
- return self._generation
+ self._pending_requests: Dict[int, asyncio.Future[Envelope]] = {}
+ self._running: bool = False
+ self._recv_task: Optional[asyncio.Task[None]] = None
+ self._background_tasks: Set[asyncio.Task[Any]] = set()
@property
def is_connected(self) -> bool:
+ """返回当前连接是否可用。"""
return self._connection is not None and not self._connection.is_closed
def register_method(self, method: str, handler: MethodHandler) -> None:
- """注册方法处理器(处理 Host 发来的请求)"""
+ """注册 Host -> Runner 的 RPC 处理器。
+
+ Args:
+ method: RPC 方法名。
+ handler: 方法处理函数。
+ """
self._method_handlers[method] = handler
def _require_connection(self) -> Connection:
- """返回当前可用连接;若连接不可用则抛出 RPCError。"""
+ """返回当前可用连接。
+
+ Returns:
+ Connection: 当前连接对象。
+
+ Raises:
+ RPCError: 当前未连接到 Host。
+ """
connection = self._connection
if connection is None or connection.is_closed:
raise RPCError(ErrorCode.E_UNKNOWN, "未连接到 Host")
return cast(Connection, connection)
async def connect_and_handshake(self) -> bool:
- """连接 Host 并完成握手
+ """连接 Host 并完成握手。
Returns:
- 是否握手成功
+ bool: 是否握手成功。
"""
client = create_transport_client(self._host_address)
self._connection = await client.connect()
connection = self._require_connection()
- # 发送 runner.hello
hello = HelloPayload(
runner_id=self._runner_id,
sdk_version=SDK_VERSION,
session_token=self._session_token,
)
- request_id = self._id_gen.next()
+ request_id = await self._id_gen.next()
envelope = Envelope(
request_id=request_id,
message_type=MessageType.REQUEST,
@@ -121,33 +121,27 @@ class RPCClient:
payload=hello.model_dump(),
)
- data = self._codec.encode_envelope(envelope)
- await connection.send_frame(data)
+ await connection.send_frame(self._codec.encode_envelope(envelope))
- # 接收握手响应
resp_data = await asyncio.wait_for(connection.recv_frame(), timeout=10.0)
- resp = self._codec.decode_envelope(resp_data)
+ response = self._codec.decode_envelope(resp_data)
+ resp_payload = HelloResponsePayload.model_validate(response.payload)
- resp_payload = HelloResponsePayload.model_validate(resp.payload)
if not resp_payload.accepted:
logger.error(f"握手被拒绝: {resp_payload.reason}")
- await self._connection.close()
- self._connection = None
+ await self.disconnect()
return False
- self._generation = resp_payload.assigned_generation
- logger.info(f"握手成功: generation={self._generation}, host_version={resp_payload.host_version}")
-
- # 启动消息接收循环
+ logger.info(f"握手成功: host_version={resp_payload.host_version}")
self._running = True
- self._recv_task = asyncio.create_task(self._recv_loop())
-
+ self._recv_task = asyncio.create_task(self._recv_loop(), name="RPCClient.recv")
return True
async def disconnect(self) -> None:
- """断开连接"""
+ """断开与 Host 的连接并清理状态。"""
self._running = False
- if self._recv_task:
+
+ if self._recv_task is not None:
self._recv_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._recv_task
@@ -160,13 +154,12 @@ class RPCClient:
await asyncio.gather(*self._background_tasks, return_exceptions=True)
self._background_tasks.clear()
- # 取消所有 pending 请求
for future in self._pending_requests.values():
if not future.done():
future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "连接关闭"))
self._pending_requests.clear()
- if self._connection:
+ if self._connection is not None:
await self._connection.close()
self._connection = None
@@ -177,16 +170,27 @@ class RPCClient:
payload: Optional[Dict[str, Any]] = None,
timeout_ms: int = 30000,
) -> Envelope:
- """向 Host 发送 RPC 请求并等待响应"""
- connection = self._require_connection()
+ """向 Host 发送 RPC 请求并等待响应。
- request_id = self._id_gen.next()
+ Args:
+ method: RPC 方法名。
+ plugin_id: 目标插件 ID。
+ payload: 请求载荷。
+ timeout_ms: 超时时间,单位毫秒。
+
+ Returns:
+ Envelope: Host 返回的响应信封。
+
+ Raises:
+ RPCError: 发送失败、超时或连接异常。
+ """
+ connection = self._require_connection()
+ request_id = await self._id_gen.next()
envelope = Envelope(
request_id=request_id,
message_type=MessageType.REQUEST,
method=method,
plugin_id=plugin_id,
- generation=self._generation,
timeout_ms=timeout_ms,
payload=payload or {},
)
@@ -196,21 +200,16 @@ class RPCClient:
self._pending_requests[request_id] = future
try:
- data = self._codec.encode_envelope(envelope)
- await connection.send_frame(data)
-
- timeout_sec = timeout_ms / 1000.0
- return await asyncio.wait_for(future, timeout=timeout_sec)
+ await connection.send_frame(self._codec.encode_envelope(envelope))
+ return await asyncio.wait_for(future, timeout=timeout_ms / 1000.0)
except asyncio.TimeoutError:
self._pending_requests.pop(request_id, None)
raise RPCError(ErrorCode.E_TIMEOUT, f"请求 {method} 超时 ({timeout_ms}ms)") from None
- except Exception as e:
+ except Exception as exc:
self._pending_requests.pop(request_id, None)
- if isinstance(e, RPCError):
+ if isinstance(exc, RPCError):
raise
- raise RPCError(ErrorCode.E_UNKNOWN, str(e)) from e
-
- # ─── 内部方法 ──────────────────────────────────────────────
+ raise RPCError(ErrorCode.E_UNKNOWN, str(exc)) from exc
async def send_event(
self,
@@ -218,33 +217,30 @@ class RPCClient:
plugin_id: str = "",
payload: Optional[Dict[str, Any]] = None,
) -> None:
- """向 Host 发送单向事件(fire-and-forget,不等待响应)。
+ """向 Host 发送单向广播消息。
Args:
- method: RPC 方法名,如 "runner.log_batch"。
- plugin_id: 目标插件 ID(可为空,表示 Runner 级消息)。
- payload: 事件数据。
+ method: RPC 方法名。
+ plugin_id: 目标插件 ID。
+ payload: 广播载荷。
"""
if not self.is_connected:
return
connection = self._require_connection()
-
- request_id = self._id_gen.next()
+ request_id = await self._id_gen.next()
envelope = Envelope(
request_id=request_id,
- message_type=MessageType.EVENT,
+ message_type=MessageType.BROADCAST,
method=method,
plugin_id=plugin_id,
- generation=self._generation,
payload=payload or {},
)
- data = self._codec.encode_envelope(envelope)
- await connection.send_frame(data)
+ await connection.send_frame(self._codec.encode_envelope(envelope))
async def _recv_loop(self) -> None:
- """消息接收主循环"""
- while self._running and self._connection and not self._connection.is_closed:
+ """持续接收 Host 发来的消息并分发。"""
+ while self._running and self._connection is not None and not self._connection.is_closed:
try:
data = await self._connection.recv_frame()
except (asyncio.IncompleteReadError, ConnectionError):
@@ -252,39 +248,47 @@ class RPCClient:
break
except asyncio.CancelledError:
break
- except Exception as e:
- logger.error(f"接收帧失败: {e}")
+ except Exception as exc:
+ logger.error(f"接收帧失败: {exc}")
break
try:
envelope = self._codec.decode_envelope(data)
- except Exception as e:
- logger.error(f"解码消息失败: {e}")
+ except Exception as exc:
+ logger.error(f"解码消息失败: {exc}")
continue
if envelope.is_response():
self._handle_response(envelope)
elif envelope.is_request():
self._track_background_task(asyncio.create_task(self._handle_request(envelope)))
- elif envelope.is_event():
- self._track_background_task(asyncio.create_task(self._handle_event(envelope)))
+ elif envelope.is_broadcast():
+ self._track_background_task(asyncio.create_task(self._handle_broadcast(envelope)))
def _handle_response(self, envelope: Envelope) -> None:
- """处理来自 Host 的响应"""
+ """处理 Host 返回的响应。
+
+ Args:
+ envelope: 响应信封。
+ """
future = self._pending_requests.pop(envelope.request_id, None)
- if future and not future.done():
- if envelope.error:
- future.set_exception(RPCError.from_dict(envelope.error))
- else:
- future.set_result(envelope)
+ if future is None or future.done():
+ return
+ if envelope.error:
+ future.set_exception(RPCError.from_dict(envelope.error))
+ else:
+ future.set_result(envelope)
async def _handle_request(self, envelope: Envelope) -> None:
- """处理来自 Host 的请求(调用插件组件)"""
+ """处理 Host 发来的请求。
+
+ Args:
+ envelope: 请求信封。
+ """
connection = self._connection
if connection is None or connection.is_closed:
logger.warning(f"处理请求 {envelope.method} 时连接已关闭,跳过响应")
return
- connection = cast(Connection, connection)
handler = self._method_handlers.get(envelope.method)
if handler is None:
@@ -298,23 +302,34 @@ class RPCClient:
try:
response = await handler(envelope)
await connection.send_frame(self._codec.encode_envelope(response))
- except RPCError as e:
- error_resp = envelope.make_error_response(e.code.value, e.message, e.details)
+ except RPCError as exc:
+ error_resp = envelope.make_error_response(exc.code.value, exc.message, exc.details)
await connection.send_frame(self._codec.encode_envelope(error_resp))
- except Exception as e:
- logger.error(f"处理请求 {envelope.method} 异常: {e}", exc_info=True)
- error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
+ except Exception as exc:
+ logger.error(f"处理请求 {envelope.method} 异常: {exc}", exc_info=True)
+ error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(exc))
await connection.send_frame(self._codec.encode_envelope(error_resp))
- async def _handle_event(self, envelope: Envelope) -> None:
- """处理来自 Host 的事件"""
- if handler := self._method_handlers.get(envelope.method):
- try:
- await handler(envelope)
- except Exception as e:
- logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True)
+ async def _handle_broadcast(self, envelope: Envelope) -> None:
+ """处理 Host 发来的广播事件。
- def _track_background_task(self, task: asyncio.Task) -> None:
- """保持后台任务强引用,直到其完成或被取消。"""
+ Args:
+ envelope: 广播信封。
+ """
+ handler = self._method_handlers.get(envelope.method)
+ if handler is None:
+ return
+
+ try:
+ await handler(envelope)
+ except Exception as exc:
+ logger.error(f"处理广播 {envelope.method} 异常: {exc}", exc_info=True)
+
+ def _track_background_task(self, task: asyncio.Task[Any]) -> None:
+ """持有后台任务强引用直到其结束。
+
+ Args:
+ task: 后台任务。
+ """
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py
index dae1cfa1..d1ebc064 100644
--- a/src/plugin_runtime/runner/runner_main.py
+++ b/src/plugin_runtime/runner/runner_main.py
@@ -9,13 +9,13 @@
6. 转发插件的能力调用到 Host
"""
-from typing import Any, Callable, List, Optional, Protocol, cast
-
from pathlib import Path
+from typing import Any, Callable, Dict, List, Optional, Protocol, Set, cast
import asyncio
import contextlib
import inspect
+import json
import logging as stdlib_logging
import os
import signal
@@ -24,27 +24,59 @@ import time
import tomllib
from src.common.logger import get_console_handler, get_logger, initialize_logging
-from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN
+from src.plugin_runtime import (
+ ENV_EXTERNAL_PLUGIN_IDS,
+ ENV_HOST_VERSION,
+ ENV_IPC_ADDRESS,
+ ENV_PLUGIN_DIRS,
+ ENV_SESSION_TOKEN,
+)
from src.plugin_runtime.protocol.envelope import (
BootstrapPluginPayload,
ComponentDeclaration,
+ ConfigUpdatedPayload,
Envelope,
HealthPayload,
InvokePayload,
InvokeResultPayload,
- RegisterComponentsPayload,
+ RegisterPluginPayload,
+ ReloadPluginPayload,
+ ReloadPluginResultPayload,
+ ReloadPluginsPayload,
+ ReloadPluginsResultPayload,
RunnerReadyPayload,
+ UnregisterPluginPayload,
)
from src.plugin_runtime.protocol.errors import ErrorCode
from src.plugin_runtime.runner.log_handler import RunnerIPCLogHandler
-from src.plugin_runtime.runner.plugin_loader import PluginLoader, PluginMeta
+from src.plugin_runtime.runner.plugin_loader import PluginCandidate, PluginLoader, PluginMeta
from src.plugin_runtime.runner.rpc_client import RPCClient
logger = get_logger("plugin_runtime.runner.main")
+_PLUGIN_ALLOWED_RAW_HOST_METHODS = frozenset(
+ {
+ "cap.call",
+ "host.route_message",
+ "host.update_message_gateway_state",
+ }
+)
+
class _ContextAwarePlugin(Protocol):
- def _set_context(self, context: Any) -> None: ...
+ """支持注入运行时上下文的插件协议。
+
+ 该协议用于描述 Runner 在激活插件时依赖的最小接口。
+ 只要插件实例实现了 ``_set_context`` 方法,就可以被 Runner
+ 注入 ``PluginContext`` 或兼容层上下文对象。
+ """
+
+ def _set_context(self, context: Any) -> None:
+ """为插件实例注入运行时上下文。
+
+ Args:
+ context: 由 Runner 构造的上下文对象。
+ """
def _install_shutdown_signal_handlers(
@@ -89,22 +121,37 @@ class PluginRunner:
host_address: str,
session_token: str,
plugin_dirs: List[str],
+ external_available_plugins: Optional[Dict[str, str]] = None,
) -> None:
+ """初始化 Runner。
+
+ Args:
+ host_address: Host 的 IPC 地址。
+ session_token: 握手用会话令牌。
+ plugin_dirs: 当前 Runner 负责扫描的插件目录列表。
+ external_available_plugins: 视为已满足的外部依赖插件版本映射。
+ """
self._host_address: str = host_address
self._session_token: str = session_token
- self._plugin_dirs: list[str] = plugin_dirs
+ self._plugin_dirs: List[str] = plugin_dirs
+ self._external_available_plugins: Dict[str, str] = {
+ str(plugin_id or "").strip(): str(plugin_version or "").strip()
+ for plugin_id, plugin_version in (external_available_plugins or {}).items()
+ if str(plugin_id or "").strip() and str(plugin_version or "").strip()
+ }
self._rpc_client: RPCClient = RPCClient(host_address, session_token)
self._loader: PluginLoader = PluginLoader(host_version=os.getenv(ENV_HOST_VERSION, ""))
self._start_time: float = time.monotonic()
self._shutting_down: bool = False
+ self._reload_lock: asyncio.Lock = asyncio.Lock()
# IPC 日志 Handler:握手成功后安装,将所有 stdlib logging 转发到 Host
self._log_handler: Optional[RunnerIPCLogHandler] = None
- self._suspended_console_handlers: list[stdlib_logging.Handler] = []
+ self._suspended_console_handlers: List[stdlib_logging.Handler] = []
async def run(self) -> None:
- """Runner 主入口"""
+ """运行 Runner 主循环。"""
# 1. 连接 Host
logger.info(f"Runner 启动,连接 Host: {self._host_address}")
ok = await self._rpc_client.connect_and_handshake()
@@ -119,36 +166,18 @@ class PluginRunner:
self._register_handlers()
# 3. 加载插件
- plugins = self._loader.discover_and_load(self._plugin_dirs)
+ plugins = self._loader.discover_and_load(
+ self._plugin_dirs,
+ extra_available=self._external_available_plugins,
+ )
logger.info(f"已加载 {len(plugins)} 个插件")
# 4. 注入 PluginContext + 调用 on_load 生命周期钩子
- failed_plugins: set[str] = set()
+ failed_plugins: Set[str] = set(self._loader.failed_plugins.keys())
for meta in plugins:
- instance = meta.instance
- self._inject_context(meta.plugin_id, instance)
- self._apply_plugin_config(meta)
- if not await self._bootstrap_plugin(meta):
- failed_plugins.add(meta.plugin_id)
- continue
- if hasattr(instance, "on_load"):
- try:
- ret = instance.on_load()
- if asyncio.iscoroutine(ret):
- await ret
- except Exception as e:
- logger.error(f"插件 {meta.plugin_id} on_load 失败,跳过注册: {e}", exc_info=True)
- failed_plugins.add(meta.plugin_id)
- await self._deactivate_plugin(meta)
-
- # 5. 向 Host 注册所有插件的组件(跳过 on_load 失败的插件)
- for meta in plugins:
- if meta.plugin_id in failed_plugins:
- continue
- ok = await self._register_plugin(meta)
+ ok = await self._activate_plugin(meta)
if not ok:
failed_plugins.add(meta.plugin_id)
- await self._deactivate_plugin(meta)
successful_plugins = [meta.plugin_id for meta in plugins if meta.plugin_id not in failed_plugins]
await self._notify_ready(successful_plugins, sorted(failed_plugins))
@@ -217,7 +246,7 @@ class PluginRunner:
"""为插件实例创建并注入 PluginContext。
对新版 MaiBotPlugin(具有 _set_context 方法):创建 PluginContext 并注入。
- 对旧版 LegacyPluginAdapter(具有 _set_context 方法,由适配器代理):同上。
+ 对旧版 LegacyPluginAdapter(具有 _set_context 方法,由兼容代理封装):同上。
"""
if not hasattr(instance, "_set_context"):
return
@@ -232,9 +261,11 @@ class PluginRunner:
bound_plugin_id = plugin_id
async def _rpc_call(
- method: str, plugin_id: str = "", payload: Optional[dict[str, Any]] = None
+ method: str,
+ plugin_id: str = "",
+ payload: Optional[Dict[str, Any]] = None,
) -> Any:
- """桥接 PluginContext.call_capability → RPCClient.send_request。
+ """桥接 PluginContext 的原始 RPC 调用到 Host。
无论调用方传入何种 plugin_id,实际发往 Host 的 plugin_id
始终绑定为当前插件实例,避免伪造其他插件身份申请能力。
@@ -243,21 +274,26 @@ class PluginRunner:
logger.warning(
f"插件 {bound_plugin_id} 尝试以 {plugin_id} 身份发起 RPC,已强制绑定回自身身份"
)
+ normalized_method = str(method or "").strip()
+ if normalized_method not in _PLUGIN_ALLOWED_RAW_HOST_METHODS:
+ raise PermissionError(
+ f"插件 {bound_plugin_id} 不允许直接调用 Host 原始 RPC 方法: "
+ f"{normalized_method or ''}"
+ )
resp = await rpc_client.send_request(
- method=method,
+ method=normalized_method,
plugin_id=bound_plugin_id,
payload=payload or {},
)
- # 从响应信封中提取业务结果
if resp.error:
raise RuntimeError(resp.error.get("message", "能力调用失败"))
- return resp.payload.get("result")
+ return resp.payload
ctx = PluginContext(plugin_id=plugin_id, rpc_call=_rpc_call)
cast(_ContextAwarePlugin, instance)._set_context(ctx)
logger.debug(f"已为插件 {plugin_id} 注入 PluginContext")
- def _apply_plugin_config(self, meta: PluginMeta, config_data: Optional[dict[str, Any]] = None) -> None:
+ def _apply_plugin_config(self, meta: PluginMeta, config_data: Optional[Dict[str, Any]] = None) -> None:
"""在 Runner 侧为插件实例注入当前插件配置。"""
instance = meta.instance
if not hasattr(instance, "set_plugin_config"):
@@ -270,7 +306,7 @@ class PluginRunner:
logger.warning(f"插件 {meta.plugin_id} 配置注入失败: {exc}")
@staticmethod
- def _load_plugin_config(plugin_dir: str) -> dict[str, Any]:
+ def _load_plugin_config(plugin_dir: str) -> Dict[str, Any]:
"""从插件目录读取 config.toml。"""
config_path = Path(plugin_dir) / "config.toml"
if not config_path.exists():
@@ -286,16 +322,60 @@ class PluginRunner:
return loaded if isinstance(loaded, dict) else {}
def _register_handlers(self) -> None:
- """注册方法处理器"""
+ """注册 Host -> Runner 的方法处理器。"""
self._rpc_client.register_method("plugin.invoke_command", self._handle_invoke)
self._rpc_client.register_method("plugin.invoke_action", self._handle_invoke)
+ self._rpc_client.register_method("plugin.invoke_api", self._handle_invoke)
self._rpc_client.register_method("plugin.invoke_tool", self._handle_invoke)
+ self._rpc_client.register_method("plugin.invoke_message_gateway", self._handle_invoke)
self._rpc_client.register_method("plugin.emit_event", self._handle_event_invoke)
+ self._rpc_client.register_method("plugin.invoke_hook", self._handle_hook_invoke)
self._rpc_client.register_method("plugin.invoke_workflow_step", self._handle_workflow_step)
self._rpc_client.register_method("plugin.health", self._handle_health)
self._rpc_client.register_method("plugin.prepare_shutdown", self._handle_prepare_shutdown)
self._rpc_client.register_method("plugin.shutdown", self._handle_shutdown)
self._rpc_client.register_method("plugin.config_updated", self._handle_config_updated)
+ self._rpc_client.register_method("plugin.reload", self._handle_reload_plugin)
+ self._rpc_client.register_method("plugin.reload_batch", self._handle_reload_plugins)
+
+ @staticmethod
+ def _resolve_component_handler_name(meta: PluginMeta, component_name: str) -> str:
+ """解析组件名对应的真实处理函数名。
+
+ Args:
+ meta: 已加载插件的元数据。
+ component_name: Host 侧请求中的组件声明名。
+
+ Returns:
+ str: 实际应在插件实例上查找的方法名。
+ """
+ return str(meta.component_handlers.get(component_name, component_name) or component_name)
+
+ def _resolve_component_handler(self, meta: PluginMeta, component_name: str) -> Any:
+ """根据组件声明名解析插件实例上的可调用处理函数。
+
+ Args:
+ meta: 已加载插件的元数据。
+ component_name: Host 侧请求中的组件声明名。
+
+ Returns:
+ Any: 解析到的可调用对象;未找到时返回 ``None``。
+ """
+ instance = meta.instance
+ handler_name = self._resolve_component_handler_name(meta, component_name)
+ handler_method = getattr(instance, handler_name, None)
+ if handler_method is not None:
+ return handler_method
+
+ if handler_name != component_name:
+ legacy_style_handler = getattr(instance, f"handle_{component_name}", None)
+ if legacy_style_handler is not None:
+ return legacy_style_handler
+
+ prefixed_handler = getattr(instance, f"handle_{component_name}", None)
+ if prefixed_handler is not None:
+ return prefixed_handler
+ return getattr(instance, component_name, None)
async def _bootstrap_plugin(self, meta: PluginMeta, capabilities_required: Optional[List[str]] = None) -> bool:
"""向 Host 同步插件 bootstrap 能力令牌。"""
@@ -308,12 +388,14 @@ class PluginRunner:
)
try:
- await self._rpc_client.send_request(
+ response = await self._rpc_client.send_request(
"plugin.bootstrap",
plugin_id=meta.plugin_id,
payload=payload.model_dump(),
timeout_ms=10000,
)
+ if response.error:
+ raise RuntimeError(response.error.get("message", "插件 bootstrap 失败"))
return True
except Exception as e:
logger.error(f"插件 {meta.plugin_id} bootstrap 失败: {e}")
@@ -324,45 +406,500 @@ class PluginRunner:
await self._bootstrap_plugin(meta, capabilities_required=[])
async def _register_plugin(self, meta: PluginMeta) -> bool:
- """向 Host 注册单个插件"""
+ """向 Host 注册单个插件。
+
+ Args:
+ meta: 待注册的插件元数据。
+
+ Returns:
+ bool: 是否注册成功。
+ """
# 收集插件组件声明
components: List[ComponentDeclaration] = []
+ config_reload_subscriptions: List[str] = []
instance = meta.instance
# 从插件实例获取组件声明(SDK 插件须实现 get_components 方法)
if hasattr(instance, "get_components"):
- components.extend(
- ComponentDeclaration(
- name=comp_info.get("name", ""),
- component_type=comp_info.get("type", ""),
- plugin_id=meta.plugin_id,
- metadata=comp_info.get("metadata", {}),
- )
- for comp_info in instance.get_components()
- )
+ meta.component_handlers.clear()
+ for comp_info in instance.get_components():
+ if not isinstance(comp_info, dict):
+ continue
- reg_payload = RegisterComponentsPayload(
+ component_name = str(comp_info.get("name", "") or "").strip()
+ raw_metadata = comp_info.get("metadata", {})
+ component_metadata = raw_metadata if isinstance(raw_metadata, dict) else {}
+
+ if component_name:
+ handler_name = str(component_metadata.get("handler_name", component_name) or component_name).strip()
+ meta.component_handlers[component_name] = handler_name or component_name
+
+ components.append(
+ ComponentDeclaration(
+ name=component_name,
+ component_type=str(comp_info.get("type", "") or "").strip(),
+ plugin_id=meta.plugin_id,
+ metadata=component_metadata,
+ )
+ )
+ if hasattr(instance, "get_config_reload_subscriptions"):
+ config_reload_subscriptions = list(instance.get_config_reload_subscriptions())
+
+ reg_payload = RegisterPluginPayload(
plugin_id=meta.plugin_id,
plugin_version=meta.version,
components=components,
capabilities_required=meta.capabilities_required,
+ dependencies=meta.dependencies,
+ config_reload_subscriptions=config_reload_subscriptions,
)
try:
- _resp = await self._rpc_client.send_request(
+ response = await self._rpc_client.send_request(
"plugin.register_components",
plugin_id=meta.plugin_id,
payload=reg_payload.model_dump(),
timeout_ms=10000,
)
+ if response.error:
+ raise RuntimeError(response.error.get("message", "插件注册失败"))
logger.info(f"插件 {meta.plugin_id} 注册完成")
return True
except Exception as e:
logger.error(f"插件 {meta.plugin_id} 注册失败: {e}")
return False
+ async def _unregister_plugin(self, plugin_id: str, reason: str) -> None:
+ """通知 Host 注销指定插件。
+
+ Args:
+ plugin_id: 目标插件 ID。
+ reason: 注销原因。
+ """
+ payload = UnregisterPluginPayload(plugin_id=plugin_id, reason=reason)
+ try:
+ await self._rpc_client.send_request(
+ "plugin.unregister",
+ plugin_id=plugin_id,
+ payload=payload.model_dump(),
+ timeout_ms=10000,
+ )
+ except Exception as exc:
+ logger.warning(f"插件 {plugin_id} 注销通知失败: {exc}")
+
+ async def _invoke_plugin_on_load(self, meta: PluginMeta) -> bool:
+ """执行插件的 ``on_load`` 生命周期。
+
+ Args:
+ meta: 待初始化的插件元数据。
+
+ Returns:
+ bool: 生命周期是否执行成功。
+ """
+ instance = meta.instance
+ if not hasattr(instance, "on_load"):
+ return True
+
+ try:
+ result = instance.on_load()
+ if asyncio.iscoroutine(result):
+ await result
+ return True
+ except Exception as exc:
+ logger.error(f"插件 {meta.plugin_id} on_load 失败: {exc}", exc_info=True)
+ return False
+
+ async def _invoke_plugin_on_unload(self, meta: PluginMeta) -> None:
+ """执行插件的 ``on_unload`` 生命周期。
+
+ Args:
+ meta: 待卸载的插件元数据。
+ """
+ instance = meta.instance
+ if not hasattr(instance, "on_unload"):
+ return
+
+ try:
+ result = instance.on_unload()
+ if asyncio.iscoroutine(result):
+ await result
+ except Exception as exc:
+ logger.error(f"插件 {meta.plugin_id} on_unload 失败: {exc}", exc_info=True)
+
+ async def _activate_plugin(self, meta: PluginMeta) -> bool:
+ """完成插件注入、授权、生命周期和组件注册。
+
+ Args:
+ meta: 待激活的插件元数据。
+
+ Returns:
+ bool: 是否激活成功。
+ """
+ self._inject_context(meta.plugin_id, meta.instance)
+ self._apply_plugin_config(meta)
+
+ if not await self._bootstrap_plugin(meta):
+ self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
+ return False
+
+ if not await self._register_plugin(meta):
+ await self._invoke_plugin_on_unload(meta)
+ await self._deactivate_plugin(meta)
+ self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
+ return False
+
+ if not await self._invoke_plugin_on_load(meta):
+ await self._unregister_plugin(meta.plugin_id, reason="on_load_failed")
+ await self._deactivate_plugin(meta)
+ self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
+ return False
+
+ self._loader.set_loaded_plugin(meta)
+ return True
+
+ async def _unload_plugin(self, meta: PluginMeta, reason: str, *, purge_modules: bool = True) -> None:
+ """卸载单个插件并清理 Host/Runner 两侧状态。
+
+ Args:
+ meta: 待卸载的插件元数据。
+ reason: 卸载原因。
+ purge_modules: 是否在卸载完成后清理插件模块缓存。
+ """
+ await self._invoke_plugin_on_unload(meta)
+ await self._unregister_plugin(meta.plugin_id, reason)
+ await self._deactivate_plugin(meta)
+ self._loader.remove_loaded_plugin(meta.plugin_id)
+ if purge_modules:
+ self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
+
+ def _collect_reverse_dependents(self, plugin_id: str) -> Set[str]:
+ """收集依赖指定插件的所有已加载插件。
+
+ Args:
+ plugin_id: 根插件 ID。
+
+ Returns:
+ Set[str]: 目标插件及其所有反向依赖插件集合。
+ """
+ impacted_plugins: Set[str] = {plugin_id}
+ changed = True
+
+ while changed:
+ changed = False
+ for loaded_plugin_id in self._loader.list_plugins():
+ if loaded_plugin_id in impacted_plugins:
+ continue
+
+ meta = self._loader.get_plugin(loaded_plugin_id)
+ if meta is None:
+ continue
+
+ if any(dependency in impacted_plugins for dependency in meta.dependencies):
+ impacted_plugins.add(loaded_plugin_id)
+ changed = True
+
+ return impacted_plugins
+
+ def _collect_reverse_dependents_for_roots(self, plugin_ids: Set[str]) -> Set[str]:
+ """收集多个根插件对应的反向依赖并集。
+
+ Args:
+ plugin_ids: 根插件 ID 集合。
+
+ Returns:
+ Set[str]: 所有根插件及其反向依赖并集。
+ """
+
+ impacted_plugins: Set[str] = set()
+ for plugin_id in sorted(plugin_ids):
+ impacted_plugins.update(self._collect_reverse_dependents(plugin_id))
+ return impacted_plugins
+
+ def _build_unload_order(self, plugin_ids: Set[str]) -> List[str]:
+ """构建受影响插件的卸载顺序。
+
+ Args:
+ plugin_ids: 需要卸载的插件集合。
+
+ Returns:
+ List[str]: 依赖方优先的卸载顺序。
+ """
+ dependency_graph: Dict[str, Set[str]] = {}
+ for plugin_id in plugin_ids:
+ meta = self._loader.get_plugin(plugin_id)
+ if meta is None:
+ dependency_graph[plugin_id] = set()
+ continue
+ dependency_graph[plugin_id] = {dependency for dependency in meta.dependencies if dependency in plugin_ids}
+
+ indegree: Dict[str, int] = {plugin_id: len(dependencies) for plugin_id, dependencies in dependency_graph.items()}
+ reverse_graph: Dict[str, Set[str]] = {plugin_id: set() for plugin_id in dependency_graph}
+
+ for plugin_id, dependencies in dependency_graph.items():
+ for dependency in dependencies:
+ reverse_graph.setdefault(dependency, set()).add(plugin_id)
+
+ queue: List[str] = sorted(plugin_id for plugin_id, degree in indegree.items() if degree == 0)
+ load_order: List[str] = []
+
+ while queue:
+ current_plugin_id = queue.pop(0)
+ load_order.append(current_plugin_id)
+ for dependent_plugin_id in sorted(reverse_graph.get(current_plugin_id, set())):
+ indegree[dependent_plugin_id] -= 1
+ if indegree[dependent_plugin_id] == 0:
+ queue.append(dependent_plugin_id)
+ queue.sort()
+
+ return list(reversed(load_order))
+
+ @staticmethod
+ def _normalize_requested_plugin_ids(plugin_ids: List[str]) -> List[str]:
+ """规范化批量重载请求中的插件 ID 列表。"""
+
+ normalized_plugin_ids: List[str] = []
+ seen_plugin_ids: Set[str] = set()
+ for plugin_id in plugin_ids:
+ normalized_plugin_id = str(plugin_id or "").strip()
+ if not normalized_plugin_id or normalized_plugin_id in seen_plugin_ids:
+ continue
+ seen_plugin_ids.add(normalized_plugin_id)
+ normalized_plugin_ids.append(normalized_plugin_id)
+ return normalized_plugin_ids
+
+ @staticmethod
+ def _finalize_failed_reload_messages(
+ failed_plugins: Dict[str, str],
+ rollback_failures: Dict[str, str],
+ ) -> Dict[str, str]:
+ """在重载失败后补充回滚结果说明。"""
+
+ finalized_failures: Dict[str, str] = {}
+ for failed_plugin_id, failure_reason in failed_plugins.items():
+ rollback_failure = rollback_failures.get(failed_plugin_id)
+ if rollback_failure:
+ finalized_failures[failed_plugin_id] = (
+ f"{failure_reason};且旧版本恢复失败: {rollback_failure}"
+ )
+ else:
+ finalized_failures[failed_plugin_id] = f"{failure_reason}(已恢复旧版本)"
+
+ for failed_plugin_id, rollback_failure in rollback_failures.items():
+ if failed_plugin_id not in finalized_failures:
+ finalized_failures[failed_plugin_id] = f"旧版本恢复失败: {rollback_failure}"
+
+ return finalized_failures
+
+ async def _reload_plugin_by_id(
+ self,
+ plugin_id: str,
+ reason: str,
+ external_available_plugins: Optional[Dict[str, str]] = None,
+ ) -> ReloadPluginResultPayload:
+ """按插件 ID 在 Runner 进程内执行精确重载。
+
+ Args:
+ plugin_id: 目标插件 ID。
+ reason: 重载原因。
+ external_available_plugins: 视为已满足的外部依赖插件版本映射。
+
+ Returns:
+ ReloadPluginResultPayload: 结构化重载结果。
+ """
+ batch_result = await self._reload_plugins_by_ids(
+ [plugin_id],
+ reason,
+ external_available_plugins=external_available_plugins,
+ )
+ return ReloadPluginResultPayload(
+ success=batch_result.success,
+ requested_plugin_id=plugin_id,
+ reloaded_plugins=batch_result.reloaded_plugins,
+ unloaded_plugins=batch_result.unloaded_plugins,
+ failed_plugins=batch_result.failed_plugins,
+ )
+
+ async def _reload_plugins_by_ids(
+ self,
+ plugin_ids: List[str],
+ reason: str,
+ external_available_plugins: Optional[Dict[str, str]] = None,
+ ) -> ReloadPluginsResultPayload:
+ """按插件 ID 列表在 Runner 进程内执行一次批量重载。"""
+
+ normalized_plugin_ids = self._normalize_requested_plugin_ids(plugin_ids)
+ if not normalized_plugin_ids:
+ return ReloadPluginsResultPayload(success=True, requested_plugin_ids=[])
+
+ candidates, duplicate_candidates = self._loader.discover_candidates(self._plugin_dirs)
+ failed_plugins: Dict[str, str] = {}
+ normalized_external_available = {
+ str(candidate_plugin_id or "").strip(): str(candidate_plugin_version or "").strip()
+ for candidate_plugin_id, candidate_plugin_version in (external_available_plugins or {}).items()
+ if str(candidate_plugin_id or "").strip() and str(candidate_plugin_version or "").strip()
+ }
+
+ loaded_plugin_ids = set(self._loader.list_plugins())
+ reload_root_ids: Set[str] = set()
+ for plugin_id in normalized_plugin_ids:
+ if plugin_id in duplicate_candidates:
+ conflict_paths = ", ".join(str(path) for path in duplicate_candidates[plugin_id])
+ failed_plugins[plugin_id] = f"检测到重复插件 ID: {conflict_paths}"
+ continue
+
+ plugin_is_loaded = plugin_id in loaded_plugin_ids
+ plugin_has_candidate = plugin_id in candidates
+ if not plugin_is_loaded and not plugin_has_candidate:
+ failed_plugins[plugin_id] = "插件不存在或未找到合法的 manifest/plugin.py"
+ continue
+
+ reload_root_ids.add(plugin_id)
+
+ if not reload_root_ids:
+ return ReloadPluginsResultPayload(
+ success=False,
+ requested_plugin_ids=normalized_plugin_ids,
+ failed_plugins=failed_plugins,
+ )
+
+ target_plugin_ids: Set[str] = {
+ plugin_id for plugin_id in reload_root_ids if plugin_id not in loaded_plugin_ids
+ }
+ if loaded_root_plugin_ids := reload_root_ids & loaded_plugin_ids:
+ target_plugin_ids.update(self._collect_reverse_dependents_for_roots(loaded_root_plugin_ids))
+
+ unload_order = self._build_unload_order(target_plugin_ids & loaded_plugin_ids)
+ unloaded_plugins: List[str] = []
+ retained_plugin_ids = loaded_plugin_ids - set(unload_order)
+ rollback_metas: Dict[str, PluginMeta] = {}
+
+ for unload_plugin_id in unload_order:
+ meta = self._loader.get_plugin(unload_plugin_id)
+ if meta is None:
+ continue
+ rollback_metas[unload_plugin_id] = meta
+ await self._unload_plugin(meta, reason=reason, purge_modules=False)
+ self._loader.purge_plugin_modules(unload_plugin_id, meta.plugin_dir)
+ unloaded_plugins.append(unload_plugin_id)
+
+ reload_candidates: Dict[str, PluginCandidate] = {}
+ for target_plugin_id in target_plugin_ids:
+ candidate = candidates.get(target_plugin_id)
+ if candidate is None:
+ failed_plugins[target_plugin_id] = "插件目录已不存在"
+ continue
+ reload_candidates[target_plugin_id] = candidate
+
+ load_order, dependency_failures = self._loader.resolve_dependencies(
+ reload_candidates,
+ extra_available={
+ **normalized_external_available,
+ **{
+ retained_plugin_id: retained_meta.version
+ for retained_plugin_id in retained_plugin_ids
+ if (retained_meta := self._loader.get_plugin(retained_plugin_id)) is not None
+ },
+ },
+ )
+ failed_plugins.update(dependency_failures)
+
+ available_plugins = {
+ **normalized_external_available,
+ **{
+ retained_plugin_id: retained_meta.version
+ for retained_plugin_id in retained_plugin_ids
+ if (retained_meta := self._loader.get_plugin(retained_plugin_id)) is not None
+ },
+ }
+ reloaded_plugins: List[str] = []
+
+ for load_plugin_id in load_order:
+ if load_plugin_id in failed_plugins:
+ continue
+
+ candidate = reload_candidates.get(load_plugin_id)
+ if candidate is None:
+ continue
+
+ _, manifest, _ = candidate
+ if unsatisfied_dependencies := self._loader.manifest_validator.get_unsatisfied_plugin_dependencies(
+ manifest,
+ available_plugin_versions=available_plugins,
+ ):
+ failed_plugins[load_plugin_id] = f"依赖未满足: {', '.join(unsatisfied_dependencies)}"
+ continue
+
+ meta = self._loader.load_candidate(load_plugin_id, candidate)
+ if meta is None:
+ failed_plugins[load_plugin_id] = "插件模块加载失败"
+ continue
+
+ activated = await self._activate_plugin(meta)
+ if not activated:
+ failed_plugins[load_plugin_id] = "插件初始化失败"
+ continue
+
+ available_plugins[load_plugin_id] = meta.version
+ reloaded_plugins.append(load_plugin_id)
+
+ if failed_plugins:
+ rollback_failures: Dict[str, str] = {}
+
+ for reloaded_plugin_id in reversed(reloaded_plugins):
+ reloaded_meta = self._loader.get_plugin(reloaded_plugin_id)
+ if reloaded_meta is None:
+ continue
+
+ try:
+ await self._unload_plugin(
+ reloaded_meta,
+ reason=f"{reason}_rollback_cleanup",
+ purge_modules=False,
+ )
+ except Exception as exc:
+ rollback_failures[reloaded_plugin_id] = f"清理失败: {exc}"
+ finally:
+ self._loader.purge_plugin_modules(reloaded_plugin_id, reloaded_meta.plugin_dir)
+
+ for rollback_plugin_id in reversed(unload_order):
+ rollback_meta = rollback_metas.get(rollback_plugin_id)
+ if rollback_meta is None:
+ continue
+
+ try:
+ restored = await self._activate_plugin(rollback_meta)
+ except Exception as exc:
+ rollback_failures[rollback_plugin_id] = str(exc)
+ continue
+
+ if not restored:
+ rollback_failures[rollback_plugin_id] = "无法重新激活旧版本"
+
+ return ReloadPluginsResultPayload(
+ success=False,
+ requested_plugin_ids=normalized_plugin_ids,
+ reloaded_plugins=[],
+ unloaded_plugins=unloaded_plugins,
+ failed_plugins=self._finalize_failed_reload_messages(failed_plugins, rollback_failures),
+ )
+
+ requested_plugin_success = all(plugin_id in reloaded_plugins for plugin_id in reload_root_ids)
+
+ return ReloadPluginsResultPayload(
+ success=requested_plugin_success and not failed_plugins,
+ requested_plugin_ids=normalized_plugin_ids,
+ reloaded_plugins=reloaded_plugins,
+ unloaded_plugins=unloaded_plugins,
+ failed_plugins=failed_plugins,
+ )
+
async def _notify_ready(self, loaded_plugins: List[str], failed_plugins: List[str]) -> None:
- """通知 Host 当前 generation 已完成插件初始化。"""
+ """通知 Host 当前 Runner 已完成插件初始化。
+
+ Args:
+ loaded_plugins: 成功初始化的插件列表。
+ failed_plugins: 初始化失败的插件列表。
+ """
payload = RunnerReadyPayload(
loaded_plugins=loaded_plugins,
failed_plugins=failed_plugins,
@@ -388,19 +925,13 @@ class PluginRunner:
f"插件 {plugin_id} 未加载",
)
- # 调用插件实例的组件方法
- instance = meta.instance
component_name = invoke.component_name
-
- # 优先查找 handle_ 或直接 方法(新版 SDK 插件)
- handler_method = getattr(instance, f"handle_{component_name}", None)
- if handler_method is None:
- handler_method = getattr(instance, component_name, None)
+ handler_method = self._resolve_component_handler(meta, component_name)
# 回退: 旧版 LegacyPluginAdapter 通过 invoke_component 统一桥接
- if (handler_method is None or not callable(handler_method)) and hasattr(instance, "invoke_component"):
+ if (handler_method is None or not callable(handler_method)) and hasattr(meta.instance, "invoke_component"):
try:
- result = await instance.invoke_component(component_name, **invoke.args)
+ result = await meta.instance.invoke_component(component_name, **invoke.args)
resp_payload = InvokeResultPayload(success=True, result=result)
return envelope.make_response(payload=resp_payload.model_dump())
except Exception as e:
@@ -447,11 +978,8 @@ class PluginRunner:
f"插件 {plugin_id} 未加载",
)
- instance = meta.instance
component_name = invoke.component_name
- handler_method = getattr(instance, f"handle_{component_name}", None)
- if handler_method is None:
- handler_method = getattr(instance, component_name, None)
+ handler_method = self._resolve_component_handler(meta, component_name)
if handler_method is None or not callable(handler_method):
return envelope.make_error_response(
@@ -487,6 +1015,60 @@ class PluginRunner:
logger.error(f"插件 {plugin_id} event_handler {component_name} 执行异常: {e}", exc_info=True)
return envelope.make_response(payload={"success": False, "continue_processing": True})
+ async def _handle_hook_invoke(self, envelope: Envelope) -> Envelope:
+ """处理 HookHandler 调用请求。
+
+ Args:
+ envelope: RPC 请求信封。
+
+ Returns:
+ Envelope: 标准化后的 Hook 调用结果。
+ """
+ try:
+ invoke = InvokePayload.model_validate(envelope.payload)
+ except Exception as exc:
+ return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
+
+ plugin_id = envelope.plugin_id
+ meta = self._loader.get_plugin(plugin_id)
+ if meta is None:
+ return envelope.make_error_response(
+ ErrorCode.E_PLUGIN_NOT_FOUND.value,
+ f"插件 {plugin_id} 未加载",
+ )
+
+ component_name = invoke.component_name
+ handler_method = self._resolve_component_handler(meta, component_name)
+ if handler_method is None or not callable(handler_method):
+ return envelope.make_error_response(
+ ErrorCode.E_METHOD_NOT_ALLOWED.value,
+ f"插件 {plugin_id} 无组件: {component_name}",
+ )
+
+ try:
+ raw = (
+ await handler_method(**invoke.args)
+ if inspect.iscoroutinefunction(handler_method)
+ else handler_method(**invoke.args)
+ )
+ except Exception as exc:
+ logger.error(f"插件 {plugin_id} hook_handler {component_name} 执行异常: {exc}", exc_info=True)
+ return envelope.make_response(payload={"success": False, "continue_processing": True})
+
+ if raw is None:
+ result = {"success": True, "continue_processing": True}
+ elif isinstance(raw, dict):
+ result = {
+ "success": True,
+ "continue_processing": raw.get("continue_processing", True),
+ "modified_kwargs": raw.get("modified_kwargs"),
+ "custom_result": raw.get("custom_result"),
+ }
+ else:
+ result = {"success": True, "continue_processing": True, "custom_result": raw}
+
+ return envelope.make_response(payload=result)
+
async def _handle_workflow_step(self, envelope: Envelope) -> Envelope:
"""处理 WorkflowStep 调用请求
@@ -506,9 +1088,8 @@ class PluginRunner:
f"插件 {plugin_id} 未加载",
)
- instance = meta.instance
component_name = invoke.component_name
- handler_method = getattr(instance, f"handle_{component_name}", None) or getattr(instance, component_name, None)
+ handler_method = self._resolve_component_handler(meta, component_name)
if handler_method is None or not callable(handler_method):
return envelope.make_error_response(
@@ -557,36 +1138,92 @@ class PluginRunner:
async def _handle_shutdown(self, envelope: Envelope) -> Envelope:
"""处理关停 — 调用所有插件的 on_unload 后退出"""
logger.info("收到 shutdown 信号,开始调用 on_unload")
- for plugin_id in self._loader.list_plugins():
+ for plugin_id in list(self._loader.list_plugins()):
meta = self._loader.get_plugin(plugin_id)
- if meta and hasattr(meta.instance, "on_unload"):
- try:
- ret = meta.instance.on_unload()
- if asyncio.iscoroutine(ret):
- await ret
- except Exception as e:
- logger.error(f"插件 {plugin_id} on_unload 失败: {e}", exc_info=True)
+ if meta is not None:
+ await self._unload_plugin(meta, reason="runner_shutdown")
self._shutting_down = True
return envelope.make_response(payload={"acknowledged": True})
async def _handle_config_updated(self, envelope: Envelope) -> Envelope:
- """处理配置更新事件"""
+ """处理配置更新事件。"""
+ try:
+ payload = ConfigUpdatedPayload.model_validate(envelope.payload)
+ except Exception as exc:
+ return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
+
plugin_id = envelope.plugin_id
if meta := self._loader.get_plugin(plugin_id):
try:
- config_data = envelope.payload.get("config_data", {})
- config_version = envelope.payload.get("config_version", "")
- self._apply_plugin_config(meta, config_data=config_data)
- if hasattr(meta.instance, "on_config_update"):
- ret = meta.instance.on_config_update(config_data, config_version)
- # 兼容同步和异步的 on_config_update 实现
- if asyncio.iscoroutine(ret):
- await ret
+ config_scope = payload.config_scope.value
+ if config_scope == "self":
+ self._apply_plugin_config(meta, config_data=payload.config_data)
+ if not hasattr(meta.instance, "on_config_update"):
+ raise AttributeError("插件缺少 on_config_update() 实现")
+
+ ret = meta.instance.on_config_update(
+ config_scope,
+ payload.config_data,
+ payload.config_version,
+ )
+ if asyncio.iscoroutine(ret):
+ await ret
except Exception as e:
logger.error(f"插件 {plugin_id} 配置更新失败: {e}")
return envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
return envelope.make_response(payload={"acknowledged": True})
+ async def _handle_reload_plugin(self, envelope: Envelope) -> Envelope:
+ """处理按插件 ID 的精确重载请求。
+
+ Args:
+ envelope: RPC 请求信封。
+
+ Returns:
+ Envelope: 结构化重载结果。
+ """
+ try:
+ payload = ReloadPluginPayload.model_validate(envelope.payload)
+ except Exception as exc:
+ return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
+
+ if self._reload_lock.locked():
+ return envelope.make_error_response(
+ ErrorCode.E_RELOAD_IN_PROGRESS.value,
+ f"插件 {payload.plugin_id} 重载请求被拒绝:已有重载任务正在执行",
+ )
+
+ async with self._reload_lock:
+ result = await self._reload_plugin_by_id(
+ payload.plugin_id,
+ payload.reason,
+ external_available_plugins=dict(payload.external_available_plugins),
+ )
+ return envelope.make_response(payload=result.model_dump())
+
+ async def _handle_reload_plugins(self, envelope: Envelope) -> Envelope:
+ """处理批量插件重载请求。"""
+
+ try:
+ payload = ReloadPluginsPayload.model_validate(envelope.payload)
+ except Exception as exc:
+ return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
+
+ if self._reload_lock.locked():
+ requested_plugin_ids = ", ".join(self._normalize_requested_plugin_ids(payload.plugin_ids)) or ""
+ return envelope.make_error_response(
+ ErrorCode.E_RELOAD_IN_PROGRESS.value,
+ f"插件 {requested_plugin_ids} 批量重载请求被拒绝:已有重载任务正在执行",
+ )
+
+ async with self._reload_lock:
+ result = await self._reload_plugins_by_ids(
+ list(payload.plugin_ids),
+ payload.reason,
+ external_available_plugins=dict(payload.external_available_plugins),
+ )
+ return envelope.make_response(payload=result.model_dump())
+
def request_capability(self) -> RPCClient:
"""获取 RPC 客户端(供 SDK 使用,发起能力调用)"""
return self._rpc_client
@@ -598,9 +1235,14 @@ class PluginRunner:
def _isolate_sys_path(plugin_dirs: List[str]) -> None:
"""清理 sys.path,限制 Runner 子进程只能访问标准库、SDK 和插件目录。
- 防止插件代码 import 主程序模块读取运行时数据。
+ 同时阻止插件代码直接导入主程序内部 ``src.*`` 模块,并清理可直接从
+ ``sys.modules`` 摸到的高权限叶子模块,避免绕过 SDK / capability 边界。
"""
- import importlib.abc
+ from importlib import util as importlib_util
+ from types import ModuleType
+
+ import builtins
+ import importlib
import sysconfig
# 保留: 标准库路径 + site-packages(含 SDK 和依赖)
@@ -631,43 +1273,145 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
for d in plugin_dir_paths:
allowed.add(d)
- # 添加项目根目录(使得 src.plugin_runtime / src.common 可导入)
- runtime_root = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
- allowed.add(runtime_root)
-
preserved_paths = [p for p in sys.path if p in allowed]
- for extra_path in [*plugin_dir_paths, runtime_root]:
+ for extra_path in plugin_dir_paths:
if extra_path not in preserved_paths:
preserved_paths.append(extra_path)
sys.path[:] = preserved_paths
- # 安装 import 钩子,阻止插件导入主程序核心模块
- # 仅允许 src.plugin_runtime 和 src.common,拒绝其他 src.* 子包
- class _PluginImportBlocker(importlib.abc.MetaPathFinder):
- """阻止 Runner 子进程导入主程序核心模块。
+ # 仅为旧版插件兼容层保留极小的 src.* 可见面:
+ # - src.plugin_system.*: 通过 maibot_sdk.compat 导入钩子重定向
+ # - src.common.logger: 仓库内仍有少量旧插件沿用该日志入口
+ allowed_src_exact_modules = frozenset(
+ {
+ "src",
+ "src.common",
+ "src.common.logger",
+ "src.common.logger_color_and_mapping",
+ }
+ )
+ allowed_src_prefixes = ("src.plugin_system",)
+ plugin_module_prefix = "_maibot_plugin_"
- 只放行 src.plugin_runtime 和 src.common,
- 拒绝 src.chat_module / src.services 等主程序内部包。
- """
+ def _is_allowed_src_module(fullname: str) -> bool:
+ """判断给定 src.* 模块是否在 Runner 允许列表中。"""
+ if fullname in allowed_src_exact_modules:
+ return True
+ return any(fullname == prefix or fullname.startswith(f"{prefix}.") for prefix in allowed_src_prefixes)
- _ALLOWED_SRC_PREFIXES = ("src.plugin_runtime", "src.common")
+ def _resolve_requester_name(import_globals: Any = None) -> str:
+ """解析当前导入请求的发起模块名。"""
+ if isinstance(import_globals, dict):
+ for key in ("__name__", "__package__"):
+ value = import_globals.get(key)
+ if isinstance(value, str) and value:
+ return value
- def find_module(self, fullname, path=None):
- return self if self._should_block(fullname) else None
+ frame = inspect.currentframe()
+ try:
+ current = frame.f_back if frame is not None else None
+ while current is not None:
+ module_name = current.f_globals.get("__name__", "")
+ if not isinstance(module_name, str) or not module_name:
+ current = current.f_back
+ continue
+ if module_name == __name__ or module_name.startswith("importlib"):
+ current = current.f_back
+ continue
+ return module_name
+ return ""
+ finally:
+ del frame
- def load_module(self, fullname):
- raise ImportError(f"Runner 子进程不允许导入主程序模块: {fullname}")
+ def _is_plugin_import_request(import_globals: Any = None) -> bool:
+ """判断当前导入是否由插件模块直接发起。"""
+ requester_name = _resolve_requester_name(import_globals)
+ return requester_name.startswith(plugin_module_prefix)
- def _should_block(self, fullname: str) -> bool:
- # 放行非 src.* 的导入、以及 "src" 本身
- if not fullname.startswith("src.") or fullname == "src":
- return False
- # 放行白名单前缀
- return not any(
- fullname == prefix or fullname.startswith(f"{prefix}.") for prefix in self._ALLOWED_SRC_PREFIXES
- )
+ def _format_block_message(fullname: str) -> str:
+ """构造统一的拒绝导入错误信息。"""
+ return (
+ f"Runner 子进程不允许导入主程序模块: {fullname}。"
+ "请改用 maibot_sdk 或 src.plugin_system 兼容层提供的接口。"
+ )
- sys.meta_path.insert(0, _PluginImportBlocker())
+ def _iter_requested_src_modules(name: str, fromlist: Any) -> List[str]:
+ """展开本次导入请求涉及的 src.* 模块名。"""
+ requested_modules = [name]
+ if not name.startswith("src") or not fromlist:
+ return requested_modules
+
+ for item in fromlist:
+ if not isinstance(item, str) or not item or item == "*":
+ continue
+ requested_modules.append(f"{name}.{item}")
+ return requested_modules
+
+ def _assert_plugin_import_allowed(name: str, import_globals: Any = None, fromlist: Any = ()) -> None:
+ """在插件发起导入时校验目标 src.* 模块是否允许访问。"""
+ if not _is_plugin_import_request(import_globals):
+ return
+
+ for requested_module in _iter_requested_src_modules(name, fromlist):
+ if not requested_module.startswith("src"):
+ continue
+ if _is_allowed_src_module(requested_module):
+ continue
+ raise ImportError(_format_block_message(requested_module))
+
+ def _detach_module_from_parent(fullname: str, module: ModuleType) -> None:
+ """从父模块上移除已清理模块的属性引用。"""
+ parent_name, _, child_name = fullname.rpartition(".")
+ if not parent_name or not child_name:
+ return
+
+ parent_module = sys.modules.get(parent_name)
+ if parent_module is None:
+ return
+ if getattr(parent_module, child_name, None) is module:
+ with contextlib.suppress(AttributeError):
+ delattr(parent_module, child_name)
+
+ # 仅清理已加载的叶子模块,保留包对象给 Runner 自己的延迟导入和相对导入使用。
+ existing_src_modules = sorted(
+ (
+ (module_name, module)
+ for module_name, module in list(sys.modules.items())
+ if module_name == "src" or module_name.startswith("src.")
+ ),
+ key=lambda item: item[0].count("."),
+ reverse=True,
+ )
+ for module_name, module in existing_src_modules:
+ if _is_allowed_src_module(module_name) or hasattr(module, "__path__"):
+ continue
+ _detach_module_from_parent(module_name, module)
+ sys.modules.pop(module_name, None)
+
+ # ``import`` 语句与 ``importlib.import_module`` 走的是不同入口,因此两边都需要兜底。
+ builtins_module = cast(Any, builtins)
+ original_import = getattr(builtins_module, "__maibot_runner_original_import__", builtins.__import__)
+ builtins_module.__maibot_runner_original_import__ = original_import
+
+ def _guarded_import(name: str, globals: Any = None, locals: Any = None, fromlist: Any = (), level: int = 0) -> Any:
+ if level == 0:
+ _assert_plugin_import_allowed(name, import_globals=globals, fromlist=fromlist)
+ return original_import(name, globals, locals, fromlist, level)
+
+ cast(Any, _guarded_import).__maibot_runner_plugin_import_guard__ = True
+ builtins.__import__ = _guarded_import
+
+ importlib_module = cast(Any, importlib)
+ original_import_module = getattr(importlib_module, "__maibot_runner_original_import_module__", importlib.import_module)
+ importlib_module.__maibot_runner_original_import_module__ = original_import_module
+
+ def _guarded_import_module(name: str, package: Optional[str] = None) -> Any:
+ resolved_name = importlib_util.resolve_name(name, package) if name.startswith(".") else name
+ _assert_plugin_import_allowed(resolved_name)
+ return original_import_module(name, package)
+
+ cast(Any, _guarded_import_module).__maibot_runner_plugin_import_guard__ = True
+ importlib.import_module = _guarded_import_module
# ─── 进程入口 ──────────────────────────────────────────────
@@ -675,8 +1419,9 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
async def _async_main() -> None:
"""异步主入口"""
- host_address = os.environ.get(ENV_IPC_ADDRESS, "")
- session_token = os.environ.get(ENV_SESSION_TOKEN, "")
+ host_address = os.environ.pop(ENV_IPC_ADDRESS, "")
+ external_plugin_ids_raw = os.environ.get(ENV_EXTERNAL_PLUGIN_IDS, "")
+ session_token = os.environ.pop(ENV_SESSION_TOKEN, "")
plugin_dirs_str = os.environ.get(ENV_PLUGIN_DIRS, "")
if not host_address or not session_token:
@@ -684,14 +1429,31 @@ async def _async_main() -> None:
sys.exit(1)
plugin_dirs = [d for d in plugin_dirs_str.split(os.pathsep) if d]
+ try:
+ external_plugin_ids = json.loads(external_plugin_ids_raw) if external_plugin_ids_raw else {}
+ except json.JSONDecodeError:
+ logger.warning("解析外部依赖插件版本映射失败,已回退为空映射")
+ external_plugin_ids = {}
+ if not isinstance(external_plugin_ids, dict):
+ logger.warning("外部依赖插件版本映射格式非法,已回退为空映射")
+ external_plugin_ids = {}
# sys.path 隔离: 只保留标准库、SDK 包、插件目录
_isolate_sys_path(plugin_dirs)
- runner = PluginRunner(host_address, session_token, plugin_dirs)
+ runner = PluginRunner(
+ host_address,
+ session_token,
+ plugin_dirs,
+ external_available_plugins={
+ str(plugin_id): str(plugin_version)
+ for plugin_id, plugin_version in external_plugin_ids.items()
+ },
+ )
# 注册信号处理
def _mark_runner_shutting_down() -> None:
+ """标记 Runner 即将进入关停流程。"""
runner._shutting_down = True
_install_shutdown_signal_handlers(_mark_runner_shutting_down)
diff --git a/src/plugin_runtime/transport/named_pipe.py b/src/plugin_runtime/transport/named_pipe.py
index a759507d..7fd39bc9 100644
--- a/src/plugin_runtime/transport/named_pipe.py
+++ b/src/plugin_runtime/transport/named_pipe.py
@@ -1,6 +1,9 @@
"""Windows Named Pipe 传输实现。
适用于 Windows 平台,使用 asyncio ProactorEventLoop 的 named pipe 支持。
+
+注意:Named Pipe 是 Windows 特有的 IPC 机制,
+在 Linux/macOS 平台上不可用。Unix-like 平台请使用 UDS 传输。
"""
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, cast
@@ -18,10 +21,12 @@ _DEFAULT_PIPE_PREFIX = "maibot-plugin"
class _NamedPipeServerHandle(Protocol):
+ """Named Pipe 服务端句柄的协议定义。"""
def close(self) -> None: ...
class _NamedPipeEventLoop(Protocol):
+ """ProactorEventLoop 的协议定义,提供 named pipe 相关方法。"""
async def start_serving_pipe(
self,
protocol_factory: Callable[[], asyncio.BaseProtocol],
@@ -40,6 +45,15 @@ class _NamedPipeEventLoop(Protocol):
def _normalize_pipe_address(pipe_name: Optional[str] = None) -> str:
+ """规范化 Named Pipe 地址。
+
+ Args:
+ pipe_name: 管道名称。如果以 '\\\\.\\pipe\\' 开头则直接使用,
+ 否则会自动添加前缀。如果为 None 则生成随机名称。
+
+ Returns:
+ 规范化的管道地址(格式:\\\\.\\pipe\\name)
+ """
if pipe_name and pipe_name.startswith(_PIPE_PREFIX):
return pipe_name
@@ -55,12 +69,21 @@ def _normalize_pipe_address(pipe_name: Optional[str] = None) -> str:
class NamedPipeConnection(Connection):
- """基于 Windows Named Pipe 的连接。"""
+ """基于 Windows Named Pipe 的连接。
+
+ 封装了底层 StreamReader/StreamWriter,提供分帧读写能力。
+ """
- pass
+ def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
+ super().__init__(reader, writer)
class _NamedPipeServerProtocol(asyncio.StreamReaderProtocol):
+ """Named Pipe 服务端协议实现。
+
+ 处理客户端连接的生命周期,包括连接建立、数据处理和连接关闭。
+ """
+
def __init__(self, handler: ConnectionHandler, loop: asyncio.AbstractEventLoop) -> None:
self._reader: asyncio.StreamReader = asyncio.StreamReader()
super().__init__(self._reader)
@@ -69,39 +92,58 @@ class _NamedPipeServerProtocol(asyncio.StreamReaderProtocol):
self._handler_task: Optional[asyncio.Task[None]] = None
def connection_made(self, transport: asyncio.BaseTransport) -> None:
+ """连接建立时的回调。"""
super().connection_made(transport)
writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), self, self._reader, self._loop)
connection = NamedPipeConnection(self._reader, writer)
- self._handler_task = self._loop.create_task(self._run_handler(connection))
+ # 使用 asyncio.create_task 确保任务正确调度
+ self._handler_task = asyncio.create_task(self._run_handler(connection))
self._handler_task.add_done_callback(self._on_handler_done)
async def _run_handler(self, connection: NamedPipeConnection) -> None:
+ """运行连接处理器。"""
try:
await self._handler(connection)
finally:
await connection.close()
def _on_handler_done(self, task: asyncio.Task[None]) -> None:
+ """连接处理器完成时的回调。"""
if task.cancelled():
return
if exc := task.exception():
- self._loop.call_exception_handler(
- {
- "message": "Named pipe 连接处理失败",
- "exception": exc,
- "protocol": self,
- }
- )
+ try:
+ self._loop.call_exception_handler(
+ {
+ "message": "Named pipe 连接处理失败",
+ "exception": exc,
+ "protocol": self,
+ }
+ )
+ except Exception:
+ # 如果 loop 已经关闭,忽略异常
+ pass
class NamedPipeTransportServer(TransportServer):
- """Windows Named Pipe 传输服务端。"""
+ """Windows Named Pipe 传输服务端。
+
+ 使用 ProactorEventLoop 的 start_serving_pipe 方法监听客户端连接。
+ """
def __init__(self, pipe_name: Optional[str] = None) -> None:
self._address: str = _normalize_pipe_address(pipe_name)
self._servers: List[_NamedPipeServerHandle] = []
async def start(self, handler: ConnectionHandler) -> None:
+ """启动 Named Pipe 服务端。
+
+ Args:
+ handler: 新连接到来时的回调函数
+
+ Raises:
+ RuntimeError: 当在非 Windows 平台或事件循环不支持时
+ """
if sys.platform != "win32":
raise RuntimeError("Named pipe 仅支持 Windows")
@@ -116,32 +158,49 @@ class NamedPipeTransportServer(TransportServer):
)
async def stop(self) -> None:
+ """停止 Named Pipe 服务端并清理资源。"""
for server in self._servers:
server.close()
+ # 等待所有服务器句柄完全关闭
+ await asyncio.gather(
+ *[asyncio.sleep(0.1) for _ in self._servers],
+ return_exceptions=True
+ )
self._servers.clear()
- await asyncio.sleep(0)
def get_address(self) -> str:
return self._address
class NamedPipeTransportClient(TransportClient):
- """Windows Named Pipe 传输客户端。"""
+ """Windows Named Pipe 传输客户端。
+
+ 用于主动连接到 Named Pipe 服务端。
+ """
def __init__(self, address: str) -> None:
self._address: str = _normalize_pipe_address(address)
async def connect(self) -> Connection:
+ """建立到 Named Pipe 服务端的连接。
+
+ Returns:
+ NamedPipeConnection: 连接对象
+
+ Raises:
+ NotImplementedError: 当在非 Windows 平台或事件循环不支持时
+ """
if sys.platform != "win32":
- raise RuntimeError("Named pipe 仅支持 Windows")
+ raise NotImplementedError("Named pipe 仅支持 Windows")
loop = asyncio.get_running_loop()
if not hasattr(loop, "create_pipe_connection"):
- raise RuntimeError("当前事件循环不支持 Windows named pipe")
+ raise NotImplementedError("当前事件循环不支持 Windows named pipe")
pipe_loop = cast(_NamedPipeEventLoop, loop)
reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(reader)
transport, _protocol = await pipe_loop.create_pipe_connection(lambda: protocol, self._address)
- writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), protocol, reader, loop)
+ # 使用返回的 protocol 创建 StreamWriter
+ writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), _protocol, reader, loop)
return NamedPipeConnection(reader, writer)
\ No newline at end of file
diff --git a/src/plugin_runtime/transport/uds.py b/src/plugin_runtime/transport/uds.py
index 47bf033b..af71ea5d 100644
--- a/src/plugin_runtime/transport/uds.py
+++ b/src/plugin_runtime/transport/uds.py
@@ -1,6 +1,9 @@
"""Unix Domain Socket 传输实现
适用于 Linux / macOS 平台。
+
+注意:UDS (Unix Domain Socket) 是 Unix-like 系统特有的 IPC 机制,
+在 Windows 平台上不可用。Windows 平台请使用 Named Pipe 传输。
"""
from pathlib import Path
@@ -8,20 +11,30 @@ from typing import Optional
import asyncio
import os
+import sys
import tempfile
from .base import Connection, ConnectionHandler, TransportClient, TransportServer
class UDSConnection(Connection):
- """基于 UDS 的连接"""
+ """基于 UDS 的连接
+
+ 封装了底层 StreamReader/StreamWriter,提供分帧读写能力。
+ """
- pass # 直接复用 Connection 基类的分帧读写
+ def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
+ super().__init__(reader, writer)
# Unix domain socket 路径的系统限制(sun_path 字段长度)
-# Linux: 108 字节, macOS: 104 字节
-_UDS_PATH_MAX = 104
+# Linux: 108 字节,macOS: 104 字节,其他 Unix: 通常 104 字节
+if sys.platform == "linux":
+ _UDS_PATH_MAX = 108
+elif sys.platform == "darwin": # macOS
+ _UDS_PATH_MAX = 104
+else:
+ _UDS_PATH_MAX = 104 # 保守默认值
class UDSTransportServer(TransportServer):
@@ -44,6 +57,18 @@ class UDSTransportServer(TransportServer):
self._server: Optional[asyncio.AbstractServer] = None
async def start(self, handler: ConnectionHandler) -> None:
+ """启动 UDS 服务端
+
+ Args:
+ handler: 新连接到来时的回调函数
+
+ Raises:
+ RuntimeError: 当在非 Unix 平台(如 Windows)上调用时
+ """
+ # 平台检查:UDS 仅在 Unix-like 系统上可用
+ if sys.platform == "win32":
+ raise RuntimeError("UDS 不支持 Windows 平台,请使用 Named Pipe")
+
# 清理残留 socket 文件
if self._socket_path.exists():
self._socket_path.unlink()
@@ -58,10 +83,16 @@ class UDSTransportServer(TransportServer):
finally:
await conn.close()
- self._server = await asyncio.start_unix_server(_on_connect, path=str(self._socket_path))
+ try:
+ self._server = await asyncio.start_unix_server(_on_connect, path=str(self._socket_path))
- # 设置文件权限为仅当前用户可访问
- self._socket_path.chmod(0o600)
+ # 设置文件权限为仅当前用户可访问
+ self._socket_path.chmod(0o600)
+ except Exception:
+ # 启动失败时清理可能创建的目录和 socket 文件
+ if self._socket_path.exists():
+ self._socket_path.unlink()
+ raise
async def stop(self) -> None:
if self._server:
@@ -77,11 +108,26 @@ class UDSTransportServer(TransportServer):
class UDSTransportClient(TransportClient):
- """UDS 传输客户端"""
+ """UDS 传输客户端
+
+ 用于主动连接到 UDS 服务端。
+ """
def __init__(self, socket_path: Path) -> None:
self._socket_path: Path = socket_path
async def connect(self) -> Connection:
+ """建立到 UDS 服务端的连接
+
+ Returns:
+ UDSConnection: 连接对象
+
+ Raises:
+ RuntimeError: 当在非 Unix 平台(如 Windows)上调用时
+ """
+ # 平台检查:UDS 仅在 Unix-like 系统上可用
+ if sys.platform == "win32":
+ raise RuntimeError("UDS 不支持 Windows 平台,请使用 Named Pipe")
+
reader, writer = await asyncio.open_unix_connection(str(self._socket_path))
return UDSConnection(reader, writer)
diff --git a/src/plugins/built_in/emoji_plugin/_manifest.json b/src/plugins/built_in/emoji_plugin/_manifest.json
index d4d262e7..5b53abad 100644
--- a/src/plugins/built_in/emoji_plugin/_manifest.json
+++ b/src/plugins/built_in/emoji_plugin/_manifest.json
@@ -1,32 +1,28 @@
{
- "manifest_version": 1,
- "name": "Emoji插件 (Emoji Actions)",
+ "manifest_version": 2,
"version": "2.0.0",
- "description": "可以发送和管理Emoji",
+ "name": "Emoji插件 (Emoji Actions)",
+ "description": "可以发送和管理 Emoji",
"author": {
"name": "SengokuCola",
"url": "https://github.com/MaiM-with-u"
},
"license": "GPL-v3.0-or-later",
+ "urls": {
+ "repository": "https://github.com/MaiM-with-u/maibot",
+ "homepage": "https://github.com/MaiM-with-u/maibot",
+ "documentation": "https://github.com/MaiM-with-u/maibot",
+ "issues": "https://github.com/MaiM-with-u/maibot/issues"
+ },
"host_application": {
- "min_version": "1.0.0"
+ "min_version": "1.0.0",
+ "max_version": "1.0.0"
},
- "homepage_url": "https://github.com/MaiM-with-u/maibot",
- "repository_url": "https://github.com/MaiM-with-u/maibot",
- "keywords": ["emoji", "action", "built-in"],
- "categories": ["Emoji"],
- "default_locale": "zh-CN",
- "plugin_info": {
- "is_built_in": true,
- "plugin_type": "action_provider",
- "components": [
- {
- "type": "action",
- "name": "emoji",
- "description": "发送表情包辅助表达情绪"
- }
- ]
+ "sdk": {
+ "min_version": "2.0.0",
+ "max_version": "2.99.99"
},
+ "dependencies": [],
"capabilities": [
"emoji.get_random",
"message.get_recent",
@@ -34,5 +30,12 @@
"llm.generate",
"send.emoji",
"config.get"
- ]
+ ],
+ "i18n": {
+ "default_locale": "zh-CN",
+ "supported_locales": [
+ "zh-CN"
+ ]
+ },
+ "id": "builtin.emoji-plugin"
}
diff --git a/src/plugins/built_in/emoji_plugin/plugin.py b/src/plugins/built_in/emoji_plugin/plugin.py
index b946931b..cc6b87c5 100644
--- a/src/plugins/built_in/emoji_plugin/plugin.py
+++ b/src/plugins/built_in/emoji_plugin/plugin.py
@@ -3,11 +3,11 @@
根据聊天上下文的情感,使用 LLM 选择并发送合适的表情包。
"""
-import random
-
-from maibot_sdk import MaiBotPlugin, Action
+from maibot_sdk import Action, MaiBotPlugin
from maibot_sdk.types import ActivationType
+import random
+
class EmojiPlugin(MaiBotPlugin):
"""表情包插件"""
@@ -95,10 +95,35 @@ class EmojiPlugin(MaiBotPlugin):
return True, f"成功发送表情包:[表情包:{chosen_emotion}]"
return False, "发送表情包失败"
- async def on_load(self):
+ async def on_load(self) -> None:
+ """处理插件加载。"""
+
# 从插件配置读取 emoji_chance 来覆盖默认概率
await self.ctx.config.get("emoji.emoji_chance")
+ async def on_unload(self) -> None:
+ """处理插件卸载。"""
+
+ async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None:
+ """处理配置热重载事件。
+
+ Args:
+ scope: 配置变更范围。
+ config_data: 最新配置数据。
+ version: 配置版本号。
+ """
+
+ del config_data
+ del version
+ if scope == "self":
+ await self.ctx.config.get("emoji.emoji_chance")
+
+
+def create_plugin() -> EmojiPlugin:
+ """创建 Emoji 插件实例。
+
+ Returns:
+ EmojiPlugin: 新的 Emoji 插件实例。
+ """
-def create_plugin():
return EmojiPlugin()
diff --git a/src/plugins/built_in/plugin_management/_manifest.json b/src/plugins/built_in/plugin_management/_manifest.json
index a5b52835..a2bfa9ce 100644
--- a/src/plugins/built_in/plugin_management/_manifest.json
+++ b/src/plugins/built_in/plugin_management/_manifest.json
@@ -1,51 +1,46 @@
{
- "manifest_version": 1,
- "name": "插件和组件管理 (Plugin and Component Management)",
+ "manifest_version": 2,
"version": "2.0.0",
- "description": "通过系统API管理插件和组件的生命周期,包括加载、卸载、启用和禁用等操作。",
+ "name": "插件和组件管理 (Plugin and Component Management)",
+ "description": "通过系统 API 管理插件和组件的生命周期,包括加载、卸载、启用和禁用等操作。",
"author": {
"name": "MaiBot团队",
"url": "https://github.com/MaiM-with-u"
},
"license": "GPL-v3.0-or-later",
- "host_application": {
- "min_version": "1.0.0"
+ "urls": {
+ "repository": "https://github.com/MaiM-with-u/maibot",
+ "homepage": "https://github.com/MaiM-with-u/maibot",
+ "documentation": "https://github.com/MaiM-with-u/maibot",
+ "issues": "https://github.com/MaiM-with-u/maibot/issues"
},
- "homepage_url": "https://github.com/MaiM-with-u/maibot",
- "repository_url": "https://github.com/MaiM-with-u/maibot",
- "keywords": [
- "plugins",
- "components",
- "management",
- "built-in"
+ "host_application": {
+ "min_version": "1.0.0",
+ "max_version": "1.0.0"
+ },
+ "sdk": {
+ "min_version": "2.0.0",
+ "max_version": "2.99.99"
+ },
+ "dependencies": [],
+ "capabilities": [
+ "component.get_all_plugins",
+ "component.list_loaded_plugins",
+ "component.list_registered_plugins",
+ "component.enable",
+ "component.disable",
+ "component.load_plugin",
+ "component.unload_plugin",
+ "component.reload_plugin",
+ "send.text",
+ "config.get"
],
- "categories": [
- "Core System",
- "Plugin Management"
- ],
- "default_locale": "zh-CN",
- "locales_path": "_locales",
- "plugin_info": {
- "is_built_in": true,
- "plugin_type": "plugin_management",
- "capabilities": [
- "component.get_all_plugins",
- "component.list_loaded_plugins",
- "component.list_registered_plugins",
- "component.enable",
- "component.disable",
- "component.load_plugin",
- "component.unload_plugin",
- "component.reload_plugin",
- "send.text",
- "config.get"
- ],
- "components": [
- {
- "type": "command",
- "name": "management",
- "description": "管理插件和组件的生命周期,包括加载、卸载、启用和禁用等操作。"
- }
+ "i18n": {
+ "default_locale": "zh-CN",
+ "locales_path": "_locales",
+ "supported_locales": [
+ "zh-CN"
]
- }
-}
\ No newline at end of file
+ },
+ "id": "builtin.plugin-management"
+}
diff --git a/src/plugins/built_in/plugin_management/plugin.py b/src/plugins/built_in/plugin_management/plugin.py
index fe0888c6..aa2da795 100644
--- a/src/plugins/built_in/plugin_management/plugin.py
+++ b/src/plugins/built_in/plugin_management/plugin.py
@@ -3,7 +3,7 @@
通过 /pm 命令管理插件和组件的生命周期。
"""
-from maibot_sdk import MaiBotPlugin, Command
+from maibot_sdk import Command, MaiBotPlugin
_VALID_COMPONENT_TYPES = ("action", "command", "event_handler")
@@ -44,6 +44,12 @@ HELP_COMPONENT = (
class PluginManagementPlugin(MaiBotPlugin):
"""插件和组件管理插件"""
+ async def on_load(self) -> None:
+ """处理插件加载。"""
+
+ async def on_unload(self) -> None:
+ """处理插件卸载。"""
+
@Command(
"management",
description="管理插件和组件的生命周期",
@@ -268,6 +274,25 @@ class PluginManagementPlugin(MaiBotPlugin):
return components
return []
+ async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None:
+ """处理配置热重载事件。
+
+ Args:
+ scope: 配置变更范围。
+ config_data: 最新配置数据。
+ version: 配置版本号。
+ """
+
+ del scope
+ del config_data
+ del version
+
+
+def create_plugin() -> PluginManagementPlugin:
+ """创建插件管理插件实例。
+
+ Returns:
+ PluginManagementPlugin: 新的插件管理插件实例。
+ """
-def create_plugin():
return PluginManagementPlugin()
diff --git a/src/services/send_service.py b/src/services/send_service.py
index 7af55716..134fb15e 100644
--- a/src/services/send_service.py
+++ b/src/services/send_service.py
@@ -1,155 +1,640 @@
"""
-发送服务模块
+发送服务模块。
-提供发送各种类型消息的核心功能。
+统一封装内部模块的出站消息发送逻辑:
+
+1. 内部模块统一调用本模块。
+2. send service 只负责构造和预处理消息。
+3. 具体走插件链还是 legacy 旧链,由 Platform IO 内部统一决策。
"""
-from typing import Dict, List, Optional, TYPE_CHECKING
+from copy import deepcopy
+from typing import Any, Dict, List, Optional
+import asyncio
+import base64
+import hashlib
import time
import traceback
+from datetime import datetime
-from maim_message import BaseMessageInfo, GroupInfo as MaimGroupInfo, MessageBase, Seg, UserInfo as MaimUserInfo
-
+from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.chat.message_receive.message import SessionMessage
-from src.chat.message_receive.uni_message_sender import UniversalMessageSender
-from src.chat.utils.utils import get_bot_account
-from src.common.data_models.mai_message_data_model import MaiMessage
-from src.common.data_models.message_component_data_model import DictComponent, MessageSequence
+from src.chat.utils.utils import calculate_typing_time, get_bot_account
+from src.common.data_models.mai_message_data_model import GroupInfo, MaiMessage, MessageInfo, UserInfo
+from src.common.data_models.message_component_data_model import (
+ AtComponent,
+ DictComponent,
+ EmojiComponent,
+ ForwardNodeComponent,
+ ImageComponent,
+ MessageSequence,
+ ReplyComponent,
+ StandardMessageComponents,
+ TextComponent,
+ VoiceComponent,
+)
from src.common.logger import get_logger
+from src.common.utils.utils_message import MessageUtils
from src.config.config import global_config
-
-if TYPE_CHECKING:
- from src.chat.message_receive.message import SessionMessage
+from src.platform_io import DeliveryBatch, get_platform_io_manager
+from src.platform_io.route_key_factory import RouteKeyFactory
logger = get_logger("send_service")
-# =============================================================================
-# 内部实现函数
-# =============================================================================
+def _inherit_platform_io_route_metadata(target_stream: BotChatSession) -> Dict[str, object]:
+ """从目标会话继承 Platform IO 路由元数据。
+
+ Args:
+ target_stream: 当前消息要发送到的会话对象。
+
+ Returns:
+ Dict[str, object]: 可安全透传到出站消息 ``additional_config`` 中的
+ 路由辅助字段。
+ """
+ inherited_metadata: Dict[str, object] = {}
+
+ context_message = target_stream.context.message if target_stream.context else None
+ if context_message is not None:
+ additional_config = context_message.message_info.additional_config
+ if isinstance(additional_config, dict):
+ for key in (*RouteKeyFactory.ACCOUNT_ID_KEYS, *RouteKeyFactory.SCOPE_KEYS):
+ value = additional_config.get(key)
+ if value is None:
+ continue
+ normalized_value = str(value).strip()
+ if normalized_value:
+ inherited_metadata[key] = value
+
+ # 当目标会话没有可继承的上下文消息时,至少补齐当前平台账号,
+ # 让按 ``platform + account_id`` 绑定的路由仍有机会命中。
+ if not RouteKeyFactory.extract_components(inherited_metadata)[0]:
+ bot_account = get_bot_account(target_stream.platform)
+ if bot_account:
+ inherited_metadata["platform_io_account_id"] = bot_account
+
+ if target_stream.group_id and (normalized_group_id := str(target_stream.group_id).strip()):
+ inherited_metadata["platform_io_target_group_id"] = normalized_group_id
+
+ if target_stream.user_id and (normalized_user_id := str(target_stream.user_id).strip()):
+ inherited_metadata["platform_io_target_user_id"] = normalized_user_id
+
+ return inherited_metadata
+
+
+def _build_binary_component_from_base64(component_type: str, raw_data: str) -> StandardMessageComponents:
+ """根据 Base64 数据构造二进制消息组件。
+
+ Args:
+ component_type: 组件类型名称。
+ raw_data: Base64 编码后的二进制数据。
+
+ Returns:
+ StandardMessageComponents: 转换后的内部消息组件。
+
+ Raises:
+ ValueError: 当组件类型不受支持时抛出。
+ """
+ binary_data = base64.b64decode(raw_data)
+ binary_hash = hashlib.sha256(binary_data).hexdigest()
+
+ if component_type == "image":
+ return ImageComponent(binary_hash=binary_hash, binary_data=binary_data)
+ if component_type == "emoji":
+ return EmojiComponent(binary_hash=binary_hash, binary_data=binary_data)
+ if component_type == "voice":
+ return VoiceComponent(binary_hash=binary_hash, binary_data=binary_data)
+ raise ValueError(f"不支持的二进制组件类型: {component_type}")
+
+
+def _build_message_sequence_from_custom_message(
+ message_type: str,
+ content: str | Dict[str, Any],
+) -> MessageSequence:
+ """根据自定义消息类型构造内部消息组件序列。
+
+ Args:
+ message_type: 自定义消息类型。
+ content: 自定义消息内容。
+
+ Returns:
+ MessageSequence: 转换后的消息组件序列。
+ """
+ normalized_type = message_type.strip().lower()
+
+ if normalized_type == "text":
+ return MessageSequence(components=[TextComponent(text=str(content))])
+
+ if normalized_type in {"image", "emoji", "voice"}:
+ return MessageSequence(
+ components=[_build_binary_component_from_base64(normalized_type, str(content))]
+ )
+
+ if normalized_type == "at":
+ return MessageSequence(components=[AtComponent(target_user_id=str(content))])
+
+ if normalized_type == "reply":
+ return MessageSequence(components=[ReplyComponent(target_message_id=str(content))])
+
+ if normalized_type == "dict" and isinstance(content, dict):
+ return MessageSequence(components=[DictComponent(data=deepcopy(content))])
+
+ return MessageSequence(
+ components=[
+ DictComponent(
+ data={
+ "type": normalized_type,
+ "data": deepcopy(content),
+ }
+ )
+ ]
+ )
+
+
+def _clone_message_sequence(message_sequence: MessageSequence) -> MessageSequence:
+ """复制消息组件序列,避免原对象被发送流程修改。
+
+ Args:
+ message_sequence: 原始消息组件序列。
+
+ Returns:
+ MessageSequence: 深拷贝后的消息组件序列。
+ """
+ return deepcopy(message_sequence)
+
+
+def _detect_outbound_message_flags(message_sequence: MessageSequence) -> Dict[str, bool]:
+ """根据消息组件序列推断出站消息标记。
+
+ Args:
+ message_sequence: 待发送的消息组件序列。
+
+ Returns:
+ Dict[str, bool]: 包含 ``is_emoji``、``is_picture``、``is_command`` 的标记字典。
+ """
+ if len(message_sequence.components) != 1:
+ return {
+ "is_emoji": False,
+ "is_picture": False,
+ "is_command": False,
+ }
+
+ component = message_sequence.components[0]
+ is_command = False
+ if isinstance(component, DictComponent) and isinstance(component.data, dict):
+ is_command = str(component.data.get("type") or "").strip().lower() == "command"
+
+ return {
+ "is_emoji": isinstance(component, EmojiComponent),
+ "is_picture": isinstance(component, ImageComponent),
+ "is_command": is_command,
+ }
+
+
+def _describe_message_sequence(message_sequence: MessageSequence) -> str:
+ """生成消息组件序列的简短描述文本。
+
+ Args:
+ message_sequence: 待描述的消息组件序列。
+
+ Returns:
+ str: 适用于日志的简短类型描述。
+ """
+ if len(message_sequence.components) != 1:
+ return "message_sequence"
+
+ component = message_sequence.components[0]
+ if isinstance(component, DictComponent) and isinstance(component.data, dict):
+ custom_type = str(component.data.get("type") or "").strip()
+ return custom_type or "dict"
+
+ if isinstance(component, TextComponent):
+ return component.format_name
+
+ if isinstance(component, ImageComponent):
+ return component.format_name
+
+ if isinstance(component, EmojiComponent):
+ return component.format_name
+
+ if isinstance(component, VoiceComponent):
+ return component.format_name
+
+ if isinstance(component, AtComponent):
+ return component.format_name
+
+ if isinstance(component, ReplyComponent):
+ return component.format_name
+
+ if isinstance(component, ForwardNodeComponent):
+ return component.format_name
+
+ return "unknown"
+
+
+def _build_processed_plain_text(message: SessionMessage) -> str:
+ """为出站消息构造轻量纯文本摘要。
+
+ Args:
+ message: 待发送的内部消息对象。
+
+ Returns:
+ str: 适用于日志与打字时长估算的纯文本摘要。
+ """
+ processed_parts: List[str] = []
+ for component in message.raw_message.components:
+ if isinstance(component, TextComponent):
+ processed_parts.append(component.text)
+ continue
+
+ if isinstance(component, ImageComponent):
+ processed_parts.append(component.content or "[图片]")
+ continue
+
+ if isinstance(component, EmojiComponent):
+ processed_parts.append(component.content or "[表情]")
+ continue
+
+ if isinstance(component, VoiceComponent):
+ processed_parts.append(component.content or "[语音]")
+ continue
+
+ if isinstance(component, AtComponent):
+ at_target = component.target_user_cardname or component.target_user_nickname or component.target_user_id
+ processed_parts.append(f"@{at_target}")
+ continue
+
+ if isinstance(component, ReplyComponent):
+ processed_parts.append(component.target_message_content or "[回复消息]")
+ continue
+
+ if isinstance(component, DictComponent):
+ raw_type = component.data.get("type") if isinstance(component.data, dict) else None
+ if isinstance(raw_type, str) and raw_type.strip():
+ processed_parts.append(f"[{raw_type.strip()}消息]")
+ else:
+ processed_parts.append("[自定义消息]")
+ continue
+
+ return " ".join(part for part in processed_parts if part)
+
+
+def _build_outbound_session_message(
+ message_sequence: MessageSequence,
+ stream_id: str,
+ display_message: str = "",
+ reply_message: Optional[MaiMessage] = None,
+ selected_expressions: Optional[List[int]] = None,
+) -> Optional[SessionMessage]:
+ """根据目标会话构建待发送的内部消息对象。
+
+ Args:
+ message_sequence: 待发送的消息组件序列。
+ stream_id: 目标会话 ID。
+ display_message: 用于界面展示的文本内容。
+ reply_message: 被回复的锚点消息。
+ selected_expressions: 可选的表情候选索引列表。
+
+ Returns:
+ Optional[SessionMessage]: 构建成功时返回内部消息对象;若目标会话或
+ 机器人账号不存在,则返回 ``None``。
+ """
+ target_stream = _chat_manager.get_session_by_session_id(stream_id)
+ if target_stream is None:
+ logger.error(f"[SendService] 未找到聊天流: {stream_id}")
+ return None
+
+ bot_user_id = get_bot_account(target_stream.platform)
+ if not bot_user_id:
+ logger.error(f"[SendService] 平台 {target_stream.platform} 未配置机器人账号,无法发送消息")
+ return None
+
+ current_time = time.time()
+ message_id = f"send_api_{int(current_time * 1000)}"
+ anchor_message = reply_message.deepcopy() if reply_message is not None else None
+
+ group_info: Optional[GroupInfo] = None
+ if target_stream.group_id:
+ group_name = ""
+ if (
+ target_stream.context
+ and target_stream.context.message
+ and target_stream.context.message.message_info.group_info
+ ):
+ group_name = target_stream.context.message.message_info.group_info.group_name
+ group_info = GroupInfo(
+ group_id=target_stream.group_id,
+ group_name=group_name,
+ )
+
+ additional_config: Dict[str, object] = _inherit_platform_io_route_metadata(target_stream)
+ if selected_expressions is not None:
+ additional_config["selected_expressions"] = selected_expressions
+
+ outbound_message = SessionMessage(
+ message_id=message_id,
+ timestamp=datetime.fromtimestamp(current_time),
+ platform=target_stream.platform,
+ )
+ outbound_message.message_info = MessageInfo(
+ user_info=UserInfo(
+ user_id=bot_user_id,
+ user_nickname=global_config.bot.nickname,
+ ),
+ group_info=group_info,
+ additional_config=additional_config,
+ )
+ outbound_message.raw_message = _clone_message_sequence(message_sequence)
+ outbound_message.session_id = target_stream.session_id
+ outbound_message.display_message = display_message
+ outbound_message.reply_to = anchor_message.message_id if anchor_message is not None else None
+ message_flags = _detect_outbound_message_flags(outbound_message.raw_message)
+ outbound_message.is_emoji = message_flags["is_emoji"]
+ outbound_message.is_picture = message_flags["is_picture"]
+ outbound_message.is_command = message_flags["is_command"]
+ outbound_message.initialized = True
+ return outbound_message
+
+
+def _ensure_reply_component(message: SessionMessage, reply_message_id: str) -> None:
+ """为消息补充回复组件。
+
+ Args:
+ message: 待发送的内部消息对象。
+ reply_message_id: 被引用消息的 ID。
+ """
+ if message.raw_message.components:
+ first_component = message.raw_message.components[0]
+ if isinstance(first_component, ReplyComponent) and first_component.target_message_id == reply_message_id:
+ return
+
+ message.raw_message.components.insert(0, ReplyComponent(target_message_id=reply_message_id))
+
+
+async def _prepare_message_for_platform_io(
+ message: SessionMessage,
+ *,
+ typing: bool,
+ set_reply: bool,
+ reply_message_id: Optional[str],
+) -> None:
+ """为 Platform IO 发送链预处理消息。
+
+ Args:
+ message: 待发送的内部消息对象。
+ typing: 是否模拟打字等待。
+ set_reply: 是否构建引用回复组件。
+ reply_message_id: 被引用消息的 ID。
+
+ Raises:
+ ValueError: 当要求设置引用回复但缺少 ``reply_message_id`` 时抛出。
+ """
+ if set_reply:
+ if not reply_message_id:
+ raise ValueError("set_reply=True 时必须提供 reply_message_id")
+ _ensure_reply_component(message, reply_message_id)
+
+ message.processed_plain_text = _build_processed_plain_text(message)
+ if typing:
+ typing_time = calculate_typing_time(
+ input_string=message.processed_plain_text or "",
+ is_emoji=message.is_emoji,
+ )
+ await asyncio.sleep(typing_time)
+
+
+def _store_sent_message(message: SessionMessage) -> None:
+ """将已成功发送的消息写入数据库。
+
+ Args:
+ message: 已成功发送的内部消息对象。
+ """
+ MessageUtils.store_message_to_db(message)
+
+
+def _log_platform_io_failures(delivery_batch: DeliveryBatch) -> None:
+ """输出 Platform IO 批量发送失败详情。
+
+ Args:
+ delivery_batch: Platform IO 返回的批量回执。
+ """
+ failed_details = "; ".join(
+ f"driver={receipt.driver_id} status={receipt.status} error={receipt.error}"
+ for receipt in delivery_batch.failed_receipts
+ ) or "未命中任何发送路由"
+ logger.warning(
+ "[SendService] Platform IO 发送失败: platform=%s %s",
+ delivery_batch.route_key.platform,
+ failed_details,
+ )
+
+
+async def _send_via_platform_io(
+ message: SessionMessage,
+ *,
+ typing: bool,
+ set_reply: bool,
+ reply_message_id: Optional[str],
+ storage_message: bool,
+ show_log: bool,
+) -> bool:
+ """通过 Platform IO 发送消息。
+
+ Args:
+ message: 待发送的内部消息对象。
+ typing: 是否模拟打字等待。
+ set_reply: 是否设置引用回复。
+ reply_message_id: 被引用消息的 ID。
+ storage_message: 发送成功后是否写入数据库。
+ show_log: 是否输出发送成功日志。
+
+ Returns:
+ bool: 发送成功时返回 ``True``。
+ """
+ platform_io_manager = get_platform_io_manager()
+ try:
+ await platform_io_manager.ensure_send_pipeline_ready()
+ except Exception as exc:
+ logger.error(f"[SendService] 准备 Platform IO 发送管线失败: {exc}")
+ logger.debug(traceback.format_exc())
+ return False
+
+ try:
+ route_key = platform_io_manager.build_route_key_from_message(message)
+ except Exception as exc:
+ logger.warning(f"[SendService] 根据消息构造 Platform IO 路由键失败: {exc}")
+ return False
+
+ try:
+ await _prepare_message_for_platform_io(
+ message,
+ typing=typing,
+ set_reply=set_reply,
+ reply_message_id=reply_message_id,
+ )
+ delivery_batch = await platform_io_manager.send_message(
+ message,
+ route_key,
+ metadata={"show_log": False},
+ )
+ except Exception as exc:
+ logger.error(f"[SendService] Platform IO 发送异常: {exc}")
+ logger.debug(traceback.format_exc())
+ return False
+
+ if delivery_batch.has_success:
+ if storage_message:
+ _store_sent_message(message)
+ if show_log:
+ successful_driver_ids = [
+ receipt.driver_id or "unknown"
+ for receipt in delivery_batch.sent_receipts
+ ]
+ logger.info(
+ "[SendService] 已通过 Platform IO 将消息发往平台 '%s' (drivers: %s)",
+ route_key.platform,
+ ", ".join(successful_driver_ids),
+ )
+ return True
+
+ _log_platform_io_failures(delivery_batch)
+ return False
+
+
+async def send_session_message(
+ message: SessionMessage,
+ *,
+ typing: bool = False,
+ set_reply: bool = False,
+ reply_message_id: Optional[str] = None,
+ storage_message: bool = True,
+ show_log: bool = True,
+) -> bool:
+ """统一发送一条内部消息。
+
+ 该方法是内部模块的统一发送入口:
+
+ 1. 构造并维护内部消息对象。
+ 2. 由 Platform IO 统一决定走插件链还是 legacy 旧链。
+ 3. send service 不再自行判断底层发送路径。
+
+ Args:
+ message: 待发送的内部消息对象。
+ typing: 是否模拟打字等待。
+ set_reply: 是否设置引用回复。
+ reply_message_id: 被引用消息的 ID。
+ storage_message: 发送成功后是否写入数据库。
+ show_log: 是否输出发送日志。
+
+ Returns:
+ bool: 发送成功时返回 ``True``,否则返回 ``False``。
+ """
+ if not message.message_id:
+ logger.error("[SendService] 消息缺少 message_id,无法发送")
+ raise ValueError("消息缺少 message_id,无法发送")
+
+ return await _send_via_platform_io(
+ message,
+ typing=typing,
+ set_reply=set_reply,
+ reply_message_id=reply_message_id,
+ storage_message=storage_message,
+ show_log=show_log,
+ )
async def _send_to_target(
- message_segment: Seg,
+ message_sequence: MessageSequence,
stream_id: str,
display_message: str = "",
typing: bool = False,
set_reply: bool = False,
- reply_message: Optional["SessionMessage"] = None,
+ reply_message: Optional[MaiMessage] = None,
storage_message: bool = True,
show_log: bool = True,
selected_expressions: Optional[List[int]] = None,
) -> bool:
- """向指定目标发送消息的内部实现"""
+ """向指定目标构建并发送消息。
+
+ Args:
+ message_sequence: 待发送的消息组件序列。
+ stream_id: 目标会话 ID。
+ display_message: 用于界面展示的文本内容。
+ typing: 是否显示输入中状态。
+ set_reply: 是否在发送时附带引用回复。
+ reply_message: 被回复的消息对象。
+ storage_message: 是否将发送结果写入消息存储。
+ show_log: 是否输出发送日志。
+ selected_expressions: 可选的表情候选索引列表。
+
+ Returns:
+ bool: 发送成功返回 ``True``,否则返回 ``False``。
+ """
try:
- if set_reply and not reply_message:
+ if set_reply and reply_message is None:
logger.warning("[SendService] 使用引用回复,但未提供回复消息")
return False
if show_log:
- logger.debug(f"[SendService] 发送{message_segment.type}消息到 {stream_id}")
+ logger.debug(f"[SendService] 发送{_describe_message_sequence(message_sequence)}消息到 {stream_id}")
- target_stream = _chat_manager.get_session_by_session_id(stream_id)
- if not target_stream:
- logger.error(f"[SendService] 未找到聊天流: {stream_id}")
- return False
-
- message_sender = UniversalMessageSender()
-
- current_time = time.time()
- message_id = f"send_api_{int(current_time * 1000)}"
-
- anchor_message: Optional[MaiMessage] = None
- if reply_message:
- anchor_message = reply_message.deepcopy()
- if anchor_message:
- logger.debug(
- f"[SendService] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}"
- )
-
- group_info = None
- if target_stream.group_id:
- group_name = ""
- if target_stream.context and target_stream.context.message and target_stream.context.message.message_info.group_info:
- group_name = target_stream.context.message.message_info.group_info.group_name
- group_info = MaimGroupInfo(
- group_id=target_stream.group_id,
- group_name=group_name,
- platform=target_stream.platform,
- )
-
- additional_config: dict[str, object] = {}
- if selected_expressions is not None:
- additional_config["selected_expressions"] = selected_expressions
- bot_user_id = get_bot_account(target_stream.platform)
- if not bot_user_id:
- logger.error(f"[SendService] 平台 {target_stream.platform} 未配置机器人账号,无法发送消息")
- return False
-
- maim_message = MessageBase(
- message_info=BaseMessageInfo(
- platform=target_stream.platform,
- message_id=message_id,
- time=current_time,
- user_info=MaimUserInfo(
- user_id=bot_user_id,
- user_nickname=global_config.bot.nickname,
- platform=target_stream.platform,
- ),
- group_info=group_info,
- additional_config=additional_config,
- ),
- message_segment=message_segment,
+ outbound_message = _build_outbound_session_message(
+ message_sequence=message_sequence,
+ stream_id=stream_id,
+ display_message=display_message,
+ reply_message=reply_message,
+ selected_expressions=selected_expressions,
)
- bot_message = SessionMessage.from_maim_message(maim_message)
- bot_message.session_id = target_stream.session_id
- bot_message.display_message = display_message
- bot_message.reply_to = anchor_message.message_id if anchor_message else None
- bot_message.is_emoji = message_segment.type == "emoji"
- bot_message.is_picture = message_segment.type == "image"
- bot_message.is_command = message_segment.type == "command"
+ if outbound_message is None:
+ return False
- sent_msg = await message_sender.send_message(
- bot_message,
+ sent = await send_session_message(
+ outbound_message,
typing=typing,
set_reply=set_reply,
- reply_message_id=anchor_message.message_id if anchor_message else None,
+ reply_message_id=reply_message.message_id if reply_message is not None else None,
storage_message=storage_message,
show_log=show_log,
)
-
- if sent_msg:
+ if sent:
logger.debug(f"[SendService] 成功发送消息到 {stream_id}")
return True
- else:
- logger.error("[SendService] 发送消息失败")
- return False
- except Exception as e:
- logger.error(f"[SendService] 发送消息时出错: {e}")
+ logger.error("[SendService] 发送消息失败")
+ return False
+ except Exception as exc:
+ logger.error(f"[SendService] 发送消息时出错: {exc}")
traceback.print_exc()
return False
-# =============================================================================
-# 公共函数 - 预定义类型的发送函数
-# =============================================================================
-
-
async def text_to_stream(
text: str,
stream_id: str,
typing: bool = False,
set_reply: bool = False,
- reply_message: Optional["SessionMessage"] = None,
+ reply_message: Optional[MaiMessage] = None,
storage_message: bool = True,
selected_expressions: Optional[List[int]] = None,
) -> bool:
- """向指定流发送文本消息"""
+ """向指定流发送文本消息。
+
+ Args:
+ text: 要发送的文本内容。
+ stream_id: 目标会话 ID。
+ typing: 是否显示输入中状态。
+ set_reply: 是否附带引用回复。
+ reply_message: 被回复的消息对象。
+ storage_message: 是否在发送成功后写入数据库。
+ selected_expressions: 可选的表情候选索引列表。
+
+ Returns:
+ bool: 发送成功时返回 ``True``。
+ """
return await _send_to_target(
- message_segment=Seg(type="text", data=text),
+ message_sequence=MessageSequence(components=[TextComponent(text=text)]),
stream_id=stream_id,
display_message="",
typing=typing,
@@ -165,11 +650,22 @@ async def emoji_to_stream(
stream_id: str,
storage_message: bool = True,
set_reply: bool = False,
- reply_message: Optional["SessionMessage"] = None,
+ reply_message: Optional[MaiMessage] = None,
) -> bool:
- """向指定流发送表情包"""
+ """向指定流发送表情消息。
+
+ Args:
+ emoji_base64: 表情图片的 Base64 内容。
+ stream_id: 目标会话 ID。
+ storage_message: 是否在发送成功后写入数据库。
+ set_reply: 是否附带引用回复。
+ reply_message: 被回复的消息对象。
+
+ Returns:
+ bool: 发送成功时返回 ``True``。
+ """
return await _send_to_target(
- message_segment=Seg(type="emoji", data=emoji_base64),
+ message_sequence=_build_message_sequence_from_custom_message("emoji", emoji_base64),
stream_id=stream_id,
display_message="",
typing=False,
@@ -184,11 +680,22 @@ async def image_to_stream(
stream_id: str,
storage_message: bool = True,
set_reply: bool = False,
- reply_message: Optional["SessionMessage"] = None,
+ reply_message: Optional[MaiMessage] = None,
) -> bool:
- """向指定流发送图片"""
+ """向指定流发送图片消息。
+
+ Args:
+ image_base64: 图片的 Base64 内容。
+ stream_id: 目标会话 ID。
+ storage_message: 是否在发送成功后写入数据库。
+ set_reply: 是否附带引用回复。
+ reply_message: 被回复的消息对象。
+
+ Returns:
+ bool: 发送成功时返回 ``True``。
+ """
return await _send_to_target(
- message_segment=Seg(type="image", data=image_base64),
+ message_sequence=_build_message_sequence_from_custom_message("image", image_base64),
stream_id=stream_id,
display_message="",
typing=False,
@@ -200,18 +707,33 @@ async def image_to_stream(
async def custom_to_stream(
message_type: str,
- content: str | Dict,
+ content: str | Dict[str, Any],
stream_id: str,
display_message: str = "",
typing: bool = False,
- reply_message: Optional["SessionMessage"] = None,
+ reply_message: Optional[MaiMessage] = None,
set_reply: bool = False,
storage_message: bool = True,
show_log: bool = True,
) -> bool:
- """向指定流发送自定义类型消息"""
+ """向指定流发送自定义类型消息。
+
+ Args:
+ message_type: 自定义消息类型。
+ content: 自定义消息内容。
+ stream_id: 目标会话 ID。
+ display_message: 用于展示的文本内容。
+ typing: 是否显示输入中状态。
+ reply_message: 被回复的消息对象。
+ set_reply: 是否附带引用回复。
+ storage_message: 是否在发送成功后写入数据库。
+ show_log: 是否输出发送日志。
+
+ Returns:
+ bool: 发送成功时返回 ``True``。
+ """
return await _send_to_target(
- message_segment=Seg(type=message_type, data=content), # type: ignore
+ message_sequence=_build_message_sequence_from_custom_message(message_type, content),
stream_id=stream_id,
display_message=display_message,
typing=typing,
@@ -227,31 +749,33 @@ async def custom_reply_set_to_stream(
stream_id: str,
display_message: str = "",
typing: bool = False,
- reply_message: Optional["SessionMessage"] = None,
+ reply_message: Optional[MaiMessage] = None,
set_reply: bool = False,
storage_message: bool = True,
show_log: bool = True,
) -> bool:
- """向指定流发送消息组件序列。"""
- flag: bool = True
- for component in reply_set.components:
- if isinstance(component, DictComponent):
- message_seg = Seg(type="dict", data=component.data) # type: ignore
- else:
- message_seg = await component.to_seg()
- status = await _send_to_target(
- message_segment=message_seg,
- stream_id=stream_id,
- display_message=display_message,
- typing=typing,
- reply_message=reply_message,
- set_reply=set_reply,
- storage_message=storage_message,
- show_log=show_log,
- )
- if not status:
- flag = False
- logger.error(f"[SendService] 发送消息组件失败,组件类型:{type(component).__name__}")
- set_reply = False
+ """向指定流发送消息组件序列。
- return flag
+ Args:
+ reply_set: 待发送的消息组件序列。
+ stream_id: 目标会话 ID。
+ display_message: 用于展示的文本内容。
+ typing: 是否显示输入中状态。
+ reply_message: 被回复的消息对象。
+ set_reply: 是否附带引用回复。
+ storage_message: 是否在发送成功后写入数据库。
+ show_log: 是否输出发送日志。
+
+ Returns:
+ bool: 发送成功时返回 ``True``。
+ """
+ return await _send_to_target(
+ message_sequence=reply_set,
+ stream_id=stream_id,
+ display_message=display_message,
+ typing=typing,
+ reply_message=reply_message,
+ set_reply=set_reply,
+ storage_message=storage_message,
+ show_log=show_log,
+ )
diff --git a/src/webui/routers/chat/serializers.py b/src/webui/routers/chat/serializers.py
new file mode 100644
index 00000000..32104f88
--- /dev/null
+++ b/src/webui/routers/chat/serializers.py
@@ -0,0 +1,175 @@
+"""提供 WebUI 聊天路由使用的消息序列化能力。"""
+
+from typing import Any, Dict, List, Optional
+
+import base64
+
+from src.common.data_models.message_component_data_model import (
+ AtComponent,
+ DictComponent,
+ EmojiComponent,
+ ForwardComponent,
+ ForwardNodeComponent,
+ ImageComponent,
+ MessageSequence,
+ ReplyComponent,
+ StandardMessageComponents,
+ TextComponent,
+ VoiceComponent,
+)
+
+
+def serialize_message_sequence(message_sequence: MessageSequence) -> List[Dict[str, Any]]:
+ """将内部统一消息组件序列转换为 WebUI 富文本消息段。
+
+ Args:
+ message_sequence: 内部统一消息组件序列。
+
+ Returns:
+ List[Dict[str, Any]]: 可直接广播给 WebUI 前端的消息段列表。
+ """
+ serialized_segments: List[Dict[str, Any]] = []
+ for component in message_sequence.components:
+ serialized_segment = serialize_message_component(component)
+ if serialized_segment is not None:
+ serialized_segments.append(serialized_segment)
+ return serialized_segments
+
+
+def serialize_message_component(component: StandardMessageComponents) -> Optional[Dict[str, Any]]:
+ """将单个内部消息组件转换为 WebUI 消息段。
+
+ Args:
+ component: 待序列化的内部消息组件。
+
+ Returns:
+ Optional[Dict[str, Any]]: 序列化后的 WebUI 消息段;若组件不应展示则返回 ``None``。
+ """
+ if isinstance(component, TextComponent):
+ return {"type": "text", "data": component.text}
+
+ if isinstance(component, ImageComponent):
+ return _serialize_binary_component(
+ segment_type="image",
+ mime_type="image/png",
+ binary_data=component.binary_data,
+ fallback_text=component.content,
+ )
+
+ if isinstance(component, EmojiComponent):
+ return _serialize_binary_component(
+ segment_type="emoji",
+ mime_type="image/gif",
+ binary_data=component.binary_data,
+ fallback_text=component.content,
+ )
+
+ if isinstance(component, VoiceComponent):
+ return _serialize_binary_component(
+ segment_type="voice",
+ mime_type="audio/wav",
+ binary_data=component.binary_data,
+ fallback_text=component.content,
+ )
+
+ if isinstance(component, AtComponent):
+ return {
+ "type": "at",
+ "data": {
+ "target_user_id": component.target_user_id,
+ "target_user_nickname": component.target_user_nickname,
+ "target_user_cardname": component.target_user_cardname,
+ },
+ }
+
+ if isinstance(component, ReplyComponent):
+ return {
+ "type": "reply",
+ "data": {
+ "target_message_id": component.target_message_id,
+ "target_message_content": component.target_message_content,
+ "target_message_sender_id": component.target_message_sender_id,
+ "target_message_sender_nickname": component.target_message_sender_nickname,
+ "target_message_sender_cardname": component.target_message_sender_cardname,
+ },
+ }
+
+ if isinstance(component, ForwardNodeComponent):
+ return {
+ "type": "forward",
+ "data": [_serialize_forward_component(item) for item in component.forward_components],
+ }
+
+ if isinstance(component, DictComponent):
+ return _serialize_dict_component(component.data)
+
+ return {"type": "unknown", "data": str(component)}
+
+
+def _serialize_binary_component(
+ segment_type: str,
+ mime_type: str,
+ binary_data: bytes,
+ fallback_text: str,
+) -> Dict[str, Any]:
+ """序列化带二进制负载的消息组件。
+
+ Args:
+ segment_type: WebUI 消息段类型。
+ mime_type: 对应的数据 MIME 类型。
+ binary_data: 组件二进制数据。
+ fallback_text: 二进制缺失时可退化展示的文本。
+
+ Returns:
+ Dict[str, Any]: 序列化后的 WebUI 消息段。
+ """
+ if binary_data:
+ encoded_payload = base64.b64encode(binary_data).decode()
+ return {"type": segment_type, "data": f"data:{mime_type};base64,{encoded_payload}"}
+
+ if fallback_text:
+ return {"type": "text", "data": fallback_text}
+
+ return {"type": "unknown", "original_type": segment_type, "data": ""}
+
+
+def _serialize_forward_component(component: ForwardComponent) -> Dict[str, Any]:
+ """序列化单个转发节点。
+
+ Args:
+ component: 待序列化的转发节点组件。
+
+ Returns:
+ Dict[str, Any]: WebUI 可消费的转发节点字典。
+ """
+ return {
+ "message_id": component.message_id,
+ "user_id": component.user_id,
+ "user_nickname": component.user_nickname,
+ "user_cardname": component.user_cardname,
+ "content": serialize_message_sequence(MessageSequence(component.content)),
+ }
+
+
+def _serialize_dict_component(data: Dict[str, Any]) -> Dict[str, Any]:
+ """最佳努力地序列化非标准字典组件。
+
+ Args:
+ data: 原始字典组件内容。
+
+ Returns:
+ Dict[str, Any]: 序列化后的 WebUI 消息段。
+ """
+ raw_type = str(data.get("type") or "dict").strip()
+ raw_payload = data.get("data", data)
+
+ if raw_type in {"text", "image", "emoji", "voice", "video", "file", "music", "face"}:
+ return {"type": raw_type, "data": raw_payload}
+
+ if raw_type == "reply":
+ return {"type": "reply", "data": raw_payload}
+
+ if raw_type == "forward" and isinstance(raw_payload, list):
+ return {"type": "forward", "data": raw_payload}
+
+ return {"type": "unknown", "original_type": raw_type, "data": raw_payload}