feat: 实现插件注册的分阶段接入与切换机制,优化 RPC 连接管理
This commit is contained in:
@@ -677,6 +677,28 @@ class TestComponentRegistry:
|
||||
assert removed == 2
|
||||
assert reg.get_stats()["total"] == 1
|
||||
|
||||
def test_reregister_same_plugin_replaces_component_set(self):
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_plugin_components(
|
||||
"p1",
|
||||
[
|
||||
{"name": "a1", "component_type": "action", "metadata": {}},
|
||||
{"name": "a2", "component_type": "action", "metadata": {}},
|
||||
],
|
||||
)
|
||||
reg.remove_components_by_plugin("p1")
|
||||
reg.register_plugin_components(
|
||||
"p1",
|
||||
[
|
||||
{"name": "a1", "component_type": "action", "metadata": {}},
|
||||
],
|
||||
)
|
||||
|
||||
assert reg.get_component("p1.a1") is not None
|
||||
assert reg.get_component("p1.a2") is None
|
||||
|
||||
def test_event_handlers_sorted_by_weight(self):
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
|
||||
@@ -1257,7 +1279,7 @@ class TestRPCServer:
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
future = loop.create_future()
|
||||
server._pending_requests[1] = future
|
||||
server._pending_requests[1] = (future, 2)
|
||||
|
||||
stale_response = Envelope(
|
||||
request_id=1,
|
||||
@@ -1274,23 +1296,53 @@ class TestRPCServer:
|
||||
loop.close()
|
||||
|
||||
|
||||
class TestRPCClient:
|
||||
"""Runner RPCClient 后台任务生命周期测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_background_tasks_retained_and_cancelled_on_disconnect(self):
|
||||
from src.plugin_runtime.runner.rpc_client import RPCClient
|
||||
|
||||
client = RPCClient(host_address="dummy", session_token="token")
|
||||
release = asyncio.Event()
|
||||
|
||||
async def pending_task():
|
||||
await release.wait()
|
||||
|
||||
task = asyncio.create_task(pending_task())
|
||||
client._track_background_task(task)
|
||||
|
||||
assert task in client._background_tasks
|
||||
|
||||
await asyncio.sleep(0)
|
||||
assert task in client._background_tasks
|
||||
|
||||
await client.disconnect()
|
||||
|
||||
assert task.cancelled() is True
|
||||
assert not client._background_tasks
|
||||
|
||||
|
||||
class TestSupervisor:
|
||||
"""Supervisor 生命周期边界测试"""
|
||||
|
||||
@staticmethod
|
||||
def _build_register_payload(plugin_id: str = "plugin_a"):
|
||||
def _build_register_payload(plugin_id: str = "plugin_a", component_names=None):
|
||||
from src.plugin_runtime.protocol.envelope import ComponentDeclaration, RegisterComponentsPayload
|
||||
|
||||
component_names = component_names or ["handler"]
|
||||
|
||||
return RegisterComponentsPayload(
|
||||
plugin_id=plugin_id,
|
||||
plugin_version="1.0.0",
|
||||
components=[
|
||||
ComponentDeclaration(
|
||||
name="handler",
|
||||
name=name,
|
||||
component_type="event_handler",
|
||||
plugin_id=plugin_id,
|
||||
metadata={"event_type": "on_message"},
|
||||
)
|
||||
for name in component_names
|
||||
],
|
||||
capabilities_required=["send.text"],
|
||||
)
|
||||
@@ -1331,8 +1383,11 @@ class TestSupervisor:
|
||||
class FakeRPCServer:
|
||||
def __init__(self):
|
||||
self.runner_generation = 1
|
||||
self.staged_generation = 0
|
||||
self.is_connected = True
|
||||
self.session_token = "fake-token"
|
||||
self.committed = False
|
||||
self.staging_started = False
|
||||
|
||||
def reset_session_token(self):
|
||||
self.session_token = "new-fake-token"
|
||||
@@ -1341,8 +1396,23 @@ class TestSupervisor:
|
||||
def restore_session_token(self, token):
|
||||
self.session_token = token
|
||||
|
||||
async def send_request(self, method, timeout_ms=5000, **kwargs):
|
||||
assert self.runner_generation == 2
|
||||
def begin_staged_takeover(self):
|
||||
self.staging_started = True
|
||||
self.staged_generation = 2
|
||||
|
||||
async def commit_staged_takeover(self):
|
||||
self.runner_generation = self.staged_generation
|
||||
self.staged_generation = 0
|
||||
self.committed = True
|
||||
|
||||
async def rollback_staged_takeover(self):
|
||||
self.staged_generation = 0
|
||||
|
||||
def has_generation(self, generation):
|
||||
return generation in {self.runner_generation, self.staged_generation}
|
||||
|
||||
async def send_request(self, method, timeout_ms=5000, target_generation=None, **kwargs):
|
||||
assert target_generation == 2
|
||||
return SimpleNamespace(payload=HealthPayload(healthy=True).model_dump())
|
||||
|
||||
supervisor._rpc_server = FakeRPCServer()
|
||||
@@ -1350,18 +1420,14 @@ class TestSupervisor:
|
||||
|
||||
async def fake_spawn_runner():
|
||||
supervisor._runner_process = new_process
|
||||
|
||||
async def advance_generation():
|
||||
await asyncio.sleep(0.01)
|
||||
supervisor._rpc_server.runner_generation = 2
|
||||
|
||||
asyncio.create_task(advance_generation())
|
||||
supervisor._staged_registered_plugins["plugin_a"] = self._build_register_payload("plugin_a")
|
||||
|
||||
monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner)
|
||||
|
||||
await supervisor.reload_plugins("test")
|
||||
|
||||
assert supervisor._runner_process is new_process
|
||||
assert supervisor._rpc_server.committed is True
|
||||
assert old_process.terminated is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -1380,6 +1446,69 @@ class TestSupervisor:
|
||||
class FakeRPCServer:
|
||||
def __init__(self):
|
||||
self.runner_generation = 1
|
||||
self.staged_generation = 0
|
||||
self.is_connected = True
|
||||
self.session_token = "fake-token"
|
||||
self.rolled_back = False
|
||||
|
||||
def reset_session_token(self):
|
||||
self.session_token = "new-fake-token"
|
||||
return self.session_token
|
||||
|
||||
def restore_session_token(self, token):
|
||||
self.session_token = token
|
||||
|
||||
def begin_staged_takeover(self):
|
||||
self.staged_generation = 2
|
||||
|
||||
async def commit_staged_takeover(self):
|
||||
self.runner_generation = self.staged_generation
|
||||
self.staged_generation = 0
|
||||
|
||||
async def rollback_staged_takeover(self):
|
||||
self.rolled_back = True
|
||||
self.staged_generation = 0
|
||||
|
||||
def has_generation(self, generation):
|
||||
return generation in {self.runner_generation, self.staged_generation}
|
||||
|
||||
async def send_request(self, method, timeout_ms=5000, target_generation=None, **kwargs):
|
||||
raise RuntimeError("new runner unhealthy")
|
||||
|
||||
supervisor._rpc_server = FakeRPCServer()
|
||||
|
||||
async def fake_spawn_runner():
|
||||
supervisor._runner_process = new_process
|
||||
supervisor._staged_registered_plugins["plugin_a"] = self._build_register_payload("plugin_a")
|
||||
|
||||
monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner)
|
||||
|
||||
await supervisor.reload_plugins("test")
|
||||
|
||||
assert supervisor._runner_process is old_process
|
||||
assert supervisor._rpc_server.rolled_back is True
|
||||
assert old_reg.plugin_id in supervisor._registered_plugins
|
||||
assert supervisor.component_registry.get_component("plugin_a.handler") is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reload_rebuilds_exact_component_set(self, monkeypatch):
|
||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||||
from src.plugin_runtime.protocol.envelope import HealthPayload
|
||||
|
||||
supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
old_process = self._make_process(1)
|
||||
new_process = self._make_process(2)
|
||||
old_reg = self._build_register_payload("plugin_a", component_names=["handler", "obsolete"])
|
||||
new_reg = self._build_register_payload("plugin_a", component_names=["handler"])
|
||||
|
||||
supervisor._runner_process = old_process
|
||||
supervisor._registered_plugins[old_reg.plugin_id] = old_reg
|
||||
supervisor._rebuild_runtime_state()
|
||||
|
||||
class FakeRPCServer:
|
||||
def __init__(self):
|
||||
self.runner_generation = 1
|
||||
self.staged_generation = 0
|
||||
self.is_connected = True
|
||||
self.session_token = "fake-token"
|
||||
|
||||
@@ -1390,22 +1519,34 @@ class TestSupervisor:
|
||||
def restore_session_token(self, token):
|
||||
self.session_token = token
|
||||
|
||||
async def send_request(self, method, timeout_ms=5000, **kwargs):
|
||||
raise RuntimeError("new runner unhealthy")
|
||||
def begin_staged_takeover(self):
|
||||
self.staged_generation = 2
|
||||
|
||||
async def commit_staged_takeover(self):
|
||||
self.runner_generation = self.staged_generation
|
||||
self.staged_generation = 0
|
||||
|
||||
async def rollback_staged_takeover(self):
|
||||
self.staged_generation = 0
|
||||
|
||||
def has_generation(self, generation):
|
||||
return generation in {self.runner_generation, self.staged_generation}
|
||||
|
||||
async def send_request(self, method, timeout_ms=5000, target_generation=None, **kwargs):
|
||||
return SimpleNamespace(payload=HealthPayload(healthy=True).model_dump())
|
||||
|
||||
supervisor._rpc_server = FakeRPCServer()
|
||||
|
||||
async def fake_spawn_runner():
|
||||
supervisor._runner_process = new_process
|
||||
supervisor._rpc_server.runner_generation = 2
|
||||
supervisor._staged_registered_plugins[new_reg.plugin_id] = new_reg
|
||||
|
||||
monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner)
|
||||
|
||||
await supervisor.reload_plugins("test")
|
||||
|
||||
assert supervisor._runner_process is old_process
|
||||
assert old_reg.plugin_id in supervisor._registered_plugins
|
||||
assert supervisor.component_registry.get_component("plugin_a.handler") is not None
|
||||
assert supervisor.component_registry.get_component("plugin_a.obsolete") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_stderr_drain_drains_stream(self):
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
4. 请求-响应关联与超时管理
|
||||
"""
|
||||
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import asyncio
|
||||
import secrets
|
||||
@@ -55,12 +55,16 @@ class RPCServer:
|
||||
self._connection: Optional[Connection] = None # 当前活跃的 Runner 连接
|
||||
self._runner_id: Optional[str] = None
|
||||
self._runner_generation: int = 0
|
||||
self._staged_connection: Optional[Connection] = None
|
||||
self._staged_runner_id: Optional[str] = None
|
||||
self._staged_runner_generation: int = 0
|
||||
self._staging_takeover: bool = False
|
||||
|
||||
# 方法处理器注册表
|
||||
self._method_handlers: Dict[str, MethodHandler] = {}
|
||||
|
||||
# 等待响应的 pending 请求: request_id -> Future
|
||||
self._pending_requests: Dict[int, asyncio.Future] = {}
|
||||
# 等待响应的 pending 请求: request_id -> (Future, target_generation)
|
||||
self._pending_requests: Dict[int, Tuple[asyncio.Future, int]] = {}
|
||||
|
||||
# 发送队列(背压控制)
|
||||
self._send_queue: Optional[asyncio.Queue[bytes]] = None
|
||||
@@ -86,10 +90,72 @@ class RPCServer:
|
||||
def runner_generation(self) -> int:
|
||||
return self._runner_generation
|
||||
|
||||
@property
|
||||
def staged_generation(self) -> int:
|
||||
return self._staged_runner_generation
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._connection is not None and not self._connection.is_closed
|
||||
|
||||
def has_generation(self, generation: int) -> bool:
|
||||
return generation == self._runner_generation or (
|
||||
self._staged_connection is not None
|
||||
and not self._staged_connection.is_closed
|
||||
and generation == self._staged_runner_generation
|
||||
)
|
||||
|
||||
def begin_staged_takeover(self) -> None:
|
||||
"""允许新 Runner 以 staged 方式接入,待 Supervisor 验证后再切换为活跃连接。"""
|
||||
self._staging_takeover = True
|
||||
|
||||
async def commit_staged_takeover(self) -> None:
|
||||
"""提交 staged Runner,原活跃连接在提交后被关闭。"""
|
||||
if self._staged_connection is None or self._staged_connection.is_closed:
|
||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "没有可提交的新 Runner 连接")
|
||||
|
||||
old_connection = self._connection
|
||||
old_generation = self._runner_generation
|
||||
|
||||
self._connection = self._staged_connection
|
||||
self._runner_id = self._staged_runner_id
|
||||
self._runner_generation = self._staged_runner_generation
|
||||
|
||||
self._staged_connection = None
|
||||
self._staged_runner_id = None
|
||||
self._staged_runner_generation = 0
|
||||
self._staging_takeover = False
|
||||
|
||||
stale_count = self._fail_pending_requests(
|
||||
ErrorCode.E_PLUGIN_CRASHED,
|
||||
"Runner 连接已被新 generation 接管",
|
||||
generation=old_generation,
|
||||
)
|
||||
if stale_count:
|
||||
logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求")
|
||||
|
||||
if old_connection and old_connection is not self._connection and not old_connection.is_closed:
|
||||
await old_connection.close()
|
||||
|
||||
async def rollback_staged_takeover(self) -> None:
|
||||
"""放弃 staged Runner,保留当前活跃连接。"""
|
||||
staged_connection = self._staged_connection
|
||||
staged_generation = self._staged_runner_generation
|
||||
|
||||
self._staged_connection = None
|
||||
self._staged_runner_id = None
|
||||
self._staged_runner_generation = 0
|
||||
self._staging_takeover = False
|
||||
|
||||
self._fail_pending_requests(
|
||||
ErrorCode.E_PLUGIN_CRASHED,
|
||||
"新 Runner 预热失败,已回滚",
|
||||
generation=staged_generation,
|
||||
)
|
||||
|
||||
if staged_connection and not staged_connection.is_closed:
|
||||
await staged_connection.close()
|
||||
|
||||
def register_method(self, method: str, handler: MethodHandler) -> None:
|
||||
"""注册 RPC 方法处理器"""
|
||||
self._method_handlers[method] = handler
|
||||
@@ -106,7 +172,7 @@ class RPCServer:
|
||||
self._running = False
|
||||
|
||||
# 取消所有 pending 请求
|
||||
for future in self._pending_requests.values():
|
||||
for future, _generation in self._pending_requests.values():
|
||||
if not future.done():
|
||||
future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
|
||||
self._pending_requests.clear()
|
||||
@@ -121,6 +187,10 @@ class RPCServer:
|
||||
await self._connection.close()
|
||||
self._connection = None
|
||||
|
||||
if self._staged_connection:
|
||||
await self._staged_connection.close()
|
||||
self._staged_connection = None
|
||||
|
||||
await self._transport.stop()
|
||||
logger.info("RPC Server 已停止")
|
||||
|
||||
@@ -130,6 +200,7 @@ class RPCServer:
|
||||
plugin_id: str = "",
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
timeout_ms: int = 30000,
|
||||
target_generation: Optional[int] = None,
|
||||
) -> Envelope:
|
||||
"""向 Runner 发送 RPC 请求并等待响应
|
||||
|
||||
@@ -145,7 +216,9 @@ class RPCServer:
|
||||
Raises:
|
||||
RPCError: 调用失败
|
||||
"""
|
||||
if not self.is_connected:
|
||||
generation = target_generation or self._runner_generation
|
||||
conn = self._get_connection_for_generation(generation)
|
||||
if conn is None or conn.is_closed:
|
||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
|
||||
|
||||
request_id = self._id_gen.next()
|
||||
@@ -154,7 +227,7 @@ class RPCServer:
|
||||
message_type=MessageType.REQUEST,
|
||||
method=method,
|
||||
plugin_id=plugin_id,
|
||||
generation=self._runner_generation,
|
||||
generation=generation,
|
||||
timeout_ms=timeout_ms,
|
||||
payload=payload or {},
|
||||
)
|
||||
@@ -166,12 +239,12 @@ class RPCServer:
|
||||
# 注册 pending future
|
||||
loop = asyncio.get_running_loop()
|
||||
future: asyncio.Future[Envelope] = loop.create_future()
|
||||
self._pending_requests[request_id] = future
|
||||
self._pending_requests[request_id] = (future, generation)
|
||||
|
||||
try:
|
||||
# 发送请求
|
||||
data = self._codec.encode_envelope(envelope)
|
||||
await self._connection.send_frame(data)
|
||||
await conn.send_frame(data)
|
||||
|
||||
# 等待响应
|
||||
timeout_sec = timeout_ms / 1000.0
|
||||
@@ -207,11 +280,13 @@ class RPCServer:
|
||||
async def _handle_connection(self, conn: Connection) -> None:
|
||||
"""处理新的 Runner 连接"""
|
||||
logger.info("收到 Runner 连接")
|
||||
previous_connection = self._connection
|
||||
previous_generation = self._runner_generation
|
||||
|
||||
# 第一条消息必须是 runner.hello 握手
|
||||
try:
|
||||
handshake_ok = await self._handle_handshake(conn)
|
||||
if not handshake_ok:
|
||||
role = await self._handle_handshake(conn)
|
||||
if role is None:
|
||||
await conn.close()
|
||||
return
|
||||
except Exception as e:
|
||||
@@ -219,41 +294,52 @@ class RPCServer:
|
||||
await conn.close()
|
||||
return
|
||||
|
||||
old_connection = self._connection
|
||||
self._connection = conn
|
||||
logger.info(f"Runner 握手成功: runner_id={self._runner_id}, generation={self._runner_generation}")
|
||||
if role == "staged":
|
||||
expected_generation = self._staged_runner_generation
|
||||
logger.info(
|
||||
f"Runner staged 握手成功: runner_id={self._staged_runner_id}, generation={self._staged_runner_generation}"
|
||||
)
|
||||
else:
|
||||
self._connection = conn
|
||||
expected_generation = self._runner_generation
|
||||
logger.info(f"Runner 握手成功: runner_id={self._runner_id}, generation={self._runner_generation}")
|
||||
|
||||
if old_connection and old_connection is not conn and not old_connection.is_closed:
|
||||
logger.info("检测到新 Runner 已接管连接,关闭旧连接")
|
||||
# 新连接接管后,旧 Runner 的 in-flight 请求不会再收到响应
|
||||
# (过期 generation 响应会被 _handle_response 丢弃),
|
||||
# 在此处立即 fail-fast 所有 pending 请求,避免挂到超时
|
||||
stale_count = 0
|
||||
for _req_id, future in list(self._pending_requests.items()):
|
||||
if not future.done():
|
||||
future.set_exception(RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已被新 generation 接管"))
|
||||
stale_count += 1
|
||||
self._pending_requests.clear()
|
||||
if stale_count:
|
||||
logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求")
|
||||
await old_connection.close()
|
||||
if previous_connection and previous_connection is not conn and not previous_connection.is_closed:
|
||||
logger.info("检测到新 Runner 已接管连接,关闭旧连接")
|
||||
stale_count = self._fail_pending_requests(
|
||||
ErrorCode.E_PLUGIN_CRASHED,
|
||||
"Runner 连接已被新 generation 接管",
|
||||
generation=previous_generation,
|
||||
)
|
||||
if stale_count:
|
||||
logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求")
|
||||
await previous_connection.close()
|
||||
|
||||
# 启动消息接收循环
|
||||
try:
|
||||
await self._recv_loop(conn)
|
||||
await self._recv_loop(conn, expected_generation=expected_generation)
|
||||
except Exception as e:
|
||||
logger.error(f"连接异常断开: {e}")
|
||||
finally:
|
||||
if self._connection is conn:
|
||||
self._connection = None
|
||||
self._runner_id = None
|
||||
# 连接断开时,立即让所有等待中的请求失败,避免挂起至超时
|
||||
for _req_id, future in list(self._pending_requests.items()):
|
||||
if not future.done():
|
||||
future.set_exception(RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已断开"))
|
||||
self._pending_requests.clear()
|
||||
self._fail_pending_requests(
|
||||
ErrorCode.E_PLUGIN_CRASHED,
|
||||
"Runner 连接已断开",
|
||||
generation=expected_generation,
|
||||
)
|
||||
elif self._staged_connection is conn:
|
||||
self._staged_connection = None
|
||||
self._staged_runner_id = None
|
||||
self._staged_runner_generation = 0
|
||||
self._fail_pending_requests(
|
||||
ErrorCode.E_PLUGIN_CRASHED,
|
||||
"Staged Runner 连接已断开",
|
||||
generation=expected_generation,
|
||||
)
|
||||
|
||||
async def _handle_handshake(self, conn: Connection) -> bool:
|
||||
async def _handle_handshake(self, conn: Connection) -> Optional[str]:
|
||||
"""处理 runner.hello 握手"""
|
||||
# 接收握手请求
|
||||
data = await asyncio.wait_for(conn.recv_frame(), timeout=10.0)
|
||||
@@ -266,7 +352,7 @@ class RPCServer:
|
||||
"首条消息必须为 runner.hello",
|
||||
)
|
||||
await conn.send_frame(self._codec.encode_envelope(error_resp))
|
||||
return False
|
||||
return None
|
||||
|
||||
# 解析握手 payload
|
||||
hello = HelloPayload.model_validate(envelope.payload)
|
||||
@@ -280,7 +366,7 @@ class RPCServer:
|
||||
)
|
||||
resp = envelope.make_response(payload=resp_payload.model_dump())
|
||||
await conn.send_frame(self._codec.encode_envelope(resp))
|
||||
return False
|
||||
return None
|
||||
|
||||
# 校验 SDK 版本
|
||||
if not self._check_sdk_version(hello.sdk_version):
|
||||
@@ -291,23 +377,31 @@ class RPCServer:
|
||||
)
|
||||
resp = envelope.make_response(payload=resp_payload.model_dump())
|
||||
await conn.send_frame(self._codec.encode_envelope(resp))
|
||||
return False
|
||||
return None
|
||||
|
||||
# 握手成功
|
||||
self._runner_id = hello.runner_id
|
||||
self._runner_generation += 1
|
||||
role = "active"
|
||||
assigned_generation = self._runner_generation + 1
|
||||
if self._staging_takeover and self.is_connected:
|
||||
role = "staged"
|
||||
self._staged_connection = conn
|
||||
self._staged_runner_id = hello.runner_id
|
||||
self._staged_runner_generation = assigned_generation
|
||||
else:
|
||||
self._runner_id = hello.runner_id
|
||||
self._runner_generation = assigned_generation
|
||||
|
||||
resp_payload = HelloResponsePayload(
|
||||
accepted=True,
|
||||
host_version=PROTOCOL_VERSION,
|
||||
assigned_generation=self._runner_generation,
|
||||
assigned_generation=assigned_generation,
|
||||
)
|
||||
resp = envelope.make_response(payload=resp_payload.model_dump())
|
||||
await conn.send_frame(self._codec.encode_envelope(resp))
|
||||
|
||||
return True
|
||||
return role
|
||||
|
||||
async def _recv_loop(self, conn: Connection) -> None:
|
||||
async def _recv_loop(self, conn: Connection, expected_generation: int) -> None:
|
||||
"""消息接收主循环"""
|
||||
while self._running and not conn.is_closed:
|
||||
try:
|
||||
@@ -329,10 +423,10 @@ class RPCServer:
|
||||
if envelope.is_response():
|
||||
self._handle_response(envelope)
|
||||
elif envelope.is_request():
|
||||
if not self._is_current_generation(envelope):
|
||||
if envelope.generation != expected_generation:
|
||||
error_resp = envelope.make_error_response(
|
||||
ErrorCode.E_GENERATION_MISMATCH.value,
|
||||
f"过期 generation: {envelope.generation} != {self._runner_generation}",
|
||||
f"过期 generation: {envelope.generation} != {expected_generation}",
|
||||
)
|
||||
await conn.send_frame(self._codec.encode_envelope(error_resp))
|
||||
continue
|
||||
@@ -341,9 +435,9 @@ class RPCServer:
|
||||
self._tasks.append(task)
|
||||
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
|
||||
elif envelope.is_event():
|
||||
if not self._is_current_generation(envelope):
|
||||
if envelope.generation != expected_generation:
|
||||
logger.warning(
|
||||
f"忽略过期 generation 事件 {envelope.method}: {envelope.generation} != {self._runner_generation}"
|
||||
f"忽略过期 generation 事件 {envelope.method}: {envelope.generation} != {expected_generation}"
|
||||
)
|
||||
continue
|
||||
task = asyncio.create_task(self._handle_event(envelope))
|
||||
@@ -352,22 +446,24 @@ class RPCServer:
|
||||
|
||||
def _handle_response(self, envelope: Envelope) -> None:
|
||||
"""处理来自 Runner 的响应"""
|
||||
if not self._is_current_generation(envelope):
|
||||
pending = self._pending_requests.get(envelope.request_id)
|
||||
if pending is None:
|
||||
return
|
||||
|
||||
future, expected_generation = pending
|
||||
if envelope.generation != expected_generation:
|
||||
logger.warning(
|
||||
f"忽略过期 generation 响应 {envelope.method}: {envelope.generation} != {self._runner_generation}"
|
||||
f"忽略过期 generation 响应 {envelope.method}: {envelope.generation} != {expected_generation}"
|
||||
)
|
||||
return
|
||||
|
||||
future = self._pending_requests.pop(envelope.request_id, None)
|
||||
if future and not future.done():
|
||||
self._pending_requests.pop(envelope.request_id, None)
|
||||
if not future.done():
|
||||
if envelope.error:
|
||||
future.set_exception(RPCError.from_dict(envelope.error))
|
||||
else:
|
||||
future.set_result(envelope)
|
||||
|
||||
def _is_current_generation(self, envelope: Envelope) -> bool:
|
||||
return envelope.generation == self._runner_generation
|
||||
|
||||
async def _handle_request(self, envelope: Envelope, conn: Connection) -> None:
|
||||
"""处理来自 Runner 的请求(通常是能力调用 cap.*)"""
|
||||
handler = self._method_handlers.get(envelope.method)
|
||||
@@ -411,3 +507,26 @@ class RPCServer:
|
||||
return min_parts <= sdk_parts <= max_parts
|
||||
except (ValueError, AttributeError):
|
||||
return False
|
||||
|
||||
def _get_connection_for_generation(self, generation: int) -> Optional[Connection]:
|
||||
if generation == self._runner_generation:
|
||||
return self._connection
|
||||
if generation == self._staged_runner_generation:
|
||||
return self._staged_connection
|
||||
return None
|
||||
|
||||
def _fail_pending_requests(
|
||||
self,
|
||||
error_code: ErrorCode,
|
||||
message: str,
|
||||
generation: Optional[int] = None,
|
||||
) -> int:
|
||||
stale_count = 0
|
||||
for request_id, (future, request_generation) in list(self._pending_requests.items()):
|
||||
if generation is not None and request_generation != generation:
|
||||
continue
|
||||
if not future.done():
|
||||
future.set_exception(RPCError(error_code, message))
|
||||
stale_count += 1
|
||||
self._pending_requests.pop(request_id, None)
|
||||
return stale_count
|
||||
|
||||
@@ -135,6 +135,7 @@ class PluginSupervisor:
|
||||
|
||||
# 已注册的插件组件信息
|
||||
self._registered_plugins: Dict[str, RegisterComponentsPayload] = {}
|
||||
self._staged_registered_plugins: Dict[str, RegisterComponentsPayload] = {}
|
||||
|
||||
# 后台任务
|
||||
self._health_task: Optional[asyncio.Task] = None
|
||||
@@ -319,6 +320,10 @@ class PluginSupervisor:
|
||||
old_session_token = self._rpc_server.session_token
|
||||
expected_generation = self._rpc_server.runner_generation + 1
|
||||
|
||||
# 允许新 Runner 以 staged 方式接入,验证通过后再切换活跃连接
|
||||
self._rpc_server.begin_staged_takeover()
|
||||
self._staged_registered_plugins.clear()
|
||||
|
||||
# 重新生成 session token,防止被终止的旧 Runner 重连
|
||||
self._rpc_server.reset_session_token()
|
||||
|
||||
@@ -330,27 +335,35 @@ class PluginSupervisor:
|
||||
# 拉起新 Runner
|
||||
try:
|
||||
await self._spawn_runner()
|
||||
await self._wait_for_runner_generation(expected_generation, timeout_sec=self._runner_spawn_timeout)
|
||||
resp = await self._rpc_server.send_request("plugin.health", timeout_ms=5000)
|
||||
await self._wait_for_runner_generation(
|
||||
expected_generation,
|
||||
timeout_sec=self._runner_spawn_timeout,
|
||||
allow_staged=True,
|
||||
)
|
||||
resp = await self._rpc_server.send_request(
|
||||
"plugin.health",
|
||||
timeout_ms=5000,
|
||||
target_generation=expected_generation,
|
||||
)
|
||||
health = HealthPayload.model_validate(resp.payload)
|
||||
if not health.healthy:
|
||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "新 Runner 健康检查失败")
|
||||
await self._rpc_server.commit_staged_takeover()
|
||||
except Exception as e:
|
||||
logger.error(f"新 Runner 健康检查失败: {e},回滚")
|
||||
await self._terminate_process(self._runner_process, old_process)
|
||||
await self._rpc_server.rollback_staged_takeover()
|
||||
self._runner_process = old_process
|
||||
# 恢复旧 session token,使旧 Runner 的连接仍可正常工作
|
||||
self._rpc_server.restore_session_token(old_session_token)
|
||||
self._staged_registered_plugins.clear()
|
||||
self._registered_plugins = dict(old_registered_plugins)
|
||||
self._rebuild_runtime_state()
|
||||
return
|
||||
|
||||
# 新 Runner 健康且已完成组件注册,现在清理旧的幽灵组件
|
||||
# 只移除不再存在于新注册表中的旧插件组件
|
||||
for old_pid in list(old_registered_plugins.keys()):
|
||||
if old_pid not in self._registered_plugins:
|
||||
self._component_registry.remove_components_by_plugin(old_pid)
|
||||
self._policy.revoke_plugin(old_pid)
|
||||
self._runner_generation = self._rpc_server.runner_generation
|
||||
self._registered_plugins = dict(self._staged_registered_plugins)
|
||||
self._staged_registered_plugins.clear()
|
||||
self._rebuild_runtime_state()
|
||||
|
||||
# 关停旧 Runner
|
||||
if old_process and old_process.returncode is None:
|
||||
@@ -380,13 +393,22 @@ class PluginSupervisor:
|
||||
except Exception as e:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e))
|
||||
|
||||
if envelope.generation != self._rpc_server.runner_generation:
|
||||
active_generation = self._rpc_server.runner_generation
|
||||
staged_generation = self._rpc_server.staged_generation
|
||||
if envelope.generation not in {active_generation, staged_generation}:
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_GENERATION_MISMATCH.value,
|
||||
f"组件注册 generation 过期: {envelope.generation} != {self._rpc_server.runner_generation}",
|
||||
f"组件注册 generation 过期: {envelope.generation} 不在已知代际中",
|
||||
)
|
||||
|
||||
# 记录注册信息
|
||||
if envelope.generation == staged_generation and staged_generation != 0:
|
||||
self._staged_registered_plugins[reg.plugin_id] = reg
|
||||
logger.info(
|
||||
f"插件 {reg.plugin_id} v{reg.plugin_version} staged 注册成功,"
|
||||
f"组件数: {len(reg.components)}, 能力需求: {reg.capabilities_required}"
|
||||
)
|
||||
return envelope.make_response(payload={"accepted": True, "staged": True})
|
||||
|
||||
self._registered_plugins[reg.plugin_id] = reg
|
||||
|
||||
# 在策略引擎中注册插件
|
||||
@@ -396,7 +418,8 @@ class PluginSupervisor:
|
||||
capabilities=reg.capabilities_required or [],
|
||||
)
|
||||
|
||||
# 在 ComponentRegistry 中注册组件
|
||||
# 同 generation 下重新注册时,以本次声明为准,避免残留幽灵组件
|
||||
self._component_registry.remove_components_by_plugin(reg.plugin_id)
|
||||
self._component_registry.register_plugin_components(
|
||||
plugin_id=reg.plugin_id,
|
||||
components=[c.model_dump() for c in reg.components],
|
||||
@@ -518,10 +541,17 @@ class PluginSupervisor:
|
||||
except Exception as e:
|
||||
logger.error(f"健康检查异常: {e}")
|
||||
|
||||
async def _wait_for_runner_generation(self, expected_generation: int, timeout_sec: float) -> None:
|
||||
async def _wait_for_runner_generation(
|
||||
self,
|
||||
expected_generation: int,
|
||||
timeout_sec: float,
|
||||
allow_staged: bool = False,
|
||||
) -> None:
|
||||
"""等待指定代际的 Runner 完成连接。"""
|
||||
deadline = asyncio.get_running_loop().time() + timeout_sec
|
||||
while asyncio.get_running_loop().time() < deadline:
|
||||
if allow_staged and self._rpc_server.has_generation(expected_generation):
|
||||
return
|
||||
if self._rpc_server.is_connected and self._rpc_server.runner_generation >= expected_generation:
|
||||
self._runner_generation = self._rpc_server.runner_generation
|
||||
return
|
||||
@@ -533,6 +563,7 @@ class PluginSupervisor:
|
||||
self._component_registry.clear()
|
||||
self._policy.clear()
|
||||
self._registered_plugins.clear()
|
||||
self._staged_registered_plugins.clear()
|
||||
|
||||
def _rebuild_runtime_state(self) -> None:
|
||||
"""根据已记录的插件注册信息重建运行时状态。"""
|
||||
|
||||
@@ -430,8 +430,10 @@ class PluginRuntimeManager:
|
||||
"""
|
||||
from src.services import send_service as send_api
|
||||
|
||||
message_type: str = args.get("message_type", "")
|
||||
content = args.get("content", "")
|
||||
message_type: str = args.get("message_type", "") or args.get("custom_type", "")
|
||||
content = args.get("content")
|
||||
if content is None:
|
||||
content = args.get("data", "")
|
||||
stream_id: str = args.get("stream_id", "")
|
||||
if not message_type or not stream_id:
|
||||
return {"success": False, "error": "缺少必要参数 message_type 或 stream_id"}
|
||||
|
||||
@@ -76,6 +76,7 @@ class RPCClient:
|
||||
# 运行状态
|
||||
self._running = False
|
||||
self._recv_task: Optional[asyncio.Task] = None
|
||||
self._background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
@property
|
||||
def generation(self) -> int:
|
||||
@@ -144,6 +145,13 @@ class RPCClient:
|
||||
await self._recv_task
|
||||
self._recv_task = None
|
||||
|
||||
for task in list(self._background_tasks):
|
||||
task.cancel()
|
||||
if self._background_tasks:
|
||||
with contextlib.suppress(Exception):
|
||||
await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
||||
self._background_tasks.clear()
|
||||
|
||||
# 取消所有 pending 请求
|
||||
for future in self._pending_requests.values():
|
||||
if not future.done():
|
||||
@@ -248,9 +256,9 @@ class RPCClient:
|
||||
if envelope.is_response():
|
||||
self._handle_response(envelope)
|
||||
elif envelope.is_request():
|
||||
asyncio.create_task(self._handle_request(envelope))
|
||||
self._track_background_task(asyncio.create_task(self._handle_request(envelope)))
|
||||
elif envelope.is_event():
|
||||
asyncio.create_task(self._handle_event(envelope))
|
||||
self._track_background_task(asyncio.create_task(self._handle_event(envelope)))
|
||||
|
||||
def _handle_response(self, envelope: Envelope) -> None:
|
||||
"""处理来自 Host 的响应"""
|
||||
@@ -290,3 +298,8 @@ class RPCClient:
|
||||
await handler(envelope)
|
||||
except Exception as e:
|
||||
logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True)
|
||||
|
||||
def _track_background_task(self, task: asyncio.Task) -> None:
|
||||
"""保持后台任务强引用,直到其完成或被取消。"""
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
Reference in New Issue
Block a user