feat: 实现插件注册的分阶段接入与切换机制,优化 RPC 连接管理
This commit is contained in:
@@ -7,7 +7,7 @@
|
||||
4. 请求-响应关联与超时管理
|
||||
"""
|
||||
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import asyncio
|
||||
import secrets
|
||||
@@ -55,12 +55,16 @@ class RPCServer:
|
||||
self._connection: Optional[Connection] = None # 当前活跃的 Runner 连接
|
||||
self._runner_id: Optional[str] = None
|
||||
self._runner_generation: int = 0
|
||||
self._staged_connection: Optional[Connection] = None
|
||||
self._staged_runner_id: Optional[str] = None
|
||||
self._staged_runner_generation: int = 0
|
||||
self._staging_takeover: bool = False
|
||||
|
||||
# 方法处理器注册表
|
||||
self._method_handlers: Dict[str, MethodHandler] = {}
|
||||
|
||||
# 等待响应的 pending 请求: request_id -> Future
|
||||
self._pending_requests: Dict[int, asyncio.Future] = {}
|
||||
# 等待响应的 pending 请求: request_id -> (Future, target_generation)
|
||||
self._pending_requests: Dict[int, Tuple[asyncio.Future, int]] = {}
|
||||
|
||||
# 发送队列(背压控制)
|
||||
self._send_queue: Optional[asyncio.Queue[bytes]] = None
|
||||
@@ -86,10 +90,72 @@ class RPCServer:
|
||||
def runner_generation(self) -> int:
|
||||
return self._runner_generation
|
||||
|
||||
@property
|
||||
def staged_generation(self) -> int:
|
||||
return self._staged_runner_generation
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._connection is not None and not self._connection.is_closed
|
||||
|
||||
def has_generation(self, generation: int) -> bool:
|
||||
return generation == self._runner_generation or (
|
||||
self._staged_connection is not None
|
||||
and not self._staged_connection.is_closed
|
||||
and generation == self._staged_runner_generation
|
||||
)
|
||||
|
||||
def begin_staged_takeover(self) -> None:
|
||||
"""允许新 Runner 以 staged 方式接入,待 Supervisor 验证后再切换为活跃连接。"""
|
||||
self._staging_takeover = True
|
||||
|
||||
async def commit_staged_takeover(self) -> None:
|
||||
"""提交 staged Runner,原活跃连接在提交后被关闭。"""
|
||||
if self._staged_connection is None or self._staged_connection.is_closed:
|
||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "没有可提交的新 Runner 连接")
|
||||
|
||||
old_connection = self._connection
|
||||
old_generation = self._runner_generation
|
||||
|
||||
self._connection = self._staged_connection
|
||||
self._runner_id = self._staged_runner_id
|
||||
self._runner_generation = self._staged_runner_generation
|
||||
|
||||
self._staged_connection = None
|
||||
self._staged_runner_id = None
|
||||
self._staged_runner_generation = 0
|
||||
self._staging_takeover = False
|
||||
|
||||
stale_count = self._fail_pending_requests(
|
||||
ErrorCode.E_PLUGIN_CRASHED,
|
||||
"Runner 连接已被新 generation 接管",
|
||||
generation=old_generation,
|
||||
)
|
||||
if stale_count:
|
||||
logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求")
|
||||
|
||||
if old_connection and old_connection is not self._connection and not old_connection.is_closed:
|
||||
await old_connection.close()
|
||||
|
||||
async def rollback_staged_takeover(self) -> None:
|
||||
"""放弃 staged Runner,保留当前活跃连接。"""
|
||||
staged_connection = self._staged_connection
|
||||
staged_generation = self._staged_runner_generation
|
||||
|
||||
self._staged_connection = None
|
||||
self._staged_runner_id = None
|
||||
self._staged_runner_generation = 0
|
||||
self._staging_takeover = False
|
||||
|
||||
self._fail_pending_requests(
|
||||
ErrorCode.E_PLUGIN_CRASHED,
|
||||
"新 Runner 预热失败,已回滚",
|
||||
generation=staged_generation,
|
||||
)
|
||||
|
||||
if staged_connection and not staged_connection.is_closed:
|
||||
await staged_connection.close()
|
||||
|
||||
def register_method(self, method: str, handler: MethodHandler) -> None:
|
||||
"""注册 RPC 方法处理器"""
|
||||
self._method_handlers[method] = handler
|
||||
@@ -106,7 +172,7 @@ class RPCServer:
|
||||
self._running = False
|
||||
|
||||
# 取消所有 pending 请求
|
||||
for future in self._pending_requests.values():
|
||||
for future, _generation in self._pending_requests.values():
|
||||
if not future.done():
|
||||
future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
|
||||
self._pending_requests.clear()
|
||||
@@ -121,6 +187,10 @@ class RPCServer:
|
||||
await self._connection.close()
|
||||
self._connection = None
|
||||
|
||||
if self._staged_connection:
|
||||
await self._staged_connection.close()
|
||||
self._staged_connection = None
|
||||
|
||||
await self._transport.stop()
|
||||
logger.info("RPC Server 已停止")
|
||||
|
||||
@@ -130,6 +200,7 @@ class RPCServer:
|
||||
plugin_id: str = "",
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
timeout_ms: int = 30000,
|
||||
target_generation: Optional[int] = None,
|
||||
) -> Envelope:
|
||||
"""向 Runner 发送 RPC 请求并等待响应
|
||||
|
||||
@@ -145,7 +216,9 @@ class RPCServer:
|
||||
Raises:
|
||||
RPCError: 调用失败
|
||||
"""
|
||||
if not self.is_connected:
|
||||
generation = target_generation or self._runner_generation
|
||||
conn = self._get_connection_for_generation(generation)
|
||||
if conn is None or conn.is_closed:
|
||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
|
||||
|
||||
request_id = self._id_gen.next()
|
||||
@@ -154,7 +227,7 @@ class RPCServer:
|
||||
message_type=MessageType.REQUEST,
|
||||
method=method,
|
||||
plugin_id=plugin_id,
|
||||
generation=self._runner_generation,
|
||||
generation=generation,
|
||||
timeout_ms=timeout_ms,
|
||||
payload=payload or {},
|
||||
)
|
||||
@@ -166,12 +239,12 @@ class RPCServer:
|
||||
# 注册 pending future
|
||||
loop = asyncio.get_running_loop()
|
||||
future: asyncio.Future[Envelope] = loop.create_future()
|
||||
self._pending_requests[request_id] = future
|
||||
self._pending_requests[request_id] = (future, generation)
|
||||
|
||||
try:
|
||||
# 发送请求
|
||||
data = self._codec.encode_envelope(envelope)
|
||||
await self._connection.send_frame(data)
|
||||
await conn.send_frame(data)
|
||||
|
||||
# 等待响应
|
||||
timeout_sec = timeout_ms / 1000.0
|
||||
@@ -207,11 +280,13 @@ class RPCServer:
|
||||
async def _handle_connection(self, conn: Connection) -> None:
|
||||
"""处理新的 Runner 连接"""
|
||||
logger.info("收到 Runner 连接")
|
||||
previous_connection = self._connection
|
||||
previous_generation = self._runner_generation
|
||||
|
||||
# 第一条消息必须是 runner.hello 握手
|
||||
try:
|
||||
handshake_ok = await self._handle_handshake(conn)
|
||||
if not handshake_ok:
|
||||
role = await self._handle_handshake(conn)
|
||||
if role is None:
|
||||
await conn.close()
|
||||
return
|
||||
except Exception as e:
|
||||
@@ -219,41 +294,52 @@ class RPCServer:
|
||||
await conn.close()
|
||||
return
|
||||
|
||||
old_connection = self._connection
|
||||
self._connection = conn
|
||||
logger.info(f"Runner 握手成功: runner_id={self._runner_id}, generation={self._runner_generation}")
|
||||
if role == "staged":
|
||||
expected_generation = self._staged_runner_generation
|
||||
logger.info(
|
||||
f"Runner staged 握手成功: runner_id={self._staged_runner_id}, generation={self._staged_runner_generation}"
|
||||
)
|
||||
else:
|
||||
self._connection = conn
|
||||
expected_generation = self._runner_generation
|
||||
logger.info(f"Runner 握手成功: runner_id={self._runner_id}, generation={self._runner_generation}")
|
||||
|
||||
if old_connection and old_connection is not conn and not old_connection.is_closed:
|
||||
logger.info("检测到新 Runner 已接管连接,关闭旧连接")
|
||||
# 新连接接管后,旧 Runner 的 in-flight 请求不会再收到响应
|
||||
# (过期 generation 响应会被 _handle_response 丢弃),
|
||||
# 在此处立即 fail-fast 所有 pending 请求,避免挂到超时
|
||||
stale_count = 0
|
||||
for _req_id, future in list(self._pending_requests.items()):
|
||||
if not future.done():
|
||||
future.set_exception(RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已被新 generation 接管"))
|
||||
stale_count += 1
|
||||
self._pending_requests.clear()
|
||||
if stale_count:
|
||||
logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求")
|
||||
await old_connection.close()
|
||||
if previous_connection and previous_connection is not conn and not previous_connection.is_closed:
|
||||
logger.info("检测到新 Runner 已接管连接,关闭旧连接")
|
||||
stale_count = self._fail_pending_requests(
|
||||
ErrorCode.E_PLUGIN_CRASHED,
|
||||
"Runner 连接已被新 generation 接管",
|
||||
generation=previous_generation,
|
||||
)
|
||||
if stale_count:
|
||||
logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求")
|
||||
await previous_connection.close()
|
||||
|
||||
# 启动消息接收循环
|
||||
try:
|
||||
await self._recv_loop(conn)
|
||||
await self._recv_loop(conn, expected_generation=expected_generation)
|
||||
except Exception as e:
|
||||
logger.error(f"连接异常断开: {e}")
|
||||
finally:
|
||||
if self._connection is conn:
|
||||
self._connection = None
|
||||
self._runner_id = None
|
||||
# 连接断开时,立即让所有等待中的请求失败,避免挂起至超时
|
||||
for _req_id, future in list(self._pending_requests.items()):
|
||||
if not future.done():
|
||||
future.set_exception(RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已断开"))
|
||||
self._pending_requests.clear()
|
||||
self._fail_pending_requests(
|
||||
ErrorCode.E_PLUGIN_CRASHED,
|
||||
"Runner 连接已断开",
|
||||
generation=expected_generation,
|
||||
)
|
||||
elif self._staged_connection is conn:
|
||||
self._staged_connection = None
|
||||
self._staged_runner_id = None
|
||||
self._staged_runner_generation = 0
|
||||
self._fail_pending_requests(
|
||||
ErrorCode.E_PLUGIN_CRASHED,
|
||||
"Staged Runner 连接已断开",
|
||||
generation=expected_generation,
|
||||
)
|
||||
|
||||
async def _handle_handshake(self, conn: Connection) -> bool:
|
||||
async def _handle_handshake(self, conn: Connection) -> Optional[str]:
|
||||
"""处理 runner.hello 握手"""
|
||||
# 接收握手请求
|
||||
data = await asyncio.wait_for(conn.recv_frame(), timeout=10.0)
|
||||
@@ -266,7 +352,7 @@ class RPCServer:
|
||||
"首条消息必须为 runner.hello",
|
||||
)
|
||||
await conn.send_frame(self._codec.encode_envelope(error_resp))
|
||||
return False
|
||||
return None
|
||||
|
||||
# 解析握手 payload
|
||||
hello = HelloPayload.model_validate(envelope.payload)
|
||||
@@ -280,7 +366,7 @@ class RPCServer:
|
||||
)
|
||||
resp = envelope.make_response(payload=resp_payload.model_dump())
|
||||
await conn.send_frame(self._codec.encode_envelope(resp))
|
||||
return False
|
||||
return None
|
||||
|
||||
# 校验 SDK 版本
|
||||
if not self._check_sdk_version(hello.sdk_version):
|
||||
@@ -291,23 +377,31 @@ class RPCServer:
|
||||
)
|
||||
resp = envelope.make_response(payload=resp_payload.model_dump())
|
||||
await conn.send_frame(self._codec.encode_envelope(resp))
|
||||
return False
|
||||
return None
|
||||
|
||||
# 握手成功
|
||||
self._runner_id = hello.runner_id
|
||||
self._runner_generation += 1
|
||||
role = "active"
|
||||
assigned_generation = self._runner_generation + 1
|
||||
if self._staging_takeover and self.is_connected:
|
||||
role = "staged"
|
||||
self._staged_connection = conn
|
||||
self._staged_runner_id = hello.runner_id
|
||||
self._staged_runner_generation = assigned_generation
|
||||
else:
|
||||
self._runner_id = hello.runner_id
|
||||
self._runner_generation = assigned_generation
|
||||
|
||||
resp_payload = HelloResponsePayload(
|
||||
accepted=True,
|
||||
host_version=PROTOCOL_VERSION,
|
||||
assigned_generation=self._runner_generation,
|
||||
assigned_generation=assigned_generation,
|
||||
)
|
||||
resp = envelope.make_response(payload=resp_payload.model_dump())
|
||||
await conn.send_frame(self._codec.encode_envelope(resp))
|
||||
|
||||
return True
|
||||
return role
|
||||
|
||||
async def _recv_loop(self, conn: Connection) -> None:
|
||||
async def _recv_loop(self, conn: Connection, expected_generation: int) -> None:
|
||||
"""消息接收主循环"""
|
||||
while self._running and not conn.is_closed:
|
||||
try:
|
||||
@@ -329,10 +423,10 @@ class RPCServer:
|
||||
if envelope.is_response():
|
||||
self._handle_response(envelope)
|
||||
elif envelope.is_request():
|
||||
if not self._is_current_generation(envelope):
|
||||
if envelope.generation != expected_generation:
|
||||
error_resp = envelope.make_error_response(
|
||||
ErrorCode.E_GENERATION_MISMATCH.value,
|
||||
f"过期 generation: {envelope.generation} != {self._runner_generation}",
|
||||
f"过期 generation: {envelope.generation} != {expected_generation}",
|
||||
)
|
||||
await conn.send_frame(self._codec.encode_envelope(error_resp))
|
||||
continue
|
||||
@@ -341,9 +435,9 @@ class RPCServer:
|
||||
self._tasks.append(task)
|
||||
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
|
||||
elif envelope.is_event():
|
||||
if not self._is_current_generation(envelope):
|
||||
if envelope.generation != expected_generation:
|
||||
logger.warning(
|
||||
f"忽略过期 generation 事件 {envelope.method}: {envelope.generation} != {self._runner_generation}"
|
||||
f"忽略过期 generation 事件 {envelope.method}: {envelope.generation} != {expected_generation}"
|
||||
)
|
||||
continue
|
||||
task = asyncio.create_task(self._handle_event(envelope))
|
||||
@@ -352,22 +446,24 @@ class RPCServer:
|
||||
|
||||
def _handle_response(self, envelope: Envelope) -> None:
|
||||
"""处理来自 Runner 的响应"""
|
||||
if not self._is_current_generation(envelope):
|
||||
pending = self._pending_requests.get(envelope.request_id)
|
||||
if pending is None:
|
||||
return
|
||||
|
||||
future, expected_generation = pending
|
||||
if envelope.generation != expected_generation:
|
||||
logger.warning(
|
||||
f"忽略过期 generation 响应 {envelope.method}: {envelope.generation} != {self._runner_generation}"
|
||||
f"忽略过期 generation 响应 {envelope.method}: {envelope.generation} != {expected_generation}"
|
||||
)
|
||||
return
|
||||
|
||||
future = self._pending_requests.pop(envelope.request_id, None)
|
||||
if future and not future.done():
|
||||
self._pending_requests.pop(envelope.request_id, None)
|
||||
if not future.done():
|
||||
if envelope.error:
|
||||
future.set_exception(RPCError.from_dict(envelope.error))
|
||||
else:
|
||||
future.set_result(envelope)
|
||||
|
||||
def _is_current_generation(self, envelope: Envelope) -> bool:
|
||||
return envelope.generation == self._runner_generation
|
||||
|
||||
async def _handle_request(self, envelope: Envelope, conn: Connection) -> None:
|
||||
"""处理来自 Runner 的请求(通常是能力调用 cap.*)"""
|
||||
handler = self._method_handlers.get(envelope.method)
|
||||
@@ -411,3 +507,26 @@ class RPCServer:
|
||||
return min_parts <= sdk_parts <= max_parts
|
||||
except (ValueError, AttributeError):
|
||||
return False
|
||||
|
||||
def _get_connection_for_generation(self, generation: int) -> Optional[Connection]:
|
||||
if generation == self._runner_generation:
|
||||
return self._connection
|
||||
if generation == self._staged_runner_generation:
|
||||
return self._staged_connection
|
||||
return None
|
||||
|
||||
def _fail_pending_requests(
|
||||
self,
|
||||
error_code: ErrorCode,
|
||||
message: str,
|
||||
generation: Optional[int] = None,
|
||||
) -> int:
|
||||
stale_count = 0
|
||||
for request_id, (future, request_generation) in list(self._pending_requests.items()):
|
||||
if generation is not None and request_generation != generation:
|
||||
continue
|
||||
if not future.done():
|
||||
future.set_exception(RPCError(error_code, message))
|
||||
stale_count += 1
|
||||
self._pending_requests.pop(request_id, None)
|
||||
return stale_count
|
||||
|
||||
@@ -135,6 +135,7 @@ class PluginSupervisor:
|
||||
|
||||
# 已注册的插件组件信息
|
||||
self._registered_plugins: Dict[str, RegisterComponentsPayload] = {}
|
||||
self._staged_registered_plugins: Dict[str, RegisterComponentsPayload] = {}
|
||||
|
||||
# 后台任务
|
||||
self._health_task: Optional[asyncio.Task] = None
|
||||
@@ -319,6 +320,10 @@ class PluginSupervisor:
|
||||
old_session_token = self._rpc_server.session_token
|
||||
expected_generation = self._rpc_server.runner_generation + 1
|
||||
|
||||
# 允许新 Runner 以 staged 方式接入,验证通过后再切换活跃连接
|
||||
self._rpc_server.begin_staged_takeover()
|
||||
self._staged_registered_plugins.clear()
|
||||
|
||||
# 重新生成 session token,防止被终止的旧 Runner 重连
|
||||
self._rpc_server.reset_session_token()
|
||||
|
||||
@@ -330,27 +335,35 @@ class PluginSupervisor:
|
||||
# 拉起新 Runner
|
||||
try:
|
||||
await self._spawn_runner()
|
||||
await self._wait_for_runner_generation(expected_generation, timeout_sec=self._runner_spawn_timeout)
|
||||
resp = await self._rpc_server.send_request("plugin.health", timeout_ms=5000)
|
||||
await self._wait_for_runner_generation(
|
||||
expected_generation,
|
||||
timeout_sec=self._runner_spawn_timeout,
|
||||
allow_staged=True,
|
||||
)
|
||||
resp = await self._rpc_server.send_request(
|
||||
"plugin.health",
|
||||
timeout_ms=5000,
|
||||
target_generation=expected_generation,
|
||||
)
|
||||
health = HealthPayload.model_validate(resp.payload)
|
||||
if not health.healthy:
|
||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "新 Runner 健康检查失败")
|
||||
await self._rpc_server.commit_staged_takeover()
|
||||
except Exception as e:
|
||||
logger.error(f"新 Runner 健康检查失败: {e},回滚")
|
||||
await self._terminate_process(self._runner_process, old_process)
|
||||
await self._rpc_server.rollback_staged_takeover()
|
||||
self._runner_process = old_process
|
||||
# 恢复旧 session token,使旧 Runner 的连接仍可正常工作
|
||||
self._rpc_server.restore_session_token(old_session_token)
|
||||
self._staged_registered_plugins.clear()
|
||||
self._registered_plugins = dict(old_registered_plugins)
|
||||
self._rebuild_runtime_state()
|
||||
return
|
||||
|
||||
# 新 Runner 健康且已完成组件注册,现在清理旧的幽灵组件
|
||||
# 只移除不再存在于新注册表中的旧插件组件
|
||||
for old_pid in list(old_registered_plugins.keys()):
|
||||
if old_pid not in self._registered_plugins:
|
||||
self._component_registry.remove_components_by_plugin(old_pid)
|
||||
self._policy.revoke_plugin(old_pid)
|
||||
self._runner_generation = self._rpc_server.runner_generation
|
||||
self._registered_plugins = dict(self._staged_registered_plugins)
|
||||
self._staged_registered_plugins.clear()
|
||||
self._rebuild_runtime_state()
|
||||
|
||||
# 关停旧 Runner
|
||||
if old_process and old_process.returncode is None:
|
||||
@@ -380,13 +393,22 @@ class PluginSupervisor:
|
||||
except Exception as e:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e))
|
||||
|
||||
if envelope.generation != self._rpc_server.runner_generation:
|
||||
active_generation = self._rpc_server.runner_generation
|
||||
staged_generation = self._rpc_server.staged_generation
|
||||
if envelope.generation not in {active_generation, staged_generation}:
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_GENERATION_MISMATCH.value,
|
||||
f"组件注册 generation 过期: {envelope.generation} != {self._rpc_server.runner_generation}",
|
||||
f"组件注册 generation 过期: {envelope.generation} 不在已知代际中",
|
||||
)
|
||||
|
||||
# 记录注册信息
|
||||
if envelope.generation == staged_generation and staged_generation != 0:
|
||||
self._staged_registered_plugins[reg.plugin_id] = reg
|
||||
logger.info(
|
||||
f"插件 {reg.plugin_id} v{reg.plugin_version} staged 注册成功,"
|
||||
f"组件数: {len(reg.components)}, 能力需求: {reg.capabilities_required}"
|
||||
)
|
||||
return envelope.make_response(payload={"accepted": True, "staged": True})
|
||||
|
||||
self._registered_plugins[reg.plugin_id] = reg
|
||||
|
||||
# 在策略引擎中注册插件
|
||||
@@ -396,7 +418,8 @@ class PluginSupervisor:
|
||||
capabilities=reg.capabilities_required or [],
|
||||
)
|
||||
|
||||
# 在 ComponentRegistry 中注册组件
|
||||
# 同 generation 下重新注册时,以本次声明为准,避免残留幽灵组件
|
||||
self._component_registry.remove_components_by_plugin(reg.plugin_id)
|
||||
self._component_registry.register_plugin_components(
|
||||
plugin_id=reg.plugin_id,
|
||||
components=[c.model_dump() for c in reg.components],
|
||||
@@ -518,10 +541,17 @@ class PluginSupervisor:
|
||||
except Exception as e:
|
||||
logger.error(f"健康检查异常: {e}")
|
||||
|
||||
async def _wait_for_runner_generation(self, expected_generation: int, timeout_sec: float) -> None:
|
||||
async def _wait_for_runner_generation(
|
||||
self,
|
||||
expected_generation: int,
|
||||
timeout_sec: float,
|
||||
allow_staged: bool = False,
|
||||
) -> None:
|
||||
"""等待指定代际的 Runner 完成连接。"""
|
||||
deadline = asyncio.get_running_loop().time() + timeout_sec
|
||||
while asyncio.get_running_loop().time() < deadline:
|
||||
if allow_staged and self._rpc_server.has_generation(expected_generation):
|
||||
return
|
||||
if self._rpc_server.is_connected and self._rpc_server.runner_generation >= expected_generation:
|
||||
self._runner_generation = self._rpc_server.runner_generation
|
||||
return
|
||||
@@ -533,6 +563,7 @@ class PluginSupervisor:
|
||||
self._component_registry.clear()
|
||||
self._policy.clear()
|
||||
self._registered_plugins.clear()
|
||||
self._staged_registered_plugins.clear()
|
||||
|
||||
def _rebuild_runtime_state(self) -> None:
|
||||
"""根据已记录的插件注册信息重建运行时状态。"""
|
||||
|
||||
@@ -430,8 +430,10 @@ class PluginRuntimeManager:
|
||||
"""
|
||||
from src.services import send_service as send_api
|
||||
|
||||
message_type: str = args.get("message_type", "")
|
||||
content = args.get("content", "")
|
||||
message_type: str = args.get("message_type", "") or args.get("custom_type", "")
|
||||
content = args.get("content")
|
||||
if content is None:
|
||||
content = args.get("data", "")
|
||||
stream_id: str = args.get("stream_id", "")
|
||||
if not message_type or not stream_id:
|
||||
return {"success": False, "error": "缺少必要参数 message_type 或 stream_id"}
|
||||
|
||||
@@ -76,6 +76,7 @@ class RPCClient:
|
||||
# 运行状态
|
||||
self._running = False
|
||||
self._recv_task: Optional[asyncio.Task] = None
|
||||
self._background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
@property
|
||||
def generation(self) -> int:
|
||||
@@ -144,6 +145,13 @@ class RPCClient:
|
||||
await self._recv_task
|
||||
self._recv_task = None
|
||||
|
||||
for task in list(self._background_tasks):
|
||||
task.cancel()
|
||||
if self._background_tasks:
|
||||
with contextlib.suppress(Exception):
|
||||
await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
||||
self._background_tasks.clear()
|
||||
|
||||
# 取消所有 pending 请求
|
||||
for future in self._pending_requests.values():
|
||||
if not future.done():
|
||||
@@ -248,9 +256,9 @@ class RPCClient:
|
||||
if envelope.is_response():
|
||||
self._handle_response(envelope)
|
||||
elif envelope.is_request():
|
||||
asyncio.create_task(self._handle_request(envelope))
|
||||
self._track_background_task(asyncio.create_task(self._handle_request(envelope)))
|
||||
elif envelope.is_event():
|
||||
asyncio.create_task(self._handle_event(envelope))
|
||||
self._track_background_task(asyncio.create_task(self._handle_event(envelope)))
|
||||
|
||||
def _handle_response(self, envelope: Envelope) -> None:
|
||||
"""处理来自 Host 的响应"""
|
||||
@@ -290,3 +298,8 @@ class RPCClient:
|
||||
await handler(envelope)
|
||||
except Exception as e:
|
||||
logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True)
|
||||
|
||||
def _track_background_task(self, task: asyncio.Task) -> None:
|
||||
"""保持后台任务强引用,直到其完成或被取消。"""
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
Reference in New Issue
Block a user