feat: 实现插件能力令牌的多版本支持,优化插件热重载逻辑

This commit is contained in:
DrSmoothl
2026-03-13 16:54:01 +08:00
parent 324432ff92
commit 8da1b6d93f
6 changed files with 389 additions and 22 deletions

View File

@@ -24,7 +24,7 @@ class PolicyEngine:
"""
def __init__(self) -> None:
self._tokens: Dict[str, CapabilityToken] = {}
self._tokens: Dict[str, Dict[int, CapabilityToken]] = {}
def register_plugin(
self,
@@ -38,12 +38,22 @@ class PolicyEngine:
generation=generation,
capabilities=set(capabilities),
)
self._tokens[plugin_id] = token
self._tokens.setdefault(plugin_id, {})[generation] = token
return token
def revoke_plugin(self, plugin_id: str) -> None:
"""撤销插件的能力令牌"""
self._tokens.pop(plugin_id, None)
def revoke_plugin(self, plugin_id: str, generation: Optional[int] = None) -> None:
"""撤销插件的能力令牌"""
if generation is None:
self._tokens.pop(plugin_id, None)
return
generations = self._tokens.get(plugin_id)
if generations is None:
return
generations.pop(generation, None)
if not generations:
self._tokens.pop(plugin_id, None)
def clear(self) -> None:
"""清空所有能力令牌。"""
@@ -55,10 +65,18 @@ class PolicyEngine:
Returns:
(allowed, reason)
"""
token = self._tokens.get(plugin_id)
if token is None:
generations = self._tokens.get(plugin_id)
if not generations:
return False, f"插件 {plugin_id} 未注册能力令牌"
if generation is None:
token = generations[max(generations)]
else:
token = generations.get(generation)
if token is None:
active_generation = max(generations)
return False, f"插件 {plugin_id} generation 不匹配: {generation} != {active_generation}"
if capability not in token.capabilities:
return False, f"插件 {plugin_id} 未获授权能力: {capability}"
@@ -69,7 +87,10 @@ class PolicyEngine:
def get_token(self, plugin_id: str) -> Optional[CapabilityToken]:
"""获取插件的能力令牌"""
return self._tokens.get(plugin_id)
generations = self._tokens.get(plugin_id)
if not generations:
return None
return generations[max(generations)]
def list_plugins(self) -> List[str]:
"""列出所有已注册的插件"""

View File

