feat: 增强插件管理和连接稳定性,添加会话令牌重置和组件清理功能
This commit is contained in:
@@ -93,6 +93,14 @@ class ComponentRegistry:
|
|||||||
comp = RegisteredComponent(name, component_type, plugin_id, metadata)
|
comp = RegisteredComponent(name, component_type, plugin_id, metadata)
|
||||||
if comp.full_name in self._components:
|
if comp.full_name in self._components:
|
||||||
logger.warning(f"组件 {comp.full_name} 已存在,覆盖")
|
logger.warning(f"组件 {comp.full_name} 已存在,覆盖")
|
||||||
|
# 从 _by_plugin 列表中移除旧条目,防止幽灵组件堆积
|
||||||
|
old_comp = self._components[comp.full_name]
|
||||||
|
old_list = self._by_plugin.get(old_comp.plugin_id)
|
||||||
|
if old_list is not None:
|
||||||
|
try:
|
||||||
|
old_list.remove(old_comp)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
self._components[comp.full_name] = comp
|
self._components[comp.full_name] = comp
|
||||||
|
|
||||||
|
|||||||
@@ -73,6 +73,11 @@ class RPCServer:
|
|||||||
def session_token(self) -> str:
|
def session_token(self) -> str:
|
||||||
return self._session_token
|
return self._session_token
|
||||||
|
|
||||||
|
def reset_session_token(self) -> str:
|
||||||
|
"""重新生成会话令牌(热重载时调用,防止旧 Runner 重连)"""
|
||||||
|
self._session_token = secrets.token_hex(32)
|
||||||
|
return self._session_token
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def runner_generation(self) -> int:
|
def runner_generation(self) -> int:
|
||||||
return self._runner_generation
|
return self._runner_generation
|
||||||
@@ -155,7 +160,7 @@ class RPCServer:
|
|||||||
raise RPCError(ErrorCode.E_BACKPRESSURE, "发送队列已满")
|
raise RPCError(ErrorCode.E_BACKPRESSURE, "发送队列已满")
|
||||||
|
|
||||||
# 注册 pending future
|
# 注册 pending future
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_running_loop()
|
||||||
future: asyncio.Future[Envelope] = loop.create_future()
|
future: asyncio.Future[Envelope] = loop.create_future()
|
||||||
self._pending_requests[request_id] = future
|
self._pending_requests[request_id] = future
|
||||||
|
|
||||||
@@ -227,6 +232,11 @@ class RPCServer:
|
|||||||
if self._connection is conn:
|
if self._connection is conn:
|
||||||
self._connection = None
|
self._connection = None
|
||||||
self._runner_id = None
|
self._runner_id = None
|
||||||
|
# 连接断开时,立即让所有等待中的请求失败,避免挂起至超时
|
||||||
|
for req_id, future in list(self._pending_requests.items()):
|
||||||
|
if not future.done():
|
||||||
|
future.set_exception(RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已断开"))
|
||||||
|
self._pending_requests.clear()
|
||||||
|
|
||||||
async def _handle_handshake(self, conn: Connection) -> bool:
|
async def _handle_handshake(self, conn: Connection) -> bool:
|
||||||
"""处理 runner.hello 握手"""
|
"""处理 runner.hello 握手"""
|
||||||
@@ -369,7 +379,12 @@ class RPCServer:
|
|||||||
"""处理来自 Runner 的事件"""
|
"""处理来自 Runner 的事件"""
|
||||||
if handler := self._method_handlers.get(envelope.method):
|
if handler := self._method_handlers.get(envelope.method):
|
||||||
try:
|
try:
|
||||||
await handler(envelope)
|
result = await handler(envelope)
|
||||||
|
# 检查 handler 返回的信封是否包含错误信息
|
||||||
|
if result is not None and isinstance(result, Envelope) and result.error:
|
||||||
|
logger.warning(
|
||||||
|
f"事件 {envelope.method} handler 返回错误: {result.error.get('message', '')}"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True)
|
logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|||||||
@@ -222,9 +222,19 @@ class PluginSupervisor:
|
|||||||
# 启动 RPC Server
|
# 启动 RPC Server
|
||||||
await self._rpc_server.start()
|
await self._rpc_server.start()
|
||||||
|
|
||||||
|
# 计算预期 generation(与 reload_plugins 保持一致)
|
||||||
|
expected_generation = self._rpc_server.runner_generation + 1
|
||||||
|
|
||||||
# 拉起 Runner 进程
|
# 拉起 Runner 进程
|
||||||
await self._spawn_runner()
|
await self._spawn_runner()
|
||||||
|
|
||||||
|
# 等待 Runner 完成连接,避免 start() 返回时 Runner 尚未就绪
|
||||||
|
try:
|
||||||
|
await self._wait_for_runner_generation(expected_generation, timeout_sec=30.0)
|
||||||
|
except TimeoutError:
|
||||||
|
if not self._rpc_server.is_connected:
|
||||||
|
logger.warning("Runner 未在 30s 内完成连接,后续操作可能失败")
|
||||||
|
|
||||||
# 启动健康检查
|
# 启动健康检查
|
||||||
self._health_task = asyncio.create_task(self._health_check_loop())
|
self._health_task = asyncio.create_task(self._health_check_loop())
|
||||||
|
|
||||||
@@ -283,6 +293,9 @@ class PluginSupervisor:
|
|||||||
old_registered_plugins = dict(self._registered_plugins)
|
old_registered_plugins = dict(self._registered_plugins)
|
||||||
expected_generation = self._rpc_server.runner_generation + 1
|
expected_generation = self._rpc_server.runner_generation + 1
|
||||||
|
|
||||||
|
# 重新生成 session token,防止被终止的旧 Runner 重连
|
||||||
|
self._rpc_server.reset_session_token()
|
||||||
|
|
||||||
# 清理旧的组件注册,防止幽灵组件残留
|
# 清理旧的组件注册,防止幽灵组件残留
|
||||||
self._clear_runtime_state()
|
self._clear_runtime_state()
|
||||||
|
|
||||||
|
|||||||
@@ -162,7 +162,7 @@ class PluginRuntimeManager:
|
|||||||
try:
|
try:
|
||||||
cont, mod = await sv.dispatch_event(
|
cont, mod = await sv.dispatch_event(
|
||||||
event_type=new_event_type,
|
event_type=new_event_type,
|
||||||
message=message_dict,
|
message=modified or message_dict,
|
||||||
extra_args=extra_args,
|
extra_args=extra_args,
|
||||||
)
|
)
|
||||||
if mod is not None:
|
if mod is not None:
|
||||||
@@ -184,7 +184,14 @@ class PluginRuntimeManager:
|
|||||||
for sv in self.supervisors:
|
for sv in self.supervisors:
|
||||||
result = sv.component_registry.find_command_by_text(text)
|
result = sv.component_registry.find_command_by_text(text)
|
||||||
if result is not None:
|
if result is not None:
|
||||||
return result
|
return {
|
||||||
|
"name": result.name,
|
||||||
|
"full_name": result.full_name,
|
||||||
|
"component_type": result.component_type,
|
||||||
|
"plugin_id": result.plugin_id,
|
||||||
|
"metadata": result.metadata,
|
||||||
|
"enabled": result.enabled,
|
||||||
|
}
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# ─── 能力实现注册 ──────────────────────────────────────────
|
# ─── 能力实现注册 ──────────────────────────────────────────
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ class RunnerIPCLogHandler(logging.Handler):
|
|||||||
FLUSH_BATCH_SIZE: int = 20
|
FLUSH_BATCH_SIZE: int = 20
|
||||||
|
|
||||||
#: 仅转发 logger name 以这些前缀开头的日志,第三方库日志将被忽略
|
#: 仅转发 logger name 以这些前缀开头的日志,第三方库日志将被忽略
|
||||||
ALLOWED_LOGGER_PREFIXES: tuple[str, ...] = ("plugin.",)
|
ALLOWED_LOGGER_PREFIXES: tuple[str, ...] = ("plugin.", "plugin_runtime.")
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@@ -165,7 +165,7 @@ class RPCClient:
|
|||||||
payload=payload or {},
|
payload=payload or {},
|
||||||
)
|
)
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_running_loop()
|
||||||
future: asyncio.Future[Envelope] = loop.create_future()
|
future: asyncio.Future[Envelope] = loop.create_future()
|
||||||
self._pending_requests[request_id] = future
|
self._pending_requests[request_id] = future
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ import signal
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from src.common.logger import get_logger, initialize_logging
|
from src.common.logger import get_logger, initialize_logging
|
||||||
from src.plugin_runtime.protocol.envelope import (
|
from src.plugin_runtime.protocol.envelope import (
|
||||||
ComponentDeclaration,
|
ComponentDeclaration,
|
||||||
@@ -80,14 +82,10 @@ class PluginRunner:
|
|||||||
plugins = self._loader.discover_and_load(self._plugin_dirs)
|
plugins = self._loader.discover_and_load(self._plugin_dirs)
|
||||||
logger.info(f"已加载 {len(plugins)} 个插件")
|
logger.info(f"已加载 {len(plugins)} 个插件")
|
||||||
|
|
||||||
# 4. 调用 on_load 生命周期钩子 + 注入 RPC 客户端供 SDK context 使用
|
# 4. 注入 PluginContext + 调用 on_load 生命周期钩子
|
||||||
for meta in plugins:
|
for meta in plugins:
|
||||||
instance = meta.instance
|
instance = meta.instance
|
||||||
# 注入 _rpc_client 以便 PluginContext 可以发起能力调用
|
self._inject_context(meta.plugin_id, instance)
|
||||||
if hasattr(instance, "_ctx"):
|
|
||||||
ctx = instance._ctx
|
|
||||||
if hasattr(ctx, "_set_rpc_client"):
|
|
||||||
ctx._set_rpc_client(self._rpc_client)
|
|
||||||
if hasattr(instance, "on_load"):
|
if hasattr(instance, "on_load"):
|
||||||
try:
|
try:
|
||||||
ret = instance.on_load()
|
ret = instance.on_load()
|
||||||
@@ -136,6 +134,39 @@ class PluginRunner:
|
|||||||
self._log_handler = None
|
self._log_handler = None
|
||||||
logger.debug("RunnerIPCLogHandler \u5df2\u5378\u8f7d")
|
logger.debug("RunnerIPCLogHandler \u5df2\u5378\u8f7d")
|
||||||
|
|
||||||
|
def _inject_context(self, plugin_id: str, instance: object) -> None:
|
||||||
|
"""为插件实例创建并注入 PluginContext。
|
||||||
|
|
||||||
|
对新版 MaiBotPlugin(具有 _set_context 方法):创建 PluginContext 并注入。
|
||||||
|
对旧版 LegacyPluginAdapter(具有 _set_context 方法,由适配器代理):同上。
|
||||||
|
"""
|
||||||
|
if not hasattr(instance, "_set_context"):
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from maibot_sdk.context import PluginContext
|
||||||
|
except ImportError:
|
||||||
|
logger.warning(f"maibot_sdk 不可用,无法为插件 {plugin_id} 创建 PluginContext")
|
||||||
|
return
|
||||||
|
|
||||||
|
rpc_client = self._rpc_client
|
||||||
|
|
||||||
|
async def _rpc_call(method: str, plugin_id: str = "", payload: dict = None) -> Any:
|
||||||
|
"""桥接 PluginContext.call_capability → RPCClient.send_request"""
|
||||||
|
resp = await rpc_client.send_request(
|
||||||
|
method=method,
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
payload=payload or {},
|
||||||
|
)
|
||||||
|
# 从响应信封中提取业务结果
|
||||||
|
if resp.error:
|
||||||
|
raise RuntimeError(resp.error.get("message", "能力调用失败"))
|
||||||
|
return resp.payload.get("result")
|
||||||
|
|
||||||
|
ctx = PluginContext(plugin_id=plugin_id, rpc_call=_rpc_call)
|
||||||
|
instance._set_context(ctx)
|
||||||
|
logger.debug(f"已为插件 {plugin_id} 注入 PluginContext")
|
||||||
|
|
||||||
def _register_handlers(self) -> None:
|
def _register_handlers(self) -> None:
|
||||||
"""注册方法处理器"""
|
"""注册方法处理器"""
|
||||||
self._rpc_client.register_method("plugin.invoke_command", self._handle_invoke)
|
self._rpc_client.register_method("plugin.invoke_command", self._handle_invoke)
|
||||||
@@ -203,10 +234,22 @@ class PluginRunner:
|
|||||||
instance = meta.instance
|
instance = meta.instance
|
||||||
component_name = invoke.component_name
|
component_name = invoke.component_name
|
||||||
|
|
||||||
|
# 优先查找 handle_<name> 或直接 <name> 方法(新版 SDK 插件)
|
||||||
handler_method = getattr(instance, f"handle_{component_name}", None)
|
handler_method = getattr(instance, f"handle_{component_name}", None)
|
||||||
if handler_method is None:
|
if handler_method is None:
|
||||||
handler_method = getattr(instance, component_name, None)
|
handler_method = getattr(instance, component_name, None)
|
||||||
|
|
||||||
|
# 回退: 旧版 LegacyPluginAdapter 通过 invoke_component 统一桥接
|
||||||
|
if (handler_method is None or not callable(handler_method)) and hasattr(instance, "invoke_component"):
|
||||||
|
try:
|
||||||
|
result = await instance.invoke_component(component_name, **invoke.args)
|
||||||
|
resp_payload = InvokeResultPayload(success=True, result=result)
|
||||||
|
return envelope.make_response(payload=resp_payload.model_dump())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"插件 {plugin_id} 组件 {component_name} (legacy) 执行异常: {e}", exc_info=True)
|
||||||
|
resp_payload = InvokeResultPayload(success=False, result=str(e))
|
||||||
|
return envelope.make_response(payload=resp_payload.model_dump())
|
||||||
|
|
||||||
if handler_method is None or not callable(handler_method):
|
if handler_method is None or not callable(handler_method):
|
||||||
return envelope.make_error_response(
|
return envelope.make_error_response(
|
||||||
ErrorCode.E_METHOD_NOT_ALLOWED.value,
|
ErrorCode.E_METHOD_NOT_ALLOWED.value,
|
||||||
@@ -326,6 +369,7 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
|
|||||||
|
|
||||||
防止插件代码 import 主程序模块读取运行时数据。
|
防止插件代码 import 主程序模块读取运行时数据。
|
||||||
"""
|
"""
|
||||||
|
import importlib.abc
|
||||||
import sysconfig
|
import sysconfig
|
||||||
|
|
||||||
# 保留: 标准库路径 + site-packages(含 SDK 和依赖)
|
# 保留: 标准库路径 + site-packages(含 SDK 和依赖)
|
||||||
@@ -348,12 +392,45 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
|
|||||||
for d in plugin_dirs:
|
for d in plugin_dirs:
|
||||||
allowed.add(os.path.normpath(d))
|
allowed.add(os.path.normpath(d))
|
||||||
|
|
||||||
# 添加当前 runner 模块所在路径(使得 src.plugin_runtime 可导入)
|
# 添加项目根目录(使得 src.plugin_runtime / src.common 可导入)
|
||||||
runtime_root = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
runtime_root = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||||
allowed.add(runtime_root)
|
allowed.add(runtime_root)
|
||||||
|
|
||||||
sys.path[:] = [p for p in sys.path if p in allowed]
|
sys.path[:] = [p for p in sys.path if p in allowed]
|
||||||
|
|
||||||
|
# 安装 import 钩子,阻止插件导入主程序核心模块
|
||||||
|
# 仅允许 src.plugin_runtime 和 src.common,拒绝其他 src.* 子包
|
||||||
|
class _PluginImportBlocker(importlib.abc.MetaPathFinder):
|
||||||
|
"""阻止 Runner 子进程导入主程序核心模块。
|
||||||
|
|
||||||
|
只放行 src.plugin_runtime 和 src.common,
|
||||||
|
拒绝 src.chat_module / src.services 等主程序内部包。
|
||||||
|
"""
|
||||||
|
|
||||||
|
_ALLOWED_SRC_PREFIXES = ("src.plugin_runtime", "src.common")
|
||||||
|
|
||||||
|
def find_module(self, fullname, path=None):
|
||||||
|
if self._should_block(fullname):
|
||||||
|
return self
|
||||||
|
return None
|
||||||
|
|
||||||
|
def load_module(self, fullname):
|
||||||
|
raise ImportError(
|
||||||
|
f"Runner 子进程不允许导入主程序模块: {fullname}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _should_block(self, fullname: str) -> bool:
|
||||||
|
# 放行非 src.* 的导入、以及 "src" 本身
|
||||||
|
if not fullname.startswith("src.") or fullname == "src":
|
||||||
|
return False
|
||||||
|
# 放行白名单前缀
|
||||||
|
for prefix in self._ALLOWED_SRC_PREFIXES:
|
||||||
|
if fullname == prefix or fullname.startswith(prefix + "."):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
sys.meta_path.insert(0, _PluginImportBlocker())
|
||||||
|
|
||||||
|
|
||||||
# ─── 进程入口 ──────────────────────────────────────────────
|
# ─── 进程入口 ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ class Connection(ABC):
|
|||||||
self._reader = reader
|
self._reader = reader
|
||||||
self._writer = writer
|
self._writer = writer
|
||||||
self._closed = False
|
self._closed = False
|
||||||
|
self._write_lock = asyncio.Lock() # 保护并发写入的帧完整性
|
||||||
|
|
||||||
async def send_frame(self, data: bytes) -> None:
|
async def send_frame(self, data: bytes) -> None:
|
||||||
"""发送一帧数据(4-byte length prefix + payload)"""
|
"""发送一帧数据(4-byte length prefix + payload)"""
|
||||||
@@ -42,8 +43,9 @@ class Connection(ABC):
|
|||||||
if length > MAX_FRAME_SIZE:
|
if length > MAX_FRAME_SIZE:
|
||||||
raise ValueError(f"帧大小 {length} 超过最大限制 {MAX_FRAME_SIZE}")
|
raise ValueError(f"帧大小 {length} 超过最大限制 {MAX_FRAME_SIZE}")
|
||||||
header = struct.pack(">I", length)
|
header = struct.pack(">I", length)
|
||||||
self._writer.write(header + data)
|
async with self._write_lock:
|
||||||
await self._writer.drain()
|
self._writer.write(header + data)
|
||||||
|
await self._writer.drain()
|
||||||
|
|
||||||
async def recv_frame(self) -> bytes:
|
async def recv_frame(self) -> bytes:
|
||||||
"""接收一帧数据"""
|
"""接收一帧数据"""
|
||||||
|
|||||||
Reference in New Issue
Block a user