feat: 增强插件能力检查,支持 generation 校验并添加清理功能

This commit is contained in:
DrSmoothl
2026-03-12 21:22:23 +08:00
parent df39fa7584
commit d0b56abdab
8 changed files with 466 additions and 51 deletions

View File

@@ -10,6 +10,7 @@
from typing import Any, Dict, List, Optional, Tuple
import asyncio
import contextlib
import os
import sys
@@ -75,6 +76,7 @@ class PluginSupervisor:
# 后台任务
self._health_task: Optional[asyncio.Task] = None
self._runner_output_tasks: List[asyncio.Task] = []
self._running = False
# 注册内部 RPC 方法
@@ -224,40 +226,26 @@ class PluginSupervisor:
# 保存旧进程引用
old_process = self._runner_process
old_registered_plugins = dict(self._registered_plugins)
expected_generation = self._rpc_server.runner_generation + 1
# 清理旧的组件注册,防止幽灵组件残留
for plugin_id in list(self._registered_plugins.keys()):
self._component_registry.remove_components_by_plugin(plugin_id)
self._policy.revoke_plugin(plugin_id)
self._registered_plugins.clear()
self._clear_runtime_state()
# 拉起新 Runner
await self._spawn_runner()
# 等待新 Runner 连接并完成握手
for _ in range(30): # 最多等待 30 秒
if self._rpc_server.is_connected:
break
await asyncio.sleep(1.0)
else:
logger.error("新 Runner 连接超时,回滚")
# 回滚:终止新进程
if self._runner_process and self._runner_process != old_process:
self._runner_process.terminate()
self._runner_process = old_process
return
# 健康检查
try:
await self._spawn_runner()
await self._wait_for_runner_generation(expected_generation, timeout_sec=30.0)
resp = await self._rpc_server.send_request("plugin.health", timeout_ms=5000)
health = HealthPayload.model_validate(resp.payload)
if not health.healthy:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "新 Runner 健康检查失败")
except Exception as e:
logger.error(f"新 Runner 健康检查失败: {e},回滚")
if self._runner_process and self._runner_process != old_process:
self._runner_process.terminate()
await self._terminate_process(self._runner_process, old_process)
self._runner_process = old_process
self._registered_plugins = dict(old_registered_plugins)
self._rebuild_runtime_state()
return
# 关停旧 Runner
@@ -286,13 +274,19 @@ class PluginSupervisor:
except Exception as e:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e))
if envelope.generation != self._rpc_server.runner_generation:
return envelope.make_error_response(
ErrorCode.E_GENERATION_MISMATCH.value,
f"组件注册 generation 过期: {envelope.generation} != {self._rpc_server.runner_generation}",
)
# 记录注册信息
self._registered_plugins[reg.plugin_id] = reg
# 在策略引擎中注册插件
self._policy.register_plugin(
plugin_id=reg.plugin_id,
generation=self._runner_generation,
generation=envelope.generation,
capabilities=reg.capabilities_required or [],
)
@@ -329,7 +323,8 @@ class PluginSupervisor:
stderr=asyncio.subprocess.PIPE,
)
self._runner_generation += 1
self._attach_runner_output_tasks(self._runner_process)
self._runner_generation = self._rpc_server.runner_generation
logger.info(f"Runner 子进程已启动: pid={self._runner_process.pid}, generation={self._runner_generation}")
async def _shutdown_runner(self) -> None:
@@ -362,6 +357,8 @@ class PluginSupervisor:
self._runner_process.kill()
await self._runner_process.wait()
await self._cleanup_runner_output_tasks()
async def _health_check_loop(self) -> None:
"""周期性健康检查 + 崩溃自动重启"""
while self._running:
@@ -382,6 +379,7 @@ class PluginSupervisor:
self._registered_plugins.clear()
try:
self._clear_runtime_state()
await self._spawn_runner()
except Exception as e:
logger.error(f"Runner 重启失败: {e}", exc_info=True)
@@ -407,3 +405,98 @@ class PluginSupervisor:
break
except Exception as e:
logger.error(f"健康检查异常: {e}")
async def _wait_for_runner_generation(self, expected_generation: int, timeout_sec: float) -> None:
"""等待指定代际的 Runner 完成连接。"""
deadline = asyncio.get_running_loop().time() + timeout_sec
while asyncio.get_running_loop().time() < deadline:
if self._rpc_server.is_connected and self._rpc_server.runner_generation >= expected_generation:
self._runner_generation = self._rpc_server.runner_generation
return
await asyncio.sleep(0.1)
raise TimeoutError(f"等待 Runner generation {expected_generation} 超时")
def _clear_runtime_state(self) -> None:
"""清空当前插件注册态。"""
self._component_registry.clear()
self._policy.clear()
self._registered_plugins.clear()
def _rebuild_runtime_state(self) -> None:
"""根据已记录的插件注册信息重建运行时状态。"""
self._component_registry.clear()
self._policy.clear()
for reg in self._registered_plugins.values():
self._policy.register_plugin(
plugin_id=reg.plugin_id,
generation=self._rpc_server.runner_generation,
capabilities=reg.capabilities_required or [],
)
self._component_registry.register_plugin_components(
plugin_id=reg.plugin_id,
components=[c.model_dump() for c in reg.components],
)
def _attach_runner_output_tasks(self, process: asyncio.subprocess.Process) -> None:
"""为 Runner 输出流创建排空任务,避免 PIPE 填满阻塞子进程。"""
streams = (
(process.stdout, "stdout"),
(process.stderr, "stderr"),
)
for stream, stream_name in streams:
if stream is None:
continue
task = asyncio.create_task(self._drain_runner_stream(stream, stream_name, process.pid))
self._runner_output_tasks.append(task)
task.add_done_callback(
lambda done_task: self._runner_output_tasks.remove(done_task)
if done_task in self._runner_output_tasks
else None
)
async def _drain_runner_stream(
self,
stream: asyncio.StreamReader,
stream_name: str,
pid: int,
) -> None:
"""持续消费 Runner 输出,避免 PIPE 回压导致子进程阻塞。"""
try:
while True:
line = await stream.readline()
if not line:
break
message = line.decode(errors="replace").rstrip()
if message:
logger.debug(f"[runner:{pid}:{stream_name}] {message}")
except asyncio.CancelledError:
raise
except Exception as e:
logger.debug(f"读取 Runner {stream_name} 失败: {e}")
async def _cleanup_runner_output_tasks(self) -> None:
"""等待并清理 Runner 输出任务。"""
tasks = list(self._runner_output_tasks)
self._runner_output_tasks.clear()
for task in tasks:
if not task.done():
task.cancel()
if tasks:
with contextlib.suppress(Exception):
await asyncio.gather(*tasks, return_exceptions=True)
@staticmethod
async def _terminate_process(
process: Optional[asyncio.subprocess.Process],
keep_process: Optional[asyncio.subprocess.Process] = None,
) -> None:
"""终止指定进程,但跳过需要保留的旧进程引用。"""
if process is None or process is keep_process or process.returncode is not None:
return
process.terminate()
try:
await asyncio.wait_for(process.wait(), timeout=10.0)
except asyncio.TimeoutError:
process.kill()
await process.wait()