feat: 增强组件注册和事件分发,添加会话令牌恢复功能,优化工作流执行超时处理

This commit is contained in:
DrSmoothl
2026-03-12 23:53:15 +08:00
parent 4b7ee3923c
commit c620040191
6 changed files with 34 additions and 13 deletions

View File

@@ -93,14 +93,17 @@ class ComponentRegistry:
comp = RegisteredComponent(name, component_type, plugin_id, metadata)
if comp.full_name in self._components:
logger.warning(f"组件 {comp.full_name} 已存在,覆盖")
# 从 _by_plugin 列表中移除旧条目,防止幽灵组件堆积
old_comp = self._components[comp.full_name]
# 从 _by_plugin 列表中移除旧条目,防止幽灵组件堆积
old_list = self._by_plugin.get(old_comp.plugin_id)
if old_list is not None:
try:
old_list.remove(old_comp)
except ValueError:
pass
# 从旧类型索引中移除,防止类型变更时幽灵残留
if old_type_dict := self._by_type.get(old_comp.component_type):
old_type_dict.pop(comp.full_name, None)
self._components[comp.full_name] = comp

View File

@@ -51,6 +51,8 @@ class EventDispatcher:
self._registry: ComponentRegistry = registry
self._result_history: Dict[str, List[EventResult]] = {}
self._history_enabled: Set[str] = set()
# 保持 fire-and-forget task 的强引用,防止被 GC 回收
self._background_tasks: Set[asyncio.Task] = set()
def enable_history(self, event_type: str) -> None:
self._history_enabled.add(event_type)
@@ -87,7 +89,6 @@ class EventDispatcher:
should_continue = True
modified_message: Optional[Dict[str, Any]] = None
fire_and_forget_tasks: List[asyncio.Task] = []
for handler in handlers:
intercept = handler.metadata.get("intercept_message", False)
@@ -105,16 +106,12 @@ class EventDispatcher:
if result and result.modified_message:
modified_message = result.modified_message
else:
# 非阻塞
# 非阻塞:保持实例级强引用,防止 task 被 GC 回收
task = asyncio.create_task(
self._invoke_handler(invoke_fn, handler, args, event_type)
)
fire_and_forget_tasks.append(task)
# 不等待 fire-and-forget 任务(但不丢弃引用以防 GC
if fire_and_forget_tasks:
for t in fire_and_forget_tasks:
t.add_done_callback(lambda _t: None)
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
return should_continue, modified_message

View File

@@ -78,6 +78,10 @@ class RPCServer:
self._session_token = secrets.token_hex(32)
return self._session_token
def restore_session_token(self, token: str) -> None:
"""恢复指定的会话令牌(热重载回滚时调用)"""
self._session_token = token
@property
def runner_generation(self) -> int:
return self._runner_generation

View File

@@ -288,9 +288,10 @@ class PluginSupervisor:
"""
logger.info(f"开始热重载插件,原因: {reason}")
# 保存旧进程引用
# 保存旧进程引用和旧 session token回滚时需要恢复
old_process = self._runner_process
old_registered_plugins = dict(self._registered_plugins)
old_session_token = self._rpc_server.session_token
expected_generation = self._rpc_server.runner_generation + 1
# 重新生成 session token防止被终止的旧 Runner 重连
@@ -313,6 +314,8 @@ class PluginSupervisor:
logger.error(f"新 Runner 健康检查失败: {e},回滚")
await self._terminate_process(self._runner_process, old_process)
self._runner_process = old_process
# 恢复旧 session token使旧 Runner 的连接仍可正常工作
self._rpc_server.restore_session_token(old_session_token)
self._registered_plugins = dict(old_registered_plugins)
self._rebuild_runtime_state()
return

View File

@@ -43,6 +43,9 @@ HOOK_CONTINUE = "continue"
HOOK_SKIP_STAGE = "skip_stage"
HOOK_ABORT = "abort"
# blocking hook 全局最大超时(秒):即使 hook 声明 timeout_ms=0 也不会无限等待
GLOBAL_BLOCKING_TIMEOUT_SEC = 120.0
class ModificationRecord:
"""消息修改记录"""
@@ -296,7 +299,8 @@ class WorkflowExecutor:
(hook_result, modified_message, error_string_or_None)
"""
timeout_ms = step.metadata.get("timeout_ms", 0)
timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else None
# 使用 hook 声明的超时,但不超过全局安全阀
timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else GLOBAL_BLOCKING_TIMEOUT_SEC
step_key = f"{stage}:{step.full_name}"
step_start = time.perf_counter()
@@ -307,7 +311,7 @@ class WorkflowExecutor:
"message": message,
"stage_outputs": ctx.stage_outputs,
})
resp = await asyncio.wait_for(coro, timeout=timeout_sec) if timeout_sec else await coro
resp = await asyncio.wait_for(coro, timeout=timeout_sec)
ctx.timings[step_key] = time.perf_counter() - step_start
hook_result = resp.get("hook_result", HOOK_CONTINUE)

View File

@@ -32,7 +32,17 @@ logger = get_logger("plugin_runtime.runner.rpc_client")
# RPC 方法处理器类型
MethodHandler = Callable[[Envelope], Awaitable[Envelope]]
SDK_VERSION = "1.0.0"
def _get_sdk_version() -> str:
"""从 maibot_sdk 包元数据中读取实际版本号,失败时回退到 1.0.0。"""
try:
from importlib.metadata import version
return version("maibot-plugin-sdk")
except Exception:
return "1.0.0"
SDK_VERSION = _get_sdk_version()
class RPCClient: