feat: 实现插件注册的分阶段接入与切换机制,优化 RPC 连接管理
This commit is contained in:
@@ -677,6 +677,28 @@ class TestComponentRegistry:
|
|||||||
assert removed == 2
|
assert removed == 2
|
||||||
assert reg.get_stats()["total"] == 1
|
assert reg.get_stats()["total"] == 1
|
||||||
|
|
||||||
|
def test_reregister_same_plugin_replaces_component_set(self):
|
||||||
|
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||||
|
|
||||||
|
reg = ComponentRegistry()
|
||||||
|
reg.register_plugin_components(
|
||||||
|
"p1",
|
||||||
|
[
|
||||||
|
{"name": "a1", "component_type": "action", "metadata": {}},
|
||||||
|
{"name": "a2", "component_type": "action", "metadata": {}},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
reg.remove_components_by_plugin("p1")
|
||||||
|
reg.register_plugin_components(
|
||||||
|
"p1",
|
||||||
|
[
|
||||||
|
{"name": "a1", "component_type": "action", "metadata": {}},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert reg.get_component("p1.a1") is not None
|
||||||
|
assert reg.get_component("p1.a2") is None
|
||||||
|
|
||||||
def test_event_handlers_sorted_by_weight(self):
|
def test_event_handlers_sorted_by_weight(self):
|
||||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||||
|
|
||||||
@@ -1257,7 +1279,7 @@ class TestRPCServer:
|
|||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
try:
|
try:
|
||||||
future = loop.create_future()
|
future = loop.create_future()
|
||||||
server._pending_requests[1] = future
|
server._pending_requests[1] = (future, 2)
|
||||||
|
|
||||||
stale_response = Envelope(
|
stale_response = Envelope(
|
||||||
request_id=1,
|
request_id=1,
|
||||||
@@ -1274,23 +1296,53 @@ class TestRPCServer:
|
|||||||
loop.close()
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
class TestRPCClient:
|
||||||
|
"""Runner RPCClient 后台任务生命周期测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_background_tasks_retained_and_cancelled_on_disconnect(self):
|
||||||
|
from src.plugin_runtime.runner.rpc_client import RPCClient
|
||||||
|
|
||||||
|
client = RPCClient(host_address="dummy", session_token="token")
|
||||||
|
release = asyncio.Event()
|
||||||
|
|
||||||
|
async def pending_task():
|
||||||
|
await release.wait()
|
||||||
|
|
||||||
|
task = asyncio.create_task(pending_task())
|
||||||
|
client._track_background_task(task)
|
||||||
|
|
||||||
|
assert task in client._background_tasks
|
||||||
|
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert task in client._background_tasks
|
||||||
|
|
||||||
|
await client.disconnect()
|
||||||
|
|
||||||
|
assert task.cancelled() is True
|
||||||
|
assert not client._background_tasks
|
||||||
|
|
||||||
|
|
||||||
class TestSupervisor:
|
class TestSupervisor:
|
||||||
"""Supervisor 生命周期边界测试"""
|
"""Supervisor 生命周期边界测试"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_register_payload(plugin_id: str = "plugin_a"):
|
def _build_register_payload(plugin_id: str = "plugin_a", component_names=None):
|
||||||
from src.plugin_runtime.protocol.envelope import ComponentDeclaration, RegisterComponentsPayload
|
from src.plugin_runtime.protocol.envelope import ComponentDeclaration, RegisterComponentsPayload
|
||||||
|
|
||||||
|
component_names = component_names or ["handler"]
|
||||||
|
|
||||||
return RegisterComponentsPayload(
|
return RegisterComponentsPayload(
|
||||||
plugin_id=plugin_id,
|
plugin_id=plugin_id,
|
||||||
plugin_version="1.0.0",
|
plugin_version="1.0.0",
|
||||||
components=[
|
components=[
|
||||||
ComponentDeclaration(
|
ComponentDeclaration(
|
||||||
name="handler",
|
name=name,
|
||||||
component_type="event_handler",
|
component_type="event_handler",
|
||||||
plugin_id=plugin_id,
|
plugin_id=plugin_id,
|
||||||
metadata={"event_type": "on_message"},
|
metadata={"event_type": "on_message"},
|
||||||
)
|
)
|
||||||
|
for name in component_names
|
||||||
],
|
],
|
||||||
capabilities_required=["send.text"],
|
capabilities_required=["send.text"],
|
||||||
)
|
)
|
||||||
@@ -1331,8 +1383,11 @@ class TestSupervisor:
|
|||||||
class FakeRPCServer:
|
class FakeRPCServer:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.runner_generation = 1
|
self.runner_generation = 1
|
||||||
|
self.staged_generation = 0
|
||||||
self.is_connected = True
|
self.is_connected = True
|
||||||
self.session_token = "fake-token"
|
self.session_token = "fake-token"
|
||||||
|
self.committed = False
|
||||||
|
self.staging_started = False
|
||||||
|
|
||||||
def reset_session_token(self):
|
def reset_session_token(self):
|
||||||
self.session_token = "new-fake-token"
|
self.session_token = "new-fake-token"
|
||||||
@@ -1341,8 +1396,23 @@ class TestSupervisor:
|
|||||||
def restore_session_token(self, token):
|
def restore_session_token(self, token):
|
||||||
self.session_token = token
|
self.session_token = token
|
||||||
|
|
||||||
async def send_request(self, method, timeout_ms=5000, **kwargs):
|
def begin_staged_takeover(self):
|
||||||
assert self.runner_generation == 2
|
self.staging_started = True
|
||||||
|
self.staged_generation = 2
|
||||||
|
|
||||||
|
async def commit_staged_takeover(self):
|
||||||
|
self.runner_generation = self.staged_generation
|
||||||
|
self.staged_generation = 0
|
||||||
|
self.committed = True
|
||||||
|
|
||||||
|
async def rollback_staged_takeover(self):
|
||||||
|
self.staged_generation = 0
|
||||||
|
|
||||||
|
def has_generation(self, generation):
|
||||||
|
return generation in {self.runner_generation, self.staged_generation}
|
||||||
|
|
||||||
|
async def send_request(self, method, timeout_ms=5000, target_generation=None, **kwargs):
|
||||||
|
assert target_generation == 2
|
||||||
return SimpleNamespace(payload=HealthPayload(healthy=True).model_dump())
|
return SimpleNamespace(payload=HealthPayload(healthy=True).model_dump())
|
||||||
|
|
||||||
supervisor._rpc_server = FakeRPCServer()
|
supervisor._rpc_server = FakeRPCServer()
|
||||||
@@ -1350,18 +1420,14 @@ class TestSupervisor:
|
|||||||
|
|
||||||
async def fake_spawn_runner():
|
async def fake_spawn_runner():
|
||||||
supervisor._runner_process = new_process
|
supervisor._runner_process = new_process
|
||||||
|
supervisor._staged_registered_plugins["plugin_a"] = self._build_register_payload("plugin_a")
|
||||||
async def advance_generation():
|
|
||||||
await asyncio.sleep(0.01)
|
|
||||||
supervisor._rpc_server.runner_generation = 2
|
|
||||||
|
|
||||||
asyncio.create_task(advance_generation())
|
|
||||||
|
|
||||||
monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner)
|
monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner)
|
||||||
|
|
||||||
await supervisor.reload_plugins("test")
|
await supervisor.reload_plugins("test")
|
||||||
|
|
||||||
assert supervisor._runner_process is new_process
|
assert supervisor._runner_process is new_process
|
||||||
|
assert supervisor._rpc_server.committed is True
|
||||||
assert old_process.terminated is True
|
assert old_process.terminated is True
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -1380,6 +1446,69 @@ class TestSupervisor:
|
|||||||
class FakeRPCServer:
|
class FakeRPCServer:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.runner_generation = 1
|
self.runner_generation = 1
|
||||||
|
self.staged_generation = 0
|
||||||
|
self.is_connected = True
|
||||||
|
self.session_token = "fake-token"
|
||||||
|
self.rolled_back = False
|
||||||
|
|
||||||
|
def reset_session_token(self):
|
||||||
|
self.session_token = "new-fake-token"
|
||||||
|
return self.session_token
|
||||||
|
|
||||||
|
def restore_session_token(self, token):
|
||||||
|
self.session_token = token
|
||||||
|
|
||||||
|
def begin_staged_takeover(self):
|
||||||
|
self.staged_generation = 2
|
||||||
|
|
||||||
|
async def commit_staged_takeover(self):
|
||||||
|
self.runner_generation = self.staged_generation
|
||||||
|
self.staged_generation = 0
|
||||||
|
|
||||||
|
async def rollback_staged_takeover(self):
|
||||||
|
self.rolled_back = True
|
||||||
|
self.staged_generation = 0
|
||||||
|
|
||||||
|
def has_generation(self, generation):
|
||||||
|
return generation in {self.runner_generation, self.staged_generation}
|
||||||
|
|
||||||
|
async def send_request(self, method, timeout_ms=5000, target_generation=None, **kwargs):
|
||||||
|
raise RuntimeError("new runner unhealthy")
|
||||||
|
|
||||||
|
supervisor._rpc_server = FakeRPCServer()
|
||||||
|
|
||||||
|
async def fake_spawn_runner():
|
||||||
|
supervisor._runner_process = new_process
|
||||||
|
supervisor._staged_registered_plugins["plugin_a"] = self._build_register_payload("plugin_a")
|
||||||
|
|
||||||
|
monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner)
|
||||||
|
|
||||||
|
await supervisor.reload_plugins("test")
|
||||||
|
|
||||||
|
assert supervisor._runner_process is old_process
|
||||||
|
assert supervisor._rpc_server.rolled_back is True
|
||||||
|
assert old_reg.plugin_id in supervisor._registered_plugins
|
||||||
|
assert supervisor.component_registry.get_component("plugin_a.handler") is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reload_rebuilds_exact_component_set(self, monkeypatch):
|
||||||
|
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||||||
|
from src.plugin_runtime.protocol.envelope import HealthPayload
|
||||||
|
|
||||||
|
supervisor = PluginSupervisor(plugin_dirs=[])
|
||||||
|
old_process = self._make_process(1)
|
||||||
|
new_process = self._make_process(2)
|
||||||
|
old_reg = self._build_register_payload("plugin_a", component_names=["handler", "obsolete"])
|
||||||
|
new_reg = self._build_register_payload("plugin_a", component_names=["handler"])
|
||||||
|
|
||||||
|
supervisor._runner_process = old_process
|
||||||
|
supervisor._registered_plugins[old_reg.plugin_id] = old_reg
|
||||||
|
supervisor._rebuild_runtime_state()
|
||||||
|
|
||||||
|
class FakeRPCServer:
|
||||||
|
def __init__(self):
|
||||||
|
self.runner_generation = 1
|
||||||
|
self.staged_generation = 0
|
||||||
self.is_connected = True
|
self.is_connected = True
|
||||||
self.session_token = "fake-token"
|
self.session_token = "fake-token"
|
||||||
|
|
||||||
@@ -1390,22 +1519,34 @@ class TestSupervisor:
|
|||||||
def restore_session_token(self, token):
|
def restore_session_token(self, token):
|
||||||
self.session_token = token
|
self.session_token = token
|
||||||
|
|
||||||
async def send_request(self, method, timeout_ms=5000, **kwargs):
|
def begin_staged_takeover(self):
|
||||||
raise RuntimeError("new runner unhealthy")
|
self.staged_generation = 2
|
||||||
|
|
||||||
|
async def commit_staged_takeover(self):
|
||||||
|
self.runner_generation = self.staged_generation
|
||||||
|
self.staged_generation = 0
|
||||||
|
|
||||||
|
async def rollback_staged_takeover(self):
|
||||||
|
self.staged_generation = 0
|
||||||
|
|
||||||
|
def has_generation(self, generation):
|
||||||
|
return generation in {self.runner_generation, self.staged_generation}
|
||||||
|
|
||||||
|
async def send_request(self, method, timeout_ms=5000, target_generation=None, **kwargs):
|
||||||
|
return SimpleNamespace(payload=HealthPayload(healthy=True).model_dump())
|
||||||
|
|
||||||
supervisor._rpc_server = FakeRPCServer()
|
supervisor._rpc_server = FakeRPCServer()
|
||||||
|
|
||||||
async def fake_spawn_runner():
|
async def fake_spawn_runner():
|
||||||
supervisor._runner_process = new_process
|
supervisor._runner_process = new_process
|
||||||
supervisor._rpc_server.runner_generation = 2
|
supervisor._staged_registered_plugins[new_reg.plugin_id] = new_reg
|
||||||
|
|
||||||
monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner)
|
monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner)
|
||||||
|
|
||||||
await supervisor.reload_plugins("test")
|
await supervisor.reload_plugins("test")
|
||||||
|
|
||||||
assert supervisor._runner_process is old_process
|
|
||||||
assert old_reg.plugin_id in supervisor._registered_plugins
|
|
||||||
assert supervisor.component_registry.get_component("plugin_a.handler") is not None
|
assert supervisor.component_registry.get_component("plugin_a.handler") is not None
|
||||||
|
assert supervisor.component_registry.get_component("plugin_a.obsolete") is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_attach_stderr_drain_drains_stream(self):
|
async def test_attach_stderr_drain_drains_stream(self):
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
4. 请求-响应关联与超时管理
|
4. 请求-响应关联与超时管理
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import secrets
|
import secrets
|
||||||
@@ -55,12 +55,16 @@ class RPCServer:
|
|||||||
self._connection: Optional[Connection] = None # 当前活跃的 Runner 连接
|
self._connection: Optional[Connection] = None # 当前活跃的 Runner 连接
|
||||||
self._runner_id: Optional[str] = None
|
self._runner_id: Optional[str] = None
|
||||||
self._runner_generation: int = 0
|
self._runner_generation: int = 0
|
||||||
|
self._staged_connection: Optional[Connection] = None
|
||||||
|
self._staged_runner_id: Optional[str] = None
|
||||||
|
self._staged_runner_generation: int = 0
|
||||||
|
self._staging_takeover: bool = False
|
||||||
|
|
||||||
# 方法处理器注册表
|
# 方法处理器注册表
|
||||||
self._method_handlers: Dict[str, MethodHandler] = {}
|
self._method_handlers: Dict[str, MethodHandler] = {}
|
||||||
|
|
||||||
# 等待响应的 pending 请求: request_id -> Future
|
# 等待响应的 pending 请求: request_id -> (Future, target_generation)
|
||||||
self._pending_requests: Dict[int, asyncio.Future] = {}
|
self._pending_requests: Dict[int, Tuple[asyncio.Future, int]] = {}
|
||||||
|
|
||||||
# 发送队列(背压控制)
|
# 发送队列(背压控制)
|
||||||
self._send_queue: Optional[asyncio.Queue[bytes]] = None
|
self._send_queue: Optional[asyncio.Queue[bytes]] = None
|
||||||
@@ -86,10 +90,72 @@ class RPCServer:
|
|||||||
def runner_generation(self) -> int:
|
def runner_generation(self) -> int:
|
||||||
return self._runner_generation
|
return self._runner_generation
|
||||||
|
|
||||||
|
@property
|
||||||
|
def staged_generation(self) -> int:
|
||||||
|
return self._staged_runner_generation
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self) -> bool:
|
def is_connected(self) -> bool:
|
||||||
return self._connection is not None and not self._connection.is_closed
|
return self._connection is not None and not self._connection.is_closed
|
||||||
|
|
||||||
|
def has_generation(self, generation: int) -> bool:
|
||||||
|
return generation == self._runner_generation or (
|
||||||
|
self._staged_connection is not None
|
||||||
|
and not self._staged_connection.is_closed
|
||||||
|
and generation == self._staged_runner_generation
|
||||||
|
)
|
||||||
|
|
||||||
|
def begin_staged_takeover(self) -> None:
|
||||||
|
"""允许新 Runner 以 staged 方式接入,待 Supervisor 验证后再切换为活跃连接。"""
|
||||||
|
self._staging_takeover = True
|
||||||
|
|
||||||
|
async def commit_staged_takeover(self) -> None:
|
||||||
|
"""提交 staged Runner,原活跃连接在提交后被关闭。"""
|
||||||
|
if self._staged_connection is None or self._staged_connection.is_closed:
|
||||||
|
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "没有可提交的新 Runner 连接")
|
||||||
|
|
||||||
|
old_connection = self._connection
|
||||||
|
old_generation = self._runner_generation
|
||||||
|
|
||||||
|
self._connection = self._staged_connection
|
||||||
|
self._runner_id = self._staged_runner_id
|
||||||
|
self._runner_generation = self._staged_runner_generation
|
||||||
|
|
||||||
|
self._staged_connection = None
|
||||||
|
self._staged_runner_id = None
|
||||||
|
self._staged_runner_generation = 0
|
||||||
|
self._staging_takeover = False
|
||||||
|
|
||||||
|
stale_count = self._fail_pending_requests(
|
||||||
|
ErrorCode.E_PLUGIN_CRASHED,
|
||||||
|
"Runner 连接已被新 generation 接管",
|
||||||
|
generation=old_generation,
|
||||||
|
)
|
||||||
|
if stale_count:
|
||||||
|
logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求")
|
||||||
|
|
||||||
|
if old_connection and old_connection is not self._connection and not old_connection.is_closed:
|
||||||
|
await old_connection.close()
|
||||||
|
|
||||||
|
async def rollback_staged_takeover(self) -> None:
|
||||||
|
"""放弃 staged Runner,保留当前活跃连接。"""
|
||||||
|
staged_connection = self._staged_connection
|
||||||
|
staged_generation = self._staged_runner_generation
|
||||||
|
|
||||||
|
self._staged_connection = None
|
||||||
|
self._staged_runner_id = None
|
||||||
|
self._staged_runner_generation = 0
|
||||||
|
self._staging_takeover = False
|
||||||
|
|
||||||
|
self._fail_pending_requests(
|
||||||
|
ErrorCode.E_PLUGIN_CRASHED,
|
||||||
|
"新 Runner 预热失败,已回滚",
|
||||||
|
generation=staged_generation,
|
||||||
|
)
|
||||||
|
|
||||||
|
if staged_connection and not staged_connection.is_closed:
|
||||||
|
await staged_connection.close()
|
||||||
|
|
||||||
def register_method(self, method: str, handler: MethodHandler) -> None:
|
def register_method(self, method: str, handler: MethodHandler) -> None:
|
||||||
"""注册 RPC 方法处理器"""
|
"""注册 RPC 方法处理器"""
|
||||||
self._method_handlers[method] = handler
|
self._method_handlers[method] = handler
|
||||||
@@ -106,7 +172,7 @@ class RPCServer:
|
|||||||
self._running = False
|
self._running = False
|
||||||
|
|
||||||
# 取消所有 pending 请求
|
# 取消所有 pending 请求
|
||||||
for future in self._pending_requests.values():
|
for future, _generation in self._pending_requests.values():
|
||||||
if not future.done():
|
if not future.done():
|
||||||
future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
|
future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
|
||||||
self._pending_requests.clear()
|
self._pending_requests.clear()
|
||||||
@@ -121,6 +187,10 @@ class RPCServer:
|
|||||||
await self._connection.close()
|
await self._connection.close()
|
||||||
self._connection = None
|
self._connection = None
|
||||||
|
|
||||||
|
if self._staged_connection:
|
||||||
|
await self._staged_connection.close()
|
||||||
|
self._staged_connection = None
|
||||||
|
|
||||||
await self._transport.stop()
|
await self._transport.stop()
|
||||||
logger.info("RPC Server 已停止")
|
logger.info("RPC Server 已停止")
|
||||||
|
|
||||||
@@ -130,6 +200,7 @@ class RPCServer:
|
|||||||
plugin_id: str = "",
|
plugin_id: str = "",
|
||||||
payload: Optional[Dict[str, Any]] = None,
|
payload: Optional[Dict[str, Any]] = None,
|
||||||
timeout_ms: int = 30000,
|
timeout_ms: int = 30000,
|
||||||
|
target_generation: Optional[int] = None,
|
||||||
) -> Envelope:
|
) -> Envelope:
|
||||||
"""向 Runner 发送 RPC 请求并等待响应
|
"""向 Runner 发送 RPC 请求并等待响应
|
||||||
|
|
||||||
@@ -145,7 +216,9 @@ class RPCServer:
|
|||||||
Raises:
|
Raises:
|
||||||
RPCError: 调用失败
|
RPCError: 调用失败
|
||||||
"""
|
"""
|
||||||
if not self.is_connected:
|
generation = target_generation or self._runner_generation
|
||||||
|
conn = self._get_connection_for_generation(generation)
|
||||||
|
if conn is None or conn.is_closed:
|
||||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
|
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
|
||||||
|
|
||||||
request_id = self._id_gen.next()
|
request_id = self._id_gen.next()
|
||||||
@@ -154,7 +227,7 @@ class RPCServer:
|
|||||||
message_type=MessageType.REQUEST,
|
message_type=MessageType.REQUEST,
|
||||||
method=method,
|
method=method,
|
||||||
plugin_id=plugin_id,
|
plugin_id=plugin_id,
|
||||||
generation=self._runner_generation,
|
generation=generation,
|
||||||
timeout_ms=timeout_ms,
|
timeout_ms=timeout_ms,
|
||||||
payload=payload or {},
|
payload=payload or {},
|
||||||
)
|
)
|
||||||
@@ -166,12 +239,12 @@ class RPCServer:
|
|||||||
# 注册 pending future
|
# 注册 pending future
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
future: asyncio.Future[Envelope] = loop.create_future()
|
future: asyncio.Future[Envelope] = loop.create_future()
|
||||||
self._pending_requests[request_id] = future
|
self._pending_requests[request_id] = (future, generation)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 发送请求
|
# 发送请求
|
||||||
data = self._codec.encode_envelope(envelope)
|
data = self._codec.encode_envelope(envelope)
|
||||||
await self._connection.send_frame(data)
|
await conn.send_frame(data)
|
||||||
|
|
||||||
# 等待响应
|
# 等待响应
|
||||||
timeout_sec = timeout_ms / 1000.0
|
timeout_sec = timeout_ms / 1000.0
|
||||||
@@ -207,11 +280,13 @@ class RPCServer:
|
|||||||
async def _handle_connection(self, conn: Connection) -> None:
|
async def _handle_connection(self, conn: Connection) -> None:
|
||||||
"""处理新的 Runner 连接"""
|
"""处理新的 Runner 连接"""
|
||||||
logger.info("收到 Runner 连接")
|
logger.info("收到 Runner 连接")
|
||||||
|
previous_connection = self._connection
|
||||||
|
previous_generation = self._runner_generation
|
||||||
|
|
||||||
# 第一条消息必须是 runner.hello 握手
|
# 第一条消息必须是 runner.hello 握手
|
||||||
try:
|
try:
|
||||||
handshake_ok = await self._handle_handshake(conn)
|
role = await self._handle_handshake(conn)
|
||||||
if not handshake_ok:
|
if role is None:
|
||||||
await conn.close()
|
await conn.close()
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -219,41 +294,52 @@ class RPCServer:
|
|||||||
await conn.close()
|
await conn.close()
|
||||||
return
|
return
|
||||||
|
|
||||||
old_connection = self._connection
|
if role == "staged":
|
||||||
self._connection = conn
|
expected_generation = self._staged_runner_generation
|
||||||
logger.info(f"Runner 握手成功: runner_id={self._runner_id}, generation={self._runner_generation}")
|
logger.info(
|
||||||
|
f"Runner staged 握手成功: runner_id={self._staged_runner_id}, generation={self._staged_runner_generation}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._connection = conn
|
||||||
|
expected_generation = self._runner_generation
|
||||||
|
logger.info(f"Runner 握手成功: runner_id={self._runner_id}, generation={self._runner_generation}")
|
||||||
|
|
||||||
if old_connection and old_connection is not conn and not old_connection.is_closed:
|
if previous_connection and previous_connection is not conn and not previous_connection.is_closed:
|
||||||
logger.info("检测到新 Runner 已接管连接,关闭旧连接")
|
logger.info("检测到新 Runner 已接管连接,关闭旧连接")
|
||||||
# 新连接接管后,旧 Runner 的 in-flight 请求不会再收到响应
|
stale_count = self._fail_pending_requests(
|
||||||
# (过期 generation 响应会被 _handle_response 丢弃),
|
ErrorCode.E_PLUGIN_CRASHED,
|
||||||
# 在此处立即 fail-fast 所有 pending 请求,避免挂到超时
|
"Runner 连接已被新 generation 接管",
|
||||||
stale_count = 0
|
generation=previous_generation,
|
||||||
for _req_id, future in list(self._pending_requests.items()):
|
)
|
||||||
if not future.done():
|
if stale_count:
|
||||||
future.set_exception(RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已被新 generation 接管"))
|
logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求")
|
||||||
stale_count += 1
|
await previous_connection.close()
|
||||||
self._pending_requests.clear()
|
|
||||||
if stale_count:
|
|
||||||
logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求")
|
|
||||||
await old_connection.close()
|
|
||||||
|
|
||||||
# 启动消息接收循环
|
# 启动消息接收循环
|
||||||
try:
|
try:
|
||||||
await self._recv_loop(conn)
|
await self._recv_loop(conn, expected_generation=expected_generation)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"连接异常断开: {e}")
|
logger.error(f"连接异常断开: {e}")
|
||||||
finally:
|
finally:
|
||||||
if self._connection is conn:
|
if self._connection is conn:
|
||||||
self._connection = None
|
self._connection = None
|
||||||
self._runner_id = None
|
self._runner_id = None
|
||||||
# 连接断开时,立即让所有等待中的请求失败,避免挂起至超时
|
self._fail_pending_requests(
|
||||||
for _req_id, future in list(self._pending_requests.items()):
|
ErrorCode.E_PLUGIN_CRASHED,
|
||||||
if not future.done():
|
"Runner 连接已断开",
|
||||||
future.set_exception(RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已断开"))
|
generation=expected_generation,
|
||||||
self._pending_requests.clear()
|
)
|
||||||
|
elif self._staged_connection is conn:
|
||||||
|
self._staged_connection = None
|
||||||
|
self._staged_runner_id = None
|
||||||
|
self._staged_runner_generation = 0
|
||||||
|
self._fail_pending_requests(
|
||||||
|
ErrorCode.E_PLUGIN_CRASHED,
|
||||||
|
"Staged Runner 连接已断开",
|
||||||
|
generation=expected_generation,
|
||||||
|
)
|
||||||
|
|
||||||
async def _handle_handshake(self, conn: Connection) -> bool:
|
async def _handle_handshake(self, conn: Connection) -> Optional[str]:
|
||||||
"""处理 runner.hello 握手"""
|
"""处理 runner.hello 握手"""
|
||||||
# 接收握手请求
|
# 接收握手请求
|
||||||
data = await asyncio.wait_for(conn.recv_frame(), timeout=10.0)
|
data = await asyncio.wait_for(conn.recv_frame(), timeout=10.0)
|
||||||
@@ -266,7 +352,7 @@ class RPCServer:
|
|||||||
"首条消息必须为 runner.hello",
|
"首条消息必须为 runner.hello",
|
||||||
)
|
)
|
||||||
await conn.send_frame(self._codec.encode_envelope(error_resp))
|
await conn.send_frame(self._codec.encode_envelope(error_resp))
|
||||||
return False
|
return None
|
||||||
|
|
||||||
# 解析握手 payload
|
# 解析握手 payload
|
||||||
hello = HelloPayload.model_validate(envelope.payload)
|
hello = HelloPayload.model_validate(envelope.payload)
|
||||||
@@ -280,7 +366,7 @@ class RPCServer:
|
|||||||
)
|
)
|
||||||
resp = envelope.make_response(payload=resp_payload.model_dump())
|
resp = envelope.make_response(payload=resp_payload.model_dump())
|
||||||
await conn.send_frame(self._codec.encode_envelope(resp))
|
await conn.send_frame(self._codec.encode_envelope(resp))
|
||||||
return False
|
return None
|
||||||
|
|
||||||
# 校验 SDK 版本
|
# 校验 SDK 版本
|
||||||
if not self._check_sdk_version(hello.sdk_version):
|
if not self._check_sdk_version(hello.sdk_version):
|
||||||
@@ -291,23 +377,31 @@ class RPCServer:
|
|||||||
)
|
)
|
||||||
resp = envelope.make_response(payload=resp_payload.model_dump())
|
resp = envelope.make_response(payload=resp_payload.model_dump())
|
||||||
await conn.send_frame(self._codec.encode_envelope(resp))
|
await conn.send_frame(self._codec.encode_envelope(resp))
|
||||||
return False
|
return None
|
||||||
|
|
||||||
# 握手成功
|
# 握手成功
|
||||||
self._runner_id = hello.runner_id
|
role = "active"
|
||||||
self._runner_generation += 1
|
assigned_generation = self._runner_generation + 1
|
||||||
|
if self._staging_takeover and self.is_connected:
|
||||||
|
role = "staged"
|
||||||
|
self._staged_connection = conn
|
||||||
|
self._staged_runner_id = hello.runner_id
|
||||||
|
self._staged_runner_generation = assigned_generation
|
||||||
|
else:
|
||||||
|
self._runner_id = hello.runner_id
|
||||||
|
self._runner_generation = assigned_generation
|
||||||
|
|
||||||
resp_payload = HelloResponsePayload(
|
resp_payload = HelloResponsePayload(
|
||||||
accepted=True,
|
accepted=True,
|
||||||
host_version=PROTOCOL_VERSION,
|
host_version=PROTOCOL_VERSION,
|
||||||
assigned_generation=self._runner_generation,
|
assigned_generation=assigned_generation,
|
||||||
)
|
)
|
||||||
resp = envelope.make_response(payload=resp_payload.model_dump())
|
resp = envelope.make_response(payload=resp_payload.model_dump())
|
||||||
await conn.send_frame(self._codec.encode_envelope(resp))
|
await conn.send_frame(self._codec.encode_envelope(resp))
|
||||||
|
|
||||||
return True
|
return role
|
||||||
|
|
||||||
async def _recv_loop(self, conn: Connection) -> None:
|
async def _recv_loop(self, conn: Connection, expected_generation: int) -> None:
|
||||||
"""消息接收主循环"""
|
"""消息接收主循环"""
|
||||||
while self._running and not conn.is_closed:
|
while self._running and not conn.is_closed:
|
||||||
try:
|
try:
|
||||||
@@ -329,10 +423,10 @@ class RPCServer:
|
|||||||
if envelope.is_response():
|
if envelope.is_response():
|
||||||
self._handle_response(envelope)
|
self._handle_response(envelope)
|
||||||
elif envelope.is_request():
|
elif envelope.is_request():
|
||||||
if not self._is_current_generation(envelope):
|
if envelope.generation != expected_generation:
|
||||||
error_resp = envelope.make_error_response(
|
error_resp = envelope.make_error_response(
|
||||||
ErrorCode.E_GENERATION_MISMATCH.value,
|
ErrorCode.E_GENERATION_MISMATCH.value,
|
||||||
f"过期 generation: {envelope.generation} != {self._runner_generation}",
|
f"过期 generation: {envelope.generation} != {expected_generation}",
|
||||||
)
|
)
|
||||||
await conn.send_frame(self._codec.encode_envelope(error_resp))
|
await conn.send_frame(self._codec.encode_envelope(error_resp))
|
||||||
continue
|
continue
|
||||||
@@ -341,9 +435,9 @@ class RPCServer:
|
|||||||
self._tasks.append(task)
|
self._tasks.append(task)
|
||||||
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
|
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
|
||||||
elif envelope.is_event():
|
elif envelope.is_event():
|
||||||
if not self._is_current_generation(envelope):
|
if envelope.generation != expected_generation:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"忽略过期 generation 事件 {envelope.method}: {envelope.generation} != {self._runner_generation}"
|
f"忽略过期 generation 事件 {envelope.method}: {envelope.generation} != {expected_generation}"
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
task = asyncio.create_task(self._handle_event(envelope))
|
task = asyncio.create_task(self._handle_event(envelope))
|
||||||
@@ -352,22 +446,24 @@ class RPCServer:
|
|||||||
|
|
||||||
def _handle_response(self, envelope: Envelope) -> None:
|
def _handle_response(self, envelope: Envelope) -> None:
|
||||||
"""处理来自 Runner 的响应"""
|
"""处理来自 Runner 的响应"""
|
||||||
if not self._is_current_generation(envelope):
|
pending = self._pending_requests.get(envelope.request_id)
|
||||||
|
if pending is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
future, expected_generation = pending
|
||||||
|
if envelope.generation != expected_generation:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"忽略过期 generation 响应 {envelope.method}: {envelope.generation} != {self._runner_generation}"
|
f"忽略过期 generation 响应 {envelope.method}: {envelope.generation} != {expected_generation}"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
future = self._pending_requests.pop(envelope.request_id, None)
|
self._pending_requests.pop(envelope.request_id, None)
|
||||||
if future and not future.done():
|
if not future.done():
|
||||||
if envelope.error:
|
if envelope.error:
|
||||||
future.set_exception(RPCError.from_dict(envelope.error))
|
future.set_exception(RPCError.from_dict(envelope.error))
|
||||||
else:
|
else:
|
||||||
future.set_result(envelope)
|
future.set_result(envelope)
|
||||||
|
|
||||||
def _is_current_generation(self, envelope: Envelope) -> bool:
|
|
||||||
return envelope.generation == self._runner_generation
|
|
||||||
|
|
||||||
async def _handle_request(self, envelope: Envelope, conn: Connection) -> None:
|
async def _handle_request(self, envelope: Envelope, conn: Connection) -> None:
|
||||||
"""处理来自 Runner 的请求(通常是能力调用 cap.*)"""
|
"""处理来自 Runner 的请求(通常是能力调用 cap.*)"""
|
||||||
handler = self._method_handlers.get(envelope.method)
|
handler = self._method_handlers.get(envelope.method)
|
||||||
@@ -411,3 +507,26 @@ class RPCServer:
|
|||||||
return min_parts <= sdk_parts <= max_parts
|
return min_parts <= sdk_parts <= max_parts
|
||||||
except (ValueError, AttributeError):
|
except (ValueError, AttributeError):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _get_connection_for_generation(self, generation: int) -> Optional[Connection]:
|
||||||
|
if generation == self._runner_generation:
|
||||||
|
return self._connection
|
||||||
|
if generation == self._staged_runner_generation:
|
||||||
|
return self._staged_connection
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _fail_pending_requests(
|
||||||
|
self,
|
||||||
|
error_code: ErrorCode,
|
||||||
|
message: str,
|
||||||
|
generation: Optional[int] = None,
|
||||||
|
) -> int:
|
||||||
|
stale_count = 0
|
||||||
|
for request_id, (future, request_generation) in list(self._pending_requests.items()):
|
||||||
|
if generation is not None and request_generation != generation:
|
||||||
|
continue
|
||||||
|
if not future.done():
|
||||||
|
future.set_exception(RPCError(error_code, message))
|
||||||
|
stale_count += 1
|
||||||
|
self._pending_requests.pop(request_id, None)
|
||||||
|
return stale_count
|
||||||
|
|||||||
@@ -135,6 +135,7 @@ class PluginSupervisor:
|
|||||||
|
|
||||||
# 已注册的插件组件信息
|
# 已注册的插件组件信息
|
||||||
self._registered_plugins: Dict[str, RegisterComponentsPayload] = {}
|
self._registered_plugins: Dict[str, RegisterComponentsPayload] = {}
|
||||||
|
self._staged_registered_plugins: Dict[str, RegisterComponentsPayload] = {}
|
||||||
|
|
||||||
# 后台任务
|
# 后台任务
|
||||||
self._health_task: Optional[asyncio.Task] = None
|
self._health_task: Optional[asyncio.Task] = None
|
||||||
@@ -319,6 +320,10 @@ class PluginSupervisor:
|
|||||||
old_session_token = self._rpc_server.session_token
|
old_session_token = self._rpc_server.session_token
|
||||||
expected_generation = self._rpc_server.runner_generation + 1
|
expected_generation = self._rpc_server.runner_generation + 1
|
||||||
|
|
||||||
|
# 允许新 Runner 以 staged 方式接入,验证通过后再切换活跃连接
|
||||||
|
self._rpc_server.begin_staged_takeover()
|
||||||
|
self._staged_registered_plugins.clear()
|
||||||
|
|
||||||
# 重新生成 session token,防止被终止的旧 Runner 重连
|
# 重新生成 session token,防止被终止的旧 Runner 重连
|
||||||
self._rpc_server.reset_session_token()
|
self._rpc_server.reset_session_token()
|
||||||
|
|
||||||
@@ -330,27 +335,35 @@ class PluginSupervisor:
|
|||||||
# 拉起新 Runner
|
# 拉起新 Runner
|
||||||
try:
|
try:
|
||||||
await self._spawn_runner()
|
await self._spawn_runner()
|
||||||
await self._wait_for_runner_generation(expected_generation, timeout_sec=self._runner_spawn_timeout)
|
await self._wait_for_runner_generation(
|
||||||
resp = await self._rpc_server.send_request("plugin.health", timeout_ms=5000)
|
expected_generation,
|
||||||
|
timeout_sec=self._runner_spawn_timeout,
|
||||||
|
allow_staged=True,
|
||||||
|
)
|
||||||
|
resp = await self._rpc_server.send_request(
|
||||||
|
"plugin.health",
|
||||||
|
timeout_ms=5000,
|
||||||
|
target_generation=expected_generation,
|
||||||
|
)
|
||||||
health = HealthPayload.model_validate(resp.payload)
|
health = HealthPayload.model_validate(resp.payload)
|
||||||
if not health.healthy:
|
if not health.healthy:
|
||||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "新 Runner 健康检查失败")
|
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "新 Runner 健康检查失败")
|
||||||
|
await self._rpc_server.commit_staged_takeover()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"新 Runner 健康检查失败: {e},回滚")
|
logger.error(f"新 Runner 健康检查失败: {e},回滚")
|
||||||
await self._terminate_process(self._runner_process, old_process)
|
await self._terminate_process(self._runner_process, old_process)
|
||||||
|
await self._rpc_server.rollback_staged_takeover()
|
||||||
self._runner_process = old_process
|
self._runner_process = old_process
|
||||||
# 恢复旧 session token,使旧 Runner 的连接仍可正常工作
|
|
||||||
self._rpc_server.restore_session_token(old_session_token)
|
self._rpc_server.restore_session_token(old_session_token)
|
||||||
|
self._staged_registered_plugins.clear()
|
||||||
self._registered_plugins = dict(old_registered_plugins)
|
self._registered_plugins = dict(old_registered_plugins)
|
||||||
self._rebuild_runtime_state()
|
self._rebuild_runtime_state()
|
||||||
return
|
return
|
||||||
|
|
||||||
# 新 Runner 健康且已完成组件注册,现在清理旧的幽灵组件
|
self._runner_generation = self._rpc_server.runner_generation
|
||||||
# 只移除不再存在于新注册表中的旧插件组件
|
self._registered_plugins = dict(self._staged_registered_plugins)
|
||||||
for old_pid in list(old_registered_plugins.keys()):
|
self._staged_registered_plugins.clear()
|
||||||
if old_pid not in self._registered_plugins:
|
self._rebuild_runtime_state()
|
||||||
self._component_registry.remove_components_by_plugin(old_pid)
|
|
||||||
self._policy.revoke_plugin(old_pid)
|
|
||||||
|
|
||||||
# 关停旧 Runner
|
# 关停旧 Runner
|
||||||
if old_process and old_process.returncode is None:
|
if old_process and old_process.returncode is None:
|
||||||
@@ -380,13 +393,22 @@ class PluginSupervisor:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e))
|
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e))
|
||||||
|
|
||||||
if envelope.generation != self._rpc_server.runner_generation:
|
active_generation = self._rpc_server.runner_generation
|
||||||
|
staged_generation = self._rpc_server.staged_generation
|
||||||
|
if envelope.generation not in {active_generation, staged_generation}:
|
||||||
return envelope.make_error_response(
|
return envelope.make_error_response(
|
||||||
ErrorCode.E_GENERATION_MISMATCH.value,
|
ErrorCode.E_GENERATION_MISMATCH.value,
|
||||||
f"组件注册 generation 过期: {envelope.generation} != {self._rpc_server.runner_generation}",
|
f"组件注册 generation 过期: {envelope.generation} 不在已知代际中",
|
||||||
)
|
)
|
||||||
|
|
||||||
# 记录注册信息
|
if envelope.generation == staged_generation and staged_generation != 0:
|
||||||
|
self._staged_registered_plugins[reg.plugin_id] = reg
|
||||||
|
logger.info(
|
||||||
|
f"插件 {reg.plugin_id} v{reg.plugin_version} staged 注册成功,"
|
||||||
|
f"组件数: {len(reg.components)}, 能力需求: {reg.capabilities_required}"
|
||||||
|
)
|
||||||
|
return envelope.make_response(payload={"accepted": True, "staged": True})
|
||||||
|
|
||||||
self._registered_plugins[reg.plugin_id] = reg
|
self._registered_plugins[reg.plugin_id] = reg
|
||||||
|
|
||||||
# 在策略引擎中注册插件
|
# 在策略引擎中注册插件
|
||||||
@@ -396,7 +418,8 @@ class PluginSupervisor:
|
|||||||
capabilities=reg.capabilities_required or [],
|
capabilities=reg.capabilities_required or [],
|
||||||
)
|
)
|
||||||
|
|
||||||
# 在 ComponentRegistry 中注册组件
|
# 同 generation 下重新注册时,以本次声明为准,避免残留幽灵组件
|
||||||
|
self._component_registry.remove_components_by_plugin(reg.plugin_id)
|
||||||
self._component_registry.register_plugin_components(
|
self._component_registry.register_plugin_components(
|
||||||
plugin_id=reg.plugin_id,
|
plugin_id=reg.plugin_id,
|
||||||
components=[c.model_dump() for c in reg.components],
|
components=[c.model_dump() for c in reg.components],
|
||||||
@@ -518,10 +541,17 @@ class PluginSupervisor:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"健康检查异常: {e}")
|
logger.error(f"健康检查异常: {e}")
|
||||||
|
|
||||||
async def _wait_for_runner_generation(self, expected_generation: int, timeout_sec: float) -> None:
|
async def _wait_for_runner_generation(
|
||||||
|
self,
|
||||||
|
expected_generation: int,
|
||||||
|
timeout_sec: float,
|
||||||
|
allow_staged: bool = False,
|
||||||
|
) -> None:
|
||||||
"""等待指定代际的 Runner 完成连接。"""
|
"""等待指定代际的 Runner 完成连接。"""
|
||||||
deadline = asyncio.get_running_loop().time() + timeout_sec
|
deadline = asyncio.get_running_loop().time() + timeout_sec
|
||||||
while asyncio.get_running_loop().time() < deadline:
|
while asyncio.get_running_loop().time() < deadline:
|
||||||
|
if allow_staged and self._rpc_server.has_generation(expected_generation):
|
||||||
|
return
|
||||||
if self._rpc_server.is_connected and self._rpc_server.runner_generation >= expected_generation:
|
if self._rpc_server.is_connected and self._rpc_server.runner_generation >= expected_generation:
|
||||||
self._runner_generation = self._rpc_server.runner_generation
|
self._runner_generation = self._rpc_server.runner_generation
|
||||||
return
|
return
|
||||||
@@ -533,6 +563,7 @@ class PluginSupervisor:
|
|||||||
self._component_registry.clear()
|
self._component_registry.clear()
|
||||||
self._policy.clear()
|
self._policy.clear()
|
||||||
self._registered_plugins.clear()
|
self._registered_plugins.clear()
|
||||||
|
self._staged_registered_plugins.clear()
|
||||||
|
|
||||||
def _rebuild_runtime_state(self) -> None:
|
def _rebuild_runtime_state(self) -> None:
|
||||||
"""根据已记录的插件注册信息重建运行时状态。"""
|
"""根据已记录的插件注册信息重建运行时状态。"""
|
||||||
|
|||||||
@@ -430,8 +430,10 @@ class PluginRuntimeManager:
|
|||||||
"""
|
"""
|
||||||
from src.services import send_service as send_api
|
from src.services import send_service as send_api
|
||||||
|
|
||||||
message_type: str = args.get("message_type", "")
|
message_type: str = args.get("message_type", "") or args.get("custom_type", "")
|
||||||
content = args.get("content", "")
|
content = args.get("content")
|
||||||
|
if content is None:
|
||||||
|
content = args.get("data", "")
|
||||||
stream_id: str = args.get("stream_id", "")
|
stream_id: str = args.get("stream_id", "")
|
||||||
if not message_type or not stream_id:
|
if not message_type or not stream_id:
|
||||||
return {"success": False, "error": "缺少必要参数 message_type 或 stream_id"}
|
return {"success": False, "error": "缺少必要参数 message_type 或 stream_id"}
|
||||||
|
|||||||
@@ -76,6 +76,7 @@ class RPCClient:
|
|||||||
# 运行状态
|
# 运行状态
|
||||||
self._running = False
|
self._running = False
|
||||||
self._recv_task: Optional[asyncio.Task] = None
|
self._recv_task: Optional[asyncio.Task] = None
|
||||||
|
self._background_tasks: set[asyncio.Task] = set()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def generation(self) -> int:
|
def generation(self) -> int:
|
||||||
@@ -144,6 +145,13 @@ class RPCClient:
|
|||||||
await self._recv_task
|
await self._recv_task
|
||||||
self._recv_task = None
|
self._recv_task = None
|
||||||
|
|
||||||
|
for task in list(self._background_tasks):
|
||||||
|
task.cancel()
|
||||||
|
if self._background_tasks:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
||||||
|
self._background_tasks.clear()
|
||||||
|
|
||||||
# 取消所有 pending 请求
|
# 取消所有 pending 请求
|
||||||
for future in self._pending_requests.values():
|
for future in self._pending_requests.values():
|
||||||
if not future.done():
|
if not future.done():
|
||||||
@@ -248,9 +256,9 @@ class RPCClient:
|
|||||||
if envelope.is_response():
|
if envelope.is_response():
|
||||||
self._handle_response(envelope)
|
self._handle_response(envelope)
|
||||||
elif envelope.is_request():
|
elif envelope.is_request():
|
||||||
asyncio.create_task(self._handle_request(envelope))
|
self._track_background_task(asyncio.create_task(self._handle_request(envelope)))
|
||||||
elif envelope.is_event():
|
elif envelope.is_event():
|
||||||
asyncio.create_task(self._handle_event(envelope))
|
self._track_background_task(asyncio.create_task(self._handle_event(envelope)))
|
||||||
|
|
||||||
def _handle_response(self, envelope: Envelope) -> None:
|
def _handle_response(self, envelope: Envelope) -> None:
|
||||||
"""处理来自 Host 的响应"""
|
"""处理来自 Host 的响应"""
|
||||||
@@ -290,3 +298,8 @@ class RPCClient:
|
|||||||
await handler(envelope)
|
await handler(envelope)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True)
|
logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True)
|
||||||
|
|
||||||
|
def _track_background_task(self, task: asyncio.Task) -> None:
|
||||||
|
"""保持后台任务强引用,直到其完成或被取消。"""
|
||||||
|
self._background_tasks.add(task)
|
||||||
|
task.add_done_callback(self._background_tasks.discard)
|
||||||
|
|||||||
Reference in New Issue
Block a user