feat: 增强组件注册和事件分发,添加会话令牌恢复功能,优化工作流执行超时处理
This commit is contained in:
@@ -93,14 +93,17 @@ class ComponentRegistry:
|
||||
comp = RegisteredComponent(name, component_type, plugin_id, metadata)
|
||||
if comp.full_name in self._components:
|
||||
logger.warning(f"组件 {comp.full_name} 已存在,覆盖")
|
||||
# 从 _by_plugin 列表中移除旧条目,防止幽灵组件堆积
|
||||
old_comp = self._components[comp.full_name]
|
||||
# 从 _by_plugin 列表中移除旧条目,防止幽灵组件堆积
|
||||
old_list = self._by_plugin.get(old_comp.plugin_id)
|
||||
if old_list is not None:
|
||||
try:
|
||||
old_list.remove(old_comp)
|
||||
except ValueError:
|
||||
pass
|
||||
# 从旧类型索引中移除,防止类型变更时幽灵残留
|
||||
if old_type_dict := self._by_type.get(old_comp.component_type):
|
||||
old_type_dict.pop(comp.full_name, None)
|
||||
|
||||
self._components[comp.full_name] = comp
|
||||
|
||||
|
||||
@@ -51,6 +51,8 @@ class EventDispatcher:
|
||||
self._registry: ComponentRegistry = registry
|
||||
self._result_history: Dict[str, List[EventResult]] = {}
|
||||
self._history_enabled: Set[str] = set()
|
||||
# 保持 fire-and-forget task 的强引用,防止被 GC 回收
|
||||
self._background_tasks: Set[asyncio.Task] = set()
|
||||
|
||||
def enable_history(self, event_type: str) -> None:
|
||||
self._history_enabled.add(event_type)
|
||||
@@ -87,7 +89,6 @@ class EventDispatcher:
|
||||
|
||||
should_continue = True
|
||||
modified_message: Optional[Dict[str, Any]] = None
|
||||
fire_and_forget_tasks: List[asyncio.Task] = []
|
||||
|
||||
for handler in handlers:
|
||||
intercept = handler.metadata.get("intercept_message", False)
|
||||
@@ -105,16 +106,12 @@ class EventDispatcher:
|
||||
if result and result.modified_message:
|
||||
modified_message = result.modified_message
|
||||
else:
|
||||
# 非阻塞
|
||||
# 非阻塞:保持实例级强引用,防止 task 被 GC 回收
|
||||
task = asyncio.create_task(
|
||||
self._invoke_handler(invoke_fn, handler, args, event_type)
|
||||
)
|
||||
fire_and_forget_tasks.append(task)
|
||||
|
||||
# 不等待 fire-and-forget 任务(但不丢弃引用以防 GC)
|
||||
if fire_and_forget_tasks:
|
||||
for t in fire_and_forget_tasks:
|
||||
t.add_done_callback(lambda _t: None)
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
return should_continue, modified_message
|
||||
|
||||
|
||||
@@ -78,6 +78,10 @@ class RPCServer:
|
||||
self._session_token = secrets.token_hex(32)
|
||||
return self._session_token
|
||||
|
||||
def restore_session_token(self, token: str) -> None:
|
||||
"""恢复指定的会话令牌(热重载回滚时调用)"""
|
||||
self._session_token = token
|
||||
|
||||
@property
|
||||
def runner_generation(self) -> int:
|
||||
return self._runner_generation
|
||||
|
||||
@@ -288,9 +288,10 @@ class PluginSupervisor:
|
||||
"""
|
||||
logger.info(f"开始热重载插件,原因: {reason}")
|
||||
|
||||
# 保存旧进程引用
|
||||
# 保存旧进程引用和旧 session token(回滚时需要恢复)
|
||||
old_process = self._runner_process
|
||||
old_registered_plugins = dict(self._registered_plugins)
|
||||
old_session_token = self._rpc_server.session_token
|
||||
expected_generation = self._rpc_server.runner_generation + 1
|
||||
|
||||
# 重新生成 session token,防止被终止的旧 Runner 重连
|
||||
@@ -313,6 +314,8 @@ class PluginSupervisor:
|
||||
logger.error(f"新 Runner 健康检查失败: {e},回滚")
|
||||
await self._terminate_process(self._runner_process, old_process)
|
||||
self._runner_process = old_process
|
||||
# 恢复旧 session token,使旧 Runner 的连接仍可正常工作
|
||||
self._rpc_server.restore_session_token(old_session_token)
|
||||
self._registered_plugins = dict(old_registered_plugins)
|
||||
self._rebuild_runtime_state()
|
||||
return
|
||||
|
||||
@@ -43,6 +43,9 @@ HOOK_CONTINUE = "continue"
|
||||
HOOK_SKIP_STAGE = "skip_stage"
|
||||
HOOK_ABORT = "abort"
|
||||
|
||||
# blocking hook 全局最大超时(秒):即使 hook 声明 timeout_ms=0 也不会无限等待
|
||||
GLOBAL_BLOCKING_TIMEOUT_SEC = 120.0
|
||||
|
||||
|
||||
class ModificationRecord:
|
||||
"""消息修改记录"""
|
||||
@@ -296,7 +299,8 @@ class WorkflowExecutor:
|
||||
(hook_result, modified_message, error_string_or_None)
|
||||
"""
|
||||
timeout_ms = step.metadata.get("timeout_ms", 0)
|
||||
timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else None
|
||||
# 使用 hook 声明的超时,但不超过全局安全阀
|
||||
timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else GLOBAL_BLOCKING_TIMEOUT_SEC
|
||||
step_key = f"{stage}:{step.full_name}"
|
||||
step_start = time.perf_counter()
|
||||
|
||||
@@ -307,7 +311,7 @@ class WorkflowExecutor:
|
||||
"message": message,
|
||||
"stage_outputs": ctx.stage_outputs,
|
||||
})
|
||||
resp = await asyncio.wait_for(coro, timeout=timeout_sec) if timeout_sec else await coro
|
||||
resp = await asyncio.wait_for(coro, timeout=timeout_sec)
|
||||
ctx.timings[step_key] = time.perf_counter() - step_start
|
||||
|
||||
hook_result = resp.get("hook_result", HOOK_CONTINUE)
|
||||
|
||||
@@ -32,7 +32,17 @@ logger = get_logger("plugin_runtime.runner.rpc_client")
|
||||
# RPC 方法处理器类型
|
||||
MethodHandler = Callable[[Envelope], Awaitable[Envelope]]
|
||||
|
||||
SDK_VERSION = "1.0.0"
|
||||
|
||||
def _get_sdk_version() -> str:
|
||||
"""从 maibot_sdk 包元数据中读取实际版本号,失败时回退到 1.0.0。"""
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
return version("maibot-plugin-sdk")
|
||||
except Exception:
|
||||
return "1.0.0"
|
||||
|
||||
|
||||
SDK_VERSION = _get_sdk_version()
|
||||
|
||||
|
||||
class RPCClient:
|
||||
|
||||
Reference in New Issue
Block a user