From 03ed59e388237b4fc9edaea5c1dcb453717063c6 Mon Sep 17 00:00:00 2001
From: SengokuCola <1026294844@qq.com>
Date: Tue, 24 Mar 2026 11:36:26 +0800
Subject: [PATCH 01/45] =?UTF-8?q?=E6=9B=B4=E6=94=B9=E6=96=87=E4=BB=B6?=
=?UTF-8?q?=E7=BB=93=E6=9E=84?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
README.md | 137 +++++++--------
docs/README_CN.md | 162 ++++++++++++++++++
docs/README_EN.md | 2 +-
docs/minimal-cross-platform-plan.md | 6 +-
src/chat/brain_chat/brain_chat.py | 4 +-
src/chat/heart_flow/heartFC_chat - 副本.py | 4 +-
src/chat/heart_flow/heartFC_chat.py | 4 +-
src/chat/replyer/group_generator.py | 4 +-
src/chat/replyer/private_generator.py | 4 +-
.../expression_auto_check_task.py | 2 +-
.../expression_learner.py | 0
.../expression_review_store.py | 0
.../expression_selector.py | 2 +-
.../expression_utils.py | 0
.../jargon_explainer.py | 0
.../jargon_explainer_old.py | 4 +-
src/{bw_learner => learners}/jargon_miner.py | 0
src/{bw_learner => learners}/learner_utils.py | 0
.../learner_utils_old.py | 0
src/main.py | 2 +-
src/memory_system/memory_retrieval.py | 2 +-
.../retrieval_tools/query_words.py | 2 +-
22 files changed, 248 insertions(+), 93 deletions(-)
create mode 100644 docs/README_CN.md
rename src/{bw_learner => learners}/expression_auto_check_task.py (98%)
rename src/{bw_learner => learners}/expression_learner.py (100%)
rename src/{bw_learner => learners}/expression_review_store.py (100%)
rename src/{bw_learner => learners}/expression_selector.py (99%)
rename src/{bw_learner => learners}/expression_utils.py (100%)
rename src/{bw_learner => learners}/jargon_explainer.py (100%)
rename src/{bw_learner => learners}/jargon_explainer_old.py (99%)
rename src/{bw_learner => learners}/jargon_miner.py (100%)
rename src/{bw_learner => learners}/learner_utils.py (100%)
rename src/{bw_learner => learners}/learner_utils_old.py (100%)
diff --git a/README.md b/README.md
index 7c41c8cf..3f3851b7 100644
--- a/README.md
+++ b/README.md
@@ -1,21 +1,21 @@
-
简体中文 |
English
+
简体中文 |
English
diff --git a/docs/minimal-cross-platform-plan.md b/docs/minimal-cross-platform-plan.md
index d0b6707b..2f0a86bd 100644
--- a/docs/minimal-cross-platform-plan.md
+++ b/docs/minimal-cross-platform-plan.md
@@ -41,7 +41,7 @@ This plan is based on the checked-in code, not on assumptions from previous draf
| `src/person_info/person_info.py:247` | `_is_bot_self(self, platform, user_id)` | Duplicate logic with same QQ fallback |
Wrong-order call sites (8 total):
-- `src/bw_learner/expression_learner.py` x3 (lines 158, 241, 301)
+- `src/learners/expression_learner.py` x3 (lines 158, 241, 301)
- `src/common/utils/utils_message.py` x4 (lines 370, 440, 476, 515)
- `src/webui/routers/chat/support.py` x1 (line 65)
@@ -122,7 +122,7 @@ Make `src/chat/utils/utils.py::is_bot_self(platform, user_id)` the only real imp
- `src/common/utils/system_utils.py`
- `src/chat/utils/utils.py`
- `src/person_info/person_info.py`
-- `src/bw_learner/expression_learner.py`
+- `src/learners/expression_learner.py`
- `src/common/utils/utils_message.py`
- `src/webui/routers/chat/support.py`
- tests
@@ -468,7 +468,7 @@ When stopping, name: the exact file(s), the blocking mismatch, why it is outside
| Phase | Allowed files |
|-------|---------------|
-| Phase 0 | `src/common/utils/system_utils.py`, `src/chat/utils/utils.py`, `src/person_info/person_info.py`, `src/bw_learner/expression_learner.py`, `src/common/utils/utils_message.py`, `src/webui/routers/chat/support.py`, tests (including `pytests/utils_test/message_utils_test.py`) |
+| Phase 0 | `src/common/utils/system_utils.py`, `src/chat/utils/utils.py`, `src/person_info/person_info.py`, `src/learners/expression_learner.py`, `src/common/utils/utils_message.py`, `src/webui/routers/chat/support.py`, tests (including `pytests/utils_test/message_utils_test.py`) |
| Phase 1 | `src/chat/utils/utils.py`, `src/chat/planner_actions/planner.py`, `src/chat/utils/statistic.py`, `src/common/message_repository.py`, `src/webui/routers/chat/support.py`, `src/services/send_service.py`, `src/chat/replyer/group_generator.py`, `src/chat/replyer/private_generator.py`, `src/chat/brain_chat/PFC/message_sender.py`, `src/person_info/person_info.py`, tests |
### INVALID OUTPUT EXAMPLES
diff --git a/src/chat/brain_chat/brain_chat.py b/src/chat/brain_chat/brain_chat.py
index 2b4863ac..1e9e648a 100644
--- a/src/chat/brain_chat/brain_chat.py
+++ b/src/chat/brain_chat/brain_chat.py
@@ -8,8 +8,8 @@ from rich.traceback import install
from src.config.config import global_config
from src.common.logger import get_logger
from src.common.utils.utils_config import ExpressionConfigUtils
-from src.bw_learner.expression_learner import ExpressionLearner
-from src.bw_learner.jargon_miner import JargonMiner
+from src.learners.expression_learner import ExpressionLearner
+from src.learners.jargon_miner import JargonMiner
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.chat.message_receive.message import SessionMessage
diff --git a/src/chat/heart_flow/heartFC_chat - 副本.py b/src/chat/heart_flow/heartFC_chat - 副本.py
index c805597d..02f70281 100644
--- a/src/chat/heart_flow/heartFC_chat - 副本.py
+++ b/src/chat/heart_flow/heartFC_chat - 副本.py
@@ -16,9 +16,9 @@ from src.chat.planner_actions.planner import ActionPlanner
from src.chat.planner_actions.action_modifier import ActionModifier
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.heart_flow.hfc_utils import CycleDetail
-from src.bw_learner.expression_learner import expression_learner_manager
+from src.learners.expression_learner import expression_learner_manager
from src.chat.heart_flow.frequency_control import frequency_control_manager
-from src.bw_learner.message_recorder import extract_and_distribute_messages
+from src.learners.message_recorder import extract_and_distribute_messages
from src.person_info.person_info import Person
from src.plugin_system.base.component_types import EventType, ActionInfo
from src.plugin_system.core import events_manager
diff --git a/src/chat/heart_flow/heartFC_chat.py b/src/chat/heart_flow/heartFC_chat.py
index af0beb4e..2c1eb162 100644
--- a/src/chat/heart_flow/heartFC_chat.py
+++ b/src/chat/heart_flow/heartFC_chat.py
@@ -7,8 +7,8 @@ import traceback
from rich.traceback import install
-from src.bw_learner.expression_learner import ExpressionLearner
-from src.bw_learner.jargon_miner import JargonMiner
+from src.learners.expression_learner import ExpressionLearner
+from src.learners.jargon_miner import JargonMiner
from src.chat.event_helpers import build_event_message
from src.chat.logger.plan_reply_logger import PlanReplyLogger
from src.chat.message_receive.chat_manager import BotChatSession
diff --git a/src/chat/replyer/group_generator.py b/src/chat/replyer/group_generator.py
index 74b324be..e10aa147 100644
--- a/src/chat/replyer/group_generator.py
+++ b/src/chat/replyer/group_generator.py
@@ -27,7 +27,7 @@ from src.services.message_service import (
replace_user_references,
translate_pid_to_description,
)
-from src.bw_learner.expression_selector import expression_selector
+from src.learners.expression_selector import expression_selector
# from src.memory_system.memory_activator import MemoryActivator
from src.person_info.person_info import Person
@@ -36,7 +36,7 @@ from src.services import llm_service as llm_api
from src.chat.logger.plan_reply_logger import PlanReplyLogger
from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt
-from src.bw_learner.jargon_explainer_old import explain_jargon_in_context, retrieve_concepts_with_jargon
+from src.learners.jargon_explainer_old import explain_jargon_in_context, retrieve_concepts_with_jargon
from src.chat.utils.common_utils import TempMethodsExpression
init_memory_retrieval_sys()
diff --git a/src/chat/replyer/private_generator.py b/src/chat/replyer/private_generator.py
index 3b70bb2c..ccd8e1e4 100644
--- a/src/chat/replyer/private_generator.py
+++ b/src/chat/replyer/private_generator.py
@@ -27,13 +27,13 @@ from src.services.message_service import (
replace_user_references,
translate_pid_to_description,
)
-from src.bw_learner.expression_selector import expression_selector
+from src.learners.expression_selector import expression_selector
# from src.memory_system.memory_activator import MemoryActivator
from src.person_info.person_info import Person, is_person_known
from src.core.types import ActionInfo, EventType
from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt
-from src.bw_learner.jargon_explainer_old import explain_jargon_in_context
+from src.learners.jargon_explainer_old import explain_jargon_in_context
init_memory_retrieval_sys()
diff --git a/src/bw_learner/expression_auto_check_task.py b/src/learners/expression_auto_check_task.py
similarity index 98%
rename from src/bw_learner/expression_auto_check_task.py
rename to src/learners/expression_auto_check_task.py
index d90eb4da..53b151b2 100644
--- a/src/bw_learner/expression_auto_check_task.py
+++ b/src/learners/expression_auto_check_task.py
@@ -15,7 +15,7 @@ import random
from sqlmodel import select
-from src.bw_learner.expression_review_store import get_review_state, set_review_state
+from src.learners.expression_review_store import get_review_state, set_review_state
from src.common.database.database import get_db_session
from src.common.database.database_model import Expression
from src.common.logger import get_logger
diff --git a/src/bw_learner/expression_learner.py b/src/learners/expression_learner.py
similarity index 100%
rename from src/bw_learner/expression_learner.py
rename to src/learners/expression_learner.py
diff --git a/src/bw_learner/expression_review_store.py b/src/learners/expression_review_store.py
similarity index 100%
rename from src/bw_learner/expression_review_store.py
rename to src/learners/expression_review_store.py
diff --git a/src/bw_learner/expression_selector.py b/src/learners/expression_selector.py
similarity index 99%
rename from src/bw_learner/expression_selector.py
rename to src/learners/expression_selector.py
index c6cfe469..c96e84cf 100644
--- a/src/bw_learner/expression_selector.py
+++ b/src/learners/expression_selector.py
@@ -9,7 +9,7 @@ from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.prompt.prompt_manager import prompt_manager
-from src.bw_learner.learner_utils_old import weighted_sample
+from src.learners.learner_utils_old import weighted_sample
from src.chat.utils.common_utils import TempMethodsExpression
logger = get_logger("expression_selector")
diff --git a/src/bw_learner/expression_utils.py b/src/learners/expression_utils.py
similarity index 100%
rename from src/bw_learner/expression_utils.py
rename to src/learners/expression_utils.py
diff --git a/src/bw_learner/jargon_explainer.py b/src/learners/jargon_explainer.py
similarity index 100%
rename from src/bw_learner/jargon_explainer.py
rename to src/learners/jargon_explainer.py
diff --git a/src/bw_learner/jargon_explainer_old.py b/src/learners/jargon_explainer_old.py
similarity index 99%
rename from src/bw_learner/jargon_explainer_old.py
rename to src/learners/jargon_explainer_old.py
index 4d144b2c..0cfafa82 100644
--- a/src/bw_learner/jargon_explainer_old.py
+++ b/src/learners/jargon_explainer_old.py
@@ -7,8 +7,8 @@ from src.common.database.database_model import Jargon
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.prompt.prompt_manager import prompt_manager
-from src.bw_learner.jargon_miner_old import search_jargon
-from src.bw_learner.learner_utils_old import (
+from src.learners.jargon_miner_old import search_jargon
+from src.learners.learner_utils_old import (
is_bot_message,
contains_bot_self_name,
parse_chat_id_list,
diff --git a/src/bw_learner/jargon_miner.py b/src/learners/jargon_miner.py
similarity index 100%
rename from src/bw_learner/jargon_miner.py
rename to src/learners/jargon_miner.py
diff --git a/src/bw_learner/learner_utils.py b/src/learners/learner_utils.py
similarity index 100%
rename from src/bw_learner/learner_utils.py
rename to src/learners/learner_utils.py
diff --git a/src/bw_learner/learner_utils_old.py b/src/learners/learner_utils_old.py
similarity index 100%
rename from src/bw_learner/learner_utils_old.py
rename to src/learners/learner_utils_old.py
diff --git a/src/main.py b/src/main.py
index 587c5634..30d1c86d 100644
--- a/src/main.py
+++ b/src/main.py
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
import asyncio
import time
-from src.bw_learner.expression_auto_check_task import ExpressionAutoCheckTask
+from src.learners.expression_auto_check_task import ExpressionAutoCheckTask
from src.chat.emoji_system.emoji_manager import emoji_manager
from src.chat.knowledge import lpmm_start_up
from src.chat.message_receive.bot import chat_bot
diff --git a/src/memory_system/memory_retrieval.py b/src/memory_system/memory_retrieval.py
index 4193a16a..982db166 100644
--- a/src/memory_system/memory_retrieval.py
+++ b/src/memory_system/memory_retrieval.py
@@ -14,7 +14,7 @@ from src.common.database.database_model import ThinkingQuestion
from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
-from src.bw_learner.jargon_explainer_old import retrieve_concepts_with_jargon
+from src.learners.jargon_explainer_old import retrieve_concepts_with_jargon
logger = get_logger("memory_retrieval")
diff --git a/src/memory_system/retrieval_tools/query_words.py b/src/memory_system/retrieval_tools/query_words.py
index 66fb3c46..ee28b934 100644
--- a/src/memory_system/retrieval_tools/query_words.py
+++ b/src/memory_system/retrieval_tools/query_words.py
@@ -4,7 +4,7 @@
"""
from src.common.logger import get_logger
-from src.bw_learner.jargon_explainer_old import retrieve_concepts_with_jargon
+from src.learners.jargon_explainer_old import retrieve_concepts_with_jargon
from .tool_registry import register_memory_retrieval_tool
logger = get_logger("memory_retrieval_tools")
From 668f41431ab5d88f147964de61b96691ed1a1705 Mon Sep 17 00:00:00 2001
From: UnCLAS-Prommer
Date: Mon, 16 Mar 2026 16:57:14 +0800
Subject: [PATCH 02/45] =?UTF-8?q?=E7=A7=BB=E9=99=A4=E6=97=A0=E7=94=A8?=
=?UTF-8?q?=E9=85=8D=E7=BD=AE?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
src/config/official_configs.py | 23 ++++-------------------
1 file changed, 4 insertions(+), 19 deletions(-)
diff --git a/src/config/official_configs.py b/src/config/official_configs.py
index 6ed5c452..e6907611 100644
--- a/src/config/official_configs.py
+++ b/src/config/official_configs.py
@@ -1615,24 +1615,6 @@ class PluginRuntimeConfig(ConfigBase):
)
"""启用插件系统"""
- builtin_plugin_dir: str = Field(
- default="src/plugins/built_in",
- json_schema_extra={
- "x-widget": "input",
- "x-icon": "folder",
- },
- )
- """内置插件目录(相对于项目根目录)"""
-
- thirdparty_plugin_dir: str = Field(
- default="plugins",
- json_schema_extra={
- "x-widget": "input",
- "x-icon": "folder-open",
- },
- )
- """第三方插件目录(相对于项目根目录)"""
-
health_check_interval_sec: float = Field(
default=30.0,
json_schema_extra={
@@ -1676,4 +1658,7 @@ class PluginRuntimeConfig(ConfigBase):
"x-icon": "link",
},
)
- """_wrap_\n 自定义 IPC Socket 路径(仅 Linux/macOS 生效)\n 留空则自动生成临时路径"""
+ """
+ 自定义 IPC Socket 路径(仅 Linux/macOS 生效)
+ 留空则自动生成临时路径
+ """
From 34190755992cd4f0685619a601f41a381f23aaef Mon Sep 17 00:00:00 2001
From: UnCLAS-Prommer
Date: Mon, 16 Mar 2026 18:18:40 +0800
Subject: [PATCH 03/45] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E6=B3=A8?=
=?UTF-8?q?=E9=87=8A;=E6=B7=BB=E5=8A=A0=E6=97=A5=E5=BF=97=E9=A2=9C?=
=?UTF-8?q?=E8=89=B2=E8=87=AA=E5=AE=9A=E4=B9=89?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
src/plugin_runtime/protocol/envelope.py | 160 +++++++++++++-----------
1 file changed, 90 insertions(+), 70 deletions(-)
diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py
index bcfb2758..56bb4582 100644
--- a/src/plugin_runtime/protocol/envelope.py
+++ b/src/plugin_runtime/protocol/envelope.py
@@ -13,46 +13,38 @@ import logging as stdlib_logging
import time
-# ─── 协议常量 ──────────────────────────────────────────────────────
-
-PROTOCOL_VERSION = "1.0"
-
+# ====== 协议常量 ======
+PROTOCOL_VERSION = "1.0.0"
# 支持的 SDK 版本范围(Host 在握手时校验)
MIN_SDK_VERSION = "1.0.0"
MAX_SDK_VERSION = "1.99.99"
-# ─── 消息类型 ──────────────────────────────────────────────────────
-
-
+# ====== 消息类型 ======
class MessageType(str, Enum):
"""RPC 消息类型"""
REQUEST = "request"
RESPONSE = "response"
- EVENT = "event"
-
-
-# ─── 请求 ID 生成器 ───────────────────────────────────────────────
+ BROADCAST = "broadcast"
+# ====== 请求 ID 生成器 ======
class RequestIdGenerator:
- """单调递增 int64 请求 ID 生成器(线程安全由调用方保证或使用 asyncio)"""
+ """单调递增 int64 请求 ID 生成器"""
def __init__(self, start: int = 1) -> None:
self._counter = start
- def next(self) -> int:
+ async def next(self) -> int:
current = self._counter
self._counter += 1
return current
-# ─── Envelope 模型 ─────────────────────────────────────────────────
-
-
+# ====== Envelope 模型 ======
class Envelope(BaseModel):
- """RPC 统一信封
+ """RPC 统一消息封装
所有 Host <-> Runner 消息均封装为此格式。
序列化流程:Envelope -> .model_dump() -> MsgPack encode
@@ -60,15 +52,25 @@ class Envelope(BaseModel):
"""
protocol_version: str = Field(default=PROTOCOL_VERSION, description="协议版本")
+ """协议版本"""
request_id: int = Field(description="单调递增请求 ID")
+ """单调递增请求 ID"""
message_type: MessageType = Field(description="消息类型")
+ """消息类型"""
method: str = Field(default="", description="RPC 方法名")
+ """RPC 方法名"""
plugin_id: str = Field(default="", description="目标插件 ID")
- timestamp_ms: int = Field(default_factory=lambda: int(time.time() * 1000), description="发送时间戳(ms)")
- timeout_ms: int = Field(default=30000, description="相对超时(ms)")
+ """目标插件 ID"""
+ timestamp_ms: int = Field(default_factory=lambda: int(time.time() * 1000), description="发送时间戳 (ms)")
+ """发送时间戳 (ms)"""
+ timeout_ms: int = Field(default=30000, description="相对超时 (ms)")
+ """相对超时 (ms)"""
generation: int = Field(default=0, description="Runner generation 编号")
+ """Runner generation 编号"""
payload: Dict[str, Any] = Field(default_factory=dict, description="业务数据")
- error: Optional[Dict[str, Any]] = Field(default=None, description="错误信息(仅 response)")
+ """业务数据"""
+ error: Optional[Dict[str, Any]] = Field(default=None, description="错误信息 (仅 response)")
+ """错误信息 (仅 response)"""
def is_request(self) -> bool:
return self.message_type == MessageType.REQUEST
@@ -76,8 +78,8 @@ class Envelope(BaseModel):
def is_response(self) -> bool:
return self.message_type == MessageType.RESPONSE
- def is_event(self) -> bool:
- return self.message_type == MessageType.EVENT
+ def is_broadcast(self) -> bool:
+ return self.message_type == MessageType.BROADCAST
def make_response(
self, payload: Optional[Dict[str, Any]] = None, error: Optional[Dict[str, Any]] = None
@@ -105,153 +107,172 @@ class Envelope(BaseModel):
)
-# ─── 握手消息 ──────────────────────────────────────────────────────
-
-
+# ====== 握手请求与响应 ======
class HelloPayload(BaseModel):
"""runner.hello 握手请求 payload"""
runner_id: str = Field(description="Runner 进程唯一标识")
+ """Runner 进程唯一标识"""
sdk_version: str = Field(description="SDK 版本号")
+ """SDK 版本号"""
session_token: str = Field(description="一次性会话令牌")
+ """一次性会话令牌"""
class HelloResponsePayload(BaseModel):
"""runner.hello 握手响应 payload"""
accepted: bool = Field(description="是否接受连接")
+ """是否接受连接"""
host_version: str = Field(default="", description="Host 版本号")
+ """Host 版本号"""
assigned_generation: int = Field(default=0, description="分配的 generation 编号")
- reason: str = Field(default="", description="拒绝原因(若 accepted=False)")
-
-
-# ─── 组件注册消息 ──────────────────────────────────────────────────
+ """分配的 generation 编号"""
+ reason: str = Field(default="", description="拒绝原因 (若 accepted=False)")
+ """拒绝原因 (若 `accepted`=`False`)"""
+# ====== 组件注册消息 ======
class ComponentDeclaration(BaseModel):
"""单个组件声明"""
name: str = Field(description="组件名称")
- component_type: str = Field(description="组件类型: action/command/tool/event_handler")
+ """组件名称"""
+ component_type: str = Field(
+ description="组件类型:action/command/tool/event_handler/workflow_handler/message_gateway"
+ )
+ """组件类型:`action`/`command`/`tool`/`event_handler`/`workflow_handler`/`message_gateway`"""
plugin_id: str = Field(description="所属插件 ID")
+ """所属插件 ID"""
metadata: Dict[str, Any] = Field(default_factory=dict, description="组件元数据")
+ """组件元数据"""
-class RegisterComponentsPayload(BaseModel):
- """plugin.register_components 请求 payload"""
+class RegisterPluginPayload(BaseModel):
+ """plugin.register_plugin 请求 payload"""
plugin_id: str = Field(description="插件 ID")
+ """插件 ID"""
plugin_version: str = Field(default="1.0.0", description="插件版本")
+ """插件版本"""
components: List[ComponentDeclaration] = Field(default_factory=list, description="组件列表")
+ """组件列表"""
capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表")
+ """所需能力列表"""
class BootstrapPluginPayload(BaseModel):
"""plugin.bootstrap 请求 payload"""
plugin_id: str = Field(description="插件 ID")
+ """插件 ID"""
plugin_version: str = Field(default="1.0.0", description="插件版本")
+ """插件版本"""
capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表")
+ """所需能力列表"""
-# ─── 调用消息 ──────────────────────────────────────────────────────
-
-
+# ====== 插件调用请求和响应 ======
class InvokePayload(BaseModel):
- """plugin.invoke_* 请求 payload"""
+ """plugin.invoke.* 请求 payload"""
component_name: str = Field(description="要调用的组件名称")
+ """要调用的组件名称"""
args: Dict[str, Any] = Field(default_factory=dict, description="调用参数")
+ """调用参数"""
class InvokeResultPayload(BaseModel):
- """plugin.invoke_* 响应 payload"""
+ """plugin.invoke.* 响应 payload"""
success: bool = Field(description="是否成功")
+ """是否成功"""
result: Any = Field(default=None, description="返回值")
+ """返回值"""
-# ─── 能力调用消息 ──────────────────────────────────────────────────
-
-
+# ====== 能力调用消息 ======
class CapabilityRequestPayload(BaseModel):
"""cap.* 请求 payload(插件 -> Host 能力调用)"""
capability: str = Field(description="能力名称,如 send.text, db.query")
+ """能力名称,如 send.text, db.query"""
args: Dict[str, Any] = Field(default_factory=dict, description="调用参数")
+ """调用参数"""
class CapabilityResponsePayload(BaseModel):
"""cap.* 响应 payload"""
success: bool = Field(description="是否成功")
+ """是否成功"""
result: Any = Field(default=None, description="返回值")
+ """返回值"""
-# ─── 健康检查 ──────────────────────────────────────────────────────
-
-
+# ====== 健康检查 ======
class HealthPayload(BaseModel):
"""plugin.health 响应 payload"""
healthy: bool = Field(description="是否健康")
+ """是否健康"""
loaded_plugins: List[str] = Field(default_factory=list, description="已加载的插件列表")
- uptime_ms: int = Field(default=0, description="运行时长(ms)")
+ """已加载的插件列表"""
+ uptime_ms: int = Field(default=0, description="运行时长 (ms)")
+ """运行时长 (ms)"""
class RunnerReadyPayload(BaseModel):
"""runner.ready 请求 payload"""
loaded_plugins: List[str] = Field(default_factory=list, description="已完成初始化的插件列表")
+ """已完成初始化的插件列表"""
failed_plugins: List[str] = Field(default_factory=list, description="初始化失败的插件列表")
+ """初始化失败的插件列表"""
-# ─── 配置更新 ──────────────────────────────────────────────────────
-
-
-# Host 侧现已支持配置更新推送:
-# - 总配置热重载完成后,PluginRuntimeManager 会向已加载插件推送配置更新事件。
-# - 插件目录下的 config.toml 变化由现有 FileWatcher 监听并转发为 plugin.config_updated。
+# ====== 配置更新 ======
class ConfigUpdatedPayload(BaseModel):
"""plugin.config_updated 事件 payload"""
plugin_id: str = Field(description="插件 ID")
+ """插件 ID"""
config_version: str = Field(description="新配置版本")
+ """新配置版本"""
config_data: Dict[str, Any] = Field(default_factory=dict, description="配置内容")
+ """配置内容"""
-# ─── 关停 ──────────────────────────────────────────────────────────
-
-
+# ====== 关停 ======
class ShutdownPayload(BaseModel):
"""plugin.shutdown / plugin.prepare_shutdown payload"""
reason: str = Field(default="normal", description="关停原因")
- drain_timeout_ms: int = Field(default=5000, description="排空超时(ms)")
+ """关停原因"""
+ drain_timeout_ms: int = Field(default=5000, description="排空超时 (ms)")
+ """排空超时 (ms)"""
-# ─── 日志传输 ──────────────────────────────────────────────────────
+# ====== 日志传输 ======
class LogEntry(BaseModel):
"""单条日志记录(Runner → Host 传输格式)"""
- timestamp_ms: int = Field(
- description="日志时间戳,Unix epoch 毫秒",
- )
- level: int = Field(
- description=("stdlib logging 整数级别: 10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL"),
- )
- logger_name: str = Field(
- description="Logger 名称,如 plugin.my_plugin.submodule",
- )
- message: str = Field(
- description="经 Formatter 格式化后的完整日志消息(含 exc_info 文本)",
- )
+ timestamp_ms: int = Field(description="日志时间戳,Unix epoch 毫秒")
+ """日志时间戳,Unix epoch 毫秒"""
+ level: int = Field(description="stdlib logging 整数级别:10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL")
+ """stdlib logging 整数级别:10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL"""
+ logger_name: str = Field(description="Logger 名称,如 plugin.my_plugin.submodule")
+ """Logger 名称,如 plugin.my_plugin.submodule"""
+ message: str = Field(description="经 Formatter 格式化后的完整日志消息(含 exc_info 文本)")
+ """经 Formatter 格式化后的完整日志消息(含 exc_info 文本)"""
exception_text: str = Field(
default="",
description="原始异常摘要(exc_text),供结构化消费;已嵌入 message 中",
)
+ """原始异常摘要(exc_text),供结构化消费;已嵌入 message 中"""
+ log_color_in_hex: Optional[str] = Field(default=None, description="日志颜色的十六进制字符串(如 #RRGGBB)")
@property
def levelname(self) -> str:
@@ -262,6 +283,5 @@ class LogEntry(BaseModel):
class LogBatchPayload(BaseModel):
"""runner.log_batch 事件 payload:Runner 端向 Host 批量推送日志记录"""
- entries: List[LogEntry] = Field(
- description="本批次日志记录列表,按时间升序排列",
- )
+ entries: List[LogEntry] = Field(description="本批次日志记录列表,按时间升序排列")
+ """本批次日志记录列表,按时间升序排列"""
From e1b2ecb5b13a01b41e5405e9c805b7add664a4b3 Mon Sep 17 00:00:00 2001
From: UnCLAS-Prommer
Date: Mon, 16 Mar 2026 19:37:58 +0800
Subject: [PATCH 04/45] =?UTF-8?q?fix:=20(AI)=20=E6=9B=B4robust=E7=9A=84?=
=?UTF-8?q?=E4=BC=A0=E8=BE=93?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
src/plugin_runtime/transport/named_pipe.py | 91 ++++++++++++++++++----
src/plugin_runtime/transport/uds.py | 62 +++++++++++++--
2 files changed, 129 insertions(+), 24 deletions(-)
diff --git a/src/plugin_runtime/transport/named_pipe.py b/src/plugin_runtime/transport/named_pipe.py
index a759507d..7fd39bc9 100644
--- a/src/plugin_runtime/transport/named_pipe.py
+++ b/src/plugin_runtime/transport/named_pipe.py
@@ -1,6 +1,9 @@
"""Windows Named Pipe 传输实现。
适用于 Windows 平台,使用 asyncio ProactorEventLoop 的 named pipe 支持。
+
+注意:Named Pipe 是 Windows 特有的 IPC 机制,
+在 Linux/macOS 平台上不可用。Unix-like 平台请使用 UDS 传输。
"""
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, cast
@@ -18,10 +21,12 @@ _DEFAULT_PIPE_PREFIX = "maibot-plugin"
class _NamedPipeServerHandle(Protocol):
+ """Named Pipe 服务端句柄的协议定义。"""
def close(self) -> None: ...
class _NamedPipeEventLoop(Protocol):
+ """ProactorEventLoop 的协议定义,提供 named pipe 相关方法。"""
async def start_serving_pipe(
self,
protocol_factory: Callable[[], asyncio.BaseProtocol],
@@ -40,6 +45,15 @@ class _NamedPipeEventLoop(Protocol):
def _normalize_pipe_address(pipe_name: Optional[str] = None) -> str:
+ """规范化 Named Pipe 地址。
+
+ Args:
+ pipe_name: 管道名称。如果以 '\\\\.\\pipe\\' 开头则直接使用,
+ 否则会自动添加前缀。如果为 None 则生成随机名称。
+
+ Returns:
+ 规范化的管道地址(格式:\\\\.\\pipe\\name)
+ """
if pipe_name and pipe_name.startswith(_PIPE_PREFIX):
return pipe_name
@@ -55,12 +69,21 @@ def _normalize_pipe_address(pipe_name: Optional[str] = None) -> str:
class NamedPipeConnection(Connection):
- """基于 Windows Named Pipe 的连接。"""
+ """基于 Windows Named Pipe 的连接。
+
+ 封装了底层 StreamReader/StreamWriter,提供分帧读写能力。
+ """
- pass
+ def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
+ super().__init__(reader, writer)
class _NamedPipeServerProtocol(asyncio.StreamReaderProtocol):
+ """Named Pipe 服务端协议实现。
+
+ 处理客户端连接的生命周期,包括连接建立、数据处理和连接关闭。
+ """
+
def __init__(self, handler: ConnectionHandler, loop: asyncio.AbstractEventLoop) -> None:
self._reader: asyncio.StreamReader = asyncio.StreamReader()
super().__init__(self._reader)
@@ -69,39 +92,58 @@ class _NamedPipeServerProtocol(asyncio.StreamReaderProtocol):
self._handler_task: Optional[asyncio.Task[None]] = None
def connection_made(self, transport: asyncio.BaseTransport) -> None:
+ """连接建立时的回调。"""
super().connection_made(transport)
writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), self, self._reader, self._loop)
connection = NamedPipeConnection(self._reader, writer)
- self._handler_task = self._loop.create_task(self._run_handler(connection))
+ # 使用 asyncio.create_task 确保任务正确调度
+ self._handler_task = asyncio.create_task(self._run_handler(connection))
self._handler_task.add_done_callback(self._on_handler_done)
async def _run_handler(self, connection: NamedPipeConnection) -> None:
+ """运行连接处理器。"""
try:
await self._handler(connection)
finally:
await connection.close()
def _on_handler_done(self, task: asyncio.Task[None]) -> None:
+ """连接处理器完成时的回调。"""
if task.cancelled():
return
if exc := task.exception():
- self._loop.call_exception_handler(
- {
- "message": "Named pipe 连接处理失败",
- "exception": exc,
- "protocol": self,
- }
- )
+ try:
+ self._loop.call_exception_handler(
+ {
+ "message": "Named pipe 连接处理失败",
+ "exception": exc,
+ "protocol": self,
+ }
+ )
+ except Exception:
+ # 如果 loop 已经关闭,忽略异常
+ pass
class NamedPipeTransportServer(TransportServer):
- """Windows Named Pipe 传输服务端。"""
+ """Windows Named Pipe 传输服务端。
+
+ 使用 ProactorEventLoop 的 start_serving_pipe 方法监听客户端连接。
+ """
def __init__(self, pipe_name: Optional[str] = None) -> None:
self._address: str = _normalize_pipe_address(pipe_name)
self._servers: List[_NamedPipeServerHandle] = []
async def start(self, handler: ConnectionHandler) -> None:
+ """启动 Named Pipe 服务端。
+
+ Args:
+ handler: 新连接到来时的回调函数
+
+ Raises:
+ RuntimeError: 当在非 Windows 平台或事件循环不支持时
+ """
if sys.platform != "win32":
raise RuntimeError("Named pipe 仅支持 Windows")
@@ -116,32 +158,49 @@ class NamedPipeTransportServer(TransportServer):
)
async def stop(self) -> None:
+ """停止 Named Pipe 服务端并清理资源。"""
for server in self._servers:
server.close()
+ # 等待所有服务器句柄完全关闭
+ await asyncio.gather(
+ *[asyncio.sleep(0.1) for _ in self._servers],
+ return_exceptions=True
+ )
self._servers.clear()
- await asyncio.sleep(0)
def get_address(self) -> str:
return self._address
class NamedPipeTransportClient(TransportClient):
- """Windows Named Pipe 传输客户端。"""
+ """Windows Named Pipe 传输客户端。
+
+ 用于主动连接到 Named Pipe 服务端。
+ """
def __init__(self, address: str) -> None:
self._address: str = _normalize_pipe_address(address)
async def connect(self) -> Connection:
+ """建立到 Named Pipe 服务端的连接。
+
+ Returns:
+ NamedPipeConnection: 连接对象
+
+ Raises:
+ NotImplementedError: 当在非 Windows 平台或事件循环不支持时
+ """
if sys.platform != "win32":
- raise RuntimeError("Named pipe 仅支持 Windows")
+ raise NotImplementedError("Named pipe 仅支持 Windows")
loop = asyncio.get_running_loop()
if not hasattr(loop, "create_pipe_connection"):
- raise RuntimeError("当前事件循环不支持 Windows named pipe")
+ raise NotImplementedError("当前事件循环不支持 Windows named pipe")
pipe_loop = cast(_NamedPipeEventLoop, loop)
reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(reader)
transport, _protocol = await pipe_loop.create_pipe_connection(lambda: protocol, self._address)
- writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), protocol, reader, loop)
+ # 使用返回的 protocol 创建 StreamWriter
+ writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), _protocol, reader, loop)
return NamedPipeConnection(reader, writer)
\ No newline at end of file
diff --git a/src/plugin_runtime/transport/uds.py b/src/plugin_runtime/transport/uds.py
index 47bf033b..af71ea5d 100644
--- a/src/plugin_runtime/transport/uds.py
+++ b/src/plugin_runtime/transport/uds.py
@@ -1,6 +1,9 @@
"""Unix Domain Socket 传输实现
适用于 Linux / macOS 平台。
+
+注意:UDS (Unix Domain Socket) 是 Unix-like 系统特有的 IPC 机制,
+在 Windows 平台上不可用。Windows 平台请使用 Named Pipe 传输。
"""
from pathlib import Path
@@ -8,20 +11,30 @@ from typing import Optional
import asyncio
import os
+import sys
import tempfile
from .base import Connection, ConnectionHandler, TransportClient, TransportServer
class UDSConnection(Connection):
- """基于 UDS 的连接"""
+ """基于 UDS 的连接
+
+ 封装了底层 StreamReader/StreamWriter,提供分帧读写能力。
+ """
- pass # 直接复用 Connection 基类的分帧读写
+ def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
+ super().__init__(reader, writer)
# Unix domain socket 路径的系统限制(sun_path 字段长度)
-# Linux: 108 字节, macOS: 104 字节
-_UDS_PATH_MAX = 104
+# Linux: 108 字节,macOS: 104 字节,其他 Unix: 通常 104 字节
+if sys.platform == "linux":
+ _UDS_PATH_MAX = 108
+elif sys.platform == "darwin": # macOS
+ _UDS_PATH_MAX = 104
+else:
+ _UDS_PATH_MAX = 104 # 保守默认值
class UDSTransportServer(TransportServer):
@@ -44,6 +57,18 @@ class UDSTransportServer(TransportServer):
self._server: Optional[asyncio.AbstractServer] = None
async def start(self, handler: ConnectionHandler) -> None:
+ """启动 UDS 服务端
+
+ Args:
+ handler: 新连接到来时的回调函数
+
+ Raises:
+ RuntimeError: 当在非 Unix 平台(如 Windows)上调用时
+ """
+ # 平台检查:UDS 仅在 Unix-like 系统上可用
+ if sys.platform == "win32":
+ raise RuntimeError("UDS 不支持 Windows 平台,请使用 Named Pipe")
+
# 清理残留 socket 文件
if self._socket_path.exists():
self._socket_path.unlink()
@@ -58,10 +83,16 @@ class UDSTransportServer(TransportServer):
finally:
await conn.close()
- self._server = await asyncio.start_unix_server(_on_connect, path=str(self._socket_path))
+ try:
+ self._server = await asyncio.start_unix_server(_on_connect, path=str(self._socket_path))
- # 设置文件权限为仅当前用户可访问
- self._socket_path.chmod(0o600)
+ # 设置文件权限为仅当前用户可访问
+ self._socket_path.chmod(0o600)
+ except Exception:
+ # 启动失败时清理可能创建的目录和 socket 文件
+ if self._socket_path.exists():
+ self._socket_path.unlink()
+ raise
async def stop(self) -> None:
if self._server:
@@ -77,11 +108,26 @@ class UDSTransportServer(TransportServer):
class UDSTransportClient(TransportClient):
- """UDS 传输客户端"""
+ """UDS 传输客户端
+
+ 用于主动连接到 UDS 服务端。
+ """
def __init__(self, socket_path: Path) -> None:
self._socket_path: Path = socket_path
async def connect(self) -> Connection:
+ """建立到 UDS 服务端的连接
+
+ Returns:
+ UDSConnection: 连接对象
+
+ Raises:
+ RuntimeError: 当在非 Unix 平台(如 Windows)上调用时
+ """
+ # 平台检查:UDS 仅在 Unix-like 系统上可用
+ if sys.platform == "win32":
+ raise RuntimeError("UDS 不支持 Windows 平台,请使用 Named Pipe")
+
reader, writer = await asyncio.open_unix_connection(str(self._socket_path))
return UDSConnection(reader, writer)
From 49b620219de2e1333e1107aead43deda9e33d0dc Mon Sep 17 00:00:00 2001
From: UnCLAS-Prommer
Date: Tue, 17 Mar 2026 01:30:31 +0800
Subject: [PATCH 05/45] =?UTF-8?q?refcator:=20=E9=87=8D=E5=91=BD=E5=90=8Dpo?=
=?UTF-8?q?licy=E4=B8=BAauthorization;=E7=A7=BB=E9=99=A4envelope=E7=9A=84g?=
=?UTF-8?q?eneration(runner=E4=B8=8D=E5=86=8D=E9=87=8D=E8=BD=BD);?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
src/plugin_runtime/host/authorization.py | 61 ++++++++++++
src/plugin_runtime/host/capability_service.py | 14 +--
src/plugin_runtime/host/policy_engine.py | 97 -------------------
src/plugin_runtime/protocol/envelope.py | 5 -
4 files changed, 69 insertions(+), 108 deletions(-)
create mode 100644 src/plugin_runtime/host/authorization.py
delete mode 100644 src/plugin_runtime/host/policy_engine.py
diff --git a/src/plugin_runtime/host/authorization.py b/src/plugin_runtime/host/authorization.py
new file mode 100644
index 00000000..d746c4d2
--- /dev/null
+++ b/src/plugin_runtime/host/authorization.py
@@ -0,0 +1,61 @@
+"""授权管理器
+
+负责管理插件的能力授权以及校验
+每个插件在 manifest 中声明能力需求,Host 启动时签发能力令牌。
+"""
+
+from dataclasses import dataclass, field
+from typing import Dict, List, Optional, Set, Tuple
+
+
+@dataclass
+class CapabilityPermissionToken:
+ """能力令牌"""
+
+ plugin_id: str
+ capabilities: Set[str] = field(default_factory=set)
+
+
+class AuthorizationManager:
+ """授权管理器
+
+ 管理所有插件的能力令牌,提供授权校验。
+ """
+
+ def __init__(self) -> None:
+ self._permission_tokens: Dict[str, CapabilityPermissionToken] = {}
+
+ def register_plugin(self, plugin_id: str, capabilities: List[str]) -> CapabilityPermissionToken:
+ """为插件签发能力令牌"""
+ token = CapabilityPermissionToken(plugin_id=plugin_id, capabilities=set(capabilities))
+ self._permission_tokens[plugin_id] = token
+ return token
+
+ def revoke_permission_token(self, plugin_id: str):
+ """移除插件的能力令牌。"""
+ self._permission_tokens.pop(plugin_id, None)
+
+ def clear(self) -> None:
+ """清空所有能力令牌。"""
+ self._permission_tokens.clear()
+
+ def check_capability(self, plugin_id: str, capability: str) -> Tuple[bool, str]:
+ """检查插件是否有权调用某项能力
+
+ Returns:
+ return (bool, str): (是否有此能力, 原因)
+ """
+ token = self._permission_tokens.get(plugin_id)
+ if not token:
+ return False, f"插件 {plugin_id} 未注册能力令牌"
+ if capability not in token.capabilities:
+ return False, f"插件 {plugin_id} 未获授权能力: {capability}"
+ return True, ""
+
+ def get_token(self, plugin_id: str) -> Optional[CapabilityPermissionToken]:
+ """获取插件的能力令牌"""
+ return self._permission_tokens.get(plugin_id)
+
+ def list_plugins(self) -> List[str]:
+ """列出所有已注册的插件"""
+ return list(self._permission_tokens.keys())
diff --git a/src/plugin_runtime/host/capability_service.py b/src/plugin_runtime/host/capability_service.py
index 6685ff60..e0c56c2b 100644
--- a/src/plugin_runtime/host/capability_service.py
+++ b/src/plugin_runtime/host/capability_service.py
@@ -4,10 +4,9 @@ Host 端实现的能力服务,处理来自插件的 cap.* 请求。
每个能力方法被注册到 RPC Server,接收 Runner 转发的请求并执行实际操作。
"""
-from typing import Any, Awaitable, Callable, Dict, List
+from typing import Any, Callable, Dict, List, Coroutine, TYPE_CHECKING
from src.common.logger import get_logger
-from src.plugin_runtime.host.policy_engine import PolicyEngine
from src.plugin_runtime.protocol.envelope import (
CapabilityRequestPayload,
CapabilityResponsePayload,
@@ -15,10 +14,13 @@ from src.plugin_runtime.protocol.envelope import (
)
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
+if TYPE_CHECKING:
+ from src.plugin_runtime.host.authorization import AuthorizationManager
+
logger = get_logger("plugin_runtime.host.capability_service")
# 能力实现函数类型: (plugin_id, capability, args) -> result
-CapabilityImpl = Callable[[str, str, Dict[str, Any]], Awaitable[Any]]
+CapabilityImpl = Callable[[str, str, Dict[str, Any]], Coroutine[Any, Any, Any]]
class CapabilityService:
@@ -31,8 +33,8 @@ class CapabilityService:
4. 执行实际操作并返回结果
"""
- def __init__(self, policy_engine: PolicyEngine) -> None:
- self._policy = policy_engine
+ def __init__(self, authorization: "AuthorizationManager") -> None:
+ self._authorization = authorization
# capability_name -> implementation
self._implementations: Dict[str, CapabilityImpl] = {}
@@ -65,7 +67,7 @@ class CapabilityService:
capability = req.capability
# 1. 权限校验
- allowed, reason = self._policy.check_capability(plugin_id, capability, envelope.generation)
+ allowed, reason = self._authorization.check_capability(plugin_id, capability)
if not allowed:
error_code = (
ErrorCode.E_GENERATION_MISMATCH if "generation 不匹配" in reason else ErrorCode.E_CAPABILITY_DENIED
diff --git a/src/plugin_runtime/host/policy_engine.py b/src/plugin_runtime/host/policy_engine.py
deleted file mode 100644
index 61b32480..00000000
--- a/src/plugin_runtime/host/policy_engine.py
+++ /dev/null
@@ -1,97 +0,0 @@
-"""策略引擎
-
-负责能力授权校验。
-每个插件在 manifest 中声明能力需求,Host 启动时签发能力令牌。
-"""
-
-from dataclasses import dataclass, field
-from typing import Dict, List, Optional, Set, Tuple
-
-
-@dataclass
-class CapabilityToken:
- """能力令牌"""
-
- plugin_id: str
- generation: int
- capabilities: Set[str] = field(default_factory=set)
-
-
-class PolicyEngine:
- """策略引擎
-
- 管理所有插件的能力令牌,提供授权校验。
- """
-
- def __init__(self) -> None:
- self._tokens: Dict[str, Dict[int, CapabilityToken]] = {}
-
- def register_plugin(
- self,
- plugin_id: str,
- generation: int,
- capabilities: List[str],
- ) -> CapabilityToken:
- """为插件签发能力令牌"""
- token = CapabilityToken(
- plugin_id=plugin_id,
- generation=generation,
- capabilities=set(capabilities),
- )
- self._tokens.setdefault(plugin_id, {})[generation] = token
- return token
-
- def revoke_plugin(self, plugin_id: str, generation: Optional[int] = None) -> None:
- """撤销插件的能力令牌。"""
- if generation is None:
- self._tokens.pop(plugin_id, None)
- return
-
- generations = self._tokens.get(plugin_id)
- if generations is None:
- return
-
- generations.pop(generation, None)
- if not generations:
- self._tokens.pop(plugin_id, None)
-
- def clear(self) -> None:
- """清空所有能力令牌。"""
- self._tokens.clear()
-
- def check_capability(self, plugin_id: str, capability: str, generation: Optional[int] = None) -> Tuple[bool, str]:
- """检查插件是否有权调用某项能力
-
- Returns:
- (allowed, reason)
- """
- generations = self._tokens.get(plugin_id)
- if not generations:
- return False, f"插件 {plugin_id} 未注册能力令牌"
-
- if generation is None:
- token = generations[max(generations)]
- else:
- token = generations.get(generation)
- if token is None:
- active_generation = max(generations)
- return False, f"插件 {plugin_id} generation 不匹配: {generation} != {active_generation}"
-
- if capability not in token.capabilities:
- return False, f"插件 {plugin_id} 未获授权能力: {capability}"
-
- if generation is not None and token.generation != generation:
- return False, f"插件 {plugin_id} generation 不匹配: {generation} != {token.generation}"
-
- return True, ""
-
- def get_token(self, plugin_id: str) -> Optional[CapabilityToken]:
- """获取插件的能力令牌"""
- generations = self._tokens.get(plugin_id)
- if not generations:
- return None
- return generations[max(generations)]
-
- def list_plugins(self) -> List[str]:
- """列出所有已注册的插件"""
- return list(self._tokens.keys())
diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py
index 56bb4582..ca9e8005 100644
--- a/src/plugin_runtime/protocol/envelope.py
+++ b/src/plugin_runtime/protocol/envelope.py
@@ -65,8 +65,6 @@ class Envelope(BaseModel):
"""发送时间戳 (ms)"""
timeout_ms: int = Field(default=30000, description="相对超时 (ms)")
"""相对超时 (ms)"""
- generation: int = Field(default=0, description="Runner generation 编号")
- """Runner generation 编号"""
payload: Dict[str, Any] = Field(default_factory=dict, description="业务数据")
"""业务数据"""
error: Optional[Dict[str, Any]] = Field(default=None, description="错误信息 (仅 response)")
@@ -91,7 +89,6 @@ class Envelope(BaseModel):
message_type=MessageType.RESPONSE,
method=self.method,
plugin_id=self.plugin_id,
- generation=self.generation,
payload=payload or {},
error=error,
)
@@ -126,8 +123,6 @@ class HelloResponsePayload(BaseModel):
"""是否接受连接"""
host_version: str = Field(default="", description="Host 版本号")
"""Host 版本号"""
- assigned_generation: int = Field(default=0, description="分配的 generation 编号")
- """分配的 generation 编号"""
reason: str = Field(default="", description="拒绝原因 (若 accepted=False)")
"""拒绝原因 (若 `accepted`=`False`)"""
From 84a6524bd9c8ed66eca35801b0d952b552194148 Mon Sep 17 00:00:00 2001
From: UnCLAS-Prommer
Date: Tue, 17 Mar 2026 20:00:19 +0800
Subject: [PATCH 06/45] =?UTF-8?q?refactor:=20=E7=A7=BB=E9=99=A4generation;?=
=?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=96=B0=E7=9A=84ErrorCode;=E4=BF=AE?=
=?UTF-8?q?=E6=94=B9ErrorCode=E7=9A=84=E4=B8=80=E4=B8=AA=E5=90=8D=E7=A7=B0?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
src/plugin_runtime/host/authorization.py | 1 +
src/plugin_runtime/host/capability_service.py | 29 +-
src/plugin_runtime/host/rpc_server.py | 444 +++++-------------
src/plugin_runtime/protocol/errors.py | 16 +-
4 files changed, 138 insertions(+), 352 deletions(-)
diff --git a/src/plugin_runtime/host/authorization.py b/src/plugin_runtime/host/authorization.py
index d746c4d2..3fb48c6a 100644
--- a/src/plugin_runtime/host/authorization.py
+++ b/src/plugin_runtime/host/authorization.py
@@ -40,6 +40,7 @@ class AuthorizationManager:
self._permission_tokens.clear()
def check_capability(self, plugin_id: str, capability: str) -> Tuple[bool, str]:
+ # sourcery skip: assign-if-exp, reintroduce-else, swap-if-else-branches, use-named-expression
"""检查插件是否有权调用某项能力
Returns:
diff --git a/src/plugin_runtime/host/capability_service.py b/src/plugin_runtime/host/capability_service.py
index e0c56c2b..98366a07 100644
--- a/src/plugin_runtime/host/capability_service.py
+++ b/src/plugin_runtime/host/capability_service.py
@@ -7,11 +7,7 @@ Host 端实现的能力服务,处理来自插件的 cap.* 请求。
from typing import Any, Callable, Dict, List, Coroutine, TYPE_CHECKING
from src.common.logger import get_logger
-from src.plugin_runtime.protocol.envelope import (
- CapabilityRequestPayload,
- CapabilityResponsePayload,
- Envelope,
-)
+from src.plugin_runtime.protocol.envelope import CapabilityRequestPayload, CapabilityResponsePayload, Envelope
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
if TYPE_CHECKING:
@@ -59,31 +55,19 @@ class CapabilityService:
try:
req = CapabilityRequestPayload.model_validate(envelope.payload)
except Exception as e:
- return envelope.make_error_response(
- ErrorCode.E_BAD_PAYLOAD.value,
- f"能力调用 payload 格式错误: {e}",
- )
+ return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, f"能力调用 payload 格式错误: {e}")
capability = req.capability
# 1. 权限校验
allowed, reason = self._authorization.check_capability(plugin_id, capability)
if not allowed:
- error_code = (
- ErrorCode.E_GENERATION_MISMATCH if "generation 不匹配" in reason else ErrorCode.E_CAPABILITY_DENIED
- )
- return envelope.make_error_response(
- error_code.value,
- reason,
- )
+ return envelope.make_error_response(ErrorCode.E_CAPABILITY_DENIED.value, reason)
# 2. 查找实现
impl = self._implementations.get(capability)
if impl is None:
- return envelope.make_error_response(
- ErrorCode.E_METHOD_NOT_ALLOWED.value,
- f"未注册的能力: {capability}",
- )
+ return envelope.make_error_response(ErrorCode.E_METHOD_NOT_ALLOWED.value, f"未注册的能力: {capability}")
# 3. 执行
try:
@@ -94,10 +78,7 @@ class CapabilityService:
return envelope.make_error_response(e.code.value, e.message, e.details)
except Exception as e:
logger.error(f"能力 {capability} 执行异常: {e}", exc_info=True)
- return envelope.make_error_response(
- ErrorCode.E_CAPABILITY_FAILED.value,
- str(e),
- )
+ return envelope.make_error_response(ErrorCode.E_CAPABILITY_FAILED.value, str(e))
def list_capabilities(self) -> List[str]:
"""列出所有已注册的能力"""
diff --git a/src/plugin_runtime/host/rpc_server.py b/src/plugin_runtime/host/rpc_server.py
index 79fe0d9a..75ef9b2a 100644
--- a/src/plugin_runtime/host/rpc_server.py
+++ b/src/plugin_runtime/host/rpc_server.py
@@ -7,7 +7,7 @@
4. 请求-响应关联与超时管理
"""
-from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple, Coroutine
import asyncio
import contextlib
@@ -32,7 +32,7 @@ from src.plugin_runtime.transport.base import Connection, TransportServer
logger = get_logger("plugin_runtime.host.rpc_server")
# RPC 方法处理器类型
-MethodHandler = Callable[[Envelope], Awaitable[Envelope]]
+MethodHandler = Callable[[Envelope], Coroutine[Any, Any, Envelope]]
class RPCServer:
@@ -55,109 +55,29 @@ class RPCServer:
self._id_gen = RequestIdGenerator()
self._connection: Optional[Connection] = None # 当前活跃的 Runner 连接
- self._runner_id: Optional[str] = None
- self._runner_generation: int = 0
- self._staged_connection: Optional[Connection] = None
- self._staged_runner_id: Optional[str] = None
- self._staged_runner_generation: int = 0
- self._staging_takeover: bool = False
# 方法处理器注册表
self._method_handlers: Dict[str, MethodHandler] = {}
- # 等待响应的 pending 请求: request_id -> (Future, target_generation)
- self._pending_requests: Dict[int, Tuple[asyncio.Future, int]] = {}
+ # 等待响应的 pending 请求: request_id -> Future
+ self._pending_requests: Dict[int, asyncio.Future[Envelope]] = {}
# 发送队列(背压控制)
self._send_queue: Optional[asyncio.Queue[Tuple[Connection, bytes, asyncio.Future[None]]]] = None
- self._send_worker_task: Optional[asyncio.Task] = None
+ self._send_worker_task: Optional[asyncio.Task[None]] = None
# 运行状态
self._running: bool = False
- self._tasks: List[asyncio.Task] = []
+ self._tasks: List[asyncio.Task[None]] = []
@property
def session_token(self) -> str:
return self._session_token
- def reset_session_token(self) -> str:
- """重新生成会话令牌(热重载时调用,防止旧 Runner 重连)"""
- self._session_token = secrets.token_hex(32)
- return self._session_token
-
- def restore_session_token(self, token: str) -> None:
- """恢复指定的会话令牌(热重载回滚时调用)"""
- self._session_token = token
-
- @property
- def runner_generation(self) -> int:
- return self._runner_generation
-
- @property
- def staged_generation(self) -> int:
- return self._staged_runner_generation
-
@property
def is_connected(self) -> bool:
return self._connection is not None and not self._connection.is_closed
- def has_generation(self, generation: int) -> bool:
- return generation == self._runner_generation or (
- self._staged_connection is not None
- and not self._staged_connection.is_closed
- and generation == self._staged_runner_generation
- )
-
- def begin_staged_takeover(self) -> None:
- """允许新 Runner 以 staged 方式接入,待 Supervisor 验证后再切换为活跃连接。"""
- self._staging_takeover = True
-
- async def commit_staged_takeover(self) -> None:
- """提交 staged Runner,原活跃连接在提交后被关闭。"""
- if self._staged_connection is None or self._staged_connection.is_closed:
- raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "没有可提交的新 Runner 连接")
-
- old_connection = self._connection
- old_generation = self._runner_generation
-
- self._connection = self._staged_connection
- self._runner_id = self._staged_runner_id
- self._runner_generation = self._staged_runner_generation
-
- self._staged_connection = None
- self._staged_runner_id = None
- self._staged_runner_generation = 0
- self._staging_takeover = False
-
- if stale_count := self._fail_pending_requests(
- ErrorCode.E_PLUGIN_CRASHED,
- "Runner 连接已被新 generation 接管",
- generation=old_generation,
- ):
- logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求")
-
- if old_connection and old_connection is not self._connection and not old_connection.is_closed:
- await old_connection.close()
-
- async def rollback_staged_takeover(self) -> None:
- """放弃 staged Runner,保留当前活跃连接。"""
- staged_connection = self._staged_connection
- staged_generation = self._staged_runner_generation
-
- self._staged_connection = None
- self._staged_runner_id = None
- self._staged_runner_generation = 0
- self._staging_takeover = False
-
- self._fail_pending_requests(
- ErrorCode.E_PLUGIN_CRASHED,
- "新 Runner 预热失败,已回滚",
- generation=staged_generation,
- )
-
- if staged_connection and not staged_connection.is_closed:
- await staged_connection.close()
-
def register_method(self, method: str, handler: MethodHandler) -> None:
"""注册 RPC 方法处理器"""
self._method_handlers[method] = handler
@@ -173,14 +93,8 @@ class RPCServer:
async def stop(self) -> None:
"""停止 RPC 服务器"""
self._running = False
-
- # 取消所有 pending 请求
- for future, _generation in self._pending_requests.values():
- if not future.done():
- future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
- self._pending_requests.clear()
-
- self._fail_queued_sends(ErrorCode.E_TIMEOUT, "服务器关闭")
+ self._fail_pending_requests(ErrorCode.E_SHUTTING_DOWN, "服务器正在关闭")
+ self._fail_queued_sends(ErrorCode.E_SHUTTING_DOWN, "服务器正在关闭")
if self._send_worker_task:
self._send_worker_task.cancel()
@@ -198,10 +112,6 @@ class RPCServer:
await self._connection.close()
self._connection = None
- if self._staged_connection:
- await self._staged_connection.close()
- self._staged_connection = None
-
await self._transport.stop()
logger.info("RPC Server 已停止")
@@ -211,7 +121,6 @@ class RPCServer:
plugin_id: str = "",
payload: Optional[Dict[str, Any]] = None,
timeout_ms: int = 30000,
- target_generation: Optional[int] = None,
) -> Envelope:
"""向 Runner 发送 RPC 请求并等待响应
@@ -227,18 +136,14 @@ class RPCServer:
Raises:
RPCError: 调用失败
"""
- generation = target_generation or self._runner_generation
- conn = self._get_connection_for_generation(generation)
- if conn is None or conn.is_closed:
+ if not self._connection or self._connection.is_closed:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
-
- request_id = self._id_gen.next()
+ request_id = await self._id_gen.next()
envelope = Envelope(
request_id=request_id,
message_type=MessageType.REQUEST,
method=method,
plugin_id=plugin_id,
- generation=generation,
timeout_ms=timeout_ms,
payload=payload or {},
)
@@ -246,12 +151,12 @@ class RPCServer:
# 注册 pending future
loop = asyncio.get_running_loop()
future: asyncio.Future[Envelope] = loop.create_future()
- self._pending_requests[request_id] = (future, generation)
+ self._pending_requests[request_id] = future
try:
# 发送请求
data = self._codec.encode_envelope(envelope)
- await self._enqueue_send(conn, data)
+ await self._enqueue_send(self._connection, data)
# 等待响应
timeout_sec = timeout_ms / 1000.0
@@ -265,93 +170,66 @@ class RPCServer:
raise
raise RPCError(ErrorCode.E_UNKNOWN, str(e)) from e
- async def send_event(self, method: str, plugin_id: str = "", payload: Optional[Dict[str, Any]] = None) -> None:
- """向 Runner 发送单向事件(不等待响应)"""
- conn = self._connection
- if conn is None or conn.is_closed:
- return
+ # ============ 内部方法 ============
+ # ========= 发送循环 =========
+ async def _send_loop(self) -> None:
+ """后台发送循环:串行消费发送队列,统一执行连接写入。"""
+ if self._send_queue is None:
+ raise RuntimeError("没有消息队列")
- request_id = self._id_gen.next()
- envelope = Envelope(
- request_id=request_id,
- message_type=MessageType.EVENT,
- method=method,
- plugin_id=plugin_id,
- generation=self._runner_generation,
- payload=payload or {},
- )
- data = self._codec.encode_envelope(envelope)
- await self._enqueue_send(conn, data)
+ while True:
+ try:
+ conn, data, send_future = await self._send_queue.get()
+ except asyncio.CancelledError:
+ break
- # ─── 内部方法 ──────────────────────────────────────────────
+ try:
+ if conn.is_closed:
+ raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
+ await conn.send_frame(data)
+ if not send_future.done():
+ send_future.set_result(None)
+ except asyncio.CancelledError:
+ if not send_future.done():
+ send_future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
+ raise
+ except Exception as e:
+ send_error = RPCError.from_exception(e, {ConnectionError: ErrorCode.E_PLUGIN_CRASHED})
+ if not send_future.done():
+ send_future.set_exception(send_error)
+ finally:
+ self._send_queue.task_done()
+ # ====== 发送循环方法 ======
async def _handle_connection(self, conn: Connection) -> None:
"""处理新的 Runner 连接"""
logger.info("收到 Runner 连接")
- previous_connection = self._connection
- previous_generation = self._runner_generation
-
# 第一条消息必须是 runner.hello 握手
try:
- role = await self._handle_handshake(conn)
- if role is None:
+ success = await self._handle_handshake(conn)
+ if not success:
await conn.close()
return
except Exception as e:
logger.error(f"握手失败: {e}")
await conn.close()
return
-
- if role == "staged":
- expected_generation = self._staged_runner_generation
- logger.info(
- f"Runner staged 握手成功: runner_id={self._staged_runner_id}, generation={self._staged_runner_generation}"
- )
- else:
- self._connection = conn
- expected_generation = self._runner_generation
- logger.info(f"Runner 握手成功: runner_id={self._runner_id}, generation={self._runner_generation}")
-
- if previous_connection and previous_connection is not conn and not previous_connection.is_closed:
- logger.info("检测到新 Runner 已接管连接,关闭旧连接")
- if stale_count := self._fail_pending_requests(
- ErrorCode.E_PLUGIN_CRASHED,
- "Runner 连接已被新 generation 接管",
- generation=previous_generation,
- ):
- logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求")
- await previous_connection.close()
-
+ logger.info("Runner staged 握手成功")
+ self._connection = conn
# 启动消息接收循环
try:
- await self._recv_loop(conn, expected_generation=expected_generation)
+ await self._recv_loop(conn)
except Exception as e:
logger.error(f"连接异常断开: {e}")
finally:
- if self._connection is conn:
- self._connection = None
- self._runner_id = None
- self._fail_pending_requests(
- ErrorCode.E_PLUGIN_CRASHED,
- "Runner 连接已断开",
- generation=expected_generation,
- )
- elif self._staged_connection is conn:
- self._staged_connection = None
- self._staged_runner_id = None
- self._staged_runner_generation = 0
- self._fail_pending_requests(
- ErrorCode.E_PLUGIN_CRASHED,
- "Staged Runner 连接已断开",
- generation=expected_generation,
- )
+ self._connection = None
+ self._fail_pending_requests(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已断开")
- async def _handle_handshake(self, conn: Connection) -> Optional[str]:
+ async def _handle_handshake(self, conn: Connection) -> bool:
"""处理 runner.hello 握手"""
# 接收握手请求
data = await asyncio.wait_for(conn.recv_frame(), timeout=10.0)
envelope = self._codec.decode_envelope(data)
-
if envelope.method != "runner.hello":
logger.error(f"期望 runner.hello,收到 {envelope.method}")
error_resp = envelope.make_error_response(
@@ -359,21 +237,17 @@ class RPCServer:
"首条消息必须为 runner.hello",
)
await conn.send_frame(self._codec.encode_envelope(error_resp))
- return None
+ return False
# 解析握手 payload
hello = HelloPayload.model_validate(envelope.payload)
-
# 校验会话令牌
if hello.session_token != self._session_token:
logger.error("会话令牌不匹配")
- resp_payload = HelloResponsePayload(
- accepted=False,
- reason="会话令牌无效",
- )
+ resp_payload = HelloResponsePayload(accepted=False, reason="会话令牌无效")
resp = envelope.make_response(payload=resp_payload.model_dump())
await conn.send_frame(self._codec.encode_envelope(resp))
- return None
+ return False
# 校验 SDK 版本
if not self._check_sdk_version(hello.sdk_version):
@@ -384,31 +258,26 @@ class RPCServer:
)
resp = envelope.make_response(payload=resp_payload.model_dump())
await conn.send_frame(self._codec.encode_envelope(resp))
- return None
+ return False
- # 握手成功
- role = "active"
- assigned_generation = self._runner_generation + 1
- if self._staging_takeover and self.is_connected:
- role = "staged"
- self._staged_connection = conn
- self._staged_runner_id = hello.runner_id
- self._staged_runner_generation = assigned_generation
- else:
- self._runner_id = hello.runner_id
- self._runner_generation = assigned_generation
-
- resp_payload = HelloResponsePayload(
- accepted=True,
- host_version=PROTOCOL_VERSION,
- assigned_generation=assigned_generation,
- )
+ # 发送响应
+ resp_payload = HelloResponsePayload(accepted=True, host_version=PROTOCOL_VERSION)
resp = envelope.make_response(payload=resp_payload.model_dump())
await conn.send_frame(self._codec.encode_envelope(resp))
+ return True
- return role
+ def _check_sdk_version(self, sdk_version: str) -> bool:
+ """检查 SDK 版本是否在支持范围内"""
+ try:
+ sdk_parts = _parse_version_tuple(sdk_version)
+ min_parts = _parse_version_tuple(MIN_SDK_VERSION)
+ max_parts = _parse_version_tuple(MAX_SDK_VERSION)
+ return min_parts <= sdk_parts <= max_parts
+ except (ValueError, AttributeError):
+ return False
- async def _recv_loop(self, conn: Connection, expected_generation: int) -> None:
+ # ========= 接收循环 =========
+ async def _recv_loop(self, conn: Connection) -> None:
"""消息接收主循环"""
while self._running and not conn.is_closed:
try:
@@ -430,109 +299,40 @@ class RPCServer:
if envelope.is_response():
self._handle_response(envelope)
elif envelope.is_request():
- if envelope.generation != expected_generation:
- error_resp = envelope.make_error_response(
- ErrorCode.E_GENERATION_MISMATCH.value,
- f"过期 generation: {envelope.generation} != {expected_generation}",
- )
- await conn.send_frame(self._codec.encode_envelope(error_resp))
- continue
# 异步处理请求(Runner 发来的能力调用)
task = asyncio.create_task(self._handle_request(envelope, conn))
self._tasks.append(task)
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
- elif envelope.is_event():
- if envelope.generation != expected_generation:
- logger.warning(
- f"忽略过期 generation 事件 {envelope.method}: {envelope.generation} != {expected_generation}"
- )
- continue
- task = asyncio.create_task(self._handle_event(envelope))
+ elif envelope.is_broadcast():
+ task = asyncio.create_task(self._handle_broadcast(envelope))
self._tasks.append(task)
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
+ else:
+ logger.warning(f"未知的消息类型: {envelope.message_type}")
+ continue
+ # ====== 接收循环内部方法 ======
def _handle_response(self, envelope: Envelope) -> None:
"""处理来自 Runner 的响应"""
- pending = self._pending_requests.get(envelope.request_id)
- if pending is None:
+ pending_future = self._pending_requests.pop(envelope.request_id, None)
+ if pending_future is None:
return
-
- future, expected_generation = pending
- if envelope.generation != expected_generation:
- logger.warning(
- f"忽略过期 generation 响应 {envelope.method}: {envelope.generation} != {expected_generation}"
- )
- return
-
- self._pending_requests.pop(envelope.request_id, None)
- if not future.done():
+ if not pending_future.done():
if envelope.error:
- future.set_exception(RPCError.from_dict(envelope.error))
+ pending_future.set_exception(RPCError.from_dict(envelope.error))
else:
- future.set_result(envelope)
-
- async def _enqueue_send(self, conn: Connection, data: bytes) -> None:
- """通过发送队列串行发送消息,提供真实背压。"""
- if conn.is_closed:
- raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
-
- if self._send_queue is None:
- await conn.send_frame(data)
- return
-
- loop = asyncio.get_running_loop()
- send_future: asyncio.Future[None] = loop.create_future()
-
- try:
- self._send_queue.put_nowait((conn, data, send_future))
- except asyncio.QueueFull:
- raise RPCError(ErrorCode.E_BACKPRESSURE, "发送队列已满") from None
-
- await send_future
-
- async def _send_loop(self) -> None:
- """后台发送循环:串行消费发送队列,统一执行连接写入。"""
- if self._send_queue is None:
- return
-
- while True:
- try:
- conn, data, send_future = await self._send_queue.get()
- except asyncio.CancelledError:
- break
-
- try:
- if conn.is_closed:
- raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
- await conn.send_frame(data)
- if not send_future.done():
- send_future.set_result(None)
- except asyncio.CancelledError:
- if not send_future.done():
- send_future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
- raise
- except Exception as e:
- send_error = e if isinstance(e, RPCError) else self._normalize_send_exception(e)
- if not send_future.done():
- send_future.set_exception(send_error)
- finally:
- self._send_queue.task_done()
-
- @staticmethod
- def _normalize_send_exception(error: Exception) -> RPCError:
- if isinstance(error, ConnectionError):
- return RPCError(ErrorCode.E_PLUGIN_CRASHED, str(error))
- return RPCError(ErrorCode.E_UNKNOWN, str(error))
+ pending_future.set_result(envelope)
async def _handle_request(self, envelope: Envelope, conn: Connection) -> None:
"""处理来自 Runner 的请求(通常是能力调用 cap.*)"""
- handler = self._method_handlers.get(envelope.method)
- if handler is None:
- error_resp = envelope.make_error_response(
+ target_method = envelope.method
+ handler = self._method_handlers.get(target_method)
+ if not handler:
+ error_response = envelope.make_error_response(
ErrorCode.E_METHOD_NOT_ALLOWED.value,
f"未注册的方法: {envelope.method}",
)
- await conn.send_frame(self._codec.encode_envelope(error_resp))
+ await conn.send_frame(self._codec.encode_envelope(error_response))
return
try:
@@ -546,59 +346,25 @@ class RPCServer:
error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
await conn.send_frame(self._codec.encode_envelope(error_resp))
- async def _handle_event(self, envelope: Envelope) -> None:
- """处理来自 Runner 的事件"""
+ async def _handle_broadcast(self, envelope: Envelope) -> None:
if handler := self._method_handlers.get(envelope.method):
try:
result = await handler(envelope)
# 检查 handler 返回的信封是否包含错误信息
- if result is not None and isinstance(result, Envelope) and result.error:
+ if result.error:
logger.warning(f"事件 {envelope.method} handler 返回错误: {result.error.get('message', '')}")
except Exception as e:
logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True)
- @staticmethod
- def _check_sdk_version(sdk_version: str) -> bool:
- """检查 SDK 版本是否在支持范围内"""
- try:
- sdk_parts = RPCServer._parse_version_tuple(sdk_version)
- min_parts = RPCServer._parse_version_tuple(MIN_SDK_VERSION)
- max_parts = RPCServer._parse_version_tuple(MAX_SDK_VERSION)
- return min_parts <= sdk_parts <= max_parts
- except (ValueError, AttributeError):
- return False
-
- @staticmethod
- def _parse_version_tuple(version: str) -> Tuple[int, int, int]:
- base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version or "", flags=re.IGNORECASE)[0]
- base_version = base_version.split("+", 1)[0]
- parts = [part for part in base_version.split(".") if part != ""]
- while len(parts) < 3:
- parts.append("0")
- return (int(parts[0]), int(parts[1]), int(parts[2]))
-
- def _get_connection_for_generation(self, generation: int) -> Optional[Connection]:
- if generation == self._runner_generation:
- return self._connection
- if generation == self._staged_runner_generation:
- return self._staged_connection
- return None
-
- def _fail_pending_requests(
- self,
- error_code: ErrorCode,
- message: str,
- generation: Optional[int] = None,
- ) -> int:
- stale_count = 0
- for request_id, (future, request_generation) in list(self._pending_requests.items()):
- if generation is not None and request_generation != generation:
- continue
+ def _fail_pending_requests(self, error_code: ErrorCode, message: str) -> int:
+ """失败所有等待中的请求(如连接断开时)"""
+ aborted_request_count = 0
+ for future in self._pending_requests.values():
if not future.done():
future.set_exception(RPCError(error_code, message))
- stale_count += 1
- self._pending_requests.pop(request_id, None)
- return stale_count
+ aborted_request_count += 1
+ self._pending_requests.clear()
+ return aborted_request_count
def _fail_queued_sends(self, error_code: ErrorCode, message: str) -> int:
if self._send_queue is None:
@@ -617,3 +383,31 @@ class RPCServer:
self._send_queue.task_done()
return failed_count
+
+ async def _enqueue_send(self, conn: Connection, data: bytes) -> None:
+ """通过发送队列串行发送消息,提供真实背压。"""
+ if conn.is_closed:
+ raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
+
+ if self._send_queue is None:
+ await conn.send_frame(data)
+ return
+
+ loop = asyncio.get_running_loop()
+ send_future: asyncio.Future[None] = loop.create_future()
+
+ try:
+ self._send_queue.put_nowait((conn, data, send_future))
+ except asyncio.QueueFull:
+ raise RPCError(ErrorCode.E_BACK_PRESSURE, "发送队列已满") from None
+
+ await send_future
+
+
+def _parse_version_tuple(version: str) -> Tuple[int, int, int]:
+ base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version or "", flags=re.IGNORECASE)[0]
+ base_version = base_version.split("+", 1)[0]
+ parts = [part for part in base_version.split(".") if part != ""]
+ while len(parts) < 3:
+ parts.append("0")
+ return (int(parts[0]), int(parts[1]), int(parts[2]))
diff --git a/src/plugin_runtime/protocol/errors.py b/src/plugin_runtime/protocol/errors.py
index dcae6b8f..ed19760d 100644
--- a/src/plugin_runtime/protocol/errors.py
+++ b/src/plugin_runtime/protocol/errors.py
@@ -7,7 +7,7 @@ from enum import Enum
from typing import Any, Dict, Optional
-class ErrorCode(str, Enum):
+class ErrorCode(Enum):
"""RPC 错误码枚举"""
# 通用
@@ -18,17 +18,17 @@ class ErrorCode(str, Enum):
E_TIMEOUT = "E_TIMEOUT"
E_BAD_PAYLOAD = "E_BAD_PAYLOAD"
E_PROTOCOL_MISMATCH = "E_PROTOCOL_MISMATCH"
+ E_SHUTTING_DOWN = "E_SHUTTING_DOWN"
# 权限与策略
E_UNAUTHORIZED = "E_UNAUTHORIZED"
E_METHOD_NOT_ALLOWED = "E_METHOD_NOT_ALLOWED"
- E_BACKPRESSURE = "E_BACKPRESSURE"
+ E_BACK_PRESSURE = "E_BACK_PRESSURE"
E_HOST_OVERLOADED = "E_HOST_OVERLOADED"
# 插件生命周期
E_PLUGIN_CRASHED = "E_PLUGIN_CRASHED"
E_PLUGIN_NOT_FOUND = "E_PLUGIN_NOT_FOUND"
- E_GENERATION_MISMATCH = "E_GENERATION_MISMATCH"
E_RELOAD_IN_PROGRESS = "E_RELOAD_IN_PROGRESS"
# 能力调用
@@ -65,3 +65,13 @@ class RPCError(Exception):
message=data.get("message", ""),
details=data.get("details", {}),
)
+
+ @classmethod
+ def from_exception(cls, exception: Exception, code_mapping: Optional[Dict[type[Exception], ErrorCode]] = None):
+ if isinstance(exception, cls):
+ return exception
+ if code_mapping:
+ for exception_type, code in code_mapping.items():
+ if isinstance(exception, exception_type):
+ return cls(code=code, message=str(exception))
+ return cls(ErrorCode.E_UNKNOWN, str(exception))
From ca6fd96d4c2e9eef1a469fd7653317ee8755be8e Mon Sep 17 00:00:00 2001
From: UnCLAS-Prommer
Date: Tue, 17 Mar 2026 21:39:06 +0800
Subject: [PATCH 07/45] =?UTF-8?q?refactor:=20=E7=A1=AE=E8=AE=A4ErrorCode?=
=?UTF-8?q?=E5=8F=AF=E4=BB=A5=E7=BB=A7=E6=89=BFstr=EF=BC=8C=E6=81=A2?=
=?UTF-8?q?=E5=A4=8D=E5=8E=9F=E6=9D=A5=E8=AE=BE=E8=AE=A1?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
src/plugin_runtime/protocol/errors.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/plugin_runtime/protocol/errors.py b/src/plugin_runtime/protocol/errors.py
index ed19760d..d2b9228b 100644
--- a/src/plugin_runtime/protocol/errors.py
+++ b/src/plugin_runtime/protocol/errors.py
@@ -7,7 +7,7 @@ from enum import Enum
from typing import Any, Dict, Optional
-class ErrorCode(Enum):
+class ErrorCode(str, Enum):
"""RPC 错误码枚举"""
# 通用
From 14a0c21cbffdb80a3ff276bd736e0431ae772b10 Mon Sep 17 00:00:00 2001
From: UnCLAS-Prommer
Date: Wed, 18 Mar 2026 02:08:13 +0800
Subject: [PATCH 08/45] =?UTF-8?q?refactor:=20component=5Fregistry=E6=9B=B4?=
=?UTF-8?q?=E6=98=93=E7=90=86=E8=A7=A3?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
src/plugin_runtime/host/component_registry.py | 427 +++++++++++++-----
src/plugin_runtime/host/logger_bridge.py | 45 ++
2 files changed, 347 insertions(+), 125 deletions(-)
create mode 100644 src/plugin_runtime/host/logger_bridge.py
diff --git a/src/plugin_runtime/host/component_registry.py b/src/plugin_runtime/host/component_registry.py
index 220a19c0..89ec82d2 100644
--- a/src/plugin_runtime/host/component_registry.py
+++ b/src/plugin_runtime/host/component_registry.py
@@ -1,7 +1,7 @@
"""Host-side ComponentRegistry
对齐旧系统 component_registry.py 的核心能力:
-- 按类型注册组件(action / command / tool / event_handler / workflow_step)
+- 按类型注册组件(action / command / tool / event_handler / workflow_handler / message_gateway)
- 命名空间 (plugin_id.component_name)
- 命令正则匹配
- 组件启用/禁用
@@ -9,8 +9,10 @@
- 注册统计
"""
-from typing import Any, Dict, List, Optional
+from enum import Enum
+from typing import Any, Dict, List, Optional, Set, TypedDict, Tuple
+import contextlib
import re
from src.common.logger import get_logger
@@ -18,8 +20,28 @@ from src.common.logger import get_logger
logger = get_logger("plugin_runtime.host.component_registry")
-class RegisteredComponent:
- """已注册的组件条目"""
+class ComponentTypes(str, Enum):
+ ACTION = "ACTION"
+ COMMAND = "COMMAND"
+ TOOL = "TOOL"
+ EVENT_HANDLER = "EVENT_HANDLER"
+ WORKFLOW_HANDLER = "WORKFLOW_HANDLER"
+ MESSAGE_GATEWAY = "MESSAGE_GATEWAY"
+
+
+class StatusDict(TypedDict):
+ total: int
+ ACTION: int
+ COMMAND: int
+ TOOL: int
+ EVENT_HANDLER: int
+ WORKFLOW_HANDLER: int
+ MESSAGE_GATEWAY: int
+ plugins: int
+
+
+class ComponentEntry:
+ """组件条目"""
__slots__ = (
"name",
@@ -28,31 +50,74 @@ class RegisteredComponent:
"plugin_id",
"metadata",
"enabled",
- "_compiled_pattern",
+ "compiled_pattern",
+ "disabled_session",
)
- def __init__(
- self,
- name: str,
- component_type: str,
- plugin_id: str,
- metadata: Dict[str, Any],
- ) -> None:
- self.name = name
- self.full_name = f"{plugin_id}.{name}"
- self.component_type = component_type
- self.plugin_id = plugin_id
- self.metadata = metadata
- self.enabled = metadata.get("enabled", True)
+ def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
+ self.name: str = name
+ self.full_name: str = f"{plugin_id}.{name}"
+ self.component_type: ComponentTypes = ComponentTypes(component_type)
+ self.plugin_id: str = plugin_id
+ self.metadata: Dict[str, Any] = metadata
+ self.enabled: bool = metadata.get("enabled", True)
+ self.disabled_session: Set[str] = set()
- # 预编译命令正则(仅 command 类型)
- self._compiled_pattern: Optional[re.Pattern] = None
- if component_type == "command":
- if pattern := metadata.get("command_pattern", ""):
- try:
- self._compiled_pattern = re.compile(pattern)
- except re.error as e:
- logger.warning(f"命令 {self.full_name} 正则编译失败: {e}")
+
+class ActionEntry(ComponentEntry):
+ """Action 组件条目"""
+
+ def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
+ super().__init__(name, component_type, plugin_id, metadata)
+
+
+class CommandEntry(ComponentEntry):
+ """Command 组件条目"""
+
+ def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
+ self.compiled_pattern: Optional[re.Pattern] = None
+ self.aliases: List[str] = metadata.get("aliases", [])
+ if pattern := metadata.get("command_pattern", ""):
+ try:
+ self.compiled_pattern = re.compile(pattern)
+ except re.error as e:
+ logger.warning(f"命令 {self.full_name} 正则编译失败: {e}")
+ super().__init__(name, component_type, plugin_id, metadata)
+
+
+class ToolEntry(ComponentEntry):
+ """Tool 组件条目"""
+
+ def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
+ self.description: str = metadata.get("description", "")
+ self.parameters: List[Dict[str, Any]] = metadata.get("parameters", [])
+ self.parameters_raw: List[Dict[str, Any]] = metadata.get("parameters_raw", [])
+ super().__init__(name, component_type, plugin_id, metadata)
+
+
+class EventHandlerEntry(ComponentEntry):
+ """EventHandler 组件条目"""
+
+ def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
+ self.event_type: str = metadata.get("event_type", "")
+ self.weight: int = metadata.get("weight", 0)
+ super().__init__(name, component_type, plugin_id, metadata)
+
+
+class WorkflowHandlerEntry(ComponentEntry):
+ """WorkflowHandler 组件条目"""
+
+ def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
+ self.stage: str = metadata.get("stage", "")
+ self.priority: int = metadata.get("priority", 0)
+ super().__init__(name, component_type, plugin_id, metadata)
+
+
+class MessageGatewayEntry(ComponentEntry):
+ """MessageGateway 组件条目"""
+
+ def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
+ super().__init__(name, component_type, plugin_id, metadata)
class ComponentRegistry:
@@ -64,19 +129,15 @@ class ComponentRegistry:
def __init__(self) -> None:
# 全量索引
- self._components: Dict[str, RegisteredComponent] = {} # full_name -> comp
+ self._components: Dict[str, ComponentEntry] = {} # full_name -> comp
# 按类型索引
- self._by_type: Dict[str, Dict[str, RegisteredComponent]] = {
- "action": {},
- "command": {},
- "tool": {},
- "event_handler": {},
- "workflow_step": {},
- }
+ self._by_type: Dict[ComponentTypes, Dict[str, ComponentEntry]] = {
+ comp_type: {} for comp_type in ComponentTypes
+ } # component_type -> (full_name -> comp)
# 按插件索引
- self._by_plugin: Dict[str, List[RegisteredComponent]] = {}
+ self._by_plugin: Dict[str, List[ComponentEntry]] = {}
def clear(self) -> None:
"""清空全部组件注册状态。"""
@@ -85,47 +146,63 @@ class ComponentRegistry:
type_dict.clear()
self._by_plugin.clear()
- # ──── 注册 / 注销 ─────────────────────────────────────────
+ # ====== 注册 / 注销 ======
+ def register_component(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> bool:
+ """注册单个组件
+
+ Args:
+ name: 组件名称(不含插件id前缀)
+ component_type: 组件类型(如 `ACTION`、`COMMAND` 等)
+ plugin_id: 插件id
+ metadata: 组件元数据
+ Returns:
+ success (bool): 是否成功注册(失败原因通常是组件类型无效)
+ """
+ try:
+ if component_type == ComponentTypes.ACTION:
+ comp = ActionEntry(name, component_type, plugin_id, metadata)
+ elif component_type == ComponentTypes.COMMAND:
+ comp = CommandEntry(name, component_type, plugin_id, metadata)
+ elif component_type == ComponentTypes.TOOL:
+ comp = ToolEntry(name, component_type, plugin_id, metadata)
+ elif component_type == ComponentTypes.EVENT_HANDLER:
+ comp = EventHandlerEntry(name, component_type, plugin_id, metadata)
+ elif component_type == ComponentTypes.WORKFLOW_HANDLER:
+ comp = WorkflowHandlerEntry(name, component_type, plugin_id, metadata)
+ elif component_type == ComponentTypes.MESSAGE_GATEWAY:
+ comp = MessageGatewayEntry(name, component_type, plugin_id, metadata)
+ else:
+ raise ValueError(f"组件类型 {component_type} 不存在")
+ except ValueError:
+ logger.error(f"组件类型 {component_type} 不存在")
+ return False
- def register_component(
- self,
- name: str,
- component_type: str,
- plugin_id: str,
- metadata: Dict[str, Any],
- ) -> bool:
- """注册单个组件。"""
- comp = RegisteredComponent(name, component_type, plugin_id, metadata)
if comp.full_name in self._components:
logger.warning(f"组件 {comp.full_name} 已存在,覆盖")
old_comp = self._components[comp.full_name]
# 从 _by_plugin 列表中移除旧条目,防止幽灵组件堆积
old_list = self._by_plugin.get(old_comp.plugin_id)
if old_list is not None:
- try:
+ with contextlib.suppress(ValueError):
old_list.remove(old_comp)
- except ValueError:
- pass
# 从旧类型索引中移除,防止类型变更时幽灵残留
if old_type_dict := self._by_type.get(old_comp.component_type):
old_type_dict.pop(comp.full_name, None)
self._components[comp.full_name] = comp
-
- if component_type not in self._by_type:
- self._by_type[component_type] = {}
- self._by_type[component_type][comp.full_name] = comp
-
+ self._by_type[comp.component_type][comp.full_name] = comp
self._by_plugin.setdefault(plugin_id, []).append(comp)
return True
- def register_plugin_components(
- self,
- plugin_id: str,
- components: List[Dict[str, Any]],
- ) -> int:
- """批量注册一个插件的所有组件,返回成功注册数。"""
+ def register_plugin_components(self, plugin_id: str, components: List[Dict[str, Any]]) -> int:
+ """批量注册一个插件的所有组件,返回成功注册数。
+ Args:
+ plugin_id (str): 插件id
+ components (List[Dict[str, Any]]): 组件字典列表,每个组件包含 name, component_type, metadata 等字段
+ Returns:
+ count (int): 成功注册的组件数量
+ """
count = 0
for comp_data in components:
ok = self.register_component(
@@ -139,7 +216,13 @@ class ComponentRegistry:
return count
def remove_components_by_plugin(self, plugin_id: str) -> int:
- """移除某个插件的所有组件,返回移除数量。"""
+ """移除某个插件的所有组件,返回移除数量。
+
+ Args:
+ plugin_id (str): 插件id
+ Returns:
+ count (int): 移除的组件数量
+ """
comps = self._by_plugin.pop(plugin_id, [])
for comp in comps:
self._components.pop(comp.full_name, None)
@@ -147,106 +230,200 @@ class ComponentRegistry:
type_dict.pop(comp.full_name, None)
return len(comps)
- # ──── 启用 / 禁用 ─────────────────────────────────────────
+ # ====== 启用 / 禁用 ======
+ def check_component_enabled(self, component: ComponentEntry, session_id: Optional[str] = None):
+ if session_id and session_id in component.disabled_session:
+ return False
+ return component.enabled
- def set_component_enabled(self, full_name: str, enabled: bool) -> bool:
- """启用或禁用指定组件。"""
+ def toggle_component_status(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool:
+ """启用或禁用指定组件。
+
+ Args:
+ full_name (str): 组件全名
+ enabled (bool): 使能情况
+ session_id (Optional[str]): 可选的会话ID,仅对该会话禁用(如果提供)
+ Returns:
+ success (bool): 是否成功设置(失败原因通常是组件不存在)
+ """
comp = self._components.get(full_name)
if comp is None:
return False
- comp.enabled = enabled
+ if session_id:
+ if enabled:
+ comp.disabled_session.discard(session_id)
+ else:
+ comp.disabled_session.add(session_id)
+ else:
+ comp.enabled = enabled
return True
- def set_plugin_enabled(self, plugin_id: str, enabled: bool) -> int:
- """批量启用或禁用某插件的所有组件。"""
+ def toggle_plugin_status(self, plugin_id: str, enabled: bool, session_id: Optional[str] = None) -> int:
+ """批量启用或禁用某插件的所有组件。
+
+ Args:
+ plugin_id (str): 插件id
+ enabled (bool): 使能情况
+ session_id (Optional[str]): 可选的会话ID,仅对该会话禁用(如果提供)
+ Returns:
+ count (int): 成功设置的组件数量(失败原因通常是插件不存在)
+ """
comps = self._by_plugin.get(plugin_id, [])
for comp in comps:
- comp.enabled = enabled
+ if session_id:
+ if enabled:
+ comp.disabled_session.discard(session_id)
+ else:
+ comp.disabled_session.add(session_id)
+ else:
+ comp.enabled = enabled
return len(comps)
- # ──── 查询方法 ─────────────────────────────────────────────
+ def get_component(self, full_name: str) -> Optional[ComponentEntry]:
+ """按全名查询。
- def get_component(self, full_name: str) -> Optional[RegisteredComponent]:
- """按全名查询。"""
+ Args:
+ full_name (str): 组件全名
+ Returns:
+ component (Optional[ComponentEntry]): 组件条目,未找到时为 None
+ """
return self._components.get(full_name)
- def get_components_by_type(self, component_type: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
- """按类型查询。"""
- type_dict = self._by_type.get(component_type, {})
+ def get_components_by_type(
+ self, component_type: str, *, enabled_only: bool = True, session_id: Optional[str] = None
+ ) -> List[ComponentEntry]:
+ """按类型查询组件
+
+ Args:
+ component_type (str): 组件类型(如 `ACTION`、`COMMAND` 等)
+ enabled_only (bool): 是否仅返回启用的组件
+ session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态
+ Returns:
+ components (List[ComponentEntry]): 组件条目列表
+ """
+ try:
+ comp_type = ComponentTypes(component_type)
+ except ValueError:
+ logger.error(f"组件类型 {component_type} 不存在")
+ raise
+ type_dict = self._by_type.get(comp_type, {})
if enabled_only:
- return [c for c in type_dict.values() if c.enabled]
+ return [c for c in type_dict.values() if self.check_component_enabled(c, session_id)]
return list(type_dict.values())
- def get_components_by_plugin(self, plugin_id: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
- """按插件查询。"""
- comps = self._by_plugin.get(plugin_id, [])
- return [c for c in comps if c.enabled] if enabled_only else list(comps)
+ def get_components_by_plugin(
+ self, plugin_id: str, *, enabled_only: bool = True, session_id: Optional[str] = None
+ ) -> List[ComponentEntry]:
+ """按插件查询组件。
- def find_command_by_text(self, text: str) -> Optional[tuple[RegisteredComponent, Dict[str, Any]]]:
+ Args:
+ plugin_id (str): 插件ID
+ enabled_only (bool): 是否仅返回启用的组件
+ session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态
+ Returns:
+ components (List[ComponentEntry]): 组件条目列表
+ """
+ comps = self._by_plugin.get(plugin_id, [])
+ return [c for c in comps if self.check_component_enabled(c, session_id)] if enabled_only else list(comps)
+
+ def find_command_by_text(
+ self, text: str, session_id: Optional[str] = None
+ ) -> Optional[Tuple[ComponentEntry, Dict[str, Any]]]:
"""通过文本匹配命令正则,返回 (组件, matched_groups) 元组。
matched_groups 为正则命名捕获组 dict,别名匹配时为空 dict。
+ Args:
+ text (str): 待匹配文本
+ session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态
+ Returns:
+ result (Optional[tuple[ComponentEntry, Dict[str, Any]]]): 匹配到的组件及正则捕获组,未找到时为 None
"""
- for comp in self._by_type.get("command", {}).values():
- if not comp.enabled:
+ for comp in self._by_type.get(ComponentTypes.COMMAND, {}).values():
+ if not self.check_component_enabled(comp, session_id):
continue
- if comp._compiled_pattern:
- m = comp._compiled_pattern.search(text)
- if m:
+ if not isinstance(comp, CommandEntry):
+ continue
+ if comp.compiled_pattern:
+ if m := comp.compiled_pattern.search(text):
return comp, m.groupdict()
# 别名匹配
- aliases = comp.metadata.get("aliases", [])
- for alias in aliases:
+ for alias in comp.aliases:
if text.startswith(alias):
return comp, {}
return None
- def get_event_handlers(self, event_type: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
- """获取特定事件类型的所有 event_handler,按 weight 降序排列。"""
- handlers = []
- for comp in self._by_type.get("event_handler", {}).values():
- if enabled_only and not comp.enabled:
+ def get_event_handlers(
+ self, event_type: str, *, enabled_only: bool = True, session_id: Optional[str] = None
+ ) -> List[EventHandlerEntry]:
+ """查询指定事件类型的事件处理器组件。
+
+ Args:
+ event_type (str): 事件类型
+ enabled_only (bool): 是否仅返回启用的组件
+ session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态
+ Returns:
+ handlers (List[EventHandlerEntry]): 符合条件的 EventHandler 组件列表,按 weight 降序排序
+ """
+ handlers: List[EventHandlerEntry] = []
+ for comp in self._by_type.get(ComponentTypes.EVENT_HANDLER, {}).values():
+ if enabled_only and not self.check_component_enabled(comp, session_id):
continue
- if comp.metadata.get("event_type") == event_type:
+ if not isinstance(comp, EventHandlerEntry):
+ continue
+ if comp.event_type == event_type:
handlers.append(comp)
- handlers.sort(key=lambda c: c.metadata.get("weight", 0), reverse=True)
+ handlers.sort(key=lambda c: c.weight, reverse=True)
return handlers
- def get_workflow_steps(self, stage: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
- """获取特定 workflow 阶段的所有步骤,按 priority 降序。"""
- steps = []
- for comp in self._by_type.get("workflow_step", {}).values():
- if enabled_only and not comp.enabled:
+ def get_workflow_handlers(
+ self, stage: str, *, enabled_only: bool = True, session_id: Optional[str] = None
+ ) -> List[WorkflowHandlerEntry]:
+ """获取特定 workflow 阶段的所有步骤,按 priority 降序。
+
+ Args:
+ stage: workflow 阶段名称
+ enabled_only: 是否仅返回启用的组件
+ session_id: 可选的会话ID,若提供则考虑会话禁用状态
+ Returns:
+ handlers (List[WorkflowHandlerEntry]): 符合条件的 WorkflowHandler 组件列表,按 priority 降序排序
+ """
+ handlers: List[WorkflowHandlerEntry] = []
+ for comp in self._by_type.get(ComponentTypes.WORKFLOW_HANDLER, {}).values():
+ if enabled_only and not self.check_component_enabled(comp, session_id):
continue
- if comp.metadata.get("stage") == stage:
- steps.append(comp)
- steps.sort(key=lambda c: c.metadata.get("priority", 0), reverse=True)
- return steps
+ if not isinstance(comp, WorkflowHandlerEntry):
+ continue
+ if comp.stage == stage:
+ handlers.append(comp)
+ handlers.sort(key=lambda c: c.priority, reverse=True)
+ return handlers
- def get_tools_for_llm(self, *, enabled_only: bool = True) -> List[Dict[str, Any]]:
- """获取可供 LLM 使用的工具列表(openai function-calling 格式预览)。"""
- result: List[Dict[str, Any]] = []
- for comp in self.get_components_by_type("tool", enabled_only=enabled_only):
- tool_def: Dict[str, Any] = {
- "name": comp.full_name,
- "description": comp.metadata.get("description", ""),
- }
- # 从结构化参数或原始参数构建 parameters
- params = comp.metadata.get("parameters", [])
- params_raw = comp.metadata.get("parameters_raw", {})
- if params:
- tool_def["parameters"] = params
- elif params_raw:
- tool_def["parameters"] = params_raw
- result.append(tool_def)
- return result
+ def get_tools(self, *, enabled_only: bool = True, session_id: Optional[str] = None) -> List[ToolEntry]:
+ """查询所有工具组件。
+
+ Args:
+ enabled_only (bool): 是否仅返回启用的组件
+ session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态
+ Returns:
+ tools (List[ToolEntry]): 符合条件的 Tool 组件列表
+ """
+ tools: List[ToolEntry] = []
+ for comp in self._by_type.get(ComponentTypes.TOOL, {}).values():
+ if enabled_only and not self.check_component_enabled(comp, session_id):
+ continue
+ if isinstance(comp, ToolEntry):
+ tools.append(comp)
+ return tools
- # ──── 统计 ─────────────────────────────────────────────────
-
- def get_stats(self) -> Dict[str, int]:
- """获取注册统计。"""
- stats: Dict[str, int] = {"total": len(self._components)}
+ # ====== 统计信息 ======
+ def get_stats(self) -> StatusDict:
+ """获取注册统计。
+
+ Returns:
+ stats (StatusDict): 组件统计信息,包括总数、各类型数量、插件数量等
+ """
+ stats: StatusDict = {"total": len(self._components)} # type: ignore
for comp_type, type_dict in self._by_type.items():
- stats[comp_type] = len(type_dict)
+ stats[comp_type.value] = len(type_dict)
stats["plugins"] = len(self._by_plugin)
return stats
diff --git a/src/plugin_runtime/host/logger_bridge.py b/src/plugin_runtime/host/logger_bridge.py
new file mode 100644
index 00000000..f2213dfe
--- /dev/null
+++ b/src/plugin_runtime/host/logger_bridge.py
@@ -0,0 +1,45 @@
+import logging as stdlib_logging
+from src.plugin_runtime.protocol.errors import ErrorCode
+from src.plugin_runtime.protocol.envelope import Envelope, LogBatchPayload
+class RunnerLogBridge:
+ """将 Runner 进程上报的批量日志重放到主进程的 Logger 中。
+
+ Runner 通过 ``runner.log_batch`` IPC 事件批量到达。
+ 每条 LogEntry 被重建为一个真实的 :class:`logging.LogRecord` 并直接
+ 调用 ``logging.getLogger(entry.logger_name).handle(record)``,
+ 从而接入主进程已配置好的 structlog Handler 链。
+ """
+
+ async def handle_log_batch(self, envelope: Envelope) -> Envelope:
+ """IPC 事件处理器:解析批量日志并重放到主进程 Logger。
+
+ Args:
+ envelope: 方法名为 ``runner.log_batch`` 的 IPC 事件信封。
+
+ Returns:
+ 空响应信封(事件模式下将被忽略)。
+ """
+ try:
+ batch = LogBatchPayload.model_validate(envelope.payload)
+ except Exception as exc:
+ return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
+
+ for entry in batch.entries:
+ # 重建一个与原始日志尽量相符的 LogRecord
+ record = stdlib_logging.LogRecord(
+ name=entry.logger_name,
+ level=entry.level,
+ pathname="",
+ lineno=0,
+ msg=entry.message,
+ args=(),
+ exc_info=None,
+ )
+ record.created = entry.timestamp_ms / 1000.0
+ record.msecs = entry.timestamp_ms % 1000
+ if entry.exception_text:
+ record.exc_text = entry.exception_text
+
+ stdlib_logging.getLogger(entry.logger_name).handle(record)
+
+ return envelope.make_response(payload={"accepted": True, "count": len(batch.entries)})
\ No newline at end of file
From 32519c688bcfb3a540b235d3f337629c490a9e6d Mon Sep 17 00:00:00 2001
From: UnCLAS-Prommer
Date: Wed, 18 Mar 2026 15:29:08 +0800
Subject: [PATCH 09/45] refactor: event_dispatcher
---
src/plugin_runtime/host/component_registry.py | 7 +-
src/plugin_runtime/host/event_dispatcher.py | 114 ++++++++++--------
2 files changed, 65 insertions(+), 56 deletions(-)
diff --git a/src/plugin_runtime/host/component_registry.py b/src/plugin_runtime/host/component_registry.py
index 89ec82d2..22f9a7e0 100644
--- a/src/plugin_runtime/host/component_registry.py
+++ b/src/plugin_runtime/host/component_registry.py
@@ -101,6 +101,7 @@ class EventHandlerEntry(ComponentEntry):
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
self.event_type: str = metadata.get("event_type", "")
self.weight: int = metadata.get("weight", 0)
+ self.intercept_message: bool = metadata.get("intercept_message", False)
super().__init__(name, component_type, plugin_id, metadata)
@@ -356,7 +357,7 @@ class ComponentRegistry:
self, event_type: str, *, enabled_only: bool = True, session_id: Optional[str] = None
) -> List[EventHandlerEntry]:
"""查询指定事件类型的事件处理器组件。
-
+
Args:
event_type (str): 事件类型
enabled_only (bool): 是否仅返回启用的组件
@@ -400,7 +401,7 @@ class ComponentRegistry:
def get_tools(self, *, enabled_only: bool = True, session_id: Optional[str] = None) -> List[ToolEntry]:
"""查询所有工具组件。
-
+
Args:
enabled_only (bool): 是否仅返回启用的组件
session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态
@@ -418,7 +419,7 @@ class ComponentRegistry:
# ====== 统计信息 ======
def get_stats(self) -> StatusDict:
"""获取注册统计。
-
+
Returns:
stats (StatusDict): 组件统计信息,包括总数、各类型数量、插件数量等
"""
diff --git a/src/plugin_runtime/host/event_dispatcher.py b/src/plugin_runtime/host/event_dispatcher.py
index 720e93d7..f08591a8 100644
--- a/src/plugin_runtime/host/event_dispatcher.py
+++ b/src/plugin_runtime/host/event_dispatcher.py
@@ -4,40 +4,38 @@
1. 按事件类型查询已注册的 event_handler(通过 ComponentRegistry)
2. 按 weight 排序,依次通过 RPC 调用 Runner 中的处理器
3. 支持阻塞(intercept_message)和非阻塞分发
-4. 事件结果历史记录
+4. 事件结果历史记录(有上限)
"""
-from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple
+from dataclasses import dataclass, field
+from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING
import asyncio
from src.common.logger import get_logger
-from src.plugin_runtime.host.component_registry import ComponentRegistry, RegisteredComponent
+
+
+if TYPE_CHECKING:
+ from .supervisor import PluginRunnerSupervisor
+ from .component_registry import ComponentRegistry, EventHandlerEntry
logger = get_logger("plugin_runtime.host.event_dispatcher")
# invoke_fn 类型: async (plugin_id, component_name, args) -> response_payload dict
InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]]
+# 每个事件类型的最大历史记录数量,防止内存无限增长
+_MAX_HISTORY_LENGTH = 100
+@dataclass
class EventResult:
"""单个 EventHandler 的执行结果"""
- __slots__ = ("handler_name", "success", "continue_processing", "modified_message", "custom_result")
-
- def __init__(
- self,
- handler_name: str,
- success: bool = True,
- continue_processing: bool = True,
- modified_message: Optional[Dict[str, Any]] = None,
- custom_result: Any = None,
- ):
- self.handler_name = handler_name
- self.success = success
- self.continue_processing = continue_processing
- self.modified_message = modified_message
- self.custom_result = custom_result
+ handler_name: str
+ success: bool = field(default=True)
+ continue_processing: bool = field(default=True)
+ modified_message: Optional[Dict[str, Any]] = field(default=None)
+ custom_result: Any = field(default=None)
class EventDispatcher:
@@ -48,8 +46,8 @@ class EventDispatcher:
再通过提供的 invoke_fn 回调 RPC 到 Runner 执行。
"""
- def __init__(self, registry: ComponentRegistry) -> None:
- self._registry: ComponentRegistry = registry
+ def __init__(self, component_registry: "ComponentRegistry") -> None:
+ self._component_registry: "ComponentRegistry" = component_registry
self._result_history: Dict[str, List[EventResult]] = {}
self._history_enabled: Set[str] = set()
# 保持 fire-and-forget task 的强引用,防止被 GC 回收
@@ -59,6 +57,10 @@ class EventDispatcher:
self._history_enabled.add(event_type)
self._result_history.setdefault(event_type, [])
+ def disable_history(self, event_type: str) -> None:
+ self._history_enabled.discard(event_type)
+ self._result_history.pop(event_type, None)
+
def get_history(self, event_type: str) -> List[EventResult]:
return self._result_history.get(event_type, [])
@@ -69,44 +71,46 @@ class EventDispatcher:
async def dispatch_event(
self,
event_type: str,
- invoke_fn: InvokeFn,
- message: Optional[Dict[str, Any]] = None,
+ supervisor: "PluginRunnerSupervisor",
+ message_dict: Optional[Dict[str, Any]] = None,
extra_args: Optional[Dict[str, Any]] = None,
) -> Tuple[bool, Optional[Dict[str, Any]]]:
- """分发事件到所有对应 handler。
+ """分发事件到所有对应 handler 的便捷方法。
+
+ 内置了通过 PluginSupervisor.invoke_plugin 调用 plugin.emit_event 的逻辑,
+ 无需调用方手动构造 invoke_fn 闭包。
Args:
event_type: 事件类型字符串
- invoke_fn: 异步回调,签名 (plugin_id, component_name, args) -> response_payload dict
+ supervisor: PluginSupervisor 实例,用于调用 invoke_plugin
message: MaiMessages 序列化后的 dict(可选)
extra_args: 额外参数
Returns:
- (should_continue, modified_message_dict)
+ (should_continue, modified_message_dict) (bool, Dict[str, Any] | None): (是否继续后续执行, 可选的修改后的消息字典)
"""
- handlers = self._registry.get_event_handlers(event_type)
- if not handlers:
+ handler_entries = self._component_registry.get_event_handlers(event_type)
+ if not handler_entries:
return True, None
should_continue = True
- modified_message: Optional[Dict[str, Any]] = None
- intercept_handlers: List[RegisteredComponent] = []
- async_handlers: List[RegisteredComponent] = []
+ modified_message: Optional[Dict[str, Any]] = message_dict
+ intercept_handlers: List["EventHandlerEntry"] = []
+ non_blocking_handlers: List["EventHandlerEntry"] = []
- for handler in handlers:
- if handler.metadata.get("intercept_message", False):
- intercept_handlers.append(handler)
+ for entry in handler_entries:
+ if entry.intercept_message:
+ intercept_handlers.append(entry)
else:
- async_handlers.append(handler)
+ non_blocking_handlers.append(entry)
- for handler in intercept_handlers:
+ for entry in intercept_handlers:
args = {
"event_type": event_type,
- "message": modified_message or message,
+ "message": modified_message,
**(extra_args or {}),
}
-
- result = await self._invoke_handler(invoke_fn, handler, args, event_type)
+ result = await self._invoke_handler(supervisor, entry, args, event_type)
if result and not result.continue_processing:
should_continue = False
break
@@ -114,16 +118,16 @@ class EventDispatcher:
modified_message = result.modified_message
if should_continue:
- final_message = modified_message or message
- for handler in async_handlers:
- async_message = final_message.copy() if isinstance(final_message, dict) else final_message
+ final_message = modified_message
+ for entry in non_blocking_handlers:
+ async_message = final_message.copy() if final_message else final_message
args = {
"event_type": event_type,
"message": async_message,
**(extra_args or {}),
}
# 非阻塞:保持实例级强引用,防止 task 被 GC 回收
- task = asyncio.create_task(self._invoke_handler(invoke_fn, handler, args, event_type))
+ task = asyncio.create_task(self._invoke_handler(supervisor, entry, args, event_type))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
@@ -131,30 +135,34 @@ class EventDispatcher:
async def _invoke_handler(
self,
- invoke_fn: InvokeFn,
- handler: RegisteredComponent,
+ supervisor: "PluginRunnerSupervisor",
+ handler_entry: "EventHandlerEntry",
args: Dict[str, Any],
event_type: str,
) -> Optional[EventResult]:
"""调用单个 handler 并收集结果。"""
try:
- resp = await invoke_fn(handler.plugin_id, handler.name, args)
+ resp_envelope = await supervisor.invoke_plugin(
+ "plugin.emit_event", handler_entry.plugin_id, handler_entry.name, args
+ )
+ resp = resp_envelope.payload
result = EventResult(
- handler_name=handler.full_name,
+ handler_name=handler_entry.full_name,
success=resp.get("success", True),
continue_processing=resp.get("continue_processing", True),
modified_message=resp.get("modified_message"),
custom_result=resp.get("custom_result"),
)
except Exception as e:
- logger.error(f"EventHandler {handler.full_name} 执行失败: {e}", exc_info=True)
- result = EventResult(
- handler_name=handler.full_name,
- success=False,
- continue_processing=True,
- )
+ logger.error(f"EventHandler {handler_entry.full_name} 执行失败: {e}", exc_info=True)
+ result = EventResult(handler_name=handler_entry.full_name, success=False, continue_processing=True)
if event_type in self._history_enabled:
- self._result_history.setdefault(event_type, []).append(result)
+ history_list = self._result_history.setdefault(event_type, [])
+ history_list.append(result)
+ # 自动清理超出限制的旧记录,防止内存无限增长
+ if len(history_list) > _MAX_HISTORY_LENGTH:
+ # 保留最新的 _MAX_HISTORY_LENGTH 条记录
+ self._result_history[event_type] = history_list[-_MAX_HISTORY_LENGTH:]
return result
From 17248a4cbc2fcdda721a879c7d84f2520e54ea87 Mon Sep 17 00:00:00 2001
From: UnCLAS-Prommer
Date: Wed, 18 Mar 2026 20:18:11 +0800
Subject: [PATCH 10/45] =?UTF-8?q?=E6=B7=BB=E5=8A=A0message=20gateway?=
=?UTF-8?q?=E7=BB=84=E4=BB=B6=E7=B1=BB=E5=9E=8B?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
src/plugin_runtime/host/component_registry.py | 25 ++
src/plugin_runtime/host/message_gateway.py | 309 ++++++++++++++++++
2 files changed, 334 insertions(+)
create mode 100644 src/plugin_runtime/host/message_gateway.py
diff --git a/src/plugin_runtime/host/component_registry.py b/src/plugin_runtime/host/component_registry.py
index 22f9a7e0..7ac7a518 100644
--- a/src/plugin_runtime/host/component_registry.py
+++ b/src/plugin_runtime/host/component_registry.py
@@ -118,6 +118,10 @@ class MessageGatewayEntry(ComponentEntry):
"""MessageGateway 组件条目"""
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
+ platform = metadata.get("platform")
+ if not platform or not isinstance(platform, str):
+ raise ValueError(f"MessageGateway 组件 {plugin_id}.{name} 缺少有效的 platform 字段")
+ self.platform: str = platform
super().__init__(name, component_type, plugin_id, metadata)
@@ -399,6 +403,27 @@ class ComponentRegistry:
handlers.sort(key=lambda c: c.priority, reverse=True)
return handlers
+ def get_message_gateways(
+ self, platform: str, *, enabled_only: bool = True, session_id: Optional[str] = None
+ ) -> Optional[MessageGatewayEntry]:
+ """查询消息网关组件。
+
+ Args:
+ platform (str): 平台名称
+ enabled_only (bool): 是否仅返回启用的组件
+ session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态
+ Returns:
+ gateway (Optional[MessageGatewayEntry]): 符合条件的 MessageGateway 组件,可能不存在
+ """
+
+ for comp in self._by_type.get(ComponentTypes.MESSAGE_GATEWAY, {}).values():
+ if not isinstance(comp, MessageGatewayEntry):
+ continue
+ if enabled_only and not self.check_component_enabled(comp, session_id):
+ continue
+ if comp.platform == platform:
+ return comp # 返回第一个
+
def get_tools(self, *, enabled_only: bool = True, session_id: Optional[str] = None) -> List[ToolEntry]:
"""查询所有工具组件。
diff --git a/src/plugin_runtime/host/message_gateway.py b/src/plugin_runtime/host/message_gateway.py
new file mode 100644
index 00000000..e995ed01
--- /dev/null
+++ b/src/plugin_runtime/host/message_gateway.py
@@ -0,0 +1,309 @@
+"""
+Message Gateway 模块
+适配器专用,用于将其他平台的消息转换为系统内部的消息格式,并将系统消息转换为其他平台的格式。
+"""
+
+from datetime import datetime
+from typing import Dict, Any, TYPE_CHECKING, TypedDict, Optional, List
+
+from src.common.logger import get_logger
+from src.chat.message_receive.message import SessionMessage
+from src.common.data_models.mai_message_data_model import UserInfo, GroupInfo, MessageInfo
+from src.common.data_models.message_component_data_model import MessageSequence
+
+if TYPE_CHECKING:
+ from .component_registry import ComponentRegistry
+ from .supervisor import PluginRunnerSupervisor
+
+logger = get_logger("plugin_runtime.host.message_gateway")
+
+
+class UserInfoDict(TypedDict, total=False):
+ user_id: str
+ user_nickname: str
+ user_cardname: Optional[str]
+
+
+class GroupInfoDict(TypedDict, total=False):
+ group_id: str
+ group_name: str
+
+
+class MessageInfoDict(TypedDict, total=False):
+ user_info: UserInfoDict
+ group_info: Optional[GroupInfoDict]
+ additional_config: Dict[str, Any]
+
+
+class MessageDict(TypedDict, total=False):
+ message_id: str
+ timestamp: str
+ platform: str
+ message_info: MessageInfoDict
+ raw_message: List[Dict[str, Any]]
+ is_mentioned: bool
+ is_at: bool
+ is_emoji: bool
+ is_picture: bool
+ is_command: bool
+ is_notify: bool
+ session_id: str
+ reply_to: Optional[str]
+ processed_plain_text: Optional[str]
+ display_message: Optional[str]
+
+
+class MessageGateway:
+ def __init__(self, component_registry: "ComponentRegistry") -> None:
+ self._component_registry = component_registry
+
+ async def receive_external_message(self, external_message: Dict[str, Any]):
+ """
+ 接收外部消息,转换为系统内部格式,并返回转换结果
+
+ Args:
+ external_message: 外部消息的字典格式数据
+
+ Returns:
+ 转换后的 SessionMessage 对象
+ """
+ # 使用递归函数将外部消息字典转换为 SessionMessage
+ try:
+ session_message = self._build_session_message_from_dict(external_message)
+ except Exception as e:
+ logger.error(f"转换外部消息失败: {e}")
+ return
+ from src.chat.message_receive.bot import chat_bot
+
+ await chat_bot.receive_message(session_message)
+
+ async def send_message_to_external(
+ self,
+ internal_message: SessionMessage,
+ supervisor: "PluginRunnerSupervisor",
+ *,
+ enabled_only: bool = True,
+ save_to_db: bool = True,
+ ) -> bool:
+ """
+ 接收系统内部消息,转换为外部格式,并返回转换结果
+
+ Args:
+ internal_message: 系统内部的 SessionMessage 对象
+
+ Returns:
+ 转换是否成功
+ """
+ try:
+ # 将 SessionMessage 转换为字典格式
+ message_dict = self._session_message_to_dict(internal_message)
+ except Exception as e:
+ logger.error(f"转换内部消息失败:{e}")
+ return False
+ gateway_entry = self._component_registry.get_message_gateways(
+ internal_message.platform,
+ enabled_only=enabled_only,
+ session_id=internal_message.session_id,
+ )
+ if not gateway_entry:
+ logger.warning(f"未找到适配平台 {internal_message.platform} 的消息网关组件,无法发送消息到外部平台")
+ return False
+ args = {"platform": internal_message.platform, "message": message_dict}
+ try:
+ resp_envelope = await supervisor.invoke_plugin(
+ "plugin.emit_event", gateway_entry.plugin_id, gateway_entry.name, args
+ )
+ logger.debug("信息发送成功")
+ except Exception as e:
+ logger.error(f"调用消息网关组件失败:{e}")
+ return False
+
+ # 更新为实际id(如果组件返回了新的id)
+ actual_message_id = resp_envelope.payload.get("message_id")
+ try:
+ actual_message_id = str(actual_message_id)
+ except Exception:
+ actual_message_id = None
+ internal_message.message_id = actual_message_id or internal_message.message_id
+ if save_to_db:
+ try:
+ from src.common.utils.utils_message import MessageUtils
+
+ MessageUtils.store_message_to_db(internal_message)
+ except Exception as e:
+ logger.error(f"保存消息到数据库失败: {e}")
+ return True
+
+ def _message_info_to_dict(self, message_info: MessageInfo) -> MessageInfoDict:
+ """
+ 将 MessageInfo 对象转换为字典格式
+
+ Args:
+ message_info: MessageInfo 对象
+
+ Returns:
+ 字典格式的消息信息
+ """
+ user_info_dict = UserInfoDict(
+ user_id=message_info.user_info.user_id,
+ user_nickname=message_info.user_info.user_nickname,
+ user_cardname=message_info.user_info.user_cardname,
+ )
+
+ group_info_dict: Optional[GroupInfoDict] = None
+ if message_info.group_info:
+ group_info_dict = GroupInfoDict(
+ group_id=message_info.group_info.group_id,
+ group_name=message_info.group_info.group_name,
+ )
+
+ return MessageInfoDict(
+ user_info=user_info_dict,
+ group_info=group_info_dict,
+ additional_config=message_info.additional_config,
+ )
+
+ def _session_message_to_dict(self, session_message: SessionMessage) -> MessageDict:
+ """
+ 将 SessionMessage 对象转换为字典格式(复用 MessageSequence.to_dict 方法)
+
+ Args:
+ session_message: SessionMessage 对象
+
+ Returns:
+ 字典格式的消息
+ """
+ # 转换基本信息
+ message_dict = MessageDict(
+ message_id=session_message.message_id,
+ timestamp=str(session_message.timestamp.timestamp()), # 转换为时间戳字符串
+ platform=session_message.platform,
+ message_info=self._message_info_to_dict(session_message.message_info),
+ raw_message=session_message.raw_message.to_dict(), # 复用 MessageSequence.to_dict()
+ is_mentioned=session_message.is_mentioned,
+ is_at=session_message.is_at,
+ is_emoji=session_message.is_emoji,
+ is_picture=session_message.is_picture,
+ is_command=session_message.is_command,
+ is_notify=session_message.is_notify,
+ session_id=session_message.session_id,
+ )
+
+ # 添加可选字段
+ if session_message.reply_to is not None:
+ message_dict["reply_to"] = session_message.reply_to
+ if session_message.processed_plain_text is not None:
+ message_dict["processed_plain_text"] = session_message.processed_plain_text
+ if session_message.display_message is not None:
+ message_dict["display_message"] = session_message.display_message
+
+ return message_dict
+
+ def _build_message_info_from_dict(self, message_info_dict: Dict[str, Any]) -> MessageInfo:
+ """
+ 从字典构建 MessageInfo 对象
+
+ Args:
+ message_info_dict: 包含消息信息的字典
+
+ Returns:
+ MessageInfo 对象
+ """
+ # 构建用户信息
+ user_info_dict = message_info_dict.get("user_info")
+ if not user_info_dict or not isinstance(user_info_dict, dict):
+ raise ValueError("消息字典中 'user_info' 字段无效")
+ user_id = user_info_dict.get("user_id")
+ user_nickname = user_info_dict.get("user_nickname")
+ user_cardname = user_info_dict.get("user_cardname")
+ if not isinstance(user_id, str) or not isinstance(user_nickname, str) or not user_id or not user_nickname:
+ raise ValueError("消息字典中 'user_info' 字段缺少有效的 'user_id' 或 'user_nickname'")
+ user_cardname = str(user_cardname) if user_cardname is not None else None
+ user_info = UserInfo(user_id=user_id, user_nickname=user_nickname, user_cardname=user_cardname)
+
+ # 构建群信息
+ if group_info_dict := message_info_dict.get("group_info"):
+ group_id = group_info_dict.get("group_id")
+ group_name = group_info_dict.get("group_name")
+ if not isinstance(group_id, str) or not isinstance(group_name, str) or not group_id or not group_name:
+ raise ValueError("消息字典中 'group_info' 字段缺少有效的 'group_id' 或 'group_name'")
+ group_info = GroupInfo(group_id=group_id, group_name=group_name)
+ else:
+ group_info = None
+
+ # 获取额外配置
+ additional_config: Dict[str, Any] = message_info_dict.get("additional_config", {})
+
+ return MessageInfo(user_info=user_info, group_info=group_info, additional_config=additional_config)
+
+ def _build_session_message_from_dict(self, message_dict: Dict[str, Any]) -> SessionMessage:
+ """
+ 从字典构建 SessionMessage 对象(递归处理消息组件)
+
+ Args:
+ message_dict: 包含消息完整信息的字典
+
+ Returns:
+ SessionMessage 对象
+ """
+ # 提取基本信息
+ message_id = message_dict["message_id"]
+ timestamp_str: str = message_dict.get("timestamp", "")
+ platform = message_dict["platform"]
+ if not isinstance(message_id, str) or not message_id:
+ raise ValueError("消息字典中缺少有效的 'message_id' 字段")
+ if not isinstance(platform, str) or not platform:
+ raise ValueError("消息字典中缺少有效的 'platform' 字段")
+
+ # 解析时间戳
+ try:
+ timestamp_float = float(timestamp_str)
+ timestamp = datetime.fromtimestamp(timestamp_float)
+ except (ValueError, TypeError):
+ timestamp = datetime.now() # 如果解析失败,使用当前时间
+
+ # 创建 SessionMessage 实例
+ session_message = SessionMessage(message_id=message_id, timestamp=timestamp, platform=platform)
+
+ # 构建消息信息
+ session_message.message_info = self._build_message_info_from_dict(message_dict["message_info"])
+
+ # 构建原始消息组件序列(复用 MessageSequence.from_dict 方法)
+ raw_message_data = message_dict["raw_message"]
+ if isinstance(raw_message_data, list):
+ session_message.raw_message = MessageSequence.from_dict(raw_message_data)
+ else:
+ raise ValueError("消息字典中 'raw_message' 字段必须是一个列表")
+
+ # 设置其他可选属性
+ session_message.is_mentioned = message_dict.get("is_mentioned", False)
+ if not isinstance(session_message.is_mentioned, bool):
+ session_message.is_mentioned = False
+ session_message.is_at = message_dict.get("is_at", False)
+ if not isinstance(session_message.is_at, bool):
+ session_message.is_at = False
+ session_message.is_emoji = message_dict.get("is_emoji", False)
+ if not isinstance(session_message.is_emoji, bool):
+ session_message.is_emoji = False
+ session_message.is_picture = message_dict.get("is_picture", False)
+ if not isinstance(session_message.is_picture, bool):
+ session_message.is_picture = False
+ session_message.is_command = message_dict.get("is_command", False)
+ if not isinstance(session_message.is_command, bool):
+ session_message.is_command = False
+ session_message.is_notify = message_dict.get("is_notify", False)
+ if not isinstance(session_message.is_notify, bool):
+ session_message.is_notify = False
+ session_message.reply_to = message_dict.get("reply_to")
+ if session_message.reply_to is not None and not isinstance(session_message.reply_to, str):
+ session_message.reply_to = None
+ session_message.processed_plain_text = message_dict.get("processed_plain_text")
+ if session_message.processed_plain_text is not None and not isinstance(
+ session_message.processed_plain_text, str
+ ):
+ session_message.processed_plain_text = None
+ session_message.display_message = message_dict.get("display_message")
+ if session_message.display_message is not None and not isinstance(session_message.display_message, str):
+ session_message.display_message = None
+
+ return session_message
From 593400c0aacc543fe8385bd10fcabe76dc184ee4 Mon Sep 17 00:00:00 2001
From: UnCLAS-Prommer
Date: Wed, 18 Mar 2026 20:24:07 +0800
Subject: [PATCH 11/45] =?UTF-8?q?bot.py=E6=94=AF=E6=8C=81gateway=E7=9A=84?=
=?UTF-8?q?=E4=BF=AE=E6=94=B9?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
src/chat/message_receive/bot.py | 16 ++++++++++++----
1 file changed, 12 insertions(+), 4 deletions(-)
diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py
index 60586406..23e7de6e 100644
--- a/src/chat/message_receive/bot.py
+++ b/src/chat/message_receive/bot.py
@@ -310,6 +310,14 @@ class ChatBot:
# logger.debug(str(message_data))
maim_raw_message = MessageBase.from_dict(message_data)
message = SessionMessage.from_maim_message(maim_raw_message)
+ await self.receive_message(message)
+
+ except Exception as e:
+ logger.error(f"预处理消息失败: {e}")
+ traceback.print_exc()
+
+ async def receive_message(self, message: SessionMessage):
+ try:
group_info = message.message_info.group_info
user_info = message.message_info.user_info
@@ -366,11 +374,11 @@ class ChatBot:
# 命令处理 - 使用新插件系统检查并处理命令
# 注意:命令返回的 response 当前只用于日志记录和流程判断,
# 不会在这里自动作为回复消息发送回会话。
- is_command, cmd_result, continue_process = await self._process_commands(message)
+ # is_command, cmd_result, continue_process = await self._process_commands(message)
- # 如果是命令且不需要继续处理,则直接返回
- if is_command and await self._handle_command_processing_result(message, cmd_result, continue_process):
- return
+ # # 如果是命令且不需要继续处理,则直接返回
+ # if is_command and await self._handle_command_processing_result(message, cmd_result, continue_process):
+ # return
# continue_flag, modified_message = await events_manager.handle_mai_events(EventType.ON_MESSAGE, message)
# if not continue_flag:
From 310d7798ba81cfd01981dee9dc87ef8100a2e2c9 Mon Sep 17 00:00:00 2001
From: UnCLAS-Prommer
Date: Thu, 19 Mar 2026 00:35:19 +0800
Subject: [PATCH 12/45] =?UTF-8?q?refactor:=20hook=5Fdispatcher=E7=9B=B8?=
=?UTF-8?q?=E5=85=B3=E7=9A=84=E4=BF=AE=E6=94=B9?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
src/common/logger_color_and_mapping.py | 2 +-
src/config/official_configs.py | 6 +-
src/plugin_runtime/host/component_registry.py | 27 +-
src/plugin_runtime/host/event_dispatcher.py | 8 +-
src/plugin_runtime/host/hook_dispatcher.py | 166 +++++++
src/plugin_runtime/host/workflow_executor.py | 422 ------------------
src/plugin_runtime/protocol/envelope.py | 4 +-
7 files changed, 193 insertions(+), 442 deletions(-)
create mode 100644 src/plugin_runtime/host/hook_dispatcher.py
delete mode 100644 src/plugin_runtime/host/workflow_executor.py
diff --git a/src/common/logger_color_and_mapping.py b/src/common/logger_color_and_mapping.py
index 1aabafbc..f84caa21 100644
--- a/src/common/logger_color_and_mapping.py
+++ b/src/common/logger_color_and_mapping.py
@@ -58,7 +58,7 @@ MODULE_COLORS: Dict[str, Tuple[str, Optional[str], bool]] = {
"plugin_runtime.host.component_registry": ("#ffaf00", None, False),
"plugin_runtime.host.capability_service": ("#ffd700", None, False),
"plugin_runtime.host.event_dispatcher": ("#87d700", None, False),
- "plugin_runtime.host.workflow_executor": ("#5fd7af", None, False),
+ "plugin_runtime.host.hook_dispatcher": ("#5fd7af", None, False),
"plugin_runtime.runner.main": ("#d787ff", None, False),
"plugin_runtime.runner.rpc_client": ("#8787ff", None, False),
"plugin_runtime.runner.manifest_validator": ("#5fafff", None, False),
diff --git a/src/config/official_configs.py b/src/config/official_configs.py
index e6907611..35360217 100644
--- a/src/config/official_configs.py
+++ b/src/config/official_configs.py
@@ -1642,14 +1642,14 @@ class PluginRuntimeConfig(ConfigBase):
)
"""等待 Runner 子进程启动并注册的超时时间(秒)"""
- workflow_blocking_timeout_sec: float = Field(
- default=120.0,
+ hook_blocking_timeout_sec: float = Field(
+ default=30,
json_schema_extra={
"x-widget": "number",
"x-icon": "timer",
},
)
- """Workflow 阻塞步骤的全局超时上限(秒)"""
+ """Hook 阻塞步骤的全局超时上限(秒)"""
ipc_socket_path: str = Field(
default="",
diff --git a/src/plugin_runtime/host/component_registry.py b/src/plugin_runtime/host/component_registry.py
index 7ac7a518..95da0052 100644
--- a/src/plugin_runtime/host/component_registry.py
+++ b/src/plugin_runtime/host/component_registry.py
@@ -25,7 +25,7 @@ class ComponentTypes(str, Enum):
COMMAND = "COMMAND"
TOOL = "TOOL"
EVENT_HANDLER = "EVENT_HANDLER"
- WORKFLOW_HANDLER = "WORKFLOW_HANDLER"
+ HOOK_HANDLER = "HOOK_HANDLER"
MESSAGE_GATEWAY = "MESSAGE_GATEWAY"
@@ -35,7 +35,7 @@ class StatusDict(TypedDict):
COMMAND: int
TOOL: int
EVENT_HANDLER: int
- WORKFLOW_HANDLER: int
+ HOOK_HANDLER: int
MESSAGE_GATEWAY: int
plugins: int
@@ -105,12 +105,13 @@ class EventHandlerEntry(ComponentEntry):
super().__init__(name, component_type, plugin_id, metadata)
-class WorkflowHandlerEntry(ComponentEntry):
+class HookHandlerEntry(ComponentEntry):
"""WorkflowHandler 组件条目"""
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
self.stage: str = metadata.get("stage", "")
self.priority: int = metadata.get("priority", 0)
+ self.blocking: bool = metadata.get("blocking", False)
super().__init__(name, component_type, plugin_id, metadata)
@@ -172,8 +173,8 @@ class ComponentRegistry:
comp = ToolEntry(name, component_type, plugin_id, metadata)
elif component_type == ComponentTypes.EVENT_HANDLER:
comp = EventHandlerEntry(name, component_type, plugin_id, metadata)
- elif component_type == ComponentTypes.WORKFLOW_HANDLER:
- comp = WorkflowHandlerEntry(name, component_type, plugin_id, metadata)
+ elif component_type == ComponentTypes.HOOK_HANDLER:
+ comp = HookHandlerEntry(name, component_type, plugin_id, metadata)
elif component_type == ComponentTypes.MESSAGE_GATEWAY:
comp = MessageGatewayEntry(name, component_type, plugin_id, metadata)
else:
@@ -380,23 +381,23 @@ class ComponentRegistry:
handlers.sort(key=lambda c: c.weight, reverse=True)
return handlers
- def get_workflow_handlers(
+ def get_hook_handlers(
self, stage: str, *, enabled_only: bool = True, session_id: Optional[str] = None
- ) -> List[WorkflowHandlerEntry]:
- """获取特定 workflow 阶段的所有步骤,按 priority 降序。
+ ) -> List[HookHandlerEntry]:
+ """获取特定 hook 阶段的所有步骤,按 priority 降序。
Args:
- stage: workflow 阶段名称
+ stage: hook 名称
enabled_only: 是否仅返回启用的组件
session_id: 可选的会话ID,若提供则考虑会话禁用状态
Returns:
- handlers (List[WorkflowHandlerEntry]): 符合条件的 WorkflowHandler 组件列表,按 priority 降序排序
+ handlers (List[HookHandlerEntry]): 符合条件的 HookHandler 组件列表,按 priority 降序排序
"""
- handlers: List[WorkflowHandlerEntry] = []
- for comp in self._by_type.get(ComponentTypes.WORKFLOW_HANDLER, {}).values():
+ handlers: List[HookHandlerEntry] = []
+ for comp in self._by_type.get(ComponentTypes.HOOK_HANDLER, {}).values():
if enabled_only and not self.check_component_enabled(comp, session_id):
continue
- if not isinstance(comp, WorkflowHandlerEntry):
+ if not isinstance(comp, HookHandlerEntry):
continue
if comp.stage == stage:
handlers.append(comp)
diff --git a/src/plugin_runtime/host/event_dispatcher.py b/src/plugin_runtime/host/event_dispatcher.py
index f08591a8..29ae530b 100644
--- a/src/plugin_runtime/host/event_dispatcher.py
+++ b/src/plugin_runtime/host/event_dispatcher.py
@@ -50,7 +50,6 @@ class EventDispatcher:
self._component_registry: "ComponentRegistry" = component_registry
self._result_history: Dict[str, List[EventResult]] = {}
self._history_enabled: Set[str] = set()
- # 保持 fire-and-forget task 的强引用,防止被 GC 回收
self._background_tasks: Set[asyncio.Task] = set()
def enable_history(self, event_type: str) -> None:
@@ -68,6 +67,13 @@ class EventDispatcher:
if event_type in self._result_history:
self._result_history[event_type] = []
+ async def stop(self):
+ """停止 EventDispatcher,取消所有未完成的后台任务"""
+ for task in self._background_tasks:
+ task.cancel()
+ await asyncio.gather(*self._background_tasks, return_exceptions=True)
+ self._background_tasks.clear()
+
async def dispatch_event(
self,
event_type: str,
diff --git a/src/plugin_runtime/host/hook_dispatcher.py b/src/plugin_runtime/host/hook_dispatcher.py
new file mode 100644
index 00000000..d5e88448
--- /dev/null
+++ b/src/plugin_runtime/host/hook_dispatcher.py
@@ -0,0 +1,166 @@
+"""
+Hook Dispatch 系统
+
+插件可以注册自己的Hook,当特定函数被调用时,Hook Dispatch系统会将调用转发给插件的Hook处理函数。
+每个Hook的参数随Hook点位确定,因此参数是易变的。插件开发者需要根据Hook点位的定义来编写Hook处理函数。
+在参数/返回值匹配的情况下允许修改参数/返回值。
+
+HookDispatcher 负责:
+1. 按 stage 查询已注册的 hook_handler(通过 ComponentRegistry)
+2. 按 priority 排序,区分 blocking 和非 blocking 模式
+3. blocking 模式:依次同步调用,支持修改参数/提前终止
+4. 非 blocking 模式:异步调用,不阻塞主流程
+5. 支持通过 global_config.plugin_runtime.hook_blocking_timeout_sec 设置超时上限
+"""
+
+import asyncio
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING
+
+from src.common.logger import get_logger
+from src.config.config import global_config
+
+
+if TYPE_CHECKING:
+ from .supervisor import PluginRunnerSupervisor
+ from .component_registry import ComponentRegistry, HookHandlerEntry
+
+logger = get_logger("plugin_runtime.host.hook_dispatcher")
+
+
+@dataclass
+class HookResult:
+ """单个 HookHandler 的执行结果"""
+
+ handler_name: str
+ success: bool = field(default=True)
+ continue_processing: bool = field(default=True)
+ modified_kwargs: Optional[Dict[str, Any]] = field(default=None)
+ custom_result: Any = field(default=None)
+
+
+class HookDispatcher:
+ """Host-side Hook 分发器
+
+ 由业务层调用 hook_dispatch(),
+ 内部通过 ComponentRegistry 查询 handler,
+ 再通过提供的 invoke_fn 回调 RPC 到 Runner 执行。
+ """
+
+ def __init__(self, component_registry: "ComponentRegistry") -> None:
+ """初始化 HookDispatcher
+
+ Args:
+ component_registry: ComponentRegistry 实例,用于查询已注册的 hook_handler
+ """
+ self._component_registry: "ComponentRegistry" = component_registry
+ self._background_tasks: Set[asyncio.Task] = set()
+
+ async def stop(self) -> None:
+ """停止 HookDispatcher,取消所有未完成的后台任务"""
+ for task in self._background_tasks:
+ task.cancel()
+ await asyncio.gather(*self._background_tasks, return_exceptions=True)
+ self._background_tasks.clear()
+
+ async def hook_dispatch(
+ self,
+ stage: str,
+ supervisor: "PluginRunnerSupervisor",
+ **kwargs: Any,
+ ) -> Dict[str, Any]:
+ """分发 hook 到所有对应 handler 的便捷方法。
+
+ 内置了通过 PluginRunnerSupervisor.invoke_plugin 调用 plugin 的逻辑,
+ 无需调用方手动构造 invoke_fn 闭包。
+
+ Args:
+ stage: hook 名称
+ supervisor: PluginRunnerSupervisor 实例,用于调用 invoke_plugin
+ **kwargs: 关键字参数,会展开传递给 handler
+
+ Returns:
+ modified_kwargs (Dict[str, Any]): 经过所有 handler 修改后的关键字参数
+ """
+ handler_entries = self._component_registry.get_hook_handlers(stage)
+ if not handler_entries:
+ return kwargs
+
+ current_kwargs = kwargs.copy()
+ blocking_handlers: List["HookHandlerEntry"] = []
+ non_blocking_handlers: List["HookHandlerEntry"] = []
+
+ # 分离 blocking 和非 blocking handler
+ for entry in handler_entries:
+ if entry.blocking:
+ blocking_handlers.append(entry)
+ else:
+ non_blocking_handlers.append(entry)
+
+ # 处理 blocking handlers(同步调用,支持修改参数/提前终止)
+ timeout = global_config.plugin_runtime.hook_blocking_timeout_sec or 30.0
+ for entry in blocking_handlers:
+ hook_args = {"stage": stage, **current_kwargs}
+ try:
+ # 应用超时控制
+ result = await asyncio.wait_for(
+ self._invoke_handler(supervisor, entry, hook_args),
+ timeout=timeout,
+ )
+ except asyncio.TimeoutError:
+ logger.error(f"Blocking HookHandler {entry.full_name} 执行超时 (>{timeout}秒),跳过")
+ result = HookResult(handler_name=entry.full_name, success=False, continue_processing=True)
+
+ if result:
+ if result.modified_kwargs is not None:
+ current_kwargs = result.modified_kwargs
+ if not result.continue_processing:
+ logger.info(f"HookHandler {entry.full_name} 终止了后续处理")
+ break
+
+ # 处理 non-blocking handlers(异步调用,不阻塞主流程)
+ for entry in non_blocking_handlers:
+ async_kwargs = current_kwargs.copy()
+ hook_args = {"stage": stage, **async_kwargs}
+ task = asyncio.create_task(
+ asyncio.wait_for(self._invoke_handler(supervisor, entry, hook_args), timeout=timeout)
+ )
+ self._background_tasks.add(task)
+ task.add_done_callback(self._background_tasks.discard)
+
+ return current_kwargs
+
+ async def _invoke_handler(
+ self,
+ supervisor: "PluginRunnerSupervisor",
+ handler_entry: "HookHandlerEntry",
+ args: Dict[str, Any],
+ ) -> Optional[HookResult]:
+ """调用单个 handler 并收集结果。
+
+ Args:
+ supervisor: PluginRunnerSupervisor 实例
+ handler_entry: HookHandlerEntry 实例
+ args: 传递给 handler 的参数字典
+ stage: hook 名称
+
+ Returns:
+ Optional[HookResult]: 执行结果,如果执行失败则返回 None
+ """
+ try:
+ resp_envelope = await supervisor.invoke_plugin(
+ "plugin.invoke_hook", handler_entry.plugin_id, handler_entry.name, args
+ )
+ resp = resp_envelope.payload
+ result = HookResult(
+ handler_name=handler_entry.full_name,
+ success=resp.get("success", True),
+ continue_processing=resp.get("continue_processing", True),
+ modified_kwargs=resp.get("modified_kwargs"),
+ custom_result=resp.get("custom_result"),
+ )
+ except Exception as e:
+ logger.error(f"HookHandler {handler_entry.full_name} 执行失败:{e}", exc_info=True)
+ result = HookResult(handler_name=handler_entry.full_name, success=False, continue_processing=True)
+
+ return result
diff --git a/src/plugin_runtime/host/workflow_executor.py b/src/plugin_runtime/host/workflow_executor.py
deleted file mode 100644
index 3037e9dd..00000000
--- a/src/plugin_runtime/host/workflow_executor.py
+++ /dev/null
@@ -1,422 +0,0 @@
-"""Host-side WorkflowExecutor
-
-6 阶段线性流转(INGRESS → PRE_PROCESS → PLAN → TOOL_EXECUTE → POST_PROCESS → EGRESS)
-
-每个阶段执行顺序:
-1. Host-side pre-filter: 根据 hook filter 条件过滤不相关的 hook
-2. 按 priority 降序排列
-3. 串行执行 blocking hook(可修改 message,返回 HookResult)
-4. 并发执行 non-blocking hook(只读)
-5. 检查是否有 SKIP_STAGE 或 ABORT
-6. PLAN 阶段内置 Command 匹配路由
-
-支持:
-- HookResult: CONTINUE / SKIP_STAGE / ABORT
-- ErrorPolicy: ABORT / SKIP / LOG (per-hook)
-- stage_outputs: 阶段间带命名空间的数据传递
-- modification_log: 消息修改审计
-"""
-
-from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple
-
-import asyncio
-import time
-import uuid
-
-from src.common.logger import get_logger
-from src.config.config import global_config
-from src.plugin_runtime.host.component_registry import ComponentRegistry, RegisteredComponent
-
-logger = get_logger("plugin_runtime.host.workflow_executor")
-
-# 阶段顺序
-STAGE_SEQUENCE: List[str] = [
- "ingress",
- "pre_process",
- "plan",
- "tool_execute",
- "post_process",
- "egress",
-]
-
-# HookResult 常量(与 SDK HookResult enum 值对应)
-HOOK_CONTINUE = "continue"
-HOOK_SKIP_STAGE = "skip_stage"
-HOOK_ABORT = "abort"
-
-
-# blocking hook 全局最大超时(秒):即使 hook 声明 timeout_ms=0 也不会无限等待
-# 从配置文件读取,允许用户调整
-def _get_blocking_timeout() -> float:
- return global_config.plugin_runtime.workflow_blocking_timeout_sec
-
-
-class ModificationRecord:
- """消息修改记录"""
-
- __slots__ = ("stage", "hook_name", "timestamp", "fields_changed")
-
- def __init__(self, stage: str, hook_name: str, fields_changed: List[str]) -> None:
- self.stage = stage
- self.hook_name = hook_name
- self.timestamp = time.perf_counter()
- self.fields_changed = fields_changed
-
-
-class WorkflowContext:
- """Workflow 执行上下文"""
-
- def __init__(self, trace_id: Optional[str] = None, stream_id: Optional[str] = None) -> None:
- self.trace_id = trace_id or uuid.uuid4().hex
- self.stream_id = stream_id
- self.timings: Dict[str, float] = {}
- self.errors: List[str] = []
- # 阶段间数据传递(按 stage 命名空间隔离)
- self.stage_outputs: Dict[str, Dict[str, Any]] = {}
- # 消息修改审计日志
- self.modification_log: List[ModificationRecord] = []
- # PLAN 阶段命令匹配结果
- self.matched_command: Optional[str] = None
-
- def set_stage_output(self, stage: str, key: str, value: Any) -> None:
- self.stage_outputs.setdefault(stage, {})[key] = value
-
- def get_stage_output(self, stage: str, key: str, default: Any = None) -> Any:
- return self.stage_outputs.get(stage, {}).get(key, default)
-
-
-class WorkflowResult:
- """Workflow 执行结果"""
-
- def __init__(
- self,
- status: str = "completed", # completed / aborted / failed
- return_message: str = "",
- stopped_at: str = "",
- diagnostics: Optional[Dict[str, Any]] = None,
- ) -> None:
- self.status = status
- self.return_message = return_message
- self.stopped_at = stopped_at
- self.diagnostics = diagnostics or {}
-
-
-# invoke_fn 签名
-InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]]
-
-
-class WorkflowExecutor:
- """Host-side Workflow 执行器
-
- 实现 stage-based pipeline + per-stage hook chain with priority + early return。
- """
-
- def __init__(self, registry: ComponentRegistry) -> None:
- self._registry = registry
- self._background_tasks: Set[asyncio.Task] = set()
-
- async def execute(
- self,
- invoke_fn: InvokeFn,
- message: Optional[Dict[str, Any]] = None,
- stream_id: Optional[str] = None,
- context: Optional[WorkflowContext] = None,
- command_invoke_fn: Optional[InvokeFn] = None,
- ) -> Tuple[WorkflowResult, Optional[Dict[str, Any]], WorkflowContext]:
- """执行 workflow pipeline。
-
- Args:
- invoke_fn: 用于 workflow_step 的回调
- command_invoke_fn: 用于 command 的回调(走 plugin.invoke_command),
- 未传则复用 invoke_fn
-
- Returns:
- (result, final_message, context)
- """
- ctx = context or WorkflowContext(stream_id=stream_id)
- current_message = dict(message) if message else None
-
- for stage in STAGE_SEQUENCE:
- stage_start = time.perf_counter()
-
- try:
- # PLAN 阶段: 先做 Command 路由
- if stage == "plan" and current_message:
- cmd_result = await self._route_command(command_invoke_fn or invoke_fn, current_message, ctx)
- if cmd_result is not None:
- # 命令匹配成功,跳过 PLAN 阶段的 hook,直接存结果进 stage_outputs
- ctx.set_stage_output("plan", "command_result", cmd_result)
- ctx.timings[stage] = time.perf_counter() - stage_start
- continue
-
- # 获取该阶段所有 hook(已按 priority 降序排列)
- all_steps = self._registry.get_workflow_steps(stage)
- if not all_steps:
- ctx.timings[stage] = time.perf_counter() - stage_start
- continue
-
- # 1. Pre-filter
- filtered_steps = self._pre_filter(all_steps, current_message)
-
- # 2. 分离 blocking 和 non-blocking
- blocking_steps = [s for s in filtered_steps if s.metadata.get("blocking", True)]
- nonblocking_steps = [s for s in filtered_steps if not s.metadata.get("blocking", True)]
-
- # 3. 串行执行 blocking hook
- skip_stage = False
- for step in blocking_steps:
- hook_result, modified, step_error = await self._invoke_step(
- invoke_fn, step, stage, ctx, current_message
- )
-
- if step_error:
- error_policy = step.metadata.get("error_policy", "abort")
- ctx.errors.append(f"{step.full_name}: {step_error}")
-
- if error_policy == "abort":
- ctx.timings[stage] = time.perf_counter() - stage_start
- return (
- WorkflowResult(
- status="failed",
- return_message=step_error,
- stopped_at=stage,
- diagnostics={"step": step.full_name, "trace_id": ctx.trace_id},
- ),
- current_message,
- ctx,
- )
- elif error_policy == "skip":
- logger.warning(f"[{ctx.trace_id}] hook {step.full_name} 异常(skip): {step_error}")
- continue
- else: # log
- logger.warning(f"[{ctx.trace_id}] hook {step.full_name} 异常(log): {step_error}")
- continue
-
- # 更新消息(仅 blocking hook 有权修改)
- if modified:
- changed_fields = (
- _diff_keys(current_message, modified) if current_message else list(modified.keys())
- )
- ctx.modification_log.append(ModificationRecord(stage, step.full_name, changed_fields))
- current_message = modified
-
- if hook_result == HOOK_ABORT:
- ctx.timings[stage] = time.perf_counter() - stage_start
- return (
- WorkflowResult(
- status="aborted",
- return_message=f"aborted by {step.full_name}",
- stopped_at=stage,
- diagnostics={"step": step.full_name, "trace_id": ctx.trace_id},
- ),
- current_message,
- ctx,
- )
-
- if hook_result == HOOK_SKIP_STAGE:
- skip_stage = True
- break
-
- # 4. 并发执行 non-blocking hook(只读,忽略返回值中的 modified_message)
- if nonblocking_steps and not skip_stage:
- for step in nonblocking_steps:
- self._track_background_task(
- asyncio.create_task(
- self._invoke_step_fire_and_forget(invoke_fn, step, stage, ctx, current_message)
- )
- )
-
- ctx.timings[stage] = time.perf_counter() - stage_start
-
- except Exception as e:
- ctx.timings[stage] = time.perf_counter() - stage_start
- ctx.errors.append(f"{stage}: {e}")
- logger.error(f"[{ctx.trace_id}] 阶段 {stage} 未捕获异常: {e}", exc_info=True)
- return (
- WorkflowResult(
- status="failed",
- return_message=str(e),
- stopped_at=stage,
- diagnostics={"trace_id": ctx.trace_id},
- ),
- current_message,
- ctx,
- )
-
- return (
- WorkflowResult(
- status="completed",
- return_message="workflow completed",
- diagnostics={"trace_id": ctx.trace_id},
- ),
- current_message,
- ctx,
- )
-
- def _track_background_task(self, task: asyncio.Task) -> None:
- """保持 non-blocking workflow task 的强引用,直到任务结束。"""
- self._background_tasks.add(task)
- task.add_done_callback(self._background_tasks.discard)
-
- # ─── 内部方法 ──────────────────────────────────────────────
-
- def _pre_filter(
- self,
- steps: List[RegisteredComponent],
- message: Optional[Dict[str, Any]],
- ) -> List[RegisteredComponent]:
- """根据 hook 声明的 filter 条件预过滤,避免无意义的 IPC 调用。"""
- if not message:
- return steps
-
- result = []
- for step in steps:
- filter_cond = step.metadata.get("filter", {})
- if not filter_cond:
- result.append(step)
- continue
- if self._match_filter(filter_cond, message):
- result.append(step)
- return result
-
- @staticmethod
- def _match_filter(filter_cond: Dict[str, Any], message: Dict[str, Any]) -> bool:
- """简单 key-value 匹配过滤。
-
- filter 中的每个 key 必须在 message 中存在且值相等,
- 全部匹配才通过。
- """
- for key, expected in filter_cond.items():
- actual = message.get(key)
- if (isinstance(expected, list) and actual not in expected) or (
- not isinstance(expected, list) and actual != expected
- ):
- return False
- return True
-
- async def _invoke_step(
- self,
- invoke_fn: InvokeFn,
- step: RegisteredComponent,
- stage: str,
- ctx: WorkflowContext,
- message: Optional[Dict[str, Any]],
- ) -> Tuple[str, Optional[Dict[str, Any]], Optional[str]]:
- """调用单个 blocking hook。
-
- Returns:
- (hook_result, modified_message, error_string_or_None)
- """
- timeout_ms = step.metadata.get("timeout_ms", 0)
- # 使用 hook 声明的超时,但不超过全局安全阀
- timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else _get_blocking_timeout()
- step_key = f"{stage}:{step.full_name}"
- step_start = time.perf_counter()
-
- try:
- coro = invoke_fn(
- step.plugin_id,
- step.name,
- {
- "stage": stage,
- "trace_id": ctx.trace_id,
- "message": message,
- "stage_outputs": ctx.stage_outputs,
- },
- )
- resp = await asyncio.wait_for(coro, timeout=timeout_sec)
- ctx.timings[step_key] = time.perf_counter() - step_start
-
- hook_result = resp.get("hook_result", HOOK_CONTINUE)
- modified_message = resp.get("modified_message")
- # 存 stage output(如果 hook 提供了)
- stage_out = resp.get("stage_output")
- if isinstance(stage_out, dict):
- for k, v in stage_out.items():
- ctx.set_stage_output(stage, k, v)
-
- return hook_result, modified_message, None
-
- except asyncio.TimeoutError:
- ctx.timings[step_key] = time.perf_counter() - step_start
- return HOOK_CONTINUE, None, f"timeout after {timeout_ms}ms"
-
- except Exception as e:
- ctx.timings[step_key] = time.perf_counter() - step_start
- return HOOK_CONTINUE, None, str(e)
-
- async def _invoke_step_fire_and_forget(
- self,
- invoke_fn: InvokeFn,
- step: RegisteredComponent,
- stage: str,
- ctx: WorkflowContext,
- message: Optional[Dict[str, Any]],
- ) -> None:
- """Non-blocking hook 调用,只读,忽略结果。"""
- timeout_ms = step.metadata.get("timeout_ms", 0)
- # 使用 hook 声明的超时,但无声明时回退到全局安全阀,防止 task 泄漏
- timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else _get_blocking_timeout()
-
- try:
- coro = invoke_fn(
- step.plugin_id,
- step.name,
- {
- "stage": stage,
- "trace_id": ctx.trace_id,
- "message": message,
- "stage_outputs": ctx.stage_outputs,
- },
- )
- await asyncio.wait_for(coro, timeout=timeout_sec)
- except asyncio.TimeoutError:
- logger.warning(f"[{ctx.trace_id}] non-blocking hook {step.full_name} 超时 ({timeout_sec}s)")
- except Exception as e:
- logger.debug(f"[{ctx.trace_id}] non-blocking hook {step.full_name}: {e}")
-
- async def _route_command(
- self,
- invoke_fn: InvokeFn,
- message: Dict[str, Any],
- ctx: WorkflowContext,
- ) -> Optional[Dict[str, Any]]:
- """PLAN 阶段内置 Command 路由。
-
- 在 registry 中查找匹配的 command 组件,
- 匹配到则直接路由到对应 command handler,返回执行结果。
- 不匹配则返回 None,让 PLAN 阶段的 hook 继续执行。
- """
- plain_text = message.get("plain_text", "")
- if not plain_text:
- return None
-
- match_result = self._registry.find_command_by_text(plain_text)
- if match_result is None:
- return None
-
- matched, matched_groups = match_result
-
- ctx.matched_command = matched.full_name
- logger.info(f"[{ctx.trace_id}] 命令匹配: {matched.full_name}")
-
- try:
- return await invoke_fn(
- matched.plugin_id,
- matched.name,
- {
- "text": plain_text,
- "message": message,
- "trace_id": ctx.trace_id,
- "matched_groups": matched_groups,
- },
- )
- except Exception as e:
- logger.error(f"[{ctx.trace_id}] 命令 {matched.full_name} 执行失败: {e}", exc_info=True)
- ctx.errors.append(f"command:{matched.full_name}: {e}")
- return None
-
-
-def _diff_keys(old: Dict[str, Any], new: Dict[str, Any]) -> List[str]:
- """返回 new 中与 old 不同的 key 列表。"""
- return [k for k, v in new.items() if k not in old or old[k] != v]
diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py
index ca9e8005..e81df019 100644
--- a/src/plugin_runtime/protocol/envelope.py
+++ b/src/plugin_runtime/protocol/envelope.py
@@ -134,9 +134,9 @@ class ComponentDeclaration(BaseModel):
name: str = Field(description="组件名称")
"""组件名称"""
component_type: str = Field(
- description="组件类型:action/command/tool/event_handler/workflow_handler/message_gateway"
+ description="组件类型:action/command/tool/event_handler/hook_handler/message_gateway"
)
- """组件类型:`action`/`command`/`tool`/`event_handler`/`workflow_handler`/`message_gateway`"""
+ """组件类型:`action`/`command`/`tool`/`event_handler`/`hook_handler`/`message_gateway`"""
plugin_id: str = Field(description="所属插件 ID")
"""所属插件 ID"""
metadata: Dict[str, Any] = Field(default_factory=dict, description="组件元数据")
From 6cc7e37b1e91eae0ed439b7e7c343ac8d32c4e00 Mon Sep 17 00:00:00 2001
From: UnCLAS-Prommer
Date: Thu, 19 Mar 2026 20:30:09 +0800
Subject: [PATCH 13/45] =?UTF-8?q?refactor:=20=E6=8F=90=E5=8F=96=E9=83=A8?=
=?UTF-8?q?=E5=88=86=E5=85=B1=E5=90=8C=E6=96=B9=E6=B3=95=EF=BC=8C=E9=A2=84?=
=?UTF-8?q?=E5=A4=87supervisor?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
src/plugin_runtime/host/event_dispatcher.py | 24 ++-
src/plugin_runtime/host/message_gateway.py | 223 +------------------
src/plugin_runtime/host/message_utils.py | 224 ++++++++++++++++++++
3 files changed, 247 insertions(+), 224 deletions(-)
create mode 100644 src/plugin_runtime/host/message_utils.py
diff --git a/src/plugin_runtime/host/event_dispatcher.py b/src/plugin_runtime/host/event_dispatcher.py
index 29ae530b..d252b6ee 100644
--- a/src/plugin_runtime/host/event_dispatcher.py
+++ b/src/plugin_runtime/host/event_dispatcher.py
@@ -14,10 +14,12 @@ import asyncio
from src.common.logger import get_logger
+from .message_utils import PluginMessageUtils, MessageDict
if TYPE_CHECKING:
from .supervisor import PluginRunnerSupervisor
from .component_registry import ComponentRegistry, EventHandlerEntry
+ from src.chat.message_receive.message import SessionMessage
logger = get_logger("plugin_runtime.host.event_dispatcher")
@@ -34,7 +36,7 @@ class EventResult:
handler_name: str
success: bool = field(default=True)
continue_processing: bool = field(default=True)
- modified_message: Optional[Dict[str, Any]] = field(default=None)
+ modified_message: Optional[MessageDict] = field(default=None)
custom_result: Any = field(default=None)
@@ -78,9 +80,9 @@ class EventDispatcher:
self,
event_type: str,
supervisor: "PluginRunnerSupervisor",
- message_dict: Optional[Dict[str, Any]] = None,
+ message: Optional["SessionMessage"] = None,
extra_args: Optional[Dict[str, Any]] = None,
- ) -> Tuple[bool, Optional[Dict[str, Any]]]:
+ ) -> Tuple[bool, Optional["SessionMessage"]]:
"""分发事件到所有对应 handler 的便捷方法。
内置了通过 PluginSupervisor.invoke_plugin 调用 plugin.emit_event 的逻辑,
@@ -93,14 +95,16 @@ class EventDispatcher:
extra_args: 额外参数
Returns:
- (should_continue, modified_message_dict) (bool, Dict[str, Any] | None): (是否继续后续执行, 可选的修改后的消息字典)
+ (should_continue, modified_message_dict) (bool, SessionMessage | None): (是否继续后续执行, 可选的修改后的消息)
"""
handler_entries = self._component_registry.get_event_handlers(event_type)
if not handler_entries:
return True, None
should_continue = True
- modified_message: Optional[Dict[str, Any]] = message_dict
+ modified_message: Optional[MessageDict] = (
+ PluginMessageUtils._session_message_to_dict(message) if message else None
+ )
intercept_handlers: List["EventHandlerEntry"] = []
non_blocking_handlers: List["EventHandlerEntry"] = []
@@ -136,8 +140,14 @@ class EventDispatcher:
task = asyncio.create_task(self._invoke_handler(supervisor, entry, args, event_type))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
-
- return should_continue, modified_message
+ try:
+ modified_message_obj = (
+ PluginMessageUtils._build_session_message_from_dict(modified_message) if modified_message else None # type: ignore
+ )
+ except Exception as e:
+ logger.error(f"构建修改后的 SessionMessage 失败: {e}")
+ modified_message_obj = None
+ return should_continue, modified_message_obj
async def _invoke_handler(
self,
diff --git a/src/plugin_runtime/host/message_gateway.py b/src/plugin_runtime/host/message_gateway.py
index e995ed01..43777286 100644
--- a/src/plugin_runtime/host/message_gateway.py
+++ b/src/plugin_runtime/host/message_gateway.py
@@ -3,56 +3,19 @@ Message Gateway 模块
适配器专用,用于将其他平台的消息转换为系统内部的消息格式,并将系统消息转换为其他平台的格式。
"""
-from datetime import datetime
-from typing import Dict, Any, TYPE_CHECKING, TypedDict, Optional, List
+from typing import Dict, Any, TYPE_CHECKING
from src.common.logger import get_logger
-from src.chat.message_receive.message import SessionMessage
-from src.common.data_models.mai_message_data_model import UserInfo, GroupInfo, MessageInfo
-from src.common.data_models.message_component_data_model import MessageSequence
+from .message_utils import PluginMessageUtils
if TYPE_CHECKING:
+ from src.chat.message_receive.message import SessionMessage
from .component_registry import ComponentRegistry
from .supervisor import PluginRunnerSupervisor
logger = get_logger("plugin_runtime.host.message_gateway")
-class UserInfoDict(TypedDict, total=False):
- user_id: str
- user_nickname: str
- user_cardname: Optional[str]
-
-
-class GroupInfoDict(TypedDict, total=False):
- group_id: str
- group_name: str
-
-
-class MessageInfoDict(TypedDict, total=False):
- user_info: UserInfoDict
- group_info: Optional[GroupInfoDict]
- additional_config: Dict[str, Any]
-
-
-class MessageDict(TypedDict, total=False):
- message_id: str
- timestamp: str
- platform: str
- message_info: MessageInfoDict
- raw_message: List[Dict[str, Any]]
- is_mentioned: bool
- is_at: bool
- is_emoji: bool
- is_picture: bool
- is_command: bool
- is_notify: bool
- session_id: str
- reply_to: Optional[str]
- processed_plain_text: Optional[str]
- display_message: Optional[str]
-
-
class MessageGateway:
def __init__(self, component_registry: "ComponentRegistry") -> None:
self._component_registry = component_registry
@@ -69,7 +32,7 @@ class MessageGateway:
"""
# 使用递归函数将外部消息字典转换为 SessionMessage
try:
- session_message = self._build_session_message_from_dict(external_message)
+ session_message = PluginMessageUtils._build_session_message_from_dict(external_message)
except Exception as e:
logger.error(f"转换外部消息失败: {e}")
return
@@ -79,7 +42,7 @@ class MessageGateway:
async def send_message_to_external(
self,
- internal_message: SessionMessage,
+ internal_message: "SessionMessage",
supervisor: "PluginRunnerSupervisor",
*,
enabled_only: bool = True,
@@ -96,7 +59,7 @@ class MessageGateway:
"""
try:
# 将 SessionMessage 转换为字典格式
- message_dict = self._session_message_to_dict(internal_message)
+ message_dict = PluginMessageUtils._session_message_to_dict(internal_message)
except Exception as e:
logger.error(f"转换内部消息失败:{e}")
return False
@@ -133,177 +96,3 @@ class MessageGateway:
except Exception as e:
logger.error(f"保存消息到数据库失败: {e}")
return True
-
- def _message_info_to_dict(self, message_info: MessageInfo) -> MessageInfoDict:
- """
- 将 MessageInfo 对象转换为字典格式
-
- Args:
- message_info: MessageInfo 对象
-
- Returns:
- 字典格式的消息信息
- """
- user_info_dict = UserInfoDict(
- user_id=message_info.user_info.user_id,
- user_nickname=message_info.user_info.user_nickname,
- user_cardname=message_info.user_info.user_cardname,
- )
-
- group_info_dict: Optional[GroupInfoDict] = None
- if message_info.group_info:
- group_info_dict = GroupInfoDict(
- group_id=message_info.group_info.group_id,
- group_name=message_info.group_info.group_name,
- )
-
- return MessageInfoDict(
- user_info=user_info_dict,
- group_info=group_info_dict,
- additional_config=message_info.additional_config,
- )
-
- def _session_message_to_dict(self, session_message: SessionMessage) -> MessageDict:
- """
- 将 SessionMessage 对象转换为字典格式(复用 MessageSequence.to_dict 方法)
-
- Args:
- session_message: SessionMessage 对象
-
- Returns:
- 字典格式的消息
- """
- # 转换基本信息
- message_dict = MessageDict(
- message_id=session_message.message_id,
- timestamp=str(session_message.timestamp.timestamp()), # 转换为时间戳字符串
- platform=session_message.platform,
- message_info=self._message_info_to_dict(session_message.message_info),
- raw_message=session_message.raw_message.to_dict(), # 复用 MessageSequence.to_dict()
- is_mentioned=session_message.is_mentioned,
- is_at=session_message.is_at,
- is_emoji=session_message.is_emoji,
- is_picture=session_message.is_picture,
- is_command=session_message.is_command,
- is_notify=session_message.is_notify,
- session_id=session_message.session_id,
- )
-
- # 添加可选字段
- if session_message.reply_to is not None:
- message_dict["reply_to"] = session_message.reply_to
- if session_message.processed_plain_text is not None:
- message_dict["processed_plain_text"] = session_message.processed_plain_text
- if session_message.display_message is not None:
- message_dict["display_message"] = session_message.display_message
-
- return message_dict
-
- def _build_message_info_from_dict(self, message_info_dict: Dict[str, Any]) -> MessageInfo:
- """
- 从字典构建 MessageInfo 对象
-
- Args:
- message_info_dict: 包含消息信息的字典
-
- Returns:
- MessageInfo 对象
- """
- # 构建用户信息
- user_info_dict = message_info_dict.get("user_info")
- if not user_info_dict or not isinstance(user_info_dict, dict):
- raise ValueError("消息字典中 'user_info' 字段无效")
- user_id = user_info_dict.get("user_id")
- user_nickname = user_info_dict.get("user_nickname")
- user_cardname = user_info_dict.get("user_cardname")
- if not isinstance(user_id, str) or not isinstance(user_nickname, str) or not user_id or not user_nickname:
- raise ValueError("消息字典中 'user_info' 字段缺少有效的 'user_id' 或 'user_nickname'")
- user_cardname = str(user_cardname) if user_cardname is not None else None
- user_info = UserInfo(user_id=user_id, user_nickname=user_nickname, user_cardname=user_cardname)
-
- # 构建群信息
- if group_info_dict := message_info_dict.get("group_info"):
- group_id = group_info_dict.get("group_id")
- group_name = group_info_dict.get("group_name")
- if not isinstance(group_id, str) or not isinstance(group_name, str) or not group_id or not group_name:
- raise ValueError("消息字典中 'group_info' 字段缺少有效的 'group_id' 或 'group_name'")
- group_info = GroupInfo(group_id=group_id, group_name=group_name)
- else:
- group_info = None
-
- # 获取额外配置
- additional_config: Dict[str, Any] = message_info_dict.get("additional_config", {})
-
- return MessageInfo(user_info=user_info, group_info=group_info, additional_config=additional_config)
-
- def _build_session_message_from_dict(self, message_dict: Dict[str, Any]) -> SessionMessage:
- """
- 从字典构建 SessionMessage 对象(递归处理消息组件)
-
- Args:
- message_dict: 包含消息完整信息的字典
-
- Returns:
- SessionMessage 对象
- """
- # 提取基本信息
- message_id = message_dict["message_id"]
- timestamp_str: str = message_dict.get("timestamp", "")
- platform = message_dict["platform"]
- if not isinstance(message_id, str) or not message_id:
- raise ValueError("消息字典中缺少有效的 'message_id' 字段")
- if not isinstance(platform, str) or not platform:
- raise ValueError("消息字典中缺少有效的 'platform' 字段")
-
- # 解析时间戳
- try:
- timestamp_float = float(timestamp_str)
- timestamp = datetime.fromtimestamp(timestamp_float)
- except (ValueError, TypeError):
- timestamp = datetime.now() # 如果解析失败,使用当前时间
-
- # 创建 SessionMessage 实例
- session_message = SessionMessage(message_id=message_id, timestamp=timestamp, platform=platform)
-
- # 构建消息信息
- session_message.message_info = self._build_message_info_from_dict(message_dict["message_info"])
-
- # 构建原始消息组件序列(复用 MessageSequence.from_dict 方法)
- raw_message_data = message_dict["raw_message"]
- if isinstance(raw_message_data, list):
- session_message.raw_message = MessageSequence.from_dict(raw_message_data)
- else:
- raise ValueError("消息字典中 'raw_message' 字段必须是一个列表")
-
- # 设置其他可选属性
- session_message.is_mentioned = message_dict.get("is_mentioned", False)
- if not isinstance(session_message.is_mentioned, bool):
- session_message.is_mentioned = False
- session_message.is_at = message_dict.get("is_at", False)
- if not isinstance(session_message.is_at, bool):
- session_message.is_at = False
- session_message.is_emoji = message_dict.get("is_emoji", False)
- if not isinstance(session_message.is_emoji, bool):
- session_message.is_emoji = False
- session_message.is_picture = message_dict.get("is_picture", False)
- if not isinstance(session_message.is_picture, bool):
- session_message.is_picture = False
- session_message.is_command = message_dict.get("is_command", False)
- if not isinstance(session_message.is_command, bool):
- session_message.is_command = False
- session_message.is_notify = message_dict.get("is_notify", False)
- if not isinstance(session_message.is_notify, bool):
- session_message.is_notify = False
- session_message.reply_to = message_dict.get("reply_to")
- if session_message.reply_to is not None and not isinstance(session_message.reply_to, str):
- session_message.reply_to = None
- session_message.processed_plain_text = message_dict.get("processed_plain_text")
- if session_message.processed_plain_text is not None and not isinstance(
- session_message.processed_plain_text, str
- ):
- session_message.processed_plain_text = None
- session_message.display_message = message_dict.get("display_message")
- if session_message.display_message is not None and not isinstance(session_message.display_message, str):
- session_message.display_message = None
-
- return session_message
diff --git a/src/plugin_runtime/host/message_utils.py b/src/plugin_runtime/host/message_utils.py
new file mode 100644
index 00000000..428e3c48
--- /dev/null
+++ b/src/plugin_runtime/host/message_utils.py
@@ -0,0 +1,224 @@
+from datetime import datetime
+from typing import Dict, Any, TypedDict, Optional, List
+
+from src.common.logger import get_logger
+from src.chat.message_receive.message import SessionMessage
+from src.common.data_models.mai_message_data_model import UserInfo, GroupInfo, MessageInfo
+from src.common.data_models.message_component_data_model import MessageSequence
+
+logger = get_logger("plugin_runtime.host.message_utils")
+
+
+class UserInfoDict(TypedDict, total=False):
+ user_id: str
+ user_nickname: str
+ user_cardname: Optional[str]
+
+
+class GroupInfoDict(TypedDict, total=False):
+ group_id: str
+ group_name: str
+
+
+class MessageInfoDict(TypedDict, total=False):
+ user_info: UserInfoDict
+ group_info: Optional[GroupInfoDict]
+ additional_config: Dict[str, Any]
+
+
+class MessageDict(TypedDict, total=False):
+ message_id: str
+ timestamp: str
+ platform: str
+ message_info: MessageInfoDict
+ raw_message: List[Dict[str, Any]]
+ is_mentioned: bool
+ is_at: bool
+ is_emoji: bool
+ is_picture: bool
+ is_command: bool
+ is_notify: bool
+ session_id: str
+ reply_to: Optional[str]
+ processed_plain_text: Optional[str]
+ display_message: Optional[str]
+
+
+class PluginMessageUtils:
+ @staticmethod
+ def _message_info_to_dict(message_info: MessageInfo) -> MessageInfoDict:
+ """
+ 将 MessageInfo 对象转换为字典格式
+
+ Args:
+ message_info: MessageInfo 对象
+
+ Returns:
+ 字典格式的消息信息
+ """
+ user_info_dict = UserInfoDict(
+ user_id=message_info.user_info.user_id,
+ user_nickname=message_info.user_info.user_nickname,
+ user_cardname=message_info.user_info.user_cardname,
+ )
+
+ group_info_dict: Optional[GroupInfoDict] = None
+ if message_info.group_info:
+ group_info_dict = GroupInfoDict(
+ group_id=message_info.group_info.group_id,
+ group_name=message_info.group_info.group_name,
+ )
+
+ return MessageInfoDict(
+ user_info=user_info_dict,
+ group_info=group_info_dict,
+ additional_config=message_info.additional_config,
+ )
+
+ @staticmethod
+ def _session_message_to_dict(session_message: SessionMessage) -> MessageDict:
+ """
+ 将 SessionMessage 对象转换为字典格式(复用 MessageSequence.to_dict 方法)
+
+ Args:
+ session_message: SessionMessage 对象
+
+ Returns:
+ 字典格式的消息
+ """
+ # 转换基本信息
+ message_dict = MessageDict(
+ message_id=session_message.message_id,
+ timestamp=str(session_message.timestamp.timestamp()), # 转换为时间戳字符串
+ platform=session_message.platform,
+ message_info=PluginMessageUtils._message_info_to_dict(session_message.message_info),
+ raw_message=session_message.raw_message.to_dict(), # 复用 MessageSequence.to_dict()
+ is_mentioned=session_message.is_mentioned,
+ is_at=session_message.is_at,
+ is_emoji=session_message.is_emoji,
+ is_picture=session_message.is_picture,
+ is_command=session_message.is_command,
+ is_notify=session_message.is_notify,
+ session_id=session_message.session_id,
+ )
+
+ # 添加可选字段
+ if session_message.reply_to is not None:
+ message_dict["reply_to"] = session_message.reply_to
+ if session_message.processed_plain_text is not None:
+ message_dict["processed_plain_text"] = session_message.processed_plain_text
+ if session_message.display_message is not None:
+ message_dict["display_message"] = session_message.display_message
+
+ return message_dict
+
+ @staticmethod
+ def _build_message_info_from_dict(message_info_dict: Dict[str, Any]) -> MessageInfo:
+ """
+ 从字典构建 MessageInfo 对象
+
+ Args:
+ message_info_dict: 包含消息信息的字典
+
+ Returns:
+ MessageInfo 对象
+ """
+ # 构建用户信息
+ user_info_dict = message_info_dict.get("user_info")
+ if not user_info_dict or not isinstance(user_info_dict, dict):
+ raise ValueError("消息字典中 'user_info' 字段无效")
+ user_id = user_info_dict.get("user_id")
+ user_nickname = user_info_dict.get("user_nickname")
+ user_cardname = user_info_dict.get("user_cardname")
+ if not isinstance(user_id, str) or not isinstance(user_nickname, str) or not user_id or not user_nickname:
+ raise ValueError("消息字典中 'user_info' 字段缺少有效的 'user_id' 或 'user_nickname'")
+ user_cardname = str(user_cardname) if user_cardname is not None else None
+ user_info = UserInfo(user_id=user_id, user_nickname=user_nickname, user_cardname=user_cardname)
+
+ # 构建群信息
+ if group_info_dict := message_info_dict.get("group_info"):
+ group_id = group_info_dict.get("group_id")
+ group_name = group_info_dict.get("group_name")
+ if not isinstance(group_id, str) or not isinstance(group_name, str) or not group_id or not group_name:
+ raise ValueError("消息字典中 'group_info' 字段缺少有效的 'group_id' 或 'group_name'")
+ group_info = GroupInfo(group_id=group_id, group_name=group_name)
+ else:
+ group_info = None
+
+ # 获取额外配置
+ additional_config: Dict[str, Any] = message_info_dict.get("additional_config", {})
+
+ return MessageInfo(user_info=user_info, group_info=group_info, additional_config=additional_config)
+
+ @staticmethod
+ def _build_session_message_from_dict(message_dict: Dict[str, Any]) -> SessionMessage:
+ """
+ 从字典构建 SessionMessage 对象(递归处理消息组件)
+
+ Args:
+ message_dict: 包含消息完整信息的字典
+
+ Returns:
+ SessionMessage 对象
+ """
+ # 提取基本信息
+ message_id = message_dict["message_id"]
+ timestamp_str: str = message_dict.get("timestamp", "")
+ platform = message_dict["platform"]
+ if not isinstance(message_id, str) or not message_id:
+ raise ValueError("消息字典中缺少有效的 'message_id' 字段")
+ if not isinstance(platform, str) or not platform:
+ raise ValueError("消息字典中缺少有效的 'platform' 字段")
+
+ # 解析时间戳
+ try:
+ timestamp_float = float(timestamp_str)
+ timestamp = datetime.fromtimestamp(timestamp_float)
+ except (ValueError, TypeError):
+ timestamp = datetime.now() # 如果解析失败,使用当前时间
+
+ # 创建 SessionMessage 实例
+ session_message = SessionMessage(message_id=message_id, timestamp=timestamp, platform=platform)
+
+ # 构建消息信息
+ session_message.message_info = PluginMessageUtils._build_message_info_from_dict(message_dict["message_info"])
+
+ # 构建原始消息组件序列(复用 MessageSequence.from_dict 方法)
+ raw_message_data = message_dict["raw_message"]
+ if isinstance(raw_message_data, list):
+ session_message.raw_message = MessageSequence.from_dict(raw_message_data)
+ else:
+ raise ValueError("消息字典中 'raw_message' 字段必须是一个列表")
+
+ # 设置其他可选属性
+ session_message.is_mentioned = message_dict.get("is_mentioned", False)
+ if not isinstance(session_message.is_mentioned, bool):
+ session_message.is_mentioned = False
+ session_message.is_at = message_dict.get("is_at", False)
+ if not isinstance(session_message.is_at, bool):
+ session_message.is_at = False
+ session_message.is_emoji = message_dict.get("is_emoji", False)
+ if not isinstance(session_message.is_emoji, bool):
+ session_message.is_emoji = False
+ session_message.is_picture = message_dict.get("is_picture", False)
+ if not isinstance(session_message.is_picture, bool):
+ session_message.is_picture = False
+ session_message.is_command = message_dict.get("is_command", False)
+ if not isinstance(session_message.is_command, bool):
+ session_message.is_command = False
+ session_message.is_notify = message_dict.get("is_notify", False)
+ if not isinstance(session_message.is_notify, bool):
+ session_message.is_notify = False
+ session_message.reply_to = message_dict.get("reply_to")
+ if session_message.reply_to is not None and not isinstance(session_message.reply_to, str):
+ session_message.reply_to = None
+ session_message.processed_plain_text = message_dict.get("processed_plain_text")
+ if session_message.processed_plain_text is not None and not isinstance(
+ session_message.processed_plain_text, str
+ ):
+ session_message.processed_plain_text = None
+ session_message.display_message = message_dict.get("display_message")
+ if session_message.display_message is not None and not isinstance(session_message.display_message, str):
+ session_message.display_message = None
+
+ return session_message
From 3d22657707b757e73cb77bf52ee0e5ea924f7832 Mon Sep 17 00:00:00 2001
From: UnCLAS-Prommer
Date: Fri, 20 Mar 2026 21:39:19 +0800
Subject: [PATCH 14/45] =?UTF-8?q?refactor:=20supervisor=E9=83=A8=E5=88=86?=
=?UTF-8?q?=E6=96=B9=E6=B3=95=E9=87=8D=E5=86=99?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
src/plugin_runtime/host/supervisor.py | 681 ++++----------------------
1 file changed, 87 insertions(+), 594 deletions(-)
diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py
index bfa00cbf..82a5970b 100644
--- a/src/plugin_runtime/host/supervisor.py
+++ b/src/plugin_runtime/host/supervisor.py
@@ -1,98 +1,41 @@
-"""Supervisor - 插件生命周期管理
-
-负责:
-1. 拉起 Runner 子进程
-2. 健康检查 + 崩溃自动重启
-3. 代码热重载(generation 切换)
-4. 优雅关停
-"""
-
-from typing import Any, Dict, List, Optional, Tuple
+from pathlib import Path
+from typing import Optional, List, Dict, Any, Tuple, TYPE_CHECKING
import asyncio
-import contextlib
-import logging as stdlib_logging
-import os
-import sys
-from pathlib import Path
+
from src.common.logger import get_logger
-from src.config.config import MMC_VERSION, global_config
-from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN
-from src.plugin_runtime.host.capability_service import CapabilityService
-from src.plugin_runtime.host.component_registry import ComponentRegistry
-from src.plugin_runtime.host.event_dispatcher import EventDispatcher
-from src.plugin_runtime.host.policy_engine import PolicyEngine
-from src.plugin_runtime.host.rpc_server import RPCServer
-from src.plugin_runtime.host.workflow_executor import WorkflowExecutor, WorkflowContext, WorkflowResult
+from src.config.config import global_config
+from src.plugin_runtime.transport.factory import create_transport_server
from src.plugin_runtime.protocol.envelope import (
BootstrapPluginPayload,
ConfigUpdatedPayload,
Envelope,
HealthPayload,
LogBatchPayload,
- RegisterComponentsPayload,
+ RegisterPluginPayload,
RunnerReadyPayload,
ShutdownPayload,
)
-from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
-from src.plugin_runtime.transport.factory import create_transport_server
-logger = get_logger("plugin_runtime.host.supervisor")
+from .authorization import AuthorizationManager
+from .capability_service import CapabilityService
+from .rpc_server import RPCServer
+from .logger_bridge import RunnerLogBridge
+from .component_registry import ComponentRegistry
+from .event_dispatcher import EventDispatcher
+from .hook_dispatcher import HookDispatcher
+from .message_gateway import MessageGateway
+from .message_utils import PluginMessageUtils
+
+if TYPE_CHECKING:
+ from src.chat.message_receive.message import SessionMessage
+
+logger = get_logger("plugin_runtime.host.runner_manager")
-# ─── 日志桥 ──────────────────────────────────────────────────────
-
-
-class RunnerLogBridge:
- """将 Runner 进程上报的批量日志重放到主进程的 Logger 中。
-
- Runner 通过 ``runner.log_batch`` IPC 事件批量到达。
- 每条 LogEntry 被重建为一个真实的 :class:`logging.LogRecord` 并直接
- 调用 ``logging.getLogger(entry.logger_name).handle(record)``,
- 从而接入主进程已配置好的 structlog Handler 链。
- """
-
- async def handle_log_batch(self, envelope: Envelope) -> Envelope:
- """IPC 事件处理器:解析批量日志并重放到主进程 Logger。
-
- Args:
- envelope: 方法名为 ``runner.log_batch`` 的 IPC 事件信封。
-
- Returns:
- 空响应信封(事件模式下将被忽略)。
- """
- try:
- batch = LogBatchPayload.model_validate(envelope.payload)
- except Exception as exc:
- return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
-
- for entry in batch.entries:
- # 重建一个与原始日志尽量相符的 LogRecord
- record = stdlib_logging.LogRecord(
- name=entry.logger_name,
- level=entry.level,
- pathname="",
- lineno=0,
- msg=entry.message,
- args=(),
- exc_info=None,
- )
- record.created = entry.timestamp_ms / 1000.0
- record.msecs = entry.timestamp_ms % 1000
- if entry.exception_text:
- record.exc_text = entry.exception_text
-
- stdlib_logging.getLogger(entry.logger_name).handle(record)
-
- return envelope.make_response(payload={"accepted": True, "count": len(batch.entries)})
-
-
-class PluginSupervisor:
- """插件 Supervisor
-
- Host 端的核心管理器,负责整个插件 Runner 进程的生命周期。
- """
+class PluginRunnerSupervisor:
+ """插件的Runner管理器,负责管理Runner的生命周期"""
def __init__(
self,
@@ -103,45 +46,34 @@ class PluginSupervisor:
runner_spawn_timeout_sec: Optional[float] = None,
):
_cfg = global_config.plugin_runtime
- self._plugin_dirs = plugin_dirs or []
- self._health_interval = (
- health_check_interval_sec if health_check_interval_sec is not None else _cfg.health_check_interval_sec
- )
- self._runner_spawn_timeout = (
- runner_spawn_timeout_sec if runner_spawn_timeout_sec is not None else _cfg.runner_spawn_timeout_sec
- )
+ self._plugin_dirs: List[Path] = plugin_dirs or []
+ self._health_interval = health_check_interval_sec or _cfg.health_check_interval_sec or 30.0
+ self._runner_spawn_timeout = runner_spawn_timeout_sec or _cfg.runner_spawn_timeout_sec or 30.0
# 基础设施
self._transport = create_transport_server(socket_path=socket_path)
- self._policy = PolicyEngine()
- self._capability_service = CapabilityService(self._policy)
+ self._authorization = AuthorizationManager()
+ self._capability_service = CapabilityService(self._authorization)
self._component_registry = ComponentRegistry()
self._event_dispatcher = EventDispatcher(self._component_registry)
- self._workflow_executor = WorkflowExecutor(self._component_registry)
+ self._hook_dispatcher = HookDispatcher(self._component_registry)
+ self._message_gateway = MessageGateway(self._component_registry)
- # 编解码
+ # 编解码和服务器
from src.plugin_runtime.protocol.codec import MsgPackCodec
codec = MsgPackCodec()
-
- self._rpc_server = RPCServer(
- transport=self._transport,
- codec=codec,
- )
+ self._rpc_server = RPCServer(transport=self._transport, codec=codec)
# Runner 子进程
self._runner_process: Optional[asyncio.subprocess.Process] = None
- self._runner_generation: int = 0
- self._max_restart_attempts: int = (
- max_restart_attempts if max_restart_attempts is not None else _cfg.max_restart_attempts
- )
+ self._max_restart_attempts: int = max_restart_attempts or _cfg.max_restart_attempts or 3
self._restart_count: int = 0
# 已注册的插件组件信息
- self._registered_plugins: Dict[str, RegisterComponentsPayload] = {}
- self._staged_registered_plugins: Dict[str, RegisterComponentsPayload] = {}
- self._runner_ready_events: Dict[int, asyncio.Event] = {}
- self._runner_ready_payloads: Dict[int, RunnerReadyPayload] = {}
+ self._registered_plugins: Dict[str, RegisterPluginPayload] = {}
+ self._runner_ready_events: asyncio.Event = asyncio.Event()
+ self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload()
# 后台任务
self._health_task: Optional[asyncio.Task] = None
@@ -153,11 +85,11 @@ class PluginSupervisor:
self._log_bridge: RunnerLogBridge = RunnerLogBridge()
# 注册内部 RPC 方法
- self._register_internal_methods()
+ self._register_internal_methods() # TODO: 完成内部方法注册
@property
- def policy_engine(self) -> PolicyEngine:
- return self._policy
+ def authorization_manager(self) -> AuthorizationManager:
+ return self._authorization
@property
def capability_service(self) -> CapabilityService:
@@ -172,8 +104,12 @@ class PluginSupervisor:
return self._event_dispatcher
@property
- def workflow_executor(self) -> WorkflowExecutor:
- return self._workflow_executor
+ def hook_dispatcher(self) -> HookDispatcher:
+ return self._hook_dispatcher
+
+ @property
+ def message_gateway(self) -> MessageGateway:
+ return self._message_gateway
@property
def rpc_server(self) -> RPCServer:
@@ -182,64 +118,26 @@ class PluginSupervisor:
async def dispatch_event(
self,
event_type: str,
- message: Optional[Dict[str, Any]] = None,
+ message: Optional["SessionMessage"] = None,
extra_args: Optional[Dict[str, Any]] = None,
- ) -> Tuple[bool, Optional[Dict[str, Any]]]:
+ ) -> Tuple[bool, Optional["SessionMessage"]]:
"""分发事件到所有对应 handler 的快捷方法。"""
+ return await self._event_dispatcher.dispatch_event(event_type, self, message, extra_args)
- async def _invoke(plugin_id: str, component_name: str, args: Dict[str, Any]) -> Dict[str, Any]:
- resp = await self.invoke_plugin(
- method="plugin.emit_event",
- plugin_id=plugin_id,
- component_name=component_name,
- args=args,
- )
- return resp.payload
+ async def dispatch_hook(self, stage: str, **kwargs):
+ """分发Hook事件到所有对应 handler 的快捷方法。"""
+ return await self._hook_dispatcher.hook_dispatch(stage, self, **kwargs)
- return await self._event_dispatcher.dispatch_event(
- event_type=event_type,
- invoke_fn=_invoke,
- message=message,
- extra_args=extra_args,
- )
-
- async def execute_workflow(
+ async def send_message_to_external(
self,
- message: Optional[Dict[str, Any]] = None,
- stream_id: Optional[str] = None,
- context: Optional[WorkflowContext] = None,
- ) -> Tuple[WorkflowResult, Optional[Dict[str, Any]], WorkflowContext]:
- """执行 Workflow Pipeline 的快捷方法。"""
-
- async def _invoke(plugin_id: str, component_name: str, args: Dict[str, Any]) -> Dict[str, Any]:
- resp = await self.invoke_plugin(
- method="plugin.invoke_workflow_step",
- plugin_id=plugin_id,
- component_name=component_name,
- args=args,
- )
- payload = resp.payload
- if payload.get("success"):
- result = payload.get("result")
- return result if isinstance(result, dict) else {}
- raise RuntimeError(payload.get("result", "workflow step invoke failed"))
-
- async def _command_invoke(plugin_id: str, component_name: str, args: Dict[str, Any]) -> Dict[str, Any]:
- """命令走 plugin.invoke_command,保留原始返回值结构。"""
- resp = await self.invoke_plugin(
- method="plugin.invoke_command",
- plugin_id=plugin_id,
- component_name=component_name,
- args=args,
- )
- return resp.payload
-
- return await self._workflow_executor.execute(
- invoke_fn=_invoke,
- message=message,
- stream_id=stream_id,
- context=context,
- command_invoke_fn=_command_invoke,
+ internal_message: "SessionMessage",
+ *,
+ enabled_only: bool = True,
+ save_to_db: bool = True,
+ ) -> bool:
+ """发送系统内部消息到外部平台的快捷方法。"""
+ return await self._message_gateway.send_message_to_external(
+ internal_message, self, enabled_only=enabled_only, save_to_db=save_to_db
)
async def start(self) -> None:
@@ -253,17 +151,13 @@ class PluginSupervisor:
# 启动 RPC Server
await self._rpc_server.start()
-
- # 计算预期 generation(与 reload_plugins 保持一致)
- expected_generation = self._rpc_server.runner_generation + 1
-
# 拉起 Runner 进程
await self._spawn_runner()
# 等待 Runner 完成连接和初始化,避免 start() 返回时 Runner 尚未就绪
try:
- await self._wait_for_runner_generation(expected_generation, timeout_sec=self._runner_spawn_timeout)
- await self._wait_for_runner_ready(expected_generation, timeout_sec=self._runner_spawn_timeout)
+ await self._wait_for_runner_connection(timeout_sec=self._runner_spawn_timeout)
+ await self._wait_for_runner_ready(timeout_sec=self._runner_spawn_timeout)
except TimeoutError:
if not self._rpc_server.is_connected:
logger.warning(f"Runner 未在 {self._runner_spawn_timeout}s 内完成连接,后续操作可能失败")
@@ -279,6 +173,10 @@ class PluginSupervisor:
"""停止 Supervisor"""
self._running = False
+ # 停止组件
+ await self._event_dispatcher.stop()
+ await self._hook_dispatcher.stop()
+
# 停止健康检查
if self._health_task:
self._health_task.cancel()
@@ -305,439 +203,34 @@ class PluginSupervisor:
由主进程业务逻辑调用,通过 RPC 转发给 Runner。
"""
return await self._rpc_server.send_request(
- method=method,
- plugin_id=plugin_id,
- payload={
- "component_name": component_name,
- "args": args or {},
- },
- timeout_ms=timeout_ms,
+ method,
+ plugin_id,
+ {"component_name": component_name, "args": args or {}},
+ timeout_ms,
)
- async def reload_plugins(self, reason: str = "manual") -> bool:
- """热重载所有插件(进程级 generation 切换)
+ async def reload_plugin(self, plugin_id: str, reason: str = "manual") -> bool:
+ raise NotImplementedError("等待SDK完成") # TODO: 完成对应的调用和请求逻辑
- 1. 拉起新 Runner
- 2. 等待新 Runner 完成注册和健康检查
- 3. 关停旧 Runner
- """
- logger.info(f"开始热重载插件,原因: {reason}")
+ async def _wait_for_runner_connection(self, timeout_sec: float) -> None:
+ """等待 Runner 连接上 RPC Server"""
- # 保存旧进程引用和旧 session token(回滚时需要恢复)
- old_process = self._runner_process
- old_registered_plugins = dict(self._registered_plugins)
- old_session_token = self._rpc_server.session_token
- expected_generation = self._rpc_server.runner_generation + 1
+ async def wait_for_connection():
+ while self._running and not self._rpc_server.is_connected:
+ await asyncio.sleep(0.1)
- # 允许新 Runner 以 staged 方式接入,验证通过后再切换活跃连接
- self._rpc_server.begin_staged_takeover()
- self._staged_registered_plugins.clear()
-
- # 重新生成 session token,防止被终止的旧 Runner 重连
- self._rpc_server.reset_session_token()
-
- # 注意:不在此处调用 _clear_runtime_state()。
- # 旧组件在新 Runner 完成注册前继续提供服务,避免热重载窗口期内
- # dispatch_event / execute_workflow 找不到任何组件导致消息静默丢失。
- # ComponentRegistry.register_component 对同名组件是覆盖式写入,安全。
-
- # 拉起新 Runner
try:
- await self._spawn_runner()
- await self._wait_for_runner_generation(
- expected_generation,
- timeout_sec=self._runner_spawn_timeout,
- allow_staged=True,
- )
- await self._wait_for_runner_ready(expected_generation, timeout_sec=self._runner_spawn_timeout)
- resp = await self._rpc_server.send_request(
- "plugin.health",
- timeout_ms=5000,
- target_generation=expected_generation,
- )
- health = HealthPayload.model_validate(resp.payload)
- if not health.healthy:
- raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "新 Runner 健康检查失败")
- await self._rpc_server.commit_staged_takeover()
- except Exception as e:
- logger.error(f"新 Runner 健康检查失败: {e},回滚")
- await self._terminate_process(self._runner_process, old_process)
- await self._rpc_server.rollback_staged_takeover()
- self._runner_process = old_process
- self._rpc_server.restore_session_token(old_session_token)
- self._staged_registered_plugins.clear()
- self._registered_plugins = dict(old_registered_plugins)
- self._rebuild_runtime_state()
- return False
+ await asyncio.wait_for(wait_for_connection(), timeout=timeout_sec)
+ logger.info("Runner 已连接到 RPC Server")
+ except asyncio.TimeoutError as e:
+ raise TimeoutError(f"等待 Runner 连接超时({timeout_sec}s)") from e
- self._runner_generation = self._rpc_server.runner_generation
- self._registered_plugins = dict(self._staged_registered_plugins)
- self._staged_registered_plugins.clear()
- self._rebuild_runtime_state()
+ async def _wait_for_runner_ready(self, timeout_sec: float = 30.0) -> RunnerReadyPayload:
+ """等待 Runner 完成初始化并上报就绪"""
- # 关停旧 Runner
- if old_process and old_process.returncode is None:
- try:
- old_process.terminate()
- await asyncio.wait_for(old_process.wait(), timeout=10.0)
- except asyncio.TimeoutError:
- old_process.kill()
-
- logger.info("热重载完成")
- return True
-
- async def notify_plugin_config_updated(
- self,
- plugin_id: str,
- config_data: Dict[str, Any],
- config_version: str = "",
- ) -> bool:
- """通知指定插件其配置已更新。"""
- if plugin_id not in self._registered_plugins:
- return False
-
- payload = ConfigUpdatedPayload(
- plugin_id=plugin_id,
- config_version=config_version,
- config_data=config_data,
- )
- await self._rpc_server.send_request(
- "plugin.config_updated",
- plugin_id=plugin_id,
- payload=payload.model_dump(),
- timeout_ms=5000,
- )
- return True
-
- # ─── 内部方法 ──────────────────────────────────────────────
-
- def _register_internal_methods(self) -> None:
- """注册 Host 端的 RPC 方法处理器"""
- # Runner -> Host 的能力调用统一走 capability_service
- self._rpc_server.register_method("cap.request", self._capability_service.handle_capability_request)
- self._rpc_server.register_method("plugin.bootstrap", self._handle_bootstrap_plugin)
- # 插件注册
- self._rpc_server.register_method("plugin.register_components", self._handle_register_components)
- self._rpc_server.register_method("runner.ready", self._handle_runner_ready)
- # Runner 日志批量上报
- self._rpc_server.register_method("runner.log_batch", self._log_bridge.handle_log_batch)
-
- async def _handle_bootstrap_plugin(self, envelope: Envelope) -> Envelope:
- """处理插件 bootstrap 请求,仅同步能力令牌。"""
try:
- bootstrap = BootstrapPluginPayload.model_validate(envelope.payload)
- except Exception as e:
- return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e))
-
- active_generation = self._rpc_server.runner_generation
- staged_generation = self._rpc_server.staged_generation
- if envelope.generation not in {active_generation, staged_generation}:
- return envelope.make_error_response(
- ErrorCode.E_GENERATION_MISMATCH.value,
- f"插件 bootstrap generation 过期: {envelope.generation} 不在已知代际中",
- )
-
- if bootstrap.capabilities_required:
- self._policy.register_plugin(
- plugin_id=bootstrap.plugin_id,
- generation=envelope.generation,
- capabilities=bootstrap.capabilities_required,
- )
- else:
- self._policy.revoke_plugin(bootstrap.plugin_id, generation=envelope.generation)
-
- return envelope.make_response(payload={"accepted": True})
-
- async def _handle_register_components(self, envelope: Envelope) -> Envelope:
- """处理插件组件注册请求"""
- try:
- reg = RegisterComponentsPayload.model_validate(envelope.payload)
- except Exception as e:
- return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e))
-
- active_generation = self._rpc_server.runner_generation
- staged_generation = self._rpc_server.staged_generation
- if envelope.generation not in {active_generation, staged_generation}:
- return envelope.make_error_response(
- ErrorCode.E_GENERATION_MISMATCH.value,
- f"组件注册 generation 过期: {envelope.generation} 不在已知代际中",
- )
-
- if envelope.generation == staged_generation and staged_generation != 0:
- self._staged_registered_plugins[reg.plugin_id] = reg
- logger.info(
- f"插件 {reg.plugin_id} v{reg.plugin_version} staged 注册成功,"
- f"组件数: {len(reg.components)}, 能力需求: {reg.capabilities_required}"
- )
- return envelope.make_response(payload={"accepted": True, "staged": True})
-
- self._registered_plugins[reg.plugin_id] = reg
-
- # 在策略引擎中注册插件
- self._policy.register_plugin(
- plugin_id=reg.plugin_id,
- generation=envelope.generation,
- capabilities=reg.capabilities_required or [],
- )
-
- # 同 generation 下重新注册时,以本次声明为准,避免残留幽灵组件
- self._component_registry.remove_components_by_plugin(reg.plugin_id)
- self._component_registry.register_plugin_components(
- plugin_id=reg.plugin_id,
- components=[c.model_dump() for c in reg.components],
- )
-
- stats = self._component_registry.get_stats()
- logger.info(
- f"插件 {reg.plugin_id} v{reg.plugin_version} 注册成功,"
- f"组件数: {len(reg.components)}, 能力需求: {reg.capabilities_required},"
- f"注册表总计: {stats}"
- )
-
- return envelope.make_response(payload={"accepted": True})
-
- async def _handle_runner_ready(self, envelope: Envelope) -> Envelope:
- """处理 Runner 初始化完成信号。"""
- try:
- ready = RunnerReadyPayload.model_validate(envelope.payload)
- except Exception as e:
- return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e))
-
- event = self._runner_ready_events.setdefault(envelope.generation, asyncio.Event())
- self._runner_ready_payloads[envelope.generation] = ready
- event.set()
- logger.info(
- f"Runner generation={envelope.generation} 已就绪,成功插件数: {len(ready.loaded_plugins)},"
- f"失败插件数: {len(ready.failed_plugins)}"
- )
- return envelope.make_response(payload={"accepted": True})
-
- async def _spawn_runner(self) -> None:
- """拉起 Runner 子进程"""
- runner_module = "src.plugin_runtime.runner.runner_main"
- address = self._transport.get_address()
- token = self._rpc_server.session_token
-
- env = os.environ.copy()
- env[ENV_IPC_ADDRESS] = address
- env[ENV_SESSION_TOKEN] = token
- env[ENV_PLUGIN_DIRS] = os.pathsep.join(str(p) for p in self._plugin_dirs)
- env[ENV_HOST_VERSION] = MMC_VERSION
-
- self._runner_process = await asyncio.create_subprocess_exec(
- sys.executable,
- "-m",
- runner_module,
- env=env,
- # stdout 不捕获:Runner 的日志均通过 IPC 传㛹(RunnerIPCLogHandler)
- stdout=None,
- # stderr 捕获为 PIPE,仅用于 IPC 建立前的进程级致命错误输出
- stderr=asyncio.subprocess.PIPE,
- )
-
- self._attach_stderr_drain(self._runner_process)
- self._runner_generation = self._rpc_server.runner_generation
- logger.info(f"Runner 子进程已启动: pid={self._runner_process.pid}, generation={self._runner_generation}")
-
- async def _shutdown_runner(self) -> None:
- """优雅关停 Runner"""
- if not self._runner_process or self._runner_process.returncode is not None:
- return
-
- # 发送 prepare_shutdown
- try:
- if self._rpc_server.is_connected:
- shutdown_payload = ShutdownPayload(reason="host_shutdown", drain_timeout_ms=5000)
- await self._rpc_server.send_request(
- "plugin.prepare_shutdown",
- payload=shutdown_payload.model_dump(),
- timeout_ms=5000,
- )
- await self._rpc_server.send_request(
- "plugin.shutdown",
- payload=shutdown_payload.model_dump(),
- timeout_ms=5000,
- )
- except Exception as e:
- logger.warning(f"发送关停命令失败: {e}")
-
- # 等待进程退出
- try:
- await asyncio.wait_for(self._runner_process.wait(), timeout=10.0)
- except asyncio.TimeoutError:
- logger.warning("Runner 未在超时内退出,强制终止")
- self._runner_process.kill()
- await self._runner_process.wait()
-
- await self._cleanup_stderr_drain()
-
- async def _health_check_loop(self) -> None:
- """周期性健康检查 + 崩溃自动重启"""
- while self._running:
- await asyncio.sleep(self._health_interval)
-
- # 检查 Runner 进程是否意外退出
- if self._runner_process and self._runner_process.returncode is not None:
- exit_code = self._runner_process.returncode
- logger.warning(f"Runner 进程已退出 (exit_code={exit_code})")
-
- if self._restart_count < self._max_restart_attempts:
- self._restart_count += 1
- logger.info(f"尝试重启 Runner ({self._restart_count}/{self._max_restart_attempts})")
- # 清理旧的组件注册
- for plugin_id in list(self._registered_plugins.keys()):
- self._component_registry.remove_components_by_plugin(plugin_id)
- self._policy.revoke_plugin(plugin_id)
- self._registered_plugins.clear()
-
- try:
- self._clear_runtime_state()
- # 重新生成 session token,防止旧 Runner 僵尸进程用旧 token 重连
- self._rpc_server.reset_session_token()
- await self._spawn_runner()
- except Exception as e:
- logger.error(f"Runner 重启失败: {e}", exc_info=True)
- else:
- logger.error(f"Runner 连续崩溃 {self._max_restart_attempts} 次,停止重启")
- continue
-
- if not self._rpc_server.is_connected:
- logger.warning("Runner 未连接,跳过健康检查")
- continue
-
- try:
- resp = await self._rpc_server.send_request("plugin.health", timeout_ms=5000)
- health = HealthPayload.model_validate(resp.payload)
- if not health.healthy:
- logger.warning(f"Runner 健康检查异常: {health}")
- else:
- # 健康检查成功,重置重启计数
- self._restart_count = 0
- except RPCError as e:
- logger.error(f"健康检查失败: {e}")
- except asyncio.CancelledError:
- break
- except Exception as e:
- logger.error(f"健康检查异常: {e}")
-
- async def _wait_for_runner_generation(
- self,
- expected_generation: int,
- timeout_sec: float,
- allow_staged: bool = False,
- ) -> None:
- """等待指定代际的 Runner 完成连接。"""
- deadline = asyncio.get_running_loop().time() + timeout_sec
- while asyncio.get_running_loop().time() < deadline:
- if allow_staged and self._rpc_server.has_generation(expected_generation):
- return
- if self._rpc_server.is_connected and self._rpc_server.runner_generation >= expected_generation:
- self._runner_generation = self._rpc_server.runner_generation
- return
- await asyncio.sleep(0.1)
- raise TimeoutError(f"等待 Runner generation {expected_generation} 超时")
-
- async def _wait_for_runner_ready(self, expected_generation: int, timeout_sec: float) -> RunnerReadyPayload:
- """等待指定代际的 Runner 完成初始化。"""
- event = self._runner_ready_events.setdefault(expected_generation, asyncio.Event())
- await asyncio.wait_for(event.wait(), timeout=timeout_sec)
- return self._runner_ready_payloads.get(expected_generation, RunnerReadyPayload())
-
- def _clear_runtime_state(self) -> None:
- """清空当前插件注册态。"""
- self._component_registry.clear()
- self._policy.clear()
- self._registered_plugins.clear()
- self._staged_registered_plugins.clear()
-
- def _rebuild_runtime_state(self) -> None:
- """根据已记录的插件注册信息重建运行时状态。"""
- self._component_registry.clear()
- self._policy.clear()
- for reg in self._registered_plugins.values():
- self._policy.register_plugin(
- plugin_id=reg.plugin_id,
- generation=self._rpc_server.runner_generation,
- capabilities=reg.capabilities_required or [],
- )
- self._component_registry.register_plugin_components(
- plugin_id=reg.plugin_id,
- components=[c.model_dump() for c in reg.components],
- )
-
- def _attach_stderr_drain(self, process: asyncio.subprocess.Process) -> None:
- """为 Runner stderr 创建排空任务,捕获 IPC 建立前的进程级错误输出。
-
- stderr 中的内容通常是:
- - Runner 启动早期(握手完成之前)的日志
- - 进程级致命错误(ImportError、SyntaxError等)
- - 异常进程退出前的最后输出
-
- 握手成功后,插件的所有日志均经由 RunnerIPCLogHandler 通过 IPC 传输。
- """
- if process.stderr is None:
- return
- task = asyncio.create_task(
- self._drain_runner_stderr(process.stderr, process.pid),
- name=f"runner_stderr_drain:{process.pid}",
- )
- self._stderr_drain_task = task
- task.add_done_callback(
- lambda done_task: None if self._stderr_drain_task is not done_task else self._clear_stderr_drain_task()
- )
-
- def _clear_stderr_drain_task(self) -> None:
- self._stderr_drain_task = None
-
- async def _drain_runner_stderr(
- self,
- stream: asyncio.StreamReader,
- pid: int,
- ) -> None:
- """持续读取 Runner stderr 并转发到 Host Logger,防止 PIPE 锡死子进程。
-
- Args:
- stream: Runner 子进程的 stderr 流。
- pid: 子进程 PID,仅用于日志上下文。
- """
- try:
- while True:
- line = await stream.readline()
- if not line:
- break
- if message := line.decode(errors="replace").rstrip():
- # 将 stderr 输出以 WARNING 级展示:
- # 如果 Runner 正常运行,此流应当无输出;
- # 有输出说明进程级错误发生,需要出现在主进程日志中
- logger.warning(f"[runner:{pid}:stderr] {message}")
- except asyncio.CancelledError:
- raise
- except Exception as exc:
- logger.debug(f"读取 Runner stderr 失败 (pid={pid}): {exc}")
-
- async def _cleanup_stderr_drain(self) -> None:
- """等待并取消 stderr 排空任务。"""
- if self._stderr_drain_task is None:
- return
- task = self._stderr_drain_task
- self._stderr_drain_task = None
- if not task.done():
- task.cancel()
- with contextlib.suppress(Exception):
- await asyncio.gather(task, return_exceptions=True)
-
- @staticmethod
- async def _terminate_process(
- process: Optional[asyncio.subprocess.Process],
- keep_process: Optional[asyncio.subprocess.Process] = None,
- ) -> None:
- """终止指定进程,但跳过需要保留的旧进程引用。"""
- if process is None or process is keep_process or process.returncode is not None:
- return
-
- process.terminate()
- try:
- await asyncio.wait_for(process.wait(), timeout=10.0)
- except asyncio.TimeoutError:
- process.kill()
- await process.wait()
+ await asyncio.wait_for(self._runner_ready_events.wait(), timeout=timeout_sec)
+ logger.info("Runner 已完成初始化并上报就绪")
+ return self._runner_ready_payloads
+ except asyncio.TimeoutError as e:
+ raise TimeoutError(f"等待 Runner 就绪超时({timeout_sec}s)") from e
From 04f260e570070a792ff00906f5694877a001637e Mon Sep 17 00:00:00 2001
From: DrSmoothl <1787882683@qq.com>
Date: Fri, 20 Mar 2026 01:15:17 +0800
Subject: [PATCH 15/45] =?UTF-8?q?=E5=AE=9E=E7=8E=B0=E5=AE=8C=E6=95=B4?=
=?UTF-8?q?=E7=9A=84=E6=B6=88=E6=81=AF=E4=B8=AD=E9=97=B4=E5=B1=82=E5=9C=B0?=
=?UTF-8?q?=E5=9F=BA=EF=BC=8C=E6=9A=82=E6=9C=AA=E6=8E=A5=E5=85=A5=E5=AE=9E?=
=?UTF-8?q?=E9=99=85=E7=9A=84=E6=B6=88=E6=81=AF=E6=B5=81?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
src/platform_io/__init__.py | 35 ++
src/platform_io/dedupe.py | 133 ++++++
src/platform_io/drivers/__init__.py | 11 +
src/platform_io/drivers/base.py | 104 +++++
src/platform_io/drivers/legacy_driver.py | 61 +++
src/platform_io/drivers/plugin_driver.py | 64 +++
src/platform_io/manager.py | 529 +++++++++++++++++++++++
src/platform_io/outbound_tracker.py | 242 +++++++++++
src/platform_io/registry.py | 70 +++
src/platform_io/route_key_factory.py | 150 +++++++
src/platform_io/routing.py | 202 +++++++++
src/platform_io/types.py | 240 ++++++++++
12 files changed, 1841 insertions(+)
create mode 100644 src/platform_io/__init__.py
create mode 100644 src/platform_io/dedupe.py
create mode 100644 src/platform_io/drivers/__init__.py
create mode 100644 src/platform_io/drivers/base.py
create mode 100644 src/platform_io/drivers/legacy_driver.py
create mode 100644 src/platform_io/drivers/plugin_driver.py
create mode 100644 src/platform_io/manager.py
create mode 100644 src/platform_io/outbound_tracker.py
create mode 100644 src/platform_io/registry.py
create mode 100644 src/platform_io/route_key_factory.py
create mode 100644 src/platform_io/routing.py
create mode 100644 src/platform_io/types.py
diff --git a/src/platform_io/__init__.py b/src/platform_io/__init__.py
new file mode 100644
index 00000000..380ecbb6
--- /dev/null
+++ b/src/platform_io/__init__.py
@@ -0,0 +1,35 @@
+"""导出 Platform IO 层的公开入口。
+
+当前仍处于地基阶段,调用方应优先从这里导入共享类型和全局管理器,
+而不是直接依赖更底层的私有子模块。
+"""
+
+from .manager import PlatformIOManager, get_platform_io_manager
+from .route_key_factory import RouteKeyFactory
+from .routing import RouteBindingConflictError, RouteTable
+from .types import (
+ DeliveryReceipt,
+ DeliveryStatus,
+ DriverDescriptor,
+ DriverKind,
+ InboundMessageEnvelope,
+ RouteBinding,
+ RouteKey,
+ RouteMode,
+)
+
+__all__ = [
+ "DeliveryReceipt",
+ "DeliveryStatus",
+ "DriverDescriptor",
+ "DriverKind",
+ "InboundMessageEnvelope",
+ "PlatformIOManager",
+ "RouteKeyFactory",
+ "RouteBinding",
+ "RouteBindingConflictError",
+ "RouteKey",
+ "RouteMode",
+ "RouteTable",
+ "get_platform_io_manager",
+]
diff --git a/src/platform_io/dedupe.py b/src/platform_io/dedupe.py
new file mode 100644
index 00000000..4c5c55a2
--- /dev/null
+++ b/src/platform_io/dedupe.py
@@ -0,0 +1,133 @@
+"""提供 Platform IO 的轻量入站消息去重能力。
+
+当前实现基于 ``dict + heapq``:
+- ``dict`` 保存去重键到过期时间的映射
+- ``heapq`` 维护按过期时间排序的小顶堆
+
+这样就不需要在每次检查时全表扫描,而是通过懒清理逐步弹出已经过期
+或已经失效的堆节点。
+"""
+
+from typing import Dict, List, Tuple
+
+import heapq
+import time
+
+
+class MessageDeduplicator:
+ """使用基于 TTL 的内存缓存进行入站消息去重。
+
+ 主要用于解决同一条外部消息被重复送入 Core 的问题,例如双路径并存、
+ 适配器重试、重连或重复回调等场景。Broker 可以借助这个组件在进入
+ Core 前先拦住重复投递,避免重复处理、重复回复和重复入库。
+
+ 当前实现使用 ``dict + heapq`` 维护过期时间:
+ - ``dict`` 负责 ``O(1)`` 级别的去重键查找
+ - ``heapq`` 负责按过期时间顺序做懒清理
+
+ 这比“每次调用都全表扫描过期项”的实现更适合高吞吐消息场景。
+
+ Notes:
+ 复杂度说明如下,设 ``n`` 为当前缓存中的有效去重键数量:
+
+ - 单次 ``mark_seen()`` 在常见路径下的时间复杂度接近 ``O(log n)``
+ - 从长期摊还角度看,``mark_seen()`` 的时间复杂度也接近 ``O(log n)``
+ - 如果某次调用恰好触发一批过期键的集中清理,则该次调用的最坏时间复杂度
+ 可达到 ``O(k log n)``,其中 ``k`` 为本次被弹出或清理的键数量
+ - 空间复杂度为 ``O(n)``
+ """
+
+ def __init__(self, ttl_seconds: float = 300.0, max_entries: int = 10000) -> None:
+ """初始化去重器。
+
+ Args:
+ ttl_seconds: 每个去重键在缓存中的保留时长,单位为秒。
+ max_entries: 缓存允许保留的最大有效键数量,超出后会触发
+ 机会性淘汰。
+
+ Raises:
+ ValueError: 当 ``ttl_seconds`` 或 ``max_entries`` 非正数时抛出。
+ """
+ if ttl_seconds <= 0:
+ raise ValueError("ttl_seconds 必须大于 0")
+ if max_entries <= 0:
+ raise ValueError("max_entries 必须大于 0")
+
+ self._ttl_seconds = ttl_seconds
+ self._max_entries = max_entries
+ self._expire_heap: List[Tuple[float, str]] = []
+ self._seen: Dict[str, float] = {}
+
+ def mark_seen(self, dedupe_key: str) -> bool:
+ """标记一条去重键已经出现过。
+
+ Args:
+ dedupe_key: 能稳定标识一条外部入站消息的去重键。
+
+ Returns:
+ bool: 若该键在当前 TTL 窗口内首次出现则返回 ``True``,
+ 否则返回 ``False``。
+
+ Notes:
+ 方法会先基于小顶堆做一次懒清理,再判断当前键是否仍在有效期内。
+ 如果缓存已达到上限,则会优先淘汰“最早过期的仍然有效的键”。
+
+ 复杂度方面,常见路径下该方法接近 ``O(log n)``;如果恰好需要
+ 集中清理一批过期键,则单次调用最坏可达到 ``O(k log n)``。
+ """
+ now = time.monotonic()
+ self._purge_expired(now)
+
+ expires_at = self._seen.get(dedupe_key)
+ if expires_at is not None and expires_at > now:
+ return False
+
+ if len(self._seen) >= self._max_entries:
+ self._evict_earliest_live()
+
+ expires_at = now + self._ttl_seconds
+ self._seen[dedupe_key] = expires_at
+ heapq.heappush(self._expire_heap, (expires_at, dedupe_key))
+ return True
+
+ def clear(self) -> None:
+ """清空全部去重缓存。"""
+ self._expire_heap.clear()
+ self._seen.clear()
+
+ def _purge_expired(self, now: float) -> None:
+ """从缓存中清理已经过期的去重键。
+
+ Args:
+ now: 当前单调时钟时间戳。
+
+ Notes:
+ 堆中可能存在旧版本节点。例如同一个 ``dedupe_key`` 被重新写入后,
+ 旧的过期时间节点仍会留在堆里。这里会通过和 ``dict`` 中当前值比对,
+ 跳过这类失效节点。
+ """
+ while self._expire_heap and self._expire_heap[0][0] <= now:
+ expires_at, dedupe_key = heapq.heappop(self._expire_heap)
+ current_expires_at = self._seen.get(dedupe_key)
+ if current_expires_at is None:
+ continue
+ if current_expires_at != expires_at:
+ continue
+ self._seen.pop(dedupe_key, None)
+
+ def _evict_earliest_live(self) -> None:
+ """当缓存达到容量上限时,淘汰一条最早过期的有效键。
+
+ Notes:
+ 堆顶可能是已经过期或已失效的旧节点,因此这里同样需要循环弹出,
+ 直到找到一条当前仍然在 ``dict`` 中生效的键。
+ """
+ while self._expire_heap:
+ expires_at, dedupe_key = heapq.heappop(self._expire_heap)
+ current_expires_at = self._seen.get(dedupe_key)
+ if current_expires_at is None:
+ continue
+ if current_expires_at != expires_at:
+ continue
+ self._seen.pop(dedupe_key, None)
+ return
diff --git a/src/platform_io/drivers/__init__.py b/src/platform_io/drivers/__init__.py
new file mode 100644
index 00000000..b12120cf
--- /dev/null
+++ b/src/platform_io/drivers/__init__.py
@@ -0,0 +1,11 @@
+"""导出 Platform IO 层的公开驱动类型。"""
+
+from .base import PlatformIODriver
+from .legacy_driver import LegacyPlatformDriver
+from .plugin_driver import PluginPlatformDriver
+
+__all__ = [
+ "LegacyPlatformDriver",
+ "PlatformIODriver",
+ "PluginPlatformDriver",
+]
diff --git a/src/platform_io/drivers/base.py b/src/platform_io/drivers/base.py
new file mode 100644
index 00000000..c6173d8c
--- /dev/null
+++ b/src/platform_io/drivers/base.py
@@ -0,0 +1,104 @@
+"""定义 Platform IO 传输驱动的基础抽象协议。"""
+
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional
+
+from src.platform_io.types import DeliveryReceipt, DriverDescriptor, InboundMessageEnvelope, RouteKey
+
+if TYPE_CHECKING:
+ from src.chat.message_receive.message import SessionMessage
+
+InboundHandler = Callable[[InboundMessageEnvelope], Awaitable[bool]]
+
+
+class PlatformIODriver(ABC):
+ """定义所有 Platform IO 驱动都必须实现的最小契约。
+
+ 当前实现故意保持接口很小,让中间层可以先落地,再逐步把 legacy
+ 与 plugin 路径的真实收发能力迁入这套协议之下。
+ """
+
+ def __init__(self, descriptor: DriverDescriptor) -> None:
+ """使用驱动描述对象初始化驱动。
+
+ Args:
+ descriptor: 注册到 Broker 中的静态驱动元数据。
+ """
+ self._descriptor = descriptor
+ self._inbound_handler: Optional[InboundHandler] = None
+
+ @property
+ def descriptor(self) -> DriverDescriptor:
+ """返回当前驱动的描述对象。
+
+ Returns:
+ DriverDescriptor: 当前驱动实例对应的描述对象。
+ """
+ return self._descriptor
+
+ @property
+ def driver_id(self) -> str:
+ """返回驱动标识。
+
+ Returns:
+ str: 当前驱动的唯一 ID。
+ """
+ return self._descriptor.driver_id
+
+ def set_inbound_handler(self, handler: InboundHandler) -> None:
+ """注册入站消息交回 Broker 的回调函数。
+
+ Args:
+ handler: 将规范化入站封装继续转发给 Broker 的异步回调。
+ """
+ self._inbound_handler = handler
+
+ def clear_inbound_handler(self) -> None:
+ """清除当前注册的入站回调函数。"""
+ self._inbound_handler = None
+
+ async def emit_inbound(self, envelope: InboundMessageEnvelope) -> bool:
+ """将一条入站封装转交给 Broker 回调。
+
+ Args:
+ envelope: 由驱动产出的规范化入站封装。
+
+ Returns:
+ bool: 若 Broker 接受该入站消息则返回 ``True``,否则返回 ``False``。
+ """
+
+ if self._inbound_handler is None:
+ return False
+ return await self._inbound_handler(envelope)
+
+ async def start(self) -> None:
+ """启动驱动生命周期。
+
+ 子类后续若需要初始化逻辑,可以覆盖这个钩子。
+ """
+ return None
+
+ async def stop(self) -> None:
+ """停止驱动生命周期。
+
+ 子类后续若需要清理逻辑,可以覆盖这个钩子。
+ """
+ return None
+
+ @abstractmethod
+ async def send_message(
+ self,
+ message: "SessionMessage",
+ route_key: RouteKey,
+ metadata: Optional[Dict[str, Any]] = None,
+ ) -> DeliveryReceipt:
+ """通过具体驱动发送一条消息。
+
+ Args:
+ message: 要投递的内部会话消息。
+ route_key: Broker 为本次投递选中的路由键。
+ metadata: 本次出站投递可选的 Broker 侧元数据。
+
+ Returns:
+ DeliveryReceipt: 规范化后的投递结果。
+ """
diff --git a/src/platform_io/drivers/legacy_driver.py b/src/platform_io/drivers/legacy_driver.py
new file mode 100644
index 00000000..bd74d8c7
--- /dev/null
+++ b/src/platform_io/drivers/legacy_driver.py
@@ -0,0 +1,61 @@
+"""提供 Platform IO 的 legacy 传输驱动骨架。"""
+
+from typing import TYPE_CHECKING, Any, Dict, Optional
+
+from src.platform_io.drivers.base import PlatformIODriver
+from src.platform_io.types import DeliveryReceipt, DriverDescriptor, DriverKind, RouteKey
+
+if TYPE_CHECKING:
+ from src.chat.message_receive.message import SessionMessage
+
+
+class LegacyPlatformDriver(PlatformIODriver):
+ """面向 ``maim_message`` 旧链路的 Platform IO 驱动骨架。"""
+
+ def __init__(
+ self,
+ driver_id: str,
+ platform: str,
+ account_id: Optional[str] = None,
+ scope: Optional[str] = None,
+ metadata: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ """初始化一个 legacy 驱动描述对象。
+
+ Args:
+ driver_id: Broker 内的唯一驱动 ID。
+ platform: 该 legacy 适配器链路负责的平台。
+ account_id: 可选的账号 ID 或 self ID。
+ scope: 可选的额外路由作用域。
+ metadata: 可选的额外驱动元数据。
+ """
+ descriptor = DriverDescriptor(
+ driver_id=driver_id,
+ kind=DriverKind.LEGACY,
+ platform=platform,
+ account_id=account_id,
+ scope=scope,
+ metadata=metadata or {},
+ )
+ super().__init__(descriptor)
+
+ async def send_message(
+ self,
+ message: "SessionMessage",
+ route_key: RouteKey,
+ metadata: Optional[Dict[str, Any]] = None,
+ ) -> DeliveryReceipt:
+ """通过 legacy 传输路径发送消息。
+
+ Args:
+ message: 要投递的内部会话消息。
+ route_key: Broker 为本次投递选择的路由键。
+ metadata: 本次出站投递可选的 Broker 侧元数据。
+
+ Returns:
+ DeliveryReceipt: 由驱动返回的规范化回执。
+
+ Raises:
+ NotImplementedError: 当前仍处于骨架阶段,尚未真正接入旧发送链。
+ """
+ raise NotImplementedError("LegacyPlatformDriver 仅完成地基实现,尚未接入旧发送链")
diff --git a/src/platform_io/drivers/plugin_driver.py b/src/platform_io/drivers/plugin_driver.py
new file mode 100644
index 00000000..9c139309
--- /dev/null
+++ b/src/platform_io/drivers/plugin_driver.py
@@ -0,0 +1,64 @@
+"""提供 Platform IO 的 plugin 传输驱动骨架。"""
+
+from typing import TYPE_CHECKING, Any, Dict, Optional
+
+from src.platform_io.drivers.base import PlatformIODriver
+from src.platform_io.types import DeliveryReceipt, DriverDescriptor, DriverKind, RouteKey
+
+if TYPE_CHECKING:
+ from src.chat.message_receive.message import SessionMessage
+
+
+class PluginPlatformDriver(PlatformIODriver):
+ """面向 ``MessageGateway`` 插件链路的 Platform IO 驱动骨架。"""
+
+ def __init__(
+ self,
+ driver_id: str,
+ platform: str,
+ account_id: Optional[str] = None,
+ scope: Optional[str] = None,
+ plugin_id: Optional[str] = None,
+ metadata: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ """初始化一个 plugin 驱动描述对象。
+
+ Args:
+ driver_id: Broker 内的唯一驱动 ID。
+ platform: 该 plugin 适配器链路负责的平台。
+ account_id: 可选的账号 ID 或 self ID。
+ scope: 可选的额外路由作用域。
+ plugin_id: 拥有该适配器实现的插件 ID,可为空。
+ metadata: 可选的额外驱动元数据。
+ """
+ descriptor = DriverDescriptor(
+ driver_id=driver_id,
+ kind=DriverKind.PLUGIN,
+ platform=platform,
+ account_id=account_id,
+ scope=scope,
+ plugin_id=plugin_id,
+ metadata=metadata or {},
+ )
+ super().__init__(descriptor)
+
+ async def send_message(
+ self,
+ message: "SessionMessage",
+ route_key: RouteKey,
+ metadata: Optional[Dict[str, Any]] = None,
+ ) -> DeliveryReceipt:
+ """通过 plugin 传输路径发送消息。
+
+ Args:
+ message: 要投递的内部会话消息。
+ route_key: Broker 为本次投递选择的路由键。
+ metadata: 本次出站投递可选的 Broker 侧元数据。
+
+ Returns:
+ DeliveryReceipt: 由驱动返回的规范化回执。
+
+ Raises:
+ NotImplementedError: 当前仍处于骨架阶段,尚未真正接入 MessageGateway。
+ """
+ raise NotImplementedError("PluginPlatformDriver 仅完成地基实现,尚未接入 MessageGateway")
diff --git a/src/platform_io/manager.py b/src/platform_io/manager.py
new file mode 100644
index 00000000..6135a567
--- /dev/null
+++ b/src/platform_io/manager.py
@@ -0,0 +1,529 @@
+"""提供 Platform IO 层的中心 Broker 管理器。"""
+
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional
+
+import hashlib
+import json
+
+from src.common.logger import get_logger
+from src.platform_io.drivers.base import PlatformIODriver
+
+from .dedupe import MessageDeduplicator
+from .outbound_tracker import OutboundTracker
+from .route_key_factory import RouteKeyFactory
+from .registry import DriverRegistry
+from .routing import RouteTable
+from .types import DeliveryReceipt, DeliveryStatus, InboundMessageEnvelope, RouteBinding, RouteKey
+
+if TYPE_CHECKING:
+ from src.chat.message_receive.message import SessionMessage
+
+logger = get_logger("platform_io.manager")
+
+InboundDispatcher = Callable[[InboundMessageEnvelope], Awaitable[None]]
+
+
+class PlatformIOManager:
+ """统一协调双路径平台消息 IO 的路由、去重与状态跟踪。
+
+ 这个管理器预期会成为 legacy 适配器链路与 plugin 适配器链路之间的
+ 唯一裁决点。当前地基阶段,它只提供共享状态和 Broker 侧契约,还没有
+ 真正把生产流量切到新中间层。
+ """
+
+ def __init__(self) -> None:
+ """初始化 Broker 管理器及其内存状态。"""
+ self._driver_registry = DriverRegistry()
+ self._route_table = RouteTable()
+ self._deduplicator = MessageDeduplicator()
+ self._outbound_tracker = OutboundTracker()
+ self._inbound_dispatcher: Optional[InboundDispatcher] = None
+ self._started = False
+
+ @property
+ def is_started(self) -> bool:
+ """返回 Broker 当前是否已进入运行态。
+
+ Returns:
+ bool: 若 Broker 已启动则返回 ``True``。
+ """
+ return self._started
+
+ async def start(self) -> None:
+ """启动 Broker,并依次启动当前已注册的全部驱动。
+
+ Raises:
+ Exception: 当某个驱动启动失败时,异常会继续上抛;已成功启动的驱动
+ 会被自动回滚停止。
+ """
+ if self._started:
+ return
+
+ started_drivers: list[PlatformIODriver] = []
+ try:
+ for driver in self._driver_registry.list():
+ await driver.start()
+ started_drivers.append(driver)
+ except Exception:
+ for driver in reversed(started_drivers):
+ try:
+ await driver.stop()
+ except Exception:
+ logger.exception("回滚驱动停止失败: driver_id=%s", driver.driver_id)
+ raise
+
+ self._started = True
+
+ async def stop(self) -> None:
+ """停止 Broker,并按逆序停止全部已注册驱动。
+
+ 停止完成后,会同步清空仅对当前运行周期有效的去重缓存和出站跟踪状态,
+ 避免下一次启动时继续沿用上一个运行周期的瞬时内存数据。
+
+ Raises:
+ RuntimeError: 当一个或多个驱动停止失败时抛出汇总异常。
+ """
+ if not self._started:
+ return
+
+ stop_errors: list[str] = []
+ for driver in reversed(self._driver_registry.list()):
+ try:
+ await driver.stop()
+ except Exception as exc:
+ stop_errors.append(f"{driver.driver_id}: {exc}")
+ logger.exception("驱动停止失败: driver_id=%s", driver.driver_id)
+
+ self._started = False
+ self._deduplicator.clear()
+ self._outbound_tracker.clear()
+ if stop_errors:
+ raise RuntimeError(f"部分驱动停止失败: {'; '.join(stop_errors)}")
+
+ async def add_driver(self, driver: PlatformIODriver) -> None:
+ """向运行中的 Broker 注册并启动一个驱动。
+
+ 如果 Broker 尚未启动,则该方法等价于 ``register_driver()``。
+
+ Args:
+ driver: 要添加的驱动实例。
+
+ Raises:
+ Exception: 当驱动启动失败时,注册会自动回滚,异常继续上抛。
+ """
+ self._register_driver_internal(driver)
+ if not self._started:
+ return
+
+ try:
+ await driver.start()
+ except Exception:
+ self._unregister_driver_internal(driver.driver_id)
+ raise
+
+ async def remove_driver(self, driver_id: str) -> Optional[PlatformIODriver]:
+ """从运行中的 Broker 停止并移除一个驱动。
+
+ 如果 Broker 尚未启动,则该方法等价于 ``unregister_driver()``。
+
+ Args:
+ driver_id: 要移除的驱动 ID。
+
+ Returns:
+ Optional[PlatformIODriver]: 若驱动存在,则返回被移除的驱动实例。
+
+ Raises:
+ Exception: 当 Broker 运行中且驱动停止失败时,异常会继续上抛。
+ """
+ if not self._started:
+ return self.unregister_driver(driver_id)
+
+ driver = self._driver_registry.get(driver_id)
+ if driver is None:
+ return None
+
+ await driver.stop()
+ return self._unregister_driver_internal(driver_id)
+
+ @property
+ def driver_registry(self) -> DriverRegistry:
+ """返回管理器持有的驱动注册表。
+
+ Returns:
+ DriverRegistry: 用于保存全部已注册驱动的注册表。
+ """
+ return self._driver_registry
+
+ @property
+ def route_table(self) -> RouteTable:
+ """返回管理器持有的路由绑定表。
+
+ Returns:
+ RouteTable: 用于归属解析的路由绑定表。
+ """
+ return self._route_table
+
+ @property
+ def deduplicator(self) -> MessageDeduplicator:
+ """返回管理器持有的入站去重器。
+
+ Returns:
+ MessageDeduplicator: 用于抑制重复入站的去重器。
+ """
+ return self._deduplicator
+
+ @property
+ def outbound_tracker(self) -> OutboundTracker:
+ """返回管理器持有的出站跟踪器。
+
+ Returns:
+ OutboundTracker: 用于记录出站 pending 状态与回执的跟踪器。
+ """
+ return self._outbound_tracker
+
+ def set_inbound_dispatcher(self, dispatcher: InboundDispatcher) -> None:
+ """设置统一的入站分发回调。
+
+ Args:
+ dispatcher: 接收已通过 Broker 审核的入站封装,并继续送入
+ Core 下一处理阶段的异步回调。
+ """
+
+ self._inbound_dispatcher = dispatcher
+
+ def clear_inbound_dispatcher(self) -> None:
+ """清除当前的入站分发回调。"""
+ self._inbound_dispatcher = None
+
+ @property
+ def has_inbound_dispatcher(self) -> bool:
+ """返回当前是否已经配置入站分发回调。
+
+ Returns:
+ bool: 若已经配置入站分发回调则返回 ``True``。
+ """
+ return self._inbound_dispatcher is not None
+
+ def register_driver(self, driver: PlatformIODriver) -> None:
+ """注册驱动,并把它的入站回调挂到 Broker。
+
+ Args:
+ driver: 要注册的驱动实例。
+
+ Raises:
+ RuntimeError: 当 Broker 已经处于运行态时抛出。此时应改用
+ ``add_driver()`` 以保证驱动生命周期和注册状态一致。
+ """
+ if self._started:
+ raise RuntimeError("Broker 运行中不允许直接 register_driver,请改用 add_driver()")
+
+ self._register_driver_internal(driver)
+
+ def _register_driver_internal(self, driver: PlatformIODriver) -> None:
+ """执行不带运行态限制的内部驱动注册。
+
+ Args:
+ driver: 要注册的驱动实例。
+ """
+ driver.set_inbound_handler(self.accept_inbound)
+ self._driver_registry.register(driver)
+
+ def unregister_driver(self, driver_id: str) -> Optional[PlatformIODriver]:
+ """从 Broker 注销一个驱动。
+
+ Args:
+ driver_id: 要移除的驱动 ID。
+
+ Returns:
+ Optional[PlatformIODriver]: 若驱动存在,则返回被移除的驱动实例。
+
+ Raises:
+ RuntimeError: 当 Broker 已经处于运行态时抛出。此时应改用
+ ``remove_driver()``,避免驱动停止与路由解绑脱节。
+ """
+ if self._started:
+ raise RuntimeError("Broker 运行中不允许直接 unregister_driver,请改用 remove_driver()")
+
+ return self._unregister_driver_internal(driver_id)
+
+ def _unregister_driver_internal(self, driver_id: str) -> Optional[PlatformIODriver]:
+ """执行不带运行态限制的内部驱动注销。
+
+ Args:
+ driver_id: 要移除的驱动 ID。
+
+ Returns:
+ Optional[PlatformIODriver]: 若驱动存在,则返回被移除的驱动实例。
+ """
+ removed_driver = self._driver_registry.unregister(driver_id)
+ if removed_driver is None:
+ return None
+
+ removed_driver.clear_inbound_handler()
+ self._route_table.remove_bindings_by_driver(driver_id)
+ return removed_driver
+
+ def bind_route(self, binding: RouteBinding, *, replace: bool = False) -> None:
+ """为某个路由键绑定驱动。
+
+ Args:
+ binding: 要保存的路由绑定。
+ replace: 是否允许替换已有的精确 active owner。
+
+ Raises:
+ ValueError: 当绑定引用了不存在的驱动,或者绑定与驱动描述不一致时抛出。
+ """
+ driver = self._driver_registry.get(binding.driver_id)
+ if driver is None:
+ raise ValueError(f"驱动 {binding.driver_id} 未注册,无法绑定路由")
+
+ self._validate_binding_against_driver(binding, driver)
+ self._route_table.bind(binding, replace=replace)
+
+ def unbind_route(self, route_key: RouteKey, driver_id: Optional[str] = None) -> None:
+ """移除一个或多个路由绑定。
+
+ Args:
+ route_key: 要移除绑定的路由键。
+ driver_id: 可选的特定驱动 ID。
+ """
+ self._route_table.unbind(route_key, driver_id)
+
+ def resolve_driver(self, route_key: RouteKey) -> Optional[PlatformIODriver]:
+ """解析某个路由键当前的 active 驱动。
+
+ Args:
+ route_key: 要解析的路由键。
+
+ Returns:
+ Optional[PlatformIODriver]: 若存在 active 驱动,则返回该驱动实例。
+ """
+ active_binding = self._route_table.get_active_binding(route_key)
+ if active_binding is None:
+ return None
+ return self._driver_registry.get(active_binding.driver_id)
+
+ @staticmethod
+ def build_route_key_from_message(message: "SessionMessage") -> RouteKey:
+ """根据 ``SessionMessage`` 构造路由键。
+
+ Args:
+ message: 内部会话消息对象。
+
+ Returns:
+ RouteKey: 由消息内容提取出的规范化路由键。
+ """
+ return RouteKeyFactory.from_session_message(message)
+
+ @staticmethod
+ def build_route_key_from_message_dict(message_dict: Dict[str, Any]) -> RouteKey:
+ """根据消息字典构造路由键。
+
+ Args:
+ message_dict: Host 与插件之间传输的消息字典。
+
+ Returns:
+ RouteKey: 由消息字典提取出的规范化路由键。
+ """
+ return RouteKeyFactory.from_message_dict(message_dict)
+
+ async def accept_inbound(self, envelope: InboundMessageEnvelope) -> bool:
+ """处理一条由驱动上报的入站封装。
+
+ Args:
+ envelope: 由传输驱动产出的入站封装。
+
+ Returns:
+ bool: 若消息被接受并继续转发给入站分发器,则返回 ``True``,
+ 否则返回 ``False``。
+ """
+
+ if not self._route_table.accepts_inbound(envelope.route_key, envelope.driver_id):
+ logger.info(
+ "忽略非 active owner 的入站消息: route=%s driver=%s",
+ envelope.route_key,
+ envelope.driver_id,
+ )
+ return False
+
+ if self._inbound_dispatcher is None:
+ logger.debug("PlatformIOManager 尚未配置 inbound dispatcher,暂不继续分发")
+ return False
+
+ dedupe_key = self._build_inbound_dedupe_key(envelope)
+ if dedupe_key is not None:
+ if not self._deduplicator.mark_seen(dedupe_key):
+ logger.info("忽略重复入站消息: dedupe_key=%s", dedupe_key)
+ return False
+
+ await self._inbound_dispatcher(envelope)
+ return True
+
+ async def send_message(
+ self,
+ message: "SessionMessage",
+ route_key: RouteKey,
+ metadata: Optional[Dict[str, Any]] = None,
+ ) -> DeliveryReceipt:
+ """通过 Broker 选中的驱动发送一条消息。
+
+ Args:
+ message: 要投递的内部会话消息。
+ route_key: 本次出站投递选择的路由键。
+ metadata: 可选的额外 Broker 侧元数据。
+
+ Returns:
+ DeliveryReceipt: 规范化后的出站回执。若路由不存在、驱动缺失,
+ 或同一消息已存在未完成的出站跟踪,也会返回失败回执而不是抛异常。
+ """
+
+ active_binding = self._route_table.get_active_binding(route_key)
+ if active_binding is None:
+ return DeliveryReceipt(
+ internal_message_id=message.message_id,
+ route_key=route_key,
+ status=DeliveryStatus.FAILED,
+ error="未找到 active 路由绑定",
+ )
+
+ driver = self._driver_registry.get(active_binding.driver_id)
+ if driver is None:
+ return DeliveryReceipt(
+ internal_message_id=message.message_id,
+ route_key=route_key,
+ status=DeliveryStatus.FAILED,
+ driver_id=active_binding.driver_id,
+ driver_kind=active_binding.driver_kind,
+ error="active 路由绑定对应的驱动不存在",
+ )
+
+ try:
+ self._outbound_tracker.begin_tracking(
+ internal_message_id=message.message_id,
+ route_key=route_key,
+ driver_id=driver.driver_id,
+ metadata=metadata,
+ )
+ except ValueError as exc:
+ return DeliveryReceipt(
+ internal_message_id=message.message_id,
+ route_key=route_key,
+ status=DeliveryStatus.FAILED,
+ driver_id=driver.driver_id,
+ driver_kind=driver.descriptor.kind,
+ error=str(exc),
+ )
+
+ try:
+ receipt = await driver.send_message(message=message, route_key=route_key, metadata=metadata)
+ except Exception as exc:
+ receipt = DeliveryReceipt(
+ internal_message_id=message.message_id,
+ route_key=route_key,
+ status=DeliveryStatus.FAILED,
+ driver_id=driver.driver_id,
+ driver_kind=driver.descriptor.kind,
+ error=str(exc),
+ )
+
+ self._outbound_tracker.finish_tracking(receipt)
+ return receipt
+
+ @staticmethod
+ def _build_inbound_dedupe_key(envelope: InboundMessageEnvelope) -> Optional[str]:
+ """构造用于入站抑制的去重键。
+
+ Args:
+ envelope: 当前正在处理的入站封装。
+
+ Returns:
+ Optional[str]: 若可以构造稳定去重键则返回该键,否则返回 ``None``。
+ """
+ raw_dedupe_key = envelope.dedupe_key or envelope.external_message_id
+ if raw_dedupe_key is None and envelope.session_message is not None:
+ raw_dedupe_key = envelope.session_message.message_id
+ if raw_dedupe_key is None and envelope.payload is not None:
+ raw_dedupe_key = PlatformIOManager._build_payload_fingerprint(envelope.payload)
+ if raw_dedupe_key is None:
+ return None
+
+ normalized_dedupe_key = str(raw_dedupe_key).strip()
+ if not normalized_dedupe_key:
+ return None
+
+ return f"{envelope.route_key.to_dedupe_scope()}:{normalized_dedupe_key}"
+
+ @staticmethod
+ def _build_payload_fingerprint(payload: Dict[str, Any]) -> Optional[str]:
+ """根据消息载荷构造稳定指纹。
+
+ Args:
+ payload: 待构造指纹的原始载荷字典。
+
+ Returns:
+ Optional[str]: 若成功生成指纹则返回十六进制摘要,否则返回 ``None``。
+ """
+ try:
+ serialized_payload = json.dumps(
+ payload,
+ default=str,
+ ensure_ascii=True,
+ separators=(",", ":"),
+ sort_keys=True,
+ )
+ except Exception:
+ return None
+
+ return hashlib.sha256(serialized_payload.encode()).hexdigest()
+
+ @staticmethod
+ def _validate_binding_against_driver(binding: RouteBinding, driver: PlatformIODriver) -> None:
+ """校验路由绑定与驱动描述是否一致。
+
+ Args:
+ binding: 待校验的路由绑定。
+ driver: 被绑定的驱动实例。
+
+ Raises:
+ ValueError: 当绑定类型、平台或更细粒度路由维度与驱动描述冲突时抛出。
+ """
+ descriptor = driver.descriptor
+ if binding.driver_kind != descriptor.kind:
+ raise ValueError(
+ f"路由绑定的 driver_kind={binding.driver_kind} 与驱动 {driver.driver_id} 的类型 "
+ f"{descriptor.kind} 不一致"
+ )
+
+ if binding.route_key.platform != descriptor.platform:
+ raise ValueError(
+ f"路由绑定的平台 {binding.route_key.platform} 与驱动 {driver.driver_id} 的平台 "
+ f"{descriptor.platform} 不一致"
+ )
+
+ if descriptor.account_id is not None and binding.route_key.account_id not in (None, descriptor.account_id):
+ raise ValueError(
+ f"路由绑定的 account_id={binding.route_key.account_id} 与驱动 {driver.driver_id} 的 "
+ f"account_id={descriptor.account_id} 冲突"
+ )
+
+ if descriptor.scope is not None and binding.route_key.scope not in (None, descriptor.scope):
+ raise ValueError(
+ f"路由绑定的 scope={binding.route_key.scope} 与驱动 {driver.driver_id} 的 "
+ f"scope={descriptor.scope} 冲突"
+ )
+
+
+_platform_io_manager: Optional[PlatformIOManager] = None
+
+
+def get_platform_io_manager() -> PlatformIOManager:
+ """返回全局 ``PlatformIOManager`` 单例。
+
+ Returns:
+ PlatformIOManager: 进程级共享的 Broker 管理器实例。
+ """
+
+ global _platform_io_manager
+ if _platform_io_manager is None:
+ _platform_io_manager = PlatformIOManager()
+ return _platform_io_manager
diff --git a/src/platform_io/outbound_tracker.py b/src/platform_io/outbound_tracker.py
new file mode 100644
index 00000000..438aa566
--- /dev/null
+++ b/src/platform_io/outbound_tracker.py
@@ -0,0 +1,242 @@
+"""跟踪 Platform IO 层的出站投递状态。
+
+当前实现基于两组 ``dict + heapq``:
+- ``_pending`` 和 ``_pending_expire_heap`` 负责管理待完成的出站记录
+- ``_receipts_by_external_id`` 和 ``_receipt_expire_heap`` 负责管理已完成回执索引
+
+这样就不需要在每次读写时全表扫描过期项,而是通过懒清理逐步弹出已经过期
+或已经失效的堆节点。
+"""
+
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional, Tuple
+
+import heapq
+import time
+
+from .types import DeliveryReceipt, RouteKey
+
+
+@dataclass(slots=True)
+class PendingOutboundRecord:
+ """表示一条仍在等待完成的出站投递记录。
+
+ Attributes:
+ internal_message_id: 正在跟踪的内部 ``SessionMessage.message_id``。
+ route_key: 该出站投递开始时使用的路由键。
+ driver_id: 负责这次出站投递的驱动 ID。
+ created_at: 开始跟踪时记录的单调时钟时间戳。
+ expires_at: 该待完成记录预计过期的单调时钟时间戳。
+ metadata: 与待完成记录一同保留的额外 Broker 侧元数据。
+ """
+
+ internal_message_id: str
+ route_key: RouteKey
+ driver_id: str
+ created_at: float = field(default_factory=time.monotonic)
+ expires_at: float = 0.0
+ metadata: Dict[str, Any] = field(default_factory=dict)
+
+
+@dataclass(slots=True)
+class StoredDeliveryReceipt:
+ """表示一条已完成并暂存的出站回执。
+
+ Attributes:
+ receipt: 规范化后的出站投递回执。
+ stored_at: 回执被写入索引时记录的单调时钟时间戳。
+ expires_at: 该回执索引预计过期的单调时钟时间戳。
+ """
+
+ receipt: DeliveryReceipt
+ stored_at: float = field(default_factory=time.monotonic)
+ expires_at: float = 0.0
+
+
+class OutboundTracker:
+ """统一跟踪出站消息的 pending 状态与最终回执。
+
+ 主要用于解决出站消息在发送过程中“状态散落在不同路径里”的问题:
+ - 发送开始后,需要在最终回执返回前保留一份 pending 状态
+ - 平台返回 ``external_message_id`` 后,需要保留一段时间的回执索引
+
+ 当前实现使用 ``dict + heapq`` 做 TTL 管理:
+ - ``dict`` 提供 ``O(1)`` 级别的主键查询
+ - ``heapq`` 提供按过期时间排序的懒清理能力
+
+ 这比“每次 begin/finish/get 都全表扫描”的实现更适合高吞吐出站场景。
+
+ Notes:
+ 复杂度说明如下,设 ``p`` 为当前有效 pending 数量,``r`` 为当前有效回执数量:
+
+ - ``begin_tracking()``、``finish_tracking()`` 的常见路径时间复杂度接近
+ ``O(log p)`` 或 ``O(log r)``
+ - ``get_pending()``、``get_receipt_by_external_id()`` 的查询本身是 ``O(1)``
+ ,连同懒清理一起看,长期摊还复杂度接近 ``O(log n)``
+ - 如果某次调用恰好触发一批过期节点的集中清理,则该次调用的最坏时间复杂度
+ 可达到 ``O(k log n)``,其中 ``k`` 为本次被弹出的节点数量
+ - 空间复杂度为 ``O(p + r)``
+ """
+
+ def __init__(self, ttl_seconds: float = 1800.0) -> None:
+ """初始化出站跟踪器。
+
+ Args:
+ ttl_seconds: 待完成记录与按外部消息 ID 建立的回执索引保留时长,
+ 单位为秒。
+
+ Raises:
+ ValueError: 当 ``ttl_seconds`` 非正数时抛出。
+ """
+ if ttl_seconds <= 0:
+ raise ValueError("ttl_seconds 必须大于 0")
+
+ self._ttl_seconds = ttl_seconds
+ self._pending: Dict[str, PendingOutboundRecord] = {}
+ self._pending_expire_heap: List[Tuple[float, str]] = []
+ self._receipts_by_external_id: Dict[str, StoredDeliveryReceipt] = {}
+ self._receipt_expire_heap: List[Tuple[float, str]] = []
+
+ def begin_tracking(
+ self,
+ internal_message_id: str,
+ route_key: RouteKey,
+ driver_id: str,
+ metadata: Optional[Dict[str, Any]] = None,
+ ) -> PendingOutboundRecord:
+ """开始跟踪一次出站投递。
+
+ Args:
+ internal_message_id: 正在投递的内部消息 ID。
+ route_key: 这次出站投递选择的路由键。
+ driver_id: 负责本次投递的驱动 ID。
+ metadata: 可选的额外元数据,会一并保存在待完成记录中。
+
+ Returns:
+ PendingOutboundRecord: 新创建的待完成记录。
+
+ Raises:
+ ValueError: 当同一个 ``internal_message_id`` 已经存在未完成记录时抛出。
+ """
+ now = time.monotonic()
+ self._cleanup_expired(now)
+
+ if internal_message_id in self._pending:
+ raise ValueError(f"消息 {internal_message_id} 已存在未完成的出站跟踪记录")
+
+ expires_at = now + self._ttl_seconds
+ record = PendingOutboundRecord(
+ internal_message_id=internal_message_id,
+ route_key=route_key,
+ driver_id=driver_id,
+ created_at=now,
+ expires_at=expires_at,
+ metadata=metadata or {},
+ )
+ self._pending[internal_message_id] = record
+ heapq.heappush(self._pending_expire_heap, (expires_at, internal_message_id))
+ return record
+
+ def finish_tracking(self, receipt: DeliveryReceipt) -> Optional[PendingOutboundRecord]:
+ """使用最终回执结束一条出站跟踪。
+
+ Args:
+ receipt: 规范化后的最终投递回执。
+
+ Returns:
+ Optional[PendingOutboundRecord]: 若此前存在待完成记录,则返回该记录。
+ """
+ now = time.monotonic()
+ self._cleanup_expired(now)
+
+ pending_record = self._pending.pop(receipt.internal_message_id, None)
+ if receipt.external_message_id:
+ expires_at = now + self._ttl_seconds
+ self._receipts_by_external_id[receipt.external_message_id] = StoredDeliveryReceipt(
+ receipt=receipt,
+ stored_at=now,
+ expires_at=expires_at,
+ )
+ heapq.heappush(self._receipt_expire_heap, (expires_at, receipt.external_message_id))
+ return pending_record
+
+ def get_pending(self, internal_message_id: str) -> Optional[PendingOutboundRecord]:
+ """根据内部消息 ID 查询待完成记录。
+
+ Args:
+ internal_message_id: 要查询的内部消息 ID。
+
+ Returns:
+ Optional[PendingOutboundRecord]: 若记录仍存在,则返回对应待完成记录。
+ """
+ self._cleanup_expired(time.monotonic())
+ return self._pending.get(internal_message_id)
+
+ def get_receipt_by_external_id(self, external_message_id: str) -> Optional[DeliveryReceipt]:
+ """根据外部平台消息 ID 查询已完成回执。
+
+ Args:
+ external_message_id: 要查询的平台侧消息 ID。
+
+ Returns:
+ Optional[DeliveryReceipt]: 若存在对应回执,则返回该回执。
+ """
+ self._cleanup_expired(time.monotonic())
+ stored_receipt = self._receipts_by_external_id.get(external_message_id)
+ return stored_receipt.receipt if stored_receipt else None
+
+ def clear(self) -> None:
+ """清空全部待完成记录与已保存回执。"""
+ self._pending.clear()
+ self._pending_expire_heap.clear()
+ self._receipts_by_external_id.clear()
+ self._receipt_expire_heap.clear()
+
+ def _cleanup_expired(self, now: float) -> None:
+ """清理内存中已经过期的待完成记录与已保存回执。
+
+ Args:
+ now: 当前单调时钟时间戳。
+ """
+ self._cleanup_expired_pending(now)
+ self._cleanup_expired_receipts(now)
+
+ def _cleanup_expired_pending(self, now: float) -> None:
+ """清理已经过期的待完成记录。
+
+ Args:
+ now: 当前单调时钟时间戳。
+
+ Notes:
+ 堆中可能存在已经失效的旧节点。例如某条记录提前 ``finish`` 后,
+ 它原本的过期节点仍可能留在堆里。这里会通过和 ``dict`` 中当前记录的
+ ``expires_at`` 对比,跳过这类旧节点。
+ """
+ while self._pending_expire_heap and self._pending_expire_heap[0][0] <= now:
+ expires_at, internal_message_id = heapq.heappop(self._pending_expire_heap)
+ current_record = self._pending.get(internal_message_id)
+ if current_record is None:
+ continue
+ if current_record.expires_at != expires_at:
+ continue
+ self._pending.pop(internal_message_id, None)
+
+ def _cleanup_expired_receipts(self, now: float) -> None:
+ """清理已经过期的回执索引。
+
+ Args:
+ now: 当前单调时钟时间戳。
+
+ Notes:
+ 同一个 ``external_message_id`` 在极端情况下可能被重复写入索引,
+ 因此这里同样需要通过 ``expires_at`` 和当前 ``dict`` 中的值比对,
+ 跳过已经失效的旧堆节点。
+ """
+ while self._receipt_expire_heap and self._receipt_expire_heap[0][0] <= now:
+ expires_at, external_message_id = heapq.heappop(self._receipt_expire_heap)
+ current_receipt = self._receipts_by_external_id.get(external_message_id)
+ if current_receipt is None:
+ continue
+ if current_receipt.expires_at != expires_at:
+ continue
+ self._receipts_by_external_id.pop(external_message_id, None)
diff --git a/src/platform_io/registry.py b/src/platform_io/registry.py
new file mode 100644
index 00000000..9ad8ea8a
--- /dev/null
+++ b/src/platform_io/registry.py
@@ -0,0 +1,70 @@
+"""提供 Platform IO 的驱动注册与查询能力。"""
+
+from typing import Dict, List, Optional
+
+from src.platform_io.drivers.base import PlatformIODriver
+from src.platform_io.types import DriverKind
+
+
+class DriverRegistry:
+ """集中保存已注册的 Platform IO 驱动,并提供基础查询接口。"""
+
+ def __init__(self) -> None:
+ """初始化一个空的驱动注册表。"""
+ self._drivers: Dict[str, PlatformIODriver] = {}
+
+ def register(self, driver: PlatformIODriver) -> None:
+ """注册一个驱动实例。
+
+ Args:
+ driver: 要注册的驱动实例。
+
+ Raises:
+ ValueError: 当驱动 ID 已经存在时抛出。
+ """
+ if driver.driver_id in self._drivers:
+ raise ValueError(f"驱动 {driver.driver_id} 已注册")
+ self._drivers[driver.driver_id] = driver
+
+ def unregister(self, driver_id: str) -> Optional[PlatformIODriver]:
+ """按驱动 ID 注销一个驱动。
+
+ Args:
+ driver_id: 要移除的驱动 ID。
+
+ Returns:
+ Optional[PlatformIODriver]: 若驱动存在,则返回被移除的驱动实例。
+ """
+ return self._drivers.pop(driver_id, None)
+
+ def get(self, driver_id: str) -> Optional[PlatformIODriver]:
+ """按驱动 ID 获取驱动实例。
+
+ Args:
+ driver_id: 要查询的驱动 ID。
+
+ Returns:
+ Optional[PlatformIODriver]: 若存在匹配驱动,则返回该驱动实例。
+ """
+ return self._drivers.get(driver_id)
+
+ def list(self, *, kind: Optional[DriverKind] = None, platform: Optional[str] = None) -> List[PlatformIODriver]:
+ """列出已注册驱动,并支持可选过滤。
+
+ Args:
+ kind: 可选的驱动类型过滤条件。
+ platform: 可选的平台名称过滤条件。
+
+ Returns:
+ List[PlatformIODriver]: 符合过滤条件的驱动列表。
+ """
+ drivers = list(self._drivers.values())
+ if kind is not None:
+ drivers = [driver for driver in drivers if driver.descriptor.kind == kind]
+ if platform is not None:
+ drivers = [driver for driver in drivers if driver.descriptor.platform == platform]
+ return drivers
+
+ def clear(self) -> None:
+ """清空全部已注册驱动。"""
+ self._drivers.clear()
diff --git a/src/platform_io/route_key_factory.py b/src/platform_io/route_key_factory.py
new file mode 100644
index 00000000..05bac6e8
--- /dev/null
+++ b/src/platform_io/route_key_factory.py
@@ -0,0 +1,150 @@
+"""提供 Platform IO 路由键的统一提取与构造能力。
+
+这层的目标不是直接接入具体消息链,而是先把“未来接线时用什么字段构造
+RouteKey”约定下来,避免 legacy 和 plugin 两条链路各自发明一套隐式规则。
+"""
+
+from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
+
+from .types import RouteKey
+
+if TYPE_CHECKING:
+ from src.chat.message_receive.message import SessionMessage
+
+
+class RouteKeyFactory:
+ """统一构造 ``RouteKey`` 的工厂。
+
+ 当前约定会优先从消息字典顶层、``message_info``、``additional_config`` 或传入 metadata 中提取
+ 以下字段:
+
+ - account_id: ``platform_io_account_id`` / ``account_id`` / ``self_id`` / ``bot_account``
+ - scope: ``platform_io_scope`` / ``route_scope`` / ``adapter_scope`` / ``connection_id``
+
+ 这样即使上游主链暂时还没有正式的 ``self_id`` 字段,中间层也能先统一
+ 约定提取口径,等具体消息链接入时直接复用。
+ """
+
+ ACCOUNT_ID_KEYS = (
+ "platform_io_account_id",
+ "account_id",
+ "self_id",
+ "bot_account",
+ )
+ SCOPE_KEYS = (
+ "platform_io_scope",
+ "route_scope",
+ "adapter_scope",
+ "connection_id",
+ )
+
+ @classmethod
+ def from_platform(
+ cls,
+ platform: str,
+ *,
+ account_id: Optional[str] = None,
+ scope: Optional[str] = None,
+ metadata: Optional[Dict[str, Any]] = None,
+ ) -> RouteKey:
+ """根据平台名和可选 metadata 构造 ``RouteKey``。
+
+ Args:
+ platform: 平台名称。
+ account_id: 显式传入的账号 ID;若为空,则尝试从 metadata 提取。
+ scope: 显式传入的路由作用域;若为空,则尝试从 metadata 提取。
+ metadata: 可选的元数据字典。
+
+ Returns:
+ RouteKey: 构造出的规范化路由键。
+ """
+ extracted_account_id, extracted_scope = cls.extract_components(metadata)
+ return RouteKey(
+ platform=platform,
+ account_id=account_id or extracted_account_id,
+ scope=scope or extracted_scope,
+ )
+
+ @classmethod
+ def from_message_dict(cls, message_dict: Dict[str, Any]) -> RouteKey:
+ """从消息字典中提取 ``RouteKey``。
+
+ Args:
+ message_dict: Host 与插件之间传输的消息字典。
+
+ Returns:
+ RouteKey: 构造出的规范化路由键。
+
+ Raises:
+ ValueError: 当消息字典缺少有效 ``platform`` 字段时抛出。
+ """
+ platform = str(message_dict.get("platform") or "").strip()
+ if not platform:
+ raise ValueError("消息字典缺少有效的 platform 字段,无法构造 RouteKey")
+
+ message_info = message_dict.get("message_info", {})
+ additional_config = {}
+ if isinstance(message_info, dict):
+ raw_additional_config = message_info.get("additional_config", {})
+ if isinstance(raw_additional_config, dict):
+ additional_config = raw_additional_config
+
+ explicit_account_id, explicit_scope = cls.extract_components(message_dict)
+ message_info_account_id, message_info_scope = cls.extract_components(message_info)
+ metadata_account_id, metadata_scope = cls.extract_components(additional_config)
+ return RouteKey(
+ platform=platform,
+ account_id=explicit_account_id or message_info_account_id or metadata_account_id,
+ scope=explicit_scope or message_info_scope or metadata_scope,
+ )
+
+ @classmethod
+ def from_session_message(cls, message: "SessionMessage") -> RouteKey:
+ """从 ``SessionMessage`` 中提取 ``RouteKey``。
+
+ Args:
+ message: 内部会话消息对象。
+
+ Returns:
+ RouteKey: 构造出的规范化路由键。
+ """
+ additional_config = message.message_info.additional_config or {}
+ metadata = additional_config if isinstance(additional_config, dict) else {}
+ return cls.from_platform(message.platform, metadata=metadata)
+
+ @classmethod
+ def extract_components(cls, mapping: Optional[Dict[str, Any]]) -> Tuple[Optional[str], Optional[str]]:
+ """从任意字典中提取 ``account_id`` 与 ``scope``。
+
+ Args:
+ mapping: 待提取的字典;若为空或不是字典,则返回空结果。
+
+ Returns:
+ Tuple[Optional[str], Optional[str]]: ``(account_id, scope)``。
+ """
+ if not mapping or not isinstance(mapping, dict):
+ return None, None
+
+ account_id = cls._pick_string(mapping, cls.ACCOUNT_ID_KEYS)
+ scope = cls._pick_string(mapping, cls.SCOPE_KEYS)
+ return account_id, scope
+
+ @staticmethod
+ def _pick_string(mapping: Dict[str, Any], keys: Tuple[str, ...]) -> Optional[str]:
+ """按优先级从字典里挑选第一个有效字符串。
+
+ Args:
+ mapping: 待查询的字典。
+ keys: 按优先级排列的候选键名。
+
+ Returns:
+ Optional[str]: 第一个规范化后非空的字符串值;若不存在则返回 ``None``。
+ """
+ for key in keys:
+ value = mapping.get(key)
+ if value is None:
+ continue
+ normalized = str(value).strip()
+ if normalized:
+ return normalized
+ return None
diff --git a/src/platform_io/routing.py b/src/platform_io/routing.py
new file mode 100644
index 00000000..7f85bbfa
--- /dev/null
+++ b/src/platform_io/routing.py
@@ -0,0 +1,202 @@
+"""提供 Platform IO 的路由绑定存储与归属解析能力。"""
+
+from typing import Dict, List, Optional
+
+from .types import RouteBinding, RouteKey, RouteMode
+
+
+class RouteBindingConflictError(ValueError):
+ """当同一路由键出现多个 active owner 竞争时抛出。"""
+
+
+class RouteTable:
+ """维护路由绑定并解析路由归属。
+
+ 这个表刻意保持轻量,只负责归属规则本身,不掺杂具体发送或接收逻辑。
+ 它决定某个路由键当前由哪个驱动 active 接管,哪些驱动仅以 shadow
+ 方式旁路观测。
+ """
+
+ def __init__(self) -> None:
+ """初始化一个空的路由绑定表。"""
+ self._bindings: Dict[RouteKey, Dict[str, RouteBinding]] = {}
+
+ def bind(self, binding: RouteBinding, *, replace: bool = False) -> None:
+ """注册或更新一条路由绑定。
+
+ Args:
+ binding: 要注册的绑定对象。
+ replace: 当精确路由键上已经存在 active owner 时,是否允许替换。
+
+ Raises:
+ RouteBindingConflictError: 当精确路由键上已存在其他 active owner,
+ 且 ``replace`` 为 ``False`` 时抛出。
+ """
+
+ if binding.mode == RouteMode.DISABLED:
+ self.unbind(binding.route_key, binding.driver_id)
+ return
+
+ if binding.mode == RouteMode.ACTIVE:
+ active_binding = self.get_active_binding(binding.route_key, exact_only=True)
+ if active_binding and active_binding.driver_id != binding.driver_id:
+ if not replace:
+ raise RouteBindingConflictError(
+ f"RouteKey {binding.route_key} 已由 {active_binding.driver_id} 接管,"
+ f"拒绝绑定到 {binding.driver_id}"
+ )
+ self.unbind(binding.route_key, active_binding.driver_id)
+
+ self._bindings.setdefault(binding.route_key, {})[binding.driver_id] = binding
+
+ def unbind(self, route_key: RouteKey, driver_id: Optional[str] = None) -> List[RouteBinding]:
+ """移除指定路由键上的绑定。
+
+ Args:
+ route_key: 要移除绑定的路由键。
+ driver_id: 可选的特定驱动 ID;若为空,则移除该路由键上的全部绑定。
+
+ Returns:
+ List[RouteBinding]: 被移除的绑定列表。
+ """
+
+ binding_map = self._bindings.get(route_key)
+ if not binding_map:
+ return []
+
+ if driver_id is None:
+ removed = list(binding_map.values())
+ self._bindings.pop(route_key, None)
+ return removed
+
+ removed_binding = binding_map.pop(driver_id, None)
+ if not binding_map:
+ self._bindings.pop(route_key, None)
+ return [removed_binding] if removed_binding else []
+
+ def remove_bindings_by_driver(self, driver_id: str) -> List[RouteBinding]:
+ """移除某个驱动在所有路由键上的绑定。
+
+ Args:
+ driver_id: 要移除绑定的驱动 ID。
+
+ Returns:
+ List[RouteBinding]: 被移除的绑定列表。
+ """
+ removed_bindings: List[RouteBinding] = []
+ empty_route_keys: List[RouteKey] = []
+
+ for route_key, binding_map in self._bindings.items():
+ removed_binding = binding_map.pop(driver_id, None)
+ if removed_binding is not None:
+ removed_bindings.append(removed_binding)
+ if not binding_map:
+ empty_route_keys.append(route_key)
+
+ for route_key in empty_route_keys:
+ self._bindings.pop(route_key, None)
+
+ return self._sort_bindings(removed_bindings)
+
+ def list_bindings(self, route_key: Optional[RouteKey] = None) -> List[RouteBinding]:
+ """列出当前绑定。
+
+ Args:
+ route_key: 可选的路由键过滤条件;若为空,则返回全部路由键上的绑定。
+
+ Returns:
+ List[RouteBinding]: 按优先级降序排列的绑定列表。
+ """
+
+ if route_key is None:
+ bindings: List[RouteBinding] = []
+ for binding_map in self._bindings.values():
+ bindings.extend(binding_map.values())
+ return self._sort_bindings(bindings)
+
+ binding_map = self._bindings.get(route_key, {})
+ return self._sort_bindings(list(binding_map.values()))
+
+ def get_active_binding(self, route_key: RouteKey, *, exact_only: bool = False) -> Optional[RouteBinding]:
+ """获取某个路由键当前生效的 active 绑定。
+
+ Args:
+ route_key: 要解析的路由键。
+ exact_only: 是否只检查精确路由键而不做回退解析。
+
+ Returns:
+ Optional[RouteBinding]: 若存在 active owner,则返回对应绑定。
+ """
+
+ candidate_keys = [route_key] if exact_only else route_key.resolution_order()
+ for candidate_key in candidate_keys:
+ binding_map = self._bindings.get(candidate_key, {})
+ active_binding = self._pick_best_binding(binding_map, RouteMode.ACTIVE)
+ if active_binding is not None:
+ return active_binding
+ return None
+
+ def get_shadow_bindings(self, route_key: RouteKey) -> List[RouteBinding]:
+ """获取某个精确路由键上的 shadow 绑定。
+
+ Args:
+ route_key: 要查看的路由键。
+
+ Returns:
+ List[RouteBinding]: 按优先级降序排列的 shadow 绑定列表。
+ """
+ binding_map = self._bindings.get(route_key, {})
+ shadow_bindings = [binding for binding in binding_map.values() if binding.mode == RouteMode.SHADOW]
+ return self._sort_bindings(shadow_bindings)
+
+ def accepts_inbound(self, route_key: RouteKey, driver_id: str) -> bool:
+ """判断某个驱动是否是当前允许入 Core 的 active owner。
+
+ Args:
+ route_key: 入站消息对应的路由键。
+ driver_id: 希望将消息送入 Core 的驱动 ID。
+
+ Returns:
+ bool: 若该驱动是解析结果中的 active owner,则返回 ``True``。
+ """
+
+ active_binding = self.get_active_binding(route_key)
+ return active_binding is not None and active_binding.driver_id == driver_id
+
+ @staticmethod
+ def _sort_bindings(bindings: List[RouteBinding]) -> List[RouteBinding]:
+ """按优先级降序排列绑定列表。
+
+ Args:
+ bindings: 待排序的绑定列表。
+
+ Returns:
+ List[RouteBinding]: 排序后的绑定列表。
+ """
+ return sorted(bindings, key=lambda item: item.priority, reverse=True)
+
+ @staticmethod
+ def _pick_best_binding(
+ binding_map: Dict[str, RouteBinding],
+ mode: RouteMode,
+ ) -> Optional[RouteBinding]:
+ """从绑定映射中挑选指定模式下优先级最高的一条绑定。
+
+ Args:
+ binding_map: 某个精确 ``RouteKey`` 对应的绑定映射。
+ mode: 需要挑选的绑定模式。
+
+ Returns:
+ Optional[RouteBinding]: 若存在匹配模式的绑定,则返回优先级最高的一条。
+
+ Notes:
+ 这里使用单次线性扫描代替“先过滤成列表再排序”的做法,以减少
+ 高频路由解析路径上的临时对象分配和排序开销。
+ """
+ best_binding: Optional[RouteBinding] = None
+ for binding in binding_map.values():
+ if binding.mode != mode:
+ continue
+ if best_binding is None or binding.priority > best_binding.priority:
+ best_binding = binding
+ return best_binding
diff --git a/src/platform_io/types.py b/src/platform_io/types.py
new file mode 100644
index 00000000..c74dc246
--- /dev/null
+++ b/src/platform_io/types.py
@@ -0,0 +1,240 @@
+"""定义 Platform IO 中间层共享的核心类型。
+
+本模块放置路由、驱动、入站与出站等规范化数据结构,供 Broker
+层在 legacy 适配器链路和 plugin 适配器链路之间复用。
+"""
+
+from dataclasses import dataclass, field
+from enum import Enum
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+
+if TYPE_CHECKING:
+ from src.chat.message_receive.message import SessionMessage
+
+
+class DriverKind(str, Enum):
+ """底层收发驱动类型枚举。"""
+
+ LEGACY = "legacy"
+ PLUGIN = "plugin"
+
+
+class RouteMode(str, Enum):
+ """路由归属模式枚举。"""
+
+ ACTIVE = "active"
+ SHADOW = "shadow"
+ DISABLED = "disabled"
+
+
+class DeliveryStatus(str, Enum):
+ """统一出站回执状态枚举。"""
+
+ PENDING = "pending"
+ SENT = "sent"
+ FAILED = "failed"
+ DROPPED = "dropped"
+
+
+@dataclass(frozen=True, slots=True)
+class RouteKey:
+ """用于 Platform IO 路由决策的唯一键。
+
+ 路由解析会按照“从最具体到最宽泛”的顺序进行回退,这样同一平台
+ 后续就能自然支持按账号、自定义 scope 等更细粒度的归属控制。
+
+ Attributes:
+ platform: 平台名称,例如 ``qq``。
+ account_id: 机器人账号 ID 或 self ID,用于区分同平台多身份。
+ scope: 额外路由作用域,预留给未来的连接实例、租户或子通道等维度。
+ """
+
+ platform: str
+ account_id: Optional[str] = None
+ scope: Optional[str] = None
+
+ def __post_init__(self) -> None:
+ """规范化并校验路由键字段。
+
+ Raises:
+ ValueError: 当 ``platform`` 规范化后为空时抛出。
+ """
+ platform = str(self.platform).strip()
+ account_id = str(self.account_id).strip() if self.account_id is not None else None
+ scope = str(self.scope).strip() if self.scope is not None else None
+
+ if not platform:
+ raise ValueError("RouteKey.platform 不能为空")
+
+ object.__setattr__(self, "platform", platform)
+ object.__setattr__(self, "account_id", account_id or None)
+ object.__setattr__(self, "scope", scope or None)
+
+ def resolution_order(self) -> List["RouteKey"]:
+ """返回从最具体到最宽泛的路由匹配顺序。
+
+ Returns:
+ List[RouteKey]: 按回退优先级排序的候选路由键列表。
+ """
+
+ keys: List[RouteKey] = [self]
+
+ if self.account_id is not None and self.scope is not None:
+ keys.append(RouteKey(platform=self.platform, account_id=self.account_id, scope=None))
+ keys.append(RouteKey(platform=self.platform, account_id=None, scope=self.scope))
+ elif self.account_id is not None:
+ keys.append(RouteKey(platform=self.platform, account_id=None, scope=None))
+ elif self.scope is not None:
+ keys.append(RouteKey(platform=self.platform, account_id=None, scope=None))
+
+ default_key = RouteKey(platform=self.platform, account_id=None, scope=None)
+ if default_key not in keys:
+ keys.append(default_key)
+
+ return keys
+
+ def to_dedupe_scope(self) -> str:
+ """生成跨驱动共享的去重作用域字符串。
+
+ Returns:
+ str: 用于入站消息去重的稳定文本作用域键。
+ """
+
+ account_id = self.account_id or "*"
+ scope = self.scope or "*"
+ return f"{self.platform}:{account_id}:{scope}"
+
+
+@dataclass(frozen=True, slots=True)
+class DriverDescriptor:
+ """描述一个已注册的 Platform IO 驱动。
+
+ Attributes:
+ driver_id: Broker 层内全局唯一的驱动标识。
+ kind: 驱动实现类型,例如 legacy 或 plugin。
+ platform: 驱动负责的平台名称。
+ account_id: 可选的账号 ID 或 self ID。
+ scope: 可选的额外路由作用域。
+ plugin_id: 当驱动来自插件适配器时,对应的插件 ID。
+ metadata: 预留给路由策略或观测能力的额外驱动元数据。
+ """
+
+ driver_id: str
+ kind: DriverKind
+ platform: str
+ account_id: Optional[str] = None
+ scope: Optional[str] = None
+ plugin_id: Optional[str] = None
+ metadata: Dict[str, Any] = field(default_factory=dict)
+
+ def __post_init__(self) -> None:
+ """规范化并校验驱动描述字段。
+
+ Raises:
+ ValueError: 当 ``driver_id`` 或 ``platform`` 规范化后为空时抛出。
+ """
+ driver_id = str(self.driver_id).strip()
+ platform = str(self.platform).strip()
+ plugin_id = str(self.plugin_id).strip() if self.plugin_id is not None else None
+
+ if not driver_id:
+ raise ValueError("DriverDescriptor.driver_id 不能为空")
+ if not platform:
+ raise ValueError("DriverDescriptor.platform 不能为空")
+
+ object.__setattr__(self, "driver_id", driver_id)
+ object.__setattr__(self, "platform", platform)
+ object.__setattr__(self, "plugin_id", plugin_id or None)
+
+ @property
+ def route_key(self) -> RouteKey:
+ """构造该驱动默认代表的路由键。
+
+ Returns:
+ RouteKey: 当前驱动描述对应的规范化路由键。
+ """
+ return RouteKey(platform=self.platform, account_id=self.account_id, scope=self.scope)
+
+
+@dataclass(frozen=True, slots=True)
+class RouteBinding:
+ """表示一条从路由键到驱动的归属绑定关系。
+
+ Attributes:
+ route_key: 该绑定覆盖的路由键。
+ driver_id: 拥有或旁路观察该路由的驱动 ID。
+ driver_kind: 绑定驱动的类型。
+ mode: 绑定模式,例如 active owner 或 shadow observer。
+ priority: 当同模式下存在多条绑定时使用的相对优先级。
+ metadata: 预留给未来路由策略的额外绑定元数据。
+ """
+
+ route_key: RouteKey
+ driver_id: str
+ driver_kind: DriverKind
+ mode: RouteMode = RouteMode.ACTIVE
+ priority: int = 0
+ metadata: Dict[str, Any] = field(default_factory=dict)
+
+ def __post_init__(self) -> None:
+ """规范化并校验绑定字段。
+
+ Raises:
+ ValueError: 当 ``driver_id`` 规范化后为空时抛出。
+ """
+ driver_id = str(self.driver_id).strip()
+ if not driver_id:
+ raise ValueError("RouteBinding.driver_id 不能为空")
+ object.__setattr__(self, "driver_id", driver_id)
+
+
+@dataclass(slots=True)
+class InboundMessageEnvelope:
+ """封装一次由驱动产出的规范化入站消息。
+
+ Attributes:
+ route_key: 该入站消息解析出的路由键。
+ driver_id: 产出该消息的驱动 ID。
+ driver_kind: 产出该消息的驱动类型。
+ external_message_id: 可选的平台侧消息 ID,用于去重。
+ dedupe_key: 可选的显式去重键。当外部消息没有稳定 ``message_id`` 时,
+ 可由上游驱动提供消息指纹。若这里为空,中间层仍可能继续回退到
+ ``session_message.message_id`` 或 ``payload`` 指纹。
+ session_message: 可选的、已经完成规范化的 ``SessionMessage`` 对象。
+ payload: 可选的原始字典载荷,供延迟转换或调试使用。
+ metadata: 额外入站元数据,例如连接信息或追踪上下文。
+ """
+
+ route_key: RouteKey
+ driver_id: str
+ driver_kind: DriverKind
+ external_message_id: Optional[str] = None
+ dedupe_key: Optional[str] = None
+ session_message: Optional["SessionMessage"] = None
+ payload: Optional[Dict[str, Any]] = None
+ metadata: Dict[str, Any] = field(default_factory=dict)
+
+
+@dataclass(slots=True)
+class DeliveryReceipt:
+ """表示一次出站投递尝试的统一结果。
+
+ Attributes:
+ internal_message_id: Broker 跟踪的内部 ``SessionMessage.message_id``。
+ route_key: 本次投递使用的路由键。
+ status: 规范化后的投递状态。
+ driver_id: 实际处理该投递的驱动 ID,可为空。
+ driver_kind: 实际处理该投递的驱动类型,可为空。
+ external_message_id: 驱动或适配器返回的平台侧消息 ID,可为空。
+ error: 投递失败时的错误信息,可为空。
+ metadata: 预留给回执、时间戳或平台特有信息的额外元数据。
+ """
+
+ internal_message_id: str
+ route_key: RouteKey
+ status: DeliveryStatus
+ driver_id: Optional[str] = None
+ driver_kind: Optional[DriverKind] = None
+ external_message_id: Optional[str] = None
+ error: Optional[str] = None
+ metadata: Dict[str, Any] = field(default_factory=dict)
From 07256182fbf2f649bf55bf68d2ac2ac3bc0e1593 Mon Sep 17 00:00:00 2001
From: DrSmoothl <1787882683@qq.com>
Date: Fri, 20 Mar 2026 01:20:22 +0800
Subject: [PATCH 16/45] =?UTF-8?q?refactor(manager):=20=E4=BD=BF=E7=94=A8?=
=?UTF-8?q?=20List=20=E7=B1=BB=E5=9E=8B=E6=9B=BF=E4=BB=A3=20list=EF=BC=8C?=
=?UTF-8?q?=E5=A2=9E=E5=BC=BA=E7=B1=BB=E5=9E=8B=E4=B8=80=E8=87=B4=E6=80=A7?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
src/platform_io/manager.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/src/platform_io/manager.py b/src/platform_io/manager.py
index 6135a567..97835667 100644
--- a/src/platform_io/manager.py
+++ b/src/platform_io/manager.py
@@ -1,6 +1,6 @@
"""提供 Platform IO 层的中心 Broker 管理器。"""
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional
import hashlib
import json
@@ -59,7 +59,7 @@ class PlatformIOManager:
if self._started:
return
- started_drivers: list[PlatformIODriver] = []
+ started_drivers: List[PlatformIODriver] = []
try:
for driver in self._driver_registry.list():
await driver.start()
@@ -86,7 +86,7 @@ class PlatformIOManager:
if not self._started:
return
- stop_errors: list[str] = []
+ stop_errors: List[str] = []
for driver in reversed(self._driver_registry.list()):
try:
await driver.stop()
From e4850c469feb68c8bd0dfbb66593b5dd959fbdd6 Mon Sep 17 00:00:00 2001
From: DrSmoothl <1787882683@qq.com>
Date: Fri, 20 Mar 2026 22:23:47 +0800
Subject: [PATCH 17/45] feat: Enhance plugin loading and management
- Added module_name parameter to PluginMeta for better module tracking.
- Improved documentation for PluginMeta and PluginLoader methods.
- Introduced methods for managing loaded plugins: set_loaded_plugin, remove_loaded_plugin, and purge_plugin_modules.
- Enhanced dependency resolution in PluginLoader with resolve_dependencies method.
- Implemented candidate discovery and loading in PluginLoader.
- Added support for plugin reloading with _reload_plugin_by_id in PluginRunner.
- Improved error handling and logging throughout the RPCClient and PluginRunner.
- Added support for handling hook invocations in PluginRunner.
- Refactored plugin registration and unregistration processes for clarity and efficiency.
---
src/plugin_runtime/capabilities/components.py | 15 +-
src/plugin_runtime/capabilities/registry.py | 115 ++--
src/plugin_runtime/host/capability_service.py | 21 +-
src/plugin_runtime/host/supervisor.py | 590 +++++++++++++++---
src/plugin_runtime/integration.py | 30 +-
src/plugin_runtime/protocol/envelope.py | 39 +-
src/plugin_runtime/runner/plugin_loader.py | 188 +++++-
src/plugin_runtime/runner/rpc_client.py | 241 +++----
src/plugin_runtime/runner/runner_main.py | 445 +++++++++++--
9 files changed, 1351 insertions(+), 333 deletions(-)
diff --git a/src/plugin_runtime/capabilities/components.py b/src/plugin_runtime/capabilities/components.py
index aa7ceb46..4223525f 100644
--- a/src/plugin_runtime/capabilities/components.py
+++ b/src/plugin_runtime/capabilities/components.py
@@ -174,7 +174,10 @@ class RuntimeComponentCapabilityMixin:
if registered_supervisor is not None:
try:
- reloaded = await registered_supervisor.reload_plugins(reason=f"load {plugin_name}")
+ reloaded = await registered_supervisor.reload_plugins(
+ plugin_ids=[plugin_name],
+ reason=f"load {plugin_name}",
+ )
if reloaded:
return {"success": True, "count": 1}
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
@@ -186,7 +189,10 @@ class RuntimeComponentCapabilityMixin:
for pdir in sv._plugin_dirs:
if (pdir / plugin_name).is_dir():
try:
- reloaded = await sv.reload_plugins(reason=f"load {plugin_name}")
+ reloaded = await sv.reload_plugins(
+ plugin_ids=[plugin_name],
+ reason=f"load {plugin_name}",
+ )
if reloaded:
return {"success": True, "count": 1}
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
@@ -222,7 +228,10 @@ class RuntimeComponentCapabilityMixin:
if sv is not None:
try:
- reloaded = await sv.reload_plugins(reason=f"reload {plugin_name}")
+ reloaded = await sv.reload_plugins(
+ plugin_ids=[plugin_name],
+ reason=f"reload {plugin_name}",
+ )
if reloaded:
return {"success": True}
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
diff --git a/src/plugin_runtime/capabilities/registry.py b/src/plugin_runtime/capabilities/registry.py
index abce97dc..ead5876a 100644
--- a/src/plugin_runtime/capabilities/registry.py
+++ b/src/plugin_runtime/capabilities/registry.py
@@ -1,4 +1,4 @@
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
from src.common.logger import get_logger
from src.plugin_runtime.host.supervisor import PluginSupervisor
@@ -12,67 +12,78 @@ logger = get_logger("plugin_runtime.integration")
def register_capability_impls(manager: "PluginRuntimeManager", supervisor: PluginSupervisor) -> None:
"""向指定 Supervisor 注册主程序提供的能力实现。"""
cap_service = supervisor.capability_service
+ rpc_server = supervisor.rpc_server
- cap_service.register_capability("send.text", manager._cap_send_text)
- cap_service.register_capability("send.emoji", manager._cap_send_emoji)
- cap_service.register_capability("send.image", manager._cap_send_image)
- cap_service.register_capability("send.command", manager._cap_send_command)
- cap_service.register_capability("send.custom", manager._cap_send_custom)
+ def _register(name: str, impl: Any) -> None:
+ """注册单个能力实现及其 RPC 入口。
- cap_service.register_capability("llm.generate", manager._cap_llm_generate)
- cap_service.register_capability("llm.generate_with_tools", manager._cap_llm_generate_with_tools)
- cap_service.register_capability("llm.get_available_models", manager._cap_llm_get_available_models)
+ Args:
+ name: 能力名称。
+ impl: 能力实现函数。
+ """
+ cap_service.register_capability(name, impl)
+ rpc_server.register_method(name, cap_service.handle_capability_request)
- cap_service.register_capability("config.get", manager._cap_config_get)
- cap_service.register_capability("config.get_plugin", manager._cap_config_get_plugin)
- cap_service.register_capability("config.get_all", manager._cap_config_get_all)
+ _register("send.text", manager._cap_send_text)
+ _register("send.emoji", manager._cap_send_emoji)
+ _register("send.image", manager._cap_send_image)
+ _register("send.command", manager._cap_send_command)
+ _register("send.custom", manager._cap_send_custom)
- cap_service.register_capability("database.query", manager._cap_database_query)
- cap_service.register_capability("database.save", manager._cap_database_save)
- cap_service.register_capability("database.get", manager._cap_database_get)
- cap_service.register_capability("database.delete", manager._cap_database_delete)
- cap_service.register_capability("database.count", manager._cap_database_count)
+ _register("llm.generate", manager._cap_llm_generate)
+ _register("llm.generate_with_tools", manager._cap_llm_generate_with_tools)
+ _register("llm.get_available_models", manager._cap_llm_get_available_models)
- cap_service.register_capability("chat.get_all_streams", manager._cap_chat_get_all_streams)
- cap_service.register_capability("chat.get_group_streams", manager._cap_chat_get_group_streams)
- cap_service.register_capability("chat.get_private_streams", manager._cap_chat_get_private_streams)
- cap_service.register_capability("chat.get_stream_by_group_id", manager._cap_chat_get_stream_by_group_id)
- cap_service.register_capability("chat.get_stream_by_user_id", manager._cap_chat_get_stream_by_user_id)
+ _register("config.get", manager._cap_config_get)
+ _register("config.get_plugin", manager._cap_config_get_plugin)
+ _register("config.get_all", manager._cap_config_get_all)
- cap_service.register_capability("message.get_by_time", manager._cap_message_get_by_time)
- cap_service.register_capability("message.get_by_time_in_chat", manager._cap_message_get_by_time_in_chat)
- cap_service.register_capability("message.get_recent", manager._cap_message_get_recent)
- cap_service.register_capability("message.count_new", manager._cap_message_count_new)
- cap_service.register_capability("message.build_readable", manager._cap_message_build_readable)
+ _register("database.query", manager._cap_database_query)
+ _register("database.save", manager._cap_database_save)
+ _register("database.get", manager._cap_database_get)
+ _register("database.delete", manager._cap_database_delete)
+ _register("database.count", manager._cap_database_count)
- cap_service.register_capability("person.get_id", manager._cap_person_get_id)
- cap_service.register_capability("person.get_value", manager._cap_person_get_value)
- cap_service.register_capability("person.get_id_by_name", manager._cap_person_get_id_by_name)
+ _register("chat.get_all_streams", manager._cap_chat_get_all_streams)
+ _register("chat.get_group_streams", manager._cap_chat_get_group_streams)
+ _register("chat.get_private_streams", manager._cap_chat_get_private_streams)
+ _register("chat.get_stream_by_group_id", manager._cap_chat_get_stream_by_group_id)
+ _register("chat.get_stream_by_user_id", manager._cap_chat_get_stream_by_user_id)
- cap_service.register_capability("emoji.get_by_description", manager._cap_emoji_get_by_description)
- cap_service.register_capability("emoji.get_random", manager._cap_emoji_get_random)
- cap_service.register_capability("emoji.get_count", manager._cap_emoji_get_count)
- cap_service.register_capability("emoji.get_emotions", manager._cap_emoji_get_emotions)
- cap_service.register_capability("emoji.get_all", manager._cap_emoji_get_all)
- cap_service.register_capability("emoji.get_info", manager._cap_emoji_get_info)
- cap_service.register_capability("emoji.register", manager._cap_emoji_register)
- cap_service.register_capability("emoji.delete", manager._cap_emoji_delete)
+ _register("message.get_by_time", manager._cap_message_get_by_time)
+ _register("message.get_by_time_in_chat", manager._cap_message_get_by_time_in_chat)
+ _register("message.get_recent", manager._cap_message_get_recent)
+ _register("message.count_new", manager._cap_message_count_new)
+ _register("message.build_readable", manager._cap_message_build_readable)
- cap_service.register_capability("frequency.get_current_talk_value", manager._cap_frequency_get_current_talk_value)
- cap_service.register_capability("frequency.set_adjust", manager._cap_frequency_set_adjust)
- cap_service.register_capability("frequency.get_adjust", manager._cap_frequency_get_adjust)
+ _register("person.get_id", manager._cap_person_get_id)
+ _register("person.get_value", manager._cap_person_get_value)
+ _register("person.get_id_by_name", manager._cap_person_get_id_by_name)
- cap_service.register_capability("tool.get_definitions", manager._cap_tool_get_definitions)
+ _register("emoji.get_by_description", manager._cap_emoji_get_by_description)
+ _register("emoji.get_random", manager._cap_emoji_get_random)
+ _register("emoji.get_count", manager._cap_emoji_get_count)
+ _register("emoji.get_emotions", manager._cap_emoji_get_emotions)
+ _register("emoji.get_all", manager._cap_emoji_get_all)
+ _register("emoji.get_info", manager._cap_emoji_get_info)
+ _register("emoji.register", manager._cap_emoji_register)
+ _register("emoji.delete", manager._cap_emoji_delete)
- cap_service.register_capability("component.get_all_plugins", manager._cap_component_get_all_plugins)
- cap_service.register_capability("component.get_plugin_info", manager._cap_component_get_plugin_info)
- cap_service.register_capability("component.list_loaded_plugins", manager._cap_component_list_loaded_plugins)
- cap_service.register_capability("component.list_registered_plugins", manager._cap_component_list_registered_plugins)
- cap_service.register_capability("component.enable", manager._cap_component_enable)
- cap_service.register_capability("component.disable", manager._cap_component_disable)
- cap_service.register_capability("component.load_plugin", manager._cap_component_load_plugin)
- cap_service.register_capability("component.unload_plugin", manager._cap_component_unload_plugin)
- cap_service.register_capability("component.reload_plugin", manager._cap_component_reload_plugin)
+ _register("frequency.get_current_talk_value", manager._cap_frequency_get_current_talk_value)
+ _register("frequency.set_adjust", manager._cap_frequency_set_adjust)
+ _register("frequency.get_adjust", manager._cap_frequency_get_adjust)
- cap_service.register_capability("knowledge.search", manager._cap_knowledge_search)
+ _register("tool.get_definitions", manager._cap_tool_get_definitions)
+
+ _register("component.get_all_plugins", manager._cap_component_get_all_plugins)
+ _register("component.get_plugin_info", manager._cap_component_get_plugin_info)
+ _register("component.list_loaded_plugins", manager._cap_component_list_loaded_plugins)
+ _register("component.list_registered_plugins", manager._cap_component_list_registered_plugins)
+ _register("component.enable", manager._cap_component_enable)
+ _register("component.disable", manager._cap_component_disable)
+ _register("component.load_plugin", manager._cap_component_load_plugin)
+ _register("component.unload_plugin", manager._cap_component_unload_plugin)
+ _register("component.reload_plugin", manager._cap_component_reload_plugin)
+
+ _register("knowledge.search", manager._cap_knowledge_search)
logger.debug("已注册全部主程序能力实现")
diff --git a/src/plugin_runtime/host/capability_service.py b/src/plugin_runtime/host/capability_service.py
index 98366a07..761b20ca 100644
--- a/src/plugin_runtime/host/capability_service.py
+++ b/src/plugin_runtime/host/capability_service.py
@@ -30,6 +30,11 @@ class CapabilityService:
"""
def __init__(self, authorization: "AuthorizationManager") -> None:
+ """初始化能力服务。
+
+ Args:
+ authorization: 能力授权管理器。
+ """
self._authorization = authorization
# capability_name -> implementation
self._implementations: Dict[str, CapabilityImpl] = {}
@@ -51,13 +56,19 @@ class CapabilityService:
校验权限后调用对应实现。
"""
plugin_id = envelope.plugin_id
+ payload = envelope.payload if isinstance(envelope.payload, dict) else {}
try:
- req = CapabilityRequestPayload.model_validate(envelope.payload)
- except Exception as e:
- return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, f"能力调用 payload 格式错误: {e}")
+ req = CapabilityRequestPayload.model_validate(payload)
+ capability = req.capability
+ args = req.args
+ except Exception:
+ capability = envelope.method
+ raw_args = payload.get("args", payload)
+ args = raw_args if isinstance(raw_args, dict) else {}
- capability = req.capability
+ if not capability:
+ return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, "能力调用缺少 capability")
# 1. 权限校验
allowed, reason = self._authorization.check_capability(plugin_id, capability)
@@ -71,7 +82,7 @@ class CapabilityService:
# 3. 执行
try:
- result = await impl(plugin_id, capability, req.args)
+ result = await impl(plugin_id, capability, args)
resp_payload = CapabilityResponsePayload(success=True, result=result)
return envelope.make_response(payload=resp_payload.model_dump())
except RPCError as e:
diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py
index 82a5970b..5ae3bdee 100644
--- a/src/plugin_runtime/host/supervisor.py
+++ b/src/plugin_runtime/host/supervisor.py
@@ -1,32 +1,38 @@
from pathlib import Path
-from typing import Optional, List, Dict, Any, Tuple, TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import asyncio
-
+import contextlib
+import os
+import sys
from src.common.logger import get_logger
from src.config.config import global_config
-from src.plugin_runtime.transport.factory import create_transport_server
+from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN
from src.plugin_runtime.protocol.envelope import (
BootstrapPluginPayload,
ConfigUpdatedPayload,
Envelope,
HealthPayload,
- LogBatchPayload,
+ PROTOCOL_VERSION,
RegisterPluginPayload,
+ ReloadPluginResultPayload,
RunnerReadyPayload,
ShutdownPayload,
+ UnregisterPluginPayload,
)
+from src.plugin_runtime.protocol.codec import MsgPackCodec
+from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
+from src.plugin_runtime.transport.factory import create_transport_server
from .authorization import AuthorizationManager
from .capability_service import CapabilityService
-from .rpc_server import RPCServer
-from .logger_bridge import RunnerLogBridge
from .component_registry import ComponentRegistry
from .event_dispatcher import EventDispatcher
from .hook_dispatcher import HookDispatcher
+from .logger_bridge import RunnerLogBridge
from .message_gateway import MessageGateway
-from .message_utils import PluginMessageUtils
+from .rpc_server import RPCServer
if TYPE_CHECKING:
from src.chat.message_receive.message import SessionMessage
@@ -35,7 +41,11 @@ logger = get_logger("plugin_runtime.host.runner_manager")
class PluginRunnerSupervisor:
- """插件的Runner管理器,负责管理Runner的生命周期"""
+ """插件 Runner 监督器。
+
+ 负责 Host 侧与单个 Runner 子进程之间的生命周期、内部 RPC、
+ 健康检查和插件级重载协调。
+ """
def __init__(
self,
@@ -44,13 +54,24 @@ class PluginRunnerSupervisor:
health_check_interval_sec: Optional[float] = None,
max_restart_attempts: Optional[int] = None,
runner_spawn_timeout_sec: Optional[float] = None,
- ):
- _cfg = global_config.plugin_runtime
- self._plugin_dirs: List[Path] = plugin_dirs or []
- self._health_interval = health_check_interval_sec or _cfg.health_check_interval_sec or 30.0
- self._runner_spawn_timeout = runner_spawn_timeout_sec or _cfg.runner_spawn_timeout_sec or 30.0
+ ) -> None:
+ """初始化 Supervisor。
+
+ Args:
+ plugin_dirs: 由当前 Runner 负责加载的插件目录列表。
+ socket_path: 自定义 IPC 地址;留空时由传输层自动生成。
+ health_check_interval_sec: 健康检查间隔,单位秒。
+ max_restart_attempts: 自动重启 Runner 的最大次数。
+ runner_spawn_timeout_sec: 等待 Runner 建连并就绪的超时时间,单位秒。
+ """
+ runtime_config = global_config.plugin_runtime
+ self._plugin_dirs: List[Path] = plugin_dirs or []
+ self._health_interval: float = health_check_interval_sec or runtime_config.health_check_interval_sec or 30.0
+ self._runner_spawn_timeout: float = (
+ runner_spawn_timeout_sec or runtime_config.runner_spawn_timeout_sec or 30.0
+ )
+ self._max_restart_attempts: int = max_restart_attempts or runtime_config.max_restart_attempts or 3
- # 基础设施
self._transport = create_transport_server(socket_path=socket_path)
self._authorization = AuthorizationManager()
self._capability_service = CapabilityService(self._authorization)
@@ -58,61 +79,55 @@ class PluginRunnerSupervisor:
self._event_dispatcher = EventDispatcher(self._component_registry)
self._hook_dispatcher = HookDispatcher(self._component_registry)
self._message_gateway = MessageGateway(self._component_registry)
-
- # 编解码和服务器
- from src.plugin_runtime.protocol.codec import MsgPackCodec
+ self._log_bridge = RunnerLogBridge()
codec = MsgPackCodec()
self._rpc_server = RPCServer(transport=self._transport, codec=codec)
- # Runner 子进程
self._runner_process: Optional[asyncio.subprocess.Process] = None
- self._max_restart_attempts: int = max_restart_attempts or _cfg.max_restart_attempts or 3
- self._restart_count: int = 0
-
- # 已注册的插件组件信息
self._registered_plugins: Dict[str, RegisterPluginPayload] = {}
self._runner_ready_events: asyncio.Event = asyncio.Event()
self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload()
+ self._health_task: Optional[asyncio.Task[None]] = None
+ self._stderr_drain_task: Optional[asyncio.Task[None]] = None
+ self._restart_count: int = 0
+ self._running: bool = False
- # 后台任务
- self._health_task: Optional[asyncio.Task] = None
- # Runner stderr 流排空任务(仅保留 stderr,用于 IPC 建立前的启动日志倒空、致命错误输出等场景)
- self._stderr_drain_task: Optional[asyncio.Task] = None
- self._running = False
-
- # Runner 日志桥(将 Runner 上报的批量日志重放到主进程 Logger)
- self._log_bridge: RunnerLogBridge = RunnerLogBridge()
-
- # 注册内部 RPC 方法
- self._register_internal_methods() # TODO: 完成内部方法注册
+ self._register_internal_methods()
@property
def authorization_manager(self) -> AuthorizationManager:
+ """返回授权管理器。"""
return self._authorization
@property
def capability_service(self) -> CapabilityService:
+ """返回能力服务。"""
return self._capability_service
@property
def component_registry(self) -> ComponentRegistry:
+ """返回组件注册表。"""
return self._component_registry
@property
def event_dispatcher(self) -> EventDispatcher:
+ """返回事件分发器。"""
return self._event_dispatcher
@property
def hook_dispatcher(self) -> HookDispatcher:
+ """返回 Hook 分发器。"""
return self._hook_dispatcher
@property
def message_gateway(self) -> MessageGateway:
+ """返回消息网关。"""
return self._message_gateway
@property
def rpc_server(self) -> RPCServer:
+ """返回底层 RPC 服务端。"""
return self._rpc_server
async def dispatch_event(
@@ -121,11 +136,28 @@ class PluginRunnerSupervisor:
message: Optional["SessionMessage"] = None,
extra_args: Optional[Dict[str, Any]] = None,
) -> Tuple[bool, Optional["SessionMessage"]]:
- """分发事件到所有对应 handler 的快捷方法。"""
+ """分发事件到已注册的事件处理器。
+
+ Args:
+ event_type: 事件类型。
+ message: 可选的消息对象。
+ extra_args: 附加参数。
+
+ Returns:
+ Tuple[bool, Optional[SessionMessage]]: 是否继续处理,以及插件可能修改后的消息。
+ """
return await self._event_dispatcher.dispatch_event(event_type, self, message, extra_args)
- async def dispatch_hook(self, stage: str, **kwargs):
- """分发Hook事件到所有对应 handler 的快捷方法。"""
+ async def dispatch_hook(self, stage: str, **kwargs: Any) -> Dict[str, Any]:
+ """分发 Hook 到已注册的 Hook 处理器。
+
+ Args:
+ stage: Hook 阶段名称。
+ **kwargs: 传递给 Hook 的关键字参数。
+
+ Returns:
+ Dict[str, Any]: 经 Hook 修改后的参数字典。
+ """
return await self._hook_dispatcher.hook_dispatch(stage, self, **kwargs)
async def send_message_to_external(
@@ -135,60 +167,68 @@ class PluginRunnerSupervisor:
enabled_only: bool = True,
save_to_db: bool = True,
) -> bool:
- """发送系统内部消息到外部平台的快捷方法。"""
+ """通过插件消息网关发送外部消息。
+
+ Args:
+ internal_message: 系统内部消息对象。
+ enabled_only: 是否仅使用启用的网关组件。
+ save_to_db: 发送成功后是否写入数据库。
+
+ Returns:
+ bool: 是否发送成功。
+ """
return await self._message_gateway.send_message_to_external(
- internal_message, self, enabled_only=enabled_only, save_to_db=save_to_db
+ internal_message,
+ self,
+ enabled_only=enabled_only,
+ save_to_db=save_to_db,
)
async def start(self) -> None:
- """启动 Supervisor
+ """启动 Supervisor。"""
+ if self._running:
+ logger.warning("PluginRunnerSupervisor 已在运行,跳过重复启动")
+ return
- 1. 启动 RPC Server
- 2. 拉起 Runner 子进程
- 3. 启动健康检查
- """
self._running = True
+ self._restart_count = 0
+ self._clear_runner_state()
- # 启动 RPC Server
await self._rpc_server.start()
- # 拉起 Runner 进程
await self._spawn_runner()
- # 等待 Runner 完成连接和初始化,避免 start() 返回时 Runner 尚未就绪
try:
await self._wait_for_runner_connection(timeout_sec=self._runner_spawn_timeout)
await self._wait_for_runner_ready(timeout_sec=self._runner_spawn_timeout)
except TimeoutError:
if not self._rpc_server.is_connected:
- logger.warning(f"Runner 未在 {self._runner_spawn_timeout}s 内完成连接,后续操作可能失败")
+ logger.warning("Runner 未在限定时间内完成连接,后续操作可能失败")
else:
- logger.warning(f"Runner 未在 {self._runner_spawn_timeout}s 内完成初始化,后续操作可能失败")
+ logger.warning("Runner 未在限定时间内完成初始化,后续操作可能失败")
- # 启动健康检查
- self._health_task = asyncio.create_task(self._health_check_loop())
-
- logger.info("PluginSupervisor 已启动")
+ self._health_task = asyncio.create_task(self._health_check_loop(), name="PluginRunnerSupervisor.health")
+ logger.info("PluginRunnerSupervisor 已启动")
async def stop(self) -> None:
- """停止 Supervisor"""
+ """停止 Supervisor。"""
+ if not self._running:
+ return
+
self._running = False
- # 停止组件
- await self._event_dispatcher.stop()
- await self._hook_dispatcher.stop()
-
- # 停止健康检查
- if self._health_task:
+ if self._health_task is not None:
self._health_task.cancel()
+ with contextlib.suppress(asyncio.CancelledError):
+ await self._health_task
self._health_task = None
- # 优雅关停 Runner
- await self._shutdown_runner()
-
- # 停止 RPC Server
+ await self._event_dispatcher.stop()
+ await self._hook_dispatcher.stop()
+ await self._shutdown_runner(reason="host_stop")
await self._rpc_server.stop()
+ self._clear_runner_state()
- logger.info("PluginSupervisor 已停止")
+ logger.info("PluginRunnerSupervisor 已停止")
async def invoke_plugin(
self,
@@ -198,9 +238,17 @@ class PluginRunnerSupervisor:
args: Optional[Dict[str, Any]] = None,
timeout_ms: int = 30000,
) -> Envelope:
- """调用插件组件
+ """调用 Runner 内的插件组件。
- 由主进程业务逻辑调用,通过 RPC 转发给 Runner。
+ Args:
+ method: RPC 方法名。
+ plugin_id: 目标插件 ID。
+ component_name: 组件名。
+ args: 调用参数。
+ timeout_ms: RPC 超时时间,单位毫秒。
+
+ Returns:
+ Envelope: RPC 响应信封。
"""
return await self._rpc_server.send_request(
method,
@@ -210,27 +258,421 @@ class PluginRunnerSupervisor:
)
async def reload_plugin(self, plugin_id: str, reason: str = "manual") -> bool:
- raise NotImplementedError("等待SDK完成") # TODO: 完成对应的调用和请求逻辑
+ """按插件 ID 触发精确重载。
+
+ Args:
+ plugin_id: 目标插件 ID。
+ reason: 重载原因。
+
+ Returns:
+ bool: 是否重载成功。
+ """
+ try:
+ response = await self._rpc_server.send_request(
+ "plugin.reload",
+ plugin_id=plugin_id,
+ payload={"plugin_id": plugin_id, "reason": reason},
+ timeout_ms=max(int(self._runner_spawn_timeout * 1000), 10000),
+ )
+ except Exception as exc:
+ logger.error(f"插件 {plugin_id} 重载请求失败: {exc}")
+ return False
+
+ result = ReloadPluginResultPayload.model_validate(response.payload)
+ if not result.success:
+ logger.warning(f"插件 {plugin_id} 重载失败: {result.failed_plugins}")
+ return result.success
+
+ async def reload_plugins(
+ self,
+ plugin_ids: Optional[List[str]] = None,
+ reason: str = "manual",
+ ) -> bool:
+ """批量重载插件。
+
+ Args:
+ plugin_ids: 目标插件 ID 列表;为空时重载当前已注册的全部插件。
+ reason: 重载原因。
+
+ Returns:
+ bool: 是否全部重载成功。
+ """
+ target_plugin_ids = plugin_ids or list(self._registered_plugins.keys())
+ ordered_plugin_ids = list(dict.fromkeys(target_plugin_ids))
+ success = True
+
+ for plugin_id in ordered_plugin_ids:
+ reloaded = await self.reload_plugin(plugin_id=plugin_id, reason=reason)
+ success = success and reloaded
+
+ return success
+
+ async def notify_plugin_config_updated(
+ self,
+ plugin_id: str,
+ config_data: Optional[Dict[str, Any]] = None,
+ config_version: str = "",
+ ) -> bool:
+ """向 Runner 推送插件配置更新。
+
+ Args:
+ plugin_id: 目标插件 ID。
+ config_data: 配置内容。
+ config_version: 配置版本号。
+
+ Returns:
+ bool: 请求是否成功送达并被 Runner 接受。
+ """
+ payload = ConfigUpdatedPayload(
+ plugin_id=plugin_id,
+ config_version=config_version,
+ config_data=config_data or {},
+ )
+ try:
+ response = await self._rpc_server.send_request(
+ "plugin.config_updated",
+ plugin_id=plugin_id,
+ payload=payload.model_dump(),
+ timeout_ms=10000,
+ )
+ except Exception as exc:
+ logger.warning(f"插件 {plugin_id} 配置更新通知失败: {exc}")
+ return False
+
+ return bool(response.payload.get("acknowledged", False))
async def _wait_for_runner_connection(self, timeout_sec: float) -> None:
- """等待 Runner 连接上 RPC Server"""
+ """等待 Runner 建立 RPC 连接。
- async def wait_for_connection():
+ Args:
+ timeout_sec: 超时时间,单位秒。
+
+ Raises:
+ TimeoutError: 在超时时间内 Runner 未完成连接。
+ """
+
+ async def wait_for_connection() -> None:
+ """轮询等待 RPC 连接建立。"""
while self._running and not self._rpc_server.is_connected:
await asyncio.sleep(0.1)
try:
await asyncio.wait_for(wait_for_connection(), timeout=timeout_sec)
logger.info("Runner 已连接到 RPC Server")
- except asyncio.TimeoutError as e:
- raise TimeoutError(f"等待 Runner 连接超时({timeout_sec}s)") from e
+ except asyncio.TimeoutError as exc:
+ raise TimeoutError(f"等待 Runner 连接超时({timeout_sec}s)") from exc
async def _wait_for_runner_ready(self, timeout_sec: float = 30.0) -> RunnerReadyPayload:
- """等待 Runner 完成初始化并上报就绪"""
+ """等待 Runner 完成启动初始化。
+ Args:
+ timeout_sec: 超时时间,单位秒。
+
+ Returns:
+ RunnerReadyPayload: Runner 上报的就绪信息。
+
+ Raises:
+ TimeoutError: 在超时时间内 Runner 未完成初始化。
+ """
try:
await asyncio.wait_for(self._runner_ready_events.wait(), timeout=timeout_sec)
logger.info("Runner 已完成初始化并上报就绪")
return self._runner_ready_payloads
- except asyncio.TimeoutError as e:
- raise TimeoutError(f"等待 Runner 就绪超时({timeout_sec}s)") from e
+ except asyncio.TimeoutError as exc:
+ raise TimeoutError(f"等待 Runner 就绪超时({timeout_sec}s)") from exc
+
+ def _register_internal_methods(self) -> None:
+ """注册 Host 侧内部 RPC 方法。"""
+ self._rpc_server.register_method("cap.call", self._capability_service.handle_capability_request)
+ self._rpc_server.register_method("plugin.bootstrap", self._handle_bootstrap_plugin)
+ self._rpc_server.register_method("plugin.register_components", self._handle_register_plugin)
+ self._rpc_server.register_method("plugin.register_plugin", self._handle_register_plugin)
+ self._rpc_server.register_method("plugin.unregister", self._handle_unregister_plugin)
+ self._rpc_server.register_method("runner.log_batch", self._log_bridge.handle_log_batch)
+ self._rpc_server.register_method("runner.ready", self._handle_runner_ready)
+
+ async def _handle_bootstrap_plugin(self, envelope: Envelope) -> Envelope:
+ """处理插件 bootstrap 请求。
+
+ Args:
+ envelope: RPC 请求信封。
+
+ Returns:
+ Envelope: RPC 响应信封。
+ """
+ try:
+ payload = BootstrapPluginPayload.model_validate(envelope.payload)
+ except Exception as exc:
+ return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
+
+ if payload.capabilities_required:
+ self._authorization.register_plugin(payload.plugin_id, payload.capabilities_required)
+ else:
+ self._authorization.revoke_permission_token(payload.plugin_id)
+
+ return envelope.make_response(payload={"accepted": True, "plugin_id": payload.plugin_id})
+
+ async def _handle_register_plugin(self, envelope: Envelope) -> Envelope:
+ """处理插件组件注册请求。
+
+ Args:
+ envelope: RPC 请求信封。
+
+ Returns:
+ Envelope: RPC 响应信封。
+ """
+ try:
+ payload = RegisterPluginPayload.model_validate(envelope.payload)
+ except Exception as exc:
+ return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
+
+ self._component_registry.remove_components_by_plugin(payload.plugin_id)
+ registered_count = self._component_registry.register_plugin_components(
+ payload.plugin_id,
+ [component.model_dump() for component in payload.components],
+ )
+ self._registered_plugins[payload.plugin_id] = payload
+
+ return envelope.make_response(
+ payload={
+ "accepted": True,
+ "plugin_id": payload.plugin_id,
+ "registered_components": registered_count,
+ }
+ )
+
+ async def _handle_unregister_plugin(self, envelope: Envelope) -> Envelope:
+ """处理插件注销请求。
+
+ Args:
+ envelope: RPC 请求信封。
+
+ Returns:
+ Envelope: RPC 响应信封。
+ """
+ try:
+ payload = UnregisterPluginPayload.model_validate(envelope.payload)
+ except Exception as exc:
+ return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
+
+ removed_components = self._component_registry.remove_components_by_plugin(payload.plugin_id)
+ self._authorization.revoke_permission_token(payload.plugin_id)
+ removed_registration = self._registered_plugins.pop(payload.plugin_id, None) is not None
+
+ return envelope.make_response(
+ payload={
+ "accepted": True,
+ "plugin_id": payload.plugin_id,
+ "reason": payload.reason,
+ "removed_components": removed_components,
+ "removed_registration": removed_registration,
+ }
+ )
+
+ async def _handle_runner_ready(self, envelope: Envelope) -> Envelope:
+ """处理 Runner 就绪通知。
+
+ Args:
+ envelope: RPC 请求信封。
+
+ Returns:
+ Envelope: RPC 响应信封。
+ """
+ try:
+ payload = RunnerReadyPayload.model_validate(envelope.payload)
+ except Exception as exc:
+ return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
+
+ self._runner_ready_payloads = payload
+ self._runner_ready_events.set()
+ return envelope.make_response(payload={"accepted": True})
+
+ def _build_runner_environment(self) -> Dict[str, str]:
+ """构建拉起 Runner 所需的环境变量。
+
+ Returns:
+ Dict[str, str]: 传递给 Runner 进程的环境变量映射。
+ """
+ return {
+ ENV_HOST_VERSION: PROTOCOL_VERSION,
+ ENV_IPC_ADDRESS: self._transport.get_address(),
+ ENV_PLUGIN_DIRS: os.pathsep.join(str(path) for path in self._plugin_dirs),
+ ENV_SESSION_TOKEN: self._rpc_server.session_token,
+ }
+
+ async def _spawn_runner(self) -> None:
+ """拉起 Runner 子进程。"""
+ if self._runner_process is not None and self._runner_process.returncode is None:
+ logger.warning("Runner 已在运行,跳过重复拉起")
+ return
+
+ self._clear_runner_state()
+
+ env = os.environ.copy()
+ env.update(self._build_runner_environment())
+
+ self._runner_process = await asyncio.create_subprocess_exec(
+ sys.executable,
+ "-m",
+ "src.plugin_runtime.runner.runner_main",
+ env=env,
+ stdout=asyncio.subprocess.DEVNULL,
+ stderr=asyncio.subprocess.PIPE,
+ )
+
+ if self._runner_process.stderr is not None:
+ self._stderr_drain_task = asyncio.create_task(
+ self._drain_runner_stderr(self._runner_process.stderr),
+ name="PluginRunnerSupervisor.stderr",
+ )
+
+ logger.info(f"Runner 已拉起,pid={self._runner_process.pid}")
+
+ async def _drain_runner_stderr(self, stream: asyncio.StreamReader) -> None:
+ """持续排空 Runner 的 stderr。
+
+ Args:
+ stream: Runner 的 stderr 流。
+ """
+ try:
+ while True:
+ line = await stream.readline()
+ if not line:
+ return
+ message = line.decode("utf-8", errors="replace").rstrip()
+ if message:
+ logger.warning(f"[runner-stderr] {message}")
+ except asyncio.CancelledError:
+ raise
+ except Exception as exc:
+ logger.warning(f"排空 Runner stderr 失败: {exc}")
+
+ async def _shutdown_runner(self, reason: str = "normal") -> None:
+ """优雅关闭 Runner 子进程。
+
+ Args:
+ reason: 关停原因。
+ """
+ process = self._runner_process
+ if process is None:
+ return
+
+ payload = ShutdownPayload(reason=reason)
+
+ if process.returncode is None and self._rpc_server.is_connected:
+ with contextlib.suppress(Exception):
+ await self._rpc_server.send_request(
+ "plugin.prepare_shutdown",
+ payload=payload.model_dump(),
+ timeout_ms=payload.drain_timeout_ms,
+ )
+ with contextlib.suppress(Exception):
+ await self._rpc_server.send_request(
+ "plugin.shutdown",
+ payload=payload.model_dump(),
+ timeout_ms=payload.drain_timeout_ms,
+ )
+
+ if process.returncode is None:
+ try:
+ await asyncio.wait_for(process.wait(), timeout=max(payload.drain_timeout_ms / 1000.0, 1.0))
+ except asyncio.TimeoutError:
+ logger.warning("Runner 优雅退出超时,尝试 terminate")
+ process.terminate()
+ try:
+ await asyncio.wait_for(process.wait(), timeout=5.0)
+ except asyncio.TimeoutError:
+ logger.warning("Runner terminate 超时,尝试 kill")
+ process.kill()
+ with contextlib.suppress(Exception):
+ await asyncio.wait_for(process.wait(), timeout=5.0)
+
+ self._runner_process = None
+
+ if self._stderr_drain_task is not None:
+ self._stderr_drain_task.cancel()
+ with contextlib.suppress(asyncio.CancelledError):
+ await self._stderr_drain_task
+ self._stderr_drain_task = None
+
+ self._clear_runner_state()
+
+ async def _health_check_loop(self) -> None:
+ """周期性检查 Runner 健康状态,并在必要时重启。"""
+ timeout_ms = max(int(self._health_interval * 1000), 1000)
+
+ while self._running:
+ try:
+ await asyncio.sleep(self._health_interval)
+ except asyncio.CancelledError:
+ return
+
+ if not self._running:
+ return
+
+ process = self._runner_process
+ if process is None or process.returncode is not None:
+ reason = "runner_process_exited" if process is not None else "runner_process_missing"
+ restarted = await self._restart_runner(reason=reason)
+ if not restarted:
+ return
+ continue
+
+ try:
+ response = await self._rpc_server.send_request("plugin.health", timeout_ms=timeout_ms)
+ health = HealthPayload.model_validate(response.payload)
+ if not health.healthy:
+ restarted = await self._restart_runner(reason="health_check_unhealthy")
+ if not restarted:
+ return
+ except asyncio.CancelledError:
+ return
+ except (RPCError, Exception) as exc:
+ logger.warning(f"Runner 健康检查失败: {exc}")
+ restarted = await self._restart_runner(reason="health_check_failed")
+ if not restarted:
+ return
+
+ async def _restart_runner(self, reason: str) -> bool:
+ """在 Runner 异常时执行整进程级重启。
+
+ Args:
+ reason: 触发重启的原因。
+
+ Returns:
+ bool: 是否重启成功。
+ """
+ if not self._running:
+ return False
+
+ if self._restart_count >= self._max_restart_attempts:
+ logger.error(f"Runner 自动重启次数已达上限,停止重启。reason={reason}")
+ return False
+
+ self._restart_count += 1
+ logger.warning(f"准备重启 Runner,第 {self._restart_count} 次,reason={reason}")
+
+ await self._shutdown_runner(reason=reason)
+
+ try:
+ await self._spawn_runner()
+ await self._wait_for_runner_connection(timeout_sec=self._runner_spawn_timeout)
+ await self._wait_for_runner_ready(timeout_sec=self._runner_spawn_timeout)
+ except Exception as exc:
+ logger.error(f"Runner 重启失败: {exc}", exc_info=True)
+ return False
+
+ self._restart_count = 0
+ logger.info("Runner 已成功重启")
+ return True
+
+ def _clear_runner_state(self) -> None:
+ """清理当前 Runner 对应的 Host 侧注册状态。"""
+ self._authorization.clear()
+ self._component_registry.clear()
+ self._registered_plugins.clear()
+ self._runner_ready_events = asyncio.Event()
+ self._runner_ready_payloads = RunnerReadyPayload()
+
+
+PluginSupervisor = PluginRunnerSupervisor
diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py
index 04c8e324..730da3e1 100644
--- a/src/plugin_runtime/integration.py
+++ b/src/plugin_runtime/integration.py
@@ -23,8 +23,10 @@ from src.plugin_runtime.capabilities import (
RuntimeDataCapabilityMixin,
)
from src.plugin_runtime.capabilities.registry import register_capability_impls
+from src.plugin_runtime.host.message_utils import MessageDict, PluginMessageUtils
if TYPE_CHECKING:
+ from src.chat.message_receive.message import SessionMessage
from src.plugin_runtime.host.supervisor import PluginSupervisor
logger = get_logger("plugin_runtime.integration")
@@ -223,9 +225,9 @@ class PluginRuntimeManager(
async def bridge_event(
self,
event_type_value: str,
- message_dict: Optional[Dict[str, Any]] = None,
+ message_dict: Optional[MessageDict] = None,
extra_args: Optional[Dict[str, Any]] = None,
- ) -> Tuple[bool, Optional[Dict[str, Any]]]:
+ ) -> Tuple[bool, Optional[MessageDict]]:
"""将事件分发到所有 Supervisor
Returns:
@@ -235,17 +237,23 @@ class PluginRuntimeManager(
return True, None
new_event_type: str = _EVENT_TYPE_MAP.get(event_type_value, event_type_value)
- modified: Optional[Dict[str, Any]] = None
+ modified: Optional[MessageDict] = None
+ current_message: Optional["SessionMessage"] = (
+ PluginMessageUtils._build_session_message_from_dict(dict(message_dict))
+ if message_dict is not None
+ else None
+ )
for sv in self.supervisors:
try:
cont, mod = await sv.dispatch_event(
event_type=new_event_type,
- message=modified or message_dict,
+ message=current_message,
extra_args=extra_args,
)
if mod is not None:
- modified = mod
+ current_message = mod
+ modified = PluginMessageUtils._session_message_to_dict(mod)
if not cont:
return False, modified
except Exception as e:
@@ -477,7 +485,7 @@ class PluginRuntimeManager(
logger.error(f"检测到重复插件 ID,跳过本次插件热重载: {details}")
return
- reload_supervisors: List[Any] = []
+ reload_supervisors: Dict[Any, List[str]] = {}
changed_paths = [change.path.resolve() for change in changes]
for supervisor in self.supervisors:
@@ -485,11 +493,13 @@ class PluginRuntimeManager(
plugin_id = self._match_plugin_id_for_supervisor(supervisor, path)
if plugin_id is None:
continue
- if (path.name in {"plugin.py", "_manifest.json"} or path.suffix == ".py") and supervisor not in reload_supervisors:
- reload_supervisors.append(supervisor)
+ if path.name in {"plugin.py", "_manifest.json"} or path.suffix == ".py":
+ reload_supervisors.setdefault(supervisor, [])
+ if plugin_id not in reload_supervisors[supervisor]:
+ reload_supervisors[supervisor].append(plugin_id)
- for supervisor in reload_supervisors:
- await supervisor.reload_plugins(reason="file_watcher")
+ for supervisor, plugin_ids in reload_supervisors.items():
+ await supervisor.reload_plugins(plugin_ids=plugin_ids, reason="file_watcher")
if reload_supervisors:
self._refresh_plugin_config_watch_subscriptions()
diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py
index e81df019..6f95f97f 100644
--- a/src/plugin_runtime/protocol/envelope.py
+++ b/src/plugin_runtime/protocol/envelope.py
@@ -144,7 +144,11 @@ class ComponentDeclaration(BaseModel):
class RegisterPluginPayload(BaseModel):
- """plugin.register_plugin 请求 payload"""
+ """插件组件注册请求载荷。
+
+ 该模型同时用于 ``plugin.register_components`` 与兼容旧命名的
+ ``plugin.register_plugin`` 请求。
+ """
plugin_id: str = Field(description="插件 ID")
"""插件 ID"""
@@ -248,6 +252,39 @@ class ShutdownPayload(BaseModel):
"""排空超时 (ms)"""
+class UnregisterPluginPayload(BaseModel):
+ """插件注销请求载荷。"""
+
+ plugin_id: str = Field(description="插件 ID")
+ """插件 ID"""
+ reason: str = Field(default="manual", description="注销原因")
+ """注销原因"""
+
+
+class ReloadPluginPayload(BaseModel):
+ """插件重载请求载荷。"""
+
+ plugin_id: str = Field(description="目标插件 ID")
+ """目标插件 ID"""
+ reason: str = Field(default="manual", description="重载原因")
+ """重载原因"""
+
+
+class ReloadPluginResultPayload(BaseModel):
+ """插件重载结果载荷。"""
+
+ success: bool = Field(description="是否重载成功")
+ """是否重载成功"""
+ requested_plugin_id: str = Field(description="请求重载的插件 ID")
+ """请求重载的插件 ID"""
+ reloaded_plugins: List[str] = Field(default_factory=list, description="成功完成重载的插件列表")
+ """成功完成重载的插件列表"""
+ unloaded_plugins: List[str] = Field(default_factory=list, description="本次已卸载的插件列表")
+ """本次已卸载的插件列表"""
+ failed_plugins: Dict[str, str] = Field(default_factory=dict, description="重载失败的插件及原因")
+ """重载失败的插件及原因"""
+
+
# ====== 日志传输 ======
diff --git a/src/plugin_runtime/runner/plugin_loader.py b/src/plugin_runtime/runner/plugin_loader.py
index 11ba45e7..90c8bf47 100644
--- a/src/plugin_runtime/runner/plugin_loader.py
+++ b/src/plugin_runtime/runner/plugin_loader.py
@@ -32,11 +32,22 @@ class PluginMeta:
self,
plugin_id: str,
plugin_dir: str,
+ module_name: str,
plugin_instance: Any,
manifest: Dict[str, Any],
) -> None:
+ """初始化插件元数据。
+
+ Args:
+ plugin_id: 插件 ID。
+ plugin_dir: 插件目录绝对路径。
+ module_name: 插件入口模块名。
+ plugin_instance: 插件实例对象。
+ manifest: 解析后的 manifest 内容。
+ """
self.plugin_id = plugin_id
self.plugin_dir = plugin_dir
+ self.module_name = module_name
self.instance = plugin_instance
self.manifest = manifest
self.version = manifest.get("version", "1.0.0")
@@ -45,6 +56,14 @@ class PluginMeta:
@staticmethod
def _extract_dependencies(manifest: Dict[str, Any]) -> List[str]:
+ """从 manifest 中提取依赖列表。
+
+ Args:
+ manifest: 插件 manifest。
+
+ Returns:
+ List[str]: 规范化后的依赖插件 ID 列表。
+ """
raw = manifest.get("dependencies", [])
result: List[str] = []
for dep in raw:
@@ -66,19 +85,24 @@ class PluginLoader:
"""
def __init__(self, host_version: str = "") -> None:
+ """初始化插件加载器。
+
+ Args:
+ host_version: Host 版本号,用于 manifest 兼容性校验。
+ """
self._loaded_plugins: Dict[str, PluginMeta] = {}
self._failed_plugins: Dict[str, str] = {}
self._manifest_validator = ManifestValidator(host_version=host_version)
self._compat_hook_installed = False
def discover_and_load(self, plugin_dirs: List[str]) -> List[PluginMeta]:
- """扫描多个目录并加载所有插件(含依赖排序和 manifest 校验)
+ """扫描多个目录并加载所有插件。
Args:
- plugin_dirs: 插件目录列表
+ plugin_dirs: 插件目录列表。
Returns:
- 成功加载的插件元数据列表(按依赖顺序)
+ List[PluginMeta]: 成功加载的插件元数据列表,按依赖顺序排列。
"""
candidates, duplicate_candidates = self._discover_candidates(plugin_dirs)
self._record_duplicate_candidates(duplicate_candidates)
@@ -90,6 +114,18 @@ class PluginLoader:
# 第三阶段:按依赖顺序加载
return self._load_plugins_in_order(load_order, candidates)
+ def discover_candidates(self, plugin_dirs: List[str]) -> Tuple[Dict[str, PluginCandidate], Dict[str, List[Path]]]:
+ """扫描插件目录并返回候选插件。
+
+ Args:
+ plugin_dirs: 需要扫描的插件根目录列表。
+
+ Returns:
+ Tuple[Dict[str, PluginCandidate], Dict[str, List[Path]]]:
+ 候选插件映射和重复插件 ID 冲突映射。
+ """
+ return self._discover_candidates(plugin_dirs)
+
def _discover_candidates(self, plugin_dirs: List[str]) -> Tuple[Dict[str, PluginCandidate], Dict[str, List[Path]]]:
"""扫描插件目录并收集候选插件。"""
candidates: Dict[str, PluginCandidate] = {}
@@ -170,7 +206,6 @@ class PluginLoader:
plugin_dir, manifest, plugin_path = candidates[plugin_id]
try:
if meta := self._load_single_plugin(plugin_id, plugin_dir, manifest, plugin_path):
- self._loaded_plugins[meta.plugin_id] = meta
results.append(meta)
except Exception as e:
self._failed_plugins[plugin_id] = str(e)
@@ -182,22 +217,109 @@ class PluginLoader:
"""获取已加载的插件"""
return self._loaded_plugins.get(plugin_id)
+ def set_loaded_plugin(self, meta: PluginMeta) -> None:
+ """登记一个已经完成初始化的插件。
+
+ Args:
+ meta: 待登记的插件元数据。
+ """
+ self._loaded_plugins[meta.plugin_id] = meta
+
+ def remove_loaded_plugin(self, plugin_id: str) -> Optional[PluginMeta]:
+ """移除一个已加载插件的元数据。
+
+ Args:
+ plugin_id: 待移除的插件 ID。
+
+ Returns:
+ Optional[PluginMeta]: 被移除的插件元数据;不存在时返回 ``None``。
+ """
+ return self._loaded_plugins.pop(plugin_id, None)
+
+ def purge_plugin_modules(self, plugin_id: str, plugin_dir: str) -> List[str]:
+ """清理指定插件目录下的模块缓存。
+
+ Args:
+ plugin_id: 插件 ID。
+ plugin_dir: 插件目录绝对路径。
+
+ Returns:
+ List[str]: 已从 ``sys.modules`` 中移除的模块名列表。
+ """
+ removed_modules: List[str] = []
+ plugin_path = Path(plugin_dir).resolve()
+ synthetic_module_name = f"_maibot_plugin_{plugin_id}"
+
+ for module_name, module in list(sys.modules.items()):
+ if module_name == synthetic_module_name:
+ removed_modules.append(module_name)
+ sys.modules.pop(module_name, None)
+ continue
+
+ module_file = getattr(module, "__file__", None)
+ if module_file is None:
+ continue
+
+ try:
+ module_path = Path(module_file).resolve()
+ except Exception:
+ continue
+
+ if module_path.is_relative_to(plugin_path):
+ removed_modules.append(module_name)
+ sys.modules.pop(module_name, None)
+
+ importlib.invalidate_caches()
+ return removed_modules
+
def list_plugins(self) -> List[str]:
"""列出所有已加载的插件 ID"""
return list(self._loaded_plugins.keys())
@property
def failed_plugins(self) -> Dict[str, str]:
+ """返回当前记录的失败插件原因映射。"""
return dict(self._failed_plugins)
# ──── 依赖解析 ────────────────────────────────────────────
+ def resolve_dependencies(
+ self,
+ candidates: Dict[str, PluginCandidate],
+ extra_available: Optional[Set[str]] = None,
+ ) -> Tuple[List[str], Dict[str, str]]:
+ """解析候选插件的依赖顺序。
+
+ Args:
+ candidates: 待加载的候选插件集合。
+ extra_available: 视为已满足的外部依赖插件 ID 集合。
+
+ Returns:
+ Tuple[List[str], Dict[str, str]]: 可加载顺序和失败原因映射。
+ """
+ return self._resolve_dependencies(candidates, extra_available=extra_available)
+
+ def load_candidate(self, plugin_id: str, candidate: PluginCandidate) -> Optional[PluginMeta]:
+ """加载单个候选插件模块。
+
+ Args:
+ plugin_id: 插件 ID。
+ candidate: 候选插件三元组。
+
+ Returns:
+ Optional[PluginMeta]: 加载成功的插件元数据;失败时返回 ``None``。
+ """
+ plugin_dir, manifest, plugin_path = candidate
+ return self._load_single_plugin(plugin_id, plugin_dir, manifest, plugin_path)
+
def _resolve_dependencies(
self,
candidates: Dict[str, PluginCandidate],
+ extra_available: Optional[Set[str]] = None,
) -> Tuple[List[str], Dict[str, str]]:
"""拓扑排序解析加载顺序,返回 (有序列表, 失败项 {id: reason})。"""
available = set(candidates.keys())
+ satisfied_dependencies = set(extra_available or set())
dep_graph: Dict[str, Set[str]] = {}
failed: Dict[str, str] = {}
@@ -212,6 +334,8 @@ class PluginLoader:
continue
if dep_name in available:
resolved.add(dep_name)
+ elif dep_name in satisfied_dependencies:
+ continue
else:
missing.append(dep_name)
if missing:
@@ -271,33 +395,39 @@ class PluginLoader:
sys.modules[module_name] = module
plugin_parent_dir = plugin_dir.parent
- with self._temporary_sys_path_entry(plugin_parent_dir):
- spec.loader.exec_module(module)
+ try:
+ with self._temporary_sys_path_entry(plugin_parent_dir):
+ spec.loader.exec_module(module)
- # 优先使用新版 create_plugin 工厂函数
- create_plugin = getattr(module, "create_plugin", None)
- if create_plugin is not None:
- instance = create_plugin()
- logger.info(f"插件 {plugin_id} v{manifest.get('version', '?')} 加载成功")
- return PluginMeta(
- plugin_id=plugin_id,
- plugin_dir=str(plugin_dir),
- plugin_instance=instance,
- manifest=manifest,
- )
+ # 优先使用新版 create_plugin 工厂函数
+ create_plugin = getattr(module, "create_plugin", None)
+ if create_plugin is not None:
+ instance = create_plugin()
+ logger.info(f"插件 {plugin_id} v{manifest.get('version', '?')} 加载成功")
+ return PluginMeta(
+ plugin_id=plugin_id,
+ plugin_dir=str(plugin_dir),
+ module_name=module_name,
+ plugin_instance=instance,
+ manifest=manifest,
+ )
- # 回退:检测旧版 @register_plugin 标记的 BasePlugin 子类
- instance = self._try_load_legacy_plugin(module, plugin_id)
- if instance is not None:
- logger.info(
- f"插件 {plugin_id} v{manifest.get('version', '?')} 通过旧版兼容层加载成功(请尽快迁移到 maibot_sdk)"
- )
- return PluginMeta(
- plugin_id=plugin_id,
- plugin_dir=str(plugin_dir),
- plugin_instance=instance,
- manifest=manifest,
- )
+ # 回退:检测旧版 @register_plugin 标记的 BasePlugin 子类
+ instance = self._try_load_legacy_plugin(module, plugin_id)
+ if instance is not None:
+ logger.info(
+ f"插件 {plugin_id} v{manifest.get('version', '?')} 通过旧版兼容层加载成功(请尽快迁移到 maibot_sdk)"
+ )
+ return PluginMeta(
+ plugin_id=plugin_id,
+ plugin_dir=str(plugin_dir),
+ module_name=module_name,
+ plugin_instance=instance,
+ manifest=manifest,
+ )
+ except Exception:
+ sys.modules.pop(module_name, None)
+ raise
logger.error(f"插件 {plugin_id} 缺少 create_plugin 工厂函数且未检测到旧版 BasePlugin")
return None
diff --git a/src/plugin_runtime/runner/rpc_client.py b/src/plugin_runtime/runner/rpc_client.py
index 6a1d59d5..dc917cc8 100644
--- a/src/plugin_runtime/runner/rpc_client.py
+++ b/src/plugin_runtime/runner/rpc_client.py
@@ -1,14 +1,6 @@
-"""Runner 端 RPC Client
+"""Runner 端 RPC 客户端。"""
-负责:
-1. 连接 Host RPC Server
-2. 发送握手(runner.hello)
-3. 发送组件注册请求
-4. 接收并分发 Host 的调用请求
-5. 发送能力调用请求到 Host
-"""
-
-from typing import Any, Awaitable, Callable, Dict, Optional, cast
+from typing import Any, Awaitable, Callable, Dict, Optional, Set, cast
import asyncio
import contextlib
@@ -29,12 +21,15 @@ from src.plugin_runtime.transport.factory import create_transport_client
logger = get_logger("plugin_runtime.runner.rpc_client")
-# RPC 方法处理器类型
MethodHandler = Callable[[Envelope], Awaitable[Envelope]]
def _get_sdk_version() -> str:
- """从 maibot_sdk 包元数据中读取实际版本号,失败时回退到 1.0.0。"""
+ """读取 SDK 版本号。
+
+ Returns:
+ str: 已安装的 SDK 版本;读取失败时回退到 ``1.0.0``。
+ """
try:
from importlib.metadata import version
@@ -47,73 +42,78 @@ SDK_VERSION = _get_sdk_version()
class RPCClient:
- """Runner 端 RPC 客户端
-
- 管理与 Host 的 IPC 连接,支持双向 RPC 调用。
- """
+ """Runner 端 RPC 客户端。"""
def __init__(
self,
host_address: str,
session_token: str,
codec: Optional[Codec] = None,
- ):
- self._host_address = host_address
- self._session_token = session_token
- self._codec = codec or MsgPackCodec()
+ ) -> None:
+ """初始化 RPC 客户端。
+
+ Args:
+ host_address: Host 的 IPC 地址。
+ session_token: 握手用会话令牌。
+ codec: 可选的编解码器实现。
+ """
+ self._host_address: str = host_address
+ self._session_token: str = session_token
+ self._codec: Codec = codec or MsgPackCodec()
self._id_gen = RequestIdGenerator()
self._connection: Optional[Connection] = None
- self._runner_id = str(uuid.uuid4())
- self._generation: int = 0
-
- # 方法处理器注册表(Host 发来的调用)
+ self._runner_id: str = str(uuid.uuid4())
self._method_handlers: Dict[str, MethodHandler] = {}
-
- # 等待响应的 pending 请求: request_id -> Future
- self._pending_requests: Dict[int, asyncio.Future] = {}
-
- # 运行状态
- self._running = False
- self._recv_task: Optional[asyncio.Task] = None
- self._background_tasks: set[asyncio.Task] = set()
-
- @property
- def generation(self) -> int:
- return self._generation
+ self._pending_requests: Dict[int, asyncio.Future[Envelope]] = {}
+ self._running: bool = False
+ self._recv_task: Optional[asyncio.Task[None]] = None
+ self._background_tasks: Set[asyncio.Task[Any]] = set()
@property
def is_connected(self) -> bool:
+ """返回当前连接是否可用。"""
return self._connection is not None and not self._connection.is_closed
def register_method(self, method: str, handler: MethodHandler) -> None:
- """注册方法处理器(处理 Host 发来的请求)"""
+ """注册 Host -> Runner 的 RPC 处理器。
+
+ Args:
+ method: RPC 方法名。
+ handler: 方法处理函数。
+ """
self._method_handlers[method] = handler
def _require_connection(self) -> Connection:
- """返回当前可用连接;若连接不可用则抛出 RPCError。"""
+ """返回当前可用连接。
+
+ Returns:
+ Connection: 当前连接对象。
+
+ Raises:
+ RPCError: 当前未连接到 Host。
+ """
connection = self._connection
if connection is None or connection.is_closed:
raise RPCError(ErrorCode.E_UNKNOWN, "未连接到 Host")
return cast(Connection, connection)
async def connect_and_handshake(self) -> bool:
- """连接 Host 并完成握手
+ """连接 Host 并完成握手。
Returns:
- 是否握手成功
+ bool: 是否握手成功。
"""
client = create_transport_client(self._host_address)
self._connection = await client.connect()
connection = self._require_connection()
- # 发送 runner.hello
hello = HelloPayload(
runner_id=self._runner_id,
sdk_version=SDK_VERSION,
session_token=self._session_token,
)
- request_id = self._id_gen.next()
+ request_id = await self._id_gen.next()
envelope = Envelope(
request_id=request_id,
message_type=MessageType.REQUEST,
@@ -121,33 +121,27 @@ class RPCClient:
payload=hello.model_dump(),
)
- data = self._codec.encode_envelope(envelope)
- await connection.send_frame(data)
+ await connection.send_frame(self._codec.encode_envelope(envelope))
- # 接收握手响应
resp_data = await asyncio.wait_for(connection.recv_frame(), timeout=10.0)
- resp = self._codec.decode_envelope(resp_data)
+ response = self._codec.decode_envelope(resp_data)
+ resp_payload = HelloResponsePayload.model_validate(response.payload)
- resp_payload = HelloResponsePayload.model_validate(resp.payload)
if not resp_payload.accepted:
logger.error(f"握手被拒绝: {resp_payload.reason}")
- await self._connection.close()
- self._connection = None
+ await self.disconnect()
return False
- self._generation = resp_payload.assigned_generation
- logger.info(f"握手成功: generation={self._generation}, host_version={resp_payload.host_version}")
-
- # 启动消息接收循环
+ logger.info(f"握手成功: host_version={resp_payload.host_version}")
self._running = True
- self._recv_task = asyncio.create_task(self._recv_loop())
-
+ self._recv_task = asyncio.create_task(self._recv_loop(), name="RPCClient.recv")
return True
async def disconnect(self) -> None:
- """断开连接"""
+ """断开与 Host 的连接并清理状态。"""
self._running = False
- if self._recv_task:
+
+ if self._recv_task is not None:
self._recv_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._recv_task
@@ -160,13 +154,12 @@ class RPCClient:
await asyncio.gather(*self._background_tasks, return_exceptions=True)
self._background_tasks.clear()
- # 取消所有 pending 请求
for future in self._pending_requests.values():
if not future.done():
future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "连接关闭"))
self._pending_requests.clear()
- if self._connection:
+ if self._connection is not None:
await self._connection.close()
self._connection = None
@@ -177,16 +170,27 @@ class RPCClient:
payload: Optional[Dict[str, Any]] = None,
timeout_ms: int = 30000,
) -> Envelope:
- """向 Host 发送 RPC 请求并等待响应"""
- connection = self._require_connection()
+ """向 Host 发送 RPC 请求并等待响应。
- request_id = self._id_gen.next()
+ Args:
+ method: RPC 方法名。
+ plugin_id: 目标插件 ID。
+ payload: 请求载荷。
+ timeout_ms: 超时时间,单位毫秒。
+
+ Returns:
+ Envelope: Host 返回的响应信封。
+
+ Raises:
+ RPCError: 发送失败、超时或连接异常。
+ """
+ connection = self._require_connection()
+ request_id = await self._id_gen.next()
envelope = Envelope(
request_id=request_id,
message_type=MessageType.REQUEST,
method=method,
plugin_id=plugin_id,
- generation=self._generation,
timeout_ms=timeout_ms,
payload=payload or {},
)
@@ -196,21 +200,16 @@ class RPCClient:
self._pending_requests[request_id] = future
try:
- data = self._codec.encode_envelope(envelope)
- await connection.send_frame(data)
-
- timeout_sec = timeout_ms / 1000.0
- return await asyncio.wait_for(future, timeout=timeout_sec)
+ await connection.send_frame(self._codec.encode_envelope(envelope))
+ return await asyncio.wait_for(future, timeout=timeout_ms / 1000.0)
except asyncio.TimeoutError:
self._pending_requests.pop(request_id, None)
raise RPCError(ErrorCode.E_TIMEOUT, f"请求 {method} 超时 ({timeout_ms}ms)") from None
- except Exception as e:
+ except Exception as exc:
self._pending_requests.pop(request_id, None)
- if isinstance(e, RPCError):
+ if isinstance(exc, RPCError):
raise
- raise RPCError(ErrorCode.E_UNKNOWN, str(e)) from e
-
- # ─── 内部方法 ──────────────────────────────────────────────
+ raise RPCError(ErrorCode.E_UNKNOWN, str(exc)) from exc
async def send_event(
self,
@@ -218,33 +217,30 @@ class RPCClient:
plugin_id: str = "",
payload: Optional[Dict[str, Any]] = None,
) -> None:
- """向 Host 发送单向事件(fire-and-forget,不等待响应)。
+ """向 Host 发送单向广播消息。
Args:
- method: RPC 方法名,如 "runner.log_batch"。
- plugin_id: 目标插件 ID(可为空,表示 Runner 级消息)。
- payload: 事件数据。
+ method: RPC 方法名。
+ plugin_id: 目标插件 ID。
+ payload: 广播载荷。
"""
if not self.is_connected:
return
connection = self._require_connection()
-
- request_id = self._id_gen.next()
+ request_id = await self._id_gen.next()
envelope = Envelope(
request_id=request_id,
- message_type=MessageType.EVENT,
+ message_type=MessageType.BROADCAST,
method=method,
plugin_id=plugin_id,
- generation=self._generation,
payload=payload or {},
)
- data = self._codec.encode_envelope(envelope)
- await connection.send_frame(data)
+ await connection.send_frame(self._codec.encode_envelope(envelope))
async def _recv_loop(self) -> None:
- """消息接收主循环"""
- while self._running and self._connection and not self._connection.is_closed:
+ """持续接收 Host 发来的消息并分发。"""
+ while self._running and self._connection is not None and not self._connection.is_closed:
try:
data = await self._connection.recv_frame()
except (asyncio.IncompleteReadError, ConnectionError):
@@ -252,39 +248,47 @@ class RPCClient:
break
except asyncio.CancelledError:
break
- except Exception as e:
- logger.error(f"接收帧失败: {e}")
+ except Exception as exc:
+ logger.error(f"接收帧失败: {exc}")
break
try:
envelope = self._codec.decode_envelope(data)
- except Exception as e:
- logger.error(f"解码消息失败: {e}")
+ except Exception as exc:
+ logger.error(f"解码消息失败: {exc}")
continue
if envelope.is_response():
self._handle_response(envelope)
elif envelope.is_request():
self._track_background_task(asyncio.create_task(self._handle_request(envelope)))
- elif envelope.is_event():
- self._track_background_task(asyncio.create_task(self._handle_event(envelope)))
+ elif envelope.is_broadcast():
+ self._track_background_task(asyncio.create_task(self._handle_broadcast(envelope)))
def _handle_response(self, envelope: Envelope) -> None:
- """处理来自 Host 的响应"""
+ """处理 Host 返回的响应。
+
+ Args:
+ envelope: 响应信封。
+ """
future = self._pending_requests.pop(envelope.request_id, None)
- if future and not future.done():
- if envelope.error:
- future.set_exception(RPCError.from_dict(envelope.error))
- else:
- future.set_result(envelope)
+ if future is None or future.done():
+ return
+ if envelope.error:
+ future.set_exception(RPCError.from_dict(envelope.error))
+ else:
+ future.set_result(envelope)
async def _handle_request(self, envelope: Envelope) -> None:
- """处理来自 Host 的请求(调用插件组件)"""
+ """处理 Host 发来的请求。
+
+ Args:
+ envelope: 请求信封。
+ """
connection = self._connection
if connection is None or connection.is_closed:
logger.warning(f"处理请求 {envelope.method} 时连接已关闭,跳过响应")
return
- connection = cast(Connection, connection)
handler = self._method_handlers.get(envelope.method)
if handler is None:
@@ -298,23 +302,34 @@ class RPCClient:
try:
response = await handler(envelope)
await connection.send_frame(self._codec.encode_envelope(response))
- except RPCError as e:
- error_resp = envelope.make_error_response(e.code.value, e.message, e.details)
+ except RPCError as exc:
+ error_resp = envelope.make_error_response(exc.code.value, exc.message, exc.details)
await connection.send_frame(self._codec.encode_envelope(error_resp))
- except Exception as e:
- logger.error(f"处理请求 {envelope.method} 异常: {e}", exc_info=True)
- error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
+ except Exception as exc:
+ logger.error(f"处理请求 {envelope.method} 异常: {exc}", exc_info=True)
+ error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(exc))
await connection.send_frame(self._codec.encode_envelope(error_resp))
- async def _handle_event(self, envelope: Envelope) -> None:
- """处理来自 Host 的事件"""
- if handler := self._method_handlers.get(envelope.method):
- try:
- await handler(envelope)
- except Exception as e:
- logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True)
+ async def _handle_broadcast(self, envelope: Envelope) -> None:
+ """处理 Host 发来的广播事件。
- def _track_background_task(self, task: asyncio.Task) -> None:
- """保持后台任务强引用,直到其完成或被取消。"""
+ Args:
+ envelope: 广播信封。
+ """
+ handler = self._method_handlers.get(envelope.method)
+ if handler is None:
+ return
+
+ try:
+ await handler(envelope)
+ except Exception as exc:
+ logger.error(f"处理广播 {envelope.method} 异常: {exc}", exc_info=True)
+
+ def _track_background_task(self, task: asyncio.Task[Any]) -> None:
+ """持有后台任务强引用直到其结束。
+
+ Args:
+ task: 后台任务。
+ """
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py
index dae1cfa1..771e685f 100644
--- a/src/plugin_runtime/runner/runner_main.py
+++ b/src/plugin_runtime/runner/runner_main.py
@@ -9,7 +9,7 @@
6. 转发插件的能力调用到 Host
"""
-from typing import Any, Callable, List, Optional, Protocol, cast
+from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, cast
from pathlib import Path
@@ -32,8 +32,11 @@ from src.plugin_runtime.protocol.envelope import (
HealthPayload,
InvokePayload,
InvokeResultPayload,
- RegisterComponentsPayload,
+ RegisterPluginPayload,
+ ReloadPluginPayload,
+ ReloadPluginResultPayload,
RunnerReadyPayload,
+ UnregisterPluginPayload,
)
from src.plugin_runtime.protocol.errors import ErrorCode
from src.plugin_runtime.runner.log_handler import RunnerIPCLogHandler
@@ -44,7 +47,8 @@ logger = get_logger("plugin_runtime.runner.main")
class _ContextAwarePlugin(Protocol):
- def _set_context(self, context: Any) -> None: ...
+ def _set_context(self, context: Any) -> None:
+ """为插件注入上下文对象。"""
def _install_shutdown_signal_handlers(
@@ -90,21 +94,29 @@ class PluginRunner:
session_token: str,
plugin_dirs: List[str],
) -> None:
+ """初始化 Runner。
+
+ Args:
+ host_address: Host 的 IPC 地址。
+ session_token: 握手用会话令牌。
+ plugin_dirs: 当前 Runner 负责扫描的插件目录列表。
+ """
self._host_address: str = host_address
self._session_token: str = session_token
- self._plugin_dirs: list[str] = plugin_dirs
+ self._plugin_dirs: List[str] = plugin_dirs
self._rpc_client: RPCClient = RPCClient(host_address, session_token)
self._loader: PluginLoader = PluginLoader(host_version=os.getenv(ENV_HOST_VERSION, ""))
self._start_time: float = time.monotonic()
self._shutting_down: bool = False
+ self._reload_lock: asyncio.Lock = asyncio.Lock()
# IPC 日志 Handler:握手成功后安装,将所有 stdlib logging 转发到 Host
self._log_handler: Optional[RunnerIPCLogHandler] = None
- self._suspended_console_handlers: list[stdlib_logging.Handler] = []
+ self._suspended_console_handlers: List[stdlib_logging.Handler] = []
async def run(self) -> None:
- """Runner 主入口"""
+ """运行 Runner 主循环。"""
# 1. 连接 Host
logger.info(f"Runner 启动,连接 Host: {self._host_address}")
ok = await self._rpc_client.connect_and_handshake()
@@ -123,32 +135,11 @@ class PluginRunner:
logger.info(f"已加载 {len(plugins)} 个插件")
# 4. 注入 PluginContext + 调用 on_load 生命周期钩子
- failed_plugins: set[str] = set()
+ failed_plugins: Set[str] = set(self._loader.failed_plugins.keys())
for meta in plugins:
- instance = meta.instance
- self._inject_context(meta.plugin_id, instance)
- self._apply_plugin_config(meta)
- if not await self._bootstrap_plugin(meta):
- failed_plugins.add(meta.plugin_id)
- continue
- if hasattr(instance, "on_load"):
- try:
- ret = instance.on_load()
- if asyncio.iscoroutine(ret):
- await ret
- except Exception as e:
- logger.error(f"插件 {meta.plugin_id} on_load 失败,跳过注册: {e}", exc_info=True)
- failed_plugins.add(meta.plugin_id)
- await self._deactivate_plugin(meta)
-
- # 5. 向 Host 注册所有插件的组件(跳过 on_load 失败的插件)
- for meta in plugins:
- if meta.plugin_id in failed_plugins:
- continue
- ok = await self._register_plugin(meta)
+ ok = await self._activate_plugin(meta)
if not ok:
failed_plugins.add(meta.plugin_id)
- await self._deactivate_plugin(meta)
successful_plugins = [meta.plugin_id for meta in plugins if meta.plugin_id not in failed_plugins]
await self._notify_ready(successful_plugins, sorted(failed_plugins))
@@ -232,7 +223,9 @@ class PluginRunner:
bound_plugin_id = plugin_id
async def _rpc_call(
- method: str, plugin_id: str = "", payload: Optional[dict[str, Any]] = None
+ method: str,
+ plugin_id: str = "",
+ payload: Optional[Dict[str, Any]] = None,
) -> Any:
"""桥接 PluginContext.call_capability → RPCClient.send_request。
@@ -257,7 +250,7 @@ class PluginRunner:
cast(_ContextAwarePlugin, instance)._set_context(ctx)
logger.debug(f"已为插件 {plugin_id} 注入 PluginContext")
- def _apply_plugin_config(self, meta: PluginMeta, config_data: Optional[dict[str, Any]] = None) -> None:
+ def _apply_plugin_config(self, meta: PluginMeta, config_data: Optional[Dict[str, Any]] = None) -> None:
"""在 Runner 侧为插件实例注入当前插件配置。"""
instance = meta.instance
if not hasattr(instance, "set_plugin_config"):
@@ -270,7 +263,7 @@ class PluginRunner:
logger.warning(f"插件 {meta.plugin_id} 配置注入失败: {exc}")
@staticmethod
- def _load_plugin_config(plugin_dir: str) -> dict[str, Any]:
+ def _load_plugin_config(plugin_dir: str) -> Dict[str, Any]:
"""从插件目录读取 config.toml。"""
config_path = Path(plugin_dir) / "config.toml"
if not config_path.exists():
@@ -286,16 +279,18 @@ class PluginRunner:
return loaded if isinstance(loaded, dict) else {}
def _register_handlers(self) -> None:
- """注册方法处理器"""
+ """注册 Host -> Runner 的方法处理器。"""
self._rpc_client.register_method("plugin.invoke_command", self._handle_invoke)
self._rpc_client.register_method("plugin.invoke_action", self._handle_invoke)
self._rpc_client.register_method("plugin.invoke_tool", self._handle_invoke)
self._rpc_client.register_method("plugin.emit_event", self._handle_event_invoke)
+ self._rpc_client.register_method("plugin.invoke_hook", self._handle_hook_invoke)
self._rpc_client.register_method("plugin.invoke_workflow_step", self._handle_workflow_step)
self._rpc_client.register_method("plugin.health", self._handle_health)
self._rpc_client.register_method("plugin.prepare_shutdown", self._handle_prepare_shutdown)
self._rpc_client.register_method("plugin.shutdown", self._handle_shutdown)
self._rpc_client.register_method("plugin.config_updated", self._handle_config_updated)
+ self._rpc_client.register_method("plugin.reload", self._handle_reload_plugin)
async def _bootstrap_plugin(self, meta: PluginMeta, capabilities_required: Optional[List[str]] = None) -> bool:
"""向 Host 同步插件 bootstrap 能力令牌。"""
@@ -324,7 +319,14 @@ class PluginRunner:
await self._bootstrap_plugin(meta, capabilities_required=[])
async def _register_plugin(self, meta: PluginMeta) -> bool:
- """向 Host 注册单个插件"""
+ """向 Host 注册单个插件。
+
+ Args:
+ meta: 待注册的插件元数据。
+
+ Returns:
+ bool: 是否注册成功。
+ """
# 收集插件组件声明
components: List[ComponentDeclaration] = []
instance = meta.instance
@@ -341,7 +343,7 @@ class PluginRunner:
for comp_info in instance.get_components()
)
- reg_payload = RegisterComponentsPayload(
+ reg_payload = RegisterPluginPayload(
plugin_id=meta.plugin_id,
plugin_version=meta.version,
components=components,
@@ -361,8 +363,281 @@ class PluginRunner:
logger.error(f"插件 {meta.plugin_id} 注册失败: {e}")
return False
+ async def _unregister_plugin(self, plugin_id: str, reason: str) -> None:
+ """通知 Host 注销指定插件。
+
+ Args:
+ plugin_id: 目标插件 ID。
+ reason: 注销原因。
+ """
+ payload = UnregisterPluginPayload(plugin_id=plugin_id, reason=reason)
+ try:
+ await self._rpc_client.send_request(
+ "plugin.unregister",
+ plugin_id=plugin_id,
+ payload=payload.model_dump(),
+ timeout_ms=10000,
+ )
+ except Exception as exc:
+ logger.warning(f"插件 {plugin_id} 注销通知失败: {exc}")
+
+ async def _invoke_plugin_on_load(self, meta: PluginMeta) -> bool:
+ """执行插件的 ``on_load`` 生命周期。
+
+ Args:
+ meta: 待初始化的插件元数据。
+
+ Returns:
+ bool: 生命周期是否执行成功。
+ """
+ instance = meta.instance
+ if not hasattr(instance, "on_load"):
+ return True
+
+ try:
+ result = instance.on_load()
+ if asyncio.iscoroutine(result):
+ await result
+ return True
+ except Exception as exc:
+ logger.error(f"插件 {meta.plugin_id} on_load 失败: {exc}", exc_info=True)
+ return False
+
+ async def _invoke_plugin_on_unload(self, meta: PluginMeta) -> None:
+ """执行插件的 ``on_unload`` 生命周期。
+
+ Args:
+ meta: 待卸载的插件元数据。
+ """
+ instance = meta.instance
+ if not hasattr(instance, "on_unload"):
+ return
+
+ try:
+ result = instance.on_unload()
+ if asyncio.iscoroutine(result):
+ await result
+ except Exception as exc:
+ logger.error(f"插件 {meta.plugin_id} on_unload 失败: {exc}", exc_info=True)
+
+ async def _activate_plugin(self, meta: PluginMeta) -> bool:
+ """完成插件注入、授权、生命周期和组件注册。
+
+ Args:
+ meta: 待激活的插件元数据。
+
+ Returns:
+ bool: 是否激活成功。
+ """
+ self._inject_context(meta.plugin_id, meta.instance)
+ self._apply_plugin_config(meta)
+
+ if not await self._bootstrap_plugin(meta):
+ self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
+ return False
+
+ if not await self._invoke_plugin_on_load(meta):
+ await self._deactivate_plugin(meta)
+ self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
+ return False
+
+ if not await self._register_plugin(meta):
+ await self._invoke_plugin_on_unload(meta)
+ await self._deactivate_plugin(meta)
+ self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
+ return False
+
+ self._loader.set_loaded_plugin(meta)
+ return True
+
+ async def _unload_plugin(self, meta: PluginMeta, reason: str) -> None:
+ """卸载单个插件并清理 Host/Runner 两侧状态。
+
+ Args:
+ meta: 待卸载的插件元数据。
+ reason: 卸载原因。
+ """
+ await self._invoke_plugin_on_unload(meta)
+ await self._unregister_plugin(meta.plugin_id, reason)
+ await self._deactivate_plugin(meta)
+ self._loader.remove_loaded_plugin(meta.plugin_id)
+ self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
+
+ def _collect_reverse_dependents(self, plugin_id: str) -> Set[str]:
+ """收集依赖指定插件的所有已加载插件。
+
+ Args:
+ plugin_id: 根插件 ID。
+
+ Returns:
+ Set[str]: 目标插件及其所有反向依赖插件集合。
+ """
+ impacted_plugins: Set[str] = {plugin_id}
+ changed = True
+
+ while changed:
+ changed = False
+ for loaded_plugin_id in self._loader.list_plugins():
+ if loaded_plugin_id in impacted_plugins:
+ continue
+
+ meta = self._loader.get_plugin(loaded_plugin_id)
+ if meta is None:
+ continue
+
+ if any(dependency in impacted_plugins for dependency in meta.dependencies):
+ impacted_plugins.add(loaded_plugin_id)
+ changed = True
+
+ return impacted_plugins
+
+ def _build_unload_order(self, plugin_ids: Set[str]) -> List[str]:
+ """构建受影响插件的卸载顺序。
+
+ Args:
+ plugin_ids: 需要卸载的插件集合。
+
+ Returns:
+ List[str]: 依赖方优先的卸载顺序。
+ """
+ dependency_graph: Dict[str, Set[str]] = {}
+ for plugin_id in plugin_ids:
+ meta = self._loader.get_plugin(plugin_id)
+ if meta is None:
+ dependency_graph[plugin_id] = set()
+ continue
+ dependency_graph[plugin_id] = {dependency for dependency in meta.dependencies if dependency in plugin_ids}
+
+ indegree: Dict[str, int] = {plugin_id: len(dependencies) for plugin_id, dependencies in dependency_graph.items()}
+ reverse_graph: Dict[str, Set[str]] = {plugin_id: set() for plugin_id in dependency_graph}
+
+ for plugin_id, dependencies in dependency_graph.items():
+ for dependency in dependencies:
+ reverse_graph.setdefault(dependency, set()).add(plugin_id)
+
+ queue: List[str] = sorted(plugin_id for plugin_id, degree in indegree.items() if degree == 0)
+ load_order: List[str] = []
+
+ while queue:
+ current_plugin_id = queue.pop(0)
+ load_order.append(current_plugin_id)
+ for dependent_plugin_id in sorted(reverse_graph.get(current_plugin_id, set())):
+ indegree[dependent_plugin_id] -= 1
+ if indegree[dependent_plugin_id] == 0:
+ queue.append(dependent_plugin_id)
+ queue.sort()
+
+ return list(reversed(load_order))
+
+ async def _reload_plugin_by_id(self, plugin_id: str, reason: str) -> ReloadPluginResultPayload:
+ """按插件 ID 在 Runner 进程内执行精确重载。
+
+ Args:
+ plugin_id: 目标插件 ID。
+ reason: 重载原因。
+
+ Returns:
+ ReloadPluginResultPayload: 结构化重载结果。
+ """
+ candidates, duplicate_candidates = self._loader.discover_candidates(self._plugin_dirs)
+ failed_plugins: Dict[str, str] = {}
+
+ if plugin_id in duplicate_candidates:
+ conflict_paths = ", ".join(str(path) for path in duplicate_candidates[plugin_id])
+ return ReloadPluginResultPayload(
+ success=False,
+ requested_plugin_id=plugin_id,
+ failed_plugins={plugin_id: f"检测到重复插件 ID: {conflict_paths}"},
+ )
+
+ loaded_plugin_ids = set(self._loader.list_plugins())
+ plugin_is_loaded = plugin_id in loaded_plugin_ids
+ plugin_has_candidate = plugin_id in candidates
+
+ if not plugin_is_loaded and not plugin_has_candidate:
+ return ReloadPluginResultPayload(
+ success=False,
+ requested_plugin_id=plugin_id,
+ failed_plugins={plugin_id: "插件不存在或未找到合法的 manifest/plugin.py"},
+ )
+
+ target_plugin_ids: Set[str] = {plugin_id}
+ if plugin_is_loaded:
+ target_plugin_ids = self._collect_reverse_dependents(plugin_id)
+
+ unload_order = self._build_unload_order(target_plugin_ids & loaded_plugin_ids)
+ unloaded_plugins: List[str] = []
+ retained_plugin_ids = loaded_plugin_ids - set(unload_order)
+
+ for unload_plugin_id in unload_order:
+ meta = self._loader.get_plugin(unload_plugin_id)
+ if meta is None:
+ continue
+ await self._unload_plugin(meta, reason=reason)
+ unloaded_plugins.append(unload_plugin_id)
+
+ reload_candidates: Dict[str, Tuple[Path, Dict[str, Any], Path]] = {}
+ for target_plugin_id in target_plugin_ids:
+ candidate = candidates.get(target_plugin_id)
+ if candidate is None:
+ failed_plugins[target_plugin_id] = "插件目录已不存在,已保持卸载状态"
+ continue
+ reload_candidates[target_plugin_id] = candidate
+
+ load_order, dependency_failures = self._loader.resolve_dependencies(
+ reload_candidates,
+ extra_available=retained_plugin_ids,
+ )
+ failed_plugins.update(dependency_failures)
+
+ available_plugins = set(retained_plugin_ids)
+ reloaded_plugins: List[str] = []
+
+ for load_plugin_id in load_order:
+ if load_plugin_id in failed_plugins:
+ continue
+
+ candidate = reload_candidates.get(load_plugin_id)
+ if candidate is None:
+ continue
+
+ _, manifest, _ = candidate
+ dependencies = PluginMeta._extract_dependencies(manifest)
+ missing_dependencies = [dependency for dependency in dependencies if dependency not in available_plugins]
+ if missing_dependencies:
+ failed_plugins[load_plugin_id] = f"依赖未满足: {', '.join(missing_dependencies)}"
+ continue
+
+ meta = self._loader.load_candidate(load_plugin_id, candidate)
+ if meta is None:
+ failed_plugins[load_plugin_id] = "插件模块加载失败"
+ continue
+
+ activated = await self._activate_plugin(meta)
+ if not activated:
+ failed_plugins[load_plugin_id] = "插件初始化失败"
+ continue
+
+ available_plugins.add(load_plugin_id)
+ reloaded_plugins.append(load_plugin_id)
+
+ requested_plugin_success = plugin_id in reloaded_plugins and not failed_plugins
+
+ return ReloadPluginResultPayload(
+ success=requested_plugin_success,
+ requested_plugin_id=plugin_id,
+ reloaded_plugins=reloaded_plugins,
+ unloaded_plugins=unloaded_plugins,
+ failed_plugins=failed_plugins,
+ )
+
async def _notify_ready(self, loaded_plugins: List[str], failed_plugins: List[str]) -> None:
- """通知 Host 当前 generation 已完成插件初始化。"""
+ """通知 Host 当前 Runner 已完成插件初始化。
+
+ Args:
+ loaded_plugins: 成功初始化的插件列表。
+ failed_plugins: 初始化失败的插件列表。
+ """
payload = RunnerReadyPayload(
loaded_plugins=loaded_plugins,
failed_plugins=failed_plugins,
@@ -487,6 +762,61 @@ class PluginRunner:
logger.error(f"插件 {plugin_id} event_handler {component_name} 执行异常: {e}", exc_info=True)
return envelope.make_response(payload={"success": False, "continue_processing": True})
+ async def _handle_hook_invoke(self, envelope: Envelope) -> Envelope:
+ """处理 HookHandler 调用请求。
+
+ Args:
+ envelope: RPC 请求信封。
+
+ Returns:
+ Envelope: 标准化后的 Hook 调用结果。
+ """
+ try:
+ invoke = InvokePayload.model_validate(envelope.payload)
+ except Exception as exc:
+ return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
+
+ plugin_id = envelope.plugin_id
+ meta = self._loader.get_plugin(plugin_id)
+ if meta is None:
+ return envelope.make_error_response(
+ ErrorCode.E_PLUGIN_NOT_FOUND.value,
+ f"插件 {plugin_id} 未加载",
+ )
+
+ instance = meta.instance
+ component_name = invoke.component_name
+ handler_method = getattr(instance, f"handle_{component_name}", None) or getattr(instance, component_name, None)
+ if handler_method is None or not callable(handler_method):
+ return envelope.make_error_response(
+ ErrorCode.E_METHOD_NOT_ALLOWED.value,
+ f"插件 {plugin_id} 无组件: {component_name}",
+ )
+
+ try:
+ raw = (
+ await handler_method(**invoke.args)
+ if inspect.iscoroutinefunction(handler_method)
+ else handler_method(**invoke.args)
+ )
+ except Exception as exc:
+ logger.error(f"插件 {plugin_id} hook_handler {component_name} 执行异常: {exc}", exc_info=True)
+ return envelope.make_response(payload={"success": False, "continue_processing": True})
+
+ if raw is None:
+ result = {"success": True, "continue_processing": True}
+ elif isinstance(raw, dict):
+ result = {
+ "success": True,
+ "continue_processing": raw.get("continue_processing", True),
+ "modified_kwargs": raw.get("modified_kwargs"),
+ "custom_result": raw.get("custom_result"),
+ }
+ else:
+ result = {"success": True, "continue_processing": True, "custom_result": raw}
+
+ return envelope.make_response(payload=result)
+
async def _handle_workflow_step(self, envelope: Envelope) -> Envelope:
"""处理 WorkflowStep 调用请求
@@ -557,15 +887,10 @@ class PluginRunner:
async def _handle_shutdown(self, envelope: Envelope) -> Envelope:
"""处理关停 — 调用所有插件的 on_unload 后退出"""
logger.info("收到 shutdown 信号,开始调用 on_unload")
- for plugin_id in self._loader.list_plugins():
+ for plugin_id in list(self._loader.list_plugins()):
meta = self._loader.get_plugin(plugin_id)
- if meta and hasattr(meta.instance, "on_unload"):
- try:
- ret = meta.instance.on_unload()
- if asyncio.iscoroutine(ret):
- await ret
- except Exception as e:
- logger.error(f"插件 {plugin_id} on_unload 失败: {e}", exc_info=True)
+ if meta is not None:
+ await self._unload_plugin(meta, reason="runner_shutdown")
self._shutting_down = True
return envelope.make_response(payload={"acknowledged": True})
@@ -587,6 +912,30 @@ class PluginRunner:
return envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
return envelope.make_response(payload={"acknowledged": True})
+ async def _handle_reload_plugin(self, envelope: Envelope) -> Envelope:
+ """处理按插件 ID 的精确重载请求。
+
+ Args:
+ envelope: RPC 请求信封。
+
+ Returns:
+ Envelope: 结构化重载结果。
+ """
+ try:
+ payload = ReloadPluginPayload.model_validate(envelope.payload)
+ except Exception as exc:
+ return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
+
+ if self._reload_lock.locked():
+ return envelope.make_error_response(
+ ErrorCode.E_RELOAD_IN_PROGRESS.value,
+ f"插件 {payload.plugin_id} 重载请求被拒绝:已有重载任务正在执行",
+ )
+
+ async with self._reload_lock:
+ result = await self._reload_plugin_by_id(payload.plugin_id, payload.reason)
+ return envelope.make_response(payload=result.model_dump())
+
def request_capability(self) -> RPCClient:
"""获取 RPC 客户端(供 SDK 使用,发起能力调用)"""
return self._rpc_client
@@ -652,13 +1001,16 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
_ALLOWED_SRC_PREFIXES = ("src.plugin_runtime", "src.common")
- def find_module(self, fullname, path=None):
+ def find_module(self, fullname: str, path: Any = None) -> Any:
+ """决定是否拦截指定模块导入。"""
return self if self._should_block(fullname) else None
- def load_module(self, fullname):
+ def load_module(self, fullname: str) -> None:
+ """阻止被拦截模块继续导入。"""
raise ImportError(f"Runner 子进程不允许导入主程序模块: {fullname}")
def _should_block(self, fullname: str) -> bool:
+ """判断给定模块名是否应被阻止导入。"""
# 放行非 src.* 的导入、以及 "src" 本身
if not fullname.startswith("src.") or fullname == "src":
return False
@@ -692,6 +1044,7 @@ async def _async_main() -> None:
# 注册信号处理
def _mark_runner_shutting_down() -> None:
+ """标记 Runner 即将进入关停流程。"""
runner._shutting_down = True
_install_shutdown_signal_handlers(_mark_runner_shutting_down)
From 75cd50ee0f2daff6aef647a7b17471731da55992 Mon Sep 17 00:00:00 2001
From: DrSmoothl <1787882683@qq.com>
Date: Fri, 20 Mar 2026 22:35:24 +0800
Subject: [PATCH 18/45] =?UTF-8?q?refactor:=20=E6=9B=B4=E6=96=B0=E8=83=BD?=
=?UTF-8?q?=E5=8A=9B=E5=AE=9E=E7=8E=B0=E6=B3=A8=E5=86=8C=E5=92=8C=E8=AF=B7?=
=?UTF-8?q?=E6=B1=82=E5=A4=84=E7=90=86=EF=BC=8C=E5=A2=9E=E5=BC=BA=E7=B1=BB?=
=?UTF-8?q?=E5=9E=8B=E4=B8=80=E8=87=B4=E6=80=A7?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
src/plugin_runtime/capabilities/registry.py | 9 ++++-----
src/plugin_runtime/host/capability_service.py | 15 +++++----------
src/plugin_runtime/runner/runner_main.py | 7 +++++--
3 files changed, 14 insertions(+), 17 deletions(-)
diff --git a/src/plugin_runtime/capabilities/registry.py b/src/plugin_runtime/capabilities/registry.py
index ead5876a..96b190b4 100644
--- a/src/plugin_runtime/capabilities/registry.py
+++ b/src/plugin_runtime/capabilities/registry.py
@@ -1,6 +1,7 @@
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING
from src.common.logger import get_logger
+from src.plugin_runtime.host.capability_service import CapabilityImpl
from src.plugin_runtime.host.supervisor import PluginSupervisor
if TYPE_CHECKING:
@@ -12,17 +13,15 @@ logger = get_logger("plugin_runtime.integration")
def register_capability_impls(manager: "PluginRuntimeManager", supervisor: PluginSupervisor) -> None:
"""向指定 Supervisor 注册主程序提供的能力实现。"""
cap_service = supervisor.capability_service
- rpc_server = supervisor.rpc_server
- def _register(name: str, impl: Any) -> None:
- """注册单个能力实现及其 RPC 入口。
+ def _register(name: str, impl: CapabilityImpl) -> None:
+ """注册单个能力实现。
Args:
name: 能力名称。
impl: 能力实现函数。
"""
cap_service.register_capability(name, impl)
- rpc_server.register_method(name, cap_service.handle_capability_request)
_register("send.text", manager._cap_send_text)
_register("send.emoji", manager._cap_send_emoji)
diff --git a/src/plugin_runtime/host/capability_service.py b/src/plugin_runtime/host/capability_service.py
index 761b20ca..0ff31fe1 100644
--- a/src/plugin_runtime/host/capability_service.py
+++ b/src/plugin_runtime/host/capability_service.py
@@ -56,19 +56,14 @@ class CapabilityService:
校验权限后调用对应实现。
"""
plugin_id = envelope.plugin_id
- payload = envelope.payload if isinstance(envelope.payload, dict) else {}
try:
- req = CapabilityRequestPayload.model_validate(payload)
- capability = req.capability
- args = req.args
- except Exception:
- capability = envelope.method
- raw_args = payload.get("args", payload)
- args = raw_args if isinstance(raw_args, dict) else {}
+ req = CapabilityRequestPayload.model_validate(envelope.payload)
+ except Exception as exc:
+ return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, f"能力调用 payload 非法: {exc}")
- if not capability:
- return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, "能力调用缺少 capability")
+ capability = req.capability
+ args = req.args
# 1. 权限校验
allowed, reason = self._authorization.check_capability(plugin_id, capability)
diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py
index 771e685f..bf36a05c 100644
--- a/src/plugin_runtime/runner/runner_main.py
+++ b/src/plugin_runtime/runner/runner_main.py
@@ -237,9 +237,12 @@ class PluginRunner:
f"插件 {bound_plugin_id} 尝试以 {plugin_id} 身份发起 RPC,已强制绑定回自身身份"
)
resp = await rpc_client.send_request(
- method=method,
+ method="cap.call",
plugin_id=bound_plugin_id,
- payload=payload or {},
+ payload={
+ "capability": method,
+ "args": payload or {},
+ },
)
# 从响应信封中提取业务结果
if resp.error:
From 85f060621d7b48f40de790a0d366c8d8be67c2c2 Mon Sep 17 00:00:00 2001
From: DrSmoothl <1787882683@qq.com>
Date: Sat, 21 Mar 2026 00:18:28 +0800
Subject: [PATCH 19/45] feat: Add NapCat adapter plugin and enhance message
handling
- Introduced a built-in NapCat adapter plugin for MVP message forwarding.
- Implemented core functionalities for connecting to NapCat/OneBot v11 WebSocket service.
- Added message serialization capabilities for WebUI chat routes.
- Enhanced the RegisterPluginPayload to include optional adapter declarations.
- Implemented methods for handling external messages and adapter declarations in the PluginRunner.
- Improved the send_service to inherit platform IO route metadata for outgoing messages.
---
.../message_receive/uni_message_sender.py | 173 ++---
src/chat/replyer/group_generator.py | 5 +-
src/chat/replyer/private_generator.py | 4 +-
src/platform_io/drivers/plugin_driver.py | 154 +++-
src/plugin_runtime/host/message_gateway.py | 102 +--
src/plugin_runtime/host/message_utils.py | 3 +
src/plugin_runtime/host/supervisor.py | 265 +++++++
src/plugin_runtime/integration.py | 92 ++-
src/plugin_runtime/protocol/envelope.py | 48 +-
src/plugin_runtime/runner/runner_main.py | 55 +-
.../built_in/napcat_adapter/_manifest.json | 30 +
src/plugins/built_in/napcat_adapter/plugin.py | 690 ++++++++++++++++++
src/services/send_service.py | 66 +-
src/webui/routers/chat/serializers.py | 175 +++++
14 files changed, 1683 insertions(+), 179 deletions(-)
create mode 100644 src/plugins/built_in/napcat_adapter/_manifest.json
create mode 100644 src/plugins/built_in/napcat_adapter/plugin.py
create mode 100644 src/webui/routers/chat/serializers.py
diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py
index 894af238..17d5d6d5 100644
--- a/src/chat/message_receive/uni_message_sender.py
+++ b/src/chat/message_receive/uni_message_sender.py
@@ -1,31 +1,37 @@
-from rich.traceback import install
-from typing import Optional
+from typing import Any, Optional, Tuple
import asyncio
+import traceback
+from rich.traceback import install
-from src.common.message_server.api import get_global_api
-from src.common.logger import get_logger
-from src.common.database.database import get_db_session
from src.chat.message_receive.message import SessionMessage
+from src.chat.utils.utils import calculate_typing_time, truncate_message
from src.common.data_models.message_component_data_model import ReplyComponent
-from src.chat.utils.utils import truncate_message
-from src.chat.utils.utils import calculate_typing_time
+from src.common.database.database import get_db_session
+from src.common.logger import get_logger
+from src.common.message_server.api import get_global_api
+from src.webui.routers.chat.serializers import serialize_message_sequence
install(extra_lines=3)
logger = get_logger("sender")
# WebUI 聊天室的消息广播器(延迟导入避免循环依赖)
-_webui_chat_broadcaster = None
+_webui_chat_broadcaster: Optional[Tuple[Any, Optional[str]]] = None
# 虚拟群 ID 前缀(与 chat_routes.py 保持一致)
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
# TODO: 重构完成后完成webui相关
-def get_webui_chat_broadcaster():
- """获取 WebUI 聊天室广播器"""
+def get_webui_chat_broadcaster() -> Tuple[Any, Optional[str]]:
+ """获取 WebUI 聊天室广播器。
+
+ Returns:
+ Tuple[Any, Optional[str]]: ``(chat_manager, platform_name)`` 二元组;
+ 若 WebUI 相关模块不可用,则元素会退化为 ``None``。
+ """
global _webui_chat_broadcaster
if _webui_chat_broadcaster is None:
try:
@@ -38,102 +44,36 @@ def get_webui_chat_broadcaster():
def is_webui_virtual_group(group_id: str) -> bool:
- """检查是否是 WebUI 虚拟群"""
- return group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX)
-
-
-def parse_message_segments(segment) -> list:
- """解析消息段,转换为 WebUI 可用的格式
-
- 参考 NapCat 适配器的消息解析逻辑
+ """检查是否是 WebUI 虚拟群。
Args:
- segment: Seg 消息段对象
+ group_id: 待判断的群 ID。
Returns:
- list: 消息段列表,每个元素为 {"type": "...", "data": ...}
+ bool: 若群 ID 属于 WebUI 虚拟群则返回 ``True``。
"""
-
- result = []
-
- if segment is None:
- return result
-
- if segment.type == "seglist":
- # 处理消息段列表
- if segment.data:
- for seg in segment.data:
- result.extend(parse_message_segments(seg))
- elif segment.type == "text":
- # 文本消息
- if segment.data:
- result.append({"type": "text", "data": segment.data})
- elif segment.type == "image":
- # 图片消息(base64)
- if segment.data:
- result.append({"type": "image", "data": f"data:image/png;base64,{segment.data}"})
- elif segment.type == "emoji":
- # 表情包消息(base64)
- if segment.data:
- result.append({"type": "emoji", "data": f"data:image/gif;base64,{segment.data}"})
- elif segment.type == "imageurl":
- # 图片链接消息
- if segment.data:
- result.append({"type": "image", "data": segment.data})
- elif segment.type == "face":
- # 原生表情
- result.append({"type": "face", "data": segment.data})
- elif segment.type == "voice":
- # 语音消息(base64)
- if segment.data:
- result.append({"type": "voice", "data": f"data:audio/wav;base64,{segment.data}"})
- elif segment.type == "voiceurl":
- # 语音链接
- if segment.data:
- result.append({"type": "voice", "data": segment.data})
- elif segment.type == "video":
- # 视频消息(base64)
- if segment.data:
- result.append({"type": "video", "data": f"data:video/mp4;base64,{segment.data}"})
- elif segment.type == "videourl":
- # 视频链接
- if segment.data:
- result.append({"type": "video", "data": segment.data})
- elif segment.type == "music":
- # 音乐消息
- result.append({"type": "music", "data": segment.data})
- elif segment.type == "file":
- # 文件消息
- result.append({"type": "file", "data": segment.data})
- elif segment.type == "reply":
- # 回复消息
- result.append({"type": "reply", "data": segment.data})
- elif segment.type == "forward":
- # 转发消息
- forward_items = []
- if segment.data:
- for item in segment.data:
- forward_items.append(
- {
- "content": parse_message_segments(item.get("message_segment", {}))
- if isinstance(item, dict)
- else []
- }
- )
- result.append({"type": "forward", "data": forward_items})
- else:
- # 未知类型,尝试作为文本处理
- if segment.data:
- result.append({"type": "unknown", "original_type": segment.type, "data": str(segment.data)})
-
- return result
+ return bool(group_id) and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX)
-async def _send_message(message: MessageSending, show_log=True) -> bool:
- """合并后的消息发送函数,包含WS发送和日志记录"""
+async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
+ """执行统一的消息发送流程。
+
+ 发送顺序为:
+ 1. WebUI 特殊链路
+ 2. Platform IO 适配器链路
+ 3. 旧版 ``maim_message`` / API Server 链路
+
+ Args:
+ message: 待发送的内部会话消息。
+ show_log: 是否输出发送成功日志。
+
+ Returns:
+ bool: 是否最终发送成功。
+ """
message_preview = truncate_message(message.processed_plain_text, max_length=200)
platform = message.platform
- group_id = message.session.group_id
+ group_info = message.message_info.group_info
+ group_id = group_info.group_id if group_info is not None else ""
try:
# 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息
@@ -146,7 +86,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
from src.config.config import global_config
# 解析消息段,获取富文本内容
- message_segments = parse_message_segments(message.message_segment)
+ message_segments = serialize_message_sequence(message.raw_message)
# 判断消息类型
# 如果只有一个文本段,使用简单的 text 类型
@@ -184,8 +124,38 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
logger.info(f"已将消息 '{message_preview}' 发往 WebUI 聊天室")
return True
+ try:
+ from src.platform_io import DeliveryStatus
+ from src.plugin_runtime.integration import get_plugin_runtime_manager
+
+ receipt = await get_plugin_runtime_manager().try_send_message_via_platform_io(message)
+ if receipt is not None:
+ if receipt.status == DeliveryStatus.SENT:
+ if show_log:
+ logger.info(
+ f"已通过 Platform IO 将消息 '{message_preview}' 发往平台'{platform}' "
+ f"(driver: {receipt.driver_id or 'unknown'})"
+ )
+ return True
+
+ logger.warning(
+ f"Platform IO 发送失败: platform={platform} driver={receipt.driver_id} "
+ f"status={receipt.status} error={receipt.error}"
+ )
+ return False
+ except Exception as exc:
+ logger.warning(f"检查 Platform IO 出站链路时出现异常,将回退旧发送链: {exc}")
+
# Fallback 逻辑: 尝试通过 API Server 发送
- async def send_with_new_api(legacy_exception=None):
+ async def send_with_new_api(legacy_exception: Optional[Exception] = None) -> bool:
+ """通过 API Server 回退链路发送消息。
+
+ Args:
+ legacy_exception: 旧发送链已经抛出的异常;若回退也失败,则重新抛出。
+
+ Returns:
+ bool: 回退链路是否发送成功。
+ """
try:
from src.config.config import global_config
@@ -289,7 +259,8 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
class UniversalMessageSender:
"""管理消息的注册、即时处理、发送和存储,并跟踪思考状态。"""
- def __init__(self):
+ def __init__(self) -> None:
+ """初始化统一消息发送器。"""
pass
async def send_message(
@@ -300,7 +271,7 @@ class UniversalMessageSender:
reply_message_id: Optional[str] = None,
storage_message: bool = True,
show_log: bool = True,
- ):
+ ) -> bool:
"""
处理、发送并存储一条消息。
diff --git a/src/chat/replyer/group_generator.py b/src/chat/replyer/group_generator.py
index e10aa147..75563df7 100644
--- a/src/chat/replyer/group_generator.py
+++ b/src/chat/replyer/group_generator.py
@@ -1129,7 +1129,10 @@ class DefaultReplyer:
user_id=bot_user_id,
user_nickname=global_config.bot.nickname,
),
- additional_config={},
+ additional_config={
+ "platform_io_target_group_id": self.chat_stream.group_id,
+ "platform_io_target_user_id": self.chat_stream.user_id,
+ },
),
message_segment=message_segment,
)
diff --git a/src/chat/replyer/private_generator.py b/src/chat/replyer/private_generator.py
index ccd8e1e4..f642dd69 100644
--- a/src/chat/replyer/private_generator.py
+++ b/src/chat/replyer/private_generator.py
@@ -970,7 +970,9 @@ class PrivateReplyer:
user_nickname=global_config.bot.nickname,
),
group_info=None,
- additional_config={},
+ additional_config={
+ "platform_io_target_user_id": self.chat_stream.user_id,
+ },
),
message_segment=message_segment,
)
diff --git a/src/platform_io/drivers/plugin_driver.py b/src/platform_io/drivers/plugin_driver.py
index 9c139309..dff980f8 100644
--- a/src/platform_io/drivers/plugin_driver.py
+++ b/src/platform_io/drivers/plugin_driver.py
@@ -1,34 +1,51 @@
-"""提供 Platform IO 的 plugin 传输驱动骨架。"""
+"""提供 Platform IO 的插件适配器驱动实现。"""
-from typing import TYPE_CHECKING, Any, Dict, Optional
+from typing import TYPE_CHECKING, Any, Dict, Optional, Protocol
from src.platform_io.drivers.base import PlatformIODriver
-from src.platform_io.types import DeliveryReceipt, DriverDescriptor, DriverKind, RouteKey
+from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, RouteKey
if TYPE_CHECKING:
from src.chat.message_receive.message import SessionMessage
+class _AdapterSupervisorProtocol(Protocol):
+ """适配器驱动依赖的 Supervisor 最小协议。"""
+
+ async def invoke_adapter(
+ self,
+ plugin_id: str,
+ method_name: str,
+ args: Optional[Dict[str, Any]] = None,
+ timeout_ms: int = 30000,
+ ) -> Any:
+ """调用适配器插件专用方法。"""
+
+
class PluginPlatformDriver(PlatformIODriver):
- """面向 ``MessageGateway`` 插件链路的 Platform IO 驱动骨架。"""
+ """面向适配器插件链路的 Platform IO 驱动。"""
def __init__(
self,
driver_id: str,
platform: str,
+ supervisor: _AdapterSupervisorProtocol,
+ send_method: str = "send_to_platform",
account_id: Optional[str] = None,
scope: Optional[str] = None,
plugin_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
- """初始化一个 plugin 驱动描述对象。
+ """初始化一个插件适配器驱动。
Args:
driver_id: Broker 内的唯一驱动 ID。
- platform: 该 plugin 适配器链路负责的平台。
+ platform: 该适配器负责的平台名称。
+ supervisor: 持有该适配器插件的 Supervisor。
+ send_method: 出站发送时要调用的插件方法名。
account_id: 可选的账号 ID 或 self ID。
scope: 可选的额外路由作用域。
- plugin_id: 拥有该适配器实现的插件 ID,可为空。
+ plugin_id: 拥有该适配器实现的插件 ID。
metadata: 可选的额外驱动元数据。
"""
descriptor = DriverDescriptor(
@@ -41,6 +58,8 @@ class PluginPlatformDriver(PlatformIODriver):
metadata=metadata or {},
)
super().__init__(descriptor)
+ self._supervisor = supervisor
+ self._send_method = send_method
async def send_message(
self,
@@ -48,7 +67,7 @@ class PluginPlatformDriver(PlatformIODriver):
route_key: RouteKey,
metadata: Optional[Dict[str, Any]] = None,
) -> DeliveryReceipt:
- """通过 plugin 传输路径发送消息。
+ """通过适配器插件发送消息。
Args:
message: 要投递的内部会话消息。
@@ -57,8 +76,119 @@ class PluginPlatformDriver(PlatformIODriver):
Returns:
DeliveryReceipt: 由驱动返回的规范化回执。
-
- Raises:
- NotImplementedError: 当前仍处于骨架阶段,尚未真正接入 MessageGateway。
"""
- raise NotImplementedError("PluginPlatformDriver 仅完成地基实现,尚未接入 MessageGateway")
+ from src.plugin_runtime.host.message_utils import PluginMessageUtils
+
+ plugin_id = self.descriptor.plugin_id or ""
+ if not plugin_id:
+ return DeliveryReceipt(
+ internal_message_id=message.message_id,
+ route_key=route_key,
+ status=DeliveryStatus.FAILED,
+ driver_id=self.driver_id,
+ driver_kind=self.descriptor.kind,
+ error="插件适配器驱动缺少 plugin_id",
+ )
+
+ try:
+ message_dict = PluginMessageUtils._session_message_to_dict(message)
+ response = await self._supervisor.invoke_adapter(
+ plugin_id=plugin_id,
+ method_name=self._send_method,
+ args={
+ "message": message_dict,
+ "route": {
+ "platform": route_key.platform,
+ "account_id": route_key.account_id,
+ "scope": route_key.scope,
+ },
+ "metadata": metadata or {},
+ },
+ timeout_ms=30000,
+ )
+ except Exception as exc:
+ return DeliveryReceipt(
+ internal_message_id=message.message_id,
+ route_key=route_key,
+ status=DeliveryStatus.FAILED,
+ driver_id=self.driver_id,
+ driver_kind=self.descriptor.kind,
+ error=str(exc),
+ )
+
+ return self._build_receipt(message.message_id, route_key, response)
+
+ def _build_receipt(self, internal_message_id: str, route_key: RouteKey, response: Any) -> DeliveryReceipt:
+ """将适配器调用响应归一化为出站回执。
+
+ Args:
+ internal_message_id: 内部消息 ID。
+ route_key: 本次投递的路由键。
+ response: Supervisor 返回的 RPC 响应对象。
+
+ Returns:
+ DeliveryReceipt: 标准化后的出站回执。
+ """
+ if getattr(response, "error", None):
+ error = response.error.get("message", "适配器发送失败")
+ return DeliveryReceipt(
+ internal_message_id=internal_message_id,
+ route_key=route_key,
+ status=DeliveryStatus.FAILED,
+ driver_id=self.driver_id,
+ driver_kind=self.descriptor.kind,
+ error=error,
+ )
+
+ payload = getattr(response, "payload", {})
+ invoke_success = bool(payload.get("success", False)) if isinstance(payload, dict) else False
+ if not invoke_success:
+ return DeliveryReceipt(
+ internal_message_id=internal_message_id,
+ route_key=route_key,
+ status=DeliveryStatus.FAILED,
+ driver_id=self.driver_id,
+ driver_kind=self.descriptor.kind,
+ error=str(payload.get("result", "适配器发送失败")) if isinstance(payload, dict) else "适配器发送失败",
+ )
+
+ result = payload.get("result") if isinstance(payload, dict) else None
+ if isinstance(result, dict):
+ if result.get("success") is False:
+ return DeliveryReceipt(
+ internal_message_id=internal_message_id,
+ route_key=route_key,
+ status=DeliveryStatus.FAILED,
+ driver_id=self.driver_id,
+ driver_kind=self.descriptor.kind,
+ error=str(result.get("error", "适配器发送失败")),
+ metadata=result.get("metadata", {}) if isinstance(result.get("metadata"), dict) else {},
+ )
+ external_message_id = str(result.get("external_message_id") or result.get("message_id") or "") or None
+ return DeliveryReceipt(
+ internal_message_id=internal_message_id,
+ route_key=route_key,
+ status=DeliveryStatus.SENT,
+ driver_id=self.driver_id,
+ driver_kind=self.descriptor.kind,
+ external_message_id=external_message_id,
+ metadata=result.get("metadata", {}) if isinstance(result.get("metadata"), dict) else {},
+ )
+
+ if isinstance(result, str) and result.strip():
+ return DeliveryReceipt(
+ internal_message_id=internal_message_id,
+ route_key=route_key,
+ status=DeliveryStatus.SENT,
+ driver_id=self.driver_id,
+ driver_kind=self.descriptor.kind,
+ external_message_id=result.strip(),
+ )
+
+ return DeliveryReceipt(
+ internal_message_id=internal_message_id,
+ route_key=route_key,
+ status=DeliveryStatus.SENT,
+ driver_id=self.driver_id,
+ driver_kind=self.descriptor.kind,
+ )
diff --git a/src/plugin_runtime/host/message_gateway.py b/src/plugin_runtime/host/message_gateway.py
index 43777286..9e8e9be6 100644
--- a/src/plugin_runtime/host/message_gateway.py
+++ b/src/plugin_runtime/host/message_gateway.py
@@ -3,9 +3,11 @@ Message Gateway 模块
适配器专用,用于将其他平台的消息转换为系统内部的消息格式,并将系统消息转换为其他平台的格式。
"""
-from typing import Dict, Any, TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Dict
from src.common.logger import get_logger
+from src.platform_io import DeliveryStatus, get_platform_io_manager
+
from .message_utils import PluginMessageUtils
if TYPE_CHECKING:
@@ -17,25 +19,53 @@ logger = get_logger("plugin_runtime.host.message_gateway")
class MessageGateway:
- def __init__(self, component_registry: "ComponentRegistry") -> None:
- self._component_registry = component_registry
+ """Host 侧消息网关包装器。"""
- async def receive_external_message(self, external_message: Dict[str, Any]):
- """
- 接收外部消息,转换为系统内部格式,并返回转换结果
+ def __init__(self, component_registry: "ComponentRegistry") -> None:
+ """初始化消息网关。
Args:
- external_message: 外部消息的字典格式数据
+ component_registry: 组件注册表。
+ """
+ self._component_registry = component_registry
+
+ def build_session_message(self, external_message: Dict[str, Any]) -> "SessionMessage":
+ """将标准消息字典转换为 ``SessionMessage``。
+
+ Args:
+ external_message: 外部消息的字典格式数据。
Returns:
- 转换后的 SessionMessage 对象
+ SessionMessage: 转换后的内部消息对象。
+
+ Raises:
+ ValueError: 消息字典不合法时抛出。
+ """
+ return PluginMessageUtils._build_session_message_from_dict(external_message)
+
+ def build_message_dict(self, internal_message: "SessionMessage") -> Dict[str, Any]:
+ """将 ``SessionMessage`` 转换为标准消息字典。
+
+ Args:
+ internal_message: 内部消息对象。
+
+ Returns:
+ Dict[str, Any]: 供适配器插件消费的标准消息字典。
+ """
+ return dict(PluginMessageUtils._session_message_to_dict(internal_message))
+
+ async def receive_external_message(self, external_message: Dict[str, Any]) -> None:
+ """接收外部消息并送入主消息链。
+
+ Args:
+ external_message: 外部消息的字典格式数据。
"""
- # 使用递归函数将外部消息字典转换为 SessionMessage
try:
- session_message = PluginMessageUtils._build_session_message_from_dict(external_message)
+ session_message = self.build_session_message(external_message)
except Exception as e:
logger.error(f"转换外部消息失败: {e}")
return
+
from src.chat.message_receive.bot import chat_bot
await chat_bot.receive_message(session_message)
@@ -48,46 +78,32 @@ class MessageGateway:
enabled_only: bool = True,
save_to_db: bool = True,
) -> bool:
- """
- 接收系统内部消息,转换为外部格式,并返回转换结果
+ """将内部消息通过 Platform IO 发送到外部平台。
Args:
- internal_message: 系统内部的 SessionMessage 对象
+ internal_message: 系统内部的 ``SessionMessage`` 对象。
+ supervisor: 当前持有该消息网关的 Supervisor。
+ enabled_only: 兼容旧签名的保留参数,当前由 Platform IO 统一裁决。
+ save_to_db: 发送成功后是否写入数据库。
Returns:
- 转换是否成功
+ bool: 是否发送成功。
"""
- try:
- # 将 SessionMessage 转换为字典格式
- message_dict = PluginMessageUtils._session_message_to_dict(internal_message)
- except Exception as e:
- logger.error(f"转换内部消息失败:{e}")
- return False
- gateway_entry = self._component_registry.get_message_gateways(
- internal_message.platform,
- enabled_only=enabled_only,
- session_id=internal_message.session_id,
- )
- if not gateway_entry:
- logger.warning(f"未找到适配平台 {internal_message.platform} 的消息网关组件,无法发送消息到外部平台")
- return False
- args = {"platform": internal_message.platform, "message": message_dict}
- try:
- resp_envelope = await supervisor.invoke_plugin(
- "plugin.emit_event", gateway_entry.plugin_id, gateway_entry.name, args
- )
- logger.debug("信息发送成功")
- except Exception as e:
- logger.error(f"调用消息网关组件失败:{e}")
+ del enabled_only
+ del supervisor
+
+ platform_io_manager = get_platform_io_manager()
+ if not platform_io_manager.is_started:
+ logger.warning("Platform IO 尚未启动,无法通过适配器链路发送消息")
return False
- # 更新为实际id(如果组件返回了新的id)
- actual_message_id = resp_envelope.payload.get("message_id")
- try:
- actual_message_id = str(actual_message_id)
- except Exception:
- actual_message_id = None
- internal_message.message_id = actual_message_id or internal_message.message_id
+ route_key = platform_io_manager.build_route_key_from_message(internal_message)
+ receipt = await platform_io_manager.send_message(internal_message, route_key)
+ if receipt.status != DeliveryStatus.SENT:
+ logger.warning(f"通过适配器链路发送消息失败: {receipt.error or receipt.status}")
+ return False
+
+ internal_message.message_id = receipt.external_message_id or internal_message.message_id
if save_to_db:
try:
from src.common.utils.utils_message import MessageUtils
diff --git a/src/plugin_runtime/host/message_utils.py b/src/plugin_runtime/host/message_utils.py
index 428e3c48..aaebb529 100644
--- a/src/plugin_runtime/host/message_utils.py
+++ b/src/plugin_runtime/host/message_utils.py
@@ -209,6 +209,9 @@ class PluginMessageUtils:
session_message.is_notify = message_dict.get("is_notify", False)
if not isinstance(session_message.is_notify, bool):
session_message.is_notify = False
+ session_message.session_id = message_dict.get("session_id", "")
+ if not isinstance(session_message.session_id, str):
+ session_message.session_id = ""
session_message.reply_to = message_dict.get("reply_to")
if session_message.reply_to is not None and not isinstance(session_message.reply_to, str):
session_message.reply_to = None
diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py
index 5ae3bdee..cdf3d4ee 100644
--- a/src/plugin_runtime/host/supervisor.py
+++ b/src/plugin_runtime/host/supervisor.py
@@ -8,13 +8,20 @@ import sys
from src.common.logger import get_logger
from src.config.config import global_config
+from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, get_platform_io_manager
+from src.platform_io.drivers import PluginPlatformDriver
+from src.platform_io.route_key_factory import RouteKeyFactory
+from src.platform_io.routing import RouteBindingConflictError
from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN
from src.plugin_runtime.protocol.envelope import (
+ AdapterDeclarationPayload,
BootstrapPluginPayload,
ConfigUpdatedPayload,
Envelope,
HealthPayload,
PROTOCOL_VERSION,
+ ReceiveExternalMessagePayload,
+ ReceiveExternalMessageResultPayload,
RegisterPluginPayload,
ReloadPluginResultPayload,
RunnerReadyPayload,
@@ -86,6 +93,7 @@ class PluginRunnerSupervisor:
self._runner_process: Optional[asyncio.subprocess.Process] = None
self._registered_plugins: Dict[str, RegisterPluginPayload] = {}
+ self._registered_adapters: Dict[str, AdapterDeclarationPayload] = {}
self._runner_ready_events: asyncio.Event = asyncio.Event()
self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload()
self._health_task: Optional[asyncio.Task[None]] = None
@@ -257,6 +265,32 @@ class PluginRunnerSupervisor:
timeout_ms,
)
+ async def invoke_adapter(
+ self,
+ plugin_id: str,
+ method_name: str,
+ args: Optional[Dict[str, Any]] = None,
+ timeout_ms: int = 30000,
+ ) -> Envelope:
+ """调用适配器插件的专用方法。
+
+ Args:
+ plugin_id: 目标适配器插件 ID。
+ method_name: 要调用的插件方法名,例如 ``send_to_platform``。
+ args: 传递给插件方法的关键字参数。
+ timeout_ms: RPC 超时时间,单位毫秒。
+
+ Returns:
+ Envelope: Runner 返回的响应信封。
+ """
+ return await self.invoke_plugin(
+ method="plugin.invoke_adapter",
+ plugin_id=plugin_id,
+ component_name=method_name,
+ args=args,
+ timeout_ms=timeout_ms,
+ )
+
async def reload_plugin(self, plugin_id: str, reason: str = "manual") -> bool:
"""按插件 ID 触发精确重载。
@@ -384,6 +418,7 @@ class PluginRunnerSupervisor:
def _register_internal_methods(self) -> None:
"""注册 Host 侧内部 RPC 方法。"""
self._rpc_server.register_method("cap.call", self._capability_service.handle_capability_request)
+ self._rpc_server.register_method("host.receive_external_message", self._handle_receive_external_message)
self._rpc_server.register_method("plugin.bootstrap", self._handle_bootstrap_plugin)
self._rpc_server.register_method("plugin.register_components", self._handle_register_plugin)
self._rpc_server.register_method("plugin.register_plugin", self._handle_register_plugin)
@@ -427,6 +462,17 @@ class PluginRunnerSupervisor:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
self._component_registry.remove_components_by_plugin(payload.plugin_id)
+ if payload.plugin_id in self._registered_adapters:
+ await self._unregister_adapter_driver(payload.plugin_id)
+
+ try:
+ if payload.adapter is not None:
+ await self._register_adapter_driver(payload.plugin_id, payload.adapter)
+ except RouteBindingConflictError as exc:
+ return envelope.make_error_response(ErrorCode.E_METHOD_NOT_ALLOWED.value, str(exc))
+ except Exception as exc:
+ return envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(exc))
+
registered_count = self._component_registry.register_plugin_components(
payload.plugin_id,
[component.model_dump() for component in payload.components],
@@ -438,6 +484,7 @@ class PluginRunnerSupervisor:
"accepted": True,
"plugin_id": payload.plugin_id,
"registered_components": registered_count,
+ "adapter_registered": payload.adapter is not None,
}
)
@@ -458,6 +505,7 @@ class PluginRunnerSupervisor:
removed_components = self._component_registry.remove_components_by_plugin(payload.plugin_id)
self._authorization.revoke_permission_token(payload.plugin_id)
removed_registration = self._registered_plugins.pop(payload.plugin_id, None) is not None
+ await self._unregister_adapter_driver(payload.plugin_id)
return envelope.make_response(
payload={
@@ -469,6 +517,221 @@ class PluginRunnerSupervisor:
}
)
+ @staticmethod
+ def _build_adapter_driver_id(plugin_id: str) -> str:
+ """构造适配器驱动 ID。
+
+ Args:
+ plugin_id: 适配器插件 ID。
+
+ Returns:
+ str: 对应 Platform IO 中的驱动 ID。
+ """
+ return f"adapter:{plugin_id}"
+
+ async def _register_adapter_driver(self, plugin_id: str, adapter: AdapterDeclarationPayload) -> None:
+ """将适配器插件注册到 Platform IO。
+
+ Args:
+ plugin_id: 适配器插件 ID。
+ adapter: 经过校验的适配器声明。
+
+ Raises:
+ ValueError: 适配器路由冲突或驱动注册失败时抛出。
+ """
+ await self._unregister_adapter_driver(plugin_id)
+
+ platform_io_manager = get_platform_io_manager()
+ driver = PluginPlatformDriver(
+ driver_id=self._build_adapter_driver_id(plugin_id),
+ platform=adapter.platform,
+ account_id=adapter.account_id or None,
+ scope=adapter.scope or None,
+ plugin_id=plugin_id,
+ send_method=adapter.send_method,
+ supervisor=self,
+ metadata={
+ "protocol": adapter.protocol,
+ **adapter.metadata,
+ },
+ )
+ binding = RouteBinding(
+ route_key=driver.descriptor.route_key,
+ driver_id=driver.driver_id,
+ driver_kind=DriverKind.PLUGIN,
+ metadata={
+ "plugin_id": plugin_id,
+ "protocol": adapter.protocol,
+ },
+ )
+
+ try:
+ if platform_io_manager.is_started:
+ await platform_io_manager.add_driver(driver)
+ else:
+ platform_io_manager.register_driver(driver)
+ platform_io_manager.bind_route(binding)
+ except Exception:
+ with contextlib.suppress(Exception):
+ if platform_io_manager.is_started:
+ await platform_io_manager.remove_driver(driver.driver_id)
+ else:
+ platform_io_manager.unregister_driver(driver.driver_id)
+ raise
+
+ self._registered_adapters[plugin_id] = adapter
+
+ async def _unregister_adapter_driver(self, plugin_id: str) -> None:
+ """从 Platform IO 注销一个适配器驱动。
+
+ Args:
+ plugin_id: 适配器插件 ID。
+ """
+ platform_io_manager = get_platform_io_manager()
+ driver_id = self._build_adapter_driver_id(plugin_id)
+
+ with contextlib.suppress(Exception):
+ if platform_io_manager.is_started:
+ await platform_io_manager.remove_driver(driver_id)
+ else:
+ platform_io_manager.unregister_driver(driver_id)
+
+ self._registered_adapters.pop(plugin_id, None)
+
+ async def _unregister_all_adapter_drivers(self) -> None:
+ """注销当前 Supervisor 管理的全部适配器驱动。"""
+ plugin_ids = list(self._registered_adapters.keys())
+ for plugin_id in plugin_ids:
+ await self._unregister_adapter_driver(plugin_id)
+
+ @staticmethod
+ def _attach_inbound_route_metadata(
+ session_message: "SessionMessage",
+ route_key: RouteKey,
+ route_metadata: Dict[str, Any],
+ ) -> None:
+ """将入站路由信息写回消息的 ``additional_config``。
+
+ Args:
+ session_message: 已构造好的内部消息对象。
+ route_key: Host 为该消息解析出的标准路由键。
+ route_metadata: 适配器通过 RPC 补充的原始路由辅助元数据。
+ """
+ additional_config = session_message.message_info.additional_config
+ if not isinstance(additional_config, dict):
+ additional_config = {}
+ session_message.message_info.additional_config = additional_config
+
+ for key, value in route_metadata.items():
+ if value is None:
+ continue
+ normalized_value = str(value).strip()
+ if normalized_value:
+ additional_config[key] = value
+
+ if route_key.account_id:
+ additional_config.setdefault("platform_io_account_id", route_key.account_id)
+ if route_key.scope:
+ additional_config.setdefault("platform_io_scope", route_key.scope)
+
+ def _build_inbound_route_key(
+ self,
+ adapter: AdapterDeclarationPayload,
+ message: Dict[str, Any],
+ route_metadata: Dict[str, Any],
+ ) -> RouteKey:
+ """为适配器入站消息构造归一路由键。
+
+ Args:
+ adapter: 当前适配器声明。
+ message: 标准消息字典。
+ route_metadata: 插件补充的路由辅助元数据。
+
+ Returns:
+ RouteKey: 供 Platform IO 使用的规范化路由键。
+
+ Raises:
+ ValueError: 消息平台字段与适配器平台声明不一致时抛出。
+ """
+ message_platform = str(message.get("platform") or adapter.platform).strip()
+ if message_platform != adapter.platform:
+ raise ValueError(
+ f"外部消息平台 {message_platform} 与适配器 {adapter.platform} 不一致"
+ )
+
+ try:
+ route_key = RouteKeyFactory.from_message_dict(message)
+ except Exception:
+ route_key = RouteKey(platform=message_platform)
+
+ route_account_id, route_scope = RouteKeyFactory.extract_components(route_metadata)
+ account_id = route_key.account_id or route_account_id or adapter.account_id or None
+ scope = route_key.scope or route_scope or adapter.scope or None
+ return RouteKey(
+ platform=message_platform,
+ account_id=account_id,
+ scope=scope,
+ )
+
+ async def _handle_receive_external_message(self, envelope: Envelope) -> Envelope:
+ """处理适配器插件上报的外部入站消息。
+
+ Args:
+ envelope: RPC 请求信封。
+
+ Returns:
+ Envelope: 注入结果响应。
+ """
+ try:
+ payload = ReceiveExternalMessagePayload.model_validate(envelope.payload)
+ except Exception as exc:
+ return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
+
+ adapter = self._registered_adapters.get(envelope.plugin_id)
+ if adapter is None:
+ return envelope.make_error_response(
+ ErrorCode.E_METHOD_NOT_ALLOWED.value,
+ f"插件 {envelope.plugin_id} 未声明为适配器,不能注入外部消息",
+ )
+
+ try:
+ route_key = self._build_inbound_route_key(
+ adapter=adapter,
+ message=payload.message,
+ route_metadata=payload.route_metadata,
+ )
+ session_message = self._message_gateway.build_session_message(payload.message)
+ self._attach_inbound_route_metadata(session_message, route_key, payload.route_metadata)
+ except Exception as exc:
+ return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
+
+ platform_io_manager = get_platform_io_manager()
+ accepted = await platform_io_manager.accept_inbound(
+ InboundMessageEnvelope(
+ route_key=route_key,
+ driver_id=self._build_adapter_driver_id(envelope.plugin_id),
+ driver_kind=DriverKind.PLUGIN,
+ external_message_id=payload.external_message_id or str(payload.message.get("message_id") or "") or None,
+ dedupe_key=payload.dedupe_key or None,
+ session_message=session_message,
+ payload=payload.message,
+ metadata={
+ "plugin_id": envelope.plugin_id,
+ "protocol": adapter.protocol,
+ **payload.route_metadata,
+ },
+ )
+ )
+ response = ReceiveExternalMessageResultPayload(
+ accepted=accepted,
+ route_key={
+ "platform": route_key.platform,
+ "account_id": route_key.account_id,
+ "scope": route_key.scope,
+ },
+ )
+ return envelope.make_response(payload=response.model_dump())
+
async def _handle_runner_ready(self, envelope: Envelope) -> Envelope:
"""处理 Runner 就绪通知。
@@ -595,6 +858,7 @@ class PluginRunnerSupervisor:
await self._stderr_drain_task
self._stderr_drain_task = None
+ await self._unregister_all_adapter_drivers()
self._clear_runner_state()
async def _health_check_loop(self) -> None:
@@ -671,6 +935,7 @@ class PluginRunnerSupervisor:
self._authorization.clear()
self._component_registry.clear()
self._registered_plugins.clear()
+ self._registered_adapters.clear()
self._runner_ready_events = asyncio.Event()
self._runner_ready_payloads = RunnerReadyPayload()
diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py
index 730da3e1..30a3c150 100644
--- a/src/plugin_runtime/integration.py
+++ b/src/plugin_runtime/integration.py
@@ -12,11 +12,13 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, Ite
import asyncio
import json
+
import tomlkit
from src.common.logger import get_logger
from src.config.config import global_config
from src.config.file_watcher import FileChange, FileWatcher
+from src.platform_io import DeliveryReceipt, InboundMessageEnvelope, get_platform_io_manager
from src.plugin_runtime.capabilities import (
RuntimeComponentCapabilityMixin,
RuntimeCoreCapabilityMixin,
@@ -57,6 +59,7 @@ class PluginRuntimeManager(
"""
def __init__(self) -> None:
+ """初始化插件运行时管理器。"""
from src.plugin_runtime.host.supervisor import PluginSupervisor
self._builtin_supervisor: Optional[PluginSupervisor] = None
@@ -66,6 +69,22 @@ class PluginRuntimeManager(
self._plugin_source_watcher_subscription_id: Optional[str] = None
self._plugin_config_watcher_subscriptions: Dict[str, Tuple[Path, str]] = {}
+ async def _dispatch_platform_inbound(self, envelope: InboundMessageEnvelope) -> None:
+ """接收 Platform IO 审核后的入站消息并送入主消息链。
+
+ Args:
+ envelope: Platform IO 产出的入站封装。
+ """
+ session_message = envelope.session_message
+ if session_message is None and envelope.payload is not None:
+ session_message = PluginMessageUtils._build_session_message_from_dict(dict(envelope.payload))
+ if session_message is None:
+ raise ValueError("Platform IO 入站封装缺少可用的 SessionMessage 或 payload")
+
+ from src.chat.message_receive.bot import chat_bot
+
+ await chat_bot.receive_message(session_message)
+
# ─── 插件目录 ─────────────────────────────────────────────
@staticmethod
@@ -110,6 +129,8 @@ class PluginRuntimeManager(
logger.info("未找到任何插件目录,跳过插件运行时启动")
return
+ platform_io_manager = get_platform_io_manager()
+
# 从配置读取自定义 IPC socket 路径(留空则自动生成)
socket_path_base = _cfg.ipc_socket_path or None
@@ -134,6 +155,9 @@ class PluginRuntimeManager(
started_supervisors: List[PluginSupervisor] = []
try:
+ platform_io_manager.set_inbound_dispatcher(self._dispatch_platform_inbound)
+ await platform_io_manager.start()
+
if self._builtin_supervisor:
await self._builtin_supervisor.start()
started_supervisors.append(self._builtin_supervisor)
@@ -147,6 +171,11 @@ class PluginRuntimeManager(
logger.error(f"插件运行时启动失败: {e}", exc_info=True)
await self._stop_plugin_file_watcher()
await asyncio.gather(*(sv.stop() for sv in started_supervisors), return_exceptions=True)
+ platform_io_manager.clear_inbound_dispatcher()
+ try:
+ await platform_io_manager.stop()
+ except Exception as platform_io_exc:
+ logger.warning(f"Platform IO 停止失败: {platform_io_exc}")
self._started = False
self._builtin_supervisor = None
self._third_party_supervisor = None
@@ -156,6 +185,7 @@ class PluginRuntimeManager(
if not self._started:
return
+ platform_io_manager = get_platform_io_manager()
await self._stop_plugin_file_watcher()
coroutines: List[Coroutine[Any, Any, None]] = []
@@ -164,11 +194,23 @@ class PluginRuntimeManager(
if self._third_party_supervisor:
coroutines.append(self._third_party_supervisor.stop())
+ stop_errors: List[str] = []
try:
- await asyncio.gather(*coroutines, return_exceptions=True)
- logger.info("插件运行时已停止")
- except Exception as e:
- logger.error(f"插件运行时停止失败: {e}", exc_info=True)
+ results = await asyncio.gather(*coroutines, return_exceptions=True)
+ for result in results:
+ if isinstance(result, Exception):
+ stop_errors.append(str(result))
+
+ platform_io_manager.clear_inbound_dispatcher()
+ try:
+ await platform_io_manager.stop()
+ except Exception as exc:
+ stop_errors.append(f"Platform IO: {exc}")
+
+ if stop_errors:
+ logger.error(f"插件运行时停止过程中存在错误: {'; '.join(stop_errors)}")
+ else:
+ logger.info("插件运行时已停止")
finally:
self._started = False
self._builtin_supervisor = None
@@ -176,6 +218,7 @@ class PluginRuntimeManager(
@property
def is_running(self) -> bool:
+ """返回插件运行时是否处于启动状态。"""
return self._started
@property
@@ -303,6 +346,37 @@ class PluginRuntimeManager(
timeout_ms=timeout_ms,
)
+ async def try_send_message_via_platform_io(
+ self,
+ message: "SessionMessage",
+ ) -> Optional[DeliveryReceipt]:
+ """尝试通过 Platform IO 中间层发送消息。
+
+ Args:
+ message: 待发送的内部会话消息。
+
+ Returns:
+ Optional[DeliveryReceipt]: 若当前消息存在 active 路由,则返回实际发送
+ 结果;若没有可用路由或 Platform IO 尚未启动,则返回 ``None``。
+ """
+ if not self._started:
+ return None
+
+ platform_io_manager = get_platform_io_manager()
+ if not platform_io_manager.is_started:
+ return None
+
+ try:
+ route_key = platform_io_manager.build_route_key_from_message(message)
+ except Exception as exc:
+ logger.warning(f"根据消息构造 Platform IO 路由键失败: {exc}")
+ return None
+
+ if platform_io_manager.resolve_driver(route_key) is None:
+ return None
+
+ return await platform_io_manager.send_message(message, route_key)
+
def _get_supervisors_for_plugin(self, plugin_id: str) -> List["PluginSupervisor"]:
"""返回当前持有指定插件的所有 Supervisor。
@@ -426,6 +500,11 @@ class PluginRuntimeManager(
"""为指定插件生成配置文件变更回调。"""
async def _callback(changes: Sequence[FileChange]) -> None:
+ """将 watcher 事件转发到指定插件的配置处理逻辑。
+
+ Args:
+ changes: 当前批次收集到的文件变更列表。
+ """
await self._handle_plugin_config_changes(plugin_id, changes)
return _callback
@@ -542,6 +621,11 @@ class PluginRuntimeManager(
# ─── 能力实现注册 ──────────────────────────────────────────
def _register_capability_impls(self, supervisor: "PluginSupervisor") -> None:
+ """向指定 Supervisor 注册主程序能力实现。
+
+ Args:
+ supervisor: 需要注册能力实现的目标 Supervisor。
+ """
register_capability_impls(self, supervisor)
diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py
index 6f95f97f..0dfc6656 100644
--- a/src/plugin_runtime/protocol/envelope.py
+++ b/src/plugin_runtime/protocol/envelope.py
@@ -7,11 +7,11 @@
from enum import Enum
from typing import Any, Dict, List, Optional
-from pydantic import BaseModel, Field
-
import logging as stdlib_logging
import time
+from pydantic import BaseModel, Field
+
# ====== 协议常量 ======
PROTOCOL_VERSION = "1.0.0"
@@ -156,6 +156,8 @@ class RegisterPluginPayload(BaseModel):
"""插件版本"""
components: List[ComponentDeclaration] = Field(default_factory=list, description="组件列表")
"""组件列表"""
+ adapter: Optional["AdapterDeclarationPayload"] = Field(default=None, description="可选的适配器声明")
+ """可选的适配器声明"""
capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表")
"""所需能力列表"""
@@ -285,6 +287,48 @@ class ReloadPluginResultPayload(BaseModel):
"""重载失败的插件及原因"""
+class AdapterDeclarationPayload(BaseModel):
+ """适配器插件声明载荷。"""
+
+ platform: str = Field(description="适配器负责的平台名称,例如 qq")
+ """适配器负责的平台名称,例如 qq"""
+ protocol: str = Field(default="", description="接入协议或实现名称,例如 napcat")
+ """接入协议或实现名称,例如 napcat"""
+ account_id: str = Field(default="", description="可选的账号 ID 或 self_id")
+ """可选的账号 ID 或 self_id"""
+ scope: str = Field(default="", description="可选的路由作用域")
+ """可选的路由作用域"""
+ send_method: str = Field(default="send_to_platform", description="Host 出站调用的插件方法名")
+ """Host 出站调用的插件方法名"""
+ metadata: Dict[str, Any] = Field(default_factory=dict, description="适配器附加元数据")
+ """适配器附加元数据"""
+
+
+class ReceiveExternalMessagePayload(BaseModel):
+ """适配器插件向 Host 注入外部消息的请求载荷。"""
+
+ message: Dict[str, Any] = Field(description="符合 MessageDict 结构的标准消息字典")
+ """符合 MessageDict 结构的标准消息字典"""
+ route_metadata: Dict[str, Any] = Field(default_factory=dict, description="可选的路由辅助元数据")
+ """可选的路由辅助元数据"""
+ external_message_id: str = Field(default="", description="可选的外部平台消息 ID")
+ """可选的外部平台消息 ID"""
+ dedupe_key: str = Field(default="", description="可选的显式去重键")
+ """可选的显式去重键"""
+
+
+class ReceiveExternalMessageResultPayload(BaseModel):
+ """外部消息注入结果载荷。"""
+
+ accepted: bool = Field(description="Host 是否接受了本次消息注入")
+ """Host 是否接受了本次消息注入"""
+ route_key: Dict[str, Any] = Field(default_factory=dict, description="本次消息使用的归一路由键")
+ """本次消息使用的归一路由键"""
+
+
+RegisterPluginPayload.model_rebuild()
+
+
# ====== 日志传输 ======
diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py
index bf36a05c..3ffb6b4b 100644
--- a/src/plugin_runtime/runner/runner_main.py
+++ b/src/plugin_runtime/runner/runner_main.py
@@ -9,9 +9,8 @@
6. 转发插件的能力调用到 Host
"""
-from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, cast
-
from pathlib import Path
+from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, cast
import asyncio
import contextlib
@@ -26,6 +25,7 @@ import tomllib
from src.common.logger import get_console_handler, get_logger, initialize_logging
from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN
from src.plugin_runtime.protocol.envelope import (
+ AdapterDeclarationPayload,
BootstrapPluginPayload,
ComponentDeclaration,
Envelope,
@@ -227,7 +227,7 @@ class PluginRunner:
plugin_id: str = "",
payload: Optional[Dict[str, Any]] = None,
) -> Any:
- """桥接 PluginContext.call_capability → RPCClient.send_request。
+ """桥接 PluginContext 的原始 RPC 调用到 Host。
无论调用方传入何种 plugin_id,实际发往 Host 的 plugin_id
始终绑定为当前插件实例,避免伪造其他插件身份申请能力。
@@ -237,17 +237,13 @@ class PluginRunner:
f"插件 {bound_plugin_id} 尝试以 {plugin_id} 身份发起 RPC,已强制绑定回自身身份"
)
resp = await rpc_client.send_request(
- method="cap.call",
+ method=method,
plugin_id=bound_plugin_id,
- payload={
- "capability": method,
- "args": payload or {},
- },
+ payload=payload or {},
)
- # 从响应信封中提取业务结果
if resp.error:
raise RuntimeError(resp.error.get("message", "能力调用失败"))
- return resp.payload.get("result")
+ return resp.payload
ctx = PluginContext(plugin_id=plugin_id, rpc_call=_rpc_call)
cast(_ContextAwarePlugin, instance)._set_context(ctx)
@@ -286,6 +282,7 @@ class PluginRunner:
self._rpc_client.register_method("plugin.invoke_command", self._handle_invoke)
self._rpc_client.register_method("plugin.invoke_action", self._handle_invoke)
self._rpc_client.register_method("plugin.invoke_tool", self._handle_invoke)
+ self._rpc_client.register_method("plugin.invoke_adapter", self._handle_invoke)
self._rpc_client.register_method("plugin.emit_event", self._handle_event_invoke)
self._rpc_client.register_method("plugin.invoke_hook", self._handle_hook_invoke)
self._rpc_client.register_method("plugin.invoke_workflow_step", self._handle_workflow_step)
@@ -306,12 +303,14 @@ class PluginRunner:
)
try:
- await self._rpc_client.send_request(
+ response = await self._rpc_client.send_request(
"plugin.bootstrap",
plugin_id=meta.plugin_id,
payload=payload.model_dump(),
timeout_ms=10000,
)
+ if response.error:
+ raise RuntimeError(response.error.get("message", "插件 bootstrap 失败"))
return True
except Exception as e:
logger.error(f"插件 {meta.plugin_id} bootstrap 失败: {e}")
@@ -321,6 +320,29 @@ class PluginRunner:
"""撤销 bootstrap 期间为插件签发的能力令牌。"""
await self._bootstrap_plugin(meta, capabilities_required=[])
+ def _collect_adapter_declaration(self, meta: PluginMeta) -> Optional[AdapterDeclarationPayload]:
+ """从插件实例中提取适配器声明。
+
+ Args:
+ meta: 待提取声明的插件元数据。
+
+ Returns:
+ Optional[AdapterDeclarationPayload]: 若插件声明了适配器角色,则返回
+ 经过校验的适配器声明;否则返回 ``None``。
+
+ Raises:
+ ValueError: 插件导出的适配器声明结构非法时抛出。
+ """
+ instance = meta.instance
+ if not hasattr(instance, "get_adapter_info"):
+ return None
+
+ adapter_info = instance.get_adapter_info()
+ if adapter_info is None:
+ return None
+
+ return AdapterDeclarationPayload.model_validate(adapter_info)
+
async def _register_plugin(self, meta: PluginMeta) -> bool:
"""向 Host 注册单个插件。
@@ -346,20 +368,29 @@ class PluginRunner:
for comp_info in instance.get_components()
)
+ try:
+ adapter = self._collect_adapter_declaration(meta)
+ except Exception as exc:
+ logger.error(f"插件 {meta.plugin_id} 适配器声明非法: {exc}", exc_info=True)
+ return False
+
reg_payload = RegisterPluginPayload(
plugin_id=meta.plugin_id,
plugin_version=meta.version,
components=components,
+ adapter=adapter,
capabilities_required=meta.capabilities_required,
)
try:
- _resp = await self._rpc_client.send_request(
+ response = await self._rpc_client.send_request(
"plugin.register_components",
plugin_id=meta.plugin_id,
payload=reg_payload.model_dump(),
timeout_ms=10000,
)
+ if response.error:
+ raise RuntimeError(response.error.get("message", "插件注册失败"))
logger.info(f"插件 {meta.plugin_id} 注册完成")
return True
except Exception as e:
diff --git a/src/plugins/built_in/napcat_adapter/_manifest.json b/src/plugins/built_in/napcat_adapter/_manifest.json
new file mode 100644
index 00000000..6f7e68fd
--- /dev/null
+++ b/src/plugins/built_in/napcat_adapter/_manifest.json
@@ -0,0 +1,30 @@
+{
+ "manifest_version": 1,
+ "name": "napcat_adapter_builtin",
+ "version": "0.1.0",
+ "description": "Built-in NapCat adapter plugin for MVP message forwarding.",
+ "author": {
+ "name": "OpenAI Codex"
+ },
+ "license": "GPL-v3.0-or-later",
+ "host_application": {
+ "min_version": "1.0.0"
+ },
+ "keywords": [
+ "adapter",
+ "built-in",
+ "napcat",
+ "onebot",
+ "qq"
+ ],
+ "categories": [
+ "Adapter",
+ "Built-in"
+ ],
+ "default_locale": "en-US",
+ "plugin_info": {
+ "is_built_in": true,
+ "plugin_type": "adapter"
+ },
+ "capabilities": []
+}
diff --git a/src/plugins/built_in/napcat_adapter/plugin.py b/src/plugins/built_in/napcat_adapter/plugin.py
new file mode 100644
index 00000000..3eff518d
--- /dev/null
+++ b/src/plugins/built_in/napcat_adapter/plugin.py
@@ -0,0 +1,690 @@
+"""内置 NapCat 适配器插件。
+
+当前实现是一个 MVP 版本,目标仅限于跑通基础消息收发链路:
+1. 作为客户端连接 NapCat / OneBot v11 WebSocket 服务。
+2. 将入站消息事件转换为 Host 侧的 ``MessageDict``。
+3. 将 Host 出站消息转换为 OneBot 动作并发送。
+
+当前范围刻意收敛为:
+- 单连接
+- 文本、@、reply 基础转发
+- 暂不处理 ``notice`` / ``meta_event``
+- 暂不支持图片、语音、文件等复杂媒体
+"""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
+from uuid import uuid4
+
+import asyncio
+import contextlib
+import json
+import time
+
+from maibot_sdk import Adapter, MaiBotPlugin
+
+if TYPE_CHECKING:
+ from aiohttp import ClientWebSocketResponse as AiohttpClientWebSocketResponse
+
+try:
+ from aiohttp import ClientSession, ClientTimeout, ClientWebSocketResponse, WSMsgType
+
+ AIOHTTP_AVAILABLE = True
+except ImportError:
+ ClientSession = cast(Any, None)
+ ClientTimeout = cast(Any, None)
+ ClientWebSocketResponse = cast(Any, None)
+ WSMsgType = cast(Any, None)
+ AIOHTTP_AVAILABLE = False
+
+if not TYPE_CHECKING:
+ AiohttpClientWebSocketResponse = Any
+
+
+@Adapter(platform="qq", protocol="napcat", send_method="send_to_platform")
+class NapCatAdapterPlugin(MaiBotPlugin):
+ """NapCat 适配器 MVP 实现。"""
+
+ def __init__(self) -> None:
+ """初始化 NapCat 适配器插件实例。"""
+ super().__init__()
+ self._plugin_config: Dict[str, Any] = {}
+ self._connection_task: Optional[asyncio.Task[None]] = None
+ self._pending_actions: Dict[str, asyncio.Future[Dict[str, Any]]] = {}
+ self._background_tasks: set[asyncio.Task[Any]] = set()
+ self._send_lock = asyncio.Lock()
+ self._ws: Optional[AiohttpClientWebSocketResponse] = None
+
+ def set_plugin_config(self, config: Dict[str, Any]) -> None:
+ """设置插件配置内容。
+
+ Args:
+ config: Runner 注入的 ``config.toml`` 解析结果。
+ """
+ self._plugin_config = config if isinstance(config, dict) else {}
+
+ async def on_load(self) -> None:
+ """在插件加载时根据配置决定是否启动连接。"""
+ await self._restart_connection_if_needed()
+
+ async def on_unload(self) -> None:
+ """在插件卸载时关闭连接并清理后台任务。"""
+ await self._stop_connection()
+ await self._cancel_background_tasks()
+
+ async def on_config_update(self, new_config: Dict[str, Any], version: str) -> None:
+ """在配置更新后重载连接状态。
+
+ Args:
+ new_config: 最新的插件配置。
+ version: 配置版本号。
+ """
+ del version
+ self.set_plugin_config(new_config)
+ await self._restart_connection_if_needed()
+
+ async def send_to_platform(
+ self,
+ message: Dict[str, Any],
+ route: Optional[Dict[str, Any]] = None,
+ metadata: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> Dict[str, Any]:
+ """将 Host 出站消息发送到 NapCat。
+
+ Args:
+ message: Host 侧标准 ``MessageDict``。
+ route: Platform IO 生成的路由信息。
+ metadata: Platform IO 附带的投递元数据。
+ **kwargs: 预留的扩展参数。
+
+ Returns:
+ Dict[str, Any]: 标准化后的发送结果。
+ """
+ del metadata
+ del kwargs
+
+ ws = self._ws
+ if ws is None or ws.closed:
+ return {"success": False, "error": "NapCat is not connected"}
+
+ try:
+ action_name, params = self._build_outbound_action(message, route or {})
+ response = await self._call_action(action_name, params)
+ except Exception as exc:
+ return {"success": False, "error": str(exc)}
+
+ if str(response.get("status", "")).lower() != "ok":
+ return {
+ "success": False,
+ "error": str(response.get("wording") or response.get("message") or "NapCat send failed"),
+ "metadata": {"retcode": response.get("retcode")},
+ }
+
+ response_data = response.get("data", {})
+ external_message_id = ""
+ if isinstance(response_data, dict):
+ external_message_id = str(response_data.get("message_id") or "")
+
+ return {
+ "success": True,
+ "external_message_id": external_message_id or None,
+ "metadata": {"action": action_name},
+ }
+
+ async def _restart_connection_if_needed(self) -> None:
+ """根据当前配置重启连接循环。"""
+ await self._stop_connection()
+ if not self._should_connect():
+ self.ctx.logger.info("NapCat 适配器保持空闲状态,因为插件或配置未启用")
+ return
+ if not AIOHTTP_AVAILABLE:
+ self.ctx.logger.error("NapCat 适配器依赖 aiohttp,但当前环境未安装该依赖")
+ return
+ self._connection_task = asyncio.create_task(self._connection_loop(), name="napcat_adapter.connection")
+
+ async def _stop_connection(self) -> None:
+ """停止当前连接并让所有等待中的动作失败返回。"""
+ connection_task = self._connection_task
+ self._connection_task = None
+
+ ws = self._ws
+ if ws is not None and not ws.closed:
+ with contextlib.suppress(Exception):
+ await ws.close()
+ self._ws = None
+
+ if connection_task is not None:
+ connection_task.cancel()
+ with contextlib.suppress(asyncio.CancelledError):
+ await connection_task
+
+ self._fail_pending_actions("NapCat connection closed")
+
+ async def _cancel_background_tasks(self) -> None:
+ """取消所有仍在运行的入站后台任务。"""
+ background_tasks = list(self._background_tasks)
+ for task in background_tasks:
+ task.cancel()
+ if background_tasks:
+ with contextlib.suppress(Exception):
+ await asyncio.gather(*background_tasks, return_exceptions=True)
+ self._background_tasks.clear()
+
+ async def _connection_loop(self) -> None:
+ """维护单个 WebSocket 连接,并在断开后按配置重连。"""
+ assert ClientSession is not None
+ assert ClientTimeout is not None
+
+ while self._should_connect():
+ ws_url = self._get_string(self._connection_config(), "ws_url")
+ if not ws_url:
+ self.ctx.logger.warning("NapCat 适配器已启用,但 connection.ws_url 为空")
+ return
+
+ headers = self._build_headers()
+ timeout = ClientTimeout(total=None, connect=10)
+ heartbeat = self._get_positive_float(self._connection_config(), "heartbeat_sec", 30.0)
+
+ try:
+ async with ClientSession(headers=headers, timeout=timeout) as session:
+ async with session.ws_connect(ws_url, heartbeat=heartbeat or None) as ws:
+ self._ws = ws
+ self.ctx.logger.info(f"NapCat 适配器已连接: {ws_url}")
+ await self._receive_loop(ws)
+ except asyncio.CancelledError:
+ raise
+ except Exception as exc:
+ self.ctx.logger.warning(f"NapCat 适配器连接失败: {exc}")
+ finally:
+ self._ws = None
+ self._fail_pending_actions("NapCat connection interrupted")
+
+ if not self._should_connect():
+ break
+
+ await asyncio.sleep(self._get_positive_float(self._connection_config(), "reconnect_delay_sec", 5.0))
+
+ async def _receive_loop(self, ws: AiohttpClientWebSocketResponse) -> None:
+ """持续消费 WebSocket 消息并分发处理。
+
+ Args:
+ ws: 当前活跃的 WebSocket 连接对象。
+ """
+ assert WSMsgType is not None
+
+ async for ws_message in ws:
+ if ws_message.type != WSMsgType.TEXT:
+ if ws_message.type in {WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.ERROR}:
+ break
+ continue
+
+ payload = self._parse_json_message(ws_message.data)
+ if payload is None:
+ continue
+
+ if echo_id := str(payload.get("echo") or "").strip():
+ self._resolve_pending_action(echo_id, payload)
+ continue
+
+ if str(payload.get("post_type") or "").strip() != "message":
+ continue
+
+ task = asyncio.create_task(self._handle_inbound_message(payload), name="napcat_adapter.inbound")
+ self._background_tasks.add(task)
+ task.add_done_callback(self._background_tasks.discard)
+
+ async def _handle_inbound_message(self, payload: Dict[str, Any]) -> None:
+ """处理单条 NapCat 入站消息并注入 Host。
+
+ Args:
+ payload: NapCat / OneBot 推送的原始事件数据。
+ """
+ self_id = str(payload.get("self_id") or "").strip()
+ sender = payload.get("sender", {})
+ if not isinstance(sender, dict):
+ sender = {}
+
+ sender_user_id = str(payload.get("user_id") or sender.get("user_id") or "").strip()
+ if not sender_user_id:
+ return
+
+ if self_id and sender_user_id == self_id and self._get_bool(self._filters_config(), "ignore_self_message", True):
+ return
+
+ message_dict = self._build_inbound_message_dict(payload, self_id, sender_user_id, sender)
+ route_metadata: Dict[str, Any] = {}
+ if self_id:
+ route_metadata["self_id"] = self_id
+ if connection_id := self._get_string(self._connection_config(), "connection_id"):
+ route_metadata["connection_id"] = connection_id
+
+ external_message_id = str(payload.get("message_id") or "").strip()
+ accepted = await self.ctx.adapter.receive_external_message(
+ message_dict,
+ route_metadata=route_metadata,
+ external_message_id=external_message_id,
+ dedupe_key=external_message_id,
+ )
+ if not accepted:
+ self.ctx.logger.debug(f"Host 丢弃了 NapCat 入站消息: {external_message_id or '无消息 ID'}")
+
+ def _build_inbound_message_dict(
+ self,
+ payload: Dict[str, Any],
+ self_id: str,
+ sender_user_id: str,
+ sender: Dict[str, Any],
+ ) -> Dict[str, Any]:
+ """构造 Host 侧可接受的 ``MessageDict``。
+
+ Args:
+ payload: NapCat 原始消息事件。
+ self_id: 当前机器人账号 ID。
+ sender_user_id: 发送者用户 ID。
+ sender: 发送者信息字典。
+
+ Returns:
+ Dict[str, Any]: 规范化后的 ``MessageDict``。
+ """
+ message_type = str(payload.get("message_type") or "").strip() or "private"
+ group_id = str(payload.get("group_id") or "").strip()
+ group_name = str(payload.get("group_name") or "").strip() or (f"group_{group_id}" if group_id else "")
+ user_nickname = str(sender.get("nickname") or sender.get("card") or sender_user_id).strip() or sender_user_id
+ user_cardname = str(sender.get("card") or "").strip() or None
+
+ raw_message, is_at = self._convert_inbound_segments(payload.get("message"), self_id)
+ raw_message_text = str(payload.get("raw_message") or "").strip()
+ if not raw_message:
+ raw_message = [{"type": "text", "data": raw_message_text or "[unsupported]"}]
+
+ plain_text = self._build_plain_text(raw_message, raw_message_text)
+ timestamp_seconds = payload.get("time")
+ if not isinstance(timestamp_seconds, (int, float)):
+ timestamp_seconds = time.time()
+
+ additional_config: Dict[str, Any] = {"self_id": self_id, "napcat_message_type": message_type}
+ if group_id:
+ additional_config["platform_io_target_group_id"] = group_id
+ else:
+ additional_config["platform_io_target_user_id"] = sender_user_id
+
+ message_info: Dict[str, Any] = {
+ "user_info": {
+ "user_id": sender_user_id,
+ "user_nickname": user_nickname,
+ "user_cardname": user_cardname,
+ },
+ "additional_config": additional_config,
+ }
+ if group_id:
+ message_info["group_info"] = {"group_id": group_id, "group_name": group_name}
+
+ message_id = str(payload.get("message_id") or f"napcat-{uuid4().hex}").strip()
+ return {
+ "message_id": message_id,
+ "timestamp": str(float(timestamp_seconds)),
+ "platform": "qq",
+ "message_info": message_info,
+ "raw_message": raw_message,
+ "is_mentioned": is_at,
+ "is_at": is_at,
+ "is_emoji": False,
+ "is_picture": False,
+ "is_command": plain_text.startswith("/"),
+ "is_notify": False,
+ "session_id": "",
+ "processed_plain_text": plain_text,
+ "display_message": plain_text,
+ }
+
+ def _convert_inbound_segments(self, message_payload: Any, self_id: str) -> tuple[List[Dict[str, Any]], bool]:
+ """将 OneBot 消息段转换为 Host 消息段结构。
+
+ Args:
+ message_payload: OneBot 原始 ``message`` 字段。
+ self_id: 当前机器人账号 ID。
+
+ Returns:
+ tuple[List[Dict[str, Any]], bool]: 转换后的消息段列表,以及是否 @ 到当前机器人。
+ """
+ if isinstance(message_payload, str):
+ normalized_text = message_payload.strip()
+ return ([{"type": "text", "data": normalized_text}] if normalized_text else []), False
+
+ if not isinstance(message_payload, list):
+ return [], False
+
+ converted_segments: List[Dict[str, Any]] = []
+ is_at = False
+ placeholder_texts = {
+ "face": "[face]",
+ "file": "[file]",
+ "image": "[image]",
+ "json": "[json]",
+ "record": "[voice]",
+ "video": "[video]",
+ "xml": "[xml]",
+ }
+
+ for segment in message_payload:
+ if not isinstance(segment, dict):
+ continue
+
+ segment_type = str(segment.get("type") or "").strip()
+ segment_data = segment.get("data", {})
+ if not isinstance(segment_data, dict):
+ segment_data = {}
+
+ if segment_type == "text":
+ if text_value := str(segment_data.get("text") or ""):
+ converted_segments.append({"type": "text", "data": text_value})
+ continue
+
+ if segment_type == "at":
+ if target_user_id := str(segment_data.get("qq") or "").strip():
+ converted_segments.append(
+ {
+ "type": "at",
+ "data": {
+ "target_user_id": target_user_id,
+ "target_user_nickname": None,
+ "target_user_cardname": None,
+ },
+ }
+ )
+ if self_id and target_user_id == self_id:
+ is_at = True
+ continue
+
+ if segment_type == "reply":
+ if target_message_id := str(segment_data.get("id") or "").strip():
+ converted_segments.append({"type": "reply", "data": target_message_id})
+ continue
+
+ if placeholder := placeholder_texts.get(segment_type):
+ converted_segments.append({"type": "text", "data": placeholder})
+
+ return converted_segments, is_at
+
+ def _build_outbound_action(
+ self,
+ message: Dict[str, Any],
+ route: Dict[str, Any],
+ ) -> tuple[str, Dict[str, Any]]:
+ """为 Host 出站消息构造 OneBot 动作。
+
+ Args:
+ message: Host 侧标准 ``MessageDict``。
+ route: Platform IO 路由信息。
+
+ Returns:
+ tuple[str, Dict[str, Any]]: 动作名称与参数字典。
+ """
+ message_info = message.get("message_info", {})
+ if not isinstance(message_info, dict):
+ message_info = {}
+
+ group_info = message_info.get("group_info", {})
+ if not isinstance(group_info, dict):
+ group_info = {}
+
+ additional_config = message_info.get("additional_config", {})
+ if not isinstance(additional_config, dict):
+ additional_config = {}
+
+ raw_message = message.get("raw_message", [])
+ segments = self._convert_outbound_segments(raw_message)
+
+ if target_group_id := str(
+ group_info.get("group_id") or additional_config.get("platform_io_target_group_id") or ""
+ ).strip():
+ return "send_group_msg", {"group_id": target_group_id, "message": segments}
+
+ if not (
+ target_user_id := str(
+ additional_config.get("platform_io_target_user_id")
+ or additional_config.get("target_user_id")
+ or route.get("target_user_id")
+ or ""
+ ).strip()
+ ):
+ raise ValueError("Outbound private message is missing target_user_id")
+
+ return "send_private_msg", {"message": segments, "user_id": target_user_id}
+
+ def _convert_outbound_segments(self, raw_message: Any) -> List[Dict[str, Any]]:
+ """将 Host 消息段转换为 OneBot 消息段。
+
+ Args:
+ raw_message: Host 侧 ``raw_message`` 字段。
+
+ Returns:
+ List[Dict[str, Any]]: OneBot 消息段列表。
+ """
+ if not isinstance(raw_message, list):
+ return [{"type": "text", "data": {"text": ""}}]
+
+ outbound_segments: List[Dict[str, Any]] = []
+ for item in raw_message:
+ if not isinstance(item, dict):
+ continue
+
+ item_type = str(item.get("type") or "").strip()
+ item_data = item.get("data")
+
+ if item_type == "text":
+ text_value = str(item_data or "")
+ outbound_segments.append({"type": "text", "data": {"text": text_value}})
+ continue
+
+ if item_type == "at" and isinstance(item_data, dict):
+ if target_user_id := str(item_data.get("target_user_id") or "").strip():
+ outbound_segments.append({"type": "at", "data": {"qq": target_user_id}})
+ continue
+
+ if item_type == "reply":
+ if target_message_id := str(item_data or "").strip():
+ outbound_segments.append({"type": "reply", "data": {"id": target_message_id}})
+ continue
+
+ fallback_text = f"[unsupported:{item_type or 'unknown'}]"
+ outbound_segments.append({"type": "text", "data": {"text": fallback_text}})
+
+ if not outbound_segments:
+ outbound_segments.append({"type": "text", "data": {"text": ""}})
+ return outbound_segments
+
+ async def _call_action(self, action_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
+ """发送 OneBot 动作并等待对应的 echo 响应。
+
+ Args:
+ action_name: OneBot 动作名称。
+ params: 动作参数。
+
+ Returns:
+ Dict[str, Any]: NapCat 返回的原始响应字典。
+ """
+ ws = self._ws
+ if ws is None or ws.closed:
+ raise RuntimeError("NapCat is not connected")
+
+ echo_id = uuid4().hex
+ loop = asyncio.get_running_loop()
+ response_future: asyncio.Future[Dict[str, Any]] = loop.create_future()
+ self._pending_actions[echo_id] = response_future
+
+ request_payload = {"action": action_name, "params": params, "echo": echo_id}
+ try:
+ async with self._send_lock:
+ await ws.send_str(json.dumps(request_payload, ensure_ascii=False))
+ timeout_seconds = self._get_positive_float(self._connection_config(), "action_timeout_sec", 15.0)
+ return await asyncio.wait_for(response_future, timeout=timeout_seconds)
+ finally:
+ self._pending_actions.pop(echo_id, None)
+
+ def _resolve_pending_action(self, echo_id: str, payload: Dict[str, Any]) -> None:
+ """解析等待中的动作响应。
+
+ Args:
+ echo_id: 动作请求对应的 echo 标识。
+ payload: NapCat 返回的响应载荷。
+ """
+ response_future = self._pending_actions.get(echo_id)
+ if response_future is None or response_future.done():
+ return
+ response_future.set_result(payload)
+
+ def _fail_pending_actions(self, error_message: str) -> None:
+ """让所有等待中的动作以异常方式结束。
+
+ Args:
+ error_message: 写入异常中的错误信息。
+ """
+ for response_future in self._pending_actions.values():
+ if not response_future.done():
+ response_future.set_exception(RuntimeError(error_message))
+ self._pending_actions.clear()
+
+ def _build_headers(self) -> Dict[str, str]:
+ """构造连接 NapCat 所需的请求头。
+
+ Returns:
+ Dict[str, str]: WebSocket 握手请求头。
+ """
+ access_token = self._get_string(self._connection_config(), "access_token")
+ return {"Authorization": f"Bearer {access_token}"} if access_token else {}
+
+ def _parse_json_message(self, data: Any) -> Optional[Dict[str, Any]]:
+ """解析 WebSocket 文本消息中的 JSON 数据。
+
+ Args:
+ data: WebSocket 收到的原始文本数据。
+
+ Returns:
+ Optional[Dict[str, Any]]: 成功时返回字典,失败时返回 ``None``。
+ """
+ try:
+ payload = json.loads(str(data))
+ except Exception as exc:
+ self.ctx.logger.warning(f"NapCat 适配器解析 JSON 载荷失败: {exc}")
+ return None
+
+ return payload if isinstance(payload, dict) else None
+
+ def _build_plain_text(self, raw_message: List[Dict[str, Any]], fallback_text: str) -> str:
+ """从标准消息段中提取可展示的纯文本。
+
+ Args:
+ raw_message: 标准化后的消息段列表。
+ fallback_text: 当无法拼出文本时使用的回退文本。
+
+ Returns:
+ str: 用于 Host 展示和命令判断的纯文本内容。
+ """
+ plain_text_parts: List[str] = []
+ for item in raw_message:
+ if not isinstance(item, dict):
+ continue
+ item_type = str(item.get("type") or "").strip()
+ item_data = item.get("data")
+ if item_type == "text":
+ plain_text_parts.append(str(item_data or ""))
+ elif item_type == "at" and isinstance(item_data, dict):
+ plain_text_parts.append(f"@{item_data.get('target_user_id') or ''}")
+ elif item_type == "reply":
+ plain_text_parts.append("[reply]")
+
+ plain_text = "".join(part for part in plain_text_parts if part).strip()
+ return plain_text or fallback_text or "[unsupported]"
+
+ def _plugin_section(self) -> Dict[str, Any]:
+ """读取插件配置中的 ``plugin`` 段。
+
+ Returns:
+ Dict[str, Any]: ``plugin`` 配置字典。
+ """
+ plugin_section = self._plugin_config.get("plugin", {})
+ return plugin_section if isinstance(plugin_section, dict) else {}
+
+ def _connection_config(self) -> Dict[str, Any]:
+ """读取插件配置中的 ``connection`` 段。
+
+ Returns:
+ Dict[str, Any]: ``connection`` 配置字典。
+ """
+ connection_config = self._plugin_config.get("connection", {})
+ return connection_config if isinstance(connection_config, dict) else {}
+
+ def _filters_config(self) -> Dict[str, Any]:
+ """读取插件配置中的 ``filters`` 段。
+
+ Returns:
+ Dict[str, Any]: ``filters`` 配置字典。
+ """
+ filters_config = self._plugin_config.get("filters", {})
+ return filters_config if isinstance(filters_config, dict) else {}
+
+ def _should_connect(self) -> bool:
+ """判断当前配置下是否应当启动连接。
+
+ Returns:
+ bool: 若启用了插件连接则返回 ``True``。
+ """
+ return self._get_bool(self._plugin_section(), "enabled", False)
+
+ @staticmethod
+ def _get_bool(mapping: Dict[str, Any], key: str, default: bool) -> bool:
+ """安全读取布尔配置值。
+
+ Args:
+ mapping: 待读取的配置字典。
+ key: 目标键名。
+ default: 读取失败时的默认值。
+
+ Returns:
+ bool: 解析后的布尔值。
+ """
+ value = mapping.get(key, default)
+ return value if isinstance(value, bool) else default
+
+ @staticmethod
+ def _get_positive_float(mapping: Dict[str, Any], key: str, default: float) -> float:
+ """安全读取正浮点数配置值。
+
+ Args:
+ mapping: 待读取的配置字典。
+ key: 目标键名。
+ default: 读取失败时的默认值。
+
+ Returns:
+ float: 合法的正浮点数;否则返回默认值。
+ """
+ value = mapping.get(key, default)
+ if isinstance(value, (int, float)) and float(value) > 0:
+ return float(value)
+ return default
+
+ @staticmethod
+ def _get_string(mapping: Dict[str, Any], key: str) -> str:
+ """安全读取字符串配置值。
+
+ Args:
+ mapping: 待读取的配置字典。
+ key: 目标键名。
+
+ Returns:
+ str: 去除首尾空白后的字符串值。
+ """
+ value = mapping.get(key)
+ return "" if value is None else str(value).strip()
+
+
+def create_plugin() -> NapCatAdapterPlugin:
+ """创建插件实例。
+
+ Returns:
+ NapCatAdapterPlugin: NapCat 内置适配器插件实例。
+ """
+ return NapCatAdapterPlugin()
diff --git a/src/services/send_service.py b/src/services/send_service.py
index 7af55716..6ca7d005 100644
--- a/src/services/send_service.py
+++ b/src/services/send_service.py
@@ -4,7 +4,7 @@
提供发送各种类型消息的核心功能。
"""
-from typing import Dict, List, Optional, TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
import time
import traceback
@@ -19,6 +19,7 @@ from src.common.data_models.mai_message_data_model import MaiMessage
from src.common.data_models.message_component_data_model import DictComponent, MessageSequence
from src.common.logger import get_logger
from src.config.config import global_config
+from src.platform_io.route_key_factory import RouteKeyFactory
if TYPE_CHECKING:
from src.chat.message_receive.message import SessionMessage
@@ -31,6 +32,50 @@ logger = get_logger("send_service")
# =============================================================================
+def _inherit_platform_io_route_metadata(target_stream: Any) -> Dict[str, object]:
+ """从目标会话上下文继承 Platform IO 路由元数据。
+
+ Args:
+ target_stream: 当前消息要发送到的会话对象。
+
+ Returns:
+ Dict[str, object]: 可安全透传到出站消息 ``additional_config`` 中的
+ 路由辅助字段。
+ """
+ inherited_metadata: Dict[str, object] = {}
+
+ context = getattr(target_stream, "context", None)
+ context_message = getattr(context, "message", None)
+ if context_message is None:
+ return inherited_metadata
+
+ additional_config = getattr(context_message.message_info, "additional_config", {})
+ if not isinstance(additional_config, dict):
+ return inherited_metadata
+
+ for key in (*RouteKeyFactory.ACCOUNT_ID_KEYS, *RouteKeyFactory.SCOPE_KEYS):
+ value = additional_config.get(key)
+ if value is None:
+ continue
+ normalized_value = str(value).strip()
+ if normalized_value:
+ inherited_metadata[key] = value
+
+ target_group_id = getattr(target_stream, "group_id", None)
+ if target_group_id is not None:
+ normalized_group_id = str(target_group_id).strip()
+ if normalized_group_id:
+ inherited_metadata["platform_io_target_group_id"] = normalized_group_id
+
+ target_user_id = getattr(target_stream, "user_id", None)
+ if target_user_id is not None:
+ normalized_user_id = str(target_user_id).strip()
+ if normalized_user_id:
+ inherited_metadata["platform_io_target_user_id"] = normalized_user_id
+
+ return inherited_metadata
+
+
async def _send_to_target(
message_segment: Seg,
stream_id: str,
@@ -42,7 +87,22 @@ async def _send_to_target(
show_log: bool = True,
selected_expressions: Optional[List[int]] = None,
) -> bool:
- """向指定目标发送消息的内部实现"""
+ """向指定目标发送消息。
+
+ Args:
+ message_segment: 待发送的消息段。
+ stream_id: 目标会话 ID。
+ display_message: 用于界面展示的文本内容。
+ typing: 是否显示输入中状态。
+ set_reply: 是否在发送时附带引用回复。
+ reply_message: 被回复的消息对象。
+ storage_message: 是否将发送结果写入消息存储。
+ show_log: 是否输出发送日志。
+ selected_expressions: 可选的表情候选索引列表。
+
+ Returns:
+ bool: 发送成功返回 ``True``,否则返回 ``False``。
+ """
try:
if set_reply and not reply_message:
logger.warning("[SendService] 使用引用回复,但未提供回复消息")
@@ -80,7 +140,7 @@ async def _send_to_target(
platform=target_stream.platform,
)
- additional_config: dict[str, object] = {}
+ additional_config: Dict[str, object] = _inherit_platform_io_route_metadata(target_stream)
if selected_expressions is not None:
additional_config["selected_expressions"] = selected_expressions
bot_user_id = get_bot_account(target_stream.platform)
diff --git a/src/webui/routers/chat/serializers.py b/src/webui/routers/chat/serializers.py
new file mode 100644
index 00000000..32104f88
--- /dev/null
+++ b/src/webui/routers/chat/serializers.py
@@ -0,0 +1,175 @@
+"""提供 WebUI 聊天路由使用的消息序列化能力。"""
+
+from typing import Any, Dict, List, Optional
+
+import base64
+
+from src.common.data_models.message_component_data_model import (
+ AtComponent,
+ DictComponent,
+ EmojiComponent,
+ ForwardComponent,
+ ForwardNodeComponent,
+ ImageComponent,
+ MessageSequence,
+ ReplyComponent,
+ StandardMessageComponents,
+ TextComponent,
+ VoiceComponent,
+)
+
+
+def serialize_message_sequence(message_sequence: MessageSequence) -> List[Dict[str, Any]]:
+ """将内部统一消息组件序列转换为 WebUI 富文本消息段。
+
+ Args:
+ message_sequence: 内部统一消息组件序列。
+
+ Returns:
+ List[Dict[str, Any]]: 可直接广播给 WebUI 前端的消息段列表。
+ """
+ serialized_segments: List[Dict[str, Any]] = []
+ for component in message_sequence.components:
+ serialized_segment = serialize_message_component(component)
+ if serialized_segment is not None:
+ serialized_segments.append(serialized_segment)
+ return serialized_segments
+
+
+def serialize_message_component(component: StandardMessageComponents) -> Optional[Dict[str, Any]]:
+ """将单个内部消息组件转换为 WebUI 消息段。
+
+ Args:
+ component: 待序列化的内部消息组件。
+
+ Returns:
+ Optional[Dict[str, Any]]: 序列化后的 WebUI 消息段;若组件不应展示则返回 ``None``。
+ """
+ if isinstance(component, TextComponent):
+ return {"type": "text", "data": component.text}
+
+ if isinstance(component, ImageComponent):
+ return _serialize_binary_component(
+ segment_type="image",
+ mime_type="image/png",
+ binary_data=component.binary_data,
+ fallback_text=component.content,
+ )
+
+ if isinstance(component, EmojiComponent):
+ return _serialize_binary_component(
+ segment_type="emoji",
+ mime_type="image/gif",
+ binary_data=component.binary_data,
+ fallback_text=component.content,
+ )
+
+ if isinstance(component, VoiceComponent):
+ return _serialize_binary_component(
+ segment_type="voice",
+ mime_type="audio/wav",
+ binary_data=component.binary_data,
+ fallback_text=component.content,
+ )
+
+ if isinstance(component, AtComponent):
+ return {
+ "type": "at",
+ "data": {
+ "target_user_id": component.target_user_id,
+ "target_user_nickname": component.target_user_nickname,
+ "target_user_cardname": component.target_user_cardname,
+ },
+ }
+
+ if isinstance(component, ReplyComponent):
+ return {
+ "type": "reply",
+ "data": {
+ "target_message_id": component.target_message_id,
+ "target_message_content": component.target_message_content,
+ "target_message_sender_id": component.target_message_sender_id,
+ "target_message_sender_nickname": component.target_message_sender_nickname,
+ "target_message_sender_cardname": component.target_message_sender_cardname,
+ },
+ }
+
+ if isinstance(component, ForwardNodeComponent):
+ return {
+ "type": "forward",
+ "data": [_serialize_forward_component(item) for item in component.forward_components],
+ }
+
+ if isinstance(component, DictComponent):
+ return _serialize_dict_component(component.data)
+
+ return {"type": "unknown", "data": str(component)}
+
+
+def _serialize_binary_component(
+ segment_type: str,
+ mime_type: str,
+ binary_data: bytes,
+ fallback_text: str,
+) -> Dict[str, Any]:
+ """序列化带二进制负载的消息组件。
+
+ Args:
+ segment_type: WebUI 消息段类型。
+ mime_type: 对应的数据 MIME 类型。
+ binary_data: 组件二进制数据。
+ fallback_text: 二进制缺失时可退化展示的文本。
+
+ Returns:
+ Dict[str, Any]: 序列化后的 WebUI 消息段。
+ """
+ if binary_data:
+ encoded_payload = base64.b64encode(binary_data).decode()
+ return {"type": segment_type, "data": f"data:{mime_type};base64,{encoded_payload}"}
+
+ if fallback_text:
+ return {"type": "text", "data": fallback_text}
+
+ return {"type": "unknown", "original_type": segment_type, "data": ""}
+
+
+def _serialize_forward_component(component: ForwardComponent) -> Dict[str, Any]:
+ """序列化单个转发节点。
+
+ Args:
+ component: 待序列化的转发节点组件。
+
+ Returns:
+ Dict[str, Any]: WebUI 可消费的转发节点字典。
+ """
+ return {
+ "message_id": component.message_id,
+ "user_id": component.user_id,
+ "user_nickname": component.user_nickname,
+ "user_cardname": component.user_cardname,
+ "content": serialize_message_sequence(MessageSequence(component.content)),
+ }
+
+
+def _serialize_dict_component(data: Dict[str, Any]) -> Dict[str, Any]:
+ """最佳努力地序列化非标准字典组件。
+
+ Args:
+ data: 原始字典组件内容。
+
+ Returns:
+ Dict[str, Any]: 序列化后的 WebUI 消息段。
+ """
+ raw_type = str(data.get("type") or "dict").strip()
+ raw_payload = data.get("data", data)
+
+ if raw_type in {"text", "image", "emoji", "voice", "video", "file", "music", "face"}:
+ return {"type": raw_type, "data": raw_payload}
+
+ if raw_type == "reply":
+ return {"type": "reply", "data": raw_payload}
+
+ if raw_type == "forward" and isinstance(raw_payload, list):
+ return {"type": "forward", "data": raw_payload}
+
+ return {"type": "unknown", "original_type": raw_type, "data": raw_payload}
From 780cd4f767eb7f8f6bc2357ef062891bb2de2ca2 Mon Sep 17 00:00:00 2001
From: DrSmoothl <1787882683@qq.com>
Date: Sat, 21 Mar 2026 00:46:34 +0800
Subject: [PATCH 20/45] =?UTF-8?q?refactor:=20=E6=9B=B4=E6=96=B0=E6=8F=92?=
=?UTF-8?q?=E4=BB=B6=E5=92=8C=20RPC=20=E6=9C=8D=E5=8A=A1=E5=99=A8=E9=80=BB?=
=?UTF-8?q?=E8=BE=91=EF=BC=8C=E5=A2=9E=E5=BC=BA=E6=8F=A1=E6=89=8B=E7=8A=B6?=
=?UTF-8?q?=E6=80=81=E7=AE=A1=E7=90=86=E4=B8=8E=E9=85=8D=E7=BD=AE=E6=A0=A1?=
=?UTF-8?q?=E9=AA=8C?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
pyproject.toml | 2 +-
src/plugin_runtime/host/rpc_server.py | 23 +-
src/plugin_runtime/host/supervisor.py | 79 +++++-
src/plugin_runtime/integration.py | 162 ++++++++---
src/plugin_runtime/protocol/envelope.py | 2 +-
src/plugins/built_in/napcat_adapter/plugin.py | 251 +++++++++++++++++-
6 files changed, 457 insertions(+), 62 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index f41e9448..9887ac24 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -19,7 +19,7 @@ dependencies = [
"jieba>=0.42.1",
"json-repair>=0.47.6",
"maim-message>=0.6.2",
- "maibot-plugin-sdk>=1.2.3,<2.0.0",
+ "maibot-plugin-sdk>=2.0.0",
"msgpack>=1.1.2",
"numpy>=2.2.6",
"openai>=1.95.0",
diff --git a/src/plugin_runtime/host/rpc_server.py b/src/plugin_runtime/host/rpc_server.py
index 75ef9b2a..2c422775 100644
--- a/src/plugin_runtime/host/rpc_server.py
+++ b/src/plugin_runtime/host/rpc_server.py
@@ -69,6 +69,7 @@ class RPCServer:
# 运行状态
self._running: bool = False
self._tasks: List[asyncio.Task[None]] = []
+ self._last_handshake_rejection_reason: str = ""
@property
def session_token(self) -> str:
@@ -78,6 +79,15 @@ class RPCServer:
def is_connected(self) -> bool:
return self._connection is not None and not self._connection.is_closed
+ @property
+ def last_handshake_rejection_reason(self) -> str:
+ """返回最近一次握手被拒绝的原因。"""
+ return self._last_handshake_rejection_reason
+
+ def clear_handshake_state(self) -> None:
+ """清空最近一次握手拒绝状态。"""
+ self._last_handshake_rejection_reason = ""
+
def register_method(self, method: str, handler: MethodHandler) -> None:
"""注册 RPC 方法处理器"""
self._method_handlers[method] = handler
@@ -85,6 +95,7 @@ class RPCServer:
async def start(self) -> None:
"""启动 RPC 服务器"""
self._running = True
+ self.clear_handshake_state()
self._send_queue = asyncio.Queue(maxsize=self._send_queue_size)
self._send_worker_task = asyncio.create_task(self._send_loop())
await self._transport.start(self._handle_connection)
@@ -93,6 +104,7 @@ class RPCServer:
async def stop(self) -> None:
"""停止 RPC 服务器"""
self._running = False
+ self.clear_handshake_state()
self._fail_pending_requests(ErrorCode.E_SHUTTING_DOWN, "服务器正在关闭")
self._fail_queued_sends(ErrorCode.E_SHUTTING_DOWN, "服务器正在关闭")
@@ -204,6 +216,7 @@ class RPCServer:
async def _handle_connection(self, conn: Connection) -> None:
"""处理新的 Runner 连接"""
logger.info("收到 Runner 连接")
+ self.clear_handshake_state()
# 第一条消息必须是 runner.hello 握手
try:
success = await self._handle_handshake(conn)
@@ -232,6 +245,7 @@ class RPCServer:
envelope = self._codec.decode_envelope(data)
if envelope.method != "runner.hello":
logger.error(f"期望 runner.hello,收到 {envelope.method}")
+ self._last_handshake_rejection_reason = "首条消息必须为 runner.hello"
error_resp = envelope.make_error_response(
ErrorCode.E_PROTOCOL_MISMATCH.value,
"首条消息必须为 runner.hello",
@@ -244,7 +258,8 @@ class RPCServer:
# 校验会话令牌
if hello.session_token != self._session_token:
logger.error("会话令牌不匹配")
- resp_payload = HelloResponsePayload(accepted=False, reason="会话令牌无效")
+ self._last_handshake_rejection_reason = "会话令牌无效"
+ resp_payload = HelloResponsePayload(accepted=False, reason=self._last_handshake_rejection_reason)
resp = envelope.make_response(payload=resp_payload.model_dump())
await conn.send_frame(self._codec.encode_envelope(resp))
return False
@@ -252,15 +267,19 @@ class RPCServer:
# 校验 SDK 版本
if not self._check_sdk_version(hello.sdk_version):
logger.error(f"SDK 版本不兼容: {hello.sdk_version}")
+ self._last_handshake_rejection_reason = (
+ f"SDK 版本 {hello.sdk_version} 不在支持范围 [{MIN_SDK_VERSION}, {MAX_SDK_VERSION}]"
+ )
resp_payload = HelloResponsePayload(
accepted=False,
- reason=f"SDK 版本 {hello.sdk_version} 不在支持范围 [{MIN_SDK_VERSION}, {MAX_SDK_VERSION}]",
+ reason=self._last_handshake_rejection_reason,
)
resp = envelope.make_response(payload=resp_payload.model_dump())
await conn.send_frame(self._codec.encode_envelope(resp))
return False
# 发送响应
+ self.clear_handshake_state()
resp_payload = HelloResponsePayload(accepted=True, host_version=PROTOCOL_VERSION)
resp = envelope.make_response(payload=resp_payload.model_dump())
await conn.send_frame(self._codec.encode_envelope(resp))
diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py
index cdf3d4ee..33091d5a 100644
--- a/src/plugin_runtime/host/supervisor.py
+++ b/src/plugin_runtime/host/supervisor.py
@@ -202,17 +202,24 @@ class PluginRunnerSupervisor:
self._restart_count = 0
self._clear_runner_state()
- await self._rpc_server.start()
- await self._spawn_runner()
-
try:
- await self._wait_for_runner_connection(timeout_sec=self._runner_spawn_timeout)
- await self._wait_for_runner_ready(timeout_sec=self._runner_spawn_timeout)
- except TimeoutError:
- if not self._rpc_server.is_connected:
- logger.warning("Runner 未在限定时间内完成连接,后续操作可能失败")
- else:
- logger.warning("Runner 未在限定时间内完成初始化,后续操作可能失败")
+ await self._rpc_server.start()
+ await self._spawn_runner()
+
+ try:
+ await self._wait_for_runner_connection(timeout_sec=self._runner_spawn_timeout)
+ await self._wait_for_runner_ready(timeout_sec=self._runner_spawn_timeout)
+ except TimeoutError:
+ if not self._rpc_server.is_connected:
+ logger.warning("Runner 未在限定时间内完成连接,后续操作可能失败")
+ else:
+ logger.warning("Runner 未在限定时间内完成初始化,后续操作可能失败")
+ except Exception:
+ await self._shutdown_runner(reason="startup_failed")
+ await self._rpc_server.stop()
+ self._clear_runner_state()
+ self._running = False
+ raise
self._health_task = asyncio.create_task(self._health_check_loop(), name="PluginRunnerSupervisor.health")
logger.info("PluginRunnerSupervisor 已启动")
@@ -387,7 +394,16 @@ class PluginRunnerSupervisor:
async def wait_for_connection() -> None:
"""轮询等待 RPC 连接建立。"""
- while self._running and not self._rpc_server.is_connected:
+ while True:
+ if self._rpc_server.is_connected:
+ return
+
+ if not self._running:
+ raise RuntimeError("Supervisor 已停止,等待 Runner 连接已取消")
+
+ if failure_reason := self._get_runner_startup_failure_reason():
+ raise RuntimeError(f"等待 Runner 连接失败: {failure_reason}")
+
await asyncio.sleep(0.1)
try:
@@ -408,10 +424,27 @@ class PluginRunnerSupervisor:
Raises:
TimeoutError: 在超时时间内 Runner 未完成初始化。
"""
+ async def wait_for_ready() -> RunnerReadyPayload:
+ """轮询等待 Runner 上报就绪。"""
+ while True:
+ if self._runner_ready_events.is_set():
+ return self._runner_ready_payloads
+
+ if not self._running:
+ raise RuntimeError("Supervisor 已停止,等待 Runner 就绪已取消")
+
+ if failure_reason := self._get_runner_startup_failure_reason():
+ raise RuntimeError(f"等待 Runner 就绪失败: {failure_reason}")
+
+ if not self._rpc_server.is_connected:
+ raise RuntimeError("等待 Runner 就绪失败: Runner RPC 连接已断开")
+
+ await asyncio.sleep(0.1)
+
try:
- await asyncio.wait_for(self._runner_ready_events.wait(), timeout=timeout_sec)
+ payload = await asyncio.wait_for(wait_for_ready(), timeout=timeout_sec)
logger.info("Runner 已完成初始化并上报就绪")
- return self._runner_ready_payloads
+ return payload
except asyncio.TimeoutError as exc:
raise TimeoutError(f"等待 Runner 就绪超时({timeout_sec}s)") from exc
@@ -923,6 +956,7 @@ class PluginRunnerSupervisor:
await self._wait_for_runner_connection(timeout_sec=self._runner_spawn_timeout)
await self._wait_for_runner_ready(timeout_sec=self._runner_spawn_timeout)
except Exception as exc:
+ await self._shutdown_runner(reason="restart_failed")
logger.error(f"Runner 重启失败: {exc}", exc_info=True)
return False
@@ -938,6 +972,25 @@ class PluginRunnerSupervisor:
self._registered_adapters.clear()
self._runner_ready_events = asyncio.Event()
self._runner_ready_payloads = RunnerReadyPayload()
+ self._rpc_server.clear_handshake_state()
+
+ def _get_runner_startup_failure_reason(self) -> Optional[str]:
+ """获取 Runner 在启动阶段已经暴露出的失败原因。
+
+ Returns:
+ Optional[str]: 若已检测到失败则返回失败原因,否则返回 ``None``。
+ """
+ if handshake_reason := self._rpc_server.last_handshake_rejection_reason:
+ return f"握手被拒绝: {handshake_reason}"
+
+ process = self._runner_process
+ if process is None:
+ return "Runner 进程不存在"
+
+ if process.returncode is not None:
+ return f"Runner 进程已退出,退出码 {process.returncode}"
+
+ return None
PluginSupervisor = PluginRunnerSupervisor
diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py
index 30a3c150..24cf09fc 100644
--- a/src/plugin_runtime/integration.py
+++ b/src/plugin_runtime/integration.py
@@ -68,6 +68,7 @@ class PluginRuntimeManager(
self._plugin_file_watcher: Optional[FileWatcher] = None
self._plugin_source_watcher_subscription_id: Optional[str] = None
self._plugin_config_watcher_subscriptions: Dict[str, Tuple[Path, str]] = {}
+ self._plugin_path_cache: Dict[str, Path] = {}
async def _dispatch_platform_inbound(self, envelope: InboundMessageEnvelope) -> None:
"""接收 Platform IO 审核后的入站消息并送入主消息链。
@@ -215,6 +216,7 @@ class PluginRuntimeManager(
self._started = False
self._builtin_supervisor = None
self._third_party_supervisor = None
+ self._plugin_path_cache.clear()
@property
def is_running(self) -> bool:
@@ -254,7 +256,7 @@ class PluginRuntimeManager(
config_payload = (
config_data
if config_data is not None
- else self._load_plugin_config_for_supervisor(plugin_id, plugin_dirs=sv._plugin_dirs)
+ else self._load_plugin_config_for_supervisor(sv, plugin_id)
)
await sv.notify_plugin_config_updated(
plugin_id=plugin_id,
@@ -452,6 +454,7 @@ class PluginRuntimeManager(
async def _stop_plugin_file_watcher(self) -> None:
"""停止插件文件监视器,并清理所有已注册订阅。"""
if self._plugin_file_watcher is None:
+ self._plugin_path_cache.clear()
return
for _plugin_id, (_config_path, subscription_id) in list(self._plugin_config_watcher_subscriptions.items()):
self._plugin_file_watcher.unsubscribe(subscription_id)
@@ -461,12 +464,95 @@ class PluginRuntimeManager(
self._plugin_source_watcher_subscription_id = None
await self._plugin_file_watcher.stop()
self._plugin_file_watcher = None
+ self._plugin_path_cache.clear()
def _iter_plugin_dirs(self) -> Iterable[Path]:
"""迭代所有 Supervisor 当前管理的插件根目录。"""
for supervisor in self.supervisors:
yield from getattr(supervisor, "_plugin_dirs", [])
+ @staticmethod
+ def _iter_candidate_plugin_paths(plugin_dirs: Iterable[Path]) -> Iterable[Path]:
+ """迭代所有可能的插件目录路径。
+
+ Args:
+ plugin_dirs: 一个或多个插件根目录。
+
+ Yields:
+ Path: 单个插件目录路径。
+ """
+ for plugin_dir in plugin_dirs:
+ plugin_root = Path(plugin_dir).resolve()
+ if not plugin_root.is_dir():
+ continue
+ for entry in plugin_root.iterdir():
+ if entry.is_dir():
+ yield entry.resolve()
+
+ @staticmethod
+ def _read_plugin_id_from_plugin_path(plugin_path: Path) -> Optional[str]:
+ """从单个插件目录中读取 manifest 声明的插件 ID。
+
+ Args:
+ plugin_path: 单个插件目录路径。
+
+ Returns:
+ Optional[str]: 解析成功时返回插件 ID,否则返回 ``None``。
+ """
+ manifest_path = plugin_path / "_manifest.json"
+ entrypoint_path = plugin_path / "plugin.py"
+ if not manifest_path.is_file() or not entrypoint_path.is_file():
+ return None
+
+ try:
+ with open(manifest_path, "r", encoding="utf-8") as manifest_file:
+ manifest = json.load(manifest_file)
+ except Exception:
+ return None
+
+ if not isinstance(manifest, dict):
+ return None
+
+ plugin_id = str(manifest.get("name", plugin_path.name)).strip() or plugin_path.name
+ return plugin_id or None
+
+ def _iter_discovered_plugin_paths(self, plugin_dirs: Iterable[Path]) -> Iterable[Tuple[str, Path]]:
+ """迭代目录中可解析到的插件 ID 与实际目录路径。
+
+ Args:
+ plugin_dirs: 一个或多个插件根目录。
+
+ Yields:
+ Tuple[str, Path]: ``(plugin_id, plugin_path)`` 二元组。
+ """
+ for plugin_path in self._iter_candidate_plugin_paths(plugin_dirs):
+ if plugin_id := self._read_plugin_id_from_plugin_path(plugin_path):
+ yield plugin_id, plugin_path
+
+ def _get_plugin_path_for_supervisor(self, supervisor: Any, plugin_id: str) -> Optional[Path]:
+ """为指定 Supervisor 定位某个插件的实际目录。
+
+ Args:
+ supervisor: 目标 Supervisor。
+ plugin_id: 插件 ID。
+
+ Returns:
+ Optional[Path]: 插件目录路径;未找到时返回 ``None``。
+ """
+ cached_path = self._plugin_path_cache.get(plugin_id)
+ if cached_path is not None:
+ for plugin_dir in getattr(supervisor, "_plugin_dirs", []):
+ if self._plugin_dir_matches(cached_path, Path(plugin_dir)):
+ return cached_path
+
+ for candidate_plugin_id, plugin_path in self._iter_discovered_plugin_paths(getattr(supervisor, "_plugin_dirs", [])):
+ if candidate_plugin_id != plugin_id:
+ continue
+ self._plugin_path_cache[plugin_id] = plugin_path
+ return plugin_path
+
+ return None
+
def _refresh_plugin_config_watch_subscriptions(self) -> None:
"""按当前已注册插件集合刷新 config.toml 的单插件订阅。
@@ -476,7 +562,11 @@ class PluginRuntimeManager(
if self._plugin_file_watcher is None:
return
- desired_config_paths = dict(self._iter_registered_plugin_config_paths())
+ desired_plugin_paths = dict(self._iter_registered_plugin_paths())
+ self._plugin_path_cache = desired_plugin_paths.copy()
+ desired_config_paths = {
+ plugin_id: plugin_path / "config.toml" for plugin_id, plugin_path in desired_plugin_paths.items()
+ }
for plugin_id, (_old_path, subscription_id) in list(self._plugin_config_watcher_subscriptions.items()):
if desired_config_paths.get(plugin_id) == self._plugin_config_watcher_subscriptions[plugin_id][0]:
@@ -509,21 +599,17 @@ class PluginRuntimeManager(
return _callback
- def _iter_registered_plugin_config_paths(self) -> Iterable[Tuple[str, Path]]:
- """迭代当前所有已注册插件的 config.toml 路径。"""
+ def _iter_registered_plugin_paths(self) -> Iterable[Tuple[str, Path]]:
+ """迭代当前所有已注册插件的实际目录路径。"""
for supervisor in self.supervisors:
for plugin_id in getattr(supervisor, "_registered_plugins", {}).keys():
- if config_path := self._get_plugin_config_path_for_supervisor(supervisor, plugin_id):
- yield plugin_id, config_path
+ if plugin_path := self._get_plugin_path_for_supervisor(supervisor, plugin_id):
+ yield plugin_id, plugin_path
def _get_plugin_config_path_for_supervisor(self, supervisor: Any, plugin_id: str) -> Optional[Path]:
"""从指定 Supervisor 的插件目录中定位某个插件的 config.toml。"""
- for plugin_dir in getattr(supervisor, "_plugin_dirs", []):
- plugin_dir = Path(plugin_dir)
- plugin_path = plugin_dir.resolve() / plugin_id
- if plugin_path.is_dir():
- return plugin_path / "config.toml"
- return None
+ plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id)
+ return None if plugin_path is None else plugin_path / "config.toml"
async def _handle_plugin_config_changes(self, plugin_id: str, changes: Sequence[FileChange]) -> None:
"""处理单个插件配置文件变化,并仅向目标插件推送配置更新。"""
@@ -542,7 +628,7 @@ class PluginRuntimeManager(
try:
await supervisor.notify_plugin_config_updated(
plugin_id=plugin_id,
- config_data=self._load_plugin_config_for_supervisor(plugin_id, getattr(supervisor, "_plugin_dirs", [])),
+ config_data=self._load_plugin_config_for_supervisor(supervisor, plugin_id),
)
except Exception as exc:
logger.warning(f"插件 {plugin_id} 配置热更新通知失败: {exc}")
@@ -591,32 +677,38 @@ class PluginRuntimeManager(
def _match_plugin_id_for_supervisor(self, supervisor: Any, path: Path) -> Optional[str]:
"""根据变更路径为指定 Supervisor 推断受影响的插件 ID。"""
- for plugin_id, _reg in getattr(supervisor, "_registered_plugins", {}).items():
- for plugin_dir in getattr(supervisor, "_plugin_dirs", []):
- plugin_dir = Path(plugin_dir)
- candidate_dir = plugin_dir.resolve() / plugin_id
- if path == candidate_dir or path.is_relative_to(candidate_dir):
- return plugin_id
+ resolved_path = path.resolve()
+
+ for plugin_id in getattr(supervisor, "_registered_plugins", {}).keys():
+ plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id)
+ if plugin_path is not None and (resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path)):
+ return plugin_id
+
+ for plugin_id, plugin_path in self._plugin_path_cache.items():
+ if not any(self._plugin_dir_matches(plugin_path, Path(plugin_dir)) for plugin_dir in getattr(supervisor, "_plugin_dirs", [])):
+ continue
+ if resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path):
+ return plugin_id
+
+ for plugin_id, plugin_path in self._iter_discovered_plugin_paths(getattr(supervisor, "_plugin_dirs", [])):
+ if resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path):
+ self._plugin_path_cache[plugin_id] = plugin_path
+ return plugin_id
- for plugin_dir in getattr(supervisor, "_plugin_dirs", []):
- plugin_dir = Path(plugin_dir)
- plugin_root = plugin_dir.resolve()
- if self._plugin_dir_matches(path, plugin_dir) and (relative_parts := path.relative_to(plugin_root).parts):
- return relative_parts[0]
return None
- @staticmethod
- def _load_plugin_config_for_supervisor(plugin_id: str, plugin_dirs: Iterable[Path]) -> Dict[str, Any]:
+ def _load_plugin_config_for_supervisor(self, supervisor: Any, plugin_id: str) -> Dict[str, Any]:
"""从给定插件目录集合中读取目标插件的配置内容。"""
- for plugin_dir in plugin_dirs:
- plugin_path = plugin_dir.resolve() / plugin_id
- if plugin_path.is_dir():
- config_path = plugin_path / "config.toml"
- if not config_path.exists():
- return {}
- with open(config_path, "r", encoding="utf-8") as handle:
- return tomlkit.load(handle).unwrap()
- return {}
+ plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id)
+ if plugin_path is None:
+ return {}
+
+ config_path = plugin_path / "config.toml"
+ if not config_path.exists():
+ return {}
+
+ with open(config_path, "r", encoding="utf-8") as handle:
+ return tomlkit.load(handle).unwrap()
# ─── 能力实现注册 ──────────────────────────────────────────
diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py
index 0dfc6656..d71e02c5 100644
--- a/src/plugin_runtime/protocol/envelope.py
+++ b/src/plugin_runtime/protocol/envelope.py
@@ -17,7 +17,7 @@ from pydantic import BaseModel, Field
PROTOCOL_VERSION = "1.0.0"
# 支持的 SDK 版本范围(Host 在握手时校验)
MIN_SDK_VERSION = "1.0.0"
-MAX_SDK_VERSION = "1.99.99"
+MAX_SDK_VERSION = "2.99.99"
# ====== 消息类型 ======
diff --git a/src/plugins/built_in/napcat_adapter/plugin.py b/src/plugins/built_in/napcat_adapter/plugin.py
index 3eff518d..a481101f 100644
--- a/src/plugins/built_in/napcat_adapter/plugin.py
+++ b/src/plugins/built_in/napcat_adapter/plugin.py
@@ -14,7 +14,7 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
+from typing import Any, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, cast
from uuid import uuid4
import asyncio
@@ -42,6 +42,13 @@ if not TYPE_CHECKING:
AiohttpClientWebSocketResponse = Any
+SUPPORTED_CONFIG_VERSION = "0.1.0"
+DEFAULT_RECONNECT_DELAY_SEC = 5.0
+DEFAULT_HEARTBEAT_SEC = 30.0
+DEFAULT_ACTION_TIMEOUT_SEC = 15.0
+DEFAULT_CHAT_LIST_TYPE = "whitelist"
+
+
@Adapter(platform="qq", protocol="napcat", send_method="send_to_platform")
class NapCatAdapterPlugin(MaiBotPlugin):
"""NapCat 适配器 MVP 实现。"""
@@ -52,7 +59,7 @@ class NapCatAdapterPlugin(MaiBotPlugin):
self._plugin_config: Dict[str, Any] = {}
self._connection_task: Optional[asyncio.Task[None]] = None
self._pending_actions: Dict[str, asyncio.Future[Dict[str, Any]]] = {}
- self._background_tasks: set[asyncio.Task[Any]] = set()
+ self._background_tasks: Set[asyncio.Task[Any]] = set()
self._send_lock = asyncio.Lock()
self._ws: Optional[AiohttpClientWebSocketResponse] = None
@@ -80,8 +87,9 @@ class NapCatAdapterPlugin(MaiBotPlugin):
new_config: 最新的插件配置。
version: 配置版本号。
"""
- del version
self.set_plugin_config(new_config)
+ if version:
+ self.ctx.logger.debug(f"NapCat 适配器收到配置更新通知: {version}")
await self._restart_connection_if_needed()
async def send_to_platform(
@@ -139,6 +147,8 @@ class NapCatAdapterPlugin(MaiBotPlugin):
if not self._should_connect():
self.ctx.logger.info("NapCat 适配器保持空闲状态,因为插件或配置未启用")
return
+ if not self._validate_current_config():
+ return
if not AIOHTTP_AVAILABLE:
self.ctx.logger.error("NapCat 适配器依赖 aiohttp,但当前环境未安装该依赖")
return
@@ -185,7 +195,7 @@ class NapCatAdapterPlugin(MaiBotPlugin):
headers = self._build_headers()
timeout = ClientTimeout(total=None, connect=10)
- heartbeat = self._get_positive_float(self._connection_config(), "heartbeat_sec", 30.0)
+ heartbeat = self._get_positive_float(self._connection_config(), "heartbeat_sec", DEFAULT_HEARTBEAT_SEC)
try:
async with ClientSession(headers=headers, timeout=timeout) as session:
@@ -204,7 +214,13 @@ class NapCatAdapterPlugin(MaiBotPlugin):
if not self._should_connect():
break
- await asyncio.sleep(self._get_positive_float(self._connection_config(), "reconnect_delay_sec", 5.0))
+ await asyncio.sleep(
+ self._get_positive_float(
+ self._connection_config(),
+ "reconnect_delay_sec",
+ DEFAULT_RECONNECT_DELAY_SEC,
+ )
+ )
async def _receive_loop(self, ws: AiohttpClientWebSocketResponse) -> None:
"""持续消费 WebSocket 消息并分发处理。
@@ -250,8 +266,11 @@ class NapCatAdapterPlugin(MaiBotPlugin):
if not sender_user_id:
return
+ group_id = str(payload.get("group_id") or "").strip()
if self_id and sender_user_id == self_id and self._get_bool(self._filters_config(), "ignore_self_message", True):
return
+ if not self._is_inbound_chat_allowed(sender_user_id, group_id):
+ return
message_dict = self._build_inbound_message_dict(payload, self_id, sender_user_id, sender)
route_metadata: Dict[str, Any] = {}
@@ -339,7 +358,7 @@ class NapCatAdapterPlugin(MaiBotPlugin):
"display_message": plain_text,
}
- def _convert_inbound_segments(self, message_payload: Any, self_id: str) -> tuple[List[Dict[str, Any]], bool]:
+ def _convert_inbound_segments(self, message_payload: Any, self_id: str) -> Tuple[List[Dict[str, Any]], bool]:
"""将 OneBot 消息段转换为 Host 消息段结构。
Args:
@@ -347,7 +366,7 @@ class NapCatAdapterPlugin(MaiBotPlugin):
self_id: 当前机器人账号 ID。
Returns:
- tuple[List[Dict[str, Any]], bool]: 转换后的消息段列表,以及是否 @ 到当前机器人。
+ Tuple[List[Dict[str, Any]], bool]: 转换后的消息段列表,以及是否 @ 到当前机器人。
"""
if isinstance(message_payload, str):
normalized_text = message_payload.strip()
@@ -412,7 +431,7 @@ class NapCatAdapterPlugin(MaiBotPlugin):
self,
message: Dict[str, Any],
route: Dict[str, Any],
- ) -> tuple[str, Dict[str, Any]]:
+ ) -> Tuple[str, Dict[str, Any]]:
"""为 Host 出站消息构造 OneBot 动作。
Args:
@@ -420,7 +439,7 @@ class NapCatAdapterPlugin(MaiBotPlugin):
route: Platform IO 路由信息。
Returns:
- tuple[str, Dict[str, Any]]: 动作名称与参数字典。
+ Tuple[str, Dict[str, Any]]: 动作名称与参数字典。
"""
message_info = message.get("message_info", {})
if not isinstance(message_info, dict):
@@ -519,7 +538,11 @@ class NapCatAdapterPlugin(MaiBotPlugin):
try:
async with self._send_lock:
await ws.send_str(json.dumps(request_payload, ensure_ascii=False))
- timeout_seconds = self._get_positive_float(self._connection_config(), "action_timeout_sec", 15.0)
+ timeout_seconds = self._get_positive_float(
+ self._connection_config(),
+ "action_timeout_sec",
+ DEFAULT_ACTION_TIMEOUT_SEC,
+ )
return await asyncio.wait_for(response_future, timeout=timeout_seconds)
finally:
self._pending_actions.pop(echo_id, None)
@@ -626,6 +649,173 @@ class NapCatAdapterPlugin(MaiBotPlugin):
filters_config = self._plugin_config.get("filters", {})
return filters_config if isinstance(filters_config, dict) else {}
+ def _chat_config(self) -> Dict[str, Any]:
+ """读取插件配置中的 ``chat`` 段。
+
+ Returns:
+ Dict[str, Any]: ``chat`` 配置字典。
+ """
+ chat_config = self._plugin_config.get("chat", {})
+ return chat_config if isinstance(chat_config, dict) else {}
+
+ def _is_inbound_chat_allowed(self, sender_user_id: str, group_id: str) -> bool:
+ """检查入站消息是否通过聊天名单过滤。
+
+ Args:
+ sender_user_id: 发送者用户 ID。
+ group_id: 群聊 ID;私聊时为空字符串。
+
+ Returns:
+ bool: 若消息允许继续进入 Host,则返回 ``True``。
+ """
+ chat_config = self._chat_config()
+ banned_user_ids = self._get_string_list(chat_config, "ban_user_id")
+ if sender_user_id in banned_user_ids:
+ self.ctx.logger.warning(f"NapCat 用户 {sender_user_id} 在全局禁止名单中,消息被丢弃")
+ return False
+
+ if group_id:
+ group_list_type = self._get_list_mode(chat_config, "group_list_type", DEFAULT_CHAT_LIST_TYPE)
+ group_id_list = self._get_string_list(chat_config, "group_list")
+ if not self._is_id_allowed_by_list_policy(group_id, group_list_type, group_id_list):
+ self.ctx.logger.warning(f"NapCat 群聊 {group_id} 未通过聊天名单过滤,消息被丢弃")
+ return False
+ return True
+
+ private_list_type = self._get_list_mode(chat_config, "private_list_type", DEFAULT_CHAT_LIST_TYPE)
+ private_id_list = self._get_string_list(chat_config, "private_list")
+ if not self._is_id_allowed_by_list_policy(sender_user_id, private_list_type, private_id_list):
+ self.ctx.logger.warning(f"NapCat 私聊用户 {sender_user_id} 未通过聊天名单过滤,消息被丢弃")
+ return False
+ return True
+
+ def _is_id_allowed_by_list_policy(
+ self,
+ target_id: str,
+ list_type: str,
+ configured_ids: Set[str],
+ ) -> bool:
+ """根据白名单或黑名单规则判断目标 ID 是否允许通过。
+
+ Args:
+ target_id: 待检查的目标 ID。
+ list_type: 名单模式,仅支持 ``whitelist`` 或 ``blacklist``。
+ configured_ids: 配置中的 ID 集合。
+
+ Returns:
+ bool: 若目标 ID 允许通过,则返回 ``True``。
+ """
+ if list_type == "whitelist":
+ return target_id in configured_ids
+ return target_id not in configured_ids
+
+ def _validate_current_config(self) -> bool:
+ """校验当前配置是否满足启动连接的前提条件。
+
+ Returns:
+ bool: 配置可用于启动连接时返回 ``True``。
+ """
+ if not self._validate_plugin_config_version():
+ return False
+
+ connection_config = self._connection_config()
+ ws_url = self._get_string(connection_config, "ws_url")
+ if not ws_url:
+ self.ctx.logger.warning("NapCat 适配器已启用,但 connection.ws_url 为空")
+ return False
+
+ self._validate_positive_float_setting(
+ connection_config,
+ "connection",
+ "reconnect_delay_sec",
+ DEFAULT_RECONNECT_DELAY_SEC,
+ )
+ self._validate_positive_float_setting(
+ connection_config,
+ "connection",
+ "heartbeat_sec",
+ DEFAULT_HEARTBEAT_SEC,
+ )
+ self._validate_positive_float_setting(
+ connection_config,
+ "connection",
+ "action_timeout_sec",
+ DEFAULT_ACTION_TIMEOUT_SEC,
+ )
+ self._validate_list_mode_setting(self._chat_config(), "chat", "group_list_type", DEFAULT_CHAT_LIST_TYPE)
+ self._validate_list_mode_setting(self._chat_config(), "chat", "private_list_type", DEFAULT_CHAT_LIST_TYPE)
+ return True
+
+ def _validate_plugin_config_version(self) -> bool:
+ """校验插件配置版本是否与当前实现兼容。
+
+ Returns:
+ bool: 版本兼容时返回 ``True``。
+ """
+ config_version = self._get_string(self._plugin_section(), "config_version")
+ if not config_version:
+ self.ctx.logger.error(
+ f"NapCat 适配器配置缺少 plugin.config_version,当前插件要求版本 {SUPPORTED_CONFIG_VERSION}"
+ )
+ return False
+
+ if config_version != SUPPORTED_CONFIG_VERSION:
+ self.ctx.logger.error(
+ "NapCat 适配器配置版本不兼容: "
+ f"当前为 {config_version},当前插件要求 {SUPPORTED_CONFIG_VERSION}"
+ )
+ return False
+
+ return True
+
+ def _validate_positive_float_setting(
+ self,
+ mapping: Dict[str, Any],
+ section_name: str,
+ key: str,
+ default: float,
+ ) -> None:
+ """校验正浮点数配置项,并在非法时输出告警日志。
+
+ Args:
+ mapping: 待读取的配置字典。
+ section_name: 当前配置段名称。
+ key: 目标配置键名。
+ default: 配置非法时实际使用的默认值。
+ """
+ value = mapping.get(key, default)
+ if isinstance(value, (int, float)) and float(value) > 0:
+ return
+
+ self.ctx.logger.warning(
+ "NapCat 适配器配置项取值无效,已回退到默认值: "
+ f"{section_name}.{key}={value!r},默认值为 {default}"
+ )
+
+ def _validate_list_mode_setting(
+ self,
+ mapping: Dict[str, Any],
+ section_name: str,
+ key: str,
+ default: str,
+ ) -> None:
+ """校验名单模式配置项,并在非法时输出告警日志。
+
+ Args:
+ mapping: 待读取的配置字典。
+ section_name: 当前配置段名称。
+ key: 目标配置键名。
+ default: 配置非法时实际使用的默认值。
+ """
+ value = mapping.get(key, default)
+ if isinstance(value, str) and value.strip() in {"whitelist", "blacklist"}:
+ return
+
+ self.ctx.logger.warning(
+ "NapCat 适配器配置项取值无效,已回退到默认值: "
+ f"{section_name}.{key}={value!r},默认值为 {default}"
+ )
+
def _should_connect(self) -> bool:
"""判断当前配置下是否应当启动连接。
@@ -680,6 +870,47 @@ class NapCatAdapterPlugin(MaiBotPlugin):
value = mapping.get(key)
return "" if value is None else str(value).strip()
+ @staticmethod
+ def _get_list_mode(mapping: Dict[str, Any], key: str, default: str) -> str:
+ """安全读取名单模式配置值。
+
+ Args:
+ mapping: 待读取的配置字典。
+ key: 目标键名。
+ default: 读取失败时的默认值。
+
+ Returns:
+ str: 合法的名单模式字符串。
+ """
+ value = mapping.get(key, default)
+ if isinstance(value, str):
+ normalized_value = value.strip()
+ if normalized_value in {"whitelist", "blacklist"}:
+ return normalized_value
+ return default
+
+ @staticmethod
+ def _get_string_list(mapping: Dict[str, Any], key: str) -> Set[str]:
+ """安全读取 ID 列表配置值。
+
+ Args:
+ mapping: 待读取的配置字典。
+ key: 目标键名。
+
+ Returns:
+ Set[str]: 去重后的字符串 ID 集合。
+ """
+ value = mapping.get(key, [])
+ if not isinstance(value, list):
+ return set()
+
+ normalized_values: Set[str] = set()
+ for item in value:
+ item_text = "" if item is None else str(item).strip()
+ if item_text:
+ normalized_values.add(item_text)
+ return normalized_values
+
def create_plugin() -> NapCatAdapterPlugin:
"""创建插件实例。
From a1859027efbf3fa5fb9178c815b350a2f7e2ae02 Mon Sep 17 00:00:00 2001
From: DrSmoothl <1787882683@qq.com>
Date: Sat, 21 Mar 2026 00:53:05 +0800
Subject: [PATCH 21/45] =?UTF-8?q?refactor:=20=E4=BC=98=E5=8C=96=E9=A2=9C?=
=?UTF-8?q?=E8=89=B2=E6=98=A0=E5=B0=84=E5=92=8C=E5=88=AB=E5=90=8D=E5=AE=9A?=
=?UTF-8?q?=E4=B9=89=EF=BC=8C=E5=A2=9E=E5=BC=BA=E6=A8=A1=E5=9D=97=E4=B8=80?=
=?UTF-8?q?=E8=87=B4=E6=80=A7?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
src/common/logger_color_and_mapping.py | 12 ++++++++++--
1 file changed, 10 insertions(+), 2 deletions(-)
diff --git a/src/common/logger_color_and_mapping.py b/src/common/logger_color_and_mapping.py
index f84caa21..4044d2dc 100644
--- a/src/common/logger_color_and_mapping.py
+++ b/src/common/logger_color_and_mapping.py
@@ -1,9 +1,8 @@
# 定义模块颜色映射
-from typing import Optional, Tuple, Dict
-
import itertools
import os
import sys
+from typing import Dict, Optional, Tuple
MODULE_COLORS: Dict[str, Tuple[str, Optional[str], bool]] = {
@@ -54,15 +53,19 @@ MODULE_COLORS: Dict[str, Tuple[str, Optional[str], bool]] = {
"component_registry": ("#ffaf00", None, False),
"plugin_runtime.integration": ("#d75f00", None, False),
"plugin_runtime.host.supervisor": ("#ff5f00", None, False),
+ "plugin_runtime.host.runner_manager": ("#ff5f00", None, False),
"plugin_runtime.host.rpc_server": ("#ff8700", None, False),
"plugin_runtime.host.component_registry": ("#ffaf00", None, False),
"plugin_runtime.host.capability_service": ("#ffd700", None, False),
"plugin_runtime.host.event_dispatcher": ("#87d700", None, False),
"plugin_runtime.host.hook_dispatcher": ("#5fd7af", None, False),
+ "plugin_runtime.host.message_gateway": ("#5fd7d7", None, False),
+ "plugin_runtime.host.message_utils": ("#5faf87", None, False),
"plugin_runtime.runner.main": ("#d787ff", None, False),
"plugin_runtime.runner.rpc_client": ("#8787ff", None, False),
"plugin_runtime.runner.manifest_validator": ("#5fafff", None, False),
"plugin_runtime.runner.plugin_loader": ("#00afaf", None, False),
+ "plugin.napcat_adapter_builtin": ("#00af87", None, False),
"webui": ("#5f87ff", None, False),
"webui.app": ("#5f87d7", None, False),
"webui.api": ("#5fafff", None, False),
@@ -157,15 +160,20 @@ MODULE_ALIASES = {
"chat_history_summarizer": "聊天概括器",
"plugin_runtime.integration": "IPC插件系统",
"plugin_runtime.host.supervisor": "插件监督器",
+ "plugin_runtime.host.runner_manager": "插件监督器",
"plugin_runtime.host.rpc_server": "插件RPC服务",
"plugin_runtime.host.component_registry": "插件组件注册",
"plugin_runtime.host.capability_service": "插件能力服务",
"plugin_runtime.host.event_dispatcher": "插件事件分发",
+ "plugin_runtime.host.hook_dispatcher": "插件Hook分发",
+ "plugin_runtime.host.message_gateway": "插件消息网关",
+ "plugin_runtime.host.message_utils": "插件消息工具",
"plugin_runtime.host.workflow_executor": "插件工作流",
"plugin_runtime.runner.main": "插件运行器",
"plugin_runtime.runner.rpc_client": "插件RPC客户端",
"plugin_runtime.runner.manifest_validator": "插件清单校验",
"plugin_runtime.runner.plugin_loader": "插件加载器",
+ "plugin.napcat_adapter_builtin": "NapCat内置适配器",
"webui": "WebUI",
"webui.app": "WebUI应用",
"webui.api": "WebUI接口",
From dd20cd4992b1c1bc250f0367317e9eef8587328f Mon Sep 17 00:00:00 2001
From: DrSmoothl <1787882683@qq.com>
Date: Sat, 21 Mar 2026 00:59:21 +0800
Subject: [PATCH 22/45] =?UTF-8?q?refactor:=20=E5=A2=9E=E5=BC=BA=E6=96=87?=
=?UTF-8?q?=E6=A1=A3=E6=B3=A8=E9=87=8A?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
src/plugin_runtime/runner/log_handler.py | 6 ++
.../runner/manifest_validator.py | 69 +++++++++++++++++++
src/plugin_runtime/runner/runner_main.py | 13 +++-
3 files changed, 87 insertions(+), 1 deletion(-)
diff --git a/src/plugin_runtime/runner/log_handler.py b/src/plugin_runtime/runner/log_handler.py
index b5a0a328..6f42940f 100644
--- a/src/plugin_runtime/runner/log_handler.py
+++ b/src/plugin_runtime/runner/log_handler.py
@@ -66,6 +66,12 @@ class RunnerIPCLogHandler(logging.Handler):
ALLOWED_LOGGER_PREFIXES: tuple[str, ...] = ("plugin.", "plugin_runtime.", "_maibot_plugin_")
def __init__(self) -> None:
+ """初始化 Runner 端日志转发处理器。
+
+ 创建有界日志缓冲区,并准备与 RPC 客户端绑定的后台刷新任务。
+ 此时不会启动任何异步任务;真正开始转发要等到 :meth:`start`
+ 被调用后才会发生。
+ """
super().__init__()
# deque(maxlen=N): append/popleft 在 CPython GIL 保护下线程安全
self._buffer: collections.deque[LogEntry] = collections.deque(maxlen=self.QUEUE_MAX)
diff --git a/src/plugin_runtime/runner/manifest_validator.py b/src/plugin_runtime/runner/manifest_validator.py
index b6990850..32429e01 100644
--- a/src/plugin_runtime/runner/manifest_validator.py
+++ b/src/plugin_runtime/runner/manifest_validator.py
@@ -18,6 +18,15 @@ class VersionComparator:
@staticmethod
def normalize_version(version: str) -> str:
+ """将版本号规范化为三段式语义版本字符串。
+
+ Args:
+ version: 原始版本号字符串。
+
+ Returns:
+ str: 规范化后的 ``major.minor.patch`` 形式版本号。
+ 当输入为空或格式非法时返回 ``0.0.0``。
+ """
if not version:
return "0.0.0"
normalized = re.sub(r"-snapshot\.\d+", "", version.strip())
@@ -30,6 +39,15 @@ class VersionComparator:
@staticmethod
def parse_version(version: str) -> Tuple[int, int, int]:
+ """将版本字符串解析为可比较的整数元组。
+
+ Args:
+ version: 原始版本号字符串。
+
+ Returns:
+ Tuple[int, int, int]: 三段式版本号对应的整数元组。
+ 当解析失败时返回 ``(0, 0, 0)``。
+ """
normalized = VersionComparator.normalize_version(version)
try:
parts = normalized.split(".")
@@ -39,6 +57,16 @@ class VersionComparator:
@staticmethod
def compare(v1: str, v2: str) -> int:
+ """比较两个版本号的大小关系。
+
+ Args:
+ v1: 第一个版本号。
+ v2: 第二个版本号。
+
+ Returns:
+ int: ``-1`` 表示 ``v1 < v2``,``1`` 表示 ``v1 > v2``,
+ ``0`` 表示两者相等。
+ """
t1 = VersionComparator.parse_version(v1)
t2 = VersionComparator.parse_version(v2)
if t1 < t2:
@@ -49,6 +77,17 @@ class VersionComparator:
@staticmethod
def is_in_range(version: str, min_version: str = "", max_version: str = "") -> Tuple[bool, str]:
+ """判断版本号是否落在给定闭区间内。
+
+ Args:
+ version: 待检查的版本号。
+ min_version: 允许的最小版本号,留空表示不限制下界。
+ max_version: 允许的最大版本号,留空表示不限制上界。
+
+ Returns:
+ Tuple[bool, str]: 第一项表示是否满足要求,第二项为失败原因;
+ 当校验通过时第二项为空字符串。
+ """
if not min_version and not max_version:
return True, ""
vn = VersionComparator.normalize_version(version)
@@ -71,6 +110,11 @@ class ManifestValidator:
SUPPORTED_MANIFEST_VERSIONS = [1, 2]
def __init__(self, host_version: str = "") -> None:
+ """初始化 Manifest 校验器。
+
+ Args:
+ host_version: 当前 Host 版本号,用于校验插件声明的兼容区间。
+ """
self._host_version = host_version
self.errors: List[str] = []
self.warnings: List[str] = []
@@ -96,6 +140,11 @@ class ManifestValidator:
return len(self.errors) == 0
def _check_required_fields(self, manifest: Dict[str, Any]) -> None:
+ """检查 Manifest 中的必填字段是否存在且非空。
+
+ Args:
+ manifest: 待校验的 Manifest 数据。
+ """
for field in self.REQUIRED_FIELDS:
if field not in manifest:
self.errors.append(f"缺少必需字段: {field}")
@@ -103,11 +152,21 @@ class ManifestValidator:
self.errors.append(f"必需字段不能为空: {field}")
def _check_manifest_version(self, manifest: Dict[str, Any]) -> None:
+ """检查 Manifest 版本号是否在当前 Runner 支持范围内。
+
+ Args:
+ manifest: 待校验的 Manifest 数据。
+ """
mv = manifest.get("manifest_version")
if mv is not None and mv not in self.SUPPORTED_MANIFEST_VERSIONS:
self.errors.append(f"不支持的 manifest_version: {mv},支持: {self.SUPPORTED_MANIFEST_VERSIONS}")
def _check_author(self, manifest: Dict[str, Any]) -> None:
+ """校验 ``author`` 字段的结构与内容。
+
+ Args:
+ manifest: 待校验的 Manifest 数据。
+ """
author = manifest.get("author")
if author is None:
return
@@ -121,6 +180,11 @@ class ManifestValidator:
self.errors.append("author 应为字符串或 {name, url} 对象")
def _check_host_compatibility(self, manifest: Dict[str, Any]) -> None:
+ """检查插件声明的 Host 兼容范围是否包含当前 Host 版本。
+
+ Args:
+ manifest: 待校验的 Manifest 数据。
+ """
host_app = manifest.get("host_application")
if not isinstance(host_app, dict) or not self._host_version:
return
@@ -131,6 +195,11 @@ class ManifestValidator:
self.errors.append(f"Host 版本不兼容: {msg} (当前 Host: {self._host_version})")
def _check_recommended(self, manifest: Dict[str, Any]) -> None:
+ """检查推荐字段是否齐备,并记录为警告而非错误。
+
+ Args:
+ manifest: 待校验的 Manifest 数据。
+ """
for field in self.RECOMMENDED_FIELDS:
if field not in manifest or not manifest[field]:
self.warnings.append(f"建议填写字段: {field}")
diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py
index 3ffb6b4b..88f92494 100644
--- a/src/plugin_runtime/runner/runner_main.py
+++ b/src/plugin_runtime/runner/runner_main.py
@@ -47,8 +47,19 @@ logger = get_logger("plugin_runtime.runner.main")
class _ContextAwarePlugin(Protocol):
+ """支持注入运行时上下文的插件协议。
+
+ 该协议用于描述 Runner 在激活插件时依赖的最小接口。
+ 只要插件实例实现了 ``_set_context`` 方法,就可以被 Runner
+ 注入 ``PluginContext`` 或兼容层上下文对象。
+ """
+
def _set_context(self, context: Any) -> None:
- """为插件注入上下文对象。"""
+ """为插件实例注入运行时上下文。
+
+ Args:
+ context: 由 Runner 构造的上下文对象。
+ """
def _install_shutdown_signal_handlers(
From 4e2e7a279e42d7f33a41bdeaf7c9b1bedd4f9953 Mon Sep 17 00:00:00 2001
From: DrSmoothl <1787882683@qq.com>
Date: Sat, 21 Mar 2026 21:47:22 +0800
Subject: [PATCH 23/45] feat: Implement adapter runtime state management and
update handling
- Added support for adapter runtime state updates in the PluginRunnerSupervisor.
- Introduced new payload classes: AdapterStateUpdatePayload and AdapterStateUpdateResultPayload for handling state updates.
- Implemented methods to bind and unbind routes based on adapter connection status.
- Enhanced the NapCat adapter to report connection state and manage runtime state.
- Added tests for adapter runtime state synchronization and database session behavior in the statistic module.
- Updated existing methods to ensure proper handling of adapter state and route bindings.
---
.../test_person_info_group_cardname.py | 355 ++++++++++++++++++
pytests/test_adapter_runtime_state.py | 162 ++++++++
pytests/test_plugin_runtime.py | 6 +-
pytests/utils_test/statistic_test.py | 115 ++++++
src/chat/utils/statistic.py | 12 +-
.../data_models/person_info_data_model.py | 96 ++++-
src/person_info/person_info.py | 79 ++--
src/plugin_runtime/host/supervisor.py | 274 +++++++++++++-
src/plugin_runtime/protocol/envelope.py | 24 ++
src/plugin_runtime/runner/runner_main.py | 7 +-
src/plugins/built_in/napcat_adapter/plugin.py | 168 ++++++++-
11 files changed, 1219 insertions(+), 79 deletions(-)
create mode 100644 pytests/common_test/test_person_info_group_cardname.py
create mode 100644 pytests/test_adapter_runtime_state.py
create mode 100644 pytests/utils_test/statistic_test.py
diff --git a/pytests/common_test/test_person_info_group_cardname.py b/pytests/common_test/test_person_info_group_cardname.py
new file mode 100644
index 00000000..62a63f43
--- /dev/null
+++ b/pytests/common_test/test_person_info_group_cardname.py
@@ -0,0 +1,355 @@
+"""人物信息群名片字段兼容测试。"""
+
+from __future__ import annotations
+
+from importlib.util import module_from_spec, spec_from_file_location
+from pathlib import Path
+from types import ModuleType, SimpleNamespace
+from typing import Any
+
+import json
+import sys
+
+import pytest
+
+from src.common.data_models.person_info_data_model import dump_group_cardname_records, parse_group_cardname_json
+
+
+class _DummyLogger:
+ """模拟日志记录器。"""
+
+ def debug(self, message: str) -> None:
+ """记录调试日志。
+
+ Args:
+ message: 日志内容。
+ """
+ del message
+
+ def info(self, message: str) -> None:
+ """记录信息日志。
+
+ Args:
+ message: 日志内容。
+ """
+ del message
+
+ def warning(self, message: str) -> None:
+ """记录警告日志。
+
+ Args:
+ message: 日志内容。
+ """
+ del message
+
+ def error(self, message: str) -> None:
+ """记录错误日志。
+
+ Args:
+ message: 日志内容。
+ """
+ del message
+
+
+class _DummyStatement:
+ """模拟 SQL 查询语句对象。"""
+
+ def where(self, condition: Any) -> "_DummyStatement":
+ """附加过滤条件。
+
+ Args:
+ condition: 过滤条件。
+
+ Returns:
+ _DummyStatement: 当前语句对象。
+ """
+ del condition
+ return self
+
+ def limit(self, value: int) -> "_DummyStatement":
+ """限制返回条数。
+
+ Args:
+ value: 条数限制。
+
+ Returns:
+ _DummyStatement: 当前语句对象。
+ """
+ del value
+ return self
+
+
+class _DummyColumn:
+ """模拟 SQLModel 列对象。"""
+
+ def is_not(self, value: Any) -> "_DummyColumn":
+ """模拟 `IS NOT` 条件构造。
+
+ Args:
+ value: 比较值。
+
+ Returns:
+ _DummyColumn: 当前列对象。
+ """
+ del value
+ return self
+
+ def __eq__(self, other: Any) -> "_DummyColumn":
+ """模拟等值条件构造。
+
+ Args:
+ other: 比较值。
+
+ Returns:
+ _DummyColumn: 当前列对象。
+ """
+ del other
+ return self
+
+
+class _DummyResult:
+ """模拟数据库查询结果。"""
+
+ def __init__(self, record: Any) -> None:
+ """初始化查询结果。
+
+ Args:
+ record: 待返回的首条记录。
+ """
+ self._record = record
+
+ def first(self) -> Any:
+ """返回第一条记录。
+
+ Returns:
+ Any: 首条记录。
+ """
+ return self._record
+
+ def all(self) -> list[Any]:
+ """返回全部结果。
+
+ Returns:
+ list[Any]: 结果列表。
+ """
+ if self._record is None:
+ return []
+ return self._record if isinstance(self._record, list) else [self._record]
+
+
+class _DummySession:
+ """模拟数据库 Session。"""
+
+ def __init__(self, record: Any) -> None:
+ """初始化 Session。
+
+ Args:
+ record: `first()` 应返回的记录。
+ """
+ self.record = record
+ self.added_records: list[Any] = []
+
+ def __enter__(self) -> "_DummySession":
+ """进入上下文管理器。
+
+ Returns:
+ _DummySession: 当前 Session。
+ """
+ return self
+
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
+ """退出上下文管理器。
+
+ Args:
+ exc_type: 异常类型。
+ exc_val: 异常值。
+ exc_tb: 异常回溯。
+ """
+ del exc_type
+ del exc_val
+ del exc_tb
+
+ def exec(self, statement: Any) -> _DummyResult:
+ """执行查询。
+
+ Args:
+ statement: 查询语句。
+
+ Returns:
+ _DummyResult: 模拟结果对象。
+ """
+ del statement
+ return _DummyResult(self.record)
+
+ def add(self, record: Any) -> None:
+ """记录被添加的对象。
+
+ Args:
+ record: 被写入 Session 的对象。
+ """
+ self.added_records.append(record)
+
+
+class _DummyPersonInfoRecord:
+ """模拟 `PersonInfo` ORM 模型。"""
+
+ person_id = "person_id"
+ person_name = "person_name"
+
+ def __init__(self, **kwargs: Any) -> None:
+ """使用关键字参数初始化记录对象。
+
+ Args:
+ **kwargs: 字段值。
+ """
+ for key, value in kwargs.items():
+ setattr(self, key, value)
+
+
+def _load_person_module(monkeypatch: pytest.MonkeyPatch, session: _DummySession) -> ModuleType:
+ """加载带依赖桩的 `person_info` 模块。
+
+ Args:
+ monkeypatch: Pytest monkeypatch 工具。
+ session: 提供给模块使用的假数据库 Session。
+
+ Returns:
+ ModuleType: 加载后的模块对象。
+ """
+ logger_module = ModuleType("src.common.logger")
+ logger_module.get_logger = lambda name: _DummyLogger()
+ monkeypatch.setitem(sys.modules, "src.common.logger", logger_module)
+
+ database_module = ModuleType("src.common.database.database")
+ database_module.get_db_session = lambda: session
+ monkeypatch.setitem(sys.modules, "src.common.database.database", database_module)
+
+ database_model_module = ModuleType("src.common.database.database_model")
+ database_model_module.PersonInfo = _DummyPersonInfoRecord
+ monkeypatch.setitem(sys.modules, "src.common.database.database_model", database_model_module)
+
+ llm_module = ModuleType("src.llm_models.utils_model")
+
+ class _DummyLLMRequest:
+ """模拟 LLMRequest。"""
+
+ def __init__(self, model_set: Any, request_type: str) -> None:
+ """初始化假请求对象。
+
+ Args:
+ model_set: 模型配置。
+ request_type: 请求类型。
+ """
+ del model_set
+ del request_type
+
+ llm_module.LLMRequest = _DummyLLMRequest
+ monkeypatch.setitem(sys.modules, "src.llm_models.utils_model", llm_module)
+
+ config_module = ModuleType("src.config.config")
+ config_module.global_config = SimpleNamespace(bot=SimpleNamespace(nickname="MaiBot"))
+ config_module.model_config = SimpleNamespace(model_task_config=SimpleNamespace(tool_use="tool_use", utils="utils"))
+ monkeypatch.setitem(sys.modules, "src.config.config", config_module)
+
+ chat_manager_module = ModuleType("src.chat.message_receive.chat_manager")
+ chat_manager_module.chat_manager = SimpleNamespace()
+ monkeypatch.setitem(sys.modules, "src.chat.message_receive.chat_manager", chat_manager_module)
+
+ module_path = Path(__file__).resolve().parents[2] / "src" / "person_info" / "person_info.py"
+ spec = spec_from_file_location("person_info_group_cardname_test_module", module_path)
+ assert spec is not None and spec.loader is not None
+
+ module = module_from_spec(spec)
+ monkeypatch.setitem(sys.modules, spec.name, module)
+ spec.loader.exec_module(module)
+
+ monkeypatch.setattr(module, "select", lambda *args: _DummyStatement())
+ monkeypatch.setattr(module, "col", lambda field: _DummyColumn())
+ return module
+
+
+def test_parse_group_cardname_json_uses_canonical_key() -> None:
+ """群名片 JSON 解析应只使用 `group_cardname` 键名。"""
+ parsed = parse_group_cardname_json(
+ json.dumps(
+ [
+ {"group_id": "1001", "group_cardname": "现行字段"},
+ ],
+ ensure_ascii=False,
+ )
+ )
+
+ assert parsed is not None
+ assert [(item.group_id, item.group_cardname) for item in parsed] == [
+ ("1001", "现行字段"),
+ ]
+
+
+def test_dump_group_cardname_records_uses_canonical_key() -> None:
+ """群名片序列化应输出 `group_cardname` 键名。"""
+ dumped = dump_group_cardname_records(
+ [
+ {"group_id": "1001", "group_cardname": "群昵称"},
+ ]
+ )
+
+ assert json.loads(dumped) == [{"group_id": "1001", "group_cardname": "群昵称"}]
+
+
+def test_person_sync_to_database_uses_group_cardname_field(monkeypatch: pytest.MonkeyPatch) -> None:
+ """同步人物信息时应写入数据库模型的 `group_cardname` 字段。"""
+ record = _DummyPersonInfoRecord()
+ session = _DummySession(record)
+ module = _load_person_module(monkeypatch, session)
+
+ person = module.Person.__new__(module.Person)
+ person.is_known = True
+ person.person_id = "person-1"
+ person.platform = "qq"
+ person.user_id = "10001"
+ person.nickname = "看番的龙"
+ person.person_name = "看番的龙"
+ person.name_reason = "测试"
+ person.know_times = 1
+ person.know_since = 1700000000.0
+ person.last_know = 1700000100.0
+ person.memory_points = ["喜好:番剧:0.8"]
+ person.group_cardname_list = [{"group_id": "20001", "group_cardname": "白泽大人"}]
+
+ person.sync_to_database()
+
+ assert record.group_cardname == '[{"group_id": "20001", "group_cardname": "白泽大人"}]'
+ assert not hasattr(record, "group_nickname")
+
+
+def test_person_load_from_database_normalizes_group_cardname_payload(monkeypatch: pytest.MonkeyPatch) -> None:
+ """从数据库加载人物信息时应读取标准 `group_cardname` 结构。"""
+ record = _DummyPersonInfoRecord(
+ user_id="10001",
+ platform="qq",
+ is_known=True,
+ user_nickname="看番的龙",
+ person_name="看番的龙",
+ name_reason=None,
+ know_counts=2,
+ memory_points='["喜好:番剧:0.8"]',
+ group_cardname=json.dumps(
+ [
+ {"group_id": "20001", "group_cardname": "白泽大人"},
+ ],
+ ensure_ascii=False,
+ ),
+ )
+ session = _DummySession(record)
+ module = _load_person_module(monkeypatch, session)
+
+ person = module.Person.__new__(module.Person)
+ person.person_id = "person-1"
+ person.memory_points = []
+ person.group_cardname_list = []
+
+ person.load_from_database()
+
+ assert person.group_cardname_list == [
+ {"group_id": "20001", "group_cardname": "白泽大人"},
+ ]
diff --git a/pytests/test_adapter_runtime_state.py b/pytests/test_adapter_runtime_state.py
new file mode 100644
index 00000000..e82f4c8c
--- /dev/null
+++ b/pytests/test_adapter_runtime_state.py
@@ -0,0 +1,162 @@
+"""适配器运行时状态同步测试。"""
+
+from typing import Any, Dict
+
+import pytest
+
+from src.platform_io.manager import PlatformIOManager
+from src.platform_io.types import RouteKey
+from src.plugin_runtime.host.supervisor import PluginSupervisor
+from src.plugin_runtime.protocol.envelope import (
+ AdapterDeclarationPayload,
+ Envelope,
+ MessageType,
+)
+
+
+def _make_request(plugin_id: str, payload: Dict[str, Any]) -> Envelope:
+ """构造一个适配器状态更新 RPC 请求。
+
+ Args:
+ plugin_id: 目标适配器插件 ID。
+ payload: 请求载荷。
+
+ Returns:
+ Envelope: 标准 RPC 请求信封。
+ """
+ return Envelope(
+ request_id=1,
+ message_type=MessageType.REQUEST,
+ method="host.update_adapter_state",
+ plugin_id=plugin_id,
+ payload=payload,
+ )
+
+
+@pytest.mark.asyncio
+async def test_adapter_runtime_state_binds_and_unbinds_routes(monkeypatch: pytest.MonkeyPatch) -> None:
+ """连接建立后应绑定路由,断开后应撤销路由。"""
+ import src.plugin_runtime.host.supervisor as supervisor_module
+
+ platform_io_manager = PlatformIOManager()
+ monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager)
+
+ supervisor = PluginSupervisor(plugin_dirs=[])
+ adapter = AdapterDeclarationPayload(platform="qq", protocol="napcat")
+ await supervisor._register_adapter_driver("napcat_adapter_builtin", adapter)
+
+ response = await supervisor._handle_update_adapter_state(
+ _make_request(
+ "napcat_adapter_builtin",
+ {
+ "connected": True,
+ "account_id": "10001",
+ "scope": "",
+ "metadata": {},
+ },
+ )
+ )
+
+ assert response.error is None
+ assert response.payload["accepted"] is True
+ assert (
+ platform_io_manager.route_table.get_active_binding(
+ RouteKey(platform="qq", account_id="10001"),
+ exact_only=True,
+ ).driver_id
+ == "adapter:napcat_adapter_builtin"
+ )
+ assert (
+ platform_io_manager.route_table.get_active_binding(
+ RouteKey(platform="qq"),
+ exact_only=True,
+ ).driver_id
+ == "adapter:napcat_adapter_builtin"
+ )
+
+ response = await supervisor._handle_update_adapter_state(
+ _make_request(
+ "napcat_adapter_builtin",
+ {
+ "connected": False,
+ "account_id": "",
+ "scope": "",
+ "metadata": {},
+ },
+ )
+ )
+
+ assert response.error is None
+ assert response.payload["accepted"] is True
+ assert platform_io_manager.route_table.get_active_binding(
+ RouteKey(platform="qq", account_id="10001"),
+ exact_only=True,
+ ) is None
+ assert platform_io_manager.route_table.get_active_binding(RouteKey(platform="qq"), exact_only=True) is None
+
+
+@pytest.mark.asyncio
+async def test_platform_default_route_is_removed_when_multiple_exact_routes_exist(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """同一平台存在多个精确路由时不应保留默认平台路由。"""
+ import src.plugin_runtime.host.supervisor as supervisor_module
+
+ platform_io_manager = PlatformIOManager()
+ monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager)
+
+ supervisor = PluginSupervisor(plugin_dirs=[])
+ adapter = AdapterDeclarationPayload(platform="qq", protocol="napcat")
+ await supervisor._register_adapter_driver("adapter_a", adapter)
+ await supervisor._register_adapter_driver("adapter_b", adapter)
+
+ await supervisor._handle_update_adapter_state(
+ _make_request(
+ "adapter_a",
+ {
+ "connected": True,
+ "account_id": "10001",
+ "scope": "",
+ "metadata": {},
+ },
+ )
+ )
+ assert (
+ platform_io_manager.route_table.get_active_binding(
+ RouteKey(platform="qq"),
+ exact_only=True,
+ ).driver_id
+ == "adapter:adapter_a"
+ )
+
+ await supervisor._handle_update_adapter_state(
+ _make_request(
+ "adapter_b",
+ {
+ "connected": True,
+ "account_id": "10002",
+ "scope": "",
+ "metadata": {},
+ },
+ )
+ )
+ assert platform_io_manager.route_table.get_active_binding(RouteKey(platform="qq"), exact_only=True) is None
+
+ await supervisor._handle_update_adapter_state(
+ _make_request(
+ "adapter_b",
+ {
+ "connected": False,
+ "account_id": "",
+ "scope": "",
+ "metadata": {},
+ },
+ )
+ )
+ assert (
+ platform_io_manager.route_table.get_active_binding(
+ RouteKey(platform="qq"),
+ exact_only=True,
+ ).driver_id
+ == "adapter:adapter_a"
+ )
diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py
index 2c703161..5ab16c85 100644
--- a/pytests/test_plugin_runtime.py
+++ b/pytests/test_plugin_runtime.py
@@ -486,10 +486,10 @@ class TestSDK:
"timeout_ms": timeout_ms,
}
)
- if method == "cap.request":
+ if method == "cap.call":
bootstrap_methods = [call["method"] for call in self.calls[:-1]]
assert "plugin.bootstrap" in bootstrap_methods
- return SimpleNamespace(error=None, payload={"result": {"success": True}})
+ return SimpleNamespace(error=None, payload={"success": True})
return SimpleNamespace(error=None, payload={"accepted": True})
async def disconnect(self):
@@ -529,7 +529,7 @@ class TestSDK:
await runner.run()
methods = [call["method"] for call in runner._rpc_client.calls]
- assert methods == ["plugin.bootstrap", "cap.request", "plugin.register_components", "runner.ready"]
+ assert methods == ["plugin.bootstrap", "plugin.register_components", "cap.call", "runner.ready"]
class TestPluginSdkUsage:
diff --git a/pytests/utils_test/statistic_test.py b/pytests/utils_test/statistic_test.py
new file mode 100644
index 00000000..d3d8c18a
--- /dev/null
+++ b/pytests/utils_test/statistic_test.py
@@ -0,0 +1,115 @@
+"""统计模块数据库会话行为测试。"""
+
+from __future__ import annotations
+
+from contextlib import contextmanager
+from datetime import datetime, timedelta
+from types import ModuleType
+from typing import Any, Callable, Iterator
+
+import sys
+
+import pytest
+
+from src.chat.utils import statistic
+
+
+class _DummyResult:
+ """模拟 SQLModel 查询结果对象。"""
+
+ def all(self) -> list[Any]:
+ """返回空结果集。
+
+ Returns:
+ list[Any]: 空列表。
+ """
+ return []
+
+
+class _DummySession:
+ """模拟数据库 Session。"""
+
+ def exec(self, statement: Any) -> _DummyResult:
+ """执行查询语句并返回空结果。
+
+ Args:
+ statement: 待执行的查询语句。
+
+ Returns:
+ _DummyResult: 空结果对象。
+ """
+ del statement
+ return _DummyResult()
+
+
+def _build_fake_get_db_session(calls: list[bool]) -> Callable[[bool], Iterator[_DummySession]]:
+ """构造一个记录 auto_commit 参数的假会话工厂。
+
+ Args:
+ calls: 用于记录每次调用 auto_commit 参数的列表。
+
+ Returns:
+ Callable[[bool], Iterator[_DummySession]]: 可替换 `get_db_session` 的上下文管理器工厂。
+ """
+
+ @contextmanager
+ def _fake_get_db_session(auto_commit: bool = True) -> Iterator[_DummySession]:
+ """记录会话参数并返回假 Session。
+
+ Args:
+ auto_commit: 是否启用自动提交。
+
+ Yields:
+ Iterator[_DummySession]: 假 Session 对象。
+ """
+ calls.append(auto_commit)
+ yield _DummySession()
+
+ return _fake_get_db_session
+
+
+def _build_statistic_task() -> statistic.StatisticOutputTask:
+ """构造一个最小可用的统计任务实例。
+
+ Returns:
+ statistic.StatisticOutputTask: 跳过 `__init__` 的测试实例。
+ """
+ task = statistic.StatisticOutputTask.__new__(statistic.StatisticOutputTask)
+ task.name_mapping = {}
+ return task
+
+
+def _is_bot_self(platform: str, user_id: str) -> bool:
+ """返回固定的非机器人身份判断结果。
+
+ Args:
+ platform: 平台名称。
+ user_id: 用户 ID。
+
+ Returns:
+ bool: 始终返回 ``False``。
+ """
+ del platform
+ del user_id
+ return False
+
+
+def test_statistic_read_queries_disable_auto_commit(monkeypatch: pytest.MonkeyPatch) -> None:
+ """统计模块的纯读查询应关闭自动提交,避免 Session 退出后对象被 expire。"""
+ calls: list[bool] = []
+ now = datetime.now()
+ task = _build_statistic_task()
+
+ monkeypatch.setattr(statistic, "get_db_session", _build_fake_get_db_session(calls))
+
+ utils_module = ModuleType("src.chat.utils.utils")
+ utils_module.is_bot_self = _is_bot_self
+ monkeypatch.setitem(sys.modules, "src.chat.utils.utils", utils_module)
+
+ statistic.StatisticOutputTask._fetch_online_time_since(now)
+ statistic.StatisticOutputTask._fetch_model_usage_since(now)
+ task._collect_message_count_for_period([("last_hour", now - timedelta(hours=1))])
+ task._collect_interval_data(now, hours=1, interval_minutes=60)
+ task._collect_metrics_interval_data(now, hours=1, interval_hours=1)
+
+ assert calls == [False] * 9
diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py
index ede10a41..51e5e643 100644
--- a/src/chat/utils/statistic.py
+++ b/src/chat/utils/statistic.py
@@ -436,14 +436,14 @@ class StatisticOutputTask(AsyncTask):
@staticmethod
def _fetch_online_time_since(query_start_time: datetime) -> list[tuple[datetime, datetime]]:
- with get_db_session() as session:
+ with get_db_session(auto_commit=False) as session:
statement = select(OnlineTime).where(col(OnlineTime.end_timestamp) >= query_start_time)
records = session.exec(statement).all()
return [(record.start_timestamp, record.end_timestamp) for record in records]
@staticmethod
def _fetch_model_usage_since(query_start_time: datetime) -> list[dict[str, object]]:
- with get_db_session() as session:
+ with get_db_session(auto_commit=False) as session:
statement = select(ModelUsage).where(col(ModelUsage.timestamp) >= query_start_time)
records = session.exec(statement).all()
return [
@@ -664,7 +664,7 @@ class StatisticOutputTask(AsyncTask):
}
query_start_timestamp = collect_period[-1][1]
- with get_db_session() as session:
+ with get_db_session(auto_commit=False) as session:
statement = select(Messages).where(col(Messages.timestamp) >= query_start_timestamp)
messages = session.exec(statement).all()
for message in messages:
@@ -713,7 +713,7 @@ class StatisticOutputTask(AsyncTask):
# 使用 ActionRecords 中的 reply 动作次数作为回复数基准
try:
action_query_start_timestamp = collect_period[-1][1]
- with get_db_session() as session:
+ with get_db_session(auto_commit=False) as session:
statement = select(ActionRecord).where(col(ActionRecord.timestamp) >= action_query_start_timestamp)
actions = session.exec(statement).all()
for action in actions:
@@ -1750,7 +1750,7 @@ class StatisticOutputTask(AsyncTask):
# 查询消息记录
query_start_timestamp = start_time.timestamp()
- with get_db_session() as session:
+ with get_db_session(auto_commit=False) as session:
statement = select(Messages).where(col(Messages.timestamp) >= start_time)
messages = session.exec(statement).all()
for message in messages:
@@ -2131,7 +2131,7 @@ class StatisticOutputTask(AsyncTask):
# 查询消息记录
query_start_timestamp = start_time.timestamp()
- with get_db_session() as session:
+ with get_db_session(auto_commit=False) as session:
statement = select(Messages).where(col(Messages.timestamp) >= start_time)
messages = session.exec(statement).all()
for message in messages:
diff --git a/src/common/data_models/person_info_data_model.py b/src/common/data_models/person_info_data_model.py
index 4cbb62d8..1b239356 100644
--- a/src/common/data_models/person_info_data_model.py
+++ b/src/common/data_models/person_info_data_model.py
@@ -1,6 +1,6 @@
-from dataclasses import dataclass
+from dataclasses import asdict, dataclass
from datetime import datetime
-from typing import Optional, List
+from typing import Any, List, Mapping, Optional, Sequence
import json
@@ -15,6 +15,76 @@ class GroupCardnameInfo:
group_cardname: str
+def _normalize_group_cardname_item(raw_item: Mapping[str, Any]) -> Optional[GroupCardnameInfo]:
+ """将单条群名片数据规范化为统一结构。
+
+ Args:
+ raw_item: 原始群名片字典,必须包含 `group_id` 和 `group_cardname`。
+
+ Returns:
+ Optional[GroupCardnameInfo]: 规范化后的群名片信息;若数据不完整则返回 ``None``。
+ """
+ group_id = str(raw_item.get("group_id") or "").strip()
+ group_cardname = str(raw_item.get("group_cardname") or "").strip()
+ if not group_id or not group_cardname:
+ return None
+ return GroupCardnameInfo(group_id=group_id, group_cardname=group_cardname)
+
+
+def parse_group_cardname_json(group_cardname_json: Optional[str]) -> Optional[List[GroupCardnameInfo]]:
+ """解析数据库中的群名片 JSON 字段。
+
+ Args:
+ group_cardname_json: 数据库存储的群名片 JSON 字符串。
+
+ Returns:
+ Optional[List[GroupCardnameInfo]]: 解析并规范化后的群名片列表;若字段为空或无有效项则返回 ``None``。
+
+ Raises:
+ json.JSONDecodeError: 当 JSON 文本格式非法时抛出。
+ TypeError: 当输入值类型不符合 `json.loads()` 要求时抛出。
+ """
+ if not group_cardname_json:
+ return None
+
+ raw_items = json.loads(group_cardname_json)
+ if not isinstance(raw_items, list):
+ return None
+
+ normalized_items: List[GroupCardnameInfo] = []
+ for raw_item in raw_items:
+ if not isinstance(raw_item, Mapping):
+ continue
+ if normalized_item := _normalize_group_cardname_item(raw_item):
+ normalized_items.append(normalized_item)
+
+ return normalized_items or None
+
+
+def dump_group_cardname_records(
+ group_cardname_records: Optional[Sequence[GroupCardnameInfo | Mapping[str, Any]]],
+) -> str:
+ """将群名片列表序列化为数据库使用的标准 JSON 字符串。
+
+ Args:
+ group_cardname_records: 待序列化的群名片列表,支持 `GroupCardnameInfo`
+ 对象和包含 `group_id` / `group_cardname` 的字典。
+
+ Returns:
+ str: 统一使用 `group_cardname` 键名的 JSON 字符串。
+ """
+ normalized_items: List[GroupCardnameInfo] = []
+ for raw_item in group_cardname_records or []:
+ if isinstance(raw_item, GroupCardnameInfo):
+ normalized_items.append(raw_item)
+ continue
+ if isinstance(raw_item, Mapping):
+ if normalized_item := _normalize_group_cardname_item(raw_item):
+ normalized_items.append(normalized_item)
+
+ return json.dumps([asdict(item) for item in normalized_items], ensure_ascii=False)
+
+
class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
def __init__(
self,
@@ -58,9 +128,16 @@ class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
"""最后一次被认识的时间"""
@classmethod
- def from_db_instance(cls, db_record: "PersonInfo"):
- nickname_json = json.loads(db_record.group_cardname) if db_record.group_cardname else None
- group_cardname_list = [GroupCardnameInfo(**item) for item in nickname_json] if nickname_json else None
+ def from_db_instance(cls, db_record: "PersonInfo") -> "MaiPersonInfo":
+ """从数据库记录构造人物信息数据模型。
+
+ Args:
+ db_record: 数据库中的人物信息记录。
+
+ Returns:
+ MaiPersonInfo: 转换后的数据模型对象。
+ """
+ group_cardname_list = parse_group_cardname_json(db_record.group_cardname)
memory_points = json.loads(db_record.memory_points) if db_record.memory_points else None
return cls(
is_known=db_record.is_known,
@@ -78,9 +155,12 @@ class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
)
def to_db_instance(self) -> "PersonInfo":
- group_cardname = (
- json.dumps([gc.__dict__ for gc in self.group_cardname_list]) if self.group_cardname_list else None
- )
+ """将当前数据模型转换为数据库记录对象。
+
+ Returns:
+ PersonInfo: 可直接写入数据库的模型实例。
+ """
+ group_cardname = dump_group_cardname_records(self.group_cardname_list)
return PersonInfo(
is_known=self.is_known,
person_id=self.person_id,
diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py
index 799f56a0..15ef0049 100644
--- a/src/person_info/person_info.py
+++ b/src/person_info/person_info.py
@@ -1,22 +1,24 @@
-import hashlib
+from datetime import datetime
+from typing import Dict, Optional, Union
+
import asyncio
+import hashlib
import json
-import time
-import random
import math
+import random
+import time
from json_repair import repair_json
-from typing import Union, Optional, Dict
-from datetime import datetime
from sqlmodel import col, select
-from src.common.logger import get_logger
+from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
+from src.common.data_models.person_info_data_model import dump_group_cardname_records, parse_group_cardname_json
from src.common.database.database import get_db_session
from src.common.database.database_model import PersonInfo
-from src.llm_models.utils_model import LLMRequest
+from src.common.logger import get_logger
from src.config.config import global_config, model_config
-from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
+from src.llm_models.utils_model import LLMRequest
logger = get_logger("person_info")
@@ -26,6 +28,32 @@ relation_selection_model = LLMRequest(
)
+def _to_group_cardname_records(group_cardname_json: Optional[str]) -> list[dict[str, str]]:
+ """将数据库中的群名片 JSON 转换为 `Person` 内部使用的结构。
+
+ Args:
+ group_cardname_json: 数据库存储的群名片 JSON 字符串。
+
+ Returns:
+ list[dict[str, str]]: 统一使用 `group_cardname` 键名的群名片列表。
+
+ Raises:
+ json.JSONDecodeError: 当 JSON 文本格式非法时抛出。
+ TypeError: 当输入值类型不符合 `json.loads()` 要求时抛出。
+ """
+ group_cardname_list = parse_group_cardname_json(group_cardname_json)
+ if not group_cardname_list:
+ return []
+
+ return [
+ {
+ "group_id": group_cardname.group_id,
+ "group_cardname": group_cardname.group_cardname,
+ }
+ for group_cardname in group_cardname_list
+ ]
+
+
def get_person_id(platform: str, user_id: Union[int, str]) -> str:
"""获取唯一id"""
if "-" in platform:
@@ -231,7 +259,7 @@ class Person:
person.know_since = time.time()
person.last_know = time.time()
person.memory_points = []
- person.group_nick_name = [] # 初始化群昵称列表
+ person.group_cardname_list = [] # 初始化群名片列表
# 如果是群聊,添加群昵称
if group_id and group_nick_name:
@@ -269,7 +297,7 @@ class Person:
self.platform = platform
self.nickname = global_config.bot.nickname
self.person_name = global_config.bot.nickname
- self.group_nick_name: list[dict[str, str]] = []
+ self.group_cardname_list: list[dict[str, str]] = []
return
self.user_id = ""
@@ -308,7 +336,7 @@ class Person:
self.know_since = None
self.last_know: Optional[float] = None
self.memory_points = []
- self.group_nick_name: list[dict[str, str]] = [] # 群昵称列表,存储 {"group_id": str, "group_nick_name": str}
+ self.group_cardname_list: list[dict[str, str]] = [] # 群名片列表,存储 {"group_id": str, "group_cardname": str}
# 从数据库加载数据
self.load_from_database()
@@ -408,16 +436,16 @@ class Person:
return
# 检查是否已存在该群号的记录
- for item in self.group_nick_name:
+ for item in self.group_cardname_list:
if item.get("group_id") == group_id:
# 更新现有记录
- item["group_nick_name"] = group_nick_name
+ item["group_cardname"] = group_nick_name
self.sync_to_database()
logger.debug(f"更新用户 {self.person_id} 在群 {group_id} 的群昵称为 {group_nick_name}")
return
# 添加新记录
- self.group_nick_name.append({"group_id": group_id, "group_nick_name": group_nick_name})
+ self.group_cardname_list.append({"group_id": group_id, "group_cardname": group_nick_name})
self.sync_to_database()
logger.debug(f"添加用户 {self.person_id} 在群 {group_id} 的群昵称 {group_nick_name}")
@@ -452,20 +480,15 @@ class Person:
else:
self.memory_points = []
- # 处理group_nick_name字段(JSON格式的列表)
+ # 处理 group_cardname 字段(JSON 格式的列表)
if record.group_cardname:
try:
- loaded_group_nick_names = json.loads(record.group_cardname)
- # 确保是列表格式
- if isinstance(loaded_group_nick_names, list):
- self.group_nick_name = loaded_group_nick_names
- else:
- self.group_nick_name = []
+ self.group_cardname_list = _to_group_cardname_records(record.group_cardname)
except (json.JSONDecodeError, TypeError):
logger.warning(f"解析用户 {self.person_id} 的group_cardname字段失败,使用默认值")
- self.group_nick_name = []
+ self.group_cardname_list = []
else:
- self.group_nick_name = []
+ self.group_cardname_list = []
logger.debug(f"已从数据库加载用户 {self.person_id} 的信息")
else:
@@ -486,11 +509,7 @@ class Person:
if self.memory_points
else json.dumps([], ensure_ascii=False)
)
- group_nickname_value = (
- json.dumps(self.group_nick_name, ensure_ascii=False)
- if self.group_nick_name
- else json.dumps([], ensure_ascii=False)
- )
+ group_cardname_value = dump_group_cardname_records(self.group_cardname_list)
first_known_time = datetime.fromtimestamp(self.know_since) if self.know_since else None
last_known_time = datetime.fromtimestamp(self.last_know) if self.last_know else None
@@ -510,7 +529,7 @@ class Person:
record.first_known_time = first_known_time
record.last_known_time = last_known_time
record.memory_points = memory_points_value
- record.group_nickname = group_nickname_value
+ record.group_cardname = group_cardname_value
session.add(record)
logger.debug(f"已同步用户 {self.person_id} 的信息到数据库")
else:
@@ -526,7 +545,7 @@ class Person:
first_known_time=first_known_time,
last_known_time=last_known_time,
memory_points=memory_points_value,
- group_nickname=group_nickname_value,
+ group_cardname=group_cardname_value,
)
session.add(record)
logger.debug(f"已创建用户 {self.person_id} 的信息到数据库")
diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py
index 33091d5a..8a26af11 100644
--- a/src/plugin_runtime/host/supervisor.py
+++ b/src/plugin_runtime/host/supervisor.py
@@ -1,3 +1,4 @@
+from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
@@ -8,13 +9,15 @@ import sys
from src.common.logger import get_logger
from src.config.config import global_config
-from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, get_platform_io_manager
+from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, RouteMode, get_platform_io_manager
from src.platform_io.drivers import PluginPlatformDriver
from src.platform_io.route_key_factory import RouteKeyFactory
from src.platform_io.routing import RouteBindingConflictError
from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN
from src.plugin_runtime.protocol.envelope import (
AdapterDeclarationPayload,
+ AdapterStateUpdatePayload,
+ AdapterStateUpdateResultPayload,
BootstrapPluginPayload,
ConfigUpdatedPayload,
Envelope,
@@ -46,6 +49,19 @@ if TYPE_CHECKING:
logger = get_logger("plugin_runtime.host.runner_manager")
+_ADAPTER_BINDING_ROLE_RUNTIME_EXACT = "runtime_exact"
+_ADAPTER_BINDING_ROLE_PLATFORM_DEFAULT = "platform_default"
+
+
+@dataclass(slots=True)
+class _AdapterRuntimeState:
+ """保存适配器插件当前的运行时连接状态。"""
+
+ connected: bool = False
+ account_id: Optional[str] = None
+ scope: Optional[str] = None
+ metadata: Dict[str, Any] = field(default_factory=dict)
+
class PluginRunnerSupervisor:
"""插件 Runner 监督器。
@@ -94,6 +110,7 @@ class PluginRunnerSupervisor:
self._runner_process: Optional[asyncio.subprocess.Process] = None
self._registered_plugins: Dict[str, RegisterPluginPayload] = {}
self._registered_adapters: Dict[str, AdapterDeclarationPayload] = {}
+ self._adapter_runtime_states: Dict[str, _AdapterRuntimeState] = {}
self._runner_ready_events: asyncio.Event = asyncio.Event()
self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload()
self._health_task: Optional[asyncio.Task[None]] = None
@@ -452,6 +469,7 @@ class PluginRunnerSupervisor:
"""注册 Host 侧内部 RPC 方法。"""
self._rpc_server.register_method("cap.call", self._capability_service.handle_capability_request)
self._rpc_server.register_method("host.receive_external_message", self._handle_receive_external_message)
+ self._rpc_server.register_method("host.update_adapter_state", self._handle_update_adapter_state)
self._rpc_server.register_method("plugin.bootstrap", self._handle_bootstrap_plugin)
self._rpc_server.register_method("plugin.register_components", self._handle_register_plugin)
self._rpc_server.register_method("plugin.register_plugin", self._handle_register_plugin)
@@ -563,14 +581,14 @@ class PluginRunnerSupervisor:
return f"adapter:{plugin_id}"
async def _register_adapter_driver(self, plugin_id: str, adapter: AdapterDeclarationPayload) -> None:
- """将适配器插件注册到 Platform IO。
+ """将适配器插件驱动注册到 Platform IO。
Args:
plugin_id: 适配器插件 ID。
adapter: 经过校验的适配器声明。
Raises:
- ValueError: 适配器路由冲突或驱动注册失败时抛出。
+ ValueError: 当驱动注册失败时抛出。
"""
await self._unregister_adapter_driver(plugin_id)
@@ -588,22 +606,12 @@ class PluginRunnerSupervisor:
**adapter.metadata,
},
)
- binding = RouteBinding(
- route_key=driver.descriptor.route_key,
- driver_id=driver.driver_id,
- driver_kind=DriverKind.PLUGIN,
- metadata={
- "plugin_id": plugin_id,
- "protocol": adapter.protocol,
- },
- )
try:
if platform_io_manager.is_started:
await platform_io_manager.add_driver(driver)
else:
platform_io_manager.register_driver(driver)
- platform_io_manager.bind_route(binding)
except Exception:
with contextlib.suppress(Exception):
if platform_io_manager.is_started:
@@ -613,6 +621,7 @@ class PluginRunnerSupervisor:
raise
self._registered_adapters[plugin_id] = adapter
+ self._adapter_runtime_states[plugin_id] = _AdapterRuntimeState()
async def _unregister_adapter_driver(self, plugin_id: str) -> None:
"""从 Platform IO 注销一个适配器驱动。
@@ -622,6 +631,9 @@ class PluginRunnerSupervisor:
"""
platform_io_manager = get_platform_io_manager()
driver_id = self._build_adapter_driver_id(plugin_id)
+ adapter = self._registered_adapters.get(plugin_id)
+
+ self._remove_adapter_route_bindings(plugin_id)
with contextlib.suppress(Exception):
if platform_io_manager.is_started:
@@ -629,7 +641,11 @@ class PluginRunnerSupervisor:
else:
platform_io_manager.unregister_driver(driver_id)
+ if adapter is not None:
+ self._refresh_platform_default_route(adapter.platform)
+
self._registered_adapters.pop(plugin_id, None)
+ self._adapter_runtime_states.pop(plugin_id, None)
async def _unregister_all_adapter_drivers(self) -> None:
"""注销当前 Supervisor 管理的全部适配器驱动。"""
@@ -637,6 +653,198 @@ class PluginRunnerSupervisor:
for plugin_id in plugin_ids:
await self._unregister_adapter_driver(plugin_id)
+ def _remove_adapter_route_bindings(self, plugin_id: str) -> None:
+ """移除某个适配器驱动当前持有的全部路由绑定。
+
+ Args:
+ plugin_id: 适配器插件 ID。
+ """
+ platform_io_manager = get_platform_io_manager()
+ platform_io_manager.route_table.remove_bindings_by_driver(self._build_adapter_driver_id(plugin_id))
+
+ @staticmethod
+ def _normalize_runtime_route_value(value: str) -> Optional[str]:
+ """规范化适配器运行时路由字段。
+
+ Args:
+ value: 待规范化的原始字符串。
+
+ Returns:
+ Optional[str]: 规范化后非空则返回字符串,否则返回 ``None``。
+ """
+ normalized_value = str(value).strip()
+ return normalized_value or None
+
+ def _build_runtime_route_key(
+ self,
+ adapter: AdapterDeclarationPayload,
+ payload: AdapterStateUpdatePayload,
+ ) -> RouteKey:
+ """根据运行时状态更新构造适配器生效路由键。
+
+ Args:
+ adapter: 当前适配器声明。
+ payload: 适配器上报的运行时状态。
+
+ Returns:
+ RouteKey: 当前连接应接管的精确路由键。
+
+ Raises:
+ ValueError: 当静态声明与运行时上报的身份信息冲突时抛出。
+ """
+ runtime_account_id = self._normalize_runtime_route_value(payload.account_id)
+ runtime_scope = self._normalize_runtime_route_value(payload.scope)
+
+ if adapter.account_id and runtime_account_id and adapter.account_id != runtime_account_id:
+ raise ValueError(
+ f"适配器声明的 account_id={adapter.account_id} 与运行时上报的 {runtime_account_id} 不一致"
+ )
+ if adapter.scope and runtime_scope and adapter.scope != runtime_scope:
+ raise ValueError(f"适配器声明的 scope={adapter.scope} 与运行时上报的 {runtime_scope} 不一致")
+
+ return RouteKey(
+ platform=adapter.platform,
+ account_id=runtime_account_id or adapter.account_id or None,
+ scope=runtime_scope or adapter.scope or None,
+ )
+
+ def _bind_runtime_exact_route(
+ self,
+ plugin_id: str,
+ adapter: AdapterDeclarationPayload,
+ route_key: RouteKey,
+ ) -> None:
+ """为适配器连接绑定精确生效路由。
+
+ Args:
+ plugin_id: 适配器插件 ID。
+ adapter: 当前适配器声明。
+ route_key: 当前连接对应的精确路由键。
+
+ Raises:
+ RouteBindingConflictError: 当目标路由已被其他 active owner 占用时抛出。
+ """
+ platform_io_manager = get_platform_io_manager()
+ platform_io_manager.bind_route(
+ RouteBinding(
+ route_key=route_key,
+ driver_id=self._build_adapter_driver_id(plugin_id),
+ driver_kind=DriverKind.PLUGIN,
+ metadata={
+ "plugin_id": plugin_id,
+ "protocol": adapter.protocol,
+ "binding_role": _ADAPTER_BINDING_ROLE_RUNTIME_EXACT,
+ },
+ )
+ )
+
+ def _list_runtime_exact_bindings(self, platform: str) -> List[RouteBinding]:
+ """列出某个平台上由 Host 动态维护的精确适配器绑定。
+
+ Args:
+ platform: 目标平台名称。
+
+ Returns:
+ List[RouteBinding]: 当前平台上全部动态精确绑定。
+ """
+ platform_io_manager = get_platform_io_manager()
+ return [
+ binding
+ for binding in platform_io_manager.route_table.list_bindings()
+ if binding.mode == RouteMode.ACTIVE
+ and binding.route_key.platform == platform
+ and binding.metadata.get("binding_role") == _ADAPTER_BINDING_ROLE_RUNTIME_EXACT
+ ]
+
+ def _refresh_platform_default_route(self, platform: str) -> None:
+ """根据当前精确绑定数量刷新平台级默认路由。
+
+ 当某个平台恰好只存在一个动态精确绑定时,会为该绑定额外创建一条
+ ``RouteKey(platform=)`` 形式的默认路由,方便缺少账号维度的
+ 出站消息继续找到唯一 owner。若精确绑定数量变为 0 或大于 1,则撤销
+ 由 Host 自动维护的默认路由,避免出现隐式歧义。
+
+ Args:
+ platform: 目标平台名称。
+ """
+ platform_io_manager = get_platform_io_manager()
+ default_route_key = RouteKey(platform=platform)
+ existing_default_binding = platform_io_manager.route_table.get_active_binding(default_route_key, exact_only=True)
+
+ if existing_default_binding is not None:
+ binding_role = existing_default_binding.metadata.get("binding_role")
+ if binding_role != _ADAPTER_BINDING_ROLE_PLATFORM_DEFAULT:
+ return
+ platform_io_manager.unbind_route(default_route_key, existing_default_binding.driver_id)
+
+ exact_bindings = self._list_runtime_exact_bindings(platform)
+ if len(exact_bindings) != 1:
+ return
+
+ exact_binding = exact_bindings[0]
+ if exact_binding.route_key == default_route_key:
+ return
+
+ platform_io_manager.bind_route(
+ RouteBinding(
+ route_key=default_route_key,
+ driver_id=exact_binding.driver_id,
+ driver_kind=exact_binding.driver_kind,
+ metadata={
+ "plugin_id": exact_binding.metadata.get("plugin_id", ""),
+ "protocol": exact_binding.metadata.get("protocol", ""),
+ "binding_role": _ADAPTER_BINDING_ROLE_PLATFORM_DEFAULT,
+ },
+ ),
+ replace=True,
+ )
+
+ def _apply_adapter_runtime_state(
+ self,
+ plugin_id: str,
+ adapter: AdapterDeclarationPayload,
+ payload: AdapterStateUpdatePayload,
+ ) -> Tuple[_AdapterRuntimeState, Dict[str, Any]]:
+ """应用适配器运行时状态,并同步 Platform IO 路由。
+
+ Args:
+ plugin_id: 适配器插件 ID。
+ adapter: 当前适配器声明。
+ payload: 适配器上报的运行时状态。
+
+ Returns:
+ Tuple[_AdapterRuntimeState, Dict[str, Any]]: 更新后的运行时状态,以及
+ 供 RPC 响应返回的路由键字典。
+
+ Raises:
+ RouteBindingConflictError: 当新的精确路由与其他 active owner 冲突时抛出。
+ ValueError: 当运行时路由信息不合法时抛出。
+ """
+ if not payload.connected:
+ self._remove_adapter_route_bindings(plugin_id)
+ self._refresh_platform_default_route(adapter.platform)
+ runtime_state = _AdapterRuntimeState(connected=False, metadata=dict(payload.metadata))
+ self._adapter_runtime_states[plugin_id] = runtime_state
+ return runtime_state, {}
+
+ route_key = self._build_runtime_route_key(adapter, payload)
+ self._remove_adapter_route_bindings(plugin_id)
+ self._bind_runtime_exact_route(plugin_id, adapter, route_key)
+ self._refresh_platform_default_route(adapter.platform)
+
+ runtime_state = _AdapterRuntimeState(
+ connected=True,
+ account_id=route_key.account_id,
+ scope=route_key.scope,
+ metadata=dict(payload.metadata),
+ )
+ self._adapter_runtime_states[plugin_id] = runtime_state
+ return runtime_state, {
+ "platform": route_key.platform,
+ "account_id": route_key.account_id,
+ "scope": route_key.scope,
+ }
+
@staticmethod
def _attach_inbound_route_metadata(
session_message: "SessionMessage",
@@ -706,6 +914,45 @@ class PluginRunnerSupervisor:
scope=scope,
)
+ async def _handle_update_adapter_state(self, envelope: Envelope) -> Envelope:
+ """处理适配器插件上报的运行时状态更新。
+
+ Args:
+ envelope: RPC 请求信封。
+
+ Returns:
+ Envelope: 状态更新处理结果。
+ """
+ try:
+ payload = AdapterStateUpdatePayload.model_validate(envelope.payload)
+ except Exception as exc:
+ return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
+
+ adapter = self._registered_adapters.get(envelope.plugin_id)
+ if adapter is None:
+ return envelope.make_error_response(
+ ErrorCode.E_METHOD_NOT_ALLOWED.value,
+ f"插件 {envelope.plugin_id} 未声明为适配器,不能更新运行时状态",
+ )
+
+ try:
+ runtime_state, route_key_dict = self._apply_adapter_runtime_state(
+ plugin_id=envelope.plugin_id,
+ adapter=adapter,
+ payload=payload,
+ )
+ except RouteBindingConflictError as exc:
+ return envelope.make_error_response(ErrorCode.E_METHOD_NOT_ALLOWED.value, str(exc))
+ except Exception as exc:
+ return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
+
+ response = AdapterStateUpdateResultPayload(
+ accepted=True,
+ connected=runtime_state.connected,
+ route_key=route_key_dict,
+ )
+ return envelope.make_response(payload=response.model_dump())
+
async def _handle_receive_external_message(self, envelope: Envelope) -> Envelope:
"""处理适配器插件上报的外部入站消息。
@@ -970,6 +1217,7 @@ class PluginRunnerSupervisor:
self._component_registry.clear()
self._registered_plugins.clear()
self._registered_adapters.clear()
+ self._adapter_runtime_states.clear()
self._runner_ready_events = asyncio.Event()
self._runner_ready_payloads = RunnerReadyPayload()
self._rpc_server.clear_handshake_state()
diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py
index d71e02c5..f68657fa 100644
--- a/src/plugin_runtime/protocol/envelope.py
+++ b/src/plugin_runtime/protocol/envelope.py
@@ -304,6 +304,30 @@ class AdapterDeclarationPayload(BaseModel):
"""适配器附加元数据"""
+class AdapterStateUpdatePayload(BaseModel):
+ """适配器运行时状态更新载荷。"""
+
+ connected: bool = Field(description="适配器当前是否已连接并准备接管路由")
+ """适配器当前是否已连接并准备接管路由"""
+ account_id: str = Field(default="", description="当前连接对应的账号 ID 或 self_id")
+ """当前连接对应的账号 ID 或 self_id"""
+ scope: str = Field(default="", description="当前连接对应的可选路由作用域")
+ """当前连接对应的可选路由作用域"""
+ metadata: Dict[str, Any] = Field(default_factory=dict, description="可选的运行时状态元数据")
+ """可选的运行时状态元数据"""
+
+
+class AdapterStateUpdateResultPayload(BaseModel):
+ """适配器运行时状态更新结果载荷。"""
+
+ accepted: bool = Field(description="Host 是否接受了本次状态更新")
+ """Host 是否接受了本次状态更新"""
+ connected: bool = Field(description="Host 记录的当前连接状态")
+ """Host 记录的当前连接状态"""
+ route_key: Dict[str, Any] = Field(default_factory=dict, description="当前生效的路由键")
+ """当前生效的路由键"""
+
+
class ReceiveExternalMessagePayload(BaseModel):
"""适配器插件向 Host 注入外部消息的请求载荷。"""
diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py
index 88f92494..8078c88b 100644
--- a/src/plugin_runtime/runner/runner_main.py
+++ b/src/plugin_runtime/runner/runner_main.py
@@ -481,13 +481,14 @@ class PluginRunner:
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
return False
- if not await self._invoke_plugin_on_load(meta):
+ if not await self._register_plugin(meta):
+ await self._invoke_plugin_on_unload(meta)
await self._deactivate_plugin(meta)
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
return False
- if not await self._register_plugin(meta):
- await self._invoke_plugin_on_unload(meta)
+ if not await self._invoke_plugin_on_load(meta):
+ await self._unregister_plugin(meta.plugin_id, reason="on_load_failed")
await self._deactivate_plugin(meta)
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
return False
diff --git a/src/plugins/built_in/napcat_adapter/plugin.py b/src/plugins/built_in/napcat_adapter/plugin.py
index a481101f..c8bb837b 100644
--- a/src/plugins/built_in/napcat_adapter/plugin.py
+++ b/src/plugins/built_in/napcat_adapter/plugin.py
@@ -60,6 +60,9 @@ class NapCatAdapterPlugin(MaiBotPlugin):
self._connection_task: Optional[asyncio.Task[None]] = None
self._pending_actions: Dict[str, asyncio.Future[Dict[str, Any]]] = {}
self._background_tasks: Set[asyncio.Task[Any]] = set()
+ self._reported_account_id: Optional[str] = None
+ self._reported_scope: Optional[str] = None
+ self._runtime_state_connected: bool = False
self._send_lock = asyncio.Lock()
self._ws: Optional[AiohttpClientWebSocketResponse] = None
@@ -170,6 +173,7 @@ class NapCatAdapterPlugin(MaiBotPlugin):
with contextlib.suppress(asyncio.CancelledError):
await connection_task
+ await self._report_adapter_disconnected()
self._fail_pending_actions("NapCat connection closed")
async def _cancel_background_tasks(self) -> None:
@@ -209,6 +213,7 @@ class NapCatAdapterPlugin(MaiBotPlugin):
self.ctx.logger.warning(f"NapCat 适配器连接失败: {exc}")
finally:
self._ws = None
+ await self._report_adapter_disconnected()
self._fail_pending_actions("NapCat connection interrupted")
if not self._should_connect():
@@ -230,26 +235,39 @@ class NapCatAdapterPlugin(MaiBotPlugin):
"""
assert WSMsgType is not None
- async for ws_message in ws:
- if ws_message.type != WSMsgType.TEXT:
- if ws_message.type in {WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.ERROR}:
- break
- continue
+ bootstrap_task = asyncio.create_task(
+ self._bootstrap_adapter_runtime_state(),
+ name="napcat_adapter.bootstrap",
+ )
+ self._background_tasks.add(bootstrap_task)
+ bootstrap_task.add_done_callback(self._background_tasks.discard)
- payload = self._parse_json_message(ws_message.data)
- if payload is None:
- continue
+ try:
+ async for ws_message in ws:
+ if ws_message.type != WSMsgType.TEXT:
+ if ws_message.type in {WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.ERROR}:
+ break
+ continue
- if echo_id := str(payload.get("echo") or "").strip():
- self._resolve_pending_action(echo_id, payload)
- continue
+ payload = self._parse_json_message(ws_message.data)
+ if payload is None:
+ continue
- if str(payload.get("post_type") or "").strip() != "message":
- continue
+ if echo_id := str(payload.get("echo") or "").strip():
+ self._resolve_pending_action(echo_id, payload)
+ continue
- task = asyncio.create_task(self._handle_inbound_message(payload), name="napcat_adapter.inbound")
- self._background_tasks.add(task)
- task.add_done_callback(self._background_tasks.discard)
+ if str(payload.get("post_type") or "").strip() != "message":
+ continue
+
+ task = asyncio.create_task(self._handle_inbound_message(payload), name="napcat_adapter.inbound")
+ self._background_tasks.add(task)
+ task.add_done_callback(self._background_tasks.discard)
+ finally:
+ if not bootstrap_task.done():
+ bootstrap_task.cancel()
+ with contextlib.suppress(asyncio.CancelledError):
+ await bootstrap_task
async def _handle_inbound_message(self, payload: Dict[str, Any]) -> None:
"""处理单条 NapCat 入站消息并注入 Host。
@@ -258,6 +276,9 @@ class NapCatAdapterPlugin(MaiBotPlugin):
payload: NapCat / OneBot 推送的原始事件数据。
"""
self_id = str(payload.get("self_id") or "").strip()
+ if self_id:
+ await self._report_adapter_connected(self_id)
+
sender = payload.get("sender", {})
if not isinstance(sender, dict):
sender = {}
@@ -570,6 +591,121 @@ class NapCatAdapterPlugin(MaiBotPlugin):
response_future.set_exception(RuntimeError(error_message))
self._pending_actions.clear()
+ async def _bootstrap_adapter_runtime_state(self) -> None:
+ """在连接建立后主动获取账号信息并激活适配器路由。
+
+ 该步骤会在 WebSocket 接收循环启动后异步执行,确保 `_call_action()`
+ 发出的 `get_login_info` 请求能够被同一连接上的接收循环消费到 echo
+ 响应,从而在真正收到业务消息前就完成 Host 侧 route 激活。
+ """
+ max_attempts = 3
+ last_error: Optional[Exception] = None
+ for attempt in range(1, max_attempts + 1):
+ ws = self._ws
+ if ws is None or ws.closed:
+ return
+
+ try:
+ response = await self._call_action("get_login_info", {})
+ self_id = self._extract_self_id_from_login_response(response)
+ await self._report_adapter_connected(self_id)
+ return
+ except asyncio.CancelledError:
+ raise
+ except Exception as exc:
+ last_error = exc
+ self.ctx.logger.warning(
+ f"NapCat 适配器获取登录信息失败,第 {attempt}/{max_attempts} 次重试: {exc}"
+ )
+ if attempt < max_attempts:
+ await asyncio.sleep(1.0)
+
+ if last_error is not None:
+ self.ctx.logger.error(f"NapCat 适配器未能完成路由激活,连接将保持只接收状态: {last_error}")
+
+ @staticmethod
+ def _extract_self_id_from_login_response(response: Dict[str, Any]) -> str:
+ """从 `get_login_info` 响应中提取当前账号 ID。
+
+ Args:
+ response: NapCat 返回的原始动作响应。
+
+ Returns:
+ str: 规范化后的 `self_id` 字符串。
+
+ Raises:
+ ValueError: 当响应中缺少有效账号 ID 时抛出。
+ """
+ if str(response.get("status") or "").lower() != "ok":
+ raise ValueError(str(response.get("wording") or response.get("message") or "get_login_info failed"))
+
+ response_data = response.get("data", {})
+ if not isinstance(response_data, dict):
+ raise ValueError("get_login_info 响应缺少 data 字段")
+
+ self_id = str(response_data.get("user_id") or "").strip()
+ if not self_id:
+ raise ValueError("get_login_info 响应缺少有效的 user_id")
+ return self_id
+
+ async def _report_adapter_connected(self, account_id: str) -> None:
+ """向 Host 上报当前连接已就绪。
+
+ Args:
+ account_id: 当前 NapCat 连接对应的机器人账号 ID。
+ """
+ normalized_account_id = str(account_id).strip()
+ if not normalized_account_id:
+ return
+
+ scope = self._get_string(self._connection_config(), "connection_id").strip()
+ if (
+ self._runtime_state_connected
+ and self._reported_account_id == normalized_account_id
+ and self._reported_scope == (scope or None)
+ ):
+ return
+
+ accepted = False
+ try:
+ accepted = await self.ctx.adapter.update_runtime_state(
+ connected=True,
+ account_id=normalized_account_id,
+ scope=scope,
+ metadata={"ws_url": self._get_string(self._connection_config(), "ws_url")},
+ )
+ except Exception as exc:
+ self.ctx.logger.warning(f"NapCat 适配器上报连接就绪状态失败: {exc}")
+ return
+
+ if not accepted:
+ self.ctx.logger.warning("NapCat 适配器连接已建立,但 Host 未接受运行时状态更新")
+ return
+
+ self._runtime_state_connected = True
+ self._reported_account_id = normalized_account_id
+ self._reported_scope = scope or None
+ self.ctx.logger.info(
+ f"NapCat 适配器已激活路由: platform=qq account_id={normalized_account_id} "
+ f"scope={self._reported_scope or '*'}"
+ )
+
+ async def _report_adapter_disconnected(self) -> None:
+ """向 Host 上报当前连接已断开,并撤销适配器路由。"""
+ if not self._runtime_state_connected:
+ self._reported_account_id = None
+ self._reported_scope = None
+ return
+
+ try:
+ await self.ctx.adapter.update_runtime_state(connected=False)
+ except Exception as exc:
+ self.ctx.logger.warning(f"NapCat 适配器上报断开状态失败: {exc}")
+ finally:
+ self._runtime_state_connected = False
+ self._reported_account_id = None
+ self._reported_scope = None
+
def _build_headers(self) -> Dict[str, str]:
"""构造连接 NapCat 所需的请求头。
From baabe4463ebea09d864b1dff8b78e04e78823751 Mon Sep 17 00:00:00 2001
From: DrSmoothl <1787882683@qq.com>
Date: Sun, 22 Mar 2026 00:19:26 +0800
Subject: [PATCH 24/45] feat: add NapCat built-in adapter with configuration,
filters, and transport layer
- Implemented configuration parsing for NapCat adapter including server, chat, and filter settings.
- Added message filtering logic to handle inbound chat messages based on user and group lists.
- Developed a transport layer for WebSocket communication with the NapCat server.
- Created a query service for fetching user and group information from the QQ platform.
- Implemented runtime state management to report connection status to the host.
- Added notice handling for various QQ platform events.
---
pytests/test_napcat_adapter_codec.py | 70 ++
pytests/test_napcat_adapter_config.py | 91 ++
pytests/test_platform_io_dedupe.py | 164 +++
pytests/test_plugin_message_utils_runtime.py | 87 ++
src/platform_io/manager.py | 35 +-
src/platform_io/types.py | 5 +-
src/plugin_runtime/host/message_utils.py | 268 ++++-
.../built_in/napcat_adapter/__init__.py | 1 +
.../built_in/napcat_adapter/codec_inbound.py | 414 +++++++
.../built_in/napcat_adapter/codec_outbound.py | 192 +++
src/plugins/built_in/napcat_adapter/config.py | 398 +++++++
.../built_in/napcat_adapter/constants.py | 9 +
.../built_in/napcat_adapter/filters.py | 68 ++
src/plugins/built_in/napcat_adapter/plugin.py | 1049 +++--------------
.../built_in/napcat_adapter/qq_notice.py | 224 ++++
.../built_in/napcat_adapter/qq_queries.py | 170 +++
.../built_in/napcat_adapter/runtime_state.py | 85 ++
.../built_in/napcat_adapter/transport.py | 322 +++++
18 files changed, 2755 insertions(+), 897 deletions(-)
create mode 100644 pytests/test_napcat_adapter_codec.py
create mode 100644 pytests/test_napcat_adapter_config.py
create mode 100644 pytests/test_platform_io_dedupe.py
create mode 100644 pytests/test_plugin_message_utils_runtime.py
create mode 100644 src/plugins/built_in/napcat_adapter/__init__.py
create mode 100644 src/plugins/built_in/napcat_adapter/codec_inbound.py
create mode 100644 src/plugins/built_in/napcat_adapter/codec_outbound.py
create mode 100644 src/plugins/built_in/napcat_adapter/config.py
create mode 100644 src/plugins/built_in/napcat_adapter/constants.py
create mode 100644 src/plugins/built_in/napcat_adapter/filters.py
create mode 100644 src/plugins/built_in/napcat_adapter/qq_notice.py
create mode 100644 src/plugins/built_in/napcat_adapter/qq_queries.py
create mode 100644 src/plugins/built_in/napcat_adapter/runtime_state.py
create mode 100644 src/plugins/built_in/napcat_adapter/transport.py
diff --git a/pytests/test_napcat_adapter_codec.py b/pytests/test_napcat_adapter_codec.py
new file mode 100644
index 00000000..6f557e08
--- /dev/null
+++ b/pytests/test_napcat_adapter_codec.py
@@ -0,0 +1,70 @@
+from pathlib import Path
+from typing import Any, Dict
+
+import importlib
+import sys
+
+
+BUILT_IN_PLUGIN_ROOT = Path(__file__).resolve().parents[1] / "src" / "plugins" / "built_in"
+if str(BUILT_IN_PLUGIN_ROOT) not in sys.path:
+ sys.path.insert(0, str(BUILT_IN_PLUGIN_ROOT))
+
+NapCatOutboundCodec = importlib.import_module("napcat_adapter.codec_outbound").NapCatOutboundCodec
+
+
+def test_napcat_outbound_codec_supports_binary_and_forward_segments() -> None:
+ codec = NapCatOutboundCodec()
+ raw_message = [
+ {"type": "text", "data": "hello"},
+ {"type": "image", "data": "", "hash": "h1", "binary_data_base64": "aW1hZ2U="},
+ {"type": "emoji", "data": "", "hash": "h2", "binary_data_base64": "ZW1vamk="},
+ {"type": "voice", "data": "", "hash": "h3", "binary_data_base64": "dm9pY2U="},
+ {
+ "type": "reply",
+ "data": {
+ "target_message_id": "origin-1",
+ "target_message_content": "origin text",
+ },
+ },
+ {
+ "type": "forward",
+ "data": [
+ {
+ "user_id": "42",
+ "user_nickname": "alice",
+ "user_cardname": "Alice",
+ "message_id": "fwd-1",
+ "content": [{"type": "text", "data": "node-text"}],
+ }
+ ],
+ },
+ ]
+
+ converted = codec.convert_segments(raw_message)
+
+ assert converted[0] == {"type": "text", "data": {"text": "hello"}}
+ assert converted[1]["type"] == "image"
+ assert converted[1]["data"]["file"] == "base64://aW1hZ2U="
+ assert converted[2]["type"] == "image"
+ assert converted[2]["data"]["subtype"] == 1
+ assert converted[3] == {"type": "record", "data": {"file": "base64://dm9pY2U="}}
+ assert converted[4] == {"type": "reply", "data": {"id": "origin-1"}}
+ assert converted[5]["type"] == "node"
+ assert converted[5]["data"]["name"] == "alice"
+ assert converted[5]["data"]["content"] == [{"type": "text", "data": {"text": "node-text"}}]
+
+
+def test_napcat_outbound_codec_builds_private_action_from_route_metadata() -> None:
+ codec = NapCatOutboundCodec()
+ message: Dict[str, Any] = {
+ "message_info": {
+ "user_info": {"user_id": "10001", "user_nickname": "tester"},
+ "additional_config": {},
+ },
+ "raw_message": [{"type": "text", "data": "hello"}],
+ }
+
+ action_name, params = codec.build_outbound_action(message, {"target_user_id": "30001"})
+
+ assert action_name == "send_private_msg"
+ assert params == {"message": [{"type": "text", "data": {"text": "hello"}}], "user_id": "30001"}
diff --git a/pytests/test_napcat_adapter_config.py b/pytests/test_napcat_adapter_config.py
new file mode 100644
index 00000000..688b1a48
--- /dev/null
+++ b/pytests/test_napcat_adapter_config.py
@@ -0,0 +1,91 @@
+from pathlib import Path
+from typing import List
+
+import importlib
+import sys
+
+
+BUILT_IN_PLUGIN_ROOT = Path(__file__).resolve().parents[1] / "src" / "plugins" / "built_in"
+if str(BUILT_IN_PLUGIN_ROOT) not in sys.path:
+ sys.path.insert(0, str(BUILT_IN_PLUGIN_ROOT))
+
+NapCatPluginSettings = importlib.import_module("napcat_adapter.config").NapCatPluginSettings
+
+
+class DummyLogger:
+ """用于测试的轻量日志对象。"""
+
+ def __init__(self) -> None:
+ """初始化测试日志对象。"""
+ self.warnings: List[str] = []
+ self.errors: List[str] = []
+
+ def warning(self, message: str) -> None:
+ """记录警告日志。
+
+ Args:
+ message: 待记录的日志内容。
+ """
+ self.warnings.append(message)
+
+ def error(self, message: str) -> None:
+ """记录错误日志。
+
+ Args:
+ message: 待记录的日志内容。
+ """
+ self.errors.append(message)
+
+
+def test_parse_new_napcat_server_config() -> None:
+ logger = DummyLogger()
+ settings = NapCatPluginSettings.from_mapping(
+ {
+ "plugin": {"enabled": True, "config_version": "0.1.0"},
+ "napcat_server": {
+ "host": "localhost",
+ "port": 8095,
+ "token": "secret",
+ "heartbeat_interval": 45,
+ "reconnect_delay_sec": 7,
+ "action_timeout_sec": 18,
+ "connection_id": "main",
+ },
+ },
+ logger,
+ )
+
+ assert settings.should_connect() is True
+ assert settings.napcat_server.host == "localhost"
+ assert settings.napcat_server.port == 8095
+ assert settings.napcat_server.token == "secret"
+ assert settings.napcat_server.heartbeat_interval == 45.0
+ assert settings.napcat_server.reconnect_delay_sec == 7.0
+ assert settings.napcat_server.action_timeout_sec == 18.0
+ assert settings.napcat_server.connection_id == "main"
+ assert settings.napcat_server.build_ws_url() == "ws://localhost:8095"
+ assert settings.validate(logger) is True
+
+
+def test_parse_legacy_connection_ws_url_fallback() -> None:
+ logger = DummyLogger()
+ settings = NapCatPluginSettings.from_mapping(
+ {
+ "plugin": {"enabled": True, "config_version": "0.1.0"},
+ "connection": {
+ "ws_url": "ws://127.0.0.1:3001",
+ "access_token": "legacy-token",
+ "heartbeat_sec": 35,
+ "action_timeout_sec": 12,
+ },
+ },
+ logger,
+ )
+
+ assert settings.napcat_server.host == "127.0.0.1"
+ assert settings.napcat_server.port == 3001
+ assert settings.napcat_server.token == "legacy-token"
+ assert settings.napcat_server.heartbeat_interval == 35.0
+ assert settings.napcat_server.action_timeout_sec == 12.0
+ assert settings.validate(logger) is True
+ assert logger.warnings
diff --git a/pytests/test_platform_io_dedupe.py b/pytests/test_platform_io_dedupe.py
new file mode 100644
index 00000000..4a3cbb44
--- /dev/null
+++ b/pytests/test_platform_io_dedupe.py
@@ -0,0 +1,164 @@
+"""Platform IO 入站去重策略测试。"""
+
+from types import SimpleNamespace
+from typing import Any, Dict, List, Optional
+
+import pytest
+
+from src.platform_io.drivers.base import PlatformIODriver
+from src.platform_io.manager import PlatformIOManager
+from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey
+
+
+def _build_envelope(
+ *,
+ dedupe_key: str | None = None,
+ external_message_id: str | None = None,
+ session_message_id: str | None = None,
+ payload: Optional[Dict[str, Any]] = None,
+) -> InboundMessageEnvelope:
+ """构造测试用入站信封。
+
+ Args:
+ dedupe_key: 显式去重键。
+ external_message_id: 平台侧消息 ID。
+ session_message_id: 规范化消息对象上的消息 ID。
+ payload: 原始载荷。
+
+ Returns:
+ InboundMessageEnvelope: 测试用入站消息信封。
+ """
+ session_message = None
+ if session_message_id is not None:
+ session_message = SimpleNamespace(message_id=session_message_id)
+
+ return InboundMessageEnvelope(
+ route_key=RouteKey(platform="qq", account_id="10001", scope="main"),
+ driver_id="plugin.napcat",
+ driver_kind=DriverKind.PLUGIN,
+ dedupe_key=dedupe_key,
+ external_message_id=external_message_id,
+ session_message=session_message,
+ payload=payload,
+ )
+
+
+class _StubPlatformIODriver(PlatformIODriver):
+ """测试用 Platform IO 驱动。"""
+
+ async def send_message(
+ self,
+ message: Any,
+ route_key: RouteKey,
+ metadata: Optional[Dict[str, Any]] = None,
+ ) -> DeliveryReceipt:
+ """返回一个固定的成功回执。
+
+ Args:
+ message: 待发送的消息对象。
+ route_key: 本次发送使用的路由键。
+ metadata: 额外发送元数据。
+
+ Returns:
+ DeliveryReceipt: 固定的成功回执。
+ """
+ return DeliveryReceipt(
+ internal_message_id=str(getattr(message, "message_id", "stub-message-id")),
+ route_key=route_key,
+ status=DeliveryStatus.SENT,
+ driver_id=self.driver_id,
+ driver_kind=self.descriptor.kind,
+ )
+
+
+def _build_manager() -> PlatformIOManager:
+ """构造带有最小 active owner 的 Broker 管理器。
+
+ Returns:
+ PlatformIOManager: 已注册测试驱动并绑定活动路由的 Broker。
+ """
+ manager = PlatformIOManager()
+ driver = _StubPlatformIODriver(
+ DriverDescriptor(
+ driver_id="plugin.napcat",
+ kind=DriverKind.PLUGIN,
+ platform="qq",
+ account_id="10001",
+ scope="main",
+ )
+ )
+ manager.register_driver(driver)
+ manager.bind_route(
+ RouteBinding(
+ route_key=RouteKey(platform="qq", account_id="10001", scope="main"),
+ driver_id=driver.driver_id,
+ driver_kind=driver.descriptor.kind,
+ )
+ )
+ return manager
+
+
+class TestPlatformIODedupe:
+ """Platform IO 去重测试。"""
+
+ @pytest.mark.asyncio
+ async def test_accept_inbound_dedupes_by_external_message_id(self) -> None:
+ """相同平台消息 ID 的重复入站应被抑制。"""
+ manager = _build_manager()
+ accepted_envelopes: List[InboundMessageEnvelope] = []
+
+ async def dispatcher(envelope: InboundMessageEnvelope) -> None:
+ """记录被成功接收的入站消息。
+
+ Args:
+ envelope: 被 Broker 接受的入站消息。
+ """
+ accepted_envelopes.append(envelope)
+
+ manager.set_inbound_dispatcher(dispatcher)
+
+ first_envelope = _build_envelope(
+ external_message_id="msg-1",
+ payload={"message": "hello"},
+ )
+ second_envelope = _build_envelope(
+ external_message_id="msg-1",
+ payload={"message": "hello"},
+ )
+
+ assert await manager.accept_inbound(first_envelope) is True
+ assert await manager.accept_inbound(second_envelope) is False
+ assert len(accepted_envelopes) == 1
+
+ @pytest.mark.asyncio
+ async def test_accept_inbound_without_stable_identity_does_not_guess_duplicate(self) -> None:
+ """缺少稳定身份时,不应仅凭 payload 内容猜测重复消息。"""
+ manager = _build_manager()
+ accepted_envelopes: List[InboundMessageEnvelope] = []
+
+ async def dispatcher(envelope: InboundMessageEnvelope) -> None:
+ """记录被成功接收的入站消息。
+
+ Args:
+ envelope: 被 Broker 接受的入站消息。
+ """
+ accepted_envelopes.append(envelope)
+
+ manager.set_inbound_dispatcher(dispatcher)
+
+ first_envelope = _build_envelope(payload={"message": "same-payload"})
+ second_envelope = _build_envelope(payload={"message": "same-payload"})
+
+ assert await manager.accept_inbound(first_envelope) is True
+ assert await manager.accept_inbound(second_envelope) is True
+ assert len(accepted_envelopes) == 2
+
+ def test_build_inbound_dedupe_key_prefers_explicit_identity(self) -> None:
+ """去重键应只来自显式或稳定的技术身份。"""
+ explicit_envelope = _build_envelope(dedupe_key="dedupe-1", external_message_id="msg-1")
+ session_message_envelope = _build_envelope(session_message_id="session-1")
+ payload_only_envelope = _build_envelope(payload={"message": "hello"})
+
+ assert PlatformIOManager._build_inbound_dedupe_key(explicit_envelope) == "qq:10001:main:dedupe-1"
+ assert PlatformIOManager._build_inbound_dedupe_key(session_message_envelope) == "qq:10001:main:session-1"
+ assert PlatformIOManager._build_inbound_dedupe_key(payload_only_envelope) is None
diff --git a/pytests/test_plugin_message_utils_runtime.py b/pytests/test_plugin_message_utils_runtime.py
new file mode 100644
index 00000000..cb4b5341
--- /dev/null
+++ b/pytests/test_plugin_message_utils_runtime.py
@@ -0,0 +1,87 @@
+from datetime import datetime
+from pathlib import Path
+
+import sys
+
+from src.chat.message_receive.message import SessionMessage
+from src.common.data_models.mai_message_data_model import GroupInfo, MessageInfo, UserInfo
+from src.common.data_models.message_component_data_model import (
+ ForwardComponent,
+ ForwardNodeComponent,
+ ImageComponent,
+ MessageSequence,
+ ReplyComponent,
+ TextComponent,
+ VoiceComponent,
+)
+from src.plugin_runtime.host.message_utils import PluginMessageUtils
+
+
+PROJECT_ROOT = Path(__file__).resolve().parents[1]
+if str(PROJECT_ROOT) not in sys.path:
+ sys.path.insert(0, str(PROJECT_ROOT))
+
+
+def test_plugin_message_utils_preserves_binary_components_and_reply_metadata() -> None:
+ message = SessionMessage(message_id="msg-1", timestamp=datetime.now(), platform="qq")
+ message.message_info = MessageInfo(
+ user_info=UserInfo(user_id="10001", user_nickname="tester"),
+ group_info=GroupInfo(group_id="20001", group_name="group"),
+ additional_config={"self_id": "999"},
+ )
+ message.session_id = "qq:20001:10001"
+ message.processed_plain_text = "binary payload"
+ message.display_message = "binary payload"
+ message.raw_message = MessageSequence(
+ components=[
+ TextComponent("hello"),
+ ImageComponent(binary_hash="", binary_data=b"image-bytes", content=""),
+ VoiceComponent(binary_hash="", binary_data=b"voice-bytes", content=""),
+ ReplyComponent(
+ target_message_id="origin-1",
+ target_message_content="origin text",
+ target_message_sender_id="42",
+ target_message_sender_nickname="alice",
+ target_message_sender_cardname="Alice",
+ ),
+ ForwardNodeComponent(
+ forward_components=[
+ ForwardComponent(
+ user_nickname="bob",
+ user_id="43",
+ user_cardname="Bob",
+ message_id="forward-1",
+ content=[
+ TextComponent("node-text"),
+ ImageComponent(binary_hash="", binary_data=b"node-image", content=""),
+ ],
+ )
+ ]
+ ),
+ ]
+ )
+
+ message_dict = PluginMessageUtils._session_message_to_dict(message)
+ rebuilt_message = PluginMessageUtils._build_session_message_from_dict(dict(message_dict))
+
+ image_component = rebuilt_message.raw_message.components[1]
+ voice_component = rebuilt_message.raw_message.components[2]
+ reply_component = rebuilt_message.raw_message.components[3]
+ forward_component = rebuilt_message.raw_message.components[4]
+
+ assert isinstance(image_component, ImageComponent)
+ assert image_component.binary_data == b"image-bytes"
+
+ assert isinstance(voice_component, VoiceComponent)
+ assert voice_component.binary_data == b"voice-bytes"
+
+ assert isinstance(reply_component, ReplyComponent)
+ assert reply_component.target_message_id == "origin-1"
+ assert reply_component.target_message_content == "origin text"
+ assert reply_component.target_message_sender_id == "42"
+ assert reply_component.target_message_sender_nickname == "alice"
+ assert reply_component.target_message_sender_cardname == "Alice"
+
+ assert isinstance(forward_component, ForwardNodeComponent)
+ assert isinstance(forward_component.forward_components[0].content[1], ImageComponent)
+ assert forward_component.forward_components[0].content[1].binary_data == b"node-image"
diff --git a/src/platform_io/manager.py b/src/platform_io/manager.py
index 97835667..b1fe3bdc 100644
--- a/src/platform_io/manager.py
+++ b/src/platform_io/manager.py
@@ -2,9 +2,6 @@
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional
-import hashlib
-import json
-
from src.common.logger import get_logger
from src.platform_io.drivers.base import PlatformIODriver
@@ -438,12 +435,17 @@ class PlatformIOManager:
Returns:
Optional[str]: 若可以构造稳定去重键则返回该键,否则返回 ``None``。
+
+ Notes:
+ 这里仅接受上游显式提供的稳定消息身份,例如 ``dedupe_key``、
+ 平台侧 ``external_message_id`` 或已经完成规范化的
+ ``session_message.message_id``。Broker 不再根据 ``payload`` 内容
+ 猜测语义去重键,避免把“短时间内两条内容刚好完全相同”的合法消息
+ 误判为重复入站。
"""
raw_dedupe_key = envelope.dedupe_key or envelope.external_message_id
if raw_dedupe_key is None and envelope.session_message is not None:
raw_dedupe_key = envelope.session_message.message_id
- if raw_dedupe_key is None and envelope.payload is not None:
- raw_dedupe_key = PlatformIOManager._build_payload_fingerprint(envelope.payload)
if raw_dedupe_key is None:
return None
@@ -453,29 +455,6 @@ class PlatformIOManager:
return f"{envelope.route_key.to_dedupe_scope()}:{normalized_dedupe_key}"
- @staticmethod
- def _build_payload_fingerprint(payload: Dict[str, Any]) -> Optional[str]:
- """根据消息载荷构造稳定指纹。
-
- Args:
- payload: 待构造指纹的原始载荷字典。
-
- Returns:
- Optional[str]: 若成功生成指纹则返回十六进制摘要,否则返回 ``None``。
- """
- try:
- serialized_payload = json.dumps(
- payload,
- default=str,
- ensure_ascii=True,
- separators=(",", ":"),
- sort_keys=True,
- )
- except Exception:
- return None
-
- return hashlib.sha256(serialized_payload.encode()).hexdigest()
-
@staticmethod
def _validate_binding_against_driver(binding: RouteBinding, driver: PlatformIODriver) -> None:
"""校验路由绑定与驱动描述是否一致。
diff --git a/src/platform_io/types.py b/src/platform_io/types.py
index c74dc246..8729b637 100644
--- a/src/platform_io/types.py
+++ b/src/platform_io/types.py
@@ -198,8 +198,9 @@ class InboundMessageEnvelope:
driver_kind: 产出该消息的驱动类型。
external_message_id: 可选的平台侧消息 ID,用于去重。
dedupe_key: 可选的显式去重键。当外部消息没有稳定 ``message_id`` 时,
- 可由上游驱动提供消息指纹。若这里为空,中间层仍可能继续回退到
- ``session_message.message_id`` 或 ``payload`` 指纹。
+ 可由上游驱动提供稳定的技术性幂等键。若这里为空,中间层仅会继续
+ 回退到 ``external_message_id`` 或 ``session_message.message_id``,
+ 不会再根据 ``payload`` 内容猜测语义去重键。
session_message: 可选的、已经完成规范化的 ``SessionMessage`` 对象。
payload: 可选的原始字典载荷,供延迟转换或调试使用。
metadata: 额外入站元数据,例如连接信息或追踪上下文。
diff --git a/src/plugin_runtime/host/message_utils.py b/src/plugin_runtime/host/message_utils.py
index aaebb529..2f6aa01b 100644
--- a/src/plugin_runtime/host/message_utils.py
+++ b/src/plugin_runtime/host/message_utils.py
@@ -1,10 +1,25 @@
from datetime import datetime
-from typing import Dict, Any, TypedDict, Optional, List
+from typing import Any, Dict, List, Optional, TypedDict
+
+import base64
+import hashlib
from src.common.logger import get_logger
from src.chat.message_receive.message import SessionMessage
from src.common.data_models.mai_message_data_model import UserInfo, GroupInfo, MessageInfo
-from src.common.data_models.message_component_data_model import MessageSequence
+from src.common.data_models.message_component_data_model import (
+ AtComponent,
+ DictComponent,
+ EmojiComponent,
+ ForwardComponent,
+ ForwardNodeComponent,
+ ImageComponent,
+ MessageSequence,
+ ReplyComponent,
+ StandardMessageComponents,
+ TextComponent,
+ VoiceComponent,
+)
logger = get_logger("plugin_runtime.host.message_utils")
@@ -45,6 +60,251 @@ class MessageDict(TypedDict, total=False):
class PluginMessageUtils:
+ @staticmethod
+ def _message_sequence_to_dict(message_sequence: MessageSequence) -> List[Dict[str, Any]]:
+ """将消息组件序列转换为插件运行时使用的字典结构。
+
+ Args:
+ message_sequence: 待转换的消息组件序列。
+
+ Returns:
+ List[Dict[str, Any]]: 供插件运行时协议使用的消息段字典列表。
+ """
+ return [PluginMessageUtils._component_to_dict(component) for component in message_sequence.components]
+
+ @staticmethod
+ def _component_to_dict(component: StandardMessageComponents) -> Dict[str, Any]:
+ """将单个消息组件转换为插件运行时字典结构。
+
+ Args:
+ component: 待转换的消息组件。
+
+ Returns:
+ Dict[str, Any]: 序列化后的消息组件字典。
+ """
+ if isinstance(component, TextComponent):
+ return {"type": "text", "data": component.text}
+
+ if isinstance(component, ImageComponent):
+ serialized = {
+ "type": "image",
+ "data": component.content,
+ "hash": component.binary_hash,
+ }
+ if component.binary_data:
+ serialized["binary_data_base64"] = base64.b64encode(component.binary_data).decode("utf-8")
+ return serialized
+
+ if isinstance(component, EmojiComponent):
+ serialized = {
+ "type": "emoji",
+ "data": component.content,
+ "hash": component.binary_hash,
+ }
+ if component.binary_data:
+ serialized["binary_data_base64"] = base64.b64encode(component.binary_data).decode("utf-8")
+ return serialized
+
+ if isinstance(component, VoiceComponent):
+ serialized = {
+ "type": "voice",
+ "data": component.content,
+ "hash": component.binary_hash,
+ }
+ if component.binary_data:
+ serialized["binary_data_base64"] = base64.b64encode(component.binary_data).decode("utf-8")
+ return serialized
+
+ if isinstance(component, AtComponent):
+ return {
+ "type": "at",
+ "data": {
+ "target_user_id": component.target_user_id,
+ "target_user_nickname": component.target_user_nickname,
+ "target_user_cardname": component.target_user_cardname,
+ },
+ }
+
+ if isinstance(component, ReplyComponent):
+ return {
+ "type": "reply",
+ "data": {
+ "target_message_id": component.target_message_id,
+ "target_message_content": component.target_message_content,
+ "target_message_sender_id": component.target_message_sender_id,
+ "target_message_sender_nickname": component.target_message_sender_nickname,
+ "target_message_sender_cardname": component.target_message_sender_cardname,
+ },
+ }
+
+ if isinstance(component, ForwardNodeComponent):
+ return {
+ "type": "forward",
+ "data": [PluginMessageUtils._forward_component_to_dict(item) for item in component.forward_components],
+ }
+
+ return {"type": "dict", "data": component.data}
+
+ @staticmethod
+ def _forward_component_to_dict(component: ForwardComponent) -> Dict[str, Any]:
+ """将单个转发节点组件转换为字典结构。
+
+ Args:
+ component: 待转换的转发节点组件。
+
+ Returns:
+ Dict[str, Any]: 序列化后的转发节点字典。
+ """
+ return {
+ "user_id": component.user_id,
+ "user_nickname": component.user_nickname,
+ "user_cardname": component.user_cardname,
+ "message_id": component.message_id,
+ "content": [PluginMessageUtils._component_to_dict(item) for item in component.content],
+ }
+
+ @staticmethod
+ def _message_sequence_from_dict(raw_message_data: List[Dict[str, Any]]) -> MessageSequence:
+ """从插件运行时字典结构恢复消息组件序列。
+
+ Args:
+ raw_message_data: 插件运行时消息段字典列表。
+
+ Returns:
+ MessageSequence: 恢复后的消息组件序列。
+ """
+ components = [PluginMessageUtils._component_from_dict(item) for item in raw_message_data]
+ return MessageSequence(components=components)
+
+ @staticmethod
+ def _component_from_dict(item: Dict[str, Any]) -> StandardMessageComponents:
+ """从插件运行时字典结构恢复单个消息组件。
+
+ Args:
+ item: 单个消息组件的字典表示。
+
+ Returns:
+ StandardMessageComponents: 恢复后的内部消息组件对象。
+ """
+ item_type = str(item.get("type") or "").strip()
+ if item_type == "text":
+ return TextComponent(text=str(item.get("data") or ""))
+
+ if item_type == "image":
+ return PluginMessageUtils._build_binary_component(ImageComponent, item)
+
+ if item_type == "emoji":
+ return PluginMessageUtils._build_binary_component(EmojiComponent, item)
+
+ if item_type == "voice":
+ return PluginMessageUtils._build_binary_component(VoiceComponent, item)
+
+ if item_type == "at":
+ item_data = item.get("data", {})
+ if not isinstance(item_data, dict):
+ item_data = {}
+ return AtComponent(
+ target_user_id=str(item_data.get("target_user_id") or ""),
+ target_user_nickname=PluginMessageUtils._normalize_optional_string(item_data.get("target_user_nickname")),
+ target_user_cardname=PluginMessageUtils._normalize_optional_string(item_data.get("target_user_cardname")),
+ )
+
+ if item_type == "reply":
+ reply_data = item.get("data")
+ if isinstance(reply_data, dict):
+ return ReplyComponent(
+ target_message_id=str(reply_data.get("target_message_id") or ""),
+ target_message_content=PluginMessageUtils._normalize_optional_string(
+ reply_data.get("target_message_content")
+ ),
+ target_message_sender_id=PluginMessageUtils._normalize_optional_string(
+ reply_data.get("target_message_sender_id")
+ ),
+ target_message_sender_nickname=PluginMessageUtils._normalize_optional_string(
+ reply_data.get("target_message_sender_nickname")
+ ),
+ target_message_sender_cardname=PluginMessageUtils._normalize_optional_string(
+ reply_data.get("target_message_sender_cardname")
+ ),
+ )
+ return ReplyComponent(target_message_id=str(reply_data or ""))
+
+ if item_type == "forward":
+ forward_nodes: List[ForwardComponent] = []
+ raw_forward_nodes = item.get("data", [])
+ if isinstance(raw_forward_nodes, list):
+ for node in raw_forward_nodes:
+ if not isinstance(node, dict):
+ continue
+ raw_content = node.get("content", [])
+ node_components: List[StandardMessageComponents] = []
+ if isinstance(raw_content, list):
+ node_components = [
+ PluginMessageUtils._component_from_dict(content)
+ for content in raw_content
+ if isinstance(content, dict)
+ ]
+ if not node_components:
+ node_components = [TextComponent(text="[empty forward node]")]
+ forward_nodes.append(
+ ForwardComponent(
+ user_nickname=str(node.get("user_nickname") or "未知用户"),
+ user_id=PluginMessageUtils._normalize_optional_string(node.get("user_id")),
+ user_cardname=PluginMessageUtils._normalize_optional_string(node.get("user_cardname")),
+ message_id=str(node.get("message_id") or ""),
+ content=node_components,
+ )
+ )
+ if not forward_nodes:
+ return DictComponent(data={"type": "forward", "data": item.get("data", [])})
+ return ForwardNodeComponent(forward_components=forward_nodes)
+
+ component_data = item.get("data")
+ if isinstance(component_data, dict):
+ return DictComponent(data=component_data)
+ return DictComponent(data=item)
+
+ @staticmethod
+ def _build_binary_component(component_cls: Any, item: Dict[str, Any]) -> StandardMessageComponents:
+ """从字典构造带二进制负载的消息组件。
+
+ Args:
+ component_cls: 目标组件类型。
+ item: 消息组件字典。
+
+ Returns:
+ StandardMessageComponents: 构造后的组件对象。
+ """
+ content = str(item.get("data") or "")
+ binary_hash = str(item.get("hash") or "")
+ raw_binary_base64 = item.get("binary_data_base64")
+ binary_data = b""
+ if isinstance(raw_binary_base64, str) and raw_binary_base64:
+ try:
+ binary_data = base64.b64decode(raw_binary_base64)
+ except Exception:
+ binary_data = b""
+
+ if not binary_hash and binary_data:
+ binary_hash = hashlib.sha256(binary_data).hexdigest()
+
+ return component_cls(binary_hash=binary_hash, content=content, binary_data=binary_data)
+
+ @staticmethod
+ def _normalize_optional_string(value: Any) -> Optional[str]:
+ """将任意值规范化为可选字符串。
+
+ Args:
+ value: 待规范化的值。
+
+ Returns:
+ Optional[str]: 规范化后的字符串;若值为空则返回 ``None``。
+ """
+ if value is None:
+ return None
+ normalized_value = str(value)
+ return normalized_value if normalized_value else None
+
@staticmethod
def _message_info_to_dict(message_info: MessageInfo) -> MessageInfoDict:
"""
@@ -92,7 +352,7 @@ class PluginMessageUtils:
timestamp=str(session_message.timestamp.timestamp()), # 转换为时间戳字符串
platform=session_message.platform,
message_info=PluginMessageUtils._message_info_to_dict(session_message.message_info),
- raw_message=session_message.raw_message.to_dict(), # 复用 MessageSequence.to_dict()
+ raw_message=PluginMessageUtils._message_sequence_to_dict(session_message.raw_message),
is_mentioned=session_message.is_mentioned,
is_at=session_message.is_at,
is_emoji=session_message.is_emoji,
@@ -186,7 +446,7 @@ class PluginMessageUtils:
# 构建原始消息组件序列(复用 MessageSequence.from_dict 方法)
raw_message_data = message_dict["raw_message"]
if isinstance(raw_message_data, list):
- session_message.raw_message = MessageSequence.from_dict(raw_message_data)
+ session_message.raw_message = PluginMessageUtils._message_sequence_from_dict(raw_message_data)
else:
raise ValueError("消息字典中 'raw_message' 字段必须是一个列表")
diff --git a/src/plugins/built_in/napcat_adapter/__init__.py b/src/plugins/built_in/napcat_adapter/__init__.py
new file mode 100644
index 00000000..fa82860f
--- /dev/null
+++ b/src/plugins/built_in/napcat_adapter/__init__.py
@@ -0,0 +1 @@
+"""NapCat 内置适配器插件包。"""
diff --git a/src/plugins/built_in/napcat_adapter/codec_inbound.py b/src/plugins/built_in/napcat_adapter/codec_inbound.py
new file mode 100644
index 00000000..b8065585
--- /dev/null
+++ b/src/plugins/built_in/napcat_adapter/codec_inbound.py
@@ -0,0 +1,414 @@
+"""NapCat 入站消息编解码。"""
+
+from typing import Any, Dict, List, Mapping, Optional, Tuple
+from uuid import uuid4
+
+import hashlib
+import json
+import time
+
+from napcat_adapter.qq_queries import NapCatQueryService
+
+
+class NapCatInboundCodec:
+ """NapCat 入站消息编码器。"""
+
+ def __init__(self, logger: Any, query_service: NapCatQueryService) -> None:
+ """初始化入站消息编码器。
+
+ Args:
+ logger: 插件日志对象。
+ query_service: QQ 查询服务。
+ """
+ self._logger = logger
+ self._query_service = query_service
+
+ async def build_message_dict(
+ self,
+ payload: Mapping[str, Any],
+ self_id: str,
+ sender_user_id: str,
+ sender: Mapping[str, Any],
+ ) -> Dict[str, Any]:
+ """构造 Host 侧可接受的 ``MessageDict``。
+
+ Args:
+ payload: NapCat 原始消息事件。
+ self_id: 当前机器人账号 ID。
+ sender_user_id: 发送者用户 ID。
+ sender: 发送者信息字典。
+
+ Returns:
+ Dict[str, Any]: 规范化后的 ``MessageDict``。
+ """
+ message_type = str(payload.get("message_type") or "").strip() or "private"
+ group_id = str(payload.get("group_id") or "").strip()
+ group_name = str(payload.get("group_name") or "").strip() or (f"group_{group_id}" if group_id else "")
+ user_nickname = str(sender.get("nickname") or sender.get("card") or sender_user_id).strip() or sender_user_id
+ user_cardname = str(sender.get("card") or "").strip() or None
+
+ raw_message, is_at = await self.convert_segments(payload, self_id)
+ raw_message_text = str(payload.get("raw_message") or "").strip()
+ if not raw_message:
+ raw_message = [{"type": "text", "data": raw_message_text or "[unsupported]"}]
+
+ plain_text = self.build_plain_text(raw_message, raw_message_text)
+ timestamp_seconds = payload.get("time")
+ if not isinstance(timestamp_seconds, (int, float)):
+ timestamp_seconds = time.time()
+
+ additional_config: Dict[str, Any] = {"self_id": self_id, "napcat_message_type": message_type}
+ if group_id:
+ additional_config["platform_io_target_group_id"] = group_id
+ else:
+ additional_config["platform_io_target_user_id"] = sender_user_id
+
+ message_info: Dict[str, Any] = {
+ "user_info": {
+ "user_id": sender_user_id,
+ "user_nickname": user_nickname,
+ "user_cardname": user_cardname,
+ },
+ "additional_config": additional_config,
+ }
+ if group_id:
+ message_info["group_info"] = {"group_id": group_id, "group_name": group_name}
+
+ message_id = str(payload.get("message_id") or f"napcat-{uuid4().hex}").strip()
+ return {
+ "message_id": message_id,
+ "timestamp": str(float(timestamp_seconds)),
+ "platform": "qq",
+ "message_info": message_info,
+ "raw_message": raw_message,
+ "is_mentioned": is_at,
+ "is_at": is_at,
+ "is_emoji": False,
+ "is_picture": False,
+ "is_command": plain_text.startswith("/"),
+ "is_notify": False,
+ "session_id": "",
+ "processed_plain_text": plain_text,
+ "display_message": plain_text,
+ }
+
+ async def convert_segments(self, payload: Mapping[str, Any], self_id: str) -> Tuple[List[Dict[str, Any]], bool]:
+ """将 OneBot 消息段转换为 Host 消息段结构。
+
+ Args:
+ payload: OneBot 原始消息事件。
+ self_id: 当前机器人账号 ID。
+
+ Returns:
+ Tuple[List[Dict[str, Any]], bool]: 转换后的消息段列表,以及是否 @ 到当前机器人。
+ """
+ message_payload = payload.get("message")
+ if isinstance(message_payload, str):
+ normalized_text = message_payload.strip()
+ return ([{"type": "text", "data": normalized_text}] if normalized_text else []), False
+
+ if not isinstance(message_payload, list):
+ return [], False
+
+ converted_segments: List[Dict[str, Any]] = []
+ is_at = False
+ for segment in message_payload:
+ if not isinstance(segment, Mapping):
+ continue
+
+ segment_type = str(segment.get("type") or "").strip()
+ segment_data = segment.get("data", {})
+ if not isinstance(segment_data, Mapping):
+ segment_data = {}
+
+ if segment_type == "text":
+ if text_value := str(segment_data.get("text") or ""):
+ converted_segments.append({"type": "text", "data": text_value})
+ continue
+
+ if segment_type == "at":
+ if target_user_id := str(segment_data.get("qq") or "").strip():
+ converted_segments.append(
+ {
+ "type": "at",
+ "data": {
+ "target_user_id": target_user_id,
+ "target_user_nickname": None,
+ "target_user_cardname": None,
+ },
+ }
+ )
+ if self_id and target_user_id == self_id:
+ is_at = True
+ continue
+
+ if segment_type == "reply":
+ if reply_segment := await self._build_reply_segment(segment_data):
+ converted_segments.append(reply_segment)
+ continue
+
+ if segment_type == "face":
+ converted_segments.append({"type": "text", "data": "[face]"})
+ continue
+
+ if segment_type == "image":
+ converted_segments.append(await self._build_image_like_segment(segment_data, is_emoji=False))
+ continue
+
+ if segment_type == "record":
+ converted_segments.append(await self._build_record_segment(segment_data))
+ continue
+
+ if segment_type == "video":
+ converted_segments.append({"type": "text", "data": "[video]"})
+ continue
+
+ if segment_type == "file":
+ converted_segments.append({"type": "text", "data": "[file]"})
+ continue
+
+ if segment_type == "json":
+ converted_segments.append(self._build_json_text_segment(segment_data))
+ continue
+
+ if segment_type == "forward":
+ if forward_segment := await self._build_forward_segment(segment_data):
+ converted_segments.append(forward_segment)
+ continue
+
+ if segment_type in {"xml", "share"}:
+ converted_segments.append({"type": "text", "data": f"[{segment_type}]"})
+
+ return converted_segments, is_at
+
+ async def _build_reply_segment(self, segment_data: Mapping[str, Any]) -> Optional[Dict[str, Any]]:
+ """构造回复消息段。
+
+ Args:
+ segment_data: OneBot ``reply`` 段的 ``data`` 字典。
+
+ Returns:
+ Optional[Dict[str, Any]]: 转换后的回复消息段;缺少消息 ID 时返回 ``None``。
+ """
+ target_message_id = str(segment_data.get("id") or "").strip()
+ if not target_message_id:
+ return None
+
+ message_detail = await self._query_service.get_message_detail(target_message_id)
+ reply_payload: Dict[str, Any] = {"target_message_id": target_message_id}
+ if message_detail is not None:
+ sender = message_detail.get("sender", {})
+ if not isinstance(sender, Mapping):
+ sender = {}
+ reply_payload["target_message_content"] = str(message_detail.get("raw_message") or "").strip() or None
+ reply_payload["target_message_sender_id"] = str(
+ message_detail.get("user_id") or sender.get("user_id") or ""
+ ).strip() or None
+ reply_payload["target_message_sender_nickname"] = str(sender.get("nickname") or "").strip() or None
+ reply_payload["target_message_sender_cardname"] = str(sender.get("card") or "").strip() or None
+
+ return {"type": "reply", "data": reply_payload}
+
+ async def _build_image_like_segment(
+ self,
+ segment_data: Mapping[str, Any],
+ is_emoji: bool,
+ ) -> Dict[str, Any]:
+ """构造图片或表情消息段。
+
+ Args:
+ segment_data: OneBot ``image`` 段的 ``data`` 字典。
+ is_emoji: 是否按表情组件处理。
+
+ Returns:
+ Dict[str, Any]: 转换后的图片或表情消息段。
+ """
+ subtype = segment_data.get("sub_type")
+ actual_is_emoji = is_emoji or (isinstance(subtype, int) and subtype not in {0, 4, 9})
+
+ image_url = str(segment_data.get("url") or "").strip()
+ binary_data = await self._query_service.download_binary(image_url)
+ if not binary_data:
+ return {"type": "text", "data": "[emoji]" if actual_is_emoji else "[image]"}
+
+ return {
+ "type": "emoji" if actual_is_emoji else "image",
+ "data": "",
+ "hash": hashlib.sha256(binary_data).hexdigest(),
+ "binary_data_base64": self._encode_binary(binary_data),
+ }
+
+ async def _build_record_segment(self, segment_data: Mapping[str, Any]) -> Dict[str, Any]:
+ """构造语音消息段。
+
+ Args:
+ segment_data: OneBot ``record`` 段的 ``data`` 字典。
+
+ Returns:
+ Dict[str, Any]: 转换后的语音或占位文本消息段。
+ """
+ file_name = str(segment_data.get("file") or "").strip()
+ file_id = str(segment_data.get("file_id") or "").strip() or None
+ if not file_name:
+ return {"type": "text", "data": "[voice]"}
+
+ record_detail = await self._query_service.get_record_detail(file_name=file_name, file_id=file_id)
+ if record_detail is None:
+ return {"type": "text", "data": "[voice]"}
+
+ record_base64 = str(record_detail.get("base64") or "").strip()
+ if not record_base64:
+ return {"type": "text", "data": "[voice]"}
+
+ try:
+ binary_data = self._decode_binary(record_base64)
+ except Exception:
+ return {"type": "text", "data": "[voice]"}
+
+ return {
+ "type": "voice",
+ "data": "",
+ "hash": hashlib.sha256(binary_data).hexdigest(),
+ "binary_data_base64": self._encode_binary(binary_data),
+ }
+
+ async def _build_forward_segment(self, segment_data: Mapping[str, Any]) -> Optional[Dict[str, Any]]:
+ """构造合并转发消息段。
+
+ Args:
+ segment_data: OneBot ``forward`` 段的 ``data`` 字典。
+
+ Returns:
+ Optional[Dict[str, Any]]: 转换后的合并转发消息段;失败时返回 ``None``。
+ """
+ message_id = str(segment_data.get("id") or "").strip()
+ if not message_id:
+ return None
+
+ forward_detail = await self._query_service.get_forward_message(message_id)
+ if forward_detail is None:
+ return {"type": "text", "data": "[forward]"}
+
+ messages = forward_detail.get("messages", [])
+ if not isinstance(messages, list):
+ return {"type": "text", "data": "[forward]"}
+
+ forward_nodes: List[Dict[str, Any]] = []
+ for forward_message in messages:
+ if not isinstance(forward_message, Mapping):
+ continue
+ raw_content = forward_message.get("content", [])
+ content_segments = await self._convert_forward_content(raw_content, "")
+ sender = forward_message.get("sender", {})
+ if not isinstance(sender, Mapping):
+ sender = {}
+ forward_nodes.append(
+ {
+ "user_id": str(sender.get("user_id") or sender.get("uin") or "").strip() or None,
+ "user_nickname": str(sender.get("nickname") or sender.get("name") or "未知用户"),
+ "user_cardname": str(sender.get("card") or "").strip() or None,
+ "message_id": str(forward_message.get("message_id") or uuid4().hex),
+ "content": content_segments or [{"type": "text", "data": "[empty]"}],
+ }
+ )
+
+ if not forward_nodes:
+ return {"type": "text", "data": "[forward]"}
+ return {"type": "forward", "data": forward_nodes}
+
+ async def _convert_forward_content(self, raw_content: Any, self_id: str) -> List[Dict[str, Any]]:
+ """转换转发节点内部的消息段列表。
+
+ Args:
+ raw_content: 转发节点原始内容。
+ self_id: 当前机器人账号 ID。
+
+ Returns:
+ List[Dict[str, Any]]: 转换后的消息段列表。
+ """
+ pseudo_payload: Dict[str, Any] = {"message": raw_content}
+ segments, _ = await self.convert_segments(pseudo_payload, self_id)
+ return segments
+
+ def _build_json_text_segment(self, segment_data: Mapping[str, Any]) -> Dict[str, Any]:
+ """将 JSON 卡片最佳努力转换为文本占位。
+
+ Args:
+ segment_data: OneBot ``json`` 段的 ``data`` 字典。
+
+ Returns:
+ Dict[str, Any]: 转换后的文本消息段。
+ """
+ json_data = str(segment_data.get("data") or "").strip()
+ if not json_data:
+ return {"type": "text", "data": "[json]"}
+
+ try:
+ parsed_json = json.loads(json_data)
+ except Exception:
+ return {"type": "text", "data": "[json]"}
+
+ app_name = str(parsed_json.get("app") or "").strip()
+ prompt = ""
+ if isinstance(parsed_json.get("meta"), Mapping):
+ prompt = str(parsed_json["meta"].get("prompt") or "").strip()
+ text = prompt or app_name or "json"
+ return {"type": "text", "data": f"[json:{text}]"}
+
+ @staticmethod
+ def _encode_binary(binary_data: bytes) -> str:
+ """将二进制内容编码为 Base64 字符串。
+
+ Args:
+ binary_data: 待编码的二进制内容。
+
+ Returns:
+ str: Base64 编码字符串。
+ """
+ import base64
+
+ return base64.b64encode(binary_data).decode("utf-8")
+
+ @staticmethod
+ def _decode_binary(binary_base64: str) -> bytes:
+ """将 Base64 字符串解码为二进制内容。
+
+ Args:
+ binary_base64: Base64 字符串。
+
+ Returns:
+ bytes: 解码后的二进制内容。
+ """
+ import base64
+
+ return base64.b64decode(binary_base64)
+
+ def build_plain_text(self, raw_message: List[Dict[str, Any]], fallback_text: str) -> str:
+ """从标准消息段中提取可展示的纯文本。
+
+ Args:
+ raw_message: 标准化后的消息段列表。
+ fallback_text: 当无法拼出文本时使用的回退文本。
+
+ Returns:
+ str: 用于 Host 展示和命令判断的纯文本内容。
+ """
+ plain_text_parts: List[str] = []
+ for item in raw_message:
+ if not isinstance(item, Mapping):
+ continue
+ item_type = str(item.get("type") or "").strip()
+ item_data = item.get("data")
+ if item_type == "text":
+ plain_text_parts.append(str(item_data or ""))
+ elif item_type == "at" and isinstance(item_data, Mapping):
+ plain_text_parts.append(f"@{item_data.get('target_user_id') or ''}")
+ elif item_type == "reply":
+ plain_text_parts.append("[reply]")
+ elif item_type == "forward":
+ plain_text_parts.append("[forward]")
+ elif item_type in {"image", "emoji", "voice"}:
+ plain_text_parts.append(f"[{item_type}]")
+
+ plain_text = "".join(part for part in plain_text_parts if part).strip()
+ return plain_text or fallback_text or "[unsupported]"
diff --git a/src/plugins/built_in/napcat_adapter/codec_outbound.py b/src/plugins/built_in/napcat_adapter/codec_outbound.py
new file mode 100644
index 00000000..6adcb622
--- /dev/null
+++ b/src/plugins/built_in/napcat_adapter/codec_outbound.py
@@ -0,0 +1,192 @@
+"""NapCat 出站消息编解码。"""
+
+from typing import Any, Dict, List, Mapping, Tuple
+
+
+class NapCatOutboundCodec:
+ """NapCat 出站消息编码器。"""
+
+ def build_outbound_action(
+ self,
+ message: Mapping[str, Any],
+ route: Mapping[str, Any],
+ ) -> Tuple[str, Dict[str, Any]]:
+ """为 Host 出站消息构造 OneBot 动作。
+
+ Args:
+ message: Host 侧标准 ``MessageDict``。
+ route: Platform IO 路由信息。
+
+ Returns:
+ Tuple[str, Dict[str, Any]]: 动作名称与参数字典。
+
+ Raises:
+ ValueError: 当私聊出站缺少目标用户 ID 时抛出。
+ """
+ message_info = message.get("message_info", {})
+ if not isinstance(message_info, Mapping):
+ message_info = {}
+
+ group_info = message_info.get("group_info", {})
+ if not isinstance(group_info, Mapping):
+ group_info = {}
+
+ additional_config = message_info.get("additional_config", {})
+ if not isinstance(additional_config, Mapping):
+ additional_config = {}
+
+ raw_message = message.get("raw_message", [])
+ segments = self.convert_segments(raw_message)
+
+ if target_group_id := str(
+ group_info.get("group_id") or additional_config.get("platform_io_target_group_id") or ""
+ ).strip():
+ return "send_group_msg", {"group_id": target_group_id, "message": segments}
+
+ target_user_id = str(
+ additional_config.get("platform_io_target_user_id")
+ or additional_config.get("target_user_id")
+ or route.get("target_user_id")
+ or ""
+ ).strip()
+ if not target_user_id:
+ raise ValueError("Outbound private message is missing target_user_id")
+
+ return "send_private_msg", {"message": segments, "user_id": target_user_id}
+
+ def convert_segments(self, raw_message: Any) -> List[Dict[str, Any]]:
+ """将 Host 消息段转换为 OneBot 消息段。
+
+ Args:
+ raw_message: Host 侧 ``raw_message`` 字段。
+
+ Returns:
+ List[Dict[str, Any]]: OneBot 消息段列表。
+ """
+ if not isinstance(raw_message, list):
+ return [{"type": "text", "data": {"text": ""}}]
+
+ outbound_segments: List[Dict[str, Any]] = []
+ for item in raw_message:
+ if not isinstance(item, Mapping):
+ continue
+
+ item_type = str(item.get("type") or "").strip()
+ item_data = item.get("data")
+
+ if item_type == "text":
+ text_value = str(item_data or "")
+ outbound_segments.append({"type": "text", "data": {"text": text_value}})
+ continue
+
+ if item_type == "at" and isinstance(item_data, Mapping):
+ if target_user_id := str(item_data.get("target_user_id") or "").strip():
+ outbound_segments.append({"type": "at", "data": {"qq": target_user_id}})
+ continue
+
+ if item_type == "reply":
+ if isinstance(item_data, Mapping):
+ target_message_id = str(item_data.get("target_message_id") or "").strip()
+ else:
+ target_message_id = str(item_data or "").strip()
+ if target_message_id:
+ outbound_segments.append({"type": "reply", "data": {"id": target_message_id}})
+ continue
+
+ if item_type == "image":
+ binary_base64 = str(item.get("binary_data_base64") or "").strip()
+ if binary_base64:
+ outbound_segments.append(
+ {
+ "type": "image",
+ "data": {"file": f"base64://{binary_base64}", "subtype": 0},
+ }
+ )
+ else:
+ outbound_segments.append({"type": "text", "data": {"text": "[image]"}})
+ continue
+
+ if item_type == "emoji":
+ binary_base64 = str(item.get("binary_data_base64") or "").strip()
+ if binary_base64:
+ outbound_segments.append(
+ {
+ "type": "image",
+ "data": {
+ "file": f"base64://{binary_base64}",
+ "subtype": 1,
+ "summary": "[动画表情]",
+ },
+ }
+ )
+ else:
+ outbound_segments.append({"type": "text", "data": {"text": "[emoji]"}})
+ continue
+
+ if item_type == "voice":
+ binary_base64 = str(item.get("binary_data_base64") or "").strip()
+ if binary_base64:
+ outbound_segments.append({"type": "record", "data": {"file": f"base64://{binary_base64}"}})
+ else:
+ outbound_segments.append({"type": "text", "data": {"text": "[voice]"}})
+ continue
+
+ if item_type == "forward" and isinstance(item_data, list):
+ outbound_segments.extend(self._build_forward_nodes(item_data))
+ continue
+
+ if item_type == "dict" and isinstance(item_data, Mapping):
+ if dict_segment := self._build_dict_component_segment(item_data):
+ outbound_segments.append(dict_segment)
+ continue
+
+ fallback_text = f"[unsupported:{item_type or 'unknown'}]"
+ outbound_segments.append({"type": "text", "data": {"text": fallback_text}})
+
+ if not outbound_segments:
+ outbound_segments.append({"type": "text", "data": {"text": ""}})
+ return outbound_segments
+
+ def _build_forward_nodes(self, forward_nodes: List[Any]) -> List[Dict[str, Any]]:
+ """构造 NapCat 转发节点列表。
+
+ Args:
+ forward_nodes: 内部转发节点列表。
+
+ Returns:
+ List[Dict[str, Any]]: NapCat 转发节点列表。
+ """
+ built_nodes: List[Dict[str, Any]] = []
+ for node in forward_nodes:
+ if not isinstance(node, Mapping):
+ continue
+ raw_content = node.get("content", [])
+ node_segments = self.convert_segments(raw_content)
+ built_nodes.append(
+ {
+ "type": "node",
+ "data": {
+ "name": str(node.get("user_nickname") or node.get("user_cardname") or "QQ用户"),
+ "uin": str(node.get("user_id") or ""),
+ "content": node_segments,
+ },
+ }
+ )
+ return built_nodes
+
+ def _build_dict_component_segment(self, item_data: Mapping[str, Any]) -> Dict[str, Any]:
+ """尽力将 ``DictComponent`` 转换为 NapCat 消息段。
+
+ Args:
+ item_data: ``DictComponent`` 原始数据。
+
+ Returns:
+ Dict[str, Any]: NapCat 消息段;不支持时返回占位文本段。
+ """
+ raw_type = str(item_data.get("type") or "").strip()
+ raw_payload = item_data.get("data", item_data)
+ if raw_type in {"file", "music", "video", "face"} and isinstance(raw_payload, Mapping):
+ return {"type": raw_type, "data": dict(raw_payload)}
+ if raw_type in {"image", "record", "reply", "at"} and isinstance(raw_payload, Mapping):
+ return {"type": raw_type, "data": dict(raw_payload)}
+ return {"type": "text", "data": {"text": f"[unsupported:{raw_type or 'dict'}]"}}
diff --git a/src/plugins/built_in/napcat_adapter/config.py b/src/plugins/built_in/napcat_adapter/config.py
new file mode 100644
index 00000000..eeb4acab
--- /dev/null
+++ b/src/plugins/built_in/napcat_adapter/config.py
@@ -0,0 +1,398 @@
+"""NapCat 内置适配器配置解析。"""
+
+from dataclasses import dataclass, field
+from typing import Any, Dict, Mapping, Optional, Set, Tuple
+from urllib.parse import urlparse
+
+from napcat_adapter.constants import (
+ DEFAULT_ACTION_TIMEOUT_SEC,
+ DEFAULT_CHAT_LIST_TYPE,
+ DEFAULT_HEARTBEAT_INTERVAL_SEC,
+ DEFAULT_NAPCAT_HOST,
+ DEFAULT_NAPCAT_PORT,
+ DEFAULT_RECONNECT_DELAY_SEC,
+ SUPPORTED_CONFIG_VERSION,
+)
+
+
+@dataclass(frozen=True)
+class NapCatPluginOptions:
+ """插件级配置。"""
+
+ enabled: bool = False
+ config_version: str = ""
+
+ def should_connect(self) -> bool:
+ """判断当前配置下是否应当启动连接。
+
+ Returns:
+ bool: 若插件连接已启用,则返回 ``True``。
+ """
+ return self.enabled
+
+
+@dataclass(frozen=True)
+class NapCatServerConfig:
+ """NapCat 正向 WebSocket 连接配置。"""
+
+ host: str = DEFAULT_NAPCAT_HOST
+ port: int = DEFAULT_NAPCAT_PORT
+ token: str = ""
+ heartbeat_interval: float = DEFAULT_HEARTBEAT_INTERVAL_SEC
+ reconnect_delay_sec: float = DEFAULT_RECONNECT_DELAY_SEC
+ action_timeout_sec: float = DEFAULT_ACTION_TIMEOUT_SEC
+ connection_id: str = ""
+
+ def build_ws_url(self) -> str:
+ """构造正向 WebSocket 地址。
+
+ Returns:
+ str: 供适配器作为客户端连接的 NapCat WebSocket 地址。
+ """
+ return f"ws://{self.host}:{self.port}"
+
+
+@dataclass(frozen=True)
+class NapCatChatConfig:
+ """聊天名单配置。"""
+
+ group_list_type: str = DEFAULT_CHAT_LIST_TYPE
+ group_list: Set[str] = field(default_factory=set)
+ private_list_type: str = DEFAULT_CHAT_LIST_TYPE
+ private_list: Set[str] = field(default_factory=set)
+ ban_user_id: Set[str] = field(default_factory=set)
+
+
+@dataclass(frozen=True)
+class NapCatFilterConfig:
+ """消息过滤配置。"""
+
+ ignore_self_message: bool = True
+
+
+@dataclass(frozen=True)
+class NapCatPluginSettings:
+ """NapCat 插件完整配置。"""
+
+ plugin: NapCatPluginOptions = field(default_factory=NapCatPluginOptions)
+ napcat_server: NapCatServerConfig = field(default_factory=NapCatServerConfig)
+ chat: NapCatChatConfig = field(default_factory=NapCatChatConfig)
+ filters: NapCatFilterConfig = field(default_factory=NapCatFilterConfig)
+
+ @classmethod
+ def from_mapping(cls, raw_config: Mapping[str, Any], logger: Any) -> "NapCatPluginSettings":
+ """从 Runner 注入的原始配置字典解析插件配置。
+
+ Args:
+ raw_config: Runner 注入的原始配置内容。
+ logger: 插件日志对象。
+
+ Returns:
+ NapCatPluginSettings: 规范化后的插件配置。
+ """
+ plugin_section = _as_mapping(raw_config.get("plugin"))
+ server_section = _as_mapping(raw_config.get("napcat_server"))
+ legacy_connection_section = _as_mapping(raw_config.get("connection"))
+ chat_section = _as_mapping(raw_config.get("chat"))
+ filters_section = _as_mapping(raw_config.get("filters"))
+
+ if not server_section and legacy_connection_section:
+ logger.warning("NapCat 适配器检测到旧版 [connection] 配置段,请尽快迁移到 [napcat_server]")
+ server_section = legacy_connection_section
+
+ legacy_host, legacy_port = _read_legacy_host_port(server_section, legacy_connection_section, logger)
+ parsed_host = _read_string(server_section, "host") or legacy_host or DEFAULT_NAPCAT_HOST
+ parsed_port = _read_positive_int(
+ mapping=server_section,
+ key="port",
+ default=legacy_port or DEFAULT_NAPCAT_PORT,
+ logger=logger,
+ setting_name="napcat_server.port",
+ )
+
+ return cls(
+ plugin=NapCatPluginOptions(
+ enabled=_read_bool(plugin_section, "enabled", False),
+ config_version=_read_string(plugin_section, "config_version"),
+ ),
+ napcat_server=NapCatServerConfig(
+ host=parsed_host,
+ port=parsed_port,
+ token=_read_string(server_section, "token") or _read_string(server_section, "access_token"),
+ heartbeat_interval=_read_positive_float(
+ mapping=server_section,
+ key="heartbeat_interval",
+ default=_read_positive_float(
+ mapping=server_section,
+ key="heartbeat_sec",
+ default=DEFAULT_HEARTBEAT_INTERVAL_SEC,
+ logger=logger,
+ setting_name="napcat_server.heartbeat_interval",
+ ),
+ logger=logger,
+ setting_name="napcat_server.heartbeat_interval",
+ ),
+ reconnect_delay_sec=_read_positive_float(
+ mapping=server_section,
+ key="reconnect_delay_sec",
+ default=DEFAULT_RECONNECT_DELAY_SEC,
+ logger=logger,
+ setting_name="napcat_server.reconnect_delay_sec",
+ ),
+ action_timeout_sec=_read_positive_float(
+ mapping=server_section,
+ key="action_timeout_sec",
+ default=DEFAULT_ACTION_TIMEOUT_SEC,
+ logger=logger,
+ setting_name="napcat_server.action_timeout_sec",
+ ),
+ connection_id=_read_string(server_section, "connection_id"),
+ ),
+ chat=NapCatChatConfig(
+ group_list_type=_read_list_mode(
+ mapping=chat_section,
+ key="group_list_type",
+ default=DEFAULT_CHAT_LIST_TYPE,
+ logger=logger,
+ setting_name="chat.group_list_type",
+ ),
+ group_list=_read_string_set(chat_section, "group_list"),
+ private_list_type=_read_list_mode(
+ mapping=chat_section,
+ key="private_list_type",
+ default=DEFAULT_CHAT_LIST_TYPE,
+ logger=logger,
+ setting_name="chat.private_list_type",
+ ),
+ private_list=_read_string_set(chat_section, "private_list"),
+ ban_user_id=_read_string_set(chat_section, "ban_user_id"),
+ ),
+ filters=NapCatFilterConfig(
+ ignore_self_message=_read_bool(filters_section, "ignore_self_message", True),
+ ),
+ )
+
+ def should_connect(self) -> bool:
+ """判断当前配置下是否应当启动连接。
+
+ Returns:
+ bool: 若插件连接已启用,则返回 ``True``。
+ """
+ return self.plugin.should_connect()
+
+ def validate(self, logger: Any) -> bool:
+ """校验当前配置是否满足启动连接的前提条件。
+
+ Args:
+ logger: 插件日志对象。
+
+ Returns:
+ bool: 若配置满足启动连接的前提条件,则返回 ``True``。
+ """
+ config_version = self.plugin.config_version
+ if not config_version:
+ logger.error(
+ f"NapCat 适配器配置缺少 plugin.config_version,当前插件要求版本 {SUPPORTED_CONFIG_VERSION}"
+ )
+ return False
+
+ if config_version != SUPPORTED_CONFIG_VERSION:
+ logger.error(
+ "NapCat 适配器配置版本不兼容: "
+ f"当前为 {config_version},当前插件要求 {SUPPORTED_CONFIG_VERSION}"
+ )
+ return False
+
+ if not self.napcat_server.host:
+ logger.warning("NapCat 适配器已启用,但 napcat_server.host 为空")
+ return False
+
+ if self.napcat_server.port <= 0:
+ logger.warning("NapCat 适配器已启用,但 napcat_server.port 不是正整数")
+ return False
+
+ return True
+
+
+def _as_mapping(value: Any) -> Dict[str, Any]:
+ """将任意值安全转换为字典。
+
+ Args:
+ value: 待转换的值。
+
+ Returns:
+ Dict[str, Any]: 若原值是映射,则返回普通字典;否则返回空字典。
+ """
+ return dict(value) if isinstance(value, Mapping) else {}
+
+
+def _read_bool(mapping: Mapping[str, Any], key: str, default: bool) -> bool:
+ """安全读取布尔配置值。
+
+ Args:
+ mapping: 待读取的配置字典。
+ key: 目标键名。
+ default: 读取失败时的默认值。
+
+ Returns:
+ bool: 解析后的布尔值。
+ """
+ value = mapping.get(key, default)
+ return value if isinstance(value, bool) else default
+
+
+def _read_string(mapping: Mapping[str, Any], key: str) -> str:
+ """安全读取字符串配置值。
+
+ Args:
+ mapping: 待读取的配置字典。
+ key: 目标键名。
+
+ Returns:
+ str: 去除首尾空白后的字符串值。
+ """
+ value = mapping.get(key)
+ return "" if value is None else str(value).strip()
+
+
+def _read_positive_float(
+ mapping: Mapping[str, Any],
+ key: str,
+ default: float,
+ logger: Any,
+ setting_name: str,
+) -> float:
+ """安全读取正浮点数配置值。
+
+ Args:
+ mapping: 待读取的配置字典。
+ key: 目标键名。
+ default: 读取失败时的默认值。
+ logger: 插件日志对象。
+ setting_name: 用于日志输出的完整配置名。
+
+ Returns:
+ float: 合法的正浮点数;否则返回默认值。
+ """
+ value = mapping.get(key, default)
+ if isinstance(value, (int, float)) and float(value) > 0:
+ return float(value)
+
+ if key in mapping:
+ logger.warning(f"NapCat 适配器配置项取值无效,已回退到默认值: {setting_name}={value!r},默认值为 {default}")
+ return default
+
+
+def _read_positive_int(
+ mapping: Mapping[str, Any],
+ key: str,
+ default: int,
+ logger: Any,
+ setting_name: str,
+) -> int:
+ """安全读取正整数配置值。
+
+ Args:
+ mapping: 待读取的配置字典。
+ key: 目标键名。
+ default: 读取失败时的默认值。
+ logger: 插件日志对象。
+ setting_name: 用于日志输出的完整配置名。
+
+ Returns:
+ int: 合法的正整数;否则返回默认值。
+ """
+ value = mapping.get(key, default)
+ if isinstance(value, int) and value > 0:
+ return value
+
+ if isinstance(value, str) and value.isdigit() and int(value) > 0:
+ return int(value)
+
+ if key in mapping:
+ logger.warning(f"NapCat 适配器配置项取值无效,已回退到默认值: {setting_name}={value!r},默认值为 {default}")
+ return default
+
+
+def _read_list_mode(
+ mapping: Mapping[str, Any],
+ key: str,
+ default: str,
+ logger: Any,
+ setting_name: str,
+) -> str:
+ """安全读取名单模式配置值。
+
+ Args:
+ mapping: 待读取的配置字典。
+ key: 目标键名。
+ default: 读取失败时的默认值。
+ logger: 插件日志对象。
+ setting_name: 用于日志输出的完整配置名。
+
+ Returns:
+ str: 合法的名单模式字符串。
+ """
+ value = mapping.get(key, default)
+ if isinstance(value, str):
+ normalized_value = value.strip()
+ if normalized_value in {"whitelist", "blacklist"}:
+ return normalized_value
+
+ if key in mapping:
+ logger.warning(f"NapCat 适配器配置项取值无效,已回退到默认值: {setting_name}={value!r},默认值为 {default}")
+ return default
+
+
+def _read_string_set(mapping: Mapping[str, Any], key: str) -> Set[str]:
+ """安全读取字符串集合配置值。
+
+ Args:
+ mapping: 待读取的配置字典。
+ key: 目标键名。
+
+ Returns:
+ Set[str]: 规范化后的字符串集合。
+ """
+ value = mapping.get(key, [])
+ if not isinstance(value, list):
+ return set()
+
+ normalized_values: Set[str] = set()
+ for item in value:
+ item_text = "" if item is None else str(item).strip()
+ if item_text:
+ normalized_values.add(item_text)
+ return normalized_values
+
+
+def _read_legacy_host_port(
+ server_section: Mapping[str, Any],
+ legacy_connection_section: Mapping[str, Any],
+ logger: Any,
+) -> Tuple[str, Optional[int]]:
+ """从旧版 ``ws_url`` 配置中提取主机与端口。
+
+ Args:
+ server_section: 新版 ``napcat_server`` 配置段。
+ legacy_connection_section: 旧版 ``connection`` 配置段。
+ logger: 插件日志对象。
+
+ Returns:
+ Tuple[str, Optional[int]]: 解析到的主机与端口;若未找到,则返回空主机与 ``None``。
+ """
+ legacy_ws_url = _read_string(server_section, "ws_url") or _read_string(legacy_connection_section, "ws_url")
+ if not legacy_ws_url:
+ return "", None
+
+ parsed_url = urlparse(legacy_ws_url)
+ parsed_host = parsed_url.hostname or ""
+ parsed_port = parsed_url.port
+
+ logger.warning(
+ "NapCat 适配器检测到旧版 ws_url 配置,已临时兼容解析,请尽快迁移到 napcat_server.host/port"
+ )
+ if parsed_url.path not in {"", "/"}:
+ logger.warning("NapCat 适配器旧版 ws_url 包含路径,新的 napcat_server 配置不会保留该路径")
+
+ return parsed_host, parsed_port
diff --git a/src/plugins/built_in/napcat_adapter/constants.py b/src/plugins/built_in/napcat_adapter/constants.py
new file mode 100644
index 00000000..bdddde6f
--- /dev/null
+++ b/src/plugins/built_in/napcat_adapter/constants.py
@@ -0,0 +1,9 @@
+"""NapCat 内置适配器共享常量。"""
+
+SUPPORTED_CONFIG_VERSION = "0.1.0"
+DEFAULT_NAPCAT_HOST = "127.0.0.1"
+DEFAULT_NAPCAT_PORT = 3001
+DEFAULT_RECONNECT_DELAY_SEC = 5.0
+DEFAULT_HEARTBEAT_INTERVAL_SEC = 30.0
+DEFAULT_ACTION_TIMEOUT_SEC = 15.0
+DEFAULT_CHAT_LIST_TYPE = "whitelist"
diff --git a/src/plugins/built_in/napcat_adapter/filters.py b/src/plugins/built_in/napcat_adapter/filters.py
new file mode 100644
index 00000000..141cda85
--- /dev/null
+++ b/src/plugins/built_in/napcat_adapter/filters.py
@@ -0,0 +1,68 @@
+"""NapCat 入站消息过滤。"""
+
+from typing import Any, Set
+
+from napcat_adapter.config import NapCatChatConfig
+
+
+class NapCatChatFilter:
+ """NapCat 聊天名单过滤器。"""
+
+ def __init__(self, logger: Any) -> None:
+ """初始化聊天名单过滤器。
+
+ Args:
+ logger: 插件日志对象。
+ """
+ self._logger = logger
+
+ def is_inbound_chat_allowed(
+ self,
+ sender_user_id: str,
+ group_id: str,
+ chat_config: NapCatChatConfig,
+ ) -> bool:
+ """检查入站消息是否通过聊天名单过滤。
+
+ Args:
+ sender_user_id: 发送者用户 ID。
+ group_id: 群聊 ID;私聊时为空字符串。
+ chat_config: 当前生效的聊天配置。
+
+ Returns:
+ bool: 若消息允许继续进入 Host,则返回 ``True``。
+ """
+ if sender_user_id in chat_config.ban_user_id:
+ self._logger.warning(f"NapCat 用户 {sender_user_id} 在全局禁止名单中,消息被丢弃")
+ return False
+
+ if group_id:
+ if not self._is_id_allowed_by_list_policy(group_id, chat_config.group_list_type, chat_config.group_list):
+ self._logger.warning(f"NapCat 群聊 {group_id} 未通过聊天名单过滤,消息被丢弃")
+ return False
+ return True
+
+ if not self._is_id_allowed_by_list_policy(
+ sender_user_id,
+ chat_config.private_list_type,
+ chat_config.private_list,
+ ):
+ self._logger.warning(f"NapCat 私聊用户 {sender_user_id} 未通过聊天名单过滤,消息被丢弃")
+ return False
+ return True
+
+ @staticmethod
+ def _is_id_allowed_by_list_policy(target_id: str, list_type: str, configured_ids: Set[str]) -> bool:
+ """根据白名单或黑名单规则判断目标 ID 是否允许通过。
+
+ Args:
+ target_id: 待检查的目标 ID。
+ list_type: 名单模式,仅支持 ``whitelist`` 或 ``blacklist``。
+ configured_ids: 配置中的 ID 集合。
+
+ Returns:
+ bool: 若目标 ID 允许通过,则返回 ``True``。
+ """
+ if list_type == "whitelist":
+ return target_id in configured_ids
+ return target_id not in configured_ids
diff --git a/src/plugins/built_in/napcat_adapter/plugin.py b/src/plugins/built_in/napcat_adapter/plugin.py
index c8bb837b..b1e9bc8c 100644
--- a/src/plugins/built_in/napcat_adapter/plugin.py
+++ b/src/plugins/built_in/napcat_adapter/plugin.py
@@ -1,6 +1,6 @@
"""内置 NapCat 适配器插件。
-当前实现是一个 MVP 版本,目标仅限于跑通基础消息收发链路:
+当前实现维持 MVP 范围,目标是跑通基础消息收发链路:
1. 作为客户端连接 NapCat / OneBot v11 WebSocket 服务。
2. 将入站消息事件转换为 Host 侧的 ``MessageDict``。
3. 将 Host 出站消息转换为 OneBot 动作并发送。
@@ -8,45 +8,26 @@
当前范围刻意收敛为:
- 单连接
- 文本、@、reply 基础转发
-- 暂不处理 ``notice`` / ``meta_event``
+- 暂不处理 ``notice`` / ``meta_event`` 的完整语义归一化
- 暂不支持图片、语音、文件等复杂媒体
"""
from __future__ import annotations
-from typing import Any, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, cast
-from uuid import uuid4
+from typing import Any, Dict, Mapping, Optional
import asyncio
-import contextlib
-import json
-import time
from maibot_sdk import Adapter, MaiBotPlugin
-if TYPE_CHECKING:
- from aiohttp import ClientWebSocketResponse as AiohttpClientWebSocketResponse
-
-try:
- from aiohttp import ClientSession, ClientTimeout, ClientWebSocketResponse, WSMsgType
-
- AIOHTTP_AVAILABLE = True
-except ImportError:
- ClientSession = cast(Any, None)
- ClientTimeout = cast(Any, None)
- ClientWebSocketResponse = cast(Any, None)
- WSMsgType = cast(Any, None)
- AIOHTTP_AVAILABLE = False
-
-if not TYPE_CHECKING:
- AiohttpClientWebSocketResponse = Any
-
-
-SUPPORTED_CONFIG_VERSION = "0.1.0"
-DEFAULT_RECONNECT_DELAY_SEC = 5.0
-DEFAULT_HEARTBEAT_SEC = 30.0
-DEFAULT_ACTION_TIMEOUT_SEC = 15.0
-DEFAULT_CHAT_LIST_TYPE = "whitelist"
+from napcat_adapter.codec_inbound import NapCatInboundCodec
+from napcat_adapter.codec_outbound import NapCatOutboundCodec
+from napcat_adapter.config import NapCatPluginSettings
+from napcat_adapter.filters import NapCatChatFilter
+from napcat_adapter.qq_notice import NapCatNoticeCodec
+from napcat_adapter.qq_queries import NapCatQueryService
+from napcat_adapter.runtime_state import NapCatRuntimeStateManager
+from napcat_adapter.transport import NapCatTransportClient
@Adapter(platform="qq", protocol="napcat", send_method="send_to_platform")
@@ -57,14 +38,14 @@ class NapCatAdapterPlugin(MaiBotPlugin):
"""初始化 NapCat 适配器插件实例。"""
super().__init__()
self._plugin_config: Dict[str, Any] = {}
- self._connection_task: Optional[asyncio.Task[None]] = None
- self._pending_actions: Dict[str, asyncio.Future[Dict[str, Any]]] = {}
- self._background_tasks: Set[asyncio.Task[Any]] = set()
- self._reported_account_id: Optional[str] = None
- self._reported_scope: Optional[str] = None
- self._runtime_state_connected: bool = False
- self._send_lock = asyncio.Lock()
- self._ws: Optional[AiohttpClientWebSocketResponse] = None
+ self._settings: Optional[NapCatPluginSettings] = None
+ self._inbound_codec: Optional[NapCatInboundCodec] = None
+ self._outbound_codec = NapCatOutboundCodec()
+ self._chat_filter: Optional[NapCatChatFilter] = None
+ self._query_service: Optional[NapCatQueryService] = None
+ self._notice_codec: Optional[NapCatNoticeCodec] = None
+ self._runtime_state: Optional[NapCatRuntimeStateManager] = None
+ self._transport: Optional[NapCatTransportClient] = None
def set_plugin_config(self, config: Dict[str, Any]) -> None:
"""设置插件配置内容。
@@ -79,9 +60,8 @@ class NapCatAdapterPlugin(MaiBotPlugin):
await self._restart_connection_if_needed()
async def on_unload(self) -> None:
- """在插件卸载时关闭连接并清理后台任务。"""
+ """在插件卸载时关闭连接并清理运行时状态。"""
await self._stop_connection()
- await self._cancel_background_tasks()
async def on_config_update(self, new_config: Dict[str, Any], version: str) -> None:
"""在配置更新后重载连接状态。
@@ -116,13 +96,14 @@ class NapCatAdapterPlugin(MaiBotPlugin):
del metadata
del kwargs
- ws = self._ws
- if ws is None or ws.closed:
- return {"success": False, "error": "NapCat is not connected"}
+ self._ensure_runtime_components()
+ transport = self._transport
+ if transport is None:
+ return {"success": False, "error": "NapCat transport is not initialized"}
try:
- action_name, params = self._build_outbound_action(message, route or {})
- response = await self._call_action(action_name, params)
+ action_name, params = self._outbound_codec.build_outbound_action(message, route or {})
+ response = await transport.call_action(action_name, params)
except Exception as exc:
return {"success": False, "error": str(exc)}
@@ -135,7 +116,7 @@ class NapCatAdapterPlugin(MaiBotPlugin):
response_data = response.get("data", {})
external_message_id = ""
- if isinstance(response_data, dict):
+ if isinstance(response_data, Mapping):
external_message_id = str(response_data.get("message_id") or "")
return {
@@ -144,143 +125,109 @@ class NapCatAdapterPlugin(MaiBotPlugin):
"metadata": {"action": action_name},
}
- async def _restart_connection_if_needed(self) -> None:
- """根据当前配置重启连接循环。"""
- await self._stop_connection()
- if not self._should_connect():
- self.ctx.logger.info("NapCat 适配器保持空闲状态,因为插件或配置未启用")
- return
- if not self._validate_current_config():
- return
- if not AIOHTTP_AVAILABLE:
- self.ctx.logger.error("NapCat 适配器依赖 aiohttp,但当前环境未安装该依赖")
- return
- self._connection_task = asyncio.create_task(self._connection_loop(), name="napcat_adapter.connection")
+ def _ensure_runtime_components(self) -> None:
+ """确保运行时依赖对象已经完成初始化。"""
+ if self._chat_filter is None:
+ self._chat_filter = NapCatChatFilter(self.ctx.logger)
- async def _stop_connection(self) -> None:
- """停止当前连接并让所有等待中的动作失败返回。"""
- connection_task = self._connection_task
- self._connection_task = None
-
- ws = self._ws
- if ws is not None and not ws.closed:
- with contextlib.suppress(Exception):
- await ws.close()
- self._ws = None
-
- if connection_task is not None:
- connection_task.cancel()
- with contextlib.suppress(asyncio.CancelledError):
- await connection_task
-
- await self._report_adapter_disconnected()
- self._fail_pending_actions("NapCat connection closed")
-
- async def _cancel_background_tasks(self) -> None:
- """取消所有仍在运行的入站后台任务。"""
- background_tasks = list(self._background_tasks)
- for task in background_tasks:
- task.cancel()
- if background_tasks:
- with contextlib.suppress(Exception):
- await asyncio.gather(*background_tasks, return_exceptions=True)
- self._background_tasks.clear()
-
- async def _connection_loop(self) -> None:
- """维护单个 WebSocket 连接,并在断开后按配置重连。"""
- assert ClientSession is not None
- assert ClientTimeout is not None
-
- while self._should_connect():
- ws_url = self._get_string(self._connection_config(), "ws_url")
- if not ws_url:
- self.ctx.logger.warning("NapCat 适配器已启用,但 connection.ws_url 为空")
- return
-
- headers = self._build_headers()
- timeout = ClientTimeout(total=None, connect=10)
- heartbeat = self._get_positive_float(self._connection_config(), "heartbeat_sec", DEFAULT_HEARTBEAT_SEC)
-
- try:
- async with ClientSession(headers=headers, timeout=timeout) as session:
- async with session.ws_connect(ws_url, heartbeat=heartbeat or None) as ws:
- self._ws = ws
- self.ctx.logger.info(f"NapCat 适配器已连接: {ws_url}")
- await self._receive_loop(ws)
- except asyncio.CancelledError:
- raise
- except Exception as exc:
- self.ctx.logger.warning(f"NapCat 适配器连接失败: {exc}")
- finally:
- self._ws = None
- await self._report_adapter_disconnected()
- self._fail_pending_actions("NapCat connection interrupted")
-
- if not self._should_connect():
- break
-
- await asyncio.sleep(
- self._get_positive_float(
- self._connection_config(),
- "reconnect_delay_sec",
- DEFAULT_RECONNECT_DELAY_SEC,
- )
+ if self._transport is None:
+ self._transport = NapCatTransportClient(
+ logger=self.ctx.logger,
+ on_connection_opened=self._bootstrap_adapter_runtime_state,
+ on_connection_closed=self._handle_transport_disconnected,
+ on_payload=self._handle_transport_payload,
)
- async def _receive_loop(self, ws: AiohttpClientWebSocketResponse) -> None:
- """持续消费 WebSocket 消息并分发处理。
+ if self._query_service is None:
+ self._query_service = NapCatQueryService(self.ctx.logger, self._transport)
+
+ if self._inbound_codec is None:
+ self._inbound_codec = NapCatInboundCodec(self.ctx.logger, self._query_service)
+
+ if self._notice_codec is None:
+ self._notice_codec = NapCatNoticeCodec(self.ctx.logger, self._query_service)
+
+ if self._runtime_state is None:
+ self._runtime_state = NapCatRuntimeStateManager(self.ctx.adapter, self.ctx.logger)
+
+ def _reload_settings(self) -> NapCatPluginSettings:
+ """重新解析当前插件配置。
+
+ Returns:
+ NapCatPluginSettings: 最新的规范化配置。
+ """
+ self._settings = NapCatPluginSettings.from_mapping(self._plugin_config, self.ctx.logger)
+ return self._settings
+
+ async def _restart_connection_if_needed(self) -> None:
+ """根据当前配置重启连接循环。"""
+ self._ensure_runtime_components()
+ settings = self._reload_settings()
+
+ await self._stop_connection()
+ if not settings.should_connect():
+ self.ctx.logger.info("NapCat 适配器保持空闲状态,因为插件或配置未启用")
+ return
+ if not settings.validate(self.ctx.logger):
+ return
+
+ transport = self._transport
+ assert transport is not None
+ if not transport.is_available():
+ self.ctx.logger.error("NapCat 适配器依赖 aiohttp,但当前环境未安装该依赖")
+ return
+
+ transport.configure(settings.napcat_server)
+ await transport.start()
+
+ async def _stop_connection(self) -> None:
+ """停止当前连接。"""
+ transport = self._transport
+ if transport is not None:
+ await transport.stop()
+ return
+
+ runtime_state = self._runtime_state
+ if runtime_state is not None:
+ await runtime_state.report_disconnected()
+
+ async def _handle_transport_payload(self, payload: Dict[str, Any]) -> None:
+ """处理来自传输层的非 echo 载荷。
Args:
- ws: 当前活跃的 WebSocket 连接对象。
+ payload: NapCat 推送的原始事件数据。
"""
- assert WSMsgType is not None
-
- bootstrap_task = asyncio.create_task(
- self._bootstrap_adapter_runtime_state(),
- name="napcat_adapter.bootstrap",
- )
- self._background_tasks.add(bootstrap_task)
- bootstrap_task.add_done_callback(self._background_tasks.discard)
-
- try:
- async for ws_message in ws:
- if ws_message.type != WSMsgType.TEXT:
- if ws_message.type in {WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.ERROR}:
- break
- continue
-
- payload = self._parse_json_message(ws_message.data)
- if payload is None:
- continue
-
- if echo_id := str(payload.get("echo") or "").strip():
- self._resolve_pending_action(echo_id, payload)
- continue
-
- if str(payload.get("post_type") or "").strip() != "message":
- continue
-
- task = asyncio.create_task(self._handle_inbound_message(payload), name="napcat_adapter.inbound")
- self._background_tasks.add(task)
- task.add_done_callback(self._background_tasks.discard)
- finally:
- if not bootstrap_task.done():
- bootstrap_task.cancel()
- with contextlib.suppress(asyncio.CancelledError):
- await bootstrap_task
+ post_type = str(payload.get("post_type") or "").strip()
+ if post_type == "message":
+ await self._handle_inbound_message(payload)
+ return
+ if post_type == "notice":
+ await self._handle_notice_event(payload)
+ return
+ if post_type == "meta_event":
+ await self._handle_meta_event(payload)
async def _handle_inbound_message(self, payload: Dict[str, Any]) -> None:
"""处理单条 NapCat 入站消息并注入 Host。
Args:
- payload: NapCat / OneBot 推送的原始事件数据。
+ payload: NapCat / OneBot 推送的原始消息事件。
"""
+ self._ensure_runtime_components()
+ settings = self._settings or self._reload_settings()
+ chat_filter = self._chat_filter
+ inbound_codec = self._inbound_codec
+ runtime_state = self._runtime_state
+ assert chat_filter is not None
+ assert inbound_codec is not None
+ assert runtime_state is not None
+
self_id = str(payload.get("self_id") or "").strip()
if self_id:
- await self._report_adapter_connected(self_id)
+ await runtime_state.report_connected(self_id, settings.napcat_server)
sender = payload.get("sender", {})
- if not isinstance(sender, dict):
+ if not isinstance(sender, Mapping):
sender = {}
sender_user_id = str(payload.get("user_id") or sender.get("user_id") or "").strip()
@@ -288,17 +235,17 @@ class NapCatAdapterPlugin(MaiBotPlugin):
return
group_id = str(payload.get("group_id") or "").strip()
- if self_id and sender_user_id == self_id and self._get_bool(self._filters_config(), "ignore_self_message", True):
+ if self_id and sender_user_id == self_id and settings.filters.ignore_self_message:
return
- if not self._is_inbound_chat_allowed(sender_user_id, group_id):
+ if not chat_filter.is_inbound_chat_allowed(sender_user_id, group_id, settings.chat):
return
- message_dict = self._build_inbound_message_dict(payload, self_id, sender_user_id, sender)
+ message_dict = await inbound_codec.build_message_dict(payload, self_id, sender_user_id, sender)
route_metadata: Dict[str, Any] = {}
if self_id:
route_metadata["self_id"] = self_id
- if connection_id := self._get_string(self._connection_config(), "connection_id"):
- route_metadata["connection_id"] = connection_id
+ if settings.napcat_server.connection_id:
+ route_metadata["connection_id"] = settings.napcat_server.connection_id
external_message_id = str(payload.get("message_id") or "").strip()
accepted = await self.ctx.adapter.receive_external_message(
@@ -310,305 +257,78 @@ class NapCatAdapterPlugin(MaiBotPlugin):
if not accepted:
self.ctx.logger.debug(f"Host 丢弃了 NapCat 入站消息: {external_message_id or '无消息 ID'}")
- def _build_inbound_message_dict(
- self,
- payload: Dict[str, Any],
- self_id: str,
- sender_user_id: str,
- sender: Dict[str, Any],
- ) -> Dict[str, Any]:
- """构造 Host 侧可接受的 ``MessageDict``。
+ async def _handle_notice_event(self, payload: Dict[str, Any]) -> None:
+ """处理 NapCat ``notice`` 事件并注入 Host。
Args:
- payload: NapCat 原始消息事件。
- self_id: 当前机器人账号 ID。
- sender_user_id: 发送者用户 ID。
- sender: 发送者信息字典。
-
- Returns:
- Dict[str, Any]: 规范化后的 ``MessageDict``。
+ payload: NapCat 推送的通知事件。
"""
- message_type = str(payload.get("message_type") or "").strip() or "private"
- group_id = str(payload.get("group_id") or "").strip()
- group_name = str(payload.get("group_name") or "").strip() or (f"group_{group_id}" if group_id else "")
- user_nickname = str(sender.get("nickname") or sender.get("card") or sender_user_id).strip() or sender_user_id
- user_cardname = str(sender.get("card") or "").strip() or None
+ self._ensure_runtime_components()
+ notice_codec = self._notice_codec
+ runtime_state = self._runtime_state
+ settings = self._settings or self._reload_settings()
+ assert notice_codec is not None
+ assert runtime_state is not None
- raw_message, is_at = self._convert_inbound_segments(payload.get("message"), self_id)
- raw_message_text = str(payload.get("raw_message") or "").strip()
- if not raw_message:
- raw_message = [{"type": "text", "data": raw_message_text or "[unsupported]"}]
+ self_id = str(payload.get("self_id") or "").strip()
+ if self_id:
+ await runtime_state.report_connected(self_id, settings.napcat_server)
- plain_text = self._build_plain_text(raw_message, raw_message_text)
- timestamp_seconds = payload.get("time")
- if not isinstance(timestamp_seconds, (int, float)):
- timestamp_seconds = time.time()
-
- additional_config: Dict[str, Any] = {"self_id": self_id, "napcat_message_type": message_type}
- if group_id:
- additional_config["platform_io_target_group_id"] = group_id
- else:
- additional_config["platform_io_target_user_id"] = sender_user_id
-
- message_info: Dict[str, Any] = {
- "user_info": {
- "user_id": sender_user_id,
- "user_nickname": user_nickname,
- "user_cardname": user_cardname,
- },
- "additional_config": additional_config,
- }
- if group_id:
- message_info["group_info"] = {"group_id": group_id, "group_name": group_name}
-
- message_id = str(payload.get("message_id") or f"napcat-{uuid4().hex}").strip()
- return {
- "message_id": message_id,
- "timestamp": str(float(timestamp_seconds)),
- "platform": "qq",
- "message_info": message_info,
- "raw_message": raw_message,
- "is_mentioned": is_at,
- "is_at": is_at,
- "is_emoji": False,
- "is_picture": False,
- "is_command": plain_text.startswith("/"),
- "is_notify": False,
- "session_id": "",
- "processed_plain_text": plain_text,
- "display_message": plain_text,
- }
-
- def _convert_inbound_segments(self, message_payload: Any, self_id: str) -> Tuple[List[Dict[str, Any]], bool]:
- """将 OneBot 消息段转换为 Host 消息段结构。
-
- Args:
- message_payload: OneBot 原始 ``message`` 字段。
- self_id: 当前机器人账号 ID。
-
- Returns:
- Tuple[List[Dict[str, Any]], bool]: 转换后的消息段列表,以及是否 @ 到当前机器人。
- """
- if isinstance(message_payload, str):
- normalized_text = message_payload.strip()
- return ([{"type": "text", "data": normalized_text}] if normalized_text else []), False
-
- if not isinstance(message_payload, list):
- return [], False
-
- converted_segments: List[Dict[str, Any]] = []
- is_at = False
- placeholder_texts = {
- "face": "[face]",
- "file": "[file]",
- "image": "[image]",
- "json": "[json]",
- "record": "[voice]",
- "video": "[video]",
- "xml": "[xml]",
- }
-
- for segment in message_payload:
- if not isinstance(segment, dict):
- continue
-
- segment_type = str(segment.get("type") or "").strip()
- segment_data = segment.get("data", {})
- if not isinstance(segment_data, dict):
- segment_data = {}
-
- if segment_type == "text":
- if text_value := str(segment_data.get("text") or ""):
- converted_segments.append({"type": "text", "data": text_value})
- continue
-
- if segment_type == "at":
- if target_user_id := str(segment_data.get("qq") or "").strip():
- converted_segments.append(
- {
- "type": "at",
- "data": {
- "target_user_id": target_user_id,
- "target_user_nickname": None,
- "target_user_cardname": None,
- },
- }
- )
- if self_id and target_user_id == self_id:
- is_at = True
- continue
-
- if segment_type == "reply":
- if target_message_id := str(segment_data.get("id") or "").strip():
- converted_segments.append({"type": "reply", "data": target_message_id})
- continue
-
- if placeholder := placeholder_texts.get(segment_type):
- converted_segments.append({"type": "text", "data": placeholder})
-
- return converted_segments, is_at
-
- def _build_outbound_action(
- self,
- message: Dict[str, Any],
- route: Dict[str, Any],
- ) -> Tuple[str, Dict[str, Any]]:
- """为 Host 出站消息构造 OneBot 动作。
-
- Args:
- message: Host 侧标准 ``MessageDict``。
- route: Platform IO 路由信息。
-
- Returns:
- Tuple[str, Dict[str, Any]]: 动作名称与参数字典。
- """
- message_info = message.get("message_info", {})
- if not isinstance(message_info, dict):
- message_info = {}
-
- group_info = message_info.get("group_info", {})
- if not isinstance(group_info, dict):
- group_info = {}
-
- additional_config = message_info.get("additional_config", {})
- if not isinstance(additional_config, dict):
- additional_config = {}
-
- raw_message = message.get("raw_message", [])
- segments = self._convert_outbound_segments(raw_message)
-
- if target_group_id := str(
- group_info.get("group_id") or additional_config.get("platform_io_target_group_id") or ""
- ).strip():
- return "send_group_msg", {"group_id": target_group_id, "message": segments}
-
- if not (
- target_user_id := str(
- additional_config.get("platform_io_target_user_id")
- or additional_config.get("target_user_id")
- or route.get("target_user_id")
- or ""
- ).strip()
- ):
- raise ValueError("Outbound private message is missing target_user_id")
-
- return "send_private_msg", {"message": segments, "user_id": target_user_id}
-
- def _convert_outbound_segments(self, raw_message: Any) -> List[Dict[str, Any]]:
- """将 Host 消息段转换为 OneBot 消息段。
-
- Args:
- raw_message: Host 侧 ``raw_message`` 字段。
-
- Returns:
- List[Dict[str, Any]]: OneBot 消息段列表。
- """
- if not isinstance(raw_message, list):
- return [{"type": "text", "data": {"text": ""}}]
-
- outbound_segments: List[Dict[str, Any]] = []
- for item in raw_message:
- if not isinstance(item, dict):
- continue
-
- item_type = str(item.get("type") or "").strip()
- item_data = item.get("data")
-
- if item_type == "text":
- text_value = str(item_data or "")
- outbound_segments.append({"type": "text", "data": {"text": text_value}})
- continue
-
- if item_type == "at" and isinstance(item_data, dict):
- if target_user_id := str(item_data.get("target_user_id") or "").strip():
- outbound_segments.append({"type": "at", "data": {"qq": target_user_id}})
- continue
-
- if item_type == "reply":
- if target_message_id := str(item_data or "").strip():
- outbound_segments.append({"type": "reply", "data": {"id": target_message_id}})
- continue
-
- fallback_text = f"[unsupported:{item_type or 'unknown'}]"
- outbound_segments.append({"type": "text", "data": {"text": fallback_text}})
-
- if not outbound_segments:
- outbound_segments.append({"type": "text", "data": {"text": ""}})
- return outbound_segments
-
- async def _call_action(self, action_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
- """发送 OneBot 动作并等待对应的 echo 响应。
-
- Args:
- action_name: OneBot 动作名称。
- params: 动作参数。
-
- Returns:
- Dict[str, Any]: NapCat 返回的原始响应字典。
- """
- ws = self._ws
- if ws is None or ws.closed:
- raise RuntimeError("NapCat is not connected")
-
- echo_id = uuid4().hex
- loop = asyncio.get_running_loop()
- response_future: asyncio.Future[Dict[str, Any]] = loop.create_future()
- self._pending_actions[echo_id] = response_future
-
- request_payload = {"action": action_name, "params": params, "echo": echo_id}
- try:
- async with self._send_lock:
- await ws.send_str(json.dumps(request_payload, ensure_ascii=False))
- timeout_seconds = self._get_positive_float(
- self._connection_config(),
- "action_timeout_sec",
- DEFAULT_ACTION_TIMEOUT_SEC,
- )
- return await asyncio.wait_for(response_future, timeout=timeout_seconds)
- finally:
- self._pending_actions.pop(echo_id, None)
-
- def _resolve_pending_action(self, echo_id: str, payload: Dict[str, Any]) -> None:
- """解析等待中的动作响应。
-
- Args:
- echo_id: 动作请求对应的 echo 标识。
- payload: NapCat 返回的响应载荷。
- """
- response_future = self._pending_actions.get(echo_id)
- if response_future is None or response_future.done():
+ message_dict = await notice_codec.build_notice_message_dict(payload)
+ if message_dict is None:
return
- response_future.set_result(payload)
- def _fail_pending_actions(self, error_message: str) -> None:
- """让所有等待中的动作以异常方式结束。
+ route_metadata: Dict[str, Any] = {}
+ if self_id:
+ route_metadata["self_id"] = self_id
+ if settings.napcat_server.connection_id:
+ route_metadata["connection_id"] = settings.napcat_server.connection_id
+
+ external_message_id = str(payload.get("message_id") or payload.get("notice_type") or "").strip()
+ accepted = await self.ctx.adapter.receive_external_message(
+ message_dict,
+ route_metadata=route_metadata,
+ external_message_id=external_message_id or None,
+ dedupe_key=external_message_id or None,
+ )
+ if not accepted:
+ self.ctx.logger.debug(f"Host 丢弃了 NapCat 通知事件: {external_message_id or '无消息 ID'}")
+
+ async def _handle_meta_event(self, payload: Dict[str, Any]) -> None:
+ """处理 NapCat ``meta_event`` 事件。
Args:
- error_message: 写入异常中的错误信息。
+ payload: NapCat 推送的元事件。
"""
- for response_future in self._pending_actions.values():
- if not response_future.done():
- response_future.set_exception(RuntimeError(error_message))
- self._pending_actions.clear()
+ self._ensure_runtime_components()
+ notice_codec = self._notice_codec
+ runtime_state = self._runtime_state
+ settings = self._settings or self._reload_settings()
+ assert notice_codec is not None
+ assert runtime_state is not None
+
+ self_id = str(payload.get("self_id") or "").strip()
+ if self_id:
+ await runtime_state.report_connected(self_id, settings.napcat_server)
+
+ await notice_codec.handle_meta_event(payload)
async def _bootstrap_adapter_runtime_state(self) -> None:
- """在连接建立后主动获取账号信息并激活适配器路由。
+ """在连接建立后主动获取账号信息并激活适配器路由。"""
+ transport = self._transport
+ query_service = self._query_service
+ runtime_state = self._runtime_state
+ settings = self._settings or self._reload_settings()
+ if transport is None or query_service is None or runtime_state is None:
+ return
- 该步骤会在 WebSocket 接收循环启动后异步执行,确保 `_call_action()`
- 发出的 `get_login_info` 请求能够被同一连接上的接收循环消费到 echo
- 响应,从而在真正收到业务消息前就完成 Host 侧 route 激活。
- """
max_attempts = 3
last_error: Optional[Exception] = None
for attempt in range(1, max_attempts + 1):
- ws = self._ws
- if ws is None or ws.closed:
- return
-
try:
- response = await self._call_action("get_login_info", {})
- self_id = self._extract_self_id_from_login_response(response)
- await self._report_adapter_connected(self_id)
+ login_info = await query_service.get_login_info()
+ self_id = self._extract_self_id_from_login_response(login_info)
+ await runtime_state.report_connected(self_id, settings.napcat_server)
return
except asyncio.CancelledError:
raise
@@ -623,430 +343,33 @@ class NapCatAdapterPlugin(MaiBotPlugin):
if last_error is not None:
self.ctx.logger.error(f"NapCat 适配器未能完成路由激活,连接将保持只接收状态: {last_error}")
+ async def _handle_transport_disconnected(self) -> None:
+ """处理传输层断开事件。"""
+ runtime_state = self._runtime_state
+ if runtime_state is not None:
+ await runtime_state.report_disconnected()
+
@staticmethod
- def _extract_self_id_from_login_response(response: Dict[str, Any]) -> str:
- """从 `get_login_info` 响应中提取当前账号 ID。
+ def _extract_self_id_from_login_response(response: Optional[Dict[str, Any]]) -> str:
+ """从 ``get_login_info`` 查询结果中提取当前账号 ID。
Args:
- response: NapCat 返回的原始动作响应。
+ response: NapCat 返回的登录信息字典。
Returns:
- str: 规范化后的 `self_id` 字符串。
+ str: 规范化后的账号 ID 字符串。
Raises:
ValueError: 当响应中缺少有效账号 ID 时抛出。
"""
- if str(response.get("status") or "").lower() != "ok":
- raise ValueError(str(response.get("wording") or response.get("message") or "get_login_info failed"))
-
- response_data = response.get("data", {})
- if not isinstance(response_data, dict):
+ if not isinstance(response, Mapping):
raise ValueError("get_login_info 响应缺少 data 字段")
- self_id = str(response_data.get("user_id") or "").strip()
+ self_id = str(response.get("user_id") or "").strip()
if not self_id:
raise ValueError("get_login_info 响应缺少有效的 user_id")
return self_id
- async def _report_adapter_connected(self, account_id: str) -> None:
- """向 Host 上报当前连接已就绪。
-
- Args:
- account_id: 当前 NapCat 连接对应的机器人账号 ID。
- """
- normalized_account_id = str(account_id).strip()
- if not normalized_account_id:
- return
-
- scope = self._get_string(self._connection_config(), "connection_id").strip()
- if (
- self._runtime_state_connected
- and self._reported_account_id == normalized_account_id
- and self._reported_scope == (scope or None)
- ):
- return
-
- accepted = False
- try:
- accepted = await self.ctx.adapter.update_runtime_state(
- connected=True,
- account_id=normalized_account_id,
- scope=scope,
- metadata={"ws_url": self._get_string(self._connection_config(), "ws_url")},
- )
- except Exception as exc:
- self.ctx.logger.warning(f"NapCat 适配器上报连接就绪状态失败: {exc}")
- return
-
- if not accepted:
- self.ctx.logger.warning("NapCat 适配器连接已建立,但 Host 未接受运行时状态更新")
- return
-
- self._runtime_state_connected = True
- self._reported_account_id = normalized_account_id
- self._reported_scope = scope or None
- self.ctx.logger.info(
- f"NapCat 适配器已激活路由: platform=qq account_id={normalized_account_id} "
- f"scope={self._reported_scope or '*'}"
- )
-
- async def _report_adapter_disconnected(self) -> None:
- """向 Host 上报当前连接已断开,并撤销适配器路由。"""
- if not self._runtime_state_connected:
- self._reported_account_id = None
- self._reported_scope = None
- return
-
- try:
- await self.ctx.adapter.update_runtime_state(connected=False)
- except Exception as exc:
- self.ctx.logger.warning(f"NapCat 适配器上报断开状态失败: {exc}")
- finally:
- self._runtime_state_connected = False
- self._reported_account_id = None
- self._reported_scope = None
-
- def _build_headers(self) -> Dict[str, str]:
- """构造连接 NapCat 所需的请求头。
-
- Returns:
- Dict[str, str]: WebSocket 握手请求头。
- """
- access_token = self._get_string(self._connection_config(), "access_token")
- return {"Authorization": f"Bearer {access_token}"} if access_token else {}
-
- def _parse_json_message(self, data: Any) -> Optional[Dict[str, Any]]:
- """解析 WebSocket 文本消息中的 JSON 数据。
-
- Args:
- data: WebSocket 收到的原始文本数据。
-
- Returns:
- Optional[Dict[str, Any]]: 成功时返回字典,失败时返回 ``None``。
- """
- try:
- payload = json.loads(str(data))
- except Exception as exc:
- self.ctx.logger.warning(f"NapCat 适配器解析 JSON 载荷失败: {exc}")
- return None
-
- return payload if isinstance(payload, dict) else None
-
- def _build_plain_text(self, raw_message: List[Dict[str, Any]], fallback_text: str) -> str:
- """从标准消息段中提取可展示的纯文本。
-
- Args:
- raw_message: 标准化后的消息段列表。
- fallback_text: 当无法拼出文本时使用的回退文本。
-
- Returns:
- str: 用于 Host 展示和命令判断的纯文本内容。
- """
- plain_text_parts: List[str] = []
- for item in raw_message:
- if not isinstance(item, dict):
- continue
- item_type = str(item.get("type") or "").strip()
- item_data = item.get("data")
- if item_type == "text":
- plain_text_parts.append(str(item_data or ""))
- elif item_type == "at" and isinstance(item_data, dict):
- plain_text_parts.append(f"@{item_data.get('target_user_id') or ''}")
- elif item_type == "reply":
- plain_text_parts.append("[reply]")
-
- plain_text = "".join(part for part in plain_text_parts if part).strip()
- return plain_text or fallback_text or "[unsupported]"
-
- def _plugin_section(self) -> Dict[str, Any]:
- """读取插件配置中的 ``plugin`` 段。
-
- Returns:
- Dict[str, Any]: ``plugin`` 配置字典。
- """
- plugin_section = self._plugin_config.get("plugin", {})
- return plugin_section if isinstance(plugin_section, dict) else {}
-
- def _connection_config(self) -> Dict[str, Any]:
- """读取插件配置中的 ``connection`` 段。
-
- Returns:
- Dict[str, Any]: ``connection`` 配置字典。
- """
- connection_config = self._plugin_config.get("connection", {})
- return connection_config if isinstance(connection_config, dict) else {}
-
- def _filters_config(self) -> Dict[str, Any]:
- """读取插件配置中的 ``filters`` 段。
-
- Returns:
- Dict[str, Any]: ``filters`` 配置字典。
- """
- filters_config = self._plugin_config.get("filters", {})
- return filters_config if isinstance(filters_config, dict) else {}
-
- def _chat_config(self) -> Dict[str, Any]:
- """读取插件配置中的 ``chat`` 段。
-
- Returns:
- Dict[str, Any]: ``chat`` 配置字典。
- """
- chat_config = self._plugin_config.get("chat", {})
- return chat_config if isinstance(chat_config, dict) else {}
-
- def _is_inbound_chat_allowed(self, sender_user_id: str, group_id: str) -> bool:
- """检查入站消息是否通过聊天名单过滤。
-
- Args:
- sender_user_id: 发送者用户 ID。
- group_id: 群聊 ID;私聊时为空字符串。
-
- Returns:
- bool: 若消息允许继续进入 Host,则返回 ``True``。
- """
- chat_config = self._chat_config()
- banned_user_ids = self._get_string_list(chat_config, "ban_user_id")
- if sender_user_id in banned_user_ids:
- self.ctx.logger.warning(f"NapCat 用户 {sender_user_id} 在全局禁止名单中,消息被丢弃")
- return False
-
- if group_id:
- group_list_type = self._get_list_mode(chat_config, "group_list_type", DEFAULT_CHAT_LIST_TYPE)
- group_id_list = self._get_string_list(chat_config, "group_list")
- if not self._is_id_allowed_by_list_policy(group_id, group_list_type, group_id_list):
- self.ctx.logger.warning(f"NapCat 群聊 {group_id} 未通过聊天名单过滤,消息被丢弃")
- return False
- return True
-
- private_list_type = self._get_list_mode(chat_config, "private_list_type", DEFAULT_CHAT_LIST_TYPE)
- private_id_list = self._get_string_list(chat_config, "private_list")
- if not self._is_id_allowed_by_list_policy(sender_user_id, private_list_type, private_id_list):
- self.ctx.logger.warning(f"NapCat 私聊用户 {sender_user_id} 未通过聊天名单过滤,消息被丢弃")
- return False
- return True
-
- def _is_id_allowed_by_list_policy(
- self,
- target_id: str,
- list_type: str,
- configured_ids: Set[str],
- ) -> bool:
- """根据白名单或黑名单规则判断目标 ID 是否允许通过。
-
- Args:
- target_id: 待检查的目标 ID。
- list_type: 名单模式,仅支持 ``whitelist`` 或 ``blacklist``。
- configured_ids: 配置中的 ID 集合。
-
- Returns:
- bool: 若目标 ID 允许通过,则返回 ``True``。
- """
- if list_type == "whitelist":
- return target_id in configured_ids
- return target_id not in configured_ids
-
- def _validate_current_config(self) -> bool:
- """校验当前配置是否满足启动连接的前提条件。
-
- Returns:
- bool: 配置可用于启动连接时返回 ``True``。
- """
- if not self._validate_plugin_config_version():
- return False
-
- connection_config = self._connection_config()
- ws_url = self._get_string(connection_config, "ws_url")
- if not ws_url:
- self.ctx.logger.warning("NapCat 适配器已启用,但 connection.ws_url 为空")
- return False
-
- self._validate_positive_float_setting(
- connection_config,
- "connection",
- "reconnect_delay_sec",
- DEFAULT_RECONNECT_DELAY_SEC,
- )
- self._validate_positive_float_setting(
- connection_config,
- "connection",
- "heartbeat_sec",
- DEFAULT_HEARTBEAT_SEC,
- )
- self._validate_positive_float_setting(
- connection_config,
- "connection",
- "action_timeout_sec",
- DEFAULT_ACTION_TIMEOUT_SEC,
- )
- self._validate_list_mode_setting(self._chat_config(), "chat", "group_list_type", DEFAULT_CHAT_LIST_TYPE)
- self._validate_list_mode_setting(self._chat_config(), "chat", "private_list_type", DEFAULT_CHAT_LIST_TYPE)
- return True
-
- def _validate_plugin_config_version(self) -> bool:
- """校验插件配置版本是否与当前实现兼容。
-
- Returns:
- bool: 版本兼容时返回 ``True``。
- """
- config_version = self._get_string(self._plugin_section(), "config_version")
- if not config_version:
- self.ctx.logger.error(
- f"NapCat 适配器配置缺少 plugin.config_version,当前插件要求版本 {SUPPORTED_CONFIG_VERSION}"
- )
- return False
-
- if config_version != SUPPORTED_CONFIG_VERSION:
- self.ctx.logger.error(
- "NapCat 适配器配置版本不兼容: "
- f"当前为 {config_version},当前插件要求 {SUPPORTED_CONFIG_VERSION}"
- )
- return False
-
- return True
-
- def _validate_positive_float_setting(
- self,
- mapping: Dict[str, Any],
- section_name: str,
- key: str,
- default: float,
- ) -> None:
- """校验正浮点数配置项,并在非法时输出告警日志。
-
- Args:
- mapping: 待读取的配置字典。
- section_name: 当前配置段名称。
- key: 目标配置键名。
- default: 配置非法时实际使用的默认值。
- """
- value = mapping.get(key, default)
- if isinstance(value, (int, float)) and float(value) > 0:
- return
-
- self.ctx.logger.warning(
- "NapCat 适配器配置项取值无效,已回退到默认值: "
- f"{section_name}.{key}={value!r},默认值为 {default}"
- )
-
- def _validate_list_mode_setting(
- self,
- mapping: Dict[str, Any],
- section_name: str,
- key: str,
- default: str,
- ) -> None:
- """校验名单模式配置项,并在非法时输出告警日志。
-
- Args:
- mapping: 待读取的配置字典。
- section_name: 当前配置段名称。
- key: 目标配置键名。
- default: 配置非法时实际使用的默认值。
- """
- value = mapping.get(key, default)
- if isinstance(value, str) and value.strip() in {"whitelist", "blacklist"}:
- return
-
- self.ctx.logger.warning(
- "NapCat 适配器配置项取值无效,已回退到默认值: "
- f"{section_name}.{key}={value!r},默认值为 {default}"
- )
-
- def _should_connect(self) -> bool:
- """判断当前配置下是否应当启动连接。
-
- Returns:
- bool: 若启用了插件连接则返回 ``True``。
- """
- return self._get_bool(self._plugin_section(), "enabled", False)
-
- @staticmethod
- def _get_bool(mapping: Dict[str, Any], key: str, default: bool) -> bool:
- """安全读取布尔配置值。
-
- Args:
- mapping: 待读取的配置字典。
- key: 目标键名。
- default: 读取失败时的默认值。
-
- Returns:
- bool: 解析后的布尔值。
- """
- value = mapping.get(key, default)
- return value if isinstance(value, bool) else default
-
- @staticmethod
- def _get_positive_float(mapping: Dict[str, Any], key: str, default: float) -> float:
- """安全读取正浮点数配置值。
-
- Args:
- mapping: 待读取的配置字典。
- key: 目标键名。
- default: 读取失败时的默认值。
-
- Returns:
- float: 合法的正浮点数;否则返回默认值。
- """
- value = mapping.get(key, default)
- if isinstance(value, (int, float)) and float(value) > 0:
- return float(value)
- return default
-
- @staticmethod
- def _get_string(mapping: Dict[str, Any], key: str) -> str:
- """安全读取字符串配置值。
-
- Args:
- mapping: 待读取的配置字典。
- key: 目标键名。
-
- Returns:
- str: 去除首尾空白后的字符串值。
- """
- value = mapping.get(key)
- return "" if value is None else str(value).strip()
-
- @staticmethod
- def _get_list_mode(mapping: Dict[str, Any], key: str, default: str) -> str:
- """安全读取名单模式配置值。
-
- Args:
- mapping: 待读取的配置字典。
- key: 目标键名。
- default: 读取失败时的默认值。
-
- Returns:
- str: 合法的名单模式字符串。
- """
- value = mapping.get(key, default)
- if isinstance(value, str):
- normalized_value = value.strip()
- if normalized_value in {"whitelist", "blacklist"}:
- return normalized_value
- return default
-
- @staticmethod
- def _get_string_list(mapping: Dict[str, Any], key: str) -> Set[str]:
- """安全读取 ID 列表配置值。
-
- Args:
- mapping: 待读取的配置字典。
- key: 目标键名。
-
- Returns:
- Set[str]: 去重后的字符串 ID 集合。
- """
- value = mapping.get(key, [])
- if not isinstance(value, list):
- return set()
-
- normalized_values: Set[str] = set()
- for item in value:
- item_text = "" if item is None else str(item).strip()
- if item_text:
- normalized_values.add(item_text)
- return normalized_values
-
def create_plugin() -> NapCatAdapterPlugin:
"""创建插件实例。
diff --git a/src/plugins/built_in/napcat_adapter/qq_notice.py b/src/plugins/built_in/napcat_adapter/qq_notice.py
new file mode 100644
index 00000000..f577cf98
--- /dev/null
+++ b/src/plugins/built_in/napcat_adapter/qq_notice.py
@@ -0,0 +1,224 @@
+"""NapCat QQ 平台通知与元事件处理。"""
+
+from typing import Any, Dict, Mapping, Optional
+from uuid import uuid4
+
+import time
+
+from napcat_adapter.qq_queries import NapCatQueryService
+
+
+class NapCatNoticeCodec:
+ """NapCat QQ 通知事件编码器。"""
+
+ def __init__(self, logger: Any, query_service: NapCatQueryService) -> None:
+ """初始化通知事件编码器。
+
+ Args:
+ logger: 插件日志对象。
+ query_service: QQ 查询服务。
+ """
+ self._logger = logger
+ self._query_service = query_service
+
+ async def build_notice_message_dict(self, payload: Mapping[str, Any]) -> Optional[Dict[str, Any]]:
+ """将 NapCat ``notice`` 事件转换为 Host 可接受的消息字典。
+
+ Args:
+ payload: NapCat 推送的原始通知事件。
+
+ Returns:
+ Optional[Dict[str, Any]]: 成功时返回标准 ``MessageDict``;无法识别时返回 ``None``。
+ """
+ notice_type = str(payload.get("notice_type") or "").strip()
+ if not notice_type:
+ return None
+
+ group_id = str(payload.get("group_id") or "").strip()
+ user_id = str(payload.get("user_id") or payload.get("operator_id") or "").strip()
+ self_id = str(payload.get("self_id") or "").strip()
+
+ user_info = await self._build_user_info(group_id=group_id, user_id=user_id)
+ group_info = await self._build_group_info(group_id)
+ notice_text = self._build_notice_text(payload, user_info.get("user_nickname", user_id or "系统"))
+ if not notice_text:
+ return None
+
+ additional_config: Dict[str, Any] = {
+ "self_id": self_id,
+ "napcat_notice_type": notice_type,
+ "napcat_notice_sub_type": str(payload.get("sub_type") or "").strip(),
+ "napcat_notice_payload": dict(payload),
+ }
+ if group_id:
+ additional_config["platform_io_target_group_id"] = group_id
+ elif user_id:
+ additional_config["platform_io_target_user_id"] = user_id
+
+ message_info: Dict[str, Any] = {"user_info": user_info, "additional_config": additional_config}
+ if group_info is not None:
+ message_info["group_info"] = group_info
+
+ timestamp_seconds = payload.get("time")
+ if not isinstance(timestamp_seconds, (int, float)):
+ timestamp_seconds = time.time()
+
+ return {
+ "message_id": f"napcat-notice-{uuid4().hex}",
+ "timestamp": str(float(timestamp_seconds)),
+ "platform": "qq",
+ "message_info": message_info,
+ "raw_message": [{"type": "text", "data": notice_text}],
+ "is_mentioned": False,
+ "is_at": False,
+ "is_emoji": False,
+ "is_picture": False,
+ "is_command": False,
+ "is_notify": True,
+ "session_id": "",
+ "processed_plain_text": notice_text,
+ "display_message": notice_text,
+ }
+
+ async def handle_meta_event(self, payload: Mapping[str, Any]) -> None:
+ """处理 ``meta_event`` 事件的日志与状态观测。
+
+ Args:
+ payload: NapCat 推送的原始元事件。
+ """
+ meta_event_type = str(payload.get("meta_event_type") or "").strip()
+ self_id = str(payload.get("self_id") or "").strip() or "unknown"
+
+ if meta_event_type == "lifecycle":
+ sub_type = str(payload.get("sub_type") or "").strip()
+ if sub_type == "connect":
+ self._logger.info(f"NapCat 元事件:Bot {self_id} 已建立连接")
+ else:
+ self._logger.debug(f"NapCat 生命周期事件: self_id={self_id} sub_type={sub_type}")
+ return
+
+ if meta_event_type == "heartbeat":
+ status = payload.get("status", {})
+ if not isinstance(status, Mapping):
+ status = {}
+ is_online = bool(status.get("online", False))
+ is_good = bool(status.get("good", False))
+ interval_ms = payload.get("interval")
+ self._logger.debug(
+ f"NapCat 心跳事件: self_id={self_id} online={is_online} good={is_good} interval={interval_ms}"
+ )
+ if not is_online:
+ self._logger.warning(f"NapCat 心跳显示 Bot {self_id} 已离线")
+ elif not is_good:
+ self._logger.warning(f"NapCat 心跳显示 Bot {self_id} 状态异常")
+
+ async def _build_user_info(self, group_id: str, user_id: str) -> Dict[str, Optional[str]]:
+ """构造通知消息的用户信息。
+
+ Args:
+ group_id: 群号;私聊或系统通知时为空字符串。
+ user_id: 事件关联用户号。
+
+ Returns:
+ Dict[str, Optional[str]]: 规范化后的用户信息字典。
+ """
+ if not user_id:
+ return {
+ "user_id": "notice",
+ "user_nickname": "系统通知",
+ "user_cardname": None,
+ }
+
+ member_info: Optional[Dict[str, Any]]
+ if group_id:
+ member_info = await self._query_service.get_group_member_info(group_id, user_id)
+ else:
+ member_info = await self._query_service.get_stranger_info(user_id)
+
+ if member_info is None:
+ return {
+ "user_id": user_id,
+ "user_nickname": user_id,
+ "user_cardname": None,
+ }
+
+ return {
+ "user_id": user_id,
+ "user_nickname": str(member_info.get("nickname") or user_id),
+ "user_cardname": self._normalize_optional_string(member_info.get("card")),
+ }
+
+ async def _build_group_info(self, group_id: str) -> Optional[Dict[str, str]]:
+ """构造通知消息的群信息。
+
+ Args:
+ group_id: 群号。
+
+ Returns:
+ Optional[Dict[str, str]]: 群信息字典;若不是群通知则返回 ``None``。
+ """
+ if not group_id:
+ return None
+
+ group_info = await self._query_service.get_group_info(group_id)
+ group_name = str(group_info.get("group_name") or f"group_{group_id}") if group_info else f"group_{group_id}"
+ return {"group_id": group_id, "group_name": group_name}
+
+ def _build_notice_text(self, payload: Mapping[str, Any], actor_name: str) -> str:
+ """根据 NapCat 通知事件生成可读文本。
+
+ Args:
+ payload: 原始通知事件。
+ actor_name: 事件操作者显示名。
+
+ Returns:
+ str: 生成的可读通知文本。
+ """
+ notice_type = str(payload.get("notice_type") or "").strip()
+ sub_type = str(payload.get("sub_type") or "").strip()
+ target_id = str(payload.get("target_id") or "").strip()
+
+ if notice_type in {"group_recall", "friend_recall"}:
+ return f"{actor_name} 撤回了一条消息"
+ if notice_type == "notify" and sub_type == "poke":
+ target_text = f" -> {target_id}" if target_id else ""
+ return f"{actor_name} 发起了戳一戳{target_text}"
+ if notice_type == "notify" and sub_type == "group_name":
+ return f"{actor_name} 修改了群名称"
+ if notice_type == "group_ban" and sub_type == "ban":
+ duration = payload.get("duration")
+ return f"{actor_name} 触发了群禁言,时长 {duration} 秒"
+ if notice_type == "group_ban" and sub_type == "lift_ban":
+ return f"{actor_name} 触发了解除禁言"
+ if notice_type == "group_upload":
+ file_info = payload.get("file", {})
+ file_name = ""
+ if isinstance(file_info, Mapping):
+ file_name = str(file_info.get("name") or "").strip()
+ return f"{actor_name} 上传了文件{f':{file_name}' if file_name else ''}"
+ if notice_type == "group_increase":
+ return f"{actor_name} 加入了群聊"
+ if notice_type == "group_decrease":
+ return f"{actor_name} 离开了群聊"
+ if notice_type == "group_admin":
+ return f"{actor_name} 的群管理员状态发生变化"
+ if notice_type == "essence":
+ return f"{actor_name} 触发了精华消息事件"
+ if notice_type == "group_msg_emoji_like":
+ return f"{actor_name} 给一条消息添加了表情回应"
+ return f"[notice] {notice_type}.{sub_type}".strip(".")
+
+ @staticmethod
+ def _normalize_optional_string(value: Any) -> Optional[str]:
+ """将任意值规范化为可选字符串。
+
+ Args:
+ value: 待规范化的值。
+
+ Returns:
+ Optional[str]: 规范化后的字符串;若值为空则返回 ``None``。
+ """
+ if value is None:
+ return None
+ normalized_value = str(value).strip()
+ return normalized_value if normalized_value else None
diff --git a/src/plugins/built_in/napcat_adapter/qq_queries.py b/src/plugins/built_in/napcat_adapter/qq_queries.py
new file mode 100644
index 00000000..7d29803a
--- /dev/null
+++ b/src/plugins/built_in/napcat_adapter/qq_queries.py
@@ -0,0 +1,170 @@
+"""NapCat QQ 平台查询能力。"""
+
+from typing import TYPE_CHECKING, Any, Dict, Optional
+
+import asyncio
+
+if TYPE_CHECKING:
+ from napcat_adapter.transport import NapCatTransportClient
+
+try:
+ from aiohttp import ClientSession, ClientTimeout
+
+ AIOHTTP_AVAILABLE = True
+except ImportError:
+ ClientSession = None # type: ignore[assignment]
+ ClientTimeout = None # type: ignore[assignment]
+ AIOHTTP_AVAILABLE = False
+
+
+class NapCatQueryService:
+ """NapCat QQ 平台查询服务。"""
+
+ def __init__(self, logger: Any, transport: "NapCatTransportClient") -> None:
+ """初始化查询服务。
+
+ Args:
+ logger: 插件日志对象。
+ transport: NapCat 传输层客户端。
+ """
+ self._logger = logger
+ self._transport = transport
+
+ async def get_login_info(self) -> Optional[Dict[str, Any]]:
+ """获取当前登录账号信息。
+
+ Returns:
+ Optional[Dict[str, Any]]: 登录信息字典;失败时返回 ``None``。
+ """
+ return await self._call_query("get_login_info", {})
+
+ async def get_group_info(self, group_id: str) -> Optional[Dict[str, Any]]:
+ """获取群信息。
+
+ Args:
+ group_id: 群号。
+
+ Returns:
+ Optional[Dict[str, Any]]: 群信息字典;失败时返回 ``None``。
+ """
+ return await self._call_query("get_group_info", {"group_id": group_id})
+
+ async def get_group_member_info(self, group_id: str, user_id: str) -> Optional[Dict[str, Any]]:
+ """获取群成员信息。
+
+ Args:
+ group_id: 群号。
+ user_id: 用户号。
+
+ Returns:
+ Optional[Dict[str, Any]]: 群成员信息字典;失败时返回 ``None``。
+ """
+ return await self._call_query(
+ "get_group_member_info",
+ {"group_id": group_id, "user_id": user_id, "no_cache": True},
+ )
+
+ async def get_stranger_info(self, user_id: str) -> Optional[Dict[str, Any]]:
+ """获取陌生人信息。
+
+ Args:
+ user_id: 用户号。
+
+ Returns:
+ Optional[Dict[str, Any]]: 陌生人信息字典;失败时返回 ``None``。
+ """
+ return await self._call_query("get_stranger_info", {"user_id": user_id})
+
+ async def get_message_detail(self, message_id: str) -> Optional[Dict[str, Any]]:
+ """获取消息详情。
+
+ Args:
+ message_id: 消息 ID。
+
+ Returns:
+ Optional[Dict[str, Any]]: 消息详情字典;失败时返回 ``None``。
+ """
+ return await self._call_query("get_msg", {"message_id": message_id})
+
+ async def get_forward_message(self, message_id: str) -> Optional[Dict[str, Any]]:
+ """获取合并转发消息详情。
+
+ Args:
+ message_id: 转发消息 ID。
+
+ Returns:
+ Optional[Dict[str, Any]]: 合并转发消息详情;失败时返回 ``None``。
+ """
+ return await self._call_query("get_forward_msg", {"message_id": message_id})
+
+ async def get_record_detail(self, file_name: str, file_id: Optional[str] = None) -> Optional[Dict[str, Any]]:
+ """获取语音文件详情。
+
+ Args:
+ file_name: 语音文件名。
+ file_id: 可选文件 ID。
+
+ Returns:
+ Optional[Dict[str, Any]]: 语音详情字典;失败时返回 ``None``。
+ """
+ params: Dict[str, Any] = {"file": file_name, "out_format": "wav"}
+ if file_id:
+ params["file_id"] = file_id
+ return await self._call_query("get_record", params)
+
+ async def download_binary(self, url: str) -> Optional[bytes]:
+ """下载远程二进制资源。
+
+ Args:
+ url: 资源 URL。
+
+ Returns:
+ Optional[bytes]: 下载到的二进制内容;失败时返回 ``None``。
+ """
+ if not url:
+ return None
+ if not AIOHTTP_AVAILABLE or ClientSession is None or ClientTimeout is None:
+ self._logger.warning("NapCat 查询层缺少 aiohttp,无法下载远程资源")
+ return None
+
+ try:
+ timeout = ClientTimeout(total=15)
+ async with ClientSession(timeout=timeout) as session:
+ async with session.get(url) as response:
+ if response.status != 200:
+ self._logger.warning(f"NapCat 远程资源下载失败: status={response.status} url={url}")
+ return None
+ return await response.read()
+ except asyncio.CancelledError:
+ raise
+ except Exception as exc:
+ self._logger.warning(f"NapCat 远程资源下载失败: {exc}")
+ return None
+
+ async def _call_query(self, action_name: str, params: Dict[str, Any]) -> Optional[Dict[str, Any]]:
+ """调用 OneBot 查询动作并提取 ``data`` 字段。
+
+ Args:
+ action_name: OneBot 动作名。
+ params: 动作参数。
+
+ Returns:
+ Optional[Dict[str, Any]]: 查询结果中的 ``data`` 字段;失败时返回 ``None``。
+ """
+ try:
+ response = await self._transport.call_action(action_name, params)
+ except asyncio.CancelledError:
+ raise
+ except Exception as exc:
+ self._logger.warning(f"NapCat 查询动作执行失败: action={action_name} error={exc}")
+ return None
+
+ if str(response.get("status") or "").lower() != "ok":
+ self._logger.warning(
+ f"NapCat 查询动作返回失败: action={action_name} "
+ f"message={response.get('wording') or response.get('message') or 'unknown'}"
+ )
+ return None
+
+ response_data = response.get("data")
+ return response_data if isinstance(response_data, dict) else None
diff --git a/src/plugins/built_in/napcat_adapter/runtime_state.py b/src/plugins/built_in/napcat_adapter/runtime_state.py
new file mode 100644
index 00000000..b4dbfa09
--- /dev/null
+++ b/src/plugins/built_in/napcat_adapter/runtime_state.py
@@ -0,0 +1,85 @@
+"""NapCat 运行时路由状态管理。"""
+
+from typing import Any, Optional
+
+from napcat_adapter.config import NapCatServerConfig
+
+
+class NapCatRuntimeStateManager:
+ """NapCat 适配器路由状态上报器。"""
+
+ def __init__(self, adapter_capability: Any, logger: Any) -> None:
+ """初始化运行时状态管理器。
+
+ Args:
+ adapter_capability: SDK 提供的适配器能力对象。
+ logger: 插件日志对象。
+ """
+ self._adapter_capability = adapter_capability
+ self._logger = logger
+ self._runtime_state_connected: bool = False
+ self._reported_account_id: Optional[str] = None
+ self._reported_scope: Optional[str] = None
+
+ async def report_connected(self, account_id: str, server_config: NapCatServerConfig) -> bool:
+ """向 Host 上报当前连接已就绪。
+
+ Args:
+ account_id: 当前 NapCat 连接对应的机器人账号 ID。
+ server_config: 当前生效的 NapCat 服务端配置。
+
+ Returns:
+ bool: 若 Host 接受了运行时状态更新,则返回 ``True``。
+ """
+ normalized_account_id = str(account_id).strip()
+ if not normalized_account_id:
+ return False
+
+ scope = server_config.connection_id or None
+ if (
+ self._runtime_state_connected
+ and self._reported_account_id == normalized_account_id
+ and self._reported_scope == scope
+ ):
+ return True
+
+ accepted = False
+ try:
+ accepted = await self._adapter_capability.update_runtime_state(
+ connected=True,
+ account_id=normalized_account_id,
+ scope=server_config.connection_id,
+ metadata={"ws_url": server_config.build_ws_url()},
+ )
+ except Exception as exc:
+ self._logger.warning(f"NapCat 适配器上报连接就绪状态失败: {exc}")
+ return False
+
+ if not accepted:
+ self._logger.warning("NapCat 适配器连接已建立,但 Host 未接受运行时状态更新")
+ return False
+
+ self._runtime_state_connected = True
+ self._reported_account_id = normalized_account_id
+ self._reported_scope = scope
+ self._logger.info(
+ f"NapCat 适配器已激活路由: platform=qq account_id={normalized_account_id} "
+ f"scope={self._reported_scope or '*'}"
+ )
+ return True
+
+ async def report_disconnected(self) -> None:
+ """向 Host 上报当前连接已断开,并撤销适配器路由。"""
+ if not self._runtime_state_connected:
+ self._reported_account_id = None
+ self._reported_scope = None
+ return
+
+ try:
+ await self._adapter_capability.update_runtime_state(connected=False)
+ except Exception as exc:
+ self._logger.warning(f"NapCat 适配器上报断开状态失败: {exc}")
+ finally:
+ self._runtime_state_connected = False
+ self._reported_account_id = None
+ self._reported_scope = None
diff --git a/src/plugins/built_in/napcat_adapter/transport.py b/src/plugins/built_in/napcat_adapter/transport.py
new file mode 100644
index 00000000..d20de097
--- /dev/null
+++ b/src/plugins/built_in/napcat_adapter/transport.py
@@ -0,0 +1,322 @@
+"""NapCat 正向 WebSocket 传输层。"""
+
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Set, cast
+from uuid import uuid4
+
+import asyncio
+import contextlib
+import json
+
+from napcat_adapter.config import NapCatServerConfig
+
+if TYPE_CHECKING:
+ from aiohttp import ClientWebSocketResponse as AiohttpClientWebSocketResponse
+
+try:
+ from aiohttp import ClientSession, ClientTimeout, WSMsgType
+
+ AIOHTTP_AVAILABLE = True
+except ImportError:
+ ClientSession = cast(Any, None)
+ ClientTimeout = cast(Any, None)
+ WSMsgType = cast(Any, None)
+ AIOHTTP_AVAILABLE = False
+
+if not TYPE_CHECKING:
+ AiohttpClientWebSocketResponse = Any
+
+
+class NapCatTransportClient:
+ """NapCat 正向 WebSocket 客户端。"""
+
+ def __init__(
+ self,
+ logger: Any,
+ on_connection_opened: Callable[[], Awaitable[None]],
+ on_connection_closed: Callable[[], Awaitable[None]],
+ on_payload: Callable[[Dict[str, Any]], Awaitable[None]],
+ ) -> None:
+ """初始化传输层客户端。
+
+ Args:
+ logger: 插件日志对象。
+ on_connection_opened: 连接建立后的异步回调。
+ on_connection_closed: 连接断开后的异步回调。
+ on_payload: 收到非 echo 载荷后的异步回调。
+ """
+ self._logger = logger
+ self._on_connection_opened = on_connection_opened
+ self._on_connection_closed = on_connection_closed
+ self._on_payload = on_payload
+ self._server_config: Optional[NapCatServerConfig] = None
+ self._connection_task: Optional[asyncio.Task[None]] = None
+ self._pending_actions: Dict[str, asyncio.Future[Dict[str, Any]]] = {}
+ self._background_tasks: Set[asyncio.Task[Any]] = set()
+ self._send_lock = asyncio.Lock()
+ self._ws: Optional[AiohttpClientWebSocketResponse] = None
+ self._stop_requested: bool = False
+ self._connection_active: bool = False
+
+ @classmethod
+ def is_available(cls) -> bool:
+ """判断当前环境是否安装了传输层依赖。
+
+ Returns:
+ bool: 若已安装 ``aiohttp``,则返回 ``True``。
+ """
+ return AIOHTTP_AVAILABLE
+
+ def configure(self, server_config: NapCatServerConfig) -> None:
+ """更新当前传输层使用的 NapCat 服务端配置。
+
+ Args:
+ server_config: 最新生效的 NapCat 服务端配置。
+ """
+ self._server_config = server_config
+
+ async def start(self) -> None:
+ """启动 NapCat 正向 WebSocket 连接循环。
+
+ Raises:
+ RuntimeError: 当缺少配置或依赖时抛出。
+ """
+ if not self.is_available():
+ raise RuntimeError("NapCat 适配器依赖 aiohttp,但当前环境未安装该依赖")
+ if self._server_config is None:
+ raise RuntimeError("NapCat 适配器尚未配置 napcat_server")
+ if self._connection_task is not None and not self._connection_task.done():
+ return
+
+ self._stop_requested = False
+ self._connection_task = asyncio.create_task(self._connection_loop(), name="napcat_adapter.connection")
+
+ async def stop(self) -> None:
+ """停止当前连接并清理所有后台任务。"""
+ self._stop_requested = True
+ connection_task = self._connection_task
+ self._connection_task = None
+
+ ws = self._ws
+ if ws is not None and not ws.closed:
+ with contextlib.suppress(Exception):
+ await ws.close()
+ self._ws = None
+
+ if connection_task is not None:
+ connection_task.cancel()
+ with contextlib.suppress(asyncio.CancelledError):
+ await connection_task
+
+ await self._cancel_background_tasks()
+ await self._notify_connection_closed()
+ self._fail_pending_actions("NapCat connection closed")
+
+ async def call_action(self, action_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
+ """发送 OneBot 动作并等待对应的 echo 响应。
+
+ Args:
+ action_name: OneBot 动作名称。
+ params: 动作参数。
+
+ Returns:
+ Dict[str, Any]: NapCat 返回的原始响应字典。
+
+ Raises:
+ RuntimeError: 当连接不可用时抛出。
+ """
+ ws = self._ws
+ server_config = self._server_config
+ if ws is None or ws.closed or server_config is None:
+ raise RuntimeError("NapCat is not connected")
+
+ echo_id = uuid4().hex
+ loop = asyncio.get_running_loop()
+ response_future: asyncio.Future[Dict[str, Any]] = loop.create_future()
+ self._pending_actions[echo_id] = response_future
+
+ request_payload = {"action": action_name, "params": params, "echo": echo_id}
+ try:
+ async with self._send_lock:
+ await ws.send_str(json.dumps(request_payload, ensure_ascii=False))
+ return await asyncio.wait_for(response_future, timeout=server_config.action_timeout_sec)
+ finally:
+ self._pending_actions.pop(echo_id, None)
+
+ async def _connection_loop(self) -> None:
+ """维护单个 WebSocket 连接,并在断开后按配置重连。"""
+ assert ClientSession is not None
+ assert ClientTimeout is not None
+
+ while not self._stop_requested:
+ server_config = self._server_config
+ if server_config is None:
+ return
+
+ ws_url = server_config.build_ws_url()
+ timeout = ClientTimeout(total=None, connect=10)
+
+ try:
+ async with ClientSession(headers=self._build_headers(server_config), timeout=timeout) as session:
+ async with session.ws_connect(ws_url, heartbeat=server_config.heartbeat_interval or None) as ws:
+ self._ws = ws
+ self._logger.info(f"NapCat 适配器已连接: {ws_url}")
+ await self._receive_loop(ws)
+ except asyncio.CancelledError:
+ raise
+ except Exception as exc:
+ self._logger.warning(f"NapCat 适配器连接失败: {exc}")
+ finally:
+ self._ws = None
+ await self._notify_connection_closed()
+ self._fail_pending_actions("NapCat connection interrupted")
+
+ if self._stop_requested:
+ break
+
+ await asyncio.sleep(server_config.reconnect_delay_sec)
+
+ async def _receive_loop(self, ws: AiohttpClientWebSocketResponse) -> None:
+ """持续消费 WebSocket 消息并分发处理。
+
+ Args:
+ ws: 当前活跃的 WebSocket 连接对象。
+ """
+ assert WSMsgType is not None
+
+ bootstrap_task = self._create_background_task(
+ self._notify_connection_opened(),
+ "napcat_adapter.bootstrap",
+ )
+ try:
+ async for ws_message in ws:
+ if ws_message.type != WSMsgType.TEXT:
+ if ws_message.type in {WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.ERROR}:
+ break
+ continue
+
+ payload = self._parse_json_message(ws_message.data)
+ if payload is None:
+ continue
+
+ if echo_id := str(payload.get("echo") or "").strip():
+ self._resolve_pending_action(echo_id, payload)
+ continue
+
+ self._create_background_task(self._on_payload(payload), "napcat_adapter.payload")
+ finally:
+ if bootstrap_task is not None and not bootstrap_task.done():
+ bootstrap_task.cancel()
+ with contextlib.suppress(asyncio.CancelledError):
+ await bootstrap_task
+
+ def _create_background_task(self, coroutine: Awaitable[Any], name: str) -> asyncio.Task[Any]:
+ """创建并跟踪一个后台任务。
+
+ Args:
+ coroutine: 待执行的协程对象。
+ name: 任务名。
+
+ Returns:
+ asyncio.Task[Any]: 已创建的后台任务。
+ """
+ task = asyncio.create_task(coroutine, name=name)
+ self._background_tasks.add(task)
+ task.add_done_callback(self._handle_background_task_completion)
+ return task
+
+ def _handle_background_task_completion(self, task: asyncio.Task[Any]) -> None:
+ """处理后台任务结束后的清理与异常记录。
+
+ Args:
+ task: 已结束的后台任务。
+ """
+ self._background_tasks.discard(task)
+ if task.cancelled():
+ return
+
+ exception = task.exception()
+ if exception is not None:
+ self._logger.error(f"NapCat 适配器后台任务异常: {exception}", exc_info=True)
+
+ async def _cancel_background_tasks(self) -> None:
+ """取消所有仍在运行的后台任务。"""
+ background_tasks = list(self._background_tasks)
+ for task in background_tasks:
+ task.cancel()
+ if background_tasks:
+ with contextlib.suppress(Exception):
+ await asyncio.gather(*background_tasks, return_exceptions=True)
+ self._background_tasks.clear()
+
+ async def _notify_connection_opened(self) -> None:
+ """在连接建立后触发上层回调。"""
+ if self._connection_active:
+ return
+
+ self._connection_active = True
+ try:
+ await self._on_connection_opened()
+ except Exception as exc:
+ self._logger.warning(f"NapCat 适配器连接建立回调失败: {exc}")
+
+ async def _notify_connection_closed(self) -> None:
+ """在连接断开后触发上层回调。"""
+ if not self._connection_active:
+ return
+
+ self._connection_active = False
+ try:
+ await self._on_connection_closed()
+ except Exception as exc:
+ self._logger.warning(f"NapCat 适配器断连回调失败: {exc}")
+
+ def _resolve_pending_action(self, echo_id: str, payload: Dict[str, Any]) -> None:
+ """解析等待中的动作响应。
+
+ Args:
+ echo_id: 动作请求对应的 echo 标识。
+ payload: NapCat 返回的响应载荷。
+ """
+ response_future = self._pending_actions.get(echo_id)
+ if response_future is None or response_future.done():
+ return
+ response_future.set_result(payload)
+
+ def _fail_pending_actions(self, error_message: str) -> None:
+ """让所有等待中的动作以异常方式结束。
+
+ Args:
+ error_message: 写入异常中的错误信息。
+ """
+ for response_future in self._pending_actions.values():
+ if not response_future.done():
+ response_future.set_exception(RuntimeError(error_message))
+ self._pending_actions.clear()
+
+ def _build_headers(self, server_config: NapCatServerConfig) -> Dict[str, str]:
+ """构造连接 NapCat 所需的请求头。
+
+ Args:
+ server_config: 当前生效的 NapCat 服务端配置。
+
+ Returns:
+ Dict[str, str]: WebSocket 握手请求头。
+ """
+ return {"Authorization": f"Bearer {server_config.token}"} if server_config.token else {}
+
+ def _parse_json_message(self, data: Any) -> Optional[Dict[str, Any]]:
+ """解析 WebSocket 文本消息中的 JSON 数据。
+
+ Args:
+ data: WebSocket 收到的原始文本数据。
+
+ Returns:
+ Optional[Dict[str, Any]]: 成功时返回字典,失败时返回 ``None``。
+ """
+ try:
+ payload = json.loads(str(data))
+ except Exception as exc:
+ self._logger.warning(f"NapCat 适配器解析 JSON 载荷失败: {exc}")
+ return None
+
+ return payload if isinstance(payload, dict) else None
From 56a6d2fd8ce13459395334c64d96ff6c991aef9c Mon Sep 17 00:00:00 2001
From: DrSmoothl <1787882683@qq.com>
Date: Sun, 22 Mar 2026 00:22:24 +0800
Subject: [PATCH 25/45] =?UTF-8?q?refactor:=20=E4=BC=98=E5=8C=96=E6=95=B0?=
=?UTF-8?q?=E6=8D=AE=E5=BA=93=E6=93=8D=E4=BD=9C=E5=92=8C=E6=A8=A1=E5=9E=8B?=
=?UTF-8?q?=E5=AE=9A=E4=B9=89=EF=BC=8C=E5=A2=9E=E5=BC=BA=E8=A1=A8=E8=BE=BE?=
=?UTF-8?q?=E6=96=B9=E5=BC=8F=E5=92=8C=E9=BB=91=E8=AF=9D=E8=A1=A8=E7=9A=84?=
=?UTF-8?q?=E6=8F=92=E5=85=A5=E9=80=BB=E8=BE=91?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
pytests/common_test/test_expression_schema.py | 78 +++++++++++++++++
pytests/common_test/test_jargon_schema.py | 84 +++++++++++++++++++
src/common/database/database_model.py | 15 ++--
src/learners/expression_learner.py | 17 +++-
src/learners/jargon_miner.py | 18 ++--
5 files changed, 195 insertions(+), 17 deletions(-)
create mode 100644 pytests/common_test/test_expression_schema.py
create mode 100644 pytests/common_test/test_jargon_schema.py
diff --git a/pytests/common_test/test_expression_schema.py b/pytests/common_test/test_expression_schema.py
new file mode 100644
index 00000000..31fcd98f
--- /dev/null
+++ b/pytests/common_test/test_expression_schema.py
@@ -0,0 +1,78 @@
+"""测试表达方式表结构和基础插入行为。"""
+
+from typing import Generator
+
+import pytest
+from sqlalchemy.pool import StaticPool
+from sqlmodel import Session, SQLModel, create_engine
+
+from src.common.database.database_model import Expression
+
+
+@pytest.fixture(name="expression_engine")
+def expression_engine_fixture() -> Generator:
+ """创建仅用于表达方式表测试的内存数据库引擎。
+
+ Yields:
+ Generator: 供测试使用的 SQLite 内存引擎。
+ """
+ engine = create_engine(
+ "sqlite://",
+ connect_args={"check_same_thread": False},
+ poolclass=StaticPool,
+ )
+ SQLModel.metadata.create_all(engine)
+ yield engine
+
+
+def test_expression_insert_assigns_auto_increment_id(expression_engine) -> None:
+ """表达方式表在新库中应能自动分配自增主键。"""
+ with Session(expression_engine) as session:
+ expression = Expression(
+ situation="表达情绪高涨或生理反应",
+ style="发送💦表情符号",
+ content_list='["表达情绪高涨或生理反应"]',
+ count=1,
+ session_id="session-a",
+ checked=False,
+ rejected=False,
+ )
+ session.add(expression)
+ session.commit()
+ session.refresh(expression)
+
+ assert expression.id is not None
+ assert expression.id > 0
+
+
+def test_expression_insert_allows_same_situation_style(expression_engine) -> None:
+ """相同情景和风格的表达方式记录不应再被错误绑定到复合主键。"""
+ with Session(expression_engine) as session:
+ first_expression = Expression(
+ situation="对重复行为的默契响应",
+ style="持续性跟发相同内容",
+ content_list='["对重复行为的默契响应"]',
+ count=1,
+ session_id="session-a",
+ checked=False,
+ rejected=False,
+ )
+ second_expression = Expression(
+ situation="对重复行为的默契响应",
+ style="持续性跟发相同内容",
+ content_list='["对重复行为的默契响应-变体"]',
+ count=2,
+ session_id="session-b",
+ checked=False,
+ rejected=False,
+ )
+
+ session.add(first_expression)
+ session.add(second_expression)
+ session.commit()
+ session.refresh(first_expression)
+ session.refresh(second_expression)
+
+ assert first_expression.id is not None
+ assert second_expression.id is not None
+ assert first_expression.id != second_expression.id
diff --git a/pytests/common_test/test_jargon_schema.py b/pytests/common_test/test_jargon_schema.py
new file mode 100644
index 00000000..909392ab
--- /dev/null
+++ b/pytests/common_test/test_jargon_schema.py
@@ -0,0 +1,84 @@
+"""测试黑话表结构和基础插入行为。"""
+
+from typing import Generator
+
+import pytest
+from sqlalchemy.pool import StaticPool
+from sqlmodel import Session, SQLModel, create_engine
+
+from src.common.database.database_model import Jargon
+
+
+@pytest.fixture(name="jargon_engine")
+def jargon_engine_fixture() -> Generator:
+ """创建仅用于黑话表测试的内存数据库引擎。
+
+ Yields:
+ Generator: 供测试使用的 SQLite 内存引擎。
+ """
+ engine = create_engine(
+ "sqlite://",
+ connect_args={"check_same_thread": False},
+ poolclass=StaticPool,
+ )
+ SQLModel.metadata.create_all(engine)
+ yield engine
+
+
+def test_jargon_insert_assigns_auto_increment_id(jargon_engine) -> None:
+ """黑话表在新库中应能自动分配自增主键。"""
+ with Session(jargon_engine) as session:
+ jargon = Jargon(
+ content="VF8V4L",
+ raw_content='["[1] test"]',
+ meaning="",
+ session_id_dict='{"session-a": 1}',
+ count=1,
+ is_jargon=True,
+ is_complete=False,
+ is_global=True,
+ last_inference_count=0,
+ )
+ session.add(jargon)
+ session.commit()
+ session.refresh(jargon)
+
+ assert jargon.id is not None
+ assert jargon.id > 0
+
+
+def test_jargon_insert_allows_same_content_with_different_rows(jargon_engine) -> None:
+ """黑话内容不应再被错误地绑成复合主键的一部分。"""
+ with Session(jargon_engine) as session:
+ first_jargon = Jargon(
+ content="表情1",
+ raw_content='["[1] first"]',
+ meaning="",
+ session_id_dict='{"session-a": 1}',
+ count=1,
+ is_jargon=True,
+ is_complete=False,
+ is_global=False,
+ last_inference_count=0,
+ )
+ second_jargon = Jargon(
+ content="表情1",
+ raw_content='["[1] second"]',
+ meaning="",
+ session_id_dict='{"session-b": 1}',
+ count=1,
+ is_jargon=True,
+ is_complete=False,
+ is_global=False,
+ last_inference_count=0,
+ )
+
+ session.add(first_jargon)
+ session.add(second_jargon)
+ session.commit()
+ session.refresh(first_jargon)
+ session.refresh(second_jargon)
+
+ assert first_jargon.id is not None
+ assert second_jargon.id is not None
+ assert first_jargon.id != second_jargon.id
diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py
index a0993a77..5b274c43 100644
--- a/src/common/database/database_model.py
+++ b/src/common/database/database_model.py
@@ -1,8 +1,9 @@
-from typing import Optional
-from sqlalchemy import Column, Float, Enum as SQLEnum, DateTime
-from sqlmodel import SQLModel, Field, LargeBinary
-from enum import Enum
from datetime import datetime
+from enum import Enum
+from typing import Optional
+
+from sqlalchemy import Column, DateTime, Enum as SQLEnum, Float
+from sqlmodel import Field, LargeBinary, SQLModel
class ModelUser(str, Enum):
@@ -172,8 +173,8 @@ class Expression(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
- situation: str = Field(index=True, max_length=255, primary_key=True) # 情景
- style: str = Field(index=True, max_length=255, primary_key=True) # 风格
+ situation: str = Field(index=True, max_length=255) # 情景
+ style: str = Field(index=True, max_length=255) # 风格
# context: str # 上下文
# up_content: str
@@ -200,7 +201,7 @@ class Jargon(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
- content: str = Field(index=True, max_length=255, primary_key=True) # 黑话内容
+ content: str = Field(index=True, max_length=255) # 黑话内容
raw_content: Optional[str] = Field(default=None, nullable=True) # 原始内容,未处理的黑话内容,为List[str]
meaning: str # 黑话含义
diff --git a/src/learners/expression_learner.py b/src/learners/expression_learner.py
index 43e4ee7d..156fedc5 100644
--- a/src/learners/expression_learner.py
+++ b/src/learners/expression_learner.py
@@ -329,7 +329,13 @@ class ExpressionLearner:
return filtered_expressions
# ====== DB 操作相关 ======
- async def _upsert_expression_to_db(self, situation: str, style: str):
+ async def _upsert_expression_to_db(self, situation: str, style: str) -> None:
+ """将表达方式写入数据库,存在时更新,不存在时新增。
+
+ Args:
+ situation: 表达方式对应的使用情景。
+ style: 表达方式风格。
+ """
expr, similarity = self._find_similar_expression(situation) or (None, 0)
if expr:
# 根据相似度决定是否使用 LLM 总结
@@ -340,7 +346,13 @@ class ExpressionLearner:
# 没有找到匹配的记录,创建新记录
self._create_expression(situation, style)
- def _create_expression(self, situation: str, style: str):
+ def _create_expression(self, situation: str, style: str) -> None:
+ """创建新的表达方式记录。
+
+ Args:
+ situation: 表达方式对应的使用情景。
+ style: 表达方式风格。
+ """
content_list = [situation]
try:
with get_db_session() as db:
@@ -353,6 +365,7 @@ class ExpressionLearner:
last_active_time=datetime.now(),
)
db.add(new_expr)
+ db.flush()
except Exception as e:
logger.error(f"创建表达方式失败: {e}")
diff --git a/src/learners/jargon_miner.py b/src/learners/jargon_miner.py
index 2fbf8a2e..674e5cc0 100644
--- a/src/learners/jargon_miner.py
+++ b/src/learners/jargon_miner.py
@@ -1,17 +1,18 @@
from collections import OrderedDict
-from json_repair import repair_json
-from sqlmodel import select
-from typing import List, Optional, Dict, Callable, TypedDict, Set
+from typing import Callable, Dict, List, Optional, Set, TypedDict
import asyncio
import json
import random
-from src.common.logger import get_logger
+from json_repair import repair_json
+from sqlmodel import select
+
+from src.common.data_models.jargon_data_model import MaiJargon
from src.common.database.database import get_db_session
from src.common.database.database_model import Jargon
-from src.common.data_models.jargon_data_model import MaiJargon
-from src.config.config import model_config, global_config
+from src.common.logger import get_logger
+from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.prompt.prompt_manager import prompt_manager
@@ -273,11 +274,12 @@ class JargonMiner:
try:
with get_db_session() as session:
session.add(new_jargon)
+ session.flush()
+ saved += 1
+ self._add_to_cache(content)
except Exception as e:
logger.error(f"保存新黑话 '{content}' 失败: {e}")
continue
- finally:
- self._add_to_cache(content)
# 固定输出提取的jargon结果,格式化为可读形式(只要有提取结果就输出)
if uniq_entries:
# 收集所有提取的jargon内容
From 89df7ccf6b213916aa787290e8c3f3ac97a4fbb0 Mon Sep 17 00:00:00 2001
From: DrSmoothl <1787882683@qq.com>
Date: Sun, 22 Mar 2026 00:43:34 +0800
Subject: [PATCH 26/45] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=20NapCat=20?=
=?UTF-8?q?=E9=80=82=E9=85=8D=E5=99=A8=E7=9A=84=E5=85=A5=E7=AB=99=E6=B6=88?=
=?UTF-8?q?=E6=81=AF=E7=BC=96=E8=A7=A3=E7=A0=81=E5=8A=9F=E8=83=BD=EF=BC=8C?=
=?UTF-8?q?=E5=A2=9E=E5=BC=BA=E6=8F=92=E4=BB=B6=E9=85=8D=E7=BD=AE=E6=9B=B4?=
=?UTF-8?q?=E6=96=B0=E9=80=BB=E8=BE=91=E5=92=8C=E6=95=B0=E6=8D=AE=E5=BA=93?=
=?UTF-8?q?=E4=BA=A4=E4=BA=92=E6=B5=8B=E8=AF=95?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../common_test/test_expression_learner.py | 81 ++++++++++++++
pytests/common_test/test_jargon_miner.py | 90 +++++++++++++++
pytests/test_napcat_adapter_codec.py | 81 ++++++++++++++
pytests/test_napcat_adapter_plugin.py | 60 ++++++++++
pytests/test_plugin_runtime.py | 35 ++++--
src/learners/expression_learner.py | 44 +++++---
src/learners/jargon_miner.py | 24 +++-
src/plugin_runtime/integration.py | 26 ++++-
.../built_in/napcat_adapter/codec_inbound.py | 104 +++++++++++++++++-
src/plugins/built_in/napcat_adapter/plugin.py | 1 +
10 files changed, 511 insertions(+), 35 deletions(-)
create mode 100644 pytests/common_test/test_expression_learner.py
create mode 100644 pytests/common_test/test_jargon_miner.py
create mode 100644 pytests/test_napcat_adapter_plugin.py
diff --git a/pytests/common_test/test_expression_learner.py b/pytests/common_test/test_expression_learner.py
new file mode 100644
index 00000000..951aa424
--- /dev/null
+++ b/pytests/common_test/test_expression_learner.py
@@ -0,0 +1,81 @@
+"""测试表达方式学习器的数据库读取行为。"""
+
+from contextlib import contextmanager
+from typing import Generator
+
+import pytest
+from sqlalchemy.pool import StaticPool
+from sqlmodel import Session, SQLModel, create_engine
+
+from src.bw_learner.expression_learner import ExpressionLearner
+from src.common.database.database_model import Expression
+
+
+@pytest.fixture(name="expression_learner_engine")
+def expression_learner_engine_fixture() -> Generator:
+ """创建用于表达方式学习器测试的内存数据库引擎。
+
+ Yields:
+ Generator: 供测试使用的 SQLite 内存引擎。
+ """
+ engine = create_engine(
+ "sqlite://",
+ connect_args={"check_same_thread": False},
+ poolclass=StaticPool,
+ )
+ SQLModel.metadata.create_all(engine)
+ yield engine
+
+
+def test_find_similar_expression_uses_read_only_session_and_history_content(
+ monkeypatch: pytest.MonkeyPatch,
+ expression_learner_engine,
+) -> None:
+ """查找相似表达方式时,应能在离开会话后安全使用结果,并比较历史情景内容。"""
+ import src.bw_learner.expression_learner as expression_learner_module
+
+ with Session(expression_learner_engine) as session:
+ session.add(
+ Expression(
+ situation="发送汗滴表情",
+ style="发送💦表情符号",
+ content_list='["表达情绪高涨或生理反应"]',
+ count=1,
+ session_id="session-a",
+ checked=False,
+ rejected=False,
+ )
+ )
+ session.commit()
+
+ @contextmanager
+ def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
+ """构造带自动提交语义的测试会话工厂。
+
+ Args:
+ auto_commit: 退出上下文时是否自动提交。
+
+ Yields:
+ Generator[Session, None, None]: SQLModel 会话对象。
+ """
+ session = Session(expression_learner_engine)
+ try:
+ yield session
+ if auto_commit:
+ session.commit()
+ except Exception:
+ session.rollback()
+ raise
+ finally:
+ session.close()
+
+ monkeypatch.setattr(expression_learner_module, "get_db_session", fake_get_db_session)
+
+ learner = ExpressionLearner(session_id="session-a")
+ result = learner._find_similar_expression("表达情绪高涨或生理反应")
+
+ assert result is not None
+ expression, similarity = result
+ assert expression.item_id is not None
+ assert expression.style == "发送💦表情符号"
+ assert similarity == pytest.approx(1.0)
diff --git a/pytests/common_test/test_jargon_miner.py b/pytests/common_test/test_jargon_miner.py
new file mode 100644
index 00000000..bf81e4d2
--- /dev/null
+++ b/pytests/common_test/test_jargon_miner.py
@@ -0,0 +1,90 @@
+"""测试黑话学习器的数据库读取行为。"""
+
+from contextlib import contextmanager
+from typing import Generator
+
+import pytest
+from sqlalchemy.pool import StaticPool
+from sqlmodel import Session, SQLModel, create_engine, select
+
+from src.bw_learner.jargon_miner import JargonMiner
+from src.common.database.database_model import Jargon
+
+
+@pytest.fixture(name="jargon_miner_engine")
+def jargon_miner_engine_fixture() -> Generator:
+ """创建用于黑话学习器测试的内存数据库引擎。
+
+ Yields:
+ Generator: 供测试使用的 SQLite 内存引擎。
+ """
+ engine = create_engine(
+ "sqlite://",
+ connect_args={"check_same_thread": False},
+ poolclass=StaticPool,
+ )
+ SQLModel.metadata.create_all(engine)
+ yield engine
+
+
+@pytest.mark.asyncio
+async def test_process_extracted_entries_updates_existing_jargon_without_detached_session(
+ monkeypatch: pytest.MonkeyPatch,
+ jargon_miner_engine,
+) -> None:
+ """更新已有黑话时,不应因会话关闭导致 ORM 实例失效。"""
+ import src.bw_learner.jargon_miner as jargon_miner_module
+
+ with Session(jargon_miner_engine) as session:
+ session.add(
+ Jargon(
+ content="VF8V4L",
+ raw_content='["[1] first"]',
+ meaning="",
+ session_id_dict='{"session-a": 1}',
+ count=0,
+ is_jargon=True,
+ is_complete=False,
+ is_global=False,
+ last_inference_count=0,
+ )
+ )
+ session.commit()
+
+ @contextmanager
+ def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
+ """构造带自动提交语义的测试会话工厂。
+
+ Args:
+ auto_commit: 退出上下文时是否自动提交。
+
+ Yields:
+ Generator[Session, None, None]: SQLModel 会话对象。
+ """
+ session = Session(jargon_miner_engine)
+ try:
+ yield session
+ if auto_commit:
+ session.commit()
+ except Exception:
+ session.rollback()
+ raise
+ finally:
+ session.close()
+
+ monkeypatch.setattr(jargon_miner_module, "get_db_session", fake_get_db_session)
+
+ jargon_miner = JargonMiner(session_id="session-a", session_name="测试群")
+ await jargon_miner.process_extracted_entries(
+ [{"content": "VF8V4L", "raw_content": {"[2] second"}}],
+ )
+
+ with Session(jargon_miner_engine) as session:
+ db_jargon = session.exec(select(Jargon).where(Jargon.content == "VF8V4L")).one()
+
+ assert db_jargon.count == 1
+ assert db_jargon.session_id_dict == '{"session-a": 2}'
+ assert sorted(db_jargon.raw_content and __import__("json").loads(db_jargon.raw_content)) == [
+ "[1] first",
+ "[2] second",
+ ]
diff --git a/pytests/test_napcat_adapter_codec.py b/pytests/test_napcat_adapter_codec.py
index 6f557e08..97ed1d9e 100644
--- a/pytests/test_napcat_adapter_codec.py
+++ b/pytests/test_napcat_adapter_codec.py
@@ -3,12 +3,16 @@ from typing import Any, Dict
import importlib
import sys
+from types import SimpleNamespace
+
+import pytest
BUILT_IN_PLUGIN_ROOT = Path(__file__).resolve().parents[1] / "src" / "plugins" / "built_in"
if str(BUILT_IN_PLUGIN_ROOT) not in sys.path:
sys.path.insert(0, str(BUILT_IN_PLUGIN_ROOT))
+NapCatInboundCodec = importlib.import_module("napcat_adapter.codec_inbound").NapCatInboundCodec
NapCatOutboundCodec = importlib.import_module("napcat_adapter.codec_outbound").NapCatOutboundCodec
@@ -68,3 +72,80 @@ def test_napcat_outbound_codec_builds_private_action_from_route_metadata() -> No
assert action_name == "send_private_msg"
assert params == {"message": [{"type": "text", "data": {"text": "hello"}}], "user_id": "30001"}
+
+
+class DummyQueryService:
+ """用于测试的轻量查询服务。"""
+
+ async def download_binary(self, url: str) -> bytes:
+ """返回固定图片二进制。
+
+ Args:
+ url: 图片地址。
+
+ Returns:
+ bytes: 固定测试图片二进制。
+ """
+ if url:
+ return b"image-bytes"
+ return b""
+
+ async def get_message_detail(self, message_id: str) -> Dict[str, Any] | None:
+ """返回空消息详情。
+
+ Args:
+ message_id: 目标消息 ID。
+
+ Returns:
+ Dict[str, Any] | None: 固定空结果。
+ """
+ del message_id
+ return None
+
+ async def get_record_detail(self, file_name: str, file_id: str | None = None) -> Dict[str, Any] | None:
+ """返回空语音详情。
+
+ Args:
+ file_name: 语音文件名。
+ file_id: 可选文件 ID。
+
+ Returns:
+ Dict[str, Any] | None: 固定空结果。
+ """
+ del file_name
+ del file_id
+ return None
+
+ async def get_forward_message(self, message_id: str) -> Dict[str, Any] | None:
+ """返回空转发详情。
+
+ Args:
+ message_id: 转发消息 ID。
+
+ Returns:
+ Dict[str, Any] | None: 固定空结果。
+ """
+ del message_id
+ return None
+
+
+@pytest.mark.asyncio
+async def test_napcat_inbound_codec_parses_cq_string_image_segments() -> None:
+ codec = NapCatInboundCodec(SimpleNamespace(debug=lambda message: None), DummyQueryService())
+ payload = {
+ "message": "[CQ:image,file=test.png,sub_type=0,url=https://example.com/test.png][CQ:at,qq=10001] 看到是国人直接给你封了",
+ }
+
+ raw_message, is_at = await codec.convert_segments(payload, "10001")
+
+ assert raw_message[0]["type"] == "image"
+ assert raw_message[1] == {
+ "type": "at",
+ "data": {
+ "target_user_id": "10001",
+ "target_user_nickname": None,
+ "target_user_cardname": None,
+ },
+ }
+ assert raw_message[2] == {"type": "text", "data": " 看到是国人直接给你封了"}
+ assert is_at is True
diff --git a/pytests/test_napcat_adapter_plugin.py b/pytests/test_napcat_adapter_plugin.py
new file mode 100644
index 00000000..ca550a39
--- /dev/null
+++ b/pytests/test_napcat_adapter_plugin.py
@@ -0,0 +1,60 @@
+"""NapCat 插件入口行为测试。"""
+
+from pathlib import Path
+from typing import List
+from types import SimpleNamespace
+
+import importlib
+import sys
+
+import pytest
+
+
+BUILT_IN_PLUGIN_ROOT = Path(__file__).resolve().parents[1] / "src" / "plugins" / "built_in"
+if str(BUILT_IN_PLUGIN_ROOT) not in sys.path:
+ sys.path.insert(0, str(BUILT_IN_PLUGIN_ROOT))
+
+NapCatAdapterPlugin = importlib.import_module("napcat_adapter.plugin").NapCatAdapterPlugin
+
+
+class DummyLogger:
+ """用于测试的轻量日志对象。"""
+
+ def __init__(self) -> None:
+ """初始化测试日志对象。"""
+ self.debug_messages: List[str] = []
+
+ def debug(self, message: str) -> None:
+ """记录调试日志。
+
+ Args:
+ message: 待记录的日志内容。
+ """
+ self.debug_messages.append(message)
+
+
+@pytest.mark.asyncio
+async def test_on_config_update_refreshes_settings_and_restarts(monkeypatch: pytest.MonkeyPatch) -> None:
+ """配置更新时应刷新插件配置、清空旧 settings,并触发连接重启。"""
+ plugin = NapCatAdapterPlugin()
+ plugin._ctx = SimpleNamespace(logger=DummyLogger())
+ plugin._settings = object()
+
+ restart_calls: List[dict] = []
+
+ async def fake_restart() -> None:
+ """记录一次重启调用。"""
+ restart_calls.append(dict(plugin._plugin_config))
+
+ monkeypatch.setattr(plugin, "_restart_connection_if_needed", fake_restart)
+
+ new_config = {
+ "plugin": {"enabled": True, "config_version": "0.1.0"},
+ "napcat_server": {"host": "127.0.0.1", "port": 3001},
+ }
+ await plugin.on_config_update(new_config, "v2")
+
+ assert plugin._plugin_config == new_config
+ assert plugin._settings is None
+ assert restart_calls == [new_config]
+ assert plugin.ctx.logger.debug_messages == ["NapCat 适配器收到配置更新通知: v2"]
diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py
index 5ab16c85..20cceb82 100644
--- a/pytests/test_plugin_runtime.py
+++ b/pytests/test_plugin_runtime.py
@@ -2238,6 +2238,7 @@ class TestIntegration:
async def test_handle_plugin_source_changes_only_reload_matching_supervisor(self, monkeypatch, tmp_path):
from src.config.file_watcher import FileChange
from src.plugin_runtime import integration as integration_module
+ import json
builtin_root = tmp_path / "src" / "plugins" / "built_in"
thirdparty_root = tmp_path / "plugins"
@@ -2247,6 +2248,10 @@ class TestIntegration:
beta_dir.mkdir(parents=True)
(alpha_dir / "config.toml").write_text("enabled = true\n", encoding="utf-8")
(beta_dir / "config.toml").write_text("enabled = false\n", encoding="utf-8")
+ (alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
+ (beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
+ (alpha_dir / "_manifest.json").write_text(json.dumps({"name": "alpha"}), encoding="utf-8")
+ (beta_dir / "_manifest.json").write_text(json.dumps({"name": "beta"}), encoding="utf-8")
monkeypatch.chdir(tmp_path)
@@ -2257,8 +2262,8 @@ class TestIntegration:
self.reload_reasons = []
self.config_updates = []
- async def reload_plugins(self, reason="manual"):
- self.reload_reasons.append(reason)
+ async def reload_plugins(self, plugin_ids=None, reason="manual"):
+ self.reload_reasons.append((plugin_ids, reason))
async def notify_plugin_config_updated(self, plugin_id, config_data, config_version=""):
self.config_updates.append((plugin_id, config_data, config_version))
@@ -2283,13 +2288,13 @@ class TestIntegration:
await manager._handle_plugin_source_changes(changes)
assert manager._builtin_supervisor.reload_reasons == []
- assert manager._third_party_supervisor.reload_reasons == ["file_watcher"]
+ assert manager._third_party_supervisor.reload_reasons == [(["beta"], "file_watcher")]
assert manager._builtin_supervisor.config_updates == []
assert manager._third_party_supervisor.config_updates == []
assert refresh_calls == [True]
@pytest.mark.asyncio
- async def test_handle_plugin_config_changes_only_notify_target_plugin(self, monkeypatch, tmp_path):
+ async def test_handle_plugin_config_changes_only_reload_target_plugin(self, monkeypatch, tmp_path):
from src.plugin_runtime import integration as integration_module
from src.config.file_watcher import FileChange
@@ -2308,27 +2313,35 @@ class TestIntegration:
def __init__(self, plugin_dirs, plugins):
self._plugin_dirs = plugin_dirs
self._registered_plugins = {plugin_id: object() for plugin_id in plugins}
- self.config_updates = []
+ self.reload_calls = []
- async def notify_plugin_config_updated(self, plugin_id, config_data, config_version=""):
- self.config_updates.append((plugin_id, config_data, config_version))
+ async def reload_plugin(self, plugin_id, reason="manual"):
+ self.reload_calls.append((plugin_id, reason))
return True
manager = integration_module.PluginRuntimeManager()
manager._started = True
manager._builtin_supervisor = FakeSupervisor([builtin_root], ["alpha"])
manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["beta"])
+ refresh_calls = []
+
+ def fake_refresh() -> None:
+ refresh_calls.append(True)
+
+ manager._refresh_plugin_config_watch_subscriptions = fake_refresh
await manager._handle_plugin_config_changes(
"alpha",
[FileChange(change_type=1, path=alpha_dir / "config.toml")],
)
- assert manager._builtin_supervisor.config_updates == [("alpha", {"enabled": True}, "")]
- assert manager._third_party_supervisor.config_updates == []
+ assert manager._builtin_supervisor.reload_calls == [("alpha", "config_file_changed")]
+ assert manager._third_party_supervisor.reload_calls == []
+ assert refresh_calls == [True]
def test_refresh_plugin_config_watch_subscriptions_registers_per_plugin(self, tmp_path):
from src.plugin_runtime import integration as integration_module
+ import json
builtin_root = tmp_path / "src" / "plugins" / "built_in"
thirdparty_root = tmp_path / "plugins"
@@ -2336,6 +2349,10 @@ class TestIntegration:
beta_dir = thirdparty_root / "beta"
alpha_dir.mkdir(parents=True)
beta_dir.mkdir(parents=True)
+ (alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
+ (beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
+ (alpha_dir / "_manifest.json").write_text(json.dumps({"name": "alpha"}), encoding="utf-8")
+ (beta_dir / "_manifest.json").write_text(json.dumps({"name": "beta"}), encoding="utf-8")
class FakeWatcher:
def __init__(self):
diff --git a/src/learners/expression_learner.py b/src/learners/expression_learner.py
index 156fedc5..b82ae1fa 100644
--- a/src/learners/expression_learner.py
+++ b/src/learners/expression_learner.py
@@ -461,25 +461,43 @@ class ExpressionLearner:
def _find_similar_expression(
self, situation: str, similarity_threshold: float = 0.75
) -> Optional[Tuple[MaiExpression, float]]:
- """在数据库中查找相似的表达方式"""
+ """在数据库中查找相似的表达方式。
+
+ Args:
+ situation: 当前待匹配的情景描述。
+ similarity_threshold: 认定为相似表达方式的最低相似度阈值。
+
+ Returns:
+ Optional[Tuple[MaiExpression, float]]: 若找到最相似的表达方式,则返回
+ ``(表达方式对象, 相似度)``;否则返回 ``None``。
+ """
try:
- with get_db_session() as session:
+ with get_db_session(auto_commit=False) as session:
statement = select(Expression).filter_by(session_id=self.session_id)
expressions = session.exec(statement).all()
- best_match: Optional[Expression] = None
- best_similarity = 0.0
+ best_match: Optional[MaiExpression] = None
+ best_similarity = 0.0
+
+ for db_expression in expressions:
+ expression = MaiExpression.from_db_instance(db_expression)
+ candidate_situations = [expression.situation, *expression.content]
+ for candidate_situation in candidate_situations:
+ normalized_candidate_situation = candidate_situation.strip()
+ if not normalized_candidate_situation:
+ continue
+ similarity = difflib.SequenceMatcher(
+ None,
+ situation,
+ normalized_candidate_situation,
+ ).ratio()
+ if similarity > similarity_threshold and similarity > best_similarity:
+ best_similarity = similarity
+ best_match = expression
- for expr in expressions:
- content_list = json.loads(expr.content_list)
- for situation in content_list:
- similarity = difflib.SequenceMatcher(None, situation, expr.situation).ratio()
- if similarity > similarity_threshold and similarity > best_similarity:
- best_similarity = similarity
- best_match = expr
if best_match:
- logger.debug(f"找到相似表达方式情景 [ID: {best_match.id}],相似度: {best_similarity:.2f}")
- return MaiExpression.from_db_instance(best_match), best_similarity
+ logger.debug(f"找到相似表达方式情景 [ID: {best_match.item_id}],相似度: {best_similarity:.2f}")
+ return best_match, best_similarity
except Exception as e:
logger.error(f"查找相似表达方式失败: {e}")
diff --git a/src/learners/jargon_miner.py b/src/learners/jargon_miner.py
index 674e5cc0..32926894 100644
--- a/src/learners/jargon_miner.py
+++ b/src/learners/jargon_miner.py
@@ -199,7 +199,7 @@ class JargonMiner:
async def process_extracted_entries(
self, entries: List[JargonEntry], person_name_filter: Optional[Callable[[str], bool]] = None
- ):
+ ) -> None:
"""
处理已提取的黑话条目(从 expression_learner 路由过来的)
@@ -230,7 +230,7 @@ class JargonMiner:
content = entry["content"]
raw_content_set = entry["raw_content"]
try:
- with get_db_session() as session:
+ with get_db_session(auto_commit=False) as session:
jargon_items = session.exec(select(Jargon).filter_by(content=content)).all()
except Exception as e:
logger.error(f"查询黑话 '{content}' 失败: {e}")
@@ -306,7 +306,13 @@ class JargonMiner:
removed_content, _ = self.cache.popitem(last=False)
logger.debug(f"缓存已满,移除最旧的黑话: {removed_content}")
- def _update_jargon(self, db_jargon: Jargon, raw_content_set: Set[str]):
+ def _update_jargon(self, db_jargon: Jargon, raw_content_set: Set[str]) -> None:
+ """更新已有黑话记录并写回数据库。
+
+ Args:
+ db_jargon: 已命中的黑话 ORM 对象。
+ raw_content_set: 本次新增的原始上下文集合。
+ """
db_jargon.count += 1
existing_raw_content: List[str] = []
if db_jargon.raw_content:
@@ -328,7 +334,17 @@ class JargonMiner:
try:
with get_db_session() as session:
- session.add(db_jargon)
+ if db_jargon.id is None:
+ raise ValueError("黑话记录缺少 id,无法更新数据库")
+ statement = select(Jargon).filter_by(id=db_jargon.id).limit(1)
+ if persisted_jargon := session.exec(statement).first():
+ persisted_jargon.count = db_jargon.count
+ persisted_jargon.raw_content = db_jargon.raw_content
+ persisted_jargon.session_id_dict = db_jargon.session_id_dict
+ persisted_jargon.is_global = db_jargon.is_global
+ session.add(persisted_jargon)
+ else:
+ logger.warning(f"黑话 ID {db_jargon.id} 在数据库中未找到,无法更新")
except Exception as e:
logger.error(f"更新黑话 '{db_jargon.content}' 失败: {e}")
diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py
index 24cf09fc..bf85669b 100644
--- a/src/plugin_runtime/integration.py
+++ b/src/plugin_runtime/integration.py
@@ -612,7 +612,17 @@ class PluginRuntimeManager(
return None if plugin_path is None else plugin_path / "config.toml"
async def _handle_plugin_config_changes(self, plugin_id: str, changes: Sequence[FileChange]) -> None:
- """处理单个插件配置文件变化,并仅向目标插件推送配置更新。"""
+ """处理单个插件配置文件变化,并精确重载目标插件。
+
+ Args:
+ plugin_id: 发生配置变更的插件 ID。
+ changes: 当前批次收集到的配置文件变更列表。
+
+ Notes:
+ 这里选择“精确重载该插件”,而不是仅推送软性的配置更新通知。
+ 这样可以保证没有实现 ``on_config_update()`` 的插件也能重新执行
+ ``on_load()``,让磁盘上的 ``config.toml`` 修改对插件运行态真正生效。
+ """
if not self._started or not changes:
return
@@ -626,18 +636,24 @@ class PluginRuntimeManager(
return
try:
- await supervisor.notify_plugin_config_updated(
+ self._load_plugin_config_for_supervisor(supervisor, plugin_id)
+ reload_success = await supervisor.reload_plugin(
plugin_id=plugin_id,
- config_data=self._load_plugin_config_for_supervisor(supervisor, plugin_id),
+ reason="config_file_changed",
)
+ if reload_success:
+ self._refresh_plugin_config_watch_subscriptions()
+ else:
+ logger.warning(f"插件 {plugin_id} 配置文件变更后重载失败")
except Exception as exc:
- logger.warning(f"插件 {plugin_id} 配置热更新通知失败: {exc}")
+ logger.warning(f"插件 {plugin_id} 配置文件变更处理失败: {exc}")
async def _handle_plugin_source_changes(self, changes: Sequence[FileChange]) -> None:
"""处理插件源码相关变化。
这里仅负责源码、清单等会影响插件装载状态的文件;配置文件的变化会由
- 单独的 per-plugin watcher 处理,避免把单插件配置更新放大成全量 reload。
+ 单独的 per-plugin watcher 处理,并精确重载对应插件,避免放大成
+ 不必要的跨插件 reload。
"""
if not self._started or not changes:
return
diff --git a/src/plugins/built_in/napcat_adapter/codec_inbound.py b/src/plugins/built_in/napcat_adapter/codec_inbound.py
index b8065585..8fb020dc 100644
--- a/src/plugins/built_in/napcat_adapter/codec_inbound.py
+++ b/src/plugins/built_in/napcat_adapter/codec_inbound.py
@@ -5,11 +5,15 @@ from uuid import uuid4
import hashlib
import json
+import re
import time
from napcat_adapter.qq_queries import NapCatQueryService
+_CQ_SEGMENT_PATTERN = re.compile(r"\[CQ:(?P[a-zA-Z0-9_]+)(?P(?:,[^\]]*)?)\]")
+
+
class NapCatInboundCodec:
"""NapCat 入站消息编码器。"""
@@ -104,8 +108,12 @@ class NapCatInboundCodec:
"""
message_payload = payload.get("message")
if isinstance(message_payload, str):
- normalized_text = message_payload.strip()
- return ([{"type": "text", "data": normalized_text}] if normalized_text else []), False
+ parsed_message_payload = self._parse_cq_message_text(message_payload)
+ if parsed_message_payload:
+ message_payload = parsed_message_payload
+ else:
+ normalized_text = self._decode_cq_entities(message_payload).strip()
+ return ([{"type": "text", "data": normalized_text}] if normalized_text else []), False
if not isinstance(message_payload, list):
return [], False
@@ -223,8 +231,8 @@ class NapCatInboundCodec:
Returns:
Dict[str, Any]: 转换后的图片或表情消息段。
"""
- subtype = segment_data.get("sub_type")
- actual_is_emoji = is_emoji or (isinstance(subtype, int) and subtype not in {0, 4, 9})
+ subtype = self._normalize_numeric_segment_value(segment_data.get("sub_type"))
+ actual_is_emoji = is_emoji or (subtype is not None and subtype not in {0, 4, 9})
image_url = str(segment_data.get("url") or "").strip()
binary_data = await self._query_service.download_binary(image_url)
@@ -412,3 +420,91 @@ class NapCatInboundCodec:
plain_text = "".join(part for part in plain_text_parts if part).strip()
return plain_text or fallback_text or "[unsupported]"
+
+ def _parse_cq_message_text(self, message_text: str) -> List[Dict[str, Any]]:
+ """将 CQ 码字符串解析为 OneBot 风格消息段列表。
+
+ Args:
+ message_text: NapCat 在字符串模式下返回的消息内容。
+
+ Returns:
+ List[Dict[str, Any]]: 解析后的 OneBot 风格消息段列表。
+ """
+ parsed_segments: List[Dict[str, Any]] = []
+ current_index = 0
+
+ for match in _CQ_SEGMENT_PATTERN.finditer(message_text):
+ prefix_text = self._decode_cq_entities(message_text[current_index : match.start()])
+ if prefix_text:
+ parsed_segments.append({"type": "text", "data": {"text": prefix_text}})
+
+ segment_type = str(match.group("type") or "").strip()
+ segment_data = self._parse_cq_segment_data(match.group("params") or "")
+ if segment_type:
+ parsed_segments.append({"type": segment_type, "data": segment_data})
+ current_index = match.end()
+
+ suffix_text = self._decode_cq_entities(message_text[current_index:])
+ if suffix_text:
+ parsed_segments.append({"type": "text", "data": {"text": suffix_text}})
+
+ return parsed_segments
+
+ def _parse_cq_segment_data(self, raw_params: str) -> Dict[str, Any]:
+ """解析单个 CQ 段中的参数串。
+
+ Args:
+ raw_params: 形如 ``,key=value,key2=value2`` 的原始参数字符串。
+
+ Returns:
+ Dict[str, Any]: 解析后的参数字典。
+ """
+ parsed_data: Dict[str, Any] = {}
+ if not raw_params:
+ return parsed_data
+
+ for item in raw_params.lstrip(",").split(","):
+ if not item or "=" not in item:
+ continue
+ key, value = item.split("=", 1)
+ normalized_key = key.strip()
+ if not normalized_key:
+ continue
+ decoded_value = self._decode_cq_entities(value)
+ parsed_data[normalized_key] = self._normalize_numeric_segment_value(decoded_value)
+
+ return parsed_data
+
+ @staticmethod
+ def _decode_cq_entities(text: str) -> str:
+ """解码 CQ 码中的 HTML 风格转义实体。
+
+ Args:
+ text: 待解码的 CQ 文本。
+
+ Returns:
+ str: 解码后的普通文本。
+ """
+ return (
+ text.replace("&", "&")
+ .replace("[", "[")
+ .replace("]", "]")
+ .replace(",", ",")
+ )
+
+ @staticmethod
+ def _normalize_numeric_segment_value(value: Any) -> Any:
+ """将可安全识别的数字字符串转为整数。
+
+ Args:
+ value: 原始字段值。
+
+ Returns:
+ Any: 规范化后的字段值。
+ """
+ if isinstance(value, str):
+ stripped_value = value.strip()
+ if stripped_value.isdigit():
+ return int(stripped_value)
+ return stripped_value
+ return value
diff --git a/src/plugins/built_in/napcat_adapter/plugin.py b/src/plugins/built_in/napcat_adapter/plugin.py
index b1e9bc8c..50900c5d 100644
--- a/src/plugins/built_in/napcat_adapter/plugin.py
+++ b/src/plugins/built_in/napcat_adapter/plugin.py
@@ -71,6 +71,7 @@ class NapCatAdapterPlugin(MaiBotPlugin):
version: 配置版本号。
"""
self.set_plugin_config(new_config)
+ self._settings = None
if version:
self.ctx.logger.debug(f"NapCat 适配器收到配置更新通知: {version}")
await self._restart_connection_if_needed()
From a0c653de4532cdba8e9a9d99e85b9433c8a87b46 Mon Sep 17 00:00:00 2001
From: DrSmoothl <1787882683@qq.com>
Date: Sun, 22 Mar 2026 01:04:29 +0800
Subject: [PATCH 27/45] =?UTF-8?q?docs:=20=E6=9B=B4=E6=96=B0=E6=B3=A8?=
=?UTF-8?q?=E9=87=8A=E8=A7=84=E8=8C=83=E5=92=8C=E8=AF=AD=E8=A8=80=E8=A7=84?=
=?UTF-8?q?=E8=8C=83=EF=BC=8C=E5=BC=BA=E8=B0=83=E4=BD=BF=E7=94=A8=20Google?=
=?UTF-8?q?=20DocStr=20=E6=A0=BC=E5=BC=8F=E5=92=8C=E7=AE=80=E4=BD=93?=
=?UTF-8?q?=E4=B8=AD=E6=96=87?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
AGENTS.md | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/AGENTS.md b/AGENTS.md
index b4caaaf1..b3456610 100644
--- a/AGENTS.md
+++ b/AGENTS.md
@@ -17,6 +17,7 @@
1. 尽量保持良好的注释
2. 如果原来的代码中有注释,则重构的时候,除非这部分代码被删除,否则相同功能的代码应该保留注释(可以对注释进行修改以保持准确性,但不应该删除注释)。
3. 如果原来的代码中没有注释,则重构的时候,如果某个功能块的代码较长或者逻辑较为复杂,则应该添加注释来解释这部分代码的功能和逻辑。
+4. 对于类,方法以及模块的注释,首选使用的注释格式为 Google DocStr 格式,但保证语言为简体中文
## 类型注解规范
1. 重构代码时,如果原来的代码中有类型注解,则相同功能的代码应该保留类型注解(可以对类型注解进行修改以保持准确性,但不应该删除类型注解)。
2. 重构代码时,如果原来的代码中没有类型注解,则重构的时候,如果某个函数的功能较为复杂或者参数较多,则应该添加类型注解来提高代码的可读性和可维护性。(对于简单的变量,可以不添加类型注解)
@@ -35,3 +36,7 @@
# 运行/调试/构建/测试/依赖
优先使用uv
依赖项以 pyproject.toml 为准
+
+# 语言规范
+
+项目的首选语言为简体中文,无论是注释语言,日志展示语言,还是 WebUI 展示语言都应该首要以简体中文为首要实现目标
From 0066224251a644697f8c1ebeede83b565fdedfbf Mon Sep 17 00:00:00 2001
From: UnCLAS-Prommer
Date: Sun, 22 Mar 2026 12:50:09 +0800
Subject: [PATCH 28/45] fix: remove nc ada
---
.../built_in/napcat_adapter/__init__.py | 1 -
.../built_in/napcat_adapter/_manifest.json | 30 --
.../built_in/napcat_adapter/codec_inbound.py | 510 ------------------
.../built_in/napcat_adapter/codec_outbound.py | 192 -------
src/plugins/built_in/napcat_adapter/config.py | 398 --------------
.../built_in/napcat_adapter/constants.py | 9 -
.../built_in/napcat_adapter/filters.py | 68 ---
src/plugins/built_in/napcat_adapter/plugin.py | 381 -------------
.../built_in/napcat_adapter/qq_notice.py | 224 --------
.../built_in/napcat_adapter/qq_queries.py | 170 ------
.../built_in/napcat_adapter/runtime_state.py | 85 ---
.../built_in/napcat_adapter/transport.py | 322 -----------
12 files changed, 2390 deletions(-)
delete mode 100644 src/plugins/built_in/napcat_adapter/__init__.py
delete mode 100644 src/plugins/built_in/napcat_adapter/_manifest.json
delete mode 100644 src/plugins/built_in/napcat_adapter/codec_inbound.py
delete mode 100644 src/plugins/built_in/napcat_adapter/codec_outbound.py
delete mode 100644 src/plugins/built_in/napcat_adapter/config.py
delete mode 100644 src/plugins/built_in/napcat_adapter/constants.py
delete mode 100644 src/plugins/built_in/napcat_adapter/filters.py
delete mode 100644 src/plugins/built_in/napcat_adapter/plugin.py
delete mode 100644 src/plugins/built_in/napcat_adapter/qq_notice.py
delete mode 100644 src/plugins/built_in/napcat_adapter/qq_queries.py
delete mode 100644 src/plugins/built_in/napcat_adapter/runtime_state.py
delete mode 100644 src/plugins/built_in/napcat_adapter/transport.py
diff --git a/src/plugins/built_in/napcat_adapter/__init__.py b/src/plugins/built_in/napcat_adapter/__init__.py
deleted file mode 100644
index fa82860f..00000000
--- a/src/plugins/built_in/napcat_adapter/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-"""NapCat 内置适配器插件包。"""
diff --git a/src/plugins/built_in/napcat_adapter/_manifest.json b/src/plugins/built_in/napcat_adapter/_manifest.json
deleted file mode 100644
index 6f7e68fd..00000000
--- a/src/plugins/built_in/napcat_adapter/_manifest.json
+++ /dev/null
@@ -1,30 +0,0 @@
-{
- "manifest_version": 1,
- "name": "napcat_adapter_builtin",
- "version": "0.1.0",
- "description": "Built-in NapCat adapter plugin for MVP message forwarding.",
- "author": {
- "name": "OpenAI Codex"
- },
- "license": "GPL-v3.0-or-later",
- "host_application": {
- "min_version": "1.0.0"
- },
- "keywords": [
- "adapter",
- "built-in",
- "napcat",
- "onebot",
- "qq"
- ],
- "categories": [
- "Adapter",
- "Built-in"
- ],
- "default_locale": "en-US",
- "plugin_info": {
- "is_built_in": true,
- "plugin_type": "adapter"
- },
- "capabilities": []
-}
diff --git a/src/plugins/built_in/napcat_adapter/codec_inbound.py b/src/plugins/built_in/napcat_adapter/codec_inbound.py
deleted file mode 100644
index 8fb020dc..00000000
--- a/src/plugins/built_in/napcat_adapter/codec_inbound.py
+++ /dev/null
@@ -1,510 +0,0 @@
-"""NapCat 入站消息编解码。"""
-
-from typing import Any, Dict, List, Mapping, Optional, Tuple
-from uuid import uuid4
-
-import hashlib
-import json
-import re
-import time
-
-from napcat_adapter.qq_queries import NapCatQueryService
-
-
-_CQ_SEGMENT_PATTERN = re.compile(r"\[CQ:(?P[a-zA-Z0-9_]+)(?P(?:,[^\]]*)?)\]")
-
-
-class NapCatInboundCodec:
- """NapCat 入站消息编码器。"""
-
- def __init__(self, logger: Any, query_service: NapCatQueryService) -> None:
- """初始化入站消息编码器。
-
- Args:
- logger: 插件日志对象。
- query_service: QQ 查询服务。
- """
- self._logger = logger
- self._query_service = query_service
-
- async def build_message_dict(
- self,
- payload: Mapping[str, Any],
- self_id: str,
- sender_user_id: str,
- sender: Mapping[str, Any],
- ) -> Dict[str, Any]:
- """构造 Host 侧可接受的 ``MessageDict``。
-
- Args:
- payload: NapCat 原始消息事件。
- self_id: 当前机器人账号 ID。
- sender_user_id: 发送者用户 ID。
- sender: 发送者信息字典。
-
- Returns:
- Dict[str, Any]: 规范化后的 ``MessageDict``。
- """
- message_type = str(payload.get("message_type") or "").strip() or "private"
- group_id = str(payload.get("group_id") or "").strip()
- group_name = str(payload.get("group_name") or "").strip() or (f"group_{group_id}" if group_id else "")
- user_nickname = str(sender.get("nickname") or sender.get("card") or sender_user_id).strip() or sender_user_id
- user_cardname = str(sender.get("card") or "").strip() or None
-
- raw_message, is_at = await self.convert_segments(payload, self_id)
- raw_message_text = str(payload.get("raw_message") or "").strip()
- if not raw_message:
- raw_message = [{"type": "text", "data": raw_message_text or "[unsupported]"}]
-
- plain_text = self.build_plain_text(raw_message, raw_message_text)
- timestamp_seconds = payload.get("time")
- if not isinstance(timestamp_seconds, (int, float)):
- timestamp_seconds = time.time()
-
- additional_config: Dict[str, Any] = {"self_id": self_id, "napcat_message_type": message_type}
- if group_id:
- additional_config["platform_io_target_group_id"] = group_id
- else:
- additional_config["platform_io_target_user_id"] = sender_user_id
-
- message_info: Dict[str, Any] = {
- "user_info": {
- "user_id": sender_user_id,
- "user_nickname": user_nickname,
- "user_cardname": user_cardname,
- },
- "additional_config": additional_config,
- }
- if group_id:
- message_info["group_info"] = {"group_id": group_id, "group_name": group_name}
-
- message_id = str(payload.get("message_id") or f"napcat-{uuid4().hex}").strip()
- return {
- "message_id": message_id,
- "timestamp": str(float(timestamp_seconds)),
- "platform": "qq",
- "message_info": message_info,
- "raw_message": raw_message,
- "is_mentioned": is_at,
- "is_at": is_at,
- "is_emoji": False,
- "is_picture": False,
- "is_command": plain_text.startswith("/"),
- "is_notify": False,
- "session_id": "",
- "processed_plain_text": plain_text,
- "display_message": plain_text,
- }
-
- async def convert_segments(self, payload: Mapping[str, Any], self_id: str) -> Tuple[List[Dict[str, Any]], bool]:
- """将 OneBot 消息段转换为 Host 消息段结构。
-
- Args:
- payload: OneBot 原始消息事件。
- self_id: 当前机器人账号 ID。
-
- Returns:
- Tuple[List[Dict[str, Any]], bool]: 转换后的消息段列表,以及是否 @ 到当前机器人。
- """
- message_payload = payload.get("message")
- if isinstance(message_payload, str):
- parsed_message_payload = self._parse_cq_message_text(message_payload)
- if parsed_message_payload:
- message_payload = parsed_message_payload
- else:
- normalized_text = self._decode_cq_entities(message_payload).strip()
- return ([{"type": "text", "data": normalized_text}] if normalized_text else []), False
-
- if not isinstance(message_payload, list):
- return [], False
-
- converted_segments: List[Dict[str, Any]] = []
- is_at = False
- for segment in message_payload:
- if not isinstance(segment, Mapping):
- continue
-
- segment_type = str(segment.get("type") or "").strip()
- segment_data = segment.get("data", {})
- if not isinstance(segment_data, Mapping):
- segment_data = {}
-
- if segment_type == "text":
- if text_value := str(segment_data.get("text") or ""):
- converted_segments.append({"type": "text", "data": text_value})
- continue
-
- if segment_type == "at":
- if target_user_id := str(segment_data.get("qq") or "").strip():
- converted_segments.append(
- {
- "type": "at",
- "data": {
- "target_user_id": target_user_id,
- "target_user_nickname": None,
- "target_user_cardname": None,
- },
- }
- )
- if self_id and target_user_id == self_id:
- is_at = True
- continue
-
- if segment_type == "reply":
- if reply_segment := await self._build_reply_segment(segment_data):
- converted_segments.append(reply_segment)
- continue
-
- if segment_type == "face":
- converted_segments.append({"type": "text", "data": "[face]"})
- continue
-
- if segment_type == "image":
- converted_segments.append(await self._build_image_like_segment(segment_data, is_emoji=False))
- continue
-
- if segment_type == "record":
- converted_segments.append(await self._build_record_segment(segment_data))
- continue
-
- if segment_type == "video":
- converted_segments.append({"type": "text", "data": "[video]"})
- continue
-
- if segment_type == "file":
- converted_segments.append({"type": "text", "data": "[file]"})
- continue
-
- if segment_type == "json":
- converted_segments.append(self._build_json_text_segment(segment_data))
- continue
-
- if segment_type == "forward":
- if forward_segment := await self._build_forward_segment(segment_data):
- converted_segments.append(forward_segment)
- continue
-
- if segment_type in {"xml", "share"}:
- converted_segments.append({"type": "text", "data": f"[{segment_type}]"})
-
- return converted_segments, is_at
-
- async def _build_reply_segment(self, segment_data: Mapping[str, Any]) -> Optional[Dict[str, Any]]:
- """构造回复消息段。
-
- Args:
- segment_data: OneBot ``reply`` 段的 ``data`` 字典。
-
- Returns:
- Optional[Dict[str, Any]]: 转换后的回复消息段;缺少消息 ID 时返回 ``None``。
- """
- target_message_id = str(segment_data.get("id") or "").strip()
- if not target_message_id:
- return None
-
- message_detail = await self._query_service.get_message_detail(target_message_id)
- reply_payload: Dict[str, Any] = {"target_message_id": target_message_id}
- if message_detail is not None:
- sender = message_detail.get("sender", {})
- if not isinstance(sender, Mapping):
- sender = {}
- reply_payload["target_message_content"] = str(message_detail.get("raw_message") or "").strip() or None
- reply_payload["target_message_sender_id"] = str(
- message_detail.get("user_id") or sender.get("user_id") or ""
- ).strip() or None
- reply_payload["target_message_sender_nickname"] = str(sender.get("nickname") or "").strip() or None
- reply_payload["target_message_sender_cardname"] = str(sender.get("card") or "").strip() or None
-
- return {"type": "reply", "data": reply_payload}
-
- async def _build_image_like_segment(
- self,
- segment_data: Mapping[str, Any],
- is_emoji: bool,
- ) -> Dict[str, Any]:
- """构造图片或表情消息段。
-
- Args:
- segment_data: OneBot ``image`` 段的 ``data`` 字典。
- is_emoji: 是否按表情组件处理。
-
- Returns:
- Dict[str, Any]: 转换后的图片或表情消息段。
- """
- subtype = self._normalize_numeric_segment_value(segment_data.get("sub_type"))
- actual_is_emoji = is_emoji or (subtype is not None and subtype not in {0, 4, 9})
-
- image_url = str(segment_data.get("url") or "").strip()
- binary_data = await self._query_service.download_binary(image_url)
- if not binary_data:
- return {"type": "text", "data": "[emoji]" if actual_is_emoji else "[image]"}
-
- return {
- "type": "emoji" if actual_is_emoji else "image",
- "data": "",
- "hash": hashlib.sha256(binary_data).hexdigest(),
- "binary_data_base64": self._encode_binary(binary_data),
- }
-
- async def _build_record_segment(self, segment_data: Mapping[str, Any]) -> Dict[str, Any]:
- """构造语音消息段。
-
- Args:
- segment_data: OneBot ``record`` 段的 ``data`` 字典。
-
- Returns:
- Dict[str, Any]: 转换后的语音或占位文本消息段。
- """
- file_name = str(segment_data.get("file") or "").strip()
- file_id = str(segment_data.get("file_id") or "").strip() or None
- if not file_name:
- return {"type": "text", "data": "[voice]"}
-
- record_detail = await self._query_service.get_record_detail(file_name=file_name, file_id=file_id)
- if record_detail is None:
- return {"type": "text", "data": "[voice]"}
-
- record_base64 = str(record_detail.get("base64") or "").strip()
- if not record_base64:
- return {"type": "text", "data": "[voice]"}
-
- try:
- binary_data = self._decode_binary(record_base64)
- except Exception:
- return {"type": "text", "data": "[voice]"}
-
- return {
- "type": "voice",
- "data": "",
- "hash": hashlib.sha256(binary_data).hexdigest(),
- "binary_data_base64": self._encode_binary(binary_data),
- }
-
- async def _build_forward_segment(self, segment_data: Mapping[str, Any]) -> Optional[Dict[str, Any]]:
- """构造合并转发消息段。
-
- Args:
- segment_data: OneBot ``forward`` 段的 ``data`` 字典。
-
- Returns:
- Optional[Dict[str, Any]]: 转换后的合并转发消息段;失败时返回 ``None``。
- """
- message_id = str(segment_data.get("id") or "").strip()
- if not message_id:
- return None
-
- forward_detail = await self._query_service.get_forward_message(message_id)
- if forward_detail is None:
- return {"type": "text", "data": "[forward]"}
-
- messages = forward_detail.get("messages", [])
- if not isinstance(messages, list):
- return {"type": "text", "data": "[forward]"}
-
- forward_nodes: List[Dict[str, Any]] = []
- for forward_message in messages:
- if not isinstance(forward_message, Mapping):
- continue
- raw_content = forward_message.get("content", [])
- content_segments = await self._convert_forward_content(raw_content, "")
- sender = forward_message.get("sender", {})
- if not isinstance(sender, Mapping):
- sender = {}
- forward_nodes.append(
- {
- "user_id": str(sender.get("user_id") or sender.get("uin") or "").strip() or None,
- "user_nickname": str(sender.get("nickname") or sender.get("name") or "未知用户"),
- "user_cardname": str(sender.get("card") or "").strip() or None,
- "message_id": str(forward_message.get("message_id") or uuid4().hex),
- "content": content_segments or [{"type": "text", "data": "[empty]"}],
- }
- )
-
- if not forward_nodes:
- return {"type": "text", "data": "[forward]"}
- return {"type": "forward", "data": forward_nodes}
-
- async def _convert_forward_content(self, raw_content: Any, self_id: str) -> List[Dict[str, Any]]:
- """转换转发节点内部的消息段列表。
-
- Args:
- raw_content: 转发节点原始内容。
- self_id: 当前机器人账号 ID。
-
- Returns:
- List[Dict[str, Any]]: 转换后的消息段列表。
- """
- pseudo_payload: Dict[str, Any] = {"message": raw_content}
- segments, _ = await self.convert_segments(pseudo_payload, self_id)
- return segments
-
- def _build_json_text_segment(self, segment_data: Mapping[str, Any]) -> Dict[str, Any]:
- """将 JSON 卡片最佳努力转换为文本占位。
-
- Args:
- segment_data: OneBot ``json`` 段的 ``data`` 字典。
-
- Returns:
- Dict[str, Any]: 转换后的文本消息段。
- """
- json_data = str(segment_data.get("data") or "").strip()
- if not json_data:
- return {"type": "text", "data": "[json]"}
-
- try:
- parsed_json = json.loads(json_data)
- except Exception:
- return {"type": "text", "data": "[json]"}
-
- app_name = str(parsed_json.get("app") or "").strip()
- prompt = ""
- if isinstance(parsed_json.get("meta"), Mapping):
- prompt = str(parsed_json["meta"].get("prompt") or "").strip()
- text = prompt or app_name or "json"
- return {"type": "text", "data": f"[json:{text}]"}
-
- @staticmethod
- def _encode_binary(binary_data: bytes) -> str:
- """将二进制内容编码为 Base64 字符串。
-
- Args:
- binary_data: 待编码的二进制内容。
-
- Returns:
- str: Base64 编码字符串。
- """
- import base64
-
- return base64.b64encode(binary_data).decode("utf-8")
-
- @staticmethod
- def _decode_binary(binary_base64: str) -> bytes:
- """将 Base64 字符串解码为二进制内容。
-
- Args:
- binary_base64: Base64 字符串。
-
- Returns:
- bytes: 解码后的二进制内容。
- """
- import base64
-
- return base64.b64decode(binary_base64)
-
- def build_plain_text(self, raw_message: List[Dict[str, Any]], fallback_text: str) -> str:
- """从标准消息段中提取可展示的纯文本。
-
- Args:
- raw_message: 标准化后的消息段列表。
- fallback_text: 当无法拼出文本时使用的回退文本。
-
- Returns:
- str: 用于 Host 展示和命令判断的纯文本内容。
- """
- plain_text_parts: List[str] = []
- for item in raw_message:
- if not isinstance(item, Mapping):
- continue
- item_type = str(item.get("type") or "").strip()
- item_data = item.get("data")
- if item_type == "text":
- plain_text_parts.append(str(item_data or ""))
- elif item_type == "at" and isinstance(item_data, Mapping):
- plain_text_parts.append(f"@{item_data.get('target_user_id') or ''}")
- elif item_type == "reply":
- plain_text_parts.append("[reply]")
- elif item_type == "forward":
- plain_text_parts.append("[forward]")
- elif item_type in {"image", "emoji", "voice"}:
- plain_text_parts.append(f"[{item_type}]")
-
- plain_text = "".join(part for part in plain_text_parts if part).strip()
- return plain_text or fallback_text or "[unsupported]"
-
- def _parse_cq_message_text(self, message_text: str) -> List[Dict[str, Any]]:
- """将 CQ 码字符串解析为 OneBot 风格消息段列表。
-
- Args:
- message_text: NapCat 在字符串模式下返回的消息内容。
-
- Returns:
- List[Dict[str, Any]]: 解析后的 OneBot 风格消息段列表。
- """
- parsed_segments: List[Dict[str, Any]] = []
- current_index = 0
-
- for match in _CQ_SEGMENT_PATTERN.finditer(message_text):
- prefix_text = self._decode_cq_entities(message_text[current_index : match.start()])
- if prefix_text:
- parsed_segments.append({"type": "text", "data": {"text": prefix_text}})
-
- segment_type = str(match.group("type") or "").strip()
- segment_data = self._parse_cq_segment_data(match.group("params") or "")
- if segment_type:
- parsed_segments.append({"type": segment_type, "data": segment_data})
- current_index = match.end()
-
- suffix_text = self._decode_cq_entities(message_text[current_index:])
- if suffix_text:
- parsed_segments.append({"type": "text", "data": {"text": suffix_text}})
-
- return parsed_segments
-
- def _parse_cq_segment_data(self, raw_params: str) -> Dict[str, Any]:
- """解析单个 CQ 段中的参数串。
-
- Args:
- raw_params: 形如 ``,key=value,key2=value2`` 的原始参数字符串。
-
- Returns:
- Dict[str, Any]: 解析后的参数字典。
- """
- parsed_data: Dict[str, Any] = {}
- if not raw_params:
- return parsed_data
-
- for item in raw_params.lstrip(",").split(","):
- if not item or "=" not in item:
- continue
- key, value = item.split("=", 1)
- normalized_key = key.strip()
- if not normalized_key:
- continue
- decoded_value = self._decode_cq_entities(value)
- parsed_data[normalized_key] = self._normalize_numeric_segment_value(decoded_value)
-
- return parsed_data
-
- @staticmethod
- def _decode_cq_entities(text: str) -> str:
- """解码 CQ 码中的 HTML 风格转义实体。
-
- Args:
- text: 待解码的 CQ 文本。
-
- Returns:
- str: 解码后的普通文本。
- """
- return (
- text.replace("&", "&")
- .replace("[", "[")
- .replace("]", "]")
- .replace(",", ",")
- )
-
- @staticmethod
- def _normalize_numeric_segment_value(value: Any) -> Any:
- """将可安全识别的数字字符串转为整数。
-
- Args:
- value: 原始字段值。
-
- Returns:
- Any: 规范化后的字段值。
- """
- if isinstance(value, str):
- stripped_value = value.strip()
- if stripped_value.isdigit():
- return int(stripped_value)
- return stripped_value
- return value
diff --git a/src/plugins/built_in/napcat_adapter/codec_outbound.py b/src/plugins/built_in/napcat_adapter/codec_outbound.py
deleted file mode 100644
index 6adcb622..00000000
--- a/src/plugins/built_in/napcat_adapter/codec_outbound.py
+++ /dev/null
@@ -1,192 +0,0 @@
-"""NapCat 出站消息编解码。"""
-
-from typing import Any, Dict, List, Mapping, Tuple
-
-
-class NapCatOutboundCodec:
- """NapCat 出站消息编码器。"""
-
- def build_outbound_action(
- self,
- message: Mapping[str, Any],
- route: Mapping[str, Any],
- ) -> Tuple[str, Dict[str, Any]]:
- """为 Host 出站消息构造 OneBot 动作。
-
- Args:
- message: Host 侧标准 ``MessageDict``。
- route: Platform IO 路由信息。
-
- Returns:
- Tuple[str, Dict[str, Any]]: 动作名称与参数字典。
-
- Raises:
- ValueError: 当私聊出站缺少目标用户 ID 时抛出。
- """
- message_info = message.get("message_info", {})
- if not isinstance(message_info, Mapping):
- message_info = {}
-
- group_info = message_info.get("group_info", {})
- if not isinstance(group_info, Mapping):
- group_info = {}
-
- additional_config = message_info.get("additional_config", {})
- if not isinstance(additional_config, Mapping):
- additional_config = {}
-
- raw_message = message.get("raw_message", [])
- segments = self.convert_segments(raw_message)
-
- if target_group_id := str(
- group_info.get("group_id") or additional_config.get("platform_io_target_group_id") or ""
- ).strip():
- return "send_group_msg", {"group_id": target_group_id, "message": segments}
-
- target_user_id = str(
- additional_config.get("platform_io_target_user_id")
- or additional_config.get("target_user_id")
- or route.get("target_user_id")
- or ""
- ).strip()
- if not target_user_id:
- raise ValueError("Outbound private message is missing target_user_id")
-
- return "send_private_msg", {"message": segments, "user_id": target_user_id}
-
- def convert_segments(self, raw_message: Any) -> List[Dict[str, Any]]:
- """将 Host 消息段转换为 OneBot 消息段。
-
- Args:
- raw_message: Host 侧 ``raw_message`` 字段。
-
- Returns:
- List[Dict[str, Any]]: OneBot 消息段列表。
- """
- if not isinstance(raw_message, list):
- return [{"type": "text", "data": {"text": ""}}]
-
- outbound_segments: List[Dict[str, Any]] = []
- for item in raw_message:
- if not isinstance(item, Mapping):
- continue
-
- item_type = str(item.get("type") or "").strip()
- item_data = item.get("data")
-
- if item_type == "text":
- text_value = str(item_data or "")
- outbound_segments.append({"type": "text", "data": {"text": text_value}})
- continue
-
- if item_type == "at" and isinstance(item_data, Mapping):
- if target_user_id := str(item_data.get("target_user_id") or "").strip():
- outbound_segments.append({"type": "at", "data": {"qq": target_user_id}})
- continue
-
- if item_type == "reply":
- if isinstance(item_data, Mapping):
- target_message_id = str(item_data.get("target_message_id") or "").strip()
- else:
- target_message_id = str(item_data or "").strip()
- if target_message_id:
- outbound_segments.append({"type": "reply", "data": {"id": target_message_id}})
- continue
-
- if item_type == "image":
- binary_base64 = str(item.get("binary_data_base64") or "").strip()
- if binary_base64:
- outbound_segments.append(
- {
- "type": "image",
- "data": {"file": f"base64://{binary_base64}", "subtype": 0},
- }
- )
- else:
- outbound_segments.append({"type": "text", "data": {"text": "[image]"}})
- continue
-
- if item_type == "emoji":
- binary_base64 = str(item.get("binary_data_base64") or "").strip()
- if binary_base64:
- outbound_segments.append(
- {
- "type": "image",
- "data": {
- "file": f"base64://{binary_base64}",
- "subtype": 1,
- "summary": "[动画表情]",
- },
- }
- )
- else:
- outbound_segments.append({"type": "text", "data": {"text": "[emoji]"}})
- continue
-
- if item_type == "voice":
- binary_base64 = str(item.get("binary_data_base64") or "").strip()
- if binary_base64:
- outbound_segments.append({"type": "record", "data": {"file": f"base64://{binary_base64}"}})
- else:
- outbound_segments.append({"type": "text", "data": {"text": "[voice]"}})
- continue
-
- if item_type == "forward" and isinstance(item_data, list):
- outbound_segments.extend(self._build_forward_nodes(item_data))
- continue
-
- if item_type == "dict" and isinstance(item_data, Mapping):
- if dict_segment := self._build_dict_component_segment(item_data):
- outbound_segments.append(dict_segment)
- continue
-
- fallback_text = f"[unsupported:{item_type or 'unknown'}]"
- outbound_segments.append({"type": "text", "data": {"text": fallback_text}})
-
- if not outbound_segments:
- outbound_segments.append({"type": "text", "data": {"text": ""}})
- return outbound_segments
-
- def _build_forward_nodes(self, forward_nodes: List[Any]) -> List[Dict[str, Any]]:
- """构造 NapCat 转发节点列表。
-
- Args:
- forward_nodes: 内部转发节点列表。
-
- Returns:
- List[Dict[str, Any]]: NapCat 转发节点列表。
- """
- built_nodes: List[Dict[str, Any]] = []
- for node in forward_nodes:
- if not isinstance(node, Mapping):
- continue
- raw_content = node.get("content", [])
- node_segments = self.convert_segments(raw_content)
- built_nodes.append(
- {
- "type": "node",
- "data": {
- "name": str(node.get("user_nickname") or node.get("user_cardname") or "QQ用户"),
- "uin": str(node.get("user_id") or ""),
- "content": node_segments,
- },
- }
- )
- return built_nodes
-
- def _build_dict_component_segment(self, item_data: Mapping[str, Any]) -> Dict[str, Any]:
- """尽力将 ``DictComponent`` 转换为 NapCat 消息段。
-
- Args:
- item_data: ``DictComponent`` 原始数据。
-
- Returns:
- Dict[str, Any]: NapCat 消息段;不支持时返回占位文本段。
- """
- raw_type = str(item_data.get("type") or "").strip()
- raw_payload = item_data.get("data", item_data)
- if raw_type in {"file", "music", "video", "face"} and isinstance(raw_payload, Mapping):
- return {"type": raw_type, "data": dict(raw_payload)}
- if raw_type in {"image", "record", "reply", "at"} and isinstance(raw_payload, Mapping):
- return {"type": raw_type, "data": dict(raw_payload)}
- return {"type": "text", "data": {"text": f"[unsupported:{raw_type or 'dict'}]"}}
diff --git a/src/plugins/built_in/napcat_adapter/config.py b/src/plugins/built_in/napcat_adapter/config.py
deleted file mode 100644
index eeb4acab..00000000
--- a/src/plugins/built_in/napcat_adapter/config.py
+++ /dev/null
@@ -1,398 +0,0 @@
-"""NapCat 内置适配器配置解析。"""
-
-from dataclasses import dataclass, field
-from typing import Any, Dict, Mapping, Optional, Set, Tuple
-from urllib.parse import urlparse
-
-from napcat_adapter.constants import (
- DEFAULT_ACTION_TIMEOUT_SEC,
- DEFAULT_CHAT_LIST_TYPE,
- DEFAULT_HEARTBEAT_INTERVAL_SEC,
- DEFAULT_NAPCAT_HOST,
- DEFAULT_NAPCAT_PORT,
- DEFAULT_RECONNECT_DELAY_SEC,
- SUPPORTED_CONFIG_VERSION,
-)
-
-
-@dataclass(frozen=True)
-class NapCatPluginOptions:
- """插件级配置。"""
-
- enabled: bool = False
- config_version: str = ""
-
- def should_connect(self) -> bool:
- """判断当前配置下是否应当启动连接。
-
- Returns:
- bool: 若插件连接已启用,则返回 ``True``。
- """
- return self.enabled
-
-
-@dataclass(frozen=True)
-class NapCatServerConfig:
- """NapCat 正向 WebSocket 连接配置。"""
-
- host: str = DEFAULT_NAPCAT_HOST
- port: int = DEFAULT_NAPCAT_PORT
- token: str = ""
- heartbeat_interval: float = DEFAULT_HEARTBEAT_INTERVAL_SEC
- reconnect_delay_sec: float = DEFAULT_RECONNECT_DELAY_SEC
- action_timeout_sec: float = DEFAULT_ACTION_TIMEOUT_SEC
- connection_id: str = ""
-
- def build_ws_url(self) -> str:
- """构造正向 WebSocket 地址。
-
- Returns:
- str: 供适配器作为客户端连接的 NapCat WebSocket 地址。
- """
- return f"ws://{self.host}:{self.port}"
-
-
-@dataclass(frozen=True)
-class NapCatChatConfig:
- """聊天名单配置。"""
-
- group_list_type: str = DEFAULT_CHAT_LIST_TYPE
- group_list: Set[str] = field(default_factory=set)
- private_list_type: str = DEFAULT_CHAT_LIST_TYPE
- private_list: Set[str] = field(default_factory=set)
- ban_user_id: Set[str] = field(default_factory=set)
-
-
-@dataclass(frozen=True)
-class NapCatFilterConfig:
- """消息过滤配置。"""
-
- ignore_self_message: bool = True
-
-
-@dataclass(frozen=True)
-class NapCatPluginSettings:
- """NapCat 插件完整配置。"""
-
- plugin: NapCatPluginOptions = field(default_factory=NapCatPluginOptions)
- napcat_server: NapCatServerConfig = field(default_factory=NapCatServerConfig)
- chat: NapCatChatConfig = field(default_factory=NapCatChatConfig)
- filters: NapCatFilterConfig = field(default_factory=NapCatFilterConfig)
-
- @classmethod
- def from_mapping(cls, raw_config: Mapping[str, Any], logger: Any) -> "NapCatPluginSettings":
- """从 Runner 注入的原始配置字典解析插件配置。
-
- Args:
- raw_config: Runner 注入的原始配置内容。
- logger: 插件日志对象。
-
- Returns:
- NapCatPluginSettings: 规范化后的插件配置。
- """
- plugin_section = _as_mapping(raw_config.get("plugin"))
- server_section = _as_mapping(raw_config.get("napcat_server"))
- legacy_connection_section = _as_mapping(raw_config.get("connection"))
- chat_section = _as_mapping(raw_config.get("chat"))
- filters_section = _as_mapping(raw_config.get("filters"))
-
- if not server_section and legacy_connection_section:
- logger.warning("NapCat 适配器检测到旧版 [connection] 配置段,请尽快迁移到 [napcat_server]")
- server_section = legacy_connection_section
-
- legacy_host, legacy_port = _read_legacy_host_port(server_section, legacy_connection_section, logger)
- parsed_host = _read_string(server_section, "host") or legacy_host or DEFAULT_NAPCAT_HOST
- parsed_port = _read_positive_int(
- mapping=server_section,
- key="port",
- default=legacy_port or DEFAULT_NAPCAT_PORT,
- logger=logger,
- setting_name="napcat_server.port",
- )
-
- return cls(
- plugin=NapCatPluginOptions(
- enabled=_read_bool(plugin_section, "enabled", False),
- config_version=_read_string(plugin_section, "config_version"),
- ),
- napcat_server=NapCatServerConfig(
- host=parsed_host,
- port=parsed_port,
- token=_read_string(server_section, "token") or _read_string(server_section, "access_token"),
- heartbeat_interval=_read_positive_float(
- mapping=server_section,
- key="heartbeat_interval",
- default=_read_positive_float(
- mapping=server_section,
- key="heartbeat_sec",
- default=DEFAULT_HEARTBEAT_INTERVAL_SEC,
- logger=logger,
- setting_name="napcat_server.heartbeat_interval",
- ),
- logger=logger,
- setting_name="napcat_server.heartbeat_interval",
- ),
- reconnect_delay_sec=_read_positive_float(
- mapping=server_section,
- key="reconnect_delay_sec",
- default=DEFAULT_RECONNECT_DELAY_SEC,
- logger=logger,
- setting_name="napcat_server.reconnect_delay_sec",
- ),
- action_timeout_sec=_read_positive_float(
- mapping=server_section,
- key="action_timeout_sec",
- default=DEFAULT_ACTION_TIMEOUT_SEC,
- logger=logger,
- setting_name="napcat_server.action_timeout_sec",
- ),
- connection_id=_read_string(server_section, "connection_id"),
- ),
- chat=NapCatChatConfig(
- group_list_type=_read_list_mode(
- mapping=chat_section,
- key="group_list_type",
- default=DEFAULT_CHAT_LIST_TYPE,
- logger=logger,
- setting_name="chat.group_list_type",
- ),
- group_list=_read_string_set(chat_section, "group_list"),
- private_list_type=_read_list_mode(
- mapping=chat_section,
- key="private_list_type",
- default=DEFAULT_CHAT_LIST_TYPE,
- logger=logger,
- setting_name="chat.private_list_type",
- ),
- private_list=_read_string_set(chat_section, "private_list"),
- ban_user_id=_read_string_set(chat_section, "ban_user_id"),
- ),
- filters=NapCatFilterConfig(
- ignore_self_message=_read_bool(filters_section, "ignore_self_message", True),
- ),
- )
-
- def should_connect(self) -> bool:
- """判断当前配置下是否应当启动连接。
-
- Returns:
- bool: 若插件连接已启用,则返回 ``True``。
- """
- return self.plugin.should_connect()
-
- def validate(self, logger: Any) -> bool:
- """校验当前配置是否满足启动连接的前提条件。
-
- Args:
- logger: 插件日志对象。
-
- Returns:
- bool: 若配置满足启动连接的前提条件,则返回 ``True``。
- """
- config_version = self.plugin.config_version
- if not config_version:
- logger.error(
- f"NapCat 适配器配置缺少 plugin.config_version,当前插件要求版本 {SUPPORTED_CONFIG_VERSION}"
- )
- return False
-
- if config_version != SUPPORTED_CONFIG_VERSION:
- logger.error(
- "NapCat 适配器配置版本不兼容: "
- f"当前为 {config_version},当前插件要求 {SUPPORTED_CONFIG_VERSION}"
- )
- return False
-
- if not self.napcat_server.host:
- logger.warning("NapCat 适配器已启用,但 napcat_server.host 为空")
- return False
-
- if self.napcat_server.port <= 0:
- logger.warning("NapCat 适配器已启用,但 napcat_server.port 不是正整数")
- return False
-
- return True
-
-
-def _as_mapping(value: Any) -> Dict[str, Any]:
- """将任意值安全转换为字典。
-
- Args:
- value: 待转换的值。
-
- Returns:
- Dict[str, Any]: 若原值是映射,则返回普通字典;否则返回空字典。
- """
- return dict(value) if isinstance(value, Mapping) else {}
-
-
-def _read_bool(mapping: Mapping[str, Any], key: str, default: bool) -> bool:
- """安全读取布尔配置值。
-
- Args:
- mapping: 待读取的配置字典。
- key: 目标键名。
- default: 读取失败时的默认值。
-
- Returns:
- bool: 解析后的布尔值。
- """
- value = mapping.get(key, default)
- return value if isinstance(value, bool) else default
-
-
-def _read_string(mapping: Mapping[str, Any], key: str) -> str:
- """安全读取字符串配置值。
-
- Args:
- mapping: 待读取的配置字典。
- key: 目标键名。
-
- Returns:
- str: 去除首尾空白后的字符串值。
- """
- value = mapping.get(key)
- return "" if value is None else str(value).strip()
-
-
-def _read_positive_float(
- mapping: Mapping[str, Any],
- key: str,
- default: float,
- logger: Any,
- setting_name: str,
-) -> float:
- """安全读取正浮点数配置值。
-
- Args:
- mapping: 待读取的配置字典。
- key: 目标键名。
- default: 读取失败时的默认值。
- logger: 插件日志对象。
- setting_name: 用于日志输出的完整配置名。
-
- Returns:
- float: 合法的正浮点数;否则返回默认值。
- """
- value = mapping.get(key, default)
- if isinstance(value, (int, float)) and float(value) > 0:
- return float(value)
-
- if key in mapping:
- logger.warning(f"NapCat 适配器配置项取值无效,已回退到默认值: {setting_name}={value!r},默认值为 {default}")
- return default
-
-
-def _read_positive_int(
- mapping: Mapping[str, Any],
- key: str,
- default: int,
- logger: Any,
- setting_name: str,
-) -> int:
- """安全读取正整数配置值。
-
- Args:
- mapping: 待读取的配置字典。
- key: 目标键名。
- default: 读取失败时的默认值。
- logger: 插件日志对象。
- setting_name: 用于日志输出的完整配置名。
-
- Returns:
- int: 合法的正整数;否则返回默认值。
- """
- value = mapping.get(key, default)
- if isinstance(value, int) and value > 0:
- return value
-
- if isinstance(value, str) and value.isdigit() and int(value) > 0:
- return int(value)
-
- if key in mapping:
- logger.warning(f"NapCat 适配器配置项取值无效,已回退到默认值: {setting_name}={value!r},默认值为 {default}")
- return default
-
-
-def _read_list_mode(
- mapping: Mapping[str, Any],
- key: str,
- default: str,
- logger: Any,
- setting_name: str,
-) -> str:
- """安全读取名单模式配置值。
-
- Args:
- mapping: 待读取的配置字典。
- key: 目标键名。
- default: 读取失败时的默认值。
- logger: 插件日志对象。
- setting_name: 用于日志输出的完整配置名。
-
- Returns:
- str: 合法的名单模式字符串。
- """
- value = mapping.get(key, default)
- if isinstance(value, str):
- normalized_value = value.strip()
- if normalized_value in {"whitelist", "blacklist"}:
- return normalized_value
-
- if key in mapping:
- logger.warning(f"NapCat 适配器配置项取值无效,已回退到默认值: {setting_name}={value!r},默认值为 {default}")
- return default
-
-
-def _read_string_set(mapping: Mapping[str, Any], key: str) -> Set[str]:
- """安全读取字符串集合配置值。
-
- Args:
- mapping: 待读取的配置字典。
- key: 目标键名。
-
- Returns:
- Set[str]: 规范化后的字符串集合。
- """
- value = mapping.get(key, [])
- if not isinstance(value, list):
- return set()
-
- normalized_values: Set[str] = set()
- for item in value:
- item_text = "" if item is None else str(item).strip()
- if item_text:
- normalized_values.add(item_text)
- return normalized_values
-
-
-def _read_legacy_host_port(
- server_section: Mapping[str, Any],
- legacy_connection_section: Mapping[str, Any],
- logger: Any,
-) -> Tuple[str, Optional[int]]:
- """从旧版 ``ws_url`` 配置中提取主机与端口。
-
- Args:
- server_section: 新版 ``napcat_server`` 配置段。
- legacy_connection_section: 旧版 ``connection`` 配置段。
- logger: 插件日志对象。
-
- Returns:
- Tuple[str, Optional[int]]: 解析到的主机与端口;若未找到,则返回空主机与 ``None``。
- """
- legacy_ws_url = _read_string(server_section, "ws_url") or _read_string(legacy_connection_section, "ws_url")
- if not legacy_ws_url:
- return "", None
-
- parsed_url = urlparse(legacy_ws_url)
- parsed_host = parsed_url.hostname or ""
- parsed_port = parsed_url.port
-
- logger.warning(
- "NapCat 适配器检测到旧版 ws_url 配置,已临时兼容解析,请尽快迁移到 napcat_server.host/port"
- )
- if parsed_url.path not in {"", "/"}:
- logger.warning("NapCat 适配器旧版 ws_url 包含路径,新的 napcat_server 配置不会保留该路径")
-
- return parsed_host, parsed_port
diff --git a/src/plugins/built_in/napcat_adapter/constants.py b/src/plugins/built_in/napcat_adapter/constants.py
deleted file mode 100644
index bdddde6f..00000000
--- a/src/plugins/built_in/napcat_adapter/constants.py
+++ /dev/null
@@ -1,9 +0,0 @@
-"""NapCat 内置适配器共享常量。"""
-
-SUPPORTED_CONFIG_VERSION = "0.1.0"
-DEFAULT_NAPCAT_HOST = "127.0.0.1"
-DEFAULT_NAPCAT_PORT = 3001
-DEFAULT_RECONNECT_DELAY_SEC = 5.0
-DEFAULT_HEARTBEAT_INTERVAL_SEC = 30.0
-DEFAULT_ACTION_TIMEOUT_SEC = 15.0
-DEFAULT_CHAT_LIST_TYPE = "whitelist"
diff --git a/src/plugins/built_in/napcat_adapter/filters.py b/src/plugins/built_in/napcat_adapter/filters.py
deleted file mode 100644
index 141cda85..00000000
--- a/src/plugins/built_in/napcat_adapter/filters.py
+++ /dev/null
@@ -1,68 +0,0 @@
-"""NapCat 入站消息过滤。"""
-
-from typing import Any, Set
-
-from napcat_adapter.config import NapCatChatConfig
-
-
-class NapCatChatFilter:
- """NapCat 聊天名单过滤器。"""
-
- def __init__(self, logger: Any) -> None:
- """初始化聊天名单过滤器。
-
- Args:
- logger: 插件日志对象。
- """
- self._logger = logger
-
- def is_inbound_chat_allowed(
- self,
- sender_user_id: str,
- group_id: str,
- chat_config: NapCatChatConfig,
- ) -> bool:
- """检查入站消息是否通过聊天名单过滤。
-
- Args:
- sender_user_id: 发送者用户 ID。
- group_id: 群聊 ID;私聊时为空字符串。
- chat_config: 当前生效的聊天配置。
-
- Returns:
- bool: 若消息允许继续进入 Host,则返回 ``True``。
- """
- if sender_user_id in chat_config.ban_user_id:
- self._logger.warning(f"NapCat 用户 {sender_user_id} 在全局禁止名单中,消息被丢弃")
- return False
-
- if group_id:
- if not self._is_id_allowed_by_list_policy(group_id, chat_config.group_list_type, chat_config.group_list):
- self._logger.warning(f"NapCat 群聊 {group_id} 未通过聊天名单过滤,消息被丢弃")
- return False
- return True
-
- if not self._is_id_allowed_by_list_policy(
- sender_user_id,
- chat_config.private_list_type,
- chat_config.private_list,
- ):
- self._logger.warning(f"NapCat 私聊用户 {sender_user_id} 未通过聊天名单过滤,消息被丢弃")
- return False
- return True
-
- @staticmethod
- def _is_id_allowed_by_list_policy(target_id: str, list_type: str, configured_ids: Set[str]) -> bool:
- """根据白名单或黑名单规则判断目标 ID 是否允许通过。
-
- Args:
- target_id: 待检查的目标 ID。
- list_type: 名单模式,仅支持 ``whitelist`` 或 ``blacklist``。
- configured_ids: 配置中的 ID 集合。
-
- Returns:
- bool: 若目标 ID 允许通过,则返回 ``True``。
- """
- if list_type == "whitelist":
- return target_id in configured_ids
- return target_id not in configured_ids
diff --git a/src/plugins/built_in/napcat_adapter/plugin.py b/src/plugins/built_in/napcat_adapter/plugin.py
deleted file mode 100644
index 50900c5d..00000000
--- a/src/plugins/built_in/napcat_adapter/plugin.py
+++ /dev/null
@@ -1,381 +0,0 @@
-"""内置 NapCat 适配器插件。
-
-当前实现维持 MVP 范围,目标是跑通基础消息收发链路:
-1. 作为客户端连接 NapCat / OneBot v11 WebSocket 服务。
-2. 将入站消息事件转换为 Host 侧的 ``MessageDict``。
-3. 将 Host 出站消息转换为 OneBot 动作并发送。
-
-当前范围刻意收敛为:
-- 单连接
-- 文本、@、reply 基础转发
-- 暂不处理 ``notice`` / ``meta_event`` 的完整语义归一化
-- 暂不支持图片、语音、文件等复杂媒体
-"""
-
-from __future__ import annotations
-
-from typing import Any, Dict, Mapping, Optional
-
-import asyncio
-
-from maibot_sdk import Adapter, MaiBotPlugin
-
-from napcat_adapter.codec_inbound import NapCatInboundCodec
-from napcat_adapter.codec_outbound import NapCatOutboundCodec
-from napcat_adapter.config import NapCatPluginSettings
-from napcat_adapter.filters import NapCatChatFilter
-from napcat_adapter.qq_notice import NapCatNoticeCodec
-from napcat_adapter.qq_queries import NapCatQueryService
-from napcat_adapter.runtime_state import NapCatRuntimeStateManager
-from napcat_adapter.transport import NapCatTransportClient
-
-
-@Adapter(platform="qq", protocol="napcat", send_method="send_to_platform")
-class NapCatAdapterPlugin(MaiBotPlugin):
- """NapCat 适配器 MVP 实现。"""
-
- def __init__(self) -> None:
- """初始化 NapCat 适配器插件实例。"""
- super().__init__()
- self._plugin_config: Dict[str, Any] = {}
- self._settings: Optional[NapCatPluginSettings] = None
- self._inbound_codec: Optional[NapCatInboundCodec] = None
- self._outbound_codec = NapCatOutboundCodec()
- self._chat_filter: Optional[NapCatChatFilter] = None
- self._query_service: Optional[NapCatQueryService] = None
- self._notice_codec: Optional[NapCatNoticeCodec] = None
- self._runtime_state: Optional[NapCatRuntimeStateManager] = None
- self._transport: Optional[NapCatTransportClient] = None
-
- def set_plugin_config(self, config: Dict[str, Any]) -> None:
- """设置插件配置内容。
-
- Args:
- config: Runner 注入的 ``config.toml`` 解析结果。
- """
- self._plugin_config = config if isinstance(config, dict) else {}
-
- async def on_load(self) -> None:
- """在插件加载时根据配置决定是否启动连接。"""
- await self._restart_connection_if_needed()
-
- async def on_unload(self) -> None:
- """在插件卸载时关闭连接并清理运行时状态。"""
- await self._stop_connection()
-
- async def on_config_update(self, new_config: Dict[str, Any], version: str) -> None:
- """在配置更新后重载连接状态。
-
- Args:
- new_config: 最新的插件配置。
- version: 配置版本号。
- """
- self.set_plugin_config(new_config)
- self._settings = None
- if version:
- self.ctx.logger.debug(f"NapCat 适配器收到配置更新通知: {version}")
- await self._restart_connection_if_needed()
-
- async def send_to_platform(
- self,
- message: Dict[str, Any],
- route: Optional[Dict[str, Any]] = None,
- metadata: Optional[Dict[str, Any]] = None,
- **kwargs: Any,
- ) -> Dict[str, Any]:
- """将 Host 出站消息发送到 NapCat。
-
- Args:
- message: Host 侧标准 ``MessageDict``。
- route: Platform IO 生成的路由信息。
- metadata: Platform IO 附带的投递元数据。
- **kwargs: 预留的扩展参数。
-
- Returns:
- Dict[str, Any]: 标准化后的发送结果。
- """
- del metadata
- del kwargs
-
- self._ensure_runtime_components()
- transport = self._transport
- if transport is None:
- return {"success": False, "error": "NapCat transport is not initialized"}
-
- try:
- action_name, params = self._outbound_codec.build_outbound_action(message, route or {})
- response = await transport.call_action(action_name, params)
- except Exception as exc:
- return {"success": False, "error": str(exc)}
-
- if str(response.get("status", "")).lower() != "ok":
- return {
- "success": False,
- "error": str(response.get("wording") or response.get("message") or "NapCat send failed"),
- "metadata": {"retcode": response.get("retcode")},
- }
-
- response_data = response.get("data", {})
- external_message_id = ""
- if isinstance(response_data, Mapping):
- external_message_id = str(response_data.get("message_id") or "")
-
- return {
- "success": True,
- "external_message_id": external_message_id or None,
- "metadata": {"action": action_name},
- }
-
- def _ensure_runtime_components(self) -> None:
- """确保运行时依赖对象已经完成初始化。"""
- if self._chat_filter is None:
- self._chat_filter = NapCatChatFilter(self.ctx.logger)
-
- if self._transport is None:
- self._transport = NapCatTransportClient(
- logger=self.ctx.logger,
- on_connection_opened=self._bootstrap_adapter_runtime_state,
- on_connection_closed=self._handle_transport_disconnected,
- on_payload=self._handle_transport_payload,
- )
-
- if self._query_service is None:
- self._query_service = NapCatQueryService(self.ctx.logger, self._transport)
-
- if self._inbound_codec is None:
- self._inbound_codec = NapCatInboundCodec(self.ctx.logger, self._query_service)
-
- if self._notice_codec is None:
- self._notice_codec = NapCatNoticeCodec(self.ctx.logger, self._query_service)
-
- if self._runtime_state is None:
- self._runtime_state = NapCatRuntimeStateManager(self.ctx.adapter, self.ctx.logger)
-
- def _reload_settings(self) -> NapCatPluginSettings:
- """重新解析当前插件配置。
-
- Returns:
- NapCatPluginSettings: 最新的规范化配置。
- """
- self._settings = NapCatPluginSettings.from_mapping(self._plugin_config, self.ctx.logger)
- return self._settings
-
- async def _restart_connection_if_needed(self) -> None:
- """根据当前配置重启连接循环。"""
- self._ensure_runtime_components()
- settings = self._reload_settings()
-
- await self._stop_connection()
- if not settings.should_connect():
- self.ctx.logger.info("NapCat 适配器保持空闲状态,因为插件或配置未启用")
- return
- if not settings.validate(self.ctx.logger):
- return
-
- transport = self._transport
- assert transport is not None
- if not transport.is_available():
- self.ctx.logger.error("NapCat 适配器依赖 aiohttp,但当前环境未安装该依赖")
- return
-
- transport.configure(settings.napcat_server)
- await transport.start()
-
- async def _stop_connection(self) -> None:
- """停止当前连接。"""
- transport = self._transport
- if transport is not None:
- await transport.stop()
- return
-
- runtime_state = self._runtime_state
- if runtime_state is not None:
- await runtime_state.report_disconnected()
-
- async def _handle_transport_payload(self, payload: Dict[str, Any]) -> None:
- """处理来自传输层的非 echo 载荷。
-
- Args:
- payload: NapCat 推送的原始事件数据。
- """
- post_type = str(payload.get("post_type") or "").strip()
- if post_type == "message":
- await self._handle_inbound_message(payload)
- return
- if post_type == "notice":
- await self._handle_notice_event(payload)
- return
- if post_type == "meta_event":
- await self._handle_meta_event(payload)
-
- async def _handle_inbound_message(self, payload: Dict[str, Any]) -> None:
- """处理单条 NapCat 入站消息并注入 Host。
-
- Args:
- payload: NapCat / OneBot 推送的原始消息事件。
- """
- self._ensure_runtime_components()
- settings = self._settings or self._reload_settings()
- chat_filter = self._chat_filter
- inbound_codec = self._inbound_codec
- runtime_state = self._runtime_state
- assert chat_filter is not None
- assert inbound_codec is not None
- assert runtime_state is not None
-
- self_id = str(payload.get("self_id") or "").strip()
- if self_id:
- await runtime_state.report_connected(self_id, settings.napcat_server)
-
- sender = payload.get("sender", {})
- if not isinstance(sender, Mapping):
- sender = {}
-
- sender_user_id = str(payload.get("user_id") or sender.get("user_id") or "").strip()
- if not sender_user_id:
- return
-
- group_id = str(payload.get("group_id") or "").strip()
- if self_id and sender_user_id == self_id and settings.filters.ignore_self_message:
- return
- if not chat_filter.is_inbound_chat_allowed(sender_user_id, group_id, settings.chat):
- return
-
- message_dict = await inbound_codec.build_message_dict(payload, self_id, sender_user_id, sender)
- route_metadata: Dict[str, Any] = {}
- if self_id:
- route_metadata["self_id"] = self_id
- if settings.napcat_server.connection_id:
- route_metadata["connection_id"] = settings.napcat_server.connection_id
-
- external_message_id = str(payload.get("message_id") or "").strip()
- accepted = await self.ctx.adapter.receive_external_message(
- message_dict,
- route_metadata=route_metadata,
- external_message_id=external_message_id,
- dedupe_key=external_message_id,
- )
- if not accepted:
- self.ctx.logger.debug(f"Host 丢弃了 NapCat 入站消息: {external_message_id or '无消息 ID'}")
-
- async def _handle_notice_event(self, payload: Dict[str, Any]) -> None:
- """处理 NapCat ``notice`` 事件并注入 Host。
-
- Args:
- payload: NapCat 推送的通知事件。
- """
- self._ensure_runtime_components()
- notice_codec = self._notice_codec
- runtime_state = self._runtime_state
- settings = self._settings or self._reload_settings()
- assert notice_codec is not None
- assert runtime_state is not None
-
- self_id = str(payload.get("self_id") or "").strip()
- if self_id:
- await runtime_state.report_connected(self_id, settings.napcat_server)
-
- message_dict = await notice_codec.build_notice_message_dict(payload)
- if message_dict is None:
- return
-
- route_metadata: Dict[str, Any] = {}
- if self_id:
- route_metadata["self_id"] = self_id
- if settings.napcat_server.connection_id:
- route_metadata["connection_id"] = settings.napcat_server.connection_id
-
- external_message_id = str(payload.get("message_id") or payload.get("notice_type") or "").strip()
- accepted = await self.ctx.adapter.receive_external_message(
- message_dict,
- route_metadata=route_metadata,
- external_message_id=external_message_id or None,
- dedupe_key=external_message_id or None,
- )
- if not accepted:
- self.ctx.logger.debug(f"Host 丢弃了 NapCat 通知事件: {external_message_id or '无消息 ID'}")
-
- async def _handle_meta_event(self, payload: Dict[str, Any]) -> None:
- """处理 NapCat ``meta_event`` 事件。
-
- Args:
- payload: NapCat 推送的元事件。
- """
- self._ensure_runtime_components()
- notice_codec = self._notice_codec
- runtime_state = self._runtime_state
- settings = self._settings or self._reload_settings()
- assert notice_codec is not None
- assert runtime_state is not None
-
- self_id = str(payload.get("self_id") or "").strip()
- if self_id:
- await runtime_state.report_connected(self_id, settings.napcat_server)
-
- await notice_codec.handle_meta_event(payload)
-
- async def _bootstrap_adapter_runtime_state(self) -> None:
- """在连接建立后主动获取账号信息并激活适配器路由。"""
- transport = self._transport
- query_service = self._query_service
- runtime_state = self._runtime_state
- settings = self._settings or self._reload_settings()
- if transport is None or query_service is None or runtime_state is None:
- return
-
- max_attempts = 3
- last_error: Optional[Exception] = None
- for attempt in range(1, max_attempts + 1):
- try:
- login_info = await query_service.get_login_info()
- self_id = self._extract_self_id_from_login_response(login_info)
- await runtime_state.report_connected(self_id, settings.napcat_server)
- return
- except asyncio.CancelledError:
- raise
- except Exception as exc:
- last_error = exc
- self.ctx.logger.warning(
- f"NapCat 适配器获取登录信息失败,第 {attempt}/{max_attempts} 次重试: {exc}"
- )
- if attempt < max_attempts:
- await asyncio.sleep(1.0)
-
- if last_error is not None:
- self.ctx.logger.error(f"NapCat 适配器未能完成路由激活,连接将保持只接收状态: {last_error}")
-
- async def _handle_transport_disconnected(self) -> None:
- """处理传输层断开事件。"""
- runtime_state = self._runtime_state
- if runtime_state is not None:
- await runtime_state.report_disconnected()
-
- @staticmethod
- def _extract_self_id_from_login_response(response: Optional[Dict[str, Any]]) -> str:
- """从 ``get_login_info`` 查询结果中提取当前账号 ID。
-
- Args:
- response: NapCat 返回的登录信息字典。
-
- Returns:
- str: 规范化后的账号 ID 字符串。
-
- Raises:
- ValueError: 当响应中缺少有效账号 ID 时抛出。
- """
- if not isinstance(response, Mapping):
- raise ValueError("get_login_info 响应缺少 data 字段")
-
- self_id = str(response.get("user_id") or "").strip()
- if not self_id:
- raise ValueError("get_login_info 响应缺少有效的 user_id")
- return self_id
-
-
-def create_plugin() -> NapCatAdapterPlugin:
- """创建插件实例。
-
- Returns:
- NapCatAdapterPlugin: NapCat 内置适配器插件实例。
- """
- return NapCatAdapterPlugin()
diff --git a/src/plugins/built_in/napcat_adapter/qq_notice.py b/src/plugins/built_in/napcat_adapter/qq_notice.py
deleted file mode 100644
index f577cf98..00000000
--- a/src/plugins/built_in/napcat_adapter/qq_notice.py
+++ /dev/null
@@ -1,224 +0,0 @@
-"""NapCat QQ 平台通知与元事件处理。"""
-
-from typing import Any, Dict, Mapping, Optional
-from uuid import uuid4
-
-import time
-
-from napcat_adapter.qq_queries import NapCatQueryService
-
-
-class NapCatNoticeCodec:
- """NapCat QQ 通知事件编码器。"""
-
- def __init__(self, logger: Any, query_service: NapCatQueryService) -> None:
- """初始化通知事件编码器。
-
- Args:
- logger: 插件日志对象。
- query_service: QQ 查询服务。
- """
- self._logger = logger
- self._query_service = query_service
-
- async def build_notice_message_dict(self, payload: Mapping[str, Any]) -> Optional[Dict[str, Any]]:
- """将 NapCat ``notice`` 事件转换为 Host 可接受的消息字典。
-
- Args:
- payload: NapCat 推送的原始通知事件。
-
- Returns:
- Optional[Dict[str, Any]]: 成功时返回标准 ``MessageDict``;无法识别时返回 ``None``。
- """
- notice_type = str(payload.get("notice_type") or "").strip()
- if not notice_type:
- return None
-
- group_id = str(payload.get("group_id") or "").strip()
- user_id = str(payload.get("user_id") or payload.get("operator_id") or "").strip()
- self_id = str(payload.get("self_id") or "").strip()
-
- user_info = await self._build_user_info(group_id=group_id, user_id=user_id)
- group_info = await self._build_group_info(group_id)
- notice_text = self._build_notice_text(payload, user_info.get("user_nickname", user_id or "系统"))
- if not notice_text:
- return None
-
- additional_config: Dict[str, Any] = {
- "self_id": self_id,
- "napcat_notice_type": notice_type,
- "napcat_notice_sub_type": str(payload.get("sub_type") or "").strip(),
- "napcat_notice_payload": dict(payload),
- }
- if group_id:
- additional_config["platform_io_target_group_id"] = group_id
- elif user_id:
- additional_config["platform_io_target_user_id"] = user_id
-
- message_info: Dict[str, Any] = {"user_info": user_info, "additional_config": additional_config}
- if group_info is not None:
- message_info["group_info"] = group_info
-
- timestamp_seconds = payload.get("time")
- if not isinstance(timestamp_seconds, (int, float)):
- timestamp_seconds = time.time()
-
- return {
- "message_id": f"napcat-notice-{uuid4().hex}",
- "timestamp": str(float(timestamp_seconds)),
- "platform": "qq",
- "message_info": message_info,
- "raw_message": [{"type": "text", "data": notice_text}],
- "is_mentioned": False,
- "is_at": False,
- "is_emoji": False,
- "is_picture": False,
- "is_command": False,
- "is_notify": True,
- "session_id": "",
- "processed_plain_text": notice_text,
- "display_message": notice_text,
- }
-
- async def handle_meta_event(self, payload: Mapping[str, Any]) -> None:
- """处理 ``meta_event`` 事件的日志与状态观测。
-
- Args:
- payload: NapCat 推送的原始元事件。
- """
- meta_event_type = str(payload.get("meta_event_type") or "").strip()
- self_id = str(payload.get("self_id") or "").strip() or "unknown"
-
- if meta_event_type == "lifecycle":
- sub_type = str(payload.get("sub_type") or "").strip()
- if sub_type == "connect":
- self._logger.info(f"NapCat 元事件:Bot {self_id} 已建立连接")
- else:
- self._logger.debug(f"NapCat 生命周期事件: self_id={self_id} sub_type={sub_type}")
- return
-
- if meta_event_type == "heartbeat":
- status = payload.get("status", {})
- if not isinstance(status, Mapping):
- status = {}
- is_online = bool(status.get("online", False))
- is_good = bool(status.get("good", False))
- interval_ms = payload.get("interval")
- self._logger.debug(
- f"NapCat 心跳事件: self_id={self_id} online={is_online} good={is_good} interval={interval_ms}"
- )
- if not is_online:
- self._logger.warning(f"NapCat 心跳显示 Bot {self_id} 已离线")
- elif not is_good:
- self._logger.warning(f"NapCat 心跳显示 Bot {self_id} 状态异常")
-
- async def _build_user_info(self, group_id: str, user_id: str) -> Dict[str, Optional[str]]:
- """构造通知消息的用户信息。
-
- Args:
- group_id: 群号;私聊或系统通知时为空字符串。
- user_id: 事件关联用户号。
-
- Returns:
- Dict[str, Optional[str]]: 规范化后的用户信息字典。
- """
- if not user_id:
- return {
- "user_id": "notice",
- "user_nickname": "系统通知",
- "user_cardname": None,
- }
-
- member_info: Optional[Dict[str, Any]]
- if group_id:
- member_info = await self._query_service.get_group_member_info(group_id, user_id)
- else:
- member_info = await self._query_service.get_stranger_info(user_id)
-
- if member_info is None:
- return {
- "user_id": user_id,
- "user_nickname": user_id,
- "user_cardname": None,
- }
-
- return {
- "user_id": user_id,
- "user_nickname": str(member_info.get("nickname") or user_id),
- "user_cardname": self._normalize_optional_string(member_info.get("card")),
- }
-
- async def _build_group_info(self, group_id: str) -> Optional[Dict[str, str]]:
- """构造通知消息的群信息。
-
- Args:
- group_id: 群号。
-
- Returns:
- Optional[Dict[str, str]]: 群信息字典;若不是群通知则返回 ``None``。
- """
- if not group_id:
- return None
-
- group_info = await self._query_service.get_group_info(group_id)
- group_name = str(group_info.get("group_name") or f"group_{group_id}") if group_info else f"group_{group_id}"
- return {"group_id": group_id, "group_name": group_name}
-
- def _build_notice_text(self, payload: Mapping[str, Any], actor_name: str) -> str:
- """根据 NapCat 通知事件生成可读文本。
-
- Args:
- payload: 原始通知事件。
- actor_name: 事件操作者显示名。
-
- Returns:
- str: 生成的可读通知文本。
- """
- notice_type = str(payload.get("notice_type") or "").strip()
- sub_type = str(payload.get("sub_type") or "").strip()
- target_id = str(payload.get("target_id") or "").strip()
-
- if notice_type in {"group_recall", "friend_recall"}:
- return f"{actor_name} 撤回了一条消息"
- if notice_type == "notify" and sub_type == "poke":
- target_text = f" -> {target_id}" if target_id else ""
- return f"{actor_name} 发起了戳一戳{target_text}"
- if notice_type == "notify" and sub_type == "group_name":
- return f"{actor_name} 修改了群名称"
- if notice_type == "group_ban" and sub_type == "ban":
- duration = payload.get("duration")
- return f"{actor_name} 触发了群禁言,时长 {duration} 秒"
- if notice_type == "group_ban" and sub_type == "lift_ban":
- return f"{actor_name} 触发了解除禁言"
- if notice_type == "group_upload":
- file_info = payload.get("file", {})
- file_name = ""
- if isinstance(file_info, Mapping):
- file_name = str(file_info.get("name") or "").strip()
- return f"{actor_name} 上传了文件{f':{file_name}' if file_name else ''}"
- if notice_type == "group_increase":
- return f"{actor_name} 加入了群聊"
- if notice_type == "group_decrease":
- return f"{actor_name} 离开了群聊"
- if notice_type == "group_admin":
- return f"{actor_name} 的群管理员状态发生变化"
- if notice_type == "essence":
- return f"{actor_name} 触发了精华消息事件"
- if notice_type == "group_msg_emoji_like":
- return f"{actor_name} 给一条消息添加了表情回应"
- return f"[notice] {notice_type}.{sub_type}".strip(".")
-
- @staticmethod
- def _normalize_optional_string(value: Any) -> Optional[str]:
- """将任意值规范化为可选字符串。
-
- Args:
- value: 待规范化的值。
-
- Returns:
- Optional[str]: 规范化后的字符串;若值为空则返回 ``None``。
- """
- if value is None:
- return None
- normalized_value = str(value).strip()
- return normalized_value if normalized_value else None
diff --git a/src/plugins/built_in/napcat_adapter/qq_queries.py b/src/plugins/built_in/napcat_adapter/qq_queries.py
deleted file mode 100644
index 7d29803a..00000000
--- a/src/plugins/built_in/napcat_adapter/qq_queries.py
+++ /dev/null
@@ -1,170 +0,0 @@
-"""NapCat QQ 平台查询能力。"""
-
-from typing import TYPE_CHECKING, Any, Dict, Optional
-
-import asyncio
-
-if TYPE_CHECKING:
- from napcat_adapter.transport import NapCatTransportClient
-
-try:
- from aiohttp import ClientSession, ClientTimeout
-
- AIOHTTP_AVAILABLE = True
-except ImportError:
- ClientSession = None # type: ignore[assignment]
- ClientTimeout = None # type: ignore[assignment]
- AIOHTTP_AVAILABLE = False
-
-
-class NapCatQueryService:
- """NapCat QQ 平台查询服务。"""
-
- def __init__(self, logger: Any, transport: "NapCatTransportClient") -> None:
- """初始化查询服务。
-
- Args:
- logger: 插件日志对象。
- transport: NapCat 传输层客户端。
- """
- self._logger = logger
- self._transport = transport
-
- async def get_login_info(self) -> Optional[Dict[str, Any]]:
- """获取当前登录账号信息。
-
- Returns:
- Optional[Dict[str, Any]]: 登录信息字典;失败时返回 ``None``。
- """
- return await self._call_query("get_login_info", {})
-
- async def get_group_info(self, group_id: str) -> Optional[Dict[str, Any]]:
- """获取群信息。
-
- Args:
- group_id: 群号。
-
- Returns:
- Optional[Dict[str, Any]]: 群信息字典;失败时返回 ``None``。
- """
- return await self._call_query("get_group_info", {"group_id": group_id})
-
- async def get_group_member_info(self, group_id: str, user_id: str) -> Optional[Dict[str, Any]]:
- """获取群成员信息。
-
- Args:
- group_id: 群号。
- user_id: 用户号。
-
- Returns:
- Optional[Dict[str, Any]]: 群成员信息字典;失败时返回 ``None``。
- """
- return await self._call_query(
- "get_group_member_info",
- {"group_id": group_id, "user_id": user_id, "no_cache": True},
- )
-
- async def get_stranger_info(self, user_id: str) -> Optional[Dict[str, Any]]:
- """获取陌生人信息。
-
- Args:
- user_id: 用户号。
-
- Returns:
- Optional[Dict[str, Any]]: 陌生人信息字典;失败时返回 ``None``。
- """
- return await self._call_query("get_stranger_info", {"user_id": user_id})
-
- async def get_message_detail(self, message_id: str) -> Optional[Dict[str, Any]]:
- """获取消息详情。
-
- Args:
- message_id: 消息 ID。
-
- Returns:
- Optional[Dict[str, Any]]: 消息详情字典;失败时返回 ``None``。
- """
- return await self._call_query("get_msg", {"message_id": message_id})
-
- async def get_forward_message(self, message_id: str) -> Optional[Dict[str, Any]]:
- """获取合并转发消息详情。
-
- Args:
- message_id: 转发消息 ID。
-
- Returns:
- Optional[Dict[str, Any]]: 合并转发消息详情;失败时返回 ``None``。
- """
- return await self._call_query("get_forward_msg", {"message_id": message_id})
-
- async def get_record_detail(self, file_name: str, file_id: Optional[str] = None) -> Optional[Dict[str, Any]]:
- """获取语音文件详情。
-
- Args:
- file_name: 语音文件名。
- file_id: 可选文件 ID。
-
- Returns:
- Optional[Dict[str, Any]]: 语音详情字典;失败时返回 ``None``。
- """
- params: Dict[str, Any] = {"file": file_name, "out_format": "wav"}
- if file_id:
- params["file_id"] = file_id
- return await self._call_query("get_record", params)
-
- async def download_binary(self, url: str) -> Optional[bytes]:
- """下载远程二进制资源。
-
- Args:
- url: 资源 URL。
-
- Returns:
- Optional[bytes]: 下载到的二进制内容;失败时返回 ``None``。
- """
- if not url:
- return None
- if not AIOHTTP_AVAILABLE or ClientSession is None or ClientTimeout is None:
- self._logger.warning("NapCat 查询层缺少 aiohttp,无法下载远程资源")
- return None
-
- try:
- timeout = ClientTimeout(total=15)
- async with ClientSession(timeout=timeout) as session:
- async with session.get(url) as response:
- if response.status != 200:
- self._logger.warning(f"NapCat 远程资源下载失败: status={response.status} url={url}")
- return None
- return await response.read()
- except asyncio.CancelledError:
- raise
- except Exception as exc:
- self._logger.warning(f"NapCat 远程资源下载失败: {exc}")
- return None
-
- async def _call_query(self, action_name: str, params: Dict[str, Any]) -> Optional[Dict[str, Any]]:
- """调用 OneBot 查询动作并提取 ``data`` 字段。
-
- Args:
- action_name: OneBot 动作名。
- params: 动作参数。
-
- Returns:
- Optional[Dict[str, Any]]: 查询结果中的 ``data`` 字段;失败时返回 ``None``。
- """
- try:
- response = await self._transport.call_action(action_name, params)
- except asyncio.CancelledError:
- raise
- except Exception as exc:
- self._logger.warning(f"NapCat 查询动作执行失败: action={action_name} error={exc}")
- return None
-
- if str(response.get("status") or "").lower() != "ok":
- self._logger.warning(
- f"NapCat 查询动作返回失败: action={action_name} "
- f"message={response.get('wording') or response.get('message') or 'unknown'}"
- )
- return None
-
- response_data = response.get("data")
- return response_data if isinstance(response_data, dict) else None
diff --git a/src/plugins/built_in/napcat_adapter/runtime_state.py b/src/plugins/built_in/napcat_adapter/runtime_state.py
deleted file mode 100644
index b4dbfa09..00000000
--- a/src/plugins/built_in/napcat_adapter/runtime_state.py
+++ /dev/null
@@ -1,85 +0,0 @@
-"""NapCat 运行时路由状态管理。"""
-
-from typing import Any, Optional
-
-from napcat_adapter.config import NapCatServerConfig
-
-
-class NapCatRuntimeStateManager:
- """NapCat 适配器路由状态上报器。"""
-
- def __init__(self, adapter_capability: Any, logger: Any) -> None:
- """初始化运行时状态管理器。
-
- Args:
- adapter_capability: SDK 提供的适配器能力对象。
- logger: 插件日志对象。
- """
- self._adapter_capability = adapter_capability
- self._logger = logger
- self._runtime_state_connected: bool = False
- self._reported_account_id: Optional[str] = None
- self._reported_scope: Optional[str] = None
-
- async def report_connected(self, account_id: str, server_config: NapCatServerConfig) -> bool:
- """向 Host 上报当前连接已就绪。
-
- Args:
- account_id: 当前 NapCat 连接对应的机器人账号 ID。
- server_config: 当前生效的 NapCat 服务端配置。
-
- Returns:
- bool: 若 Host 接受了运行时状态更新,则返回 ``True``。
- """
- normalized_account_id = str(account_id).strip()
- if not normalized_account_id:
- return False
-
- scope = server_config.connection_id or None
- if (
- self._runtime_state_connected
- and self._reported_account_id == normalized_account_id
- and self._reported_scope == scope
- ):
- return True
-
- accepted = False
- try:
- accepted = await self._adapter_capability.update_runtime_state(
- connected=True,
- account_id=normalized_account_id,
- scope=server_config.connection_id,
- metadata={"ws_url": server_config.build_ws_url()},
- )
- except Exception as exc:
- self._logger.warning(f"NapCat 适配器上报连接就绪状态失败: {exc}")
- return False
-
- if not accepted:
- self._logger.warning("NapCat 适配器连接已建立,但 Host 未接受运行时状态更新")
- return False
-
- self._runtime_state_connected = True
- self._reported_account_id = normalized_account_id
- self._reported_scope = scope
- self._logger.info(
- f"NapCat 适配器已激活路由: platform=qq account_id={normalized_account_id} "
- f"scope={self._reported_scope or '*'}"
- )
- return True
-
- async def report_disconnected(self) -> None:
- """向 Host 上报当前连接已断开,并撤销适配器路由。"""
- if not self._runtime_state_connected:
- self._reported_account_id = None
- self._reported_scope = None
- return
-
- try:
- await self._adapter_capability.update_runtime_state(connected=False)
- except Exception as exc:
- self._logger.warning(f"NapCat 适配器上报断开状态失败: {exc}")
- finally:
- self._runtime_state_connected = False
- self._reported_account_id = None
- self._reported_scope = None
diff --git a/src/plugins/built_in/napcat_adapter/transport.py b/src/plugins/built_in/napcat_adapter/transport.py
deleted file mode 100644
index d20de097..00000000
--- a/src/plugins/built_in/napcat_adapter/transport.py
+++ /dev/null
@@ -1,322 +0,0 @@
-"""NapCat 正向 WebSocket 传输层。"""
-
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Set, cast
-from uuid import uuid4
-
-import asyncio
-import contextlib
-import json
-
-from napcat_adapter.config import NapCatServerConfig
-
-if TYPE_CHECKING:
- from aiohttp import ClientWebSocketResponse as AiohttpClientWebSocketResponse
-
-try:
- from aiohttp import ClientSession, ClientTimeout, WSMsgType
-
- AIOHTTP_AVAILABLE = True
-except ImportError:
- ClientSession = cast(Any, None)
- ClientTimeout = cast(Any, None)
- WSMsgType = cast(Any, None)
- AIOHTTP_AVAILABLE = False
-
-if not TYPE_CHECKING:
- AiohttpClientWebSocketResponse = Any
-
-
-class NapCatTransportClient:
- """NapCat 正向 WebSocket 客户端。"""
-
- def __init__(
- self,
- logger: Any,
- on_connection_opened: Callable[[], Awaitable[None]],
- on_connection_closed: Callable[[], Awaitable[None]],
- on_payload: Callable[[Dict[str, Any]], Awaitable[None]],
- ) -> None:
- """初始化传输层客户端。
-
- Args:
- logger: 插件日志对象。
- on_connection_opened: 连接建立后的异步回调。
- on_connection_closed: 连接断开后的异步回调。
- on_payload: 收到非 echo 载荷后的异步回调。
- """
- self._logger = logger
- self._on_connection_opened = on_connection_opened
- self._on_connection_closed = on_connection_closed
- self._on_payload = on_payload
- self._server_config: Optional[NapCatServerConfig] = None
- self._connection_task: Optional[asyncio.Task[None]] = None
- self._pending_actions: Dict[str, asyncio.Future[Dict[str, Any]]] = {}
- self._background_tasks: Set[asyncio.Task[Any]] = set()
- self._send_lock = asyncio.Lock()
- self._ws: Optional[AiohttpClientWebSocketResponse] = None
- self._stop_requested: bool = False
- self._connection_active: bool = False
-
- @classmethod
- def is_available(cls) -> bool:
- """判断当前环境是否安装了传输层依赖。
-
- Returns:
- bool: 若已安装 ``aiohttp``,则返回 ``True``。
- """
- return AIOHTTP_AVAILABLE
-
- def configure(self, server_config: NapCatServerConfig) -> None:
- """更新当前传输层使用的 NapCat 服务端配置。
-
- Args:
- server_config: 最新生效的 NapCat 服务端配置。
- """
- self._server_config = server_config
-
- async def start(self) -> None:
- """启动 NapCat 正向 WebSocket 连接循环。
-
- Raises:
- RuntimeError: 当缺少配置或依赖时抛出。
- """
- if not self.is_available():
- raise RuntimeError("NapCat 适配器依赖 aiohttp,但当前环境未安装该依赖")
- if self._server_config is None:
- raise RuntimeError("NapCat 适配器尚未配置 napcat_server")
- if self._connection_task is not None and not self._connection_task.done():
- return
-
- self._stop_requested = False
- self._connection_task = asyncio.create_task(self._connection_loop(), name="napcat_adapter.connection")
-
- async def stop(self) -> None:
- """停止当前连接并清理所有后台任务。"""
- self._stop_requested = True
- connection_task = self._connection_task
- self._connection_task = None
-
- ws = self._ws
- if ws is not None and not ws.closed:
- with contextlib.suppress(Exception):
- await ws.close()
- self._ws = None
-
- if connection_task is not None:
- connection_task.cancel()
- with contextlib.suppress(asyncio.CancelledError):
- await connection_task
-
- await self._cancel_background_tasks()
- await self._notify_connection_closed()
- self._fail_pending_actions("NapCat connection closed")
-
- async def call_action(self, action_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
- """发送 OneBot 动作并等待对应的 echo 响应。
-
- Args:
- action_name: OneBot 动作名称。
- params: 动作参数。
-
- Returns:
- Dict[str, Any]: NapCat 返回的原始响应字典。
-
- Raises:
- RuntimeError: 当连接不可用时抛出。
- """
- ws = self._ws
- server_config = self._server_config
- if ws is None or ws.closed or server_config is None:
- raise RuntimeError("NapCat is not connected")
-
- echo_id = uuid4().hex
- loop = asyncio.get_running_loop()
- response_future: asyncio.Future[Dict[str, Any]] = loop.create_future()
- self._pending_actions[echo_id] = response_future
-
- request_payload = {"action": action_name, "params": params, "echo": echo_id}
- try:
- async with self._send_lock:
- await ws.send_str(json.dumps(request_payload, ensure_ascii=False))
- return await asyncio.wait_for(response_future, timeout=server_config.action_timeout_sec)
- finally:
- self._pending_actions.pop(echo_id, None)
-
- async def _connection_loop(self) -> None:
- """维护单个 WebSocket 连接,并在断开后按配置重连。"""
- assert ClientSession is not None
- assert ClientTimeout is not None
-
- while not self._stop_requested:
- server_config = self._server_config
- if server_config is None:
- return
-
- ws_url = server_config.build_ws_url()
- timeout = ClientTimeout(total=None, connect=10)
-
- try:
- async with ClientSession(headers=self._build_headers(server_config), timeout=timeout) as session:
- async with session.ws_connect(ws_url, heartbeat=server_config.heartbeat_interval or None) as ws:
- self._ws = ws
- self._logger.info(f"NapCat 适配器已连接: {ws_url}")
- await self._receive_loop(ws)
- except asyncio.CancelledError:
- raise
- except Exception as exc:
- self._logger.warning(f"NapCat 适配器连接失败: {exc}")
- finally:
- self._ws = None
- await self._notify_connection_closed()
- self._fail_pending_actions("NapCat connection interrupted")
-
- if self._stop_requested:
- break
-
- await asyncio.sleep(server_config.reconnect_delay_sec)
-
- async def _receive_loop(self, ws: AiohttpClientWebSocketResponse) -> None:
- """持续消费 WebSocket 消息并分发处理。
-
- Args:
- ws: 当前活跃的 WebSocket 连接对象。
- """
- assert WSMsgType is not None
-
- bootstrap_task = self._create_background_task(
- self._notify_connection_opened(),
- "napcat_adapter.bootstrap",
- )
- try:
- async for ws_message in ws:
- if ws_message.type != WSMsgType.TEXT:
- if ws_message.type in {WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.ERROR}:
- break
- continue
-
- payload = self._parse_json_message(ws_message.data)
- if payload is None:
- continue
-
- if echo_id := str(payload.get("echo") or "").strip():
- self._resolve_pending_action(echo_id, payload)
- continue
-
- self._create_background_task(self._on_payload(payload), "napcat_adapter.payload")
- finally:
- if bootstrap_task is not None and not bootstrap_task.done():
- bootstrap_task.cancel()
- with contextlib.suppress(asyncio.CancelledError):
- await bootstrap_task
-
- def _create_background_task(self, coroutine: Awaitable[Any], name: str) -> asyncio.Task[Any]:
- """创建并跟踪一个后台任务。
-
- Args:
- coroutine: 待执行的协程对象。
- name: 任务名。
-
- Returns:
- asyncio.Task[Any]: 已创建的后台任务。
- """
- task = asyncio.create_task(coroutine, name=name)
- self._background_tasks.add(task)
- task.add_done_callback(self._handle_background_task_completion)
- return task
-
- def _handle_background_task_completion(self, task: asyncio.Task[Any]) -> None:
- """处理后台任务结束后的清理与异常记录。
-
- Args:
- task: 已结束的后台任务。
- """
- self._background_tasks.discard(task)
- if task.cancelled():
- return
-
- exception = task.exception()
- if exception is not None:
- self._logger.error(f"NapCat 适配器后台任务异常: {exception}", exc_info=True)
-
- async def _cancel_background_tasks(self) -> None:
- """取消所有仍在运行的后台任务。"""
- background_tasks = list(self._background_tasks)
- for task in background_tasks:
- task.cancel()
- if background_tasks:
- with contextlib.suppress(Exception):
- await asyncio.gather(*background_tasks, return_exceptions=True)
- self._background_tasks.clear()
-
- async def _notify_connection_opened(self) -> None:
- """在连接建立后触发上层回调。"""
- if self._connection_active:
- return
-
- self._connection_active = True
- try:
- await self._on_connection_opened()
- except Exception as exc:
- self._logger.warning(f"NapCat 适配器连接建立回调失败: {exc}")
-
- async def _notify_connection_closed(self) -> None:
- """在连接断开后触发上层回调。"""
- if not self._connection_active:
- return
-
- self._connection_active = False
- try:
- await self._on_connection_closed()
- except Exception as exc:
- self._logger.warning(f"NapCat 适配器断连回调失败: {exc}")
-
- def _resolve_pending_action(self, echo_id: str, payload: Dict[str, Any]) -> None:
- """解析等待中的动作响应。
-
- Args:
- echo_id: 动作请求对应的 echo 标识。
- payload: NapCat 返回的响应载荷。
- """
- response_future = self._pending_actions.get(echo_id)
- if response_future is None or response_future.done():
- return
- response_future.set_result(payload)
-
- def _fail_pending_actions(self, error_message: str) -> None:
- """让所有等待中的动作以异常方式结束。
-
- Args:
- error_message: 写入异常中的错误信息。
- """
- for response_future in self._pending_actions.values():
- if not response_future.done():
- response_future.set_exception(RuntimeError(error_message))
- self._pending_actions.clear()
-
- def _build_headers(self, server_config: NapCatServerConfig) -> Dict[str, str]:
- """构造连接 NapCat 所需的请求头。
-
- Args:
- server_config: 当前生效的 NapCat 服务端配置。
-
- Returns:
- Dict[str, str]: WebSocket 握手请求头。
- """
- return {"Authorization": f"Bearer {server_config.token}"} if server_config.token else {}
-
- def _parse_json_message(self, data: Any) -> Optional[Dict[str, Any]]:
- """解析 WebSocket 文本消息中的 JSON 数据。
-
- Args:
- data: WebSocket 收到的原始文本数据。
-
- Returns:
- Optional[Dict[str, Any]]: 成功时返回字典,失败时返回 ``None``。
- """
- try:
- payload = json.loads(str(data))
- except Exception as exc:
- self._logger.warning(f"NapCat 适配器解析 JSON 载荷失败: {exc}")
- return None
-
- return payload if isinstance(payload, dict) else None
From d07e8f90ef3458419ad37d2c132a638f71771a81 Mon Sep 17 00:00:00 2001
From: UnCLAS-Prommer
Date: Sun, 22 Mar 2026 18:32:14 +0800
Subject: [PATCH 29/45] fix: remove nc ada pytest
---
pytests/test_napcat_adapter_codec.py | 151 --------------------------
pytests/test_napcat_adapter_config.py | 91 ----------------
pytests/test_napcat_adapter_plugin.py | 60 ----------
3 files changed, 302 deletions(-)
delete mode 100644 pytests/test_napcat_adapter_codec.py
delete mode 100644 pytests/test_napcat_adapter_config.py
delete mode 100644 pytests/test_napcat_adapter_plugin.py
diff --git a/pytests/test_napcat_adapter_codec.py b/pytests/test_napcat_adapter_codec.py
deleted file mode 100644
index 97ed1d9e..00000000
--- a/pytests/test_napcat_adapter_codec.py
+++ /dev/null
@@ -1,151 +0,0 @@
-from pathlib import Path
-from typing import Any, Dict
-
-import importlib
-import sys
-from types import SimpleNamespace
-
-import pytest
-
-
-BUILT_IN_PLUGIN_ROOT = Path(__file__).resolve().parents[1] / "src" / "plugins" / "built_in"
-if str(BUILT_IN_PLUGIN_ROOT) not in sys.path:
- sys.path.insert(0, str(BUILT_IN_PLUGIN_ROOT))
-
-NapCatInboundCodec = importlib.import_module("napcat_adapter.codec_inbound").NapCatInboundCodec
-NapCatOutboundCodec = importlib.import_module("napcat_adapter.codec_outbound").NapCatOutboundCodec
-
-
-def test_napcat_outbound_codec_supports_binary_and_forward_segments() -> None:
- codec = NapCatOutboundCodec()
- raw_message = [
- {"type": "text", "data": "hello"},
- {"type": "image", "data": "", "hash": "h1", "binary_data_base64": "aW1hZ2U="},
- {"type": "emoji", "data": "", "hash": "h2", "binary_data_base64": "ZW1vamk="},
- {"type": "voice", "data": "", "hash": "h3", "binary_data_base64": "dm9pY2U="},
- {
- "type": "reply",
- "data": {
- "target_message_id": "origin-1",
- "target_message_content": "origin text",
- },
- },
- {
- "type": "forward",
- "data": [
- {
- "user_id": "42",
- "user_nickname": "alice",
- "user_cardname": "Alice",
- "message_id": "fwd-1",
- "content": [{"type": "text", "data": "node-text"}],
- }
- ],
- },
- ]
-
- converted = codec.convert_segments(raw_message)
-
- assert converted[0] == {"type": "text", "data": {"text": "hello"}}
- assert converted[1]["type"] == "image"
- assert converted[1]["data"]["file"] == "base64://aW1hZ2U="
- assert converted[2]["type"] == "image"
- assert converted[2]["data"]["subtype"] == 1
- assert converted[3] == {"type": "record", "data": {"file": "base64://dm9pY2U="}}
- assert converted[4] == {"type": "reply", "data": {"id": "origin-1"}}
- assert converted[5]["type"] == "node"
- assert converted[5]["data"]["name"] == "alice"
- assert converted[5]["data"]["content"] == [{"type": "text", "data": {"text": "node-text"}}]
-
-
-def test_napcat_outbound_codec_builds_private_action_from_route_metadata() -> None:
- codec = NapCatOutboundCodec()
- message: Dict[str, Any] = {
- "message_info": {
- "user_info": {"user_id": "10001", "user_nickname": "tester"},
- "additional_config": {},
- },
- "raw_message": [{"type": "text", "data": "hello"}],
- }
-
- action_name, params = codec.build_outbound_action(message, {"target_user_id": "30001"})
-
- assert action_name == "send_private_msg"
- assert params == {"message": [{"type": "text", "data": {"text": "hello"}}], "user_id": "30001"}
-
-
-class DummyQueryService:
- """用于测试的轻量查询服务。"""
-
- async def download_binary(self, url: str) -> bytes:
- """返回固定图片二进制。
-
- Args:
- url: 图片地址。
-
- Returns:
- bytes: 固定测试图片二进制。
- """
- if url:
- return b"image-bytes"
- return b""
-
- async def get_message_detail(self, message_id: str) -> Dict[str, Any] | None:
- """返回空消息详情。
-
- Args:
- message_id: 目标消息 ID。
-
- Returns:
- Dict[str, Any] | None: 固定空结果。
- """
- del message_id
- return None
-
- async def get_record_detail(self, file_name: str, file_id: str | None = None) -> Dict[str, Any] | None:
- """返回空语音详情。
-
- Args:
- file_name: 语音文件名。
- file_id: 可选文件 ID。
-
- Returns:
- Dict[str, Any] | None: 固定空结果。
- """
- del file_name
- del file_id
- return None
-
- async def get_forward_message(self, message_id: str) -> Dict[str, Any] | None:
- """返回空转发详情。
-
- Args:
- message_id: 转发消息 ID。
-
- Returns:
- Dict[str, Any] | None: 固定空结果。
- """
- del message_id
- return None
-
-
-@pytest.mark.asyncio
-async def test_napcat_inbound_codec_parses_cq_string_image_segments() -> None:
- codec = NapCatInboundCodec(SimpleNamespace(debug=lambda message: None), DummyQueryService())
- payload = {
- "message": "[CQ:image,file=test.png,sub_type=0,url=https://example.com/test.png][CQ:at,qq=10001] 看到是国人直接给你封了",
- }
-
- raw_message, is_at = await codec.convert_segments(payload, "10001")
-
- assert raw_message[0]["type"] == "image"
- assert raw_message[1] == {
- "type": "at",
- "data": {
- "target_user_id": "10001",
- "target_user_nickname": None,
- "target_user_cardname": None,
- },
- }
- assert raw_message[2] == {"type": "text", "data": " 看到是国人直接给你封了"}
- assert is_at is True
diff --git a/pytests/test_napcat_adapter_config.py b/pytests/test_napcat_adapter_config.py
deleted file mode 100644
index 688b1a48..00000000
--- a/pytests/test_napcat_adapter_config.py
+++ /dev/null
@@ -1,91 +0,0 @@
-from pathlib import Path
-from typing import List
-
-import importlib
-import sys
-
-
-BUILT_IN_PLUGIN_ROOT = Path(__file__).resolve().parents[1] / "src" / "plugins" / "built_in"
-if str(BUILT_IN_PLUGIN_ROOT) not in sys.path:
- sys.path.insert(0, str(BUILT_IN_PLUGIN_ROOT))
-
-NapCatPluginSettings = importlib.import_module("napcat_adapter.config").NapCatPluginSettings
-
-
-class DummyLogger:
- """用于测试的轻量日志对象。"""
-
- def __init__(self) -> None:
- """初始化测试日志对象。"""
- self.warnings: List[str] = []
- self.errors: List[str] = []
-
- def warning(self, message: str) -> None:
- """记录警告日志。
-
- Args:
- message: 待记录的日志内容。
- """
- self.warnings.append(message)
-
- def error(self, message: str) -> None:
- """记录错误日志。
-
- Args:
- message: 待记录的日志内容。
- """
- self.errors.append(message)
-
-
-def test_parse_new_napcat_server_config() -> None:
- logger = DummyLogger()
- settings = NapCatPluginSettings.from_mapping(
- {
- "plugin": {"enabled": True, "config_version": "0.1.0"},
- "napcat_server": {
- "host": "localhost",
- "port": 8095,
- "token": "secret",
- "heartbeat_interval": 45,
- "reconnect_delay_sec": 7,
- "action_timeout_sec": 18,
- "connection_id": "main",
- },
- },
- logger,
- )
-
- assert settings.should_connect() is True
- assert settings.napcat_server.host == "localhost"
- assert settings.napcat_server.port == 8095
- assert settings.napcat_server.token == "secret"
- assert settings.napcat_server.heartbeat_interval == 45.0
- assert settings.napcat_server.reconnect_delay_sec == 7.0
- assert settings.napcat_server.action_timeout_sec == 18.0
- assert settings.napcat_server.connection_id == "main"
- assert settings.napcat_server.build_ws_url() == "ws://localhost:8095"
- assert settings.validate(logger) is True
-
-
-def test_parse_legacy_connection_ws_url_fallback() -> None:
- logger = DummyLogger()
- settings = NapCatPluginSettings.from_mapping(
- {
- "plugin": {"enabled": True, "config_version": "0.1.0"},
- "connection": {
- "ws_url": "ws://127.0.0.1:3001",
- "access_token": "legacy-token",
- "heartbeat_sec": 35,
- "action_timeout_sec": 12,
- },
- },
- logger,
- )
-
- assert settings.napcat_server.host == "127.0.0.1"
- assert settings.napcat_server.port == 3001
- assert settings.napcat_server.token == "legacy-token"
- assert settings.napcat_server.heartbeat_interval == 35.0
- assert settings.napcat_server.action_timeout_sec == 12.0
- assert settings.validate(logger) is True
- assert logger.warnings
diff --git a/pytests/test_napcat_adapter_plugin.py b/pytests/test_napcat_adapter_plugin.py
deleted file mode 100644
index ca550a39..00000000
--- a/pytests/test_napcat_adapter_plugin.py
+++ /dev/null
@@ -1,60 +0,0 @@
-"""NapCat 插件入口行为测试。"""
-
-from pathlib import Path
-from typing import List
-from types import SimpleNamespace
-
-import importlib
-import sys
-
-import pytest
-
-
-BUILT_IN_PLUGIN_ROOT = Path(__file__).resolve().parents[1] / "src" / "plugins" / "built_in"
-if str(BUILT_IN_PLUGIN_ROOT) not in sys.path:
- sys.path.insert(0, str(BUILT_IN_PLUGIN_ROOT))
-
-NapCatAdapterPlugin = importlib.import_module("napcat_adapter.plugin").NapCatAdapterPlugin
-
-
-class DummyLogger:
- """用于测试的轻量日志对象。"""
-
- def __init__(self) -> None:
- """初始化测试日志对象。"""
- self.debug_messages: List[str] = []
-
- def debug(self, message: str) -> None:
- """记录调试日志。
-
- Args:
- message: 待记录的日志内容。
- """
- self.debug_messages.append(message)
-
-
-@pytest.mark.asyncio
-async def test_on_config_update_refreshes_settings_and_restarts(monkeypatch: pytest.MonkeyPatch) -> None:
- """配置更新时应刷新插件配置、清空旧 settings,并触发连接重启。"""
- plugin = NapCatAdapterPlugin()
- plugin._ctx = SimpleNamespace(logger=DummyLogger())
- plugin._settings = object()
-
- restart_calls: List[dict] = []
-
- async def fake_restart() -> None:
- """记录一次重启调用。"""
- restart_calls.append(dict(plugin._plugin_config))
-
- monkeypatch.setattr(plugin, "_restart_connection_if_needed", fake_restart)
-
- new_config = {
- "plugin": {"enabled": True, "config_version": "0.1.0"},
- "napcat_server": {"host": "127.0.0.1", "port": 3001},
- }
- await plugin.on_config_update(new_config, "v2")
-
- assert plugin._plugin_config == new_config
- assert plugin._settings is None
- assert restart_calls == [new_config]
- assert plugin.ctx.logger.debug_messages == ["NapCat 适配器收到配置更新通知: v2"]
From e26b27c28707ee8007f40aee6f9a394e459092f4 Mon Sep 17 00:00:00 2001
From: DrSmoothl <1787882683@qq.com>
Date: Mon, 23 Mar 2026 10:54:29 +0800
Subject: [PATCH 30/45] refactor: update message gateway handling and remove
adapter references
- Changed the message sending method to return DeliveryBatch instead of DeliveryReceipt in integration.py.
- Removed AdapterDeclarationPayload and related references from envelope.py, replacing them with MessageGatewayStateUpdatePayload and MessageGatewayStateUpdateResultPayload.
- Updated runner_main.py to remove adapter-related logic and methods, focusing on message gateway functionality.
- Added tests for message gateway runtime state synchronization and action bridge functionality in test files.
---
pytests/test_adapter_runtime_state.py | 162 ----
pytests/test_message_gateway_runtime.py | 170 ++++
pytests/test_platform_io_dedupe.py | 49 +-
pytests/test_plugin_runtime_action_bridge.py | 138 ++++
.../message_receive/uni_message_sender.py | 22 +-
src/platform_io/__init__.py | 7 +-
src/platform_io/drivers/plugin_driver.py | 67 +-
src/platform_io/manager.py | 218 ++++--
src/platform_io/routing.py | 135 +---
src/platform_io/types.py | 49 +-
src/plugin_runtime/host/component_registry.py | 111 ++-
src/plugin_runtime/host/message_gateway.py | 20 +-
src/plugin_runtime/host/supervisor.py | 735 +++++++++++-------
src/plugin_runtime/integration.py | 10 +-
src/plugin_runtime/protocol/envelope.py | 53 +-
src/plugin_runtime/runner/runner_main.py | 35 +-
16 files changed, 1221 insertions(+), 760 deletions(-)
delete mode 100644 pytests/test_adapter_runtime_state.py
create mode 100644 pytests/test_message_gateway_runtime.py
create mode 100644 pytests/test_plugin_runtime_action_bridge.py
diff --git a/pytests/test_adapter_runtime_state.py b/pytests/test_adapter_runtime_state.py
deleted file mode 100644
index e82f4c8c..00000000
--- a/pytests/test_adapter_runtime_state.py
+++ /dev/null
@@ -1,162 +0,0 @@
-"""适配器运行时状态同步测试。"""
-
-from typing import Any, Dict
-
-import pytest
-
-from src.platform_io.manager import PlatformIOManager
-from src.platform_io.types import RouteKey
-from src.plugin_runtime.host.supervisor import PluginSupervisor
-from src.plugin_runtime.protocol.envelope import (
- AdapterDeclarationPayload,
- Envelope,
- MessageType,
-)
-
-
-def _make_request(plugin_id: str, payload: Dict[str, Any]) -> Envelope:
- """构造一个适配器状态更新 RPC 请求。
-
- Args:
- plugin_id: 目标适配器插件 ID。
- payload: 请求载荷。
-
- Returns:
- Envelope: 标准 RPC 请求信封。
- """
- return Envelope(
- request_id=1,
- message_type=MessageType.REQUEST,
- method="host.update_adapter_state",
- plugin_id=plugin_id,
- payload=payload,
- )
-
-
-@pytest.mark.asyncio
-async def test_adapter_runtime_state_binds_and_unbinds_routes(monkeypatch: pytest.MonkeyPatch) -> None:
- """连接建立后应绑定路由,断开后应撤销路由。"""
- import src.plugin_runtime.host.supervisor as supervisor_module
-
- platform_io_manager = PlatformIOManager()
- monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager)
-
- supervisor = PluginSupervisor(plugin_dirs=[])
- adapter = AdapterDeclarationPayload(platform="qq", protocol="napcat")
- await supervisor._register_adapter_driver("napcat_adapter_builtin", adapter)
-
- response = await supervisor._handle_update_adapter_state(
- _make_request(
- "napcat_adapter_builtin",
- {
- "connected": True,
- "account_id": "10001",
- "scope": "",
- "metadata": {},
- },
- )
- )
-
- assert response.error is None
- assert response.payload["accepted"] is True
- assert (
- platform_io_manager.route_table.get_active_binding(
- RouteKey(platform="qq", account_id="10001"),
- exact_only=True,
- ).driver_id
- == "adapter:napcat_adapter_builtin"
- )
- assert (
- platform_io_manager.route_table.get_active_binding(
- RouteKey(platform="qq"),
- exact_only=True,
- ).driver_id
- == "adapter:napcat_adapter_builtin"
- )
-
- response = await supervisor._handle_update_adapter_state(
- _make_request(
- "napcat_adapter_builtin",
- {
- "connected": False,
- "account_id": "",
- "scope": "",
- "metadata": {},
- },
- )
- )
-
- assert response.error is None
- assert response.payload["accepted"] is True
- assert platform_io_manager.route_table.get_active_binding(
- RouteKey(platform="qq", account_id="10001"),
- exact_only=True,
- ) is None
- assert platform_io_manager.route_table.get_active_binding(RouteKey(platform="qq"), exact_only=True) is None
-
-
-@pytest.mark.asyncio
-async def test_platform_default_route_is_removed_when_multiple_exact_routes_exist(
- monkeypatch: pytest.MonkeyPatch,
-) -> None:
- """同一平台存在多个精确路由时不应保留默认平台路由。"""
- import src.plugin_runtime.host.supervisor as supervisor_module
-
- platform_io_manager = PlatformIOManager()
- monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager)
-
- supervisor = PluginSupervisor(plugin_dirs=[])
- adapter = AdapterDeclarationPayload(platform="qq", protocol="napcat")
- await supervisor._register_adapter_driver("adapter_a", adapter)
- await supervisor._register_adapter_driver("adapter_b", adapter)
-
- await supervisor._handle_update_adapter_state(
- _make_request(
- "adapter_a",
- {
- "connected": True,
- "account_id": "10001",
- "scope": "",
- "metadata": {},
- },
- )
- )
- assert (
- platform_io_manager.route_table.get_active_binding(
- RouteKey(platform="qq"),
- exact_only=True,
- ).driver_id
- == "adapter:adapter_a"
- )
-
- await supervisor._handle_update_adapter_state(
- _make_request(
- "adapter_b",
- {
- "connected": True,
- "account_id": "10002",
- "scope": "",
- "metadata": {},
- },
- )
- )
- assert platform_io_manager.route_table.get_active_binding(RouteKey(platform="qq"), exact_only=True) is None
-
- await supervisor._handle_update_adapter_state(
- _make_request(
- "adapter_b",
- {
- "connected": False,
- "account_id": "",
- "scope": "",
- "metadata": {},
- },
- )
- )
- assert (
- platform_io_manager.route_table.get_active_binding(
- RouteKey(platform="qq"),
- exact_only=True,
- ).driver_id
- == "adapter:adapter_a"
- )
diff --git a/pytests/test_message_gateway_runtime.py b/pytests/test_message_gateway_runtime.py
new file mode 100644
index 00000000..9650bc10
--- /dev/null
+++ b/pytests/test_message_gateway_runtime.py
@@ -0,0 +1,170 @@
+"""消息网关运行时状态同步测试。"""
+
+from typing import Any, Dict
+
+import pytest
+
+from src.platform_io.manager import PlatformIOManager
+from src.platform_io.types import RouteKey
+from src.plugin_runtime.host.supervisor import PluginSupervisor
+from src.plugin_runtime.protocol.envelope import Envelope, MessageType
+
+
+def _make_request(method: str, plugin_id: str, payload: Dict[str, Any]) -> Envelope:
+ """构造一个 RPC 请求信封。
+
+ Args:
+ method: RPC 方法名。
+ plugin_id: 目标插件 ID。
+ payload: 请求载荷。
+
+ Returns:
+ Envelope: 标准 RPC 请求信封。
+ """
+
+ return Envelope(
+ request_id=1,
+ message_type=MessageType.REQUEST,
+ method=method,
+ plugin_id=plugin_id,
+ payload=payload,
+ )
+
+
+@pytest.mark.asyncio
+async def test_message_gateway_runtime_state_binds_send_and_receive_routes(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """消息网关就绪后应同时绑定发送表和接收表。"""
+
+ import src.plugin_runtime.host.supervisor as supervisor_module
+
+ platform_io_manager = PlatformIOManager()
+ monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager)
+
+ supervisor = PluginSupervisor(plugin_dirs=[])
+ register_response = await supervisor._handle_register_plugin(
+ _make_request(
+ "plugin.register_components",
+ "napcat_plugin",
+ {
+ "plugin_id": "napcat_plugin",
+ "plugin_version": "1.0.0",
+ "components": [
+ {
+ "name": "napcat_gateway",
+ "component_type": "MESSAGE_GATEWAY",
+ "plugin_id": "napcat_plugin",
+ "metadata": {
+ "route_type": "duplex",
+ "platform": "qq",
+ "protocol": "napcat",
+ },
+ }
+ ],
+ "capabilities_required": [],
+ },
+ )
+ )
+
+ assert register_response.error is None
+ response = await supervisor._handle_update_message_gateway_state(
+ _make_request(
+ "host.update_message_gateway_state",
+ "napcat_plugin",
+ {
+ "gateway_name": "napcat_gateway",
+ "ready": True,
+ "platform": "qq",
+ "account_id": "10001",
+ "scope": "primary",
+ "metadata": {},
+ },
+ )
+ )
+
+ assert response.error is None
+ assert response.payload["accepted"] is True
+
+ send_bindings = platform_io_manager.send_route_table.resolve_bindings(
+ RouteKey(platform="qq", account_id="10001", scope="primary")
+ )
+ receive_bindings = platform_io_manager.receive_route_table.resolve_bindings(
+ RouteKey(platform="qq", account_id="10001", scope="primary")
+ )
+
+ assert [binding.driver_id for binding in send_bindings] == ["gateway:napcat_plugin:napcat_gateway"]
+ assert [binding.driver_id for binding in receive_bindings] == ["gateway:napcat_plugin:napcat_gateway"]
+
+
+@pytest.mark.asyncio
+async def test_message_gateway_runtime_state_unbinds_routes_when_not_ready(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """消息网关断开后应撤销发送表和接收表中的绑定。"""
+
+ import src.plugin_runtime.host.supervisor as supervisor_module
+
+ platform_io_manager = PlatformIOManager()
+ monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager)
+
+ supervisor = PluginSupervisor(plugin_dirs=[])
+ await supervisor._handle_register_plugin(
+ _make_request(
+ "plugin.register_components",
+ "napcat_plugin",
+ {
+ "plugin_id": "napcat_plugin",
+ "plugin_version": "1.0.0",
+ "components": [
+ {
+ "name": "napcat_gateway",
+ "component_type": "MESSAGE_GATEWAY",
+ "plugin_id": "napcat_plugin",
+ "metadata": {
+ "route_type": "duplex",
+ "platform": "qq",
+ "protocol": "napcat",
+ },
+ }
+ ],
+ "capabilities_required": [],
+ },
+ )
+ )
+
+ await supervisor._handle_update_message_gateway_state(
+ _make_request(
+ "host.update_message_gateway_state",
+ "napcat_plugin",
+ {
+ "gateway_name": "napcat_gateway",
+ "ready": True,
+ "platform": "qq",
+ "account_id": "10001",
+ "scope": "primary",
+ "metadata": {},
+ },
+ )
+ )
+ response = await supervisor._handle_update_message_gateway_state(
+ _make_request(
+ "host.update_message_gateway_state",
+ "napcat_plugin",
+ {
+ "gateway_name": "napcat_gateway",
+ "ready": False,
+ "platform": "qq",
+ "account_id": "",
+ "scope": "",
+ "metadata": {},
+ },
+ )
+ )
+
+ assert response.error is None
+ assert response.payload["accepted"] is True
+ assert platform_io_manager.send_route_table.resolve_bindings(RouteKey(platform="qq", account_id="10001")) == []
+ assert (
+ platform_io_manager.receive_route_table.resolve_bindings(RouteKey(platform="qq", account_id="10001")) == []
+ )
diff --git a/pytests/test_platform_io_dedupe.py b/pytests/test_platform_io_dedupe.py
index 4a3cbb44..68ae95c6 100644
--- a/pytests/test_platform_io_dedupe.py
+++ b/pytests/test_platform_io_dedupe.py
@@ -159,6 +159,51 @@ class TestPlatformIODedupe:
session_message_envelope = _build_envelope(session_message_id="session-1")
payload_only_envelope = _build_envelope(payload={"message": "hello"})
- assert PlatformIOManager._build_inbound_dedupe_key(explicit_envelope) == "qq:10001:main:dedupe-1"
- assert PlatformIOManager._build_inbound_dedupe_key(session_message_envelope) == "qq:10001:main:session-1"
+ assert PlatformIOManager._build_inbound_dedupe_key(explicit_envelope) == "plugin.napcat:dedupe-1"
+ assert PlatformIOManager._build_inbound_dedupe_key(session_message_envelope) == "plugin.napcat:session-1"
assert PlatformIOManager._build_inbound_dedupe_key(payload_only_envelope) is None
+
+ @pytest.mark.asyncio
+ async def test_send_message_fans_out_to_all_matching_routes(self) -> None:
+ """同一路由命中多条发送链路时应全部发送。"""
+
+ manager = PlatformIOManager()
+ first_driver = _StubPlatformIODriver(
+ DriverDescriptor(
+ driver_id="plugin.gateway_a",
+ kind=DriverKind.PLUGIN,
+ platform="qq",
+ )
+ )
+ second_driver = _StubPlatformIODriver(
+ DriverDescriptor(
+ driver_id="plugin.gateway_b",
+ kind=DriverKind.PLUGIN,
+ platform="qq",
+ )
+ )
+ manager.register_driver(first_driver)
+ manager.register_driver(second_driver)
+ manager.bind_send_route(
+ RouteBinding(
+ route_key=RouteKey(platform="qq"),
+ driver_id=first_driver.driver_id,
+ driver_kind=first_driver.descriptor.kind,
+ )
+ )
+ manager.bind_send_route(
+ RouteBinding(
+ route_key=RouteKey(platform="qq"),
+ driver_id=second_driver.driver_id,
+ driver_kind=second_driver.descriptor.kind,
+ )
+ )
+
+ message = SimpleNamespace(message_id="internal-msg-1")
+ result = await manager.send_message(message, RouteKey(platform="qq"))
+
+ assert result.has_success is True
+ assert [receipt.driver_id for receipt in result.sent_receipts] == [
+ "plugin.gateway_a",
+ "plugin.gateway_b",
+ ]
diff --git a/pytests/test_plugin_runtime_action_bridge.py b/pytests/test_plugin_runtime_action_bridge.py
new file mode 100644
index 00000000..f2364094
--- /dev/null
+++ b/pytests/test_plugin_runtime_action_bridge.py
@@ -0,0 +1,138 @@
+from types import SimpleNamespace
+from typing import Any
+
+import pytest
+
+from src.core.component_registry import component_registry as core_component_registry
+from src.plugin_runtime.host.supervisor import PluginSupervisor
+from src.plugin_runtime.protocol.envelope import ComponentDeclaration, RegisterPluginPayload
+
+
+def _build_action_payload(plugin_id: str, action_name: str) -> RegisterPluginPayload:
+ """构造用于测试的 runtime Action 注册载荷。
+
+ Args:
+ plugin_id: 插件 ID。
+ action_name: Action 名称。
+
+ Returns:
+ RegisterPluginPayload: 测试用注册载荷。
+ """
+ return RegisterPluginPayload(
+ plugin_id=plugin_id,
+ plugin_version="1.0.0",
+ components=[
+ ComponentDeclaration(
+ name=action_name,
+ component_type="ACTION",
+ plugin_id=plugin_id,
+ metadata={
+ "description": "发送一个测试回复",
+ "enabled": True,
+ "activation_type": "keyword",
+ "activation_probability": 0.25,
+ "activation_keywords": ["测试", "hello"],
+ "action_parameters": {"target": "目标对象"},
+ "action_require": ["需要发送回复时使用"],
+ "associated_types": ["text"],
+ "parallel_action": True,
+ },
+ )
+ ],
+ )
+
+
+@pytest.mark.asyncio
+async def test_runtime_actions_are_mirrored_into_core_registry_and_invoked(monkeypatch: pytest.MonkeyPatch) -> None:
+ """运行时 Action 应镜像到旧核心注册表,并可由旧 Planner 执行。"""
+ plugin_id = "runtime_action_bridge_plugin"
+ action_name = "runtime_action_bridge_test"
+ payload = _build_action_payload(plugin_id=plugin_id, action_name=action_name)
+ supervisor = PluginSupervisor(plugin_dirs=[])
+ captured: dict[str, Any] = {}
+
+ core_component_registry.remove_action(action_name)
+
+ async def fake_invoke_plugin(
+ method: str,
+ plugin_id: str,
+ component_name: str,
+ args: dict[str, Any] | None = None,
+ timeout_ms: int = 30000,
+ ) -> Any:
+ """模拟 plugin runtime Action 调用。
+
+ Args:
+ method: RPC 方法名。
+ plugin_id: 插件 ID。
+ component_name: 组件名称。
+ args: 调用参数。
+ timeout_ms: RPC 超时时间。
+
+ Returns:
+ Any: 伪造的 RPC 响应对象。
+ """
+ captured["method"] = method
+ captured["plugin_id"] = plugin_id
+ captured["component_name"] = component_name
+ captured["args"] = args or {}
+ captured["timeout_ms"] = timeout_ms
+ return SimpleNamespace(payload={"success": True, "result": (True, "runtime action executed")})
+
+ monkeypatch.setattr(supervisor, "invoke_plugin", fake_invoke_plugin)
+
+ try:
+ supervisor._mirror_runtime_actions_to_core_registry(payload)
+
+ action_info = core_component_registry.get_action_info(action_name)
+ assert action_info is not None
+ assert action_info.plugin_name == plugin_id
+ assert action_info.description == "发送一个测试回复"
+ assert action_info.activation_keywords == ["测试", "hello"]
+ assert action_info.random_activation_probability == 0.25
+ assert action_info.parallel_action is True
+
+ executor = core_component_registry.get_action_executor(action_name)
+ assert executor is not None
+
+ success, reason = await executor(
+ action_data={"target": "MaiBot"},
+ action_reasoning="当前适合使用这个动作",
+ cycle_timers={"planner": 0.1},
+ thinking_id="tid-1",
+ chat_stream=SimpleNamespace(session_id="stream-1"),
+ log_prefix="[test]",
+ shutting_down=False,
+ plugin_config={"enabled": True},
+ )
+
+ assert success is True
+ assert reason == "runtime action executed"
+ assert captured["method"] == "plugin.invoke_action"
+ assert captured["plugin_id"] == plugin_id
+ assert captured["component_name"] == action_name
+ assert captured["args"]["stream_id"] == "stream-1"
+ assert captured["args"]["chat_id"] == "stream-1"
+ assert captured["args"]["reasoning"] == "当前适合使用这个动作"
+ assert captured["args"]["target"] == "MaiBot"
+ assert captured["args"]["action_data"] == {"target": "MaiBot"}
+ finally:
+ supervisor._remove_core_action_mirrors(plugin_id)
+ core_component_registry.remove_action(action_name)
+
+
+def test_clear_runner_state_removes_mirrored_runtime_actions() -> None:
+ """清理 Runner 状态时应同步移除旧核心注册表中的镜像 Action。"""
+ plugin_id = "runtime_action_bridge_cleanup_plugin"
+ action_name = "runtime_action_bridge_cleanup_test"
+ payload = _build_action_payload(plugin_id=plugin_id, action_name=action_name)
+ supervisor = PluginSupervisor(plugin_dirs=[])
+
+ core_component_registry.remove_action(action_name)
+
+ supervisor._mirror_runtime_actions_to_core_registry(payload)
+ assert core_component_registry.get_action_info(action_name) is not None
+
+ supervisor._clear_runner_state()
+
+ assert core_component_registry.get_action_info(action_name) is None
diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py
index 17d5d6d5..df74e459 100644
--- a/src/chat/message_receive/uni_message_sender.py
+++ b/src/chat/message_receive/uni_message_sender.py
@@ -125,23 +125,27 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
return True
try:
- from src.platform_io import DeliveryStatus
from src.plugin_runtime.integration import get_plugin_runtime_manager
- receipt = await get_plugin_runtime_manager().try_send_message_via_platform_io(message)
- if receipt is not None:
- if receipt.status == DeliveryStatus.SENT:
+ delivery_batch = await get_plugin_runtime_manager().try_send_message_via_platform_io(message)
+ if delivery_batch is not None:
+ if delivery_batch.has_success:
+ successful_driver_ids = [
+ receipt.driver_id or "unknown"
+ for receipt in delivery_batch.sent_receipts
+ ]
if show_log:
logger.info(
f"已通过 Platform IO 将消息 '{message_preview}' 发往平台'{platform}' "
- f"(driver: {receipt.driver_id or 'unknown'})"
+ f"(drivers: {', '.join(successful_driver_ids)})"
)
return True
- logger.warning(
- f"Platform IO 发送失败: platform={platform} driver={receipt.driver_id} "
- f"status={receipt.status} error={receipt.error}"
- )
+ failed_details = "; ".join(
+ f"driver={receipt.driver_id} status={receipt.status} error={receipt.error}"
+ for receipt in delivery_batch.failed_receipts
+ ) or "未命中任何发送路由"
+ logger.warning(f"Platform IO 发送失败: platform={platform} {failed_details}")
return False
except Exception as exc:
logger.warning(f"检查 Platform IO 出站链路时出现异常,将回退旧发送链: {exc}")
diff --git a/src/platform_io/__init__.py b/src/platform_io/__init__.py
index 380ecbb6..c91535d1 100644
--- a/src/platform_io/__init__.py
+++ b/src/platform_io/__init__.py
@@ -6,8 +6,9 @@
from .manager import PlatformIOManager, get_platform_io_manager
from .route_key_factory import RouteKeyFactory
-from .routing import RouteBindingConflictError, RouteTable
+from .routing import RouteTable
from .types import (
+ DeliveryBatch,
DeliveryReceipt,
DeliveryStatus,
DriverDescriptor,
@@ -15,10 +16,10 @@ from .types import (
InboundMessageEnvelope,
RouteBinding,
RouteKey,
- RouteMode,
)
__all__ = [
+ "DeliveryBatch",
"DeliveryReceipt",
"DeliveryStatus",
"DriverDescriptor",
@@ -27,9 +28,7 @@ __all__ = [
"PlatformIOManager",
"RouteKeyFactory",
"RouteBinding",
- "RouteBindingConflictError",
"RouteKey",
- "RouteMode",
"RouteTable",
"get_platform_io_manager",
]
diff --git a/src/platform_io/drivers/plugin_driver.py b/src/platform_io/drivers/plugin_driver.py
index dff980f8..c03204ad 100644
--- a/src/platform_io/drivers/plugin_driver.py
+++ b/src/platform_io/drivers/plugin_driver.py
@@ -1,4 +1,4 @@
-"""提供 Platform IO 的插件适配器驱动实现。"""
+"""提供 Platform IO 的插件消息网关驱动实现。"""
from typing import TYPE_CHECKING, Any, Dict, Optional, Protocol
@@ -9,45 +9,49 @@ if TYPE_CHECKING:
from src.chat.message_receive.message import SessionMessage
-class _AdapterSupervisorProtocol(Protocol):
- """适配器驱动依赖的 Supervisor 最小协议。"""
+class _GatewaySupervisorProtocol(Protocol):
+ """消息网关驱动依赖的 Supervisor 最小协议。"""
- async def invoke_adapter(
+ async def invoke_message_gateway(
self,
plugin_id: str,
- method_name: str,
+ component_name: str,
args: Optional[Dict[str, Any]] = None,
timeout_ms: int = 30000,
) -> Any:
- """调用适配器插件专用方法。"""
+ """调用插件声明的消息网关方法。"""
class PluginPlatformDriver(PlatformIODriver):
- """面向适配器插件链路的 Platform IO 驱动。"""
+ """面向插件消息网关链路的 Platform IO 驱动。"""
def __init__(
self,
driver_id: str,
platform: str,
- supervisor: _AdapterSupervisorProtocol,
- send_method: str = "send_to_platform",
+ supervisor: _GatewaySupervisorProtocol,
+ component_name: str,
+ *,
+ supports_send: bool,
account_id: Optional[str] = None,
scope: Optional[str] = None,
plugin_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
- """初始化一个插件适配器驱动。
+ """初始化一个插件消息网关驱动。
Args:
driver_id: Broker 内的唯一驱动 ID。
- platform: 该适配器负责的平台名称。
- supervisor: 持有该适配器插件的 Supervisor。
- send_method: 出站发送时要调用的插件方法名。
+ platform: 该消息网关负责的平台名称。
+ supervisor: 持有该插件的 Supervisor。
+ component_name: 出站时要调用的网关组件名称。
+ supports_send: 当前驱动是否具备出站能力。
account_id: 可选的账号 ID 或 self ID。
scope: 可选的额外路由作用域。
- plugin_id: 拥有该适配器实现的插件 ID。
+ plugin_id: 拥有该实现的插件 ID。
metadata: 可选的额外驱动元数据。
"""
+
descriptor = DriverDescriptor(
driver_id=driver_id,
kind=DriverKind.PLUGIN,
@@ -59,7 +63,8 @@ class PluginPlatformDriver(PlatformIODriver):
)
super().__init__(descriptor)
self._supervisor = supervisor
- self._send_method = send_method
+ self._component_name = component_name
+ self._supports_send = supports_send
async def send_message(
self,
@@ -67,16 +72,27 @@ class PluginPlatformDriver(PlatformIODriver):
route_key: RouteKey,
metadata: Optional[Dict[str, Any]] = None,
) -> DeliveryReceipt:
- """通过适配器插件发送消息。
+ """通过插件消息网关发送消息。
Args:
message: 要投递的内部会话消息。
route_key: Broker 为本次投递选择的路由键。
- metadata: 本次出站投递可选的 Broker 侧元数据。
+ metadata: 可选的发送元数据。
Returns:
- DeliveryReceipt: 由驱动返回的规范化回执。
+ DeliveryReceipt: 规范化后的发送回执。
"""
+
+ if not self._supports_send:
+ return DeliveryReceipt(
+ internal_message_id=message.message_id,
+ route_key=route_key,
+ status=DeliveryStatus.FAILED,
+ driver_id=self.driver_id,
+ driver_kind=self.descriptor.kind,
+ error="当前消息网关仅支持接收,不支持发送",
+ )
+
from src.plugin_runtime.host.message_utils import PluginMessageUtils
plugin_id = self.descriptor.plugin_id or ""
@@ -87,14 +103,14 @@ class PluginPlatformDriver(PlatformIODriver):
status=DeliveryStatus.FAILED,
driver_id=self.driver_id,
driver_kind=self.descriptor.kind,
- error="插件适配器驱动缺少 plugin_id",
+ error="插件消息网关驱动缺少 plugin_id",
)
try:
message_dict = PluginMessageUtils._session_message_to_dict(message)
- response = await self._supervisor.invoke_adapter(
+ response = await self._supervisor.invoke_message_gateway(
plugin_id=plugin_id,
- method_name=self._send_method,
+ component_name=self._component_name,
args={
"message": message_dict,
"route": {
@@ -119,7 +135,7 @@ class PluginPlatformDriver(PlatformIODriver):
return self._build_receipt(message.message_id, route_key, response)
def _build_receipt(self, internal_message_id: str, route_key: RouteKey, response: Any) -> DeliveryReceipt:
- """将适配器调用响应归一化为出站回执。
+ """将网关调用响应归一化为出站回执。
Args:
internal_message_id: 内部消息 ID。
@@ -129,8 +145,9 @@ class PluginPlatformDriver(PlatformIODriver):
Returns:
DeliveryReceipt: 标准化后的出站回执。
"""
+
if getattr(response, "error", None):
- error = response.error.get("message", "适配器发送失败")
+ error = response.error.get("message", "消息网关发送失败")
return DeliveryReceipt(
internal_message_id=internal_message_id,
route_key=route_key,
@@ -149,7 +166,7 @@ class PluginPlatformDriver(PlatformIODriver):
status=DeliveryStatus.FAILED,
driver_id=self.driver_id,
driver_kind=self.descriptor.kind,
- error=str(payload.get("result", "适配器发送失败")) if isinstance(payload, dict) else "适配器发送失败",
+ error=str(payload.get("result", "消息网关发送失败")) if isinstance(payload, dict) else "消息网关发送失败",
)
result = payload.get("result") if isinstance(payload, dict) else None
@@ -161,7 +178,7 @@ class PluginPlatformDriver(PlatformIODriver):
status=DeliveryStatus.FAILED,
driver_id=self.driver_id,
driver_kind=self.descriptor.kind,
- error=str(result.get("error", "适配器发送失败")),
+ error=str(result.get("error", "消息网关发送失败")),
metadata=result.get("metadata", {}) if isinstance(result.get("metadata"), dict) else {},
)
external_message_id = str(result.get("external_message_id") or result.get("message_id") or "") or None
diff --git a/src/platform_io/manager.py b/src/platform_io/manager.py
index b1fe3bdc..c96a9ddd 100644
--- a/src/platform_io/manager.py
+++ b/src/platform_io/manager.py
@@ -10,7 +10,7 @@ from .outbound_tracker import OutboundTracker
from .route_key_factory import RouteKeyFactory
from .registry import DriverRegistry
from .routing import RouteTable
-from .types import DeliveryReceipt, DeliveryStatus, InboundMessageEnvelope, RouteBinding, RouteKey
+from .types import DeliveryBatch, DeliveryReceipt, DeliveryStatus, InboundMessageEnvelope, RouteBinding, RouteKey
if TYPE_CHECKING:
from src.chat.message_receive.message import SessionMessage
@@ -21,17 +21,21 @@ InboundDispatcher = Callable[[InboundMessageEnvelope], Awaitable[None]]
class PlatformIOManager:
- """统一协调双路径平台消息 IO 的路由、去重与状态跟踪。
+ """统一协调平台消息 IO 的路由、去重与状态跟踪。
- 这个管理器预期会成为 legacy 适配器链路与 plugin 适配器链路之间的
- 唯一裁决点。当前地基阶段,它只提供共享状态和 Broker 侧契约,还没有
- 真正把生产流量切到新中间层。
+ 与旧实现不同,这个管理器不再负责“多条链路谁该接管平台”的裁决,
+ 只维护发送表和接收表两张轻量路由表:
+
+ - 发送时:解析所有命中的发送绑定并全部投递。
+ - 接收时:只校验当前驱动是否已登记为可接收链路,然后全部放行给上层。
+ - 去重时:仅对单条链路做技术性重放抑制,不做跨链路语义去重。
"""
def __init__(self) -> None:
"""初始化 Broker 管理器及其内存状态。"""
self._driver_registry = DriverRegistry()
- self._route_table = RouteTable()
+ self._send_route_table = RouteTable()
+ self._receive_route_table = RouteTable()
self._deduplicator = MessageDeduplicator()
self._outbound_tracker = OutboundTracker()
self._inbound_dispatcher: Optional[InboundDispatcher] = None
@@ -152,13 +156,22 @@ class PlatformIOManager:
return self._driver_registry
@property
- def route_table(self) -> RouteTable:
- """返回管理器持有的路由绑定表。
+ def send_route_table(self) -> RouteTable:
+ """返回发送路由表。"""
- Returns:
- RouteTable: 用于归属解析的路由绑定表。
- """
- return self._route_table
+ return self._send_route_table
+
+ @property
+ def receive_route_table(self) -> RouteTable:
+ """返回接收路由表。"""
+
+ return self._receive_route_table
+
+ @property
+ def route_table(self) -> RouteTable:
+ """兼容旧接口,返回发送路由表。"""
+
+ return self._send_route_table
@property
def deduplicator(self) -> MessageDeduplicator:
@@ -257,15 +270,15 @@ class PlatformIOManager:
return None
removed_driver.clear_inbound_handler()
- self._route_table.remove_bindings_by_driver(driver_id)
+ self._send_route_table.remove_bindings_by_driver(driver_id)
+ self._receive_route_table.remove_bindings_by_driver(driver_id)
return removed_driver
- def bind_route(self, binding: RouteBinding, *, replace: bool = False) -> None:
- """为某个路由键绑定驱动。
+ def bind_send_route(self, binding: RouteBinding) -> None:
+ """为某个路由键绑定发送驱动。
Args:
binding: 要保存的路由绑定。
- replace: 是否允许替换已有的精确 active owner。
Raises:
ValueError: 当绑定引用了不存在的驱动,或者绑定与驱动描述不一致时抛出。
@@ -275,30 +288,78 @@ class PlatformIOManager:
raise ValueError(f"驱动 {binding.driver_id} 未注册,无法绑定路由")
self._validate_binding_against_driver(binding, driver)
- self._route_table.bind(binding, replace=replace)
+ self._send_route_table.bind(binding)
- def unbind_route(self, route_key: RouteKey, driver_id: Optional[str] = None) -> None:
- """移除一个或多个路由绑定。
+ def bind_receive_route(self, binding: RouteBinding) -> None:
+ """为某个路由键绑定接收驱动。
+
+ Args:
+ binding: 要保存的路由绑定。
+
+ Raises:
+ ValueError: 当绑定引用了不存在的驱动,或者绑定与驱动描述不一致时抛出。
+ """
+ driver = self._driver_registry.get(binding.driver_id)
+ if driver is None:
+ raise ValueError(f"驱动 {binding.driver_id} 未注册,无法绑定路由")
+
+ self._validate_binding_against_driver(binding, driver)
+ self._receive_route_table.bind(binding)
+
+ def bind_route(self, binding: RouteBinding) -> None:
+ """兼容旧接口,默认同时绑定发送表和接收表。"""
+
+ self.bind_send_route(binding)
+ self.bind_receive_route(binding)
+
+ def unbind_send_route(self, route_key: RouteKey, driver_id: Optional[str] = None) -> None:
+ """移除发送路由绑定。
Args:
route_key: 要移除绑定的路由键。
driver_id: 可选的特定驱动 ID。
"""
- self._route_table.unbind(route_key, driver_id)
- def resolve_driver(self, route_key: RouteKey) -> Optional[PlatformIODriver]:
- """解析某个路由键当前的 active 驱动。
+ self._send_route_table.unbind(route_key, driver_id)
+
+ def unbind_receive_route(self, route_key: RouteKey, driver_id: Optional[str] = None) -> None:
+ """移除接收路由绑定。
+
+ Args:
+ route_key: 要移除绑定的路由键。
+ driver_id: 可选的特定驱动 ID。
+ """
+
+ self._receive_route_table.unbind(route_key, driver_id)
+
+ def unbind_route(self, route_key: RouteKey, driver_id: Optional[str] = None) -> None:
+ """兼容旧接口,默认同时从发送表和接收表解绑。"""
+
+ self.unbind_send_route(route_key, driver_id)
+ self.unbind_receive_route(route_key, driver_id)
+
+ def resolve_drivers(self, route_key: RouteKey) -> List[PlatformIODriver]:
+ """解析某个路由键当前命中的全部发送驱动。
Args:
route_key: 要解析的路由键。
Returns:
- Optional[PlatformIODriver]: 若存在 active 驱动,则返回该驱动实例。
+ List[PlatformIODriver]: 当前命中的全部发送驱动。
"""
- active_binding = self._route_table.get_active_binding(route_key)
- if active_binding is None:
- return None
- return self._driver_registry.get(active_binding.driver_id)
+
+ drivers: List[PlatformIODriver] = []
+ for binding in self._send_route_table.resolve_bindings(route_key):
+ driver = self._driver_registry.get(binding.driver_id)
+ if driver is not None:
+ drivers.append(driver)
+ return drivers
+
+ def resolve_driver(self, route_key: RouteKey) -> Optional[PlatformIODriver]:
+ """兼容旧接口,返回首个命中的发送驱动。"""
+
+ drivers = self.resolve_drivers(route_key)
+ return drivers[0] if drivers else None
@staticmethod
def build_route_key_from_message(message: "SessionMessage") -> RouteKey:
@@ -335,9 +396,9 @@ class PlatformIOManager:
否则返回 ``False``。
"""
- if not self._route_table.accepts_inbound(envelope.route_key, envelope.driver_id):
+ if not self._receive_route_table.has_binding_for_driver(envelope.route_key, envelope.driver_id):
logger.info(
- "忽略非 active owner 的入站消息: route=%s driver=%s",
+ "忽略未登记到接收路由表的入站消息: route=%s driver=%s",
envelope.route_key,
envelope.driver_id,
)
@@ -361,8 +422,8 @@ class PlatformIOManager:
message: "SessionMessage",
route_key: RouteKey,
metadata: Optional[Dict[str, Any]] = None,
- ) -> DeliveryReceipt:
- """通过 Broker 选中的驱动发送一条消息。
+ ) -> DeliveryBatch:
+ """通过 Broker 选中的全部发送驱动广播一条消息。
Args:
message: 要投递的内部会话消息。
@@ -370,61 +431,54 @@ class PlatformIOManager:
metadata: 可选的额外 Broker 侧元数据。
Returns:
- DeliveryReceipt: 规范化后的出站回执。若路由不存在、驱动缺失,
- 或同一消息已存在未完成的出站跟踪,也会返回失败回执而不是抛异常。
+ DeliveryBatch: 规范化后的批量出站回执。
"""
+ drivers = self.resolve_drivers(route_key)
+ if not drivers:
+ return DeliveryBatch(internal_message_id=message.message_id, route_key=route_key)
- active_binding = self._route_table.get_active_binding(route_key)
- if active_binding is None:
- return DeliveryReceipt(
- internal_message_id=message.message_id,
- route_key=route_key,
- status=DeliveryStatus.FAILED,
- error="未找到 active 路由绑定",
- )
+ receipts: List[DeliveryReceipt] = []
+ for driver in drivers:
+ try:
+ self._outbound_tracker.begin_tracking(
+ internal_message_id=message.message_id,
+ route_key=route_key,
+ driver_id=driver.driver_id,
+ metadata=metadata,
+ )
+ except ValueError as exc:
+ receipts.append(
+ DeliveryReceipt(
+ internal_message_id=message.message_id,
+ route_key=route_key,
+ status=DeliveryStatus.FAILED,
+ driver_id=driver.driver_id,
+ driver_kind=driver.descriptor.kind,
+ error=str(exc),
+ )
+ )
+ continue
- driver = self._driver_registry.get(active_binding.driver_id)
- if driver is None:
- return DeliveryReceipt(
- internal_message_id=message.message_id,
- route_key=route_key,
- status=DeliveryStatus.FAILED,
- driver_id=active_binding.driver_id,
- driver_kind=active_binding.driver_kind,
- error="active 路由绑定对应的驱动不存在",
- )
+ try:
+ receipt = await driver.send_message(message=message, route_key=route_key, metadata=metadata)
+ except Exception as exc:
+ receipt = DeliveryReceipt(
+ internal_message_id=message.message_id,
+ route_key=route_key,
+ status=DeliveryStatus.FAILED,
+ driver_id=driver.driver_id,
+ driver_kind=driver.descriptor.kind,
+ error=str(exc),
+ )
- try:
- self._outbound_tracker.begin_tracking(
- internal_message_id=message.message_id,
- route_key=route_key,
- driver_id=driver.driver_id,
- metadata=metadata,
- )
- except ValueError as exc:
- return DeliveryReceipt(
- internal_message_id=message.message_id,
- route_key=route_key,
- status=DeliveryStatus.FAILED,
- driver_id=driver.driver_id,
- driver_kind=driver.descriptor.kind,
- error=str(exc),
- )
+ self._outbound_tracker.finish_tracking(receipt)
+ receipts.append(receipt)
- try:
- receipt = await driver.send_message(message=message, route_key=route_key, metadata=metadata)
- except Exception as exc:
- receipt = DeliveryReceipt(
- internal_message_id=message.message_id,
- route_key=route_key,
- status=DeliveryStatus.FAILED,
- driver_id=driver.driver_id,
- driver_kind=driver.descriptor.kind,
- error=str(exc),
- )
-
- self._outbound_tracker.finish_tracking(receipt)
- return receipt
+ return DeliveryBatch(
+ internal_message_id=message.message_id,
+ route_key=route_key,
+ receipts=receipts,
+ )
@staticmethod
def _build_inbound_dedupe_key(envelope: InboundMessageEnvelope) -> Optional[str]:
@@ -453,7 +507,7 @@ class PlatformIOManager:
if not normalized_dedupe_key:
return None
- return f"{envelope.route_key.to_dedupe_scope()}:{normalized_dedupe_key}"
+ return f"{envelope.driver_id}:{normalized_dedupe_key}"
@staticmethod
def _validate_binding_against_driver(binding: RouteBinding, driver: PlatformIODriver) -> None:
diff --git a/src/platform_io/routing.py b/src/platform_io/routing.py
index 7f85bbfa..2a9b41ef 100644
--- a/src/platform_io/routing.py
+++ b/src/platform_io/routing.py
@@ -1,52 +1,29 @@
-"""提供 Platform IO 的路由绑定存储与归属解析能力。"""
+"""提供 Platform IO 的轻量路由绑定表。"""
from typing import Dict, List, Optional
-from .types import RouteBinding, RouteKey, RouteMode
-
-
-class RouteBindingConflictError(ValueError):
- """当同一路由键出现多个 active owner 竞争时抛出。"""
+from .types import RouteBinding, RouteKey
class RouteTable:
- """维护路由绑定并解析路由归属。
+ """维护单张路由绑定表。
- 这个表刻意保持轻量,只负责归属规则本身,不掺杂具体发送或接收逻辑。
- 它决定某个路由键当前由哪个驱动 active 接管,哪些驱动仅以 shadow
- 方式旁路观测。
+ 该实现不负责裁决“唯一 owner”,只负责保存绑定,并按
+ ``RouteKey.resolution_order()`` 解析出候选绑定列表。
"""
def __init__(self) -> None:
- """初始化一个空的路由绑定表。"""
+ """初始化空路由绑定表。"""
+
self._bindings: Dict[RouteKey, Dict[str, RouteBinding]] = {}
- def bind(self, binding: RouteBinding, *, replace: bool = False) -> None:
+ def bind(self, binding: RouteBinding) -> None:
"""注册或更新一条路由绑定。
Args:
- binding: 要注册的绑定对象。
- replace: 当精确路由键上已经存在 active owner 时,是否允许替换。
-
- Raises:
- RouteBindingConflictError: 当精确路由键上已存在其他 active owner,
- 且 ``replace`` 为 ``False`` 时抛出。
+ binding: 要保存的路由绑定。
"""
- if binding.mode == RouteMode.DISABLED:
- self.unbind(binding.route_key, binding.driver_id)
- return
-
- if binding.mode == RouteMode.ACTIVE:
- active_binding = self.get_active_binding(binding.route_key, exact_only=True)
- if active_binding and active_binding.driver_id != binding.driver_id:
- if not replace:
- raise RouteBindingConflictError(
- f"RouteKey {binding.route_key} 已由 {active_binding.driver_id} 接管,"
- f"拒绝绑定到 {binding.driver_id}"
- )
- self.unbind(binding.route_key, active_binding.driver_id)
-
self._bindings.setdefault(binding.route_key, {})[binding.driver_id] = binding
def unbind(self, route_key: RouteKey, driver_id: Optional[str] = None) -> List[RouteBinding]:
@@ -54,7 +31,7 @@ class RouteTable:
Args:
route_key: 要移除绑定的路由键。
- driver_id: 可选的特定驱动 ID;若为空,则移除该路由键上的全部绑定。
+ driver_id: 可选的驱动 ID;为空时移除该路由键下全部绑定。
Returns:
List[RouteBinding]: 被移除的绑定列表。
@@ -67,15 +44,15 @@ class RouteTable:
if driver_id is None:
removed = list(binding_map.values())
self._bindings.pop(route_key, None)
- return removed
+ return self._sort_bindings(removed)
removed_binding = binding_map.pop(driver_id, None)
if not binding_map:
self._bindings.pop(route_key, None)
- return [removed_binding] if removed_binding else []
+ return [removed_binding] if removed_binding is not None else []
def remove_bindings_by_driver(self, driver_id: str) -> List[RouteBinding]:
- """移除某个驱动在所有路由键上的绑定。
+ """移除某个驱动在整张表上的全部绑定。
Args:
driver_id: 要移除绑定的驱动 ID。
@@ -83,9 +60,9 @@ class RouteTable:
Returns:
List[RouteBinding]: 被移除的绑定列表。
"""
+
removed_bindings: List[RouteBinding] = []
empty_route_keys: List[RouteKey] = []
-
for route_key, binding_map in self._bindings.items():
removed_binding = binding_map.pop(driver_id, None)
if removed_binding is not None:
@@ -99,13 +76,13 @@ class RouteTable:
return self._sort_bindings(removed_bindings)
def list_bindings(self, route_key: Optional[RouteKey] = None) -> List[RouteBinding]:
- """列出当前绑定。
+ """列出当前路由表中的绑定。
Args:
- route_key: 可选的路由键过滤条件;若为空,则返回全部路由键上的绑定。
+ route_key: 可选的路由键过滤条件。
Returns:
- List[RouteBinding]: 按优先级降序排列的绑定列表。
+ List[RouteBinding]: 当前绑定列表。
"""
if route_key is None:
@@ -117,51 +94,38 @@ class RouteTable:
binding_map = self._bindings.get(route_key, {})
return self._sort_bindings(list(binding_map.values()))
- def get_active_binding(self, route_key: RouteKey, *, exact_only: bool = False) -> Optional[RouteBinding]:
- """获取某个路由键当前生效的 active 绑定。
+ def resolve_bindings(self, route_key: RouteKey) -> List[RouteBinding]:
+ """按从具体到宽泛的顺序解析路由候选绑定。
Args:
- route_key: 要解析的路由键。
- exact_only: 是否只检查精确路由键而不做回退解析。
+ route_key: 待解析的路由键。
Returns:
- Optional[RouteBinding]: 若存在 active owner,则返回对应绑定。
+ List[RouteBinding]: 去重后的候选绑定列表。
"""
- candidate_keys = [route_key] if exact_only else route_key.resolution_order()
- for candidate_key in candidate_keys:
- binding_map = self._bindings.get(candidate_key, {})
- active_binding = self._pick_best_binding(binding_map, RouteMode.ACTIVE)
- if active_binding is not None:
- return active_binding
- return None
+ resolved_bindings: List[RouteBinding] = []
+ seen_driver_ids: set[str] = set()
+ for candidate_key in route_key.resolution_order():
+ for binding in self.list_bindings(candidate_key):
+ if binding.driver_id in seen_driver_ids:
+ continue
+ seen_driver_ids.add(binding.driver_id)
+ resolved_bindings.append(binding)
+ return resolved_bindings
- def get_shadow_bindings(self, route_key: RouteKey) -> List[RouteBinding]:
- """获取某个精确路由键上的 shadow 绑定。
+ def has_binding_for_driver(self, route_key: RouteKey, driver_id: str) -> bool:
+ """判断指定驱动是否在当前路由键解析结果中。
Args:
- route_key: 要查看的路由键。
+ route_key: 待解析的路由键。
+ driver_id: 目标驱动 ID。
Returns:
- List[RouteBinding]: 按优先级降序排列的 shadow 绑定列表。
- """
- binding_map = self._bindings.get(route_key, {})
- shadow_bindings = [binding for binding in binding_map.values() if binding.mode == RouteMode.SHADOW]
- return self._sort_bindings(shadow_bindings)
-
- def accepts_inbound(self, route_key: RouteKey, driver_id: str) -> bool:
- """判断某个驱动是否是当前允许入 Core 的 active owner。
-
- Args:
- route_key: 入站消息对应的路由键。
- driver_id: 希望将消息送入 Core 的驱动 ID。
-
- Returns:
- bool: 若该驱动是解析结果中的 active owner,则返回 ``True``。
+ bool: 若驱动存在于解析结果中则返回 ``True``。
"""
- active_binding = self.get_active_binding(route_key)
- return active_binding is not None and active_binding.driver_id == driver_id
+ return any(binding.driver_id == driver_id for binding in self.resolve_bindings(route_key))
@staticmethod
def _sort_bindings(bindings: List[RouteBinding]) -> List[RouteBinding]:
@@ -173,30 +137,5 @@ class RouteTable:
Returns:
List[RouteBinding]: 排序后的绑定列表。
"""
+
return sorted(bindings, key=lambda item: item.priority, reverse=True)
-
- @staticmethod
- def _pick_best_binding(
- binding_map: Dict[str, RouteBinding],
- mode: RouteMode,
- ) -> Optional[RouteBinding]:
- """从绑定映射中挑选指定模式下优先级最高的一条绑定。
-
- Args:
- binding_map: 某个精确 ``RouteKey`` 对应的绑定映射。
- mode: 需要挑选的绑定模式。
-
- Returns:
- Optional[RouteBinding]: 若存在匹配模式的绑定,则返回优先级最高的一条。
-
- Notes:
- 这里使用单次线性扫描代替“先过滤成列表再排序”的做法,以减少
- 高频路由解析路径上的临时对象分配和排序开销。
- """
- best_binding: Optional[RouteBinding] = None
- for binding in binding_map.values():
- if binding.mode != mode:
- continue
- if best_binding is None or binding.priority > best_binding.priority:
- best_binding = binding
- return best_binding
diff --git a/src/platform_io/types.py b/src/platform_io/types.py
index 8729b637..200eca51 100644
--- a/src/platform_io/types.py
+++ b/src/platform_io/types.py
@@ -19,14 +19,6 @@ class DriverKind(str, Enum):
PLUGIN = "plugin"
-class RouteMode(str, Enum):
- """路由归属模式枚举。"""
-
- ACTIVE = "active"
- SHADOW = "shadow"
- DISABLED = "disabled"
-
-
class DeliveryStatus(str, Enum):
"""统一出站回执状态枚举。"""
@@ -158,21 +150,19 @@ class DriverDescriptor:
@dataclass(frozen=True, slots=True)
class RouteBinding:
- """表示一条从路由键到驱动的归属绑定关系。
+ """表示一条从路由键到驱动的绑定关系。
Attributes:
route_key: 该绑定覆盖的路由键。
- driver_id: 拥有或旁路观察该路由的驱动 ID。
+ driver_id: 拥有该路由的驱动 ID。
driver_kind: 绑定驱动的类型。
- mode: 绑定模式,例如 active owner 或 shadow observer。
- priority: 当同模式下存在多条绑定时使用的相对优先级。
+ priority: 当同一路由键存在多条绑定时使用的相对优先级。
metadata: 预留给未来路由策略的额外绑定元数据。
"""
route_key: RouteKey
driver_id: str
driver_kind: DriverKind
- mode: RouteMode = RouteMode.ACTIVE
priority: int = 0
metadata: Dict[str, Any] = field(default_factory=dict)
@@ -239,3 +229,36 @@ class DeliveryReceipt:
external_message_id: Optional[str] = None
error: Optional[str] = None
metadata: Dict[str, Any] = field(default_factory=dict)
+
+
+@dataclass(slots=True)
+class DeliveryBatch:
+ """表示一次广播式出站投递的批量结果。
+
+ Attributes:
+ internal_message_id: 内部消息 ID。
+ route_key: 本次投递使用的路由键。
+ receipts: 各条路由的独立投递回执列表。
+ """
+
+ internal_message_id: str
+ route_key: RouteKey
+ receipts: List[DeliveryReceipt] = field(default_factory=list)
+
+ @property
+ def sent_receipts(self) -> List[DeliveryReceipt]:
+ """返回全部发送成功的回执。"""
+
+ return [receipt for receipt in self.receipts if receipt.status == DeliveryStatus.SENT]
+
+ @property
+ def failed_receipts(self) -> List[DeliveryReceipt]:
+ """返回全部发送失败的回执。"""
+
+ return [receipt for receipt in self.receipts if receipt.status != DeliveryStatus.SENT]
+
+ @property
+ def has_success(self) -> bool:
+ """返回当前批量投递是否至少命中一条成功回执。"""
+
+ return bool(self.sent_receipts)
diff --git a/src/plugin_runtime/host/component_registry.py b/src/plugin_runtime/host/component_registry.py
index 95da0052..08b0ea3b 100644
--- a/src/plugin_runtime/host/component_registry.py
+++ b/src/plugin_runtime/host/component_registry.py
@@ -119,12 +119,52 @@ class MessageGatewayEntry(ComponentEntry):
"""MessageGateway 组件条目"""
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
- platform = metadata.get("platform")
- if not platform or not isinstance(platform, str):
- raise ValueError(f"MessageGateway 组件 {plugin_id}.{name} 缺少有效的 platform 字段")
- self.platform: str = platform
+ self.route_type: str = self._normalize_route_type(metadata.get("route_type", ""))
+ self.platform: str = str(metadata.get("platform", "") or "").strip()
+ self.protocol: str = str(metadata.get("protocol", "") or "").strip()
+ self.account_id: str = str(metadata.get("account_id", "") or "").strip()
+ self.scope: str = str(metadata.get("scope", "") or "").strip()
super().__init__(name, component_type, plugin_id, metadata)
+ @staticmethod
+ def _normalize_route_type(raw_value: Any) -> str:
+ """规范化消息网关路由类型。
+
+ Args:
+ raw_value: 原始路由类型值。
+
+ Returns:
+ str: 规范化后的路由类型。
+
+ Raises:
+ ValueError: 当路由类型不受支持时抛出。
+ """
+
+ normalized_value = str(raw_value or "").strip().lower()
+ route_type_aliases = {
+ "send": "send",
+ "receive": "receive",
+ "recv": "receive",
+ "recive": "receive",
+ "duplex": "duplex",
+ }
+ route_type = route_type_aliases.get(normalized_value)
+ if route_type is None:
+ raise ValueError(f"MessageGateway 路由类型不合法: {raw_value}")
+ return route_type
+
+ @property
+ def supports_send(self) -> bool:
+ """返回当前网关是否支持出站。"""
+
+ return self.route_type in {"send", "duplex"}
+
+ @property
+ def supports_receive(self) -> bool:
+ """返回当前网关是否支持入站。"""
+
+ return self.route_type in {"receive", "duplex"}
+
class ComponentRegistry:
"""Host-side 组件注册表
@@ -404,26 +444,71 @@ class ComponentRegistry:
handlers.sort(key=lambda c: c.priority, reverse=True)
return handlers
- def get_message_gateways(
- self, platform: str, *, enabled_only: bool = True, session_id: Optional[str] = None
+ def get_message_gateway(
+ self,
+ plugin_id: str,
+ name: str,
+ *,
+ enabled_only: bool = True,
+ session_id: Optional[str] = None,
) -> Optional[MessageGatewayEntry]:
- """查询消息网关组件。
+ """按插件和组件名获取单个消息网关。
Args:
- platform (str): 平台名称
- enabled_only (bool): 是否仅返回启用的组件
- session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态
+ plugin_id: 插件 ID。
+ name: 网关组件名称。
+ enabled_only: 是否仅返回启用的组件。
+ session_id: 可选的会话 ID。
+
Returns:
- gateway (Optional[MessageGatewayEntry]): 符合条件的 MessageGateway 组件,可能不存在
+ Optional[MessageGatewayEntry]: 若存在则返回消息网关条目。
"""
+ component = self._components.get(f"{plugin_id}.{name}")
+ if not isinstance(component, MessageGatewayEntry):
+ return None
+ if enabled_only and not self.check_component_enabled(component, session_id):
+ return None
+ return component
+
+ def get_message_gateways(
+ self,
+ *,
+ plugin_id: Optional[str] = None,
+ platform: str = "",
+ route_type: str = "",
+ enabled_only: bool = True,
+ session_id: Optional[str] = None,
+ ) -> List[MessageGatewayEntry]:
+ """查询消息网关组件列表。
+
+ Args:
+ plugin_id: 可选的插件 ID 过滤条件。
+ platform: 可选的平台过滤条件。
+ route_type: 可选的路由类型过滤条件。
+ enabled_only: 是否仅返回启用的组件。
+ session_id: 可选的会话 ID。
+
+ Returns:
+ List[MessageGatewayEntry]: 符合条件的消息网关组件列表。
+ """
+
+ normalized_platform = str(platform or "").strip()
+ normalized_route_type = str(route_type or "").strip().lower()
+ gateways: List[MessageGatewayEntry] = []
for comp in self._by_type.get(ComponentTypes.MESSAGE_GATEWAY, {}).values():
if not isinstance(comp, MessageGatewayEntry):
continue
+ if plugin_id and comp.plugin_id != plugin_id:
+ continue
if enabled_only and not self.check_component_enabled(comp, session_id):
continue
- if comp.platform == platform:
- return comp # 返回第一个
+ if normalized_platform and comp.platform != normalized_platform:
+ continue
+ if normalized_route_type and comp.route_type != normalized_route_type:
+ continue
+ gateways.append(comp)
+ return gateways
def get_tools(self, *, enabled_only: bool = True, session_id: Optional[str] = None) -> List[ToolEntry]:
"""查询所有工具组件。
diff --git a/src/plugin_runtime/host/message_gateway.py b/src/plugin_runtime/host/message_gateway.py
index 9e8e9be6..90f94493 100644
--- a/src/plugin_runtime/host/message_gateway.py
+++ b/src/plugin_runtime/host/message_gateway.py
@@ -1,12 +1,9 @@
-"""
-Message Gateway 模块
-适配器专用,用于将其他平台的消息转换为系统内部的消息格式,并将系统消息转换为其他平台的格式。
-"""
+"""Host 侧消息网关包装器。"""
from typing import TYPE_CHECKING, Any, Dict
from src.common.logger import get_logger
-from src.platform_io import DeliveryStatus, get_platform_io_manager
+from src.platform_io import get_platform_io_manager
from .message_utils import PluginMessageUtils
@@ -50,7 +47,7 @@ class MessageGateway:
internal_message: 内部消息对象。
Returns:
- Dict[str, Any]: 供适配器插件消费的标准消息字典。
+ Dict[str, Any]: 供消息网关插件消费的标准消息字典。
"""
return dict(PluginMessageUtils._session_message_to_dict(internal_message))
@@ -83,7 +80,7 @@ class MessageGateway:
Args:
internal_message: 系统内部的 ``SessionMessage`` 对象。
supervisor: 当前持有该消息网关的 Supervisor。
- enabled_only: 兼容旧签名的保留参数,当前由 Platform IO 统一裁决。
+ enabled_only: 兼容旧签名的保留参数,当前未使用。
save_to_db: 发送成功后是否写入数据库。
Returns:
@@ -98,12 +95,13 @@ class MessageGateway:
return False
route_key = platform_io_manager.build_route_key_from_message(internal_message)
- receipt = await platform_io_manager.send_message(internal_message, route_key)
- if receipt.status != DeliveryStatus.SENT:
- logger.warning(f"通过适配器链路发送消息失败: {receipt.error or receipt.status}")
+ delivery_batch = await platform_io_manager.send_message(internal_message, route_key)
+ if not delivery_batch.has_success:
+ logger.warning("通过消息网关链路发送消息失败: 未命中任何成功回执")
return False
- internal_message.message_id = receipt.external_message_id or internal_message.message_id
+ first_successful_receipt = delivery_batch.sent_receipts[0]
+ internal_message.message_id = first_successful_receipt.external_message_id or internal_message.message_id
if save_to_db:
try:
from src.common.utils.utils_message import MessageUtils
diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py
index 8a26af11..3588934e 100644
--- a/src/plugin_runtime/host/supervisor.py
+++ b/src/plugin_runtime/host/supervisor.py
@@ -9,24 +9,24 @@ import sys
from src.common.logger import get_logger
from src.config.config import global_config
-from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, RouteMode, get_platform_io_manager
+from src.core.component_registry import component_registry as core_component_registry
+from src.core.types import ActionActivationType, ActionInfo, ComponentType as CoreComponentType
+from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, get_platform_io_manager
from src.platform_io.drivers import PluginPlatformDriver
from src.platform_io.route_key_factory import RouteKeyFactory
-from src.platform_io.routing import RouteBindingConflictError
from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN
from src.plugin_runtime.protocol.envelope import (
- AdapterDeclarationPayload,
- AdapterStateUpdatePayload,
- AdapterStateUpdateResultPayload,
BootstrapPluginPayload,
ConfigUpdatedPayload,
Envelope,
HealthPayload,
+ MessageGatewayStateUpdatePayload,
+ MessageGatewayStateUpdateResultPayload,
PROTOCOL_VERSION,
- ReceiveExternalMessagePayload,
ReceiveExternalMessageResultPayload,
RegisterPluginPayload,
ReloadPluginResultPayload,
+ RouteMessagePayload,
RunnerReadyPayload,
ShutdownPayload,
UnregisterPluginPayload,
@@ -49,15 +49,12 @@ if TYPE_CHECKING:
logger = get_logger("plugin_runtime.host.runner_manager")
-_ADAPTER_BINDING_ROLE_RUNTIME_EXACT = "runtime_exact"
-_ADAPTER_BINDING_ROLE_PLATFORM_DEFAULT = "platform_default"
-
-
@dataclass(slots=True)
-class _AdapterRuntimeState:
- """保存适配器插件当前的运行时连接状态。"""
+class _MessageGatewayRuntimeState:
+ """保存消息网关当前的运行时连接状态。"""
- connected: bool = False
+ ready: bool = False
+ platform: Optional[str] = None
account_id: Optional[str] = None
scope: Optional[str] = None
metadata: Dict[str, Any] = field(default_factory=dict)
@@ -109,8 +106,8 @@ class PluginRunnerSupervisor:
self._runner_process: Optional[asyncio.subprocess.Process] = None
self._registered_plugins: Dict[str, RegisterPluginPayload] = {}
- self._registered_adapters: Dict[str, AdapterDeclarationPayload] = {}
- self._adapter_runtime_states: Dict[str, _AdapterRuntimeState] = {}
+ self._message_gateway_states: Dict[str, Dict[str, _MessageGatewayRuntimeState]] = {}
+ self._mirrored_core_actions: Dict[str, List[str]] = {}
self._runner_ready_events: asyncio.Event = asyncio.Event()
self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload()
self._health_task: Optional[asyncio.Task[None]] = None
@@ -289,28 +286,29 @@ class PluginRunnerSupervisor:
timeout_ms,
)
- async def invoke_adapter(
+ async def invoke_message_gateway(
self,
plugin_id: str,
- method_name: str,
+ component_name: str,
args: Optional[Dict[str, Any]] = None,
timeout_ms: int = 30000,
) -> Envelope:
- """调用适配器插件的专用方法。
+ """调用插件声明的消息网关方法。
Args:
- plugin_id: 目标适配器插件 ID。
- method_name: 要调用的插件方法名,例如 ``send_to_platform``。
- args: 传递给插件方法的关键字参数。
+ plugin_id: 目标插件 ID。
+ component_name: 消息网关组件名称。
+ args: 传递给网关方法的关键字参数。
timeout_ms: RPC 超时时间,单位毫秒。
Returns:
Envelope: Runner 返回的响应信封。
"""
+
return await self.invoke_plugin(
- method="plugin.invoke_adapter",
+ method="plugin.invoke_message_gateway",
plugin_id=plugin_id,
- component_name=method_name,
+ component_name=component_name,
args=args,
timeout_ms=timeout_ms,
)
@@ -468,8 +466,8 @@ class PluginRunnerSupervisor:
def _register_internal_methods(self) -> None:
"""注册 Host 侧内部 RPC 方法。"""
self._rpc_server.register_method("cap.call", self._capability_service.handle_capability_request)
- self._rpc_server.register_method("host.receive_external_message", self._handle_receive_external_message)
- self._rpc_server.register_method("host.update_adapter_state", self._handle_update_adapter_state)
+ self._rpc_server.register_method("host.route_message", self._handle_route_message)
+ self._rpc_server.register_method("host.update_message_gateway_state", self._handle_update_message_gateway_state)
self._rpc_server.register_method("plugin.bootstrap", self._handle_bootstrap_plugin)
self._rpc_server.register_method("plugin.register_components", self._handle_register_plugin)
self._rpc_server.register_method("plugin.register_plugin", self._handle_register_plugin)
@@ -512,30 +510,26 @@ class PluginRunnerSupervisor:
except Exception as exc:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
+ self._remove_core_action_mirrors(payload.plugin_id)
self._component_registry.remove_components_by_plugin(payload.plugin_id)
- if payload.plugin_id in self._registered_adapters:
- await self._unregister_adapter_driver(payload.plugin_id)
-
- try:
- if payload.adapter is not None:
- await self._register_adapter_driver(payload.plugin_id, payload.adapter)
- except RouteBindingConflictError as exc:
- return envelope.make_error_response(ErrorCode.E_METHOD_NOT_ALLOWED.value, str(exc))
- except Exception as exc:
- return envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(exc))
+ await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id)
registered_count = self._component_registry.register_plugin_components(
payload.plugin_id,
[component.model_dump() for component in payload.components],
)
self._registered_plugins[payload.plugin_id] = payload
+ self._message_gateway_states[payload.plugin_id] = {}
+ self._mirror_runtime_actions_to_core_registry(payload)
return envelope.make_response(
payload={
"accepted": True,
"plugin_id": payload.plugin_id,
"registered_components": registered_count,
- "adapter_registered": payload.adapter is not None,
+ "message_gateways": len(
+ self._component_registry.get_message_gateways(plugin_id=payload.plugin_id, enabled_only=False)
+ ),
}
)
@@ -556,7 +550,9 @@ class PluginRunnerSupervisor:
removed_components = self._component_registry.remove_components_by_plugin(payload.plugin_id)
self._authorization.revoke_permission_token(payload.plugin_id)
removed_registration = self._registered_plugins.pop(payload.plugin_id, None) is not None
- await self._unregister_adapter_driver(payload.plugin_id)
+ self._remove_core_action_mirrors(payload.plugin_id)
+ await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id)
+ self._message_gateway_states.pop(payload.plugin_id, None)
return envelope.make_response(
payload={
@@ -569,41 +565,321 @@ class PluginRunnerSupervisor:
)
@staticmethod
- def _build_adapter_driver_id(plugin_id: str) -> str:
- """构造适配器驱动 ID。
+ def _coerce_action_activation_type(raw_value: Any) -> ActionActivationType:
+ """将运行时 Action 激活类型转换为旧核心枚举。
Args:
- plugin_id: 适配器插件 ID。
+ raw_value: 插件运行时声明中的激活类型值。
+
+ Returns:
+ ActionActivationType: 可供旧 Planner 使用的激活类型枚举。
+ """
+ normalized_value = str(raw_value or ActionActivationType.ALWAYS.value).strip().lower()
+ try:
+ return ActionActivationType(normalized_value)
+ except ValueError:
+ return ActionActivationType.ALWAYS
+
+ @staticmethod
+ def _coerce_float(value: Any, default: float = 0.0) -> float:
+ """将任意输入尽量转换为浮点数。
+
+ Args:
+ value: 待转换的值。
+ default: 转换失败时使用的默认值。
+
+ Returns:
+ float: 转换结果。
+ """
+ try:
+ return float(value)
+ except (TypeError, ValueError):
+ return default
+
+ @staticmethod
+ def _build_core_action_info(plugin_id: str, component_name: str, metadata: Dict[str, Any]) -> ActionInfo:
+ """将运行时 Action 元数据映射为旧核心 ActionInfo。
+
+ Args:
+ plugin_id: 插件 ID。
+ component_name: 组件名称。
+ metadata: 运行时组件元数据。
+
+ Returns:
+ ActionInfo: 兼容旧 Planner 的动作定义。
+ """
+ activation_keywords = [
+ str(item)
+ for item in (metadata.get("activation_keywords") or [])
+ if item is not None and str(item).strip()
+ ]
+ action_require = [
+ str(item)
+ for item in (metadata.get("action_require") or [])
+ if item is not None and str(item).strip()
+ ]
+ associated_types = [
+ str(item)
+ for item in (metadata.get("associated_types") or [])
+ if item is not None and str(item).strip()
+ ]
+ raw_action_parameters = metadata.get("action_parameters") or {}
+ action_parameters = {
+ str(param_name): str(param_description)
+ for param_name, param_description in raw_action_parameters.items()
+ } if isinstance(raw_action_parameters, dict) else {}
+
+ return ActionInfo(
+ name=component_name,
+ component_type=CoreComponentType.ACTION,
+ description=str(metadata.get("description", "") or ""),
+ enabled=bool(metadata.get("enabled", True)),
+ plugin_name=plugin_id,
+ metadata=dict(metadata),
+ action_parameters=action_parameters,
+ action_require=action_require,
+ associated_types=associated_types,
+ activation_type=PluginRunnerSupervisor._coerce_action_activation_type(metadata.get("activation_type")),
+ random_activation_probability=PluginRunnerSupervisor._coerce_float(
+ metadata.get("activation_probability"),
+ 0.0,
+ ),
+ activation_keywords=activation_keywords,
+ parallel_action=bool(metadata.get("parallel_action", False)),
+ )
+
+ @staticmethod
+ def _extract_stream_id_from_action_kwargs(kwargs: Dict[str, Any]) -> str:
+ """从旧 ActionManager 传入参数中提取聊天流 ID。
+
+ Args:
+ kwargs: 旧动作执行器收到的关键字参数。
+
+ Returns:
+ str: 可用于新运行时 Action 的 ``stream_id``。
+ """
+ chat_stream = kwargs.get("chat_stream")
+ if chat_stream is not None:
+ try:
+ return str(chat_stream.session_id)
+ except AttributeError:
+ pass
+
+ raw_stream_id = kwargs.get("stream_id", "")
+ return str(raw_stream_id or "")
+
+ def _build_runtime_action_executor(
+ self,
+ plugin_id: str,
+ component_name: str,
+ ) -> Any:
+ """构造一个转发到 plugin runtime 的旧核心 Action 执行器。
+
+ Args:
+ plugin_id: 目标插件 ID。
+ component_name: 目标 Action 组件名称。
+
+ Returns:
+ Callable[..., Coroutine[Any, Any, tuple[bool, str]]]: 兼容旧 ActionManager 的执行器。
+ """
+
+ async def _executor(**kwargs: Any) -> tuple[bool, str]:
+ """将旧 Planner 的动作调用桥接到 plugin runtime。
+
+ Args:
+ **kwargs: 旧 ActionManager 传入的运行时上下文参数。
+
+ Returns:
+ tuple[bool, str]: ``(是否成功, 动作说明)``。
+ """
+ invoke_args: Dict[str, Any] = {}
+ action_data = kwargs.get("action_data")
+ if isinstance(action_data, dict):
+ invoke_args.update(action_data)
+
+ stream_id = self._extract_stream_id_from_action_kwargs(kwargs)
+ invoke_args["action_data"] = action_data if isinstance(action_data, dict) else {}
+ invoke_args["stream_id"] = stream_id
+ invoke_args["chat_id"] = stream_id
+ invoke_args["reasoning"] = str(kwargs.get("action_reasoning", "") or "")
+
+ thinking_id = kwargs.get("thinking_id")
+ if thinking_id is not None:
+ invoke_args["thinking_id"] = str(thinking_id)
+
+ cycle_timers = kwargs.get("cycle_timers")
+ if isinstance(cycle_timers, dict):
+ invoke_args["cycle_timers"] = cycle_timers
+
+ plugin_config = kwargs.get("plugin_config")
+ if isinstance(plugin_config, dict):
+ invoke_args["plugin_config"] = plugin_config
+
+ log_prefix = kwargs.get("log_prefix")
+ if isinstance(log_prefix, str):
+ invoke_args["log_prefix"] = log_prefix
+
+ shutting_down = kwargs.get("shutting_down")
+ if isinstance(shutting_down, bool):
+ invoke_args["shutting_down"] = shutting_down
+
+ try:
+ response = await self.invoke_plugin(
+ method="plugin.invoke_action",
+ plugin_id=plugin_id,
+ component_name=component_name,
+ args=invoke_args,
+ timeout_ms=30000,
+ )
+ except Exception as exc:
+ logger.error(f"运行时 Action {plugin_id}.{component_name} 执行失败: {exc}", exc_info=True)
+ return False, str(exc)
+
+ payload = response.payload if isinstance(response.payload, dict) else {}
+ success = bool(payload.get("success", False))
+ result = payload.get("result")
+
+ if isinstance(result, (list, tuple)):
+ if len(result) >= 2:
+ return bool(result[0]), "" if result[1] is None else str(result[1])
+ if len(result) == 1:
+ return bool(result[0]), ""
+
+ if success:
+ return True, "" if result is None else str(result)
+ return False, "" if result is None else str(result)
+
+ return _executor
+
+ def _mirror_runtime_actions_to_core_registry(self, payload: RegisterPluginPayload) -> None:
+ """将 plugin runtime 中声明的 Action 镜像到旧核心注册表。
+
+ Args:
+ payload: 当前插件的注册载荷。
+ """
+ mirrored_action_names: List[str] = []
+
+ for component in payload.components:
+ if str(component.component_type).upper() != CoreComponentType.ACTION.name:
+ continue
+
+ action_info = self._build_core_action_info(
+ plugin_id=payload.plugin_id,
+ component_name=component.name,
+ metadata=component.metadata,
+ )
+ action_executor = self._build_runtime_action_executor(
+ plugin_id=payload.plugin_id,
+ component_name=component.name,
+ )
+ registered = core_component_registry.register_action(action_info, action_executor)
+ if not registered:
+ logger.warning(
+ f"运行时 Action {payload.plugin_id}.{component.name} 无法镜像到旧核心注册表,"
+ "可能与现有 Action 重名"
+ )
+ continue
+ mirrored_action_names.append(component.name)
+
+ if mirrored_action_names:
+ self._mirrored_core_actions[payload.plugin_id] = mirrored_action_names
+
+ def _remove_core_action_mirrors(self, plugin_id: str) -> None:
+ """移除某个插件镜像到旧核心注册表的所有 Action。
+
+ Args:
+ plugin_id: 目标插件 ID。
+ """
+ mirrored_action_names = self._mirrored_core_actions.pop(plugin_id, [])
+ for action_name in mirrored_action_names:
+ core_component_registry.remove_action(action_name)
+
+ @staticmethod
+ def _build_message_gateway_driver_id(plugin_id: str, gateway_name: str) -> str:
+ """构造消息网关驱动 ID。
+
+ Args:
+ plugin_id: 插件 ID。
+ gateway_name: 网关组件名称。
Returns:
str: 对应 Platform IO 中的驱动 ID。
"""
- return f"adapter:{plugin_id}"
- async def _register_adapter_driver(self, plugin_id: str, adapter: AdapterDeclarationPayload) -> None:
- """将适配器插件驱动注册到 Platform IO。
+ return f"gateway:{plugin_id}:{gateway_name}"
+
+ @staticmethod
+ def _normalize_runtime_route_value(value: str) -> Optional[str]:
+ """规范化运行时路由字段。
Args:
- plugin_id: 适配器插件 ID。
- adapter: 经过校验的适配器声明。
+ value: 待规范化的原始字符串。
- Raises:
- ValueError: 当驱动注册失败时抛出。
+ Returns:
+ Optional[str]: 规范化后非空则返回字符串,否则返回 ``None``。
"""
- await self._unregister_adapter_driver(plugin_id)
+
+ normalized_value = str(value or "").strip()
+ return normalized_value or None
+
+ def _resolve_message_gateway_entry(
+ self,
+ plugin_id: str,
+ gateway_name: str,
+ ) -> Optional[Any]:
+ """解析指定插件的消息网关组件。
+
+ Args:
+ plugin_id: 插件 ID。
+ gateway_name: 网关组件名称;为空时按兼容规则推断。
+
+ Returns:
+ Optional[Any]: 匹配到的消息网关组件条目。
+ """
+
+ if gateway_name:
+ return self._component_registry.get_message_gateway(
+ plugin_id=plugin_id,
+ name=gateway_name,
+ enabled_only=False,
+ )
+
+ gateways = self._component_registry.get_message_gateways(plugin_id=plugin_id, enabled_only=False)
+ if len(gateways) == 1:
+ return gateways[0]
+
+ return None
+
+ async def _register_message_gateway_driver(
+ self,
+ plugin_id: str,
+ gateway_entry: Any,
+ route_key: RouteKey,
+ ) -> None:
+ """为消息网关注册驱动并绑定发送/接收路由。
+
+ Args:
+ plugin_id: 插件 ID。
+ gateway_entry: 消息网关组件条目。
+ route_key: 当前链路对应的路由键。
+ """
+
+ await self._unregister_message_gateway_driver(plugin_id, gateway_entry.name)
platform_io_manager = get_platform_io_manager()
driver = PluginPlatformDriver(
- driver_id=self._build_adapter_driver_id(plugin_id),
- platform=adapter.platform,
- account_id=adapter.account_id or None,
- scope=adapter.scope or None,
+ driver_id=self._build_message_gateway_driver_id(plugin_id, gateway_entry.name),
+ platform=route_key.platform,
+ account_id=route_key.account_id,
+ scope=route_key.scope,
plugin_id=plugin_id,
- send_method=adapter.send_method,
+ component_name=gateway_entry.name,
+ supports_send=bool(gateway_entry.supports_send),
supervisor=self,
metadata={
- "protocol": adapter.protocol,
- **adapter.metadata,
+ "protocol": gateway_entry.protocol,
+ "route_type": gateway_entry.route_type,
+ **gateway_entry.metadata,
},
)
@@ -620,20 +896,36 @@ class PluginRunnerSupervisor:
platform_io_manager.unregister_driver(driver.driver_id)
raise
- self._registered_adapters[plugin_id] = adapter
- self._adapter_runtime_states[plugin_id] = _AdapterRuntimeState()
+ binding_metadata = {
+ "plugin_id": plugin_id,
+ "gateway_name": gateway_entry.name,
+ "protocol": gateway_entry.protocol,
+ "route_type": gateway_entry.route_type,
+ **gateway_entry.metadata,
+ }
+ binding = RouteBinding(
+ route_key=route_key,
+ driver_id=driver.driver_id,
+ driver_kind=DriverKind.PLUGIN,
+ metadata=binding_metadata,
+ )
+ if gateway_entry.supports_send:
+ platform_io_manager.bind_send_route(binding)
+ if gateway_entry.supports_receive:
+ platform_io_manager.bind_receive_route(binding)
- async def _unregister_adapter_driver(self, plugin_id: str) -> None:
- """从 Platform IO 注销一个适配器驱动。
+ async def _unregister_message_gateway_driver(self, plugin_id: str, gateway_name: str) -> None:
+ """从 Platform IO 注销单个消息网关驱动。
Args:
- plugin_id: 适配器插件 ID。
+ plugin_id: 插件 ID。
+ gateway_name: 网关组件名称。
"""
- platform_io_manager = get_platform_io_manager()
- driver_id = self._build_adapter_driver_id(plugin_id)
- adapter = self._registered_adapters.get(plugin_id)
- self._remove_adapter_route_bindings(plugin_id)
+ platform_io_manager = get_platform_io_manager()
+ driver_id = self._build_message_gateway_driver_id(plugin_id, gateway_name)
+ platform_io_manager.send_route_table.remove_bindings_by_driver(driver_id)
+ platform_io_manager.receive_route_table.remove_bindings_by_driver(driver_id)
with contextlib.suppress(Exception):
if platform_io_manager.is_started:
@@ -641,204 +933,83 @@ class PluginRunnerSupervisor:
else:
platform_io_manager.unregister_driver(driver_id)
- if adapter is not None:
- self._refresh_platform_default_route(adapter.platform)
-
- self._registered_adapters.pop(plugin_id, None)
- self._adapter_runtime_states.pop(plugin_id, None)
-
- async def _unregister_all_adapter_drivers(self) -> None:
- """注销当前 Supervisor 管理的全部适配器驱动。"""
- plugin_ids = list(self._registered_adapters.keys())
- for plugin_id in plugin_ids:
- await self._unregister_adapter_driver(plugin_id)
-
- def _remove_adapter_route_bindings(self, plugin_id: str) -> None:
- """移除某个适配器驱动当前持有的全部路由绑定。
+ async def _unregister_all_message_gateway_drivers_for_plugin(self, plugin_id: str) -> None:
+ """注销指定插件的全部消息网关驱动。
Args:
- plugin_id: 适配器插件 ID。
+ plugin_id: 插件 ID。
"""
- platform_io_manager = get_platform_io_manager()
- platform_io_manager.route_table.remove_bindings_by_driver(self._build_adapter_driver_id(plugin_id))
- @staticmethod
- def _normalize_runtime_route_value(value: str) -> Optional[str]:
- """规范化适配器运行时路由字段。
+ gateway_names = list(self._message_gateway_states.get(plugin_id, {}).keys())
+ for gateway_name in gateway_names:
+ await self._unregister_message_gateway_driver(plugin_id, gateway_name)
- Args:
- value: 待规范化的原始字符串。
-
- Returns:
- Optional[str]: 规范化后非空则返回字符串,否则返回 ``None``。
- """
- normalized_value = str(value).strip()
- return normalized_value or None
-
- def _build_runtime_route_key(
+ def _build_message_gateway_route_key(
self,
- adapter: AdapterDeclarationPayload,
- payload: AdapterStateUpdatePayload,
+ gateway_entry: Any,
+ payload: MessageGatewayStateUpdatePayload,
) -> RouteKey:
- """根据运行时状态更新构造适配器生效路由键。
+ """根据消息网关运行时状态构造路由键。
Args:
- adapter: 当前适配器声明。
- payload: 适配器上报的运行时状态。
+ gateway_entry: 消息网关组件条目。
+ payload: 网关上报的运行时状态。
Returns:
- RouteKey: 当前连接应接管的精确路由键。
+ RouteKey: 当前链路对应的路由键。
Raises:
- ValueError: 当静态声明与运行时上报的身份信息冲突时抛出。
+ ValueError: 当平台信息缺失时抛出。
"""
- runtime_account_id = self._normalize_runtime_route_value(payload.account_id)
- runtime_scope = self._normalize_runtime_route_value(payload.scope)
- if adapter.account_id and runtime_account_id and adapter.account_id != runtime_account_id:
- raise ValueError(
- f"适配器声明的 account_id={adapter.account_id} 与运行时上报的 {runtime_account_id} 不一致"
- )
- if adapter.scope and runtime_scope and adapter.scope != runtime_scope:
- raise ValueError(f"适配器声明的 scope={adapter.scope} 与运行时上报的 {runtime_scope} 不一致")
+ platform = str(payload.platform or gateway_entry.platform or "").strip()
+ if not platform:
+ raise ValueError(f"消息网关 {gateway_entry.full_name} 未提供有效的平台名称")
return RouteKey(
- platform=adapter.platform,
- account_id=runtime_account_id or adapter.account_id or None,
- scope=runtime_scope or adapter.scope or None,
+ platform=platform,
+ account_id=self._normalize_runtime_route_value(payload.account_id) or gateway_entry.account_id or None,
+ scope=self._normalize_runtime_route_value(payload.scope) or gateway_entry.scope or None,
)
- def _bind_runtime_exact_route(
+ def _apply_message_gateway_state(
self,
plugin_id: str,
- adapter: AdapterDeclarationPayload,
- route_key: RouteKey,
- ) -> None:
- """为适配器连接绑定精确生效路由。
+ gateway_entry: Any,
+ payload: MessageGatewayStateUpdatePayload,
+ ) -> Tuple[_MessageGatewayRuntimeState, Dict[str, Any]]:
+ """应用消息网关运行时状态,并同步 Platform IO 路由。
Args:
- plugin_id: 适配器插件 ID。
- adapter: 当前适配器声明。
- route_key: 当前连接对应的精确路由键。
+ plugin_id: 插件 ID。
+ gateway_entry: 消息网关组件条目。
+ payload: 网关上报的运行时状态。
- Raises:
- RouteBindingConflictError: 当目标路由已被其他 active owner 占用时抛出。
+ Returns:
+ Tuple[_MessageGatewayRuntimeState, Dict[str, Any]]: 更新后的状态与路由键字典。
"""
- platform_io_manager = get_platform_io_manager()
- platform_io_manager.bind_route(
- RouteBinding(
- route_key=route_key,
- driver_id=self._build_adapter_driver_id(plugin_id),
- driver_kind=DriverKind.PLUGIN,
- metadata={
- "plugin_id": plugin_id,
- "protocol": adapter.protocol,
- "binding_role": _ADAPTER_BINDING_ROLE_RUNTIME_EXACT,
- },
+
+ plugin_states = self._message_gateway_states.setdefault(plugin_id, {})
+ if not payload.ready:
+ runtime_state = _MessageGatewayRuntimeState(
+ ready=False,
+ platform=self._normalize_runtime_route_value(payload.platform) or gateway_entry.platform or None,
+ account_id=self._normalize_runtime_route_value(payload.account_id) or gateway_entry.account_id or None,
+ scope=self._normalize_runtime_route_value(payload.scope) or gateway_entry.scope or None,
+ metadata=dict(payload.metadata),
)
- )
-
- def _list_runtime_exact_bindings(self, platform: str) -> List[RouteBinding]:
- """列出某个平台上由 Host 动态维护的精确适配器绑定。
-
- Args:
- platform: 目标平台名称。
-
- Returns:
- List[RouteBinding]: 当前平台上全部动态精确绑定。
- """
- platform_io_manager = get_platform_io_manager()
- return [
- binding
- for binding in platform_io_manager.route_table.list_bindings()
- if binding.mode == RouteMode.ACTIVE
- and binding.route_key.platform == platform
- and binding.metadata.get("binding_role") == _ADAPTER_BINDING_ROLE_RUNTIME_EXACT
- ]
-
- def _refresh_platform_default_route(self, platform: str) -> None:
- """根据当前精确绑定数量刷新平台级默认路由。
-
- 当某个平台恰好只存在一个动态精确绑定时,会为该绑定额外创建一条
- ``RouteKey(platform=)`` 形式的默认路由,方便缺少账号维度的
- 出站消息继续找到唯一 owner。若精确绑定数量变为 0 或大于 1,则撤销
- 由 Host 自动维护的默认路由,避免出现隐式歧义。
-
- Args:
- platform: 目标平台名称。
- """
- platform_io_manager = get_platform_io_manager()
- default_route_key = RouteKey(platform=platform)
- existing_default_binding = platform_io_manager.route_table.get_active_binding(default_route_key, exact_only=True)
-
- if existing_default_binding is not None:
- binding_role = existing_default_binding.metadata.get("binding_role")
- if binding_role != _ADAPTER_BINDING_ROLE_PLATFORM_DEFAULT:
- return
- platform_io_manager.unbind_route(default_route_key, existing_default_binding.driver_id)
-
- exact_bindings = self._list_runtime_exact_bindings(platform)
- if len(exact_bindings) != 1:
- return
-
- exact_binding = exact_bindings[0]
- if exact_binding.route_key == default_route_key:
- return
-
- platform_io_manager.bind_route(
- RouteBinding(
- route_key=default_route_key,
- driver_id=exact_binding.driver_id,
- driver_kind=exact_binding.driver_kind,
- metadata={
- "plugin_id": exact_binding.metadata.get("plugin_id", ""),
- "protocol": exact_binding.metadata.get("protocol", ""),
- "binding_role": _ADAPTER_BINDING_ROLE_PLATFORM_DEFAULT,
- },
- ),
- replace=True,
- )
-
- def _apply_adapter_runtime_state(
- self,
- plugin_id: str,
- adapter: AdapterDeclarationPayload,
- payload: AdapterStateUpdatePayload,
- ) -> Tuple[_AdapterRuntimeState, Dict[str, Any]]:
- """应用适配器运行时状态,并同步 Platform IO 路由。
-
- Args:
- plugin_id: 适配器插件 ID。
- adapter: 当前适配器声明。
- payload: 适配器上报的运行时状态。
-
- Returns:
- Tuple[_AdapterRuntimeState, Dict[str, Any]]: 更新后的运行时状态,以及
- 供 RPC 响应返回的路由键字典。
-
- Raises:
- RouteBindingConflictError: 当新的精确路由与其他 active owner 冲突时抛出。
- ValueError: 当运行时路由信息不合法时抛出。
- """
- if not payload.connected:
- self._remove_adapter_route_bindings(plugin_id)
- self._refresh_platform_default_route(adapter.platform)
- runtime_state = _AdapterRuntimeState(connected=False, metadata=dict(payload.metadata))
- self._adapter_runtime_states[plugin_id] = runtime_state
+ plugin_states[gateway_entry.name] = runtime_state
return runtime_state, {}
- route_key = self._build_runtime_route_key(adapter, payload)
- self._remove_adapter_route_bindings(plugin_id)
- self._bind_runtime_exact_route(plugin_id, adapter, route_key)
- self._refresh_platform_default_route(adapter.platform)
-
- runtime_state = _AdapterRuntimeState(
- connected=True,
+ route_key = self._build_message_gateway_route_key(gateway_entry, payload)
+ runtime_state = _MessageGatewayRuntimeState(
+ ready=True,
+ platform=route_key.platform,
account_id=route_key.account_id,
scope=route_key.scope,
metadata=dict(payload.metadata),
)
- self._adapter_runtime_states[plugin_id] = runtime_state
+ plugin_states[gateway_entry.name] = runtime_state
return runtime_state, {
"platform": route_key.platform,
"account_id": route_key.account_id,
@@ -856,8 +1027,9 @@ class PluginRunnerSupervisor:
Args:
session_message: 已构造好的内部消息对象。
route_key: Host 为该消息解析出的标准路由键。
- route_metadata: 适配器通过 RPC 补充的原始路由辅助元数据。
+ route_metadata: 插件通过 RPC 补充的原始路由辅助元数据。
"""
+
additional_config = session_message.message_info.additional_config
if not isinstance(additional_config, dict):
additional_config = {}
@@ -877,45 +1049,49 @@ class PluginRunnerSupervisor:
def _build_inbound_route_key(
self,
- adapter: AdapterDeclarationPayload,
+ gateway_entry: Any,
+ runtime_state: _MessageGatewayRuntimeState,
message: Dict[str, Any],
route_metadata: Dict[str, Any],
) -> RouteKey:
- """为适配器入站消息构造归一路由键。
+ """为入站消息构造归一路由键。
Args:
- adapter: 当前适配器声明。
+ gateway_entry: 接收消息的网关组件条目。
+ runtime_state: 当前网关的运行时状态。
message: 标准消息字典。
route_metadata: 插件补充的路由辅助元数据。
Returns:
RouteKey: 供 Platform IO 使用的规范化路由键。
-
- Raises:
- ValueError: 消息平台字段与适配器平台声明不一致时抛出。
"""
- message_platform = str(message.get("platform") or adapter.platform).strip()
- if message_platform != adapter.platform:
- raise ValueError(
- f"外部消息平台 {message_platform} 与适配器 {adapter.platform} 不一致"
- )
+
+ platform = str(
+ message.get("platform")
+ or route_metadata.get("platform")
+ or runtime_state.platform
+ or gateway_entry.platform
+ or ""
+ ).strip()
+ if not platform:
+ raise ValueError(f"消息网关 {gateway_entry.full_name} 的入站消息缺少平台信息")
try:
route_key = RouteKeyFactory.from_message_dict(message)
except Exception:
- route_key = RouteKey(platform=message_platform)
+ route_key = RouteKey(platform=platform)
route_account_id, route_scope = RouteKeyFactory.extract_components(route_metadata)
- account_id = route_key.account_id or route_account_id or adapter.account_id or None
- scope = route_key.scope or route_scope or adapter.scope or None
+ account_id = route_key.account_id or route_account_id or runtime_state.account_id or gateway_entry.account_id or None
+ scope = route_key.scope or route_scope or runtime_state.scope or gateway_entry.scope or None
return RouteKey(
- platform=message_platform,
+ platform=platform,
account_id=account_id,
scope=scope,
)
- async def _handle_update_adapter_state(self, envelope: Envelope) -> Envelope:
- """处理适配器插件上报的运行时状态更新。
+ async def _handle_update_message_gateway_state(self, envelope: Envelope) -> Envelope:
+ """处理消息网关上报的运行时状态更新。
Args:
envelope: RPC 请求信封。
@@ -923,38 +1099,42 @@ class PluginRunnerSupervisor:
Returns:
Envelope: 状态更新处理结果。
"""
+
try:
- payload = AdapterStateUpdatePayload.model_validate(envelope.payload)
+ payload = MessageGatewayStateUpdatePayload.model_validate(envelope.payload)
except Exception as exc:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
- adapter = self._registered_adapters.get(envelope.plugin_id)
- if adapter is None:
+ gateway_entry = self._resolve_message_gateway_entry(envelope.plugin_id, payload.gateway_name)
+ if gateway_entry is None:
return envelope.make_error_response(
ErrorCode.E_METHOD_NOT_ALLOWED.value,
- f"插件 {envelope.plugin_id} 未声明为适配器,不能更新运行时状态",
+ f"插件 {envelope.plugin_id} 未声明消息网关 {payload.gateway_name or ''}",
)
try:
- runtime_state, route_key_dict = self._apply_adapter_runtime_state(
+ if payload.ready:
+ route_key = self._build_message_gateway_route_key(gateway_entry, payload)
+ await self._register_message_gateway_driver(envelope.plugin_id, gateway_entry, route_key)
+ else:
+ await self._unregister_message_gateway_driver(envelope.plugin_id, gateway_entry.name)
+ runtime_state, route_key_dict = self._apply_message_gateway_state(
plugin_id=envelope.plugin_id,
- adapter=adapter,
+ gateway_entry=gateway_entry,
payload=payload,
)
- except RouteBindingConflictError as exc:
- return envelope.make_error_response(ErrorCode.E_METHOD_NOT_ALLOWED.value, str(exc))
except Exception as exc:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
- response = AdapterStateUpdateResultPayload(
+ response = MessageGatewayStateUpdateResultPayload(
accepted=True,
- connected=runtime_state.connected,
+ ready=runtime_state.ready,
route_key=route_key_dict,
)
return envelope.make_response(payload=response.model_dump())
- async def _handle_receive_external_message(self, envelope: Envelope) -> Envelope:
- """处理适配器插件上报的外部入站消息。
+ async def _handle_route_message(self, envelope: Envelope) -> Envelope:
+ """处理消息网关上报的外部入站消息。
Args:
envelope: RPC 请求信封。
@@ -962,21 +1142,33 @@ class PluginRunnerSupervisor:
Returns:
Envelope: 注入结果响应。
"""
+
try:
- payload = ReceiveExternalMessagePayload.model_validate(envelope.payload)
+ payload = RouteMessagePayload.model_validate(envelope.payload)
except Exception as exc:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
- adapter = self._registered_adapters.get(envelope.plugin_id)
- if adapter is None:
+ gateway_entry = self._resolve_message_gateway_entry(envelope.plugin_id, payload.gateway_name)
+ if gateway_entry is None or not bool(gateway_entry.supports_receive):
return envelope.make_error_response(
ErrorCode.E_METHOD_NOT_ALLOWED.value,
- f"插件 {envelope.plugin_id} 未声明为适配器,不能注入外部消息",
+ f"插件 {envelope.plugin_id} 未声明可接收的消息网关 {payload.gateway_name}",
+ )
+
+ runtime_state = self._message_gateway_states.get(envelope.plugin_id, {}).get(
+ gateway_entry.name,
+ _MessageGatewayRuntimeState(),
+ )
+ if not runtime_state.ready:
+ return envelope.make_error_response(
+ ErrorCode.E_METHOD_NOT_ALLOWED.value,
+ f"消息网关 {gateway_entry.full_name} 尚未就绪,不能注入外部消息",
)
try:
route_key = self._build_inbound_route_key(
- adapter=adapter,
+ gateway_entry=gateway_entry,
+ runtime_state=runtime_state,
message=payload.message,
route_metadata=payload.route_metadata,
)
@@ -989,7 +1181,7 @@ class PluginRunnerSupervisor:
accepted = await platform_io_manager.accept_inbound(
InboundMessageEnvelope(
route_key=route_key,
- driver_id=self._build_adapter_driver_id(envelope.plugin_id),
+ driver_id=self._build_message_gateway_driver_id(envelope.plugin_id, gateway_entry.name),
driver_kind=DriverKind.PLUGIN,
external_message_id=payload.external_message_id or str(payload.message.get("message_id") or "") or None,
dedupe_key=payload.dedupe_key or None,
@@ -997,7 +1189,8 @@ class PluginRunnerSupervisor:
payload=payload.message,
metadata={
"plugin_id": envelope.plugin_id,
- "protocol": adapter.protocol,
+ "gateway_name": gateway_entry.name,
+ "protocol": gateway_entry.protocol,
**payload.route_metadata,
},
)
@@ -1138,7 +1331,8 @@ class PluginRunnerSupervisor:
await self._stderr_drain_task
self._stderr_drain_task = None
- await self._unregister_all_adapter_drivers()
+ for plugin_id in list(self._message_gateway_states.keys()):
+ await self._unregister_all_message_gateway_drivers_for_plugin(plugin_id)
self._clear_runner_state()
async def _health_check_loop(self) -> None:
@@ -1213,11 +1407,12 @@ class PluginRunnerSupervisor:
def _clear_runner_state(self) -> None:
"""清理当前 Runner 对应的 Host 侧注册状态。"""
+ for plugin_id in list(self._mirrored_core_actions.keys()):
+ self._remove_core_action_mirrors(plugin_id)
self._authorization.clear()
self._component_registry.clear()
self._registered_plugins.clear()
- self._registered_adapters.clear()
- self._adapter_runtime_states.clear()
+ self._message_gateway_states.clear()
self._runner_ready_events = asyncio.Event()
self._runner_ready_payloads = RunnerReadyPayload()
self._rpc_server.clear_handshake_state()
diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py
index bf85669b..b74b2d46 100644
--- a/src/plugin_runtime/integration.py
+++ b/src/plugin_runtime/integration.py
@@ -18,7 +18,7 @@ import tomlkit
from src.common.logger import get_logger
from src.config.config import global_config
from src.config.file_watcher import FileChange, FileWatcher
-from src.platform_io import DeliveryReceipt, InboundMessageEnvelope, get_platform_io_manager
+from src.platform_io import DeliveryBatch, InboundMessageEnvelope, get_platform_io_manager
from src.plugin_runtime.capabilities import (
RuntimeComponentCapabilityMixin,
RuntimeCoreCapabilityMixin,
@@ -351,15 +351,15 @@ class PluginRuntimeManager(
async def try_send_message_via_platform_io(
self,
message: "SessionMessage",
- ) -> Optional[DeliveryReceipt]:
+ ) -> Optional[DeliveryBatch]:
"""尝试通过 Platform IO 中间层发送消息。
Args:
message: 待发送的内部会话消息。
Returns:
- Optional[DeliveryReceipt]: 若当前消息存在 active 路由,则返回实际发送
- 结果;若没有可用路由或 Platform IO 尚未启动,则返回 ``None``。
+ Optional[DeliveryBatch]: 若当前消息命中了至少一条发送路由,则返回
+ 实际发送结果;若没有可用路由或 Platform IO 尚未启动,则返回 ``None``。
"""
if not self._started:
return None
@@ -374,7 +374,7 @@ class PluginRuntimeManager(
logger.warning(f"根据消息构造 Platform IO 路由键失败: {exc}")
return None
- if platform_io_manager.resolve_driver(route_key) is None:
+ if not platform_io_manager.resolve_drivers(route_key):
return None
return await platform_io_manager.send_message(message, route_key)
diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py
index f68657fa..cbbb71be 100644
--- a/src/plugin_runtime/protocol/envelope.py
+++ b/src/plugin_runtime/protocol/envelope.py
@@ -156,8 +156,6 @@ class RegisterPluginPayload(BaseModel):
"""插件版本"""
components: List[ComponentDeclaration] = Field(default_factory=list, description="组件列表")
"""组件列表"""
- adapter: Optional["AdapterDeclarationPayload"] = Field(default=None, description="可选的适配器声明")
- """可选的适配器声明"""
capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表")
"""所需能力列表"""
@@ -287,50 +285,39 @@ class ReloadPluginResultPayload(BaseModel):
"""重载失败的插件及原因"""
-class AdapterDeclarationPayload(BaseModel):
- """适配器插件声明载荷。"""
+class MessageGatewayStateUpdatePayload(BaseModel):
+ """消息网关运行时状态更新载荷。"""
- platform: str = Field(description="适配器负责的平台名称,例如 qq")
- """适配器负责的平台名称,例如 qq"""
- protocol: str = Field(default="", description="接入协议或实现名称,例如 napcat")
- """接入协议或实现名称,例如 napcat"""
- account_id: str = Field(default="", description="可选的账号 ID 或 self_id")
- """可选的账号 ID 或 self_id"""
- scope: str = Field(default="", description="可选的路由作用域")
- """可选的路由作用域"""
- send_method: str = Field(default="send_to_platform", description="Host 出站调用的插件方法名")
- """Host 出站调用的插件方法名"""
- metadata: Dict[str, Any] = Field(default_factory=dict, description="适配器附加元数据")
- """适配器附加元数据"""
-
-
-class AdapterStateUpdatePayload(BaseModel):
- """适配器运行时状态更新载荷。"""
-
- connected: bool = Field(description="适配器当前是否已连接并准备接管路由")
- """适配器当前是否已连接并准备接管路由"""
- account_id: str = Field(default="", description="当前连接对应的账号 ID 或 self_id")
- """当前连接对应的账号 ID 或 self_id"""
- scope: str = Field(default="", description="当前连接对应的可选路由作用域")
- """当前连接对应的可选路由作用域"""
+ gateway_name: str = Field(description="消息网关组件名称")
+ """消息网关组件名称"""
+ ready: bool = Field(description="当前链路是否已经就绪")
+ """当前链路是否已经就绪"""
+ platform: str = Field(default="", description="当前链路负责的平台名称")
+ """当前链路负责的平台名称"""
+ account_id: str = Field(default="", description="当前链路对应的账号 ID 或 self_id")
+ """当前链路对应的账号 ID 或 self_id"""
+ scope: str = Field(default="", description="当前链路对应的可选路由作用域")
+ """当前链路对应的可选路由作用域"""
metadata: Dict[str, Any] = Field(default_factory=dict, description="可选的运行时状态元数据")
"""可选的运行时状态元数据"""
-class AdapterStateUpdateResultPayload(BaseModel):
- """适配器运行时状态更新结果载荷。"""
+class MessageGatewayStateUpdateResultPayload(BaseModel):
+ """消息网关运行时状态更新结果载荷。"""
accepted: bool = Field(description="Host 是否接受了本次状态更新")
"""Host 是否接受了本次状态更新"""
- connected: bool = Field(description="Host 记录的当前连接状态")
- """Host 记录的当前连接状态"""
+ ready: bool = Field(description="Host 记录的当前就绪状态")
+ """Host 记录的当前就绪状态"""
route_key: Dict[str, Any] = Field(default_factory=dict, description="当前生效的路由键")
"""当前生效的路由键"""
-class ReceiveExternalMessagePayload(BaseModel):
- """适配器插件向 Host 注入外部消息的请求载荷。"""
+class RouteMessagePayload(BaseModel):
+ """消息网关向 Host 路由外部消息的请求载荷。"""
+ gateway_name: str = Field(description="接收消息的网关组件名称")
+ """接收消息的网关组件名称"""
message: Dict[str, Any] = Field(description="符合 MessageDict 结构的标准消息字典")
"""符合 MessageDict 结构的标准消息字典"""
route_metadata: Dict[str, Any] = Field(default_factory=dict, description="可选的路由辅助元数据")
diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py
index 8078c88b..3a50e2f7 100644
--- a/src/plugin_runtime/runner/runner_main.py
+++ b/src/plugin_runtime/runner/runner_main.py
@@ -25,7 +25,6 @@ import tomllib
from src.common.logger import get_console_handler, get_logger, initialize_logging
from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN
from src.plugin_runtime.protocol.envelope import (
- AdapterDeclarationPayload,
BootstrapPluginPayload,
ComponentDeclaration,
Envelope,
@@ -219,7 +218,7 @@ class PluginRunner:
"""为插件实例创建并注入 PluginContext。
对新版 MaiBotPlugin(具有 _set_context 方法):创建 PluginContext 并注入。
- 对旧版 LegacyPluginAdapter(具有 _set_context 方法,由适配器代理):同上。
+ 对旧版 LegacyPluginAdapter(具有 _set_context 方法,由兼容代理封装):同上。
"""
if not hasattr(instance, "_set_context"):
return
@@ -293,7 +292,7 @@ class PluginRunner:
self._rpc_client.register_method("plugin.invoke_command", self._handle_invoke)
self._rpc_client.register_method("plugin.invoke_action", self._handle_invoke)
self._rpc_client.register_method("plugin.invoke_tool", self._handle_invoke)
- self._rpc_client.register_method("plugin.invoke_adapter", self._handle_invoke)
+ self._rpc_client.register_method("plugin.invoke_message_gateway", self._handle_invoke)
self._rpc_client.register_method("plugin.emit_event", self._handle_event_invoke)
self._rpc_client.register_method("plugin.invoke_hook", self._handle_hook_invoke)
self._rpc_client.register_method("plugin.invoke_workflow_step", self._handle_workflow_step)
@@ -331,29 +330,6 @@ class PluginRunner:
"""撤销 bootstrap 期间为插件签发的能力令牌。"""
await self._bootstrap_plugin(meta, capabilities_required=[])
- def _collect_adapter_declaration(self, meta: PluginMeta) -> Optional[AdapterDeclarationPayload]:
- """从插件实例中提取适配器声明。
-
- Args:
- meta: 待提取声明的插件元数据。
-
- Returns:
- Optional[AdapterDeclarationPayload]: 若插件声明了适配器角色,则返回
- 经过校验的适配器声明;否则返回 ``None``。
-
- Raises:
- ValueError: 插件导出的适配器声明结构非法时抛出。
- """
- instance = meta.instance
- if not hasattr(instance, "get_adapter_info"):
- return None
-
- adapter_info = instance.get_adapter_info()
- if adapter_info is None:
- return None
-
- return AdapterDeclarationPayload.model_validate(adapter_info)
-
async def _register_plugin(self, meta: PluginMeta) -> bool:
"""向 Host 注册单个插件。
@@ -379,17 +355,10 @@ class PluginRunner:
for comp_info in instance.get_components()
)
- try:
- adapter = self._collect_adapter_declaration(meta)
- except Exception as exc:
- logger.error(f"插件 {meta.plugin_id} 适配器声明非法: {exc}", exc_info=True)
- return False
-
reg_payload = RegisterPluginPayload(
plugin_id=meta.plugin_id,
plugin_version=meta.version,
components=components,
- adapter=adapter,
capabilities_required=meta.capabilities_required,
)
From d07915eea04b6b60b6f15b462ac93528e3338869 Mon Sep 17 00:00:00 2001
From: DrSmoothl <1787882683@qq.com>
Date: Mon, 23 Mar 2026 11:38:46 +0800
Subject: [PATCH 31/45] Refactor message sending architecture and implement
legacy driver support
- Removed UniversalMessageSender from group_generator.py and private_generator.py.
- Updated PlatformIOManager to manage legacy send drivers and ensure send pipeline readiness.
- Enhanced LegacyPlatformDriver to utilize prepared messages for sending.
- Refactored send_service to unify message sending logic and integrate with Platform IO.
- Added regression tests for Platform IO legacy driver and send service functionality.
---
pytests/test_platform_io_legacy_driver.py | 124 ++++
pytests/test_send_service.py | 141 ++++
src/chat/brain_chat/PFC/message_sender.py | 84 +--
.../message_receive/uni_message_sender.py | 63 +-
src/chat/replyer/group_generator.py | 8 +-
src/chat/replyer/private_generator.py | 8 +-
src/common/message_server/server.py | 2 +-
src/platform_io/drivers/legacy_driver.py | 51 +-
src/platform_io/manager.py | 77 ++-
src/plugin_runtime/integration.py | 2 +-
src/services/send_service.py | 636 ++++++++++++++----
11 files changed, 967 insertions(+), 229 deletions(-)
create mode 100644 pytests/test_platform_io_legacy_driver.py
create mode 100644 pytests/test_send_service.py
diff --git a/pytests/test_platform_io_legacy_driver.py b/pytests/test_platform_io_legacy_driver.py
new file mode 100644
index 00000000..2e94c1fc
--- /dev/null
+++ b/pytests/test_platform_io_legacy_driver.py
@@ -0,0 +1,124 @@
+"""Platform IO legacy driver 回归测试。"""
+
+from typing import Any, Dict, Optional
+
+import pytest
+
+from src.chat.utils import utils as chat_utils
+from src.chat.message_receive import uni_message_sender
+from src.platform_io.drivers.base import PlatformIODriver
+from src.platform_io.drivers.legacy_driver import LegacyPlatformDriver
+from src.platform_io.manager import PlatformIOManager
+from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, RouteBinding, RouteKey
+
+
+class _PluginDriver(PlatformIODriver):
+ """测试用插件发送驱动。"""
+
+ def __init__(self, driver_id: str, platform: str) -> None:
+ """初始化测试驱动。
+
+ Args:
+ driver_id: 驱动 ID。
+ platform: 负责的平台名称。
+ """
+ super().__init__(
+ DriverDescriptor(
+ driver_id=driver_id,
+ kind=DriverKind.PLUGIN,
+ platform=platform,
+ plugin_id="test.plugin",
+ )
+ )
+
+ async def send_message(
+ self,
+ message: Any,
+ route_key: RouteKey,
+ metadata: Optional[Dict[str, Any]] = None,
+ ) -> DeliveryReceipt:
+ """返回一个固定成功回执。
+
+ Args:
+ message: 待发送消息。
+ route_key: 当前路由键。
+ metadata: 发送元数据。
+
+ Returns:
+ DeliveryReceipt: 固定成功回执。
+ """
+ del metadata
+ return DeliveryReceipt(
+ internal_message_id=str(message.message_id),
+ route_key=route_key,
+ status=DeliveryStatus.SENT,
+ driver_id=self.driver_id,
+ driver_kind=self.descriptor.kind,
+ )
+
+
+@pytest.mark.asyncio
+async def test_platform_io_uses_legacy_driver_when_no_explicit_send_route(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """没有显式发送路由时,应由 Platform IO 回退到 legacy driver。"""
+ manager = PlatformIOManager()
+ monkeypatch.setattr(chat_utils, "get_all_bot_accounts", lambda: {"qq": "bot-qq"})
+
+ try:
+ await manager.ensure_send_pipeline_ready()
+
+ fallback_drivers = manager.resolve_drivers(RouteKey(platform="qq"))
+ assert [driver.driver_id for driver in fallback_drivers] == ["legacy.send.qq"]
+
+ plugin_driver = _PluginDriver(driver_id="plugin.qq.sender", platform="qq")
+ await manager.add_driver(plugin_driver)
+ manager.bind_send_route(
+ RouteBinding(
+ route_key=RouteKey(platform="qq"),
+ driver_id=plugin_driver.driver_id,
+ driver_kind=plugin_driver.descriptor.kind,
+ )
+ )
+
+ explicit_drivers = manager.resolve_drivers(RouteKey(platform="qq"))
+ assert [driver.driver_id for driver in explicit_drivers] == ["plugin.qq.sender"]
+ finally:
+ await manager.stop()
+
+
+@pytest.mark.asyncio
+async def test_legacy_platform_driver_uses_prepared_universal_sender(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """legacy driver 应复用已预处理消息的旧链发送函数。"""
+ calls: list[dict[str, Any]] = []
+
+ async def _fake_send_prepared_message_to_platform(message: Any, show_log: bool = True) -> bool:
+ """记录 legacy driver 调用。"""
+ calls.append({"message": message, "show_log": show_log})
+ return True
+
+ monkeypatch.setattr(
+ uni_message_sender,
+ "send_prepared_message_to_platform",
+ _fake_send_prepared_message_to_platform,
+ )
+
+ driver = LegacyPlatformDriver(
+ driver_id="legacy.send.qq",
+ platform="qq",
+ account_id="bot-qq",
+ )
+ message = type("FakeMessage", (), {"message_id": "message-1"})()
+ receipt = await driver.send_message(
+ message=message,
+ route_key=RouteKey(platform="qq"),
+ metadata={"show_log": False},
+ )
+
+ assert len(calls) == 1
+ assert calls[0]["message"] is message
+ assert calls[0]["show_log"] is False
+ assert receipt.status == DeliveryStatus.SENT
+ assert receipt.driver_id == "legacy.send.qq"
diff --git a/pytests/test_send_service.py b/pytests/test_send_service.py
new file mode 100644
index 00000000..4ddd4fa1
--- /dev/null
+++ b/pytests/test_send_service.py
@@ -0,0 +1,141 @@
+"""发送服务回归测试。"""
+
+from types import SimpleNamespace
+from typing import Any, Dict, List
+
+import pytest
+
+from src.chat.message_receive.chat_manager import BotChatSession
+from src.services import send_service
+
+
+class _FakePlatformIOManager:
+ """用于测试的 Platform IO 管理器假对象。"""
+
+ def __init__(self, delivery_batch: Any) -> None:
+ """初始化假 Platform IO 管理器。
+
+ Args:
+ delivery_batch: 发送时返回的批量回执。
+ """
+ self._delivery_batch = delivery_batch
+ self.ensure_calls = 0
+ self.sent_messages: List[Dict[str, Any]] = []
+
+ async def ensure_send_pipeline_ready(self) -> None:
+ """记录发送管线准备调用次数。"""
+ self.ensure_calls += 1
+
+ def build_route_key_from_message(self, message: Any) -> Any:
+ """根据消息构造假的路由键。
+
+ Args:
+ message: 待发送的内部消息对象。
+
+ Returns:
+ Any: 简化后的路由键对象。
+ """
+ del message
+ return SimpleNamespace(platform="qq")
+
+ async def send_message(self, message: Any, route_key: Any, metadata: Dict[str, Any]) -> Any:
+ """记录发送请求并返回预设回执。
+
+ Args:
+ message: 待发送的内部消息对象。
+ route_key: 本次发送使用的路由键。
+ metadata: 发送元数据。
+
+ Returns:
+ Any: 预设的批量发送回执。
+ """
+ self.sent_messages.append(
+ {
+ "message": message,
+ "route_key": route_key,
+ "metadata": metadata,
+ }
+ )
+ return self._delivery_batch
+
+
+def _build_target_stream() -> BotChatSession:
+ """构造一个最小可用的目标会话对象。
+
+ Returns:
+ BotChatSession: 测试用会话对象。
+ """
+ return BotChatSession(
+ session_id="test-session",
+ platform="qq",
+ user_id="target-user",
+ group_id=None,
+ )
+
+
+@pytest.mark.asyncio
+async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.MonkeyPatch) -> None:
+ """send service 应将发送职责统一交给 Platform IO。"""
+ fake_manager = _FakePlatformIOManager(
+ delivery_batch=SimpleNamespace(
+ has_success=True,
+ sent_receipts=[SimpleNamespace(driver_id="plugin.qq.sender")],
+ failed_receipts=[],
+ route_key=SimpleNamespace(platform="qq"),
+ )
+ )
+ stored_messages: List[Any] = []
+
+ monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager)
+ monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
+ monkeypatch.setattr(
+ send_service._chat_manager,
+ "get_session_by_session_id",
+ lambda stream_id: _build_target_stream() if stream_id == "test-session" else None,
+ )
+ monkeypatch.setattr(
+ send_service.MessageUtils,
+ "store_message_to_db",
+ lambda message: stored_messages.append(message),
+ )
+
+ result = await send_service.text_to_stream(text="你好", stream_id="test-session")
+
+ assert result is True
+ assert fake_manager.ensure_calls == 1
+ assert len(fake_manager.sent_messages) == 1
+ assert fake_manager.sent_messages[0]["metadata"] == {"show_log": False}
+ assert len(stored_messages) == 1
+
+
+@pytest.mark.asyncio
+async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch: pytest.MonkeyPatch) -> None:
+ """Platform IO 批量发送全部失败时,应直接向上返回失败。"""
+ fake_manager = _FakePlatformIOManager(
+ delivery_batch=SimpleNamespace(
+ has_success=False,
+ sent_receipts=[],
+ failed_receipts=[
+ SimpleNamespace(
+ driver_id="plugin.qq.sender",
+ status="failed",
+ error="network error",
+ )
+ ],
+ route_key=SimpleNamespace(platform="qq"),
+ )
+ )
+
+ monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager)
+ monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
+ monkeypatch.setattr(
+ send_service._chat_manager,
+ "get_session_by_session_id",
+ lambda stream_id: _build_target_stream() if stream_id == "test-session" else None,
+ )
+
+ result = await send_service.text_to_stream(text="发送失败", stream_id="test-session")
+
+ assert result is False
+ assert fake_manager.ensure_calls == 1
+ assert len(fake_manager.sent_messages) == 1
diff --git a/src/chat/brain_chat/PFC/message_sender.py b/src/chat/brain_chat/PFC/message_sender.py
index ec5fb5ba..b9da905c 100644
--- a/src/chat/brain_chat/PFC/message_sender.py
+++ b/src/chat/brain_chat/PFC/message_sender.py
@@ -1,27 +1,28 @@
-import time
+"""PFC 侧消息发送封装。"""
+
from typing import Optional
-from maim_message import Seg
from rich.traceback import install
-from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
from src.chat.message_receive.chat_manager import BotChatSession
-from src.chat.message_receive.message import MessageSending
-from src.chat.message_receive.uni_message_sender import UniversalMessageSender
-from src.chat.utils.utils import get_bot_account
+from src.common.data_models.mai_message_data_model import MaiMessage
from src.common.logger import get_logger
-from src.config.config import global_config
+from src.services import send_service as send_api
install(extra_lines=3)
-
logger = get_logger("message_sender")
class DirectMessageSender:
- """直接消息发送器"""
+ """直接消息发送器。"""
- def __init__(self, private_name: str):
+ def __init__(self, private_name: str) -> None:
+ """初始化直接消息发送器。
+
+ Args:
+ private_name: 当前私聊实例的名称。
+ """
self.private_name = private_name
async def send_message(
@@ -30,58 +31,31 @@ class DirectMessageSender:
content: str,
reply_to_message: Optional[MaiMessage] = None,
) -> None:
- """发送消息到聊天流
+ """发送文本消息到聊天流。
Args:
- chat_stream: 聊天会话
- content: 消息内容
- reply_to_message: 要回复的消息(可选)
+ chat_stream: 目标聊天会话。
+ content: 待发送的文本内容。
+ reply_to_message: 可选的引用回复锚点消息。
+
+ Raises:
+ RuntimeError: 当消息发送失败时抛出。
"""
try:
- # 创建消息内容
- segments = Seg(type="seglist", data=[Seg(type="text", data=content)])
-
- # 获取麦麦的信息
- bot_user_id = get_bot_account(chat_stream.platform)
- if not bot_user_id:
- logger.error(f"[私聊][{self.private_name}]平台 {chat_stream.platform} 未配置机器人账号,无法发送消息")
- raise RuntimeError(f"平台 {chat_stream.platform} 未配置机器人账号")
- bot_user_info = UserInfo(
- user_id=bot_user_id,
- user_nickname=global_config.bot.nickname,
+ sent = await send_api.text_to_stream(
+ text=content,
+ stream_id=chat_stream.session_id,
+ set_reply=reply_to_message is not None,
+ reply_message=reply_to_message,
+ storage_message=True,
)
- # 用当前时间作为message_id,和之前那套sender一样
- message_id = f"dm{round(time.time(), 2)}"
-
- # 构建发送者信息(私聊时为接收者)
- sender_info = None
- if reply_to_message and reply_to_message.message_info and reply_to_message.message_info.user_info:
- sender_info = reply_to_message.message_info.user_info
-
- # 构建消息对象
- message = MessageSending(
- message_id=message_id,
- session=chat_stream,
- bot_user_info=bot_user_info,
- sender_info=sender_info,
- message_segment=segments,
- reply=reply_to_message,
- is_head=True,
- is_emoji=False,
- thinking_start_time=time.time(),
- )
-
- # 发送消息
- message_sender = UniversalMessageSender()
- sent = await message_sender.send_message(message, typing=False, set_reply=False, storage_message=True)
-
if sent:
logger.info(f"[私聊][{self.private_name}]PFC消息已发送: {content}")
- else:
- logger.error(f"[私聊][{self.private_name}]PFC消息发送失败")
- raise RuntimeError("消息发送失败")
+ return
- except Exception as e:
- logger.error(f"[私聊][{self.private_name}]PFC消息发送失败: {str(e)}")
+ logger.error(f"[私聊][{self.private_name}]PFC消息发送失败")
+ raise RuntimeError("消息发送失败")
+ except Exception as exc:
+ logger.error(f"[私聊][{self.private_name}]PFC消息发送失败: {exc}")
raise
diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py
index df74e459..cf42e092 100644
--- a/src/chat/message_receive/uni_message_sender.py
+++ b/src/chat/message_receive/uni_message_sender.py
@@ -60,8 +60,7 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
发送顺序为:
1. WebUI 特殊链路
- 2. Platform IO 适配器链路
- 3. 旧版 ``maim_message`` / API Server 链路
+ 2. 旧版 ``maim_message`` / API Server 链路
Args:
message: 待发送的内部会话消息。
@@ -124,32 +123,6 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
logger.info(f"已将消息 '{message_preview}' 发往 WebUI 聊天室")
return True
- try:
- from src.plugin_runtime.integration import get_plugin_runtime_manager
-
- delivery_batch = await get_plugin_runtime_manager().try_send_message_via_platform_io(message)
- if delivery_batch is not None:
- if delivery_batch.has_success:
- successful_driver_ids = [
- receipt.driver_id or "unknown"
- for receipt in delivery_batch.sent_receipts
- ]
- if show_log:
- logger.info(
- f"已通过 Platform IO 将消息 '{message_preview}' 发往平台'{platform}' "
- f"(drivers: {', '.join(successful_driver_ids)})"
- )
- return True
-
- failed_details = "; ".join(
- f"driver={receipt.driver_id} status={receipt.status} error={receipt.error}"
- for receipt in delivery_batch.failed_receipts
- ) or "未命中任何发送路由"
- logger.warning(f"Platform IO 发送失败: platform={platform} {failed_details}")
- return False
- except Exception as exc:
- logger.warning(f"检查 Platform IO 出站链路时出现异常,将回退旧发送链: {exc}")
-
# Fallback 逻辑: 尝试通过 API Server 发送
async def send_with_new_api(legacy_exception: Optional[Exception] = None) -> bool:
"""通过 API Server 回退链路发送消息。
@@ -260,8 +233,21 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
raise e # 重新抛出其他异常
+async def send_prepared_message_to_platform(message: SessionMessage, show_log: bool = True) -> bool:
+ """发送一条已完成预处理的消息到底层平台。
+
+ Args:
+ message: 已经完成回复组件注入、文本处理等预处理的消息对象。
+ show_log: 是否输出发送成功日志。
+
+ Returns:
+ bool: 发送成功时返回 ``True``。
+ """
+ return await _send_message(message, show_log=show_log)
+
+
class UniversalMessageSender:
- """管理消息的注册、即时处理、发送和存储,并跟踪思考状态。"""
+ """旧链与 WebUI 的底层发送器。"""
def __init__(self) -> None:
"""初始化统一消息发送器。"""
@@ -276,17 +262,18 @@ class UniversalMessageSender:
storage_message: bool = True,
show_log: bool = True,
) -> bool:
- """
- 处理、发送并存储一条消息。
+ """通过旧链或 WebUI 发送并存储一条消息。
- 参数:
- message: MessageSession 对象,待发送的消息。
+ Args:
+ message: 待发送的内部消息对象。
typing: 是否模拟打字等待。
- set_reply: 是否构建回复引用消息。
+ set_reply: 是否构建引用回复消息。
+ reply_message_id: 被引用消息的 ID。
+ storage_message: 是否在发送成功后写入数据库。
+ show_log: 是否输出发送日志。
-
- 用法:
- - typing=True 时,发送前会有打字等待。
+ Returns:
+ bool: 发送成功时返回 ``True``。
"""
if not message.message_id:
logger.error("消息缺少 message_id,无法发送")
@@ -339,7 +326,7 @@ class UniversalMessageSender:
)
await asyncio.sleep(typing_time)
- sent_msg = await _send_message(message, show_log=show_log)
+ sent_msg = await send_prepared_message_to_platform(message, show_log=show_log)
if not sent_msg:
return False
diff --git a/src/chat/replyer/group_generator.py b/src/chat/replyer/group_generator.py
index 75563df7..4ffa14a7 100644
--- a/src/chat/replyer/group_generator.py
+++ b/src/chat/replyer/group_generator.py
@@ -17,7 +17,6 @@ from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUser
from src.common.data_models.mai_message_data_model import MaiMessage
from src.chat.message_receive.message import SessionMessage
from src.chat.message_receive.chat_manager import BotChatSession
-from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
from src.chat.utils.utils import get_bot_account, get_chat_type_and_target_info, is_bot_self
from src.prompt.prompt_manager import prompt_manager
@@ -51,10 +50,15 @@ class DefaultReplyer:
chat_stream: BotChatSession,
request_type: str = "replyer",
):
+ """初始化群聊回复器。
+
+ Args:
+ chat_stream: 当前绑定的聊天会话。
+ request_type: LLM 请求类型标识。
+ """
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
self.chat_stream = chat_stream
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
- self.heart_fc_sender = UniversalMessageSender()
from src.chat.tool_executor import ToolExecutor
diff --git a/src/chat/replyer/private_generator.py b/src/chat/replyer/private_generator.py
index f642dd69..c125a42f 100644
--- a/src/chat/replyer/private_generator.py
+++ b/src/chat/replyer/private_generator.py
@@ -16,7 +16,6 @@ from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUser
from src.common.data_models.mai_message_data_model import MaiMessage
from src.chat.message_receive.message import SessionMessage
from src.chat.message_receive.chat_manager import BotChatSession
-from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from src.chat.utils.timer_calculator import Timer
from src.chat.utils.utils import get_bot_account, get_chat_type_and_target_info, is_bot_self
from src.prompt.prompt_manager import prompt_manager
@@ -47,10 +46,15 @@ class PrivateReplyer:
chat_stream: BotChatSession,
request_type: str = "replyer",
):
+ """初始化私聊回复器。
+
+ Args:
+ chat_stream: 当前绑定的聊天会话。
+ request_type: LLM 请求类型标识。
+ """
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
self.chat_stream = chat_stream
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
- self.heart_fc_sender = UniversalMessageSender()
# self.memory_activator = MemoryActivator()
from src.chat.tool_executor import ToolExecutor
diff --git a/src/common/message_server/server.py b/src/common/message_server/server.py
index 77a931e5..e75da4e7 100644
--- a/src/common/message_server/server.py
+++ b/src/common/message_server/server.py
@@ -21,7 +21,7 @@ class Server:
self._server: Optional[UvicornServer] = None
self.set_address(host, port)
- def register_router(self, router: APIRouter, prefix: str = ""):
+ def register_router(self, router: APIRouter, prefix: str = ""):
"""注册路由
APIRouter 用于对相关的路由端点进行分组和模块化管理:
diff --git a/src/platform_io/drivers/legacy_driver.py b/src/platform_io/drivers/legacy_driver.py
index bd74d8c7..ef90c772 100644
--- a/src/platform_io/drivers/legacy_driver.py
+++ b/src/platform_io/drivers/legacy_driver.py
@@ -1,16 +1,16 @@
-"""提供 Platform IO 的 legacy 传输驱动骨架。"""
+"""提供 Platform IO 的 legacy 传输驱动实现。"""
from typing import TYPE_CHECKING, Any, Dict, Optional
from src.platform_io.drivers.base import PlatformIODriver
-from src.platform_io.types import DeliveryReceipt, DriverDescriptor, DriverKind, RouteKey
+from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, RouteKey
if TYPE_CHECKING:
from src.chat.message_receive.message import SessionMessage
class LegacyPlatformDriver(PlatformIODriver):
- """面向 ``maim_message`` 旧链路的 Platform IO 驱动骨架。"""
+ """面向 ``UniversalMessageSender`` 旧链的 Platform IO 驱动。"""
def __init__(
self,
@@ -25,7 +25,7 @@ class LegacyPlatformDriver(PlatformIODriver):
Args:
driver_id: Broker 内的唯一驱动 ID。
platform: 该 legacy 适配器链路负责的平台。
- account_id: 可选的账号 ID 或 self ID。
+ account_id: 可选的账号 ID。
scope: 可选的额外路由作用域。
metadata: 可选的额外驱动元数据。
"""
@@ -45,7 +45,7 @@ class LegacyPlatformDriver(PlatformIODriver):
route_key: RouteKey,
metadata: Optional[Dict[str, Any]] = None,
) -> DeliveryReceipt:
- """通过 legacy 传输路径发送消息。
+ """通过旧链发送一条已经过预处理的消息。
Args:
message: 要投递的内部会话消息。
@@ -53,9 +53,40 @@ class LegacyPlatformDriver(PlatformIODriver):
metadata: 本次出站投递可选的 Broker 侧元数据。
Returns:
- DeliveryReceipt: 由驱动返回的规范化回执。
-
- Raises:
- NotImplementedError: 当前仍处于骨架阶段,尚未真正接入旧发送链。
+ DeliveryReceipt: 规范化后的发送回执。
"""
- raise NotImplementedError("LegacyPlatformDriver 仅完成地基实现,尚未接入旧发送链")
+ from src.chat.message_receive.uni_message_sender import send_prepared_message_to_platform
+
+ show_log = False
+ if isinstance(metadata, dict):
+ show_log = bool(metadata.get("show_log", False))
+
+ try:
+ sent = await send_prepared_message_to_platform(message, show_log=show_log)
+ except Exception as exc:
+ return DeliveryReceipt(
+ internal_message_id=message.message_id,
+ route_key=route_key,
+ status=DeliveryStatus.FAILED,
+ driver_id=self.driver_id,
+ driver_kind=self.descriptor.kind,
+ error=str(exc),
+ )
+
+ if not sent:
+ return DeliveryReceipt(
+ internal_message_id=message.message_id,
+ route_key=route_key,
+ status=DeliveryStatus.FAILED,
+ driver_id=self.driver_id,
+ driver_kind=self.descriptor.kind,
+ error="旧链发送失败",
+ )
+
+ return DeliveryReceipt(
+ internal_message_id=message.message_id,
+ route_key=route_key,
+ status=DeliveryStatus.SENT,
+ driver_id=self.driver_id,
+ driver_kind=self.descriptor.kind,
+ )
diff --git a/src/platform_io/manager.py b/src/platform_io/manager.py
index c96a9ddd..cb5996b4 100644
--- a/src/platform_io/manager.py
+++ b/src/platform_io/manager.py
@@ -36,6 +36,7 @@ class PlatformIOManager:
self._driver_registry = DriverRegistry()
self._send_route_table = RouteTable()
self._receive_route_table = RouteTable()
+ self._legacy_send_drivers: Dict[str, PlatformIODriver] = {}
self._deduplicator = MessageDeduplicator()
self._outbound_tracker = OutboundTracker()
self._inbound_dispatcher: Optional[InboundDispatcher] = None
@@ -75,6 +76,16 @@ class PlatformIOManager:
self._started = True
+ async def ensure_send_pipeline_ready(self) -> None:
+ """确保出站发送管线已准备就绪。
+
+ 该方法会先同步 legacy fallback driver,再在需要时启动 Broker。
+ send service 应只调用这一层准备入口,而不是自行判断旧链或插件链。
+ """
+ await self._sync_legacy_send_drivers()
+ if not self._started:
+ await self.start()
+
async def stop(self) -> None:
"""停止 Broker,并按逆序停止全部已注册驱动。
@@ -272,8 +283,60 @@ class PlatformIOManager:
removed_driver.clear_inbound_handler()
self._send_route_table.remove_bindings_by_driver(driver_id)
self._receive_route_table.remove_bindings_by_driver(driver_id)
+ self._legacy_send_drivers = {
+ platform: driver
+ for platform, driver in self._legacy_send_drivers.items()
+ if driver.driver_id != driver_id
+ }
return removed_driver
+ async def _sync_legacy_send_drivers(self) -> None:
+ """根据当前配置同步 legacy fallback driver。"""
+ from src.chat.utils.utils import get_all_bot_accounts
+ from src.platform_io.drivers.legacy_driver import LegacyPlatformDriver
+
+ desired_accounts = get_all_bot_accounts()
+ desired_platforms = set(desired_accounts.keys())
+ current_platforms = set(self._legacy_send_drivers.keys())
+
+ for platform in sorted(current_platforms - desired_platforms):
+ await self._remove_legacy_send_driver(platform)
+
+ for platform, account_id in desired_accounts.items():
+ existing_driver = self._legacy_send_drivers.get(platform)
+ if existing_driver is not None and existing_driver.descriptor.account_id == account_id:
+ continue
+
+ if existing_driver is not None:
+ await self._remove_legacy_send_driver(platform)
+
+ driver = LegacyPlatformDriver(
+ driver_id=f"legacy.send.{platform}",
+ platform=platform,
+ account_id=account_id,
+ )
+ if self._started:
+ await self.add_driver(driver)
+ else:
+ self.register_driver(driver)
+ self._legacy_send_drivers[platform] = driver
+
+ async def _remove_legacy_send_driver(self, platform: str) -> None:
+ """移除指定平台的 legacy fallback driver。
+
+ Args:
+ platform: 要移除的目标平台。
+ """
+ driver = self._legacy_send_drivers.get(platform)
+ if driver is None:
+ return
+
+ if self._started:
+ await self.remove_driver(driver.driver_id)
+ else:
+ self.unregister_driver(driver.driver_id)
+ self._legacy_send_drivers.pop(platform, None)
+
def bind_send_route(self, binding: RouteBinding) -> None:
"""为某个路由键绑定发送驱动。
@@ -353,7 +416,19 @@ class PlatformIOManager:
driver = self._driver_registry.get(binding.driver_id)
if driver is not None:
drivers.append(driver)
- return drivers
+ if drivers:
+ return drivers
+
+ fallback_driver = self._legacy_send_drivers.get(route_key.platform)
+ if fallback_driver is None:
+ return []
+
+ descriptor = fallback_driver.descriptor
+ if descriptor.account_id is not None and route_key.account_id not in (None, descriptor.account_id):
+ return []
+ if descriptor.scope is not None and route_key.scope not in (None, descriptor.scope):
+ return []
+ return [fallback_driver]
def resolve_driver(self, route_key: RouteKey) -> Optional[PlatformIODriver]:
"""兼容旧接口,返回首个命中的发送驱动。"""
diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py
index b74b2d46..ff51f419 100644
--- a/src/plugin_runtime/integration.py
+++ b/src/plugin_runtime/integration.py
@@ -157,7 +157,7 @@ class PluginRuntimeManager(
started_supervisors: List[PluginSupervisor] = []
try:
platform_io_manager.set_inbound_dispatcher(self._dispatch_platform_inbound)
- await platform_io_manager.start()
+ await platform_io_manager.ensure_send_pipeline_ready()
if self._builtin_supervisor:
await self._builtin_supervisor.start()
diff --git a/src/services/send_service.py b/src/services/send_service.py
index 6ca7d005..7903cdeb 100644
--- a/src/services/send_service.py
+++ b/src/services/send_service.py
@@ -1,39 +1,51 @@
"""
-发送服务模块
+发送服务模块。
-提供发送各种类型消息的核心功能。
+统一封装内部模块的出站消息发送逻辑:
+
+1. 内部模块统一调用本模块。
+2. send service 只负责构造和预处理消息。
+3. 具体走插件链还是 legacy 旧链,由 Platform IO 内部统一决策。
"""
-from typing import TYPE_CHECKING, Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional
+from maim_message import Seg
+
+import asyncio
+import base64
+import hashlib
import time
import traceback
+from datetime import datetime
-from maim_message import BaseMessageInfo, GroupInfo as MaimGroupInfo, MessageBase, Seg, UserInfo as MaimUserInfo
-
+from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.chat.message_receive.message import SessionMessage
-from src.chat.message_receive.uni_message_sender import UniversalMessageSender
-from src.chat.utils.utils import get_bot_account
-from src.common.data_models.mai_message_data_model import MaiMessage
-from src.common.data_models.message_component_data_model import DictComponent, MessageSequence
+from src.chat.utils.utils import calculate_typing_time, get_bot_account
+from src.common.data_models.mai_message_data_model import GroupInfo, MaiMessage, MessageInfo, UserInfo
+from src.common.data_models.message_component_data_model import (
+ AtComponent,
+ DictComponent,
+ EmojiComponent,
+ ImageComponent,
+ MessageSequence,
+ ReplyComponent,
+ StandardMessageComponents,
+ TextComponent,
+ VoiceComponent,
+)
from src.common.logger import get_logger
+from src.common.utils.utils_message import MessageUtils
from src.config.config import global_config
+from src.platform_io import DeliveryBatch, get_platform_io_manager
from src.platform_io.route_key_factory import RouteKeyFactory
-if TYPE_CHECKING:
- from src.chat.message_receive.message import SessionMessage
-
logger = get_logger("send_service")
-# =============================================================================
-# 内部实现函数
-# =============================================================================
-
-
-def _inherit_platform_io_route_metadata(target_stream: Any) -> Dict[str, object]:
- """从目标会话上下文继承 Platform IO 路由元数据。
+def _inherit_platform_io_route_metadata(target_stream: BotChatSession) -> Dict[str, object]:
+ """从目标会话继承 Platform IO 路由元数据。
Args:
target_stream: 当前消息要发送到的会话对象。
@@ -44,12 +56,11 @@ def _inherit_platform_io_route_metadata(target_stream: Any) -> Dict[str, object]
"""
inherited_metadata: Dict[str, object] = {}
- context = getattr(target_stream, "context", None)
- context_message = getattr(context, "message", None)
+ context_message = target_stream.context.message if target_stream.context else None
if context_message is None:
return inherited_metadata
- additional_config = getattr(context_message.message_info, "additional_config", {})
+ additional_config = context_message.message_info.additional_config
if not isinstance(additional_config, dict):
return inherited_metadata
@@ -61,33 +72,412 @@ def _inherit_platform_io_route_metadata(target_stream: Any) -> Dict[str, object]
if normalized_value:
inherited_metadata[key] = value
- target_group_id = getattr(target_stream, "group_id", None)
- if target_group_id is not None:
- normalized_group_id = str(target_group_id).strip()
+ if target_stream.group_id:
+ normalized_group_id = str(target_stream.group_id).strip()
if normalized_group_id:
inherited_metadata["platform_io_target_group_id"] = normalized_group_id
- target_user_id = getattr(target_stream, "user_id", None)
- if target_user_id is not None:
- normalized_user_id = str(target_user_id).strip()
+ if target_stream.user_id:
+ normalized_user_id = str(target_stream.user_id).strip()
if normalized_user_id:
inherited_metadata["platform_io_target_user_id"] = normalized_user_id
return inherited_metadata
+def _build_component_from_seg(message_segment: Seg) -> StandardMessageComponents:
+ """将单个消息段转换为内部消息组件。
+
+ Args:
+ message_segment: 待转换的消息段。
+
+ Returns:
+ StandardMessageComponents: 转换后的内部消息组件。
+ """
+ segment_type = str(message_segment.type or "").strip().lower()
+ segment_data = message_segment.data
+
+ if segment_type == "text":
+ return TextComponent(text=str(segment_data or ""))
+
+ if segment_type == "image":
+ image_binary = base64.b64decode(str(segment_data or ""))
+ return ImageComponent(
+ binary_hash=hashlib.sha256(image_binary).hexdigest(),
+ binary_data=image_binary,
+ )
+
+ if segment_type == "emoji":
+ emoji_binary = base64.b64decode(str(segment_data or ""))
+ return EmojiComponent(
+ binary_hash=hashlib.sha256(emoji_binary).hexdigest(),
+ binary_data=emoji_binary,
+ )
+
+ if segment_type == "voice":
+ voice_binary = base64.b64decode(str(segment_data or ""))
+ return VoiceComponent(
+ binary_hash=hashlib.sha256(voice_binary).hexdigest(),
+ binary_data=voice_binary,
+ )
+
+ if segment_type == "at":
+ return AtComponent(target_user_id=str(segment_data or ""))
+
+ if segment_type == "reply":
+ return ReplyComponent(target_message_id=str(segment_data or ""))
+
+ if segment_type == "dict" and isinstance(segment_data, dict):
+ return DictComponent(data=segment_data)
+
+ return DictComponent(data={"type": segment_type, "data": segment_data})
+
+
+def _build_message_sequence_from_seg(message_segment: Seg) -> MessageSequence:
+ """将消息段转换为内部消息组件序列。
+
+ Args:
+ message_segment: 待转换的消息段。
+
+ Returns:
+ MessageSequence: 转换后的消息组件序列。
+ """
+ if str(message_segment.type or "").strip().lower() == "seglist":
+ raw_segments = message_segment.data
+ if not isinstance(raw_segments, list):
+ raise ValueError("seglist 类型的消息段数据必须是列表")
+ components = [
+ _build_component_from_seg(item)
+ for item in raw_segments
+ if isinstance(item, Seg)
+ ]
+ return MessageSequence(components=components)
+
+ return MessageSequence(components=[_build_component_from_seg(message_segment)])
+
+
+def _build_processed_plain_text(message: SessionMessage) -> str:
+ """为出站消息构造轻量纯文本摘要。
+
+ Args:
+ message: 待发送的内部消息对象。
+
+ Returns:
+ str: 适用于日志与打字时长估算的纯文本摘要。
+ """
+ processed_parts: List[str] = []
+ for component in message.raw_message.components:
+ if isinstance(component, TextComponent):
+ processed_parts.append(component.text)
+ continue
+
+ if isinstance(component, ImageComponent):
+ processed_parts.append(component.content or "[图片]")
+ continue
+
+ if isinstance(component, EmojiComponent):
+ processed_parts.append(component.content or "[表情]")
+ continue
+
+ if isinstance(component, VoiceComponent):
+ processed_parts.append(component.content or "[语音]")
+ continue
+
+ if isinstance(component, AtComponent):
+ at_target = component.target_user_cardname or component.target_user_nickname or component.target_user_id
+ processed_parts.append(f"@{at_target}")
+ continue
+
+ if isinstance(component, ReplyComponent):
+ processed_parts.append(component.target_message_content or "[回复消息]")
+ continue
+
+ if isinstance(component, DictComponent):
+ raw_type = component.data.get("type") if isinstance(component.data, dict) else None
+ if isinstance(raw_type, str) and raw_type.strip():
+ processed_parts.append(f"[{raw_type.strip()}消息]")
+ else:
+ processed_parts.append("[自定义消息]")
+ continue
+
+ return " ".join(part for part in processed_parts if part)
+
+
+def _build_outbound_session_message(
+ message_segment: Seg,
+ stream_id: str,
+ display_message: str = "",
+ reply_message: Optional[MaiMessage] = None,
+ selected_expressions: Optional[List[int]] = None,
+) -> Optional[SessionMessage]:
+ """根据目标会话构建待发送的内部消息对象。
+
+ Args:
+ message_segment: 待发送的消息段。
+ stream_id: 目标会话 ID。
+ display_message: 用于界面展示的文本内容。
+ reply_message: 被回复的锚点消息。
+ selected_expressions: 可选的表情候选索引列表。
+
+ Returns:
+ Optional[SessionMessage]: 构建成功时返回内部消息对象;若目标会话或
+ 机器人账号不存在,则返回 ``None``。
+ """
+ target_stream = _chat_manager.get_session_by_session_id(stream_id)
+ if target_stream is None:
+ logger.error(f"[SendService] 未找到聊天流: {stream_id}")
+ return None
+
+ bot_user_id = get_bot_account(target_stream.platform)
+ if not bot_user_id:
+ logger.error(f"[SendService] 平台 {target_stream.platform} 未配置机器人账号,无法发送消息")
+ return None
+
+ current_time = time.time()
+ message_id = f"send_api_{int(current_time * 1000)}"
+ anchor_message = reply_message.deepcopy() if reply_message is not None else None
+
+ group_info: Optional[GroupInfo] = None
+ if target_stream.group_id:
+ group_name = ""
+ if (
+ target_stream.context
+ and target_stream.context.message
+ and target_stream.context.message.message_info.group_info
+ ):
+ group_name = target_stream.context.message.message_info.group_info.group_name
+ group_info = GroupInfo(
+ group_id=target_stream.group_id,
+ group_name=group_name,
+ )
+
+ additional_config: Dict[str, object] = _inherit_platform_io_route_metadata(target_stream)
+ if selected_expressions is not None:
+ additional_config["selected_expressions"] = selected_expressions
+
+ outbound_message = SessionMessage(
+ message_id=message_id,
+ timestamp=datetime.fromtimestamp(current_time),
+ platform=target_stream.platform,
+ )
+ outbound_message.message_info = MessageInfo(
+ user_info=UserInfo(
+ user_id=bot_user_id,
+ user_nickname=global_config.bot.nickname,
+ ),
+ group_info=group_info,
+ additional_config=additional_config,
+ )
+ outbound_message.raw_message = _build_message_sequence_from_seg(message_segment)
+ outbound_message.session_id = target_stream.session_id
+ outbound_message.display_message = display_message
+ outbound_message.reply_to = anchor_message.message_id if anchor_message is not None else None
+ outbound_message.is_emoji = message_segment.type == "emoji"
+ outbound_message.is_picture = message_segment.type == "image"
+ outbound_message.is_command = message_segment.type == "command"
+ outbound_message.initialized = True
+ return outbound_message
+
+
+def _ensure_reply_component(message: SessionMessage, reply_message_id: str) -> None:
+ """为消息补充回复组件。
+
+ Args:
+ message: 待发送的内部消息对象。
+ reply_message_id: 被引用消息的 ID。
+ """
+ if message.raw_message.components:
+ first_component = message.raw_message.components[0]
+ if isinstance(first_component, ReplyComponent) and first_component.target_message_id == reply_message_id:
+ return
+
+ message.raw_message.components.insert(0, ReplyComponent(target_message_id=reply_message_id))
+
+
+async def _prepare_message_for_platform_io(
+ message: SessionMessage,
+ *,
+ typing: bool,
+ set_reply: bool,
+ reply_message_id: Optional[str],
+) -> None:
+ """为 Platform IO 发送链预处理消息。
+
+ Args:
+ message: 待发送的内部消息对象。
+ typing: 是否模拟打字等待。
+ set_reply: 是否构建引用回复组件。
+ reply_message_id: 被引用消息的 ID。
+
+ Raises:
+ ValueError: 当要求设置引用回复但缺少 ``reply_message_id`` 时抛出。
+ """
+ if set_reply:
+ if not reply_message_id:
+ raise ValueError("set_reply=True 时必须提供 reply_message_id")
+ _ensure_reply_component(message, reply_message_id)
+
+ message.processed_plain_text = _build_processed_plain_text(message)
+ if typing:
+ typing_time = calculate_typing_time(
+ input_string=message.processed_plain_text or "",
+ is_emoji=message.is_emoji,
+ )
+ await asyncio.sleep(typing_time)
+
+
+def _store_sent_message(message: SessionMessage) -> None:
+ """将已成功发送的消息写入数据库。
+
+ Args:
+ message: 已成功发送的内部消息对象。
+ """
+ MessageUtils.store_message_to_db(message)
+
+
+def _log_platform_io_failures(delivery_batch: DeliveryBatch) -> None:
+ """输出 Platform IO 批量发送失败详情。
+
+ Args:
+ delivery_batch: Platform IO 返回的批量回执。
+ """
+ failed_details = "; ".join(
+ f"driver={receipt.driver_id} status={receipt.status} error={receipt.error}"
+ for receipt in delivery_batch.failed_receipts
+ ) or "未命中任何发送路由"
+ logger.warning(
+ "[SendService] Platform IO 发送失败: platform=%s %s",
+ delivery_batch.route_key.platform,
+ failed_details,
+ )
+
+
+async def _send_via_platform_io(
+ message: SessionMessage,
+ *,
+ typing: bool,
+ set_reply: bool,
+ reply_message_id: Optional[str],
+ storage_message: bool,
+ show_log: bool,
+) -> bool:
+ """通过 Platform IO 发送消息。
+
+ Args:
+ message: 待发送的内部消息对象。
+ typing: 是否模拟打字等待。
+ set_reply: 是否设置引用回复。
+ reply_message_id: 被引用消息的 ID。
+ storage_message: 发送成功后是否写入数据库。
+ show_log: 是否输出发送成功日志。
+
+ Returns:
+ bool: 发送成功时返回 ``True``。
+ """
+ platform_io_manager = get_platform_io_manager()
+ try:
+ await platform_io_manager.ensure_send_pipeline_ready()
+ except Exception as exc:
+ logger.error(f"[SendService] 准备 Platform IO 发送管线失败: {exc}")
+ logger.debug(traceback.format_exc())
+ return False
+
+ try:
+ route_key = platform_io_manager.build_route_key_from_message(message)
+ except Exception as exc:
+ logger.warning(f"[SendService] 根据消息构造 Platform IO 路由键失败: {exc}")
+ return False
+
+ try:
+ await _prepare_message_for_platform_io(
+ message,
+ typing=typing,
+ set_reply=set_reply,
+ reply_message_id=reply_message_id,
+ )
+ delivery_batch = await platform_io_manager.send_message(
+ message,
+ route_key,
+ metadata={"show_log": False},
+ )
+ except Exception as exc:
+ logger.error(f"[SendService] Platform IO 发送异常: {exc}")
+ logger.debug(traceback.format_exc())
+ return False
+
+ if delivery_batch.has_success:
+ if storage_message:
+ _store_sent_message(message)
+ if show_log:
+ successful_driver_ids = [
+ receipt.driver_id or "unknown"
+ for receipt in delivery_batch.sent_receipts
+ ]
+ logger.info(
+ "[SendService] 已通过 Platform IO 将消息发往平台 '%s' (drivers: %s)",
+ route_key.platform,
+ ", ".join(successful_driver_ids),
+ )
+ return True
+
+ _log_platform_io_failures(delivery_batch)
+ return False
+
+
+async def send_session_message(
+ message: SessionMessage,
+ *,
+ typing: bool = False,
+ set_reply: bool = False,
+ reply_message_id: Optional[str] = None,
+ storage_message: bool = True,
+ show_log: bool = True,
+) -> bool:
+ """统一发送一条内部消息。
+
+ 该方法是内部模块的统一发送入口:
+
+ 1. 构造并维护内部消息对象。
+ 2. 由 Platform IO 统一决定走插件链还是 legacy 旧链。
+ 3. send service 不再自行判断底层发送路径。
+
+ Args:
+ message: 待发送的内部消息对象。
+ typing: 是否模拟打字等待。
+ set_reply: 是否设置引用回复。
+ reply_message_id: 被引用消息的 ID。
+ storage_message: 发送成功后是否写入数据库。
+ show_log: 是否输出发送日志。
+
+ Returns:
+ bool: 发送成功时返回 ``True``,否则返回 ``False``。
+ """
+ if not message.message_id:
+ logger.error("[SendService] 消息缺少 message_id,无法发送")
+ raise ValueError("消息缺少 message_id,无法发送")
+
+ return await _send_via_platform_io(
+ message,
+ typing=typing,
+ set_reply=set_reply,
+ reply_message_id=reply_message_id,
+ storage_message=storage_message,
+ show_log=show_log,
+ )
+
+
async def _send_to_target(
message_segment: Seg,
stream_id: str,
display_message: str = "",
typing: bool = False,
set_reply: bool = False,
- reply_message: Optional["SessionMessage"] = None,
+ reply_message: Optional[MaiMessage] = None,
storage_message: bool = True,
show_log: bool = True,
selected_expressions: Optional[List[int]] = None,
) -> bool:
- """向指定目标发送消息。
+ """向指定目标构建并发送消息。
Args:
message_segment: 待发送的消息段。
@@ -104,110 +494,66 @@ async def _send_to_target(
bool: 发送成功返回 ``True``,否则返回 ``False``。
"""
try:
- if set_reply and not reply_message:
+ if set_reply and reply_message is None:
logger.warning("[SendService] 使用引用回复,但未提供回复消息")
return False
if show_log:
logger.debug(f"[SendService] 发送{message_segment.type}消息到 {stream_id}")
- target_stream = _chat_manager.get_session_by_session_id(stream_id)
- if not target_stream:
- logger.error(f"[SendService] 未找到聊天流: {stream_id}")
- return False
-
- message_sender = UniversalMessageSender()
-
- current_time = time.time()
- message_id = f"send_api_{int(current_time * 1000)}"
-
- anchor_message: Optional[MaiMessage] = None
- if reply_message:
- anchor_message = reply_message.deepcopy()
- if anchor_message:
- logger.debug(
- f"[SendService] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}"
- )
-
- group_info = None
- if target_stream.group_id:
- group_name = ""
- if target_stream.context and target_stream.context.message and target_stream.context.message.message_info.group_info:
- group_name = target_stream.context.message.message_info.group_info.group_name
- group_info = MaimGroupInfo(
- group_id=target_stream.group_id,
- group_name=group_name,
- platform=target_stream.platform,
- )
-
- additional_config: Dict[str, object] = _inherit_platform_io_route_metadata(target_stream)
- if selected_expressions is not None:
- additional_config["selected_expressions"] = selected_expressions
- bot_user_id = get_bot_account(target_stream.platform)
- if not bot_user_id:
- logger.error(f"[SendService] 平台 {target_stream.platform} 未配置机器人账号,无法发送消息")
- return False
-
- maim_message = MessageBase(
- message_info=BaseMessageInfo(
- platform=target_stream.platform,
- message_id=message_id,
- time=current_time,
- user_info=MaimUserInfo(
- user_id=bot_user_id,
- user_nickname=global_config.bot.nickname,
- platform=target_stream.platform,
- ),
- group_info=group_info,
- additional_config=additional_config,
- ),
+ outbound_message = _build_outbound_session_message(
message_segment=message_segment,
+ stream_id=stream_id,
+ display_message=display_message,
+ reply_message=reply_message,
+ selected_expressions=selected_expressions,
)
- bot_message = SessionMessage.from_maim_message(maim_message)
- bot_message.session_id = target_stream.session_id
- bot_message.display_message = display_message
- bot_message.reply_to = anchor_message.message_id if anchor_message else None
- bot_message.is_emoji = message_segment.type == "emoji"
- bot_message.is_picture = message_segment.type == "image"
- bot_message.is_command = message_segment.type == "command"
+ if outbound_message is None:
+ return False
- sent_msg = await message_sender.send_message(
- bot_message,
+ sent = await send_session_message(
+ outbound_message,
typing=typing,
set_reply=set_reply,
- reply_message_id=anchor_message.message_id if anchor_message else None,
+ reply_message_id=reply_message.message_id if reply_message is not None else None,
storage_message=storage_message,
show_log=show_log,
)
-
- if sent_msg:
+ if sent:
logger.debug(f"[SendService] 成功发送消息到 {stream_id}")
return True
- else:
- logger.error("[SendService] 发送消息失败")
- return False
- except Exception as e:
- logger.error(f"[SendService] 发送消息时出错: {e}")
+ logger.error("[SendService] 发送消息失败")
+ return False
+ except Exception as exc:
+ logger.error(f"[SendService] 发送消息时出错: {exc}")
traceback.print_exc()
return False
-# =============================================================================
-# 公共函数 - 预定义类型的发送函数
-# =============================================================================
-
-
async def text_to_stream(
text: str,
stream_id: str,
typing: bool = False,
set_reply: bool = False,
- reply_message: Optional["SessionMessage"] = None,
+ reply_message: Optional[MaiMessage] = None,
storage_message: bool = True,
selected_expressions: Optional[List[int]] = None,
) -> bool:
- """向指定流发送文本消息"""
+ """向指定流发送文本消息。
+
+ Args:
+ text: 要发送的文本内容。
+ stream_id: 目标会话 ID。
+ typing: 是否显示输入中状态。
+ set_reply: 是否附带引用回复。
+ reply_message: 被回复的消息对象。
+ storage_message: 是否在发送成功后写入数据库。
+ selected_expressions: 可选的表情候选索引列表。
+
+ Returns:
+ bool: 发送成功时返回 ``True``。
+ """
return await _send_to_target(
message_segment=Seg(type="text", data=text),
stream_id=stream_id,
@@ -225,9 +571,20 @@ async def emoji_to_stream(
stream_id: str,
storage_message: bool = True,
set_reply: bool = False,
- reply_message: Optional["SessionMessage"] = None,
+ reply_message: Optional[MaiMessage] = None,
) -> bool:
- """向指定流发送表情包"""
+ """向指定流发送表情消息。
+
+ Args:
+ emoji_base64: 表情图片的 Base64 内容。
+ stream_id: 目标会话 ID。
+ storage_message: 是否在发送成功后写入数据库。
+ set_reply: 是否附带引用回复。
+ reply_message: 被回复的消息对象。
+
+ Returns:
+ bool: 发送成功时返回 ``True``。
+ """
return await _send_to_target(
message_segment=Seg(type="emoji", data=emoji_base64),
stream_id=stream_id,
@@ -244,9 +601,20 @@ async def image_to_stream(
stream_id: str,
storage_message: bool = True,
set_reply: bool = False,
- reply_message: Optional["SessionMessage"] = None,
+ reply_message: Optional[MaiMessage] = None,
) -> bool:
- """向指定流发送图片"""
+ """向指定流发送图片消息。
+
+ Args:
+ image_base64: 图片的 Base64 内容。
+ stream_id: 目标会话 ID。
+ storage_message: 是否在发送成功后写入数据库。
+ set_reply: 是否附带引用回复。
+ reply_message: 被回复的消息对象。
+
+ Returns:
+ bool: 发送成功时返回 ``True``。
+ """
return await _send_to_target(
message_segment=Seg(type="image", data=image_base64),
stream_id=stream_id,
@@ -260,18 +628,33 @@ async def image_to_stream(
async def custom_to_stream(
message_type: str,
- content: str | Dict,
+ content: str | Dict[str, Any],
stream_id: str,
display_message: str = "",
typing: bool = False,
- reply_message: Optional["SessionMessage"] = None,
+ reply_message: Optional[MaiMessage] = None,
set_reply: bool = False,
storage_message: bool = True,
show_log: bool = True,
) -> bool:
- """向指定流发送自定义类型消息"""
+ """向指定流发送自定义类型消息。
+
+ Args:
+ message_type: 自定义消息类型。
+ content: 自定义消息内容。
+ stream_id: 目标会话 ID。
+ display_message: 用于展示的文本内容。
+ typing: 是否显示输入中状态。
+ reply_message: 被回复的消息对象。
+ set_reply: 是否附带引用回复。
+ storage_message: 是否在发送成功后写入数据库。
+ show_log: 是否输出发送日志。
+
+ Returns:
+ bool: 发送成功时返回 ``True``。
+ """
return await _send_to_target(
- message_segment=Seg(type=message_type, data=content), # type: ignore
+ message_segment=Seg(type=message_type, data=content), # type: ignore[arg-type]
stream_id=stream_id,
display_message=display_message,
typing=typing,
@@ -287,18 +670,33 @@ async def custom_reply_set_to_stream(
stream_id: str,
display_message: str = "",
typing: bool = False,
- reply_message: Optional["SessionMessage"] = None,
+ reply_message: Optional[MaiMessage] = None,
set_reply: bool = False,
storage_message: bool = True,
show_log: bool = True,
) -> bool:
- """向指定流发送消息组件序列。"""
- flag: bool = True
+ """向指定流发送消息组件序列。
+
+ Args:
+ reply_set: 待发送的消息组件序列。
+ stream_id: 目标会话 ID。
+ display_message: 用于展示的文本内容。
+ typing: 是否显示输入中状态。
+ reply_message: 被回复的消息对象。
+ set_reply: 是否附带引用回复。
+ storage_message: 是否在发送成功后写入数据库。
+ show_log: 是否输出发送日志。
+
+ Returns:
+ bool: 全部组件发送成功时返回 ``True``。
+ """
+ success = True
for component in reply_set.components:
if isinstance(component, DictComponent):
- message_seg = Seg(type="dict", data=component.data) # type: ignore
+ message_seg = Seg(type="dict", data=component.data) # type: ignore[arg-type]
else:
message_seg = await component.to_seg()
+
status = await _send_to_target(
message_segment=message_seg,
stream_id=stream_id,
@@ -310,8 +708,8 @@ async def custom_reply_set_to_stream(
show_log=show_log,
)
if not status:
- flag = False
+ success = False
logger.error(f"[SendService] 发送消息组件失败,组件类型:{type(component).__name__}")
set_reply = False
- return flag
+ return success
From 18a0e7664ad23a1b582610557501f3d81923fd8a Mon Sep 17 00:00:00 2001
From: DrSmoothl <1787882683@qq.com>
Date: Mon, 23 Mar 2026 16:14:13 +0800
Subject: [PATCH 32/45] Refactor plugin runtime components and enhance message
handling
- Removed unused core action mirror functionality from PluginRunnerSupervisor.
- Simplified action and command execution logic in send_service.py.
- Introduced ComponentQueryService for unified component querying in plugin runtime.
- Enhanced message component handling with new binary component support.
- Improved message sequence construction and detection of outbound message flags.
- Updated methods for sending messages to streamline the process and improve readability.
---
pytests/test_platform_io_dedupe.py | 6 +-
pytests/test_plugin_runtime_action_bridge.py | 328 +++++---
src/chat/brain_chat/brain_planner.py | 32 +-
src/chat/message_receive/bot.py | 114 +--
src/chat/planner_actions/action_manager.py | 14 +-
src/chat/planner_actions/planner.py | 41 +-
src/chat/tool_executor.py | 14 +-
src/core/component_registry.py | 239 ------
src/platform_io/manager.py | 24 -
src/plugin_runtime/capabilities/core.py | 8 +-
src/plugin_runtime/capabilities/data.py | 4 +-
src/plugin_runtime/component_query.py | 709 ++++++++++++++++++
src/plugin_runtime/host/component_registry.py | 72 +-
src/plugin_runtime/host/supervisor.py | 238 ------
src/services/send_service.py | 265 ++++---
15 files changed, 1255 insertions(+), 853 deletions(-)
delete mode 100644 src/core/component_registry.py
create mode 100644 src/plugin_runtime/component_query.py
diff --git a/pytests/test_platform_io_dedupe.py b/pytests/test_platform_io_dedupe.py
index 68ae95c6..d6bdd1dd 100644
--- a/pytests/test_platform_io_dedupe.py
+++ b/pytests/test_platform_io_dedupe.py
@@ -72,10 +72,10 @@ class _StubPlatformIODriver(PlatformIODriver):
def _build_manager() -> PlatformIOManager:
- """构造带有最小 active owner 的 Broker 管理器。
+ """构造带有最小接收路由的 Broker 管理器。
Returns:
- PlatformIOManager: 已注册测试驱动并绑定活动路由的 Broker。
+ PlatformIOManager: 已注册测试驱动并绑定接收路由的 Broker。
"""
manager = PlatformIOManager()
driver = _StubPlatformIODriver(
@@ -88,7 +88,7 @@ def _build_manager() -> PlatformIOManager:
)
)
manager.register_driver(driver)
- manager.bind_route(
+ manager.bind_receive_route(
RouteBinding(
route_key=RouteKey(platform="qq", account_id="10001", scope="main"),
driver_id=driver.driver_id,
diff --git a/pytests/test_plugin_runtime_action_bridge.py b/pytests/test_plugin_runtime_action_bridge.py
index f2364094..e13dfaf3 100644
--- a/pytests/test_plugin_runtime_action_bridge.py
+++ b/pytests/test_plugin_runtime_action_bridge.py
@@ -1,57 +1,109 @@
+"""核心组件查询层与插件运行时聚合测试。"""
+
from types import SimpleNamespace
from typing import Any
import pytest
-from src.core.component_registry import component_registry as core_component_registry
+import src.plugin_runtime.integration as integration_module
+
+from src.core.types import ActionInfo, ToolInfo
+from src.plugin_runtime.component_query import component_query_service
from src.plugin_runtime.host.supervisor import PluginSupervisor
-from src.plugin_runtime.protocol.envelope import ComponentDeclaration, RegisterPluginPayload
-def _build_action_payload(plugin_id: str, action_name: str) -> RegisterPluginPayload:
- """构造用于测试的 runtime Action 注册载荷。
+class _FakeRuntimeManager:
+ """测试用插件运行时管理器。"""
+
+ def __init__(self, supervisor: PluginSupervisor, plugin_id: str, plugin_config: dict[str, Any]) -> None:
+ """初始化测试用运行时管理器。
+
+ Args:
+ supervisor: 持有测试组件的监督器。
+ plugin_id: 目标插件 ID。
+ plugin_config: 需要返回的插件配置。
+ """
+
+ self.supervisors = [supervisor]
+ self._plugin_id = plugin_id
+ self._plugin_config = plugin_config
+
+ def _get_supervisor_for_plugin(self, plugin_id: str) -> PluginSupervisor | None:
+ """按插件 ID 返回对应监督器。
+
+ Args:
+ plugin_id: 目标插件 ID。
+
+ Returns:
+ PluginSupervisor | None: 命中时返回监督器。
+ """
+
+ return self.supervisors[0] if plugin_id == self._plugin_id else None
+
+ def _load_plugin_config_for_supervisor(self, supervisor: Any, plugin_id: str) -> dict[str, Any]:
+ """返回测试配置。
+
+ Args:
+ supervisor: 监督器实例。
+ plugin_id: 目标插件 ID。
+
+ Returns:
+ dict[str, Any]: 测试配置内容。
+ """
+
+ del supervisor
+ if plugin_id != self._plugin_id:
+ return {}
+ return dict(self._plugin_config)
+
+
+def _install_runtime_manager(
+ monkeypatch: pytest.MonkeyPatch,
+ supervisor: PluginSupervisor,
+ plugin_id: str,
+ plugin_config: dict[str, Any] | None = None,
+) -> None:
+ """为测试安装假的运行时管理器。
Args:
- plugin_id: 插件 ID。
- action_name: Action 名称。
-
- Returns:
- RegisterPluginPayload: 测试用注册载荷。
+ monkeypatch: pytest monkeypatch 对象。
+ supervisor: 持有测试组件的监督器。
+ plugin_id: 测试插件 ID。
+ plugin_config: 可选的测试配置内容。
"""
- return RegisterPluginPayload(
- plugin_id=plugin_id,
- plugin_version="1.0.0",
- components=[
- ComponentDeclaration(
- name=action_name,
- component_type="ACTION",
- plugin_id=plugin_id,
- metadata={
- "description": "发送一个测试回复",
- "enabled": True,
- "activation_type": "keyword",
- "activation_probability": 0.25,
- "activation_keywords": ["测试", "hello"],
- "action_parameters": {"target": "目标对象"},
- "action_require": ["需要发送回复时使用"],
- "associated_types": ["text"],
- "parallel_action": True,
- },
- )
- ],
- )
+
+ fake_manager = _FakeRuntimeManager(supervisor, plugin_id, plugin_config or {"enabled": True})
+ monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: fake_manager)
@pytest.mark.asyncio
-async def test_runtime_actions_are_mirrored_into_core_registry_and_invoked(monkeypatch: pytest.MonkeyPatch) -> None:
- """运行时 Action 应镜像到旧核心注册表,并可由旧 Planner 执行。"""
+async def test_core_component_registry_reads_runtime_action_and_executor(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """核心查询层应直接读取运行时 Action,并返回 RPC 执行闭包。"""
+
plugin_id = "runtime_action_bridge_plugin"
action_name = "runtime_action_bridge_test"
- payload = _build_action_payload(plugin_id=plugin_id, action_name=action_name)
supervisor = PluginSupervisor(plugin_dirs=[])
captured: dict[str, Any] = {}
- core_component_registry.remove_action(action_name)
+ supervisor.component_registry.register_component(
+ name=action_name,
+ component_type="ACTION",
+ plugin_id=plugin_id,
+ metadata={
+ "description": "发送一个测试回复",
+ "enabled": True,
+ "activation_type": "keyword",
+ "activation_probability": 0.25,
+ "activation_keywords": ["测试", "hello"],
+ "action_parameters": {"target": "目标对象"},
+ "action_require": ["需要发送回复时使用"],
+ "associated_types": ["text"],
+ "parallel_action": True,
+ },
+ )
+ _install_runtime_manager(monkeypatch, supervisor, plugin_id, {"enabled": True, "mode": "test"})
async def fake_invoke_plugin(
method: str,
@@ -60,18 +112,8 @@ async def test_runtime_actions_are_mirrored_into_core_registry_and_invoked(monke
args: dict[str, Any] | None = None,
timeout_ms: int = 30000,
) -> Any:
- """模拟 plugin runtime Action 调用。
+ """模拟动作 RPC 调用。"""
- Args:
- method: RPC 方法名。
- plugin_id: 插件 ID。
- component_name: 组件名称。
- args: 调用参数。
- timeout_ms: RPC 超时时间。
-
- Returns:
- Any: 伪造的 RPC 响应对象。
- """
captured["method"] = method
captured["plugin_id"] = plugin_id
captured["component_name"] = component_name
@@ -81,58 +123,162 @@ async def test_runtime_actions_are_mirrored_into_core_registry_and_invoked(monke
monkeypatch.setattr(supervisor, "invoke_plugin", fake_invoke_plugin)
- try:
- supervisor._mirror_runtime_actions_to_core_registry(payload)
+ action_info = component_query_service.get_action_info(action_name)
+ assert isinstance(action_info, ActionInfo)
+ assert action_info.plugin_name == plugin_id
+ assert action_info.description == "发送一个测试回复"
+ assert action_info.activation_keywords == ["测试", "hello"]
+ assert action_info.random_activation_probability == 0.25
+ assert action_info.parallel_action is True
+ assert action_name in component_query_service.get_default_actions()
+ assert component_query_service.get_plugin_config(plugin_id) == {"enabled": True, "mode": "test"}
- action_info = core_component_registry.get_action_info(action_name)
- assert action_info is not None
- assert action_info.plugin_name == plugin_id
- assert action_info.description == "发送一个测试回复"
- assert action_info.activation_keywords == ["测试", "hello"]
- assert action_info.random_activation_probability == 0.25
- assert action_info.parallel_action is True
+ executor = component_query_service.get_action_executor(action_name)
+ assert executor is not None
- executor = core_component_registry.get_action_executor(action_name)
- assert executor is not None
+ success, reason = await executor(
+ action_data={"target": "MaiBot"},
+ action_reasoning="当前适合使用这个动作",
+ cycle_timers={"planner": 0.1},
+ thinking_id="tid-1",
+ chat_stream=SimpleNamespace(session_id="stream-1"),
+ log_prefix="[test]",
+ shutting_down=False,
+ plugin_config={"enabled": True},
+ )
- success, reason = await executor(
- action_data={"target": "MaiBot"},
- action_reasoning="当前适合使用这个动作",
- cycle_timers={"planner": 0.1},
- thinking_id="tid-1",
- chat_stream=SimpleNamespace(session_id="stream-1"),
- log_prefix="[test]",
- shutting_down=False,
- plugin_config={"enabled": True},
- )
-
- assert success is True
- assert reason == "runtime action executed"
- assert captured["method"] == "plugin.invoke_action"
- assert captured["plugin_id"] == plugin_id
- assert captured["component_name"] == action_name
- assert captured["args"]["stream_id"] == "stream-1"
- assert captured["args"]["chat_id"] == "stream-1"
- assert captured["args"]["reasoning"] == "当前适合使用这个动作"
- assert captured["args"]["target"] == "MaiBot"
- assert captured["args"]["action_data"] == {"target": "MaiBot"}
- finally:
- supervisor._remove_core_action_mirrors(plugin_id)
- core_component_registry.remove_action(action_name)
+ assert success is True
+ assert reason == "runtime action executed"
+ assert captured["method"] == "plugin.invoke_action"
+ assert captured["plugin_id"] == plugin_id
+ assert captured["component_name"] == action_name
+ assert captured["args"]["stream_id"] == "stream-1"
+ assert captured["args"]["chat_id"] == "stream-1"
+ assert captured["args"]["reasoning"] == "当前适合使用这个动作"
+ assert captured["args"]["target"] == "MaiBot"
+ assert captured["args"]["action_data"] == {"target": "MaiBot"}
-def test_clear_runner_state_removes_mirrored_runtime_actions() -> None:
- """清理 Runner 状态时应同步移除旧核心注册表中的镜像 Action。"""
- plugin_id = "runtime_action_bridge_cleanup_plugin"
- action_name = "runtime_action_bridge_cleanup_test"
- payload = _build_action_payload(plugin_id=plugin_id, action_name=action_name)
+@pytest.mark.asyncio
+async def test_core_component_registry_reads_runtime_command_and_executor(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """核心查询层应直接使用运行时命令匹配与执行闭包。"""
+
+ plugin_id = "runtime_command_bridge_plugin"
+ command_name = "runtime_command_bridge_test"
+ supervisor = PluginSupervisor(plugin_dirs=[])
+ captured: dict[str, Any] = {}
+
+ supervisor.component_registry.register_component(
+ name=command_name,
+ component_type="COMMAND",
+ plugin_id=plugin_id,
+ metadata={
+ "description": "测试命令",
+ "enabled": True,
+ "command_pattern": r"^/test(?:\s+.+)?$",
+ "aliases": ["/hello"],
+ "intercept_message_level": 1,
+ },
+ )
+ _install_runtime_manager(monkeypatch, supervisor, plugin_id, {"mode": "command"})
+
+ async def fake_invoke_plugin(
+ method: str,
+ plugin_id: str,
+ component_name: str,
+ args: dict[str, Any] | None = None,
+ timeout_ms: int = 30000,
+ ) -> Any:
+ """模拟命令 RPC 调用。"""
+
+ captured["method"] = method
+ captured["plugin_id"] = plugin_id
+ captured["component_name"] = component_name
+ captured["args"] = args or {}
+ captured["timeout_ms"] = timeout_ms
+ return SimpleNamespace(payload={"success": True, "result": (True, "command ok", True)})
+
+ monkeypatch.setattr(supervisor, "invoke_plugin", fake_invoke_plugin)
+
+ matched = component_query_service.find_command_by_text("/test hello")
+ assert matched is not None
+ command_executor, matched_groups, command_info = matched
+
+ assert matched_groups == {}
+ assert command_info.plugin_name == plugin_id
+ assert command_info.command_pattern == r"^/test(?:\s+.+)?$"
+
+ success, response_text, intercept = await command_executor(
+ message=SimpleNamespace(processed_plain_text="/test hello", session_id="stream-2"),
+ plugin_config={"mode": "command"},
+ matched_groups=matched_groups,
+ )
+
+ assert success is True
+ assert response_text == "command ok"
+ assert intercept is True
+ assert captured["method"] == "plugin.invoke_command"
+ assert captured["plugin_id"] == plugin_id
+ assert captured["component_name"] == command_name
+ assert captured["args"]["text"] == "/test hello"
+ assert captured["args"]["stream_id"] == "stream-2"
+ assert captured["args"]["plugin_config"] == {"mode": "command"}
+
+
+@pytest.mark.asyncio
+async def test_core_component_registry_reads_runtime_tools_and_executor(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """核心查询层应直接读取运行时 Tool,并返回 RPC 执行闭包。"""
+
+ plugin_id = "runtime_tool_bridge_plugin"
+ tool_name = "runtime_tool_bridge_test"
supervisor = PluginSupervisor(plugin_dirs=[])
- core_component_registry.remove_action(action_name)
+ supervisor.component_registry.register_component(
+ name=tool_name,
+ component_type="TOOL",
+ plugin_id=plugin_id,
+ metadata={
+ "description": "测试工具",
+ "enabled": True,
+ "parameters": [
+ {
+ "name": "query",
+ "param_type": "string",
+ "description": "查询词",
+ "required": True,
+ }
+ ],
+ },
+ )
+ _install_runtime_manager(monkeypatch, supervisor, plugin_id)
- supervisor._mirror_runtime_actions_to_core_registry(payload)
- assert core_component_registry.get_action_info(action_name) is not None
+ async def fake_invoke_plugin(
+ method: str,
+ plugin_id: str,
+ component_name: str,
+ args: dict[str, Any] | None = None,
+ timeout_ms: int = 30000,
+ ) -> Any:
+ """模拟工具 RPC 调用。"""
- supervisor._clear_runner_state()
+ del timeout_ms
+ assert method == "plugin.invoke_tool"
+ assert plugin_id == "runtime_tool_bridge_plugin"
+ assert component_name == "runtime_tool_bridge_test"
+ assert args == {"query": "MaiBot"}
+ return SimpleNamespace(payload={"success": True, "result": {"content": "tool ok"}})
- assert core_component_registry.get_action_info(action_name) is None
+ monkeypatch.setattr(supervisor, "invoke_plugin", fake_invoke_plugin)
+
+ tool_info = component_query_service.get_tool_info(tool_name)
+ assert isinstance(tool_info, ToolInfo)
+ assert tool_info.tool_description == "测试工具"
+ assert tool_name in component_query_service.get_llm_available_tools()
+
+ executor = component_query_service.get_tool_executor(tool_name)
+ assert executor is not None
+ assert await executor({"query": "MaiBot"}) == {"content": "tool ok"}
diff --git a/src/chat/brain_chat/brain_planner.py b/src/chat/brain_chat/brain_planner.py
index 12b103a0..709be8ee 100644
--- a/src/chat/brain_chat/brain_planner.py
+++ b/src/chat/brain_chat/brain_planner.py
@@ -1,30 +1,32 @@
+from datetime import datetime
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
+
import json
-import time
-import traceback
import random
import re
-from typing import Dict, Optional, Tuple, List, TYPE_CHECKING
-from rich.traceback import install
-from datetime import datetime
-from json_repair import repair_json
+import time
+import traceback
+
+from json_repair import repair_json
+from rich.traceback import install
-from src.llm_models.utils_model import LLMRequest
-from src.config.config import global_config, model_config
-from src.common.logger import get_logger
from src.chat.logger.plan_reply_logger import PlanReplyLogger
+from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
+from src.chat.planner_actions.action_manager import ActionManager
+from src.chat.utils.utils import get_chat_type_and_target_info
from src.common.data_models.info_data_model import ActionPlannerInfo
+from src.common.logger import get_logger
from src.common.utils.utils_action import ActionUtils
+from src.config.config import global_config, model_config
+from src.core.types import ActionActivationType, ActionInfo, ComponentType
+from src.llm_models.utils_model import LLMRequest
+from src.plugin_runtime.component_query import component_query_service
from src.prompt.prompt_manager import prompt_manager
from src.services.message_service import (
build_readable_messages_with_id,
get_actions_by_timestamp_with_chat,
get_messages_before_time_in_chat,
)
-from src.chat.utils.utils import get_chat_type_and_target_info
-from src.chat.planner_actions.action_manager import ActionManager
-from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
-from src.core.types import ActionActivationType, ActionInfo, ComponentType
-from src.core.component_registry import component_registry
if TYPE_CHECKING:
from src.common.data_models.info_data_model import TargetPersonInfo
@@ -320,7 +322,7 @@ class BrainPlanner:
current_available_actions_dict = self.action_manager.get_using_actions()
# 获取完整的动作信息
- all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore
+ all_registered_actions: Dict[str, ActionInfo] = component_query_service.get_components_by_type( # type: ignore
ComponentType.ACTION
)
current_available_actions = {}
diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py
index 23e7de6e..025150fc 100644
--- a/src/chat/message_receive/bot.py
+++ b/src/chat/message_receive/bot.py
@@ -1,19 +1,19 @@
from contextlib import suppress
-import traceback
-import os
-
-from maim_message import MessageBase
from typing import Any, Dict, Optional
+import os
+import traceback
+from maim_message import MessageBase
+
+from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
from src.common.logger import get_logger
from src.common.utils.utils_message import MessageUtils
from src.common.utils.utils_session import SessionUtils
-from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
# from src.chat.brain_chat.PFC.pfc_manager import PFCManager
from src.core.announcement_manager import global_announcement_manager
-from src.core.component_registry import component_registry
+from src.plugin_runtime.component_query import component_query_service
from .message import SessionMessage
from .chat_manager import chat_manager
@@ -58,16 +58,22 @@ class ChatBot:
logger.error(f"创建PFC聊天失败: {e}")
logger.error(traceback.format_exc())
- async def _process_commands(self, message: SessionMessage):
- # sourcery skip: use-named-expression
- """使用新插件系统处理命令"""
+ async def _process_commands(self, message: SessionMessage) -> tuple[bool, Optional[str], bool]:
+ """使用统一组件注册表处理命令。
+
+ Args:
+ message: 当前待处理的会话消息。
+
+ Returns:
+ tuple[bool, Optional[str], bool]: ``(是否命中命令, 命令响应文本, 是否继续后续处理)``。
+ """
if not message.processed_plain_text:
return False, None, True # 没有文本内容,继续处理消息
try:
text = message.processed_plain_text
- # 使用核心组件注册表查找命令
- command_result = component_registry.find_command_by_text(text)
+ # 使用插件运行时统一查询服务查找命令
+ command_result = component_query_service.find_command_by_text(text)
if command_result:
command_executor, matched_groups, command_info = command_result
plugin_name = command_info.plugin_name
@@ -81,7 +87,7 @@ class ChatBot:
message.is_command = True
# 获取插件配置
- plugin_config = component_registry.get_plugin_config(plugin_name)
+ plugin_config = component_query_service.get_plugin_config(plugin_name)
try:
# 调用命令执行器
@@ -112,88 +118,32 @@ class ChatBot:
# 命令出错时,根据命令的拦截设置决定是否继续处理消息
return True, str(e), False # 出错时继续处理消息
- # 没有找到旧系统命令,尝试新版本插件运行时
- new_cmd_result = await self._process_new_runtime_command(message)
- return new_cmd_result if new_cmd_result is not None else (False, None, True)
+ return False, None, True
except Exception as e:
logger.error(f"处理命令时出错: {e}")
return False, None, True # 出错时继续处理消息
- async def _process_new_runtime_command(self, message: SessionMessage):
- """尝试在新版本插件运行时中查找并执行命令
-
- Returns:
- (found, response, continue_processing) 三元组,
- 或 None 表示新运行时中也未找到匹配命令。
- """
- from src.plugin_runtime.integration import get_plugin_runtime_manager
-
- prm = get_plugin_runtime_manager()
- if not prm.is_running:
- return None
-
- matched = prm.find_command_by_text(message.processed_plain_text)
- if matched is None:
- return None
-
- command_name = matched["name"]
- if message.session_id and command_name in global_announcement_manager.get_disabled_chat_commands(
- message.session_id
- ):
- logger.info(f"[新运行时] 用户禁用的命令,跳过处理: {matched['full_name']}")
- return False, None, True
-
- message.is_command = True
- logger.info(f"[新运行时] 匹配命令: {matched['full_name']}")
-
- try:
- resp = await prm.invoke_plugin(
- method="plugin.invoke_command",
- plugin_id=matched["plugin_id"],
- component_name=matched["name"],
- args={
- "text": message.processed_plain_text,
- "stream_id": message.session_id or "",
- "matched_groups": matched.get("matched_groups") or {},
- },
- timeout_ms=30000,
- )
-
- payload = resp.payload
- success = payload.get("success", False)
- cmd_result = payload.get("result")
-
- # 拦截位优先从命令返回值中获取(支持运行时动态决定),
- # 回退到组件 metadata 中的静态声明
- if isinstance(cmd_result, (list, tuple)) and len(cmd_result) >= 3:
- # 命令返回 (found, response_text, intercept_bool) 三元组
- response_text = cmd_result[1] if cmd_result[1] is not None else ""
- intercept = bool(cmd_result[2])
- else:
- response_text = cmd_result if cmd_result is not None else ""
- intercept = bool(matched["metadata"].get("intercept_message_level", 0))
-
- self._mark_command_message(message, int(intercept))
-
- if success:
- logger.info(f"[新运行时] 命令执行成功: {matched['full_name']}")
- else:
- logger.warning(f"[新运行时] 命令执行失败: {matched['full_name']} - {response_text}")
-
- return True, response_text, not intercept
-
- except Exception as e:
- logger.error(f"[新运行时] 执行命令 {matched['full_name']} 异常: {e}", exc_info=True)
- return True, str(e), True
-
@staticmethod
def _mark_command_message(message: SessionMessage, intercept_message_level: int) -> None:
+ """标记消息已经被命令链消费。
+
+ Args:
+ message: 待标记的会话消息。
+ intercept_message_level: 命令设置的拦截级别。
+ """
+
message.is_command = True
message.message_info.additional_config["intercept_message_level"] = intercept_message_level
@staticmethod
def _store_intercepted_command_message(message: SessionMessage) -> None:
+ """将被命令链拦截的消息写入数据库。
+
+ Args:
+ message: 已完成命令处理的会话消息。
+ """
+
MessageUtils.store_message_to_db(message)
async def _handle_command_processing_result(
diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py
index 167cdcab..8133ac18 100644
--- a/src/chat/planner_actions/action_manager.py
+++ b/src/chat/planner_actions/action_manager.py
@@ -3,8 +3,8 @@ from typing import Dict, Optional, Tuple
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.message import SessionMessage
from src.common.logger import get_logger
-from src.core.component_registry import component_registry, ActionExecutor
from src.core.types import ActionInfo
+from src.plugin_runtime.component_query import ActionExecutor, component_query_service
logger = get_logger("action_manager")
@@ -28,7 +28,7 @@ class ActionManager:
"""
动作管理器,用于管理各种类型的动作
- 使用核心组件注册表的 executor-based 模式。
+ 使用插件运行时统一查询服务的 executor-based 模式。
"""
def __init__(self):
@@ -38,7 +38,7 @@ class ActionManager:
self._using_actions: Dict[str, ActionInfo] = {}
# 初始化时将默认动作加载到使用中的动作
- self._using_actions = component_registry.get_default_actions()
+ self._using_actions = component_query_service.get_default_actions()
# === 执行Action方法 ===
@@ -72,17 +72,17 @@ class ActionManager:
Optional[ActionHandle]: 执行句柄,如果动作未注册则返回 None
"""
try:
- executor = component_registry.get_action_executor(action_name)
+ executor = component_query_service.get_action_executor(action_name)
if not executor:
logger.warning(f"{log_prefix} 未找到Action组件: {action_name}")
return None
- info = component_registry.get_action_info(action_name)
+ info = component_query_service.get_action_info(action_name)
if not info:
logger.warning(f"{log_prefix} 未找到Action组件信息: {action_name}")
return None
- plugin_config = component_registry.get_plugin_config(info.plugin_name) or {}
+ plugin_config = component_query_service.get_plugin_config(info.plugin_name) or {}
handle = ActionHandle(
executor,
@@ -133,5 +133,5 @@ class ActionManager:
def restore_actions(self) -> None:
"""恢复到默认动作集"""
actions_to_restore = list(self._using_actions.keys())
- self._using_actions = component_registry.get_default_actions()
+ self._using_actions = component_query_service.get_default_actions()
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py
index 5184abcb..b21efa6b 100644
--- a/src/chat/planner_actions/planner.py
+++ b/src/chat/planner_actions/planner.py
@@ -1,33 +1,36 @@
+from collections import OrderedDict
+from datetime import datetime
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+
+import contextlib
import json
-import time
-import traceback
import random
import re
-import contextlib
-from typing import Dict, Optional, Tuple, List, TYPE_CHECKING, Union
-from collections import OrderedDict
-from rich.traceback import install
-from datetime import datetime
+import time
+import traceback
+
from json_repair import repair_json
-from src.llm_models.utils_model import LLMRequest
-from src.config.config import global_config, model_config
-from src.common.logger import get_logger
+from rich.traceback import install
+
from src.chat.logger.plan_reply_logger import PlanReplyLogger
+from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
+from src.chat.message_receive.message import SessionMessage
+from src.chat.planner_actions.action_manager import ActionManager
+from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
from src.common.data_models.info_data_model import ActionPlannerInfo
+from src.common.logger import get_logger
+from src.config.config import global_config, model_config
+from src.core.types import ActionActivationType, ActionInfo, ComponentType
+from src.llm_models.utils_model import LLMRequest
+from src.person_info.person_info import Person
+from src.plugin_runtime.component_query import component_query_service
from src.prompt.prompt_manager import prompt_manager
from src.services.message_service import (
build_readable_messages_with_id,
- replace_user_references,
get_messages_before_time_in_chat,
+ replace_user_references,
translate_pid_to_description,
)
-from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
-from src.chat.planner_actions.action_manager import ActionManager
-from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
-from src.chat.message_receive.message import SessionMessage
-from src.core.types import ActionActivationType, ActionInfo, ComponentType
-from src.core.component_registry import component_registry
-from src.person_info.person_info import Person
if TYPE_CHECKING:
from src.common.data_models.info_data_model import TargetPersonInfo
@@ -634,7 +637,7 @@ class ActionPlanner:
current_available_actions_dict = self.action_manager.get_using_actions()
# 获取完整的动作信息
- all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore
+ all_registered_actions: Dict[str, ActionInfo] = component_query_service.get_components_by_type( # type: ignore
ComponentType.ACTION
)
current_available_actions = {}
diff --git a/src/chat/tool_executor.py b/src/chat/tool_executor.py
index d449f7a1..aa99fce8 100644
--- a/src/chat/tool_executor.py
+++ b/src/chat/tool_executor.py
@@ -1,22 +1,20 @@
-"""
-工具执行器
+"""工具执行器。
独立的工具执行组件,可以直接输入聊天消息内容,
自动判断并执行相应的工具,返回结构化的工具执行结果。
-
-从 src.plugin_system.core.tool_use 迁移,使用新的核心组件注册表。
"""
+from typing import Any, Dict, List, Optional, Tuple
+
import hashlib
import time
-from typing import Any, Dict, List, Optional, Tuple
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.core.announcement_manager import global_announcement_manager
-from src.core.component_registry import component_registry
from src.llm_models.payload_content import ToolCall
from src.llm_models.utils_model import LLMRequest
+from src.plugin_runtime.component_query import component_query_service
from src.prompt.prompt_manager import prompt_manager
logger = get_logger("tool_use")
@@ -89,7 +87,7 @@ class ToolExecutor:
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
"""获取 LLM 可用的工具定义列表"""
- all_tools = component_registry.get_llm_available_tools()
+ all_tools = component_query_service.get_llm_available_tools()
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
return [info.get_llm_definition() for name, info in all_tools.items() if name not in user_disabled_tools]
@@ -152,7 +150,7 @@ class ToolExecutor:
function_args = tool_call.args or {}
function_args["llm_called"] = True
- executor = component_registry.get_tool_executor(function_name)
+ executor = component_query_service.get_tool_executor(function_name)
if not executor:
logger.warning(f"未知工具名称: {function_name}")
return None
diff --git a/src/core/component_registry.py b/src/core/component_registry.py
deleted file mode 100644
index bb58682a..00000000
--- a/src/core/component_registry.py
+++ /dev/null
@@ -1,239 +0,0 @@
-"""
-核心组件注册表
-
-面向最终架构的组件管理:
-- Action:注册 ActionInfo + 执行器(本地 callable 或 IPC 路由)
-- Command:注册正则模式 + 执行器
-- Tool:注册工具定义 + 执行器
-
-不依赖任何插件基类,组件执行器是纯 async callable。
-"""
-
-import re
-from typing import Any, Awaitable, Callable, Dict, Optional, Pattern, Tuple
-
-from src.common.logger import get_logger
-from src.core.types import (
- ActionInfo,
- CommandInfo,
- ComponentInfo,
- ComponentType,
- ToolInfo,
-)
-
-logger = get_logger("component_registry")
-
-# 执行器类型
-ActionExecutor = Callable[..., Awaitable[Any]]
-CommandExecutor = Callable[..., Awaitable[Tuple[bool, Optional[str], bool]]]
-ToolExecutor = Callable[..., Awaitable[Any]]
-
-
-class ComponentRegistry:
- """核心组件注册表
-
- 管理 action、command、tool 三类组件。
- 每个组件由「元信息 + 执行器」构成,执行器是 async callable,
- 不需要继承任何基类。
- """
-
- def __init__(self):
- # Action 注册
- self._actions: Dict[str, ActionInfo] = {}
- self._action_executors: Dict[str, ActionExecutor] = {}
- self._default_actions: Dict[str, ActionInfo] = {}
-
- # Command 注册
- self._commands: Dict[str, CommandInfo] = {}
- self._command_executors: Dict[str, CommandExecutor] = {}
- self._command_patterns: Dict[Pattern, str] = {}
-
- # Tool 注册
- self._tools: Dict[str, ToolInfo] = {}
- self._tool_executors: Dict[str, ToolExecutor] = {}
- self._llm_available_tools: Dict[str, ToolInfo] = {}
-
- # 插件配置(plugin_name -> config dict)
- self._plugin_configs: Dict[str, dict] = {}
-
- logger.info("核心组件注册表初始化完成")
-
- # ========== Action ==========
-
- def register_action(
- self,
- info: ActionInfo,
- executor: ActionExecutor,
- ) -> bool:
- """注册 action
-
- Args:
- info: action 元信息
- executor: 执行器,async callable
- """
- name = info.name
- if name in self._actions:
- logger.warning(f"Action {name} 已存在,跳过注册")
- return False
-
- self._actions[name] = info
- self._action_executors[name] = executor
-
- if info.enabled:
- self._default_actions[name] = info
-
- logger.debug(f"注册 Action: {name}")
- return True
-
- def get_action_info(self, name: str) -> Optional[ActionInfo]:
- return self._actions.get(name)
-
- def get_action_executor(self, name: str) -> Optional[ActionExecutor]:
- return self._action_executors.get(name)
-
- def get_default_actions(self) -> Dict[str, ActionInfo]:
- return self._default_actions.copy()
-
- def get_all_actions(self) -> Dict[str, ActionInfo]:
- return self._actions.copy()
-
- def remove_action(self, name: str) -> bool:
- if name not in self._actions:
- return False
- del self._actions[name]
- self._action_executors.pop(name, None)
- self._default_actions.pop(name, None)
- logger.debug(f"移除 Action: {name}")
- return True
-
- # ========== Command ==========
-
- def register_command(
- self,
- info: CommandInfo,
- executor: CommandExecutor,
- ) -> bool:
- """注册 command"""
- name = info.name
- if name in self._commands:
- logger.warning(f"Command {name} 已存在,跳过注册")
- return False
-
- self._commands[name] = info
- self._command_executors[name] = executor
-
- if info.enabled and info.command_pattern:
- pattern = re.compile(info.command_pattern, re.IGNORECASE | re.DOTALL)
- self._command_patterns[pattern] = name
-
- logger.debug(f"注册 Command: {name}")
- return True
-
- def find_command_by_text(self, text: str) -> Optional[Tuple[CommandExecutor, dict, CommandInfo]]:
- """根据文本查找匹配的命令
-
- Returns:
- (executor, matched_groups, command_info) 或 None
- """
- candidates = [p for p in self._command_patterns if p.match(text)]
- if not candidates:
- return None
- if len(candidates) > 1:
- logger.warning(f"文本 '{text[:50]}' 匹配到多个命令模式,使用第一个")
- pattern = candidates[0]
- name = self._command_patterns[pattern]
- return (
- self._command_executors[name],
- pattern.match(text).groupdict(), # type: ignore
- self._commands[name],
- )
-
- def remove_command(self, name: str) -> bool:
- if name not in self._commands:
- return False
- del self._commands[name]
- self._command_executors.pop(name, None)
- self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != name}
- logger.debug(f"移除 Command: {name}")
- return True
-
- # ========== Tool ==========
-
- def register_tool(
- self,
- info: ToolInfo,
- executor: ToolExecutor,
- ) -> bool:
- """注册 tool"""
- name = info.name
- if name in self._tools:
- logger.warning(f"Tool {name} 已存在,跳过注册")
- return False
-
- self._tools[name] = info
- self._tool_executors[name] = executor
-
- if info.enabled:
- self._llm_available_tools[name] = info
-
- logger.debug(f"注册 Tool: {name}")
- return True
-
- def get_tool_info(self, name: str) -> Optional[ToolInfo]:
- return self._tools.get(name)
-
- def get_tool_executor(self, name: str) -> Optional[ToolExecutor]:
- return self._tool_executors.get(name)
-
- def get_llm_available_tools(self) -> Dict[str, ToolInfo]:
- return self._llm_available_tools.copy()
-
- def get_all_tools(self) -> Dict[str, ToolInfo]:
- return self._tools.copy()
-
- def remove_tool(self, name: str) -> bool:
- if name not in self._tools:
- return False
- del self._tools[name]
- self._tool_executors.pop(name, None)
- self._llm_available_tools.pop(name, None)
- logger.debug(f"移除 Tool: {name}")
- return True
-
- # ========== 通用查询 ==========
-
- def get_component_info(self, name: str, component_type: ComponentType) -> Optional[ComponentInfo]:
- """获取组件元信息"""
- match component_type:
- case ComponentType.ACTION:
- return self._actions.get(name)
- case ComponentType.COMMAND:
- return self._commands.get(name)
- case ComponentType.TOOL:
- return self._tools.get(name)
- case _:
- return None
-
- def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
- """获取某类型的所有组件"""
- match component_type:
- case ComponentType.ACTION:
- return dict(self._actions)
- case ComponentType.COMMAND:
- return dict(self._commands)
- case ComponentType.TOOL:
- return dict(self._tools)
- case _:
- return {}
-
- # ========== 插件配置 ==========
-
- def set_plugin_config(self, plugin_name: str, config: dict) -> None:
- self._plugin_configs[plugin_name] = config
-
- def get_plugin_config(self, plugin_name: str) -> Optional[dict]:
- return self._plugin_configs.get(plugin_name)
-
-
-# 全局单例
-component_registry = ComponentRegistry()
diff --git a/src/platform_io/manager.py b/src/platform_io/manager.py
index cb5996b4..be03e35d 100644
--- a/src/platform_io/manager.py
+++ b/src/platform_io/manager.py
@@ -178,12 +178,6 @@ class PlatformIOManager:
return self._receive_route_table
- @property
- def route_table(self) -> RouteTable:
- """兼容旧接口,返回发送路由表。"""
-
- return self._send_route_table
-
@property
def deduplicator(self) -> MessageDeduplicator:
"""返回管理器持有的入站去重器。
@@ -369,12 +363,6 @@ class PlatformIOManager:
self._validate_binding_against_driver(binding, driver)
self._receive_route_table.bind(binding)
- def bind_route(self, binding: RouteBinding) -> None:
- """兼容旧接口,默认同时绑定发送表和接收表。"""
-
- self.bind_send_route(binding)
- self.bind_receive_route(binding)
-
def unbind_send_route(self, route_key: RouteKey, driver_id: Optional[str] = None) -> None:
"""移除发送路由绑定。
@@ -395,12 +383,6 @@ class PlatformIOManager:
self._receive_route_table.unbind(route_key, driver_id)
- def unbind_route(self, route_key: RouteKey, driver_id: Optional[str] = None) -> None:
- """兼容旧接口,默认同时从发送表和接收表解绑。"""
-
- self.unbind_send_route(route_key, driver_id)
- self.unbind_receive_route(route_key, driver_id)
-
def resolve_drivers(self, route_key: RouteKey) -> List[PlatformIODriver]:
"""解析某个路由键当前命中的全部发送驱动。
@@ -430,12 +412,6 @@ class PlatformIOManager:
return []
return [fallback_driver]
- def resolve_driver(self, route_key: RouteKey) -> Optional[PlatformIODriver]:
- """兼容旧接口,返回首个命中的发送驱动。"""
-
- drivers = self.resolve_drivers(route_key)
- return drivers[0] if drivers else None
-
@staticmethod
def build_route_key_from_message(message: "SessionMessage") -> RouteKey:
"""根据 ``SessionMessage`` 构造路由键。
diff --git a/src/plugin_runtime/capabilities/core.py b/src/plugin_runtime/capabilities/core.py
index def5f03d..9bb1755b 100644
--- a/src/plugin_runtime/capabilities/core.py
+++ b/src/plugin_runtime/capabilities/core.py
@@ -238,14 +238,14 @@ class RuntimeCoreCapabilityMixin:
return {"success": False, "value": None, "error": str(e)}
async def _cap_config_get_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
- from src.core.component_registry import component_registry as core_registry
+ from src.plugin_runtime.component_query import component_query_service
plugin_name: str = args.get("plugin_name", plugin_id)
key: str = args.get("key", "")
default = args.get("default")
try:
- config = core_registry.get_plugin_config(plugin_name)
+ config = component_query_service.get_plugin_config(plugin_name)
if config is None:
return {"success": False, "value": default, "error": f"未找到插件 {plugin_name} 的配置"}
@@ -258,11 +258,11 @@ class RuntimeCoreCapabilityMixin:
return {"success": False, "value": default, "error": str(e)}
async def _cap_config_get_all(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
- from src.core.component_registry import component_registry as core_registry
+ from src.plugin_runtime.component_query import component_query_service
plugin_name: str = args.get("plugin_name", plugin_id)
try:
- config = core_registry.get_plugin_config(plugin_name)
+ config = component_query_service.get_plugin_config(plugin_name)
if config is None:
return {"success": True, "value": {}}
return {"success": True, "value": config}
diff --git a/src/plugin_runtime/capabilities/data.py b/src/plugin_runtime/capabilities/data.py
index c4ae0a56..fdf8d898 100644
--- a/src/plugin_runtime/capabilities/data.py
+++ b/src/plugin_runtime/capabilities/data.py
@@ -648,10 +648,10 @@ class RuntimeDataCapabilityMixin:
return {"success": False, "error": str(e)}
async def _cap_tool_get_definitions(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
- from src.core.component_registry import component_registry as core_registry
+ from src.plugin_runtime.component_query import component_query_service
try:
- tools = core_registry.get_llm_available_tools()
+ tools = component_query_service.get_llm_available_tools()
return {
"success": True,
"tools": [{"name": name, "definition": info.get_llm_definition()} for name, info in tools.items()],
diff --git a/src/plugin_runtime/component_query.py b/src/plugin_runtime/component_query.py
new file mode 100644
index 00000000..7d23d202
--- /dev/null
+++ b/src/plugin_runtime/component_query.py
@@ -0,0 +1,709 @@
+"""插件运行时统一组件查询服务。
+
+该模块统一从插件运行时的 Host ComponentRegistry 中聚合只读视图,
+供 HFC/PFC、Planner、ToolExecutor 和运行时能力层查询与调用。
+"""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tuple
+
+from src.common.logger import get_logger
+from src.core.types import ActionActivationType, ActionInfo, CommandInfo, ComponentInfo, ComponentType, ToolInfo
+from src.llm_models.payload_content.tool_option import ToolParamType
+
+if TYPE_CHECKING:
+ from src.plugin_runtime.host.component_registry import ActionEntry, CommandEntry, ComponentEntry, ToolEntry
+ from src.plugin_runtime.host.supervisor import PluginSupervisor
+ from src.plugin_runtime.integration import PluginRuntimeManager
+
+logger = get_logger("plugin_runtime.component_query")
+
+ActionExecutor = Callable[..., Awaitable[Any]]
+CommandExecutor = Callable[..., Awaitable[Tuple[bool, Optional[str], bool]]]
+ToolExecutor = Callable[..., Awaitable[Any]]
+
+_HOST_COMPONENT_TYPE_MAP: Dict[ComponentType, str] = {
+ ComponentType.ACTION: "ACTION",
+ ComponentType.COMMAND: "COMMAND",
+ ComponentType.TOOL: "TOOL",
+}
+_TOOL_PARAM_TYPE_MAP: Dict[str, ToolParamType] = {
+ "string": ToolParamType.STRING,
+ "integer": ToolParamType.INTEGER,
+ "float": ToolParamType.FLOAT,
+ "boolean": ToolParamType.BOOLEAN,
+ "bool": ToolParamType.BOOLEAN,
+}
+
+
+class ComponentQueryService:
+ """插件运行时统一组件查询服务。
+
+ 该对象不维护独立状态,只读取插件系统中的注册结果。
+ 所有注册、删除、配置写入等写操作都被显式禁用。
+ """
+
+ @staticmethod
+ def _get_runtime_manager() -> "PluginRuntimeManager":
+ """获取插件运行时管理器单例。
+
+ Returns:
+ PluginRuntimeManager: 当前全局插件运行时管理器。
+ """
+
+ from src.plugin_runtime.integration import get_plugin_runtime_manager
+
+ return get_plugin_runtime_manager()
+
+ def _iter_supervisors(self) -> list["PluginSupervisor"]:
+ """获取当前所有活跃的插件运行时监督器。
+
+ Returns:
+ list[PluginSupervisor]: 当前运行中的监督器列表。
+ """
+
+ runtime_manager = self._get_runtime_manager()
+ return list(runtime_manager.supervisors)
+
+ def _iter_component_entries(
+ self,
+ component_type: ComponentType,
+ *,
+ enabled_only: bool = True,
+ ) -> list[tuple["PluginSupervisor", "ComponentEntry"]]:
+ """遍历指定类型的全部组件条目。
+
+ Args:
+ component_type: 目标组件类型。
+ enabled_only: 是否仅返回启用状态的组件。
+
+ Returns:
+ list[tuple[PluginSupervisor, ComponentEntry]]: ``(监督器, 组件条目)`` 列表。
+ """
+
+ host_component_type = _HOST_COMPONENT_TYPE_MAP.get(component_type)
+ if host_component_type is None:
+ return []
+
+ collected_entries: list[tuple["PluginSupervisor", "ComponentEntry"]] = []
+ for supervisor in self._iter_supervisors():
+ for component in supervisor.component_registry.get_components_by_type(
+ host_component_type,
+ enabled_only=enabled_only,
+ ):
+ collected_entries.append((supervisor, component))
+ return collected_entries
+
+ @staticmethod
+ def _coerce_action_activation_type(raw_value: Any) -> ActionActivationType:
+ """规范化动作激活类型。
+
+ Args:
+ raw_value: 原始激活类型值。
+
+ Returns:
+ ActionActivationType: 规范化后的激活类型枚举。
+ """
+
+ normalized_value = str(raw_value or "").strip().lower()
+ if normalized_value == ActionActivationType.NEVER.value:
+ return ActionActivationType.NEVER
+ if normalized_value == ActionActivationType.RANDOM.value:
+ return ActionActivationType.RANDOM
+ if normalized_value == ActionActivationType.KEYWORD.value:
+ return ActionActivationType.KEYWORD
+ return ActionActivationType.ALWAYS
+
+ @staticmethod
+ def _coerce_float(value: Any, default: float = 0.0) -> float:
+ """将任意值安全转换为浮点数。
+
+ Args:
+ value: 待转换的输入值。
+ default: 转换失败时返回的默认值。
+
+ Returns:
+ float: 转换后的浮点结果。
+ """
+
+ try:
+ return float(value)
+ except (TypeError, ValueError):
+ return default
+
+ @staticmethod
+ def _build_action_info(entry: "ActionEntry") -> ActionInfo:
+ """将运行时 Action 条目转换为核心动作信息。
+
+ Args:
+ entry: 插件运行时中的 Action 条目。
+
+ Returns:
+ ActionInfo: 供核心 Planner 使用的动作信息。
+ """
+
+ metadata = dict(entry.metadata)
+ raw_action_parameters = metadata.get("action_parameters")
+ action_parameters = (
+ {
+ str(param_name): str(param_description)
+ for param_name, param_description in raw_action_parameters.items()
+ }
+ if isinstance(raw_action_parameters, dict)
+ else {}
+ )
+ action_require = [
+ str(item)
+ for item in (metadata.get("action_require") or [])
+ if item is not None and str(item).strip()
+ ]
+ associated_types = [
+ str(item)
+ for item in (metadata.get("associated_types") or [])
+ if item is not None and str(item).strip()
+ ]
+ activation_keywords = [
+ str(item)
+ for item in (metadata.get("activation_keywords") or [])
+ if item is not None and str(item).strip()
+ ]
+
+ return ActionInfo(
+ name=entry.name,
+ component_type=ComponentType.ACTION,
+ description=str(metadata.get("description", "") or ""),
+ enabled=bool(entry.enabled),
+ plugin_name=entry.plugin_id,
+ metadata=metadata,
+ action_parameters=action_parameters,
+ action_require=action_require,
+ associated_types=associated_types,
+ activation_type=ComponentQueryService._coerce_action_activation_type(metadata.get("activation_type")),
+ random_activation_probability=ComponentQueryService._coerce_float(
+ metadata.get("activation_probability"),
+ 0.0,
+ ),
+ activation_keywords=activation_keywords,
+ parallel_action=bool(metadata.get("parallel_action", False)),
+ )
+
+ @staticmethod
+ def _build_command_info(entry: "CommandEntry") -> CommandInfo:
+ """将运行时 Command 条目转换为核心命令信息。
+
+ Args:
+ entry: 插件运行时中的 Command 条目。
+
+ Returns:
+ CommandInfo: 供核心命令链使用的命令信息。
+ """
+
+ metadata = dict(entry.metadata)
+ return CommandInfo(
+ name=entry.name,
+ component_type=ComponentType.COMMAND,
+ description=str(metadata.get("description", "") or ""),
+ enabled=bool(entry.enabled),
+ plugin_name=entry.plugin_id,
+ metadata=metadata,
+ command_pattern=str(metadata.get("command_pattern", "") or ""),
+ )
+
+ @staticmethod
+ def _coerce_tool_param_type(raw_value: Any) -> ToolParamType:
+ """规范化工具参数类型。
+
+ Args:
+ raw_value: 原始工具参数类型值。
+
+ Returns:
+ ToolParamType: 规范化后的工具参数类型。
+ """
+
+ normalized_value = str(raw_value or "").strip().lower()
+ return _TOOL_PARAM_TYPE_MAP.get(normalized_value, ToolParamType.STRING)
+
+ @staticmethod
+ def _build_tool_parameters(entry: "ToolEntry") -> list[tuple[str, ToolParamType, str, bool, list[str] | None]]:
+ """将运行时工具参数元数据转换为核心 ToolInfo 参数列表。
+
+ Args:
+ entry: 插件运行时中的 Tool 条目。
+
+ Returns:
+ list[tuple[str, ToolParamType, str, bool, list[str] | None]]: 转换后的参数列表。
+ """
+
+ structured_parameters = entry.parameters if isinstance(entry.parameters, list) else []
+ if not structured_parameters and isinstance(entry.parameters_raw, dict):
+ structured_parameters = [
+ {"name": key, **value}
+ for key, value in entry.parameters_raw.items()
+ if isinstance(value, dict)
+ ]
+
+ normalized_parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = []
+ for parameter in structured_parameters:
+ if not isinstance(parameter, dict):
+ continue
+
+ parameter_name = str(parameter.get("name", "") or "").strip()
+ if not parameter_name:
+ continue
+
+ enum_values = parameter.get("enum")
+ normalized_enum_values = (
+ [str(item) for item in enum_values if item is not None]
+ if isinstance(enum_values, list)
+ else None
+ )
+ normalized_parameters.append(
+ (
+ parameter_name,
+ ComponentQueryService._coerce_tool_param_type(parameter.get("param_type") or parameter.get("type")),
+ str(parameter.get("description", "") or ""),
+ bool(parameter.get("required", True)),
+ normalized_enum_values,
+ )
+ )
+ return normalized_parameters
+
+ @staticmethod
+ def _build_tool_info(entry: "ToolEntry") -> ToolInfo:
+ """将运行时 Tool 条目转换为核心工具信息。
+
+ Args:
+ entry: 插件运行时中的 Tool 条目。
+
+ Returns:
+ ToolInfo: 供 ToolExecutor 与能力层使用的工具信息。
+ """
+
+ return ToolInfo(
+ name=entry.name,
+ component_type=ComponentType.TOOL,
+ description=entry.description,
+ enabled=bool(entry.enabled),
+ plugin_name=entry.plugin_id,
+ metadata=dict(entry.metadata),
+ tool_parameters=ComponentQueryService._build_tool_parameters(entry),
+ tool_description=entry.description,
+ )
+
+ @staticmethod
+ def _log_duplicate_component(component_type: ComponentType, component_name: str) -> None:
+ """记录重复组件名称冲突。
+
+ Args:
+ component_type: 组件类型。
+ component_name: 发生冲突的组件名称。
+ """
+
+ logger.warning(f"检测到重复{component_type.value}名称 {component_name},将只保留首个匹配项")
+
+ def _get_unique_component_entry(
+ self,
+ component_type: ComponentType,
+ name: str,
+ ) -> Optional[tuple["PluginSupervisor", "ComponentEntry"]]:
+ """按组件短名解析唯一条目。
+
+ Args:
+ component_type: 目标组件类型。
+ name: 组件短名。
+
+ Returns:
+ Optional[tuple[PluginSupervisor, ComponentEntry]]: 唯一命中的组件条目。
+ """
+
+ matched_entries = [
+ (supervisor, entry)
+ for supervisor, entry in self._iter_component_entries(component_type)
+ if entry.name == name
+ ]
+ if not matched_entries:
+ return None
+ if len(matched_entries) > 1:
+ self._log_duplicate_component(component_type, name)
+ return matched_entries[0]
+
+ def _collect_unique_component_infos(
+ self,
+ component_type: ComponentType,
+ ) -> Dict[str, ComponentInfo]:
+ """收集某类组件的唯一信息视图。
+
+ Args:
+ component_type: 目标组件类型。
+
+ Returns:
+ Dict[str, ComponentInfo]: 组件名到核心组件信息的映射。
+ """
+
+ collected_components: Dict[str, ComponentInfo] = {}
+ for _supervisor, entry in self._iter_component_entries(component_type):
+ if entry.name in collected_components:
+ self._log_duplicate_component(component_type, entry.name)
+ continue
+
+ if component_type == ComponentType.ACTION:
+ collected_components[entry.name] = self._build_action_info(entry) # type: ignore[arg-type]
+ elif component_type == ComponentType.COMMAND:
+ collected_components[entry.name] = self._build_command_info(entry) # type: ignore[arg-type]
+ elif component_type == ComponentType.TOOL:
+ collected_components[entry.name] = self._build_tool_info(entry) # type: ignore[arg-type]
+ return collected_components
+
+ @staticmethod
+ def _extract_stream_id_from_action_kwargs(kwargs: Dict[str, Any]) -> str:
+ """从旧 ActionManager 参数中提取聊天流 ID。
+
+ Args:
+ kwargs: 旧动作执行器收到的关键字参数。
+
+ Returns:
+ str: 提取出的 ``stream_id``。
+ """
+
+ chat_stream = kwargs.get("chat_stream")
+ if chat_stream is not None:
+ try:
+ return str(chat_stream.session_id)
+ except AttributeError:
+ pass
+
+ return str(kwargs.get("stream_id", "") or "")
+
+ @staticmethod
+ def _build_action_executor(supervisor: "PluginSupervisor", plugin_id: str, component_name: str) -> ActionExecutor:
+ """构造动作执行 RPC 闭包。
+
+ Args:
+ supervisor: 负责该组件的监督器。
+ plugin_id: 插件 ID。
+ component_name: 组件名称。
+
+ Returns:
+ ActionExecutor: 兼容旧 Planner 的异步执行器。
+ """
+
+ async def _executor(**kwargs: Any) -> tuple[bool, str]:
+ """将核心动作调用桥接到插件运行时。
+
+ Args:
+ **kwargs: 旧 ActionManager 传入的上下文参数。
+
+ Returns:
+ tuple[bool, str]: ``(是否成功, 结果说明)``。
+ """
+
+ invoke_args: Dict[str, Any] = {}
+ action_data = kwargs.get("action_data")
+ if isinstance(action_data, dict):
+ invoke_args.update(action_data)
+
+ stream_id = ComponentQueryService._extract_stream_id_from_action_kwargs(kwargs)
+ invoke_args["action_data"] = action_data if isinstance(action_data, dict) else {}
+ invoke_args["stream_id"] = stream_id
+ invoke_args["chat_id"] = stream_id
+ invoke_args["reasoning"] = str(kwargs.get("action_reasoning", "") or "")
+
+ if (thinking_id := kwargs.get("thinking_id")) is not None:
+ invoke_args["thinking_id"] = str(thinking_id)
+ if isinstance(kwargs.get("cycle_timers"), dict):
+ invoke_args["cycle_timers"] = kwargs["cycle_timers"]
+ if isinstance(kwargs.get("plugin_config"), dict):
+ invoke_args["plugin_config"] = kwargs["plugin_config"]
+ if isinstance(kwargs.get("log_prefix"), str):
+ invoke_args["log_prefix"] = kwargs["log_prefix"]
+ if isinstance(kwargs.get("shutting_down"), bool):
+ invoke_args["shutting_down"] = kwargs["shutting_down"]
+
+ try:
+ response = await supervisor.invoke_plugin(
+ method="plugin.invoke_action",
+ plugin_id=plugin_id,
+ component_name=component_name,
+ args=invoke_args,
+ timeout_ms=30000,
+ )
+ except Exception as exc:
+ logger.error(f"运行时 Action {plugin_id}.{component_name} 执行失败: {exc}", exc_info=True)
+ return False, str(exc)
+
+ payload = response.payload if isinstance(response.payload, dict) else {}
+ success = bool(payload.get("success", False))
+ result = payload.get("result")
+ if isinstance(result, (list, tuple)):
+ if len(result) >= 2:
+ return bool(result[0]), "" if result[1] is None else str(result[1])
+ if len(result) == 1:
+ return bool(result[0]), ""
+ if success:
+ return True, "" if result is None else str(result)
+ return False, "" if result is None else str(result)
+
+ return _executor
+
+ @staticmethod
+ def _build_command_executor(
+ supervisor: "PluginSupervisor",
+ plugin_id: str,
+ component_name: str,
+ metadata: Dict[str, Any],
+ ) -> CommandExecutor:
+ """构造命令执行 RPC 闭包。
+
+ Args:
+ supervisor: 负责该组件的监督器。
+ plugin_id: 插件 ID。
+ component_name: 组件名称。
+ metadata: 命令组件元数据。
+
+ Returns:
+ CommandExecutor: 兼容旧消息命令链的执行器。
+ """
+
+ async def _executor(**kwargs: Any) -> tuple[bool, Optional[str], bool]:
+ """将核心命令调用桥接到插件运行时。
+
+ Args:
+ **kwargs: 命令执行上下文参数。
+
+ Returns:
+ tuple[bool, Optional[str], bool]: ``(是否成功, 返回文本, 是否拦截后续消息)``。
+ """
+
+ message = kwargs.get("message")
+ matched_groups = kwargs.get("matched_groups")
+ plugin_config = kwargs.get("plugin_config")
+ invoke_args: Dict[str, Any] = {
+ "text": str(getattr(message, "processed_plain_text", "") or ""),
+ "stream_id": str(getattr(message, "session_id", "") or ""),
+ "matched_groups": matched_groups if isinstance(matched_groups, dict) else {},
+ }
+ if isinstance(plugin_config, dict):
+ invoke_args["plugin_config"] = plugin_config
+
+ try:
+ response = await supervisor.invoke_plugin(
+ method="plugin.invoke_command",
+ plugin_id=plugin_id,
+ component_name=component_name,
+ args=invoke_args,
+ timeout_ms=30000,
+ )
+ except Exception as exc:
+ logger.error(f"运行时 Command {plugin_id}.{component_name} 执行失败: {exc}", exc_info=True)
+ return False, str(exc), True
+
+ payload = response.payload if isinstance(response.payload, dict) else {}
+ success = bool(payload.get("success", False))
+ result = payload.get("result")
+ intercept = bool(metadata.get("intercept_message_level", 0))
+ response_text: Optional[str]
+
+ if isinstance(result, (list, tuple)) and len(result) >= 3:
+ response_text = None if result[1] is None else str(result[1])
+ intercept = bool(result[2])
+ else:
+ response_text = None if result is None else str(result)
+
+ return success, response_text, intercept
+
+ return _executor
+
+ @staticmethod
+ def _build_tool_executor(supervisor: "PluginSupervisor", plugin_id: str, component_name: str) -> ToolExecutor:
+ """构造工具执行 RPC 闭包。
+
+ Args:
+ supervisor: 负责该组件的监督器。
+ plugin_id: 插件 ID。
+ component_name: 组件名称。
+
+ Returns:
+ ToolExecutor: 兼容旧 ToolExecutor 的异步执行器。
+ """
+
+ async def _executor(function_args: Dict[str, Any]) -> Any:
+ """将核心工具调用桥接到插件运行时。
+
+ Args:
+ function_args: 工具调用参数。
+
+ Returns:
+ Any: 插件工具返回结果;若结果不是字典,则会包装为 ``{"content": ...}``。
+ """
+
+ try:
+ response = await supervisor.invoke_plugin(
+ method="plugin.invoke_tool",
+ plugin_id=plugin_id,
+ component_name=component_name,
+ args=function_args,
+ timeout_ms=30000,
+ )
+ except Exception as exc:
+ logger.error(f"运行时 Tool {plugin_id}.{component_name} 执行失败: {exc}", exc_info=True)
+ return {"content": f"工具 {component_name} 执行失败: {exc}"}
+
+ payload = response.payload if isinstance(response.payload, dict) else {}
+ result = payload.get("result")
+ if isinstance(result, dict):
+ return result
+ return {"content": "" if result is None else str(result)}
+
+ return _executor
+
+ def get_action_info(self, name: str) -> Optional[ActionInfo]:
+ """获取指定动作的信息。
+
+ Args:
+ name: 动作名称。
+
+ Returns:
+ Optional[ActionInfo]: 匹配到的动作信息。
+ """
+
+ matched_entry = self._get_unique_component_entry(ComponentType.ACTION, name)
+ if matched_entry is None:
+ return None
+ _supervisor, entry = matched_entry
+ return self._build_action_info(entry) # type: ignore[arg-type]
+
+ def get_action_executor(self, name: str) -> Optional[ActionExecutor]:
+ """获取指定动作的执行器。
+
+ Args:
+ name: 动作名称。
+
+ Returns:
+ Optional[ActionExecutor]: 运行时 RPC 执行闭包。
+ """
+
+ matched_entry = self._get_unique_component_entry(ComponentType.ACTION, name)
+ if matched_entry is None:
+ return None
+ supervisor, entry = matched_entry
+ return self._build_action_executor(supervisor, entry.plugin_id, entry.name)
+
+ def get_default_actions(self) -> Dict[str, ActionInfo]:
+ """获取当前默认启用的动作集合。
+
+ Returns:
+ Dict[str, ActionInfo]: 动作名到动作信息的映射。
+ """
+
+ action_infos = self._collect_unique_component_infos(ComponentType.ACTION)
+ return {name: info for name, info in action_infos.items() if isinstance(info, ActionInfo) and info.enabled}
+
+ def find_command_by_text(self, text: str) -> Optional[Tuple[CommandExecutor, dict, CommandInfo]]:
+ """根据文本查找匹配的命令。
+
+ Args:
+ text: 待匹配的文本内容。
+
+ Returns:
+ Optional[Tuple[CommandExecutor, dict, CommandInfo]]: 匹配结果。
+ """
+
+ for supervisor in self._iter_supervisors():
+ match_result = supervisor.component_registry.find_command_by_text(text)
+ if match_result is None:
+ continue
+
+ entry, matched_groups = match_result
+ command_info = self._build_command_info(entry) # type: ignore[arg-type]
+ command_executor = self._build_command_executor(
+ supervisor,
+ entry.plugin_id,
+ entry.name,
+ dict(entry.metadata),
+ )
+ return command_executor, matched_groups, command_info
+ return None
+
+ def get_tool_info(self, name: str) -> Optional[ToolInfo]:
+ """获取指定工具的信息。
+
+ Args:
+ name: 工具名称。
+
+ Returns:
+ Optional[ToolInfo]: 匹配到的工具信息。
+ """
+
+ matched_entry = self._get_unique_component_entry(ComponentType.TOOL, name)
+ if matched_entry is None:
+ return None
+ _supervisor, entry = matched_entry
+ return self._build_tool_info(entry) # type: ignore[arg-type]
+
+ def get_tool_executor(self, name: str) -> Optional[ToolExecutor]:
+ """获取指定工具的执行器。
+
+ Args:
+ name: 工具名称。
+
+ Returns:
+ Optional[ToolExecutor]: 运行时 RPC 执行闭包。
+ """
+
+ matched_entry = self._get_unique_component_entry(ComponentType.TOOL, name)
+ if matched_entry is None:
+ return None
+ supervisor, entry = matched_entry
+ return self._build_tool_executor(supervisor, entry.plugin_id, entry.name)
+
+ def get_llm_available_tools(self) -> Dict[str, ToolInfo]:
+ """获取当前可供 LLM 选择的工具集合。
+
+ Returns:
+ Dict[str, ToolInfo]: 工具名到工具信息的映射。
+ """
+
+ tool_infos = self._collect_unique_component_infos(ComponentType.TOOL)
+ return {name: info for name, info in tool_infos.items() if isinstance(info, ToolInfo) and info.enabled}
+
+ def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
+ """获取某类组件的全部信息。
+
+ Args:
+ component_type: 组件类型。
+
+ Returns:
+ Dict[str, ComponentInfo]: 组件名到组件信息的映射。
+ """
+
+ return self._collect_unique_component_infos(component_type)
+
+ def get_plugin_config(self, plugin_name: str) -> Optional[dict]:
+ """读取指定插件的配置文件内容。
+
+ Args:
+ plugin_name: 插件名称。
+
+ Returns:
+ Optional[dict]: 读取成功时返回配置字典;未找到时返回 ``None``。
+ """
+
+ runtime_manager = self._get_runtime_manager()
+ try:
+ supervisor = runtime_manager._get_supervisor_for_plugin(plugin_name)
+ except RuntimeError as exc:
+ logger.error(f"读取插件配置失败: {exc}")
+ return None
+
+ if supervisor is None:
+ return None
+
+ try:
+ return runtime_manager._load_plugin_config_for_supervisor(supervisor, plugin_name)
+ except Exception as exc:
+ logger.error(f"读取插件 {plugin_name} 配置失败: {exc}", exc_info=True)
+ return None
+
+
+component_query_service = ComponentQueryService()
diff --git a/src/plugin_runtime/host/component_registry.py b/src/plugin_runtime/host/component_registry.py
index 08b0ea3b..1c073490 100644
--- a/src/plugin_runtime/host/component_registry.py
+++ b/src/plugin_runtime/host/component_registry.py
@@ -31,12 +31,12 @@ class ComponentTypes(str, Enum):
class StatusDict(TypedDict):
total: int
- ACTION: int
- COMMAND: int
- TOOL: int
- EVENT_HANDLER: int
- HOOK_HANDLER: int
- MESSAGE_GATEWAY: int
+ action: int
+ command: int
+ tool: int
+ event_handler: int
+ hook_handler: int
+ message_gateway: int
plugins: int
@@ -185,6 +185,23 @@ class ComponentRegistry:
# 按插件索引
self._by_plugin: Dict[str, List[ComponentEntry]] = {}
+ @staticmethod
+ def _normalize_component_type(component_type: str) -> ComponentTypes:
+ """规范化组件类型输入。
+
+ Args:
+ component_type: 原始组件类型字符串。
+
+ Returns:
+ ComponentTypes: 规范化后的组件类型枚举。
+
+ Raises:
+ ValueError: 当组件类型不受支持时抛出。
+ """
+
+ normalized_value = str(component_type or "").strip().upper()
+ return ComponentTypes(normalized_value)
+
def clear(self) -> None:
"""清空全部组件注册状态。"""
self._components.clear()
@@ -205,18 +222,19 @@ class ComponentRegistry:
success (bool): 是否成功注册(失败原因通常是组件类型无效)
"""
try:
- if component_type == ComponentTypes.ACTION:
- comp = ActionEntry(name, component_type, plugin_id, metadata)
- elif component_type == ComponentTypes.COMMAND:
- comp = CommandEntry(name, component_type, plugin_id, metadata)
- elif component_type == ComponentTypes.TOOL:
- comp = ToolEntry(name, component_type, plugin_id, metadata)
- elif component_type == ComponentTypes.EVENT_HANDLER:
- comp = EventHandlerEntry(name, component_type, plugin_id, metadata)
- elif component_type == ComponentTypes.HOOK_HANDLER:
- comp = HookHandlerEntry(name, component_type, plugin_id, metadata)
- elif component_type == ComponentTypes.MESSAGE_GATEWAY:
- comp = MessageGatewayEntry(name, component_type, plugin_id, metadata)
+ normalized_type = self._normalize_component_type(component_type)
+ if normalized_type == ComponentTypes.ACTION:
+ comp = ActionEntry(name, normalized_type.value, plugin_id, metadata)
+ elif normalized_type == ComponentTypes.COMMAND:
+ comp = CommandEntry(name, normalized_type.value, plugin_id, metadata)
+ elif normalized_type == ComponentTypes.TOOL:
+ comp = ToolEntry(name, normalized_type.value, plugin_id, metadata)
+ elif normalized_type == ComponentTypes.EVENT_HANDLER:
+ comp = EventHandlerEntry(name, normalized_type.value, plugin_id, metadata)
+ elif normalized_type == ComponentTypes.HOOK_HANDLER:
+ comp = HookHandlerEntry(name, normalized_type.value, plugin_id, metadata)
+ elif normalized_type == ComponentTypes.MESSAGE_GATEWAY:
+ comp = MessageGatewayEntry(name, normalized_type.value, plugin_id, metadata)
else:
raise ValueError(f"组件类型 {component_type} 不存在")
except ValueError:
@@ -304,6 +322,20 @@ class ComponentRegistry:
comp.enabled = enabled
return True
+ def set_component_enabled(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool:
+ """设置指定组件的启用状态。
+
+ Args:
+ full_name: 组件全名。
+ enabled: 目标启用状态。
+ session_id: 可选的会话 ID,仅对该会话生效。
+
+ Returns:
+ bool: 是否设置成功。
+ """
+
+ return self.toggle_component_status(full_name, enabled, session_id=session_id)
+
def toggle_plugin_status(self, plugin_id: str, enabled: bool, session_id: Optional[str] = None) -> int:
"""批量启用或禁用某插件的所有组件。
@@ -348,7 +380,7 @@ class ComponentRegistry:
components (List[ComponentEntry]): 组件条目列表
"""
try:
- comp_type = ComponentTypes(component_type)
+ comp_type = self._normalize_component_type(component_type)
except ValueError:
logger.error(f"组件类型 {component_type} 不存在")
raise
@@ -536,6 +568,6 @@ class ComponentRegistry:
"""
stats: StatusDict = {"total": len(self._components)} # type: ignore
for comp_type, type_dict in self._by_type.items():
- stats[comp_type.value] = len(type_dict)
+ stats[comp_type.value.lower()] = len(type_dict)
stats["plugins"] = len(self._by_plugin)
return stats
diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py
index 3588934e..4a9885f8 100644
--- a/src/plugin_runtime/host/supervisor.py
+++ b/src/plugin_runtime/host/supervisor.py
@@ -9,8 +9,6 @@ import sys
from src.common.logger import get_logger
from src.config.config import global_config
-from src.core.component_registry import component_registry as core_component_registry
-from src.core.types import ActionActivationType, ActionInfo, ComponentType as CoreComponentType
from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, get_platform_io_manager
from src.platform_io.drivers import PluginPlatformDriver
from src.platform_io.route_key_factory import RouteKeyFactory
@@ -107,7 +105,6 @@ class PluginRunnerSupervisor:
self._runner_process: Optional[asyncio.subprocess.Process] = None
self._registered_plugins: Dict[str, RegisterPluginPayload] = {}
self._message_gateway_states: Dict[str, Dict[str, _MessageGatewayRuntimeState]] = {}
- self._mirrored_core_actions: Dict[str, List[str]] = {}
self._runner_ready_events: asyncio.Event = asyncio.Event()
self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload()
self._health_task: Optional[asyncio.Task[None]] = None
@@ -510,7 +507,6 @@ class PluginRunnerSupervisor:
except Exception as exc:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
- self._remove_core_action_mirrors(payload.plugin_id)
self._component_registry.remove_components_by_plugin(payload.plugin_id)
await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id)
@@ -520,7 +516,6 @@ class PluginRunnerSupervisor:
)
self._registered_plugins[payload.plugin_id] = payload
self._message_gateway_states[payload.plugin_id] = {}
- self._mirror_runtime_actions_to_core_registry(payload)
return envelope.make_response(
payload={
@@ -550,7 +545,6 @@ class PluginRunnerSupervisor:
removed_components = self._component_registry.remove_components_by_plugin(payload.plugin_id)
self._authorization.revoke_permission_token(payload.plugin_id)
removed_registration = self._registered_plugins.pop(payload.plugin_id, None) is not None
- self._remove_core_action_mirrors(payload.plugin_id)
await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id)
self._message_gateway_states.pop(payload.plugin_id, None)
@@ -564,236 +558,6 @@ class PluginRunnerSupervisor:
}
)
- @staticmethod
- def _coerce_action_activation_type(raw_value: Any) -> ActionActivationType:
- """将运行时 Action 激活类型转换为旧核心枚举。
-
- Args:
- raw_value: 插件运行时声明中的激活类型值。
-
- Returns:
- ActionActivationType: 可供旧 Planner 使用的激活类型枚举。
- """
- normalized_value = str(raw_value or ActionActivationType.ALWAYS.value).strip().lower()
- try:
- return ActionActivationType(normalized_value)
- except ValueError:
- return ActionActivationType.ALWAYS
-
- @staticmethod
- def _coerce_float(value: Any, default: float = 0.0) -> float:
- """将任意输入尽量转换为浮点数。
-
- Args:
- value: 待转换的值。
- default: 转换失败时使用的默认值。
-
- Returns:
- float: 转换结果。
- """
- try:
- return float(value)
- except (TypeError, ValueError):
- return default
-
- @staticmethod
- def _build_core_action_info(plugin_id: str, component_name: str, metadata: Dict[str, Any]) -> ActionInfo:
- """将运行时 Action 元数据映射为旧核心 ActionInfo。
-
- Args:
- plugin_id: 插件 ID。
- component_name: 组件名称。
- metadata: 运行时组件元数据。
-
- Returns:
- ActionInfo: 兼容旧 Planner 的动作定义。
- """
- activation_keywords = [
- str(item)
- for item in (metadata.get("activation_keywords") or [])
- if item is not None and str(item).strip()
- ]
- action_require = [
- str(item)
- for item in (metadata.get("action_require") or [])
- if item is not None and str(item).strip()
- ]
- associated_types = [
- str(item)
- for item in (metadata.get("associated_types") or [])
- if item is not None and str(item).strip()
- ]
- raw_action_parameters = metadata.get("action_parameters") or {}
- action_parameters = {
- str(param_name): str(param_description)
- for param_name, param_description in raw_action_parameters.items()
- } if isinstance(raw_action_parameters, dict) else {}
-
- return ActionInfo(
- name=component_name,
- component_type=CoreComponentType.ACTION,
- description=str(metadata.get("description", "") or ""),
- enabled=bool(metadata.get("enabled", True)),
- plugin_name=plugin_id,
- metadata=dict(metadata),
- action_parameters=action_parameters,
- action_require=action_require,
- associated_types=associated_types,
- activation_type=PluginRunnerSupervisor._coerce_action_activation_type(metadata.get("activation_type")),
- random_activation_probability=PluginRunnerSupervisor._coerce_float(
- metadata.get("activation_probability"),
- 0.0,
- ),
- activation_keywords=activation_keywords,
- parallel_action=bool(metadata.get("parallel_action", False)),
- )
-
- @staticmethod
- def _extract_stream_id_from_action_kwargs(kwargs: Dict[str, Any]) -> str:
- """从旧 ActionManager 传入参数中提取聊天流 ID。
-
- Args:
- kwargs: 旧动作执行器收到的关键字参数。
-
- Returns:
- str: 可用于新运行时 Action 的 ``stream_id``。
- """
- chat_stream = kwargs.get("chat_stream")
- if chat_stream is not None:
- try:
- return str(chat_stream.session_id)
- except AttributeError:
- pass
-
- raw_stream_id = kwargs.get("stream_id", "")
- return str(raw_stream_id or "")
-
- def _build_runtime_action_executor(
- self,
- plugin_id: str,
- component_name: str,
- ) -> Any:
- """构造一个转发到 plugin runtime 的旧核心 Action 执行器。
-
- Args:
- plugin_id: 目标插件 ID。
- component_name: 目标 Action 组件名称。
-
- Returns:
- Callable[..., Coroutine[Any, Any, tuple[bool, str]]]: 兼容旧 ActionManager 的执行器。
- """
-
- async def _executor(**kwargs: Any) -> tuple[bool, str]:
- """将旧 Planner 的动作调用桥接到 plugin runtime。
-
- Args:
- **kwargs: 旧 ActionManager 传入的运行时上下文参数。
-
- Returns:
- tuple[bool, str]: ``(是否成功, 动作说明)``。
- """
- invoke_args: Dict[str, Any] = {}
- action_data = kwargs.get("action_data")
- if isinstance(action_data, dict):
- invoke_args.update(action_data)
-
- stream_id = self._extract_stream_id_from_action_kwargs(kwargs)
- invoke_args["action_data"] = action_data if isinstance(action_data, dict) else {}
- invoke_args["stream_id"] = stream_id
- invoke_args["chat_id"] = stream_id
- invoke_args["reasoning"] = str(kwargs.get("action_reasoning", "") or "")
-
- thinking_id = kwargs.get("thinking_id")
- if thinking_id is not None:
- invoke_args["thinking_id"] = str(thinking_id)
-
- cycle_timers = kwargs.get("cycle_timers")
- if isinstance(cycle_timers, dict):
- invoke_args["cycle_timers"] = cycle_timers
-
- plugin_config = kwargs.get("plugin_config")
- if isinstance(plugin_config, dict):
- invoke_args["plugin_config"] = plugin_config
-
- log_prefix = kwargs.get("log_prefix")
- if isinstance(log_prefix, str):
- invoke_args["log_prefix"] = log_prefix
-
- shutting_down = kwargs.get("shutting_down")
- if isinstance(shutting_down, bool):
- invoke_args["shutting_down"] = shutting_down
-
- try:
- response = await self.invoke_plugin(
- method="plugin.invoke_action",
- plugin_id=plugin_id,
- component_name=component_name,
- args=invoke_args,
- timeout_ms=30000,
- )
- except Exception as exc:
- logger.error(f"运行时 Action {plugin_id}.{component_name} 执行失败: {exc}", exc_info=True)
- return False, str(exc)
-
- payload = response.payload if isinstance(response.payload, dict) else {}
- success = bool(payload.get("success", False))
- result = payload.get("result")
-
- if isinstance(result, (list, tuple)):
- if len(result) >= 2:
- return bool(result[0]), "" if result[1] is None else str(result[1])
- if len(result) == 1:
- return bool(result[0]), ""
-
- if success:
- return True, "" if result is None else str(result)
- return False, "" if result is None else str(result)
-
- return _executor
-
- def _mirror_runtime_actions_to_core_registry(self, payload: RegisterPluginPayload) -> None:
- """将 plugin runtime 中声明的 Action 镜像到旧核心注册表。
-
- Args:
- payload: 当前插件的注册载荷。
- """
- mirrored_action_names: List[str] = []
-
- for component in payload.components:
- if str(component.component_type).upper() != CoreComponentType.ACTION.name:
- continue
-
- action_info = self._build_core_action_info(
- plugin_id=payload.plugin_id,
- component_name=component.name,
- metadata=component.metadata,
- )
- action_executor = self._build_runtime_action_executor(
- plugin_id=payload.plugin_id,
- component_name=component.name,
- )
- registered = core_component_registry.register_action(action_info, action_executor)
- if not registered:
- logger.warning(
- f"运行时 Action {payload.plugin_id}.{component.name} 无法镜像到旧核心注册表,"
- "可能与现有 Action 重名"
- )
- continue
- mirrored_action_names.append(component.name)
-
- if mirrored_action_names:
- self._mirrored_core_actions[payload.plugin_id] = mirrored_action_names
-
- def _remove_core_action_mirrors(self, plugin_id: str) -> None:
- """移除某个插件镜像到旧核心注册表的所有 Action。
-
- Args:
- plugin_id: 目标插件 ID。
- """
- mirrored_action_names = self._mirrored_core_actions.pop(plugin_id, [])
- for action_name in mirrored_action_names:
- core_component_registry.remove_action(action_name)
-
@staticmethod
def _build_message_gateway_driver_id(plugin_id: str, gateway_name: str) -> str:
"""构造消息网关驱动 ID。
@@ -1407,8 +1171,6 @@ class PluginRunnerSupervisor:
def _clear_runner_state(self) -> None:
"""清理当前 Runner 对应的 Host 侧注册状态。"""
- for plugin_id in list(self._mirrored_core_actions.keys()):
- self._remove_core_action_mirrors(plugin_id)
self._authorization.clear()
self._component_registry.clear()
self._registered_plugins.clear()
diff --git a/src/services/send_service.py b/src/services/send_service.py
index 7903cdeb..54f2a9de 100644
--- a/src/services/send_service.py
+++ b/src/services/send_service.py
@@ -8,10 +8,9 @@
3. 具体走插件链还是 legacy 旧链,由 Platform IO 内部统一决策。
"""
+from copy import deepcopy
from typing import Any, Dict, List, Optional
-from maim_message import Seg
-
import asyncio
import base64
import hashlib
@@ -28,6 +27,7 @@ from src.common.data_models.message_component_data_model import (
AtComponent,
DictComponent,
EmojiComponent,
+ ForwardNodeComponent,
ImageComponent,
MessageSequence,
ReplyComponent,
@@ -72,88 +72,163 @@ def _inherit_platform_io_route_metadata(target_stream: BotChatSession) -> Dict[s
if normalized_value:
inherited_metadata[key] = value
- if target_stream.group_id:
- normalized_group_id = str(target_stream.group_id).strip()
- if normalized_group_id:
- inherited_metadata["platform_io_target_group_id"] = normalized_group_id
+ if target_stream.group_id and (normalized_group_id := str(target_stream.group_id).strip()):
+ inherited_metadata["platform_io_target_group_id"] = normalized_group_id
- if target_stream.user_id:
- normalized_user_id = str(target_stream.user_id).strip()
- if normalized_user_id:
- inherited_metadata["platform_io_target_user_id"] = normalized_user_id
+ if target_stream.user_id and (normalized_user_id := str(target_stream.user_id).strip()):
+ inherited_metadata["platform_io_target_user_id"] = normalized_user_id
return inherited_metadata
-def _build_component_from_seg(message_segment: Seg) -> StandardMessageComponents:
- """将单个消息段转换为内部消息组件。
+def _build_binary_component_from_base64(component_type: str, raw_data: str) -> StandardMessageComponents:
+ """根据 Base64 数据构造二进制消息组件。
Args:
- message_segment: 待转换的消息段。
+ component_type: 组件类型名称。
+ raw_data: Base64 编码后的二进制数据。
Returns:
StandardMessageComponents: 转换后的内部消息组件。
+
+ Raises:
+ ValueError: 当组件类型不受支持时抛出。
"""
- segment_type = str(message_segment.type or "").strip().lower()
- segment_data = message_segment.data
+ binary_data = base64.b64decode(raw_data)
+ binary_hash = hashlib.sha256(binary_data).hexdigest()
- if segment_type == "text":
- return TextComponent(text=str(segment_data or ""))
-
- if segment_type == "image":
- image_binary = base64.b64decode(str(segment_data or ""))
- return ImageComponent(
- binary_hash=hashlib.sha256(image_binary).hexdigest(),
- binary_data=image_binary,
- )
-
- if segment_type == "emoji":
- emoji_binary = base64.b64decode(str(segment_data or ""))
- return EmojiComponent(
- binary_hash=hashlib.sha256(emoji_binary).hexdigest(),
- binary_data=emoji_binary,
- )
-
- if segment_type == "voice":
- voice_binary = base64.b64decode(str(segment_data or ""))
- return VoiceComponent(
- binary_hash=hashlib.sha256(voice_binary).hexdigest(),
- binary_data=voice_binary,
- )
-
- if segment_type == "at":
- return AtComponent(target_user_id=str(segment_data or ""))
-
- if segment_type == "reply":
- return ReplyComponent(target_message_id=str(segment_data or ""))
-
- if segment_type == "dict" and isinstance(segment_data, dict):
- return DictComponent(data=segment_data)
-
- return DictComponent(data={"type": segment_type, "data": segment_data})
+ if component_type == "image":
+ return ImageComponent(binary_hash=binary_hash, binary_data=binary_data)
+ if component_type == "emoji":
+ return EmojiComponent(binary_hash=binary_hash, binary_data=binary_data)
+ if component_type == "voice":
+ return VoiceComponent(binary_hash=binary_hash, binary_data=binary_data)
+ raise ValueError(f"不支持的二进制组件类型: {component_type}")
-def _build_message_sequence_from_seg(message_segment: Seg) -> MessageSequence:
- """将消息段转换为内部消息组件序列。
+def _build_message_sequence_from_custom_message(
+ message_type: str,
+ content: str | Dict[str, Any],
+) -> MessageSequence:
+ """根据自定义消息类型构造内部消息组件序列。
Args:
- message_segment: 待转换的消息段。
+ message_type: 自定义消息类型。
+ content: 自定义消息内容。
Returns:
MessageSequence: 转换后的消息组件序列。
"""
- if str(message_segment.type or "").strip().lower() == "seglist":
- raw_segments = message_segment.data
- if not isinstance(raw_segments, list):
- raise ValueError("seglist 类型的消息段数据必须是列表")
- components = [
- _build_component_from_seg(item)
- for item in raw_segments
- if isinstance(item, Seg)
- ]
- return MessageSequence(components=components)
+ normalized_type = message_type.strip().lower()
- return MessageSequence(components=[_build_component_from_seg(message_segment)])
+ if normalized_type == "text":
+ return MessageSequence(components=[TextComponent(text=str(content))])
+
+ if normalized_type in {"image", "emoji", "voice"}:
+ return MessageSequence(
+ components=[_build_binary_component_from_base64(normalized_type, str(content))]
+ )
+
+ if normalized_type == "at":
+ return MessageSequence(components=[AtComponent(target_user_id=str(content))])
+
+ if normalized_type == "reply":
+ return MessageSequence(components=[ReplyComponent(target_message_id=str(content))])
+
+ if normalized_type == "dict" and isinstance(content, dict):
+ return MessageSequence(components=[DictComponent(data=deepcopy(content))])
+
+ return MessageSequence(
+ components=[
+ DictComponent(
+ data={
+ "type": normalized_type,
+ "data": deepcopy(content),
+ }
+ )
+ ]
+ )
+
+
+def _clone_message_sequence(message_sequence: MessageSequence) -> MessageSequence:
+ """复制消息组件序列,避免原对象被发送流程修改。
+
+ Args:
+ message_sequence: 原始消息组件序列。
+
+ Returns:
+ MessageSequence: 深拷贝后的消息组件序列。
+ """
+ return deepcopy(message_sequence)
+
+
+def _detect_outbound_message_flags(message_sequence: MessageSequence) -> Dict[str, bool]:
+ """根据消息组件序列推断出站消息标记。
+
+ Args:
+ message_sequence: 待发送的消息组件序列。
+
+ Returns:
+ Dict[str, bool]: 包含 ``is_emoji``、``is_picture``、``is_command`` 的标记字典。
+ """
+ if len(message_sequence.components) != 1:
+ return {
+ "is_emoji": False,
+ "is_picture": False,
+ "is_command": False,
+ }
+
+ component = message_sequence.components[0]
+ is_command = False
+ if isinstance(component, DictComponent) and isinstance(component.data, dict):
+ is_command = str(component.data.get("type") or "").strip().lower() == "command"
+
+ return {
+ "is_emoji": isinstance(component, EmojiComponent),
+ "is_picture": isinstance(component, ImageComponent),
+ "is_command": is_command,
+ }
+
+
+def _describe_message_sequence(message_sequence: MessageSequence) -> str:
+ """生成消息组件序列的简短描述文本。
+
+ Args:
+ message_sequence: 待描述的消息组件序列。
+
+ Returns:
+ str: 适用于日志的简短类型描述。
+ """
+ if len(message_sequence.components) != 1:
+ return "message_sequence"
+
+ component = message_sequence.components[0]
+ if isinstance(component, DictComponent) and isinstance(component.data, dict):
+ custom_type = str(component.data.get("type") or "").strip()
+ return custom_type or "dict"
+
+ if isinstance(component, TextComponent):
+ return component.format_name
+
+ if isinstance(component, ImageComponent):
+ return component.format_name
+
+ if isinstance(component, EmojiComponent):
+ return component.format_name
+
+ if isinstance(component, VoiceComponent):
+ return component.format_name
+
+ if isinstance(component, AtComponent):
+ return component.format_name
+
+ if isinstance(component, ReplyComponent):
+ return component.format_name
+
+ if isinstance(component, ForwardNodeComponent):
+ return component.format_name
+
+ return "unknown"
def _build_processed_plain_text(message: SessionMessage) -> str:
@@ -204,7 +279,7 @@ def _build_processed_plain_text(message: SessionMessage) -> str:
def _build_outbound_session_message(
- message_segment: Seg,
+ message_sequence: MessageSequence,
stream_id: str,
display_message: str = "",
reply_message: Optional[MaiMessage] = None,
@@ -213,7 +288,7 @@ def _build_outbound_session_message(
"""根据目标会话构建待发送的内部消息对象。
Args:
- message_segment: 待发送的消息段。
+ message_sequence: 待发送的消息组件序列。
stream_id: 目标会话 ID。
display_message: 用于界面展示的文本内容。
reply_message: 被回复的锚点消息。
@@ -268,13 +343,14 @@ def _build_outbound_session_message(
group_info=group_info,
additional_config=additional_config,
)
- outbound_message.raw_message = _build_message_sequence_from_seg(message_segment)
+ outbound_message.raw_message = _clone_message_sequence(message_sequence)
outbound_message.session_id = target_stream.session_id
outbound_message.display_message = display_message
outbound_message.reply_to = anchor_message.message_id if anchor_message is not None else None
- outbound_message.is_emoji = message_segment.type == "emoji"
- outbound_message.is_picture = message_segment.type == "image"
- outbound_message.is_command = message_segment.type == "command"
+ message_flags = _detect_outbound_message_flags(outbound_message.raw_message)
+ outbound_message.is_emoji = message_flags["is_emoji"]
+ outbound_message.is_picture = message_flags["is_picture"]
+ outbound_message.is_command = message_flags["is_command"]
outbound_message.initialized = True
return outbound_message
@@ -467,7 +543,7 @@ async def send_session_message(
async def _send_to_target(
- message_segment: Seg,
+ message_sequence: MessageSequence,
stream_id: str,
display_message: str = "",
typing: bool = False,
@@ -480,7 +556,7 @@ async def _send_to_target(
"""向指定目标构建并发送消息。
Args:
- message_segment: 待发送的消息段。
+ message_sequence: 待发送的消息组件序列。
stream_id: 目标会话 ID。
display_message: 用于界面展示的文本内容。
typing: 是否显示输入中状态。
@@ -499,10 +575,10 @@ async def _send_to_target(
return False
if show_log:
- logger.debug(f"[SendService] 发送{message_segment.type}消息到 {stream_id}")
+ logger.debug(f"[SendService] 发送{_describe_message_sequence(message_sequence)}消息到 {stream_id}")
outbound_message = _build_outbound_session_message(
- message_segment=message_segment,
+ message_sequence=message_sequence,
stream_id=stream_id,
display_message=display_message,
reply_message=reply_message,
@@ -555,7 +631,7 @@ async def text_to_stream(
bool: 发送成功时返回 ``True``。
"""
return await _send_to_target(
- message_segment=Seg(type="text", data=text),
+ message_sequence=MessageSequence(components=[TextComponent(text=text)]),
stream_id=stream_id,
display_message="",
typing=typing,
@@ -586,7 +662,7 @@ async def emoji_to_stream(
bool: 发送成功时返回 ``True``。
"""
return await _send_to_target(
- message_segment=Seg(type="emoji", data=emoji_base64),
+ message_sequence=_build_message_sequence_from_custom_message("emoji", emoji_base64),
stream_id=stream_id,
display_message="",
typing=False,
@@ -616,7 +692,7 @@ async def image_to_stream(
bool: 发送成功时返回 ``True``。
"""
return await _send_to_target(
- message_segment=Seg(type="image", data=image_base64),
+ message_sequence=_build_message_sequence_from_custom_message("image", image_base64),
stream_id=stream_id,
display_message="",
typing=False,
@@ -654,7 +730,7 @@ async def custom_to_stream(
bool: 发送成功时返回 ``True``。
"""
return await _send_to_target(
- message_segment=Seg(type=message_type, data=content), # type: ignore[arg-type]
+ message_sequence=_build_message_sequence_from_custom_message(message_type, content),
stream_id=stream_id,
display_message=display_message,
typing=typing,
@@ -688,28 +764,15 @@ async def custom_reply_set_to_stream(
show_log: 是否输出发送日志。
Returns:
- bool: 全部组件发送成功时返回 ``True``。
+ bool: 发送成功时返回 ``True``。
"""
- success = True
- for component in reply_set.components:
- if isinstance(component, DictComponent):
- message_seg = Seg(type="dict", data=component.data) # type: ignore[arg-type]
- else:
- message_seg = await component.to_seg()
-
- status = await _send_to_target(
- message_segment=message_seg,
- stream_id=stream_id,
- display_message=display_message,
- typing=typing,
- reply_message=reply_message,
- set_reply=set_reply,
- storage_message=storage_message,
- show_log=show_log,
- )
- if not status:
- success = False
- logger.error(f"[SendService] 发送消息组件失败,组件类型:{type(component).__name__}")
- set_reply = False
-
- return success
+ return await _send_to_target(
+ message_sequence=reply_set,
+ stream_id=stream_id,
+ display_message=display_message,
+ typing=typing,
+ reply_message=reply_message,
+ set_reply=set_reply,
+ storage_message=storage_message,
+ show_log=show_log,
+ )
From 9dea6b0e6fdeae1be119eaeb0a1449ac64d10900 Mon Sep 17 00:00:00 2001
From: DrSmoothl <1787882683@qq.com>
Date: Mon, 23 Mar 2026 17:18:05 +0800
Subject: [PATCH 33/45] feat: implement dedicated API registry and enhance API
handling capabilities
- Added APIEntry and APIRegistry classes for managing plugin APIs.
- Updated PluginRunnerSupervisor to include API registry and methods for invoking APIs.
- Enhanced PluginRuntimeManager to support API registration and invocation.
- Created tests for API registration, invocation, and visibility between plugins.
- Refactored component handling to distinguish between runtime components and APIs.
---
pytests/test_plugin_runtime.py | 9 +-
pytests/test_plugin_runtime_api.py | 294 +++++++++++++++
src/plugin_runtime/capabilities/components.py | 344 +++++++++++++++++-
src/plugin_runtime/capabilities/registry.py | 4 +
src/plugin_runtime/host/api_registry.py | 290 +++++++++++++++
src/plugin_runtime/host/supervisor.py | 81 ++++-
src/plugin_runtime/runner/runner_main.py | 1 +
7 files changed, 1012 insertions(+), 11 deletions(-)
create mode 100644 pytests/test_plugin_runtime_api.py
create mode 100644 src/plugin_runtime/host/api_registry.py
diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py
index 20cceb82..9dfc34d8 100644
--- a/pytests/test_plugin_runtime.py
+++ b/pytests/test_plugin_runtime.py
@@ -2152,8 +2152,11 @@ class TestIntegration:
self.supervisors = [FakeSupervisor("plugin_a"), FakeSupervisor("plugin_b")]
monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager())
+ manager = integration_module.PluginRuntimeManager()
+ manager._builtin_supervisor = FakeSupervisor("plugin_a")
+ manager._third_party_supervisor = FakeSupervisor("plugin_b")
- result = await integration_module.PluginRuntimeManager._cap_component_enable(
+ result = await manager._cap_component_enable(
"plugin_a",
"component.enable",
{"name": "shared", "component_type": "tool", "scope": "global", "stream_id": ""},
@@ -2182,8 +2185,10 @@ class TestIntegration:
self.supervisors = [FakeSupervisor()]
monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager())
+ manager = integration_module.PluginRuntimeManager()
+ manager._builtin_supervisor = FakeSupervisor()
- result = await integration_module.PluginRuntimeManager._cap_component_disable(
+ result = await manager._cap_component_disable(
"plugin_a",
"component.disable",
{"name": "plugin_a.handler", "component_type": "tool", "scope": "stream", "stream_id": "s1"},
diff --git a/pytests/test_plugin_runtime_api.py b/pytests/test_plugin_runtime_api.py
new file mode 100644
index 00000000..fca7736a
--- /dev/null
+++ b/pytests/test_plugin_runtime_api.py
@@ -0,0 +1,294 @@
+"""插件 API 注册与调用测试。"""
+
+from types import SimpleNamespace
+from typing import Any, Dict, List
+
+import pytest
+
+from src.plugin_runtime.integration import PluginRuntimeManager
+from src.plugin_runtime.host.supervisor import PluginSupervisor
+from src.plugin_runtime.protocol.envelope import (
+ ComponentDeclaration,
+ Envelope,
+ MessageType,
+ RegisterPluginPayload,
+ UnregisterPluginPayload,
+)
+
+
+def _build_manager(*supervisors: PluginSupervisor) -> PluginRuntimeManager:
+ """构造一个最小可用的插件运行时管理器。
+
+ Args:
+ *supervisors: 需要挂载的监督器列表。
+
+ Returns:
+ PluginRuntimeManager: 已注入监督器的运行时管理器。
+ """
+
+ manager = PluginRuntimeManager()
+ if supervisors:
+ manager._builtin_supervisor = supervisors[0]
+ if len(supervisors) > 1:
+ manager._third_party_supervisor = supervisors[1]
+ return manager
+
+
+async def _register_plugin(
+ supervisor: PluginSupervisor,
+ plugin_id: str,
+ components: List[Dict[str, Any]],
+) -> Envelope:
+ """通过 Supervisor 注册测试插件。
+
+ Args:
+ supervisor: 目标监督器。
+ plugin_id: 测试插件 ID。
+ components: 组件声明列表。
+
+ Returns:
+ Envelope: 注册响应信封。
+ """
+
+ payload = RegisterPluginPayload(
+ plugin_id=plugin_id,
+ plugin_version="1.0.0",
+ components=[
+ ComponentDeclaration(
+ name=str(component.get("name", "") or ""),
+ component_type=str(component.get("component_type", "") or ""),
+ plugin_id=plugin_id,
+ metadata=component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {},
+ )
+ for component in components
+ ],
+ )
+ return await supervisor._handle_register_plugin(
+ Envelope(
+ request_id=1,
+ message_type=MessageType.REQUEST,
+ method="plugin.register_components",
+ plugin_id=plugin_id,
+ payload=payload.model_dump(),
+ )
+ )
+
+
+async def _unregister_plugin(supervisor: PluginSupervisor, plugin_id: str) -> Envelope:
+ """通过 Supervisor 注销测试插件。
+
+ Args:
+ supervisor: 目标监督器。
+ plugin_id: 测试插件 ID。
+
+ Returns:
+ Envelope: 注销响应信封。
+ """
+
+ payload = UnregisterPluginPayload(plugin_id=plugin_id, reason="test")
+ return await supervisor._handle_unregister_plugin(
+ Envelope(
+ request_id=2,
+ message_type=MessageType.REQUEST,
+ method="plugin.unregister",
+ plugin_id=plugin_id,
+ payload=payload.model_dump(),
+ )
+ )
+
+
+@pytest.mark.asyncio
+async def test_register_plugin_syncs_dedicated_api_registry() -> None:
+ """插件注册时应将 API 同步到独立注册表,而不是通用组件表。"""
+
+ supervisor = PluginSupervisor(plugin_dirs=[])
+ response = await _register_plugin(
+ supervisor,
+ "provider",
+ [
+ {
+ "name": "render_html",
+ "component_type": "API",
+ "metadata": {
+ "description": "渲染 HTML",
+ "version": "1",
+ "public": True,
+ },
+ }
+ ],
+ )
+
+ assert response.payload["accepted"] is True
+ assert response.payload["registered_components"] == 0
+ assert response.payload["registered_apis"] == 1
+ assert supervisor.api_registry.get_api("provider", "render_html") is not None
+ assert supervisor.component_registry.get_component("provider.render_html") is None
+
+ unregister_response = await _unregister_plugin(supervisor, "provider")
+ assert unregister_response.payload["removed_apis"] == 1
+ assert supervisor.api_registry.get_api("provider", "render_html") is None
+
+
+@pytest.mark.asyncio
+async def test_api_call_allows_public_api_between_plugins(monkeypatch: pytest.MonkeyPatch) -> None:
+ """公开 API 应允许其他插件通过 Host 转发调用。"""
+
+ provider_supervisor = PluginSupervisor(plugin_dirs=[])
+ consumer_supervisor = PluginSupervisor(plugin_dirs=[])
+ await _register_plugin(
+ provider_supervisor,
+ "provider",
+ [
+ {
+ "name": "render_html",
+ "component_type": "API",
+ "metadata": {
+ "description": "渲染 HTML",
+ "version": "1",
+ "public": True,
+ },
+ }
+ ],
+ )
+ await _register_plugin(consumer_supervisor, "consumer", [])
+
+ captured: Dict[str, Any] = {}
+
+ async def fake_invoke_api(
+ plugin_id: str,
+ component_name: str,
+ args: Dict[str, Any] | None = None,
+ timeout_ms: int = 30000,
+ ) -> Any:
+ """模拟 API RPC 调用。"""
+
+ captured["plugin_id"] = plugin_id
+ captured["component_name"] = component_name
+ captured["args"] = args or {}
+ captured["timeout_ms"] = timeout_ms
+ return SimpleNamespace(error=None, payload={"success": True, "result": {"image": "ok"}})
+
+ monkeypatch.setattr(provider_supervisor, "invoke_api", fake_invoke_api)
+
+ manager = _build_manager(provider_supervisor, consumer_supervisor)
+ result = await manager._cap_api_call(
+ "consumer",
+ "api.call",
+ {
+ "api_name": "provider.render_html",
+ "version": "1",
+ "args": {"html": "Hello
"},
+ },
+ )
+
+ assert result == {"success": True, "result": {"image": "ok"}}
+ assert captured["plugin_id"] == "provider"
+ assert captured["component_name"] == "render_html"
+ assert captured["args"] == {"html": "Hello
"}
+
+
+@pytest.mark.asyncio
+async def test_api_call_rejects_private_api_between_plugins() -> None:
+ """未公开的 API 默认不允许跨插件调用。"""
+
+ provider_supervisor = PluginSupervisor(plugin_dirs=[])
+ consumer_supervisor = PluginSupervisor(plugin_dirs=[])
+ await _register_plugin(
+ provider_supervisor,
+ "provider",
+ [
+ {
+ "name": "secret_api",
+ "component_type": "API",
+ "metadata": {
+ "description": "私有 API",
+ "version": "1",
+ "public": False,
+ },
+ }
+ ],
+ )
+ await _register_plugin(consumer_supervisor, "consumer", [])
+
+ manager = _build_manager(provider_supervisor, consumer_supervisor)
+ result = await manager._cap_api_call(
+ "consumer",
+ "api.call",
+ {
+ "api_name": "provider.secret_api",
+ "args": {},
+ },
+ )
+
+ assert result["success"] is False
+ assert "未公开" in str(result["error"])
+
+
+@pytest.mark.asyncio
+async def test_api_list_and_component_toggle_use_dedicated_registry() -> None:
+ """API 列表与组件启停应直接作用于独立 API 注册表。"""
+
+ provider_supervisor = PluginSupervisor(plugin_dirs=[])
+ consumer_supervisor = PluginSupervisor(plugin_dirs=[])
+ await _register_plugin(
+ provider_supervisor,
+ "provider",
+ [
+ {
+ "name": "public_api",
+ "component_type": "API",
+ "metadata": {"version": "1", "public": True},
+ },
+ {
+ "name": "private_api",
+ "component_type": "API",
+ "metadata": {"version": "1", "public": False},
+ },
+ ],
+ )
+ await _register_plugin(
+ consumer_supervisor,
+ "consumer",
+ [
+ {
+ "name": "self_private_api",
+ "component_type": "API",
+ "metadata": {"version": "1", "public": False},
+ }
+ ],
+ )
+
+ manager = _build_manager(provider_supervisor, consumer_supervisor)
+ list_result = await manager._cap_api_list("consumer", "api.list", {})
+
+ assert list_result["success"] is True
+ api_names = {(item["plugin_id"], item["name"]) for item in list_result["apis"]}
+ assert ("provider", "public_api") in api_names
+ assert ("provider", "private_api") not in api_names
+ assert ("consumer", "self_private_api") in api_names
+
+ disable_result = await manager._cap_component_disable(
+ "consumer",
+ "component.disable",
+ {
+ "name": "provider.public_api",
+ "component_type": "API",
+ "scope": "global",
+ "stream_id": "",
+ },
+ )
+ assert disable_result["success"] is True
+ assert provider_supervisor.api_registry.get_api("provider", "public_api", enabled_only=True) is None
+
+ enable_result = await manager._cap_component_enable(
+ "consumer",
+ "component.enable",
+ {
+ "name": "provider.public_api",
+ "component_type": "API",
+ "scope": "global",
+ "stream_id": "",
+ },
+ )
+ assert enable_result["success"] is True
+ assert provider_supervisor.api_registry.get_api("provider", "public_api", enabled_only=True) is not None
diff --git a/src/plugin_runtime/capabilities/components.py b/src/plugin_runtime/capabilities/components.py
index 4223525f..2eede108 100644
--- a/src/plugin_runtime/capabilities/components.py
+++ b/src/plugin_runtime/capabilities/components.py
@@ -6,7 +6,8 @@ from src.common.logger import get_logger
logger = get_logger("plugin_runtime.integration")
if TYPE_CHECKING:
- from src.plugin_runtime.host.component_registry import RegisteredComponent
+ from src.plugin_runtime.host.api_registry import APIEntry
+ from src.plugin_runtime.host.component_registry import ComponentEntry
from src.plugin_runtime.host.supervisor import PluginSupervisor
@@ -18,7 +19,7 @@ class _RuntimeComponentManagerProtocol(Protocol):
def _resolve_component_toggle_target(
self, name: str, component_type: str
- ) -> tuple[Optional["RegisteredComponent"], Optional[str]]: ...
+ ) -> tuple[Optional["ComponentEntry"], Optional[str]]: ...
def _find_duplicate_plugin_ids(self, plugin_dirs: List[Path]) -> Dict[str, List[Path]]: ...
@@ -26,6 +27,203 @@ class _RuntimeComponentManagerProtocol(Protocol):
class RuntimeComponentCapabilityMixin:
+ @staticmethod
+ def _normalize_component_type(component_type: str) -> str:
+ """规范化组件类型名称。
+
+ Args:
+ component_type: 原始组件类型。
+
+ Returns:
+ str: 统一转为大写后的组件类型名。
+ """
+
+ return str(component_type or "").strip().upper()
+
+ @classmethod
+ def _is_api_component_type(cls, component_type: str) -> bool:
+ """判断组件类型是否为 API。
+
+ Args:
+ component_type: 原始组件类型。
+
+ Returns:
+ bool: 是否为 API 组件类型。
+ """
+
+ return cls._normalize_component_type(component_type) == "API"
+
+ @staticmethod
+ def _serialize_api_entry(entry: "APIEntry") -> Dict[str, Any]:
+ """将 API 组件条目序列化为能力返回值。
+
+ Args:
+ entry: API 组件条目。
+
+ Returns:
+ Dict[str, Any]: 适合通过能力层返回给插件的 API 元信息。
+ """
+
+ return {
+ "name": entry.name,
+ "full_name": entry.full_name,
+ "plugin_id": entry.plugin_id,
+ "description": entry.description,
+ "version": entry.version,
+ "public": entry.public,
+ "enabled": entry.enabled,
+ "metadata": dict(entry.metadata),
+ }
+
+ @classmethod
+ def _serialize_api_component_entry(cls, entry: "APIEntry") -> Dict[str, Any]:
+ """将 API 条目序列化为通用组件视图。
+
+ Args:
+ entry: API 组件条目。
+
+ Returns:
+ Dict[str, Any]: 适合 ``component.get_all_plugins`` 返回的组件结构。
+ """
+
+ serialized_entry = cls._serialize_api_entry(entry)
+ return {
+ "name": serialized_entry["name"],
+ "full_name": serialized_entry["full_name"],
+ "type": "API",
+ "enabled": serialized_entry["enabled"],
+ "metadata": serialized_entry["metadata"],
+ }
+
+ @staticmethod
+ def _is_api_visible_to_plugin(entry: "APIEntry", caller_plugin_id: str) -> bool:
+ """判断某个 API 是否对调用方可见。
+
+ Args:
+ entry: 目标 API 组件条目。
+ caller_plugin_id: 调用方插件 ID。
+
+ Returns:
+ bool: 是否允许当前插件可见并调用。
+ """
+
+ return entry.plugin_id == caller_plugin_id or entry.public
+
+ def _resolve_api_target(
+ self: _RuntimeComponentManagerProtocol,
+ caller_plugin_id: str,
+ api_name: str,
+ version: str = "",
+ ) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]:
+ """解析 API 名称到唯一可调用的目标组件。
+
+ Args:
+ caller_plugin_id: 调用方插件 ID。
+ api_name: API 名称,支持 ``plugin_id.api_name`` 或唯一短名。
+ version: 可选的 API 版本。
+
+ Returns:
+ tuple[Optional[PluginSupervisor], Optional[APIEntry], Optional[str]]:
+ 解析成功时返回 ``(监督器, API 条目, None)``,失败时返回错误信息。
+ """
+
+ normalized_api_name = str(api_name or "").strip()
+ normalized_version = str(version or "").strip()
+ if not normalized_api_name:
+ return None, None, "缺少必要参数 api_name"
+
+ if "." in normalized_api_name:
+ target_plugin_id, target_api_name = normalized_api_name.split(".", 1)
+ try:
+ supervisor = self._get_supervisor_for_plugin(target_plugin_id)
+ except RuntimeError as exc:
+ return None, None, str(exc)
+
+ if supervisor is None:
+ return None, None, f"未找到 API 提供方插件: {target_plugin_id}"
+
+ entry = supervisor.api_registry.get_api(
+ plugin_id=target_plugin_id,
+ name=target_api_name,
+ enabled_only=True,
+ )
+ if entry is None:
+ return None, None, f"未找到 API: {normalized_api_name}"
+ if normalized_version and entry.version != normalized_version:
+ return None, None, f"未找到版本为 {normalized_version} 的 API: {normalized_api_name}"
+ if not self._is_api_visible_to_plugin(entry, caller_plugin_id):
+ return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用"
+ return supervisor, entry, None
+
+ visible_matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
+ hidden_match_exists = False
+ for supervisor in self.supervisors:
+ for entry in supervisor.api_registry.get_apis(name=normalized_api_name, enabled_only=True):
+ if normalized_version and entry.version != normalized_version:
+ continue
+ if self._is_api_visible_to_plugin(entry, caller_plugin_id):
+ visible_matches.append((supervisor, entry))
+ else:
+ hidden_match_exists = True
+
+ if len(visible_matches) == 1:
+ return visible_matches[0][0], visible_matches[0][1], None
+ if len(visible_matches) > 1:
+ return None, None, f"API 名称不唯一: {normalized_api_name},请使用 plugin_id.api_name"
+ if hidden_match_exists:
+ return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用"
+ if normalized_version:
+ return None, None, f"未找到版本为 {normalized_version} 的 API: {normalized_api_name}"
+ return None, None, f"未找到 API: {normalized_api_name}"
+
+ def _resolve_api_toggle_target(
+ self: _RuntimeComponentManagerProtocol,
+ name: str,
+ ) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]:
+ """解析需要启用或禁用的 API 组件。
+
+ Args:
+ name: API 名称,支持 ``plugin_id.api_name`` 或唯一短名。
+
+ Returns:
+ tuple[Optional[PluginSupervisor], Optional[APIEntry], Optional[str]]:
+ 解析成功时返回 ``(监督器, API 条目, None)``,失败时返回错误信息。
+ """
+
+ normalized_name = str(name or "").strip()
+ if not normalized_name:
+ return None, None, "缺少必要参数 name"
+
+ if "." in normalized_name:
+ plugin_id, api_name = normalized_name.split(".", 1)
+ try:
+ supervisor = self._get_supervisor_for_plugin(plugin_id)
+ except RuntimeError as exc:
+ return None, None, str(exc)
+
+ if supervisor is None:
+ return None, None, f"未找到 API 提供方插件: {plugin_id}"
+
+ entry = supervisor.api_registry.get_api(
+ plugin_id=plugin_id,
+ name=api_name,
+ enabled_only=False,
+ )
+ if entry is None:
+ return None, None, f"未找到 API: {normalized_name}"
+ return supervisor, entry, None
+
+ matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
+ for supervisor in self.supervisors:
+ for entry in supervisor.api_registry.get_apis(name=normalized_name, enabled_only=False):
+ matches.append((supervisor, entry))
+
+ if len(matches) == 1:
+ return matches[0][0], matches[0][1], None
+ if len(matches) > 1:
+ return None, None, f"API 名称不唯一: {normalized_name},请使用 plugin_id.api_name"
+ return None, None, f"未找到 API: {normalized_name}"
+
async def _cap_component_get_all_plugins(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
@@ -46,6 +244,10 @@ class RuntimeComponentCapabilityMixin:
}
for component in comps
]
+ components_list.extend(
+ self._serialize_api_component_entry(entry)
+ for entry in sv.api_registry.get_apis(plugin_id=pid, enabled_only=False)
+ )
result[pid] = {
"name": pid,
"version": reg.plugin_version,
@@ -96,24 +298,28 @@ class RuntimeComponentCapabilityMixin:
def _resolve_component_toggle_target(
self: _RuntimeComponentManagerProtocol, name: str, component_type: str
- ) -> tuple[Optional["RegisteredComponent"], Optional[str]]:
- short_name_matches: List["RegisteredComponent"] = []
+ ) -> tuple[Optional["ComponentEntry"], Optional[str]]:
+ normalized_component_type = self._normalize_component_type(component_type)
+ short_name_matches: List["ComponentEntry"] = []
for sv in self.supervisors:
comp = sv.component_registry.get_component(name)
- if comp is not None and comp.component_type == component_type:
+ if comp is not None and comp.component_type == normalized_component_type:
return comp, None
short_name_matches.extend(
candidate
- for candidate in sv.component_registry.get_components_by_type(component_type, enabled_only=False)
+ for candidate in sv.component_registry.get_components_by_type(
+ normalized_component_type,
+ enabled_only=False,
+ )
if candidate.name == name
)
if len(short_name_matches) == 1:
return short_name_matches[0], None
if len(short_name_matches) > 1:
- return None, f"组件名不唯一: {name} ({component_type}),请使用完整名 plugin_id.component_name"
- return None, f"未找到组件: {name} ({component_type})"
+ return None, f"组件名不唯一: {name} ({normalized_component_type}),请使用完整名 plugin_id.component_name"
+ return None, f"未找到组件: {name} ({normalized_component_type})"
async def _cap_component_enable(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
@@ -127,6 +333,13 @@ class RuntimeComponentCapabilityMixin:
if scope != "global" or stream_id:
return {"success": False, "error": "当前仅支持全局组件启用,不支持 scope/stream_id 定位"}
+ if self._is_api_component_type(component_type):
+ supervisor, api_entry, error = self._resolve_api_toggle_target(name)
+ if supervisor is None or api_entry is None:
+ return {"success": False, "error": error or f"未找到 API: {name}"}
+ supervisor.api_registry.toggle_api_status(api_entry.full_name, True)
+ return {"success": True}
+
comp, error = self._resolve_component_toggle_target(name, component_type)
if comp is None:
return {"success": False, "error": error or f"未找到组件: {name} ({component_type})"}
@@ -146,6 +359,13 @@ class RuntimeComponentCapabilityMixin:
if scope != "global" or stream_id:
return {"success": False, "error": "当前仅支持全局组件禁用,不支持 scope/stream_id 定位"}
+ if self._is_api_component_type(component_type):
+ supervisor, api_entry, error = self._resolve_api_toggle_target(name)
+ if supervisor is None or api_entry is None:
+ return {"success": False, "error": error or f"未找到 API: {name}"}
+ supervisor.api_registry.toggle_api_status(api_entry.full_name, False)
+ return {"success": True}
+
comp, error = self._resolve_component_toggle_target(name, component_type)
if comp is None:
return {"success": False, "error": error or f"未找到组件: {name} ({component_type})"}
@@ -239,3 +459,111 @@ class RuntimeComponentCapabilityMixin:
logger.error(f"[cap.component.reload_plugin] 热重载失败: {e}")
return {"success": False, "error": str(e)}
return {"success": False, "error": f"未找到插件: {plugin_name}"}
+
+ async def _cap_api_call(
+ self: _RuntimeComponentManagerProtocol,
+ plugin_id: str,
+ capability: str,
+ args: Dict[str, Any],
+ ) -> Any:
+ """调用其他插件公开的 API。
+
+ Args:
+ plugin_id: 当前调用方插件 ID。
+ capability: 能力名称。
+ args: 能力参数。
+
+ Returns:
+ Any: API 调用结果。
+ """
+
+ del capability
+ api_name = str(args.get("api_name", "") or "").strip()
+ version = str(args.get("version", "") or "").strip()
+ api_args = args.get("args", {})
+ if not isinstance(api_args, dict):
+ return {"success": False, "error": "参数 args 必须为字典"}
+
+ supervisor, entry, error = self._resolve_api_target(plugin_id, api_name, version)
+ if supervisor is None or entry is None:
+ return {"success": False, "error": error or "API 解析失败"}
+
+ try:
+ response = await supervisor.invoke_api(
+ plugin_id=entry.plugin_id,
+ component_name=entry.name,
+ args=api_args,
+ timeout_ms=30000,
+ )
+ except Exception as exc:
+ logger.error(f"[cap.api.call] 调用 API {entry.full_name} 失败: {exc}", exc_info=True)
+ return {"success": False, "error": str(exc)}
+
+ if response.error:
+ return {"success": False, "error": response.error.get("message", "API 调用失败")}
+
+ payload = response.payload if isinstance(response.payload, dict) else {}
+ if not bool(payload.get("success", False)):
+ result = payload.get("result")
+ return {"success": False, "error": "" if result is None else str(result)}
+ return {"success": True, "result": payload.get("result")}
+
+ async def _cap_api_get(
+ self: _RuntimeComponentManagerProtocol,
+ plugin_id: str,
+ capability: str,
+ args: Dict[str, Any],
+ ) -> Any:
+ """获取当前插件可见的单个 API 元信息。
+
+ Args:
+ plugin_id: 当前调用方插件 ID。
+ capability: 能力名称。
+ args: 能力参数。
+
+ Returns:
+ Any: API 元信息或 ``None``。
+ """
+
+ del capability
+ api_name = str(args.get("api_name", "") or "").strip()
+ version = str(args.get("version", "") or "").strip()
+ if not api_name:
+ return {"success": False, "error": "缺少必要参数 api_name"}
+
+ supervisor, entry, _error = self._resolve_api_target(plugin_id, api_name, version)
+ if supervisor is None or entry is None:
+ return {"success": True, "api": None}
+ return {"success": True, "api": self._serialize_api_entry(entry)}
+
+ async def _cap_api_list(
+ self: _RuntimeComponentManagerProtocol,
+ plugin_id: str,
+ capability: str,
+ args: Dict[str, Any],
+ ) -> Any:
+ """列出当前插件可见的 API 列表。
+
+ Args:
+ plugin_id: 当前调用方插件 ID。
+ capability: 能力名称。
+ args: 能力参数。
+
+ Returns:
+ Any: API 元信息列表。
+ """
+
+ del capability
+ target_plugin_id = str(args.get("plugin_id", "") or "").strip()
+ apis: List[Dict[str, Any]] = []
+ for supervisor in self.supervisors:
+ for entry in supervisor.api_registry.get_apis(
+ plugin_id=target_plugin_id or None,
+ enabled_only=True,
+ ):
+ if not self._is_api_visible_to_plugin(entry, plugin_id):
+ continue
+ apis.append(self._serialize_api_entry(entry))
+
+ apis.sort(key=lambda item: (str(item["plugin_id"]), str(item["name"]), str(item["version"])))
+ return {"success": True, "apis": apis}
diff --git a/src/plugin_runtime/capabilities/registry.py b/src/plugin_runtime/capabilities/registry.py
index 96b190b4..31693833 100644
--- a/src/plugin_runtime/capabilities/registry.py
+++ b/src/plugin_runtime/capabilities/registry.py
@@ -74,6 +74,10 @@ def register_capability_impls(manager: "PluginRuntimeManager", supervisor: Plugi
_register("tool.get_definitions", manager._cap_tool_get_definitions)
+ _register("api.call", manager._cap_api_call)
+ _register("api.get", manager._cap_api_get)
+ _register("api.list", manager._cap_api_list)
+
_register("component.get_all_plugins", manager._cap_component_get_all_plugins)
_register("component.get_plugin_info", manager._cap_component_get_plugin_info)
_register("component.list_loaded_plugins", manager._cap_component_list_loaded_plugins)
diff --git a/src/plugin_runtime/host/api_registry.py b/src/plugin_runtime/host/api_registry.py
new file mode 100644
index 00000000..84578ca5
--- /dev/null
+++ b/src/plugin_runtime/host/api_registry.py
@@ -0,0 +1,290 @@
+"""Host 侧插件 API 动态注册表。"""
+
+from typing import Any, Dict, List, Optional, Set
+
+from src.common.logger import get_logger
+
+logger = get_logger("plugin_runtime.host.api_registry")
+
+
+class APIEntry:
+ """API 组件条目。"""
+
+ __slots__ = (
+ "description",
+ "disabled_session",
+ "enabled",
+ "full_name",
+ "metadata",
+ "name",
+ "plugin_id",
+ "public",
+ "version",
+ )
+
+ def __init__(self, name: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
+ """初始化 API 组件条目。
+
+ Args:
+ name: API 名称。
+ plugin_id: 所属插件 ID。
+ metadata: API 元数据。
+ """
+
+ self.name: str = name
+ self.full_name: str = f"{plugin_id}.{name}"
+ self.plugin_id: str = plugin_id
+ self.description: str = str(metadata.get("description", "") or "")
+ self.version: str = str(metadata.get("version", "1") or "1").strip() or "1"
+ self.public: bool = bool(metadata.get("public", False))
+ self.metadata: Dict[str, Any] = dict(metadata)
+ self.enabled: bool = bool(metadata.get("enabled", True))
+ self.disabled_session: Set[str] = set()
+
+
+class APIRegistry:
+ """Host 侧插件 API 动态注册表。
+
+ 该注册表不直接面向 Runner,而是复用插件组件注册/卸载事件,
+ 维护面向 API 调用场景的专用索引。
+ """
+
+ def __init__(self) -> None:
+ """初始化 API 注册表。"""
+
+ self._apis: Dict[str, APIEntry] = {}
+ self._by_plugin: Dict[str, List[APIEntry]] = {}
+ self._by_name: Dict[str, List[APIEntry]] = {}
+
+ def clear(self) -> None:
+ """清空全部 API 注册状态。"""
+
+ self._apis.clear()
+ self._by_plugin.clear()
+ self._by_name.clear()
+
+ @staticmethod
+ def _is_api_component(component_type: Any) -> bool:
+ """判断组件声明是否属于 API。
+
+ Args:
+ component_type: 原始组件类型值。
+
+ Returns:
+ bool: 是否为 API 组件。
+ """
+
+ return str(component_type or "").strip().upper() == "API"
+
+ @staticmethod
+ def check_api_enabled(entry: APIEntry, session_id: Optional[str] = None) -> bool:
+ """判断 API 条目当前是否处于启用状态。
+
+ Args:
+ entry: 待检查的 API 条目。
+ session_id: 可选的会话 ID。
+
+ Returns:
+ bool: 当前是否可用。
+ """
+
+ if session_id and session_id in entry.disabled_session:
+ return False
+ return entry.enabled
+
+ def register_api(self, name: str, plugin_id: str, metadata: Dict[str, Any]) -> bool:
+ """注册单个 API 条目。
+
+ Args:
+ name: API 名称。
+ plugin_id: 所属插件 ID。
+ metadata: API 元数据。
+
+ Returns:
+ bool: 是否成功注册。
+ """
+
+ normalized_name = str(name or "").strip()
+ if not normalized_name:
+ logger.warning(f"插件 {plugin_id} 存在空 API 名称声明,已忽略")
+ return False
+
+ entry = APIEntry(name=normalized_name, plugin_id=plugin_id, metadata=metadata)
+ if entry.full_name in self._apis:
+ logger.warning(f"API {entry.full_name} 已存在,覆盖旧条目")
+ self._remove_entry(self._apis[entry.full_name])
+
+ self._apis[entry.full_name] = entry
+ self._by_plugin.setdefault(plugin_id, []).append(entry)
+ self._by_name.setdefault(entry.name, []).append(entry)
+ return True
+
+ def register_plugin_apis(self, plugin_id: str, components: List[Dict[str, Any]]) -> int:
+ """批量注册某个插件声明的全部 API。
+
+ Args:
+ plugin_id: 插件 ID。
+ components: 插件组件声明列表。
+
+ Returns:
+ int: 成功注册的 API 数量。
+ """
+
+ count = 0
+ for component in components:
+ if not self._is_api_component(component.get("component_type")):
+ continue
+ if self.register_api(
+ name=str(component.get("name", "") or ""),
+ plugin_id=plugin_id,
+ metadata=component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {},
+ ):
+ count += 1
+ return count
+
+ def _remove_entry(self, entry: APIEntry) -> None:
+ """从全部索引中移除单个 API 条目。
+
+ Args:
+ entry: 待移除的 API 条目。
+ """
+
+ self._apis.pop(entry.full_name, None)
+ plugin_entries = self._by_plugin.get(entry.plugin_id)
+ if plugin_entries is not None:
+ self._by_plugin[entry.plugin_id] = [candidate for candidate in plugin_entries if candidate is not entry]
+ if not self._by_plugin[entry.plugin_id]:
+ self._by_plugin.pop(entry.plugin_id, None)
+
+ name_entries = self._by_name.get(entry.name)
+ if name_entries is not None:
+ self._by_name[entry.name] = [candidate for candidate in name_entries if candidate is not entry]
+ if not self._by_name[entry.name]:
+ self._by_name.pop(entry.name, None)
+
+ def remove_apis_by_plugin(self, plugin_id: str) -> int:
+ """移除某个插件的全部 API。
+
+ Args:
+ plugin_id: 目标插件 ID。
+
+ Returns:
+ int: 被移除的 API 数量。
+ """
+
+ entries = list(self._by_plugin.get(plugin_id, []))
+ for entry in entries:
+ self._remove_entry(entry)
+ return len(entries)
+
+ def get_api_by_full_name(
+ self,
+ full_name: str,
+ *,
+ enabled_only: bool = True,
+ session_id: Optional[str] = None,
+ ) -> Optional[APIEntry]:
+ """按完整名查询单个 API。
+
+ Args:
+ full_name: API 完整名,格式为 ``plugin_id.api_name``。
+ enabled_only: 是否仅返回启用状态的 API。
+ session_id: 可选的会话 ID。
+
+ Returns:
+ Optional[APIEntry]: 命中时返回 API 条目。
+ """
+
+ entry = self._apis.get(full_name)
+ if entry is None:
+ return None
+ if enabled_only and not self.check_api_enabled(entry, session_id):
+ return None
+ return entry
+
+ def get_api(
+ self,
+ plugin_id: str,
+ name: str,
+ *,
+ enabled_only: bool = True,
+ session_id: Optional[str] = None,
+ ) -> Optional[APIEntry]:
+ """按插件 ID 和短名查询单个 API。
+
+ Args:
+ plugin_id: 提供方插件 ID。
+ name: API 短名。
+ enabled_only: 是否仅返回启用状态的 API。
+ session_id: 可选的会话 ID。
+
+ Returns:
+ Optional[APIEntry]: 命中时返回 API 条目。
+ """
+
+ return self.get_api_by_full_name(
+ f"{plugin_id}.{name}",
+ enabled_only=enabled_only,
+ session_id=session_id,
+ )
+
+ def get_apis(
+ self,
+ *,
+ plugin_id: Optional[str] = None,
+ name: str = "",
+ enabled_only: bool = True,
+ session_id: Optional[str] = None,
+ ) -> List[APIEntry]:
+ """查询 API 列表。
+
+ Args:
+ plugin_id: 可选的插件 ID 过滤条件。
+ name: 可选的 API 名称过滤条件。
+ enabled_only: 是否仅返回启用状态的 API。
+ session_id: 可选的会话 ID。
+
+ Returns:
+ List[APIEntry]: 符合条件的 API 条目列表。
+ """
+
+ normalized_name = str(name or "").strip()
+ if plugin_id:
+ candidates = list(self._by_plugin.get(plugin_id, []))
+ elif normalized_name:
+ candidates = list(self._by_name.get(normalized_name, []))
+ else:
+ candidates = list(self._apis.values())
+
+ filtered_entries: List[APIEntry] = []
+ for entry in candidates:
+ if normalized_name and entry.name != normalized_name:
+ continue
+ if enabled_only and not self.check_api_enabled(entry, session_id):
+ continue
+ filtered_entries.append(entry)
+ return filtered_entries
+
+ def toggle_api_status(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool:
+ """设置指定 API 的启用状态。
+
+ Args:
+ full_name: API 完整名。
+ enabled: 目标启用状态。
+ session_id: 可选的会话 ID,仅对该会话生效。
+
+ Returns:
+ bool: 是否设置成功。
+ """
+
+ entry = self._apis.get(full_name)
+ if entry is None:
+ return False
+ if session_id:
+ if enabled:
+ entry.disabled_session.discard(session_id)
+ else:
+ entry.disabled_session.add(session_id)
+ else:
+ entry.enabled = enabled
+ return True
diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py
index 4a9885f8..1add64c6 100644
--- a/src/plugin_runtime/host/supervisor.py
+++ b/src/plugin_runtime/host/supervisor.py
@@ -34,6 +34,7 @@ from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
from src.plugin_runtime.transport.factory import create_transport_server
from .authorization import AuthorizationManager
+from .api_registry import APIRegistry
from .capability_service import CapabilityService
from .component_registry import ComponentRegistry
from .event_dispatcher import EventDispatcher
@@ -93,6 +94,7 @@ class PluginRunnerSupervisor:
self._transport = create_transport_server(socket_path=socket_path)
self._authorization = AuthorizationManager()
self._capability_service = CapabilityService(self._authorization)
+ self._api_registry = APIRegistry()
self._component_registry = ComponentRegistry()
self._event_dispatcher = EventDispatcher(self._component_registry)
self._hook_dispatcher = HookDispatcher(self._component_registry)
@@ -124,6 +126,11 @@ class PluginRunnerSupervisor:
"""返回能力服务。"""
return self._capability_service
+ @property
+ def api_registry(self) -> APIRegistry:
+ """返回 API 专用注册表。"""
+ return self._api_registry
+
@property
def component_registry(self) -> ComponentRegistry:
"""返回组件注册表。"""
@@ -310,6 +317,33 @@ class PluginRunnerSupervisor:
timeout_ms=timeout_ms,
)
+ async def invoke_api(
+ self,
+ plugin_id: str,
+ component_name: str,
+ args: Optional[Dict[str, Any]] = None,
+ timeout_ms: int = 30000,
+ ) -> Envelope:
+ """调用插件声明的 API 方法。
+
+ Args:
+ plugin_id: 目标插件 ID。
+ component_name: API 组件名称。
+ args: 传递给 API 方法的关键字参数。
+ timeout_ms: RPC 超时时间,单位毫秒。
+
+ Returns:
+ Envelope: Runner 返回的响应信封。
+ """
+
+ return await self.invoke_plugin(
+ method="plugin.invoke_api",
+ plugin_id=plugin_id,
+ component_name=component_name,
+ args=args,
+ timeout_ms=timeout_ms,
+ )
+
async def reload_plugin(self, plugin_id: str, reason: str = "manual") -> bool:
"""按插件 ID 触发精确重载。
@@ -507,13 +541,17 @@ class PluginRunnerSupervisor:
except Exception as exc:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
+ component_declarations = [component.model_dump() for component in payload.components]
+ runtime_components, api_components = self._split_component_declarations(component_declarations)
self._component_registry.remove_components_by_plugin(payload.plugin_id)
+ self._api_registry.remove_apis_by_plugin(payload.plugin_id)
await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id)
registered_count = self._component_registry.register_plugin_components(
payload.plugin_id,
- [component.model_dump() for component in payload.components],
+ runtime_components,
)
+ registered_api_count = self._api_registry.register_plugin_apis(payload.plugin_id, api_components)
self._registered_plugins[payload.plugin_id] = payload
self._message_gateway_states[payload.plugin_id] = {}
@@ -522,6 +560,7 @@ class PluginRunnerSupervisor:
"accepted": True,
"plugin_id": payload.plugin_id,
"registered_components": registered_count,
+ "registered_apis": registered_api_count,
"message_gateways": len(
self._component_registry.get_message_gateways(plugin_id=payload.plugin_id, enabled_only=False)
),
@@ -543,6 +582,7 @@ class PluginRunnerSupervisor:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
removed_components = self._component_registry.remove_components_by_plugin(payload.plugin_id)
+ removed_apis = self._api_registry.remove_apis_by_plugin(payload.plugin_id)
self._authorization.revoke_permission_token(payload.plugin_id)
removed_registration = self._registered_plugins.pop(payload.plugin_id, None) is not None
await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id)
@@ -554,10 +594,48 @@ class PluginRunnerSupervisor:
"plugin_id": payload.plugin_id,
"reason": payload.reason,
"removed_components": removed_components,
+ "removed_apis": removed_apis,
"removed_registration": removed_registration,
}
)
+ @staticmethod
+ def _is_api_component(component: Dict[str, Any]) -> bool:
+ """判断组件声明是否属于 API。
+
+ Args:
+ component: 原始组件声明字典。
+
+ Returns:
+ bool: 是否为 API 组件。
+ """
+
+ return str(component.get("component_type", "") or "").strip().upper() == "API"
+
+ def _split_component_declarations(
+ self,
+ components: List[Dict[str, Any]],
+ ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
+ """拆分通用组件声明和 API 声明。
+
+ Args:
+ components: Runner 上报的原始组件声明列表。
+
+ Returns:
+ Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
+ 第一个列表为需要进入通用组件表的声明,
+ 第二个列表为需要进入 API 专用表的声明。
+ """
+
+ runtime_components: List[Dict[str, Any]] = []
+ api_components: List[Dict[str, Any]] = []
+ for component in components:
+ if self._is_api_component(component):
+ api_components.append(component)
+ else:
+ runtime_components.append(component)
+ return runtime_components, api_components
+
@staticmethod
def _build_message_gateway_driver_id(plugin_id: str, gateway_name: str) -> str:
"""构造消息网关驱动 ID。
@@ -1172,6 +1250,7 @@ class PluginRunnerSupervisor:
def _clear_runner_state(self) -> None:
"""清理当前 Runner 对应的 Host 侧注册状态。"""
self._authorization.clear()
+ self._api_registry.clear()
self._component_registry.clear()
self._registered_plugins.clear()
self._message_gateway_states.clear()
diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py
index 3a50e2f7..4bee714c 100644
--- a/src/plugin_runtime/runner/runner_main.py
+++ b/src/plugin_runtime/runner/runner_main.py
@@ -291,6 +291,7 @@ class PluginRunner:
"""注册 Host -> Runner 的方法处理器。"""
self._rpc_client.register_method("plugin.invoke_command", self._handle_invoke)
self._rpc_client.register_method("plugin.invoke_action", self._handle_invoke)
+ self._rpc_client.register_method("plugin.invoke_api", self._handle_invoke)
self._rpc_client.register_method("plugin.invoke_tool", self._handle_invoke)
self._rpc_client.register_method("plugin.invoke_message_gateway", self._handle_invoke)
self._rpc_client.register_method("plugin.emit_event", self._handle_event_invoke)
From d13767ee21ea762c3006a9d74b161b69a098dd00 Mon Sep 17 00:00:00 2001
From: DrSmoothl <1787882683@qq.com>
Date: Mon, 23 Mar 2026 20:06:12 +0800
Subject: [PATCH 34/45] feat: Enhance plugin configuration management and SDK
integration
- Add support for configuration reload scopes in the plugin runtime.
- Implement validation for SDK plugins to ensure required lifecycle methods are overridden.
- Update the configuration update handling to include scope information.
- Introduce tests for expression auto-check task and NapCat adapter SDK integration.
- Refactor configuration management to support callbacks with variable arguments.
- Improve plugin loading and error handling for configuration updates.
- Ensure that plugins can manage their own configuration updates effectively.
---
plugins/ChatFrequency/plugin.py | 29 +-
plugins/emoji_manage_plugin/plugin.py | 31 ++-
plugins/hello_world_plugin/plugin.py | 33 ++-
pyproject.toml | 2 +
.../test_expression_auto_check_task.py | 89 ++++++
pytests/test_napcat_adapter_sdk.py | 132 +++++++++
pytests/test_plugin_runtime.py | 263 ++++++++++++++++--
src/config/config.py | 135 ++++++++-
src/learners/expression_auto_check_task.py | 7 +-
src/plugin_runtime/host/supervisor.py | 19 ++
src/plugin_runtime/integration.py | 99 +++++--
src/plugin_runtime/protocol/envelope.py | 12 +
src/plugin_runtime/runner/plugin_loader.py | 30 ++
src/plugin_runtime/runner/runner_main.py | 33 ++-
src/plugins/built_in/emoji_plugin/plugin.py | 35 ++-
.../built_in/plugin_management/plugin.py | 29 +-
16 files changed, 907 insertions(+), 71 deletions(-)
create mode 100644 pytests/common_test/test_expression_auto_check_task.py
create mode 100644 pytests/test_napcat_adapter_sdk.py
diff --git a/plugins/ChatFrequency/plugin.py b/plugins/ChatFrequency/plugin.py
index b3f69384..0e9f5a0c 100644
--- a/plugins/ChatFrequency/plugin.py
+++ b/plugins/ChatFrequency/plugin.py
@@ -3,12 +3,18 @@
通过 /chat 命令设置和查看聊天频率。
"""
-from maibot_sdk import MaiBotPlugin, Command
+from maibot_sdk import Command, MaiBotPlugin
class BetterFrequencyPlugin(MaiBotPlugin):
"""聊天频率控制插件"""
+ async def on_load(self) -> None:
+ """处理插件加载。"""
+
+ async def on_unload(self) -> None:
+ """处理插件卸载。"""
+
@Command(
"set_talk_frequency",
description="设置当前聊天的talk_frequency值:/chat talk_frequency <数字> 或 /chat t <数字>",
@@ -80,6 +86,25 @@ class BetterFrequencyPlugin(MaiBotPlugin):
await self.ctx.send.text(status_msg, stream_id)
return True, None, False
+ async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None:
+ """处理配置热重载事件。
+
+ Args:
+ scope: 配置变更范围。
+ config_data: 最新配置数据。
+ version: 配置版本号。
+ """
+
+ del scope
+ del config_data
+ del version
+
+
+def create_plugin() -> BetterFrequencyPlugin:
+ """创建聊天频率插件实例。
+
+ Returns:
+ BetterFrequencyPlugin: 新的聊天频率插件实例。
+ """
-def create_plugin():
return BetterFrequencyPlugin()
diff --git a/plugins/emoji_manage_plugin/plugin.py b/plugins/emoji_manage_plugin/plugin.py
index f3c5f677..9362c828 100644
--- a/plugins/emoji_manage_plugin/plugin.py
+++ b/plugins/emoji_manage_plugin/plugin.py
@@ -3,17 +3,23 @@
通过 /emoji 命令管理表情包的添加、列表和删除。
"""
+from maibot_sdk import Command, MaiBotPlugin
+
import base64
import datetime
import hashlib
import re
-from maibot_sdk import MaiBotPlugin, Command
-
class EmojiManagePlugin(MaiBotPlugin):
"""表情包管理插件"""
+ async def on_load(self) -> None:
+ """处理插件加载。"""
+
+ async def on_unload(self) -> None:
+ """处理插件卸载。"""
+
# ===== 工具方法 =====
@staticmethod
@@ -208,6 +214,25 @@ class EmojiManagePlugin(MaiBotPlugin):
await self.ctx.send.forward(messages, stream_id)
return True, "已发送随机表情包", True
+ async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None:
+ """处理配置热重载事件。
+
+ Args:
+ scope: 配置变更范围。
+ config_data: 最新配置数据。
+ version: 配置版本号。
+ """
+
+ del scope
+ del config_data
+ del version
+
+
+def create_plugin() -> EmojiManagePlugin:
+ """创建表情包管理插件实例。
+
+ Returns:
+ EmojiManagePlugin: 新的表情包管理插件实例。
+ """
-def create_plugin():
return EmojiManagePlugin()
diff --git a/plugins/hello_world_plugin/plugin.py b/plugins/hello_world_plugin/plugin.py
index fbba9d10..4d1f37af 100644
--- a/plugins/hello_world_plugin/plugin.py
+++ b/plugins/hello_world_plugin/plugin.py
@@ -3,16 +3,22 @@
你的第一个 MaiCore 插件,包含问候功能、时间查询等基础示例。
"""
+from maibot_sdk import Action, Command, EventHandler, MaiBotPlugin, Tool
+from maibot_sdk.types import ActivationType, EventType, ToolParameterInfo, ToolParamType
+
import datetime
import random
-from maibot_sdk import MaiBotPlugin, Action, Command, Tool, EventHandler
-from maibot_sdk.types import ActivationType, EventType, ToolParameterInfo, ToolParamType
-
class HelloWorldPlugin(MaiBotPlugin):
"""Hello World 示例插件"""
+ async def on_load(self) -> None:
+ """处理插件加载。"""
+
+ async def on_unload(self) -> None:
+ """处理插件卸载。"""
+
# ===== Tool 组件 =====
@Tool(
@@ -146,6 +152,25 @@ class HelloWorldPlugin(MaiBotPlugin):
return True, True, None, None, None
+ async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None:
+ """处理配置热重载事件。
+
+ Args:
+ scope: 配置变更范围。
+ config_data: 最新配置数据。
+ version: 配置版本号。
+ """
+
+ del scope
+ del config_data
+ del version
+
+
+def create_plugin() -> HelloWorldPlugin:
+ """创建 Hello World 示例插件实例。
+
+ Returns:
+ HelloWorldPlugin: 新的示例插件实例。
+ """
-def create_plugin():
return HelloWorldPlugin()
diff --git a/pyproject.toml b/pyproject.toml
index 9887ac24..95c92acd 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -54,6 +54,8 @@ dev = [
[tool.uv]
index-url = "https://pypi.tuna.tsinghua.edu.cn/simple"
+[tool.uv.sources]
+maibot-plugin-sdk = { path = "packages/maibot-plugin-sdk", editable = true }
[tool.ruff]
diff --git a/pytests/common_test/test_expression_auto_check_task.py b/pytests/common_test/test_expression_auto_check_task.py
new file mode 100644
index 00000000..da8c59e1
--- /dev/null
+++ b/pytests/common_test/test_expression_auto_check_task.py
@@ -0,0 +1,89 @@
+"""测试表达方式自动检查任务的数据库读取行为。"""
+
+from contextlib import contextmanager
+from typing import Generator
+
+import pytest
+from sqlalchemy.pool import StaticPool
+from sqlmodel import Session, SQLModel, create_engine
+
+from src.bw_learner.expression_auto_check_task import ExpressionAutoCheckTask
+from src.common.database.database_model import Expression
+
+
+@pytest.fixture(name="expression_auto_check_engine")
+def expression_auto_check_engine_fixture() -> Generator:
+ """创建用于表达方式自动检查任务测试的内存数据库引擎。
+
+ Yields:
+ Generator: 供测试使用的 SQLite 内存引擎。
+ """
+
+ engine = create_engine(
+ "sqlite://",
+ connect_args={"check_same_thread": False},
+ poolclass=StaticPool,
+ )
+ SQLModel.metadata.create_all(engine)
+ yield engine
+
+
+@pytest.mark.asyncio
+async def test_select_expressions_uses_read_only_session(
+ monkeypatch: pytest.MonkeyPatch,
+ expression_auto_check_engine,
+) -> None:
+ """选择表达方式时应使用只读会话,并在离开会话后安全读取 ORM 字段。"""
+
+ import src.bw_learner.expression_auto_check_task as expression_auto_check_task_module
+
+ with Session(expression_auto_check_engine) as session:
+ session.add(
+ Expression(
+ situation="表达情绪高涨或生理反应",
+ style="发送💦表情符号",
+ content_list='["表达情绪高涨或生理反应"]',
+ count=1,
+ session_id="session-a",
+ checked=False,
+ rejected=False,
+ )
+ )
+ session.commit()
+
+ auto_commit_calls: list[bool] = []
+
+ @contextmanager
+ def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
+ """构造带自动提交语义的测试会话工厂。
+
+ Args:
+ auto_commit: 退出上下文时是否自动提交。
+
+ Yields:
+ Generator[Session, None, None]: SQLModel 会话对象。
+ """
+
+ auto_commit_calls.append(auto_commit)
+ session = Session(expression_auto_check_engine)
+ try:
+ yield session
+ if auto_commit:
+ session.commit()
+ except Exception:
+ session.rollback()
+ raise
+ finally:
+ session.close()
+
+ monkeypatch.setattr(expression_auto_check_task_module, "get_db_session", fake_get_db_session)
+ monkeypatch.setattr(expression_auto_check_task_module.random, "sample", lambda entries, _count: list(entries))
+
+ task = ExpressionAutoCheckTask()
+ expressions = await task._select_expressions(1)
+
+ assert auto_commit_calls == [False]
+ assert len(expressions) == 1
+ assert expressions[0].id is not None
+ assert expressions[0].situation == "表达情绪高涨或生理反应"
+ assert expressions[0].style == "发送💦表情符号"
diff --git a/pytests/test_napcat_adapter_sdk.py b/pytests/test_napcat_adapter_sdk.py
new file mode 100644
index 00000000..c6b1fdbd
--- /dev/null
+++ b/pytests/test_napcat_adapter_sdk.py
@@ -0,0 +1,132 @@
+"""NapCat 插件与新 SDK 对接测试。"""
+
+from pathlib import Path
+from typing import Any, Dict, List
+
+import importlib
+import logging
+import sys
+
+import pytest
+
+PROJECT_ROOT = Path(__file__).resolve().parents[1]
+PLUGINS_ROOT = PROJECT_ROOT / "plugins"
+SDK_ROOT = PROJECT_ROOT / "packages" / "maibot-plugin-sdk"
+
+for import_path in (str(PLUGINS_ROOT), str(SDK_ROOT)):
+ if import_path not in sys.path:
+ sys.path.insert(0, import_path)
+
+
+class _FakeGatewayCapability:
+ """用于捕获消息网关状态上报的测试替身。"""
+
+ def __init__(self) -> None:
+ """初始化测试替身。"""
+
+ self.calls: List[Dict[str, Any]] = []
+
+ async def update_state(
+ self,
+ gateway_name: str,
+ *,
+ ready: bool,
+ platform: str = "",
+ account_id: str = "",
+ scope: str = "",
+ metadata: Dict[str, Any] | None = None,
+ ) -> bool:
+ """记录一次状态上报请求。
+
+ Args:
+ gateway_name: 网关组件名称。
+ ready: 当前是否就绪。
+ platform: 平台名称。
+ account_id: 账号 ID。
+ scope: 路由作用域。
+ metadata: 附加元数据。
+
+ Returns:
+ bool: 始终返回 ``True``,模拟 Host 接受状态更新。
+ """
+
+ self.calls.append(
+ {
+ "gateway_name": gateway_name,
+ "ready": ready,
+ "platform": platform,
+ "account_id": account_id,
+ "scope": scope,
+ "metadata": metadata or {},
+ }
+ )
+ return True
+
+
+def _load_napcat_sdk_symbols() -> tuple[Any, Any, Any, Any]:
+ """动态加载 NapCat 插件测试所需的符号。
+
+ Returns:
+ tuple[Any, Any, Any, Any]:
+ 依次返回网关名常量、配置类、插件类和运行时状态管理器类。
+ """
+
+ constants_module = importlib.import_module("napcat_adapter.constants")
+ config_module = importlib.import_module("napcat_adapter.config")
+ plugin_module = importlib.import_module("napcat_adapter.plugin")
+ runtime_state_module = importlib.import_module("napcat_adapter.runtime_state")
+ return (
+ constants_module.NAPCAT_GATEWAY_NAME,
+ config_module.NapCatServerConfig,
+ plugin_module.NapCatAdapterPlugin,
+ runtime_state_module.NapCatRuntimeStateManager,
+ )
+
+
+def test_napcat_plugin_collects_duplex_message_gateway() -> None:
+ """NapCat 插件应声明新的双工消息网关组件。"""
+
+ napcat_gateway_name, _napcat_server_config, napcat_plugin_cls, _runtime_state_cls = _load_napcat_sdk_symbols()
+ plugin = napcat_plugin_cls()
+ components = plugin.get_components()
+ gateway_components = [
+ component
+ for component in components
+ if component.get("type") == "MESSAGE_GATEWAY"
+ ]
+
+ assert len(gateway_components) == 1
+ gateway_component = gateway_components[0]
+ assert gateway_component["name"] == napcat_gateway_name
+ assert gateway_component["metadata"]["route_type"] == "duplex"
+ assert gateway_component["metadata"]["platform"] == "qq"
+ assert gateway_component["metadata"]["protocol"] == "napcat"
+
+
+@pytest.mark.asyncio
+async def test_runtime_state_reports_via_gateway_capability() -> None:
+ """NapCat 运行时状态应通过新的消息网关能力上报。"""
+
+ napcat_gateway_name, napcat_server_config_cls, _napcat_plugin_cls, runtime_state_cls = _load_napcat_sdk_symbols()
+ gateway_capability = _FakeGatewayCapability()
+ runtime_state_manager = runtime_state_cls(
+ gateway_capability=gateway_capability,
+ logger=logging.getLogger("test.napcat_adapter"),
+ gateway_name=napcat_gateway_name,
+ )
+
+ connected = await runtime_state_manager.report_connected(
+ "10001",
+ napcat_server_config_cls(connection_id="primary"),
+ )
+ await runtime_state_manager.report_disconnected()
+
+ assert connected is True
+ assert gateway_capability.calls[0]["gateway_name"] == napcat_gateway_name
+ assert gateway_capability.calls[0]["ready"] is True
+ assert gateway_capability.calls[0]["platform"] == "qq"
+ assert gateway_capability.calls[0]["account_id"] == "10001"
+ assert gateway_capability.calls[0]["scope"] == "primary"
+ assert gateway_capability.calls[1]["gateway_name"] == napcat_gateway_name
+ assert gateway_capability.calls[1]["ready"] is False
+ assert gateway_capability.calls[1]["platform"] == "qq"
diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py
index 9dfc34d8..9b46f897 100644
--- a/pytests/test_plugin_runtime.py
+++ b/pytests/test_plugin_runtime.py
@@ -441,8 +441,8 @@ class TestSDK:
def set_plugin_config(self, config):
self.configs.append(config)
- async def on_config_update(self, config, version):
- self.updates.append((config, version, list(self.configs)))
+ async def on_config_update(self, scope, config, version):
+ self.updates.append((scope, config, version, list(self.configs)))
runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[])
plugin = DummyPlugin()
@@ -453,14 +453,60 @@ class TestSDK:
message_type=MessageType.REQUEST,
method="plugin.config_updated",
plugin_id="demo_plugin",
- payload={"config_data": {"enabled": True}, "config_version": "v2"},
+ payload={
+ "plugin_id": "demo_plugin",
+ "config_scope": "self",
+ "config_data": {"enabled": True},
+ "config_version": "v2",
+ },
)
response = await runner._handle_config_updated(envelope)
assert response.payload["acknowledged"] is True
assert plugin.configs == [{"enabled": True}]
- assert plugin.updates == [({"enabled": True}, "v2", [{"enabled": True}])]
+ assert plugin.updates == [("self", {"enabled": True}, "v2", [{"enabled": True}])]
+
+ @pytest.mark.asyncio
+ async def test_runner_global_config_update_does_not_override_plugin_config(self):
+ """bot/model 广播不应覆盖插件自身配置缓存。"""
+ from src.plugin_runtime.protocol.envelope import Envelope, MessageType
+ from src.plugin_runtime.runner.runner_main import PluginRunner
+
+ class DummyPlugin:
+ def __init__(self):
+ self.configs = []
+ self.updates = []
+
+ def set_plugin_config(self, config):
+ self.configs.append(config)
+
+ async def on_config_update(self, scope, config, version):
+ self.updates.append((scope, config, version, list(self.configs)))
+
+ runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[])
+ plugin = DummyPlugin()
+ runner._loader._loaded_plugins["demo_plugin"] = SimpleNamespace(instance=plugin)
+ plugin.set_plugin_config({"plugin_enabled": True})
+
+ envelope = Envelope(
+ request_id=1,
+ message_type=MessageType.REQUEST,
+ method="plugin.config_updated",
+ plugin_id="demo_plugin",
+ payload={
+ "plugin_id": "demo_plugin",
+ "config_scope": "model",
+ "config_data": {"models": []},
+ "config_version": "",
+ },
+ )
+
+ response = await runner._handle_config_updated(envelope)
+
+ assert response.payload["acknowledged"] is True
+ assert plugin.configs == [{"plugin_enabled": True}]
+ assert plugin.updates == [("model", {"models": []}, "", [{"plugin_enabled": True}])]
@pytest.mark.asyncio
async def test_runner_bootstraps_capabilities_before_on_load(self, monkeypatch):
@@ -911,6 +957,120 @@ class TestDependencyResolution:
assert loader.failed_plugins == {}
assert loaded[0].instance.answer() == 42
+ def test_loader_requires_sdk_plugin_to_override_on_config_update(self, tmp_path):
+ from src.plugin_runtime.runner.plugin_loader import PluginLoader
+
+ plugin_root = tmp_path / "plugins"
+ plugin_root.mkdir()
+ plugin_dir = plugin_root / "demo_plugin"
+ plugin_dir.mkdir()
+
+ (plugin_dir / "_manifest.json").write_text(
+ json.dumps(
+ {
+ "name": "demo_plugin",
+ "version": "1.0.0",
+ "description": "demo",
+ "author": "tester",
+ }
+ ),
+ encoding="utf-8",
+ )
+ (plugin_dir / "plugin.py").write_text(
+ "from maibot_sdk import MaiBotPlugin\n\n"
+ "class DemoPlugin(MaiBotPlugin):\n"
+ " async def on_load(self):\n"
+ " pass\n\n"
+ " async def on_unload(self):\n"
+ " pass\n\n"
+ "def create_plugin():\n"
+ " return DemoPlugin()\n",
+ encoding="utf-8",
+ )
+
+ loader = PluginLoader()
+ loaded = loader.discover_and_load([str(plugin_root)])
+
+ assert loaded == []
+ assert "demo_plugin" in loader.failed_plugins
+ assert "on_config_update" in loader.failed_plugins["demo_plugin"]
+
+ def test_loader_requires_sdk_plugin_to_override_on_load(self, tmp_path):
+ from src.plugin_runtime.runner.plugin_loader import PluginLoader
+
+ plugin_root = tmp_path / "plugins"
+ plugin_root.mkdir()
+ plugin_dir = plugin_root / "demo_plugin"
+ plugin_dir.mkdir()
+
+ (plugin_dir / "_manifest.json").write_text(
+ json.dumps(
+ {
+ "name": "demo_plugin",
+ "version": "1.0.0",
+ "description": "demo",
+ "author": "tester",
+ }
+ ),
+ encoding="utf-8",
+ )
+ (plugin_dir / "plugin.py").write_text(
+ "from maibot_sdk import MaiBotPlugin\n\n"
+ "class DemoPlugin(MaiBotPlugin):\n"
+ " async def on_unload(self):\n"
+ " pass\n\n"
+ " async def on_config_update(self, scope, config_data, version):\n"
+ " pass\n\n"
+ "def create_plugin():\n"
+ " return DemoPlugin()\n",
+ encoding="utf-8",
+ )
+
+ loader = PluginLoader()
+ loaded = loader.discover_and_load([str(plugin_root)])
+
+ assert loaded == []
+ assert "demo_plugin" in loader.failed_plugins
+ assert "on_load" in loader.failed_plugins["demo_plugin"]
+
+ def test_loader_requires_sdk_plugin_to_override_on_unload(self, tmp_path):
+ from src.plugin_runtime.runner.plugin_loader import PluginLoader
+
+ plugin_root = tmp_path / "plugins"
+ plugin_root.mkdir()
+ plugin_dir = plugin_root / "demo_plugin"
+ plugin_dir.mkdir()
+
+ (plugin_dir / "_manifest.json").write_text(
+ json.dumps(
+ {
+ "name": "demo_plugin",
+ "version": "1.0.0",
+ "description": "demo",
+ "author": "tester",
+ }
+ ),
+ encoding="utf-8",
+ )
+ (plugin_dir / "plugin.py").write_text(
+ "from maibot_sdk import MaiBotPlugin\n\n"
+ "class DemoPlugin(MaiBotPlugin):\n"
+ " async def on_load(self):\n"
+ " pass\n\n"
+ " async def on_config_update(self, scope, config_data, version):\n"
+ " pass\n\n"
+ "def create_plugin():\n"
+ " return DemoPlugin()\n",
+ encoding="utf-8",
+ )
+
+ loader = PluginLoader()
+ loaded = loader.discover_and_load([str(plugin_root)])
+
+ assert loaded == []
+ assert "demo_plugin" in loader.failed_plugins
+ assert "on_unload" in loader.failed_plugins["demo_plugin"]
+
def test_isolate_sys_path_preserves_plugin_dirs(self):
from src.plugin_runtime.runner import runner_main
@@ -2299,9 +2459,10 @@ class TestIntegration:
assert refresh_calls == [True]
@pytest.mark.asyncio
- async def test_handle_plugin_config_changes_only_reload_target_plugin(self, monkeypatch, tmp_path):
+ async def test_handle_plugin_config_changes_only_notify_target_plugin(self, monkeypatch, tmp_path):
from src.plugin_runtime import integration as integration_module
from src.config.file_watcher import FileChange
+ import json
builtin_root = tmp_path / "src" / "plugins" / "built_in"
thirdparty_root = tmp_path / "plugins"
@@ -2311,6 +2472,10 @@ class TestIntegration:
beta_dir.mkdir(parents=True)
(alpha_dir / "config.toml").write_text("enabled = true\n", encoding="utf-8")
(beta_dir / "config.toml").write_text("enabled = false\n", encoding="utf-8")
+ (alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
+ (beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
+ (alpha_dir / "_manifest.json").write_text(json.dumps({"name": "alpha"}), encoding="utf-8")
+ (beta_dir / "_manifest.json").write_text(json.dumps({"name": "beta"}), encoding="utf-8")
monkeypatch.chdir(tmp_path)
@@ -2318,31 +2483,95 @@ class TestIntegration:
def __init__(self, plugin_dirs, plugins):
self._plugin_dirs = plugin_dirs
self._registered_plugins = {plugin_id: object() for plugin_id in plugins}
- self.reload_calls = []
+ self.config_updates = []
- async def reload_plugin(self, plugin_id, reason="manual"):
- self.reload_calls.append((plugin_id, reason))
+ async def notify_plugin_config_updated(
+ self,
+ plugin_id,
+ config_data,
+ config_version="",
+ config_scope="self",
+ ):
+ self.config_updates.append((plugin_id, config_data, config_version, config_scope))
return True
manager = integration_module.PluginRuntimeManager()
manager._started = True
manager._builtin_supervisor = FakeSupervisor([builtin_root], ["alpha"])
manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["beta"])
- refresh_calls = []
-
- def fake_refresh() -> None:
- refresh_calls.append(True)
-
- manager._refresh_plugin_config_watch_subscriptions = fake_refresh
await manager._handle_plugin_config_changes(
"alpha",
[FileChange(change_type=1, path=alpha_dir / "config.toml")],
)
- assert manager._builtin_supervisor.reload_calls == [("alpha", "config_file_changed")]
- assert manager._third_party_supervisor.reload_calls == []
- assert refresh_calls == [True]
+ assert manager._builtin_supervisor.config_updates == [("alpha", {"enabled": True}, "", "self")]
+ assert manager._third_party_supervisor.config_updates == []
+
+ @pytest.mark.asyncio
+ async def test_handle_main_config_reload_only_notifies_subscribers(self, monkeypatch):
+ from src.plugin_runtime import integration as integration_module
+
+ class FakeRegistration:
+ def __init__(self, subscriptions):
+ self.config_reload_subscriptions = subscriptions
+
+ class FakeSupervisor:
+ def __init__(self, registrations):
+ self._registered_plugins = registrations
+ self.config_updates = []
+
+ def get_config_reload_subscribers(self, scope):
+ matched_plugins = []
+ for plugin_id, registration in self._registered_plugins.items():
+ if scope in registration.config_reload_subscriptions:
+ matched_plugins.append(plugin_id)
+ return matched_plugins
+
+ async def notify_plugin_config_updated(
+ self,
+ plugin_id,
+ config_data,
+ config_version="",
+ config_scope="self",
+ ):
+ self.config_updates.append((plugin_id, config_data, config_version, config_scope))
+ return True
+
+ fake_global = SimpleNamespace(plugin_runtime=SimpleNamespace(enabled=True))
+ monkeypatch.setattr(
+ integration_module.config_manager,
+ "get_global_config",
+ lambda: SimpleNamespace(model_dump=lambda: {"bot": {"name": "MaiBot"}}, plugin_runtime=fake_global.plugin_runtime),
+ )
+ monkeypatch.setattr(
+ integration_module.config_manager,
+ "get_model_config",
+ lambda: SimpleNamespace(model_dump=lambda: {"models": [{"name": "demo"}]}),
+ )
+
+ manager = integration_module.PluginRuntimeManager()
+ manager._started = True
+ manager._builtin_supervisor = FakeSupervisor(
+ {
+ "alpha": FakeRegistration(["bot"]),
+ "beta": FakeRegistration([]),
+ }
+ )
+ manager._third_party_supervisor = FakeSupervisor(
+ {
+ "gamma": FakeRegistration(["model"]),
+ }
+ )
+
+ await manager._handle_main_config_reload(["bot", "model"])
+
+ assert manager._builtin_supervisor.config_updates == [
+ ("alpha", {"bot": {"name": "MaiBot"}}, "", "bot")
+ ]
+ assert manager._third_party_supervisor.config_updates == [
+ ("gamma", {"models": [{"name": "demo"}]}, "", "model")
+ ]
def test_refresh_plugin_config_watch_subscriptions_registers_per_plugin(self, tmp_path):
from src.plugin_runtime import integration as integration_module
diff --git a/src/config/config.py b/src/config/config.py
index ff5941bf..bee81efb 100644
--- a/src/config/config.py
+++ b/src/config/config.py
@@ -4,6 +4,7 @@ from typing import Any, Callable, Mapping, Sequence, TypeVar
import asyncio
import copy
+import inspect
import sys
import tomlkit
@@ -61,6 +62,7 @@ MODEL_CONFIG_VERSION: str = "1.12.0"
logger = get_logger("config")
T = TypeVar("T", bound="ConfigBase")
+ConfigReloadCallback = Callable[[Sequence[str]], object] | Callable[[], object]
class Config(ConfigBase):
@@ -190,7 +192,7 @@ class ConfigManager:
self.global_config: Config | None = None
self.model_config: ModelConfig | None = None
self._reload_lock: asyncio.Lock = asyncio.Lock()
- self._reload_callbacks: list[Callable[[], object]] = []
+ self._reload_callbacks: list[ConfigReloadCallback] = []
self._file_watcher: FileWatcher | None = None
self._file_watcher_subscription_id: str | None = None
self._hot_reload_min_interval_s: float = 1.0
@@ -226,16 +228,125 @@ class ConfigManager:
raise RuntimeError(t("config.model_not_initialized"))
return self.model_config
- def register_reload_callback(self, callback: Callable[[], object]) -> None:
+ def register_reload_callback(self, callback: ConfigReloadCallback) -> None:
+ """注册配置热重载回调。
+
+ Args:
+ callback: 配置热重载回调。允许无参回调,也允许接收
+ ``Sequence[str]`` 类型的变更范围列表。
+ """
+
self._reload_callbacks.append(callback)
- def unregister_reload_callback(self, callback: Callable[[], object]) -> None:
+ def unregister_reload_callback(self, callback: ConfigReloadCallback) -> None:
+ """注销配置热重载回调。
+
+ Args:
+ callback: 先前注册过的回调对象。
+ """
+
try:
self._reload_callbacks.remove(callback)
except ValueError:
return
- async def reload_config(self) -> bool:
+ @staticmethod
+ def _normalize_changed_scopes(changed_scopes: Sequence[str] | None) -> tuple[str, ...]:
+ """规范化配置变更范围列表。
+
+ Args:
+ changed_scopes: 原始配置变更范围。
+
+ Returns:
+ tuple[str, ...]: 去重后的配置变更范围元组。
+ """
+
+ if not changed_scopes:
+ return ("bot", "model")
+
+ normalized_scopes: list[str] = []
+ for scope in changed_scopes:
+ normalized_scope = str(scope or "").strip().lower()
+ if normalized_scope not in {"bot", "model"}:
+ continue
+ if normalized_scope not in normalized_scopes:
+ normalized_scopes.append(normalized_scope)
+ return tuple(normalized_scopes)
+
+ @staticmethod
+ def _resolve_changed_scopes(changes: Sequence[FileChange]) -> tuple[str, ...]:
+ """根据文件变更列表推断配置变更范围。
+
+ Args:
+ changes: 文件监听器返回的变更列表。
+
+ Returns:
+ tuple[str, ...]: 命中的配置变更范围元组。
+ """
+
+ changed_scopes: list[str] = []
+ for change in changes:
+ file_name = change.path.name
+ if file_name == "bot_config.toml" and "bot" not in changed_scopes:
+ changed_scopes.append("bot")
+ if file_name == "model_config.toml" and "model" not in changed_scopes:
+ changed_scopes.append("model")
+ return tuple(changed_scopes)
+
+ @staticmethod
+ def _callback_accepts_scopes(callback: ConfigReloadCallback) -> bool:
+ """判断回调是否接收配置变更范围参数。
+
+ Args:
+ callback: 待检测的回调对象。
+
+ Returns:
+ bool: 若回调可接收一个位置参数或可变位置参数,则返回 ``True``。
+ """
+
+ try:
+ parameters = inspect.signature(callback).parameters.values()
+ except (TypeError, ValueError):
+ return False
+
+ positional_params = {
+ inspect.Parameter.POSITIONAL_ONLY,
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ }
+ for parameter in parameters:
+ if parameter.kind == inspect.Parameter.VAR_POSITIONAL:
+ return True
+ if parameter.kind in positional_params:
+ return True
+ return False
+
+ async def _invoke_reload_callback(
+ self,
+ callback: ConfigReloadCallback,
+ changed_scopes: Sequence[str],
+ ) -> None:
+ """执行单个配置热重载回调。
+
+ Args:
+ callback: 要执行的回调对象。
+ changed_scopes: 本次热重载命中的配置范围。
+ """
+
+ result = callback(changed_scopes) if self._callback_accepts_scopes(callback) else callback()
+ if asyncio.iscoroutine(result):
+ await result
+
+ async def reload_config(self, changed_scopes: Sequence[str] | None = None) -> bool:
+ """重新加载主配置和模型配置。
+
+ Args:
+ changed_scopes: 本次触发热重载的配置范围。
+
+ Returns:
+ bool: 是否重载成功。
+ """
+
+ normalized_scopes = self._normalize_changed_scopes(changed_scopes)
async with self._reload_lock:
try:
global_config_new, global_updated = load_config_from_file(
@@ -265,9 +376,7 @@ class ConfigManager:
for callback in list(self._reload_callbacks):
try:
- result = callback()
- if asyncio.iscoroutine(result):
- await result
+ await self._invoke_reload_callback(callback, normalized_scopes)
except Exception as exc:
logger.warning(t("config.reload_callback_failed", error=exc))
return True
@@ -312,6 +421,12 @@ class ConfigManager:
self._file_watcher = None
async def _handle_file_changes(self, changes: Sequence[FileChange]) -> None:
+ """处理主配置与模型配置文件变更。
+
+ Args:
+ changes: 当前批次收集到的文件变更列表。
+ """
+
if not changes:
return
now_monotonic = asyncio.get_running_loop().time()
@@ -321,7 +436,11 @@ class ConfigManager:
self._last_hot_reload_monotonic = now_monotonic
logger.info(t("config.file_change_detected"))
try:
- await asyncio.wait_for(self.reload_config(), timeout=self._hot_reload_timeout_s)
+ changed_scopes = self._resolve_changed_scopes(changes)
+ await asyncio.wait_for(
+ self.reload_config(changed_scopes=changed_scopes),
+ timeout=self._hot_reload_timeout_s,
+ )
except asyncio.TimeoutError:
logger.error(t("config.reload_timeout", timeout_seconds=self._hot_reload_timeout_s))
diff --git a/src/learners/expression_auto_check_task.py b/src/learners/expression_auto_check_task.py
index 53b151b2..e5af1057 100644
--- a/src/learners/expression_auto_check_task.py
+++ b/src/learners/expression_auto_check_task.py
@@ -3,15 +3,15 @@
功能:
1. 定期随机选取指定数量的表达方式
-2. 使用LLM进行评估
+2. 使用 LLM 进行评估
3. 通过评估的:rejected=0, checked=1
4. 未通过评估的:rejected=1, checked=1
"""
-from typing import List
import asyncio
import json
import random
+from typing import List
from sqlmodel import select
@@ -146,7 +146,8 @@ class ExpressionAutoCheckTask(AsyncTask):
选中的表达方式列表
"""
try:
- with get_db_session() as session:
+ # 这里只做查询,避免退出上下文时自动提交导致 ORM 实例过期。
+ with get_db_session(auto_commit=False) as session:
statement = select(Expression)
all_expressions = session.exec(statement).all()
diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py
index 1add64c6..afe944e5 100644
--- a/src/plugin_runtime/host/supervisor.py
+++ b/src/plugin_runtime/host/supervisor.py
@@ -399,6 +399,7 @@ class PluginRunnerSupervisor:
plugin_id: str,
config_data: Optional[Dict[str, Any]] = None,
config_version: str = "",
+ config_scope: str = "self",
) -> bool:
"""向 Runner 推送插件配置更新。
@@ -406,12 +407,14 @@ class PluginRunnerSupervisor:
plugin_id: 目标插件 ID。
config_data: 配置内容。
config_version: 配置版本号。
+ config_scope: 配置变更范围。
Returns:
bool: 请求是否成功送达并被 Runner 接受。
"""
payload = ConfigUpdatedPayload(
plugin_id=plugin_id,
+ config_scope=config_scope,
config_version=config_version,
config_data=config_data or {},
)
@@ -428,6 +431,22 @@ class PluginRunnerSupervisor:
return bool(response.payload.get("acknowledged", False))
+ def get_config_reload_subscribers(self, scope: str) -> List[str]:
+ """返回订阅指定全局配置广播的插件列表。
+
+ Args:
+ scope: 配置变更范围,仅支持 ``bot`` 或 ``model``。
+
+ Returns:
+ List[str]: 已声明订阅该范围的插件 ID 列表。
+ """
+
+ matched_plugins: List[str] = []
+ for plugin_id, registration in self._registered_plugins.items():
+ if scope in registration.config_reload_subscriptions:
+ matched_plugins.append(plugin_id)
+ return matched_plugins
+
async def _wait_for_runner_connection(self, timeout_sec: float) -> None:
"""等待 Runner 建立 RPC 连接。
diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py
index ff51f419..e45b40de 100644
--- a/src/plugin_runtime/integration.py
+++ b/src/plugin_runtime/integration.py
@@ -16,7 +16,7 @@ import json
import tomlkit
from src.common.logger import get_logger
-from src.config.config import global_config
+from src.config.config import config_manager
from src.config.file_watcher import FileChange, FileWatcher
from src.platform_io import DeliveryBatch, InboundMessageEnvelope, get_platform_io_manager
from src.plugin_runtime.capabilities import (
@@ -69,6 +69,8 @@ class PluginRuntimeManager(
self._plugin_source_watcher_subscription_id: Optional[str] = None
self._plugin_config_watcher_subscriptions: Dict[str, Tuple[Path, str]] = {}
self._plugin_path_cache: Dict[str, Path] = {}
+ self._config_reload_callback: Callable[[Sequence[str]], Awaitable[None]] = self._handle_main_config_reload
+ self._config_reload_callback_registered: bool = False
async def _dispatch_platform_inbound(self, envelope: InboundMessageEnvelope) -> None:
"""接收 Platform IO 审核后的入站消息并送入主消息链。
@@ -108,7 +110,7 @@ class PluginRuntimeManager(
logger.warning("PluginRuntimeManager 已在运行中,跳过重复启动")
return
- _cfg = global_config.plugin_runtime
+ _cfg = config_manager.get_global_config().plugin_runtime
if not _cfg.enabled:
logger.info("插件运行时已在配置中禁用,跳过启动")
return
@@ -166,11 +168,16 @@ class PluginRuntimeManager(
await self._third_party_supervisor.start()
started_supervisors.append(self._third_party_supervisor)
await self._start_plugin_file_watcher()
+ config_manager.register_reload_callback(self._config_reload_callback)
+ self._config_reload_callback_registered = True
self._started = True
logger.info(f"插件运行时已启动 — 内置: {builtin_dirs or '无'}, 第三方: {third_party_dirs or '无'}")
except Exception as e:
logger.error(f"插件运行时启动失败: {e}", exc_info=True)
await self._stop_plugin_file_watcher()
+ if self._config_reload_callback_registered:
+ config_manager.unregister_reload_callback(self._config_reload_callback)
+ self._config_reload_callback_registered = False
await asyncio.gather(*(sv.stop() for sv in started_supervisors), return_exceptions=True)
platform_io_manager.clear_inbound_dispatcher()
try:
@@ -188,6 +195,9 @@ class PluginRuntimeManager(
platform_io_manager = get_platform_io_manager()
await self._stop_plugin_file_watcher()
+ if self._config_reload_callback_registered:
+ config_manager.unregister_reload_callback(self._config_reload_callback)
+ self._config_reload_callback_registered = False
coroutines: List[Coroutine[Any, Any, None]] = []
if self._builtin_supervisor:
@@ -233,6 +243,7 @@ class PluginRuntimeManager(
plugin_id: str,
config_data: Optional[Dict[str, Any]] = None,
config_version: str = "",
+ config_scope: str = "self",
) -> bool:
"""向拥有该插件的 Supervisor 推送配置更新事件。
@@ -240,6 +251,7 @@ class PluginRuntimeManager(
plugin_id: 插件 ID
config_data: 可选的配置数据(如果为 None 则由 Supervisor 从磁盘加载)
config_version: 可选的配置版本字符串,供 Supervisor 进行版本控制
+ config_scope: 配置变更范围。
"""
if not self._started:
return False
@@ -258,12 +270,67 @@ class PluginRuntimeManager(
if config_data is not None
else self._load_plugin_config_for_supervisor(sv, plugin_id)
)
- await sv.notify_plugin_config_updated(
+ return await sv.notify_plugin_config_updated(
plugin_id=plugin_id,
config_data=config_payload,
config_version=config_version,
+ config_scope=config_scope,
)
- return True
+
+ @staticmethod
+ def _normalize_config_reload_scopes(changed_scopes: Sequence[str]) -> tuple[str, ...]:
+ """规范化配置热重载范围列表。
+
+ Args:
+ changed_scopes: 原始配置热重载范围列表。
+
+ Returns:
+ tuple[str, ...]: 去重后的有效配置范围元组。
+ """
+
+ normalized_scopes: list[str] = []
+ for scope in changed_scopes:
+ normalized_scope = str(scope or "").strip().lower()
+ if normalized_scope not in {"bot", "model"}:
+ continue
+ if normalized_scope not in normalized_scopes:
+ normalized_scopes.append(normalized_scope)
+ return tuple(normalized_scopes)
+
+ async def _broadcast_config_reload(self, scope: str, config_data: Dict[str, Any]) -> None:
+ """向订阅指定范围的插件广播配置热重载。
+
+ Args:
+ scope: 配置变更范围,仅支持 ``bot`` 或 ``model``。
+ config_data: 最新配置数据。
+ """
+
+ for supervisor in self.supervisors:
+ for plugin_id in supervisor.get_config_reload_subscribers(scope):
+ delivered = await supervisor.notify_plugin_config_updated(
+ plugin_id=plugin_id,
+ config_data=config_data,
+ config_version="",
+ config_scope=scope,
+ )
+ if not delivered:
+ logger.warning(f"向插件 {plugin_id} 广播 {scope} 配置热重载失败")
+
+ async def _handle_main_config_reload(self, changed_scopes: Sequence[str]) -> None:
+ """处理 bot/model 主配置热重载广播。
+
+ Args:
+ changed_scopes: 本次热重载命中的配置范围列表。
+ """
+
+ if not self._started:
+ return
+
+ normalized_scopes = self._normalize_config_reload_scopes(changed_scopes)
+ if "bot" in normalized_scopes:
+ await self._broadcast_config_reload("bot", config_manager.get_global_config().model_dump())
+ if "model" in normalized_scopes:
+ await self._broadcast_config_reload("model", config_manager.get_model_config().model_dump())
# ─── 事件桥接 ──────────────────────────────────────────────
@@ -612,16 +679,12 @@ class PluginRuntimeManager(
return None if plugin_path is None else plugin_path / "config.toml"
async def _handle_plugin_config_changes(self, plugin_id: str, changes: Sequence[FileChange]) -> None:
- """处理单个插件配置文件变化,并精确重载目标插件。
+ """处理单个插件配置文件变化,并定向派发自配置热更新。
Args:
plugin_id: 发生配置变更的插件 ID。
changes: 当前批次收集到的配置文件变更列表。
- Notes:
- 这里选择“精确重载该插件”,而不是仅推送软性的配置更新通知。
- 这样可以保证没有实现 ``on_config_update()`` 的插件也能重新执行
- ``on_load()``,让磁盘上的 ``config.toml`` 修改对插件运行态真正生效。
"""
if not self._started or not changes:
return
@@ -636,15 +699,15 @@ class PluginRuntimeManager(
return
try:
- self._load_plugin_config_for_supervisor(supervisor, plugin_id)
- reload_success = await supervisor.reload_plugin(
+ config_payload = self._load_plugin_config_for_supervisor(supervisor, plugin_id)
+ delivered = await supervisor.notify_plugin_config_updated(
plugin_id=plugin_id,
- reason="config_file_changed",
+ config_data=config_payload,
+ config_version="",
+ config_scope="self",
)
- if reload_success:
- self._refresh_plugin_config_watch_subscriptions()
- else:
- logger.warning(f"插件 {plugin_id} 配置文件变更后重载失败")
+ if not delivered:
+ logger.warning(f"插件 {plugin_id} 配置文件变更后通知失败")
except Exception as exc:
logger.warning(f"插件 {plugin_id} 配置文件变更处理失败: {exc}")
@@ -652,8 +715,8 @@ class PluginRuntimeManager(
"""处理插件源码相关变化。
这里仅负责源码、清单等会影响插件装载状态的文件;配置文件的变化会由
- 单独的 per-plugin watcher 处理,并精确重载对应插件,避免放大成
- 不必要的跨插件 reload。
+ 单独的 per-plugin watcher 处理,并定向派发给目标插件的
+ ``on_config_update()``,避免放大成不必要的跨插件 reload。
"""
if not self._started or not changes:
return
diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py
index cbbb71be..6078e4dc 100644
--- a/src/plugin_runtime/protocol/envelope.py
+++ b/src/plugin_runtime/protocol/envelope.py
@@ -29,6 +29,14 @@ class MessageType(str, Enum):
BROADCAST = "broadcast"
+class ConfigReloadScope(str, Enum):
+ """配置热重载范围。"""
+
+ SELF = "self"
+ BOT = "bot"
+ MODEL = "model"
+
+
# ====== 请求 ID 生成器 ======
class RequestIdGenerator:
"""单调递增 int64 请求 ID 生成器"""
@@ -158,6 +166,8 @@ class RegisterPluginPayload(BaseModel):
"""组件列表"""
capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表")
"""所需能力列表"""
+ config_reload_subscriptions: List[str] = Field(default_factory=list, description="订阅的全局配置热重载范围")
+ """订阅的全局配置热重载范围"""
class BootstrapPluginPayload(BaseModel):
@@ -236,6 +246,8 @@ class ConfigUpdatedPayload(BaseModel):
plugin_id: str = Field(description="插件 ID")
"""插件 ID"""
+ config_scope: ConfigReloadScope = Field(description="配置变更范围")
+ """配置变更范围"""
config_version: str = Field(description="新配置版本")
"""新配置版本"""
config_data: Dict[str, Any] = Field(default_factory=dict, description="配置内容")
diff --git a/src/plugin_runtime/runner/plugin_loader.py b/src/plugin_runtime/runner/plugin_loader.py
index 90c8bf47..a766eb04 100644
--- a/src/plugin_runtime/runner/plugin_loader.py
+++ b/src/plugin_runtime/runner/plugin_loader.py
@@ -403,6 +403,7 @@ class PluginLoader:
create_plugin = getattr(module, "create_plugin", None)
if create_plugin is not None:
instance = create_plugin()
+ self._validate_sdk_plugin_contract(plugin_id, instance)
logger.info(f"插件 {plugin_id} v{manifest.get('version', '?')} 加载成功")
return PluginMeta(
plugin_id=plugin_id,
@@ -432,6 +433,35 @@ class PluginLoader:
logger.error(f"插件 {plugin_id} 缺少 create_plugin 工厂函数且未检测到旧版 BasePlugin")
return None
+ @staticmethod
+ def _validate_sdk_plugin_contract(plugin_id: str, instance: Any) -> None:
+ """校验 SDK 插件的基础契约。
+
+ Args:
+ plugin_id: 当前插件 ID。
+ instance: ``create_plugin()`` 返回的插件实例。
+
+ Raises:
+ TypeError: 当插件未覆盖必需生命周期方法或订阅声明不合法时抛出。
+ """
+
+ try:
+ from maibot_sdk.plugin import MaiBotPlugin
+ except ImportError:
+ return
+
+ if not isinstance(instance, MaiBotPlugin):
+ return
+
+ if type(instance).on_load is MaiBotPlugin.on_load:
+ raise TypeError(f"插件 {plugin_id} 必须实现 on_load()")
+ if type(instance).on_unload is MaiBotPlugin.on_unload:
+ raise TypeError(f"插件 {plugin_id} 必须实现 on_unload()")
+ if type(instance).on_config_update is MaiBotPlugin.on_config_update:
+ raise TypeError(f"插件 {plugin_id} 必须实现 on_config_update()")
+
+ instance.get_config_reload_subscriptions()
+
@staticmethod
@contextlib.contextmanager
def _temporary_sys_path_entry(path: Path) -> Iterator[None]:
diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py
index 4bee714c..b94b01d1 100644
--- a/src/plugin_runtime/runner/runner_main.py
+++ b/src/plugin_runtime/runner/runner_main.py
@@ -27,6 +27,7 @@ from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIR
from src.plugin_runtime.protocol.envelope import (
BootstrapPluginPayload,
ComponentDeclaration,
+ ConfigUpdatedPayload,
Envelope,
HealthPayload,
InvokePayload,
@@ -342,6 +343,7 @@ class PluginRunner:
"""
# 收集插件组件声明
components: List[ComponentDeclaration] = []
+ config_reload_subscriptions: List[str] = []
instance = meta.instance
# 从插件实例获取组件声明(SDK 插件须实现 get_components 方法)
@@ -355,12 +357,15 @@ class PluginRunner:
)
for comp_info in instance.get_components()
)
+ if hasattr(instance, "get_config_reload_subscriptions"):
+ config_reload_subscriptions = list(instance.get_config_reload_subscriptions())
reg_payload = RegisterPluginPayload(
plugin_id=meta.plugin_id,
plugin_version=meta.version,
components=components,
capabilities_required=meta.capabilities_required,
+ config_reload_subscriptions=config_reload_subscriptions,
)
try:
@@ -911,18 +916,28 @@ class PluginRunner:
return envelope.make_response(payload={"acknowledged": True})
async def _handle_config_updated(self, envelope: Envelope) -> Envelope:
- """处理配置更新事件"""
+ """处理配置更新事件。"""
+ try:
+ payload = ConfigUpdatedPayload.model_validate(envelope.payload)
+ except Exception as exc:
+ return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
+
plugin_id = envelope.plugin_id
if meta := self._loader.get_plugin(plugin_id):
try:
- config_data = envelope.payload.get("config_data", {})
- config_version = envelope.payload.get("config_version", "")
- self._apply_plugin_config(meta, config_data=config_data)
- if hasattr(meta.instance, "on_config_update"):
- ret = meta.instance.on_config_update(config_data, config_version)
- # 兼容同步和异步的 on_config_update 实现
- if asyncio.iscoroutine(ret):
- await ret
+ config_scope = payload.config_scope.value
+ if config_scope == "self":
+ self._apply_plugin_config(meta, config_data=payload.config_data)
+ if not hasattr(meta.instance, "on_config_update"):
+ raise AttributeError("插件缺少 on_config_update() 实现")
+
+ ret = meta.instance.on_config_update(
+ config_scope,
+ payload.config_data,
+ payload.config_version,
+ )
+ if asyncio.iscoroutine(ret):
+ await ret
except Exception as e:
logger.error(f"插件 {plugin_id} 配置更新失败: {e}")
return envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
diff --git a/src/plugins/built_in/emoji_plugin/plugin.py b/src/plugins/built_in/emoji_plugin/plugin.py
index b946931b..cc6b87c5 100644
--- a/src/plugins/built_in/emoji_plugin/plugin.py
+++ b/src/plugins/built_in/emoji_plugin/plugin.py
@@ -3,11 +3,11 @@
根据聊天上下文的情感,使用 LLM 选择并发送合适的表情包。
"""
-import random
-
-from maibot_sdk import MaiBotPlugin, Action
+from maibot_sdk import Action, MaiBotPlugin
from maibot_sdk.types import ActivationType
+import random
+
class EmojiPlugin(MaiBotPlugin):
"""表情包插件"""
@@ -95,10 +95,35 @@ class EmojiPlugin(MaiBotPlugin):
return True, f"成功发送表情包:[表情包:{chosen_emotion}]"
return False, "发送表情包失败"
- async def on_load(self):
+ async def on_load(self) -> None:
+ """处理插件加载。"""
+
# 从插件配置读取 emoji_chance 来覆盖默认概率
await self.ctx.config.get("emoji.emoji_chance")
+ async def on_unload(self) -> None:
+ """处理插件卸载。"""
+
+ async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None:
+ """处理配置热重载事件。
+
+ Args:
+ scope: 配置变更范围。
+ config_data: 最新配置数据。
+ version: 配置版本号。
+ """
+
+ del config_data
+ del version
+ if scope == "self":
+ await self.ctx.config.get("emoji.emoji_chance")
+
+
+def create_plugin() -> EmojiPlugin:
+ """创建 Emoji 插件实例。
+
+ Returns:
+ EmojiPlugin: 新的 Emoji 插件实例。
+ """
-def create_plugin():
return EmojiPlugin()
diff --git a/src/plugins/built_in/plugin_management/plugin.py b/src/plugins/built_in/plugin_management/plugin.py
index fe0888c6..aa2da795 100644
--- a/src/plugins/built_in/plugin_management/plugin.py
+++ b/src/plugins/built_in/plugin_management/plugin.py
@@ -3,7 +3,7 @@
通过 /pm 命令管理插件和组件的生命周期。
"""
-from maibot_sdk import MaiBotPlugin, Command
+from maibot_sdk import Command, MaiBotPlugin
_VALID_COMPONENT_TYPES = ("action", "command", "event_handler")
@@ -44,6 +44,12 @@ HELP_COMPONENT = (
class PluginManagementPlugin(MaiBotPlugin):
"""插件和组件管理插件"""
+ async def on_load(self) -> None:
+ """处理插件加载。"""
+
+ async def on_unload(self) -> None:
+ """处理插件卸载。"""
+
@Command(
"management",
description="管理插件和组件的生命周期",
@@ -268,6 +274,25 @@ class PluginManagementPlugin(MaiBotPlugin):
return components
return []
+ async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None:
+ """处理配置热重载事件。
+
+ Args:
+ scope: 配置变更范围。
+ config_data: 最新配置数据。
+ version: 配置版本号。
+ """
+
+ del scope
+ del config_data
+ del version
+
+
+def create_plugin() -> PluginManagementPlugin:
+ """创建插件管理插件实例。
+
+ Returns:
+ PluginManagementPlugin: 新的插件管理插件实例。
+ """
-def create_plugin():
return PluginManagementPlugin()
From 7a304ba54964273e52f8d6fa3d6aee7612164be0 Mon Sep 17 00:00:00 2001
From: DrSmoothl <1787882683@qq.com>
Date: Mon, 23 Mar 2026 21:01:55 +0800
Subject: [PATCH 35/45] feat: Enhance API and Outbound Tracking Functionality
- Add test for fallback to bot account in platform IO route metadata when context message is absent.
- Improve PlatformIOManager to avoid duplicate driver entries and streamline fallback driver handling.
- Refactor OutboundTracker to support tracking by both internal message ID and driver ID, enhancing the uniqueness of pending records.
- Introduce dynamic API capabilities in RuntimeComponent, allowing plugins to replace their dynamic API lists.
- Update APIRegistry to manage dynamic APIs more effectively, including registration and toggling of API statuses.
- Implement authorization checks for dynamic API capabilities to ensure proper permissions.
- Restrict direct calls to certain host RPC methods from plugins for enhanced security.
- Refactor send_service to ensure fallback to current platform account when no context message is available.
---
pytests/test_platform_io_legacy_driver.py | 56 ++-
pytests/test_plugin_runtime_api.py | 230 +++++++++++++
pytests/test_send_service.py | 13 +
src/platform_io/manager.py | 21 +-
src/platform_io/outbound_tracker.py | 70 +++-
src/plugin_runtime/capabilities/components.py | 211 ++++++++++--
src/plugin_runtime/capabilities/registry.py | 1 +
src/plugin_runtime/host/api_registry.py | 319 +++++++++++-------
src/plugin_runtime/host/authorization.py | 5 +
src/plugin_runtime/runner/runner_main.py | 16 +-
src/services/send_service.py | 29 +-
11 files changed, 771 insertions(+), 200 deletions(-)
diff --git a/pytests/test_platform_io_legacy_driver.py b/pytests/test_platform_io_legacy_driver.py
index 2e94c1fc..76f14d8f 100644
--- a/pytests/test_platform_io_legacy_driver.py
+++ b/pytests/test_platform_io_legacy_driver.py
@@ -82,7 +82,61 @@ async def test_platform_io_uses_legacy_driver_when_no_explicit_send_route(
)
explicit_drivers = manager.resolve_drivers(RouteKey(platform="qq"))
- assert [driver.driver_id for driver in explicit_drivers] == ["plugin.qq.sender"]
+ assert [driver.driver_id for driver in explicit_drivers] == ["plugin.qq.sender", "legacy.send.qq"]
+ finally:
+ await manager.stop()
+
+
+@pytest.mark.asyncio
+async def test_platform_io_broadcasts_to_plugin_and_legacy_driver(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """同一路由命中插件驱动与 legacy driver 时,应同时广播发送。"""
+
+ manager = PlatformIOManager()
+ legacy_calls: list[dict[str, Any]] = []
+ monkeypatch.setattr(chat_utils, "get_all_bot_accounts", lambda: {"qq": "bot-qq"})
+
+ async def _fake_send_prepared_message_to_platform(message: Any, show_log: bool = True) -> bool:
+ """记录 legacy driver 调用。"""
+
+ legacy_calls.append({"message": message, "show_log": show_log})
+ return True
+
+ monkeypatch.setattr(
+ uni_message_sender,
+ "send_prepared_message_to_platform",
+ _fake_send_prepared_message_to_platform,
+ )
+
+ try:
+ await manager.ensure_send_pipeline_ready()
+
+ plugin_driver = _PluginDriver(driver_id="plugin.qq.sender", platform="qq")
+ await manager.add_driver(plugin_driver)
+ manager.bind_send_route(
+ RouteBinding(
+ route_key=RouteKey(platform="qq"),
+ driver_id=plugin_driver.driver_id,
+ driver_kind=plugin_driver.descriptor.kind,
+ )
+ )
+
+ message = type("FakeMessage", (), {"message_id": "message-1"})()
+ batch = await manager.send_message(
+ message=message,
+ route_key=RouteKey(platform="qq"),
+ metadata={"show_log": False},
+ )
+
+ assert sorted(receipt.driver_id for receipt in batch.sent_receipts) == [
+ "legacy.send.qq",
+ "plugin.qq.sender",
+ ]
+ assert batch.failed_receipts == []
+ assert len(legacy_calls) == 1
+ assert legacy_calls[0]["message"] is message
+ assert legacy_calls[0]["show_log"] is False
finally:
await manager.stop()
diff --git a/pytests/test_plugin_runtime_api.py b/pytests/test_plugin_runtime_api.py
index fca7736a..58a8e6ba 100644
--- a/pytests/test_plugin_runtime_api.py
+++ b/pytests/test_plugin_runtime_api.py
@@ -292,3 +292,233 @@ async def test_api_list_and_component_toggle_use_dedicated_registry() -> None:
)
assert enable_result["success"] is True
assert provider_supervisor.api_registry.get_api("provider", "public_api", enabled_only=True) is not None
+
+
+@pytest.mark.asyncio
+async def test_api_registry_supports_multiple_versions_with_distinct_handlers(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """同名 API 不同版本应可并存,并按版本路由到不同处理器。"""
+
+ provider_supervisor = PluginSupervisor(plugin_dirs=[])
+ consumer_supervisor = PluginSupervisor(plugin_dirs=[])
+ await _register_plugin(
+ provider_supervisor,
+ "provider",
+ [
+ {
+ "name": "render_html",
+ "component_type": "API",
+ "metadata": {
+ "description": "渲染 HTML v1",
+ "version": "1",
+ "public": True,
+ "handler_name": "handle_render_html_v1",
+ },
+ },
+ {
+ "name": "render_html",
+ "component_type": "API",
+ "metadata": {
+ "description": "渲染 HTML v2",
+ "version": "2",
+ "public": True,
+ "handler_name": "handle_render_html_v2",
+ },
+ },
+ ],
+ )
+ await _register_plugin(consumer_supervisor, "consumer", [])
+
+ captured: Dict[str, Any] = {}
+
+ async def fake_invoke_api(
+ plugin_id: str,
+ component_name: str,
+ args: Dict[str, Any] | None = None,
+ timeout_ms: int = 30000,
+ ) -> Any:
+ """模拟多版本 API 调用。"""
+
+ captured["plugin_id"] = plugin_id
+ captured["component_name"] = component_name
+ captured["args"] = args or {}
+ captured["timeout_ms"] = timeout_ms
+ return SimpleNamespace(error=None, payload={"success": True, "result": {"image": "ok"}})
+
+ monkeypatch.setattr(provider_supervisor, "invoke_api", fake_invoke_api)
+ manager = _build_manager(provider_supervisor, consumer_supervisor)
+
+ ambiguous_result = await manager._cap_api_call(
+ "consumer",
+ "api.call",
+ {
+ "api_name": "provider.render_html",
+ "args": {"html": "Hello
"},
+ },
+ )
+ assert ambiguous_result["success"] is False
+ assert "多个版本" in str(ambiguous_result["error"])
+
+ disable_ambiguous_result = await manager._cap_component_disable(
+ "consumer",
+ "component.disable",
+ {
+ "name": "provider.render_html",
+ "component_type": "API",
+ "scope": "global",
+ "stream_id": "",
+ },
+ )
+ assert disable_ambiguous_result["success"] is False
+ assert "多个版本" in str(disable_ambiguous_result["error"])
+
+ disable_v1_result = await manager._cap_component_disable(
+ "consumer",
+ "component.disable",
+ {
+ "name": "provider.render_html",
+ "component_type": "API",
+ "scope": "global",
+ "stream_id": "",
+ "version": "1",
+ },
+ )
+ assert disable_v1_result["success"] is True
+ assert provider_supervisor.api_registry.get_api("provider", "render_html", version="1", enabled_only=True) is None
+ assert provider_supervisor.api_registry.get_api("provider", "render_html", version="2", enabled_only=True) is not None
+
+ result = await manager._cap_api_call(
+ "consumer",
+ "api.call",
+ {
+ "api_name": "provider.render_html",
+ "version": "2",
+ "args": {"html": "Hello
"},
+ },
+ )
+
+ assert result == {"success": True, "result": {"image": "ok"}}
+ assert captured["plugin_id"] == "provider"
+ assert captured["component_name"] == "handle_render_html_v2"
+ assert captured["args"] == {"html": "Hello
"}
+
+
+@pytest.mark.asyncio
+async def test_api_replace_dynamic_can_offline_removed_entries(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """动态 API 替换后,被移除的 API 应返回明确下线错误。"""
+
+ supervisor = PluginSupervisor(plugin_dirs=[])
+ await _register_plugin(supervisor, "provider", [])
+ manager = _build_manager(supervisor)
+
+ captured: Dict[str, Any] = {}
+
+ async def fake_invoke_api(
+ plugin_id: str,
+ component_name: str,
+ args: Dict[str, Any] | None = None,
+ timeout_ms: int = 30000,
+ ) -> Any:
+ """模拟动态 API 调用。"""
+
+ captured["plugin_id"] = plugin_id
+ captured["component_name"] = component_name
+ captured["args"] = args or {}
+ captured["timeout_ms"] = timeout_ms
+ return SimpleNamespace(error=None, payload={"success": True, "result": {"ok": True}})
+
+ monkeypatch.setattr(supervisor, "invoke_api", fake_invoke_api)
+
+ replace_result = await manager._cap_api_replace_dynamic(
+ "provider",
+ "api.replace_dynamic",
+ {
+ "apis": [
+ {
+ "name": "mcp.search",
+ "type": "API",
+ "metadata": {
+ "version": "1",
+ "public": True,
+ "handler_name": "dynamic_search",
+ },
+ },
+ {
+ "name": "mcp.read",
+ "type": "API",
+ "metadata": {
+ "version": "1",
+ "public": True,
+ "handler_name": "dynamic_read",
+ },
+ },
+ ],
+ "offline_reason": "MCP 服务器已关闭",
+ },
+ )
+
+ assert replace_result["success"] is True
+ assert replace_result["count"] == 2
+ list_result = await manager._cap_api_list("provider", "api.list", {"plugin_id": "provider"})
+ assert {(item["name"], item["version"]) for item in list_result["apis"]} == {
+ ("mcp.read", "1"),
+ ("mcp.search", "1"),
+ }
+
+ call_result = await manager._cap_api_call(
+ "provider",
+ "api.call",
+ {
+ "api_name": "provider.mcp.search",
+ "version": "1",
+ "args": {"query": "hello"},
+ },
+ )
+ assert call_result == {"success": True, "result": {"ok": True}}
+ assert captured["component_name"] == "dynamic_search"
+ assert captured["args"]["query"] == "hello"
+ assert captured["args"]["__maibot_api_name__"] == "mcp.search"
+ assert captured["args"]["__maibot_api_version__"] == "1"
+
+ second_replace_result = await manager._cap_api_replace_dynamic(
+ "provider",
+ "api.replace_dynamic",
+ {
+ "apis": [
+ {
+ "name": "mcp.read",
+ "type": "API",
+ "metadata": {
+ "version": "1",
+ "public": True,
+ "handler_name": "dynamic_read",
+ },
+ }
+ ],
+ "offline_reason": "MCP 服务器已关闭",
+ },
+ )
+
+ assert second_replace_result["success"] is True
+ assert second_replace_result["count"] == 1
+ assert second_replace_result["offlined"] == 1
+
+ offlined_call_result = await manager._cap_api_call(
+ "provider",
+ "api.call",
+ {
+ "api_name": "provider.mcp.search",
+ "version": "1",
+ "args": {},
+ },
+ )
+ assert offlined_call_result["success"] is False
+ assert "MCP 服务器已关闭" in str(offlined_call_result["error"])
+
+ list_after_replace = await manager._cap_api_list("provider", "api.list", {"plugin_id": "provider"})
+ assert {(item["name"], item["version"]) for item in list_after_replace["apis"]} == {
+ ("mcp.read", "1"),
+ }
diff --git a/pytests/test_send_service.py b/pytests/test_send_service.py
index 4ddd4fa1..16aad080 100644
--- a/pytests/test_send_service.py
+++ b/pytests/test_send_service.py
@@ -73,6 +73,19 @@ def _build_target_stream() -> BotChatSession:
)
+def test_inherit_platform_io_route_metadata_falls_back_to_bot_account(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """没有上下文消息时,也应回填当前平台账号用于账号级路由命中。"""
+
+ monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq" if platform == "qq" else "")
+
+ metadata = send_service._inherit_platform_io_route_metadata(_build_target_stream())
+
+ assert metadata["platform_io_account_id"] == "bot-qq"
+ assert metadata["platform_io_target_user_id"] == "target-user"
+
+
@pytest.mark.asyncio
async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.MonkeyPatch) -> None:
"""send service 应将发送职责统一交给 Platform IO。"""
diff --git a/src/platform_io/manager.py b/src/platform_io/manager.py
index be03e35d..ab1b11e5 100644
--- a/src/platform_io/manager.py
+++ b/src/platform_io/manager.py
@@ -394,23 +394,22 @@ class PlatformIOManager:
"""
drivers: List[PlatformIODriver] = []
+ seen_driver_ids: set[str] = set()
for binding in self._send_route_table.resolve_bindings(route_key):
driver = self._driver_registry.get(binding.driver_id)
- if driver is not None:
+ if driver is not None and driver.driver_id not in seen_driver_ids:
drivers.append(driver)
- if drivers:
- return drivers
+ seen_driver_ids.add(driver.driver_id)
fallback_driver = self._legacy_send_drivers.get(route_key.platform)
- if fallback_driver is None:
- return []
+ if fallback_driver is not None:
+ descriptor = fallback_driver.descriptor
+ account_matches = descriptor.account_id is None or route_key.account_id in (None, descriptor.account_id)
+ scope_matches = descriptor.scope is None or route_key.scope in (None, descriptor.scope)
+ if account_matches and scope_matches and fallback_driver.driver_id not in seen_driver_ids:
+ drivers.append(fallback_driver)
- descriptor = fallback_driver.descriptor
- if descriptor.account_id is not None and route_key.account_id not in (None, descriptor.account_id):
- return []
- if descriptor.scope is not None and route_key.scope not in (None, descriptor.scope):
- return []
- return [fallback_driver]
+ return drivers
@staticmethod
def build_route_key_from_message(message: "SessionMessage") -> RouteKey:
diff --git a/src/platform_io/outbound_tracker.py b/src/platform_io/outbound_tracker.py
index 438aa566..3725691f 100644
--- a/src/platform_io/outbound_tracker.py
+++ b/src/platform_io/outbound_tracker.py
@@ -92,11 +92,24 @@ class OutboundTracker:
raise ValueError("ttl_seconds 必须大于 0")
self._ttl_seconds = ttl_seconds
- self._pending: Dict[str, PendingOutboundRecord] = {}
- self._pending_expire_heap: List[Tuple[float, str]] = []
+ self._pending: Dict[Tuple[str, str], PendingOutboundRecord] = {}
+ self._pending_expire_heap: List[Tuple[float, str, str]] = []
self._receipts_by_external_id: Dict[str, StoredDeliveryReceipt] = {}
self._receipt_expire_heap: List[Tuple[float, str]] = []
+ @staticmethod
+ def _build_pending_key(internal_message_id: str, driver_id: str) -> Tuple[str, str]:
+ """构造单条出站跟踪记录的唯一键。
+
+ Args:
+ internal_message_id: 内部消息 ID。
+ driver_id: 负责当前投递的驱动 ID。
+
+ Returns:
+ Tuple[str, str]: ``(internal_message_id, driver_id)`` 组合键。
+ """
+ return internal_message_id, driver_id
+
def begin_tracking(
self,
internal_message_id: str,
@@ -116,13 +129,15 @@ class OutboundTracker:
PendingOutboundRecord: 新创建的待完成记录。
Raises:
- ValueError: 当同一个 ``internal_message_id`` 已经存在未完成记录时抛出。
+ ValueError: 当同一个 ``internal_message_id`` 与 ``driver_id`` 组合已经存在
+ 未完成记录时抛出。
"""
now = time.monotonic()
self._cleanup_expired(now)
+ pending_key = self._build_pending_key(internal_message_id, driver_id)
- if internal_message_id in self._pending:
- raise ValueError(f"消息 {internal_message_id} 已存在未完成的出站跟踪记录")
+ if pending_key in self._pending:
+ raise ValueError(f"消息 {internal_message_id} 在驱动 {driver_id} 上已存在未完成的出站跟踪记录")
expires_at = now + self._ttl_seconds
record = PendingOutboundRecord(
@@ -133,8 +148,8 @@ class OutboundTracker:
expires_at=expires_at,
metadata=metadata or {},
)
- self._pending[internal_message_id] = record
- heapq.heappush(self._pending_expire_heap, (expires_at, internal_message_id))
+ self._pending[pending_key] = record
+ heapq.heappush(self._pending_expire_heap, (expires_at, internal_message_id, driver_id))
return record
def finish_tracking(self, receipt: DeliveryReceipt) -> Optional[PendingOutboundRecord]:
@@ -149,7 +164,19 @@ class OutboundTracker:
now = time.monotonic()
self._cleanup_expired(now)
- pending_record = self._pending.pop(receipt.internal_message_id, None)
+ pending_record: Optional[PendingOutboundRecord] = None
+ if receipt.driver_id:
+ pending_key = self._build_pending_key(receipt.internal_message_id, receipt.driver_id)
+ pending_record = self._pending.pop(pending_key, None)
+ else:
+ matched_records = [
+ key
+ for key, record in self._pending.items()
+ if record.internal_message_id == receipt.internal_message_id
+ ]
+ if len(matched_records) == 1:
+ pending_record = self._pending.pop(matched_records[0], None)
+
if receipt.external_message_id:
expires_at = now + self._ttl_seconds
self._receipts_by_external_id[receipt.external_message_id] = StoredDeliveryReceipt(
@@ -160,17 +187,33 @@ class OutboundTracker:
heapq.heappush(self._receipt_expire_heap, (expires_at, receipt.external_message_id))
return pending_record
- def get_pending(self, internal_message_id: str) -> Optional[PendingOutboundRecord]:
+ def get_pending(
+ self,
+ internal_message_id: str,
+ driver_id: Optional[str] = None,
+ ) -> Optional[PendingOutboundRecord]:
"""根据内部消息 ID 查询待完成记录。
Args:
internal_message_id: 要查询的内部消息 ID。
+ driver_id: 可选的驱动 ID;提供后仅返回该驱动上的待完成记录。
Returns:
Optional[PendingOutboundRecord]: 若记录仍存在,则返回对应待完成记录。
"""
self._cleanup_expired(time.monotonic())
- return self._pending.get(internal_message_id)
+
+ if driver_id:
+ return self._pending.get(self._build_pending_key(internal_message_id, driver_id))
+
+ matched_records = [
+ record
+ for record in self._pending.values()
+ if record.internal_message_id == internal_message_id
+ ]
+ if len(matched_records) == 1:
+ return matched_records[0]
+ return None
def get_receipt_by_external_id(self, external_message_id: str) -> Optional[DeliveryReceipt]:
"""根据外部平台消息 ID 查询已完成回执。
@@ -213,13 +256,14 @@ class OutboundTracker:
``expires_at`` 对比,跳过这类旧节点。
"""
while self._pending_expire_heap and self._pending_expire_heap[0][0] <= now:
- expires_at, internal_message_id = heapq.heappop(self._pending_expire_heap)
- current_record = self._pending.get(internal_message_id)
+ expires_at, internal_message_id, driver_id = heapq.heappop(self._pending_expire_heap)
+ pending_key = self._build_pending_key(internal_message_id, driver_id)
+ current_record = self._pending.get(pending_key)
if current_record is None:
continue
if current_record.expires_at != expires_at:
continue
- self._pending.pop(internal_message_id, None)
+ self._pending.pop(pending_key, None)
def _cleanup_expired_receipts(self, now: float) -> None:
"""清理已经过期的回执索引。
diff --git a/src/plugin_runtime/capabilities/components.py b/src/plugin_runtime/capabilities/components.py
index 2eede108..67033fdd 100644
--- a/src/plugin_runtime/capabilities/components.py
+++ b/src/plugin_runtime/capabilities/components.py
@@ -72,6 +72,8 @@ class RuntimeComponentCapabilityMixin:
"version": entry.version,
"public": entry.public,
"enabled": entry.enabled,
+ "dynamic": entry.dynamic,
+ "offline_reason": entry.offline_reason,
"metadata": dict(entry.metadata),
}
@@ -109,6 +111,32 @@ class RuntimeComponentCapabilityMixin:
return entry.plugin_id == caller_plugin_id or entry.public
+ @staticmethod
+ def _normalize_api_reference(api_name: str, version: str = "") -> tuple[str, str]:
+ """规范化 API 名称与版本参数。
+
+ 支持在 ``api_name`` 中直接携带 ``@version`` 后缀。
+ """
+
+ normalized_api_name = str(api_name or "").strip()
+ normalized_version = str(version or "").strip()
+ if normalized_api_name and not normalized_version and "@" in normalized_api_name:
+ candidate_name, candidate_version = normalized_api_name.rsplit("@", 1)
+ candidate_name = candidate_name.strip()
+ candidate_version = candidate_version.strip()
+ if candidate_name and candidate_version:
+ normalized_api_name = candidate_name
+ normalized_version = candidate_version
+ return normalized_api_name, normalized_version
+
+ @staticmethod
+ def _build_api_unavailable_error(entry: "APIEntry") -> str:
+ """构造 API 当前不可用时的错误信息。"""
+
+ if entry.offline_reason:
+ return entry.offline_reason
+ return f"API {entry.registry_key} 当前不可用"
+
def _resolve_api_target(
self: _RuntimeComponentManagerProtocol,
caller_plugin_id: str,
@@ -127,8 +155,7 @@ class RuntimeComponentCapabilityMixin:
解析成功时返回 ``(监督器, API 条目, None)``,失败时返回错误信息。
"""
- normalized_api_name = str(api_name or "").strip()
- normalized_version = str(version or "").strip()
+ normalized_api_name, normalized_version = self._normalize_api_reference(api_name, version)
if not normalized_api_name:
return None, None, "缺少必要参数 api_name"
@@ -142,34 +169,61 @@ class RuntimeComponentCapabilityMixin:
if supervisor is None:
return None, None, f"未找到 API 提供方插件: {target_plugin_id}"
- entry = supervisor.api_registry.get_api(
+ entries = supervisor.api_registry.get_apis(
plugin_id=target_plugin_id,
name=target_api_name,
- enabled_only=True,
+ version=normalized_version,
+ enabled_only=False,
)
- if entry is None:
- return None, None, f"未找到 API: {normalized_api_name}"
- if normalized_version and entry.version != normalized_version:
- return None, None, f"未找到版本为 {normalized_version} 的 API: {normalized_api_name}"
- if not self._is_api_visible_to_plugin(entry, caller_plugin_id):
+ visible_enabled_entries = [
+ entry
+ for entry in entries
+ if self._is_api_visible_to_plugin(entry, caller_plugin_id) and entry.enabled
+ ]
+ visible_disabled_entries = [
+ entry
+ for entry in entries
+ if self._is_api_visible_to_plugin(entry, caller_plugin_id) and not entry.enabled
+ ]
+ if len(visible_enabled_entries) == 1:
+ return supervisor, visible_enabled_entries[0], None
+ if len(visible_enabled_entries) > 1:
+ return None, None, f"API {normalized_api_name} 存在多个版本,请显式指定 version"
+ if visible_disabled_entries:
+ if len(visible_disabled_entries) == 1:
+ return None, None, self._build_api_unavailable_error(visible_disabled_entries[0])
+ return None, None, f"API {normalized_api_name} 存在多个已下线版本,请显式指定 version"
+ if any(not self._is_api_visible_to_plugin(entry, caller_plugin_id) for entry in entries):
return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用"
- return supervisor, entry, None
+ if normalized_version:
+ return None, None, f"未找到版本为 {normalized_version} 的 API: {normalized_api_name}"
+ return None, None, f"未找到 API: {normalized_api_name}"
- visible_matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
+ visible_enabled_matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
+ visible_disabled_matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
hidden_match_exists = False
for supervisor in self.supervisors:
- for entry in supervisor.api_registry.get_apis(name=normalized_api_name, enabled_only=True):
- if normalized_version and entry.version != normalized_version:
- continue
+ for entry in supervisor.api_registry.get_apis(
+ name=normalized_api_name,
+ version=normalized_version,
+ enabled_only=False,
+ ):
if self._is_api_visible_to_plugin(entry, caller_plugin_id):
- visible_matches.append((supervisor, entry))
+ if entry.enabled:
+ visible_enabled_matches.append((supervisor, entry))
+ else:
+ visible_disabled_matches.append((supervisor, entry))
else:
hidden_match_exists = True
- if len(visible_matches) == 1:
- return visible_matches[0][0], visible_matches[0][1], None
- if len(visible_matches) > 1:
- return None, None, f"API 名称不唯一: {normalized_api_name},请使用 plugin_id.api_name"
+ if len(visible_enabled_matches) == 1:
+ return visible_enabled_matches[0][0], visible_enabled_matches[0][1], None
+ if len(visible_enabled_matches) > 1:
+ return None, None, f"API 名称不唯一: {normalized_api_name},请使用 plugin_id.api_name 或显式指定 version"
+ if visible_disabled_matches:
+ if len(visible_disabled_matches) == 1:
+ return None, None, self._build_api_unavailable_error(visible_disabled_matches[0][1])
+ return None, None, f"API {normalized_api_name} 存在多个已下线版本,请使用 plugin_id.api_name@version"
if hidden_match_exists:
return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用"
if normalized_version:
@@ -179,18 +233,20 @@ class RuntimeComponentCapabilityMixin:
def _resolve_api_toggle_target(
self: _RuntimeComponentManagerProtocol,
name: str,
+ version: str = "",
) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]:
"""解析需要启用或禁用的 API 组件。
Args:
name: API 名称,支持 ``plugin_id.api_name`` 或唯一短名。
+ version: 可选的 API 版本。
Returns:
tuple[Optional[PluginSupervisor], Optional[APIEntry], Optional[str]]:
解析成功时返回 ``(监督器, API 条目, None)``,失败时返回错误信息。
"""
- normalized_name = str(name or "").strip()
+ normalized_name, normalized_version = self._normalize_api_reference(name, version)
if not normalized_name:
return None, None, "缺少必要参数 name"
@@ -204,24 +260,31 @@ class RuntimeComponentCapabilityMixin:
if supervisor is None:
return None, None, f"未找到 API 提供方插件: {plugin_id}"
- entry = supervisor.api_registry.get_api(
+ entries = supervisor.api_registry.get_apis(
plugin_id=plugin_id,
name=api_name,
+ version=normalized_version,
enabled_only=False,
)
- if entry is None:
+ if not entries:
return None, None, f"未找到 API: {normalized_name}"
- return supervisor, entry, None
+ if len(entries) > 1:
+ return None, None, f"API {normalized_name} 存在多个版本,请显式指定 version"
+ return supervisor, entries[0], None
matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
for supervisor in self.supervisors:
- for entry in supervisor.api_registry.get_apis(name=normalized_name, enabled_only=False):
+ for entry in supervisor.api_registry.get_apis(
+ name=normalized_name,
+ version=normalized_version,
+ enabled_only=False,
+ ):
matches.append((supervisor, entry))
if len(matches) == 1:
return matches[0][0], matches[0][1], None
if len(matches) > 1:
- return None, None, f"API 名称不唯一: {normalized_name},请使用 plugin_id.api_name"
+ return None, None, f"API 名称不唯一: {normalized_name},请使用 plugin_id.api_name 或显式指定 version"
return None, None, f"未找到 API: {normalized_name}"
async def _cap_component_get_all_plugins(
@@ -326,6 +389,7 @@ class RuntimeComponentCapabilityMixin:
) -> Any:
name: str = args.get("name", "")
component_type: str = args.get("component_type", "")
+ version: str = args.get("version", "")
scope: str = args.get("scope", "global")
stream_id: str = args.get("stream_id", "")
if not name or not component_type:
@@ -334,10 +398,10 @@ class RuntimeComponentCapabilityMixin:
return {"success": False, "error": "当前仅支持全局组件启用,不支持 scope/stream_id 定位"}
if self._is_api_component_type(component_type):
- supervisor, api_entry, error = self._resolve_api_toggle_target(name)
+ supervisor, api_entry, error = self._resolve_api_toggle_target(name, version)
if supervisor is None or api_entry is None:
return {"success": False, "error": error or f"未找到 API: {name}"}
- supervisor.api_registry.toggle_api_status(api_entry.full_name, True)
+ supervisor.api_registry.toggle_api_status(api_entry.registry_key, True)
return {"success": True}
comp, error = self._resolve_component_toggle_target(name, component_type)
@@ -352,6 +416,7 @@ class RuntimeComponentCapabilityMixin:
) -> Any:
name: str = args.get("name", "")
component_type: str = args.get("component_type", "")
+ version: str = args.get("version", "")
scope: str = args.get("scope", "global")
stream_id: str = args.get("stream_id", "")
if not name or not component_type:
@@ -360,10 +425,10 @@ class RuntimeComponentCapabilityMixin:
return {"success": False, "error": "当前仅支持全局组件禁用,不支持 scope/stream_id 定位"}
if self._is_api_component_type(component_type):
- supervisor, api_entry, error = self._resolve_api_toggle_target(name)
+ supervisor, api_entry, error = self._resolve_api_toggle_target(name, version)
if supervisor is None or api_entry is None:
return {"success": False, "error": error or f"未找到 API: {name}"}
- supervisor.api_registry.toggle_api_status(api_entry.full_name, False)
+ supervisor.api_registry.toggle_api_status(api_entry.registry_key, False)
return {"success": True}
comp, error = self._resolve_component_toggle_target(name, component_type)
@@ -488,11 +553,17 @@ class RuntimeComponentCapabilityMixin:
if supervisor is None or entry is None:
return {"success": False, "error": error or "API 解析失败"}
+ invoke_args = dict(api_args)
+ if entry.dynamic:
+ invoke_args.setdefault("__maibot_api_name__", entry.name)
+ invoke_args.setdefault("__maibot_api_full_name__", entry.full_name)
+ invoke_args.setdefault("__maibot_api_version__", entry.version)
+
try:
response = await supervisor.invoke_api(
plugin_id=entry.plugin_id,
- component_name=entry.name,
- args=api_args,
+ component_name=entry.handler_name,
+ args=invoke_args,
timeout_ms=30000,
)
except Exception as exc:
@@ -555,10 +626,16 @@ class RuntimeComponentCapabilityMixin:
del capability
target_plugin_id = str(args.get("plugin_id", "") or "").strip()
+ api_name, version = self._normalize_api_reference(
+ str(args.get("api_name", args.get("name", "")) or ""),
+ str(args.get("version", "") or ""),
+ )
apis: List[Dict[str, Any]] = []
for supervisor in self.supervisors:
for entry in supervisor.api_registry.get_apis(
plugin_id=target_plugin_id or None,
+ name=api_name,
+ version=version,
enabled_only=True,
):
if not self._is_api_visible_to_plugin(entry, plugin_id):
@@ -567,3 +644,75 @@ class RuntimeComponentCapabilityMixin:
apis.sort(key=lambda item: (str(item["plugin_id"]), str(item["name"]), str(item["version"])))
return {"success": True, "apis": apis}
+
+ async def _cap_api_replace_dynamic(
+ self: _RuntimeComponentManagerProtocol,
+ plugin_id: str,
+ capability: str,
+ args: Dict[str, Any],
+ ) -> Any:
+ """替换插件自行维护的动态 API 列表。"""
+
+ del capability
+ raw_apis = args.get("apis", [])
+ offline_reason = str(args.get("offline_reason", "") or "").strip() or "动态 API 已下线"
+ if not isinstance(raw_apis, list):
+ return {"success": False, "error": "参数 apis 必须为列表"}
+
+ try:
+ supervisor = self._get_supervisor_for_plugin(plugin_id)
+ except RuntimeError as exc:
+ return {"success": False, "error": str(exc)}
+
+ if supervisor is None:
+ return {"success": False, "error": f"未找到插件: {plugin_id}"}
+
+ normalized_components: List[Dict[str, Any]] = []
+ seen_registry_keys: set[str] = set()
+ for index, raw_api in enumerate(raw_apis):
+ if not isinstance(raw_api, dict):
+ return {"success": False, "error": f"apis[{index}] 必须为字典"}
+
+ api_name = str(raw_api.get("name", "") or "").strip()
+ component_type = str(raw_api.get("component_type", raw_api.get("type", "API")) or "").strip()
+ if not api_name:
+ return {"success": False, "error": f"apis[{index}] 缺少 name"}
+ if not self._is_api_component_type(component_type):
+ return {"success": False, "error": f"apis[{index}] 不是 API 组件"}
+
+ metadata = raw_api.get("metadata", {}) if isinstance(raw_api.get("metadata"), dict) else {}
+ normalized_metadata = dict(metadata)
+ normalized_metadata["dynamic"] = True
+ version = str(normalized_metadata.get("version", "1") or "1").strip() or "1"
+ registry_key = supervisor.api_registry.build_registry_key(plugin_id, api_name, version)
+ if registry_key in seen_registry_keys:
+ return {"success": False, "error": f"动态 API 重复声明: {registry_key}"}
+ seen_registry_keys.add(registry_key)
+
+ existing_entry = supervisor.api_registry.get_api(
+ plugin_id,
+ api_name,
+ version=version,
+ enabled_only=False,
+ )
+ if existing_entry is not None and not existing_entry.dynamic:
+ return {"success": False, "error": f"动态 API 不能覆盖静态 API: {registry_key}"}
+
+ normalized_components.append(
+ {
+ "name": api_name,
+ "component_type": "API",
+ "metadata": normalized_metadata,
+ }
+ )
+
+ registered_count, offlined_count = supervisor.api_registry.replace_plugin_dynamic_apis(
+ plugin_id,
+ normalized_components,
+ offline_reason=offline_reason,
+ )
+ return {
+ "success": True,
+ "count": registered_count,
+ "offlined": offlined_count,
+ }
diff --git a/src/plugin_runtime/capabilities/registry.py b/src/plugin_runtime/capabilities/registry.py
index 31693833..7f87604d 100644
--- a/src/plugin_runtime/capabilities/registry.py
+++ b/src/plugin_runtime/capabilities/registry.py
@@ -77,6 +77,7 @@ def register_capability_impls(manager: "PluginRuntimeManager", supervisor: Plugi
_register("api.call", manager._cap_api_call)
_register("api.get", manager._cap_api_get)
_register("api.list", manager._cap_api_list)
+ _register("api.replace_dynamic", manager._cap_api_replace_dynamic)
_register("component.get_all_plugins", manager._cap_component_get_all_plugins)
_register("component.get_plugin_info", manager._cap_component_get_plugin_info)
diff --git a/src/plugin_runtime/host/api_registry.py b/src/plugin_runtime/host/api_registry.py
index 84578ca5..1cbc05f6 100644
--- a/src/plugin_runtime/host/api_registry.py
+++ b/src/plugin_runtime/host/api_registry.py
@@ -1,45 +1,60 @@
"""Host 侧插件 API 动态注册表。"""
-from typing import Any, Dict, List, Optional, Set
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional, Set, Tuple
from src.common.logger import get_logger
logger = get_logger("plugin_runtime.host.api_registry")
+@dataclass(slots=True)
class APIEntry:
"""API 组件条目。"""
- __slots__ = (
- "description",
- "disabled_session",
- "enabled",
- "full_name",
- "metadata",
- "name",
- "plugin_id",
- "public",
- "version",
- )
+ name: str
+ plugin_id: str
+ description: str = ""
+ version: str = "1"
+ public: bool = False
+ metadata: Dict[str, Any] = field(default_factory=dict)
+ enabled: bool = True
+ handler_name: str = ""
+ dynamic: bool = False
+ offline_reason: str = ""
+ disabled_session: Set[str] = field(default_factory=set)
+ full_name: str = field(init=False)
+ registry_key: str = field(init=False)
- def __init__(self, name: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
- """初始化 API 组件条目。
+ def __post_init__(self) -> None:
+ """规范化 API 条目字段。"""
- Args:
- name: API 名称。
- plugin_id: 所属插件 ID。
- metadata: API 元数据。
- """
+ self.name = str(self.name or "").strip()
+ self.plugin_id = str(self.plugin_id or "").strip()
+ self.description = str(self.description or "").strip()
+ self.version = str(self.version or "1").strip() or "1"
+ self.handler_name = str(self.handler_name or self.name).strip() or self.name
+ self.offline_reason = str(self.offline_reason or "").strip()
+ self.full_name = f"{self.plugin_id}.{self.name}"
+ self.registry_key = APIRegistry.build_registry_key(self.plugin_id, self.name, self.version)
- self.name: str = name
- self.full_name: str = f"{plugin_id}.{name}"
- self.plugin_id: str = plugin_id
- self.description: str = str(metadata.get("description", "") or "")
- self.version: str = str(metadata.get("version", "1") or "1").strip() or "1"
- self.public: bool = bool(metadata.get("public", False))
- self.metadata: Dict[str, Any] = dict(metadata)
- self.enabled: bool = bool(metadata.get("enabled", True))
- self.disabled_session: Set[str] = set()
+ @classmethod
+ def from_metadata(cls, name: str, plugin_id: str, metadata: Dict[str, Any]) -> "APIEntry":
+ """根据 Runner 上报的元数据构造 API 条目。"""
+
+ safe_metadata = dict(metadata)
+ return cls(
+ name=name,
+ plugin_id=plugin_id,
+ description=str(safe_metadata.get("description", "") or ""),
+ version=str(safe_metadata.get("version", "1") or "1"),
+ public=bool(safe_metadata.get("public", False)),
+ metadata=safe_metadata,
+ enabled=bool(safe_metadata.get("enabled", True)),
+ handler_name=str(safe_metadata.get("handler_name", name) or name),
+ dynamic=bool(safe_metadata.get("dynamic", False)),
+ offline_reason=str(safe_metadata.get("offline_reason", "") or ""),
+ )
class APIRegistry:
@@ -53,6 +68,7 @@ class APIRegistry:
"""初始化 API 注册表。"""
self._apis: Dict[str, APIEntry] = {}
+ self._by_full_name: Dict[str, List[APIEntry]] = {}
self._by_plugin: Dict[str, List[APIEntry]] = {}
self._by_name: Dict[str, List[APIEntry]] = {}
@@ -60,75 +76,75 @@ class APIRegistry:
"""清空全部 API 注册状态。"""
self._apis.clear()
+ self._by_full_name.clear()
self._by_plugin.clear()
self._by_name.clear()
@staticmethod
def _is_api_component(component_type: Any) -> bool:
- """判断组件声明是否属于 API。
-
- Args:
- component_type: 原始组件类型值。
-
- Returns:
- bool: 是否为 API 组件。
- """
+ """判断组件声明是否属于 API。"""
return str(component_type or "").strip().upper() == "API"
+ @staticmethod
+ def _normalize_query_version(version: Any) -> str:
+ """规范化查询使用的版本字符串。"""
+
+ return str(version or "").strip()
+
+ @classmethod
+ def _split_reference(cls, reference: str, version: Any = "") -> Tuple[str, str]:
+ """解析可能带 ``@version`` 后缀的 API 引用。"""
+
+ normalized_reference = str(reference or "").strip()
+ normalized_version = cls._normalize_query_version(version)
+ if normalized_reference and not normalized_version and "@" in normalized_reference:
+ candidate_reference, candidate_version = normalized_reference.rsplit("@", 1)
+ candidate_reference = candidate_reference.strip()
+ candidate_version = candidate_version.strip()
+ if candidate_reference and candidate_version:
+ normalized_reference = candidate_reference
+ normalized_version = candidate_version
+ return normalized_reference, normalized_version
+
+ @staticmethod
+ def build_registry_key(plugin_id: str, name: str, version: str) -> str:
+ """构造 API 注册表唯一键。"""
+
+ normalized_full_name = f"{str(plugin_id or '').strip()}.{str(name or '').strip()}"
+ normalized_version = str(version or "1").strip() or "1"
+ return f"{normalized_full_name}@{normalized_version}"
+
@staticmethod
def check_api_enabled(entry: APIEntry, session_id: Optional[str] = None) -> bool:
- """判断 API 条目当前是否处于启用状态。
-
- Args:
- entry: 待检查的 API 条目。
- session_id: 可选的会话 ID。
-
- Returns:
- bool: 当前是否可用。
- """
+ """判断 API 条目当前是否处于启用状态。"""
if session_id and session_id in entry.disabled_session:
return False
return entry.enabled
def register_api(self, name: str, plugin_id: str, metadata: Dict[str, Any]) -> bool:
- """注册单个 API 条目。
-
- Args:
- name: API 名称。
- plugin_id: 所属插件 ID。
- metadata: API 元数据。
-
- Returns:
- bool: 是否成功注册。
- """
+ """注册单个 API 条目。"""
normalized_name = str(name or "").strip()
if not normalized_name:
logger.warning(f"插件 {plugin_id} 存在空 API 名称声明,已忽略")
return False
- entry = APIEntry(name=normalized_name, plugin_id=plugin_id, metadata=metadata)
- if entry.full_name in self._apis:
- logger.warning(f"API {entry.full_name} 已存在,覆盖旧条目")
- self._remove_entry(self._apis[entry.full_name])
+ entry = APIEntry.from_metadata(name=normalized_name, plugin_id=plugin_id, metadata=metadata)
+ existing_entry = self._apis.get(entry.registry_key)
+ if existing_entry is not None:
+ logger.warning(f"API {entry.registry_key} 已存在,覆盖旧条目")
+ self._remove_entry(existing_entry)
- self._apis[entry.full_name] = entry
+ self._apis[entry.registry_key] = entry
+ self._by_full_name.setdefault(entry.full_name, []).append(entry)
self._by_plugin.setdefault(plugin_id, []).append(entry)
self._by_name.setdefault(entry.name, []).append(entry)
return True
def register_plugin_apis(self, plugin_id: str, components: List[Dict[str, Any]]) -> int:
- """批量注册某个插件声明的全部 API。
-
- Args:
- plugin_id: 插件 ID。
- components: 插件组件声明列表。
-
- Returns:
- int: 成功注册的 API 数量。
- """
+ """批量注册某个插件声明的全部 API。"""
count = 0
for component in components:
@@ -142,14 +158,60 @@ class APIRegistry:
count += 1
return count
+ def replace_plugin_dynamic_apis(
+ self,
+ plugin_id: str,
+ components: List[Dict[str, Any]],
+ *,
+ offline_reason: str = "动态 API 已下线",
+ ) -> Tuple[int, int]:
+ """替换指定插件当前声明的动态 API 集合。"""
+
+ normalized_offline_reason = str(offline_reason or "").strip() or "动态 API 已下线"
+ desired_registry_keys: Set[str] = set()
+ registered_count = 0
+
+ for component in components:
+ if not self._is_api_component(component.get("component_type")):
+ continue
+ metadata = component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {}
+ dynamic_metadata = dict(metadata)
+ dynamic_metadata["dynamic"] = True
+ dynamic_metadata.pop("offline_reason", None)
+
+ entry = APIEntry.from_metadata(
+ name=str(component.get("name", "") or ""),
+ plugin_id=plugin_id,
+ metadata=dynamic_metadata,
+ )
+ desired_registry_keys.add(entry.registry_key)
+ if self.register_api(entry.name, plugin_id, dynamic_metadata):
+ registered_count += 1
+
+ offlined_count = 0
+ for entry in list(self._by_plugin.get(plugin_id, [])):
+ if not entry.dynamic or entry.registry_key in desired_registry_keys:
+ continue
+ entry.enabled = False
+ entry.offline_reason = normalized_offline_reason
+ entry.metadata["offline_reason"] = normalized_offline_reason
+ offlined_count += 1
+
+ return registered_count, offlined_count
+
def _remove_entry(self, entry: APIEntry) -> None:
- """从全部索引中移除单个 API 条目。
+ """从全部索引中移除单个 API 条目。"""
- Args:
- entry: 待移除的 API 条目。
- """
+ self._apis.pop(entry.registry_key, None)
+
+ full_name_entries = self._by_full_name.get(entry.full_name)
+ if full_name_entries is not None:
+ self._by_full_name[entry.full_name] = [
+ candidate for candidate in full_name_entries if candidate is not entry
+ ]
+ if not self._by_full_name[entry.full_name]:
+ self._by_full_name.pop(entry.full_name, None)
- self._apis.pop(entry.full_name, None)
plugin_entries = self._by_plugin.get(entry.plugin_id)
if plugin_entries is not None:
self._by_plugin[entry.plugin_id] = [candidate for candidate in plugin_entries if candidate is not entry]
@@ -163,14 +225,7 @@ class APIRegistry:
self._by_name.pop(entry.name, None)
def remove_apis_by_plugin(self, plugin_id: str) -> int:
- """移除某个插件的全部 API。
-
- Args:
- plugin_id: 目标插件 ID。
-
- Returns:
- int: 被移除的 API 数量。
- """
+ """移除某个插件的全部 API。"""
entries = list(self._by_plugin.get(plugin_id, []))
for entry in entries:
@@ -181,49 +236,48 @@ class APIRegistry:
self,
full_name: str,
*,
+ version: str = "",
enabled_only: bool = True,
session_id: Optional[str] = None,
) -> Optional[APIEntry]:
- """按完整名查询单个 API。
+ """按完整名查询单个 API。"""
- Args:
- full_name: API 完整名,格式为 ``plugin_id.api_name``。
- enabled_only: 是否仅返回启用状态的 API。
- session_id: 可选的会话 ID。
-
- Returns:
- Optional[APIEntry]: 命中时返回 API 条目。
- """
-
- entry = self._apis.get(full_name)
- if entry is None:
+ normalized_full_name, normalized_version = self._split_reference(full_name, version)
+ if not normalized_full_name:
return None
- if enabled_only and not self.check_api_enabled(entry, session_id):
+
+ if normalized_version:
+ entry = self._apis.get(f"{normalized_full_name}@{normalized_version}")
+ if entry is None:
+ return None
+ if enabled_only and not self.check_api_enabled(entry, session_id):
+ return None
+ return entry
+
+ candidates = list(self._by_full_name.get(normalized_full_name, []))
+ filtered_entries = [
+ entry
+ for entry in candidates
+ if not enabled_only or self.check_api_enabled(entry, session_id)
+ ]
+ if len(filtered_entries) != 1:
return None
- return entry
+ return filtered_entries[0]
def get_api(
self,
plugin_id: str,
name: str,
*,
+ version: str = "",
enabled_only: bool = True,
session_id: Optional[str] = None,
) -> Optional[APIEntry]:
- """按插件 ID 和短名查询单个 API。
-
- Args:
- plugin_id: 提供方插件 ID。
- name: API 短名。
- enabled_only: 是否仅返回启用状态的 API。
- session_id: 可选的会话 ID。
-
- Returns:
- Optional[APIEntry]: 命中时返回 API 条目。
- """
+ """按插件 ID、短名与版本查询单个 API。"""
return self.get_api_by_full_name(
f"{plugin_id}.{name}",
+ version=version,
enabled_only=enabled_only,
session_id=session_id,
)
@@ -233,22 +287,15 @@ class APIRegistry:
*,
plugin_id: Optional[str] = None,
name: str = "",
+ version: str = "",
enabled_only: bool = True,
session_id: Optional[str] = None,
) -> List[APIEntry]:
- """查询 API 列表。
-
- Args:
- plugin_id: 可选的插件 ID 过滤条件。
- name: 可选的 API 名称过滤条件。
- enabled_only: 是否仅返回启用状态的 API。
- session_id: 可选的会话 ID。
-
- Returns:
- List[APIEntry]: 符合条件的 API 条目列表。
- """
+ """查询 API 列表。"""
normalized_name = str(name or "").strip()
+ normalized_version = self._normalize_query_version(version)
+
if plugin_id:
candidates = list(self._by_plugin.get(plugin_id, []))
elif normalized_name:
@@ -258,26 +305,35 @@ class APIRegistry:
filtered_entries: List[APIEntry] = []
for entry in candidates:
+ if plugin_id and entry.plugin_id != plugin_id:
+ continue
if normalized_name and entry.name != normalized_name:
continue
+ if normalized_version and entry.version != normalized_version:
+ continue
if enabled_only and not self.check_api_enabled(entry, session_id):
continue
filtered_entries.append(entry)
+
+ filtered_entries.sort(key=lambda entry: (entry.plugin_id, entry.name, entry.version))
return filtered_entries
- def toggle_api_status(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool:
- """设置指定 API 的启用状态。
+ def toggle_api_status(
+ self,
+ full_name: str,
+ enabled: bool,
+ *,
+ version: str = "",
+ session_id: Optional[str] = None,
+ ) -> bool:
+ """设置指定 API 的启用状态。"""
- Args:
- full_name: API 完整名。
- enabled: 目标启用状态。
- session_id: 可选的会话 ID,仅对该会话生效。
-
- Returns:
- bool: 是否设置成功。
- """
-
- entry = self._apis.get(full_name)
+ entry = self.get_api_by_full_name(
+ full_name,
+ version=version,
+ enabled_only=False,
+ session_id=session_id,
+ )
if entry is None:
return False
if session_id:
@@ -287,4 +343,7 @@ class APIRegistry:
entry.disabled_session.add(session_id)
else:
entry.enabled = enabled
+ if enabled:
+ entry.offline_reason = ""
+ entry.metadata.pop("offline_reason", None)
return True
diff --git a/src/plugin_runtime/host/authorization.py b/src/plugin_runtime/host/authorization.py
index 3fb48c6a..70593768 100644
--- a/src/plugin_runtime/host/authorization.py
+++ b/src/plugin_runtime/host/authorization.py
@@ -7,6 +7,8 @@
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Set, Tuple
+_ALWAYS_ALLOWED_CAPABILITIES = frozenset({"api.replace_dynamic"})
+
@dataclass
class CapabilityPermissionToken:
@@ -46,6 +48,9 @@ class AuthorizationManager:
Returns:
return (bool, str): (是否有此能力, 原因)
"""
+ if capability in _ALWAYS_ALLOWED_CAPABILITIES:
+ return True, ""
+
token = self._permission_tokens.get(plugin_id)
if not token:
return False, f"插件 {plugin_id} 未注册能力令牌"
diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py
index b94b01d1..b38946d6 100644
--- a/src/plugin_runtime/runner/runner_main.py
+++ b/src/plugin_runtime/runner/runner_main.py
@@ -45,6 +45,14 @@ from src.plugin_runtime.runner.rpc_client import RPCClient
logger = get_logger("plugin_runtime.runner.main")
+_PLUGIN_ALLOWED_RAW_HOST_METHODS = frozenset(
+ {
+ "cap.call",
+ "host.route_message",
+ "host.update_message_gateway_state",
+ }
+)
+
class _ContextAwarePlugin(Protocol):
"""支持注入运行时上下文的插件协议。
@@ -247,8 +255,14 @@ class PluginRunner:
logger.warning(
f"插件 {bound_plugin_id} 尝试以 {plugin_id} 身份发起 RPC,已强制绑定回自身身份"
)
+ normalized_method = str(method or "").strip()
+ if normalized_method not in _PLUGIN_ALLOWED_RAW_HOST_METHODS:
+ raise PermissionError(
+ f"插件 {bound_plugin_id} 不允许直接调用 Host 原始 RPC 方法: "
+ f"{normalized_method or ''}"
+ )
resp = await rpc_client.send_request(
- method=method,
+ method=normalized_method,
plugin_id=bound_plugin_id,
payload=payload or {},
)
diff --git a/src/services/send_service.py b/src/services/send_service.py
index 54f2a9de..134fb15e 100644
--- a/src/services/send_service.py
+++ b/src/services/send_service.py
@@ -57,20 +57,23 @@ def _inherit_platform_io_route_metadata(target_stream: BotChatSession) -> Dict[s
inherited_metadata: Dict[str, object] = {}
context_message = target_stream.context.message if target_stream.context else None
- if context_message is None:
- return inherited_metadata
+ if context_message is not None:
+ additional_config = context_message.message_info.additional_config
+ if isinstance(additional_config, dict):
+ for key in (*RouteKeyFactory.ACCOUNT_ID_KEYS, *RouteKeyFactory.SCOPE_KEYS):
+ value = additional_config.get(key)
+ if value is None:
+ continue
+ normalized_value = str(value).strip()
+ if normalized_value:
+ inherited_metadata[key] = value
- additional_config = context_message.message_info.additional_config
- if not isinstance(additional_config, dict):
- return inherited_metadata
-
- for key in (*RouteKeyFactory.ACCOUNT_ID_KEYS, *RouteKeyFactory.SCOPE_KEYS):
- value = additional_config.get(key)
- if value is None:
- continue
- normalized_value = str(value).strip()
- if normalized_value:
- inherited_metadata[key] = value
+ # 当目标会话没有可继承的上下文消息时,至少补齐当前平台账号,
+ # 让按 ``platform + account_id`` 绑定的路由仍有机会命中。
+ if not RouteKeyFactory.extract_components(inherited_metadata)[0]:
+ bot_account = get_bot_account(target_stream.platform)
+ if bot_account:
+ inherited_metadata["platform_io_account_id"] = bot_account
if target_stream.group_id and (normalized_group_id := str(target_stream.group_id).strip()):
inherited_metadata["platform_io_target_group_id"] = normalized_group_id
From 0c508995ddbfa42020485bdc23f4d5a62b1e54f1 Mon Sep 17 00:00:00 2001
From: DrSmoothl <1787882683@qq.com>
Date: Mon, 23 Mar 2026 21:48:19 +0800
Subject: [PATCH 36/45] feat: enhance session ID calculation and plugin
management
- Updated `calculate_session_id` method in `SessionUtils` to include optional `account_id` and `scope` parameters for more granular session ID generation.
- Added new environment variables in `plugin_runtime` for external plugin dependencies and global configuration snapshots.
- Introduced methods in `RuntimeComponentManagerProtocol` for loading and reloading plugins globally, accommodating external dependencies.
- Enhanced `PluginRunnerSupervisor` to manage external available plugin IDs during plugin reloads.
- Implemented dependency extraction and management in `PluginRuntimeManager` to handle cross-supervisor dependencies.
- Added tests for session ID calculation and message registration in `ChatManager` to ensure correct behavior with new parameters.
---
pytests/test_plugin_runtime.py | 114 ++++---
pytests/utils_test/test_session_utils.py | 42 +++
src/chat/message_receive/bot.py | 16 +-
src/chat/message_receive/chat_manager.py | 58 +++-
src/common/utils/utils_session.py | 22 +-
src/plugin_runtime/__init__.py | 6 +
src/plugin_runtime/capabilities/components.py | 136 ++++----
src/plugin_runtime/host/supervisor.py | 86 +++--
src/plugin_runtime/integration.py | 302 +++++++++++++++++-
src/plugin_runtime/protocol/envelope.py | 4 +
src/plugin_runtime/runner/plugin_loader.py | 9 +-
src/plugin_runtime/runner/runner_main.py | 140 +++++++-
12 files changed, 765 insertions(+), 170 deletions(-)
create mode 100644 pytests/utils_test/test_session_utils.py
diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py
index 9b46f897..e094d85b 100644
--- a/pytests/test_plugin_runtime.py
+++ b/pytests/test_plugin_runtime.py
@@ -3,6 +3,7 @@
验证协议层、传输层、RPC 通信链路的正确性。
"""
+from pathlib import Path
from types import SimpleNamespace
import asyncio
@@ -2362,6 +2363,8 @@ class TestIntegration:
from src.plugin_runtime import integration as integration_module
instances = []
+ builtin_dir = Path("builtin")
+ thirdparty_dir = Path("thirdparty")
class FakeCapabilityService:
def register_capability(self, name, impl):
@@ -2369,11 +2372,18 @@ class TestIntegration:
class FakeSupervisor:
def __init__(self, plugin_dirs=None, socket_path=None):
- self.plugin_dirs = plugin_dirs or []
+ self._plugin_dirs = plugin_dirs or []
self.capability_service = FakeCapabilityService()
+ self.external_plugin_ids = []
self.stopped = False
instances.append(self)
+ def set_external_available_plugin_ids(self, plugin_ids):
+ self.external_plugin_ids = list(plugin_ids)
+
+ def get_loaded_plugin_ids(self):
+ return []
+
async def start(self):
if len(instances) == 2 and self is instances[1]:
raise RuntimeError("boom")
@@ -2382,10 +2392,10 @@ class TestIntegration:
self.stopped = True
monkeypatch.setattr(
- integration_module.PluginRuntimeManager, "_get_builtin_plugin_dirs", staticmethod(lambda: ["builtin"])
+ integration_module.PluginRuntimeManager, "_get_builtin_plugin_dirs", staticmethod(lambda: [builtin_dir])
)
monkeypatch.setattr(
- integration_module.PluginRuntimeManager, "_get_thirdparty_plugin_dirs", staticmethod(lambda: ["thirdparty"])
+ integration_module.PluginRuntimeManager, "_get_third_party_plugin_dirs", staticmethod(lambda: [thirdparty_dir])
)
import src.plugin_runtime.host.supervisor as supervisor_module
@@ -2427,8 +2437,11 @@ class TestIntegration:
self.reload_reasons = []
self.config_updates = []
- async def reload_plugins(self, plugin_ids=None, reason="manual"):
- self.reload_reasons.append((plugin_ids, reason))
+ def get_loaded_plugin_ids(self):
+ return sorted(self._registered_plugins.keys())
+
+ async def reload_plugins(self, plugin_ids=None, reason="manual", external_available_plugins=None):
+ self.reload_reasons.append((plugin_ids, reason, external_available_plugins or []))
async def notify_plugin_config_updated(self, plugin_id, config_data, config_version=""):
self.config_updates.append((plugin_id, config_data, config_version))
@@ -2453,11 +2466,59 @@ class TestIntegration:
await manager._handle_plugin_source_changes(changes)
assert manager._builtin_supervisor.reload_reasons == []
- assert manager._third_party_supervisor.reload_reasons == [(["beta"], "file_watcher")]
+ assert manager._third_party_supervisor.reload_reasons == [(["beta"], "file_watcher", ["alpha"])]
assert manager._builtin_supervisor.config_updates == []
assert manager._third_party_supervisor.config_updates == []
assert refresh_calls == [True]
+ @pytest.mark.asyncio
+ async def test_reload_plugins_globally_warns_and_skips_cross_supervisor_dependents(self, monkeypatch):
+ from src.plugin_runtime import integration as integration_module
+
+ class FakeRegistration:
+ def __init__(self, dependencies):
+ self.dependencies = dependencies
+
+ class FakeSupervisor:
+ def __init__(self, registrations):
+ self._registered_plugins = registrations
+ self.reload_calls = []
+
+ def get_loaded_plugin_ids(self):
+ return sorted(self._registered_plugins.keys())
+
+ async def reload_plugins(self, plugin_ids=None, reason="manual", external_available_plugins=None):
+ self.reload_calls.append((plugin_ids, reason, sorted(external_available_plugins or [])))
+ return True
+
+ builtin_supervisor = FakeSupervisor({"alpha": FakeRegistration([])})
+ third_party_supervisor = FakeSupervisor(
+ {
+ "beta": FakeRegistration(["alpha"]),
+ "gamma": FakeRegistration(["beta"]),
+ }
+ )
+
+ manager = integration_module.PluginRuntimeManager()
+ manager._builtin_supervisor = builtin_supervisor
+ manager._third_party_supervisor = third_party_supervisor
+ warning_messages = []
+
+ monkeypatch.setattr(
+ integration_module.logger,
+ "warning",
+ lambda message: warning_messages.append(message),
+ )
+
+ reloaded = await manager.reload_plugins_globally(["alpha"], reason="manual")
+
+ assert reloaded is True
+ assert builtin_supervisor.reload_calls == [(["alpha"], "manual", ["beta", "gamma"])]
+ assert third_party_supervisor.reload_calls == []
+ assert len(warning_messages) == 1
+ assert "beta, gamma" in warning_messages[0]
+ assert "跨 Supervisor API 调用仍然可用" in warning_messages[0]
+
@pytest.mark.asyncio
async def test_handle_plugin_config_changes_only_notify_target_plugin(self, monkeypatch, tmp_path):
from src.plugin_runtime import integration as integration_module
@@ -2623,55 +2684,30 @@ class TestIntegration:
async def test_component_reload_plugin_returns_failure_when_reload_rolls_back(self, monkeypatch):
from src.plugin_runtime import integration as integration_module
- class FakeSupervisor:
- def __init__(self):
- self._registered_plugins = {"alpha": object()}
+ manager = integration_module.PluginRuntimeManager()
+ monkeypatch.setattr(manager, "reload_plugins_globally", lambda plugin_ids, reason="manual": asyncio.sleep(0, False))
- async def reload_plugins(self, reason="manual"):
- return False
-
- class FakeManager:
- def __init__(self):
- self.supervisors = [FakeSupervisor()]
-
- monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager())
-
- result = await integration_module.PluginRuntimeManager._cap_component_reload_plugin(
+ result = await manager._cap_component_reload_plugin(
"plugin_a",
"component.reload_plugin",
{"plugin_name": "alpha"},
)
assert result["success"] is False
- assert "已回滚" in result["error"]
+ assert result["error"] == "插件 alpha 热重载失败"
@pytest.mark.asyncio
async def test_component_load_plugin_returns_failure_when_reload_rolls_back(self, monkeypatch, tmp_path):
from src.plugin_runtime import integration as integration_module
- plugin_root = tmp_path / "plugins"
- plugin_root.mkdir()
- (plugin_root / "alpha").mkdir()
+ manager = integration_module.PluginRuntimeManager()
+ monkeypatch.setattr(manager, "load_plugin_globally", lambda plugin_id, reason="manual": asyncio.sleep(0, False))
- class FakeSupervisor:
- def __init__(self):
- self._registered_plugins = {}
- self._plugin_dirs = [str(plugin_root)]
-
- async def reload_plugins(self, reason="manual"):
- return False
-
- class FakeManager:
- def __init__(self):
- self.supervisors = [FakeSupervisor()]
-
- monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager())
-
- result = await integration_module.PluginRuntimeManager._cap_component_load_plugin(
+ result = await manager._cap_component_load_plugin(
"plugin_a",
"component.load_plugin",
{"plugin_name": "alpha"},
)
assert result["success"] is False
- assert "已回滚" in result["error"]
+ assert result["error"] == "插件 alpha 热重载失败"
diff --git a/pytests/utils_test/test_session_utils.py b/pytests/utils_test/test_session_utils.py
new file mode 100644
index 00000000..c44e2eba
--- /dev/null
+++ b/pytests/utils_test/test_session_utils.py
@@ -0,0 +1,42 @@
+from types import SimpleNamespace
+
+from src.chat.message_receive.chat_manager import ChatManager
+from src.common.utils.utils_session import SessionUtils
+
+
+def test_calculate_session_id_distinguishes_account_and_scope() -> None:
+ base_session_id = SessionUtils.calculate_session_id("qq", user_id="42")
+ same_base_session_id = SessionUtils.calculate_session_id("qq", user_id="42")
+ account_scoped_session_id = SessionUtils.calculate_session_id("qq", user_id="42", account_id="123")
+ route_scoped_session_id = SessionUtils.calculate_session_id("qq", user_id="42", account_id="123", scope="main")
+
+ assert base_session_id == same_base_session_id
+ assert account_scoped_session_id != base_session_id
+ assert route_scoped_session_id != account_scoped_session_id
+
+
+def test_chat_manager_register_message_uses_route_metadata() -> None:
+ chat_manager = ChatManager()
+ message = SimpleNamespace(
+ platform="qq",
+ session_id="",
+ message_info=SimpleNamespace(
+ user_info=SimpleNamespace(user_id="42"),
+ group_info=SimpleNamespace(group_id="1000"),
+ additional_config={
+ "platform_io_account_id": "123",
+ "platform_io_scope": "main",
+ },
+ ),
+ )
+
+ chat_manager.register_message(message)
+
+ assert message.session_id == SessionUtils.calculate_session_id(
+ "qq",
+ user_id="42",
+ group_id="1000",
+ account_id="123",
+ scope="main",
+ )
+ assert chat_manager.last_messages[message.session_id] is message
diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py
index 025150fc..1fc4ef53 100644
--- a/src/chat/message_receive/bot.py
+++ b/src/chat/message_receive/bot.py
@@ -10,6 +10,7 @@ from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiv
from src.common.logger import get_logger
from src.common.utils.utils_message import MessageUtils
from src.common.utils.utils_session import SessionUtils
+from src.platform_io.route_key_factory import RouteKeyFactory
# from src.chat.brain_chat.PFC.pfc_manager import PFCManager
from src.core.announcement_manager import global_announcement_manager
@@ -270,11 +271,18 @@ class ChatBot:
try:
group_info = message.message_info.group_info
user_info = message.message_info.user_info
+ account_id = None
+ scope = None
+ additional_config = message.message_info.additional_config
+ if isinstance(additional_config, dict):
+ account_id, scope = RouteKeyFactory.extract_components(additional_config)
session_id = SessionUtils.calculate_session_id(
message.platform,
user_id=message.message_info.user_info.user_id,
group_id=group_info.group_id if group_info else None,
+ account_id=account_id,
+ scope=scope,
)
message.session_id = session_id # 正确初始化session_id
@@ -317,7 +325,13 @@ class ChatBot:
platform = message.platform
user_id = user_info.user_id
group_id = group_info.group_id if group_info else None
- _ = await chat_manager.get_or_create_session(platform, user_id, group_id) # 确保会话存在
+ _ = await chat_manager.get_or_create_session(
+ platform,
+ user_id,
+ group_id,
+ account_id=account_id,
+ scope=scope,
+ ) # 确保会话存在
# message.update_chat_stream(chat)
diff --git a/src/chat/message_receive/chat_manager.py b/src/chat/message_receive/chat_manager.py
index b11d233c..48d89956 100644
--- a/src/chat/message_receive/chat_manager.py
+++ b/src/chat/message_receive/chat_manager.py
@@ -1,15 +1,16 @@
+import asyncio
from datetime import datetime
+from typing import TYPE_CHECKING, Dict, List, Optional
+
from rich.traceback import install
from sqlmodel import select
-from typing import Optional, TYPE_CHECKING, List, Dict
-import asyncio
-
-from src.common.logger import get_logger
from src.common.data_models.chat_session_data_model import MaiChatSession
-from src.common.database.database_model import ChatSession
from src.common.database.database import get_db_session
+from src.common.database.database_model import ChatSession
+from src.common.logger import get_logger
from src.common.utils.utils_session import SessionUtils
+from src.platform_io.route_key_factory import RouteKeyFactory
if TYPE_CHECKING:
from .message import SessionMessage
@@ -82,7 +83,12 @@ class ChatManager:
logger.error(f"初始化聊天管理器出现错误: {e}")
async def get_or_create_session(
- self, platform: str, user_id: str, group_id: Optional[str] = None
+ self,
+ platform: str,
+ user_id: str,
+ group_id: Optional[str] = None,
+ account_id: Optional[str] = None,
+ scope: Optional[str] = None,
) -> BotChatSession:
"""获取会话,如果不存在则创建一个新会话;一个封装方法。
@@ -90,12 +96,20 @@ class ChatManager:
platform: 平台
user_id: 用户ID
group_id: 群ID(如果是群聊)
+ account_id: 平台账号 ID
+ scope: 路由作用域
Returns:
return (BotChatSession) 会话对象
Raises:
Exception: 获取或创建会话时发生错误
"""
- session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
+ session_id = SessionUtils.calculate_session_id(
+ platform,
+ user_id=user_id,
+ group_id=group_id,
+ account_id=account_id,
+ scope=scope,
+ )
if session := self.get_session_by_session_id(session_id):
session.update_active_time()
return session
@@ -131,7 +145,18 @@ class ChatManager:
raise ValueError("消息缺少平台信息")
user_id = message.message_info.user_info.user_id
group_id = message.message_info.group_info.group_id if message.message_info.group_info else None
- session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
+ account_id = None
+ scope = None
+ additional_config = message.message_info.additional_config
+ if isinstance(additional_config, dict):
+ account_id, scope = RouteKeyFactory.extract_components(additional_config)
+ session_id = SessionUtils.calculate_session_id(
+ platform,
+ user_id=user_id,
+ group_id=group_id,
+ account_id=account_id,
+ scope=scope,
+ )
message.session_id = session_id # 确保消息的session_id正确设置
self.last_messages[session_id] = message
@@ -188,7 +213,12 @@ class ChatManager:
return None
def get_session_by_info(
- self, platform: str, user_id: Optional[str] = None, group_id: Optional[str] = None
+ self,
+ platform: str,
+ user_id: Optional[str] = None,
+ group_id: Optional[str] = None,
+ account_id: Optional[str] = None,
+ scope: Optional[str] = None,
) -> Optional[BotChatSession]:
"""根据平台、用户ID和群ID获取对应的会话
@@ -196,10 +226,18 @@ class ChatManager:
platform: 平台
user_id: 用户ID
group_id: 群ID(如果是群聊)
+ account_id: 平台账号 ID
+ scope: 路由作用域
Returns:
return (Optional[BotChatSession]): 会话对象,如果不存在则返回None
"""
- session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
+ session_id = SessionUtils.calculate_session_id(
+ platform,
+ user_id=user_id,
+ group_id=group_id,
+ account_id=account_id,
+ scope=scope,
+ )
return self.get_session_by_session_id(session_id)
def get_session_by_session_id(self, session_id: str) -> Optional[BotChatSession]:
diff --git a/src/common/utils/utils_session.py b/src/common/utils/utils_session.py
index a383f5a2..1b6d8f72 100644
--- a/src/common/utils/utils_session.py
+++ b/src/common/utils/utils_session.py
@@ -5,13 +5,22 @@ import hashlib
class SessionUtils:
@staticmethod
- def calculate_session_id(platform: str, *, user_id: Optional[str] = None, group_id: Optional[str] = None) -> str:
+ def calculate_session_id(
+ platform: str,
+ *,
+ user_id: Optional[str] = None,
+ group_id: Optional[str] = None,
+ account_id: Optional[str] = None,
+ scope: Optional[str] = None,
+ ) -> str:
"""计算session_id
Args:
platform: 平台名称
user_id: 用户ID(如果是私聊)
group_id: 群ID(如果是群聊)
+ account_id: 当前平台账号 ID,可选
+ scope: 当前路由作用域,可选
Returns:
str: 计算得到的会话ID
Raises:
@@ -19,8 +28,15 @@ class SessionUtils:
"""
if not user_id and not group_id:
raise ValueError("UserID 或 GroupID 必须提供其一")
+
+ route_components = []
+ if account_id:
+ route_components.append(f"account:{account_id}")
+ if scope:
+ route_components.append(f"scope:{scope}")
+
if group_id:
- components = [platform, group_id]
+ components = [platform, *route_components, group_id]
else:
- components = [platform, user_id, "private"]
+ components = [platform, *route_components, user_id, "private"]
return hashlib.md5("_".join(components).encode()).hexdigest()
diff --git a/src/plugin_runtime/__init__.py b/src/plugin_runtime/__init__.py
index a881d399..704ce514 100644
--- a/src/plugin_runtime/__init__.py
+++ b/src/plugin_runtime/__init__.py
@@ -16,3 +16,9 @@ ENV_PLUGIN_DIRS = "MAIBOT_PLUGIN_DIRS"
ENV_HOST_VERSION = "MAIBOT_HOST_VERSION"
"""Runner 读取的 Host 应用版本号,用于 manifest 兼容性校验"""
+
+ENV_EXTERNAL_PLUGIN_IDS = "MAIBOT_EXTERNAL_PLUGIN_IDS"
+"""Runner 启动时可视为已满足的外部插件依赖列表(JSON 数组)"""
+
+ENV_GLOBAL_CONFIG_SNAPSHOT = "MAIBOT_GLOBAL_CONFIG_SNAPSHOT"
+"""Runner 启动时注入的全局配置快照(JSON 对象)"""
diff --git a/src/plugin_runtime/capabilities/components.py b/src/plugin_runtime/capabilities/components.py
index 67033fdd..2e4c111c 100644
--- a/src/plugin_runtime/capabilities/components.py
+++ b/src/plugin_runtime/capabilities/components.py
@@ -1,5 +1,5 @@
from pathlib import Path
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Protocol
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Protocol, Sequence
from src.common.logger import get_logger
@@ -15,8 +15,35 @@ class _RuntimeComponentManagerProtocol(Protocol):
@property
def supervisors(self) -> List["PluginSupervisor"]: ...
+ def _normalize_component_type(self, component_type: str) -> str: ...
+
+ def _is_api_component_type(self, component_type: str) -> bool: ...
+
+ def _serialize_api_entry(self, entry: "APIEntry") -> Dict[str, Any]: ...
+
+ def _serialize_api_component_entry(self, entry: "APIEntry") -> Dict[str, Any]: ...
+
+ def _is_api_visible_to_plugin(self, entry: "APIEntry", caller_plugin_id: str) -> bool: ...
+
+ def _normalize_api_reference(self, api_name: str, version: str = "") -> tuple[str, str]: ...
+
+ def _build_api_unavailable_error(self, entry: "APIEntry") -> str: ...
+
def _get_supervisor_for_plugin(self, plugin_id: str) -> Optional["PluginSupervisor"]: ...
+ def _resolve_api_target(
+ self,
+ caller_plugin_id: str,
+ api_name: str,
+ version: str = "",
+ ) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]: ...
+
+ def _resolve_api_toggle_target(
+ self,
+ name: str,
+ version: str = "",
+ ) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]: ...
+
def _resolve_component_toggle_target(
self, name: str, component_type: str
) -> tuple[Optional["ComponentEntry"], Optional[str]]: ...
@@ -25,6 +52,10 @@ class _RuntimeComponentManagerProtocol(Protocol):
def _iter_plugin_dirs(self) -> Iterable[Path]: ...
+ async def load_plugin_globally(self, plugin_id: str, reason: str = "manual") -> bool: ...
+
+ async def reload_plugins_globally(self, plugin_ids: Sequence[str], reason: str = "manual") -> bool: ...
+
class RuntimeComponentCapabilityMixin:
@staticmethod
@@ -266,20 +297,22 @@ class RuntimeComponentCapabilityMixin:
version=normalized_version,
enabled_only=False,
)
- if not entries:
- return None, None, f"未找到 API: {normalized_name}"
- if len(entries) > 1:
+ if len(entries) == 1:
+ return supervisor, entries[0], None
+ if entries:
return None, None, f"API {normalized_name} 存在多个版本,请显式指定 version"
- return supervisor, entries[0], None
+ return None, None, f"未找到 API: {normalized_name}"
matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
for supervisor in self.supervisors:
- for entry in supervisor.api_registry.get_apis(
- name=normalized_name,
- version=normalized_version,
- enabled_only=False,
- ):
- matches.append((supervisor, entry))
+ matches.extend(
+ (supervisor, entry)
+ for entry in supervisor.api_registry.get_apis(
+ name=normalized_name,
+ version=normalized_version,
+ enabled_only=False,
+ )
+ )
if len(matches) == 1:
return matches[0][0], matches[0][1], None
@@ -453,39 +486,14 @@ class RuntimeComponentCapabilityMixin:
return {"success": False, "error": f"检测到重复插件 ID,拒绝热重载: {details}"}
try:
- registered_supervisor = self._get_supervisor_for_plugin(plugin_name)
- except RuntimeError as exc:
- return {"success": False, "error": str(exc)}
+ loaded = await self.load_plugin_globally(plugin_name, reason=f"load {plugin_name}")
+ except Exception as e:
+ logger.error(f"[cap.component.load_plugin] 热重载失败: {e}")
+ return {"success": False, "error": str(e)}
- if registered_supervisor is not None:
- try:
- reloaded = await registered_supervisor.reload_plugins(
- plugin_ids=[plugin_name],
- reason=f"load {plugin_name}",
- )
- if reloaded:
- return {"success": True, "count": 1}
- return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
- except Exception as e:
- logger.error(f"[cap.component.load_plugin] 热重载失败: {e}")
- return {"success": False, "error": str(e)}
-
- for sv in self.supervisors:
- for pdir in sv._plugin_dirs:
- if (pdir / plugin_name).is_dir():
- try:
- reloaded = await sv.reload_plugins(
- plugin_ids=[plugin_name],
- reason=f"load {plugin_name}",
- )
- if reloaded:
- return {"success": True, "count": 1}
- return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
- except Exception as e:
- logger.error(f"[cap.component.load_plugin] 热重载失败: {e}")
- return {"success": False, "error": str(e)}
-
- return {"success": False, "error": f"未找到插件: {plugin_name}"}
+ if loaded:
+ return {"success": True, "count": 1}
+ return {"success": False, "error": f"插件 {plugin_name} 热重载失败"}
async def _cap_component_unload_plugin(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
@@ -507,23 +515,14 @@ class RuntimeComponentCapabilityMixin:
return {"success": False, "error": f"检测到重复插件 ID,拒绝热重载: {details}"}
try:
- sv = self._get_supervisor_for_plugin(plugin_name)
- except RuntimeError as exc:
- return {"success": False, "error": str(exc)}
+ reloaded = await self.reload_plugins_globally([plugin_name], reason=f"reload {plugin_name}")
+ except Exception as e:
+ logger.error(f"[cap.component.reload_plugin] 热重载失败: {e}")
+ return {"success": False, "error": str(e)}
- if sv is not None:
- try:
- reloaded = await sv.reload_plugins(
- plugin_ids=[plugin_name],
- reason=f"reload {plugin_name}",
- )
- if reloaded:
- return {"success": True}
- return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
- except Exception as e:
- logger.error(f"[cap.component.reload_plugin] 热重载失败: {e}")
- return {"success": False, "error": str(e)}
- return {"success": False, "error": f"未找到插件: {plugin_name}"}
+ if reloaded:
+ return {"success": True}
+ return {"success": False, "error": f"插件 {plugin_name} 热重载失败"}
async def _cap_api_call(
self: _RuntimeComponentManagerProtocol,
@@ -632,15 +631,16 @@ class RuntimeComponentCapabilityMixin:
)
apis: List[Dict[str, Any]] = []
for supervisor in self.supervisors:
- for entry in supervisor.api_registry.get_apis(
- plugin_id=target_plugin_id or None,
- name=api_name,
- version=version,
- enabled_only=True,
- ):
- if not self._is_api_visible_to_plugin(entry, plugin_id):
- continue
- apis.append(self._serialize_api_entry(entry))
+ apis.extend(
+ self._serialize_api_entry(entry)
+ for entry in supervisor.api_registry.get_apis(
+ plugin_id=target_plugin_id or None,
+ name=api_name,
+ version=version,
+ enabled_only=True,
+ )
+ if self._is_api_visible_to_plugin(entry, plugin_id)
+ )
apis.sort(key=lambda item: (str(item["plugin_id"]), str(item["name"]), str(item["version"])))
return {"success": True, "apis": apis}
diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py
index afe944e5..693eae51 100644
--- a/src/plugin_runtime/host/supervisor.py
+++ b/src/plugin_runtime/host/supervisor.py
@@ -4,17 +4,26 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import asyncio
import contextlib
+import json
import os
import sys
from src.common.logger import get_logger
-from src.config.config import global_config
+from src.config.config import config_manager, global_config
from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, get_platform_io_manager
from src.platform_io.drivers import PluginPlatformDriver
from src.platform_io.route_key_factory import RouteKeyFactory
-from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN
+from src.plugin_runtime import (
+ ENV_EXTERNAL_PLUGIN_IDS,
+ ENV_GLOBAL_CONFIG_SNAPSHOT,
+ ENV_HOST_VERSION,
+ ENV_IPC_ADDRESS,
+ ENV_PLUGIN_DIRS,
+ ENV_SESSION_TOKEN,
+)
from src.plugin_runtime.protocol.envelope import (
BootstrapPluginPayload,
+ ConfigReloadScope,
ConfigUpdatedPayload,
Envelope,
HealthPayload,
@@ -107,6 +116,7 @@ class PluginRunnerSupervisor:
self._runner_process: Optional[asyncio.subprocess.Process] = None
self._registered_plugins: Dict[str, RegisterPluginPayload] = {}
self._message_gateway_states: Dict[str, Dict[str, _MessageGatewayRuntimeState]] = {}
+ self._external_available_plugin_ids: List[str] = []
self._runner_ready_events: asyncio.Event = asyncio.Event()
self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload()
self._health_task: Optional[asyncio.Task[None]] = None
@@ -156,6 +166,21 @@ class PluginRunnerSupervisor:
"""返回底层 RPC 服务端。"""
return self._rpc_server
+ def set_external_available_plugin_ids(self, plugin_ids: List[str]) -> None:
+ """设置当前 Runner 启动/重载时可视为已满足的外部依赖列表。"""
+
+ normalized_plugin_ids = {
+ str(plugin_id or "").strip()
+ for plugin_id in plugin_ids
+ if str(plugin_id or "").strip()
+ }
+ self._external_available_plugin_ids = sorted(normalized_plugin_ids)
+
+ def get_loaded_plugin_ids(self) -> List[str]:
+ """返回当前 Supervisor 已注册的插件 ID 列表。"""
+
+ return sorted(self._registered_plugins.keys())
+
async def dispatch_event(
self,
event_type: str,
@@ -344,12 +369,18 @@ class PluginRunnerSupervisor:
timeout_ms=timeout_ms,
)
- async def reload_plugin(self, plugin_id: str, reason: str = "manual") -> bool:
+ async def reload_plugin(
+ self,
+ plugin_id: str,
+ reason: str = "manual",
+ external_available_plugins: Optional[List[str]] = None,
+ ) -> bool:
"""按插件 ID 触发精确重载。
Args:
plugin_id: 目标插件 ID。
reason: 重载原因。
+ external_available_plugins: 视为已满足的外部依赖插件 ID 列表。
Returns:
bool: 是否重载成功。
@@ -358,7 +389,11 @@ class PluginRunnerSupervisor:
response = await self._rpc_server.send_request(
"plugin.reload",
plugin_id=plugin_id,
- payload={"plugin_id": plugin_id, "reason": reason},
+ payload={
+ "plugin_id": plugin_id,
+ "reason": reason,
+ "external_available_plugins": external_available_plugins or self._external_available_plugin_ids,
+ },
timeout_ms=max(int(self._runner_spawn_timeout * 1000), 10000),
)
except Exception as exc:
@@ -374,12 +409,14 @@ class PluginRunnerSupervisor:
self,
plugin_ids: Optional[List[str]] = None,
reason: str = "manual",
+ external_available_plugins: Optional[List[str]] = None,
) -> bool:
"""批量重载插件。
Args:
plugin_ids: 目标插件 ID 列表;为空时重载当前已注册的全部插件。
reason: 重载原因。
+ external_available_plugins: 视为已满足的外部依赖插件 ID 列表。
Returns:
bool: 是否全部重载成功。
@@ -389,7 +426,11 @@ class PluginRunnerSupervisor:
success = True
for plugin_id in ordered_plugin_ids:
- reloaded = await self.reload_plugin(plugin_id=plugin_id, reason=reason)
+ reloaded = await self.reload_plugin(
+ plugin_id=plugin_id,
+ reason=reason,
+ external_available_plugins=external_available_plugins,
+ )
success = success and reloaded
return success
@@ -399,7 +440,7 @@ class PluginRunnerSupervisor:
plugin_id: str,
config_data: Optional[Dict[str, Any]] = None,
config_version: str = "",
- config_scope: str = "self",
+ config_scope: str | ConfigReloadScope = "self",
) -> bool:
"""向 Runner 推送插件配置更新。
@@ -412,9 +453,15 @@ class PluginRunnerSupervisor:
Returns:
bool: 请求是否成功送达并被 Runner 接受。
"""
+ try:
+ normalized_scope = ConfigReloadScope(config_scope)
+ except ValueError:
+ logger.warning(f"插件 {plugin_id} 配置更新通知失败: 非法的 config_scope={config_scope}")
+ return False
+
payload = ConfigUpdatedPayload(
plugin_id=plugin_id,
- config_scope=config_scope,
+ config_scope=normalized_scope,
config_version=config_version,
config_data=config_data or {},
)
@@ -441,11 +488,11 @@ class PluginRunnerSupervisor:
List[str]: 已声明订阅该范围的插件 ID 列表。
"""
- matched_plugins: List[str] = []
- for plugin_id, registration in self._registered_plugins.items():
- if scope in registration.config_reload_subscriptions:
- matched_plugins.append(plugin_id)
- return matched_plugins
+ return [
+ plugin_id
+ for plugin_id, registration in self._registered_plugins.items()
+ if scope in registration.config_reload_subscriptions
+ ]
async def _wait_for_runner_connection(self, timeout_sec: float) -> None:
"""等待 Runner 建立 RPC 连接。
@@ -706,10 +753,7 @@ class PluginRunnerSupervisor:
)
gateways = self._component_registry.get_message_gateways(plugin_id=plugin_id, enabled_only=False)
- if len(gateways) == 1:
- return gateways[0]
-
- return None
+ return gateways[0] if len(gateways) == 1 else None
async def _register_message_gateway_driver(
self,
@@ -823,8 +867,7 @@ class PluginRunnerSupervisor:
ValueError: 当平台信息缺失时抛出。
"""
- platform = str(payload.platform or gateway_entry.platform or "").strip()
- if not platform:
+ if not (platform := str(payload.platform or gateway_entry.platform or "").strip()):
raise ValueError(f"消息网关 {gateway_entry.full_name} 未提供有效的平台名称")
return RouteKey(
@@ -1090,7 +1133,11 @@ class PluginRunnerSupervisor:
Returns:
Dict[str, str]: 传递给 Runner 进程的环境变量映射。
"""
+ global_config_snapshot = config_manager.get_global_config().model_dump()
+ global_config_snapshot["model"] = config_manager.get_model_config().model_dump()
return {
+ ENV_EXTERNAL_PLUGIN_IDS: json.dumps(self._external_available_plugin_ids, ensure_ascii=False),
+ ENV_GLOBAL_CONFIG_SNAPSHOT: json.dumps(global_config_snapshot, ensure_ascii=False),
ENV_HOST_VERSION: PROTOCOL_VERSION,
ENV_IPC_ADDRESS: self._transport.get_address(),
ENV_PLUGIN_DIRS: os.pathsep.join(str(path) for path in self._plugin_dirs),
@@ -1136,8 +1183,7 @@ class PluginRunnerSupervisor:
line = await stream.readline()
if not line:
return
- message = line.decode("utf-8", errors="replace").rstrip()
- if message:
+ if message := line.decode("utf-8", errors="replace").rstrip():
logger.warning(f"[runner-stderr] {message}")
except asyncio.CancelledError:
raise
diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py
index e45b40de..d48260e5 100644
--- a/src/plugin_runtime/integration.py
+++ b/src/plugin_runtime/integration.py
@@ -8,7 +8,7 @@
"""
from pathlib import Path
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, Iterable, List, Optional, Sequence, Tuple
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, Iterable, List, Optional, Sequence, Set, Tuple
import asyncio
import json
@@ -102,6 +102,77 @@ class PluginRuntimeManager(
candidate = Path("plugins").resolve()
return [candidate] if candidate.is_dir() else []
+ @staticmethod
+ def _extract_manifest_dependencies(manifest: Dict[str, Any]) -> List[str]:
+ """从插件 manifest 中提取规范化后的依赖插件 ID 列表。"""
+
+ dependencies: List[str] = []
+ for dependency in manifest.get("dependencies", []):
+ if isinstance(dependency, str):
+ normalized_dependency = dependency.strip()
+ elif isinstance(dependency, dict):
+ normalized_dependency = str(dependency.get("name", "") or "").strip()
+ else:
+ normalized_dependency = ""
+
+ if normalized_dependency:
+ dependencies.append(normalized_dependency)
+ return dependencies
+
+ @classmethod
+ def _discover_plugin_dependency_map(cls, plugin_dirs: Iterable[Path]) -> Dict[str, List[str]]:
+ """扫描指定插件目录集合,返回 ``plugin_id -> dependencies`` 映射。"""
+
+ dependency_map: Dict[str, List[str]] = {}
+ for plugin_dir in cls._iter_candidate_plugin_paths(plugin_dirs):
+ manifest_path = plugin_dir / "_manifest.json"
+ entrypoint_path = plugin_dir / "plugin.py"
+ if not manifest_path.is_file() or not entrypoint_path.is_file():
+ continue
+
+ try:
+ with manifest_path.open("r", encoding="utf-8") as manifest_file:
+ manifest = json.load(manifest_file)
+ except Exception:
+ continue
+
+ if not isinstance(manifest, dict):
+ continue
+
+ plugin_id = str(manifest.get("name", plugin_dir.name) or "").strip() or plugin_dir.name
+ dependency_map[plugin_id] = cls._extract_manifest_dependencies(manifest)
+ return dependency_map
+
+ @classmethod
+ def _build_group_start_order(
+ cls,
+ builtin_dirs: Sequence[Path],
+ third_party_dirs: Sequence[Path],
+ ) -> List[str]:
+ """根据跨 Supervisor 依赖关系决定 Runner 启动顺序。"""
+
+ builtin_dependencies = cls._discover_plugin_dependency_map(builtin_dirs)
+ third_party_dependencies = cls._discover_plugin_dependency_map(third_party_dirs)
+ builtin_plugin_ids = set(builtin_dependencies)
+ third_party_plugin_ids = set(third_party_dependencies)
+
+ builtin_needs_third_party = any(
+ dependency in third_party_plugin_ids
+ for dependencies in builtin_dependencies.values()
+ for dependency in dependencies
+ )
+ third_party_needs_builtin = any(
+ dependency in builtin_plugin_ids
+ for dependencies in third_party_dependencies.values()
+ for dependency in dependencies
+ )
+
+ if builtin_needs_third_party and third_party_needs_builtin:
+ raise RuntimeError("检测到跨 Supervisor 循环依赖,当前无法安全启动独立 Runner")
+ if builtin_needs_third_party:
+ return ["third_party", "builtin"]
+ return ["builtin", "third_party"]
+
# ─── 生命周期 ─────────────────────────────────────────────
async def start(self) -> None:
@@ -161,12 +232,26 @@ class PluginRuntimeManager(
platform_io_manager.set_inbound_dispatcher(self._dispatch_platform_inbound)
await platform_io_manager.ensure_send_pipeline_ready()
- if self._builtin_supervisor:
- await self._builtin_supervisor.start()
- started_supervisors.append(self._builtin_supervisor)
- if self._third_party_supervisor:
- await self._third_party_supervisor.start()
- started_supervisors.append(self._third_party_supervisor)
+ supervisor_groups: Dict[str, Optional[PluginSupervisor]] = {
+ "builtin": self._builtin_supervisor,
+ "third_party": self._third_party_supervisor,
+ }
+ start_order = self._build_group_start_order(builtin_dirs, third_party_dirs)
+
+ for group_name in start_order:
+ supervisor = supervisor_groups.get(group_name)
+ if supervisor is None:
+ continue
+
+ external_plugin_ids = [
+ plugin_id
+ for started_supervisor in started_supervisors
+ for plugin_id in started_supervisor.get_loaded_plugin_ids()
+ ]
+ supervisor.set_external_available_plugin_ids(external_plugin_ids)
+ await supervisor.start()
+ started_supervisors.append(supervisor)
+
await self._start_plugin_file_watcher()
config_manager.register_reload_callback(self._config_reload_callback)
self._config_reload_callback_registered = True
@@ -238,6 +323,171 @@ class PluginRuntimeManager(
"""获取所有活跃的 Supervisor"""
return [s for s in (self._builtin_supervisor, self._third_party_supervisor) if s is not None]
+ def _build_registered_dependency_map(self) -> Dict[str, Set[str]]:
+ """根据当前已注册插件构建全局依赖图。"""
+
+ dependency_map: Dict[str, Set[str]] = {}
+ for supervisor in self.supervisors:
+ for plugin_id, registration in getattr(supervisor, "_registered_plugins", {}).items():
+ dependency_map[plugin_id] = {
+ str(dependency or "").strip()
+ for dependency in getattr(registration, "dependencies", [])
+ if str(dependency or "").strip()
+ }
+ return dependency_map
+
+ @staticmethod
+ def _collect_reverse_dependents(
+ plugin_ids: Set[str],
+ dependency_map: Dict[str, Set[str]],
+ ) -> Set[str]:
+ """根据依赖图收集反向依赖闭包。"""
+
+ impacted_plugins: Set[str] = set(plugin_ids)
+ changed = True
+
+ while changed:
+ changed = False
+ for registered_plugin_id, dependencies in dependency_map.items():
+ if registered_plugin_id in impacted_plugins:
+ continue
+ if dependencies & impacted_plugins:
+ impacted_plugins.add(registered_plugin_id)
+ changed = True
+
+ return impacted_plugins
+
+ def _build_registered_supervisor_map(self) -> Dict[str, "PluginSupervisor"]:
+ """构建当前已注册插件到所属 Supervisor 的映射。"""
+
+ return {
+ plugin_id: supervisor
+ for supervisor in self.supervisors
+ for plugin_id in supervisor.get_loaded_plugin_ids()
+ }
+
+ def _build_external_available_plugins_for_supervisor(self, target_supervisor: "PluginSupervisor") -> List[str]:
+ """收集某个 Supervisor 可用的外部插件 ID 列表。"""
+
+ external_plugin_ids: Set[str] = set()
+ for supervisor in self.supervisors:
+ if supervisor is target_supervisor:
+ continue
+ external_plugin_ids.update(supervisor.get_loaded_plugin_ids())
+ return sorted(external_plugin_ids)
+
+ def _find_supervisor_by_plugin_directory(self, plugin_id: str) -> Optional["PluginSupervisor"]:
+ """根据插件目录推断应负责该插件重载的 Supervisor。"""
+
+ for supervisor in self.supervisors:
+ for plugin_dir in supervisor._plugin_dirs:
+ if (Path(plugin_dir) / plugin_id).is_dir():
+ return supervisor
+ return None
+
+ def _warn_skipped_cross_supervisor_reload(
+ self,
+ requested_loaded_plugin_ids: Set[str],
+ dependency_map: Dict[str, Set[str]],
+ supervisor_by_plugin: Dict[str, "PluginSupervisor"],
+ ) -> None:
+ """记录因跨 Supervisor 边界而未参与联动重载的插件。"""
+
+ if not requested_loaded_plugin_ids:
+ return
+
+ handled_plugin_ids: Set[str] = set()
+ for supervisor in self.supervisors:
+ local_requested_plugin_ids = {
+ plugin_id
+ for plugin_id in requested_loaded_plugin_ids
+ if supervisor_by_plugin.get(plugin_id) is supervisor
+ }
+ if not local_requested_plugin_ids:
+ continue
+
+ local_plugin_ids = set(supervisor.get_loaded_plugin_ids())
+ local_dependency_map = {
+ plugin_id: {
+ dependency
+ for dependency in dependency_map.get(plugin_id, set())
+ if dependency in local_plugin_ids
+ }
+ for plugin_id in local_plugin_ids
+ }
+ handled_plugin_ids.update(
+ self._collect_reverse_dependents(local_requested_plugin_ids, local_dependency_map)
+ )
+
+ impacted_plugin_ids = self._collect_reverse_dependents(requested_loaded_plugin_ids, dependency_map)
+ skipped_plugin_ids = sorted(impacted_plugin_ids - handled_plugin_ids)
+ if not skipped_plugin_ids:
+ return
+
+ logger.warning(
+ f"插件 {', '.join(sorted(requested_loaded_plugin_ids))} 存在跨 Supervisor 依赖方未联动重载: "
+ f"{', '.join(skipped_plugin_ids)}。当前仅在单个 Supervisor 内执行联动重载;"
+ "跨 Supervisor API 调用仍然可用。如需联动重载,请将相关插件放在同一个 Supervisor 内。"
+ )
+
+ async def reload_plugins_globally(self, plugin_ids: Sequence[str], reason: str = "manual") -> bool:
+ """按 Supervisor 分组执行精确重载。
+
+ 仅在单个 Supervisor 内执行依赖联动;跨 Supervisor 依赖方仅记录告警,
+ 不再自动参与本次热重载。
+ """
+
+ normalized_plugin_ids = [
+ normalized_plugin_id
+ for plugin_id in plugin_ids
+ if (normalized_plugin_id := str(plugin_id or "").strip())
+ ]
+ if not normalized_plugin_ids:
+ return True
+
+ dependency_map = self._build_registered_dependency_map()
+ supervisor_by_plugin = self._build_registered_supervisor_map()
+ supervisor_roots: Dict["PluginSupervisor", List[str]] = {}
+ requested_loaded_plugin_ids: Set[str] = set()
+ missing_plugin_ids: List[str] = []
+
+ for plugin_id in normalized_plugin_ids:
+ supervisor = supervisor_by_plugin.get(plugin_id)
+ if supervisor is not None:
+ requested_loaded_plugin_ids.add(plugin_id)
+ else:
+ supervisor = self._find_supervisor_by_plugin_directory(plugin_id)
+
+ if supervisor is None:
+ missing_plugin_ids.append(plugin_id)
+ continue
+
+ if plugin_id not in supervisor_roots.setdefault(supervisor, []):
+ supervisor_roots[supervisor].append(plugin_id)
+
+ if missing_plugin_ids:
+ logger.warning(f"以下插件未找到可重载的 Supervisor,已跳过: {', '.join(sorted(missing_plugin_ids))}")
+
+ self._warn_skipped_cross_supervisor_reload(
+ requested_loaded_plugin_ids=requested_loaded_plugin_ids,
+ dependency_map=dependency_map,
+ supervisor_by_plugin=supervisor_by_plugin,
+ )
+
+ success = True
+ for supervisor, root_plugin_ids in supervisor_roots.items():
+ if not root_plugin_ids:
+ continue
+
+ reloaded = await supervisor.reload_plugins(
+ plugin_ids=root_plugin_ids,
+ reason=reason,
+ external_available_plugins=self._build_external_available_plugins_for_supervisor(supervisor),
+ )
+ success = success and reloaded
+
+ return success and not missing_plugin_ids
+
async def notify_plugin_config_updated(
self,
plugin_id: str,
@@ -465,6 +715,31 @@ class PluginRuntimeManager(
raise RuntimeError(f"插件 {plugin_id} 同时存在于多个 Supervisor 中,无法安全路由")
return matches[0] if matches else None
+ async def load_plugin_globally(self, plugin_id: str, reason: str = "manual") -> bool:
+ """加载或重载单个插件,并为其补齐跨 Supervisor 外部依赖。"""
+
+ normalized_plugin_id = str(plugin_id or "").strip()
+ if not normalized_plugin_id:
+ return False
+
+ try:
+ registered_supervisor = self._get_supervisor_for_plugin(normalized_plugin_id)
+ except RuntimeError:
+ return False
+
+ if registered_supervisor is not None:
+ return await self.reload_plugins_globally([normalized_plugin_id], reason=reason)
+
+ supervisor = self._find_supervisor_by_plugin_directory(normalized_plugin_id)
+ if supervisor is None:
+ return False
+
+ return await supervisor.reload_plugins(
+ plugin_ids=[normalized_plugin_id],
+ reason=reason,
+ external_available_plugins=self._build_external_available_plugins_for_supervisor(supervisor),
+ )
+
@staticmethod
def _find_duplicate_plugin_ids(plugin_dirs: List[Path]) -> Dict[str, List[Path]]:
"""扫描插件目录,找出被多个目录重复声明的插件 ID。"""
@@ -729,7 +1004,7 @@ class PluginRuntimeManager(
logger.error(f"检测到重复插件 ID,跳过本次插件热重载: {details}")
return
- reload_supervisors: Dict[Any, List[str]] = {}
+ changed_plugin_ids: List[str] = []
changed_paths = [change.path.resolve() for change in changes]
for supervisor in self.supervisors:
@@ -738,14 +1013,11 @@ class PluginRuntimeManager(
if plugin_id is None:
continue
if path.name in {"plugin.py", "_manifest.json"} or path.suffix == ".py":
- reload_supervisors.setdefault(supervisor, [])
- if plugin_id not in reload_supervisors[supervisor]:
- reload_supervisors[supervisor].append(plugin_id)
+ if plugin_id not in changed_plugin_ids:
+ changed_plugin_ids.append(plugin_id)
- for supervisor, plugin_ids in reload_supervisors.items():
- await supervisor.reload_plugins(plugin_ids=plugin_ids, reason="file_watcher")
-
- if reload_supervisors:
+ if changed_plugin_ids:
+ await self.reload_plugins_globally(changed_plugin_ids, reason="file_watcher")
self._refresh_plugin_config_watch_subscriptions()
@staticmethod
diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py
index 6078e4dc..ce40d855 100644
--- a/src/plugin_runtime/protocol/envelope.py
+++ b/src/plugin_runtime/protocol/envelope.py
@@ -166,6 +166,8 @@ class RegisterPluginPayload(BaseModel):
"""组件列表"""
capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表")
"""所需能力列表"""
+ dependencies: List[str] = Field(default_factory=list, description="插件级依赖插件 ID 列表")
+ """插件级依赖插件 ID 列表"""
config_reload_subscriptions: List[str] = Field(default_factory=list, description="订阅的全局配置热重载范围")
"""订阅的全局配置热重载范围"""
@@ -280,6 +282,8 @@ class ReloadPluginPayload(BaseModel):
"""目标插件 ID"""
reason: str = Field(default="manual", description="重载原因")
"""重载原因"""
+ external_available_plugins: List[str] = Field(default_factory=list, description="可视为已满足的外部依赖插件 ID")
+ """可视为已满足的外部依赖插件 ID"""
class ReloadPluginResultPayload(BaseModel):
diff --git a/src/plugin_runtime/runner/plugin_loader.py b/src/plugin_runtime/runner/plugin_loader.py
index a766eb04..f07eb593 100644
--- a/src/plugin_runtime/runner/plugin_loader.py
+++ b/src/plugin_runtime/runner/plugin_loader.py
@@ -95,11 +95,16 @@ class PluginLoader:
self._manifest_validator = ManifestValidator(host_version=host_version)
self._compat_hook_installed = False
- def discover_and_load(self, plugin_dirs: List[str]) -> List[PluginMeta]:
+ def discover_and_load(
+ self,
+ plugin_dirs: List[str],
+ extra_available: Optional[Set[str]] = None,
+ ) -> List[PluginMeta]:
"""扫描多个目录并加载所有插件。
Args:
plugin_dirs: 插件目录列表。
+ extra_available: 额外视为已满足的外部依赖插件 ID 集合。
Returns:
List[PluginMeta]: 成功加载的插件元数据列表,按依赖顺序排列。
@@ -108,7 +113,7 @@ class PluginLoader:
self._record_duplicate_candidates(duplicate_candidates)
# 第二阶段:依赖解析(拓扑排序)
- load_order, failed_deps = self._resolve_dependencies(candidates)
+ load_order, failed_deps = self._resolve_dependencies(candidates, extra_available=extra_available)
self._record_failed_dependencies(failed_deps)
# 第三阶段:按依赖顺序加载
diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py
index b38946d6..c0f5e771 100644
--- a/src/plugin_runtime/runner/runner_main.py
+++ b/src/plugin_runtime/runner/runner_main.py
@@ -15,6 +15,7 @@ from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, ca
import asyncio
import contextlib
import inspect
+import json
import logging as stdlib_logging
import os
import signal
@@ -23,7 +24,13 @@ import time
import tomllib
from src.common.logger import get_console_handler, get_logger, initialize_logging
-from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN
+from src.plugin_runtime import (
+ ENV_EXTERNAL_PLUGIN_IDS,
+ ENV_HOST_VERSION,
+ ENV_IPC_ADDRESS,
+ ENV_PLUGIN_DIRS,
+ ENV_SESSION_TOKEN,
+)
from src.plugin_runtime.protocol.envelope import (
BootstrapPluginPayload,
ComponentDeclaration,
@@ -112,6 +119,7 @@ class PluginRunner:
host_address: str,
session_token: str,
plugin_dirs: List[str],
+ external_available_plugin_ids: Optional[List[str]] = None,
) -> None:
"""初始化 Runner。
@@ -119,10 +127,16 @@ class PluginRunner:
host_address: Host 的 IPC 地址。
session_token: 握手用会话令牌。
plugin_dirs: 当前 Runner 负责扫描的插件目录列表。
+ external_available_plugin_ids: 视为已满足的外部依赖插件 ID 列表。
"""
self._host_address: str = host_address
self._session_token: str = session_token
self._plugin_dirs: List[str] = plugin_dirs
+ self._external_available_plugin_ids: Set[str] = {
+ str(plugin_id or "").strip()
+ for plugin_id in (external_available_plugin_ids or [])
+ if str(plugin_id or "").strip()
+ }
self._rpc_client: RPCClient = RPCClient(host_address, session_token)
self._loader: PluginLoader = PluginLoader(host_version=os.getenv(ENV_HOST_VERSION, ""))
@@ -150,7 +164,10 @@ class PluginRunner:
self._register_handlers()
# 3. 加载插件
- plugins = self._loader.discover_and_load(self._plugin_dirs)
+ plugins = self._loader.discover_and_load(
+ self._plugin_dirs,
+ extra_available=self._external_available_plugin_ids,
+ )
logger.info(f"已加载 {len(plugins)} 个插件")
# 4. 注入 PluginContext + 调用 on_load 生命周期钩子
@@ -379,6 +396,7 @@ class PluginRunner:
plugin_version=meta.version,
components=components,
capabilities_required=meta.capabilities_required,
+ dependencies=meta.dependencies,
config_reload_subscriptions=config_reload_subscriptions,
)
@@ -485,18 +503,20 @@ class PluginRunner:
self._loader.set_loaded_plugin(meta)
return True
- async def _unload_plugin(self, meta: PluginMeta, reason: str) -> None:
+ async def _unload_plugin(self, meta: PluginMeta, reason: str, *, purge_modules: bool = True) -> None:
"""卸载单个插件并清理 Host/Runner 两侧状态。
Args:
meta: 待卸载的插件元数据。
reason: 卸载原因。
+ purge_modules: 是否在卸载完成后清理插件模块缓存。
"""
await self._invoke_plugin_on_unload(meta)
await self._unregister_plugin(meta.plugin_id, reason)
await self._deactivate_plugin(meta)
self._loader.remove_loaded_plugin(meta.plugin_id)
- self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
+ if purge_modules:
+ self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
def _collect_reverse_dependents(self, plugin_id: str) -> Set[str]:
"""收集依赖指定插件的所有已加载插件。
@@ -564,18 +584,52 @@ class PluginRunner:
return list(reversed(load_order))
- async def _reload_plugin_by_id(self, plugin_id: str, reason: str) -> ReloadPluginResultPayload:
+ @staticmethod
+ def _finalize_failed_reload_messages(
+ failed_plugins: Dict[str, str],
+ rollback_failures: Dict[str, str],
+ ) -> Dict[str, str]:
+ """在重载失败后补充回滚结果说明。"""
+
+ finalized_failures: Dict[str, str] = {}
+ for failed_plugin_id, failure_reason in failed_plugins.items():
+ rollback_failure = rollback_failures.get(failed_plugin_id)
+ if rollback_failure:
+ finalized_failures[failed_plugin_id] = (
+ f"{failure_reason};且旧版本恢复失败: {rollback_failure}"
+ )
+ else:
+ finalized_failures[failed_plugin_id] = f"{failure_reason}(已恢复旧版本)"
+
+ for failed_plugin_id, rollback_failure in rollback_failures.items():
+ if failed_plugin_id not in finalized_failures:
+ finalized_failures[failed_plugin_id] = f"旧版本恢复失败: {rollback_failure}"
+
+ return finalized_failures
+
+ async def _reload_plugin_by_id(
+ self,
+ plugin_id: str,
+ reason: str,
+ external_available_plugins: Optional[Set[str]] = None,
+ ) -> ReloadPluginResultPayload:
"""按插件 ID 在 Runner 进程内执行精确重载。
Args:
plugin_id: 目标插件 ID。
reason: 重载原因。
+ external_available_plugins: 视为已满足的外部依赖插件 ID 集合。
Returns:
ReloadPluginResultPayload: 结构化重载结果。
"""
candidates, duplicate_candidates = self._loader.discover_candidates(self._plugin_dirs)
failed_plugins: Dict[str, str] = {}
+ normalized_external_available = {
+ str(candidate_plugin_id or "").strip()
+ for candidate_plugin_id in (external_available_plugins or set())
+ if str(candidate_plugin_id or "").strip()
+ }
if plugin_id in duplicate_candidates:
conflict_paths = ", ".join(str(path) for path in duplicate_candidates[plugin_id])
@@ -603,29 +657,32 @@ class PluginRunner:
unload_order = self._build_unload_order(target_plugin_ids & loaded_plugin_ids)
unloaded_plugins: List[str] = []
retained_plugin_ids = loaded_plugin_ids - set(unload_order)
+ rollback_metas: Dict[str, PluginMeta] = {}
for unload_plugin_id in unload_order:
meta = self._loader.get_plugin(unload_plugin_id)
if meta is None:
continue
- await self._unload_plugin(meta, reason=reason)
+ rollback_metas[unload_plugin_id] = meta
+ await self._unload_plugin(meta, reason=reason, purge_modules=False)
+ self._loader.purge_plugin_modules(unload_plugin_id, meta.plugin_dir)
unloaded_plugins.append(unload_plugin_id)
reload_candidates: Dict[str, Tuple[Path, Dict[str, Any], Path]] = {}
for target_plugin_id in target_plugin_ids:
candidate = candidates.get(target_plugin_id)
if candidate is None:
- failed_plugins[target_plugin_id] = "插件目录已不存在,已保持卸载状态"
+ failed_plugins[target_plugin_id] = "插件目录已不存在"
continue
reload_candidates[target_plugin_id] = candidate
load_order, dependency_failures = self._loader.resolve_dependencies(
reload_candidates,
- extra_available=retained_plugin_ids,
+ extra_available=retained_plugin_ids | normalized_external_available,
)
failed_plugins.update(dependency_failures)
- available_plugins = set(retained_plugin_ids)
+ available_plugins = set(retained_plugin_ids) | normalized_external_available
reloaded_plugins: List[str] = []
for load_plugin_id in load_order:
@@ -656,7 +713,48 @@ class PluginRunner:
available_plugins.add(load_plugin_id)
reloaded_plugins.append(load_plugin_id)
- requested_plugin_success = plugin_id in reloaded_plugins and not failed_plugins
+ if failed_plugins:
+ rollback_failures: Dict[str, str] = {}
+
+ for reloaded_plugin_id in reversed(reloaded_plugins):
+ reloaded_meta = self._loader.get_plugin(reloaded_plugin_id)
+ if reloaded_meta is None:
+ continue
+
+ try:
+ await self._unload_plugin(
+ reloaded_meta,
+ reason=f"{reason}_rollback_cleanup",
+ purge_modules=False,
+ )
+ except Exception as exc:
+ rollback_failures[reloaded_plugin_id] = f"清理失败: {exc}"
+ finally:
+ self._loader.purge_plugin_modules(reloaded_plugin_id, reloaded_meta.plugin_dir)
+
+ for rollback_plugin_id in reversed(unload_order):
+ rollback_meta = rollback_metas.get(rollback_plugin_id)
+ if rollback_meta is None:
+ continue
+
+ try:
+ restored = await self._activate_plugin(rollback_meta)
+ except Exception as exc:
+ rollback_failures[rollback_plugin_id] = str(exc)
+ continue
+
+ if not restored:
+ rollback_failures[rollback_plugin_id] = "无法重新激活旧版本"
+
+ return ReloadPluginResultPayload(
+ success=False,
+ requested_plugin_id=plugin_id,
+ reloaded_plugins=[],
+ unloaded_plugins=unloaded_plugins,
+ failed_plugins=self._finalize_failed_reload_messages(failed_plugins, rollback_failures),
+ )
+
+ requested_plugin_success = plugin_id in reloaded_plugins
return ReloadPluginResultPayload(
success=requested_plugin_success,
@@ -978,7 +1076,11 @@ class PluginRunner:
)
async with self._reload_lock:
- result = await self._reload_plugin_by_id(payload.plugin_id, payload.reason)
+ result = await self._reload_plugin_by_id(
+ payload.plugin_id,
+ payload.reason,
+ external_available_plugins=set(payload.external_available_plugins),
+ )
return envelope.make_response(payload=result.model_dump())
def request_capability(self) -> RPCClient:
@@ -1073,6 +1175,7 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
async def _async_main() -> None:
"""异步主入口"""
host_address = os.environ.get(ENV_IPC_ADDRESS, "")
+ external_plugin_ids_raw = os.environ.get(ENV_EXTERNAL_PLUGIN_IDS, "")
session_token = os.environ.get(ENV_SESSION_TOKEN, "")
plugin_dirs_str = os.environ.get(ENV_PLUGIN_DIRS, "")
@@ -1081,11 +1184,24 @@ async def _async_main() -> None:
sys.exit(1)
plugin_dirs = [d for d in plugin_dirs_str.split(os.pathsep) if d]
+ try:
+ external_plugin_ids = json.loads(external_plugin_ids_raw) if external_plugin_ids_raw else []
+ except json.JSONDecodeError:
+ logger.warning("解析外部依赖插件列表失败,已回退为空列表")
+ external_plugin_ids = []
+ if not isinstance(external_plugin_ids, list):
+ logger.warning("外部依赖插件列表格式非法,已回退为空列表")
+ external_plugin_ids = []
# sys.path 隔离: 只保留标准库、SDK 包、插件目录
_isolate_sys_path(plugin_dirs)
- runner = PluginRunner(host_address, session_token, plugin_dirs)
+ runner = PluginRunner(
+ host_address,
+ session_token,
+ plugin_dirs,
+ external_available_plugin_ids=[str(plugin_id) for plugin_id in external_plugin_ids],
+ )
# 注册信号处理
def _mark_runner_shutting_down() -> None:
From 1f02171a635e555b62bff2480412c6a4c44a3ce5 Mon Sep 17 00:00:00 2001
From: DrSmoothl <1787882683@qq.com>
Date: Mon, 23 Mar 2026 22:59:01 +0800
Subject: [PATCH 37/45] Refactor plugin loader and runner to support enhanced
manifest structure
- Updated the PluginMeta class to utilize a strongly typed PluginManifest, improving type safety and clarity.
- Refactored dependency extraction logic to streamline the handling of plugin dependencies.
- Modified the PluginLoader to accommodate new manifest versioning and validation processes.
- Enhanced the PluginRunner to work with a dictionary for external available plugins, allowing for version mapping.
- Updated built-in plugins' manifest files to version 2, adding URLs and SDK versioning for better integration and documentation.
- Improved error handling and logging for plugin loading and dependency resolution processes.
---
plugins/ChatFrequency/_manifest.json | 76 +-
plugins/MaiBot_MCPBridgePlugin/_manifest.json | 83 +-
plugins/emoji_manage_plugin/_manifest.json | 88 +-
plugins/hello_world_plugin/_manifest.json | 107 +-
pytests/test_plugin_runtime.py | 387 ++++--
src/plugin_runtime/__init__.py | 2 +-
src/plugin_runtime/capabilities/components.py | 4 +-
src/plugin_runtime/host/supervisor.py | 41 +-
src/plugin_runtime/integration.py | 114 +-
src/plugin_runtime/protocol/envelope.py | 7 +-
.../runner/manifest_validator.py | 1113 +++++++++++++++--
src/plugin_runtime/runner/plugin_loader.py | 166 ++-
src/plugin_runtime/runner/runner_main.py | 79 +-
.../built_in/emoji_plugin/_manifest.json | 43 +-
.../built_in/plugin_management/_manifest.json | 77 +-
15 files changed, 1676 insertions(+), 711 deletions(-)
diff --git a/plugins/ChatFrequency/_manifest.json b/plugins/ChatFrequency/_manifest.json
index 241242ed..56417665 100644
--- a/plugins/ChatFrequency/_manifest.json
+++ b/plugins/ChatFrequency/_manifest.json
@@ -1,58 +1,40 @@
{
- "manifest_version": 1,
- "name": "发言频率控制插件|BetterFrequency Plugin",
+ "manifest_version": 2,
"version": "2.0.0",
- "description": "控制聊天频率,支持设置focus_value和talk_frequency调整值,提供命令",
+ "name": "发言频率控制插件|BetterFrequency Plugin",
+ "description": "控制聊天频率,支持设置 focus_value 和 talk_frequency 调整值,并提供命令入口。",
"author": {
"name": "SengokuCola",
"url": "https://github.com/MaiM-with-u"
},
"license": "GPL-v3.0-or-later",
- "host_application": {
- "min_version": "1.0.0"
+ "urls": {
+ "repository": "https://github.com/SengokuCola/BetterFrequency",
+ "homepage": "https://github.com/SengokuCola/BetterFrequency",
+ "documentation": "https://github.com/SengokuCola/BetterFrequency",
+ "issues": "https://github.com/SengokuCola/BetterFrequency/issues"
},
- "homepage_url": "https://github.com/SengokuCola/BetterFrequency",
- "repository_url": "https://github.com/SengokuCola/BetterFrequency",
- "keywords": [
- "frequency",
- "control",
- "talk_frequency",
- "plugin",
- "shortcut"
+ "host_application": {
+ "min_version": "1.0.0",
+ "max_version": "1.0.0"
+ },
+ "sdk": {
+ "min_version": "2.0.0",
+ "max_version": "2.99.99"
+ },
+ "dependencies": [],
+ "capabilities": [
+ "send.text",
+ "frequency.set_adjust",
+ "frequency.get_current_talk_value",
+ "frequency.get_adjust"
],
- "categories": [
- "Chat",
- "Frequency",
- "Control"
- ],
- "default_locale": "zh-CN",
- "locales_path": "_locales",
- "plugin_info": {
- "is_built_in": false,
- "plugin_type": "frequency",
- "components": [
- {
- "type": "command",
- "name": "set_talk_frequency",
- "description": "设置当前聊天的talk_frequency调整值",
- "pattern": "/chat talk_frequency <数字> 或 /chat t <数字>"
- },
- {
- "type": "command",
- "name": "show_frequency",
- "description": "显示当前聊天的频率控制状态",
- "pattern": "/chat show 或 /chat s"
- }
- ],
- "features": [
- "设置talk_frequency调整值",
- "调整当前聊天的发言频率",
- "显示当前频率控制状态",
- "实时频率控制调整",
- "命令执行反馈(不保存消息)",
- "支持完整命令和简化命令",
- "快速操作支持"
+ "i18n": {
+ "default_locale": "zh-CN",
+ "locales_path": "_locales",
+ "supported_locales": [
+ "zh-CN"
]
},
- "id": "SengokuCola.BetterFrequency"
-}
\ No newline at end of file
+ "id": "sengokucola.betterfrequency"
+}
diff --git a/plugins/MaiBot_MCPBridgePlugin/_manifest.json b/plugins/MaiBot_MCPBridgePlugin/_manifest.json
index 85225a43..d2e08ab4 100644
--- a/plugins/MaiBot_MCPBridgePlugin/_manifest.json
+++ b/plugins/MaiBot_MCPBridgePlugin/_manifest.json
@@ -1,67 +1,42 @@
{
- "manifest_version": 1,
- "name": "MCP桥接插件",
+ "manifest_version": 2,
"version": "2.0.0",
- "description": "将 MCP (Model Context Protocol) 服务器的工具桥接到 MaiBot,使麦麦能够调用外部 MCP 工具",
+ "name": "MCP桥接插件",
+ "description": "将 MCP (Model Context Protocol) 服务器的工具桥接到 MaiBot,使麦麦能够调用外部 MCP 工具。",
"author": {
"name": "CharTyr",
"url": "https://github.com/CharTyr"
},
"license": "AGPL-3.0",
- "host_application": {
- "min_version": "0.11.6"
+ "urls": {
+ "repository": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin",
+ "homepage": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin",
+ "documentation": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin",
+ "issues": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin/issues"
},
- "homepage_url": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin",
- "repository_url": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin",
- "keywords": [
- "mcp",
- "bridge",
- "tool",
- "integration",
- "resources",
- "prompts",
- "post-process",
- "cache",
- "trace",
- "permissions",
- "import",
- "export",
- "claude-desktop",
- "workflow",
- "react",
- "agent"
+ "host_application": {
+ "min_version": "0.11.6",
+ "max_version": "1.0.0"
+ },
+ "sdk": {
+ "min_version": "2.0.0",
+ "max_version": "2.99.99"
+ },
+ "dependencies": [
+ {
+ "type": "python_package",
+ "name": "mcp",
+ "version_spec": ">=0.0.0"
+ }
],
- "categories": [
- "工具扩展",
- "外部集成"
+ "capabilities": [
+ "send.text"
],
- "default_locale": "zh-CN",
- "plugin_info": {
- "is_built_in": false,
- "components": [],
- "features": [
- "支持多个 MCP 服务器",
- "自动发现并注册 MCP 工具",
- "支持 stdio、SSE、HTTP、Streamable HTTP 四种传输方式",
- "工具参数自动转换",
- "心跳检测与自动重连",
- "调用统计(次数、成功率、耗时)",
- "WebUI 配置支持",
- "Resources 支持(实验性)",
- "Prompts 支持(实验性)",
- "结果后处理(LLM 摘要提炼)",
- "工具禁用管理",
- "调用链路追踪",
- "工具调用缓存(LRU)",
- "工具权限控制(群/用户级别)",
- "配置导入导出(Claude Desktop mcpServers)",
- "断路器模式(故障快速失败)",
- "状态实时刷新",
- "Workflow 硬流程(顺序执行多个工具)",
- "Workflow 快速添加(表单式配置)",
- "ReAct 软流程(LLM 自主多轮调用)",
- "双轨制架构(软流程 + 硬流程)"
+ "i18n": {
+ "default_locale": "zh-CN",
+ "supported_locales": [
+ "zh-CN"
]
},
- "id": "MaiBot Community.MCPBridgePlugin"
+ "id": "chartyr.mcpbridge-plugin"
}
diff --git a/plugins/emoji_manage_plugin/_manifest.json b/plugins/emoji_manage_plugin/_manifest.json
index 3af69023..998cb7da 100644
--- a/plugins/emoji_manage_plugin/_manifest.json
+++ b/plugins/emoji_manage_plugin/_manifest.json
@@ -1,68 +1,44 @@
{
- "manifest_version": 1,
- "name": "BetterEmoji",
+ "manifest_version": 2,
"version": "2.0.0",
+ "name": "BetterEmoji",
"description": "更好的表情包管理插件",
"author": {
"name": "SengokuCola",
"url": "https://github.com/SengokuCola"
},
"license": "GPL-v3.0-or-later",
- "host_application": {
- "min_version": "1.0.0"
+ "urls": {
+ "repository": "https://github.com/SengokuCola/BetterEmoji",
+ "homepage": "https://github.com/SengokuCola/BetterEmoji",
+ "documentation": "https://github.com/SengokuCola/BetterEmoji",
+ "issues": "https://github.com/SengokuCola/BetterEmoji/issues"
},
- "homepage_url": "https://github.com/SengokuCola/BetterEmoji",
- "repository_url": "https://github.com/SengokuCola/BetterEmoji",
- "keywords": [
- "emoji",
- "manage",
- "plugin"
+ "host_application": {
+ "min_version": "1.0.0",
+ "max_version": "1.0.0"
+ },
+ "sdk": {
+ "min_version": "2.0.0",
+ "max_version": "2.99.99"
+ },
+ "dependencies": [],
+ "capabilities": [
+ "emoji.get_random",
+ "emoji.get_count",
+ "emoji.get_info",
+ "emoji.get_all",
+ "emoji.register_emoji",
+ "emoji.delete_emoji",
+ "send.text",
+ "send.forward"
],
- "categories": [
- "Emoji",
- "Management"
- ],
- "default_locale": "zh-CN",
- "locales_path": "_locales",
- "plugin_info": {
- "is_built_in": false,
- "plugin_type": "emoji_manage",
- "capabilities": [
- "emoji.get_random",
- "emoji.get_count",
- "emoji.get_info",
- "emoji.get_all",
- "emoji.register_emoji",
- "emoji.delete_emoji",
- "send.text",
- "send.forward"
- ],
- "components": [
- {
- "type": "command",
- "name": "add_emoji",
- "description": "添加表情包",
- "pattern": "/emoji add"
- },
- {
- "type": "command",
- "name": "emoji_list",
- "description": "列表表情包",
- "pattern": "/emoji list"
- },
- {
- "type": "command",
- "name": "delete_emoji",
- "description": "删除表情包",
- "pattern": "/emoji delete"
- },
- {
- "type": "command",
- "name": "random_emojis",
- "description": "发送多张随机表情包",
- "pattern": "/random_emojis"
- }
+ "i18n": {
+ "default_locale": "zh-CN",
+ "locales_path": "_locales",
+ "supported_locales": [
+ "zh-CN"
]
},
- "id": "SengokuCola.BetterEmoji"
-}
\ No newline at end of file
+ "id": "sengokucola.betteremoji"
+}
diff --git a/plugins/hello_world_plugin/_manifest.json b/plugins/hello_world_plugin/_manifest.json
index dc9fc474..e2bc694d 100644
--- a/plugins/hello_world_plugin/_manifest.json
+++ b/plugins/hello_world_plugin/_manifest.json
@@ -1,88 +1,41 @@
{
- "manifest_version": 1,
- "name": "Hello World 示例插件 (Hello World Plugin)",
+ "manifest_version": 2,
"version": "2.0.0",
- "description": "我的第一个MaiCore插件,包含问候功能和时间查询等基础示例",
+ "name": "Hello World 示例插件 (Hello World Plugin)",
+ "description": "我的第一个 MaiCore 插件,包含问候功能和时间查询等基础示例",
"author": {
"name": "MaiBot开发团队",
"url": "https://github.com/MaiM-with-u"
},
"license": "GPL-v3.0-or-later",
- "host_application": {
- "min_version": "1.0.0"
+ "urls": {
+ "repository": "https://github.com/MaiM-with-u/maibot",
+ "homepage": "https://github.com/MaiM-with-u/maibot",
+ "documentation": "https://github.com/MaiM-with-u/maibot",
+ "issues": "https://github.com/MaiM-with-u/maibot/issues"
},
- "homepage_url": "https://github.com/MaiM-with-u/maibot",
- "repository_url": "https://github.com/MaiM-with-u/maibot",
- "keywords": [
- "demo",
- "example",
- "hello",
- "greeting",
- "tutorial"
+ "host_application": {
+ "min_version": "1.0.0",
+ "max_version": "1.0.0"
+ },
+ "sdk": {
+ "min_version": "2.0.0",
+ "max_version": "2.99.99"
+ },
+ "dependencies": [],
+ "capabilities": [
+ "send.text",
+ "send.forward",
+ "send.hybrid",
+ "emoji.get_random",
+ "config.get"
],
- "categories": [
- "Examples",
- "Tutorial"
- ],
- "default_locale": "zh-CN",
- "locales_path": "_locales",
- "plugin_info": {
- "is_built_in": false,
- "plugin_type": "example",
- "capabilities": [
- "send.text",
- "send.forward",
- "send.hybrid",
- "emoji.get_random",
- "config.get"
- ],
- "components": [
- {
- "type": "tool",
- "name": "compare_numbers",
- "description": "比较两个数的大小"
- },
- {
- "type": "action",
- "name": "hello_greeting",
- "description": "向用户发送问候消息"
- },
- {
- "type": "action",
- "name": "bye_greeting",
- "description": "向用户发送告别消息",
- "activation_modes": ["keyword"],
- "keywords": ["再见", "bye", "88", "拜拜"]
- },
- {
- "type": "command",
- "name": "time",
- "description": "查询当前时间",
- "pattern": "/time"
- },
- {
- "type": "command",
- "name": "random_emojis",
- "description": "发送多张随机表情包",
- "pattern": "/random_emojis"
- },
- {
- "type": "command",
- "name": "test",
- "description": "测试命令",
- "pattern": "/test"
- },
- {
- "type": "event_handler",
- "name": "print_message_handler",
- "description": "打印接收到的消息"
- },
- {
- "type": "event_handler",
- "name": "forward_messages_handler",
- "description": "把接收到的消息转发到指定聊天ID"
- }
+ "i18n": {
+ "default_locale": "zh-CN",
+ "locales_path": "_locales",
+ "supported_locales": [
+ "zh-CN"
]
},
- "id": "MaiBot开发团队.maibot"
-}
\ No newline at end of file
+ "id": "maibot-team.hello-world-plugin"
+}
diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py
index e094d85b..1d93ae24 100644
--- a/pytests/test_plugin_runtime.py
+++ b/pytests/test_plugin_runtime.py
@@ -19,6 +19,104 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "packages", "maibot-plugin-sdk"))
+def build_test_manifest(
+ plugin_id: str,
+ *,
+ version: str = "1.0.0",
+ name: str = "测试插件",
+ description: str = "测试插件描述",
+ dependencies: list[dict[str, str]] | None = None,
+ capabilities: list[str] | None = None,
+ host_min_version: str = "0.12.0",
+ host_max_version: str = "1.0.0",
+ sdk_min_version: str = "2.0.0",
+ sdk_max_version: str = "2.99.99",
+) -> dict[str, object]:
+ """构造一个合法的 Manifest v2 测试样例。
+
+ Args:
+ plugin_id: 插件 ID。
+ version: 插件版本。
+ name: 展示名称。
+ description: 插件描述。
+ dependencies: 依赖声明列表。
+ capabilities: 能力声明列表。
+ host_min_version: Host 最低支持版本。
+ host_max_version: Host 最高支持版本。
+ sdk_min_version: SDK 最低支持版本。
+ sdk_max_version: SDK 最高支持版本。
+
+ Returns:
+ dict[str, object]: 可直接序列化为 ``_manifest.json`` 的字典。
+ """
+ return {
+ "manifest_version": 2,
+ "version": version,
+ "name": name,
+ "description": description,
+ "author": {
+ "name": "tester",
+ "url": "https://example.com/tester",
+ },
+ "license": "MIT",
+ "urls": {
+ "repository": f"https://example.com/{plugin_id}",
+ },
+ "host_application": {
+ "min_version": host_min_version,
+ "max_version": host_max_version,
+ },
+ "sdk": {
+ "min_version": sdk_min_version,
+ "max_version": sdk_max_version,
+ },
+ "dependencies": dependencies or [],
+ "capabilities": capabilities or [],
+ "i18n": {
+ "default_locale": "zh-CN",
+ "supported_locales": ["zh-CN"],
+ },
+ "id": plugin_id,
+ }
+
+
+def build_test_manifest_model(
+ plugin_id: str,
+ *,
+ version: str = "1.0.0",
+ dependencies: list[dict[str, str]] | None = None,
+ capabilities: list[str] | None = None,
+ host_version: str = "1.0.0",
+ sdk_version: str = "2.0.1",
+) -> object:
+ """构造一个已经通过校验的强类型 Manifest 测试对象。
+
+ Args:
+ plugin_id: 插件 ID。
+ version: 插件版本。
+ dependencies: 依赖声明列表。
+ capabilities: 能力声明列表。
+ host_version: 当前测试使用的 Host 版本。
+ sdk_version: 当前测试使用的 SDK 版本。
+
+ Returns:
+ object: ``PluginManifest`` 实例。
+ """
+ from src.plugin_runtime.runner.manifest_validator import ManifestValidator
+
+ validator = ManifestValidator(host_version=host_version, sdk_version=sdk_version)
+ manifest = validator.parse_manifest(
+ build_test_manifest(
+ plugin_id,
+ version=version,
+ dependencies=dependencies,
+ capabilities=capabilities,
+ )
+ )
+ assert manifest is not None
+ return manifest
+
+
# ─── 协议层测试 ───────────────────────────────────────────
@@ -759,65 +857,77 @@ class TestManifestValidator:
def test_valid_manifest(self):
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
- validator = ManifestValidator()
- manifest = {
- "manifest_version": 1,
- "name": "test_plugin",
- "version": "1.0.0",
- "description": "测试插件",
- "author": "test",
- }
+ validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1")
+ manifest = build_test_manifest("test.valid-plugin", capabilities=["send.text"])
assert validator.validate(manifest) is True
assert len(validator.errors) == 0
+ assert validator.warnings == []
def test_missing_required_fields(self):
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
- validator = ManifestValidator()
- manifest = {"manifest_version": 1}
+ validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1")
+ manifest = {"manifest_version": 2}
assert validator.validate(manifest) is False
- assert len(validator.errors) >= 4 # name, version, description, author
+ assert len(validator.errors) >= 6
+ assert any("缺少必需字段" in error for error in validator.errors)
def test_unsupported_manifest_version(self):
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
- validator = ManifestValidator()
- manifest = {
- "manifest_version": 999,
- "name": "test",
- "version": "1.0",
- "description": "d",
- "author": "a",
- }
+ validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1")
+ manifest = build_test_manifest("test.invalid-version")
+ manifest["manifest_version"] = 999
assert validator.validate(manifest) is False
assert any("manifest_version" in e for e in validator.errors)
def test_host_version_compatibility(self):
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
- validator = ManifestValidator(host_version="0.8.5")
- manifest = {
- "name": "test",
- "version": "1.0",
- "description": "d",
- "author": "a",
- "host_application": {"min_version": "0.9.0"},
- }
+ validator = ManifestValidator(host_version="0.8.5", sdk_version="2.0.1")
+ manifest = build_test_manifest(
+ "test.host-check",
+ host_min_version="0.9.0",
+ host_max_version="1.0.0",
+ )
assert validator.validate(manifest) is False
assert any("Host 版本不兼容" in e for e in validator.errors)
- def test_recommended_fields_warning(self):
+ def test_sdk_version_compatibility(self):
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
- validator = ManifestValidator()
- manifest = {
- "name": "test",
- "version": "1.0",
- "description": "d",
- "author": "a",
- }
- validator.validate(manifest)
- assert len(validator.warnings) >= 3 # license, keywords, categories
+ validator = ManifestValidator(host_version="1.0.0", sdk_version="1.9.9")
+ manifest = build_test_manifest("test.sdk-check")
+ assert validator.validate(manifest) is False
+ assert any("SDK 版本不兼容" in e for e in validator.errors)
+
+ def test_extra_fields_are_rejected(self):
+ from src.plugin_runtime.runner.manifest_validator import ManifestValidator
+
+ validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1")
+ manifest = build_test_manifest("test.extra-field")
+ manifest["unexpected"] = True
+
+ assert validator.validate(manifest) is False
+ assert any("存在未声明字段" in error for error in validator.errors)
+
+ def test_python_package_conflict_rejects_manifest(self):
+ from src.plugin_runtime.runner.manifest_validator import ManifestValidator
+
+ validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1")
+ manifest = build_test_manifest(
+ "test.numpy-conflict",
+ dependencies=[
+ {
+ "type": "python_package",
+ "name": "numpy",
+ "version_spec": ">=999.0.0",
+ }
+ ],
+ )
+
+ assert validator.validate(manifest) is False
+ assert any("Python 包依赖冲突" in error for error in validator.errors)
class TestVersionComparator:
@@ -859,59 +969,83 @@ class TestDependencyResolution:
loader = PluginLoader()
candidates = {
- "core": ("dir_core", {"name": "core", "version": "1.0", "description": "d", "author": "a"}, "plugin.py"),
- "auth": (
- "dir_auth",
- {"name": "auth", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core"]},
+ "test.core": (
+ "dir_core",
+ build_test_manifest_model("test.core"),
"plugin.py",
),
- "api": (
+ "test.auth": (
+ "dir_auth",
+ build_test_manifest_model(
+ "test.auth",
+ dependencies=[
+ {"type": "plugin", "id": "test.core", "version_spec": ">=1.0.0,<2.0.0"},
+ ],
+ ),
+ "plugin.py",
+ ),
+ "test.api": (
"dir_api",
- {"name": "api", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core", "auth"]},
+ build_test_manifest_model(
+ "test.api",
+ dependencies=[
+ {"type": "plugin", "id": "test.core", "version_spec": ">=1.0.0,<2.0.0"},
+ {"type": "plugin", "id": "test.auth", "version_spec": ">=1.0.0,<2.0.0"},
+ ],
+ ),
"plugin.py",
),
}
order, failed = loader._resolve_dependencies(candidates)
assert len(failed) == 0
- assert order.index("core") < order.index("auth")
- assert order.index("auth") < order.index("api")
+ assert order.index("test.core") < order.index("test.auth")
+ assert order.index("test.auth") < order.index("test.api")
def test_missing_dependency(self):
from src.plugin_runtime.runner.plugin_loader import PluginLoader
loader = PluginLoader()
candidates = {
- "plugin_a": (
+ "test.plugin-a": (
"dir_a",
- {
- "name": "plugin_a",
- "version": "1.0",
- "description": "d",
- "author": "a",
- "dependencies": ["nonexistent"],
- },
+ build_test_manifest_model(
+ "test.plugin-a",
+ dependencies=[
+ {"type": "plugin", "id": "test.nonexistent", "version_spec": ">=1.0.0,<2.0.0"},
+ ],
+ ),
"plugin.py",
),
}
order, failed = loader._resolve_dependencies(candidates)
- assert "plugin_a" in failed
- assert "缺少依赖" in failed["plugin_a"]
+ assert "test.plugin-a" in failed
+ assert "依赖未满足" in failed["test.plugin-a"]
def test_circular_dependency(self):
from src.plugin_runtime.runner.plugin_loader import PluginLoader
loader = PluginLoader()
candidates = {
- "a": (
+ "test.a": (
"dir_a",
- {"name": "a", "version": "1.0", "description": "d", "author": "x", "dependencies": ["b"]},
+ build_test_manifest_model(
+ "test.a",
+ dependencies=[
+ {"type": "plugin", "id": "test.b", "version_spec": ">=1.0.0,<2.0.0"},
+ ],
+ ),
"p.py",
),
- "b": (
+ "test.b": (
"dir_b",
- {"name": "b", "version": "1.0", "description": "d", "author": "x", "dependencies": ["a"]},
+ build_test_manifest_model(
+ "test.b",
+ dependencies=[
+ {"type": "plugin", "id": "test.a", "version_spec": ">=1.0.0,<2.0.0"},
+ ],
+ ),
"p.py",
),
}
@@ -929,12 +1063,11 @@ class TestDependencyResolution:
(plugin_dir / "_manifest.json").write_text(
json.dumps(
- {
- "name": "grok_search_plugin",
- "version": "1.0.0",
- "description": "demo",
- "author": "tester",
- }
+ build_test_manifest(
+ "test.grok-search-plugin",
+ name="grok_search_plugin",
+ description="demo",
+ )
),
encoding="utf-8",
)
@@ -954,7 +1087,7 @@ class TestDependencyResolution:
loader = PluginLoader()
loaded = loader.discover_and_load([str(plugin_root)])
- assert [meta.plugin_id for meta in loaded] == ["grok_search_plugin"]
+ assert [meta.plugin_id for meta in loaded] == ["test.grok-search-plugin"]
assert loader.failed_plugins == {}
assert loaded[0].instance.answer() == 42
@@ -968,12 +1101,11 @@ class TestDependencyResolution:
(plugin_dir / "_manifest.json").write_text(
json.dumps(
- {
- "name": "demo_plugin",
- "version": "1.0.0",
- "description": "demo",
- "author": "tester",
- }
+ build_test_manifest(
+ "test.demo-plugin",
+ name="demo_plugin",
+ description="demo",
+ )
),
encoding="utf-8",
)
@@ -993,8 +1125,8 @@ class TestDependencyResolution:
loaded = loader.discover_and_load([str(plugin_root)])
assert loaded == []
- assert "demo_plugin" in loader.failed_plugins
- assert "on_config_update" in loader.failed_plugins["demo_plugin"]
+ assert "test.demo-plugin" in loader.failed_plugins
+ assert "on_config_update" in loader.failed_plugins["test.demo-plugin"]
def test_loader_requires_sdk_plugin_to_override_on_load(self, tmp_path):
from src.plugin_runtime.runner.plugin_loader import PluginLoader
@@ -1006,12 +1138,11 @@ class TestDependencyResolution:
(plugin_dir / "_manifest.json").write_text(
json.dumps(
- {
- "name": "demo_plugin",
- "version": "1.0.0",
- "description": "demo",
- "author": "tester",
- }
+ build_test_manifest(
+ "test.demo-plugin",
+ name="demo_plugin",
+ description="demo",
+ )
),
encoding="utf-8",
)
@@ -1031,8 +1162,8 @@ class TestDependencyResolution:
loaded = loader.discover_and_load([str(plugin_root)])
assert loaded == []
- assert "demo_plugin" in loader.failed_plugins
- assert "on_load" in loader.failed_plugins["demo_plugin"]
+ assert "test.demo-plugin" in loader.failed_plugins
+ assert "on_load" in loader.failed_plugins["test.demo-plugin"]
def test_loader_requires_sdk_plugin_to_override_on_unload(self, tmp_path):
from src.plugin_runtime.runner.plugin_loader import PluginLoader
@@ -1044,12 +1175,11 @@ class TestDependencyResolution:
(plugin_dir / "_manifest.json").write_text(
json.dumps(
- {
- "name": "demo_plugin",
- "version": "1.0.0",
- "description": "demo",
- "author": "tester",
- }
+ build_test_manifest(
+ "test.demo-plugin",
+ name="demo_plugin",
+ description="demo",
+ )
),
encoding="utf-8",
)
@@ -1069,8 +1199,8 @@ class TestDependencyResolution:
loaded = loader.discover_and_load([str(plugin_root)])
assert loaded == []
- assert "demo_plugin" in loader.failed_plugins
- assert "on_unload" in loader.failed_plugins["demo_plugin"]
+ assert "test.demo-plugin" in loader.failed_plugins
+ assert "on_unload" in loader.failed_plugins["test.demo-plugin"]
def test_isolate_sys_path_preserves_plugin_dirs(self):
from src.plugin_runtime.runner import runner_main
@@ -2374,16 +2504,19 @@ class TestIntegration:
def __init__(self, plugin_dirs=None, socket_path=None):
self._plugin_dirs = plugin_dirs or []
self.capability_service = FakeCapabilityService()
- self.external_plugin_ids = []
+ self.external_plugin_versions = {}
self.stopped = False
instances.append(self)
- def set_external_available_plugin_ids(self, plugin_ids):
- self.external_plugin_ids = list(plugin_ids)
+ def set_external_available_plugins(self, plugin_versions):
+ self.external_plugin_versions = dict(plugin_versions)
def get_loaded_plugin_ids(self):
return []
+ def get_loaded_plugin_versions(self):
+ return {}
+
async def start(self):
if len(instances) == 2 and self is instances[1]:
raise RuntimeError("boom")
@@ -2425,8 +2558,8 @@ class TestIntegration:
(beta_dir / "config.toml").write_text("enabled = false\n", encoding="utf-8")
(alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
(beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
- (alpha_dir / "_manifest.json").write_text(json.dumps({"name": "alpha"}), encoding="utf-8")
- (beta_dir / "_manifest.json").write_text(json.dumps({"name": "beta"}), encoding="utf-8")
+ (alpha_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.alpha")), encoding="utf-8")
+ (beta_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.beta")), encoding="utf-8")
monkeypatch.chdir(tmp_path)
@@ -2440,8 +2573,11 @@ class TestIntegration:
def get_loaded_plugin_ids(self):
return sorted(self._registered_plugins.keys())
+ def get_loaded_plugin_versions(self):
+ return {plugin_id: "1.0.0" for plugin_id in self._registered_plugins}
+
async def reload_plugins(self, plugin_ids=None, reason="manual", external_available_plugins=None):
- self.reload_reasons.append((plugin_ids, reason, external_available_plugins or []))
+ self.reload_reasons.append((plugin_ids, reason, external_available_plugins or {}))
async def notify_plugin_config_updated(self, plugin_id, config_data, config_version=""):
self.config_updates.append((plugin_id, config_data, config_version))
@@ -2449,8 +2585,8 @@ class TestIntegration:
manager = integration_module.PluginRuntimeManager()
manager._started = True
- manager._builtin_supervisor = FakeSupervisor([builtin_root], {"alpha": object()})
- manager._third_party_supervisor = FakeSupervisor([thirdparty_root], {"beta": object()})
+ manager._builtin_supervisor = FakeSupervisor([builtin_root], {"test.alpha": object()})
+ manager._third_party_supervisor = FakeSupervisor([thirdparty_root], {"test.beta": object()})
changes = [
FileChange(change_type=1, path=beta_dir / "plugin.py"),
@@ -2466,7 +2602,9 @@ class TestIntegration:
await manager._handle_plugin_source_changes(changes)
assert manager._builtin_supervisor.reload_reasons == []
- assert manager._third_party_supervisor.reload_reasons == [(["beta"], "file_watcher", ["alpha"])]
+ assert manager._third_party_supervisor.reload_reasons == [
+ (["test.beta"], "file_watcher", {"test.alpha": "1.0.0"})
+ ]
assert manager._builtin_supervisor.config_updates == []
assert manager._third_party_supervisor.config_updates == []
assert refresh_calls == [True]
@@ -2487,15 +2625,18 @@ class TestIntegration:
def get_loaded_plugin_ids(self):
return sorted(self._registered_plugins.keys())
+ def get_loaded_plugin_versions(self):
+ return {plugin_id: "1.0.0" for plugin_id in self._registered_plugins}
+
async def reload_plugins(self, plugin_ids=None, reason="manual", external_available_plugins=None):
- self.reload_calls.append((plugin_ids, reason, sorted(external_available_plugins or [])))
+ self.reload_calls.append((plugin_ids, reason, dict(sorted((external_available_plugins or {}).items()))))
return True
- builtin_supervisor = FakeSupervisor({"alpha": FakeRegistration([])})
+ builtin_supervisor = FakeSupervisor({"test.alpha": FakeRegistration([])})
third_party_supervisor = FakeSupervisor(
{
- "beta": FakeRegistration(["alpha"]),
- "gamma": FakeRegistration(["beta"]),
+ "test.beta": FakeRegistration(["test.alpha"]),
+ "test.gamma": FakeRegistration(["test.beta"]),
}
)
@@ -2510,13 +2651,15 @@ class TestIntegration:
lambda message: warning_messages.append(message),
)
- reloaded = await manager.reload_plugins_globally(["alpha"], reason="manual")
+ reloaded = await manager.reload_plugins_globally(["test.alpha"], reason="manual")
assert reloaded is True
- assert builtin_supervisor.reload_calls == [(["alpha"], "manual", ["beta", "gamma"])]
+ assert builtin_supervisor.reload_calls == [
+ (["test.alpha"], "manual", {"test.beta": "1.0.0", "test.gamma": "1.0.0"})
+ ]
assert third_party_supervisor.reload_calls == []
assert len(warning_messages) == 1
- assert "beta, gamma" in warning_messages[0]
+ assert "test.beta, test.gamma" in warning_messages[0]
assert "跨 Supervisor API 调用仍然可用" in warning_messages[0]
@pytest.mark.asyncio
@@ -2535,8 +2678,8 @@ class TestIntegration:
(beta_dir / "config.toml").write_text("enabled = false\n", encoding="utf-8")
(alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
(beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
- (alpha_dir / "_manifest.json").write_text(json.dumps({"name": "alpha"}), encoding="utf-8")
- (beta_dir / "_manifest.json").write_text(json.dumps({"name": "beta"}), encoding="utf-8")
+ (alpha_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.alpha")), encoding="utf-8")
+ (beta_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.beta")), encoding="utf-8")
monkeypatch.chdir(tmp_path)
@@ -2558,15 +2701,15 @@ class TestIntegration:
manager = integration_module.PluginRuntimeManager()
manager._started = True
- manager._builtin_supervisor = FakeSupervisor([builtin_root], ["alpha"])
- manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["beta"])
+ manager._builtin_supervisor = FakeSupervisor([builtin_root], ["test.alpha"])
+ manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["test.beta"])
await manager._handle_plugin_config_changes(
- "alpha",
+ "test.alpha",
[FileChange(change_type=1, path=alpha_dir / "config.toml")],
)
- assert manager._builtin_supervisor.config_updates == [("alpha", {"enabled": True}, "", "self")]
+ assert manager._builtin_supervisor.config_updates == [("test.alpha", {"enabled": True}, "", "self")]
assert manager._third_party_supervisor.config_updates == []
@pytest.mark.asyncio
@@ -2615,23 +2758,23 @@ class TestIntegration:
manager._started = True
manager._builtin_supervisor = FakeSupervisor(
{
- "alpha": FakeRegistration(["bot"]),
- "beta": FakeRegistration([]),
+ "test.alpha": FakeRegistration(["bot"]),
+ "test.beta": FakeRegistration([]),
}
)
manager._third_party_supervisor = FakeSupervisor(
{
- "gamma": FakeRegistration(["model"]),
+ "test.gamma": FakeRegistration(["model"]),
}
)
await manager._handle_main_config_reload(["bot", "model"])
assert manager._builtin_supervisor.config_updates == [
- ("alpha", {"bot": {"name": "MaiBot"}}, "", "bot")
+ ("test.alpha", {"bot": {"name": "MaiBot"}}, "", "bot")
]
assert manager._third_party_supervisor.config_updates == [
- ("gamma", {"models": [{"name": "demo"}]}, "", "model")
+ ("test.gamma", {"models": [{"name": "demo"}]}, "", "model")
]
def test_refresh_plugin_config_watch_subscriptions_registers_per_plugin(self, tmp_path):
@@ -2646,8 +2789,8 @@ class TestIntegration:
beta_dir.mkdir(parents=True)
(alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
(beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
- (alpha_dir / "_manifest.json").write_text(json.dumps({"name": "alpha"}), encoding="utf-8")
- (beta_dir / "_manifest.json").write_text(json.dumps({"name": "beta"}), encoding="utf-8")
+ (alpha_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.alpha")), encoding="utf-8")
+ (beta_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.beta")), encoding="utf-8")
class FakeWatcher:
def __init__(self):
@@ -2670,12 +2813,12 @@ class TestIntegration:
manager = integration_module.PluginRuntimeManager()
manager._plugin_file_watcher = FakeWatcher()
- manager._builtin_supervisor = FakeSupervisor([builtin_root], ["alpha"])
- manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["beta"])
+ manager._builtin_supervisor = FakeSupervisor([builtin_root], ["test.alpha"])
+ manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["test.beta"])
manager._refresh_plugin_config_watch_subscriptions()
- assert set(manager._plugin_config_watcher_subscriptions.keys()) == {"alpha", "beta"}
+ assert set(manager._plugin_config_watcher_subscriptions.keys()) == {"test.alpha", "test.beta"}
assert {
subscription["paths"][0] for subscription in manager._plugin_file_watcher.subscriptions
} == {alpha_dir / "config.toml", beta_dir / "config.toml"}
diff --git a/src/plugin_runtime/__init__.py b/src/plugin_runtime/__init__.py
index 704ce514..7f2d789f 100644
--- a/src/plugin_runtime/__init__.py
+++ b/src/plugin_runtime/__init__.py
@@ -18,7 +18,7 @@ ENV_HOST_VERSION = "MAIBOT_HOST_VERSION"
"""Runner 读取的 Host 应用版本号,用于 manifest 兼容性校验"""
ENV_EXTERNAL_PLUGIN_IDS = "MAIBOT_EXTERNAL_PLUGIN_IDS"
-"""Runner 启动时可视为已满足的外部插件依赖列表(JSON 数组)"""
+"""Runner 启动时可视为已满足的外部插件依赖版本映射(JSON 对象)"""
ENV_GLOBAL_CONFIG_SNAPSHOT = "MAIBOT_GLOBAL_CONFIG_SNAPSHOT"
"""Runner 启动时注入的全局配置快照(JSON 对象)"""
diff --git a/src/plugin_runtime/capabilities/components.py b/src/plugin_runtime/capabilities/components.py
index 2e4c111c..33b54c64 100644
--- a/src/plugin_runtime/capabilities/components.py
+++ b/src/plugin_runtime/capabilities/components.py
@@ -191,7 +191,7 @@ class RuntimeComponentCapabilityMixin:
return None, None, "缺少必要参数 api_name"
if "." in normalized_api_name:
- target_plugin_id, target_api_name = normalized_api_name.split(".", 1)
+ target_plugin_id, target_api_name = normalized_api_name.rsplit(".", 1)
try:
supervisor = self._get_supervisor_for_plugin(target_plugin_id)
except RuntimeError as exc:
@@ -282,7 +282,7 @@ class RuntimeComponentCapabilityMixin:
return None, None, "缺少必要参数 name"
if "." in normalized_name:
- plugin_id, api_name = normalized_name.split(".", 1)
+ plugin_id, api_name = normalized_name.rsplit(".", 1)
try:
supervisor = self._get_supervisor_for_plugin(plugin_id)
except RuntimeError as exc:
diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py
index 693eae51..ac953bb3 100644
--- a/src/plugin_runtime/host/supervisor.py
+++ b/src/plugin_runtime/host/supervisor.py
@@ -116,7 +116,7 @@ class PluginRunnerSupervisor:
self._runner_process: Optional[asyncio.subprocess.Process] = None
self._registered_plugins: Dict[str, RegisterPluginPayload] = {}
self._message_gateway_states: Dict[str, Dict[str, _MessageGatewayRuntimeState]] = {}
- self._external_available_plugin_ids: List[str] = []
+ self._external_available_plugins: Dict[str, str] = {}
self._runner_ready_events: asyncio.Event = asyncio.Event()
self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload()
self._health_task: Optional[asyncio.Task[None]] = None
@@ -166,21 +166,34 @@ class PluginRunnerSupervisor:
"""返回底层 RPC 服务端。"""
return self._rpc_server
- def set_external_available_plugin_ids(self, plugin_ids: List[str]) -> None:
- """设置当前 Runner 启动/重载时可视为已满足的外部依赖列表。"""
+ def set_external_available_plugins(self, plugin_versions: Dict[str, str]) -> None:
+ """设置当前 Runner 启动/重载时可视为已满足的外部依赖版本映射。
- normalized_plugin_ids = {
- str(plugin_id or "").strip()
- for plugin_id in plugin_ids
- if str(plugin_id or "").strip()
+ Args:
+ plugin_versions: 外部插件版本映射,键为插件 ID,值为插件版本。
+ """
+ self._external_available_plugins = {
+ str(plugin_id or "").strip(): str(plugin_version or "").strip()
+ for plugin_id, plugin_version in plugin_versions.items()
+ if str(plugin_id or "").strip() and str(plugin_version or "").strip()
}
- self._external_available_plugin_ids = sorted(normalized_plugin_ids)
def get_loaded_plugin_ids(self) -> List[str]:
"""返回当前 Supervisor 已注册的插件 ID 列表。"""
return sorted(self._registered_plugins.keys())
+ def get_loaded_plugin_versions(self) -> Dict[str, str]:
+ """返回当前 Supervisor 已注册插件的版本映射。
+
+ Returns:
+ Dict[str, str]: 已注册插件版本映射,键为插件 ID,值为插件版本。
+ """
+ return {
+ plugin_id: registration.plugin_version
+ for plugin_id, registration in self._registered_plugins.items()
+ }
+
async def dispatch_event(
self,
event_type: str,
@@ -373,14 +386,14 @@ class PluginRunnerSupervisor:
self,
plugin_id: str,
reason: str = "manual",
- external_available_plugins: Optional[List[str]] = None,
+ external_available_plugins: Optional[Dict[str, str]] = None,
) -> bool:
"""按插件 ID 触发精确重载。
Args:
plugin_id: 目标插件 ID。
reason: 重载原因。
- external_available_plugins: 视为已满足的外部依赖插件 ID 列表。
+ external_available_plugins: 视为已满足的外部依赖插件版本映射。
Returns:
bool: 是否重载成功。
@@ -392,7 +405,7 @@ class PluginRunnerSupervisor:
payload={
"plugin_id": plugin_id,
"reason": reason,
- "external_available_plugins": external_available_plugins or self._external_available_plugin_ids,
+ "external_available_plugins": external_available_plugins or self._external_available_plugins,
},
timeout_ms=max(int(self._runner_spawn_timeout * 1000), 10000),
)
@@ -409,14 +422,14 @@ class PluginRunnerSupervisor:
self,
plugin_ids: Optional[List[str]] = None,
reason: str = "manual",
- external_available_plugins: Optional[List[str]] = None,
+ external_available_plugins: Optional[Dict[str, str]] = None,
) -> bool:
"""批量重载插件。
Args:
plugin_ids: 目标插件 ID 列表;为空时重载当前已注册的全部插件。
reason: 重载原因。
- external_available_plugins: 视为已满足的外部依赖插件 ID 列表。
+ external_available_plugins: 视为已满足的外部依赖插件版本映射。
Returns:
bool: 是否全部重载成功。
@@ -1136,7 +1149,7 @@ class PluginRunnerSupervisor:
global_config_snapshot = config_manager.get_global_config().model_dump()
global_config_snapshot["model"] = config_manager.get_model_config().model_dump()
return {
- ENV_EXTERNAL_PLUGIN_IDS: json.dumps(self._external_available_plugin_ids, ensure_ascii=False),
+ ENV_EXTERNAL_PLUGIN_IDS: json.dumps(self._external_available_plugins, ensure_ascii=False),
ENV_GLOBAL_CONFIG_SNAPSHOT: json.dumps(global_config_snapshot, ensure_ascii=False),
ENV_HOST_VERSION: PROTOCOL_VERSION,
ENV_IPC_ADDRESS: self._transport.get_address(),
diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py
index d48260e5..092b9597 100644
--- a/src/plugin_runtime/integration.py
+++ b/src/plugin_runtime/integration.py
@@ -11,7 +11,6 @@ from pathlib import Path
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, Iterable, List, Optional, Sequence, Set, Tuple
import asyncio
-import json
import tomlkit
@@ -26,6 +25,7 @@ from src.plugin_runtime.capabilities import (
)
from src.plugin_runtime.capabilities.registry import register_capability_impls
from src.plugin_runtime.host.message_utils import MessageDict, PluginMessageUtils
+from src.plugin_runtime.runner.manifest_validator import ManifestValidator
if TYPE_CHECKING:
from src.chat.message_receive.message import SessionMessage
@@ -69,6 +69,7 @@ class PluginRuntimeManager(
self._plugin_source_watcher_subscription_id: Optional[str] = None
self._plugin_config_watcher_subscriptions: Dict[str, Tuple[Path, str]] = {}
self._plugin_path_cache: Dict[str, Path] = {}
+ self._manifest_validator: ManifestValidator = ManifestValidator()
self._config_reload_callback: Callable[[Sequence[str]], Awaitable[None]] = self._handle_main_config_reload
self._config_reload_callback_registered: bool = False
@@ -102,46 +103,11 @@ class PluginRuntimeManager(
candidate = Path("plugins").resolve()
return [candidate] if candidate.is_dir() else []
- @staticmethod
- def _extract_manifest_dependencies(manifest: Dict[str, Any]) -> List[str]:
- """从插件 manifest 中提取规范化后的依赖插件 ID 列表。"""
-
- dependencies: List[str] = []
- for dependency in manifest.get("dependencies", []):
- if isinstance(dependency, str):
- normalized_dependency = dependency.strip()
- elif isinstance(dependency, dict):
- normalized_dependency = str(dependency.get("name", "") or "").strip()
- else:
- normalized_dependency = ""
-
- if normalized_dependency:
- dependencies.append(normalized_dependency)
- return dependencies
-
@classmethod
def _discover_plugin_dependency_map(cls, plugin_dirs: Iterable[Path]) -> Dict[str, List[str]]:
"""扫描指定插件目录集合,返回 ``plugin_id -> dependencies`` 映射。"""
-
- dependency_map: Dict[str, List[str]] = {}
- for plugin_dir in cls._iter_candidate_plugin_paths(plugin_dirs):
- manifest_path = plugin_dir / "_manifest.json"
- entrypoint_path = plugin_dir / "plugin.py"
- if not manifest_path.is_file() or not entrypoint_path.is_file():
- continue
-
- try:
- with manifest_path.open("r", encoding="utf-8") as manifest_file:
- manifest = json.load(manifest_file)
- except Exception:
- continue
-
- if not isinstance(manifest, dict):
- continue
-
- plugin_id = str(manifest.get("name", plugin_dir.name) or "").strip() or plugin_dir.name
- dependency_map[plugin_id] = cls._extract_manifest_dependencies(manifest)
- return dependency_map
+ validator = ManifestValidator()
+ return validator.build_plugin_dependency_map(plugin_dirs)
@classmethod
def _build_group_start_order(
@@ -243,12 +209,12 @@ class PluginRuntimeManager(
if supervisor is None:
continue
- external_plugin_ids = [
- plugin_id
+ external_plugin_versions = {
+ plugin_id: plugin_version
for started_supervisor in started_supervisors
- for plugin_id in started_supervisor.get_loaded_plugin_ids()
- ]
- supervisor.set_external_available_plugin_ids(external_plugin_ids)
+ for plugin_id, plugin_version in started_supervisor.get_loaded_plugin_versions().items()
+ }
+ supervisor.set_external_available_plugins(external_plugin_versions)
await supervisor.start()
started_supervisors.append(supervisor)
@@ -366,23 +332,22 @@ class PluginRuntimeManager(
for plugin_id in supervisor.get_loaded_plugin_ids()
}
- def _build_external_available_plugins_for_supervisor(self, target_supervisor: "PluginSupervisor") -> List[str]:
- """收集某个 Supervisor 可用的外部插件 ID 列表。"""
+ def _build_external_available_plugins_for_supervisor(self, target_supervisor: "PluginSupervisor") -> Dict[str, str]:
+ """收集某个 Supervisor 可用的外部插件版本映射。"""
- external_plugin_ids: Set[str] = set()
+ external_plugin_versions: Dict[str, str] = {}
for supervisor in self.supervisors:
if supervisor is target_supervisor:
continue
- external_plugin_ids.update(supervisor.get_loaded_plugin_ids())
- return sorted(external_plugin_ids)
+ external_plugin_versions.update(supervisor.get_loaded_plugin_versions())
+ return external_plugin_versions
def _find_supervisor_by_plugin_directory(self, plugin_id: str) -> Optional["PluginSupervisor"]:
"""根据插件目录推断应负责该插件重载的 Supervisor。"""
for supervisor in self.supervisors:
- for plugin_dir in supervisor._plugin_dirs:
- if (Path(plugin_dir) / plugin_id).is_dir():
- return supervisor
+ if self._get_plugin_path_for_supervisor(supervisor, plugin_id) is not None:
+ return supervisor
return None
def _warn_skipped_cross_supervisor_reload(
@@ -740,30 +705,13 @@ class PluginRuntimeManager(
external_available_plugins=self._build_external_available_plugins_for_supervisor(supervisor),
)
- @staticmethod
- def _find_duplicate_plugin_ids(plugin_dirs: List[Path]) -> Dict[str, List[Path]]:
+ @classmethod
+ def _find_duplicate_plugin_ids(cls, plugin_dirs: List[Path]) -> Dict[str, List[Path]]:
"""扫描插件目录,找出被多个目录重复声明的插件 ID。"""
plugin_locations: Dict[str, List[Path]] = {}
- for base_dir in plugin_dirs:
- if not base_dir.is_dir():
- continue
- for entry in base_dir.iterdir():
- if not entry.is_dir():
- continue
- manifest_path = entry / "_manifest.json"
- plugin_path = entry / "plugin.py"
- if not manifest_path.exists() or not plugin_path.exists():
- continue
-
- plugin_id = entry.name
- try:
- with open(manifest_path, "r", encoding="utf-8") as manifest_file:
- manifest = json.load(manifest_file)
- plugin_id = str(manifest.get("name", entry.name)).strip() or entry.name
- except Exception:
- continue
-
- plugin_locations.setdefault(plugin_id, []).append(entry)
+ validator = ManifestValidator()
+ for plugin_path, manifest in validator.iter_plugin_manifests(plugin_dirs):
+ plugin_locations.setdefault(manifest.id, []).append(plugin_path)
return {
plugin_id: sorted(dict.fromkeys(paths), key=lambda p: str(p))
@@ -831,8 +779,7 @@ class PluginRuntimeManager(
if entry.is_dir():
yield entry.resolve()
- @staticmethod
- def _read_plugin_id_from_plugin_path(plugin_path: Path) -> Optional[str]:
+ def _read_plugin_id_from_plugin_path(self, plugin_path: Path) -> Optional[str]:
"""从单个插件目录中读取 manifest 声明的插件 ID。
Args:
@@ -841,22 +788,7 @@ class PluginRuntimeManager(
Returns:
Optional[str]: 解析成功时返回插件 ID,否则返回 ``None``。
"""
- manifest_path = plugin_path / "_manifest.json"
- entrypoint_path = plugin_path / "plugin.py"
- if not manifest_path.is_file() or not entrypoint_path.is_file():
- return None
-
- try:
- with open(manifest_path, "r", encoding="utf-8") as manifest_file:
- manifest = json.load(manifest_file)
- except Exception:
- return None
-
- if not isinstance(manifest, dict):
- return None
-
- plugin_id = str(manifest.get("name", plugin_path.name)).strip() or plugin_path.name
- return plugin_id or None
+ return self._manifest_validator.read_plugin_id_from_plugin_path(plugin_path)
def _iter_discovered_plugin_paths(self, plugin_dirs: Iterable[Path]) -> Iterable[Tuple[str, Path]]:
"""迭代目录中可解析到的插件 ID 与实际目录路径。
diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py
index ce40d855..c2c89a0f 100644
--- a/src/plugin_runtime/protocol/envelope.py
+++ b/src/plugin_runtime/protocol/envelope.py
@@ -282,8 +282,11 @@ class ReloadPluginPayload(BaseModel):
"""目标插件 ID"""
reason: str = Field(default="manual", description="重载原因")
"""重载原因"""
- external_available_plugins: List[str] = Field(default_factory=list, description="可视为已满足的外部依赖插件 ID")
- """可视为已满足的外部依赖插件 ID"""
+ external_available_plugins: Dict[str, str] = Field(
+ default_factory=dict,
+ description="可视为已满足的外部依赖插件版本映射",
+ )
+ """可视为已满足的外部依赖插件版本映射"""
class ReloadPluginResultPayload(BaseModel):
diff --git a/src/plugin_runtime/runner/manifest_validator.py b/src/plugin_runtime/runner/manifest_validator.py
index 32429e01..33c2b1e5 100644
--- a/src/plugin_runtime/runner/manifest_validator.py
+++ b/src/plugin_runtime/runner/manifest_validator.py
@@ -1,20 +1,36 @@
-"""Manifest 校验与版本兼容性
+"""Manifest 校验与解析。
-从旧系统的 ManifestValidator / VersionComparator 对齐移植,
-适配新 plugin_runtime 的 _manifest.json 格式。
+集中负责插件 ``_manifest.json`` 的读取、结构校验、运行时兼容性判断,
+以及插件依赖/Python 包依赖的解析逻辑。
"""
-from typing import Any, Dict, List, Tuple
+from functools import lru_cache
+from importlib import metadata as importlib_metadata
+from pathlib import Path
+from typing import Annotated, Any, Dict, Iterable, List, Literal, Optional, Tuple, Union
+import json
import re
+import tomllib
+
+from packaging.requirements import InvalidRequirement, Requirement
+from packaging.specifiers import InvalidSpecifier, SpecifierSet
+from packaging.utils import canonicalize_name
+from packaging.version import InvalidVersion, Version
+from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator, model_validator
from src.common.logger import get_logger
logger = get_logger("plugin_runtime.runner.manifest_validator")
+_SEMVER_PATTERN = re.compile(r"^\d+\.\d+\.\d+$")
+_PLUGIN_ID_PATTERN = re.compile(r"^[a-z0-9]+(?:[.-][a-z0-9]+)+$")
+_PACKAGE_NAME_PATTERN = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*$")
+_HTTP_URL_PATTERN = re.compile(r"^https?://.+$")
+
class VersionComparator:
- """语义化版本号比较器"""
+ """语义化版本号比较器。"""
@staticmethod
def normalize_version(version: str) -> str:
@@ -25,13 +41,15 @@ class VersionComparator:
Returns:
str: 规范化后的 ``major.minor.patch`` 形式版本号。
- 当输入为空或格式非法时返回 ``0.0.0``。
+ 当输入为空或格式非法时返回 ``0.0.0``。
"""
if not version:
return "0.0.0"
- normalized = re.sub(r"-snapshot\.\d+", "", version.strip())
+
+ normalized = re.sub(r"-snapshot\.\d+", "", str(version).strip())
if not re.match(r"^\d+(\.\d+){0,2}$", normalized):
return "0.0.0"
+
parts = normalized.split(".")
while len(parts) < 3:
parts.append("0")
@@ -46,7 +64,7 @@ class VersionComparator:
Returns:
Tuple[int, int, int]: 三段式版本号对应的整数元组。
- 当解析失败时返回 ``(0, 0, 0)``。
+ 当解析失败时返回 ``(0, 0, 0)``。
"""
normalized = VersionComparator.normalize_version(version)
try:
@@ -65,13 +83,13 @@ class VersionComparator:
Returns:
int: ``-1`` 表示 ``v1 < v2``,``1`` 表示 ``v1 > v2``,
- ``0`` 表示两者相等。
+ ``0`` 表示两者相等。
"""
t1 = VersionComparator.parse_version(v1)
t2 = VersionComparator.parse_version(v2)
if t1 < t2:
return -1
- elif t1 > t2:
+ if t1 > t2:
return 1
return 0
@@ -86,120 +104,1043 @@ class VersionComparator:
Returns:
Tuple[bool, str]: 第一项表示是否满足要求,第二项为失败原因;
- 当校验通过时第二项为空字符串。
+ 当校验通过时第二项为空字符串。
"""
if not min_version and not max_version:
return True, ""
- vn = VersionComparator.normalize_version(version)
+
+ normalized_version = VersionComparator.normalize_version(version)
if min_version:
- mn = VersionComparator.normalize_version(min_version)
- if VersionComparator.compare(vn, mn) < 0:
- return False, f"版本 {vn} 低于最小要求 {mn}"
+ normalized_min_version = VersionComparator.normalize_version(min_version)
+ if VersionComparator.compare(normalized_version, normalized_min_version) < 0:
+ return False, f"版本 {normalized_version} 低于最小要求 {normalized_min_version}"
if max_version:
- mx = VersionComparator.normalize_version(max_version)
- if VersionComparator.compare(vn, mx) > 0:
- return False, f"版本 {vn} 高于最大支持 {mx}"
+ normalized_max_version = VersionComparator.normalize_version(max_version)
+ if VersionComparator.compare(normalized_version, normalized_max_version) > 0:
+ return False, f"版本 {normalized_version} 高于最大支持 {normalized_max_version}"
return True, ""
+ @staticmethod
+ def is_valid_semver(version: str) -> bool:
+ """判断字符串是否为严格三段式语义版本号。
+
+ Args:
+ version: 待检查的版本号字符串。
+
+ Returns:
+ bool: 是否满足 ``X.Y.Z`` 格式。
+ """
+ return bool(_SEMVER_PATTERN.fullmatch(str(version or "").strip()))
+
+
+class _StrictManifestModel(BaseModel):
+ """Manifest 解析使用的严格基类模型。"""
+
+ model_config = ConfigDict(extra="forbid", frozen=True, str_strip_whitespace=True)
+
+
+class ManifestAuthor(_StrictManifestModel):
+ """插件作者信息。"""
+
+ name: str = Field(description="作者名称")
+ url: str = Field(description="作者主页地址")
+
+ @field_validator("name")
+ @classmethod
+ def _validate_name(cls, value: str) -> str:
+ """校验作者名称。
+
+ Args:
+ value: 原始作者名称。
+
+ Returns:
+ str: 规范化后的作者名称。
+
+ Raises:
+ ValueError: 当字段为空时抛出。
+ """
+ if not value:
+ raise ValueError("不能为空")
+ return value
+
+ @field_validator("url")
+ @classmethod
+ def _validate_url(cls, value: str) -> str:
+ """校验作者主页地址。
+
+ Args:
+ value: 原始主页地址。
+
+ Returns:
+ str: 规范化后的主页地址。
+
+ Raises:
+ ValueError: 当字段为空或不是 HTTP/HTTPS URL 时抛出。
+ """
+ if not value:
+ raise ValueError("不能为空")
+ if not _HTTP_URL_PATTERN.fullmatch(value):
+ raise ValueError("必须为 http:// 或 https:// 开头的 URL")
+ return value
+
+
+class ManifestUrls(_StrictManifestModel):
+ """插件相关链接集合。"""
+
+ repository: str = Field(description="插件仓库地址")
+ homepage: Optional[str] = Field(default=None, description="插件主页地址")
+ documentation: Optional[str] = Field(default=None, description="插件文档地址")
+ issues: Optional[str] = Field(default=None, description="插件问题反馈地址")
+
+ @field_validator("repository")
+ @classmethod
+ def _validate_repository(cls, value: str) -> str:
+ """校验仓库地址。
+
+ Args:
+ value: 原始仓库地址。
+
+ Returns:
+ str: 规范化后的仓库地址。
+
+ Raises:
+ ValueError: 当字段为空或不是 HTTP/HTTPS URL 时抛出。
+ """
+ if not value:
+ raise ValueError("不能为空")
+ if not _HTTP_URL_PATTERN.fullmatch(value):
+ raise ValueError("必须为 http:// 或 https:// 开头的 URL")
+ return value
+
+ @field_validator("homepage", "documentation", "issues")
+ @classmethod
+ def _validate_optional_url(cls, value: Optional[str]) -> Optional[str]:
+ """校验可选链接字段。
+
+ Args:
+ value: 原始链接值。
+
+ Returns:
+ Optional[str]: 合法的链接值。
+
+ Raises:
+ ValueError: 当提供的值不是 HTTP/HTTPS URL 时抛出。
+ """
+ if value is None:
+ return None
+ if not value:
+ raise ValueError("不能为空字符串")
+ if not _HTTP_URL_PATTERN.fullmatch(value):
+ raise ValueError("必须为 http:// 或 https:// 开头的 URL")
+ return value
+
+
+class ManifestVersionRange(_StrictManifestModel):
+ """版本闭区间声明。"""
+
+ min_version: str = Field(description="最小版本,闭区间")
+ max_version: str = Field(description="最大版本,闭区间")
+
+ @field_validator("min_version", "max_version")
+ @classmethod
+ def _validate_version(cls, value: str) -> str:
+ """校验版本号格式。
+
+ Args:
+ value: 原始版本号。
+
+ Returns:
+ str: 合法的版本号。
+
+ Raises:
+ ValueError: 当版本号不是严格三段式语义版本时抛出。
+ """
+ if not VersionComparator.is_valid_semver(value):
+ raise ValueError("必须为严格三段式版本号,例如 1.0.0")
+ return value
+
+ @model_validator(mode="after")
+ def _validate_range(self) -> "ManifestVersionRange":
+ """校验版本区间上下界关系。
+
+ Returns:
+ ManifestVersionRange: 当前对象本身。
+
+ Raises:
+ ValueError: 当最小版本大于最大版本时抛出。
+ """
+ if VersionComparator.compare(self.min_version, self.max_version) > 0:
+ raise ValueError("min_version 不能大于 max_version")
+ return self
+
+
+class ManifestI18n(_StrictManifestModel):
+ """国际化配置。"""
+
+ default_locale: str = Field(description="默认语言")
+ locales_path: Optional[str] = Field(default=None, description="语言资源目录")
+ supported_locales: List[str] = Field(default_factory=list, description="支持的语言列表")
+
+ @field_validator("default_locale")
+ @classmethod
+ def _validate_default_locale(cls, value: str) -> str:
+ """校验默认语言。
+
+ Args:
+ value: 原始默认语言。
+
+ Returns:
+ str: 规范化后的默认语言。
+
+ Raises:
+ ValueError: 当字段为空时抛出。
+ """
+ if not value:
+ raise ValueError("不能为空")
+ return value
+
+ @field_validator("locales_path")
+ @classmethod
+ def _validate_locales_path(cls, value: Optional[str]) -> Optional[str]:
+ """校验语言资源目录。
+
+ Args:
+ value: 原始语言资源目录。
+
+ Returns:
+ Optional[str]: 合法的目录值。
+
+ Raises:
+ ValueError: 当值为空字符串时抛出。
+ """
+ if value is None:
+ return None
+ if not value:
+ raise ValueError("不能为空字符串")
+ return value
+
+ @field_validator("supported_locales")
+ @classmethod
+ def _validate_supported_locales(cls, value: List[str]) -> List[str]:
+ """校验支持语言列表。
+
+ Args:
+ value: 原始语言列表。
+
+ Returns:
+ List[str]: 去重后的语言列表。
+
+ Raises:
+ ValueError: 当列表项为空时抛出。
+ """
+ normalized_locales: List[str] = []
+ for locale in value:
+ normalized_locale = str(locale or "").strip()
+ if not normalized_locale:
+ raise ValueError("语言列表中存在空值")
+ if normalized_locale not in normalized_locales:
+ normalized_locales.append(normalized_locale)
+ return normalized_locales
+
+ @model_validator(mode="after")
+ def _validate_default_locale_membership(self) -> "ManifestI18n":
+ """校验默认语言是否位于支持列表中。
+
+ Returns:
+ ManifestI18n: 当前对象本身。
+
+ Raises:
+ ValueError: 当 ``supported_locales`` 非空但未包含 ``default_locale`` 时抛出。
+ """
+ if self.supported_locales and self.default_locale not in self.supported_locales:
+ raise ValueError("default_locale 必须包含在 supported_locales 中")
+ return self
+
+
+class PluginDependencyDefinition(_StrictManifestModel):
+ """插件级依赖声明。"""
+
+ type: Literal["plugin"] = Field(description="依赖类型")
+ id: str = Field(description="依赖插件 ID")
+ version_spec: str = Field(description="版本约束表达式")
+
+ @field_validator("id")
+ @classmethod
+ def _validate_id(cls, value: str) -> str:
+ """校验依赖插件 ID。
+
+ Args:
+ value: 原始依赖插件 ID。
+
+ Returns:
+ str: 合法的依赖插件 ID。
+
+ Raises:
+ ValueError: 当 ID 不符合规则时抛出。
+ """
+ if not _PLUGIN_ID_PATTERN.fullmatch(value):
+ raise ValueError("必须使用小写字母/数字,并以点号或横线分隔,例如 github.author.plugin")
+ return value
+
+ @field_validator("version_spec")
+ @classmethod
+ def _validate_version_spec(cls, value: str) -> str:
+ """校验插件依赖版本约束。
+
+ Args:
+ value: 原始版本约束表达式。
+
+ Returns:
+ str: 合法的版本约束表达式。
+
+ Raises:
+ ValueError: 当表达式无效时抛出。
+ """
+ if not value:
+ raise ValueError("不能为空")
+ try:
+ SpecifierSet(value)
+ except InvalidSpecifier as exc:
+ raise ValueError(f"无效的版本约束: {exc}") from exc
+ return value
+
+
+class PythonPackageDependencyDefinition(_StrictManifestModel):
+ """Python 包依赖声明。"""
+
+ type: Literal["python_package"] = Field(description="依赖类型")
+ name: str = Field(description="Python 包名")
+ version_spec: str = Field(description="版本约束表达式")
+
+ @field_validator("name")
+ @classmethod
+ def _validate_name(cls, value: str) -> str:
+ """校验 Python 包名。
+
+ Args:
+ value: 原始包名。
+
+ Returns:
+ str: 合法的包名。
+
+ Raises:
+ ValueError: 当包名不合法时抛出。
+ """
+ if not _PACKAGE_NAME_PATTERN.fullmatch(value):
+ raise ValueError("包名只能包含字母、数字、点号、下划线和横线")
+ return value
+
+ @field_validator("version_spec")
+ @classmethod
+ def _validate_version_spec(cls, value: str) -> str:
+ """校验 Python 包版本约束。
+
+ Args:
+ value: 原始版本约束表达式。
+
+ Returns:
+ str: 合法的版本约束表达式。
+
+ Raises:
+ ValueError: 当表达式无效时抛出。
+ """
+ if not value:
+ raise ValueError("不能为空")
+ try:
+ Requirement(f"placeholder{value}")
+ except InvalidRequirement as exc:
+ raise ValueError(f"无效的版本约束: {exc}") from exc
+ return value
+
+
+ManifestDependencyDefinition = Annotated[
+ Union[PluginDependencyDefinition, PythonPackageDependencyDefinition],
+ Field(discriminator="type"),
+]
+
+
+class PluginManifest(_StrictManifestModel):
+ """插件 Manifest v2 强类型模型。"""
+
+ manifest_version: Literal[2] = Field(description="Manifest 协议版本")
+ version: str = Field(description="插件版本")
+ name: str = Field(description="插件展示名称")
+ description: str = Field(description="插件描述")
+ author: ManifestAuthor = Field(description="插件作者信息")
+ license: str = Field(description="插件协议")
+ urls: ManifestUrls = Field(description="插件相关链接")
+ host_application: ManifestVersionRange = Field(description="Host 兼容区间")
+ sdk: ManifestVersionRange = Field(description="SDK 兼容区间")
+ dependencies: List[ManifestDependencyDefinition] = Field(default_factory=list, description="依赖声明")
+ capabilities: List[str] = Field(description="插件声明的能力请求")
+ i18n: ManifestI18n = Field(description="国际化配置")
+ id: str = Field(description="稳定插件 ID")
+
+ @field_validator("version")
+ @classmethod
+ def _validate_version(cls, value: str) -> str:
+ """校验插件版本号格式。
+
+ Args:
+ value: 原始插件版本号。
+
+ Returns:
+ str: 合法的插件版本号。
+
+ Raises:
+ ValueError: 当版本号不是严格三段式语义版本时抛出。
+ """
+ if not VersionComparator.is_valid_semver(value):
+ raise ValueError("必须为严格三段式版本号,例如 1.0.0")
+ return value
+
+ @field_validator("name", "description", "license", "id")
+ @classmethod
+ def _validate_required_string(cls, value: str, info: Any) -> str:
+ """校验必填字符串字段。
+
+ Args:
+ value: 原始字段值。
+ info: Pydantic 字段上下文。
+
+ Returns:
+ str: 合法的字段值。
+
+ Raises:
+ ValueError: 当字段为空或格式不合法时抛出。
+ """
+ if not value:
+ raise ValueError("不能为空")
+ if info.field_name == "id" and not _PLUGIN_ID_PATTERN.fullmatch(value):
+ raise ValueError("必须使用小写字母/数字,并以点号或横线分隔,例如 github.author.plugin")
+ return value
+
+ @field_validator("capabilities")
+ @classmethod
+ def _validate_capabilities(cls, value: List[str]) -> List[str]:
+ """校验能力声明列表。
+
+ Args:
+ value: 原始能力声明列表。
+
+ Returns:
+ List[str]: 去重后的能力列表。
+
+ Raises:
+ ValueError: 当列表为空项或能力名为空时抛出。
+ """
+ normalized_capabilities: List[str] = []
+ for capability in value:
+ normalized_capability = str(capability or "").strip()
+ if not normalized_capability:
+ raise ValueError("capabilities 中存在空能力名")
+ if normalized_capability not in normalized_capabilities:
+ normalized_capabilities.append(normalized_capability)
+ return normalized_capabilities
+
+ @model_validator(mode="after")
+ def _validate_dependencies(self) -> "PluginManifest":
+ """校验依赖声明集合。
+
+ Returns:
+ PluginManifest: 当前对象本身。
+
+ Raises:
+ ValueError: 当依赖项重复或插件依赖自身时抛出。
+ """
+ plugin_dependency_ids: set[str] = set()
+ python_package_names: set[str] = set()
+
+ for dependency in self.dependencies:
+ if isinstance(dependency, PluginDependencyDefinition):
+ if dependency.id == self.id:
+ raise ValueError("dependencies 中的插件依赖不能依赖自身")
+ if dependency.id in plugin_dependency_ids:
+ raise ValueError(f"存在重复的插件依赖声明: {dependency.id}")
+ plugin_dependency_ids.add(dependency.id)
+ continue
+
+ normalized_package_name = canonicalize_name(dependency.name)
+ if normalized_package_name in python_package_names:
+ raise ValueError(f"存在重复的 Python 包依赖声明: {dependency.name}")
+ python_package_names.add(normalized_package_name)
+
+ return self
+
+ @property
+ def plugin_dependencies(self) -> List[PluginDependencyDefinition]:
+ """返回插件级依赖列表。
+
+ Returns:
+ List[PluginDependencyDefinition]: 所有 ``type=plugin`` 的依赖项。
+ """
+ return [dependency for dependency in self.dependencies if isinstance(dependency, PluginDependencyDefinition)]
+
+ @property
+ def python_package_dependencies(self) -> List[PythonPackageDependencyDefinition]:
+ """返回 Python 包依赖列表。
+
+ Returns:
+ List[PythonPackageDependencyDefinition]: 所有 ``type=python_package`` 的依赖项。
+ """
+ return [
+ dependency
+ for dependency in self.dependencies
+ if isinstance(dependency, PythonPackageDependencyDefinition)
+ ]
+
+ @property
+ def plugin_dependency_ids(self) -> List[str]:
+ """返回插件级依赖的插件 ID 列表。
+
+ Returns:
+ List[str]: 所有插件级依赖的插件 ID。
+ """
+ return [dependency.id for dependency in self.plugin_dependencies]
+
class ManifestValidator:
- """_manifest.json 校验器"""
+ """严格的插件 Manifest v2 校验器。"""
- REQUIRED_FIELDS = ["name", "version", "description", "author"]
- RECOMMENDED_FIELDS = ["license", "keywords", "categories"]
- SUPPORTED_MANIFEST_VERSIONS = [1, 2]
+ SUPPORTED_MANIFEST_VERSIONS = [2]
- def __init__(self, host_version: str = "") -> None:
+ def __init__(
+ self,
+ host_version: str = "",
+ sdk_version: str = "",
+ project_root: Optional[Path] = None,
+ ) -> None:
"""初始化 Manifest 校验器。
Args:
- host_version: 当前 Host 版本号,用于校验插件声明的兼容区间。
+ host_version: 当前 Host 版本号;留空时自动从主程序 ``pyproject.toml`` 读取。
+ sdk_version: 当前 SDK 版本号;留空时自动从运行环境中探测。
+ project_root: 项目根目录;留空时自动推断。
"""
- self._host_version = host_version
+ self._project_root: Path = project_root or self._resolve_project_root()
+ self._host_version: str = host_version or self._detect_default_host_version(self._project_root)
+ self._sdk_version: str = sdk_version or self._detect_default_sdk_version(self._project_root)
self.errors: List[str] = []
self.warnings: List[str] = []
def validate(self, manifest: Dict[str, Any]) -> bool:
- """校验 manifest 数据,返回是否通过(errors 为空即通过)。"""
+ """校验 manifest 数据,返回是否通过。
+
+ Args:
+ manifest: 待校验的 Manifest 原始字典。
+
+ Returns:
+ bool: 校验是否通过。
+ """
+ return self.parse_manifest(manifest) is not None
+
+ def parse_manifest(self, manifest: Dict[str, Any]) -> Optional[PluginManifest]:
+ """解析并校验 manifest 字典。
+
+ Args:
+ manifest: 待解析的 Manifest 原始字典。
+
+ Returns:
+ Optional[PluginManifest]: 解析成功时返回强类型 Manifest;失败时返回 ``None``。
+ """
self.errors.clear()
self.warnings.clear()
- self._check_required_fields(manifest)
- self._check_manifest_version(manifest)
- self._check_author(manifest)
- self._check_host_compatibility(manifest)
- self._check_recommended(manifest)
+ try:
+ parsed_manifest = PluginManifest.model_validate(manifest)
+ except ValidationError as exc:
+ self.errors.extend(self._format_validation_errors(exc))
+ self._log_errors()
+ return None
+ self._validate_runtime_compatibility(parsed_manifest)
if self.errors:
- for e in self.errors:
- logger.error(f"Manifest 校验失败: {e}")
- if self.warnings:
- for w in self.warnings:
- logger.warning(f"Manifest 警告: {w}")
+ self._log_errors()
+ return None
- return len(self.errors) == 0
+ return parsed_manifest
- def _check_required_fields(self, manifest: Dict[str, Any]) -> None:
- """检查 Manifest 中的必填字段是否存在且非空。
+ def load_from_plugin_path(self, plugin_path: Path, require_entrypoint: bool = True) -> Optional[PluginManifest]:
+ """从插件目录读取并解析 manifest。
Args:
- manifest: 待校验的 Manifest 数据。
- """
- for field in self.REQUIRED_FIELDS:
- if field not in manifest:
- self.errors.append(f"缺少必需字段: {field}")
- elif not manifest[field]:
- self.errors.append(f"必需字段不能为空: {field}")
+ plugin_path: 单个插件目录路径。
+ require_entrypoint: 是否要求目录内存在 ``plugin.py`` 入口文件。
- def _check_manifest_version(self, manifest: Dict[str, Any]) -> None:
- """检查 Manifest 版本号是否在当前 Runner 支持范围内。
+ Returns:
+ Optional[PluginManifest]: 解析成功时返回强类型 Manifest;失败时返回 ``None``。
+ """
+ self.errors.clear()
+ self.warnings.clear()
+
+ manifest_path = plugin_path / "_manifest.json"
+ entrypoint_path = plugin_path / "plugin.py"
+
+ if not manifest_path.is_file():
+ self.errors.append("缺少 _manifest.json")
+ return None
+ if require_entrypoint and not entrypoint_path.is_file():
+ self.errors.append("缺少 plugin.py")
+ return None
+
+ try:
+ with manifest_path.open("r", encoding="utf-8") as manifest_file:
+ manifest_data = json.load(manifest_file)
+ except Exception as exc:
+ self.errors.append(f"manifest 解析失败: {exc}")
+ self._log_errors()
+ return None
+
+ if not isinstance(manifest_data, dict):
+ self.errors.append("manifest 顶层必须为 JSON 对象")
+ self._log_errors()
+ return None
+
+ return self.parse_manifest(manifest_data)
+
+ def iter_plugin_manifests(
+ self,
+ plugin_dirs: Iterable[Path],
+ require_entrypoint: bool = True,
+ ) -> Iterable[Tuple[Path, PluginManifest]]:
+ """扫描插件根目录并迭代所有可成功解析的 Manifest。
Args:
- manifest: 待校验的 Manifest 数据。
- """
- mv = manifest.get("manifest_version")
- if mv is not None and mv not in self.SUPPORTED_MANIFEST_VERSIONS:
- self.errors.append(f"不支持的 manifest_version: {mv},支持: {self.SUPPORTED_MANIFEST_VERSIONS}")
+ plugin_dirs: 一个或多个插件根目录。
+ require_entrypoint: 是否要求每个插件目录内存在 ``plugin.py``。
- def _check_author(self, manifest: Dict[str, Any]) -> None:
- """校验 ``author`` 字段的结构与内容。
+ Yields:
+ Tuple[Path, PluginManifest]: ``(插件目录路径, 解析结果)`` 二元组。
+ """
+ for plugin_root in plugin_dirs:
+ normalized_root = Path(plugin_root).resolve()
+ if not normalized_root.is_dir():
+ continue
+
+ for candidate_path in sorted(entry.resolve() for entry in normalized_root.iterdir() if entry.is_dir()):
+ parsed_manifest = self.load_from_plugin_path(candidate_path, require_entrypoint=require_entrypoint)
+ if parsed_manifest is None:
+ continue
+ yield candidate_path, parsed_manifest
+
+ def build_plugin_dependency_map(
+ self,
+ plugin_dirs: Iterable[Path],
+ require_entrypoint: bool = True,
+ ) -> Dict[str, List[str]]:
+ """扫描目录并构建 ``plugin_id -> 依赖插件 ID 列表`` 映射。
Args:
- manifest: 待校验的 Manifest 数据。
- """
- author = manifest.get("author")
- if author is None:
- return
- if isinstance(author, dict):
- if "name" not in author or not author["name"]:
- self.errors.append("author 对象缺少 name 字段")
- elif isinstance(author, str):
- if not author.strip():
- self.errors.append("author 不能为空")
- else:
- self.errors.append("author 应为字符串或 {name, url} 对象")
+ plugin_dirs: 一个或多个插件根目录。
+ require_entrypoint: 是否要求每个插件目录内存在 ``plugin.py``。
- def _check_host_compatibility(self, manifest: Dict[str, Any]) -> None:
- """检查插件声明的 Host 兼容范围是否包含当前 Host 版本。
+ Returns:
+ Dict[str, List[str]]: 所有成功解析到的插件依赖映射。
+ """
+ dependency_map: Dict[str, List[str]] = {}
+ for _plugin_path, manifest in self.iter_plugin_manifests(plugin_dirs, require_entrypoint=require_entrypoint):
+ dependency_map[manifest.id] = manifest.plugin_dependency_ids
+ return dependency_map
+
+ def read_plugin_id_from_plugin_path(self, plugin_path: Path, require_entrypoint: bool = True) -> Optional[str]:
+ """从单个插件目录中读取规范化后的插件 ID。
Args:
- manifest: 待校验的 Manifest 数据。
- """
- host_app = manifest.get("host_application")
- if not isinstance(host_app, dict) or not self._host_version:
- return
- min_v = host_app.get("min_version", "")
- max_v = host_app.get("max_version", "")
- ok, msg = VersionComparator.is_in_range(self._host_version, min_v, max_v)
- if not ok:
- self.errors.append(f"Host 版本不兼容: {msg} (当前 Host: {self._host_version})")
+ plugin_path: 单个插件目录路径。
+ require_entrypoint: 是否要求目录内存在 ``plugin.py``。
- def _check_recommended(self, manifest: Dict[str, Any]) -> None:
- """检查推荐字段是否齐备,并记录为警告而非错误。
+ Returns:
+ Optional[str]: 解析成功时返回插件 ID,否则返回 ``None``。
+ """
+ manifest = self.load_from_plugin_path(plugin_path, require_entrypoint=require_entrypoint)
+ if manifest is None:
+ return None
+ return manifest.id
+
+ def get_unsatisfied_plugin_dependencies(
+ self,
+ manifest: PluginManifest,
+ available_plugin_versions: Dict[str, str],
+ ) -> List[str]:
+ """返回当前 Manifest 尚未满足的插件依赖项。
Args:
- manifest: 待校验的 Manifest 数据。
+ manifest: 目标插件的强类型 Manifest。
+ available_plugin_versions: 当前可用插件版本映射,键为插件 ID,值为插件版本。
+
+ Returns:
+ List[str]: 未满足依赖的错误描述列表。
"""
- for field in self.RECOMMENDED_FIELDS:
- if field not in manifest or not manifest[field]:
- self.warnings.append(f"建议填写字段: {field}")
+ unsatisfied_dependencies: List[str] = []
+ for dependency in manifest.plugin_dependencies:
+ dependency_version = available_plugin_versions.get(dependency.id)
+ if not dependency_version:
+ unsatisfied_dependencies.append(f"{dependency.id} (未找到依赖插件)")
+ continue
+
+ if not self._version_matches_specifier(dependency_version, dependency.version_spec):
+ unsatisfied_dependencies.append(
+ f"{dependency.id} (需要 {dependency.version_spec},当前 {dependency_version})"
+ )
+
+ return unsatisfied_dependencies
+
+ def is_plugin_dependency_satisfied(
+ self,
+ dependency: PluginDependencyDefinition,
+ plugin_version: str,
+ ) -> bool:
+ """判断单个插件依赖是否被指定版本满足。
+
+ Args:
+ dependency: 插件级依赖声明。
+ plugin_version: 当前可用的插件版本号。
+
+ Returns:
+ bool: 是否满足版本约束。
+ """
+ return self._version_matches_specifier(plugin_version, dependency.version_spec)
+
+ def _validate_runtime_compatibility(self, manifest: PluginManifest) -> None:
+ """校验运行时版本兼容性与 Python 包依赖。
+
+ Args:
+ manifest: 已通过结构校验的强类型 Manifest。
+ """
+ host_ok, host_message = VersionComparator.is_in_range(
+ self._host_version,
+ manifest.host_application.min_version,
+ manifest.host_application.max_version,
+ )
+ if not host_ok:
+ self.errors.append(f"Host 版本不兼容: {host_message} (当前 Host: {self._host_version})")
+
+ sdk_ok, sdk_message = VersionComparator.is_in_range(
+ self._sdk_version,
+ manifest.sdk.min_version,
+ manifest.sdk.max_version,
+ )
+ if not sdk_ok:
+ self.errors.append(f"SDK 版本不兼容: {sdk_message} (当前 SDK: {self._sdk_version})")
+
+ self._validate_python_package_dependencies(manifest)
+
+ def _validate_python_package_dependencies(self, manifest: PluginManifest) -> None:
+ """校验 Python 包依赖与主程序运行环境是否冲突。
+
+ Args:
+ manifest: 已通过结构校验的强类型 Manifest。
+ """
+ host_requirements = self._load_host_dependency_requirements(self._project_root)
+
+ for dependency in manifest.python_package_dependencies:
+ normalized_package_name = canonicalize_name(dependency.name)
+ package_specifier = self._build_specifier_set(dependency.version_spec)
+ if package_specifier is None:
+ self.errors.append(
+ f"Python 包依赖 {dependency.name} 的版本约束无效: {dependency.version_spec}"
+ )
+ continue
+
+ installed_version = self._get_installed_package_version(dependency.name)
+ host_requirement = host_requirements.get(normalized_package_name)
+
+ if installed_version is not None and not self._version_matches_specifier(
+ installed_version,
+ dependency.version_spec,
+ ):
+ self.errors.append(
+ f"Python 包依赖冲突: {dependency.name} 需要 {dependency.version_spec},"
+ f"当前运行环境为 {installed_version}"
+ )
+ continue
+
+ if host_requirement is None:
+ continue
+
+ if not self._requirements_may_overlap(host_requirement.specifier, package_specifier):
+ host_specifier = str(host_requirement.specifier or "")
+ self.errors.append(
+ f"Python 包依赖冲突: {dependency.name} 需要 {dependency.version_spec},"
+ f"主程序依赖约束为 {host_specifier or '任意版本'}"
+ )
+
+ def _log_errors(self) -> None:
+ """输出当前累计的 Manifest 校验错误。"""
+ for error_message in self.errors:
+ logger.error(f"Manifest 校验失败: {error_message}")
+
+ @classmethod
+ def _resolve_project_root(cls) -> Path:
+ """推断当前项目根目录。
+
+ Returns:
+ Path: 项目根目录路径。
+ """
+ return Path(__file__).resolve().parents[3]
+
+ @classmethod
+ @lru_cache(maxsize=None)
+ def _detect_default_host_version(cls, project_root: Path) -> str:
+ """从主程序 ``pyproject.toml`` 探测 Host 版本号。
+
+ Args:
+ project_root: 项目根目录。
+
+ Returns:
+ str: 探测到的 Host 版本号;失败时返回空字符串。
+ """
+ pyproject_path = project_root / "pyproject.toml"
+ try:
+ with pyproject_path.open("rb") as pyproject_file:
+ pyproject_data = tomllib.load(pyproject_file)
+ except Exception:
+ return ""
+
+ project_data = pyproject_data.get("project", {})
+ if not isinstance(project_data, dict):
+ return ""
+
+ raw_version = str(project_data.get("version", "") or "").strip()
+ if VersionComparator.is_valid_semver(raw_version):
+ return raw_version
+ return ""
+
+ @classmethod
+ @lru_cache(maxsize=None)
+ def _detect_default_sdk_version(cls, project_root: Path) -> str:
+ """探测当前运行环境中的 SDK 版本号。
+
+ Args:
+ project_root: 项目根目录。
+
+ Returns:
+ str: 探测到的 SDK 版本号;失败时返回空字符串。
+ """
+ try:
+ raw_version = importlib_metadata.version("maibot-plugin-sdk")
+ if VersionComparator.is_valid_semver(raw_version):
+ return raw_version
+ except importlib_metadata.PackageNotFoundError:
+ pass
+
+ sdk_pyproject_path = project_root / "packages" / "maibot-plugin-sdk" / "pyproject.toml"
+ try:
+ with sdk_pyproject_path.open("rb") as pyproject_file:
+ pyproject_data = tomllib.load(pyproject_file)
+ except Exception:
+ return ""
+
+ project_data = pyproject_data.get("project", {})
+ if not isinstance(project_data, dict):
+ return ""
+
+ raw_version = str(project_data.get("version", "") or "").strip()
+ if VersionComparator.is_valid_semver(raw_version):
+ return raw_version
+ return ""
+
+ @classmethod
+ @lru_cache(maxsize=None)
+ def _load_host_dependency_requirements(cls, project_root: Path) -> Dict[str, Requirement]:
+ """加载主程序 ``pyproject.toml`` 中声明的依赖约束。
+
+ Args:
+ project_root: 项目根目录。
+
+ Returns:
+ Dict[str, Requirement]: 以规范化包名为键的 Requirement 映射。
+ """
+ pyproject_path = project_root / "pyproject.toml"
+ try:
+ with pyproject_path.open("rb") as pyproject_file:
+ pyproject_data = tomllib.load(pyproject_file)
+ except Exception:
+ return {}
+
+ project_data = pyproject_data.get("project", {})
+ if not isinstance(project_data, dict):
+ return {}
+
+ raw_dependencies = project_data.get("dependencies", [])
+ if not isinstance(raw_dependencies, list):
+ return {}
+
+ requirements: Dict[str, Requirement] = {}
+ for raw_dependency in raw_dependencies:
+ dependency_text = str(raw_dependency or "").strip()
+ if not dependency_text:
+ continue
+
+ try:
+ requirement = Requirement(dependency_text)
+ except InvalidRequirement:
+ continue
+
+ requirements[canonicalize_name(requirement.name)] = requirement
+
+ return requirements
+
+ @staticmethod
+ def _get_installed_package_version(package_name: str) -> Optional[str]:
+ """获取当前运行环境中指定 Python 包的安装版本。
+
+ Args:
+ package_name: 待查询的包名。
+
+ Returns:
+ Optional[str]: 已安装版本号;未安装时返回 ``None``。
+ """
+ try:
+ return importlib_metadata.version(package_name)
+ except importlib_metadata.PackageNotFoundError:
+ return None
+
+ @staticmethod
+ def _build_specifier_set(version_spec: str) -> Optional[SpecifierSet]:
+ """构造版本约束对象。
+
+ Args:
+ version_spec: 版本约束字符串。
+
+ Returns:
+ Optional[SpecifierSet]: 构造成功时返回约束对象,否则返回 ``None``。
+ """
+ try:
+ return SpecifierSet(version_spec)
+ except InvalidSpecifier:
+ return None
+
+ @staticmethod
+ def _version_matches_specifier(version: str, version_spec: str) -> bool:
+ """判断版本是否满足给定约束。
+
+ Args:
+ version: 待判断的版本号。
+ version_spec: 版本约束表达式。
+
+ Returns:
+ bool: 是否满足约束。
+ """
+ try:
+ normalized_version = Version(version)
+ specifier_set = SpecifierSet(version_spec)
+ except (InvalidVersion, InvalidSpecifier):
+ return False
+ return specifier_set.contains(normalized_version, prereleases=True)
+
+ @classmethod
+ def _requirements_may_overlap(cls, left: SpecifierSet, right: SpecifierSet) -> bool:
+ """粗略判断两个版本约束是否存在交集。
+
+ Args:
+ left: 左侧版本约束。
+ right: 右侧版本约束。
+
+ Returns:
+ bool: 若可能存在交集则返回 ``True``,否则返回 ``False``。
+ """
+ candidate_versions = cls._build_candidate_versions(left, right)
+ for candidate_version in candidate_versions:
+ if left.contains(candidate_version, prereleases=True) and right.contains(candidate_version, prereleases=True):
+ return True
+ return False
+
+ @classmethod
+ def _build_candidate_versions(cls, left: SpecifierSet, right: SpecifierSet) -> List[Version]:
+ """为两个版本约束构造一组用于交集探测的候选版本。
+
+ Args:
+ left: 左侧版本约束。
+ right: 右侧版本约束。
+
+ Returns:
+ List[Version]: 去重后的候选版本列表。
+ """
+ candidate_versions: List[Version] = [Version("0.0.0")]
+ for specifier in tuple(left) + tuple(right):
+ for candidate_version in cls._expand_candidate_versions(specifier.version):
+ if candidate_version not in candidate_versions:
+ candidate_versions.append(candidate_version)
+ return candidate_versions
+
+ @staticmethod
+ def _expand_candidate_versions(raw_version: str) -> List[Version]:
+ """根据边界版本扩展出一组邻近候选版本。
+
+ Args:
+ raw_version: 约束中出现的边界版本字符串。
+
+ Returns:
+ List[Version]: 可用于交集探测的候选版本列表。
+ """
+ normalized_text = raw_version.replace("*", "0")
+ try:
+ boundary_version = Version(normalized_text)
+ except InvalidVersion:
+ return []
+
+ release_parts = list(boundary_version.release[:3])
+ while len(release_parts) < 3:
+ release_parts.append(0)
+ major, minor, patch = release_parts[:3]
+
+ candidates = {
+ Version(f"{major}.{minor}.{patch}"),
+ Version(f"{major}.{minor}.{patch + 1}"),
+ }
+ if patch > 0:
+ candidates.add(Version(f"{major}.{minor}.{patch - 1}"))
+ elif minor > 0:
+ candidates.add(Version(f"{major}.{minor - 1}.999"))
+ elif major > 0:
+ candidates.add(Version(f"{major - 1}.999.999"))
+
+ return sorted(candidates)
+
+ @classmethod
+ def _format_validation_errors(cls, exc: ValidationError) -> List[str]:
+ """将 Pydantic 校验错误转换为中文错误列表。
+
+ Args:
+ exc: Pydantic 抛出的校验异常。
+
+ Returns:
+ List[str]: 中文错误描述列表。
+ """
+ error_messages: List[str] = []
+ for error in exc.errors():
+ location = cls._format_error_location(error.get("loc", ()))
+ error_type = str(error.get("type", ""))
+ error_input = error.get("input")
+ error_context = error.get("ctx", {}) or {}
+
+ if error_type == "missing":
+ error_messages.append(f"缺少必需字段: {location}")
+ elif error_type == "extra_forbidden":
+ error_messages.append(f"存在未声明字段: {location}")
+ elif error_type == "literal_error":
+ expected_values = error_context.get("expected")
+ error_messages.append(f"字段 {location} 的值不合法,必须为 {expected_values}")
+ elif error_type == "model_type":
+ error_messages.append(f"字段 {location} 必须为对象")
+ elif error_type.endswith("_type"):
+ error_messages.append(f"字段 {location} 的类型不正确")
+ elif error_type == "value_error":
+ error_messages.append(f"字段 {location} 校验失败: {error_context.get('error')}")
+ else:
+ error_messages.append(f"字段 {location} 校验失败: {error.get('msg', error_input)}")
+
+ return error_messages
+
+ @staticmethod
+ def _format_error_location(location: Tuple[Any, ...]) -> str:
+ """格式化校验错误字段路径。
+
+ Args:
+ location: Pydantic 提供的字段路径元组。
+
+ Returns:
+ str: 点号连接后的字段路径。
+ """
+ return ".".join(str(item) for item in location) if location else ""
diff --git a/src/plugin_runtime/runner/plugin_loader.py b/src/plugin_runtime/runner/plugin_loader.py
index f07eb593..3eaf9f23 100644
--- a/src/plugin_runtime/runner/plugin_loader.py
+++ b/src/plugin_runtime/runner/plugin_loader.py
@@ -13,16 +13,16 @@ from typing import Any, Dict, Iterator, List, Optional, Set, Tuple
import contextlib
import importlib
import importlib.util
-import json
import os
+import re
import sys
from src.common.logger import get_logger
-from src.plugin_runtime.runner.manifest_validator import ManifestValidator
+from src.plugin_runtime.runner.manifest_validator import ManifestValidator, PluginManifest
logger = get_logger("plugin_runtime.runner.plugin_loader")
-PluginCandidate = Tuple[Path, Dict[str, Any], Path]
+PluginCandidate = Tuple[Path, PluginManifest, Path]
class PluginMeta:
@@ -34,7 +34,7 @@ class PluginMeta:
plugin_dir: str,
module_name: str,
plugin_instance: Any,
- manifest: Dict[str, Any],
+ manifest: PluginManifest,
) -> None:
"""初始化插件元数据。
@@ -43,36 +43,16 @@ class PluginMeta:
plugin_dir: 插件目录绝对路径。
module_name: 插件入口模块名。
plugin_instance: 插件实例对象。
- manifest: 解析后的 manifest 内容。
+ manifest: 解析后的强类型 Manifest。
"""
self.plugin_id = plugin_id
self.plugin_dir = plugin_dir
self.module_name = module_name
self.instance = plugin_instance
self.manifest = manifest
- self.version = manifest.get("version", "1.0.0")
- self.capabilities_required = manifest.get("capabilities", [])
- self.dependencies: List[str] = self._extract_dependencies(manifest)
-
- @staticmethod
- def _extract_dependencies(manifest: Dict[str, Any]) -> List[str]:
- """从 manifest 中提取依赖列表。
-
- Args:
- manifest: 插件 manifest。
-
- Returns:
- List[str]: 规范化后的依赖插件 ID 列表。
- """
- raw = manifest.get("dependencies", [])
- result: List[str] = []
- for dep in raw:
- if isinstance(dep, str):
- result.append(dep.strip())
- elif isinstance(dep, dict):
- if name := str(dep.get("name", "")).strip():
- result.append(name)
- return result
+ self.version = manifest.version
+ self.capabilities_required = list(manifest.capabilities)
+ self.dependencies: List[str] = list(manifest.plugin_dependency_ids)
class PluginLoader:
@@ -98,13 +78,13 @@ class PluginLoader:
def discover_and_load(
self,
plugin_dirs: List[str],
- extra_available: Optional[Set[str]] = None,
+ extra_available: Optional[Dict[str, str]] = None,
) -> List[PluginMeta]:
"""扫描多个目录并加载所有插件。
Args:
plugin_dirs: 插件目录列表。
- extra_available: 额外视为已满足的外部依赖插件 ID 集合。
+ extra_available: 额外视为已满足的外部依赖插件版本映射。
Returns:
List[PluginMeta]: 成功加载的插件元数据列表,按依赖顺序排列。
@@ -164,26 +144,17 @@ class PluginLoader:
def _discover_single_candidate(self, plugin_dir: Path) -> Optional[Tuple[str, PluginCandidate]]:
"""发现并校验单个插件目录。"""
- manifest_path = plugin_dir / "_manifest.json"
plugin_path = plugin_dir / "plugin.py"
-
- if not manifest_path.exists() or not plugin_path.exists():
+ if not plugin_path.exists():
return None
- try:
- with manifest_path.open("r", encoding="utf-8") as manifest_file:
- manifest: Dict[str, Any] = json.load(manifest_file)
- except Exception as e:
- self._failed_plugins[plugin_dir.name] = f"manifest 解析失败: {e}"
- logger.error(f"插件 {plugin_dir.name} manifest 解析失败: {e}")
- return None
-
- if not self._manifest_validator.validate(manifest):
+ manifest = self._manifest_validator.load_from_plugin_path(plugin_dir)
+ if manifest is None:
errors = "; ".join(self._manifest_validator.errors)
self._failed_plugins[plugin_dir.name] = f"manifest 校验失败: {errors}"
return None
- plugin_id = str(manifest.get("name", plugin_dir.name)).strip() or plugin_dir.name
+ plugin_id = manifest.id
return plugin_id, (plugin_dir, manifest, plugin_path)
def _record_duplicate_candidates(self, duplicate_candidates: Dict[str, List[Path]]) -> None:
@@ -253,7 +224,7 @@ class PluginLoader:
"""
removed_modules: List[str] = []
plugin_path = Path(plugin_dir).resolve()
- synthetic_module_name = f"_maibot_plugin_{plugin_id}"
+ synthetic_module_name = self._build_safe_module_name(plugin_id)
for module_name, module in list(sys.modules.items()):
if module_name == synthetic_module_name:
@@ -277,6 +248,21 @@ class PluginLoader:
importlib.invalidate_caches()
return removed_modules
+ @staticmethod
+ def _build_safe_module_name(plugin_id: str) -> str:
+ """将插件 ID 转换为可用于动态导入的安全模块名。
+
+ Args:
+ plugin_id: 原始插件 ID。
+
+ Returns:
+ str: 仅包含字母、数字和下划线的合成模块名。
+ """
+ normalized_plugin_id = re.sub(r"[^0-9A-Za-z_]", "_", str(plugin_id or "").strip())
+ if normalized_plugin_id and normalized_plugin_id[0].isdigit():
+ normalized_plugin_id = f"_{normalized_plugin_id}"
+ return f"_maibot_plugin_{normalized_plugin_id or 'plugin'}"
+
def list_plugins(self) -> List[str]:
"""列出所有已加载的插件 ID"""
return list(self._loaded_plugins.keys())
@@ -286,18 +272,27 @@ class PluginLoader:
"""返回当前记录的失败插件原因映射。"""
return dict(self._failed_plugins)
+ @property
+ def manifest_validator(self) -> ManifestValidator:
+ """返回当前加载器持有的 Manifest 校验器。
+
+ Returns:
+ ManifestValidator: 当前使用的 Manifest 校验器实例。
+ """
+ return self._manifest_validator
+
# ──── 依赖解析 ────────────────────────────────────────────
def resolve_dependencies(
self,
candidates: Dict[str, PluginCandidate],
- extra_available: Optional[Set[str]] = None,
+ extra_available: Optional[Dict[str, str]] = None,
) -> Tuple[List[str], Dict[str, str]]:
"""解析候选插件的依赖顺序。
Args:
candidates: 待加载的候选插件集合。
- extra_available: 视为已满足的外部依赖插件 ID 集合。
+ extra_available: 视为已满足的外部依赖插件版本映射。
Returns:
Tuple[List[str], Dict[str, str]]: 可加载顺序和失败原因映射。
@@ -320,36 +315,71 @@ class PluginLoader:
def _resolve_dependencies(
self,
candidates: Dict[str, PluginCandidate],
- extra_available: Optional[Set[str]] = None,
+ extra_available: Optional[Dict[str, str]] = None,
) -> Tuple[List[str], Dict[str, str]]:
"""拓扑排序解析加载顺序,返回 (有序列表, 失败项 {id: reason})。"""
available = set(candidates.keys())
- satisfied_dependencies = set(extra_available or set())
+ satisfied_dependencies = {
+ str(plugin_id or "").strip(): str(plugin_version or "").strip()
+ for plugin_id, plugin_version in (extra_available or {}).items()
+ if str(plugin_id or "").strip() and str(plugin_version or "").strip()
+ }
dep_graph: Dict[str, Set[str]] = {}
failed: Dict[str, str] = {}
for pid, (_, manifest, _) in candidates.items():
- raw_deps = manifest.get("dependencies", [])
resolved: Set[str] = set()
- missing: List[str] = []
- for dep in raw_deps:
- dep_name = dep if isinstance(dep, str) else str(dep.get("name", ""))
- dep_name = dep_name.strip()
- if not dep_name or dep_name == pid:
+ missing_or_incompatible: List[str] = []
+
+ for dependency in manifest.plugin_dependencies:
+ dependency_id = dependency.id
+ if dependency_id in available:
+ dependency_manifest = candidates[dependency_id][1]
+ if not self._manifest_validator.is_plugin_dependency_satisfied(
+ dependency,
+ dependency_manifest.version,
+ ):
+ missing_or_incompatible.append(
+ f"{dependency_id} (需要 {dependency.version_spec},当前 {dependency_manifest.version})"
+ )
+ continue
+ resolved.add(dependency_id)
continue
- if dep_name in available:
- resolved.add(dep_name)
- elif dep_name in satisfied_dependencies:
+
+ external_dependency_version = satisfied_dependencies.get(dependency_id)
+ if external_dependency_version is None:
+ missing_or_incompatible.append(f"{dependency_id} (未找到依赖插件)")
continue
- else:
- missing.append(dep_name)
- if missing:
- failed[pid] = f"缺少依赖: {', '.join(missing)}"
+
+ if not self._manifest_validator.is_plugin_dependency_satisfied(
+ dependency,
+ external_dependency_version,
+ ):
+ missing_or_incompatible.append(
+ f"{dependency_id} (需要 {dependency.version_spec},当前 {external_dependency_version})"
+ )
+
+ if missing_or_incompatible:
+ failed[pid] = f"依赖未满足: {', '.join(missing_or_incompatible)}"
dep_graph[pid] = resolved
- # 移除失败项
- for pid in failed:
- dep_graph.pop(pid, None)
+ # 迭代传播“依赖自身加载失败”到上游依赖方,避免误报为循环依赖
+ changed = True
+ while changed:
+ changed = False
+ failed_plugin_ids = set(failed)
+ for pid, dependencies in list(dep_graph.items()):
+ if pid in failed:
+ dep_graph.pop(pid, None)
+ continue
+
+ failed_dependencies = sorted(dependency for dependency in dependencies if dependency in failed_plugin_ids)
+ if not failed_dependencies:
+ continue
+
+ failed[pid] = f"依赖未满足: {', '.join(f'{dependency} (依赖插件加载失败)' for dependency in failed_dependencies)}"
+ dep_graph.pop(pid, None)
+ changed = True
# Kahn 拓扑排序
indegree = {pid: len(deps) for pid, deps in dep_graph.items()}
@@ -382,7 +412,7 @@ class PluginLoader:
self,
plugin_id: str,
plugin_dir: Path,
- manifest: Dict[str, Any],
+ manifest: PluginManifest,
plugin_path: Path,
) -> Optional[PluginMeta]:
"""加载单个插件"""
@@ -390,7 +420,7 @@ class PluginLoader:
self._ensure_compat_hook()
# 动态导入插件模块
- module_name = f"_maibot_plugin_{plugin_id}"
+ module_name = self._build_safe_module_name(plugin_id)
spec = importlib.util.spec_from_file_location(module_name, str(plugin_path))
if spec is None or spec.loader is None:
logger.error(f"无法创建模块 spec: {plugin_path}")
@@ -409,7 +439,7 @@ class PluginLoader:
if create_plugin is not None:
instance = create_plugin()
self._validate_sdk_plugin_contract(plugin_id, instance)
- logger.info(f"插件 {plugin_id} v{manifest.get('version', '?')} 加载成功")
+ logger.info(f"插件 {plugin_id} v{manifest.version} 加载成功")
return PluginMeta(
plugin_id=plugin_id,
plugin_dir=str(plugin_dir),
@@ -422,7 +452,7 @@ class PluginLoader:
instance = self._try_load_legacy_plugin(module, plugin_id)
if instance is not None:
logger.info(
- f"插件 {plugin_id} v{manifest.get('version', '?')} 通过旧版兼容层加载成功(请尽快迁移到 maibot_sdk)"
+ f"插件 {plugin_id} v{manifest.version} 通过旧版兼容层加载成功(请尽快迁移到 maibot_sdk)"
)
return PluginMeta(
plugin_id=plugin_id,
diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py
index c0f5e771..f5c32f7e 100644
--- a/src/plugin_runtime/runner/runner_main.py
+++ b/src/plugin_runtime/runner/runner_main.py
@@ -10,7 +10,7 @@
"""
from pathlib import Path
-from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, cast
+from typing import Any, Callable, Dict, List, Optional, Protocol, Set, cast
import asyncio
import contextlib
@@ -47,7 +47,7 @@ from src.plugin_runtime.protocol.envelope import (
)
from src.plugin_runtime.protocol.errors import ErrorCode
from src.plugin_runtime.runner.log_handler import RunnerIPCLogHandler
-from src.plugin_runtime.runner.plugin_loader import PluginLoader, PluginMeta
+from src.plugin_runtime.runner.plugin_loader import PluginCandidate, PluginLoader, PluginMeta
from src.plugin_runtime.runner.rpc_client import RPCClient
logger = get_logger("plugin_runtime.runner.main")
@@ -119,7 +119,7 @@ class PluginRunner:
host_address: str,
session_token: str,
plugin_dirs: List[str],
- external_available_plugin_ids: Optional[List[str]] = None,
+ external_available_plugins: Optional[Dict[str, str]] = None,
) -> None:
"""初始化 Runner。
@@ -127,15 +127,15 @@ class PluginRunner:
host_address: Host 的 IPC 地址。
session_token: 握手用会话令牌。
plugin_dirs: 当前 Runner 负责扫描的插件目录列表。
- external_available_plugin_ids: 视为已满足的外部依赖插件 ID 列表。
+ external_available_plugins: 视为已满足的外部依赖插件版本映射。
"""
self._host_address: str = host_address
self._session_token: str = session_token
self._plugin_dirs: List[str] = plugin_dirs
- self._external_available_plugin_ids: Set[str] = {
- str(plugin_id or "").strip()
- for plugin_id in (external_available_plugin_ids or [])
- if str(plugin_id or "").strip()
+ self._external_available_plugins: Dict[str, str] = {
+ str(plugin_id or "").strip(): str(plugin_version or "").strip()
+ for plugin_id, plugin_version in (external_available_plugins or {}).items()
+ if str(plugin_id or "").strip() and str(plugin_version or "").strip()
}
self._rpc_client: RPCClient = RPCClient(host_address, session_token)
@@ -166,7 +166,7 @@ class PluginRunner:
# 3. 加载插件
plugins = self._loader.discover_and_load(
self._plugin_dirs,
- extra_available=self._external_available_plugin_ids,
+ extra_available=self._external_available_plugins,
)
logger.info(f"已加载 {len(plugins)} 个插件")
@@ -611,14 +611,14 @@ class PluginRunner:
self,
plugin_id: str,
reason: str,
- external_available_plugins: Optional[Set[str]] = None,
+ external_available_plugins: Optional[Dict[str, str]] = None,
) -> ReloadPluginResultPayload:
"""按插件 ID 在 Runner 进程内执行精确重载。
Args:
plugin_id: 目标插件 ID。
reason: 重载原因。
- external_available_plugins: 视为已满足的外部依赖插件 ID 集合。
+ external_available_plugins: 视为已满足的外部依赖插件版本映射。
Returns:
ReloadPluginResultPayload: 结构化重载结果。
@@ -626,9 +626,9 @@ class PluginRunner:
candidates, duplicate_candidates = self._loader.discover_candidates(self._plugin_dirs)
failed_plugins: Dict[str, str] = {}
normalized_external_available = {
- str(candidate_plugin_id or "").strip()
- for candidate_plugin_id in (external_available_plugins or set())
- if str(candidate_plugin_id or "").strip()
+ str(candidate_plugin_id or "").strip(): str(candidate_plugin_version or "").strip()
+ for candidate_plugin_id, candidate_plugin_version in (external_available_plugins or {}).items()
+ if str(candidate_plugin_id or "").strip() and str(candidate_plugin_version or "").strip()
}
if plugin_id in duplicate_candidates:
@@ -668,7 +668,7 @@ class PluginRunner:
self._loader.purge_plugin_modules(unload_plugin_id, meta.plugin_dir)
unloaded_plugins.append(unload_plugin_id)
- reload_candidates: Dict[str, Tuple[Path, Dict[str, Any], Path]] = {}
+ reload_candidates: Dict[str, PluginCandidate] = {}
for target_plugin_id in target_plugin_ids:
candidate = candidates.get(target_plugin_id)
if candidate is None:
@@ -678,11 +678,25 @@ class PluginRunner:
load_order, dependency_failures = self._loader.resolve_dependencies(
reload_candidates,
- extra_available=retained_plugin_ids | normalized_external_available,
+ extra_available={
+ **normalized_external_available,
+ **{
+ retained_plugin_id: retained_meta.version
+ for retained_plugin_id in retained_plugin_ids
+ if (retained_meta := self._loader.get_plugin(retained_plugin_id)) is not None
+ },
+ },
)
failed_plugins.update(dependency_failures)
- available_plugins = set(retained_plugin_ids) | normalized_external_available
+ available_plugins = {
+ **normalized_external_available,
+ **{
+ retained_plugin_id: retained_meta.version
+ for retained_plugin_id in retained_plugin_ids
+ if (retained_meta := self._loader.get_plugin(retained_plugin_id)) is not None
+ },
+ }
reloaded_plugins: List[str] = []
for load_plugin_id in load_order:
@@ -694,10 +708,12 @@ class PluginRunner:
continue
_, manifest, _ = candidate
- dependencies = PluginMeta._extract_dependencies(manifest)
- missing_dependencies = [dependency for dependency in dependencies if dependency not in available_plugins]
- if missing_dependencies:
- failed_plugins[load_plugin_id] = f"依赖未满足: {', '.join(missing_dependencies)}"
+ unsatisfied_dependencies = self._loader.manifest_validator.get_unsatisfied_plugin_dependencies(
+ manifest,
+ available_plugin_versions=available_plugins,
+ )
+ if unsatisfied_dependencies:
+ failed_plugins[load_plugin_id] = f"依赖未满足: {', '.join(unsatisfied_dependencies)}"
continue
meta = self._loader.load_candidate(load_plugin_id, candidate)
@@ -710,7 +726,7 @@ class PluginRunner:
failed_plugins[load_plugin_id] = "插件初始化失败"
continue
- available_plugins.add(load_plugin_id)
+ available_plugins[load_plugin_id] = meta.version
reloaded_plugins.append(load_plugin_id)
if failed_plugins:
@@ -1079,7 +1095,7 @@ class PluginRunner:
result = await self._reload_plugin_by_id(
payload.plugin_id,
payload.reason,
- external_available_plugins=set(payload.external_available_plugins),
+ external_available_plugins=dict(payload.external_available_plugins),
)
return envelope.make_response(payload=result.model_dump())
@@ -1185,13 +1201,13 @@ async def _async_main() -> None:
plugin_dirs = [d for d in plugin_dirs_str.split(os.pathsep) if d]
try:
- external_plugin_ids = json.loads(external_plugin_ids_raw) if external_plugin_ids_raw else []
+ external_plugin_ids = json.loads(external_plugin_ids_raw) if external_plugin_ids_raw else {}
except json.JSONDecodeError:
- logger.warning("解析外部依赖插件列表失败,已回退为空列表")
- external_plugin_ids = []
- if not isinstance(external_plugin_ids, list):
- logger.warning("外部依赖插件列表格式非法,已回退为空列表")
- external_plugin_ids = []
+ logger.warning("解析外部依赖插件版本映射失败,已回退为空映射")
+ external_plugin_ids = {}
+ if not isinstance(external_plugin_ids, dict):
+ logger.warning("外部依赖插件版本映射格式非法,已回退为空映射")
+ external_plugin_ids = {}
# sys.path 隔离: 只保留标准库、SDK 包、插件目录
_isolate_sys_path(plugin_dirs)
@@ -1200,7 +1216,10 @@ async def _async_main() -> None:
host_address,
session_token,
plugin_dirs,
- external_available_plugin_ids=[str(plugin_id) for plugin_id in external_plugin_ids],
+ external_available_plugins={
+ str(plugin_id): str(plugin_version)
+ for plugin_id, plugin_version in external_plugin_ids.items()
+ },
)
# 注册信号处理
diff --git a/src/plugins/built_in/emoji_plugin/_manifest.json b/src/plugins/built_in/emoji_plugin/_manifest.json
index d4d262e7..5b53abad 100644
--- a/src/plugins/built_in/emoji_plugin/_manifest.json
+++ b/src/plugins/built_in/emoji_plugin/_manifest.json
@@ -1,32 +1,28 @@
{
- "manifest_version": 1,
- "name": "Emoji插件 (Emoji Actions)",
+ "manifest_version": 2,
"version": "2.0.0",
- "description": "可以发送和管理Emoji",
+ "name": "Emoji插件 (Emoji Actions)",
+ "description": "可以发送和管理 Emoji",
"author": {
"name": "SengokuCola",
"url": "https://github.com/MaiM-with-u"
},
"license": "GPL-v3.0-or-later",
+ "urls": {
+ "repository": "https://github.com/MaiM-with-u/maibot",
+ "homepage": "https://github.com/MaiM-with-u/maibot",
+ "documentation": "https://github.com/MaiM-with-u/maibot",
+ "issues": "https://github.com/MaiM-with-u/maibot/issues"
+ },
"host_application": {
- "min_version": "1.0.0"
+ "min_version": "1.0.0",
+ "max_version": "1.0.0"
},
- "homepage_url": "https://github.com/MaiM-with-u/maibot",
- "repository_url": "https://github.com/MaiM-with-u/maibot",
- "keywords": ["emoji", "action", "built-in"],
- "categories": ["Emoji"],
- "default_locale": "zh-CN",
- "plugin_info": {
- "is_built_in": true,
- "plugin_type": "action_provider",
- "components": [
- {
- "type": "action",
- "name": "emoji",
- "description": "发送表情包辅助表达情绪"
- }
- ]
+ "sdk": {
+ "min_version": "2.0.0",
+ "max_version": "2.99.99"
},
+ "dependencies": [],
"capabilities": [
"emoji.get_random",
"message.get_recent",
@@ -34,5 +30,12 @@
"llm.generate",
"send.emoji",
"config.get"
- ]
+ ],
+ "i18n": {
+ "default_locale": "zh-CN",
+ "supported_locales": [
+ "zh-CN"
+ ]
+ },
+ "id": "builtin.emoji-plugin"
}
diff --git a/src/plugins/built_in/plugin_management/_manifest.json b/src/plugins/built_in/plugin_management/_manifest.json
index a5b52835..a2bfa9ce 100644
--- a/src/plugins/built_in/plugin_management/_manifest.json
+++ b/src/plugins/built_in/plugin_management/_manifest.json
@@ -1,51 +1,46 @@
{
- "manifest_version": 1,
- "name": "插件和组件管理 (Plugin and Component Management)",
+ "manifest_version": 2,
"version": "2.0.0",
- "description": "通过系统API管理插件和组件的生命周期,包括加载、卸载、启用和禁用等操作。",
+ "name": "插件和组件管理 (Plugin and Component Management)",
+ "description": "通过系统 API 管理插件和组件的生命周期,包括加载、卸载、启用和禁用等操作。",
"author": {
"name": "MaiBot团队",
"url": "https://github.com/MaiM-with-u"
},
"license": "GPL-v3.0-or-later",
- "host_application": {
- "min_version": "1.0.0"
+ "urls": {
+ "repository": "https://github.com/MaiM-with-u/maibot",
+ "homepage": "https://github.com/MaiM-with-u/maibot",
+ "documentation": "https://github.com/MaiM-with-u/maibot",
+ "issues": "https://github.com/MaiM-with-u/maibot/issues"
},
- "homepage_url": "https://github.com/MaiM-with-u/maibot",
- "repository_url": "https://github.com/MaiM-with-u/maibot",
- "keywords": [
- "plugins",
- "components",
- "management",
- "built-in"
+ "host_application": {
+ "min_version": "1.0.0",
+ "max_version": "1.0.0"
+ },
+ "sdk": {
+ "min_version": "2.0.0",
+ "max_version": "2.99.99"
+ },
+ "dependencies": [],
+ "capabilities": [
+ "component.get_all_plugins",
+ "component.list_loaded_plugins",
+ "component.list_registered_plugins",
+ "component.enable",
+ "component.disable",
+ "component.load_plugin",
+ "component.unload_plugin",
+ "component.reload_plugin",
+ "send.text",
+ "config.get"
],
- "categories": [
- "Core System",
- "Plugin Management"
- ],
- "default_locale": "zh-CN",
- "locales_path": "_locales",
- "plugin_info": {
- "is_built_in": true,
- "plugin_type": "plugin_management",
- "capabilities": [
- "component.get_all_plugins",
- "component.list_loaded_plugins",
- "component.list_registered_plugins",
- "component.enable",
- "component.disable",
- "component.load_plugin",
- "component.unload_plugin",
- "component.reload_plugin",
- "send.text",
- "config.get"
- ],
- "components": [
- {
- "type": "command",
- "name": "management",
- "description": "管理插件和组件的生命周期,包括加载、卸载、启用和禁用等操作。"
- }
+ "i18n": {
+ "default_locale": "zh-CN",
+ "locales_path": "_locales",
+ "supported_locales": [
+ "zh-CN"
]
- }
-}
\ No newline at end of file
+ },
+ "id": "builtin.plugin-management"
+}
From a61b124c93e456196216fc5724112e62867aedac Mon Sep 17 00:00:00 2001
From: DrSmoothl <1787882683@qq.com>
Date: Mon, 23 Mar 2026 23:05:05 +0800
Subject: [PATCH 38/45] feat: enhance global config handling and component
resolution in plugin runtime
---
src/plugin_runtime/host/supervisor.py | 4 +-
src/plugin_runtime/integration.py | 4 +-
src/plugin_runtime/runner/plugin_loader.py | 7 +-
src/plugin_runtime/runner/runner_main.py | 90 ++++++++++++++++------
4 files changed, 75 insertions(+), 30 deletions(-)
diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py
index ac953bb3..d12014f6 100644
--- a/src/plugin_runtime/host/supervisor.py
+++ b/src/plugin_runtime/host/supervisor.py
@@ -1146,8 +1146,8 @@ class PluginRunnerSupervisor:
Returns:
Dict[str, str]: 传递给 Runner 进程的环境变量映射。
"""
- global_config_snapshot = config_manager.get_global_config().model_dump()
- global_config_snapshot["model"] = config_manager.get_model_config().model_dump()
+ global_config_snapshot = config_manager.get_global_config().model_dump(mode="json")
+ global_config_snapshot["model"] = config_manager.get_model_config().model_dump(mode="json")
return {
ENV_EXTERNAL_PLUGIN_IDS: json.dumps(self._external_available_plugins, ensure_ascii=False),
ENV_GLOBAL_CONFIG_SNAPSHOT: json.dumps(global_config_snapshot, ensure_ascii=False),
diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py
index 092b9597..c34f5ef5 100644
--- a/src/plugin_runtime/integration.py
+++ b/src/plugin_runtime/integration.py
@@ -543,9 +543,9 @@ class PluginRuntimeManager(
normalized_scopes = self._normalize_config_reload_scopes(changed_scopes)
if "bot" in normalized_scopes:
- await self._broadcast_config_reload("bot", config_manager.get_global_config().model_dump())
+ await self._broadcast_config_reload("bot", config_manager.get_global_config().model_dump(mode="json"))
if "model" in normalized_scopes:
- await self._broadcast_config_reload("model", config_manager.get_model_config().model_dump())
+ await self._broadcast_config_reload("model", config_manager.get_model_config().model_dump(mode="json"))
# ─── 事件桥接 ──────────────────────────────────────────────
diff --git a/src/plugin_runtime/runner/plugin_loader.py b/src/plugin_runtime/runner/plugin_loader.py
index 3eaf9f23..6e85714b 100644
--- a/src/plugin_runtime/runner/plugin_loader.py
+++ b/src/plugin_runtime/runner/plugin_loader.py
@@ -53,6 +53,7 @@ class PluginMeta:
self.version = manifest.version
self.capabilities_required = list(manifest.capabilities)
self.dependencies: List[str] = list(manifest.plugin_dependency_ids)
+ self.component_handlers: Dict[str, str] = {}
class PluginLoader:
@@ -421,7 +422,11 @@ class PluginLoader:
# 动态导入插件模块
module_name = self._build_safe_module_name(plugin_id)
- spec = importlib.util.spec_from_file_location(module_name, str(plugin_path))
+ spec = importlib.util.spec_from_file_location(
+ module_name,
+ str(plugin_path),
+ submodule_search_locations=[str(plugin_dir)],
+ )
if spec is None or spec.loader is None:
logger.error(f"无法创建模块 spec: {plugin_path}")
return None
diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py
index f5c32f7e..39c741bd 100644
--- a/src/plugin_runtime/runner/runner_main.py
+++ b/src/plugin_runtime/runner/runner_main.py
@@ -335,6 +335,45 @@ class PluginRunner:
self._rpc_client.register_method("plugin.config_updated", self._handle_config_updated)
self._rpc_client.register_method("plugin.reload", self._handle_reload_plugin)
+ @staticmethod
+ def _resolve_component_handler_name(meta: PluginMeta, component_name: str) -> str:
+ """解析组件名对应的真实处理函数名。
+
+ Args:
+ meta: 已加载插件的元数据。
+ component_name: Host 侧请求中的组件声明名。
+
+ Returns:
+ str: 实际应在插件实例上查找的方法名。
+ """
+ return str(meta.component_handlers.get(component_name, component_name) or component_name)
+
+ def _resolve_component_handler(self, meta: PluginMeta, component_name: str) -> Any:
+ """根据组件声明名解析插件实例上的可调用处理函数。
+
+ Args:
+ meta: 已加载插件的元数据。
+ component_name: Host 侧请求中的组件声明名。
+
+ Returns:
+ Any: 解析到的可调用对象;未找到时返回 ``None``。
+ """
+ instance = meta.instance
+ handler_name = self._resolve_component_handler_name(meta, component_name)
+ handler_method = getattr(instance, handler_name, None)
+ if handler_method is not None:
+ return handler_method
+
+ if handler_name != component_name:
+ legacy_style_handler = getattr(instance, f"handle_{component_name}", None)
+ if legacy_style_handler is not None:
+ return legacy_style_handler
+
+ prefixed_handler = getattr(instance, f"handle_{component_name}", None)
+ if prefixed_handler is not None:
+ return prefixed_handler
+ return getattr(instance, component_name, None)
+
async def _bootstrap_plugin(self, meta: PluginMeta, capabilities_required: Optional[List[str]] = None) -> bool:
"""向 Host 同步插件 bootstrap 能力令牌。"""
payload = BootstrapPluginPayload(
@@ -379,15 +418,27 @@ class PluginRunner:
# 从插件实例获取组件声明(SDK 插件须实现 get_components 方法)
if hasattr(instance, "get_components"):
- components.extend(
- ComponentDeclaration(
- name=comp_info.get("name", ""),
- component_type=comp_info.get("type", ""),
- plugin_id=meta.plugin_id,
- metadata=comp_info.get("metadata", {}),
+ meta.component_handlers.clear()
+ for comp_info in instance.get_components():
+ if not isinstance(comp_info, dict):
+ continue
+
+ component_name = str(comp_info.get("name", "") or "").strip()
+ raw_metadata = comp_info.get("metadata", {})
+ component_metadata = raw_metadata if isinstance(raw_metadata, dict) else {}
+ handler_name = str(component_metadata.get("handler_name", component_name) or component_name).strip()
+
+ if component_name:
+ meta.component_handlers[component_name] = handler_name or component_name
+
+ components.append(
+ ComponentDeclaration(
+ name=component_name,
+ component_type=str(comp_info.get("type", "") or "").strip(),
+ plugin_id=meta.plugin_id,
+ metadata=component_metadata,
+ )
)
- for comp_info in instance.get_components()
- )
if hasattr(instance, "get_config_reload_subscriptions"):
config_reload_subscriptions = list(instance.get_config_reload_subscriptions())
@@ -812,19 +863,13 @@ class PluginRunner:
f"插件 {plugin_id} 未加载",
)
- # 调用插件实例的组件方法
- instance = meta.instance
component_name = invoke.component_name
-
- # 优先查找 handle_ 或直接 方法(新版 SDK 插件)
- handler_method = getattr(instance, f"handle_{component_name}", None)
- if handler_method is None:
- handler_method = getattr(instance, component_name, None)
+ handler_method = self._resolve_component_handler(meta, component_name)
# 回退: 旧版 LegacyPluginAdapter 通过 invoke_component 统一桥接
- if (handler_method is None or not callable(handler_method)) and hasattr(instance, "invoke_component"):
+ if (handler_method is None or not callable(handler_method)) and hasattr(meta.instance, "invoke_component"):
try:
- result = await instance.invoke_component(component_name, **invoke.args)
+ result = await meta.instance.invoke_component(component_name, **invoke.args)
resp_payload = InvokeResultPayload(success=True, result=result)
return envelope.make_response(payload=resp_payload.model_dump())
except Exception as e:
@@ -871,11 +916,8 @@ class PluginRunner:
f"插件 {plugin_id} 未加载",
)
- instance = meta.instance
component_name = invoke.component_name
- handler_method = getattr(instance, f"handle_{component_name}", None)
- if handler_method is None:
- handler_method = getattr(instance, component_name, None)
+ handler_method = self._resolve_component_handler(meta, component_name)
if handler_method is None or not callable(handler_method):
return envelope.make_error_response(
@@ -933,9 +975,8 @@ class PluginRunner:
f"插件 {plugin_id} 未加载",
)
- instance = meta.instance
component_name = invoke.component_name
- handler_method = getattr(instance, f"handle_{component_name}", None) or getattr(instance, component_name, None)
+ handler_method = self._resolve_component_handler(meta, component_name)
if handler_method is None or not callable(handler_method):
return envelope.make_error_response(
ErrorCode.E_METHOD_NOT_ALLOWED.value,
@@ -985,9 +1026,8 @@ class PluginRunner:
f"插件 {plugin_id} 未加载",
)
- instance = meta.instance
component_name = invoke.component_name
- handler_method = getattr(instance, f"handle_{component_name}", None) or getattr(instance, component_name, None)
+ handler_method = self._resolve_component_handler(meta, component_name)
if handler_method is None or not callable(handler_method):
return envelope.make_error_response(
From 17b7306188dc351240d69f0444e537d5b9324a59 Mon Sep 17 00:00:00 2001
From: DrSmoothl <1787882683@qq.com>
Date: Mon, 23 Mar 2026 23:11:57 +0800
Subject: [PATCH 39/45] =?UTF-8?q?fix:=20=E4=BD=BF=E7=94=A8=20f-string=20?=
=?UTF-8?q?=E6=94=B9=E8=BF=9B=E6=97=A5=E5=BF=97=E8=AE=B0=E5=BD=95=E6=A0=BC?=
=?UTF-8?q?=E5=BC=8F?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
src/platform_io/manager.py | 11 +++++------
1 file changed, 5 insertions(+), 6 deletions(-)
diff --git a/src/platform_io/manager.py b/src/platform_io/manager.py
index ab1b11e5..dee553a6 100644
--- a/src/platform_io/manager.py
+++ b/src/platform_io/manager.py
@@ -71,7 +71,7 @@ class PlatformIOManager:
try:
await driver.stop()
except Exception:
- logger.exception("回滚驱动停止失败: driver_id=%s", driver.driver_id)
+ logger.exception(f"回滚驱动停止失败: driver_id={driver.driver_id}")
raise
self._started = True
@@ -104,7 +104,7 @@ class PlatformIOManager:
await driver.stop()
except Exception as exc:
stop_errors.append(f"{driver.driver_id}: {exc}")
- logger.exception("驱动停止失败: driver_id=%s", driver.driver_id)
+ logger.exception(f"驱动停止失败: driver_id={driver.driver_id}")
self._started = False
self._deduplicator.clear()
@@ -448,9 +448,8 @@ class PlatformIOManager:
if not self._receive_route_table.has_binding_for_driver(envelope.route_key, envelope.driver_id):
logger.info(
- "忽略未登记到接收路由表的入站消息: route=%s driver=%s",
- envelope.route_key,
- envelope.driver_id,
+ f"忽略未登记到接收路由表的入站消息: route={envelope.route_key} "
+ f"driver={envelope.driver_id}"
)
return False
@@ -461,7 +460,7 @@ class PlatformIOManager:
dedupe_key = self._build_inbound_dedupe_key(envelope)
if dedupe_key is not None:
if not self._deduplicator.mark_seen(dedupe_key):
- logger.info("忽略重复入站消息: dedupe_key=%s", dedupe_key)
+ logger.info(f"忽略重复入站消息: dedupe_key={dedupe_key}")
return False
await self._inbound_dispatcher(envelope)
From 78858f70043c9c6b7e90b9fffffcee0dec1bfb5c Mon Sep 17 00:00:00 2001
From: DrSmoothl <1787882683@qq.com>
Date: Mon, 23 Mar 2026 23:15:51 +0800
Subject: [PATCH 40/45] =?UTF-8?q?fix:=20=E6=9B=B4=E6=96=B0=E6=8F=92?=
=?UTF-8?q?=E4=BB=B6=E9=80=82=E9=85=8D=E5=99=A8=E5=90=8D=E7=A7=B0=E4=BB=A5?=
=?UTF-8?q?=E5=8F=8D=E6=98=A0=E6=96=B0=E7=9A=84=E5=91=BD=E5=90=8D=E7=BA=A6?=
=?UTF-8?q?=E5=AE=9A?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
src/common/logger_color_and_mapping.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/src/common/logger_color_and_mapping.py b/src/common/logger_color_and_mapping.py
index 4044d2dc..863c9c1e 100644
--- a/src/common/logger_color_and_mapping.py
+++ b/src/common/logger_color_and_mapping.py
@@ -65,7 +65,7 @@ MODULE_COLORS: Dict[str, Tuple[str, Optional[str], bool]] = {
"plugin_runtime.runner.rpc_client": ("#8787ff", None, False),
"plugin_runtime.runner.manifest_validator": ("#5fafff", None, False),
"plugin_runtime.runner.plugin_loader": ("#00afaf", None, False),
- "plugin.napcat_adapter_builtin": ("#00af87", None, False),
+ "plugin.maibot-team.napcat-adapter": ("#00af87", None, False),
"webui": ("#5f87ff", None, False),
"webui.app": ("#5f87d7", None, False),
"webui.api": ("#5fafff", None, False),
@@ -173,7 +173,7 @@ MODULE_ALIASES = {
"plugin_runtime.runner.rpc_client": "插件RPC客户端",
"plugin_runtime.runner.manifest_validator": "插件清单校验",
"plugin_runtime.runner.plugin_loader": "插件加载器",
- "plugin.napcat_adapter_builtin": "NapCat内置适配器",
+ "plugin.maibot-team.napcat-adapter": "NapCat内置适配器",
"webui": "WebUI",
"webui.app": "WebUI应用",
"webui.api": "WebUI接口",
From 1b61e515541a74ee6faa8b99423bcaea8ae6d344 Mon Sep 17 00:00:00 2001
From: DrSmoothl <1787882683@qq.com>
Date: Tue, 24 Mar 2026 10:55:58 +0800
Subject: [PATCH 41/45] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E6=89=B9?=
=?UTF-8?q?=E9=87=8F=E6=8F=92=E4=BB=B6=E9=87=8D=E8=BD=BD=E5=8A=9F=E8=83=BD?=
=?UTF-8?q?=E5=8F=8A=E7=9B=B8=E5=85=B3=E6=B5=8B=E8=AF=95?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
pytests/test_plugin_runtime.py | 171 ++++++++++++++++++
src/plugin_runtime/host/component_registry.py | 6 +-
src/plugin_runtime/host/supervisor.py | 67 ++++++-
src/plugin_runtime/protocol/envelope.py | 29 +++
src/plugin_runtime/runner/runner_main.py | 166 ++++++++++++++---
5 files changed, 398 insertions(+), 41 deletions(-)
diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py
index 1d93ae24..f3e1e7ce 100644
--- a/pytests/test_plugin_runtime.py
+++ b/pytests/test_plugin_runtime.py
@@ -676,6 +676,101 @@ class TestSDK:
methods = [call["method"] for call in runner._rpc_client.calls]
assert methods == ["plugin.bootstrap", "plugin.register_components", "cap.call", "runner.ready"]
+ @pytest.mark.asyncio
+ async def test_runner_batch_reload_merges_overlapping_reverse_dependents(self, monkeypatch):
+ """批量重载应只对重叠依赖闭包执行一次 unload/load。"""
+ from src.plugin_runtime.runner.runner_main import PluginRunner
+
+ runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[])
+ plugin_a_id = "test.plugin-a"
+ plugin_b_id = "test.plugin-b"
+ plugin_c_id = "test.plugin-c"
+
+ def build_meta(plugin_id: str, dependencies: list[str]) -> SimpleNamespace:
+ return SimpleNamespace(
+ plugin_id=plugin_id,
+ dependencies=dependencies,
+ plugin_dir=f"/tmp/{plugin_id}",
+ version="1.0.0",
+ instance=SimpleNamespace(),
+ )
+
+ loaded_metas = {
+ plugin_a_id: build_meta(plugin_a_id, []),
+ plugin_b_id: build_meta(plugin_b_id, [plugin_a_id]),
+ plugin_c_id: build_meta(plugin_c_id, [plugin_b_id]),
+ }
+ reloaded_metas = {
+ plugin_id: build_meta(plugin_id, list(meta.dependencies))
+ for plugin_id, meta in loaded_metas.items()
+ }
+ candidates = {
+ plugin_a_id: (
+ "dir_plugin_a",
+ build_test_manifest_model(plugin_a_id),
+ "plugin_a/plugin.py",
+ ),
+ plugin_b_id: (
+ "dir_plugin_b",
+ build_test_manifest_model(
+ plugin_b_id,
+ dependencies=[{"type": "plugin", "id": plugin_a_id, "version_spec": ">=1.0.0,<2.0.0"}],
+ ),
+ "plugin_b/plugin.py",
+ ),
+ plugin_c_id: (
+ "dir_plugin_c",
+ build_test_manifest_model(
+ plugin_c_id,
+ dependencies=[{"type": "plugin", "id": plugin_b_id, "version_spec": ">=1.0.0,<2.0.0"}],
+ ),
+ "plugin_c/plugin.py",
+ ),
+ }
+ unloaded_plugins: list[str] = []
+ activated_plugins: list[str] = []
+
+ monkeypatch.setattr(runner._loader, "discover_candidates", lambda plugin_dirs: (candidates, {}))
+ monkeypatch.setattr(runner._loader, "list_plugins", lambda: sorted(loaded_metas.keys()))
+ monkeypatch.setattr(runner._loader, "get_plugin", lambda plugin_id: loaded_metas.get(plugin_id))
+ monkeypatch.setattr(
+ runner._loader,
+ "remove_loaded_plugin",
+ lambda plugin_id: loaded_metas.pop(plugin_id, None),
+ )
+ monkeypatch.setattr(runner._loader, "purge_plugin_modules", lambda plugin_id, plugin_dir: [])
+ monkeypatch.setattr(
+ runner._loader,
+ "resolve_dependencies",
+ lambda reload_candidates, extra_available=None: (sorted(reload_candidates.keys()), {}),
+ )
+ monkeypatch.setattr(
+ runner._loader,
+ "load_candidate",
+ lambda plugin_id, candidate: reloaded_metas[plugin_id],
+ )
+
+ async def fake_unload_plugin(meta, reason, purge_modules=False):
+ del reason, purge_modules
+ unloaded_plugins.append(meta.plugin_id)
+ loaded_metas.pop(meta.plugin_id, None)
+
+ async def fake_activate_plugin(meta):
+ activated_plugins.append(meta.plugin_id)
+ loaded_metas[meta.plugin_id] = meta
+ return True
+
+ monkeypatch.setattr(runner, "_unload_plugin", fake_unload_plugin)
+ monkeypatch.setattr(runner, "_activate_plugin", fake_activate_plugin)
+
+ result = await runner._reload_plugins_by_ids([plugin_a_id, plugin_b_id], reason="manual")
+
+ assert result.success is True
+ assert result.requested_plugin_ids == [plugin_a_id, plugin_b_id]
+ assert unloaded_plugins == [plugin_c_id, plugin_b_id, plugin_a_id]
+ assert activated_plugins == [plugin_a_id, plugin_b_id, plugin_c_id]
+ assert result.reloaded_plugins == [plugin_a_id, plugin_b_id, plugin_c_id]
+
class TestPluginSdkUsage:
"""验证仓库内插件按新 SDK 归一化返回值工作。"""
@@ -1220,6 +1315,25 @@ class TestDependencyResolution:
sys.path[:] = original_path
sys.meta_path[:] = original_meta_path
+ def test_isolate_sys_path_blocks_disallowed_src_imports(self):
+ import importlib
+
+ from src.plugin_runtime.runner import runner_main
+
+ original_path = list(sys.path)
+ original_meta_path = list(sys.meta_path)
+ sys.modules.pop("src.forbidden_demo", None)
+
+ try:
+ runner_main._isolate_sys_path([])
+
+ with pytest.raises(ImportError, match="不允许导入主程序模块"):
+ importlib.import_module("src.forbidden_demo")
+ finally:
+ sys.path[:] = original_path
+ sys.meta_path[:] = original_meta_path
+ sys.modules.pop("src.forbidden_demo", None)
+
# ─── Host-side ComponentRegistry 测试 ──────────────────────
@@ -1264,6 +1378,30 @@ class TestComponentRegistry:
assert stats["command"] == 1
assert stats["tool"] == 1
+ def test_register_command_with_invalid_regex_only_warns(self, monkeypatch):
+ from src.plugin_runtime.host.component_registry import ComponentRegistry
+
+ reg = ComponentRegistry()
+ warnings: list[str] = []
+ monkeypatch.setattr(
+ "src.plugin_runtime.host.component_registry.logger.warning",
+ lambda message: warnings.append(str(message)),
+ )
+
+ success = reg.register_component(
+ "broken",
+ "command",
+ "plugin_a",
+ {
+ "command_pattern": "[",
+ },
+ )
+
+ assert success is True
+ assert reg.get_component("plugin_a.broken") is not None
+ assert warnings
+ assert "plugin_a.broken" in warnings[0]
+
def test_query_by_type(self):
from src.plugin_runtime.host.component_registry import ComponentRegistry
@@ -2303,6 +2441,39 @@ class TestSupervisor:
assert supervisor.component_registry.get_component("plugin_a.handler") is not None
assert supervisor.component_registry.get_component("plugin_a.obsolete") is None
+ @pytest.mark.asyncio
+ async def test_reload_plugins_uses_batch_rpc_for_multiple_roots(self):
+ from src.plugin_runtime.host.supervisor import PluginSupervisor
+ from src.plugin_runtime.protocol.envelope import ReloadPluginsResultPayload
+
+ supervisor = PluginSupervisor(plugin_dirs=[])
+ sent_requests: list[tuple[str, dict[str, object], int]] = []
+
+ class FakeRPCServer:
+ async def send_request(self, method, payload, timeout_ms=5000, **kwargs):
+ del kwargs
+ sent_requests.append((method, payload, timeout_ms))
+ return SimpleNamespace(
+ payload=ReloadPluginsResultPayload(
+ success=True,
+ requested_plugin_ids=["plugin_a", "plugin_b"],
+ reloaded_plugins=["plugin_a", "plugin_b", "plugin_c"],
+ unloaded_plugins=["plugin_c", "plugin_b", "plugin_a"],
+ ).model_dump()
+ )
+
+ supervisor._rpc_server = FakeRPCServer()
+
+ reloaded = await supervisor.reload_plugins(["plugin_a", "plugin_b", "plugin_a"], reason="manual")
+
+ assert reloaded is True
+ assert len(sent_requests) == 1
+ method, payload, timeout_ms = sent_requests[0]
+ assert method == "plugin.reload_batch"
+ assert payload["plugin_ids"] == ["plugin_a", "plugin_b"]
+ assert payload["reason"] == "manual"
+ assert timeout_ms >= 10000
+
@pytest.mark.asyncio
async def test_reload_rolls_back_when_runner_ready_not_received(self, monkeypatch):
from src.plugin_runtime.host.supervisor import PluginSupervisor
diff --git a/src/plugin_runtime/host/component_registry.py b/src/plugin_runtime/host/component_registry.py
index 1c073490..97fdca30 100644
--- a/src/plugin_runtime/host/component_registry.py
+++ b/src/plugin_runtime/host/component_registry.py
@@ -75,14 +75,14 @@ class CommandEntry(ComponentEntry):
"""Command 组件条目"""
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
- self.compiled_pattern: Optional[re.Pattern] = None
+ super().__init__(name, component_type, plugin_id, metadata)
self.aliases: List[str] = metadata.get("aliases", [])
+ self.compiled_pattern: Optional[re.Pattern] = None
if pattern := metadata.get("command_pattern", ""):
try:
self.compiled_pattern = re.compile(pattern)
- except re.error as e:
+ except (re.error, TypeError) as e:
logger.warning(f"命令 {self.full_name} 正则编译失败: {e}")
- super().__init__(name, component_type, plugin_id, metadata)
class ToolEntry(ComponentEntry):
diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py
index d12014f6..08638d16 100644
--- a/src/plugin_runtime/host/supervisor.py
+++ b/src/plugin_runtime/host/supervisor.py
@@ -33,6 +33,8 @@ from src.plugin_runtime.protocol.envelope import (
ReceiveExternalMessageResultPayload,
RegisterPluginPayload,
ReloadPluginResultPayload,
+ ReloadPluginsPayload,
+ ReloadPluginsResultPayload,
RouteMessagePayload,
RunnerReadyPayload,
ShutdownPayload,
@@ -194,6 +196,35 @@ class PluginRunnerSupervisor:
for plugin_id, registration in self._registered_plugins.items()
}
+ @staticmethod
+ def _normalize_reload_plugin_ids(plugin_ids: Optional[List[str] | str]) -> List[str]:
+ """规范化批量重载入参。
+
+ Args:
+ plugin_ids: 原始插件 ID 列表或单个插件 ID。
+
+ Returns:
+ List[str]: 去重且去空白后的插件 ID 列表。
+ """
+
+ raw_plugin_ids: List[str]
+ if plugin_ids is None:
+ raw_plugin_ids = []
+ elif isinstance(plugin_ids, str):
+ raw_plugin_ids = [plugin_ids]
+ else:
+ raw_plugin_ids = list(plugin_ids)
+
+ normalized_plugin_ids: List[str] = []
+ seen_plugin_ids: set[str] = set()
+ for plugin_id in raw_plugin_ids:
+ normalized_plugin_id = str(plugin_id or "").strip()
+ if not normalized_plugin_id or normalized_plugin_id in seen_plugin_ids:
+ continue
+ seen_plugin_ids.add(normalized_plugin_id)
+ normalized_plugin_ids.append(normalized_plugin_id)
+ return normalized_plugin_ids
+
async def dispatch_event(
self,
event_type: str,
@@ -420,7 +451,7 @@ class PluginRunnerSupervisor:
async def reload_plugins(
self,
- plugin_ids: Optional[List[str]] = None,
+ plugin_ids: Optional[List[str] | str] = None,
reason: str = "manual",
external_available_plugins: Optional[Dict[str, str]] = None,
) -> bool:
@@ -434,19 +465,37 @@ class PluginRunnerSupervisor:
Returns:
bool: 是否全部重载成功。
"""
- target_plugin_ids = plugin_ids or list(self._registered_plugins.keys())
- ordered_plugin_ids = list(dict.fromkeys(target_plugin_ids))
- success = True
+ ordered_plugin_ids = self._normalize_reload_plugin_ids(plugin_ids)
+ if not ordered_plugin_ids:
+ ordered_plugin_ids = list(self._registered_plugins.keys())
+ if not ordered_plugin_ids:
+ return True
- for plugin_id in ordered_plugin_ids:
- reloaded = await self.reload_plugin(
- plugin_id=plugin_id,
+ if len(ordered_plugin_ids) == 1:
+ return await self.reload_plugin(
+ plugin_id=ordered_plugin_ids[0],
reason=reason,
external_available_plugins=external_available_plugins,
)
- success = success and reloaded
- return success
+ try:
+ response = await self._rpc_server.send_request(
+ "plugin.reload_batch",
+ payload=ReloadPluginsPayload(
+ plugin_ids=ordered_plugin_ids,
+ reason=reason,
+ external_available_plugins=external_available_plugins or self._external_available_plugins,
+ ).model_dump(),
+ timeout_ms=max(int(self._runner_spawn_timeout * 1000), 10000),
+ )
+ except Exception as exc:
+ logger.error(f"插件批量重载请求失败: {exc}")
+ return False
+
+ result = ReloadPluginsResultPayload.model_validate(response.payload)
+ if not result.success:
+ logger.warning(f"插件批量重载失败: {result.failed_plugins}")
+ return result.success
async def notify_plugin_config_updated(
self,
diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py
index c2c89a0f..e738d019 100644
--- a/src/plugin_runtime/protocol/envelope.py
+++ b/src/plugin_runtime/protocol/envelope.py
@@ -289,6 +289,20 @@ class ReloadPluginPayload(BaseModel):
"""可视为已满足的外部依赖插件版本映射"""
+class ReloadPluginsPayload(BaseModel):
+ """批量插件重载请求载荷。"""
+
+ plugin_ids: List[str] = Field(default_factory=list, description="目标插件 ID 列表")
+ """目标插件 ID 列表"""
+ reason: str = Field(default="manual", description="重载原因")
+ """重载原因"""
+ external_available_plugins: Dict[str, str] = Field(
+ default_factory=dict,
+ description="可视为已满足的外部依赖插件版本映射",
+ )
+ """可视为已满足的外部依赖插件版本映射"""
+
+
class ReloadPluginResultPayload(BaseModel):
"""插件重载结果载荷。"""
@@ -304,6 +318,21 @@ class ReloadPluginResultPayload(BaseModel):
"""重载失败的插件及原因"""
+class ReloadPluginsResultPayload(BaseModel):
+ """批量插件重载结果载荷。"""
+
+ success: bool = Field(description="是否重载成功")
+ """是否重载成功"""
+ requested_plugin_ids: List[str] = Field(default_factory=list, description="请求重载的插件 ID 列表")
+ """请求重载的插件 ID 列表"""
+ reloaded_plugins: List[str] = Field(default_factory=list, description="成功完成重载的插件列表")
+ """成功完成重载的插件列表"""
+ unloaded_plugins: List[str] = Field(default_factory=list, description="本次已卸载的插件列表")
+ """本次已卸载的插件列表"""
+ failed_plugins: Dict[str, str] = Field(default_factory=dict, description="重载失败的插件及原因")
+ """重载失败的插件及原因"""
+
+
class MessageGatewayStateUpdatePayload(BaseModel):
"""消息网关运行时状态更新载荷。"""
diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py
index 39c741bd..e66d2fab 100644
--- a/src/plugin_runtime/runner/runner_main.py
+++ b/src/plugin_runtime/runner/runner_main.py
@@ -42,6 +42,8 @@ from src.plugin_runtime.protocol.envelope import (
RegisterPluginPayload,
ReloadPluginPayload,
ReloadPluginResultPayload,
+ ReloadPluginsPayload,
+ ReloadPluginsResultPayload,
RunnerReadyPayload,
UnregisterPluginPayload,
)
@@ -334,6 +336,7 @@ class PluginRunner:
self._rpc_client.register_method("plugin.shutdown", self._handle_shutdown)
self._rpc_client.register_method("plugin.config_updated", self._handle_config_updated)
self._rpc_client.register_method("plugin.reload", self._handle_reload_plugin)
+ self._rpc_client.register_method("plugin.reload_batch", self._handle_reload_plugins)
@staticmethod
def _resolve_component_handler_name(meta: PluginMeta, component_name: str) -> str:
@@ -597,6 +600,21 @@ class PluginRunner:
return impacted_plugins
+ def _collect_reverse_dependents_for_roots(self, plugin_ids: Set[str]) -> Set[str]:
+ """收集多个根插件对应的反向依赖并集。
+
+ Args:
+ plugin_ids: 根插件 ID 集合。
+
+ Returns:
+ Set[str]: 所有根插件及其反向依赖并集。
+ """
+
+ impacted_plugins: Set[str] = set()
+ for plugin_id in sorted(plugin_ids):
+ impacted_plugins.update(self._collect_reverse_dependents(plugin_id))
+ return impacted_plugins
+
def _build_unload_order(self, plugin_ids: Set[str]) -> List[str]:
"""构建受影响插件的卸载顺序。
@@ -635,6 +653,20 @@ class PluginRunner:
return list(reversed(load_order))
+ @staticmethod
+ def _normalize_requested_plugin_ids(plugin_ids: List[str]) -> List[str]:
+ """规范化批量重载请求中的插件 ID 列表。"""
+
+ normalized_plugin_ids: List[str] = []
+ seen_plugin_ids: Set[str] = set()
+ for plugin_id in plugin_ids:
+ normalized_plugin_id = str(plugin_id or "").strip()
+ if not normalized_plugin_id or normalized_plugin_id in seen_plugin_ids:
+ continue
+ seen_plugin_ids.add(normalized_plugin_id)
+ normalized_plugin_ids.append(normalized_plugin_id)
+ return normalized_plugin_ids
+
@staticmethod
def _finalize_failed_reload_messages(
failed_plugins: Dict[str, str],
@@ -674,6 +706,31 @@ class PluginRunner:
Returns:
ReloadPluginResultPayload: 结构化重载结果。
"""
+ batch_result = await self._reload_plugins_by_ids(
+ [plugin_id],
+ reason,
+ external_available_plugins=external_available_plugins,
+ )
+ return ReloadPluginResultPayload(
+ success=batch_result.success,
+ requested_plugin_id=plugin_id,
+ reloaded_plugins=batch_result.reloaded_plugins,
+ unloaded_plugins=batch_result.unloaded_plugins,
+ failed_plugins=batch_result.failed_plugins,
+ )
+
+ async def _reload_plugins_by_ids(
+ self,
+ plugin_ids: List[str],
+ reason: str,
+ external_available_plugins: Optional[Dict[str, str]] = None,
+ ) -> ReloadPluginsResultPayload:
+ """按插件 ID 列表在 Runner 进程内执行一次批量重载。"""
+
+ normalized_plugin_ids = self._normalize_requested_plugin_ids(plugin_ids)
+ if not normalized_plugin_ids:
+ return ReloadPluginsResultPayload(success=True, requested_plugin_ids=[])
+
candidates, duplicate_candidates = self._loader.discover_candidates(self._plugin_dirs)
failed_plugins: Dict[str, str] = {}
normalized_external_available = {
@@ -682,28 +739,35 @@ class PluginRunner:
if str(candidate_plugin_id or "").strip() and str(candidate_plugin_version or "").strip()
}
- if plugin_id in duplicate_candidates:
- conflict_paths = ", ".join(str(path) for path in duplicate_candidates[plugin_id])
- return ReloadPluginResultPayload(
- success=False,
- requested_plugin_id=plugin_id,
- failed_plugins={plugin_id: f"检测到重复插件 ID: {conflict_paths}"},
- )
-
loaded_plugin_ids = set(self._loader.list_plugins())
- plugin_is_loaded = plugin_id in loaded_plugin_ids
- plugin_has_candidate = plugin_id in candidates
+ reload_root_ids: Set[str] = set()
+ for plugin_id in normalized_plugin_ids:
+ if plugin_id in duplicate_candidates:
+ conflict_paths = ", ".join(str(path) for path in duplicate_candidates[plugin_id])
+ failed_plugins[plugin_id] = f"检测到重复插件 ID: {conflict_paths}"
+ continue
- if not plugin_is_loaded and not plugin_has_candidate:
- return ReloadPluginResultPayload(
+ plugin_is_loaded = plugin_id in loaded_plugin_ids
+ plugin_has_candidate = plugin_id in candidates
+ if not plugin_is_loaded and not plugin_has_candidate:
+ failed_plugins[plugin_id] = "插件不存在或未找到合法的 manifest/plugin.py"
+ continue
+
+ reload_root_ids.add(plugin_id)
+
+ if not reload_root_ids:
+ return ReloadPluginsResultPayload(
success=False,
- requested_plugin_id=plugin_id,
- failed_plugins={plugin_id: "插件不存在或未找到合法的 manifest/plugin.py"},
+ requested_plugin_ids=normalized_plugin_ids,
+ failed_plugins=failed_plugins,
)
- target_plugin_ids: Set[str] = {plugin_id}
- if plugin_is_loaded:
- target_plugin_ids = self._collect_reverse_dependents(plugin_id)
+ target_plugin_ids: Set[str] = {
+ plugin_id for plugin_id in reload_root_ids if plugin_id not in loaded_plugin_ids
+ }
+ loaded_root_plugin_ids = reload_root_ids & loaded_plugin_ids
+ if loaded_root_plugin_ids:
+ target_plugin_ids.update(self._collect_reverse_dependents_for_roots(loaded_root_plugin_ids))
unload_order = self._build_unload_order(target_plugin_ids & loaded_plugin_ids)
unloaded_plugins: List[str] = []
@@ -813,19 +877,19 @@ class PluginRunner:
if not restored:
rollback_failures[rollback_plugin_id] = "无法重新激活旧版本"
- return ReloadPluginResultPayload(
+ return ReloadPluginsResultPayload(
success=False,
- requested_plugin_id=plugin_id,
+ requested_plugin_ids=normalized_plugin_ids,
reloaded_plugins=[],
unloaded_plugins=unloaded_plugins,
failed_plugins=self._finalize_failed_reload_messages(failed_plugins, rollback_failures),
)
- requested_plugin_success = plugin_id in reloaded_plugins
+ requested_plugin_success = all(plugin_id in reloaded_plugins for plugin_id in reload_root_ids)
- return ReloadPluginResultPayload(
- success=requested_plugin_success,
- requested_plugin_id=plugin_id,
+ return ReloadPluginsResultPayload(
+ success=requested_plugin_success and not failed_plugins,
+ requested_plugin_ids=normalized_plugin_ids,
reloaded_plugins=reloaded_plugins,
unloaded_plugins=unloaded_plugins,
failed_plugins=failed_plugins,
@@ -1139,6 +1203,29 @@ class PluginRunner:
)
return envelope.make_response(payload=result.model_dump())
+ async def _handle_reload_plugins(self, envelope: Envelope) -> Envelope:
+ """处理批量插件重载请求。"""
+
+ try:
+ payload = ReloadPluginsPayload.model_validate(envelope.payload)
+ except Exception as exc:
+ return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
+
+ if self._reload_lock.locked():
+ requested_plugin_ids = ", ".join(self._normalize_requested_plugin_ids(payload.plugin_ids)) or ""
+ return envelope.make_error_response(
+ ErrorCode.E_RELOAD_IN_PROGRESS.value,
+ f"插件 {requested_plugin_ids} 批量重载请求被拒绝:已有重载任务正在执行",
+ )
+
+ async with self._reload_lock:
+ result = await self._reload_plugins_by_ids(
+ list(payload.plugin_ids),
+ payload.reason,
+ external_available_plugins=dict(payload.external_available_plugins),
+ )
+ return envelope.make_response(payload=result.model_dump())
+
def request_capability(self) -> RPCClient:
"""获取 RPC 客户端(供 SDK 使用,发起能力调用)"""
return self._rpc_client
@@ -1153,6 +1240,7 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
防止插件代码 import 主程序模块读取运行时数据。
"""
import importlib.abc
+ from importlib.machinery import ModuleSpec
import sysconfig
# 保留: 标准库路径 + site-packages(含 SDK 和依赖)
@@ -1195,6 +1283,20 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
# 安装 import 钩子,阻止插件导入主程序核心模块
# 仅允许 src.plugin_runtime 和 src.common,拒绝其他 src.* 子包
+ class _BlockedSrcModuleLoader(importlib.abc.Loader):
+ """阻止被 Runner 允许列表之外的主程序模块完成导入。"""
+
+ def __init__(self, fullname: str) -> None:
+ self._fullname = fullname
+
+ def create_module(self, spec: ModuleSpec) -> None:
+ del spec
+ return None
+
+ def exec_module(self, module: Any) -> None:
+ del module
+ raise ImportError(f"Runner 子进程不允许导入主程序模块: {self._fullname}")
+
class _PluginImportBlocker(importlib.abc.MetaPathFinder):
"""阻止 Runner 子进程导入主程序核心模块。
@@ -1203,14 +1305,15 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
"""
_ALLOWED_SRC_PREFIXES = ("src.plugin_runtime", "src.common")
+ __maibot_runner_plugin_import_blocker__ = True
- def find_module(self, fullname: str, path: Any = None) -> Any:
+ def find_spec(self, fullname: str, path: Any = None, target: Any = None) -> ModuleSpec | None:
"""决定是否拦截指定模块导入。"""
- return self if self._should_block(fullname) else None
-
- def load_module(self, fullname: str) -> None:
- """阻止被拦截模块继续导入。"""
- raise ImportError(f"Runner 子进程不允许导入主程序模块: {fullname}")
+ del path, target
+ if not self._should_block(fullname):
+ return None
+ # Python 3.13+/3.14 会优先走 find_spec,不再依赖 find_module。
+ return ModuleSpec(fullname, _BlockedSrcModuleLoader(fullname), is_package=True)
def _should_block(self, fullname: str) -> bool:
"""判断给定模块名是否应被阻止导入。"""
@@ -1222,6 +1325,11 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
fullname == prefix or fullname.startswith(f"{prefix}.") for prefix in self._ALLOWED_SRC_PREFIXES
)
+ sys.meta_path[:] = [
+ finder
+ for finder in sys.meta_path
+ if not getattr(finder, "__maibot_runner_plugin_import_blocker__", False)
+ ]
sys.meta_path.insert(0, _PluginImportBlocker())
From f4a9afc452edc08809f93da32d23f8945b1f67d2 Mon Sep 17 00:00:00 2001
From: DrSmoothl <1787882683@qq.com>
Date: Tue, 24 Mar 2026 11:43:23 +0800
Subject: [PATCH 42/45] =?UTF-8?q?feat:=20=E5=A2=9E=E5=BC=BA=20RPC=20?=
=?UTF-8?q?=E6=9C=8D=E5=8A=A1=E5=99=A8=E8=BF=9E=E6=8E=A5=E5=A4=84=E7=90=86?=
=?UTF-8?q?=EF=BC=8C=E6=B7=BB=E5=8A=A0=E8=BF=9E=E6=8E=A5=E9=94=81=E4=BB=A5?=
=?UTF-8?q?=E9=98=B2=E6=AD=A2=E5=B9=B6=E5=8F=91=E8=BF=9E=E6=8E=A5=E9=97=AE?=
=?UTF-8?q?=E9=A2=98?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
pytests/test_plugin_runtime.py | 146 +++++++++++++++++++++++
src/plugin_runtime/host/rpc_server.py | 36 ++++--
src/plugin_runtime/runner/runner_main.py | 103 ++++++++++++----
3 files changed, 252 insertions(+), 33 deletions(-)
diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py
index f3e1e7ce..29227658 100644
--- a/pytests/test_plugin_runtime.py
+++ b/pytests/test_plugin_runtime.py
@@ -1298,9 +1298,12 @@ class TestDependencyResolution:
assert "on_unload" in loader.failed_plugins["test.demo-plugin"]
def test_isolate_sys_path_preserves_plugin_dirs(self):
+ import builtins
+
from src.plugin_runtime.runner import runner_main
plugin_root = os.path.normpath("/tmp/maibot-plugin-root")
+ original_import = builtins.__import__
original_path = list(sys.path)
original_meta_path = list(sys.meta_path)
@@ -1312,14 +1315,17 @@ class TestDependencyResolution:
assert plugin_root in sys.path
finally:
+ builtins.__import__ = original_import
sys.path[:] = original_path
sys.meta_path[:] = original_meta_path
def test_isolate_sys_path_blocks_disallowed_src_imports(self):
+ import builtins
import importlib
from src.plugin_runtime.runner import runner_main
+ original_import = builtins.__import__
original_path = list(sys.path)
original_meta_path = list(sys.meta_path)
sys.modules.pop("src.forbidden_demo", None)
@@ -1330,10 +1336,89 @@ class TestDependencyResolution:
with pytest.raises(ImportError, match="不允许导入主程序模块"):
importlib.import_module("src.forbidden_demo")
finally:
+ builtins.__import__ = original_import
sys.path[:] = original_path
sys.meta_path[:] = original_meta_path
sys.modules.pop("src.forbidden_demo", None)
+ def test_isolate_sys_path_blocks_preloaded_runtime_modules(self):
+ import builtins
+ import importlib
+
+ from src.plugin_runtime.runner import runner_main
+
+ original_import = builtins.__import__
+ original_path = list(sys.path)
+ original_meta_path = list(sys.meta_path)
+
+ try:
+ runner_main._isolate_sys_path([])
+
+ with pytest.raises(ImportError, match="rpc_client"):
+ importlib.import_module("src.plugin_runtime.runner.rpc_client")
+ finally:
+ builtins.__import__ = original_import
+ sys.path[:] = original_path
+ sys.meta_path[:] = original_meta_path
+
+ def test_isolate_sys_path_keeps_legacy_logger_import_available(self):
+ import builtins
+ import importlib
+
+ from src.plugin_runtime.runner import runner_main
+
+ original_import = builtins.__import__
+ original_path = list(sys.path)
+ original_meta_path = list(sys.meta_path)
+
+ try:
+ runner_main._isolate_sys_path([])
+
+ logger_module = importlib.import_module("src.common.logger")
+ assert callable(logger_module.get_logger)
+ finally:
+ builtins.__import__ = original_import
+ sys.path[:] = original_path
+ sys.meta_path[:] = original_meta_path
+
+ @pytest.mark.asyncio
+ async def test_async_main_removes_sensitive_runtime_env_vars(self, monkeypatch):
+ from src.plugin_runtime.runner import runner_main
+
+ captured = {}
+
+ class FakeRunner:
+ def __init__(
+ self,
+ host_address: str,
+ session_token: str,
+ plugin_dirs: list[str],
+ external_available_plugins: dict[str, str] | None = None,
+ ) -> None:
+ captured["host_address"] = host_address
+ captured["session_token"] = session_token
+ captured["plugin_dirs"] = plugin_dirs
+ captured["external_available_plugins"] = external_available_plugins or {}
+
+ async def run(self) -> None:
+ assert os.environ.get(runner_main.ENV_IPC_ADDRESS) is None
+ assert os.environ.get(runner_main.ENV_SESSION_TOKEN) is None
+
+ monkeypatch.setenv(runner_main.ENV_IPC_ADDRESS, "tcp://127.0.0.1:9999")
+ monkeypatch.setenv(runner_main.ENV_SESSION_TOKEN, "secret-token")
+ monkeypatch.setenv(runner_main.ENV_PLUGIN_DIRS, "/tmp/plugins")
+ monkeypatch.setenv(runner_main.ENV_EXTERNAL_PLUGIN_IDS, '{"demo.plugin":"1.0.0"}')
+ monkeypatch.setattr(runner_main, "_install_shutdown_signal_handlers", lambda callback: None)
+ monkeypatch.setattr(runner_main, "_isolate_sys_path", lambda plugin_dirs: None)
+ monkeypatch.setattr(runner_main, "PluginRunner", FakeRunner)
+
+ await runner_main._async_main()
+
+ assert captured["host_address"] == "tcp://127.0.0.1:9999"
+ assert captured["session_token"] == "secret-token"
+ assert captured["plugin_dirs"] == ["/tmp/plugins"]
+ assert captured["external_available_plugins"] == {"demo.plugin": "1.0.0"}
+
# ─── Host-side ComponentRegistry 测试 ──────────────────────
@@ -2093,6 +2178,67 @@ class TestWorkflowExecutor:
class TestRPCServer:
"""RPC Server 代际保护测试"""
+ @pytest.mark.asyncio
+ async def test_reject_second_active_runner_connection(self):
+ from src.plugin_runtime.host.rpc_server import RPCServer
+ from src.plugin_runtime.protocol.codec import MsgPackCodec
+ from src.plugin_runtime.protocol.envelope import Envelope, HelloPayload, HelloResponsePayload, MessageType
+
+ class DummyTransport:
+ async def start(self, handler):
+ return None
+
+ async def stop(self):
+ return None
+
+ def get_address(self):
+ return "dummy"
+
+ class FakeConnection:
+ def __init__(self, incoming_frames: list[bytes]):
+ self._incoming_frames = list(incoming_frames)
+ self.sent_frames: list[bytes] = []
+ self.is_closed = False
+
+ async def recv_frame(self):
+ return self._incoming_frames.pop(0)
+
+ async def send_frame(self, data):
+ self.sent_frames.append(data)
+
+ async def close(self):
+ self.is_closed = True
+
+ codec = MsgPackCodec()
+ server = RPCServer(transport=DummyTransport(), session_token="session-token")
+ active_conn = SimpleNamespace(is_closed=False)
+ server._connection = active_conn
+
+ hello = HelloPayload(
+ runner_id="runner-b",
+ sdk_version="1.0.0",
+ session_token="session-token",
+ )
+ envelope = Envelope(
+ request_id=1,
+ message_type=MessageType.REQUEST,
+ method="runner.hello",
+ payload=hello.model_dump(),
+ )
+ incoming_conn = FakeConnection([codec.encode_envelope(envelope)])
+
+ await server._handle_connection(incoming_conn)
+
+ assert incoming_conn.is_closed is True
+ assert server._connection is active_conn
+ assert server.last_handshake_rejection_reason == "已有活跃 Runner 连接,拒绝新的握手"
+ assert len(incoming_conn.sent_frames) == 1
+
+ response = codec.decode_envelope(incoming_conn.sent_frames[0])
+ response_payload = HelloResponsePayload.model_validate(response.payload)
+ assert response_payload.accepted is False
+ assert response_payload.reason == "已有活跃 Runner 连接,拒绝新的握手"
+
def test_ignore_stale_generation_response(self):
from src.plugin_runtime.host.rpc_server import RPCServer
from src.plugin_runtime.protocol.envelope import Envelope, MessageType
diff --git a/src/plugin_runtime/host/rpc_server.py b/src/plugin_runtime/host/rpc_server.py
index 2c422775..eb6768c2 100644
--- a/src/plugin_runtime/host/rpc_server.py
+++ b/src/plugin_runtime/host/rpc_server.py
@@ -70,6 +70,7 @@ class RPCServer:
self._running: bool = False
self._tasks: List[asyncio.Task[None]] = []
self._last_handshake_rejection_reason: str = ""
+ self._connection_lock: asyncio.Lock = asyncio.Lock()
@property
def session_token(self) -> str:
@@ -216,27 +217,33 @@ class RPCServer:
async def _handle_connection(self, conn: Connection) -> None:
"""处理新的 Runner 连接"""
logger.info("收到 Runner 连接")
- self.clear_handshake_state()
- # 第一条消息必须是 runner.hello 握手
try:
- success = await self._handle_handshake(conn)
- if not success:
- await conn.close()
- return
+ async with self._connection_lock:
+ self.clear_handshake_state()
+ success = await self._handle_handshake(conn)
+ if not success:
+ await conn.close()
+ return
+ logger.info("Runner staged 握手成功")
+ self._connection = conn
except Exception as e:
logger.error(f"握手失败: {e}")
await conn.close()
return
- logger.info("Runner staged 握手成功")
- self._connection = conn
+
# 启动消息接收循环
try:
await self._recv_loop(conn)
except Exception as e:
logger.error(f"连接异常断开: {e}")
finally:
- self._connection = None
- self._fail_pending_requests(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已断开")
+ should_fail_pending_requests = False
+ async with self._connection_lock:
+ if self._connection is conn:
+ self._connection = None
+ should_fail_pending_requests = True
+ if should_fail_pending_requests:
+ self._fail_pending_requests(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已断开")
async def _handle_handshake(self, conn: Connection) -> bool:
"""处理 runner.hello 握手"""
@@ -264,6 +271,15 @@ class RPCServer:
await conn.send_frame(self._codec.encode_envelope(resp))
return False
+ # 若已有活跃连接,直接拒绝新的握手,避免后来的连接抢占当前通道。
+ if self.is_connected:
+ logger.warning("拒绝新的 Runner 连接:已有活跃连接")
+ self._last_handshake_rejection_reason = "已有活跃 Runner 连接,拒绝新的握手"
+ resp_payload = HelloResponsePayload(accepted=False, reason=self._last_handshake_rejection_reason)
+ resp = envelope.make_response(payload=resp_payload.model_dump())
+ await conn.send_frame(self._codec.encode_envelope(resp))
+ return False
+
# 校验 SDK 版本
if not self._check_sdk_version(hello.sdk_version):
logger.error(f"SDK 版本不兼容: {hello.sdk_version}")
diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py
index e66d2fab..e4e47c68 100644
--- a/src/plugin_runtime/runner/runner_main.py
+++ b/src/plugin_runtime/runner/runner_main.py
@@ -1237,11 +1237,14 @@ class PluginRunner:
def _isolate_sys_path(plugin_dirs: List[str]) -> None:
"""清理 sys.path,限制 Runner 子进程只能访问标准库、SDK 和插件目录。
- 防止插件代码 import 主程序模块读取运行时数据。
+ 同时移除插件可直接访问的主程序内部模块缓存,避免通过 ``sys.modules``
+ 或常规导入绕过 SDK / capability 边界。
"""
+ import builtins
import importlib.abc
from importlib.machinery import ModuleSpec
import sysconfig
+ from types import ModuleType
# 保留: 标准库路径 + site-packages(含 SDK 和依赖)
stdlib_paths = set()
@@ -1271,18 +1274,68 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
for d in plugin_dir_paths:
allowed.add(d)
- # 添加项目根目录(使得 src.plugin_runtime / src.common 可导入)
- runtime_root = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
- allowed.add(runtime_root)
-
preserved_paths = [p for p in sys.path if p in allowed]
- for extra_path in [*plugin_dir_paths, runtime_root]:
+ for extra_path in plugin_dir_paths:
if extra_path not in preserved_paths:
preserved_paths.append(extra_path)
sys.path[:] = preserved_paths
- # 安装 import 钩子,阻止插件导入主程序核心模块
- # 仅允许 src.plugin_runtime 和 src.common,拒绝其他 src.* 子包
+ # 仅为旧版插件兼容层保留极小的 src.* 可见面:
+ # - src.plugin_system.*: 通过 maibot_sdk.compat 导入钩子重定向
+ # - src.common.logger: 仓库内仍有少量旧插件沿用该日志入口
+ allowed_src_exact_modules = frozenset(
+ {
+ "src",
+ "src.common",
+ "src.common.logger",
+ "src.common.logger_color_and_mapping",
+ }
+ )
+ allowed_src_prefixes = ("src.plugin_system",)
+
+ def _is_allowed_src_module(fullname: str) -> bool:
+ """判断给定 src.* 模块是否在 Runner 允许列表中。"""
+ if fullname in allowed_src_exact_modules:
+ return True
+ return any(fullname == prefix or fullname.startswith(f"{prefix}.") for prefix in allowed_src_prefixes)
+
+ def _format_block_message(fullname: str) -> str:
+ """构造统一的拒绝导入错误信息。"""
+ return (
+ f"Runner 子进程不允许导入主程序模块: {fullname}。"
+ "请改用 maibot_sdk 或 src.plugin_system 兼容层提供的接口。"
+ )
+
+ def _detach_module_from_parent(fullname: str, module: ModuleType) -> None:
+ """从父模块上移除已清理模块的属性引用。"""
+ parent_name, _, child_name = fullname.rpartition(".")
+ if not parent_name or not child_name:
+ return
+
+ parent_module = sys.modules.get(parent_name)
+ if parent_module is None:
+ return
+ if getattr(parent_module, child_name, None) is module:
+ with contextlib.suppress(AttributeError):
+ delattr(parent_module, child_name)
+
+ # 清理主程序内部模块缓存,避免插件经由 sys.modules 直接拿到高权限对象。
+ existing_src_modules = sorted(
+ (
+ (module_name, module)
+ for module_name, module in list(sys.modules.items())
+ if module_name == "src" or module_name.startswith("src.")
+ ),
+ key=lambda item: item[0].count("."),
+ reverse=True,
+ )
+ for module_name, module in existing_src_modules:
+ if _is_allowed_src_module(module_name):
+ continue
+ _detach_module_from_parent(module_name, module)
+ sys.modules.pop(module_name, None)
+
+ # 安装 import 钩子,阻止再次导入被清理掉的主程序内部模块。
class _BlockedSrcModuleLoader(importlib.abc.Loader):
"""阻止被 Runner 允许列表之外的主程序模块完成导入。"""
@@ -1295,16 +1348,11 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
def exec_module(self, module: Any) -> None:
del module
- raise ImportError(f"Runner 子进程不允许导入主程序模块: {self._fullname}")
+ raise ImportError(_format_block_message(self._fullname))
class _PluginImportBlocker(importlib.abc.MetaPathFinder):
- """阻止 Runner 子进程导入主程序核心模块。
+ """阻止 Runner 子进程重新导入主程序内部 src.* 模块。"""
- 只放行 src.plugin_runtime 和 src.common,
- 拒绝 src.chat_module / src.services 等主程序内部包。
- """
-
- _ALLOWED_SRC_PREFIXES = ("src.plugin_runtime", "src.common")
__maibot_runner_plugin_import_blocker__ = True
def find_spec(self, fullname: str, path: Any = None, target: Any = None) -> ModuleSpec | None:
@@ -1317,13 +1365,9 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
def _should_block(self, fullname: str) -> bool:
"""判断给定模块名是否应被阻止导入。"""
- # 放行非 src.* 的导入、以及 "src" 本身
- if not fullname.startswith("src.") or fullname == "src":
+ if not fullname.startswith("src"):
return False
- # 放行白名单前缀
- return not any(
- fullname == prefix or fullname.startswith(f"{prefix}.") for prefix in self._ALLOWED_SRC_PREFIXES
- )
+ return not _is_allowed_src_module(fullname)
sys.meta_path[:] = [
finder
@@ -1332,15 +1376,28 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
]
sys.meta_path.insert(0, _PluginImportBlocker())
+ # ``import`` 语句在模块已存在于 sys.modules 时不会再经过 finder,
+ # 因此还需要在入口处补一层兜底。
+ original_import = getattr(builtins, "__maibot_runner_original_import__", builtins.__import__)
+ builtins.__maibot_runner_original_import__ = original_import
+
+ def _guarded_import(name: str, globals: Any = None, locals: Any = None, fromlist: Any = (), level: int = 0) -> Any:
+ if level == 0 and name.startswith("src") and not _is_allowed_src_module(name):
+ raise ImportError(_format_block_message(name))
+ return original_import(name, globals, locals, fromlist, level)
+
+ _guarded_import.__maibot_runner_plugin_import_guard__ = True
+ builtins.__import__ = _guarded_import
+
# ─── 进程入口 ──────────────────────────────────────────────
async def _async_main() -> None:
"""异步主入口"""
- host_address = os.environ.get(ENV_IPC_ADDRESS, "")
+ host_address = os.environ.pop(ENV_IPC_ADDRESS, "")
external_plugin_ids_raw = os.environ.get(ENV_EXTERNAL_PLUGIN_IDS, "")
- session_token = os.environ.get(ENV_SESSION_TOKEN, "")
+ session_token = os.environ.pop(ENV_SESSION_TOKEN, "")
plugin_dirs_str = os.environ.get(ENV_PLUGIN_DIRS, "")
if not host_address or not session_token:
From d5581a1a970c8c0ae23194ba2f1bbcc2e2b283b9 Mon Sep 17 00:00:00 2001
From: DrSmoothl <1787882683@qq.com>
Date: Tue, 24 Mar 2026 11:49:40 +0800
Subject: [PATCH 43/45] =?UTF-8?q?feat:=20=E5=A2=9E=E5=BC=BA=E6=8F=92?=
=?UTF-8?q?=E4=BB=B6=E5=AF=BC=E5=85=A5=E7=AE=A1=E7=90=86=EF=BC=8C=E6=B7=BB?=
=?UTF-8?q?=E5=8A=A0=E5=AF=BC=E5=85=A5=E8=AF=B7=E6=B1=82=E9=AA=8C=E8=AF=81?=
=?UTF-8?q?=E5=92=8C=E6=A8=A1=E5=9D=97=E8=AE=BF=E9=97=AE=E6=8E=A7=E5=88=B6?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
pytests/test_plugin_runtime.py | 55 +++++++++-
src/plugin_runtime/runner/runner_main.py | 124 +++++++++++++----------
2 files changed, 124 insertions(+), 55 deletions(-)
diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py
index 29227658..e3247f05 100644
--- a/pytests/test_plugin_runtime.py
+++ b/pytests/test_plugin_runtime.py
@@ -1299,11 +1299,13 @@ class TestDependencyResolution:
def test_isolate_sys_path_preserves_plugin_dirs(self):
import builtins
+ import importlib
from src.plugin_runtime.runner import runner_main
plugin_root = os.path.normpath("/tmp/maibot-plugin-root")
original_import = builtins.__import__
+ original_import_module = importlib.import_module
original_path = list(sys.path)
original_meta_path = list(sys.meta_path)
@@ -1316,6 +1318,7 @@ class TestDependencyResolution:
assert plugin_root in sys.path
finally:
builtins.__import__ = original_import
+ importlib.import_module = original_import_module
sys.path[:] = original_path
sys.meta_path[:] = original_meta_path
@@ -1326,17 +1329,24 @@ class TestDependencyResolution:
from src.plugin_runtime.runner import runner_main
original_import = builtins.__import__
+ original_import_module = importlib.import_module
original_path = list(sys.path)
original_meta_path = list(sys.meta_path)
sys.modules.pop("src.forbidden_demo", None)
try:
runner_main._isolate_sys_path([])
+ plugin_globals = {
+ "__name__": "_maibot_plugin_demo",
+ "__package__": "_maibot_plugin_demo",
+ "importlib": importlib,
+ }
with pytest.raises(ImportError, match="不允许导入主程序模块"):
- importlib.import_module("src.forbidden_demo")
+ exec('importlib.import_module("src.forbidden_demo")', plugin_globals)
finally:
builtins.__import__ = original_import
+ importlib.import_module = original_import_module
sys.path[:] = original_path
sys.meta_path[:] = original_meta_path
sys.modules.pop("src.forbidden_demo", None)
@@ -1348,16 +1358,23 @@ class TestDependencyResolution:
from src.plugin_runtime.runner import runner_main
original_import = builtins.__import__
+ original_import_module = importlib.import_module
original_path = list(sys.path)
original_meta_path = list(sys.meta_path)
try:
runner_main._isolate_sys_path([])
+ plugin_globals = {
+ "__name__": "_maibot_plugin_demo",
+ "__package__": "_maibot_plugin_demo",
+ "importlib": importlib,
+ }
with pytest.raises(ImportError, match="rpc_client"):
- importlib.import_module("src.plugin_runtime.runner.rpc_client")
+ exec('importlib.import_module("src.plugin_runtime.runner.rpc_client")', plugin_globals)
finally:
builtins.__import__ = original_import
+ importlib.import_module = original_import_module
sys.path[:] = original_path
sys.meta_path[:] = original_meta_path
@@ -1368,16 +1385,46 @@ class TestDependencyResolution:
from src.plugin_runtime.runner import runner_main
original_import = builtins.__import__
+ original_import_module = importlib.import_module
+ original_path = list(sys.path)
+ original_meta_path = list(sys.meta_path)
+
+ try:
+ runner_main._isolate_sys_path([])
+ plugin_globals = {
+ "__name__": "_maibot_plugin_demo",
+ "__package__": "_maibot_plugin_demo",
+ "importlib": importlib,
+ }
+
+ exec('logger_module = importlib.import_module("src.common.logger")', plugin_globals)
+ logger_module = plugin_globals["logger_module"]
+ assert callable(logger_module.get_logger)
+ finally:
+ builtins.__import__ = original_import
+ importlib.import_module = original_import_module
+ sys.path[:] = original_path
+ sys.meta_path[:] = original_meta_path
+
+ def test_isolate_sys_path_keeps_runtime_imports_working(self):
+ import builtins
+ import importlib
+
+ from src.plugin_runtime.runner import runner_main
+
+ original_import = builtins.__import__
+ original_import_module = importlib.import_module
original_path = list(sys.path)
original_meta_path = list(sys.meta_path)
try:
runner_main._isolate_sys_path([])
- logger_module = importlib.import_module("src.common.logger")
- assert callable(logger_module.get_logger)
+ uds_module = importlib.import_module("src.plugin_runtime.transport.uds")
+ assert hasattr(uds_module, "UDSTransportClient")
finally:
builtins.__import__ = original_import
+ importlib.import_module = original_import_module
sys.path[:] = original_path
sys.meta_path[:] = original_meta_path
diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py
index e4e47c68..4dac4e05 100644
--- a/src/plugin_runtime/runner/runner_main.py
+++ b/src/plugin_runtime/runner/runner_main.py
@@ -1237,12 +1237,11 @@ class PluginRunner:
def _isolate_sys_path(plugin_dirs: List[str]) -> None:
"""清理 sys.path,限制 Runner 子进程只能访问标准库、SDK 和插件目录。
- 同时移除插件可直接访问的主程序内部模块缓存,避免通过 ``sys.modules``
- 或常规导入绕过 SDK / capability 边界。
+ 同时阻止插件代码直接导入主程序内部 ``src.*`` 模块,并清理可直接从
+ ``sys.modules`` 摸到的高权限叶子模块,避免绕过 SDK / capability 边界。
"""
import builtins
- import importlib.abc
- from importlib.machinery import ModuleSpec
+ import importlib
import sysconfig
from types import ModuleType
@@ -1292,6 +1291,7 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
}
)
allowed_src_prefixes = ("src.plugin_system",)
+ plugin_module_prefix = "_maibot_plugin_"
def _is_allowed_src_module(fullname: str) -> bool:
"""判断给定 src.* 模块是否在 Runner 允许列表中。"""
@@ -1299,6 +1299,35 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
return True
return any(fullname == prefix or fullname.startswith(f"{prefix}.") for prefix in allowed_src_prefixes)
+ def _resolve_requester_name(import_globals: Any = None) -> str:
+ """解析当前导入请求的发起模块名。"""
+ if isinstance(import_globals, dict):
+ for key in ("__name__", "__package__"):
+ value = import_globals.get(key)
+ if isinstance(value, str) and value:
+ return value
+
+ frame = inspect.currentframe()
+ try:
+ current = frame.f_back if frame is not None else None
+ while current is not None:
+ module_name = current.f_globals.get("__name__", "")
+ if not isinstance(module_name, str) or not module_name:
+ current = current.f_back
+ continue
+ if module_name == __name__ or module_name.startswith("importlib"):
+ current = current.f_back
+ continue
+ return module_name
+ return ""
+ finally:
+ del frame
+
+ def _is_plugin_import_request(import_globals: Any = None) -> bool:
+ """判断当前导入是否由插件模块直接发起。"""
+ requester_name = _resolve_requester_name(import_globals)
+ return requester_name.startswith(plugin_module_prefix)
+
def _format_block_message(fullname: str) -> str:
"""构造统一的拒绝导入错误信息。"""
return (
@@ -1306,6 +1335,30 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
"请改用 maibot_sdk 或 src.plugin_system 兼容层提供的接口。"
)
+ def _iter_requested_src_modules(name: str, fromlist: Any) -> List[str]:
+ """展开本次导入请求涉及的 src.* 模块名。"""
+ requested_modules = [name]
+ if not name.startswith("src") or not fromlist:
+ return requested_modules
+
+ for item in fromlist:
+ if not isinstance(item, str) or not item or item == "*":
+ continue
+ requested_modules.append(f"{name}.{item}")
+ return requested_modules
+
+ def _assert_plugin_import_allowed(name: str, import_globals: Any = None, fromlist: Any = ()) -> None:
+ """在插件发起导入时校验目标 src.* 模块是否允许访问。"""
+ if not _is_plugin_import_request(import_globals):
+ return
+
+ for requested_module in _iter_requested_src_modules(name, fromlist):
+ if not requested_module.startswith("src"):
+ continue
+ if _is_allowed_src_module(requested_module):
+ continue
+ raise ImportError(_format_block_message(requested_module))
+
def _detach_module_from_parent(fullname: str, module: ModuleType) -> None:
"""从父模块上移除已清理模块的属性引用。"""
parent_name, _, child_name = fullname.rpartition(".")
@@ -1319,7 +1372,7 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
with contextlib.suppress(AttributeError):
delattr(parent_module, child_name)
- # 清理主程序内部模块缓存,避免插件经由 sys.modules 直接拿到高权限对象。
+ # 仅清理已加载的叶子模块,保留包对象给 Runner 自己的延迟导入和相对导入使用。
existing_src_modules = sorted(
(
(module_name, module)
@@ -1330,65 +1383,34 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
reverse=True,
)
for module_name, module in existing_src_modules:
- if _is_allowed_src_module(module_name):
+ if _is_allowed_src_module(module_name) or hasattr(module, "__path__"):
continue
_detach_module_from_parent(module_name, module)
sys.modules.pop(module_name, None)
- # 安装 import 钩子,阻止再次导入被清理掉的主程序内部模块。
- class _BlockedSrcModuleLoader(importlib.abc.Loader):
- """阻止被 Runner 允许列表之外的主程序模块完成导入。"""
-
- def __init__(self, fullname: str) -> None:
- self._fullname = fullname
-
- def create_module(self, spec: ModuleSpec) -> None:
- del spec
- return None
-
- def exec_module(self, module: Any) -> None:
- del module
- raise ImportError(_format_block_message(self._fullname))
-
- class _PluginImportBlocker(importlib.abc.MetaPathFinder):
- """阻止 Runner 子进程重新导入主程序内部 src.* 模块。"""
-
- __maibot_runner_plugin_import_blocker__ = True
-
- def find_spec(self, fullname: str, path: Any = None, target: Any = None) -> ModuleSpec | None:
- """决定是否拦截指定模块导入。"""
- del path, target
- if not self._should_block(fullname):
- return None
- # Python 3.13+/3.14 会优先走 find_spec,不再依赖 find_module。
- return ModuleSpec(fullname, _BlockedSrcModuleLoader(fullname), is_package=True)
-
- def _should_block(self, fullname: str) -> bool:
- """判断给定模块名是否应被阻止导入。"""
- if not fullname.startswith("src"):
- return False
- return not _is_allowed_src_module(fullname)
-
- sys.meta_path[:] = [
- finder
- for finder in sys.meta_path
- if not getattr(finder, "__maibot_runner_plugin_import_blocker__", False)
- ]
- sys.meta_path.insert(0, _PluginImportBlocker())
-
- # ``import`` 语句在模块已存在于 sys.modules 时不会再经过 finder,
- # 因此还需要在入口处补一层兜底。
+ # ``import`` 语句与 ``importlib.import_module`` 走的是不同入口,因此两边都需要兜底。
original_import = getattr(builtins, "__maibot_runner_original_import__", builtins.__import__)
builtins.__maibot_runner_original_import__ = original_import
def _guarded_import(name: str, globals: Any = None, locals: Any = None, fromlist: Any = (), level: int = 0) -> Any:
- if level == 0 and name.startswith("src") and not _is_allowed_src_module(name):
- raise ImportError(_format_block_message(name))
+ if level == 0:
+ _assert_plugin_import_allowed(name, import_globals=globals, fromlist=fromlist)
return original_import(name, globals, locals, fromlist, level)
_guarded_import.__maibot_runner_plugin_import_guard__ = True
builtins.__import__ = _guarded_import
+ original_import_module = getattr(importlib, "__maibot_runner_original_import_module__", importlib.import_module)
+ importlib.__maibot_runner_original_import_module__ = original_import_module
+
+ def _guarded_import_module(name: str, package: Optional[str] = None) -> Any:
+ resolved_name = importlib.util.resolve_name(name, package) if name.startswith(".") else name
+ _assert_plugin_import_allowed(resolved_name)
+ return original_import_module(name, package)
+
+ _guarded_import_module.__maibot_runner_plugin_import_guard__ = True
+ importlib.import_module = _guarded_import_module
+
# ─── 进程入口 ──────────────────────────────────────────────
From b8224bdb3c6d8b2846cc7b594345d4ed09c0ff12 Mon Sep 17 00:00:00 2001
From: DrSmoothl <1787882683@qq.com>
Date: Tue, 24 Mar 2026 11:51:15 +0800
Subject: [PATCH 44/45] =?UTF-8?q?refactor:=20=E4=BC=98=E5=8C=96=E6=8F=92?=
=?UTF-8?q?=E4=BB=B6=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91=EF=BC=8C=E6=94=B9?=
=?UTF-8?q?=E8=BF=9B=E5=AF=BC=E5=85=A5=E4=BF=9D=E6=8A=A4=E6=9C=BA=E5=88=B6?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
src/plugin_runtime/runner/runner_main.py | 30 +++++++++++++-----------
1 file changed, 16 insertions(+), 14 deletions(-)
diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py
index 4dac4e05..d1ebc064 100644
--- a/src/plugin_runtime/runner/runner_main.py
+++ b/src/plugin_runtime/runner/runner_main.py
@@ -429,9 +429,9 @@ class PluginRunner:
component_name = str(comp_info.get("name", "") or "").strip()
raw_metadata = comp_info.get("metadata", {})
component_metadata = raw_metadata if isinstance(raw_metadata, dict) else {}
- handler_name = str(component_metadata.get("handler_name", component_name) or component_name).strip()
if component_name:
+ handler_name = str(component_metadata.get("handler_name", component_name) or component_name).strip()
meta.component_handlers[component_name] = handler_name or component_name
components.append(
@@ -765,8 +765,7 @@ class PluginRunner:
target_plugin_ids: Set[str] = {
plugin_id for plugin_id in reload_root_ids if plugin_id not in loaded_plugin_ids
}
- loaded_root_plugin_ids = reload_root_ids & loaded_plugin_ids
- if loaded_root_plugin_ids:
+ if loaded_root_plugin_ids := reload_root_ids & loaded_plugin_ids:
target_plugin_ids.update(self._collect_reverse_dependents_for_roots(loaded_root_plugin_ids))
unload_order = self._build_unload_order(target_plugin_ids & loaded_plugin_ids)
@@ -823,11 +822,10 @@ class PluginRunner:
continue
_, manifest, _ = candidate
- unsatisfied_dependencies = self._loader.manifest_validator.get_unsatisfied_plugin_dependencies(
+ if unsatisfied_dependencies := self._loader.manifest_validator.get_unsatisfied_plugin_dependencies(
manifest,
available_plugin_versions=available_plugins,
- )
- if unsatisfied_dependencies:
+ ):
failed_plugins[load_plugin_id] = f"依赖未满足: {', '.join(unsatisfied_dependencies)}"
continue
@@ -1240,10 +1238,12 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
同时阻止插件代码直接导入主程序内部 ``src.*`` 模块,并清理可直接从
``sys.modules`` 摸到的高权限叶子模块,避免绕过 SDK / capability 边界。
"""
+ from importlib import util as importlib_util
+ from types import ModuleType
+
import builtins
import importlib
import sysconfig
- from types import ModuleType
# 保留: 标准库路径 + site-packages(含 SDK 和依赖)
stdlib_paths = set()
@@ -1389,26 +1389,28 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
sys.modules.pop(module_name, None)
# ``import`` 语句与 ``importlib.import_module`` 走的是不同入口,因此两边都需要兜底。
- original_import = getattr(builtins, "__maibot_runner_original_import__", builtins.__import__)
- builtins.__maibot_runner_original_import__ = original_import
+ builtins_module = cast(Any, builtins)
+ original_import = getattr(builtins_module, "__maibot_runner_original_import__", builtins.__import__)
+ builtins_module.__maibot_runner_original_import__ = original_import
def _guarded_import(name: str, globals: Any = None, locals: Any = None, fromlist: Any = (), level: int = 0) -> Any:
if level == 0:
_assert_plugin_import_allowed(name, import_globals=globals, fromlist=fromlist)
return original_import(name, globals, locals, fromlist, level)
- _guarded_import.__maibot_runner_plugin_import_guard__ = True
+ cast(Any, _guarded_import).__maibot_runner_plugin_import_guard__ = True
builtins.__import__ = _guarded_import
- original_import_module = getattr(importlib, "__maibot_runner_original_import_module__", importlib.import_module)
- importlib.__maibot_runner_original_import_module__ = original_import_module
+ importlib_module = cast(Any, importlib)
+ original_import_module = getattr(importlib_module, "__maibot_runner_original_import_module__", importlib.import_module)
+ importlib_module.__maibot_runner_original_import_module__ = original_import_module
def _guarded_import_module(name: str, package: Optional[str] = None) -> Any:
- resolved_name = importlib.util.resolve_name(name, package) if name.startswith(".") else name
+ resolved_name = importlib_util.resolve_name(name, package) if name.startswith(".") else name
_assert_plugin_import_allowed(resolved_name)
return original_import_module(name, package)
- _guarded_import_module.__maibot_runner_plugin_import_guard__ = True
+ cast(Any, _guarded_import_module).__maibot_runner_plugin_import_guard__ = True
importlib.import_module = _guarded_import_module
From 2c279f703ca386306a4f066cd8a6b33df5510adb Mon Sep 17 00:00:00 2001
From: DrSmoothl <1787882683@qq.com>
Date: Tue, 24 Mar 2026 15:32:09 +0800
Subject: [PATCH 45/45] =?UTF-8?q?Revert=20"feat=EF=BC=9A=E5=B0=9D=E8=AF=95?=
=?UTF-8?q?=E5=BB=BA=E7=AB=8Bhfc=E9=80=BB=E8=BE=91"?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
This reverts commit bfc9781c4f10212ecfffc3129115ecc6fbd3fa89.
---
src/chat/heart_flow/heartFC_chat - 副本.py | 734 ------------------
src/chat/heart_flow/heartFC_chat.py | 823 ++++-----------------
src/chat/heart_flow/heartflow.py | 42 --
3 files changed, 160 insertions(+), 1439 deletions(-)
delete mode 100644 src/chat/heart_flow/heartFC_chat - 副本.py
delete mode 100644 src/chat/heart_flow/heartflow.py
diff --git a/src/chat/heart_flow/heartFC_chat - 副本.py b/src/chat/heart_flow/heartFC_chat - 副本.py
deleted file mode 100644
index 02f70281..00000000
--- a/src/chat/heart_flow/heartFC_chat - 副本.py
+++ /dev/null
@@ -1,734 +0,0 @@
-import asyncio
-import time
-import traceback
-import random
-from typing import List, Optional, Dict, Any, Tuple, TYPE_CHECKING
-from rich.traceback import install
-
-from src.config.config import global_config
-from src.common.logger import get_logger
-from src.common.data_models.info_data_model import ActionPlannerInfo
-from src.common.data_models.message_data_model import ReplyContentType
-from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
-from src.chat.utils.prompt_builder import global_prompt_manager
-from src.chat.utils.timer_calculator import Timer
-from src.chat.planner_actions.planner import ActionPlanner
-from src.chat.planner_actions.action_modifier import ActionModifier
-from src.chat.planner_actions.action_manager import ActionManager
-from src.chat.heart_flow.hfc_utils import CycleDetail
-from src.learners.expression_learner import expression_learner_manager
-from src.chat.heart_flow.frequency_control import frequency_control_manager
-from src.learners.message_recorder import extract_and_distribute_messages
-from src.person_info.person_info import Person
-from src.plugin_system.base.component_types import EventType, ActionInfo
-from src.plugin_system.core import events_manager
-from src.plugin_system.apis import generator_api, send_api, message_api, database_api
-from src.chat.utils.chat_message_builder import (
- build_readable_messages_with_id,
- get_raw_msg_before_timestamp_with_chat,
-)
-from src.chat.utils.utils import record_replyer_action_temp
-from src.memory_system.chat_history_summarizer import ChatHistorySummarizer
-
-if TYPE_CHECKING:
- from src.common.data_models.database_data_model import DatabaseMessages
- from src.common.data_models.message_data_model import ReplySetModel
-
-
-ERROR_LOOP_INFO = {
- "loop_plan_info": {
- "action_result": {
- "action_type": "error",
- "action_data": {},
- "reasoning": "循环处理失败",
- },
- },
- "loop_action_info": {
- "action_taken": False,
- "reply_text": "",
- "command": "",
- "taken_time": time.time(),
- },
-}
-
-
-install(extra_lines=3)
-
-# 注释:原来的动作修改超时常量已移除,因为改为顺序执行
-
-logger = get_logger("hfc") # Logger Name Changed
-
-
-class HeartFChatting:
- """
- 管理一个连续的Focus Chat循环
- 用于在特定聊天流中生成回复。
- 其生命周期现在由其关联的 SubHeartflow 的 FOCUSED 状态控制。
- """
-
- def __init__(self, chat_id: str):
- """
- HeartFChatting 初始化函数
-
- 参数:
- chat_id: 聊天流唯一标识符(如stream_id)
- on_stop_focus_chat: 当收到stop_focus_chat命令时调用的回调函数
- performance_version: 性能记录版本号,用于区分不同启动版本
- """
- # 基础属性
- self.stream_id: str = chat_id # 聊天流ID
- self.chat_stream: ChatStream = get_chat_manager().get_stream(self.stream_id) # type: ignore
- if not self.chat_stream:
- raise ValueError(f"无法找到聊天流: {self.stream_id}")
- self.log_prefix = f"[{get_chat_manager().get_stream_name(self.stream_id) or self.stream_id}]"
-
- self.expression_learner = expression_learner_manager.get_expression_learner(self.stream_id)
-
- self.action_manager = ActionManager()
- self.action_planner = ActionPlanner(chat_id=self.stream_id, action_manager=self.action_manager)
- self.action_modifier = ActionModifier(action_manager=self.action_manager, chat_id=self.stream_id)
-
- # 循环控制内部状态
- self.running: bool = False
- self._loop_task: Optional[asyncio.Task] = None # 主循环任务
-
- # 添加循环信息管理相关的属性
- self.history_loop: List[CycleDetail] = []
- self._cycle_counter = 0
- self._current_cycle_detail: CycleDetail = None # type: ignore
-
- self.last_read_time = time.time() - 2
-
- self.is_mute = False
-
- self.last_active_time = time.time() # 记录上一次非noreply时间
-
- self.question_probability_multiplier = 1
- self.questioned = False
-
- # 跟踪连续 no_reply 次数,用于动态调整阈值
- self.consecutive_no_reply_count = 0
-
- # 聊天内容概括器
- self.chat_history_summarizer = ChatHistorySummarizer(chat_id=self.stream_id)
-
- async def start(self):
- """检查是否需要启动主循环,如果未激活则启动。"""
-
- # 如果循环已经激活,直接返回
- if self.running:
- logger.debug(f"{self.log_prefix} HeartFChatting 已激活,无需重复启动")
- return
-
- try:
- # 标记为活动状态,防止重复启动
- self.running = True
-
- self._loop_task = asyncio.create_task(self._main_chat_loop())
- self._loop_task.add_done_callback(self._handle_loop_completion)
-
- # 启动聊天内容概括器的后台定期检查循环
- await self.chat_history_summarizer.start()
-
- logger.info(f"{self.log_prefix} HeartFChatting 启动完成")
-
- except Exception as e:
- # 启动失败时重置状态
- self.running = False
- self._loop_task = None
- logger.error(f"{self.log_prefix} HeartFChatting 启动失败: {e}")
- raise
-
- def _handle_loop_completion(self, task: asyncio.Task):
- """当 _hfc_loop 任务完成时执行的回调。"""
- try:
- if exception := task.exception():
- logger.error(f"{self.log_prefix} HeartFChatting: 脱离了聊天(异常): {exception}")
- logger.error(traceback.format_exc()) # Log full traceback for exceptions
- else:
- logger.info(f"{self.log_prefix} HeartFChatting: 脱离了聊天 (外部停止)")
- except asyncio.CancelledError:
- logger.info(f"{self.log_prefix} HeartFChatting: 结束了聊天")
-
- def start_cycle(self) -> Tuple[Dict[str, float], str]:
- self._cycle_counter += 1
- self._current_cycle_detail = CycleDetail(self._cycle_counter)
- self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
- cycle_timers = {}
- return cycle_timers, self._current_cycle_detail.thinking_id
-
- def end_cycle(self, loop_info, cycle_timers):
- self._current_cycle_detail.set_loop_info(loop_info)
- self.history_loop.append(self._current_cycle_detail)
- self._current_cycle_detail.timers = cycle_timers
- self._current_cycle_detail.end_time = time.time()
-
- def print_cycle_info(self, cycle_timers):
- # 记录循环信息和计时器结果
- timer_strings = []
- for name, elapsed in cycle_timers.items():
- if elapsed < 0.1:
- # 不显示小于0.1秒的计时器
- continue
- formatted_time = f"{elapsed:.2f}秒"
- timer_strings.append(f"{name}: {formatted_time}")
-
- logger.info(
- f"{self.log_prefix} 第{self._current_cycle_detail.cycle_id}次思考,"
- f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒;" # type: ignore
- + (f"详情: {'; '.join(timer_strings)}" if timer_strings else "")
- )
-
- async def _loopbody(self):
- recent_messages_list = message_api.get_messages_by_time_in_chat(
- chat_id=self.stream_id,
- start_time=self.last_read_time,
- end_time=time.time(),
- limit=20,
- limit_mode="latest",
- filter_mai=True,
- filter_command=False,
- filter_intercept_message_level=0,
- )
-
- # 根据连续 no_reply 次数动态调整阈值
- # 3次 no_reply 时,阈值调高到 1.5(50%概率为1,50%概率为2)
- # 5次 no_reply 时,提高到 2(大于等于两条消息的阈值)
- if self.consecutive_no_reply_count >= 5:
- threshold = 2
- elif self.consecutive_no_reply_count >= 3:
- # 1.5 的含义:50%概率为1,50%概率为2
- threshold = 2 if random.random() < 0.5 else 1
- else:
- threshold = 1
-
- if len(recent_messages_list) >= threshold:
- # for message in recent_messages_list:
- # print(message.processed_plain_text)
-
- self.last_read_time = time.time()
-
- # !此处使at或者提及必定回复
- mentioned_message = None
- for message in recent_messages_list:
- if (message.is_mentioned or message.is_at) and global_config.chat.mentioned_bot_reply:
- mentioned_message = message
-
- # logger.info(f"{self.log_prefix} 当前talk_value: {global_config.chat.get_talk_value(self.stream_id)}")
-
- # *控制频率用
- if mentioned_message:
- await self._observe(recent_messages_list=recent_messages_list, force_reply_message=mentioned_message)
- elif (
- random.random()
- < global_config.chat.get_talk_value(self.stream_id)
- * frequency_control_manager.get_or_create_frequency_control(self.stream_id).get_talk_frequency_adjust()
- ):
- await self._observe(recent_messages_list=recent_messages_list)
- else:
- # 没有提到,继续保持沉默,等待5秒防止频繁触发
- await asyncio.sleep(10)
- return True
- else:
- await asyncio.sleep(0.2)
- return True
- return True
-
- async def _send_and_store_reply(
- self,
- response_set: "ReplySetModel",
- action_message: "DatabaseMessages",
- cycle_timers: Dict[str, float],
- thinking_id,
- actions,
- selected_expressions: Optional[List[int]] = None,
- quote_message: Optional[bool] = None,
- ) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
- with Timer("回复发送", cycle_timers):
- reply_text = await self._send_response(
- reply_set=response_set,
- message_data=action_message,
- selected_expressions=selected_expressions,
- quote_message=quote_message,
- )
-
- # 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
- platform = action_message.chat_info.platform
- if platform is None:
- platform = getattr(self.chat_stream, "platform", "unknown")
-
- person = Person(platform=platform, user_id=action_message.user_info.user_id)
- person_name = person.person_name
- action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
-
- await database_api.store_action_info(
- chat_stream=self.chat_stream,
- action_build_into_prompt=False,
- action_prompt_display=action_prompt_display,
- action_done=True,
- thinking_id=thinking_id,
- action_data={"reply_text": reply_text},
- action_name="reply",
- )
-
- # 构建循环信息
- loop_info: Dict[str, Any] = {
- "loop_plan_info": {
- "action_result": actions,
- },
- "loop_action_info": {
- "action_taken": True,
- "reply_text": reply_text,
- "command": "",
- "taken_time": time.time(),
- },
- }
-
- return loop_info, reply_text, cycle_timers
-
- async def _observe(
- self, # interest_value: float = 0.0,
- recent_messages_list: Optional[List["DatabaseMessages"]] = None,
- force_reply_message: Optional["DatabaseMessages"] = None,
- ) -> bool: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
- if recent_messages_list is None:
- recent_messages_list = []
- _reply_text = "" # 初始化reply_text变量,避免UnboundLocalError
-
- start_time = time.time()
- async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
- # 通过 MessageRecorder 统一提取消息并分发给 expression_learner 和 jargon_miner
- # 在 replyer 执行时触发,统一管理时间窗口,避免重复获取消息
- asyncio.create_task(extract_and_distribute_messages(self.stream_id))
-
- # 添加curious检测任务 - 检测聊天记录中的矛盾、冲突或需要提问的内容
- # asyncio.create_task(check_and_make_question(self.stream_id))
- # 添加聊天内容概括任务 - 累积、打包和压缩聊天记录
- # 注意:后台循环已在start()中启动,这里作为额外触发点,在有思考时立即处理
- # asyncio.create_task(self.chat_history_summarizer.process())
-
- cycle_timers, thinking_id = self.start_cycle()
- logger.info(
- f"{self.log_prefix} 开始第{self._cycle_counter}次思考(频率: {global_config.chat.get_talk_value(self.stream_id)})"
- )
-
- # 第一步:动作检查
- available_actions: Dict[str, ActionInfo] = {}
- try:
- await self.action_modifier.modify_actions()
- available_actions = self.action_manager.get_using_actions()
- except Exception as e:
- logger.error(f"{self.log_prefix} 动作修改失败: {e}")
-
- # 执行planner
- is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
-
- message_list_before_now = get_raw_msg_before_timestamp_with_chat(
- chat_id=self.stream_id,
- timestamp=time.time(),
- limit=int(global_config.chat.max_context_size * 0.6),
- filter_intercept_message_level=1,
- )
- chat_content_block, message_id_list = build_readable_messages_with_id(
- messages=message_list_before_now,
- timestamp_mode="normal_no_YMD",
- read_mark=self.action_planner.last_obs_time_mark,
- truncate=True,
- show_actions=True,
- )
-
- prompt_info = await self.action_planner.build_planner_prompt(
- is_group_chat=is_group_chat,
- chat_target_info=chat_target_info,
- current_available_actions=available_actions,
- chat_content_block=chat_content_block,
- message_id_list=message_id_list,
- )
- continue_flag, modified_message = await events_manager.handle_mai_events(
- EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id
- )
- if not continue_flag:
- return False
- if modified_message and modified_message._modify_flags.modify_llm_prompt:
- prompt_info = (modified_message.llm_prompt, prompt_info[1])
-
- with Timer("规划器", cycle_timers):
- action_to_use_info = await self.action_planner.plan(
- loop_start_time=self.last_read_time,
- available_actions=available_actions,
- force_reply_message=force_reply_message,
- )
-
- logger.info(
- f"{self.log_prefix} 决定执行{len(action_to_use_info)}个动作: {' '.join([a.action_type for a in action_to_use_info])}"
- )
-
- # 3. 并行执行所有动作
- action_tasks = [
- asyncio.create_task(
- self._execute_action(action, action_to_use_info, thinking_id, available_actions, cycle_timers)
- )
- for action in action_to_use_info
- ]
-
- # 并行执行所有任务
- results = await asyncio.gather(*action_tasks, return_exceptions=True)
-
- # 处理执行结果
- reply_loop_info = None
- reply_text_from_reply = ""
- action_success = False
- action_reply_text = ""
-
- excute_result_str = ""
- for result in results:
- excute_result_str += f"{result['action_type']} 执行结果:{result['result']}\n"
-
- if isinstance(result, BaseException):
- logger.error(f"{self.log_prefix} 动作执行异常: {result}")
- continue
-
- if result["action_type"] != "reply":
- action_success = result["success"]
- action_reply_text = result["result"]
- elif result["action_type"] == "reply":
- if result["success"]:
- reply_loop_info = result["loop_info"]
- reply_text_from_reply = result["result"]
- else:
- logger.warning(f"{self.log_prefix} 回复动作执行失败")
-
- self.action_planner.add_plan_excute_log(result=excute_result_str)
-
- # 构建最终的循环信息
- if reply_loop_info:
- # 如果有回复信息,使用回复的loop_info作为基础
- loop_info = reply_loop_info
- # 更新动作执行信息
- loop_info["loop_action_info"].update(
- {
- "action_taken": action_success,
- "taken_time": time.time(),
- }
- )
- _reply_text = reply_text_from_reply
- else:
- # 没有回复信息,构建纯动作的loop_info
- loop_info = {
- "loop_plan_info": {
- "action_result": action_to_use_info,
- },
- "loop_action_info": {
- "action_taken": action_success,
- "reply_text": action_reply_text,
- "taken_time": time.time(),
- },
- }
- _reply_text = action_reply_text
-
- self.end_cycle(loop_info, cycle_timers)
- self.print_cycle_info(cycle_timers)
-
- end_time = time.time()
- if end_time - start_time < global_config.chat.planner_smooth:
- wait_time = global_config.chat.planner_smooth - (end_time - start_time)
- await asyncio.sleep(wait_time)
- else:
- await asyncio.sleep(0.1)
- return True
-
- async def _main_chat_loop(self):
- """主循环,持续进行计划并可能回复消息,直到被外部取消。"""
- try:
- while self.running:
- # 主循环
- success = await self._loopbody()
- await asyncio.sleep(0.1)
- if not success:
- break
- except asyncio.CancelledError:
- # 设置了关闭标志位后被取消是正常流程
- logger.info(f"{self.log_prefix} 麦麦已关闭聊天")
- except Exception:
- logger.error(f"{self.log_prefix} 麦麦聊天意外错误,将于3s后尝试重新启动")
- print(traceback.format_exc())
- await asyncio.sleep(3)
- self._loop_task = asyncio.create_task(self._main_chat_loop())
- logger.error(f"{self.log_prefix} 结束了当前聊天循环")
-
- async def _handle_action(
- self,
- action: str,
- action_reasoning: str,
- action_data: dict,
- cycle_timers: Dict[str, float],
- thinking_id: str,
- action_message: Optional["DatabaseMessages"] = None,
- ) -> tuple[bool, str, str]:
- """
- 处理规划动作,使用动作工厂创建相应的动作处理器
-
- 参数:
- action: 动作类型
- action_reasoning: 决策理由
- action_data: 动作数据,包含不同动作需要的参数
- cycle_timers: 计时器字典
- thinking_id: 思考ID
- action_message: 消息数据
- 返回:
- tuple[bool, str, str]: (是否执行了动作, 思考消息ID, 命令)
- """
- try:
- # 使用工厂创建动作处理器实例
- try:
- action_handler = self.action_manager.create_action(
- action_name=action,
- action_data=action_data,
- cycle_timers=cycle_timers,
- thinking_id=thinking_id,
- chat_stream=self.chat_stream,
- log_prefix=self.log_prefix,
- action_reasoning=action_reasoning,
- action_message=action_message,
- )
- except Exception as e:
- logger.error(f"{self.log_prefix} 创建动作处理器时出错: {e}")
- traceback.print_exc()
- return False, ""
-
- # 处理动作并获取结果(固定记录一次动作信息)
- result = await action_handler.execute()
- success, action_text = result
-
- return success, action_text
-
- except Exception as e:
- logger.error(f"{self.log_prefix} 处理{action}时出错: {e}")
- traceback.print_exc()
- return False, ""
-
- async def _send_response(
- self,
- reply_set: "ReplySetModel",
- message_data: "DatabaseMessages",
- selected_expressions: Optional[List[int]] = None,
- quote_message: Optional[bool] = None,
- ) -> str:
- # 根据 llm_quote 配置决定是否使用 quote_message 参数
- if global_config.chat.llm_quote:
- # 如果配置为 true,使用 llm_quote 参数决定是否引用回复
- if quote_message is None:
- logger.warning(f"{self.log_prefix} quote_message 参数为空,不引用")
- need_reply = False
- else:
- need_reply = quote_message
- if need_reply:
- logger.info(f"{self.log_prefix} LLM 决定使用引用回复")
- else:
- # 如果配置为 false,使用原来的模式
- new_message_count = message_api.count_new_messages(
- chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time()
- )
- need_reply = new_message_count >= random.randint(2, 3) or time.time() - self.last_read_time > 90
- if need_reply:
- logger.info(f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复,或者上次回复时间超过90秒")
-
- reply_text = ""
- first_replied = False
- for reply_content in reply_set.reply_data:
- if reply_content.content_type != ReplyContentType.TEXT:
- continue
- data: str = reply_content.content # type: ignore
- if not first_replied:
- await send_api.text_to_stream(
- text=data,
- stream_id=self.chat_stream.stream_id,
- reply_message=message_data,
- set_reply=need_reply,
- typing=False,
- selected_expressions=selected_expressions,
- )
- first_replied = True
- else:
- await send_api.text_to_stream(
- text=data,
- stream_id=self.chat_stream.stream_id,
- reply_message=message_data,
- set_reply=False,
- typing=True,
- selected_expressions=selected_expressions,
- )
- reply_text += data
-
- return reply_text
-
- async def _execute_action(
- self,
- action_planner_info: ActionPlannerInfo,
- chosen_action_plan_infos: List[ActionPlannerInfo],
- thinking_id: str,
- available_actions: Dict[str, ActionInfo],
- cycle_timers: Dict[str, float],
- ):
- """执行单个动作的通用函数"""
- try:
- with Timer(f"动作{action_planner_info.action_type}", cycle_timers):
- # 直接当场执行no_reply逻辑
- if action_planner_info.action_type == "no_reply":
- # 直接处理no_reply逻辑,不再通过动作系统
- reason = action_planner_info.reasoning or "选择不回复"
- # logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
-
- # 增加连续 no_reply 计数
- self.consecutive_no_reply_count += 1
-
- await database_api.store_action_info(
- chat_stream=self.chat_stream,
- action_build_into_prompt=False,
- action_prompt_display=reason,
- action_done=True,
- thinking_id=thinking_id,
- action_data={},
- action_name="no_reply",
- action_reasoning=reason,
- )
-
- return {"action_type": "no_reply", "success": True, "result": "选择不回复", "command": ""}
-
- elif action_planner_info.action_type == "reply":
- # 直接当场执行reply逻辑
- self.questioned = False
- # 刷新主动发言状态
- # 重置连续 no_reply 计数
- self.consecutive_no_reply_count = 0
-
- reason = action_planner_info.reasoning or ""
- # 根据 think_mode 配置决定 think_level 的值
- think_mode = global_config.chat.think_mode
- if think_mode == "default":
- think_level = 0
- elif think_mode == "deep":
- think_level = 1
- elif think_mode == "dynamic":
- # dynamic 模式:从 planner 返回的 action_data 中获取
- think_level = action_planner_info.action_data.get("think_level", 1)
- else:
- # 默认使用 default 模式
- think_level = 0
- # 使用 action_reasoning(planner 的整体思考理由)作为 reply_reason
- planner_reasoning = action_planner_info.action_reasoning or reason
-
- record_replyer_action_temp(
- chat_id=self.stream_id,
- reason=reason,
- think_level=think_level,
- )
-
- await database_api.store_action_info(
- chat_stream=self.chat_stream,
- action_build_into_prompt=False,
- action_prompt_display=reason,
- action_done=True,
- thinking_id=thinking_id,
- action_data={},
- action_name="reply",
- action_reasoning=reason,
- )
-
- # 从 Planner 的 action_data 中提取未知词语列表(仅在 reply 时使用)
- unknown_words = None
- quote_message = None
- if isinstance(action_planner_info.action_data, dict):
- uw = action_planner_info.action_data.get("unknown_words")
- if isinstance(uw, list):
- cleaned_uw: List[str] = []
- for item in uw:
- if isinstance(item, str):
- s = item.strip()
- if s:
- cleaned_uw.append(s)
- if cleaned_uw:
- unknown_words = cleaned_uw
-
- # 从 Planner 的 action_data 中提取 quote_message 参数
- qm = action_planner_info.action_data.get("quote")
- if qm is not None:
- # 支持多种格式:true/false, "true"/"false", 1/0
- if isinstance(qm, bool):
- quote_message = qm
- elif isinstance(qm, str):
- quote_message = qm.lower() in ("true", "1", "yes")
- elif isinstance(qm, (int, float)):
- quote_message = bool(qm)
-
- logger.info(f"{self.log_prefix} {qm}引用回复设置: {quote_message}")
-
- success, llm_response = await generator_api.generate_reply(
- chat_stream=self.chat_stream,
- reply_message=action_planner_info.action_message,
- available_actions=available_actions,
- chosen_actions=chosen_action_plan_infos,
- reply_reason=planner_reasoning,
- unknown_words=unknown_words,
- enable_tool=global_config.tool.enable_tool,
- request_type="replyer",
- from_plugin=False,
- reply_time_point=action_planner_info.action_data.get("loop_start_time", time.time()),
- think_level=think_level,
- )
-
- if not success or not llm_response or not llm_response.reply_set:
- if action_planner_info.action_message:
- logger.info(f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败")
- else:
- logger.info("回复生成失败")
- return {"action_type": "reply", "success": False, "result": "回复生成失败", "loop_info": None}
-
- response_set = llm_response.reply_set
- selected_expressions = llm_response.selected_expressions
- loop_info, reply_text, _ = await self._send_and_store_reply(
- response_set=response_set,
- action_message=action_planner_info.action_message, # type: ignore
- cycle_timers=cycle_timers,
- thinking_id=thinking_id,
- actions=chosen_action_plan_infos,
- selected_expressions=selected_expressions,
- quote_message=quote_message,
- )
- self.last_active_time = time.time()
- return {
- "action_type": "reply",
- "success": True,
- "result": f"你使用reply动作,对' {action_planner_info.action_message.processed_plain_text} '这句话进行了回复,回复内容为: '{reply_text}'",
- "loop_info": loop_info,
- }
-
- else:
- # 执行普通动作
- with Timer("动作执行", cycle_timers):
- success, result = await self._handle_action(
- action=action_planner_info.action_type,
- action_reasoning=action_planner_info.action_reasoning or "",
- action_data=action_planner_info.action_data or {},
- cycle_timers=cycle_timers,
- thinking_id=thinking_id,
- action_message=action_planner_info.action_message,
- )
-
- self.last_active_time = time.time()
- return {
- "action_type": action_planner_info.action_type,
- "success": success,
- "result": result,
- }
-
- except Exception as e:
- logger.error(f"{self.log_prefix} 执行动作时出错: {e}")
- logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}")
- return {
- "action_type": action_planner_info.action_type,
- "success": False,
- "result": "",
- "loop_info": None,
- "error": str(e),
- }
diff --git a/src/chat/heart_flow/heartFC_chat.py b/src/chat/heart_flow/heartFC_chat.py
index 2c1eb162..74d94773 100644
--- a/src/chat/heart_flow/heartFC_chat.py
+++ b/src/chat/heart_flow/heartFC_chat.py
@@ -1,377 +1,231 @@
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from rich.traceback import install
+from typing import List, Optional, TYPE_CHECKING
import asyncio
import random
import time
import traceback
-from rich.traceback import install
-
-from src.learners.expression_learner import ExpressionLearner
-from src.learners.jargon_miner import JargonMiner
-from src.chat.event_helpers import build_event_message
-from src.chat.logger.plan_reply_logger import PlanReplyLogger
-from src.chat.message_receive.chat_manager import BotChatSession
-from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
-from src.chat.planner_actions.action_manager import ActionManager
-from src.chat.planner_actions.action_modifier import ActionModifier
-from src.chat.planner_actions.planner import ActionPlanner
-from src.chat.utils.prompt_builder import global_prompt_manager
-from src.chat.utils.timer_calculator import Timer
-from src.chat.utils.utils import record_replyer_action_temp
-from src.common.data_models.info_data_model import ActionPlannerInfo
-from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
+from src.chat.message_receive.chat_manager import chat_manager
from src.common.logger import get_logger
from src.common.utils.utils_config import ChatConfigUtils, ExpressionConfigUtils
from src.config.config import global_config
from src.config.file_watcher import FileChange
-from src.core.event_bus import event_bus
-from src.core.types import ActionInfo, EventType
-from src.person_info.person_info import Person
-from src.services import (
- database_service as database_api,
- generator_service as generator_api,
- message_service as message_api,
- send_service as send_api,
-)
-from src.services.message_service import build_readable_messages_with_id, get_messages_before_time_in_chat
+from src.learners.expression_learner import ExpressionLearner
+from src.learners.jargon_miner import JargonMiner
from .heartFC_utils import CycleDetail
if TYPE_CHECKING:
from src.chat.message_receive.message import SessionMessage
-
install(extra_lines=5)
logger = get_logger("heartFC_chat")
class HeartFChatting:
- """管理一个持续运行的 Focus Chat 会话。"""
+ """
+ 管理一个连续的Focus Chat聊天会话
+ 用于在特定的聊天会话里面生成回复
+ """
def __init__(self, session_id: str):
- self.session_id = session_id
- self.chat_stream: BotChatSession = _chat_manager.get_session_by_session_id(self.session_id) # type: ignore[assignment]
- if not self.chat_stream:
- raise ValueError(f"无法找到聊天会话 {self.session_id}")
+ """
+ 初始化 HeartFChatting 实例
- session_name = _chat_manager.get_session_name(session_id) or session_id
+ Args:
+ session_id: 聊天会话ID
+ """
+ # 基础属性
+ self.session_id = session_id
+ session_name = chat_manager.get_session_name(session_id) or session_id
self.log_prefix = f"[{session_name}]"
self.session_name = session_name
- self.action_manager = ActionManager()
- self.action_planner = ActionPlanner(chat_id=self.session_id, action_manager=self.action_manager)
- self.action_modifier = ActionModifier(action_manager=self.action_manager, chat_id=self.session_id)
-
+ # 系统运行状态
self._running: bool = False
self._loop_task: Optional[asyncio.Task] = None
+ self._cycle_counter: int = 0
+ self._hfc_lock: asyncio.Lock = asyncio.Lock() # 用于保护 _hfc_func 的并发访问
+ # 聊天频率相关
+ self._consecutive_no_reply_count = 0 # 跟踪连续 no_reply 次数,用于动态调整阈值
+ self._talk_frequency_adjust: float = 1.0 # 发言频率修正值,默认为1.0,可以根据需要调整
+
+ # HFC内消息缓存
+ self.message_cache: List[SessionMessage] = []
+
+ # Asyncio Event 用于控制循环的开始和结束
self._cycle_event = asyncio.Event()
- self._hfc_lock = asyncio.Lock()
-
- self._cycle_counter = 0
- self._current_cycle_detail: Optional[CycleDetail] = None
- self.history_loop: List[CycleDetail] = []
-
- self.last_read_time = time.time() - 2
- self.last_active_time = time.time()
- self._talk_frequency_adjust = 1.0
- self._consecutive_no_reply_count = 0
-
- self.message_cache: List["SessionMessage"] = []
-
- self._min_messages_for_extraction = 30
- self._min_extraction_interval = 60
- self._last_extraction_time = 0.0
+ # 表达方式相关内容
+ self._min_messages_for_extraction = 30 # 最少提取消息数
+ self._min_extraction_interval = 60 # 最小提取时间间隔,单位为秒
+ self._last_extraction_time: float = 0.0 # 上次提取的时间戳
expr_use, jargon_learn, expr_learn = ExpressionConfigUtils.get_expression_config_for_chat(session_id)
- self._enable_expression_use = expr_use
- self._enable_expression_learning = expr_learn
- self._enable_jargon_learning = jargon_learn
- self._expression_learner = ExpressionLearner(session_id)
- self._jargon_miner = JargonMiner(session_id, session_name=session_name)
+ self._enable_expression_use = expr_use # 允许使用表达方式,但不一定启用学习
+ self._enable_expression_learning = expr_learn # 允许学习表达方式
+ self._enable_jargon_learning = jargon_learn # 允许学习黑话
+ # 表达学习器
+ self._expression_learner: ExpressionLearner = ExpressionLearner(session_id)
+ # 黑话挖掘器
+ self._jargon_miner: JargonMiner = JargonMiner(session_id, session_name=session_name)
+
+ # TODO: ChatSummarizer 聊天总结器重构
+
+ # ====== 公开方法 ======
async def start(self):
+ """启动 HeartFChatting 的主循环"""
+ # 先检查是否已经启动运行
if self._running:
- logger.debug(f"{self.log_prefix} HeartFChatting 已在运行中")
+ logger.debug(f"{self.log_prefix} 已经在运行中,无需重复启动")
return
try:
self._running = True
- self._cycle_event.clear()
+ self._cycle_event.clear() # 确保事件初始状态为未设置
+
self._loop_task = asyncio.create_task(self.main_loop())
self._loop_task.add_done_callback(self._handle_loop_completion)
+
logger.info(f"{self.log_prefix} HeartFChatting 启动完成")
- except Exception as exc:
- logger.error(f"{self.log_prefix} HeartFChatting 启动失败: {exc}", exc_info=True)
- self._running = False
- self._cycle_event.set()
- self._loop_task = None
+ except Exception as e:
+ logger.error(f"{self.log_prefix} 启动 HeartFChatting 失败: {e}", exc_info=True)
+ self._running = False # 确保状态正确
+ self._cycle_event.set() # 确保事件被设置,避免死锁
+ self._loop_task = None # 确保任务引用被清理
raise
async def stop(self):
+ """停止 HeartFChatting 的主循环"""
if not self._running:
- logger.debug(f"{self.log_prefix} HeartFChatting 已停止")
+ logger.debug(f"{self.log_prefix} HeartFChatting 已经停止,无需重复停止")
return
self._running = False
- self._cycle_event.set()
+ self._cycle_event.set() # 触发事件,通知循环结束
if self._loop_task:
- self._loop_task.cancel()
+ self._loop_task.cancel() # 取消主循环任务
try:
- await self._loop_task
+ await self._loop_task # 等待任务完成
except asyncio.CancelledError:
- logger.info(f"{self.log_prefix} HeartFChatting 主循环已取消")
- except Exception as exc:
- logger.error(f"{self.log_prefix} 停止 HeartFChatting 时发生错误: {exc}", exc_info=True)
+ logger.info(f"{self.log_prefix} HeartFChatting 主循环已成功取消")
+ except Exception as e:
+ logger.error(f"{self.log_prefix} 停止 HeartFChatting 时发生错误: {e}", exc_info=True)
finally:
- self._loop_task = None
+ self._loop_task = None # 确保任务引用被清理
logger.info(f"{self.log_prefix} HeartFChatting 已停止")
def adjust_talk_frequency(self, new_value: float):
+ """调整发言频率的调整值
+
+ Args:
+ new_value: 新的修正值,必须为非负数。值越大,修正发言频率越高;值越小,修正发言频率越低。
+ """
self._talk_frequency_adjust = max(0.0, new_value)
async def register_message(self, message: "SessionMessage"):
+ """注册一条消息到 HeartFChatting 的缓存中,并检测其是否产生提及,决定是否唤醒聊天
+
+ Args:
+ message: 待注册的消息对象
+ """
self.message_cache.append(message)
-
+ # 先检查at必回复
if global_config.chat.inevitable_at_reply and message.is_at:
- self.last_read_time = time.time()
- async with self._hfc_lock:
- await self._judge_and_response(mentioned_message=message, recent_messages_list=[message])
- return
-
+ async with self._hfc_lock: # 确保与主循环逻辑的互斥访问
+ await self._judge_and_response(message)
+ return # 直接返回,避免同一条消息被主循环再次处理
+ # 再检查提及必回复
if global_config.chat.mentioned_bot_reply and message.is_mentioned:
- self.last_read_time = time.time()
- async with self._hfc_lock:
- await self._judge_and_response(mentioned_message=message, recent_messages_list=[message])
+ # 直接获取锁,确保一定一定触发回复逻辑,不受当前是否正在执行主循环的影响
+ async with self._hfc_lock: # 确保与主循环逻辑的互斥访问
+ await self._judge_and_response(message)
return
async def main_loop(self):
try:
while self._running and not self._cycle_event.is_set():
if not self._hfc_lock.locked():
- async with self._hfc_lock:
+ async with self._hfc_lock: # 确保主循环逻辑的互斥访问
await self._hfc_func()
- await asyncio.sleep(0.1)
+ await asyncio.sleep(5)
except asyncio.CancelledError:
- logger.info(f"{self.log_prefix} HeartFChatting: 主循环被取消")
- except Exception as exc:
- logger.error(f"{self.log_prefix} HeartFChatting: 主循环异常: {exc}", exc_info=True)
- await self.stop()
+ logger.info(f"{self.log_prefix} HeartFChatting: 主循环被取消,正在关闭")
+ except Exception as e:
+ logger.error(f"{self.log_prefix} 麦麦聊天意外错误: {e},将于3s后尝试重新启动")
+ await self.stop() # 确保状态正确
await asyncio.sleep(3)
- await self.start()
+ await self.start() # 尝试重新启动
async def _config_callback(self, file_change: Optional[FileChange] = None):
- del file_change
- expr_use, jargon_learn, expr_learn = ExpressionConfigUtils.get_expression_config_for_chat(self.session_id)
- self._enable_expression_use = expr_use
- self._enable_expression_learning = expr_learn
- self._enable_jargon_learning = jargon_learn
+ """配置文件变更回调函数"""
+ # TODO: 根据配置文件变动重新计算相关参数:
+ """
+ 需要计算的参数:
+ self._enable_expression_use = expr_use # 允许使用表达方式,但不一定启用学习
+ self._enable_expression_learning = expr_learn # 允许学习表达方式
+ self._enable_jargon_learning = jargon_learn # 允许学习黑话
+ """
- async def _hfc_func(self):
- recent_messages_list = message_api.get_messages_by_time_in_chat(
- chat_id=self.session_id,
- start_time=self.last_read_time,
- end_time=time.time(),
- limit=20,
- limit_mode="latest",
- filter_mai=True,
- filter_command=False,
- filter_intercept_message_level=1,
- )
+ # ====== 心流聊天核心逻辑 ======
+ async def _hfc_func(self, mentioned_message: Optional["SessionMessage"] = None):
+ """心流聊天的主循环逻辑"""
+ if self._consecutive_no_reply_count >= 5:
+ threshold = 2
+ elif self._consecutive_no_reply_count >= 3:
+ threshold = 2 if random.random() < 0.5 else 1
+ else:
+ threshold = 1
- if len(recent_messages_list) < 1:
+ if len(self.message_cache) < threshold:
await asyncio.sleep(0.2)
return True
- self.last_read_time = time.time()
-
- mentioned_message: Optional["SessionMessage"] = None
- for message in recent_messages_list:
- if global_config.chat.inevitable_at_reply and message.is_at:
- mentioned_message = message
- elif global_config.chat.mentioned_bot_reply and message.is_mentioned:
- mentioned_message = message
-
- talk_value = ChatConfigUtils.get_talk_value(self.session_id) * self._talk_frequency_adjust
- if mentioned_message:
- await self._judge_and_response(mentioned_message=mentioned_message, recent_messages_list=recent_messages_list)
- elif random.random() < talk_value:
- await self._judge_and_response(recent_messages_list=recent_messages_list)
+ talk_value_threshold = (
+ random.random() * ChatConfigUtils.get_talk_value(self.session_id) * self._talk_frequency_adjust
+ )
+ if mentioned_message and global_config.chat.mentioned_bot_reply:
+ await self._judge_and_response(mentioned_message)
+ elif random.random() < talk_value_threshold:
+ await self._judge_and_response()
return True
- async def _judge_and_response(
- self,
- mentioned_message: Optional["SessionMessage"] = None,
- recent_messages_list: Optional[List["SessionMessage"]] = None,
- ):
- recent_messages = list(recent_messages_list or self.message_cache[-20:])
- if recent_messages:
- asyncio.create_task(self._trigger_expression_learning(recent_messages))
-
- cycle_timers, thinking_id = self._start_cycle()
+ async def _judge_and_response(self, mentioned_message: Optional["SessionMessage"] = None):
+ """判定和生成回复"""
+ asyncio.create_task(self._trigger_expression_learning(self.message_cache))
+ # TODO: 完成反思器之后的逻辑
+ start_time = time.time()
+ current_cycle_detail = self._start_cycle()
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
- try:
- async with global_prompt_manager.async_message_scope(self._get_template_name()):
- available_actions: Dict[str, ActionInfo] = {}
- try:
- await self.action_modifier.modify_actions()
- available_actions = self.action_manager.get_using_actions()
- except Exception as exc:
- logger.error(f"{self.log_prefix} 动作修改失败: {exc}", exc_info=True)
+ # TODO: 动作检查逻辑
+ # TODO: Planner逻辑
+ # TODO: 动作执行逻辑
- is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
- message_list_before_now = get_messages_before_time_in_chat(
- chat_id=self.session_id,
- timestamp=time.time(),
- limit=int(global_config.chat.max_context_size * 0.6),
- filter_intercept_message_level=1,
- )
- chat_content_block, message_id_list = build_readable_messages_with_id(
- messages=message_list_before_now,
- timestamp_mode="normal_no_YMD",
- read_mark=self.action_planner.last_obs_time_mark,
- truncate=True,
- show_actions=True,
- )
-
- prompt, filtered_actions = await self._build_planner_prompt_with_event(
- available_actions=available_actions,
- is_group_chat=is_group_chat,
- chat_target_info=chat_target_info,
- chat_content_block=chat_content_block,
- message_id_list=message_id_list,
- )
- if prompt is None:
- return False
-
- with Timer("规划器", cycle_timers):
- reasoning, action_to_use_info, llm_raw_output, llm_reasoning, llm_duration_ms = (
- await self.action_planner._execute_main_planner(
- prompt=prompt,
- message_id_list=message_id_list,
- filtered_actions=filtered_actions,
- available_actions=available_actions,
- loop_start_time=self.last_read_time,
- )
- )
-
- action_to_use_info = self._ensure_force_reply_action(
- actions=action_to_use_info,
- force_reply_message=mentioned_message,
- available_actions=available_actions,
- )
- self.action_planner.add_plan_log(reasoning, action_to_use_info)
- self.action_planner.last_obs_time_mark = time.time()
- self._log_plan(
- prompt=prompt,
- reasoning=reasoning,
- llm_raw_output=llm_raw_output,
- llm_reasoning=llm_reasoning,
- llm_duration_ms=llm_duration_ms,
- actions=action_to_use_info,
- )
-
- logger.info(
- f"{self.log_prefix} 决定执行{len(action_to_use_info)}个动作: {' '.join([a.action_type for a in action_to_use_info])}"
- )
-
- action_tasks = [
- asyncio.create_task(
- self._execute_action(
- action,
- action_to_use_info,
- thinking_id,
- available_actions,
- cycle_timers,
- )
- )
- for action in action_to_use_info
- ]
- results = await asyncio.gather(*action_tasks, return_exceptions=True)
-
- reply_loop_info = None
- reply_text_from_reply = ""
- action_success = False
- action_reply_text = ""
- execute_result_str = ""
-
- for result in results:
- if isinstance(result, BaseException):
- logger.error(f"{self.log_prefix} 动作执行异常: {result}", exc_info=True)
- continue
-
- execute_result_str += f"{result['action_type']} 执行结果:{result['result']}\n"
- if result["action_type"] == "reply":
- if result["success"]:
- reply_loop_info = result["loop_info"]
- reply_text_from_reply = result["result"]
- else:
- logger.warning(f"{self.log_prefix} reply 动作执行失败")
- else:
- action_success = result["success"]
- action_reply_text = result["result"]
-
- self.action_planner.add_plan_excute_log(result=execute_result_str)
-
- if reply_loop_info:
- loop_info = reply_loop_info
- loop_info["loop_action_info"].update(
- {
- "action_taken": action_success,
- "taken_time": time.time(),
- }
- )
- else:
- loop_info = {
- "loop_plan_info": {
- "action_result": action_to_use_info,
- },
- "loop_action_info": {
- "action_taken": action_success,
- "reply_text": action_reply_text,
- "taken_time": time.time(),
- },
- }
- reply_text_from_reply = action_reply_text
-
- current_cycle_detail = self._end_cycle(self._current_cycle_detail, loop_info)
- logger.debug(f"{self.log_prefix} 本轮最终输出: {reply_text_from_reply}")
- return current_cycle_detail is not None
- except Exception as exc:
- logger.error(f"{self.log_prefix} 判定与回复流程失败: {exc}", exc_info=True)
- if self._current_cycle_detail:
- self._end_cycle(
- self._current_cycle_detail,
- {
- "loop_plan_info": {"action_result": []},
- "loop_action_info": {
- "action_taken": False,
- "reply_text": "",
- "taken_time": time.time(),
- "error": str(exc),
- },
- },
- )
- return False
+ cycle_detail = self._end_cycle(current_cycle_detail)
+ if wait_time := global_config.chat.planner_smooth - (time.time() - start_time) > 0:
+ await asyncio.sleep(wait_time)
+ else:
+ await asyncio.sleep(0.1) # 最小等待时间,避免过快循环
+ return True
def _handle_loop_completion(self, task: asyncio.Task):
+ """当 _hfc_func 任务完成时执行的回调。"""
try:
if exception := task.exception():
- logger.error(f"{self.log_prefix} HeartFChatting: 主循环异常退出: {exception}")
- logger.error(traceback.format_exc())
+ logger.error(f"{self.log_prefix} HeartFChatting: 脱离了聊天(异常): {exception}")
+ logger.error(traceback.format_exc()) # Log full traceback for exceptions
else:
- logger.info(f"{self.log_prefix} HeartFChatting: 主循环已退出")
+ logger.info(f"{self.log_prefix} HeartFChatting: 脱离了聊天 (外部停止)")
except asyncio.CancelledError:
- logger.info(f"{self.log_prefix} HeartFChatting: 聊天已结束")
+ logger.info(f"{self.log_prefix} HeartFChatting: 结束了聊天")
+ # ====== 学习器触发逻辑 ======
async def _trigger_expression_learning(self, messages: List["SessionMessage"]):
- if not messages:
- return
-
self._expression_learner.add_messages(messages)
if time.time() - self._last_extraction_time < self._min_extraction_interval:
return
@@ -379,14 +233,12 @@ class HeartFChatting:
return
if not self._enable_expression_learning:
return
-
extraction_end_time = time.time()
logger.info(
f"聊天流 {self.session_name} 提取到 {len(messages)} 条消息,"
f"时间窗口: {self._last_extraction_time:.2f} - {extraction_end_time:.2f}"
)
self._last_extraction_time = extraction_end_time
-
try:
jargon_miner = self._jargon_miner if self._enable_jargon_learning else None
learnt_style = await self._expression_learner.learn(jargon_miner)
@@ -394,398 +246,43 @@ class HeartFChatting:
logger.info(f"{self.log_prefix} 表达学习完成")
else:
logger.debug(f"{self.log_prefix} 表达学习未获得有效结果")
- except Exception as exc:
- logger.error(f"{self.log_prefix} 表达学习失败: {exc}", exc_info=True)
+ except Exception as e:
+ logger.error(f"{self.log_prefix} 表达学习失败: {e}", exc_info=True)
- def _start_cycle(self) -> Tuple[Dict[str, float], str]:
+ # ====== 记录循环执行信息相关逻辑 ======
+ def _start_cycle(self) -> CycleDetail:
self._cycle_counter += 1
- self._current_cycle_detail = CycleDetail(cycle_id=self._cycle_counter)
- self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
- return self._current_cycle_detail.time_records, self._current_cycle_detail.thinking_id
+ current_cycle_detail = CycleDetail(cycle_id=self._cycle_counter)
+ current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
+ return current_cycle_detail
- def _end_cycle(self, cycle_detail: Optional[CycleDetail], loop_info: Optional[Dict[str, Any]] = None):
- if cycle_detail is None:
- return None
-
- cycle_detail.loop_plan_info = (loop_info or {}).get("loop_plan_info")
- cycle_detail.loop_action_info = (loop_info or {}).get("loop_action_info")
+ def _end_cycle(self, cycle_detail: CycleDetail, only_long_execution: bool = True):
cycle_detail.end_time = time.time()
- self.history_loop.append(cycle_detail)
-
- timer_strings = [
+ timer_strings: List[str] = [
f"{name}: {duration:.2f}s"
for name, duration in cycle_detail.time_records.items()
- if duration >= 0.1
+ if not only_long_execution or duration >= 0.1
]
logger.info(
- f"{self.log_prefix} 第{cycle_detail.cycle_id} 个心流循环完成,"
- f"耗时: {cycle_detail.end_time - cycle_detail.start_time:.2f}s;"
+ f"{self.log_prefix} 第 {cycle_detail.cycle_id} 个心流循环完成"
+ f"耗时: {cycle_detail.end_time - cycle_detail.start_time:.2f}秒\n"
f"详细计时: {', '.join(timer_strings) if timer_strings else '无'}"
)
+
return cycle_detail
- async def _execute_action(
- self,
- action_planner_info: ActionPlannerInfo,
- chosen_action_plan_infos: List[ActionPlannerInfo],
- thinking_id: str,
- available_actions: Dict[str, ActionInfo],
- cycle_timers: Dict[str, float],
- ):
- try:
- with Timer(f"动作{action_planner_info.action_type}", cycle_timers):
- if action_planner_info.action_type == "no_reply":
- reason = action_planner_info.reasoning or "选择不回复"
- self._consecutive_no_reply_count += 1
- await database_api.store_action_info(
- chat_stream=self.chat_stream,
- display_prompt=reason,
- thinking_id=thinking_id,
- action_data={},
- action_name="no_reply",
- action_reasoning=reason,
- )
- return {
- "action_type": "no_reply",
- "success": True,
- "result": "选择不回复",
- "loop_info": None,
- }
+ # ====== Action相关逻辑 ======
+ async def _execute_action(self, *args, **kwargs):
+ """原ExecuteAction"""
+ raise NotImplementedError("执行动作的逻辑尚未实现") # TODO: 实现动作执行的逻辑,替换掉*args, **kwargs*占位符
- if action_planner_info.action_type == "reply":
- self._consecutive_no_reply_count = 0
- reason = action_planner_info.reasoning or ""
- think_level = self._get_think_level(action_planner_info)
- planner_reasoning = action_planner_info.action_reasoning or reason
+ async def _execute_other_actions(self, *args, **kwargs):
+ """原HandleAction"""
+ raise NotImplementedError(
+ "执行其他动作的逻辑尚未实现"
+ ) # TODO: 实现其他动作执行的逻辑, 替换掉*args, **kwargs*占位符
- record_replyer_action_temp(
- chat_id=self.session_id,
- reason=reason,
- think_level=think_level,
- )
- await database_api.store_action_info(
- chat_stream=self.chat_stream,
- display_prompt=reason,
- thinking_id=thinking_id,
- action_data={},
- action_name="reply",
- action_reasoning=reason,
- )
-
- unknown_words, quote_message = self._extract_reply_metadata(action_planner_info)
- success, llm_response = await generator_api.generate_reply(
- chat_stream=self.chat_stream,
- reply_message=action_planner_info.action_message,
- available_actions=available_actions,
- chosen_actions=chosen_action_plan_infos,
- reply_reason=planner_reasoning,
- unknown_words=unknown_words,
- enable_tool=global_config.tool.enable_tool,
- request_type="replyer",
- from_plugin=False,
- reply_time_point=action_planner_info.action_data.get("loop_start_time", time.time())
- if action_planner_info.action_data
- else time.time(),
- think_level=think_level,
- )
- if not success or not llm_response or not llm_response.reply_set:
- if action_planner_info.action_message:
- logger.info(
- f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败"
- )
- else:
- logger.info(f"{self.log_prefix} 回复生成失败")
- return {
- "action_type": "reply",
- "success": False,
- "result": "回复生成失败",
- "loop_info": None,
- }
-
- loop_info, reply_text, _ = await self._send_and_store_reply(
- response_set=llm_response.reply_set,
- action_message=action_planner_info.action_message, # type: ignore[arg-type]
- cycle_timers=cycle_timers,
- thinking_id=thinking_id,
- actions=chosen_action_plan_infos,
- selected_expressions=llm_response.selected_expressions,
- quote_message=quote_message,
- )
- self.last_active_time = time.time()
- return {
- "action_type": "reply",
- "success": True,
- "result": reply_text,
- "loop_info": loop_info,
- }
-
- with Timer("动作执行", cycle_timers):
- success, result = await self._handle_action(
- action=action_planner_info.action_type,
- action_reasoning=action_planner_info.action_reasoning or "",
- action_data=action_planner_info.action_data or {},
- cycle_timers=cycle_timers,
- thinking_id=thinking_id,
- action_message=action_planner_info.action_message,
- )
- if success:
- self.last_active_time = time.time()
- return {
- "action_type": action_planner_info.action_type,
- "success": success,
- "result": result,
- "loop_info": None,
- }
- except Exception as exc:
- logger.error(f"{self.log_prefix} 执行动作时出错: {exc}", exc_info=True)
- return {
- "action_type": action_planner_info.action_type,
- "success": False,
- "result": "",
- "loop_info": None,
- "error": str(exc),
- }
-
- async def _handle_action(
- self,
- action: str,
- action_reasoning: str,
- action_data: dict,
- cycle_timers: Dict[str, float],
- thinking_id: str,
- action_message: Optional["SessionMessage"] = None,
- ) -> Tuple[bool, str]:
- try:
- action_handler = self.action_manager.create_action(
- action_name=action,
- action_data=action_data,
- action_reasoning=action_reasoning,
- cycle_timers=cycle_timers,
- thinking_id=thinking_id,
- chat_stream=self.chat_stream,
- log_prefix=self.log_prefix,
- action_message=action_message,
- )
- if not action_handler:
- logger.warning(f"{self.log_prefix} 未能创建动作处理器: {action}")
- return False, ""
-
- success, action_text = await action_handler.execute()
- return success, action_text
- except Exception as exc:
- logger.error(f"{self.log_prefix} 处理动作 {action} 时出错: {exc}", exc_info=True)
- return False, ""
-
- async def _send_and_store_reply(
- self,
- response_set: MessageSequence,
- action_message: "SessionMessage",
- cycle_timers: Dict[str, float],
- thinking_id: str,
- actions: List[ActionPlannerInfo],
- selected_expressions: Optional[List[int]] = None,
- quote_message: Optional[bool] = None,
- ) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
- with Timer("回复发送", cycle_timers):
- reply_text = await self._send_response(
- reply_set=response_set,
- message_data=action_message,
- selected_expressions=selected_expressions,
- quote_message=quote_message,
- )
-
- platform = action_message.platform or getattr(self.chat_stream, "platform", "unknown")
- person = Person(platform=platform, user_id=action_message.message_info.user_info.user_id)
- action_prompt_display = f"你对{person.person_name}进行了回复:{reply_text}"
- await database_api.store_action_info(
- chat_stream=self.chat_stream,
- display_prompt=action_prompt_display,
- thinking_id=thinking_id,
- action_data={"reply_text": reply_text},
- action_name="reply",
- )
-
- loop_info: Dict[str, Any] = {
- "loop_plan_info": {
- "action_result": actions,
- },
- "loop_action_info": {
- "action_taken": True,
- "reply_text": reply_text,
- "command": "",
- "taken_time": time.time(),
- },
- }
- return loop_info, reply_text, cycle_timers
-
- async def _send_response(
- self,
- reply_set: MessageSequence,
- message_data: "SessionMessage",
- selected_expressions: Optional[List[int]] = None,
- quote_message: Optional[bool] = None,
- ) -> str:
- if global_config.chat.llm_quote:
- need_reply = bool(quote_message)
- else:
- new_message_count = message_api.count_new_messages(
- chat_id=self.session_id,
- start_time=self.last_read_time,
- end_time=time.time(),
- )
- need_reply = new_message_count >= random.randint(2, 3) or time.time() - self.last_read_time > 90
-
- reply_text = ""
- first_replied = False
- for component in reply_set.components:
- if not isinstance(component, TextComponent):
- continue
- data = component.text
- if not first_replied:
- await send_api.text_to_stream(
- text=data,
- stream_id=self.session_id,
- reply_message=message_data,
- set_reply=need_reply,
- typing=False,
- selected_expressions=selected_expressions,
- )
- first_replied = True
- else:
- await send_api.text_to_stream(
- text=data,
- stream_id=self.session_id,
- reply_message=message_data,
- set_reply=False,
- typing=True,
- selected_expressions=selected_expressions,
- )
- reply_text += data
- return reply_text
-
- async def _build_planner_prompt_with_event(
- self,
- available_actions: Dict[str, ActionInfo],
- is_group_chat: bool,
- chat_target_info: Any,
- chat_content_block: str,
- message_id_list: List[Tuple[str, "SessionMessage"]],
- ) -> Tuple[Optional[str], Dict[str, ActionInfo]]:
- filtered_actions = self.action_planner._filter_actions_by_activation_type(available_actions, chat_content_block)
- prompt, _ = await self.action_planner.build_planner_prompt(
- is_group_chat=is_group_chat,
- chat_target_info=chat_target_info,
- current_available_actions=filtered_actions,
- chat_content_block=chat_content_block,
- message_id_list=message_id_list,
- )
- event_message = build_event_message(EventType.ON_PLAN, llm_prompt=prompt, stream_id=self.session_id)
- continue_flag, modified_message = await event_bus.emit(EventType.ON_PLAN, event_message)
- if not continue_flag:
- logger.info(f"{self.log_prefix} ON_PLAN 事件中止了本轮 HFC")
- return None, filtered_actions
- if modified_message and modified_message._modify_flags.modify_llm_prompt and modified_message.llm_prompt:
- prompt = modified_message.llm_prompt
- return prompt, filtered_actions
-
- def _ensure_force_reply_action(
- self,
- actions: List[ActionPlannerInfo],
- force_reply_message: Optional["SessionMessage"],
- available_actions: Dict[str, ActionInfo],
- ) -> List[ActionPlannerInfo]:
- if not force_reply_message:
- return actions
-
- has_reply_to_force_message = any(
- action.action_type == "reply"
- and action.action_message
- and action.action_message.message_id == force_reply_message.message_id
- for action in actions
- )
- if has_reply_to_force_message:
- return actions
-
- actions = [action for action in actions if action.action_type != "no_reply"]
- actions.insert(
- 0,
- ActionPlannerInfo(
- action_type="reply",
- reasoning="用户提及了我,必须回复该消息",
- action_data={"loop_start_time": self.last_read_time},
- action_message=force_reply_message,
- available_actions=available_actions,
- action_reasoning=None,
- ),
- )
- logger.info(f"{self.log_prefix} 检测到强制回复消息,已补充 reply 动作")
- return actions
-
- def _log_plan(
- self,
- prompt: str,
- reasoning: str,
- llm_raw_output: Optional[str],
- llm_reasoning: Optional[str],
- llm_duration_ms: Optional[float],
- actions: List[ActionPlannerInfo],
- ) -> None:
- try:
- PlanReplyLogger.log_plan(
- chat_id=self.session_id,
- prompt=prompt,
- reasoning=reasoning,
- raw_output=llm_raw_output,
- raw_reasoning=llm_reasoning,
- actions=actions,
- timing={
- "llm_duration_ms": round(llm_duration_ms, 2) if llm_duration_ms is not None else None,
- "loop_start_time": self.last_read_time,
- },
- extra=None,
- )
- except Exception:
- logger.exception(f"{self.log_prefix} 记录 plan 日志失败")
-
- def _extract_reply_metadata(
- self,
- action_planner_info: ActionPlannerInfo,
- ) -> Tuple[Optional[List[str]], Optional[bool]]:
- unknown_words: Optional[List[str]] = None
- quote_message: Optional[bool] = None
- action_data = action_planner_info.action_data or {}
-
- raw_unknown_words = action_data.get("unknown_words")
- if isinstance(raw_unknown_words, list):
- cleaned_unknown_words = []
- for item in raw_unknown_words:
- if isinstance(item, str) and (cleaned_item := item.strip()):
- cleaned_unknown_words.append(cleaned_item)
- if cleaned_unknown_words:
- unknown_words = cleaned_unknown_words
-
- raw_quote = action_data.get("quote")
- if isinstance(raw_quote, bool):
- quote_message = raw_quote
- elif isinstance(raw_quote, str):
- quote_message = raw_quote.lower() in {"true", "1", "yes"}
- elif isinstance(raw_quote, (int, float)):
- quote_message = bool(raw_quote)
-
- return unknown_words, quote_message
-
- def _get_think_level(self, action_planner_info: ActionPlannerInfo) -> int:
- think_mode = global_config.chat.think_mode
- if think_mode == "default":
- return 0
- if think_mode == "deep":
- return 1
- if think_mode == "dynamic":
- action_data = action_planner_info.action_data or {}
- return int(action_data.get("think_level", 1))
- return 0
-
- def _get_template_name(self) -> Optional[str]:
- if self.chat_stream.context:
- return self.chat_stream.context.template_name
- return None
+ # ====== 响应发送相关方法 ======
+ async def _send_response(self, *args, **kwargs):
+ raise NotImplementedError("发送回复的逻辑尚未实现") # TODO: 实现发送回复的逻辑,替换掉*args, **kwargs*占位符
+ # 传入的消息至少应该是个MessageSequence实例,最好是SessionMessage实例,随后可直接转化为MessageSending实例
diff --git a/src/chat/heart_flow/heartflow.py b/src/chat/heart_flow/heartflow.py
deleted file mode 100644
index febff2d5..00000000
--- a/src/chat/heart_flow/heartflow.py
+++ /dev/null
@@ -1,42 +0,0 @@
-import traceback
-from typing import Any, Optional, Dict
-
-from src.chat.message_receive.chat_stream import get_chat_manager
-from src.common.logger import get_logger
-from src.chat.heart_flow.heartFC_chat import HeartFChatting
-from src.chat.brain_chat.brain_chat import BrainChatting
-from src.chat.message_receive.chat_stream import ChatStream
-
-logger = get_logger("heartflow")
-
-
-class Heartflow:
- """主心流协调器,负责初始化并协调聊天"""
-
- def __init__(self):
- self.heartflow_chat_list: Dict[Any, HeartFChatting | BrainChatting] = {}
-
- async def get_or_create_heartflow_chat(self, chat_id: Any) -> Optional[HeartFChatting | BrainChatting]:
- """获取或创建一个新的HeartFChatting实例"""
- try:
- if chat_id in self.heartflow_chat_list:
- if chat := self.heartflow_chat_list.get(chat_id):
- return chat
- else:
- chat_stream: ChatStream | None = get_chat_manager().get_stream(chat_id)
- if not chat_stream:
- raise ValueError(f"未找到 chat_id={chat_id} 的聊天流")
- if chat_stream.group_info:
- new_chat = HeartFChatting(chat_id=chat_id)
- else:
- new_chat = BrainChatting(chat_id=chat_id)
- await new_chat.start()
- self.heartflow_chat_list[chat_id] = new_chat
- return new_chat
- except Exception as e:
- logger.error(f"创建心流聊天 {chat_id} 失败: {e}", exc_info=True)
- traceback.print_exc()
- return None
-
-
-heartflow = Heartflow()