feat: 增强插件能力检查,支持 generation 校验并添加清理功能

This commit is contained in:
DrSmoothl
2026-03-12 21:22:23 +08:00
parent df39fa7584
commit d0b56abdab
8 changed files with 466 additions and 51 deletions

View File

@@ -8,14 +8,13 @@
"""
import asyncio
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
import contextlib
from dataclasses import fields
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
from src.common.logger import get_logger
from src.core.types import EventType, MaiMessages
if TYPE_CHECKING:
from src.common.data_models.llm_data_model import LLMGenerationDataModel
logger = get_logger("event_bus")
# Handler 签名:接收 MaiMessages返回 (continue, modified_message)
@@ -127,8 +126,7 @@ class EventBus:
async def cancel_handler_tasks(self, handler_name: str) -> None:
"""取消某个 handler 的所有运行中任务"""
tasks = self._running_tasks.pop(handler_name, [])
remaining = [t for t in tasks if not t.done()]
if remaining:
if remaining := [t for t in tasks if not t.done()]:
for t in remaining:
t.cancel()
await asyncio.gather(*remaining, return_exceptions=True)
@@ -156,17 +154,14 @@ class EventBus:
try:
if task.cancelled():
return
exc = task.exception()
if exc:
if exc := task.exception():
logger.error(f"handler {handler_name} 异步任务异常: {exc}")
except Exception:
pass
finally:
task_list = self._running_tasks.get(handler_name, [])
try:
with contextlib.suppress(ValueError):
task_list.remove(task)
except ValueError:
pass
async def _bridge_to_ipc_runtime(
self,
@@ -188,17 +183,29 @@ class EventBus:
event_value = event_type.value if isinstance(event_type, EventType) else str(event_type)
message_dict = message.to_dict() if message and hasattr(message, "to_dict") else None
new_continue, _ = await prm.bridge_event(
new_continue, modified_dict = await prm.bridge_event(
event_type_value=event_value,
message_dict=message_dict,
)
if not new_continue:
continue_flag = False
if modified_dict is not None and message is not None:
message = self._apply_ipc_message_update(message, modified_dict)
except Exception as e:
logger.warning(f"桥接事件到 IPC 运行时失败: {e}")
return continue_flag, message
@staticmethod
def _apply_ipc_message_update(message: MaiMessages, modified_dict: Dict[str, Any]) -> MaiMessages:
"""将 IPC 返回的消息字典回写到当前 MaiMessages。"""
updated_message = message.deepcopy()
valid_fields = {field.name for field in fields(MaiMessages)}
for key, value in modified_dict.items():
if key in valid_fields:
setattr(updated_message, key, value)
return updated_message
class _HandlerEntry:
"""内部 handler 条目"""

View File

@@ -65,10 +65,11 @@ class CapabilityService:
capability = req.capability
# 1. 权限校验
allowed, reason = self._policy.check_capability(plugin_id, capability)
allowed, reason = self._policy.check_capability(plugin_id, capability, envelope.generation)
if not allowed:
error_code = ErrorCode.E_GENERATION_MISMATCH if "generation 不匹配" in reason else ErrorCode.E_CAPABILITY_DENIED
return envelope.make_error_response(
ErrorCode.E_CAPABILITY_DENIED.value,
error_code.value,
reason,
)

View File

@@ -73,6 +73,13 @@ class ComponentRegistry:
# 按插件索引
self._by_plugin: Dict[str, List[RegisteredComponent]] = {}
def clear(self) -> None:
"""清空全部组件注册状态。"""
self._components.clear()
for type_dict in self._by_type.values():
type_dict.clear()
self._by_plugin.clear()
# ──── 注册 / 注销 ─────────────────────────────────────────
def register_component(

View File

@@ -44,7 +44,11 @@ class PolicyEngine:
"""撤销插件的能力令牌"""
self._tokens.pop(plugin_id, None)
def check_capability(self, plugin_id: str, capability: str) -> Tuple[bool, str]:
def clear(self) -> None:
"""清空所有能力令牌。"""
self._tokens.clear()
def check_capability(self, plugin_id: str, capability: str, generation: Optional[int] = None) -> Tuple[bool, str]:
"""检查插件是否有权调用某项能力
Returns:
@@ -57,6 +61,9 @@ class PolicyEngine:
if capability not in token.capabilities:
return False, f"插件 {plugin_id} 未获授权能力: {capability}"
if generation is not None and token.generation != generation:
return False, f"插件 {plugin_id} generation 不匹配: {generation} != {token.generation}"
return True, ""
def get_token(self, plugin_id: str) -> Optional[CapabilityToken]:

View File

