From 780cd4f767eb7f8f6bc2357ef062891bb2de2ca2 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Sat, 21 Mar 2026 00:46:34 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E6=9B=B4=E6=96=B0=E6=8F=92?= =?UTF-8?q?=E4=BB=B6=E5=92=8C=20RPC=20=E6=9C=8D=E5=8A=A1=E5=99=A8=E9=80=BB?= =?UTF-8?q?=E8=BE=91=EF=BC=8C=E5=A2=9E=E5=BC=BA=E6=8F=A1=E6=89=8B=E7=8A=B6?= =?UTF-8?q?=E6=80=81=E7=AE=A1=E7=90=86=E4=B8=8E=E9=85=8D=E7=BD=AE=E6=A0=A1?= =?UTF-8?q?=E9=AA=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 2 +- src/plugin_runtime/host/rpc_server.py | 23 +- src/plugin_runtime/host/supervisor.py | 79 +++++- src/plugin_runtime/integration.py | 162 ++++++++--- src/plugin_runtime/protocol/envelope.py | 2 +- src/plugins/built_in/napcat_adapter/plugin.py | 251 +++++++++++++++++- 6 files changed, 457 insertions(+), 62 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f41e9448..9887ac24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ dependencies = [ "jieba>=0.42.1", "json-repair>=0.47.6", "maim-message>=0.6.2", - "maibot-plugin-sdk>=1.2.3,<2.0.0", + "maibot-plugin-sdk>=2.0.0", "msgpack>=1.1.2", "numpy>=2.2.6", "openai>=1.95.0", diff --git a/src/plugin_runtime/host/rpc_server.py b/src/plugin_runtime/host/rpc_server.py index 75ef9b2a..2c422775 100644 --- a/src/plugin_runtime/host/rpc_server.py +++ b/src/plugin_runtime/host/rpc_server.py @@ -69,6 +69,7 @@ class RPCServer: # 运行状态 self._running: bool = False self._tasks: List[asyncio.Task[None]] = [] + self._last_handshake_rejection_reason: str = "" @property def session_token(self) -> str: @@ -78,6 +79,15 @@ class RPCServer: def is_connected(self) -> bool: return self._connection is not None and not self._connection.is_closed + @property + def last_handshake_rejection_reason(self) -> str: + """返回最近一次握手被拒绝的原因。""" + return self._last_handshake_rejection_reason + + def clear_handshake_state(self) -> None: + """清空最近一次握手拒绝状态。""" + self._last_handshake_rejection_reason = "" + def register_method(self, method: str, handler: MethodHandler) -> None: """注册 RPC 方法处理器""" self._method_handlers[method] = handler @@ -85,6 +95,7 @@ class RPCServer: async def start(self) -> None: """启动 RPC 服务器""" self._running = True + self.clear_handshake_state() self._send_queue = asyncio.Queue(maxsize=self._send_queue_size) self._send_worker_task = asyncio.create_task(self._send_loop()) await self._transport.start(self._handle_connection) @@ -93,6 +104,7 @@ class RPCServer: async def stop(self) -> None: """停止 RPC 服务器""" self._running = False + self.clear_handshake_state() self._fail_pending_requests(ErrorCode.E_SHUTTING_DOWN, "服务器正在关闭") self._fail_queued_sends(ErrorCode.E_SHUTTING_DOWN, "服务器正在关闭") @@ -204,6 +216,7 @@ class RPCServer: async def _handle_connection(self, conn: Connection) -> None: """处理新的 Runner 连接""" logger.info("收到 Runner 连接") + self.clear_handshake_state() # 第一条消息必须是 runner.hello 握手 try: success = await self._handle_handshake(conn) @@ -232,6 +245,7 @@ class RPCServer: envelope = self._codec.decode_envelope(data) if envelope.method != "runner.hello": logger.error(f"期望 runner.hello,收到 {envelope.method}") + self._last_handshake_rejection_reason = "首条消息必须为 runner.hello" error_resp = envelope.make_error_response( ErrorCode.E_PROTOCOL_MISMATCH.value, "首条消息必须为 runner.hello", @@ -244,7 +258,8 @@ class RPCServer: # 校验会话令牌 if hello.session_token != self._session_token: logger.error("会话令牌不匹配") - resp_payload = HelloResponsePayload(accepted=False, reason="会话令牌无效") + self._last_handshake_rejection_reason = "会话令牌无效" + resp_payload = HelloResponsePayload(accepted=False, reason=self._last_handshake_rejection_reason) resp = envelope.make_response(payload=resp_payload.model_dump()) await conn.send_frame(self._codec.encode_envelope(resp)) return False @@ -252,15 +267,19 @@ class RPCServer: # 校验 SDK 版本 if not self._check_sdk_version(hello.sdk_version): logger.error(f"SDK 版本不兼容: {hello.sdk_version}") + self._last_handshake_rejection_reason = ( + f"SDK 版本 {hello.sdk_version} 不在支持范围 [{MIN_SDK_VERSION}, {MAX_SDK_VERSION}]" + ) resp_payload = HelloResponsePayload( accepted=False, - reason=f"SDK 版本 {hello.sdk_version} 不在支持范围 [{MIN_SDK_VERSION}, {MAX_SDK_VERSION}]", + reason=self._last_handshake_rejection_reason, ) resp = envelope.make_response(payload=resp_payload.model_dump()) await conn.send_frame(self._codec.encode_envelope(resp)) return False # 发送响应 + self.clear_handshake_state() resp_payload = HelloResponsePayload(accepted=True, host_version=PROTOCOL_VERSION) resp = envelope.make_response(payload=resp_payload.model_dump()) await conn.send_frame(self._codec.encode_envelope(resp)) diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py index cdf3d4ee..33091d5a 100644 --- a/src/plugin_runtime/host/supervisor.py +++ b/src/plugin_runtime/host/supervisor.py @@ -202,17 +202,24 @@ class PluginRunnerSupervisor: self._restart_count = 0 self._clear_runner_state() - await self._rpc_server.start() - await self._spawn_runner() - try: - await self._wait_for_runner_connection(timeout_sec=self._runner_spawn_timeout) - await self._wait_for_runner_ready(timeout_sec=self._runner_spawn_timeout) - except TimeoutError: - if not self._rpc_server.is_connected: - logger.warning("Runner 未在限定时间内完成连接,后续操作可能失败") - else: - logger.warning("Runner 未在限定时间内完成初始化,后续操作可能失败") + await self._rpc_server.start() + await self._spawn_runner() + + try: + await self._wait_for_runner_connection(timeout_sec=self._runner_spawn_timeout) + await self._wait_for_runner_ready(timeout_sec=self._runner_spawn_timeout) + except TimeoutError: + if not self._rpc_server.is_connected: + logger.warning("Runner 未在限定时间内完成连接,后续操作可能失败") + else: + logger.warning("Runner 未在限定时间内完成初始化,后续操作可能失败") + except Exception: + await self._shutdown_runner(reason="startup_failed") + await self._rpc_server.stop() + self._clear_runner_state() + self._running = False + raise self._health_task = asyncio.create_task(self._health_check_loop(), name="PluginRunnerSupervisor.health") logger.info("PluginRunnerSupervisor 已启动") @@ -387,7 +394,16 @@ class PluginRunnerSupervisor: async def wait_for_connection() -> None: """轮询等待 RPC 连接建立。""" - while self._running and not self._rpc_server.is_connected: + while True: + if self._rpc_server.is_connected: + return + + if not self._running: + raise RuntimeError("Supervisor 已停止,等待 Runner 连接已取消") + + if failure_reason := self._get_runner_startup_failure_reason(): + raise RuntimeError(f"等待 Runner 连接失败: {failure_reason}") + await asyncio.sleep(0.1) try: @@ -408,10 +424,27 @@ class PluginRunnerSupervisor: Raises: TimeoutError: 在超时时间内 Runner 未完成初始化。 """ + async def wait_for_ready() -> RunnerReadyPayload: + """轮询等待 Runner 上报就绪。""" + while True: + if self._runner_ready_events.is_set(): + return self._runner_ready_payloads + + if not self._running: + raise RuntimeError("Supervisor 已停止,等待 Runner 就绪已取消") + + if failure_reason := self._get_runner_startup_failure_reason(): + raise RuntimeError(f"等待 Runner 就绪失败: {failure_reason}") + + if not self._rpc_server.is_connected: + raise RuntimeError("等待 Runner 就绪失败: Runner RPC 连接已断开") + + await asyncio.sleep(0.1) + try: - await asyncio.wait_for(self._runner_ready_events.wait(), timeout=timeout_sec) + payload = await asyncio.wait_for(wait_for_ready(), timeout=timeout_sec) logger.info("Runner 已完成初始化并上报就绪") - return self._runner_ready_payloads + return payload except asyncio.TimeoutError as exc: raise TimeoutError(f"等待 Runner 就绪超时({timeout_sec}s)") from exc @@ -923,6 +956,7 @@ class PluginRunnerSupervisor: await self._wait_for_runner_connection(timeout_sec=self._runner_spawn_timeout) await self._wait_for_runner_ready(timeout_sec=self._runner_spawn_timeout) except Exception as exc: + await self._shutdown_runner(reason="restart_failed") logger.error(f"Runner 重启失败: {exc}", exc_info=True) return False @@ -938,6 +972,25 @@ class PluginRunnerSupervisor: self._registered_adapters.clear() self._runner_ready_events = asyncio.Event() self._runner_ready_payloads = RunnerReadyPayload() + self._rpc_server.clear_handshake_state() + + def _get_runner_startup_failure_reason(self) -> Optional[str]: + """获取 Runner 在启动阶段已经暴露出的失败原因。 + + Returns: + Optional[str]: 若已检测到失败则返回失败原因,否则返回 ``None``。 + """ + if handshake_reason := self._rpc_server.last_handshake_rejection_reason: + return f"握手被拒绝: {handshake_reason}" + + process = self._runner_process + if process is None: + return "Runner 进程不存在" + + if process.returncode is not None: + return f"Runner 进程已退出,退出码 {process.returncode}" + + return None PluginSupervisor = PluginRunnerSupervisor diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py index 30a3c150..24cf09fc 100644 --- a/src/plugin_runtime/integration.py +++ b/src/plugin_runtime/integration.py @@ -68,6 +68,7 @@ class PluginRuntimeManager( self._plugin_file_watcher: Optional[FileWatcher] = None self._plugin_source_watcher_subscription_id: Optional[str] = None self._plugin_config_watcher_subscriptions: Dict[str, Tuple[Path, str]] = {} + self._plugin_path_cache: Dict[str, Path] = {} async def _dispatch_platform_inbound(self, envelope: InboundMessageEnvelope) -> None: """接收 Platform IO 审核后的入站消息并送入主消息链。 @@ -215,6 +216,7 @@ class PluginRuntimeManager( self._started = False self._builtin_supervisor = None self._third_party_supervisor = None + self._plugin_path_cache.clear() @property def is_running(self) -> bool: @@ -254,7 +256,7 @@ class PluginRuntimeManager( config_payload = ( config_data if config_data is not None - else self._load_plugin_config_for_supervisor(plugin_id, plugin_dirs=sv._plugin_dirs) + else self._load_plugin_config_for_supervisor(sv, plugin_id) ) await sv.notify_plugin_config_updated( plugin_id=plugin_id, @@ -452,6 +454,7 @@ class PluginRuntimeManager( async def _stop_plugin_file_watcher(self) -> None: """停止插件文件监视器,并清理所有已注册订阅。""" if self._plugin_file_watcher is None: + self._plugin_path_cache.clear() return for _plugin_id, (_config_path, subscription_id) in list(self._plugin_config_watcher_subscriptions.items()): self._plugin_file_watcher.unsubscribe(subscription_id) @@ -461,12 +464,95 @@ class PluginRuntimeManager( self._plugin_source_watcher_subscription_id = None await self._plugin_file_watcher.stop() self._plugin_file_watcher = None + self._plugin_path_cache.clear() def _iter_plugin_dirs(self) -> Iterable[Path]: """迭代所有 Supervisor 当前管理的插件根目录。""" for supervisor in self.supervisors: yield from getattr(supervisor, "_plugin_dirs", []) + @staticmethod + def _iter_candidate_plugin_paths(plugin_dirs: Iterable[Path]) -> Iterable[Path]: + """迭代所有可能的插件目录路径。 + + Args: + plugin_dirs: 一个或多个插件根目录。 + + Yields: + Path: 单个插件目录路径。 + """ + for plugin_dir in plugin_dirs: + plugin_root = Path(plugin_dir).resolve() + if not plugin_root.is_dir(): + continue + for entry in plugin_root.iterdir(): + if entry.is_dir(): + yield entry.resolve() + + @staticmethod + def _read_plugin_id_from_plugin_path(plugin_path: Path) -> Optional[str]: + """从单个插件目录中读取 manifest 声明的插件 ID。 + + Args: + plugin_path: 单个插件目录路径。 + + Returns: + Optional[str]: 解析成功时返回插件 ID,否则返回 ``None``。 + """ + manifest_path = plugin_path / "_manifest.json" + entrypoint_path = plugin_path / "plugin.py" + if not manifest_path.is_file() or not entrypoint_path.is_file(): + return None + + try: + with open(manifest_path, "r", encoding="utf-8") as manifest_file: + manifest = json.load(manifest_file) + except Exception: + return None + + if not isinstance(manifest, dict): + return None + + plugin_id = str(manifest.get("name", plugin_path.name)).strip() or plugin_path.name + return plugin_id or None + + def _iter_discovered_plugin_paths(self, plugin_dirs: Iterable[Path]) -> Iterable[Tuple[str, Path]]: + """迭代目录中可解析到的插件 ID 与实际目录路径。 + + Args: + plugin_dirs: 一个或多个插件根目录。 + + Yields: + Tuple[str, Path]: ``(plugin_id, plugin_path)`` 二元组。 + """ + for plugin_path in self._iter_candidate_plugin_paths(plugin_dirs): + if plugin_id := self._read_plugin_id_from_plugin_path(plugin_path): + yield plugin_id, plugin_path + + def _get_plugin_path_for_supervisor(self, supervisor: Any, plugin_id: str) -> Optional[Path]: + """为指定 Supervisor 定位某个插件的实际目录。 + + Args: + supervisor: 目标 Supervisor。 + plugin_id: 插件 ID。 + + Returns: + Optional[Path]: 插件目录路径;未找到时返回 ``None``。 + """ + cached_path = self._plugin_path_cache.get(plugin_id) + if cached_path is not None: + for plugin_dir in getattr(supervisor, "_plugin_dirs", []): + if self._plugin_dir_matches(cached_path, Path(plugin_dir)): + return cached_path + + for candidate_plugin_id, plugin_path in self._iter_discovered_plugin_paths(getattr(supervisor, "_plugin_dirs", [])): + if candidate_plugin_id != plugin_id: + continue + self._plugin_path_cache[plugin_id] = plugin_path + return plugin_path + + return None + def _refresh_plugin_config_watch_subscriptions(self) -> None: """按当前已注册插件集合刷新 config.toml 的单插件订阅。 @@ -476,7 +562,11 @@ class PluginRuntimeManager( if self._plugin_file_watcher is None: return - desired_config_paths = dict(self._iter_registered_plugin_config_paths()) + desired_plugin_paths = dict(self._iter_registered_plugin_paths()) + self._plugin_path_cache = desired_plugin_paths.copy() + desired_config_paths = { + plugin_id: plugin_path / "config.toml" for plugin_id, plugin_path in desired_plugin_paths.items() + } for plugin_id, (_old_path, subscription_id) in list(self._plugin_config_watcher_subscriptions.items()): if desired_config_paths.get(plugin_id) == self._plugin_config_watcher_subscriptions[plugin_id][0]: @@ -509,21 +599,17 @@ class PluginRuntimeManager( return _callback - def _iter_registered_plugin_config_paths(self) -> Iterable[Tuple[str, Path]]: - """迭代当前所有已注册插件的 config.toml 路径。""" + def _iter_registered_plugin_paths(self) -> Iterable[Tuple[str, Path]]: + """迭代当前所有已注册插件的实际目录路径。""" for supervisor in self.supervisors: for plugin_id in getattr(supervisor, "_registered_plugins", {}).keys(): - if config_path := self._get_plugin_config_path_for_supervisor(supervisor, plugin_id): - yield plugin_id, config_path + if plugin_path := self._get_plugin_path_for_supervisor(supervisor, plugin_id): + yield plugin_id, plugin_path def _get_plugin_config_path_for_supervisor(self, supervisor: Any, plugin_id: str) -> Optional[Path]: """从指定 Supervisor 的插件目录中定位某个插件的 config.toml。""" - for plugin_dir in getattr(supervisor, "_plugin_dirs", []): - plugin_dir = Path(plugin_dir) - plugin_path = plugin_dir.resolve() / plugin_id - if plugin_path.is_dir(): - return plugin_path / "config.toml" - return None + plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id) + return None if plugin_path is None else plugin_path / "config.toml" async def _handle_plugin_config_changes(self, plugin_id: str, changes: Sequence[FileChange]) -> None: """处理单个插件配置文件变化,并仅向目标插件推送配置更新。""" @@ -542,7 +628,7 @@ class PluginRuntimeManager( try: await supervisor.notify_plugin_config_updated( plugin_id=plugin_id, - config_data=self._load_plugin_config_for_supervisor(plugin_id, getattr(supervisor, "_plugin_dirs", [])), + config_data=self._load_plugin_config_for_supervisor(supervisor, plugin_id), ) except Exception as exc: logger.warning(f"插件 {plugin_id} 配置热更新通知失败: {exc}") @@ -591,32 +677,38 @@ class PluginRuntimeManager( def _match_plugin_id_for_supervisor(self, supervisor: Any, path: Path) -> Optional[str]: """根据变更路径为指定 Supervisor 推断受影响的插件 ID。""" - for plugin_id, _reg in getattr(supervisor, "_registered_plugins", {}).items(): - for plugin_dir in getattr(supervisor, "_plugin_dirs", []): - plugin_dir = Path(plugin_dir) - candidate_dir = plugin_dir.resolve() / plugin_id - if path == candidate_dir or path.is_relative_to(candidate_dir): - return plugin_id + resolved_path = path.resolve() + + for plugin_id in getattr(supervisor, "_registered_plugins", {}).keys(): + plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id) + if plugin_path is not None and (resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path)): + return plugin_id + + for plugin_id, plugin_path in self._plugin_path_cache.items(): + if not any(self._plugin_dir_matches(plugin_path, Path(plugin_dir)) for plugin_dir in getattr(supervisor, "_plugin_dirs", [])): + continue + if resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path): + return plugin_id + + for plugin_id, plugin_path in self._iter_discovered_plugin_paths(getattr(supervisor, "_plugin_dirs", [])): + if resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path): + self._plugin_path_cache[plugin_id] = plugin_path + return plugin_id - for plugin_dir in getattr(supervisor, "_plugin_dirs", []): - plugin_dir = Path(plugin_dir) - plugin_root = plugin_dir.resolve() - if self._plugin_dir_matches(path, plugin_dir) and (relative_parts := path.relative_to(plugin_root).parts): - return relative_parts[0] return None - @staticmethod - def _load_plugin_config_for_supervisor(plugin_id: str, plugin_dirs: Iterable[Path]) -> Dict[str, Any]: + def _load_plugin_config_for_supervisor(self, supervisor: Any, plugin_id: str) -> Dict[str, Any]: """从给定插件目录集合中读取目标插件的配置内容。""" - for plugin_dir in plugin_dirs: - plugin_path = plugin_dir.resolve() / plugin_id - if plugin_path.is_dir(): - config_path = plugin_path / "config.toml" - if not config_path.exists(): - return {} - with open(config_path, "r", encoding="utf-8") as handle: - return tomlkit.load(handle).unwrap() - return {} + plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id) + if plugin_path is None: + return {} + + config_path = plugin_path / "config.toml" + if not config_path.exists(): + return {} + + with open(config_path, "r", encoding="utf-8") as handle: + return tomlkit.load(handle).unwrap() # ─── 能力实现注册 ────────────────────────────────────────── diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py index 0dfc6656..d71e02c5 100644 --- a/src/plugin_runtime/protocol/envelope.py +++ b/src/plugin_runtime/protocol/envelope.py @@ -17,7 +17,7 @@ from pydantic import BaseModel, Field PROTOCOL_VERSION = "1.0.0" # 支持的 SDK 版本范围(Host 在握手时校验) MIN_SDK_VERSION = "1.0.0" -MAX_SDK_VERSION = "1.99.99" +MAX_SDK_VERSION = "2.99.99" # ====== 消息类型 ====== diff --git a/src/plugins/built_in/napcat_adapter/plugin.py b/src/plugins/built_in/napcat_adapter/plugin.py index 3eff518d..a481101f 100644 --- a/src/plugins/built_in/napcat_adapter/plugin.py +++ b/src/plugins/built_in/napcat_adapter/plugin.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast +from typing import Any, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, cast from uuid import uuid4 import asyncio @@ -42,6 +42,13 @@ if not TYPE_CHECKING: AiohttpClientWebSocketResponse = Any +SUPPORTED_CONFIG_VERSION = "0.1.0" +DEFAULT_RECONNECT_DELAY_SEC = 5.0 +DEFAULT_HEARTBEAT_SEC = 30.0 +DEFAULT_ACTION_TIMEOUT_SEC = 15.0 +DEFAULT_CHAT_LIST_TYPE = "whitelist" + + @Adapter(platform="qq", protocol="napcat", send_method="send_to_platform") class NapCatAdapterPlugin(MaiBotPlugin): """NapCat 适配器 MVP 实现。""" @@ -52,7 +59,7 @@ class NapCatAdapterPlugin(MaiBotPlugin): self._plugin_config: Dict[str, Any] = {} self._connection_task: Optional[asyncio.Task[None]] = None self._pending_actions: Dict[str, asyncio.Future[Dict[str, Any]]] = {} - self._background_tasks: set[asyncio.Task[Any]] = set() + self._background_tasks: Set[asyncio.Task[Any]] = set() self._send_lock = asyncio.Lock() self._ws: Optional[AiohttpClientWebSocketResponse] = None @@ -80,8 +87,9 @@ class NapCatAdapterPlugin(MaiBotPlugin): new_config: 最新的插件配置。 version: 配置版本号。 """ - del version self.set_plugin_config(new_config) + if version: + self.ctx.logger.debug(f"NapCat 适配器收到配置更新通知: {version}") await self._restart_connection_if_needed() async def send_to_platform( @@ -139,6 +147,8 @@ class NapCatAdapterPlugin(MaiBotPlugin): if not self._should_connect(): self.ctx.logger.info("NapCat 适配器保持空闲状态,因为插件或配置未启用") return + if not self._validate_current_config(): + return if not AIOHTTP_AVAILABLE: self.ctx.logger.error("NapCat 适配器依赖 aiohttp,但当前环境未安装该依赖") return @@ -185,7 +195,7 @@ class NapCatAdapterPlugin(MaiBotPlugin): headers = self._build_headers() timeout = ClientTimeout(total=None, connect=10) - heartbeat = self._get_positive_float(self._connection_config(), "heartbeat_sec", 30.0) + heartbeat = self._get_positive_float(self._connection_config(), "heartbeat_sec", DEFAULT_HEARTBEAT_SEC) try: async with ClientSession(headers=headers, timeout=timeout) as session: @@ -204,7 +214,13 @@ class NapCatAdapterPlugin(MaiBotPlugin): if not self._should_connect(): break - await asyncio.sleep(self._get_positive_float(self._connection_config(), "reconnect_delay_sec", 5.0)) + await asyncio.sleep( + self._get_positive_float( + self._connection_config(), + "reconnect_delay_sec", + DEFAULT_RECONNECT_DELAY_SEC, + ) + ) async def _receive_loop(self, ws: AiohttpClientWebSocketResponse) -> None: """持续消费 WebSocket 消息并分发处理。 @@ -250,8 +266,11 @@ class NapCatAdapterPlugin(MaiBotPlugin): if not sender_user_id: return + group_id = str(payload.get("group_id") or "").strip() if self_id and sender_user_id == self_id and self._get_bool(self._filters_config(), "ignore_self_message", True): return + if not self._is_inbound_chat_allowed(sender_user_id, group_id): + return message_dict = self._build_inbound_message_dict(payload, self_id, sender_user_id, sender) route_metadata: Dict[str, Any] = {} @@ -339,7 +358,7 @@ class NapCatAdapterPlugin(MaiBotPlugin): "display_message": plain_text, } - def _convert_inbound_segments(self, message_payload: Any, self_id: str) -> tuple[List[Dict[str, Any]], bool]: + def _convert_inbound_segments(self, message_payload: Any, self_id: str) -> Tuple[List[Dict[str, Any]], bool]: """将 OneBot 消息段转换为 Host 消息段结构。 Args: @@ -347,7 +366,7 @@ class NapCatAdapterPlugin(MaiBotPlugin): self_id: 当前机器人账号 ID。 Returns: - tuple[List[Dict[str, Any]], bool]: 转换后的消息段列表,以及是否 @ 到当前机器人。 + Tuple[List[Dict[str, Any]], bool]: 转换后的消息段列表,以及是否 @ 到当前机器人。 """ if isinstance(message_payload, str): normalized_text = message_payload.strip() @@ -412,7 +431,7 @@ class NapCatAdapterPlugin(MaiBotPlugin): self, message: Dict[str, Any], route: Dict[str, Any], - ) -> tuple[str, Dict[str, Any]]: + ) -> Tuple[str, Dict[str, Any]]: """为 Host 出站消息构造 OneBot 动作。 Args: @@ -420,7 +439,7 @@ class NapCatAdapterPlugin(MaiBotPlugin): route: Platform IO 路由信息。 Returns: - tuple[str, Dict[str, Any]]: 动作名称与参数字典。 + Tuple[str, Dict[str, Any]]: 动作名称与参数字典。 """ message_info = message.get("message_info", {}) if not isinstance(message_info, dict): @@ -519,7 +538,11 @@ class NapCatAdapterPlugin(MaiBotPlugin): try: async with self._send_lock: await ws.send_str(json.dumps(request_payload, ensure_ascii=False)) - timeout_seconds = self._get_positive_float(self._connection_config(), "action_timeout_sec", 15.0) + timeout_seconds = self._get_positive_float( + self._connection_config(), + "action_timeout_sec", + DEFAULT_ACTION_TIMEOUT_SEC, + ) return await asyncio.wait_for(response_future, timeout=timeout_seconds) finally: self._pending_actions.pop(echo_id, None) @@ -626,6 +649,173 @@ class NapCatAdapterPlugin(MaiBotPlugin): filters_config = self._plugin_config.get("filters", {}) return filters_config if isinstance(filters_config, dict) else {} + def _chat_config(self) -> Dict[str, Any]: + """读取插件配置中的 ``chat`` 段。 + + Returns: + Dict[str, Any]: ``chat`` 配置字典。 + """ + chat_config = self._plugin_config.get("chat", {}) + return chat_config if isinstance(chat_config, dict) else {} + + def _is_inbound_chat_allowed(self, sender_user_id: str, group_id: str) -> bool: + """检查入站消息是否通过聊天名单过滤。 + + Args: + sender_user_id: 发送者用户 ID。 + group_id: 群聊 ID;私聊时为空字符串。 + + Returns: + bool: 若消息允许继续进入 Host,则返回 ``True``。 + """ + chat_config = self._chat_config() + banned_user_ids = self._get_string_list(chat_config, "ban_user_id") + if sender_user_id in banned_user_ids: + self.ctx.logger.warning(f"NapCat 用户 {sender_user_id} 在全局禁止名单中,消息被丢弃") + return False + + if group_id: + group_list_type = self._get_list_mode(chat_config, "group_list_type", DEFAULT_CHAT_LIST_TYPE) + group_id_list = self._get_string_list(chat_config, "group_list") + if not self._is_id_allowed_by_list_policy(group_id, group_list_type, group_id_list): + self.ctx.logger.warning(f"NapCat 群聊 {group_id} 未通过聊天名单过滤,消息被丢弃") + return False + return True + + private_list_type = self._get_list_mode(chat_config, "private_list_type", DEFAULT_CHAT_LIST_TYPE) + private_id_list = self._get_string_list(chat_config, "private_list") + if not self._is_id_allowed_by_list_policy(sender_user_id, private_list_type, private_id_list): + self.ctx.logger.warning(f"NapCat 私聊用户 {sender_user_id} 未通过聊天名单过滤,消息被丢弃") + return False + return True + + def _is_id_allowed_by_list_policy( + self, + target_id: str, + list_type: str, + configured_ids: Set[str], + ) -> bool: + """根据白名单或黑名单规则判断目标 ID 是否允许通过。 + + Args: + target_id: 待检查的目标 ID。 + list_type: 名单模式,仅支持 ``whitelist`` 或 ``blacklist``。 + configured_ids: 配置中的 ID 集合。 + + Returns: + bool: 若目标 ID 允许通过,则返回 ``True``。 + """ + if list_type == "whitelist": + return target_id in configured_ids + return target_id not in configured_ids + + def _validate_current_config(self) -> bool: + """校验当前配置是否满足启动连接的前提条件。 + + Returns: + bool: 配置可用于启动连接时返回 ``True``。 + """ + if not self._validate_plugin_config_version(): + return False + + connection_config = self._connection_config() + ws_url = self._get_string(connection_config, "ws_url") + if not ws_url: + self.ctx.logger.warning("NapCat 适配器已启用,但 connection.ws_url 为空") + return False + + self._validate_positive_float_setting( + connection_config, + "connection", + "reconnect_delay_sec", + DEFAULT_RECONNECT_DELAY_SEC, + ) + self._validate_positive_float_setting( + connection_config, + "connection", + "heartbeat_sec", + DEFAULT_HEARTBEAT_SEC, + ) + self._validate_positive_float_setting( + connection_config, + "connection", + "action_timeout_sec", + DEFAULT_ACTION_TIMEOUT_SEC, + ) + self._validate_list_mode_setting(self._chat_config(), "chat", "group_list_type", DEFAULT_CHAT_LIST_TYPE) + self._validate_list_mode_setting(self._chat_config(), "chat", "private_list_type", DEFAULT_CHAT_LIST_TYPE) + return True + + def _validate_plugin_config_version(self) -> bool: + """校验插件配置版本是否与当前实现兼容。 + + Returns: + bool: 版本兼容时返回 ``True``。 + """ + config_version = self._get_string(self._plugin_section(), "config_version") + if not config_version: + self.ctx.logger.error( + f"NapCat 适配器配置缺少 plugin.config_version,当前插件要求版本 {SUPPORTED_CONFIG_VERSION}" + ) + return False + + if config_version != SUPPORTED_CONFIG_VERSION: + self.ctx.logger.error( + "NapCat 适配器配置版本不兼容: " + f"当前为 {config_version},当前插件要求 {SUPPORTED_CONFIG_VERSION}" + ) + return False + + return True + + def _validate_positive_float_setting( + self, + mapping: Dict[str, Any], + section_name: str, + key: str, + default: float, + ) -> None: + """校验正浮点数配置项,并在非法时输出告警日志。 + + Args: + mapping: 待读取的配置字典。 + section_name: 当前配置段名称。 + key: 目标配置键名。 + default: 配置非法时实际使用的默认值。 + """ + value = mapping.get(key, default) + if isinstance(value, (int, float)) and float(value) > 0: + return + + self.ctx.logger.warning( + "NapCat 适配器配置项取值无效,已回退到默认值: " + f"{section_name}.{key}={value!r},默认值为 {default}" + ) + + def _validate_list_mode_setting( + self, + mapping: Dict[str, Any], + section_name: str, + key: str, + default: str, + ) -> None: + """校验名单模式配置项,并在非法时输出告警日志。 + + Args: + mapping: 待读取的配置字典。 + section_name: 当前配置段名称。 + key: 目标配置键名。 + default: 配置非法时实际使用的默认值。 + """ + value = mapping.get(key, default) + if isinstance(value, str) and value.strip() in {"whitelist", "blacklist"}: + return + + self.ctx.logger.warning( + "NapCat 适配器配置项取值无效,已回退到默认值: " + f"{section_name}.{key}={value!r},默认值为 {default}" + ) + def _should_connect(self) -> bool: """判断当前配置下是否应当启动连接。 @@ -680,6 +870,47 @@ class NapCatAdapterPlugin(MaiBotPlugin): value = mapping.get(key) return "" if value is None else str(value).strip() + @staticmethod + def _get_list_mode(mapping: Dict[str, Any], key: str, default: str) -> str: + """安全读取名单模式配置值。 + + Args: + mapping: 待读取的配置字典。 + key: 目标键名。 + default: 读取失败时的默认值。 + + Returns: + str: 合法的名单模式字符串。 + """ + value = mapping.get(key, default) + if isinstance(value, str): + normalized_value = value.strip() + if normalized_value in {"whitelist", "blacklist"}: + return normalized_value + return default + + @staticmethod + def _get_string_list(mapping: Dict[str, Any], key: str) -> Set[str]: + """安全读取 ID 列表配置值。 + + Args: + mapping: 待读取的配置字典。 + key: 目标键名。 + + Returns: + Set[str]: 去重后的字符串 ID 集合。 + """ + value = mapping.get(key, []) + if not isinstance(value, list): + return set() + + normalized_values: Set[str] = set() + for item in value: + item_text = "" if item is None else str(item).strip() + if item_text: + normalized_values.add(item_text) + return normalized_values + def create_plugin() -> NapCatAdapterPlugin: """创建插件实例。