From d0b56abdabcf448c5c8f497f9d9c246307c7e807 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Thu, 12 Mar 2026 21:22:23 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=BC=BA=E6=8F=92=E4=BB=B6?= =?UTF-8?q?=E8=83=BD=E5=8A=9B=E6=A3=80=E6=9F=A5=EF=BC=8C=E6=94=AF=E6=8C=81?= =?UTF-8?q?=20generation=20=E6=A0=A1=E9=AA=8C=E5=B9=B6=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E6=B8=85=E7=90=86=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pytests/test_plugin_runtime.py | 271 +++++++++++++++++- src/core/event_bus.py | 31 +- src/plugin_runtime/host/capability_service.py | 5 +- src/plugin_runtime/host/component_registry.py | 7 + src/plugin_runtime/host/policy_engine.py | 9 +- src/plugin_runtime/host/rpc_server.py | 36 ++- src/plugin_runtime/host/supervisor.py | 141 +++++++-- src/plugin_runtime/integration.py | 17 +- 8 files changed, 466 insertions(+), 51 deletions(-) diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py index 2688f1c4..aafd3d0f 100644 --- a/pytests/test_plugin_runtime.py +++ b/pytests/test_plugin_runtime.py @@ -3,9 +3,11 @@ 验证协议层、传输层、RPC 通信链路的正确性。 """ +from types import SimpleNamespace + import asyncio -import sys import os +import sys import pytest @@ -227,6 +229,10 @@ class TestHost: ok, reason = engine.check_capability("unknown", "send.text") assert not ok + ok, reason = engine.check_capability("test_plugin", "send.text", generation=2) + assert not ok + assert "generation 不匹配" in reason + def test_circuit_breaker_removed(self): """熔断器已移除,验证 supervisor 不依赖它""" pass @@ -702,6 +708,49 @@ class TestEventDispatcher: assert modified["plain_text"] == "filtered" +class TestEventBus: + """核心事件总线与 IPC 桥接测试""" + + @pytest.mark.asyncio + async def test_bridge_preserves_modified_message(self, monkeypatch): + import types + + fake_message_data_model = types.ModuleType("src.common.data_models.message_data_model") + fake_message_data_model.ReplyContentType = object + fake_message_data_model.ReplyContent = object + fake_message_data_model.ForwardNode = object + fake_message_data_model.ReplySetModel = object + monkeypatch.setitem(sys.modules, "src.common.data_models.message_data_model", fake_message_data_model) + + from src.core.event_bus import EventBus + from src.core.types import EventType, MaiMessages + from src.plugin_runtime import integration as integration_module + + bus = EventBus() + + async def noop_handler(message): + return True, message + + bus.subscribe(EventType.ON_MESSAGE, noop_handler, name="noop", intercept=True) + + class FakeManager: + is_running = True + + async def bridge_event(self, event_type_value, message_dict=None, extra_args=None): + assert event_type_value == EventType.ON_MESSAGE.value + return True, {"plain_text": "modified by ipc"} + + monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager()) + + original = MaiMessages(plain_text="original") + continue_flag, modified = await bus.emit(EventType.ON_MESSAGE, original) + + assert continue_flag is True + assert modified is not None + assert modified.plain_text == "modified by ipc" + assert original.plain_text == "original" + + # ─── MaiMessages 测试 ───────────────────────────────────── class TestMaiMessages: @@ -1050,3 +1099,223 @@ class TestWorkflowExecutor: ) assert result.status == "completed" assert ctx.get_stage_output("ingress", "parsed_intent") == "greeting" + + +class TestRPCServer: + """RPC Server 代际保护测试""" + + def test_ignore_stale_generation_response(self): + from src.plugin_runtime.host.rpc_server import RPCServer + from src.plugin_runtime.protocol.envelope import Envelope, MessageType + + class DummyTransport: + async def start(self, handler): + return None + + async def stop(self): + return None + + def get_address(self): + return "dummy" + + server = RPCServer(transport=DummyTransport()) + server._runner_generation = 2 + + loop = asyncio.new_event_loop() + try: + future = loop.create_future() + server._pending_requests[1] = future + + stale_response = Envelope( + request_id=1, + message_type=MessageType.RESPONSE, + method="plugin.health", + generation=1, + payload={"healthy": True}, + ) + server._handle_response(stale_response) + + assert not future.done() + assert 1 in server._pending_requests + finally: + loop.close() + + +class TestSupervisor: + """Supervisor 生命周期边界测试""" + + @staticmethod + def _build_register_payload(plugin_id: str = "plugin_a"): + from src.plugin_runtime.protocol.envelope import ComponentDeclaration, RegisterComponentsPayload + + return RegisterComponentsPayload( + plugin_id=plugin_id, + plugin_version="1.0.0", + components=[ + ComponentDeclaration( + name="handler", + component_type="event_handler", + plugin_id=plugin_id, + metadata={"event_type": "on_message"}, + ) + ], + capabilities_required=["send.text"], + ) + + @staticmethod + def _make_process(pid: int): + class FakeProcess: + def __init__(self): + self.pid = pid + self.returncode = None + self.stdout = None + self.stderr = None + self.terminated = False + self.killed = False + + def terminate(self): + self.terminated = True + self.returncode = 0 + + def kill(self): + self.killed = True + self.returncode = -9 + + async def wait(self): + return self.returncode + + return FakeProcess() + + @pytest.mark.asyncio + async def test_reload_waits_for_target_generation(self, monkeypatch): + from src.plugin_runtime.host.supervisor import PluginSupervisor + from src.plugin_runtime.protocol.envelope import HealthPayload + + supervisor = PluginSupervisor(plugin_dirs=[]) + old_process = self._make_process(1) + new_process = self._make_process(2) + + class FakeRPCServer: + def __init__(self): + self.runner_generation = 1 + self.is_connected = True + + async def send_request(self, method, timeout_ms=5000, **kwargs): + assert self.runner_generation == 2 + return SimpleNamespace(payload=HealthPayload(healthy=True).model_dump()) + + supervisor._rpc_server = FakeRPCServer() + supervisor._runner_process = old_process + + async def fake_spawn_runner(): + supervisor._runner_process = new_process + + async def advance_generation(): + await asyncio.sleep(0.01) + supervisor._rpc_server.runner_generation = 2 + + asyncio.create_task(advance_generation()) + + monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner) + + await supervisor.reload_plugins("test") + + assert supervisor._runner_process is new_process + assert old_process.terminated is True + + @pytest.mark.asyncio + async def test_reload_restores_runtime_state_on_failure(self, monkeypatch): + from src.plugin_runtime.host.supervisor import PluginSupervisor + + supervisor = PluginSupervisor(plugin_dirs=[]) + old_process = self._make_process(1) + new_process = self._make_process(2) + old_reg = self._build_register_payload() + + supervisor._runner_process = old_process + supervisor._registered_plugins[old_reg.plugin_id] = old_reg + supervisor._rebuild_runtime_state() + + class FakeRPCServer: + def __init__(self): + self.runner_generation = 1 + self.is_connected = True + + async def send_request(self, method, timeout_ms=5000, **kwargs): + raise RuntimeError("new runner unhealthy") + + supervisor._rpc_server = FakeRPCServer() + + async def fake_spawn_runner(): + supervisor._runner_process = new_process + supervisor._rpc_server.runner_generation = 2 + + monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner) + + await supervisor.reload_plugins("test") + + assert supervisor._runner_process is old_process + assert old_reg.plugin_id in supervisor._registered_plugins + assert supervisor.component_registry.get_component("plugin_a.handler") is not None + + @pytest.mark.asyncio + async def test_attach_runner_output_tasks_drains_streams(self): + from src.plugin_runtime.host.supervisor import PluginSupervisor + + supervisor = PluginSupervisor(plugin_dirs=[]) + + stdout = asyncio.StreamReader() + stdout.feed_data(b"hello stdout\n") + stdout.feed_eof() + + stderr = asyncio.StreamReader() + stderr.feed_data(b"hello stderr\n") + stderr.feed_eof() + + process = SimpleNamespace(pid=99, stdout=stdout, stderr=stderr) + supervisor._attach_runner_output_tasks(process) + + await asyncio.sleep(0.05) + + assert not supervisor._runner_output_tasks + + +class TestIntegration: + """运行时集成层启动/清理测试""" + + @pytest.mark.asyncio + async def test_start_cleans_up_started_supervisors_on_failure(self, monkeypatch): + from src.plugin_runtime import integration as integration_module + + instances = [] + + class FakeCapabilityService: + def register_capability(self, name, impl): + return None + + class FakeSupervisor: + def __init__(self, plugin_dirs=None, socket_path=None): + self.plugin_dirs = plugin_dirs or [] + self.capability_service = FakeCapabilityService() + self.stopped = False + instances.append(self) + + async def start(self): + if len(instances) == 2 and self is instances[1]: + raise RuntimeError("boom") + + async def stop(self): + self.stopped = True + + monkeypatch.setattr(integration_module.PluginRuntimeManager, "_get_builtin_plugin_dirs", staticmethod(lambda: ["builtin"])) + monkeypatch.setattr(integration_module.PluginRuntimeManager, "_get_thirdparty_plugin_dirs", staticmethod(lambda: ["thirdparty"])) + + import src.plugin_runtime.host.supervisor as supervisor_module + monkeypatch.setattr(supervisor_module, "PluginSupervisor", FakeSupervisor) + + manager = integration_module.PluginRuntimeManager() + await manager.start() + + assert manager.is_running is False + assert len(instances) == 2 + assert instances[0].stopped is True diff --git a/src/core/event_bus.py b/src/core/event_bus.py index f4dbed46..c84a86b0 100644 --- a/src/core/event_bus.py +++ b/src/core/event_bus.py @@ -8,14 +8,13 @@ """ import asyncio -from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING +import contextlib +from dataclasses import fields +from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple from src.common.logger import get_logger from src.core.types import EventType, MaiMessages -if TYPE_CHECKING: - from src.common.data_models.llm_data_model import LLMGenerationDataModel - logger = get_logger("event_bus") # Handler 签名:接收 MaiMessages,返回 (continue, modified_message) @@ -127,8 +126,7 @@ class EventBus: async def cancel_handler_tasks(self, handler_name: str) -> None: """取消某个 handler 的所有运行中任务""" tasks = self._running_tasks.pop(handler_name, []) - remaining = [t for t in tasks if not t.done()] - if remaining: + if remaining := [t for t in tasks if not t.done()]: for t in remaining: t.cancel() await asyncio.gather(*remaining, return_exceptions=True) @@ -156,17 +154,14 @@ class EventBus: try: if task.cancelled(): return - exc = task.exception() - if exc: + if exc := task.exception(): logger.error(f"handler {handler_name} 异步任务异常: {exc}") except Exception: pass finally: task_list = self._running_tasks.get(handler_name, []) - try: + with contextlib.suppress(ValueError): task_list.remove(task) - except ValueError: - pass async def _bridge_to_ipc_runtime( self, @@ -188,17 +183,29 @@ class EventBus: event_value = event_type.value if isinstance(event_type, EventType) else str(event_type) message_dict = message.to_dict() if message and hasattr(message, "to_dict") else None - new_continue, _ = await prm.bridge_event( + new_continue, modified_dict = await prm.bridge_event( event_type_value=event_value, message_dict=message_dict, ) if not new_continue: continue_flag = False + if modified_dict is not None and message is not None: + message = self._apply_ipc_message_update(message, modified_dict) except Exception as e: logger.warning(f"桥接事件到 IPC 运行时失败: {e}") return continue_flag, message + @staticmethod + def _apply_ipc_message_update(message: MaiMessages, modified_dict: Dict[str, Any]) -> MaiMessages: + """将 IPC 返回的消息字典回写到当前 MaiMessages。""" + updated_message = message.deepcopy() + valid_fields = {field.name for field in fields(MaiMessages)} + for key, value in modified_dict.items(): + if key in valid_fields: + setattr(updated_message, key, value) + return updated_message + class _HandlerEntry: """内部 handler 条目""" diff --git a/src/plugin_runtime/host/capability_service.py b/src/plugin_runtime/host/capability_service.py index 5fe8cefe..643fef95 100644 --- a/src/plugin_runtime/host/capability_service.py +++ b/src/plugin_runtime/host/capability_service.py @@ -65,10 +65,11 @@ class CapabilityService: capability = req.capability # 1. 权限校验 - allowed, reason = self._policy.check_capability(plugin_id, capability) + allowed, reason = self._policy.check_capability(plugin_id, capability, envelope.generation) if not allowed: + error_code = ErrorCode.E_GENERATION_MISMATCH if "generation 不匹配" in reason else ErrorCode.E_CAPABILITY_DENIED return envelope.make_error_response( - ErrorCode.E_CAPABILITY_DENIED.value, + error_code.value, reason, ) diff --git a/src/plugin_runtime/host/component_registry.py b/src/plugin_runtime/host/component_registry.py index e473937e..fc2a4421 100644 --- a/src/plugin_runtime/host/component_registry.py +++ b/src/plugin_runtime/host/component_registry.py @@ -73,6 +73,13 @@ class ComponentRegistry: # 按插件索引 self._by_plugin: Dict[str, List[RegisteredComponent]] = {} + def clear(self) -> None: + """清空全部组件注册状态。""" + self._components.clear() + for type_dict in self._by_type.values(): + type_dict.clear() + self._by_plugin.clear() + # ──── 注册 / 注销 ───────────────────────────────────────── def register_component( diff --git a/src/plugin_runtime/host/policy_engine.py b/src/plugin_runtime/host/policy_engine.py index 05f25bb8..8327e56c 100644 --- a/src/plugin_runtime/host/policy_engine.py +++ b/src/plugin_runtime/host/policy_engine.py @@ -44,7 +44,11 @@ class PolicyEngine: """撤销插件的能力令牌""" self._tokens.pop(plugin_id, None) - def check_capability(self, plugin_id: str, capability: str) -> Tuple[bool, str]: + def clear(self) -> None: + """清空所有能力令牌。""" + self._tokens.clear() + + def check_capability(self, plugin_id: str, capability: str, generation: Optional[int] = None) -> Tuple[bool, str]: """检查插件是否有权调用某项能力 Returns: @@ -57,6 +61,9 @@ class PolicyEngine: if capability not in token.capabilities: return False, f"插件 {plugin_id} 未获授权能力: {capability}" + if generation is not None and token.generation != generation: + return False, f"插件 {plugin_id} generation 不匹配: {generation} != {token.generation}" + return True, "" def get_token(self, plugin_id: str) -> Optional[CapabilityToken]: diff --git a/src/plugin_runtime/host/rpc_server.py b/src/plugin_runtime/host/rpc_server.py index 7ac2bed4..78efbfdd 100644 --- a/src/plugin_runtime/host/rpc_server.py +++ b/src/plugin_runtime/host/rpc_server.py @@ -73,6 +73,10 @@ class RPCServer: def session_token(self) -> str: return self._session_token + @property + def runner_generation(self) -> int: + return self._runner_generation + @property def is_connected(self) -> bool: return self._connection is not None and not self._connection.is_closed @@ -206,18 +210,23 @@ class RPCServer: await conn.close() return - # 握手成功,保存连接 + old_connection = self._connection self._connection = conn logger.info(f"Runner 握手成功: runner_id={self._runner_id}, generation={self._runner_generation}") + if old_connection and old_connection is not conn and not old_connection.is_closed: + logger.info("检测到新 Runner 已接管连接,关闭旧连接") + await old_connection.close() + # 启动消息接收循环 try: await self._recv_loop(conn) except Exception as e: logger.error(f"连接异常断开: {e}") finally: - self._connection = None - self._runner_id = None + if self._connection is conn: + self._connection = None + self._runner_id = None async def _handle_handshake(self, conn: Connection) -> bool: """处理 runner.hello 握手""" @@ -295,17 +304,35 @@ class RPCServer: if envelope.is_response(): self._handle_response(envelope) elif envelope.is_request(): + if not self._is_current_generation(envelope): + error_resp = envelope.make_error_response( + ErrorCode.E_GENERATION_MISMATCH.value, + f"过期 generation: {envelope.generation} != {self._runner_generation}", + ) + await conn.send_frame(self._codec.encode_envelope(error_resp)) + continue # 异步处理请求(Runner 发来的能力调用) task = asyncio.create_task(self._handle_request(envelope, conn)) self._tasks.append(task) task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None) elif envelope.is_event(): + if not self._is_current_generation(envelope): + logger.warning( + f"忽略过期 generation 事件 {envelope.method}: {envelope.generation} != {self._runner_generation}" + ) + continue task = asyncio.create_task(self._handle_event(envelope)) self._tasks.append(task) task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None) def _handle_response(self, envelope: Envelope) -> None: """处理来自 Runner 的响应""" + if not self._is_current_generation(envelope): + logger.warning( + f"忽略过期 generation 响应 {envelope.method}: {envelope.generation} != {self._runner_generation}" + ) + return + future = self._pending_requests.pop(envelope.request_id, None) if future and not future.done(): if envelope.error: @@ -313,6 +340,9 @@ class RPCServer: else: future.set_result(envelope) + def _is_current_generation(self, envelope: Envelope) -> bool: + return envelope.generation == self._runner_generation + async def _handle_request(self, envelope: Envelope, conn: Connection) -> None: """处理来自 Runner 的请求(通常是能力调用 cap.*)""" handler = self._method_handlers.get(envelope.method) diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py index 150c7b9c..5f9f78e6 100644 --- a/src/plugin_runtime/host/supervisor.py +++ b/src/plugin_runtime/host/supervisor.py @@ -10,6 +10,7 @@ from typing import Any, Dict, List, Optional, Tuple import asyncio +import contextlib import os import sys @@ -75,6 +76,7 @@ class PluginSupervisor: # 后台任务 self._health_task: Optional[asyncio.Task] = None + self._runner_output_tasks: List[asyncio.Task] = [] self._running = False # 注册内部 RPC 方法 @@ -224,40 +226,26 @@ class PluginSupervisor: # 保存旧进程引用 old_process = self._runner_process + old_registered_plugins = dict(self._registered_plugins) + expected_generation = self._rpc_server.runner_generation + 1 # 清理旧的组件注册,防止幽灵组件残留 - for plugin_id in list(self._registered_plugins.keys()): - self._component_registry.remove_components_by_plugin(plugin_id) - self._policy.revoke_plugin(plugin_id) - self._registered_plugins.clear() + self._clear_runtime_state() # 拉起新 Runner - await self._spawn_runner() - - # 等待新 Runner 连接并完成握手 - for _ in range(30): # 最多等待 30 秒 - if self._rpc_server.is_connected: - break - await asyncio.sleep(1.0) - else: - logger.error("新 Runner 连接超时,回滚") - # 回滚:终止新进程 - if self._runner_process and self._runner_process != old_process: - self._runner_process.terminate() - self._runner_process = old_process - return - - # 健康检查 try: + await self._spawn_runner() + await self._wait_for_runner_generation(expected_generation, timeout_sec=30.0) resp = await self._rpc_server.send_request("plugin.health", timeout_ms=5000) health = HealthPayload.model_validate(resp.payload) if not health.healthy: raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "新 Runner 健康检查失败") except Exception as e: logger.error(f"新 Runner 健康检查失败: {e},回滚") - if self._runner_process and self._runner_process != old_process: - self._runner_process.terminate() + await self._terminate_process(self._runner_process, old_process) self._runner_process = old_process + self._registered_plugins = dict(old_registered_plugins) + self._rebuild_runtime_state() return # 关停旧 Runner @@ -286,13 +274,19 @@ class PluginSupervisor: except Exception as e: return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e)) + if envelope.generation != self._rpc_server.runner_generation: + return envelope.make_error_response( + ErrorCode.E_GENERATION_MISMATCH.value, + f"组件注册 generation 过期: {envelope.generation} != {self._rpc_server.runner_generation}", + ) + # 记录注册信息 self._registered_plugins[reg.plugin_id] = reg # 在策略引擎中注册插件 self._policy.register_plugin( plugin_id=reg.plugin_id, - generation=self._runner_generation, + generation=envelope.generation, capabilities=reg.capabilities_required or [], ) @@ -329,7 +323,8 @@ class PluginSupervisor: stderr=asyncio.subprocess.PIPE, ) - self._runner_generation += 1 + self._attach_runner_output_tasks(self._runner_process) + self._runner_generation = self._rpc_server.runner_generation logger.info(f"Runner 子进程已启动: pid={self._runner_process.pid}, generation={self._runner_generation}") async def _shutdown_runner(self) -> None: @@ -362,6 +357,8 @@ class PluginSupervisor: self._runner_process.kill() await self._runner_process.wait() + await self._cleanup_runner_output_tasks() + async def _health_check_loop(self) -> None: """周期性健康检查 + 崩溃自动重启""" while self._running: @@ -382,6 +379,7 @@ class PluginSupervisor: self._registered_plugins.clear() try: + self._clear_runtime_state() await self._spawn_runner() except Exception as e: logger.error(f"Runner 重启失败: {e}", exc_info=True) @@ -407,3 +405,98 @@ class PluginSupervisor: break except Exception as e: logger.error(f"健康检查异常: {e}") + + async def _wait_for_runner_generation(self, expected_generation: int, timeout_sec: float) -> None: + """等待指定代际的 Runner 完成连接。""" + deadline = asyncio.get_running_loop().time() + timeout_sec + while asyncio.get_running_loop().time() < deadline: + if self._rpc_server.is_connected and self._rpc_server.runner_generation >= expected_generation: + self._runner_generation = self._rpc_server.runner_generation + return + await asyncio.sleep(0.1) + raise TimeoutError(f"等待 Runner generation {expected_generation} 超时") + + def _clear_runtime_state(self) -> None: + """清空当前插件注册态。""" + self._component_registry.clear() + self._policy.clear() + self._registered_plugins.clear() + + def _rebuild_runtime_state(self) -> None: + """根据已记录的插件注册信息重建运行时状态。""" + self._component_registry.clear() + self._policy.clear() + for reg in self._registered_plugins.values(): + self._policy.register_plugin( + plugin_id=reg.plugin_id, + generation=self._rpc_server.runner_generation, + capabilities=reg.capabilities_required or [], + ) + self._component_registry.register_plugin_components( + plugin_id=reg.plugin_id, + components=[c.model_dump() for c in reg.components], + ) + + def _attach_runner_output_tasks(self, process: asyncio.subprocess.Process) -> None: + """为 Runner 输出流创建排空任务,避免 PIPE 填满阻塞子进程。""" + streams = ( + (process.stdout, "stdout"), + (process.stderr, "stderr"), + ) + for stream, stream_name in streams: + if stream is None: + continue + task = asyncio.create_task(self._drain_runner_stream(stream, stream_name, process.pid)) + self._runner_output_tasks.append(task) + task.add_done_callback( + lambda done_task: self._runner_output_tasks.remove(done_task) + if done_task in self._runner_output_tasks + else None + ) + + async def _drain_runner_stream( + self, + stream: asyncio.StreamReader, + stream_name: str, + pid: int, + ) -> None: + """持续消费 Runner 输出,避免 PIPE 回压导致子进程阻塞。""" + try: + while True: + line = await stream.readline() + if not line: + break + message = line.decode(errors="replace").rstrip() + if message: + logger.debug(f"[runner:{pid}:{stream_name}] {message}") + except asyncio.CancelledError: + raise + except Exception as e: + logger.debug(f"读取 Runner {stream_name} 失败: {e}") + + async def _cleanup_runner_output_tasks(self) -> None: + """等待并清理 Runner 输出任务。""" + tasks = list(self._runner_output_tasks) + self._runner_output_tasks.clear() + for task in tasks: + if not task.done(): + task.cancel() + if tasks: + with contextlib.suppress(Exception): + await asyncio.gather(*tasks, return_exceptions=True) + + @staticmethod + async def _terminate_process( + process: Optional[asyncio.subprocess.Process], + keep_process: Optional[asyncio.subprocess.Process] = None, + ) -> None: + """终止指定进程,但跳过需要保留的旧进程引用。""" + if process is None or process is keep_process or process.returncode is not None: + return + + process.terminate() + try: + await asyncio.wait_for(process.wait(), timeout=10.0) + except asyncio.TimeoutError: + process.kill() + await process.wait() diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py index a5ecd684..0c1b2268 100644 --- a/src/plugin_runtime/integration.py +++ b/src/plugin_runtime/integration.py @@ -90,21 +90,22 @@ class PluginRuntimeManager: ) self._register_capability_impls(self._thirdparty_supervisor) - # 并行启动 - coros = [] - if self._builtin_supervisor: - coros.append(self._builtin_supervisor.start()) - if self._thirdparty_supervisor: - coros.append(self._thirdparty_supervisor.start()) - + started_supervisors = [] try: - await asyncio.gather(*coros) + if self._builtin_supervisor: + await self._builtin_supervisor.start() + started_supervisors.append(self._builtin_supervisor) + if self._thirdparty_supervisor: + await self._thirdparty_supervisor.start() + started_supervisors.append(self._thirdparty_supervisor) self._started = True logger.info( f"插件运行时已启动 — 内置: {builtin_dirs or '无'}, 第三方: {thirdparty_dirs or '无'}" ) except Exception as e: logger.error(f"插件运行时启动失败: {e}", exc_info=True) + await asyncio.gather(*(sv.stop() for sv in started_supervisors), return_exceptions=True) + self._started = False self._builtin_supervisor = None self._thirdparty_supervisor = None