@@ -73,6 +73,10 @@ class RPCServer:
def session_token(self) -> str:
return self._session_token
@property
def runner_generation(self) -> int:
return self._runner_generation
@property
def is_connected(self) -> bool:
return self._connection is not None and not self._connection.is_closed
@@ -206,18 +210,23 @@ class RPCServer:
await conn.close()
return
# 握手成功,保存连接
old_connection = self._connection
self._connection = conn
logger.info(f"Runner 握手成功: runner_id={self._runner_id}, generation={self._runner_generation}")
if old_connection and old_connection is not conn and not old_connection.is_closed:
logger.info("检测到新 Runner 已接管连接,关闭旧连接")
await old_connection.close()
# 启动消息接收循环
try:
await self._recv_loop(conn)
except Exception as e:
logger.error(f"连接异常断开: {e}")
finally:
self._connection = None
self._runner_id = None
if self._connection is conn:
self._connection = None
self._runner_id = None
async def _handle_handshake(self, conn: Connection) -> bool:
"""处理 runner.hello 握手"""
@@ -295,17 +304,35 @@ class RPCServer:
if envelope.is_response():
self._handle_response(envelope)
elif envelope.is_request():
if not self._is_current_generation(envelope):
error_resp = envelope.make_error_response(
ErrorCode.E_GENERATION_MISMATCH.value,
f"过期 generation: {envelope.generation} != {self._runner_generation}",
)
await conn.send_frame(self._codec.encode_envelope(error_resp))
continue
# 异步处理请求Runner 发来的能力调用)
task = asyncio.create_task(self._handle_request(envelope, conn))
self._tasks.append(task)
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
elif envelope.is_event():
if not self._is_current_generation(envelope):
logger.warning(
f"忽略过期 generation 事件 {envelope.method}: {envelope.generation} != {self._runner_generation}"
)
continue
task = asyncio.create_task(self._handle_event(envelope))
self._tasks.append(task)
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
def _handle_response(self, envelope: Envelope) -> None:
"""处理来自 Runner 的响应"""
if not self._is_current_generation(envelope):
logger.warning(
f"忽略过期 generation 响应 {envelope.method}: {envelope.generation} != {self._runner_generation}"
)
return
future = self._pending_requests.pop(envelope.request_id, None)
if future and not future.done():
if envelope.error:
@@ -313,6 +340,9 @@ class RPCServer:
else:
future.set_result(envelope)
def _is_current_generation(self, envelope: Envelope) -> bool:
return envelope.generation == self._runner_generation
async def _handle_request(self, envelope: Envelope, conn: Connection) -> None:
"""处理来自 Runner 的请求(通常是能力调用 cap.*"""
handler = self._method_handlers.get(envelope.method)

View File

