refactor: 更新插件和 RPC 服务器逻辑,增强握手状态管理与配置校验
This commit is contained in:
@@ -19,7 +19,7 @@ dependencies = [
|
||||
"jieba>=0.42.1",
|
||||
"json-repair>=0.47.6",
|
||||
"maim-message>=0.6.2",
|
||||
"maibot-plugin-sdk>=1.2.3,<2.0.0",
|
||||
"maibot-plugin-sdk>=2.0.0",
|
||||
"msgpack>=1.1.2",
|
||||
"numpy>=2.2.6",
|
||||
"openai>=1.95.0",
|
||||
|
||||
@@ -69,6 +69,7 @@ class RPCServer:
|
||||
# 运行状态
|
||||
self._running: bool = False
|
||||
self._tasks: List[asyncio.Task[None]] = []
|
||||
self._last_handshake_rejection_reason: str = ""
|
||||
|
||||
@property
|
||||
def session_token(self) -> str:
|
||||
@@ -78,6 +79,15 @@ class RPCServer:
|
||||
def is_connected(self) -> bool:
|
||||
return self._connection is not None and not self._connection.is_closed
|
||||
|
||||
@property
|
||||
def last_handshake_rejection_reason(self) -> str:
|
||||
"""返回最近一次握手被拒绝的原因。"""
|
||||
return self._last_handshake_rejection_reason
|
||||
|
||||
def clear_handshake_state(self) -> None:
|
||||
"""清空最近一次握手拒绝状态。"""
|
||||
self._last_handshake_rejection_reason = ""
|
||||
|
||||
def register_method(self, method: str, handler: MethodHandler) -> None:
|
||||
"""注册 RPC 方法处理器"""
|
||||
self._method_handlers[method] = handler
|
||||
@@ -85,6 +95,7 @@ class RPCServer:
|
||||
async def start(self) -> None:
|
||||
"""启动 RPC 服务器"""
|
||||
self._running = True
|
||||
self.clear_handshake_state()
|
||||
self._send_queue = asyncio.Queue(maxsize=self._send_queue_size)
|
||||
self._send_worker_task = asyncio.create_task(self._send_loop())
|
||||
await self._transport.start(self._handle_connection)
|
||||
@@ -93,6 +104,7 @@ class RPCServer:
|
||||
async def stop(self) -> None:
|
||||
"""停止 RPC 服务器"""
|
||||
self._running = False
|
||||
self.clear_handshake_state()
|
||||
self._fail_pending_requests(ErrorCode.E_SHUTTING_DOWN, "服务器正在关闭")
|
||||
self._fail_queued_sends(ErrorCode.E_SHUTTING_DOWN, "服务器正在关闭")
|
||||
|
||||
@@ -204,6 +216,7 @@ class RPCServer:
|
||||
async def _handle_connection(self, conn: Connection) -> None:
|
||||
"""处理新的 Runner 连接"""
|
||||
logger.info("收到 Runner 连接")
|
||||
self.clear_handshake_state()
|
||||
# 第一条消息必须是 runner.hello 握手
|
||||
try:
|
||||
success = await self._handle_handshake(conn)
|
||||
@@ -232,6 +245,7 @@ class RPCServer:
|
||||
envelope = self._codec.decode_envelope(data)
|
||||
if envelope.method != "runner.hello":
|
||||
logger.error(f"期望 runner.hello,收到 {envelope.method}")
|
||||
self._last_handshake_rejection_reason = "首条消息必须为 runner.hello"
|
||||
error_resp = envelope.make_error_response(
|
||||
ErrorCode.E_PROTOCOL_MISMATCH.value,
|
||||
"首条消息必须为 runner.hello",
|
||||
@@ -244,7 +258,8 @@ class RPCServer:
|
||||
# 校验会话令牌
|
||||
if hello.session_token != self._session_token:
|
||||
logger.error("会话令牌不匹配")
|
||||
resp_payload = HelloResponsePayload(accepted=False, reason="会话令牌无效")
|
||||
self._last_handshake_rejection_reason = "会话令牌无效"
|
||||
resp_payload = HelloResponsePayload(accepted=False, reason=self._last_handshake_rejection_reason)
|
||||
resp = envelope.make_response(payload=resp_payload.model_dump())
|
||||
await conn.send_frame(self._codec.encode_envelope(resp))
|
||||
return False
|
||||
@@ -252,15 +267,19 @@ class RPCServer:
|
||||
# 校验 SDK 版本
|
||||
if not self._check_sdk_version(hello.sdk_version):
|
||||
logger.error(f"SDK 版本不兼容: {hello.sdk_version}")
|
||||
self._last_handshake_rejection_reason = (
|
||||
f"SDK 版本 {hello.sdk_version} 不在支持范围 [{MIN_SDK_VERSION}, {MAX_SDK_VERSION}]"
|
||||
)
|
||||
resp_payload = HelloResponsePayload(
|
||||
accepted=False,
|
||||
reason=f"SDK 版本 {hello.sdk_version} 不在支持范围 [{MIN_SDK_VERSION}, {MAX_SDK_VERSION}]",
|
||||
reason=self._last_handshake_rejection_reason,
|
||||
)
|
||||
resp = envelope.make_response(payload=resp_payload.model_dump())
|
||||
await conn.send_frame(self._codec.encode_envelope(resp))
|
||||
return False
|
||||
|
||||
# 发送响应
|
||||
self.clear_handshake_state()
|
||||
resp_payload = HelloResponsePayload(accepted=True, host_version=PROTOCOL_VERSION)
|
||||
resp = envelope.make_response(payload=resp_payload.model_dump())
|
||||
await conn.send_frame(self._codec.encode_envelope(resp))
|
||||
|
||||
@@ -202,17 +202,24 @@ class PluginRunnerSupervisor:
|
||||
self._restart_count = 0
|
||||
self._clear_runner_state()
|
||||
|
||||
await self._rpc_server.start()
|
||||
await self._spawn_runner()
|
||||
|
||||
try:
|
||||
await self._wait_for_runner_connection(timeout_sec=self._runner_spawn_timeout)
|
||||
await self._wait_for_runner_ready(timeout_sec=self._runner_spawn_timeout)
|
||||
except TimeoutError:
|
||||
if not self._rpc_server.is_connected:
|
||||
logger.warning("Runner 未在限定时间内完成连接,后续操作可能失败")
|
||||
else:
|
||||
logger.warning("Runner 未在限定时间内完成初始化,后续操作可能失败")
|
||||
await self._rpc_server.start()
|
||||
await self._spawn_runner()
|
||||
|
||||
try:
|
||||
await self._wait_for_runner_connection(timeout_sec=self._runner_spawn_timeout)
|
||||
await self._wait_for_runner_ready(timeout_sec=self._runner_spawn_timeout)
|
||||
except TimeoutError:
|
||||
if not self._rpc_server.is_connected:
|
||||
logger.warning("Runner 未在限定时间内完成连接,后续操作可能失败")
|
||||
else:
|
||||
logger.warning("Runner 未在限定时间内完成初始化,后续操作可能失败")
|
||||
except Exception:
|
||||
await self._shutdown_runner(reason="startup_failed")
|
||||
await self._rpc_server.stop()
|
||||
self._clear_runner_state()
|
||||
self._running = False
|
||||
raise
|
||||
|
||||
self._health_task = asyncio.create_task(self._health_check_loop(), name="PluginRunnerSupervisor.health")
|
||||
logger.info("PluginRunnerSupervisor 已启动")
|
||||
@@ -387,7 +394,16 @@ class PluginRunnerSupervisor:
|
||||
|
||||
async def wait_for_connection() -> None:
|
||||
"""轮询等待 RPC 连接建立。"""
|
||||
while self._running and not self._rpc_server.is_connected:
|
||||
while True:
|
||||
if self._rpc_server.is_connected:
|
||||
return
|
||||
|
||||
if not self._running:
|
||||
raise RuntimeError("Supervisor 已停止,等待 Runner 连接已取消")
|
||||
|
||||
if failure_reason := self._get_runner_startup_failure_reason():
|
||||
raise RuntimeError(f"等待 Runner 连接失败: {failure_reason}")
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
try:
|
||||
@@ -408,10 +424,27 @@ class PluginRunnerSupervisor:
|
||||
Raises:
|
||||
TimeoutError: 在超时时间内 Runner 未完成初始化。
|
||||
"""
|
||||
async def wait_for_ready() -> RunnerReadyPayload:
|
||||
"""轮询等待 Runner 上报就绪。"""
|
||||
while True:
|
||||
if self._runner_ready_events.is_set():
|
||||
return self._runner_ready_payloads
|
||||
|
||||
if not self._running:
|
||||
raise RuntimeError("Supervisor 已停止,等待 Runner 就绪已取消")
|
||||
|
||||
if failure_reason := self._get_runner_startup_failure_reason():
|
||||
raise RuntimeError(f"等待 Runner 就绪失败: {failure_reason}")
|
||||
|
||||
if not self._rpc_server.is_connected:
|
||||
raise RuntimeError("等待 Runner 就绪失败: Runner RPC 连接已断开")
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(self._runner_ready_events.wait(), timeout=timeout_sec)
|
||||
payload = await asyncio.wait_for(wait_for_ready(), timeout=timeout_sec)
|
||||
logger.info("Runner 已完成初始化并上报就绪")
|
||||
return self._runner_ready_payloads
|
||||
return payload
|
||||
except asyncio.TimeoutError as exc:
|
||||
raise TimeoutError(f"等待 Runner 就绪超时({timeout_sec}s)") from exc
|
||||
|
||||
@@ -923,6 +956,7 @@ class PluginRunnerSupervisor:
|
||||
await self._wait_for_runner_connection(timeout_sec=self._runner_spawn_timeout)
|
||||
await self._wait_for_runner_ready(timeout_sec=self._runner_spawn_timeout)
|
||||
except Exception as exc:
|
||||
await self._shutdown_runner(reason="restart_failed")
|
||||
logger.error(f"Runner 重启失败: {exc}", exc_info=True)
|
||||
return False
|
||||
|
||||
@@ -938,6 +972,25 @@ class PluginRunnerSupervisor:
|
||||
self._registered_adapters.clear()
|
||||
self._runner_ready_events = asyncio.Event()
|
||||
self._runner_ready_payloads = RunnerReadyPayload()
|
||||
self._rpc_server.clear_handshake_state()
|
||||
|
||||
def _get_runner_startup_failure_reason(self) -> Optional[str]:
|
||||
"""获取 Runner 在启动阶段已经暴露出的失败原因。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 若已检测到失败则返回失败原因,否则返回 ``None``。
|
||||
"""
|
||||
if handshake_reason := self._rpc_server.last_handshake_rejection_reason:
|
||||
return f"握手被拒绝: {handshake_reason}"
|
||||
|
||||
process = self._runner_process
|
||||
if process is None:
|
||||
return "Runner 进程不存在"
|
||||
|
||||
if process.returncode is not None:
|
||||
return f"Runner 进程已退出,退出码 {process.returncode}"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
PluginSupervisor = PluginRunnerSupervisor
|
||||
|
||||
@@ -68,6 +68,7 @@ class PluginRuntimeManager(
|
||||
self._plugin_file_watcher: Optional[FileWatcher] = None
|
||||
self._plugin_source_watcher_subscription_id: Optional[str] = None
|
||||
self._plugin_config_watcher_subscriptions: Dict[str, Tuple[Path, str]] = {}
|
||||
self._plugin_path_cache: Dict[str, Path] = {}
|
||||
|
||||
async def _dispatch_platform_inbound(self, envelope: InboundMessageEnvelope) -> None:
|
||||
"""接收 Platform IO 审核后的入站消息并送入主消息链。
|
||||
@@ -215,6 +216,7 @@ class PluginRuntimeManager(
|
||||
self._started = False
|
||||
self._builtin_supervisor = None
|
||||
self._third_party_supervisor = None
|
||||
self._plugin_path_cache.clear()
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
@@ -254,7 +256,7 @@ class PluginRuntimeManager(
|
||||
config_payload = (
|
||||
config_data
|
||||
if config_data is not None
|
||||
else self._load_plugin_config_for_supervisor(plugin_id, plugin_dirs=sv._plugin_dirs)
|
||||
else self._load_plugin_config_for_supervisor(sv, plugin_id)
|
||||
)
|
||||
await sv.notify_plugin_config_updated(
|
||||
plugin_id=plugin_id,
|
||||
@@ -452,6 +454,7 @@ class PluginRuntimeManager(
|
||||
async def _stop_plugin_file_watcher(self) -> None:
|
||||
"""停止插件文件监视器,并清理所有已注册订阅。"""
|
||||
if self._plugin_file_watcher is None:
|
||||
self._plugin_path_cache.clear()
|
||||
return
|
||||
for _plugin_id, (_config_path, subscription_id) in list(self._plugin_config_watcher_subscriptions.items()):
|
||||
self._plugin_file_watcher.unsubscribe(subscription_id)
|
||||
@@ -461,12 +464,95 @@ class PluginRuntimeManager(
|
||||
self._plugin_source_watcher_subscription_id = None
|
||||
await self._plugin_file_watcher.stop()
|
||||
self._plugin_file_watcher = None
|
||||
self._plugin_path_cache.clear()
|
||||
|
||||
def _iter_plugin_dirs(self) -> Iterable[Path]:
|
||||
"""迭代所有 Supervisor 当前管理的插件根目录。"""
|
||||
for supervisor in self.supervisors:
|
||||
yield from getattr(supervisor, "_plugin_dirs", [])
|
||||
|
||||
@staticmethod
|
||||
def _iter_candidate_plugin_paths(plugin_dirs: Iterable[Path]) -> Iterable[Path]:
|
||||
"""迭代所有可能的插件目录路径。
|
||||
|
||||
Args:
|
||||
plugin_dirs: 一个或多个插件根目录。
|
||||
|
||||
Yields:
|
||||
Path: 单个插件目录路径。
|
||||
"""
|
||||
for plugin_dir in plugin_dirs:
|
||||
plugin_root = Path(plugin_dir).resolve()
|
||||
if not plugin_root.is_dir():
|
||||
continue
|
||||
for entry in plugin_root.iterdir():
|
||||
if entry.is_dir():
|
||||
yield entry.resolve()
|
||||
|
||||
@staticmethod
|
||||
def _read_plugin_id_from_plugin_path(plugin_path: Path) -> Optional[str]:
|
||||
"""从单个插件目录中读取 manifest 声明的插件 ID。
|
||||
|
||||
Args:
|
||||
plugin_path: 单个插件目录路径。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 解析成功时返回插件 ID,否则返回 ``None``。
|
||||
"""
|
||||
manifest_path = plugin_path / "_manifest.json"
|
||||
entrypoint_path = plugin_path / "plugin.py"
|
||||
if not manifest_path.is_file() or not entrypoint_path.is_file():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(manifest_path, "r", encoding="utf-8") as manifest_file:
|
||||
manifest = json.load(manifest_file)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if not isinstance(manifest, dict):
|
||||
return None
|
||||
|
||||
plugin_id = str(manifest.get("name", plugin_path.name)).strip() or plugin_path.name
|
||||
return plugin_id or None
|
||||
|
||||
def _iter_discovered_plugin_paths(self, plugin_dirs: Iterable[Path]) -> Iterable[Tuple[str, Path]]:
|
||||
"""迭代目录中可解析到的插件 ID 与实际目录路径。
|
||||
|
||||
Args:
|
||||
plugin_dirs: 一个或多个插件根目录。
|
||||
|
||||
Yields:
|
||||
Tuple[str, Path]: ``(plugin_id, plugin_path)`` 二元组。
|
||||
"""
|
||||
for plugin_path in self._iter_candidate_plugin_paths(plugin_dirs):
|
||||
if plugin_id := self._read_plugin_id_from_plugin_path(plugin_path):
|
||||
yield plugin_id, plugin_path
|
||||
|
||||
def _get_plugin_path_for_supervisor(self, supervisor: Any, plugin_id: str) -> Optional[Path]:
|
||||
"""为指定 Supervisor 定位某个插件的实际目录。
|
||||
|
||||
Args:
|
||||
supervisor: 目标 Supervisor。
|
||||
plugin_id: 插件 ID。
|
||||
|
||||
Returns:
|
||||
Optional[Path]: 插件目录路径;未找到时返回 ``None``。
|
||||
"""
|
||||
cached_path = self._plugin_path_cache.get(plugin_id)
|
||||
if cached_path is not None:
|
||||
for plugin_dir in getattr(supervisor, "_plugin_dirs", []):
|
||||
if self._plugin_dir_matches(cached_path, Path(plugin_dir)):
|
||||
return cached_path
|
||||
|
||||
for candidate_plugin_id, plugin_path in self._iter_discovered_plugin_paths(getattr(supervisor, "_plugin_dirs", [])):
|
||||
if candidate_plugin_id != plugin_id:
|
||||
continue
|
||||
self._plugin_path_cache[plugin_id] = plugin_path
|
||||
return plugin_path
|
||||
|
||||
return None
|
||||
|
||||
def _refresh_plugin_config_watch_subscriptions(self) -> None:
|
||||
"""按当前已注册插件集合刷新 config.toml 的单插件订阅。
|
||||
|
||||
@@ -476,7 +562,11 @@ class PluginRuntimeManager(
|
||||
if self._plugin_file_watcher is None:
|
||||
return
|
||||
|
||||
desired_config_paths = dict(self._iter_registered_plugin_config_paths())
|
||||
desired_plugin_paths = dict(self._iter_registered_plugin_paths())
|
||||
self._plugin_path_cache = desired_plugin_paths.copy()
|
||||
desired_config_paths = {
|
||||
plugin_id: plugin_path / "config.toml" for plugin_id, plugin_path in desired_plugin_paths.items()
|
||||
}
|
||||
|
||||
for plugin_id, (_old_path, subscription_id) in list(self._plugin_config_watcher_subscriptions.items()):
|
||||
if desired_config_paths.get(plugin_id) == self._plugin_config_watcher_subscriptions[plugin_id][0]:
|
||||
@@ -509,21 +599,17 @@ class PluginRuntimeManager(
|
||||
|
||||
return _callback
|
||||
|
||||
def _iter_registered_plugin_config_paths(self) -> Iterable[Tuple[str, Path]]:
|
||||
"""迭代当前所有已注册插件的 config.toml 路径。"""
|
||||
def _iter_registered_plugin_paths(self) -> Iterable[Tuple[str, Path]]:
|
||||
"""迭代当前所有已注册插件的实际目录路径。"""
|
||||
for supervisor in self.supervisors:
|
||||
for plugin_id in getattr(supervisor, "_registered_plugins", {}).keys():
|
||||
if config_path := self._get_plugin_config_path_for_supervisor(supervisor, plugin_id):
|
||||
yield plugin_id, config_path
|
||||
if plugin_path := self._get_plugin_path_for_supervisor(supervisor, plugin_id):
|
||||
yield plugin_id, plugin_path
|
||||
|
||||
def _get_plugin_config_path_for_supervisor(self, supervisor: Any, plugin_id: str) -> Optional[Path]:
|
||||
"""从指定 Supervisor 的插件目录中定位某个插件的 config.toml。"""
|
||||
for plugin_dir in getattr(supervisor, "_plugin_dirs", []):
|
||||
plugin_dir = Path(plugin_dir)
|
||||
plugin_path = plugin_dir.resolve() / plugin_id
|
||||
if plugin_path.is_dir():
|
||||
return plugin_path / "config.toml"
|
||||
return None
|
||||
plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id)
|
||||
return None if plugin_path is None else plugin_path / "config.toml"
|
||||
|
||||
async def _handle_plugin_config_changes(self, plugin_id: str, changes: Sequence[FileChange]) -> None:
|
||||
"""处理单个插件配置文件变化,并仅向目标插件推送配置更新。"""
|
||||
@@ -542,7 +628,7 @@ class PluginRuntimeManager(
|
||||
try:
|
||||
await supervisor.notify_plugin_config_updated(
|
||||
plugin_id=plugin_id,
|
||||
config_data=self._load_plugin_config_for_supervisor(plugin_id, getattr(supervisor, "_plugin_dirs", [])),
|
||||
config_data=self._load_plugin_config_for_supervisor(supervisor, plugin_id),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(f"插件 {plugin_id} 配置热更新通知失败: {exc}")
|
||||
@@ -591,32 +677,38 @@ class PluginRuntimeManager(
|
||||
|
||||
def _match_plugin_id_for_supervisor(self, supervisor: Any, path: Path) -> Optional[str]:
|
||||
"""根据变更路径为指定 Supervisor 推断受影响的插件 ID。"""
|
||||
for plugin_id, _reg in getattr(supervisor, "_registered_plugins", {}).items():
|
||||
for plugin_dir in getattr(supervisor, "_plugin_dirs", []):
|
||||
plugin_dir = Path(plugin_dir)
|
||||
candidate_dir = plugin_dir.resolve() / plugin_id
|
||||
if path == candidate_dir or path.is_relative_to(candidate_dir):
|
||||
return plugin_id
|
||||
resolved_path = path.resolve()
|
||||
|
||||
for plugin_id in getattr(supervisor, "_registered_plugins", {}).keys():
|
||||
plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id)
|
||||
if plugin_path is not None and (resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path)):
|
||||
return plugin_id
|
||||
|
||||
for plugin_id, plugin_path in self._plugin_path_cache.items():
|
||||
if not any(self._plugin_dir_matches(plugin_path, Path(plugin_dir)) for plugin_dir in getattr(supervisor, "_plugin_dirs", [])):
|
||||
continue
|
||||
if resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path):
|
||||
return plugin_id
|
||||
|
||||
for plugin_id, plugin_path in self._iter_discovered_plugin_paths(getattr(supervisor, "_plugin_dirs", [])):
|
||||
if resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path):
|
||||
self._plugin_path_cache[plugin_id] = plugin_path
|
||||
return plugin_id
|
||||
|
||||
for plugin_dir in getattr(supervisor, "_plugin_dirs", []):
|
||||
plugin_dir = Path(plugin_dir)
|
||||
plugin_root = plugin_dir.resolve()
|
||||
if self._plugin_dir_matches(path, plugin_dir) and (relative_parts := path.relative_to(plugin_root).parts):
|
||||
return relative_parts[0]
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _load_plugin_config_for_supervisor(plugin_id: str, plugin_dirs: Iterable[Path]) -> Dict[str, Any]:
|
||||
def _load_plugin_config_for_supervisor(self, supervisor: Any, plugin_id: str) -> Dict[str, Any]:
|
||||
"""从给定插件目录集合中读取目标插件的配置内容。"""
|
||||
for plugin_dir in plugin_dirs:
|
||||
plugin_path = plugin_dir.resolve() / plugin_id
|
||||
if plugin_path.is_dir():
|
||||
config_path = plugin_path / "config.toml"
|
||||
if not config_path.exists():
|
||||
return {}
|
||||
with open(config_path, "r", encoding="utf-8") as handle:
|
||||
return tomlkit.load(handle).unwrap()
|
||||
return {}
|
||||
plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id)
|
||||
if plugin_path is None:
|
||||
return {}
|
||||
|
||||
config_path = plugin_path / "config.toml"
|
||||
if not config_path.exists():
|
||||
return {}
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as handle:
|
||||
return tomlkit.load(handle).unwrap()
|
||||
|
||||
# ─── 能力实现注册 ──────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from pydantic import BaseModel, Field
|
||||
PROTOCOL_VERSION = "1.0.0"
|
||||
# 支持的 SDK 版本范围(Host 在握手时校验)
|
||||
MIN_SDK_VERSION = "1.0.0"
|
||||
MAX_SDK_VERSION = "1.99.99"
|
||||
MAX_SDK_VERSION = "2.99.99"
|
||||
|
||||
|
||||
# ====== 消息类型 ======
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import asyncio
|
||||
@@ -42,6 +42,13 @@ if not TYPE_CHECKING:
|
||||
AiohttpClientWebSocketResponse = Any
|
||||
|
||||
|
||||
SUPPORTED_CONFIG_VERSION = "0.1.0"
|
||||
DEFAULT_RECONNECT_DELAY_SEC = 5.0
|
||||
DEFAULT_HEARTBEAT_SEC = 30.0
|
||||
DEFAULT_ACTION_TIMEOUT_SEC = 15.0
|
||||
DEFAULT_CHAT_LIST_TYPE = "whitelist"
|
||||
|
||||
|
||||
@Adapter(platform="qq", protocol="napcat", send_method="send_to_platform")
|
||||
class NapCatAdapterPlugin(MaiBotPlugin):
|
||||
"""NapCat 适配器 MVP 实现。"""
|
||||
@@ -52,7 +59,7 @@ class NapCatAdapterPlugin(MaiBotPlugin):
|
||||
self._plugin_config: Dict[str, Any] = {}
|
||||
self._connection_task: Optional[asyncio.Task[None]] = None
|
||||
self._pending_actions: Dict[str, asyncio.Future[Dict[str, Any]]] = {}
|
||||
self._background_tasks: set[asyncio.Task[Any]] = set()
|
||||
self._background_tasks: Set[asyncio.Task[Any]] = set()
|
||||
self._send_lock = asyncio.Lock()
|
||||
self._ws: Optional[AiohttpClientWebSocketResponse] = None
|
||||
|
||||
@@ -80,8 +87,9 @@ class NapCatAdapterPlugin(MaiBotPlugin):
|
||||
new_config: 最新的插件配置。
|
||||
version: 配置版本号。
|
||||
"""
|
||||
del version
|
||||
self.set_plugin_config(new_config)
|
||||
if version:
|
||||
self.ctx.logger.debug(f"NapCat 适配器收到配置更新通知: {version}")
|
||||
await self._restart_connection_if_needed()
|
||||
|
||||
async def send_to_platform(
|
||||
@@ -139,6 +147,8 @@ class NapCatAdapterPlugin(MaiBotPlugin):
|
||||
if not self._should_connect():
|
||||
self.ctx.logger.info("NapCat 适配器保持空闲状态,因为插件或配置未启用")
|
||||
return
|
||||
if not self._validate_current_config():
|
||||
return
|
||||
if not AIOHTTP_AVAILABLE:
|
||||
self.ctx.logger.error("NapCat 适配器依赖 aiohttp,但当前环境未安装该依赖")
|
||||
return
|
||||
@@ -185,7 +195,7 @@ class NapCatAdapterPlugin(MaiBotPlugin):
|
||||
|
||||
headers = self._build_headers()
|
||||
timeout = ClientTimeout(total=None, connect=10)
|
||||
heartbeat = self._get_positive_float(self._connection_config(), "heartbeat_sec", 30.0)
|
||||
heartbeat = self._get_positive_float(self._connection_config(), "heartbeat_sec", DEFAULT_HEARTBEAT_SEC)
|
||||
|
||||
try:
|
||||
async with ClientSession(headers=headers, timeout=timeout) as session:
|
||||
@@ -204,7 +214,13 @@ class NapCatAdapterPlugin(MaiBotPlugin):
|
||||
if not self._should_connect():
|
||||
break
|
||||
|
||||
await asyncio.sleep(self._get_positive_float(self._connection_config(), "reconnect_delay_sec", 5.0))
|
||||
await asyncio.sleep(
|
||||
self._get_positive_float(
|
||||
self._connection_config(),
|
||||
"reconnect_delay_sec",
|
||||
DEFAULT_RECONNECT_DELAY_SEC,
|
||||
)
|
||||
)
|
||||
|
||||
async def _receive_loop(self, ws: AiohttpClientWebSocketResponse) -> None:
|
||||
"""持续消费 WebSocket 消息并分发处理。
|
||||
@@ -250,8 +266,11 @@ class NapCatAdapterPlugin(MaiBotPlugin):
|
||||
if not sender_user_id:
|
||||
return
|
||||
|
||||
group_id = str(payload.get("group_id") or "").strip()
|
||||
if self_id and sender_user_id == self_id and self._get_bool(self._filters_config(), "ignore_self_message", True):
|
||||
return
|
||||
if not self._is_inbound_chat_allowed(sender_user_id, group_id):
|
||||
return
|
||||
|
||||
message_dict = self._build_inbound_message_dict(payload, self_id, sender_user_id, sender)
|
||||
route_metadata: Dict[str, Any] = {}
|
||||
@@ -339,7 +358,7 @@ class NapCatAdapterPlugin(MaiBotPlugin):
|
||||
"display_message": plain_text,
|
||||
}
|
||||
|
||||
def _convert_inbound_segments(self, message_payload: Any, self_id: str) -> tuple[List[Dict[str, Any]], bool]:
|
||||
def _convert_inbound_segments(self, message_payload: Any, self_id: str) -> Tuple[List[Dict[str, Any]], bool]:
|
||||
"""将 OneBot 消息段转换为 Host 消息段结构。
|
||||
|
||||
Args:
|
||||
@@ -347,7 +366,7 @@ class NapCatAdapterPlugin(MaiBotPlugin):
|
||||
self_id: 当前机器人账号 ID。
|
||||
|
||||
Returns:
|
||||
tuple[List[Dict[str, Any]], bool]: 转换后的消息段列表,以及是否 @ 到当前机器人。
|
||||
Tuple[List[Dict[str, Any]], bool]: 转换后的消息段列表,以及是否 @ 到当前机器人。
|
||||
"""
|
||||
if isinstance(message_payload, str):
|
||||
normalized_text = message_payload.strip()
|
||||
@@ -412,7 +431,7 @@ class NapCatAdapterPlugin(MaiBotPlugin):
|
||||
self,
|
||||
message: Dict[str, Any],
|
||||
route: Dict[str, Any],
|
||||
) -> tuple[str, Dict[str, Any]]:
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
"""为 Host 出站消息构造 OneBot 动作。
|
||||
|
||||
Args:
|
||||
@@ -420,7 +439,7 @@ class NapCatAdapterPlugin(MaiBotPlugin):
|
||||
route: Platform IO 路由信息。
|
||||
|
||||
Returns:
|
||||
tuple[str, Dict[str, Any]]: 动作名称与参数字典。
|
||||
Tuple[str, Dict[str, Any]]: 动作名称与参数字典。
|
||||
"""
|
||||
message_info = message.get("message_info", {})
|
||||
if not isinstance(message_info, dict):
|
||||
@@ -519,7 +538,11 @@ class NapCatAdapterPlugin(MaiBotPlugin):
|
||||
try:
|
||||
async with self._send_lock:
|
||||
await ws.send_str(json.dumps(request_payload, ensure_ascii=False))
|
||||
timeout_seconds = self._get_positive_float(self._connection_config(), "action_timeout_sec", 15.0)
|
||||
timeout_seconds = self._get_positive_float(
|
||||
self._connection_config(),
|
||||
"action_timeout_sec",
|
||||
DEFAULT_ACTION_TIMEOUT_SEC,
|
||||
)
|
||||
return await asyncio.wait_for(response_future, timeout=timeout_seconds)
|
||||
finally:
|
||||
self._pending_actions.pop(echo_id, None)
|
||||
@@ -626,6 +649,173 @@ class NapCatAdapterPlugin(MaiBotPlugin):
|
||||
filters_config = self._plugin_config.get("filters", {})
|
||||
return filters_config if isinstance(filters_config, dict) else {}
|
||||
|
||||
def _chat_config(self) -> Dict[str, Any]:
|
||||
"""读取插件配置中的 ``chat`` 段。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: ``chat`` 配置字典。
|
||||
"""
|
||||
chat_config = self._plugin_config.get("chat", {})
|
||||
return chat_config if isinstance(chat_config, dict) else {}
|
||||
|
||||
def _is_inbound_chat_allowed(self, sender_user_id: str, group_id: str) -> bool:
|
||||
"""检查入站消息是否通过聊天名单过滤。
|
||||
|
||||
Args:
|
||||
sender_user_id: 发送者用户 ID。
|
||||
group_id: 群聊 ID;私聊时为空字符串。
|
||||
|
||||
Returns:
|
||||
bool: 若消息允许继续进入 Host,则返回 ``True``。
|
||||
"""
|
||||
chat_config = self._chat_config()
|
||||
banned_user_ids = self._get_string_list(chat_config, "ban_user_id")
|
||||
if sender_user_id in banned_user_ids:
|
||||
self.ctx.logger.warning(f"NapCat 用户 {sender_user_id} 在全局禁止名单中,消息被丢弃")
|
||||
return False
|
||||
|
||||
if group_id:
|
||||
group_list_type = self._get_list_mode(chat_config, "group_list_type", DEFAULT_CHAT_LIST_TYPE)
|
||||
group_id_list = self._get_string_list(chat_config, "group_list")
|
||||
if not self._is_id_allowed_by_list_policy(group_id, group_list_type, group_id_list):
|
||||
self.ctx.logger.warning(f"NapCat 群聊 {group_id} 未通过聊天名单过滤,消息被丢弃")
|
||||
return False
|
||||
return True
|
||||
|
||||
private_list_type = self._get_list_mode(chat_config, "private_list_type", DEFAULT_CHAT_LIST_TYPE)
|
||||
private_id_list = self._get_string_list(chat_config, "private_list")
|
||||
if not self._is_id_allowed_by_list_policy(sender_user_id, private_list_type, private_id_list):
|
||||
self.ctx.logger.warning(f"NapCat 私聊用户 {sender_user_id} 未通过聊天名单过滤,消息被丢弃")
|
||||
return False
|
||||
return True
|
||||
|
||||
def _is_id_allowed_by_list_policy(
|
||||
self,
|
||||
target_id: str,
|
||||
list_type: str,
|
||||
configured_ids: Set[str],
|
||||
) -> bool:
|
||||
"""根据白名单或黑名单规则判断目标 ID 是否允许通过。
|
||||
|
||||
Args:
|
||||
target_id: 待检查的目标 ID。
|
||||
list_type: 名单模式,仅支持 ``whitelist`` 或 ``blacklist``。
|
||||
configured_ids: 配置中的 ID 集合。
|
||||
|
||||
Returns:
|
||||
bool: 若目标 ID 允许通过,则返回 ``True``。
|
||||
"""
|
||||
if list_type == "whitelist":
|
||||
return target_id in configured_ids
|
||||
return target_id not in configured_ids
|
||||
|
||||
def _validate_current_config(self) -> bool:
|
||||
"""校验当前配置是否满足启动连接的前提条件。
|
||||
|
||||
Returns:
|
||||
bool: 配置可用于启动连接时返回 ``True``。
|
||||
"""
|
||||
if not self._validate_plugin_config_version():
|
||||
return False
|
||||
|
||||
connection_config = self._connection_config()
|
||||
ws_url = self._get_string(connection_config, "ws_url")
|
||||
if not ws_url:
|
||||
self.ctx.logger.warning("NapCat 适配器已启用,但 connection.ws_url 为空")
|
||||
return False
|
||||
|
||||
self._validate_positive_float_setting(
|
||||
connection_config,
|
||||
"connection",
|
||||
"reconnect_delay_sec",
|
||||
DEFAULT_RECONNECT_DELAY_SEC,
|
||||
)
|
||||
self._validate_positive_float_setting(
|
||||
connection_config,
|
||||
"connection",
|
||||
"heartbeat_sec",
|
||||
DEFAULT_HEARTBEAT_SEC,
|
||||
)
|
||||
self._validate_positive_float_setting(
|
||||
connection_config,
|
||||
"connection",
|
||||
"action_timeout_sec",
|
||||
DEFAULT_ACTION_TIMEOUT_SEC,
|
||||
)
|
||||
self._validate_list_mode_setting(self._chat_config(), "chat", "group_list_type", DEFAULT_CHAT_LIST_TYPE)
|
||||
self._validate_list_mode_setting(self._chat_config(), "chat", "private_list_type", DEFAULT_CHAT_LIST_TYPE)
|
||||
return True
|
||||
|
||||
def _validate_plugin_config_version(self) -> bool:
|
||||
"""校验插件配置版本是否与当前实现兼容。
|
||||
|
||||
Returns:
|
||||
bool: 版本兼容时返回 ``True``。
|
||||
"""
|
||||
config_version = self._get_string(self._plugin_section(), "config_version")
|
||||
if not config_version:
|
||||
self.ctx.logger.error(
|
||||
f"NapCat 适配器配置缺少 plugin.config_version,当前插件要求版本 {SUPPORTED_CONFIG_VERSION}"
|
||||
)
|
||||
return False
|
||||
|
||||
if config_version != SUPPORTED_CONFIG_VERSION:
|
||||
self.ctx.logger.error(
|
||||
"NapCat 适配器配置版本不兼容: "
|
||||
f"当前为 {config_version},当前插件要求 {SUPPORTED_CONFIG_VERSION}"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _validate_positive_float_setting(
|
||||
self,
|
||||
mapping: Dict[str, Any],
|
||||
section_name: str,
|
||||
key: str,
|
||||
default: float,
|
||||
) -> None:
|
||||
"""校验正浮点数配置项,并在非法时输出告警日志。
|
||||
|
||||
Args:
|
||||
mapping: 待读取的配置字典。
|
||||
section_name: 当前配置段名称。
|
||||
key: 目标配置键名。
|
||||
default: 配置非法时实际使用的默认值。
|
||||
"""
|
||||
value = mapping.get(key, default)
|
||||
if isinstance(value, (int, float)) and float(value) > 0:
|
||||
return
|
||||
|
||||
self.ctx.logger.warning(
|
||||
"NapCat 适配器配置项取值无效,已回退到默认值: "
|
||||
f"{section_name}.{key}={value!r},默认值为 {default}"
|
||||
)
|
||||
|
||||
def _validate_list_mode_setting(
|
||||
self,
|
||||
mapping: Dict[str, Any],
|
||||
section_name: str,
|
||||
key: str,
|
||||
default: str,
|
||||
) -> None:
|
||||
"""校验名单模式配置项,并在非法时输出告警日志。
|
||||
|
||||
Args:
|
||||
mapping: 待读取的配置字典。
|
||||
section_name: 当前配置段名称。
|
||||
key: 目标配置键名。
|
||||
default: 配置非法时实际使用的默认值。
|
||||
"""
|
||||
value = mapping.get(key, default)
|
||||
if isinstance(value, str) and value.strip() in {"whitelist", "blacklist"}:
|
||||
return
|
||||
|
||||
self.ctx.logger.warning(
|
||||
"NapCat 适配器配置项取值无效,已回退到默认值: "
|
||||
f"{section_name}.{key}={value!r},默认值为 {default}"
|
||||
)
|
||||
|
||||
def _should_connect(self) -> bool:
|
||||
"""判断当前配置下是否应当启动连接。
|
||||
|
||||
@@ -680,6 +870,47 @@ class NapCatAdapterPlugin(MaiBotPlugin):
|
||||
value = mapping.get(key)
|
||||
return "" if value is None else str(value).strip()
|
||||
|
||||
@staticmethod
|
||||
def _get_list_mode(mapping: Dict[str, Any], key: str, default: str) -> str:
|
||||
"""安全读取名单模式配置值。
|
||||
|
||||
Args:
|
||||
mapping: 待读取的配置字典。
|
||||
key: 目标键名。
|
||||
default: 读取失败时的默认值。
|
||||
|
||||
Returns:
|
||||
str: 合法的名单模式字符串。
|
||||
"""
|
||||
value = mapping.get(key, default)
|
||||
if isinstance(value, str):
|
||||
normalized_value = value.strip()
|
||||
if normalized_value in {"whitelist", "blacklist"}:
|
||||
return normalized_value
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def _get_string_list(mapping: Dict[str, Any], key: str) -> Set[str]:
|
||||
"""安全读取 ID 列表配置值。
|
||||
|
||||
Args:
|
||||
mapping: 待读取的配置字典。
|
||||
key: 目标键名。
|
||||
|
||||
Returns:
|
||||
Set[str]: 去重后的字符串 ID 集合。
|
||||
"""
|
||||
value = mapping.get(key, [])
|
||||
if not isinstance(value, list):
|
||||
return set()
|
||||
|
||||
normalized_values: Set[str] = set()
|
||||
for item in value:
|
||||
item_text = "" if item is None else str(item).strip()
|
||||
if item_text:
|
||||
normalized_values.add(item_text)
|
||||
return normalized_values
|
||||
|
||||
|
||||
def create_plugin() -> NapCatAdapterPlugin:
|
||||
"""创建插件实例。
|
||||
|
||||
Reference in New Issue
Block a user