feat: 实现插件注册的分阶段接入与切换机制,优化 RPC 连接管理

This commit is contained in:
DrSmoothl
2026-03-13 15:21:40 +08:00
parent 5d30b3a908
commit 44a9e9ecd7
5 changed files with 393 additions and 87 deletions

View File

@@ -677,6 +677,28 @@ class TestComponentRegistry:
assert removed == 2
assert reg.get_stats()["total"] == 1
def test_reregister_same_plugin_replaces_component_set(self):
from src.plugin_runtime.host.component_registry import ComponentRegistry
reg = ComponentRegistry()
reg.register_plugin_components(
"p1",
[
{"name": "a1", "component_type": "action", "metadata": {}},
{"name": "a2", "component_type": "action", "metadata": {}},
],
)
reg.remove_components_by_plugin("p1")
reg.register_plugin_components(
"p1",
[
{"name": "a1", "component_type": "action", "metadata": {}},
],
)
assert reg.get_component("p1.a1") is not None
assert reg.get_component("p1.a2") is None
def test_event_handlers_sorted_by_weight(self):
from src.plugin_runtime.host.component_registry import ComponentRegistry
@@ -1257,7 +1279,7 @@ class TestRPCServer:
loop = asyncio.new_event_loop()
try:
future = loop.create_future()
server._pending_requests[1] = future
server._pending_requests[1] = (future, 2)
stale_response = Envelope(
request_id=1,
@@ -1274,23 +1296,53 @@ class TestRPCServer:
loop.close()
class TestRPCClient:
"""Runner RPCClient 后台任务生命周期测试"""
@pytest.mark.asyncio
async def test_background_tasks_retained_and_cancelled_on_disconnect(self):
from src.plugin_runtime.runner.rpc_client import RPCClient
client = RPCClient(host_address="dummy", session_token="token")
release = asyncio.Event()
async def pending_task():
await release.wait()
task = asyncio.create_task(pending_task())
client._track_background_task(task)
assert task in client._background_tasks
await asyncio.sleep(0)
assert task in client._background_tasks
await client.disconnect()
assert task.cancelled() is True
assert not client._background_tasks
class TestSupervisor:
"""Supervisor 生命周期边界测试"""
@staticmethod
def _build_register_payload(plugin_id: str = "plugin_a"):
def _build_register_payload(plugin_id: str = "plugin_a", component_names=None):
from src.plugin_runtime.protocol.envelope import ComponentDeclaration, RegisterComponentsPayload
component_names = component_names or ["handler"]
return RegisterComponentsPayload(
plugin_id=plugin_id,
plugin_version="1.0.0",
components=[
ComponentDeclaration(
name="handler",
name=name,
component_type="event_handler",
plugin_id=plugin_id,
metadata={"event_type": "on_message"},
)
for name in component_names
],
capabilities_required=["send.text"],
)
@@ -1331,8 +1383,11 @@ class TestSupervisor:
class FakeRPCServer:
def __init__(self):
self.runner_generation = 1
self.staged_generation = 0
self.is_connected = True
self.session_token = "fake-token"
self.committed = False
self.staging_started = False
def reset_session_token(self):
self.session_token = "new-fake-token"
@@ -1341,8 +1396,23 @@ class TestSupervisor:
def restore_session_token(self, token):
self.session_token = token
async def send_request(self, method, timeout_ms=5000, **kwargs):
assert self.runner_generation == 2
def begin_staged_takeover(self):
self.staging_started = True
self.staged_generation = 2
async def commit_staged_takeover(self):
self.runner_generation = self.staged_generation
self.staged_generation = 0
self.committed = True
async def rollback_staged_takeover(self):
self.staged_generation = 0
def has_generation(self, generation):
return generation in {self.runner_generation, self.staged_generation}
async def send_request(self, method, timeout_ms=5000, target_generation=None, **kwargs):
assert target_generation == 2
return SimpleNamespace(payload=HealthPayload(healthy=True).model_dump())
supervisor._rpc_server = FakeRPCServer()
@@ -1350,18 +1420,14 @@ class TestSupervisor:
async def fake_spawn_runner():
supervisor._runner_process = new_process
async def advance_generation():
await asyncio.sleep(0.01)
supervisor._rpc_server.runner_generation = 2
asyncio.create_task(advance_generation())
supervisor._staged_registered_plugins["plugin_a"] = self._build_register_payload("plugin_a")
monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner)
await supervisor.reload_plugins("test")
assert supervisor._runner_process is new_process
assert supervisor._rpc_server.committed is True
assert old_process.terminated is True
@pytest.mark.asyncio
@@ -1380,6 +1446,69 @@ class TestSupervisor:
class FakeRPCServer:
def __init__(self):
self.runner_generation = 1
self.staged_generation = 0
self.is_connected = True
self.session_token = "fake-token"
self.rolled_back = False
def reset_session_token(self):
self.session_token = "new-fake-token"
return self.session_token
def restore_session_token(self, token):
self.session_token = token
def begin_staged_takeover(self):
self.staged_generation = 2
async def commit_staged_takeover(self):
self.runner_generation = self.staged_generation
self.staged_generation = 0
async def rollback_staged_takeover(self):
self.rolled_back = True
self.staged_generation = 0
def has_generation(self, generation):
return generation in {self.runner_generation, self.staged_generation}
async def send_request(self, method, timeout_ms=5000, target_generation=None, **kwargs):
raise RuntimeError("new runner unhealthy")
supervisor._rpc_server = FakeRPCServer()
async def fake_spawn_runner():
supervisor._runner_process = new_process
supervisor._staged_registered_plugins["plugin_a"] = self._build_register_payload("plugin_a")
monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner)
await supervisor.reload_plugins("test")
assert supervisor._runner_process is old_process
assert supervisor._rpc_server.rolled_back is True
assert old_reg.plugin_id in supervisor._registered_plugins
assert supervisor.component_registry.get_component("plugin_a.handler") is not None
@pytest.mark.asyncio
async def test_reload_rebuilds_exact_component_set(self, monkeypatch):
from src.plugin_runtime.host.supervisor import PluginSupervisor
from src.plugin_runtime.protocol.envelope import HealthPayload
supervisor = PluginSupervisor(plugin_dirs=[])
old_process = self._make_process(1)
new_process = self._make_process(2)
old_reg = self._build_register_payload("plugin_a", component_names=["handler", "obsolete"])
new_reg = self._build_register_payload("plugin_a", component_names=["handler"])
supervisor._runner_process = old_process
supervisor._registered_plugins[old_reg.plugin_id] = old_reg
supervisor._rebuild_runtime_state()
class FakeRPCServer:
def __init__(self):
self.runner_generation = 1
self.staged_generation = 0
self.is_connected = True
self.session_token = "fake-token"
@@ -1390,22 +1519,34 @@ class TestSupervisor:
def restore_session_token(self, token):
self.session_token = token
async def send_request(self, method, timeout_ms=5000, **kwargs):
raise RuntimeError("new runner unhealthy")
def begin_staged_takeover(self):
self.staged_generation = 2
async def commit_staged_takeover(self):
self.runner_generation = self.staged_generation
self.staged_generation = 0
async def rollback_staged_takeover(self):
self.staged_generation = 0
def has_generation(self, generation):
return generation in {self.runner_generation, self.staged_generation}
async def send_request(self, method, timeout_ms=5000, target_generation=None, **kwargs):
return SimpleNamespace(payload=HealthPayload(healthy=True).model_dump())
supervisor._rpc_server = FakeRPCServer()
async def fake_spawn_runner():
supervisor._runner_process = new_process
supervisor._rpc_server.runner_generation = 2
supervisor._staged_registered_plugins[new_reg.plugin_id] = new_reg
monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner)
await supervisor.reload_plugins("test")
assert supervisor._runner_process is old_process
assert old_reg.plugin_id in supervisor._registered_plugins
assert supervisor.component_registry.get_component("plugin_a.handler") is not None
assert supervisor.component_registry.get_component("plugin_a.obsolete") is None
@pytest.mark.asyncio
async def test_attach_stderr_drain_drains_stream(self):

