feat: 增强插件能力检查,支持 generation 校验并添加清理功能

This commit is contained in:
DrSmoothl
2026-03-12 21:22:23 +08:00
parent df39fa7584
commit d0b56abdab
8 changed files with 466 additions and 51 deletions

View File

@@ -65,10 +65,11 @@ class CapabilityService:
capability = req.capability
# 1. 权限校验
allowed, reason = self._policy.check_capability(plugin_id, capability)
allowed, reason = self._policy.check_capability(plugin_id, capability, envelope.generation)
if not allowed:
error_code = ErrorCode.E_GENERATION_MISMATCH if "generation 不匹配" in reason else ErrorCode.E_CAPABILITY_DENIED
return envelope.make_error_response(
ErrorCode.E_CAPABILITY_DENIED.value,
error_code.value,
reason,
)

View File

@@ -73,6 +73,13 @@ class ComponentRegistry:
# 按插件索引
self._by_plugin: Dict[str, List[RegisteredComponent]] = {}
def clear(self) -> None:
"""清空全部组件注册状态。"""
self._components.clear()
for type_dict in self._by_type.values():
type_dict.clear()
self._by_plugin.clear()
# ──── 注册 / 注销 ─────────────────────────────────────────
def register_component(

View File

@@ -44,7 +44,11 @@ class PolicyEngine:
"""撤销插件的能力令牌"""
self._tokens.pop(plugin_id, None)
def check_capability(self, plugin_id: str, capability: str) -> Tuple[bool, str]:
def clear(self) -> None:
"""清空所有能力令牌。"""
self._tokens.clear()
def check_capability(self, plugin_id: str, capability: str, generation: Optional[int] = None) -> Tuple[bool, str]:
"""检查插件是否有权调用某项能力
Returns:
@@ -57,6 +61,9 @@ class PolicyEngine:
if capability not in token.capabilities:
return False, f"插件 {plugin_id} 未获授权能力: {capability}"
if generation is not None and token.generation != generation:
return False, f"插件 {plugin_id} generation 不匹配: {generation} != {token.generation}"
return True, ""
def get_token(self, plugin_id: str) -> Optional[CapabilityToken]:

View File

@@ -73,6 +73,10 @@ class RPCServer:
def session_token(self) -> str:
return self._session_token
@property
def runner_generation(self) -> int:
return self._runner_generation
@property
def is_connected(self) -> bool:
return self._connection is not None and not self._connection.is_closed
@@ -206,18 +210,23 @@ class RPCServer:
await conn.close()
return
# 握手成功,保存连接
old_connection = self._connection
self._connection = conn
logger.info(f"Runner 握手成功: runner_id={self._runner_id}, generation={self._runner_generation}")
if old_connection and old_connection is not conn and not old_connection.is_closed:
logger.info("检测到新 Runner 已接管连接,关闭旧连接")
await old_connection.close()
# 启动消息接收循环
try:
await self._recv_loop(conn)
except Exception as e:
logger.error(f"连接异常断开: {e}")
finally:
self._connection = None
self._runner_id = None
if self._connection is conn:
self._connection = None
self._runner_id = None
async def _handle_handshake(self, conn: Connection) -> bool:
"""处理 runner.hello 握手"""
@@ -295,17 +304,35 @@ class RPCServer:
if envelope.is_response():
self._handle_response(envelope)
elif envelope.is_request():
if not self._is_current_generation(envelope):
error_resp = envelope.make_error_response(
ErrorCode.E_GENERATION_MISMATCH.value,
f"过期 generation: {envelope.generation} != {self._runner_generation}",
)
await conn.send_frame(self._codec.encode_envelope(error_resp))
continue
# 异步处理请求Runner 发来的能力调用)
task = asyncio.create_task(self._handle_request(envelope, conn))
self._tasks.append(task)
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
elif envelope.is_event():
if not self._is_current_generation(envelope):
logger.warning(
f"忽略过期 generation 事件 {envelope.method}: {envelope.generation} != {self._runner_generation}"
)
continue
task = asyncio.create_task(self._handle_event(envelope))
self._tasks.append(task)
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
def _handle_response(self, envelope: Envelope) -> None:
"""处理来自 Runner 的响应"""
if not self._is_current_generation(envelope):
logger.warning(
f"忽略过期 generation 响应 {envelope.method}: {envelope.generation} != {self._runner_generation}"
)
return
future = self._pending_requests.pop(envelope.request_id, None)
if future and not future.done():
if envelope.error:
@@ -313,6 +340,9 @@ class RPCServer:
else:
future.set_result(envelope)
def _is_current_generation(self, envelope: Envelope) -> bool:
return envelope.generation == self._runner_generation
async def _handle_request(self, envelope: Envelope, conn: Connection) -> None:
"""处理来自 Runner 的请求(通常是能力调用 cap.*"""
handler = self._method_handlers.get(envelope.method)

View File

@@ -10,6 +10,7 @@
from typing import Any, Dict, List, Optional, Tuple
import asyncio
import contextlib
import os
import sys
@@ -75,6 +76,7 @@ class PluginSupervisor:
# 后台任务
self._health_task: Optional[asyncio.Task] = None
self._runner_output_tasks: List[asyncio.Task] = []
self._running = False
# 注册内部 RPC 方法
@@ -224,40 +226,26 @@ class PluginSupervisor:
# 保存旧进程引用
old_process = self._runner_process
old_registered_plugins = dict(self._registered_plugins)
expected_generation = self._rpc_server.runner_generation + 1
# 清理旧的组件注册,防止幽灵组件残留
for plugin_id in list(self._registered_plugins.keys()):
self._component_registry.remove_components_by_plugin(plugin_id)
self._policy.revoke_plugin(plugin_id)
self._registered_plugins.clear()
self._clear_runtime_state()
# 拉起新 Runner
await self._spawn_runner()
# 等待新 Runner 连接并完成握手
for _ in range(30): # 最多等待 30 秒
if self._rpc_server.is_connected:
break
await asyncio.sleep(1.0)
else:
logger.error("新 Runner 连接超时,回滚")
# 回滚:终止新进程
if self._runner_process and self._runner_process != old_process:
self._runner_process.terminate()
self._runner_process = old_process
return
# 健康检查
try:
await self._spawn_runner()
await self._wait_for_runner_generation(expected_generation, timeout_sec=30.0)
resp = await self._rpc_server.send_request("plugin.health", timeout_ms=5000)
health = HealthPayload.model_validate(resp.payload)
if not health.healthy:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "新 Runner 健康检查失败")
except Exception as e:
logger.error(f"新 Runner 健康检查失败: {e},回滚")
if self._runner_process and self._runner_process != old_process:
self._runner_process.terminate()
await self._terminate_process(self._runner_process, old_process)
self._runner_process = old_process
self._registered_plugins = dict(old_registered_plugins)
self._rebuild_runtime_state()
return
# 关停旧 Runner
@@ -286,13 +274,19 @@ class PluginSupervisor:
except Exception as e:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e))
if envelope.generation != self._rpc_server.runner_generation:
return envelope.make_error_response(
ErrorCode.E_GENERATION_MISMATCH.value,
f"组件注册 generation 过期: {envelope.generation} != {self._rpc_server.runner_generation}",
)
# 记录注册信息
self._registered_plugins[reg.plugin_id] = reg
# 在策略引擎中注册插件
self._policy.register_plugin(
plugin_id=reg.plugin_id,
generation=self._runner_generation,
generation=envelope.generation,
capabilities=reg.capabilities_required or [],
)
@@ -329,7 +323,8 @@ class PluginSupervisor:
stderr=asyncio.subprocess.PIPE,
)
self._runner_generation += 1
self._attach_runner_output_tasks(self._runner_process)
self._runner_generation = self._rpc_server.runner_generation
logger.info(f"Runner 子进程已启动: pid={self._runner_process.pid}, generation={self._runner_generation}")
async def _shutdown_runner(self) -> None:
@@ -362,6 +357,8 @@ class PluginSupervisor:
self._runner_process.kill()
await self._runner_process.wait()
await self._cleanup_runner_output_tasks()
async def _health_check_loop(self) -> None:
"""周期性健康检查 + 崩溃自动重启"""
while self._running:
@@ -382,6 +379,7 @@ class PluginSupervisor:
self._registered_plugins.clear()
try:
self._clear_runtime_state()
await self._spawn_runner()
except Exception as e:
logger.error(f"Runner 重启失败: {e}", exc_info=True)
@@ -407,3 +405,98 @@ class PluginSupervisor:
break
except Exception as e:
logger.error(f"健康检查异常: {e}")
async def _wait_for_runner_generation(self, expected_generation: int, timeout_sec: float) -> None:
"""等待指定代际的 Runner 完成连接。"""
deadline = asyncio.get_running_loop().time() + timeout_sec
while asyncio.get_running_loop().time() < deadline:
if self._rpc_server.is_connected and self._rpc_server.runner_generation >= expected_generation:
self._runner_generation = self._rpc_server.runner_generation
return
await asyncio.sleep(0.1)
raise TimeoutError(f"等待 Runner generation {expected_generation} 超时")
def _clear_runtime_state(self) -> None:
"""清空当前插件注册态。"""
self._component_registry.clear()
self._policy.clear()
self._registered_plugins.clear()
def _rebuild_runtime_state(self) -> None:
"""根据已记录的插件注册信息重建运行时状态。"""
self._component_registry.clear()
self._policy.clear()
for reg in self._registered_plugins.values():
self._policy.register_plugin(
plugin_id=reg.plugin_id,
generation=self._rpc_server.runner_generation,
capabilities=reg.capabilities_required or [],
)
self._component_registry.register_plugin_components(
plugin_id=reg.plugin_id,
components=[c.model_dump() for c in reg.components],
)
def _attach_runner_output_tasks(self, process: asyncio.subprocess.Process) -> None:
"""为 Runner 输出流创建排空任务,避免 PIPE 填满阻塞子进程。"""
streams = (
(process.stdout, "stdout"),
(process.stderr, "stderr"),
)
for stream, stream_name in streams:
if stream is None:
continue
task = asyncio.create_task(self._drain_runner_stream(stream, stream_name, process.pid))
self._runner_output_tasks.append(task)
task.add_done_callback(
lambda done_task: self._runner_output_tasks.remove(done_task)
if done_task in self._runner_output_tasks
else None
)
async def _drain_runner_stream(
self,
stream: asyncio.StreamReader,
stream_name: str,
pid: int,
) -> None:
"""持续消费 Runner 输出,避免 PIPE 回压导致子进程阻塞。"""
try:
while True:
line = await stream.readline()
if not line:
break
message = line.decode(errors="replace").rstrip()
if message:
logger.debug(f"[runner:{pid}:{stream_name}] {message}")
except asyncio.CancelledError:
raise
except Exception as e:
logger.debug(f"读取 Runner {stream_name} 失败: {e}")
async def _cleanup_runner_output_tasks(self) -> None:
"""等待并清理 Runner 输出任务。"""
tasks = list(self._runner_output_tasks)
self._runner_output_tasks.clear()
for task in tasks:
if not task.done():
task.cancel()
if tasks:
with contextlib.suppress(Exception):
await asyncio.gather(*tasks, return_exceptions=True)
@staticmethod
async def _terminate_process(
process: Optional[asyncio.subprocess.Process],
keep_process: Optional[asyncio.subprocess.Process] = None,
) -> None:
"""终止指定进程,但跳过需要保留的旧进程引用。"""
if process is None or process is keep_process or process.returncode is not None:
return
process.terminate()
try:
await asyncio.wait_for(process.wait(), timeout=10.0)
except asyncio.TimeoutError:
process.kill()
await process.wait()