Files
mai-bot/src/plugin_runtime/runner/runner_main.py

473 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Runner 主循环
作为独立子进程运行,负责:
1. 从环境变量读取 IPC 地址和会话令牌
2. 连接 Host 并完成握手
3. 加载所有插件
4. 注册组件到 Host
5. 处理 Host 的调用请求
6. 转发插件的能力调用到 Host
"""
import logging as stdlib_logging
from typing import List, Optional
import asyncio
import contextlib
import inspect
import os
import signal
import sys
import time
from typing import Any
from src.common.logger import get_logger, initialize_logging
from src.plugin_runtime.protocol.envelope import (
ComponentDeclaration,
Envelope,
HealthPayload,
InvokePayload,
InvokeResultPayload,
RegisterComponentsPayload,
)
from src.plugin_runtime.protocol.errors import ErrorCode
from src.plugin_runtime.runner.log_handler import RunnerIPCLogHandler
from src.plugin_runtime.runner.plugin_loader import PluginLoader, PluginMeta
from src.plugin_runtime.runner.rpc_client import RPCClient
logger = get_logger("plugin_runtime.runner.main")
class PluginRunner:
"""插件 Runner
运行在独立子进程中,管理所有插件的执行。
"""
def __init__(
self,
host_address: str,
session_token: str,
plugin_dirs: List[str],
) -> None:
self._host_address: str = host_address
self._session_token: str = session_token
self._plugin_dirs: list[str] = plugin_dirs
self._rpc_client: RPCClient = RPCClient(host_address, session_token)
self._loader: PluginLoader = PluginLoader()
self._start_time: float = time.monotonic()
self._shutting_down: bool = False
# IPC 日志 Handler握手成功后安装将所有 stdlib logging 转发到 Host
self._log_handler: Optional[RunnerIPCLogHandler] = None
async def run(self) -> None:
"""Runner 主入口"""
# 1. 连接 Host
logger.info(f"Runner 启动,连接 Host: {self._host_address}")
ok = await self._rpc_client.connect_and_handshake()
if not ok:
logger.error("握手失败,退出")
return
# 2. 握手成功后立即安装 IPC 日志 Handler接管所有 Runner 端日志
self._install_log_handler()
# 3. 注册方法处理器
self._register_handlers()
# 3. 加载插件
plugins = self._loader.discover_and_load(self._plugin_dirs)
logger.info(f"已加载 {len(plugins)} 个插件")
# 4. 注入 PluginContext + 调用 on_load 生命周期钩子
for meta in plugins:
instance = meta.instance
self._inject_context(meta.plugin_id, instance)
if hasattr(instance, "on_load"):
try:
ret = instance.on_load()
if asyncio.iscoroutine(ret):
await ret
except Exception as e:
logger.error(f"插件 {meta.plugin_id} on_load 失败: {e}", exc_info=True)
# 5. 向 Host 注册所有插件的组件
for meta in plugins:
await self._register_plugin(meta)
# 5. 等待直到收到关停信号
with contextlib.suppress(asyncio.CancelledError):
while not self._shutting_down:
await asyncio.sleep(1.0)
# 6. 卸载 IPC 日志 Handler 并刷空剩余缓冲,然后断开连接
logger.info("Runner 开始关停")
await self._uninstall_log_handler()
await self._rpc_client.disconnect()
logger.info("Runner 已退出")
def _install_log_handler(self) -> None:
"""握手完成后将 RunnerIPCLogHandler 安装到 logging.root。
安装后Runner 进程内所有 stdlib logging 调用(含 structlog 透传的)
均会通过 IPC 转发到 Host由 Host 的 RunnerLogBridge 重放到主进程 Logger。
"""
loop = asyncio.get_running_loop()
handler = RunnerIPCLogHandler()
handler.start(self._rpc_client, loop)
stdlib_logging.root.addHandler(handler)
self._log_handler = handler
logger.debug("RunnerIPCLogHandler \u5df2\u5b89\u88c3\uff0c\u63d2\u4ef6\u65e5\u5fd7\u5c06\u901a\u8fc7 IPC \u8f6c\u53d1\u5230\u4e3b\u8fdb\u7a0b")
async def _uninstall_log_handler(self) -> None:
"""关停前从 logging.root 移除 Handler 并刷空缓冲。
必须在 disconnect() 之前调用,确保最后一批日志能正常发送。
"""
if self._log_handler is None:
return
stdlib_logging.root.removeHandler(self._log_handler)
await self._log_handler.stop()
self._log_handler = None
logger.debug("RunnerIPCLogHandler \u5df2\u5378\u8f7d")
def _inject_context(self, plugin_id: str, instance: object) -> None:
"""为插件实例创建并注入 PluginContext。
对新版 MaiBotPlugin具有 _set_context 方法):创建 PluginContext 并注入。
对旧版 LegacyPluginAdapter具有 _set_context 方法,由适配器代理):同上。
"""
if not hasattr(instance, "_set_context"):
return
try:
from maibot_sdk.context import PluginContext
except ImportError:
logger.warning(f"maibot_sdk 不可用,无法为插件 {plugin_id} 创建 PluginContext")
return
rpc_client = self._rpc_client
async def _rpc_call(method: str, plugin_id: str = "", payload: dict = None) -> Any:
"""桥接 PluginContext.call_capability → RPCClient.send_request"""
resp = await rpc_client.send_request(
method=method,
plugin_id=plugin_id,
payload=payload or {},
)
# 从响应信封中提取业务结果
if resp.error:
raise RuntimeError(resp.error.get("message", "能力调用失败"))
return resp.payload.get("result")
ctx = PluginContext(plugin_id=plugin_id, rpc_call=_rpc_call)
instance._set_context(ctx)
logger.debug(f"已为插件 {plugin_id} 注入 PluginContext")
def _register_handlers(self) -> None:
"""注册方法处理器"""
self._rpc_client.register_method("plugin.invoke_command", self._handle_invoke)
self._rpc_client.register_method("plugin.invoke_action", self._handle_invoke)
self._rpc_client.register_method("plugin.invoke_tool", self._handle_invoke)
self._rpc_client.register_method("plugin.emit_event", self._handle_invoke)
self._rpc_client.register_method("plugin.invoke_workflow_step", self._handle_workflow_step)
self._rpc_client.register_method("plugin.health", self._handle_health)
self._rpc_client.register_method("plugin.prepare_shutdown", self._handle_prepare_shutdown)
self._rpc_client.register_method("plugin.shutdown", self._handle_shutdown)
self._rpc_client.register_method("plugin.config_updated", self._handle_config_updated)
async def _register_plugin(self, meta: PluginMeta) -> None:
"""向 Host 注册单个插件"""
# 收集插件组件声明
components: List[ComponentDeclaration] = []
instance = meta.instance
# 从插件实例获取组件声明SDK 插件须实现 get_components 方法)
if hasattr(instance, "get_components"):
components.extend(
ComponentDeclaration(
name=comp_info.get("name", ""),
component_type=comp_info.get("type", ""),
plugin_id=meta.plugin_id,
metadata=comp_info.get("metadata", {}),
)
for comp_info in instance.get_components()
)
reg_payload = RegisterComponentsPayload(
plugin_id=meta.plugin_id,
plugin_version=meta.version,
components=components,
capabilities_required=meta.capabilities_required,
)
try:
_resp = await self._rpc_client.send_request(
"plugin.register_components",
plugin_id=meta.plugin_id,
payload=reg_payload.model_dump(),
timeout_ms=10000,
)
logger.info(f"插件 {meta.plugin_id} 注册完成")
except Exception as e:
logger.error(f"插件 {meta.plugin_id} 注册失败: {e}")
async def _handle_invoke(self, envelope: Envelope) -> Envelope:
"""处理组件调用请求"""
try:
invoke = InvokePayload.model_validate(envelope.payload)
except Exception as e:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e))
plugin_id = envelope.plugin_id
meta = self._loader.get_plugin(plugin_id)
if meta is None:
return envelope.make_error_response(
ErrorCode.E_PLUGIN_NOT_FOUND.value,
f"插件 {plugin_id} 未加载",
)
# 调用插件实例的组件方法
instance = meta.instance
component_name = invoke.component_name
# 优先查找 handle_<name> 或直接 <name> 方法(新版 SDK 插件)
handler_method = getattr(instance, f"handle_{component_name}", None)
if handler_method is None:
handler_method = getattr(instance, component_name, None)
# 回退: 旧版 LegacyPluginAdapter 通过 invoke_component 统一桥接
if (handler_method is None or not callable(handler_method)) and hasattr(instance, "invoke_component"):
try:
result = await instance.invoke_component(component_name, **invoke.args)
resp_payload = InvokeResultPayload(success=True, result=result)
return envelope.make_response(payload=resp_payload.model_dump())
except Exception as e:
logger.error(f"插件 {plugin_id} 组件 {component_name} (legacy) 执行异常: {e}", exc_info=True)
resp_payload = InvokeResultPayload(success=False, result=str(e))
return envelope.make_response(payload=resp_payload.model_dump())
if handler_method is None or not callable(handler_method):
return envelope.make_error_response(
ErrorCode.E_METHOD_NOT_ALLOWED.value,
f"插件 {plugin_id} 无组件: {component_name}",
)
try:
result = await handler_method(**invoke.args) if inspect.iscoroutinefunction(handler_method) else handler_method(**invoke.args)
resp_payload = InvokeResultPayload(success=True, result=result)
return envelope.make_response(payload=resp_payload.model_dump())
except Exception as e:
logger.error(f"插件 {plugin_id} 组件 {component_name} 执行异常: {e}", exc_info=True)
resp_payload = InvokeResultPayload(success=False, result=str(e))
return envelope.make_response(payload=resp_payload.model_dump())
async def _handle_workflow_step(self, envelope: Envelope) -> Envelope:
"""处理 WorkflowStep 调用请求
与通用 invoke 不同,会将返回值规范化为
{hook_result, modified_message, stage_output} 格式。
"""
try:
invoke = InvokePayload.model_validate(envelope.payload)
except Exception as e:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e))
plugin_id = envelope.plugin_id
meta = self._loader.get_plugin(plugin_id)
if meta is None:
return envelope.make_error_response(
ErrorCode.E_PLUGIN_NOT_FOUND.value,
f"插件 {plugin_id} 未加载",
)
instance = meta.instance
component_name = invoke.component_name
handler_method = getattr(instance, f"handle_{component_name}", None) or getattr(instance, component_name, None)
if handler_method is None or not callable(handler_method):
return envelope.make_error_response(
ErrorCode.E_METHOD_NOT_ALLOWED.value,
f"插件 {plugin_id} 无组件: {component_name}",
)
try:
raw = await handler_method(**invoke.args) if inspect.iscoroutinefunction(handler_method) else handler_method(**invoke.args)
# 规范化返回值
if isinstance(raw, str):
result = {"hook_result": raw}
elif isinstance(raw, dict):
result = raw
result.setdefault("hook_result", "continue")
else:
result = {"hook_result": "continue"}
resp_payload = InvokeResultPayload(success=True, result=result)
return envelope.make_response(payload=resp_payload.model_dump())
except Exception as e:
logger.error(f"插件 {plugin_id} workflow_step {component_name} 执行异常: {e}", exc_info=True)
resp_payload = InvokeResultPayload(success=False, result=str(e))
return envelope.make_response(payload=resp_payload.model_dump())
async def _handle_health(self, envelope: Envelope) -> Envelope:
"""处理健康检查"""
uptime_ms = int((time.monotonic() - self._start_time) * 1000)
health = HealthPayload(
healthy=True,
loaded_plugins=self._loader.list_plugins(),
uptime_ms=uptime_ms,
)
return envelope.make_response(payload=health.model_dump())
async def _handle_prepare_shutdown(self, envelope: Envelope) -> Envelope:
"""处理准备关停"""
logger.info("收到 prepare_shutdown 信号")
return envelope.make_response(payload={"acknowledged": True})
async def _handle_shutdown(self, envelope: Envelope) -> Envelope:
"""处理关停 — 调用所有插件的 on_unload 后退出"""
logger.info("收到 shutdown 信号,开始调用 on_unload")
for plugin_id in self._loader.list_plugins():
meta = self._loader.get_plugin(plugin_id)
if meta and hasattr(meta.instance, "on_unload"):
try:
ret = meta.instance.on_unload()
if asyncio.iscoroutine(ret):
await ret
except Exception as e:
logger.error(f"插件 {plugin_id} on_unload 失败: {e}", exc_info=True)
self._shutting_down = True
return envelope.make_response(payload={"acknowledged": True})
async def _handle_config_updated(self, envelope: Envelope) -> Envelope:
"""处理配置更新事件"""
plugin_id = envelope.plugin_id
meta = self._loader.get_plugin(plugin_id)
if meta and hasattr(meta.instance, "on_config_update"):
try:
config_data = envelope.payload.get("config_data", {})
config_version = envelope.payload.get("config_version", "")
ret = meta.instance.on_config_update(config_data, config_version)
# 兼容同步和异步的 on_config_update 实现
if asyncio.iscoroutine(ret):
await ret
except Exception as e:
logger.error(f"插件 {plugin_id} 配置更新失败: {e}")
return envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
return envelope.make_response(payload={"acknowledged": True})
def request_capability(self) -> RPCClient:
"""获取 RPC 客户端(供 SDK 使用,发起能力调用)"""
return self._rpc_client
# ─── sys.path 隔离 ────────────────────────────────────────
def _isolate_sys_path(plugin_dirs: List[str]) -> None:
"""清理 sys.path限制 Runner 子进程只能访问标准库、SDK 和插件目录。
防止插件代码 import 主程序模块读取运行时数据。
"""
import importlib.abc
import sysconfig
# 保留: 标准库路径 + site-packages含 SDK 和依赖)
stdlib_paths = set()
for key in ("stdlib", "platstdlib", "purelib", "platlib"):
if path := sysconfig.get_path(key):
stdlib_paths.add(os.path.normpath(path))
allowed = set()
for p in sys.path:
norm = os.path.normpath(p)
# 保留标准库和 site-packages
if any(norm.startswith(sp) for sp in stdlib_paths):
allowed.add(p)
# 保留 site-packages第三方库 + SDK
if "site-packages" in norm or "dist-packages" in norm:
allowed.add(p)
# 添加插件目录
for d in plugin_dirs:
allowed.add(os.path.normpath(d))
# 添加项目根目录(使得 src.plugin_runtime / src.common 可导入)
runtime_root = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
allowed.add(runtime_root)
sys.path[:] = [p for p in sys.path if p in allowed]
# 安装 import 钩子,阻止插件导入主程序核心模块
# 仅允许 src.plugin_runtime 和 src.common拒绝其他 src.* 子包
class _PluginImportBlocker(importlib.abc.MetaPathFinder):
"""阻止 Runner 子进程导入主程序核心模块。
只放行 src.plugin_runtime 和 src.common
拒绝 src.chat_module / src.services 等主程序内部包。
"""
_ALLOWED_SRC_PREFIXES = ("src.plugin_runtime", "src.common")
def find_module(self, fullname, path=None):
if self._should_block(fullname):
return self
return None
def load_module(self, fullname):
raise ImportError(
f"Runner 子进程不允许导入主程序模块: {fullname}"
)
def _should_block(self, fullname: str) -> bool:
# 放行非 src.* 的导入、以及 "src" 本身
if not fullname.startswith("src.") or fullname == "src":
return False
# 放行白名单前缀
for prefix in self._ALLOWED_SRC_PREFIXES:
if fullname == prefix or fullname.startswith(prefix + "."):
return False
return True
sys.meta_path.insert(0, _PluginImportBlocker())
# ─── 进程入口 ──────────────────────────────────────────────
async def _async_main() -> None:
"""异步主入口"""
host_address = os.environ.get("MAIBOT_IPC_ADDRESS", "")
session_token = os.environ.get("MAIBOT_SESSION_TOKEN", "")
plugin_dirs_str = os.environ.get("MAIBOT_PLUGIN_DIRS", "")
if not host_address or not session_token:
logger.error("缺少必要的环境变量: MAIBOT_IPC_ADDRESS, MAIBOT_SESSION_TOKEN")
sys.exit(1)
plugin_dirs = [d for d in plugin_dirs_str.split(os.pathsep) if d]
# sys.path 隔离: 只保留标准库、SDK 包、插件目录
_isolate_sys_path(plugin_dirs)
runner = PluginRunner(host_address, session_token, plugin_dirs)
# 注册信号处理
loop = asyncio.get_event_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, lambda: setattr(runner, "_shutting_down", True))
await runner.run()
def main() -> None:
"""进程入口python -m src.plugin_runtime.runner.runner_main"""
initialize_logging(verbose=False)
asyncio.run(_async_main())
if __name__ == "__main__":
main()