View File

@@ -7,7 +7,7 @@
4. 请求-响应关联与超时管理
"""
from typing import Any, Awaitable, Callable, Dict, List, Optional
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
import asyncio
import secrets
@@ -55,12 +55,16 @@ class RPCServer:
self._connection: Optional[Connection] = None # 当前活跃的 Runner 连接
self._runner_id: Optional[str] = None
self._runner_generation: int = 0
self._staged_connection: Optional[Connection] = None
self._staged_runner_id: Optional[str] = None
self._staged_runner_generation: int = 0
self._staging_takeover: bool = False
# 方法处理器注册表
self._method_handlers: Dict[str, MethodHandler] = {}
# 等待响应的 pending 请求: request_id -> Future
self._pending_requests: Dict[int, asyncio.Future] = {}
# 等待响应的 pending 请求: request_id -> (Future, target_generation)
self._pending_requests: Dict[int, Tuple[asyncio.Future, int]] = {}
# 发送队列(背压控制)
self._send_queue: Optional[asyncio.Queue[bytes]] = None
@@ -86,10 +90,72 @@ class RPCServer:
def runner_generation(self) -> int:
return self._runner_generation
@property
def staged_generation(self) -> int:
return self._staged_runner_generation
@property
def is_connected(self) -> bool:
return self._connection is not None and not self._connection.is_closed
def has_generation(self, generation: int) -> bool:
return generation == self._runner_generation or (
self._staged_connection is not None
and not self._staged_connection.is_closed
and generation == self._staged_runner_generation
)
def begin_staged_takeover(self) -> None:
"""允许新 Runner 以 staged 方式接入,待 Supervisor 验证后再切换为活跃连接。"""
self._staging_takeover = True
async def commit_staged_takeover(self) -> None:
"""提交 staged Runner原活跃连接在提交后被关闭。"""
if self._staged_connection is None or self._staged_connection.is_closed:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "没有可提交的新 Runner 连接")
old_connection = self._connection
old_generation = self._runner_generation
self._connection = self._staged_connection
self._runner_id = self._staged_runner_id
self._runner_generation = self._staged_runner_generation
self._staged_connection = None
self._staged_runner_id = None
self._staged_runner_generation = 0
self._staging_takeover = False
stale_count = self._fail_pending_requests(
ErrorCode.E_PLUGIN_CRASHED,
"Runner 连接已被新 generation 接管",
generation=old_generation,
)
if stale_count:
logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求")
if old_connection and old_connection is not self._connection and not old_connection.is_closed:
await old_connection.close()
async def rollback_staged_takeover(self) -> None:
"""放弃 staged Runner保留当前活跃连接。"""
staged_connection = self._staged_connection
staged_generation = self._staged_runner_generation
self._staged_connection = None
self._staged_runner_id = None
self._staged_runner_generation = 0
self._staging_takeover = False
self._fail_pending_requests(
ErrorCode.E_PLUGIN_CRASHED,
"新 Runner 预热失败,已回滚",
generation=staged_generation,
)
if staged_connection and not staged_connection.is_closed:
await staged_connection.close()
def register_method(self, method: str, handler: MethodHandler) -> None:
"""注册 RPC 方法处理器"""
self._method_handlers[method] = handler
@@ -106,7 +172,7 @@ class RPCServer:
self._running = False
# 取消所有 pending 请求
for future in self._pending_requests.values():
for future, _generation in self._pending_requests.values():
if not future.done():
future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
self._pending_requests.clear()
@@ -121,6 +187,10 @@ class RPCServer:
await self._connection.close()
self._connection = None
if self._staged_connection:
await self._staged_connection.close()
self._staged_connection = None
await self._transport.stop()
logger.info("RPC Server 已停止")
@@ -130,6 +200,7 @@ class RPCServer:
plugin_id: str = "",
payload: Optional[Dict[str, Any]] = None,
timeout_ms: int = 30000,
target_generation: Optional[int] = None,
) -> Envelope:
"""向 Runner 发送 RPC 请求并等待响应
@@ -145,7 +216,9 @@ class RPCServer:
Raises:
RPCError: 调用失败
"""
if not self.is_connected:
generation = target_generation or self._runner_generation
conn = self._get_connection_for_generation(generation)
if conn is None or conn.is_closed:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
request_id = self._id_gen.next()
@@ -154,7 +227,7 @@ class RPCServer:
message_type=MessageType.REQUEST,
method=method,
plugin_id=plugin_id,
generation=self._runner_generation,
generation=generation,
timeout_ms=timeout_ms,
payload=payload or {},
)
@@ -166,12 +239,12 @@ class RPCServer:
# 注册 pending future
loop = asyncio.get_running_loop()
future: asyncio.Future[Envelope] = loop.create_future()
self._pending_requests[request_id] = future
self._pending_requests[request_id] = (future, generation)
try:
# 发送请求
data = self._codec.encode_envelope(envelope)
await self._connection.send_frame(data)
await conn.send_frame(data)
# 等待响应
timeout_sec = timeout_ms / 1000.0
@@ -207,11 +280,13 @@ class RPCServer:
async def _handle_connection(self, conn: Connection) -> None:
"""处理新的 Runner 连接"""
logger.info("收到 Runner 连接")
previous_connection = self._connection
previous_generation = self._runner_generation
# 第一条消息必须是 runner.hello 握手
try:
handshake_ok = await self._handle_handshake(conn)
if not handshake_ok:
role = await self._handle_handshake(conn)
if role is None:
await conn.close()
return
except Exception as e:
@@ -219,41 +294,52 @@ class RPCServer:
await conn.close()
return
old_connection = self._connection
self._connection = conn
logger.info(f"Runner 握手成功: runner_id={self._runner_id}, generation={self._runner_generation}")
if role == "staged":
expected_generation = self._staged_runner_generation
logger.info(
f"Runner staged 握手成功: runner_id={self._staged_runner_id}, generation={self._staged_runner_generation}"
)
else:
self._connection = conn
expected_generation = self._runner_generation
logger.info(f"Runner 握手成功: runner_id={self._runner_id}, generation={self._runner_generation}")
if old_connection and old_connection is not conn and not old_connection.is_closed:
logger.info("检测到新 Runner 已接管连接,关闭旧连接")
# 新连接接管后,旧 Runner 的 in-flight 请求不会再收到响应
# (过期 generation 响应会被 _handle_response 丢弃),
# 在此处立即 fail-fast 所有 pending 请求,避免挂到超时
stale_count = 0
for _req_id, future in list(self._pending_requests.items()):
if not future.done():
future.set_exception(RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已被新 generation 接管"))
stale_count += 1
self._pending_requests.clear()
if stale_count:
logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求")
await old_connection.close()
if previous_connection and previous_connection is not conn and not previous_connection.is_closed:
logger.info("检测到新 Runner 已接管连接,关闭旧连接")
stale_count = self._fail_pending_requests(
ErrorCode.E_PLUGIN_CRASHED,
"Runner 连接已被新 generation 接管",
generation=previous_generation,
)
if stale_count:
logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求")
await previous_connection.close()
# 启动消息接收循环
try:
await self._recv_loop(conn)
await self._recv_loop(conn, expected_generation=expected_generation)
except Exception as e:
logger.error(f"连接异常断开: {e}")
finally:
if self._connection is conn:
self._connection = None
self._runner_id = None
# 连接断开时,立即让所有等待中的请求失败,避免挂起至超时
for _req_id, future in list(self._pending_requests.items()):
if not future.done():
future.set_exception(RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已断开"))
self._pending_requests.clear()
self._fail_pending_requests(
ErrorCode.E_PLUGIN_CRASHED,
"Runner 连接已断开",
generation=expected_generation,
)
elif self._staged_connection is conn:
self._staged_connection = None
self._staged_runner_id = None
self._staged_runner_generation = 0
self._fail_pending_requests(
ErrorCode.E_PLUGIN_CRASHED,
"Staged Runner 连接已断开",
generation=expected_generation,
)
async def _handle_handshake(self, conn: Connection) -> bool:
async def _handle_handshake(self, conn: Connection) -> Optional[str]:
"""处理 runner.hello 握手"""
# 接收握手请求
data = await asyncio.wait_for(conn.recv_frame(), timeout=10.0)
@@ -266,7 +352,7 @@ class RPCServer:
"首条消息必须为 runner.hello",
)
await conn.send_frame(self._codec.encode_envelope(error_resp))
return False
return None
# 解析握手 payload
hello = HelloPayload.model_validate(envelope.payload)
@@ -280,7 +366,7 @@ class RPCServer:
)
resp = envelope.make_response(payload=resp_payload.model_dump())
await conn.send_frame(self._codec.encode_envelope(resp))
return False
return None
# 校验 SDK 版本
if not self._check_sdk_version(hello.sdk_version):
@@ -291,23 +377,31 @@ class RPCServer:
)
resp = envelope.make_response(payload=resp_payload.model_dump())
await conn.send_frame(self._codec.encode_envelope(resp))
return False
return None
# 握手成功
self._runner_id = hello.runner_id
self._runner_generation += 1
role = "active"
assigned_generation = self._runner_generation + 1
if self._staging_takeover and self.is_connected:
role = "staged"
self._staged_connection = conn
self._staged_runner_id = hello.runner_id
self._staged_runner_generation = assigned_generation
else:
self._runner_id = hello.runner_id
self._runner_generation = assigned_generation
resp_payload = HelloResponsePayload(
accepted=True,
host_version=PROTOCOL_VERSION,
assigned_generation=self._runner_generation,
assigned_generation=assigned_generation,
)
resp = envelope.make_response(payload=resp_payload.model_dump())
await conn.send_frame(self._codec.encode_envelope(resp))
return True
return role
async def _recv_loop(self, conn: Connection) -> None:
async def _recv_loop(self, conn: Connection, expected_generation: int) -> None:
"""消息接收主循环"""
while self._running and not conn.is_closed:
try:
@@ -329,10 +423,10 @@ class RPCServer:
if envelope.is_response():
self._handle_response(envelope)
elif envelope.is_request():
if not self._is_current_generation(envelope):
if envelope.generation != expected_generation:
error_resp = envelope.make_error_response(
ErrorCode.E_GENERATION_MISMATCH.value,
f"过期 generation: {envelope.generation} != {self._runner_generation}",
f"过期 generation: {envelope.generation} != {expected_generation}",
)
await conn.send_frame(self._codec.encode_envelope(error_resp))
continue
@@ -341,9 +435,9 @@ class RPCServer:
self._tasks.append(task)
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
elif envelope.is_event():
if not self._is_current_generation(envelope):
if envelope.generation != expected_generation:
logger.warning(
f"忽略过期 generation 事件 {envelope.method}: {envelope.generation} != {self._runner_generation}"
f"忽略过期 generation 事件 {envelope.method}: {envelope.generation} != {expected_generation}"
)
continue
task = asyncio.create_task(self._handle_event(envelope))
@@ -352,22 +446,24 @@ class RPCServer:
def _handle_response(self, envelope: Envelope) -> None:
"""处理来自 Runner 的响应"""
if not self._is_current_generation(envelope):
pending = self._pending_requests.get(envelope.request_id)
if pending is None:
return
future, expected_generation = pending
if envelope.generation != expected_generation:
logger.warning(
f"忽略过期 generation 响应 {envelope.method}: {envelope.generation} != {self._runner_generation}"
f"忽略过期 generation 响应 {envelope.method}: {envelope.generation} != {expected_generation}"
)
return
future = self._pending_requests.pop(envelope.request_id, None)
if future and not future.done():
self._pending_requests.pop(envelope.request_id, None)
if not future.done():
if envelope.error:
future.set_exception(RPCError.from_dict(envelope.error))
else:
future.set_result(envelope)
def _is_current_generation(self, envelope: Envelope) -> bool:
return envelope.generation == self._runner_generation
async def _handle_request(self, envelope: Envelope, conn: Connection) -> None:
"""处理来自 Runner 的请求(通常是能力调用 cap.*"""
handler = self._method_handlers.get(envelope.method)
@@ -411,3 +507,26 @@ class RPCServer:
return min_parts <= sdk_parts <= max_parts
except (ValueError, AttributeError):
return False
def _get_connection_for_generation(self, generation: int) -> Optional[Connection]:
if generation == self._runner_generation:
return self._connection
if generation == self._staged_runner_generation:
return self._staged_connection
return None
def _fail_pending_requests(
self,
error_code: ErrorCode,
message: str,
generation: Optional[int] = None,
) -> int:
stale_count = 0
for request_id, (future, request_generation) in list(self._pending_requests.items()):
if generation is not None and request_generation != generation:
continue
if not future.done():
future.set_exception(RPCError(error_code, message))
stale_count += 1
self._pending_requests.pop(request_id, None)
return stale_count

View File

@@ -135,6 +135,7 @@ class PluginSupervisor:
# 已注册的插件组件信息
self._registered_plugins: Dict[str, RegisterComponentsPayload] = {}
self._staged_registered_plugins: Dict[str, RegisterComponentsPayload] = {}
# 后台任务
self._health_task: Optional[asyncio.Task] = None
@@ -319,6 +320,10 @@ class PluginSupervisor:
old_session_token = self._rpc_server.session_token
expected_generation = self._rpc_server.runner_generation + 1
# 允许新 Runner 以 staged 方式接入,验证通过后再切换活跃连接
self._rpc_server.begin_staged_takeover()
self._staged_registered_plugins.clear()
# 重新生成 session token防止被终止的旧 Runner 重连
self._rpc_server.reset_session_token()
@@ -330,27 +335,35 @@ class PluginSupervisor:
# 拉起新 Runner
try:
await self._spawn_runner()
await self._wait_for_runner_generation(expected_generation, timeout_sec=self._runner_spawn_timeout)
resp = await self._rpc_server.send_request("plugin.health", timeout_ms=5000)
await self._wait_for_runner_generation(
expected_generation,
timeout_sec=self._runner_spawn_timeout,
allow_staged=True,
)
resp = await self._rpc_server.send_request(
"plugin.health",
timeout_ms=5000,
target_generation=expected_generation,
)
health = HealthPayload.model_validate(resp.payload)
if not health.healthy:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "新 Runner 健康检查失败")
await self._rpc_server.commit_staged_takeover()
except Exception as e:
logger.error(f"新 Runner 健康检查失败: {e},回滚")
await self._terminate_process(self._runner_process, old_process)
await self._rpc_server.rollback_staged_takeover()
self._runner_process = old_process
# 恢复旧 session token使旧 Runner 的连接仍可正常工作
self._rpc_server.restore_session_token(old_session_token)
self._staged_registered_plugins.clear()
self._registered_plugins = dict(old_registered_plugins)
self._rebuild_runtime_state()
return
# 新 Runner 健康且已完成组件注册,现在清理旧的幽灵组件
# 只移除不再存在于新注册表中的旧插件组件
for old_pid in list(old_registered_plugins.keys()):
if old_pid not in self._registered_plugins:
self._component_registry.remove_components_by_plugin(old_pid)
self._policy.revoke_plugin(old_pid)
self._runner_generation = self._rpc_server.runner_generation
self._registered_plugins = dict(self._staged_registered_plugins)
self._staged_registered_plugins.clear()
self._rebuild_runtime_state()
# 关停旧 Runner
if old_process and old_process.returncode is None:
@@ -380,13 +393,22 @@ class PluginSupervisor:
except Exception as e:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e))
if envelope.generation != self._rpc_server.runner_generation:
active_generation = self._rpc_server.runner_generation
staged_generation = self._rpc_server.staged_generation
if envelope.generation not in {active_generation, staged_generation}:
return envelope.make_error_response(
ErrorCode.E_GENERATION_MISMATCH.value,
f"组件注册 generation 过期: {envelope.generation} != {self._rpc_server.runner_generation}",
f"组件注册 generation 过期: {envelope.generation} 不在已知代际中",
)
# 记录注册信息
if envelope.generation == staged_generation and staged_generation != 0:
self._staged_registered_plugins[reg.plugin_id] = reg
logger.info(
f"插件 {reg.plugin_id} v{reg.plugin_version} staged 注册成功,"
f"组件数: {len(reg.components)}, 能力需求: {reg.capabilities_required}"
)
return envelope.make_response(payload={"accepted": True, "staged": True})
self._registered_plugins[reg.plugin_id] = reg
# 在策略引擎中注册插件
@@ -396,7 +418,8 @@ class PluginSupervisor:
capabilities=reg.capabilities_required or [],
)
# 在 ComponentRegistry 中注册组件
# 同 generation 下重新注册时,以本次声明为准,避免残留幽灵组件
self._component_registry.remove_components_by_plugin(reg.plugin_id)
self._component_registry.register_plugin_components(
plugin_id=reg.plugin_id,
components=[c.model_dump() for c in reg.components],
@@ -518,10 +541,17 @@ class PluginSupervisor:
except Exception as e:
logger.error(f"健康检查异常: {e}")
async def _wait_for_runner_generation(self, expected_generation: int, timeout_sec: float) -> None:
async def _wait_for_runner_generation(
self,
expected_generation: int,
timeout_sec: float,
allow_staged: bool = False,
) -> None:
"""等待指定代际的 Runner 完成连接。"""
deadline = asyncio.get_running_loop().time() + timeout_sec
while asyncio.get_running_loop().time() < deadline:
if allow_staged and self._rpc_server.has_generation(expected_generation):
return
if self._rpc_server.is_connected and self._rpc_server.runner_generation >= expected_generation:
self._runner_generation = self._rpc_server.runner_generation
return
@@ -533,6 +563,7 @@ class PluginSupervisor:
self._component_registry.clear()
self._policy.clear()
self._registered_plugins.clear()
self._staged_registered_plugins.clear()
def _rebuild_runtime_state(self) -> None:
"""根据已记录的插件注册信息重建运行时状态。"""

View File

@@ -430,8 +430,10 @@ class PluginRuntimeManager:
"""
from src.services import send_service as send_api
message_type: str = args.get("message_type", "")
content = args.get("content", "")
message_type: str = args.get("message_type", "") or args.get("custom_type", "")
content = args.get("content")
if content is None:
content = args.get("data", "")
stream_id: str = args.get("stream_id", "")
if not message_type or not stream_id:
return {"success": False, "error": "缺少必要参数 message_type 或 stream_id"}

View File

@@ -76,6 +76,7 @@ class RPCClient:
# 运行状态
self._running = False
self._recv_task: Optional[asyncio.Task] = None
self._background_tasks: set[asyncio.Task] = set()
@property
def generation(self) -> int:
@@ -144,6 +145,13 @@ class RPCClient:
await self._recv_task
self._recv_task = None
for task in list(self._background_tasks):
task.cancel()
if self._background_tasks:
with contextlib.suppress(Exception):
await asyncio.gather(*self._background_tasks, return_exceptions=True)
self._background_tasks.clear()
# 取消所有 pending 请求
for future in self._pending_requests.values():
if not future.done():
@@ -248,9 +256,9 @@ class RPCClient:
if envelope.is_response():
self._handle_response(envelope)
elif envelope.is_request():
asyncio.create_task(self._handle_request(envelope))
self._track_background_task(asyncio.create_task(self._handle_request(envelope)))
elif envelope.is_event():
asyncio.create_task(self._handle_event(envelope))
self._track_background_task(asyncio.create_task(self._handle_event(envelope)))
def _handle_response(self, envelope: Envelope) -> None:
"""处理来自 Host 的响应"""
@@ -290,3 +298,8 @@ class RPCClient:
await handler(envelope)
except Exception as e:
logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True)
def _track_background_task(self, task: asyncio.Task) -> None:
"""保持后台任务强引用,直到其完成或被取消。"""
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)