@@ -25,11 +25,13 @@ from src.plugin_runtime.host.policy_engine import PolicyEngine
from src.plugin_runtime.host.rpc_server import RPCServer
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor, WorkflowContext, WorkflowResult
from src.plugin_runtime.protocol.envelope import (
BootstrapPluginPayload,
ConfigUpdatedPayload,
Envelope,
HealthPayload,
LogBatchPayload,
RegisterComponentsPayload,
RunnerReadyPayload,
ShutdownPayload,
)
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
@@ -137,6 +139,8 @@ class PluginSupervisor:
# 已注册的插件组件信息
self._registered_plugins: Dict[str, RegisterComponentsPayload] = {}
self._staged_registered_plugins: Dict[str, RegisterComponentsPayload] = {}
self._runner_ready_events: Dict[int, asyncio.Event] = {}
self._runner_ready_payloads: Dict[int, RunnerReadyPayload] = {}
# 后台任务
self._health_task: Optional[asyncio.Task] = None
@@ -255,12 +259,15 @@ class PluginSupervisor:
# 拉起 Runner 进程
await self._spawn_runner()
# 等待 Runner 完成连接,避免 start() 返回时 Runner 尚未就绪
# 等待 Runner 完成连接和初始化,避免 start() 返回时 Runner 尚未就绪
try:
await self._wait_for_runner_generation(expected_generation, timeout_sec=self._runner_spawn_timeout)
await self._wait_for_runner_ready(expected_generation, timeout_sec=self._runner_spawn_timeout)
except TimeoutError:
if not self._rpc_server.is_connected:
logger.warning(f"Runner 未在 {self._runner_spawn_timeout}s 内完成连接,后续操作可能失败")
else:
logger.warning(f"Runner 未在 {self._runner_spawn_timeout}s 内完成初始化,后续操作可能失败")
# 启动健康检查
self._health_task = asyncio.create_task(self._health_check_loop())
@@ -306,7 +313,7 @@ class PluginSupervisor:
timeout_ms=timeout_ms,
)
async def reload_plugins(self, reason: str = "manual") -> None:
async def reload_plugins(self, reason: str = "manual") -> bool:
"""热重载所有插件(进程级 generation 切换)
1. 拉起新 Runner
@@ -341,6 +348,7 @@ class PluginSupervisor:
timeout_sec=self._runner_spawn_timeout,
allow_staged=True,
)
await self._wait_for_runner_ready(expected_generation, timeout_sec=self._runner_spawn_timeout)
resp = await self._rpc_server.send_request(
"plugin.health",
timeout_ms=5000,
@@ -359,7 +367,7 @@ class PluginSupervisor:
self._staged_registered_plugins.clear()
self._registered_plugins = dict(old_registered_plugins)
self._rebuild_runtime_state()
return
return False
self._runner_generation = self._rpc_server.runner_generation
self._registered_plugins = dict(self._staged_registered_plugins)
@@ -375,6 +383,7 @@ class PluginSupervisor:
old_process.kill()
logger.info("热重载完成")
return True
async def notify_plugin_config_updated(
self,
@@ -405,11 +414,39 @@ class PluginSupervisor:
"""注册 Host 端的 RPC 方法处理器"""
# Runner -> Host 的能力调用统一走 capability_service
self._rpc_server.register_method("cap.request", self._capability_service.handle_capability_request)
self._rpc_server.register_method("plugin.bootstrap", self._handle_bootstrap_plugin)
# 插件注册
self._rpc_server.register_method("plugin.register_components", self._handle_register_components)
self._rpc_server.register_method("runner.ready", self._handle_runner_ready)
# Runner 日志批量上报
self._rpc_server.register_method("runner.log_batch", self._log_bridge.handle_log_batch)
async def _handle_bootstrap_plugin(self, envelope: Envelope) -> Envelope:
"""处理插件 bootstrap 请求,仅同步能力令牌。"""
try:
bootstrap = BootstrapPluginPayload.model_validate(envelope.payload)
except Exception as e:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e))
active_generation = self._rpc_server.runner_generation
staged_generation = self._rpc_server.staged_generation
if envelope.generation not in {active_generation, staged_generation}:
return envelope.make_error_response(
ErrorCode.E_GENERATION_MISMATCH.value,
f"插件 bootstrap generation 过期: {envelope.generation} 不在已知代际中",
)
if bootstrap.capabilities_required:
self._policy.register_plugin(
plugin_id=bootstrap.plugin_id,
generation=envelope.generation,
capabilities=bootstrap.capabilities_required,
)
else:
self._policy.revoke_plugin(bootstrap.plugin_id, generation=envelope.generation)
return envelope.make_response(payload={"accepted": True})
async def _handle_register_components(self, envelope: Envelope) -> Envelope:
"""处理插件组件注册请求"""
try:
@@ -458,6 +495,22 @@ class PluginSupervisor:
return envelope.make_response(payload={"accepted": True})
async def _handle_runner_ready(self, envelope: Envelope) -> Envelope:
"""处理 Runner 初始化完成信号。"""
try:
ready = RunnerReadyPayload.model_validate(envelope.payload)
except Exception as e:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e))
event = self._runner_ready_events.setdefault(envelope.generation, asyncio.Event())
self._runner_ready_payloads[envelope.generation] = ready
event.set()
logger.info(
f"Runner generation={envelope.generation} 已就绪,成功插件数: {len(ready.loaded_plugins)}"
f"失败插件数: {len(ready.failed_plugins)}"
)
return envelope.make_response(payload={"accepted": True})
async def _spawn_runner(self) -> None:
"""拉起 Runner 子进程"""
runner_module = "src.plugin_runtime.runner.runner_main"
@@ -582,6 +635,12 @@ class PluginSupervisor:
await asyncio.sleep(0.1)
raise TimeoutError(f"等待 Runner generation {expected_generation} 超时")
async def _wait_for_runner_ready(self, expected_generation: int, timeout_sec: float) -> RunnerReadyPayload:
"""等待指定代际的 Runner 完成初始化。"""
event = self._runner_ready_events.setdefault(expected_generation, asyncio.Event())
await asyncio.wait_for(event.wait(), timeout=timeout_sec)
return self._runner_ready_payloads.get(expected_generation, RunnerReadyPayload())
def _clear_runtime_state(self) -> None:
"""清空当前插件注册态。"""
self._component_registry.clear()

View File

@@ -1742,8 +1742,10 @@ class PluginRuntimeManager:
for sv in mgr.supervisors:
if plugin_name in sv._registered_plugins:
try:
await sv.reload_plugins(reason=f"load {plugin_name}")
return {"success": True, "count": 1}
reloaded = await sv.reload_plugins(reason=f"load {plugin_name}")
if reloaded:
return {"success": True, "count": 1}
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
except Exception as e:
logger.error(f"[cap.component.load_plugin] 热重载失败: {e}")
return {"success": False, "error": str(e)}
@@ -1753,8 +1755,10 @@ class PluginRuntimeManager:
for pdir in sv._plugin_dirs:
if os.path.isdir(os.path.join(pdir, plugin_name)):
try:
await sv.reload_plugins(reason=f"load {plugin_name}")
return {"success": True, "count": 1}
reloaded = await sv.reload_plugins(reason=f"load {plugin_name}")
if reloaded:
return {"success": True, "count": 1}
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
except Exception as e:
logger.error(f"[cap.component.load_plugin] 热重载失败: {e}")
return {"success": False, "error": str(e)}
@@ -1783,8 +1787,10 @@ class PluginRuntimeManager:
for sv in mgr.supervisors:
if plugin_name in sv._registered_plugins:
try:
await sv.reload_plugins(reason=f"reload {plugin_name}")
return {"success": True}
reloaded = await sv.reload_plugins(reason=f"reload {plugin_name}")
if reloaded:
return {"success": True}
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
except Exception as e:
logger.error(f"[cap.component.reload_plugin] 热重载失败: {e}")
return {"success": False, "error": str(e)}

View File

@@ -146,6 +146,14 @@ class RegisterComponentsPayload(BaseModel):
capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表")
class BootstrapPluginPayload(BaseModel):
"""plugin.bootstrap 请求 payload"""
plugin_id: str = Field(description="插件 ID")
plugin_version: str = Field(default="1.0.0", description="插件版本")
capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表")
# ─── 调用消息 ──────────────────────────────────────────────────────
@@ -191,6 +199,13 @@ class HealthPayload(BaseModel):
uptime_ms: int = Field(default=0, description="运行时长(ms)")
class RunnerReadyPayload(BaseModel):
"""runner.ready 请求 payload"""
loaded_plugins: List[str] = Field(default_factory=list, description="已完成初始化的插件列表")
failed_plugins: List[str] = Field(default_factory=list, description="初始化失败的插件列表")
# ─── 配置更新 ──────────────────────────────────────────────────────

View File

@@ -26,12 +26,14 @@ import tomllib
from src.common.logger import get_console_handler, get_logger, initialize_logging
from src.plugin_runtime import ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN
from src.plugin_runtime.protocol.envelope import (
BootstrapPluginPayload,
ComponentDeclaration,
Envelope,
HealthPayload,
InvokePayload,
InvokeResultPayload,
RegisterComponentsPayload,
RunnerReadyPayload,
)
from src.plugin_runtime.protocol.errors import ErrorCode
from src.plugin_runtime.runner.log_handler import RunnerIPCLogHandler
@@ -99,6 +101,9 @@ class PluginRunner:
instance = meta.instance
self._inject_context(meta.plugin_id, instance)
self._apply_plugin_config(meta)
if not await self._bootstrap_plugin(meta):
failed_plugins.add(meta.plugin_id)
continue
if hasattr(instance, "on_load"):
try:
ret = instance.on_load()
@@ -107,12 +112,19 @@ class PluginRunner:
except Exception as e:
logger.error(f"插件 {meta.plugin_id} on_load 失败,跳过注册: {e}", exc_info=True)
failed_plugins.add(meta.plugin_id)
await self._deactivate_plugin(meta)
# 5. 向 Host 注册所有插件的组件(跳过 on_load 失败的插件)
for meta in plugins:
if meta.plugin_id in failed_plugins:
continue
await self._register_plugin(meta)
ok = await self._register_plugin(meta)
if not ok:
failed_plugins.add(meta.plugin_id)
await self._deactivate_plugin(meta)
successful_plugins = [meta.plugin_id for meta in plugins if meta.plugin_id not in failed_plugins]
await self._notify_ready(successful_plugins, sorted(failed_plugins))
# 5. 等待直到收到关停信号
with contextlib.suppress(asyncio.CancelledError):
@@ -256,7 +268,33 @@ class PluginRunner:
self._rpc_client.register_method("plugin.shutdown", self._handle_shutdown)
self._rpc_client.register_method("plugin.config_updated", self._handle_config_updated)
async def _register_plugin(self, meta: PluginMeta) -> None:
async def _bootstrap_plugin(self, meta: PluginMeta, capabilities_required: Optional[List[str]] = None) -> bool:
"""向 Host 同步插件 bootstrap 能力令牌。"""
payload = BootstrapPluginPayload(
plugin_id=meta.plugin_id,
plugin_version=meta.version,
capabilities_required=capabilities_required
if capabilities_required is not None
else list(meta.capabilities_required or []),
)
try:
await self._rpc_client.send_request(
"plugin.bootstrap",
plugin_id=meta.plugin_id,
payload=payload.model_dump(),
timeout_ms=10000,
)
return True
except Exception as e:
logger.error(f"插件 {meta.plugin_id} bootstrap 失败: {e}")
return False
async def _deactivate_plugin(self, meta: PluginMeta) -> None:
"""撤销 bootstrap 期间为插件签发的能力令牌。"""
await self._bootstrap_plugin(meta, capabilities_required=[])
async def _register_plugin(self, meta: PluginMeta) -> bool:
"""向 Host 注册单个插件"""
# 收集插件组件声明
components: List[ComponentDeclaration] = []
@@ -289,8 +327,22 @@ class PluginRunner:
timeout_ms=10000,
)
logger.info(f"插件 {meta.plugin_id} 注册完成")
return True
except Exception as e:
logger.error(f"插件 {meta.plugin_id} 注册失败: {e}")
return False
async def _notify_ready(self, loaded_plugins: List[str], failed_plugins: List[str]) -> None:
"""通知 Host 当前 generation 已完成插件初始化。"""
payload = RunnerReadyPayload(
loaded_plugins=loaded_plugins,
failed_plugins=failed_plugins,
)
await self._rpc_client.send_request(
"runner.ready",
payload=payload.model_dump(),
timeout_ms=10000,
)
async def _handle_invoke(self, envelope: Envelope) -> Envelope:
"""处理组件调用请求"""