@@ -10,6 +10,7 @@
from typing import Any, Dict, List, Optional, Tuple
import asyncio
import contextlib
import os
import sys
@@ -75,6 +76,7 @@ class PluginSupervisor:
# 后台任务
self._health_task: Optional[asyncio.Task] = None
self._runner_output_tasks: List[asyncio.Task] = []
self._running = False
# 注册内部 RPC 方法
@@ -224,40 +226,26 @@ class PluginSupervisor:
# 保存旧进程引用
old_process = self._runner_process
old_registered_plugins = dict(self._registered_plugins)
expected_generation = self._rpc_server.runner_generation + 1
# 清理旧的组件注册,防止幽灵组件残留
for plugin_id in list(self._registered_plugins.keys()):
self._component_registry.remove_components_by_plugin(plugin_id)
self._policy.revoke_plugin(plugin_id)
self._registered_plugins.clear()
self._clear_runtime_state()
# 拉起新 Runner
await self._spawn_runner()
# 等待新 Runner 连接并完成握手
for _ in range(30): # 最多等待 30 秒
if self._rpc_server.is_connected:
break
await asyncio.sleep(1.0)
else:
logger.error("新 Runner 连接超时,回滚")
# 回滚:终止新进程
if self._runner_process and self._runner_process != old_process:
self._runner_process.terminate()
self._runner_process = old_process
return
# 健康检查
try:
await self._spawn_runner()
await self._wait_for_runner_generation(expected_generation, timeout_sec=30.0)
resp = await self._rpc_server.send_request("plugin.health", timeout_ms=5000)
health = HealthPayload.model_validate(resp.payload)
if not health.healthy:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "新 Runner 健康检查失败")
except Exception as e:
logger.error(f"新 Runner 健康检查失败: {e},回滚")
if self._runner_process and self._runner_process != old_process:
self._runner_process.terminate()
await self._terminate_process(self._runner_process, old_process)
self._runner_process = old_process
self._registered_plugins = dict(old_registered_plugins)
self._rebuild_runtime_state()
return
# 关停旧 Runner
@@ -286,13 +274,19 @@ class PluginSupervisor:
except Exception as e:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e))
if envelope.generation != self._rpc_server.runner_generation:
return envelope.make_error_response(
ErrorCode.E_GENERATION_MISMATCH.value,
f"组件注册 generation 过期: {envelope.generation} != {self._rpc_server.runner_generation}",
)
# 记录注册信息
self._registered_plugins[reg.plugin_id] = reg
# 在策略引擎中注册插件
self._policy.register_plugin(
plugin_id=reg.plugin_id,
generation=self._runner_generation,
generation=envelope.generation,
capabilities=reg.capabilities_required or [],
)
@@ -329,7 +323,8 @@ class PluginSupervisor:
stderr=asyncio.subprocess.PIPE,
)
self._runner_generation += 1
self._attach_runner_output_tasks(self._runner_process)
self._runner_generation = self._rpc_server.runner_generation
logger.info(f"Runner 子进程已启动: pid={self._runner_process.pid}, generation={self._runner_generation}")
async def _shutdown_runner(self) -> None:
@@ -362,6 +357,8 @@ class PluginSupervisor:
self._runner_process.kill()
await self._runner_process.wait()
await self._cleanup_runner_output_tasks()
async def _health_check_loop(self) -> None:
"""周期性健康检查 + 崩溃自动重启"""
while self._running:
@@ -382,6 +379,7 @@ class PluginSupervisor:
self._registered_plugins.clear()
try:
self._clear_runtime_state()
await self._spawn_runner()
except Exception as e:
logger.error(f"Runner 重启失败: {e}", exc_info=True)
@@ -407,3 +405,98 @@ class PluginSupervisor:
break
except Exception as e:
logger.error(f"健康检查异常: {e}")
async def _wait_for_runner_generation(self, expected_generation: int, timeout_sec: float) -> None:
"""等待指定代际的 Runner 完成连接。"""
deadline = asyncio.get_running_loop().time() + timeout_sec
while asyncio.get_running_loop().time() < deadline:
if self._rpc_server.is_connected and self._rpc_server.runner_generation >= expected_generation:
self._runner_generation = self._rpc_server.runner_generation
return
await asyncio.sleep(0.1)
raise TimeoutError(f"等待 Runner generation {expected_generation} 超时")
def _clear_runtime_state(self) -> None:
"""清空当前插件注册态。"""
self._component_registry.clear()
self._policy.clear()
self._registered_plugins.clear()
def _rebuild_runtime_state(self) -> None:
"""根据已记录的插件注册信息重建运行时状态。"""
self._component_registry.clear()
self._policy.clear()
for reg in self._registered_plugins.values():
self._policy.register_plugin(
plugin_id=reg.plugin_id,
generation=self._rpc_server.runner_generation,
capabilities=reg.capabilities_required or [],
)
self._component_registry.register_plugin_components(
plugin_id=reg.plugin_id,
components=[c.model_dump() for c in reg.components],
)
def _attach_runner_output_tasks(self, process: asyncio.subprocess.Process) -> None:
"""为 Runner 输出流创建排空任务,避免 PIPE 填满阻塞子进程。"""
streams = (
(process.stdout, "stdout"),
(process.stderr, "stderr"),
)
for stream, stream_name in streams:
if stream is None:
continue
task = asyncio.create_task(self._drain_runner_stream(stream, stream_name, process.pid))
self._runner_output_tasks.append(task)
task.add_done_callback(
lambda done_task: self._runner_output_tasks.remove(done_task)
if done_task in self._runner_output_tasks
else None
)
async def _drain_runner_stream(
self,
stream: asyncio.StreamReader,
stream_name: str,
pid: int,
) -> None:
"""持续消费 Runner 输出,避免 PIPE 回压导致子进程阻塞。"""
try:
while True:
line = await stream.readline()
if not line:
break
message = line.decode(errors="replace").rstrip()
if message:
logger.debug(f"[runner:{pid}:{stream_name}] {message}")
except asyncio.CancelledError:
raise
except Exception as e:
logger.debug(f"读取 Runner {stream_name} 失败: {e}")
async def _cleanup_runner_output_tasks(self) -> None:
"""等待并清理 Runner 输出任务。"""
tasks = list(self._runner_output_tasks)
self._runner_output_tasks.clear()
for task in tasks:
if not task.done():
task.cancel()
if tasks:
with contextlib.suppress(Exception):
await asyncio.gather(*tasks, return_exceptions=True)
@staticmethod
async def _terminate_process(
process: Optional[asyncio.subprocess.Process],
keep_process: Optional[asyncio.subprocess.Process] = None,
) -> None:
"""终止指定进程,但跳过需要保留的旧进程引用。"""
if process is None or process is keep_process or process.returncode is not None:
return
process.terminate()
try:
await asyncio.wait_for(process.wait(), timeout=10.0)
except asyncio.TimeoutError:
process.kill()
await process.wait()

View File

@@ -90,21 +90,22 @@ class PluginRuntimeManager:
)
self._register_capability_impls(self._thirdparty_supervisor)
# 并行启动
coros = []
if self._builtin_supervisor:
coros.append(self._builtin_supervisor.start())
if self._thirdparty_supervisor:
coros.append(self._thirdparty_supervisor.start())
started_supervisors = []
try:
await asyncio.gather(*coros)
if self._builtin_supervisor:
await self._builtin_supervisor.start()
started_supervisors.append(self._builtin_supervisor)
if self._thirdparty_supervisor:
await self._thirdparty_supervisor.start()
started_supervisors.append(self._thirdparty_supervisor)
self._started = True
logger.info(
f"插件运行时已启动 — 内置: {builtin_dirs or ''}, 第三方: {thirdparty_dirs or ''}"
)
except Exception as e:
logger.error(f"插件运行时启动失败: {e}", exc_info=True)
await asyncio.gather(*(sv.stop() for sv in started_supervisors), return_exceptions=True)
self._started = False
self._builtin_supervisor = None
self._thirdparty_supervisor = None