feat: 增强插件能力检查,支持 generation 校验并添加清理功能
This commit is contained in:
@@ -8,14 +8,13 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
import contextlib
|
||||
from dataclasses import fields
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.core.types import EventType, MaiMessages
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.llm_data_model import LLMGenerationDataModel
|
||||
|
||||
logger = get_logger("event_bus")
|
||||
|
||||
# Handler 签名:接收 MaiMessages,返回 (continue, modified_message)
|
||||
@@ -127,8 +126,7 @@ class EventBus:
|
||||
async def cancel_handler_tasks(self, handler_name: str) -> None:
|
||||
"""取消某个 handler 的所有运行中任务"""
|
||||
tasks = self._running_tasks.pop(handler_name, [])
|
||||
remaining = [t for t in tasks if not t.done()]
|
||||
if remaining:
|
||||
if remaining := [t for t in tasks if not t.done()]:
|
||||
for t in remaining:
|
||||
t.cancel()
|
||||
await asyncio.gather(*remaining, return_exceptions=True)
|
||||
@@ -156,17 +154,14 @@ class EventBus:
|
||||
try:
|
||||
if task.cancelled():
|
||||
return
|
||||
exc = task.exception()
|
||||
if exc:
|
||||
if exc := task.exception():
|
||||
logger.error(f"handler {handler_name} 异步任务异常: {exc}")
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
task_list = self._running_tasks.get(handler_name, [])
|
||||
try:
|
||||
with contextlib.suppress(ValueError):
|
||||
task_list.remove(task)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
async def _bridge_to_ipc_runtime(
|
||||
self,
|
||||
@@ -188,17 +183,29 @@ class EventBus:
|
||||
event_value = event_type.value if isinstance(event_type, EventType) else str(event_type)
|
||||
message_dict = message.to_dict() if message and hasattr(message, "to_dict") else None
|
||||
|
||||
new_continue, _ = await prm.bridge_event(
|
||||
new_continue, modified_dict = await prm.bridge_event(
|
||||
event_type_value=event_value,
|
||||
message_dict=message_dict,
|
||||
)
|
||||
if not new_continue:
|
||||
continue_flag = False
|
||||
if modified_dict is not None and message is not None:
|
||||
message = self._apply_ipc_message_update(message, modified_dict)
|
||||
except Exception as e:
|
||||
logger.warning(f"桥接事件到 IPC 运行时失败: {e}")
|
||||
|
||||
return continue_flag, message
|
||||
|
||||
@staticmethod
|
||||
def _apply_ipc_message_update(message: MaiMessages, modified_dict: Dict[str, Any]) -> MaiMessages:
|
||||
"""将 IPC 返回的消息字典回写到当前 MaiMessages。"""
|
||||
updated_message = message.deepcopy()
|
||||
valid_fields = {field.name for field in fields(MaiMessages)}
|
||||
for key, value in modified_dict.items():
|
||||
if key in valid_fields:
|
||||
setattr(updated_message, key, value)
|
||||
return updated_message
|
||||
|
||||
|
||||
class _HandlerEntry:
|
||||
"""内部 handler 条目"""
|
||||
|
||||
@@ -65,10 +65,11 @@ class CapabilityService:
|
||||
capability = req.capability
|
||||
|
||||
# 1. 权限校验
|
||||
allowed, reason = self._policy.check_capability(plugin_id, capability)
|
||||
allowed, reason = self._policy.check_capability(plugin_id, capability, envelope.generation)
|
||||
if not allowed:
|
||||
error_code = ErrorCode.E_GENERATION_MISMATCH if "generation 不匹配" in reason else ErrorCode.E_CAPABILITY_DENIED
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_CAPABILITY_DENIED.value,
|
||||
error_code.value,
|
||||
reason,
|
||||
)
|
||||
|
||||
|
||||
@@ -73,6 +73,13 @@ class ComponentRegistry:
|
||||
# 按插件索引
|
||||
self._by_plugin: Dict[str, List[RegisteredComponent]] = {}
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空全部组件注册状态。"""
|
||||
self._components.clear()
|
||||
for type_dict in self._by_type.values():
|
||||
type_dict.clear()
|
||||
self._by_plugin.clear()
|
||||
|
||||
# ──── 注册 / 注销 ─────────────────────────────────────────
|
||||
|
||||
def register_component(
|
||||
|
||||
@@ -44,7 +44,11 @@ class PolicyEngine:
|
||||
"""撤销插件的能力令牌"""
|
||||
self._tokens.pop(plugin_id, None)
|
||||
|
||||
def check_capability(self, plugin_id: str, capability: str) -> Tuple[bool, str]:
|
||||
def clear(self) -> None:
|
||||
"""清空所有能力令牌。"""
|
||||
self._tokens.clear()
|
||||
|
||||
def check_capability(self, plugin_id: str, capability: str, generation: Optional[int] = None) -> Tuple[bool, str]:
|
||||
"""检查插件是否有权调用某项能力
|
||||
|
||||
Returns:
|
||||
@@ -57,6 +61,9 @@ class PolicyEngine:
|
||||
if capability not in token.capabilities:
|
||||
return False, f"插件 {plugin_id} 未获授权能力: {capability}"
|
||||
|
||||
if generation is not None and token.generation != generation:
|
||||
return False, f"插件 {plugin_id} generation 不匹配: {generation} != {token.generation}"
|
||||
|
||||
return True, ""
|
||||
|
||||
def get_token(self, plugin_id: str) -> Optional[CapabilityToken]:
|
||||
|
||||
@@ -73,6 +73,10 @@ class RPCServer:
|
||||
def session_token(self) -> str:
|
||||
return self._session_token
|
||||
|
||||
@property
|
||||
def runner_generation(self) -> int:
|
||||
return self._runner_generation
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._connection is not None and not self._connection.is_closed
|
||||
@@ -206,18 +210,23 @@ class RPCServer:
|
||||
await conn.close()
|
||||
return
|
||||
|
||||
# 握手成功,保存连接
|
||||
old_connection = self._connection
|
||||
self._connection = conn
|
||||
logger.info(f"Runner 握手成功: runner_id={self._runner_id}, generation={self._runner_generation}")
|
||||
|
||||
if old_connection and old_connection is not conn and not old_connection.is_closed:
|
||||
logger.info("检测到新 Runner 已接管连接,关闭旧连接")
|
||||
await old_connection.close()
|
||||
|
||||
# 启动消息接收循环
|
||||
try:
|
||||
await self._recv_loop(conn)
|
||||
except Exception as e:
|
||||
logger.error(f"连接异常断开: {e}")
|
||||
finally:
|
||||
self._connection = None
|
||||
self._runner_id = None
|
||||
if self._connection is conn:
|
||||
self._connection = None
|
||||
self._runner_id = None
|
||||
|
||||
async def _handle_handshake(self, conn: Connection) -> bool:
|
||||
"""处理 runner.hello 握手"""
|
||||
@@ -295,17 +304,35 @@ class RPCServer:
|
||||
if envelope.is_response():
|
||||
self._handle_response(envelope)
|
||||
elif envelope.is_request():
|
||||
if not self._is_current_generation(envelope):
|
||||
error_resp = envelope.make_error_response(
|
||||
ErrorCode.E_GENERATION_MISMATCH.value,
|
||||
f"过期 generation: {envelope.generation} != {self._runner_generation}",
|
||||
)
|
||||
await conn.send_frame(self._codec.encode_envelope(error_resp))
|
||||
continue
|
||||
# 异步处理请求(Runner 发来的能力调用)
|
||||
task = asyncio.create_task(self._handle_request(envelope, conn))
|
||||
self._tasks.append(task)
|
||||
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
|
||||
elif envelope.is_event():
|
||||
if not self._is_current_generation(envelope):
|
||||
logger.warning(
|
||||
f"忽略过期 generation 事件 {envelope.method}: {envelope.generation} != {self._runner_generation}"
|
||||
)
|
||||
continue
|
||||
task = asyncio.create_task(self._handle_event(envelope))
|
||||
self._tasks.append(task)
|
||||
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
|
||||
|
||||
def _handle_response(self, envelope: Envelope) -> None:
|
||||
"""处理来自 Runner 的响应"""
|
||||
if not self._is_current_generation(envelope):
|
||||
logger.warning(
|
||||
f"忽略过期 generation 响应 {envelope.method}: {envelope.generation} != {self._runner_generation}"
|
||||
)
|
||||
return
|
||||
|
||||
future = self._pending_requests.pop(envelope.request_id, None)
|
||||
if future and not future.done():
|
||||
if envelope.error:
|
||||
@@ -313,6 +340,9 @@ class RPCServer:
|
||||
else:
|
||||
future.set_result(envelope)
|
||||
|
||||
def _is_current_generation(self, envelope: Envelope) -> bool:
|
||||
return envelope.generation == self._runner_generation
|
||||
|
||||
async def _handle_request(self, envelope: Envelope, conn: Connection) -> None:
|
||||
"""处理来自 Runner 的请求(通常是能力调用 cap.*)"""
|
||||
handler = self._method_handlers.get(envelope.method)
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import sys
|
||||
|
||||
@@ -75,6 +76,7 @@ class PluginSupervisor:
|
||||
|
||||
# 后台任务
|
||||
self._health_task: Optional[asyncio.Task] = None
|
||||
self._runner_output_tasks: List[asyncio.Task] = []
|
||||
self._running = False
|
||||
|
||||
# 注册内部 RPC 方法
|
||||
@@ -224,40 +226,26 @@ class PluginSupervisor:
|
||||
|
||||
# 保存旧进程引用
|
||||
old_process = self._runner_process
|
||||
old_registered_plugins = dict(self._registered_plugins)
|
||||
expected_generation = self._rpc_server.runner_generation + 1
|
||||
|
||||
# 清理旧的组件注册,防止幽灵组件残留
|
||||
for plugin_id in list(self._registered_plugins.keys()):
|
||||
self._component_registry.remove_components_by_plugin(plugin_id)
|
||||
self._policy.revoke_plugin(plugin_id)
|
||||
self._registered_plugins.clear()
|
||||
self._clear_runtime_state()
|
||||
|
||||
# 拉起新 Runner
|
||||
await self._spawn_runner()
|
||||
|
||||
# 等待新 Runner 连接并完成握手
|
||||
for _ in range(30): # 最多等待 30 秒
|
||||
if self._rpc_server.is_connected:
|
||||
break
|
||||
await asyncio.sleep(1.0)
|
||||
else:
|
||||
logger.error("新 Runner 连接超时,回滚")
|
||||
# 回滚:终止新进程
|
||||
if self._runner_process and self._runner_process != old_process:
|
||||
self._runner_process.terminate()
|
||||
self._runner_process = old_process
|
||||
return
|
||||
|
||||
# 健康检查
|
||||
try:
|
||||
await self._spawn_runner()
|
||||
await self._wait_for_runner_generation(expected_generation, timeout_sec=30.0)
|
||||
resp = await self._rpc_server.send_request("plugin.health", timeout_ms=5000)
|
||||
health = HealthPayload.model_validate(resp.payload)
|
||||
if not health.healthy:
|
||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "新 Runner 健康检查失败")
|
||||
except Exception as e:
|
||||
logger.error(f"新 Runner 健康检查失败: {e},回滚")
|
||||
if self._runner_process and self._runner_process != old_process:
|
||||
self._runner_process.terminate()
|
||||
await self._terminate_process(self._runner_process, old_process)
|
||||
self._runner_process = old_process
|
||||
self._registered_plugins = dict(old_registered_plugins)
|
||||
self._rebuild_runtime_state()
|
||||
return
|
||||
|
||||
# 关停旧 Runner
|
||||
@@ -286,13 +274,19 @@ class PluginSupervisor:
|
||||
except Exception as e:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e))
|
||||
|
||||
if envelope.generation != self._rpc_server.runner_generation:
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_GENERATION_MISMATCH.value,
|
||||
f"组件注册 generation 过期: {envelope.generation} != {self._rpc_server.runner_generation}",
|
||||
)
|
||||
|
||||
# 记录注册信息
|
||||
self._registered_plugins[reg.plugin_id] = reg
|
||||
|
||||
# 在策略引擎中注册插件
|
||||
self._policy.register_plugin(
|
||||
plugin_id=reg.plugin_id,
|
||||
generation=self._runner_generation,
|
||||
generation=envelope.generation,
|
||||
capabilities=reg.capabilities_required or [],
|
||||
)
|
||||
|
||||
@@ -329,7 +323,8 @@ class PluginSupervisor:
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
self._runner_generation += 1
|
||||
self._attach_runner_output_tasks(self._runner_process)
|
||||
self._runner_generation = self._rpc_server.runner_generation
|
||||
logger.info(f"Runner 子进程已启动: pid={self._runner_process.pid}, generation={self._runner_generation}")
|
||||
|
||||
async def _shutdown_runner(self) -> None:
|
||||
@@ -362,6 +357,8 @@ class PluginSupervisor:
|
||||
self._runner_process.kill()
|
||||
await self._runner_process.wait()
|
||||
|
||||
await self._cleanup_runner_output_tasks()
|
||||
|
||||
async def _health_check_loop(self) -> None:
|
||||
"""周期性健康检查 + 崩溃自动重启"""
|
||||
while self._running:
|
||||
@@ -382,6 +379,7 @@ class PluginSupervisor:
|
||||
self._registered_plugins.clear()
|
||||
|
||||
try:
|
||||
self._clear_runtime_state()
|
||||
await self._spawn_runner()
|
||||
except Exception as e:
|
||||
logger.error(f"Runner 重启失败: {e}", exc_info=True)
|
||||
@@ -407,3 +405,98 @@ class PluginSupervisor:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"健康检查异常: {e}")
|
||||
|
||||
async def _wait_for_runner_generation(self, expected_generation: int, timeout_sec: float) -> None:
|
||||
"""等待指定代际的 Runner 完成连接。"""
|
||||
deadline = asyncio.get_running_loop().time() + timeout_sec
|
||||
while asyncio.get_running_loop().time() < deadline:
|
||||
if self._rpc_server.is_connected and self._rpc_server.runner_generation >= expected_generation:
|
||||
self._runner_generation = self._rpc_server.runner_generation
|
||||
return
|
||||
await asyncio.sleep(0.1)
|
||||
raise TimeoutError(f"等待 Runner generation {expected_generation} 超时")
|
||||
|
||||
def _clear_runtime_state(self) -> None:
|
||||
"""清空当前插件注册态。"""
|
||||
self._component_registry.clear()
|
||||
self._policy.clear()
|
||||
self._registered_plugins.clear()
|
||||
|
||||
def _rebuild_runtime_state(self) -> None:
|
||||
"""根据已记录的插件注册信息重建运行时状态。"""
|
||||
self._component_registry.clear()
|
||||
self._policy.clear()
|
||||
for reg in self._registered_plugins.values():
|
||||
self._policy.register_plugin(
|
||||
plugin_id=reg.plugin_id,
|
||||
generation=self._rpc_server.runner_generation,
|
||||
capabilities=reg.capabilities_required or [],
|
||||
)
|
||||
self._component_registry.register_plugin_components(
|
||||
plugin_id=reg.plugin_id,
|
||||
components=[c.model_dump() for c in reg.components],
|
||||
)
|
||||
|
||||
def _attach_runner_output_tasks(self, process: asyncio.subprocess.Process) -> None:
|
||||
"""为 Runner 输出流创建排空任务,避免 PIPE 填满阻塞子进程。"""
|
||||
streams = (
|
||||
(process.stdout, "stdout"),
|
||||
(process.stderr, "stderr"),
|
||||
)
|
||||
for stream, stream_name in streams:
|
||||
if stream is None:
|
||||
continue
|
||||
task = asyncio.create_task(self._drain_runner_stream(stream, stream_name, process.pid))
|
||||
self._runner_output_tasks.append(task)
|
||||
task.add_done_callback(
|
||||
lambda done_task: self._runner_output_tasks.remove(done_task)
|
||||
if done_task in self._runner_output_tasks
|
||||
else None
|
||||
)
|
||||
|
||||
async def _drain_runner_stream(
|
||||
self,
|
||||
stream: asyncio.StreamReader,
|
||||
stream_name: str,
|
||||
pid: int,
|
||||
) -> None:
|
||||
"""持续消费 Runner 输出,避免 PIPE 回压导致子进程阻塞。"""
|
||||
try:
|
||||
while True:
|
||||
line = await stream.readline()
|
||||
if not line:
|
||||
break
|
||||
message = line.decode(errors="replace").rstrip()
|
||||
if message:
|
||||
logger.debug(f"[runner:{pid}:{stream_name}] {message}")
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.debug(f"读取 Runner {stream_name} 失败: {e}")
|
||||
|
||||
async def _cleanup_runner_output_tasks(self) -> None:
|
||||
"""等待并清理 Runner 输出任务。"""
|
||||
tasks = list(self._runner_output_tasks)
|
||||
self._runner_output_tasks.clear()
|
||||
for task in tasks:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
if tasks:
|
||||
with contextlib.suppress(Exception):
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
@staticmethod
|
||||
async def _terminate_process(
|
||||
process: Optional[asyncio.subprocess.Process],
|
||||
keep_process: Optional[asyncio.subprocess.Process] = None,
|
||||
) -> None:
|
||||
"""终止指定进程,但跳过需要保留的旧进程引用。"""
|
||||
if process is None or process is keep_process or process.returncode is not None:
|
||||
return
|
||||
|
||||
process.terminate()
|
||||
try:
|
||||
await asyncio.wait_for(process.wait(), timeout=10.0)
|
||||
except asyncio.TimeoutError:
|
||||
process.kill()
|
||||
await process.wait()
|
||||
|
||||
@@ -90,21 +90,22 @@ class PluginRuntimeManager:
|
||||
)
|
||||
self._register_capability_impls(self._thirdparty_supervisor)
|
||||
|
||||
# 并行启动
|
||||
coros = []
|
||||
if self._builtin_supervisor:
|
||||
coros.append(self._builtin_supervisor.start())
|
||||
if self._thirdparty_supervisor:
|
||||
coros.append(self._thirdparty_supervisor.start())
|
||||
|
||||
started_supervisors = []
|
||||
try:
|
||||
await asyncio.gather(*coros)
|
||||
if self._builtin_supervisor:
|
||||
await self._builtin_supervisor.start()
|
||||
started_supervisors.append(self._builtin_supervisor)
|
||||
if self._thirdparty_supervisor:
|
||||
await self._thirdparty_supervisor.start()
|
||||
started_supervisors.append(self._thirdparty_supervisor)
|
||||
self._started = True
|
||||
logger.info(
|
||||
f"插件运行时已启动 — 内置: {builtin_dirs or '无'}, 第三方: {thirdparty_dirs or '无'}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"插件运行时启动失败: {e}", exc_info=True)
|
||||
await asyncio.gather(*(sv.stop() for sv in started_supervisors), return_exceptions=True)
|
||||
self._started = False
|
||||
self._builtin_supervisor = None
|
||||
self._thirdparty_supervisor = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user